diff --git a/clang-tools-extra/clangd/XRefs.cpp b/clang-tools-extra/clangd/XRefs.cpp index b6a0be4a0a0a..c4a73ebad3e8 100644 --- a/clang-tools-extra/clangd/XRefs.cpp +++ b/clang-tools-extra/clangd/XRefs.cpp @@ -1281,6 +1281,21 @@ std::vector findImplementations(ParsedAST &AST, Position Pos, return findImplementors(std::move(IDs), QueryKind, Index, *MainFilePath); } +namespace { +// Recursively finds all the overridden methods of `CMD` in complete type +// hierarchy. +void getOverriddenMethods(const CXXMethodDecl *CMD, + llvm::DenseSet &OverriddenMethods) { + if (!CMD) + return; + for (const CXXMethodDecl *Base : CMD->overridden_methods()) { + if (auto ID = getSymbolID(Base)) + OverriddenMethods.insert(ID); + getOverriddenMethods(Base, OverriddenMethods); + } +} +} // namespace + ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit, const SymbolIndex *Index) { if (!Limit) @@ -1300,7 +1315,7 @@ ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit, return {}; } - llvm::DenseSet IDs; + llvm::DenseSet IDs, OverriddenMethods; const auto *IdentifierAtCursor = syntax::spelledIdentifierTouching(*CurLoc, AST.getTokens()); @@ -1343,13 +1358,16 @@ ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit, if (Index) { OverriddenBy.Predicate = RelationKind::OverriddenBy; for (const NamedDecl *ND : Decls) { - // Special case: Inlcude declaration of overridding methods. + // Special case: For virtual methods, report decl/def of overrides and + // references to all overridden methods in complete type hierarchy. if (const auto *CMD = llvm::dyn_cast(ND)) { if (CMD->isVirtual()) if (IdentifierAtCursor && SM.getSpellingLoc(CMD->getLocation()) == - IdentifierAtCursor->location()) + IdentifierAtCursor->location()) { if (auto ID = getSymbolID(CMD)) OverriddenBy.Subjects.insert(ID); + getOverriddenMethods(CMD, OverriddenMethods); + } } } } @@ -1415,7 +1433,8 @@ ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit, } } // Now query the index for references from other files. - auto QueryIndex = [&](llvm::DenseSet IDs, bool AllowAttributes) { + auto QueryIndex = [&](llvm::DenseSet IDs, bool AllowAttributes, + bool AllowMainFileSymbols) { RefsRequest Req; Req.IDs = std::move(IDs); Req.Limit = Limit; @@ -1427,7 +1446,8 @@ ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit, return; auto LSPLoc = toLSPLocation(R.Location, *MainFilePath); // Avoid indexed results for the main file - the AST is authoritative. - if (!LSPLoc || LSPLoc->uri.file() == *MainFilePath) + if (!LSPLoc || + (!AllowMainFileSymbols && LSPLoc->uri.file() == *MainFilePath)) return; ReferencesResult::Reference Result; Result.Loc = std::move(*LSPLoc); @@ -1442,12 +1462,17 @@ ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit, Results.References.push_back(std::move(Result)); }); }; - QueryIndex(std::move(IDs), /*AllowAttributes=*/true); + QueryIndex(std::move(IDs), /*AllowAttributes=*/true, + /*AllowMainFileSymbols=*/false); + // For a virtual method: Occurrences of BaseMethod should be treated as refs + // and not as decl/def. Allow symbols from main file since AST does not report + // these. + QueryIndex(std::move(OverriddenMethods), /*AllowAttributes=*/false, + /*AllowMainFileSymbols=*/true); if (Results.References.size() > Limit) { Results.HasMore = true; Results.References.resize(Limit); } - // FIXME: Report refs of base methods. return Results; } diff --git a/clang-tools-extra/clangd/unittests/XRefsTests.cpp b/clang-tools-extra/clangd/unittests/XRefsTests.cpp index 966656e47d8f..06bdb9fc4c8a 100644 --- a/clang-tools-extra/clangd/unittests/XRefsTests.cpp +++ b/clang-tools-extra/clangd/unittests/XRefsTests.cpp @@ -1889,6 +1889,30 @@ TEST(FindReferences, IncludeOverrides) { checkFindRefs(Test, /*UseIndex=*/true); } +TEST(FindReferences, RefsToBaseMethod) { + llvm::StringRef Test = + R"cpp( + class BaseBase { + public: + virtual void [[func]](); + }; + class Base : public BaseBase { + public: + void [[func]]() override; + }; + class Derived : public Base { + public: + void $decl[[fu^nc]]() override; + }; + void test(BaseBase* BB, Base* B, Derived* D) { + // refs to overridden methods in complete type hierarchy are reported. + BB->[[func]](); + B->[[func]](); + D->[[func]](); + })cpp"; + checkFindRefs(Test, /*UseIndex=*/true); +} + TEST(FindReferences, MainFileReferencesOnly) { llvm::StringRef Test = R"cpp(