diff --git a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp index b65d7f0c1a39..e0368975ea3e 100644 --- a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp +++ b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp @@ -12,6 +12,7 @@ #include "clang/Lex/Lexer.h" #include "llvm/ADT/Optional.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Path.h" namespace clang { namespace tooling { @@ -174,12 +175,22 @@ inline StringRef trimInclude(StringRef IncludeName) { const char IncludeRegexPattern[] = R"(^[\t\ ]*#[\t\ ]*(import|include)[^"<]*(["<][^">]*[">]))"; +// The filename of Path excluding extension. +// Used to match implementation with headers, this differs from sys::path::stem: +// - in names with multiple dots (foo.cu.cc) it terminates at the *first* +// - an empty stem is never returned: /foo/.bar.x => .bar +// - we don't bother to handle . and .. specially +StringRef matchingStem(llvm::StringRef Path) { + StringRef Name = llvm::sys::path::filename(Path); + return Name.substr(0, Name.find('.', 1)); +} + } // anonymous namespace IncludeCategoryManager::IncludeCategoryManager(const IncludeStyle &Style, StringRef FileName) : Style(Style), FileName(FileName) { - FileStem = llvm::sys::path::stem(FileName); + FileStem = matchingStem(FileName); for (const auto &Category : Style.IncludeCategories) CategoryRegexs.emplace_back(Category.Regex, llvm::Regex::IgnoreCase); IsMainFile = FileName.endswith(".c") || FileName.endswith(".cc") || @@ -222,8 +233,7 @@ int IncludeCategoryManager::getSortIncludePriority(StringRef IncludeName, bool IncludeCategoryManager::isMainHeader(StringRef IncludeName) const { if (!IncludeName.startswith("\"")) return false; - StringRef HeaderStem = - llvm::sys::path::stem(IncludeName.drop_front(1).drop_back(1)); + StringRef HeaderStem = matchingStem(IncludeName.drop_front(1).drop_back(1)); if (FileStem.startswith(HeaderStem) || FileStem.startswith_lower(HeaderStem)) { llvm::Regex MainIncludeRegex(HeaderStem.str() + Style.IncludeIsMainRegex, diff --git a/clang/unittests/Tooling/HeaderIncludesTest.cpp b/clang/unittests/Tooling/HeaderIncludesTest.cpp index d38104fe40ec..37007fbfb65e 100644 --- a/clang/unittests/Tooling/HeaderIncludesTest.cpp +++ b/clang/unittests/Tooling/HeaderIncludesTest.cpp @@ -40,7 +40,7 @@ protected: return *Result; } - const std::string FileName = "fix.cpp"; + std::string FileName = "fix.cpp"; IncludeStyle Style = format::getLLVMStyle().IncludeStyle; }; @@ -102,6 +102,15 @@ TEST_F(HeaderIncludesTest, InsertAfterMainHeader) { Style = format::getGoogleStyle(format::FormatStyle::LanguageKind::LK_Cpp) .IncludeStyle; EXPECT_EQ(Expected, insert(Code, "")); + + FileName = "fix.cu.cpp"; + EXPECT_EQ(Expected, insert(Code, "")); + + FileName = "fix_test.cu.cpp"; + EXPECT_EQ(Expected, insert(Code, "")); + + FileName = "bar.cpp"; + EXPECT_NE(Expected, insert(Code, "")) << "Not main header"; } TEST_F(HeaderIncludesTest, InsertBeforeSystemHeaderLLVM) {