Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ include / ATen / cpu / vec256 / vec256_base.h

#pragma once

// See Note [Do not compile initializers with AVX]
// Note [Do not compile initializers with AVX]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// If you define a static initializer in this file, the initialization will use
// AVX instructions because these object files are compiled with AVX enabled.
// We need to avoid non-trivial global data in these architecture specific files
// because there's no way to guard the global initializers with CPU capability
// detection.
// See https://github.com/pytorch/pytorch/issues/37577 for an instance
// of this bug in the past.

#include <cstring>
#include <functional>
#include <cmath>
#include <type_traits>
#include <bitset>

#include <ATen/cpu/vec256/intrinsics.h>
#include <ATen/native/Math.h>
#include <ATen/NumericUtils.h>
#include <c10/util/C++17.h>
#include <c10/util/BFloat16.h>
#include <c10/util/BFloat16-math.h>
#include <c10/util/math_compat.h>
#include <ATen/native/cpu/zmath.h>
#include <c10/util/TypeCast.h>
#include <c10/macros/Macros.h>

#if defined(__GNUC__)
#define __at_align32__ __attribute__((aligned(32)))
#elif defined(_WIN32)
#define __at_align32__ __declspec(align(32))
#define __at_align32__

namespace at {
namespace vec256 {
// See Note [Acceptable use of anonymous namespace in header]
namespace {
// at::Half should be treated as floating point
template <typename T>
struct is_floating_point:
      std::is_floating_point<T>::value ||
      std::is_same<T, at::Half>::value> {

template<size_t n> struct int_of_size;

#define DEFINE_INT_OF_SIZE(int_t) \
template<> struct int_of_size<sizeof(int_t)> { using type = int_t; }



template <typename T>
using int_same_size_t = typename int_of_size<sizeof(T)>::type;

// NOTE: If you specialize on a type, you must define all operations!

// emulates vectorized types
template <class T>
struct Vec256 {
  __at_align32__ T values[32 / sizeof(T)];
  using value_type = T;
  // Note [constexpr static function to avoid odr-usage compiler bug]
  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // Why, you might ask, is size defined to be a static constexpr function,
  // rather than a more ordinary 'static constexpr int size;' variable?
  // The problem lies within ODR rules for static constexpr members versus
  // static constexpr functions.  First, recall that this class (along with all
  // of its derivations) live in an anonymous namespace: they are intended to be
  // *completely* inlined at their use-sites, because we need to compile it
  // multiple times for different instruction sets.
  // Because of this constraint, we CANNOT provide a single definition for
  // any static members in this class; since we want to compile the class
  // multiple times, there wouldn't actually be any good place to put the
  // definition.  Now here is the problem: if we ODR-use a static constexpr
  // member, we are *obligated* to provide a definition.  Without the
  // definition, you get a compile error like:
  //    relocation R_X86_64_PC32 against undefined symbol
  //    `_ZN2at6vec25612_GLOBAL__N_16Vec256IdE4sizeE' can not be used when making
  //    a shared object; recompile with -fPIC
  // If this were C++17, we could replace a static constexpr variable with
  // an inline variable which doesn't require one definition. But we are not
  // C++17.  So the next best thing is to replace the member with a static
  // constexpr (and therefore inline) function, which does not require ODR
  // either.
  // Also, technically according to the C++ standard, we don't have to define
  // a constexpr variable if we never odr-use it.  But it seems that some
  // versions GCC/Clang have buggy determinations on whether or not an
  // identifier is odr-used or not, and in any case it's hard to tell if
  // a variable is odr-used or not.  So best to just cut the problem at the root.
  static constexpr int size() {
    return 32 / sizeof(T);
  Vec256() : values{0} {}
  Vec256(T val) {
    for (int i = 0; i != size(); i++) {
      values[i] = val;
  template<typename... Args,
           typename = std::enable_if_t<(sizeof...(Args) == size())>>
  Vec256(Args... vals) : values{vals...}{
  // This also implies const T& operator[](int idx) const
  inline operator const T*() const {
    return values;
  // This also implies T& operator[](int idx)
  inline operator T*() {
    return values;
  template <int64_t mask_>
  static Vec256<T> blend(const Vec256<T>& a, const Vec256<T>& b) {
    int64_t mask = mask_;
    Vec256 vec;
    for (int64_t i = 0; i < size(); i++) {
      if (mask & 0x01) {
        vec[i] = b[i];
      } else {
        vec[i] = a[i];
      mask = mask >> 1;
    return vec;
  static Vec256<T> blendv(const Vec256<T>& a, const Vec256<T>& b,
                          const Vec256<T>& mask) {
    Vec256 vec;
    int_same_size_t<T> buffer[size()];
    for (int64_t i = 0; i < size(); i++) {
      if (buffer[i] & 0x01)
        vec[i] = b[i];
      } else {
        vec[i] = a[i];
    return vec;
  template<typename step_t>  // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
  static Vec256<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) {
    Vec256 vec;
    for (int64_t i = 0; i < size(); i++) {
      vec.values[i] = base + i * step;
    return vec;
  static Vec256<T> set(const Vec256<T>& a, const Vec256<T>& b, int64_t count = size()) {
    Vec256 vec;
    for (int64_t i = 0; i < size(); i++) {
      if (i < count) {
        vec[i] = b[i];
      } else {
        vec[i] = a[i];
    return vec;
  static Vec256<T> loadu(const void* ptr) {
    Vec256 vec;
    std::memcpy(vec.values, ptr, 32);
    return vec;
  static Vec256<T> loadu(const void* ptr, int64_t count) {
    Vec256 vec;
    std::memcpy(vec.values, ptr, count * sizeof(T));
    return vec;
  void store(void* ptr, int count = size()) const {
    std::memcpy(ptr, values, count * sizeof(T));
  int zero_mask() const {
    // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
    int mask = 0;
    for (int i = 0; i < size(); ++ i) {
      if (values[i] == static_cast<T>(0)) {
        mask |= (1 << i);
    return mask;
  Vec256<T> map(T (*f)(T)) const {
    Vec256<T> ret;
    for (int64_t i = 0; i != size(); i++) {
      ret[i] = f(values[i]);
    return ret;
  Vec256<T> map(T (*f)(const T &)) const {
    Vec256<T> ret;
    for (int64_t i = 0; i != size(); i++) {
      ret[i] = f(values[i]);
    return ret;
  template <typename other_t_abs = T,
            typename std::enable_if<!is_floating_point<other_t_abs>::value && !c10::is_complex<other_t_abs>::value, int>::type = 0>
  Vec256<T> abs() const {
    // other_t_abs is for SFINAE and clarity. Make sure it is not changed.
    static_assert(std::is_same<other_t_abs, T>::value, "other_t_abs must be T");
    return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
  template <typename float_t_abs = T,
            typename std::enable_if<is_floating_point<float_t_abs>::value, int>::type = 0>
  Vec256<T> abs() const {
    // float_t_abs is for SFINAE and clarity. Make sure it is not changed.
    static_assert(std::is_same<float_t_abs, T>::value, "float_t_abs must be T");
    // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in
    // 0.0) properly.
    return map([](T x) -> T { return std::abs(x); });
  template <typename complex_t_abs = T,
            typename std::enable_if<c10::is_complex<complex_t_abs>::value, int>::type = 0>
  Vec256<T> abs() const {
    // complex_t_abs is for SFINAE and clarity. Make sure it is not changed.
    static_assert(std::is_same<complex_t_abs, T>::value, "complex_t_abs must be T");
    // Specifically map() does not perform the type conversion needed by abs.
    return map([](T x) { return static_cast<T>(std::abs(x)); });

  template <typename other_t_sgn = T,
            typename std::enable_if<c10::is_complex<other_t_sgn>::value, int>::type = 0>
  Vec256<T> sgn() const {
    return map(at::native::sgn_impl);

  template <typename other_t_angle = T,
            typename std::enable_if<!c10::is_complex<other_t_angle>::value, int>::type = 0>
  Vec256<T> angle() const {
    // other_t_angle is for SFINAE and clarity. Make sure it is not changed.
    static_assert(std::is_same<other_t_angle, T>::value, "other_t_angle must be T");
    return map(at::native::angle_impl<T>);  // compiler is unable to resolve the overload without <T>
  template <typename complex_t_angle = T,
            typename std::enable_if<c10::is_complex<complex_t_angle>::value, int>::type = 0>
  Vec256<T> angle() const {
    // complex_t_angle is for SFINAE and clarity. Make sure it is not changed.
    static_assert(std::is_same<complex_t_angle, T>::value, "complex_t_angle must be T");
    return map([](T x) { return static_cast<T>(std::arg(x)); });
  template <typename other_t_real = T,
            typename std::enable_if<!c10::is_complex<other_t_real>::value, int>::type = 0>
  Vec256<T> real() const {
    // other_t_real is for SFINAE and clarity. Make sure it is not changed.
    static_assert(std::is_same<other_t_real, T>::value, "other_t_real must be T");
    return *this;
  template <typename complex_t_real = T,
            typename std::enable_if<c10::is_complex<complex_t_real>::value, int>::type = 0>
  Vec256<T> real() const {
    // complex_t_real is for SFINAE and clarity. Make sure it is not changed.
    static_assert(std::is_same<complex_t_real, T>::value, "complex_t_real must be T");
    return map([](T x) { return static_cast<T>(x.real()); });
  template <typename other_t_imag = T,
            typename std::enable_if<!c10::is_complex<other_t_imag>::value, int>::type = 0>
  Vec256<T> imag() const {
    // other_t_imag is for SFINAE and clarity. Make sure it is not changed.
    static_assert(std::is_same<other_t_imag, T>::value, "other_t_imag must be T");
    return Vec256(0);
  template <typename complex_t_imag = T,
            typename std::enable_if<c10::is_complex<complex_t_imag>::value, int>::type = 0>
  Vec256<T> imag() const {
    // complex_t_imag is for SFINAE and clarity. Make sure it is not changed.
    static_assert(std::is_same<complex_t_imag, T>::value, "complex_t_imag must be T");
    return map([](T x) { return static_cast<T>(x.imag()); });
  template <typename other_t_conj = T,
            typename std::enable_if<!c10::is_complex<other_t_conj>::value, int>::type = 0>
  Vec256<T> conj() const {
    // other_t_conj is for SFINAE and clarity. Make sure it is not changed.
    static_assert(std::is_same<other_t_conj, T>::value, "other_t_conj must be T");
    return *this;
  template <typename complex_t_conj = T,
            typename std::enable_if<c10::is_complex<complex_t_conj>::value, int>::type = 0>
  Vec256<T> conj() const {
    // complex_t_conj is for SFINAE and clarity. Make sure it is not changed.
    static_assert(std::is_same<complex_t_conj, T>::value, "complex_t_conj must be T");
    return map([](T x) { return static_cast<T>(std::conj(x)); });
  Vec256<T> acos() const {
    return map(std::acos);
  Vec256<T> asin() const {
    return map(std::asin);
  Vec256<T> atan() const {
    return map(std::atan);
  Vec256<T> atan2(const Vec256<T> &exp) const {
    Vec256<T> ret;
    for (int64_t i = 0; i < size(); i++) {
      ret[i] = std::atan2(values[i], exp[i]);
    return ret;
  Vec256<T> erf() const {
    return map(std::erf);
  Vec256<T> erfc() const {
    return map(std::erfc);
  Vec256<T> erfinv() const {
    return map(calc_erfinv);
  Vec256<T> exp() const {
    return map(std::exp);
  Vec256<T> expm1() const {
    return map(std::expm1);
  Vec256<T> frac() const {
    return *this - this->trunc();
  template <
    typename U = T,
    typename std::enable_if_t<is_floating_point<U>::value, int> = 0>
  Vec256<T> fmod(const Vec256<T>& q) const {
    // U is for SFINAE purposes only. Make sure it is not changed.
    static_assert(std::is_same<U, T>::value, "U must be T");
    Vec256<T> ret;
    for (int64_t i = 0; i < size(); i++) {
      ret[i] = std::fmod(values[i], q[i]);
Loading ...