From 9245f35580ca0cce147ec9cebfa431fa5b7feac4 Mon Sep 17 00:00:00 2001
From: peter klausler <pklausler@nvidia.com>
Date: Mon, 13 Sep 2021 13:45:30 -0700
Subject: [PATCH] [flang] Validate SIZE(x,DIM=n) dimension for assumed-size
 array x

Catch invalid attempts to extract the unknowable extent of the last
dimension of an assumed-size array dummy argument, and clean up
problems with assumed-rank arguments in similar circumstances
exposed by testing the fix.

Differential Revision: https://reviews.llvm.org/D109918
---
 flang/include/flang/Evaluate/tools.h      |  3 +++
 flang/include/flang/Semantics/tools.h     |  4 ---
 flang/lib/Evaluate/fold-integer.cpp       | 12 ++++++---
 flang/lib/Evaluate/formatting.cpp         |  2 +-
 flang/lib/Evaluate/shape.cpp              | 32 ++++++++++++++++-------
 flang/lib/Evaluate/tools.cpp              | 14 ++++++++++
 flang/lib/Evaluate/variable.cpp           |  2 +-
 flang/lib/Semantics/check-select-rank.cpp |  2 +-
 flang/test/Semantics/select-rank.f90      |  4 ++-
 9 files changed, 55 insertions(+), 20 deletions(-)

diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 50a9e265a606..5ebf3cd3b1fd 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -298,6 +298,9 @@ std::optional<DataRef> ExtractDataRef(const A *p, bool intoSubstring = false) {
     return std::nullopt;
   }
 }
+std::optional<DataRef> ExtractDataRef(
+    const ActualArgument &, bool intoSubstring = false);
+
 std::optional<DataRef> ExtractSubstringBase(const Substring &);
 
 // Predicate: is an expression is an array element reference?
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index d969dc914b03..5c23bc3ad853 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -179,10 +179,6 @@ inline bool IsAssumedSizeArray(const Symbol &symbol) {
   const auto *details{symbol.detailsIf<ObjectEntityDetails>()};
   return details && details->IsAssumedSize();
 }
-inline bool IsAssumedRankArray(const Symbol &symbol) {
-  const auto *details{symbol.detailsIf<ObjectEntityDetails>()};
-  return details && details->IsAssumedRank();
-}
 bool IsAssumedLengthCharacter(const Symbol &);
 bool IsExternal(const Symbol &);
 bool IsModuleProcedure(const Symbol &);
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 3fdf252407e9..c69ce32da188 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -612,7 +612,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
     if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
       if (auto named{ExtractNamedEntity(*array)}) {
         const Symbol &symbol{named->GetLastSymbol()};
-        if (semantics::IsAssumedRankArray(symbol)) {
+        if (IsAssumedRank(symbol)) {
           // DescriptorInquiry can only be placed in expression of kind
           // DescriptorInquiry::Result::kind.
           return ConvertToType<T>(Expr<
@@ -667,7 +667,13 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
         if (auto dim{GetInt64Arg(args[1])}) {
           int rank{GetRank(*shape)};
           if (*dim >= 1 && *dim <= rank) {
-            if (auto &extent{shape->at(*dim - 1)}) {
+            const Symbol *symbol{UnwrapWholeSymbolDataRef(args[0])};
+            if (symbol && IsAssumedSizeArray(*symbol) && *dim == rank) {
+              context.messages().Say(
+                  "size(array,dim=%jd) of last dimension is not available for rank-%d assumed-size array dummy argument"_err_en_US,
+                  *dim, rank);
+              return MakeInvalidIntrinsic<T>(std::move(funcRef));
+            } else if (auto &extent{shape->at(*dim - 1)}) {
               return Fold(context, ConvertToType<T>(std::move(*extent)));
             }
           } else {
@@ -705,7 +711,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
   } else if (name == "ubound") {
     return UBOUND(context, std::move(funcRef));
   }
-  // TODO: count(w/ dim), dot_product, findloc, ibits, image_status, ishftc,
+  // TODO: dot_product, findloc, ibits, image_status, ishftc,
   // matmul, maxloc, minloc, sign, transfer
   return Expr<T>{std::move(funcRef)};
 }
diff --git a/flang/lib/Evaluate/formatting.cpp b/flang/lib/Evaluate/formatting.cpp
index 5b5ae258d8b0..2569c85e345d 100644
--- a/flang/lib/Evaluate/formatting.cpp
+++ b/flang/lib/Evaluate/formatting.cpp
@@ -739,7 +739,7 @@ llvm::raw_ostream &DescriptorInquiry::AsFortran(llvm::raw_ostream &o) const {
   if (field_ == Field::Len) {
     return o << "%len";
   } else {
-    if (dimension_ >= 0) {
+    if (field_ != Field::Rank && dimension_ >= 0) {
       o << ",dim=" << (dimension_ + 1);
     }
     return o << ')';
diff --git a/flang/lib/Evaluate/shape.cpp b/flang/lib/Evaluate/shape.cpp
index 8919038d7ec9..0d2bb504a84a 100644
--- a/flang/lib/Evaluate/shape.cpp
+++ b/flang/lib/Evaluate/shape.cpp
@@ -260,7 +260,15 @@ auto GetLowerBoundHelper::operator()(const Symbol &symbol0) -> Result {
     }
   } else if (const auto *assoc{
                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
-    return (*this)(assoc->expr());
+    if (assoc->rank()) { // SELECT RANK case
+      const Symbol &resolved{ResolveAssociations(symbol)};
+      if (IsDescriptor(resolved) && dimension_ < *assoc->rank()) {
+        return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0},
+            DescriptorInquiry::Field::LowerBound, dimension_}};
+      }
+    } else {
+      return (*this)(assoc->expr());
+    }
   }
   return Default();
 }
@@ -338,7 +346,20 @@ static MaybeExtentExpr GetNonNegativeExtent(
 
 MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) {
   CHECK(dimension >= 0);
-  const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
+  const Symbol &last{base.GetLastSymbol()};
+  const Symbol &symbol{ResolveAssociations(last)};
+  if (const auto *assoc{last.detailsIf<semantics::AssocEntityDetails>()}) {
+    if (assoc->rank()) { // SELECT RANK case
+      if (semantics::IsDescriptor(symbol) && dimension < *assoc->rank()) {
+        return ExtentExpr{DescriptorInquiry{
+            NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
+      }
+    } else if (auto shape{GetShape(assoc->expr())}) {
+      if (dimension < static_cast<int>(shape->size())) {
+        return std::move(shape->at(dimension));
+      }
+    }
+  }
   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
     if (IsImpliedShape(symbol) && details->init()) {
       if (auto shape{GetShape(symbol)}) {
@@ -369,13 +390,6 @@ MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) {
         }
       }
     }
-  } else if (const auto *assoc{
-                 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
-    if (auto shape{GetShape(assoc->expr())}) {
-      if (dimension < static_cast<int>(shape->size())) {
-        return std::move(shape->at(dimension));
-      }
-    }
   }
   return std::nullopt;
 }
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index fde6089b2196..dd66259789ff 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -50,6 +50,15 @@ Expr<SomeType> Parenthesize(Expr<SomeType> &&expr) {
       std::move(expr.u));
 }
 
+std::optional<DataRef> ExtractDataRef(
+    const ActualArgument &arg, bool intoSubstring) {
+  if (const Expr<SomeType> *expr{arg.UnwrapExpr()}) {
+    return ExtractDataRef(*expr, intoSubstring);
+  } else {
+    return std::nullopt;
+  }
+}
+
 std::optional<DataRef> ExtractSubstringBase(const Substring &substring) {
   return std::visit(
       common::visitors{
@@ -665,6 +674,11 @@ std::optional<Expr<SomeType>> ConvertToType(
 }
 
 bool IsAssumedRank(const Symbol &original) {
+  if (const auto *assoc{original.detailsIf<semantics::AssocEntityDetails>()}) {
+    if (assoc->rank()) {
+      return false; // in SELECT RANK case
+    }
+  }
   const Symbol &symbol{semantics::ResolveAssociations(original)};
   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
     return details->IsAssumedRank();
diff --git a/flang/lib/Evaluate/variable.cpp b/flang/lib/Evaluate/variable.cpp
index 6b5f4caeb884..6a9fced879e0 100644
--- a/flang/lib/Evaluate/variable.cpp
+++ b/flang/lib/Evaluate/variable.cpp
@@ -245,7 +245,7 @@ DescriptorInquiry::DescriptorInquiry(
     : base_{base}, field_{field}, dimension_{dim} {
   const Symbol &last{base_.GetLastSymbol()};
   CHECK(IsDescriptor(last));
-  CHECK((field == Field::Len && dim == 0) ||
+  CHECK(((field == Field::Len || field == Field::Rank) && dim == 0) ||
       (field != Field::Len && dim >= 0 && dim < last.Rank()));
 }
 
diff --git a/flang/lib/Semantics/check-select-rank.cpp b/flang/lib/Semantics/check-select-rank.cpp
index 3487fb564df0..595c17fa7211 100644
--- a/flang/lib/Semantics/check-select-rank.cpp
+++ b/flang/lib/Semantics/check-select-rank.cpp
@@ -32,7 +32,7 @@ void SelectRankConstructChecker::Leave(
   const Symbol *saveSelSymbol{nullptr};
   if (const auto selExpr{GetExprFromSelector(selectRankStmtSel)}) {
     if (const Symbol * sel{evaluate::UnwrapWholeSymbolDataRef(*selExpr)}) {
-      if (!IsAssumedRankArray(*sel)) { // C1150
+      if (!evaluate::IsAssumedRank(*sel)) { // C1150
         context_.Say(parser::FindSourceLocation(selectRankStmtSel),
             "Selector '%s' is not an assumed-rank array variable"_err_en_US,
             sel->name().ToString());
diff --git a/flang/test/Semantics/select-rank.f90 b/flang/test/Semantics/select-rank.f90
index d0cd93195501..3e21e4860521 100644
--- a/flang/test/Semantics/select-rank.f90
+++ b/flang/test/Semantics/select-rank.f90
@@ -145,11 +145,13 @@ contains
     Rank(2)
       print *, "Now it's rank 2 "
     RANK (*)
-      print *, "Going for a other rank"
+      print *, "Going for another rank"
+      !ERROR: 'kind=' argument must be a constant scalar integer whose value is a supported kind for the intrinsic result type
       j = INT(0, KIND=MERGE(KIND(0), -1, RANK(x) == 1))
     !ERROR: Not more than one of the selectors of SELECT RANK statement may be '*'
     RANK (*)
       print *, "This is Wrong"
+      !ERROR: 'kind=' argument must be a constant scalar integer whose value is a supported kind for the intrinsic result type
       j = INT(0, KIND=MERGE(KIND(0), -1, RANK(x) == 1))
     END SELECT
    end subroutine