Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _prims / nvfuser_executor.py

import operator
from copy import deepcopy
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
from warnings import warn

import torch
import torch.overrides
from torch._prims_common import (
    _torch_dtype_to_nvfuser_dtype_map,
    getnvFuserDtype,
    Number,
    number_type,
)

from torch.fx import GraphModule
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten

if torch.cuda.is_available():
    from nvfuser._C import (  # type: ignore[import]
        DataType,
        Fusion,
        FusionDefinition,
        Tensor,
    )
else:
    DataType = None

import os


@lru_cache(None)
def get_nvprim_dump_nvtx():
    return os.getenv("PYTORCH_NVFUSER_DUMP_NVTX")


DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType(
    {
        "use_python_fusion_cache": True,
        "allow_single_op_fusion": False,
    }
)

# nvFuserTensorTemplate and nvFuserScalarTemplate are helper objects
# for cached construction of the nvFuser's Fusion
# TODO: change what is stored in the cache for nvFuser's Tensor objects
# https://github.com/pytorch/pytorch/issues/80551
@dataclass(frozen=True)
class nvFuserTensorTemplate:
    symbolic_shape: tuple
    contiguity: tuple
    dtype: DataType
    is_cpu: bool


@dataclass(frozen=True)
class nvFuserScalarTemplate:
    dtype: DataType


@lru_cache(maxsize=2048)
def compute_symbolic_shape(shape):
    """Computes the symbolic shape of a tensor.
    nvFuser specializes on size-1 dimensions as broadcasted dimensions.
    -1 is used to represent any size."""
    return tuple(1 if s == 1 else -1 for s in shape)


@lru_cache(maxsize=2048)
def compute_contiguity(shape, strides):
    """Computes the contiguity information to simplify internal indexing.
    Contiguous dimensions are represented by True, strided dimensions
    are represented by False.
    """
    from nvfuser._C import compute_contiguity

    return compute_contiguity(shape, strides)


def to_nvfuser_template_args(args):
    def to_nvfuser(arg):
        if isinstance(arg, torch.Tensor):
            return nvFuserTensorTemplate(
                compute_symbolic_shape(arg.size()),
                compute_contiguity(arg.size(), arg.stride()),
                getnvFuserDtype(arg.dtype),
                arg.is_cpu,  # type: ignore[attr-defined]
            )
        elif isinstance(arg, Number):
            return nvFuserScalarTemplate(getnvFuserDtype(number_type(arg)))
        else:
            return arg

    return tree_map(to_nvfuser, args)


def _any_get_attr_used(call_function_nodes):
    return any(
        filter(
            # bug in mypy https://github.com/python/mypy/issues/12682
            lambda n: any(  # type: ignore[arg-type]
                a.op == "get_attr" for a in n.args if isinstance(a, torch.fx.Node)  # type: ignore[attr-defined]
            ),
            call_function_nodes,
        )
    )


# MyPy bug: https://github.com/python/mypy/issues/5107
@lru_cache(maxsize=1024)  # type: ignore[arg-type]
def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
    if not torch.cuda.is_available():
        raise RuntimeError(
            "Attempting to use nvFuser trace executor but CUDA is not available!"
        )

    # Everything in the graph must support nvfuser
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == operator.getitem:
            continue
        if (
            node.op == "call_function"
            and getattr(node.target, "impl_nvfuser", None) is None
        ):
            raise ValueError(
                "All call_function nodes in the graph must support nvfuser. "
                f"Node {node} with target {node.target} does not support nvfuser"
            )

    graph_input_nodes = list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))
    call_function_nodes = list(
        filter(lambda n: n.op == "call_function", gm.graph.nodes)
    )
    assert len(graph_input_nodes) == len(
        nv_args_templates
    ), "Number of placeholder nodes in the graph must match number of args"
    assert len(nv_args_templates) > 0, "There must be at least one argument"
    assert (
        len(call_function_nodes) > 0
    ), "Graph must contain at least one call_function node"
    assert not _any_get_attr_used(
        call_function_nodes
    ), "Constant tensors that are saved in the graph and used as arguments are not supported yet"

    # Checking output dtypes
    output_node = next(filter(lambda n: n.op == "output", gm.graph.nodes))
    orig_flat_out, _ = tree_flatten(output_node.args[0])

    fusion = Fusion()
    with FusionDefinition(fusion) as fd:

        def _to_nvfuser_constant(arg):
            if isinstance(arg, Number):
                return fd.define_constant(arg)
            else:
                return arg

        class FusionInterpreter(torch.fx.Interpreter):
            def run_node(self, node):
                # Squeeze requires original shape of args[0]
                if node.target in [
                    torch.ops.nvprims.squeeze,
                    torch.ops.nvprims.squeeze.default,
                ]:
                    original_shape = list(node.args[0].meta["tensor_meta"].shape)
                    assert len(node.args) == 2
                    args, kwargs = self.fetch_args_kwargs_from_env(node)
                    args = [args[0], original_shape, args[1]]
                    return self.call_function(node.target, args, node.kwargs)

                if node.target in [
                    torch.ops.nvprims.native_batch_norm,
                    torch.ops.nvprims.native_batch_norm.default,
                ]:
                    args, kwargs = self.fetch_args_kwargs_from_env(node)
                    assert len(args) == 8
                    training = args[5]
                    args6_end = tuple(map(_to_nvfuser_constant, args[6:]))
                    args = args[:5] + (training,) + args6_end
                    return node.target.impl_nvfuser(fd, *args, **kwargs)

                return super().run_node(node)

            def call_function(self, target, args, kwargs):
                # This handles tuple unpacking
                if target == operator.getitem:
                    assert isinstance(args[0], tuple)
                    return target(*args, **kwargs)
                args = tuple(map(_to_nvfuser_constant, args))
                target = target.impl_nvfuser
                args = (fd,) + args
                return target(*args, **kwargs)

            def output(self, target, args, kwargs):
                flat_out, unflatten_spec = tree_flatten(args[0])
                for o, orig_o in zip(flat_out, orig_flat_out):
                    # casting outputs to the original data type
                    # ensures outputs produced by fusion would always agree with original GraphModule
                    out_dtype = _torch_dtype_to_nvfuser_dtype_map.get(orig_o.meta["tensor_meta"].dtype)  # type: ignore[union-attr]
                    assert isinstance(
                        o, Tensor
                    ), "output from codegen has to be tensor type"
                    fd.add_output(fd.ops.cast(o, dtype=out_dtype))
                return args[0]

        def templates_to_nvfuser_inputs(arg):
            if isinstance(arg, nvFuserTensorTemplate):
                x = fd.define_tensor(
                    arg.symbolic_shape, arg.contiguity, arg.dtype, arg.is_cpu
                )
                return x
            elif isinstance(arg, nvFuserScalarTemplate):
                x = fd.define_scalar(arg.dtype)
                return x
            else:
                return arg

        # Transforms graph to call nvfuser lowerings
        nv_args = tuple(map(templates_to_nvfuser_inputs, nv_args_templates))
        out = FusionInterpreter(gm).run(*nv_args)
        flat_out, unflatten_spec = tree_flatten(out)

    return fusion, unflatten_spec


def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None):
    executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG
    flat_args, _ = tree_flatten(args)

    # check for cuda only fusion
    if any(isinstance(arg, torch.Tensor) and arg.is_cuda for arg in flat_args) and all(  # type: ignore[attr-defined]
        (
            not isinstance(arg, torch.Tensor)
            or (arg.is_cpu and arg.ndim == 0)  # type: ignore[attr-defined]
            or arg.is_cuda  # type: ignore[attr-defined]
        )
        for arg in flat_args
    ):

        # Construction of the fusion is expensive and cached based on the GraphModule
        # and symbolic nvFuser args.
        nv_template_args = to_nvfuser_template_args(flat_args)
        use_cache = executor_parameters.get(
            "use_python_fusion_cache",
            DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"],
        )
        if use_cache:
            fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args)  # type: ignore[misc]
        else:
            fusion, unflatten_spec = make_nvfuser_fusion.__wrapped__(gm, *nv_template_args)  # type: ignore[misc]

        # Inputs to fusion.execute correspond to the same template/symbolic inputs
        # marked with `define_tensor/scalar`
        concrete_fusion_inputs = tuple(
            arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number))
        )

        if get_nvprim_dump_nvtx():
            torch.cuda.nvtx.range_push(
                "fusion: {0}, graph: {1}".format(
                    fusion.id(),
                    str(
                        [
                            {
                                "op": n.op,
                                "name": n.name,
                                "args": n.args,
                                "kwargs": n.kwargs,
                            }
                            for n in gm.graph.nodes
                        ]
                    ),
                )
            )
        result = tree_unflatten(
            fusion.execute(concrete_fusion_inputs),  # type: ignore[has-type]
            unflatten_spec,  # type: ignore[has-type]
        )
        if get_nvprim_dump_nvtx():
            torch.cuda.nvtx.range_pop()
        return result
    else:
        warn(
            "nvfuser_executor is executed with non-cuda args, fallback to aten executor"
        )
        return gm.forward(*args)


class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        # special case to stop lowering to nvprim when converting to an unsupported type
        if (
            node.op == "call_function"
            and node.target == torch.ops.nvprims.convert_element_type.default
        ):
            return (
                _torch_dtype_to_nvfuser_dtype_map.get(node.args[1]) is not None
                and _torch_dtype_to_nvfuser_dtype_map.get(
                    node.args[0].meta["tensor_meta"].dtype  # type: ignore[union-attr]
                )
                is not None
            )
        return node.op == "call_function" and (
            getattr(node.target, "impl_nvfuser", None) is not None
            or node.target == operator.getitem
        )


class PartitionedInterpreter(torch.fx.Interpreter):
    def call_module(self, target, args, kwargs):
        assert isinstance(target, str)
        assert len(kwargs) == 0
        submod = self.fetch_attr(target)
        # CapabilityBasedPartitioner hardcodes the name of the subgraphs with supported_ops as "fused_" + subgraph id
        if target.startswith("fused_"):
            return nvfuser_execute(submod, *args)
        else:
            return super().call_module(target, args, kwargs)


class NvfuserGraphModule(torch.nn.Module):
    def __init__(self, gm, use_python_fusion_cache):
        super().__init__()
        self.gm = gm
        self.executor_parameters = {"use_python_fusion_cache": use_python_fusion_cache}

    def __call__(self, *args):
        return nvfuser_execute(
            self.gm, *args, executor_parameters=self.executor_parameters
        )


# A set of operators that are supported by nvFuser
# but should not form a fusion group solely on their own
_non_compute_ops = [
    "torch.ops." + str(getattr(torch.ops.nvprims, prim).default)
    for prim in dir(torch.ops.nvprims)
    if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket)
    and getattr(torch.ops.nvprims, prim).return_type
    == torch._prims_common.RETURN_TYPE.VIEW
]

_allowed_single_node_partition_ops = [
Loading ...