1 //===--- RecursiveSymbolVisitor.h - Clang refactoring library -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// A wrapper class around \c RecursiveASTVisitor that visits each
11 /// occurrences of a named symbol.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_CLANG_TOOLING_REFACTORING_RECURSIVESYMBOLVISITOR_H
16 #define LLVM_CLANG_TOOLING_REFACTORING_RECURSIVESYMBOLVISITOR_H
17 
18 #include "clang/AST/AST.h"
19 #include "clang/AST/RecursiveASTVisitor.h"
20 #include "clang/Lex/Lexer.h"
21 
22 namespace clang {
23 namespace tooling {
24 
25 /// Traverses the AST and visits the occurrence of each named symbol in the
26 /// given nodes.
27 template <typename T>
28 class RecursiveSymbolVisitor
29     : public RecursiveASTVisitor<RecursiveSymbolVisitor<T>> {
30   using BaseType = RecursiveASTVisitor<RecursiveSymbolVisitor<T>>;
31 
32 public:
33   RecursiveSymbolVisitor(const SourceManager &SM, const LangOptions &LangOpts)
34       : SM(SM), LangOpts(LangOpts) {}
35 
36   bool visitSymbolOccurrence(const NamedDecl *ND,
37                              ArrayRef<SourceRange> NameRanges) {
38     return true;
39   }
40 
41   // Declaration visitors:
42 
43   bool VisitNamedDecl(const NamedDecl *D) {
44     return isa<CXXConversionDecl>(D) ? true : visit(D, D->getLocation());
45   }
46 
47   bool VisitCXXConstructorDecl(const CXXConstructorDecl *CD) {
48     for (const auto *Initializer : CD->inits()) {
49       // Ignore implicit initializers.
50       if (!Initializer->isWritten())
51         continue;
52       if (const FieldDecl *FD = Initializer->getMember()) {
53         if (!visit(FD, Initializer->getSourceLocation(),
54                    Lexer::getLocForEndOfToken(Initializer->getSourceLocation(),
55                                               0, SM, LangOpts)))
56           return false;
57       }
58     }
59     return true;
60   }
61 
62   // Expression visitors:
63 
64   bool VisitDeclRefExpr(const DeclRefExpr *Expr) {
65     return visit(Expr->getFoundDecl(), Expr->getLocation());
66   }
67 
68   bool VisitMemberExpr(const MemberExpr *Expr) {
69     return visit(Expr->getFoundDecl().getDecl(), Expr->getMemberLoc());
70   }
71 
72   bool VisitOffsetOfExpr(const OffsetOfExpr *S) {
73     for (unsigned I = 0, E = S->getNumComponents(); I != E; ++I) {
74       const OffsetOfNode &Component = S->getComponent(I);
75       if (Component.getKind() == OffsetOfNode::Field) {
76         if (!visit(Component.getField(), Component.getEndLoc()))
77           return false;
78       }
79       // FIXME: Try to resolve dependent field references.
80     }
81     return true;
82   }
83 
84   // Other visitors:
85 
86   bool VisitTypeLoc(const TypeLoc Loc) {
87     const SourceLocation TypeBeginLoc = Loc.getBeginLoc();
88     const SourceLocation TypeEndLoc =
89         Lexer::getLocForEndOfToken(TypeBeginLoc, 0, SM, LangOpts);
90     if (const auto *TemplateTypeParm =
91             dyn_cast<TemplateTypeParmType>(Loc.getType())) {
92       if (!visit(TemplateTypeParm->getDecl(), TypeBeginLoc, TypeEndLoc))
93         return false;
94     }
95     if (const auto *TemplateSpecType =
96             dyn_cast<TemplateSpecializationType>(Loc.getType())) {
97       if (!visit(TemplateSpecType->getTemplateName().getAsTemplateDecl(),
98                  TypeBeginLoc, TypeEndLoc))
99         return false;
100     }
101     if (const Type *TP = Loc.getTypePtr()) {
102       if (TP->getTypeClass() == clang::Type::Record)
103         return visit(TP->getAsCXXRecordDecl(), TypeBeginLoc, TypeEndLoc);
104     }
105     return true;
106   }
107 
108   bool VisitTypedefTypeLoc(TypedefTypeLoc TL) {
109     const SourceLocation TypeEndLoc =
110         Lexer::getLocForEndOfToken(TL.getBeginLoc(), 0, SM, LangOpts);
111     return visit(TL.getTypedefNameDecl(), TL.getBeginLoc(), TypeEndLoc);
112   }
113 
114   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) {
115     // The base visitor will visit NNSL prefixes, so we should only look at
116     // the current NNS.
117     if (NNS) {
118       const NamespaceDecl *ND = NNS.getNestedNameSpecifier()->getAsNamespace();
119       if (!visit(ND, NNS.getLocalBeginLoc(), NNS.getLocalEndLoc()))
120         return false;
121     }
122     return BaseType::TraverseNestedNameSpecifierLoc(NNS);
123   }
124 
125   bool VisitDesignatedInitExpr(const DesignatedInitExpr *E) {
126     for (const DesignatedInitExpr::Designator &D : E->designators()) {
127       if (D.isFieldDesignator()) {
128         if (const FieldDecl *Decl = D.getFieldDecl()) {
129           if (!visit(Decl, D.getFieldLoc(), D.getFieldLoc()))
130             return false;
131         }
132       }
133     }
134     return true;
135   }
136 
137 private:
138   const SourceManager &SM;
139   const LangOptions &LangOpts;
140 
141   bool visit(const NamedDecl *ND, SourceLocation BeginLoc,
142              SourceLocation EndLoc) {
143     return static_cast<T *>(this)->visitSymbolOccurrence(
144         ND, SourceRange(BeginLoc, EndLoc));
145   }
146   bool visit(const NamedDecl *ND, SourceLocation Loc) {
147     return visit(ND, Loc, Lexer::getLocForEndOfToken(Loc, 0, SM, LangOpts));
148   }
149 };
150 
151 } // end namespace tooling
152 } // end namespace clang
153 
154 #endif // LLVM_CLANG_TOOLING_REFACTORING_RECURSIVESYMBOLVISITOR_H
155