diff --git a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp
index b65d7f0c1a39..e0368975ea3e 100644
--- a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp
+++ b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp
@@ -12,6 +12,7 @@
#include "clang/Lex/Lexer.h"
#include "llvm/ADT/Optional.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Path.h"
namespace clang {
namespace tooling {
@@ -174,12 +175,22 @@ inline StringRef trimInclude(StringRef IncludeName) {
const char IncludeRegexPattern[] =
R"(^[\t\ ]*#[\t\ ]*(import|include)[^"<]*(["<][^">]*[">]))";
+// The filename of Path excluding extension.
+// Used to match implementation with headers, this differs from sys::path::stem:
+// - in names with multiple dots (foo.cu.cc) it terminates at the *first*
+// - an empty stem is never returned: /foo/.bar.x => .bar
+// - we don't bother to handle . and .. specially
+StringRef matchingStem(llvm::StringRef Path) {
+ StringRef Name = llvm::sys::path::filename(Path);
+ return Name.substr(0, Name.find('.', 1));
+}
+
} // anonymous namespace
IncludeCategoryManager::IncludeCategoryManager(const IncludeStyle &Style,
StringRef FileName)
: Style(Style), FileName(FileName) {
- FileStem = llvm::sys::path::stem(FileName);
+ FileStem = matchingStem(FileName);
for (const auto &Category : Style.IncludeCategories)
CategoryRegexs.emplace_back(Category.Regex, llvm::Regex::IgnoreCase);
IsMainFile = FileName.endswith(".c") || FileName.endswith(".cc") ||
@@ -222,8 +233,7 @@ int IncludeCategoryManager::getSortIncludePriority(StringRef IncludeName,
bool IncludeCategoryManager::isMainHeader(StringRef IncludeName) const {
if (!IncludeName.startswith("\""))
return false;
- StringRef HeaderStem =
- llvm::sys::path::stem(IncludeName.drop_front(1).drop_back(1));
+ StringRef HeaderStem = matchingStem(IncludeName.drop_front(1).drop_back(1));
if (FileStem.startswith(HeaderStem) ||
FileStem.startswith_lower(HeaderStem)) {
llvm::Regex MainIncludeRegex(HeaderStem.str() + Style.IncludeIsMainRegex,
diff --git a/clang/unittests/Tooling/HeaderIncludesTest.cpp b/clang/unittests/Tooling/HeaderIncludesTest.cpp
index d38104fe40ec..37007fbfb65e 100644
--- a/clang/unittests/Tooling/HeaderIncludesTest.cpp
+++ b/clang/unittests/Tooling/HeaderIncludesTest.cpp
@@ -40,7 +40,7 @@ protected:
return *Result;
}
- const std::string FileName = "fix.cpp";
+ std::string FileName = "fix.cpp";
IncludeStyle Style = format::getLLVMStyle().IncludeStyle;
};
@@ -102,6 +102,15 @@ TEST_F(HeaderIncludesTest, InsertAfterMainHeader) {
Style = format::getGoogleStyle(format::FormatStyle::LanguageKind::LK_Cpp)
.IncludeStyle;
EXPECT_EQ(Expected, insert(Code, ""));
+
+ FileName = "fix.cu.cpp";
+ EXPECT_EQ(Expected, insert(Code, ""));
+
+ FileName = "fix_test.cu.cpp";
+ EXPECT_EQ(Expected, insert(Code, ""));
+
+ FileName = "bar.cpp";
+ EXPECT_NE(Expected, insert(Code, "")) << "Not main header";
}
TEST_F(HeaderIncludesTest, InsertBeforeSystemHeaderLLVM) {