#ifndef CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_
#define CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_
#include <chrono>
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/reducer_functors.h"
#include "caffe2/perfkernels/math.h"
#include "caffe2/utils/math.h"
#ifdef CAFFE2_USE_MKL
#include <mkl.h>
#define FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
#endif
namespace caffe2 {
template <class Context>
class FloatToFusedRandRowwiseQuantizedOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit FloatToFusedRandRowwiseQuantizedOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
bitwidth_(OperatorBase::GetSingleArgument<int32_t>("bitwidth", 8)),
random_(OperatorBase::GetSingleArgument<bool>("random", true)) {
CAFFE_ENFORCE(
bitwidth_ == 1 || bitwidth_ == 2 || bitwidth_ == 4 || bitwidth_ == 8,
"Unsupported bitwidth");
if (random_) {
#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
int status = vslNewStream(
&vslStream_,
VSL_BRNG_MT19937,
std::chrono::system_clock::now().time_since_epoch().count());
if (status != VSL_STATUS_OK) {
LOG(WARNING) << "vslNewStream returns " << status;
}
#else
gen_.seed(std::chrono::system_clock::now().time_since_epoch().count());
dis_.reset(new std::uniform_real_distribution<float>(0.0f, 1.0f));
#endif
}
}
~FloatToFusedRandRowwiseQuantizedOp() {
if (random_) {
#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
int status = vslDeleteStream(&vslStream_);
if (status != VSL_STATUS_OK) {
LOG(WARNING) << "vslDeleteStream returns " << status;
}
#endif
}
}
bool RunOnDevice() override;
private:
INPUT_TAGS(DATA_FLOAT);
OUTPUT_TAGS(DATA_FUSED_QUANTIZED);
protected:
size_t bitwidth_{8};
bool random_{true};
std::vector<float> random_buffer_;
#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
VSLStreamStatePtr vslStream_;
#else
std::unique_ptr<std::uniform_real_distribution<float>> dis_;
std::minstd_rand gen_;
#endif
};
template <class Context>
class FusedRandRowwiseQuantizedToFloatOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(FusedRandRowwiseQuantizedToFloatOp)
bool RunOnDevice() override;
private:
INPUT_TAGS(DATA_FUSED_QUANTIZED);
OUTPUT_TAGS(DATA_FLOAT);
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_