#ifndef CAFFE2_OPERATORS_DATASET_OPS_H_
#define CAFFE2_OPERATORS_DATASET_OPS_H_
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "caffe2/core/blob.h"
#include "caffe2/core/blob_serialization.h"
#include "caffe2/core/tensor.h"
namespace caffe2 {
namespace dataset_ops {
// used for lengths tensors in the dataset
using TLength = int32_t;
// used for all internal dataset operations (offsets, sizes to read, etc.)
using TOffset = int64_t;
/**
* Provides functionality to iterate across a list of tensors where some
* of those tensors represent lengths in a hierarchical structure.
*/
class TreeIterator {
public:
struct FieldDesc {
int id;
int lengthFieldId = -1;
std::string name;
};
explicit TreeIterator(const std::vector<std::string>& fields);
void advance(
const std::vector<const TLength*>& lengths,
std::vector<TOffset>& offsets,
std::vector<TOffset>& sizes,
std::vector<TOffset>& limits,
TOffset num);
// Corresponds to the number of fields that have "length" as its last name
int numLengthFields() const {
return lengthFieldIds_.size();
}
// Corresponds to the number of length fields + 1 (for the top-level domain)
int numOffsetFields() const {
return numLengthFields() + 1;
}
// Get lengthField description for the given field
const FieldDesc* lengthFieldFor(const FieldDesc& desc) {
return (desc.lengthFieldId == -1)
? nullptr
: &fields_.at(lengthFieldIds_.at(desc.lengthFieldId));
}
// Get lengthField description for the given lengthFieldId, where
// 0 <= lengthFieldId < numLengthFields()
const FieldDesc& lengthField(int lengthFieldId) {
return fields_.at(lengthFieldIds_.at(lengthFieldId));
}
// Returns the index into the 'offset' vector for the given field.
int offsetFieldIdFor(const FieldDesc& fieldDesc) {
return fieldDesc.lengthFieldId + 1;
}
// Returns the field description for all fields.
const std::vector<FieldDesc>& fields() {
return fields_;
}
const std::vector<int>& lengthFieldIds() const {
return lengthFieldIds_;
}
private:
// Description of each field
std::vector<FieldDesc> fields_;
// Index into fields_ above for the fields that are lengths.
std::vector<int> lengthFieldIds_;
};
class TreeCursor {
public:
explicit TreeCursor(const TreeIterator& iterator) : it(iterator) {}
std::vector<TOffset> offsets;
std::mutex mutex_;
TreeIterator it;
};
/**
* Simple wrapper class allowing an easy traversal of the tensors representing
* the hirerarchical structure.
*/
class TreeWalker {
public:
TreeWalker(const vector<const Blob*>& inputs, TreeCursor& cursor);
// Returns the number of records in a dataset
inline TOffset size() const {
return limits_.at(0);
}
void advance();
private:
inline const TensorCPU& input(int32_t idx) const {
return inputs_[idx]->Get<TensorCPU>();
}
// TODO: Change to fieldDesc
inline const TreeIterator::FieldDesc& field(int idx) const {
return cursor_.it.fields().at(idx);
}
inline int lengthIdx(int fieldId) const {
return field(fieldId).lengthFieldId + 1;
}
inline TOffset offset(int fieldId) const {
return prevOffsets_[lengthIdx(fieldId)];
}
std::vector<int64_t> fieldDim(int fieldId) const;
void* fieldPtr(int fieldId) const;
public:
// Simple Proxy class to expose nicer API for field access
class Field {
public:
Field(TreeWalker& walker, int fieldId)
: walker_(walker), fieldId_(fieldId) {}
inline std::vector<int64_t> dim() const {
return walker_.fieldDim(fieldId_);
}
inline int64_t size() const {
int64_t size = 1;
for (const auto d : dim()) {
size *= d;
}
return size;
}
inline const TypeMeta meta() const {
return walker_.input(fieldId_).dtype();
}
inline void* ptr() const {
return walker_.fieldPtr(fieldId_);
}
int fieldId() const {
return fieldId_;
}
inline TOffset offset() const {
return walker_.offset(fieldId_);
}
private:
const TreeWalker& walker_;
const int fieldId_;
};
// Notice that a reference is returned. If advance() is called the fields will
// be updated to represent the new state.
inline const std::vector<Field>& fields() const {
return fields_;
}
private:
void gatherLengthData();
void gatherSizeLimits();
const vector<const Blob*>& inputs_;
TreeCursor& cursor_;
std::vector<Field> fields_;
std::vector<const TLength*> lengths_;
std::vector<TOffset> limits_;
std::vector<TOffset> sizes_;
std::vector<TOffset> offsets_;
std::vector<TOffset> prevOffsets_;
};
using SharedTensorVectorPtr = std::shared_ptr<std::vector<TensorCPU>>;
using Shared2DTensorVectorPtr =
std::shared_ptr<std::vector<std::vector<caffe2::TensorCPU>>>;
using Tensor2DVector = std::vector<std::vector<caffe2::TensorCPU>>;
using TensorVectorPtr = std::unique_ptr<std::vector<Tensor>>;
class SharedTensorVectorPtrSerializer : public BlobSerializerBase {
public:
void Serialize(
const void* pointer,
TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) override;
};
class SharedTensorVectorPtrDeserializer : public BlobDeserializerBase {
public:
void Deserialize(const BlobProto& proto, Blob* blob) override;
};
} // namespace dataset_ops
} // namespace caffe2
#endif // CAFFE2_OPERATORS_DATASET_OPS_H_