#ifndef GATHER_OP_H_
#define GATHER_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
// This maintains index-mapping functions shared by Gather and BatchGather ops.
namespace gather_helper {
// New shape is concatenation:
// [data dims before axis] + [indices dims] + [data dims after axis]
template <typename IndexType, typename DataDimsVec, typename IndexDimsVec>
static vector<IndexType> calc_output_shape_vector(
const DataDimsVec& data_dims,
const IndexDimsVec& indices_dims,
int axis,
bool match_outer) {
vector<IndexType> shape;
// If the dimension we are indexing is empty, just use data_dims as shape.
// This replicates behavior in (https://github.com/pytorch/pytorch/pull/13781)
// needed to allow workflows with empty batch to succeed.
if (data_dims[axis] == 0) {
shape.insert(shape.end(), data_dims.begin(), data_dims.end());
} else {
shape.insert(shape.end(), data_dims.begin(), data_dims.begin() + axis);
if (match_outer) {
shape.insert(
shape.end(), indices_dims.begin() + axis, indices_dims.end());
} else {
shape.insert(shape.end(), indices_dims.begin(), indices_dims.end());
}
shape.insert(shape.end(), data_dims.begin() + axis + 1, data_dims.end());
}
return shape;
}
// Check that indices fall within dimension array size with CAFFE_ENFORCE.
template <typename IndexType>
static void check_indexarray_range(
const IndexType* indices,
int64_t n,
IndexType indexing_axis_dim,
bool wrap_indices) {
//
for (auto i = 0; i < n; ++i) {
auto idx = indices[i];
if (wrap_indices && idx < 0) {
idx = idx + indexing_axis_dim;
}
CAFFE_ENFORCE(
0 <= idx && idx < indexing_axis_dim,
"INDICES element is out of DATA bounds, id=",
idx,
" axis_dim=",
indexing_axis_dim);
}
}
// Actual gather implementation - resizes output and copies indexed data.
template <typename Index, typename Context>
static bool gather_impl(
Operator<Context>* op,
int dataIdx,
int indicesIdx,
int outputIdx,
int axis,
bool wrap_indices,
bool match_outer) {
// If we endup using it on GPU doing O(N) memcpy is probably not best :)
// TODO: implement prefetching if it starts mattering (TF does it)
const Tensor& data = op->Input(dataIdx);
const Tensor& indices = op->Input(indicesIdx);
const TypeMeta dataType = data.dtype();
size_t item_bytesize = dataType.itemsize();
// ONNX allows negative axis to index from the back, valid range: [-r, r].
if (axis < 0) {
axis = data.dim() + axis;
}
CAFFE_ENFORCE_GE(data.dim(), axis + 1, "DATA should be at least [axis+1]-D");
CAFFE_ENFORCE_GE(axis, 0, "Axis should be non-negative");
CAFFE_ENFORCE_LT(axis, data.dim(), "Axis out of range");
// New shape:
// [data dims before axis] + [indices dims] + [data dims after axis]
vector<int64_t> shape = calc_output_shape_vector<int64_t>(
data.sizes(), indices.sizes(), axis, match_outer);
Tensor* output = op->Output(outputIdx, shape, at::dtype(dataType));
auto out = static_cast<char*>(output->raw_mutable_data(dataType));
// Succeed if size of output is zero, which can happen for empty batch which
// would have data dimension size of 0.
// This *must* be done AFTER output->raw_mutable_data() above as that has
// important allocation side effect that we must see.
if (output->numel() == 0) {
return true;
}
const Index* idxs = indices.template data<Index>();
auto src_base = static_cast<const char*>(data.raw_data());
auto outer_dims_product = data.size_to_dim(axis);
auto block_size = data.size_from_dim(axis + 1);
auto block_bytesize = block_size * item_bytesize;
auto src_indexing_axis_dim = data.size(axis);
auto src_batch_bytesize = data.size_from_dim(axis) * item_bytesize;
// Treat indices as a single block even if they have multiple dimensions.
// The "gathered batch" is a cumulative result combining indexed blocks.
auto idx_inner_dims_product = indices.size_from_dim(axis);
auto N = indices.numel();
if (match_outer) {
CAFFE_ENFORCE_GE(axis, 1, "Axis should be at least 1");
for (auto i = 0; i < axis; i++) {
CAFFE_ENFORCE_EQ(
data.size(i),
indices.size(i),
"INDICES must have the same outer dims as DATA (before dim AXIS)");
}
N = idx_inner_dims_product;
}
auto gathered_batch_bytesize = N * block_size * item_bytesize;
check_indexarray_range<Index>(idxs, N, src_indexing_axis_dim, wrap_indices);
// Special-case single-float copy for efficiency
if (data.template IsType<float>() && block_size == 1) {
for (auto batch = 0; batch < outer_dims_product; ++batch) {
const float* src_floats =
(const float*)(src_base + batch * src_batch_bytesize);
float* dst_floats = (float*)(out + batch * gathered_batch_bytesize);
for (auto i = 0; i < N; ++i) {
auto idx = idxs[i];
if (match_outer) {
idx = idxs[batch * idx_inner_dims_product + i];
}
if (wrap_indices && idx < 0) {
idx = idx + src_indexing_axis_dim;
}
dst_floats[i] = src_floats[idx];
}
}
} else {
// outer_dims_product specifies how many times we repeat inner dimensions,
// so we just iterate over it to cover all outer dimensions.
for (auto batch = 0; batch < outer_dims_product; ++batch) {
for (auto i = 0; i < N; ++i) {
auto idx = idxs[i];
if (match_outer) {
idx = idxs[batch * idx_inner_dims_product + i];
}
if (wrap_indices && idx < 0) {
idx = idx + src_indexing_axis_dim;
}
auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize;
auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize;
op->getContext()->CopyItemsSameDevice(dataType, block_size, src, dst);
}
}
}
return true;
}
} // namespace gather_helper
template <class Context>
class GatherOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit GatherOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(int, "axis", axis_, 0),
OP_SINGLE_ARG(bool, "match_outer", match_outer_, false) {
// TBD: We may want to fix the old index wrap behaviour once we have
// operator versioning, to only apply it when needed as otherwise its likely
// an error.
// Right now, we apply index wrapping by default only to axis == 0,
// since we have ONNX conversion code that uses it. For other ops it
// needs to be specified explicitly with argument or you don't get it.
if (OperatorBase::HasArgument("wrap_indices")) {
wrap_indices_ = Operator<Context>::template GetSingleArgument<bool>(
"wrap_indices", (false));
} else {
wrap_indices_ = (axis_ == 0) ? true : false;
}
}
virtual ~GatherOp() noexcept {}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, this->template Input<Tensor>(INDICES, CPU));
}
template <typename Index>
bool DoRunWithType() {
return gather_helper::gather_impl<Index, Context>(
this, DATA, INDICES, 0, axis_, wrap_indices_, match_outer_);
}
INPUT_TAGS(DATA, INDICES);
protected:
int axis_;
bool wrap_indices_;
bool match_outer_;
};
} // namespace caffe2
#endif // GATHER_OP_H_