[flang] Fold FLOOR, CEILING, NINT, and ANINT

Add GetUltimate() to ResolveAssociations(), fixing a UBOUND test case with use association

Fix folding of array-valued subscripts while I am in here

Original-commit: flang-compiler/f18@f663d4fef4
Reviewed-on: https://github.com/flang-compiler/f18/pull/905
This commit is contained in:
peter klausler 2020-01-03 11:34:16 -08:00
parent 8697c77bac
commit cc179ba749
11 changed files with 149 additions and 133 deletions

View File

@ -693,7 +693,7 @@ std::optional<Procedure> Procedure::Characterize(
const ProcedureDesignator &proc, const IntrinsicProcTable &intrinsics) { const ProcedureDesignator &proc, const IntrinsicProcTable &intrinsics) {
if (const auto *symbol{proc.GetSymbol()}) { if (const auto *symbol{proc.GetSymbol()}) {
if (auto result{characteristics::Procedure::Characterize( if (auto result{characteristics::Procedure::Characterize(
symbol->GetUltimate(), intrinsics)}) { ResolveAssociations(*symbol), intrinsics)}) {
return result; return result;
} }
} else if (const auto *intrinsic{proc.GetSpecificIntrinsic()}) { } else if (const auto *intrinsic{proc.GetSpecificIntrinsic()}) {

View File

@ -67,6 +67,9 @@ private:
FoldingContext &context_; FoldingContext &context_;
}; };
std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
FoldingContext &, Subscript &, const NamedEntity &, int dim);
// FoldOperation() rewrites expression tree nodes. // FoldOperation() rewrites expression tree nodes.
// If there is any possibility that the rewritten node will // If there is any possibility that the rewritten node will
// not have the same representation type, the result of // not have the same representation type, the result of
@ -123,7 +126,7 @@ Expr<SomeDerived> FoldOperation(FoldingContext &, StructureConstructor &&);
template<typename T> template<typename T>
std::optional<Expr<T>> Folder<T>::GetNamedConstantValue(const Symbol &symbol0) { std::optional<Expr<T>> Folder<T>::GetNamedConstantValue(const Symbol &symbol0) {
const Symbol &symbol{ResolveAssociations(symbol0).GetUltimate()}; const Symbol &symbol{ResolveAssociations(symbol0)};
if (IsNamedConstant(symbol)) { if (IsNamedConstant(symbol)) {
if (const auto *object{ if (const auto *object{
symbol.detailsIf<semantics::ObjectEntityDetails>()}) { symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
@ -214,50 +217,6 @@ std::optional<Constant<T>> Folder<T>::GetFoldedNamedConstantValue(
return std::nullopt; return std::nullopt;
} }
static std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
ss = FoldOperation(context, std::move(ss));
return std::visit(
common::visitors{
[](IndirectSubscriptIntegerExpr &expr)
-> std::optional<Constant<SubscriptInteger>> {
if (auto constant{
GetScalarConstantValue<SubscriptInteger>(expr.value())}) {
return Constant<SubscriptInteger>{*constant};
} else {
return std::nullopt;
}
},
[&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
auto lower{triplet.lower()}, upper{triplet.upper()};
std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
if (!lower) {
lower = GetLowerBound(context, base, dim);
}
if (!upper) {
upper =
ComputeUpperBound(context, GetLowerBound(context, base, dim),
GetExtent(context, base, dim));
}
auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
if (lbi && ubi && stride && *stride != 0) {
std::vector<SubscriptInteger::Scalar> values;
while ((*stride > 0 && *lbi <= *ubi) ||
(*stride < 0 && *lbi >= *ubi)) {
values.emplace_back(*lbi);
*lbi += *stride;
}
return Constant<SubscriptInteger>{std::move(values),
ConstantSubscripts{
static_cast<ConstantSubscript>(values.size())}};
} else {
return std::nullopt;
}
},
},
ss.u);
}
template<typename T> template<typename T>
std::optional<Constant<T>> Folder<T>::Folding(ArrayRef &aRef) { std::optional<Constant<T>> Folder<T>::Folding(ArrayRef &aRef) {
std::vector<Constant<SubscriptInteger>> subscripts; std::vector<Constant<SubscriptInteger>> subscripts;
@ -307,11 +266,11 @@ std::optional<Constant<T>> Folder<T>::ApplySubscripts(const Constant<T> &array,
at[j] = subscripts[j].GetScalarValue().value().ToInt64(); at[j] = subscripts[j].GetScalarValue().value().ToInt64();
} else { } else {
CHECK(k < GetRank(resultShape)); CHECK(k < GetRank(resultShape));
tmp[0] = ssLB[j] + ssAt[j]; tmp[0] = ssLB.at(k) + ssAt.at(k);
at[j] = subscripts[j].At(tmp).ToInt64(); at[j] = subscripts[j].At(tmp).ToInt64();
if (increment) { if (increment) {
if (++ssAt[j] == resultShape[k]) { if (++ssAt[k] == resultShape[k]) {
ssAt[j] = 0; ssAt[k] = 0;
} else { } else {
increment = false; increment = false;
} }

View File

@ -154,6 +154,28 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
})); }));
} else if (name == "bit_size") { } else if (name == "bit_size") {
return Expr<T>{Scalar<T>::bits}; return Expr<T>{Scalar<T>::bits};
} else if (name == "ceiling" || name == "floor" || name == "nint") {
if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
// NINT rounds ties away from zero, not to even
RoundingMode mode{name == "ceiling"
? RoundingMode::Up
: name == "floor" ? RoundingMode::Down
: RoundingMode::TiesAwayFromZero};
return std::visit(
[&](const auto &kx) {
using TR = ResultType<decltype(kx)>;
return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
ScalarFunc<T, TR>([&](const Scalar<TR> &x) {
auto y{x.template ToInteger<Scalar<T>>(mode)};
if (y.flags.test(RealFlag::Overflow)) {
context.messages().Say(
"%s intrinsic folding overflow"_en_US, name);
}
return y.value;
}));
},
cx->u);
}
} else if (name == "count") { } else if (name == "count") {
if (!args[1]) { // TODO: COUNT(x,DIM=d) if (!args[1]) { // TODO: COUNT(x,DIM=d)
if (const auto *constant{UnwrapConstantValue<LogicalResult>(args[0])}) { if (const auto *constant{UnwrapConstantValue<LogicalResult>(args[0])}) {
@ -503,10 +525,10 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
return UBOUND(context, std::move(funcRef)); return UBOUND(context, std::move(funcRef));
} }
// TODO: // TODO:
// ceiling, cshift, dot_product, eoshift, // cshift, dot_product, eoshift,
// findloc, floor, iall, iany, iparity, ibits, image_status, index, ishftc, // findloc, iall, iany, iparity, ibits, image_status, index, ishftc,
// len_trim, matmul, maxloc, maxval, // len_trim, matmul, maxloc, maxval,
// minloc, minval, nint, not, pack, product, reduce, // minloc, minval, not, pack, product, reduce,
// scan, sign, spread, sum, transfer, transpose, unpack, verify // scan, sign, spread, sum, transfer, transpose, unpack, verify
return Expr<T>{std::move(funcRef)}; return Expr<T>{std::move(funcRef)};
} }

View File

@ -87,10 +87,14 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
} else if (name == "aimag") { } else if (name == "aimag") {
return FoldElementalIntrinsic<T, ComplexT>( return FoldElementalIntrinsic<T, ComplexT>(
context, std::move(funcRef), &Scalar<ComplexT>::AIMAG); context, std::move(funcRef), &Scalar<ComplexT>::AIMAG);
} else if (name == "aint") { } else if (name == "aint" || name == "anint") {
// ANINT rounds ties away from zero, not to even
RoundingMode mode{
name == "aint" ? RoundingMode::ToZero : RoundingMode::TiesAwayFromZero};
return FoldElementalIntrinsic<T, T>(context, std::move(funcRef), return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
ScalarFunc<T, T>([&name, &context](const Scalar<T> &x) -> Scalar<T> { ScalarFunc<T, T>([&name, &context, mode](
ValueWithRealFlags<Scalar<T>> y{x.AINT()}; const Scalar<T> &x) -> Scalar<T> {
ValueWithRealFlags<Scalar<T>> y{x.ToWholeNumber(mode)};
if (y.flags.test(RealFlag::Overflow)) { if (y.flags.test(RealFlag::Overflow)) {
context.messages().Say("%s intrinsic folding overflow"_en_US, name); context.messages().Say("%s intrinsic folding overflow"_en_US, name);
} }
@ -125,7 +129,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
} else if (name == "tiny") { } else if (name == "tiny") {
return Expr<T>{Scalar<T>::TINY()}; return Expr<T>{Scalar<T>::TINY()};
} }
// TODO: anint, cshift, dim, dot_product, eoshift, fraction, matmul, // TODO: cshift, dim, dot_product, eoshift, fraction, matmul,
// maxval, minval, modulo, nearest, norm2, pack, product, // maxval, minval, modulo, nearest, norm2, pack, product,
// reduce, rrspacing, scale, set_exponent, spacing, spread, // reduce, rrspacing, scale, set_exponent, spacing, spread,
// sum, transfer, transpose, unpack, bessel_jn (transformational) and // sum, transfer, transpose, unpack, bessel_jn (transformational) and

View File

@ -11,6 +11,50 @@
namespace Fortran::evaluate { namespace Fortran::evaluate {
std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
ss = FoldOperation(context, std::move(ss));
return std::visit(
common::visitors{
[](IndirectSubscriptIntegerExpr &expr)
-> std::optional<Constant<SubscriptInteger>> {
if (const auto *constant{
UnwrapConstantValue<SubscriptInteger>(expr.value())}) {
return *constant;
} else {
return std::nullopt;
}
},
[&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
auto lower{triplet.lower()}, upper{triplet.upper()};
std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
if (!lower) {
lower = GetLowerBound(context, base, dim);
}
if (!upper) {
upper =
ComputeUpperBound(context, GetLowerBound(context, base, dim),
GetExtent(context, base, dim));
}
auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
if (lbi && ubi && stride && *stride != 0) {
std::vector<SubscriptInteger::Scalar> values;
while ((*stride > 0 && *lbi <= *ubi) ||
(*stride < 0 && *lbi >= *ubi)) {
values.emplace_back(*lbi);
*lbi += *stride;
}
return Constant<SubscriptInteger>{std::move(values),
ConstantSubscripts{
static_cast<ConstantSubscript>(values.size())}};
} else {
return std::nullopt;
}
},
},
ss.u);
}
Expr<SomeDerived> FoldOperation( Expr<SomeDerived> FoldOperation(
FoldingContext &context, StructureConstructor &&structure) { FoldingContext &context, StructureConstructor &&structure) {
StructureConstructor result{structure.derivedTypeSpec()}; StructureConstructor result{structure.derivedTypeSpec()};

View File

@ -261,6 +261,32 @@ ValueWithRealFlags<Real<W, P, IM>> Real<W, P, IM>::Divide(
return result; return result;
} }
template<typename W, int P, bool IM>
ValueWithRealFlags<Real<W, P, IM>> Real<W, P, IM>::ToWholeNumber(
RoundingMode mode) const {
ValueWithRealFlags<Real> result{*this};
if (IsNotANumber()) {
result.flags.set(RealFlag::InvalidArgument);
result.value = NotANumber();
} else if (IsInfinite()) {
result.flags.set(RealFlag::Overflow);
} else {
constexpr int noClipExponent{exponentBias + precision - 1};
if (Exponent() < noClipExponent) {
Real adjust; // ABS(EPSILON(adjust)) == 0.5
adjust.Normalize(IsSignBitSet(), noClipExponent, Fraction::MASKL(1));
// Compute ival=(*this + adjust), losing any fractional bits; keep flags
result = Add(adjust, Rounding{mode});
result.flags.reset(RealFlag::Inexact); // result *is* exact
// Return (ival-adjust) with original sign in case we've generated a zero.
result.value =
result.value.Subtract(adjust, Rounding{RoundingMode::ToZero})
.value.SIGN(*this);
}
}
return result;
}
template<typename W, int P, bool IM> template<typename W, int P, bool IM>
RealFlags Real<W, P, IM>::Normalize(bool negative, int exponent, RealFlags Real<W, P, IM>::Normalize(bool negative, int exponent,
const Fraction &fraction, Rounding rounding, RoundingBits *roundingBits) { const Fraction &fraction, Rounding rounding, RoundingBits *roundingBits) {

View File

@ -61,7 +61,7 @@ public:
return word_ == that.word_; return word_ == that.word_;
} }
// TODO: ANINT, CEILING, FLOOR, DIM, MAX, MIN, DPROD, FRACTION, // TODO: DIM, MAX, MIN, DPROD, FRACTION,
// INT/NINT, NEAREST, OUT_OF_RANGE, // INT/NINT, NEAREST, OUT_OF_RANGE,
// RRSPACING/SPACING, SCALE, SET_EXPONENT // RRSPACING/SPACING, SCALE, SET_EXPONENT
@ -205,86 +205,40 @@ public:
return result; return result;
} }
// Truncation to integer in same real format. // Conversion to integer in the same real format (AINT(), ANINT())
constexpr ValueWithRealFlags<Real> AINT() const { ValueWithRealFlags<Real> ToWholeNumber(
ValueWithRealFlags<Real> result{*this}; RoundingMode = RoundingMode::ToZero) const;
if (IsNotANumber()) {
result.flags.set(RealFlag::InvalidArgument);
result.value = NotANumber();
} else if (IsInfinite()) {
result.flags.set(RealFlag::Overflow);
} else {
int exponent{Exponent()};
if (exponent < exponentBias) { // |x| < 1.0
result.value.Normalize(IsNegative(), 0, Fraction{}); // +/-0.0
} else {
constexpr int noClipExponent{exponentBias + precision - 1};
if (int clip = noClipExponent - exponent; clip > 0) {
result.value.word_ = result.value.word_.IAND(Word::MASKR(clip).NOT());
}
}
}
return result;
}
template<typename INT> constexpr ValueWithRealFlags<INT> ToInteger() const { // Conversion to an integer (INT(), NINT(), FLOOR(), CEILING())
template<typename INT>
constexpr ValueWithRealFlags<INT> ToInteger(
RoundingMode mode = RoundingMode::ToZero) const {
ValueWithRealFlags<INT> result; ValueWithRealFlags<INT> result;
if (IsNotANumber()) { if (IsNotANumber()) {
result.flags.set(RealFlag::InvalidArgument); result.flags.set(RealFlag::InvalidArgument);
result.value = result.value.HUGE(); result.value = result.value.HUGE();
return result; return result;
} }
bool isNegative{IsNegative()}; ValueWithRealFlags<Real> intPart{ToWholeNumber(mode)};
int exponent{Exponent()}; int exponent{intPart.value.Exponent()};
Fraction fraction{GetFraction()}; result.flags.set(
if (exponent >= maxExponent || // +/-Inf RealFlag::Overflow, exponent >= exponentBias + result.value.bits);
exponent >= exponentBias + result.value.bits) { // too big result.flags |= intPart.flags;
if (isNegative) { int shift{exponent - exponentBias - precision + 1}; // positive -> left
result.value = result.value.MASKL(1); // most negative integer value result.value =
} else { result.value.ConvertUnsigned(intPart.value.GetFraction().SHIFTR(-shift))
result.value = result.value.HUGE(); // most positive integer value .value.SHIFTL(shift);
} if (IsSignBitSet()) {
result.flags.set(RealFlag::Overflow); auto negated{result.value.Negate()};
} else if (exponent < exponentBias) { // |x| < 1.0 -> 0 result.value = negated.value;
if (!fraction.IsZero()) { if (negated.overflow) {
result.flags.set(RealFlag::Underflow); result.flags.set(RealFlag::Overflow);
result.flags.set(RealFlag::Inexact);
}
} else {
// finite number |x| >= 1.0
constexpr int noShiftExponent{exponentBias + precision - 1};
if (exponent < noShiftExponent) {
int rshift = noShiftExponent - exponent;
if (!fraction.IBITS(0, rshift).IsZero()) {
result.flags.set(RealFlag::Inexact);
}
auto truncated{result.value.ConvertUnsigned(fraction.SHIFTR(rshift))};
if (truncated.overflow) {
result.flags.set(RealFlag::Overflow);
} else {
result.value = truncated.value;
}
} else {
int lshift = exponent - noShiftExponent;
if (lshift + precision >= result.value.bits) {
result.flags.set(RealFlag::Overflow);
} else {
result.value =
result.value.ConvertUnsigned(fraction).value.SHIFTL(lshift);
}
}
if (result.flags.test(RealFlag::Overflow)) {
result.value = result.value.HUGE();
} else if (isNegative) {
auto negated{result.value.Negate()};
if (negated.overflow) {
result.flags.set(RealFlag::Overflow);
result.value = result.value.HUGE();
} else {
result.value = negated.value;
}
} }
} }
if (result.flags.test(RealFlag::Overflow)) {
result.value =
IsSignBitSet() ? result.value.MASKL(1) : result.value.HUGE();
}
return result; return result;
} }

View File

@ -739,7 +739,7 @@ const Symbol &ResolveAssociations(const Symbol &symbol) {
return ResolveAssociations(*nested); return ResolveAssociations(*nested);
} }
} }
return symbol; return symbol.GetUltimate();
} }
struct CollectSymbolsHelper struct CollectSymbolsHelper

View File

@ -803,7 +803,7 @@ template<typename A> SymbolVector GetSymbolVector(const A &x) {
// when none is found. // when none is found.
const Symbol *GetLastTarget(const SymbolVector &); const Symbol *GetLastTarget(const SymbolVector &);
// Resolves any whole ASSOCIATE(B=>A) associations // Resolves any whole ASSOCIATE(B=>A) associations, then returns GetUltimate()
const Symbol &ResolveAssociations(const Symbol &); const Symbol &ResolveAssociations(const Symbol &);
// Collects all of the Symbols in an expression // Collects all of the Symbols in an expression

View File

@ -32,6 +32,11 @@ module m
TEST_R4(acos, acos(0.5_4), 1.0471975803375244140625_4) TEST_R4(acos, acos(0.5_4), 1.0471975803375244140625_4)
TEST_R4(acosh, acosh(1.5_4), 0.96242368221282958984375_4) TEST_R4(acosh, acosh(1.5_4), 0.96242368221282958984375_4)
logical, parameter :: test_aint1 = aint(2.783).EQ.(2.) logical, parameter :: test_aint1 = aint(2.783).EQ.(2.)
logical, parameter :: test_anint1 = anint(2.783).EQ.(3.)
logical, parameter :: test_floor1 = floor(-2.783).EQ.(-3.)
logical, parameter :: test_floor2 = floor(2.783).EQ.(2.)
logical, parameter :: test_ceiling1 = ceiling(-2.783).EQ.(-2.)
logical, parameter :: test_ceiling2 = ceiling(2.783).EQ.(3.)
TEST_R4(asin, asin(0.9_4), 1.11976945400238037109375_4) TEST_R4(asin, asin(0.9_4), 1.11976945400238037109375_4)
TEST_R4(asinh, asinh(1._4), 0.881373584270477294921875_4) TEST_R4(asinh, asinh(1._4), 0.881373584270477294921875_4)
TEST_R4(atan, atan(1.5_4), 0.982793748378753662109375_4) TEST_R4(atan, atan(1.5_4), 0.982793748378753662109375_4)

View File

@ -161,7 +161,8 @@ template<typename R> void basicTests(int rm, Rounding rounding) {
TEST(vr.value.Compare(check.value) == Relation::Equal)(ldesc); TEST(vr.value.Compare(check.value) == Relation::Equal)(ldesc);
} }
} }
TEST(vr.value.AINT().value.Compare(vr.value) == Relation::Equal)(ldesc); TEST(vr.value.ToWholeNumber().value.Compare(vr.value) == Relation::Equal)
(ldesc);
ix = ix.Negate().value; ix = ix.Negate().value;
TEST(ix.IsNegative())(ldesc); TEST(ix.IsNegative())(ldesc);
x = -x; x = -x;
@ -185,7 +186,8 @@ template<typename R> void basicTests(int rm, Rounding rounding) {
MATCH(x, ivf.value.ToUInt64())(ldesc); MATCH(x, ivf.value.ToUInt64())(ldesc);
MATCH(nx, ivf.value.ToInt64())(ldesc); MATCH(nx, ivf.value.ToInt64())(ldesc);
} }
TEST(vr.value.AINT().value.Compare(vr.value) == Relation::Equal)(ldesc); TEST(vr.value.ToWholeNumber().value.Compare(vr.value) == Relation::Equal)
(ldesc);
} }
} }
@ -368,7 +370,7 @@ void subsetTests(int pass, Rounding rounding, std::uint32_t opds) {
// unary operations // unary operations
{ {
ValueWithRealFlags<REAL> aint{x.AINT()}; ValueWithRealFlags<REAL> aint{x.ToWholeNumber()};
#ifndef __clang__ // broken and also slow #ifndef __clang__ // broken and also slow
fpenv.ClearFlags(); fpenv.ClearFlags();
#endif #endif