Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
smartnoise-sql / snsql / _ast / expressions / string.py
Size: Mime:
import re
import base64
import string
from snsql._ast.tokens import *

"""
    string processing expressions
"""

class LowerFunction(SqlExpr):
    def __init__(self, expression):
        self.expression = expression
    def children(self):
        return [Token("LOWER"), Token("("), self.expression, Token(")")]
    def evaluate(self, bindings):
        exp = self.expression.evaluate(bindings)
        return str(exp).lower()
    def symbol(self, relations):
        return LowerFunction(self.expression.symbol(relations))

class UpperFunction(SqlExpr):
    def __init__(self, expression):
        self.expression = expression
    def children(self):
        return [Token("UPPER"), Token("("), self.expression, Token(")")]
    def evaluate(self, bindings):
        exp = self.expression.evaluate(bindings)
        return str(exp).upper()
    def symbol(self, relations):
        return UpperFunction(self.expression.symbol(relations))

class TrimFunction(SqlExpr):
    def __init__(self, expression):
        self.expression = expression
    def children(self):
        return [Token("TRIM"), Token("("), self.expression, Token(")")]
    def evaluate(self, bindings):
        exp = self.expression.evaluate(bindings)
        if exp is None:
            return None
        else:
            return str(exp).strip()
    def symbol(self, relations):
        return TrimFunction(self.expression.symbol(relations))
    def __str__(self):
        return f"TRIM( {self.expression})"

class CharLengthFunction(SqlExpr):
    def __init__(self, expression, token=Token("CHAR_LENGTH")):
        self.expression = expression
        self.token = token
    def children(self):
        return [self.token, Token("("), self.expression, Token(")")]
    def evaluate(self, bindings):
        exp = self.expression.evaluate(bindings)
        if exp is None:
            return None
        else:
            return len(str(exp))
    def symbol(self, relations):
        return TrimFunction(self.expression.symbol(relations))

class PositionFunction(SqlExpr):
    def __init__(self, search, source):
        self.search = search
        self.source = source
    def children(self):
        return [Token("POSITION"), Token("("), self.search, Token("IN"), self.source, Token(")")]
    def evaluate(self, bindings):
        search = self.search.evaluate(bindings)
        source = self.source.evaluate(bindings)
        if search is None or source is None:
            return None
        if search not in source:
            return 0
        else:
            return source.index(search) + 1
    def symbol(self, relations):
        return PositionFunction(self.search.symbol(relations), self.source.symbol(relations))
    def __str__(self):
        return f"POSITION( {self.search} IN {self.source})"


class SubstringFunction(SqlExpr):
    def __init__(self, source, start, length, tokens):
        self.source = source
        self.start = start
        self.length = length
        self.tokens = tokens
    def children(self):
        start = [Token("SUBSTRING"), Token("("), self.source, self.tokens["FROM"], self.start]
        middle = [] if self.length is None else [self.tokens["FOR"], self.length]
        end = [Token(")")]
        return start + middle + end
    def evaluate(self, bindings):
        source = self.source.evaluate(bindings) if self.source else None
        start = self.start.evaluate(bindings) if self.start else None
        length = self.length.evaluate(bindings) if self.length else None
        if source is None or start is None:
            return None
        if not isinstance(start, int):
            raise ValueError(f"Start position must evaluate to an integer: {str(self.start)}")
        source = str(source)
        if start == 0:
            start = 1
        if start > 0:
            start = start - 1  # python indices start at 0
        if start < 0:
            start = len(source) + start
        if not length:
            return source[start:]
        else:
            if not isinstance(length, int):
                raise ValueError("Substring length if provided must be positive")
            if length < 1:
                return None
            return source[start:start+length]
    def symbol(self, relations):
        return SubstringFunction(self.source.symbol(relations), self.start.symbol(relations), self.length.symbol(relations), self.tokens)
    def __str__(self):
        middle = "" if self.length is None else f"{self.tokens['FOR']} {self.length} "
        return f"SUBSTRING( {self.source} {self.tokens['FROM']} {self.start} {middle})"


class SubstrBigQueryFunction(SqlExpr):
    def __init__(self, source, start, length):
        self.source = source
        self.start = start
        self.length = length
    def children(self):
        required = [Token("SUBSTR"), Token("("), self.source, Token(","), self.start]
        middle = [] if self.length is None else [Token(","), self.length]
        end = [Token(")")]
        return required + middle + end
    def evaluate(self, bindings):
        source = self.source.evaluate(bindings)
        start = self.start.evaluate(bindings)
        length = self.length.evaluate(bindings) if self.length else None
        if not isinstance(start, int):
            raise ValueError(f"Start position must evaluate to an integer: {str(self.start)}")
        source = str(source) # source can be also byte
        if start == 0 or start < - len(source):
            start = 1
        if start > 0:
            start = start - 1  # python indices start at 0
        if start < 0:
            start = len(source) + start
        if not length:
            return source[start:]
        else:
            if not isinstance(length, int):
                raise ValueError("Substring length if provided must be an integer")
            if length < 0:
                raise ValueError("Substring length if provided must be positive")
            return source[start:start+length]
    def symbol(self, relations):
        return SubstrBigQueryFunction(
            self.source.symbol(relations),
            self.start.symbol(relations),
            None if self.length is None else self.length.symbol(relations)
        )

class RegexpContainsFunction(SqlExpr):
    def __init__(self, source, regexp):
        self.source = source
        self.regexp = regexp
    def children(self):
        return [Token("REGEXP_CONTAINS"), Token("("), self.source, Token(","), self.regexp, Token(")")]
    def evaluate(self, bindings):
        source = self.source.evaluate(bindings)
        regexp = self.regexp.evaluate(bindings)
        try:
            pattern = re.compile(regexp)
        except re.error:
            raise ValueError("REGEXP_CONTAINS expression has an ivalid regex!")
        if re.search(pattern, str(source)):
            return True
        else:
            return False
    def symbol(self, relations):
        return RegexpContainsFunction(self.source.symbol(relations), self.regexp.symbol(relations))

class RegexpExtractFunction(SqlExpr):
    """https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#regexp_extract
    """
    def __init__(self, source, regexp, position, occurrence):
        self.source = source
        self.regexp = regexp
        self.position = position
        self.occurrence = occurrence
    def children(self):
        required = [Token("REGEXP_EXTRACT"), Token("("), self.source,  Token(","), self.regexp]
        position = [] if self.position is None else [Token(","), self.position]
        occurrence = [] if self.occurrence is None else [Token(","), self.occurrence]
        end = [Token(")")]
        return required + position + occurrence + end
    def evaluate(self, bindings):
        source = self.source.evaluate(bindings)
        regexp = self.regexp.evaluate(bindings)
        position = self.position.evaluate(bindings) if self.position else 1
        occurrence = self.occurrence.evaluate(bindings) if self.occurrence else 1
        try:
            pattern = re.compile(regexp)
        except re.error:
            raise ValueError("REGEXP_EXTRACT expression has an ivalid regex!")
        if not isinstance(position, int):
            raise ValueError("REGEXP_EXTRACT position should be a positive interger!")
        elif position <=0:
            raise ValueError("REGEXP_EXTRACT position should be a positive interger!")
        elif position > len(source)+1:
            return None
        else:
            position -= 1
        values = re.findall(pattern, source[position:])
        if not isinstance(occurrence, int):
            raise ValueError("REGEXP_EXTRACT occurrence should be a positive interger!")
        elif occurrence <0:
            raise ValueError("REGEXP_EXTRACT occurrence should be a positive interger!")
        elif occurrence >= len(values)+1:
            return None
        else:
            occurrence -= 1
        return values[occurrence]
    def symbol(self, relations):
        return RegexpExtractFunction(
            self.source.symbol(relations),
            self.regexp.symbol(relations),
            None if self.position is None else self.position.symbol(relations),
            None if self.occurrence is None else self.occurrence.symbol(relations),
        )

class RegexpReplaceFunction(SqlExpr):
    def __init__(self, source, regexp, replacement):
        self.source = source
        self.regexp = regexp
        self.replacement = replacement
    def children(self):
        return [Token("REGEXP_REPLACE"), Token("("), self.source,  Token(","), self.regexp, Token(","), self.replacement, Token(")")]
    def evaluate(self, bindings):
        source = self.source.evaluate(bindings)
        regexp = self.regexp.evaluate(bindings)
        replacement = self.replacement.evaluate(bindings)
        try:
            pattern = re.compile(regexp)
        except re.error:
            raise ValueError("REGEXP_REPLACE expression has an ivalid regex!")
        return re.sub(pattern, str(replacement), str(source))

    def symbol(self, relations):
        return RegexpReplaceFunction(
            self.source.symbol(relations),
            self.regexp.symbol(relations),
            self.replacement.symbol(relations),
        )

class ConcatFunction(SqlExpr):
    def __init__(self, expressions):
        self.expressions = expressions
    def children(self):
        return [Token("CONCAT"), Token("("), Seq(self.expressions), Token(")")]
    def evaluate(self, bindings):
        eval = [e.evaluate(bindings) for e in self.expressions]
        if any([v is None for v in eval]):
            return None
        return ''.join([str(e) for e in eval])
    def symbol(self, relations):
        symbols = [e.symbol(relations) for e in self.expressions]
        return ConcatFunction(symbols)

class CoalesceFunction(SqlExpr):
    def __init__(self, expressions):
        self.expressions = expressions
    def children(self):
        return [Token("COALESCE"), Token("("), Seq(self.expressions), Token(")")]
    def evaluate(self, bindings):
        eval = [e.evaluate(bindings) for e in self.expressions]
        eval = [e for e in eval if e is not None]
        if len(eval) == 0:
            return None
        else:
            return eval[0]
    def symbol(self, relations):
        symbols = [e.symbol(relations) for e in self.expressions]
        return CoalesceFunction(symbols)

class EncodeFunction(SqlExpr):
    def __init__(self, source, format):
        """Encode binary data into a textual representation

        Args:
            source: string (convertalbe to bytearray)
            format: of the text one between ['base64', 'hex', 'escape']
        """
        self.source = source
        self.format = format
    def children(self):
        return [
            Token("ENCODE"),
            Token("("),
            self.source,
            Token(","),
            self.format,
            Token(")")
        ]
    def evaluate(self, bindings):
        source = str(self.source.evaluate(bindings))
        format = str(self.format.evaluate(bindings))

        if format == "hex":
            return source.encode('utf-8').hex()
        elif format == "escape":
            return source.encode('utf-8').decode('utf-8')
        elif format == "base64":
            return base64.b64encode(bytes(source, 'utf-8'))
        else:
            raise ValueError(f"ENCODE format can only be one of ('base64', 'hex', 'escape'). Got {format}")
    def symbol(self, relations):
        return EncodeFunction(
            self.source.symbol(relations),
            self.format.symbol(relations),
        )

class DecodeFunction(SqlExpr):
    def __init__(self, source, format):
        """Decode binary data from textual representation in string

        Args:
            source: hexadecimal textual representation
            format: of the text. it can be one among ['base64', 'hex', 'escape']
        """
        self.source = source
        self.format = format
    def children(self):
        return [
            Token("DECODE"),
            Token("("),
            self.source,
            Token(","),
            self.format,
            Token(")")
        ]
    def evaluate(self, bindings):
        # source is a hexadecimal textual representation
        source = str(self.source.evaluate(bindings))
        format = str(self.format.evaluate(bindings))

        if format == "hex":
            if not all(c in string.hexdigits for c in source):
                raise ValueError("Source string is not a hexadecimal textual representation")

            if len(source)%2!=0:
                raise ValueError("Invalid hexadecimal source data: odd number of digits")
            return source

        elif format == "escape":
            return source.encode('utf-8').hex()

        elif format == "base64":
            if len(source)%4!=0:
                raise ValueError("Invalid base64 end sequence")
            return base64.b64decode(source.encode()).hex()
        else:
            raise ValueError(f"DECODE format can only be one of ('base64', 'hex', 'escape'). Got {format}")
    def symbol(self, relations):
        return DecodeFunction(
            self.source.symbol(relations),
            self.format.symbol(relations),
        )

class UnhexFunction(SqlExpr):
    def __init__(self, source, token=Token("UNHEX")):
        """Convert the Hexadecimal number into the bytes represented by the Number

        Args:
            source: hexadecimal textual representation
        """
        self.source = source
        self.token = token
    def children(self):
        return [
            self.token,
            Token("("),
            self.source,
            Token(")")
        ]
    def evaluate(self, bindings):
        # source is a hexadecimal textual representation
        source = str(self.source.evaluate(bindings))

        if not all(c in string.hexdigits for c in source):
            raise ValueError("Source string is not a hexadecimal textual representation")

        if len(source)%2!=0:
            raise ValueError("Invalid hexadecimal source data: odd number of digits")

        return bytearray.fromhex(source).decode()
    def symbol(self, relations):
        return UnhexFunction(
            self.source.symbol(relations),
        )

class HexFunction(SqlExpr):
    def __init__(self, source, token=Token("HEX")):
        """Convert the Bytes to Hexadecimal number

        Args:
            source: Hexadecimal number
        """
        self.source = source
        self.token = token
    def children(self):
        return [
            self.token,
            Token("("),
            self.source,
            Token(")")
        ]
    def evaluate(self, bindings):
        raise NotImplementedError("'evaluate' is not implemented for 'HEX'")

    def symbol(self, relations):
        return HexFunction(
            self.source.symbol(relations),
        )


class RegexpSplitToTableFunction(SqlExpr):
    def __init__(self, source, pattern):
        self.source = source
        self.pattern = pattern
    def children(self):
        return [
            Token("REGEXP_SPLIT_TO_TABLE"),
            Token("("),
            self.source,
            Token(","),
            self.pattern,
            Token(")")
        ]
    def evaluate(self, bindings):
        raise NotImplementedError("'evaluate' is not implemented for 'REGEXP_SPLIT_TO_TABLE'")

    def symbol(self, relations):
        return RegexpSplitToTableFunction(
            self.source.symbol(relations),
        )