Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
from enum import Enum
import copy
import os
import typing as t

from snsql._ast.ast import (
    Aggregate,
    BareFunction,
    BooleanCompare,
    Column,
    FuncName,
    GroupingExpression,
    Limit,
    Literal,
    MathFunction,
    Order,
    Query,
    Seq,
    SortItem,
    Sql,
    Table,
    Token,
    Top,
    UnifiedQuery,
    UserDefinedFunction,
)
from snsql._ast.expressions.date import ExtractFunction
from snsql._ast.expressions.string import (
    CharLengthFunction,
    CoalesceFunction,
    ConcatFunction,
    DecodeFunction,
    EncodeFunction,
    RegexpSplitToTableFunction,
)
from snsql._ast.expressions.types import CastFunction

from sarus_sql.dialects import SQLDialect

from .base import BaseTranslator
from .charset_inferring_functions import char_inference_query, custom_charset

USER_DEFINED_AGGS = (
    "APPROX_COUNT_DISTINCT",
    "CHECKSUM_AGG",
    "COUNT_BIG",  # TODO: this should be translated into count
    "GROUPING",
    "GROUPING_ID",
    "STDEVP",  # TODO: equivalent STDEVPOP.
    "STRING_AGG",
    "VARP",
)


class MSSQL_CHARSET(Enum):
    """Enum to controll for the charset inference query."""

    # It will generate a query for the Basic Multilingual Plane (BMP)
    # charset 0-65535
    BMP = "BMP"
    # It will generate a query for the ASCII charset (0-255)
    ASCII = "ASCII"
    # it will infer the charset. It assumes that a given row of text has
    # at most 1 milion characters.
    INFER = "INFER"


class SqlServerTranslator(BaseTranslator):
    """Translate Postgres <-> SqlServer (Azure, Synaps) by modifying
    the query AST.

    Current supported translations:

    - LN <-> LOG
    - RANDOM <-> RAND

    LN and LOG in postgres are respectively the natural log and the base 10 log
    LOG in MSSQL is the natural log.
    """

    def __init__(
        self, query: t.Union[Query, UnifiedQuery], **kwargs: t.Any
    ) -> None:
        # In Azure SQL
        # if queries with CTEs and the main query not using any of
        # the CTEs they must be removed otherwise an error will be raise
        # from the db.
        if isinstance(query, Query):
            if query.select.ctes is not None:
                cte_names = set(
                    [cte_table.name for cte_table in query.select.ctes]
                )
            else:
                cte_names = set()

            tables_in_the_main_select = set(
                [
                    tab.name
                    for tab in query.select.namedExpressions.find_nodes(Table)
                ]
            )
            if (
                cte_names.isdisjoint(tables_in_the_main_select)
                and query.source is None
            ):
                query.select.ctes = None
        super().__init__(query, **kwargs)

    def transform_to_dialect(self, expr: Sql) -> Sql:
        """Returns the transformed expression"""
        if self.query.limit:
            self.query.select.quantifier = Top(self.query.limit.n)
            self.query.limit = None

        if isinstance(expr, Column):
            return expr

        if isinstance(expr, MathFunction) and expr.name == FuncName("LN"):
            return MathFunction(
                name=FuncName("LOG"),
                expression=self.transform(expr.expression),
            )

        if isinstance(expr, BareFunction) and expr.name == FuncName("RANDOM"):
            # RAND in MSSQL will generate only 1 random number
            return BareFunction(
                name=FuncName("RAND"),
            )

        if isinstance(expr, CharLengthFunction):
            return CharLengthFunction(
                expression=self.walk_tree(expr.expression), token=Token("LEN")
            )

        if isinstance(expr, ExtractFunction) and expr.date_part == "epoch":
            # It converts EXTRACT(epoch FROM column) into
            # DATEDIFF(SECOND, '19700101', column)
            # used to compute bigdata bounds for date and datetimes
            return UserDefinedFunction(
                name="DATEDIFF",
                expressions=Seq(
                    [
                        Token("SECOND"),
                        Literal("'19700101'"),
                        self.walk_tree(expr.expression),
                    ]
                ),
            )

        if (
            isinstance(expr, EncodeFunction)
            and expr.format == Literal(None, "'escape'")
            and isinstance(expr.source, DecodeFunction)
            and expr.source.format == Literal(None, "'hex'")
        ):
            # Encode(Decode(source, 'hex'), 'escape') to
            # CONVERT(
            #   VARCHAR(MAX),
            #   CONVERT(VARBINARY(MAX), '50726976617465', 2),
            #   0
            # )
            source = self.walk_tree(expr.source.source)
            hex_to_binary = UserDefinedFunction(
                name="CONVERT",
                expressions=Seq(["VARBINARY(MAX)", source, Literal(2)]),
            )
            binary_to_ascii = UserDefinedFunction(
                name="CONVERT",
                expressions=Seq(["VARCHAR(MAX)", hex_to_binary, Literal(0)]),
            )
            return binary_to_ascii

        if isinstance(expr, ConcatFunction):
            # Concat in MSSQL it has to take at least 2 args
            exprs = [self.transform(e) for e in expr.expressions] + [
                Literal(None, "''")
            ]
            return ConcatFunction(exprs)

        if isinstance(expr, Query) and any(
            [
                isinstance(name_expr.expression, RegexpSplitToTableFunction)
                for name_expr in expr.select.namedExpressions
            ]
        ):
            # expected a query like this one:
            # SELECT
            #    DISTINCT REGEXP_SPLIT_TO_TABLE(anon_2.name ,'')
            #    AS "regexp_split"
            # FROM (
            #    SELECT "tab1"."name" AS "name"
            #    FROM "tab1" WHERE "tab1"."name" IS NOT NULL
            # ) AS "anon_2"
            # otherwise raise an error

            try:
                # some asserts to male sure we have the expected query.
                assert isinstance(
                    expr.select.namedExpressions[0].expression,
                    RegexpSplitToTableFunction,
                )
                assert len(expr.select.namedExpressions) == 1
                # it would be the anon_2.name. the column for which we
                # want to find the charset.
                col = expr.select.namedExpressions[0].expression.source
                # It is the alias of the charset
                alias = expr.select.namedExpressions[0].name
                # it is the source of the column on which we want to compute
                # the carset.
                source_relation = expr.source.relations[0]
            except RuntimeError:
                raise ValueError(
                    f"We don't know how to translate {expr} "
                    "into MSSQL dialect."
                )

            # We set by default ASCII charset (0-255 unicode values)
            # for efficiency. If it is known in advance that the dataset
            # has characters other than ASCII, MSSQL_CHARSET should be set
            # accordingly.
            charset = os.environ.get(
                "MSSQL_CHARSET",
                default="ASCII",
            )  # type: ignore

            if charset == MSSQL_CHARSET.INFER.value:
                col = expr.select.namedExpressions[0].expression.source
                alias = expr.select.namedExpressions[0].name
                source_relation = expr.source.relations[0]
                return char_inference_query(
                    column=col,
                    col_alias=alias,
                    source_relation=source_relation,
                    max_string_length=1000000,
                    dialect=SQLDialect.SQL_SERVER,
                )
            elif charset == MSSQL_CHARSET.ASCII.value:
                return custom_charset(alias, 255, SQLDialect.SQL_SERVER)
            else:
                return custom_charset(alias, 65535, SQLDialect.SQL_SERVER)

        if isinstance(expr, UserDefinedFunction):
            if expr.name == "MD5":
                # convert MD5(X) to
                # CONVERT(VARCHAR(MAX), HASHBYTES('MD5', X), 2)
                expr_arg = self.transform(expr.expressions[0])
                hashbytes = UserDefinedFunction(
                    name="HASHBYTES",
                    expressions=Seq([Literal("'MD5'"), expr_arg]),
                )
                covert = UserDefinedFunction(
                    name="CONVERT",
                    expressions=Seq(["VARCHAR(MAX)", hashbytes, Literal(2)]),
                )
                return covert

            if expr.name.upper() in USER_DEFINED_AGGS:
                raise NotImplementedError(
                    f"Aggregate functions {expr.name}"
                    f" not supported: {expr}"
                )
        if isinstance(expr, Query):
            # it fixes a sampling query (SELECT TOP x .. ORDER BY RAND())
            # when optimized with a WHERE clause. such query gives unstable
            # results (zero rows sometimes):
            # I tried WHERE RAND () < x:
            #  WHERE RAND( CHECKSUM(NEWID()) ) < x:
            #  WHERE RAND( ABS(CAST(CAST(new_id AS VARBINARY) AS INT))) < x:
            # https://www.sqlservercentral.com/forums/topic/whats-the-best-way-to-get-a-sample-set-of-a-big-table-without-primary-key#post-1948778  # noqa: E501
            new_query = copy.deepcopy(expr)
            if new_query.limit:
                if new_query.select.quantifier == Token("DISTINCT"):
                    # If having SELECT DISTINCT "some_col", ... LIMIT x
                    # It will be rewritten as:
                    # SELECT TOP x "some_col", ... GROUP BY "some_col"
                    # The other columns will not be inserted in the group by
                    # because of this query:
                    #   SELECT
                    #     DISTINCT "col" AS "sarus_privacy_unit",
                    #     RAND () AS random_1
                    #   FROM
                    #     "rejpsxndax"
                    #   ORDER BY
                    #     RAND ()
                    #   LIMIT
                    #     3256
                    # In this case I can't put random_1 in the GROUP BY
                    # because isn't a valid column.
                    first_expr = new_query.select.namedExpressions[0]
                    group = (
                        GroupingExpression(first_expr.name)
                        if first_expr.name is not None
                        else GroupingExpression(first_expr.expression)
                    )
                    new_query.agg = Aggregate(groupingExpressions=[group])
                new_query.select.quantifier = Top(new_query.limit.n)
                new_query.limit = None
            if (
                new_query.where
                and isinstance(new_query.where.condition, BooleanCompare)
                and isinstance(new_query.where.condition.left, BareFunction)
                and new_query.where.condition.left.name == FuncName("RANDOM")
            ):
                new_query.where = None
            if new_query.order:
                if isinstance(
                    new_query.order.sortItems[0].expression, BareFunction
                ) and new_query.order.sortItems[0].expression.name == FuncName(
                    "RANDOM"
                ):
                    new_query.order = Order(
                        [
                            SortItem(
                                expression=UserDefinedFunction(
                                    name="NEWID", expressions=Seq([])
                                ),
                                order=None,
                            )
                        ]
                    )
            return self.walk_tree(new_query)

        if isinstance(expr, CastFunction) and (
            expr.dbtype == "varchar" or expr.dbtype == "text"
        ):
            return UserDefinedFunction(
                name="CONVERT",
                expressions=Seq(
                    ["VARCHAR", self.walk_tree(expr.expression), Literal(126)]
                ),
            )

        if isinstance(expr, CastFunction) and (expr.dbtype == "boolean"):
            return CastFunction(self.walk_tree(expr.expression), dbtype="BIT")

        if isinstance(expr, Literal) and isinstance(expr.value, bool):
            return Literal(int(expr.value))

        return self.walk_tree(expr)

    def transform_to_postgres(self, expr: Sql) -> Sql:
        """Returns the transformed expression"""
        if self.query.select.quantifier:
            self.query.limit = Limit(self.query.select.quantifier.n)
            self.query.select.quantifier = None

        if isinstance(expr, Column):
            return expr

        if isinstance(expr, MathFunction) and expr.name == FuncName("LOG"):
            return MathFunction(
                name=FuncName("LN"), expression=self.transform(expr.expression)
            )

        if isinstance(expr, BareFunction) and expr.name == FuncName("RAND"):
            return BareFunction(name=FuncName("RANDOM"))

        if isinstance(expr, UserDefinedFunction):
            if (
                expr.name == "DATEDIFF"
                and expr.expressions[0] == Token("SECOND")
                and expr.expressions[1] == Literal("'19700101'")
            ):
                # it translates DATEDIFF(SECOND, '19700101', column)
                # to EXTRACT(epoch FROM column)
                return ExtractFunction("epoch", expr.expressions[2])

            if (
                expr.name == "CONVERT"
                and isinstance(expr.expressions[1], UserDefinedFunction)
                and expr.expressions[1].name == "HASHBYTES"
            ):
                # CONVERT( VARCHAR(MAX), HASHBYTES('MD5', X), 2)
                # convert MD5(X) to
                return UserDefinedFunction(
                    name="MD5",
                    expressions=Seq(
                        [self.walk_tree(expr.expressions[1].expressions[1])]
                    ),
                )
            if (
                expr.name == "CONVERT"
                and isinstance(expr.expressions[1], UserDefinedFunction)
                and expr.expressions[1].name == "CONVERT"
            ):
                # CONVERT(
                #   VARCHAR(MAX),
                #   CONVERT(VARBINARY(MAX), source, 2),
                #   0
                # ) to
                # ENCODE(DECODE(source, 'hex'), 'escape')
                source = self.walk_tree(expr.expressions[1].expressions[1])
                return EncodeFunction(
                    source=DecodeFunction(
                        source=source,
                        format=Literal(None, "'hex'"),
                    ),
                    format=Literal(None, "'escape'"),
                )
            if expr.name == "IFNULL":
                return CoalesceFunction(
                    expressions=Seq(
                        [
                            self.transform(udf_expr)
                            for udf_expr in expr.expressions
                        ]
                    ),
                )
        return self.walk_tree(expr)