import operator
from copy import deepcopy
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
from warnings import warn
import torch
import torch.overrides
from torch._prims_common import (
_torch_dtype_to_nvfuser_dtype_map,
getnvFuserDtype,
Number,
number_type,
)
from torch.fx import GraphModule
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
if torch.cuda.is_available():
from nvfuser._C import ( # type: ignore[import]
DataType,
Fusion,
FusionDefinition,
Tensor,
)
else:
DataType = None
import os
@lru_cache(None)
def get_nvprim_dump_nvtx():
return os.getenv("PYTORCH_NVFUSER_DUMP_NVTX")
DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType(
{
"use_python_fusion_cache": True,
"allow_single_op_fusion": False,
}
)
# nvFuserTensorTemplate and nvFuserScalarTemplate are helper objects
# for cached construction of the nvFuser's Fusion
# TODO: change what is stored in the cache for nvFuser's Tensor objects
# https://github.com/pytorch/pytorch/issues/80551
@dataclass(frozen=True)
class nvFuserTensorTemplate:
symbolic_shape: tuple
contiguity: tuple
dtype: DataType
is_cpu: bool
@dataclass(frozen=True)
class nvFuserScalarTemplate:
dtype: DataType
@lru_cache(maxsize=2048)
def compute_symbolic_shape(shape):
"""Computes the symbolic shape of a tensor.
nvFuser specializes on size-1 dimensions as broadcasted dimensions.
-1 is used to represent any size."""
return tuple(1 if s == 1 else -1 for s in shape)
@lru_cache(maxsize=2048)
def compute_contiguity(shape, strides):
"""Computes the contiguity information to simplify internal indexing.
Contiguous dimensions are represented by True, strided dimensions
are represented by False.
"""
from nvfuser._C import compute_contiguity
return compute_contiguity(shape, strides)
def to_nvfuser_template_args(args):
def to_nvfuser(arg):
if isinstance(arg, torch.Tensor):
return nvFuserTensorTemplate(
compute_symbolic_shape(arg.size()),
compute_contiguity(arg.size(), arg.stride()),
getnvFuserDtype(arg.dtype),
arg.is_cpu, # type: ignore[attr-defined]
)
elif isinstance(arg, Number):
return nvFuserScalarTemplate(getnvFuserDtype(number_type(arg)))
else:
return arg
return tree_map(to_nvfuser, args)
def _any_get_attr_used(call_function_nodes):
return any(
filter(
# bug in mypy https://github.com/python/mypy/issues/12682
lambda n: any( # type: ignore[arg-type]
a.op == "get_attr" for a in n.args if isinstance(a, torch.fx.Node) # type: ignore[attr-defined]
),
call_function_nodes,
)
)
# MyPy bug: https://github.com/python/mypy/issues/5107
@lru_cache(maxsize=1024) # type: ignore[arg-type]
def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
if not torch.cuda.is_available():
raise RuntimeError(
"Attempting to use nvFuser trace executor but CUDA is not available!"
)
# Everything in the graph must support nvfuser
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == operator.getitem:
continue
if (
node.op == "call_function"
and getattr(node.target, "impl_nvfuser", None) is None
):
raise ValueError(
"All call_function nodes in the graph must support nvfuser. "
f"Node {node} with target {node.target} does not support nvfuser"
)
graph_input_nodes = list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))
call_function_nodes = list(
filter(lambda n: n.op == "call_function", gm.graph.nodes)
)
assert len(graph_input_nodes) == len(
nv_args_templates
), "Number of placeholder nodes in the graph must match number of args"
assert len(nv_args_templates) > 0, "There must be at least one argument"
assert (
len(call_function_nodes) > 0
), "Graph must contain at least one call_function node"
assert not _any_get_attr_used(
call_function_nodes
), "Constant tensors that are saved in the graph and used as arguments are not supported yet"
# Checking output dtypes
output_node = next(filter(lambda n: n.op == "output", gm.graph.nodes))
orig_flat_out, _ = tree_flatten(output_node.args[0])
fusion = Fusion()
with FusionDefinition(fusion) as fd:
def _to_nvfuser_constant(arg):
if isinstance(arg, Number):
return fd.define_constant(arg)
else:
return arg
class FusionInterpreter(torch.fx.Interpreter):
def run_node(self, node):
# Squeeze requires original shape of args[0]
if node.target in [
torch.ops.nvprims.squeeze,
torch.ops.nvprims.squeeze.default,
]:
original_shape = list(node.args[0].meta["tensor_meta"].shape)
assert len(node.args) == 2
args, kwargs = self.fetch_args_kwargs_from_env(node)
args = [args[0], original_shape, args[1]]
return self.call_function(node.target, args, node.kwargs)
if node.target in [
torch.ops.nvprims.native_batch_norm,
torch.ops.nvprims.native_batch_norm.default,
]:
args, kwargs = self.fetch_args_kwargs_from_env(node)
assert len(args) == 8
training = args[5]
args6_end = tuple(map(_to_nvfuser_constant, args[6:]))
args = args[:5] + (training,) + args6_end
return node.target.impl_nvfuser(fd, *args, **kwargs)
return super().run_node(node)
def call_function(self, target, args, kwargs):
# This handles tuple unpacking
if target == operator.getitem:
assert isinstance(args[0], tuple)
return target(*args, **kwargs)
args = tuple(map(_to_nvfuser_constant, args))
target = target.impl_nvfuser
args = (fd,) + args
return target(*args, **kwargs)
def output(self, target, args, kwargs):
flat_out, unflatten_spec = tree_flatten(args[0])
for o, orig_o in zip(flat_out, orig_flat_out):
# casting outputs to the original data type
# ensures outputs produced by fusion would always agree with original GraphModule
out_dtype = _torch_dtype_to_nvfuser_dtype_map.get(orig_o.meta["tensor_meta"].dtype) # type: ignore[union-attr]
assert isinstance(
o, Tensor
), "output from codegen has to be tensor type"
fd.add_output(fd.ops.cast(o, dtype=out_dtype))
return args[0]
def templates_to_nvfuser_inputs(arg):
if isinstance(arg, nvFuserTensorTemplate):
x = fd.define_tensor(
arg.symbolic_shape, arg.contiguity, arg.dtype, arg.is_cpu
)
return x
elif isinstance(arg, nvFuserScalarTemplate):
x = fd.define_scalar(arg.dtype)
return x
else:
return arg
# Transforms graph to call nvfuser lowerings
nv_args = tuple(map(templates_to_nvfuser_inputs, nv_args_templates))
out = FusionInterpreter(gm).run(*nv_args)
flat_out, unflatten_spec = tree_flatten(out)
return fusion, unflatten_spec
def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None):
executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG
flat_args, _ = tree_flatten(args)
# check for cuda only fusion
if any(isinstance(arg, torch.Tensor) and arg.is_cuda for arg in flat_args) and all( # type: ignore[attr-defined]
(
not isinstance(arg, torch.Tensor)
or (arg.is_cpu and arg.ndim == 0) # type: ignore[attr-defined]
or arg.is_cuda # type: ignore[attr-defined]
)
for arg in flat_args
):
# Construction of the fusion is expensive and cached based on the GraphModule
# and symbolic nvFuser args.
nv_template_args = to_nvfuser_template_args(flat_args)
use_cache = executor_parameters.get(
"use_python_fusion_cache",
DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"],
)
if use_cache:
fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc]
else:
fusion, unflatten_spec = make_nvfuser_fusion.__wrapped__(gm, *nv_template_args) # type: ignore[misc]
# Inputs to fusion.execute correspond to the same template/symbolic inputs
# marked with `define_tensor/scalar`
concrete_fusion_inputs = tuple(
arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number))
)
if get_nvprim_dump_nvtx():
torch.cuda.nvtx.range_push(
"fusion: {0}, graph: {1}".format(
fusion.id(),
str(
[
{
"op": n.op,
"name": n.name,
"args": n.args,
"kwargs": n.kwargs,
}
for n in gm.graph.nodes
]
),
)
)
result = tree_unflatten(
fusion.execute(concrete_fusion_inputs), # type: ignore[has-type]
unflatten_spec, # type: ignore[has-type]
)
if get_nvprim_dump_nvtx():
torch.cuda.nvtx.range_pop()
return result
else:
warn(
"nvfuser_executor is executed with non-cuda args, fallback to aten executor"
)
return gm.forward(*args)
class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
# special case to stop lowering to nvprim when converting to an unsupported type
if (
node.op == "call_function"
and node.target == torch.ops.nvprims.convert_element_type.default
):
return (
_torch_dtype_to_nvfuser_dtype_map.get(node.args[1]) is not None
and _torch_dtype_to_nvfuser_dtype_map.get(
node.args[0].meta["tensor_meta"].dtype # type: ignore[union-attr]
)
is not None
)
return node.op == "call_function" and (
getattr(node.target, "impl_nvfuser", None) is not None
or node.target == operator.getitem
)
class PartitionedInterpreter(torch.fx.Interpreter):
def call_module(self, target, args, kwargs):
assert isinstance(target, str)
assert len(kwargs) == 0
submod = self.fetch_attr(target)
# CapabilityBasedPartitioner hardcodes the name of the subgraphs with supported_ops as "fused_" + subgraph id
if target.startswith("fused_"):
return nvfuser_execute(submod, *args)
else:
return super().call_module(target, args, kwargs)
class NvfuserGraphModule(torch.nn.Module):
def __init__(self, gm, use_python_fusion_cache):
super().__init__()
self.gm = gm
self.executor_parameters = {"use_python_fusion_cache": use_python_fusion_cache}
def __call__(self, *args):
return nvfuser_execute(
self.gm, *args, executor_parameters=self.executor_parameters
)
# A set of operators that are supported by nvFuser
# but should not form a fusion group solely on their own
_non_compute_ops = [
"torch.ops." + str(getattr(torch.ops.nvprims, prim).default)
for prim in dir(torch.ops.nvprims)
if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket)
and getattr(torch.ops.nvprims, prim).return_type
== torch._prims_common.RETURN_TYPE.VIEW
]
_allowed_single_node_partition_ops = [
Loading ...