1 //===- unittest/AST/RecursiveASTVisitorTest.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "clang/AST/RecursiveASTVisitor.h"
10 #include "clang/AST/ASTConsumer.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/Attr.h"
13 #include "clang/Frontend/FrontendAction.h"
14 #include "clang/Tooling/Tooling.h"
15 #include "llvm/ADT/FunctionExtras.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "gmock/gmock.h"
18 #include "gtest/gtest.h"
19 #include <cassert>
20
21 using namespace clang;
22 using ::testing::ElementsAre;
23
24 namespace {
25 class ProcessASTAction : public clang::ASTFrontendAction {
26 public:
ProcessASTAction(llvm::unique_function<void (clang::ASTContext &)> Process)27 ProcessASTAction(llvm::unique_function<void(clang::ASTContext &)> Process)
28 : Process(std::move(Process)) {
29 assert(this->Process);
30 }
31
CreateASTConsumer(CompilerInstance & CI,StringRef InFile)32 std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
33 StringRef InFile) {
34 class Consumer : public ASTConsumer {
35 public:
36 Consumer(llvm::function_ref<void(ASTContext &CTx)> Process)
37 : Process(Process) {}
38
39 void HandleTranslationUnit(ASTContext &Ctx) override { Process(Ctx); }
40
41 private:
42 llvm::function_ref<void(ASTContext &CTx)> Process;
43 };
44
45 return std::make_unique<Consumer>(Process);
46 }
47
48 private:
49 llvm::unique_function<void(clang::ASTContext &)> Process;
50 };
51
52 enum class VisitEvent {
53 StartTraverseFunction,
54 EndTraverseFunction,
55 StartTraverseAttr,
56 EndTraverseAttr
57 };
58
59 class CollectInterestingEvents
60 : public RecursiveASTVisitor<CollectInterestingEvents> {
61 public:
TraverseFunctionDecl(FunctionDecl * D)62 bool TraverseFunctionDecl(FunctionDecl *D) {
63 Events.push_back(VisitEvent::StartTraverseFunction);
64 bool Ret = RecursiveASTVisitor::TraverseFunctionDecl(D);
65 Events.push_back(VisitEvent::EndTraverseFunction);
66
67 return Ret;
68 }
69
TraverseAttr(Attr * A)70 bool TraverseAttr(Attr *A) {
71 Events.push_back(VisitEvent::StartTraverseAttr);
72 bool Ret = RecursiveASTVisitor::TraverseAttr(A);
73 Events.push_back(VisitEvent::EndTraverseAttr);
74
75 return Ret;
76 }
77
takeEvents()78 std::vector<VisitEvent> takeEvents() && { return std::move(Events); }
79
80 private:
81 std::vector<VisitEvent> Events;
82 };
83
collectEvents(llvm::StringRef Code)84 std::vector<VisitEvent> collectEvents(llvm::StringRef Code) {
85 CollectInterestingEvents Visitor;
86 clang::tooling::runToolOnCode(
87 std::make_unique<ProcessASTAction>(
88 [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }),
89 Code);
90 return std::move(Visitor).takeEvents();
91 }
92 } // namespace
93
TEST(RecursiveASTVisitorTest,AttributesInsideDecls)94 TEST(RecursiveASTVisitorTest, AttributesInsideDecls) {
95 /// Check attributes are traversed inside TraverseFunctionDecl.
96 llvm::StringRef Code = R"cpp(
97 __attribute__((annotate("something"))) int foo() { return 10; }
98 )cpp";
99
100 EXPECT_THAT(collectEvents(Code),
101 ElementsAre(VisitEvent::StartTraverseFunction,
102 VisitEvent::StartTraverseAttr,
103 VisitEvent::EndTraverseAttr,
104 VisitEvent::EndTraverseFunction));
105 }
106