[flang] Make SQRT folding exact

Replace the latter half of the SQRT() folding algorithm with code that
calculates an exact root with extra rounding bits, and then lets the
usual normalization and rounding code do the right thing.  Extend
tests to catch regressions.

Differential Revision: https://reviews.llvm.org/D128395
This commit is contained in:
Peter Klausler 2022-06-20 17:22:33 -07:00
parent dfaa3880e1
commit 1ef5e6de76
3 changed files with 61 additions and 45 deletions

View File

@ -274,6 +274,7 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
// SQRT(-0) == -0 in IEEE-754.
result.value = NegativeZero();
} else {
result.flags.set(RealFlag::InvalidArgument);
result.value = NotANumber();
}
} else if (IsInfinite()) {
@ -297,53 +298,31 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
result.value.GetFraction());
return result;
}
// Compute the square root of the reduced value with the slow but
// reliable bit-at-a-time method. Start with a clear significand and
// half of the unbiased exponent, and then try to set significand bits
// in descending order of magnitude without exceeding the exact result.
expo = expo / 2 + exponentBias;
result.value.Normalize(false, expo, Fraction::MASKL(1));
Real initialSq{result.value.Multiply(result.value).value};
if (Compare(initialSq) == Relation::Less) {
// Initial estimate is too large; this can happen for values just
// under 1.0.
--expo;
result.value.Normalize(false, expo, Fraction::MASKL(1));
}
for (int bit{significandBits - 1}; bit >= 0; --bit) {
Word word{result.value.word_};
result.value.word_ = word.IBSET(bit);
auto squared{result.value.Multiply(result.value, rounding)};
if (squared.flags.test(RealFlag::Overflow) ||
squared.flags.test(RealFlag::Underflow) ||
Compare(squared.value) == Relation::Less) {
result.value.word_ = word;
// (-1) <= expo <= 1; use it as a shift to set the desired square.
using Extended = typename value::Integer<(binaryPrecision + 2)>;
Extended goal{
Extended::ConvertUnsigned(GetFraction()).value.SHIFTL(expo + 1)};
// Calculate the exact square root by maximizing a value whose square
// does not exceed the goal. Use two extra bits of precision for
// rounding.
bool sticky{true};
Extended extFrac{};
for (int bit{Extended::bits - 1}; bit >= 0; --bit) {
Extended next{extFrac.IBSET(bit)};
auto squared{next.MultiplyUnsigned(next)};
auto cmp{squared.upper.CompareUnsigned(goal)};
if (cmp == Ordering::Less) {
extFrac = next;
} else if (cmp == Ordering::Equal && squared.lower.IsZero()) {
extFrac = next;
sticky = false;
break; // exact result
}
}
// The computed square root has a square that's not greater than the
// original argument. Check this square against the square of the next
// larger Real and return that one if its square is closer in magnitude to
// the original argument.
Real resultSq{result.value.Multiply(result.value).value};
Real diff{Subtract(resultSq).value.ABS()};
if (diff.IsZero()) {
return result; // exact
}
Real ulp;
ulp.Normalize(false, expo, Fraction::MASKR(1));
Real nextAfter{result.value.Add(ulp).value};
auto nextAfterSq{nextAfter.Multiply(nextAfter)};
if (!nextAfterSq.flags.test(RealFlag::Overflow) &&
!nextAfterSq.flags.test(RealFlag::Underflow)) {
Real nextAfterDiff{Subtract(nextAfterSq.value).value.ABS()};
if (nextAfterDiff.Compare(diff) == Relation::Less) {
result.value = nextAfter;
if (nextAfterDiff.IsZero()) {
return result; // exact
}
}
}
result.flags.set(RealFlag::Inexact);
RoundingBits roundingBits{extFrac.BTEST(1), extFrac.BTEST(0), sticky};
NormalizeAndRound(result, false, exponentBias,
Fraction::ConvertUnsigned(extFrac.SHIFTR(2)).value, rounding,
roundingBits);
}
return result;
}

View File

@ -49,4 +49,25 @@ module m
logical, parameter :: test_sqrt_zero_4 = sqrt_zero_4 == 0.0
real(8), parameter :: sqrt_zero_8 = sqrt(0.0)
logical, parameter :: test_sqrt_zero_8 = sqrt_zero_8 == 0.0
! Some common values to get right
real(8), parameter :: sqrt_1_8 = sqrt(1.d0)
logical, parameter :: test_sqrt_1_8 = sqrt_1_8 == 1.d0
real(8), parameter :: sqrt_2_8 = sqrt(2.d0)
logical, parameter :: test_sqrt_2_8 = sqrt_2_8 == 1.4142135623730951454746218587388284504413604736328125d0
real(8), parameter :: sqrt_3_8 = sqrt(3.d0)
logical, parameter :: test_sqrt_3_8 = sqrt_3_8 == 1.732050807568877193176604123436845839023590087890625d0
real(8), parameter :: sqrt_4_8 = sqrt(4.d0)
logical, parameter :: test_sqrt_4_8 = sqrt_4_8 == 2.d0
real(8), parameter :: sqrt_5_8 = sqrt(5.d0)
logical, parameter :: test_sqrt_5_8 = sqrt_5_8 == 2.236067977499789805051477742381393909454345703125d0
real(8), parameter :: sqrt_6_8 = sqrt(6.d0)
logical, parameter :: test_sqrt_6_8 = sqrt_6_8 == 2.44948974278317788133563226438127458095550537109375d0
real(8), parameter :: sqrt_7_8 = sqrt(7.d0)
logical, parameter :: test_sqrt_7_8 = sqrt_7_8 == 2.64575131106459071617109657381661236286163330078125d0
real(8), parameter :: sqrt_8_8 = sqrt(8.d0)
logical, parameter :: test_sqrt_8_8 = sqrt_8_8 == 2.828427124746190290949243717477656900882720947265625d0
real(8), parameter :: sqrt_9_8 = sqrt(9.d0)
logical, parameter :: test_sqrt_9_8 = sqrt_9_8 == 3.d0
real(8), parameter :: sqrt_10_8 = sqrt(10.d0)
logical, parameter :: test_sqrt_10_8 = sqrt_10_8 == 3.162277660168379522787063251598738133907318115234375d0
end module

View File

@ -392,6 +392,22 @@ void subsetTests(int pass, Rounding rounding, std::uint32_t opds) {
("%d AINT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
}
{
ValueWithRealFlags<REAL> root{x.SQRT(rounding)};
#ifndef __clang__ // broken and also slow
fpenv.ClearFlags();
#endif
FLT fcheck{std::sqrt(fj)};
auto actualFlags{FlagsToBits(fpenv.CurrentFlags())};
u.f = fcheck;
UINT rcheck{NormalizeNaN(u.ui)};
UINT check = root.value.RawBits().ToUInt64();
MATCH(rcheck, check)
("%d SQRT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
MATCH(actualFlags, FlagsToBits(root.flags))
("%d SQRT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
}
{
MATCH(IsNaN(rj), x.IsNotANumber())
("%d IsNaN(0x%jx)", pass, static_cast<std::intmax_t>(rj));