1 //===----------------------------------------------------------------------===//
2 //
3 // Copyright (c) 2012, 2013 The University of Utah
4 // All rights reserved.
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License.  See the file COPYING for details.
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #ifndef COMMON_STATEMENT_VISITOR_H
12 #define COMMON_STATEMENT_VISITOR_H
13 
14 #include "clang/AST/RecursiveASTVisitor.h"
15 
16 template<typename T>
17 class CommonStatementVisitor : public clang::RecursiveASTVisitor<T> {
18 public:
19 
CommonStatementVisitor()20   CommonStatementVisitor()
21     : CurrentFuncDecl(NULL),
22       CurrentStmt(NULL),
23       NeedParen(false)
24   { }
25 
getDerived()26   T &getDerived() { return *static_cast<T*>(this); };
27 
setCurrentFunctionDecl(clang::FunctionDecl * FD)28   void setCurrentFunctionDecl(clang::FunctionDecl *FD) {
29     CurrentFuncDecl = FD;
30   }
31 
32   bool VisitCompoundStmt(clang::CompoundStmt *S);
33 
34   bool VisitIfStmt(clang::IfStmt *IS);
35 
36   bool VisitForStmt(clang::ForStmt *FS);
37 
38   bool VisitWhileStmt(clang::WhileStmt *WS);
39 
40   bool VisitDoStmt(clang::DoStmt *DS);
41 
42   bool VisitCaseStmt(clang::CaseStmt *CS);
43 
44   bool VisitDefaultStmt(clang::DefaultStmt *DS);
45 
46   bool VisitCXXTryStmt(clang::CXXTryStmt *DS);
47 
48   void visitNonCompoundStmt(clang::Stmt *S);
49 
50 protected:
51 
52   clang::FunctionDecl *CurrentFuncDecl;
53 
54   clang::Stmt *CurrentStmt;
55 
56   bool NeedParen;
57 
58 };
59 
60 template<typename T>
VisitCompoundStmt(clang::CompoundStmt * CS)61 bool CommonStatementVisitor<T>::VisitCompoundStmt(clang::CompoundStmt *CS)
62 {
63   for (clang::CompoundStmt::body_iterator I = CS->body_begin(),
64        E = CS->body_end(); I != E; ++I) {
65     CurrentStmt = (*I);
66     getDerived().TraverseStmt(*I);
67   }
68   return false;
69 }
70 
71 template<typename T>
visitNonCompoundStmt(clang::Stmt * S)72 void CommonStatementVisitor<T>::visitNonCompoundStmt(clang::Stmt *S)
73 {
74   if (!S)
75     return;
76 
77   clang::CompoundStmt *CS = llvm::dyn_cast<clang::CompoundStmt>(S);
78   if (CS) {
79     VisitCompoundStmt(CS);
80     return;
81   }
82 
83   CurrentStmt = (S);
84   NeedParen = true;
85   getDerived().TraverseStmt(S);
86   NeedParen = false;
87 }
88 
89 // It is used to handle the case where if-then or else branch
90 // is not treated as a CompoundStmt. So it cannot be traversed
91 // from VisitCompoundStmt, e.g.,
92 //   if (x)
93 //     foo(bar())
94 template<typename T>
VisitIfStmt(clang::IfStmt * IS)95 bool CommonStatementVisitor<T>::VisitIfStmt(clang::IfStmt *IS)
96 {
97   clang::Expr *E = IS->getCond();
98   getDerived().TraverseStmt(E);
99 
100   clang::Stmt *ThenB = IS->getThen();
101   visitNonCompoundStmt(ThenB);
102 
103   clang::Stmt *ElseB = IS->getElse();
104   visitNonCompoundStmt(ElseB);
105 
106   return false;
107 }
108 
109 // It causes unsound transformation because
110 // the semantics of loop execution has been changed.
111 // For example,
112 //   int foo(int x)
113 //   {
114 //     int i;
115 //     for(i = 0; i < bar(bar(x)); i++)
116 //       ...
117 //   }
118 // will be transformed to:
119 //   int foo(int x)
120 //   {
121 //     int i;
122 //     int tmp_var = bar(x);
123 //     for(i = 0; i < bar(tmp_var); i++)
124 //       ...
125 //   }
126 template<typename T>
VisitForStmt(clang::ForStmt * FS)127 bool CommonStatementVisitor<T>::VisitForStmt(clang::ForStmt *FS)
128 {
129   clang::Stmt *Init = FS->getInit();
130   getDerived().TraverseStmt(Init);
131 
132   clang::Expr *Cond = FS->getCond();
133   getDerived().TraverseStmt(Cond);
134 
135   clang::Expr *Inc = FS->getInc();
136   getDerived().TraverseStmt(Inc);
137 
138   clang::Stmt *Body = FS->getBody();
139   visitNonCompoundStmt(Body);
140   return false;
141 }
142 
143 template<typename T>
VisitWhileStmt(clang::WhileStmt * WS)144 bool CommonStatementVisitor<T>::VisitWhileStmt(clang::WhileStmt *WS)
145 {
146   clang::Expr *E = WS->getCond();
147   getDerived().TraverseStmt(E);
148 
149   clang::Stmt *Body = WS->getBody();
150   visitNonCompoundStmt(Body);
151   return false;
152 }
153 
154 template<typename T>
VisitDoStmt(clang::DoStmt * DS)155 bool CommonStatementVisitor<T>::VisitDoStmt(clang::DoStmt *DS)
156 {
157   clang::Expr *E = DS->getCond();
158   getDerived().TraverseStmt(E);
159 
160   clang::Stmt *Body = DS->getBody();
161   visitNonCompoundStmt(Body);
162   return false;
163 }
164 
165 template<typename T>
VisitCaseStmt(clang::CaseStmt * CS)166 bool CommonStatementVisitor<T>::VisitCaseStmt(clang::CaseStmt *CS)
167 {
168   clang::Stmt *Body = CS->getSubStmt();
169   visitNonCompoundStmt(Body);
170   return false;
171 }
172 
173 template<typename T>
VisitDefaultStmt(clang::DefaultStmt * DS)174 bool CommonStatementVisitor<T>::VisitDefaultStmt(clang::DefaultStmt *DS)
175 {
176   clang::Stmt *Body = DS->getSubStmt();
177   visitNonCompoundStmt(Body);
178   return false;
179 }
180 
181 template<typename T>
VisitCXXTryStmt(clang::CXXTryStmt * CS)182 bool CommonStatementVisitor<T>::VisitCXXTryStmt(clang::CXXTryStmt *CS)
183 {
184   clang::CompoundStmt *TryBlock = CS->getTryBlock();
185   visitNonCompoundStmt(TryBlock);
186 
187   for (unsigned I = 0; I < CS->getNumHandlers(); ++I) {
188     clang::CXXCatchStmt *CatchStmt = CS->getHandler(I);
189     clang::Stmt *CatchBlock = CatchStmt->getHandlerBlock();
190     visitNonCompoundStmt(CatchBlock);
191   }
192   return false;
193 }
194 
195 #endif
196