1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Structures for affine/polyhedral analysis of affine dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/AffineStructures.h"
14 #include "mlir/Analysis/LinearTransform.h"
15 #include "mlir/Analysis/Presburger/Simplex.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/AffineExprVisitor.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "mlir/Support/LLVM.h"
22 #include "mlir/Support/MathExtras.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 #define DEBUG_TYPE "affine-structures"
30 
31 using namespace mlir;
32 using llvm::SmallDenseMap;
33 using llvm::SmallDenseSet;
34 
35 namespace {
36 
37 // See comments for SimpleAffineExprFlattener.
38 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
39 // constraint information associated with mod's, floordiv's, and ceildiv's
40 // in FlatAffineConstraints 'localVarCst'.
41 struct AffineExprFlattener : public SimpleAffineExprFlattener {
42 public:
43   // Constraints connecting newly introduced local variables (for mod's and
44   // div's) to existing (dimensional and symbolic) ones. These are always
45   // inequalities.
46   FlatAffineConstraints localVarCst;
47 
AffineExprFlattener__anon7a31c6020111::AffineExprFlattener48   AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx)
49       : SimpleAffineExprFlattener(nDims, nSymbols) {
50     localVarCst.reset(nDims, nSymbols, /*numLocals=*/0);
51   }
52 
53 private:
54   // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
55   // The local identifier added is always a floordiv of a pure add/mul affine
56   // function of other identifiers, coefficients of which are specified in
57   // `dividend' and with respect to the positive constant `divisor'. localExpr
58   // is the simplified tree expression (AffineExpr) corresponding to the
59   // quantifier.
addLocalFloorDivId__anon7a31c6020111::AffineExprFlattener60   void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
61                           AffineExpr localExpr) override {
62     SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
63     // Update localVarCst.
64     localVarCst.addLocalFloorDiv(dividend, divisor);
65   }
66 };
67 
68 } // end anonymous namespace
69 
70 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
71 // flattened (i.e., semi-affine expressions not handled yet).
72 static LogicalResult
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs,unsigned numDims,unsigned numSymbols,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)73 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
74                         unsigned numSymbols,
75                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
76                         FlatAffineConstraints *localVarCst) {
77   if (exprs.empty()) {
78     localVarCst->reset(numDims, numSymbols);
79     return success();
80   }
81 
82   AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
83   // Use the same flattener to simplify each expression successively. This way
84   // local identifiers / expressions are shared.
85   for (auto expr : exprs) {
86     if (!expr.isPureAffine())
87       return failure();
88 
89     flattener.walkPostOrder(expr);
90   }
91 
92   assert(flattener.operandExprStack.size() == exprs.size());
93   flattenedExprs->clear();
94   flattenedExprs->assign(flattener.operandExprStack.begin(),
95                          flattener.operandExprStack.end());
96 
97   if (localVarCst)
98     localVarCst->clearAndCopyFrom(flattener.localVarCst);
99 
100   return success();
101 }
102 
103 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
104 // be flattened (semi-affine expressions not handled yet).
105 LogicalResult
getFlattenedAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols,SmallVectorImpl<int64_t> * flattenedExpr,FlatAffineConstraints * localVarCst)106 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
107                              unsigned numSymbols,
108                              SmallVectorImpl<int64_t> *flattenedExpr,
109                              FlatAffineConstraints *localVarCst) {
110   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
111   LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
112                                                 &flattenedExprs, localVarCst);
113   *flattenedExpr = flattenedExprs[0];
114   return ret;
115 }
116 
117 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
118 /// flattened (i.e., semi-affine expressions not handled yet).
getFlattenedAffineExprs(AffineMap map,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)119 LogicalResult mlir::getFlattenedAffineExprs(
120     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
121     FlatAffineConstraints *localVarCst) {
122   if (map.getNumResults() == 0) {
123     localVarCst->reset(map.getNumDims(), map.getNumSymbols());
124     return success();
125   }
126   return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
127                                    map.getNumSymbols(), flattenedExprs,
128                                    localVarCst);
129 }
130 
getFlattenedAffineExprs(IntegerSet set,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)131 LogicalResult mlir::getFlattenedAffineExprs(
132     IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
133     FlatAffineConstraints *localVarCst) {
134   if (set.getNumConstraints() == 0) {
135     localVarCst->reset(set.getNumDims(), set.getNumSymbols());
136     return success();
137   }
138   return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
139                                    set.getNumSymbols(), flattenedExprs,
140                                    localVarCst);
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // FlatAffineConstraints / FlatAffineValueConstraints.
145 //===----------------------------------------------------------------------===//
146 
147 // Clones this object.
clone() const148 std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
149   return std::make_unique<FlatAffineConstraints>(*this);
150 }
151 
152 std::unique_ptr<FlatAffineValueConstraints>
clone() const153 FlatAffineValueConstraints::clone() const {
154   return std::make_unique<FlatAffineValueConstraints>(*this);
155 }
156 
157 // Construct from an IntegerSet.
FlatAffineConstraints(IntegerSet set)158 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
159     : numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
160       numSymbols(set.getNumSymbols()),
161       equalities(0, numIds + 1, set.getNumEqualities(), numIds + 1),
162       inequalities(0, numIds + 1, set.getNumInequalities(), numIds + 1) {
163   // Flatten expressions and add them to the constraint system.
164   std::vector<SmallVector<int64_t, 8>> flatExprs;
165   FlatAffineConstraints localVarCst;
166   if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
167     assert(false && "flattening unimplemented for semi-affine integer sets");
168     return;
169   }
170   assert(flatExprs.size() == set.getNumConstraints());
171   appendLocalId(/*num=*/localVarCst.getNumLocalIds());
172 
173   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
174     const auto &flatExpr = flatExprs[i];
175     assert(flatExpr.size() == getNumCols());
176     if (set.getEqFlags()[i]) {
177       addEquality(flatExpr);
178     } else {
179       addInequality(flatExpr);
180     }
181   }
182   // Add the other constraints involving local id's from flattening.
183   append(localVarCst);
184 }
185 
186 // Construct from an IntegerSet.
FlatAffineValueConstraints(IntegerSet set)187 FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set)
188     : FlatAffineConstraints(set) {
189   values.resize(numIds, None);
190 }
191 
192 // Construct a hyperrectangular constraint set from ValueRanges that represent
193 // induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are
194 // expected to match one to one. The order of variables and constraints is:
195 //
196 // ivs | lbs | ubs | eq/ineq
197 // ----+-----+-----+---------
198 //   1   -1     0      >= 0
199 // ----+-----+-----+---------
200 //  -1    0     1      >= 0
201 //
202 // All dimensions as set as DimId.
203 FlatAffineValueConstraints
getHyperrectangular(ValueRange ivs,ValueRange lbs,ValueRange ubs)204 FlatAffineValueConstraints::getHyperrectangular(ValueRange ivs, ValueRange lbs,
205                                                 ValueRange ubs) {
206   FlatAffineValueConstraints res;
207   unsigned nIvs = ivs.size();
208   assert(nIvs == lbs.size() && "expected as many lower bounds as ivs");
209   assert(nIvs == ubs.size() && "expected as many upper bounds as ivs");
210 
211   if (nIvs == 0)
212     return res;
213 
214   res.appendDimId(ivs);
215   unsigned lbsStart = res.appendDimId(lbs);
216   unsigned ubsStart = res.appendDimId(ubs);
217 
218   MLIRContext *ctx = ivs.front().getContext();
219   for (int ivIdx = 0, e = nIvs; ivIdx < e; ++ivIdx) {
220     // iv - lb >= 0
221     AffineMap lb = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
222                                   getAffineDimExpr(lbsStart + ivIdx, ctx));
223     if (failed(res.addBound(BoundType::LB, ivIdx, lb)))
224       llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
225     // -iv + ub >= 0
226     AffineMap ub = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
227                                   getAffineDimExpr(ubsStart + ivIdx, ctx));
228     if (failed(res.addBound(BoundType::UB, ivIdx, ub)))
229       llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
230   }
231   return res;
232 }
233 
reset(unsigned numReservedInequalities,unsigned numReservedEqualities,unsigned newNumReservedCols,unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals)234 void FlatAffineConstraints::reset(unsigned numReservedInequalities,
235                                   unsigned numReservedEqualities,
236                                   unsigned newNumReservedCols,
237                                   unsigned newNumDims, unsigned newNumSymbols,
238                                   unsigned newNumLocals) {
239   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
240          "minimum 1 column");
241   *this = FlatAffineConstraints(numReservedInequalities, numReservedEqualities,
242                                 newNumReservedCols, newNumDims, newNumSymbols,
243                                 newNumLocals);
244 }
245 
reset(unsigned numReservedInequalities,unsigned numReservedEqualities,unsigned newNumReservedCols,unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals)246 void FlatAffineValueConstraints::reset(unsigned numReservedInequalities,
247                                        unsigned numReservedEqualities,
248                                        unsigned newNumReservedCols,
249                                        unsigned newNumDims,
250                                        unsigned newNumSymbols,
251                                        unsigned newNumLocals) {
252   reset(numReservedInequalities, numReservedEqualities, newNumReservedCols,
253         newNumDims, newNumSymbols, newNumLocals, /*valArgs=*/{});
254 }
255 
reset(unsigned numReservedInequalities,unsigned numReservedEqualities,unsigned newNumReservedCols,unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> valArgs)256 void FlatAffineValueConstraints::reset(
257     unsigned numReservedInequalities, unsigned numReservedEqualities,
258     unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols,
259     unsigned newNumLocals, ArrayRef<Value> valArgs) {
260   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
261          "minimum 1 column");
262   SmallVector<Optional<Value>, 8> newVals;
263   if (!valArgs.empty())
264     newVals.assign(valArgs.begin(), valArgs.end());
265 
266   *this = FlatAffineValueConstraints(
267       numReservedInequalities, numReservedEqualities, newNumReservedCols,
268       newNumDims, newNumSymbols, newNumLocals, newVals);
269 }
270 
reset(unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals)271 void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
272                                   unsigned newNumLocals) {
273   reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
274         newNumSymbols, newNumLocals);
275 }
276 
reset(unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> valArgs)277 void FlatAffineValueConstraints::reset(unsigned newNumDims,
278                                        unsigned newNumSymbols,
279                                        unsigned newNumLocals,
280                                        ArrayRef<Value> valArgs) {
281   reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
282         newNumSymbols, newNumLocals, valArgs);
283 }
284 
append(const FlatAffineConstraints & other)285 void FlatAffineConstraints::append(const FlatAffineConstraints &other) {
286   assert(other.getNumCols() == getNumCols());
287   assert(other.getNumDimIds() == getNumDimIds());
288   assert(other.getNumSymbolIds() == getNumSymbolIds());
289 
290   inequalities.reserveRows(inequalities.getNumRows() +
291                            other.getNumInequalities());
292   equalities.reserveRows(equalities.getNumRows() + other.getNumEqualities());
293 
294   for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
295     addInequality(other.getInequality(r));
296   }
297   for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
298     addEquality(other.getEquality(r));
299   }
300 }
301 
appendDimId(unsigned num)302 unsigned FlatAffineConstraints::appendDimId(unsigned num) {
303   unsigned pos = getNumDimIds();
304   insertId(IdKind::Dimension, pos, num);
305   return pos;
306 }
307 
appendDimId(ValueRange vals)308 unsigned FlatAffineValueConstraints::appendDimId(ValueRange vals) {
309   unsigned pos = getNumDimIds();
310   insertId(IdKind::Dimension, pos, vals);
311   return pos;
312 }
313 
appendSymbolId(unsigned num)314 unsigned FlatAffineConstraints::appendSymbolId(unsigned num) {
315   unsigned pos = getNumSymbolIds();
316   insertId(IdKind::Symbol, pos, num);
317   return pos;
318 }
319 
appendSymbolId(ValueRange vals)320 unsigned FlatAffineValueConstraints::appendSymbolId(ValueRange vals) {
321   unsigned pos = getNumSymbolIds();
322   insertId(IdKind::Symbol, pos, vals);
323   return pos;
324 }
325 
appendLocalId(unsigned num)326 unsigned FlatAffineConstraints::appendLocalId(unsigned num) {
327   unsigned pos = getNumLocalIds();
328   insertId(IdKind::Local, pos, num);
329   return pos;
330 }
331 
insertDimId(unsigned pos,unsigned num)332 unsigned FlatAffineConstraints::insertDimId(unsigned pos, unsigned num) {
333   return insertId(IdKind::Dimension, pos, num);
334 }
335 
insertDimId(unsigned pos,ValueRange vals)336 unsigned FlatAffineValueConstraints::insertDimId(unsigned pos,
337                                                  ValueRange vals) {
338   return insertId(IdKind::Dimension, pos, vals);
339 }
340 
insertSymbolId(unsigned pos,unsigned num)341 unsigned FlatAffineConstraints::insertSymbolId(unsigned pos, unsigned num) {
342   return insertId(IdKind::Symbol, pos, num);
343 }
344 
insertSymbolId(unsigned pos,ValueRange vals)345 unsigned FlatAffineValueConstraints::insertSymbolId(unsigned pos,
346                                                     ValueRange vals) {
347   return insertId(IdKind::Symbol, pos, vals);
348 }
349 
insertLocalId(unsigned pos,unsigned num)350 unsigned FlatAffineConstraints::insertLocalId(unsigned pos, unsigned num) {
351   return insertId(IdKind::Local, pos, num);
352 }
353 
insertId(IdKind kind,unsigned pos,unsigned num)354 unsigned FlatAffineConstraints::insertId(IdKind kind, unsigned pos,
355                                          unsigned num) {
356   assertAtMostNumIdKind(pos, kind);
357 
358   unsigned absolutePos = getIdKindOffset(kind) + pos;
359   if (kind == IdKind::Dimension)
360     numDims += num;
361   else if (kind == IdKind::Symbol)
362     numSymbols += num;
363   numIds += num;
364 
365   inequalities.insertColumns(absolutePos, num);
366   equalities.insertColumns(absolutePos, num);
367 
368   return absolutePos;
369 }
370 
assertAtMostNumIdKind(unsigned val,IdKind kind) const371 void FlatAffineConstraints::assertAtMostNumIdKind(unsigned val,
372                                                   IdKind kind) const {
373   if (kind == IdKind::Dimension)
374     assert(val <= getNumDimIds());
375   else if (kind == IdKind::Symbol)
376     assert(val <= getNumSymbolIds());
377   else if (kind == IdKind::Local)
378     assert(val <= getNumLocalIds());
379   else
380     llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!");
381 }
382 
getIdKindOffset(IdKind kind) const383 unsigned FlatAffineConstraints::getIdKindOffset(IdKind kind) const {
384   if (kind == IdKind::Dimension)
385     return 0;
386   if (kind == IdKind::Symbol)
387     return getNumDimIds();
388   if (kind == IdKind::Local)
389     return getNumDimAndSymbolIds();
390   llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!");
391 }
392 
insertId(IdKind kind,unsigned pos,unsigned num)393 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
394                                               unsigned num) {
395   unsigned absolutePos = FlatAffineConstraints::insertId(kind, pos, num);
396   values.insert(values.begin() + absolutePos, num, None);
397   assert(values.size() == getNumIds());
398   return absolutePos;
399 }
400 
insertId(IdKind kind,unsigned pos,ValueRange vals)401 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
402                                               ValueRange vals) {
403   assert(!vals.empty() && "expected ValueRange with Values");
404   unsigned num = vals.size();
405   unsigned absolutePos = FlatAffineConstraints::insertId(kind, pos, num);
406 
407   // If a Value is provided, insert it; otherwise use None.
408   for (unsigned i = 0; i < num; ++i)
409     values.insert(values.begin() + absolutePos + i,
410                   vals[i] ? Optional<Value>(vals[i]) : None);
411 
412   assert(values.size() == getNumIds());
413   return absolutePos;
414 }
415 
hasValues() const416 bool FlatAffineValueConstraints::hasValues() const {
417   return llvm::find_if(values, [](Optional<Value> id) {
418            return id.hasValue();
419          }) != values.end();
420 }
421 
removeId(IdKind kind,unsigned pos)422 void FlatAffineConstraints::removeId(IdKind kind, unsigned pos) {
423   removeIdRange(kind, pos, pos + 1);
424 }
425 
removeIdRange(IdKind kind,unsigned idStart,unsigned idLimit)426 void FlatAffineConstraints::removeIdRange(IdKind kind, unsigned idStart,
427                                           unsigned idLimit) {
428   assertAtMostNumIdKind(idLimit, kind);
429   removeIdRange(getIdKindOffset(kind) + idStart,
430                 getIdKindOffset(kind) + idLimit);
431 }
432 
433 /// Checks if two constraint systems are in the same space, i.e., if they are
434 /// associated with the same set of identifiers, appearing in the same order.
areIdsAligned(const FlatAffineValueConstraints & a,const FlatAffineValueConstraints & b)435 static bool areIdsAligned(const FlatAffineValueConstraints &a,
436                           const FlatAffineValueConstraints &b) {
437   return a.getNumDimIds() == b.getNumDimIds() &&
438          a.getNumSymbolIds() == b.getNumSymbolIds() &&
439          a.getNumIds() == b.getNumIds() &&
440          a.getMaybeValues().equals(b.getMaybeValues());
441 }
442 
443 /// Calls areIdsAligned to check if two constraint systems have the same set
444 /// of identifiers in the same order.
areIdsAlignedWithOther(const FlatAffineValueConstraints & other)445 bool FlatAffineValueConstraints::areIdsAlignedWithOther(
446     const FlatAffineValueConstraints &other) {
447   return areIdsAligned(*this, other);
448 }
449 
450 /// Checks if the SSA values associated with `cst`'s identifiers are unique.
451 static bool LLVM_ATTRIBUTE_UNUSED
areIdsUnique(const FlatAffineValueConstraints & cst)452 areIdsUnique(const FlatAffineValueConstraints &cst) {
453   SmallPtrSet<Value, 8> uniqueIds;
454   for (auto val : cst.getMaybeValues()) {
455     if (val.hasValue() && !uniqueIds.insert(val.getValue()).second)
456       return false;
457   }
458   return true;
459 }
460 
461 /// Merge and align the identifiers of A and B starting at 'offset', so that
462 /// both constraint systems get the union of the contained identifiers that is
463 /// dimension-wise and symbol-wise unique; both constraint systems are updated
464 /// so that they have the union of all identifiers, with A's original
465 /// identifiers appearing first followed by any of B's identifiers that didn't
466 /// appear in A. Local identifiers of each system are by design separate/local
467 /// and are placed one after other (A's followed by B's).
468 //  E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
469 //        Output: both A, B have (%i, %j, %k) [%M, %N, %P]
mergeAndAlignIds(unsigned offset,FlatAffineValueConstraints * a,FlatAffineValueConstraints * b)470 static void mergeAndAlignIds(unsigned offset, FlatAffineValueConstraints *a,
471                              FlatAffineValueConstraints *b) {
472   assert(offset <= a->getNumDimIds() && offset <= b->getNumDimIds());
473   // A merge/align isn't meaningful if a cst's ids aren't distinct.
474   assert(areIdsUnique(*a) && "A's values aren't unique");
475   assert(areIdsUnique(*b) && "B's values aren't unique");
476 
477   assert(std::all_of(a->getMaybeValues().begin() + offset,
478                      a->getMaybeValues().begin() + a->getNumDimAndSymbolIds(),
479                      [](Optional<Value> id) { return id.hasValue(); }));
480 
481   assert(std::all_of(b->getMaybeValues().begin() + offset,
482                      b->getMaybeValues().begin() + b->getNumDimAndSymbolIds(),
483                      [](Optional<Value> id) { return id.hasValue(); }));
484 
485   // Bring A and B to common local space
486   a->mergeLocalIds(*b);
487 
488   SmallVector<Value, 4> aDimValues;
489   a->getValues(offset, a->getNumDimIds(), &aDimValues);
490 
491   {
492     // Merge dims from A into B.
493     unsigned d = offset;
494     for (auto aDimValue : aDimValues) {
495       unsigned loc;
496       if (b->findId(aDimValue, &loc)) {
497         assert(loc >= offset && "A's dim appears in B's aligned range");
498         assert(loc < b->getNumDimIds() &&
499                "A's dim appears in B's non-dim position");
500         b->swapId(d, loc);
501       } else {
502         b->insertDimId(d, aDimValue);
503       }
504       d++;
505     }
506     // Dimensions that are in B, but not in A, are added at the end.
507     for (unsigned t = a->getNumDimIds(), e = b->getNumDimIds(); t < e; t++) {
508       a->appendDimId(b->getValue(t));
509     }
510     assert(a->getNumDimIds() == b->getNumDimIds() &&
511            "expected same number of dims");
512   }
513 
514   // Merge and align symbols of A and B
515   a->mergeSymbolIds(*b);
516 
517   assert(areIdsAligned(*a, *b) && "IDs expected to be aligned");
518 }
519 
520 // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
mergeAndAlignIdsWithOther(unsigned offset,FlatAffineValueConstraints * other)521 void FlatAffineValueConstraints::mergeAndAlignIdsWithOther(
522     unsigned offset, FlatAffineValueConstraints *other) {
523   mergeAndAlignIds(offset, this, other);
524 }
525 
526 LogicalResult
composeMap(const AffineValueMap * vMap)527 FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) {
528   return composeMatchingMap(
529       computeAlignedMap(vMap->getAffineMap(), vMap->getOperands()));
530 }
531 
532 // Similar to `composeMap` except that no Values need be associated with the
533 // constraint system nor are they looked at -- the dimensions and symbols of
534 // `other` are expected to correspond 1:1 to `this` system.
composeMatchingMap(AffineMap other)535 LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
536   assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
537   assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
538 
539   std::vector<SmallVector<int64_t, 8>> flatExprs;
540   if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
541     return failure();
542   assert(flatExprs.size() == other.getNumResults());
543 
544   // Add dimensions corresponding to the map's results.
545   insertDimId(/*pos=*/0, /*num=*/other.getNumResults());
546 
547   // We add one equality for each result connecting the result dim of the map to
548   // the other identifiers.
549   // E.g.: if the expression is 16*i0 + i1, and this is the r^th
550   // iteration/result of the value map, we are adding the equality:
551   // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
552   // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
553   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
554     const auto &flatExpr = flatExprs[r];
555     assert(flatExpr.size() >= other.getNumInputs() + 1);
556 
557     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
558     // Set the coefficient for this result to one.
559     eqToAdd[r] = 1;
560 
561     // Dims and symbols.
562     for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
563       // Negate `eq[r]` since the newly added dimension will be set to this one.
564       eqToAdd[e + i] = -flatExpr[i];
565     }
566     // Local columns of `eq` are at the beginning.
567     unsigned j = getNumDimIds() + getNumSymbolIds();
568     unsigned end = flatExpr.size() - 1;
569     for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
570       eqToAdd[j] = -flatExpr[i];
571     }
572 
573     // Constant term.
574     eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
575 
576     // Add the equality connecting the result of the map to this constraint set.
577     addEquality(eqToAdd);
578   }
579 
580   return success();
581 }
582 
583 // Turn a symbol into a dimension.
turnSymbolIntoDim(FlatAffineValueConstraints * cst,Value id)584 static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value id) {
585   unsigned pos;
586   if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
587       pos < cst->getNumDimAndSymbolIds()) {
588     cst->swapId(pos, cst->getNumDimIds());
589     cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
590   }
591 }
592 
593 /// Merge and align symbols of `this` and `other` such that both get union of
594 /// of symbols that are unique. Symbols with Value as `None` are considered
595 /// to be inequal to all other symbols.
mergeSymbolIds(FlatAffineValueConstraints & other)596 void FlatAffineValueConstraints::mergeSymbolIds(
597     FlatAffineValueConstraints &other) {
598   SmallVector<Value, 4> aSymValues;
599   getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues);
600 
601   // Merge symbols: merge symbols into `other` first from `this`.
602   unsigned s = other.getNumDimIds();
603   for (Value aSymValue : aSymValues) {
604     unsigned loc;
605     // If the id is a symbol in `other`, then align it, otherwise assume that
606     // it is a new symbol
607     if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() &&
608         loc < getNumDimAndSymbolIds())
609       other.swapId(s, loc);
610     else
611       other.insertSymbolId(s - other.getNumDimIds(), aSymValue);
612     s++;
613   }
614 
615   // Symbols that are in other, but not in this, are added at the end.
616   for (unsigned t = other.getNumDimIds() + getNumSymbolIds(),
617                 e = other.getNumDimAndSymbolIds();
618        t < e; t++)
619     insertSymbolId(getNumSymbolIds(), other.getValue(t));
620 
621   assert(getNumSymbolIds() == other.getNumSymbolIds() &&
622          "expected same number of symbols");
623 }
624 
625 // Changes all symbol identifiers which are loop IVs to dim identifiers.
convertLoopIVSymbolsToDims()626 void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() {
627   // Gather all symbols which are loop IVs.
628   SmallVector<Value, 4> loopIVs;
629   for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
630     if (hasValue(i) && getForInductionVarOwner(getValue(i)))
631       loopIVs.push_back(getValue(i));
632   }
633   // Turn each symbol in 'loopIVs' into a dim identifier.
634   for (auto iv : loopIVs) {
635     turnSymbolIntoDim(this, iv);
636   }
637 }
638 
addInductionVarOrTerminalSymbol(Value val)639 void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) {
640   if (containsId(val))
641     return;
642 
643   // Caller is expected to fully compose map/operands if necessary.
644   assert((isTopLevelValue(val) || isForInductionVar(val)) &&
645          "non-terminal symbol / loop IV expected");
646   // Outer loop IVs could be used in forOp's bounds.
647   if (auto loop = getForInductionVarOwner(val)) {
648     appendDimId(val);
649     if (failed(this->addAffineForOpDomain(loop)))
650       LLVM_DEBUG(
651           loop.emitWarning("failed to add domain info to constraint system"));
652     return;
653   }
654   // Add top level symbol.
655   appendSymbolId(val);
656   // Check if the symbol is a constant.
657   if (auto constOp = val.getDefiningOp<ConstantIndexOp>())
658     addBound(BoundType::EQ, val, constOp.getValue());
659 }
660 
661 LogicalResult
addAffineForOpDomain(AffineForOp forOp)662 FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
663   unsigned pos;
664   // Pre-condition for this method.
665   if (!findId(forOp.getInductionVar(), &pos)) {
666     assert(false && "Value not found");
667     return failure();
668   }
669 
670   int64_t step = forOp.getStep();
671   if (step != 1) {
672     if (!forOp.hasConstantLowerBound())
673       LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated"));
674     else {
675       // Add constraints for the stride.
676       // (iv - lb) % step = 0 can be written as:
677       // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
678       // Add local variable 'q' and add the above equality.
679       // The first constraint is q = (iv - lb) floordiv step
680       SmallVector<int64_t, 8> dividend(getNumCols(), 0);
681       int64_t lb = forOp.getConstantLowerBound();
682       dividend[pos] = 1;
683       dividend.back() -= lb;
684       addLocalFloorDiv(dividend, step);
685       // Second constraint: (iv - lb) - step * q = 0.
686       SmallVector<int64_t, 8> eq(getNumCols(), 0);
687       eq[pos] = 1;
688       eq.back() -= lb;
689       // For the local var just added above.
690       eq[getNumCols() - 2] = -step;
691       addEquality(eq);
692     }
693   }
694 
695   if (forOp.hasConstantLowerBound()) {
696     addBound(BoundType::LB, pos, forOp.getConstantLowerBound());
697   } else {
698     // Non-constant lower bound case.
699     if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(),
700                         forOp.getLowerBoundOperands())))
701       return failure();
702   }
703 
704   if (forOp.hasConstantUpperBound()) {
705     addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1);
706     return success();
707   }
708   // Non-constant upper bound case.
709   return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(),
710                   forOp.getUpperBoundOperands());
711 }
712 
713 LogicalResult
addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,ArrayRef<AffineMap> ubMaps,ArrayRef<Value> operands)714 FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
715                                                    ArrayRef<AffineMap> ubMaps,
716                                                    ArrayRef<Value> operands) {
717   assert(lbMaps.size() == ubMaps.size());
718   assert(lbMaps.size() <= getNumDimIds());
719 
720   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
721     AffineMap lbMap = lbMaps[i];
722     AffineMap ubMap = ubMaps[i];
723     assert(!lbMap || lbMap.getNumInputs() == operands.size());
724     assert(!ubMap || ubMap.getNumInputs() == operands.size());
725 
726     // Check if this slice is just an equality along this dimension. If so,
727     // retrieve the existing loop it equates to and add it to the system.
728     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
729         ubMap.getNumResults() == 1 &&
730         lbMap.getResult(0) + 1 == ubMap.getResult(0) &&
731         // The condition above will be true for maps describing a single
732         // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
733         // Make sure we skip those cases by checking that the lb result is not
734         // just a constant.
735         !lbMap.getResult(0).isa<AffineConstantExpr>()) {
736       // Limited support: we expect the lb result to be just a loop dimension.
737       // Not supported otherwise for now.
738       AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
739       if (!result)
740         return failure();
741 
742       AffineForOp loop =
743           getForInductionVarOwner(operands[result.getPosition()]);
744       if (!loop)
745         return failure();
746 
747       if (failed(addAffineForOpDomain(loop)))
748         return failure();
749       continue;
750     }
751 
752     // This slice refers to a loop that doesn't exist in the IR yet. Add its
753     // bounds to the system assuming its dimension identifier position is the
754     // same as the position of the loop in the loop nest.
755     if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands)))
756       return failure();
757     if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands)))
758       return failure();
759   }
760   return success();
761 }
762 
addAffineIfOpDomain(AffineIfOp ifOp)763 void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
764   // Create the base constraints from the integer set attached to ifOp.
765   FlatAffineValueConstraints cst(ifOp.getIntegerSet());
766 
767   // Bind ids in the constraints to ifOp operands.
768   SmallVector<Value, 4> operands = ifOp.getOperands();
769   cst.setValues(0, cst.getNumDimAndSymbolIds(), operands);
770 
771   // Merge the constraints from ifOp to the current domain. We need first merge
772   // and align the IDs from both constraints, and then append the constraints
773   // from the ifOp into the current one.
774   mergeAndAlignIdsWithOther(0, &cst);
775   append(cst);
776 }
777 
778 // Searches for a constraint with a non-zero coefficient at `colIdx` in
779 // equality (isEq=true) or inequality (isEq=false) constraints.
780 // Returns true and sets row found in search in `rowIdx`, false otherwise.
findConstraintWithNonZeroAt(const FlatAffineConstraints & cst,unsigned colIdx,bool isEq,unsigned * rowIdx)781 static bool findConstraintWithNonZeroAt(const FlatAffineConstraints &cst,
782                                         unsigned colIdx, bool isEq,
783                                         unsigned *rowIdx) {
784   assert(colIdx < cst.getNumCols() && "position out of bounds");
785   auto at = [&](unsigned rowIdx) -> int64_t {
786     return isEq ? cst.atEq(rowIdx, colIdx) : cst.atIneq(rowIdx, colIdx);
787   };
788   unsigned e = isEq ? cst.getNumEqualities() : cst.getNumInequalities();
789   for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
790     if (at(*rowIdx) != 0) {
791       return true;
792     }
793   }
794   return false;
795 }
796 
797 // Normalizes the coefficient values across all columns in `rowIdx` by their
798 // GCD in equality or inequality constraints as specified by `isEq`.
799 template <bool isEq>
normalizeConstraintByGCD(FlatAffineConstraints * constraints,unsigned rowIdx)800 static void normalizeConstraintByGCD(FlatAffineConstraints *constraints,
801                                      unsigned rowIdx) {
802   auto at = [&](unsigned colIdx) -> int64_t {
803     return isEq ? constraints->atEq(rowIdx, colIdx)
804                 : constraints->atIneq(rowIdx, colIdx);
805   };
806   uint64_t gcd = std::abs(at(0));
807   for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) {
808     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j)));
809   }
810   if (gcd > 0 && gcd != 1) {
811     for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) {
812       int64_t v = at(j) / static_cast<int64_t>(gcd);
813       isEq ? constraints->atEq(rowIdx, j) = v
814            : constraints->atIneq(rowIdx, j) = v;
815     }
816   }
817 }
818 
normalizeConstraintsByGCD()819 void FlatAffineConstraints::normalizeConstraintsByGCD() {
820   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
821     normalizeConstraintByGCD</*isEq=*/true>(this, i);
822   }
823   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
824     normalizeConstraintByGCD</*isEq=*/false>(this, i);
825   }
826 }
827 
hasConsistentState() const828 bool FlatAffineConstraints::hasConsistentState() const {
829   if (!inequalities.hasConsistentState())
830     return false;
831   if (!equalities.hasConsistentState())
832     return false;
833 
834   // Catches errors where numDims, numSymbols, numIds aren't consistent.
835   if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds)
836     return false;
837 
838   return true;
839 }
840 
hasConsistentState() const841 bool FlatAffineValueConstraints::hasConsistentState() const {
842   return FlatAffineConstraints::hasConsistentState() &&
843          values.size() == getNumIds();
844 }
845 
hasInvalidConstraint() const846 bool FlatAffineConstraints::hasInvalidConstraint() const {
847   assert(hasConsistentState());
848   auto check = [&](bool isEq) -> bool {
849     unsigned numCols = getNumCols();
850     unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
851     for (unsigned i = 0, e = numRows; i < e; ++i) {
852       unsigned j;
853       for (j = 0; j < numCols - 1; ++j) {
854         int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
855         // Skip rows with non-zero variable coefficients.
856         if (v != 0)
857           break;
858       }
859       if (j < numCols - 1) {
860         continue;
861       }
862       // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
863       // Example invalid constraints include: '1 == 0' or '-1 >= 0'
864       int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
865       if ((isEq && v != 0) || (!isEq && v < 0)) {
866         return true;
867       }
868     }
869     return false;
870   };
871   if (check(/*isEq=*/true))
872     return true;
873   return check(/*isEq=*/false);
874 }
875 
876 /// Eliminate identifier from constraint at `rowIdx` based on coefficient at
877 /// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
878 /// updated as they have already been eliminated.
eliminateFromConstraint(FlatAffineConstraints * constraints,unsigned rowIdx,unsigned pivotRow,unsigned pivotCol,unsigned elimColStart,bool isEq)879 static void eliminateFromConstraint(FlatAffineConstraints *constraints,
880                                     unsigned rowIdx, unsigned pivotRow,
881                                     unsigned pivotCol, unsigned elimColStart,
882                                     bool isEq) {
883   // Skip if equality 'rowIdx' if same as 'pivotRow'.
884   if (isEq && rowIdx == pivotRow)
885     return;
886   auto at = [&](unsigned i, unsigned j) -> int64_t {
887     return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
888   };
889   int64_t leadCoeff = at(rowIdx, pivotCol);
890   // Skip if leading coefficient at 'rowIdx' is already zero.
891   if (leadCoeff == 0)
892     return;
893   int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
894   int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
895   int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
896   int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
897   int64_t rowMultiplier = lcm / std::abs(leadCoeff);
898 
899   unsigned numCols = constraints->getNumCols();
900   for (unsigned j = 0; j < numCols; ++j) {
901     // Skip updating column 'j' if it was just eliminated.
902     if (j >= elimColStart && j < pivotCol)
903       continue;
904     int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
905                 rowMultiplier * at(rowIdx, j);
906     isEq ? constraints->atEq(rowIdx, j) = v
907          : constraints->atIneq(rowIdx, j) = v;
908   }
909 }
910 
removeIdRange(unsigned idStart,unsigned idLimit)911 void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) {
912   assert(idLimit < getNumCols() && "invalid id limit");
913 
914   if (idStart >= idLimit)
915     return;
916 
917   // We are going to be removing one or more identifiers from the range.
918   assert(idStart < numIds && "invalid idStart position");
919 
920   // TODO: Make 'removeIdRange' a lambda called from here.
921   // Remove eliminated identifiers from the constraints..
922   equalities.removeColumns(idStart, idLimit - idStart);
923   inequalities.removeColumns(idStart, idLimit - idStart);
924 
925   // Update members numDims, numSymbols and numIds.
926   unsigned numDimsEliminated = 0;
927   unsigned numLocalsEliminated = 0;
928   unsigned numColsEliminated = idLimit - idStart;
929   if (idStart < numDims) {
930     numDimsEliminated = std::min(numDims, idLimit) - idStart;
931   }
932   // Check how many local id's were removed. Note that our identifier order is
933   // [dims, symbols, locals]. Local id start at position numDims + numSymbols.
934   if (idLimit > numDims + numSymbols) {
935     numLocalsEliminated = std::min(
936         idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds());
937   }
938   unsigned numSymbolsEliminated =
939       numColsEliminated - numDimsEliminated - numLocalsEliminated;
940 
941   numDims -= numDimsEliminated;
942   numSymbols -= numSymbolsEliminated;
943   numIds = numIds - numColsEliminated;
944 }
945 
removeIdRange(unsigned idStart,unsigned idLimit)946 void FlatAffineValueConstraints::removeIdRange(unsigned idStart,
947                                                unsigned idLimit) {
948   FlatAffineConstraints::removeIdRange(idStart, idLimit);
949   values.erase(values.begin() + idStart, values.begin() + idLimit);
950 }
951 
952 /// Returns the position of the identifier that has the minimum <number of lower
953 /// bounds> times <number of upper bounds> from the specified range of
954 /// identifiers [start, end). It is often best to eliminate in the increasing
955 /// order of these counts when doing Fourier-Motzkin elimination since FM adds
956 /// that many new constraints.
getBestIdToEliminate(const FlatAffineConstraints & cst,unsigned start,unsigned end)957 static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst,
958                                      unsigned start, unsigned end) {
959   assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
960 
961   auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
962     unsigned numLb = 0;
963     unsigned numUb = 0;
964     for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
965       if (cst.atIneq(r, pos) > 0) {
966         ++numLb;
967       } else if (cst.atIneq(r, pos) < 0) {
968         ++numUb;
969       }
970     }
971     return numLb * numUb;
972   };
973 
974   unsigned minLoc = start;
975   unsigned min = getProductOfNumLowerUpperBounds(start);
976   for (unsigned c = start + 1; c < end; c++) {
977     unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
978     if (numLbUbProduct < min) {
979       min = numLbUbProduct;
980       minLoc = c;
981     }
982   }
983   return minLoc;
984 }
985 
986 // Checks for emptiness of the set by eliminating identifiers successively and
987 // using the GCD test (on all equality constraints) and checking for trivially
988 // invalid constraints. Returns 'true' if the constraint system is found to be
989 // empty; false otherwise.
isEmpty() const990 bool FlatAffineConstraints::isEmpty() const {
991   if (isEmptyByGCDTest() || hasInvalidConstraint())
992     return true;
993 
994   FlatAffineConstraints tmpCst(*this);
995 
996   // First, eliminate as many local variables as possible using equalities.
997   tmpCst.removeRedundantLocalVars();
998   if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint())
999     return true;
1000 
1001   // Eliminate as many identifiers as possible using Gaussian elimination.
1002   unsigned currentPos = 0;
1003   while (currentPos < tmpCst.getNumIds()) {
1004     tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
1005     ++currentPos;
1006     // We check emptiness through trivial checks after eliminating each ID to
1007     // detect emptiness early. Since the checks isEmptyByGCDTest() and
1008     // hasInvalidConstraint() are linear time and single sweep on the constraint
1009     // buffer, this appears reasonable - but can optimize in the future.
1010     if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
1011       return true;
1012   }
1013 
1014   // Eliminate the remaining using FM.
1015   for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
1016     tmpCst.fourierMotzkinEliminate(
1017         getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
1018     // Check for a constraint explosion. This rarely happens in practice, but
1019     // this check exists as a safeguard against improperly constructed
1020     // constraint systems or artificially created arbitrarily complex systems
1021     // that aren't the intended use case for FlatAffineConstraints. This is
1022     // needed since FM has a worst case exponential complexity in theory.
1023     if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
1024       LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
1025       return false;
1026     }
1027 
1028     // FM wouldn't have modified the equalities in any way. So no need to again
1029     // run GCD test. Check for trivial invalid constraints.
1030     if (tmpCst.hasInvalidConstraint())
1031       return true;
1032   }
1033   return false;
1034 }
1035 
1036 // Runs the GCD test on all equality constraints. Returns 'true' if this test
1037 // fails on any equality. Returns 'false' otherwise.
1038 // This test can be used to disprove the existence of a solution. If it returns
1039 // true, no integer solution to the equality constraints can exist.
1040 //
1041 // GCD test definition:
1042 //
1043 // The equality constraint:
1044 //
1045 //  c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
1046 //
1047 // has an integer solution iff:
1048 //
1049 //  GCD of c_1, c_2, ..., c_n divides c_0.
1050 //
isEmptyByGCDTest() const1051 bool FlatAffineConstraints::isEmptyByGCDTest() const {
1052   assert(hasConsistentState());
1053   unsigned numCols = getNumCols();
1054   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1055     uint64_t gcd = std::abs(atEq(i, 0));
1056     for (unsigned j = 1; j < numCols - 1; ++j) {
1057       gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
1058     }
1059     int64_t v = std::abs(atEq(i, numCols - 1));
1060     if (gcd > 0 && (v % gcd != 0)) {
1061       return true;
1062     }
1063   }
1064   return false;
1065 }
1066 
1067 // Returns a matrix where each row is a vector along which the polytope is
1068 // bounded. The span of the returned vectors is guaranteed to contain all
1069 // such vectors. The returned vectors are NOT guaranteed to be linearly
1070 // independent. This function should not be called on empty sets.
1071 //
1072 // It is sufficient to check the perpendiculars of the constraints, as the set
1073 // of perpendiculars which are bounded must span all bounded directions.
getBoundedDirections() const1074 Matrix FlatAffineConstraints::getBoundedDirections() const {
1075   // Note that it is necessary to add the equalities too (which the constructor
1076   // does) even though we don't need to check if they are bounded; whether an
1077   // inequality is bounded or not depends on what other constraints, including
1078   // equalities, are present.
1079   Simplex simplex(*this);
1080 
1081   assert(!simplex.isEmpty() && "It is not meaningful to ask whether a "
1082                                "direction is bounded in an empty set.");
1083 
1084   SmallVector<unsigned, 8> boundedIneqs;
1085   // The constructor adds the inequalities to the simplex first, so this
1086   // processes all the inequalities.
1087   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1088     if (simplex.isBoundedAlongConstraint(i))
1089       boundedIneqs.push_back(i);
1090   }
1091 
1092   // The direction vector is given by the coefficients and does not include the
1093   // constant term, so the matrix has one fewer column.
1094   unsigned dirsNumCols = getNumCols() - 1;
1095   Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
1096 
1097   // Copy the bounded inequalities.
1098   unsigned row = 0;
1099   for (unsigned i : boundedIneqs) {
1100     for (unsigned col = 0; col < dirsNumCols; ++col)
1101       dirs(row, col) = atIneq(i, col);
1102     ++row;
1103   }
1104 
1105   // Copy the equalities. All the equalities' perpendiculars are bounded.
1106   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1107     for (unsigned col = 0; col < dirsNumCols; ++col)
1108       dirs(row, col) = atEq(i, col);
1109     ++row;
1110   }
1111 
1112   return dirs;
1113 }
1114 
eqInvolvesSuffixDims(const FlatAffineConstraints & fac,unsigned eqIndex,unsigned numDims)1115 bool eqInvolvesSuffixDims(const FlatAffineConstraints &fac, unsigned eqIndex,
1116                           unsigned numDims) {
1117   for (unsigned e = fac.getNumIds(), j = e - numDims; j < e; ++j)
1118     if (fac.atEq(eqIndex, j) != 0)
1119       return true;
1120   return false;
1121 }
ineqInvolvesSuffixDims(const FlatAffineConstraints & fac,unsigned ineqIndex,unsigned numDims)1122 bool ineqInvolvesSuffixDims(const FlatAffineConstraints &fac,
1123                             unsigned ineqIndex, unsigned numDims) {
1124   for (unsigned e = fac.getNumIds(), j = e - numDims; j < e; ++j)
1125     if (fac.atIneq(ineqIndex, j) != 0)
1126       return true;
1127   return false;
1128 }
1129 
removeConstraintsInvolvingSuffixDims(FlatAffineConstraints & fac,unsigned unboundedDims)1130 void removeConstraintsInvolvingSuffixDims(FlatAffineConstraints &fac,
1131                                           unsigned unboundedDims) {
1132   // We iterate backwards so that whether we remove constraint i - 1 or not, the
1133   // next constraint to be tested is always i - 2.
1134   for (unsigned i = fac.getNumEqualities(); i > 0; i--)
1135     if (eqInvolvesSuffixDims(fac, i - 1, unboundedDims))
1136       fac.removeEquality(i - 1);
1137   for (unsigned i = fac.getNumInequalities(); i > 0; i--)
1138     if (ineqInvolvesSuffixDims(fac, i - 1, unboundedDims))
1139       fac.removeInequality(i - 1);
1140 }
1141 
isIntegerEmpty() const1142 bool FlatAffineConstraints::isIntegerEmpty() const {
1143   return !findIntegerSample().hasValue();
1144 }
1145 
1146 /// Let this set be S. If S is bounded then we directly call into the GBR
1147 /// sampling algorithm. Otherwise, there are some unbounded directions, i.e.,
1148 /// vectors v such that S extends to infinity along v or -v. In this case we
1149 /// use an algorithm described in the integer set library (isl) manual and used
1150 /// by the isl_set_sample function in that library. The algorithm is:
1151 ///
1152 /// 1) Apply a unimodular transform T to S to obtain S*T, such that all
1153 /// dimensions in which S*T is bounded lie in the linear span of a prefix of the
1154 /// dimensions.
1155 ///
1156 /// 2) Construct a set B by removing all constraints that involve
1157 /// the unbounded dimensions and then deleting the unbounded dimensions. Note
1158 /// that B is a Bounded set.
1159 ///
1160 /// 3) Try to obtain a sample from B using the GBR sampling
1161 /// algorithm. If no sample is found, return that S is empty.
1162 ///
1163 /// 4) Otherwise, substitute the obtained sample into S*T to obtain a set
1164 /// C. C is a full-dimensional Cone and always contains a sample.
1165 ///
1166 /// 5) Obtain an integer sample from C.
1167 ///
1168 /// 6) Return T*v, where v is the concatenation of the samples from B and C.
1169 ///
1170 /// The following is a sketch of a proof that
1171 /// a) If the algorithm returns empty, then S is empty.
1172 /// b) If the algorithm returns a sample, it is a valid sample in S.
1173 ///
1174 /// The algorithm returns empty only if B is empty, in which case S*T is
1175 /// certainly empty since B was obtained by removing constraints and then
1176 /// deleting unconstrained dimensions from S*T. Since T is unimodular, a vector
1177 /// v is in S*T iff T*v is in S. So in this case, since
1178 /// S*T is empty, S is empty too.
1179 ///
1180 /// Otherwise, the algorithm substitutes the sample from B into S*T. All the
1181 /// constraints of S*T that did not involve unbounded dimensions are satisfied
1182 /// by this substitution. All dimensions in the linear span of the dimensions
1183 /// outside the prefix are unbounded in S*T (step 1). Substituting values for
1184 /// the bounded dimensions cannot make these dimensions bounded, and these are
1185 /// the only remaining dimensions in C, so C is unbounded along every vector (in
1186 /// the positive or negative direction, or both). C is hence a full-dimensional
1187 /// cone and therefore always contains an integer point.
1188 ///
1189 /// Concatenating the samples from B and C gives a sample v in S*T, so the
1190 /// returned sample T*v is a sample in S.
1191 Optional<SmallVector<int64_t, 8>>
findIntegerSample() const1192 FlatAffineConstraints::findIntegerSample() const {
1193   // First, try the GCD test heuristic.
1194   if (isEmptyByGCDTest())
1195     return {};
1196 
1197   Simplex simplex(*this);
1198   if (simplex.isEmpty())
1199     return {};
1200 
1201   // For a bounded set, we directly call into the GBR sampling algorithm.
1202   if (!simplex.isUnbounded())
1203     return simplex.findIntegerSample();
1204 
1205   // The set is unbounded. We cannot directly use the GBR algorithm.
1206   //
1207   // m is a matrix containing, in each row, a vector in which S is
1208   // bounded, such that the linear span of all these dimensions contains all
1209   // bounded dimensions in S.
1210   Matrix m = getBoundedDirections();
1211   // In column echelon form, each row of m occupies only the first rank(m)
1212   // columns and has zeros on the other columns. The transform T that brings S
1213   // to column echelon form is unimodular as well, so this is a suitable
1214   // transform to use in step 1 of the algorithm.
1215   std::pair<unsigned, LinearTransform> result =
1216       LinearTransform::makeTransformToColumnEchelon(std::move(m));
1217   const LinearTransform &transform = result.second;
1218   // 1) Apply T to S to obtain S*T.
1219   FlatAffineConstraints transformedSet = transform.applyTo(*this);
1220 
1221   // 2) Remove the unbounded dimensions and constraints involving them to
1222   // obtain a bounded set.
1223   FlatAffineConstraints boundedSet = transformedSet;
1224   unsigned numBoundedDims = result.first;
1225   unsigned numUnboundedDims = getNumIds() - numBoundedDims;
1226   removeConstraintsInvolvingSuffixDims(boundedSet, numUnboundedDims);
1227   boundedSet.removeIdRange(numBoundedDims, boundedSet.getNumIds());
1228 
1229   // 3) Try to obtain a sample from the bounded set.
1230   Optional<SmallVector<int64_t, 8>> boundedSample =
1231       Simplex(boundedSet).findIntegerSample();
1232   if (!boundedSample)
1233     return {};
1234   assert(boundedSet.containsPoint(*boundedSample) &&
1235          "Simplex returned an invalid sample!");
1236 
1237   // 4) Substitute the values of the bounded dimensions into S*T to obtain a
1238   // full-dimensional cone, which necessarily contains an integer sample.
1239   transformedSet.setAndEliminate(0, *boundedSample);
1240   FlatAffineConstraints &cone = transformedSet;
1241 
1242   // 5) Obtain an integer sample from the cone.
1243   //
1244   // We shrink the cone such that for any rational point in the shrunken cone,
1245   // rounding up each of the point's coordinates produces a point that still
1246   // lies in the original cone.
1247   //
1248   // Rounding up a point x adds a number e_i in [0, 1) to each coordinate x_i.
1249   // For each inequality sum_i a_i x_i + c >= 0 in the original cone, the
1250   // shrunken cone will have the inequality tightened by some amount s, such
1251   // that if x satisfies the shrunken cone's tightened inequality, then x + e
1252   // satisfies the original inequality, i.e.,
1253   //
1254   // sum_i a_i x_i + c + s >= 0 implies sum_i a_i (x_i + e_i) + c >= 0
1255   //
1256   // for any e_i values in [0, 1). In fact, we will handle the slightly more
1257   // general case where e_i can be in [0, 1]. For example, consider the
1258   // inequality 2x_1 - 3x_2 - 7x_3 - 6 >= 0, and let x = (3, 0, 0). How low
1259   // could the LHS go if we added a number in [0, 1] to each coordinate? The LHS
1260   // is minimized when we add 1 to the x_i with negative coefficient a_i and
1261   // keep the other x_i the same. In the example, we would get x = (3, 1, 1),
1262   // changing the value of the LHS by -3 + -7 = -10.
1263   //
1264   // In general, the value of the LHS can change by at most the sum of the
1265   // negative a_i, so we accomodate this by shifting the inequality by this
1266   // amount for the shrunken cone.
1267   for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) {
1268     for (unsigned j = 0; j < cone.numIds; ++j) {
1269       int64_t coeff = cone.atIneq(i, j);
1270       if (coeff < 0)
1271         cone.atIneq(i, cone.numIds) += coeff;
1272     }
1273   }
1274 
1275   // Obtain an integer sample in the cone by rounding up a rational point from
1276   // the shrunken cone. Shrinking the cone amounts to shifting its apex
1277   // "inwards" without changing its "shape"; the shrunken cone is still a
1278   // full-dimensional cone and is hence non-empty.
1279   Simplex shrunkenConeSimplex(cone);
1280   assert(!shrunkenConeSimplex.isEmpty() && "Shrunken cone cannot be empty!");
1281   SmallVector<Fraction, 8> shrunkenConeSample =
1282       shrunkenConeSimplex.getRationalSample();
1283 
1284   SmallVector<int64_t, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil));
1285 
1286   // 6) Return transform * concat(boundedSample, coneSample).
1287   SmallVector<int64_t, 8> &sample = boundedSample.getValue();
1288   sample.append(coneSample.begin(), coneSample.end());
1289   return transform.preMultiplyColumn(sample);
1290 }
1291 
1292 /// Helper to evaluate an affine expression at a point.
1293 /// The expression is a list of coefficients for the dimensions followed by the
1294 /// constant term.
valueAt(ArrayRef<int64_t> expr,ArrayRef<int64_t> point)1295 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
1296   assert(expr.size() == 1 + point.size() &&
1297          "Dimensionalities of point and expression don't match!");
1298   int64_t value = expr.back();
1299   for (unsigned i = 0; i < point.size(); ++i)
1300     value += expr[i] * point[i];
1301   return value;
1302 }
1303 
1304 /// A point satisfies an equality iff the value of the equality at the
1305 /// expression is zero, and it satisfies an inequality iff the value of the
1306 /// inequality at that point is non-negative.
containsPoint(ArrayRef<int64_t> point) const1307 bool FlatAffineConstraints::containsPoint(ArrayRef<int64_t> point) const {
1308   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1309     if (valueAt(getEquality(i), point) != 0)
1310       return false;
1311   }
1312   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1313     if (valueAt(getInequality(i), point) < 0)
1314       return false;
1315   }
1316   return true;
1317 }
1318 
1319 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
1320 /// function of other identifiers (where the divisor is a positive constant),
1321 /// `foundRepr` contains a boolean for each identifier indicating if the
1322 /// explicit representation for that identifier has already been computed.
1323 static Optional<std::pair<unsigned, unsigned>>
computeSingleVarRepr(const FlatAffineConstraints & cst,const SmallVector<bool,8> & foundRepr,unsigned pos)1324 computeSingleVarRepr(const FlatAffineConstraints &cst,
1325                      const SmallVector<bool, 8> &foundRepr, unsigned pos) {
1326   assert(pos < cst.getNumIds() && "invalid position");
1327   assert(foundRepr.size() == cst.getNumIds() &&
1328          "Size of foundRepr does not match total number of variables");
1329 
1330   SmallVector<unsigned, 4> lbIndices, ubIndices;
1331   cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices);
1332 
1333   // `id` is equivalent to `expr floordiv divisor` if there
1334   // are constraints of the form:
1335   //      0 <= expr - divisor * id <= divisor - 1
1336   // Rearranging, we have:
1337   //       divisor * id - expr + (divisor - 1) >= 0  <-- Lower bound for 'id'
1338   //      -divisor * id + expr                 >= 0  <-- Upper bound for 'id'
1339   //
1340   // For example:
1341   //       32*k >= 16*i + j - 31                 <-- Lower bound for 'k'
1342   //       32*k  <= 16*i + j                     <-- Upper bound for 'k'
1343   //       expr = 16*i + j, divisor = 32
1344   //       k = ( 16*i + j ) floordiv 32
1345   //
1346   //       4q >= i + j - 2                       <-- Lower bound for 'q'
1347   //       4q <= i + j + 1                       <-- Upper bound for 'q'
1348   //       expr = i + j + 1, divisor = 4
1349   //       q = (i + j + 1) floordiv 4
1350   for (unsigned ubPos : ubIndices) {
1351     for (unsigned lbPos : lbIndices) {
1352       // Due to the form of the inequalities, sum of constants of the
1353       // inequalities is (divisor - 1).
1354       int64_t divisor = cst.atIneq(lbPos, cst.getNumCols() - 1) +
1355                         cst.atIneq(ubPos, cst.getNumCols() - 1) + 1;
1356 
1357       // Divisor should be positive.
1358       if (divisor <= 0)
1359         continue;
1360 
1361       // Check if coeff of variable is equal to divisor.
1362       if (divisor != cst.atIneq(lbPos, pos))
1363         continue;
1364 
1365       // Check if constraints are opposite of each other. Constant term
1366       // is not required to be opposite and is not checked.
1367       unsigned c = 0, f = 0;
1368       for (c = 0, f = cst.getNumIds(); c < f; ++c)
1369         if (cst.atIneq(ubPos, c) != -cst.atIneq(lbPos, c))
1370           break;
1371 
1372       if (c < f)
1373         continue;
1374 
1375       // Check if the inequalities depend on a variable for which
1376       // an explicit representation has not been found yet.
1377       // Exit to avoid circular dependencies between divisions.
1378       for (c = 0, f = cst.getNumIds(); c < f; ++c) {
1379         if (c == pos)
1380           continue;
1381         if (!foundRepr[c] && cst.atIneq(lbPos, c) != 0)
1382           break;
1383       }
1384 
1385       // Expression can't be constructed as it depends on a yet unknown
1386       // identifier.
1387       // TODO: Visit/compute the identifiers in an order so that this doesn't
1388       // happen. More complex but much more efficient.
1389       if (c < f)
1390         continue;
1391 
1392       return std::make_pair(ubPos, lbPos);
1393     }
1394   }
1395 
1396   return llvm::None;
1397 }
1398 
1399 /// Find pairs of inequalities identified by their position indices, using
1400 /// which an explicit representation for each local variable can be computed
1401 /// The pairs are stored as indices of upperbound, lowerbound
1402 /// inequalities. If no such pair can be found, it is stored as llvm::None.
getLocalReprLbUbPairs(std::vector<llvm::Optional<std::pair<unsigned,unsigned>>> & repr) const1403 void FlatAffineConstraints::getLocalReprLbUbPairs(
1404     std::vector<llvm::Optional<std::pair<unsigned, unsigned>>> &repr) const {
1405   assert(repr.size() == getNumLocalIds() &&
1406          "Size of repr does not match number of local variables");
1407 
1408   SmallVector<bool, 8> foundRepr(getNumIds(), false);
1409   for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i)
1410     foundRepr[i] = true;
1411 
1412   unsigned divOffset = getNumDimAndSymbolIds();
1413   bool changed;
1414   do {
1415     // Each time changed is true, at end of this iteration, one or more local
1416     // vars have been detected as floor divs.
1417     changed = false;
1418     for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) {
1419       if (!foundRepr[i + divOffset]) {
1420         if (auto res = computeSingleVarRepr(*this, foundRepr, divOffset + i)) {
1421           foundRepr[i + divOffset] = true;
1422           repr[i] = res;
1423           changed = true;
1424         }
1425       }
1426     }
1427   } while (changed);
1428 }
1429 
1430 /// Tightens inequalities given that we are dealing with integer spaces. This is
1431 /// analogous to the GCD test but applied to inequalities. The constant term can
1432 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
1433 ///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
1434 /// fast method - linear in the number of coefficients.
1435 // Example on how this affects practical cases: consider the scenario:
1436 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
1437 // j >= 100 instead of the tighter (exact) j >= 128.
gcdTightenInequalities()1438 void FlatAffineConstraints::gcdTightenInequalities() {
1439   unsigned numCols = getNumCols();
1440   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1441     uint64_t gcd = std::abs(atIneq(i, 0));
1442     for (unsigned j = 1; j < numCols - 1; ++j) {
1443       gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j)));
1444     }
1445     if (gcd > 0 && gcd != 1) {
1446       int64_t gcdI = static_cast<int64_t>(gcd);
1447       // Tighten the constant term and normalize the constraint by the GCD.
1448       atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI);
1449       for (unsigned j = 0, e = numCols - 1; j < e; ++j)
1450         atIneq(i, j) /= gcdI;
1451     }
1452   }
1453 }
1454 
1455 // Eliminates all identifier variables in column range [posStart, posLimit).
1456 // Returns the number of variables eliminated.
gaussianEliminateIds(unsigned posStart,unsigned posLimit)1457 unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
1458                                                      unsigned posLimit) {
1459   // Return if identifier positions to eliminate are out of range.
1460   assert(posLimit <= numIds);
1461   assert(hasConsistentState());
1462 
1463   if (posStart >= posLimit)
1464     return 0;
1465 
1466   gcdTightenInequalities();
1467 
1468   unsigned pivotCol = 0;
1469   for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
1470     // Find a row which has a non-zero coefficient in column 'j'.
1471     unsigned pivotRow;
1472     if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true,
1473                                      &pivotRow)) {
1474       // No pivot row in equalities with non-zero at 'pivotCol'.
1475       if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false,
1476                                        &pivotRow)) {
1477         // If inequalities are also non-zero in 'pivotCol', it can be
1478         // eliminated.
1479         continue;
1480       }
1481       break;
1482     }
1483 
1484     // Eliminate identifier at 'pivotCol' from each equality row.
1485     for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1486       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1487                               /*isEq=*/true);
1488       normalizeConstraintByGCD</*isEq=*/true>(this, i);
1489     }
1490 
1491     // Eliminate identifier at 'pivotCol' from each inequality row.
1492     for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1493       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1494                               /*isEq=*/false);
1495       normalizeConstraintByGCD</*isEq=*/false>(this, i);
1496     }
1497     removeEquality(pivotRow);
1498     gcdTightenInequalities();
1499   }
1500   // Update position limit based on number eliminated.
1501   posLimit = pivotCol;
1502   // Remove eliminated columns from all constraints.
1503   removeIdRange(posStart, posLimit);
1504   return posLimit - posStart;
1505 }
1506 
1507 // Determine whether the identifier at 'pos' (say id_r) can be expressed as
1508 // modulo of another known identifier (say id_n) w.r.t a constant. For example,
1509 // if the following constraints hold true:
1510 // ```
1511 // 0 <= id_r <= divisor - 1
1512 // id_n - (divisor * q_expr) = id_r
1513 // ```
1514 // where `id_n` is a known identifier (called dividend), and `q_expr` is an
1515 // `AffineExpr` (called the quotient expression), `id_r` can be written as:
1516 //
1517 // `id_r = id_n mod divisor`.
1518 //
1519 // Additionally, in a special case of the above constaints where `q_expr` is an
1520 // identifier itself that is not yet known (say `id_q`), it can be written as a
1521 // floordiv in the following way:
1522 //
1523 // `id_q = id_n floordiv divisor`.
1524 //
1525 // Returns true if the above mod or floordiv are detected, updating 'memo' with
1526 // these new expressions. Returns false otherwise.
detectAsMod(const FlatAffineConstraints & cst,unsigned pos,int64_t lbConst,int64_t ubConst,SmallVectorImpl<AffineExpr> & memo,MLIRContext * context)1527 static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
1528                         int64_t lbConst, int64_t ubConst,
1529                         SmallVectorImpl<AffineExpr> &memo,
1530                         MLIRContext *context) {
1531   assert(pos < cst.getNumIds() && "invalid position");
1532 
1533   // Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can
1534   // be determined.
1535   if (lbConst != 0 || ubConst < 1)
1536     return false;
1537   int64_t divisor = ubConst + 1;
1538 
1539   // Check for the aforementioned conditions in each equality.
1540   for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
1541        curEquality < numEqualities; curEquality++) {
1542     int64_t coefficientAtPos = cst.atEq(curEquality, pos);
1543     // If current equality does not involve `id_r`, continue to the next
1544     // equality.
1545     if (coefficientAtPos == 0)
1546       continue;
1547 
1548     // Constant term should be 0 in this equality.
1549     if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
1550       continue;
1551 
1552     // Traverse through the equality and construct the dividend expression
1553     // `dividendExpr`, to contain all the identifiers which are known and are
1554     // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
1555     // `dividendExpr` gets simplified into a single identifier `id_n` discussed
1556     // above.
1557     auto dividendExpr = getAffineConstantExpr(0, context);
1558 
1559     // Track the terms that go into quotient expression, later used to detect
1560     // additional floordiv.
1561     unsigned quotientCount = 0;
1562     int quotientPosition = -1;
1563     int quotientSign = 1;
1564 
1565     // Consider each term in the current equality.
1566     unsigned curId, e;
1567     for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) {
1568       // Ignore id_r.
1569       if (curId == pos)
1570         continue;
1571       int64_t coefficientOfCurId = cst.atEq(curEquality, curId);
1572       // Ignore ids that do not contribute to the current equality.
1573       if (coefficientOfCurId == 0)
1574         continue;
1575       // Check if the current id goes into the quotient expression.
1576       if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) {
1577         quotientCount++;
1578         quotientPosition = curId;
1579         quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1;
1580         continue;
1581       }
1582       // Identifiers that are part of dividendExpr should be known.
1583       if (!memo[curId])
1584         break;
1585       // Append the current identifier to the dividend expression.
1586       dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId;
1587     }
1588 
1589     // Can't construct expression as it depends on a yet uncomputed id.
1590     if (curId < e)
1591       continue;
1592 
1593     // Express `id_r` in terms of the other ids collected so far.
1594     if (coefficientAtPos > 0)
1595       dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
1596     else
1597       dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
1598 
1599     // Simplify the expression.
1600     dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(),
1601                                       cst.getNumSymbolIds());
1602     // Only if the final dividend expression is just a single id (which we call
1603     // `id_n`), we can proceed.
1604     // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
1605     // to dims themselves.
1606     auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
1607     if (!dimExpr)
1608       continue;
1609 
1610     // Express `id_r` as `id_n % divisor` and store the expression in `memo`.
1611     if (quotientCount >= 1) {
1612       auto ub = cst.getConstantBound(FlatAffineConstraints::BoundType::UB,
1613                                      dimExpr.getPosition());
1614       // If `id_n` has an upperbound that is less than the divisor, mod can be
1615       // eliminated altogether.
1616       if (ub.hasValue() && ub.getValue() < divisor)
1617         memo[pos] = dimExpr;
1618       else
1619         memo[pos] = dimExpr % divisor;
1620       // If a unique quotient `id_q` was seen, it can be expressed as
1621       // `id_n floordiv divisor`.
1622       if (quotientCount == 1 && !memo[quotientPosition])
1623         memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
1624 
1625       return true;
1626     }
1627   }
1628   return false;
1629 }
1630 
1631 /// Gather all lower and upper bounds of the identifier at `pos`, and
1632 /// optionally any equalities on it. In addition, the bounds are to be
1633 /// independent of identifiers in position range [`offset`, `offset` + `num`).
getLowerAndUpperBoundIndices(unsigned pos,SmallVectorImpl<unsigned> * lbIndices,SmallVectorImpl<unsigned> * ubIndices,SmallVectorImpl<unsigned> * eqIndices,unsigned offset,unsigned num) const1634 void FlatAffineConstraints::getLowerAndUpperBoundIndices(
1635     unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
1636     SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
1637     unsigned offset, unsigned num) const {
1638   assert(pos < getNumIds() && "invalid position");
1639   assert(offset + num < getNumCols() && "invalid range");
1640 
1641   // Checks for a constraint that has a non-zero coeff for the identifiers in
1642   // the position range [offset, offset + num) while ignoring `pos`.
1643   auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
1644     unsigned c, f;
1645     auto cst = isEq ? getEquality(r) : getInequality(r);
1646     for (c = offset, f = offset + num; c < f; ++c) {
1647       if (c == pos)
1648         continue;
1649       if (cst[c] != 0)
1650         break;
1651     }
1652     return c < f;
1653   };
1654 
1655   // Gather all lower bounds and upper bounds of the variable. Since the
1656   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1657   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1658   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1659     // The bounds are to be independent of [offset, offset + num) columns.
1660     if (containsConstraintDependentOnRange(r, /*isEq=*/false))
1661       continue;
1662     if (atIneq(r, pos) >= 1) {
1663       // Lower bound.
1664       lbIndices->push_back(r);
1665     } else if (atIneq(r, pos) <= -1) {
1666       // Upper bound.
1667       ubIndices->push_back(r);
1668     }
1669   }
1670 
1671   // An equality is both a lower and upper bound. Record any equalities
1672   // involving the pos^th identifier.
1673   if (!eqIndices)
1674     return;
1675 
1676   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1677     if (atEq(r, pos) == 0)
1678       continue;
1679     if (containsConstraintDependentOnRange(r, /*isEq=*/true))
1680       continue;
1681     eqIndices->push_back(r);
1682   }
1683 }
1684 
1685 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
1686 /// function of other identifiers (where the divisor is a positive constant)
1687 /// given the initial set of expressions in `exprs`. If it can be, the
1688 /// corresponding position in `exprs` is set as the detected affine expr. For
1689 /// eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4. An equality can
1690 /// also yield a floordiv: eg.  4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
1691 /// <= i <= 32q + 31 => q = i floordiv 32.
detectAsFloorDiv(const FlatAffineConstraints & cst,unsigned pos,MLIRContext * context,SmallVectorImpl<AffineExpr> & exprs)1692 static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
1693                              MLIRContext *context,
1694                              SmallVectorImpl<AffineExpr> &exprs) {
1695   assert(pos < cst.getNumIds() && "invalid position");
1696 
1697   // Get upper-lower bound pair for this variable.
1698   SmallVector<bool, 8> foundRepr(cst.getNumIds(), false);
1699   for (unsigned i = 0, e = cst.getNumIds(); i < e; ++i)
1700     if (exprs[i])
1701       foundRepr[i] = true;
1702 
1703   auto ulPair = computeSingleVarRepr(cst, foundRepr, pos);
1704 
1705   // No upper-lower bound pair found for this var.
1706   if (!ulPair)
1707     return false;
1708 
1709   unsigned ubPos = ulPair->first;
1710 
1711   // Upper bound is of the form:
1712   //      -divisor * id + expr >= 0
1713   // where `id` is equivalent to `expr floordiv divisor`.
1714   //
1715   // Since the division cannot be dependent on itself, the coefficient of
1716   // of `id` in `expr` is zero. The coefficient of `id` in the upperbound
1717   // is -divisor.
1718   int64_t divisor = -cst.atIneq(ubPos, pos);
1719   int64_t constantTerm = cst.atIneq(ubPos, cst.getNumCols() - 1);
1720 
1721   // Construct the dividend expression.
1722   auto dividendExpr = getAffineConstantExpr(constantTerm, context);
1723   unsigned c, f;
1724   for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1725     if (c == pos)
1726       continue;
1727     int64_t ubVal = cst.atIneq(ubPos, c);
1728     if (ubVal == 0)
1729       continue;
1730     // computeSingleVarRepr guarantees that expr is known here.
1731     dividendExpr = dividendExpr + ubVal * exprs[c];
1732   }
1733 
1734   // Successfully detected the floordiv.
1735   exprs[pos] = dividendExpr.floorDiv(divisor);
1736   return true;
1737 }
1738 
1739 // Fills an inequality row with the value 'val'.
fillInequality(FlatAffineConstraints * cst,unsigned r,int64_t val)1740 static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
1741                                   int64_t val) {
1742   for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1743     cst->atIneq(r, c) = val;
1744   }
1745 }
1746 
1747 // Negates an inequality.
negateInequality(FlatAffineConstraints * cst,unsigned r)1748 static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
1749   for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1750     cst->atIneq(r, c) = -cst->atIneq(r, c);
1751   }
1752 }
1753 
1754 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
1755 // to check if a constraint is redundant.
removeRedundantInequalities()1756 void FlatAffineConstraints::removeRedundantInequalities() {
1757   SmallVector<bool, 32> redun(getNumInequalities(), false);
1758   // To check if an inequality is redundant, we replace the inequality by its
1759   // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
1760   // system is empty. If it is, the inequality is redundant.
1761   FlatAffineConstraints tmpCst(*this);
1762   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1763     // Change the inequality to its complement.
1764     negateInequality(&tmpCst, r);
1765     tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
1766     if (tmpCst.isEmpty()) {
1767       redun[r] = true;
1768       // Zero fill the redundant inequality.
1769       fillInequality(this, r, /*val=*/0);
1770       fillInequality(&tmpCst, r, /*val=*/0);
1771     } else {
1772       // Reverse the change (to avoid recreating tmpCst each time).
1773       tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
1774       negateInequality(&tmpCst, r);
1775     }
1776   }
1777 
1778   // Scan to get rid of all rows marked redundant, in-place.
1779   auto copyRow = [&](unsigned src, unsigned dest) {
1780     if (src == dest)
1781       return;
1782     for (unsigned c = 0, e = getNumCols(); c < e; c++) {
1783       atIneq(dest, c) = atIneq(src, c);
1784     }
1785   };
1786   unsigned pos = 0;
1787   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1788     if (!redun[r])
1789       copyRow(r, pos++);
1790   }
1791   inequalities.resizeVertically(pos);
1792 }
1793 
1794 // A more complex check to eliminate redundant inequalities and equalities. Uses
1795 // Simplex to check if a constraint is redundant.
removeRedundantConstraints()1796 void FlatAffineConstraints::removeRedundantConstraints() {
1797   // First, we run gcdTightenInequalities. This allows us to catch some
1798   // constraints which are not redundant when considering rational solutions
1799   // but are redundant in terms of integer solutions.
1800   gcdTightenInequalities();
1801   Simplex simplex(*this);
1802   simplex.detectRedundant();
1803 
1804   auto copyInequality = [&](unsigned src, unsigned dest) {
1805     if (src == dest)
1806       return;
1807     for (unsigned c = 0, e = getNumCols(); c < e; c++)
1808       atIneq(dest, c) = atIneq(src, c);
1809   };
1810   unsigned pos = 0;
1811   unsigned numIneqs = getNumInequalities();
1812   // Scan to get rid of all inequalities marked redundant, in-place. In Simplex,
1813   // the first constraints added are the inequalities.
1814   for (unsigned r = 0; r < numIneqs; r++) {
1815     if (!simplex.isMarkedRedundant(r))
1816       copyInequality(r, pos++);
1817   }
1818   inequalities.resizeVertically(pos);
1819 
1820   // Scan to get rid of all equalities marked redundant, in-place. In Simplex,
1821   // after the inequalities, a pair of constraints for each equality is added.
1822   // An equality is redundant if both the inequalities in its pair are
1823   // redundant.
1824   auto copyEquality = [&](unsigned src, unsigned dest) {
1825     if (src == dest)
1826       return;
1827     for (unsigned c = 0, e = getNumCols(); c < e; c++)
1828       atEq(dest, c) = atEq(src, c);
1829   };
1830   pos = 0;
1831   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1832     if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) &&
1833           simplex.isMarkedRedundant(numIneqs + 2 * r + 1)))
1834       copyEquality(r, pos++);
1835   }
1836   equalities.resizeVertically(pos);
1837 }
1838 
1839 /// Merge local ids of `this` and `other`. This is done by appending local ids
1840 /// of `other` to `this` and inserting local ids of `this` to `other` at start
1841 /// of its local ids.
mergeLocalIds(FlatAffineConstraints & other)1842 void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) {
1843   unsigned initLocals = getNumLocalIds();
1844   insertLocalId(getNumLocalIds(), other.getNumLocalIds());
1845   other.insertLocalId(0, initLocals);
1846 }
1847 
1848 /// Removes local variables using equalities. Each equality is checked if it
1849 /// can be reduced to the form: `e = affine-expr`, where `e` is a local
1850 /// variable and `affine-expr` is an affine expression not containing `e`.
1851 /// If an equality satisfies this form, the local variable is replaced in
1852 /// each constraint and then removed. The equality used to replace this local
1853 /// variable is also removed.
removeRedundantLocalVars()1854 void FlatAffineConstraints::removeRedundantLocalVars() {
1855   // Normalize the equality constraints to reduce coefficients of local
1856   // variables to 1 wherever possible.
1857   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1858     normalizeConstraintByGCD</*isEq=*/true>(this, i);
1859 
1860   while (true) {
1861     unsigned i, e, j, f;
1862     for (i = 0, e = getNumEqualities(); i < e; ++i) {
1863       // Find a local variable to eliminate using ith equality.
1864       for (j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j)
1865         if (std::abs(atEq(i, j)) == 1)
1866           break;
1867 
1868       // Local variable can be eliminated using ith equality.
1869       if (j < f)
1870         break;
1871     }
1872 
1873     // No equality can be used to eliminate a local variable.
1874     if (i == e)
1875       break;
1876 
1877     // Use the ith equality to simplify other equalities. If any changes
1878     // are made to an equality constraint, it is normalized by GCD.
1879     for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) {
1880       if (atEq(k, j) != 0) {
1881         eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true);
1882         normalizeConstraintByGCD</*isEq=*/true>(this, k);
1883       }
1884     }
1885 
1886     // Use the ith equality to simplify inequalities.
1887     for (unsigned k = 0, t = getNumInequalities(); k < t; ++k)
1888       eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false);
1889 
1890     // Remove the ith equality and the found local variable.
1891     removeId(j);
1892     removeEquality(i);
1893   }
1894 }
1895 
getLowerAndUpperBound(unsigned pos,unsigned offset,unsigned num,unsigned symStartPos,ArrayRef<AffineExpr> localExprs,MLIRContext * context) const1896 std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
1897     unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
1898     ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
1899   assert(pos + offset < getNumDimIds() && "invalid dim start pos");
1900   assert(symStartPos >= (pos + offset) && "invalid sym start pos");
1901   assert(getNumLocalIds() == localExprs.size() &&
1902          "incorrect local exprs count");
1903 
1904   SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
1905   getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
1906                                offset, num);
1907 
1908   /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
1909   auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
1910     b.clear();
1911     for (unsigned i = 0, e = a.size(); i < e; ++i) {
1912       if (i < offset || i >= offset + num)
1913         b.push_back(a[i]);
1914     }
1915   };
1916 
1917   SmallVector<int64_t, 8> lb, ub;
1918   SmallVector<AffineExpr, 4> lbExprs;
1919   unsigned dimCount = symStartPos - num;
1920   unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
1921   lbExprs.reserve(lbIndices.size() + eqIndices.size());
1922   // Lower bound expressions.
1923   for (auto idx : lbIndices) {
1924     auto ineq = getInequality(idx);
1925     // Extract the lower bound (in terms of other coeff's + const), i.e., if
1926     // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
1927     // - 1.
1928     addCoeffs(ineq, lb);
1929     std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
1930     auto expr =
1931         getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
1932     // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
1933     int64_t divisor = std::abs(ineq[pos + offset]);
1934     expr = (expr + divisor - 1).floorDiv(divisor);
1935     lbExprs.push_back(expr);
1936   }
1937 
1938   SmallVector<AffineExpr, 4> ubExprs;
1939   ubExprs.reserve(ubIndices.size() + eqIndices.size());
1940   // Upper bound expressions.
1941   for (auto idx : ubIndices) {
1942     auto ineq = getInequality(idx);
1943     // Extract the upper bound (in terms of other coeff's + const).
1944     addCoeffs(ineq, ub);
1945     auto expr =
1946         getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
1947     expr = expr.floorDiv(std::abs(ineq[pos + offset]));
1948     // Upper bound is exclusive.
1949     ubExprs.push_back(expr + 1);
1950   }
1951 
1952   // Equalities. It's both a lower and a upper bound.
1953   SmallVector<int64_t, 4> b;
1954   for (auto idx : eqIndices) {
1955     auto eq = getEquality(idx);
1956     addCoeffs(eq, b);
1957     if (eq[pos + offset] > 0)
1958       std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
1959 
1960     // Extract the upper bound (in terms of other coeff's + const).
1961     auto expr =
1962         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1963     expr = expr.floorDiv(std::abs(eq[pos + offset]));
1964     // Upper bound is exclusive.
1965     ubExprs.push_back(expr + 1);
1966     // Lower bound.
1967     expr =
1968         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1969     expr = expr.ceilDiv(std::abs(eq[pos + offset]));
1970     lbExprs.push_back(expr);
1971   }
1972 
1973   auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
1974   auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
1975 
1976   return {lbMap, ubMap};
1977 }
1978 
1979 /// Computes the lower and upper bounds of the first 'num' dimensional
1980 /// identifiers (starting at 'offset') as affine maps of the remaining
1981 /// identifiers (dimensional and symbolic identifiers). Local identifiers are
1982 /// themselves explicitly computed as affine functions of other identifiers in
1983 /// this process if needed.
getSliceBounds(unsigned offset,unsigned num,MLIRContext * context,SmallVectorImpl<AffineMap> * lbMaps,SmallVectorImpl<AffineMap> * ubMaps)1984 void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
1985                                            MLIRContext *context,
1986                                            SmallVectorImpl<AffineMap> *lbMaps,
1987                                            SmallVectorImpl<AffineMap> *ubMaps) {
1988   assert(num < getNumDimIds() && "invalid range");
1989 
1990   // Basic simplification.
1991   normalizeConstraintsByGCD();
1992 
1993   LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
1994                           << " identifiers\n");
1995   LLVM_DEBUG(dump());
1996 
1997   // Record computed/detected identifiers.
1998   SmallVector<AffineExpr, 8> memo(getNumIds());
1999   // Initialize dimensional and symbolic identifiers.
2000   for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
2001     if (i < offset)
2002       memo[i] = getAffineDimExpr(i, context);
2003     else if (i >= offset + num)
2004       memo[i] = getAffineDimExpr(i - num, context);
2005   }
2006   for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
2007     memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
2008 
2009   bool changed;
2010   do {
2011     changed = false;
2012     // Identify yet unknown identifiers as constants or mod's / floordiv's of
2013     // other identifiers if possible.
2014     for (unsigned pos = 0; pos < getNumIds(); pos++) {
2015       if (memo[pos])
2016         continue;
2017 
2018       auto lbConst = getConstantBound(BoundType::LB, pos);
2019       auto ubConst = getConstantBound(BoundType::UB, pos);
2020       if (lbConst.hasValue() && ubConst.hasValue()) {
2021         // Detect equality to a constant.
2022         if (lbConst.getValue() == ubConst.getValue()) {
2023           memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
2024           changed = true;
2025           continue;
2026         }
2027 
2028         // Detect an identifier as modulo of another identifier w.r.t a
2029         // constant.
2030         if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
2031                         memo, context)) {
2032           changed = true;
2033           continue;
2034         }
2035       }
2036 
2037       // Detect an identifier as a floordiv of an affine function of other
2038       // identifiers (divisor is a positive constant).
2039       if (detectAsFloorDiv(*this, pos, context, memo)) {
2040         changed = true;
2041         continue;
2042       }
2043 
2044       // Detect an identifier as an expression of other identifiers.
2045       unsigned idx;
2046       if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) {
2047         continue;
2048       }
2049 
2050       // Build AffineExpr solving for identifier 'pos' in terms of all others.
2051       auto expr = getAffineConstantExpr(0, context);
2052       unsigned j, e;
2053       for (j = 0, e = getNumIds(); j < e; ++j) {
2054         if (j == pos)
2055           continue;
2056         int64_t c = atEq(idx, j);
2057         if (c == 0)
2058           continue;
2059         // If any of the involved IDs hasn't been found yet, we can't proceed.
2060         if (!memo[j])
2061           break;
2062         expr = expr + memo[j] * c;
2063       }
2064       if (j < e)
2065         // Can't construct expression as it depends on a yet uncomputed
2066         // identifier.
2067         continue;
2068 
2069       // Add constant term to AffineExpr.
2070       expr = expr + atEq(idx, getNumIds());
2071       int64_t vPos = atEq(idx, pos);
2072       assert(vPos != 0 && "expected non-zero here");
2073       if (vPos > 0)
2074         expr = (-expr).floorDiv(vPos);
2075       else
2076         // vPos < 0.
2077         expr = expr.floorDiv(-vPos);
2078       // Successfully constructed expression.
2079       memo[pos] = expr;
2080       changed = true;
2081     }
2082     // This loop is guaranteed to reach a fixed point - since once an
2083     // identifier's explicit form is computed (in memo[pos]), it's not updated
2084     // again.
2085   } while (changed);
2086 
2087   // Set the lower and upper bound maps for all the identifiers that were
2088   // computed as affine expressions of the rest as the "detected expr" and
2089   // "detected expr + 1" respectively; set the undetected ones to null.
2090   Optional<FlatAffineConstraints> tmpClone;
2091   for (unsigned pos = 0; pos < num; pos++) {
2092     unsigned numMapDims = getNumDimIds() - num;
2093     unsigned numMapSymbols = getNumSymbolIds();
2094     AffineExpr expr = memo[pos + offset];
2095     if (expr)
2096       expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
2097 
2098     AffineMap &lbMap = (*lbMaps)[pos];
2099     AffineMap &ubMap = (*ubMaps)[pos];
2100 
2101     if (expr) {
2102       lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
2103       ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
2104     } else {
2105       // TODO: Whenever there are local identifiers in the dependence
2106       // constraints, we'll conservatively over-approximate, since we don't
2107       // always explicitly compute them above (in the while loop).
2108       if (getNumLocalIds() == 0) {
2109         // Work on a copy so that we don't update this constraint system.
2110         if (!tmpClone) {
2111           tmpClone.emplace(FlatAffineConstraints(*this));
2112           // Removing redundant inequalities is necessary so that we don't get
2113           // redundant loop bounds.
2114           tmpClone->removeRedundantInequalities();
2115         }
2116         std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
2117             pos, offset, num, getNumDimIds(), /*localExprs=*/{}, context);
2118       }
2119 
2120       // If the above fails, we'll just use the constant lower bound and the
2121       // constant upper bound (if they exist) as the slice bounds.
2122       // TODO: being conservative for the moment in cases that
2123       // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
2124       // fixed (b/126426796).
2125       if (!lbMap || lbMap.getNumResults() > 1) {
2126         LLVM_DEBUG(llvm::dbgs()
2127                    << "WARNING: Potentially over-approximating slice lb\n");
2128         auto lbConst = getConstantBound(BoundType::LB, pos + offset);
2129         if (lbConst.hasValue()) {
2130           lbMap = AffineMap::get(
2131               numMapDims, numMapSymbols,
2132               getAffineConstantExpr(lbConst.getValue(), context));
2133         }
2134       }
2135       if (!ubMap || ubMap.getNumResults() > 1) {
2136         LLVM_DEBUG(llvm::dbgs()
2137                    << "WARNING: Potentially over-approximating slice ub\n");
2138         auto ubConst = getConstantBound(BoundType::UB, pos + offset);
2139         if (ubConst.hasValue()) {
2140           (ubMap) = AffineMap::get(
2141               numMapDims, numMapSymbols,
2142               getAffineConstantExpr(ubConst.getValue() + 1, context));
2143         }
2144       }
2145     }
2146     LLVM_DEBUG(llvm::dbgs()
2147                << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
2148     LLVM_DEBUG(lbMap.dump(););
2149     LLVM_DEBUG(llvm::dbgs()
2150                << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
2151     LLVM_DEBUG(ubMap.dump(););
2152   }
2153 }
2154 
flattenAlignedMapAndMergeLocals(AffineMap map,std::vector<SmallVector<int64_t,8>> * flattenedExprs)2155 LogicalResult FlatAffineConstraints::flattenAlignedMapAndMergeLocals(
2156     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
2157   FlatAffineConstraints localCst;
2158   if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) {
2159     LLVM_DEBUG(llvm::dbgs()
2160                << "composition unimplemented for semi-affine maps\n");
2161     return failure();
2162   }
2163 
2164   // Add localCst information.
2165   if (localCst.getNumLocalIds() > 0) {
2166     unsigned numLocalIds = getNumLocalIds();
2167     // Insert local dims of localCst at the beginning.
2168     insertLocalId(/*pos=*/0, /*num=*/localCst.getNumLocalIds());
2169     // Insert local dims of `this` at the end of localCst.
2170     localCst.appendLocalId(/*num=*/numLocalIds);
2171     // Dimensions of localCst and this constraint set match. Append localCst to
2172     // this constraint set.
2173     append(localCst);
2174   }
2175 
2176   return success();
2177 }
2178 
addBound(BoundType type,unsigned pos,AffineMap boundMap)2179 LogicalResult FlatAffineConstraints::addBound(BoundType type, unsigned pos,
2180                                               AffineMap boundMap) {
2181   assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch");
2182   assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
2183   assert(pos < getNumDimAndSymbolIds() && "invalid position");
2184 
2185   // Equality follows the logic of lower bound except that we add an equality
2186   // instead of an inequality.
2187   assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
2188          "single result expected");
2189   bool lower = type == BoundType::LB || type == BoundType::EQ;
2190 
2191   std::vector<SmallVector<int64_t, 8>> flatExprs;
2192   if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
2193     return failure();
2194   assert(flatExprs.size() == boundMap.getNumResults());
2195 
2196   // Add one (in)equality for each result.
2197   for (const auto &flatExpr : flatExprs) {
2198     SmallVector<int64_t> ineq(getNumCols(), 0);
2199     // Dims and symbols.
2200     for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
2201       ineq[j] = lower ? -flatExpr[j] : flatExpr[j];
2202     }
2203     // Invalid bound: pos appears in `boundMap`.
2204     // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or
2205     // its callers to prevent invalid bounds from being added.
2206     if (ineq[pos] != 0)
2207       continue;
2208     ineq[pos] = lower ? 1 : -1;
2209     // Local columns of `ineq` are at the beginning.
2210     unsigned j = getNumDimIds() + getNumSymbolIds();
2211     unsigned end = flatExpr.size() - 1;
2212     for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
2213       ineq[j] = lower ? -flatExpr[i] : flatExpr[i];
2214     }
2215     // Constant term.
2216     ineq[getNumCols() - 1] =
2217         lower ? -flatExpr[flatExpr.size() - 1]
2218               // Upper bound in flattenedExpr is an exclusive one.
2219               : flatExpr[flatExpr.size() - 1] - 1;
2220     type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq);
2221   }
2222 
2223   return success();
2224 }
2225 
2226 AffineMap
computeAlignedMap(AffineMap map,ValueRange operands) const2227 FlatAffineValueConstraints::computeAlignedMap(AffineMap map,
2228                                               ValueRange operands) const {
2229   assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
2230 
2231   SmallVector<Value> dims, syms;
2232 #ifndef NDEBUG
2233   SmallVector<Value> newSyms;
2234   SmallVector<Value> *newSymsPtr = &newSyms;
2235 #else
2236   SmallVector<Value> *newSymsPtr = nullptr;
2237 #endif // NDEBUG
2238 
2239   dims.reserve(numDims);
2240   syms.reserve(numSymbols);
2241   for (unsigned i = 0; i < numDims; ++i)
2242     dims.push_back(values[i] ? *values[i] : Value());
2243   for (unsigned i = numDims, e = numDims + numSymbols; i < e; ++i)
2244     syms.push_back(values[i] ? *values[i] : Value());
2245 
2246   AffineMap alignedMap =
2247       alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
2248   // All symbols are already part of this FlatAffineConstraints.
2249   assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
2250   assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
2251          "unexpected new/missing symbols");
2252   return alignedMap;
2253 }
2254 
addBound(BoundType type,unsigned pos,AffineMap boundMap,ValueRange boundOperands)2255 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
2256                                                    AffineMap boundMap,
2257                                                    ValueRange boundOperands) {
2258   // Fully compose map and operands; canonicalize and simplify so that we
2259   // transitively get to terminal symbols or loop IVs.
2260   auto map = boundMap;
2261   SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
2262   fullyComposeAffineMapAndOperands(&map, &operands);
2263   map = simplifyAffineMap(map);
2264   canonicalizeMapAndOperands(&map, &operands);
2265   for (auto operand : operands)
2266     addInductionVarOrTerminalSymbol(operand);
2267   return addBound(type, pos, computeAlignedMap(map, operands));
2268 }
2269 
2270 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
2271 // bounds in 'ubMaps' to each value in `values' that appears in the constraint
2272 // system. Note that both lower/upper bounds share the same operand list
2273 // 'operands'.
2274 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
2275 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
2276 // Note that both lower/upper bounds use operands from 'operands'.
2277 // Returns failure for unimplemented cases such as semi-affine expressions or
2278 // expressions with mod/floordiv.
addSliceBounds(ArrayRef<Value> values,ArrayRef<AffineMap> lbMaps,ArrayRef<AffineMap> ubMaps,ArrayRef<Value> operands)2279 LogicalResult FlatAffineValueConstraints::addSliceBounds(
2280     ArrayRef<Value> values, ArrayRef<AffineMap> lbMaps,
2281     ArrayRef<AffineMap> ubMaps, ArrayRef<Value> operands) {
2282   assert(values.size() == lbMaps.size());
2283   assert(lbMaps.size() == ubMaps.size());
2284 
2285   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
2286     unsigned pos;
2287     if (!findId(values[i], &pos))
2288       continue;
2289 
2290     AffineMap lbMap = lbMaps[i];
2291     AffineMap ubMap = ubMaps[i];
2292     assert(!lbMap || lbMap.getNumInputs() == operands.size());
2293     assert(!ubMap || ubMap.getNumInputs() == operands.size());
2294 
2295     // Check if this slice is just an equality along this dimension.
2296     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
2297         ubMap.getNumResults() == 1 &&
2298         lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
2299       if (failed(addBound(BoundType::EQ, pos, lbMap, operands)))
2300         return failure();
2301       continue;
2302     }
2303 
2304     // If lower or upper bound maps are null or provide no results, it implies
2305     // that the source loop was not at all sliced, and the entire loop will be a
2306     // part of the slice.
2307     if (lbMap && lbMap.getNumResults() != 0 && ubMap &&
2308         ubMap.getNumResults() != 0) {
2309       if (failed(addBound(BoundType::LB, pos, lbMap, operands)))
2310         return failure();
2311       if (failed(addBound(BoundType::UB, pos, ubMap, operands)))
2312         return failure();
2313     } else {
2314       auto loop = getForInductionVarOwner(values[i]);
2315       if (failed(this->addAffineForOpDomain(loop)))
2316         return failure();
2317     }
2318   }
2319   return success();
2320 }
2321 
addEquality(ArrayRef<int64_t> eq)2322 void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
2323   assert(eq.size() == getNumCols());
2324   unsigned row = equalities.appendExtraRow();
2325   for (unsigned i = 0, e = eq.size(); i < e; ++i)
2326     equalities(row, i) = eq[i];
2327 }
2328 
addInequality(ArrayRef<int64_t> inEq)2329 void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
2330   assert(inEq.size() == getNumCols());
2331   unsigned row = inequalities.appendExtraRow();
2332   for (unsigned i = 0, e = inEq.size(); i < e; ++i)
2333     inequalities(row, i) = inEq[i];
2334 }
2335 
addBound(BoundType type,unsigned pos,int64_t value)2336 void FlatAffineConstraints::addBound(BoundType type, unsigned pos,
2337                                      int64_t value) {
2338   assert(pos < getNumCols());
2339   if (type == BoundType::EQ) {
2340     unsigned row = equalities.appendExtraRow();
2341     equalities(row, pos) = 1;
2342     equalities(row, getNumCols() - 1) = -value;
2343   } else {
2344     unsigned row = inequalities.appendExtraRow();
2345     inequalities(row, pos) = type == BoundType::LB ? 1 : -1;
2346     inequalities(row, getNumCols() - 1) =
2347         type == BoundType::LB ? -value : value;
2348   }
2349 }
2350 
addBound(BoundType type,ArrayRef<int64_t> expr,int64_t value)2351 void FlatAffineConstraints::addBound(BoundType type, ArrayRef<int64_t> expr,
2352                                      int64_t value) {
2353   assert(type != BoundType::EQ && "EQ not implemented");
2354   assert(expr.size() == getNumCols());
2355   unsigned row = inequalities.appendExtraRow();
2356   for (unsigned i = 0, e = expr.size(); i < e; ++i)
2357     inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i];
2358   inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) +=
2359       type == BoundType::LB ? -value : value;
2360 }
2361 
2362 /// Adds a new local identifier as the floordiv of an affine function of other
2363 /// identifiers, the coefficients of which are provided in 'dividend' and with
2364 /// respect to a positive constant 'divisor'. Two constraints are added to the
2365 /// system to capture equivalence with the floordiv.
2366 ///      q = expr floordiv c    <=>   c*q <= expr <= c*q + c - 1.
addLocalFloorDiv(ArrayRef<int64_t> dividend,int64_t divisor)2367 void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend,
2368                                              int64_t divisor) {
2369   assert(dividend.size() == getNumCols() && "incorrect dividend size");
2370   assert(divisor > 0 && "positive divisor expected");
2371 
2372   appendLocalId();
2373 
2374   // Add two constraints for this new identifier 'q'.
2375   SmallVector<int64_t, 8> bound(dividend.size() + 1);
2376 
2377   // dividend - q * divisor >= 0
2378   std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
2379             bound.begin());
2380   bound.back() = dividend.back();
2381   bound[getNumIds() - 1] = -divisor;
2382   addInequality(bound);
2383 
2384   // -dividend +qdivisor * q + divisor - 1 >= 0
2385   std::transform(bound.begin(), bound.end(), bound.begin(),
2386                  std::negate<int64_t>());
2387   bound[bound.size() - 1] += divisor - 1;
2388   addInequality(bound);
2389 }
2390 
findId(Value val,unsigned * pos) const2391 bool FlatAffineValueConstraints::findId(Value val, unsigned *pos) const {
2392   unsigned i = 0;
2393   for (const auto &mayBeId : values) {
2394     if (mayBeId.hasValue() && mayBeId.getValue() == val) {
2395       *pos = i;
2396       return true;
2397     }
2398     i++;
2399   }
2400   return false;
2401 }
2402 
containsId(Value val) const2403 bool FlatAffineValueConstraints::containsId(Value val) const {
2404   return llvm::any_of(values, [&](const Optional<Value> &mayBeId) {
2405     return mayBeId.hasValue() && mayBeId.getValue() == val;
2406   });
2407 }
2408 
swapId(unsigned posA,unsigned posB)2409 void FlatAffineConstraints::swapId(unsigned posA, unsigned posB) {
2410   assert(posA < getNumIds() && "invalid position A");
2411   assert(posB < getNumIds() && "invalid position B");
2412 
2413   if (posA == posB)
2414     return;
2415 
2416   for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
2417     std::swap(atIneq(r, posA), atIneq(r, posB));
2418   for (unsigned r = 0, e = getNumEqualities(); r < e; r++)
2419     std::swap(atEq(r, posA), atEq(r, posB));
2420 }
2421 
swapId(unsigned posA,unsigned posB)2422 void FlatAffineValueConstraints::swapId(unsigned posA, unsigned posB) {
2423   FlatAffineConstraints::swapId(posA, posB);
2424   std::swap(values[posA], values[posB]);
2425 }
2426 
setDimSymbolSeparation(unsigned newSymbolCount)2427 void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
2428   assert(newSymbolCount <= numDims + numSymbols &&
2429          "invalid separation position");
2430   numDims = numDims + numSymbols - newSymbolCount;
2431   numSymbols = newSymbolCount;
2432 }
2433 
addBound(BoundType type,Value val,int64_t value)2434 void FlatAffineValueConstraints::addBound(BoundType type, Value val,
2435                                           int64_t value) {
2436   unsigned pos;
2437   if (!findId(val, &pos))
2438     // This is a pre-condition for this method.
2439     assert(0 && "id not found");
2440   addBound(type, pos, value);
2441 }
2442 
removeEquality(unsigned pos)2443 void FlatAffineConstraints::removeEquality(unsigned pos) {
2444   equalities.removeRow(pos);
2445 }
2446 
removeInequality(unsigned pos)2447 void FlatAffineConstraints::removeInequality(unsigned pos) {
2448   inequalities.removeRow(pos);
2449 }
2450 
removeEqualityRange(unsigned begin,unsigned end)2451 void FlatAffineConstraints::removeEqualityRange(unsigned begin, unsigned end) {
2452   if (begin >= end)
2453     return;
2454   equalities.removeRows(begin, end - begin);
2455 }
2456 
removeInequalityRange(unsigned begin,unsigned end)2457 void FlatAffineConstraints::removeInequalityRange(unsigned begin,
2458                                                   unsigned end) {
2459   if (begin >= end)
2460     return;
2461   inequalities.removeRows(begin, end - begin);
2462 }
2463 
2464 /// Finds an equality that equates the specified identifier to a constant.
2465 /// Returns the position of the equality row. If 'symbolic' is set to true,
2466 /// symbols are also treated like a constant, i.e., an affine function of the
2467 /// symbols is also treated like a constant. Returns -1 if such an equality
2468 /// could not be found.
findEqualityToConstant(const FlatAffineConstraints & cst,unsigned pos,bool symbolic=false)2469 static int findEqualityToConstant(const FlatAffineConstraints &cst,
2470                                   unsigned pos, bool symbolic = false) {
2471   assert(pos < cst.getNumIds() && "invalid position");
2472   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
2473     int64_t v = cst.atEq(r, pos);
2474     if (v * v != 1)
2475       continue;
2476     unsigned c;
2477     unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
2478     // This checks for zeros in all positions other than 'pos' in [0, f)
2479     for (c = 0; c < f; c++) {
2480       if (c == pos)
2481         continue;
2482       if (cst.atEq(r, c) != 0) {
2483         // Dependent on another identifier.
2484         break;
2485       }
2486     }
2487     if (c == f)
2488       // Equality is free of other identifiers.
2489       return r;
2490   }
2491   return -1;
2492 }
2493 
setAndEliminate(unsigned pos,ArrayRef<int64_t> values)2494 void FlatAffineConstraints::setAndEliminate(unsigned pos,
2495                                             ArrayRef<int64_t> values) {
2496   if (values.empty())
2497     return;
2498   assert(pos + values.size() <= getNumIds() &&
2499          "invalid position or too many values");
2500   // Setting x_j = p in sum_i a_i x_i + c is equivalent to adding p*a_j to the
2501   // constant term and removing the id x_j. We do this for all the ids
2502   // pos, pos + 1, ... pos + values.size() - 1.
2503   for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
2504     for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
2505       atIneq(r, getNumCols() - 1) += atIneq(r, pos + i) * values[i];
2506   for (unsigned r = 0, e = getNumEqualities(); r < e; r++)
2507     for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
2508       atEq(r, getNumCols() - 1) += atEq(r, pos + i) * values[i];
2509   removeIdRange(pos, pos + values.size());
2510 }
2511 
constantFoldId(unsigned pos)2512 LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) {
2513   assert(pos < getNumIds() && "invalid position");
2514   int rowIdx;
2515   if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
2516     return failure();
2517 
2518   // atEq(rowIdx, pos) is either -1 or 1.
2519   assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
2520   int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
2521   setAndEliminate(pos, constVal);
2522   return success();
2523 }
2524 
constantFoldIdRange(unsigned pos,unsigned num)2525 void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
2526   for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
2527     if (failed(constantFoldId(t)))
2528       t++;
2529   }
2530 }
2531 
2532 /// Returns a non-negative constant bound on the extent (upper bound - lower
2533 /// bound) of the specified identifier if it is found to be a constant; returns
2534 /// None if it's not a constant. This methods treats symbolic identifiers
2535 /// specially, i.e., it looks for constant differences between affine
2536 /// expressions involving only the symbolic identifiers. See comments at
2537 /// function definition for example. 'lb', if provided, is set to the lower
2538 /// bound associated with the constant difference. Note that 'lb' is purely
2539 /// symbolic and thus will contain the coefficients of the symbolic identifiers
2540 /// and the constant coefficient.
2541 //  Egs: 0 <= i <= 15, return 16.
2542 //       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
2543 //       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
2544 //       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
2545 //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
getConstantBoundOnDimSize(unsigned pos,SmallVectorImpl<int64_t> * lb,int64_t * boundFloorDivisor,SmallVectorImpl<int64_t> * ub,unsigned * minLbPos,unsigned * minUbPos) const2546 Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
2547     unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
2548     SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
2549     unsigned *minUbPos) const {
2550   assert(pos < getNumDimIds() && "Invalid identifier position");
2551 
2552   // Find an equality for 'pos'^th identifier that equates it to some function
2553   // of the symbolic identifiers (+ constant).
2554   int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
2555   if (eqPos != -1) {
2556     auto eq = getEquality(eqPos);
2557     // If the equality involves a local var, punt for now.
2558     // TODO: this can be handled in the future by using the explicit
2559     // representation of the local vars.
2560     if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1,
2561                      [](int64_t coeff) { return coeff == 0; }))
2562       return None;
2563 
2564     // This identifier can only take a single value.
2565     if (lb) {
2566       // Set lb to that symbolic value.
2567       lb->resize(getNumSymbolIds() + 1);
2568       if (ub)
2569         ub->resize(getNumSymbolIds() + 1);
2570       for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
2571         int64_t v = atEq(eqPos, pos);
2572         // atEq(eqRow, pos) is either -1 or 1.
2573         assert(v * v == 1);
2574         (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v
2575                          : -atEq(eqPos, getNumDimIds() + c) / v;
2576         // Since this is an equality, ub = lb.
2577         if (ub)
2578           (*ub)[c] = (*lb)[c];
2579       }
2580       assert(boundFloorDivisor &&
2581              "both lb and divisor or none should be provided");
2582       *boundFloorDivisor = 1;
2583     }
2584     if (minLbPos)
2585       *minLbPos = eqPos;
2586     if (minUbPos)
2587       *minUbPos = eqPos;
2588     return 1;
2589   }
2590 
2591   // Check if the identifier appears at all in any of the inequalities.
2592   unsigned r, e;
2593   for (r = 0, e = getNumInequalities(); r < e; r++) {
2594     if (atIneq(r, pos) != 0)
2595       break;
2596   }
2597   if (r == e)
2598     // If it doesn't, there isn't a bound on it.
2599     return None;
2600 
2601   // Positions of constraints that are lower/upper bounds on the variable.
2602   SmallVector<unsigned, 4> lbIndices, ubIndices;
2603 
2604   // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
2605   // the bounds can only involve symbolic (and local) identifiers. Since the
2606   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2607   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2608   getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
2609                                /*eqIndices=*/nullptr, /*offset=*/0,
2610                                /*num=*/getNumDimIds());
2611 
2612   Optional<int64_t> minDiff = None;
2613   unsigned minLbPosition = 0, minUbPosition = 0;
2614   for (auto ubPos : ubIndices) {
2615     for (auto lbPos : lbIndices) {
2616       // Look for a lower bound and an upper bound that only differ by a
2617       // constant, i.e., pairs of the form  0 <= c_pos - f(c_i's) <= diffConst.
2618       // For example, if ii is the pos^th variable, we are looking for
2619       // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
2620       // minimum among all such constant differences is kept since that's the
2621       // constant bounding the extent of the pos^th variable.
2622       unsigned j, e;
2623       for (j = 0, e = getNumCols() - 1; j < e; j++)
2624         if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
2625           break;
2626         }
2627       if (j < getNumCols() - 1)
2628         continue;
2629       int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
2630                                  atIneq(lbPos, getNumCols() - 1) + 1,
2631                              atIneq(lbPos, pos));
2632       // This bound is non-negative by definition.
2633       diff = std::max<int64_t>(diff, 0);
2634       if (minDiff == None || diff < minDiff) {
2635         minDiff = diff;
2636         minLbPosition = lbPos;
2637         minUbPosition = ubPos;
2638       }
2639     }
2640   }
2641   if (lb && minDiff.hasValue()) {
2642     // Set lb to the symbolic lower bound.
2643     lb->resize(getNumSymbolIds() + 1);
2644     if (ub)
2645       ub->resize(getNumSymbolIds() + 1);
2646     // The lower bound is the ceildiv of the lb constraint over the coefficient
2647     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
2648     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
2649     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
2650     *boundFloorDivisor = atIneq(minLbPosition, pos);
2651     assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
2652     for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
2653       (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
2654     }
2655     if (ub) {
2656       for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
2657         (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
2658     }
2659     // The lower bound leads to a ceildiv while the upper bound is a floordiv
2660     // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
2661     // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
2662     // the constant term for the lower bound.
2663     (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
2664   }
2665   if (minLbPos)
2666     *minLbPos = minLbPosition;
2667   if (minUbPos)
2668     *minUbPos = minUbPosition;
2669   return minDiff;
2670 }
2671 
2672 template <bool isLower>
2673 Optional<int64_t>
computeConstantLowerOrUpperBound(unsigned pos)2674 FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
2675   assert(pos < getNumIds() && "invalid position");
2676   // Project to 'pos'.
2677   projectOut(0, pos);
2678   projectOut(1, getNumIds() - 1);
2679   // Check if there's an equality equating the '0'^th identifier to a constant.
2680   int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
2681   if (eqRowIdx != -1)
2682     // atEq(rowIdx, 0) is either -1 or 1.
2683     return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
2684 
2685   // Check if the identifier appears at all in any of the inequalities.
2686   unsigned r, e;
2687   for (r = 0, e = getNumInequalities(); r < e; r++) {
2688     if (atIneq(r, 0) != 0)
2689       break;
2690   }
2691   if (r == e)
2692     // If it doesn't, there isn't a bound on it.
2693     return None;
2694 
2695   Optional<int64_t> minOrMaxConst = None;
2696 
2697   // Take the max across all const lower bounds (or min across all constant
2698   // upper bounds).
2699   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2700     if (isLower) {
2701       if (atIneq(r, 0) <= 0)
2702         // Not a lower bound.
2703         continue;
2704     } else if (atIneq(r, 0) >= 0) {
2705       // Not an upper bound.
2706       continue;
2707     }
2708     unsigned c, f;
2709     for (c = 0, f = getNumCols() - 1; c < f; c++)
2710       if (c != 0 && atIneq(r, c) != 0)
2711         break;
2712     if (c < getNumCols() - 1)
2713       // Not a constant bound.
2714       continue;
2715 
2716     int64_t boundConst =
2717         isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
2718                 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
2719     if (isLower) {
2720       if (minOrMaxConst == None || boundConst > minOrMaxConst)
2721         minOrMaxConst = boundConst;
2722     } else {
2723       if (minOrMaxConst == None || boundConst < minOrMaxConst)
2724         minOrMaxConst = boundConst;
2725     }
2726   }
2727   return minOrMaxConst;
2728 }
2729 
getConstantBound(BoundType type,unsigned pos) const2730 Optional<int64_t> FlatAffineConstraints::getConstantBound(BoundType type,
2731                                                           unsigned pos) const {
2732   assert(type != BoundType::EQ && "EQ not implemented");
2733   FlatAffineConstraints tmpCst(*this);
2734   if (type == BoundType::LB)
2735     return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
2736   return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
2737 }
2738 
2739 // A simple (naive and conservative) check for hyper-rectangularity.
isHyperRectangular(unsigned pos,unsigned num) const2740 bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
2741                                                unsigned num) const {
2742   assert(pos < getNumCols() - 1);
2743   // Check for two non-zero coefficients in the range [pos, pos + sum).
2744   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2745     unsigned sum = 0;
2746     for (unsigned c = pos; c < pos + num; c++) {
2747       if (atIneq(r, c) != 0)
2748         sum++;
2749     }
2750     if (sum > 1)
2751       return false;
2752   }
2753   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2754     unsigned sum = 0;
2755     for (unsigned c = pos; c < pos + num; c++) {
2756       if (atEq(r, c) != 0)
2757         sum++;
2758     }
2759     if (sum > 1)
2760       return false;
2761   }
2762   return true;
2763 }
2764 
print(raw_ostream & os) const2765 void FlatAffineConstraints::print(raw_ostream &os) const {
2766   assert(hasConsistentState());
2767   os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds()
2768      << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints()
2769      << " constraints)\n";
2770   os << "(";
2771   for (unsigned i = 0, e = getNumIds(); i < e; i++) {
2772     if (auto *valueCstr = dyn_cast<const FlatAffineValueConstraints>(this)) {
2773       if (valueCstr->hasValue(i))
2774         os << "Value ";
2775       else
2776         os << "None ";
2777     } else {
2778       os << "None ";
2779     }
2780   }
2781   os << " const)\n";
2782   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
2783     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2784       os << atEq(i, j) << " ";
2785     }
2786     os << "= 0\n";
2787   }
2788   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2789     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2790       os << atIneq(i, j) << " ";
2791     }
2792     os << ">= 0\n";
2793   }
2794   os << '\n';
2795 }
2796 
dump() const2797 void FlatAffineConstraints::dump() const { print(llvm::errs()); }
2798 
2799 /// Removes duplicate constraints, trivially true constraints, and constraints
2800 /// that can be detected as redundant as a result of differing only in their
2801 /// constant term part. A constraint of the form <non-negative constant> >= 0 is
2802 /// considered trivially true.
2803 //  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
2804 //  remove duplicates in place.
removeTrivialRedundancy()2805 void FlatAffineConstraints::removeTrivialRedundancy() {
2806   gcdTightenInequalities();
2807   normalizeConstraintsByGCD();
2808 
2809   // A map used to detect redundancy stemming from constraints that only differ
2810   // in their constant term. The value stored is <row position, const term>
2811   // for a given row.
2812   SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
2813       rowsWithoutConstTerm;
2814   // To unique rows.
2815   SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
2816 
2817   // Check if constraint is of the form <non-negative-constant> >= 0.
2818   auto isTriviallyValid = [&](unsigned r) -> bool {
2819     for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
2820       if (atIneq(r, c) != 0)
2821         return false;
2822     }
2823     return atIneq(r, getNumCols() - 1) >= 0;
2824   };
2825 
2826   // Detect and mark redundant constraints.
2827   SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
2828   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2829     int64_t *rowStart = &inequalities(r, 0);
2830     auto row = ArrayRef<int64_t>(rowStart, getNumCols());
2831     if (isTriviallyValid(r) || !rowSet.insert(row).second) {
2832       redunIneq[r] = true;
2833       continue;
2834     }
2835 
2836     // Among constraints that only differ in the constant term part, mark
2837     // everything other than the one with the smallest constant term redundant.
2838     // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
2839     // former two are redundant).
2840     int64_t constTerm = atIneq(r, getNumCols() - 1);
2841     auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
2842     const auto &ret =
2843         rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
2844     if (!ret.second) {
2845       // Check if the other constraint has a higher constant term.
2846       auto &val = ret.first->second;
2847       if (val.second > constTerm) {
2848         // The stored row is redundant. Mark it so, and update with this one.
2849         redunIneq[val.first] = true;
2850         val = {r, constTerm};
2851       } else {
2852         // The one stored makes this one redundant.
2853         redunIneq[r] = true;
2854       }
2855     }
2856   }
2857 
2858   // Scan to get rid of all rows marked redundant, in-place.
2859   unsigned pos = 0;
2860   for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
2861     if (!redunIneq[r])
2862       inequalities.copyRow(r, pos++);
2863 
2864   inequalities.resizeVertically(pos);
2865 
2866   // TODO: consider doing this for equalities as well, but probably not worth
2867   // the savings.
2868 }
2869 
clearAndCopyFrom(const FlatAffineConstraints & other)2870 void FlatAffineConstraints::clearAndCopyFrom(
2871     const FlatAffineConstraints &other) {
2872   if (auto *otherValueSet = dyn_cast<const FlatAffineValueConstraints>(&other))
2873     assert(!otherValueSet->hasValues() &&
2874            "cannot copy associated Values into FlatAffineConstraints");
2875   // Note: Assigment operator does not vtable pointer, so kind does not change.
2876   *this = other;
2877 }
2878 
clearAndCopyFrom(const FlatAffineConstraints & other)2879 void FlatAffineValueConstraints::clearAndCopyFrom(
2880     const FlatAffineConstraints &other) {
2881   if (auto *otherValueSet =
2882           dyn_cast<const FlatAffineValueConstraints>(&other)) {
2883     *this = *otherValueSet;
2884   } else {
2885     *static_cast<FlatAffineConstraints *>(this) = other;
2886     values.clear();
2887     values.resize(numIds, None);
2888   }
2889 }
2890 
removeId(unsigned pos)2891 void FlatAffineConstraints::removeId(unsigned pos) {
2892   removeIdRange(pos, pos + 1);
2893 }
2894 
2895 static std::pair<unsigned, unsigned>
getNewNumDimsSymbols(unsigned pos,const FlatAffineConstraints & cst)2896 getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
2897   unsigned numDims = cst.getNumDimIds();
2898   unsigned numSymbols = cst.getNumSymbolIds();
2899   unsigned newNumDims, newNumSymbols;
2900   if (pos < numDims) {
2901     newNumDims = numDims - 1;
2902     newNumSymbols = numSymbols;
2903   } else if (pos < numDims + numSymbols) {
2904     assert(numSymbols >= 1);
2905     newNumDims = numDims;
2906     newNumSymbols = numSymbols - 1;
2907   } else {
2908     newNumDims = numDims;
2909     newNumSymbols = numSymbols;
2910   }
2911   return {newNumDims, newNumSymbols};
2912 }
2913 
2914 #undef DEBUG_TYPE
2915 #define DEBUG_TYPE "fm"
2916 
2917 /// Eliminates identifier at the specified position using Fourier-Motzkin
2918 /// variable elimination. This technique is exact for rational spaces but
2919 /// conservative (in "rare" cases) for integer spaces. The operation corresponds
2920 /// to a projection operation yielding the (convex) set of integer points
2921 /// contained in the rational shadow of the set. An emptiness test that relies
2922 /// on this method will guarantee emptiness, i.e., it disproves the existence of
2923 /// a solution if it says it's empty.
2924 /// If a non-null isResultIntegerExact is passed, it is set to true if the
2925 /// result is also integer exact. If it's set to false, the obtained solution
2926 /// *may* not be exact, i.e., it may contain integer points that do not have an
2927 /// integer pre-image in the original set.
2928 ///
2929 /// Eg:
2930 /// j >= 0, j <= i + 1
2931 /// i >= 0, i <= N + 1
2932 /// Eliminating i yields,
2933 ///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
2934 ///
2935 /// If darkShadow = true, this method computes the dark shadow on elimination;
2936 /// the dark shadow is a convex integer subset of the exact integer shadow. A
2937 /// non-empty dark shadow proves the existence of an integer solution. The
2938 /// elimination in such a case could however be an under-approximation, and thus
2939 /// should not be used for scanning sets or used by itself for dependence
2940 /// checking.
2941 ///
2942 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
2943 ///            ^
2944 ///            |
2945 ///            | * * * * o o
2946 ///         i  | * * o o o o
2947 ///            | o * * * * *
2948 ///            --------------->
2949 ///                 j ->
2950 ///
2951 /// Eliminating i from this system (projecting on the j dimension):
2952 /// rational shadow / integer light shadow:  1 <= j <= 6
2953 /// dark shadow:                             3 <= j <= 6
2954 /// exact integer shadow:                    j = 1 \union  3 <= j <= 6
2955 /// holes/splinters:                         j = 2
2956 ///
2957 /// darkShadow = false, isResultIntegerExact = nullptr are default values.
2958 // TODO: a slight modification to yield dark shadow version of FM (tightened),
2959 // which can prove the existence of a solution if there is one.
fourierMotzkinEliminate(unsigned pos,bool darkShadow,bool * isResultIntegerExact)2960 void FlatAffineConstraints::fourierMotzkinEliminate(
2961     unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
2962   LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
2963   LLVM_DEBUG(dump());
2964   assert(pos < getNumIds() && "invalid position");
2965   assert(hasConsistentState());
2966 
2967   // Check if this identifier can be eliminated through a substitution.
2968   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2969     if (atEq(r, pos) != 0) {
2970       // Use Gaussian elimination here (since we have an equality).
2971       LogicalResult ret = gaussianEliminateId(pos);
2972       (void)ret;
2973       assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
2974       LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
2975       LLVM_DEBUG(dump());
2976       return;
2977     }
2978   }
2979 
2980   // A fast linear time tightening.
2981   gcdTightenInequalities();
2982 
2983   // Check if the identifier appears at all in any of the inequalities.
2984   unsigned r, e;
2985   for (r = 0, e = getNumInequalities(); r < e; r++) {
2986     if (atIneq(r, pos) != 0)
2987       break;
2988   }
2989   if (r == getNumInequalities()) {
2990     // If it doesn't appear, just remove the column and return.
2991     // TODO: refactor removeColumns to use it from here.
2992     removeId(pos);
2993     LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2994     LLVM_DEBUG(dump());
2995     return;
2996   }
2997 
2998   // Positions of constraints that are lower bounds on the variable.
2999   SmallVector<unsigned, 4> lbIndices;
3000   // Positions of constraints that are lower bounds on the variable.
3001   SmallVector<unsigned, 4> ubIndices;
3002   // Positions of constraints that do not involve the variable.
3003   std::vector<unsigned> nbIndices;
3004   nbIndices.reserve(getNumInequalities());
3005 
3006   // Gather all lower bounds and upper bounds of the variable. Since the
3007   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
3008   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
3009   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
3010     if (atIneq(r, pos) == 0) {
3011       // Id does not appear in bound.
3012       nbIndices.push_back(r);
3013     } else if (atIneq(r, pos) >= 1) {
3014       // Lower bound.
3015       lbIndices.push_back(r);
3016     } else {
3017       // Upper bound.
3018       ubIndices.push_back(r);
3019     }
3020   }
3021 
3022   // Set the number of dimensions, symbols in the resulting system.
3023   const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this);
3024   unsigned newNumDims = dimsSymbols.first;
3025   unsigned newNumSymbols = dimsSymbols.second;
3026 
3027   /// Create the new system which has one identifier less.
3028   FlatAffineConstraints newFac(
3029       lbIndices.size() * ubIndices.size() + nbIndices.size(),
3030       getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
3031       /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols);
3032 
3033   // This will be used to check if the elimination was integer exact.
3034   unsigned lcmProducts = 1;
3035 
3036   // Let x be the variable we are eliminating.
3037   // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
3038   // that c_l, c_u >= 1) we have:
3039   // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
3040   // We thus generate a constraint:
3041   // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
3042   // Note if c_l = c_u = 1, all integer points captured by the resulting
3043   // constraint correspond to integer points in the original system (i.e., they
3044   // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
3045   // integer exact.
3046   for (auto ubPos : ubIndices) {
3047     for (auto lbPos : lbIndices) {
3048       SmallVector<int64_t, 4> ineq;
3049       ineq.reserve(newFac.getNumCols());
3050       int64_t lbCoeff = atIneq(lbPos, pos);
3051       // Note that in the comments above, ubCoeff is the negation of the
3052       // coefficient in the canonical form as the view taken here is that of the
3053       // term being moved to the other size of '>='.
3054       int64_t ubCoeff = -atIneq(ubPos, pos);
3055       // TODO: refactor this loop to avoid all branches inside.
3056       for (unsigned l = 0, e = getNumCols(); l < e; l++) {
3057         if (l == pos)
3058           continue;
3059         assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
3060         int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
3061         ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
3062                        atIneq(lbPos, l) * (lcm / lbCoeff));
3063         lcmProducts *= lcm;
3064       }
3065       if (darkShadow) {
3066         // The dark shadow is a convex subset of the exact integer shadow. If
3067         // there is a point here, it proves the existence of a solution.
3068         ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
3069       }
3070       // TODO: we need to have a way to add inequalities in-place in
3071       // FlatAffineConstraints instead of creating and copying over.
3072       newFac.addInequality(ineq);
3073     }
3074   }
3075 
3076   LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
3077                           << "\n");
3078   if (lcmProducts == 1 && isResultIntegerExact)
3079     *isResultIntegerExact = true;
3080 
3081   // Copy over the constraints not involving this variable.
3082   for (auto nbPos : nbIndices) {
3083     SmallVector<int64_t, 4> ineq;
3084     ineq.reserve(getNumCols() - 1);
3085     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
3086       if (l == pos)
3087         continue;
3088       ineq.push_back(atIneq(nbPos, l));
3089     }
3090     newFac.addInequality(ineq);
3091   }
3092 
3093   assert(newFac.getNumConstraints() ==
3094          lbIndices.size() * ubIndices.size() + nbIndices.size());
3095 
3096   // Copy over the equalities.
3097   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
3098     SmallVector<int64_t, 4> eq;
3099     eq.reserve(newFac.getNumCols());
3100     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
3101       if (l == pos)
3102         continue;
3103       eq.push_back(atEq(r, l));
3104     }
3105     newFac.addEquality(eq);
3106   }
3107 
3108   // GCD tightening and normalization allows detection of more trivially
3109   // redundant constraints.
3110   newFac.gcdTightenInequalities();
3111   newFac.normalizeConstraintsByGCD();
3112   newFac.removeTrivialRedundancy();
3113   clearAndCopyFrom(newFac);
3114   LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
3115   LLVM_DEBUG(dump());
3116 }
3117 
3118 #undef DEBUG_TYPE
3119 #define DEBUG_TYPE "affine-structures"
3120 
fourierMotzkinEliminate(unsigned pos,bool darkShadow,bool * isResultIntegerExact)3121 void FlatAffineValueConstraints::fourierMotzkinEliminate(
3122     unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
3123   SmallVector<Optional<Value>, 8> newVals;
3124   newVals.reserve(numIds - 1);
3125   newVals.append(values.begin(), values.begin() + pos);
3126   newVals.append(values.begin() + pos + 1, values.end());
3127   // Note: Base implementation discards all associated Values.
3128   FlatAffineConstraints::fourierMotzkinEliminate(pos, darkShadow,
3129                                                  isResultIntegerExact);
3130   values = newVals;
3131   assert(values.size() == getNumIds());
3132 }
3133 
projectOut(unsigned pos,unsigned num)3134 void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
3135   if (num == 0)
3136     return;
3137 
3138   // 'pos' can be at most getNumCols() - 2 if num > 0.
3139   assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
3140   assert(pos + num < getNumCols() && "invalid range");
3141 
3142   // Eliminate as many identifiers as possible using Gaussian elimination.
3143   unsigned currentPos = pos;
3144   unsigned numToEliminate = num;
3145   unsigned numGaussianEliminated = 0;
3146 
3147   while (currentPos < getNumIds()) {
3148     unsigned curNumEliminated =
3149         gaussianEliminateIds(currentPos, currentPos + numToEliminate);
3150     ++currentPos;
3151     numToEliminate -= curNumEliminated + 1;
3152     numGaussianEliminated += curNumEliminated;
3153   }
3154 
3155   // Eliminate the remaining using Fourier-Motzkin.
3156   for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
3157     unsigned numToEliminate = num - numGaussianEliminated - i;
3158     fourierMotzkinEliminate(
3159         getBestIdToEliminate(*this, pos, pos + numToEliminate));
3160   }
3161 
3162   // Fast/trivial simplifications.
3163   gcdTightenInequalities();
3164   // Normalize constraints after tightening since the latter impacts this, but
3165   // not the other way round.
3166   normalizeConstraintsByGCD();
3167 }
3168 
projectOut(Value val)3169 void FlatAffineValueConstraints::projectOut(Value val) {
3170   unsigned pos;
3171   bool ret = findId(val, &pos);
3172   assert(ret);
3173   (void)ret;
3174   fourierMotzkinEliminate(pos);
3175 }
3176 
clearConstraints()3177 void FlatAffineConstraints::clearConstraints() {
3178   equalities.resizeVertically(0);
3179   inequalities.resizeVertically(0);
3180 }
3181 
3182 namespace {
3183 
3184 enum BoundCmpResult { Greater, Less, Equal, Unknown };
3185 
3186 /// Compares two affine bounds whose coefficients are provided in 'first' and
3187 /// 'second'. The last coefficient is the constant term.
compareBounds(ArrayRef<int64_t> a,ArrayRef<int64_t> b)3188 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
3189   assert(a.size() == b.size());
3190 
3191   // For the bounds to be comparable, their corresponding identifier
3192   // coefficients should be equal; the constant terms are then compared to
3193   // determine less/greater/equal.
3194 
3195   if (!std::equal(a.begin(), a.end() - 1, b.begin()))
3196     return Unknown;
3197 
3198   if (a.back() == b.back())
3199     return Equal;
3200 
3201   return a.back() < b.back() ? Less : Greater;
3202 }
3203 } // namespace
3204 
3205 // Returns constraints that are common to both A & B.
getCommonConstraints(const FlatAffineConstraints & a,const FlatAffineConstraints & b,FlatAffineConstraints & c)3206 static void getCommonConstraints(const FlatAffineConstraints &a,
3207                                  const FlatAffineConstraints &b,
3208                                  FlatAffineConstraints &c) {
3209   c.reset(a.getNumDimIds(), a.getNumSymbolIds(), a.getNumLocalIds());
3210   // a naive O(n^2) check should be enough here given the input sizes.
3211   for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) {
3212     for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) {
3213       if (a.getInequality(r) == b.getInequality(s)) {
3214         c.addInequality(a.getInequality(r));
3215         break;
3216       }
3217     }
3218   }
3219   for (unsigned r = 0, e = a.getNumEqualities(); r < e; ++r) {
3220     for (unsigned s = 0, f = b.getNumEqualities(); s < f; ++s) {
3221       if (a.getEquality(r) == b.getEquality(s)) {
3222         c.addEquality(a.getEquality(r));
3223         break;
3224       }
3225     }
3226   }
3227 }
3228 
3229 // Computes the bounding box with respect to 'other' by finding the min of the
3230 // lower bounds and the max of the upper bounds along each of the dimensions.
3231 LogicalResult
unionBoundingBox(const FlatAffineConstraints & otherCst)3232 FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
3233   assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
3234   assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
3235   assert(getNumLocalIds() == 0 && "local ids not supported yet here");
3236 
3237   // Get the constraints common to both systems; these will be added as is to
3238   // the union.
3239   FlatAffineConstraints commonCst;
3240   getCommonConstraints(*this, otherCst, commonCst);
3241 
3242   std::vector<SmallVector<int64_t, 8>> boundingLbs;
3243   std::vector<SmallVector<int64_t, 8>> boundingUbs;
3244   boundingLbs.reserve(2 * getNumDimIds());
3245   boundingUbs.reserve(2 * getNumDimIds());
3246 
3247   // To hold lower and upper bounds for each dimension.
3248   SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
3249   // To compute min of lower bounds and max of upper bounds for each dimension.
3250   SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
3251   SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
3252   // To compute final new lower and upper bounds for the union.
3253   SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
3254 
3255   int64_t lbFloorDivisor, otherLbFloorDivisor;
3256   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
3257     auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
3258     if (!extent.hasValue())
3259       // TODO: symbolic extents when necessary.
3260       // TODO: handle union if a dimension is unbounded.
3261       return failure();
3262 
3263     auto otherExtent = otherCst.getConstantBoundOnDimSize(
3264         d, &otherLb, &otherLbFloorDivisor, &otherUb);
3265     if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
3266       // TODO: symbolic extents when necessary.
3267       return failure();
3268 
3269     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
3270 
3271     auto res = compareBounds(lb, otherLb);
3272     // Identify min.
3273     if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
3274       minLb = lb;
3275       // Since the divisor is for a floordiv, we need to convert to ceildiv,
3276       // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
3277       // div * i >= expr - div + 1.
3278       minLb.back() -= lbFloorDivisor - 1;
3279     } else if (res == BoundCmpResult::Greater) {
3280       minLb = otherLb;
3281       minLb.back() -= otherLbFloorDivisor - 1;
3282     } else {
3283       // Uncomparable - check for constant lower/upper bounds.
3284       auto constLb = getConstantBound(BoundType::LB, d);
3285       auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
3286       if (!constLb.hasValue() || !constOtherLb.hasValue())
3287         return failure();
3288       std::fill(minLb.begin(), minLb.end(), 0);
3289       minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
3290     }
3291 
3292     // Do the same for ub's but max of upper bounds. Identify max.
3293     auto uRes = compareBounds(ub, otherUb);
3294     if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
3295       maxUb = ub;
3296     } else if (uRes == BoundCmpResult::Less) {
3297       maxUb = otherUb;
3298     } else {
3299       // Uncomparable - check for constant lower/upper bounds.
3300       auto constUb = getConstantBound(BoundType::UB, d);
3301       auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
3302       if (!constUb.hasValue() || !constOtherUb.hasValue())
3303         return failure();
3304       std::fill(maxUb.begin(), maxUb.end(), 0);
3305       maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
3306     }
3307 
3308     std::fill(newLb.begin(), newLb.end(), 0);
3309     std::fill(newUb.begin(), newUb.end(), 0);
3310 
3311     // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
3312     // and so it's the divisor for newLb and newUb as well.
3313     newLb[d] = lbFloorDivisor;
3314     newUb[d] = -lbFloorDivisor;
3315     // Copy over the symbolic part + constant term.
3316     std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
3317     std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
3318                    newLb.begin() + getNumDimIds(), std::negate<int64_t>());
3319     std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
3320 
3321     boundingLbs.push_back(newLb);
3322     boundingUbs.push_back(newUb);
3323   }
3324 
3325   // Clear all constraints and add the lower/upper bounds for the bounding box.
3326   clearConstraints();
3327   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
3328     addInequality(boundingLbs[d]);
3329     addInequality(boundingUbs[d]);
3330   }
3331 
3332   // Add the constraints that were common to both systems.
3333   append(commonCst);
3334   removeTrivialRedundancy();
3335 
3336   // TODO: copy over pure symbolic constraints from this and 'other' over to the
3337   // union (since the above are just the union along dimensions); we shouldn't
3338   // be discarding any other constraints on the symbols.
3339 
3340   return success();
3341 }
3342 
unionBoundingBox(const FlatAffineValueConstraints & otherCst)3343 LogicalResult FlatAffineValueConstraints::unionBoundingBox(
3344     const FlatAffineValueConstraints &otherCst) {
3345   assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
3346   assert(otherCst.getMaybeValues()
3347              .slice(0, getNumDimIds())
3348              .equals(getMaybeValues().slice(0, getNumDimIds())) &&
3349          "dim values mismatch");
3350   assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
3351   assert(getNumLocalIds() == 0 && "local ids not supported yet here");
3352 
3353   // Align `other` to this.
3354   if (!areIdsAligned(*this, otherCst)) {
3355     FlatAffineValueConstraints otherCopy(otherCst);
3356     mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy);
3357     return FlatAffineConstraints::unionBoundingBox(otherCopy);
3358   }
3359 
3360   return FlatAffineConstraints::unionBoundingBox(otherCst);
3361 }
3362 
3363 /// Compute an explicit representation for local vars. For all systems coming
3364 /// from MLIR integer sets, maps, or expressions where local vars were
3365 /// introduced to model floordivs and mods, this always succeeds.
computeLocalVars(const FlatAffineConstraints & cst,SmallVectorImpl<AffineExpr> & memo,MLIRContext * context)3366 static LogicalResult computeLocalVars(const FlatAffineConstraints &cst,
3367                                       SmallVectorImpl<AffineExpr> &memo,
3368                                       MLIRContext *context) {
3369   unsigned numDims = cst.getNumDimIds();
3370   unsigned numSyms = cst.getNumSymbolIds();
3371 
3372   // Initialize dimensional and symbolic identifiers.
3373   for (unsigned i = 0; i < numDims; i++)
3374     memo[i] = getAffineDimExpr(i, context);
3375   for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
3376     memo[i] = getAffineSymbolExpr(i - numDims, context);
3377 
3378   bool changed;
3379   do {
3380     // Each time `changed` is true at the end of this iteration, one or more
3381     // local vars would have been detected as floordivs and set in memo; so the
3382     // number of null entries in memo[...] strictly reduces; so this converges.
3383     changed = false;
3384     for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i)
3385       if (!memo[numDims + numSyms + i] &&
3386           detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
3387         changed = true;
3388   } while (changed);
3389 
3390   ArrayRef<AffineExpr> localExprs =
3391       ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds());
3392   return success(
3393       llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
3394 }
3395 
getIneqAsAffineValueMap(unsigned pos,unsigned ineqPos,AffineValueMap & vmap,MLIRContext * context) const3396 void FlatAffineValueConstraints::getIneqAsAffineValueMap(
3397     unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
3398     MLIRContext *context) const {
3399   unsigned numDims = getNumDimIds();
3400   unsigned numSyms = getNumSymbolIds();
3401 
3402   assert(pos < numDims && "invalid position");
3403   assert(ineqPos < getNumInequalities() && "invalid inequality position");
3404 
3405   // Get expressions for local vars.
3406   SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
3407   if (failed(computeLocalVars(*this, memo, context)))
3408     assert(false &&
3409            "one or more local exprs do not have an explicit representation");
3410   auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
3411 
3412   // Compute the AffineExpr lower/upper bound for this inequality.
3413   ArrayRef<int64_t> inequality = getInequality(ineqPos);
3414   SmallVector<int64_t, 8> bound;
3415   bound.reserve(getNumCols() - 1);
3416   // Everything other than the coefficient at `pos`.
3417   bound.append(inequality.begin(), inequality.begin() + pos);
3418   bound.append(inequality.begin() + pos + 1, inequality.end());
3419 
3420   if (inequality[pos] > 0)
3421     // Lower bound.
3422     std::transform(bound.begin(), bound.end(), bound.begin(),
3423                    std::negate<int64_t>());
3424   else
3425     // Upper bound (which is exclusive).
3426     bound.back() += 1;
3427 
3428   // Convert to AffineExpr (tree) form.
3429   auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
3430                                              localExprs, context);
3431 
3432   // Get the values to bind to this affine expr (all dims and symbols).
3433   SmallVector<Value, 4> operands;
3434   getValues(0, pos, &operands);
3435   SmallVector<Value, 4> trailingOperands;
3436   getValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands);
3437   operands.append(trailingOperands.begin(), trailingOperands.end());
3438   vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
3439 }
3440 
3441 /// Returns true if the pos^th column is all zero for both inequalities and
3442 /// equalities..
isColZero(const FlatAffineConstraints & cst,unsigned pos)3443 static bool isColZero(const FlatAffineConstraints &cst, unsigned pos) {
3444   unsigned rowPos;
3445   return !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/false, &rowPos) &&
3446          !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/true, &rowPos);
3447 }
3448 
getAsIntegerSet(MLIRContext * context) const3449 IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
3450   if (getNumConstraints() == 0)
3451     // Return universal set (always true): 0 == 0.
3452     return IntegerSet::get(getNumDimIds(), getNumSymbolIds(),
3453                            getAffineConstantExpr(/*constant=*/0, context),
3454                            /*eqFlags=*/true);
3455 
3456   // Construct local references.
3457   SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
3458 
3459   if (failed(computeLocalVars(*this, memo, context))) {
3460     // Check if the local variables without an explicit representation have
3461     // zero coefficients everywhere.
3462     for (unsigned i = getNumDimAndSymbolIds(), e = getNumIds(); i < e; ++i) {
3463       if (!memo[i] && !isColZero(*this, /*pos=*/i)) {
3464         LLVM_DEBUG(llvm::dbgs() << "one or more local exprs do not have an "
3465                                    "explicit representation");
3466         return IntegerSet();
3467       }
3468     }
3469   }
3470 
3471   ArrayRef<AffineExpr> localExprs =
3472       ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
3473 
3474   // Construct the IntegerSet from the equalities/inequalities.
3475   unsigned numDims = getNumDimIds();
3476   unsigned numSyms = getNumSymbolIds();
3477 
3478   SmallVector<bool, 16> eqFlags(getNumConstraints());
3479   std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
3480   std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
3481 
3482   SmallVector<AffineExpr, 8> exprs;
3483   exprs.reserve(getNumConstraints());
3484 
3485   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
3486     exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
3487                                               localExprs, context));
3488   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
3489     exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
3490                                               numSyms, localExprs, context));
3491   return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
3492 }
3493 
3494 /// Find positions of inequalities and equalities that do not have a coefficient
3495 /// for [pos, pos + num) identifiers.
getIndependentConstraints(const FlatAffineConstraints & cst,unsigned pos,unsigned num,SmallVectorImpl<unsigned> & nbIneqIndices,SmallVectorImpl<unsigned> & nbEqIndices)3496 static void getIndependentConstraints(const FlatAffineConstraints &cst,
3497                                       unsigned pos, unsigned num,
3498                                       SmallVectorImpl<unsigned> &nbIneqIndices,
3499                                       SmallVectorImpl<unsigned> &nbEqIndices) {
3500   assert(pos < cst.getNumIds() && "invalid start position");
3501   assert(pos + num <= cst.getNumIds() && "invalid limit");
3502 
3503   for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
3504     // The bounds are to be independent of [offset, offset + num) columns.
3505     unsigned c;
3506     for (c = pos; c < pos + num; ++c) {
3507       if (cst.atIneq(r, c) != 0)
3508         break;
3509     }
3510     if (c == pos + num)
3511       nbIneqIndices.push_back(r);
3512   }
3513 
3514   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
3515     // The bounds are to be independent of [offset, offset + num) columns.
3516     unsigned c;
3517     for (c = pos; c < pos + num; ++c) {
3518       if (cst.atEq(r, c) != 0)
3519         break;
3520     }
3521     if (c == pos + num)
3522       nbEqIndices.push_back(r);
3523   }
3524 }
3525 
removeIndependentConstraints(unsigned pos,unsigned num)3526 void FlatAffineConstraints::removeIndependentConstraints(unsigned pos,
3527                                                          unsigned num) {
3528   assert(pos + num <= getNumIds() && "invalid range");
3529 
3530   // Remove constraints that are independent of these identifiers.
3531   SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
3532   getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
3533 
3534   // Iterate in reverse so that indices don't have to be updated.
3535   // TODO: This method can be made more efficient (because removal of each
3536   // inequality leads to much shifting/copying in the underlying buffer).
3537   for (auto nbIndex : llvm::reverse(nbIneqIndices))
3538     removeInequality(nbIndex);
3539   for (auto nbIndex : llvm::reverse(nbEqIndices))
3540     removeEquality(nbIndex);
3541 }
3542 
alignAffineMapWithValues(AffineMap map,ValueRange operands,ValueRange dims,ValueRange syms,SmallVector<Value> * newSyms)3543 AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
3544                                          ValueRange dims, ValueRange syms,
3545                                          SmallVector<Value> *newSyms) {
3546   assert(operands.size() == map.getNumInputs() &&
3547          "expected same number of operands and map inputs");
3548   MLIRContext *ctx = map.getContext();
3549   Builder builder(ctx);
3550   SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
3551   unsigned numSymbols = syms.size();
3552   SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
3553   if (newSyms) {
3554     newSyms->clear();
3555     newSyms->append(syms.begin(), syms.end());
3556   }
3557 
3558   for (auto operand : llvm::enumerate(operands)) {
3559     // Compute replacement dim/sym of operand.
3560     AffineExpr replacement;
3561     auto dimIt = std::find(dims.begin(), dims.end(), operand.value());
3562     auto symIt = std::find(syms.begin(), syms.end(), operand.value());
3563     if (dimIt != dims.end()) {
3564       replacement =
3565           builder.getAffineDimExpr(std::distance(dims.begin(), dimIt));
3566     } else if (symIt != syms.end()) {
3567       replacement =
3568           builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt));
3569     } else {
3570       // This operand is neither a dimension nor a symbol. Add it as a new
3571       // symbol.
3572       replacement = builder.getAffineSymbolExpr(numSymbols++);
3573       if (newSyms)
3574         newSyms->push_back(operand.value());
3575     }
3576     // Add to corresponding replacements vector.
3577     if (operand.index() < map.getNumDims()) {
3578       dimReplacements[operand.index()] = replacement;
3579     } else {
3580       symReplacements[operand.index() - map.getNumDims()] = replacement;
3581     }
3582   }
3583 
3584   return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
3585                                    dims.size(), numSymbols);
3586 }
3587