1 //===- Set.cpp - MLIR PresburgerSet Class ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/PresburgerSet.h"
10 #include "mlir/Analysis/Presburger/Simplex.h"
11 #include "llvm/ADT/STLExtras.h"
12 #include "llvm/ADT/SmallBitVector.h"
13 
14 using namespace mlir;
15 
PresburgerSet(const FlatAffineConstraints & fac)16 PresburgerSet::PresburgerSet(const FlatAffineConstraints &fac)
17     : nDim(fac.getNumDimIds()), nSym(fac.getNumSymbolIds()) {
18   unionFACInPlace(fac);
19 }
20 
getNumFACs() const21 unsigned PresburgerSet::getNumFACs() const {
22   return flatAffineConstraints.size();
23 }
24 
getNumDims() const25 unsigned PresburgerSet::getNumDims() const { return nDim; }
26 
getNumSyms() const27 unsigned PresburgerSet::getNumSyms() const { return nSym; }
28 
29 ArrayRef<FlatAffineConstraints>
getAllFlatAffineConstraints() const30 PresburgerSet::getAllFlatAffineConstraints() const {
31   return flatAffineConstraints;
32 }
33 
34 const FlatAffineConstraints &
getFlatAffineConstraints(unsigned index) const35 PresburgerSet::getFlatAffineConstraints(unsigned index) const {
36   assert(index < flatAffineConstraints.size() && "index out of bounds!");
37   return flatAffineConstraints[index];
38 }
39 
40 /// Assert that the FlatAffineConstraints and PresburgerSet live in
41 /// compatible spaces.
assertDimensionsCompatible(const FlatAffineConstraints & fac,const PresburgerSet & set)42 static void assertDimensionsCompatible(const FlatAffineConstraints &fac,
43                                        const PresburgerSet &set) {
44   assert(fac.getNumDimIds() == set.getNumDims() &&
45          "Number of dimensions of the FlatAffineConstraints and PresburgerSet"
46          "do not match!");
47   assert(fac.getNumSymbolIds() == set.getNumSyms() &&
48          "Number of symbols of the FlatAffineConstraints and PresburgerSet"
49          "do not match!");
50 }
51 
52 /// Assert that the two PresburgerSets live in compatible spaces.
assertDimensionsCompatible(const PresburgerSet & setA,const PresburgerSet & setB)53 static void assertDimensionsCompatible(const PresburgerSet &setA,
54                                        const PresburgerSet &setB) {
55   assert(setA.getNumDims() == setB.getNumDims() &&
56          "Number of dimensions of the PresburgerSets do not match!");
57   assert(setA.getNumSyms() == setB.getNumSyms() &&
58          "Number of symbols of the PresburgerSets do not match!");
59 }
60 
61 /// Mutate this set, turning it into the union of this set and the given
62 /// FlatAffineConstraints.
unionFACInPlace(const FlatAffineConstraints & fac)63 void PresburgerSet::unionFACInPlace(const FlatAffineConstraints &fac) {
64   assertDimensionsCompatible(fac, *this);
65   flatAffineConstraints.push_back(fac);
66 }
67 
68 /// Mutate this set, turning it into the union of this set and the given set.
69 ///
70 /// This is accomplished by simply adding all the FACs of the given set to this
71 /// set.
unionSetInPlace(const PresburgerSet & set)72 void PresburgerSet::unionSetInPlace(const PresburgerSet &set) {
73   assertDimensionsCompatible(set, *this);
74   for (const FlatAffineConstraints &fac : set.flatAffineConstraints)
75     unionFACInPlace(fac);
76 }
77 
78 /// Return the union of this set and the given set.
unionSet(const PresburgerSet & set) const79 PresburgerSet PresburgerSet::unionSet(const PresburgerSet &set) const {
80   assertDimensionsCompatible(set, *this);
81   PresburgerSet result = *this;
82   result.unionSetInPlace(set);
83   return result;
84 }
85 
86 /// A point is contained in the union iff any of the parts contain the point.
containsPoint(ArrayRef<int64_t> point) const87 bool PresburgerSet::containsPoint(ArrayRef<int64_t> point) const {
88   for (const FlatAffineConstraints &fac : flatAffineConstraints) {
89     if (fac.containsPoint(point))
90       return true;
91   }
92   return false;
93 }
94 
getUniverse(unsigned nDim,unsigned nSym)95 PresburgerSet PresburgerSet::getUniverse(unsigned nDim, unsigned nSym) {
96   PresburgerSet result(nDim, nSym);
97   result.unionFACInPlace(FlatAffineConstraints::getUniverse(nDim, nSym));
98   return result;
99 }
100 
getEmptySet(unsigned nDim,unsigned nSym)101 PresburgerSet PresburgerSet::getEmptySet(unsigned nDim, unsigned nSym) {
102   return PresburgerSet(nDim, nSym);
103 }
104 
105 // Return the intersection of this set with the given set.
106 //
107 // We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
108 // as (S_1 and T_1) or (S_1 and T_2) or ...
intersect(const PresburgerSet & set) const109 PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const {
110   assertDimensionsCompatible(set, *this);
111 
112   PresburgerSet result(nDim, nSym);
113   for (const FlatAffineConstraints &csA : flatAffineConstraints) {
114     for (const FlatAffineConstraints &csB : set.flatAffineConstraints) {
115       FlatAffineConstraints intersection(csA);
116       intersection.append(csB);
117       if (!intersection.isEmpty())
118         result.unionFACInPlace(std::move(intersection));
119     }
120   }
121   return result;
122 }
123 
124 /// Return `coeffs` with all the elements negated.
getNegatedCoeffs(ArrayRef<int64_t> coeffs)125 static SmallVector<int64_t, 8> getNegatedCoeffs(ArrayRef<int64_t> coeffs) {
126   SmallVector<int64_t, 8> negatedCoeffs;
127   negatedCoeffs.reserve(coeffs.size());
128   for (int64_t coeff : coeffs)
129     negatedCoeffs.emplace_back(-coeff);
130   return negatedCoeffs;
131 }
132 
133 /// Return the complement of the given inequality.
134 ///
135 /// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is
136 /// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0.
getComplementIneq(ArrayRef<int64_t> ineq)137 static SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq) {
138   SmallVector<int64_t, 8> coeffs;
139   coeffs.reserve(ineq.size());
140   for (int64_t coeff : ineq)
141     coeffs.emplace_back(-coeff);
142   --coeffs.back();
143   return coeffs;
144 }
145 
146 /// Return the set difference b \ s and accumulate the result into `result`.
147 /// `simplex` must correspond to b.
148 ///
149 /// In the following, V denotes union, ^ denotes intersection, \ denotes set
150 /// difference and ~ denotes complement.
151 /// Let b be the FlatAffineConstraints and s = (V_i s_i) be the set. We want
152 /// b \ (V_i s_i).
153 ///
154 /// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute
155 /// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality:
156 /// ~s_i = (~s_i1) V (s_i1 ^ ~s_i2) V (s_i1 ^ s_i2 ^ ~s_i3) V ...
157 /// And the required result is (b ^ ~s_i1) V (b ^ s_i1 ^ ~s_i2) V ...
158 /// We recurse by subtracting V_{j > i} S_j from each of these parts and
159 /// returning the union of the results. Each equality is handled as a
160 /// conjunction of two inequalities.
161 ///
162 /// As a heuristic, we try adding all the constraints and check if simplex
163 /// says that the intersection is empty. Also, in the process we find out that
164 /// some constraints are redundant. These redundant constraints are ignored.
subtractRecursively(FlatAffineConstraints & b,Simplex & simplex,const PresburgerSet & s,unsigned i,PresburgerSet & result)165 static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
166                                 const PresburgerSet &s, unsigned i,
167                                 PresburgerSet &result) {
168   if (i == s.getNumFACs()) {
169     result.unionFACInPlace(b);
170     return;
171   }
172   const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i);
173   assert(sI.getNumLocalIds() == 0 &&
174          "Subtracting sets with divisions is not yet supported!");
175   unsigned initialSnapshot = simplex.getSnapshot();
176   unsigned offset = simplex.numConstraints();
177   simplex.intersectFlatAffineConstraints(sI);
178 
179   if (simplex.isEmpty()) {
180     /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
181     simplex.rollback(initialSnapshot);
182     subtractRecursively(b, simplex, s, i + 1, result);
183     return;
184   }
185 
186   simplex.detectRedundant();
187   llvm::SmallBitVector isMarkedRedundant;
188   for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities();
189        j++)
190     isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j));
191 
192   simplex.rollback(initialSnapshot);
193 
194   // Recurse with the part b ^ ~ineq. Note that b is modified throughout
195   // subtractRecursively. At the time this function is called, the current b is
196   // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
197   // inequality, s_{i,j+1}. This function recurses into the next level i + 1
198   // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
199   auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
200     size_t snapshot = simplex.getSnapshot();
201     b.addInequality(ineq);
202     simplex.addInequality(ineq);
203     subtractRecursively(b, simplex, s, i + 1, result);
204     b.removeInequality(b.getNumInequalities() - 1);
205     simplex.rollback(snapshot);
206   };
207 
208   // For each inequality ineq, we first recurse with the part where ineq
209   // is not satisfied, and then add the ineq to b and simplex because
210   // ineq must be satisfied by all later parts.
211   auto processInequality = [&](ArrayRef<int64_t> ineq) {
212     recurseWithInequality(getComplementIneq(ineq));
213     b.addInequality(ineq);
214     simplex.addInequality(ineq);
215   };
216 
217   // processInequality appends some additional constraints to b. We want to
218   // rollback b to its initial state before returning, which we will do by
219   // removing all constraints beyond the original number of inequalities
220   // and equalities, so we store these counts first.
221   unsigned originalNumIneqs = b.getNumInequalities();
222   unsigned originalNumEqs = b.getNumEqualities();
223 
224   for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
225     if (isMarkedRedundant[j])
226       continue;
227     processInequality(sI.getInequality(j));
228   }
229 
230   offset = sI.getNumInequalities();
231   for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
232     const ArrayRef<int64_t> &coeffs = sI.getEquality(j);
233     // Same as the above loop for inequalities, done once each for the positive
234     // and negative inequalities that make up this equality.
235     if (!isMarkedRedundant[offset + 2 * j])
236       processInequality(coeffs);
237     if (!isMarkedRedundant[offset + 2 * j + 1])
238       processInequality(getNegatedCoeffs(coeffs));
239   }
240 
241   // Rollback b and simplex to their initial states.
242   for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i)
243     b.removeInequality(i - 1);
244 
245   for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i)
246     b.removeEquality(i - 1);
247 
248   simplex.rollback(initialSnapshot);
249 }
250 
251 /// Return the set difference fac \ set.
252 ///
253 /// The FAC here is modified in subtractRecursively, so it cannot be a const
254 /// reference even though it is restored to its original state before returning
255 /// from that function.
getSetDifference(FlatAffineConstraints fac,const PresburgerSet & set)256 PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac,
257                                               const PresburgerSet &set) {
258   assertDimensionsCompatible(fac, set);
259   assert(fac.getNumLocalIds() == 0 &&
260          "Subtracting sets with divisions is not yet supported!");
261   if (fac.isEmptyByGCDTest())
262     return PresburgerSet::getEmptySet(fac.getNumDimIds(),
263                                       fac.getNumSymbolIds());
264 
265   PresburgerSet result(fac.getNumDimIds(), fac.getNumSymbolIds());
266   Simplex simplex(fac);
267   subtractRecursively(fac, simplex, set, 0, result);
268   return result;
269 }
270 
271 /// Return the complement of this set.
complement() const272 PresburgerSet PresburgerSet::complement() const {
273   return getSetDifference(
274       FlatAffineConstraints::getUniverse(getNumDims(), getNumSyms()), *this);
275 }
276 
277 /// Return the result of subtract the given set from this set, i.e.,
278 /// return `this \ set`.
subtract(const PresburgerSet & set) const279 PresburgerSet PresburgerSet::subtract(const PresburgerSet &set) const {
280   assertDimensionsCompatible(set, *this);
281   PresburgerSet result(nDim, nSym);
282   // We compute (V_i t_i) \ (V_i set_i) as V_i (t_i \ V_i set_i).
283   for (const FlatAffineConstraints &fac : flatAffineConstraints)
284     result.unionSetInPlace(getSetDifference(fac, set));
285   return result;
286 }
287 
288 /// Two sets S and T are equal iff S contains T and T contains S.
289 /// By "S contains T", we mean that S is a superset of or equal to T.
290 ///
291 /// S contains T iff T \ S is empty, since if T \ S contains a
292 /// point then this is a point that is contained in T but not S.
293 ///
294 /// Therefore, S is equal to T iff S \ T and T \ S are both empty.
isEqual(const PresburgerSet & set) const295 bool PresburgerSet::isEqual(const PresburgerSet &set) const {
296   assertDimensionsCompatible(set, *this);
297   return this->subtract(set).isIntegerEmpty() &&
298          set.subtract(*this).isIntegerEmpty();
299 }
300 
301 /// Return true if all the sets in the union are known to be integer empty,
302 /// false otherwise.
isIntegerEmpty() const303 bool PresburgerSet::isIntegerEmpty() const {
304   // The set is empty iff all of the disjuncts are empty.
305   for (const FlatAffineConstraints &fac : flatAffineConstraints) {
306     if (!fac.isIntegerEmpty())
307       return false;
308   }
309   return true;
310 }
311 
findIntegerSample(SmallVectorImpl<int64_t> & sample)312 bool PresburgerSet::findIntegerSample(SmallVectorImpl<int64_t> &sample) {
313   // A sample exists iff any of the disjuncts contains a sample.
314   for (const FlatAffineConstraints &fac : flatAffineConstraints) {
315     if (Optional<SmallVector<int64_t, 8>> opt = fac.findIntegerSample()) {
316       sample = std::move(*opt);
317       return true;
318     }
319   }
320   return false;
321 }
322 
print(raw_ostream & os) const323 void PresburgerSet::print(raw_ostream &os) const {
324   os << getNumFACs() << " FlatAffineConstraints:\n";
325   for (const FlatAffineConstraints &fac : flatAffineConstraints) {
326     fac.print(os);
327     os << '\n';
328   }
329 }
330 
dump() const331 void PresburgerSet::dump() const { print(llvm::errs()); }
332