llvm-project/clang-tools-extra/clang-rename/USRFindingAction.cpp

231 lines
8.3 KiB
C++
Raw Normal View History

//===--- tools/extra/clang-rename/USRFindingAction.cpp - Clang rename tool ===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
///
/// \file
/// \brief Provides an action to find USR for the symbol at <offset>, as well as
/// all additional USRs.
///
//===----------------------------------------------------------------------===//
#include "USRFindingAction.h"
#include "USRFinder.h"
#include "clang/AST/AST.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Basic/FileManager.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendAction.h"
#include "clang/Lex/Lexer.h"
#include "clang/Lex/Preprocessor.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Refactoring.h"
#include "clang/Tooling/Tooling.h"
#include <algorithm>
#include <set>
#include <string>
#include <vector>
using namespace llvm;
namespace clang {
namespace rename {
namespace {
// \brief NamedDeclFindingConsumer should delegate finding USRs of given Decl to
// AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
// Decl refers to class and adds USRs of all overridden methods if Decl refers
// to virtual method.
class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
public:
AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
: FoundDecl(FoundDecl), Context(Context) {}
std::vector<std::string> Find() {
// Fill OverriddenMethods and PartialSpecs storages.
TraverseDecl(Context.getTranslationUnitDecl());
if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
addUSRsOfOverridenFunctions(MethodDecl);
for (const auto &OverriddenMethod : OverriddenMethods) {
if (checkIfOverriddenFunctionAscends(OverriddenMethod))
USRSet.insert(getUSRForDecl(OverriddenMethod));
}
} else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
handleCXXRecordDecl(RecordDecl);
} else if (const auto *TemplateDecl =
dyn_cast<ClassTemplateDecl>(FoundDecl)) {
handleClassTemplateDecl(TemplateDecl);
} else {
USRSet.insert(getUSRForDecl(FoundDecl));
}
return std::vector<std::string>(USRSet.begin(), USRSet.end());
}
bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
if (MethodDecl->isVirtual())
OverriddenMethods.push_back(MethodDecl);
return true;
}
bool VisitClassTemplatePartialSpecializationDecl(
const ClassTemplatePartialSpecializationDecl *PartialSpec) {
PartialSpecs.push_back(PartialSpec);
return true;
}
private:
void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
RecordDecl = RecordDecl->getDefinition();
if (const auto *ClassTemplateSpecDecl =
dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
addUSRsOfCtorDtors(RecordDecl);
}
void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
for (const auto *Specialization : TemplateDecl->specializations())
addUSRsOfCtorDtors(Specialization);
for (const auto *PartialSpec : PartialSpecs) {
if (PartialSpec->getSpecializedTemplate() == TemplateDecl)
addUSRsOfCtorDtors(PartialSpec);
}
addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
}
void addUSRsOfCtorDtors(const CXXRecordDecl *RecordDecl) {
RecordDecl = RecordDecl->getDefinition();
// Skip if the CXXRecordDecl doesn't have definition.
if (!RecordDecl)
return;
for (const auto *CtorDecl : RecordDecl->ctors())
USRSet.insert(getUSRForDecl(CtorDecl));
USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
USRSet.insert(getUSRForDecl(RecordDecl));
}
void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
USRSet.insert(getUSRForDecl(MethodDecl));
// Recursively visit each OverridenMethod.
for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
addUSRsOfOverridenFunctions(OverriddenMethod);
}
bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
return true;
return checkIfOverriddenFunctionAscends(OverriddenMethod);
}
return false;
}
const Decl *FoundDecl;
ASTContext &Context;
std::set<std::string> USRSet;
std::vector<const CXXMethodDecl *> OverriddenMethods;
std::vector<const ClassTemplatePartialSpecializationDecl *> PartialSpecs;
};
} // namespace
class NamedDeclFindingConsumer : public ASTConsumer {
public:
NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
ArrayRef<std::string> QualifiedNames,
std::vector<std::string> &SpellingNames,
std::vector<std::vector<std::string>> &USRList,
bool &ErrorOccurred)
: SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
SpellingNames(SpellingNames), USRList(USRList),
ErrorOccurred(ErrorOccurred) {}
private:
bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
unsigned SymbolOffset, const std::string &QualifiedName) {
DiagnosticsEngine &Engine = Context.getDiagnostics();
const FileID MainFileID = SourceMgr.getMainFileID();
if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
ErrorOccurred = true;
unsigned InvalidOffset = Engine.getCustomDiagID(
DiagnosticsEngine::Error,
"SourceLocation in file %0 at offset %1 is invalid");
Engine.Report(SourceLocation(), InvalidOffset)
<< SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset;
return false;
}
const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
.getLocWithOffset(SymbolOffset);
const NamedDecl *FoundDecl = QualifiedName.empty()
? getNamedDeclAt(Context, Point)
: getNamedDeclFor(Context, QualifiedName);
if (FoundDecl == nullptr) {
if (QualifiedName.empty()) {
FullSourceLoc FullLoc(Point, SourceMgr);
unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
DiagnosticsEngine::Error,
"clang-rename could not find symbol (offset %0)");
Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
ErrorOccurred = true;
return false;
}
unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
ErrorOccurred = true;
return false;
}
// If FoundDecl is a constructor or destructor, we want to instead take
// the Decl of the corresponding class.
if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
FoundDecl = CtorDecl->getParent();
else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
FoundDecl = DtorDecl->getParent();
SpellingNames.push_back(FoundDecl->getNameAsString());
AdditionalUSRFinder Finder(FoundDecl, Context);
USRList.push_back(Finder.Find());
return true;
}
void HandleTranslationUnit(ASTContext &Context) override {
const SourceManager &SourceMgr = Context.getSourceManager();
for (unsigned Offset : SymbolOffsets) {
if (!FindSymbol(Context, SourceMgr, Offset, ""))
return;
}
for (const std::string &QualifiedName : QualifiedNames) {
if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
return;
}
}
ArrayRef<unsigned> SymbolOffsets;
ArrayRef<std::string> QualifiedNames;
std::vector<std::string> &SpellingNames;
std::vector<std::vector<std::string>> &USRList;
bool &ErrorOccurred;
};
std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
return llvm::make_unique<NamedDeclFindingConsumer>(
SymbolOffsets, QualifiedNames, SpellingNames, USRList, ErrorOccurred);
}
} // namespace rename
} // namespace clang