Repository URL to install this package:
Version:
2.1.2+cpu ▾
|
import contextlib
import logging
from typing import Dict, List, Optional
import torch._C
import torch.fx
import torch.nn
import torch.onnx.operators
from torch._dynamo.utils import deepcopy_to_fake_tensor, get_fake_value, get_real_value
from torch._dynamo.variables.base import VariableTracker
from torch._dynamo.variables.tensor import SymNodeVariable
from torch._guards import Source
from torch.utils import _pytree as pytree
from ..exc import unimplemented, Unsupported, UserError, UserErrorType
from ..guards import GuardBuilder
from ..source import FSDPNNModuleSource, GetItemSource, NNModuleSource
from ..utils import proxy_args_kwargs
from .lists import ListVariable, TupleVariable
log = logging.getLogger(__name__)
def safe_or_raise_always_restore(tx, graph_checkpoint, checkpoint, f, sub_args):
# Will raise if not sound
try:
f.call_function(tx, sub_args, {})
finally:
tx.output.graph = graph_checkpoint
tx.restore_graphstate(checkpoint)
@contextlib.contextmanager
def dynamo_enable_grad(tx):
from . import GradModeVariable
org_value = torch.is_grad_enabled()
try:
GradModeVariable.create(tx, True)
yield
finally:
GradModeVariable.create(tx, org_value)
def are_tensors(var):
from . import TensorVariable
if isinstance(var, TensorVariable):
return True
if isinstance(var, (TupleVariable, ListVariable)):
return all(are_tensors(item) for item in var.items)
return False
def validate_args_and_maybe_create_graph_inputs(
sub_args, tracer, tx, manually_set_subgraph_inputs
):
from . import AutogradFunctionContextVariable, ConstantVariable, TensorVariable
from .builder import wrap_fx_proxy
assert tracer.parent is not None
args = []
for a in sub_args:
assert isinstance(a, VariableTracker)
if isinstance(a, ConstantVariable):
# Ensures that we recompile when the constant value changes
a.add_guard(GuardBuilder.CONSTANT_MATCH)
if manually_set_subgraph_inputs:
# This arg is not used in the body of the higher order op.
# Currently, this new input is added to make the calls
# happy, which expect a fixed number of arguments. In
# future, we can clean this up.
tracer.create_graph_input("const")
new_arg = a
elif isinstance(a, TensorVariable):
if manually_set_subgraph_inputs:
new_proxy = tracer.create_graph_input(a.as_proxy().node.name)
example_value = a.as_proxy().node.meta["example_value"]
new_arg = wrap_fx_proxy(
tx=tx, proxy=new_proxy, example_value=example_value
)
else:
new_arg = a
elif isinstance(a, AutogradFunctionContextVariable):
if manually_set_subgraph_inputs:
tracer.create_graph_input(a.as_proxy().node.name)
new_arg = a
else:
raise unimplemented(
f"HigherOrderOperator with body that accepts non-Tensors as input. "
f"Got: {a.python_type()}"
)
args.append(new_arg)
return args
# See NOTE [HigherOrderOperator tracing design] for details of the design
def speculate_subgraph(
tx,
f,
sub_args,
sub_kwargs,
graph_checkpoint,
checkpoint,
description,
*,
always_restore=False,
enable_grad=False,
# NOTE [Temporary argument `manually_set_subgraph_inputs`]
# If manually_set_subgraph_inputs=True, then we manually add
# the `sub_args` to `subgraph`, if False then we rely
# on tracer's lifting mechanism to lift these args.
# NOTE: Default `True` is temporary and plan is
# to always lift args in future and remove this
# argument.
manually_set_subgraph_inputs=True,
restore_side_effects=True,
):
if sub_kwargs is None:
sub_kwargs = {}
# See NOTE [Temporary argument `manually_set_subgraph_inputs`]
if sub_kwargs and manually_set_subgraph_inputs:
unimplemented(
"Use `manually_set_subgraph_inputs=False` when passing `sub_kwargs`."
)
try:
with tx.output.new_subtracer() as tracer:
args = validate_args_and_maybe_create_graph_inputs(
sub_args, tracer, tx, manually_set_subgraph_inputs
)
validate_args_and_maybe_create_graph_inputs(
sub_kwargs.values(), tracer, tx, manually_set_subgraph_inputs=False
)
autograd_ctx = (
dynamo_enable_grad(tx) if enable_grad else contextlib.nullcontext()
)
if restore_side_effects:
prev_side_effects = tx.output.side_effects.clone()
with autograd_ctx:
output = f.call_function(tx, args, sub_kwargs)
if restore_side_effects:
# Captured variables are tracked in side-effects
# and they show up in output graph incorrectly.
# It is ok to undo this side-effect tracking
# as speculate_subgraph will allow only
# pure functions.
tx.output.side_effects = prev_side_effects
# Register output to graph
# Modeled off of compile_and_call_fx_graph
# TODO: support pytree output
# We check always_restore because we dont use the output or side effects of always_restore code,
# like bwd.
if always_restore:
# Nothing left to do here
return output, tx.output.graph, tracer.lifted_freevars
else:
if not are_tensors(output):
unimplemented(
"HigherOrderOperator body's output must consist of tensors only"
)
tx.output.guards.update(output.guards)
# The output proxies might not belong to this SubgraphTracer
# (if they are free variables that were never lifted)
# so lift them here.
output_proxies = output.as_proxy()
output_proxies = pytree.tree_map(
tracer.maybe_lift_tracked_freevar_to_input, output_proxies
)
tx.output.create_node(
"output",
"output",
(tracer.create_arg((output_proxies,))),
{},
)
graph = tx.output.graph
graph.lint()
lifted_freevars = tracer.lifted_freevars
return (
output,
graph,
lifted_freevars,
)
except Unsupported as ex:
msg = (
f"speculate_subgraph: while introspecting {description}, we were unable "
f"to trace function `{f.get_name()}` into a single graph. This means "
f"that Dynamo was unable to prove safety for this API and will "
f"fall back to eager-mode PyTorch, which could lead to a slowdown."
)
log.warning(msg)
log.exception(ex)
tx.output.graph = graph_checkpoint
tx.restore_graphstate(checkpoint)
raise Unsupported(
f"{msg} Scroll up for the stack trace "
f"of the initial exception. The reason was: {ex.msg}"
) from ex
def make_attr(tx, name):
node = tx.output.create_proxy(
"get_attr",
name,
(),
{},
)
return node
def add_subgraph(tx, source, name, gm):
next_name = None
i = 0
while not next_name:
candidate = f"{name}_{i}"
if candidate in tx.output.nn_modules:
i += 1
else:
next_name = candidate
gm.__name__ = next_name
if source.guard_source().is_fsdp_module():
src = FSDPNNModuleSource(GetItemSource(source, next_name))
else:
src = NNModuleSource(GetItemSource(source, next_name))
gm.torchdynamo_force_dynamic = False
tx.output.register_attr_or_module(gm, next_name, source=src)
return next_name
class TorchHigherOrderOperatorVariable(VariableTracker):
def __init__(self, value, source: Optional[Source] = None, **kwargs):
super().__init__(**kwargs)
self.value = value
self.source = source
@staticmethod
def make(value, source=None, **kwargs):
if value.__name__ == "cond":
return CondHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "map":
return MapHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "executorch_call_delegate":
return ExecutorchCallDelegateHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "out_dtype":
return OutDtypeHigherOrderVariable(value, source, **kwargs)
elif value is torch._functorch.eager_transforms.grad_impl:
return FunctorchGradHigherOrderVariable(value, source, **kwargs)
elif value is torch._functorch.vmap.vmap_impl:
return FunctorchVmapHigherOrderVariable(value, source, **kwargs)
elif value.__name__ in (
"trampoline_autograd_fwd",
"trampoline_autograd_bwd",
"trampoline_autograd_apply",
):
return AutogradFunctionMethodHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "wrap":
return WrapHigherOrderVariable(value, source, **kwargs)
elif value.__name__ in (
"wrap_activation_checkpoint",
"tag_activation_checkpoint",
):
return CheckpointHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "_export_tracepoint":
return ExportTracepointHigherOrderVariable(value, source, **kwargs)
else:
unimplemented(f"HigherOrderOperator {value.__name__}")
def check_kwargs(self, kwargs, supported_types):
assert (
all(isinstance(value, supported_types) for value in kwargs.values())
or not kwargs
), "only constant kwargs are supported"
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
unimplemented(f"HigherOrderOperator {self.value.__name__}")
class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import (
ConstantVariable,
ListVariable,
NestedUserFunctionVariable,
TensorVariable,
UserFunctionVariable,
)
from .builder import wrap_fx_proxy
self.check_kwargs(kwargs, ConstantVariable)
# TODO(voz): Support fake tensor dispatch for recursive
# ops - see torch/dispatch/_dispatcher.py
if len(args) != 4:
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
f"Expected 4 arguments but got {len(args)}.\n"
f"Usage: cond(pred, true_fn, false_fn, operands)",
)
# predicate
if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
f"Expected pred to be bool/int or a tensor with single "
f"item but got {str(type(args[0]))} "
f"with original python type {str(args[0].python_type())}.",
)
tx.output.guards.update(args[0].guards)
# operands
if not isinstance(args[3], (ListVariable, TupleVariable)):
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
f"Expected a tuple but got {args[3].python_type()}",
)
operands = args[3].unpack_var_sequence(tx)
if not all(
isinstance(operand, (TensorVariable, torch.Tensor)) for operand in operands
):
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
"Expected a tuple of tensors but got {actual_args}".format( # noqa: UP032
actual_args=[
str(operand.python_type())
if isinstance(operand, VariableTracker)
else str(type(operand))
for operand in operands
],
),
)
# branches
assert isinstance(
args[1], (UserFunctionVariable, NestedUserFunctionVariable)
), str(
type(args[1])
) # true_fn
assert isinstance(
args[2], (UserFunctionVariable, NestedUserFunctionVariable)
), str(
type(args[2])
) # false_fn
# Our strategy for tracing the true/false branches of cond
# are to checkpoint our graphstate, run the true branch,
# roll it back to the checkpoint, and run the false
# branch, and then merge the graphstates. Well, perhaps
# "merge" is too strong a word: we mostly assert that
# the resulting graphstates have to be the same.
#
# We only permit guards to diverge (we union the guards from
# both branches). In particular, this means that side
# effects are NOT permitted inside true/false branches; this
# would be difficult to implement, because of the path
# explosion problem.
graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate()
def speculate_branch(branch):
# NB: 0 is predicate
ix = 1 if branch else 2
try:
# TODO: Support kwargs
ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
tx,
args[ix],
operands,
{},
graph_checkpoint,
checkpoint,
"cond",
)
# Reraise because we want to suggest workarounds
except Unsupported as e:
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e)) from e
if not isinstance(ret_val, TensorVariable):
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
"Expected branch out type to be a single tensor",
)
return ret_val, ret_graph, ret_lifted_freevars
(true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
true_nn_modules = tx.copy_graphstate().output.nn_modules
(false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
false_nn_modules = tx.copy_graphstate().output.nn_modules
# TODO (tmanlaibaatar) deduplicate this later
# Let's say we capture cond(pred, true_fn, false_fn, x)
# and true_fn has lifted variables a, b, c
# and false_fn has lifted variables a, b, d
# Then each branch graph will receive:
# true_fn(x, a, b, c, a_false, b_false, d_false)
# false_fn(x, a_true, b_true, c_true, a, b, d)
# https://github.com/pytorch/pytorch/issues/103530
def fixup_branch_inps(graph, add_after, new_args, suffix) -> None:
inp_count = 0
for node in graph.nodes:
if node.op == "placeholder":
if inp_count == add_after:
with graph.inserting_after(node):
for inp_node in new_args:
new_node_name = inp_node.node.name + suffix
graph.placeholder(new_node_name)
break
inp_count += 1
fixup_branch_inps(
true_graph,
len(operands) + len(true_lifted_freevars) - 1,
false_lifted_freevars,
"_false_branch",
)
fixup_branch_inps(
false_graph, len(operands) - 1, true_lifted_freevars, "_true_branch"
)
true_name = add_subgraph(
tx,
self.source,
"cond_true",
torch.fx.GraphModule(true_nn_modules.nn_modules, true_graph),
)
false_name = add_subgraph(
tx,
self.source,
"cond_false",
torch.fx.GraphModule(false_nn_modules.nn_modules, false_graph),
)
true_node = make_attr(tx, true_name)
false_node = make_attr(tx, false_name)
p_args = (
args[0].as_proxy(),
true_node,
false_node,
[a.as_proxy() for a in operands]
+ list(true_lifted_freevars.keys())
+ list(false_lifted_freevars.keys()),
)
# TODO: assert that the true/false return values are
# consistent
example_value = true_r.as_proxy().node.meta["example_value"]
_, p_kwargs = proxy_args_kwargs([], kwargs)
# Store the invocation as a call
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(p_args),
kwargs=p_kwargs,
),
example_value=example_value,
)
def non_single_tensor_return_unsupported(api, ret):
from . import TensorVariable
if not isinstance(ret, TensorVariable):
raise Unsupported(
f"{api} over function that returns something " f"other than one Tensor"
)
class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
from . import (
ConstantVariable,
NestedUserFunctionVariable,
TensorVariable,
UserFunctionVariable,
)
from .builder import wrap_fx_proxy
self.check_kwargs(kwargs, ConstantVariable)
assert type(args[0]) in (UserFunctionVariable, NestedUserFunctionVariable)
assert type(args[1]) is TensorVariable
sample_shape = args[1].get_real_value().size()
if len(sample_shape) < 1 or sample_shape[0] == 0:
unimplemented(
"map() operator doesn't support scalar or zero-sized tensors during tracing."
)
checkpoint = tx.copy_graphstate()
# To get the example output from map() we will need to provide at least one sample to
# the loop body. In our case we will always use xs[0], and our map() won't support zero
# sized tensor during tracing.
first_dim = args[1].call_method(
tx, "__getitem__", args=[ConstantVariable(0)], kwargs={}
)
# TODO: Support kwargs
(
body_r,
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
tx,
args[0],
[
first_dim,
*args[2:],
],
{},
tx.output.graph,
checkpoint,
"torch.ops.higher_order.map",
)
body_nn_modules = tx.copy_graphstate().output.nn_modules
body_name = add_subgraph(
tx,
self.source,
"map_body",
torch.fx.GraphModule(body_nn_modules.nn_modules, body_graph),
)
body_node = make_attr(tx, body_name)
p_args = (
body_node,
*(arg.as_proxy() for arg in args[1:]),
*(arg for arg in body_lifted_freevars.keys()),
)
non_single_tensor_return_unsupported("torch.ops.higher_order.map", body_r)
r = body_r.as_proxy().node.meta["example_value"]
example_value = r.new_empty(
[get_fake_value(args[1].as_proxy().node, tx).shape[0], *r.shape]
)
_, p_kwargs = proxy_args_kwargs([], kwargs)
# Store the invocation as a call
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(p_args),
kwargs=p_kwargs,
),
example_value=example_value,
)
class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import ConstantVariable
from .builder import wrap_fx_proxy
self.check_kwargs(kwargs, ConstantVariable)
# This is operator for delegation within Executorch which calls a
# specific function in the given lowered module with the given
# operators. The actual operator is defined in the Executorch codebase.
# This is a bad hierarchical violation since
# executorch_call_delegate sits at a higher level than dynamo, but
# there's no real solution to this issue yet.
lowered_module = tx.output.get_submodule(args[0].module_key)
lowered_node = make_attr(tx, args[0].module_key)
p_args = tuple(arg.as_proxy() for arg in args[1:])
real_sub_args = pytree.tree_map_only(
torch.fx.Proxy, lambda a: get_real_value(a.node, tx.output), p_args
)
example_res = lowered_module.original_module(*real_sub_args)
example_value = deepcopy_to_fake_tensor(example_res, tx.fake_mode)
p_args = (lowered_node,) + p_args
_, p_kwargs = proxy_args_kwargs([], kwargs)
# Store the invocation as a call
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(p_args),
kwargs=p_kwargs,
),
example_value=example_value,
)
class FunctorchGradHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import ConstantVariable
from .builder import wrap_fx_proxy
self.check_kwargs(kwargs, ConstantVariable)
# TODO: Support `fn` with kwargs.
if not torch._dynamo.config.capture_func_transforms:
unimplemented(
"torch.func.grad capture is disabled, "
"it can be turned on by setting "
"`torch._dynamo.config.capture_func_transforms=True`"
)
# [NOTE] Here we are (roughly) modelling the following
#
# grad_fn = torch.func.grad(fn, argnums=.., has_aux=..)
# grad_output = grad_fn(x)
checkpoint = tx.copy_graphstate()
graph_checkpoint = tx.output.graph
grad_args = (args[0], args[1], args[2])
# get arguments
func, argnums, has_aux = grad_args
kwargs = args[4].items
if len(kwargs) > 0:
# Since speculate_subgraph doesn't support kwargs, we can't handle this for now.
unimplemented(
"torch.func.grad: kwargs arguments are currently unsupported."
)
# Trace through the `func`
# NOTE [HACK: Enable autograd while tracing function]
# `torch.func.grad` should not be affected by `no_grad` outside of `grad`.
# So, we enable_grad right before the function to which `grad` is applied
# (the parts explicitly disabled with `no_grad` inside the function are still disabled).
# Eg.
# def f(x):
# with no_grad(): # This will disable grad tracking under it.
# y = x * 2
#
# return x ** 2 - y # grad tracking should be enabled irrespective of outside `no_grad`.
#
# with no_grad(): # This will not disable grad tracking inside of grad(f).
# grad_o = torch.func.grad(f)(x)
# TODO: Support kwargs
body_r, body_graph, body_lifted_freevars = speculate_subgraph(
tx,
func,
args[3].items,
{},
graph_checkpoint,
checkpoint,
"torch.func.grad",
# See NOTE [HACK: Enable autograd while tracing function]
enable_grad=True,
)
body_name = add_subgraph(
tx,
self.source,
"grad_body",
torch.fx.GraphModule(tx.output.nn_modules, body_graph),
)
body_node = make_attr(tx, body_name)
grad_proxy_args = (
body_node,
*(arg.as_proxy() for arg in grad_args[1:]),
)
# Model `grad_fn = grad(fn, *grad_args, **grad_kwargs)`
grad_fn = tx.output.create_proxy(
"call_function",
torch.func.grad,
args=tuple(grad_proxy_args),
kwargs={},
name="grad_proxy",
)
# Pass lifted freevars to the call to `grad_fn`
args = args[3].items
grad_fn_args = tuple(arg.as_proxy() for arg in args) + tuple(
body_lifted_freevars
)
# Call grad_fn with inputs.
# grad_output = grad_fn(*grad_fn_args, **grad_fn_kwargs)
grad_output = grad_fn(*grad_fn_args)
# `grad_fn(*grad_fn_args, **grad_fn_kwargs)`
# Output of grad_fn is
# For has_aux=False, Tuple[gradients of inputs indicated by argnums].
# For has_aux=True, Tuple[Tuple[gradients of inputs indicated by argnums], aux values]
# NOTE: example_value should match `grad_output`.
def _from_args(idx):
return args[idx].as_proxy().node.meta["example_value"].contiguous()
def to_python_ints(argnums):
if not isinstance(argnums, (ConstantVariable, TupleVariable)):
raise UserError(
UserErrorType.INVALID_INPUT,
f"argnums is expected to be int or tuple of ints. Got {argnums}.",
)
if isinstance(argnums, ConstantVariable):
if not isinstance(argnums.value, (int, tuple)):
raise UserError(
UserErrorType.INVALID_INPUT,
f"argnums is expected to be int or tuple of ints. Got {argnums}.",
)
return argnums.value
else:
const_vars = argnums.unpack_var_sequence(tx)
if not all(
isinstance(var, ConstantVariable) and isinstance(var.value, int)
for var in const_vars
):
raise UserError(
UserErrorType.INVALID_INPUT,
f"argnums is expected to contain int only. Got {const_vars}.",
)
return tuple(var.value for var in const_vars)
argnums_v = to_python_ints(argnums)
example_value = pytree.tree_map(_from_args, argnums_v)
if has_aux.value:
# case : has_aux = True
# NOTE: Currently speculate subgraph allows body_r to be
# Tensor or Tuple/List of Tensor.
# Since `grad` expects output with has_aux
# to be (output, aux), only valid output currently is
# (output, some_tensor)
body_r_proxy = body_r.as_proxy()
aux = body_r_proxy[1].node.meta["example_value"]
example_value = (example_value, aux)
fx_proxy = wrap_fx_proxy(tx=tx, proxy=grad_output, example_value=example_value)
# Call contiguous on all the computed grads.
if not has_aux.value:
if isinstance(argnums_v, int):
return fx_proxy.call_method(tx, "contiguous", (), {})
else:
grads = fx_proxy
items = []
for idx in range(len(argnums_v)):
proxy = grads.call_method(
tx, "__getitem__", (ConstantVariable(idx),), {}
).call_method(tx, "contiguous", (), {})
items.append(proxy)
return TupleVariable(items)
else: # case: has_aux.value = True
# fx_proxy -> Tuple(grads, aux)
grads = fx_proxy.call_method(tx, "__getitem__", (ConstantVariable(0),), {})
aux = fx_proxy.call_method(tx, "__getitem__", (ConstantVariable(1),), {})
if isinstance(argnums_v, int):
return TupleVariable([grads.call_method(tx, "contiguous", (), {}), aux])
else:
items = []
for idx in range(len(argnums_v)):
proxy = grads.call_method(
tx, "__getitem__", (ConstantVariable(idx),), {}
).call_method(tx, "contiguous", (), {})
items.append(proxy)
return TupleVariable([TupleVariable(items), aux])
class FunctorchVmapHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import ConstantVariable, TensorVariable
from .builder import wrap_fx_proxy
if not torch._dynamo.config.capture_func_transforms:
unimplemented(
"torch.func.vmap capture is disabled, "
"it can be turned on by setting "
"`torch._dynamo.config.capture_func_transforms=True`"
)
checkpoint = tx.copy_graphstate()
graph_checkpoint = tx.output.graph
# unpack args
fn = args[0]
in_dims = args[1]
out_dims = args[2]
randomness = args[3]
chunk_size = args[4]
batch_input_args = args[5:]
if not isinstance(in_dims, (ConstantVariable, TupleVariable)):
unimplemented("torch.func.vmap: in_dims is not an int or tuple variable.")
if not isinstance(out_dims, (ConstantVariable, TupleVariable)):
unimplemented("torch.func.vmap: out_dims is not an int or tuple variable.")
if kwargs:
unimplemented(
"NYI - torch.func.vmap: kwargs arguments are currently unsupported."
)
if chunk_size.value is not None:
unimplemented(
"NYI - torch.func.vmap is not implemented when chunk_size is passed"
)
flat_args, arg_spec = torch.utils._pytree.tree_flatten(batch_input_args)
in_dims_v = in_dims.as_python_constant()
in_dims_v = in_dims_v if isinstance(in_dims_v, int) else list(in_dims_v)
broadcasted_in_dims = torch._functorch.vmap._broadcast_to_and_flatten(
in_dims_v, arg_spec
)
# We want to pass unbatched input to speculate subgraph.
# So we loop through the inputs and select only one sample
# from the batch.
unbatched_input_args = []
for arg, in_dim in zip(flat_args, broadcasted_in_dims):
if in_dim is not None:
assert isinstance(arg, TensorVariable)
unbatched_arg = arg.call_method(
tx, "select", (ConstantVariable(in_dim), ConstantVariable(0)), {}
)
unbatched_input_args.append(unbatched_arg)
else:
unbatched_input_args.append(arg)
# Ban ops like `stride`, `storage_offset` in the traced functions.
# NOTE: We are conservatively banning more ops (vmap should be able
# to handle a few of them).
with tx.strict_translation_mode():
# trace through the function with unbatched inputs.
_, body_graph, body_lifted_freevars = speculate_subgraph(
tx,
fn,
torch.utils._pytree.tree_unflatten(unbatched_input_args, arg_spec),
{},
graph_checkpoint,
checkpoint,
"torch.vmap",
)
body_name = add_subgraph(
tx,
self.source,
"vmap_body",
torch.fx.GraphModule(tx.output.nn_modules, body_graph),
)
body_node = make_attr(tx, body_name)
# body_lifted_variable should not be treated as batched.
# So here we update `in_dims` to reflect that.
# NOTE: updated_in_dims is flat list, it is ok for now
# as speculate_subgraph does not supports functions with non-Tensor args.
# (so we graph-break above)
updated_in_dims = TupleVariable(
[ConstantVariable(dim) for dim in broadcasted_in_dims]
+ [
ConstantVariable(None),
]
* len(body_lifted_freevars)
)
vmap_proxy_args = (
body_node,
*(arg.as_proxy() for arg in (updated_in_dims, out_dims, randomness)),
)
# vmap_proxy corresponds to `vmap_proxy = vmap(fn, *vmap_args, **vmap_kwargs)`
vmap_proxy = tx.output.create_proxy(
"call_function",
torch.func.vmap,
args=tuple(vmap_proxy_args),
kwargs={},
name="vmap_proxy",
)
proxy_batched_fn_args = tuple(
arg.as_proxy() for arg in batch_input_args
) + tuple(body_lifted_freevars)
# We compute the example_value by actually calling
# `vmap` with FakeTensors.
fake_batched_fn_args = tuple(
arg.as_proxy().node.meta["example_value"] for arg in batch_input_args
) + tuple(arg.node.meta["example_value"] for arg in body_lifted_freevars)
actual_in_dims = tuple(
pytree.tree_map(lambda x: x.value, updated_in_dims.items)
)
# NOTE: `body_graph` might have operators which
# will create new tensors. So it is required
# that we run `vmap` under FakeMode.
with tx.fake_mode:
example_value = torch._functorch.vmap.vmap_impl(
torch.fx.GraphModule(tx.output.nn_modules, body_graph),
actual_in_dims,
out_dims.as_python_constant(),
randomness.value,
chunk_size.value,
*fake_batched_fn_args,
)
# proxy corresponds to `call = vmap_proxy(*batched_fn_args, **batched_fn_kwargs)`
proxy = vmap_proxy(*proxy_batched_fn_args)
return wrap_fx_proxy(
tx=tx,
proxy=proxy,
example_value=example_value,
)
class AutogradFunctionMethodHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import ConstantVariable, UserFunctionVariable
from .builder import wrap_fx_proxy
self.check_kwargs(kwargs, ConstantVariable)
from . import TorchVariable
always_restore = self.value.__name__ == "trampoline_autograd_bwd"
if (
self.value.__name__ == "trampoline_autograd_bwd"
or self.value.__name__ == "trampoline_autograd_fwd"
):
fn = UserFunctionVariable(self.value, source=self.source)
else:
fn = TorchVariable(self.value)
checkpoint = tx.copy_graphstate()
pre_guards = tx.output.guards
graph_checkpoint = tx.output.graph
# TODO: Support kwargs
(
body_r,
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
tx,
fn,
[
*args,
],
{},
graph_checkpoint,
checkpoint,
"the user-defined autograd.Function",
# Backwards should never, ever be stored!
always_restore=always_restore,
restore_side_effects=False,
)
post_guards = tx.output.guards
if body_lifted_freevars:
for freevar in body_lifted_freevars.keys():
if "saved_tensor_marked" not in freevar.node.meta:
unimplemented("NYI - freevars in autograd function.")
if always_restore:
if post_guards - pre_guards:
unimplemented("NYI - New guards discovered in a restoring state")
# Nothing left to do here
return None
p_args = (
*(arg.as_proxy() for arg in args),
*(arg for arg in body_lifted_freevars.keys()),
)
non_single_tensor_return_unsupported("autograd.Function forward", body_r)
r = body_r.as_proxy().node.meta["example_value"]
example_value = r
_, p_kwargs = proxy_args_kwargs([], kwargs)
# Store the invocation as a call
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(p_args),
kwargs=p_kwargs,
),
example_value=example_value,
)
class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
def create_wrapped_node(self, tx, args, kwargs, description):
# See NOTE [HigherOrderOperator tracing design] for more details
checkpoint = tx.copy_graphstate()
graph_checkpoint = tx.output.graph
(
body_r,
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
tx,
args[0], # function
[*args[1:]],
kwargs,
graph_checkpoint,
checkpoint,
description,
manually_set_subgraph_inputs=False,
)
body_name = add_subgraph(
tx,
self.source,
"wrap_body",
torch.fx.GraphModule(tx.output.nn_modules, body_graph),
)
body_node = make_attr(tx, body_name)
# Since, we call `speculate_subgraph` with `manually_set_subgraph_inputs=False`,
# all the arguments are lifted.
lifted_args = tuple(arg for arg in body_lifted_freevars.keys())
proxy_args = (body_node,) + lifted_args
example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
return proxy_args, {}, example_value
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
p_args, p_kwargs, example_value = self.create_wrapped_node(
tx, args, kwargs, "wrap"
)
# Store the invocation as a call
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(p_args),
kwargs=p_kwargs,
),
example_value=example_value,
)
class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
if len(kwargs) != 0:
unimplemented("out_dtype does not handle kwargs")
p_args = tuple(arg.as_proxy() for arg in args)
op = p_args[0]
output_dtype = p_args[1]
fake_sub_args = pytree.tree_map_only(
torch.fx.Proxy, lambda a: get_fake_value(a.node, tx), p_args[2:]
)
# This is a simplified implementation of this operator just for tracing.
# Actual implementation may also first promote the arguments
example_value = op(*fake_sub_args).to(dtype=output_dtype)
# Store the invocation as a call
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(p_args),
kwargs={},
),
example_value=example_value,
)
class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
from torch._higher_order_ops.wrap import TagActivationCheckpoint
from .builder import wrap_fx_proxy
checkpoint_kwargs, gmod_kwargs = TagActivationCheckpoint.divide_kwargs(kwargs)
# Here we use checkpoint_kwargs (and not gmod kwargs). gmod_kwargs are
# already flattened above and managed inside the fx graph.
p_args, _, example_value = self.create_wrapped_node(
tx, args, gmod_kwargs, "torch.utils.checkpoint.checkpoint"
)
_, checkpoint_kwargs = proxy_args_kwargs([], checkpoint_kwargs)
# Store the invocation as a call
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(p_args),
kwargs=checkpoint_kwargs,
),
example_value=example_value,
)
class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
p_args = tuple(arg.as_proxy() for arg in args)
p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=p_args,
kwargs=p_kwargs,
),
example_value=None,
)