From b65572d5a029c9655da3c7623392ee0c1e9abfff Mon Sep 17 00:00:00 2001 From: peter klausler Date: Wed, 3 Apr 2019 16:04:13 -0700 Subject: [PATCH] [flang] fix original failure (reshape intrinsic argument check) Original-commit: flang-compiler/f18@8bba330b32d928a5cf5d581d139c0cec02294b58 Reviewed-on: https://github.com/flang-compiler/f18/pull/386 Tree-same-pre-rewrite: false --- flang/lib/evaluate/call.cc | 8 -- flang/lib/evaluate/call.h | 1 - flang/lib/evaluate/common.h | 1 + flang/lib/evaluate/fold.cc | 7 +- flang/lib/evaluate/intrinsics.cc | 28 +++-- flang/lib/evaluate/shape.cc | 207 +++++++++++++++++++++++-------- flang/lib/evaluate/shape.h | 92 +++++++++----- flang/lib/evaluate/variable.cc | 31 +++-- flang/lib/evaluate/variable.h | 5 +- flang/lib/parser/message.h | 1 + 10 files changed, 272 insertions(+), 109 deletions(-) diff --git a/flang/lib/evaluate/call.cc b/flang/lib/evaluate/call.cc index 102b9fefc92e..5313069a8992 100644 --- a/flang/lib/evaluate/call.cc +++ b/flang/lib/evaluate/call.cc @@ -29,14 +29,6 @@ bool ActualArgument::operator==(const ActualArgument &that) const { isAlternateReturn == that.isAlternateReturn && value() == that.value(); } -std::optional ActualArgument::VectorSize() const { - if (Rank() != 1) { - return std::nullopt; - } - // TODO: get shape vector of value, return its length - return std::nullopt; -} - bool SpecificIntrinsic::operator==(const SpecificIntrinsic &that) const { return name == that.name && type == that.type && rank == that.rank && attrs == that.attrs; diff --git a/flang/lib/evaluate/call.h b/flang/lib/evaluate/call.h index 5b11df1ca15a..4dab1380419d 100644 --- a/flang/lib/evaluate/call.h +++ b/flang/lib/evaluate/call.h @@ -45,7 +45,6 @@ public: int Rank() const; bool operator==(const ActualArgument &) const; std::ostream &AsFortran(std::ostream &) const; - std::optional VectorSize() const; std::optional keyword; bool isAlternateReturn{false}; // when true, "value" is a label number diff --git a/flang/lib/evaluate/common.h b/flang/lib/evaluate/common.h index d91da130d971..63f9950f6c88 100644 --- a/flang/lib/evaluate/common.h +++ b/flang/lib/evaluate/common.h @@ -201,6 +201,7 @@ template class Expr; class FoldingContext { public: + FoldingContext() = default; explicit FoldingContext(const parser::ContextualMessages &m, Rounding round = defaultRounding, bool flush = false) : messages_{m}, rounding_{round}, flushSubnormalsToZero_{flush} {} diff --git a/flang/lib/evaluate/fold.cc b/flang/lib/evaluate/fold.cc index dd4e13ec10c2..1d194f1d14f8 100644 --- a/flang/lib/evaluate/fold.cc +++ b/flang/lib/evaluate/fold.cc @@ -19,6 +19,7 @@ #include "host.h" #include "int-power.h" #include "intrinsics-library-templates.h" +#include "shape.h" #include "tools.h" #include "traversal.h" #include "type.h" @@ -473,13 +474,17 @@ Expr> FoldOperation(FoldingContext &context, } return FoldElementalIntrinsic( context, std::move(funcRef), &Scalar::MERGE_BITS); + } else if (name == "rank") { + // TODO pmk: get rank + } else if (name == "shape") { + // TODO pmk: call GetShape on argument, massage result } // TODO: // ceiling, count, cshift, dot_product, eoshift, // findloc, floor, iachar, iall, iany, iparity, ibits, ichar, image_status, // index, ishftc, lbound, len_trim, matmul, max, maxloc, maxval, merge, min, // minloc, minval, mod, modulo, nint, not, pack, product, reduce, reshape, - // scan, selected_char_kind, selected_int_kind, selected_real_kind, shape, + // scan, selected_char_kind, selected_int_kind, selected_real_kind, // sign, size, spread, sum, transfer, transpose, ubound, unpack, verify } return Expr{std::move(funcRef)}; diff --git a/flang/lib/evaluate/intrinsics.cc b/flang/lib/evaluate/intrinsics.cc index b645385d84c3..f42052b273f1 100644 --- a/flang/lib/evaluate/intrinsics.cc +++ b/flang/lib/evaluate/intrinsics.cc @@ -15,6 +15,7 @@ #include "intrinsics.h" #include "expression.h" #include "fold.h" +#include "shape.h" #include "tools.h" #include "type.h" #include "../common/Fortran.h" @@ -502,6 +503,7 @@ static const IntrinsicInterface genericIntrinsicFunction[]{ {"product", {{"array", SameNumeric, Rank::array}, OptionalDIM, OptionalMASK}, SameNumeric, Rank::dimReduced}, + // TODO pmk: "rank" {"real", {{"a", AnyNumeric, Rank::elementalOrBOZ}, DefaultingKIND}, KINDReal}, {"reduce", @@ -607,7 +609,7 @@ static const IntrinsicInterface genericIntrinsicFunction[]{ // COSHAPE // TODO: Object characteristic inquiry functions // ALLOCATED, ASSOCIATED, EXTENDS_TYPE_OF, IS_CONTIGUOUS, -// PRESENT, RANK, SAME_TYPE, STORAGE_SIZE +// PRESENT, SAME_TYPE, STORAGE_SIZE // TODO: Type inquiry intrinsic functions - these return constants // BIT_SIZE, DIGITS, EPSILON, HUGE, KIND, MAXEXPONENT, MINEXPONENT, // NEW_LINE, PRECISION, RADIX, RANGE, TINY @@ -939,7 +941,7 @@ std::optional IntrinsicInterface::Match( // Check the ranks of the arguments against the intrinsic's interface. const ActualArgument *arrayArg{nullptr}; const ActualArgument *knownArg{nullptr}; - const ActualArgument *shapeArg{nullptr}; + std::optional shapeArgSize; int elementalRank{0}; for (std::size_t j{0}; j < dummies; ++j) { const IntrinsicDummyArgument &d{dummy[std::min(j, dummyArgPatterns - 1)]}; @@ -963,9 +965,21 @@ std::optional IntrinsicInterface::Match( case Rank::scalar: argOk = rank == 0; break; case Rank::vector: argOk = rank == 1; break; case Rank::shape: - CHECK(shapeArg == nullptr); - shapeArg = arg; - argOk = rank == 1 && arg->VectorSize().has_value(); + CHECK(!shapeArgSize.has_value()); + if (rank == 1) { + if (auto shape{GetShape(*arg)}) { + CHECK(shape->size() == 1); + if (auto value{ToInt64(shape->at(0))}) { + shapeArgSize = *value; + argOk = *value >= 0; + } + } + } + if (!argOk) { + messages.Say( + "'shape=' argument must be a vector of known size"_err_en_US); + return std::nullopt; + } break; case Rank::matrix: argOk = rank == 2; break; case Rank::array: @@ -1134,8 +1148,8 @@ std::optional IntrinsicInterface::Match( resultRank = knownArg->Rank() + 1; break; case Rank::shaped: - CHECK(shapeArg != nullptr); - resultRank = shapeArg->VectorSize().value(); + CHECK(shapeArgSize.has_value()); + resultRank = *shapeArgSize; break; case Rank::elementalOrBOZ: case Rank::shape: diff --git a/flang/lib/evaluate/shape.cc b/flang/lib/evaluate/shape.cc index 63b1663ee86a..e17ff1a557a8 100644 --- a/flang/lib/evaluate/shape.cc +++ b/flang/lib/evaluate/shape.cc @@ -19,25 +19,102 @@ #include "../semantics/symbol.h" namespace Fortran::evaluate { + +static Extent GetLowerBound(const semantics::Symbol &symbol, + const Component *component, int dimension) { + if (const auto *details{symbol.detailsIf()}) { + int j{0}; + for (const auto &shapeSpec : details->shape()) { + if (j++ == dimension) { + if (const auto &bound{shapeSpec.lbound().GetExplicit()}) { + return *bound; + } else if (component != nullptr) { + return Expr{DescriptorInquiry{ + *component, DescriptorInquiry::Field::LowerBound, dimension}}; + } else { + return Expr{DescriptorInquiry{ + symbol, DescriptorInquiry::Field::LowerBound, dimension}}; + } + } + } + } + return std::nullopt; +} + +static Extent GetExtent(const semantics::Symbol &symbol, + const Component *component, int dimension) { + if (const auto *details{symbol.detailsIf()}) { + int j{0}; + for (const auto &shapeSpec : details->shape()) { + if (j++ == dimension) { + if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) { + if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) { + FoldingContext noFoldingContext; + return Fold(noFoldingContext, + common::Clone(ubound.value()) - common::Clone(lbound.value()) + + Expr{1}); + } + } + if (component != nullptr) { + return Expr{DescriptorInquiry{ + *component, DescriptorInquiry::Field::Extent, dimension}}; + } else { + return Expr{DescriptorInquiry{ + &symbol, DescriptorInquiry::Field::Extent, dimension}}; + } + } + } + } + return std::nullopt; +} + +static Extent GetExtent(const Subscript &subscript, const Symbol &symbol, + const Component *component, int dimension) { + return std::visit( + common::visitors{ + [&](const Triplet &triplet) -> Extent { + Extent upper{triplet.upper()}; + if (!upper.has_value()) { + upper = GetExtent(symbol, component, dimension); + } + if (upper.has_value()) { + Extent lower{triplet.lower()}; + if (!lower.has_value()) { + lower = GetLowerBound(symbol, component, dimension); + } + if (lower.has_value()) { + auto span{ + (std::move(*upper) - std::move(*lower) + triplet.stride()) / + triplet.stride()}; + Expr extent{ + Extremum{std::move(span), + Expr{0}, Ordering::Greater}}; + FoldingContext noFoldingContext; + return Fold(noFoldingContext, std::move(extent)); + } + } + return std::nullopt; + }, + [](const IndirectSubscriptIntegerExpr &subs) -> Extent { + if (auto shape{GetShape(subs.value())}) { + if (shape->size() > 0) { + CHECK(shape->size() == 1); // vector-valued subscript + return std::move(shape->at(0)); + } + } + return std::nullopt; + }, + }, + subscript.u); +} + std::optional GetShape( const semantics::Symbol &symbol, const Component *component) { if (const auto *details{symbol.detailsIf()}) { Shape result; - int dimension{1}; - for (const auto &shapeSpec : details->shape()) { - if (shapeSpec.isExplicit()) { - result.emplace_back( - common::Clone(shapeSpec.ubound().GetExplicit().value()) - - common::Clone(shapeSpec.lbound().GetExplicit().value()) + - Expr{1}); - } else if (component != nullptr) { - result.emplace_back(Expr{DescriptorInquiry{ - *component, DescriptorInquiry::Field::Extent, dimension}}); - } else { - result.emplace_back(Expr{DescriptorInquiry{ - symbol, DescriptorInquiry::Field::Extent, dimension}}); - } - ++dimension; + int n = details->shape().size(); + for (int dimension{0}; dimension < n; ++dimension) { + result.emplace_back(GetExtent(symbol, component, dimension++)); } return result; } else { @@ -45,6 +122,14 @@ std::optional GetShape( } } +std::optional GetShape(const BaseObject &object) { + if (const Symbol * symbol{object.symbol()}) { + return GetShape(*symbol); + } else { + return Shape{}; + } +} + std::optional GetShape(const Component &component) { const Symbol &symbol{component.GetLastSymbol()}; if (symbol.Rank() > 0) { @@ -53,45 +138,17 @@ std::optional GetShape(const Component &component) { return GetShape(component.base()); } } -static Extent GetExtent(const Subscript &subscript) { - return std::visit( - common::visitors{ - [](const Triplet &triplet) -> Extent { - if (auto lower{triplet.lower()}) { - if (auto lowerValue{ToInt64(*lower)}) { - if (auto upper{triplet.upper()}) { - if (auto upperValue{ToInt64(*upper)}) { - if (auto strideValue{ToInt64(triplet.stride())}) { - if (*strideValue != 0) { - std::int64_t extent{ - (*upperValue - *lowerValue + *strideValue) / - *strideValue}; - return Expr{extent > 0 ? extent : 0}; - } - } - } - } - } - } - return std::nullopt; - }, - [](const IndirectSubscriptIntegerExpr &subs) -> Extent { - if (auto shape{GetShape(subs.value())}) { - if (shape->size() == 1) { - return std::move(shape->at(0)); - } - } - return std::nullopt; - }, - }, - subscript.u); -} + std::optional GetShape(const ArrayRef &arrayRef) { Shape shape; + const Symbol &symbol{arrayRef.GetLastSymbol()}; + const Component *component{std::get_if(&arrayRef.base())}; + int dimension{0}; for (const Subscript &ss : arrayRef.subscript()) { if (ss.Rank() > 0) { - shape.emplace_back(GetExtent(ss)); + shape.emplace_back(GetExtent(ss, symbol, component, dimension)); } + ++dimension; } if (shape.empty()) { return GetShape(arrayRef.base()); @@ -99,12 +156,18 @@ std::optional GetShape(const ArrayRef &arrayRef) { return shape; } } + std::optional GetShape(const CoarrayRef &coarrayRef) { Shape shape; + SymbolOrComponent base{coarrayRef.GetBaseSymbolOrComponent()}; + const Symbol &symbol{coarrayRef.GetLastSymbol()}; + const Component *component{std::get_if(&base)}; + int dimension{0}; for (const Subscript &ss : coarrayRef.subscript()) { if (ss.Rank() > 0) { - shape.emplace_back(GetExtent(ss)); + shape.emplace_back(GetExtent(ss, symbol, component, dimension)); } + ++dimension; } if (shape.empty()) { return GetShape(coarrayRef.GetLastSymbol()); @@ -112,9 +175,11 @@ std::optional GetShape(const CoarrayRef &coarrayRef) { return shape; } } + std::optional GetShape(const DataRef &dataRef) { return std::visit([](const auto &x) { return GetShape(x); }, dataRef.u); } + std::optional GetShape(const Substring &substring) { if (const auto *dataRef{substring.GetParentIf()}) { return GetShape(*dataRef); @@ -122,7 +187,49 @@ std::optional GetShape(const Substring &substring) { return std::nullopt; } } + std::optional GetShape(const ComplexPart &part) { return GetShape(part.complex()); } + +std::optional GetShape(const ActualArgument &arg) { + return GetShape(arg.value()); +} + +std::optional GetShape(const ProcedureRef &call) { + if (call.Rank() == 0) { + return Shape{}; + } else if (call.IsElemental()) { + for (const auto &arg : call.arguments()) { + if (arg.has_value() && arg->Rank() > 0) { + return GetShape(*arg); + } + } + } else if (const Symbol * symbol{call.proc().GetSymbol()}) { + return GetShape(*symbol); + } else if (const auto *intrinsic{ + std::get_if(&call.proc().u)}) { + if (intrinsic->name == "shape" || intrinsic->name == "lbound" || + intrinsic->name == "ubound") { + return Shape{Extent{Expr{ + call.arguments().front().value().value().Rank()}}}; + } + // TODO: shapes of other non-elemental intrinsic results + // esp. reshape, where shape is value of second argument + } + return std::nullopt; +} + +std::optional GetShape(const StructureConstructor &) { + return Shape{}; // always scalar +} + +std::optional GetShape(const BOZLiteralConstant &) { + return Shape{}; // always scalar +} + +std::optional GetShape(const NullPointer &) { + return {}; // not an object +} + } diff --git a/flang/lib/evaluate/shape.h b/flang/lib/evaluate/shape.h index 30e5e92f8903..44829b66ad59 100644 --- a/flang/lib/evaluate/shape.h +++ b/flang/lib/evaluate/shape.h @@ -30,10 +30,69 @@ using Extent = std::optional>; using Shape = std::vector; template std::optional GetShape(const A &) { - return std::nullopt; + return std::nullopt; // default case } -template std::optional GetShape(const Expr &); +// Forward declarations +template +std::optional GetShape(const std::variant &); +template +std::optional GetShape(const common::Indirection &); +template std::optional GetShape(const std::optional &); + +template std::optional GetShape(const Expr &expr) { + return GetShape(expr.u); +} + +std::optional GetShape( + const semantics::Symbol &, const Component * = nullptr); +std::optional GetShape(const BaseObject &); +std::optional GetShape(const Component &); +std::optional GetShape(const ArrayRef &); +std::optional GetShape(const CoarrayRef &); +std::optional GetShape(const DataRef &); +std::optional GetShape(const Substring &); +std::optional GetShape(const ComplexPart &); +std::optional GetShape(const ActualArgument &); +std::optional GetShape(const ProcedureRef &); +std::optional GetShape(const StructureConstructor &); +std::optional GetShape(const BOZLiteralConstant &); +std::optional GetShape(const NullPointer &); + +template +std::optional GetShape(const Designator &designator) { + return GetShape(designator.u); +} + +template +std::optional GetShape(const Variable &variable) { + return GetShape(variable.u); +} + +template +std::optional GetShape(const Operation &operation) { + if constexpr (operation.operands > 1) { + if (operation.right().Rank() > 0) { + return GetShape(operation.right()); + } + } + return GetShape(operation.left()); +} + +template +std::optional GetShape(const TypeParamInquiry &) { + return Shape{}; // always scalar +} + +template +std::optional GetShape(const ArrayConstructorValues &aconst) { + return std::nullopt; // TODO pmk much more here!! +} + +template +std::optional GetShape(const std::variant &u) { + return std::visit([](const auto &x) { return GetShape(x); }, u); +} template std::optional GetShape(const common::Indirection &p) { @@ -47,34 +106,5 @@ template std::optional GetShape(const std::optional &x) { return std::nullopt; } } - -template -std::optional GetShape(const std::variant &u) { - return std::visit([](const auto &x) { return GetShape(x); }, u); -} - -std::optional GetShape( - const semantics::Symbol &, const Component * = nullptr); -std::optional GetShape(const DataRef &); -std::optional GetShape(const ComplexPart &); -std::optional GetShape(const Substring &); -std::optional GetShape(const Component &); -std::optional GetShape(const ArrayRef &); -std::optional GetShape(const CoarrayRef &); - -template -std::optional GetShape(const Designator &designator) { - return std::visit([](const auto &x) { return GetShape(x); }, designator.u); -} - -template std::optional GetShape(const Expr &expr) { - return std::visit( - common::visitors{ - [](const BOZLiteralConstant &) { return Shape{}; }, - [](const NullPointer &) { return std::nullopt; }, - [](const auto &x) { return GetShape(x); }, - }, - expr.u); -} } #endif // FORTRAN_EVALUATE_SHAPE_H_ diff --git a/flang/lib/evaluate/variable.cc b/flang/lib/evaluate/variable.cc index 8f1a7e88b914..abcb8b2a3d48 100644 --- a/flang/lib/evaluate/variable.cc +++ b/flang/lib/evaluate/variable.cc @@ -57,9 +57,7 @@ std::optional> Triplet::upper() const { return std::nullopt; } -const Expr &Triplet::stride() const { - return stride_.value(); -} +Expr Triplet::stride() const { return stride_.value(); } bool Triplet::IsStrideOne() const { if (auto stride{ToInt64(stride_.value())}) { @@ -359,18 +357,18 @@ int ArrayRef::Rank() const { } return std::visit( common::visitors{ - [=](const Symbol *s) { return s->Rank(); }, + [=](const Symbol *s) { return 0; }, [=](const Component &c) { return c.Rank(); }, }, base_); } int CoarrayRef::Rank() const { - int rank{0}; - for (const auto &expr : subscript_) { - rank += expr.Rank(); - } - if (rank > 0) { + if (!subscript_.empty()) { + int rank{0}; + for (const auto &expr : subscript_) { + rank += expr.Rank(); + } return rank; } else { return base_.back()->Rank(); @@ -519,6 +517,21 @@ template std::optional Designator::GetType() const { } } +SymbolOrComponent CoarrayRef::GetBaseSymbolOrComponent() const { + SymbolOrComponent base{base_.front()}; + int j{0}; + for (const Symbol *symbol : base_) { + if (j == 0) { // X - already captured the symbol above + } else if (j == 1) { // X%Y + base = Component{DataRef{std::get(base)}, *symbol}; + } else { // X%Y%Z or more + base = Component{DataRef{std::move(std::get(base))}, *symbol}; + } + ++j; + } + return base; +} + // Equality testing bool BaseObject::operator==(const BaseObject &that) const { diff --git a/flang/lib/evaluate/variable.h b/flang/lib/evaluate/variable.h index 14ddf5f53926..468145db322e 100644 --- a/flang/lib/evaluate/variable.h +++ b/flang/lib/evaluate/variable.h @@ -142,7 +142,7 @@ public: std::optional> &&); std::optional> lower() const; std::optional> upper() const; - const Expr &stride() const; + Expr stride() const; bool operator==(const Triplet &) const; bool IsStrideOne() const; std::ostream &AsFortran(std::ostream &) const; @@ -237,6 +237,7 @@ public: int Rank() const; const Symbol &GetFirstSymbol() const; const Symbol &GetLastSymbol() const; + SymbolOrComponent GetBaseSymbolOrComponent() const; Expr LEN() const; bool operator==(const CoarrayRef &) const; std::ostream &AsFortran(std::ostream &) const; @@ -404,7 +405,7 @@ public: private: SymbolOrComponent base_{nullptr}; Field field_; - int dimension_{0}; + int dimension_{0}; // zero-based }; #define INSTANTIATE_VARIABLE_TEMPLATES \ diff --git a/flang/lib/parser/message.h b/flang/lib/parser/message.h index 771bc8b6d986..b26b8c39fd59 100644 --- a/flang/lib/parser/message.h +++ b/flang/lib/parser/message.h @@ -241,6 +241,7 @@ private: class ContextualMessages { public: + ContextualMessages() = default; ContextualMessages(CharBlock at, Messages *m) : at_{at}, messages_{m} {} ContextualMessages(const ContextualMessages &that) : at_{that.at_}, messages_{that.messages_} {}