[flang] Fold array operations

Original-commit: flang-compiler/f18@e6c86ecfd1
Reviewed-on: https://github.com/flang-compiler/f18/pull/420
This commit is contained in:
peter klausler 2019-04-18 14:11:15 -07:00
parent 567480a4d7
commit 146e13ce22
16 changed files with 1036 additions and 394 deletions

View File

@ -14,19 +14,45 @@
#include "call.h"
#include "expression.h"
#include "tools.h"
#include "../common/idioms.h"
#include "../semantics/symbol.h"
namespace Fortran::evaluate {
std::optional<DynamicType> ActualArgument::GetType() const {
return value().GetType();
ActualArgument::AssumedType::AssumedType(const semantics::Symbol &symbol)
: symbol_{&symbol} {
const semantics::DeclTypeSpec *type{symbol.GetType()};
CHECK(
type != nullptr && type->category() == semantics::DeclTypeSpec::TypeStar);
}
int ActualArgument::Rank() const { return value().Rank(); }
int ActualArgument::AssumedType::Rank() const { return symbol_->Rank(); }
ActualArgument &ActualArgument::operator=(Expr<SomeType> &&expr) {
u_ = std::move(expr);
return *this;
}
std::optional<DynamicType> ActualArgument::GetType() const {
if (const auto *expr{GetExpr()}) {
return expr->GetType();
} else {
return std::nullopt;
}
}
int ActualArgument::Rank() const {
if (const auto *expr{GetExpr()}) {
return expr->Rank();
} else {
return std::get<AssumedType>(u_).Rank();
}
}
bool ActualArgument::operator==(const ActualArgument &that) const {
return keyword == that.keyword &&
isAlternateReturn == that.isAlternateReturn && value() == that.value();
isAlternateReturn == that.isAlternateReturn && u_ == that.u_;
}
bool SpecificIntrinsic::operator==(const SpecificIntrinsic &that) const {
@ -80,8 +106,28 @@ const Symbol *ProcedureDesignator::GetSymbol() const {
}
Expr<SubscriptInteger> ProcedureRef::LEN() const {
// TODO: the results of the intrinsic functions REPEAT and TRIM have
// unpredictable lengths; maybe the concept of LEN() has to become dynamic
if (const auto *intrinsic{std::get_if<SpecificIntrinsic>(&proc_.u)}) {
if (intrinsic->name == "repeat") {
// LEN(REPEAT(ch,n)) == LEN(ch) * n
CHECK(arguments_.size() == 2);
const auto *stringArg{
UnwrapExpr<Expr<SomeCharacter>>(arguments_[0].value())};
const auto *nCopiesArg{
UnwrapExpr<Expr<SomeInteger>>(arguments_[1].value())};
CHECK(stringArg != nullptr && nCopiesArg != nullptr);
auto stringLen{stringArg->LEN()};
return std::move(stringLen) *
ConvertTo(stringLen, common::Clone(*nCopiesArg));
}
if (intrinsic->name == "trim") {
// LEN(TRIM(ch)) is unknown without execution.
CHECK(arguments_.size() == 1);
const auto *stringArg{
UnwrapExpr<Expr<SomeCharacter>>(arguments_[0].value())};
CHECK(stringArg != nullptr);
return stringArg->LEN();
}
}
return proc_.LEN();
}

View File

@ -41,12 +41,55 @@ namespace Fortran::evaluate {
class ActualArgument {
public:
explicit ActualArgument(Expr<SomeType> &&x) : value_{std::move(x)} {}
explicit ActualArgument(common::CopyableIndirection<Expr<SomeType>> &&v)
: value_{std::move(v)} {}
// Dummy arguments that are TYPE(*) can be forwarded as actual arguments.
// Since that's the only thing one may do with them in Fortran, they're
// represented in expressions as a special case of an actual argument.
class AssumedType {
public:
explicit AssumedType(const semantics::Symbol &);
DEFAULT_CONSTRUCTORS_AND_ASSIGNMENTS(AssumedType)
const semantics::Symbol &symbol() const { return *symbol_; }
int Rank() const;
bool operator==(const AssumedType &that) const {
return symbol_ == that.symbol_;
}
std::ostream &AsFortran(std::ostream &) const;
Expr<SomeType> &value() { return value_.value(); }
const Expr<SomeType> &value() const { return value_.value(); }
private:
const semantics::Symbol *symbol_;
};
explicit ActualArgument(Expr<SomeType> &&x) : u_{std::move(x)} {}
explicit ActualArgument(common::CopyableIndirection<Expr<SomeType>> &&v)
: u_{std::move(v)} {}
explicit ActualArgument(AssumedType x) : u_{x} {}
ActualArgument &operator=(Expr<SomeType> &&);
Expr<SomeType> *GetExpr() {
if (auto *p{
std::get_if<common::CopyableIndirection<Expr<SomeType>>>(&u_)}) {
return &p->value();
} else {
return nullptr;
}
}
const Expr<SomeType> *GetExpr() const {
if (const auto *p{
std::get_if<common::CopyableIndirection<Expr<SomeType>>>(&u_)}) {
return &p->value();
} else {
return nullptr;
}
}
const semantics::Symbol *GetAssumedTypeDummy() const {
if (const AssumedType * aType{std::get_if<AssumedType>(&u_)}) {
return &aType->symbol();
} else {
return nullptr;
}
}
std::optional<DynamicType> GetType() const;
int Rank() const;
@ -64,7 +107,7 @@ private:
// e.g. between X and (X). The parser attempts to parse each argument
// first as a variable, then as an expression, and the distinction appears
// in the parse tree.
common::CopyableIndirection<Expr<SomeType>> value_;
std::variant<common::CopyableIndirection<Expr<SomeType>>, AssumedType> u_;
};
using ActualArguments = std::vector<std::optional<ActualArgument>>;

View File

@ -19,16 +19,41 @@
namespace Fortran::evaluate {
std::size_t TotalElementCount(const ConstantSubscripts &shape) {
std::size_t size{1};
for (auto dim : shape) {
CHECK(dim >= 0);
size *= dim;
}
return size;
}
bool IncrementSubscripts(
ConstantSubscripts &indices, const ConstantSubscripts &shape) {
auto rank{shape.size()};
CHECK(indices.size() == rank);
for (std::size_t j{0}; j < rank; ++j) {
CHECK(indices[j] >= 1);
if (++indices[j] <= shape[j]) {
return true;
} else {
CHECK(indices[j] == shape[j] + 1);
indices[j] = 1;
}
}
return false; // all done
}
template<typename RESULT, typename VALUE>
ConstantBase<RESULT, VALUE>::~ConstantBase() {}
static std::int64_t SubscriptsToOffset(const std::vector<std::int64_t> &index,
const std::vector<std::int64_t> &shape) {
static ConstantSubscript SubscriptsToOffset(
const ConstantSubscripts &index, const ConstantSubscripts &shape) {
CHECK(index.size() == shape.size());
std::int64_t stride{1}, offset{0};
ConstantSubscript stride{1}, offset{0};
int dim{0};
for (std::int64_t j : index) {
std::int64_t bound{shape[dim++]};
for (auto j : index) {
auto bound{shape[dim++]};
CHECK(j >= 1 && j <= bound);
offset += stride * (j - 1);
stride *= bound;
@ -37,26 +62,26 @@ static std::int64_t SubscriptsToOffset(const std::vector<std::int64_t> &index,
}
template<typename RESULT, typename VALUE>
auto ConstantBase<RESULT, VALUE>::At(
const std::vector<std::int64_t> &index) const -> ScalarValue {
auto ConstantBase<RESULT, VALUE>::At(const ConstantSubscripts &index) const
-> ScalarValue {
return values_.at(SubscriptsToOffset(index, shape_));
}
template<typename RESULT, typename VALUE>
auto ConstantBase<RESULT, VALUE>::At(std::vector<std::int64_t> &&index) const
auto ConstantBase<RESULT, VALUE>::At(ConstantSubscripts &&index) const
-> ScalarValue {
return values_.at(SubscriptsToOffset(index, shape_));
}
static Constant<SubscriptInteger> ShapeAsConstant(
const std::vector<std::int64_t> &shape) {
const ConstantSubscripts &shape) {
using IntType = Scalar<SubscriptInteger>;
std::vector<IntType> result;
for (std::int64_t dim : shape) {
for (auto dim : shape) {
result.emplace_back(dim);
}
return {std::move(result),
std::vector<std::int64_t>{static_cast<std::int64_t>(shape.size())}};
ConstantSubscripts{static_cast<std::int64_t>(shape.size())}};
}
template<typename RESULT, typename VALUE>
@ -76,7 +101,7 @@ Constant<Type<TypeCategory::Character, KIND>>::Constant(ScalarValue &&str)
template<int KIND>
Constant<Type<TypeCategory::Character, KIND>>::Constant(std::int64_t len,
std::vector<ScalarValue> &&strings, std::vector<std::int64_t> &&dims)
std::vector<ScalarValue> &&strings, ConstantSubscripts &&dims)
: length_{len}, shape_{std::move(dims)} {
values_.assign(strings.size() * length_,
static_cast<typename ScalarValue::value_type>(' '));
@ -95,8 +120,8 @@ Constant<Type<TypeCategory::Character, KIND>>::Constant(std::int64_t len,
template<int KIND> Constant<Type<TypeCategory::Character, KIND>>::~Constant() {}
static std::int64_t ShapeElements(const std::vector<std::int64_t> &shape) {
std::int64_t elements{1};
static ConstantSubscript ShapeElements(const ConstantSubscripts &shape) {
ConstantSubscript elements{1};
for (auto dim : shape) {
elements *= dim;
}
@ -119,7 +144,7 @@ std::size_t Constant<Type<TypeCategory::Character, KIND>>::size() const {
template<int KIND>
auto Constant<Type<TypeCategory::Character, KIND>>::At(
const std::vector<std::int64_t> &index) const -> ScalarValue {
const ConstantSubscripts &index) const -> ScalarValue {
auto offset{SubscriptsToOffset(index, shape_)};
return values_.substr(offset, length_);
}
@ -138,10 +163,10 @@ Constant<SomeDerived>::Constant(StructureConstructor &&x)
: Base{std::move(x.values())}, derivedTypeSpec_{&x.derivedTypeSpec()} {}
Constant<SomeDerived>::Constant(const semantics::DerivedTypeSpec &spec,
std::vector<StructureConstructorValues> &&x, std::vector<std::int64_t> &&s)
std::vector<StructureConstructorValues> &&x, ConstantSubscripts &&s)
: Base{std::move(x), std::move(s)}, derivedTypeSpec_{&spec} {}
static std::vector<StructureConstructorValues> GetValues(
static std::vector<StructureConstructorValues> AcquireValues(
std::vector<StructureConstructor> &&x) {
std::vector<StructureConstructorValues> result;
for (auto &&structure : std::move(x)) {
@ -151,8 +176,8 @@ static std::vector<StructureConstructorValues> GetValues(
}
Constant<SomeDerived>::Constant(const semantics::DerivedTypeSpec &spec,
std::vector<StructureConstructor> &&x, std::vector<std::int64_t> &&s)
: Base{GetValues(std::move(x)), std::move(s)}, derivedTypeSpec_{&spec} {}
std::vector<StructureConstructor> &&x, ConstantSubscripts &&s)
: Base{AcquireValues(std::move(x)), std::move(s)}, derivedTypeSpec_{&spec} {}
INSTANTIATE_CONSTANT_TEMPLATES
}

View File

@ -32,6 +32,24 @@ namespace Fortran::evaluate {
template<typename> class Constant;
// When describing shapes of constants or specifying 1-based subscript
// values as indices into constants, use a vector of integers.
using ConstantSubscript = std::int64_t;
using ConstantSubscripts = std::vector<ConstantSubscript>;
std::size_t TotalElementCount(const ConstantSubscripts &);
inline ConstantSubscripts InitialSubscripts(int rank) {
return ConstantSubscripts(rank, 1); // parens, not braces: "rank" copies of 1
}
inline ConstantSubscripts InitialSubscripts(const ConstantSubscripts &shape) {
return InitialSubscripts(static_cast<int>(shape.size()));
}
// Increments a vector of subscripts in Fortran array order (first dimension
// varying most quickly). Returns false when last element was visited.
bool IncrementSubscripts(ConstantSubscripts &, const ConstantSubscripts &shape);
// Constant<> is specialized for Character kinds and SomeDerived.
// The non-Character intrinsic types, and SomeDerived, share enough
// common behavior that they use this common base class.
@ -45,7 +63,7 @@ public:
template<typename A> ConstantBase(const A &x) : values_{x} {}
template<typename A, typename = common::NoLvalue<A>>
ConstantBase(A &&x) : values_{std::move(x)} {}
ConstantBase(std::vector<ScalarValue> &&x, std::vector<std::int64_t> &&dims)
ConstantBase(std::vector<ScalarValue> &&x, ConstantSubscripts &&dims)
: values_(std::move(x)), shape_(std::move(dims)) {}
DEFAULT_CONSTRUCTORS_AND_ASSIGNMENTS(ConstantBase)
~ConstantBase();
@ -57,7 +75,8 @@ public:
bool empty() const { return values_.empty(); }
std::size_t size() const { return values_.size(); }
const std::vector<ScalarValue> &values() const { return values_; }
const std::vector<std::int64_t> &shape() const { return shape_; }
const ConstantSubscripts &shape() const { return shape_; }
ConstantSubscripts &shape() { return shape_; }
ScalarValue operator*() const {
CHECK(values_.size() == 1);
@ -65,15 +84,15 @@ public:
}
// Apply 1-based subscripts
ScalarValue At(const std::vector<std::int64_t> &) const;
ScalarValue At(std::vector<std::int64_t> &&) const;
ScalarValue At(const ConstantSubscripts &) const;
ScalarValue At(ConstantSubscripts &&) const;
Constant<SubscriptInteger> SHAPE() const;
std::ostream &AsFortran(std::ostream &) const;
protected:
std::vector<ScalarValue> values_;
std::vector<std::int64_t> shape_;
ConstantSubscripts shape_;
private:
const Constant<Result> &AsConstant() const {
@ -96,11 +115,11 @@ template<int KIND> class Constant<Type<TypeCategory::Character, KIND>> {
public:
using Result = Type<TypeCategory::Character, KIND>;
using ScalarValue = Scalar<Result>;
CLASS_BOILERPLATE(Constant)
explicit Constant(const ScalarValue &);
explicit Constant(ScalarValue &&);
Constant(
std::int64_t, std::vector<ScalarValue> &&, std::vector<std::int64_t> &&);
Constant(std::int64_t, std::vector<ScalarValue> &&, ConstantSubscripts &&);
~Constant();
int Rank() const { return static_cast<int>(shape_.size()); }
@ -109,7 +128,8 @@ public:
}
bool empty() const;
std::size_t size() const;
const std::vector<std::int64_t> &shape() const { return shape_; }
const ConstantSubscripts &shape() const { return shape_; }
ConstantSubscripts &shape() { return shape_; }
std::int64_t LEN() const { return length_; }
@ -119,7 +139,7 @@ public:
}
// Apply 1-based subscripts
ScalarValue At(const std::vector<std::int64_t> &) const;
ScalarValue At(const ConstantSubscripts &) const;
Constant<SubscriptInteger> SHAPE() const;
std::ostream &AsFortran(std::ostream &) const;
@ -128,7 +148,7 @@ public:
private:
ScalarValue values_; // one contiguous string
std::int64_t length_;
std::vector<std::int64_t> shape_;
ConstantSubscripts shape_;
};
using StructureConstructorValues = std::map<const semantics::Symbol *,
@ -140,12 +160,13 @@ class Constant<SomeDerived>
public:
using Result = SomeDerived;
using Base = ConstantBase<Result, StructureConstructorValues>;
Constant(const StructureConstructor &);
Constant(StructureConstructor &&);
Constant(const semantics::DerivedTypeSpec &, std::vector<ScalarValue> &&,
std::vector<std::int64_t> &&);
ConstantSubscripts &&);
Constant(const semantics::DerivedTypeSpec &,
std::vector<StructureConstructor> &&, std::vector<std::int64_t> &&);
std::vector<StructureConstructor> &&, ConstantSubscripts &&);
CLASS_BOILERPLATE(Constant)
const semantics::DerivedTypeSpec &derivedTypeSpec() const {

View File

@ -117,10 +117,14 @@ public:
}
template<typename R> void Descend(const ArrayConstructorValues<R> &avs) {
Visit(avs.values());
for (const auto &x : avs) {
Visit(x);
}
}
template<typename R> void Descend(ArrayConstructorValues<R> &avs) {
Visit(avs.values());
for (auto &x : avs) {
Visit(x);
}
}
template<int KIND>
@ -155,13 +159,13 @@ public:
void Descend(const StructureConstructor &sc) {
Visit(sc.derivedTypeSpec());
for (const auto &pair : sc.values()) {
for (const auto &pair : sc) {
Visit(pair.second);
}
}
void Descend(StructureConstructor &sc) {
Visit(sc.derivedTypeSpec());
for (const auto &pair : sc.values()) {
for (const auto &pair : sc) {
Visit(pair.second);
}
}
@ -241,8 +245,22 @@ public:
template<typename T> void Descend(const Variable<T> &var) { Visit(var.u); }
template<typename T> void Descend(Variable<T> &var) { Visit(var.u); }
void Descend(const ActualArgument &arg) { Visit(arg.value()); }
void Descend(ActualArgument &arg) { Visit(arg.value()); }
void Descend(const ActualArgument &arg) {
if (const auto *expr{arg.GetExpr()}) {
Visit(*expr);
} else {
const semantics::Symbol *aType{arg.GetAssumedTypeDummy()};
Visit(*aType);
}
}
void Descend(ActualArgument &arg) {
if (auto *expr{arg.GetExpr()}) {
Visit(*expr);
} else {
const semantics::Symbol *aType{arg.GetAssumedTypeDummy()};
Visit(*aType);
}
}
void Descend(const ProcedureDesignator &p) { Visit(p.u); }
void Descend(ProcedureDesignator &p) { Visit(p.u); }

View File

@ -450,18 +450,26 @@ public:
using Values = std::vector<ArrayConstructorValue<Result>>;
DEFAULT_CONSTRUCTORS_AND_ASSIGNMENTS(ArrayConstructorValues)
ArrayConstructorValues() {}
bool operator==(const ArrayConstructorValues &) const;
static constexpr int Rank() { return 1; }
template<typename A> common::NoLvalue<A> Push(A &&x) {
values_.emplace_back(std::move(x));
}
Values &values() { return values_; }
const Values &values() const { return values_; }
typename Values::iterator begin() { return values_.begin(); }
typename Values::const_iterator begin() const { return values_.begin(); }
typename Values::iterator end() { return values_.end(); }
typename Values::const_iterator end() const { return values_.end(); }
protected:
Values values_;
};
// Note that there are specializations of ArrayConstructor for character
// and derived types, since they must carry additional type information,
// but that an empty ArrayConstructor can be constructed for any type
// given an expression from which such type information may be gleaned.
template<typename RESULT>
class ArrayConstructor : public ArrayConstructorValues<RESULT> {
public:
@ -469,6 +477,7 @@ public:
using Base = ArrayConstructorValues<Result>;
DEFAULT_CONSTRUCTORS_AND_ASSIGNMENTS(ArrayConstructor)
explicit ArrayConstructor(Base &&values) : Base{std::move(values)} {}
template<typename T> explicit ArrayConstructor(const Expr<T> &) {}
static constexpr DynamicType GetType() { return Result::GetType(); }
std::ostream &AsFortran(std::ostream &) const;
};
@ -482,6 +491,8 @@ public:
CLASS_BOILERPLATE(ArrayConstructor)
ArrayConstructor(Expr<SubscriptInteger> &&len, Base &&v)
: Base{std::move(v)}, length_{std::move(len)} {}
template<typename T>
explicit ArrayConstructor(const Expr<T> &proto) : length_{proto.LEN()} {}
bool operator==(const ArrayConstructor &) const;
static constexpr DynamicType GetType() { return Result::GetType(); }
std::ostream &AsFortran(std::ostream &) const;
@ -500,6 +511,11 @@ public:
CLASS_BOILERPLATE(ArrayConstructor)
ArrayConstructor(const semantics::DerivedTypeSpec &spec, Base &&v)
: Base{std::move(v)}, derivedTypeSpec_{&spec} {}
template<typename T>
explicit ArrayConstructor(const Expr<T> &proto)
: derivedTypeSpec_{GetType(proto).derived} {
CHECK(derivedTypeSpec_ != nullptr);
}
bool operator==(const ArrayConstructor &) const;
const semantics::DerivedTypeSpec &derivedTypeSpec() const {
return *derivedTypeSpec_;
@ -715,8 +731,18 @@ public:
}
StructureConstructorValues &values() { return values_; }
const StructureConstructorValues &values() const { return values_; }
bool operator==(const StructureConstructor &) const;
StructureConstructorValues::iterator begin() { return values_.begin(); }
StructureConstructorValues::const_iterator begin() const {
return values_.begin();
}
StructureConstructorValues::iterator end() { return values_.end(); }
StructureConstructorValues::const_iterator end() const {
return values_.end();
}
StructureConstructor &Add(const semantics::Symbol &, Expr<SomeType> &&);
int Rank() const { return 0; }
DynamicType GetType() const;

View File

@ -193,12 +193,12 @@ static inline Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
(... && IsSpecificIntrinsicType<TA>)); // TODO derived types for MERGE?
static_assert(sizeof...(TA) > 0);
std::tuple<const Constant<TA> *...> args{
UnwrapExpr<Constant<TA>>(funcRef.arguments()[I].value().value())...};
UnwrapExpr<Constant<TA>>(*funcRef.arguments()[I].value().GetExpr())...};
if ((... && (std::get<I>(args) != nullptr))) {
// Compute the shape of the result based on shapes of arguments
std::vector<std::int64_t> shape;
ConstantSubscripts shape;
int rank{0};
const std::vector<std::int64_t> *shapes[sizeof...(TA)]{
const ConstantSubscripts *shapes[sizeof...(TA)]{
&std::get<I>(args)->shape()...};
const int ranks[sizeof...(TA)]{std::get<I>(args)->Rank()...};
for (unsigned int i{0}; i < sizeof...(TA); ++i) {
@ -222,29 +222,21 @@ static inline Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
CHECK(rank == static_cast<int>(shape.size()));
// Compute all the scalar values of the results
std::size_t size{1};
for (std::int64_t dim : shape) {
size *= dim;
}
std::vector<Scalar<TR>> results;
std::vector<std::int64_t> index(rank, 1);
for (std::size_t n{size}; n-- > 0;) {
if constexpr (std::is_same_v<WrapperType<TR, TA...>,
ScalarFuncWithContext<TR, TA...>>) {
results.emplace_back(func(context,
(ranks[I] ? std::get<I>(args)->At(index)
: **std::get<I>(args))...));
} else if constexpr (std::is_same_v<WrapperType<TR, TA...>,
ScalarFunc<TR, TA...>>) {
results.emplace_back(func((
ranks[I] ? std::get<I>(args)->At(index) : **std::get<I>(args))...));
}
for (int d{0}; d < rank; ++d) {
if (++index[d] <= shape[d]) {
break;
if (TotalElementCount(shape) > 0) {
ConstantSubscripts index{InitialSubscripts(rank)};
do {
if constexpr (std::is_same_v<WrapperType<TR, TA...>,
ScalarFuncWithContext<TR, TA...>>) {
results.emplace_back(func(context,
(ranks[I] ? std::get<I>(args)->At(index)
: **std::get<I>(args))...));
} else if constexpr (std::is_same_v<WrapperType<TR, TA...>,
ScalarFunc<TR, TA...>>) {
results.emplace_back(func((ranks[I] ? std::get<I>(args)->At(index)
: **std::get<I>(args))...));
}
index[d] = 1;
}
} while (IncrementSubscripts(index, shape));
}
// Build and return constant result
if constexpr (TR::category == TypeCategory::Character) {
@ -273,12 +265,21 @@ static Expr<TR> FoldElementalIntrinsic(FoldingContext &context,
template<typename T>
static Expr<T> *UnwrapArgument(std::optional<ActualArgument> &arg) {
return UnwrapExpr<Expr<T>>(arg.value().value());
if (arg.has_value()) {
if (Expr<SomeType> * expr{arg->GetExpr()}) {
return UnwrapExpr<Expr<T>>(*expr);
}
}
return nullptr;
}
static BOZLiteralConstant *UnwrapBozArgument(
std::optional<ActualArgument> &arg) {
return std::get_if<BOZLiteralConstant>(&arg.value().value().u);
if (auto *expr{UnwrapArgument<SomeType>(arg)}) {
return std::get_if<BOZLiteralConstant>(&expr->u);
} else {
return nullptr;
}
}
template<int KIND>
@ -287,9 +288,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
using T = Type<TypeCategory::Integer, KIND>;
ActualArguments &args{funcRef.arguments()};
for (std::optional<ActualArgument> &arg : args) {
if (arg.has_value()) {
arg.value().value() =
FoldOperation(context, std::move(arg.value().value()));
if (auto *expr{UnwrapArgument<SomeType>(arg)}) {
*expr = FoldOperation(context, std::move(*expr));
}
}
if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
@ -311,8 +311,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
// convert boz
for (int i{0}; i <= 1; ++i) {
if (auto *x{UnwrapBozArgument(args[i])}) {
args[i].value().value() =
Fold(context, ConvertToType<T>(std::move(*x)));
*args[i] =
AsGenericExpr(Fold(context, ConvertToType<T>(std::move(*x))));
}
}
// Third argument can be of any kind. However, it must be smaller or equal
@ -320,8 +320,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
using Int4 = Type<TypeCategory::Integer, 4>;
if (auto *n{UnwrapArgument<SomeInteger>(args[2])}) {
if (n->GetType()->kind != 4) {
args[2].value().value() =
Fold(context, ConvertToType<Int4>(std::move(*n)));
*args[2] =
AsGenericExpr(Fold(context, ConvertToType<Int4>(std::move(*n))));
}
}
const auto fptr{
@ -349,8 +349,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
// convert boz
for (int i{0}; i <= 1; ++i) {
if (auto *x{UnwrapBozArgument(args[i])}) {
args[i].value().value() =
Fold(context, ConvertToType<T>(std::move(*x)));
*args[i] =
AsGenericExpr(Fold(context, ConvertToType<T>(std::move(*x))));
}
}
auto fptr{&Scalar<T>::IAND};
@ -371,8 +371,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
using Int4 = Type<TypeCategory::Integer, 4>;
if (auto *n{UnwrapArgument<SomeInteger>(args[1])}) {
if (n->GetType()->kind != 4) {
args[1].value().value() =
Fold(context, ConvertToType<Int4>(std::move(*n)));
*args[1] =
AsGenericExpr(Fold(context, ConvertToType<Int4>(std::move(*n))));
}
}
auto fptr{&Scalar<T>::IBCLR};
@ -396,18 +396,20 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
return std::invoke(fptr, i, static_cast<int>(pos.ToInt64()));
}));
} else if (name == "int") {
return std::visit(
[&](auto &&x) -> Expr<T> {
using From = std::decay_t<decltype(x)>;
if constexpr (std::is_same_v<From, BOZLiteralConstant> ||
std::is_same_v<From, Expr<SomeReal>> ||
std::is_same_v<From, Expr<SomeInteger>> ||
std::is_same_v<From, Expr<SomeComplex>>) {
return Fold(context, ConvertToType<T>(std::move(x)));
}
common::die("int() argument type not valid");
},
std::move(args[0].value().value().u));
if (auto *expr{args[0].value().GetExpr()}) {
return std::visit(
[&](auto &&x) -> Expr<T> {
using From = std::decay_t<decltype(x)>;
if constexpr (std::is_same_v<From, BOZLiteralConstant> ||
std::is_same_v<From, Expr<SomeReal>> ||
std::is_same_v<From, Expr<SomeInteger>> ||
std::is_same_v<From, Expr<SomeComplex>>) {
return Fold(context, ConvertToType<T>(std::move(x)));
}
common::die("int() argument type not valid");
},
std::move(expr->u));
}
} else if (name == "kind") {
if constexpr (common::HasMember<T, IntegerTypes>) {
return Expr<T>{args[0].value().GetType()->kind};
@ -466,8 +468,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
using Int4 = Type<TypeCategory::Integer, 4>;
if (auto *n{UnwrapArgument<SomeInteger>(args[0])}) {
if (n->GetType()->kind != 4) {
args[0].value().value() =
Fold(context, ConvertToType<Int4>(std::move(*n)));
*args[0] =
AsGenericExpr(Fold(context, ConvertToType<Int4>(std::move(*n))));
}
}
const auto fptr{name == "maskl" ? &Scalar<T>::MASKL : &Scalar<T>::MASKR};
@ -479,8 +481,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
// convert boz
for (int i{0}; i <= 2; ++i) {
if (auto *x{UnwrapBozArgument(args[i])}) {
args[i].value().value() =
Fold(context, ConvertToType<T>(std::move(*x)));
*args[i] =
AsGenericExpr(Fold(context, ConvertToType<T>(std::move(*x))));
}
}
return FoldElementalIntrinsic<T, T, T, T>(
@ -489,24 +491,26 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
// TODO assumed-rank dummy argument
return Expr<T>{args[0].value().Rank()};
} else if (name == "shape") {
if (auto shape{GetShape(args[0].value())}) {
if (auto shapeExpr{AsShapeArrayExpr(*shape)}) {
if (auto shape{GetShape(context, args[0].value())}) {
if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
}
}
} else if (name == "size") {
if (auto shape{GetShape(args[0].value())}) {
if (auto shape{GetShape(context, args[0].value())}) {
if (auto &dimArg{args[1]}) { // DIM= is present, get one extent
if (auto dim{ToInt64(dimArg->value())}) {
std::int64_t rank = shape->size();
if (*dim >= 1 && *dim <= rank) {
if (auto &extent{shape->at(*dim - 1)}) {
return Fold(context, ConvertToType<T>(std::move(*extent)));
if (auto *expr{dimArg->GetExpr()}) {
if (auto dim{ToInt64(*expr)}) {
std::int64_t rank = shape->size();
if (*dim >= 1 && *dim <= rank) {
if (auto &extent{shape->at(*dim - 1)}) {
return Fold(context, ConvertToType<T>(std::move(*extent)));
}
} else {
context.messages().Say(
"size(array,dim=%jd) dimension is out of range for rank-%d array"_en_US,
static_cast<std::intmax_t>(*dim), static_cast<int>(rank));
}
} else {
context.messages().Say(
"size(array,dim=%jd) dimension is out of range for rank-%d array"_en_US,
static_cast<std::intmax_t>(*dim), static_cast<int>(rank));
}
}
} else if (auto extents{
@ -539,8 +543,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldOperation(FoldingContext &context,
ActualArguments &args{funcRef.arguments()};
for (std::optional<ActualArgument> &arg : args) {
if (arg.has_value()) {
arg.value().value() =
FoldOperation(context, std::move(arg.value().value()));
if (auto *expr{arg->GetExpr()}) {
*expr = FoldOperation(context, std::move(*expr));
}
}
}
if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
@ -584,8 +589,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldOperation(FoldingContext &context,
using Int4 = Type<TypeCategory::Integer, 4>;
if (auto *n{UnwrapArgument<SomeInteger>(args[0])}) {
if (n->GetType()->kind != 4) {
args[0].value().value() =
Fold(context, ConvertToType<Int4>(std::move(*n)));
*args[0] = AsGenericExpr(
Fold(context, ConvertToType<Int4>(std::move(*n))));
}
}
if (auto callable{
@ -624,8 +629,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldOperation(FoldingContext &context,
// Convert argument to the requested kind before calling aint
if (auto *x{UnwrapArgument<SomeReal>(args[0])}) {
if (!(x->GetType()->kind == T::kind)) {
args[0].value().value() =
Fold(context, ConvertToType<T>(std::move(*x)));
*args[0] =
AsGenericExpr(Fold(context, ConvertToType<T>(std::move(*x))));
}
}
return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
@ -649,25 +654,27 @@ Expr<Type<TypeCategory::Real, KIND>> FoldOperation(FoldingContext &context,
} else if (name == "epsilon") {
return Expr<T>{Constant<T>{Scalar<T>::EPSILON()}};
} else if (name == "real") {
return std::visit(
[&](auto &&x) -> Expr<T> {
using From = std::decay_t<decltype(x)>;
if constexpr (std::is_same_v<From, BOZLiteralConstant>) {
typename T::Scalar::Word::ValueWithOverflow result{
T::Scalar::Word::ConvertUnsigned(x)};
if (result.overflow) { // C1601
context.messages().Say(
"Non null truncated bits of boz literal constant in REAL intrinsic"_en_US);
if (auto *expr{args[0].value().GetExpr()}) {
return std::visit(
[&](auto &&x) -> Expr<T> {
using From = std::decay_t<decltype(x)>;
if constexpr (std::is_same_v<From, BOZLiteralConstant>) {
typename T::Scalar::Word::ValueWithOverflow result{
T::Scalar::Word::ConvertUnsigned(x)};
if (result.overflow) { // C1601
context.messages().Say(
"Non null truncated bits of boz literal constant in REAL intrinsic"_en_US);
}
return Expr<T>{Constant<T>{Scalar<T>(std::move(result.value))}};
} else if constexpr (std::is_same_v<From, Expr<SomeReal>> ||
std::is_same_v<From, Expr<SomeInteger>> ||
std::is_same_v<From, Expr<SomeComplex>>) {
return Fold(context, ConvertToType<T>(std::move(x)));
}
return Expr<T>{Constant<T>{Scalar<T>(std::move(result.value))}};
} else if constexpr (std::is_same_v<From, Expr<SomeReal>> ||
std::is_same_v<From, Expr<SomeInteger>> ||
std::is_same_v<From, Expr<SomeComplex>>) {
return Fold(context, ConvertToType<T>(std::move(x)));
}
common::die("real() argument type not valid");
},
std::move(args[0].value().value().u));
common::die("real() argument type not valid");
},
std::move(expr->u));
}
}
// TODO: anint, cshift, dim, dot_product, eoshift, fraction, huge, matmul,
// max, maxval, merge, min, minval, modulo, nearest, norm2, pack, product,
@ -685,8 +692,9 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldOperation(FoldingContext &context,
ActualArguments &args{funcRef.arguments()};
for (std::optional<ActualArgument> &arg : args) {
if (arg.has_value()) {
arg.value().value() =
FoldOperation(context, std::move(arg.value().value()));
if (auto *expr{arg->GetExpr()}) {
*expr = FoldOperation(context, std::move(*expr));
}
}
}
if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
@ -718,9 +726,9 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldOperation(FoldingContext &context,
CHECK(args.size() == 3);
using Part = typename T::Part;
Expr<SomeType> im{args[1].has_value()
? std::move(args[1].value().value())
? std::move(*args[1].value().GetExpr())
: AsGenericExpr(Constant<Part>{Scalar<Part>{}})};
Expr<SomeType> re{std::move(args[0].value().value())};
Expr<SomeType> re{std::move(*args[0].value().GetExpr())};
int reRank{re.Rank()};
int imRank{im.Rank()};
semantics::Attrs attrs;
@ -751,8 +759,9 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(FoldingContext &context,
ActualArguments &args{funcRef.arguments()};
for (std::optional<ActualArgument> &arg : args) {
if (arg.has_value()) {
arg.value().value() =
FoldOperation(context, std::move(arg.value().value()));
if (auto *expr{arg->GetExpr()}) {
*expr = FoldOperation(context, std::move(*expr));
}
}
}
if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
@ -765,11 +774,10 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(FoldingContext &context,
// simplify.
for (int i{0}; i <= 1; ++i) {
if (auto *x{UnwrapArgument<SomeInteger>(args[i])}) {
args[i].value().value() =
Fold(context, ConvertToType<LargestInt>(std::move(*x)));
*args[i] = AsGenericExpr(
Fold(context, ConvertToType<LargestInt>(std::move(*x))));
} else if (auto *x{UnwrapBozArgument(args[i])}) {
args[i].value().value() =
AsGenericExpr(Constant<LargestInt>{std::move(*x)});
*args[i] = AsGenericExpr(Constant<LargestInt>{std::move(*x)});
}
}
auto fptr{&Scalar<LargestInt>::BGE};
@ -844,16 +852,16 @@ public:
auto n{static_cast<std::int64_t>(elements_.size())};
if constexpr (std::is_same_v<T, SomeDerived>) {
return Expr<T>{Constant<T>{array.derivedTypeSpec(),
std::move(elements_), std::vector<std::int64_t>{n}}};
std::move(elements_), ConstantSubscripts{n}}};
} else if constexpr (T::category == TypeCategory::Character) {
auto length{Fold(context_, common::Clone(array.LEN()))};
if (std::optional<std::int64_t> lengthValue{ToInt64(length)}) {
return Expr<T>{Constant<T>{*lengthValue, std::move(elements_),
std::vector<std::int64_t>{n}}};
return Expr<T>{Constant<T>{
*lengthValue, std::move(elements_), ConstantSubscripts{n}}};
}
} else {
return Expr<T>{
Constant<T>{std::move(elements_), std::vector<std::int64_t>{n}}};
Constant<T>{std::move(elements_), ConstantSubscripts{n}}};
}
}
return Expr<T>{std::move(array)};
@ -864,9 +872,9 @@ private:
Expr<T> folded{Fold(context_, common::Clone(expr.value()))};
if (auto *c{UnwrapExpr<Constant<T>>(folded)}) {
// Copy elements in Fortran array element order
std::vector<std::int64_t> shape{c->shape()};
ConstantSubscripts shape{c->shape()};
int rank{c->Rank()};
std::vector<std::int64_t> index(shape.size(), 1);
ConstantSubscripts index(shape.size(), 1);
for (std::size_t n{c->size()}; n-- > 0;) {
if constexpr (std::is_same_v<T, SomeDerived>) {
elements_.emplace_back(c->derivedTypeSpec(), c->At(index));
@ -919,7 +927,7 @@ private:
return std::visit([&](const auto &y) { return FoldArray(y); }, x.u);
}
bool FoldArray(const ArrayConstructorValues<T> &xs) {
for (const auto &x : xs.values()) {
for (const auto &x : xs) {
if (!FoldArray(x)) {
return false;
}
@ -941,7 +949,7 @@ Expr<T> FoldOperation(FoldingContext &context, ArrayConstructor<T> &&array) {
Expr<SomeDerived> FoldOperation(
FoldingContext &context, StructureConstructor &&structure) {
StructureConstructor result{structure.derivedTypeSpec()};
for (auto &&[symbol, value] : std::move(structure.values())) {
for (auto &&[symbol, value] : std::move(structure)) {
result.Add(*symbol, Fold(context, std::move(value.value())));
}
return Expr<SomeDerived>{Constant<SomeDerived>{result}};
@ -984,6 +992,300 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(
return Expr<IntKIND>{std::move(inquiry)};
}
// Array operation elemental application: When all operands to an operation
// are constant arrays, array constructors without any implied DO loops,
// &/or expanded scalars, pull the operation "into" the array result by
// applying it in an elementwise fashion. For example, [A,1]+[B,2]
// is rewritten into [A+B,1+2] and then partially folded to [A+B,3].
// If possible, restructures an array expression into an array constructor
// that comprises a "flat" ArrayConstructorValues with no implied DO loops.
template<typename T>
bool ArrayConstructorIsFlat(const ArrayConstructorValues<T> &values) {
for (const ArrayConstructorValue<T> &x : values) {
if (!std::holds_alternative<Expr<T>>(x.u)) {
return false;
}
}
return true;
}
template<typename T>
std::optional<Expr<T>> AsFlatArrayConstructor(const Expr<T> &expr) {
if (const auto *c{UnwrapExpr<Constant<T>>(expr)}) {
ArrayConstructor<T> result{expr};
if (c->size() > 0) {
ConstantSubscripts at{InitialSubscripts(c->shape())};
do {
result.Push(Expr<T>{Constant<T>{c->At(at)}});
} while (IncrementSubscripts(at, c->shape()));
}
return std::make_optional<Expr<T>>(std::move(result));
} else if (const auto *a{UnwrapExpr<ArrayConstructor<T>>(expr)}) {
if (ArrayConstructorIsFlat(*a)) {
return std::make_optional<Expr<T>>(expr);
}
} else if (const auto *p{UnwrapExpr<Parentheses<T>>(expr)}) {
return AsFlatArrayConstructor(Expr<T>{p->left()});
}
return std::nullopt;
}
template<TypeCategory CAT>
std::optional<Expr<SomeKind<CAT>>> AsFlatArrayConstructor(
const Expr<SomeKind<CAT>> &expr) {
return std::visit(
[&](const auto &kindExpr) -> std::optional<Expr<SomeKind<CAT>>> {
if (auto flattened{AsFlatArrayConstructor(kindExpr)}) {
return Expr<SomeKind<CAT>>{std::move(*flattened)};
} else {
return std::nullopt;
}
},
expr.u);
}
// FromArrayConstructor is a subroutine for MapOperation() below.
// Given a flat ArrayConstructor<T> and a shape, it wraps the array
// into an Expr<T>, folds it, and returns the resulting wrapped
// array constructor or constant array value.
template<typename T>
Expr<T> FromArrayConstructor(FoldingContext &context,
ArrayConstructor<T> &&values, std::optional<ConstantSubscripts> &&shape) {
Expr<T> result{Fold(context, Expr<T>{std::move(values)})};
if (shape.has_value()) {
if (auto *constant{UnwrapExpr<Constant<T>>(result)}) {
constant->shape() = std::move(*shape);
} else {
auto resultShape{GetShape(context, result)};
CHECK(resultShape.has_value());
auto constantShape{AsConstantShape(*resultShape)};
CHECK(constantShape.has_value());
CHECK(*shape == AsConstantExtents(*constantShape));
}
}
return result;
}
// MapOperation is a utility for various specializations of ApplyElementwise()
// that follow. Given one or two flat ArrayConstructor<OPERAND> (wrapped in an
// Expr<OPERAND>) for some specific operand type(s), apply a given function f
// to each of their corresponding elements to produce a flat
// ArrayConstructor<RESULT> (wrapped in an Expr<RESULT>).
// Preserves shape.
// Unary case
template<typename RESULT, typename OPERAND>
Expr<RESULT> MapOperation(FoldingContext &context,
std::function<Expr<RESULT>(Expr<OPERAND> &&)> &&f, const Shape &shape,
Expr<OPERAND> &&values) {
ArrayConstructor<RESULT> result{values};
if constexpr (IsGenericIntrinsicCategoryType<OPERAND>) {
std::visit(
[&](auto &&kindExpr) {
using kindType = ResultType<decltype(kindExpr)>;
auto &aConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
for (auto &acValue : aConst) {
auto &scalar{std::get<Expr<kindType>>(acValue.u)};
result.Push(
FoldOperation(context, f(Expr<OPERAND>{std::move(scalar)})));
}
},
std::move(values.u));
} else {
auto &aConst{std::get<ArrayConstructor<OPERAND>>(values.u)};
for (auto &acValue : aConst) {
auto &scalar{std::get<Expr<OPERAND>>(acValue.u)};
result.Push(FoldOperation(context, f(std::move(scalar))));
}
}
return FromArrayConstructor(
context, std::move(result), AsConstantExtents(shape));
}
// array * array case
template<typename RESULT, typename LEFT, typename RIGHT>
Expr<RESULT> MapOperation(FoldingContext &context,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
const Shape &shape, Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
ArrayConstructor<RESULT> result{leftValues};
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
if constexpr (IsGenericIntrinsicCategoryType<RIGHT>) {
std::visit(
[&](auto &&kindExpr) {
using kindType = ResultType<decltype(kindExpr)>;
auto &rightArrConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
auto rightIter{rightArrConst.begin()};
for (auto &leftValue : leftArrConst) {
CHECK(rightIter != rightArrConst.end());
auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
auto &rightScalar{std::get<Expr<kindType>>(rightIter->u)};
result.Push(FoldOperation(context,
f(std::move(leftScalar), Expr<RIGHT>{std::move(rightScalar)})));
++rightIter;
}
},
std::move(rightValues.u));
} else {
auto &rightArrConst{std::get<ArrayConstructor<RIGHT>>(rightValues.u)};
auto rightIter{rightArrConst.begin()};
for (auto &leftValue : leftArrConst) {
CHECK(rightIter != rightArrConst.end());
auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
auto &rightScalar{std::get<Expr<RIGHT>>(rightIter->u)};
result.Push(FoldOperation(
context, f(std::move(leftScalar), std::move(rightScalar))));
++rightIter;
}
}
return FromArrayConstructor(
context, std::move(result), AsConstantExtents(shape));
}
// array * scalar case
template<typename RESULT, typename LEFT, typename RIGHT>
Expr<RESULT> MapOperation(FoldingContext &context,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
const Shape &shape, Expr<LEFT> &&leftValues,
const Expr<RIGHT> &rightScalar) {
ArrayConstructor<RESULT> result{leftValues};
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
for (auto &leftValue : leftArrConst) {
auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
result.Push(FoldOperation(
context, f(std::move(leftScalar), Expr<RIGHT>{rightScalar})));
}
return FromArrayConstructor(
context, std::move(result), AsConstantExtents(shape));
}
// scalar * array case
template<typename RESULT, typename LEFT, typename RIGHT>
Expr<RESULT> MapOperation(FoldingContext &context,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
const Shape &shape, const Expr<LEFT> &leftScalar,
Expr<RIGHT> &&rightValues) {
ArrayConstructor<RESULT> result{leftScalar};
if constexpr (IsGenericIntrinsicCategoryType<RIGHT>) {
std::visit(
[&](auto &&kindExpr) {
using kindType = ResultType<decltype(kindExpr)>;
auto &rightArrConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
for (auto &rightValue : rightArrConst) {
auto &rightScalar{std::get<Expr<kindType>>(rightValue.u)};
result.Push(FoldOperation(context,
f(Expr<LEFT>{leftScalar},
Expr<RIGHT>{std::move(rightScalar)})));
}
},
std::move(rightValues.u));
} else {
auto &rightArrConst{std::get<ArrayConstructor<RIGHT>>(rightValues.u)};
for (auto &rightValue : rightArrConst) {
auto &rightScalar{std::get<Expr<RIGHT>>(rightValue.u)};
result.Push(FoldOperation(
context, f(Expr<LEFT>{leftScalar}, std::move(rightScalar))));
}
}
return FromArrayConstructor(
context, std::move(result), AsConstantExtents(shape));
}
// ApplyElementwise() recursively folds the operand expression(s) of an
// operation, then attempts to apply the operation to the (corresponding)
// scalar element(s) of those operands. Returns std::nullopt for scalars
// or unlinearizable operands.
template<typename DERIVED, typename RESULT, typename OPERAND>
auto ApplyElementwise(FoldingContext &context,
Operation<DERIVED, RESULT, OPERAND> &operation,
std::function<Expr<RESULT>(Expr<OPERAND> &&)> &&f)
-> std::optional<Expr<RESULT>> {
auto &expr{operation.left()};
expr = Fold(context, std::move(expr));
if (expr.Rank() > 0) {
if (std::optional<Shape> shape{GetShape(context, expr)}) {
if (auto values{AsFlatArrayConstructor(expr)}) {
return MapOperation(context, std::move(f), *shape, std::move(*values));
}
}
}
return std::nullopt;
}
template<typename DERIVED, typename RESULT, typename OPERAND>
auto ApplyElementwise(
FoldingContext &context, Operation<DERIVED, RESULT, OPERAND> &operation)
-> std::optional<Expr<RESULT>> {
return ApplyElementwise(context, operation,
std::function<Expr<RESULT>(Expr<OPERAND> &&)>{
[](Expr<OPERAND> &&operand) {
return Expr<RESULT>{DERIVED{std::move(operand)}};
}});
}
// Predicate: is a scalar expression suitable for naive scalar expansion
// in the flattening of an array expression?
// TODO: capture such scalar expansions in temporaries, flatten everything
struct UnexpandabilityFindingVisitor : public virtual VisitorBase<bool> {
using Result = bool;
explicit UnexpandabilityFindingVisitor(int) { result() = false; }
template<typename T> void Handle(FunctionRef<T> &) { Return(true); }
template<typename T> void Handle(CoarrayRef &) { Return(true); }
};
template<typename T> bool IsExpandableScalar(const Expr<T> &expr) {
return Visitor<UnexpandabilityFindingVisitor>{0}.Traverse(expr);
}
template<typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
auto ApplyElementwise(FoldingContext &context,
Operation<DERIVED, RESULT, LEFT, RIGHT> &operation,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f)
-> std::optional<Expr<RESULT>> {
auto &leftExpr{operation.left()};
leftExpr = Fold(context, std::move(leftExpr));
auto &rightExpr{operation.right()};
rightExpr = Fold(context, std::move(rightExpr));
if (leftExpr.Rank() > 0) {
if (std::optional<Shape> leftShape{GetShape(context, leftExpr)}) {
if (auto left{AsFlatArrayConstructor(leftExpr)}) {
if (rightExpr.Rank() > 0) {
if (std::optional<Shape> rightShape{GetShape(context, rightExpr)}) {
if (auto right{AsFlatArrayConstructor(rightExpr)}) {
CheckConformance(context.messages(), *leftShape, *rightShape);
return MapOperation(context, std::move(f), *leftShape,
std::move(*left), std::move(*right));
}
}
} else if (IsExpandableScalar(rightExpr)) {
return MapOperation(
context, std::move(f), *leftShape, std::move(*left), rightExpr);
}
}
}
} else if (rightExpr.Rank() > 0 && IsExpandableScalar(leftExpr)) {
if (std::optional<Shape> shape{GetShape(context, rightExpr)}) {
if (auto right{AsFlatArrayConstructor(rightExpr)}) {
return MapOperation(
context, std::move(f), *shape, leftExpr, std::move(*right));
}
}
}
return std::nullopt;
}
template<typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
auto ApplyElementwise(
FoldingContext &context, Operation<DERIVED, RESULT, LEFT, RIGHT> &operation)
-> std::optional<Expr<RESULT>> {
return ApplyElementwise(context, operation,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)>{
[](Expr<LEFT> &&left, Expr<RIGHT> &&right) {
return Expr<RESULT>{DERIVED{std::move(left), std::move(right)}};
}});
}
// Unary operations
template<typename TO, typename FROM>
@ -1007,11 +1309,12 @@ common::IfNoLvalue<std::optional<TO>, FROM> ConvertString(FROM &&s) {
template<typename TO, TypeCategory FROMCAT>
Expr<TO> FoldOperation(
FoldingContext &context, Convert<TO, FROMCAT> &&convert) {
if (auto array{ApplyElementwise(context, convert)}) {
return *array;
}
return std::visit(
[&](auto &kindExpr) -> Expr<TO> {
kindExpr = Fold(context, std::move(kindExpr));
using Operand = ResultType<decltype(kindExpr)>;
// TODO pmk: conversion of array constructors (constant or not)
char buffer[64];
if (auto value{GetScalarConstantValue<Operand>(kindExpr)}) {
if constexpr (TO::category == TypeCategory::Integer) {
@ -1081,13 +1384,15 @@ Expr<T> FoldOperation(FoldingContext &context, Parentheses<T> &&x) {
// Preserve parentheses, even around constants.
return Expr<T>{Parentheses<T>{Expr<T>{Constant<T>{*value}}}};
}
return Expr<T>{std::move(x)};
return Expr<T>{Parentheses<T>{std::move(operand)}};
}
template<typename T>
Expr<T> FoldOperation(FoldingContext &context, Negate<T> &&x) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
auto &operand{x.left()};
operand = Fold(context, std::move(operand));
if (auto value{GetScalarConstantValue<T>(operand)}) {
if constexpr (T::category == TypeCategory::Integer) {
auto negated{value->Negate()};
@ -1108,9 +1413,17 @@ template<int KIND>
Expr<Type<TypeCategory::Real, KIND>> FoldOperation(
FoldingContext &context, ComplexComponent<KIND> &&x) {
using Operand = Type<TypeCategory::Complex, KIND>;
using Result = Type<TypeCategory::Real, KIND>;
if (auto array{ApplyElementwise(context, x,
std::function<Expr<Result>(Expr<Operand> &&)>{
[=](Expr<Operand> &&operand) {
return Expr<Result>{ComplexComponent<KIND>{
x.isImaginaryPart, std::move(operand)}};
}})}) {
return *array;
}
using Part = Type<TypeCategory::Real, KIND>;
auto &operand{x.left()};
operand = Fold(context, std::move(operand));
if (auto value{GetScalarConstantValue<Operand>(operand)}) {
if (x.isImaginaryPart) {
return Expr<Part>{Constant<Part>{value->AIMAG()}};
@ -1124,9 +1437,11 @@ Expr<Type<TypeCategory::Real, KIND>> FoldOperation(
template<int KIND>
Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
FoldingContext &context, Not<KIND> &&x) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
using Ty = Type<TypeCategory::Logical, KIND>;
auto &operand{x.left()};
operand = Fold(context, std::move(operand));
if (auto value{GetScalarConstantValue<Ty>(operand)}) {
return Expr<Ty>{Constant<Ty>{!value->IsTrue()}};
}
@ -1135,22 +1450,29 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
// Binary (dyadic) operations
template<typename T1, typename T2>
std::optional<std::pair<Scalar<T1>, Scalar<T2>>> FoldOperands(
FoldingContext &context, Expr<T1> &x, Expr<T2> &y) {
x = Fold(context, std::move(x)); // use of std::move() on &x is intentional
y = Fold(context, std::move(y));
if (auto xvalue{GetScalarConstantValue<T1>(x)}) {
if (auto yvalue{GetScalarConstantValue<T2>(y)}) {
template<typename LEFT, typename RIGHT>
std::optional<std::pair<Scalar<LEFT>, Scalar<RIGHT>>> OperandsAreConstants(
const Expr<LEFT> &x, const Expr<RIGHT> &y) {
if (auto xvalue{GetScalarConstantValue<LEFT>(x)}) {
if (auto yvalue{GetScalarConstantValue<RIGHT>(y)}) {
return {std::make_pair(*xvalue, *yvalue)};
}
}
return std::nullopt;
}
template<typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
std::optional<std::pair<Scalar<LEFT>, Scalar<RIGHT>>> OperandsAreConstants(
const Operation<DERIVED, RESULT, LEFT, RIGHT> &operation) {
return OperandsAreConstants(operation.left(), operation.right());
}
template<typename T>
Expr<T> FoldOperation(FoldingContext &context, Add<T> &&x) {
if (auto folded{FoldOperands(context, x.left(), x.right())}) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
if (auto folded{OperandsAreConstants(x)}) {
if constexpr (T::category == TypeCategory::Integer) {
auto sum{folded->first.AddSigned(folded->second)};
if (sum.overflow) {
@ -1172,7 +1494,10 @@ Expr<T> FoldOperation(FoldingContext &context, Add<T> &&x) {
template<typename T>
Expr<T> FoldOperation(FoldingContext &context, Subtract<T> &&x) {
if (auto folded{FoldOperands(context, x.left(), x.right())}) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
if (auto folded{OperandsAreConstants(x)}) {
if constexpr (T::category == TypeCategory::Integer) {
auto difference{folded->first.SubtractSigned(folded->second)};
if (difference.overflow) {
@ -1195,7 +1520,10 @@ Expr<T> FoldOperation(FoldingContext &context, Subtract<T> &&x) {
template<typename T>
Expr<T> FoldOperation(FoldingContext &context, Multiply<T> &&x) {
if (auto folded{FoldOperands(context, x.left(), x.right())}) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
if (auto folded{OperandsAreConstants(x)}) {
if constexpr (T::category == TypeCategory::Integer) {
auto product{folded->first.MultiplySigned(folded->second)};
if (product.SignedMultiplicationOverflowed()) {
@ -1217,7 +1545,10 @@ Expr<T> FoldOperation(FoldingContext &context, Multiply<T> &&x) {
template<typename T>
Expr<T> FoldOperation(FoldingContext &context, Divide<T> &&x) {
if (auto folded{FoldOperands(context, x.left(), x.right())}) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
if (auto folded{OperandsAreConstants(x)}) {
if constexpr (T::category == TypeCategory::Integer) {
auto quotAndRem{folded->first.DivideSigned(folded->second)};
if (quotAndRem.divisionByZero) {
@ -1242,7 +1573,10 @@ Expr<T> FoldOperation(FoldingContext &context, Divide<T> &&x) {
template<typename T>
Expr<T> FoldOperation(FoldingContext &context, Power<T> &&x) {
if (auto folded{FoldOperands(context, x.left(), x.right())}) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
if (auto folded{OperandsAreConstants(x)}) {
if constexpr (T::category == TypeCategory::Integer) {
auto power{folded->first.Power(folded->second)};
if (power.divisionByZero) {
@ -1264,9 +1598,12 @@ Expr<T> FoldOperation(FoldingContext &context, Power<T> &&x) {
template<typename T>
Expr<T> FoldOperation(FoldingContext &context, RealToIntPower<T> &&x) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
return std::visit(
[&](auto &y) -> Expr<T> {
if (auto folded{FoldOperands(context, x.left(), y)}) {
if (auto folded{OperandsAreConstants(x.left(), y)}) {
auto power{evaluate::IntPower(folded->first, folded->second)};
RealFlagWarnings(context, power.flags, "power with INTEGER exponent");
if (context.flushSubnormalsToZero()) {
@ -1282,7 +1619,10 @@ Expr<T> FoldOperation(FoldingContext &context, RealToIntPower<T> &&x) {
template<typename T>
Expr<T> FoldOperation(FoldingContext &context, Extremum<T> &&x) {
if (auto folded{FoldOperands(context, x.left(), x.right())}) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
if (auto folded{OperandsAreConstants(x)}) {
if constexpr (T::category == TypeCategory::Integer) {
if (folded->first.CompareSigned(folded->second) == x.ordering) {
return Expr<T>{Constant<T>{folded->first}};
@ -1306,8 +1646,11 @@ Expr<T> FoldOperation(FoldingContext &context, Extremum<T> &&x) {
template<int KIND>
Expr<Type<TypeCategory::Complex, KIND>> FoldOperation(
FoldingContext &context, ComplexConstructor<KIND> &&x) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
using Result = Type<TypeCategory::Complex, KIND>;
if (auto folded{FoldOperands(context, x.left(), x.right())}) {
if (auto folded{OperandsAreConstants(x)}) {
return Expr<Result>{
Constant<Result>{Scalar<Result>{folded->first, folded->second}}};
}
@ -1317,8 +1660,11 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldOperation(
template<int KIND>
Expr<Type<TypeCategory::Character, KIND>> FoldOperation(
FoldingContext &context, Concat<KIND> &&x) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
using Result = Type<TypeCategory::Character, KIND>;
if (auto folded{FoldOperands(context, x.left(), x.right())}) {
if (auto folded{OperandsAreConstants(x)}) {
return Expr<Result>{Constant<Result>{folded->first + folded->second}};
}
return Expr<Result>{std::move(x)};
@ -1327,8 +1673,11 @@ Expr<Type<TypeCategory::Character, KIND>> FoldOperation(
template<int KIND>
Expr<Type<TypeCategory::Character, KIND>> FoldOperation(
FoldingContext &context, SetLength<KIND> &&x) {
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
using Result = Type<TypeCategory::Character, KIND>;
if (auto folded{FoldOperands(context, x.left(), x.right())}) {
if (auto folded{OperandsAreConstants(x)}) {
auto oldLength{static_cast<std::int64_t>(folded->first.size())};
auto newLength{folded->second.ToInt64()};
if (newLength < oldLength) {
@ -1345,7 +1694,15 @@ Expr<Type<TypeCategory::Character, KIND>> FoldOperation(
template<typename T>
Expr<LogicalResult> FoldOperation(
FoldingContext &context, Relational<T> &&relation) {
if (auto folded{FoldOperands(context, relation.left(), relation.right())}) {
if (auto array{ApplyElementwise(context, relation,
std::function<Expr<LogicalResult>(Expr<T> &&, Expr<T> &&)>{
[=](Expr<T> &&x, Expr<T> &&y) {
return Expr<LogicalResult>{Relational<SomeType>{
Relational<T>{relation.opr, std::move(x), std::move(y)}}};
}})}) {
return *array;
}
if (auto folded{OperandsAreConstants(relation)}) {
bool result{};
if constexpr (T::category == TypeCategory::Integer) {
result =
@ -1374,11 +1731,19 @@ inline Expr<LogicalResult> FoldOperation(
template<int KIND>
Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
FoldingContext &context, LogicalOperation<KIND> &&x) {
FoldingContext &context, LogicalOperation<KIND> &&operation) {
using LOGICAL = Type<TypeCategory::Logical, KIND>;
if (auto folded{FoldOperands(context, x.left(), x.right())}) {
if (auto array{ApplyElementwise(context, operation,
std::function<Expr<LOGICAL>(Expr<LOGICAL> &&, Expr<LOGICAL> &&)>{
[=](Expr<LOGICAL> &&x, Expr<LOGICAL> &&y) {
return Expr<LOGICAL>{LogicalOperation<KIND>{
operation.logicalOperator, std::move(x), std::move(y)}};
}})}) {
return *array;
}
if (auto folded{OperandsAreConstants(operation)}) {
bool xt{folded->first.IsTrue()}, yt{folded->second.IsTrue()}, result{};
switch (x.logicalOperator) {
switch (operation.logicalOperator) {
case LogicalOperator::And: result = xt && yt; break;
case LogicalOperator::Or: result = xt || yt; break;
case LogicalOperator::Eqv: result = xt == yt; break;
@ -1386,7 +1751,7 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
}
return Expr<LOGICAL>{Constant<LOGICAL>{result}};
}
return Expr<LOGICAL>{std::move(x)};
return Expr<LOGICAL>{std::move(operation)};
}
// end per-operation folding functions
@ -1420,7 +1785,6 @@ FOR_EACH_TYPE_AND_KIND(template class ExpressionBase, )
// able to fold it (yet) into a known constant value; specifically,
// the expression may reference derived type kind parameters whose values
// are not yet known.
class IsConstantExprVisitor : public virtual VisitorBase<bool> {
public:
using Result = bool;

View File

@ -52,7 +52,7 @@ std::optional<Expr<T>> Fold(
template<typename T>
std::optional<Scalar<T>> GetScalarConstantValue(const Expr<T> &expr) {
if (const auto *c{UnwrapExpr<Constant<T>>(expr)}) {
if (c->size() == 1) {
if (c->Rank() == 0) {
return **c;
} else {
return std::nullopt;

View File

@ -22,8 +22,7 @@
namespace Fortran::evaluate {
static void ShapeAsFortran(
std::ostream &o, const std::vector<std::int64_t> &shape) {
static void ShapeAsFortran(std::ostream &o, const ConstantSubscripts &shape) {
if (shape.size() > 1) {
o << ",shape=";
char ch{'['};
@ -101,6 +100,10 @@ std::ostream &Constant<Type<TypeCategory::Character, KIND>>::AsFortran(
return o;
}
std::ostream &ActualArgument::AssumedType::AsFortran(std::ostream &o) const {
return o << symbol_->name().ToString();
}
std::ostream &ActualArgument::AsFortran(std::ostream &o) const {
if (keyword.has_value()) {
o << keyword->ToString() << '=';
@ -108,7 +111,11 @@ std::ostream &ActualArgument::AsFortran(std::ostream &o) const {
if (isAlternateReturn) {
o << '*';
}
return value().AsFortran(o);
if (const auto *expr{GetExpr()}) {
return expr->AsFortran(o);
} else {
return std::get<AssumedType>(u_).AsFortran(o);
}
}
std::ostream &SpecificIntrinsic::AsFortran(std::ostream &o) const {
@ -321,7 +328,7 @@ template<typename T>
std::ostream &EmitArray(
std::ostream &o, const ArrayConstructorValues<T> &values) {
const char *sep{""};
for (const auto &value : values.values()) {
for (const auto &value : values) {
o << sep;
std::visit([&](const auto &x) { EmitArray(o, x); }, value.u);
sep = ",";

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "intrinsics.h"
#include "common.h"
#include "expression.h"
#include "fold.h"
#include "shape.h"
@ -32,7 +33,7 @@ using namespace Fortran::parser::literals;
namespace Fortran::evaluate {
using common::TypeCategory;
class FoldingContext;
// This file defines the supported intrinsic procedures and implements
// their recognition and validation. It is largely table-driven. See
@ -53,6 +54,8 @@ using common::TypeCategory;
// INTEGER with a special "typeless" kind code. Arguments of intrinsic types
// that can also be be typeless values are encoded with an "elementalOrBOZ"
// rank pattern.
// Assumed-type (TYPE(*)) dummy arguments can be forwarded along to some
// intrinsic functions that accept AnyType + Rank::anyOrAssumedRank.
using CategorySet = common::EnumSet<TypeCategory, 8>;
static constexpr CategorySet IntType{TypeCategory::Integer};
static constexpr CategorySet RealType{TypeCategory::Real};
@ -72,12 +75,12 @@ ENUM_CLASS(KindCode, none, defaultIntegerKind,
defaultRealKind, // is also the default COMPLEX kind
doublePrecision, defaultCharKind, defaultLogicalKind,
any, // matches any kind value; each instance is independent
same, // match any kind, but all "same" kinds must be equal
typeless, // BOZ literals are INTEGER with this kind
teamType, // TEAM_TYPE from module ISO_FORTRAN_ENV (for coarrays)
kindArg, // this argument is KIND=
effectiveKind, // for function results: same "kindArg", possibly defaulted
dimArg, // this argument is DIM=
same, // match any kind; all "same" kinds must be equal
likeMultiply, // for DOT_PRODUCT and MATMUL
)
@ -153,7 +156,7 @@ ENUM_CLASS(Rank,
matrix,
array, // not scalar, rank is known and greater than zero
known, // rank is known and can be scalar
anyOrAssumedRank, // rank can be unknown
anyOrAssumedRank, // rank can be unknown; assumed-type TYPE(*) allowed
conformable, // scalar, or array of same rank & shape as "array" argument
reduceOperation, // a pure function with constraints for REDUCE
dimReduced, // scalar if no DIM= argument, else rank(array)-1
@ -207,7 +210,7 @@ struct IntrinsicInterface {
Rank rank{Rank::elemental};
std::optional<SpecificCall> Match(const CallCharacteristics &,
const common::IntrinsicTypeDefaultKinds &, ActualArguments &,
parser::ContextualMessages &messages) const;
FoldingContext &context) const;
int CountArguments() const;
std::ostream &Dump(std::ostream &) const;
};
@ -771,7 +774,8 @@ static const SpecificIntrinsicInterface specificIntrinsicFunction[]{
std::optional<SpecificCall> IntrinsicInterface::Match(
const CallCharacteristics &call,
const common::IntrinsicTypeDefaultKinds &defaults,
ActualArguments &arguments, parser::ContextualMessages &messages) const {
ActualArguments &arguments, FoldingContext &context) const {
auto &messages{context.messages()};
// Attempt to construct a 1-1 correspondence between the dummy arguments in
// a particular intrinsic procedure's generic interface and the actual
// arguments in a procedure reference.
@ -868,6 +872,18 @@ std::optional<SpecificCall> IntrinsicInterface::Match(
continue;
}
}
if (arg->GetAssumedTypeDummy()) {
// TYPE(*) assumed-type dummy argument forwarded to intrinsic
if (d.typePattern.categorySet == AnyType &&
d.typePattern.kindCode == KindCode::any &&
d.rank == Rank::anyOrAssumedRank) {
continue;
}
messages.Say("Assumed type TYPE(*) dummy argument not allowed "
"for '%s=' intrinsic argument"_err_en_US,
d.keyword);
return std::nullopt;
}
std::optional<DynamicType> type{arg->GetType()};
if (!type.has_value()) {
CHECK(arg->Rank() == 0);
@ -946,7 +962,7 @@ std::optional<SpecificCall> IntrinsicInterface::Match(
for (std::size_t j{0}; j < dummies; ++j) {
const IntrinsicDummyArgument &d{dummy[std::min(j, dummyArgPatterns - 1)]};
if (const ActualArgument * arg{actualForDummy[j]}) {
if (IsAssumedRank(arg->value()) && d.rank != Rank::anyOrAssumedRank) {
if (IsAssumedRank(*arg) && d.rank != Rank::anyOrAssumedRank) {
messages.Say("assumed-rank array cannot be forwarded to "
"'%s=' argument"_err_en_US,
d.keyword);
@ -967,7 +983,7 @@ std::optional<SpecificCall> IntrinsicInterface::Match(
case Rank::shape:
CHECK(!shapeArgSize.has_value());
if (rank == 1) {
if (auto shape{GetShape(*arg)}) {
if (auto shape{GetShape(context, *arg)}) {
if (auto constShape{AsConstantShape(*shape)}) {
shapeArgSize = (**constShape).ToInt64();
CHECK(shapeArgSize >= 0);
@ -1083,12 +1099,13 @@ std::optional<SpecificCall> IntrinsicInterface::Match(
CHECK(kindDummyArg != nullptr);
CHECK(result.categorySet == CategorySet{resultType->category});
if (kindArg != nullptr) {
auto &expr{kindArg->value()};
CHECK(expr.Rank() == 0);
if (auto code{ToInt64(expr)}) {
if (IsValidKindOfIntrinsicType(resultType->category, *code)) {
resultType->kind = *code;
break;
if (auto *expr{kindArg->GetExpr()}) {
CHECK(expr->Rank() == 0);
if (auto code{ToInt64(*expr)}) {
if (IsValidKindOfIntrinsicType(resultType->category, *code)) {
resultType->kind = *code;
break;
}
}
}
messages.Say("'kind=' argument must be a constant scalar integer "
@ -1196,8 +1213,8 @@ public:
bool IsIntrinsic(const std::string &) const;
std::optional<SpecificCall> Probe(const CallCharacteristics &,
ActualArguments &, parser::ContextualMessages *) const;
std::optional<SpecificCall> Probe(
const CallCharacteristics &, ActualArguments &, FoldingContext &) const;
std::optional<UnrestrictedSpecificIntrinsicFunctionInterface>
IsUnrestrictedSpecificIntrinsicFunction(const std::string &) const;
@ -1230,21 +1247,21 @@ bool IntrinsicProcTable::Implementation::IsIntrinsic(
// match for a given procedure reference.
std::optional<SpecificCall> IntrinsicProcTable::Implementation::Probe(
const CallCharacteristics &call, ActualArguments &arguments,
parser::ContextualMessages *messages) const {
FoldingContext &context) const {
if (call.isSubroutineCall) {
return std::nullopt; // TODO
}
parser::Messages *finalBuffer{messages ? messages->messages() : nullptr};
parser::Messages *finalBuffer{context.messages().messages()};
// Probe the specific intrinsic function table first.
parser::Messages specificBuffer;
parser::ContextualMessages specificErrors{
messages ? messages->at() : call.name,
finalBuffer ? &specificBuffer : nullptr};
call.name, finalBuffer ? &specificBuffer : nullptr};
FoldingContext specificContext{context, specificErrors};
std::string name{call.name.ToString()};
auto specificRange{specificFuncs_.equal_range(name)};
for (auto iter{specificRange.first}; iter != specificRange.second; ++iter) {
if (auto specificCall{
iter->second->Match(call, defaults_, arguments, specificErrors)}) {
iter->second->Match(call, defaults_, arguments, specificContext)}) {
if (const char *genericName{iter->second->generic}) {
specificCall->specificIntrinsic.name = genericName;
}
@ -1256,12 +1273,12 @@ std::optional<SpecificCall> IntrinsicProcTable::Implementation::Probe(
// Probe the generic intrinsic function table next.
parser::Messages genericBuffer;
parser::ContextualMessages genericErrors{
messages ? messages->at() : call.name,
finalBuffer ? &genericBuffer : nullptr};
call.name, finalBuffer ? &genericBuffer : nullptr};
FoldingContext genericContext{context, genericErrors};
auto genericRange{genericFuncs_.equal_range(name)};
for (auto iter{genericRange.first}; iter != genericRange.second; ++iter) {
if (auto specificCall{
iter->second->Match(call, defaults_, arguments, genericErrors)}) {
iter->second->Match(call, defaults_, arguments, genericContext)}) {
return specificCall;
}
}
@ -1277,20 +1294,20 @@ std::optional<SpecificCall> IntrinsicProcTable::Implementation::Probe(
genericErrors.Say("unknown argument '%s' to NULL()"_err_en_US,
arguments[0]->keyword->ToString().data());
} else {
Expr<SomeType> &mold{arguments[0]->value()};
if (IsPointerOrAllocatable(mold)) {
return std::make_optional<SpecificCall>(
SpecificIntrinsic{"null"s, mold.GetType(), mold.Rank(),
semantics::Attrs{semantics::Attr::POINTER}},
std::move(arguments));
} else {
genericErrors.Say("MOLD argument to NULL() must be a pointer "
"or allocatable"_err_en_US);
if (Expr<SomeType> * mold{arguments[0]->GetExpr()}) {
if (IsPointerOrAllocatable(*mold)) {
return std::make_optional<SpecificCall>(
SpecificIntrinsic{"null"s, mold->GetType(), mold->Rank(),
semantics::Attrs{semantics::Attr::POINTER}},
std::move(arguments));
}
}
genericErrors.Say("MOLD argument to NULL() must be a pointer "
"or allocatable"_err_en_US);
}
}
// No match
if (finalBuffer) {
if (finalBuffer != nullptr) {
if (genericBuffer.empty()) {
finalBuffer->Annex(std::move(specificBuffer));
} else {
@ -1358,9 +1375,9 @@ bool IntrinsicProcTable::IsIntrinsic(const std::string &name) const {
std::optional<SpecificCall> IntrinsicProcTable::Probe(
const CallCharacteristics &call, ActualArguments &arguments,
parser::ContextualMessages *messages) const {
FoldingContext &context) const {
CHECK(impl_ != nullptr || !"IntrinsicProcTable: not configured");
return impl_->Probe(call, arguments, messages);
return impl_->Probe(call, arguments, context);
}
std::optional<UnrestrictedSpecificIntrinsicFunctionInterface>

View File

@ -26,6 +26,8 @@
namespace Fortran::evaluate {
class FoldingContext;
struct CallCharacteristics {
parser::CharBlock name;
bool isSubroutineCall{false};
@ -61,8 +63,8 @@ public:
// Probe the intrinsics for a match against a specific call.
// On success, the actual arguments are transferred to the result
// in dummy argument order.
std::optional<SpecificCall> Probe(const CallCharacteristics &,
ActualArguments &, parser::ContextualMessages *messages = nullptr) const;
std::optional<SpecificCall> Probe(
const CallCharacteristics &, ActualArguments &, FoldingContext &) const;
// Probe the intrinsics with the name of a potential unrestricted specific
// intrinsic.

View File

@ -33,13 +33,15 @@ Shape AsShape(const Constant<ExtentType> &arrayConstant) {
return result;
}
std::optional<Shape> AsShape(ExtentExpr &&arrayExpr) {
std::optional<Shape> AsShape(FoldingContext &context, ExtentExpr &&arrayExpr) {
// Flatten any array expression into an array constructor if possible.
arrayExpr = Fold(context, std::move(arrayExpr));
if (auto *constArray{UnwrapExpr<Constant<ExtentType>>(arrayExpr)}) {
return AsShape(*constArray);
}
if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
Shape result;
for (auto &value : constructor->values()) {
for (auto &value : *constructor) {
if (auto *expr{std::get_if<ExtentExpr>(&value.u)}) {
if (expr->Rank() == 0) {
result.emplace_back(std::move(*expr));
@ -50,13 +52,10 @@ std::optional<Shape> AsShape(ExtentExpr &&arrayExpr) {
}
return result;
}
// TODO: linearize other array-valued expressions of known shape, e.g. A+B
// as well as conversions of arrays; this will be easier given a
// general-purpose array expression flattener (pmk)
return std::nullopt;
}
std::optional<ExtentExpr> AsShapeArrayExpr(const Shape &shape) {
std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) {
ArrayConstructorValues<ExtentType> values;
for (const auto &dim : shape) {
if (dim.has_value()) {
@ -69,7 +68,7 @@ std::optional<ExtentExpr> AsShapeArrayExpr(const Shape &shape) {
}
std::optional<Constant<ExtentType>> AsConstantShape(const Shape &shape) {
if (auto shapeArray{AsShapeArrayExpr(shape)}) {
if (auto shapeArray{AsExtentArrayExpr(shape)}) {
FoldingContext noFoldingContext;
auto folded{Fold(noFoldingContext, std::move(*shapeArray))};
if (auto *p{UnwrapExpr<Constant<ExtentType>>(folded)}) {
@ -79,6 +78,22 @@ std::optional<Constant<ExtentType>> AsConstantShape(const Shape &shape) {
return std::nullopt;
}
ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) {
ConstantSubscripts result;
for (const auto &extent : shape.values()) {
result.push_back(extent.ToInt64());
}
return result;
}
std::optional<ConstantSubscripts> AsConstantExtents(const Shape &shape) {
if (auto shapeConstant{AsConstantShape(shape)}) {
return AsConstantExtents(*shapeConstant);
} else {
return std::nullopt;
}
}
static ExtentExpr ComputeTripCount(
ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
ExtentExpr strideCopy{common::Clone(stride)};
@ -121,7 +136,16 @@ MaybeExtent GetSize(Shape &&shape) {
return extent;
}
static MaybeExtent GetLowerBound(
bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
struct MyVisitor : public virtual VisitorBase<bool> {
using Result = bool;
explicit MyVisitor(int) { result() = false; }
void Handle(const ImpliedDoIndex &) { Return(true); }
};
return Visitor<MyVisitor>{0}.Traverse(expr);
}
MaybeExtent GetShapeHelper::GetLowerBound(
const Symbol &symbol, const Component *component, int dimension) {
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
int j{0};
@ -142,7 +166,7 @@ static MaybeExtent GetLowerBound(
return std::nullopt;
}
static MaybeExtent GetExtent(
MaybeExtent GetShapeHelper::GetExtent(
const Symbol &symbol, const Component *component, int dimension) {
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
int j{0};
@ -169,8 +193,8 @@ static MaybeExtent GetExtent(
return std::nullopt;
}
static MaybeExtent GetExtent(const Subscript &subscript, const Symbol &symbol,
const Component *component, int dimension) {
MaybeExtent GetShapeHelper::GetExtent(const Subscript &subscript,
const Symbol &symbol, const Component *component, int dimension) {
return std::visit(
common::visitors{
[&](const Triplet &triplet) -> MaybeExtent {
@ -185,7 +209,7 @@ static MaybeExtent GetExtent(const Subscript &subscript, const Symbol &symbol,
return CountTrips(std::move(lower), std::move(upper),
MaybeExtent{triplet.stride()});
},
[](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtent {
[&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtent {
if (auto shape{GetShape(subs.value())}) {
if (shape->size() > 0) {
CHECK(shape->size() == 1); // vector-valued subscript
@ -198,16 +222,7 @@ static MaybeExtent GetExtent(const Subscript &subscript, const Symbol &symbol,
subscript.u);
}
bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
struct MyVisitor : public virtual VisitorBase<bool> {
using Result = bool;
explicit MyVisitor(int) { result() = false; }
void Handle(const ImpliedDoIndex &) { Return(true); }
};
return Visitor<MyVisitor>{0}.Traverse(expr);
}
std::optional<Shape> GetShape(
std::optional<Shape> GetShapeHelper::GetShape(
const Symbol &symbol, const Component *component) {
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
Shape result;
@ -221,7 +236,7 @@ std::optional<Shape> GetShape(
}
}
std::optional<Shape> GetShape(const Symbol *symbol) {
std::optional<Shape> GetShapeHelper::GetShape(const Symbol *symbol) {
if (symbol != nullptr) {
return GetShape(*symbol);
} else {
@ -229,7 +244,7 @@ std::optional<Shape> GetShape(const Symbol *symbol) {
}
}
std::optional<Shape> GetShape(const BaseObject &object) {
std::optional<Shape> GetShapeHelper::GetShape(const BaseObject &object) {
if (const Symbol * symbol{object.symbol()}) {
return GetShape(*symbol);
} else {
@ -237,7 +252,7 @@ std::optional<Shape> GetShape(const BaseObject &object) {
}
}
std::optional<Shape> GetShape(const Component &component) {
std::optional<Shape> GetShapeHelper::GetShape(const Component &component) {
const Symbol &symbol{component.GetLastSymbol()};
if (symbol.Rank() > 0) {
return GetShape(symbol, &component);
@ -246,7 +261,7 @@ std::optional<Shape> GetShape(const Component &component) {
}
}
std::optional<Shape> GetShape(const ArrayRef &arrayRef) {
std::optional<Shape> GetShapeHelper::GetShape(const ArrayRef &arrayRef) {
Shape shape;
const Symbol &symbol{arrayRef.GetLastSymbol()};
const Component *component{std::get_if<Component>(&arrayRef.base())};
@ -264,7 +279,7 @@ std::optional<Shape> GetShape(const ArrayRef &arrayRef) {
}
}
std::optional<Shape> GetShape(const CoarrayRef &coarrayRef) {
std::optional<Shape> GetShapeHelper::GetShape(const CoarrayRef &coarrayRef) {
Shape shape;
SymbolOrComponent base{coarrayRef.GetBaseSymbolOrComponent()};
const Symbol &symbol{coarrayRef.GetLastSymbol()};
@ -283,11 +298,11 @@ std::optional<Shape> GetShape(const CoarrayRef &coarrayRef) {
}
}
std::optional<Shape> GetShape(const DataRef &dataRef) {
std::optional<Shape> GetShapeHelper::GetShape(const DataRef &dataRef) {
return GetShape(dataRef.u);
}
std::optional<Shape> GetShape(const Substring &substring) {
std::optional<Shape> GetShapeHelper::GetShape(const Substring &substring) {
if (const auto *dataRef{substring.GetParentIf<DataRef>()}) {
return GetShape(*dataRef);
} else {
@ -295,15 +310,21 @@ std::optional<Shape> GetShape(const Substring &substring) {
}
}
std::optional<Shape> GetShape(const ComplexPart &part) {
std::optional<Shape> GetShapeHelper::GetShape(const ComplexPart &part) {
return GetShape(part.complex());
}
std::optional<Shape> GetShape(const ActualArgument &arg) {
return GetShape(arg.value());
std::optional<Shape> GetShapeHelper::GetShape(const ActualArgument &arg) {
if (const auto *expr{arg.GetExpr()}) {
return GetShape(*expr);
} else {
const Symbol *aType{arg.GetAssumedTypeDummy()};
CHECK(aType != nullptr);
return GetShape(*aType);
}
}
std::optional<Shape> GetShape(const ProcedureRef &call) {
std::optional<Shape> GetShapeHelper::GetShape(const ProcedureRef &call) {
if (call.Rank() == 0) {
return Shape{};
} else if (call.IsElemental()) {
@ -318,14 +339,16 @@ std::optional<Shape> GetShape(const ProcedureRef &call) {
std::get_if<SpecificIntrinsic>(&call.proc().u)}) {
if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
intrinsic->name == "ubound") {
return Shape{MaybeExtent{
ExtentExpr{call.arguments().front().value().value().Rank()}}};
const auto *expr{call.arguments().front().value().GetExpr()};
CHECK(expr != nullptr);
return Shape{MaybeExtent{ExtentExpr{expr->Rank()}}};
} else if (intrinsic->name == "reshape") {
if (call.arguments().size() >= 2 && call.arguments().at(1).has_value()) {
// SHAPE(RESHAPE(array,shape)) -> shape
const Expr<SomeType> &shapeExpr{call.arguments().at(1)->value()};
Expr<SomeInteger> shape{std::get<Expr<SomeInteger>>(shapeExpr.u)};
return AsShape(ConvertToType<ExtentType>(std::move(shape)));
const auto *shapeExpr{call.arguments().at(1).value().GetExpr()};
CHECK(shapeExpr != nullptr);
Expr<SomeInteger> shape{std::get<Expr<SomeInteger>>(shapeExpr->u)};
return AsShape(context_, ConvertToType<ExtentType>(std::move(shape)));
}
} else {
// TODO: shapes of other non-elemental intrinsic results
@ -334,28 +357,54 @@ std::optional<Shape> GetShape(const ProcedureRef &call) {
return std::nullopt;
}
std::optional<Shape> GetShape(const Relational<SomeType> &relation) {
std::optional<Shape> GetShapeHelper::GetShape(
const Relational<SomeType> &relation) {
return GetShape(relation.u);
}
std::optional<Shape> GetShape(const StructureConstructor &) {
std::optional<Shape> GetShapeHelper::GetShape(const StructureConstructor &) {
return Shape{}; // always scalar
}
std::optional<Shape> GetShape(const ImpliedDoIndex &) {
std::optional<Shape> GetShapeHelper::GetShape(const ImpliedDoIndex &) {
return Shape{}; // always scalar
}
std::optional<Shape> GetShape(const DescriptorInquiry &) {
std::optional<Shape> GetShapeHelper::GetShape(const DescriptorInquiry &) {
return Shape{}; // always scalar
}
std::optional<Shape> GetShape(const BOZLiteralConstant &) {
std::optional<Shape> GetShapeHelper::GetShape(const BOZLiteralConstant &) {
return Shape{}; // always scalar
}
std::optional<Shape> GetShape(const NullPointer &) {
std::optional<Shape> GetShapeHelper::GetShape(const NullPointer &) {
return {}; // not an object
}
void CheckConformance(parser::ContextualMessages &messages, const Shape &left,
const Shape &right) {
if (!left.empty() && !right.empty()) {
int n{static_cast<int>(left.size())};
int rn{static_cast<int>(right.size())};
if (n != rn) {
messages.Say(
"Left operand has rank %d, but right operand has rank %d"_err_en_US,
n, rn);
} else {
for (int j{0}; j < n; ++j) {
if (auto leftDim{ToInt64(left[j])}) {
if (auto rightDim{ToInt64(right[j])}) {
if (*leftDim != *rightDim) {
messages.Say("Dimension %d of left operand has extent %jd, "
"but right operand has extent %jd"_err_en_US,
j + 1, static_cast<std::intmax_t>(*leftDim),
static_cast<std::intmax_t>(*rightDim));
}
}
}
}
}
}
}
}

View File

@ -27,16 +27,20 @@
namespace Fortran::evaluate {
class FoldingContext;
using ExtentType = SubscriptInteger;
using ExtentExpr = Expr<ExtentType>;
using MaybeExtent = std::optional<ExtentExpr>;
using Shape = std::vector<MaybeExtent>;
// Convert between various representations of shapes
// Conversions between various representations of shapes.
Shape AsShape(const Constant<ExtentType> &arrayConstant);
std::optional<Shape> AsShape(ExtentExpr &&arrayExpr);
std::optional<ExtentExpr> AsShapeArrayExpr(const Shape &); // array constructor
std::optional<Shape> AsShape(FoldingContext &, ExtentExpr &&arrayExpr);
std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &);
std::optional<Constant<ExtentType>> AsConstantShape(const Shape &);
ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &);
std::optional<ConstantSubscripts> AsConstantExtents(const Shape &);
// Compute an element count for a triplet or trip count for a DO.
ExtentExpr CountTrips(
@ -49,132 +53,150 @@ MaybeExtent CountTrips(
// Computes SIZE() == PRODUCT(shape)
MaybeExtent GetSize(Shape &&);
// Forward declarations
template<typename... A>
std::optional<Shape> GetShape(const std::variant<A...> &);
template<typename A, bool COPY>
std::optional<Shape> GetShape(const common::Indirection<A, COPY> &);
template<typename A> std::optional<Shape> GetShape(const std::optional<A> &);
template<typename T> std::optional<Shape> GetShape(const Expr<T> &expr) {
return GetShape(expr.u);
}
std::optional<Shape> GetShape(const Symbol &, const Component * = nullptr);
std::optional<Shape> GetShape(const Symbol *);
std::optional<Shape> GetShape(const BaseObject &);
std::optional<Shape> GetShape(const Component &);
std::optional<Shape> GetShape(const ArrayRef &);
std::optional<Shape> GetShape(const CoarrayRef &);
std::optional<Shape> GetShape(const DataRef &);
std::optional<Shape> GetShape(const Substring &);
std::optional<Shape> GetShape(const ComplexPart &);
std::optional<Shape> GetShape(const ActualArgument &);
std::optional<Shape> GetShape(const ProcedureRef &);
std::optional<Shape> GetShape(const ImpliedDoIndex &);
std::optional<Shape> GetShape(const Relational<SomeType> &);
std::optional<Shape> GetShape(const StructureConstructor &);
std::optional<Shape> GetShape(const DescriptorInquiry &);
std::optional<Shape> GetShape(const BOZLiteralConstant &);
std::optional<Shape> GetShape(const NullPointer &);
template<typename T> std::optional<Shape> GetShape(const Constant<T> &c) {
Constant<ExtentType> shape{c.SHAPE()};
return AsShape(shape);
}
template<typename T>
std::optional<Shape> GetShape(const Designator<T> &designator) {
return GetShape(designator.u);
}
template<typename T>
std::optional<Shape> GetShape(const Variable<T> &variable) {
return GetShape(variable.u);
}
template<typename D, typename R, typename... O>
std::optional<Shape> GetShape(const Operation<D, R, O...> &operation) {
if constexpr (sizeof...(O) > 1) {
if (operation.right().Rank() > 0) {
return GetShape(operation.right());
}
}
return GetShape(operation.left());
}
template<int KIND>
std::optional<Shape> GetShape(const TypeParamInquiry<KIND> &) {
return Shape{}; // always scalar, even when applied to an array
}
// Utility predicate: does an expression reference any implied DO index?
bool ContainsAnyImpliedDoIndex(const ExtentExpr &);
template<typename T> MaybeExtent GetExtent(const ArrayConstructorValues<T> &);
// Compilation-time shape conformance checking, when corresponding extents
// are known.
void CheckConformance(
parser::ContextualMessages &, const Shape &, const Shape &);
template<typename T>
MaybeExtent GetExtent(const ArrayConstructorValue<T> &value) {
return std::visit(
common::visitors{
[](const Expr<T> &x) -> MaybeExtent {
if (std::optional<Shape> xShape{GetShape(x)}) {
// Array values in array constructors get linearized.
return GetSize(std::move(*xShape));
}
return std::nullopt;
},
[](const ImpliedDo<T> &ido) -> MaybeExtent {
// Don't be heroic and try to figure out triangular implied DO
// nests.
if (!ContainsAnyImpliedDoIndex(ido.lower()) &&
!ContainsAnyImpliedDoIndex(ido.upper()) &&
!ContainsAnyImpliedDoIndex(ido.stride())) {
if (auto nValues{GetExtent(ido.values())}) {
return std::move(*nValues) *
CountTrips(ido.lower(), ido.upper(), ido.stride());
}
}
return std::nullopt;
},
},
value.u);
}
// The implementation of GetShape() is wrapped in a helper class
// so that the member functions may mutually recurse without prototypes.
class GetShapeHelper {
public:
explicit GetShapeHelper(FoldingContext &context) : context_{context} {}
template<typename T>
MaybeExtent GetExtent(const ArrayConstructorValues<T> &values) {
ExtentExpr result{0};
for (const auto &value : values.values()) {
if (MaybeExtent n{GetExtent(value)}) {
result = std::move(result) + std::move(*n);
template<typename T> std::optional<Shape> GetShape(const Expr<T> &expr) {
return GetShape(expr.u);
}
std::optional<Shape> GetShape(const Symbol &, const Component * = nullptr);
std::optional<Shape> GetShape(const Symbol *);
std::optional<Shape> GetShape(const BaseObject &);
std::optional<Shape> GetShape(const Component &);
std::optional<Shape> GetShape(const ArrayRef &);
std::optional<Shape> GetShape(const CoarrayRef &);
std::optional<Shape> GetShape(const DataRef &);
std::optional<Shape> GetShape(const Substring &);
std::optional<Shape> GetShape(const ComplexPart &);
std::optional<Shape> GetShape(const ActualArgument &);
std::optional<Shape> GetShape(const ProcedureRef &);
std::optional<Shape> GetShape(const ImpliedDoIndex &);
std::optional<Shape> GetShape(const Relational<SomeType> &);
std::optional<Shape> GetShape(const StructureConstructor &);
std::optional<Shape> GetShape(const DescriptorInquiry &);
std::optional<Shape> GetShape(const BOZLiteralConstant &);
std::optional<Shape> GetShape(const NullPointer &);
template<typename T> std::optional<Shape> GetShape(const Constant<T> &c) {
Constant<ExtentType> shape{c.SHAPE()};
return AsShape(shape);
}
template<typename T>
std::optional<Shape> GetShape(const Designator<T> &designator) {
return GetShape(designator.u);
}
template<typename T>
std::optional<Shape> GetShape(const Variable<T> &variable) {
return GetShape(variable.u);
}
template<typename D, typename R, typename... O>
std::optional<Shape> GetShape(const Operation<D, R, O...> &operation) {
if constexpr (sizeof...(O) > 1) {
if (operation.right().Rank() > 0) {
return GetShape(operation.right());
}
}
return GetShape(operation.left());
}
template<int KIND>
std::optional<Shape> GetShape(const TypeParamInquiry<KIND> &) {
return Shape{}; // always scalar, even when applied to an array
}
template<typename T>
std::optional<Shape> GetShape(const ArrayConstructor<T> &aconst) {
return Shape{GetExtent(aconst)};
}
template<typename... A>
std::optional<Shape> GetShape(const std::variant<A...> &u) {
return std::visit([&](const auto &x) { return GetShape(x); }, u);
}
template<typename A, bool COPY>
std::optional<Shape> GetShape(const common::Indirection<A, COPY> &p) {
return GetShape(p.value());
}
template<typename A>
std::optional<Shape> GetShape(const std::optional<A> &x) {
if (x.has_value()) {
return GetShape(*x);
} else {
return std::nullopt;
}
}
return result;
}
template<typename T>
std::optional<Shape> GetShape(const ArrayConstructor<T> &aconst) {
return Shape{GetExtent(aconst)};
}
private:
MaybeExtent GetLowerBound(const Symbol &, const Component *, int dimension);
template<typename... A>
std::optional<Shape> GetShape(const std::variant<A...> &u) {
return std::visit([](const auto &x) { return GetShape(x); }, u);
}
template<typename A, bool COPY>
std::optional<Shape> GetShape(const common::Indirection<A, COPY> &p) {
return GetShape(p.value());
}
template<typename A> std::optional<Shape> GetShape(const std::optional<A> &x) {
if (x.has_value()) {
return GetShape(*x);
} else {
return std::nullopt;
template<typename T>
MaybeExtent GetExtent(const ArrayConstructorValue<T> &value) {
return std::visit(
common::visitors{
[&](const Expr<T> &x) -> MaybeExtent {
if (std::optional<Shape> xShape{GetShape(x)}) {
// Array values in array constructors get linearized.
return GetSize(std::move(*xShape));
}
return std::nullopt;
},
[&](const ImpliedDo<T> &ido) -> MaybeExtent {
// Don't be heroic and try to figure out triangular implied DO
// nests.
if (!ContainsAnyImpliedDoIndex(ido.lower()) &&
!ContainsAnyImpliedDoIndex(ido.upper()) &&
!ContainsAnyImpliedDoIndex(ido.stride())) {
if (auto nValues{GetExtent(ido.values())}) {
return std::move(*nValues) *
CountTrips(ido.lower(), ido.upper(), ido.stride());
}
}
return std::nullopt;
},
},
value.u);
}
template<typename T>
MaybeExtent GetExtent(const ArrayConstructorValues<T> &values) {
ExtentExpr result{0};
for (const auto &value : values) {
if (MaybeExtent n{GetExtent(value)}) {
result = std::move(result) + std::move(*n);
} else {
return std::nullopt;
}
}
return result;
}
MaybeExtent GetExtent(const Symbol &, const Component *, int dimension);
MaybeExtent GetExtent(
const Subscript &, const Symbol &, const Component *, int dimension);
FoldingContext &context_;
};
template<typename A>
std::optional<Shape> GetShape(FoldingContext &context, const A &x) {
return GetShapeHelper{context}.GetShape(x);
}
}
#endif // FORTRAN_EVALUATE_SHAPE_H_

View File

@ -1161,7 +1161,7 @@ template<typename T>
ArrayConstructorValues<T> MakeSpecific(
ArrayConstructorValues<SomeType> &&from) {
ArrayConstructorValues<T> to;
for (ArrayConstructorValue<SomeType> &x : from.values()) {
for (ArrayConstructorValue<SomeType> &x : from) {
std::visit(
common::visitors{
[&](common::CopyableIndirection<Expr<SomeType>> &&expr) {
@ -1456,7 +1456,7 @@ auto ExpressionAnalyzer::Procedure(const parser::ProcedureDesignator &pd,
CallCharacteristics cc{n.source};
if (std::optional<SpecificCall> specificCall{
context().intrinsics().Probe(
cc, arguments, &GetContextualMessages())}) {
cc, arguments, GetFoldingContext())}) {
return {
CallAndArguments{ProcedureDesignator{std::move(
specificCall->specificIntrinsic)},

View File

@ -14,6 +14,7 @@
#include "../../lib/evaluate/intrinsics.h"
#include "testing.h"
#include "../../lib/evaluate/common.h"
#include "../../lib/evaluate/expression.h"
#include "../../lib/evaluate/tools.h"
#include "../../lib/parser/provenance.h"
@ -111,7 +112,8 @@ struct TestCall {
std::cout << ")\n";
CallCharacteristics call{fName};
auto messages{strings.Messages(buffer)};
std::optional<SpecificCall> si{table.Probe(call, args, &messages)};
FoldingContext context{messages};
std::optional<SpecificCall> si{table.Probe(call, args, context)};
if (resultType.has_value()) {
TEST(si.has_value());
TEST(buffer.empty());

View File

@ -29,7 +29,7 @@ module m1
integer(8), parameter :: ac3bs(:) = shape([(1,j=4,1,-1)])
integer(8), parameter :: ac4s(:) = shape([((j,k,j*k,k=1,3),j=1,4)])
integer(8), parameter :: ac5s(:) = shape([((0,k=5,1,-2),j=9,2,-3)])
integer(8), parameter :: rss(:) = shape(reshape([(0,j=1,90)], [10_8,9_8]))
integer(8), parameter :: rss(:) = shape(reshape([(0,j=1,90)], -[2,3]*(-[5_8,3_8])))
contains
subroutine subr(x,n1,n2)
real, intent(in) :: x(:,:)