Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / qnnpack_func.h

#pragma once

#include <cstdlib>
#include <qnnpack/operator.h>

namespace qnnpack {
class PrePackConvWeights final {
 public:
  PrePackConvWeights(
      const pytorch_qnnp_operator_t convolution,
      const uint8_t* kernel_zero_points,
      const uint8_t* kernel,
      const int32_t* bias);

  void* getPackedWeights() const
  {
    return packed_weights_;
  }

  int64_t getOutputChannels() const
  {
    return output_channels_;
  }

  ~PrePackConvWeights()
  {
    if (packed_weights_ != nullptr) {
      free(packed_weights_);
    }
  }

  PrePackConvWeights() = delete;
  PrePackConvWeights(const PrePackConvWeights&) = delete;
  PrePackConvWeights& operator=(const PrePackConvWeights&) = delete;

 private:
  void* packed_weights_ = nullptr;
  int64_t output_channels_;
};

class PackBMatrix final {
 public:
  PackBMatrix(
      size_t input_channels,
      size_t output_channels,
      const uint8_t* kernel_zero_points,
      const float* requantization_scale,
      const uint8_t* kernel,
      const int32_t* bias);

  // This constructor is to be used for dynamic mode
  // quantization. In dynamic mode, we dont yet support
  // per channel quantization, and paying the cost of
  // memory allocation for per channel zero point and
  // requant scale will hurt performance.
  PackBMatrix(
      size_t input_channels,
      size_t output_channels,
      const uint8_t kernel_zero_point,
      const float requantization_scale,
      const uint8_t* kernel,
      const int32_t* bias);

  void* getPackedWeights() const
  {
    return packed_weights_;
  }

  uint8_t* unpackWeights(
      const uint8_t* kernel_zero_points,
      int n_elements
    ) const;

  size_t getInputChannels() const
  {
    return input_channels_;
  }

  size_t getOutputChannels() const
  {
    return output_channels_;
  }

  ~PackBMatrix()
  {
    if (packed_weights_ != nullptr) {
      free(packed_weights_);
    }
  }

  PackBMatrix() = delete;
  PackBMatrix(const PackBMatrix&) = delete;
  PackBMatrix& operator=(const PackBMatrix&) = delete;

 private:
  void* packed_weights_ = nullptr;
  size_t input_channels_;
  size_t output_channels_;
};

enum pytorch_qnnp_status qnnpackLinear(
    const size_t batch_size,
    const size_t input_channels,
    const size_t output_channels,
    const uint8_t input_zero_point,
    const uint8_t* kernel_zero_points,
    const float* requantization_scales,
    const uint8_t output_zero_point,
    const uint8_t output_min,
    const uint8_t output_max,
    const uint8_t* input,
    const size_t input_stride,
    void* packed_weights,
    uint8_t* output,
    const size_t output_stride,
    pthreadpool_t threadpool);

enum pytorch_qnnp_status qnnpackConv(
    const pytorch_qnnp_operator_t convolution,
    void* packed_weights,
    const size_t batch_size,
    const size_t input_depth,
    const size_t input_height,
    const size_t input_width,
    const uint8_t input_zero_point,
    const uint8_t* input,
    const uint8_t* kernel_zero_points,
    const float* requantization_scales,
    const uint8_t output_zero_point,
    const uint8_t output_min,
    const uint8_t output_max,
    uint8_t* output,
    pthreadpool_t threadpool);

enum pytorch_qnnp_status qnnpackDeConv(
    const pytorch_qnnp_operator_t deconvolution,
    void* packed_weights,
    const size_t batch_size,
    const size_t input_height,
    const size_t input_width,
    const uint8_t input_zero_point,
    const uint8_t* input,
    const uint8_t* kernel_zero_points,
    const float* requantization_scales,
    const uint8_t output_zero_point,
    const uint8_t output_min,
    const uint8_t output_max,
    uint8_t* output,
    pthreadpool_t threadpool);

enum pytorch_qnnp_status qnnpackLinearDynamic(
    const size_t batch_size,
    const size_t input_channels,
    const size_t output_channels,
    const uint8_t input_zero_point,
    const uint8_t* kernel_zero_points,
    const float* dequantization_scales,
    const uint8_t* input,
    const size_t input_stride,
    void* packed_weights,
    const float* bias,
    float* output,
    const size_t output_stride,
    pthreadpool_t threadpool);

} // namespace qnnpack