1 //===--- USRFindingAction.cpp - Clang refactoring library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// Provides an action to find USR for the symbol at <offset>, as well as 11 /// all additional USRs. 12 /// 13 //===----------------------------------------------------------------------===// 14 15 #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h" 16 #include "clang/AST/AST.h" 17 #include "clang/AST/ASTConsumer.h" 18 #include "clang/AST/ASTContext.h" 19 #include "clang/AST/Decl.h" 20 #include "clang/AST/RecursiveASTVisitor.h" 21 #include "clang/Basic/FileManager.h" 22 #include "clang/Frontend/CompilerInstance.h" 23 #include "clang/Frontend/FrontendAction.h" 24 #include "clang/Lex/Lexer.h" 25 #include "clang/Lex/Preprocessor.h" 26 #include "clang/Tooling/CommonOptionsParser.h" 27 #include "clang/Tooling/Refactoring.h" 28 #include "clang/Tooling/Refactoring/Rename/USRFinder.h" 29 #include "clang/Tooling/Tooling.h" 30 31 #include <algorithm> 32 #include <set> 33 #include <string> 34 #include <vector> 35 36 using namespace llvm; 37 38 namespace clang { 39 namespace tooling { 40 41 const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) { 42 // If FoundDecl is a constructor or destructor, we want to instead take 43 // the Decl of the corresponding class. 44 if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl)) 45 FoundDecl = CtorDecl->getParent(); 46 else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl)) 47 FoundDecl = DtorDecl->getParent(); 48 // FIXME: (Alex L): Canonicalize implicit template instantions, just like 49 // the indexer does it. 50 51 // Note: please update the declaration's doc comment every time the 52 // canonicalization rules are changed. 53 return FoundDecl; 54 } 55 56 namespace { 57 // NamedDeclFindingConsumer should delegate finding USRs of given Decl to 58 // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given 59 // Decl refers to class and adds USRs of all overridden methods if Decl refers 60 // to virtual method. 61 class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> { 62 public: 63 AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context) 64 : FoundDecl(FoundDecl), Context(Context) {} 65 66 std::vector<std::string> Find() { 67 // Fill OverriddenMethods and PartialSpecs storages. 68 TraverseDecl(Context.getTranslationUnitDecl()); 69 if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) { 70 addUSRsOfOverridenFunctions(MethodDecl); 71 for (const auto &OverriddenMethod : OverriddenMethods) { 72 if (checkIfOverriddenFunctionAscends(OverriddenMethod)) 73 USRSet.insert(getUSRForDecl(OverriddenMethod)); 74 } 75 addUSRsOfInstantiatedMethods(MethodDecl); 76 } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) { 77 handleCXXRecordDecl(RecordDecl); 78 } else if (const auto *TemplateDecl = 79 dyn_cast<ClassTemplateDecl>(FoundDecl)) { 80 handleClassTemplateDecl(TemplateDecl); 81 } else { 82 USRSet.insert(getUSRForDecl(FoundDecl)); 83 } 84 return std::vector<std::string>(USRSet.begin(), USRSet.end()); 85 } 86 87 bool shouldVisitTemplateInstantiations() const { return true; } 88 89 bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) { 90 if (MethodDecl->isVirtual()) 91 OverriddenMethods.push_back(MethodDecl); 92 if (MethodDecl->getInstantiatedFromMemberFunction()) 93 InstantiatedMethods.push_back(MethodDecl); 94 return true; 95 } 96 97 bool VisitClassTemplatePartialSpecializationDecl( 98 const ClassTemplatePartialSpecializationDecl *PartialSpec) { 99 PartialSpecs.push_back(PartialSpec); 100 return true; 101 } 102 103 private: 104 void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) { 105 RecordDecl = RecordDecl->getDefinition(); 106 if (const auto *ClassTemplateSpecDecl = 107 dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl)) 108 handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate()); 109 addUSRsOfCtorDtors(RecordDecl); 110 } 111 112 void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) { 113 for (const auto *Specialization : TemplateDecl->specializations()) 114 addUSRsOfCtorDtors(Specialization); 115 116 for (const auto *PartialSpec : PartialSpecs) { 117 if (PartialSpec->getSpecializedTemplate() == TemplateDecl) 118 addUSRsOfCtorDtors(PartialSpec); 119 } 120 addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl()); 121 } 122 123 void addUSRsOfCtorDtors(const CXXRecordDecl *RecordDecl) { 124 RecordDecl = RecordDecl->getDefinition(); 125 126 // Skip if the CXXRecordDecl doesn't have definition. 127 if (!RecordDecl) 128 return; 129 130 for (const auto *CtorDecl : RecordDecl->ctors()) 131 USRSet.insert(getUSRForDecl(CtorDecl)); 132 133 USRSet.insert(getUSRForDecl(RecordDecl->getDestructor())); 134 USRSet.insert(getUSRForDecl(RecordDecl)); 135 } 136 137 void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) { 138 USRSet.insert(getUSRForDecl(MethodDecl)); 139 // Recursively visit each OverridenMethod. 140 for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) 141 addUSRsOfOverridenFunctions(OverriddenMethod); 142 } 143 144 void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) { 145 // For renaming a class template method, all references of the instantiated 146 // member methods should be renamed too, so add USRs of the instantiated 147 // methods to the USR set. 148 USRSet.insert(getUSRForDecl(MethodDecl)); 149 if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction()) 150 USRSet.insert(getUSRForDecl(FT)); 151 for (const auto *Method : InstantiatedMethods) { 152 if (USRSet.find(getUSRForDecl( 153 Method->getInstantiatedFromMemberFunction())) != USRSet.end()) 154 USRSet.insert(getUSRForDecl(Method)); 155 } 156 } 157 158 bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) { 159 for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) { 160 if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end()) 161 return true; 162 return checkIfOverriddenFunctionAscends(OverriddenMethod); 163 } 164 return false; 165 } 166 167 const Decl *FoundDecl; 168 ASTContext &Context; 169 std::set<std::string> USRSet; 170 std::vector<const CXXMethodDecl *> OverriddenMethods; 171 std::vector<const CXXMethodDecl *> InstantiatedMethods; 172 std::vector<const ClassTemplatePartialSpecializationDecl *> PartialSpecs; 173 }; 174 } // namespace 175 176 std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND, 177 ASTContext &Context) { 178 AdditionalUSRFinder Finder(ND, Context); 179 return Finder.Find(); 180 } 181 182 class NamedDeclFindingConsumer : public ASTConsumer { 183 public: 184 NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets, 185 ArrayRef<std::string> QualifiedNames, 186 std::vector<std::string> &SpellingNames, 187 std::vector<std::vector<std::string>> &USRList, 188 bool Force, bool &ErrorOccurred) 189 : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames), 190 SpellingNames(SpellingNames), USRList(USRList), Force(Force), 191 ErrorOccurred(ErrorOccurred) {} 192 193 private: 194 bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr, 195 unsigned SymbolOffset, const std::string &QualifiedName) { 196 DiagnosticsEngine &Engine = Context.getDiagnostics(); 197 const FileID MainFileID = SourceMgr.getMainFileID(); 198 199 if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) { 200 ErrorOccurred = true; 201 unsigned InvalidOffset = Engine.getCustomDiagID( 202 DiagnosticsEngine::Error, 203 "SourceLocation in file %0 at offset %1 is invalid"); 204 Engine.Report(SourceLocation(), InvalidOffset) 205 << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset; 206 return false; 207 } 208 209 const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID) 210 .getLocWithOffset(SymbolOffset); 211 const NamedDecl *FoundDecl = QualifiedName.empty() 212 ? getNamedDeclAt(Context, Point) 213 : getNamedDeclFor(Context, QualifiedName); 214 215 if (FoundDecl == nullptr) { 216 if (QualifiedName.empty()) { 217 FullSourceLoc FullLoc(Point, SourceMgr); 218 unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID( 219 DiagnosticsEngine::Error, 220 "clang-rename could not find symbol (offset %0)"); 221 Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset; 222 ErrorOccurred = true; 223 return false; 224 } 225 226 if (Force) { 227 SpellingNames.push_back(std::string()); 228 USRList.push_back(std::vector<std::string>()); 229 return true; 230 } 231 232 unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID( 233 DiagnosticsEngine::Error, "clang-rename could not find symbol %0"); 234 Engine.Report(CouldNotFindSymbolNamed) << QualifiedName; 235 ErrorOccurred = true; 236 return false; 237 } 238 239 FoundDecl = getCanonicalSymbolDeclaration(FoundDecl); 240 SpellingNames.push_back(FoundDecl->getNameAsString()); 241 AdditionalUSRFinder Finder(FoundDecl, Context); 242 USRList.push_back(Finder.Find()); 243 return true; 244 } 245 246 void HandleTranslationUnit(ASTContext &Context) override { 247 const SourceManager &SourceMgr = Context.getSourceManager(); 248 for (unsigned Offset : SymbolOffsets) { 249 if (!FindSymbol(Context, SourceMgr, Offset, "")) 250 return; 251 } 252 for (const std::string &QualifiedName : QualifiedNames) { 253 if (!FindSymbol(Context, SourceMgr, 0, QualifiedName)) 254 return; 255 } 256 } 257 258 ArrayRef<unsigned> SymbolOffsets; 259 ArrayRef<std::string> QualifiedNames; 260 std::vector<std::string> &SpellingNames; 261 std::vector<std::vector<std::string>> &USRList; 262 bool Force; 263 bool &ErrorOccurred; 264 }; 265 266 std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() { 267 return llvm::make_unique<NamedDeclFindingConsumer>( 268 SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force, 269 ErrorOccurred); 270 } 271 272 } // end namespace tooling 273 } // end namespace clang 274