Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
from enum import Enum
import os
import typing as t

from snsql._ast.ast import (
    BetweenCondition,
    CaseExpression,
    Column,
    Literal,
    NestedBoolean,
    PredicatedExpression,
    Query,
    Sql,
    SqlExpr,
    Table,
    UnifiedQuery,
    UserDefinedFunction,
)
from snsql._ast.expressions.string import (
    ConcatFunction,
    DecodeFunction,
    EncodeFunction,
    RegexpSplitToTableFunction,
)

from sarus_sql.dialects import SQLDialect

from .base import BaseTranslator
from .charset_inferring_functions import char_inference_query, custom_charset

# Redshift aggregate functions that are not common with Postgres
USER_DEFINED_AGGS = (
    "ANY_VALUE",
    "LISTAGG",
    "MEDIAN",
)


def encode_decode_redshift(pg_format: Literal) -> Literal:
    """
    Translate encoding/decoding formats from PostgreSQL to Redshift.
    The function supports three encoding formats: 'escape',
    'hex', and 'base64'.

    Translation rules:
    escape -> utf8
    hex -> hex
    base64 -> base64
    """
    if pg_format == Literal(None, "'escape'"):
        encode_format = Literal(None, "'utf8'")
    elif pg_format in [Literal(None, "'hex'"), Literal(None, "'base64'")]:
        encode_format = pg_format
    else:
        raise ValueError(
            "ENCODE format can only be one of ('base64', 'hex',"
            f" 'escape'). Got {pg_format}"
        )
    return encode_format


def encode_decode_postgres(rs_format: Literal) -> Literal:
    """
    Translate encoding/decoding formats from Redshift to PostgreSQL.
    The function supports three encoding formats: 'utf8', 'hex', and 'base64'.

    Translation rules:
    utf8 -> escape
    hex -> hex
    base64 -> base64
    """
    if rs_format == Literal(None, "'utf8'"):
        encode_format = Literal(None, "'escape'")
    elif rs_format in [Literal(None, "'hex'"), Literal(None, "'base64'")]:
        encode_format = rs_format
    else:
        raise ValueError(
            "ENCODE format can only be one of ('base64', 'hex',"
            f" 'utf8'). Got {rs_format}"
        )
    return encode_format


class DETECT_CHARSET(Enum):
    """Enum to controll for the charset inference query."""

    # It will generate a query for the Basic Multilingual Plane (BMP)
    # subset of charset 0-65535, which is 0-55295
    # The reason of reduced charset is in Redshift limitation, which
    # raises an error for some of characters with a code more then 55295:
    # The UTF-8 character is reserved as a surrogate.
    # Surrogate code points (U+D800 through U+DFFF) are invalid.
    REDSHIFT_BMP = "REDSHIFT_BMP"
    # It will generate a query for the ASCII charset (0-255)
    ASCII = "ASCII"
    # it will infer the charset. It assumes that a given row of text has
    # at most 1 milion characters.
    INFER = "INFER"


class RedshiftTranslator(BaseTranslator):
    """
    Translator for Redshift <-> Postgres.
    """

    def __init__(
        self, query: t.Union[Query, UnifiedQuery], **kwargs: t.Any
    ) -> None:
        super().__init__(query, **kwargs)
        # self.quote = "`"
        # self.alias_map = self.alias_mapping()
        # self.rename_tables_columns()

    def transform_to_dialect(self, expr: Sql) -> Sql:
        """
        Walks through the tree of a parsed SQL query and translates its
        elements from PostgreSQL to Redshift.
        """
        # empty condition allows to avoid a recursion error
        if isinstance(expr, Column):
            return expr

        if isinstance(expr, CaseExpression):
            when_args = []

            for when_expr in expr.when_exprs:
                if isinstance(when_expr.expression, NestedBoolean):
                    new_when_expr = when_expr.expression
                    if isinstance(
                        new_when_expr.expression, PredicatedExpression
                    ) and isinstance(
                        new_when_expr.expression.predicate, BetweenCondition
                    ):
                        predicate = new_when_expr.expression.predicate

                        if (
                            isinstance(predicate.lower, Literal)
                            and isinstance(predicate.lower.value, float)
                            and predicate.lower.value < 2.2250738585072014e-308
                        ):
                            predicate.lower = Literal(value=0.0)
                        if (
                            isinstance(predicate.upper, Literal)
                            and isinstance(predicate.upper.value, float)
                            and predicate.upper.value < 2.2250738585072014e-308
                        ):
                            predicate.upper = Literal(value=0.0)
                        new_when_expr.predicate = predicate
                        when_expr.expression = new_when_expr

                if (
                    isinstance(when_expr.then, Literal)
                    and isinstance(when_expr.then.value, float)
                    and when_expr.then.value < 2.2250738585072014e-308
                ):
                    when_expr.then = Literal(value=0.0)

                when_args.append(when_expr)
                expr.when_exprs = when_args

            expr.else_expr = self.walk_tree(expr.else_expr)

        if isinstance(expr, ConcatFunction):
            expressions_to_concat = self.walk_tree(expr.expressions)
            arg0 = expressions_to_concat.pop()
            if len(expressions_to_concat) == 0:
                return ConcatFunction([arg0, Literal(None, "''")])
            for arg in expressions_to_concat:
                arg0 = ConcatFunction([arg0, arg])
            return arg0

        if isinstance(expr, DecodeFunction):
            # prepare format for future decode function
            decode_format = encode_decode_redshift(expr.format)

            # check if the source is decode function
            if isinstance(expr.source, EncodeFunction):
                encode_format = encode_decode_redshift(expr.source.format)
                decode_source = UserDefinedFunction(
                    name="FROM_VARBYTE",
                    expressions=[
                        expr.source.source,
                        encode_format,
                    ],
                )
            else:
                decode_source = self.walk_tree(expr.source)

            return UserDefinedFunction(
                name="TO_VARBYTE",
                expressions=[
                    decode_source,
                    decode_format,
                ],
            )

        if isinstance(expr, EncodeFunction):
            # prepare format for future encode function
            encode_format = encode_decode_redshift(expr.format)

            # check if the source is decode function
            if isinstance(expr.source, DecodeFunction):
                decode_format = encode_decode_redshift(expr.source.format)
                encode_source = UserDefinedFunction(
                    name="TO_VARBYTE",
                    expressions=[
                        expr.source.source,
                        decode_format,
                    ],
                )
            else:
                encode_source = self.walk_tree(expr.source)

            return UserDefinedFunction(
                name="FROM_VARBYTE",
                expressions=[
                    encode_source,
                    encode_format,
                ],
            )

        # check if UserDefinedFunction is not one of
        # aggregation functions, available only in Redshift and
        # not in PostgreSQL
        if (
            isinstance(expr, UserDefinedFunction)
            and expr.name.upper() in USER_DEFINED_AGGS
        ):
            raise NotImplementedError(
                f"Aggregate functions {expr.name}" f" not supported: {expr}"
            )
            return None

        if isinstance(expr, Query) and any(
            [
                isinstance(name_expr.expression, RegexpSplitToTableFunction)
                for name_expr in expr.select.namedExpressions
            ]
        ):
            # expected a query like this one:
            # SELECT
            #    DISTINCT REGEXP_SPLIT_TO_TABLE(anon_2.name ,'')
            #    AS "regexp_split"
            # FROM (
            #    SELECT "tab1"."name" AS "name"
            #    FROM "tab1" WHERE "tab1"."name" IS NOT NULL
            # ) AS "anon_2"
            # otherwise raise an error
            try:
                # some asserts to make sure we have the expected query.
                assert isinstance(
                    expr.select.namedExpressions[0].expression,
                    RegexpSplitToTableFunction,
                )
                assert len(expr.select.namedExpressions) == 1
                # it would be the anon_2.name. the column for which we
                # want to find the charset.
                col = expr.select.namedExpressions[0].expression.source
                # It is the alias of the charset
                alias = expr.select.namedExpressions[0].name
                # it is the source of the column on which we want to compute
                # the carset.
                source_relation = expr.source.relations[0]
            except RuntimeError:
                raise ValueError(
                    f"We don't know how to translate {expr} "
                    "into Redshift dialect."
                )

            # We set by default ASCII charset (0-255 unicode values)
            # for efficiency. If it is known in advance that the dataset
            # has characters other than ASCII, DETECT_CHARSET should be set
            # accordingly.
            # We use an obsolete env variable MSSQL_CHARSET here, since
            # this part will be supressed later.
            charset = os.environ.get(
                "MSSQL_CHARSET",
                default="ASCII",
            )  # type: ignore

            if charset == DETECT_CHARSET.INFER.value:
                col = expr.select.namedExpressions[0].expression.source
                alias = expr.select.namedExpressions[0].name
                source_relation = expr.source.relations[0]
                return char_inference_query(
                    column=col,
                    col_alias=alias,
                    source_relation=source_relation,
                    max_string_length=1000000,
                    dialect=SQLDialect.REDSHIFT,
                )
            elif charset == DETECT_CHARSET.ASCII.value:
                return custom_charset(alias, 255, SQLDialect.REDSHIFT)
            else:
                return custom_charset(alias, 55295, SQLDialect.REDSHIFT)

        return self.walk_tree(expr)

    def transform_to_postgres(self, expr: SqlExpr) -> SqlExpr:
        """
        Translate a SQL query back from Redshift to PostgreSQL.
        """

        if (
            isinstance(expr, UserDefinedFunction)
            and expr.name == "FROM_VARBYTE"
        ):
            # prepare format for future decode function
            encode_format = encode_decode_postgres(expr.expressions[1])

            # check if 1st argument is decode function
            if (
                isinstance(expr.expressions[0], UserDefinedFunction)
                and expr.expressions[0].name == "TO_VARBYTE"
            ):
                decode_format = encode_decode_postgres(
                    expr.expressions[0].expressions[1]
                )
                decode_source = DecodeFunction(
                    source=self.walk_tree(expr.expressions[0].expressions[0]),
                    format=decode_format,
                )
            else:
                decode_source = self.walk_tree(expr.expressions[0])

            return EncodeFunction(
                source=decode_source,
                format=encode_format,
            )

        if isinstance(expr, UserDefinedFunction) and expr.name == "TO_VARBYTE":
            # prepare format for future decode function
            decode_format = encode_decode_postgres(expr.expressions[1])

            # check if the source is encode function
            if (
                isinstance(expr.expressions[0], UserDefinedFunction)
                and expr.expressions[0].name == "FROM_VARBYTE"
            ):
                decode_format = encode_decode_postgres(
                    expr.expressions[0].expressions[1]
                )
                decode_source = EncodeFunction(
                    source=self.walk_tree(expr.expressions[0].expressions[0]),
                    format=decode_format,
                )
            else:
                decode_source = self.walk_tree(expr.expressions[0])

            return DecodeFunction(
                source=decode_source,
                format=decode_format,
            )

        # empty condition allows to avoid a recursion error
        if isinstance(expr, Column) or isinstance(expr, Table):
            return expr

        return self.walk_tree(expr)