[flang] Fold FLOOR, CEILING, NINT, and ANINT

Add GetUltimate() to ResolveAssociations(), fixing a UBOUND test case with use association

Fix folding of array-valued subscripts while I am in here

Original-commit: flang-compiler/f18@f663d4fef4
Reviewed-on: https://github.com/flang-compiler/f18/pull/905
This commit is contained in:
peter klausler 2020-01-03 11:34:16 -08:00
parent 8697c77bac
commit cc179ba749
11 changed files with 149 additions and 133 deletions

View File

@ -693,7 +693,7 @@ std::optional<Procedure> Procedure::Characterize(
const ProcedureDesignator &proc, const IntrinsicProcTable &intrinsics) {
if (const auto *symbol{proc.GetSymbol()}) {
if (auto result{characteristics::Procedure::Characterize(
symbol->GetUltimate(), intrinsics)}) {
ResolveAssociations(*symbol), intrinsics)}) {
return result;
}
} else if (const auto *intrinsic{proc.GetSpecificIntrinsic()}) {

View File

@ -67,6 +67,9 @@ private:
FoldingContext &context_;
};
std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
FoldingContext &, Subscript &, const NamedEntity &, int dim);
// FoldOperation() rewrites expression tree nodes.
// If there is any possibility that the rewritten node will
// not have the same representation type, the result of
@ -123,7 +126,7 @@ Expr<SomeDerived> FoldOperation(FoldingContext &, StructureConstructor &&);
template<typename T>
std::optional<Expr<T>> Folder<T>::GetNamedConstantValue(const Symbol &symbol0) {
const Symbol &symbol{ResolveAssociations(symbol0).GetUltimate()};
const Symbol &symbol{ResolveAssociations(symbol0)};
if (IsNamedConstant(symbol)) {
if (const auto *object{
symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
@ -214,50 +217,6 @@ std::optional<Constant<T>> Folder<T>::GetFoldedNamedConstantValue(
return std::nullopt;
}
static std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
ss = FoldOperation(context, std::move(ss));
return std::visit(
common::visitors{
[](IndirectSubscriptIntegerExpr &expr)
-> std::optional<Constant<SubscriptInteger>> {
if (auto constant{
GetScalarConstantValue<SubscriptInteger>(expr.value())}) {
return Constant<SubscriptInteger>{*constant};
} else {
return std::nullopt;
}
},
[&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
auto lower{triplet.lower()}, upper{triplet.upper()};
std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
if (!lower) {
lower = GetLowerBound(context, base, dim);
}
if (!upper) {
upper =
ComputeUpperBound(context, GetLowerBound(context, base, dim),
GetExtent(context, base, dim));
}
auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
if (lbi && ubi && stride && *stride != 0) {
std::vector<SubscriptInteger::Scalar> values;
while ((*stride > 0 && *lbi <= *ubi) ||
(*stride < 0 && *lbi >= *ubi)) {
values.emplace_back(*lbi);
*lbi += *stride;
}
return Constant<SubscriptInteger>{std::move(values),
ConstantSubscripts{
static_cast<ConstantSubscript>(values.size())}};
} else {
return std::nullopt;
}
},
},
ss.u);
}
template<typename T>
std::optional<Constant<T>> Folder<T>::Folding(ArrayRef &aRef) {
std::vector<Constant<SubscriptInteger>> subscripts;
@ -307,11 +266,11 @@ std::optional<Constant<T>> Folder<T>::ApplySubscripts(const Constant<T> &array,
at[j] = subscripts[j].GetScalarValue().value().ToInt64();
} else {
CHECK(k < GetRank(resultShape));
tmp[0] = ssLB[j] + ssAt[j];
tmp[0] = ssLB.at(k) + ssAt.at(k);
at[j] = subscripts[j].At(tmp).ToInt64();
if (increment) {
if (++ssAt[j] == resultShape[k]) {
ssAt[j] = 0;
if (++ssAt[k] == resultShape[k]) {
ssAt[k] = 0;
} else {
increment = false;
}

View File

@ -154,6 +154,28 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
}));
} else if (name == "bit_size") {
return Expr<T>{Scalar<T>::bits};
} else if (name == "ceiling" || name == "floor" || name == "nint") {
if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
// NINT rounds ties away from zero, not to even
RoundingMode mode{name == "ceiling"
? RoundingMode::Up
: name == "floor" ? RoundingMode::Down
: RoundingMode::TiesAwayFromZero};
return std::visit(
[&](const auto &kx) {
using TR = ResultType<decltype(kx)>;
return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
ScalarFunc<T, TR>([&](const Scalar<TR> &x) {
auto y{x.template ToInteger<Scalar<T>>(mode)};
if (y.flags.test(RealFlag::Overflow)) {
context.messages().Say(
"%s intrinsic folding overflow"_en_US, name);
}
return y.value;
}));
},
cx->u);
}
} else if (name == "count") {
if (!args[1]) { // TODO: COUNT(x,DIM=d)
if (const auto *constant{UnwrapConstantValue<LogicalResult>(args[0])}) {
@ -503,10 +525,10 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
return UBOUND(context, std::move(funcRef));
}
// TODO:
// ceiling, cshift, dot_product, eoshift,
// findloc, floor, iall, iany, iparity, ibits, image_status, index, ishftc,
// cshift, dot_product, eoshift,
// findloc, iall, iany, iparity, ibits, image_status, index, ishftc,
// len_trim, matmul, maxloc, maxval,
// minloc, minval, nint, not, pack, product, reduce,
// minloc, minval, not, pack, product, reduce,
// scan, sign, spread, sum, transfer, transpose, unpack, verify
return Expr<T>{std::move(funcRef)};
}

View File

@ -87,10 +87,14 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
} else if (name == "aimag") {
return FoldElementalIntrinsic<T, ComplexT>(
context, std::move(funcRef), &Scalar<ComplexT>::AIMAG);
} else if (name == "aint") {
} else if (name == "aint" || name == "anint") {
// ANINT rounds ties away from zero, not to even
RoundingMode mode{
name == "aint" ? RoundingMode::ToZero : RoundingMode::TiesAwayFromZero};
return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
ScalarFunc<T, T>([&name, &context](const Scalar<T> &x) -> Scalar<T> {
ValueWithRealFlags<Scalar<T>> y{x.AINT()};
ScalarFunc<T, T>([&name, &context, mode](
const Scalar<T> &x) -> Scalar<T> {
ValueWithRealFlags<Scalar<T>> y{x.ToWholeNumber(mode)};
if (y.flags.test(RealFlag::Overflow)) {
context.messages().Say("%s intrinsic folding overflow"_en_US, name);
}
@ -125,7 +129,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
} else if (name == "tiny") {
return Expr<T>{Scalar<T>::TINY()};
}
// TODO: anint, cshift, dim, dot_product, eoshift, fraction, matmul,
// TODO: cshift, dim, dot_product, eoshift, fraction, matmul,
// maxval, minval, modulo, nearest, norm2, pack, product,
// reduce, rrspacing, scale, set_exponent, spacing, spread,
// sum, transfer, transpose, unpack, bessel_jn (transformational) and

View File

@ -11,6 +11,50 @@
namespace Fortran::evaluate {
std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
ss = FoldOperation(context, std::move(ss));
return std::visit(
common::visitors{
[](IndirectSubscriptIntegerExpr &expr)
-> std::optional<Constant<SubscriptInteger>> {
if (const auto *constant{
UnwrapConstantValue<SubscriptInteger>(expr.value())}) {
return *constant;
} else {
return std::nullopt;
}
},
[&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
auto lower{triplet.lower()}, upper{triplet.upper()};
std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
if (!lower) {
lower = GetLowerBound(context, base, dim);
}
if (!upper) {
upper =
ComputeUpperBound(context, GetLowerBound(context, base, dim),
GetExtent(context, base, dim));
}
auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
if (lbi && ubi && stride && *stride != 0) {
std::vector<SubscriptInteger::Scalar> values;
while ((*stride > 0 && *lbi <= *ubi) ||
(*stride < 0 && *lbi >= *ubi)) {
values.emplace_back(*lbi);
*lbi += *stride;
}
return Constant<SubscriptInteger>{std::move(values),
ConstantSubscripts{
static_cast<ConstantSubscript>(values.size())}};
} else {
return std::nullopt;
}
},
},
ss.u);
}
Expr<SomeDerived> FoldOperation(
FoldingContext &context, StructureConstructor &&structure) {
StructureConstructor result{structure.derivedTypeSpec()};

View File

@ -261,6 +261,32 @@ ValueWithRealFlags<Real<W, P, IM>> Real<W, P, IM>::Divide(
return result;
}
template<typename W, int P, bool IM>
ValueWithRealFlags<Real<W, P, IM>> Real<W, P, IM>::ToWholeNumber(
RoundingMode mode) const {
ValueWithRealFlags<Real> result{*this};
if (IsNotANumber()) {
result.flags.set(RealFlag::InvalidArgument);
result.value = NotANumber();
} else if (IsInfinite()) {
result.flags.set(RealFlag::Overflow);
} else {
constexpr int noClipExponent{exponentBias + precision - 1};
if (Exponent() < noClipExponent) {
Real adjust; // ABS(EPSILON(adjust)) == 0.5
adjust.Normalize(IsSignBitSet(), noClipExponent, Fraction::MASKL(1));
// Compute ival=(*this + adjust), losing any fractional bits; keep flags
result = Add(adjust, Rounding{mode});
result.flags.reset(RealFlag::Inexact); // result *is* exact
// Return (ival-adjust) with original sign in case we've generated a zero.
result.value =
result.value.Subtract(adjust, Rounding{RoundingMode::ToZero})
.value.SIGN(*this);
}
}
return result;
}
template<typename W, int P, bool IM>
RealFlags Real<W, P, IM>::Normalize(bool negative, int exponent,
const Fraction &fraction, Rounding rounding, RoundingBits *roundingBits) {

View File

@ -61,7 +61,7 @@ public:
return word_ == that.word_;
}
// TODO: ANINT, CEILING, FLOOR, DIM, MAX, MIN, DPROD, FRACTION,
// TODO: DIM, MAX, MIN, DPROD, FRACTION,
// INT/NINT, NEAREST, OUT_OF_RANGE,
// RRSPACING/SPACING, SCALE, SET_EXPONENT
@ -205,86 +205,40 @@ public:
return result;
}
// Truncation to integer in same real format.
constexpr ValueWithRealFlags<Real> AINT() const {
ValueWithRealFlags<Real> result{*this};
if (IsNotANumber()) {
result.flags.set(RealFlag::InvalidArgument);
result.value = NotANumber();
} else if (IsInfinite()) {
result.flags.set(RealFlag::Overflow);
} else {
int exponent{Exponent()};
if (exponent < exponentBias) { // |x| < 1.0
result.value.Normalize(IsNegative(), 0, Fraction{}); // +/-0.0
} else {
constexpr int noClipExponent{exponentBias + precision - 1};
if (int clip = noClipExponent - exponent; clip > 0) {
result.value.word_ = result.value.word_.IAND(Word::MASKR(clip).NOT());
}
}
}
return result;
}
// Conversion to integer in the same real format (AINT(), ANINT())
ValueWithRealFlags<Real> ToWholeNumber(
RoundingMode = RoundingMode::ToZero) const;
template<typename INT> constexpr ValueWithRealFlags<INT> ToInteger() const {
// Conversion to an integer (INT(), NINT(), FLOOR(), CEILING())
template<typename INT>
constexpr ValueWithRealFlags<INT> ToInteger(
RoundingMode mode = RoundingMode::ToZero) const {
ValueWithRealFlags<INT> result;
if (IsNotANumber()) {
result.flags.set(RealFlag::InvalidArgument);
result.value = result.value.HUGE();
return result;
}
bool isNegative{IsNegative()};
int exponent{Exponent()};
Fraction fraction{GetFraction()};
if (exponent >= maxExponent || // +/-Inf
exponent >= exponentBias + result.value.bits) { // too big
if (isNegative) {
result.value = result.value.MASKL(1); // most negative integer value
} else {
result.value = result.value.HUGE(); // most positive integer value
}
result.flags.set(RealFlag::Overflow);
} else if (exponent < exponentBias) { // |x| < 1.0 -> 0
if (!fraction.IsZero()) {
result.flags.set(RealFlag::Underflow);
result.flags.set(RealFlag::Inexact);
}
} else {
// finite number |x| >= 1.0
constexpr int noShiftExponent{exponentBias + precision - 1};
if (exponent < noShiftExponent) {
int rshift = noShiftExponent - exponent;
if (!fraction.IBITS(0, rshift).IsZero()) {
result.flags.set(RealFlag::Inexact);
}
auto truncated{result.value.ConvertUnsigned(fraction.SHIFTR(rshift))};
if (truncated.overflow) {
result.flags.set(RealFlag::Overflow);
} else {
result.value = truncated.value;
}
} else {
int lshift = exponent - noShiftExponent;
if (lshift + precision >= result.value.bits) {
result.flags.set(RealFlag::Overflow);
} else {
result.value =
result.value.ConvertUnsigned(fraction).value.SHIFTL(lshift);
}
}
if (result.flags.test(RealFlag::Overflow)) {
result.value = result.value.HUGE();
} else if (isNegative) {
auto negated{result.value.Negate()};
if (negated.overflow) {
result.flags.set(RealFlag::Overflow);
result.value = result.value.HUGE();
} else {
result.value = negated.value;
}
ValueWithRealFlags<Real> intPart{ToWholeNumber(mode)};
int exponent{intPart.value.Exponent()};
result.flags.set(
RealFlag::Overflow, exponent >= exponentBias + result.value.bits);
result.flags |= intPart.flags;
int shift{exponent - exponentBias - precision + 1}; // positive -> left
result.value =
result.value.ConvertUnsigned(intPart.value.GetFraction().SHIFTR(-shift))
.value.SHIFTL(shift);
if (IsSignBitSet()) {
auto negated{result.value.Negate()};
result.value = negated.value;
if (negated.overflow) {
result.flags.set(RealFlag::Overflow);
}
}
if (result.flags.test(RealFlag::Overflow)) {
result.value =
IsSignBitSet() ? result.value.MASKL(1) : result.value.HUGE();
}
return result;
}

View File

@ -739,7 +739,7 @@ const Symbol &ResolveAssociations(const Symbol &symbol) {
return ResolveAssociations(*nested);
}
}
return symbol;
return symbol.GetUltimate();
}
struct CollectSymbolsHelper

View File

@ -803,7 +803,7 @@ template<typename A> SymbolVector GetSymbolVector(const A &x) {
// when none is found.
const Symbol *GetLastTarget(const SymbolVector &);
// Resolves any whole ASSOCIATE(B=>A) associations
// Resolves any whole ASSOCIATE(B=>A) associations, then returns GetUltimate()
const Symbol &ResolveAssociations(const Symbol &);
// Collects all of the Symbols in an expression

View File

@ -32,6 +32,11 @@ module m
TEST_R4(acos, acos(0.5_4), 1.0471975803375244140625_4)
TEST_R4(acosh, acosh(1.5_4), 0.96242368221282958984375_4)
logical, parameter :: test_aint1 = aint(2.783).EQ.(2.)
logical, parameter :: test_anint1 = anint(2.783).EQ.(3.)
logical, parameter :: test_floor1 = floor(-2.783).EQ.(-3.)
logical, parameter :: test_floor2 = floor(2.783).EQ.(2.)
logical, parameter :: test_ceiling1 = ceiling(-2.783).EQ.(-2.)
logical, parameter :: test_ceiling2 = ceiling(2.783).EQ.(3.)
TEST_R4(asin, asin(0.9_4), 1.11976945400238037109375_4)
TEST_R4(asinh, asinh(1._4), 0.881373584270477294921875_4)
TEST_R4(atan, atan(1.5_4), 0.982793748378753662109375_4)

View File

@ -161,7 +161,8 @@ template<typename R> void basicTests(int rm, Rounding rounding) {
TEST(vr.value.Compare(check.value) == Relation::Equal)(ldesc);
}
}
TEST(vr.value.AINT().value.Compare(vr.value) == Relation::Equal)(ldesc);
TEST(vr.value.ToWholeNumber().value.Compare(vr.value) == Relation::Equal)
(ldesc);
ix = ix.Negate().value;
TEST(ix.IsNegative())(ldesc);
x = -x;
@ -185,7 +186,8 @@ template<typename R> void basicTests(int rm, Rounding rounding) {
MATCH(x, ivf.value.ToUInt64())(ldesc);
MATCH(nx, ivf.value.ToInt64())(ldesc);
}
TEST(vr.value.AINT().value.Compare(vr.value) == Relation::Equal)(ldesc);
TEST(vr.value.ToWholeNumber().value.Compare(vr.value) == Relation::Equal)
(ldesc);
}
}
@ -368,7 +370,7 @@ void subsetTests(int pass, Rounding rounding, std::uint32_t opds) {
// unary operations
{
ValueWithRealFlags<REAL> aint{x.AINT()};
ValueWithRealFlags<REAL> aint{x.ToWholeNumber()};
#ifndef __clang__ // broken and also slow
fpenv.ClearFlags();
#endif