#ifndef CAFFE2_CORE_WORKSPACE_H_
#define CAFFE2_CORE_WORKSPACE_H_
#include "caffe2/core/common.h"
#include "caffe2/core/observer.h"
#include <climits>
#include <cstddef>
#include <mutex>
#include <typeinfo>
#include <unordered_set>
#include <vector>
#include "c10/util/Registry.h"
#include "caffe2/core/blob.h"
#include "caffe2/core/net.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/signal_handler.h"
#include "caffe2/utils/threadpool/ThreadPool.h"
C10_DECLARE_bool(caffe2_print_blob_sizes_at_exit);
namespace caffe2 {
class NetBase;
struct TORCH_API StopOnSignal {
StopOnSignal()
: handler_(std::make_shared<SignalHandler>(
SignalHandler::Action::STOP,
SignalHandler::Action::STOP)) {}
StopOnSignal(const StopOnSignal& other) : handler_(other.handler_) {}
bool operator()(int /*iter*/) {
return handler_->CheckForSignals() != SignalHandler::Action::STOP;
}
std::shared_ptr<SignalHandler> handler_;
};
/**
* Workspace is a class that holds all the related objects created during
* runtime: (1) all blobs, and (2) all instantiated networks. It is the owner of
* all these objects and deals with the scaffolding logistics.
*/
class TORCH_API Workspace {
public:
typedef std::function<bool(int)> ShouldContinue;
typedef CaffeMap<string, unique_ptr<Blob> > BlobMap;
typedef CaffeMap<string, unique_ptr<NetBase> > NetMap;
/**
* Initializes an empty workspace.
*/
Workspace() : Workspace(".", nullptr) {}
/**
* Initializes an empty workspace with the given root folder.
*
* For any operators that are going to interface with the file system, such
* as load operators, they will write things under this root folder given
* by the workspace.
*/
explicit Workspace(const string& root_folder)
: Workspace(root_folder, nullptr) {}
/**
* Initializes a workspace with a shared workspace.
*
* When we access a Blob, we will first try to access the blob that exists
* in the local workspace, and if not, access the blob that exists in the
* shared workspace. The caller keeps the ownership of the shared workspace
* and is responsible for making sure that its lifetime is longer than the
* created workspace.
*/
explicit Workspace(const Workspace* shared) : Workspace(".", shared) {}
/**
* Initializes workspace with parent workspace, blob name remapping
* (new name -> parent blob name), no other blobs are inherited from
* parent workspace
*/
Workspace(
const Workspace* shared,
const std::unordered_map<string, string>& forwarded_blobs)
: Workspace(".", nullptr) {
CAFFE_ENFORCE(shared, "Parent workspace must be specified");
for (const auto& forwarded : forwarded_blobs) {
CAFFE_ENFORCE(
shared->HasBlob(forwarded.second),
"Invalid parent workspace blob: ",
forwarded.second);
forwarded_blobs_[forwarded.first] =
std::make_pair(shared, forwarded.second);
}
}
/**
* Initializes a workspace with a root folder and a shared workspace.
*/
Workspace(const string& root_folder, const Workspace* shared)
: root_folder_(root_folder), shared_(shared), bookkeeper_(bookkeeper()) {
std::lock_guard<std::mutex> guard(bookkeeper_->wsmutex);
bookkeeper_->workspaces.insert(this);
}
~Workspace() {
if (FLAGS_caffe2_print_blob_sizes_at_exit) {
PrintBlobSizes();
}
// This is why we have a bookkeeper_ shared_ptr instead of a naked static! A
// naked static makes us vulnerable to out-of-order static destructor bugs.
std::lock_guard<std::mutex> guard(bookkeeper_->wsmutex);
bookkeeper_->workspaces.erase(this);
}
/**
* Adds blob mappings from workspace to the blobs from parent workspace.
* Creates blobs under possibly new names that redirect read/write operations
* to the blobs in the parent workspace.
* Arguments:
* parent - pointer to parent workspace
* forwarded_blobs - map from new blob name to blob name in parent's
* workspace skip_defined_blob - if set skips blobs with names that already
* exist in the workspace, otherwise throws exception
*/
void AddBlobMapping(
const Workspace* parent,
const std::unordered_map<string, string>& forwarded_blobs,
bool skip_defined_blobs = false);
/**
* Converts previously mapped tensor blobs to local blobs, copies values from
* parent workspace blobs into new local blobs. Ignores undefined blobs.
*/
template <class Context>
void CopyForwardedTensors(const std::unordered_set<std::string>& blobs) {
for (const auto& blob : blobs) {
if (!forwarded_blobs_.count(blob)) {
continue;
}
const auto& ws_blob = forwarded_blobs_[blob];
const auto* parent_ws = ws_blob.first;
auto* from_blob = parent_ws->GetBlob(ws_blob.second);
CAFFE_ENFORCE(from_blob);
CAFFE_ENFORCE(
from_blob->template IsType<Tensor>(),
"Expected blob with tensor value",
ws_blob.second);
forwarded_blobs_.erase(blob);
auto* to_blob = CreateBlob(blob);
CAFFE_ENFORCE(to_blob);
const auto& from_tensor = from_blob->template Get<Tensor>();
auto* to_tensor = BlobGetMutableTensor(to_blob, Context::GetDeviceType());
to_tensor->CopyFrom(from_tensor);
}
}
/**
* Return list of blobs owned by this Workspace, not including blobs
* shared from parent workspace.
*/
vector<string> LocalBlobs() const;
/**
* Return a list of blob names. This may be a bit slow since it will involve
* creation of multiple temp variables. For best performance, simply use
* HasBlob() and GetBlob().
*/
vector<string> Blobs() const;
/**
* Return the root folder of the workspace.
*/
const string& RootFolder() { return root_folder_; }
/**
* Checks if a blob with the given name is present in the current workspace.
*/
inline bool HasBlob(const string& name) const {
// First, check the local workspace,
// Then, check the forwarding map, then the parent workspace
if (blob_map_.count(name)) {
return true;
} else if (forwarded_blobs_.count(name)) {
const auto parent_ws = forwarded_blobs_.at(name).first;
const auto& parent_name = forwarded_blobs_.at(name).second;
return parent_ws->HasBlob(parent_name);
} else if (shared_) {
return shared_->HasBlob(name);
}
return false;
}
void PrintBlobSizes();
/**
* Creates a blob of the given name. The pointer to the blob is returned, but
* the workspace keeps ownership of the pointer. If a blob of the given name
* already exists, the creation is skipped and the existing blob is returned.
*/
Blob* CreateBlob(const string& name);
/**
* Similar to CreateBlob(), but it creates a blob in the local workspace even
* if another blob with the same name already exists in the parent workspace
* -- in such case the new blob hides the blob in parent workspace. If a blob
* of the given name already exists in the local workspace, the creation is
* skipped and the existing blob is returned.
*/
Blob* CreateLocalBlob(const string& name);
/**
* Remove the blob of the given name. Return true if removed and false if
* not exist.
* Will NOT remove from the shared workspace.
*/
bool RemoveBlob(const string& name);
/**
* Gets the blob with the given name as a const pointer. If the blob does not
* exist, a nullptr is returned.
*/
const Blob* GetBlob(const string& name) const;
/**
* Gets the blob with the given name as a mutable pointer. If the blob does
* not exist, a nullptr is returned.
*/
Blob* GetBlob(const string& name);
/**
* Renames a local workspace blob. If blob is not found in the local blob list
* or if the target name is already present in local or any parent blob list
* the function will throw.
*/
Blob* RenameBlob(const string& old_name, const string& new_name);
/**
* Creates a network with the given NetDef, and returns the pointer to the
* network. If there is anything wrong during the creation of the network, a
* nullptr is returned. The Workspace keeps ownership of the pointer.
*
* If there is already a net created in the workspace with the given name,
* CreateNet will overwrite it if overwrite=true is specified. Otherwise, an
* exception is thrown.
*/
NetBase* CreateNet(const NetDef& net_def, bool overwrite = false);
NetBase* CreateNet(
const std::shared_ptr<const NetDef>& net_def,
bool overwrite = false);
/**
* Gets the pointer to a created net. The workspace keeps ownership of the
* network.
*/
NetBase* GetNet(const string& net_name);
/**
* Deletes the instantiated network with the given name.
*/
void DeleteNet(const string& net_name);
/**
* Finds and runs the instantiated network with the given name. If the network
* does not exist or there are errors running the network, the function
* returns false.
*/
bool RunNet(const string& net_name);
/**
* Returns a list of names of the currently instantiated networks.
*/
vector<string> Nets() const {
vector<string> names;
for (auto& entry : net_map_) {
names.push_back(entry.first);
}
return names;
}
/**
* Runs a plan that has multiple nets and execution steps.
*/
bool RunPlan(const PlanDef& plan_def,
ShouldContinue should_continue = StopOnSignal{});
/*
* Returns a CPU threadpool instance for parallel execution of
* work. The threadpool is created lazily; if no operators use it,
* then no threadpool will be created.
*/
ThreadPool* GetThreadPool();
// RunOperatorOnce and RunNetOnce runs an operator or net once. The difference
// between RunNet and RunNetOnce lies in the fact that RunNet allows you to
// have a persistent net object, while RunNetOnce creates a net and discards
// it on the fly - this may make things like database read and random number
// generators repeat the same thing over multiple calls.
bool RunOperatorOnce(const OperatorDef& op_def);
bool RunNetOnce(const NetDef& net_def);
/**
* Applies a function f on each workspace that currently exists.
*
* This function is thread safe and there is no race condition between
* workspaces being passed to f in this thread and destroyed in another.
*/
template <typename F>
static void ForEach(F f) {
auto bk = bookkeeper();
std::lock_guard<std::mutex> guard(bk->wsmutex);
for (Workspace* ws : bk->workspaces) {
f(ws);
}
}
public:
std::atomic<int> last_failed_op_net_position{};
private:
struct Bookkeeper {
std::mutex wsmutex;
std::unordered_set<Workspace*> workspaces;
};
static std::shared_ptr<Bookkeeper> bookkeeper();
BlobMap blob_map_;
const string root_folder_;
const Workspace* shared_;
std::unordered_map<string, std::pair<const Workspace*, string>>
forwarded_blobs_;
std::unique_ptr<ThreadPool> thread_pool_;
std::mutex thread_pool_creation_mutex_;
std::shared_ptr<Bookkeeper> bookkeeper_;
NetMap net_map_;
C10_DISABLE_COPY_AND_ASSIGN(Workspace);
};
} // namespace caffe2
#endif // CAFFE2_CORE_WORKSPACE_H_