Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / fused_rowwise_random_quantization_ops.h

#ifndef CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_
#define CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_

#include <chrono>

#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/reducer_functors.h"
#include "caffe2/perfkernels/math.h"
#include "caffe2/utils/math.h"

#ifdef CAFFE2_USE_MKL
#include <mkl.h>
#define FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
#endif

namespace caffe2 {

template <class Context>
class FloatToFusedRandRowwiseQuantizedOp : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit FloatToFusedRandRowwiseQuantizedOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        bitwidth_(OperatorBase::GetSingleArgument<int32_t>("bitwidth", 8)),
        random_(OperatorBase::GetSingleArgument<bool>("random", true)) {
    CAFFE_ENFORCE(
        bitwidth_ == 1 || bitwidth_ == 2 || bitwidth_ == 4 || bitwidth_ == 8,
        "Unsupported bitwidth");
    if (random_) {
#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
      int status = vslNewStream(
          &vslStream_,
          VSL_BRNG_MT19937,
          std::chrono::system_clock::now().time_since_epoch().count());
      if (status != VSL_STATUS_OK) {
        LOG(WARNING) << "vslNewStream returns " << status;
      }
#else
      gen_.seed(std::chrono::system_clock::now().time_since_epoch().count());
      dis_.reset(new std::uniform_real_distribution<float>(0.0f, 1.0f));
#endif
    }
  }

  ~FloatToFusedRandRowwiseQuantizedOp() {
    if (random_) {
#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
      int status = vslDeleteStream(&vslStream_);
      if (status != VSL_STATUS_OK) {
        LOG(WARNING) << "vslDeleteStream returns " << status;
      }
#endif
    }
  }

  bool RunOnDevice() override;

 private:
  INPUT_TAGS(DATA_FLOAT);
  OUTPUT_TAGS(DATA_FUSED_QUANTIZED);

 protected:
  size_t bitwidth_{8};
  bool random_{true};
  std::vector<float> random_buffer_;

#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
  VSLStreamStatePtr vslStream_;
#else
  std::unique_ptr<std::uniform_real_distribution<float>> dis_;
  std::minstd_rand gen_;
#endif
};

template <class Context>
class FusedRandRowwiseQuantizedToFloatOp : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  USE_SIMPLE_CTOR_DTOR(FusedRandRowwiseQuantizedToFloatOp)

  bool RunOnDevice() override;

 private:
  INPUT_TAGS(DATA_FUSED_QUANTIZED);
  OUTPUT_TAGS(DATA_FLOAT);
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_