Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
import copy
import typing as t
from snsql._ast.ast import (
AliasedSubquery,
BareFunction,
Column,
From,
FuncName,
Identifier,
Literal,
MathFunction,
NamedExpression,
Query,
Relation,
Select,
Seq,
Sql,
SqlExpr,
Table,
Token,
UnifiedQuery,
Unnest,
UserDefinedFunction,
)
from snsql._ast.expressions.date import ExtractFunction
from snsql._ast.expressions.string import (
CoalesceFunction,
DecodeFunction,
EncodeFunction,
HexFunction,
RegexpSplitToTableFunction,
SubstrBigQueryFunction,
SubstringFunction,
UnhexFunction,
)
from snsql._ast.expressions.types import CastFunction
from ..table_renamer import TableRenamer
from ..utils import is_quoted, split_quote
from .base import BaseTranslator
# Big query aggregate functions that are not common with Postgres
USER_DEFINED_AGGS = (
"ANY_VALUE",
"ARRAY_CONCAT_AGG",
"BIT_XOR",
"COUNTIF",
"LOGICAL_AND",
"LOGICAL_OR",
)
class BigQueryTranslator(BaseTranslator):
"""Translator for BigQuery <-> Postgres.
Bigquery quoting:
https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical
We quote all aliases to avoid conflicts with BQ keywords
Moreover, aliases can't contain special characters like dot etc..
"""
def __init__(
self, query: t.Union[Query, UnifiedQuery], **kwargs: t.Any
) -> None:
super().__init__(query, **kwargs)
self.quote = "`"
self.alias_map = self.alias_mapping()
self.rename_tables_columns()
def alias_mapping(self) -> t.Dict[str, str]:
"""it generates a dict mapping actual aliases into normalized aliases.
In normalized aliases, spaces and dots are replaced with underscores
Returns:
Dict[str, str]: actual_alias: normalized alias
"""
tables = self.query.find_nodes(Table)
exprs = self.query.find_nodes(NamedExpression)
aliased_sub = self.query.find_nodes(AliasedSubquery)
aliases = {}
for expr in [*tables, *aliased_sub]:
if expr.alias:
aliases[expr.alias.text] = expr.alias.text.replace(
" ", "_"
).replace(".", "_")
for expr in exprs:
if expr.name:
aliases[expr.name.text] = expr.name.text.replace(
" ", "_"
).replace(".", "_")
return aliases
def rename_tables_columns(self) -> None:
"""Using TableRenamer to rename columns and tables. There are aliases
left to be renamed.
"""
# map_dict should contain escaped names
compatible_alias_map = {
split_quote(key, ""): split_quote(val, "")
for key, val in self.alias_map.items()
}
renamer = TableRenamer(copy.deepcopy(self.query), compatible_alias_map)
renamer.rename()
self.query = renamer.query
def rename_add_quotes(self, alias: str, quote: str) -> str:
"""renames aliases. To avoid conflicts with possible bigquery
keywords we backquote all aliases.
"""
if alias in self.alias_map:
alias = self.alias_map[alias]
if is_quoted(alias):
alias = f"{quote}{alias[1:-1]}{quote}"
else:
alias = f"{quote}{alias}{quote}"
return alias
def transform_to_dialect(self, expr: Sql) -> Sql:
if isinstance(expr, MathFunction) and expr.name == FuncName("LN"):
return MathFunction(
name=FuncName("LOG"),
expression=self.transform(expr.expression),
)
if isinstance(expr, BareFunction) and expr.name == FuncName("RANDOM"):
return BareFunction(name=FuncName("RAND"))
if isinstance(expr, Column):
"""Changes or add backquoting quoting"""
parts = split_quote(expr.name, self.quote)
expr.name = ".".join(parts[-2:])
return expr
elif isinstance(expr, Table):
parts = split_quote(expr.name, self.quote)
expr.name = ".".join(parts)
if expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, self.quote)
)
return expr
elif isinstance(expr, AliasedSubquery):
if expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, self.quote)
)
elif isinstance(expr, SubstringFunction):
return SubstrBigQueryFunction(
self.walk_tree(expr.source),
self.walk_tree(expr.start),
self.walk_tree(expr.length),
)
elif isinstance(expr, NamedExpression):
"""Add backquoting to aliases"""
if expr.name:
expr.name = Identifier(
self.rename_add_quotes(expr.name.text, self.quote)
)
# Encode(Decode(source, 'hex'), 'escape') -> Unhex(source)
elif (
isinstance(expr, EncodeFunction)
and expr.format == Literal(None, "'escape'")
and isinstance(expr.source, DecodeFunction)
and expr.source.format == Literal(None, "'hex'")
):
return CastFunction(
UnhexFunction(
self.walk_tree(expr.source.source), token=Token("FROM_HEX")
),
"string",
)
if isinstance(expr, CastFunction):
if (
expr.dbtype.startswith("varchar")
or expr.dbtype.startswith("char")
or expr.dbtype == "text"
):
expr.dbtype = "string"
if (
isinstance(expr.expression, UserDefinedFunction)
and expr.expression.name == "TIMEZONE"
):
assert len(expr.expression.expressions) == 2
expr.expression = UserDefinedFunction(
name="DATETIME",
expressions=self.walk_tree(expr.expression.expressions)[
::-1
],
)
if isinstance(expr, Query) and any(
[
isinstance(name_expr.expression, RegexpSplitToTableFunction)
for name_expr in expr.select.namedExpressions
]
):
# REGEXP_SPLIT_TO_TABLE be translated as with a SPLIT and UNNEST:
# example:
#
# SELECT DISTINCT REGEXP_SPLIT_TO_TABLE ( tab.name , '' ) AS col
# FROM (
# SELECT name AS name FROM mytable WHERE name IS NOT NULL
# ) AS tab
#
# it becomes:
#
# SELECT DISTINCT unnested AS col
# FROM (
# SELECT SPLIT(tab.name , '') AS splitted
# FROM ( SELECT name AS name FROM mytable WHERE name IS NOT NULL
# ) AS tab
# ) AS splitted_subquery, UNNEST(split_subq.splitted) AS unnested
#
main_named_exprs = []
for named_expr in expr.select.namedExpressions:
if isinstance(
named_expr.expression, RegexpSplitToTableFunction
):
regexp_alias = named_expr.name
regexp_source_expr = self.walk_tree(
named_expr.expression.source
)
regexp_pattern = self.walk_tree(
named_expr.expression.pattern
)
else:
main_named_exprs.append(named_expr)
# building the subquery
split = UserDefinedFunction(
name="SPLIT",
expressions=Seq([regexp_source_expr, regexp_pattern]),
)
namedexpr = NamedExpression(
expression=split, name=Identifier("splitted")
)
select = Select(quantifier=None, namedExpressions=Seq([namedexpr]))
subquery = Query(
select=select,
source=expr.source,
where=None,
agg=None,
having=None,
order=None,
limit=None,
)
aliased_subquery = AliasedSubquery(
query=subquery, alias=Identifier("split_subq")
)
relation_with_subquery = Relation(
primary=aliased_subquery, joins=None
)
relation_with_unnest = Relation(
primary=Unnest(
name=Column("split_subq.splitted"), alias="unnested"
),
joins=None,
)
main_nexpr = NamedExpression(
expression=Column("unnested"), name=regexp_alias
)
main_named_exprs.append(main_nexpr)
main_select = Select(
quantifier=expr.select.quantifier,
namedExpressions=Seq(main_named_exprs),
)
main_source = Seq([relation_with_subquery, relation_with_unnest])
expr.select = main_select
expr.source = From(main_source)
if isinstance(expr, ExtractFunction) and expr.date_part == "epoch":
# It converts EXTRACT(epoch FROM column) into
# UNIX_SECONDS(CAST(col AS TIMESTAMP))
# used to compute bigdata bounds for date and datetimes
cast = CastFunction(
self.walk_tree(expr.expression), dbtype="timestamp"
)
return UserDefinedFunction(
name="UNIX_SECONDS",
expressions=Seq([cast]),
)
if isinstance(expr, CoalesceFunction):
return UserDefinedFunction(
name="IFNULL",
expressions=[self.walk_tree(e) for e in expr.expressions],
)
if isinstance(expr, UserDefinedFunction):
# convert MD5(x) -> TO_HEX(MD5(X))
if expr.name == "MD5":
new_md5 = UserDefinedFunction(
name=expr.name,
expressions=self.walk_tree(expr.expressions),
)
return HexFunction(
token=Token("TO_HEX"),
source=new_md5,
)
# convert TIMEZONE(tz, ts)
# -> CAST(DATETIME(ts, tz)AS TIMESTAMP)
if expr.name == "TIMEZONE":
assert len(expr.expressions) == 2
return CastFunction(
expression=UserDefinedFunction(
name="DATETIME",
expressions=self.walk_tree(expr.expressions)[::-1],
),
dbtype="timestamp",
)
if expr.name.upper() in USER_DEFINED_AGGS:
raise NotImplementedError(
f"Aggregate functions {expr.name}"
f" not supported: {expr}"
)
return self.walk_tree(expr)
def transform_to_postgres(self, expr: SqlExpr) -> SqlExpr:
if isinstance(expr, MathFunction) and expr.name == FuncName("LOG"):
return MathFunction(
name=FuncName("LN"), expression=self.transform(expr.expression)
)
if isinstance(expr, BareFunction) and expr.name == FuncName("RAND"):
return BareFunction(name=FuncName("RANDOM"))
if isinstance(expr, Column) or isinstance(expr, Table):
ident_quote = '"'
parts = split_quote(expr.name, ident_quote)
expr.name = ".".join(parts)
if isinstance(expr, Table) and expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, ident_quote)
)
return expr
elif isinstance(expr, SubstrBigQueryFunction):
return SubstringFunction(
self.walk_tree(expr.source),
self.walk_tree(expr.start),
self.walk_tree(expr.length),
{"FROM": ",", "FOR": ","}, # tokens
)
elif isinstance(expr, NamedExpression) and expr.name:
expr.name = Identifier(self.rename_add_quotes(expr.name.text, '"'))
elif isinstance(expr, AliasedSubquery) and expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, '"')
)
elif isinstance(expr, CastFunction):
# CAST(FROM_HEX(source) AS STRING)
# -> Encode(Decode(source, 'escape'),'hex')
if isinstance(expr.expression, UnhexFunction):
return EncodeFunction(
source=DecodeFunction(
source=self.walk_tree(expr.expression.source),
format=Literal(None, "'hex'"),
),
format=Literal(None, "'escape'"),
)
# CAST(DATETIME(ts, tz) AS TIMESTAMP)
# -> TIMEZONE(tz, ts)
if (
expr.dbtype == "timestamp"
and isinstance(expr.expression, UserDefinedFunction)
and expr.expression.name == "DATETIME"
):
assert len(expr.expression.expressions) == 2
return UserDefinedFunction(
name="TIMEZONE",
expressions=self.walk_tree(expr.expression.expressions)[
::-1
],
)
if expr.dbtype.startswith("string"):
expr.dbtype = "varchar"
if isinstance(expr, HexFunction) and isinstance(
expr.source, UserDefinedFunction
):
return UserDefinedFunction(
name="MD5", expressions=self.walk_tree(expr.source.expressions)
)
if isinstance(expr, UserDefinedFunction):
# it translates UNIX_SECONDS(CAST(col AS TIMESTAMP))
# to EXTRACT(epoch FROM column)
if (
expr.name == "UNIX_SECONDS"
and isinstance(expr.expressions[0], CastFunction)
and expr.expressions[0].dbtype == "timestamp"
):
return ExtractFunction(
"epoch", self.walk_tree(expr.expressions[0].expression)
)
# convert DATETIME(ts, tz)
# -> TIMEZONE(tz, ts)
if expr.name == "DATETIME":
assert len(expr.expressions) == 2
return UserDefinedFunction(
name="TIMEZONE",
expressions=self.walk_tree(expr.expressions)[::-1],
)
if expr.name == "IFNULL":
return CoalesceFunction(
expressions=Seq(
[
self.transform(udf_expr)
for udf_expr in expr.expressions
]
),
)
return self.walk_tree(expr)