Repository URL to install this package:
|
Version:
2.1.2+cpu ▾
|
import dataclasses
import itertools
import sympy
from sympy.logic.boolalg import BooleanAtom, Boolean as SympyBoolean
import operator
import math
import logging
import torch
from typing import Union, Dict, Optional
from torch._prims_common import dtype_to_type
from .interp import sympy_interp
log = logging.getLogger(__name__)
__all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"]
class ValueRangeError(RuntimeError):
pass
# Like sympify, but supports less stuff, and also ensures that direct
# sympy expressions don't have free variables
def simple_sympify(e):
if isinstance(e, bool):
return sympy.true if e else sympy.false
elif isinstance(e, int):
return sympy.Integer(e)
elif isinstance(e, float):
# infinity is special; we use it to bracket integers as well
if math.isinf(e):
return sympy.oo if e > 0 else -sympy.oo
return sympy.Float(e)
elif isinstance(e, sympy.Expr):
assert e.is_constant(), e
# NaNs can occur when doing things like 0 * sympy.oo, but it is better
# if the operator notices this and takes care of it, because sometimes
# the NaN is inappropriate (for example, for ints, the [-oo, oo] range
# should go to zero when multiplied with [0, 0])
assert e != sympy.nan
return e
elif isinstance(e, BooleanAtom):
return e
else:
raise AssertionError(f"not simple sympy type {type(e)}: {e}")
# Sympy atomics only. Unlike <=, it also works on Sympy bools.
def sympy_generic_le(lower, upper):
if isinstance(lower, sympy.Expr):
assert isinstance(upper, sympy.Expr)
return lower <= upper
else:
# only negative condition is True > False
assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean)
return not (lower and not upper)
@dataclasses.dataclass(frozen=True)
class ValueRanges:
# Although the type signature here suggests you can pass any
# sympy expression, in practice the analysis here only works
# with constant sympy expressions
lower: Union[sympy.Expr, SympyBoolean]
upper: Union[sympy.Expr, SympyBoolean]
is_bool: bool
def __init__(self, lower, upper):
lower = simple_sympify(lower)
upper = simple_sympify(upper)
# TODO: when the bounds have free variables, this may be
# nontrivial to actually verify
if not sympy_generic_le(lower, upper):
raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]")
# Because this is a frozen class
object.__setattr__(self, "lower", lower)
object.__setattr__(self, "upper", upper)
object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean))
assert isinstance(upper, SympyBoolean) == self.is_bool
def __contains__(self, x):
x = simple_sympify(x)
return sympy_generic_le(self.lower, x) and sympy_generic_le(x, self.upper)
def tighten(self, other) -> "ValueRanges":
"""Given two ValueRanges, returns their intersection"""
return self & other
# Intersection
def __and__(self, other) -> "ValueRanges":
if other == ValueRanges.unknown():
return self
if self == ValueRanges.unknown():
return other
assert self.is_bool == other.is_bool, (self, other)
if self.is_bool:
range = ValueRanges(sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper))
else:
range = ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper))
return range
# Union
def __or__(self, other) -> "ValueRanges":
if ValueRanges.unknown() in (self, other):
return ValueRanges.unknown()
assert self.is_bool == other.is_bool, (self, other)
if self.is_bool:
range = ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper))
else:
range = ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper))
return range
def is_singleton(self) -> bool:
return self.lower == self.upper
# TODO: this doesn't work with bools but arguably it should
@classmethod
def unknown(cls):
return cls(-sympy.oo, sympy.oo)
@classmethod
def wrap(cls, arg):
if isinstance(arg, ValueRanges):
return arg
return ValueRanges(arg, arg)
@classmethod
def increasing_map(cls, x, fn):
"""Increasing: x <= y => f(x) <= f(y)"""
x = cls.wrap(x)
return ValueRanges(fn(x.lower), fn(x.upper))
@classmethod
def decreasing_map(cls, x, fn):
"""Decreasing: x <= y => f(x) >= f(y)"""
x = cls.wrap(x)
return ValueRanges(fn(x.upper), fn(x.lower))
@classmethod
def monotone_map(cls, x, fn):
"""It's increasing or decreasing"""
x = cls.wrap(x)
l = fn(x.lower)
u = fn(x.upper)
return ValueRanges(min(l, u), max(l, u))
@classmethod
def convex_min_zero_map(cls, x, fn):
"""fn is convex and has a minimum at 0"""
x = ValueRanges.wrap(x)
if 0 in x:
return ValueRanges(0, max(fn(x.lower), fn(x.upper)))
else:
return cls.monotone_map(x, fn)
@classmethod
def coordinatewise_increasing_map(cls, x, y, fn):
"""
Increasing on each coordinate. Mathematically:
For every 1 <= i <= n and x_i <= y_i we have that
f(x1, .., xn) <= f(x1, , yi, ..., xn)
"""
x, y = cls.wrap(x), cls.wrap(y)
return ValueRanges(
fn(x.lower, y.lower),
fn(x.upper, y.upper),
)
@classmethod
def coordinatewise_monotone_map(cls, x, y, fn):
"""It's increasing or decreasing on each coordinate"""
x, y = cls.wrap(x), cls.wrap(y)
products = [
fn(a, b)
for a, b in itertools.product([x.lower, x.upper], [y.lower, y.upper])
]
return ValueRanges(min(products), max(products))
class SymPyValueRangeAnalysis:
"""
It gives bounds on a SymPy operator given bounds on its arguments
See the function `bound_sympy` for a function that applies this logic to a full SymPy expression
"""
@staticmethod
def constant(value, dtype):
# NB: value is NOT a sympy expression, it's a constant!
is_python = isinstance(value, (int, float, bool))
assert is_python or isinstance(value, (BooleanAtom, sympy.Integer, sympy.Number))
# using nan makes subsequent computation throw, and for the purposes of optimization
# returning -math.inf - math.inf is equivalent to giving up
if math.isnan(value):
return ValueRanges.unknown()
if is_python:
type_ = dtype_to_type(dtype)
value = type_(value)
else:
# We do a type check on a best-effort basis
# We don't want to force a cast to sympy.Float if the value is Rational to avoid losing precision
if dtype == torch.bool:
assert isinstance(value, BooleanAtom)
elif dtype.is_floating_point:
assert not value.is_finite or value.is_real
else:
# dtype is intXX
assert value.is_integer
return ValueRanges.wrap(value)
@staticmethod
def not_(a):
a = ValueRanges.wrap(a)
assert a.is_bool
return ValueRanges.decreasing_map(a, sympy.Not)
@staticmethod
def or_(a, b):
return ValueRanges.coordinatewise_increasing_map(a, b, sympy.Or)
@staticmethod
def and_(a, b):
return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And)
@staticmethod
def eq(a, b):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
if a.is_singleton() and b.is_singleton() and a.lower == b.lower:
return ValueRanges.wrap(sympy.true)
elif a.lower > b.upper or b.lower > a.upper: # ranges disjoint
return ValueRanges.wrap(sympy.false)
return ValueRanges(sympy.false, sympy.true)
@classmethod
def ne(cls, a, b):
return cls.not_(cls.eq(a, b))
@classmethod
def lt(cls, a, b):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
assert a.is_bool == b.is_bool
if a.is_bool:
return cls.and_(cls.not_(a), b)
else:
if a.upper < b.lower:
return ValueRanges.wrap(sympy.true)
elif a.lower >= b.upper:
return ValueRanges.wrap(sympy.false)
return ValueRanges(sympy.false, sympy.true)
@classmethod
def gt(cls, a, b):
return cls.lt(b, a)
@classmethod
def le(cls, a, b):
return cls.not_(cls.gt(a, b))
@classmethod
def ge(cls, a, b):
return cls.not_(cls.lt(a, b))
@staticmethod
def add(a, b):
return ValueRanges.coordinatewise_increasing_map(a, b, operator.add)
@classmethod
def mul(cls, a, b):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
assert a.is_bool == b.is_bool
if a.is_bool:
return cls.and_(a, b)
def safe_mul(a, b):
# Make unknown() * wrap(0) == wrap(0)
if a == 0:
return a
elif b == 0:
return b
else:
return a * b
return ValueRanges.coordinatewise_monotone_map(a, b, safe_mul)
@classmethod
def div(cls, a, b):
return cls.truediv(a, b)
@staticmethod
def truediv(a, b):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
if 0 in b or ((-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)):
return ValueRanges.unknown()
else:
return ValueRanges.coordinatewise_monotone_map(a, b, operator.truediv)
@staticmethod
def floordiv(a, b):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
if 0 in b or ((-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)):
return ValueRanges.unknown()
else:
return ValueRanges.coordinatewise_monotone_map(a, b, operator.floordiv)
@staticmethod
def mod(x, y):
x = ValueRanges.wrap(x)
y = ValueRanges.wrap(y)
if x.is_singleton() and y.is_singleton() and y.lower != 0:
return ValueRanges.wrap(x.lower % y.lower)
if y.lower <= 0:
return ValueRanges.unknown()
return ValueRanges(0, y.upper)
@classmethod
def modular_indexing(cls, a, b, c):
return cls.mod(cls.floordiv(a, b), c)
@classmethod
def pow(cls, a, b):
def is_integer(val):
return isinstance(val, int) or (
hasattr(val, "is_integer") and val.is_integer
)
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
# Not implemented yet. It's a bit tricky
# If you want to implement it, compute the partial derivatives of a ** b
# and check the ranges where the function is increasing / decreasing
# Another non-tight way of doing this is defaulting to doing noting that for a > 0, a ** b == exp(b * log(a))
# If this second option is implemented, by carefult about the types and possible infinities here and there.
if not b.is_singleton():
return ValueRanges.unknown()
b = b.lower
if a.is_singleton():
a = a.lower
r = a ** b
if not r.is_finite:
return ValueRanges.unknown()
return ValueRanges.wrap(r)
if b == 0:
if not a.lower.is_finite:
return ValueRanges.unknown()
type_ = sympy.Float if a.lower.is_real else sympy.Integer
return ValueRanges.wrap(type_(1))
if b < 0:
a = cls.reciprocal(a)
b = -b
if a == ValueRanges.unknown():
return ValueRanges.unknown()
# Here b > 0
if not is_integer(b):
# If the base is positive, then we're good, otherwise nothing's defined
if a.lower >= 0:
return ValueRanges.increasing_map(a, lambda x: x ** b)
else:
return ValueRanges.unknown()
else:
# b > 0 integer
if b % 2 == 0:
# x^n where n is even
return ValueRanges.convex_min_zero_map(a, lambda x: x ** b)
else:
# x^n where n is odd
return ValueRanges.increasing_map(a, lambda x: x ** b)
@staticmethod
def reciprocal(x):
""" Needed as it's used in pow, but it won't appear on a SymPy expression """
x = ValueRanges.wrap(x)
if 0 in x:
return ValueRanges.unknown()
else:
return ValueRanges.decreasing_map(x, lambda y: 1 / y)
@staticmethod
def abs(x):
return ValueRanges.convex_min_zero_map(x, abs)
@staticmethod
def exp(x):
return ValueRanges.increasing_map(x, sympy.functions.elementary.exponential.exp)
@staticmethod
def log(x):
x = ValueRanges.wrap(x)
if x.lower <= 0:
return ValueRanges.unknown()
return ValueRanges.increasing_map(x, sympy.log)
@classmethod
def minimum(cls, a, b):
return cls.min_or_max(a, b, sympy.Min)
@classmethod
def maximum(cls, a, b):
return cls.min_or_max(a, b, sympy.Max)
@staticmethod
def min_or_max(a, b, fn):
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
# Performs upcasting first
def fn_(x, y):
# Poorman's version of upcasting in Sympy
# Inf is not a float...
if x.is_Integer and y.is_Integer:
result_type = sympy.Integer
elif x.is_rational and y.is_rational:
result_type = sympy.Rational
else:
assert x.is_real or not x.is_finite or y.is_real or not y.is_finite
result_type = sympy.Float
return fn(result_type(x), result_type(y))
return ValueRanges.coordinatewise_increasing_map(a, b, fn_)
@classmethod
def floor(cls, x):
return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor)
@classmethod
def ceil(cls, x):
return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.ceiling)
# It's used in some models on symints
@staticmethod
def sqrt(x):
x = ValueRanges.wrap(x)
if x.lower < 0:
return ValueRanges.unknown()
return ValueRanges.increasing_map(x, sympy.sqrt)
@staticmethod
def where(a, b, c):
b = ValueRanges.wrap(b)
c = ValueRanges.wrap(c)
assert a.is_bool
assert b.is_bool == c.is_bool
if b.is_bool:
return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
else:
return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))
class ValueRangeAnalysis(SymPyValueRangeAnalysis):
def __init__(self):
self.name = "ValueRangeAnalysis"
boolean_operators = (
"xor",
"logical_and",
"logical_or",
"logical_not",
)
for op in boolean_operators:
setattr(self, op, self.bool_handler)
@staticmethod
def bool_handler(*args, **kwargs):
# just assuming bools can have both values
return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type]
@staticmethod
def default_handler(*args, **kwargs):
# many ops are unlikely to show up in optimizable indexing compute,
# so we dont have full coverage
return ValueRanges.unknown()
def load(self, name: str, index: sympy.Expr):
return ValueRanges.unknown()
def store(self, name, index, value, mode=None):
return
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
return ValueRanges.unknown()
def index_expr(self, index, dtype):
assert isinstance(index, ValueRanges)
return index
@staticmethod
def to_dtype(x, dtype: torch.dtype):
x = ValueRanges.wrap(x)
if dtype == torch.bool:
if x.is_singleton():
return ValueRanges.wrap(x.lower != 0)
elif 0 not in x:
return ValueRanges.wrap(sympy.true)
else:
return ValueRanges(sympy.false, sympy.true)
def cast(x, dtype):
# dtype is int or float
if dtype.is_floating_point:
return sympy.Float(x)
else:
try:
return sympy.Integer(x)
except TypeError:
# inf cannot be cast to Integer
return x
if x.is_bool:
if x.is_singleton():
val = 1 if x.lower else 0
return ValueRanges.wrap(cast(val, dtype))
else:
return ValueRanges(cast(0, dtype), cast(1, dtype))
else:
# int to float or float to int
return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))
@staticmethod
def square(x):
return ValueRanges.convex_min_zero_map(x, lambda y: y * y)
@staticmethod
def neg(x):
return ValueRanges.decreasing_map(x, operator.neg)
@classmethod
def truncdiv(cls, a, b):
x = cls.truediv(a, b)
if x == ValueRanges.unknown():
return x
def trunc(x):
return sympy.Integer(x) if x.is_finite else x
return ValueRanges.increasing_map(x, trunc)
@classmethod
def sub(cls, a, b):
return cls.add(a, cls.neg(b))
def __getattr__(self, name):
log.debug("unhandled ValueRange op %s", name)
return self.default_handler
def bound_sympy(expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None) -> ValueRanges:
if isinstance(expr, sympy.Number):
return ValueRanges.wrap(expr)
ranges = ranges or {}
# If there's a tracing context, augment available constrained ranges.
context = torch._guards.TracingContext.get()
if context and context.fake_mode.shape_env:
ranges = {**ranges, **context.fake_mode.shape_env.var_to_range}
unbounded_vars = expr.free_symbols - ranges.keys()
if unbounded_vars:
# Give some bounds to the free variables via their SymPy assumptions
# TODO A better way of doing this would be to assign them a range upon creation, as
# size variables can come with a lower bound of 2, as we specialise on 0 and 1
unbounded_ranges: Dict[sympy.Symbol, ValueRanges] = {}
for s in unbounded_vars:
assert s.is_integer # type: ignore[attr-defined]
if s.is_positive: # type: ignore[attr-defined]
lower = 1
elif s.is_nonnegative: # type: ignore[attr-defined]
lower = 0
else:
lower = -math.inf # type: ignore[assignment]
unbounded_ranges[s] = ValueRanges(lower, math.inf) # type: ignore[index]
ranges = {**ranges, **unbounded_ranges}
return sympy_interp(SymPyValueRangeAnalysis, ranges, expr)