Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _dynamo / backends / debugging.py

import functools
from importlib import import_module

from functorch.compile import min_cut_rematerialization_partition, nop

import torch
from torch._functorch.compilers import ts_compile
from .common import aot_autograd
from .registry import register_debug_backend as register_backend

"""
This file contains TorchDynamo backends intended for debugging uses.
"""


@register_backend
def eager(gm, fake_tensor_inputs):
    return gm


@register_backend(name="ts")
def torchscript(gm, fake_tensor_inputs):
    return torch.jit.script(gm)


# Useful for debugging purpose
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
aot_eager = aot_autograd(fw_compiler=nop)
register_backend(name="aot_eager", compiler_fn=aot_eager)


# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
# inductor problems.
# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
# isolate inductor vs aot_eager errors
aot_eager_decomp_partition = aot_autograd(
    # these are taken from memory_efficient_fusion()
    fw_compiler=nop,
    bw_compiler=nop,
    # NB: lambda here is to delay import of inductor
    decompositions=lambda: import_module(
        "torch._inductor.compile_fx"
    ).select_decomp_table(),
    partition_fn=functools.partial(
        min_cut_rematerialization_partition, compiler="inductor"
    ),
)
register_backend(
    name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
)

# AOT Autograd with torchscript backend. Default partitioner.
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
# by using the relevant fuser with torch.jit.fuser(...)
aot_ts = aot_autograd(fw_compiler=ts_compile)
register_backend(name="aot_ts", compiler_fn=aot_ts)