1 //===--- ASTMatchFinder.cpp - Structural query framework ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  Implements an algorithm to efficiently search for matches on AST nodes.
10 //  Uses memoization to support recursive matches like HasDescendant.
11 //
12 //  The general idea is to visit all AST nodes with a RecursiveASTVisitor,
13 //  calling the Matches(...) method of each matcher we are running on each
14 //  AST node. The matcher can recurse via the ASTMatchFinder interface.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "clang/ASTMatchers/ASTMatchFinder.h"
19 #include "clang/AST/ASTConsumer.h"
20 #include "clang/AST/ASTContext.h"
21 #include "clang/AST/RecursiveASTVisitor.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/Support/Timer.h"
25 #include <deque>
26 #include <memory>
27 #include <set>
28 
29 namespace clang {
30 namespace ast_matchers {
31 namespace internal {
32 namespace {
33 
34 typedef MatchFinder::MatchCallback MatchCallback;
35 
36 // The maximum number of memoization entries to store.
37 // 10k has been experimentally found to give a good trade-off
38 // of performance vs. memory consumption by running matcher
39 // that match on every statement over a very large codebase.
40 //
41 // FIXME: Do some performance optimization in general and
42 // revisit this number; also, put up micro-benchmarks that we can
43 // optimize this on.
44 static const unsigned MaxMemoizationEntries = 10000;
45 
46 enum class MatchType {
47   Ancestors,
48 
49   Descendants,
50   Child,
51 };
52 
53 // We use memoization to avoid running the same matcher on the same
54 // AST node twice.  This struct is the key for looking up match
55 // result.  It consists of an ID of the MatcherInterface (for
56 // identifying the matcher), a pointer to the AST node and the
57 // bound nodes before the matcher was executed.
58 //
59 // We currently only memoize on nodes whose pointers identify the
60 // nodes (\c Stmt and \c Decl, but not \c QualType or \c TypeLoc).
61 // For \c QualType and \c TypeLoc it is possible to implement
62 // generation of keys for each type.
63 // FIXME: Benchmark whether memoization of non-pointer typed nodes
64 // provides enough benefit for the additional amount of code.
65 struct MatchKey {
66   DynTypedMatcher::MatcherIDType MatcherID;
67   DynTypedNode Node;
68   BoundNodesTreeBuilder BoundNodes;
69   TraversalKind Traversal = TK_AsIs;
70   MatchType Type;
71 
operator <clang::ast_matchers::internal::__anon5573cc030111::MatchKey72   bool operator<(const MatchKey &Other) const {
73     return std::tie(Traversal, Type, MatcherID, Node, BoundNodes) <
74            std::tie(Other.Traversal, Other.Type, Other.MatcherID, Other.Node,
75                     Other.BoundNodes);
76   }
77 };
78 
79 // Used to store the result of a match and possibly bound nodes.
80 struct MemoizedMatchResult {
81   bool ResultOfMatch;
82   BoundNodesTreeBuilder Nodes;
83 };
84 
85 // A RecursiveASTVisitor that traverses all children or all descendants of
86 // a node.
87 class MatchChildASTVisitor
88     : public RecursiveASTVisitor<MatchChildASTVisitor> {
89 public:
90   typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase;
91 
92   // Creates an AST visitor that matches 'matcher' on all children or
93   // descendants of a traversed node. max_depth is the maximum depth
94   // to traverse: use 1 for matching the children and INT_MAX for
95   // matching the descendants.
MatchChildASTVisitor(const DynTypedMatcher * Matcher,ASTMatchFinder * Finder,BoundNodesTreeBuilder * Builder,int MaxDepth,TraversalKind Traversal,ASTMatchFinder::BindKind Bind)96   MatchChildASTVisitor(const DynTypedMatcher *Matcher, ASTMatchFinder *Finder,
97                        BoundNodesTreeBuilder *Builder, int MaxDepth,
98                        TraversalKind Traversal, ASTMatchFinder::BindKind Bind)
99       : Matcher(Matcher), Finder(Finder), Builder(Builder), CurrentDepth(0),
100         MaxDepth(MaxDepth), Traversal(Traversal), Bind(Bind), Matches(false) {}
101 
102   // Returns true if a match is found in the subtree rooted at the
103   // given AST node. This is done via a set of mutually recursive
104   // functions. Here's how the recursion is done (the  *wildcard can
105   // actually be Decl, Stmt, or Type):
106   //
107   //   - Traverse(node) calls BaseTraverse(node) when it needs
108   //     to visit the descendants of node.
109   //   - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node))
110   //     Traverse*(c) for each child c of 'node'.
111   //   - Traverse*(c) in turn calls Traverse(c), completing the
112   //     recursion.
findMatch(const DynTypedNode & DynNode)113   bool findMatch(const DynTypedNode &DynNode) {
114     reset();
115     if (const Decl *D = DynNode.get<Decl>())
116       traverse(*D);
117     else if (const Stmt *S = DynNode.get<Stmt>())
118       traverse(*S);
119     else if (const NestedNameSpecifier *NNS =
120              DynNode.get<NestedNameSpecifier>())
121       traverse(*NNS);
122     else if (const NestedNameSpecifierLoc *NNSLoc =
123              DynNode.get<NestedNameSpecifierLoc>())
124       traverse(*NNSLoc);
125     else if (const QualType *Q = DynNode.get<QualType>())
126       traverse(*Q);
127     else if (const TypeLoc *T = DynNode.get<TypeLoc>())
128       traverse(*T);
129     else if (const auto *C = DynNode.get<CXXCtorInitializer>())
130       traverse(*C);
131     // FIXME: Add other base types after adding tests.
132 
133     // It's OK to always overwrite the bound nodes, as if there was
134     // no match in this recursive branch, the result set is empty
135     // anyway.
136     *Builder = ResultBindings;
137 
138     return Matches;
139   }
140 
141   // The following are overriding methods from the base visitor class.
142   // They are public only to allow CRTP to work. They are *not *part
143   // of the public API of this class.
TraverseDecl(Decl * DeclNode)144   bool TraverseDecl(Decl *DeclNode) {
145     ScopedIncrement ScopedDepth(&CurrentDepth);
146     return (DeclNode == nullptr) || traverse(*DeclNode);
147   }
148 
getStmtToTraverse(Stmt * StmtNode)149   Stmt *getStmtToTraverse(Stmt *StmtNode) {
150     Stmt *StmtToTraverse = StmtNode;
151     if (auto *ExprNode = dyn_cast_or_null<Expr>(StmtNode)) {
152       auto *LambdaNode = dyn_cast_or_null<LambdaExpr>(StmtNode);
153       if (LambdaNode &&
154           Finder->getASTContext().getParentMapContext().getTraversalKind() ==
155               TK_IgnoreUnlessSpelledInSource)
156         StmtToTraverse = LambdaNode;
157       else
158         StmtToTraverse =
159             Finder->getASTContext().getParentMapContext().traverseIgnored(
160                 ExprNode);
161     }
162     if (Traversal == TraversalKind::TK_IgnoreImplicitCastsAndParentheses) {
163       if (Expr *ExprNode = dyn_cast_or_null<Expr>(StmtNode))
164         StmtToTraverse = ExprNode->IgnoreParenImpCasts();
165     }
166     return StmtToTraverse;
167   }
168 
TraverseStmt(Stmt * StmtNode,DataRecursionQueue * Queue=nullptr)169   bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr) {
170     // If we need to keep track of the depth, we can't perform data recursion.
171     if (CurrentDepth == 0 || (CurrentDepth <= MaxDepth && MaxDepth < INT_MAX))
172       Queue = nullptr;
173 
174     ScopedIncrement ScopedDepth(&CurrentDepth);
175     Stmt *StmtToTraverse = getStmtToTraverse(StmtNode);
176     if (!StmtToTraverse)
177       return true;
178     if (!match(*StmtToTraverse))
179       return false;
180     return VisitorBase::TraverseStmt(StmtToTraverse, Queue);
181   }
182   // We assume that the QualType and the contained type are on the same
183   // hierarchy level. Thus, we try to match either of them.
TraverseType(QualType TypeNode)184   bool TraverseType(QualType TypeNode) {
185     if (TypeNode.isNull())
186       return true;
187     ScopedIncrement ScopedDepth(&CurrentDepth);
188     // Match the Type.
189     if (!match(*TypeNode))
190       return false;
191     // The QualType is matched inside traverse.
192     return traverse(TypeNode);
193   }
194   // We assume that the TypeLoc, contained QualType and contained Type all are
195   // on the same hierarchy level. Thus, we try to match all of them.
TraverseTypeLoc(TypeLoc TypeLocNode)196   bool TraverseTypeLoc(TypeLoc TypeLocNode) {
197     if (TypeLocNode.isNull())
198       return true;
199     ScopedIncrement ScopedDepth(&CurrentDepth);
200     // Match the Type.
201     if (!match(*TypeLocNode.getType()))
202       return false;
203     // Match the QualType.
204     if (!match(TypeLocNode.getType()))
205       return false;
206     // The TypeLoc is matched inside traverse.
207     return traverse(TypeLocNode);
208   }
TraverseNestedNameSpecifier(NestedNameSpecifier * NNS)209   bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
210     ScopedIncrement ScopedDepth(&CurrentDepth);
211     return (NNS == nullptr) || traverse(*NNS);
212   }
TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS)213   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) {
214     if (!NNS)
215       return true;
216     ScopedIncrement ScopedDepth(&CurrentDepth);
217     if (!match(*NNS.getNestedNameSpecifier()))
218       return false;
219     return traverse(NNS);
220   }
TraverseConstructorInitializer(CXXCtorInitializer * CtorInit)221   bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit) {
222     if (!CtorInit)
223       return true;
224     ScopedIncrement ScopedDepth(&CurrentDepth);
225     return traverse(*CtorInit);
226   }
TraverseLambdaExpr(LambdaExpr * Node)227   bool TraverseLambdaExpr(LambdaExpr *Node) {
228     if (Finder->getASTContext().getParentMapContext().getTraversalKind() !=
229         TK_IgnoreUnlessSpelledInSource)
230       return VisitorBase::TraverseLambdaExpr(Node);
231     if (!Node)
232       return true;
233     ScopedIncrement ScopedDepth(&CurrentDepth);
234 
235     for (unsigned I = 0, N = Node->capture_size(); I != N; ++I) {
236       const auto *C = Node->capture_begin() + I;
237       if (!C->isExplicit())
238         continue;
239       if (Node->isInitCapture(C) && !match(*C->getCapturedVar()))
240         return false;
241       if (!match(*Node->capture_init_begin()[I]))
242         return false;
243     }
244 
245     if (const auto *TPL = Node->getTemplateParameterList()) {
246       for (const auto *TP : *TPL) {
247         if (!match(*TP))
248           return false;
249       }
250     }
251 
252     for (const auto *P : Node->getCallOperator()->parameters()) {
253       if (!match(*P))
254         return false;
255     }
256 
257     if (!match(*Node->getBody()))
258       return false;
259 
260     return true;
261   }
262 
shouldVisitTemplateInstantiations() const263   bool shouldVisitTemplateInstantiations() const { return true; }
shouldVisitImplicitCode() const264   bool shouldVisitImplicitCode() const { return true; }
265 
266 private:
267   // Used for updating the depth during traversal.
268   struct ScopedIncrement {
ScopedIncrementclang::ast_matchers::internal::__anon5573cc030111::MatchChildASTVisitor::ScopedIncrement269     explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); }
~ScopedIncrementclang::ast_matchers::internal::__anon5573cc030111::MatchChildASTVisitor::ScopedIncrement270     ~ScopedIncrement() { --(*Depth); }
271 
272    private:
273     int *Depth;
274   };
275 
276   // Resets the state of this object.
reset()277   void reset() {
278     Matches = false;
279     CurrentDepth = 0;
280   }
281 
282   // Forwards the call to the corresponding Traverse*() method in the
283   // base visitor class.
baseTraverse(const Decl & DeclNode)284   bool baseTraverse(const Decl &DeclNode) {
285     return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode));
286   }
baseTraverse(const Stmt & StmtNode)287   bool baseTraverse(const Stmt &StmtNode) {
288     return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode));
289   }
baseTraverse(QualType TypeNode)290   bool baseTraverse(QualType TypeNode) {
291     return VisitorBase::TraverseType(TypeNode);
292   }
baseTraverse(TypeLoc TypeLocNode)293   bool baseTraverse(TypeLoc TypeLocNode) {
294     return VisitorBase::TraverseTypeLoc(TypeLocNode);
295   }
baseTraverse(const NestedNameSpecifier & NNS)296   bool baseTraverse(const NestedNameSpecifier &NNS) {
297     return VisitorBase::TraverseNestedNameSpecifier(
298         const_cast<NestedNameSpecifier*>(&NNS));
299   }
baseTraverse(NestedNameSpecifierLoc NNS)300   bool baseTraverse(NestedNameSpecifierLoc NNS) {
301     return VisitorBase::TraverseNestedNameSpecifierLoc(NNS);
302   }
baseTraverse(const CXXCtorInitializer & CtorInit)303   bool baseTraverse(const CXXCtorInitializer &CtorInit) {
304     return VisitorBase::TraverseConstructorInitializer(
305         const_cast<CXXCtorInitializer *>(&CtorInit));
306   }
307 
308   // Sets 'Matched' to true if 'Matcher' matches 'Node' and:
309   //   0 < CurrentDepth <= MaxDepth.
310   //
311   // Returns 'true' if traversal should continue after this function
312   // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
313   template <typename T>
match(const T & Node)314   bool match(const T &Node) {
315     if (CurrentDepth == 0 || CurrentDepth > MaxDepth) {
316       return true;
317     }
318     if (Bind != ASTMatchFinder::BK_All) {
319       BoundNodesTreeBuilder RecursiveBuilder(*Builder);
320       if (Matcher->matches(DynTypedNode::create(Node), Finder,
321                            &RecursiveBuilder)) {
322         Matches = true;
323         ResultBindings.addMatch(RecursiveBuilder);
324         return false; // Abort as soon as a match is found.
325       }
326     } else {
327       BoundNodesTreeBuilder RecursiveBuilder(*Builder);
328       if (Matcher->matches(DynTypedNode::create(Node), Finder,
329                            &RecursiveBuilder)) {
330         // After the first match the matcher succeeds.
331         Matches = true;
332         ResultBindings.addMatch(RecursiveBuilder);
333       }
334     }
335     return true;
336   }
337 
338   // Traverses the subtree rooted at 'Node'; returns true if the
339   // traversal should continue after this function returns.
340   template <typename T>
traverse(const T & Node)341   bool traverse(const T &Node) {
342     static_assert(IsBaseType<T>::value,
343                   "traverse can only be instantiated with base type");
344     if (!match(Node))
345       return false;
346     return baseTraverse(Node);
347   }
348 
349   const DynTypedMatcher *const Matcher;
350   ASTMatchFinder *const Finder;
351   BoundNodesTreeBuilder *const Builder;
352   BoundNodesTreeBuilder ResultBindings;
353   int CurrentDepth;
354   const int MaxDepth;
355   const TraversalKind Traversal;
356   const ASTMatchFinder::BindKind Bind;
357   bool Matches;
358 };
359 
360 // Controls the outermost traversal of the AST and allows to match multiple
361 // matchers.
362 class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
363                         public ASTMatchFinder {
364 public:
MatchASTVisitor(const MatchFinder::MatchersByType * Matchers,const MatchFinder::MatchFinderOptions & Options)365   MatchASTVisitor(const MatchFinder::MatchersByType *Matchers,
366                   const MatchFinder::MatchFinderOptions &Options)
367       : Matchers(Matchers), Options(Options), ActiveASTContext(nullptr) {}
368 
~MatchASTVisitor()369   ~MatchASTVisitor() override {
370     if (Options.CheckProfiling) {
371       Options.CheckProfiling->Records = std::move(TimeByBucket);
372     }
373   }
374 
onStartOfTranslationUnit()375   void onStartOfTranslationUnit() {
376     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
377     TimeBucketRegion Timer;
378     for (MatchCallback *MC : Matchers->AllCallbacks) {
379       if (EnableCheckProfiling)
380         Timer.setBucket(&TimeByBucket[MC->getID()]);
381       MC->onStartOfTranslationUnit();
382     }
383   }
384 
onEndOfTranslationUnit()385   void onEndOfTranslationUnit() {
386     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
387     TimeBucketRegion Timer;
388     for (MatchCallback *MC : Matchers->AllCallbacks) {
389       if (EnableCheckProfiling)
390         Timer.setBucket(&TimeByBucket[MC->getID()]);
391       MC->onEndOfTranslationUnit();
392     }
393   }
394 
set_active_ast_context(ASTContext * NewActiveASTContext)395   void set_active_ast_context(ASTContext *NewActiveASTContext) {
396     ActiveASTContext = NewActiveASTContext;
397   }
398 
399   // The following Visit*() and Traverse*() functions "override"
400   // methods in RecursiveASTVisitor.
401 
VisitTypedefNameDecl(TypedefNameDecl * DeclNode)402   bool VisitTypedefNameDecl(TypedefNameDecl *DeclNode) {
403     // When we see 'typedef A B', we add name 'B' to the set of names
404     // A's canonical type maps to.  This is necessary for implementing
405     // isDerivedFrom(x) properly, where x can be the name of the base
406     // class or any of its aliases.
407     //
408     // In general, the is-alias-of (as defined by typedefs) relation
409     // is tree-shaped, as you can typedef a type more than once.  For
410     // example,
411     //
412     //   typedef A B;
413     //   typedef A C;
414     //   typedef C D;
415     //   typedef C E;
416     //
417     // gives you
418     //
419     //   A
420     //   |- B
421     //   `- C
422     //      |- D
423     //      `- E
424     //
425     // It is wrong to assume that the relation is a chain.  A correct
426     // implementation of isDerivedFrom() needs to recognize that B and
427     // E are aliases, even though neither is a typedef of the other.
428     // Therefore, we cannot simply walk through one typedef chain to
429     // find out whether the type name matches.
430     const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr();
431     const Type *CanonicalType =  // root of the typedef tree
432         ActiveASTContext->getCanonicalType(TypeNode);
433     TypeAliases[CanonicalType].insert(DeclNode);
434     return true;
435   }
436 
VisitObjCCompatibleAliasDecl(ObjCCompatibleAliasDecl * CAD)437   bool VisitObjCCompatibleAliasDecl(ObjCCompatibleAliasDecl *CAD) {
438     const ObjCInterfaceDecl *InterfaceDecl = CAD->getClassInterface();
439     CompatibleAliases[InterfaceDecl].insert(CAD);
440     return true;
441   }
442 
443   bool TraverseDecl(Decl *DeclNode);
444   bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr);
445   bool TraverseType(QualType TypeNode);
446   bool TraverseTypeLoc(TypeLoc TypeNode);
447   bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS);
448   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS);
449   bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit);
450 
451   // Matches children or descendants of 'Node' with 'BaseMatcher'.
memoizedMatchesRecursively(const DynTypedNode & Node,ASTContext & Ctx,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,int MaxDepth,TraversalKind Traversal,BindKind Bind)452   bool memoizedMatchesRecursively(const DynTypedNode &Node, ASTContext &Ctx,
453                                   const DynTypedMatcher &Matcher,
454                                   BoundNodesTreeBuilder *Builder, int MaxDepth,
455                                   TraversalKind Traversal, BindKind Bind) {
456     // For AST-nodes that don't have an identity, we can't memoize.
457     if (!Node.getMemoizationData() || !Builder->isComparable())
458       return matchesRecursively(Node, Matcher, Builder, MaxDepth, Traversal,
459                                 Bind);
460 
461     MatchKey Key;
462     Key.MatcherID = Matcher.getID();
463     Key.Node = Node;
464     // Note that we key on the bindings *before* the match.
465     Key.BoundNodes = *Builder;
466     Key.Traversal = Ctx.getParentMapContext().getTraversalKind();
467     // Memoize result even doing a single-level match, it might be expensive.
468     Key.Type = MaxDepth == 1 ? MatchType::Child : MatchType::Descendants;
469     MemoizationMap::iterator I = ResultCache.find(Key);
470     if (I != ResultCache.end()) {
471       *Builder = I->second.Nodes;
472       return I->second.ResultOfMatch;
473     }
474 
475     MemoizedMatchResult Result;
476     Result.Nodes = *Builder;
477     Result.ResultOfMatch = matchesRecursively(Node, Matcher, &Result.Nodes,
478                                               MaxDepth, Traversal, Bind);
479 
480     MemoizedMatchResult &CachedResult = ResultCache[Key];
481     CachedResult = std::move(Result);
482 
483     *Builder = CachedResult.Nodes;
484     return CachedResult.ResultOfMatch;
485   }
486 
487   // Matches children or descendants of 'Node' with 'BaseMatcher'.
matchesRecursively(const DynTypedNode & Node,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,int MaxDepth,TraversalKind Traversal,BindKind Bind)488   bool matchesRecursively(const DynTypedNode &Node,
489                           const DynTypedMatcher &Matcher,
490                           BoundNodesTreeBuilder *Builder, int MaxDepth,
491                           TraversalKind Traversal, BindKind Bind) {
492     MatchChildASTVisitor Visitor(
493       &Matcher, this, Builder, MaxDepth, Traversal, Bind);
494     return Visitor.findMatch(Node);
495   }
496 
497   bool classIsDerivedFrom(const CXXRecordDecl *Declaration,
498                           const Matcher<NamedDecl> &Base,
499                           BoundNodesTreeBuilder *Builder,
500                           bool Directly) override;
501 
502   bool objcClassIsDerivedFrom(const ObjCInterfaceDecl *Declaration,
503                               const Matcher<NamedDecl> &Base,
504                               BoundNodesTreeBuilder *Builder,
505                               bool Directly) override;
506 
507   // Implements ASTMatchFinder::matchesChildOf.
matchesChildOf(const DynTypedNode & Node,ASTContext & Ctx,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,TraversalKind Traversal,BindKind Bind)508   bool matchesChildOf(const DynTypedNode &Node, ASTContext &Ctx,
509                       const DynTypedMatcher &Matcher,
510                       BoundNodesTreeBuilder *Builder, TraversalKind Traversal,
511                       BindKind Bind) override {
512     if (ResultCache.size() > MaxMemoizationEntries)
513       ResultCache.clear();
514     return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, 1, Traversal,
515                                       Bind);
516   }
517   // Implements ASTMatchFinder::matchesDescendantOf.
matchesDescendantOf(const DynTypedNode & Node,ASTContext & Ctx,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,BindKind Bind)518   bool matchesDescendantOf(const DynTypedNode &Node, ASTContext &Ctx,
519                            const DynTypedMatcher &Matcher,
520                            BoundNodesTreeBuilder *Builder,
521                            BindKind Bind) override {
522     if (ResultCache.size() > MaxMemoizationEntries)
523       ResultCache.clear();
524     return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, INT_MAX,
525                                       TraversalKind::TK_AsIs, Bind);
526   }
527   // Implements ASTMatchFinder::matchesAncestorOf.
matchesAncestorOf(const DynTypedNode & Node,ASTContext & Ctx,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,AncestorMatchMode MatchMode)528   bool matchesAncestorOf(const DynTypedNode &Node, ASTContext &Ctx,
529                          const DynTypedMatcher &Matcher,
530                          BoundNodesTreeBuilder *Builder,
531                          AncestorMatchMode MatchMode) override {
532     // Reset the cache outside of the recursive call to make sure we
533     // don't invalidate any iterators.
534     if (ResultCache.size() > MaxMemoizationEntries)
535       ResultCache.clear();
536     return memoizedMatchesAncestorOfRecursively(Node, Ctx, Matcher, Builder,
537                                                 MatchMode);
538   }
539 
540   // Matches all registered matchers on the given node and calls the
541   // result callback for every node that matches.
match(const DynTypedNode & Node)542   void match(const DynTypedNode &Node) {
543     // FIXME: Improve this with a switch or a visitor pattern.
544     if (auto *N = Node.get<Decl>()) {
545       match(*N);
546     } else if (auto *N = Node.get<Stmt>()) {
547       match(*N);
548     } else if (auto *N = Node.get<Type>()) {
549       match(*N);
550     } else if (auto *N = Node.get<QualType>()) {
551       match(*N);
552     } else if (auto *N = Node.get<NestedNameSpecifier>()) {
553       match(*N);
554     } else if (auto *N = Node.get<NestedNameSpecifierLoc>()) {
555       match(*N);
556     } else if (auto *N = Node.get<TypeLoc>()) {
557       match(*N);
558     } else if (auto *N = Node.get<CXXCtorInitializer>()) {
559       match(*N);
560     }
561   }
562 
match(const T & Node)563   template <typename T> void match(const T &Node) {
564     matchDispatch(&Node);
565   }
566 
567   // Implements ASTMatchFinder::getASTContext.
getASTContext() const568   ASTContext &getASTContext() const override { return *ActiveASTContext; }
569 
shouldVisitTemplateInstantiations() const570   bool shouldVisitTemplateInstantiations() const { return true; }
shouldVisitImplicitCode() const571   bool shouldVisitImplicitCode() const { return true; }
572 
573 private:
574   class TimeBucketRegion {
575   public:
TimeBucketRegion()576     TimeBucketRegion() : Bucket(nullptr) {}
~TimeBucketRegion()577     ~TimeBucketRegion() { setBucket(nullptr); }
578 
579     /// Start timing for \p NewBucket.
580     ///
581     /// If there was a bucket already set, it will finish the timing for that
582     /// other bucket.
583     /// \p NewBucket will be timed until the next call to \c setBucket() or
584     /// until the \c TimeBucketRegion is destroyed.
585     /// If \p NewBucket is the same as the currently timed bucket, this call
586     /// does nothing.
setBucket(llvm::TimeRecord * NewBucket)587     void setBucket(llvm::TimeRecord *NewBucket) {
588       if (Bucket != NewBucket) {
589         auto Now = llvm::TimeRecord::getCurrentTime(true);
590         if (Bucket)
591           *Bucket += Now;
592         if (NewBucket)
593           *NewBucket -= Now;
594         Bucket = NewBucket;
595       }
596     }
597 
598   private:
599     llvm::TimeRecord *Bucket;
600   };
601 
602   /// Runs all the \p Matchers on \p Node.
603   ///
604   /// Used by \c matchDispatch() below.
605   template <typename T, typename MC>
matchWithoutFilter(const T & Node,const MC & Matchers)606   void matchWithoutFilter(const T &Node, const MC &Matchers) {
607     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
608     TimeBucketRegion Timer;
609     for (const auto &MP : Matchers) {
610       if (EnableCheckProfiling)
611         Timer.setBucket(&TimeByBucket[MP.second->getID()]);
612       BoundNodesTreeBuilder Builder;
613       if (MP.first.matches(Node, this, &Builder)) {
614         MatchVisitor Visitor(ActiveASTContext, MP.second);
615         Builder.visitMatches(&Visitor);
616       }
617     }
618   }
619 
matchWithFilter(const DynTypedNode & DynNode)620   void matchWithFilter(const DynTypedNode &DynNode) {
621     auto Kind = DynNode.getNodeKind();
622     auto it = MatcherFiltersMap.find(Kind);
623     const auto &Filter =
624         it != MatcherFiltersMap.end() ? it->second : getFilterForKind(Kind);
625 
626     if (Filter.empty())
627       return;
628 
629     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
630     TimeBucketRegion Timer;
631     auto &Matchers = this->Matchers->DeclOrStmt;
632     for (unsigned short I : Filter) {
633       auto &MP = Matchers[I];
634       if (EnableCheckProfiling)
635         Timer.setBucket(&TimeByBucket[MP.second->getID()]);
636       BoundNodesTreeBuilder Builder;
637       if (MP.first.matches(DynNode, this, &Builder)) {
638         MatchVisitor Visitor(ActiveASTContext, MP.second);
639         Builder.visitMatches(&Visitor);
640       }
641     }
642   }
643 
getFilterForKind(ASTNodeKind Kind)644   const std::vector<unsigned short> &getFilterForKind(ASTNodeKind Kind) {
645     auto &Filter = MatcherFiltersMap[Kind];
646     auto &Matchers = this->Matchers->DeclOrStmt;
647     assert((Matchers.size() < USHRT_MAX) && "Too many matchers.");
648     for (unsigned I = 0, E = Matchers.size(); I != E; ++I) {
649       if (Matchers[I].first.canMatchNodesOfKind(Kind)) {
650         Filter.push_back(I);
651       }
652     }
653     return Filter;
654   }
655 
656   /// @{
657   /// Overloads to pair the different node types to their matchers.
matchDispatch(const Decl * Node)658   void matchDispatch(const Decl *Node) {
659     return matchWithFilter(DynTypedNode::create(*Node));
660   }
matchDispatch(const Stmt * Node)661   void matchDispatch(const Stmt *Node) {
662     return matchWithFilter(DynTypedNode::create(*Node));
663   }
664 
matchDispatch(const Type * Node)665   void matchDispatch(const Type *Node) {
666     matchWithoutFilter(QualType(Node, 0), Matchers->Type);
667   }
matchDispatch(const TypeLoc * Node)668   void matchDispatch(const TypeLoc *Node) {
669     matchWithoutFilter(*Node, Matchers->TypeLoc);
670   }
matchDispatch(const QualType * Node)671   void matchDispatch(const QualType *Node) {
672     matchWithoutFilter(*Node, Matchers->Type);
673   }
matchDispatch(const NestedNameSpecifier * Node)674   void matchDispatch(const NestedNameSpecifier *Node) {
675     matchWithoutFilter(*Node, Matchers->NestedNameSpecifier);
676   }
matchDispatch(const NestedNameSpecifierLoc * Node)677   void matchDispatch(const NestedNameSpecifierLoc *Node) {
678     matchWithoutFilter(*Node, Matchers->NestedNameSpecifierLoc);
679   }
matchDispatch(const CXXCtorInitializer * Node)680   void matchDispatch(const CXXCtorInitializer *Node) {
681     matchWithoutFilter(*Node, Matchers->CtorInit);
682   }
matchDispatch(const void *)683   void matchDispatch(const void *) { /* Do nothing. */ }
684   /// @}
685 
686   // Returns whether an ancestor of \p Node matches \p Matcher.
687   //
688   // The order of matching ((which can lead to different nodes being bound in
689   // case there are multiple matches) is breadth first search.
690   //
691   // To allow memoization in the very common case of having deeply nested
692   // expressions inside a template function, we first walk up the AST, memoizing
693   // the result of the match along the way, as long as there is only a single
694   // parent.
695   //
696   // Once there are multiple parents, the breadth first search order does not
697   // allow simple memoization on the ancestors. Thus, we only memoize as long
698   // as there is a single parent.
memoizedMatchesAncestorOfRecursively(const DynTypedNode & Node,ASTContext & Ctx,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,AncestorMatchMode MatchMode)699   bool memoizedMatchesAncestorOfRecursively(const DynTypedNode &Node,
700                                             ASTContext &Ctx,
701                                             const DynTypedMatcher &Matcher,
702                                             BoundNodesTreeBuilder *Builder,
703                                             AncestorMatchMode MatchMode) {
704     // For AST-nodes that don't have an identity, we can't memoize.
705     // When doing a single-level match, we don't need to memoize because
706     // ParentMap (in ASTContext) already memoizes the result.
707     if (!Builder->isComparable() ||
708         MatchMode == AncestorMatchMode::AMM_ParentOnly)
709       return matchesAncestorOfRecursively(Node, Ctx, Matcher, Builder,
710                                           MatchMode);
711 
712     MatchKey Key;
713     Key.MatcherID = Matcher.getID();
714     Key.Node = Node;
715     Key.BoundNodes = *Builder;
716     Key.Traversal = Ctx.getParentMapContext().getTraversalKind();
717     Key.Type = MatchType::Ancestors;
718 
719     // Note that we cannot use insert and reuse the iterator, as recursive
720     // calls to match might invalidate the result cache iterators.
721     MemoizationMap::iterator I = ResultCache.find(Key);
722     if (I != ResultCache.end()) {
723       *Builder = I->second.Nodes;
724       return I->second.ResultOfMatch;
725     }
726 
727     MemoizedMatchResult Result;
728     Result.Nodes = *Builder;
729     Result.ResultOfMatch = matchesAncestorOfRecursively(
730         Node, Ctx, Matcher, &Result.Nodes, MatchMode);
731 
732     MemoizedMatchResult &CachedResult = ResultCache[Key];
733     CachedResult = std::move(Result);
734 
735     *Builder = CachedResult.Nodes;
736     return CachedResult.ResultOfMatch;
737   }
738 
matchesAncestorOfRecursively(const DynTypedNode & Node,ASTContext & Ctx,const DynTypedMatcher & Matcher,BoundNodesTreeBuilder * Builder,AncestorMatchMode MatchMode)739   bool matchesAncestorOfRecursively(const DynTypedNode &Node, ASTContext &Ctx,
740                                     const DynTypedMatcher &Matcher,
741                                     BoundNodesTreeBuilder *Builder,
742                                     AncestorMatchMode MatchMode) {
743     const auto &Parents = ActiveASTContext->getParents(Node);
744     if (Parents.empty()) {
745       // Nodes may have no parents if:
746       //  a) the node is the TranslationUnitDecl
747       //  b) we have a limited traversal scope that excludes the parent edges
748       //  c) there is a bug in the AST, and the node is not reachable
749       // Usually the traversal scope is the whole AST, which precludes b.
750       // Bugs are common enough that it's worthwhile asserting when we can.
751 #ifndef NDEBUG
752       if (!Node.get<TranslationUnitDecl>() &&
753           /* Traversal scope is full AST if any of the bounds are the TU */
754           llvm::any_of(ActiveASTContext->getTraversalScope(), [](Decl *D) {
755             return D->getKind() == Decl::TranslationUnit;
756           })) {
757         llvm::errs() << "Tried to match orphan node:\n";
758         Node.dump(llvm::errs(), *ActiveASTContext);
759         llvm_unreachable("Parent map should be complete!");
760       }
761 #endif
762       return false;
763     }
764     if (Parents.size() == 1) {
765       // Only one parent - do recursive memoization.
766       const DynTypedNode Parent = Parents[0];
767       BoundNodesTreeBuilder BuilderCopy = *Builder;
768       if (Matcher.matches(Parent, this, &BuilderCopy)) {
769         *Builder = std::move(BuilderCopy);
770         return true;
771       }
772       if (MatchMode != ASTMatchFinder::AMM_ParentOnly) {
773         return memoizedMatchesAncestorOfRecursively(Parent, Ctx, Matcher,
774                                                     Builder, MatchMode);
775         // Once we get back from the recursive call, the result will be the
776         // same as the parent's result.
777       }
778     } else {
779       // Multiple parents - BFS over the rest of the nodes.
780       llvm::DenseSet<const void *> Visited;
781       std::deque<DynTypedNode> Queue(Parents.begin(), Parents.end());
782       while (!Queue.empty()) {
783         BoundNodesTreeBuilder BuilderCopy = *Builder;
784         if (Matcher.matches(Queue.front(), this, &BuilderCopy)) {
785           *Builder = std::move(BuilderCopy);
786           return true;
787         }
788         if (MatchMode != ASTMatchFinder::AMM_ParentOnly) {
789           for (const auto &Parent :
790                ActiveASTContext->getParents(Queue.front())) {
791             // Make sure we do not visit the same node twice.
792             // Otherwise, we'll visit the common ancestors as often as there
793             // are splits on the way down.
794             if (Visited.insert(Parent.getMemoizationData()).second)
795               Queue.push_back(Parent);
796           }
797         }
798         Queue.pop_front();
799       }
800     }
801     return false;
802   }
803 
804   // Implements a BoundNodesTree::Visitor that calls a MatchCallback with
805   // the aggregated bound nodes for each match.
806   class MatchVisitor : public BoundNodesTreeBuilder::Visitor {
807   public:
MatchVisitor(ASTContext * Context,MatchFinder::MatchCallback * Callback)808     MatchVisitor(ASTContext* Context,
809                  MatchFinder::MatchCallback* Callback)
810       : Context(Context),
811         Callback(Callback) {}
812 
visitMatch(const BoundNodes & BoundNodesView)813     void visitMatch(const BoundNodes& BoundNodesView) override {
814       Callback->run(MatchFinder::MatchResult(BoundNodesView, Context));
815     }
816 
817   private:
818     ASTContext* Context;
819     MatchFinder::MatchCallback* Callback;
820   };
821 
822   // Returns true if 'TypeNode' has an alias that matches the given matcher.
typeHasMatchingAlias(const Type * TypeNode,const Matcher<NamedDecl> & Matcher,BoundNodesTreeBuilder * Builder)823   bool typeHasMatchingAlias(const Type *TypeNode,
824                             const Matcher<NamedDecl> &Matcher,
825                             BoundNodesTreeBuilder *Builder) {
826     const Type *const CanonicalType =
827       ActiveASTContext->getCanonicalType(TypeNode);
828     auto Aliases = TypeAliases.find(CanonicalType);
829     if (Aliases == TypeAliases.end())
830       return false;
831     for (const TypedefNameDecl *Alias : Aliases->second) {
832       BoundNodesTreeBuilder Result(*Builder);
833       if (Matcher.matches(*Alias, this, &Result)) {
834         *Builder = std::move(Result);
835         return true;
836       }
837     }
838     return false;
839   }
840 
841   bool
objcClassHasMatchingCompatibilityAlias(const ObjCInterfaceDecl * InterfaceDecl,const Matcher<NamedDecl> & Matcher,BoundNodesTreeBuilder * Builder)842   objcClassHasMatchingCompatibilityAlias(const ObjCInterfaceDecl *InterfaceDecl,
843                                          const Matcher<NamedDecl> &Matcher,
844                                          BoundNodesTreeBuilder *Builder) {
845     auto Aliases = CompatibleAliases.find(InterfaceDecl);
846     if (Aliases == CompatibleAliases.end())
847       return false;
848     for (const ObjCCompatibleAliasDecl *Alias : Aliases->second) {
849       BoundNodesTreeBuilder Result(*Builder);
850       if (Matcher.matches(*Alias, this, &Result)) {
851         *Builder = std::move(Result);
852         return true;
853       }
854     }
855     return false;
856   }
857 
858   /// Bucket to record map.
859   ///
860   /// Used to get the appropriate bucket for each matcher.
861   llvm::StringMap<llvm::TimeRecord> TimeByBucket;
862 
863   const MatchFinder::MatchersByType *Matchers;
864 
865   /// Filtered list of matcher indices for each matcher kind.
866   ///
867   /// \c Decl and \c Stmt toplevel matchers usually apply to a specific node
868   /// kind (and derived kinds) so it is a waste to try every matcher on every
869   /// node.
870   /// We precalculate a list of matchers that pass the toplevel restrict check.
871   llvm::DenseMap<ASTNodeKind, std::vector<unsigned short>> MatcherFiltersMap;
872 
873   const MatchFinder::MatchFinderOptions &Options;
874   ASTContext *ActiveASTContext;
875 
876   // Maps a canonical type to its TypedefDecls.
877   llvm::DenseMap<const Type*, std::set<const TypedefNameDecl*> > TypeAliases;
878 
879   // Maps an Objective-C interface to its ObjCCompatibleAliasDecls.
880   llvm::DenseMap<const ObjCInterfaceDecl *,
881                  llvm::SmallPtrSet<const ObjCCompatibleAliasDecl *, 2>>
882       CompatibleAliases;
883 
884   // Maps (matcher, node) -> the match result for memoization.
885   typedef std::map<MatchKey, MemoizedMatchResult> MemoizationMap;
886   MemoizationMap ResultCache;
887 };
888 
889 static CXXRecordDecl *
getAsCXXRecordDeclOrPrimaryTemplate(const Type * TypeNode)890 getAsCXXRecordDeclOrPrimaryTemplate(const Type *TypeNode) {
891   if (auto *RD = TypeNode->getAsCXXRecordDecl())
892     return RD;
893 
894   // Find the innermost TemplateSpecializationType that isn't an alias template.
895   auto *TemplateType = TypeNode->getAs<TemplateSpecializationType>();
896   while (TemplateType && TemplateType->isTypeAlias())
897     TemplateType =
898         TemplateType->getAliasedType()->getAs<TemplateSpecializationType>();
899 
900   // If this is the name of a (dependent) template specialization, use the
901   // definition of the template, even though it might be specialized later.
902   if (TemplateType)
903     if (auto *ClassTemplate = dyn_cast_or_null<ClassTemplateDecl>(
904           TemplateType->getTemplateName().getAsTemplateDecl()))
905       return ClassTemplate->getTemplatedDecl();
906 
907   return nullptr;
908 }
909 
910 // Returns true if the given C++ class is directly or indirectly derived
911 // from a base type with the given name.  A class is not considered to be
912 // derived from itself.
classIsDerivedFrom(const CXXRecordDecl * Declaration,const Matcher<NamedDecl> & Base,BoundNodesTreeBuilder * Builder,bool Directly)913 bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration,
914                                          const Matcher<NamedDecl> &Base,
915                                          BoundNodesTreeBuilder *Builder,
916                                          bool Directly) {
917   if (!Declaration->hasDefinition())
918     return false;
919   for (const auto &It : Declaration->bases()) {
920     const Type *TypeNode = It.getType().getTypePtr();
921 
922     if (typeHasMatchingAlias(TypeNode, Base, Builder))
923       return true;
924 
925     // FIXME: Going to the primary template here isn't really correct, but
926     // unfortunately we accept a Decl matcher for the base class not a Type
927     // matcher, so it's the best thing we can do with our current interface.
928     CXXRecordDecl *ClassDecl = getAsCXXRecordDeclOrPrimaryTemplate(TypeNode);
929     if (!ClassDecl)
930       continue;
931     if (ClassDecl == Declaration) {
932       // This can happen for recursive template definitions.
933       continue;
934     }
935     BoundNodesTreeBuilder Result(*Builder);
936     if (Base.matches(*ClassDecl, this, &Result)) {
937       *Builder = std::move(Result);
938       return true;
939     }
940     if (!Directly && classIsDerivedFrom(ClassDecl, Base, Builder, Directly))
941       return true;
942   }
943   return false;
944 }
945 
946 // Returns true if the given Objective-C class is directly or indirectly
947 // derived from a matching base class. A class is not considered to be derived
948 // from itself.
objcClassIsDerivedFrom(const ObjCInterfaceDecl * Declaration,const Matcher<NamedDecl> & Base,BoundNodesTreeBuilder * Builder,bool Directly)949 bool MatchASTVisitor::objcClassIsDerivedFrom(
950     const ObjCInterfaceDecl *Declaration, const Matcher<NamedDecl> &Base,
951     BoundNodesTreeBuilder *Builder, bool Directly) {
952   // Check if any of the superclasses of the class match.
953   for (const ObjCInterfaceDecl *ClassDecl = Declaration->getSuperClass();
954        ClassDecl != nullptr; ClassDecl = ClassDecl->getSuperClass()) {
955     // Check if there are any matching compatibility aliases.
956     if (objcClassHasMatchingCompatibilityAlias(ClassDecl, Base, Builder))
957       return true;
958 
959     // Check if there are any matching type aliases.
960     const Type *TypeNode = ClassDecl->getTypeForDecl();
961     if (typeHasMatchingAlias(TypeNode, Base, Builder))
962       return true;
963 
964     if (Base.matches(*ClassDecl, this, Builder))
965       return true;
966 
967     // Not `return false` as a temporary workaround for PR43879.
968     if (Directly)
969       break;
970   }
971 
972   return false;
973 }
974 
TraverseDecl(Decl * DeclNode)975 bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) {
976   if (!DeclNode) {
977     return true;
978   }
979   match(*DeclNode);
980   return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode);
981 }
982 
TraverseStmt(Stmt * StmtNode,DataRecursionQueue * Queue)983 bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue) {
984   if (!StmtNode) {
985     return true;
986   }
987   match(*StmtNode);
988   return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode, Queue);
989 }
990 
TraverseType(QualType TypeNode)991 bool MatchASTVisitor::TraverseType(QualType TypeNode) {
992   match(TypeNode);
993   return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode);
994 }
995 
TraverseTypeLoc(TypeLoc TypeLocNode)996 bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLocNode) {
997   // The RecursiveASTVisitor only visits types if they're not within TypeLocs.
998   // We still want to find those types via matchers, so we match them here. Note
999   // that the TypeLocs are structurally a shadow-hierarchy to the expressed
1000   // type, so we visit all involved parts of a compound type when matching on
1001   // each TypeLoc.
1002   match(TypeLocNode);
1003   match(TypeLocNode.getType());
1004   return RecursiveASTVisitor<MatchASTVisitor>::TraverseTypeLoc(TypeLocNode);
1005 }
1006 
TraverseNestedNameSpecifier(NestedNameSpecifier * NNS)1007 bool MatchASTVisitor::TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
1008   match(*NNS);
1009   return RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifier(NNS);
1010 }
1011 
TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS)1012 bool MatchASTVisitor::TraverseNestedNameSpecifierLoc(
1013     NestedNameSpecifierLoc NNS) {
1014   if (!NNS)
1015     return true;
1016 
1017   match(NNS);
1018 
1019   // We only match the nested name specifier here (as opposed to traversing it)
1020   // because the traversal is already done in the parallel "Loc"-hierarchy.
1021   if (NNS.hasQualifier())
1022     match(*NNS.getNestedNameSpecifier());
1023   return
1024       RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifierLoc(NNS);
1025 }
1026 
TraverseConstructorInitializer(CXXCtorInitializer * CtorInit)1027 bool MatchASTVisitor::TraverseConstructorInitializer(
1028     CXXCtorInitializer *CtorInit) {
1029   if (!CtorInit)
1030     return true;
1031 
1032   match(*CtorInit);
1033 
1034   return RecursiveASTVisitor<MatchASTVisitor>::TraverseConstructorInitializer(
1035       CtorInit);
1036 }
1037 
1038 class MatchASTConsumer : public ASTConsumer {
1039 public:
MatchASTConsumer(MatchFinder * Finder,MatchFinder::ParsingDoneTestCallback * ParsingDone)1040   MatchASTConsumer(MatchFinder *Finder,
1041                    MatchFinder::ParsingDoneTestCallback *ParsingDone)
1042       : Finder(Finder), ParsingDone(ParsingDone) {}
1043 
1044 private:
HandleTranslationUnit(ASTContext & Context)1045   void HandleTranslationUnit(ASTContext &Context) override {
1046     if (ParsingDone != nullptr) {
1047       ParsingDone->run();
1048     }
1049     Finder->matchAST(Context);
1050   }
1051 
1052   MatchFinder *Finder;
1053   MatchFinder::ParsingDoneTestCallback *ParsingDone;
1054 };
1055 
1056 } // end namespace
1057 } // end namespace internal
1058 
MatchResult(const BoundNodes & Nodes,ASTContext * Context)1059 MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes,
1060                                       ASTContext *Context)
1061   : Nodes(Nodes), Context(Context),
1062     SourceManager(&Context->getSourceManager()) {}
1063 
~MatchCallback()1064 MatchFinder::MatchCallback::~MatchCallback() {}
~ParsingDoneTestCallback()1065 MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {}
1066 
MatchFinder(MatchFinderOptions Options)1067 MatchFinder::MatchFinder(MatchFinderOptions Options)
1068     : Options(std::move(Options)), ParsingDone(nullptr) {}
1069 
~MatchFinder()1070 MatchFinder::~MatchFinder() {}
1071 
addMatcher(const DeclarationMatcher & NodeMatch,MatchCallback * Action)1072 void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch,
1073                              MatchCallback *Action) {
1074   Matchers.DeclOrStmt.emplace_back(NodeMatch, Action);
1075   Matchers.AllCallbacks.insert(Action);
1076 }
1077 
addMatcher(const TypeMatcher & NodeMatch,MatchCallback * Action)1078 void MatchFinder::addMatcher(const TypeMatcher &NodeMatch,
1079                              MatchCallback *Action) {
1080   Matchers.Type.emplace_back(NodeMatch, Action);
1081   Matchers.AllCallbacks.insert(Action);
1082 }
1083 
addMatcher(const StatementMatcher & NodeMatch,MatchCallback * Action)1084 void MatchFinder::addMatcher(const StatementMatcher &NodeMatch,
1085                              MatchCallback *Action) {
1086   Matchers.DeclOrStmt.emplace_back(NodeMatch, Action);
1087   Matchers.AllCallbacks.insert(Action);
1088 }
1089 
addMatcher(const NestedNameSpecifierMatcher & NodeMatch,MatchCallback * Action)1090 void MatchFinder::addMatcher(const NestedNameSpecifierMatcher &NodeMatch,
1091                              MatchCallback *Action) {
1092   Matchers.NestedNameSpecifier.emplace_back(NodeMatch, Action);
1093   Matchers.AllCallbacks.insert(Action);
1094 }
1095 
addMatcher(const NestedNameSpecifierLocMatcher & NodeMatch,MatchCallback * Action)1096 void MatchFinder::addMatcher(const NestedNameSpecifierLocMatcher &NodeMatch,
1097                              MatchCallback *Action) {
1098   Matchers.NestedNameSpecifierLoc.emplace_back(NodeMatch, Action);
1099   Matchers.AllCallbacks.insert(Action);
1100 }
1101 
addMatcher(const TypeLocMatcher & NodeMatch,MatchCallback * Action)1102 void MatchFinder::addMatcher(const TypeLocMatcher &NodeMatch,
1103                              MatchCallback *Action) {
1104   Matchers.TypeLoc.emplace_back(NodeMatch, Action);
1105   Matchers.AllCallbacks.insert(Action);
1106 }
1107 
addMatcher(const CXXCtorInitializerMatcher & NodeMatch,MatchCallback * Action)1108 void MatchFinder::addMatcher(const CXXCtorInitializerMatcher &NodeMatch,
1109                              MatchCallback *Action) {
1110   Matchers.CtorInit.emplace_back(NodeMatch, Action);
1111   Matchers.AllCallbacks.insert(Action);
1112 }
1113 
addDynamicMatcher(const internal::DynTypedMatcher & NodeMatch,MatchCallback * Action)1114 bool MatchFinder::addDynamicMatcher(const internal::DynTypedMatcher &NodeMatch,
1115                                     MatchCallback *Action) {
1116   if (NodeMatch.canConvertTo<Decl>()) {
1117     addMatcher(NodeMatch.convertTo<Decl>(), Action);
1118     return true;
1119   } else if (NodeMatch.canConvertTo<QualType>()) {
1120     addMatcher(NodeMatch.convertTo<QualType>(), Action);
1121     return true;
1122   } else if (NodeMatch.canConvertTo<Stmt>()) {
1123     addMatcher(NodeMatch.convertTo<Stmt>(), Action);
1124     return true;
1125   } else if (NodeMatch.canConvertTo<NestedNameSpecifier>()) {
1126     addMatcher(NodeMatch.convertTo<NestedNameSpecifier>(), Action);
1127     return true;
1128   } else if (NodeMatch.canConvertTo<NestedNameSpecifierLoc>()) {
1129     addMatcher(NodeMatch.convertTo<NestedNameSpecifierLoc>(), Action);
1130     return true;
1131   } else if (NodeMatch.canConvertTo<TypeLoc>()) {
1132     addMatcher(NodeMatch.convertTo<TypeLoc>(), Action);
1133     return true;
1134   } else if (NodeMatch.canConvertTo<CXXCtorInitializer>()) {
1135     addMatcher(NodeMatch.convertTo<CXXCtorInitializer>(), Action);
1136     return true;
1137   }
1138   return false;
1139 }
1140 
newASTConsumer()1141 std::unique_ptr<ASTConsumer> MatchFinder::newASTConsumer() {
1142   return std::make_unique<internal::MatchASTConsumer>(this, ParsingDone);
1143 }
1144 
match(const clang::DynTypedNode & Node,ASTContext & Context)1145 void MatchFinder::match(const clang::DynTypedNode &Node, ASTContext &Context) {
1146   internal::MatchASTVisitor Visitor(&Matchers, Options);
1147   Visitor.set_active_ast_context(&Context);
1148   Visitor.match(Node);
1149 }
1150 
matchAST(ASTContext & Context)1151 void MatchFinder::matchAST(ASTContext &Context) {
1152   internal::MatchASTVisitor Visitor(&Matchers, Options);
1153   Visitor.set_active_ast_context(&Context);
1154   Visitor.onStartOfTranslationUnit();
1155   Visitor.TraverseAST(Context);
1156   Visitor.onEndOfTranslationUnit();
1157 }
1158 
registerTestCallbackAfterParsing(MatchFinder::ParsingDoneTestCallback * NewParsingDone)1159 void MatchFinder::registerTestCallbackAfterParsing(
1160     MatchFinder::ParsingDoneTestCallback *NewParsingDone) {
1161   ParsingDone = NewParsingDone;
1162 }
1163 
getID() const1164 StringRef MatchFinder::MatchCallback::getID() const { return "<unknown>"; }
1165 
1166 } // end namespace ast_matchers
1167 } // end namespace clang
1168