import contextlib
import functools
import inspect
import logging
import os
import sys
import textwrap
import threading
import traceback
import types
import warnings
from enum import Enum
from typing import Optional, Tuple, TYPE_CHECKING, Union
from unittest.mock import patch
import torch
import torch.utils._pytree as pytree
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.nn.parallel.distributed import DistributedDataParallel
from .backends.registry import CompilerFn, lookup_backend
from .hooks import Hooks
if TYPE_CHECKING:
from torch._C._dynamo.eval_frame import ( # noqa: F401
reset_code,
set_eval_frame,
set_guard_error_hook,
set_guard_fail_hook,
skip_code,
unsupported,
)
else:
for name in dir(torch._C._dynamo.eval_frame):
if name.startswith("__"):
continue
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
from . import config, convert_frame, skipfiles, utils
from .exc import ResetRequired
from .mutation_guard import install_generation_tagging_init
from .types import DynamoCallback
from .utils import compile_times
log = logging.getLogger(__name__)
from torch.fx.experimental import proxy_tensor
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
null_context = contextlib.nullcontext
# See https://github.com/python/typing/pull/240
class Unset(Enum):
token = 0
unset = Unset.token
compile_lock = threading.RLock()
most_recent_backend: Optional[CompilerFn] = None
class OptimizedModule(torch.nn.Module):
"""
Wraps the original nn.Module object and later patches its
forward method to optimized self.forward method.
"""
def __init__(self, mod, dynamo_ctx):
super().__init__()
# Installs the params/buffer
self._orig_mod = mod
self.dynamo_ctx = dynamo_ctx
def __getattr__(self, name):
if name == "_orig_mod":
return self._modules["_orig_mod"]
return getattr(self._orig_mod, name)
def forward(self, *args, **kwargs):
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
def remove_from_cache(f):
"""
Make sure f.__code__ is not cached to force a recompile
"""
if isinstance(f, types.CodeType):
reset_code(f)
elif hasattr(f, "__code__"):
reset_code(f.__code__)
elif hasattr(getattr(f, "forward", None), "__code__"):
reset_code(f.forward.__code__)
else:
from . import reset
reset()
log.warning("could not determine __code__ for %s", f)
def nothing():
pass
def innermost_fn(fn):
"""
In case of nesting of _TorchDynamoContext calls, find the innermost
function. TorchDynamo caches on fn.__code__ object, so its necessary to find
the innermost function to pass on the optimize, run, disable etc.
"""
unaltered_fn = fn
while hasattr(unaltered_fn, "_torchdynamo_orig_callable"):
unaltered_fn = unaltered_fn._torchdynamo_orig_callable
assert callable(unaltered_fn)
return unaltered_fn
@contextlib.contextmanager
def enable_dynamic(enable: bool = True):
if not enable:
yield
return
with config.patch(dynamic_shapes=True, specialize_int_float=False):
yield
class _TorchDynamoContext:
def __init__(
self,
callback: DynamoCallback,
on_enter=nothing,
backend_ctx_ctor=null_context,
patch_fn=nothing,
first_ctx=False,
*,
dynamic=False,
):
super().__init__()
assert callable(callback) or callback is False or callback is None
self.callback: DynamoCallback = callback
self.prior: Union[Unset, DynamoCallback] = unset
self.on_enter = on_enter
self.extra_ctx_ctor = backend_ctx_ctor
self.first_ctx = first_ctx
self.dynamic = dynamic
patch_fn()
def __enter__(self):
if config.raise_on_ctx_manager_usage:
raise RuntimeError(
"torch._dynamo.optimize(...) is used with a context manager. "
"Please refer to https://github.com/pytorch/torchdynamo#usage-example "
"to use torch._dynamo.optimize(...) as an annotation/decorator. "
)
self.on_enter()
self.prior = set_eval_frame(self.callback)
self.backend_ctx = self.extra_ctx_ctor()
self.backend_ctx.__enter__()
self.dynamic_ctx = enable_dynamic(self.dynamic)
self.dynamic_ctx.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
assert self.prior is not unset
set_eval_frame(self.prior)
self.prior = unset
# TODO: This is totally not the right way to chain contexts manually
self.dynamic_ctx.__exit__(exc_type, exc_val, exc_tb)
self.backend_ctx.__exit__(exc_type, exc_val, exc_tb)
def __call__(self, fn):
fn = innermost_fn(fn)
# Optimize the forward method of torch.nn.Module object
if isinstance(fn, torch.nn.Module):
mod = fn
new_mod = OptimizedModule(mod, self)
# Save the function pointer to find the original callable while nesting
# of decorators.
new_mod._torchdynamo_orig_callable = mod.forward
return new_mod
assert callable(fn)
callback = self.callback
on_enter = self.on_enter
backend_ctx_ctor = self.extra_ctx_ctor
@functools.wraps(fn)
def _fn(*args, **kwargs):
if (
not isinstance(self, DisableContext)
and torch.fx._symbolic_trace.is_fx_tracing()
):
if config.error_on_nested_fx_trace:
raise RuntimeError(
"Detected that you are using FX to symbolically trace "
"a dynamo-optimized function. This is not supported at the moment."
)
else:
return fn(*args, **kwargs)
on_enter()
prior = set_eval_frame(callback)
backend_ctx = backend_ctx_ctor()
backend_ctx.__enter__()
dynamic_ctx = enable_dynamic(self.dynamic)
dynamic_ctx.__enter__()
try:
return fn(*args, **kwargs)
finally:
set_eval_frame(prior)
dynamic_ctx.__exit__(None, None, None)
backend_ctx.__exit__(None, None, None)
# hooks to properly handle inlining
if isinstance(self, DisableContext):
_fn._torchdynamo_disable = True # type: ignore[attr-defined]
else:
_fn._torchdynamo_inline = fn # type: ignore[attr-defined]
# Save the function pointer to find the original callable while nesting
# of decorators.
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
# If the function is called using torch._dynamo.optimize decorator, we
# should prevent any type of skipping.
if callback not in (None, False):
if not hasattr(fn, "__code__"):
raise RuntimeError(
textwrap.dedent(
"""
torch._dynamo.optimize is called on a non function object.
If this is a callable class, please wrap the relevant code into a function and optimize the
wrapper function.
>> class CallableClass:
>> def __init__(self):
>> super().__init__()
>> self.relu = torch.nn.ReLU()
>>
>> def __call__(self, x):
>> return self.relu(torch.sin(x))
>>
>> def print_hello(self):
>> print("Hello world")
>>
>> mod = CallableClass()
If you want to optimize the __call__ function and other code, wrap that up in a function
>> def wrapper_fn(x):
>> y = mod(x)
>> return y.sum()
and then optimize the wrapper_fn
>> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn)
"""
)
)
always_optimize_code_objects[fn.__code__] = True
return _fn
class OptimizeContext(_TorchDynamoContext):
@staticmethod
def _different_backend(old, new):
return not (old == new or old is None)
def __init__(self, callback, backend_ctx_ctor, first_ctx=False, *, dynamic=False):
def on_enter():
global most_recent_backend
if OptimizeContext._different_backend(most_recent_backend, compiler_fn):
if config.raise_on_backend_change:
raise ResetRequired()
else:
warnings.warn(
"changing options to `torch.compile()` may require "
"calling `torch._dynamo.reset()` to take effect"
)
most_recent_backend = compiler_fn
install_generation_tagging_init()
compiler_fn = innermost_fn(callback)
super().__init__(
callback=callback,
on_enter=on_enter,
backend_ctx_ctor=backend_ctx_ctor,
patch_fn=TorchPatcher.patch,
first_ctx=first_ctx,
dynamic=dynamic,
)
class RunOnlyContext(_TorchDynamoContext):
def __init__(self):
super().__init__(callback=False)
class DisableContext(_TorchDynamoContext):
def __init__(self):
super().__init__(callback=None)
def catch_errors_wrapper(callback, hooks: Hooks):
@functools.wraps(callback)
def catch_errors(frame, cache_size):
if (
frame.f_lasti >= 0
or skipfiles.check(frame.f_code.co_filename)
or config.disable
):
log.debug(f"skipping {frame.f_code.co_name} {frame.f_code.co_filename}")
return None
if frame.f_code.co_filename == "<string>" and frame.f_code.co_name == "__new__":
# nametuple constructor
return None
if config.optimize_ddp:
ddp_module = DistributedDataParallel._get_active_ddp_module()
if ddp_module:
with compile_lock:
from torch._dynamo.backends.distributed import DDPOptimizer
ddp_optimizer = DDPOptimizer(
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
backend_compile_fn=callback._torchdynamo_orig_callable,
)
hijacked_callback = convert_frame.convert_frame(
ddp_optimizer.compile_fn,
hooks=hooks,
)
return hijacked_callback(frame, cache_size, hooks)
with compile_lock:
return callback(frame, cache_size, hooks)
catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
return catch_errors
def _optimize_catch_errors(
compile_fn, hooks: Hooks, backend_ctx_ctor=null_context, dynamic=False
):
Loading ...