[flang] Fold LBOUND and UBOUND; do not insert empty triplets into whole array expressions

Original-commit: flang-compiler/f18@82fba68a66
Reviewed-on: https://github.com/flang-compiler/f18/pull/611
Tree-same-pre-rewrite: false
This commit is contained in:
peter klausler 2019-07-30 16:51:25 -07:00
parent a7041f3a78
commit 43b3e49490
10 changed files with 242 additions and 52 deletions

View File

@ -31,6 +31,7 @@
#include "../parser/message.h"
#include "../semantics/scope.h"
#include "../semantics/symbol.h"
#include "../semantics/tools.h"
#include <cmath>
#include <complex>
#include <cstdio>
@ -206,15 +207,13 @@ using ScalarFuncWithContext =
template<typename T>
static inline Constant<T> *FoldConvertedArg(
FoldingContext &context, std::optional<ActualArgument> &arg) {
if (arg.has_value()) {
if (auto *expr{arg->UnwrapExpr()}) {
if (UnwrapExpr<Expr<T>>(*expr) == nullptr) {
if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
*expr = Fold(context, std::move(*converted));
}
if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
if (UnwrapExpr<Expr<T>>(*expr) == nullptr) {
if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
*expr = Fold(context, std::move(*converted));
}
return UnwrapConstantValue<T>(*expr);
}
return UnwrapConstantValue<T>(*expr);
}
return nullptr;
}
@ -533,7 +532,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
return std::invoke(fptr, i, static_cast<int>(pos.ToInt64()));
}));
} else if (name == "int") {
if (auto *expr{args[0].value().UnwrapExpr()}) {
if (auto *expr{UnwrapExpr<Expr<SomeType>>(args[0])}) {
return std::visit(
[&](auto &&x) -> Expr<T> {
using From = std::decay_t<decltype(x)>;
@ -551,6 +550,52 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else {
common::die("kind() result not integral");
}
} else if (name == "lbound") {
if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
if (int rank{array->Rank()}) {
std::optional<std::int64_t> dim;
if (args[1].has_value()) {
dim = GetInt64Arg(args[1]);
if (!dim.has_value()) {
// DIM= is present but not constant
return Expr<T>{std::move(funcRef)};
} else if (*dim < 1 || *dim > rank) {
context.messages().Say(
"LBOUND(array,dim=%jd) dimension is out of range for rank-%d array"_en_US,
static_cast<std::intmax_t>(*dim), rank);
return Expr<T>(std::move(funcRef));
}
}
bool lowerBoundsAreOne{true};
if (auto named{ExtractNamedEntity(*array)}) {
const Symbol &symbol{named->GetLastSymbol()};
if (symbol.Rank() == rank) {
lowerBoundsAreOne = false;
if (dim.has_value()) {
if (auto lb{
GetLowerBound(context, *named, static_cast<int>(*dim))}) {
return Fold(context, ConvertToType<T>(std::move(*lb)));
}
} else if (auto lbounds{
AsConstantShape(GetLowerBounds(context, *named))}) {
return Fold(context,
ConvertToType<T>(Expr<ExtentType>{std::move(*lbounds)}));
}
} else {
lowerBoundsAreOne = symbol.Rank() == 0; // component
}
}
if (lowerBoundsAreOne) {
if (dim.has_value()) {
return Expr<T>{1};
} else {
std::vector<Scalar<T>> ones(rank, Scalar<T>{1});
return Expr<T>{
Constant<T>{std::move(ones), ConstantSubscripts{rank}}};
}
}
}
}
} else if (name == "leadz" || name == "trailz" || name == "poppar" ||
name == "popcnt") {
if (auto *sn{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
@ -688,16 +733,16 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
}
}
} else if (name == "shape") {
if (auto shape{GetShape(context, args[0].value())}) {
if (auto shape{GetShape(context, args[0])}) {
if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
}
}
} else if (name == "size") {
if (auto shape{GetShape(context, args[0].value())}) {
if (auto shape{GetShape(context, args[0])}) {
if (auto &dimArg{args[1]}) { // DIM= is present, get one extent
if (auto dim{GetInt64Arg(args[1])}) {
int rank = GetRank(*shape);
int rank{GetRank(*shape)};
if (*dim >= 1 && *dim <= rank) {
if (auto &extent{shape->at(*dim - 1)}) {
return Fold(context, ConvertToType<T>(std::move(*extent)));
@ -717,13 +762,70 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
return Expr<T>{ConvertToType<T>(Fold(context, std::move(product)))};
}
}
} else if (name == "ubound") {
if (auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
if (int rank{array->Rank()}; rank > 0) {
std::optional<std::int64_t> dim;
if (args[1].has_value()) {
dim = GetInt64Arg(args[1]);
if (!dim.has_value()) {
// DIM= is present but not constant
return Expr<T>{std::move(funcRef)};
} else if (*dim < 1 || *dim > rank) {
context.messages().Say(
"UBOUND(array,dim=%jd) dimension is out of range for rank-%d array"_en_US,
static_cast<std::intmax_t>(*dim), rank);
return Expr<T>(std::move(funcRef));
}
}
bool takeBoundsFromShape{true};
if (auto named{ExtractNamedEntity(*array)}) {
const Symbol &symbol{named->GetLastSymbol()};
if (symbol.Rank() == rank) {
takeBoundsFromShape = false;
if (dim.has_value()) {
if (semantics::IsAssumedSizeArray(symbol) && *dim == rank) {
return Expr<T>{-1};
} else if (auto ub{GetUpperBound(
context, *named, static_cast<int>(*dim))}) {
return Fold(context, ConvertToType<T>(std::move(*ub)));
}
} else {
Shape ubounds{GetUpperBounds(context, *named)};
if (semantics::IsAssumedSizeArray(symbol)) {
CHECK(!ubounds.back().has_value());
ubounds.back() = ExtentExpr{-1};
}
if (auto constant{AsConstantShape(ubounds)}) {
return Fold(context,
ConvertToType<T>(Expr<ExtentType>{std::move(*constant)}));
}
}
} else {
takeBoundsFromShape = symbol.Rank() == 0; // component
}
}
if (takeBoundsFromShape) {
if (auto shape{GetShape(context, *array)}) {
if (dim.has_value()) {
if (auto &dimSize{shape->at(*dim)}) {
return Fold(context,
ConvertToType<T>(Expr<ExtentType>{std::move(*dimSize)}));
}
} else if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
}
}
}
}
}
}
// TODO:
// ceiling, count, cshift, dot_product, eoshift,
// findloc, floor, iall, iany, iparity, ibits, image_status, index, ishftc,
// lbound, len_trim, matmul, max, maxloc, maxval, merge, min,
// len_trim, matmul, max, maxloc, maxval, merge, min,
// minloc, minval, mod, modulo, nint, not, pack, product, reduce,
// scan, sign, spread, sum, transfer, transpose, ubound, unpack, verify
// scan, sign, spread, sum, transfer, transpose, unpack, verify
return Expr<T>{std::move(funcRef)};
}

View File

@ -266,7 +266,7 @@ MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
[&](const Triplet &triplet) -> MaybeExtentExpr {
MaybeExtentExpr upper{triplet.upper()};
if (!upper.has_value()) {
upper = GetExtent(context, base, dimension);
upper = GetUpperBound(context, base, dimension);
}
MaybeExtentExpr lower{triplet.lower()};
if (!lower.has_value()) {
@ -298,12 +298,46 @@ MaybeExtentExpr GetUpperBound(FoldingContext &context, MaybeExtentExpr &&lower,
}
}
MaybeExtentExpr GetUpperBound(
FoldingContext &context, const NamedEntity &base, int dimension) {
const Symbol &symbol{base.GetLastSymbol()};
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
int j{0};
for (const auto &shapeSpec : details->shape()) {
if (j++ == dimension) {
if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
return Fold(context, common::Clone(*bound));
} else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
break;
} else {
return GetUpperBound(context, GetLowerBound(context, base, dimension),
GetExtent(context, base, dimension));
}
}
}
}
return std::nullopt;
}
Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) {
int rank{base.GetLastSymbol().Rank()};
Shape result;
for (int dim{0}; dim < rank; ++dim) {
result.emplace_back(GetUpperBound(context, base, dim));
}
return result;
}
void GetShapeVisitor::Handle(const Symbol &symbol) {
Handle(NamedEntity{symbol});
}
void GetShapeVisitor::Handle(const Component &component) {
Handle(NamedEntity{Component{component}});
if (component.GetLastSymbol().Rank() > 0) {
Handle(NamedEntity{Component{component}});
} else {
Nested(component.base());
}
}
void GetShapeVisitor::Handle(const NamedEntity &base) {
@ -326,6 +360,10 @@ void GetShapeVisitor::Handle(const NamedEntity &base) {
Return();
}
void GetShapeVisitor::Handle(const Substring &substring) {
Nested(substring.parent());
}
void GetShapeVisitor::Handle(const ArrayRef &arrayRef) {
Shape shape;
int dimension{0};

View File

@ -69,6 +69,9 @@ MaybeExtentExpr GetExtent(
FoldingContext &, const Subscript &, const NamedEntity &, int dimension);
MaybeExtentExpr GetUpperBound(
FoldingContext &, MaybeExtentExpr &&lower, MaybeExtentExpr &&extent);
MaybeExtentExpr GetUpperBound(
FoldingContext &, const NamedEntity &, int dimension);
Shape GetUpperBounds(FoldingContext &, const NamedEntity &);
// Compute an element count for a triplet or trip count for a DO.
ExtentExpr CountTrips(
@ -104,6 +107,7 @@ public:
void Handle(const StaticDataObject::Pointer &) { Scalar(); }
void Handle(const ArrayRef &);
void Handle(const CoarrayRef &);
void Handle(const Substring &);
void Handle(const ProcedureRef &);
void Handle(const StructureConstructor &) { Scalar(); }
template<typename T> void Handle(const ArrayConstructor<T> &aconst) {

View File

@ -209,7 +209,7 @@ template<typename A, typename B> A *UnwrapExpr(std::optional<B> &x) {
// If an expression simply wraps a DataRef, extract and return it.
template<typename A>
common::IfNoLvalue<std::optional<DataRef>, A> ExtractDataRef(const A &) {
return std::nullopt; // default base casec
return std::nullopt; // default base case
}
template<typename T>
std::optional<DataRef> ExtractDataRef(const Designator<T> &d) {
@ -235,6 +235,24 @@ std::optional<DataRef> ExtractDataRef(const std::optional<A> &x) {
}
}
template<typename A> std::optional<NamedEntity> ExtractNamedEntity(const A &x) {
if (auto dataRef{ExtractDataRef(x)}) {
return std::visit(
common::visitors{
[](const Symbol *symbol) -> std::optional<NamedEntity> {
return NamedEntity{*symbol};
},
[](Component &&component) -> std::optional<NamedEntity> {
return NamedEntity{std::move(component)};
},
[](auto &&) -> std::optional<NamedEntity> { return std::nullopt; },
},
std::move(dataRef->u));
} else {
return std::nullopt;
}
}
// If an expression is simply a whole symbol data designator,
// extract and return that symbol, else null.
template<typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {

View File

@ -340,7 +340,7 @@ public:
private:
void SetBounds(std::optional<Expr<SubscriptInteger>> &,
std::optional<Expr<SubscriptInteger>> &);
std::variant<DataRef, StaticDataObject::Pointer> parent_;
Parent parent_;
std::optional<IndirectSubscriptIntegerExpr> lower_, upper_;
};

View File

@ -153,25 +153,15 @@ MaybeExpr ExpressionAnalyzer::Designate(DataRef &&ref) {
// subscripts are in hand.
MaybeExpr ExpressionAnalyzer::CompleteSubscripts(ArrayRef &&ref) {
const Symbol &symbol{ref.GetLastSymbol().GetUltimate()};
const auto *object{symbol.detailsIf<semantics::ObjectEntityDetails>()};
int symbolRank{symbol.Rank()};
int subscripts{static_cast<int>(ref.size())};
if (subscripts == 0) {
if (semantics::IsAssumedSizeArray(symbol)) {
// Don't introduce a triplet that would later be caught
// as being invalid.
return Designate(DataRef{std::move(ref)});
}
// A -> A(:,:)
for (; subscripts < symbolRank; ++subscripts) {
ref.emplace_back(Triplet{});
}
}
if (subscripts != symbolRank) {
// nothing to check
} else if (subscripts != symbolRank) {
Say("Reference to rank-%d object '%s' has %d subscripts"_err_en_US,
symbolRank, symbol.name(), subscripts);
return std::nullopt;
} else if (subscripts == 0) {
// nothing to check
} else if (Component * component{ref.base().UnwrapComponent()}) {
int baseRank{component->base().Rank()};
if (baseRank > 0) {
@ -186,11 +176,10 @@ MaybeExpr ExpressionAnalyzer::CompleteSubscripts(ArrayRef &&ref) {
return std::nullopt;
}
}
} else if (const auto *details{
symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
} else if (object != nullptr) {
// C928 & C1002
if (Triplet * last{std::get_if<Triplet>(&ref.subscript().back().u)}) {
if (!last->upper().has_value() && details->IsAssumedSize()) {
if (!last->upper().has_value() && object->IsAssumedSize()) {
Say("Assumed-size array '%s' must have explicit final "
"subscript upper bound value"_err_en_US,
symbol.name());
@ -221,10 +210,8 @@ MaybeExpr ExpressionAnalyzer::ApplySubscripts(
std::move(dataRef.u));
}
// Top-level checks for data references. Unsubscripted whole array references
// get expanded -- e.g., MATRIX becomes MATRIX(:,:).
// Top-level checks for data references.
MaybeExpr ExpressionAnalyzer::TopLevelChecks(DataRef &&dataRef) {
bool addSubscripts{false};
if (Component * component{std::get_if<Component>(&dataRef.u)}) {
const Symbol &symbol{component->GetLastSymbol()};
int componentRank{symbol.Rank()};
@ -234,18 +221,8 @@ MaybeExpr ExpressionAnalyzer::TopLevelChecks(DataRef &&dataRef) {
Say("Reference to whole rank-%d component '%%%s' of "
"rank-%d array of derived type is not allowed"_err_en_US,
componentRank, symbol.name(), baseRank);
} else {
addSubscripts = true;
}
}
} else if (const Symbol **symbol{std::get_if<const Symbol *>(&dataRef.u)}) {
addSubscripts = (*symbol)->Rank() > 0;
}
if (addSubscripts) {
if (MaybeExpr subscripted{
ApplySubscripts(std::move(dataRef), std::vector<Subscript>{})}) {
return subscripted;
}
}
return Designate(std::move(dataRef));
}

View File

@ -478,11 +478,6 @@ bool IsFinalizable(const Symbol &symbol) {
bool IsCoarray(const Symbol &symbol) { return symbol.Corank() > 0; }
bool IsAssumedSizeArray(const Symbol &symbol) {
const auto *details{symbol.detailsIf<ObjectEntityDetails>()};
return details && details->IsAssumedSize();
}
bool IsExternalInPureContext(const Symbol &symbol, const Scope &scope) {
if (const auto *pureProc{semantics::FindPureProcedureContaining(&scope)}) {
if (const Symbol * root{GetAssociationRoot(symbol)}) {

View File

@ -110,7 +110,10 @@ inline bool IsProtected(const Symbol &symbol) {
}
bool IsFinalizable(const Symbol &symbol);
bool IsCoarray(const Symbol &symbol);
bool IsAssumedSizeArray(const Symbol &symbol);
inline bool IsAssumedSizeArray(const Symbol &symbol) {
const auto *details{symbol.detailsIf<ObjectEntityDetails>()};
return details && details->IsAssumedSize();
}
std::optional<parser::MessageFixedText> WhyNotModifiable(
const Symbol &symbol, const Scope &scope);
// Is the symbol modifiable in this scope

View File

@ -129,6 +129,7 @@ set(FOLDING_TESTS
folding05.f90
folding06.f90
folding07.f90
folding08.f90
)

View File

@ -0,0 +1,52 @@
! Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
!
! Licensed under the Apache License, Version 2.0 (the "License");
! you may not use this file except in compliance with the License.
! You may obtain a copy of the License at
!
! http://www.apache.org/licenses/LICENSE-2.0
!
! Unless required by applicable law or agreed to in writing, software
! distributed under the License is distributed on an "AS IS" BASIS,
! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
! implied.
! See the License for the specific language governing permissions and
! limitations under the License.
! Test folding of LBOUND and UBOUND
subroutine testlbound(n1,a1,a2)
integer, intent(in) :: n1
real, intent(in) :: a1(0:n1), a2(0:*)
type :: t
real :: a
end type
type(t) :: ta(0:2)
character(len=2) :: ca(-1:1)
integer, parameter :: lba1(:) = lbound(a1)
logical, parameter :: test_lba1 = all(lba1 == [0])
integer, parameter :: lba2(:) = lbound(a2)
logical, parameter :: test_lba2 = all(lba2 == [0])
integer, parameter :: uba2(:) = ubound(a2)
logical, parameter :: test_uba2 = all(uba2 == [-1])
integer, parameter :: lbta1(:) = lbound(ta)
logical, parameter :: test_lbta1 = all(lbta1 == [0])
integer, parameter :: ubta1(:) = ubound(ta)
logical, parameter :: test_ubta1 = all(ubta1 == [2])
integer, parameter :: lbta2(:) = lbound(ta(:))
logical, parameter :: test_lbta2 = all(lbta2 == [1])
integer, parameter :: ubta2(:) = ubound(ta(:))
logical, parameter :: test_ubta2 = all(ubta2 == [3])
integer, parameter :: lbta3(:) = lbound(ta%a)
logical, parameter :: test_lbta3 = all(lbta3 == [1])
integer, parameter :: ubta3(:) = ubound(ta%a)
logical, parameter :: test_ubta3 = all(ubta3 == [3])
integer, parameter :: lbca1(:) = lbound(ca)
logical, parameter :: test_lbca1 = all(lbca1 == [-1])
integer, parameter :: ubca1(:) = ubound(ca)
logical, parameter :: test_ubca1 = all(ubca1 == [1])
integer, parameter :: lbca2(:) = lbound(ca(:)(1:1))
logical, parameter :: test_lbca2 = all(lbca2 == [1])
integer, parameter :: ubca2(:) = ubound(ca(:)(1:1))
logical, parameter :: test_ubca2 = all(ubca2 == [3])
end