1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion transformation utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/LoopFusionUtils.h"
14 
15 #include "mlir/Analysis/AffineAnalysis.h"
16 #include "mlir/Analysis/AffineStructures.h"
17 #include "mlir/Analysis/LoopAnalysis.h"
18 #include "mlir/Analysis/Utils.h"
19 #include "mlir/Dialect/AffineOps/AffineOps.h"
20 #include "mlir/Dialect/StandardOps/Ops.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BlockAndValueMapping.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/Function.h"
26 #include "mlir/IR/Operation.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 #define DEBUG_TYPE "loop-fusion-utils"
33 
34 using namespace mlir;
35 
36 // Gathers all load and store memref accesses in 'opA' into 'values', where
37 // 'values[memref] == true' for each store operation.
getLoadAndStoreMemRefAccesses(Operation * opA,DenseMap<Value,bool> & values)38 static void getLoadAndStoreMemRefAccesses(Operation *opA,
39                                           DenseMap<Value, bool> &values) {
40   opA->walk([&](Operation *op) {
41     if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
42       if (values.count(loadOp.getMemRef()) == 0)
43         values[loadOp.getMemRef()] = false;
44     } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
45       values[storeOp.getMemRef()] = true;
46     }
47   });
48 }
49 
50 // Returns true if 'op' is a load or store operation which access an memref
51 // accessed 'values' and at least one of the access is a store operation.
52 // Returns false otherwise.
isDependentLoadOrStoreOp(Operation * op,DenseMap<Value,bool> & values)53 static bool isDependentLoadOrStoreOp(Operation *op,
54                                      DenseMap<Value, bool> &values) {
55   if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
56     return values.count(loadOp.getMemRef()) > 0 &&
57            values[loadOp.getMemRef()] == true;
58   } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
59     return values.count(storeOp.getMemRef()) > 0;
60   }
61   return false;
62 }
63 
64 // Returns the first operation in range ('opA', 'opB') which has a data
65 // dependence on 'opA'. Returns 'nullptr' of no dependence exists.
getFirstDependentOpInRange(Operation * opA,Operation * opB)66 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
67   // Record memref values from all loads/store in loop nest rooted at 'opA'.
68   // Map from memref value to bool which is true if store, false otherwise.
69   DenseMap<Value, bool> values;
70   getLoadAndStoreMemRefAccesses(opA, values);
71 
72   // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
73   // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
74   // and at least one of the accesses is a store).
75   Operation *firstDepOp = nullptr;
76   for (Block::iterator it = std::next(Block::iterator(opA));
77        it != Block::iterator(opB); ++it) {
78     Operation *opX = &(*it);
79     opX->walk([&](Operation *op) {
80       if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
81         firstDepOp = opX;
82     });
83     if (firstDepOp)
84       break;
85   }
86   return firstDepOp;
87 }
88 
89 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there
90 // exists a data dependence from 'opX' to 'opB'.
91 // Returns 'nullptr' of no dependence exists.
getLastDependentOpInRange(Operation * opA,Operation * opB)92 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
93   // Record memref values from all loads/store in loop nest rooted at 'opB'.
94   // Map from memref value to bool which is true if store, false otherwise.
95   DenseMap<Value, bool> values;
96   getLoadAndStoreMemRefAccesses(opB, values);
97 
98   // For each 'opX' in block in range ('opA', 'opB') in reverse order,
99   // check if there is a data dependence from 'opX' to 'opB':
100   // *) 'opX' and 'opB' access the same memref and at least one of the accesses
101   //    is a store.
102   // *) 'opX' produces an SSA Value which is used by 'opB'.
103   Operation *lastDepOp = nullptr;
104   for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
105        it != Block::reverse_iterator(opA); ++it) {
106     Operation *opX = &(*it);
107     opX->walk([&](Operation *op) {
108       if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) {
109         if (isDependentLoadOrStoreOp(op, values)) {
110           lastDepOp = opX;
111           return WalkResult::interrupt();
112         }
113         return WalkResult::advance();
114       }
115       for (auto value : op->getResults()) {
116         for (auto user : value.getUsers()) {
117           SmallVector<AffineForOp, 4> loops;
118           // Check if any loop in loop nest surrounding 'user' is 'opB'.
119           getLoopIVs(*user, &loops);
120           if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
121             lastDepOp = opX;
122             return WalkResult::interrupt();
123           }
124         }
125       }
126       return WalkResult::advance();
127     });
128     if (lastDepOp)
129       break;
130   }
131   return lastDepOp;
132 }
133 
134 // Computes and returns an insertion point operation, before which the
135 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
136 // dependences. Returns nullptr if no such insertion point is found.
getFusedLoopNestInsertionPoint(AffineForOp srcForOp,AffineForOp dstForOp)137 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
138                                                  AffineForOp dstForOp) {
139   bool isSrcForOpBeforeDstForOp =
140       srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation());
141   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
142   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
143 
144   auto *firstDepOpA =
145       getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
146   auto *lastDepOpB =
147       getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
148   // Block:
149   //      ...
150   //  |-- opA
151   //  |   ...
152   //  |   lastDepOpB --|
153   //  |   ...          |
154   //  |-> firstDepOpA  |
155   //      ...          |
156   //      opB <---------
157   //
158   // Valid insertion point range: (lastDepOpB, firstDepOpA)
159   //
160   if (firstDepOpA != nullptr) {
161     if (lastDepOpB != nullptr) {
162       if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
163         // No valid insertion point exists which preserves dependences.
164         return nullptr;
165     }
166     // Return insertion point in valid range closest to 'opB'.
167     // TODO(andydavis) Consider other insertion points in valid range.
168     return firstDepOpA;
169   }
170   // No dependences from 'opA' to operation in range ('opA', 'opB'), return
171   // 'opB' insertion point.
172   return forOpB.getOperation();
173 }
174 
175 // Gathers all load and store ops in loop nest rooted at 'forOp' into
176 // 'loadAndStoreOps'.
177 static bool
gatherLoadsAndStores(AffineForOp forOp,SmallVectorImpl<Operation * > & loadAndStoreOps)178 gatherLoadsAndStores(AffineForOp forOp,
179                      SmallVectorImpl<Operation *> &loadAndStoreOps) {
180   bool hasIfOp = false;
181   forOp.walk([&](Operation *op) {
182     if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
183       loadAndStoreOps.push_back(op);
184     else if (isa<AffineIfOp>(op))
185       hasIfOp = true;
186   });
187   return !hasIfOp;
188 }
189 
190 // TODO(andydavis) Prevent fusion of loop nests with side-effecting operations.
canFuseLoops(AffineForOp srcForOp,AffineForOp dstForOp,unsigned dstLoopDepth,ComputationSliceState * srcSlice)191 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
192                                 unsigned dstLoopDepth,
193                                 ComputationSliceState *srcSlice) {
194   // Return 'failure' if 'dstLoopDepth == 0'.
195   if (dstLoopDepth == 0) {
196     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n.");
197     return FusionResult::FailPrecondition;
198   }
199   // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
200   auto *block = srcForOp.getOperation()->getBlock();
201   if (block != dstForOp.getOperation()->getBlock()) {
202     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n.");
203     return FusionResult::FailPrecondition;
204   }
205 
206   // Return 'failure' if no valid insertion point for fused loop nest in 'block'
207   // exists which would preserve dependences.
208   if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
209     LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n.");
210     return FusionResult::FailBlockDependence;
211   }
212 
213   // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
214   bool isSrcForOpBeforeDstForOp =
215       srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation());
216   // 'forOpA' executes before 'forOpB' in 'block'.
217   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
218   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
219 
220   // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
221   SmallVector<Operation *, 4> opsA;
222   if (!gatherLoadsAndStores(forOpA, opsA)) {
223     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
224     return FusionResult::FailPrecondition;
225   }
226 
227   // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
228   SmallVector<Operation *, 4> opsB;
229   if (!gatherLoadsAndStores(forOpB, opsB)) {
230     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
231     return FusionResult::FailPrecondition;
232   }
233 
234   // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
235   unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
236       *srcForOp.getOperation(), *dstForOp.getOperation());
237 
238   // Compute union of computation slices computed between all pairs of ops
239   // from 'forOpA' and 'forOpB'.
240   if (failed(mlir::computeSliceUnion(opsA, opsB, dstLoopDepth, numCommonLoops,
241                                      isSrcForOpBeforeDstForOp, srcSlice))) {
242     LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
243     return FusionResult::FailPrecondition;
244   }
245 
246   return FusionResult::Success;
247 }
248 
249 /// Collect loop nest statistics (eg. loop trip count and operation count)
250 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
251 /// returns false otherwise.
getLoopNestStats(AffineForOp forOpRoot,LoopNestStats * stats)252 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
253   auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
254     auto *childForOp = forOp.getOperation();
255     auto *parentForOp = forOp.getParentOp();
256     if (!llvm::isa<FuncOp>(parentForOp)) {
257       if (!isa<AffineForOp>(parentForOp)) {
258         LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp");
259         return WalkResult::interrupt();
260       }
261       // Add mapping to 'forOp' from its parent AffineForOp.
262       stats->loopMap[parentForOp].push_back(forOp);
263     }
264 
265     // Record the number of op operations in the body of 'forOp'.
266     unsigned count = 0;
267     stats->opCountMap[childForOp] = 0;
268     for (auto &op : *forOp.getBody()) {
269       if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
270         ++count;
271     }
272     stats->opCountMap[childForOp] = count;
273 
274     // Record trip count for 'forOp'. Set flag if trip count is not
275     // constant.
276     Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
277     if (!maybeConstTripCount.hasValue()) {
278       // Currently only constant trip count loop nests are supported.
279       LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported");
280       return WalkResult::interrupt();
281     }
282 
283     stats->tripCountMap[childForOp] = maybeConstTripCount.getValue();
284     return WalkResult::advance();
285   });
286   return !walkResult.wasInterrupted();
287 }
288 
289 // Computes the total cost of the loop nest rooted at 'forOp'.
290 // Currently, the total cost is computed by counting the total operation
291 // instance count (i.e. total number of operations in the loop bodyloop
292 // operation count * loop trip count) for the entire loop nest.
293 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
294 // specified in the map when computing the total op instance count.
295 // NOTEs: 1) This is used to compute the cost of computation slices, which are
296 // sliced along the iteration dimension, and thus reduce the trip count.
297 // If 'computeCostMap' is non-null, the total op count for forOps specified
298 // in the map is increased (not overridden) by adding the op count from the
299 // map to the existing op count for the for loop. This is done before
300 // multiplying by the loop's trip count, and is used to model the cost of
301 // inserting a sliced loop nest of known cost into the loop's body.
302 // 2) This is also used to compute the cost of fusing a slice of some loop nest
303 // within another loop.
getComputeCostHelper(Operation * forOp,LoopNestStats & stats,llvm::SmallDenseMap<Operation *,uint64_t,8> * tripCountOverrideMap,DenseMap<Operation *,int64_t> * computeCostMap)304 static int64_t getComputeCostHelper(
305     Operation *forOp, LoopNestStats &stats,
306     llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
307     DenseMap<Operation *, int64_t> *computeCostMap) {
308   // 'opCount' is the total number operations in one iteration of 'forOp' body,
309   // minus terminator op which is a no-op.
310   int64_t opCount = stats.opCountMap[forOp] - 1;
311   if (stats.loopMap.count(forOp) > 0) {
312     for (auto childForOp : stats.loopMap[forOp]) {
313       opCount += getComputeCostHelper(childForOp.getOperation(), stats,
314                                       tripCountOverrideMap, computeCostMap);
315     }
316   }
317   // Add in additional op instances from slice (if specified in map).
318   if (computeCostMap != nullptr) {
319     auto it = computeCostMap->find(forOp);
320     if (it != computeCostMap->end()) {
321       opCount += it->second;
322     }
323   }
324   // Override trip count (if specified in map).
325   int64_t tripCount = stats.tripCountMap[forOp];
326   if (tripCountOverrideMap != nullptr) {
327     auto it = tripCountOverrideMap->find(forOp);
328     if (it != tripCountOverrideMap->end()) {
329       tripCount = it->second;
330     }
331   }
332   // Returns the total number of dynamic instances of operations in loop body.
333   return tripCount * opCount;
334 }
335 
336 // TODO(andydavis,b/126426796): extend this to handle multiple result maps.
getConstDifference(AffineMap lbMap,AffineMap ubMap)337 static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
338   assert(lbMap.getNumResults() == 1 && "expected single result bound map");
339   assert(ubMap.getNumResults() == 1 && "expected single result bound map");
340   assert(lbMap.getNumDims() == ubMap.getNumDims());
341   assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
342   AffineExpr lbExpr(lbMap.getResult(0));
343   AffineExpr ubExpr(ubMap.getResult(0));
344   auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
345                                          lbMap.getNumSymbols());
346   auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
347   if (!cExpr)
348     return None;
349   return cExpr.getValue();
350 }
351 
352 // Return the number of iterations in the given slice.
getSliceIterationCount(const llvm::SmallDenseMap<Operation *,uint64_t,8> & sliceTripCountMap)353 static uint64_t getSliceIterationCount(
354     const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
355   uint64_t iterCount = 1;
356   for (const auto &count : sliceTripCountMap) {
357     iterCount *= count.second;
358   }
359   return iterCount;
360 }
361 
362 // Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
363 // nest surrounding represented by slice loop bounds in 'slice'.
364 // Returns true on success, false otherwise (if a non-constant trip count
365 // was encountered).
366 // TODO(andydavis) Make this work with non-unit step loops.
buildSliceTripCountMap(ComputationSliceState * slice,llvm::SmallDenseMap<Operation *,uint64_t,8> * tripCountMap)367 static bool buildSliceTripCountMap(
368     ComputationSliceState *slice,
369     llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
370   unsigned numSrcLoopIVs = slice->ivs.size();
371   // Populate map from AffineForOp -> trip count
372   for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
373     AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]);
374     auto *op = forOp.getOperation();
375     AffineMap lbMap = slice->lbs[i];
376     AffineMap ubMap = slice->ubs[i];
377     if (lbMap == AffineMap() || ubMap == AffineMap()) {
378       // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
379       if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
380         (*tripCountMap)[op] =
381             forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
382         continue;
383       }
384       Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
385       if (maybeConstTripCount.hasValue()) {
386         (*tripCountMap)[op] = maybeConstTripCount.getValue();
387         continue;
388       }
389       return false;
390     }
391     Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
392     // Slice bounds are created with a constant ub - lb difference.
393     if (!tripCount.hasValue())
394       return false;
395     (*tripCountMap)[op] = tripCount.getValue();
396   }
397   return true;
398 }
399 
400 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
401 /// Currently, the total cost is computed by counting the total operation
402 /// instance count (i.e. total number of operations in the loop body * loop
403 /// trip count) for the entire loop nest.
getComputeCost(AffineForOp forOp,LoopNestStats & stats)404 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
405   return getComputeCostHelper(forOp.getOperation(), stats,
406                               /*tripCountOverrideMap=*/nullptr,
407                               /*computeCostMap=*/nullptr);
408 }
409 
410 /// Computes and returns in 'computeCost', the total compute cost of fusing the
411 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
412 /// the total cost is computed by counting the total operation instance count
413 /// (i.e. total number of operations in the loop body * loop trip count) for
414 /// the entire loop nest.
getFusionComputeCost(AffineForOp srcForOp,LoopNestStats & srcStats,AffineForOp dstForOp,LoopNestStats & dstStats,ComputationSliceState * slice,int64_t * computeCost)415 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
416                                 AffineForOp dstForOp, LoopNestStats &dstStats,
417                                 ComputationSliceState *slice,
418                                 int64_t *computeCost) {
419   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
420   DenseMap<Operation *, int64_t> computeCostMap;
421 
422   // Build trip count map for computation slice.
423   if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
424     return false;
425   // Checks whether a store to load forwarding will happen.
426   int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
427   assert(sliceIterationCount > 0);
428   bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
429   auto *insertPointParent = slice->insertPoint->getParentOp();
430 
431   // The store and loads to this memref will disappear.
432   // TODO(andydavis) Add load coalescing to memref data flow opt pass.
433   if (storeLoadFwdGuaranteed) {
434     // Subtract from operation count the loads/store we expect load/store
435     // forwarding to remove.
436     unsigned storeCount = 0;
437     llvm::SmallDenseSet<Value, 4> storeMemrefs;
438     srcForOp.walk([&](Operation *op) {
439       if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
440         storeMemrefs.insert(storeOp.getMemRef());
441         ++storeCount;
442       }
443     });
444     // Subtract out any store ops in single-iteration src slice loop nest.
445     if (storeCount > 0)
446       computeCostMap[insertPointParent] = -storeCount;
447     // Subtract out any load users of 'storeMemrefs' nested below
448     // 'insertPointParent'.
449     for (auto value : storeMemrefs) {
450       for (auto *user : value.getUsers()) {
451         if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
452           SmallVector<AffineForOp, 4> loops;
453           // Check if any loop in loop nest surrounding 'user' is
454           // 'insertPointParent'.
455           getLoopIVs(*user, &loops);
456           if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
457             if (auto forOp =
458                     dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
459               if (computeCostMap.count(forOp) == 0)
460                 computeCostMap[forOp] = 0;
461               computeCostMap[forOp] -= 1;
462             }
463           }
464         }
465       }
466     }
467   }
468 
469   // Compute op instance count for the src loop nest with iteration slicing.
470   int64_t sliceComputeCost = getComputeCostHelper(
471       srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
472 
473   // Compute cost of fusion for this depth.
474   computeCostMap[insertPointParent] = sliceComputeCost;
475 
476   *computeCost =
477       getComputeCostHelper(dstForOp.getOperation(), dstStats,
478                            /*tripCountOverrideMap=*/nullptr, &computeCostMap);
479   return true;
480 }
481