#pragma once
#include <chrono>
#include <string>
#include "caffe2/core/db.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/stats.h"
#include "caffe2/queue/blobs_queue.h"
namespace caffe2 {
namespace db {
namespace {
const std::string& GetStringFromBlob(Blob* blob) {
if (blob->template IsType<string>()) {
return blob->template Get<string>();
} else if (blob->template IsType<Tensor>()) {
return *blob->template Get<Tensor>().template data<string>();
} else {
CAFFE_THROW("Unsupported Blob type");
}
}
}
class BlobsQueueDBCursor : public Cursor {
public:
explicit BlobsQueueDBCursor(
std::shared_ptr<BlobsQueue> queue,
int key_blob_index,
int value_blob_index,
float timeout_secs)
: queue_(queue),
key_blob_index_(key_blob_index),
value_blob_index_(value_blob_index),
timeout_secs_(timeout_secs),
inited_(false),
valid_(false) {
LOG(INFO) << "BlobsQueueDBCursor constructed";
CAFFE_ENFORCE(queue_ != nullptr, "queue is null");
CAFFE_ENFORCE(value_blob_index_ >= 0, "value_blob_index < 0");
}
virtual ~BlobsQueueDBCursor() {}
void Seek(const string& /* unused */) override {
CAFFE_THROW("Seek is not supported.");
}
bool SupportsSeek() override {
return false;
}
void SeekToFirst() override {
// not applicable
}
void Next() override {
unique_ptr<Blob> blob = make_unique<Blob>();
vector<Blob*> blob_vector{blob.get()};
auto success = queue_->blockingRead(blob_vector, timeout_secs_);
if (!success) {
LOG(ERROR) << "Timed out reading from BlobsQueue or it is closed";
valid_ = false;
return;
}
if (key_blob_index_ >= 0) {
key_ = GetStringFromBlob(blob_vector[key_blob_index_]);
}
value_ = GetStringFromBlob(blob_vector[value_blob_index_]);
valid_ = true;
}
string key() override {
if (!inited_) {
Next();
inited_ = true;
}
return key_;
}
string value() override {
if (!inited_) {
Next();
inited_ = true;
}
return value_;
}
bool Valid() override {
return valid_;
}
private:
std::shared_ptr<BlobsQueue> queue_;
int key_blob_index_;
int value_blob_index_;
float timeout_secs_;
bool inited_;
string key_;
string value_;
bool valid_;
};
class BlobsQueueDB : public DB {
public:
BlobsQueueDB(
const string& source,
Mode mode,
std::shared_ptr<BlobsQueue> queue,
int key_blob_index = -1,
int value_blob_index = 0,
float timeout_secs = 0.0)
: DB(source, mode),
queue_(queue),
key_blob_index_(key_blob_index),
value_blob_index_(value_blob_index),
timeout_secs_(timeout_secs) {
LOG(INFO) << "BlobsQueueDB constructed";
}
virtual ~BlobsQueueDB() {
Close();
}
void Close() override {}
unique_ptr<Cursor> NewCursor() override {
return make_unique<BlobsQueueDBCursor>(
queue_, key_blob_index_, value_blob_index_, timeout_secs_);
}
unique_ptr<Transaction> NewTransaction() override {
CAFFE_THROW("Not implemented.");
}
private:
std::shared_ptr<BlobsQueue> queue_;
int key_blob_index_;
int value_blob_index_;
float timeout_secs_;
};
} // namespace db
} // namespace caffe2