from collections import namedtuple, UserString
import itertools
from django.db import connections
from django.db import models
from django.db.models.sql.compiler import SQLCompiler
from django.utils import timezone
from psycopg2.extensions import AsIs
class UpdateField(UserString):
def __init__(self, field, expression=None):
self.data = field
self.expression = expression
class UpsertResult(list):
"""
Returned by the upsert operation.
Wraps a list and provides properties to access created, updated,
untouched, and deleted elements
"""
@property
def created(self):
return (i for i in self if i.status_ == 'c')
@property
def updated(self):
return (i for i in self if i.status_ == 'u')
@property
def untouched(self):
return (i for i in self if i.status_ == 'n')
@property
def deleted(self):
return (i for i in self if i.status_ == 'd')
def _quote(field):
return '"{0}"'.format(field)
def _get_update_fields(queryset, to_update, exclude=None):
"""
Get the fields to be updated in an upsert.
Always exclude auto_now_add, auto_created fields, and unique fields in an
update
"""
exclude = exclude or []
model = queryset.model
fields = {
**{field.attname: field for field in model._meta.fields},
**{field.name: field for field in model._meta.fields},
}
if to_update is None:
to_update = [field.attname for field in model._meta.fields]
to_update = [
attname
for attname in to_update
if (
attname not in exclude
and not getattr(fields[attname], 'auto_now_add', False)
and not fields[attname].auto_created
)
]
return to_update
def _fill_auto_fields(queryset, values):
"""
Given a list of models, fill in auto_now and auto_now_add fields
for upserts. Since django manager utils passes Django's ORM, these values
have to be automatically constructed
"""
model = queryset.model
auto_field_names = [
f.attname
for f in model._meta.fields
if getattr(f, 'auto_now', False) or getattr(f, 'auto_now_add', False)
]
now = timezone.now()
for value in values:
for f in auto_field_names:
setattr(value, f, now)
return values
def _prep_sql_args(queryset, connection, cursor, sql_args):
compiler = SQLCompiler(query=queryset.query, connection=connection, using=queryset.using)
return [
AsIs(cursor.mogrify(*sql_arg.as_sql(compiler, connection)).decode('utf-8'))
if hasattr(sql_arg, 'as_sql')
else sql_arg
for sql_arg in sql_args
]
def _get_field_db_val(queryset, field, value, connection):
if hasattr(value, 'resolve_expression'): # pragma: no cover
# Handle cases when the field is of type "Func" and other expressions.
# This is useful for libraries like django-rdkit that can't easily be tested
return value.resolve_expression(queryset.query, allow_joins=False, for_save=True)
else:
return field.get_db_prep_save(value, connection)
def _sort_by_unique_fields(queryset, model_objs, unique_fields):
"""
Sort a list of models by their unique fields.
Sorting models in an upsert greatly reduces the chances of deadlock
when doing concurrent upserts
"""
model = queryset.model
connection = connections[queryset.db]
unique_fields = [field for field in model._meta.fields if field.attname in unique_fields]
def sort_key(model_obj):
return tuple(
_get_field_db_val(queryset, field, getattr(model_obj, field.attname), connection)
for field in unique_fields
)
return sorted(model_objs, key=sort_key)
def _get_values_for_row(queryset, model_obj, all_fields):
connection = connections[queryset.db]
return [
# Convert field value to db value
# Use attname here to support fields with custom db_column names
_get_field_db_val(queryset, field, getattr(model_obj, field.attname), connection)
for field in all_fields
]
def _get_values_for_rows(queryset, model_objs, all_fields):
connection = connections[queryset.db]
row_values = []
sql_args = []
for i, model_obj in enumerate(model_objs):
sql_args.extend(_get_values_for_row(queryset, model_obj, all_fields))
if i == 0:
row_values.append(
'({0})'.format(
', '.join(['%s::{0}'.format(f.db_type(connection)) for f in all_fields])
)
)
else:
row_values.append('({0})'.format(', '.join(['%s'] * len(all_fields))))
return row_values, sql_args
def _get_return_fields_sql(returning, return_status=False, alias=None):
if alias:
return_fields_sql = ', '.join(
'{0}.{1}'.format(alias, _quote(field)) for field in returning
)
else:
return_fields_sql = ', '.join(_quote(field) for field in returning)
if return_status:
return_fields_sql += ', CASE WHEN xmax = 0 THEN \'c\' ELSE \'u\' END AS status_'
return return_fields_sql
def _get_upsert_sql(
queryset,
model_objs,
unique_fields,
update_fields,
returning,
ignore_duplicate_updates=True,
return_untouched=False,
):
"""
Generates the postgres specific sql necessary to perform an upsert
(ON CONFLICT) INSERT INTO table_name (field1, field2)
VALUES (1, 'two')
ON CONFLICT (unique_field) DO UPDATE SET field2 = EXCLUDED.field2;
"""
model = queryset.model
update_expressions = {f: f.expression for f in update_fields if getattr(f, 'expression', None)}
# Use all fields except pk unless the uniqueness constraint is the pk field
all_fields = [
field
for field in model._meta.fields
if field.column in unique_fields or not isinstance(field, models.AutoField)
]
all_field_names = [field.column for field in all_fields]
returning = returning if returning is not True else [f.column for f in model._meta.fields]
all_field_names_sql = ', '.join([_quote(field) for field in all_field_names])
# Convert field names to db column names
unique_fields = [model._meta.get_field(unique_field) for unique_field in unique_fields]
update_fields = [model._meta.get_field(update_field) for update_field in update_fields]
row_values, sql_args = _get_values_for_rows(queryset, model_objs, all_fields)
unique_field_names_sql = ', '.join([_quote(field.column) for field in unique_fields])
update_fields_expressions = {
field.column: f'EXCLUDED.{_quote(field.column)}' for field in update_fields
}
if update_expressions:
connection = connections[queryset.db]
compiler = SQLCompiler(query=queryset.query, connection=connection, using=queryset.using)
with connection.cursor() as cursor:
for field_name, expr in update_expressions.items():
expr = expr.resolve_expression(queryset.query, allow_joins=False, for_save=True)
update_fields_expressions[
model._meta.get_field(field_name).column
] = cursor.mogrify(*expr.as_sql(compiler, connection)).decode('utf-8')
update_fields_sql = ', '.join(
f'{_quote(field.column)} = {update_fields_expressions[field.column]}'
for field in update_fields
)
return_sql = (
'RETURNING ' + _get_return_fields_sql(returning, return_status=True) if returning else ''
)
ignore_duplicates_sql = ''
if ignore_duplicate_updates:
ignore_duplicates_sql = (
' WHERE ({update_fields_sql}) IS DISTINCT FROM ({excluded_update_fields_sql}) '
).format(
update_fields_sql=', '.join(
'{0}.{1}'.format(model._meta.db_table, _quote(field.column))
for field in update_fields
),
excluded_update_fields_sql=', '.join(update_fields_expressions.values()),
)
on_conflict = (
'DO UPDATE SET {0} {1}'.format(update_fields_sql, ignore_duplicates_sql)
if update_fields
else 'DO NOTHING'
)
if return_untouched:
row_values_sql = ', '.join(
['(\'{0}\', {1})'.format(i, row_value[1:-1]) for i, row_value in enumerate(row_values)]
)
sql = (
' WITH input_rows("temp_id_", {all_field_names_sql}) AS ('
' VALUES {row_values_sql}'
' ), ins AS ( '
' INSERT INTO {table_name} ({all_field_names_sql})'
' SELECT {all_field_names_sql} FROM input_rows ORDER BY temp_id_'
' ON CONFLICT ({unique_field_names_sql}) {on_conflict} {return_sql}'
' )'
' SELECT DISTINCT ON ({table_pk_name}) * FROM ('
' SELECT status_, {return_fields_sql}'
' FROM ins'
' UNION ALL'
' SELECT \'n\' AS status_, {aliased_return_fields_sql}'
' FROM input_rows'
' JOIN {table_name} c USING ({unique_field_names_sql})'
' ) as results'
' ORDER BY results."{table_pk_name}", CASE WHEN(status_ = \'n\') THEN 1 ELSE 0 END;'
).format(
all_field_names_sql=all_field_names_sql,
row_values_sql=row_values_sql,
table_name=model._meta.db_table,
unique_field_names_sql=unique_field_names_sql,
on_conflict=on_conflict,
return_sql=return_sql,
table_pk_name=model._meta.pk.name,
return_fields_sql=_get_return_fields_sql(returning),
aliased_return_fields_sql=_get_return_fields_sql(returning, alias='c'),
)
else:
row_values_sql = ', '.join(row_values)
sql = (
' INSERT INTO {table_name} ({all_field_names_sql})'
' VALUES {row_values_sql}'
' ON CONFLICT ({unique_field_names_sql}) {on_conflict} {return_sql}'
).format(
table_name=model._meta.db_table,
all_field_names_sql=all_field_names_sql,
row_values_sql=row_values_sql,
unique_field_names_sql=unique_field_names_sql,
on_conflict=on_conflict,
return_sql=return_sql,
)
return sql, sql_args
def _fetch(
queryset,
model_objs,
unique_fields,
update_fields,
returning,
sync,
ignore_duplicate_updates=True,
return_untouched=False,
):
"""
Perfom the upsert and do an optional sync operation
"""
model = queryset.model
connection = connections[queryset.db]
if (return_untouched or sync) and returning is not True:
returning = set(returning) if returning else set()
returning.add(model._meta.pk.name)
upserted = []
deleted = []
# We must return untouched rows when doing a sync operation
return_untouched = True if sync else return_untouched
if model_objs:
sql, sql_args = _get_upsert_sql(
queryset,
model_objs,
unique_fields,
update_fields,
returning,
ignore_duplicate_updates=ignore_duplicate_updates,
return_untouched=return_untouched,
)
with connection.cursor() as cursor:
sql_args = _prep_sql_args(queryset, connection, cursor, sql_args)
cursor.execute(sql, sql_args)
if cursor.description:
nt_result = namedtuple('Result', [col[0] for col in cursor.description])
upserted = [nt_result(*row) for row in cursor.fetchall()]
pk_field = model._meta.pk.name
if sync:
orig_ids = queryset.values_list(pk_field, flat=True)
deleted = set(orig_ids) - {getattr(r, pk_field) for r in upserted}
model.objects.filter(pk__in=deleted).delete()
nt_deleted_result = namedtuple('DeletedResult', [model._meta.pk.name, 'status_'])
return UpsertResult(
upserted + [nt_deleted_result(**{pk_field: d, 'status_': 'd'}) for d in deleted]
)
def _upsert(
queryset,
model_objs,
unique_fields,
update_fields=None,
returning=False,
sync=False,
ignore_duplicate_updates=True,
return_untouched=False,
):
"""
Perform a bulk upsert on a table, optionally syncing the results.
Args:
queryset (Model|QuerySet): A model or a queryset that defines the
collection to sync
model_objs (List[Model]): A list of Django models to sync. All models
in this list will be bulk upserted and any models not in the table
(or queryset) will be deleted if sync=True.
unique_fields (List[str]): A list of fields that define the uniqueness
of the model. The model must have a unique constraint on these
fields
update_fields (List[str], default=None): A list of fields to update
whenever objects already exist. If an empty list is provided, it
is equivalent to doing a bulk insert on the objects that don't
exist. If `None`, all fields will be updated. If you want to
perform an expression such as an ``F`` object on a field when
it is updated, use the ``pgbulk.UpdateField`` class. See
examples below.
returning (bool|List[str]): If True, returns all fields. If a list,
only returns fields in the list
sync (bool, default=False): Perform a sync operation on the queryset
ignore_duplicate_updates (bool, default=True): Don't perform an update
if all columns are identical to the row in the database.
return_untouched (bool, default=False): When
``ignore_duplicate_updates`` is ``True``, untouched rows will not
be returned in upsert results. Set this to ``True`` to return
untouched rows.
"""
queryset = queryset if isinstance(queryset, models.QuerySet) else queryset.objects.all()
# Populate automatically generated fields in the rows like date times
_fill_auto_fields(queryset, model_objs)
# Sort the rows to reduce the chances of deadlock during concurrent upserts
model_objs = _sort_by_unique_fields(queryset, model_objs, unique_fields)
update_fields = _get_update_fields(queryset, update_fields, exclude=unique_fields)
return _fetch(
queryset,
model_objs,
unique_fields,
update_fields,
returning,
sync,
ignore_duplicate_updates=ignore_duplicate_updates,
return_untouched=return_untouched,
)
[docs]def update(queryset, model_objs, update_fields=None):
"""
Bulk updates a list of model objects that are already saved.
Args:
queryset (QuerySet): The queryset to use when bulk updating
model_objs (List[Model]): Model object values to use for the update
update_fields (List[str], default=None): A list of fields on the
model objects to update. If ``None``, all fields will be updated.
Example:
Update an attribute of multiple models in bulk::
import pgbulk
pgbulk.update(
MyModel,
[
MyModel(id=1, some_attr='some_val1'),
MyModel(id=2, some_attr='some_val2')
],
# These are the fields that will be updated. If not provided,
# all fields will be updated
['some_attr']
)
"""
queryset = queryset if isinstance(queryset, models.QuerySet) else queryset.objects.all()
connection = connections[queryset.db]
model = queryset.model
update_fields = _get_update_fields(queryset, update_fields)
# Add the pk to the value fields so we can join during the update
value_fields = [model._meta.pk.attname] + update_fields
row_values = [
[
_get_field_db_val(
queryset,
model_obj._meta.get_field(field),
getattr(model_obj, model_obj._meta.get_field(field).attname),
connection,
)
for field in value_fields
]
for model_obj in model_objs
]
# If we do not have any values or fields to update, just return
if len(row_values) == 0 or len(update_fields) == 0:
return []
db_types = [model._meta.get_field(field).db_type(connection) for field in value_fields]
value_fields_sql = ', '.join(
'"{field}"'.format(field=model._meta.get_field(field).column) for field in value_fields
)
update_fields_sql = ', '.join(
[
'"{field}" = "new_values"."{field}"'.format(field=model._meta.get_field(field).column)
for field in update_fields
]
)
values_sql = ', '.join(
[
'({0})'.format(
', '.join(
[
'%s::{0}'.format(db_types[i]) if not row_number and i else '%s'
for i, _ in enumerate(row)
]
)
)
for row_number, row in enumerate(row_values)
]
)
update_sql = (
'UPDATE {table} '
'SET {update_fields_sql} '
'FROM (VALUES {values_sql}) AS new_values ({value_fields_sql}) '
'WHERE "{table}"."{pk_field}" = "new_values"."{pk_field}"'
).format(
table=model._meta.db_table,
pk_field=model._meta.pk.column,
update_fields_sql=update_fields_sql,
values_sql=values_sql,
value_fields_sql=value_fields_sql,
)
update_sql_params = list(itertools.chain(*row_values))
with connection.cursor() as cursor:
update_sql_params = _prep_sql_args(queryset, connection, cursor, update_sql_params)
return cursor.execute(update_sql, update_sql_params)
[docs]def upsert(
queryset,
model_objs,
unique_fields,
update_fields=None,
returning=False,
ignore_duplicate_updates=True,
return_untouched=False,
):
"""
Perform a bulk upsert on a table.
Args:
queryset (Model|QuerySet): A model or a queryset that defines the
collection to upsert
model_objs (List[Model]): A list of Django models to upsert. All models
in this list will be bulk upserted.
unique_fields (List[str]): A list of fields that define the uniqueness
of the model. The model must have a unique constraint on these
fields
update_fields (List[str], default=None): A list of fields to update
whenever objects already exist. If an empty list is provided, it
is equivalent to doing a bulk insert on the objects that don't
exist. If `None`, all fields will be updated. If you want to
perform an expression such as an ``F`` object on a field when
it is updated, use the ``pgbulk.UpdateField`` class. See
examples below.
returning (bool|List[str], default=False): If True, returns all fields.
If a list, only returns fields in the list. If False, do not
return results from the upsert.
ignore_duplicate_updates (bool, default=True): Don't perform an update
if all columns are identical to the row in the database.
return_untouched (bool, default=False): When
``ignore_duplicate_updates`` is ``True``, untouched rows will not
be returned in upsert results. Set this to ``True`` to return
untouched rows.
Examples:
A basic bulk upsert on a model::
import pgbulk
pgbulk.upsert(
MyModel,
[
MyModel(int_field=1, some_attr='some_val1'),
MyModel(int_field=2, some_attr='some_val2')
],
# These are the fields that identify the uniqueness constraint.
['int_field'],
# These are the fields that will be updated if the row already
# exists. If not provided, all fields will be updated
['some_attr']
)
Return the results of an upsert::
results = pgbulk.upsert(
MyModel,
[
MyModel(int_field=1, some_attr='some_val1'),
MyModel(int_field=2, some_attr='some_val2')
],
['int_field'],
['some_attr'],
# ``True`` will return all columns. One can also explicitly
# list which columns will be returned
returning=True
)
# Print which results were created
print(results.created)
# Print which results were updated.
# By default, if an update results in no changes, it will not
# be updated and will not be returned.
print(results.updated)
Upsert values and update rows even when the update is meaningless
(i.e. a duplicate update). This is turned off by default, but it
can be enabled in case postgres triggers or other processes
need to happen as a result of an update::
pgbulk.upsert(
MyModel,
[
MyModel(int_field=1, some_attr='some_val1'),
MyModel(int_field=2, some_attr='some_val2')
],
['int_field'],
['some_attr'],
# Perform updates in the database even if it's a duplicate
# update.
ignore_duplicate_updates=False
)
Upsert values and ignore duplicate updates, but still return
the rows that were untouched by the upsert::
results = pgbulk.upsert(
MyModel,
[
MyModel(int_field=1, some_attr='some_val1'),
MyModel(int_field=2, some_attr='some_val2')
],
['int_field'],
['some_attr'],
returning=True,
# Even though we don't perform an update on a duplicate,
# return the row that was part of the upsert but untouched
# This option is only meaningful when
# ignore_duplicate_updates=True (the default)
return_untouched=True,
)
# Print untouched rows
print(results.untouched)
Use an expression for a field if an update happens. In the example
below, we increment ``some_int_field`` by one whenever an update happens.
Otherwise it defaults to zero::
results = pgbulk.upsert(
MyModel,
[
MyModel(some_int_field=0, some_key='a'),
MyModel(some_int_field=0, some_key='b')
],
['some_key'],
[
# Use the UpdateField class to specify an expression for the update.
# If provided, the expression will be executed when an update happens.
pgbulk.UpdateField('some_int_field', expression=models.F('some_int_field') + 1)
],
)
"""
return _upsert(
queryset,
model_objs,
unique_fields,
update_fields=update_fields,
returning=returning,
ignore_duplicate_updates=ignore_duplicate_updates,
return_untouched=return_untouched,
sync=False,
)
[docs]def sync(
queryset,
model_objs,
unique_fields,
update_fields=None,
returning=False,
ignore_duplicate_updates=True,
return_untouched=False,
):
"""
Perform a bulk sync on a table. A sync is an upsert on a list of model
objects followed by a delete on the queryset whose elements were not
in the list of models.
Args:
queryset (Model|QuerySet): A model or a queryset that defines the
collection to sync. If a model is provided, all rows are
candidates for the sync operation
model_objs (List[Model]): A list of Django models to sync. All models
in this list will be bulk upserted. Any models in the queryset
that are not present in this list will be deleted.
unique_fields (List[str]): A list of fields that define the uniqueness
of the model. The model must have a unique constraint on these
fields
update_fields (List[str], default=None): A list of fields to update
whenever objects already exist. If an empty list is provided, it
is equivalent to doing a bulk insert on the objects that don't
exist. If `None`, all fields will be updated.
returning (bool|List[str]): If True, returns all fields. If a list,
only returns fields in the list
ignore_duplicate_updates (bool, default=True): Don't perform an update
if the row is a duplicate.
Examples:
Sync two elements to a table with three elements. The sync will
upsert two elements and delete the third::
# Assume the MyModel table has objects with int_field=1,2,3
# We will sync this table to two elements
results = pgbulk.sync(
MyModel.objects.all(),
[
MyModel(int_field=1, some_attr='some_val1'),
MyModel(int_field=2, some_attr='some_val2')
],
['int_field'],
['some_attr'],
)
# Print created, updated, untouched, and deleted rows from the sync
print(results.created)
print(results.updated)
print(results.untouched)
print(results.deleted)
"""
return _upsert(
queryset,
model_objs,
unique_fields,
update_fields=update_fields,
returning=returning,
ignore_duplicate_updates=ignore_duplicate_updates,
sync=True,
)