Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _dynamo / backends / nvfuser.py

import logging
from functools import partial

import torch
from ..backends.common import aot_autograd, mem_efficient_fusion_kwargs
from .registry import register_backend, register_debug_backend

log = logging.getLogger(__name__)


def prims_executor(gm, inputs, *, executor):
    from functorch.compile import make_boxed_func

    # This function is called once per forward/backward pass of a graph in AOT
    # Autograd. We use it to set up the nvFuser-specific FX graph and return
    # execute function.
    from torch._prims.context import TorchRefsNvfuserCapabilityMode
    from torch._prims.executor import execute
    from torch.fx.experimental.proxy_tensor import make_fx

    # AOT Autograd might not use the partitioner, so we need to make sure that
    # the graph is transformed to use nvFuser-compatible nodes.
    if not getattr(gm, "_nvprim_transformed", False):
        with TorchRefsNvfuserCapabilityMode():
            gm = make_fx(gm)(*inputs)

    # Then we return a callable that executes the "gm" graph
    return make_boxed_func(partial(execute, gm, executor=executor))


def nvprims_fw_bw_partition_fn(joint_module, joint_inputs, *, num_fwd_outputs):
    # This function is called once per forward+backward pass of a graph in AOT
    # Autograd. We use it to set up the nvFuser-specific FX graph that is later
    # passed to the executor.
    from functorch.compile import min_cut_rematerialization_partition

    from torch._prims.context import TorchRefsNvfuserCapabilityMode
    from torch.fx.experimental.proxy_tensor import make_fx

    # AOT Autograd expects arguments of the traced function to be named exactly
    # "primals, tangents"
    def func(primals, tangents):
        return joint_module(primals, tangents)

    # First we trace the graph conditionally decomposing nodes
    # that can be sent to the nvfuser executor
    with TorchRefsNvfuserCapabilityMode():
        prim_gm = make_fx(func)(*joint_inputs)

    # all nvprims for now
    recomputable_ops = {
        getattr(torch.ops.nvprims, prim)
        for prim in dir(torch.ops.nvprims)
        if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket)
        and getattr(torch.ops.nvprims, prim).is_recomputable
    }

    fw_gm, bw_gm = min_cut_rematerialization_partition(
        prim_gm,
        joint_inputs,
        recomputable_ops=recomputable_ops,
        num_fwd_outputs=num_fwd_outputs,
    )
    # AOT Autograd might not use the partitioner, so we need to make sure that
    # the graph is marked as already transformed to use nvFuser-compatible nodes
    fw_gm._nvprim_transformed = True
    bw_gm._nvprim_transformed = True
    return fw_gm, bw_gm


def create_nvprims_backend(*, executor):
    return aot_autograd(
        fw_compiler=partial(prims_executor, executor=executor),
        bw_compiler=partial(prims_executor, executor=executor),
        partition_fn=nvprims_fw_bw_partition_fn,
    )


aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser")
aot_nvprims_aten = create_nvprims_backend(executor="aten")

# "nvprims" is a subset of PrimTorch primitives that are guaranteed to be
# supported by nvFuser. This is the preferred backend for nvFuser+PrimTorch.
register_backend(name="nvprims_nvfuser", compiler_fn=aot_nvprims_nvfuser)
# This is useful for debugging. Can be removed later.
register_debug_backend(name="nvprims_aten", compiler_fn=aot_nvprims_aten)


# Use min cut rematerialization and TorchScript+nvFuser with AOT Autograd
# aot_ts_nvfuser uses the memory efficient fusion algorithm from AOT Autograd.
# It uses min cut rematerialization algorithm, uses nvFuser as the
# compiler backend, and TorchScript as the frontend.
aot_mem_efficient_fusion = aot_autograd(**mem_efficient_fusion_kwargs(use_decomps=True))
aot_mem_efficient_fusion.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
register_backend(name="aot_ts_nvfuser", compiler_fn=aot_mem_efficient_fusion)