Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / load_save_op.h

#ifndef CAFFE2_OPERATORS_LOAD_SAVE_OP_H_
#define CAFFE2_OPERATORS_LOAD_SAVE_OP_H_

#include <cstdio>
#include <map>
#include <unordered_set>

#include "caffe2/core/blob_serialization.h"
#include "caffe2/core/context.h"
#include "caffe2/core/db.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/load_save_op_util.h"
#include "caffe2/utils/math.h"
#include "caffe2/utils/proto_utils.h"

namespace caffe2 {

using db::Cursor;
using db::DB;
using db::Transaction;

template <class Context>
class DBExistsOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  explicit DBExistsOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        ws_(ws),
        absolute_path_(
            this->template GetSingleArgument<int>("absolute_path", false)),
        db_name_(this->template GetSingleArgument<string>("db_name", "")),
        db_type_(this->template GetSingleArgument<string>("db_type", "")) {}

  bool RunOnDevice() override {
    string full_db_name =
        absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_);
    auto* output = Output(0);
    output->Resize();
    bool* exists = output->template mutable_data<bool>();

    *exists = caffe2::db::DBExists(db_type_, full_db_name);
    return true;
  }

 private:
  Workspace* ws_;
  bool absolute_path_;
  std::string db_name_;
  std::string db_type_;
};

template <class Context>
class LoadOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  explicit LoadOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        ws_(ws),
        absolute_path_(
            this->template GetSingleArgument<int>("absolute_path", false)),
        add_prefix_(this->template GetSingleArgument<string>("add_prefix", "")),
        strip_prefix_(
            this->template GetSingleArgument<string>("strip_prefix", "")),
        db_name_(this->template GetSingleArgument<string>("db", "")),
        db_names_(this->template GetRepeatedArgument<string>("dbs")),
        db_type_(this->template GetSingleArgument<string>("db_type", "")),
        keep_device_(this->template GetSingleArgument<int>("keep_device", 0)),
        load_all_(this->template GetSingleArgument<int>("load_all", 0)),
        allow_incomplete_(
            this->template GetSingleArgument<bool>("allow_incomplete", false)),
        blob_names_(
            this->template GetRepeatedArgument<string>("source_blob_names")),
        shape_(this->template GetRepeatedArgument<int64_t>("shape")) {
    if (InputSize() == 0) {
      CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type.");
      if (db_names_.empty()) {
        CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name.");
        db_names_.push_back(db_name_);
        db_name_ = "";
      } else {
        std::set<std::string> db_name_set;
        for (const string& db_name : db_names_) {
          CAFFE_ENFORCE_GT(db_name.size(), 0, "Db name should not be empty.");
          CAFFE_ENFORCE(
              db_name_set.insert(db_name).second,
              "Duplicated db name: ",
              db_name);
        }
        db_name_ = "";
      }
    }
    CAFFE_ENFORCE(
        blob_names_.empty() || blob_names_.size() == OutputSize(),
        "Number of output blobs and source_blob_names mismatch.");
    CAFFE_ENFORCE(
        blob_names_.empty() || strip_prefix_.empty(),
        "strip_prefix and source_blob_names are mutually exclusive.");
    CAFFE_ENFORCE(
        blob_names_.empty() || !load_all_,
        "cannot load_all_ while using source_blob_names.");
    if (!load_all_) {
      // blob_names_ will be filled with ''source blob names'' in file/db
      // if argument source_blob_names is not given, then blob_names_ is
      // inferred from operator output
      if (blob_names_.empty()) {
        for (const string& name : operator_def.output()) {
          blob_names_.push_back(name);
        }
      }
      int idx = 0;
      std::set<std::string> name_set;
      for (const string& name : blob_names_) {
        CAFFE_ENFORCE(
            name_set.insert(name).second,
            "Duplicated source blob name: ",
            name);
        output_indices_[name] = idx++;
      }
    }
  }

  void SetCurrentDevice(BlobProto* proto);

  bool RunOnDevice() override {
    int total_loaded_blobs = 0;
    std::unordered_map<string, load_save_op_util::BlobState> blob_states;
    if (InputSize() > 0) {
      for (int i = 0; i < InputSize(); ++i) {
        const db::DBReader& reader = this->template Input<db::DBReader>(i);
        extract(i, reader.cursor(), &blob_states, &total_loaded_blobs);
      }
    } else {
      for (int i = 0; i < db_names_.size(); ++i) {
        string full_db_name = absolute_path_
            ? db_names_[i]
            : (ws_->RootFolder() + "/" + db_names_[i]);
        std::unique_ptr<DB> in_db(
            caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::READ));
        CAFFE_ENFORCE(
            in_db.get(),
            "Cannot find db implementation of type ",
            db_type_,
            " (while trying to open ",
            full_db_name,
            ")");
        std::unique_ptr<Cursor> cursor(in_db->NewCursor());
        extract(i, cursor.get(), &blob_states, &total_loaded_blobs);
      }
    }

    load_save_op_util::validateBlobStates(blob_states);
    // Loaded all the needed blobs.
    if (!load_all_ && total_loaded_blobs == OutputSize()) {
      VLOG(1) << "Loaded " << total_loaded_blobs << " blobs fully from db(s)";
      return true;
    }

    if (load_all_) {
      for (const string& name : this->debug_def().output()) {
        CAFFE_ENFORCE(
            blob_states.count(name),
            "Output blob name ",
            name,
            " does not exist in the db(s).");
      }
      return true;
    }

    // Only loaded a subset of the blobs.
    if (allow_incomplete_) {
      VLOG(1) << "Loaded " << total_loaded_blobs << " blobs out of "
              << OutputSize() << " blobs from db(s).";
    } else {
      for (const string& output_name : this->debug_def().output()) {
        if (blob_states.count(output_name) == 0) {
          LOG(ERROR) << "Failed to load blob: " << output_name;
        }
      }
      CAFFE_THROW(
          "Expected to load ",
          OutputSize(),
          " blobs, got ",
          total_loaded_blobs,
          " only.\n");
    }

    return true;
  }

 private:
  void extract(
      int db_id,
      Cursor* cursor,
      std::unordered_map<string, load_save_op_util::BlobState>* blob_states,
      int* total_loaded_blobs) {
    if (load_all_) {
      extractAll(db_id, cursor, blob_states, total_loaded_blobs);
    } else {
      extractFrom(
          db_id,
          cursor,
          OperatorBase::Outputs(),
          blob_states,
          total_loaded_blobs);
    }
  }

  void extractAll(
      int db_id,
      Cursor* cursor,
      std::unordered_map<string, load_save_op_util::BlobState>* blob_states,
      int* total_loaded_blobs) {
    CAFFE_ENFORCE(cursor, "cursor is not valid");
    int loaded_blobs = 0;
    for (; cursor->Valid(); cursor->Next()) {
      const auto key = load_save_op_util::buildBlobNameFromDbKey(
          cursor->key(), strip_prefix_, add_prefix_);
      if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
        CAFFE_THROW("Duplicate Key ", key, " is found!\n");
      } else {
        key_to_dbid_[key] = db_id;
      }

      BlobProto proto;
      CAFFE_ENFORCE(
          proto.ParseFromString(cursor->value()), "Couldn't parse Proto");
      if (!keep_device_) {
        // If we are not keeping the device as the one specified in the
        // proto, we will set the current device.
        SetCurrentDevice(&proto);
      }
      Blob* blob = ws_->CreateBlob(key);
      load_save_op_util::ProcessBlob(
          blob, proto, blob_states, key, &loaded_blobs);
    }
    *total_loaded_blobs += loaded_blobs;
  }

  void extractFrom(
      int db_id,
      Cursor* cursor,
      const vector<Blob*>& outputs,
      std::unordered_map<string, load_save_op_util::BlobState>* blob_states,
      int* total_loaded_blobs) {
    CAFFE_ENFORCE(cursor);
    int loaded_blobs = 0;
    for (; cursor->Valid(); cursor->Next()) {
      const auto key = load_save_op_util::buildBlobNameFromDbKey(
          cursor->key(), strip_prefix_, add_prefix_);
      if (!output_indices_.count(key)) {
        VLOG(1) << "Key " << key << " not used. Skipping.";
      } else {
        if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
          CAFFE_THROW("Duplicate Key ", key, " is found!\n");
        } else {
          key_to_dbid_[key] = db_id;
        }

        VLOG(2) << "Deserializing blob " << key;
        BlobProto proto;
        CAFFE_ENFORCE(proto.ParseFromString(cursor->value()));
        if (!keep_device_) {
          // If we are not keeping the device as the one specified in the
          // proto, we will set the current device.
          SetCurrentDevice(&proto);
        }
        auto blobIndex = output_indices_[key];
        Blob* blob = outputs.at(blobIndex);
        load_save_op_util::ProcessBlob(
            blob, proto, blob_states, key, &loaded_blobs);

        if (*total_loaded_blobs + loaded_blobs == OutputSize()) {
          break;
        }
      }
    }

    *total_loaded_blobs += loaded_blobs;
  }

 private:
  Workspace* ws_;
  bool absolute_path_;
  string add_prefix_;
  string strip_prefix_;
  string db_name_;
  std::vector<std::string> db_names_;
  string db_type_;
  bool keep_device_;
  bool load_all_;
  bool allow_incomplete_;
  std::map<string, int> output_indices_;
  std::map<string, int> key_to_dbid_;
  std::vector<std::string> blob_names_;
  std::vector<int64_t> shape_;
};

template <class Context>
class SaveOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  explicit SaveOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        ws_(ws),
        absolute_path_(
            this->template GetSingleArgument<int>("absolute_path", false)),
        strip_prefix_(
            this->template GetSingleArgument<string>("strip_prefix", "")),
        db_name_(this->template GetSingleArgument<string>("db", "")),
        db_type_(this->template GetSingleArgument<string>("db_type", "")),
        blob_names_(
            this->template GetRepeatedArgument<string>("blob_name_overrides")),
        chunk_size_(this->template GetSingleArgument<int>(
            "chunk_size",
            kDefaultChunkSize)) {
    CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name.");
    CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type.");
    CAFFE_ENFORCE(
        blob_names_.empty() ||
            blob_names_.size() == OperatorBase::Inputs().size(),
        "Number of blobs and blob_name_overrides mismatch.");
    CAFFE_ENFORCE(
        blob_names_.empty() || strip_prefix_.empty(),
        "strip_prefix and blob_name_overrides are mutually exclusive.");

    if (blob_names_.empty()) {
      std::set<std::string> input_names;
      blob_names_.resize(OperatorBase::Inputs().size());
      for (int i = 0; i < blob_names_.size(); ++i) {
        std::string name;
        if (strip_prefix_.empty()) {
          name = operator_def.input(i);
        } else {
          auto match_pos = operator_def.input(i).find(strip_prefix_);
          if (match_pos == string::npos) {
            name = operator_def.input(i);
          } else {
            name = operator_def.input(i).substr(
                match_pos + strip_prefix_.size(), string::npos);
          }
        }
        CAFFE_ENFORCE(
            input_names.insert(name).second, "Duplicated input: ", name);
        blob_names_[i] = name;
Loading ...