1 //===--- ByteCodeStmtGen.cpp - Code generator for expressions ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ByteCodeStmtGen.h"
10 #include "ByteCodeEmitter.h"
11 #include "ByteCodeGenError.h"
12 #include "Context.h"
13 #include "Function.h"
14 #include "PrimType.h"
15 #include "Program.h"
16 #include "State.h"
17 #include "clang/Basic/LLVM.h"
18 
19 using namespace clang;
20 using namespace clang::interp;
21 
22 namespace clang {
23 namespace interp {
24 
25 /// Scope managing label targets.
26 template <class Emitter> class LabelScope {
27 public:
28   virtual ~LabelScope() {  }
29 
30 protected:
31   LabelScope(ByteCodeStmtGen<Emitter> *Ctx) : Ctx(Ctx) {}
32   /// ByteCodeStmtGen instance.
33   ByteCodeStmtGen<Emitter> *Ctx;
34 };
35 
36 /// Sets the context for break/continue statements.
37 template <class Emitter> class LoopScope final : public LabelScope<Emitter> {
38 public:
39   using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
40   using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
41 
42   LoopScope(ByteCodeStmtGen<Emitter> *Ctx, LabelTy BreakLabel,
43             LabelTy ContinueLabel)
44       : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
45         OldContinueLabel(Ctx->ContinueLabel) {
46     this->Ctx->BreakLabel = BreakLabel;
47     this->Ctx->ContinueLabel = ContinueLabel;
48   }
49 
50   ~LoopScope() {
51     this->Ctx->BreakLabel = OldBreakLabel;
52     this->Ctx->ContinueLabel = OldContinueLabel;
53   }
54 
55 private:
56   OptLabelTy OldBreakLabel;
57   OptLabelTy OldContinueLabel;
58 };
59 
60 // Sets the context for a switch scope, mapping labels.
61 template <class Emitter> class SwitchScope final : public LabelScope<Emitter> {
62 public:
63   using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
64   using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
65   using CaseMap = typename ByteCodeStmtGen<Emitter>::CaseMap;
66 
67   SwitchScope(ByteCodeStmtGen<Emitter> *Ctx, CaseMap &&CaseLabels,
68               LabelTy BreakLabel, OptLabelTy DefaultLabel)
69       : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
70         OldDefaultLabel(this->Ctx->DefaultLabel),
71         OldCaseLabels(std::move(this->Ctx->CaseLabels)) {
72     this->Ctx->BreakLabel = BreakLabel;
73     this->Ctx->DefaultLabel = DefaultLabel;
74     this->Ctx->CaseLabels = std::move(CaseLabels);
75   }
76 
77   ~SwitchScope() {
78     this->Ctx->BreakLabel = OldBreakLabel;
79     this->Ctx->DefaultLabel = OldDefaultLabel;
80     this->Ctx->CaseLabels = std::move(OldCaseLabels);
81   }
82 
83 private:
84   OptLabelTy OldBreakLabel;
85   OptLabelTy OldDefaultLabel;
86   CaseMap OldCaseLabels;
87 };
88 
89 } // namespace interp
90 } // namespace clang
91 
92 template <class Emitter>
93 bool ByteCodeStmtGen<Emitter>::visitFunc(const FunctionDecl *F) {
94   // Classify the return type.
95   ReturnType = this->classify(F->getReturnType());
96 
97   // Constructor. Set up field initializers.
98   if (const auto *Ctor = dyn_cast<CXXConstructorDecl>(F)) {
99     const RecordDecl *RD = Ctor->getParent();
100     const Record *R = this->getRecord(RD);
101     if (!R)
102       return false;
103 
104     for (const auto *Init : Ctor->inits()) {
105       // Scope needed for the initializers.
106       BlockScope<Emitter> Scope(this);
107 
108       const Expr *InitExpr = Init->getInit();
109       if (const FieldDecl *Member = Init->getMember()) {
110         const Record::Field *F = R->getField(Member);
111 
112         if (std::optional<PrimType> T = this->classify(InitExpr)) {
113           if (!this->visit(InitExpr))
114             return false;
115 
116           if (!this->emitInitThisField(*T, F->Offset, InitExpr))
117             return false;
118         } else {
119           // Non-primitive case. Get a pointer to the field-to-initialize
120           // on the stack and call visitInitialzer() for it.
121           if (!this->emitThis(InitExpr))
122             return false;
123 
124           if (!this->emitGetPtrField(F->Offset, InitExpr))
125             return false;
126 
127           if (!this->visitInitializer(InitExpr))
128             return false;
129 
130           if (!this->emitPopPtr(InitExpr))
131             return false;
132         }
133       } else if (const Type *Base = Init->getBaseClass()) {
134         // Base class initializer.
135         // Get This Base and call initializer on it.
136         const auto *BaseDecl = Base->getAsCXXRecordDecl();
137         assert(BaseDecl);
138         const Record::Base *B = R->getBase(BaseDecl);
139         assert(B);
140         if (!this->emitGetPtrThisBase(B->Offset, InitExpr))
141           return false;
142         if (!this->visitInitializer(InitExpr))
143           return false;
144         if (!this->emitPopPtr(InitExpr))
145           return false;
146       }
147     }
148   }
149 
150   if (const auto *Body = F->getBody())
151     if (!visitStmt(Body))
152       return false;
153 
154   // Emit a guard return to protect against a code path missing one.
155   if (F->getReturnType()->isVoidType())
156     return this->emitRetVoid(SourceInfo{});
157   else
158     return this->emitNoRet(SourceInfo{});
159 }
160 
161 template <class Emitter>
162 bool ByteCodeStmtGen<Emitter>::visitStmt(const Stmt *S) {
163   switch (S->getStmtClass()) {
164   case Stmt::CompoundStmtClass:
165     return visitCompoundStmt(cast<CompoundStmt>(S));
166   case Stmt::DeclStmtClass:
167     return visitDeclStmt(cast<DeclStmt>(S));
168   case Stmt::ReturnStmtClass:
169     return visitReturnStmt(cast<ReturnStmt>(S));
170   case Stmt::IfStmtClass:
171     return visitIfStmt(cast<IfStmt>(S));
172   case Stmt::WhileStmtClass:
173     return visitWhileStmt(cast<WhileStmt>(S));
174   case Stmt::DoStmtClass:
175     return visitDoStmt(cast<DoStmt>(S));
176   case Stmt::ForStmtClass:
177     return visitForStmt(cast<ForStmt>(S));
178   case Stmt::CXXForRangeStmtClass:
179     return visitCXXForRangeStmt(cast<CXXForRangeStmt>(S));
180   case Stmt::BreakStmtClass:
181     return visitBreakStmt(cast<BreakStmt>(S));
182   case Stmt::ContinueStmtClass:
183     return visitContinueStmt(cast<ContinueStmt>(S));
184   case Stmt::SwitchStmtClass:
185     return visitSwitchStmt(cast<SwitchStmt>(S));
186   case Stmt::CaseStmtClass:
187     return visitCaseStmt(cast<CaseStmt>(S));
188   case Stmt::DefaultStmtClass:
189     return visitDefaultStmt(cast<DefaultStmt>(S));
190   case Stmt::NullStmtClass:
191     return true;
192   default: {
193     if (auto *Exp = dyn_cast<Expr>(S))
194       return this->discard(Exp);
195     return this->bail(S);
196   }
197   }
198 }
199 
200 /// Visits the given statment without creating a variable
201 /// scope for it in case it is a compound statement.
202 template <class Emitter>
203 bool ByteCodeStmtGen<Emitter>::visitLoopBody(const Stmt *S) {
204   if (isa<NullStmt>(S))
205     return true;
206 
207   if (const auto *CS = dyn_cast<CompoundStmt>(S)) {
208     for (auto *InnerStmt : CS->body())
209       if (!visitStmt(InnerStmt))
210         return false;
211     return true;
212   }
213 
214   return this->visitStmt(S);
215 }
216 
217 template <class Emitter>
218 bool ByteCodeStmtGen<Emitter>::visitCompoundStmt(
219     const CompoundStmt *CompoundStmt) {
220   BlockScope<Emitter> Scope(this);
221   for (auto *InnerStmt : CompoundStmt->body())
222     if (!visitStmt(InnerStmt))
223       return false;
224   return true;
225 }
226 
227 template <class Emitter>
228 bool ByteCodeStmtGen<Emitter>::visitDeclStmt(const DeclStmt *DS) {
229   for (auto *D : DS->decls()) {
230     if (isa<StaticAssertDecl, TagDecl, TypedefNameDecl>(D))
231       continue;
232 
233     const auto *VD = dyn_cast<VarDecl>(D);
234     if (!VD)
235       return false;
236     if (!this->visitVarDecl(VD))
237       return false;
238   }
239 
240   return true;
241 }
242 
243 template <class Emitter>
244 bool ByteCodeStmtGen<Emitter>::visitReturnStmt(const ReturnStmt *RS) {
245   if (const Expr *RE = RS->getRetValue()) {
246     ExprScope<Emitter> RetScope(this);
247     if (ReturnType) {
248       // Primitive types are simply returned.
249       if (!this->visit(RE))
250         return false;
251       this->emitCleanup();
252       return this->emitRet(*ReturnType, RS);
253     } else {
254       // RVO - construct the value in the return location.
255       if (!this->emitRVOPtr(RE))
256         return false;
257       if (!this->visitInitializer(RE))
258         return false;
259       if (!this->emitPopPtr(RE))
260         return false;
261 
262       this->emitCleanup();
263       return this->emitRetVoid(RS);
264     }
265   }
266 
267   // Void return.
268   this->emitCleanup();
269   return this->emitRetVoid(RS);
270 }
271 
272 template <class Emitter>
273 bool ByteCodeStmtGen<Emitter>::visitIfStmt(const IfStmt *IS) {
274   BlockScope<Emitter> IfScope(this);
275 
276   if (IS->isNonNegatedConsteval())
277     return visitStmt(IS->getThen());
278   if (IS->isNegatedConsteval())
279     return IS->getElse() ? visitStmt(IS->getElse()) : true;
280 
281   if (auto *CondInit = IS->getInit())
282     if (!visitStmt(IS->getInit()))
283       return false;
284 
285   if (const DeclStmt *CondDecl = IS->getConditionVariableDeclStmt())
286     if (!visitDeclStmt(CondDecl))
287       return false;
288 
289   if (!this->visitBool(IS->getCond()))
290     return false;
291 
292   if (const Stmt *Else = IS->getElse()) {
293     LabelTy LabelElse = this->getLabel();
294     LabelTy LabelEnd = this->getLabel();
295     if (!this->jumpFalse(LabelElse))
296       return false;
297     if (!visitStmt(IS->getThen()))
298       return false;
299     if (!this->jump(LabelEnd))
300       return false;
301     this->emitLabel(LabelElse);
302     if (!visitStmt(Else))
303       return false;
304     this->emitLabel(LabelEnd);
305   } else {
306     LabelTy LabelEnd = this->getLabel();
307     if (!this->jumpFalse(LabelEnd))
308       return false;
309     if (!visitStmt(IS->getThen()))
310       return false;
311     this->emitLabel(LabelEnd);
312   }
313 
314   return true;
315 }
316 
317 template <class Emitter>
318 bool ByteCodeStmtGen<Emitter>::visitWhileStmt(const WhileStmt *S) {
319   const Expr *Cond = S->getCond();
320   const Stmt *Body = S->getBody();
321 
322   LabelTy CondLabel = this->getLabel(); // Label before the condition.
323   LabelTy EndLabel = this->getLabel();  // Label after the loop.
324   LoopScope<Emitter> LS(this, EndLabel, CondLabel);
325 
326   this->emitLabel(CondLabel);
327   if (!this->visitBool(Cond))
328     return false;
329   if (!this->jumpFalse(EndLabel))
330     return false;
331 
332   LocalScope<Emitter> Scope(this);
333   {
334     DestructorScope<Emitter> DS(Scope);
335     if (!this->visitLoopBody(Body))
336       return false;
337   }
338 
339   if (!this->jump(CondLabel))
340     return false;
341   this->emitLabel(EndLabel);
342 
343   return true;
344 }
345 
346 template <class Emitter>
347 bool ByteCodeStmtGen<Emitter>::visitDoStmt(const DoStmt *S) {
348   const Expr *Cond = S->getCond();
349   const Stmt *Body = S->getBody();
350 
351   LabelTy StartLabel = this->getLabel();
352   LabelTy EndLabel = this->getLabel();
353   LabelTy CondLabel = this->getLabel();
354   LoopScope<Emitter> LS(this, EndLabel, CondLabel);
355   LocalScope<Emitter> Scope(this);
356 
357   this->emitLabel(StartLabel);
358   {
359     DestructorScope<Emitter> DS(Scope);
360 
361     if (!this->visitLoopBody(Body))
362       return false;
363     this->emitLabel(CondLabel);
364     if (!this->visitBool(Cond))
365       return false;
366   }
367   if (!this->jumpTrue(StartLabel))
368     return false;
369 
370   this->emitLabel(EndLabel);
371   return true;
372 }
373 
374 template <class Emitter>
375 bool ByteCodeStmtGen<Emitter>::visitForStmt(const ForStmt *S) {
376   // for (Init; Cond; Inc) { Body }
377   const Stmt *Init = S->getInit();
378   const Expr *Cond = S->getCond();
379   const Expr *Inc = S->getInc();
380   const Stmt *Body = S->getBody();
381 
382   LabelTy EndLabel = this->getLabel();
383   LabelTy CondLabel = this->getLabel();
384   LabelTy IncLabel = this->getLabel();
385   LoopScope<Emitter> LS(this, EndLabel, IncLabel);
386   LocalScope<Emitter> Scope(this);
387 
388   if (Init && !this->visitStmt(Init))
389     return false;
390   this->emitLabel(CondLabel);
391   if (Cond) {
392     if (!this->visitBool(Cond))
393       return false;
394     if (!this->jumpFalse(EndLabel))
395       return false;
396   }
397 
398   {
399     DestructorScope<Emitter> DS(Scope);
400 
401     if (Body && !this->visitLoopBody(Body))
402       return false;
403     this->emitLabel(IncLabel);
404     if (Inc && !this->discard(Inc))
405       return false;
406   }
407 
408   if (!this->jump(CondLabel))
409     return false;
410   this->emitLabel(EndLabel);
411   return true;
412 }
413 
414 template <class Emitter>
415 bool ByteCodeStmtGen<Emitter>::visitCXXForRangeStmt(const CXXForRangeStmt *S) {
416   const Stmt *Init = S->getInit();
417   const Expr *Cond = S->getCond();
418   const Expr *Inc = S->getInc();
419   const Stmt *Body = S->getBody();
420   const Stmt *BeginStmt = S->getBeginStmt();
421   const Stmt *RangeStmt = S->getRangeStmt();
422   const Stmt *EndStmt = S->getEndStmt();
423   const VarDecl *LoopVar = S->getLoopVariable();
424 
425   LabelTy EndLabel = this->getLabel();
426   LabelTy CondLabel = this->getLabel();
427   LabelTy IncLabel = this->getLabel();
428   LoopScope<Emitter> LS(this, EndLabel, IncLabel);
429 
430   // Emit declarations needed in the loop.
431   if (Init && !this->visitStmt(Init))
432     return false;
433   if (!this->visitStmt(RangeStmt))
434     return false;
435   if (!this->visitStmt(BeginStmt))
436     return false;
437   if (!this->visitStmt(EndStmt))
438     return false;
439 
440   // Now the condition as well as the loop variable assignment.
441   this->emitLabel(CondLabel);
442   if (!this->visitBool(Cond))
443     return false;
444   if (!this->jumpFalse(EndLabel))
445     return false;
446 
447   if (!this->visitVarDecl(LoopVar))
448     return false;
449 
450   // Body.
451   LocalScope<Emitter> Scope(this);
452   {
453     DestructorScope<Emitter> DS(Scope);
454 
455     if (!this->visitLoopBody(Body))
456       return false;
457     this->emitLabel(IncLabel);
458     if (!this->discard(Inc))
459       return false;
460   }
461   if (!this->jump(CondLabel))
462     return false;
463 
464   this->emitLabel(EndLabel);
465   return true;
466 }
467 
468 template <class Emitter>
469 bool ByteCodeStmtGen<Emitter>::visitBreakStmt(const BreakStmt *S) {
470   if (!BreakLabel)
471     return false;
472 
473   this->VarScope->emitDestructors();
474   return this->jump(*BreakLabel);
475 }
476 
477 template <class Emitter>
478 bool ByteCodeStmtGen<Emitter>::visitContinueStmt(const ContinueStmt *S) {
479   if (!ContinueLabel)
480     return false;
481 
482   this->VarScope->emitDestructors();
483   return this->jump(*ContinueLabel);
484 }
485 
486 template <class Emitter>
487 bool ByteCodeStmtGen<Emitter>::visitSwitchStmt(const SwitchStmt *S) {
488   const Expr *Cond = S->getCond();
489   PrimType CondT = this->classifyPrim(Cond->getType());
490 
491   LabelTy EndLabel = this->getLabel();
492   OptLabelTy DefaultLabel = std::nullopt;
493   unsigned CondVar = this->allocateLocalPrimitive(Cond, CondT, true, false);
494 
495   if (const auto *CondInit = S->getInit())
496     if (!visitStmt(CondInit))
497       return false;
498 
499   // Initialize condition variable.
500   if (!this->visit(Cond))
501     return false;
502   if (!this->emitSetLocal(CondT, CondVar, S))
503     return false;
504 
505   CaseMap CaseLabels;
506   // Create labels and comparison ops for all case statements.
507   for (const SwitchCase *SC = S->getSwitchCaseList(); SC;
508        SC = SC->getNextSwitchCase()) {
509     if (const auto *CS = dyn_cast<CaseStmt>(SC)) {
510       // FIXME: Implement ranges.
511       if (CS->caseStmtIsGNURange())
512         return false;
513       CaseLabels[SC] = this->getLabel();
514 
515       const Expr *Value = CS->getLHS();
516       PrimType ValueT = this->classifyPrim(Value->getType());
517 
518       // Compare the case statement's value to the switch condition.
519       if (!this->emitGetLocal(CondT, CondVar, CS))
520         return false;
521       if (!this->visit(Value))
522         return false;
523 
524       // Compare and jump to the case label.
525       if (!this->emitEQ(ValueT, S))
526         return false;
527       if (!this->jumpTrue(CaseLabels[CS]))
528         return false;
529     } else {
530       assert(!DefaultLabel);
531       DefaultLabel = this->getLabel();
532     }
533   }
534 
535   // If none of the conditions above were true, fall through to the default
536   // statement or jump after the switch statement.
537   if (DefaultLabel) {
538     if (!this->jump(*DefaultLabel))
539       return false;
540   } else {
541     if (!this->jump(EndLabel))
542       return false;
543   }
544 
545   SwitchScope<Emitter> SS(this, std::move(CaseLabels), EndLabel, DefaultLabel);
546   if (!this->visitStmt(S->getBody()))
547     return false;
548   this->emitLabel(EndLabel);
549   return true;
550 }
551 
552 template <class Emitter>
553 bool ByteCodeStmtGen<Emitter>::visitCaseStmt(const CaseStmt *S) {
554   this->emitLabel(CaseLabels[S]);
555   return this->visitStmt(S->getSubStmt());
556 }
557 
558 template <class Emitter>
559 bool ByteCodeStmtGen<Emitter>::visitDefaultStmt(const DefaultStmt *S) {
560   this->emitLabel(*DefaultLabel);
561   return this->visitStmt(S->getSubStmt());
562 }
563 
564 namespace clang {
565 namespace interp {
566 
567 template class ByteCodeStmtGen<ByteCodeEmitter>;
568 
569 } // namespace interp
570 } // namespace clang
571