forked from OSchip/llvm-project
[flang][runtime] Enable real/complex kind 10 and 16 variants of dot_product.
HasCppTypeFor<> used to evaluate to false always, so kind 10 and 16 variants of dot_product were not instantiated even though the host supported 80- and 128-bit real and complex data types. In addition, HAS_FLOAT128 was not enabling complex kind=16 variant of dot_product. This is fixed now. Note that the change for HasCppTypeFor<> may also affect other functions such as matmul, i.e. kind 10 and 16 variants of them may become available now (depending on the build host). Differential Revision: https://reviews.llvm.org/D133051
This commit is contained in:
parent
7ea643c06d
commit
f8a9f43ef7
|
@ -23,14 +23,14 @@ namespace Fortran::runtime {
|
|||
|
||||
using common::TypeCategory;
|
||||
|
||||
template <TypeCategory CAT, int KIND> struct CppTypeForHelper {};
|
||||
template <TypeCategory CAT, int KIND> struct CppTypeForHelper {
|
||||
using type = void;
|
||||
};
|
||||
template <TypeCategory CAT, int KIND>
|
||||
using CppTypeFor = typename CppTypeForHelper<CAT, KIND>::type;
|
||||
|
||||
template <TypeCategory CAT, int KIND, bool SFINAE = false>
|
||||
constexpr bool HasCppTypeFor{false};
|
||||
template <TypeCategory CAT, int KIND>
|
||||
constexpr bool HasCppTypeFor<CAT, KIND, true>{
|
||||
constexpr bool HasCppTypeFor{
|
||||
!std::is_void_v<typename CppTypeForHelper<CAT, KIND>::type>};
|
||||
|
||||
template <int KIND> struct CppTypeForHelper<TypeCategory::Integer, KIND> {
|
||||
|
|
|
@ -147,24 +147,24 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
|
|||
};
|
||||
|
||||
extern "C" {
|
||||
std::int8_t RTNAME(DotProductInteger1)(
|
||||
CppTypeFor<TypeCategory::Integer, 1> RTNAME(DotProductInteger1)(
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
|
||||
}
|
||||
std::int16_t RTNAME(DotProductInteger2)(
|
||||
CppTypeFor<TypeCategory::Integer, 2> RTNAME(DotProductInteger2)(
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
|
||||
}
|
||||
std::int32_t RTNAME(DotProductInteger4)(
|
||||
CppTypeFor<TypeCategory::Integer, 4> RTNAME(DotProductInteger4)(
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
|
||||
}
|
||||
std::int64_t RTNAME(DotProductInteger8)(
|
||||
CppTypeFor<TypeCategory::Integer, 8> RTNAME(DotProductInteger8)(
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
|
||||
}
|
||||
#ifdef __SIZEOF_INT128__
|
||||
common::int128_t RTNAME(DotProductInteger16)(
|
||||
CppTypeFor<TypeCategory::Integer, 16> RTNAME(DotProductInteger16)(
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
|
||||
}
|
||||
|
@ -172,16 +172,16 @@ common::int128_t RTNAME(DotProductInteger16)(
|
|||
|
||||
// TODO: REAL/COMPLEX(2 & 3)
|
||||
// Intermediate results and operations are at least 64 bits
|
||||
float RTNAME(DotProductReal4)(
|
||||
CppTypeFor<TypeCategory::Real, 4> RTNAME(DotProductReal4)(
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
|
||||
}
|
||||
double RTNAME(DotProductReal8)(
|
||||
CppTypeFor<TypeCategory::Real, 8> RTNAME(DotProductReal8)(
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
|
||||
}
|
||||
#if LDBL_MANT_DIG == 64
|
||||
long double RTNAME(DotProductReal10)(
|
||||
CppTypeFor<TypeCategory::Real, 10> RTNAME(DotProductReal10)(
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
|
||||
}
|
||||
|
@ -193,24 +193,25 @@ CppTypeFor<TypeCategory::Real, 16> RTNAME(DotProductReal16)(
|
|||
}
|
||||
#endif
|
||||
|
||||
void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
|
||||
void RTNAME(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)};
|
||||
result = std::complex<float>{
|
||||
static_cast<float>(z.real()), static_cast<float>(z.imag())};
|
||||
result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
|
||||
}
|
||||
void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
|
||||
void RTNAME(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
|
||||
}
|
||||
#if LDBL_MANT_DIG == 64
|
||||
void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
void RTNAME(CppDotProductComplex10)(
|
||||
CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
|
||||
const Descriptor &y, const char *source, int line) {
|
||||
result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
|
||||
}
|
||||
#elif LDBL_MANT_DIG == 113
|
||||
void RTNAME(CppDotProductComplex16)(std::complex<CppFloat128Type> &result,
|
||||
const Descriptor &x, const Descriptor &y, const char *source, int line) {
|
||||
#endif
|
||||
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
|
||||
void RTNAME(CppDotProductComplex16)(
|
||||
CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
|
||||
const Descriptor &y, const char *source, int line) {
|
||||
result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
|
||||
}
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue