forked from OSchip/llvm-project
255 lines
8.8 KiB
C++
255 lines
8.8 KiB
C++
//===-- lib/Semantics/check-case.cpp --------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "check-case.h"
|
|
#include "flang/Common/idioms.h"
|
|
#include "flang/Common/reference.h"
|
|
#include "flang/Common/template.h"
|
|
#include "flang/Evaluate/fold.h"
|
|
#include "flang/Evaluate/type.h"
|
|
#include "flang/Parser/parse-tree.h"
|
|
#include "flang/Semantics/semantics.h"
|
|
#include "flang/Semantics/tools.h"
|
|
#include <tuple>
|
|
|
|
namespace Fortran::semantics {
|
|
|
|
template <typename T> class CaseValues {
|
|
public:
|
|
CaseValues(SemanticsContext &c, const evaluate::DynamicType &t)
|
|
: context_{c}, caseExprType_{t} {}
|
|
|
|
void Check(const std::list<parser::CaseConstruct::Case> &cases) {
|
|
for (const parser::CaseConstruct::Case &c : cases) {
|
|
AddCase(c);
|
|
}
|
|
if (!hasErrors_) {
|
|
cases_.sort(Comparator{});
|
|
if (!AreCasesDisjoint()) { // C1149
|
|
ReportConflictingCases();
|
|
}
|
|
}
|
|
}
|
|
|
|
private:
|
|
using Value = evaluate::Scalar<T>;
|
|
|
|
void AddCase(const parser::CaseConstruct::Case &c) {
|
|
const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
|
|
const parser::CaseStmt &caseStmt{stmt.statement};
|
|
const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)};
|
|
std::visit(
|
|
common::visitors{
|
|
[&](const std::list<parser::CaseValueRange> &ranges) {
|
|
for (const auto &range : ranges) {
|
|
auto pair{ComputeBounds(range)};
|
|
if (pair.first && pair.second && *pair.first > *pair.second) {
|
|
context_.Say(stmt.source,
|
|
"CASE has lower bound greater than upper bound"_en_US);
|
|
} else {
|
|
if constexpr (T::category == TypeCategory::Logical) { // C1148
|
|
if ((pair.first || pair.second) &&
|
|
(!pair.first || !pair.second ||
|
|
*pair.first != *pair.second)) {
|
|
context_.Say(stmt.source,
|
|
"CASE range is not allowed for LOGICAL"_err_en_US);
|
|
}
|
|
}
|
|
cases_.emplace_back(stmt);
|
|
cases_.back().lower = std::move(pair.first);
|
|
cases_.back().upper = std::move(pair.second);
|
|
}
|
|
}
|
|
},
|
|
[&](const parser::Default &) { cases_.emplace_front(stmt); },
|
|
},
|
|
selector.u);
|
|
}
|
|
|
|
std::optional<Value> GetValue(const parser::CaseValue &caseValue) {
|
|
const parser::Expr &expr{caseValue.thing.thing.value()};
|
|
auto *x{expr.typedExpr.get()};
|
|
if (x && x->v) { // C1147
|
|
auto type{x->v->GetType()};
|
|
if (type && type->category() == caseExprType_.category() &&
|
|
(type->category() != TypeCategory::Character ||
|
|
type->kind() == caseExprType_.kind())) {
|
|
x->v = evaluate::Fold(context_.foldingContext(),
|
|
evaluate::ConvertToType(T::GetType(), std::move(*x->v)));
|
|
if (x->v) {
|
|
if (auto value{evaluate::GetScalarConstantValue<T>(*x->v)}) {
|
|
return *value;
|
|
}
|
|
}
|
|
context_.Say(
|
|
expr.source, "CASE value must be a constant scalar"_err_en_US);
|
|
} else {
|
|
std::string typeStr{type ? type->AsFortran() : "typeless"s};
|
|
context_.Say(expr.source,
|
|
"CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
|
|
typeStr, caseExprType_.AsFortran());
|
|
}
|
|
hasErrors_ = true;
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>;
|
|
PairOfValues ComputeBounds(const parser::CaseValueRange &range) {
|
|
return std::visit(common::visitors{
|
|
[&](const parser::CaseValue &x) {
|
|
auto value{GetValue(x)};
|
|
return PairOfValues{value, value};
|
|
},
|
|
[&](const parser::CaseValueRange::Range &x) {
|
|
std::optional<Value> lo, hi;
|
|
if (x.lower) {
|
|
lo = GetValue(*x.lower);
|
|
}
|
|
if (x.upper) {
|
|
hi = GetValue(*x.upper);
|
|
}
|
|
if ((x.lower && !lo) || (x.upper && !hi)) {
|
|
return PairOfValues{}; // error case
|
|
}
|
|
return PairOfValues{std::move(lo), std::move(hi)};
|
|
},
|
|
},
|
|
range.u);
|
|
}
|
|
|
|
struct Case {
|
|
explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {}
|
|
bool IsDefault() const { return !lower && !upper; }
|
|
std::string AsFortran() const {
|
|
std::string result;
|
|
{
|
|
llvm::raw_string_ostream bs{result};
|
|
if (lower) {
|
|
evaluate::Constant<T>{*lower}.AsFortran(bs << '(');
|
|
if (!upper) {
|
|
bs << ':';
|
|
} else if (*lower != *upper) {
|
|
evaluate::Constant<T>{*upper}.AsFortran(bs << ':');
|
|
}
|
|
bs << ')';
|
|
} else if (upper) {
|
|
evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')';
|
|
} else {
|
|
bs << "DEFAULT";
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
const parser::Statement<parser::CaseStmt> &stmt;
|
|
std::optional<Value> lower, upper;
|
|
};
|
|
|
|
// Defines a comparator for use with std::list<>::sort().
|
|
// Returns true if and only if the highest value in range x is less
|
|
// than the least value in range y. The DEFAULT case is arbitrarily
|
|
// defined to be less than all others. When two ranges overlap,
|
|
// neither is less than the other.
|
|
struct Comparator {
|
|
bool operator()(const Case &x, const Case &y) const {
|
|
if (x.IsDefault()) {
|
|
return !y.IsDefault();
|
|
} else {
|
|
return x.upper && y.lower && *x.upper < *y.lower;
|
|
}
|
|
}
|
|
};
|
|
|
|
bool AreCasesDisjoint() const {
|
|
auto endIter{cases_.end()};
|
|
for (auto iter{cases_.begin()}; iter != endIter; ++iter) {
|
|
auto next{iter};
|
|
if (++next != endIter && !Comparator{}(*iter, *next)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// This has quadratic time, but only runs in error cases
|
|
void ReportConflictingCases() {
|
|
for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) {
|
|
parser::Message *msg{nullptr};
|
|
for (auto p{cases_.begin()}; p != cases_.end(); ++p) {
|
|
if (p->stmt.source.begin() < iter->stmt.source.begin() &&
|
|
!Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) {
|
|
if (!msg) {
|
|
msg = &context_.Say(iter->stmt.source,
|
|
"CASE %s conflicts with previous cases"_err_en_US,
|
|
iter->AsFortran());
|
|
}
|
|
msg->Attach(
|
|
p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
SemanticsContext &context_;
|
|
const evaluate::DynamicType &caseExprType_;
|
|
std::list<Case> cases_;
|
|
bool hasErrors_{false};
|
|
};
|
|
|
|
template <TypeCategory CAT> struct TypeVisitor {
|
|
using Result = bool;
|
|
using Types = evaluate::CategoryTypes<CAT>;
|
|
template <typename T> Result Test() {
|
|
if (T::kind == exprType.kind()) {
|
|
CaseValues<T>(context, exprType).Check(caseList);
|
|
return true;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
SemanticsContext &context;
|
|
const evaluate::DynamicType &exprType;
|
|
const std::list<parser::CaseConstruct::Case> &caseList;
|
|
};
|
|
|
|
void CaseChecker::Enter(const parser::CaseConstruct &construct) {
|
|
const auto &selectCaseStmt{
|
|
std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
|
|
const auto &selectCase{selectCaseStmt.statement};
|
|
const auto &selectExpr{
|
|
std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing};
|
|
const auto *x{GetExpr(selectExpr)};
|
|
if (!x) {
|
|
return; // expression semantics failed
|
|
}
|
|
if (auto exprType{x->GetType()}) {
|
|
const auto &caseList{
|
|
std::get<std::list<parser::CaseConstruct::Case>>(construct.t)};
|
|
switch (exprType->category()) {
|
|
case TypeCategory::Integer:
|
|
common::SearchTypes(
|
|
TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList});
|
|
return;
|
|
case TypeCategory::Logical:
|
|
CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType}
|
|
.Check(caseList);
|
|
return;
|
|
case TypeCategory::Character:
|
|
common::SearchTypes(
|
|
TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList});
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
context_.Say(selectExpr.source,
|
|
"SELECT CASE expression must be integer, logical, or character"_err_en_US);
|
|
}
|
|
} // namespace Fortran::semantics
|