1 //===- Coroutines.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the common infrastructure for Coroutine Passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/Coroutines.h"
14 #include "CoroInstr.h"
15 #include "CoroInternal.h"
16 #include "llvm-c/Transforms/Coroutines.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/CallGraphSCCPass.h"
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/InstIterator.h"
26 #include "llvm/IR/Instructions.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Intrinsics.h"
29 #include "llvm/IR/LegacyPassManager.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/ErrorHandling.h"
35 #include "llvm/Transforms/IPO.h"
36 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
37 #include "llvm/Transforms/Utils/Local.h"
38 #include <cassert>
39 #include <cstddef>
40 #include <utility>
41 
42 using namespace llvm;
43 
44 void llvm::initializeCoroutines(PassRegistry &Registry) {
45   initializeCoroEarlyLegacyPass(Registry);
46   initializeCoroSplitLegacyPass(Registry);
47   initializeCoroElideLegacyPass(Registry);
48   initializeCoroCleanupLegacyPass(Registry);
49 }
50 
51 static void addCoroutineOpt0Passes(const PassManagerBuilder &Builder,
52                                    legacy::PassManagerBase &PM) {
53   PM.add(createCoroSplitLegacyPass());
54   PM.add(createCoroElideLegacyPass());
55 
56   PM.add(createBarrierNoopPass());
57   PM.add(createCoroCleanupLegacyPass());
58 }
59 
60 static void addCoroutineEarlyPasses(const PassManagerBuilder &Builder,
61                                     legacy::PassManagerBase &PM) {
62   PM.add(createCoroEarlyLegacyPass());
63 }
64 
65 static void addCoroutineScalarOptimizerPasses(const PassManagerBuilder &Builder,
66                                               legacy::PassManagerBase &PM) {
67   PM.add(createCoroElideLegacyPass());
68 }
69 
70 static void addCoroutineSCCPasses(const PassManagerBuilder &Builder,
71                                   legacy::PassManagerBase &PM) {
72   PM.add(createCoroSplitLegacyPass(Builder.OptLevel != 0));
73 }
74 
75 static void addCoroutineOptimizerLastPasses(const PassManagerBuilder &Builder,
76                                             legacy::PassManagerBase &PM) {
77   PM.add(createCoroCleanupLegacyPass());
78 }
79 
80 void llvm::addCoroutinePassesToExtensionPoints(PassManagerBuilder &Builder) {
81   Builder.addExtension(PassManagerBuilder::EP_EarlyAsPossible,
82                        addCoroutineEarlyPasses);
83   Builder.addExtension(PassManagerBuilder::EP_EnabledOnOptLevel0,
84                        addCoroutineOpt0Passes);
85   Builder.addExtension(PassManagerBuilder::EP_CGSCCOptimizerLate,
86                        addCoroutineSCCPasses);
87   Builder.addExtension(PassManagerBuilder::EP_ScalarOptimizerLate,
88                        addCoroutineScalarOptimizerPasses);
89   Builder.addExtension(PassManagerBuilder::EP_OptimizerLast,
90                        addCoroutineOptimizerLastPasses);
91 }
92 
93 // Construct the lowerer base class and initialize its members.
94 coro::LowererBase::LowererBase(Module &M)
95     : TheModule(M), Context(M.getContext()),
96       Int8Ptr(Type::getInt8PtrTy(Context)),
97       ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
98                                      /*isVarArg=*/false)),
99       NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
100 
101 // Creates a sequence of instructions to obtain a resume function address using
102 // llvm.coro.subfn.addr. It generates the following sequence:
103 //
104 //    call i8* @llvm.coro.subfn.addr(i8* %Arg, i8 %index)
105 //    bitcast i8* %2 to void(i8*)*
106 
107 Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
108                                         Instruction *InsertPt) {
109   auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
110   auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
111 
112   assert(Index >= CoroSubFnInst::IndexFirst &&
113          Index < CoroSubFnInst::IndexLast &&
114          "makeSubFnCall: Index value out of range");
115   auto *Call = CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt);
116 
117   auto *Bitcast =
118       new BitCastInst(Call, ResumeFnType->getPointerTo(), "", InsertPt);
119   return Bitcast;
120 }
121 
122 #ifndef NDEBUG
123 static bool isCoroutineIntrinsicName(StringRef Name) {
124   // NOTE: Must be sorted!
125   static const char *const CoroIntrinsics[] = {
126       "llvm.coro.alloc",
127       "llvm.coro.async.context.alloc",
128       "llvm.coro.async.context.dealloc",
129       "llvm.coro.async.store_resume",
130       "llvm.coro.begin",
131       "llvm.coro.destroy",
132       "llvm.coro.done",
133       "llvm.coro.end",
134       "llvm.coro.end.async",
135       "llvm.coro.frame",
136       "llvm.coro.free",
137       "llvm.coro.id",
138       "llvm.coro.id.async",
139       "llvm.coro.id.retcon",
140       "llvm.coro.id.retcon.once",
141       "llvm.coro.noop",
142       "llvm.coro.param",
143       "llvm.coro.prepare.async",
144       "llvm.coro.prepare.retcon",
145       "llvm.coro.promise",
146       "llvm.coro.resume",
147       "llvm.coro.save",
148       "llvm.coro.size",
149       "llvm.coro.subfn.addr",
150       "llvm.coro.suspend",
151       "llvm.coro.suspend.async",
152       "llvm.coro.suspend.retcon",
153   };
154   return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
155 }
156 #endif
157 
158 // Verifies if a module has named values listed. Also, in debug mode verifies
159 // that names are intrinsic names.
160 bool coro::declaresIntrinsics(const Module &M,
161                               const std::initializer_list<StringRef> List) {
162   for (StringRef Name : List) {
163     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
164     if (M.getNamedValue(Name))
165       return true;
166   }
167 
168   return false;
169 }
170 
171 // Replace all coro.frees associated with the provided CoroId either with 'null'
172 // if Elide is true and with its frame parameter otherwise.
173 void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
174   SmallVector<CoroFreeInst *, 4> CoroFrees;
175   for (User *U : CoroId->users())
176     if (auto CF = dyn_cast<CoroFreeInst>(U))
177       CoroFrees.push_back(CF);
178 
179   if (CoroFrees.empty())
180     return;
181 
182   Value *Replacement =
183       Elide ? ConstantPointerNull::get(Type::getInt8PtrTy(CoroId->getContext()))
184             : CoroFrees.front()->getFrame();
185 
186   for (CoroFreeInst *CF : CoroFrees) {
187     CF->replaceAllUsesWith(Replacement);
188     CF->eraseFromParent();
189   }
190 }
191 
192 // FIXME: This code is stolen from CallGraph::addToCallGraph(Function *F), which
193 // happens to be private. It is better for this functionality exposed by the
194 // CallGraph.
195 static void buildCGN(CallGraph &CG, CallGraphNode *Node) {
196   Function *F = Node->getFunction();
197 
198   // Look for calls by this function.
199   for (Instruction &I : instructions(F))
200     if (auto *Call = dyn_cast<CallBase>(&I)) {
201       const Function *Callee = Call->getCalledFunction();
202       if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID()))
203         // Indirect calls of intrinsics are not allowed so no need to check.
204         // We can be more precise here by using TargetArg returned by
205         // Intrinsic::isLeaf.
206         Node->addCalledFunction(Call, CG.getCallsExternalNode());
207       else if (!Callee->isIntrinsic())
208         Node->addCalledFunction(Call, CG.getOrInsertFunction(Callee));
209     }
210 }
211 
212 // Rebuild CGN after we extracted parts of the code from ParentFunc into
213 // NewFuncs. Builds CGNs for the NewFuncs and adds them to the current SCC.
214 void coro::updateCallGraph(Function &ParentFunc, ArrayRef<Function *> NewFuncs,
215                            CallGraph &CG, CallGraphSCC &SCC) {
216   // Rebuild CGN from scratch for the ParentFunc
217   auto *ParentNode = CG[&ParentFunc];
218   ParentNode->removeAllCalledFunctions();
219   buildCGN(CG, ParentNode);
220 
221   SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
222 
223   for (Function *F : NewFuncs) {
224     CallGraphNode *Callee = CG.getOrInsertFunction(F);
225     Nodes.push_back(Callee);
226     buildCGN(CG, Callee);
227   }
228 
229   SCC.initialize(Nodes);
230 }
231 
232 static void clear(coro::Shape &Shape) {
233   Shape.CoroBegin = nullptr;
234   Shape.CoroEnds.clear();
235   Shape.CoroSizes.clear();
236   Shape.CoroSuspends.clear();
237 
238   Shape.FrameTy = nullptr;
239   Shape.FramePtr = nullptr;
240   Shape.AllocaSpillBlock = nullptr;
241 }
242 
243 static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
244                                     CoroSuspendInst *SuspendInst) {
245   Module *M = SuspendInst->getModule();
246   auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
247   auto *SaveInst =
248       cast<CoroSaveInst>(CallInst::Create(Fn, CoroBegin, "", SuspendInst));
249   assert(!SuspendInst->getCoroSave());
250   SuspendInst->setArgOperand(0, SaveInst);
251   return SaveInst;
252 }
253 
254 // Collect "interesting" coroutine intrinsics.
255 void coro::Shape::buildFrom(Function &F) {
256   bool HasFinalSuspend = false;
257   size_t FinalSuspendIndex = 0;
258   clear(*this);
259   SmallVector<CoroFrameInst *, 8> CoroFrames;
260   SmallVector<CoroSaveInst *, 2> UnusedCoroSaves;
261 
262   for (Instruction &I : instructions(F)) {
263     if (auto II = dyn_cast<IntrinsicInst>(&I)) {
264       switch (II->getIntrinsicID()) {
265       default:
266         continue;
267       case Intrinsic::coro_size:
268         CoroSizes.push_back(cast<CoroSizeInst>(II));
269         break;
270       case Intrinsic::coro_frame:
271         CoroFrames.push_back(cast<CoroFrameInst>(II));
272         break;
273       case Intrinsic::coro_save:
274         // After optimizations, coro_suspends using this coro_save might have
275         // been removed, remember orphaned coro_saves to remove them later.
276         if (II->use_empty())
277           UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
278         break;
279       case Intrinsic::coro_suspend_async: {
280         auto *Suspend = cast<CoroSuspendAsyncInst>(II);
281         Suspend->checkWellFormed();
282         CoroSuspends.push_back(Suspend);
283         break;
284       }
285       case Intrinsic::coro_suspend_retcon: {
286         auto Suspend = cast<CoroSuspendRetconInst>(II);
287         CoroSuspends.push_back(Suspend);
288         break;
289       }
290       case Intrinsic::coro_suspend: {
291         auto Suspend = cast<CoroSuspendInst>(II);
292         CoroSuspends.push_back(Suspend);
293         if (Suspend->isFinal()) {
294           if (HasFinalSuspend)
295             report_fatal_error(
296               "Only one suspend point can be marked as final");
297           HasFinalSuspend = true;
298           FinalSuspendIndex = CoroSuspends.size() - 1;
299         }
300         break;
301       }
302       case Intrinsic::coro_begin: {
303         auto CB = cast<CoroBeginInst>(II);
304 
305         // Ignore coro id's that aren't pre-split.
306         auto Id = dyn_cast<CoroIdInst>(CB->getId());
307         if (Id && !Id->getInfo().isPreSplit())
308           break;
309 
310         if (CoroBegin)
311           report_fatal_error(
312                 "coroutine should have exactly one defining @llvm.coro.begin");
313         CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
314         CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
315         CB->removeAttribute(AttributeList::FunctionIndex,
316                             Attribute::NoDuplicate);
317         CoroBegin = CB;
318         break;
319       }
320       case Intrinsic::coro_end_async:
321       case Intrinsic::coro_end:
322         CoroEnds.push_back(cast<AnyCoroEndInst>(II));
323         if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(II)) {
324           AsyncEnd->checkWellFormed();
325         }
326         if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(II)) {
327           // Make sure that the fallthrough coro.end is the first element in the
328           // CoroEnds vector.
329           // Note: I don't think this is neccessary anymore.
330           if (CoroEnds.size() > 1) {
331             if (CoroEnds.front()->isFallthrough())
332               report_fatal_error(
333                   "Only one coro.end can be marked as fallthrough");
334             std::swap(CoroEnds.front(), CoroEnds.back());
335           }
336         }
337         break;
338       }
339     }
340   }
341 
342   // If for some reason, we were not able to find coro.begin, bailout.
343   if (!CoroBegin) {
344     // Replace coro.frame which are supposed to be lowered to the result of
345     // coro.begin with undef.
346     auto *Undef = UndefValue::get(Type::getInt8PtrTy(F.getContext()));
347     for (CoroFrameInst *CF : CoroFrames) {
348       CF->replaceAllUsesWith(Undef);
349       CF->eraseFromParent();
350     }
351 
352     // Replace all coro.suspend with undef and remove related coro.saves if
353     // present.
354     for (AnyCoroSuspendInst *CS : CoroSuspends) {
355       CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
356       CS->eraseFromParent();
357       if (auto *CoroSave = CS->getCoroSave())
358         CoroSave->eraseFromParent();
359     }
360 
361     // Replace all coro.ends with unreachable instruction.
362     for (AnyCoroEndInst *CE : CoroEnds)
363       changeToUnreachable(CE, /*UseLLVMTrap=*/false);
364 
365     return;
366   }
367 
368   auto Id = CoroBegin->getId();
369   switch (auto IdIntrinsic = Id->getIntrinsicID()) {
370   case Intrinsic::coro_id: {
371     auto SwitchId = cast<CoroIdInst>(Id);
372     this->ABI = coro::ABI::Switch;
373     this->SwitchLowering.HasFinalSuspend = HasFinalSuspend;
374     this->SwitchLowering.ResumeSwitch = nullptr;
375     this->SwitchLowering.PromiseAlloca = SwitchId->getPromise();
376     this->SwitchLowering.ResumeEntryBlock = nullptr;
377 
378     for (auto AnySuspend : CoroSuspends) {
379       auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
380       if (!Suspend) {
381 #ifndef NDEBUG
382         AnySuspend->dump();
383 #endif
384         report_fatal_error("coro.id must be paired with coro.suspend");
385       }
386 
387       if (!Suspend->getCoroSave())
388         createCoroSave(CoroBegin, Suspend);
389     }
390     break;
391   }
392   case Intrinsic::coro_id_async: {
393     auto *AsyncId = cast<CoroIdAsyncInst>(Id);
394     AsyncId->checkWellFormed();
395     this->ABI = coro::ABI::Async;
396     this->AsyncLowering.Context = AsyncId->getStorage();
397     this->AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
398     this->AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
399     this->AsyncLowering.ContextAlignment =
400         AsyncId->getStorageAlignment().value();
401     this->AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
402     auto &Context = F.getContext();
403     auto *Int8PtrTy = Type::getInt8PtrTy(Context);
404     auto *VoidTy = Type::getVoidTy(Context);
405     this->AsyncLowering.AsyncFuncTy =
406         FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy, Int8PtrTy}, false);
407     break;
408   };
409   case Intrinsic::coro_id_retcon:
410   case Intrinsic::coro_id_retcon_once: {
411     auto ContinuationId = cast<AnyCoroIdRetconInst>(Id);
412     ContinuationId->checkWellFormed();
413     this->ABI = (IdIntrinsic == Intrinsic::coro_id_retcon
414                   ? coro::ABI::Retcon
415                   : coro::ABI::RetconOnce);
416     auto Prototype = ContinuationId->getPrototype();
417     this->RetconLowering.ResumePrototype = Prototype;
418     this->RetconLowering.Alloc = ContinuationId->getAllocFunction();
419     this->RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
420     this->RetconLowering.ReturnBlock = nullptr;
421     this->RetconLowering.IsFrameInlineInStorage = false;
422 
423     // Determine the result value types, and make sure they match up with
424     // the values passed to the suspends.
425     auto ResultTys = getRetconResultTypes();
426     auto ResumeTys = getRetconResumeTypes();
427 
428     for (auto AnySuspend : CoroSuspends) {
429       auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
430       if (!Suspend) {
431 #ifndef NDEBUG
432         AnySuspend->dump();
433 #endif
434         report_fatal_error("coro.id.retcon.* must be paired with "
435                            "coro.suspend.retcon");
436       }
437 
438       // Check that the argument types of the suspend match the results.
439       auto SI = Suspend->value_begin(), SE = Suspend->value_end();
440       auto RI = ResultTys.begin(), RE = ResultTys.end();
441       for (; SI != SE && RI != RE; ++SI, ++RI) {
442         auto SrcTy = (*SI)->getType();
443         if (SrcTy != *RI) {
444           // The optimizer likes to eliminate bitcasts leading into variadic
445           // calls, but that messes with our invariants.  Re-insert the
446           // bitcast and ignore this type mismatch.
447           if (CastInst::isBitCastable(SrcTy, *RI)) {
448             auto BCI = new BitCastInst(*SI, *RI, "", Suspend);
449             SI->set(BCI);
450             continue;
451           }
452 
453 #ifndef NDEBUG
454           Suspend->dump();
455           Prototype->getFunctionType()->dump();
456 #endif
457           report_fatal_error("argument to coro.suspend.retcon does not "
458                              "match corresponding prototype function result");
459         }
460       }
461       if (SI != SE || RI != RE) {
462 #ifndef NDEBUG
463         Suspend->dump();
464         Prototype->getFunctionType()->dump();
465 #endif
466         report_fatal_error("wrong number of arguments to coro.suspend.retcon");
467       }
468 
469       // Check that the result type of the suspend matches the resume types.
470       Type *SResultTy = Suspend->getType();
471       ArrayRef<Type*> SuspendResultTys;
472       if (SResultTy->isVoidTy()) {
473         // leave as empty array
474       } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
475         SuspendResultTys = SResultStructTy->elements();
476       } else {
477         // forms an ArrayRef using SResultTy, be careful
478         SuspendResultTys = SResultTy;
479       }
480       if (SuspendResultTys.size() != ResumeTys.size()) {
481 #ifndef NDEBUG
482         Suspend->dump();
483         Prototype->getFunctionType()->dump();
484 #endif
485         report_fatal_error("wrong number of results from coro.suspend.retcon");
486       }
487       for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
488         if (SuspendResultTys[I] != ResumeTys[I]) {
489 #ifndef NDEBUG
490           Suspend->dump();
491           Prototype->getFunctionType()->dump();
492 #endif
493           report_fatal_error("result from coro.suspend.retcon does not "
494                              "match corresponding prototype function param");
495         }
496       }
497     }
498     break;
499   }
500 
501   default:
502     llvm_unreachable("coro.begin is not dependent on a coro.id call");
503   }
504 
505   // The coro.free intrinsic is always lowered to the result of coro.begin.
506   for (CoroFrameInst *CF : CoroFrames) {
507     CF->replaceAllUsesWith(CoroBegin);
508     CF->eraseFromParent();
509   }
510 
511   // Move final suspend to be the last element in the CoroSuspends vector.
512   if (ABI == coro::ABI::Switch &&
513       SwitchLowering.HasFinalSuspend &&
514       FinalSuspendIndex != CoroSuspends.size() - 1)
515     std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
516 
517   // Remove orphaned coro.saves.
518   for (CoroSaveInst *CoroSave : UnusedCoroSaves)
519     CoroSave->eraseFromParent();
520 }
521 
522 static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
523   Call->setCallingConv(Callee->getCallingConv());
524   // TODO: attributes?
525 }
526 
527 static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
528   if (CG)
529     (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
530 }
531 
532 Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
533                               CallGraph *CG) const {
534   switch (ABI) {
535   case coro::ABI::Switch:
536     llvm_unreachable("can't allocate memory in coro switch-lowering");
537 
538   case coro::ABI::Retcon:
539   case coro::ABI::RetconOnce: {
540     auto Alloc = RetconLowering.Alloc;
541     Size = Builder.CreateIntCast(Size,
542                                  Alloc->getFunctionType()->getParamType(0),
543                                  /*is signed*/ false);
544     auto *Call = Builder.CreateCall(Alloc, Size);
545     propagateCallAttrsFromCallee(Call, Alloc);
546     addCallToCallGraph(CG, Call, Alloc);
547     return Call;
548   }
549   case coro::ABI::Async:
550     llvm_unreachable("can't allocate memory in coro async-lowering");
551   }
552   llvm_unreachable("Unknown coro::ABI enum");
553 }
554 
555 void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
556                               CallGraph *CG) const {
557   switch (ABI) {
558   case coro::ABI::Switch:
559     llvm_unreachable("can't allocate memory in coro switch-lowering");
560 
561   case coro::ABI::Retcon:
562   case coro::ABI::RetconOnce: {
563     auto Dealloc = RetconLowering.Dealloc;
564     Ptr = Builder.CreateBitCast(Ptr,
565                                 Dealloc->getFunctionType()->getParamType(0));
566     auto *Call = Builder.CreateCall(Dealloc, Ptr);
567     propagateCallAttrsFromCallee(Call, Dealloc);
568     addCallToCallGraph(CG, Call, Dealloc);
569     return;
570   }
571   case coro::ABI::Async:
572     llvm_unreachable("can't allocate memory in coro async-lowering");
573   }
574   llvm_unreachable("Unknown coro::ABI enum");
575 }
576 
577 LLVM_ATTRIBUTE_NORETURN
578 static void fail(const Instruction *I, const char *Reason, Value *V) {
579 #ifndef NDEBUG
580   I->dump();
581   if (V) {
582     errs() << "  Value: ";
583     V->printAsOperand(llvm::errs());
584     errs() << '\n';
585   }
586 #endif
587   report_fatal_error(Reason);
588 }
589 
590 /// Check that the given value is a well-formed prototype for the
591 /// llvm.coro.id.retcon.* intrinsics.
592 static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
593   auto F = dyn_cast<Function>(V->stripPointerCasts());
594   if (!F)
595     fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
596 
597   auto FT = F->getFunctionType();
598 
599   if (isa<CoroIdRetconInst>(I)) {
600     bool ResultOkay;
601     if (FT->getReturnType()->isPointerTy()) {
602       ResultOkay = true;
603     } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
604       ResultOkay = (!SRetTy->isOpaque() &&
605                     SRetTy->getNumElements() > 0 &&
606                     SRetTy->getElementType(0)->isPointerTy());
607     } else {
608       ResultOkay = false;
609     }
610     if (!ResultOkay)
611       fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
612               "result", F);
613 
614     if (FT->getReturnType() !=
615           I->getFunction()->getFunctionType()->getReturnType())
616       fail(I, "llvm.coro.id.retcon prototype return type must be same as"
617               "current function return type", F);
618   } else {
619     // No meaningful validation to do here for llvm.coro.id.unique.once.
620   }
621 
622   if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
623     fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
624             "its first parameter", F);
625 }
626 
627 /// Check that the given value is a well-formed allocator.
628 static void checkWFAlloc(const Instruction *I, Value *V) {
629   auto F = dyn_cast<Function>(V->stripPointerCasts());
630   if (!F)
631     fail(I, "llvm.coro.* allocator not a Function", V);
632 
633   auto FT = F->getFunctionType();
634   if (!FT->getReturnType()->isPointerTy())
635     fail(I, "llvm.coro.* allocator must return a pointer", F);
636 
637   if (FT->getNumParams() != 1 ||
638       !FT->getParamType(0)->isIntegerTy())
639     fail(I, "llvm.coro.* allocator must take integer as only param", F);
640 }
641 
642 /// Check that the given value is a well-formed deallocator.
643 static void checkWFDealloc(const Instruction *I, Value *V) {
644   auto F = dyn_cast<Function>(V->stripPointerCasts());
645   if (!F)
646     fail(I, "llvm.coro.* deallocator not a Function", V);
647 
648   auto FT = F->getFunctionType();
649   if (!FT->getReturnType()->isVoidTy())
650     fail(I, "llvm.coro.* deallocator must return void", F);
651 
652   if (FT->getNumParams() != 1 ||
653       !FT->getParamType(0)->isPointerTy())
654     fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
655 }
656 
657 static void checkConstantInt(const Instruction *I, Value *V,
658                              const char *Reason) {
659   if (!isa<ConstantInt>(V)) {
660     fail(I, Reason, V);
661   }
662 }
663 
664 void AnyCoroIdRetconInst::checkWellFormed() const {
665   checkConstantInt(this, getArgOperand(SizeArg),
666                    "size argument to coro.id.retcon.* must be constant");
667   checkConstantInt(this, getArgOperand(AlignArg),
668                    "alignment argument to coro.id.retcon.* must be constant");
669   checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
670   checkWFAlloc(this, getArgOperand(AllocArg));
671   checkWFDealloc(this, getArgOperand(DeallocArg));
672 }
673 
674 static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
675   auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts());
676   if (!AsyncFuncPtrAddr)
677     fail(I, "llvm.coro.id.async async function pointer not a global", V);
678 
679   auto *StructTy =
680       cast<StructType>(AsyncFuncPtrAddr->getType()->getPointerElementType());
681   if (StructTy->isOpaque() || !StructTy->isPacked() ||
682       StructTy->getNumElements() != 2 ||
683       !StructTy->getElementType(0)->isIntegerTy(32) ||
684       !StructTy->getElementType(1)->isIntegerTy(32))
685     fail(I,
686          "llvm.coro.id.async async function pointer argument's type is not "
687          "<{i32, i32}>",
688          V);
689 }
690 
691 void CoroIdAsyncInst::checkWellFormed() const {
692   checkConstantInt(this, getArgOperand(SizeArg),
693                    "size argument to coro.id.async must be constant");
694   checkConstantInt(this, getArgOperand(AlignArg),
695                    "alignment argument to coro.id.async must be constant");
696   checkConstantInt(this, getArgOperand(StorageArg),
697                    "storage argument offset to coro.id.async must be constant");
698   checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
699 }
700 
701 static void checkAsyncContextProjectFunction(const Instruction *I,
702                                              Function *F) {
703   auto *FunTy = cast<FunctionType>(F->getType()->getPointerElementType());
704   if (!FunTy->getReturnType()->isPointerTy() ||
705       !FunTy->getReturnType()->getPointerElementType()->isIntegerTy(8))
706     fail(I,
707          "llvm.coro.suspend.async resume function projection function must "
708          "return an i8* type",
709          F);
710   if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy() ||
711       !FunTy->getParamType(0)->getPointerElementType()->isIntegerTy(8))
712     fail(I,
713          "llvm.coro.suspend.async resume function projection function must "
714          "take one i8* type as parameter",
715          F);
716 }
717 
718 void CoroSuspendAsyncInst::checkWellFormed() const {
719   checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
720 }
721 
722 void CoroAsyncEndInst::checkWellFormed() const {
723   auto *MustTailCallFunc = getMustTailCallFunction();
724   if (!MustTailCallFunc)
725     return;
726   auto *FnTy =
727       cast<FunctionType>(MustTailCallFunc->getType()->getPointerElementType());
728   if (FnTy->getNumParams() != (getNumArgOperands() - 3))
729     fail(this,
730          "llvm.coro.end.async must tail call function argument type must "
731          "match the tail arguments",
732          MustTailCallFunc);
733 }
734 
735 void LLVMAddCoroEarlyPass(LLVMPassManagerRef PM) {
736   unwrap(PM)->add(createCoroEarlyLegacyPass());
737 }
738 
739 void LLVMAddCoroSplitPass(LLVMPassManagerRef PM) {
740   unwrap(PM)->add(createCoroSplitLegacyPass());
741 }
742 
743 void LLVMAddCoroElidePass(LLVMPassManagerRef PM) {
744   unwrap(PM)->add(createCoroElideLegacyPass());
745 }
746 
747 void LLVMAddCoroCleanupPass(LLVMPassManagerRef PM) {
748   unwrap(PM)->add(createCoroCleanupLegacyPass());
749 }
750 
751 void
752 LLVMPassManagerBuilderAddCoroutinePassesToExtensionPoints(LLVMPassManagerBuilderRef PMB) {
753   PassManagerBuilder *Builder = unwrap(PMB);
754   addCoroutinePassesToExtensionPoints(*Builder);
755 }
756