Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
import copy
import typing as t
from snsql._ast.ast import (
AliasedSubquery,
Column,
From,
Identifier,
Literal,
NamedExpression,
Query,
Relation,
Select,
Seq,
Sql,
SqlExpr,
Table,
Token,
UnifiedQuery,
UserDefinedFunction,
)
from snsql._ast.expressions.date import ExtractFunction
from snsql._ast.expressions.string import (
DecodeFunction,
EncodeFunction,
RegexpSplitToTableFunction,
SubstrBigQueryFunction,
SubstringFunction,
UnhexFunction,
)
from snsql._ast.expressions.types import CastFunction
from ..table_renamer import TableRenamer
from ..utils import is_quoted, split_quote
from .base import BaseTranslator
# Databricks aggregate functions that are not common with Postgres
USER_DEFINED_AGGS = (
"ANY_VALUE",
"APPROX_COUNT_DISTINCT",
"APPROX_PERCENTILE",
"APPROX_TOP_K",
"BITMAP_CONSTRUCT_AGG",
"BITMAP_OR_AGG",
"COLLECT_LIST",
"COLLECT_SET",
"COUNT_IF",
"COUNT_MIN_SKETCH",
"FIRST",
"HLL_SKETCH_AGG",
"HLL_UNION_AGG",
"KURTOSIS",
"LAST",
"MAX_BY",
"MEAN",
"MEDIAN",
"MIN_BY",
"PERCENTILE",
"PERCENTILE_APPROX",
"SKEWNESS",
"STD",
"TRY_AVG",
"TRY_SUM",
)
class DatabricksTranslator(BaseTranslator):
"""Translator for Databricks <-> Postgres.
Databricks does not support quote '"' in query,
it should be replaced with '`'.
We quote all aliases to avoid conflicts with Databricks keywords
Moreover, aliases can't contain special characters like dot etc..
"""
def __init__(
self, query: t.Union[Query, UnifiedQuery], **kwargs: t.Any
) -> None:
super().__init__(query, **kwargs)
self.quote = "`"
self.alias_map = self.alias_mapping()
self.rename_tables_columns()
def alias_mapping(self) -> t.Dict[str, str]:
"""it generates a dict mapping actual aliases into normalized aliases.
In normalized aliases, spaces and dots are replaced with underscores
Returns:
Dict[str, str]: actual_alias: normalized alias
"""
tables = self.query.find_nodes(Table)
exprs = self.query.find_nodes(NamedExpression)
aliased_sub = self.query.find_nodes(AliasedSubquery)
aliases = {}
for expr in [*tables, *aliased_sub]:
if expr.alias:
aliases[expr.alias.text] = expr.alias.text.replace(
" ", "_"
).replace(".", "_")
for expr in exprs:
if expr.name:
aliases[expr.name.text] = expr.name.text.replace(
" ", "_"
).replace(".", "_")
return aliases
def rename_tables_columns(self) -> None:
"""Using TableRenamer to rename columns and tables. There are aliases
left to be renamed.
"""
# map_dict should contain escaped names
compatible_alias_map = {
split_quote(key, "`"): split_quote(val, "`")
for key, val in self.alias_map.items()
}
renamer = TableRenamer(copy.deepcopy(self.query), compatible_alias_map)
renamer.rename()
self.query = renamer.query
def rename_add_quotes(self, alias: str, quote: str) -> str:
"""renames aliases. To avoid conflicts with possible databricks
keywords we backquote all aliases.
"""
if alias in self.alias_map:
alias = self.alias_map[alias]
if is_quoted(alias):
alias = f"{quote}{alias[1:-1]}{quote}"
else:
alias = f"{quote}{alias}{quote}"
return alias
def transform_to_dialect(self, expr: Sql) -> Sql:
if isinstance(expr, Column):
"""Changes or add backquoting quoting"""
parts = split_quote(expr.name, self.quote)
expr.name = ".".join(parts[-2:])
return expr
elif isinstance(expr, Table):
parts = split_quote(expr.name, self.quote)
expr.name = ".".join(parts)
if expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, self.quote)
)
return expr
elif isinstance(expr, AliasedSubquery):
if expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, self.quote)
)
elif isinstance(expr, NamedExpression):
"""Add backquoting to aliases"""
if expr.name:
expr.name = Identifier(
self.rename_add_quotes(expr.name.text, self.quote)
)
# Encode(Decode(source, 'hex'), 'escape') -> Unhex(source)
elif (
isinstance(expr, EncodeFunction)
and expr.format == Literal(None, "'escape'")
and isinstance(expr.source, DecodeFunction)
and expr.source.format == Literal(None, "'hex'")
):
return CastFunction(
UnhexFunction(
self.walk_tree(expr.source.source), token=Token("UNHEX")
),
"string",
)
if isinstance(expr, CastFunction):
if (
expr.dbtype.startswith("varchar")
or expr.dbtype.startswith("char")
or expr.dbtype == "text"
):
expr.dbtype = "string"
if isinstance(expr, UserDefinedFunction) and expr.name == "TIMEZONE":
timezone = UserDefinedFunction(
name="CONVERT_TIMEZONE",
expressions=self.walk_tree(expr.expressions),
)
return timezone
if (
isinstance(expr, UserDefinedFunction)
and expr.name.upper() in USER_DEFINED_AGGS
):
raise NotImplementedError(
f"Aggregate functions {expr.name}" f" not supported: {expr}"
)
if isinstance(expr, (SubstringFunction, SubstrBigQueryFunction)):
return SubstringFunction(
source=self.walk_tree(
expr.source,
),
start=expr.start,
length=expr.length,
tokens={"FROM": Token(","), "FOR": Token(",")},
)
if isinstance(expr, Query) and any(
[
isinstance(name_expr.expression, RegexpSplitToTableFunction)
for name_expr in expr.select.namedExpressions
]
):
# REGEXP_SPLIT_TO_TABLE be translated as with a SPLIT and UNNEST:
# example:
#
# SELECT DISTINCT REGEXP_SPLIT_TO_TABLE ( tab.name , '' ) AS col
# FROM (
# SELECT name AS name FROM mytable WHERE name IS NOT NULL
# ) AS tab
#
# it becomes:
#
# SELECT DISTINCT `exploded` AS `col`
# FROM (
# SELECT EXPLODE( SPLIT( `tab`.`name`,'')) AS `exploded`
# FROM(
# SELECT `name` AS `name` FROM `mytable` WHERE `name`
# IS NOT NULL
# ) AS`tab`
# ) AS `split_subq`
#
main_named_exprs = []
for named_expr in expr.select.namedExpressions:
if isinstance(
named_expr.expression, RegexpSplitToTableFunction
):
regexp_alias = named_expr.name
regexp_source_expr = self.walk_tree(
named_expr.expression.source
)
regexp_pattern = self.walk_tree(
named_expr.expression.pattern
)
else:
main_named_exprs.append(named_expr)
# building the subquery
split = UserDefinedFunction(
name="SPLIT",
expressions=Seq([regexp_source_expr, regexp_pattern]),
)
explode = UserDefinedFunction(
name="EXPLODE",
expressions=Seq([split]),
)
namedexpr = NamedExpression(
expression=explode, name=Identifier("exploded")
)
select = Select(quantifier=None, namedExpressions=Seq([namedexpr]))
subquery = Query(
select=select,
source=expr.source,
where=None,
agg=None,
having=None,
order=None,
limit=None,
)
aliased_subquery = AliasedSubquery(
query=subquery, alias=Identifier("split_subq")
)
relation_with_subquery = Relation(
primary=aliased_subquery, joins=None
)
main_nexpr = NamedExpression(
expression=Column("exploded"), name=regexp_alias
)
main_named_exprs.append(main_nexpr)
main_select = Select(
quantifier=expr.select.quantifier,
namedExpressions=Seq(main_named_exprs),
)
main_source = Seq([relation_with_subquery])
expr.select = main_select
expr.source = From(main_source)
if isinstance(expr, ExtractFunction) and expr.date_part == "epoch":
# It converts EXTRACT(epoch FROM column) into
# DATEDIFF(SECOND, '1970-01-01', column)
# used to compute bigdata bounds for date and datetimes
return UserDefinedFunction(
name="DATEDIFF",
expressions=Seq(
[
Token("SECOND"),
Literal("'1970-01-01'"),
self.walk_tree(expr.expression),
]
),
)
return self.walk_tree(expr)
def transform_to_postgres(self, expr: SqlExpr) -> SqlExpr:
ident_quote = '"'
if isinstance(expr, NamedExpression):
"""Add backquoting to aliases"""
if expr.name:
expr.name = Identifier(
self.rename_add_quotes(expr.name.text, ident_quote)
)
if isinstance(expr, Column) or isinstance(expr, Table):
parts = split_quote(expr.name, ident_quote)
expr.name = ".".join(parts)
if isinstance(expr, Table) and expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, ident_quote)
)
return expr
elif isinstance(expr, AliasedSubquery) and expr.alias:
expr.alias = Identifier(
self.rename_add_quotes(expr.alias.text, ident_quote)
)
elif isinstance(expr, CastFunction):
# UNHEX(source)
# -> Encode(Decode(source, 'escape'),'hex')
if isinstance(expr.expression, UnhexFunction):
return EncodeFunction(
source=DecodeFunction(
source=self.walk_tree(expr.expression.source),
format=Literal(None, "'hex'"),
),
format=Literal(None, "'escape'"),
)
if expr.dbtype.startswith("string"):
expr.dbtype = "varchar"
if isinstance(expr, SubstringFunction):
return SubstringFunction(
source=self.walk_tree(
expr.source,
),
start=expr.start,
length=expr.length,
tokens={"FROM": Token(","), "FOR": Token(",")},
)
if isinstance(expr, UserDefinedFunction):
# it translates DATEDIFF(SECOND, '1970-01-01', col)
# to EXTRACT(epoch FROM col)
if (
expr.name == "DATEDIFF"
and expr.expressions[0] == Token("SECOND")
and expr.expressions[1] == Literal("'1970-01-01'")
):
return ExtractFunction(
"epoch", self.walk_tree(expr.expressions[2])
)
if expr.name == "UNIX_SECONDS":
return ExtractFunction(
"epoch", self.walk_tree(expr.expressions)
)
# convert CONVERT_TIMEZONE(ts, tz)
# -> TIMEZONE(ts, tz)
if expr.name == "CONVERT_TIMEZONE":
assert len(expr.expressions) == 2
return UserDefinedFunction(
name="TIMEZONE",
expressions=self.walk_tree(expr.expressions),
)
return self.walk_tree(expr)