forked from OSchip/llvm-project
[HIP] Math Headers to use type promotion
Similar to libcxx implementation of cmath function overloads, use type promotion templates to determine return types of multi-argument math functions. Fixes: SWDEV-256825 Reviewed By: tra, yaxunl Differential Revision: https://reviews.llvm.org/D90409
This commit is contained in:
parent
cdbf6bfdc7
commit
ca5b31502c
|
@ -16,6 +16,8 @@
|
|||
|
||||
#if defined(__cplusplus)
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#endif
|
||||
#include <limits.h>
|
||||
#include <stdint.h>
|
||||
|
@ -205,6 +207,72 @@ template <bool __B, class __T = void> struct __hip_enable_if {};
|
|||
|
||||
template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; };
|
||||
|
||||
// decltype is only available in C++11 and above.
|
||||
#if __cplusplus >= 201103L
|
||||
// __hip_promote
|
||||
namespace __hip {
|
||||
|
||||
template <class _Tp> struct __numeric_type {
|
||||
static void __test(...);
|
||||
static _Float16 __test(_Float16);
|
||||
static float __test(float);
|
||||
static double __test(char);
|
||||
static double __test(int);
|
||||
static double __test(unsigned);
|
||||
static double __test(long);
|
||||
static double __test(unsigned long);
|
||||
static double __test(long long);
|
||||
static double __test(unsigned long long);
|
||||
static double __test(double);
|
||||
|
||||
typedef decltype(__test(std::declval<_Tp>())) type;
|
||||
static const bool value = !std::is_same<type, void>::value;
|
||||
};
|
||||
|
||||
template <> struct __numeric_type<void> { static const bool value = true; };
|
||||
|
||||
template <class _A1, class _A2 = void, class _A3 = void,
|
||||
bool = __numeric_type<_A1>::value &&__numeric_type<_A2>::value
|
||||
&&__numeric_type<_A3>::value>
|
||||
class __promote_imp {
|
||||
public:
|
||||
static const bool value = false;
|
||||
};
|
||||
|
||||
template <class _A1, class _A2, class _A3>
|
||||
class __promote_imp<_A1, _A2, _A3, true> {
|
||||
private:
|
||||
typedef typename __promote_imp<_A1>::type __type1;
|
||||
typedef typename __promote_imp<_A2>::type __type2;
|
||||
typedef typename __promote_imp<_A3>::type __type3;
|
||||
|
||||
public:
|
||||
typedef decltype(__type1() + __type2() + __type3()) type;
|
||||
static const bool value = true;
|
||||
};
|
||||
|
||||
template <class _A1, class _A2> class __promote_imp<_A1, _A2, void, true> {
|
||||
private:
|
||||
typedef typename __promote_imp<_A1>::type __type1;
|
||||
typedef typename __promote_imp<_A2>::type __type2;
|
||||
|
||||
public:
|
||||
typedef decltype(__type1() + __type2()) type;
|
||||
static const bool value = true;
|
||||
};
|
||||
|
||||
template <class _A1> class __promote_imp<_A1, void, void, true> {
|
||||
public:
|
||||
typedef typename __numeric_type<_A1>::type type;
|
||||
static const bool value = true;
|
||||
};
|
||||
|
||||
template <class _A1, class _A2 = void, class _A3 = void>
|
||||
class __promote : public __promote_imp<_A1, _A2, _A3> {};
|
||||
|
||||
} // namespace __hip
|
||||
#endif //__cplusplus >= 201103L
|
||||
|
||||
// __HIP_OVERLOAD1 is used to resolve function calls with integer argument to
|
||||
// avoid compilation error due to ambibuity. e.g. floor(5) is resolved with
|
||||
// floor(double).
|
||||
|
@ -219,6 +287,18 @@ template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; };
|
|||
// __HIP_OVERLOAD2 is used to resolve function calls with mixed float/double
|
||||
// or integer argument to avoid compilation error due to ambibuity. e.g.
|
||||
// max(5.0f, 6.0) is resolved with max(double, double).
|
||||
#if __cplusplus >= 201103L
|
||||
#define __HIP_OVERLOAD2(__retty, __fn) \
|
||||
template <typename __T1, typename __T2> \
|
||||
__DEVICE__ typename __hip_enable_if< \
|
||||
std::numeric_limits<__T1>::is_specialized && \
|
||||
std::numeric_limits<__T2>::is_specialized, \
|
||||
typename __hip::__promote<__T1, __T2>::type>::type \
|
||||
__fn(__T1 __x, __T2 __y) { \
|
||||
typedef typename __hip::__promote<__T1, __T2>::type __result_type; \
|
||||
return __fn((__result_type)__x, (__result_type)__y); \
|
||||
}
|
||||
#else
|
||||
#define __HIP_OVERLOAD2(__retty, __fn) \
|
||||
template <typename __T1, typename __T2> \
|
||||
__DEVICE__ \
|
||||
|
@ -228,6 +308,7 @@ template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; };
|
|||
__fn(__T1 __x, __T2 __y) { \
|
||||
return __fn((double)__x, (double)__y); \
|
||||
}
|
||||
#endif
|
||||
|
||||
__HIP_OVERLOAD1(double, abs)
|
||||
__HIP_OVERLOAD1(double, acos)
|
||||
|
@ -296,6 +377,18 @@ __HIP_OVERLOAD2(double, max)
|
|||
__HIP_OVERLOAD2(double, min)
|
||||
|
||||
// Additional Overloads that don't quite match HIP_OVERLOAD.
|
||||
#if __cplusplus >= 201103L
|
||||
template <typename __T1, typename __T2, typename __T3>
|
||||
__DEVICE__ typename __hip_enable_if<
|
||||
std::numeric_limits<__T1>::is_specialized &&
|
||||
std::numeric_limits<__T2>::is_specialized &&
|
||||
std::numeric_limits<__T3>::is_specialized,
|
||||
typename __hip::__promote<__T1, __T2, __T3>::type>::type
|
||||
fma(__T1 __x, __T2 __y, __T3 __z) {
|
||||
typedef typename __hip::__promote<__T1, __T2, __T3>::type __result_type;
|
||||
return ::fma((__result_type)__x, (__result_type)__y, (__result_type)__z);
|
||||
}
|
||||
#else
|
||||
template <typename __T1, typename __T2, typename __T3>
|
||||
__DEVICE__
|
||||
typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
|
||||
|
@ -305,6 +398,7 @@ __DEVICE__
|
|||
fma(__T1 __x, __T2 __y, __T3 __z) {
|
||||
return ::fma((double)__x, (double)__y, (double)__z);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename __T>
|
||||
__DEVICE__
|
||||
|
@ -327,6 +421,17 @@ __DEVICE__
|
|||
return ::modf((double)__x, __exp);
|
||||
}
|
||||
|
||||
#if __cplusplus >= 201103L
|
||||
template <typename __T1, typename __T2>
|
||||
__DEVICE__
|
||||
typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
|
||||
std::numeric_limits<__T2>::is_specialized,
|
||||
typename __hip::__promote<__T1, __T2>::type>::type
|
||||
remquo(__T1 __x, __T2 __y, int *__quo) {
|
||||
typedef typename __hip::__promote<__T1, __T2>::type __result_type;
|
||||
return ::remquo((__result_type)__x, (__result_type)__y, __quo);
|
||||
}
|
||||
#else
|
||||
template <typename __T1, typename __T2>
|
||||
__DEVICE__
|
||||
typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
|
||||
|
@ -335,6 +440,7 @@ __DEVICE__
|
|||
remquo(__T1 __x, __T2 __y, int *__quo) {
|
||||
return ::remquo((double)__x, (double)__y, __quo);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename __T>
|
||||
__DEVICE__
|
||||
|
|
Loading…
Reference in New Issue