Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
import typing as t
from snsql._ast.ast import (
AliasedSubquery,
BareFunction,
Column,
FuncName,
Identifier,
Literal,
MathFunction,
NamedExpression,
Query,
Seq,
Sql,
Table,
UserDefinedFunction,
)
from snsql._ast.expressions.string import (
CoalesceFunction,
DecodeFunction,
EncodeFunction,
UnhexFunction,
)
from ..utils import is_quoted, split_quote
from .base import BaseTranslator
USER_DEFINED_AGGS = (
"BIT_XOR",
"GROUP_CONCAT",
"JSON_ARRAYAGG",
"JSON_OBJECTAGG",
"STDDEV_POP",
"STDDEV_SAMP",
"VAR_POP",
"VAR_SAMP",
)
class MySQLTranslator(BaseTranslator):
"""Translator for Postgres <-> MySQL (5)
Current supported translations:
- LN <-> LOG
- RANDOM <-> RAND
- Encode(Decode(source, 'hex'), 'escape') <-> UNHEX(source)
LN and LOG in postgres are respectively the natural log and the base 10 log
LOG in MySQL is the natural log.
- we quote the table and subquery names with `
"""
def __init__(self, query: Query, **kwargs: t.Any) -> None:
super().__init__(query, **kwargs)
self.quote = "`"
def transform_to_dialect(self, expr: Sql) -> Sql:
if isinstance(expr, Column):
"""Changes or add backquoting quoting"""
parts = split_quote(expr.name, self.quote)
expr.name = ".".join(parts[-2:])
return expr
elif isinstance(expr, Table):
parts = split_quote(expr.name, self.quote)
expr.name = ".".join(parts)
if expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, self.quote)
)
return expr
elif isinstance(expr, AliasedSubquery):
if expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, self.quote)
)
if isinstance(expr, MathFunction) and expr.name == FuncName("LN"):
return MathFunction(
name=FuncName("LOG"),
expression=self.walk_tree(expr.expression),
)
elif isinstance(expr, NamedExpression):
"""Add backquoting to aliases"""
if expr.name:
expr.name = Identifier(
self.rename_add_quotes(expr.name.text, self.quote)
)
if isinstance(expr, BareFunction) and expr.name == FuncName("RANDOM"):
return BareFunction(name=FuncName("RAND"))
# Encode(Decode(source, 'escape'), 'hex') -> Unhex(source)
if (
isinstance(expr, EncodeFunction)
and expr.format == Literal(None, "'escape'")
and isinstance(expr.source, DecodeFunction)
and expr.source.format == Literal(None, "'hex'")
):
return UnhexFunction(self.walk_tree(expr.source.source))
if isinstance(expr, CoalesceFunction):
return UserDefinedFunction(
name="IFNULL",
expressions=[self.walk_tree(e) for e in expr.expressions],
)
if isinstance(expr, UserDefinedFunction):
if expr.name == "TIMEZONE":
assert len(expr.expressions) == 2
tz1 = expr.expressions[0]
timestamp1 = expr.expressions[1]
# TIMEZONE(t1, TIMEZONE(tz, expr))
if (
isinstance(timestamp1, UserDefinedFunction)
and timestamp1.name == "TIMEZONE"
):
assert len(timestamp1.expressions) == 2
tz2 = timestamp1.expressions[0]
timestamp2 = timestamp1.expressions[1]
# TIMEZONE('UTC', TIMEZONE(tz2, expr))
# -> TO_UTC_TIMESTAMP(expr, tz2)
if tz1.text == "'UTC'":
return UserDefinedFunction(
name="TO_UTC_TIMESTAMP",
expressions=Seq(
[
self.transform(timestamp2),
self.transform(tz2),
]
),
)
# TIMEZONE(tz1, TIMEZONE('UTC', expr))
# -> FROM_UTC_TIMESTAMP(expr, tz1)
if tz2.text == "'UTC'":
return UserDefinedFunction(
name="FROM_UTC_TIMESTAMP",
expressions=Seq(
[
self.transform(timestamp2),
self.transform(tz1),
]
),
)
# TIMEZONE(tz1, TIMEZONE(tz2, expr))
# -> FROM_UTC_TIMESTAMP(TO_UTC_TIMESTAMP(expr, tz2), tz1)
return UserDefinedFunction(
name="FROM_UTC_TIMESTAMP",
expressions=Seq(
[
UserDefinedFunction(
name="TO_UTC_TIMESTAMP",
expressions=Seq(
[
self.transform(timestamp2),
self.transform(tz2),
]
),
),
self.transform(tz1),
]
),
)
raise NotImplementedError(
"We can translate only `TIMEZONE(tz1, TIMEZONE(tz2, exp))`"
f" Postgres -> MySQL.\nGot {expr} "
)
if expr.name.upper() in USER_DEFINED_AGGS:
raise NotImplementedError(
f"Aggregate functions {expr.name}"
f" not supported: {expr}"
)
return self.walk_tree(expr)
def transform_to_postgres(self, expr: Sql) -> Sql:
if isinstance(expr, Column) or isinstance(expr, Table):
ident_quote = '"'
parts = split_quote(expr.name, ident_quote)
expr.name = ".".join(parts)
if isinstance(expr, Table) and expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, ident_quote)
)
return expr
if isinstance(expr, Table):
ident_quote = '"'
parts = split_quote(expr.name, ident_quote)
expr.name = ".".join(parts)
if isinstance(expr, Table) and expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, ident_quote)
)
return expr
elif isinstance(expr, NamedExpression) and expr.name:
expr.name = Identifier(self.rename_add_quotes(expr.name.text, '"'))
elif isinstance(expr, AliasedSubquery) and expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, '"')
)
if isinstance(expr, MathFunction) and expr.name == FuncName("LOG"):
return MathFunction(
name=FuncName("LN"), expression=self.transform(expr.expression)
)
if isinstance(expr, BareFunction) and expr.name == FuncName("RAND"):
return BareFunction(name=FuncName("RANDOM"))
# unhex(source) -> encode(decode(source, 'hex'), 'escape')
# see issue: https://gitlab.com/sarus-tech/sarus-sql/-/issues/107
if isinstance(expr, UnhexFunction):
return EncodeFunction(
source=DecodeFunction(
source=self.transform(expr.source),
format=Literal(None, "'hex'"),
),
format=Literal(None, "'escape'"),
)
# ENCODE() and DECODE() functions are deprecated after MySQL 5.7 and
# should no longer be used:
# https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_decode # noqa: E501
# ENCODE/DECODE from mysql will be removed in postgres.
if isinstance(expr, DecodeFunction) or isinstance(
expr, EncodeFunction
):
if isinstance(expr.format, Literal) and expr.format not in [
Literal(None, "'escape'"),
Literal(None, "'hex'"),
Literal(None, "'base64'"),
]:
return self.transform(expr.source)
# TO_UTC
if isinstance(expr, UserDefinedFunction):
if expr.name == "TO_UTC_TIMESTAMP":
assert len(expr.expressions) == 2
child_expr = expr.expressions[0]
tz = self.transform(expr.expressions[1])
# TO_UTC_TIMESTAMP (FROM_UTC_TIMESTAMP(ts, tz1), tz2)
# -> ts
if (
isinstance(child_expr, UserDefinedFunction)
and child_expr.name == "FROM_UTC_TIMESTAMP"
):
assert len(child_expr.expressions == 2)
return self.transform(child_expr.expressions[0])
# TO_UTC_TIMESTAMP (ts, tz)
# -> TIMEZONE('UTC', TIMEZONE(tz, ts))
return UserDefinedFunction(
name="TIMEZONE",
expressions=Seq(
[
Literal("'UTC'"),
UserDefinedFunction(
name="TIMEZONE",
expressions=Seq(
[
tz,
self.transform(expr.expressions[0]),
]
),
),
]
),
)
if expr.name == "FROM_UTC_TIMESTAMP":
assert len(expr.expressions) == 2
child_expr = expr.expressions[0]
# FROM_UTC_TIMESTAMP(TO_UTC_TIMESTAMP(ts, tz1), tz2)
# -> TIMEZONE(tz2, TIMEZONE(tz1, ts))
if (
isinstance(child_expr, UserDefinedFunction)
and child_expr.name == "TO_UTC_TIMESTAMP"
):
assert len(child_expr.expressions) == 2
return UserDefinedFunction(
name="TIMEZONE",
expressions=Seq(
[
self.transform(expr.expressions[1]),
UserDefinedFunction(
name="TIMEZONE",
expressions=Seq(
[
self.transform(
child_expr.expressions[1]
),
self.transform(
child_expr.expressions[0]
),
]
),
),
]
),
)
# FROM_UTC_TIMESTAMP (ts, tz)
# -> TIMEZONE(tz, TIMEZONE(tz, 'UTC'))
return UserDefinedFunction(
name="TIMEZONE",
expressions=Seq(
[
self.transform(expr.expressions[1]),
UserDefinedFunction(
name="TIMEZONE",
expressions=Seq(
[
Literal("'UTC'"),
self.transform(expr.expressions[0]),
]
),
),
]
),
)
if expr.name == "IFNULL":
return CoalesceFunction(
expressions=Seq(
[
self.transform(udf_expr)
for udf_expr in expr.expressions
]
),
)
return self.walk_tree(expr)
def rename_add_quotes(self, alias: str, quote: str) -> str:
if is_quoted(alias):
alias = f"{quote}{alias[1:-1]}{quote}"
else:
alias = f"{quote}{alias}{quote}"
return alias