1 //===----------------------------------------------------------------------===//
2 //
3 // Copyright (c) 2012, 2013, 2015, 2016, 2017, 2019 The University of Utah
4 // All rights reserved.
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License.  See the file COPYING for details.
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #if HAVE_CONFIG_H
12 #  include <config.h>
13 #endif
14 
15 #include "SimplifyStruct.h"
16 
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/Lex/Lexer.h"
19 #include "clang/AST/ASTContext.h"
20 #include "clang/Basic/SourceManager.h"
21 #include "clang/AST/RecordLayout.h"
22 
23 #include "TransformationManager.h"
24 
25 using namespace clang;
26 
27 static const char *DescriptionMsg =
28 "This pass replaces a struct with its parent if it has only one \
29 field, and this field is a struct, e.g, \n\
30   struct S1 { \n\
31     int f1;\n\
32     int f2;\n\
33   }\n\
34   struct S2 { \n\
35     struct S1 f1;\n\
36   } \n\
37 In the above code, struct S2 will be replaced with struct S1, including\n\
38 all its referenced. \n";
39 
40 static RegisterTransformation<SimplifyStruct>
41          Trans("simplify-struct", DescriptionMsg);
42 
43 class SimplifyStructCollectionVisitor : public
44   RecursiveASTVisitor<SimplifyStructCollectionVisitor> {
45 
46 public:
SimplifyStructCollectionVisitor(SimplifyStruct * Instance)47   explicit SimplifyStructCollectionVisitor(SimplifyStruct *Instance)
48     : ConsumerInstance(Instance)
49   { }
50 
51   bool VisitRecordDecl(RecordDecl *RD);
52 
53 private:
54 
55   SimplifyStruct *ConsumerInstance;
56 };
57 
58 class SimplifyStructRewriteVisitor : public
59   RecursiveASTVisitor<SimplifyStructRewriteVisitor> {
60 
61 public:
SimplifyStructRewriteVisitor(SimplifyStruct * Instance)62   explicit SimplifyStructRewriteVisitor(SimplifyStruct *Instance)
63     : ConsumerInstance(Instance)
64   { }
65 
66   bool VisitVarDecl(VarDecl *VD);
67 
68   bool VisitRecordDecl(RecordDecl *RD);
69 
70   bool VisitRecordTypeLoc(RecordTypeLoc RTLoc);
71 
72   bool VisitMemberExpr(MemberExpr *ME);
73 
74 private:
75 
76   SimplifyStruct *ConsumerInstance;
77 };
78 
VisitRecordDecl(RecordDecl * RD)79 bool SimplifyStructCollectionVisitor::VisitRecordDecl(RecordDecl *RD)
80 {
81   if (ConsumerInstance->isInIncludedFile(RD))
82     return true;
83   if (!RD->isThisDeclarationADefinition() || !RD->isStruct())
84     return true;
85   if (ConsumerInstance->isSpecialRecordDecl(RD))
86     return true;
87   if (RD->isInvalidDecl())
88     return true;
89 
90   const ASTRecordLayout &Info = ConsumerInstance->Context->getASTRecordLayout(RD);
91   unsigned Count = Info.getFieldCount();
92   if (Count != 1)
93     return true;
94 
95   const FieldDecl *FD = *(RD->field_begin());
96   TransAssert(FD && "Invalid FieldDecl!");
97   const Type *Ty = FD->getType().getTypePtr();
98   const RecordType *RT = Ty->getAs<RecordType>();
99   if (!RT)
100     return true;
101 
102   const RecordDecl *NestedRD = RT->getDecl();
103   if (NestedRD->getNameAsString() == "")
104     return true;
105 
106   ConsumerInstance->ValidInstanceNum++;
107   if (ConsumerInstance->TransformationCounter ==
108       ConsumerInstance->ValidInstanceNum) {
109     ConsumerInstance->TheRecordDecl =
110       dyn_cast<RecordDecl>(RD->getCanonicalDecl());
111     ConsumerInstance->ReplacingRecordDecl =
112       dyn_cast<RecordDecl>(NestedRD->getCanonicalDecl());
113     ConsumerInstance->setQualifierFlags(FD);
114   }
115   return true;
116 }
117 
VisitVarDecl(VarDecl * VD)118 bool SimplifyStructRewriteVisitor::VisitVarDecl(VarDecl *VD)
119 {
120   if (!ConsumerInstance->ConstField && !ConsumerInstance->VolatileField)
121     return true;
122 
123   QualType QT = VD->getType();
124   const Type *Ty = QT.getTypePtr();
125   const RecordType *RT = Ty->getAs<RecordType>();
126   if (!RT)
127     return true;
128 
129   const RecordDecl *RD = RT->getDecl();
130   if (RD != ConsumerInstance->TheRecordDecl)
131     return true;
132 
133   SourceLocation LocStart = VD->getBeginLoc();
134   void *LocPtr = LocStart.getPtrEncoding();
135   if (ConsumerInstance->VisitedVarDeclLocs.count(LocPtr))
136     return true;
137 
138   ConsumerInstance->VisitedVarDeclLocs.insert(LocPtr);
139 
140   std::string QualStr = "";
141   if (ConsumerInstance->ConstField && !QT.isConstQualified())
142     QualStr += "const ";
143   if (ConsumerInstance->VolatileField && !QT.isVolatileQualified())
144     QualStr += "volatile ";
145   ConsumerInstance->TheRewriter.InsertText(LocStart, QualStr);
146   return true;
147 }
148 
VisitRecordDecl(RecordDecl * RD)149 bool SimplifyStructRewriteVisitor::VisitRecordDecl(RecordDecl *RD)
150 {
151   RecordDecl *CanonicalRD = dyn_cast<RecordDecl>(RD->getCanonicalDecl());
152   if (CanonicalRD != ConsumerInstance->TheRecordDecl)
153     return true;
154 
155   SourceLocation LocStart = RD->getLocation();
156   void *LocPtr = LocStart.getPtrEncoding();
157   if (!ConsumerInstance->VisitedLocs.count(LocPtr)) {
158     ConsumerInstance->VisitedLocs.insert(LocPtr);
159     std::string RPName =
160       ConsumerInstance->ReplacingRecordDecl->getNameAsString();
161     if (RD->getNameAsString() != "") {
162       ConsumerInstance->RewriteHelper->replaceRecordDeclName(RD, RPName);
163     }
164     else {
165       ConsumerInstance->TheRewriter.ReplaceText(LocStart,
166         /*struct*/6, "struct " + RPName);
167     }
168   }
169 
170   if (!RD->isThisDeclarationADefinition())
171     return true;
172 
173   SourceLocation LBLoc = RD->getBraceRange().getBegin();
174   SourceLocation RBLoc = RD->getBraceRange().getEnd();
175   ConsumerInstance->TheRewriter.RemoveText(SourceRange(LBLoc, RBLoc));
176   return true;
177 }
178 
VisitRecordTypeLoc(RecordTypeLoc RTLoc)179 bool SimplifyStructRewriteVisitor::VisitRecordTypeLoc(RecordTypeLoc RTLoc)
180 {
181   const Type *Ty = RTLoc.getTypePtr();
182   if (Ty->isUnionType())
183     return true;
184 
185   RecordDecl *RD = RTLoc.getDecl();
186   RecordDecl *CanonicalRD = dyn_cast<RecordDecl>(RD->getCanonicalDecl());
187   if (CanonicalRD != ConsumerInstance->TheRecordDecl)
188     return true;
189 
190   SourceLocation LocStart = RTLoc.getBeginLoc();
191   void *LocPtr = LocStart.getPtrEncoding();
192   if (ConsumerInstance->VisitedLocs.count(LocPtr))
193     return true;
194   ConsumerInstance->VisitedLocs.insert(LocPtr);
195 
196   ConsumerInstance->RewriteHelper->replaceRecordType(RTLoc,
197     ConsumerInstance->ReplacingRecordDecl->getNameAsString());
198   return true;
199 }
200 
VisitMemberExpr(MemberExpr * ME)201 bool SimplifyStructRewriteVisitor::VisitMemberExpr(MemberExpr *ME)
202 {
203   ValueDecl *OrigDecl = ME->getMemberDecl();
204   FieldDecl *FD = dyn_cast<FieldDecl>(OrigDecl);
205 
206   if (!FD) {
207     // in C++, getMemberDecl returns a CXXMethodDecl.
208     if (TransformationManager::isCXXLangOpt())
209       return true;
210     TransAssert(0 && "Bad FD!\n");
211   }
212 
213   RecordDecl *RD = FD->getParent();
214   if (!RD || (dyn_cast<RecordDecl>(RD->getCanonicalDecl()) !=
215               ConsumerInstance->TheRecordDecl))
216     return true;
217 
218   const Type *T = FD->getType().getTypePtr();
219   const RecordType *RT = T->getAs<RecordType>();
220   TransAssert(RT && "Invalid record type!");
221   const RecordDecl *ReplacingRD =
222     dyn_cast<RecordDecl>(RT->getDecl()->getCanonicalDecl());
223   (void)ReplacingRD;
224   TransAssert((ReplacingRD == ConsumerInstance->ReplacingRecordDecl) &&
225     "Unmatched Replacing RD!");
226 
227   SourceLocation LocEnd = ME->getEndLoc();
228   if (LocEnd.isMacroID()) {
229     LocEnd = ConsumerInstance->SrcManager->getSpellingLoc(LocEnd);
230   }
231   SourceLocation ArrowPos =
232       Lexer::findLocationAfterToken(LocEnd,
233                                     tok::arrow,
234                                     *(ConsumerInstance->SrcManager),
235                                     ConsumerInstance->Context->getLangOpts(),
236                                     /*SkipTrailingWhitespaceAndNewLine=*/true);
237   SourceLocation PeriodPos =
238       Lexer::findLocationAfterToken(LocEnd,
239                                     tok::period,
240                                     *(ConsumerInstance->SrcManager),
241                                     ConsumerInstance->Context->getLangOpts(),
242                                     /*SkipTrailingWhitespaceAndNewLine=*/true);
243 
244   std::string ES;
245   ConsumerInstance->RewriteHelper->getExprString(ME, ES);
246 
247   // no more MemberExpr upon this ME
248   if (ArrowPos.isInvalid() && PeriodPos.isInvalid()) {
249     SourceLocation StartLoc = ME->getBeginLoc();
250     size_t Pos;
251 
252     if (ME->isArrow()) {
253       Pos = ES.find("->");
254     }
255     else {
256       Pos = ES.find(".");
257     }
258     TransAssert((Pos != std::string::npos) && "Cannot find arrow or dot!");
259     StartLoc = StartLoc.getLocWithOffset(Pos);
260 
261     int Off = ES.length() - Pos;
262     ConsumerInstance->TheRewriter.RemoveText(StartLoc, Off);
263     return true;
264   }
265 
266   SourceLocation StartLoc = ME->getMemberLoc();
267   const char *StartBuf =
268     ConsumerInstance->SrcManager->getCharacterData(StartLoc);
269   const char *EndBuf;
270   if (ArrowPos.isValid()) {
271     EndBuf = ConsumerInstance->SrcManager->getCharacterData(ArrowPos);
272     EndBuf++;
273   }
274   else {
275     TransAssert(PeriodPos.isValid() && "Bad dot position!");
276     EndBuf = ConsumerInstance->SrcManager->getCharacterData(PeriodPos);
277   }
278   int Off = EndBuf - StartBuf;
279   ConsumerInstance->TheRewriter.RemoveText(StartLoc, Off);
280   return true;
281 }
282 
Initialize(ASTContext & context)283 void SimplifyStruct::Initialize(ASTContext &context)
284 {
285   Transformation::Initialize(context);
286   CollectionVisitor = new SimplifyStructCollectionVisitor(this);
287   RewriteVisitor = new SimplifyStructRewriteVisitor(this);
288 }
289 
HandleTranslationUnit(ASTContext & Ctx)290 void SimplifyStruct::HandleTranslationUnit(ASTContext &Ctx)
291 {
292   // ISSUE: not well tested on CXX code, so currently disable this pass for CXX
293   if (TransformationManager::isCXXLangOpt()) {
294     ValidInstanceNum = 0;
295     TransError = TransMaxInstanceError;
296     return;
297   }
298 
299   TransAssert(CollectionVisitor && "NULL CollectionVisitor!");
300   CollectionVisitor->TraverseDecl(Ctx.getTranslationUnitDecl());
301   if (QueryInstanceOnly) {
302     return;
303   }
304 
305   if (TransformationCounter > ValidInstanceNum) {
306     TransError = TransMaxInstanceError;
307     return;
308   }
309 
310   Ctx.getDiagnostics().setSuppressAllDiagnostics(false);
311 
312   TransAssert(RewriteVisitor && "NULL RewriteVisitor!");
313   TransAssert(TheRecordDecl && "NULL TheRecordDecl!");
314   TransAssert(ReplacingRecordDecl && "NULL ReplacingRecordDecl!");
315   RewriteVisitor->TraverseDecl(Ctx.getTranslationUnitDecl());
316 
317   if (Ctx.getDiagnostics().hasErrorOccurred() ||
318       Ctx.getDiagnostics().hasFatalErrorOccurred())
319     TransError = TransInternalError;
320 }
321 
setQualifierFlags(const FieldDecl * FD)322 void SimplifyStruct::setQualifierFlags(const FieldDecl *FD)
323 {
324   QualType QT = FD->getType();
325   if (QT.isConstQualified())
326     ConstField = true;
327   if (QT.isVolatileQualified())
328     VolatileField = true;
329 }
330 
~SimplifyStruct(void)331 SimplifyStruct::~SimplifyStruct(void)
332 {
333   delete CollectionVisitor;
334   delete RewriteVisitor;
335 }
336 
337