Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-flask-auth / sarus_flask_auth / dialects / DATABRICKSDialect.py
Size: Mime:
""" Copyright (C) Sarus Technologies SAS - All Rights Reserved
Unauthorized copying of this file, via any medium is strictly prohibited
Proprietary and confidential
Write to contact@sarus.tech for more information about purchasing a licence
"""
import datetime
import decimal
import re

from databricks.sqlalchemy import DatabricksDialect
from sqlalchemy import sql, text, types
from sqlalchemy.engine import processors, reflection


class DatabricksDecimal(types.TypeDecorator):  # pragma: no cover
    """Translates strings to decimals.
    Class is taken from databricks-sql-python library,
    as_generic function is re-assigned here.
    """

    impl = types.DECIMAL

    def process_result_value(self, value, dialect):
        if value is not None:
            if isinstance(value, decimal.Decimal):
                return value
            else:
                return decimal.Decimal(value)
        else:
            return None

    def as_generic(self):
        """Fixes the error in schema detection."""
        return types.DECIMAL()


class DatabricksTimestamp(types.TypeDecorator):  # pragma: no cover
    """Translates timestamp strings to datetime objects.
    Class is taken from databricks-sql-python library,
    as_generic function is re-assigned here.
    """

    impl = types.TIMESTAMP

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        elif isinstance(value, str):
            return processors.str_to_datetime(value)
        elif isinstance(value, datetime.datetime):
            return value
        else:
            raise ValueError("Wrong datetime format")

    def adapt(self, impltype, **kwargs):
        return self.impl

    def as_generic(self):
        """Fixes the error in schema detection."""
        return types.DateTime()


class DatabricksDate(types.TypeDecorator):  # pragma: no cover
    """Translates date strings to date objects.
    Class is taken from databricks-sql-python library,
    as_generic function is re-assigned here.
    """

    impl = types.DATE

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        elif isinstance(value, str):
            return processors.str_to_date(value)
        elif isinstance(value, (datetime.datetime, datetime.date)):
            return value
        else:
            raise ValueError("Wrong date format")

        return processors.str_to_date(value)

    def adapt(self, impltype, **kwargs):
        return self.impl

    def as_generic(self):
        """Fixes the error in schema detection."""
        return types.Date()


class DATABRICKSDialect(DatabricksDialect):  # pragma: no cover
    def get_columns(self, connection, table_name, schema=None, **kwargs):
        """This function is taken from databricks-sql-python library,
        provided by Databricks:
        https://github.com/databricks/databricks-sql-python
        """

        _type_map = {
            "boolean": types.Boolean,
            "smallint": types.SmallInteger,
            "int": types.Integer,
            "bigint": types.BigInteger,
            "float": types.Float,
            "double": types.Float,
            "string": types.String,
            "varchar": types.String,
            "char": types.String,
            "binary": types.String,
            "array": types.String,
            "map": types.String,
            "struct": types.String,
            "uniontype": types.String,
            "decimal": DatabricksDecimal,
            "timestamp": DatabricksTimestamp,
            "date": DatabricksDate,
        }

        with self.get_driver_connection(
            connection
        )._dbapi_connection.dbapi_connection.cursor() as cur:
            resp = cur.columns(
                catalog_name=self.catalog,
                schema_name=schema or self.schema,
                table_name=table_name,
            ).fetchall()

        columns = []

        for col in resp:
            # Taken from PyHive. This removes added type info from decimals
            # and maps
            _col_type = re.search(r"^\w+", col.TYPE_NAME).group(0)
            this_column = {
                "name": col.COLUMN_NAME,
                "type": _type_map[_col_type.lower()],
                "nullable": bool(col.NULLABLE),
                "default": col.COLUMN_DEF,
                "autoincrement": False
                if col.IS_AUTO_INCREMENT == "NO"
                else True,
            }
            columns.append(this_column)

        return columns

    @reflection.cache
    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
        """Return information about foreign_keys in `table_name`.
        Given a :class:`_engine.Connection`, a string
        `table_name`, and an optional string `schema`, return foreign
        key information as a list of dicts with these keys:
        name
          the constraint's name
        constrained_columns
          a list of column names that make up the foreign key
        referred_schema
          the name of the referred schema
        referred_table
          the name of the referred table
        referred_columns
          a list of column names in the referred table that correspond to
          constrained_columns
        """
        if not self.has_table(
            connection=connection,
            table_name="information_schema.table_constraints",
        ):
            # Union catalog is not enabled and no constraints are defined
            return []
        else:
            # there is some information about constraints
            if not schema:
                schema = connection.dialect.default_schema_name

            query = text(
                """\
SELECT
    tc.constraint_name AS constraint_name,
    kcu.column_name AS constrained_column,
    ccu.table_schema AS referred_schema,
    ccu.table_name AS referred_table,
    ccu.column_name AS referred_column
FROM
    information_schema.table_constraints AS tc
    JOIN information_schema.key_column_usage AS kcu
      ON tc.constraint_name = kcu.constraint_name
    JOIN information_schema.constraint_column_usage AS ccu
      ON ccu.constraint_name = tc.constraint_name
WHERE
    constraint_type = 'FOREIGN KEY'
    AND tc.table_name=:tablename
    AND tc.table_schema=:schema
                """
            ).bindparams(
                sql.bindparam("tablename", table_name),
                sql.bindparam("schema", schema),
            )

            results = connection.execute(query)
            constraint_list = list()
            for (
                constraint_name,
                constrained_column,
                referred_schema,
                referred_table,
                referred_column,
            ) in results:
                current_constraint = {
                    "name": constraint_name,
                    "constrained_columns": [constrained_column],
                    "referred_schema": referred_schema,
                    "referred_table": referred_table,
                    "referred_columns": [referred_column],
                }
                constraint_list.append(current_constraint)

            return constraint_list if constraint_list else []

    @reflection.cache
    def get_pk_constraint(self, connection, table_name, schema=None, **kw):
        """Return information about the primary key constraint on
        table_name`.
        Given a :class:`_engine.Connection`, a string
        `table_name`, and an optional string `schema`, return primary
        key information as a dictionary with these keys:
        constrained_columns
          a list of column names that make up the primary key
        name
          optional name of the primary key constraint.
        """
        if not self.has_table(
            connection=connection,
            table_name="information_schema.table_constraints",
        ):
            # Union catalog is not enabled and no constraints are defined
            return {
                "name": None,
                "constrained_columns": [],
            }
        else:
            # there is some information about constraints
            if not schema:
                schema = connection.dialect.default_schema_name
            query = text(
                """\
SELECT
    tc.constraint_name AS constraint_name,
    kcu.column_name AS constrained_column
FROM
    information_schema.table_constraints AS tc
    JOIN information_schema.key_column_usage AS kcu
      ON tc.constraint_name = kcu.constraint_name
WHERE
    constraint_type = 'PRIMARY KEY'
    AND tc.table_name=:tablename
    AND tc.table_schema=:schema
                """
            ).bindparams(
                sql.bindparam("tablename", table_name),
                sql.bindparam("schema", schema),
            )
            results = connection.execute(query).fetchall()
            if results:
                # only one primary key is accepted
                constraint_name, constrained_column = results[0]
                return {
                    "constrained_columns": [constrained_column],
                    "name": constraint_name,
                }
            else:
                return {
                    "name": None,
                    "constrained_columns": [],
                }