Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / custom_class_detail.h

#pragma once

#include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h>
#include <ATen/core/function.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeTraits.h>
#include <c10/util/irange.h>

namespace torch {

namespace detail {
/**
 * In the Facebook internal build (using BUCK), this macro is enabled by
 * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
 * binary.
 */
#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
TORCH_API void record_custom_class(std::string name);

/**
 * Record an instance of a custom class being loaded
 * grab portion of string after final '.' from qualified name
 * as this seemingly aligns with how users name their custom classes
 * example: __torch__.torch.classes.xnnpack.Conv2dOpContext
 */
#define RECORD_CUSTOM_CLASS(NAME) \
  auto name = std::string(NAME);  \
  detail::record_custom_class(name.substr(name.find_last_of(".") + 1));
#else
#define RECORD_CUSTOM_CLASS(NAME)
#endif
} // namespace detail

/// This struct is used to represent default values for arguments
/// when registering methods for custom classes.
///     static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
///       .def("myMethod", &Foo::myMethod, {torch::arg("name") = name});
struct arg {
  // Static method for representing a default value of None. This is meant to
  // be used like so:
  //     torch::arg("name") = torch::arg::none
  // and is identical to:
  //     torch::arg("name") = IValue()
  static c10::IValue none() {
    return c10::IValue();
  }

  // Explicit constructor.
  explicit arg(std::string name)
      : name_(std::move(name)), value_(c10::nullopt) {}
  // Assignment operator. This enables the pybind-like syntax of
  // torch::arg("name") = value.
  arg& operator=(const c10::IValue& rhs) {
    value_ = rhs;
    return *this;
  }

  // The name of the argument. This is copied to the schema; argument
  // names cannot be extracted from the C++ declaration.
  std::string name_;
  // IValue's default constructor makes it None, which is not distinguishable
  // from an actual, user-provided default value that is None. This boolean
  // helps distinguish between the two cases.
  c10::optional<c10::IValue> value_;
};

namespace detail {

// Argument type utilities
template <class R, class...>
struct types {
  using type = types;
};

template <typename Method>
struct WrapMethod;

template <typename R, typename CurrClass, typename... Args>
struct WrapMethod<R (CurrClass::*)(Args...)> {
  WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {}

  R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
    return c10::guts::invoke(m, *cur, args...);
  }

  R (CurrClass::*m)(Args...);
};

template <typename R, typename CurrClass, typename... Args>
struct WrapMethod<R (CurrClass::*)(Args...) const> {
  WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {}

  R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
    return c10::guts::invoke(m, *cur, args...);
  }

  R (CurrClass::*m)(Args...) const;
};

// Adapter for different callable types
template <
    typename CurClass,
    typename Func,
    std::enable_if_t<
        std::is_member_function_pointer<std::decay_t<Func>>::value,
        bool> = false>
WrapMethod<Func> wrap_func(Func f) {
  return WrapMethod<Func>(std::move(f));
}

template <
    typename CurClass,
    typename Func,
    std::enable_if_t<
        !std::is_member_function_pointer<std::decay_t<Func>>::value,
        bool> = false>
Func wrap_func(Func f) {
  return f;
}

template <
    class Functor,
    bool AllowDeprecatedTypes,
    size_t... ivalue_arg_indices>
typename c10::guts::infer_function_traits_t<Functor>::return_type
call_torchbind_method_from_stack(
    Functor& functor,
    jit::Stack& stack,
    std::index_sequence<ivalue_arg_indices...>) {
  (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
                 // be unused and we have to silence the compiler warning.

  constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices);

  using IValueArgTypes =
      typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
  // TODO We shouldn't use c10::impl stuff directly here. We should use the
  // KernelFunction API instead.
  return (functor)(c10::impl::ivalue_to_arg<
                   typename c10::impl::decay_if_not_tensor<
                       c10::guts::typelist::
                           element_t<ivalue_arg_indices, IValueArgTypes>>::type,
                   AllowDeprecatedTypes>::
                       call(torch::jit::peek(
                           stack, ivalue_arg_indices, num_ivalue_args))...);
}

template <class Functor, bool AllowDeprecatedTypes>
typename c10::guts::infer_function_traits_t<Functor>::return_type
call_torchbind_method_from_stack(Functor& functor, jit::Stack& stack) {
  constexpr size_t num_ivalue_args =
      c10::guts::infer_function_traits_t<Functor>::number_of_parameters;
  return call_torchbind_method_from_stack<Functor, AllowDeprecatedTypes>(
      functor, stack, std::make_index_sequence<num_ivalue_args>());
}

template <class RetType, class Func>
struct BoxedProxy;

template <class RetType, class Func>
struct BoxedProxy {
  void operator()(jit::Stack& stack, Func& func) {
    auto retval = call_torchbind_method_from_stack<Func, false>(func, stack);
    constexpr size_t num_ivalue_args =
        c10::guts::infer_function_traits_t<Func>::number_of_parameters;
    torch::jit::drop(stack, num_ivalue_args);
    stack.emplace_back(c10::ivalue::from(std::move(retval)));
  }
};

template <class Func>
struct BoxedProxy<void, Func> {
  void operator()(jit::Stack& stack, Func& func) {
    call_torchbind_method_from_stack<Func, false>(func, stack);
    constexpr size_t num_ivalue_args =
        c10::guts::infer_function_traits_t<Func>::number_of_parameters;
    torch::jit::drop(stack, num_ivalue_args);
    stack.emplace_back();
  }
};

inline bool validIdent(size_t i, char n) {
  return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
}

inline void checkValidIdent(const std::string& str, const char* type) {
  for (const auto i : c10::irange(str.size())) {
    TORCH_CHECK(
        validIdent(i, str[i]),
        type,
        " must be a valid Python/C++ identifier."
        " Character '",
        str[i],
        "' at index ",
        i,
        " is illegal.");
  }
}

class TORCH_API class_base {
 protected:
  explicit class_base(
      const std::string& namespaceName,
      const std::string& className,
      std::string doc_string,
      const std::type_info& intrusivePtrClassTypeid,
      const std::type_info& taggedCapsuleClass);

  static c10::FunctionSchema withNewArguments(
      const c10::FunctionSchema& schema,
      std::initializer_list<arg> default_args);
  std::string qualClassName;
  at::ClassTypePtr classTypePtr;
};

} // namespace detail

TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
TORCH_API void registerCustomClassMethod(std::unique_ptr<jit::Function> method);

// Given a qualified name (e.g. __torch__.torch.classes.Foo), return
// the ClassType pointer to the Type that describes that custom class,
// or nullptr if no class by that name was found.
TORCH_API at::ClassTypePtr getCustomClass(const std::string& name);

// Given an IValue, return true if the object contained in that IValue
// is a custom C++ class, otherwise return false.
TORCH_API bool isCustomClass(const c10::IValue& v);

// This API is for testing purposes ONLY. It should not be used in
// any load-bearing code.
TORCH_API std::vector<c10::FunctionSchema> customClassSchemasForBCCheck();

namespace jit {
using ::torch::registerCustomClass;
using ::torch::registerCustomClassMethod;
} // namespace jit

} // namespace torch