Repository URL to install this package:
|
Version:
0.6.6.dev0 ▾
|
""" Copyright (C) Sarus Technologies SAS - All Rights Reserved
Unauthorized copying of this file, via any medium is strictly prohibited
Proprietary and confidential
Write to contact@sarus.tech for more information about purchasing a licence
"""
import datetime
import decimal
import re
from databricks.sqlalchemy import DatabricksDialect
from sqlalchemy import sql, text, types
from sqlalchemy.engine import processors, reflection
class DatabricksDecimal(types.TypeDecorator): # pragma: no cover
"""Translates strings to decimals.
Class is taken from databricks-sql-python library,
as_generic function is re-assigned here.
"""
impl = types.DECIMAL
def process_result_value(self, value, dialect):
if value is not None:
if isinstance(value, decimal.Decimal):
return value
else:
return decimal.Decimal(value)
else:
return None
def as_generic(self):
"""Fixes the error in schema detection."""
return types.DECIMAL()
class DatabricksTimestamp(types.TypeDecorator): # pragma: no cover
"""Translates timestamp strings to datetime objects.
Class is taken from databricks-sql-python library,
as_generic function is re-assigned here.
"""
impl = types.TIMESTAMP
def process_result_value(self, value, dialect):
if value is None:
return None
elif isinstance(value, str):
return processors.str_to_datetime(value)
elif isinstance(value, datetime.datetime):
return value
else:
raise ValueError("Wrong datetime format")
def adapt(self, impltype, **kwargs):
return self.impl
def as_generic(self):
"""Fixes the error in schema detection."""
return types.DateTime()
class DatabricksDate(types.TypeDecorator): # pragma: no cover
"""Translates date strings to date objects.
Class is taken from databricks-sql-python library,
as_generic function is re-assigned here.
"""
impl = types.DATE
def process_result_value(self, value, dialect):
if value is None:
return None
elif isinstance(value, str):
return processors.str_to_date(value)
elif isinstance(value, (datetime.datetime, datetime.date)):
return value
else:
raise ValueError("Wrong date format")
return processors.str_to_date(value)
def adapt(self, impltype, **kwargs):
return self.impl
def as_generic(self):
"""Fixes the error in schema detection."""
return types.Date()
class DATABRICKSDialect(DatabricksDialect): # pragma: no cover
def get_columns(self, connection, table_name, schema=None, **kwargs):
"""This function is taken from databricks-sql-python library,
provided by Databricks:
https://github.com/databricks/databricks-sql-python
"""
_type_map = {
"boolean": types.Boolean,
"smallint": types.SmallInteger,
"int": types.Integer,
"bigint": types.BigInteger,
"float": types.Float,
"double": types.Float,
"string": types.String,
"varchar": types.String,
"char": types.String,
"binary": types.String,
"array": types.String,
"map": types.String,
"struct": types.String,
"uniontype": types.String,
"decimal": DatabricksDecimal,
"timestamp": DatabricksTimestamp,
"date": DatabricksDate,
}
with self.get_driver_connection(
connection
)._dbapi_connection.dbapi_connection.cursor() as cur:
resp = cur.columns(
catalog_name=self.catalog,
schema_name=schema or self.schema,
table_name=table_name,
).fetchall()
columns = []
for col in resp:
# Taken from PyHive. This removes added type info from decimals
# and maps
_col_type = re.search(r"^\w+", col.TYPE_NAME).group(0)
this_column = {
"name": col.COLUMN_NAME,
"type": _type_map[_col_type.lower()],
"nullable": bool(col.NULLABLE),
"default": col.COLUMN_DEF,
"autoincrement": False
if col.IS_AUTO_INCREMENT == "NO"
else True,
}
columns.append(this_column)
return columns
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
"""Return information about foreign_keys in `table_name`.
Given a :class:`_engine.Connection`, a string
`table_name`, and an optional string `schema`, return foreign
key information as a list of dicts with these keys:
name
the constraint's name
constrained_columns
a list of column names that make up the foreign key
referred_schema
the name of the referred schema
referred_table
the name of the referred table
referred_columns
a list of column names in the referred table that correspond to
constrained_columns
"""
if not self.has_table(
connection=connection,
table_name="information_schema.table_constraints",
):
# Union catalog is not enabled and no constraints are defined
return []
else:
# there is some information about constraints
if not schema:
schema = connection.dialect.default_schema_name
query = text(
"""\
SELECT
tc.constraint_name AS constraint_name,
kcu.column_name AS constrained_column,
ccu.table_schema AS referred_schema,
ccu.table_name AS referred_table,
ccu.column_name AS referred_column
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
WHERE
constraint_type = 'FOREIGN KEY'
AND tc.table_name=:tablename
AND tc.table_schema=:schema
"""
).bindparams(
sql.bindparam("tablename", table_name),
sql.bindparam("schema", schema),
)
results = connection.execute(query)
constraint_list = list()
for (
constraint_name,
constrained_column,
referred_schema,
referred_table,
referred_column,
) in results:
current_constraint = {
"name": constraint_name,
"constrained_columns": [constrained_column],
"referred_schema": referred_schema,
"referred_table": referred_table,
"referred_columns": [referred_column],
}
constraint_list.append(current_constraint)
return constraint_list if constraint_list else []
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
"""Return information about the primary key constraint on
table_name`.
Given a :class:`_engine.Connection`, a string
`table_name`, and an optional string `schema`, return primary
key information as a dictionary with these keys:
constrained_columns
a list of column names that make up the primary key
name
optional name of the primary key constraint.
"""
if not self.has_table(
connection=connection,
table_name="information_schema.table_constraints",
):
# Union catalog is not enabled and no constraints are defined
return {
"name": None,
"constrained_columns": [],
}
else:
# there is some information about constraints
if not schema:
schema = connection.dialect.default_schema_name
query = text(
"""\
SELECT
tc.constraint_name AS constraint_name,
kcu.column_name AS constrained_column
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
WHERE
constraint_type = 'PRIMARY KEY'
AND tc.table_name=:tablename
AND tc.table_schema=:schema
"""
).bindparams(
sql.bindparam("tablename", table_name),
sql.bindparam("schema", schema),
)
results = connection.execute(query).fetchall()
if results:
# only one primary key is accepted
constraint_name, constrained_column = results[0]
return {
"constrained_columns": [constrained_column],
"name": constraint_name,
}
else:
return {
"name": None,
"constrained_columns": [],
}