[libc] Extend MPFRMatcher to handle multiple-input-multiple-output functions.

Tests for frexp[f|l] now use the new capability. Not all input-output
combinations have been addressed by this change. Support for newer combinations
can be added in future as needed.

Reviewed By: lntue

Differential Revision: https://reviews.llvm.org/D86506
This commit is contained in:
Siva Chandra Reddy 2020-08-20 22:36:53 -07:00
parent 75e0b58668
commit 3f4674a557
6 changed files with 515 additions and 98 deletions

View File

@ -333,6 +333,7 @@ add_fp_unittest(
add_fp_unittest(
frexp_test
NEED_MPFR
SUITE
libc_math_unittests
SRCS
@ -345,6 +346,7 @@ add_fp_unittest(
add_fp_unittest(
frexpf_test
NEED_MPFR
SUITE
libc_math_unittests
SRCS
@ -357,6 +359,7 @@ add_fp_unittest(
add_fp_unittest(
frexpl_test
NEED_MPFR
SUITE
libc_math_unittests
SRCS

View File

@ -11,13 +11,18 @@
#include "utils/FPUtil/BasicOperations.h"
#include "utils/FPUtil/BitPatterns.h"
#include "utils/FPUtil/ClassificationFunctions.h"
#include "utils/FPUtil/FPBits.h"
#include "utils/FPUtil/FloatOperations.h"
#include "utils/FPUtil/FloatProperties.h"
#include "utils/MPFRWrapper/MPFRUtils.h"
#include "utils/UnitTest/Test.h"
using FPBits = __llvm_libc::fputil::FPBits<double>;
using __llvm_libc::fputil::valueAsBits;
using __llvm_libc::fputil::valueFromBits;
namespace mpfr = __llvm_libc::testing::mpfr;
using BitPatterns = __llvm_libc::fputil::BitPatterns<double>;
using Properties = __llvm_libc::fputil::FloatProperties<double>;
@ -127,17 +132,19 @@ TEST(FrexpTest, SomeIntegers) {
}
TEST(FrexpTest, InDoubleRange) {
using BitsType = Properties::BitsType;
constexpr BitsType count = 1000000;
constexpr BitsType step = UINT64_MAX / count;
for (BitsType i = 0, v = 0; i <= count; ++i, v += step) {
double x = valueFromBits(v);
using UIntType = FPBits::UIntType;
constexpr UIntType count = 1000001;
constexpr UIntType step = UIntType(-1) / count;
for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
double x = FPBits(v);
if (isnan(x) || isinf(x) || x == 0.0)
continue;
int exponent;
double frac = __llvm_libc::frexp(x, &exponent);
ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0);
ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5);
mpfr::BinaryOutput<double> result;
result.f = __llvm_libc::frexp(x, &result.i);
ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0);
ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5);
ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0);
}
}

View File

@ -11,14 +11,18 @@
#include "utils/FPUtil/BasicOperations.h"
#include "utils/FPUtil/BitPatterns.h"
#include "utils/FPUtil/ClassificationFunctions.h"
#include "utils/FPUtil/FPBits.h"
#include "utils/FPUtil/FloatOperations.h"
#include "utils/FPUtil/FloatProperties.h"
#include "utils/MPFRWrapper/MPFRUtils.h"
#include "utils/UnitTest/Test.h"
using FPBits = __llvm_libc::fputil::FPBits<float>;
using __llvm_libc::fputil::valueAsBits;
using __llvm_libc::fputil::valueFromBits;
namespace mpfr = __llvm_libc::testing::mpfr;
using BitPatterns = __llvm_libc::fputil::BitPatterns<float>;
using Properties = __llvm_libc::fputil::FloatProperties<float>;
@ -109,7 +113,7 @@ TEST(FrexpfTest, PowersOfTwo) {
EXPECT_EQ(exponent, 7);
}
TEST(FrexpTest, SomeIntegers) {
TEST(FrexpfTest, SomeIntegers) {
int exponent;
EXPECT_EQ(valueAsBits(0.75f),
@ -135,17 +139,19 @@ TEST(FrexpTest, SomeIntegers) {
}
TEST(FrexpfTest, InFloatRange) {
using BitsType = Properties::BitsType;
constexpr BitsType count = 1000000;
constexpr BitsType step = UINT32_MAX / count;
for (BitsType i = 0, v = 0; i <= count; ++i, v += step) {
float x = valueFromBits(v);
using UIntType = FPBits::UIntType;
constexpr UIntType count = 1000001;
constexpr UIntType step = UIntType(-1) / count;
for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
float x = FPBits(v);
if (isnan(x) || isinf(x) || x == 0.0)
continue;
int exponent;
float frac = __llvm_libc::frexpf(x, &exponent);
ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0f);
ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5f);
mpfr::BinaryOutput<float> result;
result.f = __llvm_libc::frexpf(x, &result.i);
ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0);
ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5);
ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0);
}
}

View File

@ -10,10 +10,13 @@
#include "src/math/frexpl.h"
#include "utils/FPUtil/BasicOperations.h"
#include "utils/FPUtil/FPBits.h"
#include "utils/MPFRWrapper/MPFRUtils.h"
#include "utils/UnitTest/Test.h"
using FPBits = __llvm_libc::fputil::FPBits<long double>;
namespace mpfr = __llvm_libc::testing::mpfr;
TEST(FrexplTest, SpecialNumbers) {
int exponent;
@ -94,10 +97,11 @@ TEST(FrexplTest, LongDoubleRange) {
if (isnan(x) || isinf(x) || x == 0.0l)
continue;
int exponent;
long double frac = __llvm_libc::frexpl(x, &exponent);
mpfr::BinaryOutput<long double> result;
result.f = __llvm_libc::frexpl(x, &result.i);
ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0l);
ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5l);
ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0);
ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5);
ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0);
}
}

View File

@ -14,6 +14,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <mpfr.h>
#include <stdint.h>
#include <string>
@ -65,50 +66,90 @@ public:
mpfr_set_sj(value, x, MPFR_RNDN);
}
template <typename XType,
cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0>
MPFRNumber(Operation op, XType rawValue) {
mpfr_init2(value, mpfrPrecision);
MPFRNumber mpfrInput(rawValue);
switch (op) {
case Operation::Abs:
mpfr_abs(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Ceil:
mpfr_ceil(value, mpfrInput.value);
break;
case Operation::Cos:
mpfr_cos(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Exp:
mpfr_exp(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Exp2:
mpfr_exp2(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Floor:
mpfr_floor(value, mpfrInput.value);
break;
case Operation::Round:
mpfr_round(value, mpfrInput.value);
break;
case Operation::Sin:
mpfr_sin(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Sqrt:
mpfr_sqrt(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Trunc:
mpfr_trunc(value, mpfrInput.value);
break;
}
}
MPFRNumber(const MPFRNumber &other) {
mpfr_set(value, other.value, MPFR_RNDN);
}
~MPFRNumber() { mpfr_clear(value); }
MPFRNumber &operator=(const MPFRNumber &rhs) {
mpfr_set(value, rhs.value, MPFR_RNDN);
return *this;
}
MPFRNumber abs() const {
MPFRNumber result;
mpfr_abs(result.value, value, MPFR_RNDN);
return result;
}
MPFRNumber ceil() const {
MPFRNumber result;
mpfr_ceil(result.value, value);
return result;
}
MPFRNumber cos() const {
MPFRNumber result;
mpfr_cos(result.value, value, MPFR_RNDN);
return result;
}
MPFRNumber exp() const {
MPFRNumber result;
mpfr_exp(result.value, value, MPFR_RNDN);
return result;
}
MPFRNumber exp2() const {
MPFRNumber result;
mpfr_exp2(result.value, value, MPFR_RNDN);
return result;
}
MPFRNumber floor() const {
MPFRNumber result;
mpfr_floor(result.value, value);
return result;
}
MPFRNumber frexp(int &exp) {
MPFRNumber result;
mpfr_exp_t resultExp;
mpfr_frexp(&resultExp, result.value, value, MPFR_RNDN);
exp = resultExp;
return result;
}
MPFRNumber remquo(const MPFRNumber &divisor, int &quotient) {
MPFRNumber remainder;
long q;
mpfr_remquo(remainder.value, &q, value, divisor.value, MPFR_RNDN);
quotient = q;
return remainder;
}
MPFRNumber round() const {
MPFRNumber result;
mpfr_round(result.value, value);
return result;
}
MPFRNumber sin() const {
MPFRNumber result;
mpfr_sin(result.value, value, MPFR_RNDN);
return result;
}
MPFRNumber sqrt() const {
MPFRNumber result;
mpfr_sqrt(result.value, value, MPFR_RNDN);
return result;
}
MPFRNumber trunc() const {
MPFRNumber result;
mpfr_trunc(result.value, value);
return result;
}
std::string str() const {
// 200 bytes should be more than sufficient to hold a 100-digit number
@ -179,10 +220,65 @@ public:
namespace internal {
template <typename T>
void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
MPFRNumber mpfrResult(operation, input);
template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
unaryOperation(Operation op, InputType input) {
MPFRNumber mpfrInput(input);
switch (op) {
case Operation::Abs:
return mpfrInput.abs();
case Operation::Ceil:
return mpfrInput.ceil();
case Operation::Cos:
return mpfrInput.cos();
case Operation::Exp:
return mpfrInput.exp();
case Operation::Exp2:
return mpfrInput.exp2();
case Operation::Floor:
return mpfrInput.floor();
case Operation::Round:
return mpfrInput.round();
case Operation::Sin:
return mpfrInput.sin();
case Operation::Sqrt:
return mpfrInput.sqrt();
case Operation::Trunc:
return mpfrInput.trunc();
default:
__builtin_unreachable();
}
}
template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
unaryOperationTwoOutputs(Operation op, InputType input, int &output) {
MPFRNumber mpfrInput(input);
switch (op) {
case Operation::Frexp:
return mpfrInput.frexp(output);
default:
__builtin_unreachable();
}
}
template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) {
MPFRNumber inputX(x), inputY(y);
switch (op) {
case Operation::RemQuo:
return inputX.remquo(inputY, output);
default:
__builtin_unreachable();
}
}
template <typename T>
void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrInput(input);
MPFRNumber mpfrResult = unaryOperation(op, input);
MPFRNumber mpfrMatchValue(matchValue);
FPBits<T> inputBits(input);
FPBits<T> matchBits(matchValue);
@ -201,25 +297,174 @@ void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
<< '\n';
}
template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &);
template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &);
template void
MPFRMatcher<long double>::explainError(testutils::StreamWrapper &);
explainUnaryOperationSingleOutputError<float>(Operation op, float, float,
testutils::StreamWrapper &);
template void
explainUnaryOperationSingleOutputError<double>(Operation op, double, double,
testutils::StreamWrapper &);
template void explainUnaryOperationSingleOutputError<long double>(
Operation op, long double, long double, testutils::StreamWrapper &);
template <typename T>
bool compare(Operation op, T input, T libcResult, double ulpError) {
void explainUnaryOperationTwoOutputsError(Operation op, T input,
const BinaryOutput<T> &libcResult,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrInput(input);
FPBits<T> inputBits(input);
int mpfrIntResult;
MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);
if (mpfrIntResult != libcResult.i) {
OS << "MPFR integral result: " << mpfrIntResult << '\n'
<< "Libc integral result: " << libcResult.i << '\n';
} else {
OS << "Integral result from libc matches integral result from MPFR.\n";
}
MPFRNumber mpfrMatchValue(libcResult.f);
OS << "Libc floating point result is not within tolerance value of the MPFR "
<< "result.\n\n";
OS << " Input decimal: " << mpfrInput.str() << "\n\n";
OS << "Libc floating point value: " << mpfrMatchValue.str() << '\n';
__llvm_libc::fputil::testing::describeValue(
" Libc floating point bits: ", libcResult.f, OS);
OS << "\n\n";
OS << " MPFR result: " << mpfrResult.str() << '\n';
__llvm_libc::fputil::testing::describeValue(
" MPFR rounded: ", mpfrResult.as<T>(), OS);
OS << '\n'
<< " ULP error: "
<< std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
}
template void explainUnaryOperationTwoOutputsError<float>(
Operation, float, const BinaryOutput<float> &, testutils::StreamWrapper &);
template void
explainUnaryOperationTwoOutputsError<double>(Operation, double,
const BinaryOutput<double> &,
testutils::StreamWrapper &);
template void explainUnaryOperationTwoOutputsError<long double>(
Operation, long double, const BinaryOutput<long double> &,
testutils::StreamWrapper &);
template <typename T>
void explainBinaryOperationTwoOutputsError(Operation op,
const BinaryInput<T> &input,
const BinaryOutput<T> &libcResult,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrX(input.x);
MPFRNumber mpfrY(input.y);
FPBits<T> xbits(input.x);
FPBits<T> ybits(input.y);
int mpfrIntResult;
MPFRNumber mpfrResult =
binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
MPFRNumber mpfrMatchValue(libcResult.f);
OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n'
<< "MPFR integral result: " << mpfrIntResult << '\n'
<< "Libc integral result: " << libcResult.i << '\n'
<< "Libc floating point result: " << mpfrMatchValue.str() << '\n'
<< " MPFR result: " << mpfrResult.str() << '\n';
__llvm_libc::fputil::testing::describeValue(
"Libc floating point result bits: ", libcResult.f, OS);
__llvm_libc::fputil::testing::describeValue(
" MPFR rounded bits: ", mpfrResult.as<T>(), OS);
OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
}
template void explainBinaryOperationTwoOutputsError<float>(
Operation, const BinaryInput<float> &, const BinaryOutput<float> &,
testutils::StreamWrapper &);
template void explainBinaryOperationTwoOutputsError<double>(
Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
testutils::StreamWrapper &);
template void explainBinaryOperationTwoOutputsError<long double>(
Operation, const BinaryInput<long double> &,
const BinaryOutput<long double> &, testutils::StreamWrapper &);
template <typename T>
bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult,
double ulpError) {
// If the ulp error is exactly 0.5 (i.e a tie), we would check that the result
// is rounded to the nearest even.
MPFRNumber mpfrResult(op, input);
MPFRNumber mpfrResult = unaryOperation(op, input);
double ulp = mpfrResult.ulp(libcResult);
bool bitsAreEven = ((FPBits<T>(libcResult).bitsAsUInt() & 1) == 0);
return (ulp < ulpError) ||
((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
}
template bool compare<float>(Operation, float, float, double);
template bool compare<double>(Operation, double, double, double);
template bool compare<long double>(Operation, long double, long double, double);
template bool compareUnaryOperationSingleOutput<float>(Operation, float, float,
double);
template bool compareUnaryOperationSingleOutput<double>(Operation, double,
double, double);
template bool compareUnaryOperationSingleOutput<long double>(Operation,
long double,
long double,
double);
template <typename T>
bool compareUnaryOperationTwoOutputs(Operation op, T input,
const BinaryOutput<T> &libcResult,
double ulpError) {
int mpfrIntResult;
MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);
double ulp = mpfrResult.ulp(libcResult.f);
if (mpfrIntResult != libcResult.i)
return false;
bool bitsAreEven = ((FPBits<T>(libcResult.f).bitsAsUInt() & 1) == 0);
return (ulp < ulpError) ||
((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
}
template bool
compareUnaryOperationTwoOutputs<float>(Operation, float,
const BinaryOutput<float> &, double);
template bool
compareUnaryOperationTwoOutputs<double>(Operation, double,
const BinaryOutput<double> &, double);
template bool compareUnaryOperationTwoOutputs<long double>(
Operation, long double, const BinaryOutput<long double> &, double);
template <typename T>
bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
const BinaryOutput<T> &libcResult,
double ulpError) {
int mpfrIntResult;
MPFRNumber mpfrResult =
binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
double ulp = mpfrResult.ulp(libcResult.f);
if (mpfrIntResult != libcResult.i) {
if (op == Operation::RemQuo) {
if ((0x7 & mpfrIntResult) != libcResult.i)
return false;
} else {
return false;
}
}
bool bitsAreEven = ((FPBits<T>(libcResult.f).bitsAsUInt() & 1) == 0);
return (ulp < ulpError) ||
((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
}
template bool
compareBinaryOperationTwoOutputs<float>(Operation, const BinaryInput<float> &,
const BinaryOutput<float> &, double);
template bool
compareBinaryOperationTwoOutputs<double>(Operation, const BinaryInput<double> &,
const BinaryOutput<double> &, double);
template bool compareBinaryOperationTwoOutputs<long double>(
Operation, const BinaryInput<long double> &,
const BinaryOutput<long double> &, double);
} // namespace internal

View File

@ -19,6 +19,10 @@ namespace testing {
namespace mpfr {
enum class Operation : int {
// Operations with take a single floating point number as input
// and produce a single floating point number as output. The input
// and output floating point numbers are of the same kind.
BeginUnaryOperationsSingleOutput,
Abs,
Ceil,
Cos,
@ -28,45 +32,193 @@ enum class Operation : int {
Round,
Sin,
Sqrt,
Trunc
Trunc,
EndUnaryOperationsSingleOutput,
// Operations which take a single floating point nubmer as input
// but produce two outputs. The first ouput is a floating point
// number of the same type as the input. The second output is of type
// 'int'.
BeginUnaryOperationsTwoOutputs,
Frexp, // Floating point output, the first output, is the fractional part.
EndUnaryOperationsTwoOutputs,
// Operations wich take two floating point nubmers of the same type as
// input and produce a single floating point number of the same type as
// output.
BeginBinaryOperationsSingleOutput,
// TODO: Add operations like hypot.
EndBinaryOperationsSingleOutput,
// Operations which take two floating point numbers of the same type as
// input and produce two outputs. The first output is a floating nubmer of
// the same type as the inputs. The second output is af type 'int'.
BeginBinaryOperationsTwoOutputs,
RemQuo, // The first output, the floating point output, is the remainder.
EndBinaryOperationsTwoOutputs,
BeginTernaryOperationsSingleOuput,
// TODO: Add operations like fma.
EndTernaryOperationsSingleOutput,
};
template <typename T> struct BinaryInput {
static_assert(
__llvm_libc::cpp::IsFloatingPointType<T>::Value,
"Template parameter of BinaryInput must be a floating point type.");
using Type = T;
T x, y;
};
template <typename T> struct TernaryInput {
static_assert(
__llvm_libc::cpp::IsFloatingPointType<T>::Value,
"Template parameter of TernaryInput must be a floating point type.");
using Type = T;
T x, y, z;
};
template <typename T> struct BinaryOutput {
T f;
int i;
};
namespace internal {
template <typename T1, typename T2>
struct AreMatchingBinaryInputAndBinaryOutput {
static constexpr bool value = false;
};
template <typename T>
bool compare(Operation op, T input, T libcOutput, double t);
struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
static constexpr bool value = cpp::IsFloatingPointType<T>::Value;
};
template <typename T> class MPFRMatcher : public testing::Matcher<T> {
static_assert(__llvm_libc::cpp::IsFloatingPointType<T>::Value,
"MPFRMatcher can only be used with floating point values.");
template <typename T>
bool compareUnaryOperationSingleOutput(Operation op, T input, T libcOutput,
double t);
template <typename T>
bool compareUnaryOperationTwoOutputs(Operation op, T input,
const BinaryOutput<T> &libcOutput,
double t);
template <typename T>
bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
const BinaryOutput<T> &libcOutput,
double t);
Operation operation;
T input;
T matchValue;
template <typename T>
void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
testutils::StreamWrapper &OS);
template <typename T>
void explainUnaryOperationTwoOutputsError(Operation op, T input,
const BinaryOutput<T> &matchValue,
testutils::StreamWrapper &OS);
template <typename T>
void explainBinaryOperationTwoOutputsError(Operation op,
const BinaryInput<T> &input,
const BinaryOutput<T> &matchValue,
testutils::StreamWrapper &OS);
template <Operation op, typename InputType, typename OutputType>
class MPFRMatcher : public testing::Matcher<OutputType> {
InputType input;
OutputType matchValue;
double ulpTolerance;
public:
MPFRMatcher(Operation op, T testInput, double ulpTolerance)
: operation(op), input(testInput), ulpTolerance(ulpTolerance) {}
MPFRMatcher(InputType testInput, double ulpTolerance)
: input(testInput), ulpTolerance(ulpTolerance) {}
bool match(T libcResult) {
bool match(OutputType libcResult) {
matchValue = libcResult;
return internal::compare(operation, input, libcResult, ulpTolerance);
return match(input, matchValue, ulpTolerance);
}
void explainError(testutils::StreamWrapper &OS) override;
void explainError(testutils::StreamWrapper &OS) override {
explainError(input, matchValue, OS);
}
private:
template <typename T> static bool match(T in, T out, double tolerance) {
return compareUnaryOperationSingleOutput(op, in, out, tolerance);
}
template <typename T>
static bool match(T in, const BinaryOutput<T> &out, double tolerance) {
return compareUnaryOperationTwoOutputs(op, in, out, tolerance);
}
template <typename T>
static bool match(const BinaryInput<T> &in, T out, double tolerance) {
// TODO: Implement the comparision function and error reporter.
}
template <typename T>
static bool match(BinaryInput<T> in, const BinaryOutput<T> &out,
double tolerance) {
return compareBinaryOperationTwoOutputs(op, in, out, tolerance);
}
template <typename T>
static bool match(const TernaryInput<T> &in, T out, double tolerance) {
// TODO: Implement the comparision function and error reporter.
}
template <typename T>
static void explainError(T in, T out, testutils::StreamWrapper &OS) {
explainUnaryOperationSingleOutputError(op, in, out, OS);
}
template <typename T>
static void explainError(T in, const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explainUnaryOperationTwoOutputsError(op, in, out, OS);
}
template <typename T>
static void explainError(const BinaryInput<T> &in, const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explainBinaryOperationTwoOutputsError(op, in, out, OS);
}
};
} // namespace internal
template <typename T, typename U>
// Return true if the input and ouput types for the operation op are valid
// types.
template <Operation op, typename InputType, typename OutputType>
constexpr bool isValidOperation() {
return (Operation::BeginUnaryOperationsSingleOutput < op &&
op < Operation::EndUnaryOperationsSingleOutput &&
cpp::IsSame<InputType, OutputType>::Value &&
cpp::IsFloatingPointType<InputType>::Value) ||
(Operation::BeginUnaryOperationsTwoOutputs < op &&
op < Operation::EndUnaryOperationsTwoOutputs &&
cpp::IsFloatingPointType<InputType>::Value &&
cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) ||
(Operation::BeginBinaryOperationsSingleOutput < op &&
op < Operation::EndBinaryOperationsSingleOutput &&
cpp::IsFloatingPointType<OutputType>::Value &&
cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) ||
(Operation::BeginBinaryOperationsTwoOutputs < op &&
op < Operation::EndBinaryOperationsTwoOutputs &&
internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
OutputType>::value) ||
(Operation::BeginTernaryOperationsSingleOuput < op &&
op < Operation::EndTernaryOperationsSingleOutput &&
cpp::IsFloatingPointType<OutputType>::Value &&
cpp::IsSame<InputType, TernaryInput<OutputType>>::Value);
}
template <Operation op, typename InputType, typename OutputType>
__attribute__((no_sanitize("address")))
typename cpp::EnableIfType<cpp::IsSameV<U, double>, internal::MPFRMatcher<T>>
getMPFRMatcher(Operation op, T input, U t) {
static_assert(
__llvm_libc::cpp::IsFloatingPointType<T>::Value,
"getMPFRMatcher can only be used to match floating point results.");
return internal::MPFRMatcher<T>(op, input, t);
cpp::EnableIfType<isValidOperation<op, InputType, OutputType>(),
internal::MPFRMatcher<op, InputType, OutputType>>
getMPFRMatcher(InputType input, OutputType outputUnused, double t) {
return internal::MPFRMatcher<op, InputType, OutputType>(input, t);
}
} // namespace mpfr
@ -74,11 +226,11 @@ getMPFRMatcher(Operation op, T input, U t) {
} // namespace __llvm_libc
#define EXPECT_MPFR_MATCH(op, input, matchValue, tolerance) \
EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \
op, input, tolerance))
EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \
input, matchValue, tolerance))
#define ASSERT_MPFR_MATCH(op, input, matchValue, tolerance) \
ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \
op, input, tolerance))
ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \
input, matchValue, tolerance))
#endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H