import dataclasses
import functools
import itertools
import logging
from typing import Callable, Dict, List, Tuple
import sympy
from sympy import Expr
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from . import ir
from .codegen.common import IndentedBuffer
from .utils import sympy_subs, sympy_symbol, VarRanges
from .virtualized import V
log = logging.getLogger(__name__)
@dataclasses.dataclass
class ZeroGuard:
"""
An expression we should check equals zero.
Guards are currently not checked. Plan to add this later.
"""
expr: Expr
@dataclasses.dataclass
class PositiveGuard:
"""
An expression we should check for > 0
Guards are currently not checked. Plan to add this later.
"""
expr: Expr
class SizeVarAllocator:
def __init__(self, shape_env=None):
super().__init__()
if shape_env is None:
shape_env = ShapeEnv()
self.shape_env = shape_env
self.var_to_val = self.shape_env.var_to_val
self.guards = []
self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements
# maps of dynamic sizes that have to be precomputed on the host to the kernel args
self.precomputed_replacements: Dict[Expr, sympy.Symbol] = dict()
self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = dict()
self.need_seed = False
self.stride_vars = self.make_stride_vars_cache()
self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
self._simplify_loops = self.make_simplify_loops_cache()
self.declare = ""
self.ending = ""
self.as_strided = "as_strided"
def seed(self):
"""
Seed is a special variable used to hold the rng seed for a graph.
Note this is only used by the CPU backend, we put seeds in a
1-element tensor for the CUDA backend.
"""
self.need_seed = True
return sympy_symbol("seed")
def simplify(self, expr: Expr):
return sympy.expand(expr).xreplace(self.replacements)
def make_simplify_with_ranges_cache(self):
"""
self._simplify_with_ranges() can be expensive, cache its results
"""
cache = dict()
replacement_count = len(self.replacements)
def simplify_with_ranges(expr: Expr, var_ranges: VarRanges):
nonlocal replacement_count
if replacement_count != len(self.replacements):
# new replacements invalidates cached results
cache.clear()
replacement_count = len(self.replacements)
key = (expr, *var_ranges.items())
result = cache.get(key, None)
if result is None:
result = self._simplify_with_ranges(expr, var_ranges)
cache[key] = result
return result
return simplify_with_ranges
def make_simplify_loops_cache(self):
"""
self._simplify_with_ranges() can be expensive, cache its results
"""
cache = dict()
replacement_count = len(self.replacements)
def simplify_loops(index_vars, sizes, index_formulas):
nonlocal replacement_count
if replacement_count != len(self.replacements):
# new replacements invalidates cached results
cache.clear()
replacement_count = len(self.replacements)
key = (*index_vars, *sizes, *index_formulas)
result = cache.get(key, None)
if result is None:
result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
cache[key] = result
return result
return simplify_loops
def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges):
"""
Simplify indexing expression with knowledge of the ranges of
iteration variables.
"""
from .ir import FloorDiv, ModularIndexing
expr = join_dimensions(self.simplify(expr))
original_expr = expr
def remove_zero_terms(base, divisor):
"""Symbols smaller than the divisor are zero"""
for v in base.free_symbols:
if v in var_ranges:
# var smaller than divisor can be removed
# if the rest is guaranteed to be multiple of divisor
rest = sympy.Wild("_rest", exclude=[v])
m = base.match(v + rest)
if m and v not in m[rest].free_symbols:
gcd = sympy.gcd(m[rest], divisor)
if gcd == divisor:
if self.maybe_guard_leq(var_ranges[v], divisor):
base = m[rest]
return base
def visit_indexing_div(base, divisor):
return FloorDiv(remove_zero_terms(base, divisor), divisor)
def visit_modular_indexing(base, divisor, modulus):
base = remove_zero_terms(base, divisor)
if isinstance(base, ModularIndexing):
# for modular indexing, biggest values from the ranges don't necessarily result in
# the biggest result, the biggest result is modulus - 1
base_s = base.args[2] - 1
elif not base.has(ModularIndexing):
# actual iteration range is to size-1
iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
base_lowest = sympy_subs(base, iter_ranges_zero)
if self.maybe_guard_lt(base_lowest, 0):
# can't replace with indexing div if base can be negative
return ModularIndexing(base, divisor, modulus)
iter_ranges = {k: v - 1 for k, v in var_ranges.items()}
base_s = sympy_subs(base, iter_ranges)
else:
base_s = base
if self.maybe_guard_lt(base_s, modulus * divisor):
return FloorDiv(base, divisor)
return ModularIndexing(base, divisor, modulus)
if expr.has(ModularIndexing):
expr = expr.replace(
ModularIndexing(
sympy.Wild("base"),
sympy.Wild("divisor"),
sympy.Wild("modulus"),
),
visit_modular_indexing,
)
if expr.has(FloorDiv):
expr = expr.replace(
FloorDiv(
sympy.Wild("base"),
sympy.Wild("divisor"),
),
visit_indexing_div,
)
if expr != original_expr:
return self._simplify_with_ranges(expr, var_ranges)
return expr
def _simplify_loops_impl(self, index_vars, sizes, index_formulas):
"""
Try to remove as many axis from loop iterations as possible, by:
1) removing size==1 dimensions
2) fuse contiguous dimensions into a single loop
If channel_last = True, we will prevent the last dim fused with other dims
"""
sizes = list(map(self.simplify, sizes))
strides = [self.stride_vars(x, index_vars) for x in index_formulas]
assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))
for i in range(len(sizes)):
if sizes[i] == 1:
# remove dim
sizes[i] = None
def can_merge_dims(a, b):
for k in range(len(strides)):
if self.simplify(strides[k][a] * sizes[a]) == self.simplify(
strides[k][b]
):
# approximate test passed, try sound version
va = index_vars[a]
vb = index_vars[b]
v = sympy_symbol("_merge_tester")
expr1 = sympy_subs(index_formulas[k], {va: v * sizes[a], vb: 0})
expr2 = sympy_subs(index_formulas[k], {va: 0, vb: v})
if self.simplify(expr1) == self.simplify(expr2):
continue
return False
return True
changed = True
while changed:
changed = False
for i, j in itertools.product(
reversed(range(len(sizes))), reversed(range(len(sizes)))
):
if i == j or sizes[i] is None or sizes[j] is None:
continue
if can_merge_dims(i, j):
changed = True
sizes[i] = sizes[i] * sizes[j]
sizes[j] = None
def reindex(index):
it = list(reversed(index))
new_index = []
for size in sizes:
if size is None:
new_index.append(sympy.Integer(0))
else:
new_index.append(it.pop())
assert not it
return new_index
def prune(index):
assert len(index) == len(sizes)
return [i for i, s in zip(index, sizes) if s is not None]
return [x for x in sizes if x is not None], reindex, prune
def guard_equals(self, left: Expr, right: Expr) -> Expr:
assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
return left
def maybe_guard_equals(self, left: Expr, right: Expr) -> bool:
"""if left==right, guard on that fact and return true"""
if left == right:
return True
if self.size_hint(left - right) == 0:
self.guard_equals(left, right)
return True
return False
def maybe_guard_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
"""if left==right, guard on that fact and return true"""
if len(left) != len(right):
return False
if all(self.size_hint(a - b) == 0 for a, b in zip(left, right)):
for a, b in zip(left, right):
self.guard_equals(a, b)
return True
return False
def maybe_guard_leq(self, left: Expr, right: Expr) -> bool:
try:
if self.size_hint(left) > self.size_hint(right):
return False
except TypeError:
return False
self.guard_leq(left, right)
return True
def maybe_guard_lt(self, left: Expr, right: Expr) -> bool:
try:
if self.size_hint(left) >= self.size_hint(right):
return False
except TypeError:
return False
self.guard_lt(left, right)
return True
def guard_leq(self, left: Expr, right: Expr) -> None:
return self.guard_lt(left, right + 1)
def guard_lt(self, left: Expr, right: Expr) -> None:
expr = self.simplify(right - left)
assert self.size_hint(expr) > 0
if len(expr.free_symbols) == 0:
return
if "-" in str(expr):
# all vars are positive, so needs a minus sign to get negative values
self.guards.append(PositiveGuard(expr))
def guard_min(self, left: Expr, right: Expr) -> Expr:
"""return the smaller of left and right, and guard on that choice"""
lv = self.size_hint(left)
rv = self.size_hint(right)
if lv == rv:
return self.guard_equals(left, right)
elif lv < rv:
self.guard_lt(left, right)
return left
else:
self.guard_lt(right, left)
return right
def guard_max(self, left: Expr, right: Expr) -> Expr:
"""return the larger of left and right, and guard on that choice"""
return -self.guard_min(-left, -right)
def maybe_guard_multiple_of(self, numerator: Expr, denominator: Expr) -> bool:
"""if denominator divides numerator, return True and guard on that fact"""
if sympy.gcd(numerator, denominator) == denominator:
# can prove it symbolically
return True
if self.size_hint(numerator) % self.size_hint(denominator) == 0:
self.guard_equals(numerator % denominator, 0)
return True
return False
def guard_static_shape(self, left: Expr) -> int:
right = self.size_hint(left)
self.guard_equals(left, sympy.Integer(right))
return int(right)
def __getitem__(self, val: int) -> Expr:
return self.shape_env.duck_int(val)
def size_hint(self, expr: Expr) -> int:
out = sympy_subs(sympy.expand(expr), self.var_to_val)
return int(out)
def size_hints(self, exprs: List[Expr]) -> int:
return tuple(self.size_hint(x) for x in exprs)
Loading ...