1 //===- SetTest.cpp - Tests for PresburgerSet ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains tests for PresburgerSet. The tests for union,
10 // intersection, subtract, and complement work by computing the operation on
11 // two sets and checking, for a set of points, that the resulting set contains
12 // the point iff the result is supposed to contain it. The test for isEqual just
13 // checks if the result for two sets matches the expected result.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Analysis/PresburgerSet.h"
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 
22 namespace mlir {
23 
24 /// Compute the union of s and t, and check that each of the given points
25 /// belongs to the union iff it belongs to at least one of s and t.
testUnionAtPoints(PresburgerSet s,PresburgerSet t,ArrayRef<SmallVector<int64_t,4>> points)26 static void testUnionAtPoints(PresburgerSet s, PresburgerSet t,
27                               ArrayRef<SmallVector<int64_t, 4>> points) {
28   PresburgerSet unionSet = s.unionSet(t);
29   for (const SmallVector<int64_t, 4> &point : points) {
30     bool inS = s.containsPoint(point);
31     bool inT = t.containsPoint(point);
32     bool inUnion = unionSet.containsPoint(point);
33     EXPECT_EQ(inUnion, inS || inT);
34   }
35 }
36 
37 /// Compute the intersection of s and t, and check that each of the given points
38 /// belongs to the intersection iff it belongs to both s and t.
testIntersectAtPoints(PresburgerSet s,PresburgerSet t,ArrayRef<SmallVector<int64_t,4>> points)39 static void testIntersectAtPoints(PresburgerSet s, PresburgerSet t,
40                                   ArrayRef<SmallVector<int64_t, 4>> points) {
41   PresburgerSet intersection = s.intersect(t);
42   for (const SmallVector<int64_t, 4> &point : points) {
43     bool inS = s.containsPoint(point);
44     bool inT = t.containsPoint(point);
45     bool inIntersection = intersection.containsPoint(point);
46     EXPECT_EQ(inIntersection, inS && inT);
47   }
48 }
49 
50 /// Compute the set difference s \ t, and check that each of the given points
51 /// belongs to the difference iff it belongs to s and does not belong to t.
testSubtractAtPoints(PresburgerSet s,PresburgerSet t,ArrayRef<SmallVector<int64_t,4>> points)52 static void testSubtractAtPoints(PresburgerSet s, PresburgerSet t,
53                                  ArrayRef<SmallVector<int64_t, 4>> points) {
54   PresburgerSet diff = s.subtract(t);
55   for (const SmallVector<int64_t, 4> &point : points) {
56     bool inS = s.containsPoint(point);
57     bool inT = t.containsPoint(point);
58     bool inDiff = diff.containsPoint(point);
59     if (inT)
60       EXPECT_FALSE(inDiff);
61     else
62       EXPECT_EQ(inDiff, inS);
63   }
64 }
65 
66 /// Compute the complement of s, and check that each of the given points
67 /// belongs to the complement iff it does not belong to s.
testComplementAtPoints(PresburgerSet s,ArrayRef<SmallVector<int64_t,4>> points)68 static void testComplementAtPoints(PresburgerSet s,
69                                    ArrayRef<SmallVector<int64_t, 4>> points) {
70   PresburgerSet complement = s.complement();
71   complement.complement();
72   for (const SmallVector<int64_t, 4> &point : points) {
73     bool inS = s.containsPoint(point);
74     bool inComplement = complement.containsPoint(point);
75     if (inS)
76       EXPECT_FALSE(inComplement);
77     else
78       EXPECT_TRUE(inComplement);
79   }
80 }
81 
82 /// Construct a FlatAffineConstraints from a set of inequality and
83 /// equality constraints. `numIds` is the total number of ids, of which
84 /// `numLocals` is the number of local ids.
85 static FlatAffineConstraints
makeFACFromConstraints(unsigned numIds,ArrayRef<SmallVector<int64_t,4>> ineqs,ArrayRef<SmallVector<int64_t,4>> eqs,unsigned numLocals=0)86 makeFACFromConstraints(unsigned numIds, ArrayRef<SmallVector<int64_t, 4>> ineqs,
87                        ArrayRef<SmallVector<int64_t, 4>> eqs,
88                        unsigned numLocals = 0) {
89   FlatAffineConstraints fac(/*numReservedInequalities=*/ineqs.size(),
90                             /*numReservedEqualities=*/eqs.size(),
91                             /*numReservedCols=*/numIds + 1,
92                             /*numDims=*/numIds - numLocals,
93                             /*numSymbols=*/0, numLocals);
94   for (const SmallVector<int64_t, 4> &eq : eqs)
95     fac.addEquality(eq);
96   for (const SmallVector<int64_t, 4> &ineq : ineqs)
97     fac.addInequality(ineq);
98   return fac;
99 }
100 
101 /// Construct a FlatAffineConstraints having `numDims` dimensions from the given
102 /// set of inequality constraints. This is a convenience function to be used
103 /// when the FAC to be constructed does not have any local ids and does not have
104 /// equalties.
105 static FlatAffineConstraints
makeFACFromIneqs(unsigned numDims,ArrayRef<SmallVector<int64_t,4>> ineqs)106 makeFACFromIneqs(unsigned numDims, ArrayRef<SmallVector<int64_t, 4>> ineqs) {
107   return makeFACFromConstraints(numDims, ineqs, /*eqs=*/{});
108 }
109 
110 /// Construct a PresburgerSet having `numDims` dimensions and no symbols from
111 /// the given list of FlatAffineConstraints. Each FAC in `facs` should also have
112 /// `numDims` dimensions and no symbols, although it can have any number of
113 /// local ids.
makeSetFromFACs(unsigned numDims,ArrayRef<FlatAffineConstraints> facs)114 static PresburgerSet makeSetFromFACs(unsigned numDims,
115                                      ArrayRef<FlatAffineConstraints> facs) {
116   PresburgerSet set = PresburgerSet::getEmptySet(numDims);
117   for (const FlatAffineConstraints &fac : facs)
118     set.unionFACInPlace(fac);
119   return set;
120 }
121 
TEST(SetTest,containsPoint)122 TEST(SetTest, containsPoint) {
123   PresburgerSet setA =
124       makeSetFromFACs(1, {
125                              makeFACFromIneqs(1, {{1, -2},    // x >= 2.
126                                                   {-1, 8}}),  // x <= 8.
127                              makeFACFromIneqs(1, {{1, -10},   // x >= 10.
128                                                   {-1, 20}}), // x <= 20.
129                          });
130   for (unsigned x = 0; x <= 21; ++x) {
131     if ((2 <= x && x <= 8) || (10 <= x && x <= 20))
132       EXPECT_TRUE(setA.containsPoint({x}));
133     else
134       EXPECT_FALSE(setA.containsPoint({x}));
135   }
136 
137   // A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} union
138   // a square with opposite corners (2, 2) and (10, 10).
139   PresburgerSet setB =
140       makeSetFromFACs(2, {makeFACFromIneqs(2,
141                                            {
142                                                {1, 1, -2},   // x + y >= 4.
143                                                {-1, -1, 30}, // x + y <= 32.
144                                                {1, -1, 0},   // x - y >= 2.
145                                                {-1, 1, 10},  // x - y <= 16.
146                                            }),
147                           makeFACFromIneqs(2, {
148                                                   {1, 0, -2},  // x >= 2.
149                                                   {0, 1, -2},  // y >= 2.
150                                                   {-1, 0, 10}, // x <= 10.
151                                                   {0, -1, 10}  // y <= 10.
152                                               })});
153 
154   for (unsigned x = 1; x <= 25; ++x) {
155     for (unsigned y = -6; y <= 16; ++y) {
156       if (4 <= x + y && x + y <= 32 && 2 <= x - y && x - y <= 16)
157         EXPECT_TRUE(setB.containsPoint({x, y}));
158       else if (2 <= x && x <= 10 && 2 <= y && y <= 10)
159         EXPECT_TRUE(setB.containsPoint({x, y}));
160       else
161         EXPECT_FALSE(setB.containsPoint({x, y}));
162     }
163   }
164 }
165 
TEST(SetTest,Union)166 TEST(SetTest, Union) {
167   PresburgerSet set =
168       makeSetFromFACs(1, {
169                              makeFACFromIneqs(1, {{1, -2},    // x >= 2.
170                                                   {-1, 8}}),  // x <= 8.
171                              makeFACFromIneqs(1, {{1, -10},   // x >= 10.
172                                                   {-1, 20}}), // x <= 20.
173                          });
174 
175   // Universe union set.
176   testUnionAtPoints(PresburgerSet::getUniverse(1), set,
177                     {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
178 
179   // empty set union set.
180   testUnionAtPoints(PresburgerSet::getEmptySet(1), set,
181                     {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
182 
183   // empty set union Universe.
184   testUnionAtPoints(PresburgerSet::getEmptySet(1),
185                     PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}});
186 
187   // Universe union empty set.
188   testUnionAtPoints(PresburgerSet::getUniverse(1),
189                     PresburgerSet::getEmptySet(1), {{1}, {2}, {0}, {-1}});
190 
191   // empty set union empty set.
192   testUnionAtPoints(PresburgerSet::getEmptySet(1),
193                     PresburgerSet::getEmptySet(1), {{1}, {2}, {0}, {-1}});
194 }
195 
TEST(SetTest,Intersect)196 TEST(SetTest, Intersect) {
197   PresburgerSet set =
198       makeSetFromFACs(1, {
199                              makeFACFromIneqs(1, {{1, -2},    // x >= 2.
200                                                   {-1, 8}}),  // x <= 8.
201                              makeFACFromIneqs(1, {{1, -10},   // x >= 10.
202                                                   {-1, 20}}), // x <= 20.
203                          });
204 
205   // Universe intersection set.
206   testIntersectAtPoints(PresburgerSet::getUniverse(1), set,
207                         {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
208 
209   // empty set intersection set.
210   testIntersectAtPoints(PresburgerSet::getEmptySet(1), set,
211                         {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
212 
213   // empty set intersection Universe.
214   testIntersectAtPoints(PresburgerSet::getEmptySet(1),
215                         PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}});
216 
217   // Universe intersection empty set.
218   testIntersectAtPoints(PresburgerSet::getUniverse(1),
219                         PresburgerSet::getEmptySet(1), {{1}, {2}, {0}, {-1}});
220 
221   // Universe intersection Universe.
222   testIntersectAtPoints(PresburgerSet::getUniverse(1),
223                         PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}});
224 }
225 
TEST(SetTest,Subtract)226 TEST(SetTest, Subtract) {
227   // The interval [2, 8] minus
228   // the interval [10, 20].
229   testSubtractAtPoints(
230       makeSetFromFACs(1, {makeFACFromIneqs(1, {})}),
231       makeSetFromFACs(1,
232                       {
233                           makeFACFromIneqs(1, {{1, -2},    // x >= 2.
234                                                {-1, 8}}),  // x <= 8.
235                           makeFACFromIneqs(1, {{1, -10},   // x >= 10.
236                                                {-1, 20}}), // x <= 20.
237                       }),
238       {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
239 
240   // ((-infinity, 0] U [3, 4] U [6, 7]) - ([2, 3] U [5, 6])
241   testSubtractAtPoints(
242       makeSetFromFACs(1,
243                       {
244                           makeFACFromIneqs(1,
245                                            {
246                                                {-1, 0} // x <= 0.
247                                            }),
248                           makeFACFromIneqs(1,
249                                            {
250                                                {1, -3}, // x >= 3.
251                                                {-1, 4}  // x <= 4.
252                                            }),
253                           makeFACFromIneqs(1,
254                                            {
255                                                {1, -6}, // x >= 6.
256                                                {-1, 7}  // x <= 7.
257                                            }),
258                       }),
259       makeSetFromFACs(1, {makeFACFromIneqs(1,
260                                            {
261                                                {1, -2}, // x >= 2.
262                                                {-1, 3}, // x <= 3.
263                                            }),
264                           makeFACFromIneqs(1,
265                                            {
266                                                {1, -5}, // x >= 5.
267                                                {-1, 6}  // x <= 6.
268                                            })}),
269       {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
270 
271   // Expected result is {[x, y] : x > y}, i.e., {[x, y] : x >= y + 1}.
272   testSubtractAtPoints(
273       makeSetFromFACs(2, {makeFACFromIneqs(2,
274                                            {
275                                                {1, -1, 0} // x >= y.
276                                            })}),
277       makeSetFromFACs(2, {makeFACFromIneqs(2,
278                                            {
279                                                {1, 1, 0} // x >= -y.
280                                            })}),
281       {{0, 1}, {1, 1}, {1, 0}, {1, -1}, {0, -1}});
282 
283   // A rectangle with corners at (2, 2) and (10, 10), minus
284   // a rectangle with corners at (5, -10) and (7, 100).
285   // This splits the former rectangle into two halves, (2, 2) to (5, 10) and
286   // (7, 2) to (10, 10).
287   testSubtractAtPoints(
288       makeSetFromFACs(2, {makeFACFromIneqs(2,
289                                            {
290                                                {1, 0, -2},  // x >= 2.
291                                                {0, 1, -2},  // y >= 2.
292                                                {-1, 0, 10}, // x <= 10.
293                                                {0, -1, 10}  // y <= 10.
294                                            })}),
295       makeSetFromFACs(2, {makeFACFromIneqs(2,
296                                            {
297                                                {1, 0, -5},   // x >= 5.
298                                                {0, 1, 10},   // y >= -10.
299                                                {-1, 0, 7},   // x <= 7.
300                                                {0, -1, 100}, // y <= 100.
301                                            })}),
302       {{1, 2},  {2, 2},  {4, 2},  {5, 2},  {7, 2},  {8, 2},  {11, 2},
303        {1, 1},  {2, 1},  {4, 1},  {5, 1},  {7, 1},  {8, 1},  {11, 1},
304        {1, 10}, {2, 10}, {4, 10}, {5, 10}, {7, 10}, {8, 10}, {11, 10},
305        {1, 11}, {2, 11}, {4, 11}, {5, 11}, {7, 11}, {8, 11}, {11, 11}});
306 
307   // A rectangle with corners at (2, 2) and (10, 10), minus
308   // a rectangle with corners at (5, 4) and (7, 8).
309   // This creates a hole in the middle of the former rectangle, and the
310   // resulting set can be represented as a union of four rectangles.
311   testSubtractAtPoints(
312       makeSetFromFACs(2, {makeFACFromIneqs(2,
313                                            {
314                                                {1, 0, -2},  // x >= 2.
315                                                {0, 1, -2},  // y >= 2.
316                                                {-1, 0, 10}, // x <= 10.
317                                                {0, -1, 10}  // y <= 10.
318                                            })}),
319       makeSetFromFACs(2, {makeFACFromIneqs(2,
320                                            {
321                                                {1, 0, -5}, // x >= 5.
322                                                {0, 1, -4}, // y >= 4.
323                                                {-1, 0, 7}, // x <= 7.
324                                                {0, -1, 8}, // y <= 8.
325                                            })}),
326       {{1, 1},
327        {2, 2},
328        {10, 10},
329        {11, 11},
330        {5, 4},
331        {7, 4},
332        {5, 8},
333        {7, 8},
334        {4, 4},
335        {8, 4},
336        {4, 8},
337        {8, 8}});
338 
339   // The second set is a superset of the first one, since on the line x + y = 0,
340   // y <= 1 is equivalent to x >= -1. So the result is empty.
341   testSubtractAtPoints(
342       makeSetFromFACs(2, {makeFACFromConstraints(2,
343                                                  {
344                                                      {1, 0, 0} // x >= 0.
345                                                  },
346                                                  {
347                                                      {1, 1, 0} // x + y = 0.
348                                                  })}),
349       makeSetFromFACs(2, {makeFACFromConstraints(2,
350                                                  {
351                                                      {0, -1, 1} // y <= 1.
352                                                  },
353                                                  {
354                                                      {1, 1, 0} // x + y = 0.
355                                                  })}),
356       {{0, 0},
357        {1, -1},
358        {2, -2},
359        {-1, 1},
360        {-2, 2},
361        {1, 1},
362        {-1, -1},
363        {-1, 1},
364        {1, -1}});
365 
366   // The result should be {0} U {2}.
367   testSubtractAtPoints(
368       makeSetFromFACs(1,
369                       {
370                           makeFACFromIneqs(1, {{1, 0},    // x >= 0.
371                                                {-1, 2}}), // x <= 2.
372                       }),
373       makeSetFromFACs(1,
374                       {
375                           makeFACFromConstraints(1, {},
376                                                  {
377                                                      {1, -1} // x = 1.
378                                                  }),
379                       }),
380       {{-1}, {0}, {1}, {2}, {3}});
381 
382   // Sets with lots of redundant inequalities to test the redundancy heuristic.
383   // (the heuristic is for the subtrahend, the second set which is the one being
384   // subtracted)
385 
386   // A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} minus
387   // a triangle with vertices {(2, 2), (10, 2), (10, 10)}.
388   testSubtractAtPoints(
389       makeSetFromFACs(2, {makeFACFromIneqs(2,
390                                            {
391                                                {1, 1, -2},   // x + y >= 4.
392                                                {-1, -1, 30}, // x + y <= 32.
393                                                {1, -1, 0},   // x - y >= 2.
394                                                {-1, 1, 10},  // x - y <= 16.
395                                            })}),
396       makeSetFromFACs(
397           2, {makeFACFromIneqs(2,
398                                {
399                                    {1, 0, -2},   // x >= 2. [redundant]
400                                    {0, 1, -2},   // y >= 2.
401                                    {-1, 0, 10},  // x <= 10.
402                                    {0, -1, 10},  // y <= 10. [redundant]
403                                    {1, 1, -2},   // x + y >= 2. [redundant]
404                                    {-1, -1, 30}, // x + y <= 30. [redundant]
405                                    {1, -1, 0},   // x - y >= 0.
406                                    {-1, 1, 10},  // x - y <= 10.
407                                })}),
408       {{1, 2},  {2, 2},   {3, 2},   {4, 2},  {1, 1},   {2, 1},   {3, 1},
409        {4, 1},  {2, 0},   {3, 0},   {4, 0},  {5, 0},   {10, 2},  {11, 2},
410        {10, 1}, {10, 10}, {10, 11}, {10, 9}, {11, 10}, {10, -6}, {11, -6},
411        {24, 8}, {24, 7},  {17, 15}, {16, 15}});
412 
413   testSubtractAtPoints(
414       makeSetFromFACs(2, {makeFACFromIneqs(2,
415                                            {
416                                                {1, 1, -2},   // x + y >= 4.
417                                                {-1, -1, 30}, // x + y <= 32.
418                                                {1, -1, 0},   // x - y >= 2.
419                                                {-1, 1, 10},  // x - y <= 16.
420                                            })}),
421       makeSetFromFACs(
422           2, {makeFACFromIneqs(2,
423                                {
424                                    {1, 0, -2},   // x >= 2. [redundant]
425                                    {0, 1, -2},   // y >= 2.
426                                    {-1, 0, 10},  // x <= 10.
427                                    {0, -1, 10},  // y <= 10. [redundant]
428                                    {1, 1, -2},   // x + y >= 2. [redundant]
429                                    {-1, -1, 30}, // x + y <= 30. [redundant]
430                                    {1, -1, 0},   // x - y >= 0.
431                                    {-1, 1, 10},  // x - y <= 10.
432                                })}),
433       {{1, 2},  {2, 2},   {3, 2},   {4, 2},  {1, 1},   {2, 1},   {3, 1},
434        {4, 1},  {2, 0},   {3, 0},   {4, 0},  {5, 0},   {10, 2},  {11, 2},
435        {10, 1}, {10, 10}, {10, 11}, {10, 9}, {11, 10}, {10, -6}, {11, -6},
436        {24, 8}, {24, 7},  {17, 15}, {16, 15}});
437 
438   // ((-infinity, -5] U [3, 3] U [4, 4] U [5, 5]) - ([-2, -10] U [3, 4] U [6,
439   // 7])
440   testSubtractAtPoints(
441       makeSetFromFACs(1,
442                       {
443                           makeFACFromIneqs(1,
444                                            {
445                                                {-1, -5}, // x <= -5.
446                                            }),
447                           makeFACFromConstraints(1, {},
448                                                  {
449                                                      {1, -3} // x = 3.
450                                                  }),
451                           makeFACFromConstraints(1, {},
452                                                  {
453                                                      {1, -4} // x = 4.
454                                                  }),
455                           makeFACFromConstraints(1, {},
456                                                  {
457                                                      {1, -5} // x = 5.
458                                                  }),
459                       }),
460       makeSetFromFACs(
461           1,
462           {
463               makeFACFromIneqs(1,
464                                {
465                                    {-1, -2},  // x <= -2.
466                                    {1, -10},  // x >= -10.
467                                    {-1, 0},   // x <= 0. [redundant]
468                                    {-1, 10},  // x <= 10. [redundant]
469                                    {1, -100}, // x >= -100. [redundant]
470                                    {1, -50}   // x >= -50. [redundant]
471                                }),
472               makeFACFromIneqs(1,
473                                {
474                                    {1, -3}, // x >= 3.
475                                    {-1, 4}, // x <= 4.
476                                    {1, 1},  // x >= -1. [redundant]
477                                    {1, 7},  // x >= -7. [redundant]
478                                    {-1, 10} // x <= 10. [redundant]
479                                }),
480               makeFACFromIneqs(1,
481                                {
482                                    {1, -6}, // x >= 6.
483                                    {-1, 7}, // x <= 7.
484                                    {1, 1},  // x >= -1. [redundant]
485                                    {1, -3}, // x >= -3. [redundant]
486                                    {-1, 5}  // x <= 5. [redundant]
487                                }),
488           }),
489       {{-6},
490        {-5},
491        {-4},
492        {-9},
493        {-10},
494        {-11},
495        {0},
496        {1},
497        {2},
498        {3},
499        {4},
500        {5},
501        {6},
502        {7},
503        {8}});
504 }
505 
TEST(SetTest,Complement)506 TEST(SetTest, Complement) {
507   // Complement of universe.
508   testComplementAtPoints(
509       PresburgerSet::getUniverse(1),
510       {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}});
511 
512   // Complement of empty set.
513   testComplementAtPoints(
514       PresburgerSet::getEmptySet(1),
515       {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}});
516 
517   testComplementAtPoints(
518       makeSetFromFACs(2, {makeFACFromIneqs(2,
519                                            {
520                                                {1, 0, -2},  // x >= 2.
521                                                {0, 1, -2},  // y >= 2.
522                                                {-1, 0, 10}, // x <= 10.
523                                                {0, -1, 10}  // y <= 10.
524                                            })}),
525       {{1, 1},
526        {2, 1},
527        {1, 2},
528        {2, 2},
529        {2, 3},
530        {3, 2},
531        {10, 10},
532        {10, 11},
533        {11, 10},
534        {2, 10},
535        {2, 11},
536        {1, 10}});
537 }
538 
TEST(SetTest,isEqual)539 TEST(SetTest, isEqual) {
540   // set = [2, 8] U [10, 20].
541   PresburgerSet universe = PresburgerSet::getUniverse(1);
542   PresburgerSet emptySet = PresburgerSet::getEmptySet(1);
543   PresburgerSet set =
544       makeSetFromFACs(1, {
545                              makeFACFromIneqs(1, {{1, -2},    // x >= 2.
546                                                   {-1, 8}}),  // x <= 8.
547                              makeFACFromIneqs(1, {{1, -10},   // x >= 10.
548                                                   {-1, 20}}), // x <= 20.
549                          });
550 
551   // universe != emptySet.
552   EXPECT_FALSE(universe.isEqual(emptySet));
553   // emptySet != universe.
554   EXPECT_FALSE(emptySet.isEqual(universe));
555   // emptySet == emptySet.
556   EXPECT_TRUE(emptySet.isEqual(emptySet));
557   // universe == universe.
558   EXPECT_TRUE(universe.isEqual(universe));
559 
560   // universe U emptySet == universe.
561   EXPECT_TRUE(universe.unionSet(emptySet).isEqual(universe));
562   // universe U universe == universe.
563   EXPECT_TRUE(universe.unionSet(universe).isEqual(universe));
564   // emptySet U emptySet == emptySet.
565   EXPECT_TRUE(emptySet.unionSet(emptySet).isEqual(emptySet));
566   // universe U emptySet != emptySet.
567   EXPECT_FALSE(universe.unionSet(emptySet).isEqual(emptySet));
568   // universe U universe != emptySet.
569   EXPECT_FALSE(universe.unionSet(universe).isEqual(emptySet));
570   // emptySet U emptySet != universe.
571   EXPECT_FALSE(emptySet.unionSet(emptySet).isEqual(universe));
572 
573   // set \ set == emptySet.
574   EXPECT_TRUE(set.subtract(set).isEqual(emptySet));
575   // set == set.
576   EXPECT_TRUE(set.isEqual(set));
577   // set U (universe \ set) == universe.
578   EXPECT_TRUE(set.unionSet(set.complement()).isEqual(universe));
579   // set U (universe \ set) != set.
580   EXPECT_FALSE(set.unionSet(set.complement()).isEqual(set));
581   // set != set U (universe \ set).
582   EXPECT_FALSE(set.isEqual(set.unionSet(set.complement())));
583 
584   // square is one unit taller than rect.
585   PresburgerSet square =
586       makeSetFromFACs(2, {makeFACFromIneqs(2, {
587                                                   {1, 0, -2}, // x >= 2.
588                                                   {0, 1, -2}, // y >= 2.
589                                                   {-1, 0, 9}, // x <= 9.
590                                                   {0, -1, 9}  // y <= 9.
591                                               })});
592   PresburgerSet rect =
593       makeSetFromFACs(2, {makeFACFromIneqs(2, {
594                                                   {1, 0, -2}, // x >= 2.
595                                                   {0, 1, -2}, // y >= 2.
596                                                   {-1, 0, 9}, // x <= 9.
597                                                   {0, -1, 8}  // y <= 8.
598                                               })});
599   EXPECT_FALSE(square.isEqual(rect));
600   PresburgerSet universeRect = square.unionSet(square.complement());
601   PresburgerSet universeSquare = rect.unionSet(rect.complement());
602   EXPECT_TRUE(universeRect.isEqual(universeSquare));
603   EXPECT_FALSE(universeRect.isEqual(rect));
604   EXPECT_FALSE(universeSquare.isEqual(square));
605   EXPECT_FALSE(rect.complement().isEqual(square.complement()));
606 }
607 
expectEqual(PresburgerSet s,PresburgerSet t)608 void expectEqual(PresburgerSet s, PresburgerSet t) {
609   EXPECT_TRUE(s.isEqual(t));
610 }
611 
expectEmpty(PresburgerSet s)612 void expectEmpty(PresburgerSet s) { EXPECT_TRUE(s.isIntegerEmpty()); }
613 
TEST(SetTest,divisions)614 TEST(SetTest, divisions) {
615   // Note: we currently need to add the equalities as inequalities to the FAC
616   // since detecting divisions based on equalities is not yet supported.
617 
618   // evens = {x : exists q, x = 2q}.
619   PresburgerSet evens{
620       makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, 0}}, 1)};
621   // odds = {x : exists q, x = 2q + 1}.
622   PresburgerSet odds{
623       makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, -1}}, 1)};
624   // multiples6 = {x : exists q, x = 6q}.
625   PresburgerSet multiples3{
626       makeFACFromConstraints(2, {{1, -3, 0}, {-1, 3, 2}}, {{1, -3, 0}}, 1)};
627   // multiples6 = {x : exists q, x = 6q}.
628   PresburgerSet multiples6{
629       makeFACFromConstraints(2, {{1, -6, 0}, {-1, 6, 5}}, {{1, -6, 0}}, 1)};
630 
631   // evens /\ odds = empty.
632   expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds)));
633   // evens U odds = universe.
634   expectEqual(evens.unionSet(odds), PresburgerSet::getUniverse(1));
635   expectEqual(evens.complement(), odds);
636   expectEqual(odds.complement(), evens);
637   // even multiples of 3 = multiples of 6.
638   expectEqual(multiples3.intersect(evens), multiples6);
639 }
640 
641 } // namespace mlir
642