1 //===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Provides an action to find USR for the symbol at <offset>, as well as
11 /// all additional USRs.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
16 #include "clang/AST/AST.h"
17 #include "clang/AST/ASTConsumer.h"
18 #include "clang/AST/ASTContext.h"
19 #include "clang/AST/Decl.h"
20 #include "clang/AST/RecursiveASTVisitor.h"
21 #include "clang/Basic/FileManager.h"
22 #include "clang/Frontend/CompilerInstance.h"
23 #include "clang/Frontend/FrontendAction.h"
24 #include "clang/Lex/Lexer.h"
25 #include "clang/Lex/Preprocessor.h"
26 #include "clang/Tooling/CommonOptionsParser.h"
27 #include "clang/Tooling/Refactoring.h"
28 #include "clang/Tooling/Refactoring/Rename/USRFinder.h"
29 #include "clang/Tooling/Tooling.h"
30 
31 #include <algorithm>
32 #include <set>
33 #include <string>
34 #include <vector>
35 
36 using namespace llvm;
37 
38 namespace clang {
39 namespace tooling {
40 
41 const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) {
42   if (!FoundDecl)
43     return nullptr;
44   // If FoundDecl is a constructor or destructor, we want to instead take
45   // the Decl of the corresponding class.
46   if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
47     FoundDecl = CtorDecl->getParent();
48   else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
49     FoundDecl = DtorDecl->getParent();
50   // FIXME: (Alex L): Canonicalize implicit template instantions, just like
51   // the indexer does it.
52 
53   // Note: please update the declaration's doc comment every time the
54   // canonicalization rules are changed.
55   return FoundDecl;
56 }
57 
58 namespace {
59 // NamedDeclFindingConsumer should delegate finding USRs of given Decl to
60 // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
61 // Decl refers to class and adds USRs of all overridden methods if Decl refers
62 // to virtual method.
63 class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
64 public:
65   AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
66       : FoundDecl(FoundDecl), Context(Context) {}
67 
68   std::vector<std::string> Find() {
69     // Fill OverriddenMethods and PartialSpecs storages.
70     TraverseAST(Context);
71     if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
72       addUSRsOfOverridenFunctions(MethodDecl);
73       for (const auto &OverriddenMethod : OverriddenMethods) {
74         if (checkIfOverriddenFunctionAscends(OverriddenMethod))
75           USRSet.insert(getUSRForDecl(OverriddenMethod));
76       }
77       addUSRsOfInstantiatedMethods(MethodDecl);
78     } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
79       handleCXXRecordDecl(RecordDecl);
80     } else if (const auto *TemplateDecl =
81                    dyn_cast<ClassTemplateDecl>(FoundDecl)) {
82       handleClassTemplateDecl(TemplateDecl);
83     } else {
84       USRSet.insert(getUSRForDecl(FoundDecl));
85     }
86     return std::vector<std::string>(USRSet.begin(), USRSet.end());
87   }
88 
89   bool shouldVisitTemplateInstantiations() const { return true; }
90 
91   bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
92     if (MethodDecl->isVirtual())
93       OverriddenMethods.push_back(MethodDecl);
94     if (MethodDecl->getInstantiatedFromMemberFunction())
95       InstantiatedMethods.push_back(MethodDecl);
96     return true;
97   }
98 
99   bool VisitClassTemplatePartialSpecializationDecl(
100       const ClassTemplatePartialSpecializationDecl *PartialSpec) {
101     PartialSpecs.push_back(PartialSpec);
102     return true;
103   }
104 
105 private:
106   void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
107     if (!RecordDecl->getDefinition()) {
108       USRSet.insert(getUSRForDecl(RecordDecl));
109       return;
110     }
111     RecordDecl = RecordDecl->getDefinition();
112     if (const auto *ClassTemplateSpecDecl =
113             dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
114       handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
115     addUSRsOfCtorDtors(RecordDecl);
116   }
117 
118   void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
119     for (const auto *Specialization : TemplateDecl->specializations())
120       addUSRsOfCtorDtors(Specialization);
121 
122     for (const auto *PartialSpec : PartialSpecs) {
123       if (PartialSpec->getSpecializedTemplate() == TemplateDecl)
124         addUSRsOfCtorDtors(PartialSpec);
125     }
126     addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
127   }
128 
129   void addUSRsOfCtorDtors(const CXXRecordDecl *RecordDecl) {
130     RecordDecl = RecordDecl->getDefinition();
131 
132     // Skip if the CXXRecordDecl doesn't have definition.
133     if (!RecordDecl)
134       return;
135 
136     for (const auto *CtorDecl : RecordDecl->ctors())
137       USRSet.insert(getUSRForDecl(CtorDecl));
138 
139     USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
140     USRSet.insert(getUSRForDecl(RecordDecl));
141   }
142 
143   void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
144     USRSet.insert(getUSRForDecl(MethodDecl));
145     // Recursively visit each OverridenMethod.
146     for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
147       addUSRsOfOverridenFunctions(OverriddenMethod);
148   }
149 
150   void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) {
151     // For renaming a class template method, all references of the instantiated
152     // member methods should be renamed too, so add USRs of the instantiated
153     // methods to the USR set.
154     USRSet.insert(getUSRForDecl(MethodDecl));
155     if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction())
156       USRSet.insert(getUSRForDecl(FT));
157     for (const auto *Method : InstantiatedMethods) {
158       if (USRSet.find(getUSRForDecl(
159               Method->getInstantiatedFromMemberFunction())) != USRSet.end())
160         USRSet.insert(getUSRForDecl(Method));
161     }
162   }
163 
164   bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
165     for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
166       if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
167         return true;
168       return checkIfOverriddenFunctionAscends(OverriddenMethod);
169     }
170     return false;
171   }
172 
173   const Decl *FoundDecl;
174   ASTContext &Context;
175   std::set<std::string> USRSet;
176   std::vector<const CXXMethodDecl *> OverriddenMethods;
177   std::vector<const CXXMethodDecl *> InstantiatedMethods;
178   std::vector<const ClassTemplatePartialSpecializationDecl *> PartialSpecs;
179 };
180 } // namespace
181 
182 std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND,
183                                                ASTContext &Context) {
184   AdditionalUSRFinder Finder(ND, Context);
185   return Finder.Find();
186 }
187 
188 class NamedDeclFindingConsumer : public ASTConsumer {
189 public:
190   NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
191                            ArrayRef<std::string> QualifiedNames,
192                            std::vector<std::string> &SpellingNames,
193                            std::vector<std::vector<std::string>> &USRList,
194                            bool Force, bool &ErrorOccurred)
195       : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
196         SpellingNames(SpellingNames), USRList(USRList), Force(Force),
197         ErrorOccurred(ErrorOccurred) {}
198 
199 private:
200   bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
201                   unsigned SymbolOffset, const std::string &QualifiedName) {
202     DiagnosticsEngine &Engine = Context.getDiagnostics();
203     const FileID MainFileID = SourceMgr.getMainFileID();
204 
205     if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
206       ErrorOccurred = true;
207       unsigned InvalidOffset = Engine.getCustomDiagID(
208           DiagnosticsEngine::Error,
209           "SourceLocation in file %0 at offset %1 is invalid");
210       Engine.Report(SourceLocation(), InvalidOffset)
211           << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset;
212       return false;
213     }
214 
215     const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
216                                      .getLocWithOffset(SymbolOffset);
217     const NamedDecl *FoundDecl = QualifiedName.empty()
218                                      ? getNamedDeclAt(Context, Point)
219                                      : getNamedDeclFor(Context, QualifiedName);
220 
221     if (FoundDecl == nullptr) {
222       if (QualifiedName.empty()) {
223         FullSourceLoc FullLoc(Point, SourceMgr);
224         unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
225             DiagnosticsEngine::Error,
226             "clang-rename could not find symbol (offset %0)");
227         Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
228         ErrorOccurred = true;
229         return false;
230       }
231 
232       if (Force) {
233         SpellingNames.push_back(std::string());
234         USRList.push_back(std::vector<std::string>());
235         return true;
236       }
237 
238       unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
239           DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
240       Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
241       ErrorOccurred = true;
242       return false;
243     }
244 
245     FoundDecl = getCanonicalSymbolDeclaration(FoundDecl);
246     SpellingNames.push_back(FoundDecl->getNameAsString());
247     AdditionalUSRFinder Finder(FoundDecl, Context);
248     USRList.push_back(Finder.Find());
249     return true;
250   }
251 
252   void HandleTranslationUnit(ASTContext &Context) override {
253     const SourceManager &SourceMgr = Context.getSourceManager();
254     for (unsigned Offset : SymbolOffsets) {
255       if (!FindSymbol(Context, SourceMgr, Offset, ""))
256         return;
257     }
258     for (const std::string &QualifiedName : QualifiedNames) {
259       if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
260         return;
261     }
262   }
263 
264   ArrayRef<unsigned> SymbolOffsets;
265   ArrayRef<std::string> QualifiedNames;
266   std::vector<std::string> &SpellingNames;
267   std::vector<std::vector<std::string>> &USRList;
268   bool Force;
269   bool &ErrorOccurred;
270 };
271 
272 std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
273   return std::make_unique<NamedDeclFindingConsumer>(
274       SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
275       ErrorOccurred);
276 }
277 
278 } // end namespace tooling
279 } // end namespace clang
280