Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ include / ATen / cpu / vec256 / vec256_qint.h

#pragma once

// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]

#include <ATen/cpu/vec256/intrinsics.h>
#include <ATen/cpu/vec256/vec256_base.h>
#include <ATen/native/quantized/affine_quantizer_base.h>
#include <c10/util/qint32.h>
#include <c10/util/qint8.h>
#include <c10/util/quint8.h>

#include <array>

// This file defines Vec256<> for the quantized types.
//
//
// Currently, we simply use these classes as efficient converters between
// the quantized types and Vec256<float>, usually in bandwidth-bound cases
// where doing the arithmetic in full-precision is acceptable (e.g.
// elementwise operators).
//
//
// Conversions are as follows:
//  Vec256<qint8> -> 4x Vec256<float>
//  Vec256<quint8> -> 4x Vec256<float>
//  Vec256<qint32> -> 1x Vec256<float>
//
// The size of the returned float vector is specified by the special
// constexpr function float_num_vecs. The type of the value returned
// from dequantize (and expected as an argument to quantize) is
// specified by float_vec_return_type.
//
// When writing kernels with these vectors, it is expected that floating-
// point operations will be carried out in a loop over Vec256<T>::float_num_vecs
// iterations.

namespace at {
namespace vec256 {
namespace {

#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)

struct Vec256qi {
 protected:
  __m256i vals __attribute__((aligned(64)));

 public:
  Vec256qi() {}
  Vec256qi(__m256i v) : vals(v) {}
  operator __m256i() const {
    return vals;
  }
};

#if defined(CPU_CAPABILITY_AVX2)
template <typename T>
__m256i pack_saturate_and_clamp(
    __m256i first,
    __m256i second,
    T min_val,
    T max_val);

template <>
__m256i pack_saturate_and_clamp<int32_t>(
    __m256i first,
    __m256i second,
    int32_t min_val,
    int32_t max_val) {
  // This function is for linkage only, will not be used
  AT_ERROR("pack_saturate_and_clamp<int32_t> is not supported");
}

template <>
__m256i pack_saturate_and_clamp<int8_t>(
    __m256i first,
    __m256i second,
    int8_t min_val,
    int8_t max_val) {
  __m256i packed_and_sat = _mm256_packs_epi16(first, second);
  return _mm256_max_epi8(
      _mm256_set1_epi8(min_val),
      _mm256_min_epi8(packed_and_sat, _mm256_set1_epi8(max_val)));
}

template <>
__m256i pack_saturate_and_clamp<uint8_t>(
    __m256i first,
    __m256i second,
    uint8_t min_val,
    uint8_t max_val) {
  __m256i packed_and_sat = _mm256_packus_epi16(first, second);
  return _mm256_max_epu8(
      _mm256_set1_epi8(min_val),
      _mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val)));
}
#endif

template <typename T>
inline void __attribute__((always_inline)) QuantizeAvx2(
    const float* src,
    typename T::underlying* dst,
    int len,
    float inverse_scale,
    int64_t zero_point) {
#if defined(CPU_CAPABILITY_AVX2)
  constexpr int VLEN = 8;
  constexpr auto min_val = std::numeric_limits<typename T::underlying>::min();
  constexpr auto max_val = std::numeric_limits<typename T::underlying>::max();
  const __m256i min_v = _mm256_set1_epi32(min_val);
  const __m256i max_v = _mm256_set1_epi32(max_val);
  // This is the largest int32 value < int32_max exactly representable in float
  constexpr int32_t int32_float_max_val =
      std::numeric_limits<int32_t>::max() - 127;
  int i = 0;
  __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
  // clang-format off
  static const __m256i shuffle_mask_v = _mm256_set_epi8(
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0x0c, 0x08, 0x04, 0x00,
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0x0c, 0x08, 0x04, 0x00);
  // clang-format on
  __m256i permute_mask_v =
      _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
  __m256i permute_mask_l8_v =
      _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
  int len_aligned = len / (VLEN * 4) * (VLEN * 4);
  for (; i < len_aligned; i += 4 * VLEN) {
    // x
    __m256 x_vals = _mm256_load_ps(src + i);
    __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v);
    // If the floating point value is greater than int32_max,
    // _mm256_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to
    // Clip at int32_float_max_val to avoid this.
    x_transformed_v =
        _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val));
    // y
    __m256 y_vals = _mm256_load_ps(src + i + VLEN);
    __m256 y_transformed_v = _mm256_mul_ps(y_vals, inverse_scale_v);
    y_transformed_v =
        _mm256_min_ps(y_transformed_v, _mm256_set1_ps(int32_float_max_val));
    // z
    __m256 z_vals = _mm256_load_ps(src + i + 2 * VLEN);
    __m256 z_transformed_v = _mm256_mul_ps(z_vals, inverse_scale_v);
    z_transformed_v =
        _mm256_min_ps(z_transformed_v, _mm256_set1_ps(int32_float_max_val));
    // w
    __m256 w_vals = _mm256_load_ps(src + i + 3 * VLEN);
    __m256 w_transformed_v = _mm256_mul_ps(w_vals, inverse_scale_v);
    w_transformed_v =
        _mm256_min_ps(w_transformed_v, _mm256_set1_ps(int32_float_max_val));

    __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
    __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v);
    __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v);
    __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v);

    // add zero point
    x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point));
    y_rounded_v = _mm256_add_epi32(y_rounded_v, _mm256_set1_epi32(zero_point));
    z_rounded_v = _mm256_add_epi32(z_rounded_v, _mm256_set1_epi32(zero_point));
    w_rounded_v = _mm256_add_epi32(w_rounded_v, _mm256_set1_epi32(zero_point));

    __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
    __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
    __m256i xyzw_clamped_v = pack_saturate_and_clamp<typename T::underlying>(
        xy_packed_v, zw_packed_v, min_val, max_val);

    xyzw_clamped_v =
        _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
    _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), xyzw_clamped_v);
  }

  // Additional 8-lane AVX2 version to take advantage when len is smaller
  // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM)
  for (; i < len / VLEN * VLEN; i += VLEN) {
    __m256 x_vals = _mm256_load_ps(src + i);
    __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v);
    x_transformed_v =
        _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val));
    __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
    x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point));
    __m256i x_clipped_v =
        _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, x_rounded_v));

    x_clipped_v = _mm256_shuffle_epi8(x_clipped_v, shuffle_mask_v);
    x_clipped_v = _mm256_permutevar8x32_epi32(x_clipped_v, permute_mask_l8_v);
    _mm_storel_epi64(
        reinterpret_cast<__m128i*>(dst + i),
        _mm256_castsi256_si128(x_clipped_v));
  }

  for (; i < len; ++i) {
    float transformed = src[i] * inverse_scale;

    // Not exactly the same behavior as the vectorized code.
    // The vectorized code above always rounds to even in halfway cases
    // (https://software.intel.com/en-us/node/523819), but std::nearbyint
    // does the same only when the current rounding mode is FE_TONEAREST.
    // However, in practice, this should not be a problem because most cases
    // use the default rounding mode FE_TONEAREST.
    // Note that we cannot implement the same behavior as the vectorized code
    // using std::round because it does rounding away from zero in halfway
    // cases.
    transformed = zero_point + nearbyint(transformed);
    float clipped =
        std::min(std::max(transformed, float(min_val)), float(max_val));
    dst[i] = clipped;
  }
#else
  at::native::quantize_vec<T>(
      1.0f / inverse_scale, zero_point, src, reinterpret_cast<T*>(dst), len);
#endif
}

template<>
struct Vec256<c10::qint32> : public Vec256qi {
    static constexpr int size() {
        return 8;
    }

    static constexpr int float_num_vecs() {
        return 1;
    }

    static constexpr int int_num_vecs() {
        return 1;
    }

    using float_vec_return_type = std::array<Vec256<float>, 1>;
    using int_vec_return_type = std::array<Vec256<c10::qint32>, 1>;
    using value_type = c10::qint32::underlying;

 public:
    using Vec256qi::Vec256qi;
    Vec256() {}

    Vec256(__m256i vals_) { vals = vals_;}

    // Broadcast constructor
    Vec256(const c10::qint32& val) {
        value_type uw = val.val_;
        vals = _mm256_set1_epi32(uw);
    }

    void store(void* ptr, int count = size()) const {
      if (count != size()) {
        memcpy(ptr, &vals, count * sizeof(value_type));
      } else {
        _mm256_storeu_si256((__m256i*)ptr, vals);
      }
    }

    static Vec256<c10::qint32> loadu(const void* ptr) {
        return Vec256<c10::qint32>(ptr);
    }

    float_vec_return_type dequantize(
        Vec256<float> scale,
        Vec256<float> zero_point,
        Vec256<float> scale_zp_premul) const {
      __m256 float_vals = _mm256_cvtepi32_ps(vals);
#if defined(CPU_CAPABILITY_AVX2)
      return {vec256::fmadd(scale, Vec256<float>(float_vals), scale_zp_premul)};
#else
      return {scale * (Vec256<float>(float_vals) - zero_point)};
#endif
    }

    static Vec256<c10::qint32> quantize(
        const float_vec_return_type& rhs,
        float scale,
        int32_t zero_point,
        float inverse_scale) {
      Vec256<c10::qint32> retval;
      auto rhs_data = (__m256)rhs[0];
      at::native::quantize_vec<c10::qint32, /*precision=*/32>(
          scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 8);
      return retval;
    }

    Vec256<c10::qint32> maximum(Vec256<c10::qint32> b) const {
#ifdef CPU_CAPABILITY_AVX2
      return _mm256_max_epi32(vals, b.vals);
#else
      // Pray the compiler can autovectorize this
      std::array<int32_t, size()> int_vals;
      _mm256_storeu_si256(reinterpret_cast<__m256i*>(int_vals.data()), vals);
      std::array<int32_t, size()> b_vals;
      _mm256_storeu_si256(
          reinterpret_cast<__m256i*>(b_vals.data()), b.vals);
      std::array<int32_t, size()> result_vals;
      for (size_t i = 0; i < size(); ++i) {
        result_vals[i] = std::max<int32_t>(int_vals[i], b_vals[i]);
      }
      return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
    }

    Vec256<c10::qint32> minimum(Vec256<c10::qint32> b) const {
#ifdef CPU_CAPABILITY_AVX2
      return _mm256_min_epi32(vals, b.vals);
#else
      // Pray the compiler can autovectorize this
      std::array<int32_t, size()> int_vals;
      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
      std::array<int32_t, size()> b_vals;
      _mm256_storeu_si256(
          reinterpret_cast<__m256i*>(&b_vals), b.vals);
      std::array<int32_t, size()> result_vals;
      for (size_t i = 0; i < size(); ++i) {
        result_vals[i] = std::min<int32_t>(int_vals[i], b_vals[i]);
      }
      return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
    }

    Vec256<c10::qint32> relu(Vec256<c10::qint32> zero_point) const {
        return maximum(zero_point);
    }

    Vec256<c10::qint32> relu6(
        Vec256<c10::qint32> zero_point,
        Vec256<c10::qint32> q_six) {
#ifdef CPU_CAPABILITY_AVX2
      return _mm256_min_epi32(
          _mm256_max_epi32(vals, zero_point.vals), q_six.vals);
#else
      // Pray the compiler can autovectorize this
      std::array<int32_t, size()> int_vals;
      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
      std::array<int32_t, size()> zero_point_vals;
      _mm256_storeu_si256(
          reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
      std::array<int32_t,size()> q_six_vals;
      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
      std::array<int32_t, size()> result_vals;
      for (size_t i = 0; i < size(); ++i) {
        result_vals[i] = std::min<int32_t>(
            std::max<int32_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
Loading ...