[flang] Do not lose call in shape inquiry on function reference

Currently, something like `print *, size(foo(n,m))` was rewritten
to `print *, size(foo_result_symbol)` when foo result is a non constant
shape array. This cannot be processed by lowering or reprocessed by a
Fortran compiler since the syntax is wrong (`foo_result_symbol` is
unknown on the caller side) and the arguments are lost when they might
be required to compute the result shape.

It is not possible (and probably not desired) to make GetShape fail in
general in such case since returning nullopt seems only expected for
scalars or assumed rank (see GetRank usage in lib/Semantics/check-call.cpp),
and returning a vector with nullopt extent may trigger some checks to
believe they are facing an assumed size (like here in intrinsic argument
checks: 196204c72c/flang/lib/Evaluate/intrinsics.cpp (L1530)).

Hence, I went for a solution that limits the rewrite change to folding
(where the original expression is returned if the shape depends on a non
constant shape from a call).

I added a non default option to GetShapeHelper that prevents the rewrite
of shape inquiry on calls to descriptor inquiries. At first I wanted to
avoid touching GetShapeHelper, but it would require to re-implement all
its logic to determine if the shape comes from a function call or not
(the expression could be `size(1+foo(n,m))`). So added an alternate
entry point to GetShapeHelper seemed the cleanest solution to me.

Differential Revision: https://reviews.llvm.org/D116933
This commit is contained in:
Jean Perier 2022-01-10 19:09:45 +01:00
parent 7f1955dc96
commit fb3faa8b32
4 changed files with 83 additions and 5 deletions

View File

@ -104,6 +104,9 @@ public:
using Base::operator();
GetShapeHelper() : Base{*this} {}
explicit GetShapeHelper(FoldingContext &c) : Base{*this}, context_{&c} {}
explicit GetShapeHelper(FoldingContext &c, bool useResultSymbolShape)
: Base{*this}, context_{&c}, useResultSymbolShape_{useResultSymbolShape} {
}
Result operator()(const ImpliedDoIndex &) const { return ScalarShape(); }
Result operator()(const DescriptorInquiry &) const { return ScalarShape(); }
@ -197,6 +200,7 @@ private:
}
FoldingContext *context_{nullptr};
bool useResultSymbolShape_{true};
};
template <typename A>
@ -241,6 +245,15 @@ std::optional<ConstantSubscripts> GetConstantExtents(
}
}
// Get shape that does not depends on callee scope symbols if the expression
// contains calls. Return std::nullopt if it is not possible to build such shape
// (e.g. for calls to array functions whose result shape depends on the
// arguments).
template <typename A>
std::optional<Shape> GetContextFreeShape(FoldingContext &context, const A &x) {
return GetShapeHelper{context, false}(x);
}
// Compilation-time shape conformance checking, when corresponding extents
// are or should be known. The result is an optional Boolean:
// - nullopt: no error found or reported, but conformance cannot

View File

@ -158,7 +158,7 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
}
}
if (takeBoundsFromShape) {
if (auto shape{GetShape(context, *array)}) {
if (auto shape{GetContextFreeShape(context, *array)}) {
if (dim) {
if (auto &dimSize{shape->at(*dim)}) {
return Fold(context,
@ -851,7 +851,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
}
}
} else if (name == "shape") {
if (auto shape{GetShape(context, args[0])}) {
if (auto shape{GetContextFreeShape(context, args[0])}) {
if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
}
@ -894,7 +894,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
return result.value;
}));
} else if (name == "size") {
if (auto shape{GetShape(context, args[0])}) {
if (auto shape{GetContextFreeShape(context, args[0])}) {
if (auto &dimArg{args[1]}) { // DIM= is present, get one extent
if (auto dim{GetInt64Arg(args[1])}) {
int rank{GetRank(*shape)};

View File

@ -556,9 +556,22 @@ auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
return (*this)(assoc.expr());
}
},
[&](const semantics::SubprogramDetails &subp) {
[&](const semantics::SubprogramDetails &subp) -> Result {
if (subp.isFunction()) {
return (*this)(subp.result());
auto resultShape{(*this)(subp.result())};
if (resultShape && !useResultSymbolShape_) {
// Ensure the shape does not contain descriptor inquiries, they
// may refer to symbols belonging to the called subprogram scope
// that are meaningless on the caller side without the related
// call expression.
for (auto extent : *resultShape) {
if (extent &&
std::holds_alternative<DescriptorInquiry>(extent->u)) {
return std::nullopt;
}
}
}
return resultShape;
} else {
return Result{};
}

View File

@ -0,0 +1,52 @@
! Test expression rewrites, in case where the expression cannot be
! folded to constant values.
! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
! Test rewrites of inquiry intrinsics with arguments whose shape depends
! on a function reference with non constant shape. The function reference
! must be retained.
module some_mod
contains
function returns_array(n, m)
integer :: returns_array(10:n+10,10:m+10)
returns_array = 0
end function
subroutine ubound_test(x, n, m)
integer :: x(n, m)
!CHECK: PRINT *, [INTEGER(4)::int(size(x,dim=1),kind=4),int(size(x,dim=2),kind=4)]
print *, ubound(x)
!CHECK: PRINT *, ubound(returns_array(n,m))
print *, ubound(returns_array(n, m))
!CHECK: PRINT *, ubound(returns_array(n,m),dim=1_4)
print *, ubound(returns_array(n, m), dim=1)
end subroutine
subroutine size_test(x, n, m)
integer :: x(n, m)
!CHECK: PRINT *, int(size(x,dim=1)*size(x,dim=2),kind=4)
print *, size(x)
!CHECK: PRINT *, size(returns_array(n,m))
print *, size(returns_array(n, m))
!CHECK: PRINT *, size(returns_array(n,m),dim=1_4)
print *, size(returns_array(n, m), dim=1)
end subroutine
subroutine shape_test(x, n, m)
integer :: x(n, m)
!CHECK: PRINT *, [INTEGER(4)::int(size(x,dim=1),kind=4),int(size(x,dim=2),kind=4)]
print *, shape(x)
!CHECK: PRINT *, shape(returns_array(n,m))
print *, shape(returns_array(n, m))
end subroutine
subroutine lbound_test(x, n, m)
integer :: x(n, m)
!CHECK: PRINT *, [INTEGER(4)::1_4,1_4]
print *, lbound(x)
!CHECK: PRINT *, [INTEGER(4)::1_4,1_4]
print *, lbound(returns_array(n, m))
!CHECK: PRINT *, 1_4
print *, lbound(returns_array(n, m), dim=1)
end subroutine
end module