[flang][runtime] Enable real/complex kind 10 and 16 variants of dot_product.

HasCppTypeFor<> used to evaluate to false always, so kind 10 and 16
variants of dot_product were not instantiated even though the host
supported 80- and 128-bit real and complex data types.
In addition, HAS_FLOAT128 was not enabling complex kind=16 variant
of dot_product. This is fixed now.

Note that the change for HasCppTypeFor<> may also affect other
functions such as matmul, i.e. kind 10 and 16 variants of them
may become available now (depending on the build host).

Differential Revision: https://reviews.llvm.org/D133051
This commit is contained in:
Slava Zakharin 2022-08-31 13:42:39 -07:00
parent 7ea643c06d
commit f8a9f43ef7
2 changed files with 23 additions and 22 deletions

View File

@ -23,14 +23,14 @@ namespace Fortran::runtime {
using common::TypeCategory;
template <TypeCategory CAT, int KIND> struct CppTypeForHelper {};
template <TypeCategory CAT, int KIND> struct CppTypeForHelper {
using type = void;
};
template <TypeCategory CAT, int KIND>
using CppTypeFor = typename CppTypeForHelper<CAT, KIND>::type;
template <TypeCategory CAT, int KIND, bool SFINAE = false>
constexpr bool HasCppTypeFor{false};
template <TypeCategory CAT, int KIND>
constexpr bool HasCppTypeFor<CAT, KIND, true>{
constexpr bool HasCppTypeFor{
!std::is_void_v<typename CppTypeForHelper<CAT, KIND>::type>};
template <int KIND> struct CppTypeForHelper<TypeCategory::Integer, KIND> {

View File

@ -147,24 +147,24 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
};
extern "C" {
std::int8_t RTNAME(DotProductInteger1)(
CppTypeFor<TypeCategory::Integer, 1> RTNAME(DotProductInteger1)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
}
std::int16_t RTNAME(DotProductInteger2)(
CppTypeFor<TypeCategory::Integer, 2> RTNAME(DotProductInteger2)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
}
std::int32_t RTNAME(DotProductInteger4)(
CppTypeFor<TypeCategory::Integer, 4> RTNAME(DotProductInteger4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
}
std::int64_t RTNAME(DotProductInteger8)(
CppTypeFor<TypeCategory::Integer, 8> RTNAME(DotProductInteger8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
#ifdef __SIZEOF_INT128__
common::int128_t RTNAME(DotProductInteger16)(
CppTypeFor<TypeCategory::Integer, 16> RTNAME(DotProductInteger16)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
}
@ -172,16 +172,16 @@ common::int128_t RTNAME(DotProductInteger16)(
// TODO: REAL/COMPLEX(2 & 3)
// Intermediate results and operations are at least 64 bits
float RTNAME(DotProductReal4)(
CppTypeFor<TypeCategory::Real, 4> RTNAME(DotProductReal4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
}
double RTNAME(DotProductReal8)(
CppTypeFor<TypeCategory::Real, 8> RTNAME(DotProductReal8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
}
#if LDBL_MANT_DIG == 64
long double RTNAME(DotProductReal10)(
CppTypeFor<TypeCategory::Real, 10> RTNAME(DotProductReal10)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
}
@ -193,24 +193,25 @@ CppTypeFor<TypeCategory::Real, 16> RTNAME(DotProductReal16)(
}
#endif
void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
void RTNAME(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)};
result = std::complex<float>{
static_cast<float>(z.real()), static_cast<float>(z.imag())};
result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
}
void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
void RTNAME(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
}
#if LDBL_MANT_DIG == 64
void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
void RTNAME(CppDotProductComplex10)(
CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
}
#elif LDBL_MANT_DIG == 113
void RTNAME(CppDotProductComplex16)(std::complex<CppFloat128Type> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
#endif
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
void RTNAME(CppDotProductComplex16)(
CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
}
#endif