Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
from enum import Enum
import copy
import os
import typing as t
from snsql._ast.ast import (
Aggregate,
BareFunction,
BooleanCompare,
Column,
FuncName,
GroupingExpression,
Limit,
Literal,
MathFunction,
Order,
Query,
Seq,
SortItem,
Sql,
Table,
Token,
Top,
UnifiedQuery,
UserDefinedFunction,
)
from snsql._ast.expressions.date import ExtractFunction
from snsql._ast.expressions.string import (
CharLengthFunction,
CoalesceFunction,
ConcatFunction,
DecodeFunction,
EncodeFunction,
RegexpSplitToTableFunction,
)
from snsql._ast.expressions.types import CastFunction
from sarus_sql.dialects import SQLDialect
from .base import BaseTranslator
from .charset_inferring_functions import char_inference_query, custom_charset
USER_DEFINED_AGGS = (
"APPROX_COUNT_DISTINCT",
"CHECKSUM_AGG",
"COUNT_BIG", # TODO: this should be translated into count
"GROUPING",
"GROUPING_ID",
"STDEVP", # TODO: equivalent STDEVPOP.
"STRING_AGG",
"VARP",
)
class MSSQL_CHARSET(Enum):
"""Enum to controll for the charset inference query."""
# It will generate a query for the Basic Multilingual Plane (BMP)
# charset 0-65535
BMP = "BMP"
# It will generate a query for the ASCII charset (0-255)
ASCII = "ASCII"
# it will infer the charset. It assumes that a given row of text has
# at most 1 milion characters.
INFER = "INFER"
class SqlServerTranslator(BaseTranslator):
"""Translate Postgres <-> SqlServer (Azure, Synaps) by modifying
the query AST.
Current supported translations:
- LN <-> LOG
- RANDOM <-> RAND
LN and LOG in postgres are respectively the natural log and the base 10 log
LOG in MSSQL is the natural log.
"""
def __init__(
self, query: t.Union[Query, UnifiedQuery], **kwargs: t.Any
) -> None:
# In Azure SQL
# if queries with CTEs and the main query not using any of
# the CTEs they must be removed otherwise an error will be raise
# from the db.
if isinstance(query, Query):
if query.select.ctes is not None:
cte_names = set(
[cte_table.name for cte_table in query.select.ctes]
)
else:
cte_names = set()
tables_in_the_main_select = set(
[
tab.name
for tab in query.select.namedExpressions.find_nodes(Table)
]
)
if (
cte_names.isdisjoint(tables_in_the_main_select)
and query.source is None
):
query.select.ctes = None
super().__init__(query, **kwargs)
def transform_to_dialect(self, expr: Sql) -> Sql:
"""Returns the transformed expression"""
if self.query.limit:
self.query.select.quantifier = Top(self.query.limit.n)
self.query.limit = None
if isinstance(expr, Column):
return expr
if isinstance(expr, MathFunction) and expr.name == FuncName("LN"):
return MathFunction(
name=FuncName("LOG"),
expression=self.transform(expr.expression),
)
if isinstance(expr, BareFunction) and expr.name == FuncName("RANDOM"):
# RAND in MSSQL will generate only 1 random number
return BareFunction(
name=FuncName("RAND"),
)
if isinstance(expr, CharLengthFunction):
return CharLengthFunction(
expression=self.walk_tree(expr.expression), token=Token("LEN")
)
if isinstance(expr, ExtractFunction) and expr.date_part == "epoch":
# It converts EXTRACT(epoch FROM column) into
# DATEDIFF(SECOND, '19700101', column)
# used to compute bigdata bounds for date and datetimes
return UserDefinedFunction(
name="DATEDIFF",
expressions=Seq(
[
Token("SECOND"),
Literal("'19700101'"),
self.walk_tree(expr.expression),
]
),
)
if (
isinstance(expr, EncodeFunction)
and expr.format == Literal(None, "'escape'")
and isinstance(expr.source, DecodeFunction)
and expr.source.format == Literal(None, "'hex'")
):
# Encode(Decode(source, 'hex'), 'escape') to
# CONVERT(
# VARCHAR(MAX),
# CONVERT(VARBINARY(MAX), '50726976617465', 2),
# 0
# )
source = self.walk_tree(expr.source.source)
hex_to_binary = UserDefinedFunction(
name="CONVERT",
expressions=Seq(["VARBINARY(MAX)", source, Literal(2)]),
)
binary_to_ascii = UserDefinedFunction(
name="CONVERT",
expressions=Seq(["VARCHAR(MAX)", hex_to_binary, Literal(0)]),
)
return binary_to_ascii
if isinstance(expr, ConcatFunction):
# Concat in MSSQL it has to take at least 2 args
exprs = [self.transform(e) for e in expr.expressions] + [
Literal(None, "''")
]
return ConcatFunction(exprs)
if isinstance(expr, Query) and any(
[
isinstance(name_expr.expression, RegexpSplitToTableFunction)
for name_expr in expr.select.namedExpressions
]
):
# expected a query like this one:
# SELECT
# DISTINCT REGEXP_SPLIT_TO_TABLE(anon_2.name ,'')
# AS "regexp_split"
# FROM (
# SELECT "tab1"."name" AS "name"
# FROM "tab1" WHERE "tab1"."name" IS NOT NULL
# ) AS "anon_2"
# otherwise raise an error
try:
# some asserts to male sure we have the expected query.
assert isinstance(
expr.select.namedExpressions[0].expression,
RegexpSplitToTableFunction,
)
assert len(expr.select.namedExpressions) == 1
# it would be the anon_2.name. the column for which we
# want to find the charset.
col = expr.select.namedExpressions[0].expression.source
# It is the alias of the charset
alias = expr.select.namedExpressions[0].name
# it is the source of the column on which we want to compute
# the carset.
source_relation = expr.source.relations[0]
except RuntimeError:
raise ValueError(
f"We don't know how to translate {expr} "
"into MSSQL dialect."
)
# We set by default ASCII charset (0-255 unicode values)
# for efficiency. If it is known in advance that the dataset
# has characters other than ASCII, MSSQL_CHARSET should be set
# accordingly.
charset = os.environ.get(
"MSSQL_CHARSET",
default="ASCII",
) # type: ignore
if charset == MSSQL_CHARSET.INFER.value:
col = expr.select.namedExpressions[0].expression.source
alias = expr.select.namedExpressions[0].name
source_relation = expr.source.relations[0]
return char_inference_query(
column=col,
col_alias=alias,
source_relation=source_relation,
max_string_length=1000000,
dialect=SQLDialect.SQL_SERVER,
)
elif charset == MSSQL_CHARSET.ASCII.value:
return custom_charset(alias, 255, SQLDialect.SQL_SERVER)
else:
return custom_charset(alias, 65535, SQLDialect.SQL_SERVER)
if isinstance(expr, UserDefinedFunction):
if expr.name == "MD5":
# convert MD5(X) to
# CONVERT(VARCHAR(MAX), HASHBYTES('MD5', X), 2)
expr_arg = self.transform(expr.expressions[0])
hashbytes = UserDefinedFunction(
name="HASHBYTES",
expressions=Seq([Literal("'MD5'"), expr_arg]),
)
covert = UserDefinedFunction(
name="CONVERT",
expressions=Seq(["VARCHAR(MAX)", hashbytes, Literal(2)]),
)
return covert
if expr.name.upper() in USER_DEFINED_AGGS:
raise NotImplementedError(
f"Aggregate functions {expr.name}"
f" not supported: {expr}"
)
if isinstance(expr, Query):
# it fixes a sampling query (SELECT TOP x .. ORDER BY RAND())
# when optimized with a WHERE clause. such query gives unstable
# results (zero rows sometimes):
# I tried WHERE RAND () < x:
# WHERE RAND( CHECKSUM(NEWID()) ) < x:
# WHERE RAND( ABS(CAST(CAST(new_id AS VARBINARY) AS INT))) < x:
# https://www.sqlservercentral.com/forums/topic/whats-the-best-way-to-get-a-sample-set-of-a-big-table-without-primary-key#post-1948778 # noqa: E501
new_query = copy.deepcopy(expr)
if new_query.limit:
if new_query.select.quantifier == Token("DISTINCT"):
# If having SELECT DISTINCT "some_col", ... LIMIT x
# It will be rewritten as:
# SELECT TOP x "some_col", ... GROUP BY "some_col"
# The other columns will not be inserted in the group by
# because of this query:
# SELECT
# DISTINCT "col" AS "sarus_privacy_unit",
# RAND () AS random_1
# FROM
# "rejpsxndax"
# ORDER BY
# RAND ()
# LIMIT
# 3256
# In this case I can't put random_1 in the GROUP BY
# because isn't a valid column.
first_expr = new_query.select.namedExpressions[0]
group = (
GroupingExpression(first_expr.name)
if first_expr.name is not None
else GroupingExpression(first_expr.expression)
)
new_query.agg = Aggregate(groupingExpressions=[group])
new_query.select.quantifier = Top(new_query.limit.n)
new_query.limit = None
if (
new_query.where
and isinstance(new_query.where.condition, BooleanCompare)
and isinstance(new_query.where.condition.left, BareFunction)
and new_query.where.condition.left.name == FuncName("RANDOM")
):
new_query.where = None
if new_query.order:
if isinstance(
new_query.order.sortItems[0].expression, BareFunction
) and new_query.order.sortItems[0].expression.name == FuncName(
"RANDOM"
):
new_query.order = Order(
[
SortItem(
expression=UserDefinedFunction(
name="NEWID", expressions=Seq([])
),
order=None,
)
]
)
return self.walk_tree(new_query)
if isinstance(expr, CastFunction) and (
expr.dbtype == "varchar" or expr.dbtype == "text"
):
return UserDefinedFunction(
name="CONVERT",
expressions=Seq(
["VARCHAR", self.walk_tree(expr.expression), Literal(126)]
),
)
if isinstance(expr, CastFunction) and (expr.dbtype == "boolean"):
return CastFunction(self.walk_tree(expr.expression), dbtype="BIT")
if isinstance(expr, Literal) and isinstance(expr.value, bool):
return Literal(int(expr.value))
return self.walk_tree(expr)
def transform_to_postgres(self, expr: Sql) -> Sql:
"""Returns the transformed expression"""
if self.query.select.quantifier:
self.query.limit = Limit(self.query.select.quantifier.n)
self.query.select.quantifier = None
if isinstance(expr, Column):
return expr
if isinstance(expr, MathFunction) and expr.name == FuncName("LOG"):
return MathFunction(
name=FuncName("LN"), expression=self.transform(expr.expression)
)
if isinstance(expr, BareFunction) and expr.name == FuncName("RAND"):
return BareFunction(name=FuncName("RANDOM"))
if isinstance(expr, UserDefinedFunction):
if (
expr.name == "DATEDIFF"
and expr.expressions[0] == Token("SECOND")
and expr.expressions[1] == Literal("'19700101'")
):
# it translates DATEDIFF(SECOND, '19700101', column)
# to EXTRACT(epoch FROM column)
return ExtractFunction("epoch", expr.expressions[2])
if (
expr.name == "CONVERT"
and isinstance(expr.expressions[1], UserDefinedFunction)
and expr.expressions[1].name == "HASHBYTES"
):
# CONVERT( VARCHAR(MAX), HASHBYTES('MD5', X), 2)
# convert MD5(X) to
return UserDefinedFunction(
name="MD5",
expressions=Seq(
[self.walk_tree(expr.expressions[1].expressions[1])]
),
)
if (
expr.name == "CONVERT"
and isinstance(expr.expressions[1], UserDefinedFunction)
and expr.expressions[1].name == "CONVERT"
):
# CONVERT(
# VARCHAR(MAX),
# CONVERT(VARBINARY(MAX), source, 2),
# 0
# ) to
# ENCODE(DECODE(source, 'hex'), 'escape')
source = self.walk_tree(expr.expressions[1].expressions[1])
return EncodeFunction(
source=DecodeFunction(
source=source,
format=Literal(None, "'hex'"),
),
format=Literal(None, "'escape'"),
)
if expr.name == "IFNULL":
return CoalesceFunction(
expressions=Seq(
[
self.transform(udf_expr)
for udf_expr in expr.expressions
]
),
)
return self.walk_tree(expr)