1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Async/IR/Async.h"
10
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13
14 using namespace mlir;
15 using namespace mlir::async;
16
17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
18
initialize()19 void AsyncDialect::initialize() {
20 addOperations<
21 #define GET_OP_LIST
22 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
23 >();
24 addTypes<
25 #define GET_TYPEDEF_LIST
26 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
27 >();
28 }
29
30 //===----------------------------------------------------------------------===//
31 // YieldOp
32 //===----------------------------------------------------------------------===//
33
verify(YieldOp op)34 static LogicalResult verify(YieldOp op) {
35 // Get the underlying value types from async values returned from the
36 // parent `async.execute` operation.
37 auto executeOp = op->getParentOfType<ExecuteOp>();
38 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
39 return result.getType().cast<ValueType>().getValueType();
40 });
41
42 if (op.getOperandTypes() != types)
43 return op.emitOpError("operand types do not match the types returned from "
44 "the parent ExecuteOp");
45
46 return success();
47 }
48
49 //===----------------------------------------------------------------------===//
50 /// ExecuteOp
51 //===----------------------------------------------------------------------===//
52
53 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
54
getNumRegionInvocations(ArrayRef<Attribute> operands,SmallVectorImpl<int64_t> & countPerRegion)55 void ExecuteOp::getNumRegionInvocations(
56 ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
57 (void)operands;
58 assert(countPerRegion.empty());
59 countPerRegion.push_back(1);
60 }
61
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)62 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
63 ArrayRef<Attribute> operands,
64 SmallVectorImpl<RegionSuccessor> ®ions) {
65 // The `body` region branch back to the parent operation.
66 if (index.hasValue()) {
67 assert(*index == 0);
68 regions.push_back(RegionSuccessor(getResults()));
69 return;
70 }
71
72 // Otherwise the successor is the body region.
73 regions.push_back(RegionSuccessor(&body()));
74 }
75
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange dependencies,ValueRange operands,BodyBuilderFn bodyBuilder)76 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
77 TypeRange resultTypes, ValueRange dependencies,
78 ValueRange operands, BodyBuilderFn bodyBuilder) {
79
80 result.addOperands(dependencies);
81 result.addOperands(operands);
82
83 // Add derived `operand_segment_sizes` attribute based on parsed operands.
84 int32_t numDependencies = dependencies.size();
85 int32_t numOperands = operands.size();
86 auto operandSegmentSizes = DenseIntElementsAttr::get(
87 VectorType::get({2}, builder.getIntegerType(32)),
88 {numDependencies, numOperands});
89 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
90
91 // First result is always a token, and then `resultTypes` wrapped into
92 // `async.value`.
93 result.addTypes({TokenType::get(result.getContext())});
94 for (Type type : resultTypes)
95 result.addTypes(ValueType::get(type));
96
97 // Add a body region with block arguments as unwrapped async value operands.
98 Region *bodyRegion = result.addRegion();
99 bodyRegion->push_back(new Block);
100 Block &bodyBlock = bodyRegion->front();
101 for (Value operand : operands) {
102 auto valueType = operand.getType().dyn_cast<ValueType>();
103 bodyBlock.addArgument(valueType ? valueType.getValueType()
104 : operand.getType());
105 }
106
107 // Create the default terminator if the builder is not provided and if the
108 // expected result is empty. Otherwise, leave this to the caller
109 // because we don't know which values to return from the execute op.
110 if (resultTypes.empty() && !bodyBuilder) {
111 OpBuilder::InsertionGuard guard(builder);
112 builder.setInsertionPointToStart(&bodyBlock);
113 builder.create<async::YieldOp>(result.location, ValueRange());
114 } else if (bodyBuilder) {
115 OpBuilder::InsertionGuard guard(builder);
116 builder.setInsertionPointToStart(&bodyBlock);
117 bodyBuilder(builder, result.location, bodyBlock.getArguments());
118 }
119 }
120
print(OpAsmPrinter & p,ExecuteOp op)121 static void print(OpAsmPrinter &p, ExecuteOp op) {
122 p << op.getOperationName();
123
124 // [%tokens,...]
125 if (!op.dependencies().empty())
126 p << " [" << op.dependencies() << "]";
127
128 // (%value as %unwrapped: !async.value<!arg.type>, ...)
129 if (!op.operands().empty()) {
130 p << " (";
131 Block *entry = op.body().empty() ? nullptr : &op.body().front();
132 llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
133 Value argument = entry ? entry->getArgument(n++) : Value();
134 p << operand << " as " << argument << ": " << operand.getType();
135 });
136 p << ")";
137 }
138
139 // -> (!async.value<!return.type>, ...)
140 p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes()));
141 p.printOptionalAttrDictWithKeyword(op->getAttrs(),
142 {kOperandSegmentSizesAttr});
143 p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
144 }
145
parseExecuteOp(OpAsmParser & parser,OperationState & result)146 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
147 MLIRContext *ctx = result.getContext();
148
149 // Sizes of parsed variadic operands, will be updated below after parsing.
150 int32_t numDependencies = 0;
151 int32_t numOperands = 0;
152
153 auto tokenTy = TokenType::get(ctx);
154
155 // Parse dependency tokens.
156 if (succeeded(parser.parseOptionalLSquare())) {
157 SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
158 if (parser.parseOperandList(tokenArgs) ||
159 parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
160 parser.parseRSquare())
161 return failure();
162
163 numDependencies = tokenArgs.size();
164 }
165
166 // Parse async value operands (%value as %unwrapped : !async.value<!type>).
167 SmallVector<OpAsmParser::OperandType, 4> valueArgs;
168 SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
169 SmallVector<Type, 4> valueTypes;
170 SmallVector<Type, 4> unwrappedTypes;
171
172 if (succeeded(parser.parseOptionalLParen())) {
173 auto argsLoc = parser.getCurrentLocation();
174
175 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
176 auto parseAsyncValueArg = [&]() -> ParseResult {
177 if (parser.parseOperand(valueArgs.emplace_back()) ||
178 parser.parseKeyword("as") ||
179 parser.parseOperand(unwrappedArgs.emplace_back()) ||
180 parser.parseColonType(valueTypes.emplace_back()))
181 return failure();
182
183 auto valueTy = valueTypes.back().dyn_cast<ValueType>();
184 unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
185
186 return success();
187 };
188
189 // If the next token is `)` skip async value arguments parsing.
190 if (failed(parser.parseOptionalRParen())) {
191 do {
192 if (parseAsyncValueArg())
193 return failure();
194 } while (succeeded(parser.parseOptionalComma()));
195
196 if (parser.parseRParen() ||
197 parser.resolveOperands(valueArgs, valueTypes, argsLoc,
198 result.operands))
199 return failure();
200 }
201
202 numOperands = valueArgs.size();
203 }
204
205 // Add derived `operand_segment_sizes` attribute based on parsed operands.
206 auto operandSegmentSizes = DenseIntElementsAttr::get(
207 VectorType::get({2}, parser.getBuilder().getI32Type()),
208 {numDependencies, numOperands});
209 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
210
211 // Parse the types of results returned from the async execute op.
212 SmallVector<Type, 4> resultTypes;
213 if (parser.parseOptionalArrowTypeList(resultTypes))
214 return failure();
215
216 // Async execute first result is always a completion token.
217 parser.addTypeToList(tokenTy, result.types);
218 parser.addTypesToList(resultTypes, result.types);
219
220 // Parse operation attributes.
221 NamedAttrList attrs;
222 if (parser.parseOptionalAttrDictWithKeyword(attrs))
223 return failure();
224 result.addAttributes(attrs);
225
226 // Parse asynchronous region.
227 Region *body = result.addRegion();
228 if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
229 /*argTypes=*/{unwrappedTypes},
230 /*enableNameShadowing=*/false))
231 return failure();
232
233 return success();
234 }
235
verify(ExecuteOp op)236 static LogicalResult verify(ExecuteOp op) {
237 // Unwrap async.execute value operands types.
238 auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
239 return operand.getType().cast<ValueType>().getValueType();
240 });
241
242 // Verify that unwrapped argument types matches the body region arguments.
243 if (op.body().getArgumentTypes() != unwrappedTypes)
244 return op.emitOpError("async body region argument types do not match the "
245 "execute operation arguments types");
246
247 return success();
248 }
249
250 //===----------------------------------------------------------------------===//
251 /// CreateGroupOp
252 //===----------------------------------------------------------------------===//
253
canonicalize(CreateGroupOp op,PatternRewriter & rewriter)254 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
255 PatternRewriter &rewriter) {
256 // Find all `await_all` users of the group.
257 llvm::SmallVector<AwaitAllOp> awaitAllUsers;
258
259 auto isAwaitAll = [&](Operation *op) -> bool {
260 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
261 awaitAllUsers.push_back(awaitAll);
262 return true;
263 }
264 return false;
265 };
266
267 // Check if all users of the group are `await_all` operations.
268 if (!llvm::all_of(op->getUsers(), isAwaitAll))
269 return failure();
270
271 // If group is only awaited without adding anything to it, we can safely erase
272 // the create operation and all users.
273 for (AwaitAllOp awaitAll : awaitAllUsers)
274 rewriter.eraseOp(awaitAll);
275 rewriter.eraseOp(op);
276
277 return success();
278 }
279
280 //===----------------------------------------------------------------------===//
281 /// AwaitOp
282 //===----------------------------------------------------------------------===//
283
build(OpBuilder & builder,OperationState & result,Value operand,ArrayRef<NamedAttribute> attrs)284 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
285 ArrayRef<NamedAttribute> attrs) {
286 result.addOperands({operand});
287 result.attributes.append(attrs.begin(), attrs.end());
288
289 // Add unwrapped async.value type to the returned values types.
290 if (auto valueType = operand.getType().dyn_cast<ValueType>())
291 result.addTypes(valueType.getValueType());
292 }
293
parseAwaitResultType(OpAsmParser & parser,Type & operandType,Type & resultType)294 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
295 Type &resultType) {
296 if (parser.parseType(operandType))
297 return failure();
298
299 // Add unwrapped async.value type to the returned values types.
300 if (auto valueType = operandType.dyn_cast<ValueType>())
301 resultType = valueType.getValueType();
302
303 return success();
304 }
305
printAwaitResultType(OpAsmPrinter & p,Operation * op,Type operandType,Type resultType)306 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
307 Type operandType, Type resultType) {
308 p << operandType;
309 }
310
verify(AwaitOp op)311 static LogicalResult verify(AwaitOp op) {
312 Type argType = op.operand().getType();
313
314 // Awaiting on a token does not have any results.
315 if (argType.isa<TokenType>() && !op.getResultTypes().empty())
316 return op.emitOpError("awaiting on a token must have empty result");
317
318 // Awaiting on a value unwraps the async value type.
319 if (auto value = argType.dyn_cast<ValueType>()) {
320 if (*op.getResultType() != value.getValueType())
321 return op.emitOpError()
322 << "result type " << *op.getResultType()
323 << " does not match async value type " << value.getValueType();
324 }
325
326 return success();
327 }
328
329 //===----------------------------------------------------------------------===//
330 // TableGen'd op method definitions
331 //===----------------------------------------------------------------------===//
332
333 #define GET_OP_CLASSES
334 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
335
336 //===----------------------------------------------------------------------===//
337 // TableGen'd type method definitions
338 //===----------------------------------------------------------------------===//
339
340 #define GET_TYPEDEF_CLASSES
341 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
342
print(DialectAsmPrinter & printer) const343 void ValueType::print(DialectAsmPrinter &printer) const {
344 printer << getMnemonic();
345 printer << "<";
346 printer.printType(getValueType());
347 printer << '>';
348 }
349
parse(mlir::MLIRContext *,mlir::DialectAsmParser & parser)350 Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) {
351 Type ty;
352 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
353 parser.emitError(parser.getNameLoc(), "failed to parse async value type");
354 return Type();
355 }
356 return ValueType::get(ty);
357 }
358
359 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const360 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
361 if (failed(generatedTypePrinter(type, os)))
362 llvm_unreachable("unexpected 'async' type kind");
363 }
364
365 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const366 Type AsyncDialect::parseType(DialectAsmParser &parser) const {
367 StringRef typeTag;
368 if (parser.parseKeyword(&typeTag))
369 return Type();
370 Type genType;
371 auto parseResult = generatedTypeParser(parser.getBuilder().getContext(),
372 parser, typeTag, genType);
373 if (parseResult.hasValue())
374 return genType;
375 parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag;
376 return {};
377 }
378