Repository URL to install this package:
|
Version:
0.6.6.dev0 ▾
|
""" Copyright (C) Sarus Technologies SAS - All Rights Reserved
Unauthorized copying of this file, via any medium is strictly prohibited
Proprietary and confidential
Write to contact@sarus.tech for more information about purchasing a licence
"""
import datetime
import decimal
import re
from pyhive.sqlalchemy_hive import HiveDialect
from sqlalchemy import exc, inspect, text, types, util
from sqlalchemy_jdbcapi.base import BaseDialect
from sarus_flask_auth.dialects import JARS_LIST
_type_map = {
"boolean": types.Boolean,
"tinyint": types.SmallInteger, # TODO: maybe we need tiny data type here?
"smallint": types.SmallInteger,
"int": types.Integer,
"bigint": types.BigInteger,
"float": types.Float,
"double": types.Float,
"string": types.String,
"varchar": types.String,
"char": types.String,
"timestamp": types.DateTime,
"binary": types.String,
"array": types.String,
"map": types.String,
"struct": types.String,
"uniontype": types.String,
"decimal": types.DECIMAL,
}
class HIVEJDBCDialect(BaseDialect, HiveDialect): # pragma: no cover
name = "HIVEJDBCDialect"
driver = "org.apache.hive.jdbc.HiveDriver"
jdbc_db_name = "hive2"
jdbc_driver_name = "org.apache.hive.jdbc.HiveDriver"
supports_statement_cache = True
def initialize(self, connection):
super(HiveDialect, self).initialize(connection)
def create_connect_args(self, url):
"""
Connection arguments for JDBC driver from sqlalchemy url
"""
if url is None:
return
# JDBC driver expects the connection_string like this:
# "jdbc:sqlserver://{server}:{port};database={db};user=...;password=...;..." # noqa: E800
# here we translate the the connection string to the driver style
# string
# "hive_jdbc" -> "jdbc:sqlserver"
s: str = str(url)
jdbc_url: str = s.split("//", 1)[-1].replace("?", ";")
jdbc_creds, jdbs_address = jdbc_url.split("@", 1)
user, password = jdbc_creds.split(":")
if not jdbc_url.startswith("jdbc"):
jdbc_url = f"jdbc:{self.jdbc_db_name}://{jdbs_address}"
kwargs = {
"jclassname": self.jdbc_driver_name,
"url": jdbc_url,
"driver_args": {
"user": user,
"password": password,
},
"jars": JARS_LIST,
}
return ((), kwargs)
def do_commit(self, connection):
"""
Hive does not support transactions, so we suppress this function
"""
pass
def has_table(self, connection, table_name, schema=None, **kw):
"""
JDBC driver raises an error when table does not exist, so we replace
has_table function
It is needed for create_all() and other sqlalchemy functions
"""
tables = inspect(connection).get_table_names(schema=schema)
if table_name in tables:
return True
else:
return False
def _get_table_columns(self, connection, table_name, schema):
"""
The method is taken from pyhive package:
https://github.com/dropbox/PyHive
"""
full_table = table_name
if schema:
full_table = schema + "." + table_name
try:
# This needs the table name to be unescaped (no backticks).
rows = connection.execute(
text("DESCRIBE {}".format(full_table))
).fetchall()
except exc.OperationalError as e:
# Check if table exists
regex_fmt = (
r"TExecuteStatementResp.*SemanticException.*Table not found {}"
)
regex = regex_fmt.format(re.escape(full_table))
if re.search(regex, e.args[0]):
raise exc.NoSuchTableError(full_table)
else:
raise
else:
regex = r"Table .* does not exist"
if len(rows) == 1 and re.match(regex, rows[0].col_name):
raise exc.NoSuchTableError(full_table)
return rows
def get_columns(self, connection, table_name, schema=None, **kw):
"""
The method is taken from pyhive package:
https://github.com/dropbox/PyHive
"""
rows = self._get_table_columns(connection, table_name, schema)
# Strip whitespace
rows = [[col.strip() if col else None for col in row] for row in rows]
# Filter out empty rows and comment
rows = [row for row in rows if row[0] and row[0] != "# col_name"]
result = []
for col_name, col_type, _comment in rows:
if col_name == "# Partition Information":
break
# Take out the more detailed type information
# e.g. 'map<int,int>' -> 'map'
# 'decimal(10,1)' -> decimal
# Use the same method as pyhive package
col_type = re.search(r"^\w+", col_type).group(0)
try:
coltype = _type_map[col_type]
except KeyError:
util.warn(
"Did not recognize type '%s' of column '%s'"
% (col_type, col_name)
)
coltype = types.NullType
result.append(
{
"name": col_name,
"type": coltype,
"nullable": True,
"default": None,
}
)
return result
def on_connect(self):
"""
We change jaydebeapi driver data types
see https://github.com/baztian/jaydebeapi
"""
# TODO: check if java binary type needs to be transformed
def _to_binary(rs, col):
java_val = rs.getObject(col)
return java_val
def _to_bool(rs, col):
"""
Ensure to return boolean values as bool,
not strings
"""
java_val = rs.getObject(col)
if java_val is None:
return
return bool(java_val)
def _to_date(rs, col):
"""
Transform java date type to python date
"""
java_val = rs.getDate(col)
if not java_val:
return
d = datetime.datetime.strptime(str(java_val)[:10], "%Y-%m-%d")
return d.date()
def _to_datetime(rs, col):
"""
Force to return timestamps as datetime.datetime,
not strings
"""
java_val = rs.getTimestamp(col)
if not java_val:
return
datetype_value = datetime.datetime.strptime(
str(java_val)[:19], "%Y-%m-%d %H:%M:%S"
)
datetype_value.replace(
microsecond=int(str(java_val.getNanos())[:6])
)
return datetype_value
def _to_double(rs, col):
"""
Transform java double type to python float
"""
java_val = rs.getObject(col)
if java_val is None:
return
if hasattr(java_val, "scale"):
scale = java_val.scale()
if scale == 0:
return float(java_val.longValue())
else:
return float(java_val.doubleValue())
else:
return float(java_val)
def _to_decimal(rs, col):
"""
Parse java decimal type
"""
java_val = rs.getObject(col)
if java_val is None:
return
if hasattr(java_val, "scale"):
scale = java_val.scale()
if scale == 0:
return decimal.Decimal(java_val.longValue())
else:
return decimal.Decimal(java_val.doubleValue())
else:
return float(java_val)
def _to_int(rs, col):
"""
Parse java integer type
"""
java_val = rs.getObject(col)
if java_val is None:
return
return int(java_val)
# replace dbapi type converters with the local ones
# see
# http://download.oracle.com/javase/8/docs/api/java/sql/Types.html
# for possible keys
self.dbapi._DEFAULT_CONVERTERS.update(
{
"BIGINT": _to_int,
"BINARY": _to_binary,
"BIT": _to_bool,
"BOOLEAN": _to_bool,
"DATE": _to_date,
"DECIMAL": _to_decimal,
"DOUBLE": _to_double,
"INTEGER": _to_int,
"FLOAT": _to_double,
"NUMERIC": _to_decimal,
"SMALLINT": _to_int,
"TINYINT": _to_int,
"TIME": _to_datetime,
"TIMESTAMP": _to_datetime,
"TINYINT": _to_int,
}
)
return