Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ _python_dispatcher.py

import re
import torch._C as C


"""
PythonDispatcher class is a thin python-binding to C++ dispatcher and it
is designed to show how dispatcher precompute works. In particular,
it shows for a certain op `foo`, what the computed dispatch table looks
like after user register their kernels to certains dispatch keys.

In the real C++ dispatcher we support many dispatch keys for different
functionalities. For simplicity PythonDispatcher only supports dispatch
keys for a single example of each use case. These use cases are listed below:

- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference &
    autograd kernel in pytorch core library.
    E.g. CPU, CUDA
- QuantizedCPU/AutogradOther: represents in-tree backends which we usually have backend specific
    inference kernels, but they share the same autograd kernel specified in AutogradOther.
    E.g. QuantizedCPU, QuantizedCUDA
- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd
    kernel defined in pytorch core library. Backend owner is responsible for registering both
    inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support.
    E.g. XLA, XPU, MLC
- DefaultBackend: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc.
    Kernels registered to this key MUST work for inference for all backends.
- Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther.
    Kernels registered to this key MUST work for autograd for all backends.
- Math: alias key Math = DefaultBackend + Autograd
    Kernels registered to this key MUST work for both inference + autograd for all backends.

Note we only allow registrations to alias keys inside pytorch core library. E.g you shouldn't register
a Math or DefaultBackend kernel from torch-xla extension, instead you should upstream the kernel into
pytorch/pytorch repo so that it's available for all backends and continuously tested even without the extension.

Usage:
  dispatcher = PythonDispatcher()
  dispatcher.register(["CPU", "XLA", "Math"])
  print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend.
  # For more debugging information
  # print(dispatcher.keys())
  # print(dispatcher.registrations())
  # print(dispatcher.rawRegistrations())
  # print(dispatcher.rawDispatchTable())
PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table.
This file only provides the simplified API for developers, revelant test code is located in
test/test_dispatch.py
"""
class PythonDispatcher:
    namespace = "__test__"
    name = "foo"
    runtime_keys = [
        "CPU", "AutogradCPU",
        "QuantizedCPU", "AutogradOther",
        "XLA", "AutogradXLA",
    ]
    alias_keys = [
        "DefaultBackend",
        "Autograd",
        "Math",
    ]
    supported_keys = runtime_keys + alias_keys

    def __init__(self):
        C._dispatch_check_invariants(self.name)  # type: ignore[attr-defined]
        self.ref = C._dispatch_library("FRAGMENT", self.namespace, "")  # type: ignore[attr-defined]
        self.ref.def_("foo(Tensor x) -> Tensor")

    """
    Returns a list of dispatch keys supported by PythonDispatcher.
    You can register kernels to these keys.
    """
    def keys(self):
        return self.supported_keys

    """
    Register kernels to the target dispatchKeys.
    dispatchKeys(list[str]): a list of dispatch keys that you want to register
      your own kernel. Note that you don't need to write the kernel yourself in
      this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is
      automatically generated and registered.
    """
    def register(self, dispatchKeys):
        # Overriden is not supported and triggers a warning in C++ dispatcher.
        if len(set(dispatchKeys)) != len(dispatchKeys):
            raise RuntimeError(f"Overriden is not allowed but found duplicates in {dispatchKeys}.")
        # We currently forbid this in codegen instead of C++ dispatcher.
        if 'Math' in dispatchKeys and 'DefaultBackend' in dispatchKeys:
            raise RuntimeError("Registration to both Math and DefaultBackend is not allowed.")
        for key in dispatchKeys:
            if key not in self.supported_keys:
                raise RuntimeError(f"{key} is not supported, please select a dispatch key in {self.supported_keys}.")
            self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key)

    """
    Helper function to format (key, kernel).
    """
    def _format_line(self, key, kernel):
        return "{:<15} {}\n".format(key, kernel)

    """
    Helper function to print a table header.
    """
    def _format_header(self, header):
        s = f"""
{header}
"""
        s += self._format_line("key", "kernel")
        s += "---------------------------\n"
        return s

    """
    Returns raw output of all registration info for debugging only.
    Use registrations() for a simplified version.
    """
    def rawRegistrations(self):
        return C._dispatch_dump("{}::{}".format(self.namespace, self.name))  # type: ignore[attr-defined]

    """
    Returns raw output of computed dispatch table for debugging only.
    Use dispatchTable() for a simplified version.
    """
    def rawDispatchTable(self):
        return C._dispatch_dump_table("{}::{}".format(self.namespace, self.name))  # type: ignore[attr-defined]

    """
    Returns a table(str) including all the registrations from users.
    Note this includes registrations to both runtime keys and alias keys.
    """
    def registrations(self):
        output = self._format_header("Registered Kernels")
        state = self.rawRegistrations()
        state_entries = state.split('\n')
        for line in state_entries:
            first = line.split(":")[0]
            if any(first.startswith(k) for k in self.supported_keys):
                kernel = line.split("::")[0].split(" ")[1]
                output += self._format_line(first, kernel)
        return output

    """
    Returns the computed dispatch table(str). Note this only include
    runtime keys, registrations to alias keys have been decoded to their
    mapped runtime keys.
    """
    def dispatchTable(self):
        output = self._format_header("Computed Dispatch Table")
        table = self.rawDispatchTable()
        table_entries = table.split('\n')
        regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)")
        for line in table_entries:
            k = line.split(":")[0]
            if k in self.runtime_keys:
                entry = regex.sub('[', line)
                output += self._format_line(k, entry.split(": ")[1])
        return output