Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ _dynamo / symbolic_convert.py

import collections
import dataclasses
import dis
import functools
import importlib
import inspect
import itertools
import logging
import operator
import sys
import traceback
import types
import typing
import weakref
from collections.abc import Sized
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
from unittest.mock import patch

import torch
from torch._guards import Checkpointable

from . import (
    allowed_functions,
    config,
    exc,
    logging as torchdynamo_logging,
    side_effects,
    skipfiles,
    variables,
)
from .allowed_functions import is_allowed, is_builtin_callable, is_builtin_constant
from .bytecode_analysis import JUMP_OPNAMES, livevars_analysis
from .bytecode_transformation import (
    cleaned_instructions,
    create_instruction,
    create_jump_absolute,
    Instruction,
    is_generator,
    unique_id,
)
from .codegen import PyCodegen
from .exc import BackendCompilerFailed, unimplemented, Unsupported
from .guards import GuardBuilder
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
from .replay_record import DummyModule, ExecutionRecorder
from .resume_execution import ContinueExecutionCache, ReenterWith
from .source import (
    AttrSource,
    GetItemSource,
    GlobalSource,
    GlobalWeakRefSource,
    LocalInputSource,
    LocalSource,
)
from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs
from .variables.base import MutableLocal, typestr, VariableTracker
from .variables.builder import VariableBuilder, wrap_fx_proxy
from .variables.builtin import BuiltinVariable
from .variables.constant import ConstantVariable, EnumVariable
from .variables.dicts import ConstDictVariable
from .variables.functions import (
    BaseUserFunctionVariable,
    NestedUserFunctionVariable,
    UserFunctionVariable,
    UserMethodVariable,
)
from .variables.lists import (
    BaseListVariable,
    ListIteratorVariable,
    ListVariable,
    SliceVariable,
    TupleVariable,
)
from .variables.misc import (
    ClosureVariable,
    ContextWrappingVariable,
    GetAttrVariable,
    GradModeVariable,
    PythonModuleVariable,
    UnknownVariable,
    WithExitFunctionVariable,
)
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
    supported_const_comparison_ops,
    supported_tensor_comparison_ops,
    SymNodeVariable,
    TensorVariable,
)
from .variables.torch import TorchVariable
from .variables.user_defined import UserDefinedObjectVariable, UserDefinedVariable

log = logging.getLogger(__name__)


@functools.lru_cache(None)
def _step_logger():
    return torchdynamo_logging.get_step_logger(log)


@dataclasses.dataclass
class BlockStackEntry:
    target: Instruction
    stack_index: Optional[int] = None
    with_context: ContextWrappingVariable = None

    def can_restore(self):
        return self.with_context is not None

    def resume_fn(self):
        assert self.stack_index is not None
        if self.with_context and self.with_context.target_values:
            return ReenterWith(self.stack_index, tuple(self.with_context.target_values))
        else:
            return ReenterWith(self.stack_index)

    def exit(self, tx):
        return self.with_context.exit(tx)


class InstructionTranslatorGraphState(NamedTuple):
    output: OutputGraphState
    symbolic_locals: Dict[str, VariableTracker]
    stack: List[VariableTracker]
    block_stack: List[BlockStackEntry]
    instruction_pointer: Optional[int]
    current_instruction: Instruction
    next_instruction: Optional[Instruction]
    lineno: int

    def diff(self, other: "InstructionTranslatorGraphState") -> Optional[str]:
        for k in self._fields:
            if k == "output":
                return self.output.diff(other.output, prefix=f"{k}.")
            sv = getattr(self, k)
            ov = getattr(other, k)
            if sv != ov:
                return f"{k} mismatch: {sv} != {ov}"
        return None


def stack_op(fn: typing.Callable[..., object]):
    nargs = len(inspect.signature(fn).parameters)
    fn_var = BuiltinVariable(fn)

    @functools.wraps(fn)
    def impl(self: "InstructionTranslatorBase", inst: Instruction):
        self.push(fn_var.call_function(self, self.popn(nargs), {}))

    return impl


def _detect_and_normalize_assert_statement(
    self: "InstructionTranslatorBase",
    truth_fn: typing.Callable[[object], bool],
    push: bool,
):
    # Detect if this jump instruction is assert and normalize the assert
    # by pushing dummy error message when nothing is given.
    #
    # Python 3.9 assertion is in following format:
    # 18 POP_JUMP_IF_TRUE       28
    # 20 LOAD_ASSERTION_ERROR
    # 22 LOAD_CONST               3 ('Assert message') -> optional instruction
    # 24 CALL_FUNCTION            1                    -> optional instruction
    # 26 RAISE_VARARGS
    #
    # Python 3.8 assertion is in following format:
    # 18 POP_JUMP_IF_TRUE       28
    # 20 LOAD_GLOBAL              0 (Assertion type)
    # 22 LOAD_CONST               3 ('Assert message') -> optional instruction
    # 24 CALL_FUNCTION            1                    -> optional instruction
    # 26 RAISE_VARARGS            1

    if (truth_fn is not operator.truth) or push:
        return False

    assert isinstance(self.instruction_pointer, int)
    current_instruction_pointer = self.instruction_pointer
    inst = self.instructions[current_instruction_pointer]
    # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0
    if sys.version_info < (3, 9):
        if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError":
            return False
    else:
        if inst.opname != "LOAD_ASSERTION_ERROR":
            return False

    current_instruction_pointer += 1

    if current_instruction_pointer >= len(self.instructions):
        return False

    inst = self.instructions[current_instruction_pointer]
    has_error_msg = False
    # DETECT RAISE_VARARGS or LOAD CONST
    if inst.opname == "LOAD_CONST":
        if not isinstance(inst.argval, str):
            return False
        self.LOAD_CONST(inst)
        has_error_msg = True

        # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION
        current_instruction_pointer += 1
        if current_instruction_pointer >= len(self.instructions):
            return False
        inst = self.instructions[current_instruction_pointer]
        if inst.opname != "CALL_FUNCTION":
            return False

        # CALL_FUNCTION should be followed by RAISE_VARARGS
        current_instruction_pointer += 1
        if current_instruction_pointer >= len(self.instructions):
            return False
        inst = self.instructions[current_instruction_pointer]

    if inst.opname != "RAISE_VARARGS":
        return False

    if not has_error_msg:
        # Push dummy value instead of error message
        self.push(ConstantVariable("assertion error"))

    return True


def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
    def inner(self: "InstructionTranslatorBase", inst: Instruction):
        value: VariableTracker = self.pop()
        self.output.guards.update(value.guards)
        if (
            config.rewrite_assert_with_torch_assert
            and _detect_and_normalize_assert_statement(self, truth_fn, push)
        ):
            error_msg: VariableTracker = self.pop()
            self.output.guards.update(error_msg.guards)
            # Skip over things like `assert True`
            if value.is_python_constant() and bool(value.as_python_constant()):
                self.jump(inst)
                return

            # Manually insert torch._assert instead of python assert and jump over
            # assert related instructions as we don't need them anymore.
            self.output.create_proxy(
                "call_function",
                torch._assert,
                *proxy_args_kwargs((value, error_msg), {}),
            )
            self.jump(inst)
            return

        if value.is_python_constant():
            if truth_fn(value.as_python_constant()):
                push and self.push(value)
                self.jump(inst)
        elif (
            isinstance(value, (TensorVariable)) and self.should_compile_partial_graph()
        ):
            # compile a partial subgraph prefix then jump into user code
            if self.has_backedge():
                msg = (
                    "Skipping frame because there is a graph break in a for/while loop"
                )
                log.debug(msg)
                raise exc.SkipFrame(msg)

            self.push(value)
            log.debug("generic_jump triggered compile")
            self.output.compile_subgraph(
                self,
                reason=GraphCompileReason(
                    f"generic_jump {typestr(value)}", [self.frame_summary()]
                ),
            )
            self.pop()

            if_next = self.create_call_resume_at(self.next_instruction)
            push and self.push(value)
            if_jump = self.create_call_resume_at(inst.target)

            self.output.add_output_instructions(
                [(create_instruction(inst.opname, target=if_jump[0]))]
                + if_next
                + if_jump
            )
        elif isinstance(value, NNModuleVariable):
            # Equivant of "self.nn_module is not None"
            if truth_fn(value):
                push and self.push(value)
                self.jump(inst)
        elif isinstance(value, UserDefinedObjectVariable):
            x = value.var_getattr(self, "__bool__")
            # __bool__ is function
            if isinstance(x, UserMethodVariable):
                state = self.copy_graphstate()
                result = x.call_function(self, [], {})
                if isinstance(result, ConstantVariable) and isinstance(
                    result.value, bool
                ):
                    self.output.guards.update(result.guards)
                    if truth_fn(result.value):
                        push and self.push(value)
                        self.jump(inst)
                else:
                    # rollback to the state before the __bool__ inline
                    self.restore_graphstate(state)
                    unimplemented(
                        "generic_jump on UserDefined with __bool__ returning non-constant"
                    )
            # __bool__ is non-function or not existed in the user defined object
            else:
                if truth_fn(True):
                    push and self.push(value)
                    self.jump(inst)
        elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
            self
        ):
            if truth_fn(len(value.unpack_var_sequence(self))):
                push and self.push(value)
                self.jump(inst)
        elif isinstance(value, SymNodeVariable):
            eval_result = value.evaluate_expr(self.output)
            if truth_fn(eval_result):
                push and self.push(value)
                self.jump(inst)
        else:
            unimplemented(f"generic_jump {typestr(value)}")

    return inner


explain = False


def break_graph_if_unsupported(*, push):
    def decorator(inner_fn):
        @functools.wraps(inner_fn)
        def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
            state = self.copy_graphstate()
            reason = None
            try:
                return inner_fn(self, inst)
            except Unsupported as excp:
                if self.has_backedge() and self.should_compile_partial_graph():
                    msg = "Skipping frame because there is a graph break in a for/while loop"
Loading ...