[HIP] Math Headers to use type promotion

Similar to libcxx implementation of cmath function
overloads, use type promotion templates to determine
return types of multi-argument math functions.

Fixes: SWDEV-256825

Reviewed By: tra, yaxunl

Differential Revision: https://reviews.llvm.org/D90409
This commit is contained in:
Aaron En Ye Shi 2020-10-28 17:51:55 +00:00
parent cdbf6bfdc7
commit ca5b31502c
1 changed files with 106 additions and 0 deletions

View File

@ -16,6 +16,8 @@
#if defined(__cplusplus)
#include <limits>
#include <type_traits>
#include <utility>
#endif
#include <limits.h>
#include <stdint.h>
@ -205,6 +207,72 @@ template <bool __B, class __T = void> struct __hip_enable_if {};
template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; };
// decltype is only available in C++11 and above.
#if __cplusplus >= 201103L
// __hip_promote
namespace __hip {
template <class _Tp> struct __numeric_type {
static void __test(...);
static _Float16 __test(_Float16);
static float __test(float);
static double __test(char);
static double __test(int);
static double __test(unsigned);
static double __test(long);
static double __test(unsigned long);
static double __test(long long);
static double __test(unsigned long long);
static double __test(double);
typedef decltype(__test(std::declval<_Tp>())) type;
static const bool value = !std::is_same<type, void>::value;
};
template <> struct __numeric_type<void> { static const bool value = true; };
template <class _A1, class _A2 = void, class _A3 = void,
bool = __numeric_type<_A1>::value &&__numeric_type<_A2>::value
&&__numeric_type<_A3>::value>
class __promote_imp {
public:
static const bool value = false;
};
template <class _A1, class _A2, class _A3>
class __promote_imp<_A1, _A2, _A3, true> {
private:
typedef typename __promote_imp<_A1>::type __type1;
typedef typename __promote_imp<_A2>::type __type2;
typedef typename __promote_imp<_A3>::type __type3;
public:
typedef decltype(__type1() + __type2() + __type3()) type;
static const bool value = true;
};
template <class _A1, class _A2> class __promote_imp<_A1, _A2, void, true> {
private:
typedef typename __promote_imp<_A1>::type __type1;
typedef typename __promote_imp<_A2>::type __type2;
public:
typedef decltype(__type1() + __type2()) type;
static const bool value = true;
};
template <class _A1> class __promote_imp<_A1, void, void, true> {
public:
typedef typename __numeric_type<_A1>::type type;
static const bool value = true;
};
template <class _A1, class _A2 = void, class _A3 = void>
class __promote : public __promote_imp<_A1, _A2, _A3> {};
} // namespace __hip
#endif //__cplusplus >= 201103L
// __HIP_OVERLOAD1 is used to resolve function calls with integer argument to
// avoid compilation error due to ambibuity. e.g. floor(5) is resolved with
// floor(double).
@ -219,6 +287,18 @@ template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; };
// __HIP_OVERLOAD2 is used to resolve function calls with mixed float/double
// or integer argument to avoid compilation error due to ambibuity. e.g.
// max(5.0f, 6.0) is resolved with max(double, double).
#if __cplusplus >= 201103L
#define __HIP_OVERLOAD2(__retty, __fn) \
template <typename __T1, typename __T2> \
__DEVICE__ typename __hip_enable_if< \
std::numeric_limits<__T1>::is_specialized && \
std::numeric_limits<__T2>::is_specialized, \
typename __hip::__promote<__T1, __T2>::type>::type \
__fn(__T1 __x, __T2 __y) { \
typedef typename __hip::__promote<__T1, __T2>::type __result_type; \
return __fn((__result_type)__x, (__result_type)__y); \
}
#else
#define __HIP_OVERLOAD2(__retty, __fn) \
template <typename __T1, typename __T2> \
__DEVICE__ \
@ -228,6 +308,7 @@ template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; };
__fn(__T1 __x, __T2 __y) { \
return __fn((double)__x, (double)__y); \
}
#endif
__HIP_OVERLOAD1(double, abs)
__HIP_OVERLOAD1(double, acos)
@ -296,6 +377,18 @@ __HIP_OVERLOAD2(double, max)
__HIP_OVERLOAD2(double, min)
// Additional Overloads that don't quite match HIP_OVERLOAD.
#if __cplusplus >= 201103L
template <typename __T1, typename __T2, typename __T3>
__DEVICE__ typename __hip_enable_if<
std::numeric_limits<__T1>::is_specialized &&
std::numeric_limits<__T2>::is_specialized &&
std::numeric_limits<__T3>::is_specialized,
typename __hip::__promote<__T1, __T2, __T3>::type>::type
fma(__T1 __x, __T2 __y, __T3 __z) {
typedef typename __hip::__promote<__T1, __T2, __T3>::type __result_type;
return ::fma((__result_type)__x, (__result_type)__y, (__result_type)__z);
}
#else
template <typename __T1, typename __T2, typename __T3>
__DEVICE__
typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
@ -305,6 +398,7 @@ __DEVICE__
fma(__T1 __x, __T2 __y, __T3 __z) {
return ::fma((double)__x, (double)__y, (double)__z);
}
#endif
template <typename __T>
__DEVICE__
@ -327,6 +421,17 @@ __DEVICE__
return ::modf((double)__x, __exp);
}
#if __cplusplus >= 201103L
template <typename __T1, typename __T2>
__DEVICE__
typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
std::numeric_limits<__T2>::is_specialized,
typename __hip::__promote<__T1, __T2>::type>::type
remquo(__T1 __x, __T2 __y, int *__quo) {
typedef typename __hip::__promote<__T1, __T2>::type __result_type;
return ::remquo((__result_type)__x, (__result_type)__y, __quo);
}
#else
template <typename __T1, typename __T2>
__DEVICE__
typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
@ -335,6 +440,7 @@ __DEVICE__
remquo(__T1 __x, __T2 __y, int *__quo) {
return ::remquo((double)__x, (double)__y, __quo);
}
#endif
template <typename __T>
__DEVICE__