#pragma once
#include <complex>
#include <iostream>
#include <c10/macros/Macros.h>
#if defined(__CUDACC__) || defined(__HIPCC__)
#include <thrust/complex.h>
#endif
namespace c10 {
// c10::complex is an implementation of complex numbers that aims
// to work on all devices supported by PyTorch
//
// Most of the APIs duplicates std::complex
// Reference: https://en.cppreference.com/w/cpp/numeric/complex
//
// [NOTE: Complex Operator Unification]
// Operators currently use a mix of std::complex, thrust::complex, and c10::complex internally.
// The end state is that all operators will use c10::complex internally. Until then, there may
// be some hacks to support all variants.
//
//
// [Note on Constructors]
//
// The APIs of constructors are mostly copied from C++ standard:
// https://en.cppreference.com/w/cpp/numeric/complex/complex
//
// Since C++14, all constructors are constexpr in std::complex
//
// There are three types of constructors:
// - initializing from real and imag:
// `constexpr complex( const T& re = T(), const T& im = T() );`
// - implicitly-declared copy constructor
// - converting constructors
//
// Converting constructors:
// - std::complex defines converting constructor between float/double/long double,
// while we define converting constructor between float/double.
// - For these converting constructors, upcasting is implicit, downcasting is
// explicit.
// - We also define explicit casting from std::complex/thrust::complex
// - Note that the conversion from thrust is not constexpr, because
// thrust does not define them as constexpr ????
//
//
// [Operator =]
//
// The APIs of operator = are mostly copied from C++ standard:
// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
//
// Since C++20, all operator= are constexpr. Although we are not building with
// C++20, we also obey this behavior.
//
// There are three types of assign operator:
// - Assign a real value from the same scalar type
// - In std, this is templated as complex& operator=(const T& x)
// with specialization `complex& operator=(T x)` for float/double/long double
// Since we only support float and double, on will use `complex& operator=(T x)`
// - Copy assignment operator and converting assignment operator
// - There is no specialization of converting assignment operators, which type is
// convertible is solely dependent on whether the scalar type is convertible
//
// In addition to the standard assignment, we also provide assignment operators with std and thrust
//
//
// [Casting operators]
//
// std::complex does not have casting operators. We define casting operators casting to std::complex and thrust::complex
//
//
// [Operator ""]
//
// std::complex has custom literals `i`, `if` and `il` defined in namespace `std::literals::complex_literals`.
// We define our own custom literals in the namespace `c10::complex_literals`. Our custom literals does not
// follow the same behavior as in std::complex, instead, we define _if, _id to construct float/double
// complex literals.
//
//
// [real() and imag()]
//
// In C++20, there are two overload of these functions, one it to return the real/imag, another is to set real/imag,
// they are both constexpr. We follow this design.
//
//
// [Operator +=,-=,*=,/=]
//
// Since C++20, these operators become constexpr. In our implementation, they are also constexpr.
//
// There are two types of such operators: operating with a real number, or operating with another complex number.
// For the operating with a real number, the generic template form has argument type `const T &`, while the overload
// for float/double/long double has `T`. We will follow the same type as float/double/long double in std.
//
// [Unary operator +-]
//
// Since C++20, they are constexpr. We also make them expr
//
// [Binary operators +-*/]
//
// Each operator has three versions (taking + as example):
// - complex + complex
// - complex + real
// - real + complex
//
// [Operator ==, !=]
//
// Each operator has three versions (taking == as example):
// - complex == complex
// - complex == real
// - real == complex
//
// Some of them are removed on C++20, but we decide to keep them
//
// [Operator <<, >>]
//
// These are implemented by casting to std::complex
//
//
//
// TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported, because:
// - lots of members and functions of c10::Half are not constexpr
// - thrust::complex only support float and double
template<typename T>
struct alignas(sizeof(T) * 2) complex {
using value_type = T;
T real_ = T(0);
T imag_ = T(0);
constexpr complex() = default;
constexpr complex(const T& re, const T& im = T()): real_(re), imag_(im) {}
template<typename U>
explicit constexpr complex(const std::complex<U> &other): complex(other.real(), other.imag()) {}
#if defined(__CUDACC__) || defined(__HIPCC__)
template<typename U>
explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other): real_(other.real()), imag_(other.imag()) {}
// NOTE can not be implemented as follow due to ROCm bug:
// explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other): complex(other.real(), other.imag()) {}
#endif
// Use SFINAE to specialize casting constructor for c10::complex<float> and c10::complex<double>
template<typename U = T>
explicit constexpr complex(const std::enable_if_t<std::is_same<U, float>::value, complex<double>> &other):
real_(other.real_), imag_(other.imag_) {}
template<typename U = T>
constexpr complex(const std::enable_if_t<std::is_same<U, double>::value, complex<float>> &other):
real_(other.real_), imag_(other.imag_) {}
constexpr complex<T> &operator =(T re) {
real_ = re;
imag_ = 0;
return *this;
}
constexpr complex<T> &operator +=(T re) {
real_ += re;
return *this;
}
constexpr complex<T> &operator -=(T re) {
real_ -= re;
return *this;
}
constexpr complex<T> &operator *=(T re) {
real_ *= re;
imag_ *= re;
return *this;
}
constexpr complex<T> &operator /=(T re) {
real_ /= re;
imag_ /= re;
return *this;
}
template<typename U>
constexpr complex<T> &operator =(const complex<U> &rhs) {
real_ = rhs.real();
imag_ = rhs.imag();
return *this;
}
template<typename U>
constexpr complex<T> &operator +=(const complex<U> &rhs) {
real_ += rhs.real();
imag_ += rhs.imag();
return *this;
}
template<typename U>
constexpr complex<T> &operator -=(const complex<U> &rhs) {
real_ -= rhs.real();
imag_ -= rhs.imag();
return *this;
}
template<typename U>
constexpr complex<T> &operator *=(const complex<U> &rhs) {
// (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
T a = real_;
T b = imag_;
U c = rhs.real();
U d = rhs.imag();
real_ = a * c - b * d;
imag_ = a * d + b * c;
return *this;
}
#ifdef __APPLE__
#define FORCE_INLINE_APPLE __attribute__((always_inline))
#else
#define FORCE_INLINE_APPLE
#endif
template<typename U>
constexpr FORCE_INLINE_APPLE complex<T> &operator /=(const complex<U> &rhs) __ubsan_ignore_float_divide_by_zero__ {
// (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
T a = real_;
T b = imag_;
U c = rhs.real();
U d = rhs.imag();
auto denominator = c * c + d * d;
real_ = (a * c + b * d) / denominator;
imag_ = (b * c - a * d) / denominator;
return *this;
}
#undef FORCE_INLINE_APPLE
template<typename U>
constexpr complex<T> &operator =(const std::complex<U> &rhs) {
real_ = rhs.real();
imag_ = rhs.imag();
return *this;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
template<typename U>
C10_HOST_DEVICE complex<T> &operator =(const thrust::complex<U> &rhs) {
real_ = rhs.real();
imag_ = rhs.imag();
return *this;
}
#endif
template<typename U>
explicit constexpr operator std::complex<U>() const {
return std::complex<U>(std::complex<T>(real(), imag()));
}
#if defined(__CUDACC__) || defined(__HIPCC__)
template<typename U>
C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
}
#endif
// consistent with NumPy behavior
explicit constexpr operator bool() const {
return real() || imag();
}
C10_HOST_DEVICE constexpr T real() const {
return real_;
}
constexpr void real(T value) {
real_ = value;
}
constexpr T imag() const {
return imag_;
}
constexpr void imag(T value) {
imag_ = value;
}
};
namespace complex_literals {
constexpr complex<float> operator"" _if(long double imag) {
return complex<float>(0.0f, static_cast<float>(imag));
}
constexpr complex<double> operator"" _id(long double imag) {
return complex<double>(0.0, static_cast<double>(imag));
}
constexpr complex<float> operator"" _if(unsigned long long imag) {
return complex<float>(0.0f, static_cast<float>(imag));
}
constexpr complex<double> operator"" _id(unsigned long long imag) {
return complex<double>(0.0, static_cast<double>(imag));
}
} // namespace complex_literals
template<typename T>
constexpr complex<T> operator+(const complex<T>& val) {
return val;
}
template<typename T>
constexpr complex<T> operator-(const complex<T>& val) {
return complex<T>(-val.real(), -val.imag());
}
template<typename T>
constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
complex<T> result = lhs;
return result += rhs;
}
template<typename T>
constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
complex<T> result = lhs;
return result += rhs;
}
template<typename T>
constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
return complex<T>(lhs + rhs.real(), rhs.imag());
}
template<typename T>
constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
complex<T> result = lhs;
return result -= rhs;
}
template<typename T>
constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
complex<T> result = lhs;
return result -= rhs;
}
template<typename T>
constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
complex<T> result = -rhs;
return result += lhs;
}
template<typename T>
Loading ...