diff --git a/flang/lib/evaluate/call.cc b/flang/lib/evaluate/call.cc index ffb73335098f..7f9f1556f3ff 100644 --- a/flang/lib/evaluate/call.cc +++ b/flang/lib/evaluate/call.cc @@ -14,19 +14,45 @@ #include "call.h" #include "expression.h" +#include "tools.h" +#include "../common/idioms.h" #include "../semantics/symbol.h" namespace Fortran::evaluate { -std::optional ActualArgument::GetType() const { - return value().GetType(); +ActualArgument::AssumedType::AssumedType(const semantics::Symbol &symbol) + : symbol_{&symbol} { + const semantics::DeclTypeSpec *type{symbol.GetType()}; + CHECK( + type != nullptr && type->category() == semantics::DeclTypeSpec::TypeStar); } -int ActualArgument::Rank() const { return value().Rank(); } +int ActualArgument::AssumedType::Rank() const { return symbol_->Rank(); } + +ActualArgument &ActualArgument::operator=(Expr &&expr) { + u_ = std::move(expr); + return *this; +} + +std::optional ActualArgument::GetType() const { + if (const auto *expr{GetExpr()}) { + return expr->GetType(); + } else { + return std::nullopt; + } +} + +int ActualArgument::Rank() const { + if (const auto *expr{GetExpr()}) { + return expr->Rank(); + } else { + return std::get(u_).Rank(); + } +} bool ActualArgument::operator==(const ActualArgument &that) const { return keyword == that.keyword && - isAlternateReturn == that.isAlternateReturn && value() == that.value(); + isAlternateReturn == that.isAlternateReturn && u_ == that.u_; } bool SpecificIntrinsic::operator==(const SpecificIntrinsic &that) const { @@ -80,8 +106,28 @@ const Symbol *ProcedureDesignator::GetSymbol() const { } Expr ProcedureRef::LEN() const { - // TODO: the results of the intrinsic functions REPEAT and TRIM have - // unpredictable lengths; maybe the concept of LEN() has to become dynamic + if (const auto *intrinsic{std::get_if(&proc_.u)}) { + if (intrinsic->name == "repeat") { + // LEN(REPEAT(ch,n)) == LEN(ch) * n + CHECK(arguments_.size() == 2); + const auto *stringArg{ + UnwrapExpr>(arguments_[0].value())}; + const auto *nCopiesArg{ + UnwrapExpr>(arguments_[1].value())}; + CHECK(stringArg != nullptr && nCopiesArg != nullptr); + auto stringLen{stringArg->LEN()}; + return std::move(stringLen) * + ConvertTo(stringLen, common::Clone(*nCopiesArg)); + } + if (intrinsic->name == "trim") { + // LEN(TRIM(ch)) is unknown without execution. + CHECK(arguments_.size() == 1); + const auto *stringArg{ + UnwrapExpr>(arguments_[0].value())}; + CHECK(stringArg != nullptr); + return stringArg->LEN(); + } + } return proc_.LEN(); } diff --git a/flang/lib/evaluate/call.h b/flang/lib/evaluate/call.h index 1c9bc396cc43..d9dc8968d263 100644 --- a/flang/lib/evaluate/call.h +++ b/flang/lib/evaluate/call.h @@ -41,12 +41,55 @@ namespace Fortran::evaluate { class ActualArgument { public: - explicit ActualArgument(Expr &&x) : value_{std::move(x)} {} - explicit ActualArgument(common::CopyableIndirection> &&v) - : value_{std::move(v)} {} + // Dummy arguments that are TYPE(*) can be forwarded as actual arguments. + // Since that's the only thing one may do with them in Fortran, they're + // represented in expressions as a special case of an actual argument. + class AssumedType { + public: + explicit AssumedType(const semantics::Symbol &); + DEFAULT_CONSTRUCTORS_AND_ASSIGNMENTS(AssumedType) + const semantics::Symbol &symbol() const { return *symbol_; } + int Rank() const; + bool operator==(const AssumedType &that) const { + return symbol_ == that.symbol_; + } + std::ostream &AsFortran(std::ostream &) const; - Expr &value() { return value_.value(); } - const Expr &value() const { return value_.value(); } + private: + const semantics::Symbol *symbol_; + }; + + explicit ActualArgument(Expr &&x) : u_{std::move(x)} {} + explicit ActualArgument(common::CopyableIndirection> &&v) + : u_{std::move(v)} {} + explicit ActualArgument(AssumedType x) : u_{x} {} + + ActualArgument &operator=(Expr &&); + + Expr *GetExpr() { + if (auto *p{ + std::get_if>>(&u_)}) { + return &p->value(); + } else { + return nullptr; + } + } + const Expr *GetExpr() const { + if (const auto *p{ + std::get_if>>(&u_)}) { + return &p->value(); + } else { + return nullptr; + } + } + + const semantics::Symbol *GetAssumedTypeDummy() const { + if (const AssumedType * aType{std::get_if(&u_)}) { + return &aType->symbol(); + } else { + return nullptr; + } + } std::optional GetType() const; int Rank() const; @@ -64,7 +107,7 @@ private: // e.g. between X and (X). The parser attempts to parse each argument // first as a variable, then as an expression, and the distinction appears // in the parse tree. - common::CopyableIndirection> value_; + std::variant>, AssumedType> u_; }; using ActualArguments = std::vector>; diff --git a/flang/lib/evaluate/constant.cc b/flang/lib/evaluate/constant.cc index 05c4f0356deb..90eb113fbd0c 100644 --- a/flang/lib/evaluate/constant.cc +++ b/flang/lib/evaluate/constant.cc @@ -19,16 +19,41 @@ namespace Fortran::evaluate { +std::size_t TotalElementCount(const ConstantSubscripts &shape) { + std::size_t size{1}; + for (auto dim : shape) { + CHECK(dim >= 0); + size *= dim; + } + return size; +} + +bool IncrementSubscripts( + ConstantSubscripts &indices, const ConstantSubscripts &shape) { + auto rank{shape.size()}; + CHECK(indices.size() == rank); + for (std::size_t j{0}; j < rank; ++j) { + CHECK(indices[j] >= 1); + if (++indices[j] <= shape[j]) { + return true; + } else { + CHECK(indices[j] == shape[j] + 1); + indices[j] = 1; + } + } + return false; // all done +} + template ConstantBase::~ConstantBase() {} -static std::int64_t SubscriptsToOffset(const std::vector &index, - const std::vector &shape) { +static ConstantSubscript SubscriptsToOffset( + const ConstantSubscripts &index, const ConstantSubscripts &shape) { CHECK(index.size() == shape.size()); - std::int64_t stride{1}, offset{0}; + ConstantSubscript stride{1}, offset{0}; int dim{0}; - for (std::int64_t j : index) { - std::int64_t bound{shape[dim++]}; + for (auto j : index) { + auto bound{shape[dim++]}; CHECK(j >= 1 && j <= bound); offset += stride * (j - 1); stride *= bound; @@ -37,26 +62,26 @@ static std::int64_t SubscriptsToOffset(const std::vector &index, } template -auto ConstantBase::At( - const std::vector &index) const -> ScalarValue { +auto ConstantBase::At(const ConstantSubscripts &index) const + -> ScalarValue { return values_.at(SubscriptsToOffset(index, shape_)); } template -auto ConstantBase::At(std::vector &&index) const +auto ConstantBase::At(ConstantSubscripts &&index) const -> ScalarValue { return values_.at(SubscriptsToOffset(index, shape_)); } static Constant ShapeAsConstant( - const std::vector &shape) { + const ConstantSubscripts &shape) { using IntType = Scalar; std::vector result; - for (std::int64_t dim : shape) { + for (auto dim : shape) { result.emplace_back(dim); } return {std::move(result), - std::vector{static_cast(shape.size())}}; + ConstantSubscripts{static_cast(shape.size())}}; } template @@ -76,7 +101,7 @@ Constant>::Constant(ScalarValue &&str) template Constant>::Constant(std::int64_t len, - std::vector &&strings, std::vector &&dims) + std::vector &&strings, ConstantSubscripts &&dims) : length_{len}, shape_{std::move(dims)} { values_.assign(strings.size() * length_, static_cast(' ')); @@ -95,8 +120,8 @@ Constant>::Constant(std::int64_t len, template Constant>::~Constant() {} -static std::int64_t ShapeElements(const std::vector &shape) { - std::int64_t elements{1}; +static ConstantSubscript ShapeElements(const ConstantSubscripts &shape) { + ConstantSubscript elements{1}; for (auto dim : shape) { elements *= dim; } @@ -119,7 +144,7 @@ std::size_t Constant>::size() const { template auto Constant>::At( - const std::vector &index) const -> ScalarValue { + const ConstantSubscripts &index) const -> ScalarValue { auto offset{SubscriptsToOffset(index, shape_)}; return values_.substr(offset, length_); } @@ -138,10 +163,10 @@ Constant::Constant(StructureConstructor &&x) : Base{std::move(x.values())}, derivedTypeSpec_{&x.derivedTypeSpec()} {} Constant::Constant(const semantics::DerivedTypeSpec &spec, - std::vector &&x, std::vector &&s) + std::vector &&x, ConstantSubscripts &&s) : Base{std::move(x), std::move(s)}, derivedTypeSpec_{&spec} {} -static std::vector GetValues( +static std::vector AcquireValues( std::vector &&x) { std::vector result; for (auto &&structure : std::move(x)) { @@ -151,8 +176,8 @@ static std::vector GetValues( } Constant::Constant(const semantics::DerivedTypeSpec &spec, - std::vector &&x, std::vector &&s) - : Base{GetValues(std::move(x)), std::move(s)}, derivedTypeSpec_{&spec} {} + std::vector &&x, ConstantSubscripts &&s) + : Base{AcquireValues(std::move(x)), std::move(s)}, derivedTypeSpec_{&spec} {} INSTANTIATE_CONSTANT_TEMPLATES } diff --git a/flang/lib/evaluate/constant.h b/flang/lib/evaluate/constant.h index 4ef0fd8c563d..80a19f940d52 100644 --- a/flang/lib/evaluate/constant.h +++ b/flang/lib/evaluate/constant.h @@ -32,6 +32,24 @@ namespace Fortran::evaluate { template class Constant; +// When describing shapes of constants or specifying 1-based subscript +// values as indices into constants, use a vector of integers. +using ConstantSubscript = std::int64_t; +using ConstantSubscripts = std::vector; + +std::size_t TotalElementCount(const ConstantSubscripts &); + +inline ConstantSubscripts InitialSubscripts(int rank) { + return ConstantSubscripts(rank, 1); // parens, not braces: "rank" copies of 1 +} +inline ConstantSubscripts InitialSubscripts(const ConstantSubscripts &shape) { + return InitialSubscripts(static_cast(shape.size())); +} + +// Increments a vector of subscripts in Fortran array order (first dimension +// varying most quickly). Returns false when last element was visited. +bool IncrementSubscripts(ConstantSubscripts &, const ConstantSubscripts &shape); + // Constant<> is specialized for Character kinds and SomeDerived. // The non-Character intrinsic types, and SomeDerived, share enough // common behavior that they use this common base class. @@ -45,7 +63,7 @@ public: template ConstantBase(const A &x) : values_{x} {} template> ConstantBase(A &&x) : values_{std::move(x)} {} - ConstantBase(std::vector &&x, std::vector &&dims) + ConstantBase(std::vector &&x, ConstantSubscripts &&dims) : values_(std::move(x)), shape_(std::move(dims)) {} DEFAULT_CONSTRUCTORS_AND_ASSIGNMENTS(ConstantBase) ~ConstantBase(); @@ -57,7 +75,8 @@ public: bool empty() const { return values_.empty(); } std::size_t size() const { return values_.size(); } const std::vector &values() const { return values_; } - const std::vector &shape() const { return shape_; } + const ConstantSubscripts &shape() const { return shape_; } + ConstantSubscripts &shape() { return shape_; } ScalarValue operator*() const { CHECK(values_.size() == 1); @@ -65,15 +84,15 @@ public: } // Apply 1-based subscripts - ScalarValue At(const std::vector &) const; - ScalarValue At(std::vector &&) const; + ScalarValue At(const ConstantSubscripts &) const; + ScalarValue At(ConstantSubscripts &&) const; Constant SHAPE() const; std::ostream &AsFortran(std::ostream &) const; protected: std::vector values_; - std::vector shape_; + ConstantSubscripts shape_; private: const Constant &AsConstant() const { @@ -96,11 +115,11 @@ template class Constant> { public: using Result = Type; using ScalarValue = Scalar; + CLASS_BOILERPLATE(Constant) explicit Constant(const ScalarValue &); explicit Constant(ScalarValue &&); - Constant( - std::int64_t, std::vector &&, std::vector &&); + Constant(std::int64_t, std::vector &&, ConstantSubscripts &&); ~Constant(); int Rank() const { return static_cast(shape_.size()); } @@ -109,7 +128,8 @@ public: } bool empty() const; std::size_t size() const; - const std::vector &shape() const { return shape_; } + const ConstantSubscripts &shape() const { return shape_; } + ConstantSubscripts &shape() { return shape_; } std::int64_t LEN() const { return length_; } @@ -119,7 +139,7 @@ public: } // Apply 1-based subscripts - ScalarValue At(const std::vector &) const; + ScalarValue At(const ConstantSubscripts &) const; Constant SHAPE() const; std::ostream &AsFortran(std::ostream &) const; @@ -128,7 +148,7 @@ public: private: ScalarValue values_; // one contiguous string std::int64_t length_; - std::vector shape_; + ConstantSubscripts shape_; }; using StructureConstructorValues = std::map public: using Result = SomeDerived; using Base = ConstantBase; + Constant(const StructureConstructor &); Constant(StructureConstructor &&); Constant(const semantics::DerivedTypeSpec &, std::vector &&, - std::vector &&); + ConstantSubscripts &&); Constant(const semantics::DerivedTypeSpec &, - std::vector &&, std::vector &&); + std::vector &&, ConstantSubscripts &&); CLASS_BOILERPLATE(Constant) const semantics::DerivedTypeSpec &derivedTypeSpec() const { diff --git a/flang/lib/evaluate/descender.h b/flang/lib/evaluate/descender.h index d4d8c255a13a..7a46c2cb895c 100644 --- a/flang/lib/evaluate/descender.h +++ b/flang/lib/evaluate/descender.h @@ -117,10 +117,14 @@ public: } template void Descend(const ArrayConstructorValues &avs) { - Visit(avs.values()); + for (const auto &x : avs) { + Visit(x); + } } template void Descend(ArrayConstructorValues &avs) { - Visit(avs.values()); + for (auto &x : avs) { + Visit(x); + } } template @@ -155,13 +159,13 @@ public: void Descend(const StructureConstructor &sc) { Visit(sc.derivedTypeSpec()); - for (const auto &pair : sc.values()) { + for (const auto &pair : sc) { Visit(pair.second); } } void Descend(StructureConstructor &sc) { Visit(sc.derivedTypeSpec()); - for (const auto &pair : sc.values()) { + for (const auto &pair : sc) { Visit(pair.second); } } @@ -241,8 +245,22 @@ public: template void Descend(const Variable &var) { Visit(var.u); } template void Descend(Variable &var) { Visit(var.u); } - void Descend(const ActualArgument &arg) { Visit(arg.value()); } - void Descend(ActualArgument &arg) { Visit(arg.value()); } + void Descend(const ActualArgument &arg) { + if (const auto *expr{arg.GetExpr()}) { + Visit(*expr); + } else { + const semantics::Symbol *aType{arg.GetAssumedTypeDummy()}; + Visit(*aType); + } + } + void Descend(ActualArgument &arg) { + if (auto *expr{arg.GetExpr()}) { + Visit(*expr); + } else { + const semantics::Symbol *aType{arg.GetAssumedTypeDummy()}; + Visit(*aType); + } + } void Descend(const ProcedureDesignator &p) { Visit(p.u); } void Descend(ProcedureDesignator &p) { Visit(p.u); } diff --git a/flang/lib/evaluate/expression.h b/flang/lib/evaluate/expression.h index dd54778a779e..6bdc513f7d8c 100644 --- a/flang/lib/evaluate/expression.h +++ b/flang/lib/evaluate/expression.h @@ -450,18 +450,26 @@ public: using Values = std::vector>; DEFAULT_CONSTRUCTORS_AND_ASSIGNMENTS(ArrayConstructorValues) ArrayConstructorValues() {} + bool operator==(const ArrayConstructorValues &) const; static constexpr int Rank() { return 1; } template common::NoLvalue Push(A &&x) { values_.emplace_back(std::move(x)); } - Values &values() { return values_; } - const Values &values() const { return values_; } + + typename Values::iterator begin() { return values_.begin(); } + typename Values::const_iterator begin() const { return values_.begin(); } + typename Values::iterator end() { return values_.end(); } + typename Values::const_iterator end() const { return values_.end(); } protected: Values values_; }; +// Note that there are specializations of ArrayConstructor for character +// and derived types, since they must carry additional type information, +// but that an empty ArrayConstructor can be constructed for any type +// given an expression from which such type information may be gleaned. template class ArrayConstructor : public ArrayConstructorValues { public: @@ -469,6 +477,7 @@ public: using Base = ArrayConstructorValues; DEFAULT_CONSTRUCTORS_AND_ASSIGNMENTS(ArrayConstructor) explicit ArrayConstructor(Base &&values) : Base{std::move(values)} {} + template explicit ArrayConstructor(const Expr &) {} static constexpr DynamicType GetType() { return Result::GetType(); } std::ostream &AsFortran(std::ostream &) const; }; @@ -482,6 +491,8 @@ public: CLASS_BOILERPLATE(ArrayConstructor) ArrayConstructor(Expr &&len, Base &&v) : Base{std::move(v)}, length_{std::move(len)} {} + template + explicit ArrayConstructor(const Expr &proto) : length_{proto.LEN()} {} bool operator==(const ArrayConstructor &) const; static constexpr DynamicType GetType() { return Result::GetType(); } std::ostream &AsFortran(std::ostream &) const; @@ -500,6 +511,11 @@ public: CLASS_BOILERPLATE(ArrayConstructor) ArrayConstructor(const semantics::DerivedTypeSpec &spec, Base &&v) : Base{std::move(v)}, derivedTypeSpec_{&spec} {} + template + explicit ArrayConstructor(const Expr &proto) + : derivedTypeSpec_{GetType(proto).derived} { + CHECK(derivedTypeSpec_ != nullptr); + } bool operator==(const ArrayConstructor &) const; const semantics::DerivedTypeSpec &derivedTypeSpec() const { return *derivedTypeSpec_; @@ -715,8 +731,18 @@ public: } StructureConstructorValues &values() { return values_; } const StructureConstructorValues &values() const { return values_; } + bool operator==(const StructureConstructor &) const; + StructureConstructorValues::iterator begin() { return values_.begin(); } + StructureConstructorValues::const_iterator begin() const { + return values_.begin(); + } + StructureConstructorValues::iterator end() { return values_.end(); } + StructureConstructorValues::const_iterator end() const { + return values_.end(); + } + StructureConstructor &Add(const semantics::Symbol &, Expr &&); int Rank() const { return 0; } DynamicType GetType() const; diff --git a/flang/lib/evaluate/fold.cc b/flang/lib/evaluate/fold.cc index 44d5831865ca..ea73249e588c 100644 --- a/flang/lib/evaluate/fold.cc +++ b/flang/lib/evaluate/fold.cc @@ -193,12 +193,12 @@ static inline Expr FoldElementalIntrinsicHelper(FoldingContext &context, (... && IsSpecificIntrinsicType)); // TODO derived types for MERGE? static_assert(sizeof...(TA) > 0); std::tuple *...> args{ - UnwrapExpr>(funcRef.arguments()[I].value().value())...}; + UnwrapExpr>(*funcRef.arguments()[I].value().GetExpr())...}; if ((... && (std::get(args) != nullptr))) { // Compute the shape of the result based on shapes of arguments - std::vector shape; + ConstantSubscripts shape; int rank{0}; - const std::vector *shapes[sizeof...(TA)]{ + const ConstantSubscripts *shapes[sizeof...(TA)]{ &std::get(args)->shape()...}; const int ranks[sizeof...(TA)]{std::get(args)->Rank()...}; for (unsigned int i{0}; i < sizeof...(TA); ++i) { @@ -222,29 +222,21 @@ static inline Expr FoldElementalIntrinsicHelper(FoldingContext &context, CHECK(rank == static_cast(shape.size())); // Compute all the scalar values of the results - std::size_t size{1}; - for (std::int64_t dim : shape) { - size *= dim; - } std::vector> results; - std::vector index(rank, 1); - for (std::size_t n{size}; n-- > 0;) { - if constexpr (std::is_same_v, - ScalarFuncWithContext>) { - results.emplace_back(func(context, - (ranks[I] ? std::get(args)->At(index) - : **std::get(args))...)); - } else if constexpr (std::is_same_v, - ScalarFunc>) { - results.emplace_back(func(( - ranks[I] ? std::get(args)->At(index) : **std::get(args))...)); - } - for (int d{0}; d < rank; ++d) { - if (++index[d] <= shape[d]) { - break; + if (TotalElementCount(shape) > 0) { + ConstantSubscripts index{InitialSubscripts(rank)}; + do { + if constexpr (std::is_same_v, + ScalarFuncWithContext>) { + results.emplace_back(func(context, + (ranks[I] ? std::get(args)->At(index) + : **std::get(args))...)); + } else if constexpr (std::is_same_v, + ScalarFunc>) { + results.emplace_back(func((ranks[I] ? std::get(args)->At(index) + : **std::get(args))...)); } - index[d] = 1; - } + } while (IncrementSubscripts(index, shape)); } // Build and return constant result if constexpr (TR::category == TypeCategory::Character) { @@ -273,12 +265,21 @@ static Expr FoldElementalIntrinsic(FoldingContext &context, template static Expr *UnwrapArgument(std::optional &arg) { - return UnwrapExpr>(arg.value().value()); + if (arg.has_value()) { + if (Expr * expr{arg->GetExpr()}) { + return UnwrapExpr>(*expr); + } + } + return nullptr; } static BOZLiteralConstant *UnwrapBozArgument( std::optional &arg) { - return std::get_if(&arg.value().value().u); + if (auto *expr{UnwrapArgument(arg)}) { + return std::get_if(&expr->u); + } else { + return nullptr; + } } template @@ -287,9 +288,8 @@ Expr> FoldOperation(FoldingContext &context, using T = Type; ActualArguments &args{funcRef.arguments()}; for (std::optional &arg : args) { - if (arg.has_value()) { - arg.value().value() = - FoldOperation(context, std::move(arg.value().value())); + if (auto *expr{UnwrapArgument(arg)}) { + *expr = FoldOperation(context, std::move(*expr)); } } if (auto *intrinsic{std::get_if(&funcRef.proc().u)}) { @@ -311,8 +311,8 @@ Expr> FoldOperation(FoldingContext &context, // convert boz for (int i{0}; i <= 1; ++i) { if (auto *x{UnwrapBozArgument(args[i])}) { - args[i].value().value() = - Fold(context, ConvertToType(std::move(*x))); + *args[i] = + AsGenericExpr(Fold(context, ConvertToType(std::move(*x)))); } } // Third argument can be of any kind. However, it must be smaller or equal @@ -320,8 +320,8 @@ Expr> FoldOperation(FoldingContext &context, using Int4 = Type; if (auto *n{UnwrapArgument(args[2])}) { if (n->GetType()->kind != 4) { - args[2].value().value() = - Fold(context, ConvertToType(std::move(*n))); + *args[2] = + AsGenericExpr(Fold(context, ConvertToType(std::move(*n)))); } } const auto fptr{ @@ -349,8 +349,8 @@ Expr> FoldOperation(FoldingContext &context, // convert boz for (int i{0}; i <= 1; ++i) { if (auto *x{UnwrapBozArgument(args[i])}) { - args[i].value().value() = - Fold(context, ConvertToType(std::move(*x))); + *args[i] = + AsGenericExpr(Fold(context, ConvertToType(std::move(*x)))); } } auto fptr{&Scalar::IAND}; @@ -371,8 +371,8 @@ Expr> FoldOperation(FoldingContext &context, using Int4 = Type; if (auto *n{UnwrapArgument(args[1])}) { if (n->GetType()->kind != 4) { - args[1].value().value() = - Fold(context, ConvertToType(std::move(*n))); + *args[1] = + AsGenericExpr(Fold(context, ConvertToType(std::move(*n)))); } } auto fptr{&Scalar::IBCLR}; @@ -396,18 +396,20 @@ Expr> FoldOperation(FoldingContext &context, return std::invoke(fptr, i, static_cast(pos.ToInt64())); })); } else if (name == "int") { - return std::visit( - [&](auto &&x) -> Expr { - using From = std::decay_t; - if constexpr (std::is_same_v || - std::is_same_v> || - std::is_same_v> || - std::is_same_v>) { - return Fold(context, ConvertToType(std::move(x))); - } - common::die("int() argument type not valid"); - }, - std::move(args[0].value().value().u)); + if (auto *expr{args[0].value().GetExpr()}) { + return std::visit( + [&](auto &&x) -> Expr { + using From = std::decay_t; + if constexpr (std::is_same_v || + std::is_same_v> || + std::is_same_v> || + std::is_same_v>) { + return Fold(context, ConvertToType(std::move(x))); + } + common::die("int() argument type not valid"); + }, + std::move(expr->u)); + } } else if (name == "kind") { if constexpr (common::HasMember) { return Expr{args[0].value().GetType()->kind}; @@ -466,8 +468,8 @@ Expr> FoldOperation(FoldingContext &context, using Int4 = Type; if (auto *n{UnwrapArgument(args[0])}) { if (n->GetType()->kind != 4) { - args[0].value().value() = - Fold(context, ConvertToType(std::move(*n))); + *args[0] = + AsGenericExpr(Fold(context, ConvertToType(std::move(*n)))); } } const auto fptr{name == "maskl" ? &Scalar::MASKL : &Scalar::MASKR}; @@ -479,8 +481,8 @@ Expr> FoldOperation(FoldingContext &context, // convert boz for (int i{0}; i <= 2; ++i) { if (auto *x{UnwrapBozArgument(args[i])}) { - args[i].value().value() = - Fold(context, ConvertToType(std::move(*x))); + *args[i] = + AsGenericExpr(Fold(context, ConvertToType(std::move(*x)))); } } return FoldElementalIntrinsic( @@ -489,24 +491,26 @@ Expr> FoldOperation(FoldingContext &context, // TODO assumed-rank dummy argument return Expr{args[0].value().Rank()}; } else if (name == "shape") { - if (auto shape{GetShape(args[0].value())}) { - if (auto shapeExpr{AsShapeArrayExpr(*shape)}) { + if (auto shape{GetShape(context, args[0].value())}) { + if (auto shapeExpr{AsExtentArrayExpr(*shape)}) { return Fold(context, ConvertToType(std::move(*shapeExpr))); } } } else if (name == "size") { - if (auto shape{GetShape(args[0].value())}) { + if (auto shape{GetShape(context, args[0].value())}) { if (auto &dimArg{args[1]}) { // DIM= is present, get one extent - if (auto dim{ToInt64(dimArg->value())}) { - std::int64_t rank = shape->size(); - if (*dim >= 1 && *dim <= rank) { - if (auto &extent{shape->at(*dim - 1)}) { - return Fold(context, ConvertToType(std::move(*extent))); + if (auto *expr{dimArg->GetExpr()}) { + if (auto dim{ToInt64(*expr)}) { + std::int64_t rank = shape->size(); + if (*dim >= 1 && *dim <= rank) { + if (auto &extent{shape->at(*dim - 1)}) { + return Fold(context, ConvertToType(std::move(*extent))); + } + } else { + context.messages().Say( + "size(array,dim=%jd) dimension is out of range for rank-%d array"_en_US, + static_cast(*dim), static_cast(rank)); } - } else { - context.messages().Say( - "size(array,dim=%jd) dimension is out of range for rank-%d array"_en_US, - static_cast(*dim), static_cast(rank)); } } } else if (auto extents{ @@ -539,8 +543,9 @@ Expr> FoldOperation(FoldingContext &context, ActualArguments &args{funcRef.arguments()}; for (std::optional &arg : args) { if (arg.has_value()) { - arg.value().value() = - FoldOperation(context, std::move(arg.value().value())); + if (auto *expr{arg->GetExpr()}) { + *expr = FoldOperation(context, std::move(*expr)); + } } } if (auto *intrinsic{std::get_if(&funcRef.proc().u)}) { @@ -584,8 +589,8 @@ Expr> FoldOperation(FoldingContext &context, using Int4 = Type; if (auto *n{UnwrapArgument(args[0])}) { if (n->GetType()->kind != 4) { - args[0].value().value() = - Fold(context, ConvertToType(std::move(*n))); + *args[0] = AsGenericExpr( + Fold(context, ConvertToType(std::move(*n)))); } } if (auto callable{ @@ -624,8 +629,8 @@ Expr> FoldOperation(FoldingContext &context, // Convert argument to the requested kind before calling aint if (auto *x{UnwrapArgument(args[0])}) { if (!(x->GetType()->kind == T::kind)) { - args[0].value().value() = - Fold(context, ConvertToType(std::move(*x))); + *args[0] = + AsGenericExpr(Fold(context, ConvertToType(std::move(*x)))); } } return FoldElementalIntrinsic(context, std::move(funcRef), @@ -649,25 +654,27 @@ Expr> FoldOperation(FoldingContext &context, } else if (name == "epsilon") { return Expr{Constant{Scalar::EPSILON()}}; } else if (name == "real") { - return std::visit( - [&](auto &&x) -> Expr { - using From = std::decay_t; - if constexpr (std::is_same_v) { - typename T::Scalar::Word::ValueWithOverflow result{ - T::Scalar::Word::ConvertUnsigned(x)}; - if (result.overflow) { // C1601 - context.messages().Say( - "Non null truncated bits of boz literal constant in REAL intrinsic"_en_US); + if (auto *expr{args[0].value().GetExpr()}) { + return std::visit( + [&](auto &&x) -> Expr { + using From = std::decay_t; + if constexpr (std::is_same_v) { + typename T::Scalar::Word::ValueWithOverflow result{ + T::Scalar::Word::ConvertUnsigned(x)}; + if (result.overflow) { // C1601 + context.messages().Say( + "Non null truncated bits of boz literal constant in REAL intrinsic"_en_US); + } + return Expr{Constant{Scalar(std::move(result.value))}}; + } else if constexpr (std::is_same_v> || + std::is_same_v> || + std::is_same_v>) { + return Fold(context, ConvertToType(std::move(x))); } - return Expr{Constant{Scalar(std::move(result.value))}}; - } else if constexpr (std::is_same_v> || - std::is_same_v> || - std::is_same_v>) { - return Fold(context, ConvertToType(std::move(x))); - } - common::die("real() argument type not valid"); - }, - std::move(args[0].value().value().u)); + common::die("real() argument type not valid"); + }, + std::move(expr->u)); + } } // TODO: anint, cshift, dim, dot_product, eoshift, fraction, huge, matmul, // max, maxval, merge, min, minval, modulo, nearest, norm2, pack, product, @@ -685,8 +692,9 @@ Expr> FoldOperation(FoldingContext &context, ActualArguments &args{funcRef.arguments()}; for (std::optional &arg : args) { if (arg.has_value()) { - arg.value().value() = - FoldOperation(context, std::move(arg.value().value())); + if (auto *expr{arg->GetExpr()}) { + *expr = FoldOperation(context, std::move(*expr)); + } } } if (auto *intrinsic{std::get_if(&funcRef.proc().u)}) { @@ -718,9 +726,9 @@ Expr> FoldOperation(FoldingContext &context, CHECK(args.size() == 3); using Part = typename T::Part; Expr im{args[1].has_value() - ? std::move(args[1].value().value()) + ? std::move(*args[1].value().GetExpr()) : AsGenericExpr(Constant{Scalar{}})}; - Expr re{std::move(args[0].value().value())}; + Expr re{std::move(*args[0].value().GetExpr())}; int reRank{re.Rank()}; int imRank{im.Rank()}; semantics::Attrs attrs; @@ -751,8 +759,9 @@ Expr> FoldOperation(FoldingContext &context, ActualArguments &args{funcRef.arguments()}; for (std::optional &arg : args) { if (arg.has_value()) { - arg.value().value() = - FoldOperation(context, std::move(arg.value().value())); + if (auto *expr{arg->GetExpr()}) { + *expr = FoldOperation(context, std::move(*expr)); + } } } if (auto *intrinsic{std::get_if(&funcRef.proc().u)}) { @@ -765,11 +774,10 @@ Expr> FoldOperation(FoldingContext &context, // simplify. for (int i{0}; i <= 1; ++i) { if (auto *x{UnwrapArgument(args[i])}) { - args[i].value().value() = - Fold(context, ConvertToType(std::move(*x))); + *args[i] = AsGenericExpr( + Fold(context, ConvertToType(std::move(*x)))); } else if (auto *x{UnwrapBozArgument(args[i])}) { - args[i].value().value() = - AsGenericExpr(Constant{std::move(*x)}); + *args[i] = AsGenericExpr(Constant{std::move(*x)}); } } auto fptr{&Scalar::BGE}; @@ -844,16 +852,16 @@ public: auto n{static_cast(elements_.size())}; if constexpr (std::is_same_v) { return Expr{Constant{array.derivedTypeSpec(), - std::move(elements_), std::vector{n}}}; + std::move(elements_), ConstantSubscripts{n}}}; } else if constexpr (T::category == TypeCategory::Character) { auto length{Fold(context_, common::Clone(array.LEN()))}; if (std::optional lengthValue{ToInt64(length)}) { - return Expr{Constant{*lengthValue, std::move(elements_), - std::vector{n}}}; + return Expr{Constant{ + *lengthValue, std::move(elements_), ConstantSubscripts{n}}}; } } else { return Expr{ - Constant{std::move(elements_), std::vector{n}}}; + Constant{std::move(elements_), ConstantSubscripts{n}}}; } } return Expr{std::move(array)}; @@ -864,9 +872,9 @@ private: Expr folded{Fold(context_, common::Clone(expr.value()))}; if (auto *c{UnwrapExpr>(folded)}) { // Copy elements in Fortran array element order - std::vector shape{c->shape()}; + ConstantSubscripts shape{c->shape()}; int rank{c->Rank()}; - std::vector index(shape.size(), 1); + ConstantSubscripts index(shape.size(), 1); for (std::size_t n{c->size()}; n-- > 0;) { if constexpr (std::is_same_v) { elements_.emplace_back(c->derivedTypeSpec(), c->At(index)); @@ -919,7 +927,7 @@ private: return std::visit([&](const auto &y) { return FoldArray(y); }, x.u); } bool FoldArray(const ArrayConstructorValues &xs) { - for (const auto &x : xs.values()) { + for (const auto &x : xs) { if (!FoldArray(x)) { return false; } @@ -941,7 +949,7 @@ Expr FoldOperation(FoldingContext &context, ArrayConstructor &&array) { Expr FoldOperation( FoldingContext &context, StructureConstructor &&structure) { StructureConstructor result{structure.derivedTypeSpec()}; - for (auto &&[symbol, value] : std::move(structure.values())) { + for (auto &&[symbol, value] : std::move(structure)) { result.Add(*symbol, Fold(context, std::move(value.value()))); } return Expr{Constant{result}}; @@ -984,6 +992,300 @@ Expr> FoldOperation( return Expr{std::move(inquiry)}; } +// Array operation elemental application: When all operands to an operation +// are constant arrays, array constructors without any implied DO loops, +// &/or expanded scalars, pull the operation "into" the array result by +// applying it in an elementwise fashion. For example, [A,1]+[B,2] +// is rewritten into [A+B,1+2] and then partially folded to [A+B,3]. + +// If possible, restructures an array expression into an array constructor +// that comprises a "flat" ArrayConstructorValues with no implied DO loops. +template +bool ArrayConstructorIsFlat(const ArrayConstructorValues &values) { + for (const ArrayConstructorValue &x : values) { + if (!std::holds_alternative>(x.u)) { + return false; + } + } + return true; +} + +template +std::optional> AsFlatArrayConstructor(const Expr &expr) { + if (const auto *c{UnwrapExpr>(expr)}) { + ArrayConstructor result{expr}; + if (c->size() > 0) { + ConstantSubscripts at{InitialSubscripts(c->shape())}; + do { + result.Push(Expr{Constant{c->At(at)}}); + } while (IncrementSubscripts(at, c->shape())); + } + return std::make_optional>(std::move(result)); + } else if (const auto *a{UnwrapExpr>(expr)}) { + if (ArrayConstructorIsFlat(*a)) { + return std::make_optional>(expr); + } + } else if (const auto *p{UnwrapExpr>(expr)}) { + return AsFlatArrayConstructor(Expr{p->left()}); + } + return std::nullopt; +} + +template +std::optional>> AsFlatArrayConstructor( + const Expr> &expr) { + return std::visit( + [&](const auto &kindExpr) -> std::optional>> { + if (auto flattened{AsFlatArrayConstructor(kindExpr)}) { + return Expr>{std::move(*flattened)}; + } else { + return std::nullopt; + } + }, + expr.u); +} + +// FromArrayConstructor is a subroutine for MapOperation() below. +// Given a flat ArrayConstructor and a shape, it wraps the array +// into an Expr, folds it, and returns the resulting wrapped +// array constructor or constant array value. +template +Expr FromArrayConstructor(FoldingContext &context, + ArrayConstructor &&values, std::optional &&shape) { + Expr result{Fold(context, Expr{std::move(values)})}; + if (shape.has_value()) { + if (auto *constant{UnwrapExpr>(result)}) { + constant->shape() = std::move(*shape); + } else { + auto resultShape{GetShape(context, result)}; + CHECK(resultShape.has_value()); + auto constantShape{AsConstantShape(*resultShape)}; + CHECK(constantShape.has_value()); + CHECK(*shape == AsConstantExtents(*constantShape)); + } + } + return result; +} + +// MapOperation is a utility for various specializations of ApplyElementwise() +// that follow. Given one or two flat ArrayConstructor (wrapped in an +// Expr) for some specific operand type(s), apply a given function f +// to each of their corresponding elements to produce a flat +// ArrayConstructor (wrapped in an Expr). +// Preserves shape. + +// Unary case +template +Expr MapOperation(FoldingContext &context, + std::function(Expr &&)> &&f, const Shape &shape, + Expr &&values) { + ArrayConstructor result{values}; + if constexpr (IsGenericIntrinsicCategoryType) { + std::visit( + [&](auto &&kindExpr) { + using kindType = ResultType; + auto &aConst{std::get>(kindExpr.u)}; + for (auto &acValue : aConst) { + auto &scalar{std::get>(acValue.u)}; + result.Push( + FoldOperation(context, f(Expr{std::move(scalar)}))); + } + }, + std::move(values.u)); + } else { + auto &aConst{std::get>(values.u)}; + for (auto &acValue : aConst) { + auto &scalar{std::get>(acValue.u)}; + result.Push(FoldOperation(context, f(std::move(scalar)))); + } + } + return FromArrayConstructor( + context, std::move(result), AsConstantExtents(shape)); +} + +// array * array case +template +Expr MapOperation(FoldingContext &context, + std::function(Expr &&, Expr &&)> &&f, + const Shape &shape, Expr &&leftValues, Expr &&rightValues) { + ArrayConstructor result{leftValues}; + auto &leftArrConst{std::get>(leftValues.u)}; + if constexpr (IsGenericIntrinsicCategoryType) { + std::visit( + [&](auto &&kindExpr) { + using kindType = ResultType; + + auto &rightArrConst{std::get>(kindExpr.u)}; + auto rightIter{rightArrConst.begin()}; + for (auto &leftValue : leftArrConst) { + CHECK(rightIter != rightArrConst.end()); + auto &leftScalar{std::get>(leftValue.u)}; + auto &rightScalar{std::get>(rightIter->u)}; + result.Push(FoldOperation(context, + f(std::move(leftScalar), Expr{std::move(rightScalar)}))); + ++rightIter; + } + }, + std::move(rightValues.u)); + } else { + auto &rightArrConst{std::get>(rightValues.u)}; + auto rightIter{rightArrConst.begin()}; + for (auto &leftValue : leftArrConst) { + CHECK(rightIter != rightArrConst.end()); + auto &leftScalar{std::get>(leftValue.u)}; + auto &rightScalar{std::get>(rightIter->u)}; + result.Push(FoldOperation( + context, f(std::move(leftScalar), std::move(rightScalar)))); + ++rightIter; + } + } + return FromArrayConstructor( + context, std::move(result), AsConstantExtents(shape)); +} + +// array * scalar case +template +Expr MapOperation(FoldingContext &context, + std::function(Expr &&, Expr &&)> &&f, + const Shape &shape, Expr &&leftValues, + const Expr &rightScalar) { + ArrayConstructor result{leftValues}; + auto &leftArrConst{std::get>(leftValues.u)}; + for (auto &leftValue : leftArrConst) { + auto &leftScalar{std::get>(leftValue.u)}; + result.Push(FoldOperation( + context, f(std::move(leftScalar), Expr{rightScalar}))); + } + return FromArrayConstructor( + context, std::move(result), AsConstantExtents(shape)); +} + +// scalar * array case +template +Expr MapOperation(FoldingContext &context, + std::function(Expr &&, Expr &&)> &&f, + const Shape &shape, const Expr &leftScalar, + Expr &&rightValues) { + ArrayConstructor result{leftScalar}; + if constexpr (IsGenericIntrinsicCategoryType) { + std::visit( + [&](auto &&kindExpr) { + using kindType = ResultType; + auto &rightArrConst{std::get>(kindExpr.u)}; + for (auto &rightValue : rightArrConst) { + auto &rightScalar{std::get>(rightValue.u)}; + result.Push(FoldOperation(context, + f(Expr{leftScalar}, + Expr{std::move(rightScalar)}))); + } + }, + std::move(rightValues.u)); + } else { + auto &rightArrConst{std::get>(rightValues.u)}; + for (auto &rightValue : rightArrConst) { + auto &rightScalar{std::get>(rightValue.u)}; + result.Push(FoldOperation( + context, f(Expr{leftScalar}, std::move(rightScalar)))); + } + } + return FromArrayConstructor( + context, std::move(result), AsConstantExtents(shape)); +} + +// ApplyElementwise() recursively folds the operand expression(s) of an +// operation, then attempts to apply the operation to the (corresponding) +// scalar element(s) of those operands. Returns std::nullopt for scalars +// or unlinearizable operands. +template +auto ApplyElementwise(FoldingContext &context, + Operation &operation, + std::function(Expr &&)> &&f) + -> std::optional> { + auto &expr{operation.left()}; + expr = Fold(context, std::move(expr)); + if (expr.Rank() > 0) { + if (std::optional shape{GetShape(context, expr)}) { + if (auto values{AsFlatArrayConstructor(expr)}) { + return MapOperation(context, std::move(f), *shape, std::move(*values)); + } + } + } + return std::nullopt; +} + +template +auto ApplyElementwise( + FoldingContext &context, Operation &operation) + -> std::optional> { + return ApplyElementwise(context, operation, + std::function(Expr &&)>{ + [](Expr &&operand) { + return Expr{DERIVED{std::move(operand)}}; + }}); +} + +// Predicate: is a scalar expression suitable for naive scalar expansion +// in the flattening of an array expression? +// TODO: capture such scalar expansions in temporaries, flatten everything +struct UnexpandabilityFindingVisitor : public virtual VisitorBase { + using Result = bool; + explicit UnexpandabilityFindingVisitor(int) { result() = false; } + template void Handle(FunctionRef &) { Return(true); } + template void Handle(CoarrayRef &) { Return(true); } +}; + +template bool IsExpandableScalar(const Expr &expr) { + return Visitor{0}.Traverse(expr); +} + +template +auto ApplyElementwise(FoldingContext &context, + Operation &operation, + std::function(Expr &&, Expr &&)> &&f) + -> std::optional> { + auto &leftExpr{operation.left()}; + leftExpr = Fold(context, std::move(leftExpr)); + auto &rightExpr{operation.right()}; + rightExpr = Fold(context, std::move(rightExpr)); + if (leftExpr.Rank() > 0) { + if (std::optional leftShape{GetShape(context, leftExpr)}) { + if (auto left{AsFlatArrayConstructor(leftExpr)}) { + if (rightExpr.Rank() > 0) { + if (std::optional rightShape{GetShape(context, rightExpr)}) { + if (auto right{AsFlatArrayConstructor(rightExpr)}) { + CheckConformance(context.messages(), *leftShape, *rightShape); + return MapOperation(context, std::move(f), *leftShape, + std::move(*left), std::move(*right)); + } + } + } else if (IsExpandableScalar(rightExpr)) { + return MapOperation( + context, std::move(f), *leftShape, std::move(*left), rightExpr); + } + } + } + } else if (rightExpr.Rank() > 0 && IsExpandableScalar(leftExpr)) { + if (std::optional shape{GetShape(context, rightExpr)}) { + if (auto right{AsFlatArrayConstructor(rightExpr)}) { + return MapOperation( + context, std::move(f), *shape, leftExpr, std::move(*right)); + } + } + } + return std::nullopt; +} + +template +auto ApplyElementwise( + FoldingContext &context, Operation &operation) + -> std::optional> { + return ApplyElementwise(context, operation, + std::function(Expr &&, Expr &&)>{ + [](Expr &&left, Expr &&right) { + return Expr{DERIVED{std::move(left), std::move(right)}}; + }}); +} + // Unary operations template @@ -1007,11 +1309,12 @@ common::IfNoLvalue, FROM> ConvertString(FROM &&s) { template Expr FoldOperation( FoldingContext &context, Convert &&convert) { + if (auto array{ApplyElementwise(context, convert)}) { + return *array; + } return std::visit( [&](auto &kindExpr) -> Expr { - kindExpr = Fold(context, std::move(kindExpr)); using Operand = ResultType; - // TODO pmk: conversion of array constructors (constant or not) char buffer[64]; if (auto value{GetScalarConstantValue(kindExpr)}) { if constexpr (TO::category == TypeCategory::Integer) { @@ -1081,13 +1384,15 @@ Expr FoldOperation(FoldingContext &context, Parentheses &&x) { // Preserve parentheses, even around constants. return Expr{Parentheses{Expr{Constant{*value}}}}; } - return Expr{std::move(x)}; + return Expr{Parentheses{std::move(operand)}}; } template Expr FoldOperation(FoldingContext &context, Negate &&x) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } auto &operand{x.left()}; - operand = Fold(context, std::move(operand)); if (auto value{GetScalarConstantValue(operand)}) { if constexpr (T::category == TypeCategory::Integer) { auto negated{value->Negate()}; @@ -1108,9 +1413,17 @@ template Expr> FoldOperation( FoldingContext &context, ComplexComponent &&x) { using Operand = Type; + using Result = Type; + if (auto array{ApplyElementwise(context, x, + std::function(Expr &&)>{ + [=](Expr &&operand) { + return Expr{ComplexComponent{ + x.isImaginaryPart, std::move(operand)}}; + }})}) { + return *array; + } using Part = Type; auto &operand{x.left()}; - operand = Fold(context, std::move(operand)); if (auto value{GetScalarConstantValue(operand)}) { if (x.isImaginaryPart) { return Expr{Constant{value->AIMAG()}}; @@ -1124,9 +1437,11 @@ Expr> FoldOperation( template Expr> FoldOperation( FoldingContext &context, Not &&x) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } using Ty = Type; auto &operand{x.left()}; - operand = Fold(context, std::move(operand)); if (auto value{GetScalarConstantValue(operand)}) { return Expr{Constant{!value->IsTrue()}}; } @@ -1135,22 +1450,29 @@ Expr> FoldOperation( // Binary (dyadic) operations -template -std::optional, Scalar>> FoldOperands( - FoldingContext &context, Expr &x, Expr &y) { - x = Fold(context, std::move(x)); // use of std::move() on &x is intentional - y = Fold(context, std::move(y)); - if (auto xvalue{GetScalarConstantValue(x)}) { - if (auto yvalue{GetScalarConstantValue(y)}) { +template +std::optional, Scalar>> OperandsAreConstants( + const Expr &x, const Expr &y) { + if (auto xvalue{GetScalarConstantValue(x)}) { + if (auto yvalue{GetScalarConstantValue(y)}) { return {std::make_pair(*xvalue, *yvalue)}; } } return std::nullopt; } +template +std::optional, Scalar>> OperandsAreConstants( + const Operation &operation) { + return OperandsAreConstants(operation.left(), operation.right()); +} + template Expr FoldOperation(FoldingContext &context, Add &&x) { - if (auto folded{FoldOperands(context, x.left(), x.right())}) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } + if (auto folded{OperandsAreConstants(x)}) { if constexpr (T::category == TypeCategory::Integer) { auto sum{folded->first.AddSigned(folded->second)}; if (sum.overflow) { @@ -1172,7 +1494,10 @@ Expr FoldOperation(FoldingContext &context, Add &&x) { template Expr FoldOperation(FoldingContext &context, Subtract &&x) { - if (auto folded{FoldOperands(context, x.left(), x.right())}) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } + if (auto folded{OperandsAreConstants(x)}) { if constexpr (T::category == TypeCategory::Integer) { auto difference{folded->first.SubtractSigned(folded->second)}; if (difference.overflow) { @@ -1195,7 +1520,10 @@ Expr FoldOperation(FoldingContext &context, Subtract &&x) { template Expr FoldOperation(FoldingContext &context, Multiply &&x) { - if (auto folded{FoldOperands(context, x.left(), x.right())}) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } + if (auto folded{OperandsAreConstants(x)}) { if constexpr (T::category == TypeCategory::Integer) { auto product{folded->first.MultiplySigned(folded->second)}; if (product.SignedMultiplicationOverflowed()) { @@ -1217,7 +1545,10 @@ Expr FoldOperation(FoldingContext &context, Multiply &&x) { template Expr FoldOperation(FoldingContext &context, Divide &&x) { - if (auto folded{FoldOperands(context, x.left(), x.right())}) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } + if (auto folded{OperandsAreConstants(x)}) { if constexpr (T::category == TypeCategory::Integer) { auto quotAndRem{folded->first.DivideSigned(folded->second)}; if (quotAndRem.divisionByZero) { @@ -1242,7 +1573,10 @@ Expr FoldOperation(FoldingContext &context, Divide &&x) { template Expr FoldOperation(FoldingContext &context, Power &&x) { - if (auto folded{FoldOperands(context, x.left(), x.right())}) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } + if (auto folded{OperandsAreConstants(x)}) { if constexpr (T::category == TypeCategory::Integer) { auto power{folded->first.Power(folded->second)}; if (power.divisionByZero) { @@ -1264,9 +1598,12 @@ Expr FoldOperation(FoldingContext &context, Power &&x) { template Expr FoldOperation(FoldingContext &context, RealToIntPower &&x) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } return std::visit( [&](auto &y) -> Expr { - if (auto folded{FoldOperands(context, x.left(), y)}) { + if (auto folded{OperandsAreConstants(x.left(), y)}) { auto power{evaluate::IntPower(folded->first, folded->second)}; RealFlagWarnings(context, power.flags, "power with INTEGER exponent"); if (context.flushSubnormalsToZero()) { @@ -1282,7 +1619,10 @@ Expr FoldOperation(FoldingContext &context, RealToIntPower &&x) { template Expr FoldOperation(FoldingContext &context, Extremum &&x) { - if (auto folded{FoldOperands(context, x.left(), x.right())}) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } + if (auto folded{OperandsAreConstants(x)}) { if constexpr (T::category == TypeCategory::Integer) { if (folded->first.CompareSigned(folded->second) == x.ordering) { return Expr{Constant{folded->first}}; @@ -1306,8 +1646,11 @@ Expr FoldOperation(FoldingContext &context, Extremum &&x) { template Expr> FoldOperation( FoldingContext &context, ComplexConstructor &&x) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } using Result = Type; - if (auto folded{FoldOperands(context, x.left(), x.right())}) { + if (auto folded{OperandsAreConstants(x)}) { return Expr{ Constant{Scalar{folded->first, folded->second}}}; } @@ -1317,8 +1660,11 @@ Expr> FoldOperation( template Expr> FoldOperation( FoldingContext &context, Concat &&x) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } using Result = Type; - if (auto folded{FoldOperands(context, x.left(), x.right())}) { + if (auto folded{OperandsAreConstants(x)}) { return Expr{Constant{folded->first + folded->second}}; } return Expr{std::move(x)}; @@ -1327,8 +1673,11 @@ Expr> FoldOperation( template Expr> FoldOperation( FoldingContext &context, SetLength &&x) { + if (auto array{ApplyElementwise(context, x)}) { + return *array; + } using Result = Type; - if (auto folded{FoldOperands(context, x.left(), x.right())}) { + if (auto folded{OperandsAreConstants(x)}) { auto oldLength{static_cast(folded->first.size())}; auto newLength{folded->second.ToInt64()}; if (newLength < oldLength) { @@ -1345,7 +1694,15 @@ Expr> FoldOperation( template Expr FoldOperation( FoldingContext &context, Relational &&relation) { - if (auto folded{FoldOperands(context, relation.left(), relation.right())}) { + if (auto array{ApplyElementwise(context, relation, + std::function(Expr &&, Expr &&)>{ + [=](Expr &&x, Expr &&y) { + return Expr{Relational{ + Relational{relation.opr, std::move(x), std::move(y)}}}; + }})}) { + return *array; + } + if (auto folded{OperandsAreConstants(relation)}) { bool result{}; if constexpr (T::category == TypeCategory::Integer) { result = @@ -1374,11 +1731,19 @@ inline Expr FoldOperation( template Expr> FoldOperation( - FoldingContext &context, LogicalOperation &&x) { + FoldingContext &context, LogicalOperation &&operation) { using LOGICAL = Type; - if (auto folded{FoldOperands(context, x.left(), x.right())}) { + if (auto array{ApplyElementwise(context, operation, + std::function(Expr &&, Expr &&)>{ + [=](Expr &&x, Expr &&y) { + return Expr{LogicalOperation{ + operation.logicalOperator, std::move(x), std::move(y)}}; + }})}) { + return *array; + } + if (auto folded{OperandsAreConstants(operation)}) { bool xt{folded->first.IsTrue()}, yt{folded->second.IsTrue()}, result{}; - switch (x.logicalOperator) { + switch (operation.logicalOperator) { case LogicalOperator::And: result = xt && yt; break; case LogicalOperator::Or: result = xt || yt; break; case LogicalOperator::Eqv: result = xt == yt; break; @@ -1386,7 +1751,7 @@ Expr> FoldOperation( } return Expr{Constant{result}}; } - return Expr{std::move(x)}; + return Expr{std::move(operation)}; } // end per-operation folding functions @@ -1420,7 +1785,6 @@ FOR_EACH_TYPE_AND_KIND(template class ExpressionBase, ) // able to fold it (yet) into a known constant value; specifically, // the expression may reference derived type kind parameters whose values // are not yet known. - class IsConstantExprVisitor : public virtual VisitorBase { public: using Result = bool; diff --git a/flang/lib/evaluate/fold.h b/flang/lib/evaluate/fold.h index 469091d5822c..eada99e9c8f1 100644 --- a/flang/lib/evaluate/fold.h +++ b/flang/lib/evaluate/fold.h @@ -52,7 +52,7 @@ std::optional> Fold( template std::optional> GetScalarConstantValue(const Expr &expr) { if (const auto *c{UnwrapExpr>(expr)}) { - if (c->size() == 1) { + if (c->Rank() == 0) { return **c; } else { return std::nullopt; diff --git a/flang/lib/evaluate/formatting.cc b/flang/lib/evaluate/formatting.cc index 6db94b6f6a77..f9afc94dda43 100644 --- a/flang/lib/evaluate/formatting.cc +++ b/flang/lib/evaluate/formatting.cc @@ -22,8 +22,7 @@ namespace Fortran::evaluate { -static void ShapeAsFortran( - std::ostream &o, const std::vector &shape) { +static void ShapeAsFortran(std::ostream &o, const ConstantSubscripts &shape) { if (shape.size() > 1) { o << ",shape="; char ch{'['}; @@ -101,6 +100,10 @@ std::ostream &Constant>::AsFortran( return o; } +std::ostream &ActualArgument::AssumedType::AsFortran(std::ostream &o) const { + return o << symbol_->name().ToString(); +} + std::ostream &ActualArgument::AsFortran(std::ostream &o) const { if (keyword.has_value()) { o << keyword->ToString() << '='; @@ -108,7 +111,11 @@ std::ostream &ActualArgument::AsFortran(std::ostream &o) const { if (isAlternateReturn) { o << '*'; } - return value().AsFortran(o); + if (const auto *expr{GetExpr()}) { + return expr->AsFortran(o); + } else { + return std::get(u_).AsFortran(o); + } } std::ostream &SpecificIntrinsic::AsFortran(std::ostream &o) const { @@ -321,7 +328,7 @@ template std::ostream &EmitArray( std::ostream &o, const ArrayConstructorValues &values) { const char *sep{""}; - for (const auto &value : values.values()) { + for (const auto &value : values) { o << sep; std::visit([&](const auto &x) { EmitArray(o, x); }, value.u); sep = ","; diff --git a/flang/lib/evaluate/intrinsics.cc b/flang/lib/evaluate/intrinsics.cc index 5c65f75c94a3..2373e5b74523 100644 --- a/flang/lib/evaluate/intrinsics.cc +++ b/flang/lib/evaluate/intrinsics.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "intrinsics.h" +#include "common.h" #include "expression.h" #include "fold.h" #include "shape.h" @@ -32,7 +33,7 @@ using namespace Fortran::parser::literals; namespace Fortran::evaluate { -using common::TypeCategory; +class FoldingContext; // This file defines the supported intrinsic procedures and implements // their recognition and validation. It is largely table-driven. See @@ -53,6 +54,8 @@ using common::TypeCategory; // INTEGER with a special "typeless" kind code. Arguments of intrinsic types // that can also be be typeless values are encoded with an "elementalOrBOZ" // rank pattern. +// Assumed-type (TYPE(*)) dummy arguments can be forwarded along to some +// intrinsic functions that accept AnyType + Rank::anyOrAssumedRank. using CategorySet = common::EnumSet; static constexpr CategorySet IntType{TypeCategory::Integer}; static constexpr CategorySet RealType{TypeCategory::Real}; @@ -72,12 +75,12 @@ ENUM_CLASS(KindCode, none, defaultIntegerKind, defaultRealKind, // is also the default COMPLEX kind doublePrecision, defaultCharKind, defaultLogicalKind, any, // matches any kind value; each instance is independent + same, // match any kind, but all "same" kinds must be equal typeless, // BOZ literals are INTEGER with this kind teamType, // TEAM_TYPE from module ISO_FORTRAN_ENV (for coarrays) kindArg, // this argument is KIND= effectiveKind, // for function results: same "kindArg", possibly defaulted dimArg, // this argument is DIM= - same, // match any kind; all "same" kinds must be equal likeMultiply, // for DOT_PRODUCT and MATMUL ) @@ -153,7 +156,7 @@ ENUM_CLASS(Rank, matrix, array, // not scalar, rank is known and greater than zero known, // rank is known and can be scalar - anyOrAssumedRank, // rank can be unknown + anyOrAssumedRank, // rank can be unknown; assumed-type TYPE(*) allowed conformable, // scalar, or array of same rank & shape as "array" argument reduceOperation, // a pure function with constraints for REDUCE dimReduced, // scalar if no DIM= argument, else rank(array)-1 @@ -207,7 +210,7 @@ struct IntrinsicInterface { Rank rank{Rank::elemental}; std::optional Match(const CallCharacteristics &, const common::IntrinsicTypeDefaultKinds &, ActualArguments &, - parser::ContextualMessages &messages) const; + FoldingContext &context) const; int CountArguments() const; std::ostream &Dump(std::ostream &) const; }; @@ -771,7 +774,8 @@ static const SpecificIntrinsicInterface specificIntrinsicFunction[]{ std::optional IntrinsicInterface::Match( const CallCharacteristics &call, const common::IntrinsicTypeDefaultKinds &defaults, - ActualArguments &arguments, parser::ContextualMessages &messages) const { + ActualArguments &arguments, FoldingContext &context) const { + auto &messages{context.messages()}; // Attempt to construct a 1-1 correspondence between the dummy arguments in // a particular intrinsic procedure's generic interface and the actual // arguments in a procedure reference. @@ -868,6 +872,18 @@ std::optional IntrinsicInterface::Match( continue; } } + if (arg->GetAssumedTypeDummy()) { + // TYPE(*) assumed-type dummy argument forwarded to intrinsic + if (d.typePattern.categorySet == AnyType && + d.typePattern.kindCode == KindCode::any && + d.rank == Rank::anyOrAssumedRank) { + continue; + } + messages.Say("Assumed type TYPE(*) dummy argument not allowed " + "for '%s=' intrinsic argument"_err_en_US, + d.keyword); + return std::nullopt; + } std::optional type{arg->GetType()}; if (!type.has_value()) { CHECK(arg->Rank() == 0); @@ -946,7 +962,7 @@ std::optional IntrinsicInterface::Match( for (std::size_t j{0}; j < dummies; ++j) { const IntrinsicDummyArgument &d{dummy[std::min(j, dummyArgPatterns - 1)]}; if (const ActualArgument * arg{actualForDummy[j]}) { - if (IsAssumedRank(arg->value()) && d.rank != Rank::anyOrAssumedRank) { + if (IsAssumedRank(*arg) && d.rank != Rank::anyOrAssumedRank) { messages.Say("assumed-rank array cannot be forwarded to " "'%s=' argument"_err_en_US, d.keyword); @@ -967,7 +983,7 @@ std::optional IntrinsicInterface::Match( case Rank::shape: CHECK(!shapeArgSize.has_value()); if (rank == 1) { - if (auto shape{GetShape(*arg)}) { + if (auto shape{GetShape(context, *arg)}) { if (auto constShape{AsConstantShape(*shape)}) { shapeArgSize = (**constShape).ToInt64(); CHECK(shapeArgSize >= 0); @@ -1083,12 +1099,13 @@ std::optional IntrinsicInterface::Match( CHECK(kindDummyArg != nullptr); CHECK(result.categorySet == CategorySet{resultType->category}); if (kindArg != nullptr) { - auto &expr{kindArg->value()}; - CHECK(expr.Rank() == 0); - if (auto code{ToInt64(expr)}) { - if (IsValidKindOfIntrinsicType(resultType->category, *code)) { - resultType->kind = *code; - break; + if (auto *expr{kindArg->GetExpr()}) { + CHECK(expr->Rank() == 0); + if (auto code{ToInt64(*expr)}) { + if (IsValidKindOfIntrinsicType(resultType->category, *code)) { + resultType->kind = *code; + break; + } } } messages.Say("'kind=' argument must be a constant scalar integer " @@ -1196,8 +1213,8 @@ public: bool IsIntrinsic(const std::string &) const; - std::optional Probe(const CallCharacteristics &, - ActualArguments &, parser::ContextualMessages *) const; + std::optional Probe( + const CallCharacteristics &, ActualArguments &, FoldingContext &) const; std::optional IsUnrestrictedSpecificIntrinsicFunction(const std::string &) const; @@ -1230,21 +1247,21 @@ bool IntrinsicProcTable::Implementation::IsIntrinsic( // match for a given procedure reference. std::optional IntrinsicProcTable::Implementation::Probe( const CallCharacteristics &call, ActualArguments &arguments, - parser::ContextualMessages *messages) const { + FoldingContext &context) const { if (call.isSubroutineCall) { return std::nullopt; // TODO } - parser::Messages *finalBuffer{messages ? messages->messages() : nullptr}; + parser::Messages *finalBuffer{context.messages().messages()}; // Probe the specific intrinsic function table first. parser::Messages specificBuffer; parser::ContextualMessages specificErrors{ - messages ? messages->at() : call.name, - finalBuffer ? &specificBuffer : nullptr}; + call.name, finalBuffer ? &specificBuffer : nullptr}; + FoldingContext specificContext{context, specificErrors}; std::string name{call.name.ToString()}; auto specificRange{specificFuncs_.equal_range(name)}; for (auto iter{specificRange.first}; iter != specificRange.second; ++iter) { if (auto specificCall{ - iter->second->Match(call, defaults_, arguments, specificErrors)}) { + iter->second->Match(call, defaults_, arguments, specificContext)}) { if (const char *genericName{iter->second->generic}) { specificCall->specificIntrinsic.name = genericName; } @@ -1256,12 +1273,12 @@ std::optional IntrinsicProcTable::Implementation::Probe( // Probe the generic intrinsic function table next. parser::Messages genericBuffer; parser::ContextualMessages genericErrors{ - messages ? messages->at() : call.name, - finalBuffer ? &genericBuffer : nullptr}; + call.name, finalBuffer ? &genericBuffer : nullptr}; + FoldingContext genericContext{context, genericErrors}; auto genericRange{genericFuncs_.equal_range(name)}; for (auto iter{genericRange.first}; iter != genericRange.second; ++iter) { if (auto specificCall{ - iter->second->Match(call, defaults_, arguments, genericErrors)}) { + iter->second->Match(call, defaults_, arguments, genericContext)}) { return specificCall; } } @@ -1277,20 +1294,20 @@ std::optional IntrinsicProcTable::Implementation::Probe( genericErrors.Say("unknown argument '%s' to NULL()"_err_en_US, arguments[0]->keyword->ToString().data()); } else { - Expr &mold{arguments[0]->value()}; - if (IsPointerOrAllocatable(mold)) { - return std::make_optional( - SpecificIntrinsic{"null"s, mold.GetType(), mold.Rank(), - semantics::Attrs{semantics::Attr::POINTER}}, - std::move(arguments)); - } else { - genericErrors.Say("MOLD argument to NULL() must be a pointer " - "or allocatable"_err_en_US); + if (Expr * mold{arguments[0]->GetExpr()}) { + if (IsPointerOrAllocatable(*mold)) { + return std::make_optional( + SpecificIntrinsic{"null"s, mold->GetType(), mold->Rank(), + semantics::Attrs{semantics::Attr::POINTER}}, + std::move(arguments)); + } } + genericErrors.Say("MOLD argument to NULL() must be a pointer " + "or allocatable"_err_en_US); } } // No match - if (finalBuffer) { + if (finalBuffer != nullptr) { if (genericBuffer.empty()) { finalBuffer->Annex(std::move(specificBuffer)); } else { @@ -1358,9 +1375,9 @@ bool IntrinsicProcTable::IsIntrinsic(const std::string &name) const { std::optional IntrinsicProcTable::Probe( const CallCharacteristics &call, ActualArguments &arguments, - parser::ContextualMessages *messages) const { + FoldingContext &context) const { CHECK(impl_ != nullptr || !"IntrinsicProcTable: not configured"); - return impl_->Probe(call, arguments, messages); + return impl_->Probe(call, arguments, context); } std::optional diff --git a/flang/lib/evaluate/intrinsics.h b/flang/lib/evaluate/intrinsics.h index 7d9729b47e78..890ab6082b09 100644 --- a/flang/lib/evaluate/intrinsics.h +++ b/flang/lib/evaluate/intrinsics.h @@ -26,6 +26,8 @@ namespace Fortran::evaluate { +class FoldingContext; + struct CallCharacteristics { parser::CharBlock name; bool isSubroutineCall{false}; @@ -61,8 +63,8 @@ public: // Probe the intrinsics for a match against a specific call. // On success, the actual arguments are transferred to the result // in dummy argument order. - std::optional Probe(const CallCharacteristics &, - ActualArguments &, parser::ContextualMessages *messages = nullptr) const; + std::optional Probe( + const CallCharacteristics &, ActualArguments &, FoldingContext &) const; // Probe the intrinsics with the name of a potential unrestricted specific // intrinsic. diff --git a/flang/lib/evaluate/shape.cc b/flang/lib/evaluate/shape.cc index 3a9d1827eba6..04e18bb3026d 100644 --- a/flang/lib/evaluate/shape.cc +++ b/flang/lib/evaluate/shape.cc @@ -33,13 +33,15 @@ Shape AsShape(const Constant &arrayConstant) { return result; } -std::optional AsShape(ExtentExpr &&arrayExpr) { +std::optional AsShape(FoldingContext &context, ExtentExpr &&arrayExpr) { + // Flatten any array expression into an array constructor if possible. + arrayExpr = Fold(context, std::move(arrayExpr)); if (auto *constArray{UnwrapExpr>(arrayExpr)}) { return AsShape(*constArray); } if (auto *constructor{UnwrapExpr>(arrayExpr)}) { Shape result; - for (auto &value : constructor->values()) { + for (auto &value : *constructor) { if (auto *expr{std::get_if(&value.u)}) { if (expr->Rank() == 0) { result.emplace_back(std::move(*expr)); @@ -50,13 +52,10 @@ std::optional AsShape(ExtentExpr &&arrayExpr) { } return result; } - // TODO: linearize other array-valued expressions of known shape, e.g. A+B - // as well as conversions of arrays; this will be easier given a - // general-purpose array expression flattener (pmk) return std::nullopt; } -std::optional AsShapeArrayExpr(const Shape &shape) { +std::optional AsExtentArrayExpr(const Shape &shape) { ArrayConstructorValues values; for (const auto &dim : shape) { if (dim.has_value()) { @@ -69,7 +68,7 @@ std::optional AsShapeArrayExpr(const Shape &shape) { } std::optional> AsConstantShape(const Shape &shape) { - if (auto shapeArray{AsShapeArrayExpr(shape)}) { + if (auto shapeArray{AsExtentArrayExpr(shape)}) { FoldingContext noFoldingContext; auto folded{Fold(noFoldingContext, std::move(*shapeArray))}; if (auto *p{UnwrapExpr>(folded)}) { @@ -79,6 +78,22 @@ std::optional> AsConstantShape(const Shape &shape) { return std::nullopt; } +ConstantSubscripts AsConstantExtents(const Constant &shape) { + ConstantSubscripts result; + for (const auto &extent : shape.values()) { + result.push_back(extent.ToInt64()); + } + return result; +} + +std::optional AsConstantExtents(const Shape &shape) { + if (auto shapeConstant{AsConstantShape(shape)}) { + return AsConstantExtents(*shapeConstant); + } else { + return std::nullopt; + } +} + static ExtentExpr ComputeTripCount( ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) { ExtentExpr strideCopy{common::Clone(stride)}; @@ -121,7 +136,16 @@ MaybeExtent GetSize(Shape &&shape) { return extent; } -static MaybeExtent GetLowerBound( +bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) { + struct MyVisitor : public virtual VisitorBase { + using Result = bool; + explicit MyVisitor(int) { result() = false; } + void Handle(const ImpliedDoIndex &) { Return(true); } + }; + return Visitor{0}.Traverse(expr); +} + +MaybeExtent GetShapeHelper::GetLowerBound( const Symbol &symbol, const Component *component, int dimension) { if (const auto *details{symbol.detailsIf()}) { int j{0}; @@ -142,7 +166,7 @@ static MaybeExtent GetLowerBound( return std::nullopt; } -static MaybeExtent GetExtent( +MaybeExtent GetShapeHelper::GetExtent( const Symbol &symbol, const Component *component, int dimension) { if (const auto *details{symbol.detailsIf()}) { int j{0}; @@ -169,8 +193,8 @@ static MaybeExtent GetExtent( return std::nullopt; } -static MaybeExtent GetExtent(const Subscript &subscript, const Symbol &symbol, - const Component *component, int dimension) { +MaybeExtent GetShapeHelper::GetExtent(const Subscript &subscript, + const Symbol &symbol, const Component *component, int dimension) { return std::visit( common::visitors{ [&](const Triplet &triplet) -> MaybeExtent { @@ -185,7 +209,7 @@ static MaybeExtent GetExtent(const Subscript &subscript, const Symbol &symbol, return CountTrips(std::move(lower), std::move(upper), MaybeExtent{triplet.stride()}); }, - [](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtent { + [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtent { if (auto shape{GetShape(subs.value())}) { if (shape->size() > 0) { CHECK(shape->size() == 1); // vector-valued subscript @@ -198,16 +222,7 @@ static MaybeExtent GetExtent(const Subscript &subscript, const Symbol &symbol, subscript.u); } -bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) { - struct MyVisitor : public virtual VisitorBase { - using Result = bool; - explicit MyVisitor(int) { result() = false; } - void Handle(const ImpliedDoIndex &) { Return(true); } - }; - return Visitor{0}.Traverse(expr); -} - -std::optional GetShape( +std::optional GetShapeHelper::GetShape( const Symbol &symbol, const Component *component) { if (const auto *details{symbol.detailsIf()}) { Shape result; @@ -221,7 +236,7 @@ std::optional GetShape( } } -std::optional GetShape(const Symbol *symbol) { +std::optional GetShapeHelper::GetShape(const Symbol *symbol) { if (symbol != nullptr) { return GetShape(*symbol); } else { @@ -229,7 +244,7 @@ std::optional GetShape(const Symbol *symbol) { } } -std::optional GetShape(const BaseObject &object) { +std::optional GetShapeHelper::GetShape(const BaseObject &object) { if (const Symbol * symbol{object.symbol()}) { return GetShape(*symbol); } else { @@ -237,7 +252,7 @@ std::optional GetShape(const BaseObject &object) { } } -std::optional GetShape(const Component &component) { +std::optional GetShapeHelper::GetShape(const Component &component) { const Symbol &symbol{component.GetLastSymbol()}; if (symbol.Rank() > 0) { return GetShape(symbol, &component); @@ -246,7 +261,7 @@ std::optional GetShape(const Component &component) { } } -std::optional GetShape(const ArrayRef &arrayRef) { +std::optional GetShapeHelper::GetShape(const ArrayRef &arrayRef) { Shape shape; const Symbol &symbol{arrayRef.GetLastSymbol()}; const Component *component{std::get_if(&arrayRef.base())}; @@ -264,7 +279,7 @@ std::optional GetShape(const ArrayRef &arrayRef) { } } -std::optional GetShape(const CoarrayRef &coarrayRef) { +std::optional GetShapeHelper::GetShape(const CoarrayRef &coarrayRef) { Shape shape; SymbolOrComponent base{coarrayRef.GetBaseSymbolOrComponent()}; const Symbol &symbol{coarrayRef.GetLastSymbol()}; @@ -283,11 +298,11 @@ std::optional GetShape(const CoarrayRef &coarrayRef) { } } -std::optional GetShape(const DataRef &dataRef) { +std::optional GetShapeHelper::GetShape(const DataRef &dataRef) { return GetShape(dataRef.u); } -std::optional GetShape(const Substring &substring) { +std::optional GetShapeHelper::GetShape(const Substring &substring) { if (const auto *dataRef{substring.GetParentIf()}) { return GetShape(*dataRef); } else { @@ -295,15 +310,21 @@ std::optional GetShape(const Substring &substring) { } } -std::optional GetShape(const ComplexPart &part) { +std::optional GetShapeHelper::GetShape(const ComplexPart &part) { return GetShape(part.complex()); } -std::optional GetShape(const ActualArgument &arg) { - return GetShape(arg.value()); +std::optional GetShapeHelper::GetShape(const ActualArgument &arg) { + if (const auto *expr{arg.GetExpr()}) { + return GetShape(*expr); + } else { + const Symbol *aType{arg.GetAssumedTypeDummy()}; + CHECK(aType != nullptr); + return GetShape(*aType); + } } -std::optional GetShape(const ProcedureRef &call) { +std::optional GetShapeHelper::GetShape(const ProcedureRef &call) { if (call.Rank() == 0) { return Shape{}; } else if (call.IsElemental()) { @@ -318,14 +339,16 @@ std::optional GetShape(const ProcedureRef &call) { std::get_if(&call.proc().u)}) { if (intrinsic->name == "shape" || intrinsic->name == "lbound" || intrinsic->name == "ubound") { - return Shape{MaybeExtent{ - ExtentExpr{call.arguments().front().value().value().Rank()}}}; + const auto *expr{call.arguments().front().value().GetExpr()}; + CHECK(expr != nullptr); + return Shape{MaybeExtent{ExtentExpr{expr->Rank()}}}; } else if (intrinsic->name == "reshape") { if (call.arguments().size() >= 2 && call.arguments().at(1).has_value()) { // SHAPE(RESHAPE(array,shape)) -> shape - const Expr &shapeExpr{call.arguments().at(1)->value()}; - Expr shape{std::get>(shapeExpr.u)}; - return AsShape(ConvertToType(std::move(shape))); + const auto *shapeExpr{call.arguments().at(1).value().GetExpr()}; + CHECK(shapeExpr != nullptr); + Expr shape{std::get>(shapeExpr->u)}; + return AsShape(context_, ConvertToType(std::move(shape))); } } else { // TODO: shapes of other non-elemental intrinsic results @@ -334,28 +357,54 @@ std::optional GetShape(const ProcedureRef &call) { return std::nullopt; } -std::optional GetShape(const Relational &relation) { +std::optional GetShapeHelper::GetShape( + const Relational &relation) { return GetShape(relation.u); } -std::optional GetShape(const StructureConstructor &) { +std::optional GetShapeHelper::GetShape(const StructureConstructor &) { return Shape{}; // always scalar } -std::optional GetShape(const ImpliedDoIndex &) { +std::optional GetShapeHelper::GetShape(const ImpliedDoIndex &) { return Shape{}; // always scalar } -std::optional GetShape(const DescriptorInquiry &) { +std::optional GetShapeHelper::GetShape(const DescriptorInquiry &) { return Shape{}; // always scalar } -std::optional GetShape(const BOZLiteralConstant &) { +std::optional GetShapeHelper::GetShape(const BOZLiteralConstant &) { return Shape{}; // always scalar } -std::optional GetShape(const NullPointer &) { +std::optional GetShapeHelper::GetShape(const NullPointer &) { return {}; // not an object } +void CheckConformance(parser::ContextualMessages &messages, const Shape &left, + const Shape &right) { + if (!left.empty() && !right.empty()) { + int n{static_cast(left.size())}; + int rn{static_cast(right.size())}; + if (n != rn) { + messages.Say( + "Left operand has rank %d, but right operand has rank %d"_err_en_US, + n, rn); + } else { + for (int j{0}; j < n; ++j) { + if (auto leftDim{ToInt64(left[j])}) { + if (auto rightDim{ToInt64(right[j])}) { + if (*leftDim != *rightDim) { + messages.Say("Dimension %d of left operand has extent %jd, " + "but right operand has extent %jd"_err_en_US, + j + 1, static_cast(*leftDim), + static_cast(*rightDim)); + } + } + } + } + } + } +} } diff --git a/flang/lib/evaluate/shape.h b/flang/lib/evaluate/shape.h index 21df38db05dd..85caf6eebac2 100644 --- a/flang/lib/evaluate/shape.h +++ b/flang/lib/evaluate/shape.h @@ -27,16 +27,20 @@ namespace Fortran::evaluate { +class FoldingContext; + using ExtentType = SubscriptInteger; using ExtentExpr = Expr; using MaybeExtent = std::optional; using Shape = std::vector; -// Convert between various representations of shapes +// Conversions between various representations of shapes. Shape AsShape(const Constant &arrayConstant); -std::optional AsShape(ExtentExpr &&arrayExpr); -std::optional AsShapeArrayExpr(const Shape &); // array constructor +std::optional AsShape(FoldingContext &, ExtentExpr &&arrayExpr); +std::optional AsExtentArrayExpr(const Shape &); std::optional> AsConstantShape(const Shape &); +ConstantSubscripts AsConstantExtents(const Constant &); +std::optional AsConstantExtents(const Shape &); // Compute an element count for a triplet or trip count for a DO. ExtentExpr CountTrips( @@ -49,132 +53,150 @@ MaybeExtent CountTrips( // Computes SIZE() == PRODUCT(shape) MaybeExtent GetSize(Shape &&); -// Forward declarations -template -std::optional GetShape(const std::variant &); -template -std::optional GetShape(const common::Indirection &); -template std::optional GetShape(const std::optional &); - -template std::optional GetShape(const Expr &expr) { - return GetShape(expr.u); -} - -std::optional GetShape(const Symbol &, const Component * = nullptr); -std::optional GetShape(const Symbol *); -std::optional GetShape(const BaseObject &); -std::optional GetShape(const Component &); -std::optional GetShape(const ArrayRef &); -std::optional GetShape(const CoarrayRef &); -std::optional GetShape(const DataRef &); -std::optional GetShape(const Substring &); -std::optional GetShape(const ComplexPart &); -std::optional GetShape(const ActualArgument &); -std::optional GetShape(const ProcedureRef &); -std::optional GetShape(const ImpliedDoIndex &); -std::optional GetShape(const Relational &); -std::optional GetShape(const StructureConstructor &); -std::optional GetShape(const DescriptorInquiry &); -std::optional GetShape(const BOZLiteralConstant &); -std::optional GetShape(const NullPointer &); - -template std::optional GetShape(const Constant &c) { - Constant shape{c.SHAPE()}; - return AsShape(shape); -} - -template -std::optional GetShape(const Designator &designator) { - return GetShape(designator.u); -} - -template -std::optional GetShape(const Variable &variable) { - return GetShape(variable.u); -} - -template -std::optional GetShape(const Operation &operation) { - if constexpr (sizeof...(O) > 1) { - if (operation.right().Rank() > 0) { - return GetShape(operation.right()); - } - } - return GetShape(operation.left()); -} - -template -std::optional GetShape(const TypeParamInquiry &) { - return Shape{}; // always scalar, even when applied to an array -} - // Utility predicate: does an expression reference any implied DO index? bool ContainsAnyImpliedDoIndex(const ExtentExpr &); -template MaybeExtent GetExtent(const ArrayConstructorValues &); +// Compilation-time shape conformance checking, when corresponding extents +// are known. +void CheckConformance( + parser::ContextualMessages &, const Shape &, const Shape &); -template -MaybeExtent GetExtent(const ArrayConstructorValue &value) { - return std::visit( - common::visitors{ - [](const Expr &x) -> MaybeExtent { - if (std::optional xShape{GetShape(x)}) { - // Array values in array constructors get linearized. - return GetSize(std::move(*xShape)); - } - return std::nullopt; - }, - [](const ImpliedDo &ido) -> MaybeExtent { - // Don't be heroic and try to figure out triangular implied DO - // nests. - if (!ContainsAnyImpliedDoIndex(ido.lower()) && - !ContainsAnyImpliedDoIndex(ido.upper()) && - !ContainsAnyImpliedDoIndex(ido.stride())) { - if (auto nValues{GetExtent(ido.values())}) { - return std::move(*nValues) * - CountTrips(ido.lower(), ido.upper(), ido.stride()); - } - } - return std::nullopt; - }, - }, - value.u); -} +// The implementation of GetShape() is wrapped in a helper class +// so that the member functions may mutually recurse without prototypes. +class GetShapeHelper { +public: + explicit GetShapeHelper(FoldingContext &context) : context_{context} {} -template -MaybeExtent GetExtent(const ArrayConstructorValues &values) { - ExtentExpr result{0}; - for (const auto &value : values.values()) { - if (MaybeExtent n{GetExtent(value)}) { - result = std::move(result) + std::move(*n); + template std::optional GetShape(const Expr &expr) { + return GetShape(expr.u); + } + + std::optional GetShape(const Symbol &, const Component * = nullptr); + std::optional GetShape(const Symbol *); + std::optional GetShape(const BaseObject &); + std::optional GetShape(const Component &); + std::optional GetShape(const ArrayRef &); + std::optional GetShape(const CoarrayRef &); + std::optional GetShape(const DataRef &); + std::optional GetShape(const Substring &); + std::optional GetShape(const ComplexPart &); + std::optional GetShape(const ActualArgument &); + std::optional GetShape(const ProcedureRef &); + std::optional GetShape(const ImpliedDoIndex &); + std::optional GetShape(const Relational &); + std::optional GetShape(const StructureConstructor &); + std::optional GetShape(const DescriptorInquiry &); + std::optional GetShape(const BOZLiteralConstant &); + std::optional GetShape(const NullPointer &); + + template std::optional GetShape(const Constant &c) { + Constant shape{c.SHAPE()}; + return AsShape(shape); + } + + template + std::optional GetShape(const Designator &designator) { + return GetShape(designator.u); + } + + template + std::optional GetShape(const Variable &variable) { + return GetShape(variable.u); + } + + template + std::optional GetShape(const Operation &operation) { + if constexpr (sizeof...(O) > 1) { + if (operation.right().Rank() > 0) { + return GetShape(operation.right()); + } + } + return GetShape(operation.left()); + } + + template + std::optional GetShape(const TypeParamInquiry &) { + return Shape{}; // always scalar, even when applied to an array + } + + template + std::optional GetShape(const ArrayConstructor &aconst) { + return Shape{GetExtent(aconst)}; + } + + template + std::optional GetShape(const std::variant &u) { + return std::visit([&](const auto &x) { return GetShape(x); }, u); + } + + template + std::optional GetShape(const common::Indirection &p) { + return GetShape(p.value()); + } + + template + std::optional GetShape(const std::optional &x) { + if (x.has_value()) { + return GetShape(*x); } else { return std::nullopt; } } - return result; -} -template -std::optional GetShape(const ArrayConstructor &aconst) { - return Shape{GetExtent(aconst)}; -} +private: + MaybeExtent GetLowerBound(const Symbol &, const Component *, int dimension); -template -std::optional GetShape(const std::variant &u) { - return std::visit([](const auto &x) { return GetShape(x); }, u); -} - -template -std::optional GetShape(const common::Indirection &p) { - return GetShape(p.value()); -} - -template std::optional GetShape(const std::optional &x) { - if (x.has_value()) { - return GetShape(*x); - } else { - return std::nullopt; + template + MaybeExtent GetExtent(const ArrayConstructorValue &value) { + return std::visit( + common::visitors{ + [&](const Expr &x) -> MaybeExtent { + if (std::optional xShape{GetShape(x)}) { + // Array values in array constructors get linearized. + return GetSize(std::move(*xShape)); + } + return std::nullopt; + }, + [&](const ImpliedDo &ido) -> MaybeExtent { + // Don't be heroic and try to figure out triangular implied DO + // nests. + if (!ContainsAnyImpliedDoIndex(ido.lower()) && + !ContainsAnyImpliedDoIndex(ido.upper()) && + !ContainsAnyImpliedDoIndex(ido.stride())) { + if (auto nValues{GetExtent(ido.values())}) { + return std::move(*nValues) * + CountTrips(ido.lower(), ido.upper(), ido.stride()); + } + } + return std::nullopt; + }, + }, + value.u); } + + template + MaybeExtent GetExtent(const ArrayConstructorValues &values) { + ExtentExpr result{0}; + for (const auto &value : values) { + if (MaybeExtent n{GetExtent(value)}) { + result = std::move(result) + std::move(*n); + } else { + return std::nullopt; + } + } + return result; + } + + MaybeExtent GetExtent(const Symbol &, const Component *, int dimension); + MaybeExtent GetExtent( + const Subscript &, const Symbol &, const Component *, int dimension); + + FoldingContext &context_; +}; + +template +std::optional GetShape(FoldingContext &context, const A &x) { + return GetShapeHelper{context}.GetShape(x); } } #endif // FORTRAN_EVALUATE_SHAPE_H_ diff --git a/flang/lib/semantics/expression.cc b/flang/lib/semantics/expression.cc index 751e4fb8ef50..485496c1480d 100644 --- a/flang/lib/semantics/expression.cc +++ b/flang/lib/semantics/expression.cc @@ -1161,7 +1161,7 @@ template ArrayConstructorValues MakeSpecific( ArrayConstructorValues &&from) { ArrayConstructorValues to; - for (ArrayConstructorValue &x : from.values()) { + for (ArrayConstructorValue &x : from) { std::visit( common::visitors{ [&](common::CopyableIndirection> &&expr) { @@ -1456,7 +1456,7 @@ auto ExpressionAnalyzer::Procedure(const parser::ProcedureDesignator &pd, CallCharacteristics cc{n.source}; if (std::optional specificCall{ context().intrinsics().Probe( - cc, arguments, &GetContextualMessages())}) { + cc, arguments, GetFoldingContext())}) { return { CallAndArguments{ProcedureDesignator{std::move( specificCall->specificIntrinsic)}, diff --git a/flang/test/evaluate/intrinsics.cc b/flang/test/evaluate/intrinsics.cc index e9a6a36551e0..53d80eede51c 100644 --- a/flang/test/evaluate/intrinsics.cc +++ b/flang/test/evaluate/intrinsics.cc @@ -14,6 +14,7 @@ #include "../../lib/evaluate/intrinsics.h" #include "testing.h" +#include "../../lib/evaluate/common.h" #include "../../lib/evaluate/expression.h" #include "../../lib/evaluate/tools.h" #include "../../lib/parser/provenance.h" @@ -111,7 +112,8 @@ struct TestCall { std::cout << ")\n"; CallCharacteristics call{fName}; auto messages{strings.Messages(buffer)}; - std::optional si{table.Probe(call, args, &messages)}; + FoldingContext context{messages}; + std::optional si{table.Probe(call, args, context)}; if (resultType.has_value()) { TEST(si.has_value()); TEST(buffer.empty()); diff --git a/flang/test/semantics/modfile25.f90 b/flang/test/semantics/modfile25.f90 index 3d23e639622a..87eb8d8df7da 100644 --- a/flang/test/semantics/modfile25.f90 +++ b/flang/test/semantics/modfile25.f90 @@ -29,7 +29,7 @@ module m1 integer(8), parameter :: ac3bs(:) = shape([(1,j=4,1,-1)]) integer(8), parameter :: ac4s(:) = shape([((j,k,j*k,k=1,3),j=1,4)]) integer(8), parameter :: ac5s(:) = shape([((0,k=5,1,-2),j=9,2,-3)]) - integer(8), parameter :: rss(:) = shape(reshape([(0,j=1,90)], [10_8,9_8])) + integer(8), parameter :: rss(:) = shape(reshape([(0,j=1,90)], -[2,3]*(-[5_8,3_8]))) contains subroutine subr(x,n1,n2) real, intent(in) :: x(:,:)