Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch / include / ATen / PythonTorchFunctionTLS.h
Size: Mime:
#pragma once

#include <c10/core/SafePyObject.h>
#include <c10/macros/Macros.h>

namespace at {
namespace impl {

enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };

struct TORCH_API PythonTorchFunctionTLS {
  static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
  static TorchFunctionDisabledState get_disabled_state();

  static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
  static const std::shared_ptr<SafePyObject> pop_stack();
  static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
  static int64_t stack_len();

  static const PythonTorchFunctionTLS& get_state();
  static void set_state(const PythonTorchFunctionTLS& state);

 private:
  // The mode TLS is split into
  //   - disabled_state, which says which part of torch function are disabled
  //   - stack_, which is a vector of modes representing the stack of user
  //   defined modes
  TorchFunctionDisabledState disabled_state_ =
      TorchFunctionDisabledState::ENABLED;
  std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
};

TORCH_API bool torch_function_mode_enabled();

} // namespace impl
} // namespace at