Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / torch / csrc / jit / api / compilation_unit.h

#pragma once
#include <ATen/core/function.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/name_mangler.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/graph_executor.h>

#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/utils/memory.h>

#include <ATen/core/function_schema.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>

#include <functional>
#include <memory>
#include <mutex>
#include <ostream>
#include <string>
#include <unordered_map>
#include <vector>

namespace torch {
namespace jit {

struct Def;
struct Property;
struct ClassDef;
struct SugaredValue;
struct Resolver;

using ResolverPtr = std::shared_ptr<Resolver>;
struct Self {
  virtual ~Self() = default;
  virtual std::shared_ptr<SugaredValue> makeSugared(Value* v) const = 0;
  virtual ClassTypePtr getClassType() const = 0;
};

// A CompilationUnit is a list of named Functions
// with helper methods to iterate the list or invoke the function.
// Classes have a CompilationUnit holding the class methods,
// and Modules have a CompilationUnit holding the Functions that
// are used to implement their Methods

struct TORCH_API CompilationUnit {
  enum class FunctionType { Method, Hook, PreHook };
  // constructor that takes a set of functions to compile using the native
  // resolver
  explicit CompilationUnit(const std::string& source);
  CompilationUnit() = default;

  CompilationUnit& operator=(CompilationUnit&&) = default;
  CompilationUnit(CompilationUnit&&) = default;
  CompilationUnit& operator=(const CompilationUnit&) = delete;
  CompilationUnit(const CompilationUnit&) = delete;

  Function* find_function(const c10::QualifiedName& name) const {
    auto it = dict_.find(name);
    if (it == dict_.end()) {
      return nullptr;
    }
    return functions_[it->second].get();
  }

  Function& get_function(const c10::QualifiedName& name) const {
    if (auto r = find_function(name)) {
      return *r;
    }
    TORCH_CHECK(false, "attempted to get undefined function ", name.name());
  }

  void set_optimized(bool o) {
    TORCH_WARN(
        "CompilationUnit::set_optimized() is deprecated and has no effect. "
        "Please use setGraphExecutorOptimize()");
  }

  bool is_optimized() const {
    TORCH_WARN(
        "CompilationUnit::is_optimized() is deprecated and always returns true. "
        "Please use getGraphExecutorOptimize()");
    return true;
  }

  // for historic reasons, these are defined in ir_emitter.cpp
  // Returns the list of Functions just defined.
  std::vector<Function*> define(
      const c10::optional<c10::QualifiedName>& prefix,
      const std::vector<Property>& properties,
      const std::vector<ResolverPtr>& propResolvers,
      const std::vector<Def>& definitions,
      const std::vector<ResolverPtr>&
          defResolvers, /* determines how we handle free
                     variables in each definition*/
      // if non-null, the first argument to each def, is bound to this value
      const Self* self,
      // see [name mangling]
      bool shouldMangle = false);

  void define_hooks(
      const c10::optional<c10::QualifiedName>& prefix,
      const std::vector<Def>& hookDefs,
      const std::vector<ResolverPtr>& hookResolvers,
      const std::vector<Def>& preHookDefs,
      const std::vector<ResolverPtr>& preHookResolvers,
      const Self* self,
      bool shouldMangle = false);

  // same as above but parse the definitions from source
  // Returns the list of Functions just defined.
  std::vector<Function*> define(
      // prefix namespace to put all the defined functions into
      const c10::optional<c10::QualifiedName>& prefix,
      const std::string& source,
      const ResolverPtr& resolver,
      const Self* self);

  void define_interface(
      const c10::QualifiedName& qualifiedName,
      const ClassDef& classDef,
      ResolverPtr rcb,
      bool is_module = false);

  Function* create_function(
      c10::QualifiedName name,
      std::shared_ptr<Graph> graph,
      bool shouldMangle = false) {
    if (shouldMangle) {
      name = mangle(name);
    }
    auto fn = torch::make_unique<GraphFunction>(
        std::move(name), std::move(graph), nullptr);
    auto ret = fn.get();
    register_function(std::move(fn));
    return ret;
  }

  std::vector<Function*> get_functions() const {
    return fmap(functions_, [](const std::unique_ptr<Function>& fn) {
      return fn.get();
    });
  }

  /// Run a method from this compilation.
  ///
  /// For example:
  /// @code
  ///   IValue output = module->run("relu_script", a, b);
  /// @endcode
  ///
  /// To get a compile a module from a source string, see torch::jit::compile
  ///
  /// @param method_name The name of the method to run
  /// @param args Arguments to be passed to the method
  /// @return An IValue containing the return value (or values if it is a tuple)
  /// from the method
  template <typename... Types>
  IValue run_method(const c10::QualifiedName& method_name, Types&&... args) {
    return get_function(method_name)({IValue(std::forward<Types>(args))...});
  }

  void drop_all_functions() {
    dict_.clear();
    functions_.clear();
  }

  /**
   * Register a class as being owned by this compilation unit.
   */
  void register_type(c10::NamedTypePtr namedType) {
    // TODO: class types cannot be redefined because we have no way right now
    // of invalidating their methods. NamedTuples are fine though, since they
    // don't have methods.
    TORCH_CHECK(
        0 == classDict_.count(*namedType->name()),
        "class '",
        namedType->name()->qualifiedName(),
        "' already defined.");
    classes_.push_back(std::move(namedType));
    classDict_[*classes_.back()->name()] = classes_.size() - 1;
  };

  c10::ClassTypePtr get_class(const c10::QualifiedName& name) const {
    auto type = get_type(name);
    if (!type) {
      return nullptr;
    }
    return type->cast<c10::ClassType>();
  }

  c10::InterfaceTypePtr get_interface(const c10::QualifiedName& name) const {
    auto type = get_type(name);
    if (!type) {
      return nullptr;
    }
    return type->cast<c10::InterfaceType>();
  }

  c10::TupleTypePtr get_named_tuple(const c10::QualifiedName& name) const {
    for (const auto& cls : classes_) {
      if (cls->name()->qualifiedName() == name.qualifiedName()) {
        return cls->expect<TupleType>();
      }
    }
    return nullptr;
  }

  c10::NamedTypePtr get_type(const c10::QualifiedName& name) const {
    auto it = classDict_.find(name);
    if (it == classDict_.end()) {
      return nullptr;
    }
    return classes_[it->second];
  }

  // For testing: clear all Python-defined classes to ensure that unit tests
  // have isolation.
  void _clear_python_cu() {
    // Delete all the associated class methods
    for (const auto& type : classes_) {
      if (auto cls = type->cast<ClassType>()) {
        for (auto method : cls->methods()) {
          // Tombstone the method in the compilation unit.
          // Don't erase because the dict_
          auto it = dict_.find(method->qualname());
          TORCH_INTERNAL_ASSERT(it != dict_.end());
          functions_[it->second] = nullptr;
          // Erase in our big lookup table
          dict_.erase(it);
        }
        // Classes can have multiple pointers to the same hook,
        // need to make sure to not delete it twice
        std::unordered_set<Function*> hooks_to_delete;
        for (const auto& hook : cls->getForwardHooks()) {
          hooks_to_delete.insert(hook);
        }
        for (const auto& pre_hook : cls->getForwardPreHooks()) {
          hooks_to_delete.insert(pre_hook);
        }
        for (const auto& hook : hooks_to_delete) {
          // Tombstone the hook in the compilation unit.
          auto it = dict_.find(hook->qualname());
          if (it != dict_.end()) {
            functions_[it->second] = nullptr;
            // Erase in our big lookup table
            dict_.erase(it);
          }
        }
      }
    }
    classes_.clear();
    classDict_.clear();
  }

  // [Internal Only] Remove method.
  // Note Used for freezing.
  void unsafeRemoveMethod(const c10::QualifiedName& method_name) {
    auto it = dict_.find(method_name);
    TORCH_CHECK(
        it != dict_.end(),
        "method '",
        method_name.qualifiedName(),
        "' does not exist.");
    functions_[it->second] = nullptr;
    dict_.erase(it);
  }

  // [name mangling] All code objects must have a unique qualified name in a
  // CompilationUnit. In Python, sometimes functions won't have unique qualified
  // name (for example, nested functions). So we mangle Python functions to
  // ensure that they are uniquely named.
  //
  // We also use mangling to distinguish different Module instances. Since each
  // Module is a singleton class instance, different instances of the same
  // Python Module will have different types but the same qualified name.
  c10::QualifiedName mangle(const c10::QualifiedName& name) const {
    auto mangled = name;
    while (get_type(mangled) || find_function(mangled)) {
      mangled = mangler_.mangle(mangled);
    }
    return mangled;
  }

 private:
  std::unique_ptr<Function> define(
      const c10::optional<c10::QualifiedName>& prefix,
      const Def& def,
      const ResolverPtr& resolver,
      const Self* self,
      const std::unordered_map<std::string, Function*>& function_table,
      bool shouldMangle = false,
      FunctionType type = FunctionType::Method) const;

  // Define a property on \p self.
  struct PropertyPair;
  PropertyPair define_property(
      const c10::optional<c10::QualifiedName>& prefix,
      const Property& prop,
      const ResolverPtr& resolver,
      const Self* self,
      const std::unordered_map<std::string, Function*>& function_table,
      bool shouldMangle = false) const;

  Function& register_function(std::unique_ptr<Function> fn) {
    TORCH_CHECK(
        0 == dict_.count(fn->qualname().qualifiedName()),
        "method '",
        fn->qualname().qualifiedName(),
        "' already defined.");
    functions_.emplace_back(std::move(fn));
    dict_[functions_.back()->qualname()] = functions_.size() - 1;
    return *functions_.back();
  }
  std::vector<std::unique_ptr<Function>> functions_;
  // for fast lookup
  std::unordered_map<c10::QualifiedName, size_t> dict_;
  std::unordered_map<c10::QualifiedName, size_t> classDict_;

  // [class ownership] Right now there aree two relationships between classes
  // and compilation units:
  // 1. Classes have compilation units internally that hold their methods.
  // 2. On load, the TypePtrs of any imported classes are owned by the main
  // module's compilation unit.
  std::vector<c10::NamedTypePtr> classes_;

  mutable NameMangler mangler_;
};

// An owning pointer to a Function. Just a pair of a raw Function ptr and it's
// owning CU. We need this because pybind requires a ref-counted way to refer to
// Functions.
struct StrongFunctionPtr {
  StrongFunctionPtr(std::shared_ptr<CompilationUnit> cu, Function* function)
      : cu_(std::move(cu)), function_(function) {
    TORCH_INTERNAL_ASSERT(cu_);
    TORCH_INTERNAL_ASSERT(function_);
  }
  std::shared_ptr<CompilationUnit> cu_;
  Function* function_;
};

namespace script {
Loading ...