Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / gather_op.h

#ifndef GATHER_OP_H_
#define GATHER_OP_H_

#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"

namespace caffe2 {

// This maintains index-mapping functions shared by Gather and BatchGather ops.
namespace gather_helper {

// New shape is concatenation:
//  [data dims before axis] + [indices dims] + [data dims after axis]
template <typename IndexType, typename DataDimsVec, typename IndexDimsVec>
static vector<IndexType> calc_output_shape_vector(
    const DataDimsVec& data_dims,
    const IndexDimsVec& indices_dims,
    int axis,
    bool match_outer) {
  vector<IndexType> shape;
  // If the dimension we are indexing is empty, just use data_dims as shape.
  // This replicates behavior in (https://github.com/pytorch/pytorch/pull/13781)
  // needed to allow workflows with empty batch to succeed.
  if (data_dims[axis] == 0) {
    shape.insert(shape.end(), data_dims.begin(), data_dims.end());
  } else {
    shape.insert(shape.end(), data_dims.begin(), data_dims.begin() + axis);
    if (match_outer) {
      shape.insert(
          shape.end(), indices_dims.begin() + axis, indices_dims.end());
    } else {
      shape.insert(shape.end(), indices_dims.begin(), indices_dims.end());
    }
    shape.insert(shape.end(), data_dims.begin() + axis + 1, data_dims.end());
  }
  return shape;
}

// Check that indices fall within dimension array size with CAFFE_ENFORCE.
template <typename IndexType>
static void check_indexarray_range(
    const IndexType* indices,
    int64_t n,
    IndexType indexing_axis_dim,
    bool wrap_indices) {
  //
  for (auto i = 0; i < n; ++i) {
    auto idx = indices[i];
    if (wrap_indices && idx < 0) {
      idx = idx + indexing_axis_dim;
    }
    CAFFE_ENFORCE(
        0 <= idx && idx < indexing_axis_dim,
        "INDICES element is out of DATA bounds, id=",
        idx,
        " axis_dim=",
        indexing_axis_dim);
  }
}

// Actual gather implementation - resizes output and copies indexed data.
template <typename Index, typename Context>
static bool gather_impl(
    Operator<Context>* op,
    int dataIdx,
    int indicesIdx,
    int outputIdx,
    int axis,
    bool wrap_indices,
    bool match_outer) {
  // If we endup using it on GPU doing O(N) memcpy is probably not best :)
  // TODO: implement prefetching if it starts mattering (TF does it)

  const Tensor& data = op->Input(dataIdx);
  const Tensor& indices = op->Input(indicesIdx);
  const TypeMeta dataType = data.dtype();
  size_t item_bytesize = dataType.itemsize();

  // ONNX allows negative axis to index from the back, valid range: [-r, r].
  if (axis < 0) {
    axis = data.dim() + axis;
  }
  CAFFE_ENFORCE_GE(data.dim(), axis + 1, "DATA should be at least [axis+1]-D");
  CAFFE_ENFORCE_GE(axis, 0, "Axis should be non-negative");
  CAFFE_ENFORCE_LT(axis, data.dim(), "Axis out of range");

  // New shape:
  //  [data dims before axis] + [indices dims] + [data dims after axis]
  vector<int64_t> shape = calc_output_shape_vector<int64_t>(
      data.sizes(), indices.sizes(), axis, match_outer);
  Tensor* output = op->Output(outputIdx, shape, at::dtype(dataType));
  auto out = static_cast<char*>(output->raw_mutable_data(dataType));

  // Succeed if size of output is zero, which can happen for empty batch which
  // would have data dimension size of 0.
  // This *must* be done AFTER output->raw_mutable_data() above as that has
  // important allocation side effect that we must see.
  if (output->numel() == 0) {
    return true;
  }

  const Index* idxs = indices.template data<Index>();
  auto src_base = static_cast<const char*>(data.raw_data());

  auto outer_dims_product = data.size_to_dim(axis);
  auto block_size = data.size_from_dim(axis + 1);
  auto block_bytesize = block_size * item_bytesize;

  auto src_indexing_axis_dim = data.size(axis);
  auto src_batch_bytesize = data.size_from_dim(axis) * item_bytesize;
  // Treat indices as a single block even if they have multiple dimensions.
  // The "gathered batch" is a cumulative result combining indexed blocks.
  auto idx_inner_dims_product = indices.size_from_dim(axis);
  auto N = indices.numel();
  if (match_outer) {
    CAFFE_ENFORCE_GE(axis, 1, "Axis should be at least 1");
    for (auto i = 0; i < axis; i++) {
      CAFFE_ENFORCE_EQ(
          data.size(i),
          indices.size(i),
          "INDICES must have the same outer dims as DATA (before dim AXIS)");
    }
    N = idx_inner_dims_product;
  }

  auto gathered_batch_bytesize = N * block_size * item_bytesize;

  check_indexarray_range<Index>(idxs, N, src_indexing_axis_dim, wrap_indices);

  // Special-case single-float copy for efficiency
  if (data.template IsType<float>() && block_size == 1) {
    for (auto batch = 0; batch < outer_dims_product; ++batch) {
      const float* src_floats =
          (const float*)(src_base + batch * src_batch_bytesize);
      float* dst_floats = (float*)(out + batch * gathered_batch_bytesize);

      for (auto i = 0; i < N; ++i) {
        auto idx = idxs[i];
        if (match_outer) {
          idx = idxs[batch * idx_inner_dims_product + i];
        }
        if (wrap_indices && idx < 0) {
          idx = idx + src_indexing_axis_dim;
        }
        dst_floats[i] = src_floats[idx];
      }
    }
  } else {
    // outer_dims_product specifies how many times we repeat inner dimensions,
    // so we just iterate over it to cover all outer dimensions.
    for (auto batch = 0; batch < outer_dims_product; ++batch) {
      for (auto i = 0; i < N; ++i) {
        auto idx = idxs[i];
        if (match_outer) {
          idx = idxs[batch * idx_inner_dims_product + i];
        }
        if (wrap_indices && idx < 0) {
          idx = idx + src_indexing_axis_dim;
        }

        auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize;
        auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize;
        op->getContext()->CopyItemsSameDevice(dataType, block_size, src, dst);
      }
    }
  }
  return true;
}

} // namespace gather_helper

template <class Context>
class GatherOp : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  template <class... Args>
  explicit GatherOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        OP_SINGLE_ARG(int, "axis", axis_, 0),
        OP_SINGLE_ARG(bool, "match_outer", match_outer_, false) {
    // TBD: We may want to fix the old index wrap behaviour once we have
    // operator versioning, to only apply it when needed as otherwise its likely
    // an error.
    // Right now, we apply index wrapping by default only to axis == 0,
    // since we have ONNX conversion code that uses it. For other ops it
    // needs to be specified explicitly with argument or you don't get it.
    if (OperatorBase::HasArgument("wrap_indices")) {
      wrap_indices_ = Operator<Context>::template GetSingleArgument<bool>(
          "wrap_indices", (false));
    } else {
      wrap_indices_ = (axis_ == 0) ? true : false;
    }
  }

  virtual ~GatherOp() noexcept {}

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
        this, this->template Input<Tensor>(INDICES, CPU));
  }

  template <typename Index>
  bool DoRunWithType() {
    return gather_helper::gather_impl<Index, Context>(
        this, DATA, INDICES, 0, axis_, wrap_indices_, match_outer_);
  }

  INPUT_TAGS(DATA, INDICES);

 protected:
  int axis_;
  bool wrap_indices_;
  bool match_outer_;
};

} // namespace caffe2
#endif // GATHER_OP_H_