#pragma once
namespace at { namespace native {
#if defined(USE_ROCM)
// take these out when ROCm implements std:: math functions
#include <math.h>
template <typename scalar_t>
static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
template <>
__forceinline__ __device__ float device_sqrt(float val) {
return ::sqrtf(val);
}
template <>
__forceinline__ __device__ double device_sqrt(double val) {
return ::sqrt(val);
}
#else
template<typename scalar_t>
__forceinline__ __device__ double device_sqrt(scalar_t val) {
return std::sqrt(val);
}
#endif
}}