#pragma once
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/perfkernels/embedding_lookup.h"
#ifdef USE_FBGEMM
#include "fbgemm/Fbgemm.h"
#endif
#include <algorithm>
#include <functional>
namespace caffe2 {
// A templated class that implements SparseLengths[Sum,WeightedSum,Mean].
template <
typename T, // output type
class InputTypes, // supported input types, such as TensorTypes<float>
bool USE_WEIGHT = false, // Whether it is SparseLengthsWeightedSum
bool USE_MEAN = false, // Whether this is SparseLengthsMean
bool USE_POSITIONAL_WEIGHT = false
// USE_WEIGHT = true and USE_POSITIONAL_WEIGHT = true
// -> SparseLengthsPositionalWeightedSum
>
class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
public:
USE_OPERATOR_FUNCTIONS(CPUContext);
template <class... Args>
explicit CPUSparseLengthsReductionOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...) {
static_assert(
!(USE_WEIGHT & USE_MEAN), "Cannot both specify weight and mean.");
}
~CPUSparseLengthsReductionOp() {}
// Currently, we support float and at::Half inputs for input data type, and
// int32_t and int64_t for the index type.
bool RunOnDevice() override {
return DispatchHelper<InputTypes>::call(this, Input(DATA));
}
template <typename InputType>
bool DoRunWithType() {
return DispatchHelper<TensorTypes2<int32_t, int64_t>, InputType>::call(
this, Input(INDICES));
}
template <typename InputType, typename IndexType>
bool DoRunWithType2() {
auto& dataInput = Input(DATA);
auto& indicesInput = Input(INDICES);
auto& lengthsInput = Input(LENGTHS);
const int64_t M = lengthsInput.size(0);
const int64_t indices_size = indicesInput.numel();
auto shape = dataInput.sizes().vec();
shape[0] = M;
auto* output = Output(0, shape, at::dtype<T>());
T* out_data = output->template mutable_data<T>();
if (indices_size == 0) {
if (M > 0) {
memset(out_data, 0, output->numel() * sizeof(T));
}
return true;
}
CAFFE_ENFORCE_EQ(1, indicesInput.dim(), "INDICES must be a vector");
CAFFE_ENFORCE_EQ(1, lengthsInput.dim(), "LENGTHS must be a vector");
const int64_t N = dataInput.size(0);
const int D = dataInput.size_from_dim(1);
const InputType* in_data = dataInput.template data<InputType>();
const IndexType* indices = indicesInput.template data<IndexType>();
const int* lengths = lengthsInput.template data<int>();
const T* in_weight = nullptr;
if (USE_WEIGHT) {
// static if
auto& weightInput = Input(WEIGHT);
CAFFE_ENFORCE_EQ(1, weightInput.dim(), "WEIGHT must be a vector");
if (!USE_POSITIONAL_WEIGHT) {
CAFFE_ENFORCE_EQ(
weightInput.numel(),
indices_size,
"Weight should have the same length as indices.");
}
in_weight = weightInput.template data<T>();
}
#ifdef USE_FBGEMM
// If this is the first call or block size has changed (should never
// happen actually), generate a kernel.
if (D != last_block_size) {
last_block_size = D;
if (std::is_same<InputType, float>::value) {
if (std::is_same<IndexType, std::int32_t>::value) {
kernel_fp32_i32_ =
fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
} else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
kernel_fp32_i64_ =
fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
}
} else {
CAFFE_ENFORCE((std::is_same<InputType, at::Half>::value));
if (std::is_same<IndexType, std::int32_t>::value) {
kernel_fp16_i32_ =
fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int32_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
} else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
kernel_fp16_i64_ =
fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int64_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
}
}
}
bool success;
if (std::is_same<InputType, float>::value) {
if (std::is_same<IndexType, std::int32_t>::value) {
success = kernel_fp32_i32_(
M,
indices_size,
N,
reinterpret_cast<const float*>(in_data),
indicesInput.template data<std::int32_t>(),
lengths,
in_weight,
out_data);
} else {
success = kernel_fp32_i64_(
M,
indices_size,
N,
reinterpret_cast<const float*>(in_data),
indicesInput.template data<std::int64_t>(),
lengths,
in_weight,
out_data);
}
} else {
if (std::is_same<IndexType, std::int32_t>::value) {
success = kernel_fp16_i32_(
M,
indices_size,
N,
reinterpret_cast<const fbgemm::float16*>(in_data),
indicesInput.template data<std::int32_t>(),
lengths,
in_weight,
out_data);
} else {
success = kernel_fp16_i64_(
M,
indices_size,
N,
reinterpret_cast<const fbgemm::float16*>(in_data),
indicesInput.template data<std::int64_t>(),
lengths,
in_weight,
out_data);
}
}
if (success) {
return true;
}
int64_t current = 0;
for (int m = 0; m < M; ++m) {
for (int i = 0; i < lengths[m]; ++i) {
CAFFE_ENFORCE_LT(
current,
indices_size,
"Your input seems to be incorrect: the sum of lengths values "
"should be the size of the indices tensor, but it appears not.");
IndexType idx = indices[current];
CAFFE_ENFORCE(
0 <= idx && idx < N,
"Index ",
current,
" is out of bounds: ",
idx,
", range 0 to ",
N);
++current;
}
}
CAFFE_ENFORCE_EQ(
current,
indices_size,
"Your input seems to be incorrect: the sum of lengths values should be "
"the size of the indices tensor, but it appears not.");
return false;
#endif
// delegate work to perfkernel that branches based on architecture
EmbeddingLookup<IndexType, InputType, T, USE_POSITIONAL_WEIGHT>(
D,
M,
indices_size,
N,
in_data,
indices,
lengths,
in_weight,
nullptr, // scale_bias field is only used in SparseLengths8BitsRowwiseOp
USE_MEAN,
out_data);
return true;
}
enum {
DATA = 0, // Data input.
WEIGHT = 1, // Weight input used in SparseLengthsWeightedSum
INDICES = 1 + USE_WEIGHT, // 1 in SparseLengths[Sum,Mean] and
// 2 in SparseLengthsWeightedSum
LENGTHS = 2 + USE_WEIGHT, // 2 in SparseLengths[Sum, Mean],
// 3 in SparseLengthsWeightedSum
};
#ifdef USE_FBGEMM
private:
std::int64_t last_block_size{-1};
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type
kernel_fp32_i32_;
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type
kernel_fp32_i64_;
fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int32_t>::Type
kernel_fp16_i32_;
fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int64_t>::Type
kernel_fp16_i64_;
#endif
};
template <typename T, class Context, class Engine = DefaultEngine>
class TTSparseLengthsSumOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit TTSparseLengthsSumOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
factor_i(this->template GetRepeatedArgument<int>(
"factor_i",
vector<int>{1, 1, 1})),
factor_j(this->template GetRepeatedArgument<int>(
"factor_j",
vector<int>{1, 1, 1})),
ranks(this->template GetRepeatedArgument<int>(
"ranks",
vector<int>{1, 1, 1, 1})),
emb_size(this->template GetSingleArgument<int>("emb_size", 64)) {
// cumprod of i, used for index slice
l_cumprod.push_back(1);
for (size_t i = 1; i < factor_i.size(); ++i) {
l_cumprod.push_back(l_cumprod[i - 1] * factor_i[i - 1]);
}
}
~TTSparseLengthsSumOp() {}
void Ind2Sub(int64_t* out_factor_index, const int64_t* indices, int len) {
// TODO: vectorization
auto N = factor_i.size();
for (int j = 0; j < len; j++) {
auto idx = indices[j];
for (int i = N; i > 0; i--) {
out_factor_index[j * N + i - 1] = idx / l_cumprod[i - 1];
idx = idx % l_cumprod[i - 1];
}
}
}
bool GetSlice(
std::vector<std::vector<T>>& tgt_slice,
const T* core,
const vector<int64_t>& ind_slice,
int bs,
int idx) {
// implement the functinality index_select(core, 1, ind_slice)
auto num_of_elements = ranks[idx] * factor_j[idx] * ranks[idx + 1];
for (int i = 0; i < bs; i++) {
memcpy(
tgt_slice[i].data(),
core + ind_slice[i] * num_of_elements,
num_of_elements * sizeof(T));
}
return true;
}
// ind: it stores the index to each tensor core
// bs: the number of indices
// GatherAllRows uses two steps to calculate the lengthsum functionality: 1) it uses tensor train
// to calculate the embedding for each index. 2) it sums the embedding for each bag.
// In Step 1), it batches all the indices together. Specifically, for every index, it uses the pre-computed
// ind of each tensor core to extract the corresponding slice of the core. Then it does gemm operation
// sequentially on the slices to produce the embedding result for each index.
// In Step 2), it takes the embedding computed in step 1) and apply the sum operation for each bag.
bool GatherAllRows(
int64_t* ind,
int bs,
int x_len,
vector<const T*> cores,
int segments,
const int* lengths,
T* out_data) {
// compute the largest memory consumption of intermediate result
// TODO: dynamic allocation size: cur_rows*factor_j[i]*ranks[i+1]
// and also explore the contiguous memory storage for res and int_res
int max_rank = *max_element(ranks.begin(), ranks.end());
std::vector<std::vector<T>> res(bs, std::vector<T>(emb_size * max_rank, 0));
std::vector<std::vector<T>> int_res(
bs, std::vector<T>(emb_size * max_rank, 0));
// Store the matrix A
vector<T*> Y_ptr(bs);
// Store the intermediate result in each layer
vector<T*> Z_ptr(bs);
Loading ...