import collections
import dataclasses
import dis
import functools
import importlib
import inspect
import itertools
import logging
import operator
import sys
import traceback
import types
import typing
import weakref
from collections.abc import Sized
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
from unittest.mock import patch
import torch
from torch._guards import Checkpointable
from . import (
allowed_functions,
config,
exc,
logging as torchdynamo_logging,
side_effects,
skipfiles,
variables,
)
from .allowed_functions import is_allowed, is_builtin_callable, is_builtin_constant
from .bytecode_analysis import JUMP_OPNAMES, livevars_analysis
from .bytecode_transformation import (
cleaned_instructions,
create_instruction,
create_jump_absolute,
Instruction,
is_generator,
unique_id,
)
from .codegen import PyCodegen
from .exc import BackendCompilerFailed, unimplemented, Unsupported
from .guards import GuardBuilder
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
from .replay_record import DummyModule, ExecutionRecorder
from .resume_execution import ContinueExecutionCache, ReenterWith
from .source import (
AttrSource,
GetItemSource,
GlobalSource,
GlobalWeakRefSource,
LocalInputSource,
LocalSource,
)
from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs
from .variables.base import MutableLocal, typestr, VariableTracker
from .variables.builder import VariableBuilder, wrap_fx_proxy
from .variables.builtin import BuiltinVariable
from .variables.constant import ConstantVariable, EnumVariable
from .variables.dicts import ConstDictVariable
from .variables.functions import (
BaseUserFunctionVariable,
NestedUserFunctionVariable,
UserFunctionVariable,
UserMethodVariable,
)
from .variables.lists import (
BaseListVariable,
ListIteratorVariable,
ListVariable,
SliceVariable,
TupleVariable,
)
from .variables.misc import (
ClosureVariable,
ContextWrappingVariable,
GetAttrVariable,
GradModeVariable,
PythonModuleVariable,
UnknownVariable,
WithExitFunctionVariable,
)
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
supported_const_comparison_ops,
supported_tensor_comparison_ops,
SymNodeVariable,
TensorVariable,
)
from .variables.torch import TorchVariable
from .variables.user_defined import UserDefinedObjectVariable, UserDefinedVariable
log = logging.getLogger(__name__)
@functools.lru_cache(None)
def _step_logger():
return torchdynamo_logging.get_step_logger(log)
@dataclasses.dataclass
class BlockStackEntry:
target: Instruction
stack_index: Optional[int] = None
with_context: ContextWrappingVariable = None
def can_restore(self):
return self.with_context is not None
def resume_fn(self):
assert self.stack_index is not None
if self.with_context and self.with_context.target_values:
return ReenterWith(self.stack_index, tuple(self.with_context.target_values))
else:
return ReenterWith(self.stack_index)
def exit(self, tx):
return self.with_context.exit(tx)
class InstructionTranslatorGraphState(NamedTuple):
output: OutputGraphState
symbolic_locals: Dict[str, VariableTracker]
stack: List[VariableTracker]
block_stack: List[BlockStackEntry]
instruction_pointer: Optional[int]
current_instruction: Instruction
next_instruction: Optional[Instruction]
lineno: int
def diff(self, other: "InstructionTranslatorGraphState") -> Optional[str]:
for k in self._fields:
if k == "output":
return self.output.diff(other.output, prefix=f"{k}.")
sv = getattr(self, k)
ov = getattr(other, k)
if sv != ov:
return f"{k} mismatch: {sv} != {ov}"
return None
def stack_op(fn: typing.Callable[..., object]):
nargs = len(inspect.signature(fn).parameters)
fn_var = BuiltinVariable(fn)
@functools.wraps(fn)
def impl(self: "InstructionTranslatorBase", inst: Instruction):
self.push(fn_var.call_function(self, self.popn(nargs), {}))
return impl
def _detect_and_normalize_assert_statement(
self: "InstructionTranslatorBase",
truth_fn: typing.Callable[[object], bool],
push: bool,
):
# Detect if this jump instruction is assert and normalize the assert
# by pushing dummy error message when nothing is given.
#
# Python 3.9 assertion is in following format:
# 18 POP_JUMP_IF_TRUE 28
# 20 LOAD_ASSERTION_ERROR
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
# 24 CALL_FUNCTION 1 -> optional instruction
# 26 RAISE_VARARGS
#
# Python 3.8 assertion is in following format:
# 18 POP_JUMP_IF_TRUE 28
# 20 LOAD_GLOBAL 0 (Assertion type)
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
# 24 CALL_FUNCTION 1 -> optional instruction
# 26 RAISE_VARARGS 1
if (truth_fn is not operator.truth) or push:
return False
assert isinstance(self.instruction_pointer, int)
current_instruction_pointer = self.instruction_pointer
inst = self.instructions[current_instruction_pointer]
# Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0
if sys.version_info < (3, 9):
if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError":
return False
else:
if inst.opname != "LOAD_ASSERTION_ERROR":
return False
current_instruction_pointer += 1
if current_instruction_pointer >= len(self.instructions):
return False
inst = self.instructions[current_instruction_pointer]
has_error_msg = False
# DETECT RAISE_VARARGS or LOAD CONST
if inst.opname == "LOAD_CONST":
if not isinstance(inst.argval, str):
return False
self.LOAD_CONST(inst)
has_error_msg = True
# if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION
current_instruction_pointer += 1
if current_instruction_pointer >= len(self.instructions):
return False
inst = self.instructions[current_instruction_pointer]
if inst.opname != "CALL_FUNCTION":
return False
# CALL_FUNCTION should be followed by RAISE_VARARGS
current_instruction_pointer += 1
if current_instruction_pointer >= len(self.instructions):
return False
inst = self.instructions[current_instruction_pointer]
if inst.opname != "RAISE_VARARGS":
return False
if not has_error_msg:
# Push dummy value instead of error message
self.push(ConstantVariable("assertion error"))
return True
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
def inner(self: "InstructionTranslatorBase", inst: Instruction):
value: VariableTracker = self.pop()
self.output.guards.update(value.guards)
if (
config.rewrite_assert_with_torch_assert
and _detect_and_normalize_assert_statement(self, truth_fn, push)
):
error_msg: VariableTracker = self.pop()
self.output.guards.update(error_msg.guards)
# Skip over things like `assert True`
if value.is_python_constant() and bool(value.as_python_constant()):
self.jump(inst)
return
# Manually insert torch._assert instead of python assert and jump over
# assert related instructions as we don't need them anymore.
self.output.create_proxy(
"call_function",
torch._assert,
*proxy_args_kwargs((value, error_msg), {}),
)
self.jump(inst)
return
if value.is_python_constant():
if truth_fn(value.as_python_constant()):
push and self.push(value)
self.jump(inst)
elif (
isinstance(value, (TensorVariable)) and self.should_compile_partial_graph()
):
# compile a partial subgraph prefix then jump into user code
if self.has_backedge():
msg = (
"Skipping frame because there is a graph break in a for/while loop"
)
log.debug(msg)
raise exc.SkipFrame(msg)
self.push(value)
log.debug("generic_jump triggered compile")
self.output.compile_subgraph(
self,
reason=GraphCompileReason(
f"generic_jump {typestr(value)}", [self.frame_summary()]
),
)
self.pop()
if_next = self.create_call_resume_at(self.next_instruction)
push and self.push(value)
if_jump = self.create_call_resume_at(inst.target)
self.output.add_output_instructions(
[(create_instruction(inst.opname, target=if_jump[0]))]
+ if_next
+ if_jump
)
elif isinstance(value, NNModuleVariable):
# Equivant of "self.nn_module is not None"
if truth_fn(value):
push and self.push(value)
self.jump(inst)
elif isinstance(value, UserDefinedObjectVariable):
x = value.var_getattr(self, "__bool__")
# __bool__ is function
if isinstance(x, UserMethodVariable):
state = self.copy_graphstate()
result = x.call_function(self, [], {})
if isinstance(result, ConstantVariable) and isinstance(
result.value, bool
):
self.output.guards.update(result.guards)
if truth_fn(result.value):
push and self.push(value)
self.jump(inst)
else:
# rollback to the state before the __bool__ inline
self.restore_graphstate(state)
unimplemented(
"generic_jump on UserDefined with __bool__ returning non-constant"
)
# __bool__ is non-function or not existed in the user defined object
else:
if truth_fn(True):
push and self.push(value)
self.jump(inst)
elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
self
):
if truth_fn(len(value.unpack_var_sequence(self))):
push and self.push(value)
self.jump(inst)
elif isinstance(value, SymNodeVariable):
eval_result = value.evaluate_expr(self.output)
if truth_fn(eval_result):
push and self.push(value)
self.jump(inst)
else:
unimplemented(f"generic_jump {typestr(value)}")
return inner
explain = False
def break_graph_if_unsupported(*, push):
def decorator(inner_fn):
@functools.wraps(inner_fn)
def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
state = self.copy_graphstate()
reason = None
try:
return inner_fn(self, inst)
except Unsupported as excp:
if self.has_backedge() and self.should_compile_partial_graph():
msg = "Skipping frame because there is a graph break in a for/while loop"
Loading ...