1 //===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Provides an action to find USR for the symbol at <offset>, as well as
11 /// all additional USRs.
12 ///
13 //===----------------------------------------------------------------------===//
14
15 #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
16 #include "clang/AST/AST.h"
17 #include "clang/AST/ASTConsumer.h"
18 #include "clang/AST/ASTContext.h"
19 #include "clang/AST/Decl.h"
20 #include "clang/AST/RecursiveASTVisitor.h"
21 #include "clang/Basic/FileManager.h"
22 #include "clang/Frontend/CompilerInstance.h"
23 #include "clang/Frontend/FrontendAction.h"
24 #include "clang/Lex/Lexer.h"
25 #include "clang/Lex/Preprocessor.h"
26 #include "clang/Tooling/CommonOptionsParser.h"
27 #include "clang/Tooling/Refactoring.h"
28 #include "clang/Tooling/Refactoring/Rename/USRFinder.h"
29 #include "clang/Tooling/Tooling.h"
30
31 #include <algorithm>
32 #include <set>
33 #include <string>
34 #include <vector>
35
36 using namespace llvm;
37
38 namespace clang {
39 namespace tooling {
40
getCanonicalSymbolDeclaration(const NamedDecl * FoundDecl)41 const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) {
42 if (!FoundDecl)
43 return nullptr;
44 // If FoundDecl is a constructor or destructor, we want to instead take
45 // the Decl of the corresponding class.
46 if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
47 FoundDecl = CtorDecl->getParent();
48 else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
49 FoundDecl = DtorDecl->getParent();
50 // FIXME: (Alex L): Canonicalize implicit template instantions, just like
51 // the indexer does it.
52
53 // Note: please update the declaration's doc comment every time the
54 // canonicalization rules are changed.
55 return FoundDecl;
56 }
57
58 namespace {
59 // NamedDeclFindingConsumer should delegate finding USRs of given Decl to
60 // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
61 // Decl refers to class and adds USRs of all overridden methods if Decl refers
62 // to virtual method.
63 class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
64 public:
AdditionalUSRFinder(const Decl * FoundDecl,ASTContext & Context)65 AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
66 : FoundDecl(FoundDecl), Context(Context) {}
67
Find()68 std::vector<std::string> Find() {
69 // Fill OverriddenMethods and PartialSpecs storages.
70 TraverseAST(Context);
71 if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
72 addUSRsOfOverridenFunctions(MethodDecl);
73 for (const auto &OverriddenMethod : OverriddenMethods) {
74 if (checkIfOverriddenFunctionAscends(OverriddenMethod))
75 USRSet.insert(getUSRForDecl(OverriddenMethod));
76 }
77 addUSRsOfInstantiatedMethods(MethodDecl);
78 } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
79 handleCXXRecordDecl(RecordDecl);
80 } else if (const auto *TemplateDecl =
81 dyn_cast<ClassTemplateDecl>(FoundDecl)) {
82 handleClassTemplateDecl(TemplateDecl);
83 } else {
84 USRSet.insert(getUSRForDecl(FoundDecl));
85 }
86 return std::vector<std::string>(USRSet.begin(), USRSet.end());
87 }
88
shouldVisitTemplateInstantiations() const89 bool shouldVisitTemplateInstantiations() const { return true; }
90
VisitCXXMethodDecl(const CXXMethodDecl * MethodDecl)91 bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
92 if (MethodDecl->isVirtual())
93 OverriddenMethods.push_back(MethodDecl);
94 if (MethodDecl->getInstantiatedFromMemberFunction())
95 InstantiatedMethods.push_back(MethodDecl);
96 return true;
97 }
98
VisitClassTemplatePartialSpecializationDecl(const ClassTemplatePartialSpecializationDecl * PartialSpec)99 bool VisitClassTemplatePartialSpecializationDecl(
100 const ClassTemplatePartialSpecializationDecl *PartialSpec) {
101 PartialSpecs.push_back(PartialSpec);
102 return true;
103 }
104
105 private:
handleCXXRecordDecl(const CXXRecordDecl * RecordDecl)106 void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
107 if (!RecordDecl->getDefinition()) {
108 USRSet.insert(getUSRForDecl(RecordDecl));
109 return;
110 }
111 RecordDecl = RecordDecl->getDefinition();
112 if (const auto *ClassTemplateSpecDecl =
113 dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
114 handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
115 addUSRsOfCtorDtors(RecordDecl);
116 }
117
handleClassTemplateDecl(const ClassTemplateDecl * TemplateDecl)118 void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
119 for (const auto *Specialization : TemplateDecl->specializations())
120 addUSRsOfCtorDtors(Specialization);
121
122 for (const auto *PartialSpec : PartialSpecs) {
123 if (PartialSpec->getSpecializedTemplate() == TemplateDecl)
124 addUSRsOfCtorDtors(PartialSpec);
125 }
126 addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
127 }
128
addUSRsOfCtorDtors(const CXXRecordDecl * RD)129 void addUSRsOfCtorDtors(const CXXRecordDecl *RD) {
130 const auto* RecordDecl = RD->getDefinition();
131
132 // Skip if the CXXRecordDecl doesn't have definition.
133 if (!RecordDecl) {
134 USRSet.insert(getUSRForDecl(RD));
135 return;
136 }
137
138 for (const auto *CtorDecl : RecordDecl->ctors())
139 USRSet.insert(getUSRForDecl(CtorDecl));
140 // Add template constructor decls, they are not in ctors() unfortunately.
141 if (RecordDecl->hasUserDeclaredConstructor())
142 for (const auto *D : RecordDecl->decls())
143 if (const auto *FTD = dyn_cast<FunctionTemplateDecl>(D))
144 if (const auto *Ctor =
145 dyn_cast<CXXConstructorDecl>(FTD->getTemplatedDecl()))
146 USRSet.insert(getUSRForDecl(Ctor));
147
148 USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
149 USRSet.insert(getUSRForDecl(RecordDecl));
150 }
151
addUSRsOfOverridenFunctions(const CXXMethodDecl * MethodDecl)152 void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
153 USRSet.insert(getUSRForDecl(MethodDecl));
154 // Recursively visit each OverridenMethod.
155 for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
156 addUSRsOfOverridenFunctions(OverriddenMethod);
157 }
158
addUSRsOfInstantiatedMethods(const CXXMethodDecl * MethodDecl)159 void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) {
160 // For renaming a class template method, all references of the instantiated
161 // member methods should be renamed too, so add USRs of the instantiated
162 // methods to the USR set.
163 USRSet.insert(getUSRForDecl(MethodDecl));
164 if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction())
165 USRSet.insert(getUSRForDecl(FT));
166 for (const auto *Method : InstantiatedMethods) {
167 if (USRSet.find(getUSRForDecl(
168 Method->getInstantiatedFromMemberFunction())) != USRSet.end())
169 USRSet.insert(getUSRForDecl(Method));
170 }
171 }
172
checkIfOverriddenFunctionAscends(const CXXMethodDecl * MethodDecl)173 bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
174 for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
175 if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
176 return true;
177 return checkIfOverriddenFunctionAscends(OverriddenMethod);
178 }
179 return false;
180 }
181
182 const Decl *FoundDecl;
183 ASTContext &Context;
184 std::set<std::string> USRSet;
185 std::vector<const CXXMethodDecl *> OverriddenMethods;
186 std::vector<const CXXMethodDecl *> InstantiatedMethods;
187 std::vector<const ClassTemplatePartialSpecializationDecl *> PartialSpecs;
188 };
189 } // namespace
190
getUSRsForDeclaration(const NamedDecl * ND,ASTContext & Context)191 std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND,
192 ASTContext &Context) {
193 AdditionalUSRFinder Finder(ND, Context);
194 return Finder.Find();
195 }
196
197 class NamedDeclFindingConsumer : public ASTConsumer {
198 public:
NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,ArrayRef<std::string> QualifiedNames,std::vector<std::string> & SpellingNames,std::vector<std::vector<std::string>> & USRList,bool Force,bool & ErrorOccurred)199 NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
200 ArrayRef<std::string> QualifiedNames,
201 std::vector<std::string> &SpellingNames,
202 std::vector<std::vector<std::string>> &USRList,
203 bool Force, bool &ErrorOccurred)
204 : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
205 SpellingNames(SpellingNames), USRList(USRList), Force(Force),
206 ErrorOccurred(ErrorOccurred) {}
207
208 private:
FindSymbol(ASTContext & Context,const SourceManager & SourceMgr,unsigned SymbolOffset,const std::string & QualifiedName)209 bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
210 unsigned SymbolOffset, const std::string &QualifiedName) {
211 DiagnosticsEngine &Engine = Context.getDiagnostics();
212 const FileID MainFileID = SourceMgr.getMainFileID();
213
214 if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
215 ErrorOccurred = true;
216 unsigned InvalidOffset = Engine.getCustomDiagID(
217 DiagnosticsEngine::Error,
218 "SourceLocation in file %0 at offset %1 is invalid");
219 Engine.Report(SourceLocation(), InvalidOffset)
220 << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset;
221 return false;
222 }
223
224 const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
225 .getLocWithOffset(SymbolOffset);
226 const NamedDecl *FoundDecl = QualifiedName.empty()
227 ? getNamedDeclAt(Context, Point)
228 : getNamedDeclFor(Context, QualifiedName);
229
230 if (FoundDecl == nullptr) {
231 if (QualifiedName.empty()) {
232 FullSourceLoc FullLoc(Point, SourceMgr);
233 unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
234 DiagnosticsEngine::Error,
235 "clang-rename could not find symbol (offset %0)");
236 Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
237 ErrorOccurred = true;
238 return false;
239 }
240
241 if (Force) {
242 SpellingNames.push_back(std::string());
243 USRList.push_back(std::vector<std::string>());
244 return true;
245 }
246
247 unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
248 DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
249 Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
250 ErrorOccurred = true;
251 return false;
252 }
253
254 FoundDecl = getCanonicalSymbolDeclaration(FoundDecl);
255 SpellingNames.push_back(FoundDecl->getNameAsString());
256 AdditionalUSRFinder Finder(FoundDecl, Context);
257 USRList.push_back(Finder.Find());
258 return true;
259 }
260
HandleTranslationUnit(ASTContext & Context)261 void HandleTranslationUnit(ASTContext &Context) override {
262 const SourceManager &SourceMgr = Context.getSourceManager();
263 for (unsigned Offset : SymbolOffsets) {
264 if (!FindSymbol(Context, SourceMgr, Offset, ""))
265 return;
266 }
267 for (const std::string &QualifiedName : QualifiedNames) {
268 if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
269 return;
270 }
271 }
272
273 ArrayRef<unsigned> SymbolOffsets;
274 ArrayRef<std::string> QualifiedNames;
275 std::vector<std::string> &SpellingNames;
276 std::vector<std::vector<std::string>> &USRList;
277 bool Force;
278 bool &ErrorOccurred;
279 };
280
newASTConsumer()281 std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
282 return std::make_unique<NamedDeclFindingConsumer>(
283 SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
284 ErrorOccurred);
285 }
286
287 } // end namespace tooling
288 } // end namespace clang
289