1 //===- unittest/AST/RecursiveASTVisitorTest.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/AST/RecursiveASTVisitor.h"
10 #include "clang/AST/ASTConsumer.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/Attr.h"
13 #include "clang/Frontend/FrontendAction.h"
14 #include "clang/Tooling/Tooling.h"
15 #include "llvm/ADT/FunctionExtras.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "gmock/gmock.h"
18 #include "gtest/gtest.h"
19 #include <cassert>
20 
21 using namespace clang;
22 using ::testing::ElementsAre;
23 
24 namespace {
25 class ProcessASTAction : public clang::ASTFrontendAction {
26 public:
ProcessASTAction(llvm::unique_function<void (clang::ASTContext &)> Process)27   ProcessASTAction(llvm::unique_function<void(clang::ASTContext &)> Process)
28       : Process(std::move(Process)) {
29     assert(this->Process);
30   }
31 
CreateASTConsumer(CompilerInstance & CI,StringRef InFile)32   std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
33                                                  StringRef InFile) {
34     class Consumer : public ASTConsumer {
35     public:
36       Consumer(llvm::function_ref<void(ASTContext &CTx)> Process)
37           : Process(Process) {}
38 
39       void HandleTranslationUnit(ASTContext &Ctx) override { Process(Ctx); }
40 
41     private:
42       llvm::function_ref<void(ASTContext &CTx)> Process;
43     };
44 
45     return std::make_unique<Consumer>(Process);
46   }
47 
48 private:
49   llvm::unique_function<void(clang::ASTContext &)> Process;
50 };
51 
52 enum class VisitEvent {
53   StartTraverseFunction,
54   EndTraverseFunction,
55   StartTraverseAttr,
56   EndTraverseAttr
57 };
58 
59 class CollectInterestingEvents
60     : public RecursiveASTVisitor<CollectInterestingEvents> {
61 public:
TraverseFunctionDecl(FunctionDecl * D)62   bool TraverseFunctionDecl(FunctionDecl *D) {
63     Events.push_back(VisitEvent::StartTraverseFunction);
64     bool Ret = RecursiveASTVisitor::TraverseFunctionDecl(D);
65     Events.push_back(VisitEvent::EndTraverseFunction);
66 
67     return Ret;
68   }
69 
TraverseAttr(Attr * A)70   bool TraverseAttr(Attr *A) {
71     Events.push_back(VisitEvent::StartTraverseAttr);
72     bool Ret = RecursiveASTVisitor::TraverseAttr(A);
73     Events.push_back(VisitEvent::EndTraverseAttr);
74 
75     return Ret;
76   }
77 
takeEvents()78   std::vector<VisitEvent> takeEvents() && { return std::move(Events); }
79 
80 private:
81   std::vector<VisitEvent> Events;
82 };
83 
collectEvents(llvm::StringRef Code)84 std::vector<VisitEvent> collectEvents(llvm::StringRef Code) {
85   CollectInterestingEvents Visitor;
86   clang::tooling::runToolOnCode(
87       std::make_unique<ProcessASTAction>(
88           [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }),
89       Code);
90   return std::move(Visitor).takeEvents();
91 }
92 } // namespace
93 
TEST(RecursiveASTVisitorTest,AttributesInsideDecls)94 TEST(RecursiveASTVisitorTest, AttributesInsideDecls) {
95   /// Check attributes are traversed inside TraverseFunctionDecl.
96   llvm::StringRef Code = R"cpp(
97 __attribute__((annotate("something"))) int foo() { return 10; }
98   )cpp";
99 
100   EXPECT_THAT(collectEvents(Code),
101               ElementsAre(VisitEvent::StartTraverseFunction,
102                           VisitEvent::StartTraverseAttr,
103                           VisitEvent::EndTraverseAttr,
104                           VisitEvent::EndTraverseFunction));
105 }
106