1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Async/IR/Async.h"
10 
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::async;
16 
17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
18 
initialize()19 void AsyncDialect::initialize() {
20   addOperations<
21 #define GET_OP_LIST
22 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
23       >();
24   addTypes<
25 #define GET_TYPEDEF_LIST
26 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
27       >();
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // YieldOp
32 //===----------------------------------------------------------------------===//
33 
verify(YieldOp op)34 static LogicalResult verify(YieldOp op) {
35   // Get the underlying value types from async values returned from the
36   // parent `async.execute` operation.
37   auto executeOp = op->getParentOfType<ExecuteOp>();
38   auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
39     return result.getType().cast<ValueType>().getValueType();
40   });
41 
42   if (op.getOperandTypes() != types)
43     return op.emitOpError("operand types do not match the types returned from "
44                           "the parent ExecuteOp");
45 
46   return success();
47 }
48 
49 //===----------------------------------------------------------------------===//
50 /// ExecuteOp
51 //===----------------------------------------------------------------------===//
52 
53 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
54 
getNumRegionInvocations(ArrayRef<Attribute> operands,SmallVectorImpl<int64_t> & countPerRegion)55 void ExecuteOp::getNumRegionInvocations(
56     ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
57   (void)operands;
58   assert(countPerRegion.empty());
59   countPerRegion.push_back(1);
60 }
61 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)62 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
63                                     ArrayRef<Attribute> operands,
64                                     SmallVectorImpl<RegionSuccessor> &regions) {
65   // The `body` region branch back to the parent operation.
66   if (index.hasValue()) {
67     assert(*index == 0);
68     regions.push_back(RegionSuccessor(getResults()));
69     return;
70   }
71 
72   // Otherwise the successor is the body region.
73   regions.push_back(RegionSuccessor(&body()));
74 }
75 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange dependencies,ValueRange operands,BodyBuilderFn bodyBuilder)76 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
77                       TypeRange resultTypes, ValueRange dependencies,
78                       ValueRange operands, BodyBuilderFn bodyBuilder) {
79 
80   result.addOperands(dependencies);
81   result.addOperands(operands);
82 
83   // Add derived `operand_segment_sizes` attribute based on parsed operands.
84   int32_t numDependencies = dependencies.size();
85   int32_t numOperands = operands.size();
86   auto operandSegmentSizes = DenseIntElementsAttr::get(
87       VectorType::get({2}, builder.getIntegerType(32)),
88       {numDependencies, numOperands});
89   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
90 
91   // First result is always a token, and then `resultTypes` wrapped into
92   // `async.value`.
93   result.addTypes({TokenType::get(result.getContext())});
94   for (Type type : resultTypes)
95     result.addTypes(ValueType::get(type));
96 
97   // Add a body region with block arguments as unwrapped async value operands.
98   Region *bodyRegion = result.addRegion();
99   bodyRegion->push_back(new Block);
100   Block &bodyBlock = bodyRegion->front();
101   for (Value operand : operands) {
102     auto valueType = operand.getType().dyn_cast<ValueType>();
103     bodyBlock.addArgument(valueType ? valueType.getValueType()
104                                     : operand.getType());
105   }
106 
107   // Create the default terminator if the builder is not provided and if the
108   // expected result is empty. Otherwise, leave this to the caller
109   // because we don't know which values to return from the execute op.
110   if (resultTypes.empty() && !bodyBuilder) {
111     OpBuilder::InsertionGuard guard(builder);
112     builder.setInsertionPointToStart(&bodyBlock);
113     builder.create<async::YieldOp>(result.location, ValueRange());
114   } else if (bodyBuilder) {
115     OpBuilder::InsertionGuard guard(builder);
116     builder.setInsertionPointToStart(&bodyBlock);
117     bodyBuilder(builder, result.location, bodyBlock.getArguments());
118   }
119 }
120 
print(OpAsmPrinter & p,ExecuteOp op)121 static void print(OpAsmPrinter &p, ExecuteOp op) {
122   p << op.getOperationName();
123 
124   // [%tokens,...]
125   if (!op.dependencies().empty())
126     p << " [" << op.dependencies() << "]";
127 
128   // (%value as %unwrapped: !async.value<!arg.type>, ...)
129   if (!op.operands().empty()) {
130     p << " (";
131     Block *entry = op.body().empty() ? nullptr : &op.body().front();
132     llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
133       Value argument = entry ? entry->getArgument(n++) : Value();
134       p << operand << " as " << argument << ": " << operand.getType();
135     });
136     p << ")";
137   }
138 
139   // -> (!async.value<!return.type>, ...)
140   p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes()));
141   p.printOptionalAttrDictWithKeyword(op->getAttrs(),
142                                      {kOperandSegmentSizesAttr});
143   p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
144 }
145 
parseExecuteOp(OpAsmParser & parser,OperationState & result)146 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
147   MLIRContext *ctx = result.getContext();
148 
149   // Sizes of parsed variadic operands, will be updated below after parsing.
150   int32_t numDependencies = 0;
151   int32_t numOperands = 0;
152 
153   auto tokenTy = TokenType::get(ctx);
154 
155   // Parse dependency tokens.
156   if (succeeded(parser.parseOptionalLSquare())) {
157     SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
158     if (parser.parseOperandList(tokenArgs) ||
159         parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
160         parser.parseRSquare())
161       return failure();
162 
163     numDependencies = tokenArgs.size();
164   }
165 
166   // Parse async value operands (%value as %unwrapped : !async.value<!type>).
167   SmallVector<OpAsmParser::OperandType, 4> valueArgs;
168   SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
169   SmallVector<Type, 4> valueTypes;
170   SmallVector<Type, 4> unwrappedTypes;
171 
172   if (succeeded(parser.parseOptionalLParen())) {
173     auto argsLoc = parser.getCurrentLocation();
174 
175     // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
176     auto parseAsyncValueArg = [&]() -> ParseResult {
177       if (parser.parseOperand(valueArgs.emplace_back()) ||
178           parser.parseKeyword("as") ||
179           parser.parseOperand(unwrappedArgs.emplace_back()) ||
180           parser.parseColonType(valueTypes.emplace_back()))
181         return failure();
182 
183       auto valueTy = valueTypes.back().dyn_cast<ValueType>();
184       unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
185 
186       return success();
187     };
188 
189     // If the next token is `)` skip async value arguments parsing.
190     if (failed(parser.parseOptionalRParen())) {
191       do {
192         if (parseAsyncValueArg())
193           return failure();
194       } while (succeeded(parser.parseOptionalComma()));
195 
196       if (parser.parseRParen() ||
197           parser.resolveOperands(valueArgs, valueTypes, argsLoc,
198                                  result.operands))
199         return failure();
200     }
201 
202     numOperands = valueArgs.size();
203   }
204 
205   // Add derived `operand_segment_sizes` attribute based on parsed operands.
206   auto operandSegmentSizes = DenseIntElementsAttr::get(
207       VectorType::get({2}, parser.getBuilder().getI32Type()),
208       {numDependencies, numOperands});
209   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
210 
211   // Parse the types of results returned from the async execute op.
212   SmallVector<Type, 4> resultTypes;
213   if (parser.parseOptionalArrowTypeList(resultTypes))
214     return failure();
215 
216   // Async execute first result is always a completion token.
217   parser.addTypeToList(tokenTy, result.types);
218   parser.addTypesToList(resultTypes, result.types);
219 
220   // Parse operation attributes.
221   NamedAttrList attrs;
222   if (parser.parseOptionalAttrDictWithKeyword(attrs))
223     return failure();
224   result.addAttributes(attrs);
225 
226   // Parse asynchronous region.
227   Region *body = result.addRegion();
228   if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
229                          /*argTypes=*/{unwrappedTypes},
230                          /*enableNameShadowing=*/false))
231     return failure();
232 
233   return success();
234 }
235 
verify(ExecuteOp op)236 static LogicalResult verify(ExecuteOp op) {
237   // Unwrap async.execute value operands types.
238   auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
239     return operand.getType().cast<ValueType>().getValueType();
240   });
241 
242   // Verify that unwrapped argument types matches the body region arguments.
243   if (op.body().getArgumentTypes() != unwrappedTypes)
244     return op.emitOpError("async body region argument types do not match the "
245                           "execute operation arguments types");
246 
247   return success();
248 }
249 
250 //===----------------------------------------------------------------------===//
251 /// CreateGroupOp
252 //===----------------------------------------------------------------------===//
253 
canonicalize(CreateGroupOp op,PatternRewriter & rewriter)254 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
255                                           PatternRewriter &rewriter) {
256   // Find all `await_all` users of the group.
257   llvm::SmallVector<AwaitAllOp> awaitAllUsers;
258 
259   auto isAwaitAll = [&](Operation *op) -> bool {
260     if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
261       awaitAllUsers.push_back(awaitAll);
262       return true;
263     }
264     return false;
265   };
266 
267   // Check if all users of the group are `await_all` operations.
268   if (!llvm::all_of(op->getUsers(), isAwaitAll))
269     return failure();
270 
271   // If group is only awaited without adding anything to it, we can safely erase
272   // the create operation and all users.
273   for (AwaitAllOp awaitAll : awaitAllUsers)
274     rewriter.eraseOp(awaitAll);
275   rewriter.eraseOp(op);
276 
277   return success();
278 }
279 
280 //===----------------------------------------------------------------------===//
281 /// AwaitOp
282 //===----------------------------------------------------------------------===//
283 
build(OpBuilder & builder,OperationState & result,Value operand,ArrayRef<NamedAttribute> attrs)284 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
285                     ArrayRef<NamedAttribute> attrs) {
286   result.addOperands({operand});
287   result.attributes.append(attrs.begin(), attrs.end());
288 
289   // Add unwrapped async.value type to the returned values types.
290   if (auto valueType = operand.getType().dyn_cast<ValueType>())
291     result.addTypes(valueType.getValueType());
292 }
293 
parseAwaitResultType(OpAsmParser & parser,Type & operandType,Type & resultType)294 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
295                                         Type &resultType) {
296   if (parser.parseType(operandType))
297     return failure();
298 
299   // Add unwrapped async.value type to the returned values types.
300   if (auto valueType = operandType.dyn_cast<ValueType>())
301     resultType = valueType.getValueType();
302 
303   return success();
304 }
305 
printAwaitResultType(OpAsmPrinter & p,Operation * op,Type operandType,Type resultType)306 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
307                                  Type operandType, Type resultType) {
308   p << operandType;
309 }
310 
verify(AwaitOp op)311 static LogicalResult verify(AwaitOp op) {
312   Type argType = op.operand().getType();
313 
314   // Awaiting on a token does not have any results.
315   if (argType.isa<TokenType>() && !op.getResultTypes().empty())
316     return op.emitOpError("awaiting on a token must have empty result");
317 
318   // Awaiting on a value unwraps the async value type.
319   if (auto value = argType.dyn_cast<ValueType>()) {
320     if (*op.getResultType() != value.getValueType())
321       return op.emitOpError()
322              << "result type " << *op.getResultType()
323              << " does not match async value type " << value.getValueType();
324   }
325 
326   return success();
327 }
328 
329 //===----------------------------------------------------------------------===//
330 // TableGen'd op method definitions
331 //===----------------------------------------------------------------------===//
332 
333 #define GET_OP_CLASSES
334 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
335 
336 //===----------------------------------------------------------------------===//
337 // TableGen'd type method definitions
338 //===----------------------------------------------------------------------===//
339 
340 #define GET_TYPEDEF_CLASSES
341 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
342 
print(DialectAsmPrinter & printer) const343 void ValueType::print(DialectAsmPrinter &printer) const {
344   printer << getMnemonic();
345   printer << "<";
346   printer.printType(getValueType());
347   printer << '>';
348 }
349 
parse(mlir::MLIRContext *,mlir::DialectAsmParser & parser)350 Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) {
351   Type ty;
352   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
353     parser.emitError(parser.getNameLoc(), "failed to parse async value type");
354     return Type();
355   }
356   return ValueType::get(ty);
357 }
358 
359 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const360 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
361   if (failed(generatedTypePrinter(type, os)))
362     llvm_unreachable("unexpected 'async' type kind");
363 }
364 
365 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const366 Type AsyncDialect::parseType(DialectAsmParser &parser) const {
367   StringRef typeTag;
368   if (parser.parseKeyword(&typeTag))
369     return Type();
370   Type genType;
371   auto parseResult = generatedTypeParser(parser.getBuilder().getContext(),
372                                          parser, typeTag, genType);
373   if (parseResult.hasValue())
374     return genType;
375   parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag;
376   return {};
377 }
378