1 //===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous analysis routines for non-loop IR
10 // structures.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Analysis/Utils.h"
15 #include "mlir/Analysis/AffineAnalysis.h"
16 #include "mlir/Analysis/LoopAnalysis.h"
17 #include "mlir/Analysis/PresburgerSet.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/raw_ostream.h"
25
26 #define DEBUG_TYPE "analysis-utils"
27
28 using namespace mlir;
29
30 using llvm::SmallDenseMap;
31
32 /// Populates 'loops' with IVs of the loops surrounding 'op' ordered from
33 /// the outermost 'affine.for' operation to the innermost one.
getLoopIVs(Operation & op,SmallVectorImpl<AffineForOp> * loops)34 void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
35 auto *currOp = op.getParentOp();
36 AffineForOp currAffineForOp;
37 // Traverse up the hierarchy collecting all 'affine.for' operation while
38 // skipping over 'affine.if' operations.
39 while (currOp) {
40 if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
41 loops->push_back(currAffineForOp);
42 currOp = currOp->getParentOp();
43 }
44 std::reverse(loops->begin(), loops->end());
45 }
46
47 /// Populates 'ops' with IVs of the loops surrounding `op`, along with
48 /// `affine.if` operations interleaved between these loops, ordered from the
49 /// outermost `affine.for` operation to the innermost one.
getEnclosingAffineForAndIfOps(Operation & op,SmallVectorImpl<Operation * > * ops)50 void mlir::getEnclosingAffineForAndIfOps(Operation &op,
51 SmallVectorImpl<Operation *> *ops) {
52 ops->clear();
53 Operation *currOp = op.getParentOp();
54
55 // Traverse up the hierarchy collecting all `affine.for` and `affine.if`
56 // operations.
57 while (currOp) {
58 if (isa<AffineIfOp, AffineForOp>(currOp))
59 ops->push_back(currOp);
60 currOp = currOp->getParentOp();
61 }
62 std::reverse(ops->begin(), ops->end());
63 }
64
65 // Populates 'cst' with FlatAffineConstraints which represent original domain of
66 // the loop bounds that define 'ivs'.
67 LogicalResult
getSourceAsConstraints(FlatAffineConstraints & cst)68 ComputationSliceState::getSourceAsConstraints(FlatAffineConstraints &cst) {
69 assert(!ivs.empty() && "Cannot have a slice without its IVs");
70 cst.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs);
71 for (Value iv : ivs) {
72 AffineForOp loop = getForInductionVarOwner(iv);
73 assert(loop && "Expected affine for");
74 if (failed(cst.addAffineForOpDomain(loop)))
75 return failure();
76 }
77 return success();
78 }
79
80 // Populates 'cst' with FlatAffineConstraints which represent slice bounds.
81 LogicalResult
getAsConstraints(FlatAffineConstraints * cst)82 ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
83 assert(!lbOperands.empty());
84 // Adds src 'ivs' as dimension identifiers in 'cst'.
85 unsigned numDims = ivs.size();
86 // Adds operands (dst ivs and symbols) as symbols in 'cst'.
87 unsigned numSymbols = lbOperands[0].size();
88
89 SmallVector<Value, 4> values(ivs);
90 // Append 'ivs' then 'operands' to 'values'.
91 values.append(lbOperands[0].begin(), lbOperands[0].end());
92 cst->reset(numDims, numSymbols, 0, values);
93
94 // Add loop bound constraints for values which are loop IVs of the destination
95 // of fusion and equality constraints for symbols which are constants.
96 for (unsigned i = numDims, end = values.size(); i < end; ++i) {
97 Value value = values[i];
98 assert(cst->containsId(value) && "value expected to be present");
99 if (isValidSymbol(value)) {
100 // Check if the symbol is a constant.
101 if (auto cOp = value.getDefiningOp<ConstantIndexOp>())
102 cst->setIdToConstant(value, cOp.getValue());
103 } else if (auto loop = getForInductionVarOwner(value)) {
104 if (failed(cst->addAffineForOpDomain(loop)))
105 return failure();
106 }
107 }
108
109 // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
110 LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
111 assert(succeeded(ret) &&
112 "should not fail as we never have semi-affine slice maps");
113 (void)ret;
114 return success();
115 }
116
117 // Clears state bounds and operand state.
clearBounds()118 void ComputationSliceState::clearBounds() {
119 lbs.clear();
120 ubs.clear();
121 lbOperands.clear();
122 ubOperands.clear();
123 }
124
dump() const125 void ComputationSliceState::dump() const {
126 llvm::errs() << "\tIVs:\n";
127 for (Value iv : ivs)
128 llvm::errs() << "\t\t" << iv << "\n";
129
130 llvm::errs() << "\tLBs:\n";
131 for (auto &en : llvm::enumerate(lbs)) {
132 llvm::errs() << "\t\t" << en.value() << "\n";
133 llvm::errs() << "\t\tOperands:\n";
134 for (Value lbOp : lbOperands[en.index()])
135 llvm::errs() << "\t\t\t" << lbOp << "\n";
136 }
137
138 llvm::errs() << "\tUBs:\n";
139 for (auto &en : llvm::enumerate(ubs)) {
140 llvm::errs() << "\t\t" << en.value() << "\n";
141 llvm::errs() << "\t\tOperands:\n";
142 for (Value ubOp : ubOperands[en.index()])
143 llvm::errs() << "\t\t\t" << ubOp << "\n";
144 }
145 }
146
147 /// Fast check to determine if the computation slice is maximal. Returns true if
148 /// each slice dimension maps to an existing dst dimension and both the src
149 /// and the dst loops for those dimensions have the same bounds. Returns false
150 /// if both the src and the dst loops don't have the same bounds. Returns
151 /// llvm::None if none of the above can be proven.
isSliceMaximalFastCheck() const152 Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
153 assert(lbs.size() == ubs.size() && lbs.size() && ivs.size() &&
154 "Unexpected number of lbs, ubs and ivs in slice");
155
156 for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
157 AffineMap lbMap = lbs[i];
158 AffineMap ubMap = ubs[i];
159
160 // Check if this slice is just an equality along this dimension.
161 if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
162 ubMap.getNumResults() != 1 ||
163 lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
164 // The condition above will be true for maps describing a single
165 // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
166 // Make sure we skip those cases by checking that the lb result is not
167 // just a constant.
168 lbMap.getResult(0).isa<AffineConstantExpr>())
169 return llvm::None;
170
171 // Limited support: we expect the lb result to be just a loop dimension for
172 // now.
173 AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
174 if (!result)
175 return llvm::None;
176
177 // Retrieve dst loop bounds.
178 AffineForOp dstLoop =
179 getForInductionVarOwner(lbOperands[i][result.getPosition()]);
180 if (!dstLoop)
181 return llvm::None;
182 AffineMap dstLbMap = dstLoop.getLowerBoundMap();
183 AffineMap dstUbMap = dstLoop.getUpperBoundMap();
184
185 // Retrieve src loop bounds.
186 AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
187 assert(srcLoop && "Expected affine for");
188 AffineMap srcLbMap = srcLoop.getLowerBoundMap();
189 AffineMap srcUbMap = srcLoop.getUpperBoundMap();
190
191 // Limited support: we expect simple src and dst loops with a single
192 // constant component per bound for now.
193 if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
194 dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
195 return llvm::None;
196
197 AffineExpr srcLbResult = srcLbMap.getResult(0);
198 AffineExpr dstLbResult = dstLbMap.getResult(0);
199 AffineExpr srcUbResult = srcUbMap.getResult(0);
200 AffineExpr dstUbResult = dstUbMap.getResult(0);
201 if (!srcLbResult.isa<AffineConstantExpr>() ||
202 !srcUbResult.isa<AffineConstantExpr>() ||
203 !dstLbResult.isa<AffineConstantExpr>() ||
204 !dstUbResult.isa<AffineConstantExpr>())
205 return llvm::None;
206
207 // Check if src and dst loop bounds are the same. If not, we can guarantee
208 // that the slice is not maximal.
209 if (srcLbResult != dstLbResult || srcUbResult != dstUbResult)
210 return false;
211 }
212
213 return true;
214 }
215
216 /// Returns true if it is deterministically verified that the original iteration
217 /// space of the slice is contained within the new iteration space that is
218 /// created after fusing 'this' slice into its destination.
isSliceValid()219 Optional<bool> ComputationSliceState::isSliceValid() {
220 // Fast check to determine if the slice is valid. If the following conditions
221 // are verified to be true, slice is declared valid by the fast check:
222 // 1. Each slice loop is a single iteration loop bound in terms of a single
223 // destination loop IV.
224 // 2. Loop bounds of the destination loop IV (from above) and those of the
225 // source loop IV are exactly the same.
226 // If the fast check is inconclusive or false, we proceed with a more
227 // expensive analysis.
228 // TODO: Store the result of the fast check, as it might be used again in
229 // `canRemoveSrcNodeAfterFusion`.
230 Optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
231 if (isValidFastCheck.hasValue() && isValidFastCheck.getValue())
232 return true;
233
234 // Create constraints for the source loop nest using which slice is computed.
235 FlatAffineConstraints srcConstraints;
236 // TODO: Store the source's domain to avoid computation at each depth.
237 if (failed(getSourceAsConstraints(srcConstraints))) {
238 LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
239 return llvm::None;
240 }
241 // As the set difference utility currently cannot handle symbols in its
242 // operands, validity of the slice cannot be determined.
243 if (srcConstraints.getNumSymbolIds() > 0) {
244 LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
245 return llvm::None;
246 }
247 // TODO: Handle local ids in the source domains while using the 'projectOut'
248 // utility below. Currently, aligning is not done assuming that there will be
249 // no local ids in the source domain.
250 if (srcConstraints.getNumLocalIds() != 0) {
251 LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
252 return llvm::None;
253 }
254
255 // Create constraints for the slice loop nest that would be created if the
256 // fusion succeeds.
257 FlatAffineConstraints sliceConstraints;
258 if (failed(getAsConstraints(&sliceConstraints))) {
259 LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
260 return llvm::None;
261 }
262
263 // Projecting out every dimension other than the 'ivs' to express slice's
264 // domain completely in terms of source's IVs.
265 sliceConstraints.projectOut(ivs.size(),
266 sliceConstraints.getNumIds() - ivs.size());
267
268 LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
269 LLVM_DEBUG(srcConstraints.dump());
270 LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
271 "(expressed in terms of its source's IVs):\n");
272 LLVM_DEBUG(sliceConstraints.dump());
273
274 // TODO: Store 'srcSet' to avoid recalculating for each depth.
275 PresburgerSet srcSet(srcConstraints);
276 PresburgerSet sliceSet(sliceConstraints);
277 PresburgerSet diffSet = sliceSet.subtract(srcSet);
278
279 if (!diffSet.isIntegerEmpty()) {
280 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
281 return false;
282 }
283 return true;
284 }
285
286 /// Returns true if the computation slice encloses all the iterations of the
287 /// sliced loop nest. Returns false if it does not. Returns llvm::None if it
288 /// cannot determine if the slice is maximal or not.
isMaximal() const289 Optional<bool> ComputationSliceState::isMaximal() const {
290 // Fast check to determine if the computation slice is maximal. If the result
291 // is inconclusive, we proceed with a more expensive analysis.
292 Optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
293 if (isMaximalFastCheck.hasValue())
294 return isMaximalFastCheck;
295
296 // Create constraints for the src loop nest being sliced.
297 FlatAffineConstraints srcConstraints;
298 srcConstraints.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0,
299 /*numLocals=*/0, ivs);
300 for (Value iv : ivs) {
301 AffineForOp loop = getForInductionVarOwner(iv);
302 assert(loop && "Expected affine for");
303 if (failed(srcConstraints.addAffineForOpDomain(loop)))
304 return llvm::None;
305 }
306
307 // Create constraints for the slice using the dst loop nest information. We
308 // retrieve existing dst loops from the lbOperands.
309 SmallVector<Value, 8> consumerIVs;
310 for (Value lbOp : lbOperands[0])
311 if (getForInductionVarOwner(lbOp))
312 consumerIVs.push_back(lbOp);
313
314 // Add empty IV Values for those new loops that are not equalities and,
315 // therefore, are not yet materialized in the IR.
316 for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
317 consumerIVs.push_back(Value());
318
319 FlatAffineConstraints sliceConstraints;
320 sliceConstraints.reset(/*numDims=*/consumerIVs.size(), /*numSymbols=*/0,
321 /*numLocals=*/0, consumerIVs);
322
323 if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
324 return llvm::None;
325
326 if (srcConstraints.getNumDimIds() != sliceConstraints.getNumDimIds())
327 // Constraint dims are different. The integer set difference can't be
328 // computed so we don't know if the slice is maximal.
329 return llvm::None;
330
331 // Compute the difference between the src loop nest and the slice integer
332 // sets.
333 PresburgerSet srcSet(srcConstraints);
334 PresburgerSet sliceSet(sliceConstraints);
335 PresburgerSet diffSet = srcSet.subtract(sliceSet);
336 return diffSet.isIntegerEmpty();
337 }
338
getRank() const339 unsigned MemRefRegion::getRank() const {
340 return memref.getType().cast<MemRefType>().getRank();
341 }
342
getConstantBoundingSizeAndShape(SmallVectorImpl<int64_t> * shape,std::vector<SmallVector<int64_t,4>> * lbs,SmallVectorImpl<int64_t> * lbDivisors) const343 Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
344 SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
345 SmallVectorImpl<int64_t> *lbDivisors) const {
346 auto memRefType = memref.getType().cast<MemRefType>();
347 unsigned rank = memRefType.getRank();
348 if (shape)
349 shape->reserve(rank);
350
351 assert(rank == cst.getNumDimIds() && "inconsistent memref region");
352
353 // Use a copy of the region constraints that has upper/lower bounds for each
354 // memref dimension with static size added to guard against potential
355 // over-approximation from projection or union bounding box. We may not add
356 // this on the region itself since they might just be redundant constraints
357 // that will need non-trivials means to eliminate.
358 FlatAffineConstraints cstWithShapeBounds(cst);
359 for (unsigned r = 0; r < rank; r++) {
360 cstWithShapeBounds.addConstantLowerBound(r, 0);
361 int64_t dimSize = memRefType.getDimSize(r);
362 if (ShapedType::isDynamic(dimSize))
363 continue;
364 cstWithShapeBounds.addConstantUpperBound(r, dimSize - 1);
365 }
366
367 // Find a constant upper bound on the extent of this memref region along each
368 // dimension.
369 int64_t numElements = 1;
370 int64_t diffConstant;
371 int64_t lbDivisor;
372 for (unsigned d = 0; d < rank; d++) {
373 SmallVector<int64_t, 4> lb;
374 Optional<int64_t> diff =
375 cstWithShapeBounds.getConstantBoundOnDimSize(d, &lb, &lbDivisor);
376 if (diff.hasValue()) {
377 diffConstant = diff.getValue();
378 assert(diffConstant >= 0 && "Dim size bound can't be negative");
379 assert(lbDivisor > 0);
380 } else {
381 // If no constant bound is found, then it can always be bound by the
382 // memref's dim size if the latter has a constant size along this dim.
383 auto dimSize = memRefType.getDimSize(d);
384 if (dimSize == -1)
385 return None;
386 diffConstant = dimSize;
387 // Lower bound becomes 0.
388 lb.resize(cstWithShapeBounds.getNumSymbolIds() + 1, 0);
389 lbDivisor = 1;
390 }
391 numElements *= diffConstant;
392 if (lbs) {
393 lbs->push_back(lb);
394 assert(lbDivisors && "both lbs and lbDivisor or none");
395 lbDivisors->push_back(lbDivisor);
396 }
397 if (shape) {
398 shape->push_back(diffConstant);
399 }
400 }
401 return numElements;
402 }
403
getLowerAndUpperBound(unsigned pos,AffineMap & lbMap,AffineMap & ubMap) const404 void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
405 AffineMap &ubMap) const {
406 assert(pos < cst.getNumDimIds() && "invalid position");
407 auto memRefType = memref.getType().cast<MemRefType>();
408 unsigned rank = memRefType.getRank();
409
410 assert(rank == cst.getNumDimIds() && "inconsistent memref region");
411
412 auto boundPairs = cst.getLowerAndUpperBound(
413 pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolIds(),
414 /*localExprs=*/{}, memRefType.getContext());
415 lbMap = boundPairs.first;
416 ubMap = boundPairs.second;
417 assert(lbMap && "lower bound for a region must exist");
418 assert(ubMap && "upper bound for a region must exist");
419 assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolIds() - rank);
420 assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolIds() - rank);
421 }
422
unionBoundingBox(const MemRefRegion & other)423 LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
424 assert(memref == other.memref);
425 return cst.unionBoundingBox(*other.getConstraints());
426 }
427
428 /// Computes the memory region accessed by this memref with the region
429 /// represented as constraints symbolic/parametric in 'loopDepth' loops
430 /// surrounding opInst and any additional Function symbols.
431 // For example, the memref region for this load operation at loopDepth = 1 will
432 // be as below:
433 //
434 // affine.for %i = 0 to 32 {
435 // affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
436 // load %A[%ii]
437 // }
438 // }
439 //
440 // region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
441 // The last field is a 2-d FlatAffineConstraints symbolic in %i.
442 //
443 // TODO: extend this to any other memref dereferencing ops
444 // (dma_start, dma_wait).
compute(Operation * op,unsigned loopDepth,const ComputationSliceState * sliceState,bool addMemRefDimBounds)445 LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
446 const ComputationSliceState *sliceState,
447 bool addMemRefDimBounds) {
448 assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
449 "affine read/write op expected");
450
451 MemRefAccess access(op);
452 memref = access.memref;
453 write = access.isStore();
454
455 unsigned rank = access.getRank();
456
457 LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
458 << "depth: " << loopDepth << "\n";);
459
460 // 0-d memrefs.
461 if (rank == 0) {
462 SmallVector<AffineForOp, 4> ivs;
463 getLoopIVs(*op, &ivs);
464 assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
465 // The first 'loopDepth' IVs are symbols for this region.
466 ivs.resize(loopDepth);
467 SmallVector<Value, 4> regionSymbols;
468 extractForInductionVars(ivs, ®ionSymbols);
469 // A 0-d memref has a 0-d region.
470 cst.reset(rank, loopDepth, /*numLocals=*/0, regionSymbols);
471 return success();
472 }
473
474 // Build the constraints for this region.
475 AffineValueMap accessValueMap;
476 access.getAccessMap(&accessValueMap);
477 AffineMap accessMap = accessValueMap.getAffineMap();
478
479 unsigned numDims = accessMap.getNumDims();
480 unsigned numSymbols = accessMap.getNumSymbols();
481 unsigned numOperands = accessValueMap.getNumOperands();
482 // Merge operands with slice operands.
483 SmallVector<Value, 4> operands;
484 operands.resize(numOperands);
485 for (unsigned i = 0; i < numOperands; ++i)
486 operands[i] = accessValueMap.getOperand(i);
487
488 if (sliceState != nullptr) {
489 operands.reserve(operands.size() + sliceState->lbOperands[0].size());
490 // Append slice operands to 'operands' as symbols.
491 for (auto extraOperand : sliceState->lbOperands[0]) {
492 if (!llvm::is_contained(operands, extraOperand)) {
493 operands.push_back(extraOperand);
494 numSymbols++;
495 }
496 }
497 }
498 // We'll first associate the dims and symbols of the access map to the dims
499 // and symbols resp. of cst. This will change below once cst is
500 // fully constructed out.
501 cst.reset(numDims, numSymbols, 0, operands);
502
503 // Add equality constraints.
504 // Add inequalities for loop lower/upper bounds.
505 for (unsigned i = 0; i < numDims + numSymbols; ++i) {
506 auto operand = operands[i];
507 if (auto loop = getForInductionVarOwner(operand)) {
508 // Note that cst can now have more dimensions than accessMap if the
509 // bounds expressions involve outer loops or other symbols.
510 // TODO: rewrite this to use getInstIndexSet; this way
511 // conditionals will be handled when the latter supports it.
512 if (failed(cst.addAffineForOpDomain(loop)))
513 return failure();
514 } else {
515 // Has to be a valid symbol.
516 auto symbol = operand;
517 assert(isValidSymbol(symbol));
518 // Check if the symbol is a constant.
519 if (auto *op = symbol.getDefiningOp()) {
520 if (auto constOp = dyn_cast<ConstantIndexOp>(op)) {
521 cst.setIdToConstant(symbol, constOp.getValue());
522 }
523 }
524 }
525 }
526
527 // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
528 if (sliceState != nullptr) {
529 // Add dim and symbol slice operands.
530 for (auto operand : sliceState->lbOperands[0]) {
531 cst.addInductionVarOrTerminalSymbol(operand);
532 }
533 // Add upper/lower bounds from 'sliceState' to 'cst'.
534 LogicalResult ret =
535 cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
536 sliceState->lbOperands[0]);
537 assert(succeeded(ret) &&
538 "should not fail as we never have semi-affine slice maps");
539 (void)ret;
540 }
541
542 // Add access function equalities to connect loop IVs to data dimensions.
543 if (failed(cst.composeMap(&accessValueMap))) {
544 op->emitError("getMemRefRegion: compose affine map failed");
545 LLVM_DEBUG(accessValueMap.getAffineMap().dump());
546 return failure();
547 }
548
549 // Set all identifiers appearing after the first 'rank' identifiers as
550 // symbolic identifiers - so that the ones corresponding to the memref
551 // dimensions are the dimensional identifiers for the memref region.
552 cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank);
553
554 // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
555 // this memref region is symbolic.
556 SmallVector<AffineForOp, 4> enclosingIVs;
557 getLoopIVs(*op, &enclosingIVs);
558 assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
559 enclosingIVs.resize(loopDepth);
560 SmallVector<Value, 4> ids;
561 cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids);
562 for (auto id : ids) {
563 AffineForOp iv;
564 if ((iv = getForInductionVarOwner(id)) &&
565 llvm::is_contained(enclosingIVs, iv) == false) {
566 cst.projectOut(id);
567 }
568 }
569
570 // Project out any local variables (these would have been added for any
571 // mod/divs).
572 cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds());
573
574 // Constant fold any symbolic identifiers.
575 cst.constantFoldIdRange(/*pos=*/cst.getNumDimIds(),
576 /*num=*/cst.getNumSymbolIds());
577
578 assert(cst.getNumDimIds() == rank && "unexpected MemRefRegion format");
579
580 // Add upper/lower bounds for each memref dimension with static size
581 // to guard against potential over-approximation from projection.
582 // TODO: Support dynamic memref dimensions.
583 if (addMemRefDimBounds) {
584 auto memRefType = memref.getType().cast<MemRefType>();
585 for (unsigned r = 0; r < rank; r++) {
586 cst.addConstantLowerBound(/*pos=*/r, /*lb=*/0);
587 if (memRefType.isDynamicDim(r))
588 continue;
589 cst.addConstantUpperBound(/*pos=*/r, memRefType.getDimSize(r) - 1);
590 }
591 }
592 cst.removeTrivialRedundancy();
593
594 LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
595 LLVM_DEBUG(cst.dump());
596 return success();
597 }
598
getMemRefEltSizeInBytes(MemRefType memRefType)599 static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
600 auto elementType = memRefType.getElementType();
601
602 unsigned sizeInBits;
603 if (elementType.isIntOrFloat()) {
604 sizeInBits = elementType.getIntOrFloatBitWidth();
605 } else {
606 auto vectorType = elementType.cast<VectorType>();
607 sizeInBits =
608 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
609 }
610 return llvm::divideCeil(sizeInBits, 8);
611 }
612
613 // Returns the size of the region.
getRegionSize()614 Optional<int64_t> MemRefRegion::getRegionSize() {
615 auto memRefType = memref.getType().cast<MemRefType>();
616
617 auto layoutMaps = memRefType.getAffineMaps();
618 if (layoutMaps.size() > 1 ||
619 (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
620 LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
621 return false;
622 }
623
624 // Indices to use for the DmaStart op.
625 // Indices for the original memref being DMAed from/to.
626 SmallVector<Value, 4> memIndices;
627 // Indices for the faster buffer being DMAed into/from.
628 SmallVector<Value, 4> bufIndices;
629
630 // Compute the extents of the buffer.
631 Optional<int64_t> numElements = getConstantBoundingSizeAndShape();
632 if (!numElements.hasValue()) {
633 LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
634 return None;
635 }
636 return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
637 }
638
639 /// Returns the size of memref data in bytes if it's statically shaped, None
640 /// otherwise. If the element of the memref has vector type, takes into account
641 /// size of the vector as well.
642 // TODO: improve/complete this when we have target data.
getMemRefSizeInBytes(MemRefType memRefType)643 Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
644 if (!memRefType.hasStaticShape())
645 return None;
646 auto elementType = memRefType.getElementType();
647 if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
648 return None;
649
650 uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
651 for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
652 sizeInBytes = sizeInBytes * memRefType.getDimSize(i);
653 }
654 return sizeInBytes;
655 }
656
657 template <typename LoadOrStoreOp>
boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,bool emitError)658 LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
659 bool emitError) {
660 static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
661 AffineWriteOpInterface>::value,
662 "argument should be either a AffineReadOpInterface or a "
663 "AffineWriteOpInterface");
664
665 Operation *op = loadOrStoreOp.getOperation();
666 MemRefRegion region(op->getLoc());
667 if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
668 /*addMemRefDimBounds=*/false)))
669 return success();
670
671 LLVM_DEBUG(llvm::dbgs() << "Memory region");
672 LLVM_DEBUG(region.getConstraints()->dump());
673
674 bool outOfBounds = false;
675 unsigned rank = loadOrStoreOp.getMemRefType().getRank();
676
677 // For each dimension, check for out of bounds.
678 for (unsigned r = 0; r < rank; r++) {
679 FlatAffineConstraints ucst(*region.getConstraints());
680
681 // Intersect memory region with constraint capturing out of bounds (both out
682 // of upper and out of lower), and check if the constraint system is
683 // feasible. If it is, there is at least one point out of bounds.
684 SmallVector<int64_t, 4> ineq(rank + 1, 0);
685 int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
686 // TODO: handle dynamic dim sizes.
687 if (dimSize == -1)
688 continue;
689
690 // Check for overflow: d_i >= memref dim size.
691 ucst.addConstantLowerBound(r, dimSize);
692 outOfBounds = !ucst.isEmpty();
693 if (outOfBounds && emitError) {
694 loadOrStoreOp.emitOpError()
695 << "memref out of upper bound access along dimension #" << (r + 1);
696 }
697
698 // Check for a negative index.
699 FlatAffineConstraints lcst(*region.getConstraints());
700 std::fill(ineq.begin(), ineq.end(), 0);
701 // d_i <= -1;
702 lcst.addConstantUpperBound(r, -1);
703 outOfBounds = !lcst.isEmpty();
704 if (outOfBounds && emitError) {
705 loadOrStoreOp.emitOpError()
706 << "memref out of lower bound access along dimension #" << (r + 1);
707 }
708 }
709 return failure(outOfBounds);
710 }
711
712 // Explicitly instantiate the template so that the compiler knows we need them!
713 template LogicalResult
714 mlir::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp, bool emitError);
715 template LogicalResult
716 mlir::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp, bool emitError);
717
718 // Returns in 'positions' the Block positions of 'op' in each ancestor
719 // Block from the Block containing operation, stopping at 'limitBlock'.
findInstPosition(Operation * op,Block * limitBlock,SmallVectorImpl<unsigned> * positions)720 static void findInstPosition(Operation *op, Block *limitBlock,
721 SmallVectorImpl<unsigned> *positions) {
722 Block *block = op->getBlock();
723 while (block != limitBlock) {
724 // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
725 // rely on linear scans.
726 int instPosInBlock = std::distance(block->begin(), op->getIterator());
727 positions->push_back(instPosInBlock);
728 op = block->getParentOp();
729 block = op->getBlock();
730 }
731 std::reverse(positions->begin(), positions->end());
732 }
733
734 // Returns the Operation in a possibly nested set of Blocks, where the
735 // position of the operation is represented by 'positions', which has a
736 // Block position for each level of nesting.
getInstAtPosition(ArrayRef<unsigned> positions,unsigned level,Block * block)737 static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
738 unsigned level, Block *block) {
739 unsigned i = 0;
740 for (auto &op : *block) {
741 if (i != positions[level]) {
742 ++i;
743 continue;
744 }
745 if (level == positions.size() - 1)
746 return &op;
747 if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
748 return getInstAtPosition(positions, level + 1,
749 childAffineForOp.getBody());
750
751 for (auto ®ion : op.getRegions()) {
752 for (auto &b : region)
753 if (auto *ret = getInstAtPosition(positions, level + 1, &b))
754 return ret;
755 }
756 return nullptr;
757 }
758 return nullptr;
759 }
760
761 // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
addMissingLoopIVBounds(SmallPtrSet<Value,8> & ivs,FlatAffineConstraints * cst)762 static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
763 FlatAffineConstraints *cst) {
764 for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
765 auto value = cst->getIdValue(i);
766 if (ivs.count(value) == 0) {
767 assert(isForInductionVar(value));
768 auto loop = getForInductionVarOwner(value);
769 if (failed(cst->addAffineForOpDomain(loop)))
770 return failure();
771 }
772 }
773 return success();
774 }
775
776 /// Returns the innermost common loop depth for the set of operations in 'ops'.
777 // TODO: Move this to LoopUtils.
getInnermostCommonLoopDepth(ArrayRef<Operation * > ops,SmallVectorImpl<AffineForOp> * surroundingLoops)778 unsigned mlir::getInnermostCommonLoopDepth(
779 ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
780 unsigned numOps = ops.size();
781 assert(numOps > 0 && "Expected at least one operation");
782
783 std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
784 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
785 for (unsigned i = 0; i < numOps; ++i) {
786 getLoopIVs(*ops[i], &loops[i]);
787 loopDepthLimit =
788 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
789 }
790
791 unsigned loopDepth = 0;
792 for (unsigned d = 0; d < loopDepthLimit; ++d) {
793 unsigned i;
794 for (i = 1; i < numOps; ++i) {
795 if (loops[i - 1][d] != loops[i][d])
796 return loopDepth;
797 }
798 if (surroundingLoops)
799 surroundingLoops->push_back(loops[i - 1][d]);
800 ++loopDepth;
801 }
802 return loopDepth;
803 }
804
805 /// Computes in 'sliceUnion' the union of all slice bounds computed at
806 /// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
807 /// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
808 /// union was computed correctly, an appropriate failure otherwise.
809 SliceComputationResult
computeSliceUnion(ArrayRef<Operation * > opsA,ArrayRef<Operation * > opsB,unsigned loopDepth,unsigned numCommonLoops,bool isBackwardSlice,ComputationSliceState * sliceUnion)810 mlir::computeSliceUnion(ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB,
811 unsigned loopDepth, unsigned numCommonLoops,
812 bool isBackwardSlice,
813 ComputationSliceState *sliceUnion) {
814 // Compute the union of slice bounds between all pairs in 'opsA' and
815 // 'opsB' in 'sliceUnionCst'.
816 FlatAffineConstraints sliceUnionCst;
817 assert(sliceUnionCst.getNumDimAndSymbolIds() == 0);
818 std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
819 for (unsigned i = 0, numOpsA = opsA.size(); i < numOpsA; ++i) {
820 MemRefAccess srcAccess(opsA[i]);
821 for (unsigned j = 0, numOpsB = opsB.size(); j < numOpsB; ++j) {
822 MemRefAccess dstAccess(opsB[j]);
823 if (srcAccess.memref != dstAccess.memref)
824 continue;
825 // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
826 if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) ||
827 (isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) {
828 LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
829 return SliceComputationResult::GenericFailure;
830 }
831
832 bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
833 isa<AffineReadOpInterface>(dstAccess.opInst);
834 FlatAffineConstraints dependenceConstraints;
835 // Check dependence between 'srcAccess' and 'dstAccess'.
836 DependenceResult result = checkMemrefAccessDependence(
837 srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
838 &dependenceConstraints, /*dependenceComponents=*/nullptr,
839 /*allowRAR=*/readReadAccesses);
840 if (result.value == DependenceResult::Failure) {
841 LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
842 return SliceComputationResult::GenericFailure;
843 }
844 if (result.value == DependenceResult::NoDependence)
845 continue;
846 dependentOpPairs.push_back({opsA[i], opsB[j]});
847
848 // Compute slice bounds for 'srcAccess' and 'dstAccess'.
849 ComputationSliceState tmpSliceState;
850 mlir::getComputationSliceState(opsA[i], opsB[j], &dependenceConstraints,
851 loopDepth, isBackwardSlice,
852 &tmpSliceState);
853
854 if (sliceUnionCst.getNumDimAndSymbolIds() == 0) {
855 // Initialize 'sliceUnionCst' with the bounds computed in previous step.
856 if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
857 LLVM_DEBUG(llvm::dbgs()
858 << "Unable to compute slice bound constraints\n");
859 return SliceComputationResult::GenericFailure;
860 }
861 assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
862 continue;
863 }
864
865 // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
866 FlatAffineConstraints tmpSliceCst;
867 if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
868 LLVM_DEBUG(llvm::dbgs()
869 << "Unable to compute slice bound constraints\n");
870 return SliceComputationResult::GenericFailure;
871 }
872
873 // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
874 if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) {
875
876 // Pre-constraint id alignment: record loop IVs used in each constraint
877 // system.
878 SmallPtrSet<Value, 8> sliceUnionIVs;
879 for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k)
880 sliceUnionIVs.insert(sliceUnionCst.getIdValue(k));
881 SmallPtrSet<Value, 8> tmpSliceIVs;
882 for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k)
883 tmpSliceIVs.insert(tmpSliceCst.getIdValue(k));
884
885 sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst);
886
887 // Post-constraint id alignment: add loop IV bounds missing after
888 // id alignment to constraint systems. This can occur if one constraint
889 // system uses an loop IV that is not used by the other. The call
890 // to unionBoundingBox below expects constraints for each Loop IV, even
891 // if they are the unsliced full loop bounds added here.
892 if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
893 return SliceComputationResult::GenericFailure;
894 if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
895 return SliceComputationResult::GenericFailure;
896 }
897 // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
898 if (sliceUnionCst.getNumLocalIds() > 0 ||
899 tmpSliceCst.getNumLocalIds() > 0 ||
900 failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
901 LLVM_DEBUG(llvm::dbgs()
902 << "Unable to compute union bounding box of slice bounds\n");
903 return SliceComputationResult::GenericFailure;
904 }
905 }
906 }
907
908 // Empty union.
909 if (sliceUnionCst.getNumDimAndSymbolIds() == 0)
910 return SliceComputationResult::GenericFailure;
911
912 // Gather loops surrounding ops from loop nest where slice will be inserted.
913 SmallVector<Operation *, 4> ops;
914 for (auto &dep : dependentOpPairs) {
915 ops.push_back(isBackwardSlice ? dep.second : dep.first);
916 }
917 SmallVector<AffineForOp, 4> surroundingLoops;
918 unsigned innermostCommonLoopDepth =
919 getInnermostCommonLoopDepth(ops, &surroundingLoops);
920 if (loopDepth > innermostCommonLoopDepth) {
921 LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
922 return SliceComputationResult::GenericFailure;
923 }
924
925 // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
926 unsigned numSliceLoopIVs = sliceUnionCst.getNumDimIds();
927
928 // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
929 sliceUnionCst.convertLoopIVSymbolsToDims();
930 sliceUnion->clearBounds();
931 sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
932 sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
933
934 // Get slice bounds from slice union constraints 'sliceUnionCst'.
935 sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
936 opsA[0]->getContext(), &sliceUnion->lbs,
937 &sliceUnion->ubs);
938
939 // Add slice bound operands of union.
940 SmallVector<Value, 4> sliceBoundOperands;
941 sliceUnionCst.getIdValues(numSliceLoopIVs,
942 sliceUnionCst.getNumDimAndSymbolIds(),
943 &sliceBoundOperands);
944
945 // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
946 sliceUnion->ivs.clear();
947 sliceUnionCst.getIdValues(0, numSliceLoopIVs, &sliceUnion->ivs);
948
949 // Set loop nest insertion point to block start at 'loopDepth'.
950 sliceUnion->insertPoint =
951 isBackwardSlice
952 ? surroundingLoops[loopDepth - 1].getBody()->begin()
953 : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
954
955 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
956 // canonicalization.
957 sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
958 sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
959
960 // Check if the slice computed is valid. Return success only if it is verified
961 // that the slice is valid, otherwise return appropriate failure status.
962 Optional<bool> isSliceValid = sliceUnion->isSliceValid();
963 if (!isSliceValid.hasValue()) {
964 LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
965 return SliceComputationResult::GenericFailure;
966 }
967 if (!isSliceValid.getValue())
968 return SliceComputationResult::IncorrectSliceFailure;
969
970 return SliceComputationResult::Success;
971 }
972
973 // TODO: extend this to handle multiple result maps.
getConstDifference(AffineMap lbMap,AffineMap ubMap)974 static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
975 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
976 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
977 assert(lbMap.getNumDims() == ubMap.getNumDims());
978 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
979 AffineExpr lbExpr(lbMap.getResult(0));
980 AffineExpr ubExpr(ubMap.getResult(0));
981 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
982 lbMap.getNumSymbols());
983 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
984 if (!cExpr)
985 return None;
986 return cExpr.getValue();
987 }
988
989 // Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
990 // nest surrounding represented by slice loop bounds in 'slice'. Returns true
991 // on success, false otherwise (if a non-constant trip count was encountered).
992 // TODO: Make this work with non-unit step loops.
buildSliceTripCountMap(const ComputationSliceState & slice,llvm::SmallDenseMap<Operation *,uint64_t,8> * tripCountMap)993 bool mlir::buildSliceTripCountMap(
994 const ComputationSliceState &slice,
995 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
996 unsigned numSrcLoopIVs = slice.ivs.size();
997 // Populate map from AffineForOp -> trip count
998 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
999 AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
1000 auto *op = forOp.getOperation();
1001 AffineMap lbMap = slice.lbs[i];
1002 AffineMap ubMap = slice.ubs[i];
1003 // If lower or upper bound maps are null or provide no results, it implies
1004 // that source loop was not at all sliced, and the entire loop will be a
1005 // part of the slice.
1006 if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
1007 ubMap.getNumResults() == 0) {
1008 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
1009 if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
1010 (*tripCountMap)[op] =
1011 forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
1012 continue;
1013 }
1014 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
1015 if (maybeConstTripCount.hasValue()) {
1016 (*tripCountMap)[op] = maybeConstTripCount.getValue();
1017 continue;
1018 }
1019 return false;
1020 }
1021 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
1022 // Slice bounds are created with a constant ub - lb difference.
1023 if (!tripCount.hasValue())
1024 return false;
1025 (*tripCountMap)[op] = tripCount.getValue();
1026 }
1027 return true;
1028 }
1029
1030 // Return the number of iterations in the given slice.
getSliceIterationCount(const llvm::SmallDenseMap<Operation *,uint64_t,8> & sliceTripCountMap)1031 uint64_t mlir::getSliceIterationCount(
1032 const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
1033 uint64_t iterCount = 1;
1034 for (const auto &count : sliceTripCountMap) {
1035 iterCount *= count.second;
1036 }
1037 return iterCount;
1038 }
1039
1040 const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
1041 // Computes slice bounds by projecting out any loop IVs from
1042 // 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
1043 // bounds in 'sliceState' which represent the one loop nest's IVs in terms of
1044 // the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
getComputationSliceState(Operation * depSourceOp,Operation * depSinkOp,FlatAffineConstraints * dependenceConstraints,unsigned loopDepth,bool isBackwardSlice,ComputationSliceState * sliceState)1045 void mlir::getComputationSliceState(
1046 Operation *depSourceOp, Operation *depSinkOp,
1047 FlatAffineConstraints *dependenceConstraints, unsigned loopDepth,
1048 bool isBackwardSlice, ComputationSliceState *sliceState) {
1049 // Get loop nest surrounding src operation.
1050 SmallVector<AffineForOp, 4> srcLoopIVs;
1051 getLoopIVs(*depSourceOp, &srcLoopIVs);
1052 unsigned numSrcLoopIVs = srcLoopIVs.size();
1053
1054 // Get loop nest surrounding dst operation.
1055 SmallVector<AffineForOp, 4> dstLoopIVs;
1056 getLoopIVs(*depSinkOp, &dstLoopIVs);
1057 unsigned numDstLoopIVs = dstLoopIVs.size();
1058
1059 assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
1060 (isBackwardSlice && loopDepth <= numDstLoopIVs));
1061
1062 // Project out dimensions other than those up to 'loopDepth'.
1063 unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
1064 unsigned num =
1065 isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
1066 dependenceConstraints->projectOut(pos, num);
1067
1068 // Add slice loop IV values to 'sliceState'.
1069 unsigned offset = isBackwardSlice ? 0 : loopDepth;
1070 unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
1071 dependenceConstraints->getIdValues(offset, offset + numSliceLoopIVs,
1072 &sliceState->ivs);
1073
1074 // Set up lower/upper bound affine maps for the slice.
1075 sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
1076 sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
1077
1078 // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
1079 dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs,
1080 depSourceOp->getContext(),
1081 &sliceState->lbs, &sliceState->ubs);
1082
1083 // Set up bound operands for the slice's lower and upper bounds.
1084 SmallVector<Value, 4> sliceBoundOperands;
1085 unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds();
1086 for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
1087 if (i < offset || i >= offset + numSliceLoopIVs) {
1088 sliceBoundOperands.push_back(dependenceConstraints->getIdValue(i));
1089 }
1090 }
1091
1092 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1093 // canonicalization.
1094 sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1095 sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1096
1097 // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
1098 sliceState->insertPoint =
1099 isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
1100 : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
1101
1102 llvm::SmallDenseSet<Value, 8> sequentialLoops;
1103 if (isa<AffineReadOpInterface>(depSourceOp) &&
1104 isa<AffineReadOpInterface>(depSinkOp)) {
1105 // For read-read access pairs, clear any slice bounds on sequential loops.
1106 // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
1107 getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
1108 &sequentialLoops);
1109 }
1110 auto getSliceLoop = [&](unsigned i) {
1111 return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
1112 };
1113 auto isInnermostInsertion = [&]() {
1114 return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
1115 : loopDepth >= dstLoopIVs.size());
1116 };
1117 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
1118 auto srcIsUnitSlice = [&]() {
1119 return (buildSliceTripCountMap(*sliceState, &sliceTripCountMap) &&
1120 (getSliceIterationCount(sliceTripCountMap) == 1));
1121 };
1122 // Clear all sliced loop bounds beginning at the first sequential loop, or
1123 // first loop with a slice fusion barrier attribute..
1124
1125 for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
1126 Value iv = getSliceLoop(i).getInductionVar();
1127 if (sequentialLoops.count(iv) == 0 &&
1128 getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
1129 continue;
1130 // Skip reset of bounds of reduction loop inserted in the destination loop
1131 // that meets the following conditions:
1132 // 1. Slice is single trip count.
1133 // 2. Loop bounds of the source and destination match.
1134 // 3. Is being inserted at the innermost insertion point.
1135 Optional<bool> isMaximal = sliceState->isMaximal();
1136 if (isLoopParallelAndContainsReduction(getSliceLoop(i)) &&
1137 isInnermostInsertion() && srcIsUnitSlice() && isMaximal.hasValue() &&
1138 isMaximal.getValue())
1139 continue;
1140 for (unsigned j = i; j < numSliceLoopIVs; ++j) {
1141 sliceState->lbs[j] = AffineMap();
1142 sliceState->ubs[j] = AffineMap();
1143 }
1144 break;
1145 }
1146 }
1147
1148 /// Creates a computation slice of the loop nest surrounding 'srcOpInst',
1149 /// updates the slice loop bounds with any non-null bound maps specified in
1150 /// 'sliceState', and inserts this slice into the loop nest surrounding
1151 /// 'dstOpInst' at loop depth 'dstLoopDepth'.
1152 // TODO: extend the slicing utility to compute slices that
1153 // aren't necessarily a one-to-one relation b/w the source and destination. The
1154 // relation between the source and destination could be many-to-many in general.
1155 // TODO: the slice computation is incorrect in the cases
1156 // where the dependence from the source to the destination does not cover the
1157 // entire destination index set. Subtract out the dependent destination
1158 // iterations from destination index set and check for emptiness --- this is one
1159 // solution.
1160 AffineForOp
insertBackwardComputationSlice(Operation * srcOpInst,Operation * dstOpInst,unsigned dstLoopDepth,ComputationSliceState * sliceState)1161 mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst,
1162 unsigned dstLoopDepth,
1163 ComputationSliceState *sliceState) {
1164 // Get loop nest surrounding src operation.
1165 SmallVector<AffineForOp, 4> srcLoopIVs;
1166 getLoopIVs(*srcOpInst, &srcLoopIVs);
1167 unsigned numSrcLoopIVs = srcLoopIVs.size();
1168
1169 // Get loop nest surrounding dst operation.
1170 SmallVector<AffineForOp, 4> dstLoopIVs;
1171 getLoopIVs(*dstOpInst, &dstLoopIVs);
1172 unsigned dstLoopIVsSize = dstLoopIVs.size();
1173 if (dstLoopDepth > dstLoopIVsSize) {
1174 dstOpInst->emitError("invalid destination loop depth");
1175 return AffineForOp();
1176 }
1177
1178 // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
1179 SmallVector<unsigned, 4> positions;
1180 // TODO: This code is incorrect since srcLoopIVs can be 0-d.
1181 findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
1182
1183 // Clone src loop nest and insert it a the beginning of the operation block
1184 // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
1185 auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
1186 OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
1187 auto sliceLoopNest =
1188 cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
1189
1190 Operation *sliceInst =
1191 getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
1192 // Get loop nest surrounding 'sliceInst'.
1193 SmallVector<AffineForOp, 4> sliceSurroundingLoops;
1194 getLoopIVs(*sliceInst, &sliceSurroundingLoops);
1195
1196 // Sanity check.
1197 unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
1198 (void)sliceSurroundingLoopsSize;
1199 assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
1200 unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
1201 (void)sliceLoopLimit;
1202 assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
1203
1204 // Update loop bounds for loops in 'sliceLoopNest'.
1205 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1206 auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
1207 if (AffineMap lbMap = sliceState->lbs[i])
1208 forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
1209 if (AffineMap ubMap = sliceState->ubs[i])
1210 forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
1211 }
1212 return sliceLoopNest;
1213 }
1214
1215 // Constructs MemRefAccess populating it with the memref, its indices and
1216 // opinst from 'loadOrStoreOpInst'.
MemRefAccess(Operation * loadOrStoreOpInst)1217 MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
1218 if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
1219 memref = loadOp.getMemRef();
1220 opInst = loadOrStoreOpInst;
1221 auto loadMemrefType = loadOp.getMemRefType();
1222 indices.reserve(loadMemrefType.getRank());
1223 for (auto index : loadOp.getMapOperands()) {
1224 indices.push_back(index);
1225 }
1226 } else {
1227 assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
1228 "Affine read/write op expected");
1229 auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
1230 opInst = loadOrStoreOpInst;
1231 memref = storeOp.getMemRef();
1232 auto storeMemrefType = storeOp.getMemRefType();
1233 indices.reserve(storeMemrefType.getRank());
1234 for (auto index : storeOp.getMapOperands()) {
1235 indices.push_back(index);
1236 }
1237 }
1238 }
1239
getRank() const1240 unsigned MemRefAccess::getRank() const {
1241 return memref.getType().cast<MemRefType>().getRank();
1242 }
1243
isStore() const1244 bool MemRefAccess::isStore() const {
1245 return isa<AffineWriteOpInterface>(opInst);
1246 }
1247
1248 /// Returns the nesting depth of this statement, i.e., the number of loops
1249 /// surrounding this statement.
getNestingDepth(Operation * op)1250 unsigned mlir::getNestingDepth(Operation *op) {
1251 Operation *currOp = op;
1252 unsigned depth = 0;
1253 while ((currOp = currOp->getParentOp())) {
1254 if (isa<AffineForOp>(currOp))
1255 depth++;
1256 }
1257 return depth;
1258 }
1259
1260 /// Equal if both affine accesses are provably equivalent (at compile
1261 /// time) when considering the memref, the affine maps and their respective
1262 /// operands. The equality of access functions + operands is checked by
1263 /// subtracting fully composed value maps, and then simplifying the difference
1264 /// using the expression flattener.
1265 /// TODO: this does not account for aliasing of memrefs.
operator ==(const MemRefAccess & rhs) const1266 bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
1267 if (memref != rhs.memref)
1268 return false;
1269
1270 AffineValueMap diff, thisMap, rhsMap;
1271 getAccessMap(&thisMap);
1272 rhs.getAccessMap(&rhsMap);
1273 AffineValueMap::difference(thisMap, rhsMap, &diff);
1274 return llvm::all_of(diff.getAffineMap().getResults(),
1275 [](AffineExpr e) { return e == 0; });
1276 }
1277
1278 /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
1279 /// where each lists loops from outer-most to inner-most in loop nest.
getNumCommonSurroundingLoops(Operation & A,Operation & B)1280 unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) {
1281 SmallVector<AffineForOp, 4> loopsA, loopsB;
1282 getLoopIVs(A, &loopsA);
1283 getLoopIVs(B, &loopsB);
1284
1285 unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
1286 unsigned numCommonLoops = 0;
1287 for (unsigned i = 0; i < minNumLoops; ++i) {
1288 if (loopsA[i].getOperation() != loopsB[i].getOperation())
1289 break;
1290 ++numCommonLoops;
1291 }
1292 return numCommonLoops;
1293 }
1294
getMemoryFootprintBytes(Block & block,Block::iterator start,Block::iterator end,int memorySpace)1295 static Optional<int64_t> getMemoryFootprintBytes(Block &block,
1296 Block::iterator start,
1297 Block::iterator end,
1298 int memorySpace) {
1299 SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
1300
1301 // Walk this 'affine.for' operation to gather all memory regions.
1302 auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
1303 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
1304 // Neither load nor a store op.
1305 return WalkResult::advance();
1306 }
1307
1308 // Compute the memref region symbolic in any IVs enclosing this block.
1309 auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
1310 if (failed(
1311 region->compute(opInst,
1312 /*loopDepth=*/getNestingDepth(&*block.begin())))) {
1313 return opInst->emitError("error obtaining memory region\n");
1314 }
1315
1316 auto it = regions.find(region->memref);
1317 if (it == regions.end()) {
1318 regions[region->memref] = std::move(region);
1319 } else if (failed(it->second->unionBoundingBox(*region))) {
1320 return opInst->emitWarning(
1321 "getMemoryFootprintBytes: unable to perform a union on a memory "
1322 "region");
1323 }
1324 return WalkResult::advance();
1325 });
1326 if (result.wasInterrupted())
1327 return None;
1328
1329 int64_t totalSizeInBytes = 0;
1330 for (const auto ®ion : regions) {
1331 Optional<int64_t> size = region.second->getRegionSize();
1332 if (!size.hasValue())
1333 return None;
1334 totalSizeInBytes += size.getValue();
1335 }
1336 return totalSizeInBytes;
1337 }
1338
getMemoryFootprintBytes(AffineForOp forOp,int memorySpace)1339 Optional<int64_t> mlir::getMemoryFootprintBytes(AffineForOp forOp,
1340 int memorySpace) {
1341 auto *forInst = forOp.getOperation();
1342 return ::getMemoryFootprintBytes(
1343 *forInst->getBlock(), Block::iterator(forInst),
1344 std::next(Block::iterator(forInst)), memorySpace);
1345 }
1346
1347 /// Returns whether a loop is parallel and contains a reduction loop.
isLoopParallelAndContainsReduction(AffineForOp forOp)1348 bool mlir::isLoopParallelAndContainsReduction(AffineForOp forOp) {
1349 SmallVector<LoopReduction> reductions;
1350 if (!isLoopParallel(forOp, &reductions))
1351 return false;
1352 return !reductions.empty();
1353 }
1354
1355 /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
1356 /// at 'forOp'.
getSequentialLoops(AffineForOp forOp,llvm::SmallDenseSet<Value,8> * sequentialLoops)1357 void mlir::getSequentialLoops(AffineForOp forOp,
1358 llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
1359 forOp->walk([&](Operation *op) {
1360 if (auto innerFor = dyn_cast<AffineForOp>(op))
1361 if (!isLoopParallel(innerFor))
1362 sequentialLoops->insert(innerFor.getInductionVar());
1363 });
1364 }
1365
simplifyIntegerSet(IntegerSet set)1366 IntegerSet mlir::simplifyIntegerSet(IntegerSet set) {
1367 FlatAffineConstraints fac(set);
1368 if (fac.isEmpty())
1369 return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
1370 set.getContext());
1371 fac.removeTrivialRedundancy();
1372
1373 auto simplifiedSet = fac.getAsIntegerSet(set.getContext());
1374 assert(simplifiedSet && "guaranteed to succeed while roundtripping");
1375 return simplifiedSet;
1376 }
1377