1 //===----------------------------------------------------------------------===//
2 //
3 // Copyright (c) 2012, 2013, 2015, 2016, 2017, 2019 The University of Utah
4 // All rights reserved.
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License. See the file COPYING for details.
8 //
9 //===----------------------------------------------------------------------===//
10
11 #if HAVE_CONFIG_H
12 # include <config.h>
13 #endif
14
15 #include "SimplifyStruct.h"
16
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/Lex/Lexer.h"
19 #include "clang/AST/ASTContext.h"
20 #include "clang/Basic/SourceManager.h"
21 #include "clang/AST/RecordLayout.h"
22
23 #include "TransformationManager.h"
24
25 using namespace clang;
26
27 static const char *DescriptionMsg =
28 "This pass replaces a struct with its parent if it has only one \
29 field, and this field is a struct, e.g, \n\
30 struct S1 { \n\
31 int f1;\n\
32 int f2;\n\
33 }\n\
34 struct S2 { \n\
35 struct S1 f1;\n\
36 } \n\
37 In the above code, struct S2 will be replaced with struct S1, including\n\
38 all its referenced. \n";
39
40 static RegisterTransformation<SimplifyStruct>
41 Trans("simplify-struct", DescriptionMsg);
42
43 class SimplifyStructCollectionVisitor : public
44 RecursiveASTVisitor<SimplifyStructCollectionVisitor> {
45
46 public:
SimplifyStructCollectionVisitor(SimplifyStruct * Instance)47 explicit SimplifyStructCollectionVisitor(SimplifyStruct *Instance)
48 : ConsumerInstance(Instance)
49 { }
50
51 bool VisitRecordDecl(RecordDecl *RD);
52
53 private:
54
55 SimplifyStruct *ConsumerInstance;
56 };
57
58 class SimplifyStructRewriteVisitor : public
59 RecursiveASTVisitor<SimplifyStructRewriteVisitor> {
60
61 public:
SimplifyStructRewriteVisitor(SimplifyStruct * Instance)62 explicit SimplifyStructRewriteVisitor(SimplifyStruct *Instance)
63 : ConsumerInstance(Instance)
64 { }
65
66 bool VisitVarDecl(VarDecl *VD);
67
68 bool VisitRecordDecl(RecordDecl *RD);
69
70 bool VisitRecordTypeLoc(RecordTypeLoc RTLoc);
71
72 bool VisitMemberExpr(MemberExpr *ME);
73
74 private:
75
76 SimplifyStruct *ConsumerInstance;
77 };
78
VisitRecordDecl(RecordDecl * RD)79 bool SimplifyStructCollectionVisitor::VisitRecordDecl(RecordDecl *RD)
80 {
81 if (ConsumerInstance->isInIncludedFile(RD))
82 return true;
83 if (!RD->isThisDeclarationADefinition() || !RD->isStruct())
84 return true;
85 if (ConsumerInstance->isSpecialRecordDecl(RD))
86 return true;
87 if (RD->isInvalidDecl())
88 return true;
89
90 const ASTRecordLayout &Info = ConsumerInstance->Context->getASTRecordLayout(RD);
91 unsigned Count = Info.getFieldCount();
92 if (Count != 1)
93 return true;
94
95 const FieldDecl *FD = *(RD->field_begin());
96 TransAssert(FD && "Invalid FieldDecl!");
97 const Type *Ty = FD->getType().getTypePtr();
98 const RecordType *RT = Ty->getAs<RecordType>();
99 if (!RT)
100 return true;
101
102 const RecordDecl *NestedRD = RT->getDecl();
103 if (NestedRD->getNameAsString() == "")
104 return true;
105
106 ConsumerInstance->ValidInstanceNum++;
107 if (ConsumerInstance->TransformationCounter ==
108 ConsumerInstance->ValidInstanceNum) {
109 ConsumerInstance->TheRecordDecl =
110 dyn_cast<RecordDecl>(RD->getCanonicalDecl());
111 ConsumerInstance->ReplacingRecordDecl =
112 dyn_cast<RecordDecl>(NestedRD->getCanonicalDecl());
113 ConsumerInstance->setQualifierFlags(FD);
114 }
115 return true;
116 }
117
VisitVarDecl(VarDecl * VD)118 bool SimplifyStructRewriteVisitor::VisitVarDecl(VarDecl *VD)
119 {
120 if (!ConsumerInstance->ConstField && !ConsumerInstance->VolatileField)
121 return true;
122
123 QualType QT = VD->getType();
124 const Type *Ty = QT.getTypePtr();
125 const RecordType *RT = Ty->getAs<RecordType>();
126 if (!RT)
127 return true;
128
129 const RecordDecl *RD = RT->getDecl();
130 if (RD != ConsumerInstance->TheRecordDecl)
131 return true;
132
133 SourceLocation LocStart = VD->getBeginLoc();
134 void *LocPtr = LocStart.getPtrEncoding();
135 if (ConsumerInstance->VisitedVarDeclLocs.count(LocPtr))
136 return true;
137
138 ConsumerInstance->VisitedVarDeclLocs.insert(LocPtr);
139
140 std::string QualStr = "";
141 if (ConsumerInstance->ConstField && !QT.isConstQualified())
142 QualStr += "const ";
143 if (ConsumerInstance->VolatileField && !QT.isVolatileQualified())
144 QualStr += "volatile ";
145 ConsumerInstance->TheRewriter.InsertText(LocStart, QualStr);
146 return true;
147 }
148
VisitRecordDecl(RecordDecl * RD)149 bool SimplifyStructRewriteVisitor::VisitRecordDecl(RecordDecl *RD)
150 {
151 RecordDecl *CanonicalRD = dyn_cast<RecordDecl>(RD->getCanonicalDecl());
152 if (CanonicalRD != ConsumerInstance->TheRecordDecl)
153 return true;
154
155 SourceLocation LocStart = RD->getLocation();
156 void *LocPtr = LocStart.getPtrEncoding();
157 if (!ConsumerInstance->VisitedLocs.count(LocPtr)) {
158 ConsumerInstance->VisitedLocs.insert(LocPtr);
159 std::string RPName =
160 ConsumerInstance->ReplacingRecordDecl->getNameAsString();
161 if (RD->getNameAsString() != "") {
162 ConsumerInstance->RewriteHelper->replaceRecordDeclName(RD, RPName);
163 }
164 else {
165 ConsumerInstance->TheRewriter.ReplaceText(LocStart,
166 /*struct*/6, "struct " + RPName);
167 }
168 }
169
170 if (!RD->isThisDeclarationADefinition())
171 return true;
172
173 SourceLocation LBLoc = RD->getBraceRange().getBegin();
174 SourceLocation RBLoc = RD->getBraceRange().getEnd();
175 ConsumerInstance->TheRewriter.RemoveText(SourceRange(LBLoc, RBLoc));
176 return true;
177 }
178
VisitRecordTypeLoc(RecordTypeLoc RTLoc)179 bool SimplifyStructRewriteVisitor::VisitRecordTypeLoc(RecordTypeLoc RTLoc)
180 {
181 const Type *Ty = RTLoc.getTypePtr();
182 if (Ty->isUnionType())
183 return true;
184
185 RecordDecl *RD = RTLoc.getDecl();
186 RecordDecl *CanonicalRD = dyn_cast<RecordDecl>(RD->getCanonicalDecl());
187 if (CanonicalRD != ConsumerInstance->TheRecordDecl)
188 return true;
189
190 SourceLocation LocStart = RTLoc.getBeginLoc();
191 void *LocPtr = LocStart.getPtrEncoding();
192 if (ConsumerInstance->VisitedLocs.count(LocPtr))
193 return true;
194 ConsumerInstance->VisitedLocs.insert(LocPtr);
195
196 ConsumerInstance->RewriteHelper->replaceRecordType(RTLoc,
197 ConsumerInstance->ReplacingRecordDecl->getNameAsString());
198 return true;
199 }
200
VisitMemberExpr(MemberExpr * ME)201 bool SimplifyStructRewriteVisitor::VisitMemberExpr(MemberExpr *ME)
202 {
203 ValueDecl *OrigDecl = ME->getMemberDecl();
204 FieldDecl *FD = dyn_cast<FieldDecl>(OrigDecl);
205
206 if (!FD) {
207 // in C++, getMemberDecl returns a CXXMethodDecl.
208 if (TransformationManager::isCXXLangOpt())
209 return true;
210 TransAssert(0 && "Bad FD!\n");
211 }
212
213 RecordDecl *RD = FD->getParent();
214 if (!RD || (dyn_cast<RecordDecl>(RD->getCanonicalDecl()) !=
215 ConsumerInstance->TheRecordDecl))
216 return true;
217
218 const Type *T = FD->getType().getTypePtr();
219 const RecordType *RT = T->getAs<RecordType>();
220 TransAssert(RT && "Invalid record type!");
221 const RecordDecl *ReplacingRD =
222 dyn_cast<RecordDecl>(RT->getDecl()->getCanonicalDecl());
223 (void)ReplacingRD;
224 TransAssert((ReplacingRD == ConsumerInstance->ReplacingRecordDecl) &&
225 "Unmatched Replacing RD!");
226
227 SourceLocation LocEnd = ME->getEndLoc();
228 if (LocEnd.isMacroID()) {
229 LocEnd = ConsumerInstance->SrcManager->getSpellingLoc(LocEnd);
230 }
231 SourceLocation ArrowPos =
232 Lexer::findLocationAfterToken(LocEnd,
233 tok::arrow,
234 *(ConsumerInstance->SrcManager),
235 ConsumerInstance->Context->getLangOpts(),
236 /*SkipTrailingWhitespaceAndNewLine=*/true);
237 SourceLocation PeriodPos =
238 Lexer::findLocationAfterToken(LocEnd,
239 tok::period,
240 *(ConsumerInstance->SrcManager),
241 ConsumerInstance->Context->getLangOpts(),
242 /*SkipTrailingWhitespaceAndNewLine=*/true);
243
244 std::string ES;
245 ConsumerInstance->RewriteHelper->getExprString(ME, ES);
246
247 // no more MemberExpr upon this ME
248 if (ArrowPos.isInvalid() && PeriodPos.isInvalid()) {
249 SourceLocation StartLoc = ME->getBeginLoc();
250 size_t Pos;
251
252 if (ME->isArrow()) {
253 Pos = ES.find("->");
254 }
255 else {
256 Pos = ES.find(".");
257 }
258 TransAssert((Pos != std::string::npos) && "Cannot find arrow or dot!");
259 StartLoc = StartLoc.getLocWithOffset(Pos);
260
261 int Off = ES.length() - Pos;
262 ConsumerInstance->TheRewriter.RemoveText(StartLoc, Off);
263 return true;
264 }
265
266 SourceLocation StartLoc = ME->getMemberLoc();
267 const char *StartBuf =
268 ConsumerInstance->SrcManager->getCharacterData(StartLoc);
269 const char *EndBuf;
270 if (ArrowPos.isValid()) {
271 EndBuf = ConsumerInstance->SrcManager->getCharacterData(ArrowPos);
272 EndBuf++;
273 }
274 else {
275 TransAssert(PeriodPos.isValid() && "Bad dot position!");
276 EndBuf = ConsumerInstance->SrcManager->getCharacterData(PeriodPos);
277 }
278 int Off = EndBuf - StartBuf;
279 ConsumerInstance->TheRewriter.RemoveText(StartLoc, Off);
280 return true;
281 }
282
Initialize(ASTContext & context)283 void SimplifyStruct::Initialize(ASTContext &context)
284 {
285 Transformation::Initialize(context);
286 CollectionVisitor = new SimplifyStructCollectionVisitor(this);
287 RewriteVisitor = new SimplifyStructRewriteVisitor(this);
288 }
289
HandleTranslationUnit(ASTContext & Ctx)290 void SimplifyStruct::HandleTranslationUnit(ASTContext &Ctx)
291 {
292 // ISSUE: not well tested on CXX code, so currently disable this pass for CXX
293 if (TransformationManager::isCXXLangOpt()) {
294 ValidInstanceNum = 0;
295 TransError = TransMaxInstanceError;
296 return;
297 }
298
299 TransAssert(CollectionVisitor && "NULL CollectionVisitor!");
300 CollectionVisitor->TraverseDecl(Ctx.getTranslationUnitDecl());
301 if (QueryInstanceOnly) {
302 return;
303 }
304
305 if (TransformationCounter > ValidInstanceNum) {
306 TransError = TransMaxInstanceError;
307 return;
308 }
309
310 Ctx.getDiagnostics().setSuppressAllDiagnostics(false);
311
312 TransAssert(RewriteVisitor && "NULL RewriteVisitor!");
313 TransAssert(TheRecordDecl && "NULL TheRecordDecl!");
314 TransAssert(ReplacingRecordDecl && "NULL ReplacingRecordDecl!");
315 RewriteVisitor->TraverseDecl(Ctx.getTranslationUnitDecl());
316
317 if (Ctx.getDiagnostics().hasErrorOccurred() ||
318 Ctx.getDiagnostics().hasFatalErrorOccurred())
319 TransError = TransInternalError;
320 }
321
setQualifierFlags(const FieldDecl * FD)322 void SimplifyStruct::setQualifierFlags(const FieldDecl *FD)
323 {
324 QualType QT = FD->getType();
325 if (QT.isConstQualified())
326 ConstField = true;
327 if (QT.isVolatileQualified())
328 VolatileField = true;
329 }
330
~SimplifyStruct(void)331 SimplifyStruct::~SimplifyStruct(void)
332 {
333 delete CollectionVisitor;
334 delete RewriteVisitor;
335 }
336
337