Repository URL to install this package:
|
Version:
0.2.102 ▾
|
import re
import base64
import string
from snsql._ast.tokens import *
"""
string processing expressions
"""
class LowerFunction(SqlExpr):
def __init__(self, expression):
self.expression = expression
def children(self):
return [Token("LOWER"), Token("("), self.expression, Token(")")]
def evaluate(self, bindings):
exp = self.expression.evaluate(bindings)
return str(exp).lower()
def symbol(self, relations):
return LowerFunction(self.expression.symbol(relations))
class UpperFunction(SqlExpr):
def __init__(self, expression):
self.expression = expression
def children(self):
return [Token("UPPER"), Token("("), self.expression, Token(")")]
def evaluate(self, bindings):
exp = self.expression.evaluate(bindings)
return str(exp).upper()
def symbol(self, relations):
return UpperFunction(self.expression.symbol(relations))
class TrimFunction(SqlExpr):
def __init__(self, expression):
self.expression = expression
def children(self):
return [Token("TRIM"), Token("("), self.expression, Token(")")]
def evaluate(self, bindings):
exp = self.expression.evaluate(bindings)
if exp is None:
return None
else:
return str(exp).strip()
def symbol(self, relations):
return TrimFunction(self.expression.symbol(relations))
def __str__(self):
return f"TRIM( {self.expression})"
class CharLengthFunction(SqlExpr):
def __init__(self, expression, token=Token("CHAR_LENGTH")):
self.expression = expression
self.token = token
def children(self):
return [self.token, Token("("), self.expression, Token(")")]
def evaluate(self, bindings):
exp = self.expression.evaluate(bindings)
if exp is None:
return None
else:
return len(str(exp))
def symbol(self, relations):
return TrimFunction(self.expression.symbol(relations))
class PositionFunction(SqlExpr):
def __init__(self, search, source):
self.search = search
self.source = source
def children(self):
return [Token("POSITION"), Token("("), self.search, Token("IN"), self.source, Token(")")]
def evaluate(self, bindings):
search = self.search.evaluate(bindings)
source = self.source.evaluate(bindings)
if search is None or source is None:
return None
if search not in source:
return 0
else:
return source.index(search) + 1
def symbol(self, relations):
return PositionFunction(self.search.symbol(relations), self.source.symbol(relations))
def __str__(self):
return f"POSITION( {self.search} IN {self.source})"
class SubstringFunction(SqlExpr):
def __init__(self, source, start, length, tokens):
self.source = source
self.start = start
self.length = length
self.tokens = tokens
def children(self):
start = [Token("SUBSTRING"), Token("("), self.source, self.tokens["FROM"], self.start]
middle = [] if self.length is None else [self.tokens["FOR"], self.length]
end = [Token(")")]
return start + middle + end
def evaluate(self, bindings):
source = self.source.evaluate(bindings) if self.source else None
start = self.start.evaluate(bindings) if self.start else None
length = self.length.evaluate(bindings) if self.length else None
if source is None or start is None:
return None
if not isinstance(start, int):
raise ValueError(f"Start position must evaluate to an integer: {str(self.start)}")
source = str(source)
if start == 0:
start = 1
if start > 0:
start = start - 1 # python indices start at 0
if start < 0:
start = len(source) + start
if not length:
return source[start:]
else:
if not isinstance(length, int):
raise ValueError("Substring length if provided must be positive")
if length < 1:
return None
return source[start:start+length]
def symbol(self, relations):
return SubstringFunction(self.source.symbol(relations), self.start.symbol(relations), self.length.symbol(relations), self.tokens)
def __str__(self):
middle = "" if self.length is None else f"{self.tokens['FOR']} {self.length} "
return f"SUBSTRING( {self.source} {self.tokens['FROM']} {self.start} {middle})"
class SubstrBigQueryFunction(SqlExpr):
def __init__(self, source, start, length):
self.source = source
self.start = start
self.length = length
def children(self):
required = [Token("SUBSTR"), Token("("), self.source, Token(","), self.start]
middle = [] if self.length is None else [Token(","), self.length]
end = [Token(")")]
return required + middle + end
def evaluate(self, bindings):
source = self.source.evaluate(bindings)
start = self.start.evaluate(bindings)
length = self.length.evaluate(bindings) if self.length else None
if not isinstance(start, int):
raise ValueError(f"Start position must evaluate to an integer: {str(self.start)}")
source = str(source) # source can be also byte
if start == 0 or start < - len(source):
start = 1
if start > 0:
start = start - 1 # python indices start at 0
if start < 0:
start = len(source) + start
if not length:
return source[start:]
else:
if not isinstance(length, int):
raise ValueError("Substring length if provided must be an integer")
if length < 0:
raise ValueError("Substring length if provided must be positive")
return source[start:start+length]
def symbol(self, relations):
return SubstrBigQueryFunction(
self.source.symbol(relations),
self.start.symbol(relations),
None if self.length is None else self.length.symbol(relations)
)
class RegexpContainsFunction(SqlExpr):
def __init__(self, source, regexp):
self.source = source
self.regexp = regexp
def children(self):
return [Token("REGEXP_CONTAINS"), Token("("), self.source, Token(","), self.regexp, Token(")")]
def evaluate(self, bindings):
source = self.source.evaluate(bindings)
regexp = self.regexp.evaluate(bindings)
try:
pattern = re.compile(regexp)
except re.error:
raise ValueError("REGEXP_CONTAINS expression has an ivalid regex!")
if re.search(pattern, str(source)):
return True
else:
return False
def symbol(self, relations):
return RegexpContainsFunction(self.source.symbol(relations), self.regexp.symbol(relations))
class RegexpExtractFunction(SqlExpr):
"""https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#regexp_extract
"""
def __init__(self, source, regexp, position, occurrence):
self.source = source
self.regexp = regexp
self.position = position
self.occurrence = occurrence
def children(self):
required = [Token("REGEXP_EXTRACT"), Token("("), self.source, Token(","), self.regexp]
position = [] if self.position is None else [Token(","), self.position]
occurrence = [] if self.occurrence is None else [Token(","), self.occurrence]
end = [Token(")")]
return required + position + occurrence + end
def evaluate(self, bindings):
source = self.source.evaluate(bindings)
regexp = self.regexp.evaluate(bindings)
position = self.position.evaluate(bindings) if self.position else 1
occurrence = self.occurrence.evaluate(bindings) if self.occurrence else 1
try:
pattern = re.compile(regexp)
except re.error:
raise ValueError("REGEXP_EXTRACT expression has an ivalid regex!")
if not isinstance(position, int):
raise ValueError("REGEXP_EXTRACT position should be a positive interger!")
elif position <=0:
raise ValueError("REGEXP_EXTRACT position should be a positive interger!")
elif position > len(source)+1:
return None
else:
position -= 1
values = re.findall(pattern, source[position:])
if not isinstance(occurrence, int):
raise ValueError("REGEXP_EXTRACT occurrence should be a positive interger!")
elif occurrence <0:
raise ValueError("REGEXP_EXTRACT occurrence should be a positive interger!")
elif occurrence >= len(values)+1:
return None
else:
occurrence -= 1
return values[occurrence]
def symbol(self, relations):
return RegexpExtractFunction(
self.source.symbol(relations),
self.regexp.symbol(relations),
None if self.position is None else self.position.symbol(relations),
None if self.occurrence is None else self.occurrence.symbol(relations),
)
class RegexpReplaceFunction(SqlExpr):
def __init__(self, source, regexp, replacement):
self.source = source
self.regexp = regexp
self.replacement = replacement
def children(self):
return [Token("REGEXP_REPLACE"), Token("("), self.source, Token(","), self.regexp, Token(","), self.replacement, Token(")")]
def evaluate(self, bindings):
source = self.source.evaluate(bindings)
regexp = self.regexp.evaluate(bindings)
replacement = self.replacement.evaluate(bindings)
try:
pattern = re.compile(regexp)
except re.error:
raise ValueError("REGEXP_REPLACE expression has an ivalid regex!")
return re.sub(pattern, str(replacement), str(source))
def symbol(self, relations):
return RegexpReplaceFunction(
self.source.symbol(relations),
self.regexp.symbol(relations),
self.replacement.symbol(relations),
)
class ConcatFunction(SqlExpr):
def __init__(self, expressions):
self.expressions = expressions
def children(self):
return [Token("CONCAT"), Token("("), Seq(self.expressions), Token(")")]
def evaluate(self, bindings):
eval = [e.evaluate(bindings) for e in self.expressions]
if any([v is None for v in eval]):
return None
return ''.join([str(e) for e in eval])
def symbol(self, relations):
symbols = [e.symbol(relations) for e in self.expressions]
return ConcatFunction(symbols)
class CoalesceFunction(SqlExpr):
def __init__(self, expressions):
self.expressions = expressions
def children(self):
return [Token("COALESCE"), Token("("), Seq(self.expressions), Token(")")]
def evaluate(self, bindings):
eval = [e.evaluate(bindings) for e in self.expressions]
eval = [e for e in eval if e is not None]
if len(eval) == 0:
return None
else:
return eval[0]
def symbol(self, relations):
symbols = [e.symbol(relations) for e in self.expressions]
return CoalesceFunction(symbols)
class EncodeFunction(SqlExpr):
def __init__(self, source, format):
"""Encode binary data into a textual representation
Args:
source: string (convertalbe to bytearray)
format: of the text one between ['base64', 'hex', 'escape']
"""
self.source = source
self.format = format
def children(self):
return [
Token("ENCODE"),
Token("("),
self.source,
Token(","),
self.format,
Token(")")
]
def evaluate(self, bindings):
source = str(self.source.evaluate(bindings))
format = str(self.format.evaluate(bindings))
if format == "hex":
return source.encode('utf-8').hex()
elif format == "escape":
return source.encode('utf-8').decode('utf-8')
elif format == "base64":
return base64.b64encode(bytes(source, 'utf-8'))
else:
raise ValueError(f"ENCODE format can only be one of ('base64', 'hex', 'escape'). Got {format}")
def symbol(self, relations):
return EncodeFunction(
self.source.symbol(relations),
self.format.symbol(relations),
)
class DecodeFunction(SqlExpr):
def __init__(self, source, format):
"""Decode binary data from textual representation in string
Args:
source: hexadecimal textual representation
format: of the text. it can be one among ['base64', 'hex', 'escape']
"""
self.source = source
self.format = format
def children(self):
return [
Token("DECODE"),
Token("("),
self.source,
Token(","),
self.format,
Token(")")
]
def evaluate(self, bindings):
# source is a hexadecimal textual representation
source = str(self.source.evaluate(bindings))
format = str(self.format.evaluate(bindings))
if format == "hex":
if not all(c in string.hexdigits for c in source):
raise ValueError("Source string is not a hexadecimal textual representation")
if len(source)%2!=0:
raise ValueError("Invalid hexadecimal source data: odd number of digits")
return source
elif format == "escape":
return source.encode('utf-8').hex()
elif format == "base64":
if len(source)%4!=0:
raise ValueError("Invalid base64 end sequence")
return base64.b64decode(source.encode()).hex()
else:
raise ValueError(f"DECODE format can only be one of ('base64', 'hex', 'escape'). Got {format}")
def symbol(self, relations):
return DecodeFunction(
self.source.symbol(relations),
self.format.symbol(relations),
)
class UnhexFunction(SqlExpr):
def __init__(self, source, token=Token("UNHEX")):
"""Convert the Hexadecimal number into the bytes represented by the Number
Args:
source: hexadecimal textual representation
"""
self.source = source
self.token = token
def children(self):
return [
self.token,
Token("("),
self.source,
Token(")")
]
def evaluate(self, bindings):
# source is a hexadecimal textual representation
source = str(self.source.evaluate(bindings))
if not all(c in string.hexdigits for c in source):
raise ValueError("Source string is not a hexadecimal textual representation")
if len(source)%2!=0:
raise ValueError("Invalid hexadecimal source data: odd number of digits")
return bytearray.fromhex(source).decode()
def symbol(self, relations):
return UnhexFunction(
self.source.symbol(relations),
)
class HexFunction(SqlExpr):
def __init__(self, source, token=Token("HEX")):
"""Convert the Bytes to Hexadecimal number
Args:
source: Hexadecimal number
"""
self.source = source
self.token = token
def children(self):
return [
self.token,
Token("("),
self.source,
Token(")")
]
def evaluate(self, bindings):
raise NotImplementedError("'evaluate' is not implemented for 'HEX'")
def symbol(self, relations):
return HexFunction(
self.source.symbol(relations),
)
class RegexpSplitToTableFunction(SqlExpr):
def __init__(self, source, pattern):
self.source = source
self.pattern = pattern
def children(self):
return [
Token("REGEXP_SPLIT_TO_TABLE"),
Token("("),
self.source,
Token(","),
self.pattern,
Token(")")
]
def evaluate(self, bindings):
raise NotImplementedError("'evaluate' is not implemented for 'REGEXP_SPLIT_TO_TABLE'")
def symbol(self, relations):
return RegexpSplitToTableFunction(
self.source.symbol(relations),
)