Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ _inductor / graph.py

import logging
import operator
import os
import re
import sys
import time
from typing import Dict, List, Optional, Set

import sympy

import torch
import torch.fx
from torch._decomp import get_decompositions
from torch._dynamo.utils import dynamo_timed
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._mode_utils import no_dispatch

from .._dynamo import config as dynamo_config

from . import config, ir
from .codegen.wrapper import CppWrapperCodeGen, WrapperCodeGen
from .exc import (
    LoweringException,
    MissingOperatorWithDecomp,
    MissingOperatorWithoutDecomp,
)
from .ir import Constant, FixedLayout, InputBuffer, Pointwise, Reduction, TensorBox
from .lowering import (
    FALLBACK_ALLOW_LIST,
    layout_constraints,
    lowerings,
    make_fallback,
    needs_realized_inputs,
)
from .sizevars import CppSizeVarAllocator, SizeVarAllocator
from .utils import (
    convert_shape_to_inductor,
    gather_origins,
    get_dtype_size,
    sympy_product,
)
from .virtualized import V

log = logging.getLogger(__name__)


def supported_dtype_of_cpp_wrapper(dtype):
    supported_dtype = {
        torch.float32,
        torch.float64,
        torch.int64,
        torch.int32,
        torch.int16,
        torch.int8,
        torch.uint8,
        torch.bool,
        # torch.float16, # TODO: implement this
        # torch.bfloat16, # TODO: implement this
    }
    return dtype in supported_dtype


class GraphLowering(torch.fx.Interpreter):
    def symbolic_sizes_strides(self, ex: torch.Tensor):
        """
        Support dynamic shapes and dynamic strides by assigning variables
        to each dimension.  We duck-shape tensors, so if two tensors
        have the same size they get assigned the same symbolic variable.
        """
        if self.reuse_shape_env:
            return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor(
                ex.stride()
            )
        else:
            from torch._dynamo.source import ConstantSource

            # TODO: this should not be needed once #93059 lands
            # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816
            # TODO: make a dedicated UnknownSource for this?
            source = ConstantSource(
                f"__unknown_tensor_{len(self._shape_env.var_to_val)}"
            )
            (
                size,
                stride,
                _,
            ) = self._shape_env.create_symbolic_sizes_strides_storage_offset(ex, source)

        size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size]
        stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
        return size, stride

    def static_sizes_strides(self, ex: torch.Tensor):
        """
        Primarily used to weights
        """
        size = [sympy.Integer(i) for i in ex.size()]
        stride = [sympy.Integer(i) for i in ex.stride()]
        return size, stride

    def __init__(
        self,
        gm: torch.fx.GraphModule,
        shape_env=None,
        num_static_inputs=None,
        graph_id=None,
    ):
        super().__init__(gm)
        if shape_env is None:
            shape_env = ShapeEnv()
            self.reuse_shape_env = False
        else:
            self._shape_env = shape_env
            self.reuse_shape_env = True
        self._shape_env = shape_env
        self.sizevars = SizeVarAllocator(shape_env)
        self.graph_inputs: Dict[str, TensorBox] = {}
        self.graph_inputs_original: Dict[str, InputBuffer] = {}
        self.graph_outputs: Optional[List[ir.IRNode]] = None
        self.device_types: Set[str] = set()
        self.buffers: List[ir.ComputedBuffer] = []
        self.constants: Dict[str, torch.Tensor] = {}
        self.removed_buffers: Set[str] = set()
        self.inplaced_to_remove: Set[str] = set()
        self.wrapper_code = None
        self.num_static_inputs = num_static_inputs
        self.mutated_inputs: Set[str] = set()
        self.unaligned_buffers: Set[str] = set()
        self.randomness_offset = sympy.Integer(0)
        self.randomness_seeds: List[str] = []
        self.name_to_buffer: Dict[str, ir.ComputedBuffer] = {}
        self.creation_time = time.time()
        self.name = "GraphLowering"
        self._can_use_cpp_wrapper = config.cpp_wrapper
        self.graph_id = graph_id
        self.scheduler = None
        self._warned_fallback = {"aten.convolution_backward"}

    def warn_fallback(self, name):
        if name not in self._warned_fallback:
            self._warned_fallback.add(name)
            log.info(f"Using FallbackKernel: {name}")

    @property
    def fake_mode(self):
        return V.fake_mode

    def get_dtype(self, buffer_name: str):
        if buffer_name in self.constants:
            return self.constants[buffer_name].dtype
        if buffer_name in self.name_to_buffer:
            return self.name_to_buffer[buffer_name].get_dtype()
        if buffer_name in self.graph_inputs:
            return self.graph_inputs[buffer_name].get_dtype()
        m = re.match(r"as_strided\(([a-zA-Z0-9_]+),", buffer_name)
        if m:
            return self.get_dtype(m.group(1))
        raise KeyError(f"could not find {buffer_name}")

    def random_seed_buffer(self, device: torch.device):
        """
        Return a device-unique 1-element tensor storing our RNG seed.
        This will get initialized at the start of each graph in
        `wrapper.py`.

        Note this is only used by cuda backends.  The CPU backend handles
        RNG seeds as a sizevar.
        """
        name = f"seed_{device.type}_{device.index}"
        if name not in self.constants:
            self.constants[name] = torch.zeros((), device=device, dtype=torch.int64)
            self.randomness_seeds.append(name)

        return ir.RandSeedBuffer(
            name=name,
            layout=ir.FixedLayout(
                device=device,
                dtype=torch.int64,
                size=[],
                stride=[],
            ),
        )

    def increment_randomness_offset(self, numel):
        """
        A global counter of how many random numbers we have handed out so far.
        """
        offset = self.randomness_offset
        self.randomness_offset = offset + numel
        return offset

    @dynamo_timed
    def run(self, *args):
        return super().run(*args)

    def disable_cpp_wrapper(self, cond):
        self._can_use_cpp_wrapper = False
        log.debug("Set _can_use_cpp_wrapper to False due to %s", cond)

    def check_buffer_for_cpp_wrapper(self, buffer: ir.ComputedBuffer):
        if isinstance(buffer, ir.ExternKernel):
            if not getattr(buffer, "cpp_kernel", False):
                self.disable_cpp_wrapper("ExternKernel")

    def register_buffer(self, buffer: ir.ComputedBuffer):
        if config.cpp_wrapper:
            self.check_buffer_for_cpp_wrapper(buffer)

        name = f"buf{len(self.buffers)}"
        self.buffers.append(buffer)
        self.name_to_buffer[name] = buffer
        return name

    def realize_users_of(self, name: str):
        """
        When a buffer is mutated we need to make sure all the reads to
        the old version are realized before the mutation happens.
        """
        assert isinstance(name, str)

        def visit(value):
            if isinstance(value, (list, tuple)):
                return [visit(x) for x in value]
            if isinstance(value, ir.IRNode):
                if value.is_user_of(name):
                    value.realize()
            return value

        for key, value in self.env.items():
            try:
                visit(value)
            except Exception:
                log.warning("error in realize_users_of", exc_info=True)

    def add_tensor_constant(self, data):
        def allocate():
            for name, value in self.constants.items():
                if (
                    data.size() == value.size()
                    and data.stride() == value.stride()
                    and data.dtype == value.dtype
                    and data.device == value.device
                    and torch.eq(data, value).all()
                ):
                    return name
            name = f"constant{len(self.constants)}"
            self.constants[name] = data
            return name

        return TensorBox.create(
            ir.ConstantBuffer(
                allocate(),
                FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)),
            )
        )

    def constant_name(self, name: str, device_override: torch.device):
        """
        We AOT copy constants to the devices they are needed on.
        If device_override doesn't match the constant's device, then
        copy it and return a different name.
        """
        if self.constants[name].device == device_override or device_override is None:
            return name
        alt_name = f"{name}_{device_override.type}{device_override.index or 0}"
        if alt_name not in self.constants:
            self.constants[alt_name] = self.constants[name].to(device_override)
        return alt_name

    def placeholder(self, target: str, args, kwargs):
        example: torch.Tensor = super().placeholder(target, args, kwargs)
        # todo(chilli): We can remove the last check once we turn buffers into
        # static shape tensors. That's a hack to workaround Inductor believing
        # the buffer should be static but us passing in a fake tensor with
        # symbolic shapes.
        if (
            config.static_weight_shapes
            and (
                len(self.graph_inputs) < self.num_static_inputs
                or not dynamo_config.dynamic_shapes
            )
            and not example._has_symbolic_sizes_strides
        ):
            # the first N inputs are weights
            sizes, strides = self.static_sizes_strides(example)
        else:
            sizes, strides = self.symbolic_sizes_strides(example)
        # TODO(jansel): handle input aliasing
        tensor = TensorBox.create(
            InputBuffer(
                target,
                FixedLayout(example.device, example.dtype, sizes, strides),
            )
        )
        self.graph_inputs[target] = tensor
        self.graph_inputs_original[target] = tensor.data.data
        self.device_types.add(example.device.type)
        return tensor

    def call_function(self, target, args, kwargs):
        with ir.IRNode.current_origins(gather_origins(args, kwargs)):
            if target is operator.getitem and isinstance(args[0], (list, tuple)):
                return super().call_function(target, args, kwargs)

            if hasattr(target, "_inductor_lowering_function"):
                # passthrough lowerings from .pattern_matcher
                return target(*args, **kwargs)

            if target not in lowerings:
                base_name = target.name().split(".")[0]
                if base_name in FALLBACK_ALLOW_LIST:
                    make_fallback(target)
                elif config.implicit_fallbacks:
                    error = (
                        MissingOperatorWithDecomp
                        if get_decompositions([target])
                        else MissingOperatorWithoutDecomp
                    )
                    log.info(
                        "Creating implicit fallback for:\n%s",
                        error.operator_str(target, args, kwargs),
                    )
                    make_fallback(target)
                elif get_decompositions([target]):
                    # There isn't a good way to dynamically patch this in
                    # since AOT Autograd already ran.  The error message tells
                    # the user how to fix it.
                    raise MissingOperatorWithDecomp(target, args, kwargs)
                else:
                    raise MissingOperatorWithoutDecomp(target, args, kwargs)

            try:
                out = lowerings[target](*args, **kwargs)
                return out
            except Exception as e:
                log.exception("Error from lowering")
                raise LoweringException(e, target, args, kwargs) from e

    def get_attr(self, target, args, kwargs):
        # this is a constant
        value = getattr(self.module, target)
        with no_dispatch():
            if value.shape == ():
                return Constant(value.item(), value.dtype, value.device)
            if len(value.shape) == 1 and value.shape[0] <= 8:
Loading ...