1 //===----------------------------------------------------------------------===// 2 // 3 // Copyright (c) 2012, 2013 The University of Utah 4 // All rights reserved. 5 // 6 // This file is distributed under the University of Illinois Open Source 7 // License. See the file COPYING for details. 8 // 9 //===----------------------------------------------------------------------===// 10 11 #ifndef COMMON_STATEMENT_VISITOR_H 12 #define COMMON_STATEMENT_VISITOR_H 13 14 #include "clang/AST/RecursiveASTVisitor.h" 15 16 template<typename T> 17 class CommonStatementVisitor : public clang::RecursiveASTVisitor<T> { 18 public: 19 CommonStatementVisitor()20 CommonStatementVisitor() 21 : CurrentFuncDecl(NULL), 22 CurrentStmt(NULL), 23 NeedParen(false) 24 { } 25 getDerived()26 T &getDerived() { return *static_cast<T*>(this); }; 27 setCurrentFunctionDecl(clang::FunctionDecl * FD)28 void setCurrentFunctionDecl(clang::FunctionDecl *FD) { 29 CurrentFuncDecl = FD; 30 } 31 32 bool VisitCompoundStmt(clang::CompoundStmt *S); 33 34 bool VisitIfStmt(clang::IfStmt *IS); 35 36 bool VisitForStmt(clang::ForStmt *FS); 37 38 bool VisitWhileStmt(clang::WhileStmt *WS); 39 40 bool VisitDoStmt(clang::DoStmt *DS); 41 42 bool VisitCaseStmt(clang::CaseStmt *CS); 43 44 bool VisitDefaultStmt(clang::DefaultStmt *DS); 45 46 bool VisitCXXTryStmt(clang::CXXTryStmt *DS); 47 48 void visitNonCompoundStmt(clang::Stmt *S); 49 50 protected: 51 52 clang::FunctionDecl *CurrentFuncDecl; 53 54 clang::Stmt *CurrentStmt; 55 56 bool NeedParen; 57 58 }; 59 60 template<typename T> VisitCompoundStmt(clang::CompoundStmt * CS)61bool CommonStatementVisitor<T>::VisitCompoundStmt(clang::CompoundStmt *CS) 62 { 63 for (clang::CompoundStmt::body_iterator I = CS->body_begin(), 64 E = CS->body_end(); I != E; ++I) { 65 CurrentStmt = (*I); 66 getDerived().TraverseStmt(*I); 67 } 68 return false; 69 } 70 71 template<typename T> visitNonCompoundStmt(clang::Stmt * S)72void CommonStatementVisitor<T>::visitNonCompoundStmt(clang::Stmt *S) 73 { 74 if (!S) 75 return; 76 77 clang::CompoundStmt *CS = llvm::dyn_cast<clang::CompoundStmt>(S); 78 if (CS) { 79 VisitCompoundStmt(CS); 80 return; 81 } 82 83 CurrentStmt = (S); 84 NeedParen = true; 85 getDerived().TraverseStmt(S); 86 NeedParen = false; 87 } 88 89 // It is used to handle the case where if-then or else branch 90 // is not treated as a CompoundStmt. So it cannot be traversed 91 // from VisitCompoundStmt, e.g., 92 // if (x) 93 // foo(bar()) 94 template<typename T> VisitIfStmt(clang::IfStmt * IS)95bool CommonStatementVisitor<T>::VisitIfStmt(clang::IfStmt *IS) 96 { 97 clang::Expr *E = IS->getCond(); 98 getDerived().TraverseStmt(E); 99 100 clang::Stmt *ThenB = IS->getThen(); 101 visitNonCompoundStmt(ThenB); 102 103 clang::Stmt *ElseB = IS->getElse(); 104 visitNonCompoundStmt(ElseB); 105 106 return false; 107 } 108 109 // It causes unsound transformation because 110 // the semantics of loop execution has been changed. 111 // For example, 112 // int foo(int x) 113 // { 114 // int i; 115 // for(i = 0; i < bar(bar(x)); i++) 116 // ... 117 // } 118 // will be transformed to: 119 // int foo(int x) 120 // { 121 // int i; 122 // int tmp_var = bar(x); 123 // for(i = 0; i < bar(tmp_var); i++) 124 // ... 125 // } 126 template<typename T> VisitForStmt(clang::ForStmt * FS)127bool CommonStatementVisitor<T>::VisitForStmt(clang::ForStmt *FS) 128 { 129 clang::Stmt *Init = FS->getInit(); 130 getDerived().TraverseStmt(Init); 131 132 clang::Expr *Cond = FS->getCond(); 133 getDerived().TraverseStmt(Cond); 134 135 clang::Expr *Inc = FS->getInc(); 136 getDerived().TraverseStmt(Inc); 137 138 clang::Stmt *Body = FS->getBody(); 139 visitNonCompoundStmt(Body); 140 return false; 141 } 142 143 template<typename T> VisitWhileStmt(clang::WhileStmt * WS)144bool CommonStatementVisitor<T>::VisitWhileStmt(clang::WhileStmt *WS) 145 { 146 clang::Expr *E = WS->getCond(); 147 getDerived().TraverseStmt(E); 148 149 clang::Stmt *Body = WS->getBody(); 150 visitNonCompoundStmt(Body); 151 return false; 152 } 153 154 template<typename T> VisitDoStmt(clang::DoStmt * DS)155bool CommonStatementVisitor<T>::VisitDoStmt(clang::DoStmt *DS) 156 { 157 clang::Expr *E = DS->getCond(); 158 getDerived().TraverseStmt(E); 159 160 clang::Stmt *Body = DS->getBody(); 161 visitNonCompoundStmt(Body); 162 return false; 163 } 164 165 template<typename T> VisitCaseStmt(clang::CaseStmt * CS)166bool CommonStatementVisitor<T>::VisitCaseStmt(clang::CaseStmt *CS) 167 { 168 clang::Stmt *Body = CS->getSubStmt(); 169 visitNonCompoundStmt(Body); 170 return false; 171 } 172 173 template<typename T> VisitDefaultStmt(clang::DefaultStmt * DS)174bool CommonStatementVisitor<T>::VisitDefaultStmt(clang::DefaultStmt *DS) 175 { 176 clang::Stmt *Body = DS->getSubStmt(); 177 visitNonCompoundStmt(Body); 178 return false; 179 } 180 181 template<typename T> VisitCXXTryStmt(clang::CXXTryStmt * CS)182bool CommonStatementVisitor<T>::VisitCXXTryStmt(clang::CXXTryStmt *CS) 183 { 184 clang::CompoundStmt *TryBlock = CS->getTryBlock(); 185 visitNonCompoundStmt(TryBlock); 186 187 for (unsigned I = 0; I < CS->getNumHandlers(); ++I) { 188 clang::CXXCatchStmt *CatchStmt = CS->getHandler(I); 189 clang::Stmt *CatchBlock = CatchStmt->getHandlerBlock(); 190 visitNonCompoundStmt(CatchBlock); 191 } 192 return false; 193 } 194 195 #endif 196