#pragma once
/// \file
///
/// This header provides an API for extending PyTorch's core library
/// of operators with user defined operators and data types. This
/// API can be used in a few ways:
///
/// * You can define new custom operators and classes with TORCH_LIBRARY(),
/// making them available for use in both eager Python as well as in
/// TorchScript. This API is modeled off of pybind11's `PYBIND11_MODULE`
/// macro, as the provided functionality is similar (pybind11 lets you bind
/// C++ to Python only; `torch/library.h` lets you bind C++ simultaneously to
/// Python and TorchScript).
///
/// * You can override existing operators with TORCH_LIBRARY_IMPL(),
/// providing a new implementation for these operators for a custom
/// backend (e.g., XLA). When you pass operators with tensors of your custom
/// backend, your overridden implementations will be called instead
/// of the standard implementations.
///
/// * You can use both capabilities at the same time, allowing you
/// to write custom operators that register CPU/CUDA/Autograd
/// implementations without having to write the boilerplate
/// conditionals yourself.
///
/// For a tutorial style introduction to the library API, check
/// out the [Extending TorchScript with Custom C++
/// Operators](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html)
/// tutorial.
///
/// ```
/// // Define a library whose operators live in the namespace 'myops'.
/// // You must define all of the operators for this library in
/// // this namespace.
/// TORCH_LIBRARY(myops, m) {
/// // Define a operator with exactly one implementation for all backends.
/// m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl);
///
/// // Define a schema for an operator, but provide no implementation
/// // (use this syntax if you want to use the dispatcher)
/// m.def("mul(Tensor self, Tensor other) -> Tensor");
///
/// // Provide an implementation for a defined operator (you can
/// // provide multiple; one per backend). The dispatcher takes care of
/// // calling the correct implementation depending on if we get a CPU
/// // tensor or a CUDA tensor
/// m.impl("mul", torch::kCPU, &mul_cpu_impl);
/// m.impl("mul", torch::kCUDA, &mul_cuda_impl);
/// }
///
/// // Define implementations for operators for a non-standard backend,
/// // e.g., XLA (valid values are entries of DispatchKey). This can
/// // be used to define operators in a different file than the initial
/// // TORCH_LIBRARY definition (e.g., if it is in an external library)
/// TORCH_LIBRARY_IMPL(myops, XLA, m) {
/// m.impl("mul", &mul_xla_impl);
/// }
/// ```
#include <ATen/core/op_registration/infer_schema.h>
#include <ATen/core/op_registration/op_allowlist.h>
#include <c10/core/DispatchKey.h>
#include <torch/csrc/jit/frontend/function_schema_parser.h>
// Just for inferFunctionSchemaFromFunctor
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/enum_tag.h>
namespace torch {
#if defined C10_MOBILE
/**
* The NoInferSchemaTag is a type name used to indicate that this call to the
* CppFunction constructor should not trigger schema inference from functor.
* Schema inference from functor utilizes template meta-programming, and is
* costly from a size perspective. Ideally, one would expect that the schema
* inference would require very little binary size since most of the
* computation can be done by the compiler at build time, but that isn't
* necessarily the case.
*
* Schema inference is elided only for mobile use-cases where we don't need
* the additional runtime cost or size overhead on client devices.
*
*/
struct NoInferSchemaTag {};
#endif
// For multipy/torchdeploy use case
enum class _RegisterOrVerify {
REGISTER,
VERIFY
};
template <class CurClass>
class class_;
/// Represents a C++ function that implements an operator. Most users won't
/// interact directly with this class, except via error messages: the
/// constructors this function define the set of permissible "function"-like
/// things you can bind via the interface.
///
/// This class erases the type of the passed in function, but durably records
/// the type via an inferred schema for the function.
class TORCH_API CppFunction final {
// TODO: This is morally the same thing as KernelRegistrationConfig, but it's
// opaque to the user.
public:
/// This overload accepts function pointers, e.g., `CppFunction(&add_impl)`
template <typename Func>
explicit CppFunction(
Func* f,
std::enable_if_t<
c10::guts::is_function_type<Func>::value,
std::nullptr_t> = nullptr)
: func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
cpp_signature_(c10::impl::CppSignature::make<Func>()),
schema_(
c10::detail::inferFunctionSchemaFromFunctor<std::decay_t<Func>>()),
debug_() {}
/// This overload accepts compile time function pointers, e.g.,
/// `CppFunction(TORCH_FN(add_impl))`
template <typename FuncPtr>
explicit CppFunction(
FuncPtr f,
std::enable_if_t<
c10::is_compile_time_function_pointer<FuncPtr>::value,
std::nullptr_t> = nullptr)
: func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
cpp_signature_(
c10::impl::CppSignature::make<typename FuncPtr::FuncType>()),
schema_(c10::detail::inferFunctionSchemaFromFunctor<
typename FuncPtr::FuncType>()),
debug_() {}
/// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
/// ... })`
template <typename Lambda>
explicit CppFunction(
Lambda&& f,
std::enable_if_t<
c10::guts::is_functor<std::decay_t<Lambda>>::value,
std::nullptr_t> = nullptr)
: func_(c10::KernelFunction::makeFromUnboxedLambda(
std::forward<Lambda>(f))),
cpp_signature_(c10::impl::CppSignature::make<Lambda>()),
schema_(c10::detail::inferFunctionSchemaFromFunctor<
std::decay_t<Lambda>>()),
debug_() {}
#if defined C10_MOBILE
/// This overload accepts function pointers, e.g., `CppFunction(&add_impl,
/// NoInferSchemaTag())`
template <typename Func>
explicit CppFunction(
Func* f,
NoInferSchemaTag,
std::enable_if_t<
c10::guts::is_function_type<Func>::value,
std::nullptr_t> = nullptr)
: func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
cpp_signature_(c10::impl::CppSignature::make<Func>())
// TODO: Don't go through WrapRuntimeKernelFunctor
,
schema_(nullptr),
debug_() {}
/// This overload accepts compile time function pointers, e.g.,
/// `CppFunction(TORCH_FN(add_impl), NoInferSchemaTag())`
template <typename FuncPtr>
explicit CppFunction(
FuncPtr f,
NoInferSchemaTag,
std::enable_if_t<
c10::is_compile_time_function_pointer<FuncPtr>::value,
std::nullptr_t> = nullptr)
: func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
cpp_signature_(
c10::impl::CppSignature::make<typename FuncPtr::FuncType>())
// TODO: Don't go through WrapRuntimeKernelFunctor
,
schema_(nullptr),
debug_() {}
/// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
/// ... }. NoInferSchemaTag())`
template <typename Lambda>
explicit CppFunction(
Lambda&& f,
NoInferSchemaTag,
std::enable_if_t<
c10::guts::is_functor<std::decay_t<Lambda>>::value,
std::nullptr_t> = nullptr)
: func_(c10::KernelFunction::makeFromUnboxedLambda(
std::forward<Lambda>(f))),
cpp_signature_(c10::impl::CppSignature::make<Lambda>())
// TODO: Don't go through WrapRuntimeKernelFunctor
,
schema_(nullptr),
debug_() {}
#endif
~CppFunction();
CppFunction(CppFunction&&) noexcept = default;
CppFunction& operator=(CppFunction&&) = default;
/// \private
/// Creates a function from a type-erased boxed kernel.
static CppFunction makeFromBoxedKernel(c10::BoxedKernel kernel) {
return CppFunction(
c10::KernelFunction::makeFromBoxedKernel(std::move(kernel)),
/* cpp_signature */ c10::nullopt, // not known for boxed functions
/* schema */ nullptr);
}
/// This creates a fallthrough function. Fallthrough functions
/// immediately redispatch to the next available dispatch key,
/// but are implemented more efficiently than a hand written
/// function done in the same way.
static CppFunction makeFallthrough() {
return makeFromBoxedKernel(c10::BoxedKernel::makeFallthrough());
}
/// \private
///
/// Creates a function that raises an error saying that named tensors
/// are not supported when called.
static CppFunction makeNamedNotSupported() {
return makeFromBoxedKernel(c10::BoxedKernel::makeNamedNotSupported());
}
/// Create a function from a boxed kernel function with signature
/// `void(const OperatorHandle&, Stack*)`; i.e., they receive a
/// stack of arguments in a boxed calling convention, rather than
/// in the native C++ calling convention. Boxed functions are
/// typically only used to register backend fallbacks via
/// torch::Library::fallback().
template <c10::BoxedKernel::BoxedKernelFunction* func>
static CppFunction makeFromBoxedFunction() {
return makeFromBoxedKernel(
c10::BoxedKernel::makeFromFunction<func>());
}
// Variant that takes in a boxed kernel function with a plumbed
// DispatchKeySet. See Note [Plumbing Keys Through The Dispatcher] for
// details.
template <c10::BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
static CppFunction makeFromBoxedFunction() {
return makeFromBoxedKernel(
c10::BoxedKernel::makeFromFunction<func>());
}
/// Create a function from a boxed kernel functor which defines
/// `operator()(const OperatorHandle&, DispatchKeySet, Stack*)`
/// (receiving arguments from boxed calling convention) and inherits
/// from `c10::OperatorKernel`. Unlike makeFromBoxedFunction, functions
/// registered in this way can also carry additional state which
/// is managed by the functor; this is useful if you're writing an
/// adapter to some other implementation, e.g., a Python callable, which
/// is dynamically associated with the registered kernel.
template <class KernelFunctor>
static CppFunction makeFromBoxedFunctor(
std::unique_ptr<KernelFunctor> kernelFunctor) {
return makeFromBoxedKernel(
c10::BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
}
/// Create a function from an unboxed kernel function.
/// This is typically used to register common operators.
template <
typename FuncPtr,
std::enable_if_t<
c10::guts::is_function_type<FuncPtr>::value,
std::nullptr_t> = nullptr>
static CppFunction makeFromUnboxedFunction(FuncPtr* f) {
return CppFunction(f);
}
/// Create a function from a compile time unboxed kernel function pointer.
/// This is typically used to register common operators.
/// Compile time function pointers can be used to allow the compiler
/// to optimize (e.g. inline) calls to it.
template <
typename FuncPtr,
std::enable_if_t<
c10::is_compile_time_function_pointer<FuncPtr>::value,
std::nullptr_t> = nullptr>
static CppFunction makeFromUnboxedFunction(FuncPtr f) {
return CppFunction(f);
}
CppFunction&& debug(std::string d) && {
debug_ = std::move(d);
return std::move(*this);
}
private:
c10::optional<c10::DispatchKey> dispatch_key_;
c10::KernelFunction func_;
c10::optional<c10::impl::CppSignature> cpp_signature_;
std::unique_ptr<c10::FunctionSchema> schema_;
std::string debug_;
// The "setter" for dispatch_key_
template <typename Func>
friend CppFunction dispatch(c10::DispatchKey, Func&&);
// The only class which actually pulls out values from CppFunction (does so
// destructively, felt too lazy to write accessors that I don't even
// want users to use)
friend class Library;
CppFunction(
c10::KernelFunction func,
c10::optional<c10::impl::CppSignature> cpp_signature,
std::unique_ptr<c10::FunctionSchema> schema);
};
/// \defgroup torch-dispatch-overloads torch::dispatch overloads
/// Create a torch::CppFunction which is associated with a specific
/// dispatch key. torch::CppFunctions that are tagged with a
/// c10::DispatchKey don't get invoked unless the dispatcher determines
/// that this particular c10::DispatchKey is the one that should be
/// dispatched to.
///
/// This function is generally not used directly, instead, prefer using
/// TORCH_LIBRARY_IMPL(), which will implicitly set the c10::DispatchKey
/// for all registration calls inside of its body.
///
/// \ingroup torch-dispatch-overloads
template <typename Func>
inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) {
CppFunction f(std::forward<Func>(raw_f));
if (k == c10::DispatchKey::CatchAll) {
f.dispatch_key_ = c10::nullopt;
} else {
f.dispatch_key_ = k;
}
return f;
}
Loading ...