#ifndef CAFFE2_CORE_DB_H_
#define CAFFE2_CORE_DB_H_
#include <mutex>
#include "c10/util/Registry.h"
#include "caffe2/core/blob_serialization.h"
#include "caffe2/proto/caffe2_pb.h"
namespace caffe2 {
namespace db {
/**
* The mode of the database, whether we are doing a read, write, or creating
* a new database.
*/
enum Mode { READ, WRITE, NEW };
/**
* An abstract class for the cursor of the database while reading.
*/
class TORCH_API Cursor {
public:
Cursor() {}
virtual ~Cursor() {}
/**
* Seek to a specific key (or if the key does not exist, seek to the
* immediate next). This is optional for dbs, and in default, SupportsSeek()
* returns false meaning that the db cursor does not support it.
*/
virtual void Seek(const string& key) = 0;
virtual bool SupportsSeek() {
return false;
}
/**
* Seek to the first key in the database.
*/
virtual void SeekToFirst() = 0;
/**
* Go to the next location in the database.
*/
virtual void Next() = 0;
/**
* Returns the current key.
*/
virtual string key() = 0;
/**
* Returns the current value.
*/
virtual string value() = 0;
/**
* Returns whether the current location is valid - for example, if we have
* reached the end of the database, return false.
*/
virtual bool Valid() = 0;
C10_DISABLE_COPY_AND_ASSIGN(Cursor);
};
/**
* An abstract class for the current database transaction while writing.
*/
class TORCH_API Transaction {
public:
Transaction() {}
virtual ~Transaction() {}
/**
* Puts the key value pair to the database.
*/
virtual void Put(const string& key, const string& value) = 0;
/**
* Commits the current writes.
*/
virtual void Commit() = 0;
C10_DISABLE_COPY_AND_ASSIGN(Transaction);
};
/**
* An abstract class for accessing a database of key-value pairs.
*/
class TORCH_API DB {
public:
DB(const string& /*source*/, Mode mode) : mode_(mode) {}
virtual ~DB() {}
/**
* Closes the database.
*/
virtual void Close() = 0;
/**
* Returns a cursor to read the database. The caller takes the ownership of
* the pointer.
*/
virtual std::unique_ptr<Cursor> NewCursor() = 0;
/**
* Returns a transaction to write data to the database. The caller takes the
* ownership of the pointer.
*/
virtual std::unique_ptr<Transaction> NewTransaction() = 0;
protected:
Mode mode_;
C10_DISABLE_COPY_AND_ASSIGN(DB);
};
// Database classes are registered by their names so we can do optional
// dependencies.
C10_DECLARE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode);
#define REGISTER_CAFFE2_DB(name, ...) \
C10_REGISTER_CLASS(Caffe2DBRegistry, name, __VA_ARGS__)
/**
* Returns a database object of the given database type, source and mode. The
* caller takes the ownership of the pointer. If the database type is not
* supported, a nullptr is returned. The caller is responsible for examining the
* validity of the pointer.
*/
inline unique_ptr<DB>
CreateDB(const string& db_type, const string& source, Mode mode) {
auto result = Caffe2DBRegistry()->Create(db_type, source, mode);
VLOG(1) << ((!result) ? "not found db " : "found db ") << db_type;
return result;
}
/**
* Returns whether or not a database exists given the database type and path.
*/
inline bool DBExists(const string& db_type, const string& full_db_name) {
// Warning! We assume that creating a DB throws an exception if the DB
// does not exist. If the DB constructor does not follow this design
// pattern,
// the returned output (the existence tensor) can be wrong.
try {
std::unique_ptr<DB> db(
caffe2::db::CreateDB(db_type, full_db_name, caffe2::db::READ));
return true;
} catch (...) {
return false;
}
}
/**
* A reader wrapper for DB that also allows us to serialize it.
*/
class TORCH_API DBReader {
public:
friend class DBReaderSerializer;
DBReader() {}
DBReader(
const string& db_type,
const string& source,
const int32_t num_shards = 1,
const int32_t shard_id = 0) {
Open(db_type, source, num_shards, shard_id);
}
explicit DBReader(const DBReaderProto& proto) {
Open(proto.db_type(), proto.source());
if (proto.has_key()) {
CAFFE_ENFORCE(
cursor_->SupportsSeek(),
"Encountering a proto that needs seeking but the db type "
"does not support it.");
cursor_->Seek(proto.key());
}
num_shards_ = 1;
shard_id_ = 0;
}
explicit DBReader(std::unique_ptr<DB> db)
: db_type_("<memory-type>"),
source_("<memory-source>"),
db_(std::move(db)) {
CAFFE_ENFORCE(db_.get(), "Passed null db");
cursor_ = db_->NewCursor();
}
void Open(
const string& db_type,
const string& source,
const int32_t num_shards = 1,
const int32_t shard_id = 0) {
// Note(jiayq): resetting is needed when we re-open e.g. leveldb where no
// concurrent access is allowed.
cursor_.reset();
db_.reset();
db_type_ = db_type;
source_ = source;
db_ = CreateDB(db_type_, source_, READ);
CAFFE_ENFORCE(
db_,
"Cannot find db implementation of type ",
db_type,
" (while trying to open ",
source_,
")");
InitializeCursor(num_shards, shard_id);
}
void Open(
unique_ptr<DB>&& db,
const int32_t num_shards = 1,
const int32_t shard_id = 0) {
cursor_.reset();
db_.reset();
db_ = std::move(db);
CAFFE_ENFORCE(db_.get(), "Passed null db");
InitializeCursor(num_shards, shard_id);
}
public:
/**
* Read a set of key and value from the db and move to next. Thread safe.
*
* The string objects key and value must be created by the caller and
* explicitly passed in to this function. This saves one additional object
* copy.
*
* If the cursor reaches its end, the reader will go back to the head of
* the db. This function can be used to enable multiple input ops to read
* the same db.
*
* Note(jiayq): we loosen the definition of a const function here a little
* bit: the state of the cursor is actually changed. However, this allows
* us to pass in a DBReader to an Operator without the need of a duplicated
* output blob.
*/
void Read(string* key, string* value) const {
CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
*key = cursor_->key();
*value = cursor_->value();
// In sharded mode, each read skips num_shards_ records
for (uint32_t s = 0; s < num_shards_; s++) {
cursor_->Next();
if (!cursor_->Valid()) {
MoveToBeginning();
break;
}
}
}
/**
* @brief Seeks to the first key. Thread safe.
*/
void SeekToFirst() const {
CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
MoveToBeginning();
}
/**
* Returns the underlying cursor of the db reader.
*
* Note that if you directly use the cursor, the read will not be thread
* safe, because there is no mechanism to stop multiple threads from
* accessing the same cursor. You should consider using Read() explicitly.
*/
inline Cursor* cursor() const {
VLOG(1) << "Usually for a DBReader you should use Read() to be "
"thread safe. Consider refactoring your code.";
return cursor_.get();
}
private:
void InitializeCursor(const int32_t num_shards, const int32_t shard_id) {
CAFFE_ENFORCE(num_shards >= 1);
CAFFE_ENFORCE(shard_id >= 0);
CAFFE_ENFORCE(shard_id < num_shards);
num_shards_ = num_shards;
shard_id_ = shard_id;
cursor_ = db_->NewCursor();
SeekToFirst();
}
void MoveToBeginning() const {
cursor_->SeekToFirst();
for (uint32_t s = 0; s < shard_id_; s++) {
cursor_->Next();
CAFFE_ENFORCE(
cursor_->Valid(), "Db has fewer rows than shard id: ", s, shard_id_);
}
}
string db_type_;
string source_;
unique_ptr<DB> db_;
unique_ptr<Cursor> cursor_;
mutable std::mutex reader_mutex_;
uint32_t num_shards_{};
uint32_t shard_id_{};
C10_DISABLE_COPY_AND_ASSIGN(DBReader);
};
class TORCH_API DBReaderSerializer : public BlobSerializerBase {
public:
/**
* Serializes a DBReader. Note that this blob has to contain DBReader,
* otherwise this function produces a fatal error.
*/
void Serialize(
const void* pointer,
TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) override;
};
class TORCH_API DBReaderDeserializer : public BlobDeserializerBase {
public:
void Deserialize(const BlobProto& proto, Blob* blob) override;
};
} // namespace db
} // namespace caffe2
#endif // CAFFE2_CORE_DB_H_