Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / torch / csrc / jit / api / module.h

#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/object.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/named_value.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/graph_executor.h>

#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/api/include/torch/ordered_dict.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/utils/memory.h>

#include <ATen/core/function_schema.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>

#include <functional>
#include <memory>
#include <mutex>
#include <ostream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

// This file contains classes which assist in desugaring Python style
// modules and their methods into flattened graphs which don't have any
// function calls.

namespace torch {
namespace jit {

using ::c10::Argument;
using ::c10::FunctionSchema;
using ::c10::QualifiedName;
// Map which stores filename to content.
using ExtraFilesMap = std::unordered_map<std::string, std::string>;

using ModulePtr = c10::intrusive_ptr<c10::ivalue::Object>;

struct Module;

template <typename T>
struct slot_list_impl;

template <typename T>
struct Named {
  std::string name;
  T value;
};

using NameModule = Named<Module>;
using NameValue = Named<IValue>;
using NameTensor = Named<at::Tensor>;

namespace detail {
struct TORCH_API ModulePolicy;
struct TORCH_API ParameterPolicy;
struct TORCH_API AttributePolicy;
struct TORCH_API BufferPolicy;
template <typename P>
struct NamedPolicy;
} // namespace detail

using module_list = slot_list_impl<detail::ModulePolicy>;
using named_module_list =
    slot_list_impl<detail::NamedPolicy<detail::ModulePolicy>>;

using parameter_list = slot_list_impl<detail::ParameterPolicy>;
using named_parameter_list =
    slot_list_impl<detail::NamedPolicy<detail::ParameterPolicy>>;

using attribute_list = slot_list_impl<detail::AttributePolicy>;
using named_attribute_list =
    slot_list_impl<detail::NamedPolicy<detail::AttributePolicy>>;

using buffer_list = slot_list_impl<detail::BufferPolicy>;
using named_buffer_list =
    slot_list_impl<detail::NamedPolicy<detail::BufferPolicy>>;

using ModuleLookup = std::function<Module(const std::vector<std::string>&)>;

struct TORCH_API Module : public Object {
  explicit Module(c10::QualifiedName class_name);
  Module(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
  Module() = default;
  Module(
      c10::QualifiedName,
      std::shared_ptr<CompilationUnit> cu,
      bool shouldMangle = false);
  Module(ModulePtr module_value) : Object(std::move(module_value)) {}
  ~Module() = default;

  void set_optimized(bool o) {
    TORCH_WARN(
        "Module::set_optimized() is deprecated and has no effect. "
        "Please use setGraphExecutorOptimize()");
  }

  bool is_optimized() const {
    TORCH_WARN(
        "Module::is_optimized() is deprecated and always returns true. "
        "Please use getGraphExecutorOptimize()");
    return true;
  }

  IValue forward(std::vector<IValue> inputs) {
    return get_method("forward")(std::move(inputs));
  }

  // In script modules, buffers are Tensors attribute that are _not_ registered
  // as parameters. This is different than in nn.Module where there is a special
  // register_buffer method. With this simplification, we only need to track
  // whether a slot is a parameter to be able to classify it.
  void register_buffer(const std::string& name, at::Tensor v) {
    bool is_param = false;
    bool is_buffer = true;
    type()->addOrCheckAttribute(name, TensorType::get(), is_param, is_buffer);
    _ivalue()->setAttr(name, std::move(v));
  }

  void register_parameter(
      const std::string& name,
      at::Tensor v,
      bool is_buffer) {
    type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer, is_buffer);
    _ivalue()->setAttr(name, std::move(v));
  }

  void register_attribute(
      const std::string& name,
      const TypePtr& t,
      IValue v,
      bool is_param = false,
      bool is_buffer = false) {
    type()->addOrCheckAttribute(name, t, is_param, is_buffer);
    _ivalue()->setAttr(name, std::move(v));
  }

  void register_module(const std::string& name, const Module& module) {
    type()->addOrCheckAttribute(name, module.type());
    _ivalue()->setAttr(name, module._ivalue());
  }

  void apply(const std::function<void(Module&)>& fn);

  buffer_list buffers(bool recurse = true) const;
  named_buffer_list named_buffers(bool recurse = true) const;

  module_list children() const; // direct modules
  named_module_list named_children() const;
  module_list modules() const; // all modules, including this one, recursively
  named_module_list named_modules() const;

  // all tensors involved in gradient optimization
  parameter_list parameters(bool recurse = true) const;
  named_parameter_list named_parameters(bool recurse = true) const;

  // all members of the object, similar to iterating over dir(obj) in python
  attribute_list attributes(bool recurse = true) const;
  named_attribute_list named_attributes(bool recurse = true) const;

  void dump(
      bool print_method_bodies,
      bool print_attr_values,
      bool print_param_values) const;

  std::string dump_to_str(
      bool print_method_bodies,
      bool print_attr_values,
      bool print_param_values,
      int level) const;

  /// Enables "training" mode.
  void train(bool on = true);
  /// Calls train(false) to enable "eval" mode.
  /// Do not override this method, override `train()` instead.
  void eval() {
    train(/*on=*/false);
  }
  /// True if the module is in training mode.
  bool is_training() const {
    return attr("training", true).toBool();
  }

  /// Recursively casts all parameters to the given `dtype` and `device`.
  ///
  /// If `non_blocking` is true and the source is in pinned memory and
  /// destination is on the GPU or vice versa, the copy is performed
  /// asynchronously with respect to the host. Otherwise, the argument has no
  /// effect.
  void to(at::Device device, at::ScalarType dtype, bool non_blocking = false);

  /// Recursively casts all parameters to the given dtype.
  ///
  /// If `non_blocking` is true and the source is in pinned memory and
  /// destination is on the GPU or vice versa, the copy is performed
  /// asynchronously with respect to the host. Otherwise, the argument has no
  /// effect.
  void to(at::ScalarType dtype, bool non_blocking = false);

  /// Recursively moves all parameters to the given device.
  ///
  /// If `non_blocking` is true and the source is in pinned memory and
  /// destination is on the GPU or vice versa, the copy is performed
  /// asynchronously with respect to the host. Otherwise, the argument has no
  /// effect.
  void to(at::Device device, bool non_blocking = false);

  void save(
      std::ostream& out,
      const ExtraFilesMap& extra_files = ExtraFilesMap()) const;

  void save(
      const std::string& filename,
      const ExtraFilesMap& extra_files = ExtraFilesMap()) const;

  void _save_for_mobile(
      std::ostream& out,
      const ExtraFilesMap& extra_files = ExtraFilesMap(),
      bool save_mobile_debug_info = false) const;

  void _save_for_mobile(
      const std::string& filename,
      const ExtraFilesMap& extra_files = ExtraFilesMap(),
      bool save_mobile_debug_info = false) const;

  Module copy() const;

  Module deepcopy() const;

  // Clones both the underlying `ClassType` and the module instance(data), this
  // function creates a new `ClassType` and returns a new instance that has the
  // same data as the current instance but with the new type, shared ClassType
  // will be preserved as well
  Module clone(bool inplace = false) const;

  void clone_method(const Module& orig, const std::string& name);

  IValue operator()(std::vector<IValue> inputs);

  template <typename... Types>
  IValue create_class(const c10::QualifiedName& name, Types&&... args) const {
    return create_class(name, {IValue(std::forward<Types>(args))...});
  }

  IValue create_class(const c10::QualifiedName& name, Stack stack) const;

  inline bool operator==(const Module& y) const noexcept {
    return _ivalue() == y._ivalue();
  }

 private:
  Module clone_impl(
      std::unordered_map<TypePtr, TypePtr>& type_remap,
      bool inplace,
      IValue::HashAliasedIValueMap memo) const;

  void clone_method(
      const Module& orig,
      const Function& method,
      const std::unordered_map<TypePtr, TypePtr>& type_remap);

  c10::QualifiedName getNameForMethod(std::string basename) const {
    return QualifiedName(*type()->name(), std::move(basename));
  }

  void to_impl(
      const c10::optional<at::Device>& device,
      const c10::optional<at::ScalarType>& dtype,
      bool non_blocking);
};

// C++ equivalent api of `torch.jit.freeze`. See documentation there for
// details.
TORCH_API Module freeze(
    const Module& module,
    c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
    bool optimize_numerics = true);

namespace detail {

struct TORCH_API SlotCursor {
  Module module_;
  int64_t i_; // slot offset, -1 indicates the module itself
};

} // namespace detail

// This iterator allows the (optionally recursive) enumeration of
// the  members of a Module. It performs a depth-first pre-order
// traversal of the module. The Policy template parameter determines
// which slots of the object should be included. For instance,
// when iterating parameters, we return the parameter tensors,
// but skip modules, buffers, and other attributes.
// See ModulePolicy for comments about Policy object's API.
template <typename Policy>
struct slot_iterator_impl {
  using SlotCursor = detail::SlotCursor;
  using value_type = typename Policy::value_type;
  slot_iterator_impl(
      Module root,
      bool recurse, // if true, do a depth-first search, otherwise, just look at
                    // slots of root
      bool return_module) // if true include root itself as the first thing
                          // visited (used in modules())
      : cursors_({SlotCursor{root, return_module ? -1 : 0}}),
        recurse_(recurse) {
    // advance iterator to first valid element (or the end, if empty)
    while_not_valid_next();
  }
  // empty cursors_, represents end of iteration
  slot_iterator_impl() : recurse_(false) {}
  value_type operator*() const {
    return Policy::create(cursors_, cur());
  }
  value_type operator->() const {
    return **this;
  }
  slot_iterator_impl& operator++() {
    next_valid();
    return *this;
  }
  slot_iterator_impl operator++(int) {
    // this is really expensive, should we delete it so people don't use it
    // instead of prefix?
    slot_iterator_impl old = *this;
    ++(*this);
    return old;
  }

 private:
  // return_module() is a corner case where instead of returning a submodule
  // of root, we are returning root itself, because we are iterating modules(),
  // which contains the root module itself.
  // It is represented with a single SlotCursor whose index is -1.
  bool return_module() const {
    return top().i_ == -1;
  }
  const SlotCursor& top() const {
Loading ...