Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ _inductor / sizevars.py

import dataclasses
import functools
import itertools
import logging
from typing import Callable, Dict, List, Tuple

import sympy
from sympy import Expr

from torch.fx.experimental.symbolic_shapes import ShapeEnv

from . import ir
from .codegen.common import IndentedBuffer
from .utils import sympy_subs, sympy_symbol, VarRanges
from .virtualized import V

log = logging.getLogger(__name__)


@dataclasses.dataclass
class ZeroGuard:
    """
    An expression we should check equals zero.
    Guards are currently not checked.  Plan to add this later.
    """

    expr: Expr


@dataclasses.dataclass
class PositiveGuard:
    """
    An expression we should check for > 0
    Guards are currently not checked.  Plan to add this later.
    """

    expr: Expr


class SizeVarAllocator:
    def __init__(self, shape_env=None):
        super().__init__()
        if shape_env is None:
            shape_env = ShapeEnv()
        self.shape_env = shape_env
        self.var_to_val = self.shape_env.var_to_val
        self.guards = []
        self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements
        # maps of dynamic sizes that have to be precomputed on the host to the kernel args
        self.precomputed_replacements: Dict[Expr, sympy.Symbol] = dict()
        self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = dict()
        self.need_seed = False
        self.stride_vars = self.make_stride_vars_cache()
        self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
        self._simplify_loops = self.make_simplify_loops_cache()
        self.declare = ""
        self.ending = ""
        self.as_strided = "as_strided"

    def seed(self):
        """
        Seed is a special variable used to hold the rng seed for a graph.

        Note this is only used by the CPU backend, we put seeds in a
        1-element tensor for the CUDA backend.
        """
        self.need_seed = True
        return sympy_symbol("seed")

    def simplify(self, expr: Expr):
        return sympy.expand(expr).xreplace(self.replacements)

    def make_simplify_with_ranges_cache(self):
        """
        self._simplify_with_ranges() can be expensive, cache its results
        """
        cache = dict()
        replacement_count = len(self.replacements)

        def simplify_with_ranges(expr: Expr, var_ranges: VarRanges):
            nonlocal replacement_count
            if replacement_count != len(self.replacements):
                # new replacements invalidates cached results
                cache.clear()
                replacement_count = len(self.replacements)
            key = (expr, *var_ranges.items())
            result = cache.get(key, None)
            if result is None:
                result = self._simplify_with_ranges(expr, var_ranges)
                cache[key] = result
            return result

        return simplify_with_ranges

    def make_simplify_loops_cache(self):
        """
        self._simplify_with_ranges() can be expensive, cache its results
        """
        cache = dict()
        replacement_count = len(self.replacements)

        def simplify_loops(index_vars, sizes, index_formulas):
            nonlocal replacement_count
            if replacement_count != len(self.replacements):
                # new replacements invalidates cached results
                cache.clear()
                replacement_count = len(self.replacements)
            key = (*index_vars, *sizes, *index_formulas)
            result = cache.get(key, None)
            if result is None:
                result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
                cache[key] = result
            return result

        return simplify_loops

    def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges):
        """
        Simplify indexing expression with knowledge of the ranges of
        iteration variables.
        """
        from .ir import FloorDiv, ModularIndexing

        expr = join_dimensions(self.simplify(expr))
        original_expr = expr

        def remove_zero_terms(base, divisor):
            """Symbols smaller than the divisor are zero"""
            for v in base.free_symbols:
                if v in var_ranges:
                    # var smaller than divisor can be removed
                    # if the rest is guaranteed to be multiple of divisor
                    rest = sympy.Wild("_rest", exclude=[v])
                    m = base.match(v + rest)
                    if m and v not in m[rest].free_symbols:
                        gcd = sympy.gcd(m[rest], divisor)
                        if gcd == divisor:
                            if self.maybe_guard_leq(var_ranges[v], divisor):
                                base = m[rest]
            return base

        def visit_indexing_div(base, divisor):
            return FloorDiv(remove_zero_terms(base, divisor), divisor)

        def visit_modular_indexing(base, divisor, modulus):
            base = remove_zero_terms(base, divisor)
            if isinstance(base, ModularIndexing):
                # for modular indexing, biggest values from the ranges don't necessarily result in
                # the biggest result, the biggest result is modulus - 1
                base_s = base.args[2] - 1
            elif not base.has(ModularIndexing):
                # actual iteration range is to size-1
                iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
                base_lowest = sympy_subs(base, iter_ranges_zero)
                if self.maybe_guard_lt(base_lowest, 0):
                    # can't replace with indexing div if base can be negative
                    return ModularIndexing(base, divisor, modulus)
                iter_ranges = {k: v - 1 for k, v in var_ranges.items()}
                base_s = sympy_subs(base, iter_ranges)
            else:
                base_s = base
            if self.maybe_guard_lt(base_s, modulus * divisor):
                return FloorDiv(base, divisor)
            return ModularIndexing(base, divisor, modulus)

        if expr.has(ModularIndexing):
            expr = expr.replace(
                ModularIndexing(
                    sympy.Wild("base"),
                    sympy.Wild("divisor"),
                    sympy.Wild("modulus"),
                ),
                visit_modular_indexing,
            )

        if expr.has(FloorDiv):
            expr = expr.replace(
                FloorDiv(
                    sympy.Wild("base"),
                    sympy.Wild("divisor"),
                ),
                visit_indexing_div,
            )

        if expr != original_expr:
            return self._simplify_with_ranges(expr, var_ranges)
        return expr

    def _simplify_loops_impl(self, index_vars, sizes, index_formulas):
        """
        Try to remove as many axis from loop iterations as possible, by:
            1) removing size==1 dimensions
            2) fuse contiguous dimensions into a single loop
            If channel_last = True, we will prevent the last dim fused with other dims
        """
        sizes = list(map(self.simplify, sizes))

        strides = [self.stride_vars(x, index_vars) for x in index_formulas]
        assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))

        for i in range(len(sizes)):
            if sizes[i] == 1:
                # remove dim
                sizes[i] = None

        def can_merge_dims(a, b):
            for k in range(len(strides)):
                if self.simplify(strides[k][a] * sizes[a]) == self.simplify(
                    strides[k][b]
                ):
                    # approximate test passed, try sound version
                    va = index_vars[a]
                    vb = index_vars[b]
                    v = sympy_symbol("_merge_tester")
                    expr1 = sympy_subs(index_formulas[k], {va: v * sizes[a], vb: 0})
                    expr2 = sympy_subs(index_formulas[k], {va: 0, vb: v})
                    if self.simplify(expr1) == self.simplify(expr2):
                        continue
                return False
            return True

        changed = True
        while changed:
            changed = False
            for i, j in itertools.product(
                reversed(range(len(sizes))), reversed(range(len(sizes)))
            ):
                if i == j or sizes[i] is None or sizes[j] is None:
                    continue
                if can_merge_dims(i, j):
                    changed = True
                    sizes[i] = sizes[i] * sizes[j]
                    sizes[j] = None

        def reindex(index):
            it = list(reversed(index))
            new_index = []
            for size in sizes:
                if size is None:
                    new_index.append(sympy.Integer(0))
                else:
                    new_index.append(it.pop())
            assert not it
            return new_index

        def prune(index):
            assert len(index) == len(sizes)
            return [i for i, s in zip(index, sizes) if s is not None]

        return [x for x in sizes if x is not None], reindex, prune

    def guard_equals(self, left: Expr, right: Expr) -> Expr:
        assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
        return left

    def maybe_guard_equals(self, left: Expr, right: Expr) -> bool:
        """if left==right, guard on that fact and return true"""
        if left == right:
            return True
        if self.size_hint(left - right) == 0:
            self.guard_equals(left, right)
            return True
        return False

    def maybe_guard_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
        """if left==right, guard on that fact and return true"""
        if len(left) != len(right):
            return False
        if all(self.size_hint(a - b) == 0 for a, b in zip(left, right)):
            for a, b in zip(left, right):
                self.guard_equals(a, b)
            return True
        return False

    def maybe_guard_leq(self, left: Expr, right: Expr) -> bool:
        try:
            if self.size_hint(left) > self.size_hint(right):
                return False
        except TypeError:
            return False
        self.guard_leq(left, right)
        return True

    def maybe_guard_lt(self, left: Expr, right: Expr) -> bool:
        try:
            if self.size_hint(left) >= self.size_hint(right):
                return False
        except TypeError:
            return False
        self.guard_lt(left, right)
        return True

    def guard_leq(self, left: Expr, right: Expr) -> None:
        return self.guard_lt(left, right + 1)

    def guard_lt(self, left: Expr, right: Expr) -> None:
        expr = self.simplify(right - left)
        assert self.size_hint(expr) > 0
        if len(expr.free_symbols) == 0:
            return
        if "-" in str(expr):
            # all vars are positive, so needs a minus sign to get negative values
            self.guards.append(PositiveGuard(expr))

    def guard_min(self, left: Expr, right: Expr) -> Expr:
        """return the smaller of left and right, and guard on that choice"""
        lv = self.size_hint(left)
        rv = self.size_hint(right)
        if lv == rv:
            return self.guard_equals(left, right)
        elif lv < rv:
            self.guard_lt(left, right)
            return left
        else:
            self.guard_lt(right, left)
            return right

    def guard_max(self, left: Expr, right: Expr) -> Expr:
        """return the larger of left and right, and guard on that choice"""
        return -self.guard_min(-left, -right)

    def maybe_guard_multiple_of(self, numerator: Expr, denominator: Expr) -> bool:
        """if denominator divides numerator, return True and guard on that fact"""
        if sympy.gcd(numerator, denominator) == denominator:
            # can prove it symbolically
            return True
        if self.size_hint(numerator) % self.size_hint(denominator) == 0:
            self.guard_equals(numerator % denominator, 0)
            return True
        return False

    def guard_static_shape(self, left: Expr) -> int:
        right = self.size_hint(left)
        self.guard_equals(left, sympy.Integer(right))
        return int(right)

    def __getitem__(self, val: int) -> Expr:
        return self.shape_env.duck_int(val)

    def size_hint(self, expr: Expr) -> int:
        out = sympy_subs(sympy.expand(expr), self.var_to_val)
        return int(out)

    def size_hints(self, exprs: List[Expr]) -> int:
        return tuple(self.size_hint(x) for x in exprs)
Loading ...