import hashlib
import logging
from django.db.backends.utils import truncate_name
from django.db.transaction import atomic
from django.utils import six
from django.utils.encoding import force_bytes
logger = logging.getLogger('django.db.backends.schema')
def _related_non_m2m_objects(old_field, new_field):
# Filters out m2m objects from reverse relations.
# Returns (old_relation, new_relation) tuples.
return zip(
(obj for obj in old_field.model._meta.related_objects if not obj.field.many_to_many),
(obj for obj in new_field.model._meta.related_objects if not obj.field.many_to_many)
)
class BaseDatabaseSchemaEditor(object):
"""
This class (and its subclasses) are responsible for emitting schema-changing
statements to the databases - model creation/removal/alteration, field
renaming, index fiddling, and so on.
It is intended to eventually completely replace DatabaseCreation.
This class should be used by creating an instance for each set of schema
changes (e.g. a migration file), and by first calling start(),
then the relevant actions, and then commit(). This is necessary to allow
things like circular foreign key references - FKs will only be created once
commit() is called.
"""
# Overrideable SQL templates
sql_create_table = "CREATE TABLE %(table)s (%(definition)s)"
sql_rename_table = "ALTER TABLE %(old_table)s RENAME TO %(new_table)s"
sql_retablespace_table = "ALTER TABLE %(table)s SET TABLESPACE %(new_tablespace)s"
sql_delete_table = "DROP TABLE %(table)s CASCADE"
sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(column)s %(definition)s"
sql_alter_column = "ALTER TABLE %(table)s %(changes)s"
sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
sql_alter_column_null = "ALTER COLUMN %(column)s DROP NOT NULL"
sql_alter_column_not_null = "ALTER COLUMN %(column)s SET NOT NULL"
sql_alter_column_default = "ALTER COLUMN %(column)s SET DEFAULT %(default)s"
sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT"
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)"
sql_delete_unique = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_create_fk = (
"ALTER TABLE %(table)s ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) "
"REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
)
sql_create_inline_fk = None
sql_delete_fk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
sql_delete_index = "DROP INDEX %(name)s"
sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
sql_delete_pk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
def __init__(self, connection, collect_sql=False):
self.connection = connection
self.collect_sql = collect_sql
if self.collect_sql:
self.collected_sql = []
# State-managing methods
def __enter__(self):
self.deferred_sql = []
if self.connection.features.can_rollback_ddl:
self.atomic = atomic(self.connection.alias)
self.atomic.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
for sql in self.deferred_sql:
self.execute(sql)
if self.connection.features.can_rollback_ddl:
self.atomic.__exit__(exc_type, exc_value, traceback)
# Core utility functions
def execute(self, sql, params=[]):
"""
Executes the given SQL statement, with optional parameters.
"""
# Log the command we're running, then run it
logger.debug("%s; (params %r)" % (sql, params))
if self.collect_sql:
ending = "" if sql.endswith(";") else ";"
if params is not None:
self.collected_sql.append((sql % tuple(map(self.quote_value, params))) + ending)
else:
self.collected_sql.append(sql + ending)
else:
with self.connection.cursor() as cursor:
cursor.execute(sql, params)
def quote_name(self, name):
return self.connection.ops.quote_name(name)
@classmethod
def _digest(cls, *args):
"""
Generates a 32-bit digest of a set of arguments that can be used to
shorten identifying names.
"""
h = hashlib.md5()
for arg in args:
h.update(force_bytes(arg))
return h.hexdigest()[:8]
# Field <-> database mapping functions
def column_sql(self, model, field, include_default=False):
"""
Takes a field and returns its column definition.
The field must already have had set_attributes_from_name called.
"""
# Get the column's type and use that as the basis of the SQL
db_params = field.db_parameters(connection=self.connection)
sql = db_params['type']
params = []
# Check for fields that aren't actually columns (e.g. M2M)
if sql is None:
return None, None
# Work out nullability
null = field.null
# If we were told to include a default value, do so
include_default = include_default and not self.skip_default(field)
if include_default:
default_value = self.effective_default(field)
if default_value is not None:
if self.connection.features.requires_literal_defaults:
# Some databases can't take defaults as a parameter (oracle)
# If this is the case, the individual schema backend should
# implement prepare_default
sql += " DEFAULT %s" % self.prepare_default(default_value)
else:
sql += " DEFAULT %s"
params += [default_value]
# Oracle treats the empty string ('') as null, so coerce the null
# option whenever '' is a possible value.
if (field.empty_strings_allowed and not field.primary_key and
self.connection.features.interprets_empty_strings_as_nulls):
null = True
if null and not self.connection.features.implied_column_null:
sql += " NULL"
elif not null:
sql += " NOT NULL"
# Primary key/unique outputs
if field.primary_key:
sql += " PRIMARY KEY"
elif field.unique:
sql += " UNIQUE"
# Optionally add the tablespace if it's an implicitly indexed column
tablespace = field.db_tablespace or model._meta.db_tablespace
if tablespace and self.connection.features.supports_tablespaces and field.unique:
sql += " %s" % self.connection.ops.tablespace_sql(tablespace, inline=True)
# Return the sql
return sql, params
def skip_default(self, field):
"""
Some backends don't accept default values for certain columns types
(i.e. MySQL longtext and longblob).
"""
return False
def prepare_default(self, value):
"""
Only used for backends which have requires_literal_defaults feature
"""
raise NotImplementedError(
'subclasses of BaseDatabaseSchemaEditor for backends which have '
'requires_literal_defaults must provide a prepare_default() method'
)
def effective_default(self, field):
"""
Returns a field's effective database default value
"""
if field.has_default():
default = field.get_default()
elif not field.null and field.blank and field.empty_strings_allowed:
if field.get_internal_type() == "BinaryField":
default = six.binary_type()
else:
default = six.text_type()
else:
default = None
# If it's a callable, call it
if six.callable(default):
default = default()
# Run it through the field's get_db_prep_save method so we can send it
# to the database.
default = field.get_db_prep_save(default, self.connection)
return default
def quote_value(self, value):
"""
Returns a quoted version of the value so it's safe to use in an SQL
string. This is not safe against injection from user code; it is
intended only for use in making SQL scripts or preparing default values
for particularly tricky backends (defaults are not user-defined, though,
so this is safe).
"""
raise NotImplementedError()
# Actions
def create_model(self, model):
"""
Takes a model and creates a table for it in the database.
Will also create any accompanying indexes or unique constraints.
"""
# Create column SQL, add FK deferreds if needed
column_sqls = []
params = []
for field in model._meta.local_fields:
# SQL
definition, extra_params = self.column_sql(model, field)
if definition is None:
continue
# Check constraints can go on the column SQL here
db_params = field.db_parameters(connection=self.connection)
if db_params['check']:
definition += " CHECK (%s)" % db_params['check']
# Autoincrement SQL (for backends with inline variant)
col_type_suffix = field.db_type_suffix(connection=self.connection)
if col_type_suffix:
definition += " %s" % col_type_suffix
params.extend(extra_params)
# FK
if field.remote_field and field.db_constraint:
to_table = field.remote_field.model._meta.db_table
to_column = field.remote_field.model._meta.get_field(field.remote_field.field_name).column
if self.connection.features.supports_foreign_keys:
self.deferred_sql.append(self._create_fk_sql(model, field, "_fk_%(to_table)s_%(to_column)s"))
elif self.sql_create_inline_fk:
definition += " " + self.sql_create_inline_fk % {
"to_table": self.quote_name(to_table),
"to_column": self.quote_name(to_column),
}
# Add the SQL to our big list
column_sqls.append("%s %s" % (
self.quote_name(field.column),
definition,
))
# Autoincrement SQL (for backends with post table definition variant)
if field.get_internal_type() == "AutoField":
autoinc_sql = self.connection.ops.autoinc_sql(model._meta.db_table, field.column)
if autoinc_sql:
self.deferred_sql.extend(autoinc_sql)
# Add any unique_togethers (always deferred, as some fields might be
# created afterwards, like geometry fields with some backends)
for fields in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in fields]
self.deferred_sql.append(self._create_unique_sql(model, columns))
# Make the table
sql = self.sql_create_table % {
"table": self.quote_name(model._meta.db_table),
"definition": ", ".join(column_sqls)
}
if model._meta.db_tablespace:
tablespace_sql = self.connection.ops.tablespace_sql(model._meta.db_tablespace)
if tablespace_sql:
sql += ' ' + tablespace_sql
# Prevent using [] as params, in the case a literal '%' is used in the definition
self.execute(sql, params or None)
# Add any field index and index_together's (deferred as SQLite3 _remake_table needs it)
self.deferred_sql.extend(self._model_indexes_sql(model))
# Make M2M tables
for field in model._meta.local_many_to_many:
if field.remote_field.through._meta.auto_created:
self.create_model(field.remote_field.through)
def delete_model(self, model):
"""
Deletes a model from the database.
"""
# Handle auto-created intermediary models
for field in model._meta.local_many_to_many:
if field.remote_field.through._meta.auto_created:
self.delete_model(field.remote_field.through)
# Delete the table
self.execute(self.sql_delete_table % {
"table": self.quote_name(model._meta.db_table),
})
def alter_unique_together(self, model, old_unique_together, new_unique_together):
"""
Deals with a model changing its unique_together.
Note: The input unique_togethers must be doubly-nested, not the single-
nested ["foo", "bar"] format.
"""
olds = set(tuple(fields) for fields in old_unique_together)
news = set(tuple(fields) for fields in new_unique_together)
# Deleted uniques
for fields in olds.difference(news):
self._delete_composed_index(model, fields, {'unique': True}, self.sql_delete_unique)
# Created uniques
for fields in news.difference(olds):
columns = [model._meta.get_field(field).column for field in fields]
self.execute(self._create_unique_sql(model, columns))
def alter_index_together(self, model, old_index_together, new_index_together):
"""
Deals with a model changing its index_together.
Note: The input index_togethers must be doubly-nested, not the single-
nested ["foo", "bar"] format.
"""
olds = set(tuple(fields) for fields in old_index_together)
news = set(tuple(fields) for fields in new_index_together)
# Deleted indexes
for fields in olds.difference(news):
self._delete_composed_index(model, fields, {'index': True}, self.sql_delete_index)
# Created indexes
for field_names in news.difference(olds):
fields = [model._meta.get_field(field) for field in field_names]
self.execute(self._create_index_sql(model, fields, suffix="_idx"))
def _delete_composed_index(self, model, fields, constraint_kwargs, sql):
columns = [model._meta.get_field(field).column for field in fields]
constraint_names = self._constraint_names(model, columns, **constraint_kwargs)
if len(constraint_names) != 1:
raise ValueError("Found wrong number (%s) of constraints for %s(%s)" % (
len(constraint_names),
Loading ...