from abc import ABC, abstractmethod
import contextlib
from typing import Any
import torch
import torch.utils._pytree as pytree
from torch._C._functorch import (
TransformType,
RandomnessType,
CInterpreter,
CGradInterpreterPtr,
CFunctionalizeInterpreterPtr,
CVmapInterpreterPtr,
CJvpInterpreterPtr,
pop_dynamic_layer_stack,
push_dynamic_layer_stack,
)
from torch.autograd.forward_ad import _set_fwd_grad_enabled
"""
This file contains the functorch integration with PyDispatcher.
PyDispatcher does not understand functorch's DynamicLayerStack dispatching
logic because it is entirely implemented in C++ in the fallbacks for two
dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable
to directly reuse C++ boxed fallbacks).
Instead of trying to hammer PyDispatcher into understanding those fallbacks,
we re-implement the logic of peeking the top of the stack for an interpreter,
selecting the interpreter to dispatch on, etc, in Python. This leads to a
simpler design.
The main difference between C++ functorch and PyDispatcher's functorch logic
is that:
- C++ functorch needs to manually tweak dispatch keys to ping-pong between
DynamicLayerFrontMode and DynamicLayerBackMode.
- PyDispatcher's functorch logic pops an Interpreter from the top of the stack
and asks it to execute the rule associated with the Interpreter.
In C++ we do the ping-pong because e.g. vmap rules are associated with the
batched DispatchKey, but in PyDispatcher we are able to avoid this by asking
the user to register a batching rule directly to a transform that an
interpreter then invokes.
"""
# FuncTorchInterpreter is the Python version of Interpreter (recall that
# the DynamicLayerStack is a stack of interpreters).
# It is a wrapper around the actual C++ Interpreter object.
#
# Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h
class FuncTorchInterpreter(ABC):
def __init__(self, cptr: Any):
self._cptr = cptr
# Process an operation. eg for vmap, this is invoking a batching rule.
# Conceptually this is analogous to Interpreter::process in C++
@abstractmethod
def process(self, op, args, kwargs):
pass
# lower an operation from this Interpreter to the next Interpreter on the stack.
# Concretely, this involves temporarily popping the current Interpreter.
# Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++
def lower(self):
return temporarily_pop_interpreter_stack()
def level(self):
return self._cptr.level()
def key(self):
return self._cptr.key()
@contextlib.contextmanager
def temporarily_pop_interpreter_stack():
try:
saved = pop_dynamic_layer_stack()
yield
finally:
push_dynamic_layer_stack(saved)
class VmapInterpreter(FuncTorchInterpreter):
def __init__(self, cdata: CInterpreter):
assert cdata.key() == TransformType.Vmap
# NOTE: [Interpreter cdata vs cptr]
# cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
# so that we can access methods specific to the vmap interpreter
self._cdata = cdata
self._cptr = CVmapInterpreterPtr(cdata)
def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Vmap]
return kernel(self, *args, **kwargs)
def batch_size(self):
return self._cptr.batchSize()
def randomness(self):
typ = self._cptr.randomness()
if typ == RandomnessType.Error:
return "error"
elif typ == RandomnessType.Same:
return "same"
elif typ == RandomnessType.Different:
return "different"
raise RuntimeError(f"Unknown RandomnessType: {typ}")
@contextlib.contextmanager
def nested(*contexts):
with contextlib.ExitStack() as stack:
for ctx in contexts:
stack.enter_context(ctx)
yield contexts
class GradInterpreter(FuncTorchInterpreter):
def __init__(self, cdata: CInterpreter):
assert cdata.key() == TransformType.Grad
# See NOTE: [Interpreter cdata vs cptr]
self._cdata = cdata
self._cptr = CGradInterpreterPtr(cdata)
def lift(self, args, kwargs):
args, kwargs = pytree.tree_map_only(torch.Tensor, self._cptr.lift, [args, kwargs])
return args, kwargs
def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Grad]
args, kwargs = self.lift(args, kwargs)
return kernel(self, *args, **kwargs)
# GradInterpreter has custom lower because of the no_grad interaction
# See NOTE [grad and vjp interaction with no_grad]
# This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter
def lower(self):
prev_grad_mode = self.prev_grad_mode()
if not self.prev_grad_mode:
return nested(torch.no_grad(), super().lower())
return super().lower()
def prev_grad_mode(self):
return self._cptr.prevGradMode()
class JvpInterpreter(FuncTorchInterpreter):
def __init__(self, cdata: CInterpreter):
assert cdata.key() == TransformType.Jvp
# See NOTE: [Interpreter cdata vs cptr]
self._cdata = cdata
self._cptr = CJvpInterpreterPtr(cdata)
def lift(self, args, kwargs):
args, kwargs = pytree.tree_map_only(torch.Tensor, self._cptr.lift, [args, kwargs])
return args, kwargs
def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Jvp]
args, kwargs = self.lift(args, kwargs)
return kernel(self, *args, **kwargs)
# Jvp has custom lower because of the no_fwd_grad interaction
# See NOTE [grad and vjp interaction with no_grad] for related info.
# This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter
def lower(self):
prev_fwd_grad_mode = self.prev_fwd_grad_mode()
if not self.prev_fwd_grad_mode:
return nested(_set_fwd_grad_enabled(False), super().lower())
return super().lower()
def prev_fwd_grad_mode(self):
return self._cptr.prevFwdGradMode()
class FunctionalizeInterpreter(FuncTorchInterpreter):
def __init__(self, cdata: CInterpreter):
assert cdata.key() == TransformType.Functionalize
self._cdata = cdata
self._cptr = CFunctionalizeInterpreterPtr(cdata)
def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Functionalize]
return kernel(self, *args, **kwargs)
def functionalize_add_back_views(self):
return self._cptr.functionalizeAddBackViews()
def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
key = cinterpreter.key()
if key == TransformType.Grad:
return GradInterpreter(cinterpreter)
if key == TransformType.Vmap:
return VmapInterpreter(cinterpreter)
if key == TransformType.Jvp:
return JvpInterpreter(cinterpreter)
if key == TransformType.Functionalize:
return FunctionalizeInterpreter(cinterpreter)
raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")
def retrieve_current_functorch_interpreter():
interpreter = torch._C._functorch.peek_interpreter_stack()
assert interpreter is not None
return coerce_cinterpreter(interpreter)
def dispatch_functorch(op, args, kwargs):
interpreter = retrieve_current_functorch_interpreter()
# In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's
# unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers.
# PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch
# transforms, so we manually unwrap the dead tensors here.
# This logic won't need to exist when we have mode-only functorch.
args, kwargs = pytree.tree_map_only(
torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs))
return interpreter.process(op, args, kwargs)