Bin
2025-12-16 9e0b2ba2c317b1a86212f24cbae3195ad1f3dbfa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
"""
Main module with the bulk_update function.
"""
import itertools
from collections import defaultdict
 
from django.db import connections, models
from django.db.models.sql import UpdateQuery
 
 
def _get_db_type(field, connection):
    if isinstance(field, (models.PositiveSmallIntegerField, models.PositiveIntegerField)):
        return field.db_type(connection).split(' ', 1)[0]
 
    return field.db_type(connection)
 
 
def _as_sql(obj, field, query, compiler, connection):
    value = getattr(obj, field.attname)
 
    if hasattr(value, 'resolve_expression'):
        value = value.resolve_expression(query, allow_joins=False, for_save=True)
    else:
        value = field.get_db_prep_save(value, connection=connection)
 
    if hasattr(value, 'as_sql'):
        placeholder, value = compiler.compile(value)
        if isinstance(value, list):
            value = tuple(value)
    else:
        placeholder = '%s'
 
    return value, placeholder
 
 
def flatten(l, types=(list, float)):  # noqa: E741
    """
    Flat nested list of lists into a single list.
    """
    l = [item if isinstance(item, types) else [item] for item in l]  # noqa: E741
    return [item for sublist in l for item in sublist]
 
 
def grouper(iterable, size):
    # http://stackoverflow.com/a/8991553
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, size))
        if not chunk:
            return
        yield chunk
 
 
def validate_fields(meta, fields):
 
    fields = frozenset(fields)
    field_names = set()
 
    for field in meta.fields:
        if not field.primary_key:
            field_names.add(field.name)
 
            if field.name != field.attname:
                field_names.add(field.attname)
 
    non_model_fields = fields.difference(field_names)
 
    if non_model_fields:
        raise TypeError('These fields are not present in ' 'current meta: {}'.format(', '.join(non_model_fields)))
 
 
def get_fields(update_fields, exclude_fields, meta, obj=None):
 
    deferred_fields = set()
 
    if update_fields is not None:
        validate_fields(meta, update_fields)
    elif obj:
        deferred_fields = obj.get_deferred_fields()
 
    if exclude_fields is None:
        exclude_fields = set()
    else:
        exclude_fields = set(exclude_fields)
        validate_fields(meta, exclude_fields)
 
    exclude_fields |= deferred_fields
 
    fields = [
        field
        for field in meta.concrete_fields
        if (
            not field.primary_key
            and field.attname not in deferred_fields
            and field.attname not in exclude_fields
            and field.name not in exclude_fields
            and (update_fields is None or field.attname in update_fields or field.name in update_fields)
        )
    ]
 
    return fields
 
 
def bulk_update(
    objs, meta=None, update_fields=None, exclude_fields=None, using='default', batch_size=None, pk_field='pk'
):
    assert batch_size is None or batch_size > 0
 
    # force to retrieve objs from the DB at the beginning,
    # to avoid multiple subsequent queries
    objs = list(objs)
    if not objs:
        return
    batch_size = batch_size or len(objs)
 
    if meta:
        fields = get_fields(update_fields, exclude_fields, meta)
    else:
        meta = objs[0]._meta
        if update_fields is not None:
            fields = get_fields(update_fields, exclude_fields, meta, objs[0])
        else:
            fields = None
 
    if fields is not None and len(fields) == 0:
        return
 
    if pk_field == 'pk':
        pk_field = meta.get_field(meta.pk.name)
    else:
        pk_field = meta.get_field(pk_field)
 
    connection = connections[using]
    query = UpdateQuery(meta.model)
    compiler = query.get_compiler(connection=connection)
 
    template = '"{column}" = CAST(CASE "{pk_column}" {cases}ELSE "{column}" END AS {type})'
 
    case_template = 'WHEN %s THEN {} '
 
    lenpks = 0
    for objs_batch in grouper(objs, batch_size):
 
        pks = []
        parameters = defaultdict(list)
        placeholders = defaultdict(list)
 
        for obj in objs_batch:
 
            pk_value, _ = _as_sql(obj, pk_field, query, compiler, connection)
            pks.append(pk_value)
 
            loaded_fields = fields or get_fields(update_fields, exclude_fields, meta, obj)
 
            for field in loaded_fields:
                value, placeholder = _as_sql(obj, field, query, compiler, connection)
                parameters[field].extend(flatten([pk_value, value], types=tuple))
                placeholders[field].append(placeholder)
 
        values = ', '.join(
            template.format(
                column=field.column,
                pk_column=pk_field.column,
                cases=(case_template * len(placeholders[field])).format(*placeholders[field]),
                type=_get_db_type(field, connection=connection),
            )
            for field in parameters.keys()
        )
 
        parameters = flatten(parameters.values(), types=list)
        parameters.extend(pks)
 
        n_pks = len(pks)
        del pks
 
        dbtable = '"{}"'.format(meta.db_table)
 
        in_clause = '"{pk_column}" in ({pks})'.format(
            pk_column=pk_field.column,
            pks=', '.join(itertools.repeat('%s', n_pks)),
        )
 
        sql = 'UPDATE {dbtable} SET {values} WHERE {in_clause}'.format(  # nosec
            dbtable=dbtable,
            values=values,
            in_clause=in_clause,
        )
        del values
 
        lenpks += n_pks
 
        connection.cursor().execute(sql, parameters)
 
    return lenpks