import contextlib
import ctypes
import inspect
import sys
import types
from abc import ABC
from typing import Any, Dict
import torch._C
from torch import _utils_internal
from torch._functorch.pyfunctorch import dispatch_functorch
# Query `hasattr` only once.
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
@contextlib.contextmanager
def dl_open_guard():
"""
Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
shared library to load custom operators.
"""
if not _SET_GLOBAL_FLAGS:
yield
return
old_flags = sys.getdlopenflags()
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
try:
yield
finally:
sys.setdlopenflags(old_flags)
def has_key(op, k):
return (
torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), k)
or k in op.py_kernels
)
# TODO(voz) We are missing an entire axis of registration - Modes for the python key
class PyOperatorABC(ABC):
def __call__(self, *args, **kwargs):
pass
def py_impl(self, dispatch_key, fn):
pass
def name(self):
pass
is_included_in_alias = torch._C._dispatch_is_included_in_alias
DispatchKey = torch._C.DispatchKey
# Equivalent to computeDispatchTableEntryWithDebug
def resolve_key(op: PyOperatorABC, k: DispatchKey): # type: ignore[valid-type]
# 1. (Direct) operator registration
if has_key(op, k):
return k
# 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
cand = DispatchKey.CompositeExplicitAutogradNonFunctional
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
op, cand
):
return cand
# 2.2 Use CompositeExplicitAutograd kernel if available
cand = DispatchKey.CompositeExplicitAutograd
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
op, cand
):
return cand
has_backend_kernel = torch._C._dispatch_has_kernel_for_any_dispatch_key(
op.name(), torch._C._dispatch_get_backend_keyset_from_autograd(k)
) or has_key(op, DispatchKey.CompositeExplicitAutograd)
# 2.3. Use CompositeImplicitAutograd kernel if available
cand = DispatchKey.CompositeImplicitAutogradNestedTensor
if (
(k != DispatchKey.Undefined and is_included_in_alias(k, cand))
and has_key(op, cand)
and not has_backend_kernel
):
return cand
cand = DispatchKey.CompositeImplicitAutograd
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
op, cand
):
if (
k == DispatchKey.AutogradOther
and torch._C._dispatch_has_kernel_for_any_dispatch_key(
op.name(), torch._C._dispatch_autogradother_backends
)
):
raise RuntimeError("ambiguous autogradother kernel")
elif not has_backend_kernel:
return cand
# 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
cand = DispatchKey.Autograd
if is_included_in_alias(k, cand) and has_key(op, cand):
return cand
# Backend fallback
if torch._C._dispatch_has_backend_fallback(k):
# The dispatch key itself will implicitly route to backend fallback.
# This is probably not great for the pure Python implementation.
return k
raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
pyop_namespace = {}
class PyOperator(PyOperatorABC):
def __init__(self, name):
self._name = name
self.table = {}
self.python_key_mode_table = {}
self.functorch_table = {}
# Make _OPNamespace not scream, this whole name based association needs a good hard look
self.__name__ = name
pyop_namespace[name] = self
def fallthrough(self, dispatch_key):
self.table[dispatch_key] = self._fallthrough_fn(self, dispatch_key)
def py_impl(self, dispatch_key_or_mode_or_transform):
def inner(fn):
if inspect.isclass(dispatch_key_or_mode_or_transform) and issubclass(
dispatch_key_or_mode_or_transform,
torch.utils._python_dispatch.TorchDispatchMode,
):
mode = dispatch_key_or_mode_or_transform
assert mode not in self.python_key_mode_table
# TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
self.python_key_mode_table[mode] = fn
return fn
if isinstance(
dispatch_key_or_mode_or_transform, torch._C._functorch.TransformType
):
transform = dispatch_key_or_mode_or_transform
self.functorch_table[transform] = fn
return fn
dispatch_key = dispatch_key_or_mode_or_transform
assert (
dispatch_key != torch._C.DispatchKey.Python
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
assert isinstance(dispatch_key, torch._C.DispatchKey)
assert dispatch_key not in self.table
self.table[dispatch_key] = fn
return fn
return inner
def dispatch(self, dispatch_key, *args, **kwargs):
from torch.utils._python_dispatch import _get_current_dispatch_mode
if dispatch_key == torch._C.DispatchKey.FuncTorchDynamicLayerFrontMode:
return dispatch_functorch(self, args, kwargs)
if dispatch_key == torch._C.DispatchKey.Python:
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
curr_mode = _get_current_dispatch_mode()
assert (
curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
assert (
type(curr_mode) in self.python_key_mode_table
), f"Current active mode {curr_mode} not registered"
# TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)
assert dispatch_key in self.table, dispatch_key
return self.table[dispatch_key](*args, **kwargs)
def __call__(self, *args, **kwargs):
flat_args = _to_flat_tuple(args, kwargs)
if torch.overrides.has_torch_function(flat_args):
return torch.overrides.handle_torch_function(
self, flat_args, *args, **kwargs
)
dispatch_key_set = _compute_keyset(args, kwargs)
return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
def name(self):
return self.name
# TODO(voz): Should rewrite fallthrough register as the impl for keys we do not specify
# as opposed to being this sort of explicit thing where ops are a little too key aware...
def _fallthrough_fn(self, operator, dispatch_key):
def inner(*args, **kwargs):
all_keys_after_current = torch._C._dispatch_keyset_full_after(dispatch_key)
all_keys_after_current_masked = all_keys_after_current & _compute_keyset(
args, kwargs
)
return self.dispatch(
all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs
)
return inner
def _to_flat_tuple(args, kwargs):
flat_args, _ = torch.utils._pytree.tree_flatten(args)
flat_kwargs, _ = torch.utils._pytree.tree_flatten(kwargs)
flat_all = flat_args + flat_kwargs
return flat_all
def _compute_keyset(args, kwargs):
tensors = _get_tensors(args, kwargs)
return key_extractor(tensors)
def _get_tensors(args, kwargs):
flat_all = _to_flat_tuple(args, kwargs)
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
return tuple(tensor_args)
# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
# at ATen/core/dispatch/DispatchKeyExtractor.h
def key_extractor(tensors):
key_set = torch._C._dispatch_tls_local_include_set()
for tensor in tensors:
key_set = key_set | torch._C._dispatch_keys(tensor)
key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
return key_set
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
class OpOverload(PyOperatorABC):
def __init__(self, overloadpacket, op, op_dk, schema, tags):
self._op = op
self._op_dk = op_dk
self._schema = schema
self._overloadpacket = overloadpacket
self._tags = tags
self._overloadname = (
"default" if schema.overload_name == "" else schema.overload_name
)
self._name = self._schema.name
if schema.overload_name:
self._name += "." + schema.overload_name
self.py_kernels: Dict[torch._C.DispatchKey, Any] = {} # type: ignore[name-defined]
self.__name__ = "{}.{}".format(
self._schema.name.split("::")[1], self._overloadname
)
# TODO(voz): Lots of shared logic around python_key_mode_table, maybe pull into base...
self.python_key_mode_table = {}
self.__module__ = overloadpacket.__module__
op.__module__ = overloadpacket.__module__
self.__qualname__ = self._name
self.__annotations__ = {}
# NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp
self._dispatch_cache = {}
# Logic replicated from aten/src/ATen/native/MathBitsFallback.h
is_write = None
for a in self._schema.arguments:
if a.alias_info is None:
continue
if is_write is None:
is_write = a.alias_info.is_write
else:
# We will conservatively call mixed mutable/non-mutable
# aliased inputs as NOT a view
is_write = a.alias_info.is_write or is_write
self.is_view = is_write is not None and not is_write
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
def __deepcopy__(self, memo=None):
return self
def __repr__(self):
return "<OpOverload(op='{}.{}', overload='{}')>".format(
*self._schema.name.split("::"), self._overloadname
)
def __call__(self, *args, **kwargs):
return self._op(*args, **kwargs or {})
def __hash__(self):
return hash(self._op)
# `my_namespace.my_op_name.overload_name`
def __str__(self):
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
@property
def namespace(self):
return self._schema.name.split("::")[0]
def decompose(self, *args, **kwargs):
dk = torch._C.DispatchKey.CompositeImplicitAutograd
if dk in self.py_kernels:
# NB: This branch is not too necessary anymore, because we can
# apply Python CompositeImplicitAutograd *before* tracing
# using Python dispatcher (also taking advantage of the autograd
# formula). But it's included for completeness
return self.py_kernels[dk](*args, **kwargs)
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
return self._op_dk(dk, *args, **kwargs)
else:
return NotImplemented
def py_impl(self, dispatch_key_or_mode):
def inner(fn):
if inspect.isclass(dispatch_key_or_mode) and issubclass(
dispatch_key_or_mode, torch.utils._python_dispatch.TorchDispatchMode
):
mode = dispatch_key_or_mode
assert mode not in self.python_key_mode_table
# TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
self.python_key_mode_table[mode] = fn
self._dispatch_cache.clear()
return fn
assert isinstance(dispatch_key_or_mode, torch._C.DispatchKey)
assert (
dispatch_key_or_mode != torch._C.DispatchKey.Python
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
if dispatch_key_or_mode in self.py_kernels:
raise RuntimeError(
f"Trying to override a python impl for {dispatch_key_or_mode} on operator {self._name}"
)
self.py_kernels[dispatch_key_or_mode] = fn
self._dispatch_cache.clear()
return fn
return inner
# Remove a dispatch key from the dispatch cache. This will force it to get
# recomputed the next time. Does nothing
# WARNING: if you register a dispatch key to py_kernels of an OpOverload,
# calling _del_dispatch on that key is NOT sufficient to apply your change,
# because a single registration may affect MULTIPLE dispatch keys (e.g.,
# registering Autograd affects AutogradCPU). del_dispatch is to be used
Loading ...