From f8a9f43ef7affb7991e60cdd5ce93d2566f5b2e4 Mon Sep 17 00:00:00 2001 From: Slava Zakharin Date: Wed, 31 Aug 2022 13:42:39 -0700 Subject: [PATCH] [flang][runtime] Enable real/complex kind 10 and 16 variants of dot_product. HasCppTypeFor<> used to evaluate to false always, so kind 10 and 16 variants of dot_product were not instantiated even though the host supported 80- and 128-bit real and complex data types. In addition, HAS_FLOAT128 was not enabling complex kind=16 variant of dot_product. This is fixed now. Note that the change for HasCppTypeFor<> may also affect other functions such as matmul, i.e. kind 10 and 16 variants of them may become available now (depending on the build host). Differential Revision: https://reviews.llvm.org/D133051 --- flang/include/flang/Runtime/cpp-type.h | 8 +++--- flang/runtime/dot-product.cpp | 37 +++++++++++++------------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/flang/include/flang/Runtime/cpp-type.h b/flang/include/flang/Runtime/cpp-type.h index aa4b6f360124..00af2c115484 100644 --- a/flang/include/flang/Runtime/cpp-type.h +++ b/flang/include/flang/Runtime/cpp-type.h @@ -23,14 +23,14 @@ namespace Fortran::runtime { using common::TypeCategory; -template struct CppTypeForHelper {}; +template struct CppTypeForHelper { + using type = void; +}; template using CppTypeFor = typename CppTypeForHelper::type; -template -constexpr bool HasCppTypeFor{false}; template -constexpr bool HasCppTypeFor{ +constexpr bool HasCppTypeFor{ !std::is_void_v::type>}; template struct CppTypeForHelper { diff --git a/flang/runtime/dot-product.cpp b/flang/runtime/dot-product.cpp index 2f9debbfccaa..857ed6759817 100644 --- a/flang/runtime/dot-product.cpp +++ b/flang/runtime/dot-product.cpp @@ -147,24 +147,24 @@ template struct DotProduct { }; extern "C" { -std::int8_t RTNAME(DotProductInteger1)( +CppTypeFor RTNAME(DotProductInteger1)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } -std::int16_t RTNAME(DotProductInteger2)( +CppTypeFor RTNAME(DotProductInteger2)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } -std::int32_t RTNAME(DotProductInteger4)( +CppTypeFor RTNAME(DotProductInteger4)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } -std::int64_t RTNAME(DotProductInteger8)( +CppTypeFor RTNAME(DotProductInteger8)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } #ifdef __SIZEOF_INT128__ -common::int128_t RTNAME(DotProductInteger16)( +CppTypeFor RTNAME(DotProductInteger16)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } @@ -172,16 +172,16 @@ common::int128_t RTNAME(DotProductInteger16)( // TODO: REAL/COMPLEX(2 & 3) // Intermediate results and operations are at least 64 bits -float RTNAME(DotProductReal4)( +CppTypeFor RTNAME(DotProductReal4)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } -double RTNAME(DotProductReal8)( +CppTypeFor RTNAME(DotProductReal8)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } #if LDBL_MANT_DIG == 64 -long double RTNAME(DotProductReal10)( +CppTypeFor RTNAME(DotProductReal10)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } @@ -193,24 +193,25 @@ CppTypeFor RTNAME(DotProductReal16)( } #endif -void RTNAME(CppDotProductComplex4)(std::complex &result, +void RTNAME(CppDotProductComplex4)(CppTypeFor &result, const Descriptor &x, const Descriptor &y, const char *source, int line) { - auto z{DotProduct{}(x, y, source, line)}; - result = std::complex{ - static_cast(z.real()), static_cast(z.imag())}; + result = DotProduct{}(x, y, source, line); } -void RTNAME(CppDotProductComplex8)(std::complex &result, +void RTNAME(CppDotProductComplex8)(CppTypeFor &result, const Descriptor &x, const Descriptor &y, const char *source, int line) { result = DotProduct{}(x, y, source, line); } #if LDBL_MANT_DIG == 64 -void RTNAME(CppDotProductComplex10)(std::complex &result, - const Descriptor &x, const Descriptor &y, const char *source, int line) { +void RTNAME(CppDotProductComplex10)( + CppTypeFor &result, const Descriptor &x, + const Descriptor &y, const char *source, int line) { result = DotProduct{}(x, y, source, line); } -#elif LDBL_MANT_DIG == 113 -void RTNAME(CppDotProductComplex16)(std::complex &result, - const Descriptor &x, const Descriptor &y, const char *source, int line) { +#endif +#if LDBL_MANT_DIG == 113 || HAS_FLOAT128 +void RTNAME(CppDotProductComplex16)( + CppTypeFor &result, const Descriptor &x, + const Descriptor &y, const char *source, int line) { result = DotProduct{}(x, y, source, line); } #endif