Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / core / workspace.h

#ifndef CAFFE2_CORE_WORKSPACE_H_
#define CAFFE2_CORE_WORKSPACE_H_

#include "caffe2/core/common.h"
#include "caffe2/core/observer.h"

#include <climits>
#include <cstddef>
#include <mutex>
#include <typeinfo>
#include <unordered_set>
#include <vector>

#include "c10/util/Registry.h"
#include "caffe2/core/blob.h"
#include "caffe2/core/net.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/signal_handler.h"
#include "caffe2/utils/threadpool/ThreadPool.h"

C10_DECLARE_bool(caffe2_print_blob_sizes_at_exit);

namespace caffe2 {

class NetBase;

struct TORCH_API StopOnSignal {
  StopOnSignal()
      : handler_(std::make_shared<SignalHandler>(
            SignalHandler::Action::STOP,
            SignalHandler::Action::STOP)) {}

  StopOnSignal(const StopOnSignal& other) : handler_(other.handler_) {}

  bool operator()(int /*iter*/) {
    return handler_->CheckForSignals() != SignalHandler::Action::STOP;
  }

  std::shared_ptr<SignalHandler> handler_;
};

/**
 * Workspace is a class that holds all the related objects created during
 * runtime: (1) all blobs, and (2) all instantiated networks. It is the owner of
 * all these objects and deals with the scaffolding logistics.
 */
class TORCH_API Workspace {
 public:
  typedef std::function<bool(int)> ShouldContinue;
  typedef CaffeMap<string, unique_ptr<Blob> > BlobMap;
  typedef CaffeMap<string, unique_ptr<NetBase> > NetMap;
  /**
   * Initializes an empty workspace.
   */
  Workspace() : Workspace(".", nullptr) {}

  /**
   * Initializes an empty workspace with the given root folder.
   *
   * For any operators that are going to interface with the file system, such
   * as load operators, they will write things under this root folder given
   * by the workspace.
   */
  explicit Workspace(const string& root_folder)
      : Workspace(root_folder, nullptr) {}

  /**
   * Initializes a workspace with a shared workspace.
   *
   * When we access a Blob, we will first try to access the blob that exists
   * in the local workspace, and if not, access the blob that exists in the
   * shared workspace. The caller keeps the ownership of the shared workspace
   * and is responsible for making sure that its lifetime is longer than the
   * created workspace.
   */
  explicit Workspace(const Workspace* shared) : Workspace(".", shared) {}

  /**
   * Initializes workspace with parent workspace, blob name remapping
   * (new name -> parent blob name), no other blobs are inherited from
   * parent workspace
   */
  Workspace(
      const Workspace* shared,
      const std::unordered_map<string, string>& forwarded_blobs)
      : Workspace(".", nullptr) {
    CAFFE_ENFORCE(shared, "Parent workspace must be specified");
    for (const auto& forwarded : forwarded_blobs) {
      CAFFE_ENFORCE(
          shared->HasBlob(forwarded.second),
          "Invalid parent workspace blob: ",
          forwarded.second);
      forwarded_blobs_[forwarded.first] =
          std::make_pair(shared, forwarded.second);
    }
  }

  /**
   * Initializes a workspace with a root folder and a shared workspace.
   */
  Workspace(const string& root_folder, const Workspace* shared)
      : root_folder_(root_folder), shared_(shared), bookkeeper_(bookkeeper()) {
    std::lock_guard<std::mutex> guard(bookkeeper_->wsmutex);
    bookkeeper_->workspaces.insert(this);
  }

  ~Workspace() {
    if (FLAGS_caffe2_print_blob_sizes_at_exit) {
      PrintBlobSizes();
    }
    // This is why we have a bookkeeper_ shared_ptr instead of a naked static! A
    // naked static makes us vulnerable to out-of-order static destructor bugs.
    std::lock_guard<std::mutex> guard(bookkeeper_->wsmutex);
    bookkeeper_->workspaces.erase(this);
  }

  /**
   * Adds blob mappings from workspace to the blobs from parent workspace.
   * Creates blobs under possibly new names that redirect read/write operations
   * to the blobs in the parent workspace.
   * Arguments:
   *  parent - pointer to parent workspace
   *  forwarded_blobs - map from new blob name to blob name in parent's
   * workspace skip_defined_blob - if set skips blobs with names that already
   * exist in the workspace, otherwise throws exception
   */
  void AddBlobMapping(
      const Workspace* parent,
      const std::unordered_map<string, string>& forwarded_blobs,
      bool skip_defined_blobs = false);

  /**
   * Converts previously mapped tensor blobs to local blobs, copies values from
   * parent workspace blobs into new local blobs. Ignores undefined blobs.
   */
  template <class Context>
  void CopyForwardedTensors(const std::unordered_set<std::string>& blobs) {
    for (const auto& blob : blobs) {
      if (!forwarded_blobs_.count(blob)) {
        continue;
      }
      const auto& ws_blob = forwarded_blobs_[blob];
      const auto* parent_ws = ws_blob.first;
      auto* from_blob = parent_ws->GetBlob(ws_blob.second);
      CAFFE_ENFORCE(from_blob);
      CAFFE_ENFORCE(
          from_blob->template IsType<Tensor>(),
          "Expected blob with tensor value",
          ws_blob.second);
      forwarded_blobs_.erase(blob);
      auto* to_blob = CreateBlob(blob);
      CAFFE_ENFORCE(to_blob);
      const auto& from_tensor = from_blob->template Get<Tensor>();
      auto* to_tensor = BlobGetMutableTensor(to_blob, Context::GetDeviceType());
      to_tensor->CopyFrom(from_tensor);
    }
  }

  /**
   * Return list of blobs owned by this Workspace, not including blobs
   * shared from parent workspace.
   */
  vector<string> LocalBlobs() const;

  /**
   * Return a list of blob names. This may be a bit slow since it will involve
   * creation of multiple temp variables. For best performance, simply use
   * HasBlob() and GetBlob().
   */
  vector<string> Blobs() const;

  /**
   * Return the root folder of the workspace.
   */
  const string& RootFolder() { return root_folder_; }
  /**
   * Checks if a blob with the given name is present in the current workspace.
   */
  inline bool HasBlob(const string& name) const {
    // First, check the local workspace,
    // Then, check the forwarding map, then the parent workspace
    if (blob_map_.count(name)) {
      return true;
    } else if (forwarded_blobs_.count(name)) {
      const auto parent_ws = forwarded_blobs_.at(name).first;
      const auto& parent_name = forwarded_blobs_.at(name).second;
      return parent_ws->HasBlob(parent_name);
    } else if (shared_) {
      return shared_->HasBlob(name);
    }
    return false;
  }

  void PrintBlobSizes();

  /**
   * Creates a blob of the given name. The pointer to the blob is returned, but
   * the workspace keeps ownership of the pointer. If a blob of the given name
   * already exists, the creation is skipped and the existing blob is returned.
   */
  Blob* CreateBlob(const string& name);
  /**
   * Similar to CreateBlob(), but it creates a blob in the local workspace even
   * if another blob with the same name already exists in the parent workspace
   * -- in such case the new blob hides the blob in parent workspace. If a blob
   * of the given name already exists in the local workspace, the creation is
   * skipped and the existing blob is returned.
   */
  Blob* CreateLocalBlob(const string& name);
  /**
   * Remove the blob of the given name. Return true if removed and false if
   * not exist.
   * Will NOT remove from the shared workspace.
   */
  bool RemoveBlob(const string& name);
  /**
   * Gets the blob with the given name as a const pointer. If the blob does not
   * exist, a nullptr is returned.
   */
  const Blob* GetBlob(const string& name) const;
  /**
   * Gets the blob with the given name as a mutable pointer. If the blob does
   * not exist, a nullptr is returned.
   */
  Blob* GetBlob(const string& name);

  /**
   * Renames a local workspace blob. If blob is not found in the local blob list
   * or if the target name is already present in local or any parent blob list
   * the function will throw.
   */
  Blob* RenameBlob(const string& old_name, const string& new_name);

  /**
   * Creates a network with the given NetDef, and returns the pointer to the
   * network. If there is anything wrong during the creation of the network, a
   * nullptr is returned. The Workspace keeps ownership of the pointer.
   *
   * If there is already a net created in the workspace with the given name,
   * CreateNet will overwrite it if overwrite=true is specified. Otherwise, an
   * exception is thrown.
   */
  NetBase* CreateNet(const NetDef& net_def, bool overwrite = false);
  NetBase* CreateNet(
      const std::shared_ptr<const NetDef>& net_def,
      bool overwrite = false);
  /**
   * Gets the pointer to a created net. The workspace keeps ownership of the
   * network.
   */
  NetBase* GetNet(const string& net_name);
  /**
   * Deletes the instantiated network with the given name.
   */
  void DeleteNet(const string& net_name);
  /**
   * Finds and runs the instantiated network with the given name. If the network
   * does not exist or there are errors running the network, the function
   * returns false.
   */
  bool RunNet(const string& net_name);

  /**
   * Returns a list of names of the currently instantiated networks.
   */
  vector<string> Nets() const {
    vector<string> names;
    for (auto& entry : net_map_) {
      names.push_back(entry.first);
    }
    return names;
  }

  /**
   * Runs a plan that has multiple nets and execution steps.
   */
  bool RunPlan(const PlanDef& plan_def,
               ShouldContinue should_continue = StopOnSignal{});

  /*
   * Returns a CPU threadpool instance for parallel execution of
   * work. The threadpool is created lazily; if no operators use it,
   * then no threadpool will be created.
   */
  ThreadPool* GetThreadPool();

  // RunOperatorOnce and RunNetOnce runs an operator or net once. The difference
  // between RunNet and RunNetOnce lies in the fact that RunNet allows you to
  // have a persistent net object, while RunNetOnce creates a net and discards
  // it on the fly - this may make things like database read and random number
  // generators repeat the same thing over multiple calls.
  bool RunOperatorOnce(const OperatorDef& op_def);
  bool RunNetOnce(const NetDef& net_def);

  /**
   * Applies a function f on each workspace that currently exists.
   *
   * This function is thread safe and there is no race condition between
   * workspaces being passed to f in this thread and destroyed in another.
   */
  template <typename F>
  static void ForEach(F f) {
    auto bk = bookkeeper();
    std::lock_guard<std::mutex> guard(bk->wsmutex);
    for (Workspace* ws : bk->workspaces) {
      f(ws);
    }
  }

 public:
  std::atomic<int> last_failed_op_net_position{};

 private:
  struct Bookkeeper {
    std::mutex wsmutex;
    std::unordered_set<Workspace*> workspaces;
  };

  static std::shared_ptr<Bookkeeper> bookkeeper();

  BlobMap blob_map_;
  const string root_folder_;
  const Workspace* shared_;
  std::unordered_map<string, std::pair<const Workspace*, string>>
      forwarded_blobs_;
  std::unique_ptr<ThreadPool> thread_pool_;
  std::mutex thread_pool_creation_mutex_;
  std::shared_ptr<Bookkeeper> bookkeeper_;
  NetMap net_map_;

  C10_DISABLE_COPY_AND_ASSIGN(Workspace);
};

}  // namespace caffe2

#endif  // CAFFE2_CORE_WORKSPACE_H_