Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / dataset_ops.h

#ifndef CAFFE2_OPERATORS_DATASET_OPS_H_
#define CAFFE2_OPERATORS_DATASET_OPS_H_

#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "caffe2/core/blob.h"
#include "caffe2/core/blob_serialization.h"
#include "caffe2/core/tensor.h"

namespace caffe2 {
namespace dataset_ops {

// used for lengths tensors in the dataset
using TLength = int32_t;
// used for all internal dataset operations (offsets, sizes to read, etc.)
using TOffset = int64_t;

/**
 * Provides functionality to iterate across a list of tensors where some
 * of those tensors represent lengths in a hierarchical structure.
 */
class TreeIterator {
 public:
  struct FieldDesc {
    int id;
    int lengthFieldId = -1;
    std::string name;
  };

  explicit TreeIterator(const std::vector<std::string>& fields);

  void advance(
      const std::vector<const TLength*>& lengths,
      std::vector<TOffset>& offsets,
      std::vector<TOffset>& sizes,
      std::vector<TOffset>& limits,
      TOffset num);

  // Corresponds to the number of fields that have "length" as its last name
  int numLengthFields() const {
    return lengthFieldIds_.size();
  }

  // Corresponds to the number of length fields + 1 (for the top-level domain)
  int numOffsetFields() const {
    return numLengthFields() + 1;
  }

  // Get lengthField description for the given field
  const FieldDesc* lengthFieldFor(const FieldDesc& desc) {
    return (desc.lengthFieldId == -1)
        ? nullptr
        : &fields_.at(lengthFieldIds_.at(desc.lengthFieldId));
  }

  // Get lengthField description for the given lengthFieldId, where
  // 0 <= lengthFieldId < numLengthFields()
  const FieldDesc& lengthField(int lengthFieldId) {
    return fields_.at(lengthFieldIds_.at(lengthFieldId));
  }

  // Returns the index into the 'offset' vector for the given field.
  int offsetFieldIdFor(const FieldDesc& fieldDesc) {
    return fieldDesc.lengthFieldId + 1;
  }

  // Returns the field description for all fields.
  const std::vector<FieldDesc>& fields() {
    return fields_;
  }

  const std::vector<int>& lengthFieldIds() const {
    return lengthFieldIds_;
  }

 private:
  // Description of each field
  std::vector<FieldDesc> fields_;
  // Index into fields_ above for the fields that are lengths.
  std::vector<int> lengthFieldIds_;
};

class TreeCursor {
 public:
  explicit TreeCursor(const TreeIterator& iterator) : it(iterator) {}
  std::vector<TOffset> offsets;
  std::mutex mutex_;
  TreeIterator it;
};

/**
 * Simple wrapper class allowing an easy traversal of the tensors representing
 * the hirerarchical structure.
 */
class TreeWalker {
 public:
  TreeWalker(const vector<const Blob*>& inputs, TreeCursor& cursor);

  // Returns the number of records in a dataset
  inline TOffset size() const {
    return limits_.at(0);
  }

  void advance();

 private:
  inline const TensorCPU& input(int32_t idx) const {
    return inputs_[idx]->Get<TensorCPU>();
  }

  // TODO: Change to fieldDesc
  inline const TreeIterator::FieldDesc& field(int idx) const {
    return cursor_.it.fields().at(idx);
  }

  inline int lengthIdx(int fieldId) const {
    return field(fieldId).lengthFieldId + 1;
  }

  inline TOffset offset(int fieldId) const {
    return prevOffsets_[lengthIdx(fieldId)];
  }

  std::vector<int64_t> fieldDim(int fieldId) const;

  void* fieldPtr(int fieldId) const;

 public:
  // Simple Proxy class to expose nicer API for field access
  class Field {
   public:
    Field(TreeWalker& walker, int fieldId)
        : walker_(walker), fieldId_(fieldId) {}

    inline std::vector<int64_t> dim() const {
      return walker_.fieldDim(fieldId_);
    }

    inline int64_t size() const {
      int64_t size = 1;
      for (const auto d : dim()) {
        size *= d;
      }
      return size;
    }

    inline const TypeMeta meta() const {
      return walker_.input(fieldId_).dtype();
    }

    inline void* ptr() const {
      return walker_.fieldPtr(fieldId_);
    }

    int fieldId() const {
      return fieldId_;
    }

    inline TOffset offset() const {
      return walker_.offset(fieldId_);
    }

   private:
    const TreeWalker& walker_;
    const int fieldId_;
  };

  // Notice that a reference is returned. If advance() is called the fields will
  // be updated to represent the new state.
  inline const std::vector<Field>& fields() const {
    return fields_;
  }

 private:
  void gatherLengthData();

  void gatherSizeLimits();

  const vector<const Blob*>& inputs_;
  TreeCursor& cursor_;
  std::vector<Field> fields_;

  std::vector<const TLength*> lengths_;
  std::vector<TOffset> limits_;
  std::vector<TOffset> sizes_;
  std::vector<TOffset> offsets_;
  std::vector<TOffset> prevOffsets_;
};

using SharedTensorVectorPtr = std::shared_ptr<std::vector<TensorCPU>>;

using Shared2DTensorVectorPtr =
    std::shared_ptr<std::vector<std::vector<caffe2::TensorCPU>>>;

using Tensor2DVector = std::vector<std::vector<caffe2::TensorCPU>>;

using TensorVectorPtr = std::unique_ptr<std::vector<Tensor>>;

class SharedTensorVectorPtrSerializer : public BlobSerializerBase {
 public:
  void Serialize(
      const void* pointer,
      TypeMeta typeMeta,
      const string& name,
      BlobSerializerBase::SerializationAcceptor acceptor) override;
};

class SharedTensorVectorPtrDeserializer : public BlobDeserializerBase {
 public:
  void Deserialize(const BlobProto& proto, Blob* blob) override;
};

} // namespace dataset_ops
} // namespace caffe2

#endif // CAFFE2_OPERATORS_DATASET_OPS_H_