1 //===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// A bitvector that uses an IntervalMap to coalesce adjacent elements
11 /// into intervals.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_ADT_COALESCINGBITVECTOR_H
16 #define LLVM_ADT_COALESCINGBITVECTOR_H
17 
18 #include "llvm/ADT/IntervalMap.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/iterator_range.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/raw_ostream.h"
24 
25 #include <initializer_list>
26 
27 namespace llvm {
28 
29 /// A bitvector that, under the hood, relies on an IntervalMap to coalesce
30 /// elements into intervals. Good for representing sets which predominantly
31 /// contain contiguous ranges. Bad for representing sets with lots of gaps
32 /// between elements.
33 ///
34 /// Compared to SparseBitVector, CoalescingBitVector offers more predictable
35 /// performance for non-sequential find() operations.
36 ///
37 /// \tparam IndexT - The type of the index into the bitvector.
38 template <typename IndexT> class CoalescingBitVector {
39   static_assert(std::is_unsigned<IndexT>::value,
40                 "Index must be an unsigned integer.");
41 
42   using ThisT = CoalescingBitVector<IndexT>;
43 
44   /// An interval map for closed integer ranges. The mapped values are unused.
45   using MapT = IntervalMap<IndexT, char>;
46 
47   using UnderlyingIterator = typename MapT::const_iterator;
48 
49   using IntervalT = std::pair<IndexT, IndexT>;
50 
51 public:
52   using Allocator = typename MapT::Allocator;
53 
54   /// Construct by passing in a CoalescingBitVector<IndexT>::Allocator
55   /// reference.
56   CoalescingBitVector(Allocator &Alloc)
57       : Alloc(&Alloc), Intervals(Alloc) {}
58 
59   /// \name Copy/move constructors and assignment operators.
60   /// @{
61 
62   CoalescingBitVector(const ThisT &Other)
63       : Alloc(Other.Alloc), Intervals(*Other.Alloc) {
64     set(Other);
65   }
66 
67   ThisT &operator=(const ThisT &Other) {
68     clear();
69     set(Other);
70     return *this;
71   }
72 
73   CoalescingBitVector(ThisT &&Other) = delete;
74   ThisT &operator=(ThisT &&Other) = delete;
75 
76   /// @}
77 
78   /// Clear all the bits.
79   void clear() { Intervals.clear(); }
80 
81   /// Check whether no bits are set.
82   bool empty() const { return Intervals.empty(); }
83 
84   /// Count the number of set bits.
85   unsigned count() const {
86     unsigned Bits = 0;
87     for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It)
88       Bits += 1 + It.stop() - It.start();
89     return Bits;
90   }
91 
92   /// Set the bit at \p Index.
93   ///
94   /// This method does /not/ support setting a bit that has already been set,
95   /// for efficiency reasons. If possible, restructure your code to not set the
96   /// same bit multiple times, or use \ref test_and_set.
97   void set(IndexT Index) {
98     assert(!test(Index) && "Setting already-set bits not supported/efficient, "
99                            "IntervalMap will assert");
100     insert(Index, Index);
101   }
102 
103   /// Set the bits set in \p Other.
104   ///
105   /// This method does /not/ support setting already-set bits, see \ref set
106   /// for the rationale. For a safe set union operation, use \ref operator|=.
107   void set(const ThisT &Other) {
108     for (auto It = Other.Intervals.begin(), End = Other.Intervals.end();
109          It != End; ++It)
110       insert(It.start(), It.stop());
111   }
112 
113   /// Set the bits at \p Indices. Used for testing, primarily.
114   void set(std::initializer_list<IndexT> Indices) {
115     for (IndexT Index : Indices)
116       set(Index);
117   }
118 
119   /// Check whether the bit at \p Index is set.
120   bool test(IndexT Index) const {
121     const auto It = Intervals.find(Index);
122     if (It == Intervals.end())
123       return false;
124     assert(It.stop() >= Index && "Interval must end after Index");
125     return It.start() <= Index;
126   }
127 
128   /// Set the bit at \p Index. Supports setting an already-set bit.
129   void test_and_set(IndexT Index) {
130     if (!test(Index))
131       set(Index);
132   }
133 
134   /// Reset the bit at \p Index. Supports resetting an already-unset bit.
135   void reset(IndexT Index) {
136     auto It = Intervals.find(Index);
137     if (It == Intervals.end())
138       return;
139 
140     // Split the interval containing Index into up to two parts: one from
141     // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to
142     // either Start or Stop, we create one new interval. If Index is equal to
143     // both Start and Stop, we simply erase the existing interval.
144     IndexT Start = It.start();
145     if (Index < Start)
146       // The index was not set.
147       return;
148     IndexT Stop = It.stop();
149     assert(Index <= Stop && "Wrong interval for index");
150     It.erase();
151     if (Start < Index)
152       insert(Start, Index - 1);
153     if (Index < Stop)
154       insert(Index + 1, Stop);
155   }
156 
157   /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may
158   /// be a faster alternative.
159   void operator|=(const ThisT &RHS) {
160     // Get the overlaps between the two interval maps.
161     SmallVector<IntervalT, 8> Overlaps;
162     getOverlaps(RHS, Overlaps);
163 
164     // Insert the non-overlapping parts of all the intervals from RHS.
165     for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end();
166          It != End; ++It) {
167       IndexT Start = It.start();
168       IndexT Stop = It.stop();
169       SmallVector<IntervalT, 8> NonOverlappingParts;
170       getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts);
171       for (IntervalT AdditivePortion : NonOverlappingParts)
172         insert(AdditivePortion.first, AdditivePortion.second);
173     }
174   }
175 
176   /// Set intersection.
177   void operator&=(const ThisT &RHS) {
178     // Get the overlaps between the two interval maps (i.e. the intersection).
179     SmallVector<IntervalT, 8> Overlaps;
180     getOverlaps(RHS, Overlaps);
181     // Rebuild the interval map, including only the overlaps.
182     clear();
183     for (IntervalT Overlap : Overlaps)
184       insert(Overlap.first, Overlap.second);
185   }
186 
187   /// Reset all bits present in \p Other.
188   void intersectWithComplement(const ThisT &Other) {
189     SmallVector<IntervalT, 8> Overlaps;
190     if (!getOverlaps(Other, Overlaps)) {
191       // If there is no overlap with Other, the intersection is empty.
192       return;
193     }
194 
195     // Delete the overlapping intervals. Split up intervals that only partially
196     // intersect an overlap.
197     for (IntervalT Overlap : Overlaps) {
198       IndexT OlapStart, OlapStop;
199       std::tie(OlapStart, OlapStop) = Overlap;
200 
201       auto It = Intervals.find(OlapStart);
202       IndexT CurrStart = It.start();
203       IndexT CurrStop = It.stop();
204       assert(CurrStart <= OlapStart && OlapStop <= CurrStop &&
205              "Expected some intersection!");
206 
207       // Split the overlap interval into up to two parts: one from [CurrStart,
208       // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is
209       // equal to CurrStart, the first split interval is unnecessary. Ditto for
210       // when OlapStop is equal to CurrStop, we omit the second split interval.
211       It.erase();
212       if (CurrStart < OlapStart)
213         insert(CurrStart, OlapStart - 1);
214       if (OlapStop < CurrStop)
215         insert(OlapStop + 1, CurrStop);
216     }
217   }
218 
219   bool operator==(const ThisT &RHS) const {
220     // We cannot just use std::equal because it checks the dereferenced values
221     // of an iterator pair for equality, not the iterators themselves. In our
222     // case that results in comparison of the (unused) IntervalMap values.
223     auto ItL = Intervals.begin();
224     auto ItR = RHS.Intervals.begin();
225     while (ItL != Intervals.end() && ItR != RHS.Intervals.end() &&
226            ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) {
227       ++ItL;
228       ++ItR;
229     }
230     return ItL == Intervals.end() && ItR == RHS.Intervals.end();
231   }
232 
233   bool operator!=(const ThisT &RHS) const { return !operator==(RHS); }
234 
235   class const_iterator {
236     friend class CoalescingBitVector;
237 
238   public:
239     using iterator_category = std::forward_iterator_tag;
240     using value_type = IndexT;
241     using difference_type = std::ptrdiff_t;
242     using pointer = value_type *;
243     using reference = value_type &;
244 
245   private:
246     // For performance reasons, make the offset at the end different than the
247     // one used in \ref begin, to optimize the common `It == end()` pattern.
248     static constexpr unsigned kIteratorAtTheEndOffset = ~0u;
249 
250     UnderlyingIterator MapIterator;
251     unsigned OffsetIntoMapIterator = 0;
252 
253     // Querying the start/stop of an IntervalMap iterator can be very expensive.
254     // Cache these values for performance reasons.
255     IndexT CachedStart = IndexT();
256     IndexT CachedStop = IndexT();
257 
258     void setToEnd() {
259       OffsetIntoMapIterator = kIteratorAtTheEndOffset;
260       CachedStart = IndexT();
261       CachedStop = IndexT();
262     }
263 
264     /// MapIterator has just changed, reset the cached state to point to the
265     /// start of the new underlying iterator.
266     void resetCache() {
267       if (MapIterator.valid()) {
268         OffsetIntoMapIterator = 0;
269         CachedStart = MapIterator.start();
270         CachedStop = MapIterator.stop();
271       } else {
272         setToEnd();
273       }
274     }
275 
276     /// Advance the iterator to \p Index, if it is contained within the current
277     /// interval. The public-facing method which supports advancing past the
278     /// current interval is \ref advanceToLowerBound.
279     void advanceTo(IndexT Index) {
280       assert(Index <= CachedStop && "Cannot advance to OOB index");
281       if (Index < CachedStart)
282         // We're already past this index.
283         return;
284       OffsetIntoMapIterator = Index - CachedStart;
285     }
286 
287     const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) {
288       resetCache();
289     }
290 
291   public:
292     const_iterator() { setToEnd(); }
293 
294     bool operator==(const const_iterator &RHS) const {
295       // Do /not/ compare MapIterator for equality, as this is very expensive.
296       // The cached start/stop values make that check unnecessary.
297       return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) ==
298              std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart,
299                       RHS.CachedStop);
300     }
301 
302     bool operator!=(const const_iterator &RHS) const {
303       return !operator==(RHS);
304     }
305 
306     IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; }
307 
308     const_iterator &operator++() { // Pre-increment (++It).
309       if (CachedStart + OffsetIntoMapIterator < CachedStop) {
310         // Keep going within the current interval.
311         ++OffsetIntoMapIterator;
312       } else {
313         // We reached the end of the current interval: advance.
314         ++MapIterator;
315         resetCache();
316       }
317       return *this;
318     }
319 
320     const_iterator operator++(int) { // Post-increment (It++).
321       const_iterator tmp = *this;
322       operator++();
323       return tmp;
324     }
325 
326     /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If
327     /// no such set bit exists, advance to end(). This is like std::lower_bound.
328     /// This is useful if \p Index is close to the current iterator position.
329     /// However, unlike \ref find(), this has worst-case O(n) performance.
330     void advanceToLowerBound(IndexT Index) {
331       if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
332         return;
333 
334       // Advance to the first interval containing (or past) Index, or to end().
335       while (Index > CachedStop) {
336         ++MapIterator;
337         resetCache();
338         if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
339           return;
340       }
341 
342       advanceTo(Index);
343     }
344   };
345 
346   const_iterator begin() const { return const_iterator(Intervals.begin()); }
347 
348   const_iterator end() const { return const_iterator(); }
349 
350   /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index.
351   /// If no such set bit exists, return end(). This is like std::lower_bound.
352   /// This has worst-case logarithmic performance (roughly O(log(gaps between
353   /// contiguous ranges))).
354   const_iterator find(IndexT Index) const {
355     auto UnderlyingIt = Intervals.find(Index);
356     if (UnderlyingIt == Intervals.end())
357       return end();
358     auto It = const_iterator(UnderlyingIt);
359     It.advanceTo(Index);
360     return It;
361   }
362 
363   /// Return a range iterator which iterates over all of the set bits in the
364   /// half-open range [Start, End).
365   iterator_range<const_iterator> half_open_range(IndexT Start,
366                                                  IndexT End) const {
367     assert(Start < End && "Not a valid range");
368     auto StartIt = find(Start);
369     if (StartIt == end() || *StartIt >= End)
370       return {end(), end()};
371     auto EndIt = StartIt;
372     EndIt.advanceToLowerBound(End);
373     return {StartIt, EndIt};
374   }
375 
376   void print(raw_ostream &OS) const {
377     OS << "{";
378     for (auto It = Intervals.begin(), End = Intervals.end(); It != End;
379          ++It) {
380       OS << "[" << It.start();
381       if (It.start() != It.stop())
382         OS << ", " << It.stop();
383       OS << "]";
384     }
385     OS << "}";
386   }
387 
388 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
389   LLVM_DUMP_METHOD void dump() const {
390     // LLDB swallows the first line of output after callling dump(). Add
391     // newlines before/after the braces to work around this.
392     dbgs() << "\n";
393     print(dbgs());
394     dbgs() << "\n";
395   }
396 #endif
397 
398 private:
399   void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); }
400 
401   /// Record the overlaps between \p this and \p Other in \p Overlaps. Return
402   /// true if there is any overlap.
403   bool getOverlaps(const ThisT &Other,
404                    SmallVectorImpl<IntervalT> &Overlaps) const {
405     for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals);
406          I.valid(); ++I)
407       Overlaps.emplace_back(I.start(), I.stop());
408     assert(llvm::is_sorted(Overlaps,
409                            [](IntervalT LHS, IntervalT RHS) {
410                              return LHS.second < RHS.first;
411                            }) &&
412            "Overlaps must be sorted");
413     return !Overlaps.empty();
414   }
415 
416   /// Given the set of overlaps between this and some other bitvector, and an
417   /// interval [Start, Stop] from that bitvector, determine the portions of the
418   /// interval which do not overlap with this.
419   void getNonOverlappingParts(IndexT Start, IndexT Stop,
420                               const SmallVectorImpl<IntervalT> &Overlaps,
421                               SmallVectorImpl<IntervalT> &NonOverlappingParts) {
422     IndexT NextUncoveredBit = Start;
423     for (IntervalT Overlap : Overlaps) {
424       IndexT OlapStart, OlapStop;
425       std::tie(OlapStart, OlapStop) = Overlap;
426 
427       // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop
428       // and Start <= OlapStop.
429       bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop;
430       if (!DoesOverlap)
431         continue;
432 
433       // Cover the range [NextUncoveredBit, OlapStart). This puts the start of
434       // the next uncovered range at OlapStop+1.
435       if (NextUncoveredBit < OlapStart)
436         NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1);
437       NextUncoveredBit = OlapStop + 1;
438       if (NextUncoveredBit > Stop)
439         break;
440     }
441     if (NextUncoveredBit <= Stop)
442       NonOverlappingParts.emplace_back(NextUncoveredBit, Stop);
443   }
444 
445   Allocator *Alloc;
446   MapT Intervals;
447 };
448 
449 } // namespace llvm
450 
451 #endif // LLVM_ADT_COALESCINGBITVECTOR_H
452