1 //===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/LoopUtils.h"
14 
15 #include "mlir/Analysis/AffineAnalysis.h"
16 #include "mlir/Analysis/LoopAnalysis.h"
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Analysis/Utils.h"
19 #include "mlir/Dialect/AffineOps/AffineOps.h"
20 #include "mlir/Dialect/LoopOps/LoopOps.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "mlir/IR/Function.h"
24 #include "mlir/Transforms/RegionUtils.h"
25 #include "mlir/Transforms/Utils.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallPtrSet.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 #define DEBUG_TYPE "LoopUtils"
34 
35 using namespace mlir;
36 using llvm::SetVector;
37 using llvm::SmallMapVector;
38 
39 /// Computes the cleanup loop lower bound of the loop being unrolled with
40 /// the specified unroll factor; this bound will also be upper bound of the main
41 /// part of the unrolled loop. Computes the bound as an AffineMap with its
42 /// operands or a null map when the trip count can't be expressed as an affine
43 /// expression.
getCleanupLoopLowerBound(AffineForOp forOp,unsigned unrollFactor,AffineMap * map,SmallVectorImpl<Value> * operands,OpBuilder & b)44 void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
45                                     AffineMap *map,
46                                     SmallVectorImpl<Value> *operands,
47                                     OpBuilder &b) {
48   auto lbMap = forOp.getLowerBoundMap();
49 
50   // Single result lower bound map only.
51   if (lbMap.getNumResults() != 1) {
52     *map = AffineMap();
53     return;
54   }
55 
56   AffineMap tripCountMap;
57   SmallVector<Value, 4> tripCountOperands;
58   buildTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands);
59 
60   // Sometimes the trip count cannot be expressed as an affine expression.
61   if (!tripCountMap) {
62     *map = AffineMap();
63     return;
64   }
65 
66   unsigned step = forOp.getStep();
67   auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap,
68                                     forOp.getLowerBoundOperands());
69 
70   // For each upper bound expr, get the range.
71   // Eg: affine.for %i = lb to min (ub1, ub2),
72   // where tripCountExprs yield (tr1, tr2), we create affine.apply's:
73   // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all
74   // these affine.apply's make up the cleanup loop lower bound.
75   SmallVector<AffineExpr, 4> bumpExprs(tripCountMap.getNumResults());
76   SmallVector<Value, 4> bumpValues(tripCountMap.getNumResults());
77   for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) {
78     auto tripCountExpr = tripCountMap.getResult(i);
79     bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step;
80     auto bumpMap = AffineMap::get(tripCountMap.getNumDims(),
81                                   tripCountMap.getNumSymbols(), bumpExprs[i]);
82     bumpValues[i] =
83         b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands);
84   }
85 
86   SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults());
87   for (unsigned i = 0, e = bumpExprs.size(); i < e; i++)
88     newUbExprs[i] = b.getAffineDimExpr(0) + b.getAffineDimExpr(i + 1);
89 
90   operands->clear();
91   operands->push_back(lb);
92   operands->append(bumpValues.begin(), bumpValues.end());
93   *map = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs);
94   // Simplify the map + operands.
95   fullyComposeAffineMapAndOperands(map, operands);
96   *map = simplifyAffineMap(*map);
97   canonicalizeMapAndOperands(map, operands);
98   // Remove any affine.apply's that became dead from the simplification above.
99   for (auto v : bumpValues) {
100     if (v.use_empty())
101       v.getDefiningOp()->erase();
102   }
103   if (lb.use_empty())
104     lb.erase();
105 }
106 
107 /// Promotes the loop body of a forOp to its containing block if the forOp
108 /// was known to have a single iteration.
109 // TODO(bondhugula): extend this for arbitrary affine bounds.
promoteIfSingleIteration(AffineForOp forOp)110 LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
111   Optional<uint64_t> tripCount = getConstantTripCount(forOp);
112   if (!tripCount.hasValue() || tripCount.getValue() != 1)
113     return failure();
114 
115   // TODO(mlir-team): there is no builder for a max.
116   if (forOp.getLowerBoundMap().getNumResults() != 1)
117     return failure();
118 
119   // Replaces all IV uses to its single iteration value.
120   auto iv = forOp.getInductionVar();
121   Operation *op = forOp.getOperation();
122   if (!iv.use_empty()) {
123     if (forOp.hasConstantLowerBound()) {
124       OpBuilder topBuilder(op->getParentOfType<FuncOp>().getBody());
125       auto constOp = topBuilder.create<ConstantIndexOp>(
126           forOp.getLoc(), forOp.getConstantLowerBound());
127       iv.replaceAllUsesWith(constOp);
128     } else {
129       AffineBound lb = forOp.getLowerBound();
130       SmallVector<Value, 4> lbOperands(lb.operand_begin(), lb.operand_end());
131       OpBuilder builder(op->getBlock(), Block::iterator(op));
132       if (lb.getMap() == builder.getDimIdentityMap()) {
133         // No need of generating an affine.apply.
134         iv.replaceAllUsesWith(lbOperands[0]);
135       } else {
136         auto affineApplyOp = builder.create<AffineApplyOp>(
137             op->getLoc(), lb.getMap(), lbOperands);
138         iv.replaceAllUsesWith(affineApplyOp);
139       }
140     }
141   }
142   // Move the loop body operations, except for terminator, to the loop's
143   // containing block.
144   auto *block = op->getBlock();
145   forOp.getBody()->getOperations().back().erase();
146   block->getOperations().splice(Block::iterator(op),
147                                 forOp.getBody()->getOperations());
148   forOp.erase();
149   return success();
150 }
151 
152 /// Promotes all single iteration for op's in the FuncOp, i.e., moves
153 /// their body into the containing Block.
promoteSingleIterationLoops(FuncOp f)154 void mlir::promoteSingleIterationLoops(FuncOp f) {
155   // Gathers all innermost loops through a post order pruned walk.
156   f.walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); });
157 }
158 
159 /// Generates a 'affine.for' op with the specified lower and upper bounds
160 /// while generating the right IV remappings for the shifted operations. The
161 /// operation blocks that go into the loop are specified in instGroupQueue
162 /// starting from the specified offset, and in that order; the first element of
163 /// the pair specifies the shift applied to that group of operations; note
164 /// that the shift is multiplied by the loop step before being applied. Returns
165 /// nullptr if the generated loop simplifies to a single iteration one.
166 static AffineForOp
generateLoop(AffineMap lbMap,AffineMap ubMap,const std::vector<std::pair<uint64_t,ArrayRef<Operation * >>> & instGroupQueue,unsigned offset,AffineForOp srcForInst,OpBuilder b)167 generateLoop(AffineMap lbMap, AffineMap ubMap,
168              const std::vector<std::pair<uint64_t, ArrayRef<Operation *>>>
169                  &instGroupQueue,
170              unsigned offset, AffineForOp srcForInst, OpBuilder b) {
171   SmallVector<Value, 4> lbOperands(srcForInst.getLowerBoundOperands());
172   SmallVector<Value, 4> ubOperands(srcForInst.getUpperBoundOperands());
173 
174   assert(lbMap.getNumInputs() == lbOperands.size());
175   assert(ubMap.getNumInputs() == ubOperands.size());
176 
177   auto loopChunk =
178       b.create<AffineForOp>(srcForInst.getLoc(), lbOperands, lbMap, ubOperands,
179                             ubMap, srcForInst.getStep());
180   auto loopChunkIV = loopChunk.getInductionVar();
181   auto srcIV = srcForInst.getInductionVar();
182 
183   BlockAndValueMapping operandMap;
184 
185   OpBuilder bodyBuilder = loopChunk.getBodyBuilder();
186   for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end();
187        it != e; ++it) {
188     uint64_t shift = it->first;
189     auto insts = it->second;
190     // All 'same shift' operations get added with their operands being
191     // remapped to results of cloned operations, and their IV used remapped.
192     // Generate the remapping if the shift is not zero: remappedIV = newIV -
193     // shift.
194     if (!srcIV.use_empty() && shift != 0) {
195       auto ivRemap = bodyBuilder.create<AffineApplyOp>(
196           srcForInst.getLoc(),
197           bodyBuilder.getSingleDimShiftAffineMap(
198               -static_cast<int64_t>(srcForInst.getStep() * shift)),
199           loopChunkIV);
200       operandMap.map(srcIV, ivRemap);
201     } else {
202       operandMap.map(srcIV, loopChunkIV);
203     }
204     for (auto *op : insts) {
205       if (!isa<AffineTerminatorOp>(op))
206         bodyBuilder.clone(*op, operandMap);
207     }
208   };
209   if (succeeded(promoteIfSingleIteration(loopChunk)))
210     return AffineForOp();
211   return loopChunk;
212 }
213 
214 /// Skew the operations in the body of a 'affine.for' operation with the
215 /// specified operation-wise shifts. The shifts are with respect to the
216 /// original execution order, and are multiplied by the loop 'step' before being
217 /// applied. A shift of zero for each operation will lead to no change.
218 // The skewing of operations with respect to one another can be used for
219 // example to allow overlap of asynchronous operations (such as DMA
220 // communication) with computation, or just relative shifting of operations
221 // for better register reuse, locality or parallelism. As such, the shifts are
222 // typically expected to be at most of the order of the number of operations.
223 // This method should not be used as a substitute for loop distribution/fission.
224 // This method uses an algorithm// in time linear in the number of operations
225 // in the body of the for loop - (using the 'sweep line' paradigm). This method
226 // asserts preservation of SSA dominance. A check for that as well as that for
227 // memory-based dependence preservation check rests with the users of this
228 // method.
instBodySkew(AffineForOp forOp,ArrayRef<uint64_t> shifts,bool unrollPrologueEpilogue)229 LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
230                                  bool unrollPrologueEpilogue) {
231   if (forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
232     return success();
233 
234   // If the trip counts aren't constant, we would need versioning and
235   // conditional guards (or context information to prevent such versioning). The
236   // better way to pipeline for such loops is to first tile them and extract
237   // constant trip count "full tiles" before applying this.
238   auto mayBeConstTripCount = getConstantTripCount(forOp);
239   if (!mayBeConstTripCount.hasValue()) {
240     LLVM_DEBUG(forOp.emitRemark("non-constant trip count loop not handled"));
241     return success();
242   }
243   uint64_t tripCount = mayBeConstTripCount.getValue();
244 
245   assert(isInstwiseShiftValid(forOp, shifts) &&
246          "shifts will lead to an invalid transformation\n");
247 
248   int64_t step = forOp.getStep();
249 
250   unsigned numChildInsts = forOp.getBody()->getOperations().size();
251 
252   // Do a linear time (counting) sort for the shifts.
253   uint64_t maxShift = 0;
254   for (unsigned i = 0; i < numChildInsts; i++) {
255     maxShift = std::max(maxShift, shifts[i]);
256   }
257   // Such large shifts are not the typical use case.
258   if (maxShift >= numChildInsts) {
259     forOp.emitWarning("not shifting because shifts are unrealistically large");
260     return success();
261   }
262 
263   // An array of operation groups sorted by shift amount; each group has all
264   // operations with the same shift in the order in which they appear in the
265   // body of the 'affine.for' op.
266   std::vector<std::vector<Operation *>> sortedInstGroups(maxShift + 1);
267   unsigned pos = 0;
268   for (auto &op : *forOp.getBody()) {
269     auto shift = shifts[pos++];
270     sortedInstGroups[shift].push_back(&op);
271   }
272 
273   // Unless the shifts have a specific pattern (which actually would be the
274   // common use case), prologue and epilogue are not meaningfully defined.
275   // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
276   // loop generated as the prologue and the last as epilogue and unroll these
277   // fully.
278   AffineForOp prologue;
279   AffineForOp epilogue;
280 
281   // Do a sweep over the sorted shifts while storing open groups in a
282   // vector, and generating loop portions as necessary during the sweep. A block
283   // of operations is paired with its shift.
284   std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> instGroupQueue;
285 
286   auto origLbMap = forOp.getLowerBoundMap();
287   uint64_t lbShift = 0;
288   OpBuilder b(forOp.getOperation());
289   for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) {
290     // If nothing is shifted by d, continue.
291     if (sortedInstGroups[d].empty())
292       continue;
293     if (!instGroupQueue.empty()) {
294       assert(d >= 1 &&
295              "Queue expected to be empty when the first block is found");
296       // The interval for which the loop needs to be generated here is:
297       // [lbShift, min(lbShift + tripCount, d)) and the body of the
298       // loop needs to have all operations in instQueue in that order.
299       AffineForOp res;
300       if (lbShift + tripCount * step < d * step) {
301         res = generateLoop(
302             b.getShiftedAffineMap(origLbMap, lbShift),
303             b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
304             instGroupQueue, 0, forOp, b);
305         // Entire loop for the queued op groups generated, empty it.
306         instGroupQueue.clear();
307         lbShift += tripCount * step;
308       } else {
309         res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
310                            b.getShiftedAffineMap(origLbMap, d), instGroupQueue,
311                            0, forOp, b);
312         lbShift = d * step;
313       }
314       if (!prologue && res)
315         prologue = res;
316       epilogue = res;
317     } else {
318       // Start of first interval.
319       lbShift = d * step;
320     }
321     // Augment the list of operations that get into the current open interval.
322     instGroupQueue.push_back({d, sortedInstGroups[d]});
323   }
324 
325   // Those operations groups left in the queue now need to be processed (FIFO)
326   // and their loops completed.
327   for (unsigned i = 0, e = instGroupQueue.size(); i < e; ++i) {
328     uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step;
329     epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
330                             b.getShiftedAffineMap(origLbMap, ubShift),
331                             instGroupQueue, i, forOp, b);
332     lbShift = ubShift;
333     if (!prologue)
334       prologue = epilogue;
335   }
336 
337   // Erase the original for op.
338   forOp.erase();
339 
340   if (unrollPrologueEpilogue && prologue)
341     loopUnrollFull(prologue);
342   if (unrollPrologueEpilogue && !epilogue &&
343       epilogue.getOperation() != prologue.getOperation())
344     loopUnrollFull(epilogue);
345 
346   return success();
347 }
348 
349 // Collect perfectly nested loops starting from `rootForOps`.  Loops are
350 // perfectly nested if each loop is the first and only non-terminator operation
351 // in the parent loop.  Collect at most `maxLoops` loops and append them to
352 // `forOps`.
353 template <typename T>
getPerfectlyNestedLoopsImpl(SmallVectorImpl<T> & forOps,T rootForOp,unsigned maxLoops=std::numeric_limits<unsigned>::max ())354 static void getPerfectlyNestedLoopsImpl(
355     SmallVectorImpl<T> &forOps, T rootForOp,
356     unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
357   for (unsigned i = 0; i < maxLoops; ++i) {
358     forOps.push_back(rootForOp);
359     Block &body = rootForOp.region().front();
360     if (body.begin() != std::prev(body.end(), 2))
361       return;
362 
363     rootForOp = dyn_cast<T>(&body.front());
364     if (!rootForOp)
365       return;
366   }
367 }
368 
369 /// Get perfectly nested sequence of loops starting at root of loop nest
370 /// (the first op being another AffineFor, and the second op - a terminator).
371 /// A loop is perfectly nested iff: the first op in the loop's body is another
372 /// AffineForOp, and the second op is a terminator).
getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> & nestedLoops,AffineForOp root)373 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
374                                    AffineForOp root) {
375   getPerfectlyNestedLoopsImpl(nestedLoops, root);
376 }
377 
getPerfectlyNestedLoops(SmallVectorImpl<loop::ForOp> & nestedLoops,loop::ForOp root)378 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<loop::ForOp> &nestedLoops,
379                                    loop::ForOp root) {
380   getPerfectlyNestedLoopsImpl(nestedLoops, root);
381 }
382 
383 /// Unrolls this loop completely.
loopUnrollFull(AffineForOp forOp)384 LogicalResult mlir::loopUnrollFull(AffineForOp forOp) {
385   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
386   if (mayBeConstantTripCount.hasValue()) {
387     uint64_t tripCount = mayBeConstantTripCount.getValue();
388     if (tripCount == 1) {
389       return promoteIfSingleIteration(forOp);
390     }
391     return loopUnrollByFactor(forOp, tripCount);
392   }
393   return failure();
394 }
395 
396 /// Unrolls and jams this loop by the specified factor or by the trip count (if
397 /// constant) whichever is lower.
loopUnrollUpToFactor(AffineForOp forOp,uint64_t unrollFactor)398 LogicalResult mlir::loopUnrollUpToFactor(AffineForOp forOp,
399                                          uint64_t unrollFactor) {
400   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
401 
402   if (mayBeConstantTripCount.hasValue() &&
403       mayBeConstantTripCount.getValue() < unrollFactor)
404     return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue());
405   return loopUnrollByFactor(forOp, unrollFactor);
406 }
407 
408 /// Unrolls this loop by the specified factor. Returns success if the loop
409 /// is successfully unrolled.
loopUnrollByFactor(AffineForOp forOp,uint64_t unrollFactor)410 LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp,
411                                        uint64_t unrollFactor) {
412   assert(unrollFactor >= 1 && "unroll factor should be >= 1");
413 
414   if (unrollFactor == 1)
415     return promoteIfSingleIteration(forOp);
416 
417   if (forOp.getBody()->empty() ||
418       forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
419     return failure();
420 
421   // Loops where the lower bound is a max expression isn't supported for
422   // unrolling since the trip count can be expressed as an affine function when
423   // both the lower bound and the upper bound are multi-result maps. However,
424   // one meaningful way to do such unrolling would be to specialize the loop for
425   // the 'hotspot' case and unroll that hotspot.
426   if (forOp.getLowerBoundMap().getNumResults() != 1)
427     return failure();
428 
429   // If the trip count is lower than the unroll factor, no unrolled body.
430   // TODO(bondhugula): option to specify cleanup loop unrolling.
431   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
432   if (mayBeConstantTripCount.hasValue() &&
433       mayBeConstantTripCount.getValue() < unrollFactor)
434     return failure();
435 
436   // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
437   Operation *op = forOp.getOperation();
438   if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
439     OpBuilder builder(op->getBlock(), ++Block::iterator(op));
440     auto cleanupForInst = cast<AffineForOp>(builder.clone(*op));
441     AffineMap cleanupMap;
442     SmallVector<Value, 4> cleanupOperands;
443     getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands,
444                              builder);
445     assert(cleanupMap &&
446            "cleanup loop lower bound map for single result lower bound maps "
447            "can always be determined");
448     cleanupForInst.setLowerBound(cleanupOperands, cleanupMap);
449     // Promote the loop body up if this has turned into a single iteration loop.
450     promoteIfSingleIteration(cleanupForInst);
451 
452     // Adjust upper bound of the original loop; this is the same as the lower
453     // bound of the cleanup loop.
454     forOp.setUpperBound(cleanupOperands, cleanupMap);
455   }
456 
457   // Scale the step of loop being unrolled by unroll factor.
458   int64_t step = forOp.getStep();
459   forOp.setStep(step * unrollFactor);
460 
461   // Builder to insert unrolled bodies just before the terminator of the body of
462   // 'forOp'.
463   OpBuilder builder = forOp.getBodyBuilder();
464 
465   // Keep a pointer to the last non-terminator operation in the original block
466   // so that we know what to clone (since we are doing this in-place).
467   Block::iterator srcBlockEnd = std::prev(forOp.getBody()->end(), 2);
468 
469   // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies).
470   auto forOpIV = forOp.getInductionVar();
471   for (unsigned i = 1; i < unrollFactor; i++) {
472     BlockAndValueMapping operandMap;
473 
474     // If the induction variable is used, create a remapping to the value for
475     // this unrolled instance.
476     if (!forOpIV.use_empty()) {
477       // iv' = iv + 1/2/3...unrollFactor-1;
478       auto d0 = builder.getAffineDimExpr(0);
479       auto bumpMap = AffineMap::get(1, 0, {d0 + i * step});
480       auto ivUnroll =
481           builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
482       operandMap.map(forOpIV, ivUnroll);
483     }
484 
485     // Clone the original body of 'forOp'.
486     for (auto it = forOp.getBody()->begin(); it != std::next(srcBlockEnd);
487          it++) {
488       builder.clone(*it, operandMap);
489     }
490   }
491 
492   // Promote the loop body up if this has turned into a single iteration loop.
493   promoteIfSingleIteration(forOp);
494   return success();
495 }
496 
497 /// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is
498 /// nested within 'forOpA' as the only non-terminator operation in its block.
interchangeLoops(AffineForOp forOpA,AffineForOp forOpB)499 void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) {
500   auto *forOpAInst = forOpA.getOperation();
501 
502   assert(&*forOpA.getBody()->begin() == forOpB.getOperation());
503   auto &forOpABody = forOpA.getBody()->getOperations();
504   auto &forOpBBody = forOpB.getBody()->getOperations();
505 
506   // 1) Splice forOpA's non-terminator operations (which is just forOpB) just
507   // before forOpA (in ForOpA's parent's block) this should leave 'forOpA's
508   // body containing only the terminator.
509   forOpAInst->getBlock()->getOperations().splice(Block::iterator(forOpAInst),
510                                                  forOpABody, forOpABody.begin(),
511                                                  std::prev(forOpABody.end()));
512   // 2) Splice forOpB's non-terminator operations into the beginning of forOpA's
513   // body (this leaves forOpB's body containing only the terminator).
514   forOpABody.splice(forOpABody.begin(), forOpBBody, forOpBBody.begin(),
515                     std::prev(forOpBBody.end()));
516   // 3) Splice forOpA into the beginning of forOpB's body.
517   forOpBBody.splice(forOpBBody.begin(), forOpAInst->getBlock()->getOperations(),
518                     Block::iterator(forOpAInst));
519 }
520 
521 // Checks each dependence component against the permutation to see if the
522 // desired loop interchange would violate dependences by making the
523 // dependence component lexicographically negative.
checkLoopInterchangeDependences(const std::vector<SmallVector<DependenceComponent,2>> & depCompsVec,ArrayRef<AffineForOp> loops,ArrayRef<unsigned> loopPermMap)524 static bool checkLoopInterchangeDependences(
525     const std::vector<SmallVector<DependenceComponent, 2>> &depCompsVec,
526     ArrayRef<AffineForOp> loops, ArrayRef<unsigned> loopPermMap) {
527   // Invert permutation map.
528   unsigned maxLoopDepth = loops.size();
529   SmallVector<unsigned, 4> loopPermMapInv;
530   loopPermMapInv.resize(maxLoopDepth);
531   for (unsigned i = 0; i < maxLoopDepth; ++i)
532     loopPermMapInv[loopPermMap[i]] = i;
533 
534   // Check each dependence component against the permutation to see if the
535   // desired loop interchange permutation would make the dependence vectors
536   // lexicographically negative.
537   // Example 1: [-1, 1][0, 0]
538   // Example 2: [0, 0][-1, 1]
539   for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
540     const SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
541     assert(depComps.size() >= maxLoopDepth);
542     // Check if the first non-zero dependence component is positive.
543     // This iterates through loops in the desired order.
544     for (unsigned j = 0; j < maxLoopDepth; ++j) {
545       unsigned permIndex = loopPermMapInv[j];
546       assert(depComps[permIndex].lb.hasValue());
547       int64_t depCompLb = depComps[permIndex].lb.getValue();
548       if (depCompLb > 0)
549         break;
550       if (depCompLb < 0)
551         return false;
552     }
553   }
554   return true;
555 }
556 
557 /// Checks if the loop interchange permutation 'loopPermMap' of the perfectly
558 /// nested sequence of loops in 'loops' would violate dependences.
isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,ArrayRef<unsigned> loopPermMap)559 bool mlir::isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
560                                              ArrayRef<unsigned> loopPermMap) {
561   // Gather dependence components for dependences between all ops in loop nest
562   // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
563   assert(loopPermMap.size() == loops.size());
564   unsigned maxLoopDepth = loops.size();
565   std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
566   getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
567   return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap);
568 }
569 
570 /// Performs a sequence of loop interchanges of loops in perfectly nested
571 /// sequence of loops in 'loops', as specified by permutation in 'loopPermMap'.
interchangeLoops(ArrayRef<AffineForOp> loops,ArrayRef<unsigned> loopPermMap)572 unsigned mlir::interchangeLoops(ArrayRef<AffineForOp> loops,
573                                 ArrayRef<unsigned> loopPermMap) {
574   Optional<unsigned> loopNestRootIndex;
575   for (int i = loops.size() - 1; i >= 0; --i) {
576     int permIndex = static_cast<int>(loopPermMap[i]);
577     // Store the index of the for loop which will be the new loop nest root.
578     if (permIndex == 0)
579       loopNestRootIndex = i;
580     if (permIndex > i) {
581       // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest.
582       sinkLoop(loops[i], permIndex - i);
583     }
584   }
585   assert(loopNestRootIndex.hasValue());
586   return loopNestRootIndex.getValue();
587 }
588 
589 // Sinks all sequential loops to the innermost levels (while preserving
590 // relative order among them) and moves all parallel loops to the
591 // outermost (while again preserving relative order among them).
sinkSequentialLoops(AffineForOp forOp)592 AffineForOp mlir::sinkSequentialLoops(AffineForOp forOp) {
593   SmallVector<AffineForOp, 4> loops;
594   getPerfectlyNestedLoops(loops, forOp);
595   if (loops.size() < 2)
596     return forOp;
597 
598   // Gather dependence components for dependences between all ops in loop nest
599   // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
600   unsigned maxLoopDepth = loops.size();
601   std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
602   getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
603 
604   // Mark loops as either parallel or sequential.
605   SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true);
606   for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
607     SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
608     assert(depComps.size() >= maxLoopDepth);
609     for (unsigned j = 0; j < maxLoopDepth; ++j) {
610       DependenceComponent &depComp = depComps[j];
611       assert(depComp.lb.hasValue() && depComp.ub.hasValue());
612       if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0)
613         isParallelLoop[j] = false;
614     }
615   }
616 
617   // Count the number of parallel loops.
618   unsigned numParallelLoops = 0;
619   for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i)
620     if (isParallelLoop[i])
621       ++numParallelLoops;
622 
623   // Compute permutation of loops that sinks sequential loops (and thus raises
624   // parallel loops) while preserving relative order.
625   SmallVector<unsigned, 4> loopPermMap(maxLoopDepth);
626   unsigned nextSequentialLoop = numParallelLoops;
627   unsigned nextParallelLoop = 0;
628   for (unsigned i = 0; i < maxLoopDepth; ++i) {
629     if (isParallelLoop[i]) {
630       loopPermMap[i] = nextParallelLoop++;
631     } else {
632       loopPermMap[i] = nextSequentialLoop++;
633     }
634   }
635 
636   // Check if permutation 'loopPermMap' would violate dependences.
637   if (!checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap))
638     return forOp;
639   // Perform loop interchange according to permutation 'loopPermMap'.
640   unsigned loopNestRootIndex = interchangeLoops(loops, loopPermMap);
641   return loops[loopNestRootIndex];
642 }
643 
644 /// Performs a series of loop interchanges to sink 'forOp' 'loopDepth' levels
645 /// deeper in the loop nest.
sinkLoop(AffineForOp forOp,unsigned loopDepth)646 void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) {
647   for (unsigned i = 0; i < loopDepth; ++i) {
648     AffineForOp nextForOp = cast<AffineForOp>(forOp.getBody()->front());
649     interchangeLoops(forOp, nextForOp);
650   }
651 }
652 
653 // Factors out common behavior to add a new `iv` (resp. `iv` + `offset`) to the
654 // lower (resp. upper) loop bound. When called for both the lower and upper
655 // bounds, the resulting IR resembles:
656 //
657 // ```mlir
658 //    affine.for %i = max (`iv, ...) to min (`iv` + `offset`) {
659 //      ...
660 //    }
661 // ```
augmentMapAndBounds(OpBuilder & b,Value iv,AffineMap * map,SmallVector<Value,4> * operands,int64_t offset=0)662 static void augmentMapAndBounds(OpBuilder &b, Value iv, AffineMap *map,
663                                 SmallVector<Value, 4> *operands,
664                                 int64_t offset = 0) {
665   auto bounds = llvm::to_vector<4>(map->getResults());
666   bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset);
667   operands->insert(operands->begin() + map->getNumDims(), iv);
668   *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds);
669   canonicalizeMapAndOperands(map, operands);
670 }
671 
672 // Stripmines `forOp` by `factor` and sinks it under each of the `targets`.
673 // Stripmine-sink is a primitive building block for generalized tiling of
674 // imperfectly nested loops.
675 // This transformation is purely mechanical and does not check legality,
676 // profitability or even structural correctness. It is the user's
677 // responsibility to specify `targets` that are dominated by `forOp`.
678 // Returns the new AffineForOps, one per `targets`, nested immediately under
679 // each of the `targets`.
680 static SmallVector<AffineForOp, 8>
stripmineSink(AffineForOp forOp,uint64_t factor,ArrayRef<AffineForOp> targets)681 stripmineSink(AffineForOp forOp, uint64_t factor,
682               ArrayRef<AffineForOp> targets) {
683   auto originalStep = forOp.getStep();
684   auto scaledStep = originalStep * factor;
685   forOp.setStep(scaledStep);
686 
687   auto *op = forOp.getOperation();
688   OpBuilder b(op->getBlock(), ++Block::iterator(op));
689 
690   // Lower-bound map creation.
691   auto lbMap = forOp.getLowerBoundMap();
692   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
693   augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands);
694 
695   // Upper-bound map creation.
696   auto ubMap = forOp.getUpperBoundMap();
697   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
698   augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands,
699                       /*offset=*/scaledStep);
700 
701   auto iv = forOp.getInductionVar();
702   SmallVector<AffineForOp, 8> innerLoops;
703   for (auto t : targets) {
704     // Insert newForOp before the terminator of `t`.
705     OpBuilder b = t.getBodyBuilder();
706     auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap,
707                                           ubOperands, ubMap, originalStep);
708     auto begin = t.getBody()->begin();
709     // Skip terminator and `newForOp` which is just before the terminator.
710     auto nOps = t.getBody()->getOperations().size() - 2;
711     newForOp.getBody()->getOperations().splice(
712         newForOp.getBody()->getOperations().begin(),
713         t.getBody()->getOperations(), begin, std::next(begin, nOps));
714     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
715                                newForOp.region());
716     innerLoops.push_back(newForOp);
717   }
718 
719   return innerLoops;
720 }
721 
stripmineSink(loop::ForOp forOp,Value factor,ArrayRef<loop::ForOp> targets)722 static Loops stripmineSink(loop::ForOp forOp, Value factor,
723                            ArrayRef<loop::ForOp> targets) {
724   auto originalStep = forOp.step();
725   auto iv = forOp.getInductionVar();
726 
727   OpBuilder b(forOp);
728   forOp.setStep(b.create<MulIOp>(forOp.getLoc(), originalStep, factor));
729 
730   Loops innerLoops;
731   for (auto t : targets) {
732     // Save information for splicing ops out of t when done
733     auto begin = t.getBody()->begin();
734     auto nOps = t.getBody()->getOperations().size();
735 
736     // Insert newForOp before the terminator of `t`.
737     OpBuilder b(t.getBodyBuilder());
738     Value stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
739     Value less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::slt,
740                                   forOp.upperBound(), stepped);
741     Value ub =
742         b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped);
743 
744     // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
745     auto newForOp = b.create<loop::ForOp>(t.getLoc(), iv, ub, originalStep);
746     newForOp.getBody()->getOperations().splice(
747         newForOp.getBody()->getOperations().begin(),
748         t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
749     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
750                                newForOp.region());
751 
752     innerLoops.push_back(newForOp);
753   }
754 
755   return innerLoops;
756 }
757 
758 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
759 // Returns the new AffineForOps, nested immediately under `target`.
760 template <typename ForType, typename SizeType>
stripmineSink(ForType forOp,SizeType factor,ForType target)761 static ForType stripmineSink(ForType forOp, SizeType factor, ForType target) {
762   // TODO(ntv): Use cheap structural assertions that targets are nested under
763   // forOp and that targets are not nested under each other when DominanceInfo
764   // exposes the capability. It seems overkill to construct a whole function
765   // dominance tree at this point.
766   auto res = stripmineSink(forOp, factor, ArrayRef<ForType>{target});
767   assert(res.size() == 1 && "Expected 1 inner forOp");
768   return res[0];
769 }
770 
771 template <typename ForType, typename SizeType>
772 static SmallVector<SmallVector<ForType, 8>, 8>
tileImpl(ArrayRef<ForType> forOps,ArrayRef<SizeType> sizes,ArrayRef<ForType> targets)773 tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes,
774          ArrayRef<ForType> targets) {
775   SmallVector<SmallVector<ForType, 8>, 8> res;
776   SmallVector<ForType, 8> currentTargets(targets.begin(), targets.end());
777   for (auto it : llvm::zip(forOps, sizes)) {
778     auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
779     res.push_back(step);
780     currentTargets = step;
781   }
782   return res;
783 }
784 
785 SmallVector<SmallVector<AffineForOp, 8>, 8>
tile(ArrayRef<AffineForOp> forOps,ArrayRef<uint64_t> sizes,ArrayRef<AffineForOp> targets)786 mlir::tile(ArrayRef<AffineForOp> forOps, ArrayRef<uint64_t> sizes,
787            ArrayRef<AffineForOp> targets) {
788   return tileImpl(forOps, sizes, targets);
789 }
790 
tile(ArrayRef<loop::ForOp> forOps,ArrayRef<Value> sizes,ArrayRef<loop::ForOp> targets)791 SmallVector<Loops, 8> mlir::tile(ArrayRef<loop::ForOp> forOps,
792                                  ArrayRef<Value> sizes,
793                                  ArrayRef<loop::ForOp> targets) {
794   return tileImpl(forOps, sizes, targets);
795 }
796 
797 template <typename ForType, typename SizeType>
798 static SmallVector<ForType, 8>
tileImpl(ArrayRef<ForType> forOps,ArrayRef<SizeType> sizes,ForType target)799 tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes, ForType target) {
800   SmallVector<ForType, 8> res;
801   for (auto loops : tile(forOps, sizes, ArrayRef<ForType>{target})) {
802     assert(loops.size() == 1);
803     res.push_back(loops[0]);
804   }
805   return res;
806 }
807 
tile(ArrayRef<AffineForOp> forOps,ArrayRef<uint64_t> sizes,AffineForOp target)808 SmallVector<AffineForOp, 8> mlir::tile(ArrayRef<AffineForOp> forOps,
809                                        ArrayRef<uint64_t> sizes,
810                                        AffineForOp target) {
811   return tileImpl(forOps, sizes, target);
812 }
813 
tile(ArrayRef<loop::ForOp> forOps,ArrayRef<Value> sizes,loop::ForOp target)814 Loops mlir::tile(ArrayRef<loop::ForOp> forOps, ArrayRef<Value> sizes,
815                  loop::ForOp target) {
816   return tileImpl(forOps, sizes, target);
817 }
818 
tilePerfectlyNested(loop::ForOp rootForOp,ArrayRef<Value> sizes)819 Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef<Value> sizes) {
820   // Collect perfectly nested loops.  If more size values provided than nested
821   // loops available, truncate `sizes`.
822   SmallVector<loop::ForOp, 4> forOps;
823   forOps.reserve(sizes.size());
824   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
825   if (forOps.size() < sizes.size())
826     sizes = sizes.take_front(forOps.size());
827 
828   return ::tile(forOps, sizes, forOps.back());
829 }
830 
831 // Build the IR that performs ceil division of a positive value by a constant:
832 //    ceildiv(a, B) = divis(a + (B-1), B)
833 // where divis is rounding-to-zero division.
ceilDivPositive(OpBuilder & builder,Location loc,Value dividend,int64_t divisor)834 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
835                              int64_t divisor) {
836   assert(divisor > 0 && "expected positive divisor");
837   assert(dividend.getType().isIndex() && "expected index-typed value");
838 
839   Value divisorMinusOneCst = builder.create<ConstantIndexOp>(loc, divisor - 1);
840   Value divisorCst = builder.create<ConstantIndexOp>(loc, divisor);
841   Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOneCst);
842   return builder.create<SignedDivIOp>(loc, sum, divisorCst);
843 }
844 
845 // Build the IR that performs ceil division of a positive value by another
846 // positive value:
847 //    ceildiv(a, b) = divis(a + (b - 1), b)
848 // where divis is rounding-to-zero division.
ceilDivPositive(OpBuilder & builder,Location loc,Value dividend,Value divisor)849 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
850                              Value divisor) {
851   assert(dividend.getType().isIndex() && "expected index-typed value");
852 
853   Value cstOne = builder.create<ConstantIndexOp>(loc, 1);
854   Value divisorMinusOne = builder.create<SubIOp>(loc, divisor, cstOne);
855   Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOne);
856   return builder.create<SignedDivIOp>(loc, sum, divisor);
857 }
858 
859 // Hoist the ops within `outer` that appear before `inner`.
860 // Such ops include the ops that have been introduced by parametric tiling.
861 // Ops that come from triangular loops (i.e. that belong to the program slice
862 // rooted at `outer`) and ops that have side effects cannot be hoisted.
863 // Return failure when any op fails to hoist.
hoistOpsBetween(loop::ForOp outer,loop::ForOp inner)864 static LogicalResult hoistOpsBetween(loop::ForOp outer, loop::ForOp inner) {
865   SetVector<Operation *> forwardSlice;
866   getForwardSlice(outer.getOperation(), &forwardSlice, [&inner](Operation *op) {
867     return op != inner.getOperation();
868   });
869   LogicalResult status = success();
870   SmallVector<Operation *, 8> toHoist;
871   for (auto &op : outer.getBody()->getOperations()) {
872     // Stop when encountering the inner loop.
873     if (&op == inner.getOperation())
874       break;
875     // Skip over non-hoistable ops.
876     if (forwardSlice.count(&op) > 0) {
877       status = failure();
878       continue;
879     }
880     // Skip loop::ForOp, these are not considered a failure.
881     if (op.getNumRegions() > 0)
882       continue;
883     // Skip other ops with regions.
884     if (op.getNumRegions() > 0) {
885       status = failure();
886       continue;
887     }
888     // Skip if op has side effects.
889     // TODO(ntv): loads to immutable memory regions are ok.
890     if (!op.hasNoSideEffect()) {
891       status = failure();
892       continue;
893     }
894     toHoist.push_back(&op);
895   }
896   auto *outerForOp = outer.getOperation();
897   for (auto *op : toHoist)
898     op->moveBefore(outerForOp);
899   return status;
900 }
901 
902 // Traverse the interTile and intraTile loops and try to hoist ops such that
903 // bands of perfectly nested loops are isolated.
904 // Return failure if either perfect interTile or perfect intraTile bands cannot
905 // be formed.
tryIsolateBands(const TileLoops & tileLoops)906 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
907   LogicalResult status = success();
908   auto &interTile = tileLoops.first;
909   auto &intraTile = tileLoops.second;
910   auto size = interTile.size();
911   assert(size == intraTile.size());
912   if (size <= 1)
913     return success();
914   for (unsigned s = 1; s < size; ++s)
915     status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
916                                : failure();
917   for (unsigned s = 1; s < size; ++s)
918     status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
919                                : failure();
920   return status;
921 }
922 
extractFixedOuterLoops(loop::ForOp rootForOp,ArrayRef<int64_t> sizes)923 TileLoops mlir::extractFixedOuterLoops(loop::ForOp rootForOp,
924                                        ArrayRef<int64_t> sizes) {
925   // Collect perfectly nested loops.  If more size values provided than nested
926   // loops available, truncate `sizes`.
927   SmallVector<loop::ForOp, 4> forOps;
928   forOps.reserve(sizes.size());
929   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
930   if (forOps.size() < sizes.size())
931     sizes = sizes.take_front(forOps.size());
932 
933   // Compute the tile sizes such that i-th outer loop executes size[i]
934   // iterations.  Given that the loop current executes
935   //   numIterations = ceildiv((upperBound - lowerBound), step)
936   // iterations, we need to tile with size ceildiv(numIterations, size[i]).
937   SmallVector<Value, 4> tileSizes;
938   tileSizes.reserve(sizes.size());
939   for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
940     assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
941 
942     auto forOp = forOps[i];
943     OpBuilder builder(forOp);
944     auto loc = forOp.getLoc();
945     Value diff =
946         builder.create<SubIOp>(loc, forOp.upperBound(), forOp.lowerBound());
947     Value numIterations = ceilDivPositive(builder, loc, diff, forOp.step());
948     Value iterationsPerBlock =
949         ceilDivPositive(builder, loc, numIterations, sizes[i]);
950     tileSizes.push_back(iterationsPerBlock);
951   }
952 
953   // Call parametric tiling with the given sizes.
954   auto intraTile = tile(forOps, tileSizes, forOps.back());
955   TileLoops tileLoops = std::make_pair(forOps, intraTile);
956 
957   // TODO(ntv, zinenko) for now we just ignore the result of band isolation.
958   // In the future, mapping decisions may be impacted by the ability to
959   // isolate perfectly nested bands.
960   tryIsolateBands(tileLoops);
961 
962   return tileLoops;
963 }
964 
965 // Replaces all uses of `orig` with `replacement` except if the user is listed
966 // in `exceptions`.
967 static void
replaceAllUsesExcept(Value orig,Value replacement,const SmallPtrSetImpl<Operation * > & exceptions)968 replaceAllUsesExcept(Value orig, Value replacement,
969                      const SmallPtrSetImpl<Operation *> &exceptions) {
970   for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
971     if (exceptions.count(use.getOwner()) == 0)
972       use.set(replacement);
973   }
974 }
975 
976 // Transform a loop with a strictly positive step
977 //   for %i = %lb to %ub step %s
978 // into a 0-based loop with step 1
979 //   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
980 //     %i = %ii * %s + %lb
981 // Insert the induction variable remapping in the body of `inner`, which is
982 // expected to be either `loop` or another loop perfectly nested under `loop`.
983 // Insert the definition of new bounds immediate before `outer`, which is
984 // expected to be either `loop` or its parent in the loop nest.
normalizeLoop(loop::ForOp loop,loop::ForOp outer,loop::ForOp inner)985 static void normalizeLoop(loop::ForOp loop, loop::ForOp outer,
986                           loop::ForOp inner) {
987   OpBuilder builder(outer);
988   Location loc = loop.getLoc();
989 
990   // Check if the loop is already known to have a constant zero lower bound or
991   // a constant one step.
992   bool isZeroBased = false;
993   if (auto ubCst =
994           dyn_cast_or_null<ConstantIndexOp>(loop.lowerBound().getDefiningOp()))
995     isZeroBased = ubCst.getValue() == 0;
996 
997   bool isStepOne = false;
998   if (auto stepCst =
999           dyn_cast_or_null<ConstantIndexOp>(loop.step().getDefiningOp()))
1000     isStepOne = stepCst.getValue() == 1;
1001 
1002   if (isZeroBased && isStepOne)
1003     return;
1004 
1005   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
1006   // assuming the step is strictly positive.  Update the bounds and the step
1007   // of the loop to go from 0 to the number of iterations, if necessary.
1008   // TODO(zinenko): introduce support for negative steps or emit dynamic asserts
1009   // on step positivity, whatever gets implemented first.
1010   Value diff =
1011       builder.create<SubIOp>(loc, loop.upperBound(), loop.lowerBound());
1012   Value numIterations = ceilDivPositive(builder, loc, diff, loop.step());
1013   loop.setUpperBound(numIterations);
1014 
1015   Value lb = loop.lowerBound();
1016   if (!isZeroBased) {
1017     Value cst0 = builder.create<ConstantIndexOp>(loc, 0);
1018     loop.setLowerBound(cst0);
1019   }
1020 
1021   Value step = loop.step();
1022   if (!isStepOne) {
1023     Value cst1 = builder.create<ConstantIndexOp>(loc, 1);
1024     loop.setStep(cst1);
1025   }
1026 
1027   // Insert code computing the value of the original loop induction variable
1028   // from the "normalized" one.
1029   builder.setInsertionPointToStart(inner.getBody());
1030   Value scaled =
1031       isStepOne ? loop.getInductionVar()
1032                 : builder.create<MulIOp>(loc, loop.getInductionVar(), step);
1033   Value shifted =
1034       isZeroBased ? scaled : builder.create<AddIOp>(loc, scaled, lb);
1035 
1036   SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
1037                                        shifted.getDefiningOp()};
1038   replaceAllUsesExcept(loop.getInductionVar(), shifted, preserve);
1039 }
1040 
coalesceLoops(MutableArrayRef<loop::ForOp> loops)1041 void mlir::coalesceLoops(MutableArrayRef<loop::ForOp> loops) {
1042   if (loops.size() < 2)
1043     return;
1044 
1045   loop::ForOp innermost = loops.back();
1046   loop::ForOp outermost = loops.front();
1047 
1048   // 1. Make sure all loops iterate from 0 to upperBound with step 1.  This
1049   // allows the following code to assume upperBound is the number of iterations.
1050   for (auto loop : loops)
1051     normalizeLoop(loop, outermost, innermost);
1052 
1053   // 2. Emit code computing the upper bound of the coalesced loop as product
1054   // of the number of iterations of all loops.
1055   OpBuilder builder(outermost);
1056   Location loc = outermost.getLoc();
1057   Value upperBound = outermost.upperBound();
1058   for (auto loop : loops.drop_front())
1059     upperBound = builder.create<MulIOp>(loc, upperBound, loop.upperBound());
1060   outermost.setUpperBound(upperBound);
1061 
1062   builder.setInsertionPointToStart(outermost.getBody());
1063 
1064   // 3. Remap induction variables.  For each original loop, the value of the
1065   // induction variable can be obtained by dividing the induction variable of
1066   // the linearized loop by the total number of iterations of the loops nested
1067   // in it modulo the number of iterations in this loop (remove the values
1068   // related to the outer loops):
1069   //   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
1070   // Compute these iteratively from the innermost loop by creating a "running
1071   // quotient" of division by the range.
1072   Value previous = outermost.getInductionVar();
1073   for (unsigned i = 0, e = loops.size(); i < e; ++i) {
1074     unsigned idx = loops.size() - i - 1;
1075     if (i != 0)
1076       previous = builder.create<SignedDivIOp>(loc, previous,
1077                                               loops[idx + 1].upperBound());
1078 
1079     Value iv = (i == e - 1) ? previous
1080                             : builder.create<SignedRemIOp>(
1081                                   loc, previous, loops[idx].upperBound());
1082     replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
1083                                loops.back().region());
1084   }
1085 
1086   // 4. Move the operations from the innermost just above the second-outermost
1087   // loop, delete the extra terminator and the second-outermost loop.
1088   loop::ForOp second = loops[1];
1089   innermost.getBody()->back().erase();
1090   outermost.getBody()->getOperations().splice(
1091       Block::iterator(second.getOperation()),
1092       innermost.getBody()->getOperations());
1093   second.erase();
1094 }
1095 
mapLoopToProcessorIds(loop::ForOp forOp,ArrayRef<Value> processorId,ArrayRef<Value> numProcessors)1096 void mlir::mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef<Value> processorId,
1097                                  ArrayRef<Value> numProcessors) {
1098   assert(processorId.size() == numProcessors.size());
1099   if (processorId.empty())
1100     return;
1101 
1102   OpBuilder b(forOp);
1103   Location loc(forOp.getLoc());
1104   Value mul = processorId.front();
1105   for (unsigned i = 1, e = processorId.size(); i < e; ++i)
1106     mul = b.create<AddIOp>(loc, b.create<MulIOp>(loc, mul, numProcessors[i]),
1107                            processorId[i]);
1108   Value lb = b.create<AddIOp>(loc, forOp.lowerBound(),
1109                               b.create<MulIOp>(loc, forOp.step(), mul));
1110   forOp.setLowerBound(lb);
1111 
1112   Value step = forOp.step();
1113   for (auto numProcs : numProcessors)
1114     step = b.create<MulIOp>(loc, step, numProcs);
1115   forOp.setStep(step);
1116 }
1117 
1118 /// Given a memref region, determine the lowest depth at which transfers can be
1119 /// placed for it, and return the corresponding block, start and end positions
1120 /// in the block for placing incoming (read) and outgoing (write) copies
1121 /// respectively. The lowest depth depends on whether the region being accessed
1122 /// is hoistable with respect to one or more immediately surrounding loops.
1123 static void
findHighestBlockForPlacement(const MemRefRegion & region,Block & block,Block::iterator & begin,Block::iterator & end,Block ** copyPlacementBlock,Block::iterator * copyInPlacementStart,Block::iterator * copyOutPlacementStart)1124 findHighestBlockForPlacement(const MemRefRegion &region, Block &block,
1125                              Block::iterator &begin, Block::iterator &end,
1126                              Block **copyPlacementBlock,
1127                              Block::iterator *copyInPlacementStart,
1128                              Block::iterator *copyOutPlacementStart) {
1129   const auto *cst = region.getConstraints();
1130   SmallVector<Value, 4> symbols;
1131   cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols);
1132 
1133   SmallVector<AffineForOp, 4> enclosingFors;
1134   getLoopIVs(*block.begin(), &enclosingFors);
1135   // Walk up loop parents till we find an IV on which this region is
1136   // symbolic/variant.
1137   auto it = enclosingFors.rbegin();
1138   for (auto e = enclosingFors.rend(); it != e; ++it) {
1139     // TODO(bondhugula): also need to be checking this for regions symbols that
1140     // aren't loop IVs, whether we are within their resp. defs' dominance scope.
1141     if (llvm::is_contained(symbols, it->getInductionVar()))
1142       break;
1143   }
1144 
1145   if (it != enclosingFors.rbegin()) {
1146     auto lastInvariantIV = *std::prev(it);
1147     *copyInPlacementStart = Block::iterator(lastInvariantIV.getOperation());
1148     *copyOutPlacementStart = std::next(*copyInPlacementStart);
1149     *copyPlacementBlock = lastInvariantIV.getOperation()->getBlock();
1150   } else {
1151     *copyInPlacementStart = begin;
1152     *copyOutPlacementStart = end;
1153     *copyPlacementBlock = &block;
1154   }
1155 }
1156 
1157 // Info comprising stride and number of elements transferred every stride.
1158 struct StrideInfo {
1159   int64_t stride;
1160   int64_t numEltPerStride;
1161 };
1162 
1163 /// Returns striding information for a copy/transfer of this region with
1164 /// potentially multiple striding levels from outermost to innermost. For an
1165 /// n-dimensional region, there can be at most n-1 levels of striding
1166 /// successively nested.
1167 //  TODO(bondhugula): make this work with non-identity layout maps.
getMultiLevelStrides(const MemRefRegion & region,ArrayRef<int64_t> bufferShape,SmallVectorImpl<StrideInfo> * strideInfos)1168 static void getMultiLevelStrides(const MemRefRegion &region,
1169                                  ArrayRef<int64_t> bufferShape,
1170                                  SmallVectorImpl<StrideInfo> *strideInfos) {
1171   if (bufferShape.size() <= 1)
1172     return;
1173 
1174   int64_t numEltPerStride = 1;
1175   int64_t stride = 1;
1176   for (int d = bufferShape.size() - 1; d >= 1; d--) {
1177     int64_t dimSize = region.memref.getType().cast<MemRefType>().getDimSize(d);
1178     stride *= dimSize;
1179     numEltPerStride *= bufferShape[d];
1180     // A stride is needed only if the region has a shorter extent than the
1181     // memref along the dimension *and* has an extent greater than one along the
1182     // next major dimension.
1183     if (bufferShape[d] < dimSize && bufferShape[d - 1] > 1) {
1184       strideInfos->push_back({stride, numEltPerStride});
1185     }
1186   }
1187 }
1188 
1189 /// Generates a point-wise copy from/to `memref' to/from `fastMemRef' and
1190 /// returns the outermost AffineForOp of the copy loop nest. `memIndicesStart'
1191 /// holds the lower coordinates of the region in the original memref to copy
1192 /// in/out. If `copyOut' is true, generates a copy-out; otherwise a copy-in.
generatePointWiseCopy(Location loc,Value memref,Value fastMemRef,AffineMap memAffineMap,ArrayRef<Value> memIndicesStart,ArrayRef<int64_t> fastBufferShape,bool isCopyOut,OpBuilder b)1193 static AffineForOp generatePointWiseCopy(Location loc, Value memref,
1194                                          Value fastMemRef,
1195                                          AffineMap memAffineMap,
1196                                          ArrayRef<Value> memIndicesStart,
1197                                          ArrayRef<int64_t> fastBufferShape,
1198                                          bool isCopyOut, OpBuilder b) {
1199   assert(!memIndicesStart.empty() && "only 1-d or more memrefs");
1200 
1201   // The copy-in nest is generated as follows as an example for a 2-d region:
1202   // for x = ...
1203   //   for y = ...
1204   //     fast_buf[x][y] = buf[mem_x + x][mem_y + y]
1205 
1206   SmallVector<Value, 4> fastBufIndices, memIndices;
1207   AffineForOp copyNestRoot;
1208   for (unsigned d = 0, e = fastBufferShape.size(); d < e; ++d) {
1209     auto forOp = b.create<AffineForOp>(loc, 0, fastBufferShape[d]);
1210     if (d == 0)
1211       copyNestRoot = forOp;
1212     b = forOp.getBodyBuilder();
1213     fastBufIndices.push_back(forOp.getInductionVar());
1214 
1215     Value memBase =
1216         (memAffineMap == b.getMultiDimIdentityMap(memAffineMap.getNumDims()))
1217             ? memIndicesStart[d]
1218             : b.create<AffineApplyOp>(
1219                   loc,
1220                   AffineMap::get(memAffineMap.getNumDims(),
1221                                  memAffineMap.getNumSymbols(),
1222                                  memAffineMap.getResult(d)),
1223                   memIndicesStart);
1224 
1225     // Construct the subscript for the slow memref being copied.
1226     auto memIndex = b.create<AffineApplyOp>(
1227         loc,
1228         AffineMap::get(2, 0, b.getAffineDimExpr(0) + b.getAffineDimExpr(1)),
1229         ValueRange({memBase, forOp.getInductionVar()}));
1230     memIndices.push_back(memIndex);
1231   }
1232 
1233   if (!isCopyOut) {
1234     // Copy in.
1235     auto load = b.create<AffineLoadOp>(loc, memref, memIndices);
1236     b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufIndices);
1237     return copyNestRoot;
1238   }
1239 
1240   // Copy out.
1241   auto load = b.create<AffineLoadOp>(loc, fastMemRef, fastBufIndices);
1242   b.create<AffineStoreOp>(loc, load, memref, memIndices);
1243   return copyNestRoot;
1244 }
1245 
1246 static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
emitRemarkForBlock(Block & block)1247 emitRemarkForBlock(Block &block) {
1248   return block.getParentOp()->emitRemark();
1249 }
1250 
1251 /// Creates a buffer in the faster memory space for the specified memref region;
1252 /// generates a copy from the lower memory space to this one, and replaces all
1253 /// loads/stores in the block range [`begin', `end') of `block' to load/store
1254 /// from that buffer. Returns failure if copies could not be generated due to
1255 /// yet unimplemented cases. `copyInPlacementStart` and `copyOutPlacementStart`
1256 /// in copyPlacementBlock specify the insertion points where the incoming copies
1257 /// and outgoing copies, respectively, should be inserted (the insertion happens
1258 /// right before the insertion point). Since `begin` can itself be invalidated
1259 /// due to the memref rewriting done from this method, the output argument
1260 /// `nBegin` is set to its replacement (set to `begin` if no invalidation
1261 /// happens). Since outgoing copies could have  been inserted at `end`, the
1262 /// output argument `nEnd` is set to the new end. `sizeInBytes` is set to the
1263 /// size of the fast buffer allocated.
generateCopy(const MemRefRegion & region,Block * block,Block::iterator begin,Block::iterator end,Block * copyPlacementBlock,Block::iterator copyInPlacementStart,Block::iterator copyOutPlacementStart,AffineCopyOptions copyOptions,DenseMap<Value,Value> & fastBufferMap,DenseSet<Operation * > & copyNests,uint64_t * sizeInBytes,Block::iterator * nBegin,Block::iterator * nEnd)1264 static LogicalResult generateCopy(
1265     const MemRefRegion &region, Block *block, Block::iterator begin,
1266     Block::iterator end, Block *copyPlacementBlock,
1267     Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart,
1268     AffineCopyOptions copyOptions, DenseMap<Value, Value> &fastBufferMap,
1269     DenseSet<Operation *> &copyNests, uint64_t *sizeInBytes,
1270     Block::iterator *nBegin, Block::iterator *nEnd) {
1271   *nBegin = begin;
1272   *nEnd = end;
1273 
1274   FuncOp f = begin->getParentOfType<FuncOp>();
1275   OpBuilder topBuilder(f.getBody());
1276   Value zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
1277 
1278   if (begin == end)
1279     return success();
1280 
1281   // Is the copy out point at the end of the block where we are doing
1282   // explicit copying.
1283   bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart);
1284 
1285   // Copies for read regions are going to be inserted at 'begin'.
1286   OpBuilder prologue(copyPlacementBlock, copyInPlacementStart);
1287   // Copies for write regions are going to be inserted at 'end'.
1288   OpBuilder epilogue(copyPlacementBlock, copyOutPlacementStart);
1289   OpBuilder &b = region.isWrite() ? epilogue : prologue;
1290 
1291   // Builder to create constants at the top level.
1292   auto func = copyPlacementBlock->getParent()->getParentOfType<FuncOp>();
1293   OpBuilder top(func.getBody());
1294 
1295   auto loc = region.loc;
1296   auto memref = region.memref;
1297   auto memRefType = memref.getType().cast<MemRefType>();
1298 
1299   auto layoutMaps = memRefType.getAffineMaps();
1300   if (layoutMaps.size() > 1 ||
1301       (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
1302     LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
1303     return failure();
1304   }
1305 
1306   // Indices to use for the copying.
1307   // Indices for the original memref being copied from/to.
1308   SmallVector<Value, 4> memIndices;
1309   // Indices for the faster buffer being copied into/from.
1310   SmallVector<Value, 4> bufIndices;
1311 
1312   unsigned rank = memRefType.getRank();
1313   SmallVector<int64_t, 4> fastBufferShape;
1314 
1315   // Compute the extents of the buffer.
1316   std::vector<SmallVector<int64_t, 4>> lbs;
1317   SmallVector<int64_t, 8> lbDivisors;
1318   lbs.reserve(rank);
1319   Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(
1320       &fastBufferShape, &lbs, &lbDivisors);
1321   if (!numElements.hasValue()) {
1322     LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
1323     return failure();
1324   }
1325 
1326   if (numElements.getValue() == 0) {
1327     LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n");
1328     *sizeInBytes = 0;
1329     return success();
1330   }
1331 
1332   const FlatAffineConstraints *cst = region.getConstraints();
1333   // 'regionSymbols' hold values that this memory region is symbolic/parametric
1334   // on; these typically include loop IVs surrounding the level at which the
1335   // copy generation is being done or other valid symbols in MLIR.
1336   SmallVector<Value, 8> regionSymbols;
1337   cst->getIdValues(rank, cst->getNumIds(), &regionSymbols);
1338 
1339   // Construct the index expressions for the fast memory buffer. The index
1340   // expression for a particular dimension of the fast buffer is obtained by
1341   // subtracting out the lower bound on the original memref's data region
1342   // along the corresponding dimension.
1343 
1344   // Index start offsets for faster memory buffer relative to the original.
1345   SmallVector<AffineExpr, 4> offsets;
1346   offsets.reserve(rank);
1347   for (unsigned d = 0; d < rank; d++) {
1348     assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
1349 
1350     AffineExpr offset = top.getAffineConstantExpr(0);
1351     for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
1352       offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
1353     }
1354     assert(lbDivisors[d] > 0);
1355     offset =
1356         (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
1357 
1358     // Set copy start location for this dimension in the lower memory space
1359     // memref.
1360     if (auto caf = offset.dyn_cast<AffineConstantExpr>()) {
1361       auto indexVal = caf.getValue();
1362       if (indexVal == 0) {
1363         memIndices.push_back(zeroIndex);
1364       } else {
1365         memIndices.push_back(
1366             top.create<ConstantIndexOp>(loc, indexVal).getResult());
1367       }
1368     } else {
1369       // The coordinate for the start location is just the lower bound along the
1370       // corresponding dimension on the memory region (stored in 'offset').
1371       auto map = AffineMap::get(
1372           cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset);
1373       memIndices.push_back(b.create<AffineApplyOp>(loc, map, regionSymbols));
1374     }
1375     // The fast buffer is copied into at location zero; addressing is relative.
1376     bufIndices.push_back(zeroIndex);
1377 
1378     // Record the offsets since they are needed to remap the memory accesses of
1379     // the original memref further below.
1380     offsets.push_back(offset);
1381   }
1382 
1383   // The faster memory space buffer.
1384   Value fastMemRef;
1385 
1386   // Check if a buffer was already created.
1387   bool existingBuf = fastBufferMap.count(memref) > 0;
1388   if (!existingBuf) {
1389     AffineMap fastBufferLayout = b.getMultiDimIdentityMap(rank);
1390     auto fastMemRefType =
1391         MemRefType::get(fastBufferShape, memRefType.getElementType(),
1392                         fastBufferLayout, copyOptions.fastMemorySpace);
1393 
1394     // Create the fast memory space buffer just before the 'affine.for'
1395     // operation.
1396     fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType).getResult();
1397     // Record it.
1398     fastBufferMap[memref] = fastMemRef;
1399     // fastMemRefType is a constant shaped memref.
1400     *sizeInBytes = getMemRefSizeInBytes(fastMemRefType).getValue();
1401     LLVM_DEBUG(emitRemarkForBlock(*block)
1402                << "Creating fast buffer of type " << fastMemRefType
1403                << " and size " << llvm::divideCeil(*sizeInBytes, 1024)
1404                << " KiB\n");
1405   } else {
1406     // Reuse the one already created.
1407     fastMemRef = fastBufferMap[memref];
1408     *sizeInBytes = 0;
1409   }
1410 
1411   auto numElementsSSA =
1412       top.create<ConstantIndexOp>(loc, numElements.getValue());
1413 
1414   SmallVector<StrideInfo, 4> strideInfos;
1415   getMultiLevelStrides(region, fastBufferShape, &strideInfos);
1416 
1417   // TODO(bondhugula): use all stride levels once DmaStartOp is extended for
1418   // multi-level strides.
1419   if (strideInfos.size() > 1) {
1420     LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
1421     return failure();
1422   }
1423 
1424   Value stride = nullptr;
1425   Value numEltPerStride = nullptr;
1426   if (!strideInfos.empty()) {
1427     stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride);
1428     numEltPerStride =
1429         top.create<ConstantIndexOp>(loc, strideInfos[0].numEltPerStride);
1430   }
1431 
1432   // Record the last operation where we want the memref replacement to end. We
1433   // later do the memref replacement only in [begin, postDomFilter] so
1434   // that the original memref's used in the data movement code themselves don't
1435   // get replaced.
1436   auto postDomFilter = std::prev(end);
1437 
1438   // Create fully composed affine maps for each memref.
1439   auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size());
1440   fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices);
1441   auto bufAffineMap = b.getMultiDimIdentityMap(bufIndices.size());
1442   fullyComposeAffineMapAndOperands(&bufAffineMap, &bufIndices);
1443 
1444   if (!copyOptions.generateDma) {
1445     // Point-wise copy generation.
1446     auto copyNest = generatePointWiseCopy(loc, memref, fastMemRef, memAffineMap,
1447                                           memIndices, fastBufferShape,
1448                                           /*isCopyOut=*/region.isWrite(), b);
1449 
1450     // Record this so that we can skip it from yet another copy.
1451     copyNests.insert(copyNest);
1452 
1453     // Since new ops are being appended (for copy out's), adjust the end to
1454     // mark end of block range being processed if necessary.
1455     if (region.isWrite() && isCopyOutAtEndOfBlock)
1456       *nEnd = Block::iterator(copyNest.getOperation());
1457   } else {
1458     // DMA generation.
1459     // Create a tag (single element 1-d memref) for the DMA.
1460     auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {},
1461                                          copyOptions.tagMemorySpace);
1462     auto tagMemRef = prologue.create<AllocOp>(loc, tagMemRefType);
1463 
1464     SmallVector<Value, 4> tagIndices({zeroIndex});
1465     auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size());
1466     fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices);
1467     if (!region.isWrite()) {
1468       // DMA non-blocking read from original buffer to fast buffer.
1469       b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
1470                                  fastMemRef, bufAffineMap, bufIndices,
1471                                  tagMemRef, tagAffineMap, tagIndices,
1472                                  numElementsSSA, stride, numEltPerStride);
1473     } else {
1474       // DMA non-blocking write from fast buffer to the original memref.
1475       auto op = b.create<AffineDmaStartOp>(
1476           loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
1477           memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
1478           stride, numEltPerStride);
1479       // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the
1480       // end to mark end of block range being processed.
1481       if (isCopyOutAtEndOfBlock)
1482         *nEnd = Block::iterator(op.getOperation());
1483     }
1484 
1485     // Matching DMA wait to block on completion; tag always has a 0 index.
1486     b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex,
1487                               numElementsSSA);
1488 
1489     // Generate dealloc for the tag.
1490     auto tagDeallocOp = epilogue.create<DeallocOp>(loc, tagMemRef);
1491     if (*nEnd == end && isCopyOutAtEndOfBlock)
1492       // Since new ops are being appended (for outgoing DMAs), adjust the end to
1493       // mark end of range of the original.
1494       *nEnd = Block::iterator(tagDeallocOp.getOperation());
1495   }
1496 
1497   // Generate dealloc for the buffer.
1498   if (!existingBuf) {
1499     auto bufDeallocOp = epilogue.create<DeallocOp>(loc, fastMemRef);
1500     // When generating pointwise copies, `nEnd' has to be set to deallocOp on
1501     // the fast buffer (since it marks the new end insertion point).
1502     if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock)
1503       *nEnd = Block::iterator(bufDeallocOp.getOperation());
1504   }
1505 
1506   // Replace all uses of the old memref with the faster one while remapping
1507   // access indices (subtracting out lower bound offsets for each dimension).
1508   // Ex: to replace load %A[%i, %j] with load %Abuf[%i - %iT, %j - %jT],
1509   // index remap will be (%i, %j) -> (%i - %iT, %j - %jT),
1510   // i.e., affine.apply (d0, d1, d2, d3) -> (d2-d0, d3-d1) (%iT, %jT, %i, %j),
1511   // and (%iT, %jT) will be the 'extraOperands' for 'rep all memref uses with'.
1512   // d2, d3 correspond to the original indices (%i, %j).
1513   SmallVector<AffineExpr, 4> remapExprs;
1514   remapExprs.reserve(rank);
1515   for (unsigned i = 0; i < rank; i++) {
1516     // The starting operands of indexRemap will be regionSymbols (the symbols on
1517     // which the memref region is parametric); then those corresponding to
1518     // the memref's original indices follow.
1519     auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i);
1520     remapExprs.push_back(dimExpr - offsets[i]);
1521   }
1522   auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs);
1523 
1524   // Record the begin since it may be invalidated by memref replacement.
1525   Block::iterator prevOfBegin;
1526   bool isBeginAtStartOfBlock = (begin == block->begin());
1527   if (!isBeginAtStartOfBlock)
1528     prevOfBegin = std::prev(begin);
1529 
1530   // *Only* those uses within the range [begin, end) of 'block' are replaced.
1531   replaceAllMemRefUsesWith(memref, fastMemRef,
1532                            /*extraIndices=*/{}, indexRemap,
1533                            /*extraOperands=*/regionSymbols,
1534                            /*symbolOperands=*/{},
1535                            /*domInstFilter=*/&*begin,
1536                            /*postDomInstFilter=*/&*postDomFilter);
1537 
1538   *nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);
1539 
1540   return success();
1541 }
1542 
1543 /// Construct the memref region to just include the entire memref. Returns false
1544 /// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
1545 /// enclosing loop IVs of opInst (starting from the outermost) that the region
1546 /// is parametric on.
getFullMemRefAsRegion(Operation * opInst,unsigned numParamLoopIVs,MemRefRegion * region)1547 static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs,
1548                                   MemRefRegion *region) {
1549   unsigned rank;
1550   if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
1551     rank = loadOp.getMemRefType().getRank();
1552     region->memref = loadOp.getMemRef();
1553     region->setWrite(false);
1554   } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
1555     rank = storeOp.getMemRefType().getRank();
1556     region->memref = storeOp.getMemRef();
1557     region->setWrite(true);
1558   } else {
1559     assert(false && "expected load or store op");
1560     return false;
1561   }
1562   auto memRefType = region->memref.getType().cast<MemRefType>();
1563   if (!memRefType.hasStaticShape())
1564     return false;
1565 
1566   auto *regionCst = region->getConstraints();
1567 
1568   // Just get the first numSymbols IVs, which the memref region is parametric
1569   // on.
1570   SmallVector<AffineForOp, 4> ivs;
1571   getLoopIVs(*opInst, &ivs);
1572   ivs.resize(numParamLoopIVs);
1573   SmallVector<Value, 4> symbols;
1574   extractForInductionVars(ivs, &symbols);
1575   regionCst->reset(rank, numParamLoopIVs, 0);
1576   regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols);
1577 
1578   // Memref dim sizes provide the bounds.
1579   for (unsigned d = 0; d < rank; d++) {
1580     auto dimSize = memRefType.getDimSize(d);
1581     assert(dimSize > 0 && "filtered dynamic shapes above");
1582     regionCst->addConstantLowerBound(d, 0);
1583     regionCst->addConstantUpperBound(d, dimSize - 1);
1584   }
1585   return true;
1586 }
1587 
1588 /// Generates copies for a contiguous sequence of operations in `block` in the
1589 /// iterator range [`begin', `end'), where `end' can't be past the terminator of
1590 /// the block (since additional operations are potentially inserted right before
1591 /// `end'. Returns the total size of the fast buffers used.
1592 //  Since we generate alloc's and dealloc's for all fast buffers (before and
1593 //  after the range of operations resp.), all of the fast memory capacity is
1594 //  assumed to be available for processing this block range.
affineDataCopyGenerate(Block::iterator begin,Block::iterator end,const AffineCopyOptions & copyOptions,DenseSet<Operation * > & copyNests)1595 uint64_t mlir::affineDataCopyGenerate(Block::iterator begin,
1596                                       Block::iterator end,
1597                                       const AffineCopyOptions &copyOptions,
1598                                       DenseSet<Operation *> &copyNests) {
1599   if (begin == end)
1600     return 0;
1601 
1602   assert(begin->getBlock() == std::prev(end)->getBlock() &&
1603          "Inconsistent block begin/end args");
1604   assert(end != end->getBlock()->end() && "end can't be the block terminator");
1605 
1606   Block *block = begin->getBlock();
1607 
1608   // Copies will be generated for this depth, i.e., symbolic in all loops
1609   // surrounding the this block range.
1610   unsigned copyDepth = getNestingDepth(*begin);
1611 
1612   LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
1613                           << "\n");
1614   LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n");
1615   LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n");
1616 
1617   // List of memory regions to copy for. We need a map vector to have a
1618   // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
1619   // since the alloc's for example are identical except for the SSA id.
1620   SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> readRegions;
1621   SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> writeRegions;
1622 
1623   // Map from original memref's to the fast buffers that their accesses are
1624   // replaced with.
1625   DenseMap<Value, Value> fastBufferMap;
1626 
1627   // To check for errors when walking the block.
1628   bool error = false;
1629 
1630   // Walk this range of operations  to gather all memory regions.
1631   block->walk(begin, end, [&](Operation *opInst) {
1632     // Gather regions to allocate to buffers in faster memory space.
1633     if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
1634       if ((loadOp.getMemRefType().getMemorySpace() !=
1635            copyOptions.slowMemorySpace))
1636         return;
1637     } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
1638       if (storeOp.getMemRefType().getMemorySpace() !=
1639           copyOptions.slowMemorySpace)
1640         return;
1641     } else {
1642       // Neither load nor a store op.
1643       return;
1644     }
1645 
1646     // Compute the MemRefRegion accessed.
1647     auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
1648     if (failed(region->compute(opInst, copyDepth))) {
1649       LLVM_DEBUG(llvm::dbgs()
1650                  << "Error obtaining memory region: semi-affine maps?\n");
1651       LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
1652       if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
1653         LLVM_DEBUG(
1654             opInst->emitError("non-constant memref sizes not yet supported"));
1655         error = true;
1656         return;
1657       }
1658     }
1659 
1660     // Each memref has a single buffer associated with it irrespective of how
1661     // many load's and store's happen on it.
1662     // TODO(bondhugula): in the future, when regions don't intersect and satisfy
1663     // other properties (based on load/store regions), we could consider
1664     // multiple buffers per memref.
1665 
1666     // Add to the appropriate region if it's not already in it, or take a
1667     // bounding box union with the existing one if it's already in there.
1668     // Note that a memref may have both read and write regions - so update the
1669     // region in the other list if one exists (write in case of read and vice
1670     // versa) since there is a single bounding box for a memref across all reads
1671     // and writes that happen on it.
1672 
1673     // Attempts to update; returns true if 'region' exists in targetRegions.
1674     auto updateRegion =
1675         [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
1676                 &targetRegions) {
1677           auto it = targetRegions.find(region->memref);
1678           if (it == targetRegions.end())
1679             return false;
1680 
1681           // Perform a union with the existing region.
1682           if (failed(it->second->unionBoundingBox(*region))) {
1683             LLVM_DEBUG(llvm::dbgs()
1684                        << "Memory region bounding box failed; "
1685                           "over-approximating to the entire memref\n");
1686             // If the union fails, we will overapproximate.
1687             if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
1688               LLVM_DEBUG(opInst->emitError(
1689                   "non-constant memref sizes not yet supported"));
1690               error = true;
1691               return true;
1692             }
1693             it->second->getConstraints()->clearAndCopyFrom(
1694                 *region->getConstraints());
1695           } else {
1696             // Union was computed and stored in 'it->second': copy to 'region'.
1697             region->getConstraints()->clearAndCopyFrom(
1698                 *it->second->getConstraints());
1699           }
1700           return true;
1701         };
1702 
1703     bool existsInRead = updateRegion(readRegions);
1704     if (error)
1705       return;
1706     bool existsInWrite = updateRegion(writeRegions);
1707     if (error)
1708       return;
1709 
1710     // Finally add it to the region list.
1711     if (region->isWrite() && !existsInWrite) {
1712       writeRegions[region->memref] = std::move(region);
1713     } else if (!region->isWrite() && !existsInRead) {
1714       readRegions[region->memref] = std::move(region);
1715     }
1716   });
1717 
1718   if (error) {
1719     begin->emitError(
1720         "copy generation failed for one or more memref's in this block\n");
1721     return 0;
1722   }
1723 
1724   uint64_t totalCopyBuffersSizeInBytes = 0;
1725   bool ret = true;
1726   auto processRegions =
1727       [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
1728               &regions) {
1729         for (const auto &regionEntry : regions) {
1730           // For each region, hoist copy in/out past all hoistable
1731           // 'affine.for's.
1732           Block::iterator copyInPlacementStart, copyOutPlacementStart;
1733           Block *copyPlacementBlock;
1734           findHighestBlockForPlacement(
1735               *regionEntry.second, *block, begin, end, &copyPlacementBlock,
1736               &copyInPlacementStart, &copyOutPlacementStart);
1737 
1738           uint64_t sizeInBytes;
1739           Block::iterator nBegin, nEnd;
1740           LogicalResult iRet = generateCopy(
1741               *regionEntry.second, block, begin, end, copyPlacementBlock,
1742               copyInPlacementStart, copyOutPlacementStart, copyOptions,
1743               fastBufferMap, copyNests, &sizeInBytes, &nBegin, &nEnd);
1744           if (succeeded(iRet)) {
1745             // begin/end could have been invalidated, and need update.
1746             begin = nBegin;
1747             end = nEnd;
1748             totalCopyBuffersSizeInBytes += sizeInBytes;
1749           }
1750           ret = ret & succeeded(iRet);
1751         }
1752       };
1753   processRegions(readRegions);
1754   processRegions(writeRegions);
1755 
1756   if (!ret) {
1757     begin->emitError(
1758         "copy generation failed for one or more memref's in this block\n");
1759     return totalCopyBuffersSizeInBytes;
1760   }
1761 
1762   // For a range of operations, a note will be emitted at the caller.
1763   AffineForOp forOp;
1764   uint64_t sizeInKib = llvm::divideCeil(totalCopyBuffersSizeInBytes, 1024);
1765   if (llvm::DebugFlag && (forOp = dyn_cast<AffineForOp>(&*begin))) {
1766     forOp.emitRemark()
1767         << sizeInKib
1768         << " KiB of copy buffers in fast memory space for this block\n";
1769   }
1770 
1771   if (totalCopyBuffersSizeInBytes > copyOptions.fastMemCapacityBytes) {
1772     StringRef str = "Total size of all copy buffers' for this block "
1773                     "exceeds fast memory capacity\n";
1774     block->getParentOp()->emitError(str);
1775   }
1776 
1777   return totalCopyBuffersSizeInBytes;
1778 }
1779