#pragma once
#include <ATen/core/builtin_function.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/class_type.h>
#include <ATen/core/op_registration/infer_schema.h>
#include <ATen/core/stack.h>
#include <c10/util/C++17.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
#include <c10/util/TypeTraits.h>
#include <torch/custom_class_detail.h>
#include <torch/library.h>
#include <iostream>
#include <sstream>
namespace torch {
/// This function is used in conjunction with `class_::def()` to register
/// a constructor for a given C++ class type. For example,
/// `torch::init<int, std::string>()` would register a two-argument constructor
/// taking an `int` and a `std::string` as argument.
template <class... Types>
detail::types<void, Types...> init() {
return detail::types<void, Types...>{};
}
template <typename Func, typename... ParameterTypeList>
struct InitLambda {
Func f;
};
template <typename Func>
decltype(auto) init(Func&& f) {
using InitTraits = c10::guts::infer_function_traits_t<std::decay_t<Func>>;
using ParameterTypeList = typename InitTraits::parameter_types;
InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
return init;
}
/// Entry point for custom C++ class registration. To register a C++ class
/// in PyTorch, instantiate `torch::class_` with the desired class as the
/// template parameter. Typically, this instantiation should be done in
/// the initialization of a global variable, so that the class will be
/// made available on dynamic library loading without any additional API
/// calls needed. For example, to register a class named Foo, you might
/// create a global variable like so:
///
/// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
/// .def("myMethod", &Foo::myMethod)
/// .def("lambdaMethod", [](const c10::intrusive_ptr<Foo>& self) {
/// // Do something with `self`
/// });
///
/// In addition to registering the class, this registration also chains
/// `def()` calls to register methods. `myMethod()` is registered with
/// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()`
/// is registered with a C++ lambda expression.
template <class CurClass>
class class_ : public ::torch::detail::class_base {
static_assert(
std::is_base_of<CustomClassHolder, CurClass>::value,
"torch::class_<T> requires T to inherit from CustomClassHolder");
public:
/// This constructor actually registers the class type.
/// String argument `namespaceName` is an identifier for the
/// namespace you would like this class to appear in.
/// String argument `className` is the name you would like to
/// see this class exposed as in Python and TorchScript. For example, if
/// you pass `foo` as the namespace name and `Bar` as the className, the
/// class will appear as `torch.classes.foo.Bar` in Python and TorchScript
explicit class_(
const std::string& namespaceName,
const std::string& className,
std::string doc_string = "")
: class_base(
namespaceName,
className,
std::move(doc_string),
typeid(c10::intrusive_ptr<CurClass>),
typeid(c10::tagged_capsule<CurClass>)) {}
/// def() can be used in conjunction with `torch::init()` to register
/// a constructor for a given C++ class type. For example, passing
/// `torch::init<int, std::string>()` would register a two-argument
/// constructor taking an `int` and a `std::string` as argument.
template <typename... Types>
class_& def(
torch::detail::types<void, Types...>,
std::string doc_string = "",
std::initializer_list<arg> default_args =
{}) { // Used in combination with
// torch::init<...>()
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
auto classObj = c10::make_intrusive<CurClass>(args...);
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
};
defineMethod(
"__init__",
std::move(func),
std::move(doc_string),
std::move(default_args));
return *this;
}
// Used in combination with torch::init([]lambda(){......})
template <typename Func, typename... ParameterTypes>
class_& def(
InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
std::string doc_string = "",
std::initializer_list<arg> default_args = {}) {
auto init_lambda_wrapper = [func = std::move(init.f)](
c10::tagged_capsule<CurClass> self,
ParameterTypes... arg) {
c10::intrusive_ptr<CurClass> classObj =
at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(classObj));
};
defineMethod(
"__init__",
std::move(init_lambda_wrapper),
std::move(doc_string),
std::move(default_args));
return *this;
}
/// This is the normal method registration API. `name` is the name that
/// the method will be made accessible by in Python and TorchScript.
/// `f` is a callable object that defines the method. Typically `f`
/// will either be a pointer to a method on `CurClass`, or a lambda
/// expression that takes a `c10::intrusive_ptr<CurClass>` as the first
/// argument (emulating a `this` argument in a C++ method.)
///
/// Examples:
///
/// // Exposes method `foo` on C++ class `Foo` as `call_foo()` in
/// // Python and TorchScript
/// .def("call_foo", &Foo::foo)
///
/// // Exposes the given lambda expression as method `call_lambda()`
/// // in Python and TorchScript.
/// .def("call_lambda", [](const c10::intrusive_ptr<Foo>& self) {
/// // do something
/// })
template <typename Func>
class_& def(
std::string name,
Func f,
std::string doc_string = "",
std::initializer_list<arg> default_args = {}) {
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
defineMethod(
std::move(name),
std::move(wrapped_f),
std::move(doc_string),
std::move(default_args));
return *this;
}
/// Method registration API for static methods.
template <typename Func>
class_& def_static(std::string name, Func func, std::string doc_string = "") {
auto qualMethodName = qualClassName + "." + name;
auto schema =
c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
auto wrapped_func =
[func = std::move(func)](jit::Stack& stack) mutable -> void {
using RetType =
typename c10::guts::infer_function_traits_t<Func>::return_type;
detail::BoxedProxy<RetType, Func>()(stack, func);
};
auto method = std::make_unique<jit::BuiltinOpFunction>(
std::move(qualMethodName),
std::move(schema),
std::move(wrapped_func),
std::move(doc_string));
classTypePtr->addStaticMethod(method.get());
registerCustomClassMethod(std::move(method));
return *this;
}
/// Property registration API for properties with both getter and setter
/// functions.
template <typename GetterFunc, typename SetterFunc>
class_& def_property(
const std::string& name,
GetterFunc getter_func,
SetterFunc setter_func,
std::string doc_string = "") {
torch::jit::Function* getter{};
torch::jit::Function* setter{};
auto wrapped_getter =
detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
auto wrapped_setter =
detail::wrap_func<CurClass, SetterFunc>(std::move(setter_func));
setter = defineMethod(name + "_setter", wrapped_setter, doc_string);
classTypePtr->addProperty(name, getter, setter);
return *this;
}
/// Property registration API for properties with only getter function.
template <typename GetterFunc>
class_& def_property(
const std::string& name,
GetterFunc getter_func,
std::string doc_string = "") {
torch::jit::Function* getter{};
auto wrapped_getter =
detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
classTypePtr->addProperty(name, getter, nullptr);
return *this;
}
/// Property registration API for properties with read-write access.
template <typename T>
class_& def_readwrite(const std::string& name, T CurClass::*field) {
auto getter_func = [field =
field](const c10::intrusive_ptr<CurClass>& self) {
return self.get()->*field;
};
auto setter_func = [field = field](
const c10::intrusive_ptr<CurClass>& self, T value) {
self.get()->*field = value;
};
return def_property(name, getter_func, setter_func);
}
/// Property registration API for properties with read-only access.
template <typename T>
class_& def_readonly(const std::string& name, T CurClass::*field) {
auto getter_func =
[field = std::move(field)](const c10::intrusive_ptr<CurClass>& self) {
return self.get()->*field;
};
return def_property(name, getter_func);
}
/// This is an unsafe method registration API added for adding custom JIT
/// backend support via custom C++ classes. It is not for general purpose use.
class_& _def_unboxed(
std::string name,
std::function<void(jit::Stack&)> func,
c10::FunctionSchema schema,
std::string doc_string = "") {
auto method = std::make_unique<jit::BuiltinOpFunction>(
qualClassName + "." + name,
std::move(schema),
std::move(func),
std::move(doc_string));
classTypePtr->addMethod(method.get());
registerCustomClassMethod(std::move(method));
return *this;
}
/// def_pickle() is used to define exactly what state gets serialized
/// or deserialized for a given instance of a custom C++ class in
/// Python or TorchScript. This protocol is equivalent to the Pickle
/// concept of `__getstate__` and `__setstate__` from Python
/// (https://docs.python.org/2/library/pickle.html#object.__getstate__)
///
/// Currently, both the `get_state` and `set_state` callables must be
/// C++ lambda expressions. They should have the following signatures,
/// where `CurClass` is the class you're registering and `T1` is some object
/// that encapsulates the state of the object.
///
/// __getstate__(intrusive_ptr<CurClass>) -> T1
/// __setstate__(T2) -> intrusive_ptr<CurClass>
///
/// `T1` must be an object that is convertable to IValue by the same rules
/// for custom op/method registration.
///
/// For the common case, T1 == T2. T1 can also be a subtype of T2. An
/// example where it makes sense for T1 and T2 to differ is if __setstate__
/// handles legacy formats in a backwards compatible way.
///
/// Example:
///
/// .def_pickle(
/// // __getstate__
/// [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
/// return self->stack_;
/// },
/// [](std::vector<std::string> state) { // __setstate__
/// return c10::make_intrusive<MyStackClass<std::string>>(
/// std::vector<std::string>{"i", "was", "deserialized"});
/// })
template <typename GetStateFn, typename SetStateFn>
class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
static_assert(
c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
"def_pickle() currently only supports lambdas as "
"__getstate__ and __setstate__ arguments.");
def("__getstate__", std::forward<GetStateFn>(get_state));
// __setstate__ needs to be registered with some custom handling:
// We need to wrap the invocation of the user-provided function
// such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
// and assign it to the `capsule` attribute.
using SetStateTraits =
c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
using SetStateArg = typename c10::guts::typelist::head_t<
typename SetStateTraits::parameter_types>;
auto setstate_wrapper = [set_state = std::forward<SetStateFn>(set_state)](
c10::tagged_capsule<CurClass> self,
SetStateArg&& arg) {
c10::intrusive_ptr<CurClass> classObj =
at::guts::invoke(set_state, std::forward<SetStateArg>(arg));
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(classObj));
};
defineMethod(
"__setstate__",
detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
std::move(setstate_wrapper)));
// type validation
auto getstate_schema = classTypePtr->getMethod("__getstate__").getSchema();
auto format_getstate_schema = [&getstate_schema]() {
std::stringstream ss;
ss << getstate_schema;
return ss.str();
};
TORCH_CHECK(
getstate_schema.arguments().size() == 1,
Loading ...