1 //===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous analysis routines for non-loop IR
10 // structures.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Analysis/Utils.h"
15 #include "mlir/Analysis/AffineAnalysis.h"
16 #include "mlir/Analysis/LoopAnalysis.h"
17 #include "mlir/Analysis/PresburgerSet.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/raw_ostream.h"
25 
26 #define DEBUG_TYPE "analysis-utils"
27 
28 using namespace mlir;
29 
30 using llvm::SmallDenseMap;
31 
32 /// Populates 'loops' with IVs of the loops surrounding 'op' ordered from
33 /// the outermost 'affine.for' operation to the innermost one.
getLoopIVs(Operation & op,SmallVectorImpl<AffineForOp> * loops)34 void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
35   auto *currOp = op.getParentOp();
36   AffineForOp currAffineForOp;
37   // Traverse up the hierarchy collecting all 'affine.for' operation while
38   // skipping over 'affine.if' operations.
39   while (currOp) {
40     if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
41       loops->push_back(currAffineForOp);
42     currOp = currOp->getParentOp();
43   }
44   std::reverse(loops->begin(), loops->end());
45 }
46 
47 /// Populates 'ops' with IVs of the loops surrounding `op`, along with
48 /// `affine.if` operations interleaved between these loops, ordered from the
49 /// outermost `affine.for` operation to the innermost one.
getEnclosingAffineForAndIfOps(Operation & op,SmallVectorImpl<Operation * > * ops)50 void mlir::getEnclosingAffineForAndIfOps(Operation &op,
51                                          SmallVectorImpl<Operation *> *ops) {
52   ops->clear();
53   Operation *currOp = op.getParentOp();
54 
55   // Traverse up the hierarchy collecting all `affine.for` and `affine.if`
56   // operations.
57   while (currOp) {
58     if (isa<AffineIfOp, AffineForOp>(currOp))
59       ops->push_back(currOp);
60     currOp = currOp->getParentOp();
61   }
62   std::reverse(ops->begin(), ops->end());
63 }
64 
65 // Populates 'cst' with FlatAffineValueConstraints which represent original
66 // domain of the loop bounds that define 'ivs'.
67 LogicalResult
getSourceAsConstraints(FlatAffineValueConstraints & cst)68 ComputationSliceState::getSourceAsConstraints(FlatAffineValueConstraints &cst) {
69   assert(!ivs.empty() && "Cannot have a slice without its IVs");
70   cst.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs);
71   for (Value iv : ivs) {
72     AffineForOp loop = getForInductionVarOwner(iv);
73     assert(loop && "Expected affine for");
74     if (failed(cst.addAffineForOpDomain(loop)))
75       return failure();
76   }
77   return success();
78 }
79 
80 // Populates 'cst' with FlatAffineValueConstraints which represent slice bounds.
81 LogicalResult
getAsConstraints(FlatAffineValueConstraints * cst)82 ComputationSliceState::getAsConstraints(FlatAffineValueConstraints *cst) {
83   assert(!lbOperands.empty());
84   // Adds src 'ivs' as dimension identifiers in 'cst'.
85   unsigned numDims = ivs.size();
86   // Adds operands (dst ivs and symbols) as symbols in 'cst'.
87   unsigned numSymbols = lbOperands[0].size();
88 
89   SmallVector<Value, 4> values(ivs);
90   // Append 'ivs' then 'operands' to 'values'.
91   values.append(lbOperands[0].begin(), lbOperands[0].end());
92   cst->reset(numDims, numSymbols, 0, values);
93 
94   // Add loop bound constraints for values which are loop IVs of the destination
95   // of fusion and equality constraints for symbols which are constants.
96   for (unsigned i = numDims, end = values.size(); i < end; ++i) {
97     Value value = values[i];
98     assert(cst->containsId(value) && "value expected to be present");
99     if (isValidSymbol(value)) {
100       // Check if the symbol is a constant.
101       if (auto cOp = value.getDefiningOp<ConstantIndexOp>())
102         cst->addBound(FlatAffineConstraints::EQ, value, cOp.getValue());
103     } else if (auto loop = getForInductionVarOwner(value)) {
104       if (failed(cst->addAffineForOpDomain(loop)))
105         return failure();
106     }
107   }
108 
109   // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
110   LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
111   assert(succeeded(ret) &&
112          "should not fail as we never have semi-affine slice maps");
113   (void)ret;
114   return success();
115 }
116 
117 // Clears state bounds and operand state.
clearBounds()118 void ComputationSliceState::clearBounds() {
119   lbs.clear();
120   ubs.clear();
121   lbOperands.clear();
122   ubOperands.clear();
123 }
124 
dump() const125 void ComputationSliceState::dump() const {
126   llvm::errs() << "\tIVs:\n";
127   for (Value iv : ivs)
128     llvm::errs() << "\t\t" << iv << "\n";
129 
130   llvm::errs() << "\tLBs:\n";
131   for (auto &en : llvm::enumerate(lbs)) {
132     llvm::errs() << "\t\t" << en.value() << "\n";
133     llvm::errs() << "\t\tOperands:\n";
134     for (Value lbOp : lbOperands[en.index()])
135       llvm::errs() << "\t\t\t" << lbOp << "\n";
136   }
137 
138   llvm::errs() << "\tUBs:\n";
139   for (auto &en : llvm::enumerate(ubs)) {
140     llvm::errs() << "\t\t" << en.value() << "\n";
141     llvm::errs() << "\t\tOperands:\n";
142     for (Value ubOp : ubOperands[en.index()])
143       llvm::errs() << "\t\t\t" << ubOp << "\n";
144   }
145 }
146 
147 /// Fast check to determine if the computation slice is maximal. Returns true if
148 /// each slice dimension maps to an existing dst dimension and both the src
149 /// and the dst loops for those dimensions have the same bounds. Returns false
150 /// if both the src and the dst loops don't have the same bounds. Returns
151 /// llvm::None if none of the above can be proven.
isSliceMaximalFastCheck() const152 Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
153   assert(lbs.size() == ubs.size() && lbs.size() && ivs.size() &&
154          "Unexpected number of lbs, ubs and ivs in slice");
155 
156   for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
157     AffineMap lbMap = lbs[i];
158     AffineMap ubMap = ubs[i];
159 
160     // Check if this slice is just an equality along this dimension.
161     if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
162         ubMap.getNumResults() != 1 ||
163         lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
164         // The condition above will be true for maps describing a single
165         // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
166         // Make sure we skip those cases by checking that the lb result is not
167         // just a constant.
168         lbMap.getResult(0).isa<AffineConstantExpr>())
169       return llvm::None;
170 
171     // Limited support: we expect the lb result to be just a loop dimension for
172     // now.
173     AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
174     if (!result)
175       return llvm::None;
176 
177     // Retrieve dst loop bounds.
178     AffineForOp dstLoop =
179         getForInductionVarOwner(lbOperands[i][result.getPosition()]);
180     if (!dstLoop)
181       return llvm::None;
182     AffineMap dstLbMap = dstLoop.getLowerBoundMap();
183     AffineMap dstUbMap = dstLoop.getUpperBoundMap();
184 
185     // Retrieve src loop bounds.
186     AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
187     assert(srcLoop && "Expected affine for");
188     AffineMap srcLbMap = srcLoop.getLowerBoundMap();
189     AffineMap srcUbMap = srcLoop.getUpperBoundMap();
190 
191     // Limited support: we expect simple src and dst loops with a single
192     // constant component per bound for now.
193     if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
194         dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
195       return llvm::None;
196 
197     AffineExpr srcLbResult = srcLbMap.getResult(0);
198     AffineExpr dstLbResult = dstLbMap.getResult(0);
199     AffineExpr srcUbResult = srcUbMap.getResult(0);
200     AffineExpr dstUbResult = dstUbMap.getResult(0);
201     if (!srcLbResult.isa<AffineConstantExpr>() ||
202         !srcUbResult.isa<AffineConstantExpr>() ||
203         !dstLbResult.isa<AffineConstantExpr>() ||
204         !dstUbResult.isa<AffineConstantExpr>())
205       return llvm::None;
206 
207     // Check if src and dst loop bounds are the same. If not, we can guarantee
208     // that the slice is not maximal.
209     if (srcLbResult != dstLbResult || srcUbResult != dstUbResult)
210       return false;
211   }
212 
213   return true;
214 }
215 
216 /// Returns true if it is deterministically verified that the original iteration
217 /// space of the slice is contained within the new iteration space that is
218 /// created after fusing 'this' slice into its destination.
isSliceValid()219 Optional<bool> ComputationSliceState::isSliceValid() {
220   // Fast check to determine if the slice is valid. If the following conditions
221   // are verified to be true, slice is declared valid by the fast check:
222   // 1. Each slice loop is a single iteration loop bound in terms of a single
223   //    destination loop IV.
224   // 2. Loop bounds of the destination loop IV (from above) and those of the
225   //    source loop IV are exactly the same.
226   // If the fast check is inconclusive or false, we proceed with a more
227   // expensive analysis.
228   // TODO: Store the result of the fast check, as it might be used again in
229   // `canRemoveSrcNodeAfterFusion`.
230   Optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
231   if (isValidFastCheck.hasValue() && isValidFastCheck.getValue())
232     return true;
233 
234   // Create constraints for the source loop nest using which slice is computed.
235   FlatAffineValueConstraints srcConstraints;
236   // TODO: Store the source's domain to avoid computation at each depth.
237   if (failed(getSourceAsConstraints(srcConstraints))) {
238     LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
239     return llvm::None;
240   }
241   // As the set difference utility currently cannot handle symbols in its
242   // operands, validity of the slice cannot be determined.
243   if (srcConstraints.getNumSymbolIds() > 0) {
244     LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
245     return llvm::None;
246   }
247   // TODO: Handle local ids in the source domains while using the 'projectOut'
248   // utility below. Currently, aligning is not done assuming that there will be
249   // no local ids in the source domain.
250   if (srcConstraints.getNumLocalIds() != 0) {
251     LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
252     return llvm::None;
253   }
254 
255   // Create constraints for the slice loop nest that would be created if the
256   // fusion succeeds.
257   FlatAffineValueConstraints sliceConstraints;
258   if (failed(getAsConstraints(&sliceConstraints))) {
259     LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
260     return llvm::None;
261   }
262 
263   // Projecting out every dimension other than the 'ivs' to express slice's
264   // domain completely in terms of source's IVs.
265   sliceConstraints.projectOut(ivs.size(),
266                               sliceConstraints.getNumIds() - ivs.size());
267 
268   LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
269   LLVM_DEBUG(srcConstraints.dump());
270   LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
271                              "(expressed in terms of its source's IVs):\n");
272   LLVM_DEBUG(sliceConstraints.dump());
273 
274   // TODO: Store 'srcSet' to avoid recalculating for each depth.
275   PresburgerSet srcSet(srcConstraints);
276   PresburgerSet sliceSet(sliceConstraints);
277   PresburgerSet diffSet = sliceSet.subtract(srcSet);
278 
279   if (!diffSet.isIntegerEmpty()) {
280     LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
281     return false;
282   }
283   return true;
284 }
285 
286 /// Returns true if the computation slice encloses all the iterations of the
287 /// sliced loop nest. Returns false if it does not. Returns llvm::None if it
288 /// cannot determine if the slice is maximal or not.
isMaximal() const289 Optional<bool> ComputationSliceState::isMaximal() const {
290   // Fast check to determine if the computation slice is maximal. If the result
291   // is inconclusive, we proceed with a more expensive analysis.
292   Optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
293   if (isMaximalFastCheck.hasValue())
294     return isMaximalFastCheck;
295 
296   // Create constraints for the src loop nest being sliced.
297   FlatAffineValueConstraints srcConstraints;
298   srcConstraints.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0,
299                        /*numLocals=*/0, ivs);
300   for (Value iv : ivs) {
301     AffineForOp loop = getForInductionVarOwner(iv);
302     assert(loop && "Expected affine for");
303     if (failed(srcConstraints.addAffineForOpDomain(loop)))
304       return llvm::None;
305   }
306 
307   // Create constraints for the slice using the dst loop nest information. We
308   // retrieve existing dst loops from the lbOperands.
309   SmallVector<Value, 8> consumerIVs;
310   for (Value lbOp : lbOperands[0])
311     if (getForInductionVarOwner(lbOp))
312       consumerIVs.push_back(lbOp);
313 
314   // Add empty IV Values for those new loops that are not equalities and,
315   // therefore, are not yet materialized in the IR.
316   for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
317     consumerIVs.push_back(Value());
318 
319   FlatAffineValueConstraints sliceConstraints;
320   sliceConstraints.reset(/*numDims=*/consumerIVs.size(), /*numSymbols=*/0,
321                          /*numLocals=*/0, consumerIVs);
322 
323   if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
324     return llvm::None;
325 
326   if (srcConstraints.getNumDimIds() != sliceConstraints.getNumDimIds())
327     // Constraint dims are different. The integer set difference can't be
328     // computed so we don't know if the slice is maximal.
329     return llvm::None;
330 
331   // Compute the difference between the src loop nest and the slice integer
332   // sets.
333   PresburgerSet srcSet(srcConstraints);
334   PresburgerSet sliceSet(sliceConstraints);
335   PresburgerSet diffSet = srcSet.subtract(sliceSet);
336   return diffSet.isIntegerEmpty();
337 }
338 
getRank() const339 unsigned MemRefRegion::getRank() const {
340   return memref.getType().cast<MemRefType>().getRank();
341 }
342 
getConstantBoundingSizeAndShape(SmallVectorImpl<int64_t> * shape,std::vector<SmallVector<int64_t,4>> * lbs,SmallVectorImpl<int64_t> * lbDivisors) const343 Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
344     SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
345     SmallVectorImpl<int64_t> *lbDivisors) const {
346   auto memRefType = memref.getType().cast<MemRefType>();
347   unsigned rank = memRefType.getRank();
348   if (shape)
349     shape->reserve(rank);
350 
351   assert(rank == cst.getNumDimIds() && "inconsistent memref region");
352 
353   // Use a copy of the region constraints that has upper/lower bounds for each
354   // memref dimension with static size added to guard against potential
355   // over-approximation from projection or union bounding box. We may not add
356   // this on the region itself since they might just be redundant constraints
357   // that will need non-trivials means to eliminate.
358   FlatAffineConstraints cstWithShapeBounds(cst);
359   for (unsigned r = 0; r < rank; r++) {
360     cstWithShapeBounds.addBound(FlatAffineConstraints::LB, r, 0);
361     int64_t dimSize = memRefType.getDimSize(r);
362     if (ShapedType::isDynamic(dimSize))
363       continue;
364     cstWithShapeBounds.addBound(FlatAffineConstraints::UB, r, dimSize - 1);
365   }
366 
367   // Find a constant upper bound on the extent of this memref region along each
368   // dimension.
369   int64_t numElements = 1;
370   int64_t diffConstant;
371   int64_t lbDivisor;
372   for (unsigned d = 0; d < rank; d++) {
373     SmallVector<int64_t, 4> lb;
374     Optional<int64_t> diff =
375         cstWithShapeBounds.getConstantBoundOnDimSize(d, &lb, &lbDivisor);
376     if (diff.hasValue()) {
377       diffConstant = diff.getValue();
378       assert(diffConstant >= 0 && "Dim size bound can't be negative");
379       assert(lbDivisor > 0);
380     } else {
381       // If no constant bound is found, then it can always be bound by the
382       // memref's dim size if the latter has a constant size along this dim.
383       auto dimSize = memRefType.getDimSize(d);
384       if (dimSize == -1)
385         return None;
386       diffConstant = dimSize;
387       // Lower bound becomes 0.
388       lb.resize(cstWithShapeBounds.getNumSymbolIds() + 1, 0);
389       lbDivisor = 1;
390     }
391     numElements *= diffConstant;
392     if (lbs) {
393       lbs->push_back(lb);
394       assert(lbDivisors && "both lbs and lbDivisor or none");
395       lbDivisors->push_back(lbDivisor);
396     }
397     if (shape) {
398       shape->push_back(diffConstant);
399     }
400   }
401   return numElements;
402 }
403 
getLowerAndUpperBound(unsigned pos,AffineMap & lbMap,AffineMap & ubMap) const404 void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
405                                          AffineMap &ubMap) const {
406   assert(pos < cst.getNumDimIds() && "invalid position");
407   auto memRefType = memref.getType().cast<MemRefType>();
408   unsigned rank = memRefType.getRank();
409 
410   assert(rank == cst.getNumDimIds() && "inconsistent memref region");
411 
412   auto boundPairs = cst.getLowerAndUpperBound(
413       pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolIds(),
414       /*localExprs=*/{}, memRefType.getContext());
415   lbMap = boundPairs.first;
416   ubMap = boundPairs.second;
417   assert(lbMap && "lower bound for a region must exist");
418   assert(ubMap && "upper bound for a region must exist");
419   assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolIds() - rank);
420   assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolIds() - rank);
421 }
422 
unionBoundingBox(const MemRefRegion & other)423 LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
424   assert(memref == other.memref);
425   return cst.unionBoundingBox(*other.getConstraints());
426 }
427 
428 /// Computes the memory region accessed by this memref with the region
429 /// represented as constraints symbolic/parametric in 'loopDepth' loops
430 /// surrounding opInst and any additional Function symbols.
431 //  For example, the memref region for this load operation at loopDepth = 1 will
432 //  be as below:
433 //
434 //    affine.for %i = 0 to 32 {
435 //      affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
436 //        load %A[%ii]
437 //      }
438 //    }
439 //
440 // region:  {memref = %A, write = false, {%i <= m0 <= %i + 7} }
441 // The last field is a 2-d FlatAffineConstraints symbolic in %i.
442 //
443 // TODO: extend this to any other memref dereferencing ops
444 // (dma_start, dma_wait).
compute(Operation * op,unsigned loopDepth,const ComputationSliceState * sliceState,bool addMemRefDimBounds)445 LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
446                                     const ComputationSliceState *sliceState,
447                                     bool addMemRefDimBounds) {
448   assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
449          "affine read/write op expected");
450 
451   MemRefAccess access(op);
452   memref = access.memref;
453   write = access.isStore();
454 
455   unsigned rank = access.getRank();
456 
457   LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
458                           << "depth: " << loopDepth << "\n";);
459 
460   // 0-d memrefs.
461   if (rank == 0) {
462     SmallVector<AffineForOp, 4> ivs;
463     getLoopIVs(*op, &ivs);
464     assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
465     // The first 'loopDepth' IVs are symbols for this region.
466     ivs.resize(loopDepth);
467     SmallVector<Value, 4> regionSymbols;
468     extractForInductionVars(ivs, &regionSymbols);
469     // A 0-d memref has a 0-d region.
470     cst.reset(rank, loopDepth, /*numLocals=*/0, regionSymbols);
471     return success();
472   }
473 
474   // Build the constraints for this region.
475   AffineValueMap accessValueMap;
476   access.getAccessMap(&accessValueMap);
477   AffineMap accessMap = accessValueMap.getAffineMap();
478 
479   unsigned numDims = accessMap.getNumDims();
480   unsigned numSymbols = accessMap.getNumSymbols();
481   unsigned numOperands = accessValueMap.getNumOperands();
482   // Merge operands with slice operands.
483   SmallVector<Value, 4> operands;
484   operands.resize(numOperands);
485   for (unsigned i = 0; i < numOperands; ++i)
486     operands[i] = accessValueMap.getOperand(i);
487 
488   if (sliceState != nullptr) {
489     operands.reserve(operands.size() + sliceState->lbOperands[0].size());
490     // Append slice operands to 'operands' as symbols.
491     for (auto extraOperand : sliceState->lbOperands[0]) {
492       if (!llvm::is_contained(operands, extraOperand)) {
493         operands.push_back(extraOperand);
494         numSymbols++;
495       }
496     }
497   }
498   // We'll first associate the dims and symbols of the access map to the dims
499   // and symbols resp. of cst. This will change below once cst is
500   // fully constructed out.
501   cst.reset(numDims, numSymbols, 0, operands);
502 
503   // Add equality constraints.
504   // Add inequalities for loop lower/upper bounds.
505   for (unsigned i = 0; i < numDims + numSymbols; ++i) {
506     auto operand = operands[i];
507     if (auto loop = getForInductionVarOwner(operand)) {
508       // Note that cst can now have more dimensions than accessMap if the
509       // bounds expressions involve outer loops or other symbols.
510       // TODO: rewrite this to use getInstIndexSet; this way
511       // conditionals will be handled when the latter supports it.
512       if (failed(cst.addAffineForOpDomain(loop)))
513         return failure();
514     } else {
515       // Has to be a valid symbol.
516       auto symbol = operand;
517       assert(isValidSymbol(symbol));
518       // Check if the symbol is a constant.
519       if (auto *op = symbol.getDefiningOp()) {
520         if (auto constOp = dyn_cast<ConstantIndexOp>(op)) {
521           cst.addBound(FlatAffineConstraints::EQ, symbol, constOp.getValue());
522         }
523       }
524     }
525   }
526 
527   // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
528   if (sliceState != nullptr) {
529     // Add dim and symbol slice operands.
530     for (auto operand : sliceState->lbOperands[0]) {
531       cst.addInductionVarOrTerminalSymbol(operand);
532     }
533     // Add upper/lower bounds from 'sliceState' to 'cst'.
534     LogicalResult ret =
535         cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
536                            sliceState->lbOperands[0]);
537     assert(succeeded(ret) &&
538            "should not fail as we never have semi-affine slice maps");
539     (void)ret;
540   }
541 
542   // Add access function equalities to connect loop IVs to data dimensions.
543   if (failed(cst.composeMap(&accessValueMap))) {
544     op->emitError("getMemRefRegion: compose affine map failed");
545     LLVM_DEBUG(accessValueMap.getAffineMap().dump());
546     return failure();
547   }
548 
549   // Set all identifiers appearing after the first 'rank' identifiers as
550   // symbolic identifiers - so that the ones corresponding to the memref
551   // dimensions are the dimensional identifiers for the memref region.
552   cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank);
553 
554   // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
555   // this memref region is symbolic.
556   SmallVector<AffineForOp, 4> enclosingIVs;
557   getLoopIVs(*op, &enclosingIVs);
558   assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
559   enclosingIVs.resize(loopDepth);
560   SmallVector<Value, 4> ids;
561   cst.getValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids);
562   for (auto id : ids) {
563     AffineForOp iv;
564     if ((iv = getForInductionVarOwner(id)) &&
565         llvm::is_contained(enclosingIVs, iv) == false) {
566       cst.projectOut(id);
567     }
568   }
569 
570   // Project out any local variables (these would have been added for any
571   // mod/divs).
572   cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds());
573 
574   // Constant fold any symbolic identifiers.
575   cst.constantFoldIdRange(/*pos=*/cst.getNumDimIds(),
576                           /*num=*/cst.getNumSymbolIds());
577 
578   assert(cst.getNumDimIds() == rank && "unexpected MemRefRegion format");
579 
580   // Add upper/lower bounds for each memref dimension with static size
581   // to guard against potential over-approximation from projection.
582   // TODO: Support dynamic memref dimensions.
583   if (addMemRefDimBounds) {
584     auto memRefType = memref.getType().cast<MemRefType>();
585     for (unsigned r = 0; r < rank; r++) {
586       cst.addBound(FlatAffineConstraints::LB, /*pos=*/r, /*value=*/0);
587       if (memRefType.isDynamicDim(r))
588         continue;
589       cst.addBound(FlatAffineConstraints::UB, /*pos=*/r,
590                    memRefType.getDimSize(r) - 1);
591     }
592   }
593   cst.removeTrivialRedundancy();
594 
595   LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
596   LLVM_DEBUG(cst.dump());
597   return success();
598 }
599 
getMemRefEltSizeInBytes(MemRefType memRefType)600 static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
601   auto elementType = memRefType.getElementType();
602 
603   unsigned sizeInBits;
604   if (elementType.isIntOrFloat()) {
605     sizeInBits = elementType.getIntOrFloatBitWidth();
606   } else {
607     auto vectorType = elementType.cast<VectorType>();
608     sizeInBits =
609         vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
610   }
611   return llvm::divideCeil(sizeInBits, 8);
612 }
613 
614 // Returns the size of the region.
getRegionSize()615 Optional<int64_t> MemRefRegion::getRegionSize() {
616   auto memRefType = memref.getType().cast<MemRefType>();
617 
618   auto layoutMaps = memRefType.getAffineMaps();
619   if (layoutMaps.size() > 1 ||
620       (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
621     LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
622     return false;
623   }
624 
625   // Indices to use for the DmaStart op.
626   // Indices for the original memref being DMAed from/to.
627   SmallVector<Value, 4> memIndices;
628   // Indices for the faster buffer being DMAed into/from.
629   SmallVector<Value, 4> bufIndices;
630 
631   // Compute the extents of the buffer.
632   Optional<int64_t> numElements = getConstantBoundingSizeAndShape();
633   if (!numElements.hasValue()) {
634     LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
635     return None;
636   }
637   return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
638 }
639 
640 /// Returns the size of memref data in bytes if it's statically shaped, None
641 /// otherwise.  If the element of the memref has vector type, takes into account
642 /// size of the vector as well.
643 //  TODO: improve/complete this when we have target data.
getMemRefSizeInBytes(MemRefType memRefType)644 Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
645   if (!memRefType.hasStaticShape())
646     return None;
647   auto elementType = memRefType.getElementType();
648   if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
649     return None;
650 
651   uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
652   for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
653     sizeInBytes = sizeInBytes * memRefType.getDimSize(i);
654   }
655   return sizeInBytes;
656 }
657 
658 template <typename LoadOrStoreOp>
boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,bool emitError)659 LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
660                                             bool emitError) {
661   static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
662                                 AffineWriteOpInterface>::value,
663                 "argument should be either a AffineReadOpInterface or a "
664                 "AffineWriteOpInterface");
665 
666   Operation *op = loadOrStoreOp.getOperation();
667   MemRefRegion region(op->getLoc());
668   if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
669                             /*addMemRefDimBounds=*/false)))
670     return success();
671 
672   LLVM_DEBUG(llvm::dbgs() << "Memory region");
673   LLVM_DEBUG(region.getConstraints()->dump());
674 
675   bool outOfBounds = false;
676   unsigned rank = loadOrStoreOp.getMemRefType().getRank();
677 
678   // For each dimension, check for out of bounds.
679   for (unsigned r = 0; r < rank; r++) {
680     FlatAffineConstraints ucst(*region.getConstraints());
681 
682     // Intersect memory region with constraint capturing out of bounds (both out
683     // of upper and out of lower), and check if the constraint system is
684     // feasible. If it is, there is at least one point out of bounds.
685     SmallVector<int64_t, 4> ineq(rank + 1, 0);
686     int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
687     // TODO: handle dynamic dim sizes.
688     if (dimSize == -1)
689       continue;
690 
691     // Check for overflow: d_i >= memref dim size.
692     ucst.addBound(FlatAffineConstraints::LB, r, dimSize);
693     outOfBounds = !ucst.isEmpty();
694     if (outOfBounds && emitError) {
695       loadOrStoreOp.emitOpError()
696           << "memref out of upper bound access along dimension #" << (r + 1);
697     }
698 
699     // Check for a negative index.
700     FlatAffineConstraints lcst(*region.getConstraints());
701     std::fill(ineq.begin(), ineq.end(), 0);
702     // d_i <= -1;
703     lcst.addBound(FlatAffineConstraints::UB, r, -1);
704     outOfBounds = !lcst.isEmpty();
705     if (outOfBounds && emitError) {
706       loadOrStoreOp.emitOpError()
707           << "memref out of lower bound access along dimension #" << (r + 1);
708     }
709   }
710   return failure(outOfBounds);
711 }
712 
713 // Explicitly instantiate the template so that the compiler knows we need them!
714 template LogicalResult
715 mlir::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp, bool emitError);
716 template LogicalResult
717 mlir::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp, bool emitError);
718 
719 // Returns in 'positions' the Block positions of 'op' in each ancestor
720 // Block from the Block containing operation, stopping at 'limitBlock'.
findInstPosition(Operation * op,Block * limitBlock,SmallVectorImpl<unsigned> * positions)721 static void findInstPosition(Operation *op, Block *limitBlock,
722                              SmallVectorImpl<unsigned> *positions) {
723   Block *block = op->getBlock();
724   while (block != limitBlock) {
725     // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
726     // rely on linear scans.
727     int instPosInBlock = std::distance(block->begin(), op->getIterator());
728     positions->push_back(instPosInBlock);
729     op = block->getParentOp();
730     block = op->getBlock();
731   }
732   std::reverse(positions->begin(), positions->end());
733 }
734 
735 // Returns the Operation in a possibly nested set of Blocks, where the
736 // position of the operation is represented by 'positions', which has a
737 // Block position for each level of nesting.
getInstAtPosition(ArrayRef<unsigned> positions,unsigned level,Block * block)738 static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
739                                     unsigned level, Block *block) {
740   unsigned i = 0;
741   for (auto &op : *block) {
742     if (i != positions[level]) {
743       ++i;
744       continue;
745     }
746     if (level == positions.size() - 1)
747       return &op;
748     if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
749       return getInstAtPosition(positions, level + 1,
750                                childAffineForOp.getBody());
751 
752     for (auto &region : op.getRegions()) {
753       for (auto &b : region)
754         if (auto *ret = getInstAtPosition(positions, level + 1, &b))
755           return ret;
756     }
757     return nullptr;
758   }
759   return nullptr;
760 }
761 
762 // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
addMissingLoopIVBounds(SmallPtrSet<Value,8> & ivs,FlatAffineValueConstraints * cst)763 static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
764                                             FlatAffineValueConstraints *cst) {
765   for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
766     auto value = cst->getValue(i);
767     if (ivs.count(value) == 0) {
768       assert(isForInductionVar(value));
769       auto loop = getForInductionVarOwner(value);
770       if (failed(cst->addAffineForOpDomain(loop)))
771         return failure();
772     }
773   }
774   return success();
775 }
776 
777 /// Returns the innermost common loop depth for the set of operations in 'ops'.
778 // TODO: Move this to LoopUtils.
getInnermostCommonLoopDepth(ArrayRef<Operation * > ops,SmallVectorImpl<AffineForOp> * surroundingLoops)779 unsigned mlir::getInnermostCommonLoopDepth(
780     ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
781   unsigned numOps = ops.size();
782   assert(numOps > 0 && "Expected at least one operation");
783 
784   std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
785   unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
786   for (unsigned i = 0; i < numOps; ++i) {
787     getLoopIVs(*ops[i], &loops[i]);
788     loopDepthLimit =
789         std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
790   }
791 
792   unsigned loopDepth = 0;
793   for (unsigned d = 0; d < loopDepthLimit; ++d) {
794     unsigned i;
795     for (i = 1; i < numOps; ++i) {
796       if (loops[i - 1][d] != loops[i][d])
797         return loopDepth;
798     }
799     if (surroundingLoops)
800       surroundingLoops->push_back(loops[i - 1][d]);
801     ++loopDepth;
802   }
803   return loopDepth;
804 }
805 
806 /// Computes in 'sliceUnion' the union of all slice bounds computed at
807 /// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
808 /// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
809 /// union was computed correctly, an appropriate failure otherwise.
810 SliceComputationResult
computeSliceUnion(ArrayRef<Operation * > opsA,ArrayRef<Operation * > opsB,unsigned loopDepth,unsigned numCommonLoops,bool isBackwardSlice,ComputationSliceState * sliceUnion)811 mlir::computeSliceUnion(ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB,
812                         unsigned loopDepth, unsigned numCommonLoops,
813                         bool isBackwardSlice,
814                         ComputationSliceState *sliceUnion) {
815   // Compute the union of slice bounds between all pairs in 'opsA' and
816   // 'opsB' in 'sliceUnionCst'.
817   FlatAffineValueConstraints sliceUnionCst;
818   assert(sliceUnionCst.getNumDimAndSymbolIds() == 0);
819   std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
820   for (unsigned i = 0, numOpsA = opsA.size(); i < numOpsA; ++i) {
821     MemRefAccess srcAccess(opsA[i]);
822     for (unsigned j = 0, numOpsB = opsB.size(); j < numOpsB; ++j) {
823       MemRefAccess dstAccess(opsB[j]);
824       if (srcAccess.memref != dstAccess.memref)
825         continue;
826       // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
827       if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) ||
828           (isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) {
829         LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
830         return SliceComputationResult::GenericFailure;
831       }
832 
833       bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
834                               isa<AffineReadOpInterface>(dstAccess.opInst);
835       FlatAffineValueConstraints dependenceConstraints;
836       // Check dependence between 'srcAccess' and 'dstAccess'.
837       DependenceResult result = checkMemrefAccessDependence(
838           srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
839           &dependenceConstraints, /*dependenceComponents=*/nullptr,
840           /*allowRAR=*/readReadAccesses);
841       if (result.value == DependenceResult::Failure) {
842         LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
843         return SliceComputationResult::GenericFailure;
844       }
845       if (result.value == DependenceResult::NoDependence)
846         continue;
847       dependentOpPairs.push_back({opsA[i], opsB[j]});
848 
849       // Compute slice bounds for 'srcAccess' and 'dstAccess'.
850       ComputationSliceState tmpSliceState;
851       mlir::getComputationSliceState(opsA[i], opsB[j], &dependenceConstraints,
852                                      loopDepth, isBackwardSlice,
853                                      &tmpSliceState);
854 
855       if (sliceUnionCst.getNumDimAndSymbolIds() == 0) {
856         // Initialize 'sliceUnionCst' with the bounds computed in previous step.
857         if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
858           LLVM_DEBUG(llvm::dbgs()
859                      << "Unable to compute slice bound constraints\n");
860           return SliceComputationResult::GenericFailure;
861         }
862         assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
863         continue;
864       }
865 
866       // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
867       FlatAffineValueConstraints tmpSliceCst;
868       if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
869         LLVM_DEBUG(llvm::dbgs()
870                    << "Unable to compute slice bound constraints\n");
871         return SliceComputationResult::GenericFailure;
872       }
873 
874       // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
875       if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) {
876 
877         // Pre-constraint id alignment: record loop IVs used in each constraint
878         // system.
879         SmallPtrSet<Value, 8> sliceUnionIVs;
880         for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k)
881           sliceUnionIVs.insert(sliceUnionCst.getValue(k));
882         SmallPtrSet<Value, 8> tmpSliceIVs;
883         for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k)
884           tmpSliceIVs.insert(tmpSliceCst.getValue(k));
885 
886         sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst);
887 
888         // Post-constraint id alignment: add loop IV bounds missing after
889         // id alignment to constraint systems. This can occur if one constraint
890         // system uses an loop IV that is not used by the other. The call
891         // to unionBoundingBox below expects constraints for each Loop IV, even
892         // if they are the unsliced full loop bounds added here.
893         if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
894           return SliceComputationResult::GenericFailure;
895         if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
896           return SliceComputationResult::GenericFailure;
897       }
898       // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
899       if (sliceUnionCst.getNumLocalIds() > 0 ||
900           tmpSliceCst.getNumLocalIds() > 0 ||
901           failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
902         LLVM_DEBUG(llvm::dbgs()
903                    << "Unable to compute union bounding box of slice bounds\n");
904         return SliceComputationResult::GenericFailure;
905       }
906     }
907   }
908 
909   // Empty union.
910   if (sliceUnionCst.getNumDimAndSymbolIds() == 0)
911     return SliceComputationResult::GenericFailure;
912 
913   // Gather loops surrounding ops from loop nest where slice will be inserted.
914   SmallVector<Operation *, 4> ops;
915   for (auto &dep : dependentOpPairs) {
916     ops.push_back(isBackwardSlice ? dep.second : dep.first);
917   }
918   SmallVector<AffineForOp, 4> surroundingLoops;
919   unsigned innermostCommonLoopDepth =
920       getInnermostCommonLoopDepth(ops, &surroundingLoops);
921   if (loopDepth > innermostCommonLoopDepth) {
922     LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
923     return SliceComputationResult::GenericFailure;
924   }
925 
926   // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
927   unsigned numSliceLoopIVs = sliceUnionCst.getNumDimIds();
928 
929   // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
930   sliceUnionCst.convertLoopIVSymbolsToDims();
931   sliceUnion->clearBounds();
932   sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
933   sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
934 
935   // Get slice bounds from slice union constraints 'sliceUnionCst'.
936   sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
937                                opsA[0]->getContext(), &sliceUnion->lbs,
938                                &sliceUnion->ubs);
939 
940   // Add slice bound operands of union.
941   SmallVector<Value, 4> sliceBoundOperands;
942   sliceUnionCst.getValues(numSliceLoopIVs,
943                           sliceUnionCst.getNumDimAndSymbolIds(),
944                           &sliceBoundOperands);
945 
946   // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
947   sliceUnion->ivs.clear();
948   sliceUnionCst.getValues(0, numSliceLoopIVs, &sliceUnion->ivs);
949 
950   // Set loop nest insertion point to block start at 'loopDepth'.
951   sliceUnion->insertPoint =
952       isBackwardSlice
953           ? surroundingLoops[loopDepth - 1].getBody()->begin()
954           : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
955 
956   // Give each bound its own copy of 'sliceBoundOperands' for subsequent
957   // canonicalization.
958   sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
959   sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
960 
961   // Check if the slice computed is valid. Return success only if it is verified
962   // that the slice is valid, otherwise return appropriate failure status.
963   Optional<bool> isSliceValid = sliceUnion->isSliceValid();
964   if (!isSliceValid.hasValue()) {
965     LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
966     return SliceComputationResult::GenericFailure;
967   }
968   if (!isSliceValid.getValue())
969     return SliceComputationResult::IncorrectSliceFailure;
970 
971   return SliceComputationResult::Success;
972 }
973 
974 // TODO: extend this to handle multiple result maps.
getConstDifference(AffineMap lbMap,AffineMap ubMap)975 static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
976   assert(lbMap.getNumResults() == 1 && "expected single result bound map");
977   assert(ubMap.getNumResults() == 1 && "expected single result bound map");
978   assert(lbMap.getNumDims() == ubMap.getNumDims());
979   assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
980   AffineExpr lbExpr(lbMap.getResult(0));
981   AffineExpr ubExpr(ubMap.getResult(0));
982   auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
983                                          lbMap.getNumSymbols());
984   auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
985   if (!cExpr)
986     return None;
987   return cExpr.getValue();
988 }
989 
990 // Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
991 // nest surrounding represented by slice loop bounds in 'slice'. Returns true
992 // on success, false otherwise (if a non-constant trip count was encountered).
993 // TODO: Make this work with non-unit step loops.
buildSliceTripCountMap(const ComputationSliceState & slice,llvm::SmallDenseMap<Operation *,uint64_t,8> * tripCountMap)994 bool mlir::buildSliceTripCountMap(
995     const ComputationSliceState &slice,
996     llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
997   unsigned numSrcLoopIVs = slice.ivs.size();
998   // Populate map from AffineForOp -> trip count
999   for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1000     AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
1001     auto *op = forOp.getOperation();
1002     AffineMap lbMap = slice.lbs[i];
1003     AffineMap ubMap = slice.ubs[i];
1004     // If lower or upper bound maps are null or provide no results, it implies
1005     // that source loop was not at all sliced, and the entire loop will be a
1006     // part of the slice.
1007     if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
1008         ubMap.getNumResults() == 0) {
1009       // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
1010       if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
1011         (*tripCountMap)[op] =
1012             forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
1013         continue;
1014       }
1015       Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
1016       if (maybeConstTripCount.hasValue()) {
1017         (*tripCountMap)[op] = maybeConstTripCount.getValue();
1018         continue;
1019       }
1020       return false;
1021     }
1022     Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
1023     // Slice bounds are created with a constant ub - lb difference.
1024     if (!tripCount.hasValue())
1025       return false;
1026     (*tripCountMap)[op] = tripCount.getValue();
1027   }
1028   return true;
1029 }
1030 
1031 // Return the number of iterations in the given slice.
getSliceIterationCount(const llvm::SmallDenseMap<Operation *,uint64_t,8> & sliceTripCountMap)1032 uint64_t mlir::getSliceIterationCount(
1033     const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
1034   uint64_t iterCount = 1;
1035   for (const auto &count : sliceTripCountMap) {
1036     iterCount *= count.second;
1037   }
1038   return iterCount;
1039 }
1040 
1041 const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
1042 // Computes slice bounds by projecting out any loop IVs from
1043 // 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
1044 // bounds in 'sliceState' which represent the one loop nest's IVs in terms of
1045 // the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
getComputationSliceState(Operation * depSourceOp,Operation * depSinkOp,FlatAffineValueConstraints * dependenceConstraints,unsigned loopDepth,bool isBackwardSlice,ComputationSliceState * sliceState)1046 void mlir::getComputationSliceState(
1047     Operation *depSourceOp, Operation *depSinkOp,
1048     FlatAffineValueConstraints *dependenceConstraints, unsigned loopDepth,
1049     bool isBackwardSlice, ComputationSliceState *sliceState) {
1050   // Get loop nest surrounding src operation.
1051   SmallVector<AffineForOp, 4> srcLoopIVs;
1052   getLoopIVs(*depSourceOp, &srcLoopIVs);
1053   unsigned numSrcLoopIVs = srcLoopIVs.size();
1054 
1055   // Get loop nest surrounding dst operation.
1056   SmallVector<AffineForOp, 4> dstLoopIVs;
1057   getLoopIVs(*depSinkOp, &dstLoopIVs);
1058   unsigned numDstLoopIVs = dstLoopIVs.size();
1059 
1060   assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
1061          (isBackwardSlice && loopDepth <= numDstLoopIVs));
1062 
1063   // Project out dimensions other than those up to 'loopDepth'.
1064   unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
1065   unsigned num =
1066       isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
1067   dependenceConstraints->projectOut(pos, num);
1068 
1069   // Add slice loop IV values to 'sliceState'.
1070   unsigned offset = isBackwardSlice ? 0 : loopDepth;
1071   unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
1072   dependenceConstraints->getValues(offset, offset + numSliceLoopIVs,
1073                                    &sliceState->ivs);
1074 
1075   // Set up lower/upper bound affine maps for the slice.
1076   sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
1077   sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
1078 
1079   // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
1080   dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs,
1081                                         depSourceOp->getContext(),
1082                                         &sliceState->lbs, &sliceState->ubs);
1083 
1084   // Set up bound operands for the slice's lower and upper bounds.
1085   SmallVector<Value, 4> sliceBoundOperands;
1086   unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds();
1087   for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
1088     if (i < offset || i >= offset + numSliceLoopIVs) {
1089       sliceBoundOperands.push_back(dependenceConstraints->getValue(i));
1090     }
1091   }
1092 
1093   // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1094   // canonicalization.
1095   sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1096   sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1097 
1098   // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
1099   sliceState->insertPoint =
1100       isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
1101                       : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
1102 
1103   llvm::SmallDenseSet<Value, 8> sequentialLoops;
1104   if (isa<AffineReadOpInterface>(depSourceOp) &&
1105       isa<AffineReadOpInterface>(depSinkOp)) {
1106     // For read-read access pairs, clear any slice bounds on sequential loops.
1107     // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
1108     getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
1109                        &sequentialLoops);
1110   }
1111   auto getSliceLoop = [&](unsigned i) {
1112     return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
1113   };
1114   auto isInnermostInsertion = [&]() {
1115     return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
1116                             : loopDepth >= dstLoopIVs.size());
1117   };
1118   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
1119   auto srcIsUnitSlice = [&]() {
1120     return (buildSliceTripCountMap(*sliceState, &sliceTripCountMap) &&
1121             (getSliceIterationCount(sliceTripCountMap) == 1));
1122   };
1123   // Clear all sliced loop bounds beginning at the first sequential loop, or
1124   // first loop with a slice fusion barrier attribute..
1125 
1126   for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
1127     Value iv = getSliceLoop(i).getInductionVar();
1128     if (sequentialLoops.count(iv) == 0 &&
1129         getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
1130       continue;
1131     // Skip reset of bounds of reduction loop inserted in the destination loop
1132     // that meets the following conditions:
1133     //    1. Slice is  single trip count.
1134     //    2. Loop bounds of the source and destination match.
1135     //    3. Is being inserted at the innermost insertion point.
1136     Optional<bool> isMaximal = sliceState->isMaximal();
1137     if (isLoopParallelAndContainsReduction(getSliceLoop(i)) &&
1138         isInnermostInsertion() && srcIsUnitSlice() && isMaximal.hasValue() &&
1139         isMaximal.getValue())
1140       continue;
1141     for (unsigned j = i; j < numSliceLoopIVs; ++j) {
1142       sliceState->lbs[j] = AffineMap();
1143       sliceState->ubs[j] = AffineMap();
1144     }
1145     break;
1146   }
1147 }
1148 
1149 /// Creates a computation slice of the loop nest surrounding 'srcOpInst',
1150 /// updates the slice loop bounds with any non-null bound maps specified in
1151 /// 'sliceState', and inserts this slice into the loop nest surrounding
1152 /// 'dstOpInst' at loop depth 'dstLoopDepth'.
1153 // TODO: extend the slicing utility to compute slices that
1154 // aren't necessarily a one-to-one relation b/w the source and destination. The
1155 // relation between the source and destination could be many-to-many in general.
1156 // TODO: the slice computation is incorrect in the cases
1157 // where the dependence from the source to the destination does not cover the
1158 // entire destination index set. Subtract out the dependent destination
1159 // iterations from destination index set and check for emptiness --- this is one
1160 // solution.
1161 AffineForOp
insertBackwardComputationSlice(Operation * srcOpInst,Operation * dstOpInst,unsigned dstLoopDepth,ComputationSliceState * sliceState)1162 mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst,
1163                                      unsigned dstLoopDepth,
1164                                      ComputationSliceState *sliceState) {
1165   // Get loop nest surrounding src operation.
1166   SmallVector<AffineForOp, 4> srcLoopIVs;
1167   getLoopIVs(*srcOpInst, &srcLoopIVs);
1168   unsigned numSrcLoopIVs = srcLoopIVs.size();
1169 
1170   // Get loop nest surrounding dst operation.
1171   SmallVector<AffineForOp, 4> dstLoopIVs;
1172   getLoopIVs(*dstOpInst, &dstLoopIVs);
1173   unsigned dstLoopIVsSize = dstLoopIVs.size();
1174   if (dstLoopDepth > dstLoopIVsSize) {
1175     dstOpInst->emitError("invalid destination loop depth");
1176     return AffineForOp();
1177   }
1178 
1179   // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
1180   SmallVector<unsigned, 4> positions;
1181   // TODO: This code is incorrect since srcLoopIVs can be 0-d.
1182   findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
1183 
1184   // Clone src loop nest and insert it a the beginning of the operation block
1185   // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
1186   auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
1187   OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
1188   auto sliceLoopNest =
1189       cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
1190 
1191   Operation *sliceInst =
1192       getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
1193   // Get loop nest surrounding 'sliceInst'.
1194   SmallVector<AffineForOp, 4> sliceSurroundingLoops;
1195   getLoopIVs(*sliceInst, &sliceSurroundingLoops);
1196 
1197   // Sanity check.
1198   unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
1199   (void)sliceSurroundingLoopsSize;
1200   assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
1201   unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
1202   (void)sliceLoopLimit;
1203   assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
1204 
1205   // Update loop bounds for loops in 'sliceLoopNest'.
1206   for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1207     auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
1208     if (AffineMap lbMap = sliceState->lbs[i])
1209       forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
1210     if (AffineMap ubMap = sliceState->ubs[i])
1211       forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
1212   }
1213   return sliceLoopNest;
1214 }
1215 
1216 // Constructs  MemRefAccess populating it with the memref, its indices and
1217 // opinst from 'loadOrStoreOpInst'.
MemRefAccess(Operation * loadOrStoreOpInst)1218 MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
1219   if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
1220     memref = loadOp.getMemRef();
1221     opInst = loadOrStoreOpInst;
1222     auto loadMemrefType = loadOp.getMemRefType();
1223     indices.reserve(loadMemrefType.getRank());
1224     for (auto index : loadOp.getMapOperands()) {
1225       indices.push_back(index);
1226     }
1227   } else {
1228     assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
1229            "Affine read/write op expected");
1230     auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
1231     opInst = loadOrStoreOpInst;
1232     memref = storeOp.getMemRef();
1233     auto storeMemrefType = storeOp.getMemRefType();
1234     indices.reserve(storeMemrefType.getRank());
1235     for (auto index : storeOp.getMapOperands()) {
1236       indices.push_back(index);
1237     }
1238   }
1239 }
1240 
getRank() const1241 unsigned MemRefAccess::getRank() const {
1242   return memref.getType().cast<MemRefType>().getRank();
1243 }
1244 
isStore() const1245 bool MemRefAccess::isStore() const {
1246   return isa<AffineWriteOpInterface>(opInst);
1247 }
1248 
1249 /// Returns the nesting depth of this statement, i.e., the number of loops
1250 /// surrounding this statement.
getNestingDepth(Operation * op)1251 unsigned mlir::getNestingDepth(Operation *op) {
1252   Operation *currOp = op;
1253   unsigned depth = 0;
1254   while ((currOp = currOp->getParentOp())) {
1255     if (isa<AffineForOp>(currOp))
1256       depth++;
1257   }
1258   return depth;
1259 }
1260 
1261 /// Equal if both affine accesses are provably equivalent (at compile
1262 /// time) when considering the memref, the affine maps and their respective
1263 /// operands. The equality of access functions + operands is checked by
1264 /// subtracting fully composed value maps, and then simplifying the difference
1265 /// using the expression flattener.
1266 /// TODO: this does not account for aliasing of memrefs.
operator ==(const MemRefAccess & rhs) const1267 bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
1268   if (memref != rhs.memref)
1269     return false;
1270 
1271   AffineValueMap diff, thisMap, rhsMap;
1272   getAccessMap(&thisMap);
1273   rhs.getAccessMap(&rhsMap);
1274   AffineValueMap::difference(thisMap, rhsMap, &diff);
1275   return llvm::all_of(diff.getAffineMap().getResults(),
1276                       [](AffineExpr e) { return e == 0; });
1277 }
1278 
1279 /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
1280 /// where each lists loops from outer-most to inner-most in loop nest.
getNumCommonSurroundingLoops(Operation & A,Operation & B)1281 unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) {
1282   SmallVector<AffineForOp, 4> loopsA, loopsB;
1283   getLoopIVs(A, &loopsA);
1284   getLoopIVs(B, &loopsB);
1285 
1286   unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
1287   unsigned numCommonLoops = 0;
1288   for (unsigned i = 0; i < minNumLoops; ++i) {
1289     if (loopsA[i].getOperation() != loopsB[i].getOperation())
1290       break;
1291     ++numCommonLoops;
1292   }
1293   return numCommonLoops;
1294 }
1295 
getMemoryFootprintBytes(Block & block,Block::iterator start,Block::iterator end,int memorySpace)1296 static Optional<int64_t> getMemoryFootprintBytes(Block &block,
1297                                                  Block::iterator start,
1298                                                  Block::iterator end,
1299                                                  int memorySpace) {
1300   SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
1301 
1302   // Walk this 'affine.for' operation to gather all memory regions.
1303   auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
1304     if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
1305       // Neither load nor a store op.
1306       return WalkResult::advance();
1307     }
1308 
1309     // Compute the memref region symbolic in any IVs enclosing this block.
1310     auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
1311     if (failed(
1312             region->compute(opInst,
1313                             /*loopDepth=*/getNestingDepth(&*block.begin())))) {
1314       return opInst->emitError("error obtaining memory region\n");
1315     }
1316 
1317     auto it = regions.find(region->memref);
1318     if (it == regions.end()) {
1319       regions[region->memref] = std::move(region);
1320     } else if (failed(it->second->unionBoundingBox(*region))) {
1321       return opInst->emitWarning(
1322           "getMemoryFootprintBytes: unable to perform a union on a memory "
1323           "region");
1324     }
1325     return WalkResult::advance();
1326   });
1327   if (result.wasInterrupted())
1328     return None;
1329 
1330   int64_t totalSizeInBytes = 0;
1331   for (const auto &region : regions) {
1332     Optional<int64_t> size = region.second->getRegionSize();
1333     if (!size.hasValue())
1334       return None;
1335     totalSizeInBytes += size.getValue();
1336   }
1337   return totalSizeInBytes;
1338 }
1339 
getMemoryFootprintBytes(AffineForOp forOp,int memorySpace)1340 Optional<int64_t> mlir::getMemoryFootprintBytes(AffineForOp forOp,
1341                                                 int memorySpace) {
1342   auto *forInst = forOp.getOperation();
1343   return ::getMemoryFootprintBytes(
1344       *forInst->getBlock(), Block::iterator(forInst),
1345       std::next(Block::iterator(forInst)), memorySpace);
1346 }
1347 
1348 /// Returns whether a loop is parallel and contains a reduction loop.
isLoopParallelAndContainsReduction(AffineForOp forOp)1349 bool mlir::isLoopParallelAndContainsReduction(AffineForOp forOp) {
1350   SmallVector<LoopReduction> reductions;
1351   if (!isLoopParallel(forOp, &reductions))
1352     return false;
1353   return !reductions.empty();
1354 }
1355 
1356 /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
1357 /// at 'forOp'.
getSequentialLoops(AffineForOp forOp,llvm::SmallDenseSet<Value,8> * sequentialLoops)1358 void mlir::getSequentialLoops(AffineForOp forOp,
1359                               llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
1360   forOp->walk([&](Operation *op) {
1361     if (auto innerFor = dyn_cast<AffineForOp>(op))
1362       if (!isLoopParallel(innerFor))
1363         sequentialLoops->insert(innerFor.getInductionVar());
1364   });
1365 }
1366 
simplifyIntegerSet(IntegerSet set)1367 IntegerSet mlir::simplifyIntegerSet(IntegerSet set) {
1368   FlatAffineConstraints fac(set);
1369   if (fac.isEmpty())
1370     return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
1371                                    set.getContext());
1372   fac.removeTrivialRedundancy();
1373 
1374   auto simplifiedSet = fac.getAsIntegerSet(set.getContext());
1375   assert(simplifiedSet && "guaranteed to succeed while roundtripping");
1376   return simplifiedSet;
1377 }
1378