#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
//
// Note [Do not compile initializers with AVX]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// If you define a static initializer in this file, the initialization will use
// AVX instructions because these object files are compiled with AVX enabled.
// We need to avoid non-trivial global data in these architecture specific files
// because there's no way to guard the global initializers with CPU capability
// detection.
//
// See https://github.com/pytorch/pytorch/issues/37577 for an instance
// of this bug in the past.
#include <cstring>
#include <functional>
#include <cmath>
#include <type_traits>
#include <bitset>
#include <ATen/cpu/vec256/intrinsics.h>
#include <ATen/native/Math.h>
#include <ATen/NumericUtils.h>
#include <c10/util/C++17.h>
#include <c10/util/BFloat16.h>
#include <c10/util/BFloat16-math.h>
#include <c10/util/math_compat.h>
#include <ATen/native/cpu/zmath.h>
#include <c10/util/TypeCast.h>
#include <c10/macros/Macros.h>
#if defined(__GNUC__)
#define __at_align32__ __attribute__((aligned(32)))
#elif defined(_WIN32)
#define __at_align32__ __declspec(align(32))
#else
#define __at_align32__
#endif
namespace at {
namespace vec256 {
// See Note [Acceptable use of anonymous namespace in header]
namespace {
// at::Half should be treated as floating point
template <typename T>
struct is_floating_point:
std::integral_constant<bool,
std::is_floating_point<T>::value ||
std::is_same<T, at::Half>::value> {
};
template<size_t n> struct int_of_size;
#define DEFINE_INT_OF_SIZE(int_t) \
template<> struct int_of_size<sizeof(int_t)> { using type = int_t; }
DEFINE_INT_OF_SIZE(int64_t);
DEFINE_INT_OF_SIZE(int32_t);
DEFINE_INT_OF_SIZE(int16_t);
DEFINE_INT_OF_SIZE(int8_t);
#undef DEFINE_INT_OF_SIZE
template <typename T>
using int_same_size_t = typename int_of_size<sizeof(T)>::type;
// NOTE: If you specialize on a type, you must define all operations!
// emulates vectorized types
template <class T>
struct Vec256 {
private:
__at_align32__ T values[32 / sizeof(T)];
public:
using value_type = T;
// Note [constexpr static function to avoid odr-usage compiler bug]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Why, you might ask, is size defined to be a static constexpr function,
// rather than a more ordinary 'static constexpr int size;' variable?
// The problem lies within ODR rules for static constexpr members versus
// static constexpr functions. First, recall that this class (along with all
// of its derivations) live in an anonymous namespace: they are intended to be
// *completely* inlined at their use-sites, because we need to compile it
// multiple times for different instruction sets.
//
// Because of this constraint, we CANNOT provide a single definition for
// any static members in this class; since we want to compile the class
// multiple times, there wouldn't actually be any good place to put the
// definition. Now here is the problem: if we ODR-use a static constexpr
// member, we are *obligated* to provide a definition. Without the
// definition, you get a compile error like:
//
// relocation R_X86_64_PC32 against undefined symbol
// `_ZN2at6vec25612_GLOBAL__N_16Vec256IdE4sizeE' can not be used when making
// a shared object; recompile with -fPIC
//
// If this were C++17, we could replace a static constexpr variable with
// an inline variable which doesn't require one definition. But we are not
// C++17. So the next best thing is to replace the member with a static
// constexpr (and therefore inline) function, which does not require ODR
// either.
//
// Also, technically according to the C++ standard, we don't have to define
// a constexpr variable if we never odr-use it. But it seems that some
// versions GCC/Clang have buggy determinations on whether or not an
// identifier is odr-used or not, and in any case it's hard to tell if
// a variable is odr-used or not. So best to just cut the problem at the root.
static constexpr int size() {
return 32 / sizeof(T);
}
Vec256() : values{0} {}
Vec256(T val) {
for (int i = 0; i != size(); i++) {
values[i] = val;
}
}
template<typename... Args,
typename = std::enable_if_t<(sizeof...(Args) == size())>>
Vec256(Args... vals) : values{vals...}{
}
// This also implies const T& operator[](int idx) const
inline operator const T*() const {
return values;
}
// This also implies T& operator[](int idx)
inline operator T*() {
return values;
}
template <int64_t mask_>
static Vec256<T> blend(const Vec256<T>& a, const Vec256<T>& b) {
int64_t mask = mask_;
Vec256 vec;
for (int64_t i = 0; i < size(); i++) {
if (mask & 0x01) {
vec[i] = b[i];
} else {
vec[i] = a[i];
}
mask = mask >> 1;
}
return vec;
}
static Vec256<T> blendv(const Vec256<T>& a, const Vec256<T>& b,
const Vec256<T>& mask) {
Vec256 vec;
int_same_size_t<T> buffer[size()];
mask.store(buffer);
for (int64_t i = 0; i < size(); i++) {
if (buffer[i] & 0x01)
{
vec[i] = b[i];
} else {
vec[i] = a[i];
}
}
return vec;
}
template<typename step_t> // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
static Vec256<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) {
Vec256 vec;
for (int64_t i = 0; i < size(); i++) {
vec.values[i] = base + i * step;
}
return vec;
}
static Vec256<T> set(const Vec256<T>& a, const Vec256<T>& b, int64_t count = size()) {
Vec256 vec;
for (int64_t i = 0; i < size(); i++) {
if (i < count) {
vec[i] = b[i];
} else {
vec[i] = a[i];
}
}
return vec;
}
static Vec256<T> loadu(const void* ptr) {
Vec256 vec;
std::memcpy(vec.values, ptr, 32);
return vec;
}
static Vec256<T> loadu(const void* ptr, int64_t count) {
Vec256 vec;
std::memcpy(vec.values, ptr, count * sizeof(T));
return vec;
}
void store(void* ptr, int count = size()) const {
std::memcpy(ptr, values, count * sizeof(T));
}
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
int mask = 0;
for (int i = 0; i < size(); ++ i) {
if (values[i] == static_cast<T>(0)) {
mask |= (1 << i);
}
}
return mask;
}
Vec256<T> map(T (*f)(T)) const {
Vec256<T> ret;
for (int64_t i = 0; i != size(); i++) {
ret[i] = f(values[i]);
}
return ret;
}
Vec256<T> map(T (*f)(const T &)) const {
Vec256<T> ret;
for (int64_t i = 0; i != size(); i++) {
ret[i] = f(values[i]);
}
return ret;
}
template <typename other_t_abs = T,
typename std::enable_if<!is_floating_point<other_t_abs>::value && !c10::is_complex<other_t_abs>::value, int>::type = 0>
Vec256<T> abs() const {
// other_t_abs is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<other_t_abs, T>::value, "other_t_abs must be T");
return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
}
template <typename float_t_abs = T,
typename std::enable_if<is_floating_point<float_t_abs>::value, int>::type = 0>
Vec256<T> abs() const {
// float_t_abs is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<float_t_abs, T>::value, "float_t_abs must be T");
// Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in
// 0.0) properly.
return map([](T x) -> T { return std::abs(x); });
}
template <typename complex_t_abs = T,
typename std::enable_if<c10::is_complex<complex_t_abs>::value, int>::type = 0>
Vec256<T> abs() const {
// complex_t_abs is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<complex_t_abs, T>::value, "complex_t_abs must be T");
// Specifically map() does not perform the type conversion needed by abs.
return map([](T x) { return static_cast<T>(std::abs(x)); });
}
template <typename other_t_sgn = T,
typename std::enable_if<c10::is_complex<other_t_sgn>::value, int>::type = 0>
Vec256<T> sgn() const {
return map(at::native::sgn_impl);
}
template <typename other_t_angle = T,
typename std::enable_if<!c10::is_complex<other_t_angle>::value, int>::type = 0>
Vec256<T> angle() const {
// other_t_angle is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<other_t_angle, T>::value, "other_t_angle must be T");
return map(at::native::angle_impl<T>); // compiler is unable to resolve the overload without <T>
}
template <typename complex_t_angle = T,
typename std::enable_if<c10::is_complex<complex_t_angle>::value, int>::type = 0>
Vec256<T> angle() const {
// complex_t_angle is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<complex_t_angle, T>::value, "complex_t_angle must be T");
return map([](T x) { return static_cast<T>(std::arg(x)); });
}
template <typename other_t_real = T,
typename std::enable_if<!c10::is_complex<other_t_real>::value, int>::type = 0>
Vec256<T> real() const {
// other_t_real is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<other_t_real, T>::value, "other_t_real must be T");
return *this;
}
template <typename complex_t_real = T,
typename std::enable_if<c10::is_complex<complex_t_real>::value, int>::type = 0>
Vec256<T> real() const {
// complex_t_real is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<complex_t_real, T>::value, "complex_t_real must be T");
return map([](T x) { return static_cast<T>(x.real()); });
}
template <typename other_t_imag = T,
typename std::enable_if<!c10::is_complex<other_t_imag>::value, int>::type = 0>
Vec256<T> imag() const {
// other_t_imag is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<other_t_imag, T>::value, "other_t_imag must be T");
return Vec256(0);
}
template <typename complex_t_imag = T,
typename std::enable_if<c10::is_complex<complex_t_imag>::value, int>::type = 0>
Vec256<T> imag() const {
// complex_t_imag is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<complex_t_imag, T>::value, "complex_t_imag must be T");
return map([](T x) { return static_cast<T>(x.imag()); });
}
template <typename other_t_conj = T,
typename std::enable_if<!c10::is_complex<other_t_conj>::value, int>::type = 0>
Vec256<T> conj() const {
// other_t_conj is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<other_t_conj, T>::value, "other_t_conj must be T");
return *this;
}
template <typename complex_t_conj = T,
typename std::enable_if<c10::is_complex<complex_t_conj>::value, int>::type = 0>
Vec256<T> conj() const {
// complex_t_conj is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<complex_t_conj, T>::value, "complex_t_conj must be T");
return map([](T x) { return static_cast<T>(std::conj(x)); });
}
Vec256<T> acos() const {
return map(std::acos);
}
Vec256<T> asin() const {
return map(std::asin);
}
Vec256<T> atan() const {
return map(std::atan);
}
Vec256<T> atan2(const Vec256<T> &exp) const {
Vec256<T> ret;
for (int64_t i = 0; i < size(); i++) {
ret[i] = std::atan2(values[i], exp[i]);
}
return ret;
}
Vec256<T> erf() const {
return map(std::erf);
}
Vec256<T> erfc() const {
return map(std::erfc);
}
Vec256<T> erfinv() const {
return map(calc_erfinv);
}
Vec256<T> exp() const {
return map(std::exp);
}
Vec256<T> expm1() const {
return map(std::expm1);
}
Vec256<T> frac() const {
return *this - this->trunc();
}
template <
typename U = T,
typename std::enable_if_t<is_floating_point<U>::value, int> = 0>
Vec256<T> fmod(const Vec256<T>& q) const {
// U is for SFINAE purposes only. Make sure it is not changed.
static_assert(std::is_same<U, T>::value, "U must be T");
Vec256<T> ret;
for (int64_t i = 0; i < size(); i++) {
ret[i] = std::fmod(values[i], q[i]);
Loading ...