1 //===- CoroEarly.cpp - Coroutine Early Function Pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass lowers coroutine intrinsics that hide the details of the exact
9 // calling convention for coroutine resume and destroy functions and details of
10 // the structure of the coroutine frame.
11 //===----------------------------------------------------------------------===//
12 
13 #include "CoroInternal.h"
14 #include "llvm/IR/CallSite.h"
15 #include "llvm/IR/IRBuilder.h"
16 #include "llvm/IR/InstIterator.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/Pass.h"
19 
20 using namespace llvm;
21 
22 #define DEBUG_TYPE "coro-early"
23 
24 namespace {
25 // Created on demand if the coro-early pass has work to do.
26 class Lowerer : public coro::LowererBase {
27   IRBuilder<> Builder;
28   PointerType *const AnyResumeFnPtrTy;
29   Constant *NoopCoro = nullptr;
30 
31   void lowerResumeOrDestroy(CallSite CS, CoroSubFnInst::ResumeKind);
32   void lowerCoroPromise(CoroPromiseInst *Intrin);
33   void lowerCoroDone(IntrinsicInst *II);
34   void lowerCoroNoop(IntrinsicInst *II);
35 
36 public:
37   Lowerer(Module &M)
38       : LowererBase(M), Builder(Context),
39         AnyResumeFnPtrTy(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
40                                            /*isVarArg=*/false)
41                              ->getPointerTo()) {}
42   bool lowerEarlyIntrinsics(Function &F);
43 };
44 }
45 
46 // Replace a direct call to coro.resume or coro.destroy with an indirect call to
47 // an address returned by coro.subfn.addr intrinsic. This is done so that
48 // CGPassManager recognizes devirtualization when CoroElide pass replaces a call
49 // to coro.subfn.addr with an appropriate function address.
50 void Lowerer::lowerResumeOrDestroy(CallSite CS,
51                                    CoroSubFnInst::ResumeKind Index) {
52   Value *ResumeAddr =
53       makeSubFnCall(CS.getArgOperand(0), Index, CS.getInstruction());
54   CS.setCalledFunction(ResumeAddr);
55   CS.setCallingConv(CallingConv::Fast);
56 }
57 
58 // Coroutine promise field is always at the fixed offset from the beginning of
59 // the coroutine frame. i8* coro.promise(i8*, i1 from) intrinsic adds an offset
60 // to a passed pointer to move from coroutine frame to coroutine promise and
61 // vice versa. Since we don't know exactly which coroutine frame it is, we build
62 // a coroutine frame mock up starting with two function pointers, followed by a
63 // properly aligned coroutine promise field.
64 // TODO: Handle the case when coroutine promise alloca has align override.
65 void Lowerer::lowerCoroPromise(CoroPromiseInst *Intrin) {
66   Value *Operand = Intrin->getArgOperand(0);
67   unsigned Alignement = Intrin->getAlignment();
68   Type *Int8Ty = Builder.getInt8Ty();
69 
70   auto *SampleStruct =
71       StructType::get(Context, {AnyResumeFnPtrTy, AnyResumeFnPtrTy, Int8Ty});
72   const DataLayout &DL = TheModule.getDataLayout();
73   int64_t Offset = alignTo(
74       DL.getStructLayout(SampleStruct)->getElementOffset(2), Alignement);
75   if (Intrin->isFromPromise())
76     Offset = -Offset;
77 
78   Builder.SetInsertPoint(Intrin);
79   Value *Replacement =
80       Builder.CreateConstInBoundsGEP1_32(Int8Ty, Operand, Offset);
81 
82   Intrin->replaceAllUsesWith(Replacement);
83   Intrin->eraseFromParent();
84 }
85 
86 // When a coroutine reaches final suspend point, it zeros out ResumeFnAddr in
87 // the coroutine frame (it is UB to resume from a final suspend point).
88 // The llvm.coro.done intrinsic is used to check whether a coroutine is
89 // suspended at the final suspend point or not.
90 void Lowerer::lowerCoroDone(IntrinsicInst *II) {
91   Value *Operand = II->getArgOperand(0);
92 
93   // ResumeFnAddr is the first pointer sized element of the coroutine frame.
94   static_assert(coro::Shape::SwitchFieldIndex::Resume == 0,
95                 "resume function not at offset zero");
96   auto *FrameTy = Int8Ptr;
97   PointerType *FramePtrTy = FrameTy->getPointerTo();
98 
99   Builder.SetInsertPoint(II);
100   auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy);
101   auto *Load = Builder.CreateLoad(BCI);
102   auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
103 
104   II->replaceAllUsesWith(Cond);
105   II->eraseFromParent();
106 }
107 
108 void Lowerer::lowerCoroNoop(IntrinsicInst *II) {
109   if (!NoopCoro) {
110     LLVMContext &C = Builder.getContext();
111     Module &M = *II->getModule();
112 
113     // Create a noop.frame struct type.
114     StructType *FrameTy = StructType::create(C, "NoopCoro.Frame");
115     auto *FramePtrTy = FrameTy->getPointerTo();
116     auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
117                                    /*isVarArg=*/false);
118     auto *FnPtrTy = FnTy->getPointerTo();
119     FrameTy->setBody({FnPtrTy, FnPtrTy});
120 
121     // Create a Noop function that does nothing.
122     Function *NoopFn =
123         Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
124                          "NoopCoro.ResumeDestroy", &M);
125     NoopFn->setCallingConv(CallingConv::Fast);
126     auto *Entry = BasicBlock::Create(C, "entry", NoopFn);
127     ReturnInst::Create(C, Entry);
128 
129     // Create a constant struct for the frame.
130     Constant* Values[] = {NoopFn, NoopFn};
131     Constant* NoopCoroConst = ConstantStruct::get(FrameTy, Values);
132     NoopCoro = new GlobalVariable(M, NoopCoroConst->getType(), /*isConstant=*/true,
133                                 GlobalVariable::PrivateLinkage, NoopCoroConst,
134                                 "NoopCoro.Frame.Const");
135   }
136 
137   Builder.SetInsertPoint(II);
138   auto *NoopCoroVoidPtr = Builder.CreateBitCast(NoopCoro, Int8Ptr);
139   II->replaceAllUsesWith(NoopCoroVoidPtr);
140   II->eraseFromParent();
141 }
142 
143 // Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate,
144 // as CoroSplit assumes there is exactly one coro.begin. After CoroSplit,
145 // NoDuplicate attribute will be removed from coro.begin otherwise, it will
146 // interfere with inlining.
147 static void setCannotDuplicate(CoroIdInst *CoroId) {
148   for (User *U : CoroId->users())
149     if (auto *CB = dyn_cast<CoroBeginInst>(U))
150       CB->setCannotDuplicate();
151 }
152 
153 bool Lowerer::lowerEarlyIntrinsics(Function &F) {
154   bool Changed = false;
155   CoroIdInst *CoroId = nullptr;
156   SmallVector<CoroFreeInst *, 4> CoroFrees;
157   for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) {
158     Instruction &I = *IB++;
159     if (auto CS = CallSite(&I)) {
160       switch (CS.getIntrinsicID()) {
161       default:
162         continue;
163       case Intrinsic::coro_free:
164         CoroFrees.push_back(cast<CoroFreeInst>(&I));
165         break;
166       case Intrinsic::coro_suspend:
167         // Make sure that final suspend point is not duplicated as CoroSplit
168         // pass expects that there is at most one final suspend point.
169         if (cast<CoroSuspendInst>(&I)->isFinal())
170           CS.setCannotDuplicate();
171         break;
172       case Intrinsic::coro_end:
173         // Make sure that fallthrough coro.end is not duplicated as CoroSplit
174         // pass expects that there is at most one fallthrough coro.end.
175         if (cast<CoroEndInst>(&I)->isFallthrough())
176           CS.setCannotDuplicate();
177         break;
178       case Intrinsic::coro_noop:
179         lowerCoroNoop(cast<IntrinsicInst>(&I));
180         break;
181       case Intrinsic::coro_id:
182         // Mark a function that comes out of the frontend that has a coro.id
183         // with a coroutine attribute.
184         if (auto *CII = cast<CoroIdInst>(&I)) {
185           if (CII->getInfo().isPreSplit()) {
186             F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT);
187             setCannotDuplicate(CII);
188             CII->setCoroutineSelf();
189             CoroId = cast<CoroIdInst>(&I);
190           }
191         }
192         break;
193       case Intrinsic::coro_id_retcon:
194       case Intrinsic::coro_id_retcon_once:
195         F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
196         break;
197       case Intrinsic::coro_resume:
198         lowerResumeOrDestroy(CS, CoroSubFnInst::ResumeIndex);
199         break;
200       case Intrinsic::coro_destroy:
201         lowerResumeOrDestroy(CS, CoroSubFnInst::DestroyIndex);
202         break;
203       case Intrinsic::coro_promise:
204         lowerCoroPromise(cast<CoroPromiseInst>(&I));
205         break;
206       case Intrinsic::coro_done:
207         lowerCoroDone(cast<IntrinsicInst>(&I));
208         break;
209       }
210       Changed = true;
211     }
212   }
213   // Make sure that all CoroFree reference the coro.id intrinsic.
214   // Token type is not exposed through coroutine C/C++ builtins to plain C, so
215   // we allow specifying none and fixing it up here.
216   if (CoroId)
217     for (CoroFreeInst *CF : CoroFrees)
218       CF->setArgOperand(0, CoroId);
219   return Changed;
220 }
221 
222 //===----------------------------------------------------------------------===//
223 //                              Top Level Driver
224 //===----------------------------------------------------------------------===//
225 
226 namespace {
227 
228 struct CoroEarlyLegacy : public FunctionPass {
229   static char ID; // Pass identification, replacement for typeid.
230   CoroEarlyLegacy() : FunctionPass(ID) {
231     initializeCoroEarlyLegacyPass(*PassRegistry::getPassRegistry());
232   }
233 
234   std::unique_ptr<Lowerer> L;
235 
236   // This pass has work to do only if we find intrinsics we are going to lower
237   // in the module.
238   bool doInitialization(Module &M) override {
239     if (coro::declaresIntrinsics(M, {"llvm.coro.id",
240                                      "llvm.coro.id.retcon",
241                                      "llvm.coro.id.retcon.once",
242                                      "llvm.coro.destroy",
243                                      "llvm.coro.done",
244                                      "llvm.coro.end",
245                                      "llvm.coro.noop",
246                                      "llvm.coro.free",
247                                      "llvm.coro.promise",
248                                      "llvm.coro.resume",
249                                      "llvm.coro.suspend"}))
250       L = std::make_unique<Lowerer>(M);
251     return false;
252   }
253 
254   bool runOnFunction(Function &F) override {
255     if (!L)
256       return false;
257 
258     return L->lowerEarlyIntrinsics(F);
259   }
260 
261   void getAnalysisUsage(AnalysisUsage &AU) const override {
262     AU.setPreservesCFG();
263   }
264   StringRef getPassName() const override {
265     return "Lower early coroutine intrinsics";
266   }
267 };
268 }
269 
270 char CoroEarlyLegacy::ID = 0;
271 INITIALIZE_PASS(CoroEarlyLegacy, "coro-early",
272                 "Lower early coroutine intrinsics", false, false)
273 
274 Pass *llvm::createCoroEarlyLegacyPass() { return new CoroEarlyLegacy(); }
275