import itertools
|
import logging
|
import time
|
from typing import Dict, Optional, TypeVar
|
|
from django.db import OperationalError, connection, models, transaction
|
from django.db.models import Model, QuerySet, Subquery
|
from django.db.models.signals import post_migrate
|
from django.db.utils import DatabaseError, ProgrammingError
|
from django.dispatch import receiver
|
|
logger = logging.getLogger(__name__)
|
|
|
class SQCount(Subquery):
|
template = '(SELECT count(*) FROM (%(subquery)s) _count)'
|
output_field = models.IntegerField()
|
|
|
ModelType = TypeVar('ModelType', bound=Model)
|
|
|
def fast_first(queryset: QuerySet[ModelType]) -> Optional[ModelType]:
|
"""Replacement for queryset.first() when you don't need ordering,
|
queryset.first() works slowly in some cases
|
"""
|
|
if result := queryset[:1]:
|
return result[0]
|
return None
|
|
|
def fast_first_or_create(model, **model_params) -> Optional[ModelType]:
|
"""Like get_or_create, but using fast_first instead of first(). Additionally, unlike get_or_create, this method will not raise an exception if more than one model instance matching the given params is returned, making it a safer choice than get_or_create for models that don't have a uniqueness constraint on the fields used."""
|
if instance := fast_first(model.objects.filter(**model_params)):
|
return instance
|
return model.objects.create(**model_params)
|
|
|
def batch_update_with_retry(queryset, batch_size=500, max_retries=3, **update_fields):
|
"""
|
Update objects in batches with retry logic to handle deadlocks.
|
|
Args:
|
queryset: QuerySet of objects to update
|
batch_size: Number of objects to update in each batch
|
max_retries: Maximum number of retry attempts for each batch
|
**update_fields: Fields to update (e.g., overlap=1)
|
"""
|
object_ids = list(queryset.values_list('id', flat=True))
|
total_objects = len(object_ids)
|
|
for i in range(0, total_objects, batch_size):
|
batch_ids = object_ids[i : i + batch_size]
|
retry_count = 0
|
last_error = None
|
|
while retry_count < max_retries:
|
try:
|
with transaction.atomic():
|
queryset.model.objects.filter(id__in=batch_ids).update(**update_fields)
|
break
|
except OperationalError as e:
|
last_error = e
|
if 'deadlock detected' in str(e):
|
retry_count += 1
|
wait_time = 0.1 * (2**retry_count) # Exponential backoff
|
logger.warning(
|
f'Deadlock detected, retry {retry_count}/{max_retries} '
|
f'for batch {i}-{i+len(batch_ids)}. Waiting {wait_time}s...'
|
)
|
time.sleep(wait_time)
|
else:
|
raise
|
else:
|
logger.error(f'Failed to update batch after {max_retries} retries. ' f'Batch: {i}-{i+len(batch_ids)}')
|
raise last_error
|
|
|
def batch_delete(queryset, batch_size=500):
|
"""
|
Delete objects in batches to minimize memory usage and transaction size.
|
|
Args:
|
queryset: The queryset to delete
|
batch_size: Number of objects to delete in each batch
|
|
Returns:
|
int: Total number of deleted objects
|
"""
|
total_deleted = 0
|
|
# Create a database cursor that yields primary keys without loading all into memory
|
# The iterator position is maintained between calls to islice
|
# Example: if queryset has 1500 records and batch_size=500:
|
# - First iteration will get records 1-500
|
# - Second iteration will get records 501-1000
|
# - Third iteration will get records 1001-1500
|
# - Fourth iteration will get empty list (no more records)
|
pks_to_delete = queryset.values_list('pk', flat=True).iterator(chunk_size=batch_size)
|
|
# Delete in batches
|
while True:
|
# Get the next batch of primary keys from where the iterator left off
|
# islice advances the iterator's position after taking batch_size items
|
batch_iterator = itertools.islice(pks_to_delete, batch_size)
|
|
# Convert the slice iterator to a list we can use
|
# This only loads batch_size items into memory at a time
|
batch = list(batch_iterator)
|
|
# If no more items to process, we're done
|
# This happens when the iterator is exhausted
|
if not batch:
|
break
|
|
# Delete the batch in a transaction
|
with transaction.atomic():
|
# Delete all objects whose primary keys are in this batch
|
deleted = queryset.model.objects.filter(pk__in=batch).delete()[0]
|
total_deleted += deleted
|
|
return total_deleted
|
|
|
# =====================
|
# Schema helpers
|
# =====================
|
|
_column_presence_cache: Dict[str, Dict[str, Dict[str, bool]]] = {}
|
|
|
def current_db_key() -> str:
|
"""Return a process-stable identifier for the current DB connection.
|
|
Using vendor + NAME isolates caches between sqlite test DBs and postgres runs,
|
avoiding stale lookups across pytest sessions or multi-DB setups.
|
"""
|
try:
|
name = str(connection.settings_dict.get('NAME'))
|
except Exception as e:
|
name = 'unknown'
|
logger.error(f'Error getting current DB key: {e}')
|
return f'{connection.vendor}:{name}'
|
|
|
def has_column_cached(table_name: str, column_name: str) -> bool:
|
"""Check if a DB column exists for the given table, with per-process memoization.
|
|
Notes:
|
- Uses Django introspection; caches per (table, column) with case-insensitive column keys.
|
- Safe during early migrations; returns False on any error.
|
- Works in both sync and async contexts by temporarily allowing async-unsafe operations.
|
"""
|
col_key = column_name.lower()
|
db_cache = _column_presence_cache.get(current_db_key())
|
table_cache = db_cache.get(table_name) if db_cache else None
|
if table_cache and col_key in table_cache:
|
return table_cache[col_key]
|
|
# Check if we're in an async context and need to allow async-unsafe operations
|
import asyncio
|
import os
|
|
in_async_context = False
|
previous_value = os.environ.get('DJANGO_ALLOW_ASYNC_UNSAFE')
|
|
try:
|
asyncio.get_running_loop()
|
in_async_context = True
|
# Temporarily allow async-unsafe operations for schema introspection
|
# This is safe because:
|
# 1. We're only reading DB schema metadata, not data
|
# 2. This happens once per process during startup and is cached
|
# 3. No concurrent access to the same data
|
os.environ['DJANGO_ALLOW_ASYNC_UNSAFE'] = 'true'
|
except RuntimeError:
|
pass # Not in async context
|
|
try:
|
with connection.cursor() as cursor:
|
cols = connection.introspection.get_table_description(cursor, table_name)
|
present = any(getattr(col, 'name', '').lower() == col_key for col in cols)
|
except (DatabaseError, ProgrammingError):
|
present = False
|
finally:
|
# Restore the previous value
|
if in_async_context:
|
if previous_value is None:
|
os.environ.pop('DJANGO_ALLOW_ASYNC_UNSAFE', None)
|
else:
|
os.environ['DJANGO_ALLOW_ASYNC_UNSAFE'] = previous_value
|
|
_column_presence_cache.setdefault(current_db_key(), {}).setdefault(table_name, {})[col_key] = present
|
return present
|
|
|
@receiver(post_migrate)
|
def signal_clear_column_presence_cache(**_kwargs):
|
"""If some migration adds a column, we need to clear the column_presence_cache
|
so that the next migration can introspect the new column using has_column_cached()."""
|
logger.debug('Clearing column presence cache in post_migrate signal')
|
_column_presence_cache.clear()
|