1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/Async/IR/Async.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 #define DEBUG_TYPE "convert-async-to-llvm"
23 
24 using namespace mlir;
25 using namespace mlir::async;
26 
27 //===----------------------------------------------------------------------===//
28 // Async Runtime C API declaration.
29 //===----------------------------------------------------------------------===//
30 
31 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
32 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
33 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
34 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
35 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
36 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
37 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
38 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
39 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
40 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
41 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
42 static constexpr const char *kGetValueStorage =
43     "mlirAsyncRuntimeGetValueStorage";
44 static constexpr const char *kAddTokenToGroup =
45     "mlirAsyncRuntimeAddTokenToGroup";
46 static constexpr const char *kAwaitTokenAndExecute =
47     "mlirAsyncRuntimeAwaitTokenAndExecute";
48 static constexpr const char *kAwaitValueAndExecute =
49     "mlirAsyncRuntimeAwaitValueAndExecute";
50 static constexpr const char *kAwaitAllAndExecute =
51     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
52 
53 namespace {
54 /// Async Runtime API function types.
55 ///
56 /// Because we can't create API function signature for type parametrized
57 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
58 /// lowering all async data types become opaque pointers at runtime.
59 struct AsyncAPI {
60   // All async types are lowered to opaque i8* LLVM pointers at runtime.
opaquePointerType__anon008b44d10111::AsyncAPI61   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
62     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
63   }
64 
tokenType__anon008b44d10111::AsyncAPI65   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
66     return LLVM::LLVMTokenType::get(ctx);
67   }
68 
addOrDropRefFunctionType__anon008b44d10111::AsyncAPI69   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
70     auto ref = opaquePointerType(ctx);
71     auto count = IntegerType::get(ctx, 32);
72     return FunctionType::get(ctx, {ref, count}, {});
73   }
74 
createTokenFunctionType__anon008b44d10111::AsyncAPI75   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
76     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
77   }
78 
createValueFunctionType__anon008b44d10111::AsyncAPI79   static FunctionType createValueFunctionType(MLIRContext *ctx) {
80     auto i32 = IntegerType::get(ctx, 32);
81     auto value = opaquePointerType(ctx);
82     return FunctionType::get(ctx, {i32}, {value});
83   }
84 
createGroupFunctionType__anon008b44d10111::AsyncAPI85   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
86     return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
87   }
88 
getValueStorageFunctionType__anon008b44d10111::AsyncAPI89   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
90     auto value = opaquePointerType(ctx);
91     auto storage = opaquePointerType(ctx);
92     return FunctionType::get(ctx, {value}, {storage});
93   }
94 
emplaceTokenFunctionType__anon008b44d10111::AsyncAPI95   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
96     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
97   }
98 
emplaceValueFunctionType__anon008b44d10111::AsyncAPI99   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
100     auto value = opaquePointerType(ctx);
101     return FunctionType::get(ctx, {value}, {});
102   }
103 
awaitTokenFunctionType__anon008b44d10111::AsyncAPI104   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
105     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
106   }
107 
awaitValueFunctionType__anon008b44d10111::AsyncAPI108   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
109     auto value = opaquePointerType(ctx);
110     return FunctionType::get(ctx, {value}, {});
111   }
112 
awaitGroupFunctionType__anon008b44d10111::AsyncAPI113   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
114     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
115   }
116 
executeFunctionType__anon008b44d10111::AsyncAPI117   static FunctionType executeFunctionType(MLIRContext *ctx) {
118     auto hdl = opaquePointerType(ctx);
119     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
120     return FunctionType::get(ctx, {hdl, resume}, {});
121   }
122 
addTokenToGroupFunctionType__anon008b44d10111::AsyncAPI123   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
124     auto i64 = IntegerType::get(ctx, 64);
125     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
126                              {i64});
127   }
128 
awaitTokenAndExecuteFunctionType__anon008b44d10111::AsyncAPI129   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
130     auto hdl = opaquePointerType(ctx);
131     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
132     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
133   }
134 
awaitValueAndExecuteFunctionType__anon008b44d10111::AsyncAPI135   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
136     auto value = opaquePointerType(ctx);
137     auto hdl = opaquePointerType(ctx);
138     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
139     return FunctionType::get(ctx, {value, hdl, resume}, {});
140   }
141 
awaitAllAndExecuteFunctionType__anon008b44d10111::AsyncAPI142   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
143     auto hdl = opaquePointerType(ctx);
144     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
145     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
146   }
147 
148   // Auxiliary coroutine resume intrinsic wrapper.
resumeFunctionType__anon008b44d10111::AsyncAPI149   static Type resumeFunctionType(MLIRContext *ctx) {
150     auto voidTy = LLVM::LLVMVoidType::get(ctx);
151     auto i8Ptr = opaquePointerType(ctx);
152     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
153   }
154 };
155 } // namespace
156 
157 /// Adds Async Runtime C API declarations to the module.
addAsyncRuntimeApiDeclarations(ModuleOp module)158 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
159   auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
160                                                          module.getBody());
161 
162   auto addFuncDecl = [&](StringRef name, FunctionType type) {
163     if (module.lookupSymbol(name))
164       return;
165     builder.create<FuncOp>(name, type).setPrivate();
166   };
167 
168   MLIRContext *ctx = module.getContext();
169   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
170   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
171   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
172   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
173   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
174   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
175   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
176   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
177   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
178   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
179   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
180   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
181   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
182   addFuncDecl(kAwaitTokenAndExecute,
183               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
184   addFuncDecl(kAwaitValueAndExecute,
185               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
186   addFuncDecl(kAwaitAllAndExecute,
187               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // Add malloc/free declarations to the module.
192 //===----------------------------------------------------------------------===//
193 
194 static constexpr const char *kMalloc = "malloc";
195 static constexpr const char *kFree = "free";
196 
addLLVMFuncDecl(ModuleOp module,ImplicitLocOpBuilder & builder,StringRef name,Type ret,ArrayRef<Type> params)197 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
198                             StringRef name, Type ret, ArrayRef<Type> params) {
199   if (module.lookupSymbol(name))
200     return;
201   Type type = LLVM::LLVMFunctionType::get(ret, params);
202   builder.create<LLVM::LLVMFuncOp>(name, type);
203 }
204 
205 /// Adds malloc/free declarations to the module.
addCRuntimeDeclarations(ModuleOp module)206 static void addCRuntimeDeclarations(ModuleOp module) {
207   using namespace mlir::LLVM;
208 
209   MLIRContext *ctx = module.getContext();
210   ImplicitLocOpBuilder builder(module.getLoc(),
211                                module.getBody()->getTerminator());
212 
213   auto voidTy = LLVMVoidType::get(ctx);
214   auto i64 = IntegerType::get(ctx, 64);
215   auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8));
216 
217   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
218   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // Coroutine resume function wrapper.
223 //===----------------------------------------------------------------------===//
224 
225 static constexpr const char *kResume = "__resume";
226 
227 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
228 /// intrinsics. We need this function to be able to pass it to the async
229 /// runtime execute API.
addResumeFunction(ModuleOp module)230 static void addResumeFunction(ModuleOp module) {
231   MLIRContext *ctx = module.getContext();
232 
233   OpBuilder moduleBuilder(module.getBody()->getTerminator());
234   Location loc = module.getLoc();
235 
236   if (module.lookupSymbol(kResume))
237     return;
238 
239   auto voidTy = LLVM::LLVMVoidType::get(ctx);
240   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
241 
242   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
243       loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
244   resumeOp.setPrivate();
245 
246   auto *block = resumeOp.addEntryBlock();
247   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
248 
249   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
250   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // Convert Async dialect types to LLVM types.
255 //===----------------------------------------------------------------------===//
256 
257 namespace {
258 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
259 /// their runtime type (opaque pointers) and does not convert any other types.
260 class AsyncRuntimeTypeConverter : public TypeConverter {
261 public:
AsyncRuntimeTypeConverter()262   AsyncRuntimeTypeConverter() {
263     addConversion([](Type type) { return type; });
264     addConversion(convertAsyncTypes);
265   }
266 
convertAsyncTypes(Type type)267   static Optional<Type> convertAsyncTypes(Type type) {
268     if (type.isa<TokenType, GroupType, ValueType>())
269       return AsyncAPI::opaquePointerType(type.getContext());
270 
271     if (type.isa<CoroIdType, CoroStateType>())
272       return AsyncAPI::tokenType(type.getContext());
273     if (type.isa<CoroHandleType>())
274       return AsyncAPI::opaquePointerType(type.getContext());
275 
276     return llvm::None;
277   }
278 };
279 } // namespace
280 
281 //===----------------------------------------------------------------------===//
282 // Convert async.coro.id to @llvm.coro.id intrinsic.
283 //===----------------------------------------------------------------------===//
284 
285 namespace {
286 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
287 public:
288   using OpConversionPattern::OpConversionPattern;
289 
290   LogicalResult
matchAndRewrite(CoroIdOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const291   matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands,
292                   ConversionPatternRewriter &rewriter) const override {
293     auto token = AsyncAPI::tokenType(op->getContext());
294     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
295     auto loc = op->getLoc();
296 
297     // Constants for initializing coroutine frame.
298     auto constZero = rewriter.create<LLVM::ConstantOp>(
299         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
300     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
301 
302     // Get coroutine id: @llvm.coro.id.
303     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
304         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
305 
306     return success();
307   }
308 };
309 } // namespace
310 
311 //===----------------------------------------------------------------------===//
312 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
313 //===----------------------------------------------------------------------===//
314 
315 namespace {
316 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
317 public:
318   using OpConversionPattern::OpConversionPattern;
319 
320   LogicalResult
matchAndRewrite(CoroBeginOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const321   matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands,
322                   ConversionPatternRewriter &rewriter) const override {
323     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
324     auto loc = op->getLoc();
325 
326     // Get coroutine frame size: @llvm.coro.size.i64.
327     auto coroSize =
328         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
329 
330     // Allocate memory for the coroutine frame.
331     auto coroAlloc = rewriter.create<LLVM::CallOp>(
332         loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc),
333         ValueRange(coroSize.getResult()));
334 
335     // Begin a coroutine: @llvm.coro.begin.
336     auto coroId = CoroBeginOpAdaptor(operands).id();
337     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
338         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
339 
340     return success();
341   }
342 };
343 } // namespace
344 
345 //===----------------------------------------------------------------------===//
346 // Convert async.coro.free to @llvm.coro.free intrinsic.
347 //===----------------------------------------------------------------------===//
348 
349 namespace {
350 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
351 public:
352   using OpConversionPattern::OpConversionPattern;
353 
354   LogicalResult
matchAndRewrite(CoroFreeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const355   matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands,
356                   ConversionPatternRewriter &rewriter) const override {
357     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
358     auto loc = op->getLoc();
359 
360     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
361     auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands);
362 
363     // Free the memory.
364     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
365                                               rewriter.getSymbolRefAttr(kFree),
366                                               ValueRange(coroMem.getResult()));
367 
368     return success();
369   }
370 };
371 } // namespace
372 
373 //===----------------------------------------------------------------------===//
374 // Convert async.coro.end to @llvm.coro.end intrinsic.
375 //===----------------------------------------------------------------------===//
376 
377 namespace {
378 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
379 public:
380   using OpConversionPattern::OpConversionPattern;
381 
382   LogicalResult
matchAndRewrite(CoroEndOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const383   matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands,
384                   ConversionPatternRewriter &rewriter) const override {
385     // We are not in the block that is part of the unwind sequence.
386     auto constFalse = rewriter.create<LLVM::ConstantOp>(
387         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
388 
389     // Mark the end of a coroutine: @llvm.coro.end.
390     auto coroHdl = CoroEndOpAdaptor(operands).handle();
391     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
392                                      ValueRange({coroHdl, constFalse}));
393     rewriter.eraseOp(op);
394 
395     return success();
396   }
397 };
398 } // namespace
399 
400 //===----------------------------------------------------------------------===//
401 // Convert async.coro.save to @llvm.coro.save intrinsic.
402 //===----------------------------------------------------------------------===//
403 
404 namespace {
405 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
406 public:
407   using OpConversionPattern::OpConversionPattern;
408 
409   LogicalResult
matchAndRewrite(CoroSaveOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const410   matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands,
411                   ConversionPatternRewriter &rewriter) const override {
412     // Save the coroutine state: @llvm.coro.save
413     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
414         op, AsyncAPI::tokenType(op->getContext()), operands);
415 
416     return success();
417   }
418 };
419 } // namespace
420 
421 //===----------------------------------------------------------------------===//
422 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
423 //===----------------------------------------------------------------------===//
424 
425 namespace {
426 
427 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
428 /// branch to the appropriate block based on the return code.
429 ///
430 /// Before:
431 ///
432 ///   ^suspended:
433 ///     "opBefore"(...)
434 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
435 ///   ^resume:
436 ///     "op"(...)
437 ///   ^cleanup: ...
438 ///   ^suspend: ...
439 ///
440 /// After:
441 ///
442 ///   ^suspended:
443 ///     "opBefore"(...)
444 ///     %suspend = llmv.intr.coro.suspend ...
445 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
446 ///   ^resume:
447 ///     "op"(...)
448 ///   ^cleanup: ...
449 ///   ^suspend: ...
450 ///
451 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
452 public:
453   using OpConversionPattern::OpConversionPattern;
454 
455   LogicalResult
matchAndRewrite(CoroSuspendOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const456   matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands,
457                   ConversionPatternRewriter &rewriter) const override {
458     auto i8 = rewriter.getIntegerType(8);
459     auto i32 = rewriter.getI32Type();
460     auto loc = op->getLoc();
461 
462     // This is not a final suspension point.
463     auto constFalse = rewriter.create<LLVM::ConstantOp>(
464         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
465 
466     // Suspend a coroutine: @llvm.coro.suspend
467     auto coroState = CoroSuspendOpAdaptor(operands).state();
468     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
469         loc, i8, ValueRange({coroState, constFalse}));
470 
471     // Cast return code to i32.
472 
473     // After a suspension point decide if we should branch into resume, cleanup
474     // or suspend block of the coroutine (see @llvm.coro.suspend return code
475     // documentation).
476     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
477     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
478                                               op.cleanupDest()};
479     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
480         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
481         /*defaultDestination=*/op.suspendDest(),
482         /*defaultOperands=*/ValueRange(),
483         /*caseValues=*/caseValues,
484         /*caseDestinations=*/caseDest,
485         /*caseOperands=*/ArrayRef<ValueRange>(),
486         /*branchWeights=*/ArrayRef<int32_t>());
487 
488     return success();
489   }
490 };
491 } // namespace
492 
493 //===----------------------------------------------------------------------===//
494 // Convert async.runtime.create to the corresponding runtime API call.
495 //
496 // To allocate storage for the async values we use getelementptr trick:
497 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
498 //===----------------------------------------------------------------------===//
499 
500 namespace {
501 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
502 public:
503   using OpConversionPattern::OpConversionPattern;
504 
505   LogicalResult
matchAndRewrite(RuntimeCreateOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const506   matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands,
507                   ConversionPatternRewriter &rewriter) const override {
508     TypeConverter *converter = getTypeConverter();
509     Type resultType = op->getResultTypes()[0];
510 
511     // Tokens and Groups lowered to function calls without arguments.
512     if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) {
513       rewriter.replaceOpWithNewOp<CallOp>(
514           op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup,
515           converter->convertType(resultType));
516       return success();
517     }
518 
519     // To create a value we need to compute the storage requirement.
520     if (auto value = resultType.dyn_cast<ValueType>()) {
521       // Returns the size requirements for the async value storage.
522       auto sizeOf = [&](ValueType valueType) -> Value {
523         auto loc = op->getLoc();
524         auto i32 = rewriter.getI32Type();
525 
526         auto storedType = converter->convertType(valueType.getValueType());
527         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
528 
529         // %Size = getelementptr %T* null, int 1
530         // %SizeI = ptrtoint %T* %Size to i32
531         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
532         auto one = rewriter.create<LLVM::ConstantOp>(
533             loc, i32, rewriter.getI32IntegerAttr(1));
534         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
535                                                 one.getResult());
536         return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep);
537       };
538 
539       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
540                                           sizeOf(value));
541 
542       return success();
543     }
544 
545     return rewriter.notifyMatchFailure(op, "unsupported async type");
546   }
547 };
548 } // namespace
549 
550 //===----------------------------------------------------------------------===//
551 // Convert async.runtime.set_available to the corresponding runtime API call.
552 //===----------------------------------------------------------------------===//
553 
554 namespace {
555 class RuntimeSetAvailableOpLowering
556     : public OpConversionPattern<RuntimeSetAvailableOp> {
557 public:
558   using OpConversionPattern::OpConversionPattern;
559 
560   LogicalResult
matchAndRewrite(RuntimeSetAvailableOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const561   matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
562                   ConversionPatternRewriter &rewriter) const override {
563     Type operandType = op.operand().getType();
564 
565     if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) {
566       rewriter.create<CallOp>(op->getLoc(),
567                               operandType.isa<TokenType>() ? kEmplaceToken
568                                                            : kEmplaceValue,
569                               TypeRange(), operands);
570       rewriter.eraseOp(op);
571       return success();
572     }
573 
574     return rewriter.notifyMatchFailure(op, "unsupported async type");
575   }
576 };
577 } // namespace
578 
579 //===----------------------------------------------------------------------===//
580 // Convert async.runtime.await to the corresponding runtime API call.
581 //===----------------------------------------------------------------------===//
582 
583 namespace {
584 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
585 public:
586   using OpConversionPattern::OpConversionPattern;
587 
588   LogicalResult
matchAndRewrite(RuntimeAwaitOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const589   matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
590                   ConversionPatternRewriter &rewriter) const override {
591     Type operandType = op.operand().getType();
592 
593     StringRef apiFuncName;
594     if (operandType.isa<TokenType>())
595       apiFuncName = kAwaitToken;
596     else if (operandType.isa<ValueType>())
597       apiFuncName = kAwaitValue;
598     else if (operandType.isa<GroupType>())
599       apiFuncName = kAwaitGroup;
600     else
601       return rewriter.notifyMatchFailure(op, "unsupported async type");
602 
603     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
604     rewriter.eraseOp(op);
605 
606     return success();
607   }
608 };
609 } // namespace
610 
611 //===----------------------------------------------------------------------===//
612 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
613 //===----------------------------------------------------------------------===//
614 
615 namespace {
616 class RuntimeAwaitAndResumeOpLowering
617     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
618 public:
619   using OpConversionPattern::OpConversionPattern;
620 
621   LogicalResult
matchAndRewrite(RuntimeAwaitAndResumeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const622   matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
623                   ConversionPatternRewriter &rewriter) const override {
624     Type operandType = op.operand().getType();
625 
626     StringRef apiFuncName;
627     if (operandType.isa<TokenType>())
628       apiFuncName = kAwaitTokenAndExecute;
629     else if (operandType.isa<ValueType>())
630       apiFuncName = kAwaitValueAndExecute;
631     else if (operandType.isa<GroupType>())
632       apiFuncName = kAwaitAllAndExecute;
633     else
634       return rewriter.notifyMatchFailure(op, "unsupported async type");
635 
636     Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
637     Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();
638 
639     // A pointer to coroutine resume intrinsic wrapper.
640     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
641     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
642         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
643 
644     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
645                             ValueRange({operand, handle, resumePtr.res()}));
646     rewriter.eraseOp(op);
647 
648     return success();
649   }
650 };
651 } // namespace
652 
653 //===----------------------------------------------------------------------===//
654 // Convert async.runtime.resume to the corresponding runtime API call.
655 //===----------------------------------------------------------------------===//
656 
657 namespace {
658 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
659 public:
660   using OpConversionPattern::OpConversionPattern;
661 
662   LogicalResult
matchAndRewrite(RuntimeResumeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const663   matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands,
664                   ConversionPatternRewriter &rewriter) const override {
665     // A pointer to coroutine resume intrinsic wrapper.
666     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
667     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
668         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
669 
670     // Call async runtime API to execute a coroutine in the managed thread.
671     auto coroHdl = RuntimeResumeOpAdaptor(operands).handle();
672     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute,
673                                         ValueRange({coroHdl, resumePtr.res()}));
674 
675     return success();
676   }
677 };
678 } // namespace
679 
680 //===----------------------------------------------------------------------===//
681 // Convert async.runtime.store to the corresponding runtime API call.
682 //===----------------------------------------------------------------------===//
683 
684 namespace {
685 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
686 public:
687   using OpConversionPattern::OpConversionPattern;
688 
689   LogicalResult
matchAndRewrite(RuntimeStoreOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const690   matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands,
691                   ConversionPatternRewriter &rewriter) const override {
692     Location loc = op->getLoc();
693 
694     // Get a pointer to the async value storage from the runtime.
695     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
696     auto storage = RuntimeStoreOpAdaptor(operands).storage();
697     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
698                                               TypeRange(i8Ptr), storage);
699 
700     // Cast from i8* to the LLVM pointer type.
701     auto valueType = op.value().getType();
702     auto llvmValueType = getTypeConverter()->convertType(valueType);
703     if (!llvmValueType)
704       return rewriter.notifyMatchFailure(
705           op, "failed to convert stored value type to LLVM type");
706 
707     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
708         loc, LLVM::LLVMPointerType::get(llvmValueType),
709         storagePtr.getResult(0));
710 
711     // Store the yielded value into the async value storage.
712     auto value = RuntimeStoreOpAdaptor(operands).value();
713     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
714 
715     // Erase the original runtime store operation.
716     rewriter.eraseOp(op);
717 
718     return success();
719   }
720 };
721 } // namespace
722 
723 //===----------------------------------------------------------------------===//
724 // Convert async.runtime.load to the corresponding runtime API call.
725 //===----------------------------------------------------------------------===//
726 
727 namespace {
728 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
729 public:
730   using OpConversionPattern::OpConversionPattern;
731 
732   LogicalResult
matchAndRewrite(RuntimeLoadOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const733   matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands,
734                   ConversionPatternRewriter &rewriter) const override {
735     Location loc = op->getLoc();
736 
737     // Get a pointer to the async value storage from the runtime.
738     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
739     auto storage = RuntimeLoadOpAdaptor(operands).storage();
740     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
741                                               TypeRange(i8Ptr), storage);
742 
743     // Cast from i8* to the LLVM pointer type.
744     auto valueType = op.result().getType();
745     auto llvmValueType = getTypeConverter()->convertType(valueType);
746     if (!llvmValueType)
747       return rewriter.notifyMatchFailure(
748           op, "failed to convert loaded value type to LLVM type");
749 
750     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
751         loc, LLVM::LLVMPointerType::get(llvmValueType),
752         storagePtr.getResult(0));
753 
754     // Load from the casted pointer.
755     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
756 
757     return success();
758   }
759 };
760 } // namespace
761 
762 //===----------------------------------------------------------------------===//
763 // Convert async.runtime.add_to_group to the corresponding runtime API call.
764 //===----------------------------------------------------------------------===//
765 
766 namespace {
767 class RuntimeAddToGroupOpLowering
768     : public OpConversionPattern<RuntimeAddToGroupOp> {
769 public:
770   using OpConversionPattern::OpConversionPattern;
771 
772   LogicalResult
matchAndRewrite(RuntimeAddToGroupOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const773   matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands,
774                   ConversionPatternRewriter &rewriter) const override {
775     // Currently we can only add tokens to the group.
776     if (!op.operand().getType().isa<TokenType>())
777       return rewriter.notifyMatchFailure(op, "only token type is supported");
778 
779     // Replace with a runtime API function call.
780     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup,
781                                         rewriter.getI64Type(), operands);
782 
783     return success();
784   }
785 };
786 } // namespace
787 
788 //===----------------------------------------------------------------------===//
789 // Async reference counting ops lowering (`async.runtime.add_ref` and
790 // `async.runtime.drop_ref` to the corresponding API calls).
791 //===----------------------------------------------------------------------===//
792 
793 namespace {
794 template <typename RefCountingOp>
795 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
796 public:
RefCountingOpLowering(TypeConverter & converter,MLIRContext * ctx,StringRef apiFunctionName)797   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
798                                  StringRef apiFunctionName)
799       : OpConversionPattern<RefCountingOp>(converter, ctx),
800         apiFunctionName(apiFunctionName) {}
801 
802   LogicalResult
matchAndRewrite(RefCountingOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const803   matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands,
804                   ConversionPatternRewriter &rewriter) const override {
805     auto count =
806         rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(),
807                                     rewriter.getI32IntegerAttr(op.count()));
808 
809     auto operand = typename RefCountingOp::Adaptor(operands).operand();
810     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
811                                         ValueRange({operand, count}));
812 
813     return success();
814   }
815 
816 private:
817   StringRef apiFunctionName;
818 };
819 
820 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
821 public:
RuntimeAddRefOpLowering(TypeConverter & converter,MLIRContext * ctx)822   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
823       : RefCountingOpLowering(converter, ctx, kAddRef) {}
824 };
825 
826 class RuntimeDropRefOpLowering
827     : public RefCountingOpLowering<RuntimeDropRefOp> {
828 public:
RuntimeDropRefOpLowering(TypeConverter & converter,MLIRContext * ctx)829   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
830       : RefCountingOpLowering(converter, ctx, kDropRef) {}
831 };
832 } // namespace
833 
834 //===----------------------------------------------------------------------===//
835 // Convert return operations that return async values from async regions.
836 //===----------------------------------------------------------------------===//
837 
838 namespace {
839 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
840 public:
841   using OpConversionPattern::OpConversionPattern;
842 
843   LogicalResult
matchAndRewrite(ReturnOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const844   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
845                   ConversionPatternRewriter &rewriter) const override {
846     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
847     return success();
848   }
849 };
850 } // namespace
851 
852 //===----------------------------------------------------------------------===//
853 
854 namespace {
855 struct ConvertAsyncToLLVMPass
856     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
857   void runOnOperation() override;
858 };
859 } // namespace
860 
runOnOperation()861 void ConvertAsyncToLLVMPass::runOnOperation() {
862   ModuleOp module = getOperation();
863   MLIRContext *ctx = module->getContext();
864 
865   // Add declarations for all functions required by the coroutines lowering.
866   addResumeFunction(module);
867   addAsyncRuntimeApiDeclarations(module);
868   addCRuntimeDeclarations(module);
869 
870   // Lower async.runtime and async.coro operations to Async Runtime API and
871   // LLVM coroutine intrinsics.
872 
873   // Convert async dialect types and operations to LLVM dialect.
874   AsyncRuntimeTypeConverter converter;
875   OwningRewritePatternList patterns;
876 
877   // We use conversion to LLVM type to lower async.runtime load and store
878   // operations.
879   LLVMTypeConverter llvmConverter(ctx);
880   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
881 
882   // Convert async types in function signatures and function calls.
883   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
884   populateCallOpTypeConversionPattern(patterns, ctx, converter);
885 
886   // Convert return operations inside async.execute regions.
887   patterns.insert<ReturnOpOpConversion>(converter, ctx);
888 
889   // Lower async.runtime operations to the async runtime API calls.
890   patterns.insert<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering,
891                   RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
892                   RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
893                   RuntimeDropRefOpLowering>(converter, ctx);
894 
895   // Lower async.runtime operations that rely on LLVM type converter to convert
896   // from async value payload type to the LLVM type.
897   patterns.insert<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
898                   RuntimeLoadOpLowering>(llvmConverter, ctx);
899 
900   // Lower async coroutine operations to LLVM coroutine intrinsics.
901   patterns.insert<CoroIdOpConversion, CoroBeginOpConversion,
902                   CoroFreeOpConversion, CoroEndOpConversion,
903                   CoroSaveOpConversion, CoroSuspendOpConversion>(converter,
904                                                                  ctx);
905 
906   ConversionTarget target(*ctx);
907   target.addLegalOp<ConstantOp>();
908   target.addLegalDialect<LLVM::LLVMDialect>();
909 
910   // All operations from Async dialect must be lowered to the runtime API and
911   // LLVM intrinsics calls.
912   target.addIllegalDialect<AsyncDialect>();
913 
914   // Add dynamic legality constraints to apply conversions defined above.
915   target.addDynamicallyLegalOp<FuncOp>(
916       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
917   target.addDynamicallyLegalOp<ReturnOp>(
918       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
919   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
920     return converter.isSignatureLegal(op.getCalleeType());
921   });
922 
923   if (failed(applyPartialConversion(module, target, std::move(patterns))))
924     signalPassFailure();
925 }
926 
927 //===----------------------------------------------------------------------===//
928 // Patterns for structural type conversions for the Async dialect operations.
929 //===----------------------------------------------------------------------===//
930 
931 namespace {
932 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
933 public:
934   using OpConversionPattern::OpConversionPattern;
935   LogicalResult
matchAndRewrite(ExecuteOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const936   matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
937                   ConversionPatternRewriter &rewriter) const override {
938     ExecuteOp newOp =
939         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
940     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
941                                 newOp.getRegion().end());
942 
943     // Set operands and update block argument and result types.
944     newOp->setOperands(operands);
945     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
946       return failure();
947     for (auto result : newOp.getResults())
948       result.setType(typeConverter->convertType(result.getType()));
949 
950     rewriter.replaceOp(op, newOp.getResults());
951     return success();
952   }
953 };
954 
955 // Dummy pattern to trigger the appropriate type conversion / materialization.
956 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
957 public:
958   using OpConversionPattern::OpConversionPattern;
959   LogicalResult
matchAndRewrite(AwaitOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const960   matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
961                   ConversionPatternRewriter &rewriter) const override {
962     rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
963     return success();
964   }
965 };
966 
967 // Dummy pattern to trigger the appropriate type conversion / materialization.
968 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
969 public:
970   using OpConversionPattern::OpConversionPattern;
971   LogicalResult
matchAndRewrite(async::YieldOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const972   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
973                   ConversionPatternRewriter &rewriter) const override {
974     rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
975     return success();
976   }
977 };
978 } // namespace
979 
createConvertAsyncToLLVMPass()980 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
981   return std::make_unique<ConvertAsyncToLLVMPass>();
982 }
983 
populateAsyncStructuralTypeConversionsAndLegality(MLIRContext * context,TypeConverter & typeConverter,OwningRewritePatternList & patterns,ConversionTarget & target)984 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
985     MLIRContext *context, TypeConverter &typeConverter,
986     OwningRewritePatternList &patterns, ConversionTarget &target) {
987   typeConverter.addConversion([&](TokenType type) { return type; });
988   typeConverter.addConversion([&](ValueType type) {
989     return ValueType::get(typeConverter.convertType(type.getValueType()));
990   });
991 
992   patterns
993       .insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
994           typeConverter, context);
995 
996   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
997       [&](Operation *op) { return typeConverter.isLegal(op); });
998 }
999