1 //===----------------------------------------------------------------------===//
2 //
3 // Copyright (c) 2016, 2017 The University of Utah
4 // All rights reserved.
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License.  See the file COPYING for details.
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #if HAVE_CONFIG_H
12 #  include <config.h>
13 #endif
14 
15 #include "ReplaceArrayAccessWithIndex.h"
16 
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/ASTContext.h"
19 #include "clang/Basic/SourceManager.h"
20 
21 #include "TransformationManager.h"
22 
23 #include <iostream>
24 
25 using namespace clang;
26 
27 
28 static const char *Description =
29   "Replace array accesses with the index expression.";
30 
31 static RegisterTransformation<ReplaceArrayAccessWithIndex>
32 Trans("replace-array-access-with-index", Description);
33 
34 class ReplaceArrayAccessWithIndex::IndexCollector
35   : public RecursiveASTVisitor<ReplaceArrayAccessWithIndex::IndexCollector>
36 {
37 public:
38   explicit IndexCollector(ReplaceArrayAccessWithIndex *instance);
39   bool VisitArraySubscriptExpr(ArraySubscriptExpr *ASE);
40 
41 private:
42   const VarDecl *getVarDeclFromExpr(const Expr *E);
43   ReplaceArrayAccessWithIndex *ConsumerInstance;
44 };
45 
IndexCollector(ReplaceArrayAccessWithIndex * instance)46 ReplaceArrayAccessWithIndex::IndexCollector::IndexCollector(
47   ReplaceArrayAccessWithIndex *instance)
48   : ConsumerInstance(instance)
49 {
50   // No further initialization needed.
51 }
52 
VisitArraySubscriptExpr(ArraySubscriptExpr * ASE)53 bool ReplaceArrayAccessWithIndex::IndexCollector::VisitArraySubscriptExpr(
54   ArraySubscriptExpr *ASE)
55 {
56   // Skip expressions in included files.
57   if (ConsumerInstance->isInIncludedFile(ASE))
58     return true;
59 
60   const VarDecl *BaseVD = getVarDeclFromExpr(ASE->getBase());
61 
62   if (!BaseVD)
63     return true;
64 
65   ArrayType const *ArrayTy = dyn_cast<ArrayType>(BaseVD->getType().getTypePtr());
66   // Only apply the transformation to one-dimensional arrays of scalars.
67   if (!ArrayTy || !ArrayTy->getElementType().getTypePtr()->isScalarType())
68     return true;
69 
70   ConsumerInstance->ASEs.push_back(ASE);
71   ConsumerInstance->ValidInstanceNum++;
72 
73   return true;
74 }
75 
getVarDeclFromExpr(const Expr * E)76 const VarDecl *ReplaceArrayAccessWithIndex::IndexCollector::getVarDeclFromExpr(
77   const Expr *E)
78 {
79   TransAssert(E && "NULL Expr!");
80   const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParenCasts());
81   if (!DRE)
82     return NULL;
83   const ValueDecl *OrigDecl = DRE->getDecl();
84   const VarDecl *VD = dyn_cast<VarDecl>(OrigDecl);
85   if (!VD)
86     return NULL;
87   const VarDecl *CanonicalVD = VD->getCanonicalDecl();
88   return CanonicalVD;
89 }
90 
91 
92 
~ReplaceArrayAccessWithIndex(void)93 ReplaceArrayAccessWithIndex::~ReplaceArrayAccessWithIndex(void)
94 {
95   delete Collector;
96 }
97 
98 
Initialize(clang::ASTContext & context)99 void ReplaceArrayAccessWithIndex::Initialize(clang::ASTContext &context)
100 {
101   Transformation::Initialize(context);
102   Collector = new IndexCollector(this);
103 }
104 
HandleTranslationUnit(clang::ASTContext & Ctx)105 void ReplaceArrayAccessWithIndex::HandleTranslationUnit(clang::ASTContext &Ctx)
106 {
107   TransAssert(Collector && "NULL Collector");
108   Collector->TraverseDecl(Ctx.getTranslationUnitDecl());
109 
110   if (QueryInstanceOnly)
111     return;
112 
113   if (TransformationCounter > ValidInstanceNum) {
114     TransError = TransMaxInstanceError;
115     return;
116   }
117 
118   Ctx.getDiagnostics().setSuppressAllDiagnostics(false);
119   doRewrite();
120 
121   if (Ctx.getDiagnostics().hasErrorOccurred() ||
122       Ctx.getDiagnostics().hasFatalErrorOccurred())
123     TransError = TransInternalError;
124 }
125 
doRewrite(void)126 void ReplaceArrayAccessWithIndex::doRewrite(void)
127 {
128   ArraySubscriptExpr const *ASE = ASEs[TransformationCounter - 1];
129   Expr const *Idx = ASE->getIdx();
130 
131   TransAssert(Idx && "Bad Idx!");
132 
133   std::string IdxStr;
134   RewriteHelper->getExprString(Idx, IdxStr);
135 
136   QualType ASEType = ASE->getType().getCanonicalType();
137   QualType IdxType = Idx->getType().getCanonicalType();
138 
139   if (ASEType != IdxType) {
140     IdxStr = std::string("(") + ASEType.getAsString() + std::string(")")+
141       std::string("(") + IdxStr + std::string(")");
142   }
143 
144   RewriteHelper->replaceExpr(ASE, IdxStr);
145 }
146