1 //===----------------------------------------------------------------------===// 2 // 3 // Copyright (c) 2016, 2017 The University of Utah 4 // All rights reserved. 5 // 6 // This file is distributed under the University of Illinois Open Source 7 // License. See the file COPYING for details. 8 // 9 //===----------------------------------------------------------------------===// 10 11 #if HAVE_CONFIG_H 12 # include <config.h> 13 #endif 14 15 #include "ReplaceArrayAccessWithIndex.h" 16 17 #include "clang/AST/RecursiveASTVisitor.h" 18 #include "clang/AST/ASTContext.h" 19 #include "clang/Basic/SourceManager.h" 20 21 #include "TransformationManager.h" 22 23 #include <iostream> 24 25 using namespace clang; 26 27 28 static const char *Description = 29 "Replace array accesses with the index expression."; 30 31 static RegisterTransformation<ReplaceArrayAccessWithIndex> 32 Trans("replace-array-access-with-index", Description); 33 34 class ReplaceArrayAccessWithIndex::IndexCollector 35 : public RecursiveASTVisitor<ReplaceArrayAccessWithIndex::IndexCollector> 36 { 37 public: 38 explicit IndexCollector(ReplaceArrayAccessWithIndex *instance); 39 bool VisitArraySubscriptExpr(ArraySubscriptExpr *ASE); 40 41 private: 42 const VarDecl *getVarDeclFromExpr(const Expr *E); 43 ReplaceArrayAccessWithIndex *ConsumerInstance; 44 }; 45 IndexCollector(ReplaceArrayAccessWithIndex * instance)46ReplaceArrayAccessWithIndex::IndexCollector::IndexCollector( 47 ReplaceArrayAccessWithIndex *instance) 48 : ConsumerInstance(instance) 49 { 50 // No further initialization needed. 51 } 52 VisitArraySubscriptExpr(ArraySubscriptExpr * ASE)53bool ReplaceArrayAccessWithIndex::IndexCollector::VisitArraySubscriptExpr( 54 ArraySubscriptExpr *ASE) 55 { 56 // Skip expressions in included files. 57 if (ConsumerInstance->isInIncludedFile(ASE)) 58 return true; 59 60 const VarDecl *BaseVD = getVarDeclFromExpr(ASE->getBase()); 61 62 if (!BaseVD) 63 return true; 64 65 ArrayType const *ArrayTy = dyn_cast<ArrayType>(BaseVD->getType().getTypePtr()); 66 // Only apply the transformation to one-dimensional arrays of scalars. 67 if (!ArrayTy || !ArrayTy->getElementType().getTypePtr()->isScalarType()) 68 return true; 69 70 ConsumerInstance->ASEs.push_back(ASE); 71 ConsumerInstance->ValidInstanceNum++; 72 73 return true; 74 } 75 getVarDeclFromExpr(const Expr * E)76const VarDecl *ReplaceArrayAccessWithIndex::IndexCollector::getVarDeclFromExpr( 77 const Expr *E) 78 { 79 TransAssert(E && "NULL Expr!"); 80 const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParenCasts()); 81 if (!DRE) 82 return NULL; 83 const ValueDecl *OrigDecl = DRE->getDecl(); 84 const VarDecl *VD = dyn_cast<VarDecl>(OrigDecl); 85 if (!VD) 86 return NULL; 87 const VarDecl *CanonicalVD = VD->getCanonicalDecl(); 88 return CanonicalVD; 89 } 90 91 92 ~ReplaceArrayAccessWithIndex(void)93ReplaceArrayAccessWithIndex::~ReplaceArrayAccessWithIndex(void) 94 { 95 delete Collector; 96 } 97 98 Initialize(clang::ASTContext & context)99void ReplaceArrayAccessWithIndex::Initialize(clang::ASTContext &context) 100 { 101 Transformation::Initialize(context); 102 Collector = new IndexCollector(this); 103 } 104 HandleTranslationUnit(clang::ASTContext & Ctx)105void ReplaceArrayAccessWithIndex::HandleTranslationUnit(clang::ASTContext &Ctx) 106 { 107 TransAssert(Collector && "NULL Collector"); 108 Collector->TraverseDecl(Ctx.getTranslationUnitDecl()); 109 110 if (QueryInstanceOnly) 111 return; 112 113 if (TransformationCounter > ValidInstanceNum) { 114 TransError = TransMaxInstanceError; 115 return; 116 } 117 118 Ctx.getDiagnostics().setSuppressAllDiagnostics(false); 119 doRewrite(); 120 121 if (Ctx.getDiagnostics().hasErrorOccurred() || 122 Ctx.getDiagnostics().hasFatalErrorOccurred()) 123 TransError = TransInternalError; 124 } 125 doRewrite(void)126void ReplaceArrayAccessWithIndex::doRewrite(void) 127 { 128 ArraySubscriptExpr const *ASE = ASEs[TransformationCounter - 1]; 129 Expr const *Idx = ASE->getIdx(); 130 131 TransAssert(Idx && "Bad Idx!"); 132 133 std::string IdxStr; 134 RewriteHelper->getExprString(Idx, IdxStr); 135 136 QualType ASEType = ASE->getType().getCanonicalType(); 137 QualType IdxType = Idx->getType().getCanonicalType(); 138 139 if (ASEType != IdxType) { 140 IdxStr = std::string("(") + ASEType.getAsString() + std::string(")")+ 141 std::string("(") + IdxStr + std::string(")"); 142 } 143 144 RewriteHelper->replaceExpr(ASE, IdxStr); 145 } 146