1 //===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file A bitvector that uses an IntervalMap to coalesce adjacent elements
10 /// into intervals.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_ADT_COALESCINGBITVECTOR_H
15 #define LLVM_ADT_COALESCINGBITVECTOR_H
16 
17 #include "llvm/ADT/IntervalMap.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/iterator_range.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #include <algorithm>
24 #include <initializer_list>
25 
26 namespace llvm {
27 
28 /// A bitvector that, under the hood, relies on an IntervalMap to coalesce
29 /// elements into intervals. Good for representing sets which predominantly
30 /// contain contiguous ranges. Bad for representing sets with lots of gaps
31 /// between elements.
32 ///
33 /// Compared to SparseBitVector, CoalescingBitVector offers more predictable
34 /// performance for non-sequential find() operations.
35 ///
36 /// \tparam IndexT - The type of the index into the bitvector.
37 template <typename IndexT> class CoalescingBitVector {
38   static_assert(std::is_unsigned<IndexT>::value,
39                 "Index must be an unsigned integer.");
40 
41   using ThisT = CoalescingBitVector<IndexT>;
42 
43   /// An interval map for closed integer ranges. The mapped values are unused.
44   using MapT = IntervalMap<IndexT, char>;
45 
46   using UnderlyingIterator = typename MapT::const_iterator;
47 
48   using IntervalT = std::pair<IndexT, IndexT>;
49 
50 public:
51   using Allocator = typename MapT::Allocator;
52 
53   /// Construct by passing in a CoalescingBitVector<IndexT>::Allocator
54   /// reference.
55   CoalescingBitVector(Allocator &Alloc)
56       : Alloc(&Alloc), Intervals(Alloc) {}
57 
58   /// \name Copy/move constructors and assignment operators.
59   /// @{
60 
61   CoalescingBitVector(const ThisT &Other)
62       : Alloc(Other.Alloc), Intervals(*Other.Alloc) {
63     set(Other);
64   }
65 
66   ThisT &operator=(const ThisT &Other) {
67     clear();
68     set(Other);
69     return *this;
70   }
71 
72   CoalescingBitVector(ThisT &&Other) = delete;
73   ThisT &operator=(ThisT &&Other) = delete;
74 
75   /// @}
76 
77   /// Clear all the bits.
78   void clear() { Intervals.clear(); }
79 
80   /// Check whether no bits are set.
81   bool empty() const { return Intervals.empty(); }
82 
83   /// Count the number of set bits.
84   unsigned count() const {
85     unsigned Bits = 0;
86     for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It)
87       Bits += 1 + It.stop() - It.start();
88     return Bits;
89   }
90 
91   /// Set the bit at \p Index.
92   ///
93   /// This method does /not/ support setting a bit that has already been set,
94   /// for efficiency reasons. If possible, restructure your code to not set the
95   /// same bit multiple times, or use \ref test_and_set.
96   void set(IndexT Index) {
97     assert(!test(Index) && "Setting already-set bits not supported/efficient, "
98                            "IntervalMap will assert");
99     insert(Index, Index);
100   }
101 
102   /// Set the bits set in \p Other.
103   ///
104   /// This method does /not/ support setting already-set bits, see \ref set
105   /// for the rationale. For a safe set union operation, use \ref operator|=.
106   void set(const ThisT &Other) {
107     for (auto It = Other.Intervals.begin(), End = Other.Intervals.end();
108          It != End; ++It)
109       insert(It.start(), It.stop());
110   }
111 
112   /// Set the bits at \p Indices. Used for testing, primarily.
113   void set(std::initializer_list<IndexT> Indices) {
114     for (IndexT Index : Indices)
115       set(Index);
116   }
117 
118   /// Check whether the bit at \p Index is set.
119   bool test(IndexT Index) const {
120     const auto It = Intervals.find(Index);
121     if (It == Intervals.end())
122       return false;
123     assert(It.stop() >= Index && "Interval must end after Index");
124     return It.start() <= Index;
125   }
126 
127   /// Set the bit at \p Index. Supports setting an already-set bit.
128   void test_and_set(IndexT Index) {
129     if (!test(Index))
130       set(Index);
131   }
132 
133   /// Reset the bit at \p Index. Supports resetting an already-unset bit.
134   void reset(IndexT Index) {
135     auto It = Intervals.find(Index);
136     if (It == Intervals.end())
137       return;
138 
139     // Split the interval containing Index into up to two parts: one from
140     // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to
141     // either Start or Stop, we create one new interval. If Index is equal to
142     // both Start and Stop, we simply erase the existing interval.
143     IndexT Start = It.start();
144     if (Index < Start)
145       // The index was not set.
146       return;
147     IndexT Stop = It.stop();
148     assert(Index <= Stop && "Wrong interval for index");
149     It.erase();
150     if (Start < Index)
151       insert(Start, Index - 1);
152     if (Index < Stop)
153       insert(Index + 1, Stop);
154   }
155 
156   /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may
157   /// be a faster alternative.
158   void operator|=(const ThisT &RHS) {
159     // Get the overlaps between the two interval maps.
160     SmallVector<IntervalT, 8> Overlaps;
161     getOverlaps(RHS, Overlaps);
162 
163     // Insert the non-overlapping parts of all the intervals from RHS.
164     for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end();
165          It != End; ++It) {
166       IndexT Start = It.start();
167       IndexT Stop = It.stop();
168       SmallVector<IntervalT, 8> NonOverlappingParts;
169       getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts);
170       for (IntervalT AdditivePortion : NonOverlappingParts)
171         insert(AdditivePortion.first, AdditivePortion.second);
172     }
173   }
174 
175   /// Set intersection.
176   void operator&=(const ThisT &RHS) {
177     // Get the overlaps between the two interval maps (i.e. the intersection).
178     SmallVector<IntervalT, 8> Overlaps;
179     getOverlaps(RHS, Overlaps);
180     // Rebuild the interval map, including only the overlaps.
181     clear();
182     for (IntervalT Overlap : Overlaps)
183       insert(Overlap.first, Overlap.second);
184   }
185 
186   /// Reset all bits present in \p Other.
187   void intersectWithComplement(const ThisT &Other) {
188     SmallVector<IntervalT, 8> Overlaps;
189     if (!getOverlaps(Other, Overlaps)) {
190       // If there is no overlap with Other, the intersection is empty.
191       return;
192     }
193 
194     // Delete the overlapping intervals. Split up intervals that only partially
195     // intersect an overlap.
196     for (IntervalT Overlap : Overlaps) {
197       IndexT OlapStart, OlapStop;
198       std::tie(OlapStart, OlapStop) = Overlap;
199 
200       auto It = Intervals.find(OlapStart);
201       IndexT CurrStart = It.start();
202       IndexT CurrStop = It.stop();
203       assert(CurrStart <= OlapStart && OlapStop <= CurrStop &&
204              "Expected some intersection!");
205 
206       // Split the overlap interval into up to two parts: one from [CurrStart,
207       // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is
208       // equal to CurrStart, the first split interval is unnecessary. Ditto for
209       // when OlapStop is equal to CurrStop, we omit the second split interval.
210       It.erase();
211       if (CurrStart < OlapStart)
212         insert(CurrStart, OlapStart - 1);
213       if (OlapStop < CurrStop)
214         insert(OlapStop + 1, CurrStop);
215     }
216   }
217 
218   bool operator==(const ThisT &RHS) const {
219     // We cannot just use std::equal because it checks the dereferenced values
220     // of an iterator pair for equality, not the iterators themselves. In our
221     // case that results in comparison of the (unused) IntervalMap values.
222     auto ItL = Intervals.begin();
223     auto ItR = RHS.Intervals.begin();
224     while (ItL != Intervals.end() && ItR != RHS.Intervals.end() &&
225            ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) {
226       ++ItL;
227       ++ItR;
228     }
229     return ItL == Intervals.end() && ItR == RHS.Intervals.end();
230   }
231 
232   bool operator!=(const ThisT &RHS) const { return !operator==(RHS); }
233 
234   class const_iterator
235       : public std::iterator<std::forward_iterator_tag, IndexT> {
236     friend class CoalescingBitVector;
237 
238     // For performance reasons, make the offset at the end different than the
239     // one used in \ref begin, to optimize the common `It == end()` pattern.
240     static constexpr unsigned kIteratorAtTheEndOffset = ~0u;
241 
242     UnderlyingIterator MapIterator;
243     unsigned OffsetIntoMapIterator = 0;
244 
245     // Querying the start/stop of an IntervalMap iterator can be very expensive.
246     // Cache these values for performance reasons.
247     IndexT CachedStart = IndexT();
248     IndexT CachedStop = IndexT();
249 
250     void setToEnd() {
251       OffsetIntoMapIterator = kIteratorAtTheEndOffset;
252       CachedStart = IndexT();
253       CachedStop = IndexT();
254     }
255 
256     /// MapIterator has just changed, reset the cached state to point to the
257     /// start of the new underlying iterator.
258     void resetCache() {
259       if (MapIterator.valid()) {
260         OffsetIntoMapIterator = 0;
261         CachedStart = MapIterator.start();
262         CachedStop = MapIterator.stop();
263       } else {
264         setToEnd();
265       }
266     }
267 
268     /// Advance the iterator to \p Index, if it is contained within the current
269     /// interval. The public-facing method which supports advancing past the
270     /// current interval is \ref advanceToLowerBound.
271     void advanceTo(IndexT Index) {
272       assert(Index <= CachedStop && "Cannot advance to OOB index");
273       if (Index < CachedStart)
274         // We're already past this index.
275         return;
276       OffsetIntoMapIterator = Index - CachedStart;
277     }
278 
279     const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) {
280       resetCache();
281     }
282 
283   public:
284     const_iterator() { setToEnd(); }
285 
286     bool operator==(const const_iterator &RHS) const {
287       // Do /not/ compare MapIterator for equality, as this is very expensive.
288       // The cached start/stop values make that check unnecessary.
289       return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) ==
290              std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart,
291                       RHS.CachedStop);
292     }
293 
294     bool operator!=(const const_iterator &RHS) const {
295       return !operator==(RHS);
296     }
297 
298     IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; }
299 
300     const_iterator &operator++() { // Pre-increment (++It).
301       if (CachedStart + OffsetIntoMapIterator < CachedStop) {
302         // Keep going within the current interval.
303         ++OffsetIntoMapIterator;
304       } else {
305         // We reached the end of the current interval: advance.
306         ++MapIterator;
307         resetCache();
308       }
309       return *this;
310     }
311 
312     const_iterator operator++(int) { // Post-increment (It++).
313       const_iterator tmp = *this;
314       operator++();
315       return tmp;
316     }
317 
318     /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If
319     /// no such set bit exists, advance to end(). This is like std::lower_bound.
320     /// This is useful if \p Index is close to the current iterator position.
321     /// However, unlike \ref find(), this has worst-case O(n) performance.
322     void advanceToLowerBound(IndexT Index) {
323       if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
324         return;
325 
326       // Advance to the first interval containing (or past) Index, or to end().
327       while (Index > CachedStop) {
328         ++MapIterator;
329         resetCache();
330         if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
331           return;
332       }
333 
334       advanceTo(Index);
335     }
336   };
337 
338   const_iterator begin() const { return const_iterator(Intervals.begin()); }
339 
340   const_iterator end() const { return const_iterator(); }
341 
342   /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index.
343   /// If no such set bit exists, return end(). This is like std::lower_bound.
344   /// This has worst-case logarithmic performance (roughly O(log(gaps between
345   /// contiguous ranges))).
346   const_iterator find(IndexT Index) const {
347     auto UnderlyingIt = Intervals.find(Index);
348     if (UnderlyingIt == Intervals.end())
349       return end();
350     auto It = const_iterator(UnderlyingIt);
351     It.advanceTo(Index);
352     return It;
353   }
354 
355   /// Return a range iterator which iterates over all of the set bits in the
356   /// half-open range [Start, End).
357   iterator_range<const_iterator> half_open_range(IndexT Start,
358                                                  IndexT End) const {
359     assert(Start < End && "Not a valid range");
360     auto StartIt = find(Start);
361     if (StartIt == end() || *StartIt >= End)
362       return {end(), end()};
363     auto EndIt = StartIt;
364     EndIt.advanceToLowerBound(End);
365     return {StartIt, EndIt};
366   }
367 
368   void print(raw_ostream &OS) const {
369     OS << "{";
370     for (auto It = Intervals.begin(), End = Intervals.end(); It != End;
371          ++It) {
372       OS << "[" << It.start();
373       if (It.start() != It.stop())
374         OS << ", " << It.stop();
375       OS << "]";
376     }
377     OS << "}";
378   }
379 
380 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
381   LLVM_DUMP_METHOD void dump() const {
382     // LLDB swallows the first line of output after callling dump(). Add
383     // newlines before/after the braces to work around this.
384     dbgs() << "\n";
385     print(dbgs());
386     dbgs() << "\n";
387   }
388 #endif
389 
390 private:
391   void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); }
392 
393   /// Record the overlaps between \p this and \p Other in \p Overlaps. Return
394   /// true if there is any overlap.
395   bool getOverlaps(const ThisT &Other,
396                    SmallVectorImpl<IntervalT> &Overlaps) const {
397     for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals);
398          I.valid(); ++I)
399       Overlaps.emplace_back(I.start(), I.stop());
400     assert(llvm::is_sorted(Overlaps,
401                            [](IntervalT LHS, IntervalT RHS) {
402                              return LHS.second < RHS.first;
403                            }) &&
404            "Overlaps must be sorted");
405     return !Overlaps.empty();
406   }
407 
408   /// Given the set of overlaps between this and some other bitvector, and an
409   /// interval [Start, Stop] from that bitvector, determine the portions of the
410   /// interval which do not overlap with this.
411   void getNonOverlappingParts(IndexT Start, IndexT Stop,
412                               const SmallVectorImpl<IntervalT> &Overlaps,
413                               SmallVectorImpl<IntervalT> &NonOverlappingParts) {
414     IndexT NextUncoveredBit = Start;
415     for (IntervalT Overlap : Overlaps) {
416       IndexT OlapStart, OlapStop;
417       std::tie(OlapStart, OlapStop) = Overlap;
418 
419       // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop
420       // and Start <= OlapStop.
421       bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop;
422       if (!DoesOverlap)
423         continue;
424 
425       // Cover the range [NextUncoveredBit, OlapStart). This puts the start of
426       // the next uncovered range at OlapStop+1.
427       if (NextUncoveredBit < OlapStart)
428         NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1);
429       NextUncoveredBit = OlapStop + 1;
430       if (NextUncoveredBit > Stop)
431         break;
432     }
433     if (NextUncoveredBit <= Stop)
434       NonOverlappingParts.emplace_back(NextUncoveredBit, Stop);
435   }
436 
437   Allocator *Alloc;
438   MapT Intervals;
439 };
440 
441 } // namespace llvm
442 
443 #endif // LLVM_ADT_COALESCINGBITVECTOR_H
444