[flang] Enforce accessibility requirement on type-bound generic operators, &c.

Type-bound generics like operator(+) and assignment(=) need to not be
PRIVATE if they are used outside the module in which they are declared.

Differential Revision: https://reviews.llvm.org/D139123
This commit is contained in:
Peter Klausler 2022-11-30 17:40:48 -08:00
parent 00272ad8d3
commit d7a1351bb8
8 changed files with 157 additions and 92 deletions

View File

@ -117,6 +117,7 @@ public:
const Scope *GetDerivedTypeParent() const;
const Scope &GetDerivedTypeBase() const;
inline std::optional<SourceName> GetName() const;
// Returns true if this scope contains, or is, another scope.
bool Contains(const Scope &) const;
/// Make a scope nested in this one
Scope &MakeScope(Kind kind, Symbol *symbol = nullptr);

View File

@ -85,8 +85,13 @@ bool IsIntrinsicConcat(
bool IsGenericDefinedOp(const Symbol &);
bool IsDefinedOperator(SourceName);
std::string MakeOpName(SourceName);
// Returns true if maybeAncestor exists and is a proper ancestor of a
// descendent scope (or symbol owner). Will be false, unlike Scope::Contains(),
// if maybeAncestor *is* the descendent.
bool DoesScopeContain(const Scope *maybeAncestor, const Scope &maybeDescendent);
bool DoesScopeContain(const Scope *, const Symbol &);
bool IsUseAssociated(const Symbol &, const Scope &);
bool IsHostAssociated(const Symbol &, const Scope &);
bool IsHostAssociatedIntoSubprogram(const Symbol &, const Scope &);
@ -182,8 +187,9 @@ bool HasCoarray(const parser::Expr &);
bool IsAssumedType(const Symbol &);
bool IsPolymorphic(const Symbol &);
bool IsPolymorphicAllocatable(const Symbol &);
// Return an error if component symbol is not accessible from scope (7.5.4.8(2))
std::optional<parser::MessageFormattedText> CheckAccessibleComponent(
// Return an error if a symbol is not accessible from a scope
std::optional<parser::MessageFormattedText> CheckAccessibleSymbol(
const semantics::Scope &, const Symbol &);
// Analysis of image control statements

View File

@ -155,8 +155,10 @@ public:
// Find and return a user-defined operator or report an error.
// The provided message is used if there is no such operator.
MaybeExpr TryDefinedOp(const char *, parser::MessageFixedText,
const Symbol **definedOpSymbolPtr = nullptr, bool isUserOp = false);
// If a definedOpSymbolPtr is provided, the caller must check
// for its accessibility.
MaybeExpr TryDefinedOp(
const char *, parser::MessageFixedText, bool isUserOp = false);
template <typename E>
MaybeExpr TryDefinedOp(E opr, parser::MessageFixedText msg) {
return TryDefinedOp(
@ -175,7 +177,7 @@ private:
MaybeExpr AnalyzeExprOrWholeAssumedSizeArray(const parser::Expr &);
bool AreConformable() const;
const Symbol *FindBoundOp(parser::CharBlock, int passIndex,
const Symbol *&definedOp, bool isSubroutine);
const Symbol *&generic, bool isSubroutine);
void AddAssignmentConversion(
const DynamicType &lhsType, const DynamicType &rhsType);
bool OkLogicalIntegerAssignment(TypeCategory lhs, TypeCategory rhs);
@ -1778,10 +1780,9 @@ MaybeExpr ExpressionAnalyzer::Analyze(
}
}
if (symbol) {
if (const auto *currScope{context_.globalScope().FindScope(source)}) {
if (auto msg{CheckAccessibleComponent(*currScope, *symbol)}) {
Say(source, *msg);
}
const semantics::Scope &innermost{context_.FindScope(expr.source)};
if (auto msg{CheckAccessibleSymbol(innermost, *symbol)}) {
Say(expr.source, std::move(*msg));
}
if (checkConflicts) {
auto componentIter{
@ -1809,7 +1810,6 @@ MaybeExpr ExpressionAnalyzer::Analyze(
}
unavailable.insert(symbol->name());
if (value) {
const auto &innermost{context_.FindScope(expr.source)};
if (symbol->has<semantics::ProcEntityDetails>()) {
CHECK(IsPointer(*symbol));
} else if (symbol->has<semantics::ObjectEntityDetails>()) {
@ -2869,7 +2869,7 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::DefinedUnary &x) {
ArgumentAnalyzer analyzer{*this, name.source};
analyzer.Analyze(std::get<1>(x.t));
return analyzer.TryDefinedOp(name.source.ToString().c_str(),
"No operator %s defined for %s"_err_en_US, nullptr, true);
"No operator %s defined for %s"_err_en_US, true);
}
// Binary (dyadic) operations
@ -3053,7 +3053,7 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::DefinedBinary &x) {
analyzer.Analyze(std::get<1>(x.t));
analyzer.Analyze(std::get<2>(x.t));
return analyzer.TryDefinedOp(name.source.ToString().c_str(),
"No operator %s defined for %s and %s"_err_en_US, nullptr, true);
"No operator %s defined for %s and %s"_err_en_US, true);
}
// Returns true if a parsed function reference should be converted
@ -3635,63 +3635,100 @@ bool ArgumentAnalyzer::CheckForNullPointer(const char *where) {
return true;
}
MaybeExpr ArgumentAnalyzer::TryDefinedOp(const char *opr,
parser::MessageFixedText error, const Symbol **definedOpSymbolPtr,
bool isUserOp) {
MaybeExpr ArgumentAnalyzer::TryDefinedOp(
const char *opr, parser::MessageFixedText error, bool isUserOp) {
if (AnyUntypedOrMissingOperand()) {
context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1));
return std::nullopt;
}
const Symbol *localDefinedOpSymbolPtr{nullptr};
if (!definedOpSymbolPtr) {
definedOpSymbolPtr = &localDefinedOpSymbolPtr;
}
MaybeExpr result;
bool anyPossibilities{false};
std::optional<parser::MessageFormattedText> inaccessible;
std::vector<const Symbol *> hit;
std::string oprNameString{
isUserOp ? std::string{opr} : "operator("s + opr + ')'};
parser::CharBlock oprName{oprNameString};
{
auto restorer{context_.GetContextualMessages().DiscardMessages()};
std::string oprNameString{
isUserOp ? std::string{opr} : "operator("s + opr + ')'};
parser::CharBlock oprName{oprNameString};
const auto &scope{context_.context().FindScope(source_)};
if (Symbol *symbol{scope.FindSymbol(oprName)}) {
*definedOpSymbolPtr = symbol;
anyPossibilities = true;
parser::Name name{symbol->name(), symbol};
if (auto result{context_.AnalyzeDefinedOp(name, GetActuals())}) {
return result;
result = context_.AnalyzeDefinedOp(name, GetActuals());
if (result) {
inaccessible = CheckAccessibleSymbol(scope, *symbol);
if (inaccessible) {
result.reset();
} else {
hit.push_back(symbol);
}
}
}
for (std::size_t passIndex{0}; passIndex < actuals_.size(); ++passIndex) {
if (const Symbol *
symbol{FindBoundOp(oprName, passIndex, *definedOpSymbolPtr, false)}) {
if (MaybeExpr result{TryBoundOp(*symbol, passIndex)}) {
return result;
const Symbol *generic{nullptr};
if (const Symbol *binding{
FindBoundOp(oprName, passIndex, generic, false)}) {
anyPossibilities = true;
if (MaybeExpr thisResult{TryBoundOp(*binding, passIndex)}) {
if (auto thisInaccessible{
CheckAccessibleSymbol(scope, DEREF(generic))}) {
inaccessible = thisInaccessible;
} else {
result = std::move(thisResult);
hit.push_back(binding);
}
}
}
}
}
if (*definedOpSymbolPtr) {
SayNoMatch(ToUpperCase((*definedOpSymbolPtr)->name().ToString()));
} else if (actuals_.size() == 1 || AreConformable()) {
if (CheckForNullPointer()) {
context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1));
if (result) {
if (hit.size() > 1) {
if (auto *msg{context_.Say(
"%zd matching accessible generic interfaces for %s were found"_err_en_US,
hit.size(), ToUpperCase(opr))}) {
for (const Symbol *symbol : hit) {
AttachDeclaration(*msg, *symbol);
}
}
}
} else {
} else if (inaccessible) {
context_.Say(source_, std::move(*inaccessible));
} else if (anyPossibilities) {
SayNoMatch(ToUpperCase(oprNameString), false);
} else if (actuals_.size() == 2 && !AreConformable()) {
context_.Say(
"Operands of %s are not conformable; have rank %d and rank %d"_err_en_US,
ToUpperCase(opr), actuals_[0]->Rank(), actuals_[1]->Rank());
} else if (CheckForNullPointer()) {
context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1));
}
return std::nullopt;
return result;
}
MaybeExpr ArgumentAnalyzer::TryDefinedOp(
std::vector<const char *> oprs, parser::MessageFixedText error) {
const Symbol *definedOpSymbolPtr{nullptr};
for (std::size_t i{1}; i < oprs.size(); ++i) {
if (oprs.size() == 1) {
return TryDefinedOp(oprs[0], error);
}
MaybeExpr result;
std::vector<const char *> hit;
{
auto restorer{context_.GetContextualMessages().DiscardMessages()};
if (auto result{TryDefinedOp(oprs[i], error, &definedOpSymbolPtr)}) {
return result;
for (std::size_t i{0}; i < oprs.size(); ++i) {
if (MaybeExpr thisResult{TryDefinedOp(oprs[i], error)}) {
result = std::move(thisResult);
hit.push_back(oprs[i]);
}
}
}
return TryDefinedOp(oprs[0], error, &definedOpSymbolPtr);
if (hit.empty()) { // for the error
result = TryDefinedOp(oprs[0], error);
} else if (hit.size() > 1) {
context_.Say(
"Matching accessible definitions were found with %zd variant spellings of the generic operator ('%s', '%s')"_err_en_US,
hit.size(), ToUpperCase(hit[0]), ToUpperCase(hit[1]));
}
return result;
}
MaybeExpr ArgumentAnalyzer::TryBoundOp(const Symbol &symbol, int passIndex) {
@ -3768,31 +3805,34 @@ bool ArgumentAnalyzer::OkLogicalIntegerAssignment(
}
std::optional<ProcedureRef> ArgumentAnalyzer::GetDefinedAssignmentProc() {
auto restorer{context_.GetContextualMessages().DiscardMessages()};
const Symbol *proc{nullptr};
int passedObjectIndex{-1};
std::string oprNameString{"assignment(=)"};
parser::CharBlock oprName{oprNameString};
const Symbol *proc{nullptr};
const auto &scope{context_.context().FindScope(source_)};
if (const Symbol *symbol{scope.FindSymbol(oprName)}) {
ExpressionAnalyzer::AdjustActuals noAdjustment;
auto pair{context_.ResolveGeneric(*symbol, actuals_, noAdjustment, true)};
if (pair.first) {
proc = pair.first;
} else {
context_.EmitGenericResolutionError(*symbol, pair.second, true);
}
}
int passedObjectIndex{-1};
const Symbol *definedOpSymbol{nullptr};
for (std::size_t i{0}; i < actuals_.size(); ++i) {
if (const Symbol *
specific{FindBoundOp(oprName, i, definedOpSymbol, true)}) {
if (const Symbol *
resolution{GetBindingResolution(GetType(i), *specific)}) {
proc = resolution;
// If multiple resolutions were possible, they will have been already
// diagnosed.
{
auto restorer{context_.GetContextualMessages().DiscardMessages()};
if (const Symbol *symbol{scope.FindSymbol(oprName)}) {
ExpressionAnalyzer::AdjustActuals noAdjustment;
auto pair{context_.ResolveGeneric(*symbol, actuals_, noAdjustment, true)};
if (pair.first) {
proc = pair.first;
} else {
proc = specific;
passedObjectIndex = i;
context_.EmitGenericResolutionError(*symbol, pair.second, true);
}
}
for (std::size_t i{0}; i < actuals_.size(); ++i) {
const Symbol *generic{nullptr};
if (const Symbol *specific{FindBoundOp(oprName, i, generic, true)}) {
if (const Symbol *resolution{
GetBindingResolution(GetType(i), *specific)}) {
proc = resolution;
} else {
proc = specific;
passedObjectIndex = i;
}
}
}
}
@ -3871,24 +3911,23 @@ bool ArgumentAnalyzer::AreConformable() const {
// Look for a type-bound operator in the type of arg number passIndex.
const Symbol *ArgumentAnalyzer::FindBoundOp(parser::CharBlock oprName,
int passIndex, const Symbol *&definedOp, bool isSubroutine) {
int passIndex, const Symbol *&generic, bool isSubroutine) {
const auto *type{GetDerivedTypeSpec(GetType(passIndex))};
if (!type || !type->scope()) {
return nullptr;
}
const Symbol *symbol{type->scope()->FindComponent(oprName)};
if (!symbol) {
generic = type->scope()->FindComponent(oprName);
if (!generic) {
return nullptr;
}
definedOp = symbol;
ExpressionAnalyzer::AdjustActuals adjustment{
[&](const Symbol &proc, ActualArguments &) {
return passIndex == GetPassIndex(proc);
}};
auto pair{
context_.ResolveGeneric(*symbol, actuals_, adjustment, isSubroutine)};
context_.ResolveGeneric(*generic, actuals_, adjustment, isSubroutine)};
if (!pair.first) {
context_.EmitGenericResolutionError(*symbol, pair.second, isSubroutine);
context_.EmitGenericResolutionError(*generic, pair.second, isSubroutine);
}
return pair.first;
}

View File

@ -5989,7 +5989,7 @@ bool DeclarationVisitor::OkToAddComponent(
} else if (extends) {
msg = "Type cannot be extended as it has a component named"
" '%s'"_err_en_US;
} else if (CheckAccessibleComponent(currScope(), *prev)) {
} else if (CheckAccessibleSymbol(currScope(), *prev)) {
// inaccessible component -- redeclaration is ok
msg = "Component '%s' is inaccessibly declared in or as a "
"parent of this derived type"_warn_en_US;
@ -6864,8 +6864,7 @@ const parser::Name *DeclarationVisitor::FindComponent(
derived->Instantiate(currScope()); // in case of forward referenced type
if (const Scope * scope{derived->scope()}) {
if (Resolve(component, scope->FindComponent(component.source))) {
if (auto msg{
CheckAccessibleComponent(currScope(), *component.symbol)}) {
if (auto msg{CheckAccessibleSymbol(currScope(), *component.symbol)}) {
context().Say(component.source, *msg);
}
return &component;

View File

@ -991,9 +991,8 @@ bool IsPolymorphicAllocatable(const Symbol &symbol) {
return IsAllocatable(symbol) && IsPolymorphic(symbol);
}
std::optional<parser::MessageFormattedText> CheckAccessibleComponent(
std::optional<parser::MessageFormattedText> CheckAccessibleSymbol(
const Scope &scope, const Symbol &symbol) {
CHECK(symbol.owner().IsDerivedType()); // symbol must be a component
if (symbol.attrs().test(Attr::PRIVATE)) {
if (FindModuleFileContaining(scope)) {
// Don't enforce component accessibility checks in module files;
@ -1003,7 +1002,7 @@ std::optional<parser::MessageFormattedText> CheckAccessibleComponent(
moduleScope{FindModuleContaining(symbol.owner())}) {
if (!moduleScope->Contains(scope)) {
return parser::MessageFormattedText{
"PRIVATE component '%s' is only accessible within module '%s'"_err_en_US,
"PRIVATE name '%s' is only accessible within module '%s'"_err_en_US,
symbol.name(), moduleScope->GetName().value()};
}
}

View File

@ -91,9 +91,9 @@ subroutine s7
type(t2) :: x
integer :: j
j = x%i2
!ERROR: PRIVATE component 'i3' is only accessible within module 'm7'
!ERROR: PRIVATE name 'i3' is only accessible within module 'm7'
j = x%i3
!ERROR: PRIVATE component 't1' is only accessible within module 'm7'
!ERROR: PRIVATE name 't1' is only accessible within module 'm7'
j = x%t1%i1
end
@ -117,11 +117,11 @@ end
subroutine s8
use m8
type(t) :: x
!ERROR: PRIVATE component 'i2' is only accessible within module 'm8'
!ERROR: PRIVATE name 'i2' is only accessible within module 'm8'
x = t(2, 5)
!ERROR: PRIVATE component 'i2' is only accessible within module 'm8'
!ERROR: PRIVATE name 'i2' is only accessible within module 'm8'
x = t(i1=2, i2=5)
!ERROR: PRIVATE component 'i2' is only accessible within module 'm8'
!ERROR: PRIVATE name 'i2' is only accessible within module 'm8'
a = [y%i2]
end
@ -143,3 +143,24 @@ contains
x = t(i1=2, i2=5) !OK
end
end
module m10
type t
integer n
contains
procedure :: f
generic, private :: operator(+) => f
end type
contains
type(t) function f(x,y)
class(t), intent(in) :: x, y
f = t(x%n + y%n)
end function
end module
subroutine s10
use m10
type(t) x
x = t(1)
!ERROR: PRIVATE name 'operator(+)' is only accessible within module 'm10'
x = x + x
end subroutine

View File

@ -58,15 +58,15 @@ contains
l = z'fe' == r !OK
l = cVar == z'fe' !OK
l = z'fe' == cVar !OK
!ERROR: No intrinsic or user-defined OPERATOR(==) matches operand types CHARACTER(KIND=1) and INTEGER(4)
!ERROR: Operands of .EQ. must have comparable types; have CHARACTER(KIND=1) and INTEGER(4)
l = charVar == z'fe'
!ERROR: No intrinsic or user-defined OPERATOR(==) matches operand types INTEGER(4) and CHARACTER(KIND=1)
!ERROR: Operands of .EQ. must have comparable types; have INTEGER(4) and CHARACTER(KIND=1)
l = z'fe' == charVar
!ERROR: No intrinsic or user-defined OPERATOR(==) matches operand types LOGICAL(4) and INTEGER(4)
l = l == z'fe' !OK
!ERROR: No intrinsic or user-defined OPERATOR(==) matches operand types INTEGER(4) and LOGICAL(4)
l = z'fe' == l !OK
!ERROR: No intrinsic or user-defined OPERATOR(==) matches operand types TYPE(t) and REAL(4)
!ERROR: Operands of .EQ. must have comparable types; have LOGICAL(4) and INTEGER(4)
l = l == z'fe'
!ERROR: Operands of .EQ. must have comparable types; have INTEGER(4) and LOGICAL(4)
l = z'fe' == l
!ERROR: Operands of .EQ. must have comparable types; have TYPE(t) and REAL(4)
l = x == r
lVar = z'a' == b'1010' !OK
@ -265,9 +265,9 @@ contains
i = x + y
i = x + i
i = y + i
!ERROR: No intrinsic or user-defined OPERATOR(+) matches operand types CLASS(t2) and CLASS(t1)
!ERROR: Operands of + must be numeric; have CLASS(t2) and CLASS(t1)
i = y + x
!ERROR: No intrinsic or user-defined OPERATOR(+) matches operand types INTEGER(4) and CLASS(t1)
!ERROR: Operands of + must be numeric; have INTEGER(4) and CLASS(t1)
i = i + x
end
end
@ -307,9 +307,9 @@ module m7
j = null() - null(mold=x1)
j = null(mold=x1) - null()
j = null() - null()
!ERROR: No intrinsic or user-defined OPERATOR(/) matches operand types untyped and TYPE(t1)
!ERROR: A NULL() pointer is not allowed as an operand here
j = null() / null(mold=x1)
!ERROR: No intrinsic or user-defined OPERATOR(/) matches operand types TYPE(t1) and untyped
!ERROR: A NULL() pointer is not allowed as an operand here
j = null(mold=x1) / null()
!ERROR: A NULL() pointer is not allowed as an operand here
j = null() / null()

View File

@ -37,9 +37,9 @@ contains
subroutine s1(x, y, z)
logical :: x
complex :: y, z
!ERROR: No intrinsic or user-defined OPERATOR(.A.) matches operand types COMPLEX(4) and COMPLEX(4)
!ERROR: Operands of .AND. must be LOGICAL; have COMPLEX(4) and COMPLEX(4)
x = y .and. z
!ERROR: No intrinsic or user-defined OPERATOR(.A.) matches operand types COMPLEX(4) and COMPLEX(4)
!ERROR: Operands of .AND. must be LOGICAL; have COMPLEX(4) and COMPLEX(4)
x = y .a. z
end
end