1 //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to pipeline data transfers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Transforms/Passes.h"
15 
16 #include "mlir/Analysis/AffineAnalysis.h"
17 #include "mlir/Analysis/LoopAnalysis.h"
18 #include "mlir/Analysis/Utils.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/Transforms/LoopUtils.h"
22 #include "mlir/Transforms/Utils.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/Support/Debug.h"
25 
26 #define DEBUG_TYPE "affine-pipeline-data-transfer"
27 
28 using namespace mlir;
29 
30 namespace {
31 struct PipelineDataTransfer
32     : public AffinePipelineDataTransferBase<PipelineDataTransfer> {
33   void runOnFunction() override;
34   void runOnAffineForOp(AffineForOp forOp);
35 
36   std::vector<AffineForOp> forOps;
37 };
38 
39 } // end anonymous namespace
40 
41 /// Creates a pass to pipeline explicit movement of data across levels of the
42 /// memory hierarchy.
createPipelineDataTransferPass()43 std::unique_ptr<OperationPass<FuncOp>> mlir::createPipelineDataTransferPass() {
44   return std::make_unique<PipelineDataTransfer>();
45 }
46 
47 // Returns the position of the tag memref operand given a DMA operation.
48 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
49 // added.  TODO
getTagMemRefPos(Operation & dmaOp)50 static unsigned getTagMemRefPos(Operation &dmaOp) {
51   assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp)));
52   if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
53     return dmaStartOp.getTagMemRefOperandIndex();
54   }
55   // First operand for a dma finish operation.
56   return 0;
57 }
58 
59 /// Doubles the buffer of the supplied memref on the specified 'affine.for'
60 /// operation by adding a leading dimension of size two to the memref.
61 /// Replaces all uses of the old memref by the new one while indexing the newly
62 /// added dimension by the loop IV of the specified 'affine.for' operation
63 /// modulo 2. Returns false if such a replacement cannot be performed.
doubleBuffer(Value oldMemRef,AffineForOp forOp)64 static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
65   auto *forBody = forOp.getBody();
66   OpBuilder bInner(forBody, forBody->begin());
67 
68   // Doubles the shape with a leading dimension extent of 2.
69   auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
70     // Add the leading dimension in the shape for the double buffer.
71     ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
72     SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
73     newShape[0] = 2;
74     std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
75     return MemRefType::Builder(oldMemRefType)
76         .setShape(newShape)
77         .setAffineMaps({});
78   };
79 
80   auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
81   auto newMemRefType = doubleShape(oldMemRefType);
82 
83   // The double buffer is allocated right before 'forOp'.
84   OpBuilder bOuter(forOp);
85   // Put together alloc operands for any dynamic dimensions of the memref.
86   SmallVector<Value, 4> allocOperands;
87   unsigned dynamicDimCount = 0;
88   for (auto dimSize : oldMemRefType.getShape()) {
89     if (dimSize == -1)
90       allocOperands.push_back(
91           bOuter.create<DimOp>(forOp.getLoc(), oldMemRef, dynamicDimCount++));
92   }
93 
94   // Create and place the alloc right before the 'affine.for' operation.
95   Value newMemRef =
96       bOuter.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands);
97 
98   // Create 'iv mod 2' value to index the leading dimension.
99   auto d0 = bInner.getAffineDimExpr(0);
100   int64_t step = forOp.getStep();
101   auto modTwoMap =
102       AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
103   auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
104                                                  forOp.getInductionVar());
105 
106   // replaceAllMemRefUsesWith will succeed unless the forOp body has
107   // non-dereferencing uses of the memref (dealloc's are fine though).
108   if (failed(replaceAllMemRefUsesWith(
109           oldMemRef, newMemRef,
110           /*extraIndices=*/{ivModTwoOp},
111           /*indexRemap=*/AffineMap(),
112           /*extraOperands=*/{},
113           /*symbolOperands=*/{},
114           /*domInstFilter=*/&*forOp.getBody()->begin()))) {
115     LLVM_DEBUG(
116         forOp.emitError("memref replacement for double buffering failed"));
117     ivModTwoOp.erase();
118     return false;
119   }
120   // Insert the dealloc op right after the for loop.
121   bOuter.setInsertionPointAfter(forOp);
122   bOuter.create<DeallocOp>(forOp.getLoc(), newMemRef);
123 
124   return true;
125 }
126 
127 /// Returns success if the IR is in a valid state.
runOnFunction()128 void PipelineDataTransfer::runOnFunction() {
129   // Do a post order walk so that inner loop DMAs are processed first. This is
130   // necessary since 'affine.for' operations nested within would otherwise
131   // become invalid (erased) when the outer loop is pipelined (the pipelined one
132   // gets deleted and replaced by a prologue, a new steady-state loop and an
133   // epilogue).
134   forOps.clear();
135   getFunction().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
136   for (auto forOp : forOps)
137     runOnAffineForOp(forOp);
138 }
139 
140 // Check if tags of the dma start op and dma wait op match.
checkTagMatch(AffineDmaStartOp startOp,AffineDmaWaitOp waitOp)141 static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
142   if (startOp.getTagMemRef() != waitOp.getTagMemRef())
143     return false;
144   auto startIndices = startOp.getTagIndices();
145   auto waitIndices = waitOp.getTagIndices();
146   // Both of these have the same number of indices since they correspond to the
147   // same tag memref.
148   for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
149             e = startIndices.end();
150        it != e; ++it, ++wIt) {
151     // Keep it simple for now, just checking if indices match.
152     // TODO: this would in general need to check if there is no
153     // intervening write writing to the same tag location, i.e., memory last
154     // write/data flow analysis. This is however sufficient/powerful enough for
155     // now since the DMA generation pass or the input for it will always have
156     // start/wait with matching tags (same SSA operand indices).
157     if (*it != *wIt)
158       return false;
159   }
160   return true;
161 }
162 
163 // Identify matching DMA start/finish operations to overlap computation with.
findMatchingStartFinishInsts(AffineForOp forOp,SmallVectorImpl<std::pair<Operation *,Operation * >> & startWaitPairs)164 static void findMatchingStartFinishInsts(
165     AffineForOp forOp,
166     SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
167 
168   // Collect outgoing DMA operations - needed to check for dependences below.
169   SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
170   for (auto &op : *forOp.getBody()) {
171     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
172     if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
173       outgoingDmaOps.push_back(dmaStartOp);
174   }
175 
176   SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
177   for (auto &op : *forOp.getBody()) {
178     // Collect DMA finish operations.
179     if (isa<AffineDmaWaitOp>(op)) {
180       dmaFinishInsts.push_back(&op);
181       continue;
182     }
183     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
184     if (!dmaStartOp)
185       continue;
186 
187     // Only DMAs incoming into higher memory spaces are pipelined for now.
188     // TODO: handle outgoing DMA pipelining.
189     if (!dmaStartOp.isDestMemorySpaceFaster())
190       continue;
191 
192     // Check for dependence with outgoing DMAs. Doing this conservatively.
193     // TODO: use the dependence analysis to check for
194     // dependences between an incoming and outgoing DMA in the same iteration.
195     auto it = outgoingDmaOps.begin();
196     for (; it != outgoingDmaOps.end(); ++it) {
197       if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
198         break;
199     }
200     if (it != outgoingDmaOps.end())
201       continue;
202 
203     // We only double buffer if the buffer is not live out of loop.
204     auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
205     bool escapingUses = false;
206     for (auto *user : memref.getUsers()) {
207       // We can double buffer regardless of dealloc's outside the loop.
208       if (isa<DeallocOp>(user))
209         continue;
210       if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
211         LLVM_DEBUG(llvm::dbgs()
212                        << "can't pipeline: buffer is live out of loop\n";);
213         escapingUses = true;
214         break;
215       }
216     }
217     if (!escapingUses)
218       dmaStartInsts.push_back(&op);
219   }
220 
221   // For each start operation, we look for a matching finish operation.
222   for (auto *dmaStartOp : dmaStartInsts) {
223     for (auto *dmaFinishOp : dmaFinishInsts) {
224       if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp),
225                         cast<AffineDmaWaitOp>(dmaFinishOp))) {
226         startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
227         break;
228       }
229     }
230   }
231 }
232 
233 /// Overlap DMA transfers with computation in this loop. If successful,
234 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
235 /// inserted right before where it was.
runOnAffineForOp(AffineForOp forOp)236 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
237   auto mayBeConstTripCount = getConstantTripCount(forOp);
238   if (!mayBeConstTripCount.hasValue()) {
239     LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
240     return;
241   }
242 
243   SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
244   findMatchingStartFinishInsts(forOp, startWaitPairs);
245 
246   if (startWaitPairs.empty()) {
247     LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
248     return;
249   }
250 
251   // Double the buffers for the higher memory space memref's.
252   // Identify memref's to replace by scanning through all DMA start
253   // operations. A DMA start operation has two memref's - the one from the
254   // higher level of memory hierarchy is the one to double buffer.
255   // TODO: check whether double-buffering is even necessary.
256   // TODO: make this work with different layouts: assuming here that
257   // the dimension we are adding here for the double buffering is the outermost
258   // dimension.
259   for (auto &pair : startWaitPairs) {
260     auto *dmaStartOp = pair.first;
261     Value oldMemRef = dmaStartOp->getOperand(
262         cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
263     if (!doubleBuffer(oldMemRef, forOp)) {
264       // Normally, double buffering should not fail because we already checked
265       // that there are no uses outside.
266       LLVM_DEBUG(llvm::dbgs()
267                      << "double buffering failed for" << dmaStartOp << "\n";);
268       // IR still valid and semantically correct.
269       return;
270     }
271     // If the old memref has no more uses, remove its 'dead' alloc if it was
272     // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
273     // operation could have been used on it if it was dynamically shaped in
274     // order to create the double buffer above.)
275     // '-canonicalize' does this in a more general way, but we'll anyway do the
276     // simple/common case so that the output / test cases looks clear.
277     if (auto *allocOp = oldMemRef.getDefiningOp()) {
278       if (oldMemRef.use_empty()) {
279         allocOp->erase();
280       } else if (oldMemRef.hasOneUse()) {
281         if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef.user_begin())) {
282           dealloc.erase();
283           allocOp->erase();
284         }
285       }
286     }
287   }
288 
289   // Double the buffers for tag memrefs.
290   for (auto &pair : startWaitPairs) {
291     auto *dmaFinishOp = pair.second;
292     Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp));
293     if (!doubleBuffer(oldTagMemRef, forOp)) {
294       LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
295       return;
296     }
297     // If the old tag has no uses or a single dealloc use, remove it.
298     // (canonicalization handles more complex cases).
299     if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
300       if (oldTagMemRef.use_empty()) {
301         tagAllocOp->erase();
302       } else if (oldTagMemRef.hasOneUse()) {
303         if (auto dealloc = dyn_cast<DeallocOp>(*oldTagMemRef.user_begin())) {
304           dealloc.erase();
305           tagAllocOp->erase();
306         }
307       }
308     }
309   }
310 
311   // Double buffering would have invalidated all the old DMA start/wait insts.
312   startWaitPairs.clear();
313   findMatchingStartFinishInsts(forOp, startWaitPairs);
314 
315   // Store shift for operation for later lookup for AffineApplyOp's.
316   DenseMap<Operation *, unsigned> instShiftMap;
317   for (auto &pair : startWaitPairs) {
318     auto *dmaStartOp = pair.first;
319     assert(isa<AffineDmaStartOp>(dmaStartOp));
320     instShiftMap[dmaStartOp] = 0;
321     // Set shifts for DMA start op's affine operand computation slices to 0.
322     SmallVector<AffineApplyOp, 4> sliceOps;
323     mlir::createAffineComputationSlice(dmaStartOp, &sliceOps);
324     if (!sliceOps.empty()) {
325       for (auto sliceOp : sliceOps) {
326         instShiftMap[sliceOp.getOperation()] = 0;
327       }
328     } else {
329       // If a slice wasn't created, the reachable affine.apply op's from its
330       // operands are the ones that go with it.
331       SmallVector<Operation *, 4> affineApplyInsts;
332       SmallVector<Value, 4> operands(dmaStartOp->getOperands());
333       getReachableAffineApplyOps(operands, affineApplyInsts);
334       for (auto *op : affineApplyInsts) {
335         instShiftMap[op] = 0;
336       }
337     }
338   }
339   // Everything else (including compute ops and dma finish) are shifted by one.
340   for (auto &op : forOp.getBody()->without_terminator())
341     if (instShiftMap.find(&op) == instShiftMap.end())
342       instShiftMap[&op] = 1;
343 
344   // Get shifts stored in map.
345   SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
346   unsigned s = 0;
347   for (auto &op : forOp.getBody()->without_terminator()) {
348     assert(instShiftMap.find(&op) != instShiftMap.end());
349     shifts[s++] = instShiftMap[&op];
350 
351     // Tagging operations with shifts for debugging purposes.
352     LLVM_DEBUG({
353       OpBuilder b(&op);
354       op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
355     });
356   }
357 
358   if (!isOpwiseShiftValid(forOp, shifts)) {
359     // Violates dependences.
360     LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
361     return;
362   }
363 
364   if (failed(affineForOpBodySkew(forOp, shifts))) {
365     LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
366     return;
367   }
368 }
369