1 //===--- ASTMatchFinder.cpp - Structural query framework ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  Implements an algorithm to efficiently search for matches on AST nodes.
10 //  Uses memoization to support recursive matches like HasDescendant.
11 //
12 //  The general idea is to visit all AST nodes with a RecursiveASTVisitor,
13 //  calling the Matches(...) method of each matcher we are running on each
14 //  AST node. The matcher can recurse via the ASTMatchFinder interface.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "clang/ASTMatchers/ASTMatchFinder.h"
19 #include "clang/AST/ASTConsumer.h"
20 #include "clang/AST/ASTContext.h"
21 #include "clang/AST/DeclCXX.h"
22 #include "clang/AST/RecursiveASTVisitor.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/StringMap.h"
26 #include "llvm/Support/PrettyStackTrace.h"
27 #include "llvm/Support/Timer.h"
28 #include <deque>
29 #include <memory>
30 #include <set>
31 
32 namespace clang {
33 namespace ast_matchers {
34 namespace internal {
35 namespace {
36 
37 typedef MatchFinder::MatchCallback MatchCallback;
38 
39 // The maximum number of memoization entries to store.
40 // 10k has been experimentally found to give a good trade-off
41 // of performance vs. memory consumption by running matcher
42 // that match on every statement over a very large codebase.
43 //
44 // FIXME: Do some performance optimization in general and
45 // revisit this number; also, put up micro-benchmarks that we can
46 // optimize this on.
47 static const unsigned MaxMemoizationEntries = 10000;
48 
49 enum class MatchType {
50   Ancestors,
51 
52   Descendants,
53   Child,
54 };
55 
56 // We use memoization to avoid running the same matcher on the same
57 // AST node twice.  This struct is the key for looking up match
58 // result.  It consists of an ID of the MatcherInterface (for
59 // identifying the matcher), a pointer to the AST node and the
60 // bound nodes before the matcher was executed.
61 //
62 // We currently only memoize on nodes whose pointers identify the
63 // nodes (\c Stmt and \c Decl, but not \c QualType or \c TypeLoc).
64 // For \c QualType and \c TypeLoc it is possible to implement
65 // generation of keys for each type.
66 // FIXME: Benchmark whether memoization of non-pointer typed nodes
67 // provides enough benefit for the additional amount of code.
68 struct MatchKey {
69   DynTypedMatcher::MatcherIDType MatcherID;
70   DynTypedNode Node;
71   BoundNodesTreeBuilder BoundNodes;
72   TraversalKind Traversal = TK_AsIs;
73   MatchType Type;
74 
operator <clang::ast_matchers::internal::__anon55449d7d0111::MatchKey75   bool operator<(const MatchKey &Other) const {
76     return std::tie(Traversal, Type, MatcherID, Node, BoundNodes) <
77            std::tie(Other.Traversal, Other.Type, Other.MatcherID, Other.Node,
78                     Other.BoundNodes);
79   }
80 };
81 
82 // Used to store the result of a match and possibly bound nodes.
83 struct MemoizedMatchResult {
84   bool ResultOfMatch;
85   BoundNodesTreeBuilder Nodes;
86 };
87 
88 // A RecursiveASTVisitor that traverses all children or all descendants of
89 // a node.
90 class MatchChildASTVisitor
91     : public RecursiveASTVisitor<MatchChildASTVisitor> {
92 public:
93   typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase;
94 
95   // Creates an AST visitor that matches 'matcher' on all children or
96   // descendants of a traversed node. max_depth is the maximum depth
97   // to traverse: use 1 for matching the children and INT_MAX for
98   // matching the descendants.
MatchChildASTVisitor(const DynTypedMatcher * Matcher,ASTMatchFinder * Finder,BoundNodesTreeBuilder * Builder,int MaxDepth,bool IgnoreImplicitChildren,ASTMatchFinder::BindKind Bind)99   MatchChildASTVisitor(const DynTypedMatcher *Matcher, ASTMatchFinder *Finder,
100                        BoundNodesTreeBuilder *Builder, int MaxDepth,
101                        bool IgnoreImplicitChildren,
102                        ASTMatchFinder::BindKind Bind)
103       : Matcher(Matcher), Finder(Finder), Builder(Builder), CurrentDepth(0),
104         MaxDepth(MaxDepth), IgnoreImplicitChildren(IgnoreImplicitChildren),
105         Bind(Bind), Matches(false) {}
106 
107   // Returns true if a match is found in the subtree rooted at the
108   // given AST node. This is done via a set of mutually recursive
109   // functions. Here's how the recursion is done (the  *wildcard can
110   // actually be Decl, Stmt, or Type):
111   //
112   //   - Traverse(node) calls BaseTraverse(node) when it needs
113   //     to visit the descendants of node.
114   //   - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node))
115   //     Traverse*(c) for each child c of 'node'.
116   //   - Traverse*(c) in turn calls Traverse(c), completing the
117   //     recursion.
findMatch(const DynTypedNode & DynNode)118   bool findMatch(const DynTypedNode &DynNode) {
119     reset();
120     if (const Decl *D = DynNode.get<Decl>())
121       traverse(*D);
122     else if (const Stmt *S = DynNode.get<Stmt>())
123       traverse(*S);
124     else if (const NestedNameSpecifier *NNS =
125              DynNode.get<NestedNameSpecifier>())
126       traverse(*NNS);
127     else if (const NestedNameSpecifierLoc *NNSLoc =
128              DynNode.get<NestedNameSpecifierLoc>())
129       traverse(*NNSLoc);
130     else if (const QualType *Q = DynNode.get<QualType>())
131       traverse(*Q);
132     else if (const TypeLoc *T = DynNode.get<TypeLoc>())
133       traverse(*T);
134     else if (const auto *C = DynNode.get<CXXCtorInitializer>())
135       traverse(*C);
136     else if (const TemplateArgumentLoc *TALoc =
137                  DynNode.get<TemplateArgumentLoc>())
138       traverse(*TALoc);
139     else if (const Attr *A = DynNode.get<Attr>())
140       traverse(*A);
141     // FIXME: Add other base types after adding tests.
142 
143     // It's OK to always overwrite the bound nodes, as if there was
144     // no match in this recursive branch, the result set is empty
145     // anyway.
146     *Builder = ResultBindings;
147 
148     return Matches;
149   }
150 
151   // The following are overriding methods from the base visitor class.
152   // They are public only to allow CRTP to work. They are *not *part
153   // of the public API of this class.
TraverseDecl(Decl * DeclNode)154   bool TraverseDecl(Decl *DeclNode) {
155 
156     if (DeclNode && DeclNode->isImplicit() &&
157         Finder->isTraversalIgnoringImplicitNodes())
158       return baseTraverse(*DeclNode);
159 
160     ScopedIncrement ScopedDepth(&CurrentDepth);
161     return (DeclNode == nullptr) || traverse(*DeclNode);
162   }
163 
getStmtToTraverse(Stmt * StmtNode)164   Stmt *getStmtToTraverse(Stmt *StmtNode) {
165     Stmt *StmtToTraverse = StmtNode;
166     if (auto *ExprNode = dyn_cast_or_null<Expr>(StmtNode)) {
167       auto *LambdaNode = dyn_cast_or_null<LambdaExpr>(StmtNode);
168       if (LambdaNode && Finder->isTraversalIgnoringImplicitNodes())
169         StmtToTraverse = LambdaNode;
170       else
171         StmtToTraverse =
172             Finder->getASTContext().getParentMapContext().traverseIgnored(
173                 ExprNode);
174     }
175     return StmtToTraverse;
176   }
177 
TraverseStmt(Stmt * StmtNode,DataRecursionQueue * Queue=nullptr)178   bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr) {
179     // If we need to keep track of the depth, we can't perform data recursion.
180     if (CurrentDepth == 0 || (CurrentDepth <= MaxDepth && MaxDepth < INT_MAX))
181       Queue = nullptr;
182 
183     ScopedIncrement ScopedDepth(&CurrentDepth);
184     Stmt *StmtToTraverse = getStmtToTraverse(StmtNode);
185     if (!StmtToTraverse)
186       return true;
187 
188     if (IgnoreImplicitChildren && isa<CXXDefaultArgExpr>(StmtNode))
189       return true;
190 
191     if (!match(*StmtToTraverse))
192       return false;
193     return VisitorBase::TraverseStmt(StmtToTraverse, Queue);
194   }
195   // We assume that the QualType and the contained type are on the same
196   // hierarchy level. Thus, we try to match either of them.
TraverseType(QualType TypeNode)197   bool TraverseType(QualType TypeNode) {
198     if (TypeNode.isNull())
199       return true;
200     ScopedIncrement ScopedDepth(&CurrentDepth);
201     // Match the Type.
202     if (!match(*TypeNode))
203       return false;
204     // The QualType is matched inside traverse.
205     return traverse(TypeNode);
206   }
207   // We assume that the TypeLoc, contained QualType and contained Type all are
208   // on the same hierarchy level. Thus, we try to match all of them.
TraverseTypeLoc(TypeLoc TypeLocNode)209   bool TraverseTypeLoc(TypeLoc TypeLocNode) {
210     if (TypeLocNode.isNull())
211       return true;
212     ScopedIncrement ScopedDepth(&CurrentDepth);
213     // Match the Type.
214     if (!match(*TypeLocNode.getType()))
215       return false;
216     // Match the QualType.
217     if (!match(TypeLocNode.getType()))
218       return false;
219     // The TypeLoc is matched inside traverse.
220     return traverse(TypeLocNode);
221   }
TraverseNestedNameSpecifier(NestedNameSpecifier * NNS)222   bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
223     ScopedIncrement ScopedDepth(&CurrentDepth);
224     return (NNS == nullptr) || traverse(*NNS);
225   }
TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS)226   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) {
227     if (!NNS)
228       return true;
229     ScopedIncrement ScopedDepth(&CurrentDepth);
230     if (!match(*NNS.getNestedNameSpecifier()))
231       return false;
232     return traverse(NNS);
233   }
TraverseConstructorInitializer(CXXCtorInitializer * CtorInit)234   bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit) {
235     if (!CtorInit)
236       return true;
237     ScopedIncrement ScopedDepth(&CurrentDepth);
238     return traverse(*CtorInit);
239   }
TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL)240   bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL) {
241     ScopedIncrement ScopedDepth(&CurrentDepth);
242     return traverse(TAL);
243   }
TraverseCXXForRangeStmt(CXXForRangeStmt * Node)244   bool TraverseCXXForRangeStmt(CXXForRangeStmt *Node) {
245     if (!Finder->isTraversalIgnoringImplicitNodes())
246       return VisitorBase::TraverseCXXForRangeStmt(Node);
247     if (!Node)
248       return true;
249     ScopedIncrement ScopedDepth(&CurrentDepth);
250     if (auto *Init = Node->getInit())
251       if (!traverse(*Init))
252         return false;
253     if (!match(*Node->getLoopVariable()))
254       return false;
255     if (match(*Node->getRangeInit()))
256       if (!VisitorBase::TraverseStmt(Node->getRangeInit()))
257         return false;
258     if (!match(*Node->getBody()))
259       return false;
260     return VisitorBase::TraverseStmt(Node->getBody());
261   }
TraverseCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator * Node)262   bool TraverseCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *Node) {
263     if (!Finder->isTraversalIgnoringImplicitNodes())
264       return VisitorBase::TraverseCXXRewrittenBinaryOperator(Node);
265     if (!Node)
266       return true;
267     ScopedIncrement ScopedDepth(&CurrentDepth);
268 
269     return match(*Node->getLHS()) && match(*Node->getRHS());
270   }
TraverseAttr(Attr * A)271   bool TraverseAttr(Attr *A) {
272     if (A == nullptr ||
273         (A->isImplicit() &&
274          Finder->getASTContext().getParentMapContext().getTraversalKind() ==
275              TK_IgnoreUnlessSpelledInSource))
276       return true;
277     ScopedIncrement ScopedDepth(&CurrentDepth);
278     return traverse(*A);
279   }
TraverseLambdaExpr(LambdaExpr * Node)280   bool TraverseLambdaExpr(LambdaExpr *Node) {
281     if (!Finder->isTraversalIgnoringImplicitNodes())
282       return VisitorBase::TraverseLambdaExpr(Node);
283     if (!Node)
284       return true;
285     ScopedIncrement ScopedDepth(&CurrentDepth);
286 
287     for (unsigned I = 0, N = Node->capture_size(); I != N; ++I) {
288       const auto *C = Node->capture_begin() + I;
289       if (!C->isExplicit())
290         continue;
291       if (Node->isInitCapture(C) && !match(*C->getCapturedVar()))
292         return false;
293       if (!match(*Node->capture_init_begin()[I]))
294         return false;
295     }
296 
297     if (const auto *TPL = Node->getTemplateParameterList()) {
298       for (const auto *TP : *TPL) {
299         if (!match(*TP))
300           return false;
301       }
302     }
303 
304     for (const auto *P : Node->getCallOperator()->parameters()) {
305       if (!match(*P))
306         return false;
307     }
308 
309     if (!match(*Node->getBody()))
310       return false;
311 
312     return VisitorBase::TraverseStmt(Node->getBody());
313   }
314 
shouldVisitTemplateInstantiations() const315   bool shouldVisitTemplateInstantiations() const { return true; }
shouldVisitImplicitCode() const316   bool shouldVisitImplicitCode() const { return !IgnoreImplicitChildren; }
317 
318 private:
319   // Used for updating the depth during traversal.
320   struct ScopedIncrement {
ScopedIncrementclang::ast_matchers::internal::__anon55449d7d0111::MatchChildASTVisitor::ScopedIncrement321     explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); }
~ScopedIncrementclang::ast_matchers::internal::__anon55449d7d0111::MatchChildASTVisitor::ScopedIncrement322     ~ScopedIncrement() { --(*Depth); }
323 
324    private:
325     int *Depth;
326   };
327 
328   // Resets the state of this object.
reset()329   void reset() {
330     Matches = false;
331     CurrentDepth = 0;
332   }
333 
334   // Forwards the call to the corresponding Traverse*() method in the
335   // base visitor class.
baseTraverse(const Decl & DeclNode)336   bool baseTraverse(const Decl &DeclNode) {
337     return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode));
338   }
baseTraverse(const Stmt & StmtNode)339   bool baseTraverse(const Stmt &StmtNode) {
340     return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode));
341   }
baseTraverse(QualType TypeNode)342   bool baseTraverse(QualType TypeNode) {
343     return VisitorBase::TraverseType(TypeNode);
344   }
baseTraverse(TypeLoc TypeLocNode)345   bool baseTraverse(TypeLoc TypeLocNode) {
346     return VisitorBase::TraverseTypeLoc(TypeLocNode);
347   }
baseTraverse(const NestedNameSpecifier & NNS)348   bool baseTraverse(const NestedNameSpecifier &NNS) {
349     return VisitorBase::TraverseNestedNameSpecifier(
350         const_cast<NestedNameSpecifier*>(&NNS));
351   }
baseTraverse(NestedNameSpecifierLoc NNS)352   bool baseTraverse(NestedNameSpecifierLoc NNS) {
353     return VisitorBase::TraverseNestedNameSpecifierLoc(NNS);
354   }
baseTraverse(const CXXCtorInitializer & CtorInit)355   bool baseTraverse(const CXXCtorInitializer &CtorInit) {
356     return VisitorBase::TraverseConstructorInitializer(
357         const_cast<CXXCtorInitializer *>(&CtorInit));
358   }
baseTraverse(TemplateArgumentLoc TAL)359   bool baseTraverse(TemplateArgumentLoc TAL) {
360     return VisitorBase::TraverseTemplateArgumentLoc(TAL);
361   }
baseTraverse(const Attr & AttrNode)362   bool baseTraverse(const Attr &AttrNode) {
363     return VisitorBase::TraverseAttr(const_cast<Attr *>(&AttrNode));
364   }
365 
366   // Sets 'Matched' to true if 'Matcher' matches 'Node' and:
367   //   0 < CurrentDepth <= MaxDepth.
368   //
369   // Returns 'true' if traversal should continue after this function
370   // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
371   template <typename T>
match(const T & Node)372   bool match(const T &Node) {
373     if (CurrentDepth == 0 || CurrentDepth > MaxDepth) {
374       return true;
375     }
376     if (Bind != ASTMatchFinder::BK_All) {
377       BoundNodesTreeBuilder RecursiveBuilder(*Builder);
378       if (Matcher->matches(DynTypedNode::create(Node), Finder,
379                            &RecursiveBuilder)) {
380         Matches = true;
381         ResultBindings.addMatch(RecursiveBuilder);
382         return false; // Abort as soon as a match is found.
383       }
384     } else {
385       BoundNodesTreeBuilder RecursiveBuilder(*Builder);
386       if (Matcher->matches(DynTypedNode::create(Node), Finder,
387                            &RecursiveBuilder)) {
388         // After the first match the matcher succeeds.
389         Matches = true;
390         ResultBindings.addMatch(RecursiveBuilder);
391       }
392     }
393     return true;
394   }
395 
396   // Traverses the subtree rooted at 'Node'; returns true if the
397   // traversal should continue after this function returns.
398   template <typename T>
traverse(const T & Node)399   bool traverse(const T &Node) {
400     static_assert(IsBaseType<T>::value,
401                   "traverse can only be instantiated with base type");
402     if (!match(Node))
403       return false;
404     return baseTraverse(Node);
405   }
406 
407   const DynTypedMatcher *const Matcher;
408   ASTMatchFinder *const Finder;
409   BoundNodesTreeBuilder *const Builder;
410   BoundNodesTreeBuilder ResultBindings;
411   int CurrentDepth;
412   const int MaxDepth;
413   const bool IgnoreImplicitChildren;
414   const ASTMatchFinder::BindKind Bind;
415   bool Matches;
416 };
417 
418 // Controls the outermost traversal of the AST and allows to match multiple
419 // matchers.
420 class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
421                         public ASTMatchFinder {
422 public:
MatchASTVisitor(const MatchFinder::MatchersByType * Matchers,const MatchFinder::MatchFinderOptions & Options)423   MatchASTVisitor(const MatchFinder::MatchersByType *Matchers,
424                   const MatchFinder::MatchFinderOptions &Options)
425       : Matchers(Matchers), Options(Options), ActiveASTContext(nullptr) {}
426 
~MatchASTVisitor()427   ~MatchASTVisitor() override {
428     if (Options.CheckProfiling) {
429       Options.CheckProfiling->Records = std::move(TimeByBucket);
430     }
431   }
432 
onStartOfTranslationUnit()433   void onStartOfTranslationUnit() {
434     const bool EnableCheckProfiling = Options.CheckProfiling.has_value();
435     TimeBucketRegion Timer;
436     for (MatchCallback *MC : Matchers->AllCallbacks) {
437       if (EnableCheckProfiling)
438         Timer.setBucket(&TimeByBucket[MC->getID()]);
439       MC->onStartOfTranslationUnit();
440     }
441   }
442 
onEndOfTranslationUnit()443   void onEndOfTranslationUnit() {
444     const bool EnableCheckProfiling = Options.CheckProfiling.has_value();
445     TimeBucketRegion Timer;
446     for (MatchCallback *MC : Matchers->AllCallbacks) {
447       if (EnableCheckProfiling)
448         Timer.setBucket(&TimeByBucket[MC->getID()]);
449       MC->onEndOfTranslationUnit();
450     }
451   }
452 
set_active_ast_context(ASTContext * NewActiveASTContext)453   void set_active_ast_context(ASTContext *NewActiveASTContext) {
454     ActiveASTContext = NewActiveASTContext;
455   }
456 
457   // The following Visit*() and Traverse*() functions "override"
458   // methods in RecursiveASTVisitor.
459 
VisitTypedefNameDecl(TypedefNameDecl * DeclNode)460   bool VisitTypedefNameDecl(TypedefNameDecl *DeclNode) {
461     // When we see 'typedef A B', we add name 'B' to the set of names
462     // A's canonical type maps to.  This is necessary for implementing
463     // isDerivedFrom(x) properly, where x can be the name of the base
464     // class or any of its aliases.
465     //
466     // In general, the is-alias-of (as defined by typedefs) relation
467     // is tree-shaped, as you can typedef a type more than once.  For
468     // example,
469     //
470     //   typedef A B;
471     //   typedef A C;
472     //   typedef C D;
473     //   typedef C E;
474     //
475     // gives you
476     //
477     //   A
478     //   |- B
479     //   `- C
480     //      |- D
481     //      `- E
482     //
483     // It is wrong to assume that the relation is a chain.  A correct
484     // implementation of isDerivedFrom() needs to recognize that B and
485     // E are aliases, even though neither is a typedef of the other.
486     // Therefore, we cannot simply walk through one typedef chain to
487     // find out whether the type name matches.
488     const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr();
489     const Type *CanonicalType =  // root of the typedef tree
490         ActiveASTContext->getCanonicalType(TypeNode);
491     TypeAliases[CanonicalType].insert(DeclNode);
492     return true;
493   }
494 
VisitObjCCompatibleAliasDecl(ObjCCompatibleAliasDecl * CAD)495   bool VisitObjCCompatibleAliasDecl(ObjCCompatibleAliasDecl *CAD) {
496     const ObjCInterfaceDecl *InterfaceDecl = CAD->getClassInterface();
497     CompatibleAliases[InterfaceDecl].insert(CAD);
498     return true;
499   }
500 
501   bool TraverseDecl(Decl *DeclNode);
502   bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr);
503   bool TraverseType(QualType TypeNode);
504   bool TraverseTypeLoc(TypeLoc TypeNode);
505   bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS);
506   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS);
507   bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit);
508   bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL);
509   bool TraverseAttr(Attr *AttrNode);
510 
dataTraverseNode(Stmt * S,DataRecursionQueue * Queue)511   bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) {
512     if (auto *RF = dyn_cast<CXXForRangeStmt>(S)) {
513       {
514         ASTNodeNotAsIsSourceScope RAII(this, true);
515         TraverseStmt(RF->getInit());
516         // Don't traverse under the loop variable
517         match(*RF->getLoopVariable());
518         TraverseStmt(RF->getRangeInit());
519       }
520       {
521         ASTNodeNotSpelledInSourceScope RAII(this, true);
522         for (auto *SubStmt : RF->children()) {
523           if (SubStmt != RF->getBody())
524             TraverseStmt(SubStmt);
525         }
526       }
527       TraverseStmt(RF->getBody());
528       return true;
529     } else if (auto *RBO = dyn_cast<CXXRewrittenBinaryOperator>(S)) {
530       {
531         ASTNodeNotAsIsSourceScope RAII(this, true);
532         TraverseStmt(const_cast<Expr *>(RBO->getLHS()));
533         TraverseStmt(const_cast<Expr *>(RBO->getRHS()));
534       }
535       {
536         ASTNodeNotSpelledInSourceScope RAII(this, true);
537         for (auto *SubStmt : RBO->children()) {
538           TraverseStmt(SubStmt);
539         }
540       }
541       return true;
542     } else if (auto *LE = dyn_cast<LambdaExpr>(S)) {
543       for (auto I : llvm::zip(LE->captures(), LE->capture_inits())) {
544         auto C = std::get<0>(I);
545         ASTNodeNotSpelledInSourceScope RAII(
546             this, TraversingASTNodeNotSpelledInSource || !C.isExplicit());
547         TraverseLambdaCapture(LE, &C, std::get<1>(I));
548       }
549 
550       {
551         ASTNodeNotSpelledInSourceScope RAII(this, true);
552         TraverseDecl(LE->getLambdaClass());
553       }
554       {
555         ASTNodeNotAsIsSourceScope RAII(this, true);
556 
557         // We need to poke around to find the bits that might be explicitly
558         // written.
559         TypeLoc TL = LE->getCallOperator()->getTypeSourceInfo()->getTypeLoc();
560         FunctionProtoTypeLoc Proto = TL.getAsAdjusted<FunctionProtoTypeLoc>();
561 
562         if (auto *TPL = LE->getTemplateParameterList()) {
563           for (NamedDecl *D : *TPL) {
564             TraverseDecl(D);
565           }
566           if (Expr *RequiresClause = TPL->getRequiresClause()) {
567             TraverseStmt(RequiresClause);
568           }
569         }
570 
571         if (LE->hasExplicitParameters()) {
572           // Visit parameters.
573           for (ParmVarDecl *Param : Proto.getParams())
574             TraverseDecl(Param);
575         }
576 
577         const auto *T = Proto.getTypePtr();
578         for (const auto &E : T->exceptions())
579           TraverseType(E);
580 
581         if (Expr *NE = T->getNoexceptExpr())
582           TraverseStmt(NE, Queue);
583 
584         if (LE->hasExplicitResultType())
585           TraverseTypeLoc(Proto.getReturnLoc());
586         TraverseStmt(LE->getTrailingRequiresClause());
587       }
588 
589       TraverseStmt(LE->getBody());
590       return true;
591     }
592     return RecursiveASTVisitor<MatchASTVisitor>::dataTraverseNode(S, Queue);
593   }
594 
595   // Matches children or descendants of 'Node' with 'BaseMatcher'.
memoizedMatchesRecursively(const DynTypedNode & Node,ASTContext & Ctx,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,int MaxDepth,BindKind Bind)596   bool memoizedMatchesRecursively(const DynTypedNode &Node, ASTContext &Ctx,
597                                   const DynTypedMatcher &Matcher,
598                                   BoundNodesTreeBuilder *Builder, int MaxDepth,
599                                   BindKind Bind) {
600     // For AST-nodes that don't have an identity, we can't memoize.
601     if (!Node.getMemoizationData() || !Builder->isComparable())
602       return matchesRecursively(Node, Matcher, Builder, MaxDepth, Bind);
603 
604     MatchKey Key;
605     Key.MatcherID = Matcher.getID();
606     Key.Node = Node;
607     // Note that we key on the bindings *before* the match.
608     Key.BoundNodes = *Builder;
609     Key.Traversal = Ctx.getParentMapContext().getTraversalKind();
610     // Memoize result even doing a single-level match, it might be expensive.
611     Key.Type = MaxDepth == 1 ? MatchType::Child : MatchType::Descendants;
612     MemoizationMap::iterator I = ResultCache.find(Key);
613     if (I != ResultCache.end()) {
614       *Builder = I->second.Nodes;
615       return I->second.ResultOfMatch;
616     }
617 
618     MemoizedMatchResult Result;
619     Result.Nodes = *Builder;
620     Result.ResultOfMatch =
621         matchesRecursively(Node, Matcher, &Result.Nodes, MaxDepth, Bind);
622 
623     MemoizedMatchResult &CachedResult = ResultCache[Key];
624     CachedResult = std::move(Result);
625 
626     *Builder = CachedResult.Nodes;
627     return CachedResult.ResultOfMatch;
628   }
629 
630   // Matches children or descendants of 'Node' with 'BaseMatcher'.
matchesRecursively(const DynTypedNode & Node,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,int MaxDepth,BindKind Bind)631   bool matchesRecursively(const DynTypedNode &Node,
632                           const DynTypedMatcher &Matcher,
633                           BoundNodesTreeBuilder *Builder, int MaxDepth,
634                           BindKind Bind) {
635     bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
636                            TraversingASTChildrenNotSpelledInSource;
637 
638     bool IgnoreImplicitChildren = false;
639 
640     if (isTraversalIgnoringImplicitNodes()) {
641       IgnoreImplicitChildren = true;
642     }
643 
644     ASTNodeNotSpelledInSourceScope RAII(this, ScopedTraversal);
645 
646     MatchChildASTVisitor Visitor(&Matcher, this, Builder, MaxDepth,
647                                  IgnoreImplicitChildren, Bind);
648     return Visitor.findMatch(Node);
649   }
650 
651   bool classIsDerivedFrom(const CXXRecordDecl *Declaration,
652                           const Matcher<NamedDecl> &Base,
653                           BoundNodesTreeBuilder *Builder,
654                           bool Directly) override;
655 
656 private:
657   bool
658   classIsDerivedFromImpl(const CXXRecordDecl *Declaration,
659                          const Matcher<NamedDecl> &Base,
660                          BoundNodesTreeBuilder *Builder, bool Directly,
661                          llvm::SmallPtrSetImpl<const CXXRecordDecl *> &Visited);
662 
663 public:
664   bool objcClassIsDerivedFrom(const ObjCInterfaceDecl *Declaration,
665                               const Matcher<NamedDecl> &Base,
666                               BoundNodesTreeBuilder *Builder,
667                               bool Directly) override;
668 
669 public:
670   // Implements ASTMatchFinder::matchesChildOf.
matchesChildOf(const DynTypedNode & Node,ASTContext & Ctx,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,BindKind Bind)671   bool matchesChildOf(const DynTypedNode &Node, ASTContext &Ctx,
672                       const DynTypedMatcher &Matcher,
673                       BoundNodesTreeBuilder *Builder, BindKind Bind) override {
674     if (ResultCache.size() > MaxMemoizationEntries)
675       ResultCache.clear();
676     return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, 1, Bind);
677   }
678   // Implements ASTMatchFinder::matchesDescendantOf.
matchesDescendantOf(const DynTypedNode & Node,ASTContext & Ctx,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,BindKind Bind)679   bool matchesDescendantOf(const DynTypedNode &Node, ASTContext &Ctx,
680                            const DynTypedMatcher &Matcher,
681                            BoundNodesTreeBuilder *Builder,
682                            BindKind Bind) override {
683     if (ResultCache.size() > MaxMemoizationEntries)
684       ResultCache.clear();
685     return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, INT_MAX,
686                                       Bind);
687   }
688   // Implements ASTMatchFinder::matchesAncestorOf.
matchesAncestorOf(const DynTypedNode & Node,ASTContext & Ctx,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,AncestorMatchMode MatchMode)689   bool matchesAncestorOf(const DynTypedNode &Node, ASTContext &Ctx,
690                          const DynTypedMatcher &Matcher,
691                          BoundNodesTreeBuilder *Builder,
692                          AncestorMatchMode MatchMode) override {
693     // Reset the cache outside of the recursive call to make sure we
694     // don't invalidate any iterators.
695     if (ResultCache.size() > MaxMemoizationEntries)
696       ResultCache.clear();
697     if (MatchMode == AncestorMatchMode::AMM_ParentOnly)
698       return matchesParentOf(Node, Matcher, Builder);
699     return matchesAnyAncestorOf(Node, Ctx, Matcher, Builder);
700   }
701 
702   // Matches all registered matchers on the given node and calls the
703   // result callback for every node that matches.
match(const DynTypedNode & Node)704   void match(const DynTypedNode &Node) {
705     // FIXME: Improve this with a switch or a visitor pattern.
706     if (auto *N = Node.get<Decl>()) {
707       match(*N);
708     } else if (auto *N = Node.get<Stmt>()) {
709       match(*N);
710     } else if (auto *N = Node.get<Type>()) {
711       match(*N);
712     } else if (auto *N = Node.get<QualType>()) {
713       match(*N);
714     } else if (auto *N = Node.get<NestedNameSpecifier>()) {
715       match(*N);
716     } else if (auto *N = Node.get<NestedNameSpecifierLoc>()) {
717       match(*N);
718     } else if (auto *N = Node.get<TypeLoc>()) {
719       match(*N);
720     } else if (auto *N = Node.get<CXXCtorInitializer>()) {
721       match(*N);
722     } else if (auto *N = Node.get<TemplateArgumentLoc>()) {
723       match(*N);
724     } else if (auto *N = Node.get<Attr>()) {
725       match(*N);
726     }
727   }
728 
match(const T & Node)729   template <typename T> void match(const T &Node) {
730     matchDispatch(&Node);
731   }
732 
733   // Implements ASTMatchFinder::getASTContext.
getASTContext() const734   ASTContext &getASTContext() const override { return *ActiveASTContext; }
735 
shouldVisitTemplateInstantiations() const736   bool shouldVisitTemplateInstantiations() const { return true; }
shouldVisitImplicitCode() const737   bool shouldVisitImplicitCode() const { return true; }
738 
739   // We visit the lambda body explicitly, so instruct the RAV
740   // to not visit it on our behalf too.
shouldVisitLambdaBody() const741   bool shouldVisitLambdaBody() const { return false; }
742 
IsMatchingInASTNodeNotSpelledInSource() const743   bool IsMatchingInASTNodeNotSpelledInSource() const override {
744     return TraversingASTNodeNotSpelledInSource;
745   }
isMatchingChildrenNotSpelledInSource() const746   bool isMatchingChildrenNotSpelledInSource() const override {
747     return TraversingASTChildrenNotSpelledInSource;
748   }
setMatchingChildrenNotSpelledInSource(bool Set)749   void setMatchingChildrenNotSpelledInSource(bool Set) override {
750     TraversingASTChildrenNotSpelledInSource = Set;
751   }
752 
IsMatchingInASTNodeNotAsIs() const753   bool IsMatchingInASTNodeNotAsIs() const override {
754     return TraversingASTNodeNotAsIs;
755   }
756 
TraverseTemplateInstantiations(ClassTemplateDecl * D)757   bool TraverseTemplateInstantiations(ClassTemplateDecl *D) {
758     ASTNodeNotSpelledInSourceScope RAII(this, true);
759     return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
760         D);
761   }
762 
TraverseTemplateInstantiations(VarTemplateDecl * D)763   bool TraverseTemplateInstantiations(VarTemplateDecl *D) {
764     ASTNodeNotSpelledInSourceScope RAII(this, true);
765     return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
766         D);
767   }
768 
TraverseTemplateInstantiations(FunctionTemplateDecl * D)769   bool TraverseTemplateInstantiations(FunctionTemplateDecl *D) {
770     ASTNodeNotSpelledInSourceScope RAII(this, true);
771     return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
772         D);
773   }
774 
775 private:
776   bool TraversingASTNodeNotSpelledInSource = false;
777   bool TraversingASTNodeNotAsIs = false;
778   bool TraversingASTChildrenNotSpelledInSource = false;
779 
780   class CurMatchData {
781 // We don't have enough free low bits in 32bit builds to discriminate 8 pointer
782 // types in PointerUnion. so split the union in 2 using a free bit from the
783 // callback pointer.
784 #define CMD_TYPES_0                                                            \
785   const QualType *, const TypeLoc *, const NestedNameSpecifier *,              \
786       const NestedNameSpecifierLoc *
787 #define CMD_TYPES_1                                                            \
788   const CXXCtorInitializer *, const TemplateArgumentLoc *, const Attr *,       \
789       const DynTypedNode *
790 
791 #define IMPL(Index)                                                            \
792   template <typename NodeType>                                                 \
793   std::enable_if_t<                                                            \
794       llvm::is_one_of<const NodeType *, CMD_TYPES_##Index>::value>             \
795   SetCallbackAndRawNode(const MatchCallback *CB, const NodeType &N) {          \
796     assertEmpty();                                                             \
797     Callback.setPointerAndInt(CB, Index);                                      \
798     Node##Index = &N;                                                          \
799   }                                                                            \
800                                                                                \
801   template <typename T>                                                        \
802   std::enable_if_t<llvm::is_one_of<const T *, CMD_TYPES_##Index>::value,       \
803                    const T *>                                                  \
804   getNode() const {                                                            \
805     assertHoldsState();                                                        \
806     return Callback.getInt() == (Index) ? Node##Index.dyn_cast<const T *>()    \
807                                         : nullptr;                             \
808   }
809 
810   public:
CurMatchData()811     CurMatchData() : Node0(nullptr) {}
812 
813     IMPL(0)
814     IMPL(1)
815 
getCallback() const816     const MatchCallback *getCallback() const { return Callback.getPointer(); }
817 
SetBoundNodes(const BoundNodes & BN)818     void SetBoundNodes(const BoundNodes &BN) {
819       assertHoldsState();
820       BNodes = &BN;
821     }
822 
clearBoundNodes()823     void clearBoundNodes() {
824       assertHoldsState();
825       BNodes = nullptr;
826     }
827 
getBoundNodes() const828     const BoundNodes *getBoundNodes() const {
829       assertHoldsState();
830       return BNodes;
831     }
832 
reset()833     void reset() {
834       assertHoldsState();
835       Callback.setPointerAndInt(nullptr, 0);
836       Node0 = nullptr;
837     }
838 
839   private:
assertHoldsState() const840     void assertHoldsState() const {
841       assert(Callback.getPointer() != nullptr && !Node0.isNull());
842     }
843 
assertEmpty() const844     void assertEmpty() const {
845       assert(Callback.getPointer() == nullptr && Node0.isNull() &&
846              BNodes == nullptr);
847     }
848 
849     llvm::PointerIntPair<const MatchCallback *, 1> Callback;
850     union {
851       llvm::PointerUnion<CMD_TYPES_0> Node0;
852       llvm::PointerUnion<CMD_TYPES_1> Node1;
853     };
854     const BoundNodes *BNodes = nullptr;
855 
856 #undef CMD_TYPES_0
857 #undef CMD_TYPES_1
858 #undef IMPL
859   } CurMatchState;
860 
861   struct CurMatchRAII {
862     template <typename NodeType>
CurMatchRAIIclang::ast_matchers::internal::__anon55449d7d0111::MatchASTVisitor::CurMatchRAII863     CurMatchRAII(MatchASTVisitor &MV, const MatchCallback *CB,
864                  const NodeType &NT)
865         : MV(MV) {
866       MV.CurMatchState.SetCallbackAndRawNode(CB, NT);
867     }
868 
~CurMatchRAIIclang::ast_matchers::internal::__anon55449d7d0111::MatchASTVisitor::CurMatchRAII869     ~CurMatchRAII() { MV.CurMatchState.reset(); }
870 
871   private:
872     MatchASTVisitor &MV;
873   };
874 
875 public:
876   class TraceReporter : llvm::PrettyStackTraceEntry {
dumpNode(const ASTContext & Ctx,const DynTypedNode & Node,raw_ostream & OS)877     static void dumpNode(const ASTContext &Ctx, const DynTypedNode &Node,
878                          raw_ostream &OS) {
879       if (const auto *D = Node.get<Decl>()) {
880         OS << D->getDeclKindName() << "Decl ";
881         if (const auto *ND = dyn_cast<NamedDecl>(D)) {
882           ND->printQualifiedName(OS);
883           OS << " : ";
884         } else
885           OS << ": ";
886         D->getSourceRange().print(OS, Ctx.getSourceManager());
887       } else if (const auto *S = Node.get<Stmt>()) {
888         OS << S->getStmtClassName() << " : ";
889         S->getSourceRange().print(OS, Ctx.getSourceManager());
890       } else if (const auto *T = Node.get<Type>()) {
891         OS << T->getTypeClassName() << "Type : ";
892         QualType(T, 0).print(OS, Ctx.getPrintingPolicy());
893       } else if (const auto *QT = Node.get<QualType>()) {
894         OS << "QualType : ";
895         QT->print(OS, Ctx.getPrintingPolicy());
896       } else {
897         OS << Node.getNodeKind().asStringRef() << " : ";
898         Node.getSourceRange().print(OS, Ctx.getSourceManager());
899       }
900     }
901 
dumpNodeFromState(const ASTContext & Ctx,const CurMatchData & State,raw_ostream & OS)902     static void dumpNodeFromState(const ASTContext &Ctx,
903                                   const CurMatchData &State, raw_ostream &OS) {
904       if (const DynTypedNode *MatchNode = State.getNode<DynTypedNode>()) {
905         dumpNode(Ctx, *MatchNode, OS);
906       } else if (const auto *QT = State.getNode<QualType>()) {
907         dumpNode(Ctx, DynTypedNode::create(*QT), OS);
908       } else if (const auto *TL = State.getNode<TypeLoc>()) {
909         dumpNode(Ctx, DynTypedNode::create(*TL), OS);
910       } else if (const auto *NNS = State.getNode<NestedNameSpecifier>()) {
911         dumpNode(Ctx, DynTypedNode::create(*NNS), OS);
912       } else if (const auto *NNSL = State.getNode<NestedNameSpecifierLoc>()) {
913         dumpNode(Ctx, DynTypedNode::create(*NNSL), OS);
914       } else if (const auto *CtorInit = State.getNode<CXXCtorInitializer>()) {
915         dumpNode(Ctx, DynTypedNode::create(*CtorInit), OS);
916       } else if (const auto *TAL = State.getNode<TemplateArgumentLoc>()) {
917         dumpNode(Ctx, DynTypedNode::create(*TAL), OS);
918       } else if (const auto *At = State.getNode<Attr>()) {
919         dumpNode(Ctx, DynTypedNode::create(*At), OS);
920       }
921     }
922 
923   public:
TraceReporter(const MatchASTVisitor & MV)924     TraceReporter(const MatchASTVisitor &MV) : MV(MV) {}
print(raw_ostream & OS) const925     void print(raw_ostream &OS) const override {
926       const CurMatchData &State = MV.CurMatchState;
927       const MatchCallback *CB = State.getCallback();
928       if (!CB) {
929         OS << "ASTMatcher: Not currently matching\n";
930         return;
931       }
932 
933       assert(MV.ActiveASTContext &&
934              "ActiveASTContext should be set if there is a matched callback");
935 
936       ASTContext &Ctx = MV.getASTContext();
937 
938       if (const BoundNodes *Nodes = State.getBoundNodes()) {
939         OS << "ASTMatcher: Processing '" << CB->getID() << "' against:\n\t";
940         dumpNodeFromState(Ctx, State, OS);
941         const BoundNodes::IDToNodeMap &Map = Nodes->getMap();
942         if (Map.empty()) {
943           OS << "\nNo bound nodes\n";
944           return;
945         }
946         OS << "\n--- Bound Nodes Begin ---\n";
947         for (const auto &Item : Map) {
948           OS << "    " << Item.first << " - { ";
949           dumpNode(Ctx, Item.second, OS);
950           OS << " }\n";
951         }
952         OS << "--- Bound Nodes End ---\n";
953       } else {
954         OS << "ASTMatcher: Matching '" << CB->getID() << "' against:\n\t";
955         dumpNodeFromState(Ctx, State, OS);
956         OS << '\n';
957       }
958     }
959 
960   private:
961     const MatchASTVisitor &MV;
962   };
963 
964 private:
965   struct ASTNodeNotSpelledInSourceScope {
ASTNodeNotSpelledInSourceScopeclang::ast_matchers::internal::__anon55449d7d0111::MatchASTVisitor::ASTNodeNotSpelledInSourceScope966     ASTNodeNotSpelledInSourceScope(MatchASTVisitor *V, bool B)
967         : MV(V), MB(V->TraversingASTNodeNotSpelledInSource) {
968       V->TraversingASTNodeNotSpelledInSource = B;
969     }
~ASTNodeNotSpelledInSourceScopeclang::ast_matchers::internal::__anon55449d7d0111::MatchASTVisitor::ASTNodeNotSpelledInSourceScope970     ~ASTNodeNotSpelledInSourceScope() {
971       MV->TraversingASTNodeNotSpelledInSource = MB;
972     }
973 
974   private:
975     MatchASTVisitor *MV;
976     bool MB;
977   };
978 
979   struct ASTNodeNotAsIsSourceScope {
ASTNodeNotAsIsSourceScopeclang::ast_matchers::internal::__anon55449d7d0111::MatchASTVisitor::ASTNodeNotAsIsSourceScope980     ASTNodeNotAsIsSourceScope(MatchASTVisitor *V, bool B)
981         : MV(V), MB(V->TraversingASTNodeNotAsIs) {
982       V->TraversingASTNodeNotAsIs = B;
983     }
~ASTNodeNotAsIsSourceScopeclang::ast_matchers::internal::__anon55449d7d0111::MatchASTVisitor::ASTNodeNotAsIsSourceScope984     ~ASTNodeNotAsIsSourceScope() { MV->TraversingASTNodeNotAsIs = MB; }
985 
986   private:
987     MatchASTVisitor *MV;
988     bool MB;
989   };
990 
991   class TimeBucketRegion {
992   public:
993     TimeBucketRegion() = default;
~TimeBucketRegion()994     ~TimeBucketRegion() { setBucket(nullptr); }
995 
996     /// Start timing for \p NewBucket.
997     ///
998     /// If there was a bucket already set, it will finish the timing for that
999     /// other bucket.
1000     /// \p NewBucket will be timed until the next call to \c setBucket() or
1001     /// until the \c TimeBucketRegion is destroyed.
1002     /// If \p NewBucket is the same as the currently timed bucket, this call
1003     /// does nothing.
setBucket(llvm::TimeRecord * NewBucket)1004     void setBucket(llvm::TimeRecord *NewBucket) {
1005       if (Bucket != NewBucket) {
1006         auto Now = llvm::TimeRecord::getCurrentTime(true);
1007         if (Bucket)
1008           *Bucket += Now;
1009         if (NewBucket)
1010           *NewBucket -= Now;
1011         Bucket = NewBucket;
1012       }
1013     }
1014 
1015   private:
1016     llvm::TimeRecord *Bucket = nullptr;
1017   };
1018 
1019   /// Runs all the \p Matchers on \p Node.
1020   ///
1021   /// Used by \c matchDispatch() below.
1022   template <typename T, typename MC>
matchWithoutFilter(const T & Node,const MC & Matchers)1023   void matchWithoutFilter(const T &Node, const MC &Matchers) {
1024     const bool EnableCheckProfiling = Options.CheckProfiling.has_value();
1025     TimeBucketRegion Timer;
1026     for (const auto &MP : Matchers) {
1027       if (EnableCheckProfiling)
1028         Timer.setBucket(&TimeByBucket[MP.second->getID()]);
1029       BoundNodesTreeBuilder Builder;
1030       CurMatchRAII RAII(*this, MP.second, Node);
1031       if (MP.first.matches(Node, this, &Builder)) {
1032         MatchVisitor Visitor(*this, ActiveASTContext, MP.second);
1033         Builder.visitMatches(&Visitor);
1034       }
1035     }
1036   }
1037 
matchWithFilter(const DynTypedNode & DynNode)1038   void matchWithFilter(const DynTypedNode &DynNode) {
1039     auto Kind = DynNode.getNodeKind();
1040     auto it = MatcherFiltersMap.find(Kind);
1041     const auto &Filter =
1042         it != MatcherFiltersMap.end() ? it->second : getFilterForKind(Kind);
1043 
1044     if (Filter.empty())
1045       return;
1046 
1047     const bool EnableCheckProfiling = Options.CheckProfiling.has_value();
1048     TimeBucketRegion Timer;
1049     auto &Matchers = this->Matchers->DeclOrStmt;
1050     for (unsigned short I : Filter) {
1051       auto &MP = Matchers[I];
1052       if (EnableCheckProfiling)
1053         Timer.setBucket(&TimeByBucket[MP.second->getID()]);
1054       BoundNodesTreeBuilder Builder;
1055 
1056       {
1057         TraversalKindScope RAII(getASTContext(), MP.first.getTraversalKind());
1058         if (getASTContext().getParentMapContext().traverseIgnored(DynNode) !=
1059             DynNode)
1060           continue;
1061       }
1062 
1063       CurMatchRAII RAII(*this, MP.second, DynNode);
1064       if (MP.first.matches(DynNode, this, &Builder)) {
1065         MatchVisitor Visitor(*this, ActiveASTContext, MP.second);
1066         Builder.visitMatches(&Visitor);
1067       }
1068     }
1069   }
1070 
getFilterForKind(ASTNodeKind Kind)1071   const std::vector<unsigned short> &getFilterForKind(ASTNodeKind Kind) {
1072     auto &Filter = MatcherFiltersMap[Kind];
1073     auto &Matchers = this->Matchers->DeclOrStmt;
1074     assert((Matchers.size() < USHRT_MAX) && "Too many matchers.");
1075     for (unsigned I = 0, E = Matchers.size(); I != E; ++I) {
1076       if (Matchers[I].first.canMatchNodesOfKind(Kind)) {
1077         Filter.push_back(I);
1078       }
1079     }
1080     return Filter;
1081   }
1082 
1083   /// @{
1084   /// Overloads to pair the different node types to their matchers.
matchDispatch(const Decl * Node)1085   void matchDispatch(const Decl *Node) {
1086     return matchWithFilter(DynTypedNode::create(*Node));
1087   }
matchDispatch(const Stmt * Node)1088   void matchDispatch(const Stmt *Node) {
1089     return matchWithFilter(DynTypedNode::create(*Node));
1090   }
1091 
matchDispatch(const Type * Node)1092   void matchDispatch(const Type *Node) {
1093     matchWithoutFilter(QualType(Node, 0), Matchers->Type);
1094   }
matchDispatch(const TypeLoc * Node)1095   void matchDispatch(const TypeLoc *Node) {
1096     matchWithoutFilter(*Node, Matchers->TypeLoc);
1097   }
matchDispatch(const QualType * Node)1098   void matchDispatch(const QualType *Node) {
1099     matchWithoutFilter(*Node, Matchers->Type);
1100   }
matchDispatch(const NestedNameSpecifier * Node)1101   void matchDispatch(const NestedNameSpecifier *Node) {
1102     matchWithoutFilter(*Node, Matchers->NestedNameSpecifier);
1103   }
matchDispatch(const NestedNameSpecifierLoc * Node)1104   void matchDispatch(const NestedNameSpecifierLoc *Node) {
1105     matchWithoutFilter(*Node, Matchers->NestedNameSpecifierLoc);
1106   }
matchDispatch(const CXXCtorInitializer * Node)1107   void matchDispatch(const CXXCtorInitializer *Node) {
1108     matchWithoutFilter(*Node, Matchers->CtorInit);
1109   }
matchDispatch(const TemplateArgumentLoc * Node)1110   void matchDispatch(const TemplateArgumentLoc *Node) {
1111     matchWithoutFilter(*Node, Matchers->TemplateArgumentLoc);
1112   }
matchDispatch(const Attr * Node)1113   void matchDispatch(const Attr *Node) {
1114     matchWithoutFilter(*Node, Matchers->Attr);
1115   }
matchDispatch(const void *)1116   void matchDispatch(const void *) { /* Do nothing. */ }
1117   /// @}
1118 
1119   // Returns whether a direct parent of \p Node matches \p Matcher.
1120   // Unlike matchesAnyAncestorOf there's no memoization: it doesn't save much.
matchesParentOf(const DynTypedNode & Node,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder)1121   bool matchesParentOf(const DynTypedNode &Node, const DynTypedMatcher &Matcher,
1122                        BoundNodesTreeBuilder *Builder) {
1123     for (const auto &Parent : ActiveASTContext->getParents(Node)) {
1124       BoundNodesTreeBuilder BuilderCopy = *Builder;
1125       if (Matcher.matches(Parent, this, &BuilderCopy)) {
1126         *Builder = std::move(BuilderCopy);
1127         return true;
1128       }
1129     }
1130     return false;
1131   }
1132 
1133   // Returns whether an ancestor of \p Node matches \p Matcher.
1134   //
1135   // The order of matching (which can lead to different nodes being bound in
1136   // case there are multiple matches) is breadth first search.
1137   //
1138   // To allow memoization in the very common case of having deeply nested
1139   // expressions inside a template function, we first walk up the AST, memoizing
1140   // the result of the match along the way, as long as there is only a single
1141   // parent.
1142   //
1143   // Once there are multiple parents, the breadth first search order does not
1144   // allow simple memoization on the ancestors. Thus, we only memoize as long
1145   // as there is a single parent.
1146   //
1147   // We avoid a recursive implementation to prevent excessive stack use on
1148   // very deep ASTs (similarly to RecursiveASTVisitor's data recursion).
matchesAnyAncestorOf(DynTypedNode Node,ASTContext & Ctx,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder)1149   bool matchesAnyAncestorOf(DynTypedNode Node, ASTContext &Ctx,
1150                             const DynTypedMatcher &Matcher,
1151                             BoundNodesTreeBuilder *Builder) {
1152 
1153     // Memoization keys that can be updated with the result.
1154     // These are the memoizable nodes in the chain of unique parents, which
1155     // terminates when a node has multiple parents, or matches, or is the root.
1156     std::vector<MatchKey> Keys;
1157     // When returning, update the memoization cache.
1158     auto Finish = [&](bool Matched) {
1159       for (const auto &Key : Keys) {
1160         MemoizedMatchResult &CachedResult = ResultCache[Key];
1161         CachedResult.ResultOfMatch = Matched;
1162         CachedResult.Nodes = *Builder;
1163       }
1164       return Matched;
1165     };
1166 
1167     // Loop while there's a single parent and we want to attempt memoization.
1168     DynTypedNodeList Parents{ArrayRef<DynTypedNode>()}; // after loop: size != 1
1169     for (;;) {
1170       // A cache key only makes sense if memoization is possible.
1171       if (Builder->isComparable()) {
1172         Keys.emplace_back();
1173         Keys.back().MatcherID = Matcher.getID();
1174         Keys.back().Node = Node;
1175         Keys.back().BoundNodes = *Builder;
1176         Keys.back().Traversal = Ctx.getParentMapContext().getTraversalKind();
1177         Keys.back().Type = MatchType::Ancestors;
1178 
1179         // Check the cache.
1180         MemoizationMap::iterator I = ResultCache.find(Keys.back());
1181         if (I != ResultCache.end()) {
1182           Keys.pop_back(); // Don't populate the cache for the matching node!
1183           *Builder = I->second.Nodes;
1184           return Finish(I->second.ResultOfMatch);
1185         }
1186       }
1187 
1188       Parents = ActiveASTContext->getParents(Node);
1189       // Either no parents or multiple parents: leave chain+memoize mode and
1190       // enter bfs+forgetful mode.
1191       if (Parents.size() != 1)
1192         break;
1193 
1194       // Check the next parent.
1195       Node = *Parents.begin();
1196       BoundNodesTreeBuilder BuilderCopy = *Builder;
1197       if (Matcher.matches(Node, this, &BuilderCopy)) {
1198         *Builder = std::move(BuilderCopy);
1199         return Finish(true);
1200       }
1201     }
1202     // We reached the end of the chain.
1203 
1204     if (Parents.empty()) {
1205       // Nodes may have no parents if:
1206       //  a) the node is the TranslationUnitDecl
1207       //  b) we have a limited traversal scope that excludes the parent edges
1208       //  c) there is a bug in the AST, and the node is not reachable
1209       // Usually the traversal scope is the whole AST, which precludes b.
1210       // Bugs are common enough that it's worthwhile asserting when we can.
1211 #ifndef NDEBUG
1212       if (!Node.get<TranslationUnitDecl>() &&
1213           /* Traversal scope is full AST if any of the bounds are the TU */
1214           llvm::any_of(ActiveASTContext->getTraversalScope(), [](Decl *D) {
1215             return D->getKind() == Decl::TranslationUnit;
1216           })) {
1217         llvm::errs() << "Tried to match orphan node:\n";
1218         Node.dump(llvm::errs(), *ActiveASTContext);
1219         llvm_unreachable("Parent map should be complete!");
1220       }
1221 #endif
1222     } else {
1223       assert(Parents.size() > 1);
1224       // BFS starting from the parents not yet considered.
1225       // Memoization of newly visited nodes is not possible (but we still update
1226       // results for the elements in the chain we found above).
1227       std::deque<DynTypedNode> Queue(Parents.begin(), Parents.end());
1228       llvm::DenseSet<const void *> Visited;
1229       while (!Queue.empty()) {
1230         BoundNodesTreeBuilder BuilderCopy = *Builder;
1231         if (Matcher.matches(Queue.front(), this, &BuilderCopy)) {
1232           *Builder = std::move(BuilderCopy);
1233           return Finish(true);
1234         }
1235         for (const auto &Parent : ActiveASTContext->getParents(Queue.front())) {
1236           // Make sure we do not visit the same node twice.
1237           // Otherwise, we'll visit the common ancestors as often as there
1238           // are splits on the way down.
1239           if (Visited.insert(Parent.getMemoizationData()).second)
1240             Queue.push_back(Parent);
1241         }
1242         Queue.pop_front();
1243       }
1244     }
1245     return Finish(false);
1246   }
1247 
1248   // Implements a BoundNodesTree::Visitor that calls a MatchCallback with
1249   // the aggregated bound nodes for each match.
1250   class MatchVisitor : public BoundNodesTreeBuilder::Visitor {
1251     struct CurBoundScope {
CurBoundScopeclang::ast_matchers::internal::__anon55449d7d0111::MatchASTVisitor::MatchVisitor::CurBoundScope1252       CurBoundScope(MatchASTVisitor::CurMatchData &State, const BoundNodes &BN)
1253           : State(State) {
1254         State.SetBoundNodes(BN);
1255       }
1256 
~CurBoundScopeclang::ast_matchers::internal::__anon55449d7d0111::MatchASTVisitor::MatchVisitor::CurBoundScope1257       ~CurBoundScope() { State.clearBoundNodes(); }
1258 
1259     private:
1260       MatchASTVisitor::CurMatchData &State;
1261     };
1262 
1263   public:
MatchVisitor(MatchASTVisitor & MV,ASTContext * Context,MatchFinder::MatchCallback * Callback)1264     MatchVisitor(MatchASTVisitor &MV, ASTContext *Context,
1265                  MatchFinder::MatchCallback *Callback)
1266         : State(MV.CurMatchState), Context(Context), Callback(Callback) {}
1267 
visitMatch(const BoundNodes & BoundNodesView)1268     void visitMatch(const BoundNodes& BoundNodesView) override {
1269       TraversalKindScope RAII(*Context, Callback->getCheckTraversalKind());
1270       CurBoundScope RAII2(State, BoundNodesView);
1271       Callback->run(MatchFinder::MatchResult(BoundNodesView, Context));
1272     }
1273 
1274   private:
1275     MatchASTVisitor::CurMatchData &State;
1276     ASTContext* Context;
1277     MatchFinder::MatchCallback* Callback;
1278   };
1279 
1280   // Returns true if 'TypeNode' has an alias that matches the given matcher.
typeHasMatchingAlias(const Type * TypeNode,const Matcher<NamedDecl> & Matcher,BoundNodesTreeBuilder * Builder)1281   bool typeHasMatchingAlias(const Type *TypeNode,
1282                             const Matcher<NamedDecl> &Matcher,
1283                             BoundNodesTreeBuilder *Builder) {
1284     const Type *const CanonicalType =
1285       ActiveASTContext->getCanonicalType(TypeNode);
1286     auto Aliases = TypeAliases.find(CanonicalType);
1287     if (Aliases == TypeAliases.end())
1288       return false;
1289     for (const TypedefNameDecl *Alias : Aliases->second) {
1290       BoundNodesTreeBuilder Result(*Builder);
1291       if (Matcher.matches(*Alias, this, &Result)) {
1292         *Builder = std::move(Result);
1293         return true;
1294       }
1295     }
1296     return false;
1297   }
1298 
1299   bool
objcClassHasMatchingCompatibilityAlias(const ObjCInterfaceDecl * InterfaceDecl,const Matcher<NamedDecl> & Matcher,BoundNodesTreeBuilder * Builder)1300   objcClassHasMatchingCompatibilityAlias(const ObjCInterfaceDecl *InterfaceDecl,
1301                                          const Matcher<NamedDecl> &Matcher,
1302                                          BoundNodesTreeBuilder *Builder) {
1303     auto Aliases = CompatibleAliases.find(InterfaceDecl);
1304     if (Aliases == CompatibleAliases.end())
1305       return false;
1306     for (const ObjCCompatibleAliasDecl *Alias : Aliases->second) {
1307       BoundNodesTreeBuilder Result(*Builder);
1308       if (Matcher.matches(*Alias, this, &Result)) {
1309         *Builder = std::move(Result);
1310         return true;
1311       }
1312     }
1313     return false;
1314   }
1315 
1316   /// Bucket to record map.
1317   ///
1318   /// Used to get the appropriate bucket for each matcher.
1319   llvm::StringMap<llvm::TimeRecord> TimeByBucket;
1320 
1321   const MatchFinder::MatchersByType *Matchers;
1322 
1323   /// Filtered list of matcher indices for each matcher kind.
1324   ///
1325   /// \c Decl and \c Stmt toplevel matchers usually apply to a specific node
1326   /// kind (and derived kinds) so it is a waste to try every matcher on every
1327   /// node.
1328   /// We precalculate a list of matchers that pass the toplevel restrict check.
1329   llvm::DenseMap<ASTNodeKind, std::vector<unsigned short>> MatcherFiltersMap;
1330 
1331   const MatchFinder::MatchFinderOptions &Options;
1332   ASTContext *ActiveASTContext;
1333 
1334   // Maps a canonical type to its TypedefDecls.
1335   llvm::DenseMap<const Type*, std::set<const TypedefNameDecl*> > TypeAliases;
1336 
1337   // Maps an Objective-C interface to its ObjCCompatibleAliasDecls.
1338   llvm::DenseMap<const ObjCInterfaceDecl *,
1339                  llvm::SmallPtrSet<const ObjCCompatibleAliasDecl *, 2>>
1340       CompatibleAliases;
1341 
1342   // Maps (matcher, node) -> the match result for memoization.
1343   typedef std::map<MatchKey, MemoizedMatchResult> MemoizationMap;
1344   MemoizationMap ResultCache;
1345 };
1346 
1347 static CXXRecordDecl *
getAsCXXRecordDeclOrPrimaryTemplate(const Type * TypeNode)1348 getAsCXXRecordDeclOrPrimaryTemplate(const Type *TypeNode) {
1349   if (auto *RD = TypeNode->getAsCXXRecordDecl())
1350     return RD;
1351 
1352   // Find the innermost TemplateSpecializationType that isn't an alias template.
1353   auto *TemplateType = TypeNode->getAs<TemplateSpecializationType>();
1354   while (TemplateType && TemplateType->isTypeAlias())
1355     TemplateType =
1356         TemplateType->getAliasedType()->getAs<TemplateSpecializationType>();
1357 
1358   // If this is the name of a (dependent) template specialization, use the
1359   // definition of the template, even though it might be specialized later.
1360   if (TemplateType)
1361     if (auto *ClassTemplate = dyn_cast_or_null<ClassTemplateDecl>(
1362           TemplateType->getTemplateName().getAsTemplateDecl()))
1363       return ClassTemplate->getTemplatedDecl();
1364 
1365   return nullptr;
1366 }
1367 
1368 // Returns true if the given C++ class is directly or indirectly derived
1369 // from a base type with the given name.  A class is not considered to be
1370 // derived from itself.
classIsDerivedFrom(const CXXRecordDecl * Declaration,const Matcher<NamedDecl> & Base,BoundNodesTreeBuilder * Builder,bool Directly)1371 bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration,
1372                                          const Matcher<NamedDecl> &Base,
1373                                          BoundNodesTreeBuilder *Builder,
1374                                          bool Directly) {
1375   llvm::SmallPtrSet<const CXXRecordDecl *, 8> Visited;
1376   return classIsDerivedFromImpl(Declaration, Base, Builder, Directly, Visited);
1377 }
1378 
classIsDerivedFromImpl(const CXXRecordDecl * Declaration,const Matcher<NamedDecl> & Base,BoundNodesTreeBuilder * Builder,bool Directly,llvm::SmallPtrSetImpl<const CXXRecordDecl * > & Visited)1379 bool MatchASTVisitor::classIsDerivedFromImpl(
1380     const CXXRecordDecl *Declaration, const Matcher<NamedDecl> &Base,
1381     BoundNodesTreeBuilder *Builder, bool Directly,
1382     llvm::SmallPtrSetImpl<const CXXRecordDecl *> &Visited) {
1383   if (!Declaration->hasDefinition())
1384     return false;
1385   if (!Visited.insert(Declaration).second)
1386     return false;
1387   for (const auto &It : Declaration->bases()) {
1388     const Type *TypeNode = It.getType().getTypePtr();
1389 
1390     if (typeHasMatchingAlias(TypeNode, Base, Builder))
1391       return true;
1392 
1393     // FIXME: Going to the primary template here isn't really correct, but
1394     // unfortunately we accept a Decl matcher for the base class not a Type
1395     // matcher, so it's the best thing we can do with our current interface.
1396     CXXRecordDecl *ClassDecl = getAsCXXRecordDeclOrPrimaryTemplate(TypeNode);
1397     if (!ClassDecl)
1398       continue;
1399     if (ClassDecl == Declaration) {
1400       // This can happen for recursive template definitions.
1401       continue;
1402     }
1403     BoundNodesTreeBuilder Result(*Builder);
1404     if (Base.matches(*ClassDecl, this, &Result)) {
1405       *Builder = std::move(Result);
1406       return true;
1407     }
1408     if (!Directly &&
1409         classIsDerivedFromImpl(ClassDecl, Base, Builder, Directly, Visited))
1410       return true;
1411   }
1412   return false;
1413 }
1414 
1415 // Returns true if the given Objective-C class is directly or indirectly
1416 // derived from a matching base class. A class is not considered to be derived
1417 // from itself.
objcClassIsDerivedFrom(const ObjCInterfaceDecl * Declaration,const Matcher<NamedDecl> & Base,BoundNodesTreeBuilder * Builder,bool Directly)1418 bool MatchASTVisitor::objcClassIsDerivedFrom(
1419     const ObjCInterfaceDecl *Declaration, const Matcher<NamedDecl> &Base,
1420     BoundNodesTreeBuilder *Builder, bool Directly) {
1421   // Check if any of the superclasses of the class match.
1422   for (const ObjCInterfaceDecl *ClassDecl = Declaration->getSuperClass();
1423        ClassDecl != nullptr; ClassDecl = ClassDecl->getSuperClass()) {
1424     // Check if there are any matching compatibility aliases.
1425     if (objcClassHasMatchingCompatibilityAlias(ClassDecl, Base, Builder))
1426       return true;
1427 
1428     // Check if there are any matching type aliases.
1429     const Type *TypeNode = ClassDecl->getTypeForDecl();
1430     if (typeHasMatchingAlias(TypeNode, Base, Builder))
1431       return true;
1432 
1433     if (Base.matches(*ClassDecl, this, Builder))
1434       return true;
1435 
1436     // Not `return false` as a temporary workaround for PR43879.
1437     if (Directly)
1438       break;
1439   }
1440 
1441   return false;
1442 }
1443 
TraverseDecl(Decl * DeclNode)1444 bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) {
1445   if (!DeclNode) {
1446     return true;
1447   }
1448 
1449   bool ScopedTraversal =
1450       TraversingASTNodeNotSpelledInSource || DeclNode->isImplicit();
1451   bool ScopedChildren = TraversingASTChildrenNotSpelledInSource;
1452 
1453   if (const auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(DeclNode)) {
1454     auto SK = CTSD->getSpecializationKind();
1455     if (SK == TSK_ExplicitInstantiationDeclaration ||
1456         SK == TSK_ExplicitInstantiationDefinition)
1457       ScopedChildren = true;
1458   } else if (const auto *FD = dyn_cast<FunctionDecl>(DeclNode)) {
1459     if (FD->isDefaulted())
1460       ScopedChildren = true;
1461     if (FD->isTemplateInstantiation())
1462       ScopedTraversal = true;
1463   } else if (isa<BindingDecl>(DeclNode)) {
1464     ScopedChildren = true;
1465   }
1466 
1467   ASTNodeNotSpelledInSourceScope RAII1(this, ScopedTraversal);
1468   ASTChildrenNotSpelledInSourceScope RAII2(this, ScopedChildren);
1469 
1470   match(*DeclNode);
1471   return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode);
1472 }
1473 
TraverseStmt(Stmt * StmtNode,DataRecursionQueue * Queue)1474 bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue) {
1475   if (!StmtNode) {
1476     return true;
1477   }
1478   bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
1479                          TraversingASTChildrenNotSpelledInSource;
1480 
1481   ASTNodeNotSpelledInSourceScope RAII(this, ScopedTraversal);
1482   match(*StmtNode);
1483   return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode, Queue);
1484 }
1485 
TraverseType(QualType TypeNode)1486 bool MatchASTVisitor::TraverseType(QualType TypeNode) {
1487   match(TypeNode);
1488   return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode);
1489 }
1490 
TraverseTypeLoc(TypeLoc TypeLocNode)1491 bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLocNode) {
1492   // The RecursiveASTVisitor only visits types if they're not within TypeLocs.
1493   // We still want to find those types via matchers, so we match them here. Note
1494   // that the TypeLocs are structurally a shadow-hierarchy to the expressed
1495   // type, so we visit all involved parts of a compound type when matching on
1496   // each TypeLoc.
1497   match(TypeLocNode);
1498   match(TypeLocNode.getType());
1499   return RecursiveASTVisitor<MatchASTVisitor>::TraverseTypeLoc(TypeLocNode);
1500 }
1501 
TraverseNestedNameSpecifier(NestedNameSpecifier * NNS)1502 bool MatchASTVisitor::TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
1503   match(*NNS);
1504   return RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifier(NNS);
1505 }
1506 
TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS)1507 bool MatchASTVisitor::TraverseNestedNameSpecifierLoc(
1508     NestedNameSpecifierLoc NNS) {
1509   if (!NNS)
1510     return true;
1511 
1512   match(NNS);
1513 
1514   // We only match the nested name specifier here (as opposed to traversing it)
1515   // because the traversal is already done in the parallel "Loc"-hierarchy.
1516   if (NNS.hasQualifier())
1517     match(*NNS.getNestedNameSpecifier());
1518   return
1519       RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifierLoc(NNS);
1520 }
1521 
TraverseConstructorInitializer(CXXCtorInitializer * CtorInit)1522 bool MatchASTVisitor::TraverseConstructorInitializer(
1523     CXXCtorInitializer *CtorInit) {
1524   if (!CtorInit)
1525     return true;
1526 
1527   bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
1528                          TraversingASTChildrenNotSpelledInSource;
1529 
1530   if (!CtorInit->isWritten())
1531     ScopedTraversal = true;
1532 
1533   ASTNodeNotSpelledInSourceScope RAII1(this, ScopedTraversal);
1534 
1535   match(*CtorInit);
1536 
1537   return RecursiveASTVisitor<MatchASTVisitor>::TraverseConstructorInitializer(
1538       CtorInit);
1539 }
1540 
TraverseTemplateArgumentLoc(TemplateArgumentLoc Loc)1541 bool MatchASTVisitor::TraverseTemplateArgumentLoc(TemplateArgumentLoc Loc) {
1542   match(Loc);
1543   return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateArgumentLoc(Loc);
1544 }
1545 
TraverseAttr(Attr * AttrNode)1546 bool MatchASTVisitor::TraverseAttr(Attr *AttrNode) {
1547   match(*AttrNode);
1548   return RecursiveASTVisitor<MatchASTVisitor>::TraverseAttr(AttrNode);
1549 }
1550 
1551 class MatchASTConsumer : public ASTConsumer {
1552 public:
MatchASTConsumer(MatchFinder * Finder,MatchFinder::ParsingDoneTestCallback * ParsingDone)1553   MatchASTConsumer(MatchFinder *Finder,
1554                    MatchFinder::ParsingDoneTestCallback *ParsingDone)
1555       : Finder(Finder), ParsingDone(ParsingDone) {}
1556 
1557 private:
HandleTranslationUnit(ASTContext & Context)1558   void HandleTranslationUnit(ASTContext &Context) override {
1559     if (ParsingDone != nullptr) {
1560       ParsingDone->run();
1561     }
1562     Finder->matchAST(Context);
1563   }
1564 
1565   MatchFinder *Finder;
1566   MatchFinder::ParsingDoneTestCallback *ParsingDone;
1567 };
1568 
1569 } // end namespace
1570 } // end namespace internal
1571 
MatchResult(const BoundNodes & Nodes,ASTContext * Context)1572 MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes,
1573                                       ASTContext *Context)
1574   : Nodes(Nodes), Context(Context),
1575     SourceManager(&Context->getSourceManager()) {}
1576 
~MatchCallback()1577 MatchFinder::MatchCallback::~MatchCallback() {}
~ParsingDoneTestCallback()1578 MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {}
1579 
MatchFinder(MatchFinderOptions Options)1580 MatchFinder::MatchFinder(MatchFinderOptions Options)
1581     : Options(std::move(Options)), ParsingDone(nullptr) {}
1582 
~MatchFinder()1583 MatchFinder::~MatchFinder() {}
1584 
addMatcher(const DeclarationMatcher & NodeMatch,MatchCallback * Action)1585 void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch,
1586                              MatchCallback *Action) {
1587   std::optional<TraversalKind> TK;
1588   if (Action)
1589     TK = Action->getCheckTraversalKind();
1590   if (TK)
1591     Matchers.DeclOrStmt.emplace_back(traverse(*TK, NodeMatch), Action);
1592   else
1593     Matchers.DeclOrStmt.emplace_back(NodeMatch, Action);
1594   Matchers.AllCallbacks.insert(Action);
1595 }
1596 
addMatcher(const TypeMatcher & NodeMatch,MatchCallback * Action)1597 void MatchFinder::addMatcher(const TypeMatcher &NodeMatch,
1598                              MatchCallback *Action) {
1599   Matchers.Type.emplace_back(NodeMatch, Action);
1600   Matchers.AllCallbacks.insert(Action);
1601 }
1602 
addMatcher(const StatementMatcher & NodeMatch,MatchCallback * Action)1603 void MatchFinder::addMatcher(const StatementMatcher &NodeMatch,
1604                              MatchCallback *Action) {
1605   std::optional<TraversalKind> TK;
1606   if (Action)
1607     TK = Action->getCheckTraversalKind();
1608   if (TK)
1609     Matchers.DeclOrStmt.emplace_back(traverse(*TK, NodeMatch), Action);
1610   else
1611     Matchers.DeclOrStmt.emplace_back(NodeMatch, Action);
1612   Matchers.AllCallbacks.insert(Action);
1613 }
1614 
addMatcher(const NestedNameSpecifierMatcher & NodeMatch,MatchCallback * Action)1615 void MatchFinder::addMatcher(const NestedNameSpecifierMatcher &NodeMatch,
1616                              MatchCallback *Action) {
1617   Matchers.NestedNameSpecifier.emplace_back(NodeMatch, Action);
1618   Matchers.AllCallbacks.insert(Action);
1619 }
1620 
addMatcher(const NestedNameSpecifierLocMatcher & NodeMatch,MatchCallback * Action)1621 void MatchFinder::addMatcher(const NestedNameSpecifierLocMatcher &NodeMatch,
1622                              MatchCallback *Action) {
1623   Matchers.NestedNameSpecifierLoc.emplace_back(NodeMatch, Action);
1624   Matchers.AllCallbacks.insert(Action);
1625 }
1626 
addMatcher(const TypeLocMatcher & NodeMatch,MatchCallback * Action)1627 void MatchFinder::addMatcher(const TypeLocMatcher &NodeMatch,
1628                              MatchCallback *Action) {
1629   Matchers.TypeLoc.emplace_back(NodeMatch, Action);
1630   Matchers.AllCallbacks.insert(Action);
1631 }
1632 
addMatcher(const CXXCtorInitializerMatcher & NodeMatch,MatchCallback * Action)1633 void MatchFinder::addMatcher(const CXXCtorInitializerMatcher &NodeMatch,
1634                              MatchCallback *Action) {
1635   Matchers.CtorInit.emplace_back(NodeMatch, Action);
1636   Matchers.AllCallbacks.insert(Action);
1637 }
1638 
addMatcher(const TemplateArgumentLocMatcher & NodeMatch,MatchCallback * Action)1639 void MatchFinder::addMatcher(const TemplateArgumentLocMatcher &NodeMatch,
1640                              MatchCallback *Action) {
1641   Matchers.TemplateArgumentLoc.emplace_back(NodeMatch, Action);
1642   Matchers.AllCallbacks.insert(Action);
1643 }
1644 
addMatcher(const AttrMatcher & AttrMatch,MatchCallback * Action)1645 void MatchFinder::addMatcher(const AttrMatcher &AttrMatch,
1646                              MatchCallback *Action) {
1647   Matchers.Attr.emplace_back(AttrMatch, Action);
1648   Matchers.AllCallbacks.insert(Action);
1649 }
1650 
addDynamicMatcher(const internal::DynTypedMatcher & NodeMatch,MatchCallback * Action)1651 bool MatchFinder::addDynamicMatcher(const internal::DynTypedMatcher &NodeMatch,
1652                                     MatchCallback *Action) {
1653   if (NodeMatch.canConvertTo<Decl>()) {
1654     addMatcher(NodeMatch.convertTo<Decl>(), Action);
1655     return true;
1656   } else if (NodeMatch.canConvertTo<QualType>()) {
1657     addMatcher(NodeMatch.convertTo<QualType>(), Action);
1658     return true;
1659   } else if (NodeMatch.canConvertTo<Stmt>()) {
1660     addMatcher(NodeMatch.convertTo<Stmt>(), Action);
1661     return true;
1662   } else if (NodeMatch.canConvertTo<NestedNameSpecifier>()) {
1663     addMatcher(NodeMatch.convertTo<NestedNameSpecifier>(), Action);
1664     return true;
1665   } else if (NodeMatch.canConvertTo<NestedNameSpecifierLoc>()) {
1666     addMatcher(NodeMatch.convertTo<NestedNameSpecifierLoc>(), Action);
1667     return true;
1668   } else if (NodeMatch.canConvertTo<TypeLoc>()) {
1669     addMatcher(NodeMatch.convertTo<TypeLoc>(), Action);
1670     return true;
1671   } else if (NodeMatch.canConvertTo<CXXCtorInitializer>()) {
1672     addMatcher(NodeMatch.convertTo<CXXCtorInitializer>(), Action);
1673     return true;
1674   } else if (NodeMatch.canConvertTo<TemplateArgumentLoc>()) {
1675     addMatcher(NodeMatch.convertTo<TemplateArgumentLoc>(), Action);
1676     return true;
1677   } else if (NodeMatch.canConvertTo<Attr>()) {
1678     addMatcher(NodeMatch.convertTo<Attr>(), Action);
1679     return true;
1680   }
1681   return false;
1682 }
1683 
newASTConsumer()1684 std::unique_ptr<ASTConsumer> MatchFinder::newASTConsumer() {
1685   return std::make_unique<internal::MatchASTConsumer>(this, ParsingDone);
1686 }
1687 
match(const clang::DynTypedNode & Node,ASTContext & Context)1688 void MatchFinder::match(const clang::DynTypedNode &Node, ASTContext &Context) {
1689   internal::MatchASTVisitor Visitor(&Matchers, Options);
1690   Visitor.set_active_ast_context(&Context);
1691   Visitor.match(Node);
1692 }
1693 
matchAST(ASTContext & Context)1694 void MatchFinder::matchAST(ASTContext &Context) {
1695   internal::MatchASTVisitor Visitor(&Matchers, Options);
1696   internal::MatchASTVisitor::TraceReporter StackTrace(Visitor);
1697   Visitor.set_active_ast_context(&Context);
1698   Visitor.onStartOfTranslationUnit();
1699   Visitor.TraverseAST(Context);
1700   Visitor.onEndOfTranslationUnit();
1701 }
1702 
registerTestCallbackAfterParsing(MatchFinder::ParsingDoneTestCallback * NewParsingDone)1703 void MatchFinder::registerTestCallbackAfterParsing(
1704     MatchFinder::ParsingDoneTestCallback *NewParsingDone) {
1705   ParsingDone = NewParsingDone;
1706 }
1707 
getID() const1708 StringRef MatchFinder::MatchCallback::getID() const { return "<unknown>"; }
1709 
1710 std::optional<TraversalKind>
getCheckTraversalKind() const1711 MatchFinder::MatchCallback::getCheckTraversalKind() const {
1712   return std::nullopt;
1713 }
1714 
1715 } // end namespace ast_matchers
1716 } // end namespace clang
1717