#ifndef CAFFE2_OPERATORS_SEGMENT_REDUCTION_OP_H_
#define CAFFE2_OPERATORS_SEGMENT_REDUCTION_OP_H_
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/reducer_functors.h"
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(LengthsSum);
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(LengthsMean);
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(LengthsMax);
namespace caffe2 {
template <typename TData>
class BaseInputAccessor {
public:
BaseInputAccessor() {}
bool observeInput(const Tensor& dataInput) {
data_ = dataInput.raw_data();
return dataInput.template IsType<TData>();
}
inline const TData*
getBlockPtr(int64_t in_block_size, int64_t idx, int64_t /* blocks */ = 1) {
return static_cast<const TData*>(data_) + in_block_size * idx;
}
protected:
const void* data_ = nullptr;
};
////////////////////////////////////////////////////////////////////////////////
// Range reducer ops: leverage that input segment is continuous and allow
// reducer functors to do something special
// Note: for now there are no real use cases for it yet :)
// Also, doesn't support additional arguments for now
////////////////////////////////////////////////////////////////////////////////
/**
* Base implementation for segment reduction op that leverages continuity of the
* data
*
* Assumes that segments are sorted and there are no skip indices
*/
template <
typename T,
typename SIndex,
class Context,
class RangeReducer,
class InputAccessor = BaseInputAccessor<T>>
class AbstractSortedSegmentRangeOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(AbstractSortedSegmentRangeOp);
bool RunOnDevice() override {
auto& dataInput = Input(DATA);
auto& segment_ids = Input(SEGMENT_IDS);
CAFFE_ENFORCE_EQ(1, segment_ids.dim(), "SEGMENT_IDS must be a vector");
auto N = segment_ids.size(0);
CAFFE_ENFORCE_EQ(
N,
dataInput.size(0),
"SEGMENT_IDS must have the same length as outer dimension of DATA");
OPERATOR_NEEDS_FEATURE(
inputAccessor_.observeInput(dataInput),
"Unsupported input type: ",
dataInput.dtype().name(),
".");
const SIndex* s_ids = segment_ids.template data<SIndex>();
const SIndex K = N > 0 ? s_ids[N - 1] + 1 : 0;
auto shape = dataInput.sizes().vec();
shape[0] = K;
auto* output = Output(0, shape, at::dtype<T>());
T* out = output->template mutable_data<T>();
if (N == 0) {
return true;
}
int64_t block_size = dataInput.numel() / N;
// Assume the segments are sorted and there are no gaps
CAFFE_ENFORCE_EQ(0, s_ids[0], "Indices must be sorted and not have gaps");
for (int64_t i = 0; i < N;) {
int64_t start = i;
for (++i; i < N && s_ids[start] == s_ids[i]; ++i)
;
RangeReducer()(
block_size,
i - start,
inputAccessor_.getBlockPtr(block_size, start, i - start),
out + block_size * s_ids[start],
&context_);
// check correctness of the next segment
if (i < N) {
CAFFE_ENFORCE_EQ(
s_ids[start] + 1,
s_ids[i],
"Indices must be sorted and not have gaps");
}
}
return true;
}
static constexpr int kNumInputs = 2;
INPUT_TAGS(DATA, SEGMENT_IDS);
private:
InputAccessor inputAccessor_;
};
template <
typename T,
typename SIndex,
class Context,
class RangeReducerGradient>
class AbstractSortedSegmentRangeGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(AbstractSortedSegmentRangeGradientOp);
bool RunOnDevice() override {
// TODO(azzolini): avoid using input/output if not used by a particular op
auto& data_in = Input(DATA_IN);
auto& data_out = Input(DATA_OUT);
auto& segment_grads = Input(SEGMENT_GRADS);
auto& segment_ids = Input(SEGMENT_IDS);
CAFFE_ENFORCE_EQ(1, segment_ids.dim(), "SEGMENT_IDS must be a vector");
int64_t N = segment_ids.size(0);
const SIndex* s_ids = segment_ids.template data<SIndex>();
const T* s_grads = segment_grads.template data<T>();
const T* d_in = data_in.template data<T>();
const T* d_out = data_out.template data<T>();
auto shape = segment_grads.sizes().vec();
shape[0] = N;
auto* data_grads = Output(0, shape, at::dtype<T>());
const SIndex K = segment_grads.size(0);
T* out = data_grads->template mutable_data<T>();
if (N == 0) {
return true;
}
int64_t block_size = segment_grads.size_from_dim(1);
// Assume the segments are sorted and there are no gaps
CAFFE_ENFORCE_EQ(0, s_ids[0], "Indices must be sorted and not have gaps");
// repeat the check from forward op
CAFFE_ENFORCE_EQ(
K - 1, s_ids[N - 1], "Indices must be sorted and not have gaps");
for (int64_t i = 0; i < N;) {
int64_t start = i;
for (++i; i < N && s_ids[start] == s_ids[i]; ++i)
;
auto expanded_idx = block_size * start;
auto reduced_idx = block_size * s_ids[start];
RangeReducerGradient()(
block_size,
i - start,
s_grads + reduced_idx,
out + expanded_idx,
d_in + expanded_idx,
d_out + reduced_idx,
&context_);
// check correctness of the next segment
if (i < N) {
CAFFE_ENFORCE_EQ(
s_ids[start] + 1,
s_ids[i],
"Indices must be sorted and not have gaps");
}
}
return true;
}
static constexpr int kNumInputs = 4;
INPUT_TAGS(DATA_IN, DATA_OUT, SEGMENT_GRADS, SEGMENT_IDS);
};
template <typename T, typename SIndex, typename Context, typename ReducerDef>
struct AbstractSortedSegmentRangeDef {
using OpDef = ReducerDef;
static constexpr const char* basename = "SortedSegmentRange";
static constexpr const char* doc = R"DOC(
Applies '{op}' to each segment of input tensor. In order to allow for more
efficient implementation of '{op}', the input segments have to be contiguous
and non-empty.
SEGMENT_IDS is a vector that maps each of the first dimension slices of the
DATA to a particular group (segment). Values belonging to the same segment are
aggregated together.
The first dimension of the output is equal to the number of input segments,
i.e. `SEGMENT_IDS[-1]+1`. Other dimensions are inherited from the input tensor.
{op_doc}
)DOC";
static void PopulateSchema(OpSchema& schema) {
schema.Input(0, "DATA", "Input tensor to be aggregated");
schema.Input(
1,
"SEGMENT_IDS",
"Vector with the same length as the first dimension of DATA "
"and values in the range 0..K-1 and in increasing order that "
"maps each slice of DATA to one of the segments");
schema.Output(
0,
"OUTPUT",
"Aggregated tensor with the first dimension of K and the "
"other dimentsions inherited from DATA");
}
using ForwardOp = AbstractSortedSegmentRangeOp<
T,
SIndex,
Context,
typename ReducerDef::template Reducer<T, Context>>;
using BackwardOp = AbstractSortedSegmentRangeGradientOp<
T,
SIndex,
Context,
typename ReducerDef::template ReducerGradient<T, Context>>;
struct GetGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
string(basename) + ReducerDef::name + "Gradient",
"",
vector<string>{I(0), O(0), GO(0), I(1)},
// no gradient on segment_ids!
vector<string>{GI(0)});
}
};
};
////////////////////////////////////////////////////////////////////////////////
// Incremental reducer ops: assume that reducer consumes pieces of data one by
// one. Also, supports additional arguments passed to reducer, e.g. scalers for
// weighted sum.
//
// Note: in current implementation additional inputs are considered auxiliary
// constants and have limitations:
// - there is no gradient computation for auxiliary inputs
// - auxiliary inputs aren't affected by fused embedding lookup in operations
// like sparse_sorted_segment
////////////////////////////////////////////////////////////////////////////////
/**
* @brief Simple non-segmented reduction over the first few dimensions of the
* tensor
*
* Inputs:
* 0: DATA - input embedding to do lookups in
* 1..P: AUX_ARG_<I> - optional additional arguments to be passed to the
* reducer
*
* Args:
* num_reduce_dim (default 1) - the number of dims in front of the tensor to
* reduce
*
* Output:
* Tensor without the first `num_dim` dimensions of DATA
*/
template <
typename T,
class Context,
class Reducer,
bool FirstDim,
class InputAccessor = BaseInputAccessor<T>>
class AbstractReduceFrontOrBackOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit AbstractReduceFrontOrBackOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(int, "num_reduce_dim", num_reduce_dims_, 1) {}
bool RunOnDevice() override {
auto& data = Input(0);
// If more complicated fixed size logic becomes necessary, it can be moved
// to the reducer class
int64_t in_block_size = FirstDim
? data.size_from_dim(num_reduce_dims_)
: data.size_to_dim(data.dim() - num_reduce_dims_);
return DispatchHelper<typename Reducer::FixedDispatch>::call(
this, in_block_size);
}
template <int FixedSize>
bool DoRunWithValue() {
auto& data = Input(0);
CAFFE_ENFORCE_LE(num_reduce_dims_, data.dim());
typename Reducer::Meta ctx(FirstDim);
ctx.observeInput(0, data, num_reduce_dims_);
for (int i = 1; i < Reducer::kInputCount; ++i) {
auto& aux_in = Input(i);
ctx.observeInput(i, aux_in, num_reduce_dims_);
}
OPERATOR_NEEDS_FEATURE(
inputAccessor_.observeInput(data),
"Unsupported input type: ",
data.dtype().name(),
".");
vector<int64_t> shape;
ctx.appendOutputShape(&shape);
auto* output = Output(0, shape, at::dtype<T>());
T* out = output->template mutable_data<T>();
const int block_size = FirstDim
? data.size_from_dim(num_reduce_dims_)
: data.size_from_dim(data.dim() - num_reduce_dims_);
const int num_blocks = block_size > 0 ? data.numel() / block_size : 0;
Reducer r(ctx, out, &context_);
for (int64_t i = 0; i < num_blocks; ++i) {
r.template process<FixedSize>(
ctx, inputAccessor_.getBlockPtr(block_size, i), i, &context_);
}
r.template finish<FixedSize>(ctx, &context_);
return true;
}
Loading ...