diff --git a/flang/lib/evaluate/expression.cc b/flang/lib/evaluate/expression.cc index 4d07e9bf00b8..0d412c975b8c 100644 --- a/flang/lib/evaluate/expression.cc +++ b/flang/lib/evaluate/expression.cc @@ -85,11 +85,11 @@ std::optional ExpressionBase::GetType() const { return Result::GetType(); } else { return std::visit( - [&](const auto &x) { + [&](const auto &x) -> std::optional { if constexpr (!common::HasMember) { return x.GetType(); } else { - return std::optional{}; + return std::nullopt; } }, derived().u); diff --git a/flang/lib/evaluate/shape.cc b/flang/lib/evaluate/shape.cc index d7d2b1a9225f..6f03378e4003 100644 --- a/flang/lib/evaluate/shape.cc +++ b/flang/lib/evaluate/shape.cc @@ -357,7 +357,31 @@ Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) { } void GetShapeVisitor::Handle(const Symbol &symbol) { - Handle(NamedEntity{symbol}); + std::visit( + common::visitors{ + [&](const semantics::ObjectEntityDetails &object) { + Handle(NamedEntity{symbol}); + }, + [&](const semantics::AssocEntityDetails &assoc) { + Nested(assoc.expr()); + }, + [&](const semantics::SubprogramDetails &subp) { + if (subp.isFunction()) { + Handle(subp.result()); + } else { + Return(); + } + }, + [&](const semantics::ProcBindingDetails &binding) { + Handle(binding.symbol()); + }, + [&](const semantics::UseDetails &use) { Handle(use.symbol()); }, + [&](const semantics::HostAssocDetails &assoc) { + Handle(assoc.symbol()); + }, + [&](const auto &) { Return(); }, + }, + symbol.details()); } void GetShapeVisitor::Handle(const Component &component) { @@ -369,23 +393,21 @@ void GetShapeVisitor::Handle(const Component &component) { } void GetShapeVisitor::Handle(const NamedEntity &base) { - const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())}; - if (const auto *details{symbol.detailsIf()}) { + const Symbol &symbol{base.GetLastSymbol()}; + if (const auto *object{symbol.detailsIf()}) { if (IsImpliedShape(symbol)) { - Nested(details->init()); + Nested(object->init()); } else { Shape result; - int n{static_cast(details->shape().size())}; + int n{static_cast(object->shape().size())}; for (int dimension{0}; dimension < n; ++dimension) { result.emplace_back(GetExtent(context_, base, dimension)); } Return(std::move(result)); } - } else if (const auto *details{ - symbol.detailsIf()}) { - Nested(details->expr()); + } else { + Return(); // error recovery } - Return(); } void GetShapeVisitor::Handle(const Substring &substring) { diff --git a/flang/test/evaluate/folding08.f90 b/flang/test/evaluate/folding08.f90 index 7c9dfe1e80ad..770dfa82e9d8 100644 --- a/flang/test/evaluate/folding08.f90 +++ b/flang/test/evaluate/folding08.f90 @@ -15,38 +15,48 @@ ! Test folding of LBOUND and UBOUND -subroutine test(n1,a1,a2) - integer, intent(in) :: n1 - real, intent(in) :: a1(0:n1), a2(0:*) - type :: t - real :: a - end type - type(t) :: ta(0:2) - character(len=2) :: ca(-1:1) - integer, parameter :: lba1(:) = lbound(a1) - logical, parameter :: test_lba1 = all(lba1 == [0]) - integer, parameter :: lba2(:) = lbound(a2) - logical, parameter :: test_lba2 = all(lba2 == [0]) - integer, parameter :: uba2(:) = ubound(a2) - logical, parameter :: test_uba2 = all(uba2 == [-1]) - integer, parameter :: lbta1(:) = lbound(ta) - logical, parameter :: test_lbta1 = all(lbta1 == [0]) - integer, parameter :: ubta1(:) = ubound(ta) - logical, parameter :: test_ubta1 = all(ubta1 == [2]) - integer, parameter :: lbta2(:) = lbound(ta(:)) - logical, parameter :: test_lbta2 = all(lbta2 == [1]) - integer, parameter :: ubta2(:) = ubound(ta(:)) - logical, parameter :: test_ubta2 = all(ubta2 == [3]) - integer, parameter :: lbta3(:) = lbound(ta%a) - logical, parameter :: test_lbta3 = all(lbta3 == [1]) - integer, parameter :: ubta3(:) = ubound(ta%a) - logical, parameter :: test_ubta3 = all(ubta3 == [3]) - integer, parameter :: lbca1(:) = lbound(ca) - logical, parameter :: test_lbca1 = all(lbca1 == [-1]) - integer, parameter :: ubca1(:) = ubound(ca) - logical, parameter :: test_ubca1 = all(ubca1 == [1]) - integer, parameter :: lbca2(:) = lbound(ca(:)(1:1)) - logical, parameter :: test_lbca2 = all(lbca2 == [1]) - integer, parameter :: ubca2(:) = ubound(ca(:)(1:1)) - logical, parameter :: test_ubca2 = all(ubca2 == [3]) +module m + contains + function foo() + real :: foo(2:3,4:6) + end function + subroutine test(n1,a1,a2) + integer, intent(in) :: n1 + real, intent(in) :: a1(0:n1), a2(0:*) + type :: t + real :: a + end type + type(t) :: ta(0:2) + character(len=2) :: ca(-1:1) + integer, parameter :: lba1(:) = lbound(a1) + logical, parameter :: test_lba1 = all(lba1 == [0]) + integer, parameter :: lba2(:) = lbound(a2) + logical, parameter :: test_lba2 = all(lba2 == [0]) + integer, parameter :: uba2(:) = ubound(a2) + logical, parameter :: test_uba2 = all(uba2 == [-1]) + integer, parameter :: lbta1(:) = lbound(ta) + logical, parameter :: test_lbta1 = all(lbta1 == [0]) + integer, parameter :: ubta1(:) = ubound(ta) + logical, parameter :: test_ubta1 = all(ubta1 == [2]) + integer, parameter :: lbta2(:) = lbound(ta(:)) + logical, parameter :: test_lbta2 = all(lbta2 == [1]) + integer, parameter :: ubta2(:) = ubound(ta(:)) + logical, parameter :: test_ubta2 = all(ubta2 == [3]) + integer, parameter :: lbta3(:) = lbound(ta%a) + logical, parameter :: test_lbta3 = all(lbta3 == [1]) + integer, parameter :: ubta3(:) = ubound(ta%a) + logical, parameter :: test_ubta3 = all(ubta3 == [3]) + integer, parameter :: lbca1(:) = lbound(ca) + logical, parameter :: test_lbca1 = all(lbca1 == [-1]) + integer, parameter :: ubca1(:) = ubound(ca) + logical, parameter :: test_ubca1 = all(ubca1 == [1]) + integer, parameter :: lbca2(:) = lbound(ca(:)(1:1)) + logical, parameter :: test_lbca2 = all(lbca2 == [1]) + integer, parameter :: ubca2(:) = ubound(ca(:)(1:1)) + logical, parameter :: test_ubca2 = all(ubca2 == [3]) + integer, parameter :: lbfoo(:) = lbound(foo()) + logical, parameter :: test_lbfoo = all(lbfoo == [1,1]) + integer, parameter :: ubfoo(:) = ubound(foo()) + logical, parameter :: test_ubfoo = all(ubfoo == [2,3]) + end subroutine end