Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _dynamo / output_graph.py

import collections
import copy
import functools
import itertools
import logging
import operator
import re
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union

import torch.nn
from torch import fx
from torch._guards import (
    Checkpointable,
    Guard,
    GuardsCheckpointState,
    tracing,
    TracingContext,
)
from torch.fx.experimental.symbolic_shapes import ShapeEnv

from . import config, logging as torchdynamo_logging, variables
from .backends.registry import CompiledFn, CompilerFn
from .bytecode_transformation import create_instruction, Instruction, unique_id
from .codegen import PyCodegen
from .exc import BackendCompilerFailed, unimplemented
from .guards import GuardBuilder
from .mutation_guard import is_dynamic_nn_module
from .side_effects import SideEffects
from .source import (
    ConstantSource,
    is_constant_source,
    LocalInputSource,
    LocalSource,
    ShapeEnvSource,
)
from .utils import (
    assert_no_fake_params_or_buffers,
    checkpoint_params,
    CleanupHook,
    clone_inputs,
    count_calls,
    counters,
    dynamo_timed,
    format_graph_tabular,
    same,
)
from .variables.base import VariableTracker
from .variables.builder import GraphArg, TrackedFake, VariableBuilder, wrap_fx_proxy
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
    SymNodeVariable,
    TensorVariable,
    UnspecializedPythonVariable,
)

log = logging.getLogger(__name__)


class OutputGraphState(NamedTuple):
    graphargs: List[GraphArg]
    tracked_fakes: List[TrackedFake]
    guard_state: GuardsCheckpointState
    nn_modules: Optional[Dict[str, torch.nn.Module]]
    side_effects: SideEffects
    timestamp: int

    def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]:
        for k in self._fields:
            if k == "guard_state":
                r = self.guard_state.diff(other.guard_state)
                if r is not None:
                    return r
                continue
            elif k == "side_effects":
                r = self.side_effects.diff(other.side_effects)
                if r is not None:
                    return r
                continue

            sv = getattr(self, k)
            ov = getattr(other, k)
            if sv != ov:
                return f"{prefix}{k} mismatch: {sv} != {ov}"
        return None

    # Back compat .guards api
    @property
    def guards(self):
        return self.guard_state.dynamo_guards


@functools.lru_cache(None)
def _step_logger():
    return torchdynamo_logging.get_step_logger(log)


@dataclass
class GraphCompileReason:
    """Stores why a given output graph was compiled; i.e. what caused the graph break."""

    reason: str
    user_stack: List[traceback.FrameSummary]


def _get_gen_rand_values_fn(random_calls):
    def _gen_rand_values():
        return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]

    return _gen_rand_values


class FakeRootModule(torch.nn.Module):
    """Trick the constructor of fx.GraphModule"""

    def __init__(self, nn_modules: Dict[str, torch.nn.Module]):
        super().__init__()
        for k, v in nn_modules.items():
            setattr(self, k, v)

    def __repr__(self):
        return "FakeRootModule(...)"


class WrapperBackend:
    def __init__(self, backend: CompilerFn, original_example_inputs):
        self.backend: CompilerFn = backend
        self.original_example_inputs = original_example_inputs

    @property
    def example_inputs(self):
        return clone_inputs(self.original_example_inputs)

    def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):

        self.restore = checkpoint_params(gm)
        self.gm = gm
        copy_gm = copy.deepcopy(self.gm)
        self.candidate = self.backend(copy_gm, self.original_example_inputs)

        if self.candidate is None or self.candidate is self.gm.forward:
            return self.gm.forward

        if not config.verify_correctness:
            return self.candidate

        # if verify_correctness=True
        try:
            correct = self.gm.forward(*self.example_inputs)
            result = self.candidate(*self.example_inputs)

            # TODO: replace `same` function with the one in testing
            if same(correct, result):
                return self.candidate

            raise RuntimeError(f"incorrect results of backend {self}")
            return self.gm.forward

        except Exception:
            log.exception("error in verify_correctness")
            raise
        finally:
            self.restore()


class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
    """
    Wrapper class to hold outputs of InstructionTranslator.  Mainly the
    generated fx.Graph.
    """

    def __init__(
        self,
        f_globals: Dict[str, Any],
        code_options: Dict[str, Any],
        compiler_fn: CompilerFn,
        root_tx,
    ):
        super().__init__()
        self.graph = torch.fx.Graph()
        self.graphargs: List[GraphArg] = []
        fake_mode = torch._subclasses.FakeTensorMode(
            shape_env=ShapeEnv() if config.dynamic_shapes else None,
        )
        self.tracing_context: TracingContext = TracingContext(fake_mode)
        if config.dynamic_shapes:
            # Register a SHAPE_ENV guard to make sure we setup shape guards
            # that show up in ShapeEnv
            self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))

        # tracked_fakes says where any tensor that was wrapped to fake came
        # from.  It is similar to GraphArg, in that all GraphArgs will get
        # will get added to TrackedFakes, but TrackedFakes also contains
        # GraphArgs that got pruned, and things like Tensor attributes which
        # aren't explicit graph inputs.  Used by shape guard
        self.tracked_fakes: List[TrackedFake] = []
        # Although we prune unused graphargs before sending graphs to
        # compilers, we may have legitimately triggered shape guards
        # on "unused" inputs that we must keep track of.  So after
        # remove_unused_graphargs is called, orig_graphargs and
        # graphargs no longer alias; orig_graphargs is the original
        # graphargs, and graphargs is the pruned list.  Guard creation
        # should use original graphargs.
        self.orig_graphargs: List[GraphArg] = self.graphargs
        self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
        self.side_effects = SideEffects()
        self.code_options = dict(code_options)
        self.output_instructions: List[Instruction] = []
        # used to track nodes that are added between calls of copy_graphstate
        # and restore_graphstate
        self.timestamp = 0
        # Node => computed real value (see utils.get_real_value)
        self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}

        # Not checkpointed
        self.compiler_fn: CompilerFn = compiler_fn
        self.root_globals = f_globals
        self.root_tx = root_tx
        from torch._dynamo.symbolic_convert import InstructionTranslatorBase

        self._current_tx: List[InstructionTranslatorBase] = []
        self.cleanups: List[CleanupHook] = []
        self.should_exit = False
        self.random_values_var = None
        self.initial_random_state = ()
        self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {}
        # Maps the source arg position to the grapharg position
        self.pos_to_arg: Dict[int, int] = {}

        # Enables creating unique node names by tracking
        # all current placeholder node names
        self.name_to_input: OrderedDict[
            str, Optional[fx.Proxy]
        ] = collections.OrderedDict()

    @property
    def output(self):
        return self

    @property
    def fake_mode(self):
        return self.root_tx.fake_mode

    @property
    def shape_env(self):
        return self.tracing_context.fake_mode.shape_env

    @property
    def guards(self) -> Set[Guard]:
        return self.tracing_context.guards_context.dynamo_guards

    def push_tx(self, tx):
        self._current_tx.append(tx)

    def pop_tx(self):
        return self._current_tx.pop()

    @property
    def current_tx(self):
        return self.root_tx if not self._current_tx else self._current_tx[-1]

    def copy_graphstate(self) -> OutputGraphState:
        """Create a checkpoint of the current state by copying everything"""
        assert self.nn_modules is not None
        guards_graph_state = self.tracing_context.guards_context.copy_graphstate()
        state = OutputGraphState(
            list(self.graphargs),
            list(self.tracked_fakes),
            guards_graph_state,
            dict(self.nn_modules),
            self.side_effects.clone(),
            self.timestamp,
        )
        self.timestamp += 1
        return state

    def restore_graphstate(self, state: OutputGraphState):
        """Restore a checkpoint created by self.copy_graphstate()"""
        (
            self.graphargs,
            self.tracked_fakes,
            guards_state,
            self.nn_modules,
            self.side_effects,
            self.timestamp,
        ) = state
        self.tracing_context.guards_context.restore_graphstate(guards_state)
        # FX deepcopy doesn't work for a partially created graph, so just remove new nodes
        removed_nodes = 0
        for node in reversed(list(self.graph.nodes)):
            if node.meta["creation_timestamp"] > self.timestamp:
                # Erasing node alone does not remove the meta information
                # So, remove the help tensor explicitly
                if "example_value" in node.meta:
                    del node.meta["example_value"]
                self.remove_node(node)
                self.real_value_cache.pop(node, None)
                removed_nodes += 1
        log.debug(f"restore_graphstate: removed {removed_nodes} nodes")

    def add_grapharg(self, arg: GraphArg):
        curr_pos = len(self.graphargs)
        self.graphargs.append(arg)
        if isinstance(arg.source, LocalInputSource):
            self.pos_to_arg[arg.source.pos] = curr_pos

    def count_calls(self):
        return count_calls(self.graph)

    def get_submodule(self, keys):
        assert keys
        obj = self.nn_modules
        for k in keys.split("."):
            if isinstance(obj, dict):
                obj = obj[k]
            else:
                obj = getattr(obj, k)
        return obj

    def create_graph_input(self, name, type_expr=None):
        # unique
        if name in self.name_to_input:
            for i in itertools.count():
                if f"{name}_{i}" not in self.name_to_input:
                    name = f"{name}_{i}"
                    break

        if self.name_to_input:
            prev_name = next(reversed(self.name_to_input))
            ctx = self.graph.inserting_after(self.name_to_input[prev_name])
        else:
            ctx = self.graph.inserting_before(None)
        with ctx:
            proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
            self.name_to_input[name] = proxy.node
            return proxy

    def new_var(self, name="tmp"):
        existing = set(self.code_options["co_varnames"])
        for i in itertools.count():
            var = f"___{name}_{i}"
            if var not in existing:
                self.code_options["co_varnames"] = self.code_options["co_varnames"] + (
                    var,
Loading ...