Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import typing as t

from snsql._ast.ast import (
    AliasedSubquery,
    BareFunction,
    Column,
    FuncName,
    Identifier,
    Literal,
    MathFunction,
    NamedExpression,
    Query,
    Seq,
    Sql,
    Table,
    UserDefinedFunction,
)
from snsql._ast.expressions.string import (
    CoalesceFunction,
    DecodeFunction,
    EncodeFunction,
    UnhexFunction,
)

from ..utils import is_quoted, split_quote
from .base import BaseTranslator

USER_DEFINED_AGGS = (
    "BIT_XOR",
    "GROUP_CONCAT",
    "JSON_ARRAYAGG",
    "JSON_OBJECTAGG",
    "STDDEV_POP",
    "STDDEV_SAMP",
    "VAR_POP",
    "VAR_SAMP",
)


class MySQLTranslator(BaseTranslator):
    """Translator for Postgres <-> MySQL (5)

    Current supported translations:

    - LN <-> LOG
    - RANDOM <-> RAND
    - Encode(Decode(source, 'hex'), 'escape') <-> UNHEX(source)
    LN and LOG in postgres are respectively the natural log and the base 10 log
    LOG in MySQL is the natural log.
    - we quote the table and subquery names with `
    """

    def __init__(self, query: Query, **kwargs: t.Any) -> None:
        super().__init__(query, **kwargs)
        self.quote = "`"

    def transform_to_dialect(self, expr: Sql) -> Sql:
        if isinstance(expr, Column):
            """Changes or add backquoting quoting"""
            parts = split_quote(expr.name, self.quote)
            expr.name = ".".join(parts[-2:])
            return expr

        elif isinstance(expr, Table):
            parts = split_quote(expr.name, self.quote)
            expr.name = ".".join(parts)
            if expr.alias:
                expr.alias = Identifier(
                    self.rename_add_quotes(expr.alias.text, self.quote)
                )
            return expr

        elif isinstance(expr, AliasedSubquery):
            if expr.alias:
                expr.alias = Identifier(
                    self.rename_add_quotes(expr.alias.text, self.quote)
                )

        if isinstance(expr, MathFunction) and expr.name == FuncName("LN"):
            return MathFunction(
                name=FuncName("LOG"),
                expression=self.walk_tree(expr.expression),
            )

        elif isinstance(expr, NamedExpression):
            """Add backquoting to aliases"""
            if expr.name:
                expr.name = Identifier(
                    self.rename_add_quotes(expr.name.text, self.quote)
                )

        if isinstance(expr, BareFunction) and expr.name == FuncName("RANDOM"):
            return BareFunction(name=FuncName("RAND"))

        # Encode(Decode(source, 'escape'), 'hex') -> Unhex(source)
        if (
            isinstance(expr, EncodeFunction)
            and expr.format == Literal(None, "'escape'")
            and isinstance(expr.source, DecodeFunction)
            and expr.source.format == Literal(None, "'hex'")
        ):
            return UnhexFunction(self.walk_tree(expr.source.source))

        if isinstance(expr, CoalesceFunction):
            return UserDefinedFunction(
                name="IFNULL",
                expressions=[self.walk_tree(e) for e in expr.expressions],
            )

        if isinstance(expr, UserDefinedFunction):
            if expr.name == "TIMEZONE":
                assert len(expr.expressions) == 2
                tz1 = expr.expressions[0]
                timestamp1 = expr.expressions[1]

                # TIMEZONE(t1, TIMEZONE(tz, expr))
                if (
                    isinstance(timestamp1, UserDefinedFunction)
                    and timestamp1.name == "TIMEZONE"
                ):
                    assert len(timestamp1.expressions) == 2
                    tz2 = timestamp1.expressions[0]
                    timestamp2 = timestamp1.expressions[1]

                    # TIMEZONE('UTC', TIMEZONE(tz2, expr))
                    # -> TO_UTC_TIMESTAMP(expr, tz2)
                    if tz1.text == "'UTC'":
                        return UserDefinedFunction(
                            name="TO_UTC_TIMESTAMP",
                            expressions=Seq(
                                [
                                    self.transform(timestamp2),
                                    self.transform(tz2),
                                ]
                            ),
                        )

                    # TIMEZONE(tz1, TIMEZONE('UTC', expr))
                    # -> FROM_UTC_TIMESTAMP(expr, tz1)
                    if tz2.text == "'UTC'":
                        return UserDefinedFunction(
                            name="FROM_UTC_TIMESTAMP",
                            expressions=Seq(
                                [
                                    self.transform(timestamp2),
                                    self.transform(tz1),
                                ]
                            ),
                        )

                    # TIMEZONE(tz1, TIMEZONE(tz2, expr))
                    # -> FROM_UTC_TIMESTAMP(TO_UTC_TIMESTAMP(expr, tz2), tz1)
                    return UserDefinedFunction(
                        name="FROM_UTC_TIMESTAMP",
                        expressions=Seq(
                            [
                                UserDefinedFunction(
                                    name="TO_UTC_TIMESTAMP",
                                    expressions=Seq(
                                        [
                                            self.transform(timestamp2),
                                            self.transform(tz2),
                                        ]
                                    ),
                                ),
                                self.transform(tz1),
                            ]
                        ),
                    )
                raise NotImplementedError(
                    "We can translate only `TIMEZONE(tz1, TIMEZONE(tz2, exp))`"
                    f" Postgres -> MySQL.\nGot {expr} "
                )
            if expr.name.upper() in USER_DEFINED_AGGS:
                raise NotImplementedError(
                    f"Aggregate functions {expr.name}"
                    f" not supported: {expr}"
                )

        return self.walk_tree(expr)

    def transform_to_postgres(self, expr: Sql) -> Sql:
        if isinstance(expr, Column) or isinstance(expr, Table):
            ident_quote = '"'
            parts = split_quote(expr.name, ident_quote)
            expr.name = ".".join(parts)
            if isinstance(expr, Table) and expr.alias:
                expr.alias = Identifier(
                    self.rename_add_quotes(expr.alias.text, ident_quote)
                )
            return expr

        if isinstance(expr, Table):
            ident_quote = '"'
            parts = split_quote(expr.name, ident_quote)
            expr.name = ".".join(parts)
            if isinstance(expr, Table) and expr.alias:
                expr.alias = Identifier(
                    self.rename_add_quotes(expr.alias.text, ident_quote)
                )
            return expr

        elif isinstance(expr, NamedExpression) and expr.name:
            expr.name = Identifier(self.rename_add_quotes(expr.name.text, '"'))

        elif isinstance(expr, AliasedSubquery) and expr.alias:
            expr.alias = Identifier(
                self.rename_add_quotes(expr.alias.text, '"')
            )

        if isinstance(expr, MathFunction) and expr.name == FuncName("LOG"):
            return MathFunction(
                name=FuncName("LN"), expression=self.transform(expr.expression)
            )

        if isinstance(expr, BareFunction) and expr.name == FuncName("RAND"):
            return BareFunction(name=FuncName("RANDOM"))

        # unhex(source) -> encode(decode(source, 'hex'), 'escape')
        # see issue: https://gitlab.com/sarus-tech/sarus-sql/-/issues/107
        if isinstance(expr, UnhexFunction):
            return EncodeFunction(
                source=DecodeFunction(
                    source=self.transform(expr.source),
                    format=Literal(None, "'hex'"),
                ),
                format=Literal(None, "'escape'"),
            )

        # ENCODE() and DECODE() functions are deprecated after MySQL 5.7 and
        # should no longer be used:
        # https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_decode # noqa: E501
        # ENCODE/DECODE from mysql will be removed in postgres.
        if isinstance(expr, DecodeFunction) or isinstance(
            expr, EncodeFunction
        ):
            if isinstance(expr.format, Literal) and expr.format not in [
                Literal(None, "'escape'"),
                Literal(None, "'hex'"),
                Literal(None, "'base64'"),
            ]:
                return self.transform(expr.source)

        # TO_UTC
        if isinstance(expr, UserDefinedFunction):
            if expr.name == "TO_UTC_TIMESTAMP":
                assert len(expr.expressions) == 2
                child_expr = expr.expressions[0]
                tz = self.transform(expr.expressions[1])

                # TO_UTC_TIMESTAMP (FROM_UTC_TIMESTAMP(ts, tz1), tz2)
                # -> ts
                if (
                    isinstance(child_expr, UserDefinedFunction)
                    and child_expr.name == "FROM_UTC_TIMESTAMP"
                ):
                    assert len(child_expr.expressions == 2)
                    return self.transform(child_expr.expressions[0])

                # TO_UTC_TIMESTAMP (ts, tz)
                # -> TIMEZONE('UTC', TIMEZONE(tz, ts))
                return UserDefinedFunction(
                    name="TIMEZONE",
                    expressions=Seq(
                        [
                            Literal("'UTC'"),
                            UserDefinedFunction(
                                name="TIMEZONE",
                                expressions=Seq(
                                    [
                                        tz,
                                        self.transform(expr.expressions[0]),
                                    ]
                                ),
                            ),
                        ]
                    ),
                )

            if expr.name == "FROM_UTC_TIMESTAMP":
                assert len(expr.expressions) == 2
                child_expr = expr.expressions[0]

                #  FROM_UTC_TIMESTAMP(TO_UTC_TIMESTAMP(ts, tz1), tz2)
                # -> TIMEZONE(tz2, TIMEZONE(tz1, ts))
                if (
                    isinstance(child_expr, UserDefinedFunction)
                    and child_expr.name == "TO_UTC_TIMESTAMP"
                ):
                    assert len(child_expr.expressions) == 2
                    return UserDefinedFunction(
                        name="TIMEZONE",
                        expressions=Seq(
                            [
                                self.transform(expr.expressions[1]),
                                UserDefinedFunction(
                                    name="TIMEZONE",
                                    expressions=Seq(
                                        [
                                            self.transform(
                                                child_expr.expressions[1]
                                            ),
                                            self.transform(
                                                child_expr.expressions[0]
                                            ),
                                        ]
                                    ),
                                ),
                            ]
                        ),
                    )

                # FROM_UTC_TIMESTAMP (ts, tz)
                # -> TIMEZONE(tz, TIMEZONE(tz, 'UTC'))
                return UserDefinedFunction(
                    name="TIMEZONE",
                    expressions=Seq(
                        [
                            self.transform(expr.expressions[1]),
                            UserDefinedFunction(
                                name="TIMEZONE",
                                expressions=Seq(
                                    [
                                        Literal("'UTC'"),
                                        self.transform(expr.expressions[0]),
                                    ]
                                ),
                            ),
                        ]
                    ),
                )

            if expr.name == "IFNULL":
                return CoalesceFunction(
                    expressions=Seq(
                        [
                            self.transform(udf_expr)
                            for udf_expr in expr.expressions
                        ]
                    ),
                )

        return self.walk_tree(expr)

    def rename_add_quotes(self, alias: str, quote: str) -> str:
        if is_quoted(alias):
            alias = f"{quote}{alias[1:-1]}{quote}"
        else:
            alias = f"{quote}{alias}{quote}"
        return alias