Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / do_op.h

#ifndef CAFFE2_OPERATORS_DO_OP_H_
#define CAFFE2_OPERATORS_DO_OP_H_

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/create_scope_op.h"
#include "caffe2/proto/caffe2_pb.h"

namespace caffe2 {

template <class Context>
class DoOp final : public Operator<Context> {
 public:
  explicit DoOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws), parent_ws_(ws) {
    CAFFE_ENFORCE(
        this->template HasSingleArgumentOfType<NetDef>("net"),
        "net must be specified in Do operator");
    net_def_ = this->template GetSingleArgument<NetDef>("net", NetDef());
    is_gradient_op_ = operator_def.is_gradient_op();
    copy_external_blobs_ =
        this->template GetSingleArgument<bool>("copy_external_blobs", false);
    reuse_workspace_ =
        this->template GetSingleArgument<bool>("reuse_workspace", false);
    CAFFE_ENFORCE(
        !(is_gradient_op_ && reuse_workspace_),
        "Gradient Do op requires use of stacked workspaces");
    CAFFE_ENFORCE(
        !(copy_external_blobs_ && reuse_workspace_),
        "Reuse workspace and copy external blobs simultaneously in Do op");

    const auto& inner_blobs =
        this->template GetRepeatedArgument<std::string>("inner_blobs");
    const auto& outer_blobs_idx =
        this->template GetRepeatedArgument<int>("outer_blobs_idx");
    CAFFE_ENFORCE_EQ(
        inner_blobs.size(),
        outer_blobs_idx.size(),
        "Invalid blob bindings: different inner/outer blobs lengths");

    const auto& outer_blob_names = checkAndGetOuterNames(operator_def);
    std::unordered_set<std::string> used_outer_names;
    for (size_t blob_idx = 0; blob_idx < inner_blobs.size(); ++blob_idx) {
      CAFFE_ENFORCE(
          !blob_bindings_.count(inner_blobs[blob_idx]),
          "Invalid blob bindings: redefinition of inner blob " +
              inner_blobs[blob_idx]);
      CAFFE_ENFORCE(
          outer_blobs_idx[blob_idx] >= 0 &&
              outer_blobs_idx[blob_idx] < outer_blob_names.size(),
          "Invalid blob bindings: outer blob index (" +
              c10::to_string(outer_blobs_idx[blob_idx]) + ", inner name: " +
              inner_blobs[blob_idx] + ") is out of bounds [0, " +
              c10::to_string(outer_blob_names.size() - 1) + "]");
      const auto& outer_name = outer_blob_names[outer_blobs_idx[blob_idx]];
      CAFFE_ENFORCE(
          !used_outer_names.count(outer_name),
          "Reusage of outer name: " + outer_name);
      used_outer_names.insert(outer_name);
      blob_bindings_[inner_blobs[blob_idx]] = outer_name;
      forwarded_inner_blobs_.insert(inner_blobs[blob_idx]);
    }
    std::unordered_set<std::string> all_outer_names(
        outer_blob_names.begin(), outer_blob_names.end());
    CAFFE_ENFORCE_EQ(
        used_outer_names.size(),
        all_outer_names.size(),
        "Not all outer names are used in blob bindings");
  }

  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override {
    auto* ws_stack =
        this->template Output<detail::WorkspaceStack>(OutputSize() - 1);
    std::shared_ptr<Workspace> net_workspace;
    if (is_gradient_op_) {
      net_workspace =
          ws_stack->popGradientWorkspace(parent_ws_, blob_bindings_);
    } else {
      if (reuse_workspace_ && !ws_stack->empty()) {
        net_workspace =
            ws_stack->reuseLastForwardWorkspace(parent_ws_, blob_bindings_);
      } else {
        net_workspace =
            ws_stack->pushForwardWorkspace(parent_ws_, blob_bindings_);
      }
    }
    CAFFE_ENFORCE(net_workspace, "Failed to initialize Do op workspace");

    // TODO(iliacher): figure how to reuse existing net with a new workspace
    auto* net = net_workspace->GetNet(net_def_.name());
    if (!net) {
      net = net_workspace->CreateNet(net_def_, true);
    }
    CAFFE_ENFORCE(net, "Failed to initialize subnet");
    auto success = net->Run();
    if (!is_gradient_op_ && copy_external_blobs_) {
      net_workspace->template CopyForwardedTensors<Context>(
          forwarded_inner_blobs_);
    }
    return success;
  }

 private:
  // returns vector of input blob names followed by output blob names in
  // operator definition order; ensures that input (output) names are unique,
  // checks number of input (output) blobs
  std::vector<std::string> checkAndGetOuterNames(
      const OperatorDef& operator_def) const {
    auto input_names = getInputBlobNames(operator_def);
    CAFFE_ENFORCE(!input_names.empty(), "Expected at least one input blob");
    std::string input_ws_blob = input_names.back(); // copy
    // removing blob that holds pointer op workspace
    input_names.pop_back();

    std::unordered_set<std::string> all_input_names(
        input_names.begin(), input_names.end());
    CAFFE_ENFORCE_EQ(
        input_names.size(), all_input_names.size(), "Duplicate input blobs");

    auto output_names = getOutputBlobNames(operator_def);
    CAFFE_ENFORCE(!output_names.empty(), "Expected at least one output blob");
    const auto& output_ws_blob = output_names.back();
    CAFFE_ENFORCE_EQ(
        input_ws_blob,
        output_ws_blob,
        "Expected same input/output workspace blob");
    // remove blob that holds pointer to op workspace
    output_names.pop_back();

    std::unordered_set<std::string> all_output_names(
        output_names.begin(), output_names.end());
    CAFFE_ENFORCE_EQ(
        output_names.size(), all_output_names.size(), "Duplicate output blobs");

    std::vector<std::string> outer_blob_names;
    outer_blob_names.reserve(input_names.size() + output_names.size());
    outer_blob_names.insert(
        outer_blob_names.end(), input_names.begin(), input_names.end());
    outer_blob_names.insert(
        outer_blob_names.end(), output_names.begin(), output_names.end());
    return outer_blob_names;
  }

  std::vector<std::string> getInputBlobNames(
      const OperatorDef& operator_def) const {
    std::vector<std::string> names;
    names.reserve(operator_def.input_size());
    for (auto idx = 0; idx < operator_def.input_size(); ++idx) {
      names.push_back(operator_def.input(idx));
    }
    return names;
  }

  std::vector<std::string> getOutputBlobNames(
      const OperatorDef& operator_def) const {
    std::vector<std::string> names;
    names.reserve(operator_def.output_size());
    for (auto idx = 0; idx < operator_def.output_size(); ++idx) {
      names.push_back(operator_def.output(idx));
    }
    return names;
  }

  std::unordered_map<std::string, std::string> blob_bindings_;
  std::unordered_set<std::string> forwarded_inner_blobs_;
  bool is_gradient_op_;
  bool copy_external_blobs_;
  bool reuse_workspace_;
  NetDef net_def_;
  Workspace* parent_ws_;
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_DO_OP_H_