Repository URL to install this package:
Version:
2.4.1 ▾
|
# mypy: allow-untyped-defs
import functools
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable
import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._ops import HigherOrderOperator
from torch.fx.experimental.proxy_tensor import make_fx
from torch.multiprocessing.reductions import StorageWeakRef
@dataclass
class UnsupportedAliasMutationException(RuntimeError):
reason: str
def autograd_not_implemented_inner(
operator: HigherOrderOperator, delayed_error: bool, *args: Any, **kwargs: Any
) -> Any:
"""If autograd is enabled and any of the arguments require grad this will either
raise an error or return a DelayedError depending on the value of delayed.
Args:
operator: The HigherOrderOperator to call with the *args and **kwargs with
op_name: The name of the HigherOrderOperator
delayed_error: If True, return a DelayedError instead of raising an error
args: The flattened operands to the HigherOrderOperator
kwargs: The keyword arguments to the HigherOrderOperator
Raises:
RuntimeError: If autograd is enabled and any of the arguments to the HigherOrderOperator
"""
with torch._C._AutoDispatchBelowAutograd():
result = operator(*args, **kwargs)
flat_operands = pytree.arg_tree_leaves(*args)
if torch.is_grad_enabled() and any(
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
):
if delayed_error:
err_fn = torch._C._functions.DelayedError(
f"Autograd not implemented for {str(operator)}",
1,
)
def fake_requires_grad(tensor):
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
tensor = tensor.detach()
tensor.requires_grad = True
return tensor
return pytree.tree_map_only(
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
)
else:
raise RuntimeError(f"Autograd not implemented for {str(operator)}")
return result
def autograd_not_implemented(op: HigherOrderOperator, deferred_error: bool) -> Callable:
def inner(*args, **kwargs):
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
return inner
def _maybe_run_with_interpreter(fn):
maybe_interpreted_fn = fn
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with fx_traceback.preserve_node_meta():
return torch.fx.Interpreter(fn).run(*args)
maybe_interpreted_fn = graph_with_interpreter
return maybe_interpreted_fn
def reenter_make_fx(fn):
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
@functools.wraps(fn)
def wrapped(*args):
assert (
_CURRENT_MAKE_FX_TRACER is not None
), "Cannot reenter make_fx when we're not under a make_fx tracing session"
return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
_maybe_run_with_interpreter(fn), *args
)
return wrapped
@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
_old_is_inlining = torch._dynamo.config.inline_inbuilt_nn_modules
try:
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
# once we are confident fx tracing works with dynamo.
torch.fx._symbolic_trace._is_fx_tracing_flag = False
# TODO(anijain2305, export-team) For non-strict export with module
# stack info, the codepatch forces the nn module __getattr__ to
# ProxyAttr __getattr__ downstream. To circumvent the issue for now,
# skip inlining inbuilt nn modules for cond.
torch._dynamo.config.inline_inbuilt_nn_modules = False
yield
finally:
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
torch._dynamo.config.inline_inbuilt_nn_modules = _old_is_inlining
def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):
"""
Dispatch-trace the branch with inputs and check if
producing graph has mutable op on the input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
def _detect_input_mutation(gm):
input_nodes = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_nodes.add(node)
if node.op == "call_function":
target = node.target
if (
isinstance(target, torch._ops.OpOverload)
and target._schema.is_mutable
):
for arg in node.args:
if arg in input_nodes:
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule):
if _detect_input_mutation(module):
return True
return False
return _detect_input_mutation(gm)
def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):
"""
Dispatch-trace the branch with inputs and check if
producing graph has output aliasing the branch input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
def _detect_input_alias(gm):
input_storages = set()
for node in gm.graph.nodes:
# We need to check existence of "val" because we reuse the logic here
# for map operator, where num_mapped_args is a scalar
# and doesn't have a "val" meta.
if node.op == "placeholder" and "val" in node.meta:
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
if node.op == "output":
def check_alias(out):
if out is not None and "val" in out.meta:
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
return out_storage in input_storages
return False
if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
return True
return False
return _detect_input_alias(gm)
def unique_graph_id(proxy_mode, prefix):
"""Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
# There are probably better ways - I know that create_arg has some self incrementing name
# magic to it, but since we explicitly have to get the name for register_module,
# I was not sure how to do that. This kinda simulates it.
next_name = None
i = 0
while not next_name:
candidate = f"{prefix}_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
i += 1
else:
next_name = candidate
return i, next_name