1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion transformation utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/LoopFusionUtils.h"
14 
15 #include "mlir/Analysis/AffineAnalysis.h"
16 #include "mlir/Analysis/AffineStructures.h"
17 #include "mlir/Analysis/LoopAnalysis.h"
18 #include "mlir/Analysis/SliceAnalysis.h"
19 #include "mlir/Analysis/Utils.h"
20 #include "mlir/Dialect/Affine/IR/AffineOps.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BlockAndValueMapping.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/Transforms/LoopUtils.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 #define DEBUG_TYPE "loop-fusion-utils"
34 
35 using namespace mlir;
36 
37 // Gathers all load and store memref accesses in 'opA' into 'values', where
38 // 'values[memref] == true' for each store operation.
getLoadAndStoreMemRefAccesses(Operation * opA,DenseMap<Value,bool> & values)39 static void getLoadAndStoreMemRefAccesses(Operation *opA,
40                                           DenseMap<Value, bool> &values) {
41   opA->walk([&](Operation *op) {
42     if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
43       if (values.count(loadOp.getMemRef()) == 0)
44         values[loadOp.getMemRef()] = false;
45     } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
46       values[storeOp.getMemRef()] = true;
47     }
48   });
49 }
50 
51 /// Returns true if 'op' is a load or store operation which access a memref
52 /// accessed 'values' and at least one of the access is a store operation.
53 /// Returns false otherwise.
isDependentLoadOrStoreOp(Operation * op,DenseMap<Value,bool> & values)54 static bool isDependentLoadOrStoreOp(Operation *op,
55                                      DenseMap<Value, bool> &values) {
56   if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
57     return values.count(loadOp.getMemRef()) > 0 &&
58            values[loadOp.getMemRef()] == true;
59   } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
60     return values.count(storeOp.getMemRef()) > 0;
61   }
62   return false;
63 }
64 
65 // Returns the first operation in range ('opA', 'opB') which has a data
66 // dependence on 'opA'. Returns 'nullptr' of no dependence exists.
getFirstDependentOpInRange(Operation * opA,Operation * opB)67 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
68   // Record memref values from all loads/store in loop nest rooted at 'opA'.
69   // Map from memref value to bool which is true if store, false otherwise.
70   DenseMap<Value, bool> values;
71   getLoadAndStoreMemRefAccesses(opA, values);
72 
73   // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
74   // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
75   // and at least one of the accesses is a store).
76   Operation *firstDepOp = nullptr;
77   for (Block::iterator it = std::next(Block::iterator(opA));
78        it != Block::iterator(opB); ++it) {
79     Operation *opX = &(*it);
80     opX->walk([&](Operation *op) {
81       if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
82         firstDepOp = opX;
83     });
84     if (firstDepOp)
85       break;
86   }
87   return firstDepOp;
88 }
89 
90 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there
91 // exists a data dependence from 'opX' to 'opB'.
92 // Returns 'nullptr' of no dependence exists.
getLastDependentOpInRange(Operation * opA,Operation * opB)93 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
94   // Record memref values from all loads/store in loop nest rooted at 'opB'.
95   // Map from memref value to bool which is true if store, false otherwise.
96   DenseMap<Value, bool> values;
97   getLoadAndStoreMemRefAccesses(opB, values);
98 
99   // For each 'opX' in block in range ('opA', 'opB') in reverse order,
100   // check if there is a data dependence from 'opX' to 'opB':
101   // *) 'opX' and 'opB' access the same memref and at least one of the accesses
102   //    is a store.
103   // *) 'opX' produces an SSA Value which is used by 'opB'.
104   Operation *lastDepOp = nullptr;
105   for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
106        it != Block::reverse_iterator(opA); ++it) {
107     Operation *opX = &(*it);
108     opX->walk([&](Operation *op) {
109       if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
110         if (isDependentLoadOrStoreOp(op, values)) {
111           lastDepOp = opX;
112           return WalkResult::interrupt();
113         }
114         return WalkResult::advance();
115       }
116       for (auto value : op->getResults()) {
117         for (Operation *user : value.getUsers()) {
118           SmallVector<AffineForOp, 4> loops;
119           // Check if any loop in loop nest surrounding 'user' is 'opB'.
120           getLoopIVs(*user, &loops);
121           if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
122             lastDepOp = opX;
123             return WalkResult::interrupt();
124           }
125         }
126       }
127       return WalkResult::advance();
128     });
129     if (lastDepOp)
130       break;
131   }
132   return lastDepOp;
133 }
134 
135 // Computes and returns an insertion point operation, before which the
136 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
137 // dependences. Returns nullptr if no such insertion point is found.
getFusedLoopNestInsertionPoint(AffineForOp srcForOp,AffineForOp dstForOp)138 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
139                                                  AffineForOp dstForOp) {
140   bool isSrcForOpBeforeDstForOp =
141       srcForOp->isBeforeInBlock(dstForOp.getOperation());
142   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
143   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
144 
145   auto *firstDepOpA =
146       getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
147   auto *lastDepOpB =
148       getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
149   // Block:
150   //      ...
151   //  |-- opA
152   //  |   ...
153   //  |   lastDepOpB --|
154   //  |   ...          |
155   //  |-> firstDepOpA  |
156   //      ...          |
157   //      opB <---------
158   //
159   // Valid insertion point range: (lastDepOpB, firstDepOpA)
160   //
161   if (firstDepOpA != nullptr) {
162     if (lastDepOpB != nullptr) {
163       if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
164         // No valid insertion point exists which preserves dependences.
165         return nullptr;
166     }
167     // Return insertion point in valid range closest to 'opB'.
168     // TODO: Consider other insertion points in valid range.
169     return firstDepOpA;
170   }
171   // No dependences from 'opA' to operation in range ('opA', 'opB'), return
172   // 'opB' insertion point.
173   return forOpB.getOperation();
174 }
175 
176 // Gathers all load and store ops in loop nest rooted at 'forOp' into
177 // 'loadAndStoreOps'.
178 static bool
gatherLoadsAndStores(AffineForOp forOp,SmallVectorImpl<Operation * > & loadAndStoreOps)179 gatherLoadsAndStores(AffineForOp forOp,
180                      SmallVectorImpl<Operation *> &loadAndStoreOps) {
181   bool hasIfOp = false;
182   forOp.walk([&](Operation *op) {
183     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
184       loadAndStoreOps.push_back(op);
185     else if (isa<AffineIfOp>(op))
186       hasIfOp = true;
187   });
188   return !hasIfOp;
189 }
190 
191 /// Returns the maximum loop depth at which we could fuse producer loop
192 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
193 // TODO: Generalize this check for sibling and more generic fusion scenarios.
194 // TODO: Support forward slice fusion.
getMaxLoopDepth(ArrayRef<Operation * > srcOps,ArrayRef<Operation * > dstOps)195 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
196                                 ArrayRef<Operation *> dstOps) {
197   if (dstOps.empty())
198     // Expected at least one memory operation.
199     // TODO: Revisit this case with a specific example.
200     return 0;
201 
202   // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
203   // that they are not considered for analysis.
204   DenseSet<Value> producerConsumerMemrefs;
205   gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
206   SmallVector<Operation *, 4> targetDstOps;
207   for (Operation *dstOp : dstOps) {
208     auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
209     Value memref = loadOp ? loadOp.getMemRef()
210                           : cast<AffineWriteOpInterface>(dstOp).getMemRef();
211     if (producerConsumerMemrefs.count(memref) > 0)
212       targetDstOps.push_back(dstOp);
213   }
214 
215   assert(!targetDstOps.empty() &&
216          "No dependences between 'srcForOp' and 'dstForOp'?");
217 
218   // Compute the innermost common loop depth for loads and stores.
219   unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
220 
221   // Return common loop depth for loads if there are no store ops.
222   if (all_of(targetDstOps,
223              [&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
224     return loopDepth;
225 
226   // Check dependences on all pairs of ops in 'targetDstOps' and store the
227   // minimum loop depth at which a dependence is satisfied.
228   for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
229     auto *srcOpInst = targetDstOps[i];
230     MemRefAccess srcAccess(srcOpInst);
231     for (unsigned j = 0; j < e; ++j) {
232       auto *dstOpInst = targetDstOps[j];
233       MemRefAccess dstAccess(dstOpInst);
234 
235       unsigned numCommonLoops =
236           getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
237       for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
238         FlatAffineConstraints dependenceConstraints;
239         // TODO: Cache dependence analysis results, check cache here.
240         DependenceResult result = checkMemrefAccessDependence(
241             srcAccess, dstAccess, d, &dependenceConstraints,
242             /*dependenceComponents=*/nullptr);
243         if (hasDependence(result)) {
244           // Store minimum loop depth and break because we want the min 'd' at
245           // which there is a dependence.
246           loopDepth = std::min(loopDepth, d - 1);
247           break;
248         }
249       }
250     }
251   }
252 
253   return loopDepth;
254 }
255 
256 // TODO: Prevent fusion of loop nests with side-effecting operations.
257 // TODO: This pass performs some computation that is the same for all the depths
258 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes
259 // all the depths at once or only the legal maximal depth for maximal fusion.
canFuseLoops(AffineForOp srcForOp,AffineForOp dstForOp,unsigned dstLoopDepth,ComputationSliceState * srcSlice,FusionStrategy fusionStrategy)260 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
261                                 unsigned dstLoopDepth,
262                                 ComputationSliceState *srcSlice,
263                                 FusionStrategy fusionStrategy) {
264   // Return 'failure' if 'dstLoopDepth == 0'.
265   if (dstLoopDepth == 0) {
266     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
267     return FusionResult::FailPrecondition;
268   }
269   // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
270   auto *block = srcForOp->getBlock();
271   if (block != dstForOp->getBlock()) {
272     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
273     return FusionResult::FailPrecondition;
274   }
275 
276   // Return 'failure' if no valid insertion point for fused loop nest in 'block'
277   // exists which would preserve dependences.
278   if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
279     LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
280     return FusionResult::FailBlockDependence;
281   }
282 
283   // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
284   bool isSrcForOpBeforeDstForOp =
285       srcForOp->isBeforeInBlock(dstForOp.getOperation());
286   // 'forOpA' executes before 'forOpB' in 'block'.
287   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
288   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
289 
290   // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
291   SmallVector<Operation *, 4> opsA;
292   if (!gatherLoadsAndStores(forOpA, opsA)) {
293     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
294     return FusionResult::FailPrecondition;
295   }
296 
297   // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
298   SmallVector<Operation *, 4> opsB;
299   if (!gatherLoadsAndStores(forOpB, opsB)) {
300     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
301     return FusionResult::FailPrecondition;
302   }
303 
304   // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
305   // loop dependences.
306   // TODO: Enable this check for sibling and more generic loop fusion
307   // strategies.
308   if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
309     // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
310     assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
311     if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
312       LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
313       return FusionResult::FailFusionDependence;
314     }
315   }
316 
317   // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
318   unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
319       *srcForOp.getOperation(), *dstForOp.getOperation());
320 
321   // Filter out ops in 'opsA' to compute the slice union based on the
322   // assumptions made by the fusion strategy.
323   SmallVector<Operation *, 4> strategyOpsA;
324   switch (fusionStrategy.getStrategy()) {
325   case FusionStrategy::Generic:
326     // Generic fusion. Take into account all the memory operations to compute
327     // the slice union.
328     strategyOpsA.append(opsA.begin(), opsA.end());
329     break;
330   case FusionStrategy::ProducerConsumer:
331     // Producer-consumer fusion (AffineLoopFusion pass) only takes into
332     // account stores in 'srcForOp' to compute the slice union.
333     for (Operation *op : opsA) {
334       if (isa<AffineWriteOpInterface>(op))
335         strategyOpsA.push_back(op);
336     }
337     break;
338   case FusionStrategy::Sibling:
339     // Sibling fusion (AffineLoopFusion pass) only takes into account the loads
340     // to 'memref' in 'srcForOp' to compute the slice union.
341     for (Operation *op : opsA) {
342       auto load = dyn_cast<AffineReadOpInterface>(op);
343       if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
344         strategyOpsA.push_back(op);
345     }
346     break;
347   }
348 
349   // Compute union of computation slices computed between all pairs of ops
350   // from 'forOpA' and 'forOpB'.
351   SliceComputationResult sliceComputationResult =
352       mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
353                               isSrcForOpBeforeDstForOp, srcSlice);
354   if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
355     LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
356     return FusionResult::FailPrecondition;
357   }
358   if (sliceComputationResult.value ==
359       SliceComputationResult::IncorrectSliceFailure) {
360     LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
361     return FusionResult::FailIncorrectSlice;
362   }
363 
364   return FusionResult::Success;
365 }
366 
367 /// Patch the loop body of a forOp that is a single iteration reduction loop
368 /// into its containing block.
promoteSingleIterReductionLoop(AffineForOp forOp,bool siblingFusionUser)369 LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
370                                              bool siblingFusionUser) {
371   // Check if the reduction loop is a single iteration loop.
372   Optional<uint64_t> tripCount = getConstantTripCount(forOp);
373   if (!tripCount || tripCount.getValue() != 1)
374     return failure();
375   auto iterOperands = forOp.getIterOperands();
376   auto *parentOp = forOp->getParentOp();
377   if (!isa<AffineForOp>(parentOp))
378     return failure();
379   auto newOperands = forOp.getBody()->getTerminator()->getOperands();
380   OpBuilder b(parentOp);
381   // Replace the parent loop and add iteroperands and results from the `forOp`.
382   AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
383   AffineForOp newLoop = replaceForOpWithNewYields(
384       b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs());
385 
386   // For sibling-fusion users, collect operations that use the results of the
387   // `forOp` outside the new parent loop that has absorbed all its iter args
388   // and operands. These operations will be moved later after the results
389   // have been replaced.
390   SetVector<Operation *> forwardSlice;
391   if (siblingFusionUser) {
392     for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
393       SetVector<Operation *> tmpForwardSlice;
394       getForwardSlice(forOp.getResult(i), &tmpForwardSlice);
395       forwardSlice.set_union(tmpForwardSlice);
396     }
397   }
398   // Update the results of the `forOp` in the new loop.
399   for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
400     forOp.getResult(i).replaceAllUsesWith(
401         newLoop.getResult(i + parentOp->getNumResults()));
402   }
403   // For sibling-fusion users, move operations that use the results of the
404   // `forOp` outside the new parent loop
405   if (siblingFusionUser) {
406     topologicalSort(forwardSlice);
407     for (Operation *op : llvm::reverse(forwardSlice))
408       op->moveAfter(newLoop);
409   }
410   // Replace the induction variable.
411   auto iv = forOp.getInductionVar();
412   iv.replaceAllUsesWith(newLoop.getInductionVar());
413   // Replace the iter args.
414   auto forOpIterArgs = forOp.getRegionIterArgs();
415   for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back(
416                                               forOpIterArgs.size()))) {
417     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
418   }
419   // Move the loop body operations, except for its terminator, to the loop's
420   // containing block.
421   forOp.getBody()->back().erase();
422   auto *parentBlock = forOp->getBlock();
423   parentBlock->getOperations().splice(Block::iterator(forOp),
424                                       forOp.getBody()->getOperations());
425   forOp.erase();
426   parentForOp.erase();
427   return success();
428 }
429 
430 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
431 /// and source slice loop bounds specified in 'srcSlice'.
fuseLoops(AffineForOp srcForOp,AffineForOp dstForOp,const ComputationSliceState & srcSlice,bool isInnermostSiblingInsertion)432 void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
433                      const ComputationSliceState &srcSlice,
434                      bool isInnermostSiblingInsertion) {
435   // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
436   OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
437   BlockAndValueMapping mapper;
438   b.clone(*srcForOp, mapper);
439 
440   // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
441   SmallVector<AffineForOp, 4> sliceLoops;
442   for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
443     auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]);
444     if (!loopIV)
445       continue;
446     auto forOp = getForInductionVarOwner(loopIV);
447     sliceLoops.push_back(forOp);
448     if (AffineMap lbMap = srcSlice.lbs[i]) {
449       auto lbOperands = srcSlice.lbOperands[i];
450       canonicalizeMapAndOperands(&lbMap, &lbOperands);
451       forOp.setLowerBound(lbOperands, lbMap);
452     }
453     if (AffineMap ubMap = srcSlice.ubs[i]) {
454       auto ubOperands = srcSlice.ubOperands[i];
455       canonicalizeMapAndOperands(&ubMap, &ubOperands);
456       forOp.setUpperBound(ubOperands, ubMap);
457     }
458   }
459 
460   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
461   auto srcIsUnitSlice = [&]() {
462     return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
463             (getSliceIterationCount(sliceTripCountMap) == 1));
464   };
465   // Fix up and if possible, eliminate single iteration loops.
466   for (AffineForOp forOp : sliceLoops) {
467     if (isLoopParallelAndContainsReduction(forOp) &&
468         isInnermostSiblingInsertion && srcIsUnitSlice())
469       // Patch reduction loop - only ones that are sibling-fused with the
470       // destination loop - into the parent loop.
471       (void)promoteSingleIterReductionLoop(forOp, true);
472     else
473       // Promote any single iteration slice loops.
474       (void)promoteIfSingleIteration(forOp);
475   }
476 }
477 
478 /// Collect loop nest statistics (eg. loop trip count and operation count)
479 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
480 /// returns false otherwise.
getLoopNestStats(AffineForOp forOpRoot,LoopNestStats * stats)481 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
482   auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
483     auto *childForOp = forOp.getOperation();
484     auto *parentForOp = forOp->getParentOp();
485     if (!llvm::isa<FuncOp>(parentForOp)) {
486       if (!isa<AffineForOp>(parentForOp)) {
487         LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
488         return WalkResult::interrupt();
489       }
490       // Add mapping to 'forOp' from its parent AffineForOp.
491       stats->loopMap[parentForOp].push_back(forOp);
492     }
493 
494     // Record the number of op operations in the body of 'forOp'.
495     unsigned count = 0;
496     stats->opCountMap[childForOp] = 0;
497     for (auto &op : *forOp.getBody()) {
498       if (!isa<AffineForOp, AffineIfOp>(op))
499         ++count;
500     }
501     stats->opCountMap[childForOp] = count;
502 
503     // Record trip count for 'forOp'. Set flag if trip count is not
504     // constant.
505     Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
506     if (!maybeConstTripCount.hasValue()) {
507       // Currently only constant trip count loop nests are supported.
508       LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
509       return WalkResult::interrupt();
510     }
511 
512     stats->tripCountMap[childForOp] = maybeConstTripCount.getValue();
513     return WalkResult::advance();
514   });
515   return !walkResult.wasInterrupted();
516 }
517 
518 // Computes the total cost of the loop nest rooted at 'forOp'.
519 // Currently, the total cost is computed by counting the total operation
520 // instance count (i.e. total number of operations in the loop bodyloop
521 // operation count * loop trip count) for the entire loop nest.
522 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
523 // specified in the map when computing the total op instance count.
524 // NOTEs: 1) This is used to compute the cost of computation slices, which are
525 // sliced along the iteration dimension, and thus reduce the trip count.
526 // If 'computeCostMap' is non-null, the total op count for forOps specified
527 // in the map is increased (not overridden) by adding the op count from the
528 // map to the existing op count for the for loop. This is done before
529 // multiplying by the loop's trip count, and is used to model the cost of
530 // inserting a sliced loop nest of known cost into the loop's body.
531 // 2) This is also used to compute the cost of fusing a slice of some loop nest
532 // within another loop.
getComputeCostHelper(Operation * forOp,LoopNestStats & stats,llvm::SmallDenseMap<Operation *,uint64_t,8> * tripCountOverrideMap,DenseMap<Operation *,int64_t> * computeCostMap)533 static int64_t getComputeCostHelper(
534     Operation *forOp, LoopNestStats &stats,
535     llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
536     DenseMap<Operation *, int64_t> *computeCostMap) {
537   // 'opCount' is the total number operations in one iteration of 'forOp' body,
538   // minus terminator op which is a no-op.
539   int64_t opCount = stats.opCountMap[forOp] - 1;
540   if (stats.loopMap.count(forOp) > 0) {
541     for (auto childForOp : stats.loopMap[forOp]) {
542       opCount += getComputeCostHelper(childForOp.getOperation(), stats,
543                                       tripCountOverrideMap, computeCostMap);
544     }
545   }
546   // Add in additional op instances from slice (if specified in map).
547   if (computeCostMap != nullptr) {
548     auto it = computeCostMap->find(forOp);
549     if (it != computeCostMap->end()) {
550       opCount += it->second;
551     }
552   }
553   // Override trip count (if specified in map).
554   int64_t tripCount = stats.tripCountMap[forOp];
555   if (tripCountOverrideMap != nullptr) {
556     auto it = tripCountOverrideMap->find(forOp);
557     if (it != tripCountOverrideMap->end()) {
558       tripCount = it->second;
559     }
560   }
561   // Returns the total number of dynamic instances of operations in loop body.
562   return tripCount * opCount;
563 }
564 
565 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
566 /// Currently, the total cost is computed by counting the total operation
567 /// instance count (i.e. total number of operations in the loop body * loop
568 /// trip count) for the entire loop nest.
getComputeCost(AffineForOp forOp,LoopNestStats & stats)569 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
570   return getComputeCostHelper(forOp.getOperation(), stats,
571                               /*tripCountOverrideMap=*/nullptr,
572                               /*computeCostMap=*/nullptr);
573 }
574 
575 /// Computes and returns in 'computeCost', the total compute cost of fusing the
576 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
577 /// the total cost is computed by counting the total operation instance count
578 /// (i.e. total number of operations in the loop body * loop trip count) for
579 /// the entire loop nest.
getFusionComputeCost(AffineForOp srcForOp,LoopNestStats & srcStats,AffineForOp dstForOp,LoopNestStats & dstStats,const ComputationSliceState & slice,int64_t * computeCost)580 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
581                                 AffineForOp dstForOp, LoopNestStats &dstStats,
582                                 const ComputationSliceState &slice,
583                                 int64_t *computeCost) {
584   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
585   DenseMap<Operation *, int64_t> computeCostMap;
586 
587   // Build trip count map for computation slice.
588   if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
589     return false;
590   // Checks whether a store to load forwarding will happen.
591   int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
592   assert(sliceIterationCount > 0);
593   bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
594   auto *insertPointParent = slice.insertPoint->getParentOp();
595 
596   // The store and loads to this memref will disappear.
597   // TODO: Add load coalescing to memref data flow opt pass.
598   if (storeLoadFwdGuaranteed) {
599     // Subtract from operation count the loads/store we expect load/store
600     // forwarding to remove.
601     unsigned storeCount = 0;
602     llvm::SmallDenseSet<Value, 4> storeMemrefs;
603     srcForOp.walk([&](Operation *op) {
604       if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
605         storeMemrefs.insert(storeOp.getMemRef());
606         ++storeCount;
607       }
608     });
609     // Subtract out any store ops in single-iteration src slice loop nest.
610     if (storeCount > 0)
611       computeCostMap[insertPointParent] = -storeCount;
612     // Subtract out any load users of 'storeMemrefs' nested below
613     // 'insertPointParent'.
614     for (auto value : storeMemrefs) {
615       for (auto *user : value.getUsers()) {
616         if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
617           SmallVector<AffineForOp, 4> loops;
618           // Check if any loop in loop nest surrounding 'user' is
619           // 'insertPointParent'.
620           getLoopIVs(*user, &loops);
621           if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
622             if (auto forOp =
623                     dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
624               if (computeCostMap.count(forOp) == 0)
625                 computeCostMap[forOp] = 0;
626               computeCostMap[forOp] -= 1;
627             }
628           }
629         }
630       }
631     }
632   }
633 
634   // Compute op instance count for the src loop nest with iteration slicing.
635   int64_t sliceComputeCost = getComputeCostHelper(
636       srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
637 
638   // Compute cost of fusion for this depth.
639   computeCostMap[insertPointParent] = sliceComputeCost;
640 
641   *computeCost =
642       getComputeCostHelper(dstForOp.getOperation(), dstStats,
643                            /*tripCountOverrideMap=*/nullptr, &computeCostMap);
644   return true;
645 }
646 
647 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
648 /// producer-consumer dependence between write ops in 'srcOps' and read ops in
649 /// 'dstOps'.
gatherProducerConsumerMemrefs(ArrayRef<Operation * > srcOps,ArrayRef<Operation * > dstOps,DenseSet<Value> & producerConsumerMemrefs)650 void mlir::gatherProducerConsumerMemrefs(
651     ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
652     DenseSet<Value> &producerConsumerMemrefs) {
653   // Gather memrefs from stores in 'srcOps'.
654   DenseSet<Value> srcStoreMemRefs;
655   for (Operation *op : srcOps)
656     if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
657       srcStoreMemRefs.insert(storeOp.getMemRef());
658 
659   // Compute the intersection between memrefs from stores in 'srcOps' and
660   // memrefs from loads in 'dstOps'.
661   for (Operation *op : dstOps)
662     if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
663       if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
664         producerConsumerMemrefs.insert(loadOp.getMemRef());
665 }
666