[ASTMatchers] Fix traversal below range-for elements

Differential Revision: https://reviews.llvm.org/D95562
This commit is contained in:
Stephen Kelly 2021-01-27 22:03:23 +00:00
parent b01b964d37
commit 79125085f1
2 changed files with 84 additions and 10 deletions

View File

@ -243,10 +243,14 @@ public:
return true;
ScopedIncrement ScopedDepth(&CurrentDepth);
if (auto *Init = Node->getInit())
if (!match(*Init))
if (!traverse(*Init))
return false;
if (!match(*Node->getLoopVariable()) || !match(*Node->getRangeInit()) ||
!match(*Node->getBody()))
if (!match(*Node->getLoopVariable()))
return false;
if (match(*Node->getRangeInit()))
if (!VisitorBase::TraverseStmt(Node->getRangeInit()))
return false;
if (!match(*Node->getBody()))
return false;
return VisitorBase::TraverseStmt(Node->getBody());
}
@ -488,15 +492,21 @@ public:
bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) {
if (auto *RF = dyn_cast<CXXForRangeStmt>(S)) {
for (auto *SubStmt : RF->children()) {
if (SubStmt == RF->getInit() || SubStmt == RF->getLoopVarStmt() ||
SubStmt == RF->getRangeInit() || SubStmt == RF->getBody()) {
TraverseStmt(SubStmt, Queue);
} else {
ASTNodeNotSpelledInSourceScope RAII(this, true);
TraverseStmt(SubStmt, Queue);
{
ASTNodeNotAsIsSourceScope RAII(this, true);
TraverseStmt(RF->getInit());
// Don't traverse under the loop variable
match(*RF->getLoopVariable());
TraverseStmt(RF->getRangeInit());
}
{
ASTNodeNotSpelledInSourceScope RAII(this, true);
for (auto *SubStmt : RF->children()) {
if (SubStmt != RF->getBody())
TraverseStmt(SubStmt);
}
}
TraverseStmt(RF->getBody());
return true;
} else if (auto *RBO = dyn_cast<CXXRewrittenBinaryOperator>(S)) {
{

View File

@ -2820,6 +2820,36 @@ struct CtorInitsNonTrivial : NonTrivial
EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
}
Code = R"cpp(
struct Range {
int* begin() const;
int* end() const;
};
Range getRange(int);
void rangeFor()
{
for (auto i : getRange(42))
{
}
}
)cpp";
{
auto M = integerLiteral(equals(42));
EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
}
{
auto M = callExpr(hasDescendant(integerLiteral(equals(42))));
EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
}
{
auto M = compoundStmt(hasDescendant(integerLiteral(equals(42))));
EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
}
Code = R"cpp(
void rangeFor()
{
@ -2891,6 +2921,40 @@ struct CtorInitsNonTrivial : NonTrivial
matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
true, {"-std=c++20"}));
}
Code = R"cpp(
struct Range {
int* begin() const;
int* end() const;
};
Range getRange(int);
int getNum(int);
void rangeFor()
{
for (auto j = getNum(42); auto i : getRange(j))
{
}
}
)cpp";
{
auto M = integerLiteral(equals(42));
EXPECT_TRUE(
matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
EXPECT_TRUE(
matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
true, {"-std=c++20"}));
}
{
auto M = compoundStmt(hasDescendant(integerLiteral(equals(42))));
EXPECT_TRUE(
matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
EXPECT_TRUE(
matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
true, {"-std=c++20"}));
}
Code = R"cpp(
void hasDefaultArg(int i, int j = 0)
{