Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / ATen / cpu / vec / vec256 / vsx / vec256_common_vsx.h

#pragma once

#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>

// Note: header order is important here
#include <ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h>

#include <ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h>

#include <ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h>

namespace at {
namespace vec {

inline namespace CPU_CAPABILITY {

DEFINE_CLAMP_FUNCS(c10::quint8)
DEFINE_CLAMP_FUNCS(c10::qint8)
DEFINE_CLAMP_FUNCS(c10::qint32)
DEFINE_CLAMP_FUNCS(int16_t)
DEFINE_CLAMP_FUNCS(int32_t)
DEFINE_CLAMP_FUNCS(int64_t)
DEFINE_CLAMP_FUNCS(float)
DEFINE_CLAMP_FUNCS(double)

template <>
Vectorized<double> C10_ALWAYS_INLINE fmadd(
    const Vectorized<double>& a,
    const Vectorized<double>& b,
    const Vectorized<double>& c) {
  return Vectorized<double>{
      vec_madd(a.vec0(), b.vec0(), c.vec0()),
      vec_madd(a.vec1(), b.vec1(), c.vec1())};
}

template <>
Vectorized<int64_t> C10_ALWAYS_INLINE fmadd(
    const Vectorized<int64_t>& a,
    const Vectorized<int64_t>& b,
    const Vectorized<int64_t>& c) {
  return Vectorized<int64_t>{
      a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
}
template <>
Vectorized<int32_t> C10_ALWAYS_INLINE fmadd(
    const Vectorized<int32_t>& a,
    const Vectorized<int32_t>& b,
    const Vectorized<int32_t>& c) {
  return Vectorized<int32_t>{
      a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
}
template <>
Vectorized<int16_t> C10_ALWAYS_INLINE fmadd(
    const Vectorized<int16_t>& a,
    const Vectorized<int16_t>& b,
    const Vectorized<int16_t>& c) {
  return Vectorized<int16_t>{
      a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
}

DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float)
DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double)
DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t)
DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t)
DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t)

template <>
Vectorized<int64_t> C10_ALWAYS_INLINE
convert_to_int_of_same_size<double>(const Vectorized<double>& src) {
  return Vectorized<int64_t>{vec_signed(src.vec0()), vec_signed(src.vec1())};
}

template <>
Vectorized<int32_t> C10_ALWAYS_INLINE
convert_to_int_of_same_size<float>(
    const Vectorized<float>& src) {
  return Vectorized<int32_t>{vec_signed(src.vec0()), vec_signed(src.vec1())};
}

template <>
inline void convert(const int32_t* src, float* dst, int64_t n) {
  // int32_t and float have same size
  int64_t i;
  for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
    const int32_t* src_a = src + i;
    float* dst_a = dst + i;
    vint32 input_vec0 = vec_vsx_ld(offset0, reinterpret_cast<const vint32*>(src_a));
    vint32 input_vec1 =
        vec_vsx_ld(offset16, reinterpret_cast<const vint32*>(src_a));
    vfloat32 c0 = vec_float(input_vec0);
    vfloat32 c1 = vec_float(input_vec1);
    vec_vsx_st(c0, offset0, dst_a);
    vec_vsx_st(c1, offset16, dst_a);
  }

  for (; i < n; i++) {
    dst[i] = static_cast<float>(src[i]);
  }
}

template <>
inline void convert(const int64_t* src, double* dst, int64_t n) {
  int64_t i;
  for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
    const int64_t* src_a = src + i;
    double* dst_a = dst + i;
    vint64 input_vec0 =
        vec_vsx_ld(offset0, reinterpret_cast<const vint64*>(src_a));
    vint64 input_vec1 =
        vec_vsx_ld(offset16, reinterpret_cast<const vint64*>(src_a));
    vfloat64 c0 = vec_double(input_vec0);
    vfloat64 c1 = vec_double(input_vec1);
    vec_vsx_st(c0, offset0, reinterpret_cast<double*>(dst_a));
    vec_vsx_st(c1, offset16, reinterpret_cast<double*>(dst_a));
  }
  for (; i < n; i++) {
    dst[i] = static_cast<double>(src[i]);
  }
}

template <>
std::pair<Vectorized<double>, Vectorized<double>> inline interleave2<double>(
    const Vectorized<double>& a,
    const Vectorized<double>& b) {
  // inputs:
  //   a      = {a0, a1, a2, a3}
  //   b      = {b0, b1, b2, b3}

  vfloat64 ab00 = vec_xxpermdi(a.vec0(), b.vec0(), 0);
  vfloat64 ab11 = vec_xxpermdi(a.vec0(), b.vec0(), 3);
  vfloat64 ab2_00 = vec_xxpermdi(a.vec1(), b.vec1(), 0);
  vfloat64 ab2_11 = vec_xxpermdi(a.vec1(), b.vec1(), 3);
  //   return {a0, b0, a1, b1}
  //          {a2, b2, a3, b3}
  return std::make_pair(
      Vectorized<double>{ab00, ab11}, Vectorized<double>{ab2_00, ab2_11});
}

template <>
std::pair<Vectorized<double>, Vectorized<double>> inline deinterleave2<double>(
    const Vectorized<double>& a,
    const Vectorized<double>& b) {
  // inputs:
  //   a = {a0, b0, a1, b1}
  //   b = {a2, b2, a3, b3}
  vfloat64 aa01 = vec_xxpermdi(a.vec0(), a.vec1(), 0);
  vfloat64 aa23 = vec_xxpermdi(b.vec0(), b.vec1(), 0);

  vfloat64 bb_01 = vec_xxpermdi(a.vec0(), a.vec1(), 3);
  vfloat64 bb_23 = vec_xxpermdi(b.vec0(), b.vec1(), 3);

  // swap lanes:
  //   return {a0, a1, a2, a3}
  //          {b0, b1, b2, b3}
  return std::make_pair(
      Vectorized<double>{aa01, aa23}, Vectorized<double>{bb_01, bb_23});
}

template <>
std::pair<Vectorized<float>, Vectorized<float>> inline interleave2<float>(
    const Vectorized<float>& a,
    const Vectorized<float>& b) {
  // inputs:
  //   a = {a0, a1, a2, a3,, a4, a5, a6, a7}
  //   b = {b0, b1, b2, b3,, b4, b5, b6, b7}

  vfloat32 ab0011 = vec_mergeh(a.vec0(), b.vec0());
  vfloat32 ab2233 = vec_mergel(a.vec0(), b.vec0());

  vfloat32 ab2_0011 = vec_mergeh(a.vec1(), b.vec1());
  vfloat32 ab2_2233 = vec_mergel(a.vec1(), b.vec1());
  // group cols crossing lanes:
  //   return {a0, b0, a1, b1,, a2, b2, a3, b3}
  //          {a4, b4, a5, b5,, a6, b6, a7, b7}

  return std::make_pair(
      Vectorized<float>{ab0011, ab2233}, Vectorized<float>{ab2_0011, ab2_2233});
}

template <>
std::pair<Vectorized<float>, Vectorized<float>> inline deinterleave2<float>(
    const Vectorized<float>& a,
    const Vectorized<float>& b) {
  // inputs:
  //   a = {a0, b0, a1, b1,, a2, b2, a3, b3}
  //   b = {a4, b4, a5, b5,, a6, b6, a7, b7}

  // {a0,a2,b0,b2} {a1,a3,b1,b3}
  vfloat32 a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1());
  vfloat32 a1a3b1b3 = vec_mergel(a.vec0(), a.vec1());

  vfloat32 aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3);
  vfloat32 bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3);

  vfloat32 a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1());
  vfloat32 a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1());

  vfloat32 aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2);
  vfloat32 bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2);

  // it could be done with vec_perm ,too
  // swap lanes:
  //   return {a0, a1, a2, a3,, a4, a5, a6, a7}
  //          {b0, b1, b2, b3,, b4, b5, b6, b7}

  return std::make_pair(
      Vectorized<float>{aa0123, aa0123_2}, Vectorized<float>{bb0123, bb0123_2});
}

} // namespace
} // namespace vec
} // namespace at