1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion transformation utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Transforms/LoopFusionUtils.h"
14
15 #include "mlir/Analysis/AffineAnalysis.h"
16 #include "mlir/Analysis/AffineStructures.h"
17 #include "mlir/Analysis/LoopAnalysis.h"
18 #include "mlir/Analysis/Utils.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/Operation.h"
26 #include "mlir/Transforms/LoopUtils.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/raw_ostream.h"
31
32 #define DEBUG_TYPE "loop-fusion-utils"
33
34 using namespace mlir;
35
36 // Gathers all load and store memref accesses in 'opA' into 'values', where
37 // 'values[memref] == true' for each store operation.
getLoadAndStoreMemRefAccesses(Operation * opA,DenseMap<Value,bool> & values)38 static void getLoadAndStoreMemRefAccesses(Operation *opA,
39 DenseMap<Value, bool> &values) {
40 opA->walk([&](Operation *op) {
41 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
42 if (values.count(loadOp.getMemRef()) == 0)
43 values[loadOp.getMemRef()] = false;
44 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
45 values[storeOp.getMemRef()] = true;
46 }
47 });
48 }
49
50 /// Returns true if 'op' is a load or store operation which access a memref
51 /// accessed 'values' and at least one of the access is a store operation.
52 /// Returns false otherwise.
isDependentLoadOrStoreOp(Operation * op,DenseMap<Value,bool> & values)53 static bool isDependentLoadOrStoreOp(Operation *op,
54 DenseMap<Value, bool> &values) {
55 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
56 return values.count(loadOp.getMemRef()) > 0 &&
57 values[loadOp.getMemRef()] == true;
58 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
59 return values.count(storeOp.getMemRef()) > 0;
60 }
61 return false;
62 }
63
64 // Returns the first operation in range ('opA', 'opB') which has a data
65 // dependence on 'opA'. Returns 'nullptr' of no dependence exists.
getFirstDependentOpInRange(Operation * opA,Operation * opB)66 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
67 // Record memref values from all loads/store in loop nest rooted at 'opA'.
68 // Map from memref value to bool which is true if store, false otherwise.
69 DenseMap<Value, bool> values;
70 getLoadAndStoreMemRefAccesses(opA, values);
71
72 // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
73 // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
74 // and at least one of the accesses is a store).
75 Operation *firstDepOp = nullptr;
76 for (Block::iterator it = std::next(Block::iterator(opA));
77 it != Block::iterator(opB); ++it) {
78 Operation *opX = &(*it);
79 opX->walk([&](Operation *op) {
80 if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
81 firstDepOp = opX;
82 });
83 if (firstDepOp)
84 break;
85 }
86 return firstDepOp;
87 }
88
89 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there
90 // exists a data dependence from 'opX' to 'opB'.
91 // Returns 'nullptr' of no dependence exists.
getLastDependentOpInRange(Operation * opA,Operation * opB)92 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
93 // Record memref values from all loads/store in loop nest rooted at 'opB'.
94 // Map from memref value to bool which is true if store, false otherwise.
95 DenseMap<Value, bool> values;
96 getLoadAndStoreMemRefAccesses(opB, values);
97
98 // For each 'opX' in block in range ('opA', 'opB') in reverse order,
99 // check if there is a data dependence from 'opX' to 'opB':
100 // *) 'opX' and 'opB' access the same memref and at least one of the accesses
101 // is a store.
102 // *) 'opX' produces an SSA Value which is used by 'opB'.
103 Operation *lastDepOp = nullptr;
104 for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
105 it != Block::reverse_iterator(opA); ++it) {
106 Operation *opX = &(*it);
107 opX->walk([&](Operation *op) {
108 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
109 if (isDependentLoadOrStoreOp(op, values)) {
110 lastDepOp = opX;
111 return WalkResult::interrupt();
112 }
113 return WalkResult::advance();
114 }
115 for (auto value : op->getResults()) {
116 for (Operation *user : value.getUsers()) {
117 SmallVector<AffineForOp, 4> loops;
118 // Check if any loop in loop nest surrounding 'user' is 'opB'.
119 getLoopIVs(*user, &loops);
120 if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
121 lastDepOp = opX;
122 return WalkResult::interrupt();
123 }
124 }
125 }
126 return WalkResult::advance();
127 });
128 if (lastDepOp)
129 break;
130 }
131 return lastDepOp;
132 }
133
134 // Computes and returns an insertion point operation, before which the
135 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
136 // dependences. Returns nullptr if no such insertion point is found.
getFusedLoopNestInsertionPoint(AffineForOp srcForOp,AffineForOp dstForOp)137 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
138 AffineForOp dstForOp) {
139 bool isSrcForOpBeforeDstForOp =
140 srcForOp->isBeforeInBlock(dstForOp.getOperation());
141 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
142 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
143
144 auto *firstDepOpA =
145 getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
146 auto *lastDepOpB =
147 getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
148 // Block:
149 // ...
150 // |-- opA
151 // | ...
152 // | lastDepOpB --|
153 // | ... |
154 // |-> firstDepOpA |
155 // ... |
156 // opB <---------
157 //
158 // Valid insertion point range: (lastDepOpB, firstDepOpA)
159 //
160 if (firstDepOpA != nullptr) {
161 if (lastDepOpB != nullptr) {
162 if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
163 // No valid insertion point exists which preserves dependences.
164 return nullptr;
165 }
166 // Return insertion point in valid range closest to 'opB'.
167 // TODO: Consider other insertion points in valid range.
168 return firstDepOpA;
169 }
170 // No dependences from 'opA' to operation in range ('opA', 'opB'), return
171 // 'opB' insertion point.
172 return forOpB.getOperation();
173 }
174
175 // Gathers all load and store ops in loop nest rooted at 'forOp' into
176 // 'loadAndStoreOps'.
177 static bool
gatherLoadsAndStores(AffineForOp forOp,SmallVectorImpl<Operation * > & loadAndStoreOps)178 gatherLoadsAndStores(AffineForOp forOp,
179 SmallVectorImpl<Operation *> &loadAndStoreOps) {
180 bool hasIfOp = false;
181 forOp.walk([&](Operation *op) {
182 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
183 loadAndStoreOps.push_back(op);
184 else if (isa<AffineIfOp>(op))
185 hasIfOp = true;
186 });
187 return !hasIfOp;
188 }
189
190 /// Returns the maximum loop depth at which we could fuse producer loop
191 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
192 // TODO: Generalize this check for sibling and more generic fusion scenarios.
193 // TODO: Support forward slice fusion.
getMaxLoopDepth(ArrayRef<Operation * > srcOps,ArrayRef<Operation * > dstOps)194 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
195 ArrayRef<Operation *> dstOps) {
196 if (dstOps.empty())
197 // Expected at least one memory operation.
198 // TODO: Revisit this case with a specific example.
199 return 0;
200
201 // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
202 // that they are not considered for analysis.
203 DenseSet<Value> producerConsumerMemrefs;
204 gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
205 SmallVector<Operation *, 4> targetDstOps;
206 for (Operation *dstOp : dstOps) {
207 auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
208 Value memref = loadOp ? loadOp.getMemRef()
209 : cast<AffineWriteOpInterface>(dstOp).getMemRef();
210 if (producerConsumerMemrefs.count(memref) > 0)
211 targetDstOps.push_back(dstOp);
212 }
213
214 assert(!targetDstOps.empty() &&
215 "No dependences between 'srcForOp' and 'dstForOp'?");
216
217 // Compute the innermost common loop depth for loads and stores.
218 unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
219
220 // Return common loop depth for loads if there are no store ops.
221 if (all_of(targetDstOps,
222 [&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
223 return loopDepth;
224
225 // Check dependences on all pairs of ops in 'targetDstOps' and store the
226 // minimum loop depth at which a dependence is satisfied.
227 for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
228 auto *srcOpInst = targetDstOps[i];
229 MemRefAccess srcAccess(srcOpInst);
230 for (unsigned j = 0; j < e; ++j) {
231 auto *dstOpInst = targetDstOps[j];
232 MemRefAccess dstAccess(dstOpInst);
233
234 unsigned numCommonLoops =
235 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
236 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
237 FlatAffineConstraints dependenceConstraints;
238 // TODO: Cache dependence analysis results, check cache here.
239 DependenceResult result = checkMemrefAccessDependence(
240 srcAccess, dstAccess, d, &dependenceConstraints,
241 /*dependenceComponents=*/nullptr);
242 if (hasDependence(result)) {
243 // Store minimum loop depth and break because we want the min 'd' at
244 // which there is a dependence.
245 loopDepth = std::min(loopDepth, d - 1);
246 break;
247 }
248 }
249 }
250 }
251
252 return loopDepth;
253 }
254
255 // TODO: Prevent fusion of loop nests with side-effecting operations.
256 // TODO: This pass performs some computation that is the same for all the depths
257 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes
258 // all the depths at once or only the legal maximal depth for maximal fusion.
canFuseLoops(AffineForOp srcForOp,AffineForOp dstForOp,unsigned dstLoopDepth,ComputationSliceState * srcSlice,FusionStrategy fusionStrategy)259 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
260 unsigned dstLoopDepth,
261 ComputationSliceState *srcSlice,
262 FusionStrategy fusionStrategy) {
263 // Return 'failure' if 'dstLoopDepth == 0'.
264 if (dstLoopDepth == 0) {
265 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
266 return FusionResult::FailPrecondition;
267 }
268 // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
269 auto *block = srcForOp->getBlock();
270 if (block != dstForOp->getBlock()) {
271 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
272 return FusionResult::FailPrecondition;
273 }
274
275 // Return 'failure' if no valid insertion point for fused loop nest in 'block'
276 // exists which would preserve dependences.
277 if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
278 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
279 return FusionResult::FailBlockDependence;
280 }
281
282 // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
283 bool isSrcForOpBeforeDstForOp =
284 srcForOp->isBeforeInBlock(dstForOp.getOperation());
285 // 'forOpA' executes before 'forOpB' in 'block'.
286 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
287 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
288
289 // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
290 SmallVector<Operation *, 4> opsA;
291 if (!gatherLoadsAndStores(forOpA, opsA)) {
292 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
293 return FusionResult::FailPrecondition;
294 }
295
296 // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
297 SmallVector<Operation *, 4> opsB;
298 if (!gatherLoadsAndStores(forOpB, opsB)) {
299 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
300 return FusionResult::FailPrecondition;
301 }
302
303 // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
304 // loop dependences.
305 // TODO: Enable this check for sibling and more generic loop fusion
306 // strategies.
307 if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
308 // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
309 assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
310 if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
311 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
312 return FusionResult::FailFusionDependence;
313 }
314 }
315
316 // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
317 unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
318 *srcForOp.getOperation(), *dstForOp.getOperation());
319
320 // Filter out ops in 'opsA' to compute the slice union based on the
321 // assumptions made by the fusion strategy.
322 SmallVector<Operation *, 4> strategyOpsA;
323 switch (fusionStrategy.getStrategy()) {
324 case FusionStrategy::Generic:
325 // Generic fusion. Take into account all the memory operations to compute
326 // the slice union.
327 strategyOpsA.append(opsA.begin(), opsA.end());
328 break;
329 case FusionStrategy::ProducerConsumer:
330 // Producer-consumer fusion (AffineLoopFusion pass) only takes into
331 // account stores in 'srcForOp' to compute the slice union.
332 for (Operation *op : opsA) {
333 if (isa<AffineWriteOpInterface>(op))
334 strategyOpsA.push_back(op);
335 }
336 break;
337 case FusionStrategy::Sibling:
338 // Sibling fusion (AffineLoopFusion pass) only takes into account the loads
339 // to 'memref' in 'srcForOp' to compute the slice union.
340 for (Operation *op : opsA) {
341 auto load = dyn_cast<AffineReadOpInterface>(op);
342 if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
343 strategyOpsA.push_back(op);
344 }
345 break;
346 }
347
348 // Compute union of computation slices computed between all pairs of ops
349 // from 'forOpA' and 'forOpB'.
350 if (failed(mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth,
351 numCommonLoops, isSrcForOpBeforeDstForOp,
352 srcSlice))) {
353 LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
354 return FusionResult::FailPrecondition;
355 }
356
357 return FusionResult::Success;
358 }
359
360 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
361 /// and source slice loop bounds specified in 'srcSlice'.
fuseLoops(AffineForOp srcForOp,AffineForOp dstForOp,const ComputationSliceState & srcSlice)362 void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
363 const ComputationSliceState &srcSlice) {
364 // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
365 OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
366 BlockAndValueMapping mapper;
367 b.clone(*srcForOp, mapper);
368
369 // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
370 SmallVector<AffineForOp, 4> sliceLoops;
371 for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
372 auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]);
373 if (!loopIV)
374 continue;
375 auto forOp = getForInductionVarOwner(loopIV);
376 sliceLoops.push_back(forOp);
377 if (AffineMap lbMap = srcSlice.lbs[i]) {
378 auto lbOperands = srcSlice.lbOperands[i];
379 canonicalizeMapAndOperands(&lbMap, &lbOperands);
380 forOp.setLowerBound(lbOperands, lbMap);
381 }
382 if (AffineMap ubMap = srcSlice.ubs[i]) {
383 auto ubOperands = srcSlice.ubOperands[i];
384 canonicalizeMapAndOperands(&ubMap, &ubOperands);
385 forOp.setUpperBound(ubOperands, ubMap);
386 }
387 }
388
389 // Promote any single iteration slice loops.
390 for (AffineForOp forOp : sliceLoops)
391 promoteIfSingleIteration(forOp);
392 }
393
394 /// Collect loop nest statistics (eg. loop trip count and operation count)
395 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
396 /// returns false otherwise.
getLoopNestStats(AffineForOp forOpRoot,LoopNestStats * stats)397 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
398 auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
399 auto *childForOp = forOp.getOperation();
400 auto *parentForOp = forOp->getParentOp();
401 if (!llvm::isa<FuncOp>(parentForOp)) {
402 if (!isa<AffineForOp>(parentForOp)) {
403 LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp");
404 return WalkResult::interrupt();
405 }
406 // Add mapping to 'forOp' from its parent AffineForOp.
407 stats->loopMap[parentForOp].push_back(forOp);
408 }
409
410 // Record the number of op operations in the body of 'forOp'.
411 unsigned count = 0;
412 stats->opCountMap[childForOp] = 0;
413 for (auto &op : *forOp.getBody()) {
414 if (!isa<AffineForOp, AffineIfOp>(op))
415 ++count;
416 }
417 stats->opCountMap[childForOp] = count;
418
419 // Record trip count for 'forOp'. Set flag if trip count is not
420 // constant.
421 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
422 if (!maybeConstTripCount.hasValue()) {
423 // Currently only constant trip count loop nests are supported.
424 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported");
425 return WalkResult::interrupt();
426 }
427
428 stats->tripCountMap[childForOp] = maybeConstTripCount.getValue();
429 return WalkResult::advance();
430 });
431 return !walkResult.wasInterrupted();
432 }
433
434 // Computes the total cost of the loop nest rooted at 'forOp'.
435 // Currently, the total cost is computed by counting the total operation
436 // instance count (i.e. total number of operations in the loop bodyloop
437 // operation count * loop trip count) for the entire loop nest.
438 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
439 // specified in the map when computing the total op instance count.
440 // NOTEs: 1) This is used to compute the cost of computation slices, which are
441 // sliced along the iteration dimension, and thus reduce the trip count.
442 // If 'computeCostMap' is non-null, the total op count for forOps specified
443 // in the map is increased (not overridden) by adding the op count from the
444 // map to the existing op count for the for loop. This is done before
445 // multiplying by the loop's trip count, and is used to model the cost of
446 // inserting a sliced loop nest of known cost into the loop's body.
447 // 2) This is also used to compute the cost of fusing a slice of some loop nest
448 // within another loop.
getComputeCostHelper(Operation * forOp,LoopNestStats & stats,llvm::SmallDenseMap<Operation *,uint64_t,8> * tripCountOverrideMap,DenseMap<Operation *,int64_t> * computeCostMap)449 static int64_t getComputeCostHelper(
450 Operation *forOp, LoopNestStats &stats,
451 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
452 DenseMap<Operation *, int64_t> *computeCostMap) {
453 // 'opCount' is the total number operations in one iteration of 'forOp' body,
454 // minus terminator op which is a no-op.
455 int64_t opCount = stats.opCountMap[forOp] - 1;
456 if (stats.loopMap.count(forOp) > 0) {
457 for (auto childForOp : stats.loopMap[forOp]) {
458 opCount += getComputeCostHelper(childForOp.getOperation(), stats,
459 tripCountOverrideMap, computeCostMap);
460 }
461 }
462 // Add in additional op instances from slice (if specified in map).
463 if (computeCostMap != nullptr) {
464 auto it = computeCostMap->find(forOp);
465 if (it != computeCostMap->end()) {
466 opCount += it->second;
467 }
468 }
469 // Override trip count (if specified in map).
470 int64_t tripCount = stats.tripCountMap[forOp];
471 if (tripCountOverrideMap != nullptr) {
472 auto it = tripCountOverrideMap->find(forOp);
473 if (it != tripCountOverrideMap->end()) {
474 tripCount = it->second;
475 }
476 }
477 // Returns the total number of dynamic instances of operations in loop body.
478 return tripCount * opCount;
479 }
480
481 // TODO: extend this to handle multiple result maps.
getConstDifference(AffineMap lbMap,AffineMap ubMap)482 static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
483 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
484 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
485 assert(lbMap.getNumDims() == ubMap.getNumDims());
486 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
487 AffineExpr lbExpr(lbMap.getResult(0));
488 AffineExpr ubExpr(ubMap.getResult(0));
489 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
490 lbMap.getNumSymbols());
491 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
492 if (!cExpr)
493 return None;
494 return cExpr.getValue();
495 }
496
497 // Return the number of iterations in the given slice.
getSliceIterationCount(const llvm::SmallDenseMap<Operation *,uint64_t,8> & sliceTripCountMap)498 static uint64_t getSliceIterationCount(
499 const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
500 uint64_t iterCount = 1;
501 for (const auto &count : sliceTripCountMap) {
502 iterCount *= count.second;
503 }
504 return iterCount;
505 }
506
507 // Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
508 // nest surrounding represented by slice loop bounds in 'slice'.
509 // Returns true on success, false otherwise (if a non-constant trip count
510 // was encountered).
511 // TODO: Make this work with non-unit step loops.
buildSliceTripCountMap(const ComputationSliceState & slice,llvm::SmallDenseMap<Operation *,uint64_t,8> * tripCountMap)512 static bool buildSliceTripCountMap(
513 const ComputationSliceState &slice,
514 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
515 unsigned numSrcLoopIVs = slice.ivs.size();
516 // Populate map from AffineForOp -> trip count
517 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
518 AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
519 auto *op = forOp.getOperation();
520 AffineMap lbMap = slice.lbs[i];
521 AffineMap ubMap = slice.ubs[i];
522 if (lbMap == AffineMap() || ubMap == AffineMap()) {
523 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
524 if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
525 (*tripCountMap)[op] =
526 forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
527 continue;
528 }
529 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
530 if (maybeConstTripCount.hasValue()) {
531 (*tripCountMap)[op] = maybeConstTripCount.getValue();
532 continue;
533 }
534 return false;
535 }
536 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
537 // Slice bounds are created with a constant ub - lb difference.
538 if (!tripCount.hasValue())
539 return false;
540 (*tripCountMap)[op] = tripCount.getValue();
541 }
542 return true;
543 }
544
545 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
546 /// Currently, the total cost is computed by counting the total operation
547 /// instance count (i.e. total number of operations in the loop body * loop
548 /// trip count) for the entire loop nest.
getComputeCost(AffineForOp forOp,LoopNestStats & stats)549 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
550 return getComputeCostHelper(forOp.getOperation(), stats,
551 /*tripCountOverrideMap=*/nullptr,
552 /*computeCostMap=*/nullptr);
553 }
554
555 /// Computes and returns in 'computeCost', the total compute cost of fusing the
556 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
557 /// the total cost is computed by counting the total operation instance count
558 /// (i.e. total number of operations in the loop body * loop trip count) for
559 /// the entire loop nest.
getFusionComputeCost(AffineForOp srcForOp,LoopNestStats & srcStats,AffineForOp dstForOp,LoopNestStats & dstStats,const ComputationSliceState & slice,int64_t * computeCost)560 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
561 AffineForOp dstForOp, LoopNestStats &dstStats,
562 const ComputationSliceState &slice,
563 int64_t *computeCost) {
564 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
565 DenseMap<Operation *, int64_t> computeCostMap;
566
567 // Build trip count map for computation slice.
568 if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
569 return false;
570 // Checks whether a store to load forwarding will happen.
571 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
572 assert(sliceIterationCount > 0);
573 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
574 auto *insertPointParent = slice.insertPoint->getParentOp();
575
576 // The store and loads to this memref will disappear.
577 // TODO: Add load coalescing to memref data flow opt pass.
578 if (storeLoadFwdGuaranteed) {
579 // Subtract from operation count the loads/store we expect load/store
580 // forwarding to remove.
581 unsigned storeCount = 0;
582 llvm::SmallDenseSet<Value, 4> storeMemrefs;
583 srcForOp.walk([&](Operation *op) {
584 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
585 storeMemrefs.insert(storeOp.getMemRef());
586 ++storeCount;
587 }
588 });
589 // Subtract out any store ops in single-iteration src slice loop nest.
590 if (storeCount > 0)
591 computeCostMap[insertPointParent] = -storeCount;
592 // Subtract out any load users of 'storeMemrefs' nested below
593 // 'insertPointParent'.
594 for (auto value : storeMemrefs) {
595 for (auto *user : value.getUsers()) {
596 if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
597 SmallVector<AffineForOp, 4> loops;
598 // Check if any loop in loop nest surrounding 'user' is
599 // 'insertPointParent'.
600 getLoopIVs(*user, &loops);
601 if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
602 if (auto forOp =
603 dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
604 if (computeCostMap.count(forOp) == 0)
605 computeCostMap[forOp] = 0;
606 computeCostMap[forOp] -= 1;
607 }
608 }
609 }
610 }
611 }
612 }
613
614 // Compute op instance count for the src loop nest with iteration slicing.
615 int64_t sliceComputeCost = getComputeCostHelper(
616 srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
617
618 // Compute cost of fusion for this depth.
619 computeCostMap[insertPointParent] = sliceComputeCost;
620
621 *computeCost =
622 getComputeCostHelper(dstForOp.getOperation(), dstStats,
623 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
624 return true;
625 }
626
627 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
628 /// producer-consumer dependence between write ops in 'srcOps' and read ops in
629 /// 'dstOps'.
gatherProducerConsumerMemrefs(ArrayRef<Operation * > srcOps,ArrayRef<Operation * > dstOps,DenseSet<Value> & producerConsumerMemrefs)630 void mlir::gatherProducerConsumerMemrefs(
631 ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
632 DenseSet<Value> &producerConsumerMemrefs) {
633 // Gather memrefs from stores in 'srcOps'.
634 DenseSet<Value> srcStoreMemRefs;
635 for (Operation *op : srcOps)
636 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
637 srcStoreMemRefs.insert(storeOp.getMemRef());
638
639 // Compute the intersection between memrefs from stores in 'srcOps' and
640 // memrefs from loads in 'dstOps'.
641 for (Operation *op : dstOps)
642 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
643 if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
644 producerConsumerMemrefs.insert(loadOp.getMemRef());
645 }
646