[flang] Add more checks on WHERE and FORALL

Check that masks and LHS of assignments in WHERE statements and
constructs have consistent shapes. They must all have the same rank and
any extents that are compile-time constants must match.

Emit a warning for assignments in FORALL statements and constructs where
the LHS does not reference each of the index variables.

Original-commit: flang-compiler/f18@8b04dbebcf
Reviewed-on: https://github.com/flang-compiler/f18/pull/1009
This commit is contained in:
Tim Keith 2020-02-20 14:54:46 -08:00
parent 703c56132b
commit e0ba2b8783
7 changed files with 220 additions and 189 deletions

View File

@ -159,10 +159,13 @@ public:
void CheckIndexVarRedefine(const parser::Name &);
void ActivateIndexVar(const parser::Name &, IndexVarKind);
void DeactivateIndexVar(const parser::Name &);
SymbolVector GetIndexVars(IndexVarKind);
private:
void CheckIndexVarRedefine(
const parser::CharBlock &, const Symbol &, parser::MessageFixedText &&);
bool CheckError(bool);
const common::IntrinsicTypeDefaultKinds &defaultKinds_;
const common::LanguageFeatureControl languageFeatures_;
parser::AllSources &allSources_;
@ -176,8 +179,6 @@ private:
Scope globalScope_;
parser::Messages messages_;
evaluate::FoldingContext foldingContext_;
bool CheckError(bool);
ConstructStack constructStack_;
struct IndexVarInfo {
parser::CharBlock location;

View File

@ -29,192 +29,60 @@ using namespace Fortran::parser::literals;
namespace Fortran::semantics {
using ControlExpr = evaluate::Expr<evaluate::SubscriptInteger>;
using MaskExpr = evaluate::Expr<evaluate::LogicalResult>;
// The context tracks some number of active FORALL statements/constructs
// and some number of active WHERE statements/constructs. WHERE can nest
// in FORALL but not vice versa. Pointer assignments are allowed in
// FORALL but not in WHERE. These constraints are manifest in the grammar
// and don't need to be rechecked here, since errors cannot appear in the
// parse tree.
struct Control {
Symbol *name;
ControlExpr lower, upper, step;
};
struct ForallContext {
explicit ForallContext(const ForallContext *that) : outer{that} {}
const ForallContext *outer{nullptr};
std::optional<parser::CharBlock> constructName;
std::vector<Control> control;
std::optional<MaskExpr> maskExpr;
std::set<parser::CharBlock> activeNames;
};
struct WhereContext {
WhereContext(MaskExpr &&x, const WhereContext *o, const ForallContext *f)
: outer{o}, forall{f}, thisMaskExpr{std::move(x)} {}
const WhereContext *outer{nullptr};
const ForallContext *forall{nullptr}; // innermost enclosing FORALL
std::optional<parser::CharBlock> constructName;
MaskExpr thisMaskExpr; // independent of outer WHERE, if any
MaskExpr cumulativeMaskExpr{thisMaskExpr};
};
class AssignmentContext {
public:
explicit AssignmentContext(SemanticsContext &c) : context_{c} {}
AssignmentContext(const AssignmentContext &c, WhereContext &w)
: context_{c.context_}, where_{&w} {}
AssignmentContext(const AssignmentContext &c, ForallContext &f)
: context_{c.context_}, forall_{&f} {}
explicit AssignmentContext(SemanticsContext &context) : context_{context} {}
AssignmentContext(AssignmentContext &&) = default;
AssignmentContext(const AssignmentContext &) = delete;
bool operator==(const AssignmentContext &x) const { return this == &x; }
template<typename A> void PushWhereContext(const A &);
void PopWhereContext();
void Analyze(const parser::AssignmentStmt &);
void Analyze(const parser::PointerAssignmentStmt &);
void Analyze(const parser::WhereStmt &);
void Analyze(const parser::WhereConstruct &);
void Analyze(const parser::ForallConstruct &);
template<typename A> void Analyze(const parser::UnlabeledStatement<A> &stmt) {
context_.set_location(stmt.source);
Analyze(stmt.statement);
}
template<typename A> void Analyze(const common::Indirection<A> &x) {
Analyze(x.value());
}
template<typename A> std::enable_if_t<UnionTrait<A>> Analyze(const A &x) {
std::visit([&](const auto &y) { Analyze(y); }, x.u);
}
template<typename A> void Analyze(const std::list<A> &list) {
for (const auto &elem : list) {
Analyze(elem);
}
}
template<typename A> void Analyze(const std::optional<A> &x) {
if (x) {
Analyze(*x);
}
}
void Analyze(const parser::ConcurrentControl &);
private:
void Analyze(const parser::WhereConstruct::MaskedElsewhere &);
void Analyze(const parser::MaskedElsewhereStmt &);
void Analyze(const parser::WhereConstruct::Elsewhere &);
void CheckForPureContext(const SomeExpr &lhs, const SomeExpr &rhs,
parser::CharBlock rhsSource, bool isPointerAssignment);
MaskExpr GetMask(const parser::LogicalExpr &, bool defaultValue = true);
void CheckShape(parser::CharBlock, const SomeExpr *);
template<typename... A>
parser::Message *Say(parser::CharBlock at, A &&... args) {
return &context_.Say(at, std::forward<A>(args)...);
}
evaluate::FoldingContext &foldingContext() {
return context_.foldingContext();
}
SemanticsContext &context_;
WhereContext *where_{nullptr};
ForallContext *forall_{nullptr};
int whereDepth_{0}; // number of WHEREs currently nested in
// shape of masks in LHS of assignments in current WHERE:
std::vector<std::optional<std::int64_t>> whereExtents_;
};
void AssignmentContext::Analyze(const parser::AssignmentStmt &stmt) {
// Assignment statement analysis is in expression.cpp where user-defined
// assignments can be recognized and replaced.
if (const evaluate::Assignment * assignment{GetAssignment(stmt)}) {
if (forall_) {
// TODO: Warn if some name in forall_->activeNames or its outer
// contexts does not appear on LHS
const SomeExpr &lhs{assignment->lhs};
const SomeExpr &rhs{assignment->rhs};
auto lhsLoc{std::get<parser::Variable>(stmt.t).GetSource()};
auto rhsLoc{std::get<parser::Expr>(stmt.t).source};
if (whereDepth_ > 0) {
CheckShape(lhsLoc, &lhs);
}
CheckForPureContext(assignment->lhs, assignment->rhs,
std::get<parser::Expr>(stmt.t).source, false /* not => */);
CheckForPureContext(lhs, rhs, rhsLoc, false);
}
// TODO: Fortran 2003 ALLOCATABLE assignment semantics (automatic
// (re)allocation of LHS array when unallocated or nonconformable)
}
void AssignmentContext::Analyze(const parser::PointerAssignmentStmt &stmt) {
CHECK(!where_);
const evaluate::Assignment *assignment{GetAssignment(stmt)};
if (!assignment) {
return;
CHECK(whereDepth_ == 0);
if (const evaluate::Assignment * assignment{GetAssignment(stmt)}) {
const SomeExpr &lhs{assignment->lhs};
const SomeExpr &rhs{assignment->rhs};
CheckForPureContext(lhs, rhs, std::get<parser::Expr>(stmt.t).source, true);
auto restorer{
foldingContext().messages().SetLocation(context_.location().value())};
CheckPointerAssignment(foldingContext(), *assignment);
}
const SomeExpr &lhs{assignment->lhs};
const SomeExpr &rhs{assignment->rhs};
if (forall_) {
// TODO: Warn if some name in forall_->activeNames or its outer
// contexts does not appear on LHS
}
CheckForPureContext(lhs, rhs, std::get<parser::Expr>(stmt.t).source,
true /* isPointerAssignment */);
auto restorer{context_.foldingContext().messages().SetLocation(
context_.location().value())};
CheckPointerAssignment(context_.foldingContext(), *assignment);
}
void AssignmentContext::Analyze(const parser::WhereStmt &stmt) {
WhereContext where{
GetMask(std::get<parser::LogicalExpr>(stmt.t)), where_, forall_};
AssignmentContext nested{*this, where};
nested.Analyze(std::get<parser::AssignmentStmt>(stmt.t));
}
// N.B. Construct name matching is checked during label resolution.
void AssignmentContext::Analyze(const parser::WhereConstruct &construct) {
const auto &whereStmt{
std::get<parser::Statement<parser::WhereConstructStmt>>(construct.t)};
WhereContext where{
GetMask(std::get<parser::LogicalExpr>(whereStmt.statement.t)), where_,
forall_};
if (const auto &name{
std::get<std::optional<parser::Name>>(whereStmt.statement.t)}) {
where.constructName = name->source;
}
AssignmentContext nested{*this, where};
nested.Analyze(std::get<std::list<parser::WhereBodyConstruct>>(construct.t));
nested.Analyze(std::get<std::list<parser::WhereConstruct::MaskedElsewhere>>(
construct.t));
nested.Analyze(
std::get<std::optional<parser::WhereConstruct::Elsewhere>>(construct.t));
}
void AssignmentContext::Analyze(
const parser::WhereConstruct::MaskedElsewhere &elsewhere) {
CHECK(where_);
Analyze(
std::get<parser::Statement<parser::MaskedElsewhereStmt>>(elsewhere.t));
Analyze(std::get<std::list<parser::WhereBodyConstruct>>(elsewhere.t));
}
void AssignmentContext::Analyze(const parser::MaskedElsewhereStmt &elsewhere) {
MaskExpr mask{GetMask(std::get<parser::LogicalExpr>(elsewhere.t))};
MaskExpr copyCumulative{where_->cumulativeMaskExpr};
MaskExpr notOldMask{evaluate::LogicalNegation(std::move(copyCumulative))};
if (!evaluate::AreConformable(notOldMask, mask)) {
context_.Say("mask of ELSEWHERE statement is not conformable with "
"the prior mask(s) in its WHERE construct"_err_en_US);
}
MaskExpr copyMask{mask};
where_->cumulativeMaskExpr =
evaluate::BinaryLogicalOperation(evaluate::LogicalOperator::Or,
std::move(where_->cumulativeMaskExpr), std::move(copyMask));
where_->thisMaskExpr = evaluate::BinaryLogicalOperation(
evaluate::LogicalOperator::And, std::move(notOldMask), std::move(mask));
if (where_->outer &&
!evaluate::AreConformable(
where_->outer->thisMaskExpr, where_->thisMaskExpr)) {
context_.Say("effective mask of ELSEWHERE statement is not conformable "
"with the mask of the surrounding WHERE construct"_err_en_US);
}
}
void AssignmentContext::Analyze(
const parser::WhereConstruct::Elsewhere &elsewhere) {
MaskExpr copyCumulative{DEREF(where_).cumulativeMaskExpr};
where_->thisMaskExpr = evaluate::LogicalNegation(std::move(copyCumulative));
Analyze(std::get<std::list<parser::WhereBodyConstruct>>(elsewhere.t));
}
// C1594 checks
@ -333,14 +201,45 @@ void AssignmentContext::CheckForPureContext(const SomeExpr &lhs,
}
}
MaskExpr AssignmentContext::GetMask(
const parser::LogicalExpr &logicalExpr, bool defaultValue) {
MaskExpr mask{defaultValue};
if (const SomeExpr * expr{GetExpr(logicalExpr)}) {
auto *logical{std::get_if<evaluate::Expr<evaluate::SomeLogical>>(&expr->u)};
mask = evaluate::ConvertTo(mask, common::Clone(DEREF(logical)));
// 10.2.3.1(2) The masks and LHS of assignments must all have the same shape
void AssignmentContext::CheckShape(parser::CharBlock at, const SomeExpr *expr) {
if (auto shape{evaluate::GetShape(foldingContext(), expr)}) {
std::size_t size{shape->size()};
if (whereDepth_ == 0) {
whereExtents_.resize(size);
} else if (whereExtents_.size() != size) {
Say(at,
"Must have rank %zd to match prior mask or assignment of"
" WHERE construct"_err_en_US,
whereExtents_.size());
return;
}
for (std::size_t i{0}; i < size; ++i) {
if (std::optional<std::int64_t> extent{evaluate::ToInt64((*shape)[i])}) {
if (!whereExtents_[i]) {
whereExtents_[i] = *extent;
} else if (*whereExtents_[i] != *extent) {
Say(at,
"Dimension %d must have extent %jd to match prior mask or"
" assignment of WHERE construct"_err_en_US,
i + 1, static_cast<std::intmax_t>(*whereExtents_[i]));
}
}
}
}
}
template<typename A> void AssignmentContext::PushWhereContext(const A &x) {
const auto &expr{std::get<parser::LogicalExpr>(x.t)};
CheckShape(expr.thing.value().source, GetExpr(expr));
++whereDepth_;
}
void AssignmentContext::PopWhereContext() {
--whereDepth_;
if (whereDepth_ == 0) {
whereExtents_.clear();
}
return mask;
}
AssignmentChecker::~AssignmentChecker() {}
@ -354,10 +253,22 @@ void AssignmentChecker::Enter(const parser::PointerAssignmentStmt &x) {
context_.value().Analyze(x);
}
void AssignmentChecker::Enter(const parser::WhereStmt &x) {
context_.value().Analyze(x);
context_.value().PushWhereContext(x);
}
void AssignmentChecker::Enter(const parser::WhereConstruct &x) {
context_.value().Analyze(x);
void AssignmentChecker::Leave(const parser::WhereStmt &) {
context_.value().PopWhereContext();
}
void AssignmentChecker::Enter(const parser::WhereConstructStmt &x) {
context_.value().PushWhereContext(x);
}
void AssignmentChecker::Leave(const parser::EndWhereStmt &) {
context_.value().PopWhereContext();
}
void AssignmentChecker::Enter(const parser::MaskedElsewhereStmt &x) {
context_.value().PushWhereContext(x);
}
void AssignmentChecker::Leave(const parser::MaskedElsewhereStmt &) {
context_.value().PopWhereContext();
}
}

View File

@ -16,9 +16,11 @@
namespace Fortran::parser {
class ContextualMessages;
struct AssignmentStmt;
struct EndWhereStmt;
struct MaskedElsewhereStmt;
struct PointerAssignmentStmt;
struct WhereConstructStmt;
struct WhereStmt;
struct WhereConstruct;
}
namespace Fortran::semantics {
@ -41,7 +43,11 @@ public:
void Enter(const parser::AssignmentStmt &);
void Enter(const parser::PointerAssignmentStmt &);
void Enter(const parser::WhereStmt &);
void Enter(const parser::WhereConstruct &);
void Leave(const parser::WhereStmt &);
void Enter(const parser::WhereConstructStmt &);
void Leave(const parser::EndWhereStmt &);
void Enter(const parser::MaskedElsewhereStmt &);
void Leave(const parser::MaskedElsewhereStmt &);
private:
common::Indirection<AssignmentContext> context_;

View File

@ -452,6 +452,7 @@ public:
common::visitors{[&](const auto &x) { return GetAssignment(x); }},
stmt.u)};
if (assignment) {
CheckForallIndexesUsed(*assignment);
CheckForImpureCall(assignment->lhs);
CheckForImpureCall(assignment->rhs);
if (const auto *proc{
@ -753,6 +754,38 @@ private:
}
}
// Each index should be used on the LHS of each assignment in a FORALL
void CheckForallIndexesUsed(const evaluate::Assignment &assignment) {
SymbolVector indexVars{context_.GetIndexVars(IndexVarKind::FORALL)};
if (!indexVars.empty()) {
SymbolSet symbols{evaluate::CollectSymbols(assignment.lhs)};
std::visit(
common::visitors{
[&](const evaluate::Assignment::BoundsSpec &spec) {
for (const auto &bound : spec) {
symbols.merge(evaluate::CollectSymbols(bound));
}
},
[&](const evaluate::Assignment::BoundsRemapping &remapping) {
for (const auto &bounds : remapping) {
symbols.merge(evaluate::CollectSymbols(bounds.first));
symbols.merge(evaluate::CollectSymbols(bounds.second));
}
},
[](const auto &) {},
},
assignment.u);
for (const Symbol &index : indexVars) {
if (symbols.count(index) == 0) {
context_.Say(
"Warning: FORALL index variable '%s' not used on left-hand side"
" of assignment"_en_US,
index.name());
}
}
}
}
// For messages where the DO loop must be DO CONCURRENT, make that explicit.
const char *LoopKindName() const {
return kind_ == IndexVarKind::DO ? "DO CONCURRENT" : "FORALL";

View File

@ -123,7 +123,8 @@ static bool PerformStatementSemantics(
RewriteParseTree(context, program);
CheckDeclarations(context);
StatementSemanticsPass1{context}.Walk(program);
return StatementSemanticsPass2{context}.Walk(program);
StatementSemanticsPass2{context}.Walk(program);
return !context.AnyFatalError();
}
SemanticsContext::SemanticsContext(
@ -262,6 +263,16 @@ void SemanticsContext::DeactivateIndexVar(const parser::Name &name) {
}
}
SymbolVector SemanticsContext::GetIndexVars(IndexVarKind kind) {
SymbolVector result;
for (const auto &[symbol, info] : activeIndexVars_) {
if (info.kind == kind) {
result.push_back(symbol);
}
}
return result;
}
bool Semantics::Perform() {
return ValidateLabels(context_, program_) &&
parser::CanonicalizeDo(program_) && // force line break

View File

@ -1,14 +1,53 @@
integer :: a1(10), a2(10)
logical :: m1(10), m2(5,5)
m1 = .true.
m2 = .false.
a1 = [((i),i=1,10)]
where (m1)
a2 = 1
!ERROR: mask of ELSEWHERE statement is not conformable with the prior mask(s) in its WHERE construct
elsewhere (m2)
a2 = 2
elsewhere
a2 = 3
end where
! 10.2.3.1(2) All masks and LHS of assignments in a WHERE must conform
subroutine s1
integer :: a1(10), a2(10)
logical :: m1(10), m2(5,5)
m1 = .true.
m2 = .false.
a1 = [((i),i=1,10)]
where (m1)
a2 = 1
!ERROR: Must have rank 1 to match prior mask or assignment of WHERE construct
elsewhere (m2)
a2 = 2
elsewhere
a2 = 3
end where
end
subroutine s2
logical, allocatable :: m1(:), m4(:,:)
logical :: m2(2), m3(3)
where(m1)
where(m2)
end where
!ERROR: Dimension 1 must have extent 2 to match prior mask or assignment of WHERE construct
where(m3)
end where
!ERROR: Must have rank 1 to match prior mask or assignment of WHERE construct
where(m4)
end where
endwhere
where(m1)
where(m3)
end where
!ERROR: Dimension 1 must have extent 3 to match prior mask or assignment of WHERE construct
elsewhere(m2)
end where
end
subroutine s3
logical, allocatable :: m1(:,:)
logical :: m2(4,2)
real :: x(4,4), y(4,4)
real :: a(4,5), b(4,5)
where(m1)
x = y
!ERROR: Dimension 2 must have extent 4 to match prior mask or assignment of WHERE construct
a = b
!ERROR: Dimension 2 must have extent 4 to match prior mask or assignment of WHERE construct
where(m2)
end where
end where
end

View File

@ -16,7 +16,6 @@ subroutine forall1
end forall
end
subroutine forall2
integer, pointer :: a(:)
integer, target :: b(10,10)
@ -73,3 +72,34 @@ subroutine forall4
!ERROR: FORALL step expression may not be zero
forall(i=1:10:zero) a(i) = i
end
! Note: this gets warnings but not errors
subroutine forall5
real, target :: x(10), y(10)
forall(i=1:10)
x(i) = y(i)
end forall
forall(i=1:10)
x = y ! warning: i not used on LHS
forall(j=1:10)
x(i) = y(i) ! warning: j not used on LHS
x(j) = y(j) ! warning: i not used on LHS
endforall
endforall
do concurrent(i=1:10)
x = y
forall(i=1:10) x = y
end do
end
subroutine forall6
type t
real, pointer :: p
end type
type(t) :: a(10)
real, target :: b(10)
forall(i=1:10)
a(i)%p => b(i)
a(1)%p => b(i) ! warning: i not used on LHS
end forall
end