Repository URL to install this package:
|
Version:
2.4.1 ▾
|
# mypy: allow-untyped-defs
"""
Utils for caching the outputs of AOTAutograd
"""
from __future__ import annotations
import functools
import logging
import os
from typing import TYPE_CHECKING
import torch
from torch._functorch import config
from torch._inductor.codecache import (
_ident,
BypassFxGraphCache,
FxGraphCachePickler,
FxGraphHashDetails,
get_code_hash,
)
from .schemas import AOTConfig # noqa: F401
if TYPE_CHECKING:
from torch.fx.node import Node
log = logging.getLogger(__name__)
class BypassAOTAutogradCache(Exception):
pass
def check_node_safe(node: Node):
"""
Checks that the node only uses supported operators. We are starting with very
conservative cacheability constraints, and incrementally adding more support as we expand.
"""
def is_torch_function(target):
is_builtin_fun_or_type = type(target).__name__ == "builtin_function_or_method"
# TODO: handle torch.nn.functional and other non inlined targets, which don't compile down to a builtin
return is_builtin_fun_or_type
def is_tensor(target):
# Tensors always have example values in meta field
return "example_value" in target.meta
# I'd love to use a match statement here, but it wasn't introduced until py3.10
if node.op == "call_function":
# We support only torch.* functions for now
# We can probably add an allowlist of safe non-torch implementations as well
if not is_torch_function(node.target):
raise BypassAOTAutogradCache(
f"Unsupported call_function target {node.target}"
)
elif node.op == "call_method":
method_name = node.target
method_target = node.args[0]
# Only support method calls on base tensors
if not is_tensor(method_target):
raise BypassAOTAutogradCache(
f"Unsupported call_method target {method_target}"
)
if (
type(method_name) != str
and type(method_name).__name__ != "method_descriptor"
):
raise BypassAOTAutogradCache(
f"Unsupported call_method method {node.target}: {method_name}"
)
# Cache safe
elif node.op in ("placeholder", "get_attr", "call_module", "output"):
# Assumption today for call_module being a safe op:
# (1) today the only call_module ops that can show up in a graph come from "built-in-nn-modules"
# that dynamo assumes are safe to trace. If dynamo assumes they are safely to blindly trace, then
# they should be safe to cache as well.
# (2) in the steady-state (some time in H2?) we shouldn't see these anymore, once inline builtin nn modules by default
# (3) We do not allow user made nn modules in the graph today, only function calls.
pass
else:
raise BypassAOTAutogradCache(f"Unsupported node op {node.op}")
@functools.lru_cache(None)
def get_autograd_code_hash():
autograd_root = os.path.dirname(__file__)
return get_code_hash([autograd_root])
def check_cacheable(gm: torch.fx.GraphModule):
"""
Checks that the graph module only uses supported operators
"""
nodes = gm.graph.nodes
if torch._dynamo.compiled_autograd.compiled_autograd_enabled_count:
raise BypassAOTAutogradCache(
"Cannot cache a graph with compiled autograd enabled"
)
for node in nodes:
check_node_safe(node)
class AOTAutogradCacheDetails(FxGraphHashDetails):
"""
Object to capture all the details for a dynamo graph module relevant to computing
a safe and stable cache key for AOTAutograd.
"""
def __init__(
self,
gm: torch.fx.GraphModule,
example_inputs,
aot_config: AOTConfig,
):
check_cacheable(gm)
# FxGraphHashDetails contains all the keys related to inductor. Also includes some system info
self.aot_config = aot_config
self.grad_enabled = torch.is_grad_enabled()
self.disable_amp = torch._C._is_any_autocast_enabled()
self.deterministic_algorithms = torch.are_deterministic_algorithms_enabled()
self.code_hash = get_autograd_code_hash()
self.autograd_config = config.save_config()
try:
super().__init__(gm, example_inputs, {}, [])
except BypassFxGraphCache as e:
# Sometimes inductor configs are unpickleable and can fail
raise BypassAOTAutogradCache from e
def debug_str(self) -> str:
return AOTAutogradCachePickler.debug_str(self)
def _reduce_aot_config(aot_config: AOTConfig):
"""
Reduce the config to a stable key for caching.
"""
return (
_ident,
(
aot_config.num_params_buffers,
aot_config.keep_inference_input_mutations,
aot_config.is_export,
aot_config.no_tangents,
aot_config.dynamic_shapes,
aot_config.aot_autograd_arg_pos_to_source,
aot_config.enable_log,
aot_config.pre_dispatch,
),
)
class AOTAutogradCachePickler(FxGraphCachePickler):
dispatch_table = FxGraphCachePickler.dispatch_table.copy()
dispatch_table[AOTConfig] = _reduce_aot_config
def autograd_cache_hash(
gm: torch.fx.GraphModule,
example_inputs,
config: AOTConfig,
# TODO: add args and parameters
) -> str:
"""
Generate a unique hash of the FX graph for caching.
"""
details = AOTAutogradCacheDetails(gm, example_inputs, config)
# The prefix distinguishes among the other kinds of objects we cache
key = "a" + AOTAutogradCachePickler.get_hash(details)
log.debug("FX graph cache hash details for key %s:\n%s", key, details.debug_str())
return key