Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
from snsql._ast.tokens import *

"""
    SQL-specific expressions
"""


class AllColumns(SqlExpr):
    """A SELECT with * or Table.*"""

    def __init__(self, table=None):
        self.table = table

    def __str__(self):
        return (self.table + "." if self.table is not None else "") + "*"

    def __hash__(self):
        return hash(str(self))

    def symbol_name(self):
        return 'star'

    def all_symbols(self, relations):
        sym = [r.all_symbols(self) for r in relations if r.alias_match(str(self))]
        if len(sym) == 0:
            raise ValueError("Column cannot be found " + str(self))
        return flatten(sym)

    @property
    def is_key_count(self):
        # can only be key_count if metadata specifies row_privacy,
        # which is set by load_symbols
        return False

    @property
    def is_count(self):
        return True


class AggFunction(SqlExpr):
    """A function such as SUM, COUNT, AVG"""

    def __init__(self, name, quantifier, expression):
        self.name = name
        self.quantifier = quantifier
        self.expression = expression

    def symbol_name(self):
        prefix = (
            self.name.lower()
            + "_"
            + ("" if self.quantifier is None else self.quantifier.lower() + "_")
        )
        return self.prepend(prefix, self.expression.symbol_name())

    def prepend(self, prefix, value):
        # handles generation of a new identifier while honoring escaping rules
        if value == "" or not value[0] in ['"', "`", "["]:
            return prefix + value
        value = value.replace("`", '"').replace("[", '"').replace("]", '"')
        parts = value.split('"')
        if len(parts) == 3:
            return '"' + prefix + parts[1] + '"'
        else:
            return prefix + "_x_" + value.replace('"', "").replace(" ", "")

    def is_aggregate(self):
        return self.name in ["SUM", "COUNT", "MIN", "MAX", "AVG", "VAR"]

    def symbol(self, relations):
        return AggFunction(self.name, self.quantifier, self.expression.symbol(relations))

    def type(self):
        # will switch to lookup table
        if self.name == "SUM":
            return self.expression.type()
        elif self.name == "COUNT":
            return "int"
        elif self.name == "MIN":
            return self.expression.type()
        elif self.name == "MAX":
            return self.expression.type()
        elif self.name in ["VAR", "VARIANCE", "AVG", "STD", "STDDEV", "STDEV"]:
            return "float"
        else:
            return "unknown"

    def sensitivity(self):
        # will switch to lookup table
        if self.name == "SUM":
            return self.expression.sensitivity()
        elif self.name == "COUNT":
            return 1
        elif self.name == "AVG":
            return self.expression.sensitivity()
        elif self.name == "MIN":
            return self.expression.sensitivity()
        elif self.name == "MAX":
            return self.expression.sensitivity()
        else:
            return None

    def __str__(self):
        ret = self.name + "( "
        if self.quantifier is not None:
            ret += self.quantifier + " "
        ret += str(self.expression) + " )"
        return ret

    def children(self):
        return [self.name, Token("("), self.quantifier, self.expression, Token(")")]


    def evaluate(self, bindings):
        # need to decide what to do with this
        return self.expression.evaluate(bindings)

    @property
    def is_key_count(self):
        if self.name == "SUM":
            return self.expression.is_key_count
        elif self.name == "COUNT":
            return self.expression.is_key_count
        else:
            return False

    @property
    def is_count(self):
        if self.name == "SUM":
            return self.expression.is_count
        elif self.name == "COUNT":
            return True
        else:
            return False


class RankingFunction(SqlExpr):
    def __init__(self, name, over):
        self.name = name
        self.over = over

    def children(self):
        return [self.name, Token("("), Token(")"), self.over]

    def symbol(self, relations):
        return RankingFunction(self.name, self.over.symbol(relations))


class LeadFunction(SqlExpr):
    #LEAD '(' sourceExpr=expression (',' offset=expression (',' default=expression)? )? ')' overClause
    def __init__(self, source, offset, default, over):
        self.source = source
        self.offset = offset
        self.default = default
        self.over = over

    def children(self):
        pre = [Token("LEAD"), Token("("), self.source]
        offset = [] if self.offset is None else [Token(","), self.offset]
        default = [] if self.default is None else [Token(","), self.default]
        post = [Token(")"), self.over]
        return pre + offset + default + post

    def symbol(self, relations):
        return LeadFunction(
            self.source.symbol(relations),
            self.offset.symbol(relations) if self.offset is not None else None,
            self.default.symbol(relations) if self.default is not None else None,
            self.over.symbol(relations)
        )

class GreatestFunction(SqlExpr):
    def __init__(self, expressions):
        self.expressions = expressions
    def children(self):
        return [Token("GREATEST"), Token("("), Seq(self.expressions), Token(")")]
    def evaluate(self, bindings):
        eval = [e.evaluate(bindings) for e in self.expressions]
        eval = [e for e in eval if e is not None]
        if len(eval) == 0:
            return None
        else:
            return max(eval)
    def symbol(self, relations):
        symbols = [e.symbol(relations) for e in self.expressions]
        return GreatestFunction(symbols)


class LeastFunction(SqlExpr):
    def __init__(self, expressions):
        self.expressions = expressions
    def children(self):
        return [Token("LEAST"), Token("("), Seq(self.expressions), Token(")")]
    def evaluate(self, bindings):
        eval = [e.evaluate(bindings) for e in self.expressions]
        eval = [e for e in eval if e is not None]
        if len(eval) == 0:
            return None
        else:
            return min(eval)
    def symbol(self, relations):
        symbols = [e.symbol(relations) for e in self.expressions]
        return LeastFunction(symbols)

class NotFuction(SqlExpr):
    """Similar to LogicalNot but it can be used in a select statement."""
    def __init__(self, boolean_expression):
        self.boolean_expression = boolean_expression
    def children(self):
        return [Token("NOT"), Token("("), self.boolean_expression, Token(")")]
    def symbol(self, relations):
        return NotFuction(self.boolean_expression.symbol(relations))


class PartitionByClause(SqlExpr):
    def __init__(self, expressions):
        self.expressions = expressions

    def children(self):
        return [Token("PARTITION"), Token("BY"), Seq(self.expressions)]

    def symbol(self, relations):
        symbols = [e.symbol(relations) for e in self.expressions if e is not None]
        return PartitionByClause(symbols)


class OverClause(SqlExpr):
    def __init__(self, partition, order):
        self.partition = partition
        self.order = order

    def children(self):
        pre = [Token("OVER"), Token("(")]
        post = [Token(")")]
        partition = [] if self.partition is None else [self.partition]
        order = [] if self.order is None else [self.order]
        return pre + partition + order + post

    def symbol(self, relations):
        return OverClause(self.partition.symbol(relations), self.order.symbol(relations))


class GroupingExpression(SqlExpr):
    """An expression used in Group By"""

    def __init__(self, expression):
        self.expression = expression

    def type(self):
        return self.expression.type()

    def children(self):
        return [self.expression]

    def symbol(self, relations):
        return GroupingExpression(self.expression.symbol(relations))


class SortItem(SqlExpr):
    """Used to sort a query's output"""

    def __init__(self, expression, order):
        self.expression = expression
        self.order = order

    def type(self):
        return self.expression.type()

    def children(self):
        return [self.expression, self.order]

    def symbol(self, relations):
        return SortItem(
            self.expression.symbol(relations),
            None if self.order is None else self.order.symbol(relations),
        )


class BooleanJoinCriteria(SqlExpr):
    """Join criteria using boolean expression"""

    def __init__(self, expression):
        self.expression = expression

    def children(self):
        return [Token("ON"), self.expression]

    def symbol(self, relations):
        return BooleanJoinCriteria(self.expression.symbol(relations))


class UsingJoinCriteria(SqlExpr):
    """Join criteria with USING syntax"""

    def __init__(self, identifiers):
        self.identifiers = Seq(identifiers)

    def children(self):
        return [Token("USING"), Token("("), self.identifiers, Token(")")]

    def symbol(self, relations):
        return UsingJoinCriteria(self.identifiers.symbol(relations))

class UserDefinedFunction(SqlExpr):
    def __init__(self, name, expressions):
        self.name = name.upper()
        self.expressions = Seq(expressions)

    def children(self):
        return [Token(self.name), Token("("), self.expressions, Token(")")]

    def symbol(self, relations):
        symbols = [e.symbol(relations) for e in self.expressions if e is not None]
        return UserDefinedFunction(self.name, symbols)