1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 //  Provides MatchVerifier, a base class to implement gtest matchers that
11 //  verify things that can be matched on the AST.
12 //
13 //  Also implements matchers based on MatchVerifier:
14 //  LocationVerifier and RangeVerifier to verify whether a matched node has
15 //  the expected source location or source range.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "clang/AST/ASTContext.h"
20 #include "clang/ASTMatchers/ASTMatchFinder.h"
21 #include "clang/ASTMatchers/ASTMatchers.h"
22 #include "clang/Tooling/Tooling.h"
23 #include "gtest/gtest.h"
24 
25 namespace clang {
26 namespace ast_matchers {
27 
28 enum Language { Lang_C, Lang_C89, Lang_CXX, Lang_CXX11, Lang_OpenCL };
29 
30 /// \brief Base class for verifying some property of nodes found by a matcher.
31 template <typename NodeType>
32 class MatchVerifier : public MatchFinder::MatchCallback {
33 public:
34   template <typename MatcherType>
35   testing::AssertionResult match(const std::string &Code,
36                                  const MatcherType &AMatcher) {
37     std::vector<std::string> Args;
38     return match(Code, AMatcher, Args, Lang_CXX);
39   }
40 
41   template <typename MatcherType>
42   testing::AssertionResult match(const std::string &Code,
43                                  const MatcherType &AMatcher,
44                                  Language L) {
45     std::vector<std::string> Args;
46     return match(Code, AMatcher, Args, L);
47   }
48 
49   template <typename MatcherType>
50   testing::AssertionResult match(const std::string &Code,
51                                  const MatcherType &AMatcher,
52                                  std::vector<std::string>& Args,
53                                  Language L);
54 
55 protected:
56   virtual void run(const MatchFinder::MatchResult &Result);
57   virtual void verify(const MatchFinder::MatchResult &Result,
58                       const NodeType &Node) {}
59 
60   void setFailure(const Twine &Result) {
61     Verified = false;
62     VerifyResult = Result.str();
63   }
64 
65   void setSuccess() {
66     Verified = true;
67   }
68 
69 private:
70   bool Verified;
71   std::string VerifyResult;
72 };
73 
74 /// \brief Runs a matcher over some code, and returns the result of the
75 /// verifier for the matched node.
76 template <typename NodeType> template <typename MatcherType>
77 testing::AssertionResult MatchVerifier<NodeType>::match(
78     const std::string &Code, const MatcherType &AMatcher,
79     std::vector<std::string>& Args, Language L) {
80   MatchFinder Finder;
81   Finder.addMatcher(AMatcher.bind(""), this);
82   OwningPtr<tooling::FrontendActionFactory> Factory(
83       tooling::newFrontendActionFactory(&Finder));
84 
85   StringRef FileName;
86   switch (L) {
87   case Lang_C:
88     Args.push_back("-std=c99");
89     FileName = "input.c";
90     break;
91   case Lang_C89:
92     Args.push_back("-std=c89");
93     FileName = "input.c";
94     break;
95   case Lang_CXX:
96     Args.push_back("-std=c++98");
97     FileName = "input.cc";
98     break;
99   case Lang_CXX11:
100     Args.push_back("-std=c++11");
101     FileName = "input.cc";
102     break;
103   case Lang_OpenCL:
104     FileName = "input.cl";
105   }
106 
107   // Default to failure in case callback is never called
108   setFailure("Could not find match");
109   if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
110     return testing::AssertionFailure() << "Parsing error";
111   if (!Verified)
112     return testing::AssertionFailure() << VerifyResult;
113   return testing::AssertionSuccess();
114 }
115 
116 template <typename NodeType>
117 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
118   const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
119   if (!Node) {
120     setFailure("Matched node has wrong type");
121   } else {
122     // Callback has been called, default to success.
123     setSuccess();
124     verify(Result, *Node);
125   }
126 }
127 
128 template <>
129 inline void MatchVerifier<ast_type_traits::DynTypedNode>::run(
130     const MatchFinder::MatchResult &Result) {
131   BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
132   BoundNodes::IDToNodeMap::const_iterator I = M.find("");
133   if (I == M.end()) {
134     setFailure("Node was not bound");
135   } else {
136     // Callback has been called, default to success.
137     setSuccess();
138     verify(Result, I->second);
139   }
140 }
141 
142 /// \brief Verify whether a node has the correct source location.
143 ///
144 /// By default, Node.getSourceLocation() is checked. This can be changed
145 /// by overriding getLocation().
146 template <typename NodeType>
147 class LocationVerifier : public MatchVerifier<NodeType> {
148 public:
149   void expectLocation(unsigned Line, unsigned Column) {
150     ExpectLine = Line;
151     ExpectColumn = Column;
152   }
153 
154 protected:
155   void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
156     SourceLocation Loc = getLocation(Node);
157     unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
158     unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
159     if (Line != ExpectLine || Column != ExpectColumn) {
160       std::string MsgStr;
161       llvm::raw_string_ostream Msg(MsgStr);
162       Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
163           << ">, found <";
164       Loc.print(Msg, *Result.SourceManager);
165       Msg << '>';
166       this->setFailure(Msg.str());
167     }
168   }
169 
170   virtual SourceLocation getLocation(const NodeType &Node) {
171     return Node.getLocation();
172   }
173 
174 private:
175   unsigned ExpectLine, ExpectColumn;
176 };
177 
178 /// \brief Verify whether a node has the correct source range.
179 ///
180 /// By default, Node.getSourceRange() is checked. This can be changed
181 /// by overriding getRange().
182 template <typename NodeType>
183 class RangeVerifier : public MatchVerifier<NodeType> {
184 public:
185   void expectRange(unsigned BeginLine, unsigned BeginColumn,
186                    unsigned EndLine, unsigned EndColumn) {
187     ExpectBeginLine = BeginLine;
188     ExpectBeginColumn = BeginColumn;
189     ExpectEndLine = EndLine;
190     ExpectEndColumn = EndColumn;
191   }
192 
193 protected:
194   void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
195     SourceRange R = getRange(Node);
196     SourceLocation Begin = R.getBegin();
197     SourceLocation End = R.getEnd();
198     unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
199     unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
200     unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
201     unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
202     if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
203         EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
204       std::string MsgStr;
205       llvm::raw_string_ostream Msg(MsgStr);
206       Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
207           << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
208       Begin.print(Msg, *Result.SourceManager);
209       Msg << '-';
210       End.print(Msg, *Result.SourceManager);
211       Msg << '>';
212       this->setFailure(Msg.str());
213     }
214   }
215 
216   virtual SourceRange getRange(const NodeType &Node) {
217     return Node.getSourceRange();
218   }
219 
220 private:
221   unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
222 };
223 
224 /// \brief Verify whether a node's dump contains a given substring.
225 class DumpVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
226 public:
227   void expectSubstring(const std::string &Str) {
228     ExpectSubstring = Str;
229   }
230 
231 protected:
232   void verify(const MatchFinder::MatchResult &Result,
233               const ast_type_traits::DynTypedNode &Node) {
234     std::string DumpStr;
235     llvm::raw_string_ostream Dump(DumpStr);
236     Node.dump(Dump, *Result.SourceManager);
237 
238     if (Dump.str().find(ExpectSubstring) == std::string::npos) {
239       std::string MsgStr;
240       llvm::raw_string_ostream Msg(MsgStr);
241       Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
242           << Dump.str() << '>';
243       this->setFailure(Msg.str());
244     }
245   }
246 
247 private:
248   std::string ExpectSubstring;
249 };
250 
251 /// \brief Verify whether a node's pretty print matches a given string.
252 class PrintVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
253 public:
254   void expectString(const std::string &Str) {
255     ExpectString = Str;
256   }
257 
258 protected:
259   void verify(const MatchFinder::MatchResult &Result,
260               const ast_type_traits::DynTypedNode &Node) {
261     std::string PrintStr;
262     llvm::raw_string_ostream Print(PrintStr);
263     Node.print(Print, Result.Context->getPrintingPolicy());
264 
265     if (Print.str() != ExpectString) {
266       std::string MsgStr;
267       llvm::raw_string_ostream Msg(MsgStr);
268       Msg << "Expected pretty print <" << ExpectString << ">, found <"
269           << Print.str() << '>';
270       this->setFailure(Msg.str());
271     }
272   }
273 
274 private:
275   std::string ExpectString;
276 };
277 
278 } // end namespace ast_matchers
279 } // end namespace clang
280