1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Structures for affine/polyhedral analysis of affine dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/AffineStructures.h"
14 #include "mlir/Analysis/LinearTransform.h"
15 #include "mlir/Analysis/Presburger/Simplex.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/AffineExprVisitor.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "mlir/Support/LLVM.h"
22 #include "mlir/Support/MathExtras.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28
29 #define DEBUG_TYPE "affine-structures"
30
31 using namespace mlir;
32 using llvm::SmallDenseMap;
33 using llvm::SmallDenseSet;
34
35 namespace {
36
37 // See comments for SimpleAffineExprFlattener.
38 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
39 // constraint information associated with mod's, floordiv's, and ceildiv's
40 // in FlatAffineConstraints 'localVarCst'.
41 struct AffineExprFlattener : public SimpleAffineExprFlattener {
42 public:
43 // Constraints connecting newly introduced local variables (for mod's and
44 // div's) to existing (dimensional and symbolic) ones. These are always
45 // inequalities.
46 FlatAffineConstraints localVarCst;
47
AffineExprFlattener__anon7a31c6020111::AffineExprFlattener48 AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx)
49 : SimpleAffineExprFlattener(nDims, nSymbols) {
50 localVarCst.reset(nDims, nSymbols, /*numLocals=*/0);
51 }
52
53 private:
54 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
55 // The local identifier added is always a floordiv of a pure add/mul affine
56 // function of other identifiers, coefficients of which are specified in
57 // `dividend' and with respect to the positive constant `divisor'. localExpr
58 // is the simplified tree expression (AffineExpr) corresponding to the
59 // quantifier.
addLocalFloorDivId__anon7a31c6020111::AffineExprFlattener60 void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
61 AffineExpr localExpr) override {
62 SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
63 // Update localVarCst.
64 localVarCst.addLocalFloorDiv(dividend, divisor);
65 }
66 };
67
68 } // end anonymous namespace
69
70 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
71 // flattened (i.e., semi-affine expressions not handled yet).
72 static LogicalResult
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs,unsigned numDims,unsigned numSymbols,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)73 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
74 unsigned numSymbols,
75 std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
76 FlatAffineConstraints *localVarCst) {
77 if (exprs.empty()) {
78 localVarCst->reset(numDims, numSymbols);
79 return success();
80 }
81
82 AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
83 // Use the same flattener to simplify each expression successively. This way
84 // local identifiers / expressions are shared.
85 for (auto expr : exprs) {
86 if (!expr.isPureAffine())
87 return failure();
88
89 flattener.walkPostOrder(expr);
90 }
91
92 assert(flattener.operandExprStack.size() == exprs.size());
93 flattenedExprs->clear();
94 flattenedExprs->assign(flattener.operandExprStack.begin(),
95 flattener.operandExprStack.end());
96
97 if (localVarCst)
98 localVarCst->clearAndCopyFrom(flattener.localVarCst);
99
100 return success();
101 }
102
103 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
104 // be flattened (semi-affine expressions not handled yet).
105 LogicalResult
getFlattenedAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols,SmallVectorImpl<int64_t> * flattenedExpr,FlatAffineConstraints * localVarCst)106 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
107 unsigned numSymbols,
108 SmallVectorImpl<int64_t> *flattenedExpr,
109 FlatAffineConstraints *localVarCst) {
110 std::vector<SmallVector<int64_t, 8>> flattenedExprs;
111 LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
112 &flattenedExprs, localVarCst);
113 *flattenedExpr = flattenedExprs[0];
114 return ret;
115 }
116
117 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
118 /// flattened (i.e., semi-affine expressions not handled yet).
getFlattenedAffineExprs(AffineMap map,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)119 LogicalResult mlir::getFlattenedAffineExprs(
120 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
121 FlatAffineConstraints *localVarCst) {
122 if (map.getNumResults() == 0) {
123 localVarCst->reset(map.getNumDims(), map.getNumSymbols());
124 return success();
125 }
126 return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
127 map.getNumSymbols(), flattenedExprs,
128 localVarCst);
129 }
130
getFlattenedAffineExprs(IntegerSet set,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)131 LogicalResult mlir::getFlattenedAffineExprs(
132 IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
133 FlatAffineConstraints *localVarCst) {
134 if (set.getNumConstraints() == 0) {
135 localVarCst->reset(set.getNumDims(), set.getNumSymbols());
136 return success();
137 }
138 return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
139 set.getNumSymbols(), flattenedExprs,
140 localVarCst);
141 }
142
143 //===----------------------------------------------------------------------===//
144 // FlatAffineConstraints / FlatAffineValueConstraints.
145 //===----------------------------------------------------------------------===//
146
147 // Clones this object.
clone() const148 std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
149 return std::make_unique<FlatAffineConstraints>(*this);
150 }
151
152 std::unique_ptr<FlatAffineValueConstraints>
clone() const153 FlatAffineValueConstraints::clone() const {
154 return std::make_unique<FlatAffineValueConstraints>(*this);
155 }
156
157 // Construct from an IntegerSet.
FlatAffineConstraints(IntegerSet set)158 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
159 : numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
160 numSymbols(set.getNumSymbols()),
161 equalities(0, numIds + 1, set.getNumEqualities(), numIds + 1),
162 inequalities(0, numIds + 1, set.getNumInequalities(), numIds + 1) {
163 // Flatten expressions and add them to the constraint system.
164 std::vector<SmallVector<int64_t, 8>> flatExprs;
165 FlatAffineConstraints localVarCst;
166 if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
167 assert(false && "flattening unimplemented for semi-affine integer sets");
168 return;
169 }
170 assert(flatExprs.size() == set.getNumConstraints());
171 appendLocalId(/*num=*/localVarCst.getNumLocalIds());
172
173 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
174 const auto &flatExpr = flatExprs[i];
175 assert(flatExpr.size() == getNumCols());
176 if (set.getEqFlags()[i]) {
177 addEquality(flatExpr);
178 } else {
179 addInequality(flatExpr);
180 }
181 }
182 // Add the other constraints involving local id's from flattening.
183 append(localVarCst);
184 }
185
186 // Construct from an IntegerSet.
FlatAffineValueConstraints(IntegerSet set)187 FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set)
188 : FlatAffineConstraints(set) {
189 values.resize(numIds, None);
190 }
191
192 // Construct a hyperrectangular constraint set from ValueRanges that represent
193 // induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are
194 // expected to match one to one. The order of variables and constraints is:
195 //
196 // ivs | lbs | ubs | eq/ineq
197 // ----+-----+-----+---------
198 // 1 -1 0 >= 0
199 // ----+-----+-----+---------
200 // -1 0 1 >= 0
201 //
202 // All dimensions as set as DimId.
203 FlatAffineValueConstraints
getHyperrectangular(ValueRange ivs,ValueRange lbs,ValueRange ubs)204 FlatAffineValueConstraints::getHyperrectangular(ValueRange ivs, ValueRange lbs,
205 ValueRange ubs) {
206 FlatAffineValueConstraints res;
207 unsigned nIvs = ivs.size();
208 assert(nIvs == lbs.size() && "expected as many lower bounds as ivs");
209 assert(nIvs == ubs.size() && "expected as many upper bounds as ivs");
210
211 if (nIvs == 0)
212 return res;
213
214 res.appendDimId(ivs);
215 unsigned lbsStart = res.appendDimId(lbs);
216 unsigned ubsStart = res.appendDimId(ubs);
217
218 MLIRContext *ctx = ivs.front().getContext();
219 for (int ivIdx = 0, e = nIvs; ivIdx < e; ++ivIdx) {
220 // iv - lb >= 0
221 AffineMap lb = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
222 getAffineDimExpr(lbsStart + ivIdx, ctx));
223 if (failed(res.addBound(BoundType::LB, ivIdx, lb)))
224 llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
225 // -iv + ub >= 0
226 AffineMap ub = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
227 getAffineDimExpr(ubsStart + ivIdx, ctx));
228 if (failed(res.addBound(BoundType::UB, ivIdx, ub)))
229 llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
230 }
231 return res;
232 }
233
reset(unsigned numReservedInequalities,unsigned numReservedEqualities,unsigned newNumReservedCols,unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals)234 void FlatAffineConstraints::reset(unsigned numReservedInequalities,
235 unsigned numReservedEqualities,
236 unsigned newNumReservedCols,
237 unsigned newNumDims, unsigned newNumSymbols,
238 unsigned newNumLocals) {
239 assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
240 "minimum 1 column");
241 *this = FlatAffineConstraints(numReservedInequalities, numReservedEqualities,
242 newNumReservedCols, newNumDims, newNumSymbols,
243 newNumLocals);
244 }
245
reset(unsigned numReservedInequalities,unsigned numReservedEqualities,unsigned newNumReservedCols,unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals)246 void FlatAffineValueConstraints::reset(unsigned numReservedInequalities,
247 unsigned numReservedEqualities,
248 unsigned newNumReservedCols,
249 unsigned newNumDims,
250 unsigned newNumSymbols,
251 unsigned newNumLocals) {
252 reset(numReservedInequalities, numReservedEqualities, newNumReservedCols,
253 newNumDims, newNumSymbols, newNumLocals, /*valArgs=*/{});
254 }
255
reset(unsigned numReservedInequalities,unsigned numReservedEqualities,unsigned newNumReservedCols,unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> valArgs)256 void FlatAffineValueConstraints::reset(
257 unsigned numReservedInequalities, unsigned numReservedEqualities,
258 unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols,
259 unsigned newNumLocals, ArrayRef<Value> valArgs) {
260 assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
261 "minimum 1 column");
262 SmallVector<Optional<Value>, 8> newVals;
263 if (!valArgs.empty())
264 newVals.assign(valArgs.begin(), valArgs.end());
265
266 *this = FlatAffineValueConstraints(
267 numReservedInequalities, numReservedEqualities, newNumReservedCols,
268 newNumDims, newNumSymbols, newNumLocals, newVals);
269 }
270
reset(unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals)271 void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
272 unsigned newNumLocals) {
273 reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
274 newNumSymbols, newNumLocals);
275 }
276
reset(unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> valArgs)277 void FlatAffineValueConstraints::reset(unsigned newNumDims,
278 unsigned newNumSymbols,
279 unsigned newNumLocals,
280 ArrayRef<Value> valArgs) {
281 reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
282 newNumSymbols, newNumLocals, valArgs);
283 }
284
append(const FlatAffineConstraints & other)285 void FlatAffineConstraints::append(const FlatAffineConstraints &other) {
286 assert(other.getNumCols() == getNumCols());
287 assert(other.getNumDimIds() == getNumDimIds());
288 assert(other.getNumSymbolIds() == getNumSymbolIds());
289
290 inequalities.reserveRows(inequalities.getNumRows() +
291 other.getNumInequalities());
292 equalities.reserveRows(equalities.getNumRows() + other.getNumEqualities());
293
294 for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
295 addInequality(other.getInequality(r));
296 }
297 for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
298 addEquality(other.getEquality(r));
299 }
300 }
301
appendDimId(unsigned num)302 unsigned FlatAffineConstraints::appendDimId(unsigned num) {
303 unsigned pos = getNumDimIds();
304 insertId(IdKind::Dimension, pos, num);
305 return pos;
306 }
307
appendDimId(ValueRange vals)308 unsigned FlatAffineValueConstraints::appendDimId(ValueRange vals) {
309 unsigned pos = getNumDimIds();
310 insertId(IdKind::Dimension, pos, vals);
311 return pos;
312 }
313
appendSymbolId(unsigned num)314 unsigned FlatAffineConstraints::appendSymbolId(unsigned num) {
315 unsigned pos = getNumSymbolIds();
316 insertId(IdKind::Symbol, pos, num);
317 return pos;
318 }
319
appendSymbolId(ValueRange vals)320 unsigned FlatAffineValueConstraints::appendSymbolId(ValueRange vals) {
321 unsigned pos = getNumSymbolIds();
322 insertId(IdKind::Symbol, pos, vals);
323 return pos;
324 }
325
appendLocalId(unsigned num)326 unsigned FlatAffineConstraints::appendLocalId(unsigned num) {
327 unsigned pos = getNumLocalIds();
328 insertId(IdKind::Local, pos, num);
329 return pos;
330 }
331
insertDimId(unsigned pos,unsigned num)332 unsigned FlatAffineConstraints::insertDimId(unsigned pos, unsigned num) {
333 return insertId(IdKind::Dimension, pos, num);
334 }
335
insertDimId(unsigned pos,ValueRange vals)336 unsigned FlatAffineValueConstraints::insertDimId(unsigned pos,
337 ValueRange vals) {
338 return insertId(IdKind::Dimension, pos, vals);
339 }
340
insertSymbolId(unsigned pos,unsigned num)341 unsigned FlatAffineConstraints::insertSymbolId(unsigned pos, unsigned num) {
342 return insertId(IdKind::Symbol, pos, num);
343 }
344
insertSymbolId(unsigned pos,ValueRange vals)345 unsigned FlatAffineValueConstraints::insertSymbolId(unsigned pos,
346 ValueRange vals) {
347 return insertId(IdKind::Symbol, pos, vals);
348 }
349
insertLocalId(unsigned pos,unsigned num)350 unsigned FlatAffineConstraints::insertLocalId(unsigned pos, unsigned num) {
351 return insertId(IdKind::Local, pos, num);
352 }
353
insertId(IdKind kind,unsigned pos,unsigned num)354 unsigned FlatAffineConstraints::insertId(IdKind kind, unsigned pos,
355 unsigned num) {
356 assertAtMostNumIdKind(pos, kind);
357
358 unsigned absolutePos = getIdKindOffset(kind) + pos;
359 if (kind == IdKind::Dimension)
360 numDims += num;
361 else if (kind == IdKind::Symbol)
362 numSymbols += num;
363 numIds += num;
364
365 inequalities.insertColumns(absolutePos, num);
366 equalities.insertColumns(absolutePos, num);
367
368 return absolutePos;
369 }
370
assertAtMostNumIdKind(unsigned val,IdKind kind) const371 void FlatAffineConstraints::assertAtMostNumIdKind(unsigned val,
372 IdKind kind) const {
373 if (kind == IdKind::Dimension)
374 assert(val <= getNumDimIds());
375 else if (kind == IdKind::Symbol)
376 assert(val <= getNumSymbolIds());
377 else if (kind == IdKind::Local)
378 assert(val <= getNumLocalIds());
379 else
380 llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!");
381 }
382
getIdKindOffset(IdKind kind) const383 unsigned FlatAffineConstraints::getIdKindOffset(IdKind kind) const {
384 if (kind == IdKind::Dimension)
385 return 0;
386 if (kind == IdKind::Symbol)
387 return getNumDimIds();
388 if (kind == IdKind::Local)
389 return getNumDimAndSymbolIds();
390 llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!");
391 }
392
insertId(IdKind kind,unsigned pos,unsigned num)393 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
394 unsigned num) {
395 unsigned absolutePos = FlatAffineConstraints::insertId(kind, pos, num);
396 values.insert(values.begin() + absolutePos, num, None);
397 assert(values.size() == getNumIds());
398 return absolutePos;
399 }
400
insertId(IdKind kind,unsigned pos,ValueRange vals)401 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
402 ValueRange vals) {
403 assert(!vals.empty() && "expected ValueRange with Values");
404 unsigned num = vals.size();
405 unsigned absolutePos = FlatAffineConstraints::insertId(kind, pos, num);
406
407 // If a Value is provided, insert it; otherwise use None.
408 for (unsigned i = 0; i < num; ++i)
409 values.insert(values.begin() + absolutePos + i,
410 vals[i] ? Optional<Value>(vals[i]) : None);
411
412 assert(values.size() == getNumIds());
413 return absolutePos;
414 }
415
hasValues() const416 bool FlatAffineValueConstraints::hasValues() const {
417 return llvm::find_if(values, [](Optional<Value> id) {
418 return id.hasValue();
419 }) != values.end();
420 }
421
removeId(IdKind kind,unsigned pos)422 void FlatAffineConstraints::removeId(IdKind kind, unsigned pos) {
423 removeIdRange(kind, pos, pos + 1);
424 }
425
removeIdRange(IdKind kind,unsigned idStart,unsigned idLimit)426 void FlatAffineConstraints::removeIdRange(IdKind kind, unsigned idStart,
427 unsigned idLimit) {
428 assertAtMostNumIdKind(idLimit, kind);
429 removeIdRange(getIdKindOffset(kind) + idStart,
430 getIdKindOffset(kind) + idLimit);
431 }
432
433 /// Checks if two constraint systems are in the same space, i.e., if they are
434 /// associated with the same set of identifiers, appearing in the same order.
areIdsAligned(const FlatAffineValueConstraints & a,const FlatAffineValueConstraints & b)435 static bool areIdsAligned(const FlatAffineValueConstraints &a,
436 const FlatAffineValueConstraints &b) {
437 return a.getNumDimIds() == b.getNumDimIds() &&
438 a.getNumSymbolIds() == b.getNumSymbolIds() &&
439 a.getNumIds() == b.getNumIds() &&
440 a.getMaybeValues().equals(b.getMaybeValues());
441 }
442
443 /// Calls areIdsAligned to check if two constraint systems have the same set
444 /// of identifiers in the same order.
areIdsAlignedWithOther(const FlatAffineValueConstraints & other)445 bool FlatAffineValueConstraints::areIdsAlignedWithOther(
446 const FlatAffineValueConstraints &other) {
447 return areIdsAligned(*this, other);
448 }
449
450 /// Checks if the SSA values associated with `cst`'s identifiers are unique.
451 static bool LLVM_ATTRIBUTE_UNUSED
areIdsUnique(const FlatAffineValueConstraints & cst)452 areIdsUnique(const FlatAffineValueConstraints &cst) {
453 SmallPtrSet<Value, 8> uniqueIds;
454 for (auto val : cst.getMaybeValues()) {
455 if (val.hasValue() && !uniqueIds.insert(val.getValue()).second)
456 return false;
457 }
458 return true;
459 }
460
461 /// Merge and align the identifiers of A and B starting at 'offset', so that
462 /// both constraint systems get the union of the contained identifiers that is
463 /// dimension-wise and symbol-wise unique; both constraint systems are updated
464 /// so that they have the union of all identifiers, with A's original
465 /// identifiers appearing first followed by any of B's identifiers that didn't
466 /// appear in A. Local identifiers of each system are by design separate/local
467 /// and are placed one after other (A's followed by B's).
468 // E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
469 // Output: both A, B have (%i, %j, %k) [%M, %N, %P]
mergeAndAlignIds(unsigned offset,FlatAffineValueConstraints * a,FlatAffineValueConstraints * b)470 static void mergeAndAlignIds(unsigned offset, FlatAffineValueConstraints *a,
471 FlatAffineValueConstraints *b) {
472 assert(offset <= a->getNumDimIds() && offset <= b->getNumDimIds());
473 // A merge/align isn't meaningful if a cst's ids aren't distinct.
474 assert(areIdsUnique(*a) && "A's values aren't unique");
475 assert(areIdsUnique(*b) && "B's values aren't unique");
476
477 assert(std::all_of(a->getMaybeValues().begin() + offset,
478 a->getMaybeValues().begin() + a->getNumDimAndSymbolIds(),
479 [](Optional<Value> id) { return id.hasValue(); }));
480
481 assert(std::all_of(b->getMaybeValues().begin() + offset,
482 b->getMaybeValues().begin() + b->getNumDimAndSymbolIds(),
483 [](Optional<Value> id) { return id.hasValue(); }));
484
485 // Bring A and B to common local space
486 a->mergeLocalIds(*b);
487
488 SmallVector<Value, 4> aDimValues;
489 a->getValues(offset, a->getNumDimIds(), &aDimValues);
490
491 {
492 // Merge dims from A into B.
493 unsigned d = offset;
494 for (auto aDimValue : aDimValues) {
495 unsigned loc;
496 if (b->findId(aDimValue, &loc)) {
497 assert(loc >= offset && "A's dim appears in B's aligned range");
498 assert(loc < b->getNumDimIds() &&
499 "A's dim appears in B's non-dim position");
500 b->swapId(d, loc);
501 } else {
502 b->insertDimId(d, aDimValue);
503 }
504 d++;
505 }
506 // Dimensions that are in B, but not in A, are added at the end.
507 for (unsigned t = a->getNumDimIds(), e = b->getNumDimIds(); t < e; t++) {
508 a->appendDimId(b->getValue(t));
509 }
510 assert(a->getNumDimIds() == b->getNumDimIds() &&
511 "expected same number of dims");
512 }
513
514 // Merge and align symbols of A and B
515 a->mergeSymbolIds(*b);
516
517 assert(areIdsAligned(*a, *b) && "IDs expected to be aligned");
518 }
519
520 // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
mergeAndAlignIdsWithOther(unsigned offset,FlatAffineValueConstraints * other)521 void FlatAffineValueConstraints::mergeAndAlignIdsWithOther(
522 unsigned offset, FlatAffineValueConstraints *other) {
523 mergeAndAlignIds(offset, this, other);
524 }
525
526 LogicalResult
composeMap(const AffineValueMap * vMap)527 FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) {
528 return composeMatchingMap(
529 computeAlignedMap(vMap->getAffineMap(), vMap->getOperands()));
530 }
531
532 // Similar to `composeMap` except that no Values need be associated with the
533 // constraint system nor are they looked at -- the dimensions and symbols of
534 // `other` are expected to correspond 1:1 to `this` system.
composeMatchingMap(AffineMap other)535 LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
536 assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
537 assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
538
539 std::vector<SmallVector<int64_t, 8>> flatExprs;
540 if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
541 return failure();
542 assert(flatExprs.size() == other.getNumResults());
543
544 // Add dimensions corresponding to the map's results.
545 insertDimId(/*pos=*/0, /*num=*/other.getNumResults());
546
547 // We add one equality for each result connecting the result dim of the map to
548 // the other identifiers.
549 // E.g.: if the expression is 16*i0 + i1, and this is the r^th
550 // iteration/result of the value map, we are adding the equality:
551 // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
552 // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
553 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
554 const auto &flatExpr = flatExprs[r];
555 assert(flatExpr.size() >= other.getNumInputs() + 1);
556
557 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
558 // Set the coefficient for this result to one.
559 eqToAdd[r] = 1;
560
561 // Dims and symbols.
562 for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
563 // Negate `eq[r]` since the newly added dimension will be set to this one.
564 eqToAdd[e + i] = -flatExpr[i];
565 }
566 // Local columns of `eq` are at the beginning.
567 unsigned j = getNumDimIds() + getNumSymbolIds();
568 unsigned end = flatExpr.size() - 1;
569 for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
570 eqToAdd[j] = -flatExpr[i];
571 }
572
573 // Constant term.
574 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
575
576 // Add the equality connecting the result of the map to this constraint set.
577 addEquality(eqToAdd);
578 }
579
580 return success();
581 }
582
583 // Turn a symbol into a dimension.
turnSymbolIntoDim(FlatAffineValueConstraints * cst,Value id)584 static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value id) {
585 unsigned pos;
586 if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
587 pos < cst->getNumDimAndSymbolIds()) {
588 cst->swapId(pos, cst->getNumDimIds());
589 cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
590 }
591 }
592
593 /// Merge and align symbols of `this` and `other` such that both get union of
594 /// of symbols that are unique. Symbols with Value as `None` are considered
595 /// to be inequal to all other symbols.
mergeSymbolIds(FlatAffineValueConstraints & other)596 void FlatAffineValueConstraints::mergeSymbolIds(
597 FlatAffineValueConstraints &other) {
598 SmallVector<Value, 4> aSymValues;
599 getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues);
600
601 // Merge symbols: merge symbols into `other` first from `this`.
602 unsigned s = other.getNumDimIds();
603 for (Value aSymValue : aSymValues) {
604 unsigned loc;
605 // If the id is a symbol in `other`, then align it, otherwise assume that
606 // it is a new symbol
607 if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() &&
608 loc < getNumDimAndSymbolIds())
609 other.swapId(s, loc);
610 else
611 other.insertSymbolId(s - other.getNumDimIds(), aSymValue);
612 s++;
613 }
614
615 // Symbols that are in other, but not in this, are added at the end.
616 for (unsigned t = other.getNumDimIds() + getNumSymbolIds(),
617 e = other.getNumDimAndSymbolIds();
618 t < e; t++)
619 insertSymbolId(getNumSymbolIds(), other.getValue(t));
620
621 assert(getNumSymbolIds() == other.getNumSymbolIds() &&
622 "expected same number of symbols");
623 }
624
625 // Changes all symbol identifiers which are loop IVs to dim identifiers.
convertLoopIVSymbolsToDims()626 void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() {
627 // Gather all symbols which are loop IVs.
628 SmallVector<Value, 4> loopIVs;
629 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
630 if (hasValue(i) && getForInductionVarOwner(getValue(i)))
631 loopIVs.push_back(getValue(i));
632 }
633 // Turn each symbol in 'loopIVs' into a dim identifier.
634 for (auto iv : loopIVs) {
635 turnSymbolIntoDim(this, iv);
636 }
637 }
638
addInductionVarOrTerminalSymbol(Value val)639 void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) {
640 if (containsId(val))
641 return;
642
643 // Caller is expected to fully compose map/operands if necessary.
644 assert((isTopLevelValue(val) || isForInductionVar(val)) &&
645 "non-terminal symbol / loop IV expected");
646 // Outer loop IVs could be used in forOp's bounds.
647 if (auto loop = getForInductionVarOwner(val)) {
648 appendDimId(val);
649 if (failed(this->addAffineForOpDomain(loop)))
650 LLVM_DEBUG(
651 loop.emitWarning("failed to add domain info to constraint system"));
652 return;
653 }
654 // Add top level symbol.
655 appendSymbolId(val);
656 // Check if the symbol is a constant.
657 if (auto constOp = val.getDefiningOp<ConstantIndexOp>())
658 addBound(BoundType::EQ, val, constOp.getValue());
659 }
660
661 LogicalResult
addAffineForOpDomain(AffineForOp forOp)662 FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
663 unsigned pos;
664 // Pre-condition for this method.
665 if (!findId(forOp.getInductionVar(), &pos)) {
666 assert(false && "Value not found");
667 return failure();
668 }
669
670 int64_t step = forOp.getStep();
671 if (step != 1) {
672 if (!forOp.hasConstantLowerBound())
673 LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated"));
674 else {
675 // Add constraints for the stride.
676 // (iv - lb) % step = 0 can be written as:
677 // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
678 // Add local variable 'q' and add the above equality.
679 // The first constraint is q = (iv - lb) floordiv step
680 SmallVector<int64_t, 8> dividend(getNumCols(), 0);
681 int64_t lb = forOp.getConstantLowerBound();
682 dividend[pos] = 1;
683 dividend.back() -= lb;
684 addLocalFloorDiv(dividend, step);
685 // Second constraint: (iv - lb) - step * q = 0.
686 SmallVector<int64_t, 8> eq(getNumCols(), 0);
687 eq[pos] = 1;
688 eq.back() -= lb;
689 // For the local var just added above.
690 eq[getNumCols() - 2] = -step;
691 addEquality(eq);
692 }
693 }
694
695 if (forOp.hasConstantLowerBound()) {
696 addBound(BoundType::LB, pos, forOp.getConstantLowerBound());
697 } else {
698 // Non-constant lower bound case.
699 if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(),
700 forOp.getLowerBoundOperands())))
701 return failure();
702 }
703
704 if (forOp.hasConstantUpperBound()) {
705 addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1);
706 return success();
707 }
708 // Non-constant upper bound case.
709 return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(),
710 forOp.getUpperBoundOperands());
711 }
712
713 LogicalResult
addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,ArrayRef<AffineMap> ubMaps,ArrayRef<Value> operands)714 FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
715 ArrayRef<AffineMap> ubMaps,
716 ArrayRef<Value> operands) {
717 assert(lbMaps.size() == ubMaps.size());
718 assert(lbMaps.size() <= getNumDimIds());
719
720 for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
721 AffineMap lbMap = lbMaps[i];
722 AffineMap ubMap = ubMaps[i];
723 assert(!lbMap || lbMap.getNumInputs() == operands.size());
724 assert(!ubMap || ubMap.getNumInputs() == operands.size());
725
726 // Check if this slice is just an equality along this dimension. If so,
727 // retrieve the existing loop it equates to and add it to the system.
728 if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
729 ubMap.getNumResults() == 1 &&
730 lbMap.getResult(0) + 1 == ubMap.getResult(0) &&
731 // The condition above will be true for maps describing a single
732 // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
733 // Make sure we skip those cases by checking that the lb result is not
734 // just a constant.
735 !lbMap.getResult(0).isa<AffineConstantExpr>()) {
736 // Limited support: we expect the lb result to be just a loop dimension.
737 // Not supported otherwise for now.
738 AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
739 if (!result)
740 return failure();
741
742 AffineForOp loop =
743 getForInductionVarOwner(operands[result.getPosition()]);
744 if (!loop)
745 return failure();
746
747 if (failed(addAffineForOpDomain(loop)))
748 return failure();
749 continue;
750 }
751
752 // This slice refers to a loop that doesn't exist in the IR yet. Add its
753 // bounds to the system assuming its dimension identifier position is the
754 // same as the position of the loop in the loop nest.
755 if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands)))
756 return failure();
757 if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands)))
758 return failure();
759 }
760 return success();
761 }
762
addAffineIfOpDomain(AffineIfOp ifOp)763 void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
764 // Create the base constraints from the integer set attached to ifOp.
765 FlatAffineValueConstraints cst(ifOp.getIntegerSet());
766
767 // Bind ids in the constraints to ifOp operands.
768 SmallVector<Value, 4> operands = ifOp.getOperands();
769 cst.setValues(0, cst.getNumDimAndSymbolIds(), operands);
770
771 // Merge the constraints from ifOp to the current domain. We need first merge
772 // and align the IDs from both constraints, and then append the constraints
773 // from the ifOp into the current one.
774 mergeAndAlignIdsWithOther(0, &cst);
775 append(cst);
776 }
777
778 // Searches for a constraint with a non-zero coefficient at `colIdx` in
779 // equality (isEq=true) or inequality (isEq=false) constraints.
780 // Returns true and sets row found in search in `rowIdx`, false otherwise.
findConstraintWithNonZeroAt(const FlatAffineConstraints & cst,unsigned colIdx,bool isEq,unsigned * rowIdx)781 static bool findConstraintWithNonZeroAt(const FlatAffineConstraints &cst,
782 unsigned colIdx, bool isEq,
783 unsigned *rowIdx) {
784 assert(colIdx < cst.getNumCols() && "position out of bounds");
785 auto at = [&](unsigned rowIdx) -> int64_t {
786 return isEq ? cst.atEq(rowIdx, colIdx) : cst.atIneq(rowIdx, colIdx);
787 };
788 unsigned e = isEq ? cst.getNumEqualities() : cst.getNumInequalities();
789 for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
790 if (at(*rowIdx) != 0) {
791 return true;
792 }
793 }
794 return false;
795 }
796
797 // Normalizes the coefficient values across all columns in `rowIdx` by their
798 // GCD in equality or inequality constraints as specified by `isEq`.
799 template <bool isEq>
normalizeConstraintByGCD(FlatAffineConstraints * constraints,unsigned rowIdx)800 static void normalizeConstraintByGCD(FlatAffineConstraints *constraints,
801 unsigned rowIdx) {
802 auto at = [&](unsigned colIdx) -> int64_t {
803 return isEq ? constraints->atEq(rowIdx, colIdx)
804 : constraints->atIneq(rowIdx, colIdx);
805 };
806 uint64_t gcd = std::abs(at(0));
807 for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) {
808 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j)));
809 }
810 if (gcd > 0 && gcd != 1) {
811 for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) {
812 int64_t v = at(j) / static_cast<int64_t>(gcd);
813 isEq ? constraints->atEq(rowIdx, j) = v
814 : constraints->atIneq(rowIdx, j) = v;
815 }
816 }
817 }
818
normalizeConstraintsByGCD()819 void FlatAffineConstraints::normalizeConstraintsByGCD() {
820 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
821 normalizeConstraintByGCD</*isEq=*/true>(this, i);
822 }
823 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
824 normalizeConstraintByGCD</*isEq=*/false>(this, i);
825 }
826 }
827
hasConsistentState() const828 bool FlatAffineConstraints::hasConsistentState() const {
829 if (!inequalities.hasConsistentState())
830 return false;
831 if (!equalities.hasConsistentState())
832 return false;
833
834 // Catches errors where numDims, numSymbols, numIds aren't consistent.
835 if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds)
836 return false;
837
838 return true;
839 }
840
hasConsistentState() const841 bool FlatAffineValueConstraints::hasConsistentState() const {
842 return FlatAffineConstraints::hasConsistentState() &&
843 values.size() == getNumIds();
844 }
845
hasInvalidConstraint() const846 bool FlatAffineConstraints::hasInvalidConstraint() const {
847 assert(hasConsistentState());
848 auto check = [&](bool isEq) -> bool {
849 unsigned numCols = getNumCols();
850 unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
851 for (unsigned i = 0, e = numRows; i < e; ++i) {
852 unsigned j;
853 for (j = 0; j < numCols - 1; ++j) {
854 int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
855 // Skip rows with non-zero variable coefficients.
856 if (v != 0)
857 break;
858 }
859 if (j < numCols - 1) {
860 continue;
861 }
862 // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
863 // Example invalid constraints include: '1 == 0' or '-1 >= 0'
864 int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
865 if ((isEq && v != 0) || (!isEq && v < 0)) {
866 return true;
867 }
868 }
869 return false;
870 };
871 if (check(/*isEq=*/true))
872 return true;
873 return check(/*isEq=*/false);
874 }
875
876 /// Eliminate identifier from constraint at `rowIdx` based on coefficient at
877 /// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
878 /// updated as they have already been eliminated.
eliminateFromConstraint(FlatAffineConstraints * constraints,unsigned rowIdx,unsigned pivotRow,unsigned pivotCol,unsigned elimColStart,bool isEq)879 static void eliminateFromConstraint(FlatAffineConstraints *constraints,
880 unsigned rowIdx, unsigned pivotRow,
881 unsigned pivotCol, unsigned elimColStart,
882 bool isEq) {
883 // Skip if equality 'rowIdx' if same as 'pivotRow'.
884 if (isEq && rowIdx == pivotRow)
885 return;
886 auto at = [&](unsigned i, unsigned j) -> int64_t {
887 return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
888 };
889 int64_t leadCoeff = at(rowIdx, pivotCol);
890 // Skip if leading coefficient at 'rowIdx' is already zero.
891 if (leadCoeff == 0)
892 return;
893 int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
894 int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
895 int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
896 int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
897 int64_t rowMultiplier = lcm / std::abs(leadCoeff);
898
899 unsigned numCols = constraints->getNumCols();
900 for (unsigned j = 0; j < numCols; ++j) {
901 // Skip updating column 'j' if it was just eliminated.
902 if (j >= elimColStart && j < pivotCol)
903 continue;
904 int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
905 rowMultiplier * at(rowIdx, j);
906 isEq ? constraints->atEq(rowIdx, j) = v
907 : constraints->atIneq(rowIdx, j) = v;
908 }
909 }
910
removeIdRange(unsigned idStart,unsigned idLimit)911 void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) {
912 assert(idLimit < getNumCols() && "invalid id limit");
913
914 if (idStart >= idLimit)
915 return;
916
917 // We are going to be removing one or more identifiers from the range.
918 assert(idStart < numIds && "invalid idStart position");
919
920 // TODO: Make 'removeIdRange' a lambda called from here.
921 // Remove eliminated identifiers from the constraints..
922 equalities.removeColumns(idStart, idLimit - idStart);
923 inequalities.removeColumns(idStart, idLimit - idStart);
924
925 // Update members numDims, numSymbols and numIds.
926 unsigned numDimsEliminated = 0;
927 unsigned numLocalsEliminated = 0;
928 unsigned numColsEliminated = idLimit - idStart;
929 if (idStart < numDims) {
930 numDimsEliminated = std::min(numDims, idLimit) - idStart;
931 }
932 // Check how many local id's were removed. Note that our identifier order is
933 // [dims, symbols, locals]. Local id start at position numDims + numSymbols.
934 if (idLimit > numDims + numSymbols) {
935 numLocalsEliminated = std::min(
936 idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds());
937 }
938 unsigned numSymbolsEliminated =
939 numColsEliminated - numDimsEliminated - numLocalsEliminated;
940
941 numDims -= numDimsEliminated;
942 numSymbols -= numSymbolsEliminated;
943 numIds = numIds - numColsEliminated;
944 }
945
removeIdRange(unsigned idStart,unsigned idLimit)946 void FlatAffineValueConstraints::removeIdRange(unsigned idStart,
947 unsigned idLimit) {
948 FlatAffineConstraints::removeIdRange(idStart, idLimit);
949 values.erase(values.begin() + idStart, values.begin() + idLimit);
950 }
951
952 /// Returns the position of the identifier that has the minimum <number of lower
953 /// bounds> times <number of upper bounds> from the specified range of
954 /// identifiers [start, end). It is often best to eliminate in the increasing
955 /// order of these counts when doing Fourier-Motzkin elimination since FM adds
956 /// that many new constraints.
getBestIdToEliminate(const FlatAffineConstraints & cst,unsigned start,unsigned end)957 static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst,
958 unsigned start, unsigned end) {
959 assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
960
961 auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
962 unsigned numLb = 0;
963 unsigned numUb = 0;
964 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
965 if (cst.atIneq(r, pos) > 0) {
966 ++numLb;
967 } else if (cst.atIneq(r, pos) < 0) {
968 ++numUb;
969 }
970 }
971 return numLb * numUb;
972 };
973
974 unsigned minLoc = start;
975 unsigned min = getProductOfNumLowerUpperBounds(start);
976 for (unsigned c = start + 1; c < end; c++) {
977 unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
978 if (numLbUbProduct < min) {
979 min = numLbUbProduct;
980 minLoc = c;
981 }
982 }
983 return minLoc;
984 }
985
986 // Checks for emptiness of the set by eliminating identifiers successively and
987 // using the GCD test (on all equality constraints) and checking for trivially
988 // invalid constraints. Returns 'true' if the constraint system is found to be
989 // empty; false otherwise.
isEmpty() const990 bool FlatAffineConstraints::isEmpty() const {
991 if (isEmptyByGCDTest() || hasInvalidConstraint())
992 return true;
993
994 FlatAffineConstraints tmpCst(*this);
995
996 // First, eliminate as many local variables as possible using equalities.
997 tmpCst.removeRedundantLocalVars();
998 if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint())
999 return true;
1000
1001 // Eliminate as many identifiers as possible using Gaussian elimination.
1002 unsigned currentPos = 0;
1003 while (currentPos < tmpCst.getNumIds()) {
1004 tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
1005 ++currentPos;
1006 // We check emptiness through trivial checks after eliminating each ID to
1007 // detect emptiness early. Since the checks isEmptyByGCDTest() and
1008 // hasInvalidConstraint() are linear time and single sweep on the constraint
1009 // buffer, this appears reasonable - but can optimize in the future.
1010 if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
1011 return true;
1012 }
1013
1014 // Eliminate the remaining using FM.
1015 for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
1016 tmpCst.fourierMotzkinEliminate(
1017 getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
1018 // Check for a constraint explosion. This rarely happens in practice, but
1019 // this check exists as a safeguard against improperly constructed
1020 // constraint systems or artificially created arbitrarily complex systems
1021 // that aren't the intended use case for FlatAffineConstraints. This is
1022 // needed since FM has a worst case exponential complexity in theory.
1023 if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
1024 LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
1025 return false;
1026 }
1027
1028 // FM wouldn't have modified the equalities in any way. So no need to again
1029 // run GCD test. Check for trivial invalid constraints.
1030 if (tmpCst.hasInvalidConstraint())
1031 return true;
1032 }
1033 return false;
1034 }
1035
1036 // Runs the GCD test on all equality constraints. Returns 'true' if this test
1037 // fails on any equality. Returns 'false' otherwise.
1038 // This test can be used to disprove the existence of a solution. If it returns
1039 // true, no integer solution to the equality constraints can exist.
1040 //
1041 // GCD test definition:
1042 //
1043 // The equality constraint:
1044 //
1045 // c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
1046 //
1047 // has an integer solution iff:
1048 //
1049 // GCD of c_1, c_2, ..., c_n divides c_0.
1050 //
isEmptyByGCDTest() const1051 bool FlatAffineConstraints::isEmptyByGCDTest() const {
1052 assert(hasConsistentState());
1053 unsigned numCols = getNumCols();
1054 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1055 uint64_t gcd = std::abs(atEq(i, 0));
1056 for (unsigned j = 1; j < numCols - 1; ++j) {
1057 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
1058 }
1059 int64_t v = std::abs(atEq(i, numCols - 1));
1060 if (gcd > 0 && (v % gcd != 0)) {
1061 return true;
1062 }
1063 }
1064 return false;
1065 }
1066
1067 // Returns a matrix where each row is a vector along which the polytope is
1068 // bounded. The span of the returned vectors is guaranteed to contain all
1069 // such vectors. The returned vectors are NOT guaranteed to be linearly
1070 // independent. This function should not be called on empty sets.
1071 //
1072 // It is sufficient to check the perpendiculars of the constraints, as the set
1073 // of perpendiculars which are bounded must span all bounded directions.
getBoundedDirections() const1074 Matrix FlatAffineConstraints::getBoundedDirections() const {
1075 // Note that it is necessary to add the equalities too (which the constructor
1076 // does) even though we don't need to check if they are bounded; whether an
1077 // inequality is bounded or not depends on what other constraints, including
1078 // equalities, are present.
1079 Simplex simplex(*this);
1080
1081 assert(!simplex.isEmpty() && "It is not meaningful to ask whether a "
1082 "direction is bounded in an empty set.");
1083
1084 SmallVector<unsigned, 8> boundedIneqs;
1085 // The constructor adds the inequalities to the simplex first, so this
1086 // processes all the inequalities.
1087 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1088 if (simplex.isBoundedAlongConstraint(i))
1089 boundedIneqs.push_back(i);
1090 }
1091
1092 // The direction vector is given by the coefficients and does not include the
1093 // constant term, so the matrix has one fewer column.
1094 unsigned dirsNumCols = getNumCols() - 1;
1095 Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
1096
1097 // Copy the bounded inequalities.
1098 unsigned row = 0;
1099 for (unsigned i : boundedIneqs) {
1100 for (unsigned col = 0; col < dirsNumCols; ++col)
1101 dirs(row, col) = atIneq(i, col);
1102 ++row;
1103 }
1104
1105 // Copy the equalities. All the equalities' perpendiculars are bounded.
1106 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1107 for (unsigned col = 0; col < dirsNumCols; ++col)
1108 dirs(row, col) = atEq(i, col);
1109 ++row;
1110 }
1111
1112 return dirs;
1113 }
1114
eqInvolvesSuffixDims(const FlatAffineConstraints & fac,unsigned eqIndex,unsigned numDims)1115 bool eqInvolvesSuffixDims(const FlatAffineConstraints &fac, unsigned eqIndex,
1116 unsigned numDims) {
1117 for (unsigned e = fac.getNumIds(), j = e - numDims; j < e; ++j)
1118 if (fac.atEq(eqIndex, j) != 0)
1119 return true;
1120 return false;
1121 }
ineqInvolvesSuffixDims(const FlatAffineConstraints & fac,unsigned ineqIndex,unsigned numDims)1122 bool ineqInvolvesSuffixDims(const FlatAffineConstraints &fac,
1123 unsigned ineqIndex, unsigned numDims) {
1124 for (unsigned e = fac.getNumIds(), j = e - numDims; j < e; ++j)
1125 if (fac.atIneq(ineqIndex, j) != 0)
1126 return true;
1127 return false;
1128 }
1129
removeConstraintsInvolvingSuffixDims(FlatAffineConstraints & fac,unsigned unboundedDims)1130 void removeConstraintsInvolvingSuffixDims(FlatAffineConstraints &fac,
1131 unsigned unboundedDims) {
1132 // We iterate backwards so that whether we remove constraint i - 1 or not, the
1133 // next constraint to be tested is always i - 2.
1134 for (unsigned i = fac.getNumEqualities(); i > 0; i--)
1135 if (eqInvolvesSuffixDims(fac, i - 1, unboundedDims))
1136 fac.removeEquality(i - 1);
1137 for (unsigned i = fac.getNumInequalities(); i > 0; i--)
1138 if (ineqInvolvesSuffixDims(fac, i - 1, unboundedDims))
1139 fac.removeInequality(i - 1);
1140 }
1141
isIntegerEmpty() const1142 bool FlatAffineConstraints::isIntegerEmpty() const {
1143 return !findIntegerSample().hasValue();
1144 }
1145
1146 /// Let this set be S. If S is bounded then we directly call into the GBR
1147 /// sampling algorithm. Otherwise, there are some unbounded directions, i.e.,
1148 /// vectors v such that S extends to infinity along v or -v. In this case we
1149 /// use an algorithm described in the integer set library (isl) manual and used
1150 /// by the isl_set_sample function in that library. The algorithm is:
1151 ///
1152 /// 1) Apply a unimodular transform T to S to obtain S*T, such that all
1153 /// dimensions in which S*T is bounded lie in the linear span of a prefix of the
1154 /// dimensions.
1155 ///
1156 /// 2) Construct a set B by removing all constraints that involve
1157 /// the unbounded dimensions and then deleting the unbounded dimensions. Note
1158 /// that B is a Bounded set.
1159 ///
1160 /// 3) Try to obtain a sample from B using the GBR sampling
1161 /// algorithm. If no sample is found, return that S is empty.
1162 ///
1163 /// 4) Otherwise, substitute the obtained sample into S*T to obtain a set
1164 /// C. C is a full-dimensional Cone and always contains a sample.
1165 ///
1166 /// 5) Obtain an integer sample from C.
1167 ///
1168 /// 6) Return T*v, where v is the concatenation of the samples from B and C.
1169 ///
1170 /// The following is a sketch of a proof that
1171 /// a) If the algorithm returns empty, then S is empty.
1172 /// b) If the algorithm returns a sample, it is a valid sample in S.
1173 ///
1174 /// The algorithm returns empty only if B is empty, in which case S*T is
1175 /// certainly empty since B was obtained by removing constraints and then
1176 /// deleting unconstrained dimensions from S*T. Since T is unimodular, a vector
1177 /// v is in S*T iff T*v is in S. So in this case, since
1178 /// S*T is empty, S is empty too.
1179 ///
1180 /// Otherwise, the algorithm substitutes the sample from B into S*T. All the
1181 /// constraints of S*T that did not involve unbounded dimensions are satisfied
1182 /// by this substitution. All dimensions in the linear span of the dimensions
1183 /// outside the prefix are unbounded in S*T (step 1). Substituting values for
1184 /// the bounded dimensions cannot make these dimensions bounded, and these are
1185 /// the only remaining dimensions in C, so C is unbounded along every vector (in
1186 /// the positive or negative direction, or both). C is hence a full-dimensional
1187 /// cone and therefore always contains an integer point.
1188 ///
1189 /// Concatenating the samples from B and C gives a sample v in S*T, so the
1190 /// returned sample T*v is a sample in S.
1191 Optional<SmallVector<int64_t, 8>>
findIntegerSample() const1192 FlatAffineConstraints::findIntegerSample() const {
1193 // First, try the GCD test heuristic.
1194 if (isEmptyByGCDTest())
1195 return {};
1196
1197 Simplex simplex(*this);
1198 if (simplex.isEmpty())
1199 return {};
1200
1201 // For a bounded set, we directly call into the GBR sampling algorithm.
1202 if (!simplex.isUnbounded())
1203 return simplex.findIntegerSample();
1204
1205 // The set is unbounded. We cannot directly use the GBR algorithm.
1206 //
1207 // m is a matrix containing, in each row, a vector in which S is
1208 // bounded, such that the linear span of all these dimensions contains all
1209 // bounded dimensions in S.
1210 Matrix m = getBoundedDirections();
1211 // In column echelon form, each row of m occupies only the first rank(m)
1212 // columns and has zeros on the other columns. The transform T that brings S
1213 // to column echelon form is unimodular as well, so this is a suitable
1214 // transform to use in step 1 of the algorithm.
1215 std::pair<unsigned, LinearTransform> result =
1216 LinearTransform::makeTransformToColumnEchelon(std::move(m));
1217 const LinearTransform &transform = result.second;
1218 // 1) Apply T to S to obtain S*T.
1219 FlatAffineConstraints transformedSet = transform.applyTo(*this);
1220
1221 // 2) Remove the unbounded dimensions and constraints involving them to
1222 // obtain a bounded set.
1223 FlatAffineConstraints boundedSet = transformedSet;
1224 unsigned numBoundedDims = result.first;
1225 unsigned numUnboundedDims = getNumIds() - numBoundedDims;
1226 removeConstraintsInvolvingSuffixDims(boundedSet, numUnboundedDims);
1227 boundedSet.removeIdRange(numBoundedDims, boundedSet.getNumIds());
1228
1229 // 3) Try to obtain a sample from the bounded set.
1230 Optional<SmallVector<int64_t, 8>> boundedSample =
1231 Simplex(boundedSet).findIntegerSample();
1232 if (!boundedSample)
1233 return {};
1234 assert(boundedSet.containsPoint(*boundedSample) &&
1235 "Simplex returned an invalid sample!");
1236
1237 // 4) Substitute the values of the bounded dimensions into S*T to obtain a
1238 // full-dimensional cone, which necessarily contains an integer sample.
1239 transformedSet.setAndEliminate(0, *boundedSample);
1240 FlatAffineConstraints &cone = transformedSet;
1241
1242 // 5) Obtain an integer sample from the cone.
1243 //
1244 // We shrink the cone such that for any rational point in the shrunken cone,
1245 // rounding up each of the point's coordinates produces a point that still
1246 // lies in the original cone.
1247 //
1248 // Rounding up a point x adds a number e_i in [0, 1) to each coordinate x_i.
1249 // For each inequality sum_i a_i x_i + c >= 0 in the original cone, the
1250 // shrunken cone will have the inequality tightened by some amount s, such
1251 // that if x satisfies the shrunken cone's tightened inequality, then x + e
1252 // satisfies the original inequality, i.e.,
1253 //
1254 // sum_i a_i x_i + c + s >= 0 implies sum_i a_i (x_i + e_i) + c >= 0
1255 //
1256 // for any e_i values in [0, 1). In fact, we will handle the slightly more
1257 // general case where e_i can be in [0, 1]. For example, consider the
1258 // inequality 2x_1 - 3x_2 - 7x_3 - 6 >= 0, and let x = (3, 0, 0). How low
1259 // could the LHS go if we added a number in [0, 1] to each coordinate? The LHS
1260 // is minimized when we add 1 to the x_i with negative coefficient a_i and
1261 // keep the other x_i the same. In the example, we would get x = (3, 1, 1),
1262 // changing the value of the LHS by -3 + -7 = -10.
1263 //
1264 // In general, the value of the LHS can change by at most the sum of the
1265 // negative a_i, so we accomodate this by shifting the inequality by this
1266 // amount for the shrunken cone.
1267 for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) {
1268 for (unsigned j = 0; j < cone.numIds; ++j) {
1269 int64_t coeff = cone.atIneq(i, j);
1270 if (coeff < 0)
1271 cone.atIneq(i, cone.numIds) += coeff;
1272 }
1273 }
1274
1275 // Obtain an integer sample in the cone by rounding up a rational point from
1276 // the shrunken cone. Shrinking the cone amounts to shifting its apex
1277 // "inwards" without changing its "shape"; the shrunken cone is still a
1278 // full-dimensional cone and is hence non-empty.
1279 Simplex shrunkenConeSimplex(cone);
1280 assert(!shrunkenConeSimplex.isEmpty() && "Shrunken cone cannot be empty!");
1281 SmallVector<Fraction, 8> shrunkenConeSample =
1282 shrunkenConeSimplex.getRationalSample();
1283
1284 SmallVector<int64_t, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil));
1285
1286 // 6) Return transform * concat(boundedSample, coneSample).
1287 SmallVector<int64_t, 8> &sample = boundedSample.getValue();
1288 sample.append(coneSample.begin(), coneSample.end());
1289 return transform.preMultiplyColumn(sample);
1290 }
1291
1292 /// Helper to evaluate an affine expression at a point.
1293 /// The expression is a list of coefficients for the dimensions followed by the
1294 /// constant term.
valueAt(ArrayRef<int64_t> expr,ArrayRef<int64_t> point)1295 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
1296 assert(expr.size() == 1 + point.size() &&
1297 "Dimensionalities of point and expression don't match!");
1298 int64_t value = expr.back();
1299 for (unsigned i = 0; i < point.size(); ++i)
1300 value += expr[i] * point[i];
1301 return value;
1302 }
1303
1304 /// A point satisfies an equality iff the value of the equality at the
1305 /// expression is zero, and it satisfies an inequality iff the value of the
1306 /// inequality at that point is non-negative.
containsPoint(ArrayRef<int64_t> point) const1307 bool FlatAffineConstraints::containsPoint(ArrayRef<int64_t> point) const {
1308 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1309 if (valueAt(getEquality(i), point) != 0)
1310 return false;
1311 }
1312 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1313 if (valueAt(getInequality(i), point) < 0)
1314 return false;
1315 }
1316 return true;
1317 }
1318
1319 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
1320 /// function of other identifiers (where the divisor is a positive constant),
1321 /// `foundRepr` contains a boolean for each identifier indicating if the
1322 /// explicit representation for that identifier has already been computed.
1323 static Optional<std::pair<unsigned, unsigned>>
computeSingleVarRepr(const FlatAffineConstraints & cst,const SmallVector<bool,8> & foundRepr,unsigned pos)1324 computeSingleVarRepr(const FlatAffineConstraints &cst,
1325 const SmallVector<bool, 8> &foundRepr, unsigned pos) {
1326 assert(pos < cst.getNumIds() && "invalid position");
1327 assert(foundRepr.size() == cst.getNumIds() &&
1328 "Size of foundRepr does not match total number of variables");
1329
1330 SmallVector<unsigned, 4> lbIndices, ubIndices;
1331 cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices);
1332
1333 // `id` is equivalent to `expr floordiv divisor` if there
1334 // are constraints of the form:
1335 // 0 <= expr - divisor * id <= divisor - 1
1336 // Rearranging, we have:
1337 // divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id'
1338 // -divisor * id + expr >= 0 <-- Upper bound for 'id'
1339 //
1340 // For example:
1341 // 32*k >= 16*i + j - 31 <-- Lower bound for 'k'
1342 // 32*k <= 16*i + j <-- Upper bound for 'k'
1343 // expr = 16*i + j, divisor = 32
1344 // k = ( 16*i + j ) floordiv 32
1345 //
1346 // 4q >= i + j - 2 <-- Lower bound for 'q'
1347 // 4q <= i + j + 1 <-- Upper bound for 'q'
1348 // expr = i + j + 1, divisor = 4
1349 // q = (i + j + 1) floordiv 4
1350 for (unsigned ubPos : ubIndices) {
1351 for (unsigned lbPos : lbIndices) {
1352 // Due to the form of the inequalities, sum of constants of the
1353 // inequalities is (divisor - 1).
1354 int64_t divisor = cst.atIneq(lbPos, cst.getNumCols() - 1) +
1355 cst.atIneq(ubPos, cst.getNumCols() - 1) + 1;
1356
1357 // Divisor should be positive.
1358 if (divisor <= 0)
1359 continue;
1360
1361 // Check if coeff of variable is equal to divisor.
1362 if (divisor != cst.atIneq(lbPos, pos))
1363 continue;
1364
1365 // Check if constraints are opposite of each other. Constant term
1366 // is not required to be opposite and is not checked.
1367 unsigned c = 0, f = 0;
1368 for (c = 0, f = cst.getNumIds(); c < f; ++c)
1369 if (cst.atIneq(ubPos, c) != -cst.atIneq(lbPos, c))
1370 break;
1371
1372 if (c < f)
1373 continue;
1374
1375 // Check if the inequalities depend on a variable for which
1376 // an explicit representation has not been found yet.
1377 // Exit to avoid circular dependencies between divisions.
1378 for (c = 0, f = cst.getNumIds(); c < f; ++c) {
1379 if (c == pos)
1380 continue;
1381 if (!foundRepr[c] && cst.atIneq(lbPos, c) != 0)
1382 break;
1383 }
1384
1385 // Expression can't be constructed as it depends on a yet unknown
1386 // identifier.
1387 // TODO: Visit/compute the identifiers in an order so that this doesn't
1388 // happen. More complex but much more efficient.
1389 if (c < f)
1390 continue;
1391
1392 return std::make_pair(ubPos, lbPos);
1393 }
1394 }
1395
1396 return llvm::None;
1397 }
1398
1399 /// Find pairs of inequalities identified by their position indices, using
1400 /// which an explicit representation for each local variable can be computed
1401 /// The pairs are stored as indices of upperbound, lowerbound
1402 /// inequalities. If no such pair can be found, it is stored as llvm::None.
getLocalReprLbUbPairs(std::vector<llvm::Optional<std::pair<unsigned,unsigned>>> & repr) const1403 void FlatAffineConstraints::getLocalReprLbUbPairs(
1404 std::vector<llvm::Optional<std::pair<unsigned, unsigned>>> &repr) const {
1405 assert(repr.size() == getNumLocalIds() &&
1406 "Size of repr does not match number of local variables");
1407
1408 SmallVector<bool, 8> foundRepr(getNumIds(), false);
1409 for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i)
1410 foundRepr[i] = true;
1411
1412 unsigned divOffset = getNumDimAndSymbolIds();
1413 bool changed;
1414 do {
1415 // Each time changed is true, at end of this iteration, one or more local
1416 // vars have been detected as floor divs.
1417 changed = false;
1418 for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) {
1419 if (!foundRepr[i + divOffset]) {
1420 if (auto res = computeSingleVarRepr(*this, foundRepr, divOffset + i)) {
1421 foundRepr[i + divOffset] = true;
1422 repr[i] = res;
1423 changed = true;
1424 }
1425 }
1426 }
1427 } while (changed);
1428 }
1429
1430 /// Tightens inequalities given that we are dealing with integer spaces. This is
1431 /// analogous to the GCD test but applied to inequalities. The constant term can
1432 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
1433 /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a
1434 /// fast method - linear in the number of coefficients.
1435 // Example on how this affects practical cases: consider the scenario:
1436 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
1437 // j >= 100 instead of the tighter (exact) j >= 128.
gcdTightenInequalities()1438 void FlatAffineConstraints::gcdTightenInequalities() {
1439 unsigned numCols = getNumCols();
1440 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1441 uint64_t gcd = std::abs(atIneq(i, 0));
1442 for (unsigned j = 1; j < numCols - 1; ++j) {
1443 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j)));
1444 }
1445 if (gcd > 0 && gcd != 1) {
1446 int64_t gcdI = static_cast<int64_t>(gcd);
1447 // Tighten the constant term and normalize the constraint by the GCD.
1448 atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI);
1449 for (unsigned j = 0, e = numCols - 1; j < e; ++j)
1450 atIneq(i, j) /= gcdI;
1451 }
1452 }
1453 }
1454
1455 // Eliminates all identifier variables in column range [posStart, posLimit).
1456 // Returns the number of variables eliminated.
gaussianEliminateIds(unsigned posStart,unsigned posLimit)1457 unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
1458 unsigned posLimit) {
1459 // Return if identifier positions to eliminate are out of range.
1460 assert(posLimit <= numIds);
1461 assert(hasConsistentState());
1462
1463 if (posStart >= posLimit)
1464 return 0;
1465
1466 gcdTightenInequalities();
1467
1468 unsigned pivotCol = 0;
1469 for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
1470 // Find a row which has a non-zero coefficient in column 'j'.
1471 unsigned pivotRow;
1472 if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true,
1473 &pivotRow)) {
1474 // No pivot row in equalities with non-zero at 'pivotCol'.
1475 if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false,
1476 &pivotRow)) {
1477 // If inequalities are also non-zero in 'pivotCol', it can be
1478 // eliminated.
1479 continue;
1480 }
1481 break;
1482 }
1483
1484 // Eliminate identifier at 'pivotCol' from each equality row.
1485 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1486 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1487 /*isEq=*/true);
1488 normalizeConstraintByGCD</*isEq=*/true>(this, i);
1489 }
1490
1491 // Eliminate identifier at 'pivotCol' from each inequality row.
1492 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1493 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1494 /*isEq=*/false);
1495 normalizeConstraintByGCD</*isEq=*/false>(this, i);
1496 }
1497 removeEquality(pivotRow);
1498 gcdTightenInequalities();
1499 }
1500 // Update position limit based on number eliminated.
1501 posLimit = pivotCol;
1502 // Remove eliminated columns from all constraints.
1503 removeIdRange(posStart, posLimit);
1504 return posLimit - posStart;
1505 }
1506
1507 // Determine whether the identifier at 'pos' (say id_r) can be expressed as
1508 // modulo of another known identifier (say id_n) w.r.t a constant. For example,
1509 // if the following constraints hold true:
1510 // ```
1511 // 0 <= id_r <= divisor - 1
1512 // id_n - (divisor * q_expr) = id_r
1513 // ```
1514 // where `id_n` is a known identifier (called dividend), and `q_expr` is an
1515 // `AffineExpr` (called the quotient expression), `id_r` can be written as:
1516 //
1517 // `id_r = id_n mod divisor`.
1518 //
1519 // Additionally, in a special case of the above constaints where `q_expr` is an
1520 // identifier itself that is not yet known (say `id_q`), it can be written as a
1521 // floordiv in the following way:
1522 //
1523 // `id_q = id_n floordiv divisor`.
1524 //
1525 // Returns true if the above mod or floordiv are detected, updating 'memo' with
1526 // these new expressions. Returns false otherwise.
detectAsMod(const FlatAffineConstraints & cst,unsigned pos,int64_t lbConst,int64_t ubConst,SmallVectorImpl<AffineExpr> & memo,MLIRContext * context)1527 static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
1528 int64_t lbConst, int64_t ubConst,
1529 SmallVectorImpl<AffineExpr> &memo,
1530 MLIRContext *context) {
1531 assert(pos < cst.getNumIds() && "invalid position");
1532
1533 // Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can
1534 // be determined.
1535 if (lbConst != 0 || ubConst < 1)
1536 return false;
1537 int64_t divisor = ubConst + 1;
1538
1539 // Check for the aforementioned conditions in each equality.
1540 for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
1541 curEquality < numEqualities; curEquality++) {
1542 int64_t coefficientAtPos = cst.atEq(curEquality, pos);
1543 // If current equality does not involve `id_r`, continue to the next
1544 // equality.
1545 if (coefficientAtPos == 0)
1546 continue;
1547
1548 // Constant term should be 0 in this equality.
1549 if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
1550 continue;
1551
1552 // Traverse through the equality and construct the dividend expression
1553 // `dividendExpr`, to contain all the identifiers which are known and are
1554 // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
1555 // `dividendExpr` gets simplified into a single identifier `id_n` discussed
1556 // above.
1557 auto dividendExpr = getAffineConstantExpr(0, context);
1558
1559 // Track the terms that go into quotient expression, later used to detect
1560 // additional floordiv.
1561 unsigned quotientCount = 0;
1562 int quotientPosition = -1;
1563 int quotientSign = 1;
1564
1565 // Consider each term in the current equality.
1566 unsigned curId, e;
1567 for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) {
1568 // Ignore id_r.
1569 if (curId == pos)
1570 continue;
1571 int64_t coefficientOfCurId = cst.atEq(curEquality, curId);
1572 // Ignore ids that do not contribute to the current equality.
1573 if (coefficientOfCurId == 0)
1574 continue;
1575 // Check if the current id goes into the quotient expression.
1576 if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) {
1577 quotientCount++;
1578 quotientPosition = curId;
1579 quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1;
1580 continue;
1581 }
1582 // Identifiers that are part of dividendExpr should be known.
1583 if (!memo[curId])
1584 break;
1585 // Append the current identifier to the dividend expression.
1586 dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId;
1587 }
1588
1589 // Can't construct expression as it depends on a yet uncomputed id.
1590 if (curId < e)
1591 continue;
1592
1593 // Express `id_r` in terms of the other ids collected so far.
1594 if (coefficientAtPos > 0)
1595 dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
1596 else
1597 dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
1598
1599 // Simplify the expression.
1600 dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(),
1601 cst.getNumSymbolIds());
1602 // Only if the final dividend expression is just a single id (which we call
1603 // `id_n`), we can proceed.
1604 // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
1605 // to dims themselves.
1606 auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
1607 if (!dimExpr)
1608 continue;
1609
1610 // Express `id_r` as `id_n % divisor` and store the expression in `memo`.
1611 if (quotientCount >= 1) {
1612 auto ub = cst.getConstantBound(FlatAffineConstraints::BoundType::UB,
1613 dimExpr.getPosition());
1614 // If `id_n` has an upperbound that is less than the divisor, mod can be
1615 // eliminated altogether.
1616 if (ub.hasValue() && ub.getValue() < divisor)
1617 memo[pos] = dimExpr;
1618 else
1619 memo[pos] = dimExpr % divisor;
1620 // If a unique quotient `id_q` was seen, it can be expressed as
1621 // `id_n floordiv divisor`.
1622 if (quotientCount == 1 && !memo[quotientPosition])
1623 memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
1624
1625 return true;
1626 }
1627 }
1628 return false;
1629 }
1630
1631 /// Gather all lower and upper bounds of the identifier at `pos`, and
1632 /// optionally any equalities on it. In addition, the bounds are to be
1633 /// independent of identifiers in position range [`offset`, `offset` + `num`).
getLowerAndUpperBoundIndices(unsigned pos,SmallVectorImpl<unsigned> * lbIndices,SmallVectorImpl<unsigned> * ubIndices,SmallVectorImpl<unsigned> * eqIndices,unsigned offset,unsigned num) const1634 void FlatAffineConstraints::getLowerAndUpperBoundIndices(
1635 unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
1636 SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
1637 unsigned offset, unsigned num) const {
1638 assert(pos < getNumIds() && "invalid position");
1639 assert(offset + num < getNumCols() && "invalid range");
1640
1641 // Checks for a constraint that has a non-zero coeff for the identifiers in
1642 // the position range [offset, offset + num) while ignoring `pos`.
1643 auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
1644 unsigned c, f;
1645 auto cst = isEq ? getEquality(r) : getInequality(r);
1646 for (c = offset, f = offset + num; c < f; ++c) {
1647 if (c == pos)
1648 continue;
1649 if (cst[c] != 0)
1650 break;
1651 }
1652 return c < f;
1653 };
1654
1655 // Gather all lower bounds and upper bounds of the variable. Since the
1656 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1657 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1658 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1659 // The bounds are to be independent of [offset, offset + num) columns.
1660 if (containsConstraintDependentOnRange(r, /*isEq=*/false))
1661 continue;
1662 if (atIneq(r, pos) >= 1) {
1663 // Lower bound.
1664 lbIndices->push_back(r);
1665 } else if (atIneq(r, pos) <= -1) {
1666 // Upper bound.
1667 ubIndices->push_back(r);
1668 }
1669 }
1670
1671 // An equality is both a lower and upper bound. Record any equalities
1672 // involving the pos^th identifier.
1673 if (!eqIndices)
1674 return;
1675
1676 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1677 if (atEq(r, pos) == 0)
1678 continue;
1679 if (containsConstraintDependentOnRange(r, /*isEq=*/true))
1680 continue;
1681 eqIndices->push_back(r);
1682 }
1683 }
1684
1685 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
1686 /// function of other identifiers (where the divisor is a positive constant)
1687 /// given the initial set of expressions in `exprs`. If it can be, the
1688 /// corresponding position in `exprs` is set as the detected affine expr. For
1689 /// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can
1690 /// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
1691 /// <= i <= 32q + 31 => q = i floordiv 32.
detectAsFloorDiv(const FlatAffineConstraints & cst,unsigned pos,MLIRContext * context,SmallVectorImpl<AffineExpr> & exprs)1692 static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
1693 MLIRContext *context,
1694 SmallVectorImpl<AffineExpr> &exprs) {
1695 assert(pos < cst.getNumIds() && "invalid position");
1696
1697 // Get upper-lower bound pair for this variable.
1698 SmallVector<bool, 8> foundRepr(cst.getNumIds(), false);
1699 for (unsigned i = 0, e = cst.getNumIds(); i < e; ++i)
1700 if (exprs[i])
1701 foundRepr[i] = true;
1702
1703 auto ulPair = computeSingleVarRepr(cst, foundRepr, pos);
1704
1705 // No upper-lower bound pair found for this var.
1706 if (!ulPair)
1707 return false;
1708
1709 unsigned ubPos = ulPair->first;
1710
1711 // Upper bound is of the form:
1712 // -divisor * id + expr >= 0
1713 // where `id` is equivalent to `expr floordiv divisor`.
1714 //
1715 // Since the division cannot be dependent on itself, the coefficient of
1716 // of `id` in `expr` is zero. The coefficient of `id` in the upperbound
1717 // is -divisor.
1718 int64_t divisor = -cst.atIneq(ubPos, pos);
1719 int64_t constantTerm = cst.atIneq(ubPos, cst.getNumCols() - 1);
1720
1721 // Construct the dividend expression.
1722 auto dividendExpr = getAffineConstantExpr(constantTerm, context);
1723 unsigned c, f;
1724 for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1725 if (c == pos)
1726 continue;
1727 int64_t ubVal = cst.atIneq(ubPos, c);
1728 if (ubVal == 0)
1729 continue;
1730 // computeSingleVarRepr guarantees that expr is known here.
1731 dividendExpr = dividendExpr + ubVal * exprs[c];
1732 }
1733
1734 // Successfully detected the floordiv.
1735 exprs[pos] = dividendExpr.floorDiv(divisor);
1736 return true;
1737 }
1738
1739 // Fills an inequality row with the value 'val'.
fillInequality(FlatAffineConstraints * cst,unsigned r,int64_t val)1740 static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
1741 int64_t val) {
1742 for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1743 cst->atIneq(r, c) = val;
1744 }
1745 }
1746
1747 // Negates an inequality.
negateInequality(FlatAffineConstraints * cst,unsigned r)1748 static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
1749 for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1750 cst->atIneq(r, c) = -cst->atIneq(r, c);
1751 }
1752 }
1753
1754 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
1755 // to check if a constraint is redundant.
removeRedundantInequalities()1756 void FlatAffineConstraints::removeRedundantInequalities() {
1757 SmallVector<bool, 32> redun(getNumInequalities(), false);
1758 // To check if an inequality is redundant, we replace the inequality by its
1759 // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
1760 // system is empty. If it is, the inequality is redundant.
1761 FlatAffineConstraints tmpCst(*this);
1762 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1763 // Change the inequality to its complement.
1764 negateInequality(&tmpCst, r);
1765 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
1766 if (tmpCst.isEmpty()) {
1767 redun[r] = true;
1768 // Zero fill the redundant inequality.
1769 fillInequality(this, r, /*val=*/0);
1770 fillInequality(&tmpCst, r, /*val=*/0);
1771 } else {
1772 // Reverse the change (to avoid recreating tmpCst each time).
1773 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
1774 negateInequality(&tmpCst, r);
1775 }
1776 }
1777
1778 // Scan to get rid of all rows marked redundant, in-place.
1779 auto copyRow = [&](unsigned src, unsigned dest) {
1780 if (src == dest)
1781 return;
1782 for (unsigned c = 0, e = getNumCols(); c < e; c++) {
1783 atIneq(dest, c) = atIneq(src, c);
1784 }
1785 };
1786 unsigned pos = 0;
1787 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1788 if (!redun[r])
1789 copyRow(r, pos++);
1790 }
1791 inequalities.resizeVertically(pos);
1792 }
1793
1794 // A more complex check to eliminate redundant inequalities and equalities. Uses
1795 // Simplex to check if a constraint is redundant.
removeRedundantConstraints()1796 void FlatAffineConstraints::removeRedundantConstraints() {
1797 // First, we run gcdTightenInequalities. This allows us to catch some
1798 // constraints which are not redundant when considering rational solutions
1799 // but are redundant in terms of integer solutions.
1800 gcdTightenInequalities();
1801 Simplex simplex(*this);
1802 simplex.detectRedundant();
1803
1804 auto copyInequality = [&](unsigned src, unsigned dest) {
1805 if (src == dest)
1806 return;
1807 for (unsigned c = 0, e = getNumCols(); c < e; c++)
1808 atIneq(dest, c) = atIneq(src, c);
1809 };
1810 unsigned pos = 0;
1811 unsigned numIneqs = getNumInequalities();
1812 // Scan to get rid of all inequalities marked redundant, in-place. In Simplex,
1813 // the first constraints added are the inequalities.
1814 for (unsigned r = 0; r < numIneqs; r++) {
1815 if (!simplex.isMarkedRedundant(r))
1816 copyInequality(r, pos++);
1817 }
1818 inequalities.resizeVertically(pos);
1819
1820 // Scan to get rid of all equalities marked redundant, in-place. In Simplex,
1821 // after the inequalities, a pair of constraints for each equality is added.
1822 // An equality is redundant if both the inequalities in its pair are
1823 // redundant.
1824 auto copyEquality = [&](unsigned src, unsigned dest) {
1825 if (src == dest)
1826 return;
1827 for (unsigned c = 0, e = getNumCols(); c < e; c++)
1828 atEq(dest, c) = atEq(src, c);
1829 };
1830 pos = 0;
1831 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1832 if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) &&
1833 simplex.isMarkedRedundant(numIneqs + 2 * r + 1)))
1834 copyEquality(r, pos++);
1835 }
1836 equalities.resizeVertically(pos);
1837 }
1838
1839 /// Merge local ids of `this` and `other`. This is done by appending local ids
1840 /// of `other` to `this` and inserting local ids of `this` to `other` at start
1841 /// of its local ids.
mergeLocalIds(FlatAffineConstraints & other)1842 void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) {
1843 unsigned initLocals = getNumLocalIds();
1844 insertLocalId(getNumLocalIds(), other.getNumLocalIds());
1845 other.insertLocalId(0, initLocals);
1846 }
1847
1848 /// Removes local variables using equalities. Each equality is checked if it
1849 /// can be reduced to the form: `e = affine-expr`, where `e` is a local
1850 /// variable and `affine-expr` is an affine expression not containing `e`.
1851 /// If an equality satisfies this form, the local variable is replaced in
1852 /// each constraint and then removed. The equality used to replace this local
1853 /// variable is also removed.
removeRedundantLocalVars()1854 void FlatAffineConstraints::removeRedundantLocalVars() {
1855 // Normalize the equality constraints to reduce coefficients of local
1856 // variables to 1 wherever possible.
1857 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1858 normalizeConstraintByGCD</*isEq=*/true>(this, i);
1859
1860 while (true) {
1861 unsigned i, e, j, f;
1862 for (i = 0, e = getNumEqualities(); i < e; ++i) {
1863 // Find a local variable to eliminate using ith equality.
1864 for (j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j)
1865 if (std::abs(atEq(i, j)) == 1)
1866 break;
1867
1868 // Local variable can be eliminated using ith equality.
1869 if (j < f)
1870 break;
1871 }
1872
1873 // No equality can be used to eliminate a local variable.
1874 if (i == e)
1875 break;
1876
1877 // Use the ith equality to simplify other equalities. If any changes
1878 // are made to an equality constraint, it is normalized by GCD.
1879 for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) {
1880 if (atEq(k, j) != 0) {
1881 eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true);
1882 normalizeConstraintByGCD</*isEq=*/true>(this, k);
1883 }
1884 }
1885
1886 // Use the ith equality to simplify inequalities.
1887 for (unsigned k = 0, t = getNumInequalities(); k < t; ++k)
1888 eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false);
1889
1890 // Remove the ith equality and the found local variable.
1891 removeId(j);
1892 removeEquality(i);
1893 }
1894 }
1895
getLowerAndUpperBound(unsigned pos,unsigned offset,unsigned num,unsigned symStartPos,ArrayRef<AffineExpr> localExprs,MLIRContext * context) const1896 std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
1897 unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
1898 ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
1899 assert(pos + offset < getNumDimIds() && "invalid dim start pos");
1900 assert(symStartPos >= (pos + offset) && "invalid sym start pos");
1901 assert(getNumLocalIds() == localExprs.size() &&
1902 "incorrect local exprs count");
1903
1904 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
1905 getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
1906 offset, num);
1907
1908 /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
1909 auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
1910 b.clear();
1911 for (unsigned i = 0, e = a.size(); i < e; ++i) {
1912 if (i < offset || i >= offset + num)
1913 b.push_back(a[i]);
1914 }
1915 };
1916
1917 SmallVector<int64_t, 8> lb, ub;
1918 SmallVector<AffineExpr, 4> lbExprs;
1919 unsigned dimCount = symStartPos - num;
1920 unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
1921 lbExprs.reserve(lbIndices.size() + eqIndices.size());
1922 // Lower bound expressions.
1923 for (auto idx : lbIndices) {
1924 auto ineq = getInequality(idx);
1925 // Extract the lower bound (in terms of other coeff's + const), i.e., if
1926 // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
1927 // - 1.
1928 addCoeffs(ineq, lb);
1929 std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
1930 auto expr =
1931 getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
1932 // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
1933 int64_t divisor = std::abs(ineq[pos + offset]);
1934 expr = (expr + divisor - 1).floorDiv(divisor);
1935 lbExprs.push_back(expr);
1936 }
1937
1938 SmallVector<AffineExpr, 4> ubExprs;
1939 ubExprs.reserve(ubIndices.size() + eqIndices.size());
1940 // Upper bound expressions.
1941 for (auto idx : ubIndices) {
1942 auto ineq = getInequality(idx);
1943 // Extract the upper bound (in terms of other coeff's + const).
1944 addCoeffs(ineq, ub);
1945 auto expr =
1946 getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
1947 expr = expr.floorDiv(std::abs(ineq[pos + offset]));
1948 // Upper bound is exclusive.
1949 ubExprs.push_back(expr + 1);
1950 }
1951
1952 // Equalities. It's both a lower and a upper bound.
1953 SmallVector<int64_t, 4> b;
1954 for (auto idx : eqIndices) {
1955 auto eq = getEquality(idx);
1956 addCoeffs(eq, b);
1957 if (eq[pos + offset] > 0)
1958 std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
1959
1960 // Extract the upper bound (in terms of other coeff's + const).
1961 auto expr =
1962 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1963 expr = expr.floorDiv(std::abs(eq[pos + offset]));
1964 // Upper bound is exclusive.
1965 ubExprs.push_back(expr + 1);
1966 // Lower bound.
1967 expr =
1968 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1969 expr = expr.ceilDiv(std::abs(eq[pos + offset]));
1970 lbExprs.push_back(expr);
1971 }
1972
1973 auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
1974 auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
1975
1976 return {lbMap, ubMap};
1977 }
1978
1979 /// Computes the lower and upper bounds of the first 'num' dimensional
1980 /// identifiers (starting at 'offset') as affine maps of the remaining
1981 /// identifiers (dimensional and symbolic identifiers). Local identifiers are
1982 /// themselves explicitly computed as affine functions of other identifiers in
1983 /// this process if needed.
getSliceBounds(unsigned offset,unsigned num,MLIRContext * context,SmallVectorImpl<AffineMap> * lbMaps,SmallVectorImpl<AffineMap> * ubMaps)1984 void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
1985 MLIRContext *context,
1986 SmallVectorImpl<AffineMap> *lbMaps,
1987 SmallVectorImpl<AffineMap> *ubMaps) {
1988 assert(num < getNumDimIds() && "invalid range");
1989
1990 // Basic simplification.
1991 normalizeConstraintsByGCD();
1992
1993 LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
1994 << " identifiers\n");
1995 LLVM_DEBUG(dump());
1996
1997 // Record computed/detected identifiers.
1998 SmallVector<AffineExpr, 8> memo(getNumIds());
1999 // Initialize dimensional and symbolic identifiers.
2000 for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
2001 if (i < offset)
2002 memo[i] = getAffineDimExpr(i, context);
2003 else if (i >= offset + num)
2004 memo[i] = getAffineDimExpr(i - num, context);
2005 }
2006 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
2007 memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
2008
2009 bool changed;
2010 do {
2011 changed = false;
2012 // Identify yet unknown identifiers as constants or mod's / floordiv's of
2013 // other identifiers if possible.
2014 for (unsigned pos = 0; pos < getNumIds(); pos++) {
2015 if (memo[pos])
2016 continue;
2017
2018 auto lbConst = getConstantBound(BoundType::LB, pos);
2019 auto ubConst = getConstantBound(BoundType::UB, pos);
2020 if (lbConst.hasValue() && ubConst.hasValue()) {
2021 // Detect equality to a constant.
2022 if (lbConst.getValue() == ubConst.getValue()) {
2023 memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
2024 changed = true;
2025 continue;
2026 }
2027
2028 // Detect an identifier as modulo of another identifier w.r.t a
2029 // constant.
2030 if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
2031 memo, context)) {
2032 changed = true;
2033 continue;
2034 }
2035 }
2036
2037 // Detect an identifier as a floordiv of an affine function of other
2038 // identifiers (divisor is a positive constant).
2039 if (detectAsFloorDiv(*this, pos, context, memo)) {
2040 changed = true;
2041 continue;
2042 }
2043
2044 // Detect an identifier as an expression of other identifiers.
2045 unsigned idx;
2046 if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) {
2047 continue;
2048 }
2049
2050 // Build AffineExpr solving for identifier 'pos' in terms of all others.
2051 auto expr = getAffineConstantExpr(0, context);
2052 unsigned j, e;
2053 for (j = 0, e = getNumIds(); j < e; ++j) {
2054 if (j == pos)
2055 continue;
2056 int64_t c = atEq(idx, j);
2057 if (c == 0)
2058 continue;
2059 // If any of the involved IDs hasn't been found yet, we can't proceed.
2060 if (!memo[j])
2061 break;
2062 expr = expr + memo[j] * c;
2063 }
2064 if (j < e)
2065 // Can't construct expression as it depends on a yet uncomputed
2066 // identifier.
2067 continue;
2068
2069 // Add constant term to AffineExpr.
2070 expr = expr + atEq(idx, getNumIds());
2071 int64_t vPos = atEq(idx, pos);
2072 assert(vPos != 0 && "expected non-zero here");
2073 if (vPos > 0)
2074 expr = (-expr).floorDiv(vPos);
2075 else
2076 // vPos < 0.
2077 expr = expr.floorDiv(-vPos);
2078 // Successfully constructed expression.
2079 memo[pos] = expr;
2080 changed = true;
2081 }
2082 // This loop is guaranteed to reach a fixed point - since once an
2083 // identifier's explicit form is computed (in memo[pos]), it's not updated
2084 // again.
2085 } while (changed);
2086
2087 // Set the lower and upper bound maps for all the identifiers that were
2088 // computed as affine expressions of the rest as the "detected expr" and
2089 // "detected expr + 1" respectively; set the undetected ones to null.
2090 Optional<FlatAffineConstraints> tmpClone;
2091 for (unsigned pos = 0; pos < num; pos++) {
2092 unsigned numMapDims = getNumDimIds() - num;
2093 unsigned numMapSymbols = getNumSymbolIds();
2094 AffineExpr expr = memo[pos + offset];
2095 if (expr)
2096 expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
2097
2098 AffineMap &lbMap = (*lbMaps)[pos];
2099 AffineMap &ubMap = (*ubMaps)[pos];
2100
2101 if (expr) {
2102 lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
2103 ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
2104 } else {
2105 // TODO: Whenever there are local identifiers in the dependence
2106 // constraints, we'll conservatively over-approximate, since we don't
2107 // always explicitly compute them above (in the while loop).
2108 if (getNumLocalIds() == 0) {
2109 // Work on a copy so that we don't update this constraint system.
2110 if (!tmpClone) {
2111 tmpClone.emplace(FlatAffineConstraints(*this));
2112 // Removing redundant inequalities is necessary so that we don't get
2113 // redundant loop bounds.
2114 tmpClone->removeRedundantInequalities();
2115 }
2116 std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
2117 pos, offset, num, getNumDimIds(), /*localExprs=*/{}, context);
2118 }
2119
2120 // If the above fails, we'll just use the constant lower bound and the
2121 // constant upper bound (if they exist) as the slice bounds.
2122 // TODO: being conservative for the moment in cases that
2123 // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
2124 // fixed (b/126426796).
2125 if (!lbMap || lbMap.getNumResults() > 1) {
2126 LLVM_DEBUG(llvm::dbgs()
2127 << "WARNING: Potentially over-approximating slice lb\n");
2128 auto lbConst = getConstantBound(BoundType::LB, pos + offset);
2129 if (lbConst.hasValue()) {
2130 lbMap = AffineMap::get(
2131 numMapDims, numMapSymbols,
2132 getAffineConstantExpr(lbConst.getValue(), context));
2133 }
2134 }
2135 if (!ubMap || ubMap.getNumResults() > 1) {
2136 LLVM_DEBUG(llvm::dbgs()
2137 << "WARNING: Potentially over-approximating slice ub\n");
2138 auto ubConst = getConstantBound(BoundType::UB, pos + offset);
2139 if (ubConst.hasValue()) {
2140 (ubMap) = AffineMap::get(
2141 numMapDims, numMapSymbols,
2142 getAffineConstantExpr(ubConst.getValue() + 1, context));
2143 }
2144 }
2145 }
2146 LLVM_DEBUG(llvm::dbgs()
2147 << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
2148 LLVM_DEBUG(lbMap.dump(););
2149 LLVM_DEBUG(llvm::dbgs()
2150 << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
2151 LLVM_DEBUG(ubMap.dump(););
2152 }
2153 }
2154
flattenAlignedMapAndMergeLocals(AffineMap map,std::vector<SmallVector<int64_t,8>> * flattenedExprs)2155 LogicalResult FlatAffineConstraints::flattenAlignedMapAndMergeLocals(
2156 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
2157 FlatAffineConstraints localCst;
2158 if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) {
2159 LLVM_DEBUG(llvm::dbgs()
2160 << "composition unimplemented for semi-affine maps\n");
2161 return failure();
2162 }
2163
2164 // Add localCst information.
2165 if (localCst.getNumLocalIds() > 0) {
2166 unsigned numLocalIds = getNumLocalIds();
2167 // Insert local dims of localCst at the beginning.
2168 insertLocalId(/*pos=*/0, /*num=*/localCst.getNumLocalIds());
2169 // Insert local dims of `this` at the end of localCst.
2170 localCst.appendLocalId(/*num=*/numLocalIds);
2171 // Dimensions of localCst and this constraint set match. Append localCst to
2172 // this constraint set.
2173 append(localCst);
2174 }
2175
2176 return success();
2177 }
2178
addBound(BoundType type,unsigned pos,AffineMap boundMap)2179 LogicalResult FlatAffineConstraints::addBound(BoundType type, unsigned pos,
2180 AffineMap boundMap) {
2181 assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch");
2182 assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
2183 assert(pos < getNumDimAndSymbolIds() && "invalid position");
2184
2185 // Equality follows the logic of lower bound except that we add an equality
2186 // instead of an inequality.
2187 assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
2188 "single result expected");
2189 bool lower = type == BoundType::LB || type == BoundType::EQ;
2190
2191 std::vector<SmallVector<int64_t, 8>> flatExprs;
2192 if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
2193 return failure();
2194 assert(flatExprs.size() == boundMap.getNumResults());
2195
2196 // Add one (in)equality for each result.
2197 for (const auto &flatExpr : flatExprs) {
2198 SmallVector<int64_t> ineq(getNumCols(), 0);
2199 // Dims and symbols.
2200 for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
2201 ineq[j] = lower ? -flatExpr[j] : flatExpr[j];
2202 }
2203 // Invalid bound: pos appears in `boundMap`.
2204 // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or
2205 // its callers to prevent invalid bounds from being added.
2206 if (ineq[pos] != 0)
2207 continue;
2208 ineq[pos] = lower ? 1 : -1;
2209 // Local columns of `ineq` are at the beginning.
2210 unsigned j = getNumDimIds() + getNumSymbolIds();
2211 unsigned end = flatExpr.size() - 1;
2212 for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
2213 ineq[j] = lower ? -flatExpr[i] : flatExpr[i];
2214 }
2215 // Constant term.
2216 ineq[getNumCols() - 1] =
2217 lower ? -flatExpr[flatExpr.size() - 1]
2218 // Upper bound in flattenedExpr is an exclusive one.
2219 : flatExpr[flatExpr.size() - 1] - 1;
2220 type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq);
2221 }
2222
2223 return success();
2224 }
2225
2226 AffineMap
computeAlignedMap(AffineMap map,ValueRange operands) const2227 FlatAffineValueConstraints::computeAlignedMap(AffineMap map,
2228 ValueRange operands) const {
2229 assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
2230
2231 SmallVector<Value> dims, syms;
2232 #ifndef NDEBUG
2233 SmallVector<Value> newSyms;
2234 SmallVector<Value> *newSymsPtr = &newSyms;
2235 #else
2236 SmallVector<Value> *newSymsPtr = nullptr;
2237 #endif // NDEBUG
2238
2239 dims.reserve(numDims);
2240 syms.reserve(numSymbols);
2241 for (unsigned i = 0; i < numDims; ++i)
2242 dims.push_back(values[i] ? *values[i] : Value());
2243 for (unsigned i = numDims, e = numDims + numSymbols; i < e; ++i)
2244 syms.push_back(values[i] ? *values[i] : Value());
2245
2246 AffineMap alignedMap =
2247 alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
2248 // All symbols are already part of this FlatAffineConstraints.
2249 assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
2250 assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
2251 "unexpected new/missing symbols");
2252 return alignedMap;
2253 }
2254
addBound(BoundType type,unsigned pos,AffineMap boundMap,ValueRange boundOperands)2255 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
2256 AffineMap boundMap,
2257 ValueRange boundOperands) {
2258 // Fully compose map and operands; canonicalize and simplify so that we
2259 // transitively get to terminal symbols or loop IVs.
2260 auto map = boundMap;
2261 SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
2262 fullyComposeAffineMapAndOperands(&map, &operands);
2263 map = simplifyAffineMap(map);
2264 canonicalizeMapAndOperands(&map, &operands);
2265 for (auto operand : operands)
2266 addInductionVarOrTerminalSymbol(operand);
2267 return addBound(type, pos, computeAlignedMap(map, operands));
2268 }
2269
2270 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
2271 // bounds in 'ubMaps' to each value in `values' that appears in the constraint
2272 // system. Note that both lower/upper bounds share the same operand list
2273 // 'operands'.
2274 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
2275 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
2276 // Note that both lower/upper bounds use operands from 'operands'.
2277 // Returns failure for unimplemented cases such as semi-affine expressions or
2278 // expressions with mod/floordiv.
addSliceBounds(ArrayRef<Value> values,ArrayRef<AffineMap> lbMaps,ArrayRef<AffineMap> ubMaps,ArrayRef<Value> operands)2279 LogicalResult FlatAffineValueConstraints::addSliceBounds(
2280 ArrayRef<Value> values, ArrayRef<AffineMap> lbMaps,
2281 ArrayRef<AffineMap> ubMaps, ArrayRef<Value> operands) {
2282 assert(values.size() == lbMaps.size());
2283 assert(lbMaps.size() == ubMaps.size());
2284
2285 for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
2286 unsigned pos;
2287 if (!findId(values[i], &pos))
2288 continue;
2289
2290 AffineMap lbMap = lbMaps[i];
2291 AffineMap ubMap = ubMaps[i];
2292 assert(!lbMap || lbMap.getNumInputs() == operands.size());
2293 assert(!ubMap || ubMap.getNumInputs() == operands.size());
2294
2295 // Check if this slice is just an equality along this dimension.
2296 if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
2297 ubMap.getNumResults() == 1 &&
2298 lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
2299 if (failed(addBound(BoundType::EQ, pos, lbMap, operands)))
2300 return failure();
2301 continue;
2302 }
2303
2304 // If lower or upper bound maps are null or provide no results, it implies
2305 // that the source loop was not at all sliced, and the entire loop will be a
2306 // part of the slice.
2307 if (lbMap && lbMap.getNumResults() != 0 && ubMap &&
2308 ubMap.getNumResults() != 0) {
2309 if (failed(addBound(BoundType::LB, pos, lbMap, operands)))
2310 return failure();
2311 if (failed(addBound(BoundType::UB, pos, ubMap, operands)))
2312 return failure();
2313 } else {
2314 auto loop = getForInductionVarOwner(values[i]);
2315 if (failed(this->addAffineForOpDomain(loop)))
2316 return failure();
2317 }
2318 }
2319 return success();
2320 }
2321
addEquality(ArrayRef<int64_t> eq)2322 void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
2323 assert(eq.size() == getNumCols());
2324 unsigned row = equalities.appendExtraRow();
2325 for (unsigned i = 0, e = eq.size(); i < e; ++i)
2326 equalities(row, i) = eq[i];
2327 }
2328
addInequality(ArrayRef<int64_t> inEq)2329 void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
2330 assert(inEq.size() == getNumCols());
2331 unsigned row = inequalities.appendExtraRow();
2332 for (unsigned i = 0, e = inEq.size(); i < e; ++i)
2333 inequalities(row, i) = inEq[i];
2334 }
2335
addBound(BoundType type,unsigned pos,int64_t value)2336 void FlatAffineConstraints::addBound(BoundType type, unsigned pos,
2337 int64_t value) {
2338 assert(pos < getNumCols());
2339 if (type == BoundType::EQ) {
2340 unsigned row = equalities.appendExtraRow();
2341 equalities(row, pos) = 1;
2342 equalities(row, getNumCols() - 1) = -value;
2343 } else {
2344 unsigned row = inequalities.appendExtraRow();
2345 inequalities(row, pos) = type == BoundType::LB ? 1 : -1;
2346 inequalities(row, getNumCols() - 1) =
2347 type == BoundType::LB ? -value : value;
2348 }
2349 }
2350
addBound(BoundType type,ArrayRef<int64_t> expr,int64_t value)2351 void FlatAffineConstraints::addBound(BoundType type, ArrayRef<int64_t> expr,
2352 int64_t value) {
2353 assert(type != BoundType::EQ && "EQ not implemented");
2354 assert(expr.size() == getNumCols());
2355 unsigned row = inequalities.appendExtraRow();
2356 for (unsigned i = 0, e = expr.size(); i < e; ++i)
2357 inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i];
2358 inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) +=
2359 type == BoundType::LB ? -value : value;
2360 }
2361
2362 /// Adds a new local identifier as the floordiv of an affine function of other
2363 /// identifiers, the coefficients of which are provided in 'dividend' and with
2364 /// respect to a positive constant 'divisor'. Two constraints are added to the
2365 /// system to capture equivalence with the floordiv.
2366 /// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1.
addLocalFloorDiv(ArrayRef<int64_t> dividend,int64_t divisor)2367 void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend,
2368 int64_t divisor) {
2369 assert(dividend.size() == getNumCols() && "incorrect dividend size");
2370 assert(divisor > 0 && "positive divisor expected");
2371
2372 appendLocalId();
2373
2374 // Add two constraints for this new identifier 'q'.
2375 SmallVector<int64_t, 8> bound(dividend.size() + 1);
2376
2377 // dividend - q * divisor >= 0
2378 std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
2379 bound.begin());
2380 bound.back() = dividend.back();
2381 bound[getNumIds() - 1] = -divisor;
2382 addInequality(bound);
2383
2384 // -dividend +qdivisor * q + divisor - 1 >= 0
2385 std::transform(bound.begin(), bound.end(), bound.begin(),
2386 std::negate<int64_t>());
2387 bound[bound.size() - 1] += divisor - 1;
2388 addInequality(bound);
2389 }
2390
findId(Value val,unsigned * pos) const2391 bool FlatAffineValueConstraints::findId(Value val, unsigned *pos) const {
2392 unsigned i = 0;
2393 for (const auto &mayBeId : values) {
2394 if (mayBeId.hasValue() && mayBeId.getValue() == val) {
2395 *pos = i;
2396 return true;
2397 }
2398 i++;
2399 }
2400 return false;
2401 }
2402
containsId(Value val) const2403 bool FlatAffineValueConstraints::containsId(Value val) const {
2404 return llvm::any_of(values, [&](const Optional<Value> &mayBeId) {
2405 return mayBeId.hasValue() && mayBeId.getValue() == val;
2406 });
2407 }
2408
swapId(unsigned posA,unsigned posB)2409 void FlatAffineConstraints::swapId(unsigned posA, unsigned posB) {
2410 assert(posA < getNumIds() && "invalid position A");
2411 assert(posB < getNumIds() && "invalid position B");
2412
2413 if (posA == posB)
2414 return;
2415
2416 for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
2417 std::swap(atIneq(r, posA), atIneq(r, posB));
2418 for (unsigned r = 0, e = getNumEqualities(); r < e; r++)
2419 std::swap(atEq(r, posA), atEq(r, posB));
2420 }
2421
swapId(unsigned posA,unsigned posB)2422 void FlatAffineValueConstraints::swapId(unsigned posA, unsigned posB) {
2423 FlatAffineConstraints::swapId(posA, posB);
2424 std::swap(values[posA], values[posB]);
2425 }
2426
setDimSymbolSeparation(unsigned newSymbolCount)2427 void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
2428 assert(newSymbolCount <= numDims + numSymbols &&
2429 "invalid separation position");
2430 numDims = numDims + numSymbols - newSymbolCount;
2431 numSymbols = newSymbolCount;
2432 }
2433
addBound(BoundType type,Value val,int64_t value)2434 void FlatAffineValueConstraints::addBound(BoundType type, Value val,
2435 int64_t value) {
2436 unsigned pos;
2437 if (!findId(val, &pos))
2438 // This is a pre-condition for this method.
2439 assert(0 && "id not found");
2440 addBound(type, pos, value);
2441 }
2442
removeEquality(unsigned pos)2443 void FlatAffineConstraints::removeEquality(unsigned pos) {
2444 equalities.removeRow(pos);
2445 }
2446
removeInequality(unsigned pos)2447 void FlatAffineConstraints::removeInequality(unsigned pos) {
2448 inequalities.removeRow(pos);
2449 }
2450
removeEqualityRange(unsigned begin,unsigned end)2451 void FlatAffineConstraints::removeEqualityRange(unsigned begin, unsigned end) {
2452 if (begin >= end)
2453 return;
2454 equalities.removeRows(begin, end - begin);
2455 }
2456
removeInequalityRange(unsigned begin,unsigned end)2457 void FlatAffineConstraints::removeInequalityRange(unsigned begin,
2458 unsigned end) {
2459 if (begin >= end)
2460 return;
2461 inequalities.removeRows(begin, end - begin);
2462 }
2463
2464 /// Finds an equality that equates the specified identifier to a constant.
2465 /// Returns the position of the equality row. If 'symbolic' is set to true,
2466 /// symbols are also treated like a constant, i.e., an affine function of the
2467 /// symbols is also treated like a constant. Returns -1 if such an equality
2468 /// could not be found.
findEqualityToConstant(const FlatAffineConstraints & cst,unsigned pos,bool symbolic=false)2469 static int findEqualityToConstant(const FlatAffineConstraints &cst,
2470 unsigned pos, bool symbolic = false) {
2471 assert(pos < cst.getNumIds() && "invalid position");
2472 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
2473 int64_t v = cst.atEq(r, pos);
2474 if (v * v != 1)
2475 continue;
2476 unsigned c;
2477 unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
2478 // This checks for zeros in all positions other than 'pos' in [0, f)
2479 for (c = 0; c < f; c++) {
2480 if (c == pos)
2481 continue;
2482 if (cst.atEq(r, c) != 0) {
2483 // Dependent on another identifier.
2484 break;
2485 }
2486 }
2487 if (c == f)
2488 // Equality is free of other identifiers.
2489 return r;
2490 }
2491 return -1;
2492 }
2493
setAndEliminate(unsigned pos,ArrayRef<int64_t> values)2494 void FlatAffineConstraints::setAndEliminate(unsigned pos,
2495 ArrayRef<int64_t> values) {
2496 if (values.empty())
2497 return;
2498 assert(pos + values.size() <= getNumIds() &&
2499 "invalid position or too many values");
2500 // Setting x_j = p in sum_i a_i x_i + c is equivalent to adding p*a_j to the
2501 // constant term and removing the id x_j. We do this for all the ids
2502 // pos, pos + 1, ... pos + values.size() - 1.
2503 for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
2504 for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
2505 atIneq(r, getNumCols() - 1) += atIneq(r, pos + i) * values[i];
2506 for (unsigned r = 0, e = getNumEqualities(); r < e; r++)
2507 for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
2508 atEq(r, getNumCols() - 1) += atEq(r, pos + i) * values[i];
2509 removeIdRange(pos, pos + values.size());
2510 }
2511
constantFoldId(unsigned pos)2512 LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) {
2513 assert(pos < getNumIds() && "invalid position");
2514 int rowIdx;
2515 if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
2516 return failure();
2517
2518 // atEq(rowIdx, pos) is either -1 or 1.
2519 assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
2520 int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
2521 setAndEliminate(pos, constVal);
2522 return success();
2523 }
2524
constantFoldIdRange(unsigned pos,unsigned num)2525 void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
2526 for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
2527 if (failed(constantFoldId(t)))
2528 t++;
2529 }
2530 }
2531
2532 /// Returns a non-negative constant bound on the extent (upper bound - lower
2533 /// bound) of the specified identifier if it is found to be a constant; returns
2534 /// None if it's not a constant. This methods treats symbolic identifiers
2535 /// specially, i.e., it looks for constant differences between affine
2536 /// expressions involving only the symbolic identifiers. See comments at
2537 /// function definition for example. 'lb', if provided, is set to the lower
2538 /// bound associated with the constant difference. Note that 'lb' is purely
2539 /// symbolic and thus will contain the coefficients of the symbolic identifiers
2540 /// and the constant coefficient.
2541 // Egs: 0 <= i <= 15, return 16.
2542 // s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
2543 // s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
2544 // s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
2545 // ceil(s0 - 7 / 8) = floor(s0 / 8)).
getConstantBoundOnDimSize(unsigned pos,SmallVectorImpl<int64_t> * lb,int64_t * boundFloorDivisor,SmallVectorImpl<int64_t> * ub,unsigned * minLbPos,unsigned * minUbPos) const2546 Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
2547 unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
2548 SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
2549 unsigned *minUbPos) const {
2550 assert(pos < getNumDimIds() && "Invalid identifier position");
2551
2552 // Find an equality for 'pos'^th identifier that equates it to some function
2553 // of the symbolic identifiers (+ constant).
2554 int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
2555 if (eqPos != -1) {
2556 auto eq = getEquality(eqPos);
2557 // If the equality involves a local var, punt for now.
2558 // TODO: this can be handled in the future by using the explicit
2559 // representation of the local vars.
2560 if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1,
2561 [](int64_t coeff) { return coeff == 0; }))
2562 return None;
2563
2564 // This identifier can only take a single value.
2565 if (lb) {
2566 // Set lb to that symbolic value.
2567 lb->resize(getNumSymbolIds() + 1);
2568 if (ub)
2569 ub->resize(getNumSymbolIds() + 1);
2570 for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
2571 int64_t v = atEq(eqPos, pos);
2572 // atEq(eqRow, pos) is either -1 or 1.
2573 assert(v * v == 1);
2574 (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v
2575 : -atEq(eqPos, getNumDimIds() + c) / v;
2576 // Since this is an equality, ub = lb.
2577 if (ub)
2578 (*ub)[c] = (*lb)[c];
2579 }
2580 assert(boundFloorDivisor &&
2581 "both lb and divisor or none should be provided");
2582 *boundFloorDivisor = 1;
2583 }
2584 if (minLbPos)
2585 *minLbPos = eqPos;
2586 if (minUbPos)
2587 *minUbPos = eqPos;
2588 return 1;
2589 }
2590
2591 // Check if the identifier appears at all in any of the inequalities.
2592 unsigned r, e;
2593 for (r = 0, e = getNumInequalities(); r < e; r++) {
2594 if (atIneq(r, pos) != 0)
2595 break;
2596 }
2597 if (r == e)
2598 // If it doesn't, there isn't a bound on it.
2599 return None;
2600
2601 // Positions of constraints that are lower/upper bounds on the variable.
2602 SmallVector<unsigned, 4> lbIndices, ubIndices;
2603
2604 // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
2605 // the bounds can only involve symbolic (and local) identifiers. Since the
2606 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2607 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2608 getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
2609 /*eqIndices=*/nullptr, /*offset=*/0,
2610 /*num=*/getNumDimIds());
2611
2612 Optional<int64_t> minDiff = None;
2613 unsigned minLbPosition = 0, minUbPosition = 0;
2614 for (auto ubPos : ubIndices) {
2615 for (auto lbPos : lbIndices) {
2616 // Look for a lower bound and an upper bound that only differ by a
2617 // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst.
2618 // For example, if ii is the pos^th variable, we are looking for
2619 // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
2620 // minimum among all such constant differences is kept since that's the
2621 // constant bounding the extent of the pos^th variable.
2622 unsigned j, e;
2623 for (j = 0, e = getNumCols() - 1; j < e; j++)
2624 if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
2625 break;
2626 }
2627 if (j < getNumCols() - 1)
2628 continue;
2629 int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
2630 atIneq(lbPos, getNumCols() - 1) + 1,
2631 atIneq(lbPos, pos));
2632 // This bound is non-negative by definition.
2633 diff = std::max<int64_t>(diff, 0);
2634 if (minDiff == None || diff < minDiff) {
2635 minDiff = diff;
2636 minLbPosition = lbPos;
2637 minUbPosition = ubPos;
2638 }
2639 }
2640 }
2641 if (lb && minDiff.hasValue()) {
2642 // Set lb to the symbolic lower bound.
2643 lb->resize(getNumSymbolIds() + 1);
2644 if (ub)
2645 ub->resize(getNumSymbolIds() + 1);
2646 // The lower bound is the ceildiv of the lb constraint over the coefficient
2647 // of the variable at 'pos'. We express the ceildiv equivalently as a floor
2648 // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
2649 // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
2650 *boundFloorDivisor = atIneq(minLbPosition, pos);
2651 assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
2652 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
2653 (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
2654 }
2655 if (ub) {
2656 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
2657 (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
2658 }
2659 // The lower bound leads to a ceildiv while the upper bound is a floordiv
2660 // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
2661 // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
2662 // the constant term for the lower bound.
2663 (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
2664 }
2665 if (minLbPos)
2666 *minLbPos = minLbPosition;
2667 if (minUbPos)
2668 *minUbPos = minUbPosition;
2669 return minDiff;
2670 }
2671
2672 template <bool isLower>
2673 Optional<int64_t>
computeConstantLowerOrUpperBound(unsigned pos)2674 FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
2675 assert(pos < getNumIds() && "invalid position");
2676 // Project to 'pos'.
2677 projectOut(0, pos);
2678 projectOut(1, getNumIds() - 1);
2679 // Check if there's an equality equating the '0'^th identifier to a constant.
2680 int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
2681 if (eqRowIdx != -1)
2682 // atEq(rowIdx, 0) is either -1 or 1.
2683 return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
2684
2685 // Check if the identifier appears at all in any of the inequalities.
2686 unsigned r, e;
2687 for (r = 0, e = getNumInequalities(); r < e; r++) {
2688 if (atIneq(r, 0) != 0)
2689 break;
2690 }
2691 if (r == e)
2692 // If it doesn't, there isn't a bound on it.
2693 return None;
2694
2695 Optional<int64_t> minOrMaxConst = None;
2696
2697 // Take the max across all const lower bounds (or min across all constant
2698 // upper bounds).
2699 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2700 if (isLower) {
2701 if (atIneq(r, 0) <= 0)
2702 // Not a lower bound.
2703 continue;
2704 } else if (atIneq(r, 0) >= 0) {
2705 // Not an upper bound.
2706 continue;
2707 }
2708 unsigned c, f;
2709 for (c = 0, f = getNumCols() - 1; c < f; c++)
2710 if (c != 0 && atIneq(r, c) != 0)
2711 break;
2712 if (c < getNumCols() - 1)
2713 // Not a constant bound.
2714 continue;
2715
2716 int64_t boundConst =
2717 isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
2718 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
2719 if (isLower) {
2720 if (minOrMaxConst == None || boundConst > minOrMaxConst)
2721 minOrMaxConst = boundConst;
2722 } else {
2723 if (minOrMaxConst == None || boundConst < minOrMaxConst)
2724 minOrMaxConst = boundConst;
2725 }
2726 }
2727 return minOrMaxConst;
2728 }
2729
getConstantBound(BoundType type,unsigned pos) const2730 Optional<int64_t> FlatAffineConstraints::getConstantBound(BoundType type,
2731 unsigned pos) const {
2732 assert(type != BoundType::EQ && "EQ not implemented");
2733 FlatAffineConstraints tmpCst(*this);
2734 if (type == BoundType::LB)
2735 return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
2736 return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
2737 }
2738
2739 // A simple (naive and conservative) check for hyper-rectangularity.
isHyperRectangular(unsigned pos,unsigned num) const2740 bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
2741 unsigned num) const {
2742 assert(pos < getNumCols() - 1);
2743 // Check for two non-zero coefficients in the range [pos, pos + sum).
2744 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2745 unsigned sum = 0;
2746 for (unsigned c = pos; c < pos + num; c++) {
2747 if (atIneq(r, c) != 0)
2748 sum++;
2749 }
2750 if (sum > 1)
2751 return false;
2752 }
2753 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2754 unsigned sum = 0;
2755 for (unsigned c = pos; c < pos + num; c++) {
2756 if (atEq(r, c) != 0)
2757 sum++;
2758 }
2759 if (sum > 1)
2760 return false;
2761 }
2762 return true;
2763 }
2764
print(raw_ostream & os) const2765 void FlatAffineConstraints::print(raw_ostream &os) const {
2766 assert(hasConsistentState());
2767 os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds()
2768 << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints()
2769 << " constraints)\n";
2770 os << "(";
2771 for (unsigned i = 0, e = getNumIds(); i < e; i++) {
2772 if (auto *valueCstr = dyn_cast<const FlatAffineValueConstraints>(this)) {
2773 if (valueCstr->hasValue(i))
2774 os << "Value ";
2775 else
2776 os << "None ";
2777 } else {
2778 os << "None ";
2779 }
2780 }
2781 os << " const)\n";
2782 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
2783 for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2784 os << atEq(i, j) << " ";
2785 }
2786 os << "= 0\n";
2787 }
2788 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2789 for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2790 os << atIneq(i, j) << " ";
2791 }
2792 os << ">= 0\n";
2793 }
2794 os << '\n';
2795 }
2796
dump() const2797 void FlatAffineConstraints::dump() const { print(llvm::errs()); }
2798
2799 /// Removes duplicate constraints, trivially true constraints, and constraints
2800 /// that can be detected as redundant as a result of differing only in their
2801 /// constant term part. A constraint of the form <non-negative constant> >= 0 is
2802 /// considered trivially true.
2803 // Uses a DenseSet to hash and detect duplicates followed by a linear scan to
2804 // remove duplicates in place.
removeTrivialRedundancy()2805 void FlatAffineConstraints::removeTrivialRedundancy() {
2806 gcdTightenInequalities();
2807 normalizeConstraintsByGCD();
2808
2809 // A map used to detect redundancy stemming from constraints that only differ
2810 // in their constant term. The value stored is <row position, const term>
2811 // for a given row.
2812 SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
2813 rowsWithoutConstTerm;
2814 // To unique rows.
2815 SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
2816
2817 // Check if constraint is of the form <non-negative-constant> >= 0.
2818 auto isTriviallyValid = [&](unsigned r) -> bool {
2819 for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
2820 if (atIneq(r, c) != 0)
2821 return false;
2822 }
2823 return atIneq(r, getNumCols() - 1) >= 0;
2824 };
2825
2826 // Detect and mark redundant constraints.
2827 SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
2828 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2829 int64_t *rowStart = &inequalities(r, 0);
2830 auto row = ArrayRef<int64_t>(rowStart, getNumCols());
2831 if (isTriviallyValid(r) || !rowSet.insert(row).second) {
2832 redunIneq[r] = true;
2833 continue;
2834 }
2835
2836 // Among constraints that only differ in the constant term part, mark
2837 // everything other than the one with the smallest constant term redundant.
2838 // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
2839 // former two are redundant).
2840 int64_t constTerm = atIneq(r, getNumCols() - 1);
2841 auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
2842 const auto &ret =
2843 rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
2844 if (!ret.second) {
2845 // Check if the other constraint has a higher constant term.
2846 auto &val = ret.first->second;
2847 if (val.second > constTerm) {
2848 // The stored row is redundant. Mark it so, and update with this one.
2849 redunIneq[val.first] = true;
2850 val = {r, constTerm};
2851 } else {
2852 // The one stored makes this one redundant.
2853 redunIneq[r] = true;
2854 }
2855 }
2856 }
2857
2858 // Scan to get rid of all rows marked redundant, in-place.
2859 unsigned pos = 0;
2860 for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
2861 if (!redunIneq[r])
2862 inequalities.copyRow(r, pos++);
2863
2864 inequalities.resizeVertically(pos);
2865
2866 // TODO: consider doing this for equalities as well, but probably not worth
2867 // the savings.
2868 }
2869
clearAndCopyFrom(const FlatAffineConstraints & other)2870 void FlatAffineConstraints::clearAndCopyFrom(
2871 const FlatAffineConstraints &other) {
2872 if (auto *otherValueSet = dyn_cast<const FlatAffineValueConstraints>(&other))
2873 assert(!otherValueSet->hasValues() &&
2874 "cannot copy associated Values into FlatAffineConstraints");
2875 // Note: Assigment operator does not vtable pointer, so kind does not change.
2876 *this = other;
2877 }
2878
clearAndCopyFrom(const FlatAffineConstraints & other)2879 void FlatAffineValueConstraints::clearAndCopyFrom(
2880 const FlatAffineConstraints &other) {
2881 if (auto *otherValueSet =
2882 dyn_cast<const FlatAffineValueConstraints>(&other)) {
2883 *this = *otherValueSet;
2884 } else {
2885 *static_cast<FlatAffineConstraints *>(this) = other;
2886 values.clear();
2887 values.resize(numIds, None);
2888 }
2889 }
2890
removeId(unsigned pos)2891 void FlatAffineConstraints::removeId(unsigned pos) {
2892 removeIdRange(pos, pos + 1);
2893 }
2894
2895 static std::pair<unsigned, unsigned>
getNewNumDimsSymbols(unsigned pos,const FlatAffineConstraints & cst)2896 getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
2897 unsigned numDims = cst.getNumDimIds();
2898 unsigned numSymbols = cst.getNumSymbolIds();
2899 unsigned newNumDims, newNumSymbols;
2900 if (pos < numDims) {
2901 newNumDims = numDims - 1;
2902 newNumSymbols = numSymbols;
2903 } else if (pos < numDims + numSymbols) {
2904 assert(numSymbols >= 1);
2905 newNumDims = numDims;
2906 newNumSymbols = numSymbols - 1;
2907 } else {
2908 newNumDims = numDims;
2909 newNumSymbols = numSymbols;
2910 }
2911 return {newNumDims, newNumSymbols};
2912 }
2913
2914 #undef DEBUG_TYPE
2915 #define DEBUG_TYPE "fm"
2916
2917 /// Eliminates identifier at the specified position using Fourier-Motzkin
2918 /// variable elimination. This technique is exact for rational spaces but
2919 /// conservative (in "rare" cases) for integer spaces. The operation corresponds
2920 /// to a projection operation yielding the (convex) set of integer points
2921 /// contained in the rational shadow of the set. An emptiness test that relies
2922 /// on this method will guarantee emptiness, i.e., it disproves the existence of
2923 /// a solution if it says it's empty.
2924 /// If a non-null isResultIntegerExact is passed, it is set to true if the
2925 /// result is also integer exact. If it's set to false, the obtained solution
2926 /// *may* not be exact, i.e., it may contain integer points that do not have an
2927 /// integer pre-image in the original set.
2928 ///
2929 /// Eg:
2930 /// j >= 0, j <= i + 1
2931 /// i >= 0, i <= N + 1
2932 /// Eliminating i yields,
2933 /// j >= 0, 0 <= N + 1, j - 1 <= N + 1
2934 ///
2935 /// If darkShadow = true, this method computes the dark shadow on elimination;
2936 /// the dark shadow is a convex integer subset of the exact integer shadow. A
2937 /// non-empty dark shadow proves the existence of an integer solution. The
2938 /// elimination in such a case could however be an under-approximation, and thus
2939 /// should not be used for scanning sets or used by itself for dependence
2940 /// checking.
2941 ///
2942 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
2943 /// ^
2944 /// |
2945 /// | * * * * o o
2946 /// i | * * o o o o
2947 /// | o * * * * *
2948 /// --------------->
2949 /// j ->
2950 ///
2951 /// Eliminating i from this system (projecting on the j dimension):
2952 /// rational shadow / integer light shadow: 1 <= j <= 6
2953 /// dark shadow: 3 <= j <= 6
2954 /// exact integer shadow: j = 1 \union 3 <= j <= 6
2955 /// holes/splinters: j = 2
2956 ///
2957 /// darkShadow = false, isResultIntegerExact = nullptr are default values.
2958 // TODO: a slight modification to yield dark shadow version of FM (tightened),
2959 // which can prove the existence of a solution if there is one.
fourierMotzkinEliminate(unsigned pos,bool darkShadow,bool * isResultIntegerExact)2960 void FlatAffineConstraints::fourierMotzkinEliminate(
2961 unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
2962 LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
2963 LLVM_DEBUG(dump());
2964 assert(pos < getNumIds() && "invalid position");
2965 assert(hasConsistentState());
2966
2967 // Check if this identifier can be eliminated through a substitution.
2968 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2969 if (atEq(r, pos) != 0) {
2970 // Use Gaussian elimination here (since we have an equality).
2971 LogicalResult ret = gaussianEliminateId(pos);
2972 (void)ret;
2973 assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
2974 LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
2975 LLVM_DEBUG(dump());
2976 return;
2977 }
2978 }
2979
2980 // A fast linear time tightening.
2981 gcdTightenInequalities();
2982
2983 // Check if the identifier appears at all in any of the inequalities.
2984 unsigned r, e;
2985 for (r = 0, e = getNumInequalities(); r < e; r++) {
2986 if (atIneq(r, pos) != 0)
2987 break;
2988 }
2989 if (r == getNumInequalities()) {
2990 // If it doesn't appear, just remove the column and return.
2991 // TODO: refactor removeColumns to use it from here.
2992 removeId(pos);
2993 LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2994 LLVM_DEBUG(dump());
2995 return;
2996 }
2997
2998 // Positions of constraints that are lower bounds on the variable.
2999 SmallVector<unsigned, 4> lbIndices;
3000 // Positions of constraints that are lower bounds on the variable.
3001 SmallVector<unsigned, 4> ubIndices;
3002 // Positions of constraints that do not involve the variable.
3003 std::vector<unsigned> nbIndices;
3004 nbIndices.reserve(getNumInequalities());
3005
3006 // Gather all lower bounds and upper bounds of the variable. Since the
3007 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
3008 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
3009 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
3010 if (atIneq(r, pos) == 0) {
3011 // Id does not appear in bound.
3012 nbIndices.push_back(r);
3013 } else if (atIneq(r, pos) >= 1) {
3014 // Lower bound.
3015 lbIndices.push_back(r);
3016 } else {
3017 // Upper bound.
3018 ubIndices.push_back(r);
3019 }
3020 }
3021
3022 // Set the number of dimensions, symbols in the resulting system.
3023 const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this);
3024 unsigned newNumDims = dimsSymbols.first;
3025 unsigned newNumSymbols = dimsSymbols.second;
3026
3027 /// Create the new system which has one identifier less.
3028 FlatAffineConstraints newFac(
3029 lbIndices.size() * ubIndices.size() + nbIndices.size(),
3030 getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
3031 /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols);
3032
3033 // This will be used to check if the elimination was integer exact.
3034 unsigned lcmProducts = 1;
3035
3036 // Let x be the variable we are eliminating.
3037 // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
3038 // that c_l, c_u >= 1) we have:
3039 // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
3040 // We thus generate a constraint:
3041 // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
3042 // Note if c_l = c_u = 1, all integer points captured by the resulting
3043 // constraint correspond to integer points in the original system (i.e., they
3044 // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
3045 // integer exact.
3046 for (auto ubPos : ubIndices) {
3047 for (auto lbPos : lbIndices) {
3048 SmallVector<int64_t, 4> ineq;
3049 ineq.reserve(newFac.getNumCols());
3050 int64_t lbCoeff = atIneq(lbPos, pos);
3051 // Note that in the comments above, ubCoeff is the negation of the
3052 // coefficient in the canonical form as the view taken here is that of the
3053 // term being moved to the other size of '>='.
3054 int64_t ubCoeff = -atIneq(ubPos, pos);
3055 // TODO: refactor this loop to avoid all branches inside.
3056 for (unsigned l = 0, e = getNumCols(); l < e; l++) {
3057 if (l == pos)
3058 continue;
3059 assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
3060 int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
3061 ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
3062 atIneq(lbPos, l) * (lcm / lbCoeff));
3063 lcmProducts *= lcm;
3064 }
3065 if (darkShadow) {
3066 // The dark shadow is a convex subset of the exact integer shadow. If
3067 // there is a point here, it proves the existence of a solution.
3068 ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
3069 }
3070 // TODO: we need to have a way to add inequalities in-place in
3071 // FlatAffineConstraints instead of creating and copying over.
3072 newFac.addInequality(ineq);
3073 }
3074 }
3075
3076 LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
3077 << "\n");
3078 if (lcmProducts == 1 && isResultIntegerExact)
3079 *isResultIntegerExact = true;
3080
3081 // Copy over the constraints not involving this variable.
3082 for (auto nbPos : nbIndices) {
3083 SmallVector<int64_t, 4> ineq;
3084 ineq.reserve(getNumCols() - 1);
3085 for (unsigned l = 0, e = getNumCols(); l < e; l++) {
3086 if (l == pos)
3087 continue;
3088 ineq.push_back(atIneq(nbPos, l));
3089 }
3090 newFac.addInequality(ineq);
3091 }
3092
3093 assert(newFac.getNumConstraints() ==
3094 lbIndices.size() * ubIndices.size() + nbIndices.size());
3095
3096 // Copy over the equalities.
3097 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
3098 SmallVector<int64_t, 4> eq;
3099 eq.reserve(newFac.getNumCols());
3100 for (unsigned l = 0, e = getNumCols(); l < e; l++) {
3101 if (l == pos)
3102 continue;
3103 eq.push_back(atEq(r, l));
3104 }
3105 newFac.addEquality(eq);
3106 }
3107
3108 // GCD tightening and normalization allows detection of more trivially
3109 // redundant constraints.
3110 newFac.gcdTightenInequalities();
3111 newFac.normalizeConstraintsByGCD();
3112 newFac.removeTrivialRedundancy();
3113 clearAndCopyFrom(newFac);
3114 LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
3115 LLVM_DEBUG(dump());
3116 }
3117
3118 #undef DEBUG_TYPE
3119 #define DEBUG_TYPE "affine-structures"
3120
fourierMotzkinEliminate(unsigned pos,bool darkShadow,bool * isResultIntegerExact)3121 void FlatAffineValueConstraints::fourierMotzkinEliminate(
3122 unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
3123 SmallVector<Optional<Value>, 8> newVals;
3124 newVals.reserve(numIds - 1);
3125 newVals.append(values.begin(), values.begin() + pos);
3126 newVals.append(values.begin() + pos + 1, values.end());
3127 // Note: Base implementation discards all associated Values.
3128 FlatAffineConstraints::fourierMotzkinEliminate(pos, darkShadow,
3129 isResultIntegerExact);
3130 values = newVals;
3131 assert(values.size() == getNumIds());
3132 }
3133
projectOut(unsigned pos,unsigned num)3134 void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
3135 if (num == 0)
3136 return;
3137
3138 // 'pos' can be at most getNumCols() - 2 if num > 0.
3139 assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
3140 assert(pos + num < getNumCols() && "invalid range");
3141
3142 // Eliminate as many identifiers as possible using Gaussian elimination.
3143 unsigned currentPos = pos;
3144 unsigned numToEliminate = num;
3145 unsigned numGaussianEliminated = 0;
3146
3147 while (currentPos < getNumIds()) {
3148 unsigned curNumEliminated =
3149 gaussianEliminateIds(currentPos, currentPos + numToEliminate);
3150 ++currentPos;
3151 numToEliminate -= curNumEliminated + 1;
3152 numGaussianEliminated += curNumEliminated;
3153 }
3154
3155 // Eliminate the remaining using Fourier-Motzkin.
3156 for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
3157 unsigned numToEliminate = num - numGaussianEliminated - i;
3158 fourierMotzkinEliminate(
3159 getBestIdToEliminate(*this, pos, pos + numToEliminate));
3160 }
3161
3162 // Fast/trivial simplifications.
3163 gcdTightenInequalities();
3164 // Normalize constraints after tightening since the latter impacts this, but
3165 // not the other way round.
3166 normalizeConstraintsByGCD();
3167 }
3168
projectOut(Value val)3169 void FlatAffineValueConstraints::projectOut(Value val) {
3170 unsigned pos;
3171 bool ret = findId(val, &pos);
3172 assert(ret);
3173 (void)ret;
3174 fourierMotzkinEliminate(pos);
3175 }
3176
clearConstraints()3177 void FlatAffineConstraints::clearConstraints() {
3178 equalities.resizeVertically(0);
3179 inequalities.resizeVertically(0);
3180 }
3181
3182 namespace {
3183
3184 enum BoundCmpResult { Greater, Less, Equal, Unknown };
3185
3186 /// Compares two affine bounds whose coefficients are provided in 'first' and
3187 /// 'second'. The last coefficient is the constant term.
compareBounds(ArrayRef<int64_t> a,ArrayRef<int64_t> b)3188 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
3189 assert(a.size() == b.size());
3190
3191 // For the bounds to be comparable, their corresponding identifier
3192 // coefficients should be equal; the constant terms are then compared to
3193 // determine less/greater/equal.
3194
3195 if (!std::equal(a.begin(), a.end() - 1, b.begin()))
3196 return Unknown;
3197
3198 if (a.back() == b.back())
3199 return Equal;
3200
3201 return a.back() < b.back() ? Less : Greater;
3202 }
3203 } // namespace
3204
3205 // Returns constraints that are common to both A & B.
getCommonConstraints(const FlatAffineConstraints & a,const FlatAffineConstraints & b,FlatAffineConstraints & c)3206 static void getCommonConstraints(const FlatAffineConstraints &a,
3207 const FlatAffineConstraints &b,
3208 FlatAffineConstraints &c) {
3209 c.reset(a.getNumDimIds(), a.getNumSymbolIds(), a.getNumLocalIds());
3210 // a naive O(n^2) check should be enough here given the input sizes.
3211 for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) {
3212 for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) {
3213 if (a.getInequality(r) == b.getInequality(s)) {
3214 c.addInequality(a.getInequality(r));
3215 break;
3216 }
3217 }
3218 }
3219 for (unsigned r = 0, e = a.getNumEqualities(); r < e; ++r) {
3220 for (unsigned s = 0, f = b.getNumEqualities(); s < f; ++s) {
3221 if (a.getEquality(r) == b.getEquality(s)) {
3222 c.addEquality(a.getEquality(r));
3223 break;
3224 }
3225 }
3226 }
3227 }
3228
3229 // Computes the bounding box with respect to 'other' by finding the min of the
3230 // lower bounds and the max of the upper bounds along each of the dimensions.
3231 LogicalResult
unionBoundingBox(const FlatAffineConstraints & otherCst)3232 FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
3233 assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
3234 assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
3235 assert(getNumLocalIds() == 0 && "local ids not supported yet here");
3236
3237 // Get the constraints common to both systems; these will be added as is to
3238 // the union.
3239 FlatAffineConstraints commonCst;
3240 getCommonConstraints(*this, otherCst, commonCst);
3241
3242 std::vector<SmallVector<int64_t, 8>> boundingLbs;
3243 std::vector<SmallVector<int64_t, 8>> boundingUbs;
3244 boundingLbs.reserve(2 * getNumDimIds());
3245 boundingUbs.reserve(2 * getNumDimIds());
3246
3247 // To hold lower and upper bounds for each dimension.
3248 SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
3249 // To compute min of lower bounds and max of upper bounds for each dimension.
3250 SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
3251 SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
3252 // To compute final new lower and upper bounds for the union.
3253 SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
3254
3255 int64_t lbFloorDivisor, otherLbFloorDivisor;
3256 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
3257 auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
3258 if (!extent.hasValue())
3259 // TODO: symbolic extents when necessary.
3260 // TODO: handle union if a dimension is unbounded.
3261 return failure();
3262
3263 auto otherExtent = otherCst.getConstantBoundOnDimSize(
3264 d, &otherLb, &otherLbFloorDivisor, &otherUb);
3265 if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
3266 // TODO: symbolic extents when necessary.
3267 return failure();
3268
3269 assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
3270
3271 auto res = compareBounds(lb, otherLb);
3272 // Identify min.
3273 if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
3274 minLb = lb;
3275 // Since the divisor is for a floordiv, we need to convert to ceildiv,
3276 // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
3277 // div * i >= expr - div + 1.
3278 minLb.back() -= lbFloorDivisor - 1;
3279 } else if (res == BoundCmpResult::Greater) {
3280 minLb = otherLb;
3281 minLb.back() -= otherLbFloorDivisor - 1;
3282 } else {
3283 // Uncomparable - check for constant lower/upper bounds.
3284 auto constLb = getConstantBound(BoundType::LB, d);
3285 auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
3286 if (!constLb.hasValue() || !constOtherLb.hasValue())
3287 return failure();
3288 std::fill(minLb.begin(), minLb.end(), 0);
3289 minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
3290 }
3291
3292 // Do the same for ub's but max of upper bounds. Identify max.
3293 auto uRes = compareBounds(ub, otherUb);
3294 if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
3295 maxUb = ub;
3296 } else if (uRes == BoundCmpResult::Less) {
3297 maxUb = otherUb;
3298 } else {
3299 // Uncomparable - check for constant lower/upper bounds.
3300 auto constUb = getConstantBound(BoundType::UB, d);
3301 auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
3302 if (!constUb.hasValue() || !constOtherUb.hasValue())
3303 return failure();
3304 std::fill(maxUb.begin(), maxUb.end(), 0);
3305 maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
3306 }
3307
3308 std::fill(newLb.begin(), newLb.end(), 0);
3309 std::fill(newUb.begin(), newUb.end(), 0);
3310
3311 // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
3312 // and so it's the divisor for newLb and newUb as well.
3313 newLb[d] = lbFloorDivisor;
3314 newUb[d] = -lbFloorDivisor;
3315 // Copy over the symbolic part + constant term.
3316 std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
3317 std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
3318 newLb.begin() + getNumDimIds(), std::negate<int64_t>());
3319 std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
3320
3321 boundingLbs.push_back(newLb);
3322 boundingUbs.push_back(newUb);
3323 }
3324
3325 // Clear all constraints and add the lower/upper bounds for the bounding box.
3326 clearConstraints();
3327 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
3328 addInequality(boundingLbs[d]);
3329 addInequality(boundingUbs[d]);
3330 }
3331
3332 // Add the constraints that were common to both systems.
3333 append(commonCst);
3334 removeTrivialRedundancy();
3335
3336 // TODO: copy over pure symbolic constraints from this and 'other' over to the
3337 // union (since the above are just the union along dimensions); we shouldn't
3338 // be discarding any other constraints on the symbols.
3339
3340 return success();
3341 }
3342
unionBoundingBox(const FlatAffineValueConstraints & otherCst)3343 LogicalResult FlatAffineValueConstraints::unionBoundingBox(
3344 const FlatAffineValueConstraints &otherCst) {
3345 assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
3346 assert(otherCst.getMaybeValues()
3347 .slice(0, getNumDimIds())
3348 .equals(getMaybeValues().slice(0, getNumDimIds())) &&
3349 "dim values mismatch");
3350 assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
3351 assert(getNumLocalIds() == 0 && "local ids not supported yet here");
3352
3353 // Align `other` to this.
3354 if (!areIdsAligned(*this, otherCst)) {
3355 FlatAffineValueConstraints otherCopy(otherCst);
3356 mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy);
3357 return FlatAffineConstraints::unionBoundingBox(otherCopy);
3358 }
3359
3360 return FlatAffineConstraints::unionBoundingBox(otherCst);
3361 }
3362
3363 /// Compute an explicit representation for local vars. For all systems coming
3364 /// from MLIR integer sets, maps, or expressions where local vars were
3365 /// introduced to model floordivs and mods, this always succeeds.
computeLocalVars(const FlatAffineConstraints & cst,SmallVectorImpl<AffineExpr> & memo,MLIRContext * context)3366 static LogicalResult computeLocalVars(const FlatAffineConstraints &cst,
3367 SmallVectorImpl<AffineExpr> &memo,
3368 MLIRContext *context) {
3369 unsigned numDims = cst.getNumDimIds();
3370 unsigned numSyms = cst.getNumSymbolIds();
3371
3372 // Initialize dimensional and symbolic identifiers.
3373 for (unsigned i = 0; i < numDims; i++)
3374 memo[i] = getAffineDimExpr(i, context);
3375 for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
3376 memo[i] = getAffineSymbolExpr(i - numDims, context);
3377
3378 bool changed;
3379 do {
3380 // Each time `changed` is true at the end of this iteration, one or more
3381 // local vars would have been detected as floordivs and set in memo; so the
3382 // number of null entries in memo[...] strictly reduces; so this converges.
3383 changed = false;
3384 for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i)
3385 if (!memo[numDims + numSyms + i] &&
3386 detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
3387 changed = true;
3388 } while (changed);
3389
3390 ArrayRef<AffineExpr> localExprs =
3391 ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds());
3392 return success(
3393 llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
3394 }
3395
getIneqAsAffineValueMap(unsigned pos,unsigned ineqPos,AffineValueMap & vmap,MLIRContext * context) const3396 void FlatAffineValueConstraints::getIneqAsAffineValueMap(
3397 unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
3398 MLIRContext *context) const {
3399 unsigned numDims = getNumDimIds();
3400 unsigned numSyms = getNumSymbolIds();
3401
3402 assert(pos < numDims && "invalid position");
3403 assert(ineqPos < getNumInequalities() && "invalid inequality position");
3404
3405 // Get expressions for local vars.
3406 SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
3407 if (failed(computeLocalVars(*this, memo, context)))
3408 assert(false &&
3409 "one or more local exprs do not have an explicit representation");
3410 auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
3411
3412 // Compute the AffineExpr lower/upper bound for this inequality.
3413 ArrayRef<int64_t> inequality = getInequality(ineqPos);
3414 SmallVector<int64_t, 8> bound;
3415 bound.reserve(getNumCols() - 1);
3416 // Everything other than the coefficient at `pos`.
3417 bound.append(inequality.begin(), inequality.begin() + pos);
3418 bound.append(inequality.begin() + pos + 1, inequality.end());
3419
3420 if (inequality[pos] > 0)
3421 // Lower bound.
3422 std::transform(bound.begin(), bound.end(), bound.begin(),
3423 std::negate<int64_t>());
3424 else
3425 // Upper bound (which is exclusive).
3426 bound.back() += 1;
3427
3428 // Convert to AffineExpr (tree) form.
3429 auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
3430 localExprs, context);
3431
3432 // Get the values to bind to this affine expr (all dims and symbols).
3433 SmallVector<Value, 4> operands;
3434 getValues(0, pos, &operands);
3435 SmallVector<Value, 4> trailingOperands;
3436 getValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands);
3437 operands.append(trailingOperands.begin(), trailingOperands.end());
3438 vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
3439 }
3440
3441 /// Returns true if the pos^th column is all zero for both inequalities and
3442 /// equalities..
isColZero(const FlatAffineConstraints & cst,unsigned pos)3443 static bool isColZero(const FlatAffineConstraints &cst, unsigned pos) {
3444 unsigned rowPos;
3445 return !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/false, &rowPos) &&
3446 !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/true, &rowPos);
3447 }
3448
getAsIntegerSet(MLIRContext * context) const3449 IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
3450 if (getNumConstraints() == 0)
3451 // Return universal set (always true): 0 == 0.
3452 return IntegerSet::get(getNumDimIds(), getNumSymbolIds(),
3453 getAffineConstantExpr(/*constant=*/0, context),
3454 /*eqFlags=*/true);
3455
3456 // Construct local references.
3457 SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
3458
3459 if (failed(computeLocalVars(*this, memo, context))) {
3460 // Check if the local variables without an explicit representation have
3461 // zero coefficients everywhere.
3462 for (unsigned i = getNumDimAndSymbolIds(), e = getNumIds(); i < e; ++i) {
3463 if (!memo[i] && !isColZero(*this, /*pos=*/i)) {
3464 LLVM_DEBUG(llvm::dbgs() << "one or more local exprs do not have an "
3465 "explicit representation");
3466 return IntegerSet();
3467 }
3468 }
3469 }
3470
3471 ArrayRef<AffineExpr> localExprs =
3472 ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
3473
3474 // Construct the IntegerSet from the equalities/inequalities.
3475 unsigned numDims = getNumDimIds();
3476 unsigned numSyms = getNumSymbolIds();
3477
3478 SmallVector<bool, 16> eqFlags(getNumConstraints());
3479 std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
3480 std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
3481
3482 SmallVector<AffineExpr, 8> exprs;
3483 exprs.reserve(getNumConstraints());
3484
3485 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
3486 exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
3487 localExprs, context));
3488 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
3489 exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
3490 numSyms, localExprs, context));
3491 return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
3492 }
3493
3494 /// Find positions of inequalities and equalities that do not have a coefficient
3495 /// for [pos, pos + num) identifiers.
getIndependentConstraints(const FlatAffineConstraints & cst,unsigned pos,unsigned num,SmallVectorImpl<unsigned> & nbIneqIndices,SmallVectorImpl<unsigned> & nbEqIndices)3496 static void getIndependentConstraints(const FlatAffineConstraints &cst,
3497 unsigned pos, unsigned num,
3498 SmallVectorImpl<unsigned> &nbIneqIndices,
3499 SmallVectorImpl<unsigned> &nbEqIndices) {
3500 assert(pos < cst.getNumIds() && "invalid start position");
3501 assert(pos + num <= cst.getNumIds() && "invalid limit");
3502
3503 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
3504 // The bounds are to be independent of [offset, offset + num) columns.
3505 unsigned c;
3506 for (c = pos; c < pos + num; ++c) {
3507 if (cst.atIneq(r, c) != 0)
3508 break;
3509 }
3510 if (c == pos + num)
3511 nbIneqIndices.push_back(r);
3512 }
3513
3514 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
3515 // The bounds are to be independent of [offset, offset + num) columns.
3516 unsigned c;
3517 for (c = pos; c < pos + num; ++c) {
3518 if (cst.atEq(r, c) != 0)
3519 break;
3520 }
3521 if (c == pos + num)
3522 nbEqIndices.push_back(r);
3523 }
3524 }
3525
removeIndependentConstraints(unsigned pos,unsigned num)3526 void FlatAffineConstraints::removeIndependentConstraints(unsigned pos,
3527 unsigned num) {
3528 assert(pos + num <= getNumIds() && "invalid range");
3529
3530 // Remove constraints that are independent of these identifiers.
3531 SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
3532 getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
3533
3534 // Iterate in reverse so that indices don't have to be updated.
3535 // TODO: This method can be made more efficient (because removal of each
3536 // inequality leads to much shifting/copying in the underlying buffer).
3537 for (auto nbIndex : llvm::reverse(nbIneqIndices))
3538 removeInequality(nbIndex);
3539 for (auto nbIndex : llvm::reverse(nbEqIndices))
3540 removeEquality(nbIndex);
3541 }
3542
alignAffineMapWithValues(AffineMap map,ValueRange operands,ValueRange dims,ValueRange syms,SmallVector<Value> * newSyms)3543 AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
3544 ValueRange dims, ValueRange syms,
3545 SmallVector<Value> *newSyms) {
3546 assert(operands.size() == map.getNumInputs() &&
3547 "expected same number of operands and map inputs");
3548 MLIRContext *ctx = map.getContext();
3549 Builder builder(ctx);
3550 SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
3551 unsigned numSymbols = syms.size();
3552 SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
3553 if (newSyms) {
3554 newSyms->clear();
3555 newSyms->append(syms.begin(), syms.end());
3556 }
3557
3558 for (auto operand : llvm::enumerate(operands)) {
3559 // Compute replacement dim/sym of operand.
3560 AffineExpr replacement;
3561 auto dimIt = std::find(dims.begin(), dims.end(), operand.value());
3562 auto symIt = std::find(syms.begin(), syms.end(), operand.value());
3563 if (dimIt != dims.end()) {
3564 replacement =
3565 builder.getAffineDimExpr(std::distance(dims.begin(), dimIt));
3566 } else if (symIt != syms.end()) {
3567 replacement =
3568 builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt));
3569 } else {
3570 // This operand is neither a dimension nor a symbol. Add it as a new
3571 // symbol.
3572 replacement = builder.getAffineSymbolExpr(numSymbols++);
3573 if (newSyms)
3574 newSyms->push_back(operand.value());
3575 }
3576 // Add to corresponding replacements vector.
3577 if (operand.index() < map.getNumDims()) {
3578 dimReplacements[operand.index()] = replacement;
3579 } else {
3580 symReplacements[operand.index() - map.getNumDims()] = replacement;
3581 }
3582 }
3583
3584 return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
3585 dims.size(), numSymbols);
3586 }
3587