Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _ops.py

import contextlib
import ctypes
import inspect
import sys
import types
from abc import ABC
from typing import Any, Dict

import torch._C

from torch import _utils_internal
from torch._functorch.pyfunctorch import dispatch_functorch

# Query `hasattr` only once.

_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")


@contextlib.contextmanager
def dl_open_guard():
    """
    Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
    shared library to load custom operators.
    """
    if not _SET_GLOBAL_FLAGS:
        yield
        return
    old_flags = sys.getdlopenflags()
    sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
    try:
        yield
    finally:
        sys.setdlopenflags(old_flags)


def has_key(op, k):
    return (
        torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), k)
        or k in op.py_kernels
    )


# TODO(voz) We are missing an entire axis of registration - Modes for the python key
class PyOperatorABC(ABC):
    def __call__(self, *args, **kwargs):
        pass

    def py_impl(self, dispatch_key, fn):
        pass

    def name(self):
        pass


is_included_in_alias = torch._C._dispatch_is_included_in_alias

DispatchKey = torch._C.DispatchKey

# Equivalent to computeDispatchTableEntryWithDebug
def resolve_key(op: PyOperatorABC, k: DispatchKey):  # type: ignore[valid-type]
    # 1. (Direct) operator registration
    if has_key(op, k):
        return k
    # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
    cand = DispatchKey.CompositeExplicitAutogradNonFunctional
    if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
        op, cand
    ):
        return cand
    # 2.2 Use CompositeExplicitAutograd kernel if available
    cand = DispatchKey.CompositeExplicitAutograd
    if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
        op, cand
    ):
        return cand
    has_backend_kernel = torch._C._dispatch_has_kernel_for_any_dispatch_key(
        op.name(), torch._C._dispatch_get_backend_keyset_from_autograd(k)
    ) or has_key(op, DispatchKey.CompositeExplicitAutograd)
    # 2.3. Use CompositeImplicitAutograd kernel if available
    cand = DispatchKey.CompositeImplicitAutogradNestedTensor
    if (
        (k != DispatchKey.Undefined and is_included_in_alias(k, cand))
        and has_key(op, cand)
        and not has_backend_kernel
    ):
        return cand
    cand = DispatchKey.CompositeImplicitAutograd
    if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
        op, cand
    ):
        if (
            k == DispatchKey.AutogradOther
            and torch._C._dispatch_has_kernel_for_any_dispatch_key(
                op.name(), torch._C._dispatch_autogradother_backends
            )
        ):
            raise RuntimeError("ambiguous autogradother kernel")
        elif not has_backend_kernel:
            return cand
    # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
    cand = DispatchKey.Autograd
    if is_included_in_alias(k, cand) and has_key(op, cand):
        return cand
    # Backend fallback
    if torch._C._dispatch_has_backend_fallback(k):
        # The dispatch key itself will implicitly route to backend fallback.
        # This is probably not great for the pure Python implementation.
        return k
    raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")


pyop_namespace = {}


class PyOperator(PyOperatorABC):
    def __init__(self, name):
        self._name = name
        self.table = {}
        self.python_key_mode_table = {}
        self.functorch_table = {}

        # Make _OPNamespace not scream, this whole name based association needs a good hard look
        self.__name__ = name
        pyop_namespace[name] = self

    def fallthrough(self, dispatch_key):
        self.table[dispatch_key] = self._fallthrough_fn(self, dispatch_key)

    def py_impl(self, dispatch_key_or_mode_or_transform):
        def inner(fn):
            if inspect.isclass(dispatch_key_or_mode_or_transform) and issubclass(
                dispatch_key_or_mode_or_transform,
                torch.utils._python_dispatch.TorchDispatchMode,
            ):
                mode = dispatch_key_or_mode_or_transform
                assert mode not in self.python_key_mode_table
                # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
                self.python_key_mode_table[mode] = fn
                return fn

            if isinstance(
                dispatch_key_or_mode_or_transform, torch._C._functorch.TransformType
            ):
                transform = dispatch_key_or_mode_or_transform
                self.functorch_table[transform] = fn
                return fn

            dispatch_key = dispatch_key_or_mode_or_transform
            assert (
                dispatch_key != torch._C.DispatchKey.Python
            ), "Please register a mode for the torch._C.DispatchKey.Python key instead."
            assert isinstance(dispatch_key, torch._C.DispatchKey)
            assert dispatch_key not in self.table
            self.table[dispatch_key] = fn
            return fn

        return inner

    def dispatch(self, dispatch_key, *args, **kwargs):
        from torch.utils._python_dispatch import _get_current_dispatch_mode

        if dispatch_key == torch._C.DispatchKey.FuncTorchDynamicLayerFrontMode:
            return dispatch_functorch(self, args, kwargs)

        if dispatch_key == torch._C.DispatchKey.Python:
            # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
            curr_mode = _get_current_dispatch_mode()
            assert (
                curr_mode is not None
            ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
            assert (
                type(curr_mode) in self.python_key_mode_table
            ), f"Current active mode {curr_mode} not registered"
            # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
            return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)

        assert dispatch_key in self.table, dispatch_key
        return self.table[dispatch_key](*args, **kwargs)

    def __call__(self, *args, **kwargs):
        flat_args = _to_flat_tuple(args, kwargs)
        if torch.overrides.has_torch_function(flat_args):
            return torch.overrides.handle_torch_function(
                self, flat_args, *args, **kwargs
            )

        dispatch_key_set = _compute_keyset(args, kwargs)
        return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)

    def name(self):
        return self.name

    # TODO(voz): Should rewrite fallthrough register as the impl for keys we do not specify
    # as opposed to being this sort of explicit thing where ops are a little too key aware...
    def _fallthrough_fn(self, operator, dispatch_key):
        def inner(*args, **kwargs):
            all_keys_after_current = torch._C._dispatch_keyset_full_after(dispatch_key)
            all_keys_after_current_masked = all_keys_after_current & _compute_keyset(
                args, kwargs
            )
            return self.dispatch(
                all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs
            )

        return inner


def _to_flat_tuple(args, kwargs):
    flat_args, _ = torch.utils._pytree.tree_flatten(args)
    flat_kwargs, _ = torch.utils._pytree.tree_flatten(kwargs)
    flat_all = flat_args + flat_kwargs
    return flat_all


def _compute_keyset(args, kwargs):
    tensors = _get_tensors(args, kwargs)
    return key_extractor(tensors)


def _get_tensors(args, kwargs):
    flat_all = _to_flat_tuple(args, kwargs)
    tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
    return tuple(tensor_args)


# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
# at ATen/core/dispatch/DispatchKeyExtractor.h
def key_extractor(tensors):
    key_set = torch._C._dispatch_tls_local_include_set()
    for tensor in tensors:
        key_set = key_set | torch._C._dispatch_keys(tensor)
    key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
    return key_set


# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
class OpOverload(PyOperatorABC):
    def __init__(self, overloadpacket, op, op_dk, schema, tags):
        self._op = op
        self._op_dk = op_dk
        self._schema = schema
        self._overloadpacket = overloadpacket
        self._tags = tags
        self._overloadname = (
            "default" if schema.overload_name == "" else schema.overload_name
        )
        self._name = self._schema.name
        if schema.overload_name:
            self._name += "." + schema.overload_name
        self.py_kernels: Dict[torch._C.DispatchKey, Any] = {}  # type: ignore[name-defined]
        self.__name__ = "{}.{}".format(
            self._schema.name.split("::")[1], self._overloadname
        )
        # TODO(voz): Lots of shared logic around python_key_mode_table, maybe pull into base...
        self.python_key_mode_table = {}
        self.__module__ = overloadpacket.__module__
        op.__module__ = overloadpacket.__module__
        self.__qualname__ = self._name
        self.__annotations__ = {}
        # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp
        self._dispatch_cache = {}

        # Logic replicated from aten/src/ATen/native/MathBitsFallback.h
        is_write = None
        for a in self._schema.arguments:
            if a.alias_info is None:
                continue
            if is_write is None:
                is_write = a.alias_info.is_write
            else:
                # We will conservatively call mixed mutable/non-mutable
                # aliased inputs as NOT a view
                is_write = a.alias_info.is_write or is_write
        self.is_view = is_write is not None and not is_write

    # it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
    def __deepcopy__(self, memo=None):
        return self

    def __repr__(self):
        return "<OpOverload(op='{}.{}', overload='{}')>".format(
            *self._schema.name.split("::"), self._overloadname
        )

    def __call__(self, *args, **kwargs):
        return self._op(*args, **kwargs or {})

    def __hash__(self):
        return hash(self._op)

    # `my_namespace.my_op_name.overload_name`
    def __str__(self):
        return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)

    @property
    def namespace(self):
        return self._schema.name.split("::")[0]

    def decompose(self, *args, **kwargs):
        dk = torch._C.DispatchKey.CompositeImplicitAutograd
        if dk in self.py_kernels:
            # NB: This branch is not too necessary anymore, because we can
            # apply Python CompositeImplicitAutograd *before* tracing
            # using Python dispatcher (also taking advantage of the autograd
            # formula).  But it's included for completeness
            return self.py_kernels[dk](*args, **kwargs)
        elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
            return self._op_dk(dk, *args, **kwargs)
        else:
            return NotImplemented

    def py_impl(self, dispatch_key_or_mode):
        def inner(fn):
            if inspect.isclass(dispatch_key_or_mode) and issubclass(
                dispatch_key_or_mode, torch.utils._python_dispatch.TorchDispatchMode
            ):
                mode = dispatch_key_or_mode
                assert mode not in self.python_key_mode_table
                # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
                self.python_key_mode_table[mode] = fn
                self._dispatch_cache.clear()
                return fn

            assert isinstance(dispatch_key_or_mode, torch._C.DispatchKey)
            assert (
                dispatch_key_or_mode != torch._C.DispatchKey.Python
            ), "Please register a mode for the torch._C.DispatchKey.Python key instead."

            if dispatch_key_or_mode in self.py_kernels:
                raise RuntimeError(
                    f"Trying to override a python impl for {dispatch_key_or_mode} on operator {self._name}"
                )
            self.py_kernels[dispatch_key_or_mode] = fn
            self._dispatch_cache.clear()
            return fn

        return inner

    # Remove a dispatch key from the dispatch cache.  This will force it to get
    # recomputed the next time.  Does nothing
    # WARNING: if you register a dispatch key to py_kernels of an OpOverload,
    # calling _del_dispatch on that key is NOT sufficient to apply your change,
    # because a single registration may affect MULTIPLE dispatch keys (e.g.,
    # registering Autograd affects AutogradCPU).  del_dispatch is to be used
Loading ...