#ifndef QUANT_DECODE_OP_H_
#define QUANT_DECODE_OP_H_
#include <c10/util/typeid.h>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
namespace caffe2 {
namespace {
template <class CodebookT, class CodeT>
void Decode(
const Tensor& codebook,
const Tensor& codes,
/* optional */ const Tensor* const decoded_grad,
Tensor* const output,
bool resizeOnly) {
CAFFE_ENFORCE(codebook.IsType<CodebookT>());
auto* cb_ptr = codebook.data<CodebookT>();
int cb_size = codebook.numel();
CAFFE_ENFORCE(codes.IsType<CodeT>());
auto* code_ptr = codes.data<CodeT>();
if (decoded_grad == nullptr) {
// Forward pass: decode and store codebook values in output.
output->ResizeLike(codes);
auto* out_ptr = output->template mutable_data<CodebookT>();
if (resizeOnly) {
return;
}
int sz = output->numel();
for (int i = 0; i < sz; i++) {
DCHECK_LE(*code_ptr, cb_size);
*out_ptr++ = cb_ptr[*code_ptr++];
}
} else {
// Backward pass: decode and accumulate gradient w.r.t. codebook values.
CAFFE_ENFORCE_EQ(codes.numel(), decoded_grad->numel());
auto* gradient_ptr = decoded_grad->data<CodebookT>();
auto* const gradient_end = gradient_ptr + decoded_grad->numel();
CAFFE_ENFORCE_EQ(cb_size, output->numel());
auto* out_ptr = output->template mutable_data<CodebookT>();
while (gradient_ptr < gradient_end) {
DCHECK_LE(*code_ptr, cb_size);
out_ptr[*code_ptr++] += *gradient_ptr++;
}
}
}
#define REGISTER_DECODER(codebookType, codesType) \
{ \
{TypeMeta::Id<codebookType>(), TypeMeta::Id<codesType>()}, \
[](const Tensor& codebook_, \
const Tensor& codes_, \
const Tensor* gradient_, \
Tensor* outDecoded_, \
bool resizeOnly_) { \
Decode<codebookType, codesType>( \
codebook_, codes_, gradient_, outDecoded_, resizeOnly_); \
} \
}
inline void DecodeGeneral(
const Tensor& codebook,
const Tensor& codes,
const Tensor* gradient,
Tensor* outDecoded,
bool resizeOnly) {
const static std::map<
std::pair<TypeIdentifier, TypeIdentifier>,
std::function<void(
const Tensor& codebook,
const Tensor& codes,
const Tensor* gradient,
Tensor* outDecoded,
bool resizeOnly)>>
gDecoderMapper = {REGISTER_DECODER(float, uint8_t),
REGISTER_DECODER(float, uint16_t),
REGISTER_DECODER(float, int32_t)};
gDecoderMapper.at({codebook.dtype().id(), codes.dtype().id()})(
codebook, codes, gradient, outDecoded, resizeOnly);
}
} // namespace
// Decode tensors based on given codebook,
// The codebook is generated by model_quantize.py
enum class QuantDecodeRunTy {
RUN_ALWAYS,
RUN_ONCE,
};
template <QuantDecodeRunTy QuantDecodeRun>
class QuantDecodeOp final : public Operator<CPUContext> {
public:
USE_OPERATOR_FUNCTIONS(CPUContext);
template <class... Args>
explicit QuantDecodeOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...) {}
~QuantDecodeOp() {}
bool RunOnDevice() override {
CAFFE_ENFORCE_GT(InputSize(), 1);
// first input is the codebook
CAFFE_ENFORCE_EQ(InputSize(), OutputSize() + 1);
const auto& codebook = Input(0);
CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.dtype().name());
for (int i = 0; i < OutputSize(); i++) {
auto& ci = Input(i + 1);
auto* co = Output(i);
DecodeGeneral(
codebook,
ci,
nullptr,
co,
/*resizeOnly=*/QuantDecodeRun == QuantDecodeRunTy::RUN_ONCE &&
hasRun_);
}
hasRun_ = true;
return true;
}
private:
bool hasRun_{false};
};
class QuantDecodeGradientOp final : public Operator<CPUContext> {
public:
USE_OPERATOR_FUNCTIONS(CPUContext);
template <class... Args>
explicit QuantDecodeGradientOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...) {}
~QuantDecodeGradientOp() {}
bool RunOnDevice() override {
// Inputs: 1 codebook, n tensors of codes, and n corresponding gradients.
CAFFE_ENFORCE(InputSize() >= 3 && InputSize() % 2 == 1);
const int num_code_tensors = (InputSize() - 1) / 2;
CAFFE_ENFORCE_EQ(OutputSize(), 1);
const auto& codebook = Input(0);
CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.dtype().name());
auto* gradient = Output(0, codebook.sizes(), at::dtype<float>());
auto* gradient_ptr = gradient->template mutable_data<float>();
std::fill(gradient_ptr, gradient_ptr + gradient->numel(), 0);
for (int i = 0; i < num_code_tensors; i++) {
auto& codes_i = Input(i + 1);
auto& output_gradient_i = Input(i + num_code_tensors + 1);
DecodeGeneral(codebook, codes_i, &output_gradient_i, gradient, false);
}
return true;
}
};
} // namespace caffe2
#endif // QUANT_DECODE_OP_H_