import dataclasses
import functools
import itertools
import logging
import sys
import warnings
from typing import Any, Dict, List, Optional
import functorch
from functorch.compile import min_cut_rematerialization_partition
import torch._dynamo.config as dynamo_config
import torch.fx
from torch._dynamo import logging as dynamo_logging, utils as dynamo_utils
from torch._dynamo.utils import fake_mode_from_tensors
from torch._functorch.aot_autograd import make_boxed_func
from torch._subclasses.fake_tensor import FakeTensor
from .._dynamo.backends.common import aot_autograd
from . import config, metrics, overrides, pattern_matcher
from .debug import DebugContext
from .decomposition import select_decomp_table
from .graph import GraphLowering
from .mkldnn import convert_outplace_to_inplace
from .utils import developer_warning, get_dtype_size, has_incompatible_cudagraph_ops
from .virtualized import V
log = logging.getLogger(__name__)
ALIGNMENT = 16
@dataclasses.dataclass
class BoxedBool:
value: bool
def __bool__(self):
return self.value
@staticmethod
def disable(obj):
if isinstance(obj, BoxedBool):
obj.value = False
return obj
return False
# copy_ fails when trying to write to tensors with memory overlap,
# for expanded dimensions (a dimension which used to have size 1 -> ?)
# we can select one element from that dimension and write to it
# to achieve writing to all values of that dimension of the input tensor
def get_expanded_dims(t):
return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
def index_expanded_dims(t, expanded_dims):
for expanded_dim in expanded_dims:
t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
return t
def complex_memory_overlap(t):
# if torch._debug_has_internal_overlap thinks this tensor potentially has
# memory overlap internally, let's dig deeper to find out whether it's true.
if torch._debug_has_internal_overlap(t) != 0:
strides = t.stride()
sizes = t.shape
indices = list(range(len(strides)))
indices = [x for _, x in sorted(zip(strides, indices))]
for i in range(len(strides)):
prev_stride = 1 if i == 0 else strides[indices[i - 1]]
prev_size = 1 if i == 0 else sizes[indices[i - 1]]
if strides[indices[i]] < prev_stride * prev_size:
return True
return False
@functools.lru_cache(None)
def _step_logger():
return dynamo_logging.get_step_logger(log)
@functools.lru_cache(None)
def _warn_tf32_disabled():
if (
torch.cuda.is_available()
and not torch.backends.cuda.matmul.allow_tf32
and torch.cuda.get_device_capability() >= (8, 0)
):
warnings.warn(
"TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
"Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
)
def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
aten = torch.ops.aten
tf32_ops = {
aten.mm.default,
aten.addmm.default,
aten.bmm.default,
aten.baddbmm.default,
}
for node in gm.graph.nodes:
if (
node.op == "call_function"
and node.target in tf32_ops
and isinstance(node.meta.get("val", None), torch.Tensor)
and node.meta["val"].dtype == torch.float32
and node.meta["val"].device.type == "cuda"
):
return True
return False
@DebugContext.wrap
def count_bytes_inner(gm, example_inputs, num_fixed=0, **kwargs):
shape_env = _shape_env_from_inputs(example_inputs)
graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
num_bytes, nodes_num_elem = graph.count_bytes()
metrics.num_bytes_accessed += num_bytes
metrics.nodes_num_elem += nodes_num_elem
return make_boxed_func(gm.forward)
@DebugContext.wrap
@torch.utils._python_dispatch._disable_current_modes()
def compile_fx_inner(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
cudagraphs=None,
num_fixed=0,
is_backward=False,
graph_id=None,
):
if is_tf32_warning_applicable(gm):
_warn_tf32_disabled()
if dynamo_utils.count_calls(gm.graph) == 0:
return make_boxed_func(gm.forward)
# lift the maximum depth of the Python interpreter stack
# to adapt large/deep models
sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
_step_logger()(
logging.INFO,
"torchinductor compiling "
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
f"graph {graph_id}",
)
V.debug.fx_graph(gm, example_inputs)
if cudagraphs is None:
cudagraphs = config.triton.cudagraphs
shape_env = _shape_env_from_inputs(example_inputs)
fake_mode = fake_mode_from_tensors(
example_inputs
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
with V.set_fake_mode(fake_mode):
pattern_matcher.fx_passes(gm)
V.debug.fx_graph_transformed(gm, example_inputs)
graph = GraphLowering(
gm,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
compiled_fn = graph.compile_to_fn()
if cudagraphs:
complex_memory_overlap_inputs = any(
complex_memory_overlap(t) for t in example_inputs
)
if (
set(graph.device_types) == {"cuda"}
and not graph.mutated_inputs
and not has_incompatible_cudagraph_ops(gm)
and not complex_memory_overlap_inputs
):
compiled_fn = cudagraphify(
compiled_fn, example_inputs, static_input_idxs=range(num_fixed)
)
else:
BoxedBool.disable(cudagraphs)
if len(set(graph.device_types)) > 1:
developer_warning("skipping cudagraphs due to multiple devices")
elif set(graph.device_types) == {"cuda"}:
if graph.mutated_inputs:
developer_warning("skipping cudagraphs due to input mutation")
elif complex_memory_overlap_inputs:
developer_warning(
"skipping cudagraphs due to complex input striding"
)
result = align_inputs(compiled_fn, example_inputs, range(num_fixed))
_step_logger()(
logging.INFO,
"torchinductor done compiling "
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
f"graph {graph_id}",
)
# aot autograd needs to know to pass in inputs as a list
result._boxed_call = True
return result
def clone_preserve_strides(x):
needed_size = (
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
)
buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
return torch.as_strided(buffer, x.size(), x.stride())
def align_inputs(model, inputs, static_input_idxs=()):
def is_aligned(storage_offset, dtype):
return (storage_offset * get_dtype_size(dtype)) % ALIGNMENT == 0
check_inputs = [
i
for i in range(len(inputs))
if (
i not in static_input_idxs
or not is_aligned(inputs[i].storage_offset(), inputs[i].dtype)
)
and inputs[i].device.type == "cuda"
]
if len(check_inputs) == 0:
return model
def run(new_inputs):
for i in check_inputs:
if new_inputs[i].data_ptr() % ALIGNMENT:
new_inputs[i] = clone_preserve_strides(new_inputs[i])
return model(new_inputs)
return run
@dynamo_utils.dynamo_timed
def cudagraphify(model, inputs, static_input_idxs=()):
# if using fake tensors, defer cudagraphs until we get real inputs at runtime
if not any(isinstance(inp, FakeTensor) for inp in inputs):
return cudagraphify_impl(model, inputs, static_input_idxs)
compiled_fn = None
def run(new_inputs):
nonlocal compiled_fn
if compiled_fn is None:
with dynamo_utils.preserve_rng_state():
compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
return compiled_fn(new_inputs)
return run
def remove_unaligned_input_idxs(inputs, static_input_idxs):
"""
We require all inputs to be aligned, so introduce a copy for any
that aren't.
"""
aligned_static_input_idxs = {
idx for idx in static_input_idxs if (inputs[idx].data_ptr() % ALIGNMENT) == 0
}
if len(aligned_static_input_idxs) != len(static_input_idxs):
return aligned_static_input_idxs
return static_input_idxs
def cudagraphify_impl(model, inputs, static_input_idxs=()):
"""
Assumes inputs[static_input_idxs[i]] are always the same memory address
"""
static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
def static_input(x):
"""
Copy and input while preserving strides
"""
# TODO(jansel): figure out why this version doesn't work:
# return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
needed_size = (
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
)
buffer = torch.zeros(needed_size, dtype=x.dtype, device=x.device)
return torch.as_strided(buffer, x.size(), x.stride())
assert isinstance(inputs, (list, tuple))
static_inputs = [
static_input(x) if idx not in static_input_idxs else x.detach()
for idx, x in enumerate(inputs)
]
inps_expanded_dims = [
get_expanded_dims(x) if idx not in static_input_idxs else []
for idx, x in enumerate(inputs)
]
# warmup
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
# copy static_inputs because it will be cleared in model
with torch.cuda.stream(stream):
model(list(static_inputs))
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = model(list(static_inputs))
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
if config.size_asserts:
def run(new_inputs):
assert len(static_inputs) == len(new_inputs)
for idx, (dst, src, expanded_dims) in enumerate(
zip(static_inputs, new_inputs, inps_expanded_dims)
):
if idx in static_input_idxs:
assert dst.data_ptr() == src.data_ptr()
else:
# TODO - could make one single op of multiple slices
# and avoid dispatch.
# Could also pre-index the `dst` tensors
dst = index_expanded_dims(dst, expanded_dims)
Loading ...