Bin
2025-12-17 d616898802dfe7e5dd648bcf53c6d1f86b6d3642
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import itertools
import logging
import time
from typing import Dict, Optional, TypeVar
 
from django.db import OperationalError, connection, models, transaction
from django.db.models import Model, QuerySet, Subquery
from django.db.models.signals import post_migrate
from django.db.utils import DatabaseError, ProgrammingError
from django.dispatch import receiver
 
logger = logging.getLogger(__name__)
 
 
class SQCount(Subquery):
    template = '(SELECT count(*) FROM (%(subquery)s) _count)'
    output_field = models.IntegerField()
 
 
ModelType = TypeVar('ModelType', bound=Model)
 
 
def fast_first(queryset: QuerySet[ModelType]) -> Optional[ModelType]:
    """Replacement for queryset.first() when you don't need ordering,
    queryset.first() works slowly in some cases
    """
 
    if result := queryset[:1]:
        return result[0]
    return None
 
 
def fast_first_or_create(model, **model_params) -> Optional[ModelType]:
    """Like get_or_create, but using fast_first instead of first(). Additionally, unlike get_or_create, this method will not raise an exception if more than one model instance matching the given params is returned, making it a safer choice than get_or_create for models that don't have a uniqueness constraint on the fields used."""
    if instance := fast_first(model.objects.filter(**model_params)):
        return instance
    return model.objects.create(**model_params)
 
 
def batch_update_with_retry(queryset, batch_size=500, max_retries=3, **update_fields):
    """
    Update objects in batches with retry logic to handle deadlocks.
 
    Args:
        queryset: QuerySet of objects to update
        batch_size: Number of objects to update in each batch
        max_retries: Maximum number of retry attempts for each batch
        **update_fields: Fields to update (e.g., overlap=1)
    """
    object_ids = list(queryset.values_list('id', flat=True))
    total_objects = len(object_ids)
 
    for i in range(0, total_objects, batch_size):
        batch_ids = object_ids[i : i + batch_size]
        retry_count = 0
        last_error = None
 
        while retry_count < max_retries:
            try:
                with transaction.atomic():
                    queryset.model.objects.filter(id__in=batch_ids).update(**update_fields)
                break
            except OperationalError as e:
                last_error = e
                if 'deadlock detected' in str(e):
                    retry_count += 1
                    wait_time = 0.1 * (2**retry_count)  # Exponential backoff
                    logger.warning(
                        f'Deadlock detected, retry {retry_count}/{max_retries} '
                        f'for batch {i}-{i+len(batch_ids)}. Waiting {wait_time}s...'
                    )
                    time.sleep(wait_time)
                else:
                    raise
        else:
            logger.error(f'Failed to update batch after {max_retries} retries. ' f'Batch: {i}-{i+len(batch_ids)}')
            raise last_error
 
 
def batch_delete(queryset, batch_size=500):
    """
    Delete objects in batches to minimize memory usage and transaction size.
 
    Args:
        queryset: The queryset to delete
        batch_size: Number of objects to delete in each batch
 
    Returns:
        int: Total number of deleted objects
    """
    total_deleted = 0
 
    # Create a database cursor that yields primary keys without loading all into memory
    # The iterator position is maintained between calls to islice
    # Example: if queryset has 1500 records and batch_size=500:
    # - First iteration will get records 1-500
    # - Second iteration will get records 501-1000
    # - Third iteration will get records 1001-1500
    # - Fourth iteration will get empty list (no more records)
    pks_to_delete = queryset.values_list('pk', flat=True).iterator(chunk_size=batch_size)
 
    # Delete in batches
    while True:
        # Get the next batch of primary keys from where the iterator left off
        # islice advances the iterator's position after taking batch_size items
        batch_iterator = itertools.islice(pks_to_delete, batch_size)
 
        # Convert the slice iterator to a list we can use
        # This only loads batch_size items into memory at a time
        batch = list(batch_iterator)
 
        # If no more items to process, we're done
        # This happens when the iterator is exhausted
        if not batch:
            break
 
        # Delete the batch in a transaction
        with transaction.atomic():
            # Delete all objects whose primary keys are in this batch
            deleted = queryset.model.objects.filter(pk__in=batch).delete()[0]
            total_deleted += deleted
 
    return total_deleted
 
 
# =====================
# Schema helpers
# =====================
 
_column_presence_cache: Dict[str, Dict[str, Dict[str, bool]]] = {}
 
 
def current_db_key() -> str:
    """Return a process-stable identifier for the current DB connection.
 
    Using vendor + NAME isolates caches between sqlite test DBs and postgres runs,
    avoiding stale lookups across pytest sessions or multi-DB setups.
    """
    try:
        name = str(connection.settings_dict.get('NAME'))
    except Exception as e:
        name = 'unknown'
        logger.error(f'Error getting current DB key: {e}')
    return f'{connection.vendor}:{name}'
 
 
def has_column_cached(table_name: str, column_name: str) -> bool:
    """Check if a DB column exists for the given table, with per-process memoization.
 
    Notes:
    - Uses Django introspection; caches per (table, column) with case-insensitive column keys.
    - Safe during early migrations; returns False on any error.
    - Works in both sync and async contexts by temporarily allowing async-unsafe operations.
    """
    col_key = column_name.lower()
    db_cache = _column_presence_cache.get(current_db_key())
    table_cache = db_cache.get(table_name) if db_cache else None
    if table_cache and col_key in table_cache:
        return table_cache[col_key]
 
    # Check if we're in an async context and need to allow async-unsafe operations
    import asyncio
    import os
 
    in_async_context = False
    previous_value = os.environ.get('DJANGO_ALLOW_ASYNC_UNSAFE')
 
    try:
        asyncio.get_running_loop()
        in_async_context = True
        # Temporarily allow async-unsafe operations for schema introspection
        # This is safe because:
        # 1. We're only reading DB schema metadata, not data
        # 2. This happens once per process during startup and is cached
        # 3. No concurrent access to the same data
        os.environ['DJANGO_ALLOW_ASYNC_UNSAFE'] = 'true'
    except RuntimeError:
        pass  # Not in async context
 
    try:
        with connection.cursor() as cursor:
            cols = connection.introspection.get_table_description(cursor, table_name)
        present = any(getattr(col, 'name', '').lower() == col_key for col in cols)
    except (DatabaseError, ProgrammingError):
        present = False
    finally:
        # Restore the previous value
        if in_async_context:
            if previous_value is None:
                os.environ.pop('DJANGO_ALLOW_ASYNC_UNSAFE', None)
            else:
                os.environ['DJANGO_ALLOW_ASYNC_UNSAFE'] = previous_value
 
    _column_presence_cache.setdefault(current_db_key(), {}).setdefault(table_name, {})[col_key] = present
    return present
 
 
@receiver(post_migrate)
def signal_clear_column_presence_cache(**_kwargs):
    """If some migration adds a column, we need to clear the column_presence_cache
    so that the next migration can introspect the new column using has_column_cached()."""
    logger.debug('Clearing column presence cache in post_migrate signal')
    _column_presence_cache.clear()