Fix pow and ldexp in HIP header

This commit is contained in:
Yaxun (Sam) Liu 2020-07-21 16:54:41 -04:00
parent f659c44016
commit ce04d4e39c
2 changed files with 20 additions and 0 deletions

View File

@ -78,6 +78,7 @@ __device__ __attribute__((const)) float __ocml_len4_f32(float, float, float,
__device__ __attribute__((pure)) float __ocml_ncdf_f32(float);
__device__ __attribute__((pure)) float __ocml_ncdfinv_f32(float);
__device__ __attribute__((pure)) float __ocml_pow_f32(float, float);
__device__ __attribute__((pure)) float __ocml_pown_f32(float, int);
__device__ __attribute__((pure)) float __ocml_rcbrt_f32(float);
__device__ __attribute__((const)) float __ocml_remainder_f32(float, float);
__device__ float __ocml_remquo_f32(float, float,
@ -205,6 +206,7 @@ __device__ __attribute__((const)) double __ocml_len4_f64(double, double, double,
__device__ __attribute__((pure)) double __ocml_ncdf_f64(double);
__device__ __attribute__((pure)) double __ocml_ncdfinv_f64(double);
__device__ __attribute__((pure)) double __ocml_pow_f64(double, double);
__device__ __attribute__((pure)) double __ocml_pown_f64(double, int);
__device__ __attribute__((pure)) double __ocml_rcbrt_f64(double);
__device__ __attribute__((const)) double __ocml_remainder_f64(double, double);
__device__ double __ocml_remquo_f64(double, double,
@ -290,6 +292,7 @@ __device__ __attribute__((const)) _Float16 __ocml_rsqrt_f16(_Float16);
__device__ _Float16 __ocml_sin_f16(_Float16);
__device__ __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16);
__device__ __attribute__((const)) _Float16 __ocml_trunc_f16(_Float16);
__device__ __attribute__((pure)) _Float16 __ocml_pown_f16(_Float16, int);
typedef _Float16 __2f16 __attribute__((ext_vector_type(2)));
typedef short __2i16 __attribute__((ext_vector_type(2)));
@ -320,6 +323,7 @@ __device__ __attribute__((const)) __2f16 __ocml_rsqrt_2f16(__2f16);
__device__ __2f16 __ocml_sin_2f16(__2f16);
__device__ __attribute__((const)) __2f16 __ocml_sqrt_2f16(__2f16);
__device__ __attribute__((const)) __2f16 __ocml_trunc_2f16(__2f16);
__device__ __attribute__((const)) __2f16 __ocml_pown_2f16(__2f16, __2i16);
} // extern "C"

View File

@ -294,6 +294,8 @@ normf(int __dim,
__DEVICE__
inline float powf(float __x, float __y) { return __ocml_pow_f32(__x, __y); }
__DEVICE__
inline float powif(float __x, int __y) { return __ocml_pown_f32(__x, __y); }
__DEVICE__
inline float rcbrtf(float __x) { return __ocml_rcbrt_f32(__x); }
__DEVICE__
inline float remainderf(float __x, float __y) {
@ -759,6 +761,8 @@ inline double normcdfinv(double __x) { return __ocml_ncdfinv_f64(__x); }
__DEVICE__
inline double pow(double __x, double __y) { return __ocml_pow_f64(__x, __y); }
__DEVICE__
inline double powi(double __x, int __y) { return __ocml_pown_f64(__x, __y); }
__DEVICE__
inline double rcbrt(double __x) { return __ocml_rcbrt_f64(__x); }
__DEVICE__
inline double remainder(double __x, double __y) {
@ -1134,6 +1138,7 @@ __DEF_FUN1(double, trunc);
__DEVICE__ \
inline float __func(float __x, int __y) { return __func##f(__x, __y); }
__DEF_FLOAT_FUN2I(scalbn)
__DEF_FLOAT_FUN2I(ldexp)
template <class T> __DEVICE__ inline T min(T __arg1, T __arg2) {
return (__arg1 < __arg2) ? __arg1 : __arg2;
@ -1173,6 +1178,17 @@ __host__ inline static int max(int __arg1, int __arg2) {
return std::max(__arg1, __arg2);
}
__DEVICE__
inline float pow(float __base, int __iexp) { return powif(__base, __iexp); }
__DEVICE__
inline double pow(double __base, int __iexp) { return powi(__base, __iexp); }
__DEVICE__
inline _Float16 pow(_Float16 __base, int __iexp) {
return __ocml_pown_f16(__base, __iexp);
}
#pragma pop_macro("__DEF_FUN1")
#pragma pop_macro("__DEF_FUN2")
#pragma pop_macro("__DEF_FUNI")