Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / pack_rnn_sequence_op.h

#ifndef CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
#define CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_

#include <algorithm>
#include <vector>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

template <class Context, bool Forward>
class PackRNNSequenceOpBase : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit PackRNNSequenceOpBase(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<int32_t, int64_t, float, double>>::call(
        this, Input(0));
  }

  template <typename ValT>
  bool DoRunWithType() {
    // The value is copied from the sequence to the pack
    // if Forward is true, and vice versa
    int dim_offset = Forward ? 1 : 2;
    auto& values = Input(0);
    CAFFE_ENFORCE_GT(values.dim(), dim_offset);

    // block_size is the size for each individual feature
    int64_t block_size = values.size_from_dim(dim_offset);
    auto values_vec = values.template data<ValT>();

    auto& lengths = Input(LENGTHS);
    CAFFE_ENFORCE_EQ(lengths.dim(), 1);
    const auto cols = lengths.numel();
    const int32_t* lengths_vec = lengths.template data<int32_t>();
    // the total number of rows is defined as the max number from lengths
    // if when the lengths is empty, we set rows = 0 to support zero lengths
    const auto rows =
        cols ? *std::max_element(lengths_vec, lengths_vec + cols) : 0;
    CAFFE_ENFORCE_GE(rows, 0);
    int length_sum = 0;
    if (cols > 0) {
      math::Sum<int, Context>(cols, lengths_vec, &length_sum, &context_);
    }

    vector<int64_t> shape;
    // the output shape is rows * cols for the pack,
    // or length_sum for the sequence
    if (Forward) {
      shape.push_back(rows);
      shape.push_back(cols);
    } else {
      shape.push_back(length_sum);
    }
    // insert the dim for the feature
    shape.insert(
        shape.end(), values.sizes().begin() + dim_offset, values.sizes().end());

    auto* output = Output(OUTPUTVALUE, shape, at::dtype<ValT>());

    auto output_data = output->template mutable_data<ValT>();
    // initialize output_data with zero, as it is the default value for padding
    // when certain length is smaller than rows
    math::Set<ValT, Context>(output->numel(), 0, output_data, &context_);

    int32_t offset = 0;
    for (int c = 0; c < cols; c++) {
      for (int r = 0; r < lengths_vec[c]; r++) {
        auto input_offset = Forward ? (offset + r) : (r * cols + c);
        auto output_offset = Forward ? (r * cols + c) : (offset + r);
        context_.CopyItemsSameDevice(
            values.dtype(),
            block_size,
            values_vec + input_offset * block_size,
            output_data + output_offset * block_size);
      }
      offset += lengths_vec[c];
    }
    return true;
  }

 private:
  INPUT_TAGS(INPUTVALUE, LENGTHS);
  OUTPUT_TAGS(OUTPUTVALUE);
};
} // namespace caffe2

#endif // CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_