1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/Async/IR/Async.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21
22 #define DEBUG_TYPE "convert-async-to-llvm"
23
24 using namespace mlir;
25 using namespace mlir::async;
26
27 //===----------------------------------------------------------------------===//
28 // Async Runtime C API declaration.
29 //===----------------------------------------------------------------------===//
30
31 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
32 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
33 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
34 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
35 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
36 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
37 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
38 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
39 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
40 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
41 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
42 static constexpr const char *kGetValueStorage =
43 "mlirAsyncRuntimeGetValueStorage";
44 static constexpr const char *kAddTokenToGroup =
45 "mlirAsyncRuntimeAddTokenToGroup";
46 static constexpr const char *kAwaitTokenAndExecute =
47 "mlirAsyncRuntimeAwaitTokenAndExecute";
48 static constexpr const char *kAwaitValueAndExecute =
49 "mlirAsyncRuntimeAwaitValueAndExecute";
50 static constexpr const char *kAwaitAllAndExecute =
51 "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
52
53 namespace {
54 /// Async Runtime API function types.
55 ///
56 /// Because we can't create API function signature for type parametrized
57 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
58 /// lowering all async data types become opaque pointers at runtime.
59 struct AsyncAPI {
60 // All async types are lowered to opaque i8* LLVM pointers at runtime.
opaquePointerType__anon008b44d10111::AsyncAPI61 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
62 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
63 }
64
tokenType__anon008b44d10111::AsyncAPI65 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
66 return LLVM::LLVMTokenType::get(ctx);
67 }
68
addOrDropRefFunctionType__anon008b44d10111::AsyncAPI69 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
70 auto ref = opaquePointerType(ctx);
71 auto count = IntegerType::get(ctx, 32);
72 return FunctionType::get(ctx, {ref, count}, {});
73 }
74
createTokenFunctionType__anon008b44d10111::AsyncAPI75 static FunctionType createTokenFunctionType(MLIRContext *ctx) {
76 return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
77 }
78
createValueFunctionType__anon008b44d10111::AsyncAPI79 static FunctionType createValueFunctionType(MLIRContext *ctx) {
80 auto i32 = IntegerType::get(ctx, 32);
81 auto value = opaquePointerType(ctx);
82 return FunctionType::get(ctx, {i32}, {value});
83 }
84
createGroupFunctionType__anon008b44d10111::AsyncAPI85 static FunctionType createGroupFunctionType(MLIRContext *ctx) {
86 return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
87 }
88
getValueStorageFunctionType__anon008b44d10111::AsyncAPI89 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
90 auto value = opaquePointerType(ctx);
91 auto storage = opaquePointerType(ctx);
92 return FunctionType::get(ctx, {value}, {storage});
93 }
94
emplaceTokenFunctionType__anon008b44d10111::AsyncAPI95 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
96 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
97 }
98
emplaceValueFunctionType__anon008b44d10111::AsyncAPI99 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
100 auto value = opaquePointerType(ctx);
101 return FunctionType::get(ctx, {value}, {});
102 }
103
awaitTokenFunctionType__anon008b44d10111::AsyncAPI104 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
105 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
106 }
107
awaitValueFunctionType__anon008b44d10111::AsyncAPI108 static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
109 auto value = opaquePointerType(ctx);
110 return FunctionType::get(ctx, {value}, {});
111 }
112
awaitGroupFunctionType__anon008b44d10111::AsyncAPI113 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
114 return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
115 }
116
executeFunctionType__anon008b44d10111::AsyncAPI117 static FunctionType executeFunctionType(MLIRContext *ctx) {
118 auto hdl = opaquePointerType(ctx);
119 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
120 return FunctionType::get(ctx, {hdl, resume}, {});
121 }
122
addTokenToGroupFunctionType__anon008b44d10111::AsyncAPI123 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
124 auto i64 = IntegerType::get(ctx, 64);
125 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
126 {i64});
127 }
128
awaitTokenAndExecuteFunctionType__anon008b44d10111::AsyncAPI129 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
130 auto hdl = opaquePointerType(ctx);
131 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
132 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
133 }
134
awaitValueAndExecuteFunctionType__anon008b44d10111::AsyncAPI135 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
136 auto value = opaquePointerType(ctx);
137 auto hdl = opaquePointerType(ctx);
138 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
139 return FunctionType::get(ctx, {value, hdl, resume}, {});
140 }
141
awaitAllAndExecuteFunctionType__anon008b44d10111::AsyncAPI142 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
143 auto hdl = opaquePointerType(ctx);
144 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
145 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
146 }
147
148 // Auxiliary coroutine resume intrinsic wrapper.
resumeFunctionType__anon008b44d10111::AsyncAPI149 static Type resumeFunctionType(MLIRContext *ctx) {
150 auto voidTy = LLVM::LLVMVoidType::get(ctx);
151 auto i8Ptr = opaquePointerType(ctx);
152 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
153 }
154 };
155 } // namespace
156
157 /// Adds Async Runtime C API declarations to the module.
addAsyncRuntimeApiDeclarations(ModuleOp module)158 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
159 auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
160 module.getBody());
161
162 auto addFuncDecl = [&](StringRef name, FunctionType type) {
163 if (module.lookupSymbol(name))
164 return;
165 builder.create<FuncOp>(name, type).setPrivate();
166 };
167
168 MLIRContext *ctx = module.getContext();
169 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
170 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
171 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
172 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
173 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
174 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
175 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
176 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
177 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
178 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
179 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
180 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
181 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
182 addFuncDecl(kAwaitTokenAndExecute,
183 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
184 addFuncDecl(kAwaitValueAndExecute,
185 AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
186 addFuncDecl(kAwaitAllAndExecute,
187 AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
188 }
189
190 //===----------------------------------------------------------------------===//
191 // Add malloc/free declarations to the module.
192 //===----------------------------------------------------------------------===//
193
194 static constexpr const char *kMalloc = "malloc";
195 static constexpr const char *kFree = "free";
196
addLLVMFuncDecl(ModuleOp module,ImplicitLocOpBuilder & builder,StringRef name,Type ret,ArrayRef<Type> params)197 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
198 StringRef name, Type ret, ArrayRef<Type> params) {
199 if (module.lookupSymbol(name))
200 return;
201 Type type = LLVM::LLVMFunctionType::get(ret, params);
202 builder.create<LLVM::LLVMFuncOp>(name, type);
203 }
204
205 /// Adds malloc/free declarations to the module.
addCRuntimeDeclarations(ModuleOp module)206 static void addCRuntimeDeclarations(ModuleOp module) {
207 using namespace mlir::LLVM;
208
209 MLIRContext *ctx = module.getContext();
210 ImplicitLocOpBuilder builder(module.getLoc(),
211 module.getBody()->getTerminator());
212
213 auto voidTy = LLVMVoidType::get(ctx);
214 auto i64 = IntegerType::get(ctx, 64);
215 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8));
216
217 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
218 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
219 }
220
221 //===----------------------------------------------------------------------===//
222 // Coroutine resume function wrapper.
223 //===----------------------------------------------------------------------===//
224
225 static constexpr const char *kResume = "__resume";
226
227 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
228 /// intrinsics. We need this function to be able to pass it to the async
229 /// runtime execute API.
addResumeFunction(ModuleOp module)230 static void addResumeFunction(ModuleOp module) {
231 MLIRContext *ctx = module.getContext();
232
233 OpBuilder moduleBuilder(module.getBody()->getTerminator());
234 Location loc = module.getLoc();
235
236 if (module.lookupSymbol(kResume))
237 return;
238
239 auto voidTy = LLVM::LLVMVoidType::get(ctx);
240 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
241
242 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
243 loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
244 resumeOp.setPrivate();
245
246 auto *block = resumeOp.addEntryBlock();
247 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
248
249 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
250 blockBuilder.create<LLVM::ReturnOp>(ValueRange());
251 }
252
253 //===----------------------------------------------------------------------===//
254 // Convert Async dialect types to LLVM types.
255 //===----------------------------------------------------------------------===//
256
257 namespace {
258 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
259 /// their runtime type (opaque pointers) and does not convert any other types.
260 class AsyncRuntimeTypeConverter : public TypeConverter {
261 public:
AsyncRuntimeTypeConverter()262 AsyncRuntimeTypeConverter() {
263 addConversion([](Type type) { return type; });
264 addConversion(convertAsyncTypes);
265 }
266
convertAsyncTypes(Type type)267 static Optional<Type> convertAsyncTypes(Type type) {
268 if (type.isa<TokenType, GroupType, ValueType>())
269 return AsyncAPI::opaquePointerType(type.getContext());
270
271 if (type.isa<CoroIdType, CoroStateType>())
272 return AsyncAPI::tokenType(type.getContext());
273 if (type.isa<CoroHandleType>())
274 return AsyncAPI::opaquePointerType(type.getContext());
275
276 return llvm::None;
277 }
278 };
279 } // namespace
280
281 //===----------------------------------------------------------------------===//
282 // Convert async.coro.id to @llvm.coro.id intrinsic.
283 //===----------------------------------------------------------------------===//
284
285 namespace {
286 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
287 public:
288 using OpConversionPattern::OpConversionPattern;
289
290 LogicalResult
matchAndRewrite(CoroIdOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const291 matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands,
292 ConversionPatternRewriter &rewriter) const override {
293 auto token = AsyncAPI::tokenType(op->getContext());
294 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
295 auto loc = op->getLoc();
296
297 // Constants for initializing coroutine frame.
298 auto constZero = rewriter.create<LLVM::ConstantOp>(
299 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
300 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
301
302 // Get coroutine id: @llvm.coro.id.
303 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
304 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
305
306 return success();
307 }
308 };
309 } // namespace
310
311 //===----------------------------------------------------------------------===//
312 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
313 //===----------------------------------------------------------------------===//
314
315 namespace {
316 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
317 public:
318 using OpConversionPattern::OpConversionPattern;
319
320 LogicalResult
matchAndRewrite(CoroBeginOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const321 matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands,
322 ConversionPatternRewriter &rewriter) const override {
323 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
324 auto loc = op->getLoc();
325
326 // Get coroutine frame size: @llvm.coro.size.i64.
327 auto coroSize =
328 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
329
330 // Allocate memory for the coroutine frame.
331 auto coroAlloc = rewriter.create<LLVM::CallOp>(
332 loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc),
333 ValueRange(coroSize.getResult()));
334
335 // Begin a coroutine: @llvm.coro.begin.
336 auto coroId = CoroBeginOpAdaptor(operands).id();
337 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
338 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
339
340 return success();
341 }
342 };
343 } // namespace
344
345 //===----------------------------------------------------------------------===//
346 // Convert async.coro.free to @llvm.coro.free intrinsic.
347 //===----------------------------------------------------------------------===//
348
349 namespace {
350 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
351 public:
352 using OpConversionPattern::OpConversionPattern;
353
354 LogicalResult
matchAndRewrite(CoroFreeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const355 matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands,
356 ConversionPatternRewriter &rewriter) const override {
357 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
358 auto loc = op->getLoc();
359
360 // Get a pointer to the coroutine frame memory: @llvm.coro.free.
361 auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands);
362
363 // Free the memory.
364 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
365 rewriter.getSymbolRefAttr(kFree),
366 ValueRange(coroMem.getResult()));
367
368 return success();
369 }
370 };
371 } // namespace
372
373 //===----------------------------------------------------------------------===//
374 // Convert async.coro.end to @llvm.coro.end intrinsic.
375 //===----------------------------------------------------------------------===//
376
377 namespace {
378 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
379 public:
380 using OpConversionPattern::OpConversionPattern;
381
382 LogicalResult
matchAndRewrite(CoroEndOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const383 matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands,
384 ConversionPatternRewriter &rewriter) const override {
385 // We are not in the block that is part of the unwind sequence.
386 auto constFalse = rewriter.create<LLVM::ConstantOp>(
387 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
388
389 // Mark the end of a coroutine: @llvm.coro.end.
390 auto coroHdl = CoroEndOpAdaptor(operands).handle();
391 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
392 ValueRange({coroHdl, constFalse}));
393 rewriter.eraseOp(op);
394
395 return success();
396 }
397 };
398 } // namespace
399
400 //===----------------------------------------------------------------------===//
401 // Convert async.coro.save to @llvm.coro.save intrinsic.
402 //===----------------------------------------------------------------------===//
403
404 namespace {
405 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
406 public:
407 using OpConversionPattern::OpConversionPattern;
408
409 LogicalResult
matchAndRewrite(CoroSaveOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const410 matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands,
411 ConversionPatternRewriter &rewriter) const override {
412 // Save the coroutine state: @llvm.coro.save
413 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
414 op, AsyncAPI::tokenType(op->getContext()), operands);
415
416 return success();
417 }
418 };
419 } // namespace
420
421 //===----------------------------------------------------------------------===//
422 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
423 //===----------------------------------------------------------------------===//
424
425 namespace {
426
427 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
428 /// branch to the appropriate block based on the return code.
429 ///
430 /// Before:
431 ///
432 /// ^suspended:
433 /// "opBefore"(...)
434 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup
435 /// ^resume:
436 /// "op"(...)
437 /// ^cleanup: ...
438 /// ^suspend: ...
439 ///
440 /// After:
441 ///
442 /// ^suspended:
443 /// "opBefore"(...)
444 /// %suspend = llmv.intr.coro.suspend ...
445 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
446 /// ^resume:
447 /// "op"(...)
448 /// ^cleanup: ...
449 /// ^suspend: ...
450 ///
451 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
452 public:
453 using OpConversionPattern::OpConversionPattern;
454
455 LogicalResult
matchAndRewrite(CoroSuspendOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const456 matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands,
457 ConversionPatternRewriter &rewriter) const override {
458 auto i8 = rewriter.getIntegerType(8);
459 auto i32 = rewriter.getI32Type();
460 auto loc = op->getLoc();
461
462 // This is not a final suspension point.
463 auto constFalse = rewriter.create<LLVM::ConstantOp>(
464 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
465
466 // Suspend a coroutine: @llvm.coro.suspend
467 auto coroState = CoroSuspendOpAdaptor(operands).state();
468 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
469 loc, i8, ValueRange({coroState, constFalse}));
470
471 // Cast return code to i32.
472
473 // After a suspension point decide if we should branch into resume, cleanup
474 // or suspend block of the coroutine (see @llvm.coro.suspend return code
475 // documentation).
476 llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
477 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
478 op.cleanupDest()};
479 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
480 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
481 /*defaultDestination=*/op.suspendDest(),
482 /*defaultOperands=*/ValueRange(),
483 /*caseValues=*/caseValues,
484 /*caseDestinations=*/caseDest,
485 /*caseOperands=*/ArrayRef<ValueRange>(),
486 /*branchWeights=*/ArrayRef<int32_t>());
487
488 return success();
489 }
490 };
491 } // namespace
492
493 //===----------------------------------------------------------------------===//
494 // Convert async.runtime.create to the corresponding runtime API call.
495 //
496 // To allocate storage for the async values we use getelementptr trick:
497 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
498 //===----------------------------------------------------------------------===//
499
500 namespace {
501 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
502 public:
503 using OpConversionPattern::OpConversionPattern;
504
505 LogicalResult
matchAndRewrite(RuntimeCreateOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const506 matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands,
507 ConversionPatternRewriter &rewriter) const override {
508 TypeConverter *converter = getTypeConverter();
509 Type resultType = op->getResultTypes()[0];
510
511 // Tokens and Groups lowered to function calls without arguments.
512 if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) {
513 rewriter.replaceOpWithNewOp<CallOp>(
514 op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup,
515 converter->convertType(resultType));
516 return success();
517 }
518
519 // To create a value we need to compute the storage requirement.
520 if (auto value = resultType.dyn_cast<ValueType>()) {
521 // Returns the size requirements for the async value storage.
522 auto sizeOf = [&](ValueType valueType) -> Value {
523 auto loc = op->getLoc();
524 auto i32 = rewriter.getI32Type();
525
526 auto storedType = converter->convertType(valueType.getValueType());
527 auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
528
529 // %Size = getelementptr %T* null, int 1
530 // %SizeI = ptrtoint %T* %Size to i32
531 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
532 auto one = rewriter.create<LLVM::ConstantOp>(
533 loc, i32, rewriter.getI32IntegerAttr(1));
534 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
535 one.getResult());
536 return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep);
537 };
538
539 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
540 sizeOf(value));
541
542 return success();
543 }
544
545 return rewriter.notifyMatchFailure(op, "unsupported async type");
546 }
547 };
548 } // namespace
549
550 //===----------------------------------------------------------------------===//
551 // Convert async.runtime.set_available to the corresponding runtime API call.
552 //===----------------------------------------------------------------------===//
553
554 namespace {
555 class RuntimeSetAvailableOpLowering
556 : public OpConversionPattern<RuntimeSetAvailableOp> {
557 public:
558 using OpConversionPattern::OpConversionPattern;
559
560 LogicalResult
matchAndRewrite(RuntimeSetAvailableOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const561 matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
562 ConversionPatternRewriter &rewriter) const override {
563 Type operandType = op.operand().getType();
564
565 if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) {
566 rewriter.create<CallOp>(op->getLoc(),
567 operandType.isa<TokenType>() ? kEmplaceToken
568 : kEmplaceValue,
569 TypeRange(), operands);
570 rewriter.eraseOp(op);
571 return success();
572 }
573
574 return rewriter.notifyMatchFailure(op, "unsupported async type");
575 }
576 };
577 } // namespace
578
579 //===----------------------------------------------------------------------===//
580 // Convert async.runtime.await to the corresponding runtime API call.
581 //===----------------------------------------------------------------------===//
582
583 namespace {
584 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
585 public:
586 using OpConversionPattern::OpConversionPattern;
587
588 LogicalResult
matchAndRewrite(RuntimeAwaitOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const589 matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
590 ConversionPatternRewriter &rewriter) const override {
591 Type operandType = op.operand().getType();
592
593 StringRef apiFuncName;
594 if (operandType.isa<TokenType>())
595 apiFuncName = kAwaitToken;
596 else if (operandType.isa<ValueType>())
597 apiFuncName = kAwaitValue;
598 else if (operandType.isa<GroupType>())
599 apiFuncName = kAwaitGroup;
600 else
601 return rewriter.notifyMatchFailure(op, "unsupported async type");
602
603 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
604 rewriter.eraseOp(op);
605
606 return success();
607 }
608 };
609 } // namespace
610
611 //===----------------------------------------------------------------------===//
612 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
613 //===----------------------------------------------------------------------===//
614
615 namespace {
616 class RuntimeAwaitAndResumeOpLowering
617 : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
618 public:
619 using OpConversionPattern::OpConversionPattern;
620
621 LogicalResult
matchAndRewrite(RuntimeAwaitAndResumeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const622 matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
623 ConversionPatternRewriter &rewriter) const override {
624 Type operandType = op.operand().getType();
625
626 StringRef apiFuncName;
627 if (operandType.isa<TokenType>())
628 apiFuncName = kAwaitTokenAndExecute;
629 else if (operandType.isa<ValueType>())
630 apiFuncName = kAwaitValueAndExecute;
631 else if (operandType.isa<GroupType>())
632 apiFuncName = kAwaitAllAndExecute;
633 else
634 return rewriter.notifyMatchFailure(op, "unsupported async type");
635
636 Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
637 Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();
638
639 // A pointer to coroutine resume intrinsic wrapper.
640 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
641 auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
642 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
643
644 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
645 ValueRange({operand, handle, resumePtr.res()}));
646 rewriter.eraseOp(op);
647
648 return success();
649 }
650 };
651 } // namespace
652
653 //===----------------------------------------------------------------------===//
654 // Convert async.runtime.resume to the corresponding runtime API call.
655 //===----------------------------------------------------------------------===//
656
657 namespace {
658 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
659 public:
660 using OpConversionPattern::OpConversionPattern;
661
662 LogicalResult
matchAndRewrite(RuntimeResumeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const663 matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands,
664 ConversionPatternRewriter &rewriter) const override {
665 // A pointer to coroutine resume intrinsic wrapper.
666 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
667 auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
668 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
669
670 // Call async runtime API to execute a coroutine in the managed thread.
671 auto coroHdl = RuntimeResumeOpAdaptor(operands).handle();
672 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute,
673 ValueRange({coroHdl, resumePtr.res()}));
674
675 return success();
676 }
677 };
678 } // namespace
679
680 //===----------------------------------------------------------------------===//
681 // Convert async.runtime.store to the corresponding runtime API call.
682 //===----------------------------------------------------------------------===//
683
684 namespace {
685 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
686 public:
687 using OpConversionPattern::OpConversionPattern;
688
689 LogicalResult
matchAndRewrite(RuntimeStoreOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const690 matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands,
691 ConversionPatternRewriter &rewriter) const override {
692 Location loc = op->getLoc();
693
694 // Get a pointer to the async value storage from the runtime.
695 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
696 auto storage = RuntimeStoreOpAdaptor(operands).storage();
697 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
698 TypeRange(i8Ptr), storage);
699
700 // Cast from i8* to the LLVM pointer type.
701 auto valueType = op.value().getType();
702 auto llvmValueType = getTypeConverter()->convertType(valueType);
703 if (!llvmValueType)
704 return rewriter.notifyMatchFailure(
705 op, "failed to convert stored value type to LLVM type");
706
707 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
708 loc, LLVM::LLVMPointerType::get(llvmValueType),
709 storagePtr.getResult(0));
710
711 // Store the yielded value into the async value storage.
712 auto value = RuntimeStoreOpAdaptor(operands).value();
713 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
714
715 // Erase the original runtime store operation.
716 rewriter.eraseOp(op);
717
718 return success();
719 }
720 };
721 } // namespace
722
723 //===----------------------------------------------------------------------===//
724 // Convert async.runtime.load to the corresponding runtime API call.
725 //===----------------------------------------------------------------------===//
726
727 namespace {
728 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
729 public:
730 using OpConversionPattern::OpConversionPattern;
731
732 LogicalResult
matchAndRewrite(RuntimeLoadOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const733 matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands,
734 ConversionPatternRewriter &rewriter) const override {
735 Location loc = op->getLoc();
736
737 // Get a pointer to the async value storage from the runtime.
738 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
739 auto storage = RuntimeLoadOpAdaptor(operands).storage();
740 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
741 TypeRange(i8Ptr), storage);
742
743 // Cast from i8* to the LLVM pointer type.
744 auto valueType = op.result().getType();
745 auto llvmValueType = getTypeConverter()->convertType(valueType);
746 if (!llvmValueType)
747 return rewriter.notifyMatchFailure(
748 op, "failed to convert loaded value type to LLVM type");
749
750 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
751 loc, LLVM::LLVMPointerType::get(llvmValueType),
752 storagePtr.getResult(0));
753
754 // Load from the casted pointer.
755 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
756
757 return success();
758 }
759 };
760 } // namespace
761
762 //===----------------------------------------------------------------------===//
763 // Convert async.runtime.add_to_group to the corresponding runtime API call.
764 //===----------------------------------------------------------------------===//
765
766 namespace {
767 class RuntimeAddToGroupOpLowering
768 : public OpConversionPattern<RuntimeAddToGroupOp> {
769 public:
770 using OpConversionPattern::OpConversionPattern;
771
772 LogicalResult
matchAndRewrite(RuntimeAddToGroupOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const773 matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands,
774 ConversionPatternRewriter &rewriter) const override {
775 // Currently we can only add tokens to the group.
776 if (!op.operand().getType().isa<TokenType>())
777 return rewriter.notifyMatchFailure(op, "only token type is supported");
778
779 // Replace with a runtime API function call.
780 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup,
781 rewriter.getI64Type(), operands);
782
783 return success();
784 }
785 };
786 } // namespace
787
788 //===----------------------------------------------------------------------===//
789 // Async reference counting ops lowering (`async.runtime.add_ref` and
790 // `async.runtime.drop_ref` to the corresponding API calls).
791 //===----------------------------------------------------------------------===//
792
793 namespace {
794 template <typename RefCountingOp>
795 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
796 public:
RefCountingOpLowering(TypeConverter & converter,MLIRContext * ctx,StringRef apiFunctionName)797 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
798 StringRef apiFunctionName)
799 : OpConversionPattern<RefCountingOp>(converter, ctx),
800 apiFunctionName(apiFunctionName) {}
801
802 LogicalResult
matchAndRewrite(RefCountingOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const803 matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands,
804 ConversionPatternRewriter &rewriter) const override {
805 auto count =
806 rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(),
807 rewriter.getI32IntegerAttr(op.count()));
808
809 auto operand = typename RefCountingOp::Adaptor(operands).operand();
810 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
811 ValueRange({operand, count}));
812
813 return success();
814 }
815
816 private:
817 StringRef apiFunctionName;
818 };
819
820 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
821 public:
RuntimeAddRefOpLowering(TypeConverter & converter,MLIRContext * ctx)822 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
823 : RefCountingOpLowering(converter, ctx, kAddRef) {}
824 };
825
826 class RuntimeDropRefOpLowering
827 : public RefCountingOpLowering<RuntimeDropRefOp> {
828 public:
RuntimeDropRefOpLowering(TypeConverter & converter,MLIRContext * ctx)829 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
830 : RefCountingOpLowering(converter, ctx, kDropRef) {}
831 };
832 } // namespace
833
834 //===----------------------------------------------------------------------===//
835 // Convert return operations that return async values from async regions.
836 //===----------------------------------------------------------------------===//
837
838 namespace {
839 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
840 public:
841 using OpConversionPattern::OpConversionPattern;
842
843 LogicalResult
matchAndRewrite(ReturnOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const844 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
845 ConversionPatternRewriter &rewriter) const override {
846 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
847 return success();
848 }
849 };
850 } // namespace
851
852 //===----------------------------------------------------------------------===//
853
854 namespace {
855 struct ConvertAsyncToLLVMPass
856 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
857 void runOnOperation() override;
858 };
859 } // namespace
860
runOnOperation()861 void ConvertAsyncToLLVMPass::runOnOperation() {
862 ModuleOp module = getOperation();
863 MLIRContext *ctx = module->getContext();
864
865 // Add declarations for all functions required by the coroutines lowering.
866 addResumeFunction(module);
867 addAsyncRuntimeApiDeclarations(module);
868 addCRuntimeDeclarations(module);
869
870 // Lower async.runtime and async.coro operations to Async Runtime API and
871 // LLVM coroutine intrinsics.
872
873 // Convert async dialect types and operations to LLVM dialect.
874 AsyncRuntimeTypeConverter converter;
875 OwningRewritePatternList patterns;
876
877 // We use conversion to LLVM type to lower async.runtime load and store
878 // operations.
879 LLVMTypeConverter llvmConverter(ctx);
880 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
881
882 // Convert async types in function signatures and function calls.
883 populateFuncOpTypeConversionPattern(patterns, ctx, converter);
884 populateCallOpTypeConversionPattern(patterns, ctx, converter);
885
886 // Convert return operations inside async.execute regions.
887 patterns.insert<ReturnOpOpConversion>(converter, ctx);
888
889 // Lower async.runtime operations to the async runtime API calls.
890 patterns.insert<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering,
891 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
892 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
893 RuntimeDropRefOpLowering>(converter, ctx);
894
895 // Lower async.runtime operations that rely on LLVM type converter to convert
896 // from async value payload type to the LLVM type.
897 patterns.insert<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
898 RuntimeLoadOpLowering>(llvmConverter, ctx);
899
900 // Lower async coroutine operations to LLVM coroutine intrinsics.
901 patterns.insert<CoroIdOpConversion, CoroBeginOpConversion,
902 CoroFreeOpConversion, CoroEndOpConversion,
903 CoroSaveOpConversion, CoroSuspendOpConversion>(converter,
904 ctx);
905
906 ConversionTarget target(*ctx);
907 target.addLegalOp<ConstantOp>();
908 target.addLegalDialect<LLVM::LLVMDialect>();
909
910 // All operations from Async dialect must be lowered to the runtime API and
911 // LLVM intrinsics calls.
912 target.addIllegalDialect<AsyncDialect>();
913
914 // Add dynamic legality constraints to apply conversions defined above.
915 target.addDynamicallyLegalOp<FuncOp>(
916 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
917 target.addDynamicallyLegalOp<ReturnOp>(
918 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
919 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
920 return converter.isSignatureLegal(op.getCalleeType());
921 });
922
923 if (failed(applyPartialConversion(module, target, std::move(patterns))))
924 signalPassFailure();
925 }
926
927 //===----------------------------------------------------------------------===//
928 // Patterns for structural type conversions for the Async dialect operations.
929 //===----------------------------------------------------------------------===//
930
931 namespace {
932 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
933 public:
934 using OpConversionPattern::OpConversionPattern;
935 LogicalResult
matchAndRewrite(ExecuteOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const936 matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
937 ConversionPatternRewriter &rewriter) const override {
938 ExecuteOp newOp =
939 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
940 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
941 newOp.getRegion().end());
942
943 // Set operands and update block argument and result types.
944 newOp->setOperands(operands);
945 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
946 return failure();
947 for (auto result : newOp.getResults())
948 result.setType(typeConverter->convertType(result.getType()));
949
950 rewriter.replaceOp(op, newOp.getResults());
951 return success();
952 }
953 };
954
955 // Dummy pattern to trigger the appropriate type conversion / materialization.
956 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
957 public:
958 using OpConversionPattern::OpConversionPattern;
959 LogicalResult
matchAndRewrite(AwaitOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const960 matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
961 ConversionPatternRewriter &rewriter) const override {
962 rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
963 return success();
964 }
965 };
966
967 // Dummy pattern to trigger the appropriate type conversion / materialization.
968 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
969 public:
970 using OpConversionPattern::OpConversionPattern;
971 LogicalResult
matchAndRewrite(async::YieldOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const972 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
973 ConversionPatternRewriter &rewriter) const override {
974 rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
975 return success();
976 }
977 };
978 } // namespace
979
createConvertAsyncToLLVMPass()980 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
981 return std::make_unique<ConvertAsyncToLLVMPass>();
982 }
983
populateAsyncStructuralTypeConversionsAndLegality(MLIRContext * context,TypeConverter & typeConverter,OwningRewritePatternList & patterns,ConversionTarget & target)984 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
985 MLIRContext *context, TypeConverter &typeConverter,
986 OwningRewritePatternList &patterns, ConversionTarget &target) {
987 typeConverter.addConversion([&](TokenType type) { return type; });
988 typeConverter.addConversion([&](ValueType type) {
989 return ValueType::get(typeConverter.convertType(type.getValueType()));
990 });
991
992 patterns
993 .insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
994 typeConverter, context);
995
996 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
997 [&](Operation *op) { return typeConverter.isLegal(op); });
998 }
999