Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
from enum import Enum
import os
import typing as t
from snsql._ast.ast import (
BetweenCondition,
CaseExpression,
Column,
Literal,
NestedBoolean,
PredicatedExpression,
Query,
Sql,
SqlExpr,
Table,
UnifiedQuery,
UserDefinedFunction,
)
from snsql._ast.expressions.string import (
ConcatFunction,
DecodeFunction,
EncodeFunction,
RegexpSplitToTableFunction,
)
from sarus_sql.dialects import SQLDialect
from .base import BaseTranslator
from .charset_inferring_functions import char_inference_query, custom_charset
# Redshift aggregate functions that are not common with Postgres
USER_DEFINED_AGGS = (
"ANY_VALUE",
"LISTAGG",
"MEDIAN",
)
def encode_decode_redshift(pg_format: Literal) -> Literal:
"""
Translate encoding/decoding formats from PostgreSQL to Redshift.
The function supports three encoding formats: 'escape',
'hex', and 'base64'.
Translation rules:
escape -> utf8
hex -> hex
base64 -> base64
"""
if pg_format == Literal(None, "'escape'"):
encode_format = Literal(None, "'utf8'")
elif pg_format in [Literal(None, "'hex'"), Literal(None, "'base64'")]:
encode_format = pg_format
else:
raise ValueError(
"ENCODE format can only be one of ('base64', 'hex',"
f" 'escape'). Got {pg_format}"
)
return encode_format
def encode_decode_postgres(rs_format: Literal) -> Literal:
"""
Translate encoding/decoding formats from Redshift to PostgreSQL.
The function supports three encoding formats: 'utf8', 'hex', and 'base64'.
Translation rules:
utf8 -> escape
hex -> hex
base64 -> base64
"""
if rs_format == Literal(None, "'utf8'"):
encode_format = Literal(None, "'escape'")
elif rs_format in [Literal(None, "'hex'"), Literal(None, "'base64'")]:
encode_format = rs_format
else:
raise ValueError(
"ENCODE format can only be one of ('base64', 'hex',"
f" 'utf8'). Got {rs_format}"
)
return encode_format
class DETECT_CHARSET(Enum):
"""Enum to controll for the charset inference query."""
# It will generate a query for the Basic Multilingual Plane (BMP)
# subset of charset 0-65535, which is 0-55295
# The reason of reduced charset is in Redshift limitation, which
# raises an error for some of characters with a code more then 55295:
# The UTF-8 character is reserved as a surrogate.
# Surrogate code points (U+D800 through U+DFFF) are invalid.
REDSHIFT_BMP = "REDSHIFT_BMP"
# It will generate a query for the ASCII charset (0-255)
ASCII = "ASCII"
# it will infer the charset. It assumes that a given row of text has
# at most 1 milion characters.
INFER = "INFER"
class RedshiftTranslator(BaseTranslator):
"""
Translator for Redshift <-> Postgres.
"""
def __init__(
self, query: t.Union[Query, UnifiedQuery], **kwargs: t.Any
) -> None:
super().__init__(query, **kwargs)
# self.quote = "`"
# self.alias_map = self.alias_mapping()
# self.rename_tables_columns()
def transform_to_dialect(self, expr: Sql) -> Sql:
"""
Walks through the tree of a parsed SQL query and translates its
elements from PostgreSQL to Redshift.
"""
# empty condition allows to avoid a recursion error
if isinstance(expr, Column):
return expr
if isinstance(expr, CaseExpression):
when_args = []
for when_expr in expr.when_exprs:
if isinstance(when_expr.expression, NestedBoolean):
new_when_expr = when_expr.expression
if isinstance(
new_when_expr.expression, PredicatedExpression
) and isinstance(
new_when_expr.expression.predicate, BetweenCondition
):
predicate = new_when_expr.expression.predicate
if (
isinstance(predicate.lower, Literal)
and isinstance(predicate.lower.value, float)
and predicate.lower.value < 2.2250738585072014e-308
):
predicate.lower = Literal(value=0.0)
if (
isinstance(predicate.upper, Literal)
and isinstance(predicate.upper.value, float)
and predicate.upper.value < 2.2250738585072014e-308
):
predicate.upper = Literal(value=0.0)
new_when_expr.predicate = predicate
when_expr.expression = new_when_expr
if (
isinstance(when_expr.then, Literal)
and isinstance(when_expr.then.value, float)
and when_expr.then.value < 2.2250738585072014e-308
):
when_expr.then = Literal(value=0.0)
when_args.append(when_expr)
expr.when_exprs = when_args
expr.else_expr = self.walk_tree(expr.else_expr)
if isinstance(expr, ConcatFunction):
expressions_to_concat = self.walk_tree(expr.expressions)
arg0 = expressions_to_concat.pop()
if len(expressions_to_concat) == 0:
return ConcatFunction([arg0, Literal(None, "''")])
for arg in expressions_to_concat:
arg0 = ConcatFunction([arg0, arg])
return arg0
if isinstance(expr, DecodeFunction):
# prepare format for future decode function
decode_format = encode_decode_redshift(expr.format)
# check if the source is decode function
if isinstance(expr.source, EncodeFunction):
encode_format = encode_decode_redshift(expr.source.format)
decode_source = UserDefinedFunction(
name="FROM_VARBYTE",
expressions=[
expr.source.source,
encode_format,
],
)
else:
decode_source = self.walk_tree(expr.source)
return UserDefinedFunction(
name="TO_VARBYTE",
expressions=[
decode_source,
decode_format,
],
)
if isinstance(expr, EncodeFunction):
# prepare format for future encode function
encode_format = encode_decode_redshift(expr.format)
# check if the source is decode function
if isinstance(expr.source, DecodeFunction):
decode_format = encode_decode_redshift(expr.source.format)
encode_source = UserDefinedFunction(
name="TO_VARBYTE",
expressions=[
expr.source.source,
decode_format,
],
)
else:
encode_source = self.walk_tree(expr.source)
return UserDefinedFunction(
name="FROM_VARBYTE",
expressions=[
encode_source,
encode_format,
],
)
# check if UserDefinedFunction is not one of
# aggregation functions, available only in Redshift and
# not in PostgreSQL
if (
isinstance(expr, UserDefinedFunction)
and expr.name.upper() in USER_DEFINED_AGGS
):
raise NotImplementedError(
f"Aggregate functions {expr.name}" f" not supported: {expr}"
)
return None
if isinstance(expr, Query) and any(
[
isinstance(name_expr.expression, RegexpSplitToTableFunction)
for name_expr in expr.select.namedExpressions
]
):
# expected a query like this one:
# SELECT
# DISTINCT REGEXP_SPLIT_TO_TABLE(anon_2.name ,'')
# AS "regexp_split"
# FROM (
# SELECT "tab1"."name" AS "name"
# FROM "tab1" WHERE "tab1"."name" IS NOT NULL
# ) AS "anon_2"
# otherwise raise an error
try:
# some asserts to make sure we have the expected query.
assert isinstance(
expr.select.namedExpressions[0].expression,
RegexpSplitToTableFunction,
)
assert len(expr.select.namedExpressions) == 1
# it would be the anon_2.name. the column for which we
# want to find the charset.
col = expr.select.namedExpressions[0].expression.source
# It is the alias of the charset
alias = expr.select.namedExpressions[0].name
# it is the source of the column on which we want to compute
# the carset.
source_relation = expr.source.relations[0]
except RuntimeError:
raise ValueError(
f"We don't know how to translate {expr} "
"into Redshift dialect."
)
# We set by default ASCII charset (0-255 unicode values)
# for efficiency. If it is known in advance that the dataset
# has characters other than ASCII, DETECT_CHARSET should be set
# accordingly.
# We use an obsolete env variable MSSQL_CHARSET here, since
# this part will be supressed later.
charset = os.environ.get(
"MSSQL_CHARSET",
default="ASCII",
) # type: ignore
if charset == DETECT_CHARSET.INFER.value:
col = expr.select.namedExpressions[0].expression.source
alias = expr.select.namedExpressions[0].name
source_relation = expr.source.relations[0]
return char_inference_query(
column=col,
col_alias=alias,
source_relation=source_relation,
max_string_length=1000000,
dialect=SQLDialect.REDSHIFT,
)
elif charset == DETECT_CHARSET.ASCII.value:
return custom_charset(alias, 255, SQLDialect.REDSHIFT)
else:
return custom_charset(alias, 55295, SQLDialect.REDSHIFT)
return self.walk_tree(expr)
def transform_to_postgres(self, expr: SqlExpr) -> SqlExpr:
"""
Translate a SQL query back from Redshift to PostgreSQL.
"""
if (
isinstance(expr, UserDefinedFunction)
and expr.name == "FROM_VARBYTE"
):
# prepare format for future decode function
encode_format = encode_decode_postgres(expr.expressions[1])
# check if 1st argument is decode function
if (
isinstance(expr.expressions[0], UserDefinedFunction)
and expr.expressions[0].name == "TO_VARBYTE"
):
decode_format = encode_decode_postgres(
expr.expressions[0].expressions[1]
)
decode_source = DecodeFunction(
source=self.walk_tree(expr.expressions[0].expressions[0]),
format=decode_format,
)
else:
decode_source = self.walk_tree(expr.expressions[0])
return EncodeFunction(
source=decode_source,
format=encode_format,
)
if isinstance(expr, UserDefinedFunction) and expr.name == "TO_VARBYTE":
# prepare format for future decode function
decode_format = encode_decode_postgres(expr.expressions[1])
# check if the source is encode function
if (
isinstance(expr.expressions[0], UserDefinedFunction)
and expr.expressions[0].name == "FROM_VARBYTE"
):
decode_format = encode_decode_postgres(
expr.expressions[0].expressions[1]
)
decode_source = EncodeFunction(
source=self.walk_tree(expr.expressions[0].expressions[0]),
format=decode_format,
)
else:
decode_source = self.walk_tree(expr.expressions[0])
return DecodeFunction(
source=decode_source,
format=decode_format,
)
# empty condition allows to avoid a recursion error
if isinstance(expr, Column) or isinstance(expr, Table):
return expr
return self.walk_tree(expr)