1 //===--- ByteCodeStmtGen.cpp - Code generator for expressions ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ByteCodeStmtGen.h"
10 #include "ByteCodeEmitter.h"
11 #include "ByteCodeGenError.h"
12 #include "Context.h"
13 #include "Function.h"
14 #include "PrimType.h"
15 #include "Program.h"
16 #include "State.h"
17 #include "clang/Basic/LLVM.h"
18 
19 using namespace clang;
20 using namespace clang::interp;
21 
22 namespace clang {
23 namespace interp {
24 
25 /// Scope managing label targets.
26 template <class Emitter> class LabelScope {
27 public:
28   virtual ~LabelScope() {  }
29 
30 protected:
31   LabelScope(ByteCodeStmtGen<Emitter> *Ctx) : Ctx(Ctx) {}
32   /// ByteCodeStmtGen instance.
33   ByteCodeStmtGen<Emitter> *Ctx;
34 };
35 
36 /// Sets the context for break/continue statements.
37 template <class Emitter> class LoopScope final : public LabelScope<Emitter> {
38 public:
39   using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
40   using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
41 
42   LoopScope(ByteCodeStmtGen<Emitter> *Ctx, LabelTy BreakLabel,
43             LabelTy ContinueLabel)
44       : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
45         OldContinueLabel(Ctx->ContinueLabel) {
46     this->Ctx->BreakLabel = BreakLabel;
47     this->Ctx->ContinueLabel = ContinueLabel;
48   }
49 
50   ~LoopScope() {
51     this->Ctx->BreakLabel = OldBreakLabel;
52     this->Ctx->ContinueLabel = OldContinueLabel;
53   }
54 
55 private:
56   OptLabelTy OldBreakLabel;
57   OptLabelTy OldContinueLabel;
58 };
59 
60 // Sets the context for a switch scope, mapping labels.
61 template <class Emitter> class SwitchScope final : public LabelScope<Emitter> {
62 public:
63   using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
64   using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
65   using CaseMap = typename ByteCodeStmtGen<Emitter>::CaseMap;
66 
67   SwitchScope(ByteCodeStmtGen<Emitter> *Ctx, CaseMap &&CaseLabels,
68               LabelTy BreakLabel, OptLabelTy DefaultLabel)
69       : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
70         OldDefaultLabel(this->Ctx->DefaultLabel),
71         OldCaseLabels(std::move(this->Ctx->CaseLabels)) {
72     this->Ctx->BreakLabel = BreakLabel;
73     this->Ctx->DefaultLabel = DefaultLabel;
74     this->Ctx->CaseLabels = std::move(CaseLabels);
75   }
76 
77   ~SwitchScope() {
78     this->Ctx->BreakLabel = OldBreakLabel;
79     this->Ctx->DefaultLabel = OldDefaultLabel;
80     this->Ctx->CaseLabels = std::move(OldCaseLabels);
81   }
82 
83 private:
84   OptLabelTy OldBreakLabel;
85   OptLabelTy OldDefaultLabel;
86   CaseMap OldCaseLabels;
87 };
88 
89 } // namespace interp
90 } // namespace clang
91 
92 template <class Emitter>
93 bool ByteCodeStmtGen<Emitter>::visitFunc(const FunctionDecl *F) {
94   // Classify the return type.
95   ReturnType = this->classify(F->getReturnType());
96 
97   // Constructor. Set up field initializers.
98   if (const auto Ctor = dyn_cast<CXXConstructorDecl>(F)) {
99     const RecordDecl *RD = Ctor->getParent();
100     const Record *R = this->getRecord(RD);
101     if (!R)
102       return false;
103 
104     for (const auto *Init : Ctor->inits()) {
105       const Expr *InitExpr = Init->getInit();
106       if (const FieldDecl *Member = Init->getMember()) {
107         const Record::Field *F = R->getField(Member);
108 
109         if (std::optional<PrimType> T = this->classify(InitExpr)) {
110           if (!this->emitThis(InitExpr))
111             return false;
112 
113           if (!this->visit(InitExpr))
114             return false;
115 
116           if (!this->emitInitField(*T, F->Offset, InitExpr))
117             return false;
118 
119           if (!this->emitPopPtr(InitExpr))
120             return false;
121         } else {
122           // Non-primitive case. Get a pointer to the field-to-initialize
123           // on the stack and call visitInitialzer() for it.
124           if (!this->emitThis(InitExpr))
125             return false;
126 
127           if (!this->emitGetPtrField(F->Offset, InitExpr))
128             return false;
129 
130           if (!this->visitInitializer(InitExpr))
131             return false;
132 
133           if (!this->emitPopPtr(InitExpr))
134             return false;
135         }
136       } else if (const Type *Base = Init->getBaseClass()) {
137         // Base class initializer.
138         // Get This Base and call initializer on it.
139         auto *BaseDecl = Base->getAsCXXRecordDecl();
140         assert(BaseDecl);
141         const Record::Base *B = R->getBase(BaseDecl);
142         assert(B);
143         if (!this->emitGetPtrThisBase(B->Offset, InitExpr))
144           return false;
145         if (!this->visitInitializer(InitExpr))
146           return false;
147         if (!this->emitPopPtr(InitExpr))
148           return false;
149       }
150     }
151   }
152 
153   if (const auto *Body = F->getBody())
154     if (!visitStmt(Body))
155       return false;
156 
157   // Emit a guard return to protect against a code path missing one.
158   if (F->getReturnType()->isVoidType())
159     return this->emitRetVoid(SourceInfo{});
160   else
161     return this->emitNoRet(SourceInfo{});
162 }
163 
164 template <class Emitter>
165 bool ByteCodeStmtGen<Emitter>::visitStmt(const Stmt *S) {
166   switch (S->getStmtClass()) {
167   case Stmt::CompoundStmtClass:
168     return visitCompoundStmt(cast<CompoundStmt>(S));
169   case Stmt::DeclStmtClass:
170     return visitDeclStmt(cast<DeclStmt>(S));
171   case Stmt::ReturnStmtClass:
172     return visitReturnStmt(cast<ReturnStmt>(S));
173   case Stmt::IfStmtClass:
174     return visitIfStmt(cast<IfStmt>(S));
175   case Stmt::WhileStmtClass:
176     return visitWhileStmt(cast<WhileStmt>(S));
177   case Stmt::DoStmtClass:
178     return visitDoStmt(cast<DoStmt>(S));
179   case Stmt::ForStmtClass:
180     return visitForStmt(cast<ForStmt>(S));
181   case Stmt::BreakStmtClass:
182     return visitBreakStmt(cast<BreakStmt>(S));
183   case Stmt::ContinueStmtClass:
184     return visitContinueStmt(cast<ContinueStmt>(S));
185   case Stmt::NullStmtClass:
186     return true;
187   default: {
188     if (auto *Exp = dyn_cast<Expr>(S))
189       return this->discard(Exp);
190     return this->bail(S);
191   }
192   }
193 }
194 
195 template <class Emitter>
196 bool ByteCodeStmtGen<Emitter>::visitCompoundStmt(
197     const CompoundStmt *CompoundStmt) {
198   BlockScope<Emitter> Scope(this);
199   for (auto *InnerStmt : CompoundStmt->body())
200     if (!visitStmt(InnerStmt))
201       return false;
202   return true;
203 }
204 
205 template <class Emitter>
206 bool ByteCodeStmtGen<Emitter>::visitDeclStmt(const DeclStmt *DS) {
207   for (auto *D : DS->decls()) {
208     // Variable declarator.
209     if (auto *VD = dyn_cast<VarDecl>(D)) {
210       if (!this->visitVarDecl(VD))
211         return false;
212       continue;
213     }
214 
215     // Decomposition declarator.
216     if (auto *DD = dyn_cast<DecompositionDecl>(D)) {
217       return this->bail(DD);
218     }
219   }
220 
221   return true;
222 }
223 
224 template <class Emitter>
225 bool ByteCodeStmtGen<Emitter>::visitReturnStmt(const ReturnStmt *RS) {
226   if (const Expr *RE = RS->getRetValue()) {
227     ExprScope<Emitter> RetScope(this);
228     if (ReturnType) {
229       // Primitive types are simply returned.
230       if (!this->visit(RE))
231         return false;
232       this->emitCleanup();
233       return this->emitRet(*ReturnType, RS);
234     } else {
235       // RVO - construct the value in the return location.
236       if (!this->emitRVOPtr(RE))
237         return false;
238       if (!this->visitInitializer(RE))
239         return false;
240       if (!this->emitPopPtr(RE))
241         return false;
242 
243       this->emitCleanup();
244       return this->emitRetVoid(RS);
245     }
246   }
247 
248   // Void return.
249   this->emitCleanup();
250   return this->emitRetVoid(RS);
251 }
252 
253 template <class Emitter>
254 bool ByteCodeStmtGen<Emitter>::visitIfStmt(const IfStmt *IS) {
255   BlockScope<Emitter> IfScope(this);
256 
257   if (IS->isNonNegatedConsteval())
258     return visitStmt(IS->getThen());
259   if (IS->isNegatedConsteval())
260     return IS->getElse() ? visitStmt(IS->getElse()) : true;
261 
262   if (auto *CondInit = IS->getInit())
263     if (!visitStmt(IS->getInit()))
264       return false;
265 
266   if (const DeclStmt *CondDecl = IS->getConditionVariableDeclStmt())
267     if (!visitDeclStmt(CondDecl))
268       return false;
269 
270   if (!this->visitBool(IS->getCond()))
271     return false;
272 
273   if (const Stmt *Else = IS->getElse()) {
274     LabelTy LabelElse = this->getLabel();
275     LabelTy LabelEnd = this->getLabel();
276     if (!this->jumpFalse(LabelElse))
277       return false;
278     if (!visitStmt(IS->getThen()))
279       return false;
280     if (!this->jump(LabelEnd))
281       return false;
282     this->emitLabel(LabelElse);
283     if (!visitStmt(Else))
284       return false;
285     this->emitLabel(LabelEnd);
286   } else {
287     LabelTy LabelEnd = this->getLabel();
288     if (!this->jumpFalse(LabelEnd))
289       return false;
290     if (!visitStmt(IS->getThen()))
291       return false;
292     this->emitLabel(LabelEnd);
293   }
294 
295   return true;
296 }
297 
298 template <class Emitter>
299 bool ByteCodeStmtGen<Emitter>::visitWhileStmt(const WhileStmt *S) {
300   const Expr *Cond = S->getCond();
301   const Stmt *Body = S->getBody();
302 
303   LabelTy CondLabel = this->getLabel(); // Label before the condition.
304   LabelTy EndLabel = this->getLabel();  // Label after the loop.
305   LoopScope<Emitter> LS(this, EndLabel, CondLabel);
306 
307   this->emitLabel(CondLabel);
308   if (!this->visitBool(Cond))
309     return false;
310   if (!this->jumpFalse(EndLabel))
311     return false;
312 
313   if (!this->visitStmt(Body))
314     return false;
315   if (!this->jump(CondLabel))
316     return false;
317 
318   this->emitLabel(EndLabel);
319 
320   return true;
321 }
322 
323 template <class Emitter>
324 bool ByteCodeStmtGen<Emitter>::visitDoStmt(const DoStmt *S) {
325   const Expr *Cond = S->getCond();
326   const Stmt *Body = S->getBody();
327 
328   LabelTy StartLabel = this->getLabel();
329   LabelTy EndLabel = this->getLabel();
330   LabelTy CondLabel = this->getLabel();
331   LoopScope<Emitter> LS(this, EndLabel, CondLabel);
332 
333   this->emitLabel(StartLabel);
334   if (!this->visitStmt(Body))
335     return false;
336   this->emitLabel(CondLabel);
337   if (!this->visitBool(Cond))
338     return false;
339   if (!this->jumpTrue(StartLabel))
340     return false;
341   this->emitLabel(EndLabel);
342   return true;
343 }
344 
345 template <class Emitter>
346 bool ByteCodeStmtGen<Emitter>::visitForStmt(const ForStmt *S) {
347   // for (Init; Cond; Inc) { Body }
348   const Stmt *Init = S->getInit();
349   const Expr *Cond = S->getCond();
350   const Expr *Inc = S->getInc();
351   const Stmt *Body = S->getBody();
352 
353   LabelTy EndLabel = this->getLabel();
354   LabelTy CondLabel = this->getLabel();
355   LabelTy IncLabel = this->getLabel();
356   LoopScope<Emitter> LS(this, EndLabel, IncLabel);
357 
358   if (Init && !this->visitStmt(Init))
359     return false;
360   this->emitLabel(CondLabel);
361   if (Cond) {
362     if (!this->visitBool(Cond))
363       return false;
364     if (!this->jumpFalse(EndLabel))
365       return false;
366   }
367   if (Body && !this->visitStmt(Body))
368     return false;
369   this->emitLabel(IncLabel);
370   if (Inc && !this->discard(Inc))
371     return false;
372   if (!this->jump(CondLabel))
373     return false;
374   this->emitLabel(EndLabel);
375   return true;
376 }
377 
378 template <class Emitter>
379 bool ByteCodeStmtGen<Emitter>::visitBreakStmt(const BreakStmt *S) {
380   if (!BreakLabel)
381     return false;
382 
383   return this->jump(*BreakLabel);
384 }
385 
386 template <class Emitter>
387 bool ByteCodeStmtGen<Emitter>::visitContinueStmt(const ContinueStmt *S) {
388   if (!ContinueLabel)
389     return false;
390 
391   return this->jump(*ContinueLabel);
392 }
393 
394 namespace clang {
395 namespace interp {
396 
397 template class ByteCodeStmtGen<ByteCodeEmitter>;
398 
399 } // namespace interp
400 } // namespace clang
401