Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ include / c10 / util / Registry.h

#ifndef C10_UTIL_REGISTRY_H_
#define C10_UTIL_REGISTRY_H_

/**
 * Simple registry implementation that uses static variables to
 * register object creators during program initialization time.
 */

// NB: This Registry works poorly when you have other namespaces.
// Make all macro invocations from inside the at namespace.

#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <functional>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>

#include <c10/macros/Macros.h>
#include <c10/util/Type.h>

namespace c10 {

template <typename KeyType>
inline std::string KeyStrRepr(const KeyType& /*key*/) {
  return "[key type printing not supported]";
}

template <>
inline std::string KeyStrRepr(const std::string& key) {
  return key;
}

enum RegistryPriority {
  REGISTRY_FALLBACK = 1,
  REGISTRY_DEFAULT = 2,
  REGISTRY_PREFERRED = 3,
};

/**
 * @brief A template class that allows one to register classes by keys.
 *
 * The keys are usually a std::string specifying the name, but can be anything
 * that can be used in a std::map.
 *
 * You should most likely not use the Registry class explicitly, but use the
 * helper macros below to declare specific registries as well as registering
 * objects.
 */
template <class SrcType, class ObjectPtrType, class... Args>
class Registry {
 public:
  typedef std::function<ObjectPtrType(Args...)> Creator;

  Registry(bool warning = true)
      : registry_(), priority_(), terminate_(true), warning_(warning) {}

  void Register(
      const SrcType& key,
      Creator creator,
      const RegistryPriority priority = REGISTRY_DEFAULT) {
    std::lock_guard<std::mutex> lock(register_mutex_);
    // The if statement below is essentially the same as the following line:
    // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key
    //                                   << " registered twice.";
    // However, TORCH_CHECK_EQ depends on google logging, and since registration
    // is carried out at static initialization time, we do not want to have an
    // explicit dependency on glog's initialization function.
    if (registry_.count(key) != 0) {
      auto cur_priority = priority_[key];
      if (priority > cur_priority) {
#ifdef DEBUG
        std::string warn_msg =
            "Overwriting already registered item for key " + KeyStrRepr(key);
        fprintf(stderr, "%s\n", warn_msg.c_str());
#endif
        registry_[key] = creator;
        priority_[key] = priority;
      } else if (priority == cur_priority) {
        std::string err_msg =
            "Key already registered with the same priority: " + KeyStrRepr(key);
        fprintf(stderr, "%s\n", err_msg.c_str());
        if (terminate_) {
          std::exit(1);
        } else {
          throw std::runtime_error(err_msg);
        }
      } else if (warning_) {
        std::string warn_msg =
            "Higher priority item already registered, skipping registration of " +
            KeyStrRepr(key);
        fprintf(stderr, "%s\n", warn_msg.c_str());
      }
    } else {
      registry_[key] = creator;
      priority_[key] = priority;
    }
  }

  void Register(
      const SrcType& key,
      Creator creator,
      const std::string& help_msg,
      const RegistryPriority priority = REGISTRY_DEFAULT) {
    Register(key, creator, priority);
    help_message_[key] = help_msg;
  }

  inline bool Has(const SrcType& key) {
    return (registry_.count(key) != 0);
  }

  ObjectPtrType Create(const SrcType& key, Args... args) {
    auto it = registry_.find(key);
    if (it == registry_.end()) {
      // Returns nullptr if the key is not registered.
      return nullptr;
    }
    return it->second(args...);
  }

  /**
   * Returns the keys currently registered as a std::vector.
   */
  std::vector<SrcType> Keys() const {
    std::vector<SrcType> keys;
    keys.reserve(registry_.size());
    for (const auto& it : registry_) {
      keys.push_back(it.first);
    }
    return keys;
  }

  inline const std::unordered_map<SrcType, std::string>& HelpMessage() const {
    return help_message_;
  }

  const char* HelpMessage(const SrcType& key) const {
    auto it = help_message_.find(key);
    if (it == help_message_.end()) {
      return nullptr;
    }
    return it->second.c_str();
  }

  // Used for testing, if terminate is unset, Registry throws instead of
  // calling std::exit
  void SetTerminate(bool terminate) {
    terminate_ = terminate;
  }

 private:
  std::unordered_map<SrcType, Creator> registry_;
  std::unordered_map<SrcType, RegistryPriority> priority_;
  bool terminate_;
  const bool warning_;
  std::unordered_map<SrcType, std::string> help_message_;
  std::mutex register_mutex_;

  C10_DISABLE_COPY_AND_ASSIGN(Registry);
};

template <class SrcType, class ObjectPtrType, class... Args>
class Registerer {
 public:
  explicit Registerer(
      const SrcType& key,
      Registry<SrcType, ObjectPtrType, Args...>* registry,
      typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
      const std::string& help_msg = "") {
    registry->Register(key, creator, help_msg);
  }

  explicit Registerer(
      const SrcType& key,
      const RegistryPriority priority,
      Registry<SrcType, ObjectPtrType, Args...>* registry,
      typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
      const std::string& help_msg = "") {
    registry->Register(key, creator, help_msg, priority);
  }

  template <class DerivedType>
  static ObjectPtrType DefaultCreator(Args... args) {
    return ObjectPtrType(new DerivedType(args...));
  }
};

/**
 * C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function
 * declaration, as well as creating a convenient typename for its corresponding
 * registerer.
 */
// Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE
// as import and DEFINE as export, because these registry macros will be used
// in downstream shared libraries as well, and one cannot use *_API - the API
// macro will be defined on a per-shared-library basis. Semantically, when one
// declares a typed registry it is always going to be IMPORT, and when one
// defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE),
// the instantiation unit is always going to be exported.
//
// The only unique condition is when in the same file one does DECLARE and
// DEFINE - in Windows compilers, this generates a warning that dllimport and
// dllexport are mixed, but the warning is fine and linker will be properly
// exporting the symbol. Same thing happens in the gflags flag declaration and
// definition caes.
#define C10_DECLARE_TYPED_REGISTRY(                                        \
    RegistryName, SrcType, ObjectType, PtrType, ...)                       \
  C10_IMPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  RegistryName();                                                          \
  typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>   \
      Registerer##RegistryName

#define C10_DEFINE_TYPED_REGISTRY(                                         \
    RegistryName, SrcType, ObjectType, PtrType, ...)                       \
  C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  RegistryName() {                                                         \
    static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>*   \
        registry = new ::c10::                                             \
            Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>();       \
    return registry;                                                       \
  }

#define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING(                            \
    RegistryName, SrcType, ObjectType, PtrType, ...)                          \
  C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>*    \
  RegistryName() {                                                            \
    static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>*      \
        registry =                                                            \
            new ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>( \
                false);                                                       \
    return registry;                                                          \
  }

// Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated
// creator with comma in its templated arguments.
#define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...)                  \
  static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
      key, RegistryName(), ##__VA_ARGS__);

#define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY(                           \
    RegistryName, key, priority, ...)                                       \
  static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
      key, priority, RegistryName(), ##__VA_ARGS__);

#define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...)                    \
  static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
      key,                                                                  \
      RegistryName(),                                                       \
      Registerer##RegistryName::DefaultCreator<__VA_ARGS__>,                \
      ::c10::demangle_type<__VA_ARGS__>());

#define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY(                             \
    RegistryName, key, priority, ...)                                       \
  static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
      key,                                                                  \
      priority,                                                             \
      RegistryName(),                                                       \
      Registerer##RegistryName::DefaultCreator<__VA_ARGS__>,                \
      ::c10::demangle_type<__VA_ARGS__>());

// C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use
// std::string as the key type, because that is the most commonly used cases.
#define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
  C10_DECLARE_TYPED_REGISTRY(                               \
      RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)

#define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
  C10_DEFINE_TYPED_REGISTRY(                               \
      RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)

#define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \
  C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING(                               \
      RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)

#define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
  C10_DECLARE_TYPED_REGISTRY(                                      \
      RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)

#define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
  C10_DEFINE_TYPED_REGISTRY(                                      \
      RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)

#define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \
    RegistryName, ObjectType, ...)                  \
  C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING(        \
      RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)

// C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string
// as the key
// type, because that is the most commonly used cases.
#define C10_REGISTER_CREATOR(RegistryName, key, ...) \
  C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__)

#define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \
  C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY(                                  \
      RegistryName, #key, priority, __VA_ARGS__)

#define C10_REGISTER_CLASS(RegistryName, key, ...) \
  C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)

#define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \
  C10_REGISTER_TYPED_CLASS_WITH_PRIORITY(                                  \
      RegistryName, #key, priority, __VA_ARGS__)

} // namespace c10

#endif // C10_UTIL_REGISTRY_H_