1 //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements scf.parallel to scf.for + async.execute conversion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Async/IR/Async.h"
15 #include "mlir/Dialect/Async/Passes.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 
24 using namespace mlir;
25 using namespace mlir::async;
26 
27 #define DEBUG_TYPE "async-parallel-for"
28 
29 namespace {
30 
31 // Rewrite scf.parallel operation into multiple concurrent async.execute
32 // operations over non overlapping subranges of the original loop.
33 //
34 // Example:
35 //
36 //   scf.parallel (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
37 //     "do_some_compute"(%i, %j): () -> ()
38 //   }
39 //
40 // Converted to:
41 //
42 //   // Parallel compute function that executes the parallel body region for
43 //   // a subset of the parallel iteration space defined by the one-dimensional
44 //   // compute block index.
45 //   func parallel_compute_function(%block_index : index, %block_size : index,
46 //                                  <parallel operation properties>, ...) {
47 //     // Compute multi-dimensional loop bounds for %block_index.
48 //     %block_lbi, %block_lbj = ...
49 //     %block_ubi, %block_ubj = ...
50 //
51 //     // Clone parallel operation body into the scf.for loop nest.
52 //     scf.for %i = %blockLbi to %blockUbi {
53 //       scf.for %j = block_lbj to %block_ubj {
54 //         "do_some_compute"(%i, %j): () -> ()
55 //       }
56 //     }
57 //   }
58 //
59 // And a dispatch function depending on the `asyncDispatch` option.
60 //
61 // When async dispatch is on: (pseudocode)
62 //
63 //   %block_size = ... compute parallel compute block size
64 //   %block_count = ... compute the number of compute blocks
65 //
66 //   func @async_dispatch(%block_start : index, %block_end : index, ...) {
67 //     // Keep splitting block range until we reached a range of size 1.
68 //     while (%block_end - %block_start > 1) {
69 //       %mid_index = block_start + (block_end - block_start) / 2;
70 //       async.execute { call @async_dispatch(%mid_index, %block_end); }
71 //       %block_end = %mid_index
72 //     }
73 //
74 //     // Call parallel compute function for a single block.
75 //     call @parallel_compute_fn(%block_start, %block_size, ...);
76 //   }
77 //
78 //   // Launch async dispatch for [0, block_count) range.
79 //   call @async_dispatch(%c0, %block_count);
80 //
81 // When async dispatch is off:
82 //
83 //   %block_size = ... compute parallel compute block size
84 //   %block_count = ... compute the number of compute blocks
85 //
86 //   scf.for %block_index = %c0 to %block_count {
87 //      call @parallel_compute_fn(%block_index, %block_size, ...)
88 //   }
89 //
90 struct AsyncParallelForPass
91     : public AsyncParallelForBase<AsyncParallelForPass> {
92   AsyncParallelForPass() = default;
93 
AsyncParallelForPass__anon0a9cb5450111::AsyncParallelForPass94   AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
95                        int32_t targetBlockSize) {
96     this->asyncDispatch = asyncDispatch;
97     this->numWorkerThreads = numWorkerThreads;
98     this->targetBlockSize = targetBlockSize;
99   }
100 
101   void runOnOperation() override;
102 };
103 
104 struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
105 public:
AsyncParallelForRewrite__anon0a9cb5450111::AsyncParallelForRewrite106   AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch,
107                           int32_t numWorkerThreads, int32_t targetBlockSize)
108       : OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
109         numWorkerThreads(numWorkerThreads), targetBlockSize(targetBlockSize) {}
110 
111   LogicalResult matchAndRewrite(scf::ParallelOp op,
112                                 PatternRewriter &rewriter) const override;
113 
114 private:
115   bool asyncDispatch;
116   int32_t numWorkerThreads;
117   int32_t targetBlockSize;
118 };
119 
120 struct ParallelComputeFunctionType {
121   FunctionType type;
122   llvm::SmallVector<Value> captures;
123 };
124 
125 struct ParallelComputeFunction {
126   FuncOp func;
127   llvm::SmallVector<Value> captures;
128 };
129 
130 } // namespace
131 
132 // Converts one-dimensional iteration index in the [0, tripCount) interval
133 // into multidimensional iteration coordinate.
delinearize(ImplicitLocOpBuilder & b,Value index,ArrayRef<Value> tripCounts)134 static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
135                                       ArrayRef<Value> tripCounts) {
136   SmallVector<Value> coords(tripCounts.size());
137   assert(!tripCounts.empty() && "tripCounts must be not empty");
138 
139   for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
140     coords[i] = b.create<SignedRemIOp>(index, tripCounts[i]);
141     index = b.create<SignedDivIOp>(index, tripCounts[i]);
142   }
143 
144   return coords;
145 }
146 
147 // Returns a function type and implicit captures for a parallel compute
148 // function. We'll need a list of implicit captures to setup block and value
149 // mapping when we'll clone the body of the parallel operation.
150 static ParallelComputeFunctionType
getParallelComputeFunctionType(scf::ParallelOp op,PatternRewriter & rewriter)151 getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
152   // Values implicitly captured by the parallel operation.
153   llvm::SetVector<Value> captures;
154   getUsedValuesDefinedAbove(op.region(), op.region(), captures);
155 
156   llvm::SmallVector<Type> inputs;
157   inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
158 
159   Type indexTy = rewriter.getIndexType();
160 
161   // One-dimensional iteration space defined by the block index and size.
162   inputs.push_back(indexTy); // blockIndex
163   inputs.push_back(indexTy); // blockSize
164 
165   // Multi-dimensional parallel iteration space defined by the loop trip counts.
166   for (unsigned i = 0; i < op.getNumLoops(); ++i)
167     inputs.push_back(indexTy); // loop tripCount
168 
169   // Parallel operation lower bound, upper bound and step.
170   for (unsigned i = 0; i < op.getNumLoops(); ++i) {
171     inputs.push_back(indexTy); // lower bound
172     inputs.push_back(indexTy); // upper bound
173     inputs.push_back(indexTy); // step
174   }
175 
176   // Types of the implicit captures.
177   for (Value capture : captures)
178     inputs.push_back(capture.getType());
179 
180   // Convert captures to vector for later convenience.
181   SmallVector<Value> capturesVector(captures.begin(), captures.end());
182   return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector};
183 }
184 
185 // Create a parallel compute fuction from the parallel operation.
186 static ParallelComputeFunction
createParallelComputeFunction(scf::ParallelOp op,PatternRewriter & rewriter)187 createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
188   OpBuilder::InsertionGuard guard(rewriter);
189   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
190 
191   ModuleOp module = op->getParentOfType<ModuleOp>();
192 
193   ParallelComputeFunctionType computeFuncType =
194       getParallelComputeFunctionType(op, rewriter);
195 
196   FunctionType type = computeFuncType.type;
197   FuncOp func = FuncOp::create(op.getLoc(), "parallel_compute_fn", type);
198   func.setPrivate();
199 
200   // Insert function into the module symbol table and assign it unique name.
201   SymbolTable symbolTable(module);
202   symbolTable.insert(func);
203   rewriter.getListener()->notifyOperationInserted(func);
204 
205   // Create function entry block.
206   Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs());
207   b.setInsertionPointToEnd(block);
208 
209   unsigned offset = 0; // argument offset for arguments decoding
210 
211   // Returns `numArguments` arguments starting from `offset` and updates offset
212   // by moving forward to the next argument.
213   auto getArguments = [&](unsigned numArguments) -> ArrayRef<Value> {
214     auto args = block->getArguments();
215     auto slice = args.drop_front(offset).take_front(numArguments);
216     offset += numArguments;
217     return {slice.begin(), slice.end()};
218   };
219 
220   // Block iteration position defined by the block index and size.
221   Value blockIndex = block->getArgument(offset++);
222   Value blockSize = block->getArgument(offset++);
223 
224   // Constants used below.
225   Value c0 = b.create<ConstantIndexOp>(0);
226   Value c1 = b.create<ConstantIndexOp>(1);
227 
228   // Multi-dimensional parallel iteration space defined by the loop trip counts.
229   ArrayRef<Value> tripCounts = getArguments(op.getNumLoops());
230 
231   // Compute a product of trip counts to get the size of the flattened
232   // one-dimensional iteration space.
233   Value tripCount = tripCounts[0];
234   for (unsigned i = 1; i < tripCounts.size(); ++i)
235     tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
236 
237   // Parallel operation lower bound and step.
238   ArrayRef<Value> lowerBound = getArguments(op.getNumLoops());
239   offset += op.getNumLoops(); // skip upper bound arguments
240   ArrayRef<Value> step = getArguments(op.getNumLoops());
241 
242   // Remaining arguments are implicit captures of the parallel operation.
243   ArrayRef<Value> captures = getArguments(block->getNumArguments() - offset);
244 
245   // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
246   //   blockFirstIndex = blockIndex * blockSize
247   Value blockFirstIndex = b.create<MulIOp>(blockIndex, blockSize);
248 
249   // The last one-dimensional index in the block defined by the `blockIndex`:
250   //   blockLastIndex = max(blockFirstIndex + blockSize, tripCount) - 1
251   Value blockEnd0 = b.create<AddIOp>(blockFirstIndex, blockSize);
252   Value blockEnd1 = b.create<CmpIOp>(CmpIPredicate::sge, blockEnd0, tripCount);
253   Value blockEnd2 = b.create<SelectOp>(blockEnd1, tripCount, blockEnd0);
254   Value blockLastIndex = b.create<SubIOp>(blockEnd2, c1);
255 
256   // Convert one-dimensional indices to multi-dimensional coordinates.
257   auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
258   auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts);
259 
260   // Compute loops upper bounds derived from the block last coordinates:
261   //   blockEndCoord[i] = blockLastCoord[i] + 1
262   //
263   // Block first and last coordinates can be the same along the outer compute
264   // dimension when inner compute dimension contains multiple blocks.
265   SmallVector<Value> blockEndCoord(op.getNumLoops());
266   for (size_t i = 0; i < blockLastCoord.size(); ++i)
267     blockEndCoord[i] = b.create<AddIOp>(blockLastCoord[i], c1);
268 
269   // Construct a loop nest out of scf.for operations that will iterate over
270   // all coordinates in [blockFirstCoord, blockLastCoord] range.
271   using LoopBodyBuilder =
272       std::function<void(OpBuilder &, Location, Value, ValueRange)>;
273   using LoopNestBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
274 
275   // Parallel region induction variables computed from the multi-dimensional
276   // iteration coordinate using parallel operation bounds and step:
277   //
278   //   computeBlockInductionVars[loopIdx] =
279   //       lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopDdx]
280   SmallVector<Value> computeBlockInductionVars(op.getNumLoops());
281 
282   // We need to know if we are in the first or last iteration of the
283   // multi-dimensional loop for each loop in the nest, so we can decide what
284   // loop bounds should we use for the nested loops: bounds defined by compute
285   // block interval, or bounds defined by the parallel operation.
286   //
287   // Example: 2d parallel operation
288   //                   i   j
289   //   loop sizes:   [50, 50]
290   //   first coord:  [25, 25]
291   //   last coord:   [30, 30]
292   //
293   // If `i` is equal to 25 then iteration over `j` should start at 25, when `i`
294   // is between 25 and 30 it should start at 0. The upper bound for `j` should
295   // be 50, except when `i` is equal to 30, then it should also be 30.
296   //
297   // Value at ith position specifies if all loops in [0, i) range of the loop
298   // nest are in the first/last iteration.
299   SmallVector<Value> isBlockFirstCoord(op.getNumLoops());
300   SmallVector<Value> isBlockLastCoord(op.getNumLoops());
301 
302   // Builds inner loop nest inside async.execute operation that does all the
303   // work concurrently.
304   LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
305     return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv,
306                         ValueRange args) {
307       ImplicitLocOpBuilder nb(loc, nestedBuilder);
308 
309       // Compute induction variable for `loopIdx`.
310       computeBlockInductionVars[loopIdx] = nb.create<AddIOp>(
311           lowerBound[loopIdx], nb.create<MulIOp>(iv, step[loopIdx]));
312 
313       // Check if we are inside first or last iteration of the loop.
314       isBlockFirstCoord[loopIdx] =
315           nb.create<CmpIOp>(CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
316       isBlockLastCoord[loopIdx] =
317           nb.create<CmpIOp>(CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
318 
319       // Check if the previous loop is in its first or last iteration.
320       if (loopIdx > 0) {
321         isBlockFirstCoord[loopIdx] = nb.create<AndOp>(
322             isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
323         isBlockLastCoord[loopIdx] = nb.create<AndOp>(
324             isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
325       }
326 
327       // Keep building loop nest.
328       if (loopIdx < op.getNumLoops() - 1) {
329         // Select nested loop lower/upper bounds depending on out position in
330         // the multi-dimensional iteration space.
331         auto lb = nb.create<SelectOp>(isBlockFirstCoord[loopIdx],
332                                       blockFirstCoord[loopIdx + 1], c0);
333 
334         auto ub = nb.create<SelectOp>(isBlockLastCoord[loopIdx],
335                                       blockEndCoord[loopIdx + 1],
336                                       tripCounts[loopIdx + 1]);
337 
338         nb.create<scf::ForOp>(lb, ub, c1, ValueRange(),
339                               workLoopBuilder(loopIdx + 1));
340         nb.create<scf::YieldOp>(loc);
341         return;
342       }
343 
344       // Copy the body of the parallel op into the inner-most loop.
345       BlockAndValueMapping mapping;
346       mapping.map(op.getInductionVars(), computeBlockInductionVars);
347       mapping.map(computeFuncType.captures, captures);
348 
349       for (auto &bodyOp : op.getLoopBody().getOps())
350         nb.clone(bodyOp, mapping);
351     };
352   };
353 
354   b.create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(),
355                        workLoopBuilder(0));
356   b.create<ReturnOp>(ValueRange());
357 
358   return {func, std::move(computeFuncType.captures)};
359 }
360 
361 // Creates recursive async dispatch function for the given parallel compute
362 // function. Dispatch function keeps splitting block range into halves until it
363 // reaches a single block, and then excecutes it inline.
364 //
365 // Function pseudocode (mix of C++ and MLIR):
366 //
367 //   func @async_dispatch(%block_start : index, %block_end : index, ...) {
368 //
369 //     // Keep splitting block range until we reached a range of size 1.
370 //     while (%block_end - %block_start > 1) {
371 //       %mid_index = block_start + (block_end - block_start) / 2;
372 //       async.execute { call @async_dispatch(%mid_index, %block_end); }
373 //       %block_end = %mid_index
374 //     }
375 //
376 //     // Call parallel compute function for a single block.
377 //     call @parallel_compute_fn(%block_start, %block_size, ...);
378 //   }
379 //
createAsyncDispatchFunction(ParallelComputeFunction & computeFunc,PatternRewriter & rewriter)380 static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
381                                           PatternRewriter &rewriter) {
382   OpBuilder::InsertionGuard guard(rewriter);
383   Location loc = computeFunc.func.getLoc();
384   ImplicitLocOpBuilder b(loc, rewriter);
385 
386   ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
387 
388   ArrayRef<Type> computeFuncInputTypes =
389       computeFunc.func.type().cast<FunctionType>().getInputs();
390 
391   // Compared to the parallel compute function async dispatch function takes
392   // additional !async.group argument. Also instead of a single `blockIndex` it
393   // takes `blockStart` and `blockEnd` arguments to define the range of
394   // dispatched blocks.
395   SmallVector<Type> inputTypes;
396   inputTypes.push_back(async::GroupType::get(rewriter.getContext()));
397   inputTypes.push_back(rewriter.getIndexType()); // add blockStart argument
398   inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
399 
400   FunctionType type = rewriter.getFunctionType(inputTypes, TypeRange());
401   FuncOp func = FuncOp::create(loc, "async_dispatch_fn", type);
402   func.setPrivate();
403 
404   // Insert function into the module symbol table and assign it unique name.
405   SymbolTable symbolTable(module);
406   symbolTable.insert(func);
407   rewriter.getListener()->notifyOperationInserted(func);
408 
409   // Create function entry block.
410   Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs());
411   b.setInsertionPointToEnd(block);
412 
413   Type indexTy = b.getIndexType();
414   Value c1 = b.create<ConstantIndexOp>(1);
415   Value c2 = b.create<ConstantIndexOp>(2);
416 
417   // Get the async group that will track async dispatch completion.
418   Value group = block->getArgument(0);
419 
420   // Get the block iteration range: [blockStart, blockEnd)
421   Value blockStart = block->getArgument(1);
422   Value blockEnd = block->getArgument(2);
423 
424   // Create a work splitting while loop for the [blockStart, blockEnd) range.
425   SmallVector<Type> types = {indexTy, indexTy};
426   SmallVector<Value> operands = {blockStart, blockEnd};
427 
428   // Create a recursive dispatch loop.
429   scf::WhileOp whileOp = b.create<scf::WhileOp>(types, operands);
430   Block *before = b.createBlock(&whileOp.before(), {}, types);
431   Block *after = b.createBlock(&whileOp.after(), {}, types);
432 
433   // Setup dispatch loop condition block: decide if we need to go into the
434   // `after` block and launch one more async dispatch.
435   {
436     b.setInsertionPointToEnd(before);
437     Value start = before->getArgument(0);
438     Value end = before->getArgument(1);
439     Value distance = b.create<SubIOp>(end, start);
440     Value dispatch = b.create<CmpIOp>(CmpIPredicate::sgt, distance, c1);
441     b.create<scf::ConditionOp>(dispatch, before->getArguments());
442   }
443 
444   // Setup the async dispatch loop body: recursively call dispatch function
445   // for the seconds half of the original range and go to the next iteration.
446   {
447     b.setInsertionPointToEnd(after);
448     Value start = after->getArgument(0);
449     Value end = after->getArgument(1);
450     Value distance = b.create<SubIOp>(end, start);
451     Value halfDistance = b.create<SignedDivIOp>(distance, c2);
452     Value midIndex = b.create<AddIOp>(start, halfDistance);
453 
454     // Call parallel compute function inside the async.execute region.
455     auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
456                                   Location executeLoc, ValueRange executeArgs) {
457       // Update the original `blockStart` and `blockEnd` with new range.
458       SmallVector<Value> operands{block->getArguments().begin(),
459                                   block->getArguments().end()};
460       operands[1] = midIndex;
461       operands[2] = end;
462 
463       executeBuilder.create<CallOp>(executeLoc, func.sym_name(),
464                                     func.getCallableResults(), operands);
465       executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
466     };
467 
468     // Create async.execute operation to dispatch half of the block range.
469     auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
470                                        executeBodyBuilder);
471     b.create<AddToGroupOp>(indexTy, execute.token(), group);
472     b.create<scf::YieldOp>(ValueRange({start, midIndex}));
473   }
474 
475   // After dispatching async operations to process the tail of the block range
476   // call the parallel compute function for the first block of the range.
477   b.setInsertionPointAfter(whileOp);
478 
479   // Drop async dispatch specific arguments: async group, block start and end.
480   auto forwardedInputs = block->getArguments().drop_front(3);
481   SmallVector<Value> computeFuncOperands = {blockStart};
482   computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
483 
484   b.create<CallOp>(computeFunc.func.sym_name(),
485                    computeFunc.func.getCallableResults(), computeFuncOperands);
486   b.create<ReturnOp>(ValueRange());
487 
488   return func;
489 }
490 
491 // Launch async dispatch of the parallel compute function.
doAsyncDispatch(ImplicitLocOpBuilder & b,PatternRewriter & rewriter,ParallelComputeFunction & parallelComputeFunction,scf::ParallelOp op,Value blockSize,Value blockCount,const SmallVector<Value> & tripCounts)492 static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
493                             ParallelComputeFunction &parallelComputeFunction,
494                             scf::ParallelOp op, Value blockSize,
495                             Value blockCount,
496                             const SmallVector<Value> &tripCounts) {
497   MLIRContext *ctx = op->getContext();
498 
499   // Add one more level of indirection to dispatch parallel compute functions
500   // using async operations and recursive work splitting.
501   FuncOp asyncDispatchFunction =
502       createAsyncDispatchFunction(parallelComputeFunction, rewriter);
503 
504   Value c0 = b.create<ConstantIndexOp>(0);
505   Value c1 = b.create<ConstantIndexOp>(1);
506 
507   // Create an async.group to wait on all async tokens from the concurrent
508   // execution of multiple parallel compute function. First block will be
509   // executed synchronously in the caller thread.
510   Value groupSize = b.create<SubIOp>(blockCount, c1);
511   Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
512 
513   // Appends operands shared by async dispatch and parallel compute functions to
514   // the given operands vector.
515   auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
516     operands.append(tripCounts);
517     operands.append(op.lowerBound().begin(), op.lowerBound().end());
518     operands.append(op.upperBound().begin(), op.upperBound().end());
519     operands.append(op.step().begin(), op.step().end());
520     operands.append(parallelComputeFunction.captures);
521   };
522 
523   // Check if the block size is one, in this case we can skip the async dispatch
524   // completely. If this will be known statically, then canonicalization will
525   // erase async group operations.
526   Value isSingleBlock = b.create<CmpIOp>(CmpIPredicate::eq, blockCount, c1);
527 
528   auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
529     ImplicitLocOpBuilder nb(loc, nestedBuilder);
530 
531     // Call parallel compute function for the single block.
532     SmallVector<Value> operands = {c0, blockSize};
533     appendBlockComputeOperands(operands);
534 
535     nb.create<CallOp>(parallelComputeFunction.func.sym_name(),
536                       parallelComputeFunction.func.getCallableResults(),
537                       operands);
538     nb.create<scf::YieldOp>();
539   };
540 
541   auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
542     ImplicitLocOpBuilder nb(loc, nestedBuilder);
543 
544     // Launch async dispatch function for [0, blockCount) range.
545     SmallVector<Value> operands = {group, c0, blockCount, blockSize};
546     appendBlockComputeOperands(operands);
547 
548     nb.create<CallOp>(asyncDispatchFunction.sym_name(),
549                       asyncDispatchFunction.getCallableResults(), operands);
550     nb.create<scf::YieldOp>();
551   };
552 
553   // Dispatch either single block compute function, or launch async dispatch.
554   b.create<scf::IfOp>(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch);
555 
556   // Wait for the completion of all parallel compute operations.
557   b.create<AwaitAllOp>(group);
558 }
559 
560 // Dispatch parallel compute functions by submitting all async compute tasks
561 // from a simple for loop in the caller thread.
562 static void
doSequantialDispatch(ImplicitLocOpBuilder & b,PatternRewriter & rewriter,ParallelComputeFunction & parallelComputeFunction,scf::ParallelOp op,Value blockSize,Value blockCount,const SmallVector<Value> & tripCounts)563 doSequantialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
564                      ParallelComputeFunction &parallelComputeFunction,
565                      scf::ParallelOp op, Value blockSize, Value blockCount,
566                      const SmallVector<Value> &tripCounts) {
567   MLIRContext *ctx = op->getContext();
568 
569   FuncOp compute = parallelComputeFunction.func;
570 
571   Value c0 = b.create<ConstantIndexOp>(0);
572   Value c1 = b.create<ConstantIndexOp>(1);
573 
574   // Create an async.group to wait on all async tokens from the concurrent
575   // execution of multiple parallel compute function. First block will be
576   // executed synchronously in the caller thread.
577   Value groupSize = b.create<SubIOp>(blockCount, c1);
578   Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
579 
580   // Call parallel compute function for all blocks.
581   using LoopBodyBuilder =
582       std::function<void(OpBuilder &, Location, Value, ValueRange)>;
583 
584   // Returns parallel compute function operands to process the given block.
585   auto computeFuncOperands = [&](Value blockIndex) -> SmallVector<Value> {
586     SmallVector<Value> computeFuncOperands = {blockIndex, blockSize};
587     computeFuncOperands.append(tripCounts);
588     computeFuncOperands.append(op.lowerBound().begin(), op.lowerBound().end());
589     computeFuncOperands.append(op.upperBound().begin(), op.upperBound().end());
590     computeFuncOperands.append(op.step().begin(), op.step().end());
591     computeFuncOperands.append(parallelComputeFunction.captures);
592     return computeFuncOperands;
593   };
594 
595   // Induction variable is the index of the block: [0, blockCount).
596   LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc,
597                                     Value iv, ValueRange args) {
598     ImplicitLocOpBuilder nb(loc, loopBuilder);
599 
600     // Call parallel compute function inside the async.execute region.
601     auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
602                                   Location executeLoc, ValueRange executeArgs) {
603       executeBuilder.create<CallOp>(executeLoc, compute.sym_name(),
604                                     compute.getCallableResults(),
605                                     computeFuncOperands(iv));
606       executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
607     };
608 
609     // Create async.execute operation to launch parallel computate function.
610     auto execute = nb.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
611                                         executeBodyBuilder);
612     nb.create<AddToGroupOp>(rewriter.getIndexType(), execute.token(), group);
613     nb.create<scf::YieldOp>();
614   };
615 
616   // Iterate over all compute blocks and launch parallel compute operations.
617   b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder);
618 
619   // Call parallel compute function for the first block in the caller thread.
620   b.create<CallOp>(compute.sym_name(), compute.getCallableResults(),
621                    computeFuncOperands(c0));
622 
623   // Wait for the completion of all async compute operations.
624   b.create<AwaitAllOp>(group);
625 }
626 
627 LogicalResult
matchAndRewrite(scf::ParallelOp op,PatternRewriter & rewriter) const628 AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
629                                          PatternRewriter &rewriter) const {
630   // We do not currently support rewrite for parallel op with reductions.
631   if (op.getNumReductions() != 0)
632     return failure();
633 
634   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
635 
636   // Compute trip count for each loop induction variable:
637   //   tripCount = ceil_div(upperBound - lowerBound, step);
638   SmallVector<Value> tripCounts(op.getNumLoops());
639   for (size_t i = 0; i < op.getNumLoops(); ++i) {
640     auto lb = op.lowerBound()[i];
641     auto ub = op.upperBound()[i];
642     auto step = op.step()[i];
643     auto range = b.create<SubIOp>(ub, lb);
644     tripCounts[i] = b.create<SignedCeilDivIOp>(range, step);
645   }
646 
647   // Compute a product of trip counts to get the 1-dimensional iteration space
648   // for the scf.parallel operation.
649   Value tripCount = tripCounts[0];
650   for (size_t i = 1; i < tripCounts.size(); ++i)
651     tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
652 
653   // Short circuit no-op parallel loops (zero iterations) that can arise from
654   // the memrefs with dynamic dimension(s) equal to zero.
655   Value c0 = b.create<ConstantIndexOp>(0);
656   Value isZeroIterations = b.create<CmpIOp>(CmpIPredicate::eq, tripCount, c0);
657 
658   // Do absolutely nothing if the trip count is zero.
659   auto noOp = [&](OpBuilder &nestedBuilder, Location loc) {
660     nestedBuilder.create<scf::YieldOp>(loc);
661   };
662 
663   // Compute the parallel block size and dispatch concurrent tasks computing
664   // results for each block.
665   auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
666     ImplicitLocOpBuilder nb(loc, nestedBuilder);
667 
668     // With large number of threads the value of creating many compute blocks
669     // is reduced because the problem typically becomes memory bound. For small
670     // number of threads it helps with stragglers.
671     float overshardingFactor = numWorkerThreads <= 4    ? 8.0
672                                : numWorkerThreads <= 8  ? 4.0
673                                : numWorkerThreads <= 16 ? 2.0
674                                : numWorkerThreads <= 32 ? 1.0
675                                : numWorkerThreads <= 64 ? 0.8
676                                                         : 0.6;
677 
678     // Do not overload worker threads with too many compute blocks.
679     Value maxComputeBlocks = b.create<ConstantIndexOp>(
680         std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
681 
682     // Target block size from the pass parameters.
683     Value targetComputeBlock = b.create<ConstantIndexOp>(targetBlockSize);
684 
685     // Compute parallel block size from the parallel problem size:
686     //   blockSize = min(tripCount,
687     //                   max(ceil_div(tripCount, maxComputeBlocks),
688     //                       targetComputeBlock))
689     Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
690     Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlock);
691     Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlock);
692     Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
693     Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
694     Value blockCount0 = b.create<SignedCeilDivIOp>(tripCount, blockSize0);
695 
696     // Compute balanced block size for the estimated block count.
697     Value blockSize = b.create<SignedCeilDivIOp>(tripCount, blockCount0);
698     Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize);
699 
700     // Create a parallel compute function that takes a block id and computes the
701     // parallel operation body for a subset of iteration space.
702     ParallelComputeFunction parallelComputeFunction =
703         createParallelComputeFunction(op, rewriter);
704 
705     // Dispatch parallel compute function using async recursive work splitting,
706     // or by submitting compute task sequentially from a caller thread.
707     if (asyncDispatch) {
708       doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
709                       blockCount, tripCounts);
710     } else {
711       doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
712                            blockCount, tripCounts);
713     }
714 
715     nb.create<scf::YieldOp>();
716   };
717 
718   // Replace the `scf.parallel` operation with the parallel compute function.
719   b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch);
720 
721   // Parallel operation was replaced with a block iteration loop.
722   rewriter.eraseOp(op);
723 
724   return success();
725 }
726 
runOnOperation()727 void AsyncParallelForPass::runOnOperation() {
728   MLIRContext *ctx = &getContext();
729 
730   RewritePatternSet patterns(ctx);
731   patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
732                                         targetBlockSize);
733 
734   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
735     signalPassFailure();
736 }
737 
createAsyncParallelForPass()738 std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
739   return std::make_unique<AsyncParallelForPass>();
740 }
741 
742 std::unique_ptr<Pass>
createAsyncParallelForPass(bool asyncDispatch,int32_t numWorkerThreads,int32_t targetBlockSize)743 mlir::createAsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
744                                  int32_t targetBlockSize) {
745   return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
746                                                 targetBlockSize);
747 }
748