Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import copy
import typing as t

from snsql._ast.ast import (
    AliasedSubquery,
    BareFunction,
    Column,
    From,
    FuncName,
    Identifier,
    Literal,
    MathFunction,
    NamedExpression,
    Query,
    Relation,
    Select,
    Seq,
    Sql,
    SqlExpr,
    Table,
    Token,
    UnifiedQuery,
    Unnest,
    UserDefinedFunction,
)
from snsql._ast.expressions.date import ExtractFunction
from snsql._ast.expressions.string import (
    CoalesceFunction,
    DecodeFunction,
    EncodeFunction,
    HexFunction,
    RegexpSplitToTableFunction,
    SubstrBigQueryFunction,
    SubstringFunction,
    UnhexFunction,
)
from snsql._ast.expressions.types import CastFunction

from ..table_renamer import TableRenamer
from ..utils import is_quoted, split_quote
from .base import BaseTranslator

# Big query aggregate functions that are not common with Postgres
USER_DEFINED_AGGS = (
    "ANY_VALUE",
    "ARRAY_CONCAT_AGG",
    "BIT_XOR",
    "COUNTIF",
    "LOGICAL_AND",
    "LOGICAL_OR",
)


class BigQueryTranslator(BaseTranslator):
    """Translator for BigQuery <-> Postgres.
    Bigquery quoting:
    https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical

    We quote all aliases to avoid conflicts with BQ keywords
    Moreover, aliases can't contain special characters like dot etc..
    """

    def __init__(
        self, query: t.Union[Query, UnifiedQuery], **kwargs: t.Any
    ) -> None:
        super().__init__(query, **kwargs)
        self.quote = "`"
        self.alias_map = self.alias_mapping()
        self.rename_tables_columns()

    def alias_mapping(self) -> t.Dict[str, str]:
        """it generates a dict mapping actual aliases into normalized aliases.
        In normalized aliases, spaces and dots are replaced with underscores

        Returns:
            Dict[str, str]: actual_alias: normalized alias
        """

        tables = self.query.find_nodes(Table)
        exprs = self.query.find_nodes(NamedExpression)
        aliased_sub = self.query.find_nodes(AliasedSubquery)

        aliases = {}
        for expr in [*tables, *aliased_sub]:
            if expr.alias:
                aliases[expr.alias.text] = expr.alias.text.replace(
                    " ", "_"
                ).replace(".", "_")

        for expr in exprs:
            if expr.name:
                aliases[expr.name.text] = expr.name.text.replace(
                    " ", "_"
                ).replace(".", "_")

        return aliases

    def rename_tables_columns(self) -> None:
        """Using TableRenamer to rename columns and tables. There are aliases
        left to be renamed.
        """
        # map_dict should contain escaped names
        compatible_alias_map = {
            split_quote(key, ""): split_quote(val, "")
            for key, val in self.alias_map.items()
        }
        renamer = TableRenamer(copy.deepcopy(self.query), compatible_alias_map)
        renamer.rename()
        self.query = renamer.query

    def rename_add_quotes(self, alias: str, quote: str) -> str:
        """renames aliases. To avoid conflicts with possible bigquery
        keywords we backquote all aliases.
        """
        if alias in self.alias_map:
            alias = self.alias_map[alias]

        if is_quoted(alias):
            alias = f"{quote}{alias[1:-1]}{quote}"
        else:
            alias = f"{quote}{alias}{quote}"
        return alias

    def transform_to_dialect(self, expr: Sql) -> Sql:
        if isinstance(expr, MathFunction) and expr.name == FuncName("LN"):
            return MathFunction(
                name=FuncName("LOG"),
                expression=self.transform(expr.expression),
            )

        if isinstance(expr, BareFunction) and expr.name == FuncName("RANDOM"):
            return BareFunction(name=FuncName("RAND"))

        if isinstance(expr, Column):
            """Changes or add backquoting quoting"""
            parts = split_quote(expr.name, self.quote)
            expr.name = ".".join(parts[-2:])
            return expr

        elif isinstance(expr, Table):
            parts = split_quote(expr.name, self.quote)
            expr.name = ".".join(parts)
            if expr.alias:
                expr.alias = Identifier(
                    self.rename_add_quotes(expr.alias.text, self.quote)
                )
            return expr

        elif isinstance(expr, AliasedSubquery):
            if expr.alias:
                expr.alias = Identifier(
                    self.rename_add_quotes(expr.alias.text, self.quote)
                )

        elif isinstance(expr, SubstringFunction):
            return SubstrBigQueryFunction(
                self.walk_tree(expr.source),
                self.walk_tree(expr.start),
                self.walk_tree(expr.length),
            )

        elif isinstance(expr, NamedExpression):
            """Add backquoting to aliases"""
            if expr.name:
                expr.name = Identifier(
                    self.rename_add_quotes(expr.name.text, self.quote)
                )

        # Encode(Decode(source, 'hex'), 'escape') -> Unhex(source)
        elif (
            isinstance(expr, EncodeFunction)
            and expr.format == Literal(None, "'escape'")
            and isinstance(expr.source, DecodeFunction)
            and expr.source.format == Literal(None, "'hex'")
        ):
            return CastFunction(
                UnhexFunction(
                    self.walk_tree(expr.source.source), token=Token("FROM_HEX")
                ),
                "string",
            )

        if isinstance(expr, CastFunction):
            if (
                expr.dbtype.startswith("varchar")
                or expr.dbtype.startswith("char")
                or expr.dbtype == "text"
            ):
                expr.dbtype = "string"

            if (
                isinstance(expr.expression, UserDefinedFunction)
                and expr.expression.name == "TIMEZONE"
            ):
                assert len(expr.expression.expressions) == 2
                expr.expression = UserDefinedFunction(
                    name="DATETIME",
                    expressions=self.walk_tree(expr.expression.expressions)[
                        ::-1
                    ],
                )

        if isinstance(expr, Query) and any(
            [
                isinstance(name_expr.expression, RegexpSplitToTableFunction)
                for name_expr in expr.select.namedExpressions
            ]
        ):
            # REGEXP_SPLIT_TO_TABLE be translated as with a SPLIT and UNNEST:
            # example:
            #
            # SELECT DISTINCT REGEXP_SPLIT_TO_TABLE ( tab.name , '' ) AS col
            # FROM (
            #   SELECT name AS name FROM mytable WHERE name IS NOT NULL
            # ) AS tab
            #
            # it becomes:
            #
            # SELECT DISTINCT unnested AS col
            # FROM (
            #    SELECT SPLIT(tab.name , '') AS splitted
            #    FROM ( SELECT name AS name FROM mytable WHERE name IS NOT NULL
            #    ) AS tab
            # ) AS splitted_subquery, UNNEST(split_subq.splitted) AS unnested
            #

            main_named_exprs = []
            for named_expr in expr.select.namedExpressions:
                if isinstance(
                    named_expr.expression, RegexpSplitToTableFunction
                ):
                    regexp_alias = named_expr.name
                    regexp_source_expr = self.walk_tree(
                        named_expr.expression.source
                    )
                    regexp_pattern = self.walk_tree(
                        named_expr.expression.pattern
                    )
                else:
                    main_named_exprs.append(named_expr)

            # building the subquery
            split = UserDefinedFunction(
                name="SPLIT",
                expressions=Seq([regexp_source_expr, regexp_pattern]),
            )
            namedexpr = NamedExpression(
                expression=split, name=Identifier("splitted")
            )
            select = Select(quantifier=None, namedExpressions=Seq([namedexpr]))
            subquery = Query(
                select=select,
                source=expr.source,
                where=None,
                agg=None,
                having=None,
                order=None,
                limit=None,
            )
            aliased_subquery = AliasedSubquery(
                query=subquery, alias=Identifier("split_subq")
            )
            relation_with_subquery = Relation(
                primary=aliased_subquery, joins=None
            )
            relation_with_unnest = Relation(
                primary=Unnest(
                    name=Column("split_subq.splitted"), alias="unnested"
                ),
                joins=None,
            )

            main_nexpr = NamedExpression(
                expression=Column("unnested"), name=regexp_alias
            )
            main_named_exprs.append(main_nexpr)
            main_select = Select(
                quantifier=expr.select.quantifier,
                namedExpressions=Seq(main_named_exprs),
            )
            main_source = Seq([relation_with_subquery, relation_with_unnest])
            expr.select = main_select
            expr.source = From(main_source)

        if isinstance(expr, ExtractFunction) and expr.date_part == "epoch":
            # It converts EXTRACT(epoch FROM column) into
            # UNIX_SECONDS(CAST(col AS TIMESTAMP))
            # used to compute bigdata bounds for date and datetimes
            cast = CastFunction(
                self.walk_tree(expr.expression), dbtype="timestamp"
            )
            return UserDefinedFunction(
                name="UNIX_SECONDS",
                expressions=Seq([cast]),
            )

        if isinstance(expr, CoalesceFunction):
            return UserDefinedFunction(
                name="IFNULL",
                expressions=[self.walk_tree(e) for e in expr.expressions],
            )

        if isinstance(expr, UserDefinedFunction):
            # convert MD5(x) -> TO_HEX(MD5(X))
            if expr.name == "MD5":
                new_md5 = UserDefinedFunction(
                    name=expr.name,
                    expressions=self.walk_tree(expr.expressions),
                )
                return HexFunction(
                    token=Token("TO_HEX"),
                    source=new_md5,
                )

            # convert TIMEZONE(tz, ts)
            # -> CAST(DATETIME(ts, tz)AS TIMESTAMP)
            if expr.name == "TIMEZONE":
                assert len(expr.expressions) == 2
                return CastFunction(
                    expression=UserDefinedFunction(
                        name="DATETIME",
                        expressions=self.walk_tree(expr.expressions)[::-1],
                    ),
                    dbtype="timestamp",
                )

            if expr.name.upper() in USER_DEFINED_AGGS:
                raise NotImplementedError(
                    f"Aggregate functions {expr.name}"
                    f" not supported: {expr}"
                )

        return self.walk_tree(expr)

    def transform_to_postgres(self, expr: SqlExpr) -> SqlExpr:
        if isinstance(expr, MathFunction) and expr.name == FuncName("LOG"):
            return MathFunction(
                name=FuncName("LN"), expression=self.transform(expr.expression)
            )

        if isinstance(expr, BareFunction) and expr.name == FuncName("RAND"):
            return BareFunction(name=FuncName("RANDOM"))

        if isinstance(expr, Column) or isinstance(expr, Table):
            ident_quote = '"'
            parts = split_quote(expr.name, ident_quote)
            expr.name = ".".join(parts)
            if isinstance(expr, Table) and expr.alias:
                expr.alias = Identifier(
                    self.rename_add_quotes(expr.alias.text, ident_quote)
                )
            return expr

        elif isinstance(expr, SubstrBigQueryFunction):
            return SubstringFunction(
                self.walk_tree(expr.source),
                self.walk_tree(expr.start),
                self.walk_tree(expr.length),
                {"FROM": ",", "FOR": ","},  # tokens
            )

        elif isinstance(expr, NamedExpression) and expr.name:
            expr.name = Identifier(self.rename_add_quotes(expr.name.text, '"'))

        elif isinstance(expr, AliasedSubquery) and expr.alias:
            expr.alias = Identifier(
                self.rename_add_quotes(expr.alias.text, '"')
            )

        elif isinstance(expr, CastFunction):
            # CAST(FROM_HEX(source) AS STRING)
            # -> Encode(Decode(source, 'escape'),'hex')
            if isinstance(expr.expression, UnhexFunction):
                return EncodeFunction(
                    source=DecodeFunction(
                        source=self.walk_tree(expr.expression.source),
                        format=Literal(None, "'hex'"),
                    ),
                    format=Literal(None, "'escape'"),
                )

            # CAST(DATETIME(ts, tz) AS TIMESTAMP)
            # -> TIMEZONE(tz, ts)
            if (
                expr.dbtype == "timestamp"
                and isinstance(expr.expression, UserDefinedFunction)
                and expr.expression.name == "DATETIME"
            ):
                assert len(expr.expression.expressions) == 2
                return UserDefinedFunction(
                    name="TIMEZONE",
                    expressions=self.walk_tree(expr.expression.expressions)[
                        ::-1
                    ],
                )

            if expr.dbtype.startswith("string"):
                expr.dbtype = "varchar"

        if isinstance(expr, HexFunction) and isinstance(
            expr.source, UserDefinedFunction
        ):
            return UserDefinedFunction(
                name="MD5", expressions=self.walk_tree(expr.source.expressions)
            )

        if isinstance(expr, UserDefinedFunction):
            # it translates UNIX_SECONDS(CAST(col AS TIMESTAMP))
            # to EXTRACT(epoch FROM column)
            if (
                expr.name == "UNIX_SECONDS"
                and isinstance(expr.expressions[0], CastFunction)
                and expr.expressions[0].dbtype == "timestamp"
            ):
                return ExtractFunction(
                    "epoch", self.walk_tree(expr.expressions[0].expression)
                )

            # convert DATETIME(ts, tz)
            # -> TIMEZONE(tz, ts)
            if expr.name == "DATETIME":
                assert len(expr.expressions) == 2
                return UserDefinedFunction(
                    name="TIMEZONE",
                    expressions=self.walk_tree(expr.expressions)[::-1],
                )

            if expr.name == "IFNULL":
                return CoalesceFunction(
                    expressions=Seq(
                        [
                            self.transform(udf_expr)
                            for udf_expr in expr.expressions
                        ]
                    ),
                )

        return self.walk_tree(expr)