1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Structures for affine/polyhedral analysis of affine dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/AffineStructures.h"
14 #include "mlir/Analysis/Presburger/Simplex.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/AffineExprVisitor.h"
19 #include "mlir/IR/IntegerSet.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Support/MathExtras.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/raw_ostream.h"
25
26 #define DEBUG_TYPE "affine-structures"
27
28 using namespace mlir;
29 using llvm::SmallDenseMap;
30 using llvm::SmallDenseSet;
31
32 namespace {
33
34 // See comments for SimpleAffineExprFlattener.
35 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
36 // constraint information associated with mod's, floordiv's, and ceildiv's
37 // in FlatAffineConstraints 'localVarCst'.
38 struct AffineExprFlattener : public SimpleAffineExprFlattener {
39 public:
40 // Constraints connecting newly introduced local variables (for mod's and
41 // div's) to existing (dimensional and symbolic) ones. These are always
42 // inequalities.
43 FlatAffineConstraints localVarCst;
44
AffineExprFlattener__anon882b85370111::AffineExprFlattener45 AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx)
46 : SimpleAffineExprFlattener(nDims, nSymbols) {
47 localVarCst.reset(nDims, nSymbols, /*numLocals=*/0);
48 }
49
50 private:
51 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
52 // The local identifier added is always a floordiv of a pure add/mul affine
53 // function of other identifiers, coefficients of which are specified in
54 // `dividend' and with respect to the positive constant `divisor'. localExpr
55 // is the simplified tree expression (AffineExpr) corresponding to the
56 // quantifier.
addLocalFloorDivId__anon882b85370111::AffineExprFlattener57 void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
58 AffineExpr localExpr) override {
59 SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
60 // Update localVarCst.
61 localVarCst.addLocalFloorDiv(dividend, divisor);
62 }
63 };
64
65 } // end anonymous namespace
66
67 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
68 // flattened (i.e., semi-affine expressions not handled yet).
69 static LogicalResult
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs,unsigned numDims,unsigned numSymbols,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)70 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
71 unsigned numSymbols,
72 std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
73 FlatAffineConstraints *localVarCst) {
74 if (exprs.empty()) {
75 localVarCst->reset(numDims, numSymbols);
76 return success();
77 }
78
79 AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
80 // Use the same flattener to simplify each expression successively. This way
81 // local identifiers / expressions are shared.
82 for (auto expr : exprs) {
83 if (!expr.isPureAffine())
84 return failure();
85
86 flattener.walkPostOrder(expr);
87 }
88
89 assert(flattener.operandExprStack.size() == exprs.size());
90 flattenedExprs->clear();
91 flattenedExprs->assign(flattener.operandExprStack.begin(),
92 flattener.operandExprStack.end());
93
94 if (localVarCst)
95 localVarCst->clearAndCopyFrom(flattener.localVarCst);
96
97 return success();
98 }
99
100 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
101 // be flattened (semi-affine expressions not handled yet).
102 LogicalResult
getFlattenedAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols,SmallVectorImpl<int64_t> * flattenedExpr,FlatAffineConstraints * localVarCst)103 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
104 unsigned numSymbols,
105 SmallVectorImpl<int64_t> *flattenedExpr,
106 FlatAffineConstraints *localVarCst) {
107 std::vector<SmallVector<int64_t, 8>> flattenedExprs;
108 LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
109 &flattenedExprs, localVarCst);
110 *flattenedExpr = flattenedExprs[0];
111 return ret;
112 }
113
114 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
115 /// flattened (i.e., semi-affine expressions not handled yet).
getFlattenedAffineExprs(AffineMap map,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)116 LogicalResult mlir::getFlattenedAffineExprs(
117 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
118 FlatAffineConstraints *localVarCst) {
119 if (map.getNumResults() == 0) {
120 localVarCst->reset(map.getNumDims(), map.getNumSymbols());
121 return success();
122 }
123 return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
124 map.getNumSymbols(), flattenedExprs,
125 localVarCst);
126 }
127
getFlattenedAffineExprs(IntegerSet set,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)128 LogicalResult mlir::getFlattenedAffineExprs(
129 IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
130 FlatAffineConstraints *localVarCst) {
131 if (set.getNumConstraints() == 0) {
132 localVarCst->reset(set.getNumDims(), set.getNumSymbols());
133 return success();
134 }
135 return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
136 set.getNumSymbols(), flattenedExprs,
137 localVarCst);
138 }
139
140 //===----------------------------------------------------------------------===//
141 // FlatAffineConstraints.
142 //===----------------------------------------------------------------------===//
143
144 // Copy constructor.
FlatAffineConstraints(const FlatAffineConstraints & other)145 FlatAffineConstraints::FlatAffineConstraints(
146 const FlatAffineConstraints &other) {
147 numReservedCols = other.numReservedCols;
148 numDims = other.getNumDimIds();
149 numSymbols = other.getNumSymbolIds();
150 numIds = other.getNumIds();
151
152 auto otherIds = other.getIds();
153 ids.reserve(numReservedCols);
154 ids.append(otherIds.begin(), otherIds.end());
155
156 unsigned numReservedEqualities = other.getNumReservedEqualities();
157 unsigned numReservedInequalities = other.getNumReservedInequalities();
158
159 equalities.reserve(numReservedEqualities * numReservedCols);
160 inequalities.reserve(numReservedInequalities * numReservedCols);
161
162 for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
163 addInequality(other.getInequality(r));
164 }
165 for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
166 addEquality(other.getEquality(r));
167 }
168 }
169
170 // Clones this object.
clone() const171 std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
172 return std::make_unique<FlatAffineConstraints>(*this);
173 }
174
175 // Construct from an IntegerSet.
FlatAffineConstraints(IntegerSet set)176 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
177 : numReservedCols(set.getNumInputs() + 1),
178 numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
179 numSymbols(set.getNumSymbols()) {
180 equalities.reserve(set.getNumEqualities() * numReservedCols);
181 inequalities.reserve(set.getNumInequalities() * numReservedCols);
182 ids.resize(numIds, None);
183
184 // Flatten expressions and add them to the constraint system.
185 std::vector<SmallVector<int64_t, 8>> flatExprs;
186 FlatAffineConstraints localVarCst;
187 if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
188 assert(false && "flattening unimplemented for semi-affine integer sets");
189 return;
190 }
191 assert(flatExprs.size() == set.getNumConstraints());
192 for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) {
193 addLocalId(getNumLocalIds());
194 }
195
196 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
197 const auto &flatExpr = flatExprs[i];
198 assert(flatExpr.size() == getNumCols());
199 if (set.getEqFlags()[i]) {
200 addEquality(flatExpr);
201 } else {
202 addInequality(flatExpr);
203 }
204 }
205 // Add the other constraints involving local id's from flattening.
206 append(localVarCst);
207 }
208
reset(unsigned numReservedInequalities,unsigned numReservedEqualities,unsigned newNumReservedCols,unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> idArgs)209 void FlatAffineConstraints::reset(unsigned numReservedInequalities,
210 unsigned numReservedEqualities,
211 unsigned newNumReservedCols,
212 unsigned newNumDims, unsigned newNumSymbols,
213 unsigned newNumLocals,
214 ArrayRef<Value> idArgs) {
215 assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
216 "minimum 1 column");
217 numReservedCols = newNumReservedCols;
218 numDims = newNumDims;
219 numSymbols = newNumSymbols;
220 numIds = numDims + numSymbols + newNumLocals;
221 assert(idArgs.empty() || idArgs.size() == numIds);
222
223 clearConstraints();
224 if (numReservedEqualities >= 1)
225 equalities.reserve(newNumReservedCols * numReservedEqualities);
226 if (numReservedInequalities >= 1)
227 inequalities.reserve(newNumReservedCols * numReservedInequalities);
228 if (idArgs.empty()) {
229 ids.resize(numIds, None);
230 } else {
231 ids.assign(idArgs.begin(), idArgs.end());
232 }
233 }
234
reset(unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> idArgs)235 void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
236 unsigned newNumLocals,
237 ArrayRef<Value> idArgs) {
238 reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
239 newNumSymbols, newNumLocals, idArgs);
240 }
241
append(const FlatAffineConstraints & other)242 void FlatAffineConstraints::append(const FlatAffineConstraints &other) {
243 assert(other.getNumCols() == getNumCols());
244 assert(other.getNumDimIds() == getNumDimIds());
245 assert(other.getNumSymbolIds() == getNumSymbolIds());
246
247 inequalities.reserve(inequalities.size() +
248 other.getNumInequalities() * numReservedCols);
249 equalities.reserve(equalities.size() +
250 other.getNumEqualities() * numReservedCols);
251
252 for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
253 addInequality(other.getInequality(r));
254 }
255 for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
256 addEquality(other.getEquality(r));
257 }
258 }
259
addLocalId(unsigned pos)260 void FlatAffineConstraints::addLocalId(unsigned pos) {
261 addId(IdKind::Local, pos);
262 }
263
addDimId(unsigned pos,Value id)264 void FlatAffineConstraints::addDimId(unsigned pos, Value id) {
265 addId(IdKind::Dimension, pos, id);
266 }
267
addSymbolId(unsigned pos,Value id)268 void FlatAffineConstraints::addSymbolId(unsigned pos, Value id) {
269 addId(IdKind::Symbol, pos, id);
270 }
271
272 /// Adds a dimensional identifier. The added column is initialized to
273 /// zero.
addId(IdKind kind,unsigned pos,Value id)274 void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value id) {
275 if (kind == IdKind::Dimension)
276 assert(pos <= getNumDimIds());
277 else if (kind == IdKind::Symbol)
278 assert(pos <= getNumSymbolIds());
279 else
280 assert(pos <= getNumLocalIds());
281
282 unsigned oldNumReservedCols = numReservedCols;
283
284 // Check if a resize is necessary.
285 if (getNumCols() + 1 > numReservedCols) {
286 equalities.resize(getNumEqualities() * (getNumCols() + 1));
287 inequalities.resize(getNumInequalities() * (getNumCols() + 1));
288 numReservedCols++;
289 }
290
291 int absolutePos;
292
293 if (kind == IdKind::Dimension) {
294 absolutePos = pos;
295 numDims++;
296 } else if (kind == IdKind::Symbol) {
297 absolutePos = pos + getNumDimIds();
298 numSymbols++;
299 } else {
300 absolutePos = pos + getNumDimIds() + getNumSymbolIds();
301 }
302 numIds++;
303
304 // Note that getNumCols() now will already return the new size, which will be
305 // at least one.
306 int numInequalities = static_cast<int>(getNumInequalities());
307 int numEqualities = static_cast<int>(getNumEqualities());
308 int numCols = static_cast<int>(getNumCols());
309 for (int r = numInequalities - 1; r >= 0; r--) {
310 for (int c = numCols - 2; c >= 0; c--) {
311 if (c < absolutePos)
312 atIneq(r, c) = inequalities[r * oldNumReservedCols + c];
313 else
314 atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c];
315 }
316 atIneq(r, absolutePos) = 0;
317 }
318
319 for (int r = numEqualities - 1; r >= 0; r--) {
320 for (int c = numCols - 2; c >= 0; c--) {
321 // All values in column absolutePositions < absolutePos have the same
322 // coordinates in the 2-d view of the coefficient buffer.
323 if (c < absolutePos)
324 atEq(r, c) = equalities[r * oldNumReservedCols + c];
325 else
326 // Those at absolutePosition >= absolutePos, get a shifted
327 // absolutePosition.
328 atEq(r, c + 1) = equalities[r * oldNumReservedCols + c];
329 }
330 // Initialize added dimension to zero.
331 atEq(r, absolutePos) = 0;
332 }
333
334 // If an 'id' is provided, insert it; otherwise use None.
335 if (id)
336 ids.insert(ids.begin() + absolutePos, id);
337 else
338 ids.insert(ids.begin() + absolutePos, None);
339 assert(ids.size() == getNumIds());
340 }
341
342 /// Checks if two constraint systems are in the same space, i.e., if they are
343 /// associated with the same set of identifiers, appearing in the same order.
areIdsAligned(const FlatAffineConstraints & A,const FlatAffineConstraints & B)344 static bool areIdsAligned(const FlatAffineConstraints &A,
345 const FlatAffineConstraints &B) {
346 return A.getNumDimIds() == B.getNumDimIds() &&
347 A.getNumSymbolIds() == B.getNumSymbolIds() &&
348 A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
349 }
350
351 /// Calls areIdsAligned to check if two constraint systems have the same set
352 /// of identifiers in the same order.
areIdsAlignedWithOther(const FlatAffineConstraints & other)353 bool FlatAffineConstraints::areIdsAlignedWithOther(
354 const FlatAffineConstraints &other) {
355 return areIdsAligned(*this, other);
356 }
357
358 /// Checks if the SSA values associated with `cst''s identifiers are unique.
359 static bool LLVM_ATTRIBUTE_UNUSED
areIdsUnique(const FlatAffineConstraints & cst)360 areIdsUnique(const FlatAffineConstraints &cst) {
361 SmallPtrSet<Value, 8> uniqueIds;
362 for (auto id : cst.getIds()) {
363 if (id.hasValue() && !uniqueIds.insert(id.getValue()).second)
364 return false;
365 }
366 return true;
367 }
368
369 // Swap the posA^th identifier with the posB^th identifier.
swapId(FlatAffineConstraints * A,unsigned posA,unsigned posB)370 static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) {
371 assert(posA < A->getNumIds() && "invalid position A");
372 assert(posB < A->getNumIds() && "invalid position B");
373
374 if (posA == posB)
375 return;
376
377 for (unsigned r = 0, e = A->getNumInequalities(); r < e; r++) {
378 std::swap(A->atIneq(r, posA), A->atIneq(r, posB));
379 }
380 for (unsigned r = 0, e = A->getNumEqualities(); r < e; r++) {
381 std::swap(A->atEq(r, posA), A->atEq(r, posB));
382 }
383 std::swap(A->getId(posA), A->getId(posB));
384 }
385
386 /// Merge and align the identifiers of A and B starting at 'offset', so that
387 /// both constraint systems get the union of the contained identifiers that is
388 /// dimension-wise and symbol-wise unique; both constraint systems are updated
389 /// so that they have the union of all identifiers, with A's original
390 /// identifiers appearing first followed by any of B's identifiers that didn't
391 /// appear in A. Local identifiers of each system are by design separate/local
392 /// and are placed one after other (A's followed by B's).
393 // Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
394 // Output: both A, B have (%i, %j, %k) [%M, %N, %P]
395 //
mergeAndAlignIds(unsigned offset,FlatAffineConstraints * A,FlatAffineConstraints * B)396 static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
397 FlatAffineConstraints *B) {
398 assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds());
399 // A merge/align isn't meaningful if a cst's ids aren't distinct.
400 assert(areIdsUnique(*A) && "A's id values aren't unique");
401 assert(areIdsUnique(*B) && "B's id values aren't unique");
402
403 assert(std::all_of(A->getIds().begin() + offset,
404 A->getIds().begin() + A->getNumDimAndSymbolIds(),
405 [](Optional<Value> id) { return id.hasValue(); }));
406
407 assert(std::all_of(B->getIds().begin() + offset,
408 B->getIds().begin() + B->getNumDimAndSymbolIds(),
409 [](Optional<Value> id) { return id.hasValue(); }));
410
411 // Place local id's of A after local id's of B.
412 for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) {
413 B->addLocalId(0);
414 }
415 for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e;
416 t++) {
417 A->addLocalId(A->getNumLocalIds());
418 }
419
420 SmallVector<Value, 4> aDimValues, aSymValues;
421 A->getIdValues(offset, A->getNumDimIds(), &aDimValues);
422 A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues);
423 {
424 // Merge dims from A into B.
425 unsigned d = offset;
426 for (auto aDimValue : aDimValues) {
427 unsigned loc;
428 if (B->findId(aDimValue, &loc)) {
429 assert(loc >= offset && "A's dim appears in B's aligned range");
430 assert(loc < B->getNumDimIds() &&
431 "A's dim appears in B's non-dim position");
432 swapId(B, d, loc);
433 } else {
434 B->addDimId(d);
435 B->setIdValue(d, aDimValue);
436 }
437 d++;
438 }
439
440 // Dimensions that are in B, but not in A, are added at the end.
441 for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) {
442 A->addDimId(A->getNumDimIds());
443 A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t));
444 }
445 }
446 {
447 // Merge symbols: merge A's symbols into B first.
448 unsigned s = B->getNumDimIds();
449 for (auto aSymValue : aSymValues) {
450 unsigned loc;
451 if (B->findId(aSymValue, &loc)) {
452 assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
453 "A's symbol appears in B's non-symbol position");
454 swapId(B, s, loc);
455 } else {
456 B->addSymbolId(s - B->getNumDimIds());
457 B->setIdValue(s, aSymValue);
458 }
459 s++;
460 }
461 // Symbols that are in B, but not in A, are added at the end.
462 for (unsigned t = A->getNumDimAndSymbolIds(),
463 e = B->getNumDimAndSymbolIds();
464 t < e; t++) {
465 A->addSymbolId(A->getNumSymbolIds());
466 A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t));
467 }
468 }
469 assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
470 }
471
472 // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
mergeAndAlignIdsWithOther(unsigned offset,FlatAffineConstraints * other)473 void FlatAffineConstraints::mergeAndAlignIdsWithOther(
474 unsigned offset, FlatAffineConstraints *other) {
475 mergeAndAlignIds(offset, this, other);
476 }
477
478 // This routine may add additional local variables if the flattened expression
479 // corresponding to the map has such variables due to mod's, ceildiv's, and
480 // floordiv's in it.
composeMap(const AffineValueMap * vMap)481 LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) {
482 std::vector<SmallVector<int64_t, 8>> flatExprs;
483 FlatAffineConstraints localCst;
484 if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs,
485 &localCst))) {
486 LLVM_DEBUG(llvm::dbgs()
487 << "composition unimplemented for semi-affine maps\n");
488 return failure();
489 }
490 assert(flatExprs.size() == vMap->getNumResults());
491
492 // Add localCst information.
493 if (localCst.getNumLocalIds() > 0) {
494 localCst.setIdValues(0, /*end=*/localCst.getNumDimAndSymbolIds(),
495 /*values=*/vMap->getOperands());
496 // Align localCst and this.
497 mergeAndAlignIds(/*offset=*/0, &localCst, this);
498 // Finally, append localCst to this constraint set.
499 append(localCst);
500 }
501
502 // Add dimensions corresponding to the map's results.
503 for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) {
504 // TODO: Consider using a batched version to add a range of IDs.
505 addDimId(0);
506 }
507
508 // We add one equality for each result connecting the result dim of the map to
509 // the other identifiers.
510 // For eg: if the expression is 16*i0 + i1, and this is the r^th
511 // iteration/result of the value map, we are adding the equality:
512 // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
513 // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
514 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
515 const auto &flatExpr = flatExprs[r];
516 assert(flatExpr.size() >= vMap->getNumOperands() + 1);
517
518 // eqToAdd is the equality corresponding to the flattened affine expression.
519 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
520 // Set the coefficient for this result to one.
521 eqToAdd[r] = 1;
522
523 // Dims and symbols.
524 for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
525 unsigned loc;
526 bool ret = findId(vMap->getOperand(i), &loc);
527 assert(ret && "value map's id can't be found");
528 (void)ret;
529 // Negate 'eq[r]' since the newly added dimension will be set to this one.
530 eqToAdd[loc] = -flatExpr[i];
531 }
532 // Local vars common to eq and localCst are at the beginning.
533 unsigned j = getNumDimIds() + getNumSymbolIds();
534 unsigned end = flatExpr.size() - 1;
535 for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) {
536 eqToAdd[j] = -flatExpr[i];
537 }
538
539 // Constant term.
540 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
541
542 // Add the equality connecting the result of the map to this constraint set.
543 addEquality(eqToAdd);
544 }
545
546 return success();
547 }
548
549 // Similar to composeMap except that no Value's need be associated with the
550 // constraint system nor are they looked at -- since the dimensions and
551 // symbols of 'other' are expected to correspond 1:1 to 'this' system. It
552 // is thus not convenient to share code with composeMap.
composeMatchingMap(AffineMap other)553 LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
554 assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
555 assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
556
557 std::vector<SmallVector<int64_t, 8>> flatExprs;
558 FlatAffineConstraints localCst;
559 if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) {
560 LLVM_DEBUG(llvm::dbgs()
561 << "composition unimplemented for semi-affine maps\n");
562 return failure();
563 }
564 assert(flatExprs.size() == other.getNumResults());
565
566 // Add localCst information.
567 if (localCst.getNumLocalIds() > 0) {
568 // Place local id's of A after local id's of B.
569 for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; l++) {
570 addLocalId(0);
571 }
572 // Finally, append localCst to this constraint set.
573 append(localCst);
574 }
575
576 // Add dimensions corresponding to the map's results.
577 for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
578 addDimId(0);
579 }
580
581 // We add one equality for each result connecting the result dim of the map to
582 // the other identifiers.
583 // For eg: if the expression is 16*i0 + i1, and this is the r^th
584 // iteration/result of the value map, we are adding the equality:
585 // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
586 // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
587 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
588 const auto &flatExpr = flatExprs[r];
589 assert(flatExpr.size() >= other.getNumInputs() + 1);
590
591 // eqToAdd is the equality corresponding to the flattened affine expression.
592 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
593 // Set the coefficient for this result to one.
594 eqToAdd[r] = 1;
595
596 // Dims and symbols.
597 for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
598 // Negate 'eq[r]' since the newly added dimension will be set to this one.
599 eqToAdd[e + i] = -flatExpr[i];
600 }
601 // Local vars common to eq and localCst are at the beginning.
602 unsigned j = getNumDimIds() + getNumSymbolIds();
603 unsigned end = flatExpr.size() - 1;
604 for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
605 eqToAdd[j] = -flatExpr[i];
606 }
607
608 // Constant term.
609 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
610
611 // Add the equality connecting the result of the map to this constraint set.
612 addEquality(eqToAdd);
613 }
614
615 return success();
616 }
617
618 // Turn a dimension into a symbol.
turnDimIntoSymbol(FlatAffineConstraints * cst,Value id)619 static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value id) {
620 unsigned pos;
621 if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) {
622 swapId(cst, pos, cst->getNumDimIds() - 1);
623 cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1);
624 }
625 }
626
627 // Turn a symbol into a dimension.
turnSymbolIntoDim(FlatAffineConstraints * cst,Value id)628 static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value id) {
629 unsigned pos;
630 if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
631 pos < cst->getNumDimAndSymbolIds()) {
632 swapId(cst, pos, cst->getNumDimIds());
633 cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
634 }
635 }
636
637 // Changes all symbol identifiers which are loop IVs to dim identifiers.
convertLoopIVSymbolsToDims()638 void FlatAffineConstraints::convertLoopIVSymbolsToDims() {
639 // Gather all symbols which are loop IVs.
640 SmallVector<Value, 4> loopIVs;
641 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
642 if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue()))
643 loopIVs.push_back(ids[i].getValue());
644 }
645 // Turn each symbol in 'loopIVs' into a dim identifier.
646 for (auto iv : loopIVs) {
647 turnSymbolIntoDim(this, iv);
648 }
649 }
650
addInductionVarOrTerminalSymbol(Value id)651 void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
652 if (containsId(id))
653 return;
654
655 // Caller is expected to fully compose map/operands if necessary.
656 assert((isTopLevelValue(id) || isForInductionVar(id)) &&
657 "non-terminal symbol / loop IV expected");
658 // Outer loop IVs could be used in forOp's bounds.
659 if (auto loop = getForInductionVarOwner(id)) {
660 addDimId(getNumDimIds(), id);
661 if (failed(this->addAffineForOpDomain(loop)))
662 LLVM_DEBUG(
663 loop.emitWarning("failed to add domain info to constraint system"));
664 return;
665 }
666 // Add top level symbol.
667 addSymbolId(getNumSymbolIds(), id);
668 // Check if the symbol is a constant.
669 if (auto constOp = id.getDefiningOp<ConstantIndexOp>())
670 setIdToConstant(id, constOp.getValue());
671 }
672
addAffineForOpDomain(AffineForOp forOp)673 LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
674 unsigned pos;
675 // Pre-condition for this method.
676 if (!findId(forOp.getInductionVar(), &pos)) {
677 assert(false && "Value not found");
678 return failure();
679 }
680
681 int64_t step = forOp.getStep();
682 if (step != 1) {
683 if (!forOp.hasConstantLowerBound())
684 forOp.emitWarning("domain conservatively approximated");
685 else {
686 // Add constraints for the stride.
687 // (iv - lb) % step = 0 can be written as:
688 // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
689 // Add local variable 'q' and add the above equality.
690 // The first constraint is q = (iv - lb) floordiv step
691 SmallVector<int64_t, 8> dividend(getNumCols(), 0);
692 int64_t lb = forOp.getConstantLowerBound();
693 dividend[pos] = 1;
694 dividend.back() -= lb;
695 addLocalFloorDiv(dividend, step);
696 // Second constraint: (iv - lb) - step * q = 0.
697 SmallVector<int64_t, 8> eq(getNumCols(), 0);
698 eq[pos] = 1;
699 eq.back() -= lb;
700 // For the local var just added above.
701 eq[getNumCols() - 2] = -step;
702 addEquality(eq);
703 }
704 }
705
706 if (forOp.hasConstantLowerBound()) {
707 addConstantLowerBound(pos, forOp.getConstantLowerBound());
708 } else {
709 // Non-constant lower bound case.
710 if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(),
711 forOp.getLowerBoundOperands(),
712 /*eq=*/false, /*lower=*/true)))
713 return failure();
714 }
715
716 if (forOp.hasConstantUpperBound()) {
717 addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1);
718 return success();
719 }
720 // Non-constant upper bound case.
721 return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(),
722 forOp.getUpperBoundOperands(),
723 /*eq=*/false, /*lower=*/false);
724 }
725
726 // Searches for a constraint with a non-zero coefficient at 'colIdx' in
727 // equality (isEq=true) or inequality (isEq=false) constraints.
728 // Returns true and sets row found in search in 'rowIdx'.
729 // Returns false otherwise.
findConstraintWithNonZeroAt(const FlatAffineConstraints & cst,unsigned colIdx,bool isEq,unsigned * rowIdx)730 static bool findConstraintWithNonZeroAt(const FlatAffineConstraints &cst,
731 unsigned colIdx, bool isEq,
732 unsigned *rowIdx) {
733 assert(colIdx < cst.getNumCols() && "position out of bounds");
734 auto at = [&](unsigned rowIdx) -> int64_t {
735 return isEq ? cst.atEq(rowIdx, colIdx) : cst.atIneq(rowIdx, colIdx);
736 };
737 unsigned e = isEq ? cst.getNumEqualities() : cst.getNumInequalities();
738 for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
739 if (at(*rowIdx) != 0) {
740 return true;
741 }
742 }
743 return false;
744 }
745
746 // Normalizes the coefficient values across all columns in 'rowIDx' by their
747 // GCD in equality or inequality constraints as specified by 'isEq'.
748 template <bool isEq>
normalizeConstraintByGCD(FlatAffineConstraints * constraints,unsigned rowIdx)749 static void normalizeConstraintByGCD(FlatAffineConstraints *constraints,
750 unsigned rowIdx) {
751 auto at = [&](unsigned colIdx) -> int64_t {
752 return isEq ? constraints->atEq(rowIdx, colIdx)
753 : constraints->atIneq(rowIdx, colIdx);
754 };
755 uint64_t gcd = std::abs(at(0));
756 for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) {
757 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j)));
758 }
759 if (gcd > 0 && gcd != 1) {
760 for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) {
761 int64_t v = at(j) / static_cast<int64_t>(gcd);
762 isEq ? constraints->atEq(rowIdx, j) = v
763 : constraints->atIneq(rowIdx, j) = v;
764 }
765 }
766 }
767
normalizeConstraintsByGCD()768 void FlatAffineConstraints::normalizeConstraintsByGCD() {
769 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
770 normalizeConstraintByGCD</*isEq=*/true>(this, i);
771 }
772 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
773 normalizeConstraintByGCD</*isEq=*/false>(this, i);
774 }
775 }
776
hasConsistentState() const777 bool FlatAffineConstraints::hasConsistentState() const {
778 if (inequalities.size() != getNumInequalities() * numReservedCols)
779 return false;
780 if (equalities.size() != getNumEqualities() * numReservedCols)
781 return false;
782 if (ids.size() != getNumIds())
783 return false;
784
785 // Catches errors where numDims, numSymbols, numIds aren't consistent.
786 if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds)
787 return false;
788
789 return true;
790 }
791
792 /// Checks all rows of equality/inequality constraints for trivial
793 /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced
794 /// after elimination. Returns 'true' if an invalid constraint is found;
795 /// 'false' otherwise.
hasInvalidConstraint() const796 bool FlatAffineConstraints::hasInvalidConstraint() const {
797 assert(hasConsistentState());
798 auto check = [&](bool isEq) -> bool {
799 unsigned numCols = getNumCols();
800 unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
801 for (unsigned i = 0, e = numRows; i < e; ++i) {
802 unsigned j;
803 for (j = 0; j < numCols - 1; ++j) {
804 int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
805 // Skip rows with non-zero variable coefficients.
806 if (v != 0)
807 break;
808 }
809 if (j < numCols - 1) {
810 continue;
811 }
812 // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
813 // Example invalid constraints include: '1 == 0' or '-1 >= 0'
814 int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
815 if ((isEq && v != 0) || (!isEq && v < 0)) {
816 return true;
817 }
818 }
819 return false;
820 };
821 if (check(/*isEq=*/true))
822 return true;
823 return check(/*isEq=*/false);
824 }
825
826 // Eliminate identifier from constraint at 'rowIdx' based on coefficient at
827 // pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
828 // updated as they have already been eliminated.
eliminateFromConstraint(FlatAffineConstraints * constraints,unsigned rowIdx,unsigned pivotRow,unsigned pivotCol,unsigned elimColStart,bool isEq)829 static void eliminateFromConstraint(FlatAffineConstraints *constraints,
830 unsigned rowIdx, unsigned pivotRow,
831 unsigned pivotCol, unsigned elimColStart,
832 bool isEq) {
833 // Skip if equality 'rowIdx' if same as 'pivotRow'.
834 if (isEq && rowIdx == pivotRow)
835 return;
836 auto at = [&](unsigned i, unsigned j) -> int64_t {
837 return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
838 };
839 int64_t leadCoeff = at(rowIdx, pivotCol);
840 // Skip if leading coefficient at 'rowIdx' is already zero.
841 if (leadCoeff == 0)
842 return;
843 int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
844 int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
845 int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
846 int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
847 int64_t rowMultiplier = lcm / std::abs(leadCoeff);
848
849 unsigned numCols = constraints->getNumCols();
850 for (unsigned j = 0; j < numCols; ++j) {
851 // Skip updating column 'j' if it was just eliminated.
852 if (j >= elimColStart && j < pivotCol)
853 continue;
854 int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
855 rowMultiplier * at(rowIdx, j);
856 isEq ? constraints->atEq(rowIdx, j) = v
857 : constraints->atIneq(rowIdx, j) = v;
858 }
859 }
860
861 // Remove coefficients in column range [colStart, colLimit) in place.
862 // This removes in data in the specified column range, and copies any
863 // remaining valid data into place.
shiftColumnsToLeft(FlatAffineConstraints * constraints,unsigned colStart,unsigned colLimit,bool isEq)864 static void shiftColumnsToLeft(FlatAffineConstraints *constraints,
865 unsigned colStart, unsigned colLimit,
866 bool isEq) {
867 assert(colLimit <= constraints->getNumIds());
868 if (colLimit <= colStart)
869 return;
870
871 unsigned numCols = constraints->getNumCols();
872 unsigned numRows = isEq ? constraints->getNumEqualities()
873 : constraints->getNumInequalities();
874 unsigned numToEliminate = colLimit - colStart;
875 for (unsigned r = 0, e = numRows; r < e; ++r) {
876 for (unsigned c = colLimit; c < numCols; ++c) {
877 if (isEq) {
878 constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c);
879 } else {
880 constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c);
881 }
882 }
883 }
884 }
885
886 // Removes identifiers in column range [idStart, idLimit), and copies any
887 // remaining valid data into place, and updates member variables.
removeIdRange(unsigned idStart,unsigned idLimit)888 void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) {
889 assert(idLimit < getNumCols() && "invalid id limit");
890
891 if (idStart >= idLimit)
892 return;
893
894 // We are going to be removing one or more identifiers from the range.
895 assert(idStart < numIds && "invalid idStart position");
896
897 // TODO: Make 'removeIdRange' a lambda called from here.
898 // Remove eliminated identifiers from equalities.
899 shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true);
900
901 // Remove eliminated identifiers from inequalities.
902 shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false);
903
904 // Update members numDims, numSymbols and numIds.
905 unsigned numDimsEliminated = 0;
906 unsigned numLocalsEliminated = 0;
907 unsigned numColsEliminated = idLimit - idStart;
908 if (idStart < numDims) {
909 numDimsEliminated = std::min(numDims, idLimit) - idStart;
910 }
911 // Check how many local id's were removed. Note that our identifier order is
912 // [dims, symbols, locals]. Local id start at position numDims + numSymbols.
913 if (idLimit > numDims + numSymbols) {
914 numLocalsEliminated = std::min(
915 idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds());
916 }
917 unsigned numSymbolsEliminated =
918 numColsEliminated - numDimsEliminated - numLocalsEliminated;
919
920 numDims -= numDimsEliminated;
921 numSymbols -= numSymbolsEliminated;
922 numIds = numIds - numColsEliminated;
923
924 ids.erase(ids.begin() + idStart, ids.begin() + idLimit);
925
926 // No resize necessary. numReservedCols remains the same.
927 }
928
929 /// Returns the position of the identifier that has the minimum <number of lower
930 /// bounds> times <number of upper bounds> from the specified range of
931 /// identifiers [start, end). It is often best to eliminate in the increasing
932 /// order of these counts when doing Fourier-Motzkin elimination since FM adds
933 /// that many new constraints.
getBestIdToEliminate(const FlatAffineConstraints & cst,unsigned start,unsigned end)934 static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst,
935 unsigned start, unsigned end) {
936 assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
937
938 auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
939 unsigned numLb = 0;
940 unsigned numUb = 0;
941 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
942 if (cst.atIneq(r, pos) > 0) {
943 ++numLb;
944 } else if (cst.atIneq(r, pos) < 0) {
945 ++numUb;
946 }
947 }
948 return numLb * numUb;
949 };
950
951 unsigned minLoc = start;
952 unsigned min = getProductOfNumLowerUpperBounds(start);
953 for (unsigned c = start + 1; c < end; c++) {
954 unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
955 if (numLbUbProduct < min) {
956 min = numLbUbProduct;
957 minLoc = c;
958 }
959 }
960 return minLoc;
961 }
962
963 // Checks for emptiness of the set by eliminating identifiers successively and
964 // using the GCD test (on all equality constraints) and checking for trivially
965 // invalid constraints. Returns 'true' if the constraint system is found to be
966 // empty; false otherwise.
isEmpty() const967 bool FlatAffineConstraints::isEmpty() const {
968 if (isEmptyByGCDTest() || hasInvalidConstraint())
969 return true;
970
971 // First, eliminate as many identifiers as possible using Gaussian
972 // elimination.
973 FlatAffineConstraints tmpCst(*this);
974 unsigned currentPos = 0;
975 while (currentPos < tmpCst.getNumIds()) {
976 tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
977 ++currentPos;
978 // We check emptiness through trivial checks after eliminating each ID to
979 // detect emptiness early. Since the checks isEmptyByGCDTest() and
980 // hasInvalidConstraint() are linear time and single sweep on the constraint
981 // buffer, this appears reasonable - but can optimize in the future.
982 if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
983 return true;
984 }
985
986 // Eliminate the remaining using FM.
987 for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
988 tmpCst.FourierMotzkinEliminate(
989 getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
990 // Check for a constraint explosion. This rarely happens in practice, but
991 // this check exists as a safeguard against improperly constructed
992 // constraint systems or artificially created arbitrarily complex systems
993 // that aren't the intended use case for FlatAffineConstraints. This is
994 // needed since FM has a worst case exponential complexity in theory.
995 if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
996 LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
997 return false;
998 }
999
1000 // FM wouldn't have modified the equalities in any way. So no need to again
1001 // run GCD test. Check for trivial invalid constraints.
1002 if (tmpCst.hasInvalidConstraint())
1003 return true;
1004 }
1005 return false;
1006 }
1007
1008 // Runs the GCD test on all equality constraints. Returns 'true' if this test
1009 // fails on any equality. Returns 'false' otherwise.
1010 // This test can be used to disprove the existence of a solution. If it returns
1011 // true, no integer solution to the equality constraints can exist.
1012 //
1013 // GCD test definition:
1014 //
1015 // The equality constraint:
1016 //
1017 // c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
1018 //
1019 // has an integer solution iff:
1020 //
1021 // GCD of c_1, c_2, ..., c_n divides c_0.
1022 //
isEmptyByGCDTest() const1023 bool FlatAffineConstraints::isEmptyByGCDTest() const {
1024 assert(hasConsistentState());
1025 unsigned numCols = getNumCols();
1026 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1027 uint64_t gcd = std::abs(atEq(i, 0));
1028 for (unsigned j = 1; j < numCols - 1; ++j) {
1029 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
1030 }
1031 int64_t v = std::abs(atEq(i, numCols - 1));
1032 if (gcd > 0 && (v % gcd != 0)) {
1033 return true;
1034 }
1035 }
1036 return false;
1037 }
1038
1039 // First, try the GCD test heuristic.
1040 //
1041 // If that doesn't find the set empty, check if the set is unbounded. If it is,
1042 // we cannot use the GBR algorithm and we conservatively return false.
1043 //
1044 // If the set is bounded, we use the complete emptiness check for this case
1045 // provided by Simplex::findIntegerSample(), which gives a definitive answer.
isIntegerEmpty() const1046 bool FlatAffineConstraints::isIntegerEmpty() const {
1047 if (isEmptyByGCDTest())
1048 return true;
1049
1050 Simplex simplex(*this);
1051 if (simplex.isUnbounded())
1052 return false;
1053 return !simplex.findIntegerSample().hasValue();
1054 }
1055
1056 Optional<SmallVector<int64_t, 8>>
findIntegerSample() const1057 FlatAffineConstraints::findIntegerSample() const {
1058 return Simplex(*this).findIntegerSample();
1059 }
1060
1061 /// Tightens inequalities given that we are dealing with integer spaces. This is
1062 /// analogous to the GCD test but applied to inequalities. The constant term can
1063 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
1064 /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a
1065 /// fast method - linear in the number of coefficients.
1066 // Example on how this affects practical cases: consider the scenario:
1067 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
1068 // j >= 100 instead of the tighter (exact) j >= 128.
GCDTightenInequalities()1069 void FlatAffineConstraints::GCDTightenInequalities() {
1070 unsigned numCols = getNumCols();
1071 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1072 uint64_t gcd = std::abs(atIneq(i, 0));
1073 for (unsigned j = 1; j < numCols - 1; ++j) {
1074 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j)));
1075 }
1076 if (gcd > 0 && gcd != 1) {
1077 int64_t gcdI = static_cast<int64_t>(gcd);
1078 // Tighten the constant term and normalize the constraint by the GCD.
1079 atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI);
1080 for (unsigned j = 0, e = numCols - 1; j < e; ++j)
1081 atIneq(i, j) /= gcdI;
1082 }
1083 }
1084 }
1085
1086 // Eliminates all identifier variables in column range [posStart, posLimit).
1087 // Returns the number of variables eliminated.
gaussianEliminateIds(unsigned posStart,unsigned posLimit)1088 unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
1089 unsigned posLimit) {
1090 // Return if identifier positions to eliminate are out of range.
1091 assert(posLimit <= numIds);
1092 assert(hasConsistentState());
1093
1094 if (posStart >= posLimit)
1095 return 0;
1096
1097 GCDTightenInequalities();
1098
1099 unsigned pivotCol = 0;
1100 for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
1101 // Find a row which has a non-zero coefficient in column 'j'.
1102 unsigned pivotRow;
1103 if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true,
1104 &pivotRow)) {
1105 // No pivot row in equalities with non-zero at 'pivotCol'.
1106 if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false,
1107 &pivotRow)) {
1108 // If inequalities are also non-zero in 'pivotCol', it can be
1109 // eliminated.
1110 continue;
1111 }
1112 break;
1113 }
1114
1115 // Eliminate identifier at 'pivotCol' from each equality row.
1116 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1117 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1118 /*isEq=*/true);
1119 normalizeConstraintByGCD</*isEq=*/true>(this, i);
1120 }
1121
1122 // Eliminate identifier at 'pivotCol' from each inequality row.
1123 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1124 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1125 /*isEq=*/false);
1126 normalizeConstraintByGCD</*isEq=*/false>(this, i);
1127 }
1128 removeEquality(pivotRow);
1129 GCDTightenInequalities();
1130 }
1131 // Update position limit based on number eliminated.
1132 posLimit = pivotCol;
1133 // Remove eliminated columns from all constraints.
1134 removeIdRange(posStart, posLimit);
1135 return posLimit - posStart;
1136 }
1137
1138 // Detect the identifier at 'pos' (say id_r) as modulo of another identifier
1139 // (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
1140 // could be detected as the floordiv of n. For eg:
1141 // id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=>
1142 // id_r = id_n mod 4, id_q = id_n floordiv 4.
1143 // lbConst and ubConst are the constant lower and upper bounds for 'pos' -
1144 // pre-detected at the caller.
detectAsMod(const FlatAffineConstraints & cst,unsigned pos,int64_t lbConst,int64_t ubConst,SmallVectorImpl<AffineExpr> * memo)1145 static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
1146 int64_t lbConst, int64_t ubConst,
1147 SmallVectorImpl<AffineExpr> *memo) {
1148 assert(pos < cst.getNumIds() && "invalid position");
1149
1150 // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
1151 // id_n - divisor * id_q. If these are true, then id_n becomes the dividend
1152 // and id_q the quotient when dividing id_n by the divisor.
1153
1154 if (lbConst != 0 || ubConst < 1)
1155 return false;
1156
1157 int64_t divisor = ubConst + 1;
1158
1159 // Now check for: id_r = id_n - divisor * id_q. As an example, we
1160 // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
1161 unsigned seenQuotient = 0, seenDividend = 0;
1162 int quotientPos = -1, dividendPos = -1;
1163 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1164 // id_n should have coeff 1 or -1.
1165 if (std::abs(cst.atEq(r, pos)) != 1)
1166 continue;
1167 // constant term should be 0.
1168 if (cst.atEq(r, cst.getNumCols() - 1) != 0)
1169 continue;
1170 unsigned c, f;
1171 int quotientSign = 1, dividendSign = 1;
1172 for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
1173 if (c == pos)
1174 continue;
1175 // The coefficient of the quotient should be +/-divisor.
1176 // TODO: could be extended to detect an affine function for the quotient
1177 // (i.e., the coeff could be a non-zero multiple of divisor).
1178 int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
1179 if (v == divisor || v == -divisor) {
1180 seenQuotient++;
1181 quotientPos = c;
1182 quotientSign = v > 0 ? 1 : -1;
1183 }
1184 // The coefficient of the dividend should be +/-1.
1185 // TODO: could be extended to detect an affine function of the other
1186 // identifiers as the dividend.
1187 else if (v == -1 || v == 1) {
1188 seenDividend++;
1189 dividendPos = c;
1190 dividendSign = v < 0 ? 1 : -1;
1191 } else if (cst.atEq(r, c) != 0) {
1192 // Cannot be inferred as a mod since the constraint has a coefficient
1193 // for an identifier that's neither a unit nor the divisor (see TODOs
1194 // above).
1195 break;
1196 }
1197 }
1198 if (c < f)
1199 // Cannot be inferred as a mod since the constraint has a coefficient for
1200 // an identifier that's neither a unit nor the divisor (see TODOs above).
1201 continue;
1202
1203 // We are looking for exactly one identifier as the dividend.
1204 if (seenDividend == 1 && seenQuotient >= 1) {
1205 if (!(*memo)[dividendPos])
1206 return false;
1207 // Successfully detected a mod.
1208 (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
1209 auto ub = cst.getConstantUpperBound(dividendPos);
1210 if (ub.hasValue() && ub.getValue() < divisor)
1211 // The mod can be optimized away.
1212 (*memo)[pos] = (*memo)[dividendPos] * dividendSign;
1213 else
1214 (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
1215
1216 if (seenQuotient == 1 && !(*memo)[quotientPos])
1217 // Successfully detected a floordiv as well.
1218 (*memo)[quotientPos] =
1219 (*memo)[dividendPos].floorDiv(divisor) * quotientSign;
1220 return true;
1221 }
1222 }
1223 return false;
1224 }
1225
1226 /// Gather all lower and upper bounds of the identifier at `pos`, and
1227 /// optionally any equalities on it. In addition, the bounds are to be
1228 /// independent of identifiers in position range [`offset`, `offset` + `num`).
getLowerAndUpperBoundIndices(unsigned pos,SmallVectorImpl<unsigned> * lbIndices,SmallVectorImpl<unsigned> * ubIndices,SmallVectorImpl<unsigned> * eqIndices,unsigned offset,unsigned num) const1229 void FlatAffineConstraints::getLowerAndUpperBoundIndices(
1230 unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
1231 SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
1232 unsigned offset, unsigned num) const {
1233 assert(pos < getNumIds() && "invalid position");
1234 assert(offset + num < getNumCols() && "invalid range");
1235
1236 // Checks for a constraint that has a non-zero coeff for the identifiers in
1237 // the position range [offset, offset + num) while ignoring `pos`.
1238 auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
1239 unsigned c, f;
1240 auto cst = isEq ? getEquality(r) : getInequality(r);
1241 for (c = offset, f = offset + num; c < f; ++c) {
1242 if (c == pos)
1243 continue;
1244 if (cst[c] != 0)
1245 break;
1246 }
1247 return c < f;
1248 };
1249
1250 // Gather all lower bounds and upper bounds of the variable. Since the
1251 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1252 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1253 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1254 // The bounds are to be independent of [offset, offset + num) columns.
1255 if (containsConstraintDependentOnRange(r, /*isEq=*/false))
1256 continue;
1257 if (atIneq(r, pos) >= 1) {
1258 // Lower bound.
1259 lbIndices->push_back(r);
1260 } else if (atIneq(r, pos) <= -1) {
1261 // Upper bound.
1262 ubIndices->push_back(r);
1263 }
1264 }
1265
1266 // An equality is both a lower and upper bound. Record any equalities
1267 // involving the pos^th identifier.
1268 if (!eqIndices)
1269 return;
1270
1271 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1272 if (atEq(r, pos) == 0)
1273 continue;
1274 if (containsConstraintDependentOnRange(r, /*isEq=*/true))
1275 continue;
1276 eqIndices->push_back(r);
1277 }
1278 }
1279
1280 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
1281 /// function of other identifiers (where the divisor is a positive constant)
1282 /// given the initial set of expressions in `exprs`. If it can be, the
1283 /// corresponding position in `exprs` is set as the detected affine expr. For
1284 /// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can
1285 /// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
1286 /// <= i <= 32q + 31 => q = i floordiv 32.
detectAsFloorDiv(const FlatAffineConstraints & cst,unsigned pos,MLIRContext * context,SmallVectorImpl<AffineExpr> & exprs)1287 static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
1288 MLIRContext *context,
1289 SmallVectorImpl<AffineExpr> &exprs) {
1290 assert(pos < cst.getNumIds() && "invalid position");
1291
1292 SmallVector<unsigned, 4> lbIndices, ubIndices;
1293 cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices);
1294
1295 // Check if any lower bound, upper bound pair is of the form:
1296 // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id'
1297 // divisor * id <= expr <-- Upper bound for 'id'
1298 // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1).
1299 //
1300 // For example, if -32*k + 16*i + j >= 0
1301 // 32*k - 16*i - j + 31 >= 0 <=>
1302 // k = ( 16*i + j ) floordiv 32
1303 unsigned seenDividends = 0;
1304 for (auto ubPos : ubIndices) {
1305 for (auto lbPos : lbIndices) {
1306 // Check if the lower bound's constant term is divisor - 1. The
1307 // 'divisor' here is cst.atIneq(lbPos, pos) and we already know that it's
1308 // positive (since cst.Ineq(lbPos, ...) is a lower bound expr for 'pos'.
1309 int64_t divisor = cst.atIneq(lbPos, pos);
1310 int64_t lbConstTerm = cst.atIneq(lbPos, cst.getNumCols() - 1);
1311 if (lbConstTerm != divisor - 1)
1312 continue;
1313 // Check if upper bound's constant term is 0.
1314 if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0)
1315 continue;
1316 // For the remaining part, check if the lower bound expr's coeff's are
1317 // negations of corresponding upper bound ones'.
1318 unsigned c, f;
1319 for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1320 if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c))
1321 break;
1322 if (c != pos && cst.atIneq(lbPos, c) != 0)
1323 seenDividends++;
1324 }
1325 // Lb coeff's aren't negative of ub coeff's (for the non constant term
1326 // part).
1327 if (c < f)
1328 continue;
1329 if (seenDividends >= 1) {
1330 // Construct the dividend expression.
1331 auto dividendExpr = getAffineConstantExpr(0, context);
1332 unsigned c, f;
1333 for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1334 if (c == pos)
1335 continue;
1336 int64_t ubVal = cst.atIneq(ubPos, c);
1337 if (ubVal == 0)
1338 continue;
1339 if (!exprs[c])
1340 break;
1341 dividendExpr = dividendExpr + ubVal * exprs[c];
1342 }
1343 // Expression can't be constructed as it depends on a yet unknown
1344 // identifier.
1345 // TODO: Visit/compute the identifiers in an order so that this doesn't
1346 // happen. More complex but much more efficient.
1347 if (c < f)
1348 continue;
1349 // Successfully detected the floordiv.
1350 exprs[pos] = dividendExpr.floorDiv(divisor);
1351 return true;
1352 }
1353 }
1354 }
1355 return false;
1356 }
1357
1358 // Fills an inequality row with the value 'val'.
fillInequality(FlatAffineConstraints * cst,unsigned r,int64_t val)1359 static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
1360 int64_t val) {
1361 for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1362 cst->atIneq(r, c) = val;
1363 }
1364 }
1365
1366 // Negates an inequality.
negateInequality(FlatAffineConstraints * cst,unsigned r)1367 static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
1368 for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1369 cst->atIneq(r, c) = -cst->atIneq(r, c);
1370 }
1371 }
1372
1373 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
1374 // to check if a constraint is redundant.
removeRedundantInequalities()1375 void FlatAffineConstraints::removeRedundantInequalities() {
1376 SmallVector<bool, 32> redun(getNumInequalities(), false);
1377 // To check if an inequality is redundant, we replace the inequality by its
1378 // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
1379 // system is empty. If it is, the inequality is redundant.
1380 FlatAffineConstraints tmpCst(*this);
1381 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1382 // Change the inequality to its complement.
1383 negateInequality(&tmpCst, r);
1384 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
1385 if (tmpCst.isEmpty()) {
1386 redun[r] = true;
1387 // Zero fill the redundant inequality.
1388 fillInequality(this, r, /*val=*/0);
1389 fillInequality(&tmpCst, r, /*val=*/0);
1390 } else {
1391 // Reverse the change (to avoid recreating tmpCst each time).
1392 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
1393 negateInequality(&tmpCst, r);
1394 }
1395 }
1396
1397 // Scan to get rid of all rows marked redundant, in-place.
1398 auto copyRow = [&](unsigned src, unsigned dest) {
1399 if (src == dest)
1400 return;
1401 for (unsigned c = 0, e = getNumCols(); c < e; c++) {
1402 atIneq(dest, c) = atIneq(src, c);
1403 }
1404 };
1405 unsigned pos = 0;
1406 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1407 if (!redun[r])
1408 copyRow(r, pos++);
1409 }
1410 inequalities.resize(numReservedCols * pos);
1411 }
1412
getLowerAndUpperBound(unsigned pos,unsigned offset,unsigned num,unsigned symStartPos,ArrayRef<AffineExpr> localExprs,MLIRContext * context) const1413 std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
1414 unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
1415 ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
1416 assert(pos + offset < getNumDimIds() && "invalid dim start pos");
1417 assert(symStartPos >= (pos + offset) && "invalid sym start pos");
1418 assert(getNumLocalIds() == localExprs.size() &&
1419 "incorrect local exprs count");
1420
1421 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
1422 getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
1423 offset, num);
1424
1425 /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
1426 auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
1427 b.clear();
1428 for (unsigned i = 0, e = a.size(); i < e; ++i) {
1429 if (i < offset || i >= offset + num)
1430 b.push_back(a[i]);
1431 }
1432 };
1433
1434 SmallVector<int64_t, 8> lb, ub;
1435 SmallVector<AffineExpr, 4> lbExprs;
1436 unsigned dimCount = symStartPos - num;
1437 unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
1438 lbExprs.reserve(lbIndices.size() + eqIndices.size());
1439 // Lower bound expressions.
1440 for (auto idx : lbIndices) {
1441 auto ineq = getInequality(idx);
1442 // Extract the lower bound (in terms of other coeff's + const), i.e., if
1443 // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
1444 // - 1.
1445 addCoeffs(ineq, lb);
1446 std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
1447 auto expr =
1448 getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
1449 // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
1450 int64_t divisor = std::abs(ineq[pos + offset]);
1451 expr = (expr + divisor - 1).floorDiv(divisor);
1452 lbExprs.push_back(expr);
1453 }
1454
1455 SmallVector<AffineExpr, 4> ubExprs;
1456 ubExprs.reserve(ubIndices.size() + eqIndices.size());
1457 // Upper bound expressions.
1458 for (auto idx : ubIndices) {
1459 auto ineq = getInequality(idx);
1460 // Extract the upper bound (in terms of other coeff's + const).
1461 addCoeffs(ineq, ub);
1462 auto expr =
1463 getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
1464 expr = expr.floorDiv(std::abs(ineq[pos + offset]));
1465 // Upper bound is exclusive.
1466 ubExprs.push_back(expr + 1);
1467 }
1468
1469 // Equalities. It's both a lower and a upper bound.
1470 SmallVector<int64_t, 4> b;
1471 for (auto idx : eqIndices) {
1472 auto eq = getEquality(idx);
1473 addCoeffs(eq, b);
1474 if (eq[pos + offset] > 0)
1475 std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
1476
1477 // Extract the upper bound (in terms of other coeff's + const).
1478 auto expr =
1479 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1480 expr = expr.floorDiv(std::abs(eq[pos + offset]));
1481 // Upper bound is exclusive.
1482 ubExprs.push_back(expr + 1);
1483 // Lower bound.
1484 expr =
1485 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1486 expr = expr.ceilDiv(std::abs(eq[pos + offset]));
1487 lbExprs.push_back(expr);
1488 }
1489
1490 auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
1491 auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
1492
1493 return {lbMap, ubMap};
1494 }
1495
1496 /// Computes the lower and upper bounds of the first 'num' dimensional
1497 /// identifiers (starting at 'offset') as affine maps of the remaining
1498 /// identifiers (dimensional and symbolic identifiers). Local identifiers are
1499 /// themselves explicitly computed as affine functions of other identifiers in
1500 /// this process if needed.
getSliceBounds(unsigned offset,unsigned num,MLIRContext * context,SmallVectorImpl<AffineMap> * lbMaps,SmallVectorImpl<AffineMap> * ubMaps)1501 void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
1502 MLIRContext *context,
1503 SmallVectorImpl<AffineMap> *lbMaps,
1504 SmallVectorImpl<AffineMap> *ubMaps) {
1505 assert(num < getNumDimIds() && "invalid range");
1506
1507 // Basic simplification.
1508 normalizeConstraintsByGCD();
1509
1510 LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
1511 << " identifiers\n");
1512 LLVM_DEBUG(dump());
1513
1514 // Record computed/detected identifiers.
1515 SmallVector<AffineExpr, 8> memo(getNumIds());
1516 // Initialize dimensional and symbolic identifiers.
1517 for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
1518 if (i < offset)
1519 memo[i] = getAffineDimExpr(i, context);
1520 else if (i >= offset + num)
1521 memo[i] = getAffineDimExpr(i - num, context);
1522 }
1523 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
1524 memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
1525
1526 bool changed;
1527 do {
1528 changed = false;
1529 // Identify yet unknown identifiers as constants or mod's / floordiv's of
1530 // other identifiers if possible.
1531 for (unsigned pos = 0; pos < getNumIds(); pos++) {
1532 if (memo[pos])
1533 continue;
1534
1535 auto lbConst = getConstantLowerBound(pos);
1536 auto ubConst = getConstantUpperBound(pos);
1537 if (lbConst.hasValue() && ubConst.hasValue()) {
1538 // Detect equality to a constant.
1539 if (lbConst.getValue() == ubConst.getValue()) {
1540 memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
1541 changed = true;
1542 continue;
1543 }
1544
1545 // Detect an identifier as modulo of another identifier w.r.t a
1546 // constant.
1547 if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
1548 &memo)) {
1549 changed = true;
1550 continue;
1551 }
1552 }
1553
1554 // Detect an identifier as a floordiv of an affine function of other
1555 // identifiers (divisor is a positive constant).
1556 if (detectAsFloorDiv(*this, pos, context, memo)) {
1557 changed = true;
1558 continue;
1559 }
1560
1561 // Detect an identifier as an expression of other identifiers.
1562 unsigned idx;
1563 if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) {
1564 continue;
1565 }
1566
1567 // Build AffineExpr solving for identifier 'pos' in terms of all others.
1568 auto expr = getAffineConstantExpr(0, context);
1569 unsigned j, e;
1570 for (j = 0, e = getNumIds(); j < e; ++j) {
1571 if (j == pos)
1572 continue;
1573 int64_t c = atEq(idx, j);
1574 if (c == 0)
1575 continue;
1576 // If any of the involved IDs hasn't been found yet, we can't proceed.
1577 if (!memo[j])
1578 break;
1579 expr = expr + memo[j] * c;
1580 }
1581 if (j < e)
1582 // Can't construct expression as it depends on a yet uncomputed
1583 // identifier.
1584 continue;
1585
1586 // Add constant term to AffineExpr.
1587 expr = expr + atEq(idx, getNumIds());
1588 int64_t vPos = atEq(idx, pos);
1589 assert(vPos != 0 && "expected non-zero here");
1590 if (vPos > 0)
1591 expr = (-expr).floorDiv(vPos);
1592 else
1593 // vPos < 0.
1594 expr = expr.floorDiv(-vPos);
1595 // Successfully constructed expression.
1596 memo[pos] = expr;
1597 changed = true;
1598 }
1599 // This loop is guaranteed to reach a fixed point - since once an
1600 // identifier's explicit form is computed (in memo[pos]), it's not updated
1601 // again.
1602 } while (changed);
1603
1604 // Set the lower and upper bound maps for all the identifiers that were
1605 // computed as affine expressions of the rest as the "detected expr" and
1606 // "detected expr + 1" respectively; set the undetected ones to null.
1607 Optional<FlatAffineConstraints> tmpClone;
1608 for (unsigned pos = 0; pos < num; pos++) {
1609 unsigned numMapDims = getNumDimIds() - num;
1610 unsigned numMapSymbols = getNumSymbolIds();
1611 AffineExpr expr = memo[pos + offset];
1612 if (expr)
1613 expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
1614
1615 AffineMap &lbMap = (*lbMaps)[pos];
1616 AffineMap &ubMap = (*ubMaps)[pos];
1617
1618 if (expr) {
1619 lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
1620 ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
1621 } else {
1622 // TODO: Whenever there are local identifiers in the dependence
1623 // constraints, we'll conservatively over-approximate, since we don't
1624 // always explicitly compute them above (in the while loop).
1625 if (getNumLocalIds() == 0) {
1626 // Work on a copy so that we don't update this constraint system.
1627 if (!tmpClone) {
1628 tmpClone.emplace(FlatAffineConstraints(*this));
1629 // Removing redundant inequalities is necessary so that we don't get
1630 // redundant loop bounds.
1631 tmpClone->removeRedundantInequalities();
1632 }
1633 std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
1634 pos, offset, num, getNumDimIds(), /*localExprs=*/{}, context);
1635 }
1636
1637 // If the above fails, we'll just use the constant lower bound and the
1638 // constant upper bound (if they exist) as the slice bounds.
1639 // TODO: being conservative for the moment in cases that
1640 // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
1641 // fixed (b/126426796).
1642 if (!lbMap || lbMap.getNumResults() > 1) {
1643 LLVM_DEBUG(llvm::dbgs()
1644 << "WARNING: Potentially over-approximating slice lb\n");
1645 auto lbConst = getConstantLowerBound(pos + offset);
1646 if (lbConst.hasValue()) {
1647 lbMap = AffineMap::get(
1648 numMapDims, numMapSymbols,
1649 getAffineConstantExpr(lbConst.getValue(), context));
1650 }
1651 }
1652 if (!ubMap || ubMap.getNumResults() > 1) {
1653 LLVM_DEBUG(llvm::dbgs()
1654 << "WARNING: Potentially over-approximating slice ub\n");
1655 auto ubConst = getConstantUpperBound(pos + offset);
1656 if (ubConst.hasValue()) {
1657 (ubMap) = AffineMap::get(
1658 numMapDims, numMapSymbols,
1659 getAffineConstantExpr(ubConst.getValue() + 1, context));
1660 }
1661 }
1662 }
1663 LLVM_DEBUG(llvm::dbgs()
1664 << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
1665 LLVM_DEBUG(lbMap.dump(););
1666 LLVM_DEBUG(llvm::dbgs()
1667 << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
1668 LLVM_DEBUG(ubMap.dump(););
1669 }
1670 }
1671
1672 LogicalResult
addLowerOrUpperBound(unsigned pos,AffineMap boundMap,ValueRange boundOperands,bool eq,bool lower)1673 FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
1674 ValueRange boundOperands, bool eq,
1675 bool lower) {
1676 assert(pos < getNumDimAndSymbolIds() && "invalid position");
1677 // Equality follows the logic of lower bound except that we add an equality
1678 // instead of an inequality.
1679 assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
1680 if (eq)
1681 lower = true;
1682
1683 // Fully compose map and operands; canonicalize and simplify so that we
1684 // transitively get to terminal symbols or loop IVs.
1685 auto map = boundMap;
1686 SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
1687 fullyComposeAffineMapAndOperands(&map, &operands);
1688 map = simplifyAffineMap(map);
1689 canonicalizeMapAndOperands(&map, &operands);
1690 for (auto operand : operands)
1691 addInductionVarOrTerminalSymbol(operand);
1692
1693 FlatAffineConstraints localVarCst;
1694 std::vector<SmallVector<int64_t, 8>> flatExprs;
1695 if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) {
1696 LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
1697 return failure();
1698 }
1699
1700 // Merge and align with localVarCst.
1701 if (localVarCst.getNumLocalIds() > 0) {
1702 // Set values for localVarCst.
1703 localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
1704 for (auto operand : operands) {
1705 unsigned pos;
1706 if (findId(operand, &pos)) {
1707 if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
1708 // If the local var cst has this as a dim, turn it into its symbol.
1709 turnDimIntoSymbol(&localVarCst, operand);
1710 } else if (pos < getNumDimIds()) {
1711 // Or vice versa.
1712 turnSymbolIntoDim(&localVarCst, operand);
1713 }
1714 }
1715 }
1716 mergeAndAlignIds(/*offset=*/0, this, &localVarCst);
1717 append(localVarCst);
1718 }
1719
1720 // Record positions of the operands in the constraint system. Need to do
1721 // this here since the constraint system changes after a bound is added.
1722 SmallVector<unsigned, 8> positions;
1723 unsigned numOperands = operands.size();
1724 for (auto operand : operands) {
1725 unsigned pos;
1726 if (!findId(operand, &pos))
1727 assert(0 && "expected to be found");
1728 positions.push_back(pos);
1729 }
1730
1731 for (const auto &flatExpr : flatExprs) {
1732 SmallVector<int64_t, 4> ineq(getNumCols(), 0);
1733 ineq[pos] = lower ? 1 : -1;
1734 // Dims and symbols.
1735 for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) {
1736 ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
1737 }
1738 // Copy over the local id coefficients.
1739 unsigned numLocalIds = flatExpr.size() - 1 - numOperands;
1740 for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds;
1741 jj++, j++) {
1742 ineq[j] =
1743 lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj];
1744 }
1745 // Constant term.
1746 ineq[getNumCols() - 1] =
1747 lower ? -flatExpr[flatExpr.size() - 1]
1748 // Upper bound in flattenedExpr is an exclusive one.
1749 : flatExpr[flatExpr.size() - 1] - 1;
1750 eq ? addEquality(ineq) : addInequality(ineq);
1751 }
1752 return success();
1753 }
1754
1755 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
1756 // bounds in 'ubMaps' to each value in `values' that appears in the constraint
1757 // system. Note that both lower/upper bounds share the same operand list
1758 // 'operands'.
1759 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
1760 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
1761 // Note that both lower/upper bounds use operands from 'operands'.
1762 // Returns failure for unimplemented cases such as semi-affine expressions or
1763 // expressions with mod/floordiv.
addSliceBounds(ArrayRef<Value> values,ArrayRef<AffineMap> lbMaps,ArrayRef<AffineMap> ubMaps,ArrayRef<Value> operands)1764 LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values,
1765 ArrayRef<AffineMap> lbMaps,
1766 ArrayRef<AffineMap> ubMaps,
1767 ArrayRef<Value> operands) {
1768 assert(values.size() == lbMaps.size());
1769 assert(lbMaps.size() == ubMaps.size());
1770
1771 for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
1772 unsigned pos;
1773 if (!findId(values[i], &pos))
1774 continue;
1775
1776 AffineMap lbMap = lbMaps[i];
1777 AffineMap ubMap = ubMaps[i];
1778 assert(!lbMap || lbMap.getNumInputs() == operands.size());
1779 assert(!ubMap || ubMap.getNumInputs() == operands.size());
1780
1781 // Check if this slice is just an equality along this dimension.
1782 if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
1783 ubMap.getNumResults() == 1 &&
1784 lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
1785 if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
1786 /*lower=*/true)))
1787 return failure();
1788 continue;
1789 }
1790
1791 if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
1792 /*lower=*/true)))
1793 return failure();
1794
1795 if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
1796 /*lower=*/false)))
1797 return failure();
1798 }
1799 return success();
1800 }
1801
addEquality(ArrayRef<int64_t> eq)1802 void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
1803 assert(eq.size() == getNumCols());
1804 unsigned offset = equalities.size();
1805 equalities.resize(equalities.size() + numReservedCols);
1806 std::copy(eq.begin(), eq.end(), equalities.begin() + offset);
1807 }
1808
addInequality(ArrayRef<int64_t> inEq)1809 void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
1810 assert(inEq.size() == getNumCols());
1811 unsigned offset = inequalities.size();
1812 inequalities.resize(inequalities.size() + numReservedCols);
1813 std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset);
1814 }
1815
addConstantLowerBound(unsigned pos,int64_t lb)1816 void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
1817 assert(pos < getNumCols());
1818 unsigned offset = inequalities.size();
1819 inequalities.resize(inequalities.size() + numReservedCols);
1820 std::fill(inequalities.begin() + offset,
1821 inequalities.begin() + offset + getNumCols(), 0);
1822 inequalities[offset + pos] = 1;
1823 inequalities[offset + getNumCols() - 1] = -lb;
1824 }
1825
addConstantUpperBound(unsigned pos,int64_t ub)1826 void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) {
1827 assert(pos < getNumCols());
1828 unsigned offset = inequalities.size();
1829 inequalities.resize(inequalities.size() + numReservedCols);
1830 std::fill(inequalities.begin() + offset,
1831 inequalities.begin() + offset + getNumCols(), 0);
1832 inequalities[offset + pos] = -1;
1833 inequalities[offset + getNumCols() - 1] = ub;
1834 }
1835
addConstantLowerBound(ArrayRef<int64_t> expr,int64_t lb)1836 void FlatAffineConstraints::addConstantLowerBound(ArrayRef<int64_t> expr,
1837 int64_t lb) {
1838 assert(expr.size() == getNumCols());
1839 unsigned offset = inequalities.size();
1840 inequalities.resize(inequalities.size() + numReservedCols);
1841 std::fill(inequalities.begin() + offset,
1842 inequalities.begin() + offset + getNumCols(), 0);
1843 std::copy(expr.begin(), expr.end(), inequalities.begin() + offset);
1844 inequalities[offset + getNumCols() - 1] += -lb;
1845 }
1846
addConstantUpperBound(ArrayRef<int64_t> expr,int64_t ub)1847 void FlatAffineConstraints::addConstantUpperBound(ArrayRef<int64_t> expr,
1848 int64_t ub) {
1849 assert(expr.size() == getNumCols());
1850 unsigned offset = inequalities.size();
1851 inequalities.resize(inequalities.size() + numReservedCols);
1852 std::fill(inequalities.begin() + offset,
1853 inequalities.begin() + offset + getNumCols(), 0);
1854 for (unsigned i = 0, e = getNumCols(); i < e; i++) {
1855 inequalities[offset + i] = -expr[i];
1856 }
1857 inequalities[offset + getNumCols() - 1] += ub;
1858 }
1859
1860 /// Adds a new local identifier as the floordiv of an affine function of other
1861 /// identifiers, the coefficients of which are provided in 'dividend' and with
1862 /// respect to a positive constant 'divisor'. Two constraints are added to the
1863 /// system to capture equivalence with the floordiv.
1864 /// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1.
addLocalFloorDiv(ArrayRef<int64_t> dividend,int64_t divisor)1865 void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend,
1866 int64_t divisor) {
1867 assert(dividend.size() == getNumCols() && "incorrect dividend size");
1868 assert(divisor > 0 && "positive divisor expected");
1869
1870 addLocalId(getNumLocalIds());
1871
1872 // Add two constraints for this new identifier 'q'.
1873 SmallVector<int64_t, 8> bound(dividend.size() + 1);
1874
1875 // dividend - q * divisor >= 0
1876 std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
1877 bound.begin());
1878 bound.back() = dividend.back();
1879 bound[getNumIds() - 1] = -divisor;
1880 addInequality(bound);
1881
1882 // -dividend +qdivisor * q + divisor - 1 >= 0
1883 std::transform(bound.begin(), bound.end(), bound.begin(),
1884 std::negate<int64_t>());
1885 bound[bound.size() - 1] += divisor - 1;
1886 addInequality(bound);
1887 }
1888
findId(Value id,unsigned * pos) const1889 bool FlatAffineConstraints::findId(Value id, unsigned *pos) const {
1890 unsigned i = 0;
1891 for (const auto &mayBeId : ids) {
1892 if (mayBeId.hasValue() && mayBeId.getValue() == id) {
1893 *pos = i;
1894 return true;
1895 }
1896 i++;
1897 }
1898 return false;
1899 }
1900
containsId(Value id) const1901 bool FlatAffineConstraints::containsId(Value id) const {
1902 return llvm::any_of(ids, [&](const Optional<Value> &mayBeId) {
1903 return mayBeId.hasValue() && mayBeId.getValue() == id;
1904 });
1905 }
1906
setDimSymbolSeparation(unsigned newSymbolCount)1907 void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
1908 assert(newSymbolCount <= numDims + numSymbols &&
1909 "invalid separation position");
1910 numDims = numDims + numSymbols - newSymbolCount;
1911 numSymbols = newSymbolCount;
1912 }
1913
1914 /// Sets the specified identifier to a constant value.
setIdToConstant(unsigned pos,int64_t val)1915 void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
1916 unsigned offset = equalities.size();
1917 equalities.resize(equalities.size() + numReservedCols);
1918 std::fill(equalities.begin() + offset,
1919 equalities.begin() + offset + getNumCols(), 0);
1920 equalities[offset + pos] = 1;
1921 equalities[offset + getNumCols() - 1] = -val;
1922 }
1923
1924 /// Sets the specified identifier to a constant value; asserts if the id is not
1925 /// found.
setIdToConstant(Value id,int64_t val)1926 void FlatAffineConstraints::setIdToConstant(Value id, int64_t val) {
1927 unsigned pos;
1928 if (!findId(id, &pos))
1929 // This is a pre-condition for this method.
1930 assert(0 && "id not found");
1931 setIdToConstant(pos, val);
1932 }
1933
removeEquality(unsigned pos)1934 void FlatAffineConstraints::removeEquality(unsigned pos) {
1935 unsigned numEqualities = getNumEqualities();
1936 assert(pos < numEqualities);
1937 unsigned outputIndex = pos * numReservedCols;
1938 unsigned inputIndex = (pos + 1) * numReservedCols;
1939 unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols;
1940 std::copy(equalities.begin() + inputIndex,
1941 equalities.begin() + inputIndex + numElemsToCopy,
1942 equalities.begin() + outputIndex);
1943 assert(equalities.size() >= numReservedCols);
1944 equalities.resize(equalities.size() - numReservedCols);
1945 }
1946
removeInequality(unsigned pos)1947 void FlatAffineConstraints::removeInequality(unsigned pos) {
1948 unsigned numInequalities = getNumInequalities();
1949 assert(pos < numInequalities && "invalid position");
1950 unsigned outputIndex = pos * numReservedCols;
1951 unsigned inputIndex = (pos + 1) * numReservedCols;
1952 unsigned numElemsToCopy = (numInequalities - pos - 1) * numReservedCols;
1953 std::copy(inequalities.begin() + inputIndex,
1954 inequalities.begin() + inputIndex + numElemsToCopy,
1955 inequalities.begin() + outputIndex);
1956 assert(inequalities.size() >= numReservedCols);
1957 inequalities.resize(inequalities.size() - numReservedCols);
1958 }
1959
1960 /// Finds an equality that equates the specified identifier to a constant.
1961 /// Returns the position of the equality row. If 'symbolic' is set to true,
1962 /// symbols are also treated like a constant, i.e., an affine function of the
1963 /// symbols is also treated like a constant. Returns -1 if such an equality
1964 /// could not be found.
findEqualityToConstant(const FlatAffineConstraints & cst,unsigned pos,bool symbolic=false)1965 static int findEqualityToConstant(const FlatAffineConstraints &cst,
1966 unsigned pos, bool symbolic = false) {
1967 assert(pos < cst.getNumIds() && "invalid position");
1968 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1969 int64_t v = cst.atEq(r, pos);
1970 if (v * v != 1)
1971 continue;
1972 unsigned c;
1973 unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
1974 // This checks for zeros in all positions other than 'pos' in [0, f)
1975 for (c = 0; c < f; c++) {
1976 if (c == pos)
1977 continue;
1978 if (cst.atEq(r, c) != 0) {
1979 // Dependent on another identifier.
1980 break;
1981 }
1982 }
1983 if (c == f)
1984 // Equality is free of other identifiers.
1985 return r;
1986 }
1987 return -1;
1988 }
1989
setAndEliminate(unsigned pos,int64_t constVal)1990 void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) {
1991 assert(pos < getNumIds() && "invalid position");
1992 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1993 atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal;
1994 }
1995 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1996 atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal;
1997 }
1998 removeId(pos);
1999 }
2000
constantFoldId(unsigned pos)2001 LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) {
2002 assert(pos < getNumIds() && "invalid position");
2003 int rowIdx;
2004 if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
2005 return failure();
2006
2007 // atEq(rowIdx, pos) is either -1 or 1.
2008 assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
2009 int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
2010 setAndEliminate(pos, constVal);
2011 return success();
2012 }
2013
constantFoldIdRange(unsigned pos,unsigned num)2014 void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
2015 for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
2016 if (failed(constantFoldId(t)))
2017 t++;
2018 }
2019 }
2020
2021 /// Returns the extent (upper bound - lower bound) of the specified
2022 /// identifier if it is found to be a constant; returns None if it's not a
2023 /// constant. This methods treats symbolic identifiers specially, i.e.,
2024 /// it looks for constant differences between affine expressions involving
2025 /// only the symbolic identifiers. See comments at function definition for
2026 /// example. 'lb', if provided, is set to the lower bound associated with the
2027 /// constant difference. Note that 'lb' is purely symbolic and thus will contain
2028 /// the coefficients of the symbolic identifiers and the constant coefficient.
2029 // Egs: 0 <= i <= 15, return 16.
2030 // s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
2031 // s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
2032 // s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
2033 // ceil(s0 - 7 / 8) = floor(s0 / 8)).
getConstantBoundOnDimSize(unsigned pos,SmallVectorImpl<int64_t> * lb,int64_t * boundFloorDivisor,SmallVectorImpl<int64_t> * ub,unsigned * minLbPos,unsigned * minUbPos) const2034 Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
2035 unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
2036 SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
2037 unsigned *minUbPos) const {
2038 assert(pos < getNumDimIds() && "Invalid identifier position");
2039
2040 // Find an equality for 'pos'^th identifier that equates it to some function
2041 // of the symbolic identifiers (+ constant).
2042 int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
2043 if (eqPos != -1) {
2044 auto eq = getEquality(eqPos);
2045 // If the equality involves a local var, punt for now.
2046 // TODO: this can be handled in the future by using the explicit
2047 // representation of the local vars.
2048 if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1,
2049 [](int64_t coeff) { return coeff == 0; }))
2050 return None;
2051
2052 // This identifier can only take a single value.
2053 if (lb) {
2054 // Set lb to that symbolic value.
2055 lb->resize(getNumSymbolIds() + 1);
2056 if (ub)
2057 ub->resize(getNumSymbolIds() + 1);
2058 for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
2059 int64_t v = atEq(eqPos, pos);
2060 // atEq(eqRow, pos) is either -1 or 1.
2061 assert(v * v == 1);
2062 (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v
2063 : -atEq(eqPos, getNumDimIds() + c) / v;
2064 // Since this is an equality, ub = lb.
2065 if (ub)
2066 (*ub)[c] = (*lb)[c];
2067 }
2068 assert(boundFloorDivisor &&
2069 "both lb and divisor or none should be provided");
2070 *boundFloorDivisor = 1;
2071 }
2072 if (minLbPos)
2073 *minLbPos = eqPos;
2074 if (minUbPos)
2075 *minUbPos = eqPos;
2076 return 1;
2077 }
2078
2079 // Check if the identifier appears at all in any of the inequalities.
2080 unsigned r, e;
2081 for (r = 0, e = getNumInequalities(); r < e; r++) {
2082 if (atIneq(r, pos) != 0)
2083 break;
2084 }
2085 if (r == e)
2086 // If it doesn't, there isn't a bound on it.
2087 return None;
2088
2089 // Positions of constraints that are lower/upper bounds on the variable.
2090 SmallVector<unsigned, 4> lbIndices, ubIndices;
2091
2092 // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
2093 // the bounds can only involve symbolic (and local) identifiers. Since the
2094 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2095 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2096 getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
2097 /*eqIndices=*/nullptr, /*offset=*/0,
2098 /*num=*/getNumDimIds());
2099
2100 Optional<int64_t> minDiff = None;
2101 unsigned minLbPosition = 0, minUbPosition = 0;
2102 for (auto ubPos : ubIndices) {
2103 for (auto lbPos : lbIndices) {
2104 // Look for a lower bound and an upper bound that only differ by a
2105 // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst.
2106 // For example, if ii is the pos^th variable, we are looking for
2107 // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
2108 // minimum among all such constant differences is kept since that's the
2109 // constant bounding the extent of the pos^th variable.
2110 unsigned j, e;
2111 for (j = 0, e = getNumCols() - 1; j < e; j++)
2112 if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
2113 break;
2114 }
2115 if (j < getNumCols() - 1)
2116 continue;
2117 int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
2118 atIneq(lbPos, getNumCols() - 1) + 1,
2119 atIneq(lbPos, pos));
2120 if (minDiff == None || diff < minDiff) {
2121 minDiff = diff;
2122 minLbPosition = lbPos;
2123 minUbPosition = ubPos;
2124 }
2125 }
2126 }
2127 if (lb && minDiff.hasValue()) {
2128 // Set lb to the symbolic lower bound.
2129 lb->resize(getNumSymbolIds() + 1);
2130 if (ub)
2131 ub->resize(getNumSymbolIds() + 1);
2132 // The lower bound is the ceildiv of the lb constraint over the coefficient
2133 // of the variable at 'pos'. We express the ceildiv equivalently as a floor
2134 // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
2135 // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
2136 *boundFloorDivisor = atIneq(minLbPosition, pos);
2137 assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
2138 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
2139 (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
2140 }
2141 if (ub) {
2142 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
2143 (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
2144 }
2145 // The lower bound leads to a ceildiv while the upper bound is a floordiv
2146 // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
2147 // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
2148 // the constant term for the lower bound.
2149 (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
2150 }
2151 if (minLbPos)
2152 *minLbPos = minLbPosition;
2153 if (minUbPos)
2154 *minUbPos = minUbPosition;
2155 return minDiff;
2156 }
2157
2158 template <bool isLower>
2159 Optional<int64_t>
computeConstantLowerOrUpperBound(unsigned pos)2160 FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
2161 assert(pos < getNumIds() && "invalid position");
2162 // Project to 'pos'.
2163 projectOut(0, pos);
2164 projectOut(1, getNumIds() - 1);
2165 // Check if there's an equality equating the '0'^th identifier to a constant.
2166 int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
2167 if (eqRowIdx != -1)
2168 // atEq(rowIdx, 0) is either -1 or 1.
2169 return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
2170
2171 // Check if the identifier appears at all in any of the inequalities.
2172 unsigned r, e;
2173 for (r = 0, e = getNumInequalities(); r < e; r++) {
2174 if (atIneq(r, 0) != 0)
2175 break;
2176 }
2177 if (r == e)
2178 // If it doesn't, there isn't a bound on it.
2179 return None;
2180
2181 Optional<int64_t> minOrMaxConst = None;
2182
2183 // Take the max across all const lower bounds (or min across all constant
2184 // upper bounds).
2185 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2186 if (isLower) {
2187 if (atIneq(r, 0) <= 0)
2188 // Not a lower bound.
2189 continue;
2190 } else if (atIneq(r, 0) >= 0) {
2191 // Not an upper bound.
2192 continue;
2193 }
2194 unsigned c, f;
2195 for (c = 0, f = getNumCols() - 1; c < f; c++)
2196 if (c != 0 && atIneq(r, c) != 0)
2197 break;
2198 if (c < getNumCols() - 1)
2199 // Not a constant bound.
2200 continue;
2201
2202 int64_t boundConst =
2203 isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
2204 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
2205 if (isLower) {
2206 if (minOrMaxConst == None || boundConst > minOrMaxConst)
2207 minOrMaxConst = boundConst;
2208 } else {
2209 if (minOrMaxConst == None || boundConst < minOrMaxConst)
2210 minOrMaxConst = boundConst;
2211 }
2212 }
2213 return minOrMaxConst;
2214 }
2215
2216 Optional<int64_t>
getConstantLowerBound(unsigned pos) const2217 FlatAffineConstraints::getConstantLowerBound(unsigned pos) const {
2218 FlatAffineConstraints tmpCst(*this);
2219 return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
2220 }
2221
2222 Optional<int64_t>
getConstantUpperBound(unsigned pos) const2223 FlatAffineConstraints::getConstantUpperBound(unsigned pos) const {
2224 FlatAffineConstraints tmpCst(*this);
2225 return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
2226 }
2227
2228 // A simple (naive and conservative) check for hyper-rectangularity.
isHyperRectangular(unsigned pos,unsigned num) const2229 bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
2230 unsigned num) const {
2231 assert(pos < getNumCols() - 1);
2232 // Check for two non-zero coefficients in the range [pos, pos + sum).
2233 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2234 unsigned sum = 0;
2235 for (unsigned c = pos; c < pos + num; c++) {
2236 if (atIneq(r, c) != 0)
2237 sum++;
2238 }
2239 if (sum > 1)
2240 return false;
2241 }
2242 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2243 unsigned sum = 0;
2244 for (unsigned c = pos; c < pos + num; c++) {
2245 if (atEq(r, c) != 0)
2246 sum++;
2247 }
2248 if (sum > 1)
2249 return false;
2250 }
2251 return true;
2252 }
2253
print(raw_ostream & os) const2254 void FlatAffineConstraints::print(raw_ostream &os) const {
2255 assert(hasConsistentState());
2256 os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds()
2257 << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints()
2258 << " constraints)\n";
2259 os << "(";
2260 for (unsigned i = 0, e = getNumIds(); i < e; i++) {
2261 if (ids[i] == None)
2262 os << "None ";
2263 else
2264 os << "Value ";
2265 }
2266 os << " const)\n";
2267 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
2268 for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2269 os << atEq(i, j) << " ";
2270 }
2271 os << "= 0\n";
2272 }
2273 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2274 for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2275 os << atIneq(i, j) << " ";
2276 }
2277 os << ">= 0\n";
2278 }
2279 os << '\n';
2280 }
2281
dump() const2282 void FlatAffineConstraints::dump() const { print(llvm::errs()); }
2283
2284 /// Removes duplicate constraints, trivially true constraints, and constraints
2285 /// that can be detected as redundant as a result of differing only in their
2286 /// constant term part. A constraint of the form <non-negative constant> >= 0 is
2287 /// considered trivially true.
2288 // Uses a DenseSet to hash and detect duplicates followed by a linear scan to
2289 // remove duplicates in place.
removeTrivialRedundancy()2290 void FlatAffineConstraints::removeTrivialRedundancy() {
2291 GCDTightenInequalities();
2292 normalizeConstraintsByGCD();
2293
2294 // A map used to detect redundancy stemming from constraints that only differ
2295 // in their constant term. The value stored is <row position, const term>
2296 // for a given row.
2297 SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
2298 rowsWithoutConstTerm;
2299 // To unique rows.
2300 SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
2301
2302 // Check if constraint is of the form <non-negative-constant> >= 0.
2303 auto isTriviallyValid = [&](unsigned r) -> bool {
2304 for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
2305 if (atIneq(r, c) != 0)
2306 return false;
2307 }
2308 return atIneq(r, getNumCols() - 1) >= 0;
2309 };
2310
2311 // Detect and mark redundant constraints.
2312 SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
2313 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2314 int64_t *rowStart = inequalities.data() + numReservedCols * r;
2315 auto row = ArrayRef<int64_t>(rowStart, getNumCols());
2316 if (isTriviallyValid(r) || !rowSet.insert(row).second) {
2317 redunIneq[r] = true;
2318 continue;
2319 }
2320
2321 // Among constraints that only differ in the constant term part, mark
2322 // everything other than the one with the smallest constant term redundant.
2323 // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
2324 // former two are redundant).
2325 int64_t constTerm = atIneq(r, getNumCols() - 1);
2326 auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
2327 const auto &ret =
2328 rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
2329 if (!ret.second) {
2330 // Check if the other constraint has a higher constant term.
2331 auto &val = ret.first->second;
2332 if (val.second > constTerm) {
2333 // The stored row is redundant. Mark it so, and update with this one.
2334 redunIneq[val.first] = true;
2335 val = {r, constTerm};
2336 } else {
2337 // The one stored makes this one redundant.
2338 redunIneq[r] = true;
2339 }
2340 }
2341 }
2342
2343 auto copyRow = [&](unsigned src, unsigned dest) {
2344 if (src == dest)
2345 return;
2346 for (unsigned c = 0, e = getNumCols(); c < e; c++) {
2347 atIneq(dest, c) = atIneq(src, c);
2348 }
2349 };
2350
2351 // Scan to get rid of all rows marked redundant, in-place.
2352 unsigned pos = 0;
2353 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2354 if (!redunIneq[r])
2355 copyRow(r, pos++);
2356 }
2357 inequalities.resize(numReservedCols * pos);
2358
2359 // TODO: consider doing this for equalities as well, but probably not worth
2360 // the savings.
2361 }
2362
clearAndCopyFrom(const FlatAffineConstraints & other)2363 void FlatAffineConstraints::clearAndCopyFrom(
2364 const FlatAffineConstraints &other) {
2365 FlatAffineConstraints copy(other);
2366 std::swap(*this, copy);
2367 assert(copy.getNumIds() == copy.getIds().size());
2368 }
2369
removeId(unsigned pos)2370 void FlatAffineConstraints::removeId(unsigned pos) {
2371 removeIdRange(pos, pos + 1);
2372 }
2373
2374 static std::pair<unsigned, unsigned>
getNewNumDimsSymbols(unsigned pos,const FlatAffineConstraints & cst)2375 getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
2376 unsigned numDims = cst.getNumDimIds();
2377 unsigned numSymbols = cst.getNumSymbolIds();
2378 unsigned newNumDims, newNumSymbols;
2379 if (pos < numDims) {
2380 newNumDims = numDims - 1;
2381 newNumSymbols = numSymbols;
2382 } else if (pos < numDims + numSymbols) {
2383 assert(numSymbols >= 1);
2384 newNumDims = numDims;
2385 newNumSymbols = numSymbols - 1;
2386 } else {
2387 newNumDims = numDims;
2388 newNumSymbols = numSymbols;
2389 }
2390 return {newNumDims, newNumSymbols};
2391 }
2392
2393 #undef DEBUG_TYPE
2394 #define DEBUG_TYPE "fm"
2395
2396 /// Eliminates identifier at the specified position using Fourier-Motzkin
2397 /// variable elimination. This technique is exact for rational spaces but
2398 /// conservative (in "rare" cases) for integer spaces. The operation corresponds
2399 /// to a projection operation yielding the (convex) set of integer points
2400 /// contained in the rational shadow of the set. An emptiness test that relies
2401 /// on this method will guarantee emptiness, i.e., it disproves the existence of
2402 /// a solution if it says it's empty.
2403 /// If a non-null isResultIntegerExact is passed, it is set to true if the
2404 /// result is also integer exact. If it's set to false, the obtained solution
2405 /// *may* not be exact, i.e., it may contain integer points that do not have an
2406 /// integer pre-image in the original set.
2407 ///
2408 /// Eg:
2409 /// j >= 0, j <= i + 1
2410 /// i >= 0, i <= N + 1
2411 /// Eliminating i yields,
2412 /// j >= 0, 0 <= N + 1, j - 1 <= N + 1
2413 ///
2414 /// If darkShadow = true, this method computes the dark shadow on elimination;
2415 /// the dark shadow is a convex integer subset of the exact integer shadow. A
2416 /// non-empty dark shadow proves the existence of an integer solution. The
2417 /// elimination in such a case could however be an under-approximation, and thus
2418 /// should not be used for scanning sets or used by itself for dependence
2419 /// checking.
2420 ///
2421 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
2422 /// ^
2423 /// |
2424 /// | * * * * o o
2425 /// i | * * o o o o
2426 /// | o * * * * *
2427 /// --------------->
2428 /// j ->
2429 ///
2430 /// Eliminating i from this system (projecting on the j dimension):
2431 /// rational shadow / integer light shadow: 1 <= j <= 6
2432 /// dark shadow: 3 <= j <= 6
2433 /// exact integer shadow: j = 1 \union 3 <= j <= 6
2434 /// holes/splinters: j = 2
2435 ///
2436 /// darkShadow = false, isResultIntegerExact = nullptr are default values.
2437 // TODO: a slight modification to yield dark shadow version of FM (tightened),
2438 // which can prove the existence of a solution if there is one.
FourierMotzkinEliminate(unsigned pos,bool darkShadow,bool * isResultIntegerExact)2439 void FlatAffineConstraints::FourierMotzkinEliminate(
2440 unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
2441 LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
2442 LLVM_DEBUG(dump());
2443 assert(pos < getNumIds() && "invalid position");
2444 assert(hasConsistentState());
2445
2446 // Check if this identifier can be eliminated through a substitution.
2447 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2448 if (atEq(r, pos) != 0) {
2449 // Use Gaussian elimination here (since we have an equality).
2450 LogicalResult ret = gaussianEliminateId(pos);
2451 (void)ret;
2452 assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
2453 LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
2454 LLVM_DEBUG(dump());
2455 return;
2456 }
2457 }
2458
2459 // A fast linear time tightening.
2460 GCDTightenInequalities();
2461
2462 // Check if the identifier appears at all in any of the inequalities.
2463 unsigned r, e;
2464 for (r = 0, e = getNumInequalities(); r < e; r++) {
2465 if (atIneq(r, pos) != 0)
2466 break;
2467 }
2468 if (r == getNumInequalities()) {
2469 // If it doesn't appear, just remove the column and return.
2470 // TODO: refactor removeColumns to use it from here.
2471 removeId(pos);
2472 LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2473 LLVM_DEBUG(dump());
2474 return;
2475 }
2476
2477 // Positions of constraints that are lower bounds on the variable.
2478 SmallVector<unsigned, 4> lbIndices;
2479 // Positions of constraints that are lower bounds on the variable.
2480 SmallVector<unsigned, 4> ubIndices;
2481 // Positions of constraints that do not involve the variable.
2482 std::vector<unsigned> nbIndices;
2483 nbIndices.reserve(getNumInequalities());
2484
2485 // Gather all lower bounds and upper bounds of the variable. Since the
2486 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2487 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2488 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2489 if (atIneq(r, pos) == 0) {
2490 // Id does not appear in bound.
2491 nbIndices.push_back(r);
2492 } else if (atIneq(r, pos) >= 1) {
2493 // Lower bound.
2494 lbIndices.push_back(r);
2495 } else {
2496 // Upper bound.
2497 ubIndices.push_back(r);
2498 }
2499 }
2500
2501 // Set the number of dimensions, symbols in the resulting system.
2502 const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this);
2503 unsigned newNumDims = dimsSymbols.first;
2504 unsigned newNumSymbols = dimsSymbols.second;
2505
2506 SmallVector<Optional<Value>, 8> newIds;
2507 newIds.reserve(numIds - 1);
2508 newIds.append(ids.begin(), ids.begin() + pos);
2509 newIds.append(ids.begin() + pos + 1, ids.end());
2510
2511 /// Create the new system which has one identifier less.
2512 FlatAffineConstraints newFac(
2513 lbIndices.size() * ubIndices.size() + nbIndices.size(),
2514 getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
2515 /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds);
2516
2517 assert(newFac.getIds().size() == newFac.getNumIds());
2518
2519 // This will be used to check if the elimination was integer exact.
2520 unsigned lcmProducts = 1;
2521
2522 // Let x be the variable we are eliminating.
2523 // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
2524 // that c_l, c_u >= 1) we have:
2525 // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
2526 // We thus generate a constraint:
2527 // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
2528 // Note if c_l = c_u = 1, all integer points captured by the resulting
2529 // constraint correspond to integer points in the original system (i.e., they
2530 // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
2531 // integer exact.
2532 for (auto ubPos : ubIndices) {
2533 for (auto lbPos : lbIndices) {
2534 SmallVector<int64_t, 4> ineq;
2535 ineq.reserve(newFac.getNumCols());
2536 int64_t lbCoeff = atIneq(lbPos, pos);
2537 // Note that in the comments above, ubCoeff is the negation of the
2538 // coefficient in the canonical form as the view taken here is that of the
2539 // term being moved to the other size of '>='.
2540 int64_t ubCoeff = -atIneq(ubPos, pos);
2541 // TODO: refactor this loop to avoid all branches inside.
2542 for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2543 if (l == pos)
2544 continue;
2545 assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
2546 int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
2547 ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
2548 atIneq(lbPos, l) * (lcm / lbCoeff));
2549 lcmProducts *= lcm;
2550 }
2551 if (darkShadow) {
2552 // The dark shadow is a convex subset of the exact integer shadow. If
2553 // there is a point here, it proves the existence of a solution.
2554 ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
2555 }
2556 // TODO: we need to have a way to add inequalities in-place in
2557 // FlatAffineConstraints instead of creating and copying over.
2558 newFac.addInequality(ineq);
2559 }
2560 }
2561
2562 LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
2563 << "\n");
2564 if (lcmProducts == 1 && isResultIntegerExact)
2565 *isResultIntegerExact = true;
2566
2567 // Copy over the constraints not involving this variable.
2568 for (auto nbPos : nbIndices) {
2569 SmallVector<int64_t, 4> ineq;
2570 ineq.reserve(getNumCols() - 1);
2571 for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2572 if (l == pos)
2573 continue;
2574 ineq.push_back(atIneq(nbPos, l));
2575 }
2576 newFac.addInequality(ineq);
2577 }
2578
2579 assert(newFac.getNumConstraints() ==
2580 lbIndices.size() * ubIndices.size() + nbIndices.size());
2581
2582 // Copy over the equalities.
2583 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2584 SmallVector<int64_t, 4> eq;
2585 eq.reserve(newFac.getNumCols());
2586 for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2587 if (l == pos)
2588 continue;
2589 eq.push_back(atEq(r, l));
2590 }
2591 newFac.addEquality(eq);
2592 }
2593
2594 // GCD tightening and normalization allows detection of more trivially
2595 // redundant constraints.
2596 newFac.GCDTightenInequalities();
2597 newFac.normalizeConstraintsByGCD();
2598 newFac.removeTrivialRedundancy();
2599 clearAndCopyFrom(newFac);
2600 LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2601 LLVM_DEBUG(dump());
2602 }
2603
2604 #undef DEBUG_TYPE
2605 #define DEBUG_TYPE "affine-structures"
2606
projectOut(unsigned pos,unsigned num)2607 void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
2608 if (num == 0)
2609 return;
2610
2611 // 'pos' can be at most getNumCols() - 2 if num > 0.
2612 assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
2613 assert(pos + num < getNumCols() && "invalid range");
2614
2615 // Eliminate as many identifiers as possible using Gaussian elimination.
2616 unsigned currentPos = pos;
2617 unsigned numToEliminate = num;
2618 unsigned numGaussianEliminated = 0;
2619
2620 while (currentPos < getNumIds()) {
2621 unsigned curNumEliminated =
2622 gaussianEliminateIds(currentPos, currentPos + numToEliminate);
2623 ++currentPos;
2624 numToEliminate -= curNumEliminated + 1;
2625 numGaussianEliminated += curNumEliminated;
2626 }
2627
2628 // Eliminate the remaining using Fourier-Motzkin.
2629 for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
2630 unsigned numToEliminate = num - numGaussianEliminated - i;
2631 FourierMotzkinEliminate(
2632 getBestIdToEliminate(*this, pos, pos + numToEliminate));
2633 }
2634
2635 // Fast/trivial simplifications.
2636 GCDTightenInequalities();
2637 // Normalize constraints after tightening since the latter impacts this, but
2638 // not the other way round.
2639 normalizeConstraintsByGCD();
2640 }
2641
projectOut(Value id)2642 void FlatAffineConstraints::projectOut(Value id) {
2643 unsigned pos;
2644 bool ret = findId(id, &pos);
2645 assert(ret);
2646 (void)ret;
2647 FourierMotzkinEliminate(pos);
2648 }
2649
clearConstraints()2650 void FlatAffineConstraints::clearConstraints() {
2651 equalities.clear();
2652 inequalities.clear();
2653 }
2654
2655 namespace {
2656
2657 enum BoundCmpResult { Greater, Less, Equal, Unknown };
2658
2659 /// Compares two affine bounds whose coefficients are provided in 'first' and
2660 /// 'second'. The last coefficient is the constant term.
compareBounds(ArrayRef<int64_t> a,ArrayRef<int64_t> b)2661 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
2662 assert(a.size() == b.size());
2663
2664 // For the bounds to be comparable, their corresponding identifier
2665 // coefficients should be equal; the constant terms are then compared to
2666 // determine less/greater/equal.
2667
2668 if (!std::equal(a.begin(), a.end() - 1, b.begin()))
2669 return Unknown;
2670
2671 if (a.back() == b.back())
2672 return Equal;
2673
2674 return a.back() < b.back() ? Less : Greater;
2675 }
2676 } // namespace
2677
2678 // Returns constraints that are common to both A & B.
getCommonConstraints(const FlatAffineConstraints & A,const FlatAffineConstraints & B,FlatAffineConstraints & C)2679 static void getCommonConstraints(const FlatAffineConstraints &A,
2680 const FlatAffineConstraints &B,
2681 FlatAffineConstraints &C) {
2682 C.reset(A.getNumDimIds(), A.getNumSymbolIds(), A.getNumLocalIds());
2683 // A naive O(n^2) check should be enough here given the input sizes.
2684 for (unsigned r = 0, e = A.getNumInequalities(); r < e; ++r) {
2685 for (unsigned s = 0, f = B.getNumInequalities(); s < f; ++s) {
2686 if (A.getInequality(r) == B.getInequality(s)) {
2687 C.addInequality(A.getInequality(r));
2688 break;
2689 }
2690 }
2691 }
2692 for (unsigned r = 0, e = A.getNumEqualities(); r < e; ++r) {
2693 for (unsigned s = 0, f = B.getNumEqualities(); s < f; ++s) {
2694 if (A.getEquality(r) == B.getEquality(s)) {
2695 C.addEquality(A.getEquality(r));
2696 break;
2697 }
2698 }
2699 }
2700 }
2701
2702 // Computes the bounding box with respect to 'other' by finding the min of the
2703 // lower bounds and the max of the upper bounds along each of the dimensions.
2704 LogicalResult
unionBoundingBox(const FlatAffineConstraints & otherCst)2705 FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
2706 assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
2707 assert(otherCst.getIds()
2708 .slice(0, getNumDimIds())
2709 .equals(getIds().slice(0, getNumDimIds())) &&
2710 "dim values mismatch");
2711 assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
2712 assert(getNumLocalIds() == 0 && "local ids not supported yet here");
2713
2714 // Align `other` to this.
2715 Optional<FlatAffineConstraints> otherCopy;
2716 if (!areIdsAligned(*this, otherCst)) {
2717 otherCopy.emplace(FlatAffineConstraints(otherCst));
2718 mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy.getValue());
2719 }
2720
2721 const auto &otherAligned = otherCopy ? *otherCopy : otherCst;
2722
2723 // Get the constraints common to both systems; these will be added as is to
2724 // the union.
2725 FlatAffineConstraints commonCst;
2726 getCommonConstraints(*this, otherAligned, commonCst);
2727
2728 std::vector<SmallVector<int64_t, 8>> boundingLbs;
2729 std::vector<SmallVector<int64_t, 8>> boundingUbs;
2730 boundingLbs.reserve(2 * getNumDimIds());
2731 boundingUbs.reserve(2 * getNumDimIds());
2732
2733 // To hold lower and upper bounds for each dimension.
2734 SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
2735 // To compute min of lower bounds and max of upper bounds for each dimension.
2736 SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
2737 SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
2738 // To compute final new lower and upper bounds for the union.
2739 SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
2740
2741 int64_t lbFloorDivisor, otherLbFloorDivisor;
2742 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
2743 auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
2744 if (!extent.hasValue())
2745 // TODO: symbolic extents when necessary.
2746 // TODO: handle union if a dimension is unbounded.
2747 return failure();
2748
2749 auto otherExtent = otherAligned.getConstantBoundOnDimSize(
2750 d, &otherLb, &otherLbFloorDivisor, &otherUb);
2751 if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
2752 // TODO: symbolic extents when necessary.
2753 return failure();
2754
2755 assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
2756
2757 auto res = compareBounds(lb, otherLb);
2758 // Identify min.
2759 if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
2760 minLb = lb;
2761 // Since the divisor is for a floordiv, we need to convert to ceildiv,
2762 // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
2763 // div * i >= expr - div + 1.
2764 minLb.back() -= lbFloorDivisor - 1;
2765 } else if (res == BoundCmpResult::Greater) {
2766 minLb = otherLb;
2767 minLb.back() -= otherLbFloorDivisor - 1;
2768 } else {
2769 // Uncomparable - check for constant lower/upper bounds.
2770 auto constLb = getConstantLowerBound(d);
2771 auto constOtherLb = otherAligned.getConstantLowerBound(d);
2772 if (!constLb.hasValue() || !constOtherLb.hasValue())
2773 return failure();
2774 std::fill(minLb.begin(), minLb.end(), 0);
2775 minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
2776 }
2777
2778 // Do the same for ub's but max of upper bounds. Identify max.
2779 auto uRes = compareBounds(ub, otherUb);
2780 if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
2781 maxUb = ub;
2782 } else if (uRes == BoundCmpResult::Less) {
2783 maxUb = otherUb;
2784 } else {
2785 // Uncomparable - check for constant lower/upper bounds.
2786 auto constUb = getConstantUpperBound(d);
2787 auto constOtherUb = otherAligned.getConstantUpperBound(d);
2788 if (!constUb.hasValue() || !constOtherUb.hasValue())
2789 return failure();
2790 std::fill(maxUb.begin(), maxUb.end(), 0);
2791 maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
2792 }
2793
2794 std::fill(newLb.begin(), newLb.end(), 0);
2795 std::fill(newUb.begin(), newUb.end(), 0);
2796
2797 // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
2798 // and so it's the divisor for newLb and newUb as well.
2799 newLb[d] = lbFloorDivisor;
2800 newUb[d] = -lbFloorDivisor;
2801 // Copy over the symbolic part + constant term.
2802 std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
2803 std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
2804 newLb.begin() + getNumDimIds(), std::negate<int64_t>());
2805 std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
2806
2807 boundingLbs.push_back(newLb);
2808 boundingUbs.push_back(newUb);
2809 }
2810
2811 // Clear all constraints and add the lower/upper bounds for the bounding box.
2812 clearConstraints();
2813 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
2814 addInequality(boundingLbs[d]);
2815 addInequality(boundingUbs[d]);
2816 }
2817
2818 // Add the constraints that were common to both systems.
2819 append(commonCst);
2820 removeTrivialRedundancy();
2821
2822 // TODO: copy over pure symbolic constraints from this and 'other' over to the
2823 // union (since the above are just the union along dimensions); we shouldn't
2824 // be discarding any other constraints on the symbols.
2825
2826 return success();
2827 }
2828
2829 /// Compute an explicit representation for local vars. For all systems coming
2830 /// from MLIR integer sets, maps, or expressions where local vars were
2831 /// introduced to model floordivs and mods, this always succeeds.
computeLocalVars(const FlatAffineConstraints & cst,SmallVectorImpl<AffineExpr> & memo,MLIRContext * context)2832 static LogicalResult computeLocalVars(const FlatAffineConstraints &cst,
2833 SmallVectorImpl<AffineExpr> &memo,
2834 MLIRContext *context) {
2835 unsigned numDims = cst.getNumDimIds();
2836 unsigned numSyms = cst.getNumSymbolIds();
2837
2838 // Initialize dimensional and symbolic identifiers.
2839 for (unsigned i = 0; i < numDims; i++)
2840 memo[i] = getAffineDimExpr(i, context);
2841 for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
2842 memo[i] = getAffineSymbolExpr(i - numDims, context);
2843
2844 bool changed;
2845 do {
2846 // Each time `changed` is true at the end of this iteration, one or more
2847 // local vars would have been detected as floordivs and set in memo; so the
2848 // number of null entries in memo[...] strictly reduces; so this converges.
2849 changed = false;
2850 for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i)
2851 if (!memo[numDims + numSyms + i] &&
2852 detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
2853 changed = true;
2854 } while (changed);
2855
2856 ArrayRef<AffineExpr> localExprs =
2857 ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds());
2858 return success(
2859 llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
2860 }
2861
getIneqAsAffineValueMap(unsigned pos,unsigned ineqPos,AffineValueMap & vmap,MLIRContext * context) const2862 void FlatAffineConstraints::getIneqAsAffineValueMap(
2863 unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
2864 MLIRContext *context) const {
2865 unsigned numDims = getNumDimIds();
2866 unsigned numSyms = getNumSymbolIds();
2867
2868 assert(pos < numDims && "invalid position");
2869 assert(ineqPos < getNumInequalities() && "invalid inequality position");
2870
2871 // Get expressions for local vars.
2872 SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
2873 if (failed(computeLocalVars(*this, memo, context)))
2874 assert(false &&
2875 "one or more local exprs do not have an explicit representation");
2876 auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
2877
2878 // Compute the AffineExpr lower/upper bound for this inequality.
2879 ArrayRef<int64_t> inequality = getInequality(ineqPos);
2880 SmallVector<int64_t, 8> bound;
2881 bound.reserve(getNumCols() - 1);
2882 // Everything other than the coefficient at `pos`.
2883 bound.append(inequality.begin(), inequality.begin() + pos);
2884 bound.append(inequality.begin() + pos + 1, inequality.end());
2885
2886 if (inequality[pos] > 0)
2887 // Lower bound.
2888 std::transform(bound.begin(), bound.end(), bound.begin(),
2889 std::negate<int64_t>());
2890 else
2891 // Upper bound (which is exclusive).
2892 bound.back() += 1;
2893
2894 // Convert to AffineExpr (tree) form.
2895 auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
2896 localExprs, context);
2897
2898 // Get the values to bind to this affine expr (all dims and symbols).
2899 SmallVector<Value, 4> operands;
2900 getIdValues(0, pos, &operands);
2901 SmallVector<Value, 4> trailingOperands;
2902 getIdValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands);
2903 operands.append(trailingOperands.begin(), trailingOperands.end());
2904 vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
2905 }
2906
2907 /// Returns true if the pos^th column is all zero for both inequalities and
2908 /// equalities..
isColZero(const FlatAffineConstraints & cst,unsigned pos)2909 static bool isColZero(const FlatAffineConstraints &cst, unsigned pos) {
2910 unsigned rowPos;
2911 return !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/false, &rowPos) &&
2912 !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/true, &rowPos);
2913 }
2914
getAsIntegerSet(MLIRContext * context) const2915 IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
2916 if (getNumConstraints() == 0)
2917 // Return universal set (always true): 0 == 0.
2918 return IntegerSet::get(getNumDimIds(), getNumSymbolIds(),
2919 getAffineConstantExpr(/*constant=*/0, context),
2920 /*eqFlags=*/true);
2921
2922 // Construct local references.
2923 SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
2924
2925 if (failed(computeLocalVars(*this, memo, context))) {
2926 // Check if the local variables without an explicit representation have
2927 // zero coefficients everywhere.
2928 for (unsigned i = getNumDimAndSymbolIds(), e = getNumIds(); i < e; ++i) {
2929 if (!memo[i] && !isColZero(*this, /*pos=*/i)) {
2930 LLVM_DEBUG(llvm::dbgs() << "one or more local exprs do not have an "
2931 "explicit representation");
2932 return IntegerSet();
2933 }
2934 }
2935 }
2936
2937 ArrayRef<AffineExpr> localExprs =
2938 ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
2939
2940 // Construct the IntegerSet from the equalities/inequalities.
2941 unsigned numDims = getNumDimIds();
2942 unsigned numSyms = getNumSymbolIds();
2943
2944 SmallVector<bool, 16> eqFlags(getNumConstraints());
2945 std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
2946 std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
2947
2948 SmallVector<AffineExpr, 8> exprs;
2949 exprs.reserve(getNumConstraints());
2950
2951 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
2952 exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
2953 localExprs, context));
2954 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
2955 exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
2956 numSyms, localExprs, context));
2957 return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
2958 }
2959
2960 /// Find positions of inequalities and equalities that do not have a coefficient
2961 /// for [pos, pos + num) identifiers.
getIndependentConstraints(const FlatAffineConstraints & cst,unsigned pos,unsigned num,SmallVectorImpl<unsigned> & nbIneqIndices,SmallVectorImpl<unsigned> & nbEqIndices)2962 static void getIndependentConstraints(const FlatAffineConstraints &cst,
2963 unsigned pos, unsigned num,
2964 SmallVectorImpl<unsigned> &nbIneqIndices,
2965 SmallVectorImpl<unsigned> &nbEqIndices) {
2966 assert(pos < cst.getNumIds() && "invalid start position");
2967 assert(pos + num <= cst.getNumIds() && "invalid limit");
2968
2969 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
2970 // The bounds are to be independent of [offset, offset + num) columns.
2971 unsigned c;
2972 for (c = pos; c < pos + num; ++c) {
2973 if (cst.atIneq(r, c) != 0)
2974 break;
2975 }
2976 if (c == pos + num)
2977 nbIneqIndices.push_back(r);
2978 }
2979
2980 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
2981 // The bounds are to be independent of [offset, offset + num) columns.
2982 unsigned c;
2983 for (c = pos; c < pos + num; ++c) {
2984 if (cst.atEq(r, c) != 0)
2985 break;
2986 }
2987 if (c == pos + num)
2988 nbEqIndices.push_back(r);
2989 }
2990 }
2991
removeIndependentConstraints(unsigned pos,unsigned num)2992 void FlatAffineConstraints::removeIndependentConstraints(unsigned pos,
2993 unsigned num) {
2994 assert(pos + num <= getNumIds() && "invalid range");
2995
2996 // Remove constraints that are independent of these identifiers.
2997 SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
2998 getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
2999
3000 // Iterate in reverse so that indices don't have to be updated.
3001 // TODO: This method can be made more efficient (because removal of each
3002 // inequality leads to much shifting/copying in the underlying buffer).
3003 for (auto nbIndex : llvm::reverse(nbIneqIndices))
3004 removeInequality(nbIndex);
3005 for (auto nbIndex : llvm::reverse(nbEqIndices))
3006 removeEquality(nbIndex);
3007 }
3008