"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ """ Main module with the bulk_update function. """ import itertools from collections import defaultdict from django.db import connections, models from django.db.models.sql import UpdateQuery def _get_db_type(field, connection): if isinstance(field, (models.PositiveSmallIntegerField, models.PositiveIntegerField)): return field.db_type(connection).split(' ', 1)[0] return field.db_type(connection) def _as_sql(obj, field, query, compiler, connection): value = getattr(obj, field.attname) if hasattr(value, 'resolve_expression'): value = value.resolve_expression(query, allow_joins=False, for_save=True) else: value = field.get_db_prep_save(value, connection=connection) if hasattr(value, 'as_sql'): placeholder, value = compiler.compile(value) if isinstance(value, list): value = tuple(value) else: placeholder = '%s' return value, placeholder def flatten(l, types=(list, float)): # noqa: E741 """ Flat nested list of lists into a single list. """ l = [item if isinstance(item, types) else [item] for item in l] # noqa: E741 return [item for sublist in l for item in sublist] def grouper(iterable, size): # http://stackoverflow.com/a/8991553 it = iter(iterable) while True: chunk = tuple(itertools.islice(it, size)) if not chunk: return yield chunk def validate_fields(meta, fields): fields = frozenset(fields) field_names = set() for field in meta.fields: if not field.primary_key: field_names.add(field.name) if field.name != field.attname: field_names.add(field.attname) non_model_fields = fields.difference(field_names) if non_model_fields: raise TypeError('These fields are not present in ' 'current meta: {}'.format(', '.join(non_model_fields))) def get_fields(update_fields, exclude_fields, meta, obj=None): deferred_fields = set() if update_fields is not None: validate_fields(meta, update_fields) elif obj: deferred_fields = obj.get_deferred_fields() if exclude_fields is None: exclude_fields = set() else: exclude_fields = set(exclude_fields) validate_fields(meta, exclude_fields) exclude_fields |= deferred_fields fields = [ field for field in meta.concrete_fields if ( not field.primary_key and field.attname not in deferred_fields and field.attname not in exclude_fields and field.name not in exclude_fields and (update_fields is None or field.attname in update_fields or field.name in update_fields) ) ] return fields def bulk_update( objs, meta=None, update_fields=None, exclude_fields=None, using='default', batch_size=None, pk_field='pk' ): assert batch_size is None or batch_size > 0 # force to retrieve objs from the DB at the beginning, # to avoid multiple subsequent queries objs = list(objs) if not objs: return batch_size = batch_size or len(objs) if meta: fields = get_fields(update_fields, exclude_fields, meta) else: meta = objs[0]._meta if update_fields is not None: fields = get_fields(update_fields, exclude_fields, meta, objs[0]) else: fields = None if fields is not None and len(fields) == 0: return if pk_field == 'pk': pk_field = meta.get_field(meta.pk.name) else: pk_field = meta.get_field(pk_field) connection = connections[using] query = UpdateQuery(meta.model) compiler = query.get_compiler(connection=connection) template = '"{column}" = CAST(CASE "{pk_column}" {cases}ELSE "{column}" END AS {type})' case_template = 'WHEN %s THEN {} ' lenpks = 0 for objs_batch in grouper(objs, batch_size): pks = [] parameters = defaultdict(list) placeholders = defaultdict(list) for obj in objs_batch: pk_value, _ = _as_sql(obj, pk_field, query, compiler, connection) pks.append(pk_value) loaded_fields = fields or get_fields(update_fields, exclude_fields, meta, obj) for field in loaded_fields: value, placeholder = _as_sql(obj, field, query, compiler, connection) parameters[field].extend(flatten([pk_value, value], types=tuple)) placeholders[field].append(placeholder) values = ', '.join( template.format( column=field.column, pk_column=pk_field.column, cases=(case_template * len(placeholders[field])).format(*placeholders[field]), type=_get_db_type(field, connection=connection), ) for field in parameters.keys() ) parameters = flatten(parameters.values(), types=list) parameters.extend(pks) n_pks = len(pks) del pks dbtable = '"{}"'.format(meta.db_table) in_clause = '"{pk_column}" in ({pks})'.format( pk_column=pk_field.column, pks=', '.join(itertools.repeat('%s', n_pks)), ) sql = 'UPDATE {dbtable} SET {values} WHERE {in_clause}'.format( # nosec dbtable=dbtable, values=values, in_clause=in_clause, ) del values lenpks += n_pks connection.cursor().execute(sql, parameters) return lenpks