#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec256/intrinsics.h>
#include <ATen/cpu/vec256/vec256_base.h>
#include <ATen/native/quantized/affine_quantizer_base.h>
#include <c10/util/qint32.h>
#include <c10/util/qint8.h>
#include <c10/util/quint8.h>
#include <array>
// This file defines Vec256<> for the quantized types.
//
//
// Currently, we simply use these classes as efficient converters between
// the quantized types and Vec256<float>, usually in bandwidth-bound cases
// where doing the arithmetic in full-precision is acceptable (e.g.
// elementwise operators).
//
//
// Conversions are as follows:
// Vec256<qint8> -> 4x Vec256<float>
// Vec256<quint8> -> 4x Vec256<float>
// Vec256<qint32> -> 1x Vec256<float>
//
// The size of the returned float vector is specified by the special
// constexpr function float_num_vecs. The type of the value returned
// from dequantize (and expected as an argument to quantize) is
// specified by float_vec_return_type.
//
// When writing kernels with these vectors, it is expected that floating-
// point operations will be carried out in a loop over Vec256<T>::float_num_vecs
// iterations.
namespace at {
namespace vec256 {
namespace {
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
struct Vec256qi {
protected:
__m256i vals __attribute__((aligned(64)));
public:
Vec256qi() {}
Vec256qi(__m256i v) : vals(v) {}
operator __m256i() const {
return vals;
}
};
#if defined(CPU_CAPABILITY_AVX2)
template <typename T>
__m256i pack_saturate_and_clamp(
__m256i first,
__m256i second,
T min_val,
T max_val);
template <>
__m256i pack_saturate_and_clamp<int32_t>(
__m256i first,
__m256i second,
int32_t min_val,
int32_t max_val) {
// This function is for linkage only, will not be used
AT_ERROR("pack_saturate_and_clamp<int32_t> is not supported");
}
template <>
__m256i pack_saturate_and_clamp<int8_t>(
__m256i first,
__m256i second,
int8_t min_val,
int8_t max_val) {
__m256i packed_and_sat = _mm256_packs_epi16(first, second);
return _mm256_max_epi8(
_mm256_set1_epi8(min_val),
_mm256_min_epi8(packed_and_sat, _mm256_set1_epi8(max_val)));
}
template <>
__m256i pack_saturate_and_clamp<uint8_t>(
__m256i first,
__m256i second,
uint8_t min_val,
uint8_t max_val) {
__m256i packed_and_sat = _mm256_packus_epi16(first, second);
return _mm256_max_epu8(
_mm256_set1_epi8(min_val),
_mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val)));
}
#endif
template <typename T>
inline void __attribute__((always_inline)) QuantizeAvx2(
const float* src,
typename T::underlying* dst,
int len,
float inverse_scale,
int64_t zero_point) {
#if defined(CPU_CAPABILITY_AVX2)
constexpr int VLEN = 8;
constexpr auto min_val = std::numeric_limits<typename T::underlying>::min();
constexpr auto max_val = std::numeric_limits<typename T::underlying>::max();
const __m256i min_v = _mm256_set1_epi32(min_val);
const __m256i max_v = _mm256_set1_epi32(max_val);
// This is the largest int32 value < int32_max exactly representable in float
constexpr int32_t int32_float_max_val =
std::numeric_limits<int32_t>::max() - 127;
int i = 0;
__m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
// clang-format off
static const __m256i shuffle_mask_v = _mm256_set_epi8(
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0x0c, 0x08, 0x04, 0x00,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0x0c, 0x08, 0x04, 0x00);
// clang-format on
__m256i permute_mask_v =
_mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
__m256i permute_mask_l8_v =
_mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
int len_aligned = len / (VLEN * 4) * (VLEN * 4);
for (; i < len_aligned; i += 4 * VLEN) {
// x
__m256 x_vals = _mm256_load_ps(src + i);
__m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v);
// If the floating point value is greater than int32_max,
// _mm256_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to
// Clip at int32_float_max_val to avoid this.
x_transformed_v =
_mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val));
// y
__m256 y_vals = _mm256_load_ps(src + i + VLEN);
__m256 y_transformed_v = _mm256_mul_ps(y_vals, inverse_scale_v);
y_transformed_v =
_mm256_min_ps(y_transformed_v, _mm256_set1_ps(int32_float_max_val));
// z
__m256 z_vals = _mm256_load_ps(src + i + 2 * VLEN);
__m256 z_transformed_v = _mm256_mul_ps(z_vals, inverse_scale_v);
z_transformed_v =
_mm256_min_ps(z_transformed_v, _mm256_set1_ps(int32_float_max_val));
// w
__m256 w_vals = _mm256_load_ps(src + i + 3 * VLEN);
__m256 w_transformed_v = _mm256_mul_ps(w_vals, inverse_scale_v);
w_transformed_v =
_mm256_min_ps(w_transformed_v, _mm256_set1_ps(int32_float_max_val));
__m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
__m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v);
__m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v);
__m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v);
// add zero point
x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point));
y_rounded_v = _mm256_add_epi32(y_rounded_v, _mm256_set1_epi32(zero_point));
z_rounded_v = _mm256_add_epi32(z_rounded_v, _mm256_set1_epi32(zero_point));
w_rounded_v = _mm256_add_epi32(w_rounded_v, _mm256_set1_epi32(zero_point));
__m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
__m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
__m256i xyzw_clamped_v = pack_saturate_and_clamp<typename T::underlying>(
xy_packed_v, zw_packed_v, min_val, max_val);
xyzw_clamped_v =
_mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), xyzw_clamped_v);
}
// Additional 8-lane AVX2 version to take advantage when len is smaller
// based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM)
for (; i < len / VLEN * VLEN; i += VLEN) {
__m256 x_vals = _mm256_load_ps(src + i);
__m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v);
x_transformed_v =
_mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val));
__m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point));
__m256i x_clipped_v =
_mm256_max_epi32(min_v, _mm256_min_epi32(max_v, x_rounded_v));
x_clipped_v = _mm256_shuffle_epi8(x_clipped_v, shuffle_mask_v);
x_clipped_v = _mm256_permutevar8x32_epi32(x_clipped_v, permute_mask_l8_v);
_mm_storel_epi64(
reinterpret_cast<__m128i*>(dst + i),
_mm256_castsi256_si128(x_clipped_v));
}
for (; i < len; ++i) {
float transformed = src[i] * inverse_scale;
// Not exactly the same behavior as the vectorized code.
// The vectorized code above always rounds to even in halfway cases
// (https://software.intel.com/en-us/node/523819), but std::nearbyint
// does the same only when the current rounding mode is FE_TONEAREST.
// However, in practice, this should not be a problem because most cases
// use the default rounding mode FE_TONEAREST.
// Note that we cannot implement the same behavior as the vectorized code
// using std::round because it does rounding away from zero in halfway
// cases.
transformed = zero_point + nearbyint(transformed);
float clipped =
std::min(std::max(transformed, float(min_val)), float(max_val));
dst[i] = clipped;
}
#else
at::native::quantize_vec<T>(
1.0f / inverse_scale, zero_point, src, reinterpret_cast<T*>(dst), len);
#endif
}
template<>
struct Vec256<c10::qint32> : public Vec256qi {
static constexpr int size() {
return 8;
}
static constexpr int float_num_vecs() {
return 1;
}
static constexpr int int_num_vecs() {
return 1;
}
using float_vec_return_type = std::array<Vec256<float>, 1>;
using int_vec_return_type = std::array<Vec256<c10::qint32>, 1>;
using value_type = c10::qint32::underlying;
public:
using Vec256qi::Vec256qi;
Vec256() {}
Vec256(__m256i vals_) { vals = vals_;}
// Broadcast constructor
Vec256(const c10::qint32& val) {
value_type uw = val.val_;
vals = _mm256_set1_epi32(uw);
}
void store(void* ptr, int count = size()) const {
if (count != size()) {
memcpy(ptr, &vals, count * sizeof(value_type));
} else {
_mm256_storeu_si256((__m256i*)ptr, vals);
}
}
static Vec256<c10::qint32> loadu(const void* ptr) {
return Vec256<c10::qint32>(ptr);
}
float_vec_return_type dequantize(
Vec256<float> scale,
Vec256<float> zero_point,
Vec256<float> scale_zp_premul) const {
__m256 float_vals = _mm256_cvtepi32_ps(vals);
#if defined(CPU_CAPABILITY_AVX2)
return {vec256::fmadd(scale, Vec256<float>(float_vals), scale_zp_premul)};
#else
return {scale * (Vec256<float>(float_vals) - zero_point)};
#endif
}
static Vec256<c10::qint32> quantize(
const float_vec_return_type& rhs,
float scale,
int32_t zero_point,
float inverse_scale) {
Vec256<c10::qint32> retval;
auto rhs_data = (__m256)rhs[0];
at::native::quantize_vec<c10::qint32, /*precision=*/32>(
scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 8);
return retval;
}
Vec256<c10::qint32> maximum(Vec256<c10::qint32> b) const {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_max_epi32(vals, b.vals);
#else
// Pray the compiler can autovectorize this
std::array<int32_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(int_vals.data()), vals);
std::array<int32_t, size()> b_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(b_vals.data()), b.vals);
std::array<int32_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::max<int32_t>(int_vals[i], b_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
}
Vec256<c10::qint32> minimum(Vec256<c10::qint32> b) const {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_min_epi32(vals, b.vals);
#else
// Pray the compiler can autovectorize this
std::array<int32_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
std::array<int32_t, size()> b_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&b_vals), b.vals);
std::array<int32_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::min<int32_t>(int_vals[i], b_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
}
Vec256<c10::qint32> relu(Vec256<c10::qint32> zero_point) const {
return maximum(zero_point);
}
Vec256<c10::qint32> relu6(
Vec256<c10::qint32> zero_point,
Vec256<c10::qint32> q_six) {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_min_epi32(
_mm256_max_epi32(vals, zero_point.vals), q_six.vals);
#else
// Pray the compiler can autovectorize this
std::array<int32_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
std::array<int32_t, size()> zero_point_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
std::array<int32_t,size()> q_six_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
std::array<int32_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::min<int32_t>(
std::max<int32_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
Loading ...