forked from OSchip/llvm-project
[flang] Make SQRT folding exact
Replace the latter half of the SQRT() folding algorithm with code that calculates an exact root with extra rounding bits, and then lets the usual normalization and rounding code do the right thing. Extend tests to catch regressions. Differential Revision: https://reviews.llvm.org/D128395
This commit is contained in:
parent
dfaa3880e1
commit
1ef5e6de76
|
@ -274,6 +274,7 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
|
||||||
// SQRT(-0) == -0 in IEEE-754.
|
// SQRT(-0) == -0 in IEEE-754.
|
||||||
result.value = NegativeZero();
|
result.value = NegativeZero();
|
||||||
} else {
|
} else {
|
||||||
|
result.flags.set(RealFlag::InvalidArgument);
|
||||||
result.value = NotANumber();
|
result.value = NotANumber();
|
||||||
}
|
}
|
||||||
} else if (IsInfinite()) {
|
} else if (IsInfinite()) {
|
||||||
|
@ -297,53 +298,31 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
|
||||||
result.value.GetFraction());
|
result.value.GetFraction());
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
// Compute the square root of the reduced value with the slow but
|
// (-1) <= expo <= 1; use it as a shift to set the desired square.
|
||||||
// reliable bit-at-a-time method. Start with a clear significand and
|
using Extended = typename value::Integer<(binaryPrecision + 2)>;
|
||||||
// half of the unbiased exponent, and then try to set significand bits
|
Extended goal{
|
||||||
// in descending order of magnitude without exceeding the exact result.
|
Extended::ConvertUnsigned(GetFraction()).value.SHIFTL(expo + 1)};
|
||||||
expo = expo / 2 + exponentBias;
|
// Calculate the exact square root by maximizing a value whose square
|
||||||
result.value.Normalize(false, expo, Fraction::MASKL(1));
|
// does not exceed the goal. Use two extra bits of precision for
|
||||||
Real initialSq{result.value.Multiply(result.value).value};
|
// rounding.
|
||||||
if (Compare(initialSq) == Relation::Less) {
|
bool sticky{true};
|
||||||
// Initial estimate is too large; this can happen for values just
|
Extended extFrac{};
|
||||||
// under 1.0.
|
for (int bit{Extended::bits - 1}; bit >= 0; --bit) {
|
||||||
--expo;
|
Extended next{extFrac.IBSET(bit)};
|
||||||
result.value.Normalize(false, expo, Fraction::MASKL(1));
|
auto squared{next.MultiplyUnsigned(next)};
|
||||||
}
|
auto cmp{squared.upper.CompareUnsigned(goal)};
|
||||||
for (int bit{significandBits - 1}; bit >= 0; --bit) {
|
if (cmp == Ordering::Less) {
|
||||||
Word word{result.value.word_};
|
extFrac = next;
|
||||||
result.value.word_ = word.IBSET(bit);
|
} else if (cmp == Ordering::Equal && squared.lower.IsZero()) {
|
||||||
auto squared{result.value.Multiply(result.value, rounding)};
|
extFrac = next;
|
||||||
if (squared.flags.test(RealFlag::Overflow) ||
|
sticky = false;
|
||||||
squared.flags.test(RealFlag::Underflow) ||
|
break; // exact result
|
||||||
Compare(squared.value) == Relation::Less) {
|
|
||||||
result.value.word_ = word;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// The computed square root has a square that's not greater than the
|
RoundingBits roundingBits{extFrac.BTEST(1), extFrac.BTEST(0), sticky};
|
||||||
// original argument. Check this square against the square of the next
|
NormalizeAndRound(result, false, exponentBias,
|
||||||
// larger Real and return that one if its square is closer in magnitude to
|
Fraction::ConvertUnsigned(extFrac.SHIFTR(2)).value, rounding,
|
||||||
// the original argument.
|
roundingBits);
|
||||||
Real resultSq{result.value.Multiply(result.value).value};
|
|
||||||
Real diff{Subtract(resultSq).value.ABS()};
|
|
||||||
if (diff.IsZero()) {
|
|
||||||
return result; // exact
|
|
||||||
}
|
|
||||||
Real ulp;
|
|
||||||
ulp.Normalize(false, expo, Fraction::MASKR(1));
|
|
||||||
Real nextAfter{result.value.Add(ulp).value};
|
|
||||||
auto nextAfterSq{nextAfter.Multiply(nextAfter)};
|
|
||||||
if (!nextAfterSq.flags.test(RealFlag::Overflow) &&
|
|
||||||
!nextAfterSq.flags.test(RealFlag::Underflow)) {
|
|
||||||
Real nextAfterDiff{Subtract(nextAfterSq.value).value.ABS()};
|
|
||||||
if (nextAfterDiff.Compare(diff) == Relation::Less) {
|
|
||||||
result.value = nextAfter;
|
|
||||||
if (nextAfterDiff.IsZero()) {
|
|
||||||
return result; // exact
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result.flags.set(RealFlag::Inexact);
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,4 +49,25 @@ module m
|
||||||
logical, parameter :: test_sqrt_zero_4 = sqrt_zero_4 == 0.0
|
logical, parameter :: test_sqrt_zero_4 = sqrt_zero_4 == 0.0
|
||||||
real(8), parameter :: sqrt_zero_8 = sqrt(0.0)
|
real(8), parameter :: sqrt_zero_8 = sqrt(0.0)
|
||||||
logical, parameter :: test_sqrt_zero_8 = sqrt_zero_8 == 0.0
|
logical, parameter :: test_sqrt_zero_8 = sqrt_zero_8 == 0.0
|
||||||
|
! Some common values to get right
|
||||||
|
real(8), parameter :: sqrt_1_8 = sqrt(1.d0)
|
||||||
|
logical, parameter :: test_sqrt_1_8 = sqrt_1_8 == 1.d0
|
||||||
|
real(8), parameter :: sqrt_2_8 = sqrt(2.d0)
|
||||||
|
logical, parameter :: test_sqrt_2_8 = sqrt_2_8 == 1.4142135623730951454746218587388284504413604736328125d0
|
||||||
|
real(8), parameter :: sqrt_3_8 = sqrt(3.d0)
|
||||||
|
logical, parameter :: test_sqrt_3_8 = sqrt_3_8 == 1.732050807568877193176604123436845839023590087890625d0
|
||||||
|
real(8), parameter :: sqrt_4_8 = sqrt(4.d0)
|
||||||
|
logical, parameter :: test_sqrt_4_8 = sqrt_4_8 == 2.d0
|
||||||
|
real(8), parameter :: sqrt_5_8 = sqrt(5.d0)
|
||||||
|
logical, parameter :: test_sqrt_5_8 = sqrt_5_8 == 2.236067977499789805051477742381393909454345703125d0
|
||||||
|
real(8), parameter :: sqrt_6_8 = sqrt(6.d0)
|
||||||
|
logical, parameter :: test_sqrt_6_8 = sqrt_6_8 == 2.44948974278317788133563226438127458095550537109375d0
|
||||||
|
real(8), parameter :: sqrt_7_8 = sqrt(7.d0)
|
||||||
|
logical, parameter :: test_sqrt_7_8 = sqrt_7_8 == 2.64575131106459071617109657381661236286163330078125d0
|
||||||
|
real(8), parameter :: sqrt_8_8 = sqrt(8.d0)
|
||||||
|
logical, parameter :: test_sqrt_8_8 = sqrt_8_8 == 2.828427124746190290949243717477656900882720947265625d0
|
||||||
|
real(8), parameter :: sqrt_9_8 = sqrt(9.d0)
|
||||||
|
logical, parameter :: test_sqrt_9_8 = sqrt_9_8 == 3.d0
|
||||||
|
real(8), parameter :: sqrt_10_8 = sqrt(10.d0)
|
||||||
|
logical, parameter :: test_sqrt_10_8 = sqrt_10_8 == 3.162277660168379522787063251598738133907318115234375d0
|
||||||
end module
|
end module
|
||||||
|
|
|
@ -392,6 +392,22 @@ void subsetTests(int pass, Rounding rounding, std::uint32_t opds) {
|
||||||
("%d AINT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
|
("%d AINT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
ValueWithRealFlags<REAL> root{x.SQRT(rounding)};
|
||||||
|
#ifndef __clang__ // broken and also slow
|
||||||
|
fpenv.ClearFlags();
|
||||||
|
#endif
|
||||||
|
FLT fcheck{std::sqrt(fj)};
|
||||||
|
auto actualFlags{FlagsToBits(fpenv.CurrentFlags())};
|
||||||
|
u.f = fcheck;
|
||||||
|
UINT rcheck{NormalizeNaN(u.ui)};
|
||||||
|
UINT check = root.value.RawBits().ToUInt64();
|
||||||
|
MATCH(rcheck, check)
|
||||||
|
("%d SQRT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
|
||||||
|
MATCH(actualFlags, FlagsToBits(root.flags))
|
||||||
|
("%d SQRT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
MATCH(IsNaN(rj), x.IsNotANumber())
|
MATCH(IsNaN(rj), x.IsNotANumber())
|
||||||
("%d IsNaN(0x%jx)", pass, static_cast<std::intmax_t>(rj));
|
("%d IsNaN(0x%jx)", pass, static_cast<std::intmax_t>(rj));
|
||||||
|
|
Loading…
Reference in New Issue