1 //===--- ByteCodeStmtGen.cpp - Code generator for expressions ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "ByteCodeStmtGen.h"
10 #include "ByteCodeEmitter.h"
11 #include "ByteCodeGenError.h"
12 #include "Context.h"
13 #include "Function.h"
14 #include "PrimType.h"
15 #include "Program.h"
16 #include "State.h"
17 #include "clang/Basic/LLVM.h"
18
19 using namespace clang;
20 using namespace clang::interp;
21
22 namespace clang {
23 namespace interp {
24
25 /// Scope managing label targets.
26 template <class Emitter> class LabelScope {
27 public:
~LabelScope()28 virtual ~LabelScope() { }
29
30 protected:
LabelScope(ByteCodeStmtGen<Emitter> * Ctx)31 LabelScope(ByteCodeStmtGen<Emitter> *Ctx) : Ctx(Ctx) {}
32 /// ByteCodeStmtGen instance.
33 ByteCodeStmtGen<Emitter> *Ctx;
34 };
35
36 /// Sets the context for break/continue statements.
37 template <class Emitter> class LoopScope final : public LabelScope<Emitter> {
38 public:
39 using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
40 using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
41
LoopScope(ByteCodeStmtGen<Emitter> * Ctx,LabelTy BreakLabel,LabelTy ContinueLabel)42 LoopScope(ByteCodeStmtGen<Emitter> *Ctx, LabelTy BreakLabel,
43 LabelTy ContinueLabel)
44 : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
45 OldContinueLabel(Ctx->ContinueLabel) {
46 this->Ctx->BreakLabel = BreakLabel;
47 this->Ctx->ContinueLabel = ContinueLabel;
48 }
49
~LoopScope()50 ~LoopScope() {
51 this->Ctx->BreakLabel = OldBreakLabel;
52 this->Ctx->ContinueLabel = OldContinueLabel;
53 }
54
55 private:
56 OptLabelTy OldBreakLabel;
57 OptLabelTy OldContinueLabel;
58 };
59
60 // Sets the context for a switch scope, mapping labels.
61 template <class Emitter> class SwitchScope final : public LabelScope<Emitter> {
62 public:
63 using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
64 using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
65 using CaseMap = typename ByteCodeStmtGen<Emitter>::CaseMap;
66
SwitchScope(ByteCodeStmtGen<Emitter> * Ctx,CaseMap && CaseLabels,LabelTy BreakLabel,OptLabelTy DefaultLabel)67 SwitchScope(ByteCodeStmtGen<Emitter> *Ctx, CaseMap &&CaseLabels,
68 LabelTy BreakLabel, OptLabelTy DefaultLabel)
69 : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
70 OldDefaultLabel(this->Ctx->DefaultLabel),
71 OldCaseLabels(std::move(this->Ctx->CaseLabels)) {
72 this->Ctx->BreakLabel = BreakLabel;
73 this->Ctx->DefaultLabel = DefaultLabel;
74 this->Ctx->CaseLabels = std::move(CaseLabels);
75 }
76
~SwitchScope()77 ~SwitchScope() {
78 this->Ctx->BreakLabel = OldBreakLabel;
79 this->Ctx->DefaultLabel = OldDefaultLabel;
80 this->Ctx->CaseLabels = std::move(OldCaseLabels);
81 }
82
83 private:
84 OptLabelTy OldBreakLabel;
85 OptLabelTy OldDefaultLabel;
86 CaseMap OldCaseLabels;
87 };
88
89 } // namespace interp
90 } // namespace clang
91
92 template <class Emitter>
visitFunc(const FunctionDecl * F)93 bool ByteCodeStmtGen<Emitter>::visitFunc(const FunctionDecl *F) {
94 // Classify the return type.
95 ReturnType = this->classify(F->getReturnType());
96
97 // Constructor. Set up field initializers.
98 if (const auto Ctor = dyn_cast<CXXConstructorDecl>(F)) {
99 const RecordDecl *RD = Ctor->getParent();
100 const Record *R = this->getRecord(RD);
101 if (!R)
102 return false;
103
104 for (const auto *Init : Ctor->inits()) {
105 const Expr *InitExpr = Init->getInit();
106 if (const FieldDecl *Member = Init->getMember()) {
107 const Record::Field *F = R->getField(Member);
108
109 if (std::optional<PrimType> T = this->classify(InitExpr)) {
110 if (!this->emitThis(InitExpr))
111 return false;
112
113 if (!this->visit(InitExpr))
114 return false;
115
116 if (!this->emitInitField(*T, F->Offset, InitExpr))
117 return false;
118
119 if (!this->emitPopPtr(InitExpr))
120 return false;
121 } else {
122 // Non-primitive case. Get a pointer to the field-to-initialize
123 // on the stack and call visitInitialzer() for it.
124 if (!this->emitThis(InitExpr))
125 return false;
126
127 if (!this->emitGetPtrField(F->Offset, InitExpr))
128 return false;
129
130 if (!this->visitInitializer(InitExpr))
131 return false;
132
133 if (!this->emitPopPtr(InitExpr))
134 return false;
135 }
136 } else if (const Type *Base = Init->getBaseClass()) {
137 // Base class initializer.
138 // Get This Base and call initializer on it.
139 auto *BaseDecl = Base->getAsCXXRecordDecl();
140 assert(BaseDecl);
141 const Record::Base *B = R->getBase(BaseDecl);
142 assert(B);
143 if (!this->emitGetPtrThisBase(B->Offset, InitExpr))
144 return false;
145 if (!this->visitInitializer(InitExpr))
146 return false;
147 if (!this->emitPopPtr(InitExpr))
148 return false;
149 }
150 }
151 }
152
153 if (const auto *Body = F->getBody())
154 if (!visitStmt(Body))
155 return false;
156
157 // Emit a guard return to protect against a code path missing one.
158 if (F->getReturnType()->isVoidType())
159 return this->emitRetVoid(SourceInfo{});
160 else
161 return this->emitNoRet(SourceInfo{});
162 }
163
164 template <class Emitter>
visitStmt(const Stmt * S)165 bool ByteCodeStmtGen<Emitter>::visitStmt(const Stmt *S) {
166 switch (S->getStmtClass()) {
167 case Stmt::CompoundStmtClass:
168 return visitCompoundStmt(cast<CompoundStmt>(S));
169 case Stmt::DeclStmtClass:
170 return visitDeclStmt(cast<DeclStmt>(S));
171 case Stmt::ReturnStmtClass:
172 return visitReturnStmt(cast<ReturnStmt>(S));
173 case Stmt::IfStmtClass:
174 return visitIfStmt(cast<IfStmt>(S));
175 case Stmt::WhileStmtClass:
176 return visitWhileStmt(cast<WhileStmt>(S));
177 case Stmt::DoStmtClass:
178 return visitDoStmt(cast<DoStmt>(S));
179 case Stmt::ForStmtClass:
180 return visitForStmt(cast<ForStmt>(S));
181 case Stmt::BreakStmtClass:
182 return visitBreakStmt(cast<BreakStmt>(S));
183 case Stmt::ContinueStmtClass:
184 return visitContinueStmt(cast<ContinueStmt>(S));
185 case Stmt::NullStmtClass:
186 return true;
187 default: {
188 if (auto *Exp = dyn_cast<Expr>(S))
189 return this->discard(Exp);
190 return this->bail(S);
191 }
192 }
193 }
194
195 template <class Emitter>
visitCompoundStmt(const CompoundStmt * CompoundStmt)196 bool ByteCodeStmtGen<Emitter>::visitCompoundStmt(
197 const CompoundStmt *CompoundStmt) {
198 BlockScope<Emitter> Scope(this);
199 for (auto *InnerStmt : CompoundStmt->body())
200 if (!visitStmt(InnerStmt))
201 return false;
202 return true;
203 }
204
205 template <class Emitter>
visitDeclStmt(const DeclStmt * DS)206 bool ByteCodeStmtGen<Emitter>::visitDeclStmt(const DeclStmt *DS) {
207 for (auto *D : DS->decls()) {
208 // Variable declarator.
209 if (auto *VD = dyn_cast<VarDecl>(D)) {
210 if (!this->visitVarDecl(VD))
211 return false;
212 continue;
213 }
214
215 // Decomposition declarator.
216 if (auto *DD = dyn_cast<DecompositionDecl>(D)) {
217 return this->bail(DD);
218 }
219 }
220
221 return true;
222 }
223
224 template <class Emitter>
visitReturnStmt(const ReturnStmt * RS)225 bool ByteCodeStmtGen<Emitter>::visitReturnStmt(const ReturnStmt *RS) {
226 if (const Expr *RE = RS->getRetValue()) {
227 ExprScope<Emitter> RetScope(this);
228 if (ReturnType) {
229 // Primitive types are simply returned.
230 if (!this->visit(RE))
231 return false;
232 this->emitCleanup();
233 return this->emitRet(*ReturnType, RS);
234 } else {
235 // RVO - construct the value in the return location.
236 if (!this->emitRVOPtr(RE))
237 return false;
238 if (!this->visitInitializer(RE))
239 return false;
240 if (!this->emitPopPtr(RE))
241 return false;
242
243 this->emitCleanup();
244 return this->emitRetVoid(RS);
245 }
246 }
247
248 // Void return.
249 this->emitCleanup();
250 return this->emitRetVoid(RS);
251 }
252
253 template <class Emitter>
visitIfStmt(const IfStmt * IS)254 bool ByteCodeStmtGen<Emitter>::visitIfStmt(const IfStmt *IS) {
255 BlockScope<Emitter> IfScope(this);
256
257 if (IS->isNonNegatedConsteval())
258 return visitStmt(IS->getThen());
259 if (IS->isNegatedConsteval())
260 return IS->getElse() ? visitStmt(IS->getElse()) : true;
261
262 if (auto *CondInit = IS->getInit())
263 if (!visitStmt(IS->getInit()))
264 return false;
265
266 if (const DeclStmt *CondDecl = IS->getConditionVariableDeclStmt())
267 if (!visitDeclStmt(CondDecl))
268 return false;
269
270 if (!this->visitBool(IS->getCond()))
271 return false;
272
273 if (const Stmt *Else = IS->getElse()) {
274 LabelTy LabelElse = this->getLabel();
275 LabelTy LabelEnd = this->getLabel();
276 if (!this->jumpFalse(LabelElse))
277 return false;
278 if (!visitStmt(IS->getThen()))
279 return false;
280 if (!this->jump(LabelEnd))
281 return false;
282 this->emitLabel(LabelElse);
283 if (!visitStmt(Else))
284 return false;
285 this->emitLabel(LabelEnd);
286 } else {
287 LabelTy LabelEnd = this->getLabel();
288 if (!this->jumpFalse(LabelEnd))
289 return false;
290 if (!visitStmt(IS->getThen()))
291 return false;
292 this->emitLabel(LabelEnd);
293 }
294
295 return true;
296 }
297
298 template <class Emitter>
visitWhileStmt(const WhileStmt * S)299 bool ByteCodeStmtGen<Emitter>::visitWhileStmt(const WhileStmt *S) {
300 const Expr *Cond = S->getCond();
301 const Stmt *Body = S->getBody();
302
303 LabelTy CondLabel = this->getLabel(); // Label before the condition.
304 LabelTy EndLabel = this->getLabel(); // Label after the loop.
305 LoopScope<Emitter> LS(this, EndLabel, CondLabel);
306
307 this->emitLabel(CondLabel);
308 if (!this->visitBool(Cond))
309 return false;
310 if (!this->jumpFalse(EndLabel))
311 return false;
312
313 if (!this->visitStmt(Body))
314 return false;
315 if (!this->jump(CondLabel))
316 return false;
317
318 this->emitLabel(EndLabel);
319
320 return true;
321 }
322
323 template <class Emitter>
visitDoStmt(const DoStmt * S)324 bool ByteCodeStmtGen<Emitter>::visitDoStmt(const DoStmt *S) {
325 const Expr *Cond = S->getCond();
326 const Stmt *Body = S->getBody();
327
328 LabelTy StartLabel = this->getLabel();
329 LabelTy EndLabel = this->getLabel();
330 LabelTy CondLabel = this->getLabel();
331 LoopScope<Emitter> LS(this, EndLabel, CondLabel);
332
333 this->emitLabel(StartLabel);
334 if (!this->visitStmt(Body))
335 return false;
336 this->emitLabel(CondLabel);
337 if (!this->visitBool(Cond))
338 return false;
339 if (!this->jumpTrue(StartLabel))
340 return false;
341 this->emitLabel(EndLabel);
342 return true;
343 }
344
345 template <class Emitter>
visitForStmt(const ForStmt * S)346 bool ByteCodeStmtGen<Emitter>::visitForStmt(const ForStmt *S) {
347 // for (Init; Cond; Inc) { Body }
348 const Stmt *Init = S->getInit();
349 const Expr *Cond = S->getCond();
350 const Expr *Inc = S->getInc();
351 const Stmt *Body = S->getBody();
352
353 LabelTy EndLabel = this->getLabel();
354 LabelTy CondLabel = this->getLabel();
355 LabelTy IncLabel = this->getLabel();
356 LoopScope<Emitter> LS(this, EndLabel, IncLabel);
357
358 if (Init && !this->visitStmt(Init))
359 return false;
360 this->emitLabel(CondLabel);
361 if (Cond) {
362 if (!this->visitBool(Cond))
363 return false;
364 if (!this->jumpFalse(EndLabel))
365 return false;
366 }
367 if (Body && !this->visitStmt(Body))
368 return false;
369 this->emitLabel(IncLabel);
370 if (Inc && !this->discard(Inc))
371 return false;
372 if (!this->jump(CondLabel))
373 return false;
374 this->emitLabel(EndLabel);
375 return true;
376 }
377
378 template <class Emitter>
visitBreakStmt(const BreakStmt * S)379 bool ByteCodeStmtGen<Emitter>::visitBreakStmt(const BreakStmt *S) {
380 if (!BreakLabel)
381 return false;
382
383 return this->jump(*BreakLabel);
384 }
385
386 template <class Emitter>
visitContinueStmt(const ContinueStmt * S)387 bool ByteCodeStmtGen<Emitter>::visitContinueStmt(const ContinueStmt *S) {
388 if (!ContinueLabel)
389 return false;
390
391 return this->jump(*ContinueLabel);
392 }
393
394 namespace clang {
395 namespace interp {
396
397 template class ByteCodeStmtGen<ByteCodeEmitter>;
398
399 } // namespace interp
400 } // namespace clang
401