forked from OSchip/llvm-project
[flang] evaluate: Fold SQRT, HYPOT, & CABS
Implement IEEE Real::SQRT() operation, then use it to also implement Real::HYPOT(), which can then be used directly to implement Complex::ABS(). Differential Revision: https://reviews.llvm.org/D109250
This commit is contained in:
parent
ea04bf302c
commit
c9e9635ffe
|
@ -77,6 +77,11 @@ public:
|
|||
ValueWithRealFlags<Complex> Divide(
|
||||
const Complex &, Rounding rounding = defaultRounding) const;
|
||||
|
||||
// ABS/CABS = HYPOT(re_, imag_) = SQRT(re_**2 + im_**2)
|
||||
ValueWithRealFlags<Part> ABS(Rounding rounding = defaultRounding) const {
|
||||
return re_.HYPOT(im_, rounding);
|
||||
}
|
||||
|
||||
constexpr Complex FlushSubnormalToZero() const {
|
||||
return {re_.FlushSubnormalToZero(), im_.FlushSubnormalToZero()};
|
||||
}
|
||||
|
@ -88,7 +93,6 @@ public:
|
|||
std::string DumpHexadecimal() const;
|
||||
llvm::raw_ostream &AsFortran(llvm::raw_ostream &, int kind) const;
|
||||
|
||||
// TODO: (C)ABS once Real::HYPOT is done
|
||||
// TODO: unit testing
|
||||
|
||||
private:
|
||||
|
|
|
@ -115,8 +115,10 @@ public:
|
|||
ValueWithRealFlags<Real> Divide(
|
||||
const Real &, Rounding rounding = defaultRounding) const;
|
||||
|
||||
// SQRT(x**2 + y**2) but computed so as to avoid spurious overflow
|
||||
// TODO: not yet implemented; needed for CABS
|
||||
ValueWithRealFlags<Real> SQRT(Rounding rounding = defaultRounding) const;
|
||||
|
||||
// HYPOT(x,y)=SQRT(x**2 + y**2) computed so as to avoid spurious
|
||||
// intermediate overflows.
|
||||
ValueWithRealFlags<Real> HYPOT(
|
||||
const Real &, Rounding rounding = defaultRounding) const;
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
|
|||
name == "bessel_y1" || name == "cos" || name == "cosh" || name == "erf" ||
|
||||
name == "erfc" || name == "erfc_scaled" || name == "exp" ||
|
||||
name == "gamma" || name == "log" || name == "log10" ||
|
||||
name == "log_gamma" || name == "sin" || name == "sinh" ||
|
||||
name == "sqrt" || name == "tan" || name == "tanh") {
|
||||
name == "log_gamma" || name == "sin" || name == "sinh" || name == "tan" ||
|
||||
name == "tanh") {
|
||||
CHECK(args.size() == 1);
|
||||
if (auto callable{GetHostRuntimeWrapper<T, T>(name)}) {
|
||||
return FoldElementalIntrinsic<T, T>(
|
||||
|
@ -40,8 +40,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
|
|||
} else if (name == "amax0" || name == "amin0" || name == "amin1" ||
|
||||
name == "amax1" || name == "dmin1" || name == "dmax1") {
|
||||
return RewriteSpecificMINorMAX(context, std::move(funcRef));
|
||||
} else if (name == "atan" || name == "atan2" || name == "hypot" ||
|
||||
name == "mod") {
|
||||
} else if (name == "atan" || name == "atan2" || name == "mod") {
|
||||
std::string localName{name == "atan" ? "atan2" : name};
|
||||
CHECK(args.size() == 2);
|
||||
if (auto callable{GetHostRuntimeWrapper<T, T, T>(localName)}) {
|
||||
|
@ -71,13 +70,10 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
|
|||
return FoldElementalIntrinsic<T, T>(
|
||||
context, std::move(funcRef), &Scalar<T>::ABS);
|
||||
} else if (auto *z{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
|
||||
if (auto callable{GetHostRuntimeWrapper<T, ComplexT>("abs")}) {
|
||||
return FoldElementalIntrinsic<T, ComplexT>(
|
||||
context, std::move(funcRef), *callable);
|
||||
} else {
|
||||
context.messages().Say(
|
||||
"abs(complex(kind=%d)) cannot be folded on host"_en_US, KIND);
|
||||
}
|
||||
return FoldElementalIntrinsic<T, ComplexT>(context, std::move(funcRef),
|
||||
ScalarFunc<T, ComplexT>([](const Scalar<ComplexT> &z) -> Scalar<T> {
|
||||
return z.ABS().value;
|
||||
}));
|
||||
} else {
|
||||
common::die(" unexpected argument type inside abs");
|
||||
}
|
||||
|
@ -108,6 +104,13 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
|
|||
return Expr<T>{Scalar<T>::EPSILON()};
|
||||
} else if (name == "huge") {
|
||||
return Expr<T>{Scalar<T>::HUGE()};
|
||||
} else if (name == "hypot") {
|
||||
CHECK(args.size() == 2);
|
||||
return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
|
||||
ScalarFunc<T, T, T>(
|
||||
[](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> {
|
||||
return x.HYPOT(y).value;
|
||||
}));
|
||||
} else if (name == "max") {
|
||||
return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
|
||||
} else if (name == "maxval") {
|
||||
|
@ -130,6 +133,10 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
|
|||
} else if (name == "sign") {
|
||||
return FoldElementalIntrinsic<T, T, T>(
|
||||
context, std::move(funcRef), &Scalar<T>::SIGN);
|
||||
} else if (name == "sqrt") {
|
||||
return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
|
||||
ScalarFunc<T, T>(
|
||||
[](const Scalar<T> &x) -> Scalar<T> { return x.SQRT().value; }));
|
||||
} else if (name == "sum") {
|
||||
return FoldSum<T>(context, std::move(funcRef));
|
||||
} else if (name == "tiny") {
|
||||
|
|
|
@ -222,7 +222,6 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
|
|||
FolderFactory<F, F{std::erfc}>::Create("erfc"),
|
||||
FolderFactory<F, F{std::exp}>::Create("exp"),
|
||||
FolderFactory<F, F{std::tgamma}>::Create("gamma"),
|
||||
FolderFactory<F2, F2{std::hypot}>::Create("hypot"),
|
||||
FolderFactory<F, F{std::log}>::Create("log"),
|
||||
FolderFactory<F, F{std::log10}>::Create("log10"),
|
||||
FolderFactory<F, F{std::lgamma}>::Create("log_gamma"),
|
||||
|
@ -230,7 +229,6 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
|
|||
FolderFactory<F2, F2{std::pow}>::Create("pow"),
|
||||
FolderFactory<F, F{std::sin}>::Create("sin"),
|
||||
FolderFactory<F, F{std::sinh}>::Create("sinh"),
|
||||
FolderFactory<F, F{std::sqrt}>::Create("sqrt"),
|
||||
FolderFactory<F, F{std::tan}>::Create("tan"),
|
||||
FolderFactory<F, F{std::tanh}>::Create("tanh"),
|
||||
};
|
||||
|
|
|
@ -261,6 +261,107 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::Divide(
|
|||
return result;
|
||||
}
|
||||
|
||||
template <typename W, int P>
|
||||
ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
|
||||
ValueWithRealFlags<Real> result;
|
||||
if (IsNotANumber()) {
|
||||
result.value = NotANumber();
|
||||
if (IsSignalingNaN()) {
|
||||
result.flags.set(RealFlag::InvalidArgument);
|
||||
}
|
||||
} else if (IsNegative()) {
|
||||
if (IsZero()) {
|
||||
// SQRT(-0) == -0 in IEEE-754.
|
||||
result.value.word_ = result.value.word_.IBSET(bits - 1);
|
||||
} else {
|
||||
result.value = NotANumber();
|
||||
}
|
||||
} else if (IsInfinite()) {
|
||||
// SQRT(+Inf) == +Inf
|
||||
result.value = Infinity(false);
|
||||
} else {
|
||||
// Slow but reliable bit-at-a-time method. Start with a clear significand
|
||||
// and half the unbiased exponent, and then try to set significand bits
|
||||
// in descending order of magnitude without exceeding the exact result.
|
||||
int expo{UnbiasedExponent()};
|
||||
if (IsSubnormal()) {
|
||||
expo -= GetFraction().LEADZ();
|
||||
}
|
||||
expo = expo / 2 + exponentBias;
|
||||
result.value.Normalize(false, expo, Fraction::MASKL(1));
|
||||
for (int bit{significandBits - 1}; bit >= 0; --bit) {
|
||||
Word word{result.value.word_};
|
||||
result.value.word_ = word.IBSET(bit);
|
||||
auto squared{result.value.Multiply(result.value, rounding)};
|
||||
if (squared.flags.test(RealFlag::Overflow) ||
|
||||
squared.flags.test(RealFlag::Underflow) ||
|
||||
Compare(squared.value) == Relation::Less) {
|
||||
result.value.word_ = word;
|
||||
}
|
||||
}
|
||||
// The computed square root, when squared, has a square that's not greater
|
||||
// than the original argument. Check this square against the square of the
|
||||
// next Real value, and return that one if its square is closer in magnitude
|
||||
// to the original argument.
|
||||
Real resultSq{result.value.Multiply(result.value).value};
|
||||
Real diff{Subtract(resultSq).value.ABS()};
|
||||
if (diff.IsZero()) {
|
||||
return result; // exact
|
||||
}
|
||||
Real ulp;
|
||||
ulp.Normalize(false, expo, Fraction::MASKR(1));
|
||||
Real nextAfter{result.value.Add(ulp).value};
|
||||
auto nextAfterSq{nextAfter.Multiply(nextAfter)};
|
||||
if (!nextAfterSq.flags.test(RealFlag::Overflow) &&
|
||||
!nextAfterSq.flags.test(RealFlag::Underflow)) {
|
||||
Real nextAfterDiff{Subtract(nextAfterSq.value).value.ABS()};
|
||||
if (nextAfterDiff.Compare(diff) == Relation::Less) {
|
||||
result.value = nextAfter;
|
||||
if (nextAfterDiff.IsZero()) {
|
||||
return result; // exact
|
||||
}
|
||||
}
|
||||
}
|
||||
result.flags.set(RealFlag::Inexact);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// HYPOT(x,y) = SQRT(x**2 + y**2) by definition, but those squared intermediate
|
||||
// values are susceptible to over/underflow when computed naively.
|
||||
// Assuming that x>=y, calculate instead:
|
||||
// HYPOT(x,y) = SQRT(x**2 * (1+(y/x)**2))
|
||||
// = ABS(x) * SQRT(1+(y/x)**2)
|
||||
template <typename W, int P>
|
||||
ValueWithRealFlags<Real<W, P>> Real<W, P>::HYPOT(
|
||||
const Real &y, Rounding rounding) const {
|
||||
ValueWithRealFlags<Real> result;
|
||||
if (IsNotANumber() || y.IsNotANumber()) {
|
||||
result.flags.set(RealFlag::InvalidArgument);
|
||||
result.value = NotANumber();
|
||||
} else if (ABS().Compare(y.ABS()) == Relation::Less) {
|
||||
return y.HYPOT(*this);
|
||||
} else if (IsZero()) {
|
||||
return result; // x==y==0
|
||||
} else {
|
||||
auto yOverX{y.Divide(*this, rounding)}; // y/x
|
||||
bool inexact{yOverX.flags.test(RealFlag::Inexact)};
|
||||
auto squared{yOverX.value.Multiply(yOverX.value, rounding)}; // (y/x)**2
|
||||
inexact |= squared.flags.test(RealFlag::Inexact);
|
||||
Real one;
|
||||
one.Normalize(false, exponentBias, Fraction::MASKL(1)); // 1.0
|
||||
auto sum{squared.value.Add(one, rounding)}; // 1.0 + (y/x)**2
|
||||
inexact |= sum.flags.test(RealFlag::Inexact);
|
||||
auto sqrt{sum.value.SQRT()};
|
||||
inexact |= sqrt.flags.test(RealFlag::Inexact);
|
||||
result = sqrt.value.Multiply(ABS(), rounding);
|
||||
if (inexact) {
|
||||
result.flags.set(RealFlag::Inexact);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename W, int P>
|
||||
ValueWithRealFlags<Real<W, P>> Real<W, P>::ToWholeNumber(
|
||||
common::RoundingMode mode) const {
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
! RUN: %S/test_folding.sh %s %t %flang_fc1
|
||||
! REQUIRES: shell
|
||||
! Tests folding of SQRT()
|
||||
module m
|
||||
implicit none
|
||||
! +Inf
|
||||
real(8), parameter :: inf8 = z'7ff0000000000000'
|
||||
logical, parameter :: test_inf8 = sqrt(inf8) == inf8
|
||||
! max finite
|
||||
real(8), parameter :: h8 = huge(1.0_8), h8z = z'7fefffffffffffff'
|
||||
logical, parameter :: test_h8 = h8 == h8z
|
||||
real(8), parameter :: sqrt_h8 = sqrt(h8), sqrt_h8z = z'5fefffffffffffff'
|
||||
logical, parameter :: test_sqrt_h8 = sqrt_h8 == sqrt_h8z
|
||||
real(8), parameter :: sqr_sqrt_h8 = sqrt_h8 * sqrt_h8, sqr_sqrt_h8z = z'7feffffffffffffe'
|
||||
logical, parameter :: test_sqr_sqrt_h8 = sqr_sqrt_h8 == sqr_sqrt_h8z
|
||||
! -0 (sqrt is -0)
|
||||
real(8), parameter :: n08 = z'8000000000000000'
|
||||
real(8), parameter :: sqrt_n08 = sqrt(n08)
|
||||
!WARN: division by zero
|
||||
real(8), parameter :: inf_n08 = 1.0_8 / sqrt_n08, inf_n08z = z'fff0000000000000'
|
||||
logical, parameter :: test_n08 = inf_n08 == inf_n08z
|
||||
! min normal
|
||||
real(8), parameter :: t8 = tiny(1.0_8), t8z = z'0010000000000000'
|
||||
logical, parameter :: test_t8 = t8 == t8z
|
||||
real(8), parameter :: sqrt_t8 = sqrt(t8), sqrt_t8z = z'2000000000000000'
|
||||
logical, parameter :: test_sqrt_t8 = sqrt_t8 == sqrt_t8z
|
||||
real(8), parameter :: sqr_sqrt_t8 = sqrt_t8 * sqrt_t8
|
||||
logical, parameter :: test_sqr_sqrt_t8 = sqr_sqrt_t8 == t8
|
||||
! max subnormal
|
||||
real(8), parameter :: maxs8 = z'000fffffffffffff'
|
||||
real(8), parameter :: sqrt_maxs8 = sqrt(maxs8), sqrt_maxs8z = z'2000000000000000'
|
||||
logical, parameter :: test_sqrt_maxs8 = sqrt_maxs8 == sqrt_maxs8z
|
||||
! min subnormal
|
||||
real(8), parameter :: mins8 = z'1'
|
||||
real(8), parameter :: sqrt_mins8 = sqrt(mins8), sqrt_mins8z = z'1e60000000000000'
|
||||
logical, parameter :: test_sqrt_mins8 = sqrt_mins8 == sqrt_mins8z
|
||||
real(8), parameter :: sqr_sqrt_mins8 = sqrt_mins8 * sqrt_mins8
|
||||
logical, parameter :: test_sqr_sqrt_mins8 = sqr_sqrt_mins8 == mins8
|
||||
end module
|
||||
|
Loading…
Reference in New Issue