#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/object.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/named_value.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/api/include/torch/ordered_dict.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/utils/memory.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
#include <functional>
#include <memory>
#include <mutex>
#include <ostream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
// This file contains classes which assist in desugaring Python style
// modules and their methods into flattened graphs which don't have any
// function calls.
namespace torch {
namespace jit {
using ::c10::Argument;
using ::c10::FunctionSchema;
using ::c10::QualifiedName;
// Map which stores filename to content.
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
using ModulePtr = c10::intrusive_ptr<c10::ivalue::Object>;
struct Module;
template <typename T>
struct slot_list_impl;
template <typename T>
struct Named {
std::string name;
T value;
};
using NameModule = Named<Module>;
using NameValue = Named<IValue>;
using NameTensor = Named<at::Tensor>;
namespace detail {
struct TORCH_API ModulePolicy;
struct TORCH_API ParameterPolicy;
struct TORCH_API AttributePolicy;
struct TORCH_API BufferPolicy;
template <typename P>
struct NamedPolicy;
} // namespace detail
using module_list = slot_list_impl<detail::ModulePolicy>;
using named_module_list =
slot_list_impl<detail::NamedPolicy<detail::ModulePolicy>>;
using parameter_list = slot_list_impl<detail::ParameterPolicy>;
using named_parameter_list =
slot_list_impl<detail::NamedPolicy<detail::ParameterPolicy>>;
using attribute_list = slot_list_impl<detail::AttributePolicy>;
using named_attribute_list =
slot_list_impl<detail::NamedPolicy<detail::AttributePolicy>>;
using buffer_list = slot_list_impl<detail::BufferPolicy>;
using named_buffer_list =
slot_list_impl<detail::NamedPolicy<detail::BufferPolicy>>;
using ModuleLookup = std::function<Module(const std::vector<std::string>&)>;
struct TORCH_API Module : public Object {
explicit Module(c10::QualifiedName class_name);
Module(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
Module() = default;
Module(
c10::QualifiedName,
std::shared_ptr<CompilationUnit> cu,
bool shouldMangle = false);
Module(ModulePtr module_value) : Object(std::move(module_value)) {}
~Module() = default;
void set_optimized(bool o) {
TORCH_WARN(
"Module::set_optimized() is deprecated and has no effect. "
"Please use setGraphExecutorOptimize()");
}
bool is_optimized() const {
TORCH_WARN(
"Module::is_optimized() is deprecated and always returns true. "
"Please use getGraphExecutorOptimize()");
return true;
}
IValue forward(std::vector<IValue> inputs) {
return get_method("forward")(std::move(inputs));
}
// In script modules, buffers are Tensors attribute that are _not_ registered
// as parameters. This is different than in nn.Module where there is a special
// register_buffer method. With this simplification, we only need to track
// whether a slot is a parameter to be able to classify it.
void register_buffer(const std::string& name, at::Tensor v) {
bool is_param = false;
bool is_buffer = true;
type()->addOrCheckAttribute(name, TensorType::get(), is_param, is_buffer);
_ivalue()->setAttr(name, std::move(v));
}
void register_parameter(
const std::string& name,
at::Tensor v,
bool is_buffer) {
type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer, is_buffer);
_ivalue()->setAttr(name, std::move(v));
}
void register_attribute(
const std::string& name,
const TypePtr& t,
IValue v,
bool is_param = false,
bool is_buffer = false) {
type()->addOrCheckAttribute(name, t, is_param, is_buffer);
_ivalue()->setAttr(name, std::move(v));
}
void register_module(const std::string& name, const Module& module) {
type()->addOrCheckAttribute(name, module.type());
_ivalue()->setAttr(name, module._ivalue());
}
void apply(const std::function<void(Module&)>& fn);
buffer_list buffers(bool recurse = true) const;
named_buffer_list named_buffers(bool recurse = true) const;
module_list children() const; // direct modules
named_module_list named_children() const;
module_list modules() const; // all modules, including this one, recursively
named_module_list named_modules() const;
// all tensors involved in gradient optimization
parameter_list parameters(bool recurse = true) const;
named_parameter_list named_parameters(bool recurse = true) const;
// all members of the object, similar to iterating over dir(obj) in python
attribute_list attributes(bool recurse = true) const;
named_attribute_list named_attributes(bool recurse = true) const;
void dump(
bool print_method_bodies,
bool print_attr_values,
bool print_param_values) const;
std::string dump_to_str(
bool print_method_bodies,
bool print_attr_values,
bool print_param_values,
int level) const;
/// Enables "training" mode.
void train(bool on = true);
/// Calls train(false) to enable "eval" mode.
/// Do not override this method, override `train()` instead.
void eval() {
train(/*on=*/false);
}
/// True if the module is in training mode.
bool is_training() const {
return attr("training", true).toBool();
}
/// Recursively casts all parameters to the given `dtype` and `device`.
///
/// If `non_blocking` is true and the source is in pinned memory and
/// destination is on the GPU or vice versa, the copy is performed
/// asynchronously with respect to the host. Otherwise, the argument has no
/// effect.
void to(at::Device device, at::ScalarType dtype, bool non_blocking = false);
/// Recursively casts all parameters to the given dtype.
///
/// If `non_blocking` is true and the source is in pinned memory and
/// destination is on the GPU or vice versa, the copy is performed
/// asynchronously with respect to the host. Otherwise, the argument has no
/// effect.
void to(at::ScalarType dtype, bool non_blocking = false);
/// Recursively moves all parameters to the given device.
///
/// If `non_blocking` is true and the source is in pinned memory and
/// destination is on the GPU or vice versa, the copy is performed
/// asynchronously with respect to the host. Otherwise, the argument has no
/// effect.
void to(at::Device device, bool non_blocking = false);
void save(
std::ostream& out,
const ExtraFilesMap& extra_files = ExtraFilesMap()) const;
void save(
const std::string& filename,
const ExtraFilesMap& extra_files = ExtraFilesMap()) const;
void _save_for_mobile(
std::ostream& out,
const ExtraFilesMap& extra_files = ExtraFilesMap(),
bool save_mobile_debug_info = false) const;
void _save_for_mobile(
const std::string& filename,
const ExtraFilesMap& extra_files = ExtraFilesMap(),
bool save_mobile_debug_info = false) const;
Module copy() const;
Module deepcopy() const;
// Clones both the underlying `ClassType` and the module instance(data), this
// function creates a new `ClassType` and returns a new instance that has the
// same data as the current instance but with the new type, shared ClassType
// will be preserved as well
Module clone(bool inplace = false) const;
void clone_method(const Module& orig, const std::string& name);
IValue operator()(std::vector<IValue> inputs);
template <typename... Types>
IValue create_class(const c10::QualifiedName& name, Types&&... args) const {
return create_class(name, {IValue(std::forward<Types>(args))...});
}
IValue create_class(const c10::QualifiedName& name, Stack stack) const;
inline bool operator==(const Module& y) const noexcept {
return _ivalue() == y._ivalue();
}
private:
Module clone_impl(
std::unordered_map<TypePtr, TypePtr>& type_remap,
bool inplace,
IValue::HashAliasedIValueMap memo) const;
void clone_method(
const Module& orig,
const Function& method,
const std::unordered_map<TypePtr, TypePtr>& type_remap);
c10::QualifiedName getNameForMethod(std::string basename) const {
return QualifiedName(*type()->name(), std::move(basename));
}
void to_impl(
const c10::optional<at::Device>& device,
const c10::optional<at::ScalarType>& dtype,
bool non_blocking);
};
// C++ equivalent api of `torch.jit.freeze`. See documentation there for
// details.
TORCH_API Module freeze(
const Module& module,
c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
bool optimize_numerics = true);
namespace detail {
struct TORCH_API SlotCursor {
Module module_;
int64_t i_; // slot offset, -1 indicates the module itself
};
} // namespace detail
// This iterator allows the (optionally recursive) enumeration of
// the members of a Module. It performs a depth-first pre-order
// traversal of the module. The Policy template parameter determines
// which slots of the object should be included. For instance,
// when iterating parameters, we return the parameter tensors,
// but skip modules, buffers, and other attributes.
// See ModulePolicy for comments about Policy object's API.
template <typename Policy>
struct slot_iterator_impl {
using SlotCursor = detail::SlotCursor;
using value_type = typename Policy::value_type;
slot_iterator_impl(
Module root,
bool recurse, // if true, do a depth-first search, otherwise, just look at
// slots of root
bool return_module) // if true include root itself as the first thing
// visited (used in modules())
: cursors_({SlotCursor{root, return_module ? -1 : 0}}),
recurse_(recurse) {
// advance iterator to first valid element (or the end, if empty)
while_not_valid_next();
}
// empty cursors_, represents end of iteration
slot_iterator_impl() : recurse_(false) {}
value_type operator*() const {
return Policy::create(cursors_, cur());
}
value_type operator->() const {
return **this;
}
slot_iterator_impl& operator++() {
next_valid();
return *this;
}
slot_iterator_impl operator++(int) {
// this is really expensive, should we delete it so people don't use it
// instead of prefix?
slot_iterator_impl old = *this;
++(*this);
return old;
}
private:
// return_module() is a corner case where instead of returning a submodule
// of root, we are returning root itself, because we are iterating modules(),
// which contains the root module itself.
// It is represented with a single SlotCursor whose index is -1.
bool return_module() const {
return top().i_ == -1;
}
const SlotCursor& top() const {
Loading ...