#ifndef CAFFE2_OPERATORS_DO_OP_H_
#define CAFFE2_OPERATORS_DO_OP_H_
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/create_scope_op.h"
#include "caffe2/proto/caffe2_pb.h"
namespace caffe2 {
template <class Context>
class DoOp final : public Operator<Context> {
public:
explicit DoOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws), parent_ws_(ws) {
CAFFE_ENFORCE(
this->template HasSingleArgumentOfType<NetDef>("net"),
"net must be specified in Do operator");
net_def_ = this->template GetSingleArgument<NetDef>("net", NetDef());
is_gradient_op_ = operator_def.is_gradient_op();
copy_external_blobs_ =
this->template GetSingleArgument<bool>("copy_external_blobs", false);
reuse_workspace_ =
this->template GetSingleArgument<bool>("reuse_workspace", false);
CAFFE_ENFORCE(
!(is_gradient_op_ && reuse_workspace_),
"Gradient Do op requires use of stacked workspaces");
CAFFE_ENFORCE(
!(copy_external_blobs_ && reuse_workspace_),
"Reuse workspace and copy external blobs simultaneously in Do op");
const auto& inner_blobs =
this->template GetRepeatedArgument<std::string>("inner_blobs");
const auto& outer_blobs_idx =
this->template GetRepeatedArgument<int>("outer_blobs_idx");
CAFFE_ENFORCE_EQ(
inner_blobs.size(),
outer_blobs_idx.size(),
"Invalid blob bindings: different inner/outer blobs lengths");
const auto& outer_blob_names = checkAndGetOuterNames(operator_def);
std::unordered_set<std::string> used_outer_names;
for (size_t blob_idx = 0; blob_idx < inner_blobs.size(); ++blob_idx) {
CAFFE_ENFORCE(
!blob_bindings_.count(inner_blobs[blob_idx]),
"Invalid blob bindings: redefinition of inner blob " +
inner_blobs[blob_idx]);
CAFFE_ENFORCE(
outer_blobs_idx[blob_idx] >= 0 &&
outer_blobs_idx[blob_idx] < outer_blob_names.size(),
"Invalid blob bindings: outer blob index (" +
c10::to_string(outer_blobs_idx[blob_idx]) + ", inner name: " +
inner_blobs[blob_idx] + ") is out of bounds [0, " +
c10::to_string(outer_blob_names.size() - 1) + "]");
const auto& outer_name = outer_blob_names[outer_blobs_idx[blob_idx]];
CAFFE_ENFORCE(
!used_outer_names.count(outer_name),
"Reusage of outer name: " + outer_name);
used_outer_names.insert(outer_name);
blob_bindings_[inner_blobs[blob_idx]] = outer_name;
forwarded_inner_blobs_.insert(inner_blobs[blob_idx]);
}
std::unordered_set<std::string> all_outer_names(
outer_blob_names.begin(), outer_blob_names.end());
CAFFE_ENFORCE_EQ(
used_outer_names.size(),
all_outer_names.size(),
"Not all outer names are used in blob bindings");
}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
auto* ws_stack =
this->template Output<detail::WorkspaceStack>(OutputSize() - 1);
std::shared_ptr<Workspace> net_workspace;
if (is_gradient_op_) {
net_workspace =
ws_stack->popGradientWorkspace(parent_ws_, blob_bindings_);
} else {
if (reuse_workspace_ && !ws_stack->empty()) {
net_workspace =
ws_stack->reuseLastForwardWorkspace(parent_ws_, blob_bindings_);
} else {
net_workspace =
ws_stack->pushForwardWorkspace(parent_ws_, blob_bindings_);
}
}
CAFFE_ENFORCE(net_workspace, "Failed to initialize Do op workspace");
// TODO(iliacher): figure how to reuse existing net with a new workspace
auto* net = net_workspace->GetNet(net_def_.name());
if (!net) {
net = net_workspace->CreateNet(net_def_, true);
}
CAFFE_ENFORCE(net, "Failed to initialize subnet");
auto success = net->Run();
if (!is_gradient_op_ && copy_external_blobs_) {
net_workspace->template CopyForwardedTensors<Context>(
forwarded_inner_blobs_);
}
return success;
}
private:
// returns vector of input blob names followed by output blob names in
// operator definition order; ensures that input (output) names are unique,
// checks number of input (output) blobs
std::vector<std::string> checkAndGetOuterNames(
const OperatorDef& operator_def) const {
auto input_names = getInputBlobNames(operator_def);
CAFFE_ENFORCE(!input_names.empty(), "Expected at least one input blob");
std::string input_ws_blob = input_names.back(); // copy
// removing blob that holds pointer op workspace
input_names.pop_back();
std::unordered_set<std::string> all_input_names(
input_names.begin(), input_names.end());
CAFFE_ENFORCE_EQ(
input_names.size(), all_input_names.size(), "Duplicate input blobs");
auto output_names = getOutputBlobNames(operator_def);
CAFFE_ENFORCE(!output_names.empty(), "Expected at least one output blob");
const auto& output_ws_blob = output_names.back();
CAFFE_ENFORCE_EQ(
input_ws_blob,
output_ws_blob,
"Expected same input/output workspace blob");
// remove blob that holds pointer to op workspace
output_names.pop_back();
std::unordered_set<std::string> all_output_names(
output_names.begin(), output_names.end());
CAFFE_ENFORCE_EQ(
output_names.size(), all_output_names.size(), "Duplicate output blobs");
std::vector<std::string> outer_blob_names;
outer_blob_names.reserve(input_names.size() + output_names.size());
outer_blob_names.insert(
outer_blob_names.end(), input_names.begin(), input_names.end());
outer_blob_names.insert(
outer_blob_names.end(), output_names.begin(), output_names.end());
return outer_blob_names;
}
std::vector<std::string> getInputBlobNames(
const OperatorDef& operator_def) const {
std::vector<std::string> names;
names.reserve(operator_def.input_size());
for (auto idx = 0; idx < operator_def.input_size(); ++idx) {
names.push_back(operator_def.input(idx));
}
return names;
}
std::vector<std::string> getOutputBlobNames(
const OperatorDef& operator_def) const {
std::vector<std::string> names;
names.reserve(operator_def.output_size());
for (auto idx = 0; idx < operator_def.output_size(); ++idx) {
names.push_back(operator_def.output(idx));
}
return names;
}
std::unordered_map<std::string, std::string> blob_bindings_;
std::unordered_set<std::string> forwarded_inner_blobs_;
bool is_gradient_op_;
bool copy_external_blobs_;
bool reuse_workspace_;
NetDef net_def_;
Workspace* parent_ws_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_DO_OP_H_