[libc] Add implementation of fmaf.

Differential Revision: https://reviews.llvm.org/D94018
This commit is contained in:
Tue Ly 2021-01-02 01:36:29 -05:00
parent 5acdae1f9a
commit 4726bec8f2
12 changed files with 363 additions and 17 deletions

View File

@ -65,6 +65,7 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.floor
libc.src.math.floorf
libc.src.math.floorl
libc.src.math.fmaf
libc.src.math.fmax
libc.src.math.fmaxf
libc.src.math.fmaxl

View File

@ -106,6 +106,7 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.floor
libc.src.math.floorf
libc.src.math.floorl
libc.src.math.fmaf
libc.src.math.fmin
libc.src.math.fminf
libc.src.math.fminl

View File

@ -322,6 +322,8 @@ def StdC : StandardSpec<"stdc"> {
FunctionSpec<"fmaxf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<FloatType>]>,
FunctionSpec<"fmaxl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>, ArgSpec<LongDoubleType>]>,
FunctionSpec<"fmaf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<FloatType>, ArgSpec<FloatType>]>,
FunctionSpec<"frexp", RetValSpec<DoubleType>, [ArgSpec<DoubleType>, ArgSpec<IntPtr>]>,
FunctionSpec<"frexpf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<IntPtr>]>,
FunctionSpec<"frexpl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>, ArgSpec<IntPtr>]>,

View File

@ -978,3 +978,14 @@ add_entrypoint_object(
-O2
)
add_entrypoint_object(
fmaf
SRCS
fmaf.cpp
HDRS
fmaf.h
DEPENDS
libc.utils.FPUtil.fputil
COMPILE_OPTIONS
-O2
)

64
libc/src/math/fmaf.cpp Normal file
View File

@ -0,0 +1,64 @@
//===-- Implementation of fmaf function -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "src/__support/common.h"
#include "utils/FPUtil/FEnv.h"
#include "utils/FPUtil/FPBits.h"
namespace __llvm_libc {
float LLVM_LIBC_ENTRYPOINT(fmaf)(float x, float y, float z) {
// Product is exact.
double prod = static_cast<double>(x) * static_cast<double>(y);
double z_d = static_cast<double>(z);
double sum = prod + z_d;
fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);
if (!(bit_sum.isInfOrNaN() || bit_sum.isZero())) {
// Since the sum is computed in double precision, rounding might happen
// (for instance, when bitz.exponent > bit_prod.exponent + 5, or
// bit_prod.exponent > bitz.exponent + 40). In that case, when we round
// the sum back to float, double rounding error might occur.
// A concrete example of this phenomenon is as follows:
// x = y = 1 + 2^(-12), z = 2^(-53)
// The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
// So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
// On the other hand, with the default rounding mode,
// double(x*y + z) = 1 + 2^(-11) + 2^(-24)
// and casting again to float gives us:
// float(double(x*y + z)) = 1 + 2^(-11).
//
// In order to correct this possible double rounding error, first we use
// Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
// assuming the (default) rounding mode is round-to-the-nearest,
// tie-to-even. Moreover, t satisfies the condition that t < eps(sum),
// i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
// occurs when computing the sum, we just need to use t to adjust (any) last
// bit of sum, so that the sticky bits used when rounding sum to float are
// correct (when it matters).
fputil::FPBits<double> t(
(bit_prod.exponent >= bitz.exponent)
? ((static_cast<double>(bit_sum) - bit_prod) - bitz)
: ((static_cast<double>(bit_sum) - bitz) - bit_prod));
// Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
// zero.
if (!t.isZero() && ((bit_sum.mantissa & 0xfff'ffffULL) == 0)) {
if (bit_sum.sign != t.sign) {
++bit_sum.mantissa;
} else if (bit_sum.mantissa) {
--bit_sum.mantissa;
}
}
}
return static_cast<float>(static_cast<double>(bit_sum));
}
} // namespace __llvm_libc

18
libc/src/math/fmaf.h Normal file
View File

@ -0,0 +1,18 @@
//===-- Implementation header for fmaf --------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_LIBC_SRC_MATH_FMAF_H
#define LLVM_LIBC_SRC_MATH_FMAF_H
namespace __llvm_libc {
float fmaf(float x, float y, float z);
} // namespace __llvm_libc
#endif // LLVM_LIBC_SRC_MATH_FMAF_H

View File

@ -1049,3 +1049,16 @@ add_fp_unittest(
libc.src.math.nextafterl
libc.utils.FPUtil.fputil
)
add_fp_unittest(
fmaf_test
NEED_MPFR
SUITE
libc_math_unittests
SRCS
fmaf_test.cpp
DEPENDS
libc.include.math
libc.src.math.fmaf
libc.utils.FPUtil.fputil
)

View File

@ -0,0 +1,94 @@
//===-- Utility class to test different flavors of fma --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
#define LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
#include "utils/FPUtil/FPBits.h"
#include "utils/FPUtil/TestHelpers.h"
#include "utils/MPFRWrapper/MPFRUtils.h"
#include "utils/UnitTest/Test.h"
#include <random>
namespace mpfr = __llvm_libc::testing::mpfr;
template <typename T>
class FmaTestTemplate : public __llvm_libc::testing::Test {
private:
using Func = T (*)(T, T, T);
using FPBits = __llvm_libc::fputil::FPBits<T>;
using UIntType = typename FPBits::UIntType;
const T nan = __llvm_libc::fputil::FPBits<T>::buildNaN(1);
const T inf = __llvm_libc::fputil::FPBits<T>::inf();
const T negInf = __llvm_libc::fputil::FPBits<T>::negInf();
const T zero = __llvm_libc::fputil::FPBits<T>::zero();
const T negZero = __llvm_libc::fputil::FPBits<T>::negZero();
UIntType getRandomBitPattern() {
UIntType bits{0};
for (size_t i = 0; i < sizeof(UIntType) / 2; ++i) {
bits = (bits << 2) + static_cast<uint16_t>(std::rand());
}
return bits;
}
public:
void testSpecialNumbers(Func func) {
EXPECT_FP_EQ(func(zero, zero, zero), zero);
EXPECT_FP_EQ(func(zero, negZero, negZero), negZero);
EXPECT_FP_EQ(func(inf, inf, zero), inf);
EXPECT_FP_EQ(func(negInf, inf, negInf), negInf);
EXPECT_FP_EQ(func(inf, zero, zero), nan);
EXPECT_FP_EQ(func(inf, negInf, inf), nan);
EXPECT_FP_EQ(func(nan, zero, inf), nan);
EXPECT_FP_EQ(func(inf, negInf, nan), nan);
// Test underflow rounding up.
EXPECT_FP_EQ(func(T(0.5), FPBits(FPBits::minSubnormal),
FPBits(FPBits::minSubnormal)),
FPBits(UIntType(2)));
// Test underflow rounding down.
FPBits v(FPBits::minNormal + UIntType(1));
EXPECT_FP_EQ(
func(T(1) / T(FPBits::minNormal << 1), v, FPBits(FPBits::minNormal)),
v);
// Test overflow.
FPBits z(FPBits::maxNormal);
EXPECT_FP_EQ(func(T(1.75), z, -z), T(0.75) * z);
}
void testSubnormalRange(Func func) {
constexpr UIntType count = 1000001;
constexpr UIntType step =
(FPBits::maxSubnormal - FPBits::minSubnormal) / count;
for (UIntType v = FPBits::minSubnormal, w = FPBits::maxSubnormal;
v <= FPBits::maxSubnormal && w >= FPBits::minSubnormal;
v += step, w -= step) {
T x = FPBits(getRandomBitPattern()), y = FPBits(v), z = FPBits(w);
T result = func(x, y, z);
mpfr::TernaryInput<T> input{x, y, z};
ASSERT_MPFR_MATCH(mpfr::Operation::Fma, input, result, 0.5);
}
}
void testNormalRange(Func func) {
constexpr UIntType count = 1000001;
constexpr UIntType step = (FPBits::maxNormal - FPBits::minNormal) / count;
for (UIntType v = FPBits::minNormal, w = FPBits::maxNormal;
v <= FPBits::maxNormal && w >= FPBits::minNormal;
v += step, w -= step) {
T x = FPBits(v), y = FPBits(w), z = FPBits(getRandomBitPattern());
T result = func(x, y, z);
mpfr::TernaryInput<T> input{x, y, z};
ASSERT_MPFR_MATCH(mpfr::Operation::Fma, input, result, 0.5);
}
}
};
#endif // LLVM_LIBC_TEST_SRC_MATH_FMATEST_H

View File

@ -0,0 +1,19 @@
//===-- Unittests for fmaf ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "FmaTest.h"
#include "src/math/fmaf.h"
using FmaTest = FmaTestTemplate<float>;
TEST_F(FmaTest, SpecialNumbers) { testSpecialNumbers(&__llvm_libc::fmaf); }
TEST_F(FmaTest, SubnormalRange) { testSubnormalRange(&__llvm_libc::fmaf); }
TEST_F(FmaTest, NormalRange) { testNormalRange(&__llvm_libc::fmaf); }

View File

@ -84,7 +84,10 @@ template <typename T> struct __attribute__((packed)) FPBits {
// We don't want accidental type promotions/conversions so we require exact
// type match.
template <typename XType,
cpp::EnableIfType<cpp::IsSame<T, XType>::Value, int> = 0>
cpp::EnableIfType<cpp::IsSame<T, XType>::Value ||
(cpp::IsIntegral<XType>::Value &&
(sizeof(XType) == sizeof(UIntType))),
int> = 0>
explicit FPBits(XType x) {
*this = *reinterpret_cast<FPBits<T> *>(&x);
}
@ -106,13 +109,6 @@ template <typename T> struct __attribute__((packed)) FPBits {
// the potential software implementations of UIntType will not slow real
// code.
template <typename XType,
cpp::EnableIfType<cpp::IsSame<UIntType, XType>::Value, int> = 0>
explicit FPBits<long double>(XType x) {
// The last 4 bytes of v are ignored in case of i386.
*this = *reinterpret_cast<FPBits<T> *>(&x);
}
UIntType bitsAsUInt() const {
return *reinterpret_cast<const UIntType *>(this);
}

View File

@ -35,48 +35,69 @@ namespace __llvm_libc {
namespace testing {
namespace mpfr {
template <typename T> struct Precision;
template <> struct Precision<float> {
static constexpr unsigned int value = 24;
};
template <> struct Precision<double> {
static constexpr unsigned int value = 53;
};
#if !(defined(__x86_64__) || defined(__i386__))
template <> struct Precision<long double> {
static constexpr unsigned int value = 64;
};
#else
template <> struct Precision<long double> {
static constexpr unsigned int value = 113;
};
#endif
class MPFRNumber {
// A precision value which allows sufficiently large additional
// precision even compared to quad-precision floating point values.
static constexpr unsigned int mpfrPrecision = 128;
unsigned int mpfrPrecision;
mpfr_t value;
public:
MPFRNumber() { mpfr_init2(value, mpfrPrecision); }
MPFRNumber() : mpfrPrecision(128) { mpfr_init2(value, mpfrPrecision); }
// We use explicit EnableIf specializations to disallow implicit
// conversions. Implicit conversions can potentially lead to loss of
// precision.
template <typename XType,
cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_flt(value, x, MPFR_RNDN);
}
template <typename XType,
cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_d(value, x, MPFR_RNDN);
}
template <typename XType,
cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_ld(value, x, MPFR_RNDN);
}
template <typename XType,
cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_sj(value, x, MPFR_RNDN);
}
MPFRNumber(const MPFRNumber &other) {
MPFRNumber(const MPFRNumber &other) : mpfrPrecision(other.mpfrPrecision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set(value, other.value, MPFR_RNDN);
}
@ -85,6 +106,7 @@ public:
}
MPFRNumber &operator=(const MPFRNumber &rhs) {
mpfrPrecision = rhs.mpfrPrecision;
mpfr_set(value, rhs.value, MPFR_RNDN);
return *this;
}
@ -193,6 +215,12 @@ public:
return result;
}
MPFRNumber fma(const MPFRNumber &b, const MPFRNumber &c) {
MPFRNumber result(*this);
mpfr_fma(result.value, value, b.value, c.value, MPFR_RNDN);
return result;
}
std::string str() const {
// 200 bytes should be more than sufficient to hold a 100-digit number
// plus additional bytes for the decimal point, '-' sign etc.
@ -328,6 +356,22 @@ binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) {
}
}
template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
ternaryOperationOneOutput(Operation op, InputType x, InputType y, InputType z) {
// For FMA function, we just need to compare with the mpfr_fma with the same
// precision as InputType. Using higher precision as the intermediate results
// to compare might incorrectly fail due to double-rounding errors.
constexpr unsigned int prec = Precision<InputType>::value;
MPFRNumber inputX(x, prec), inputY(y, prec), inputZ(z, prec);
switch (op) {
case Operation::Fma:
return inputX.fma(inputY, inputZ);
default:
__builtin_unreachable();
}
}
template <typename T>
void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
testutils::StreamWrapper &OS) {
@ -476,6 +520,48 @@ template void explainBinaryOperationOneOutputError<long double>(
Operation, const BinaryInput<long double> &, long double,
testutils::StreamWrapper &);
template <typename T>
void explainTernaryOperationOneOutputError(Operation op,
const TernaryInput<T> &input,
T libcResult,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrX(input.x, Precision<T>::value);
MPFRNumber mpfrY(input.y, Precision<T>::value);
MPFRNumber mpfrZ(input.z, Precision<T>::value);
FPBits<T> xbits(input.x);
FPBits<T> ybits(input.y);
FPBits<T> zbits(input.z);
MPFRNumber mpfrResult =
ternaryOperationOneOutput(op, input.x, input.y, input.z);
MPFRNumber mpfrMatchValue(libcResult);
OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str()
<< " z: " << mpfrZ.str() << '\n';
__llvm_libc::fputil::testing::describeValue("First input bits: ", input.x,
OS);
__llvm_libc::fputil::testing::describeValue("Second input bits: ", input.y,
OS);
__llvm_libc::fputil::testing::describeValue("Third input bits: ", input.z,
OS);
OS << "Libc result: " << mpfrMatchValue.str() << '\n'
<< "MPFR result: " << mpfrResult.str() << '\n';
__llvm_libc::fputil::testing::describeValue(
"Libc floating point result bits: ", libcResult, OS);
__llvm_libc::fputil::testing::describeValue(
" MPFR rounded bits: ", mpfrResult.as<T>(), OS);
OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult)) << '\n';
}
template void explainTernaryOperationOneOutputError<float>(
Operation, const TernaryInput<float> &, float, testutils::StreamWrapper &);
template void explainTernaryOperationOneOutputError<double>(
Operation, const TernaryInput<double> &, double,
testutils::StreamWrapper &);
template void explainTernaryOperationOneOutputError<long double>(
Operation, const TernaryInput<long double> &, long double,
testutils::StreamWrapper &);
template <typename T>
bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult,
double ulpError) {
@ -575,6 +661,27 @@ compareBinaryOperationOneOutput<double>(Operation, const BinaryInput<double> &,
template bool compareBinaryOperationOneOutput<long double>(
Operation, const BinaryInput<long double> &, long double, double);
template <typename T>
bool compareTernaryOperationOneOutput(Operation op,
const TernaryInput<T> &input,
T libcResult, double ulpError) {
MPFRNumber mpfrResult =
ternaryOperationOneOutput(op, input.x, input.y, input.z);
double ulp = mpfrResult.ulp(libcResult);
bool bitsAreEven = ((FPBits<T>(libcResult).bitsAsUInt() & 1) == 0);
return (ulp < ulpError) ||
((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
}
template bool
compareTernaryOperationOneOutput<float>(Operation, const TernaryInput<float> &,
float, double);
template bool compareTernaryOperationOneOutput<double>(
Operation, const TernaryInput<double> &, double, double);
template bool compareTernaryOperationOneOutput<long double>(
Operation, const TernaryInput<long double> &, long double, double);
static mpfr_rnd_t getMPFRRoundingMode(RoundingMode mode) {
switch (mode) {
case RoundingMode::Upward:

View File

@ -57,8 +57,11 @@ enum class Operation : int {
RemQuo, // The first output, the floating point output, is the remainder.
EndBinaryOperationsTwoOutputs,
// Operations which take three floating point nubmers of the same type as
// input and produce a single floating point number of the same type as
// output.
BeginTernaryOperationsSingleOuput,
// TODO: Add operations like fma.
Fma,
EndTernaryOperationsSingleOutput,
};
@ -113,6 +116,11 @@ template <typename T>
bool compareBinaryOperationOneOutput(Operation op, const BinaryInput<T> &input,
T libcOutput, double t);
template <typename T>
bool compareTernaryOperationOneOutput(Operation op,
const TernaryInput<T> &input,
T libcOutput, double t);
template <typename T>
void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
testutils::StreamWrapper &OS);
@ -132,6 +140,12 @@ void explainBinaryOperationOneOutputError(Operation op,
T matchValue,
testutils::StreamWrapper &OS);
template <typename T>
void explainTernaryOperationOneOutputError(Operation op,
const TernaryInput<T> &input,
T matchValue,
testutils::StreamWrapper &OS);
template <Operation op, typename InputType, typename OutputType>
class MPFRMatcher : public testing::Matcher<OutputType> {
InputType input;
@ -174,7 +188,7 @@ private:
template <typename T>
static bool match(const TernaryInput<T> &in, T out, double tolerance) {
// TODO: Implement the comparision function and error reporter.
return compareTernaryOperationOneOutput(op, in, out, tolerance);
}
template <typename T>
@ -199,6 +213,12 @@ private:
testutils::StreamWrapper &OS) {
explainBinaryOperationOneOutputError(op, in, out, OS);
}
template <typename T>
static void explainError(const TernaryInput<T> &in, T out,
testutils::StreamWrapper &OS) {
explainTernaryOperationOneOutputError(op, in, out, OS);
}
};
} // namespace internal