Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / library.h

#pragma once

/// \file
///
/// This header provides an API for extending PyTorch's core library
/// of operators with user defined operators and data types.  This
/// API can be used in a few ways:
///
/// * You can define new custom operators and classes with TORCH_LIBRARY(),
///   making them available for use in both eager Python as well as in
///   TorchScript. This API is modeled off of pybind11's `PYBIND11_MODULE`
///   macro, as the provided functionality is similar (pybind11 lets you bind
///   C++ to Python only; `torch/library.h` lets you bind C++ simultaneously to
///   Python and TorchScript).
///
/// * You can override existing operators with TORCH_LIBRARY_IMPL(),
///   providing a new implementation for these operators for a custom
///   backend (e.g., XLA).  When you pass operators with tensors of your custom
///   backend, your overridden implementations will be called instead
///   of the standard implementations.
///
/// * You can use both capabilities at the same time, allowing you
///   to write custom operators that register CPU/CUDA/Autograd
///   implementations without having to write the boilerplate
///   conditionals yourself.
///
/// For a tutorial style introduction to the library API, check
/// out the [Extending TorchScript with Custom C++
/// Operators](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html)
/// tutorial.
///
/// ```
/// // Define a library whose operators live in the namespace 'myops'.
/// // You must define all of the operators for this library in
/// // this namespace.
/// TORCH_LIBRARY(myops, m) {
///   // Define a operator with exactly one implementation for all backends.
///   m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl);
///
///   // Define a schema for an operator, but provide no implementation
///   // (use this syntax if you want to use the dispatcher)
///   m.def("mul(Tensor self, Tensor other) -> Tensor");
///
///   // Provide an implementation for a defined operator (you can
///   // provide multiple; one per backend).  The dispatcher takes care of
///   // calling the correct implementation depending on if we get a CPU
///   // tensor or a CUDA tensor
///   m.impl("mul", torch::kCPU, &mul_cpu_impl);
///   m.impl("mul", torch::kCUDA, &mul_cuda_impl);
/// }
///
/// // Define implementations for operators for a non-standard backend,
/// // e.g., XLA (valid values are entries of DispatchKey).  This can
/// // be used to define operators in a different file than the initial
/// // TORCH_LIBRARY definition (e.g., if it is in an external library)
/// TORCH_LIBRARY_IMPL(myops, XLA, m) {
///   m.impl("mul", &mul_xla_impl);
/// }
/// ```

#include <ATen/core/op_registration/infer_schema.h>
#include <ATen/core/op_registration/op_allowlist.h>
#include <c10/core/DispatchKey.h>
#include <torch/csrc/jit/frontend/function_schema_parser.h>

// Just for inferFunctionSchemaFromFunctor
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/enum_tag.h>

namespace torch {

#if defined C10_MOBILE
/**
 * The NoInferSchemaTag is a type name used to indicate that this call to the
 * CppFunction constructor should not trigger schema inference from functor.
 * Schema inference from functor utilizes template meta-programming, and is
 * costly from a size perspective. Ideally, one would expect that the schema
 * inference would require very little binary size since most of the
 * computation can be done by the compiler at build time, but that isn't
 * necessarily the case.
 *
 * Schema inference is elided only for mobile use-cases where we don't need
 * the additional runtime cost or size overhead on client devices.
 *
 */
struct NoInferSchemaTag {};
#endif

// For multipy/torchdeploy use case
enum class _RegisterOrVerify {
  REGISTER,
  VERIFY
};

template <class CurClass>
class class_;

/// Represents a C++ function that implements an operator.  Most users won't
/// interact directly with this class, except via error messages: the
/// constructors this function define the set of permissible "function"-like
/// things you can bind via the interface.
///
/// This class erases the type of the passed in function, but durably records
/// the type via an inferred schema for the function.
class TORCH_API CppFunction final {
  // TODO: This is morally the same thing as KernelRegistrationConfig, but it's
  // opaque to the user.

 public:
  /// This overload accepts function pointers, e.g., `CppFunction(&add_impl)`
  template <typename Func>
  explicit CppFunction(
      Func* f,
      std::enable_if_t<
          c10::guts::is_function_type<Func>::value,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
        cpp_signature_(c10::impl::CppSignature::make<Func>()),
        schema_(
            c10::detail::inferFunctionSchemaFromFunctor<std::decay_t<Func>>()),
        debug_() {}

  /// This overload accepts compile time function pointers, e.g.,
  /// `CppFunction(TORCH_FN(add_impl))`
  template <typename FuncPtr>
  explicit CppFunction(
      FuncPtr f,
      std::enable_if_t<
          c10::is_compile_time_function_pointer<FuncPtr>::value,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
        cpp_signature_(
            c10::impl::CppSignature::make<typename FuncPtr::FuncType>()),
        schema_(c10::detail::inferFunctionSchemaFromFunctor<
                typename FuncPtr::FuncType>()),
        debug_() {}

  /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
  /// ... })`
  template <typename Lambda>
  explicit CppFunction(
      Lambda&& f,
      std::enable_if_t<
          c10::guts::is_functor<std::decay_t<Lambda>>::value,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedLambda(
            std::forward<Lambda>(f))),
        cpp_signature_(c10::impl::CppSignature::make<Lambda>()),
        schema_(c10::detail::inferFunctionSchemaFromFunctor<
                std::decay_t<Lambda>>()),
        debug_() {}

#if defined C10_MOBILE
  /// This overload accepts function pointers, e.g., `CppFunction(&add_impl,
  /// NoInferSchemaTag())`
  template <typename Func>
  explicit CppFunction(
      Func* f,
      NoInferSchemaTag,
      std::enable_if_t<
          c10::guts::is_function_type<Func>::value,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
        cpp_signature_(c10::impl::CppSignature::make<Func>())
        // TODO: Don't go through WrapRuntimeKernelFunctor
        ,
        schema_(nullptr),
        debug_() {}

  /// This overload accepts compile time function pointers, e.g.,
  /// `CppFunction(TORCH_FN(add_impl), NoInferSchemaTag())`
  template <typename FuncPtr>
  explicit CppFunction(
      FuncPtr f,
      NoInferSchemaTag,
      std::enable_if_t<
          c10::is_compile_time_function_pointer<FuncPtr>::value,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
        cpp_signature_(
            c10::impl::CppSignature::make<typename FuncPtr::FuncType>())
        // TODO: Don't go through WrapRuntimeKernelFunctor
        ,
        schema_(nullptr),
        debug_() {}

  /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
  /// ... }. NoInferSchemaTag())`
  template <typename Lambda>
  explicit CppFunction(
      Lambda&& f,
      NoInferSchemaTag,
      std::enable_if_t<
          c10::guts::is_functor<std::decay_t<Lambda>>::value,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedLambda(
            std::forward<Lambda>(f))),
        cpp_signature_(c10::impl::CppSignature::make<Lambda>())
        // TODO: Don't go through WrapRuntimeKernelFunctor
        ,
        schema_(nullptr),
        debug_() {}
#endif

  ~CppFunction();

  CppFunction(CppFunction&&) noexcept = default;

  CppFunction& operator=(CppFunction&&) = default;

  /// \private
  /// Creates a function from a type-erased boxed kernel.
  static CppFunction makeFromBoxedKernel(c10::BoxedKernel kernel) {
    return CppFunction(
        c10::KernelFunction::makeFromBoxedKernel(std::move(kernel)),
        /* cpp_signature */ c10::nullopt, // not known for boxed functions
        /* schema */ nullptr);
  }

  /// This creates a fallthrough function.  Fallthrough functions
  /// immediately redispatch to the next available dispatch key,
  /// but are implemented more efficiently than a hand written
  /// function done in the same way.
  static CppFunction makeFallthrough() {
    return makeFromBoxedKernel(c10::BoxedKernel::makeFallthrough());
  }

  /// \private
  ///
  /// Creates a function that raises an error saying that named tensors
  /// are not supported when called.
  static CppFunction makeNamedNotSupported() {
    return makeFromBoxedKernel(c10::BoxedKernel::makeNamedNotSupported());
  }

  /// Create a function from a boxed kernel function with signature
  /// `void(const OperatorHandle&, Stack*)`; i.e., they receive a
  /// stack of arguments in a boxed calling convention, rather than
  /// in the native C++ calling convention.  Boxed functions are
  /// typically only used to register backend fallbacks via
  /// torch::Library::fallback().
  template <c10::BoxedKernel::BoxedKernelFunction* func>
  static CppFunction makeFromBoxedFunction() {
    return makeFromBoxedKernel(
        c10::BoxedKernel::makeFromFunction<func>());
  }

  // Variant that takes in a boxed kernel function with a plumbed
  // DispatchKeySet. See Note [Plumbing Keys Through The Dispatcher] for
  // details.
  template <c10::BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
  static CppFunction makeFromBoxedFunction() {
    return makeFromBoxedKernel(
        c10::BoxedKernel::makeFromFunction<func>());
  }

  /// Create a function from a boxed kernel functor which defines
  /// `operator()(const OperatorHandle&, DispatchKeySet, Stack*)`
  /// (receiving arguments from boxed calling convention) and inherits
  /// from `c10::OperatorKernel`.  Unlike makeFromBoxedFunction, functions
  /// registered in this way can also carry additional state which
  /// is managed by the functor; this is useful if you're writing an
  /// adapter to some other implementation, e.g., a Python callable, which
  /// is dynamically associated with the registered kernel.
  template <class KernelFunctor>
  static CppFunction makeFromBoxedFunctor(
      std::unique_ptr<KernelFunctor> kernelFunctor) {
    return makeFromBoxedKernel(
        c10::BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
  }

  /// Create a function from an unboxed kernel function.
  /// This is typically used to register common operators.
  template <
      typename FuncPtr,
      std::enable_if_t<
          c10::guts::is_function_type<FuncPtr>::value,
          std::nullptr_t> = nullptr>
  static CppFunction makeFromUnboxedFunction(FuncPtr* f) {
    return CppFunction(f);
  }

  /// Create a function from a compile time unboxed kernel function pointer.
  /// This is typically used to register common operators.
  /// Compile time function pointers can be used to allow the compiler
  /// to optimize (e.g. inline) calls to it.
  template <
      typename FuncPtr,
      std::enable_if_t<
          c10::is_compile_time_function_pointer<FuncPtr>::value,
          std::nullptr_t> = nullptr>
  static CppFunction makeFromUnboxedFunction(FuncPtr f) {
    return CppFunction(f);
  }

  CppFunction&& debug(std::string d) && {
    debug_ = std::move(d);
    return std::move(*this);
  }

 private:
  c10::optional<c10::DispatchKey> dispatch_key_;
  c10::KernelFunction func_;
  c10::optional<c10::impl::CppSignature> cpp_signature_;
  std::unique_ptr<c10::FunctionSchema> schema_;
  std::string debug_;

  // The "setter" for dispatch_key_
  template <typename Func>
  friend CppFunction dispatch(c10::DispatchKey, Func&&);

  // The only class which actually pulls out values from CppFunction (does so
  // destructively, felt too lazy to write accessors that I don't even
  // want users to use)
  friend class Library;

  CppFunction(
      c10::KernelFunction func,
      c10::optional<c10::impl::CppSignature> cpp_signature,
      std::unique_ptr<c10::FunctionSchema> schema);
};

/// \defgroup torch-dispatch-overloads torch::dispatch overloads

/// Create a torch::CppFunction which is associated with a specific
/// dispatch key.  torch::CppFunctions that are tagged with a
/// c10::DispatchKey don't get invoked unless the dispatcher determines
/// that this particular c10::DispatchKey is the one that should be
/// dispatched to.
///
/// This function is generally not used directly, instead, prefer using
/// TORCH_LIBRARY_IMPL(), which will implicitly set the c10::DispatchKey
/// for all registration calls inside of its body.
///
/// \ingroup torch-dispatch-overloads
template <typename Func>
inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) {
  CppFunction f(std::forward<Func>(raw_f));
  if (k == c10::DispatchKey::CatchAll) {
    f.dispatch_key_ = c10::nullopt;
  } else {
    f.dispatch_key_ = k;
  }
  return f;
}
Loading ...