Repository URL to install this package:
|
Version:
1.11.0 ▾
|
ccc-model-manager
/
lib
/
python3.9
/
site-packages
/
torch
/
fx
/
experimental
/
symbolic_shapes.py
|
|---|
import torch
import torch.utils._pytree as pytree
from typing import Set, Dict, List, Type, Optional, cast
import operator
import math
import functools
from functools import lru_cache, partial
import traceback
import collections
import textwrap
from torch._subclasses.meta_utils import MetaConverter
try:
import sympy # type: ignore[import]
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
aten = torch.ops.aten # type: ignore[has-type]
__all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "PySymInt", "ShapeEnv",
"SymDispatchMode", "PySymFloat", "sym_float", "FloorDiv"
]
SYM_FUNCTION_MODE = None
# We don't bother with the metaclass as all of the dispatching logic happens
# entirely from Python
#
# Didn't bother with ancestors for now, unlikely to have multiple modes for
# symints right now
# SymDispatchMode gets invoked whenever an operation is processed on
# a PySymInt. When this occurs, you get called at __sym_dispatch__
# with the operation in question. This is symmetric to TorchDispatchMode
# but with some caveats:
#
# - In TorchDispatchMode, you get the same arguments as what a user
# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b),
# you get (a, b) as args to your call. In SymDispatchMode, if
# you call a + b (where a and b are SymInts), you will get
# (a.get_pyobj(), b.get_pyobj()) as your args (these are PySymInts)
#
# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor).
# So you have to manually call Tracer/create_node to write into
# the graph. See ProxySymDispatchMode for an example
#
class SymDispatchMode:
def __sym_dispatch__(self, func, types, args, kwargs):
raise NotImplementedError()
def __enter__(self):
global SYM_FUNCTION_MODE
old = SYM_FUNCTION_MODE
if hasattr(self, "inner"):
raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version")
else:
self.inner = old
SYM_FUNCTION_MODE = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global SYM_FUNCTION_MODE
SYM_FUNCTION_MODE = self.inner
def has_symbolic_sizes_strides(elem):
return elem._has_symbolic_sizes_strides
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
def _handle_sym_dispatch(func, args, kwargs):
global SYM_FUNCTION_MODE
mode = SYM_FUNCTION_MODE
assert mode
SYM_FUNCTION_MODE = mode.inner
try:
# TODO: properly compute types
types: List[Type] = []
return mode.__sym_dispatch__(func, types, args, kwargs)
finally:
SYM_FUNCTION_MODE = mode
def sym_float(a):
if hasattr(a, '__sym_float__'):
return a.__sym_float__()
elif isinstance(a, torch._C.SymFloatNode):
return a
return float(a)
# TODO: An incomplete list
# 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction
class PySymInt(object):
"""
PySymInt objects are the primary "symbolic shape" objects that flow through
our program. They're what sit under FakeTensor, and contains our primary
implementation of symbolic shapes.
"""
def __init__(self, expr, shape_env, constant=None):
self.expr = expr
self.shape_env = shape_env
self.constant = constant
def wrap(self, num):
return PySymInt(sympy.Integer(num), self.shape_env, constant=num)
def clone(self):
return PySymInt(self.expr, self.shape_env, constant=self.constant)
def __str__(self):
return f"{self.expr}"
def __repr__(self):
return f"{self.expr}"
# Today we error on calling int on a symbolic shape, as this is a very accessible footgun.
def __int__(self):
raise RuntimeError("Trying to extract a concrete int out of a symbolic int")
# You can manually trigger a guard with this function
def guard_int(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
return int(self.shape_env.evaluate_expr(self.expr))
def __sym_float__(self):
if SYM_FUNCTION_MODE:
return _handle_sym_dispatch(sym_float, (self,), {})
# TODO: consider constant prop here
# TODO: wrapping the expr with sympy.Float doesn't seem to work, why
# not?
return PySymFloat(self.expr, self.shape_env)
def __bool__(self):
return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr)))
class PySymFloat:
def __init__(self, expr, shape_env, constant=None):
self.expr = expr
self.shape_env = shape_env
self.constant = constant
def wrap(self, num):
return PySymFloat(sympy.Float(num), self.shape_env, constant=num)
def __str__(self):
return f"{self.expr}"
if HAS_SYMPY:
class FloorDiv(sympy.Function):
"""
We maintain this so that:
1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
"""
nargs = (2,)
@classmethod
def eval(cls, base, divisor):
if base == 0:
return sympy.Integer(0)
if divisor == 1:
return base
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return base // divisor
if isinstance(base, FloorDiv):
return FloorDiv(base.args[0], base.args[1] * divisor)
gcd = sympy.gcd(base, divisor)
if gcd != 1:
return FloorDiv(
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
)
class Ceil(sympy.Function):
"""
sympy doesn't have its own ceil(), so rolling one here.
We maintain this so that we can simplify a sympy.Rational into a sympy.Float.
sympy.Float isn't supported.
"""
nargs = (1,)
@classmethod
def eval(cls, a):
if isinstance(a, sympy.Integer):
return a
elif isinstance(a, sympy.core.symbol.Symbol) and a.is_scalar:
# TODO: do we need to simplify expr's first? (e.g. if we have 3/3), is is_scalar() true?
return a
elif isinstance(a, sympy.Rational):
return a.floor() + 1
else:
raise NotImplementedError("math.ceil() not supported for type: " + str(type(a)))
# Methods that have a `__foo__` as well as `__rfoo__`
reflectable_magic_methods = {
'add': lambda a, b: a + b,
'sub': lambda a, b: a - b,
'mul': lambda a, b: a * b,
'mod': lambda a, b: a % b,
'truediv': lambda a, b: a / b,
'floordiv': lambda a, b: FloorDiv(a, b)
}
magic_methods = {
**reflectable_magic_methods,
'eq': lambda a, b: sympy.Eq(a, b),
'gt': lambda a, b: sympy.Gt(a, b),
'lt': lambda a, b: sympy.Lt(a, b),
'le': lambda a, b: sympy.Le(a, b),
'ge': lambda a, b: sympy.Ge(a, b),
'ceil': lambda a: Ceil(a)
}
unary_magic_methods = {
'ceil'
}
float_magic_methods = {"add", "sub", "mul", "truediv", "ceil"}
def _make_magic(method, func, py_type):
func = lru_cache(256)(func)
def magic_impl(self, other):
if SYM_FUNCTION_MODE:
return _handle_sym_dispatch(getattr(operator, method), (self, other), {})
if isinstance(other, py_type):
other = other.expr
# TODO: consider constant prop here
expr = self.shape_env.replace(self.expr)
other = self.shape_env.replace(other)
out = func(expr, other)
out = sympy.expand(out)
if method in ["truediv"]:
return PySymFloat(out, self.shape_env)
else:
# TODO: relational operators actually technically return a
# PySymBool, this is a type error
return py_type(out, self.shape_env)
def unary_magic_impl(self):
if SYM_FUNCTION_MODE:
if method in ["ceil"]:
op = getattr(math, method)
else:
op = getattr(operator, method)
return _handle_sym_dispatch(op, (self,), {})
# TODO: consider constant prop here
expr = self.shape_env.replace(self.expr)
out = func(expr)
out = sympy.expand(out)
if method in ["ceil"]:
return PySymInt(out, self.shape_env)
else:
return py_type(out, self.shape_env)
# this should be wrapped transparently into torch.SymIntNode
if method in unary_magic_methods:
setattr(py_type, method, unary_magic_impl)
setattr(py_type, f"__{method}__", unary_magic_impl)
else:
setattr(py_type, method, magic_impl)
setattr(py_type, f"__{method}__", magic_impl)
if method in reflectable_magic_methods:
setattr(py_type, f"__r{method}__", magic_impl)
for method, func in magic_methods.items():
_make_magic(method, func, PySymInt)
for method, func in magic_methods.items():
if method not in float_magic_methods:
continue
_make_magic(method, func, PySymFloat)
del method
del func
def _lru_cache(fn, maxsize=None):
"""
Wrapper around lru_cache that clears when new info about shapes has been
updated.
Use lru_cache if the output is always the same, regardless of the
constraints we know now (i.e. evaluate_expr)
Use _lru_cache otherwise.
"""
fn_cache = lru_cache(maxsize)(fn)
prior_key = None
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
nonlocal prior_key
if prior_key != self._get_key():
prior_key = self._get_key()
fn_cache.cache_clear()
return fn_cache(self, *args, **kwargs)
wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined]
return wrapper
class ShapeEnv(object):
def __init__(self):
self.guards = []
# Maps symbolic ints to their original concrete values
# Currently populated from tensors
self.var_to_val: Dict["sympy.Symbol", "sympy.Integer"] = {}
# Maps from sympy ints to expressions representing them
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} #
# Set holds a % b expressions that evaluate to 0.
self.divisible: Set["sympy.Expr"] = set()
# Duck-shaping says that if two input tensors have the same size,
# they get assigned the same symbolic variable
self.val_to_symint: Dict[int, torch.SymIntNode] = {}
def _get_key(self):
"""
Defines the current "state" of the guards we've accumulated in this ShapeEnv.
Determines when we need to invalidate our cache
"""
return (len(self.replacements), len(self.divisible))
# NB: This is only called for input symbolic sizes; intermediate symbolic
# sizes are allocated via a different mechanism
def create_symint(self, name, val):
assert val >= 0
if not HAS_SYMPY:
raise RuntimeError("Need sympy installed to create symbolic shapes")
# TODO: Put 0/1 specialization in guards
if val == 0 or val == 1:
return val
# This implements duck-shaping: input sizes that match are assigned
# the same symint
# TODO: Create a guard whenever this happens
# TODO: But how do I represent the guard in this case?
if val in self.val_to_symint:
return self.val_to_symint[val]
sympy_expr = sympy.Symbol(name, positive=True, integer=True)
py_sym_int = PySymInt(sympy_expr, self)
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
self.var_to_val[sympy_expr] = sympy.Integer(val)
self.val_to_symint[val] = cpp_sym_int
return cpp_sym_int
def evaluate_guards_for_args(self, *args):
new_env = ShapeEnv()
# NB: This must be kept in sync with create_aot_dispatcher_function
# and wrap_fake_symbolic
meta_converter = MetaConverter()
pytree.tree_map_only(torch.Tensor, partial(meta_converter, shape_env=new_env), args)
return all(guard.xreplace(new_env.var_to_val) == value for guard, value, _ in self.guards)
def get_nontrivial_guards(self):
return [(self.simplify(guard), val) for guard, val, _ in self.guards if self._maybe_evaluate_static(guard) is None]
def format_guards(self, verbose=False):
def format_val(guard, val):
if val is sympy.true:
return str(guard)
elif val is sympy.false:
return f"Not({guard})"
else:
return f"Eq({guard}, {val})"
def format_tb(tb):
if not verbose:
return ""
return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}"
return '\n'.join(f" - {format_val(guard, val)}{format_tb(tb)}" for guard, val, tb in self.guards)
def get_shape_groups(self):
shape_groups = collections.defaultdict(list)
for k, v in self.replacements.items():
shape_groups[v].append(k)
return shape_groups
@_lru_cache
def _maybe_evaluate_static(self, expr: "sympy.Expr") -> "Optional[sympy.Expr]":
"""
Tries to evaluate expr without introducing guards
"""
expr = self.simplify(expr)
# Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values)
symbols = list(expr.free_symbols)
new_shape_env = {
k: sympy.Symbol(f"shape_{idx}", positive=True, integer=True) + 1
for idx, k in enumerate(symbols)
}
new_expr = expr.xreplace(new_shape_env)
floor_div_replace = {}
for atom in new_expr.atoms(FloorDiv):
floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
new_expr = sympy.expand(new_expr.xreplace(floor_div_replace))
if len(list(new_expr.free_symbols)) == 0:
return new_expr
return None
@_lru_cache
def replace(self, expr: "sympy.Expr") -> "sympy.Expr":
replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols}
return sympy.expand(expr.xreplace(replacements))
@_lru_cache
def _update_divisible(self):
new_divisible = set()
for k in self.divisible:
res = self.replace(k)
if len(res.free_symbols) > 0:
new_divisible.add(k)
self.divisible = new_divisible
@_lru_cache
def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
expr = self.replace(expr)
if expr.has(FloorDiv):
self._update_divisible()
div_replacements = {}
for atom in expr.atoms(FloorDiv):
base, divisor = atom.args
if self.replace(base % divisor) in self.divisible:
div_replacements[atom] = base / divisor
expr = expr.xreplace(div_replacements)
expr = sympy.expand(expr)
return expr
@lru_cache(256)
def size_hint(self, expr: "sympy.Expr"):
"""
Gets a size hint for a given expression from the underlying shapes we had.
Does not introduce a guard, so only use this when you can guarantee that
your code is still valid for arbitrary shapes (such as optimization decisions)
"""
result_expr = sympy.expand(expr).xreplace(self.var_to_val)
assert len(result_expr.free_symbols) == 0, "Size hint has variables we don't have underlying values for"
return result_expr
@_lru_cache
def _find(self, a: "sympy.Symbol") -> "sympy.Expr":
"""
Implements a DSU-like algorithm to find the variable that represents a
Also handles transitive non-identity replacements.
a: b + c
c: d
"""
if a not in self.replacements:
return a
res = self.replacements[a]
cur_replace = {s: self._find(s) for s in res.free_symbols}
self.replacements[a] = self.replacements[a].xreplace(cur_replace)
return self.replacements[a]
@lru_cache(256)
def _maybe_guard_eq(self, expr: "sympy.Eq") -> None:
"""
Evaluates the result of an eq call. If true, uses information to
simplify shapes (i.e. a == b or a % 5 == 0)
"""
concrete_bool = bool(self.size_hint(expr))
if not concrete_bool:
return
free = list(expr.free_symbols)
assert len(free) > 0, "The expression should not be static by this point"
# In case of really gnarly expression, we don't blow up
if len(free) > 5:
return
free = sorted(free, key=lambda x: (self.size_hint(x), x.name), reverse=True) # type: ignore[attr-defined]
lhs = expr.lhs
rhs = expr.rhs
try:
solutions = sympy.solve(lhs - rhs, free[0], dict=True)
if len(solutions) != 1:
return
solution = solutions[0][free[0]]
if all(t.is_integer for t in sympy.preorder_traversal(solution)):
new_var = self._find(solution)
self.replacements[cast(sympy.Symbol, free[0])] = new_var
except NotImplementedError:
if expr.has(sympy.Mod):
mod_expr = tuple(expr.atoms(sympy.Mod))[0]
try:
solutions = sympy.solve(lhs - rhs, mod_expr, dict=True)
if len(solutions) == 1 and solutions[0][mod_expr] == 0:
self.divisible.add(mod_expr)
except NotImplementedError:
pass
return
@lru_cache(256)
def evaluate_expr(self, expr: "sympy.Expr"):
"""
Given an expression, evaluates it, adding guards if necessary
"""
if len(expr.free_symbols) == 0:
return expr
expr = self.simplify(expr)
static_expr = self._maybe_evaluate_static(expr)
if static_expr is not None:
return static_expr
if isinstance(expr, sympy.Eq):
self._maybe_guard_eq(expr)
concrete_val = self.size_hint(expr)
# TODO: optimize this; avoid formatting traces until we need them
# NB: drop two frames; evaluate_expr and the Sym* function that
# actually called us
stack = ''.join(traceback.format_list(traceback.extract_stack()[:-2]))
self.guards.append((expr, concrete_val, stack))
return concrete_val