Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-flask-auth / sarus_flask_auth / dialects / HIVEJDBCDialect.py
Size: Mime:
""" Copyright (C) Sarus Technologies SAS - All Rights Reserved
Unauthorized copying of this file, via any medium is strictly prohibited
Proprietary and confidential
Write to contact@sarus.tech for more information about purchasing a licence
"""
import datetime
import decimal
import re

from pyhive.sqlalchemy_hive import HiveDialect
from sqlalchemy import exc, inspect, text, types, util
from sqlalchemy_jdbcapi.base import BaseDialect

from sarus_flask_auth.dialects import JARS_LIST

_type_map = {
    "boolean": types.Boolean,
    "tinyint": types.SmallInteger,  # TODO: maybe we need tiny data type here?
    "smallint": types.SmallInteger,
    "int": types.Integer,
    "bigint": types.BigInteger,
    "float": types.Float,
    "double": types.Float,
    "string": types.String,
    "varchar": types.String,
    "char": types.String,
    "timestamp": types.DateTime,
    "binary": types.String,
    "array": types.String,
    "map": types.String,
    "struct": types.String,
    "uniontype": types.String,
    "decimal": types.DECIMAL,
}


class HIVEJDBCDialect(BaseDialect, HiveDialect):  # pragma: no cover
    name = "HIVEJDBCDialect"
    driver = "org.apache.hive.jdbc.HiveDriver"
    jdbc_db_name = "hive2"
    jdbc_driver_name = "org.apache.hive.jdbc.HiveDriver"
    supports_statement_cache = True

    def initialize(self, connection):
        super(HiveDialect, self).initialize(connection)

    def create_connect_args(self, url):
        """
        Connection arguments for JDBC driver from sqlalchemy url
        """
        if url is None:
            return

        # JDBC driver expects the connection_string like this:
        # "jdbc:sqlserver://{server}:{port};database={db};user=...;password=...;..."  # noqa: E800
        # here we translate the the connection string to the driver style
        # string
        # "hive_jdbc" -> "jdbc:sqlserver"
        s: str = str(url)
        jdbc_url: str = s.split("//", 1)[-1].replace("?", ";")
        jdbc_creds, jdbs_address = jdbc_url.split("@", 1)
        user, password = jdbc_creds.split(":")

        if not jdbc_url.startswith("jdbc"):
            jdbc_url = f"jdbc:{self.jdbc_db_name}://{jdbs_address}"
        kwargs = {
            "jclassname": self.jdbc_driver_name,
            "url": jdbc_url,
            "driver_args": {
                "user": user,
                "password": password,
            },
            "jars": JARS_LIST,
        }
        return ((), kwargs)

    def do_commit(self, connection):
        """
        Hive does not support transactions, so we suppress this function
        """
        pass

    def has_table(self, connection, table_name, schema=None, **kw):
        """
        JDBC driver raises an error when table does not exist, so we replace
        has_table function
        It is needed for create_all() and other sqlalchemy functions
        """
        tables = inspect(connection).get_table_names(schema=schema)
        if table_name in tables:
            return True
        else:
            return False

    def _get_table_columns(self, connection, table_name, schema):
        """
        The method is taken from pyhive package:
        https://github.com/dropbox/PyHive
        """
        full_table = table_name
        if schema:
            full_table = schema + "." + table_name
        try:
            # This needs the table name to be unescaped (no backticks).
            rows = connection.execute(
                text("DESCRIBE {}".format(full_table))
            ).fetchall()
        except exc.OperationalError as e:
            # Check if table exists
            regex_fmt = (
                r"TExecuteStatementResp.*SemanticException.*Table not found {}"
            )
            regex = regex_fmt.format(re.escape(full_table))
            if re.search(regex, e.args[0]):
                raise exc.NoSuchTableError(full_table)
            else:
                raise
        else:
            regex = r"Table .* does not exist"
            if len(rows) == 1 and re.match(regex, rows[0].col_name):
                raise exc.NoSuchTableError(full_table)
            return rows

    def get_columns(self, connection, table_name, schema=None, **kw):
        """
        The method is taken from pyhive package:
        https://github.com/dropbox/PyHive
        """
        rows = self._get_table_columns(connection, table_name, schema)
        # Strip whitespace
        rows = [[col.strip() if col else None for col in row] for row in rows]
        # Filter out empty rows and comment
        rows = [row for row in rows if row[0] and row[0] != "# col_name"]
        result = []
        for col_name, col_type, _comment in rows:
            if col_name == "# Partition Information":
                break
            # Take out the more detailed type information
            # e.g. 'map<int,int>' -> 'map'
            #      'decimal(10,1)' -> decimal
            # Use the same method as pyhive package
            col_type = re.search(r"^\w+", col_type).group(0)
            try:
                coltype = _type_map[col_type]
            except KeyError:
                util.warn(
                    "Did not recognize type '%s' of column '%s'"
                    % (col_type, col_name)
                )
                coltype = types.NullType

            result.append(
                {
                    "name": col_name,
                    "type": coltype,
                    "nullable": True,
                    "default": None,
                }
            )
        return result

    def on_connect(self):
        """
        We change jaydebeapi driver data types
        see https://github.com/baztian/jaydebeapi
        """

        # TODO: check if java binary type needs to be transformed
        def _to_binary(rs, col):
            java_val = rs.getObject(col)
            return java_val

        def _to_bool(rs, col):
            """
            Ensure to return boolean values as bool,
            not strings
            """
            java_val = rs.getObject(col)
            if java_val is None:
                return
            return bool(java_val)

        def _to_date(rs, col):
            """
            Transform java date type to python date
            """
            java_val = rs.getDate(col)
            if not java_val:
                return
            d = datetime.datetime.strptime(str(java_val)[:10], "%Y-%m-%d")
            return d.date()

        def _to_datetime(rs, col):
            """
            Force to return timestamps as datetime.datetime,
            not strings
            """
            java_val = rs.getTimestamp(col)
            if not java_val:
                return
            datetype_value = datetime.datetime.strptime(
                str(java_val)[:19], "%Y-%m-%d %H:%M:%S"
            )
            datetype_value.replace(
                microsecond=int(str(java_val.getNanos())[:6])
            )
            return datetype_value

        def _to_double(rs, col):
            """
            Transform java double type to python float
            """
            java_val = rs.getObject(col)
            if java_val is None:
                return
            if hasattr(java_val, "scale"):
                scale = java_val.scale()
                if scale == 0:
                    return float(java_val.longValue())
                else:
                    return float(java_val.doubleValue())
            else:
                return float(java_val)

        def _to_decimal(rs, col):
            """
            Parse java decimal type
            """
            java_val = rs.getObject(col)
            if java_val is None:
                return
            if hasattr(java_val, "scale"):
                scale = java_val.scale()
                if scale == 0:
                    return decimal.Decimal(java_val.longValue())
                else:
                    return decimal.Decimal(java_val.doubleValue())
            else:
                return float(java_val)

        def _to_int(rs, col):
            """
            Parse java integer type
            """
            java_val = rs.getObject(col)
            if java_val is None:
                return
            return int(java_val)

        # replace dbapi type converters with the local ones
        # see
        # http://download.oracle.com/javase/8/docs/api/java/sql/Types.html
        # for possible keys
        self.dbapi._DEFAULT_CONVERTERS.update(
            {
                "BIGINT": _to_int,
                "BINARY": _to_binary,
                "BIT": _to_bool,
                "BOOLEAN": _to_bool,
                "DATE": _to_date,
                "DECIMAL": _to_decimal,
                "DOUBLE": _to_double,
                "INTEGER": _to_int,
                "FLOAT": _to_double,
                "NUMERIC": _to_decimal,
                "SMALLINT": _to_int,
                "TINYINT": _to_int,
                "TIME": _to_datetime,
                "TIMESTAMP": _to_datetime,
                "TINYINT": _to_int,
            }
        )

        return