Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

hemamaps / django-extensions   python

Repository URL to install this package:

Version: 1.6.7 

/ management / commands / sqldiff.py

# coding=utf-8
"""
sqldiff.py - Prints the (approximated) difference between models and database

TODO:
 - better support for relations
 - better support for constraints (mainly postgresql?)
 - support for table spaces with postgresql
 - when a table is not managed (meta.managed==False) then only do a one-way
   sqldiff ? show differences from db->table but not the other way around since
   it's not managed.

KNOWN ISSUES:
 - MySQL has by far the most problems with introspection. Please be
   carefull when using MySQL with sqldiff.
   - Booleans are reported back as Integers, so there's no way to know if
     there was a real change.
   - Varchar sizes are reported back without unicode support so their size
     may change in comparison to the real length of the varchar.
   - Some of the 'fixes' to counter these problems might create false
     positives or false negatives.
"""

import sys

import django
import six
from django.core.management import CommandError, sql as _sql
from django.core.management.color import no_style
from django.db import connection, transaction
from django.db.models.fields import AutoField, IntegerField

from django_extensions.compat import get_app_models
from django_extensions.management.utils import signalcommand
from django_extensions.compat import CompatibilityBaseCommand as BaseCommand

try:
    from django.core.management.base import OutputWrapper
    HAS_OUTPUTWRAPPER = True
except ImportError:
    HAS_OUTPUTWRAPPER = False


ORDERING_FIELD = IntegerField('_order', null=True)


def flatten(l, ltypes=(list, tuple)):
    ltype = type(l)
    l = list(l)
    i = 0
    while i < len(l):
        while isinstance(l[i], ltypes):
            if not l[i]:
                l.pop(i)
                i -= 1
                break
            else:
                l[i:i + 1] = l[i]
        i += 1
    return ltype(l)


def all_local_fields(meta):
    all_fields = []
    if meta.proxy:
        for parent in meta.parents:
            all_fields.extend(all_local_fields(parent._meta))
    else:
        for f in meta.local_fields:
            col_type = f.db_type(connection=connection)
            if col_type is None:
                continue
            all_fields.append(f)
    return all_fields


class SQLDiff(object):
    DATA_TYPES_REVERSE_OVERRIDE = {}

    IGNORE_MISSING_TABLES = [
        "django_migrations",
        "south_migrationhistory",
    ]

    DIFF_TYPES = [
        'error',
        'comment',
        'table-missing-in-db',
        'table-missing-in-model',
        'field-missing-in-db',
        'field-missing-in-model',
        'fkey-missing-in-db',
        'fkey-missing-in-model',
        'index-missing-in-db',
        'index-missing-in-model',
        'unique-missing-in-db',
        'unique-missing-in-model',
        'field-type-differ',
        'field-parameter-differ',
        'notnull-differ',
    ]
    DIFF_TEXTS = {
        'error': 'error: %(0)s',
        'comment': 'comment: %(0)s',
        'table-missing-in-db': "table '%(0)s' missing in database",
        'table-missing-in-model': "table '%(0)s' missing in models",
        'field-missing-in-db': "field '%(1)s' defined in model but missing in database",
        'field-missing-in-model': "field '%(1)s' defined in database but missing in model",
        'fkey-missing-in-db': "field '%(1)s' FOREIGN KEY defined in model but missing in database",
        'fkey-missing-in-model': "field '%(1)s' FOREIGN KEY defined in database but missing in model",
        'index-missing-in-db': "field '%(1)s' INDEX defined in model but missing in database",
        'index-missing-in-model': "field '%(1)s' INDEX defined in database schema but missing in model",
        'unique-missing-in-db': "field '%(1)s' UNIQUE defined in model but missing in database",
        'unique-missing-in-model': "field '%(1)s' UNIQUE defined in database schema but missing in model",
        'field-type-differ': "field '%(1)s' not of same type: db='%(3)s', model='%(2)s'",
        'field-parameter-differ': "field '%(1)s' parameters differ: db='%(3)s', model='%(2)s'",
        'notnull-differ': "field '%(1)s' null constraint should be '%(2)s' in the database",
    }

    SQL_FIELD_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('ADD COLUMN'), style.SQL_FIELD(qn(args[1])), ' '.join(style.SQL_COLTYPE(a) if i == 0 else style.SQL_KEYWORD(a) for i, a in enumerate(args[2:])))
    SQL_FIELD_MISSING_IN_MODEL = lambda self, style, qn, args: "%s %s\n\t%s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('DROP COLUMN'), style.SQL_FIELD(qn(args[1])))
    SQL_FKEY_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s %s (%s)%s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('ADD COLUMN'), style.SQL_FIELD(qn(args[1])), ' '.join(style.SQL_COLTYPE(a) if i == 0 else style.SQL_KEYWORD(a) for i, a in enumerate(args[4:])), style.SQL_KEYWORD('REFERENCES'), style.SQL_TABLE(qn(args[2])), style.SQL_FIELD(qn(args[3])), connection.ops.deferrable_sql())
    SQL_INDEX_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s (%s%s);" % (style.SQL_KEYWORD('CREATE INDEX'), style.SQL_TABLE(qn("%s" % '_'.join(a for a in args[0:3] if a))), style.SQL_KEYWORD('ON'), style.SQL_TABLE(qn(args[0])), style.SQL_FIELD(qn(args[1])), style.SQL_KEYWORD(args[3]))
    # FIXME: need to lookup index name instead of just appending _idx to table + fieldname
    SQL_INDEX_MISSING_IN_MODEL = lambda self, style, qn, args: "%s %s;" % (style.SQL_KEYWORD('DROP INDEX'), style.SQL_TABLE(qn("%s" % '_'.join(a for a in args[0:3] if a))))
    SQL_UNIQUE_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s (%s);" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('ADD'), style.SQL_KEYWORD('UNIQUE'), style.SQL_FIELD(qn(args[1])))
    # FIXME: need to lookup unique constraint name instead of appending _key to table + fieldname
    SQL_UNIQUE_MISSING_IN_MODEL = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('DROP'), style.SQL_KEYWORD('CONSTRAINT'), style.SQL_TABLE(qn("%s_key" % ('_'.join(args[:2])))))
    SQL_FIELD_TYPE_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD("MODIFY"), style.SQL_FIELD(qn(args[1])), style.SQL_COLTYPE(args[2]))
    SQL_FIELD_PARAMETER_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD("MODIFY"), style.SQL_FIELD(qn(args[1])), style.SQL_COLTYPE(args[2]))
    SQL_NOTNULL_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('MODIFY'), style.SQL_FIELD(qn(args[1])), style.SQL_KEYWORD(args[2]), style.SQL_KEYWORD('NOT NULL'))
    SQL_ERROR = lambda self, style, qn, args: style.NOTICE('-- Error: %s' % style.ERROR(args[0]))
    SQL_COMMENT = lambda self, style, qn, args: style.NOTICE('-- Comment: %s' % style.SQL_TABLE(args[0]))
    SQL_TABLE_MISSING_IN_DB = lambda self, style, qn, args: style.NOTICE('-- Table missing: %s' % args[0])
    SQL_TABLE_MISSING_IN_MODEL = lambda self, style, qn, args: style.NOTICE('-- Model missing for table: %s' % args[0])

    can_detect_notnull_differ = False
    can_detect_unsigned_differ = False
    unsigned_suffix = None

    def __init__(self, app_models, options):
        self.has_differences = None
        self.app_models = app_models
        self.options = options
        self.dense = options.get('dense_output', False)

        try:
            self.introspection = connection.introspection
        except AttributeError:
            from django.db import get_introspection_module
            self.introspection = get_introspection_module()

        self.cursor = connection.cursor()
        self.django_tables = self.get_django_tables(options.get('only_existing', True))
        self.db_tables = self.introspection.get_table_list(self.cursor)
        if django.VERSION[:2] >= (1, 8):
            # TODO: We are losing information about tables which are views here
            self.db_tables = [table_info.name for table_info in self.db_tables]
        self.differences = []
        self.unknown_db_fields = {}
        self.new_db_fields = set()
        self.null = {}
        self.unsigned = set()

        self.DIFF_SQL = {
            'error': self.SQL_ERROR,
            'comment': self.SQL_COMMENT,
            'table-missing-in-db': self.SQL_TABLE_MISSING_IN_DB,
            'table-missing-in-model': self.SQL_TABLE_MISSING_IN_MODEL,
            'field-missing-in-db': self.SQL_FIELD_MISSING_IN_DB,
            'field-missing-in-model': self.SQL_FIELD_MISSING_IN_MODEL,
            'fkey-missing-in-db': self.SQL_FKEY_MISSING_IN_DB,
            'fkey-missing-in-model': self.SQL_FIELD_MISSING_IN_MODEL,
            'index-missing-in-db': self.SQL_INDEX_MISSING_IN_DB,
            'index-missing-in-model': self.SQL_INDEX_MISSING_IN_MODEL,
            'unique-missing-in-db': self.SQL_UNIQUE_MISSING_IN_DB,
            'unique-missing-in-model': self.SQL_UNIQUE_MISSING_IN_MODEL,
            'field-type-differ': self.SQL_FIELD_TYPE_DIFFER,
            'field-parameter-differ': self.SQL_FIELD_PARAMETER_DIFFER,
            'notnull-differ': self.SQL_NOTNULL_DIFFER,
        }

        if self.can_detect_notnull_differ:
            self.load_null()

        if self.can_detect_unsigned_differ:
            self.load_unsigned()

    def load_null(self):
        raise NotImplementedError("load_null functions must be implemented if diff backend has 'can_detect_notnull_differ' set to True")

    def load_unsigned(self):
        raise NotImplementedError("load_unsigned function must be implemented if diff backend has 'can_detect_unsigned_differ' set to True")

    def add_app_model_marker(self, app_label, model_name):
        self.differences.append((app_label, model_name, []))

    def add_difference(self, diff_type, *args):
        assert diff_type in self.DIFF_TYPES, 'Unknown difference type'
        self.differences[-1][-1].append((diff_type, args))

    def get_django_tables(self, only_existing):
        try:
            django_tables = self.introspection.django_table_names(only_existing=only_existing)
        except AttributeError:
            # backwards compatibility for before introspection refactoring (r8296)
            try:
                django_tables = _sql.django_table_names(only_existing=only_existing)
            except AttributeError:
                # backwards compatibility for before svn r7568
                django_tables = _sql.django_table_list(only_existing=only_existing)
        return django_tables

    def sql_to_dict(self, query, param):
        """ sql_to_dict(query, param) -> list of dicts

        code from snippet at http://www.djangosnippets.org/snippets/1383/
        """
        cursor = connection.cursor()
        cursor.execute(query, param)
        fieldnames = [name[0] for name in cursor.description]
        result = []
        for row in cursor.fetchall():
            rowset = []
            for field in zip(fieldnames, row):
                rowset.append(field)
            result.append(dict(rowset))
        return result

    def get_field_model_type(self, field):
        return field.db_type(connection=connection)

    def get_field_db_type(self, description, field=None, table_name=None):
        from django.db import models
        # DB-API cursor.description
        # (name, type_code, display_size, internal_size, precision, scale, null_ok) = description
        type_code = description[1]
        if type_code in self.DATA_TYPES_REVERSE_OVERRIDE:
            reverse_type = self.DATA_TYPES_REVERSE_OVERRIDE[type_code]
        else:
            try:
                try:
                    reverse_type = self.introspection.data_types_reverse[type_code]
                except AttributeError:
                    # backwards compatibility for before introspection refactoring (r8296)
                    reverse_type = self.introspection.DATA_TYPES_REVERSE.get(type_code)
            except KeyError:
                reverse_type = self.get_field_db_type_lookup(type_code)
                if not reverse_type:
                    # type_code not found in data_types_reverse map
                    key = (self.differences[-1][:2], description[:2])
                    if key not in self.unknown_db_fields:
                        self.unknown_db_fields[key] = 1
                        self.add_difference('comment', "Unknown database type for field '%s' (%s)" % (description[0], type_code))
                    return None

        kwargs = {}
        if type_code == 16946 and field and getattr(field, 'geom_type', None) == 'POINT':
            reverse_type = 'django.contrib.gis.db.models.fields.PointField'

        if isinstance(reverse_type, tuple):
            kwargs.update(reverse_type[1])
            reverse_type = reverse_type[0]

        if reverse_type == "CharField" and description[3]:
            kwargs['max_length'] = description[3]

        if reverse_type == "DecimalField":
            kwargs['max_digits'] = description[4]
            kwargs['decimal_places'] = description[5] and abs(description[5]) or description[5]

        if description[6]:
            kwargs['blank'] = True
            if reverse_type not in ('TextField', 'CharField'):
                kwargs['null'] = True

        if field and getattr(field, 'geography', False):
            kwargs['geography'] = True

        if '.' in reverse_type:
            from django_extensions.compat import importlib
            module_path, package_name = reverse_type.rsplit('.', 1)
            module = importlib.import_module(module_path)
            field_db_type = getattr(module, package_name)(**kwargs).db_type(connection=connection)
        else:
            field_db_type = getattr(models, reverse_type)(**kwargs).db_type(connection=connection)

        tablespace = field.db_tablespace
        if not tablespace:
            tablespace = "public"
        if (tablespace, table_name, field.column) in self.unsigned:
            field_db_type = '%s %s' % (field_db_type, self.unsigned_suffix)

        return field_db_type

    def get_field_db_type_lookup(self, type_code):
        return None

    def get_field_db_nullable(self, field, table_name):
        tablespace = field.db_tablespace
        if tablespace == "":
            tablespace = "public"
        attname = field.db_column or field.attname
        return self.null.get((tablespace, table_name, attname), 'fixme')

    def strip_parameters(self, field_type):
        if field_type and field_type != 'double precision':
            return field_type.split(" ")[0].split("(")[0].lower()
        return field_type

    def find_unique_missing_in_db(self, meta, table_indexes, table_constraints, table_name):
        for field in all_local_fields(meta):
            if field.unique and meta.managed:
                attname = field.db_column or field.attname
                db_field_unique = table_indexes.get(attname, {}).get('unique')
                if not db_field_unique and table_constraints:
                    db_field_unique = any(constraint['unique'] for contraint_name, constraint in six.iteritems(table_constraints) if [attname] == constraint['columns'])
                if attname in table_indexes and db_field_unique:
                    continue
                self.add_difference('unique-missing-in-db', table_name, attname)

    def find_unique_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
        # TODO: Postgresql does not list unique_togethers in table_indexes
        #       MySQL does
        fields = dict([(field.db_column or field.name, field.unique) for field in all_local_fields(meta)])
        for att_name, att_opts in six.iteritems(table_indexes):
            db_field_unique = att_opts['unique']
            if not db_field_unique and table_constraints:
                db_field_unique = any(constraint['unique'] for contraint_name, constraint in six.iteritems(table_constraints) if att_name in constraint['columns'])
            if db_field_unique and att_name in fields and not fields[att_name]:
                if att_name in flatten(meta.unique_together):
                    continue
                self.add_difference('unique-missing-in-model', table_name, att_name)

    def find_index_missing_in_db(self, meta, table_indexes, table_constraints, table_name):
        for field in all_local_fields(meta):
            if field.db_index:
                attname = field.db_column or field.attname
                if attname not in table_indexes:
                    self.add_difference('index-missing-in-db', table_name, attname, '', '')
                    db_type = field.db_type(connection=connection)
                    if db_type.startswith('varchar'):
                        self.add_difference('index-missing-in-db', table_name, attname, 'like', ' varchar_pattern_ops')
                    if db_type.startswith('text'):
                        self.add_difference('index-missing-in-db', table_name, attname, 'like', ' text_pattern_ops')
Loading ...