import collections
import copy
import functools
import itertools
import logging
import operator
import re
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union
import torch.nn
from torch import fx
from torch._guards import (
Checkpointable,
Guard,
GuardsCheckpointState,
tracing,
TracingContext,
)
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from . import config, logging as torchdynamo_logging, variables
from .backends.registry import CompiledFn, CompilerFn
from .bytecode_transformation import create_instruction, Instruction, unique_id
from .codegen import PyCodegen
from .exc import BackendCompilerFailed, unimplemented
from .guards import GuardBuilder
from .mutation_guard import is_dynamic_nn_module
from .side_effects import SideEffects
from .source import (
ConstantSource,
is_constant_source,
LocalInputSource,
LocalSource,
ShapeEnvSource,
)
from .utils import (
assert_no_fake_params_or_buffers,
checkpoint_params,
CleanupHook,
clone_inputs,
count_calls,
counters,
dynamo_timed,
format_graph_tabular,
same,
)
from .variables.base import VariableTracker
from .variables.builder import GraphArg, TrackedFake, VariableBuilder, wrap_fx_proxy
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
SymNodeVariable,
TensorVariable,
UnspecializedPythonVariable,
)
log = logging.getLogger(__name__)
class OutputGraphState(NamedTuple):
graphargs: List[GraphArg]
tracked_fakes: List[TrackedFake]
guard_state: GuardsCheckpointState
nn_modules: Optional[Dict[str, torch.nn.Module]]
side_effects: SideEffects
timestamp: int
def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]:
for k in self._fields:
if k == "guard_state":
r = self.guard_state.diff(other.guard_state)
if r is not None:
return r
continue
elif k == "side_effects":
r = self.side_effects.diff(other.side_effects)
if r is not None:
return r
continue
sv = getattr(self, k)
ov = getattr(other, k)
if sv != ov:
return f"{prefix}{k} mismatch: {sv} != {ov}"
return None
# Back compat .guards api
@property
def guards(self):
return self.guard_state.dynamo_guards
@functools.lru_cache(None)
def _step_logger():
return torchdynamo_logging.get_step_logger(log)
@dataclass
class GraphCompileReason:
"""Stores why a given output graph was compiled; i.e. what caused the graph break."""
reason: str
user_stack: List[traceback.FrameSummary]
def _get_gen_rand_values_fn(random_calls):
def _gen_rand_values():
return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
return _gen_rand_values
class FakeRootModule(torch.nn.Module):
"""Trick the constructor of fx.GraphModule"""
def __init__(self, nn_modules: Dict[str, torch.nn.Module]):
super().__init__()
for k, v in nn_modules.items():
setattr(self, k, v)
def __repr__(self):
return "FakeRootModule(...)"
class WrapperBackend:
def __init__(self, backend: CompilerFn, original_example_inputs):
self.backend: CompilerFn = backend
self.original_example_inputs = original_example_inputs
@property
def example_inputs(self):
return clone_inputs(self.original_example_inputs)
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
self.restore = checkpoint_params(gm)
self.gm = gm
copy_gm = copy.deepcopy(self.gm)
self.candidate = self.backend(copy_gm, self.original_example_inputs)
if self.candidate is None or self.candidate is self.gm.forward:
return self.gm.forward
if not config.verify_correctness:
return self.candidate
# if verify_correctness=True
try:
correct = self.gm.forward(*self.example_inputs)
result = self.candidate(*self.example_inputs)
# TODO: replace `same` function with the one in testing
if same(correct, result):
return self.candidate
raise RuntimeError(f"incorrect results of backend {self}")
return self.gm.forward
except Exception:
log.exception("error in verify_correctness")
raise
finally:
self.restore()
class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
"""
Wrapper class to hold outputs of InstructionTranslator. Mainly the
generated fx.Graph.
"""
def __init__(
self,
f_globals: Dict[str, Any],
code_options: Dict[str, Any],
compiler_fn: CompilerFn,
root_tx,
):
super().__init__()
self.graph = torch.fx.Graph()
self.graphargs: List[GraphArg] = []
fake_mode = torch._subclasses.FakeTensorMode(
shape_env=ShapeEnv() if config.dynamic_shapes else None,
)
self.tracing_context: TracingContext = TracingContext(fake_mode)
if config.dynamic_shapes:
# Register a SHAPE_ENV guard to make sure we setup shape guards
# that show up in ShapeEnv
self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
# tracked_fakes says where any tensor that was wrapped to fake came
# from. It is similar to GraphArg, in that all GraphArgs will get
# will get added to TrackedFakes, but TrackedFakes also contains
# GraphArgs that got pruned, and things like Tensor attributes which
# aren't explicit graph inputs. Used by shape guard
self.tracked_fakes: List[TrackedFake] = []
# Although we prune unused graphargs before sending graphs to
# compilers, we may have legitimately triggered shape guards
# on "unused" inputs that we must keep track of. So after
# remove_unused_graphargs is called, orig_graphargs and
# graphargs no longer alias; orig_graphargs is the original
# graphargs, and graphargs is the pruned list. Guard creation
# should use original graphargs.
self.orig_graphargs: List[GraphArg] = self.graphargs
self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
self.side_effects = SideEffects()
self.code_options = dict(code_options)
self.output_instructions: List[Instruction] = []
# used to track nodes that are added between calls of copy_graphstate
# and restore_graphstate
self.timestamp = 0
# Node => computed real value (see utils.get_real_value)
self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}
# Not checkpointed
self.compiler_fn: CompilerFn = compiler_fn
self.root_globals = f_globals
self.root_tx = root_tx
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
self._current_tx: List[InstructionTranslatorBase] = []
self.cleanups: List[CleanupHook] = []
self.should_exit = False
self.random_values_var = None
self.initial_random_state = ()
self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {}
# Maps the source arg position to the grapharg position
self.pos_to_arg: Dict[int, int] = {}
# Enables creating unique node names by tracking
# all current placeholder node names
self.name_to_input: OrderedDict[
str, Optional[fx.Proxy]
] = collections.OrderedDict()
@property
def output(self):
return self
@property
def fake_mode(self):
return self.root_tx.fake_mode
@property
def shape_env(self):
return self.tracing_context.fake_mode.shape_env
@property
def guards(self) -> Set[Guard]:
return self.tracing_context.guards_context.dynamo_guards
def push_tx(self, tx):
self._current_tx.append(tx)
def pop_tx(self):
return self._current_tx.pop()
@property
def current_tx(self):
return self.root_tx if not self._current_tx else self._current_tx[-1]
def copy_graphstate(self) -> OutputGraphState:
"""Create a checkpoint of the current state by copying everything"""
assert self.nn_modules is not None
guards_graph_state = self.tracing_context.guards_context.copy_graphstate()
state = OutputGraphState(
list(self.graphargs),
list(self.tracked_fakes),
guards_graph_state,
dict(self.nn_modules),
self.side_effects.clone(),
self.timestamp,
)
self.timestamp += 1
return state
def restore_graphstate(self, state: OutputGraphState):
"""Restore a checkpoint created by self.copy_graphstate()"""
(
self.graphargs,
self.tracked_fakes,
guards_state,
self.nn_modules,
self.side_effects,
self.timestamp,
) = state
self.tracing_context.guards_context.restore_graphstate(guards_state)
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
removed_nodes = 0
for node in reversed(list(self.graph.nodes)):
if node.meta["creation_timestamp"] > self.timestamp:
# Erasing node alone does not remove the meta information
# So, remove the help tensor explicitly
if "example_value" in node.meta:
del node.meta["example_value"]
self.remove_node(node)
self.real_value_cache.pop(node, None)
removed_nodes += 1
log.debug(f"restore_graphstate: removed {removed_nodes} nodes")
def add_grapharg(self, arg: GraphArg):
curr_pos = len(self.graphargs)
self.graphargs.append(arg)
if isinstance(arg.source, LocalInputSource):
self.pos_to_arg[arg.source.pos] = curr_pos
def count_calls(self):
return count_calls(self.graph)
def get_submodule(self, keys):
assert keys
obj = self.nn_modules
for k in keys.split("."):
if isinstance(obj, dict):
obj = obj[k]
else:
obj = getattr(obj, k)
return obj
def create_graph_input(self, name, type_expr=None):
# unique
if name in self.name_to_input:
for i in itertools.count():
if f"{name}_{i}" not in self.name_to_input:
name = f"{name}_{i}"
break
if self.name_to_input:
prev_name = next(reversed(self.name_to_input))
ctx = self.graph.inserting_after(self.name_to_input[prev_name])
else:
ctx = self.graph.inserting_before(None)
with ctx:
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
self.name_to_input[name] = proxy.node
return proxy
def new_var(self, name="tmp"):
existing = set(self.code_options["co_varnames"])
for i in itertools.count():
var = f"___{name}_{i}"
if var not in existing:
self.code_options["co_varnames"] = self.code_options["co_varnames"] + (
var,
Loading ...