[flang] Fold NEAREST() and its relatives

Implement constant folding for the intrinsic function NEAREST()
and the related functions IEEE_NEXT_AFTER(), IEEE_NEXT_UP(), and
IEEE_NEXT_DOWN().

Differential Revision: https://reviews.llvm.org/D122510
This commit is contained in:
Peter Klausler 2022-03-24 09:03:07 -07:00
parent fceea4e110
commit e619c07d16
4 changed files with 205 additions and 2 deletions

View File

@ -120,6 +120,9 @@ public:
ValueWithRealFlags<Real> SQRT(Rounding rounding = defaultRounding) const; ValueWithRealFlags<Real> SQRT(Rounding rounding = defaultRounding) const;
// NEAREST(), IEEE_NEXT_AFTER(), IEEE_NEXT_UP(), and IEEE_NEXT_DOWN()
ValueWithRealFlags<Real> NEAREST(bool upward) const;
// HYPOT(x,y)=SQRT(x**2 + y**2) computed so as to avoid spurious // HYPOT(x,y)=SQRT(x**2 + y**2) computed so as to avoid spurious
// intermediate overflows. // intermediate overflows.
ValueWithRealFlags<Real> HYPOT( ValueWithRealFlags<Real> HYPOT(

View File

@ -119,6 +119,31 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
RelationalOperator::GT, T::Scalar::HUGE().Negate()); RelationalOperator::GT, T::Scalar::HUGE().Negate());
} else if (name == "merge") { } else if (name == "merge") {
return FoldMerge<T>(context, std::move(funcRef)); return FoldMerge<T>(context, std::move(funcRef));
} else if (name == "nearest") {
if (const auto *sExpr{UnwrapExpr<Expr<SomeReal>>(args[1])}) {
return std::visit(
[&](const auto &sVal) {
using TS = ResultType<decltype(sVal)>;
return FoldElementalIntrinsic<T, T, TS>(context, std::move(funcRef),
ScalarFunc<T, T, TS>([&](const Scalar<T> &x,
const Scalar<TS> &s) -> Scalar<T> {
if (s.IsZero()) {
context.messages().Say(
"NEAREST: S argument is zero"_warn_en_US);
}
auto result{x.NEAREST(!s.IsNegative())};
if (result.flags.test(RealFlag::Overflow)) {
context.messages().Say(
"NEAREST intrinsic folding overflow"_warn_en_US);
} else if (result.flags.test(RealFlag::InvalidArgument)) {
context.messages().Say(
"NEAREST intrinsic folding: bad argument"_warn_en_US);
}
return result.value;
}));
},
sExpr->u);
}
} else if (name == "min") { } else if (name == "min") {
return FoldMINorMAX(context, std::move(funcRef), Ordering::Less); return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
} else if (name == "minval") { } else if (name == "minval") {
@ -167,10 +192,58 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return FoldSum<T>(context, std::move(funcRef)); return FoldSum<T>(context, std::move(funcRef));
} else if (name == "tiny") { } else if (name == "tiny") {
return Expr<T>{Scalar<T>::TINY()}; return Expr<T>{Scalar<T>::TINY()};
} else if (name == "__builtin_ieee_next_after") {
if (const auto *yExpr{UnwrapExpr<Expr<SomeReal>>(args[1])}) {
return std::visit(
[&](const auto &yVal) {
using TY = ResultType<decltype(yVal)>;
return FoldElementalIntrinsic<T, T, TY>(context, std::move(funcRef),
ScalarFunc<T, T, TY>([&](const Scalar<T> &x,
const Scalar<TY> &y) -> Scalar<T> {
bool upward{true};
switch (x.Compare(Scalar<T>::Convert(y).value)) {
case Relation::Unordered:
context.messages().Say(
"IEEE_NEXT_AFTER intrinsic folding: bad argument"_warn_en_US);
return x;
case Relation::Equal:
return x;
case Relation::Less:
upward = true;
break;
case Relation::Greater:
upward = false;
break;
}
auto result{x.NEAREST(upward)};
if (result.flags.test(RealFlag::Overflow)) {
context.messages().Say(
"IEEE_NEXT_AFTER intrinsic folding overflow"_warn_en_US);
}
return result.value;
}));
},
yExpr->u);
}
} else if (name == "__builtin_ieee_next_up" ||
name == "__builtin_ieee_next_down") {
bool upward{name == "__builtin_ieee_next_up"};
const char *iName{upward ? "IEEE_NEXT_UP" : "IEEE_NEXT_DOWN"};
return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
ScalarFunc<T, T>([&](const Scalar<T> &x) -> Scalar<T> {
auto result{x.NEAREST(upward)};
if (result.flags.test(RealFlag::Overflow)) {
context.messages().Say(
"%s intrinsic folding overflow"_warn_en_US, iName);
} else if (result.flags.test(RealFlag::InvalidArgument)) {
context.messages().Say(
"%s intrinsic folding: bad argument"_warn_en_US, iName);
}
return result.value;
}));
} }
// TODO: dim, dot_product, fraction, matmul, // TODO: dim, dot_product, fraction, matmul,
// modulo, nearest, norm2, rrspacing, // modulo, norm2, rrspacing,
// __builtin_next_after/down/up,
// set_exponent, spacing, transfer, // set_exponent, spacing, transfer,
// bessel_jn (transformational) and bessel_yn (transformational) // bessel_jn (transformational) and bessel_yn (transformational)
return Expr<T>{std::move(funcRef)}; return Expr<T>{std::move(funcRef)};

View File

@ -346,6 +346,45 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
return result; return result;
} }
template <typename W, int P>
ValueWithRealFlags<Real<W, P>> Real<W, P>::NEAREST(bool upward) const {
ValueWithRealFlags<Real> result;
if (IsFinite()) {
Fraction fraction{GetFraction()};
int expo{Exponent()};
Fraction one{1};
Fraction nearest;
bool isNegative{IsNegative()};
if (upward != isNegative) { // upward in magnitude
auto next{fraction.AddUnsigned(one)};
if (next.carry) {
++expo;
nearest = Fraction::Least(); // MSB only
} else {
nearest = next.value;
}
} else { // downward in magnitude
if (IsZero()) {
nearest = 1; // smallest magnitude negative subnormal
isNegative = !isNegative;
} else {
auto sub1{fraction.SubtractSigned(one)};
if (sub1.overflow) {
nearest = Fraction{0}.NOT();
--expo;
} else {
nearest = sub1.value;
}
}
}
result.flags = result.value.Normalize(isNegative, expo, nearest);
} else {
result.flags.set(RealFlag::InvalidArgument);
result.value = *this;
}
return result;
}
// HYPOT(x,y) = SQRT(x**2 + y**2) by definition, but those squared intermediate // HYPOT(x,y) = SQRT(x**2 + y**2) by definition, but those squared intermediate
// values are susceptible to over/underflow when computed naively. // values are susceptible to over/underflow when computed naively.
// Assuming that x>=y, calculate instead: // Assuming that x>=y, calculate instead:

View File

@ -0,0 +1,88 @@
! RUN: %python %S/test_folding.py %s %flang_fc1
! Tests folding of NEAREST() and its relatives
module m1
real, parameter :: minSubnormal = 1.e-45
logical, parameter :: test_1 = nearest(0., 1.) == minSubnormal
logical, parameter :: test_2 = nearest(minSubnormal, -1.) == 0
logical, parameter :: test_3 = nearest(1., 1.) == 1.0000001
logical, parameter :: test_4 = nearest(1.0000001, -1.) == 1
!WARN: warning: NEAREST intrinsic folding overflow
real, parameter :: inf = nearest(huge(1.), 1.)
!WARN: warning: NEAREST intrinsic folding: bad argument
logical, parameter :: test_5 = nearest(inf, 1.) == inf
!WARN: warning: NEAREST intrinsic folding: bad argument
logical, parameter :: test_6 = nearest(-inf, -1.) == -inf
logical, parameter :: test_7 = nearest(1.9999999, 1.) == 2.
logical, parameter :: test_8 = nearest(2., -1.) == 1.9999999
logical, parameter :: test_9 = nearest(1.9999999999999999999_10, 1.) == 2._10
logical, parameter :: test_10 = nearest(-1., 1.) == -.99999994
logical, parameter :: test_11 = nearest(-1., -2.) == -1.0000001
real, parameter :: negZero = sign(0., -1.)
logical, parameter :: test_12 = nearest(negZero, 1.) == minSubnormal
logical, parameter :: test_13 = nearest(negZero, -1.) == -minSubnormal
!WARN: warning: NEAREST: S argument is zero
logical, parameter :: test_14 = nearest(0., negZero) == -minSubnormal
!WARN: warning: NEAREST: S argument is zero
logical, parameter :: test_15 = nearest(negZero, 0.) == minSubnormal
end module
module m2
use ieee_arithmetic, only: ieee_next_after
real, parameter :: minSubnormal = 1.e-45
logical, parameter :: test_0 = ieee_next_after(0., 0.) == 0.
logical, parameter :: test_1 = ieee_next_after(0., 1.) == minSubnormal
logical, parameter :: test_2 = ieee_next_after(minSubnormal, -1.) == 0
logical, parameter :: test_3 = ieee_next_after(1., 2.) == 1.0000001
logical, parameter :: test_4 = ieee_next_after(1.0000001, -1.) == 1
!WARN: warning: division by zero
real, parameter :: inf = 1. / 0.
logical, parameter :: test_5 = ieee_next_after(inf, inf) == inf
logical, parameter :: test_6 = ieee_next_after(inf, -inf) == inf
logical, parameter :: test_7 = ieee_next_after(-inf, inf) == -inf
logical, parameter :: test_8 = ieee_next_after(-inf, -1.) == -inf
logical, parameter :: test_9 = ieee_next_after(1.9999999, 3.) == 2.
logical, parameter :: test_10 = ieee_next_after(2., 1.) == 1.9999999
logical, parameter :: test_11 = ieee_next_after(1.9999999999999999999_10, 3.) == 2._10
logical, parameter :: test_12 = ieee_next_after(1., 1.) == 1.
!WARN: warning: invalid argument on division
real, parameter :: nan = 0. / 0.
!WARN: warning: IEEE_NEXT_AFTER intrinsic folding: bad argument
real, parameter :: x13 = ieee_next_after(nan, nan)
logical, parameter :: test_13 = .not. (x13 == x13)
!WARN: warning: IEEE_NEXT_AFTER intrinsic folding: bad argument
real, parameter :: x14 = ieee_next_after(nan, 0.)
logical, parameter :: test_14 = .not. (x14 == x14)
end module
module m3
use ieee_arithmetic, only: ieee_next_up, ieee_next_down
real(kind(0.d0)), parameter :: minSubnormal = 5.d-324
logical, parameter :: test_1 = ieee_next_up(0.d0) == minSubnormal
logical, parameter :: test_2 = ieee_next_down(0.d0) == -minSubnormal
logical, parameter :: test_3 = ieee_next_up(1.d0) == 1.0000000000000002d0
logical, parameter :: test_4 = ieee_next_down(1.0000000000000002d0) == 1.d0
!WARN: warning: division by zero
real(kind(0.d0)), parameter :: inf = 1.d0 / 0.d0
!WARN: warning: IEEE_NEXT_UP intrinsic folding overflow
logical, parameter :: test_5 = ieee_next_up(huge(0.d0)) == inf
!WARN: warning: IEEE_NEXT_DOWN intrinsic folding overflow
logical, parameter :: test_6 = ieee_next_down(-huge(0.d0)) == -inf
!WARN: warning: IEEE_NEXT_UP intrinsic folding: bad argument
logical, parameter :: test_7 = ieee_next_up(inf) == inf
!WARN: warning: IEEE_NEXT_DOWN intrinsic folding: bad argument
logical, parameter :: test_8 = ieee_next_down(inf) == inf
!WARN: warning: IEEE_NEXT_UP intrinsic folding: bad argument
logical, parameter :: test_9 = ieee_next_up(-inf) == -inf
!WARN: warning: IEEE_NEXT_DOWN intrinsic folding: bad argument
logical, parameter :: test_10 = ieee_next_down(-inf) == -inf
logical, parameter :: test_11 = ieee_next_up(1.9999999999999997d0) == 2.d0
logical, parameter :: test_12 = ieee_next_down(2.d0) == 1.9999999999999997d0
!WARN: warning: invalid argument on division
real(kind(0.d0)), parameter :: nan = 0.d0 / 0.d0
!WARN: warning: IEEE_NEXT_UP intrinsic folding: bad argument
real(kind(0.d0)), parameter :: x13 = ieee_next_up(nan)
logical, parameter :: test_13 = .not. (x13 == x13)
!WARN: warning: IEEE_NEXT_DOWN intrinsic folding: bad argument
real(kind(0.d0)), parameter :: x14 = ieee_next_down(nan)
logical, parameter :: test_14 = .not. (x14 == x14)
end module