1 //===--- RecursiveSymbolVisitor.h - Clang refactoring library -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// A wrapper class around \c RecursiveASTVisitor that visits each 11 /// occurrences of a named symbol. 12 /// 13 //===----------------------------------------------------------------------===// 14 15 #ifndef LLVM_CLANG_TOOLING_REFACTOR_RECURSIVE_SYMBOL_VISITOR_H 16 #define LLVM_CLANG_TOOLING_REFACTOR_RECURSIVE_SYMBOL_VISITOR_H 17 18 #include "clang/AST/AST.h" 19 #include "clang/AST/RecursiveASTVisitor.h" 20 #include "clang/Lex/Lexer.h" 21 22 namespace clang { 23 namespace tooling { 24 25 /// Traverses the AST and visits the occurrence of each named symbol in the 26 /// given nodes. 27 template <typename T> 28 class RecursiveSymbolVisitor 29 : public RecursiveASTVisitor<RecursiveSymbolVisitor<T>> { 30 using BaseType = RecursiveASTVisitor<RecursiveSymbolVisitor<T>>; 31 32 public: RecursiveSymbolVisitor(const SourceManager & SM,const LangOptions & LangOpts)33 RecursiveSymbolVisitor(const SourceManager &SM, const LangOptions &LangOpts) 34 : SM(SM), LangOpts(LangOpts) {} 35 visitSymbolOccurrence(const NamedDecl * ND,ArrayRef<SourceRange> NameRanges)36 bool visitSymbolOccurrence(const NamedDecl *ND, 37 ArrayRef<SourceRange> NameRanges) { 38 return true; 39 } 40 41 // Declaration visitors: 42 VisitNamedDecl(const NamedDecl * D)43 bool VisitNamedDecl(const NamedDecl *D) { 44 return isa<CXXConversionDecl>(D) ? true : visit(D, D->getLocation()); 45 } 46 VisitCXXConstructorDecl(const CXXConstructorDecl * CD)47 bool VisitCXXConstructorDecl(const CXXConstructorDecl *CD) { 48 for (const auto *Initializer : CD->inits()) { 49 // Ignore implicit initializers. 50 if (!Initializer->isWritten()) 51 continue; 52 if (const FieldDecl *FD = Initializer->getMember()) { 53 if (!visit(FD, Initializer->getSourceLocation(), 54 Lexer::getLocForEndOfToken(Initializer->getSourceLocation(), 55 0, SM, LangOpts))) 56 return false; 57 } 58 } 59 return true; 60 } 61 62 // Expression visitors: 63 VisitDeclRefExpr(const DeclRefExpr * Expr)64 bool VisitDeclRefExpr(const DeclRefExpr *Expr) { 65 return visit(Expr->getFoundDecl(), Expr->getLocation()); 66 } 67 VisitMemberExpr(const MemberExpr * Expr)68 bool VisitMemberExpr(const MemberExpr *Expr) { 69 return visit(Expr->getFoundDecl().getDecl(), Expr->getMemberLoc()); 70 } 71 VisitOffsetOfExpr(const OffsetOfExpr * S)72 bool VisitOffsetOfExpr(const OffsetOfExpr *S) { 73 for (unsigned I = 0, E = S->getNumComponents(); I != E; ++I) { 74 const OffsetOfNode &Component = S->getComponent(I); 75 if (Component.getKind() == OffsetOfNode::Field) { 76 if (!visit(Component.getField(), Component.getEndLoc())) 77 return false; 78 } 79 // FIXME: Try to resolve dependent field references. 80 } 81 return true; 82 } 83 84 // Other visitors: 85 VisitTypeLoc(const TypeLoc Loc)86 bool VisitTypeLoc(const TypeLoc Loc) { 87 const SourceLocation TypeBeginLoc = Loc.getBeginLoc(); 88 const SourceLocation TypeEndLoc = 89 Lexer::getLocForEndOfToken(TypeBeginLoc, 0, SM, LangOpts); 90 if (const auto *TemplateTypeParm = 91 dyn_cast<TemplateTypeParmType>(Loc.getType())) { 92 if (!visit(TemplateTypeParm->getDecl(), TypeBeginLoc, TypeEndLoc)) 93 return false; 94 } 95 if (const auto *TemplateSpecType = 96 dyn_cast<TemplateSpecializationType>(Loc.getType())) { 97 if (!visit(TemplateSpecType->getTemplateName().getAsTemplateDecl(), 98 TypeBeginLoc, TypeEndLoc)) 99 return false; 100 } 101 if (const Type *TP = Loc.getTypePtr()) { 102 if (TP->getTypeClass() == clang::Type::Record) 103 return visit(TP->getAsCXXRecordDecl(), TypeBeginLoc, TypeEndLoc); 104 } 105 return true; 106 } 107 VisitTypedefTypeLoc(TypedefTypeLoc TL)108 bool VisitTypedefTypeLoc(TypedefTypeLoc TL) { 109 const SourceLocation TypeEndLoc = 110 Lexer::getLocForEndOfToken(TL.getBeginLoc(), 0, SM, LangOpts); 111 return visit(TL.getTypedefNameDecl(), TL.getBeginLoc(), TypeEndLoc); 112 } 113 TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS)114 bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) { 115 // The base visitor will visit NNSL prefixes, so we should only look at 116 // the current NNS. 117 if (NNS) { 118 const NamespaceDecl *ND = NNS.getNestedNameSpecifier()->getAsNamespace(); 119 if (!visit(ND, NNS.getLocalBeginLoc(), NNS.getLocalEndLoc())) 120 return false; 121 } 122 return BaseType::TraverseNestedNameSpecifierLoc(NNS); 123 } 124 125 private: 126 const SourceManager &SM; 127 const LangOptions &LangOpts; 128 visit(const NamedDecl * ND,SourceLocation BeginLoc,SourceLocation EndLoc)129 bool visit(const NamedDecl *ND, SourceLocation BeginLoc, 130 SourceLocation EndLoc) { 131 return static_cast<T *>(this)->visitSymbolOccurrence( 132 ND, SourceRange(BeginLoc, EndLoc)); 133 } visit(const NamedDecl * ND,SourceLocation Loc)134 bool visit(const NamedDecl *ND, SourceLocation Loc) { 135 return visit(ND, Loc, Lexer::getLocForEndOfToken(Loc, 0, SM, LangOpts)); 136 } 137 }; 138 139 } // end namespace tooling 140 } // end namespace clang 141 142 #endif // LLVM_CLANG_TOOLING_REFACTOR_RECURSIVE_SYMBOL_VISITOR_H 143