Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import copy
import typing as t

from snsql._ast.ast import (
    AliasedSubquery,
    Column,
    From,
    Identifier,
    Literal,
    NamedExpression,
    Query,
    Relation,
    Select,
    Seq,
    Sql,
    SqlExpr,
    Table,
    Token,
    UnifiedQuery,
    UserDefinedFunction,
)
from snsql._ast.expressions.date import ExtractFunction
from snsql._ast.expressions.string import (
    DecodeFunction,
    EncodeFunction,
    RegexpSplitToTableFunction,
    SubstrBigQueryFunction,
    SubstringFunction,
    UnhexFunction,
)
from snsql._ast.expressions.types import CastFunction

from ..table_renamer import TableRenamer
from ..utils import is_quoted, split_quote
from .base import BaseTranslator

# Databricks aggregate functions that are not common with Postgres
USER_DEFINED_AGGS = (
    "ANY_VALUE",
    "APPROX_COUNT_DISTINCT",
    "APPROX_PERCENTILE",
    "APPROX_TOP_K",
    "BITMAP_CONSTRUCT_AGG",
    "BITMAP_OR_AGG",
    "COLLECT_LIST",
    "COLLECT_SET",
    "COUNT_IF",
    "COUNT_MIN_SKETCH",
    "FIRST",
    "HLL_SKETCH_AGG",
    "HLL_UNION_AGG",
    "KURTOSIS",
    "LAST",
    "MAX_BY",
    "MEAN",
    "MEDIAN",
    "MIN_BY",
    "PERCENTILE",
    "PERCENTILE_APPROX",
    "SKEWNESS",
    "STD",
    "TRY_AVG",
    "TRY_SUM",
)


class DatabricksTranslator(BaseTranslator):
    """Translator for Databricks <-> Postgres.
    Databricks does not support quote '"' in query,
    it should be replaced with '`'.

    We quote all aliases to avoid conflicts with Databricks keywords
    Moreover, aliases can't contain special characters like dot etc..
    """

    def __init__(
        self, query: t.Union[Query, UnifiedQuery], **kwargs: t.Any
    ) -> None:
        super().__init__(query, **kwargs)
        self.quote = "`"
        self.alias_map = self.alias_mapping()
        self.rename_tables_columns()

    def alias_mapping(self) -> t.Dict[str, str]:
        """it generates a dict mapping actual aliases into normalized aliases.
        In normalized aliases, spaces and dots are replaced with underscores

        Returns:
            Dict[str, str]: actual_alias: normalized alias
        """

        tables = self.query.find_nodes(Table)
        exprs = self.query.find_nodes(NamedExpression)
        aliased_sub = self.query.find_nodes(AliasedSubquery)

        aliases = {}
        for expr in [*tables, *aliased_sub]:
            if expr.alias:
                aliases[expr.alias.text] = expr.alias.text.replace(
                    " ", "_"
                ).replace(".", "_")

        for expr in exprs:
            if expr.name:
                aliases[expr.name.text] = expr.name.text.replace(
                    " ", "_"
                ).replace(".", "_")

        return aliases

    def rename_tables_columns(self) -> None:
        """Using TableRenamer to rename columns and tables. There are aliases
        left to be renamed.
        """
        # map_dict should contain escaped names
        compatible_alias_map = {
            split_quote(key, "`"): split_quote(val, "`")
            for key, val in self.alias_map.items()
        }
        renamer = TableRenamer(copy.deepcopy(self.query), compatible_alias_map)
        renamer.rename()
        self.query = renamer.query

    def rename_add_quotes(self, alias: str, quote: str) -> str:
        """renames aliases. To avoid conflicts with possible databricks
        keywords we backquote all aliases.
        """
        if alias in self.alias_map:
            alias = self.alias_map[alias]

        if is_quoted(alias):
            alias = f"{quote}{alias[1:-1]}{quote}"
        else:
            alias = f"{quote}{alias}{quote}"
        return alias

    def transform_to_dialect(self, expr: Sql) -> Sql:
        if isinstance(expr, Column):
            """Changes or add backquoting quoting"""
            parts = split_quote(expr.name, self.quote)
            expr.name = ".".join(parts[-2:])
            return expr

        elif isinstance(expr, Table):
            parts = split_quote(expr.name, self.quote)
            expr.name = ".".join(parts)
            if expr.alias:
                expr.alias = Identifier(
                    self.rename_add_quotes(expr.alias.text, self.quote)
                )
            return expr

        elif isinstance(expr, AliasedSubquery):
            if expr.alias:
                expr.alias = Identifier(
                    self.rename_add_quotes(expr.alias.text, self.quote)
                )

        elif isinstance(expr, NamedExpression):
            """Add backquoting to aliases"""
            if expr.name:
                expr.name = Identifier(
                    self.rename_add_quotes(expr.name.text, self.quote)
                )

        # Encode(Decode(source, 'hex'), 'escape') -> Unhex(source)
        elif (
            isinstance(expr, EncodeFunction)
            and expr.format == Literal(None, "'escape'")
            and isinstance(expr.source, DecodeFunction)
            and expr.source.format == Literal(None, "'hex'")
        ):
            return CastFunction(
                UnhexFunction(
                    self.walk_tree(expr.source.source), token=Token("UNHEX")
                ),
                "string",
            )

        if isinstance(expr, CastFunction):
            if (
                expr.dbtype.startswith("varchar")
                or expr.dbtype.startswith("char")
                or expr.dbtype == "text"
            ):
                expr.dbtype = "string"

        if isinstance(expr, UserDefinedFunction) and expr.name == "TIMEZONE":
            timezone = UserDefinedFunction(
                name="CONVERT_TIMEZONE",
                expressions=self.walk_tree(expr.expressions),
            )
            return timezone

        if (
            isinstance(expr, UserDefinedFunction)
            and expr.name.upper() in USER_DEFINED_AGGS
        ):
            raise NotImplementedError(
                f"Aggregate functions {expr.name}" f" not supported: {expr}"
            )

        if isinstance(expr, (SubstringFunction, SubstrBigQueryFunction)):
            return SubstringFunction(
                source=self.walk_tree(
                    expr.source,
                ),
                start=expr.start,
                length=expr.length,
                tokens={"FROM": Token(","), "FOR": Token(",")},
            )

        if isinstance(expr, Query) and any(
            [
                isinstance(name_expr.expression, RegexpSplitToTableFunction)
                for name_expr in expr.select.namedExpressions
            ]
        ):
            # REGEXP_SPLIT_TO_TABLE be translated as with a SPLIT and UNNEST:
            # example:
            #
            # SELECT DISTINCT REGEXP_SPLIT_TO_TABLE ( tab.name , '' ) AS col
            # FROM (
            #   SELECT name AS name FROM mytable WHERE name IS NOT NULL
            # ) AS tab
            #
            # it becomes:
            #
            # SELECT DISTINCT `exploded` AS `col`
            # FROM (
            #   SELECT EXPLODE( SPLIT( `tab`.`name`,'')) AS `exploded`
            #   FROM(
            #     SELECT `name` AS `name` FROM `mytable` WHERE `name`
            #        IS NOT NULL
            #   ) AS`tab`
            # ) AS `split_subq`
            #

            main_named_exprs = []
            for named_expr in expr.select.namedExpressions:
                if isinstance(
                    named_expr.expression, RegexpSplitToTableFunction
                ):
                    regexp_alias = named_expr.name
                    regexp_source_expr = self.walk_tree(
                        named_expr.expression.source
                    )
                    regexp_pattern = self.walk_tree(
                        named_expr.expression.pattern
                    )
                else:
                    main_named_exprs.append(named_expr)

            # building the subquery
            split = UserDefinedFunction(
                name="SPLIT",
                expressions=Seq([regexp_source_expr, regexp_pattern]),
            )
            explode = UserDefinedFunction(
                name="EXPLODE",
                expressions=Seq([split]),
            )
            namedexpr = NamedExpression(
                expression=explode, name=Identifier("exploded")
            )
            select = Select(quantifier=None, namedExpressions=Seq([namedexpr]))
            subquery = Query(
                select=select,
                source=expr.source,
                where=None,
                agg=None,
                having=None,
                order=None,
                limit=None,
            )
            aliased_subquery = AliasedSubquery(
                query=subquery, alias=Identifier("split_subq")
            )
            relation_with_subquery = Relation(
                primary=aliased_subquery, joins=None
            )

            main_nexpr = NamedExpression(
                expression=Column("exploded"), name=regexp_alias
            )
            main_named_exprs.append(main_nexpr)
            main_select = Select(
                quantifier=expr.select.quantifier,
                namedExpressions=Seq(main_named_exprs),
            )
            main_source = Seq([relation_with_subquery])
            expr.select = main_select
            expr.source = From(main_source)

        if isinstance(expr, ExtractFunction) and expr.date_part == "epoch":
            # It converts EXTRACT(epoch FROM column) into
            # DATEDIFF(SECOND, '1970-01-01', column)
            # used to compute bigdata bounds for date and datetimes
            return UserDefinedFunction(
                name="DATEDIFF",
                expressions=Seq(
                    [
                        Token("SECOND"),
                        Literal("'1970-01-01'"),
                        self.walk_tree(expr.expression),
                    ]
                ),
            )

        return self.walk_tree(expr)

    def transform_to_postgres(self, expr: SqlExpr) -> SqlExpr:
        ident_quote = '"'

        if isinstance(expr, NamedExpression):
            """Add backquoting to aliases"""
            if expr.name:
                expr.name = Identifier(
                    self.rename_add_quotes(expr.name.text, ident_quote)
                )

        if isinstance(expr, Column) or isinstance(expr, Table):
            parts = split_quote(expr.name, ident_quote)
            expr.name = ".".join(parts)
            if isinstance(expr, Table) and expr.alias:
                expr.alias = Identifier(
                    self.rename_add_quotes(expr.alias.text, ident_quote)
                )
            return expr

        elif isinstance(expr, AliasedSubquery) and expr.alias:
            expr.alias = Identifier(
                self.rename_add_quotes(expr.alias.text, ident_quote)
            )

        elif isinstance(expr, CastFunction):
            # UNHEX(source)
            # -> Encode(Decode(source, 'escape'),'hex')
            if isinstance(expr.expression, UnhexFunction):
                return EncodeFunction(
                    source=DecodeFunction(
                        source=self.walk_tree(expr.expression.source),
                        format=Literal(None, "'hex'"),
                    ),
                    format=Literal(None, "'escape'"),
                )

            if expr.dbtype.startswith("string"):
                expr.dbtype = "varchar"

        if isinstance(expr, SubstringFunction):
            return SubstringFunction(
                source=self.walk_tree(
                    expr.source,
                ),
                start=expr.start,
                length=expr.length,
                tokens={"FROM": Token(","), "FOR": Token(",")},
            )

        if isinstance(expr, UserDefinedFunction):
            # it translates DATEDIFF(SECOND, '1970-01-01', col)
            # to EXTRACT(epoch FROM col)
            if (
                expr.name == "DATEDIFF"
                and expr.expressions[0] == Token("SECOND")
                and expr.expressions[1] == Literal("'1970-01-01'")
            ):
                return ExtractFunction(
                    "epoch", self.walk_tree(expr.expressions[2])
                )

            if expr.name == "UNIX_SECONDS":
                return ExtractFunction(
                    "epoch", self.walk_tree(expr.expressions)
                )

            # convert CONVERT_TIMEZONE(ts, tz)
            # -> TIMEZONE(ts, tz)
            if expr.name == "CONVERT_TIMEZONE":
                assert len(expr.expressions) == 2
                return UserDefinedFunction(
                    name="TIMEZONE",
                    expressions=self.walk_tree(expr.expressions),
                )

        return self.walk_tree(expr)