Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / queue / blobs_queue_db.h


#pragma once

#include <chrono>
#include <string>

#include "caffe2/core/db.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/stats.h"
#include "caffe2/queue/blobs_queue.h"

namespace caffe2 {
namespace db {

namespace {
const std::string& GetStringFromBlob(Blob* blob) {
  if (blob->template IsType<string>()) {
    return blob->template Get<string>();
  } else if (blob->template IsType<Tensor>()) {
    return *blob->template Get<Tensor>().template data<string>();
  } else {
    CAFFE_THROW("Unsupported Blob type");
  }
}
}

class BlobsQueueDBCursor : public Cursor {
 public:
  explicit BlobsQueueDBCursor(
      std::shared_ptr<BlobsQueue> queue,
      int key_blob_index,
      int value_blob_index,
      float timeout_secs)
      : queue_(queue),
        key_blob_index_(key_blob_index),
        value_blob_index_(value_blob_index),
        timeout_secs_(timeout_secs),
        inited_(false),
        valid_(false) {
    LOG(INFO) << "BlobsQueueDBCursor constructed";
    CAFFE_ENFORCE(queue_ != nullptr, "queue is null");
    CAFFE_ENFORCE(value_blob_index_ >= 0, "value_blob_index < 0");
  }

  virtual ~BlobsQueueDBCursor() {}

  void Seek(const string& /* unused */) override {
    CAFFE_THROW("Seek is not supported.");
  }

  bool SupportsSeek() override {
    return false;
  }

  void SeekToFirst() override {
    // not applicable
  }

  void Next() override {
    unique_ptr<Blob> blob = make_unique<Blob>();
    vector<Blob*> blob_vector{blob.get()};
    auto success = queue_->blockingRead(blob_vector, timeout_secs_);
    if (!success) {
      LOG(ERROR) << "Timed out reading from BlobsQueue or it is closed";
      valid_ = false;
      return;
    }

    if (key_blob_index_ >= 0) {
      key_ = GetStringFromBlob(blob_vector[key_blob_index_]);
    }
    value_ = GetStringFromBlob(blob_vector[value_blob_index_]);
    valid_ = true;
  }

  string key() override {
    if (!inited_) {
      Next();
      inited_ = true;
    }
    return key_;
  }

  string value() override {
    if (!inited_) {
      Next();
      inited_ = true;
    }
    return value_;
  }

  bool Valid() override {
    return valid_;
  }

 private:
  std::shared_ptr<BlobsQueue> queue_;
  int key_blob_index_;
  int value_blob_index_;
  float timeout_secs_;
  bool inited_;
  string key_;
  string value_;
  bool valid_;
};

class BlobsQueueDB : public DB {
 public:
  BlobsQueueDB(
      const string& source,
      Mode mode,
      std::shared_ptr<BlobsQueue> queue,
      int key_blob_index = -1,
      int value_blob_index = 0,
      float timeout_secs = 0.0)
      : DB(source, mode),
        queue_(queue),
        key_blob_index_(key_blob_index),
        value_blob_index_(value_blob_index),
        timeout_secs_(timeout_secs) {
    LOG(INFO) << "BlobsQueueDB constructed";
  }

  virtual ~BlobsQueueDB() {
    Close();
  }

  void Close() override {}
  unique_ptr<Cursor> NewCursor() override {
    return make_unique<BlobsQueueDBCursor>(
        queue_, key_blob_index_, value_blob_index_, timeout_secs_);
  }

  unique_ptr<Transaction> NewTransaction() override {
    CAFFE_THROW("Not implemented.");
  }

 private:
  std::shared_ptr<BlobsQueue> queue_;
  int key_blob_index_;
  int value_blob_index_;
  float timeout_secs_;
};
} // namespace db
} // namespace caffe2