Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch / include / ATen / cpu / vec / vec256 / vec256_half_neon.h
Size: Mime:
#pragma once

// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]

#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_float_neon.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/Half.h>
#include <c10/util/irange.h>

namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {

// Right now contains only aarch64 implementation.
// Due to follow two reasons aarch32 is not currently supported.
// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics
//    that work for aarch64 dont work for aarch32.
// 2. Android NDK r21 has problems with compiling aarch32.
//    Clang seg faults.
//    https://github.com/android/ndk/issues/1248
//    https://bugs.llvm.org/show_bug.cgi?id=45824
// Most likely we will do aarch32 support with inline asm.
#if !defined(C10_MOBILE) && defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)

#ifdef __BIG_ENDIAN__
#error "Big endian is not supported."
#endif

template <int index, bool mask_val>
struct BlendHalfRegs {
  static float16x8_t impl(
      const float16x8_t& a,
      const float16x8_t& b,
      float16x8_t& res);
};

template <int index>
struct BlendHalfRegs<index, true> {
  static float16x8_t impl(
      const float16x8_t& a,
      const float16x8_t& b,
      float16x8_t& res) {
    return vsetq_lane_f16(vgetq_lane_f16(b, index), res, index);
  }
};

template <int index>
struct BlendHalfRegs<index, false> {
  static float16x8_t impl(
      const float16x8_t& a,
      const float16x8_t& b,
      float16x8_t& res) {
    return vsetq_lane_f16(vgetq_lane_f16(a, index), res, index);
  }
};

// On ARM, Half type supports float16_t->Half constructor and Half->float16_t
// conversion
template <>
class Vectorized<c10::Half> {
 private:
  float16x8x2_t values;

 public:
  // value_type should be c10::Half to fit interface with vec_base.h
  using value_type = c10::Half;
  using size_type = int;
  static constexpr size_type size() {
    static_assert(sizeof(float16x8x2_t) == 16 * sizeof(value_type));
    return 16;
  }

 private:
  // We use these private map functions to implement various methods
  Vectorized<c10::Half> map2(
      const Vectorized<c10::Half>& second,
      c10::Half (*const f)(c10::Half, c10::Half)) const {
    __at_align__ c10::Half tmp_first[size()];
    __at_align__ c10::Half tmp_second[size()];
    store(tmp_first); // store this to tmp_first
    second.store(tmp_second);
    for (const auto i : c10::irange(size())) {
      tmp_first[i] = f(tmp_first[i], tmp_second[i]);
    }
    return loadu(tmp_first);
  }

  Vectorized<c10::Half> map_with_vec_float_method(
      Vectorized<float> (Vectorized<float>::*m)() const) const {
    // Convert low float16x8_t to 2 float32x4_t variables, apply m, and convert
    // back
    float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values.val[0]));
    float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values.val[0]));
    Vectorized<float> mv0 = (Vectorized<float>(v00, v01).*m)();
    float16x4_t r00 = vcvt_f16_f32(mv0.get_low());
    float16x4_t r01 = vcvt_f16_f32(mv0.get_high());

    // Convert high float16x8_t to 2 float32x4_t variables, apply m, and convert
    // back
    float32x4_t v10 = vcvt_f32_f16(vget_low_f16(values.val[1]));
    float32x4_t v11 = vcvt_f32_f16(vget_high_f16(values.val[1]));
    Vectorized<float> mv1 = (Vectorized<float>(v10, v11).*m)();
    float16x4_t r10 = vcvt_f16_f32(mv1.get_low());
    float16x4_t r11 = vcvt_f16_f32(mv1.get_high());

    // Pack result into Vectorized<c10::Half>
    return Vectorized<c10::Half>(
        vcombine_f16(r00, r01), vcombine_f16(r10, r11));
  }

  Vectorized<c10::Half> map2_with_vec_float_method(
      const Vectorized<c10::Half>& second,
      Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
          const) const {
    // Convert low float16x8_t to 2 float32x4_t variables, apply m, and convert
    // back
    float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values.val[0]));
    float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values.val[0]));
    float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.get_low()));
    float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.get_low()));
    Vectorized<float> mv0 = (Vectorized<float>(v00, v01).*m)(
        Vectorized<float>(second_v00, second_v01));
    float16x4_t r00 = vcvt_f16_f32(mv0.get_low());
    float16x4_t r01 = vcvt_f16_f32(mv0.get_high());

    // Convert high float16x8_t to 2 float32x4_t variables, apply m, and convert
    // back
    float32x4_t v10 = vcvt_f32_f16(vget_low_f16(values.val[1]));
    float32x4_t v11 = vcvt_f32_f16(vget_high_f16(values.val[1]));
    float32x4_t second_v10 = vcvt_f32_f16(vget_low_f16(second.get_high()));
    float32x4_t second_v11 = vcvt_f32_f16(vget_high_f16(second.get_high()));
    Vectorized<float> mv1 = (Vectorized<float>(v10, v11).*m)(
        Vectorized<float>(second_v10, second_v11));
    float16x4_t r10 = vcvt_f16_f32(mv1.get_low());
    float16x4_t r11 = vcvt_f16_f32(mv1.get_high());

    // Pack result into Vectorized<c10::Half>
    return Vectorized<c10::Half>(
        vcombine_f16(r00, r01), vcombine_f16(r10, r11));
  }

 public:
   // constructor
  Vectorized() {}
  Vectorized(float16x8x2_t v) : values(v) {}

  // A ctor that accepts c10::Half is needed to fit interface with vec_base.h
  // A second constructor that takes float16_t is also included
  Vectorized(c10::Half val)
      : values{vdupq_n_f16((float16_t)val), vdupq_n_f16((float16_t)val)} {
  }
  Vectorized(float16_t val) : values{vdupq_n_f16(val), vdupq_n_f16(val)} {}
  Vectorized(
      float16_t val0,
      float16_t val1,
      float16_t val2,
      float16_t val3,
      float16_t val4,
      float16_t val5,
      float16_t val6,
      float16_t val7,
      float16_t val8,
      float16_t val9,
      float16_t val10,
      float16_t val11,
      float16_t val12,
      float16_t val13,
      float16_t val14,
      float16_t val15)
      : values{
            val0,
            val1,
            val2,
            val3,
            val4,
            val5,
            val6,
            val7,
            val8,
            val9,
            val10,
            val11,
            val12,
            val13,
            val14,
            val15} {}
  Vectorized(float16x8_t val0, float16x8_t val1) : values{val0, val1} {}
  operator float16x8x2_t() const {
    return values;
  }
  template <int64_t mask>
  static Vectorized<c10::Half> blend(
      const Vectorized<c10::Half>& a,
      const Vectorized<c10::Half>& b) {
    Vectorized<c10::Half> vec;
    // 0.
    vec.values.val[0] = BlendHalfRegs<0, (mask & 0x01) != 0>::impl(
        a.values.val[0], b.values.val[0], vec.values.val[0]);
    vec.values.val[0] = BlendHalfRegs<1, (mask & 0x02) != 0>::impl(
        a.values.val[0], b.values.val[0], vec.values.val[0]);
    vec.values.val[0] = BlendHalfRegs<2, (mask & 0x04) != 0>::impl(
        a.values.val[0], b.values.val[0], vec.values.val[0]);
    vec.values.val[0] = BlendHalfRegs<3, (mask & 0x08) != 0>::impl(
        a.values.val[0], b.values.val[0], vec.values.val[0]);

    vec.values.val[0] = BlendHalfRegs<4, (mask & 0x10) != 0>::impl(
        a.values.val[0], b.values.val[0], vec.values.val[0]);
    vec.values.val[0] = BlendHalfRegs<5, (mask & 0x20) != 0>::impl(
        a.values.val[0], b.values.val[0], vec.values.val[0]);
    vec.values.val[0] = BlendHalfRegs<6, (mask & 0x40) != 0>::impl(
        a.values.val[0], b.values.val[0], vec.values.val[0]);
    vec.values.val[0] = BlendHalfRegs<7, (mask & 0x80) != 0>::impl(
        a.values.val[0], b.values.val[0], vec.values.val[0]);

    // 1.
    vec.values.val[1] = BlendHalfRegs<0, (mask & 0x10) != 0>::impl(
        a.values.val[1], b.values.val[1], vec.values.val[1]);
    vec.values.val[1] = BlendHalfRegs<1, (mask & 0x20) != 0>::impl(
        a.values.val[1], b.values.val[1], vec.values.val[1]);
    vec.values.val[1] = BlendHalfRegs<2, (mask & 0x40) != 0>::impl(
        a.values.val[1], b.values.val[1], vec.values.val[1]);
    vec.values.val[1] = BlendHalfRegs<3, (mask & 0x80) != 0>::impl(
        a.values.val[1], b.values.val[1], vec.values.val[1]);

    vec.values.val[1] = BlendHalfRegs<4, (mask & 0x10) != 0>::impl(
        a.values.val[1], b.values.val[1], vec.values.val[1]);
    vec.values.val[1] = BlendHalfRegs<5, (mask & 0x20) != 0>::impl(
        a.values.val[1], b.values.val[1], vec.values.val[1]);
    vec.values.val[1] = BlendHalfRegs<6, (mask & 0x40) != 0>::impl(
        a.values.val[1], b.values.val[1], vec.values.val[1]);
    vec.values.val[1] = BlendHalfRegs<7, (mask & 0x80) != 0>::impl(
        a.values.val[1], b.values.val[1], vec.values.val[1]);

    return vec;
  }
  static Vectorized<c10::Half> blendv(
      const Vectorized<c10::Half>& a,
      const Vectorized<c10::Half>& b,
      const Vectorized<c10::Half>& mask) {
    // Note: using blendv is very awkward because 0xFFFF is one of many NaN's in
    // FP16 It's unfortunate that the mask has type Half (required from
    // vec_base)

    // TODO
    // NB: This requires that each value, i.e., each uint value,
    // of the mask either all be zeros or all be 1s.
    // We perhaps need some kind of an assert?
    // But that will affect performance.
    Vectorized<c10::Half> vec(mask.values);
    vec.values.val[0] = vbslq_f16(
        vreinterpretq_u16_f16(vec.values.val[0]),
        b.values.val[0],
        a.values.val[0]);
    vec.values.val[1] = vbslq_f16(
        vreinterpretq_u16_f16(vec.values.val[1]),
        b.values.val[1],
        a.values.val[1]);
    return vec;
  }
  template <typename step_t>
  static Vectorized<c10::Half> arange(
      c10::Half base = 0.0,
      step_t step = static_cast<step_t>(1)) {
    const Vectorized<c10::Half> base_vec(base);
    const Vectorized<c10::Half> step_vec(step);
    const Vectorized<c10::Half> step_sizes(
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
    return fmadd(step_sizes, step_vec, base_vec);
  }
  static Vectorized<c10::Half> set(
      const Vectorized<c10::Half>& a,
      const Vectorized<c10::Half>& b,
      int64_t count = size()) {
    uint16_t pre_mask[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
    for (int i = 0; i < count; i++) {
      pre_mask[i] = 0xFFFF;
    }
    uint16x8x2_t mask = vld1q_u16_x2(pre_mask);

    // Using blendv is awkward because 0xFFFF is one of many NaN's in FP16
    // so we directly use vbslq_f16 instead
    Vectorized<c10::Half> vec(
        vbslq_f16(
            // Low bits
            mask.val[0],
            b.values.val[0],
            a.values.val[0]),
        // High bits
        vbslq_f16(mask.val[1], b.values.val[1], a.values.val[1]));

    return vec;
  }
  static Vectorized<c10::Half> loadu(const void* ptr, int64_t count = size()) {
    if (count == size()) {
      return vld1q_f16_x2(reinterpret_cast<const float16_t*>(ptr));
    } else if (count == (size() >> 1)) {
      Vectorized<c10::Half> res;
      res.values.val[0] = vld1q_f16(reinterpret_cast<const float16_t*>(ptr));
      std::memset(&res.values.val[1], 0, sizeof(res.values.val[1]));
      return res;
    }
    __at_align__ float16_t tmp_values[size()];
    for (const auto i : c10::irange(size())) {
      tmp_values[i] = 0;
    }
    std::memcpy(
        tmp_values,
        reinterpret_cast<const float16_t*>(ptr),
        count * sizeof(float16_t));
    return vld1q_f16_x2(reinterpret_cast<const float16_t*>(tmp_values));
  }
  void store(void* ptr, int64_t count = size()) const {
    if (count == size()) {
      vst1q_f16_x2(reinterpret_cast<float16_t*>(ptr), values);
      return;
    } else if (count == (size() >> 1)) {
      vst1q_f16(reinterpret_cast<float16_t*>(ptr), values.val[0]);
    } else {
      float16_t tmp_values[size()];
      vst1q_f16_x2(reinterpret_cast<float16_t*>(tmp_values), values);
      std::memcpy(ptr, tmp_values, count * sizeof(float16_t));
    }
  }
  inline const float16x8_t& get_low() const {
    return values.val[0];
  }
  inline float16x8_t& get_low() {
    return values.val[0];
  }
  inline const float16x8_t& get_high() const {
    return values.val[1];
  }
  inline float16x8_t& get_high() {
    return values.val[1];
  }
  // Very slow implementation of indexing.
  // Only required because vec256_qint refers to this.
  // Once we specialize that implementation for ARM
  // this should be removed. TODO (kimishpatel)
  c10::Half operator[](int idx) const {
    __at_align__ c10::Half tmp[size()];
    store(tmp);
    return tmp[idx];
  }
  c10::Half operator[](int idx) {
    __at_align__ c10::Half tmp[size()];
    store(tmp);
    return tmp[idx];
  }
  // For boolean version where we want to if any 1/all zero
  // etc. can be done faster in a different way.
  int zero_mask() const {
    __at_align__ c10::Half tmp[size()];
    store(tmp);
    int mask = 0;
    for (int i = 0; i < size(); ++i) {
      if (tmp[i] == 0) {
        mask |= (1 << i);
      }
    }
    return mask;
  }
  Vectorized<c10::Half> isnan() const {
    __at_align__ c10::Half tmp[size()];
    __at_align__ c10::Half res[size()];
    store(tmp);
    for (const auto i : c10::irange(size())) {
      if (_isnan(tmp[i])) {
        std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(c10::Half));
      } else {
        std::memset(static_cast<void*>(&res[i]), 0, sizeof(c10::Half));
      }
    }
    return loadu(res);
  };
  bool has_inf_nan() const {
    __at_align__ c10::Half tmp[size()];
    store(tmp);
    for (const auto i : c10::irange(size())) {
      if (_isnan(tmp[i]) || _isinf(tmp[i])) {
        return true;
      }
    }
    return false;
  }
  Vectorized<c10::Half> map(c10::Half (*const f)(c10::Half)) const {
    __at_align__ c10::Half tmp[size()];
    store(tmp);
    for (const auto i : c10::irange(size())) {
      tmp[i] = f(tmp[i]);
    }
    return loadu(tmp);
  }
  Vectorized<c10::Half> abs() const {
    return Vectorized<c10::Half>(
        vabsq_f16(values.val[0]), vabsq_f16(values.val[1]));
  }
  Vectorized<c10::Half> angle() const {
    auto zero = Vectorized<c10::Half>(0);
    auto pi = Vectorized<c10::Half>(c10::pi<c10::Half>);
    auto tmp = blendv(zero, pi, *this < zero);
    return blendv(tmp, *this, isnan());
  }
  Vectorized<c10::Half> real() const {
    return *this;
  }
  Vectorized<c10::Half> imag() const {
    return Vectorized<c10::Half>(0);
  }
  Vectorized<c10::Half> conj() const {
    return *this;
  }

  // Sleef does not support FP16, so many math functions are applied by
  // converting to FP32, applying the math function, and then converting back to
  // FP16.
  Vectorized<c10::Half> acos() const {
    return map_with_vec_float_method(&Vectorized<float>::acos);
  }
  Vectorized<c10::Half> acosh() const {
    return map_with_vec_float_method(&Vectorized<float>::acosh);
  }
  Vectorized<c10::Half> asin() const {
    return map_with_vec_float_method(&Vectorized<float>::asin);
  }
  Vectorized<c10::Half> atan() const {
    return map_with_vec_float_method(&Vectorized<float>::atan);
  }
  Vectorized<c10::Half> atanh() const {
    return map_with_vec_float_method(&Vectorized<float>::atanh);
  }
  Vectorized<c10::Half> atan2(const Vectorized<c10::Half>& exp) const {
    return map2_with_vec_float_method(exp, &Vectorized<float>::atan2);
  }
  Vectorized<c10::Half> copysign(const Vectorized<c10::Half>& sign) const {
    return map2_with_vec_float_method(sign, &Vectorized<float>::copysign);
  }
  Vectorized<c10::Half> erf() const {
    return map_with_vec_float_method(&Vectorized<float>::erf);
  }
  Vectorized<c10::Half> erfc() const {
    return map_with_vec_float_method(&Vectorized<float>::erfc);
  }
  Vectorized<c10::Half> erfinv() const {
    return map_with_vec_float_method(&Vectorized<float>::erfinv);
  }
  Vectorized<c10::Half> exp() const {
    return map_with_vec_float_method(&Vectorized<float>::exp);
  }
  Vectorized<c10::Half> exp2() const {
    return map_with_vec_float_method(&Vectorized<float>::exp2);
  }
  Vectorized<c10::Half> expm1() const {
    return map_with_vec_float_method(&Vectorized<float>::expm1);
  }
  Vectorized<c10::Half> exp_u20() const {
    return map_with_vec_float_method(&Vectorized<float>::exp_u20);
  }
  Vectorized<c10::Half> fmod(const Vectorized<c10::Half>& q) const {
    // This function is questionable with a conversion, so we use map2
    return map2(q, std::fmod);
  }
  Vectorized<c10::Half> hypot(const Vectorized<c10::Half>& b) const {
    return map2_with_vec_float_method(b, &Vectorized<float>::hypot);
  }
  Vectorized<c10::Half> i0() const {
    return map_with_vec_float_method(&Vectorized<float>::i0);
  }
  Vectorized<c10::Half> i0e() const {
    return map_with_vec_float_method(&Vectorized<float>::i0e);
  }
  Vectorized<c10::Half> digamma() const {
    return map_with_vec_float_method(&Vectorized<float>::digamma);
  }
  Vectorized<c10::Half> igamma(const Vectorized<c10::Half>& x) const {
    return map2_with_vec_float_method(x, &Vectorized<float>::igamma);
  }
  Vectorized<c10::Half> igammac(const Vectorized<c10::Half>& x) const {
    return map2_with_vec_float_method(x, &Vectorized<float>::igammac);
  }
  Vectorized<c10::Half> log() const {
    return map_with_vec_float_method(&Vectorized<float>::log);
  }
  Vectorized<c10::Half> log10() const {
    return map_with_vec_float_method(&Vectorized<float>::log10);
  }
  Vectorized<c10::Half> log1p() const {
    return map_with_vec_float_method(&Vectorized<float>::log1p);
  }
  Vectorized<c10::Half> log2() const {
    return map_with_vec_float_method(&Vectorized<float>::log2);
  }
  Vectorized<c10::Half> nextafter(const Vectorized<c10::Half>& b) const {
    // This function does not make sense with conversion, so we use map2
    return map2(b, std::nextafter);
  }
  Vectorized<c10::Half> frac() const;
  Vectorized<c10::Half> sin() const {
    return map_with_vec_float_method(&Vectorized<float>::sin);
  }
  Vectorized<c10::Half> sinh() const {
    return map_with_vec_float_method(&Vectorized<float>::sinh);
  }
  Vectorized<c10::Half> cos() const {
    return map_with_vec_float_method(&Vectorized<float>::cos);
  }
  Vectorized<c10::Half> cosh() const {
    return map_with_vec_float_method(&Vectorized<float>::cosh);
  }
  Vectorized<c10::Half> ceil() const {
    // This function is questionable with a conversion, so we use map
    return map(at::native::ceil_impl);
  }
  Vectorized<c10::Half> floor() const {
    // This function is questionable with a conversion, so we use map
    return map(at::native::floor_impl);
  }
  Vectorized<c10::Half> neg() const {
    return Vectorized<c10::Half>(
        vnegq_f16(values.val[0]), vnegq_f16(values.val[1]));
  }
  inline Vectorized<c10::Half> round() const {
    // This function is questionable with a conversion, so we use map
    return map(at::native::round_impl);
  }
  inline Vectorized<c10::Half> tan() const {
    return map_with_vec_float_method(&Vectorized<float>::tan);
  }
  inline Vectorized<c10::Half> tanh() const {
    return map_with_vec_float_method(&Vectorized<float>::tanh);
  }
  Vectorized<c10::Half> trunc() const {
    float16x8_t r0 = vrndq_f16(values.val[0]);
    float16x8_t r1 = vrndq_f16(values.val[1]);
    return Vectorized<c10::Half>(r0, r1);
  }
  Vectorized<c10::Half> lgamma() const {
    return map_with_vec_float_method(&Vectorized<float>::lgamma);
  }
  Vectorized<c10::Half> sqrt() const {
    return Vectorized<c10::Half>(
        vsqrtq_f16(values.val[0]), vsqrtq_f16(values.val[1]));
  }
  Vectorized<c10::Half> reciprocal() const {
    auto ones = vdupq_n_f16(1.0f);
    auto r0 = vdivq_f16(ones, values.val[0]);
    auto r1 = vdivq_f16(ones, values.val[1]);
    return Vectorized<c10::Half>(r0, r1);
  }
  Vectorized<c10::Half> rsqrt() const {
    return this->sqrt().reciprocal();
  }
  Vectorized<c10::Half> pow(const Vectorized<c10::Half>& exp) const {
    return map2_with_vec_float_method(exp, &Vectorized<float>::pow);
  }
  Vectorized<c10::Half> operator==(const Vectorized<c10::Half>& other) const {
    float16x8_t r0 =
        vreinterpretq_f16_u16(vceqq_f16(values.val[0], other.values.val[0]));
    float16x8_t r1 =
        vreinterpretq_f16_u16(vceqq_f16(values.val[1], other.values.val[1]));
    return Vectorized<c10::Half>(r0, r1);
  }

  Vectorized<c10::Half> operator!=(const Vectorized<c10::Half>& other) const {
    float16x8_t r0 = vreinterpretq_f16_u16(
        vmvnq_u16(vceqq_f16(values.val[0], other.values.val[0])));
    float16x8_t r1 = vreinterpretq_f16_u16(
        vmvnq_u16(vceqq_f16(values.val[1], other.values.val[1])));
    return Vectorized<c10::Half>(r0, r1);
  }

  Vectorized<c10::Half> operator<(const Vectorized<c10::Half>& other) const {
    float16x8_t r0 =
        vreinterpretq_f16_u16(vcltq_f16(values.val[0], other.values.val[0]));
    float16x8_t r1 =
        vreinterpretq_f16_u16(vcltq_f16(values.val[1], other.values.val[1]));
    return Vectorized<c10::Half>(r0, r1);
  }

  Vectorized<c10::Half> operator<=(const Vectorized<c10::Half>& other) const {
    float16x8_t r0 =
        vreinterpretq_f16_u16(vcleq_f16(values.val[0], other.values.val[0]));
    float16x8_t r1 =
        vreinterpretq_f16_u16(vcleq_f16(values.val[1], other.values.val[1]));
    return Vectorized<c10::Half>(r0, r1);
  }

  Vectorized<c10::Half> operator>(const Vectorized<c10::Half>& other) const {
    float16x8_t r0 =
        vreinterpretq_f16_u16(vcgtq_f16(values.val[0], other.values.val[0]));
    float16x8_t r1 =
        vreinterpretq_f16_u16(vcgtq_f16(values.val[1], other.values.val[1]));
    return Vectorized<c10::Half>(r0, r1);
  }

  Vectorized<c10::Half> operator>=(const Vectorized<c10::Half>& other) const {
    float16x8_t r0 =
        vreinterpretq_f16_u16(vcgeq_f16(values.val[0], other.values.val[0]));
    float16x8_t r1 =
        vreinterpretq_f16_u16(vcgeq_f16(values.val[1], other.values.val[1]));
    return Vectorized<c10::Half>(r0, r1);
  }

  Vectorized<c10::Half> eq(const Vectorized<c10::Half>& other) const;
  Vectorized<c10::Half> ne(const Vectorized<c10::Half>& other) const;
  Vectorized<c10::Half> gt(const Vectorized<c10::Half>& other) const;
  Vectorized<c10::Half> ge(const Vectorized<c10::Half>& other) const;
  Vectorized<c10::Half> lt(const Vectorized<c10::Half>& other) const;
  Vectorized<c10::Half> le(const Vectorized<c10::Half>& other) const;
}; // Vectorized<Half>

template <>
Vectorized<c10::Half> inline operator+(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& b) {
  float16x8_t r0 = vaddq_f16(a.get_low(), b.get_low());
  float16x8_t r1 = vaddq_f16(a.get_high(), b.get_high());
  return Vectorized<c10::Half>(r0, r1);
}

template <>
Vectorized<c10::Half> inline operator-(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& b) {
  float16x8_t r0 = vsubq_f16(a.get_low(), b.get_low());
  float16x8_t r1 = vsubq_f16(a.get_high(), b.get_high());
  return Vectorized<c10::Half>(r0, r1);
}

template <>
Vectorized<c10::Half> inline operator*(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& b) {
  float16x8_t r0 = vmulq_f16(a.get_low(), b.get_low());
  float16x8_t r1 = vmulq_f16(a.get_high(), b.get_high());
  return Vectorized<c10::Half>(r0, r1);
}

template <>
Vectorized<c10::Half> inline operator/(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& b) {
  float16x8_t r0 = vdivq_f16(a.get_low(), b.get_low());
  float16x8_t r1 = vdivq_f16(a.get_high(), b.get_high());
  return Vectorized<c10::Half>(r0, r1);
}

// frac. Implement this here so we can use subtraction
inline Vectorized<c10::Half> Vectorized<c10::Half>::frac() const {
  return *this - this->trunc();
}

// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<c10::Half> inline maximum(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& b) {
  float16x8_t r0 = vmaxq_f16(a.get_low(), b.get_low());
  float16x8_t r1 = vmaxq_f16(a.get_high(), b.get_high());
  return Vectorized<c10::Half>(r0, r1);
}

// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<c10::Half> inline minimum(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& b) {
  float16x8_t r0 = vminq_f16(a.get_low(), b.get_low());
  float16x8_t r1 = vminq_f16(a.get_high(), b.get_high());
  return Vectorized<c10::Half>(r0, r1);
}

template <>
Vectorized<c10::Half> inline clamp(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& min,
    const Vectorized<c10::Half>& max) {
  return minimum(max, maximum(min, a));
}

template <>
Vectorized<c10::Half> inline clamp_max(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& max) {
  return minimum(max, a);
}

template <>
Vectorized<c10::Half> inline clamp_min(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& min) {
  return maximum(min, a);
}

template <>
Vectorized<c10::Half> inline operator&(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& b) {
  float16x8_t r0 = vreinterpretq_f16_u16(vandq_u16(
      vreinterpretq_u16_f16(a.get_low()), vreinterpretq_u16_f16(b.get_low())));
  float16x8_t r1 = vreinterpretq_f16_u16(vandq_u16(
      vreinterpretq_u16_f16(a.get_high()),
      vreinterpretq_u16_f16(b.get_high())));
  return Vectorized<c10::Half>(r0, r1);
}

template <>
Vectorized<c10::Half> inline operator|(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& b) {
  float16x8_t r0 = vreinterpretq_f16_u16(vorrq_u16(
      vreinterpretq_u16_f16(a.get_low()), vreinterpretq_u16_f16(b.get_low())));
  float16x8_t r1 = vreinterpretq_f16_u16(vorrq_u16(
      vreinterpretq_u16_f16(a.get_high()),
      vreinterpretq_u16_f16(b.get_high())));
  return Vectorized<c10::Half>(r0, r1);
}

template <>
Vectorized<c10::Half> inline operator^(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& b) {
  float16x8_t r0 = vreinterpretq_f16_u16(veorq_u16(
      vreinterpretq_u16_f16(a.get_low()), vreinterpretq_u16_f16(b.get_low())));
  float16x8_t r1 = vreinterpretq_f16_u16(veorq_u16(
      vreinterpretq_u16_f16(a.get_high()),
      vreinterpretq_u16_f16(b.get_high())));
  return Vectorized<c10::Half>(r0, r1);
}

inline Vectorized<c10::Half> Vectorized<c10::Half>::eq(
    const Vectorized<c10::Half>& other) const {
  return (*this == other) & Vectorized<c10::Half>(1);
}

inline Vectorized<c10::Half> Vectorized<c10::Half>::ne(
    const Vectorized<c10::Half>& other) const {
  return (*this != other) & Vectorized<c10::Half>(1);
}

inline Vectorized<c10::Half> Vectorized<c10::Half>::gt(
    const Vectorized<c10::Half>& other) const {
  return (*this > other) & Vectorized<c10::Half>(1);
}

inline Vectorized<c10::Half> Vectorized<c10::Half>::ge(
    const Vectorized<c10::Half>& other) const {
  return (*this >= other) & Vectorized<c10::Half>(1);
}

inline Vectorized<c10::Half> Vectorized<c10::Half>::lt(
    const Vectorized<c10::Half>& other) const {
  return (*this < other) & Vectorized<c10::Half>(1);
}

inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
    const Vectorized<c10::Half>& other) const {
  return (*this <= other) & Vectorized<c10::Half>(1);
}

template <>
inline void convert(const float16_t* src, int16_t* dst, int64_t n) {
  int64_t i;
#pragma unroll
  for (i = 0; i <= (n - Vectorized<c10::Half>::size());
       i += Vectorized<c10::Half>::size()) {
    vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i)));
    vst1q_s16(dst + i + 8, vcvtq_s16_f16(vld1q_f16(src + i + 8)));
  }
#pragma unroll
  for (; i < n; i++) {
    dst[i] = static_cast<int16_t>(src[i]);
  }
}

template <>
inline void convert(const int16_t* src, float16_t* dst, int64_t n) {
  int64_t i;
#pragma unroll
  for (i = 0; i <= (n - Vectorized<c10::Half>::size());
       i += Vectorized<c10::Half>::size()) {
    vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i)));
    vst1q_f16(dst + i + 8, vcvtq_f16_s16(vld1q_s16(src + i + 8)));
  }
#pragma unroll
  for (; i < n; i++) {
    dst[i] = static_cast<float16_t>(src[i]);
  }
}

template <>
Vectorized<c10::Half> inline fmadd(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& b,
    const Vectorized<c10::Half>& c) {
  float16x8_t r0 = vfmaq_f16(c.get_low(), a.get_low(), b.get_low());
  float16x8_t r1 = vfmaq_f16(c.get_high(), a.get_high(), b.get_high());
  return Vectorized<c10::Half>(r0, r1);
}

template <>
Vectorized<c10::Half> inline fmsub(
    const Vectorized<c10::Half>& a,
    const Vectorized<c10::Half>& b,
    const Vectorized<c10::Half>& c) {
  float16x8_t r0 = vfmsq_f16(c.get_low(), a.get_low(), b.get_low());
  float16x8_t r1 = vfmsq_f16(c.get_high(), a.get_high(), b.get_high());
  return Vectorized<c10::Half>(r0, r1);
}

#endif /* defined(aarch64) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(C10_MOBILE) */

} // namespace CPU_CAPABILITY
} // namespace at::vec