import logging
import operator
import os
import re
import sys
import time
from typing import Dict, List, Optional, Set
import sympy
import torch
import torch.fx
from torch._decomp import get_decompositions
from torch._dynamo.utils import dynamo_timed
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._mode_utils import no_dispatch
from .._dynamo import config as dynamo_config
from . import config, ir
from .codegen.wrapper import CppWrapperCodeGen, WrapperCodeGen
from .exc import (
LoweringException,
MissingOperatorWithDecomp,
MissingOperatorWithoutDecomp,
)
from .ir import Constant, FixedLayout, InputBuffer, Pointwise, Reduction, TensorBox
from .lowering import (
FALLBACK_ALLOW_LIST,
layout_constraints,
lowerings,
make_fallback,
needs_realized_inputs,
)
from .sizevars import CppSizeVarAllocator, SizeVarAllocator
from .utils import (
convert_shape_to_inductor,
gather_origins,
get_dtype_size,
sympy_product,
)
from .virtualized import V
log = logging.getLogger(__name__)
def supported_dtype_of_cpp_wrapper(dtype):
supported_dtype = {
torch.float32,
torch.float64,
torch.int64,
torch.int32,
torch.int16,
torch.int8,
torch.uint8,
torch.bool,
# torch.float16, # TODO: implement this
# torch.bfloat16, # TODO: implement this
}
return dtype in supported_dtype
class GraphLowering(torch.fx.Interpreter):
def symbolic_sizes_strides(self, ex: torch.Tensor):
"""
Support dynamic shapes and dynamic strides by assigning variables
to each dimension. We duck-shape tensors, so if two tensors
have the same size they get assigned the same symbolic variable.
"""
if self.reuse_shape_env:
return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor(
ex.stride()
)
else:
from torch._dynamo.source import ConstantSource
# TODO: this should not be needed once #93059 lands
# https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816
# TODO: make a dedicated UnknownSource for this?
source = ConstantSource(
f"__unknown_tensor_{len(self._shape_env.var_to_val)}"
)
(
size,
stride,
_,
) = self._shape_env.create_symbolic_sizes_strides_storage_offset(ex, source)
size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size]
stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
return size, stride
def static_sizes_strides(self, ex: torch.Tensor):
"""
Primarily used to weights
"""
size = [sympy.Integer(i) for i in ex.size()]
stride = [sympy.Integer(i) for i in ex.stride()]
return size, stride
def __init__(
self,
gm: torch.fx.GraphModule,
shape_env=None,
num_static_inputs=None,
graph_id=None,
):
super().__init__(gm)
if shape_env is None:
shape_env = ShapeEnv()
self.reuse_shape_env = False
else:
self._shape_env = shape_env
self.reuse_shape_env = True
self._shape_env = shape_env
self.sizevars = SizeVarAllocator(shape_env)
self.graph_inputs: Dict[str, TensorBox] = {}
self.graph_inputs_original: Dict[str, InputBuffer] = {}
self.graph_outputs: Optional[List[ir.IRNode]] = None
self.device_types: Set[str] = set()
self.buffers: List[ir.ComputedBuffer] = []
self.constants: Dict[str, torch.Tensor] = {}
self.removed_buffers: Set[str] = set()
self.inplaced_to_remove: Set[str] = set()
self.wrapper_code = None
self.num_static_inputs = num_static_inputs
self.mutated_inputs: Set[str] = set()
self.unaligned_buffers: Set[str] = set()
self.randomness_offset = sympy.Integer(0)
self.randomness_seeds: List[str] = []
self.name_to_buffer: Dict[str, ir.ComputedBuffer] = {}
self.creation_time = time.time()
self.name = "GraphLowering"
self._can_use_cpp_wrapper = config.cpp_wrapper
self.graph_id = graph_id
self.scheduler = None
self._warned_fallback = {"aten.convolution_backward"}
def warn_fallback(self, name):
if name not in self._warned_fallback:
self._warned_fallback.add(name)
log.info(f"Using FallbackKernel: {name}")
@property
def fake_mode(self):
return V.fake_mode
def get_dtype(self, buffer_name: str):
if buffer_name in self.constants:
return self.constants[buffer_name].dtype
if buffer_name in self.name_to_buffer:
return self.name_to_buffer[buffer_name].get_dtype()
if buffer_name in self.graph_inputs:
return self.graph_inputs[buffer_name].get_dtype()
m = re.match(r"as_strided\(([a-zA-Z0-9_]+),", buffer_name)
if m:
return self.get_dtype(m.group(1))
raise KeyError(f"could not find {buffer_name}")
def random_seed_buffer(self, device: torch.device):
"""
Return a device-unique 1-element tensor storing our RNG seed.
This will get initialized at the start of each graph in
`wrapper.py`.
Note this is only used by cuda backends. The CPU backend handles
RNG seeds as a sizevar.
"""
name = f"seed_{device.type}_{device.index}"
if name not in self.constants:
self.constants[name] = torch.zeros((), device=device, dtype=torch.int64)
self.randomness_seeds.append(name)
return ir.RandSeedBuffer(
name=name,
layout=ir.FixedLayout(
device=device,
dtype=torch.int64,
size=[],
stride=[],
),
)
def increment_randomness_offset(self, numel):
"""
A global counter of how many random numbers we have handed out so far.
"""
offset = self.randomness_offset
self.randomness_offset = offset + numel
return offset
@dynamo_timed
def run(self, *args):
return super().run(*args)
def disable_cpp_wrapper(self, cond):
self._can_use_cpp_wrapper = False
log.debug("Set _can_use_cpp_wrapper to False due to %s", cond)
def check_buffer_for_cpp_wrapper(self, buffer: ir.ComputedBuffer):
if isinstance(buffer, ir.ExternKernel):
if not getattr(buffer, "cpp_kernel", False):
self.disable_cpp_wrapper("ExternKernel")
def register_buffer(self, buffer: ir.ComputedBuffer):
if config.cpp_wrapper:
self.check_buffer_for_cpp_wrapper(buffer)
name = f"buf{len(self.buffers)}"
self.buffers.append(buffer)
self.name_to_buffer[name] = buffer
return name
def realize_users_of(self, name: str):
"""
When a buffer is mutated we need to make sure all the reads to
the old version are realized before the mutation happens.
"""
assert isinstance(name, str)
def visit(value):
if isinstance(value, (list, tuple)):
return [visit(x) for x in value]
if isinstance(value, ir.IRNode):
if value.is_user_of(name):
value.realize()
return value
for key, value in self.env.items():
try:
visit(value)
except Exception:
log.warning("error in realize_users_of", exc_info=True)
def add_tensor_constant(self, data):
def allocate():
for name, value in self.constants.items():
if (
data.size() == value.size()
and data.stride() == value.stride()
and data.dtype == value.dtype
and data.device == value.device
and torch.eq(data, value).all()
):
return name
name = f"constant{len(self.constants)}"
self.constants[name] = data
return name
return TensorBox.create(
ir.ConstantBuffer(
allocate(),
FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)),
)
)
def constant_name(self, name: str, device_override: torch.device):
"""
We AOT copy constants to the devices they are needed on.
If device_override doesn't match the constant's device, then
copy it and return a different name.
"""
if self.constants[name].device == device_override or device_override is None:
return name
alt_name = f"{name}_{device_override.type}{device_override.index or 0}"
if alt_name not in self.constants:
self.constants[alt_name] = self.constants[name].to(device_override)
return alt_name
def placeholder(self, target: str, args, kwargs):
example: torch.Tensor = super().placeholder(target, args, kwargs)
# todo(chilli): We can remove the last check once we turn buffers into
# static shape tensors. That's a hack to workaround Inductor believing
# the buffer should be static but us passing in a fake tensor with
# symbolic shapes.
if (
config.static_weight_shapes
and (
len(self.graph_inputs) < self.num_static_inputs
or not dynamo_config.dynamic_shapes
)
and not example._has_symbolic_sizes_strides
):
# the first N inputs are weights
sizes, strides = self.static_sizes_strides(example)
else:
sizes, strides = self.symbolic_sizes_strides(example)
# TODO(jansel): handle input aliasing
tensor = TensorBox.create(
InputBuffer(
target,
FixedLayout(example.device, example.dtype, sizes, strides),
)
)
self.graph_inputs[target] = tensor
self.graph_inputs_original[target] = tensor.data.data
self.device_types.add(example.device.type)
return tensor
def call_function(self, target, args, kwargs):
with ir.IRNode.current_origins(gather_origins(args, kwargs)):
if target is operator.getitem and isinstance(args[0], (list, tuple)):
return super().call_function(target, args, kwargs)
if hasattr(target, "_inductor_lowering_function"):
# passthrough lowerings from .pattern_matcher
return target(*args, **kwargs)
if target not in lowerings:
base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target)
elif config.implicit_fallbacks:
error = (
MissingOperatorWithDecomp
if get_decompositions([target])
else MissingOperatorWithoutDecomp
)
log.info(
"Creating implicit fallback for:\n%s",
error.operator_str(target, args, kwargs),
)
make_fallback(target)
elif get_decompositions([target]):
# There isn't a good way to dynamically patch this in
# since AOT Autograd already ran. The error message tells
# the user how to fix it.
raise MissingOperatorWithDecomp(target, args, kwargs)
else:
raise MissingOperatorWithoutDecomp(target, args, kwargs)
try:
out = lowerings[target](*args, **kwargs)
return out
except Exception as e:
log.exception("Error from lowering")
raise LoweringException(e, target, args, kwargs) from e
def get_attr(self, target, args, kwargs):
# this is a constant
value = getattr(self.module, target)
with no_dispatch():
if value.shape == ():
return Constant(value.item(), value.dtype, value.device)
if len(value.shape) == 1 and value.shape[0] <= 8:
Loading ...