1 //===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file A bitvector that uses an IntervalMap to coalesce adjacent elements
10 /// into intervals.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_ADT_COALESCINGBITVECTOR_H
15 #define LLVM_ADT_COALESCINGBITVECTOR_H
16 
17 #include "llvm/ADT/IntervalMap.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/iterator_range.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #include <algorithm>
24 #include <initializer_list>
25 
26 namespace llvm {
27 
28 /// A bitvector that, under the hood, relies on an IntervalMap to coalesce
29 /// elements into intervals. Good for representing sets which predominantly
30 /// contain contiguous ranges. Bad for representing sets with lots of gaps
31 /// between elements.
32 ///
33 /// Compared to SparseBitVector, CoalescingBitVector offers more predictable
34 /// performance for non-sequential find() operations.
35 ///
36 /// \tparam IndexT - The type of the index into the bitvector.
37 template <typename IndexT> class CoalescingBitVector {
38   static_assert(std::is_unsigned<IndexT>::value,
39                 "Index must be an unsigned integer.");
40 
41   using ThisT = CoalescingBitVector<IndexT>;
42 
43   /// An interval map for closed integer ranges. The mapped values are unused.
44   using MapT = IntervalMap<IndexT, char>;
45 
46   using UnderlyingIterator = typename MapT::const_iterator;
47 
48   using IntervalT = std::pair<IndexT, IndexT>;
49 
50 public:
51   using Allocator = typename MapT::Allocator;
52 
53   /// Construct by passing in a CoalescingBitVector<IndexT>::Allocator
54   /// reference.
CoalescingBitVector(Allocator & Alloc)55   CoalescingBitVector(Allocator &Alloc)
56       : Alloc(&Alloc), Intervals(Alloc) {}
57 
58   /// \name Copy/move constructors and assignment operators.
59   /// @{
60 
CoalescingBitVector(const ThisT & Other)61   CoalescingBitVector(const ThisT &Other)
62       : Alloc(Other.Alloc), Intervals(*Other.Alloc) {
63     set(Other);
64   }
65 
66   ThisT &operator=(const ThisT &Other) {
67     clear();
68     set(Other);
69     return *this;
70   }
71 
72   CoalescingBitVector(ThisT &&Other) = delete;
73   ThisT &operator=(ThisT &&Other) = delete;
74 
75   /// @}
76 
77   /// Clear all the bits.
clear()78   void clear() { Intervals.clear(); }
79 
80   /// Check whether no bits are set.
empty()81   bool empty() const { return Intervals.empty(); }
82 
83   /// Count the number of set bits.
count()84   unsigned count() const {
85     unsigned Bits = 0;
86     for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It)
87       Bits += 1 + It.stop() - It.start();
88     return Bits;
89   }
90 
91   /// Set the bit at \p Index.
92   ///
93   /// This method does /not/ support setting a bit that has already been set,
94   /// for efficiency reasons. If possible, restructure your code to not set the
95   /// same bit multiple times, or use \ref test_and_set.
set(IndexT Index)96   void set(IndexT Index) {
97     assert(!test(Index) && "Setting already-set bits not supported/efficient, "
98                            "IntervalMap will assert");
99     insert(Index, Index);
100   }
101 
102   /// Set the bits set in \p Other.
103   ///
104   /// This method does /not/ support setting already-set bits, see \ref set
105   /// for the rationale. For a safe set union operation, use \ref operator|=.
set(const ThisT & Other)106   void set(const ThisT &Other) {
107     for (auto It = Other.Intervals.begin(), End = Other.Intervals.end();
108          It != End; ++It)
109       insert(It.start(), It.stop());
110   }
111 
112   /// Set the bits at \p Indices. Used for testing, primarily.
set(std::initializer_list<IndexT> Indices)113   void set(std::initializer_list<IndexT> Indices) {
114     for (IndexT Index : Indices)
115       set(Index);
116   }
117 
118   /// Check whether the bit at \p Index is set.
test(IndexT Index)119   bool test(IndexT Index) const {
120     const auto It = Intervals.find(Index);
121     if (It == Intervals.end())
122       return false;
123     assert(It.stop() >= Index && "Interval must end after Index");
124     return It.start() <= Index;
125   }
126 
127   /// Set the bit at \p Index. Supports setting an already-set bit.
test_and_set(IndexT Index)128   void test_and_set(IndexT Index) {
129     if (!test(Index))
130       set(Index);
131   }
132 
133   /// Reset the bit at \p Index. Supports resetting an already-unset bit.
reset(IndexT Index)134   void reset(IndexT Index) {
135     auto It = Intervals.find(Index);
136     if (It == Intervals.end())
137       return;
138 
139     // Split the interval containing Index into up to two parts: one from
140     // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to
141     // either Start or Stop, we create one new interval. If Index is equal to
142     // both Start and Stop, we simply erase the existing interval.
143     IndexT Start = It.start();
144     if (Index < Start)
145       // The index was not set.
146       return;
147     IndexT Stop = It.stop();
148     assert(Index <= Stop && "Wrong interval for index");
149     It.erase();
150     if (Start < Index)
151       insert(Start, Index - 1);
152     if (Index < Stop)
153       insert(Index + 1, Stop);
154   }
155 
156   /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may
157   /// be a faster alternative.
158   void operator|=(const ThisT &RHS) {
159     // Get the overlaps between the two interval maps.
160     SmallVector<IntervalT, 8> Overlaps;
161     getOverlaps(RHS, Overlaps);
162 
163     // Insert the non-overlapping parts of all the intervals from RHS.
164     for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end();
165          It != End; ++It) {
166       IndexT Start = It.start();
167       IndexT Stop = It.stop();
168       SmallVector<IntervalT, 8> NonOverlappingParts;
169       getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts);
170       for (IntervalT AdditivePortion : NonOverlappingParts)
171         insert(AdditivePortion.first, AdditivePortion.second);
172     }
173   }
174 
175   /// Set intersection.
176   void operator&=(const ThisT &RHS) {
177     // Get the overlaps between the two interval maps (i.e. the intersection).
178     SmallVector<IntervalT, 8> Overlaps;
179     getOverlaps(RHS, Overlaps);
180     // Rebuild the interval map, including only the overlaps.
181     clear();
182     for (IntervalT Overlap : Overlaps)
183       insert(Overlap.first, Overlap.second);
184   }
185 
186   /// Reset all bits present in \p Other.
intersectWithComplement(const ThisT & Other)187   void intersectWithComplement(const ThisT &Other) {
188     SmallVector<IntervalT, 8> Overlaps;
189     if (!getOverlaps(Other, Overlaps)) {
190       // If there is no overlap with Other, the intersection is empty.
191       return;
192     }
193 
194     // Delete the overlapping intervals. Split up intervals that only partially
195     // intersect an overlap.
196     for (IntervalT Overlap : Overlaps) {
197       IndexT OlapStart, OlapStop;
198       std::tie(OlapStart, OlapStop) = Overlap;
199 
200       auto It = Intervals.find(OlapStart);
201       IndexT CurrStart = It.start();
202       IndexT CurrStop = It.stop();
203       assert(CurrStart <= OlapStart && OlapStop <= CurrStop &&
204              "Expected some intersection!");
205 
206       // Split the overlap interval into up to two parts: one from [CurrStart,
207       // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is
208       // equal to CurrStart, the first split interval is unnecessary. Ditto for
209       // when OlapStop is equal to CurrStop, we omit the second split interval.
210       It.erase();
211       if (CurrStart < OlapStart)
212         insert(CurrStart, OlapStart - 1);
213       if (OlapStop < CurrStop)
214         insert(OlapStop + 1, CurrStop);
215     }
216   }
217 
218   bool operator==(const ThisT &RHS) const {
219     // We cannot just use std::equal because it checks the dereferenced values
220     // of an iterator pair for equality, not the iterators themselves. In our
221     // case that results in comparison of the (unused) IntervalMap values.
222     auto ItL = Intervals.begin();
223     auto ItR = RHS.Intervals.begin();
224     while (ItL != Intervals.end() && ItR != RHS.Intervals.end() &&
225            ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) {
226       ++ItL;
227       ++ItR;
228     }
229     return ItL == Intervals.end() && ItR == RHS.Intervals.end();
230   }
231 
232   bool operator!=(const ThisT &RHS) const { return !operator==(RHS); }
233 
234   class const_iterator {
235     friend class CoalescingBitVector;
236 
237   public:
238     using iterator_category = std::forward_iterator_tag;
239     using value_type = IndexT;
240     using difference_type = std::ptrdiff_t;
241     using pointer = value_type *;
242     using reference = value_type &;
243 
244   private:
245     // For performance reasons, make the offset at the end different than the
246     // one used in \ref begin, to optimize the common `It == end()` pattern.
247     static constexpr unsigned kIteratorAtTheEndOffset = ~0u;
248 
249     UnderlyingIterator MapIterator;
250     unsigned OffsetIntoMapIterator = 0;
251 
252     // Querying the start/stop of an IntervalMap iterator can be very expensive.
253     // Cache these values for performance reasons.
254     IndexT CachedStart = IndexT();
255     IndexT CachedStop = IndexT();
256 
setToEnd()257     void setToEnd() {
258       OffsetIntoMapIterator = kIteratorAtTheEndOffset;
259       CachedStart = IndexT();
260       CachedStop = IndexT();
261     }
262 
263     /// MapIterator has just changed, reset the cached state to point to the
264     /// start of the new underlying iterator.
resetCache()265     void resetCache() {
266       if (MapIterator.valid()) {
267         OffsetIntoMapIterator = 0;
268         CachedStart = MapIterator.start();
269         CachedStop = MapIterator.stop();
270       } else {
271         setToEnd();
272       }
273     }
274 
275     /// Advance the iterator to \p Index, if it is contained within the current
276     /// interval. The public-facing method which supports advancing past the
277     /// current interval is \ref advanceToLowerBound.
advanceTo(IndexT Index)278     void advanceTo(IndexT Index) {
279       assert(Index <= CachedStop && "Cannot advance to OOB index");
280       if (Index < CachedStart)
281         // We're already past this index.
282         return;
283       OffsetIntoMapIterator = Index - CachedStart;
284     }
285 
const_iterator(UnderlyingIterator MapIt)286     const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) {
287       resetCache();
288     }
289 
290   public:
const_iterator()291     const_iterator() { setToEnd(); }
292 
293     bool operator==(const const_iterator &RHS) const {
294       // Do /not/ compare MapIterator for equality, as this is very expensive.
295       // The cached start/stop values make that check unnecessary.
296       return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) ==
297              std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart,
298                       RHS.CachedStop);
299     }
300 
301     bool operator!=(const const_iterator &RHS) const {
302       return !operator==(RHS);
303     }
304 
305     IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; }
306 
307     const_iterator &operator++() { // Pre-increment (++It).
308       if (CachedStart + OffsetIntoMapIterator < CachedStop) {
309         // Keep going within the current interval.
310         ++OffsetIntoMapIterator;
311       } else {
312         // We reached the end of the current interval: advance.
313         ++MapIterator;
314         resetCache();
315       }
316       return *this;
317     }
318 
319     const_iterator operator++(int) { // Post-increment (It++).
320       const_iterator tmp = *this;
321       operator++();
322       return tmp;
323     }
324 
325     /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If
326     /// no such set bit exists, advance to end(). This is like std::lower_bound.
327     /// This is useful if \p Index is close to the current iterator position.
328     /// However, unlike \ref find(), this has worst-case O(n) performance.
advanceToLowerBound(IndexT Index)329     void advanceToLowerBound(IndexT Index) {
330       if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
331         return;
332 
333       // Advance to the first interval containing (or past) Index, or to end().
334       while (Index > CachedStop) {
335         ++MapIterator;
336         resetCache();
337         if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
338           return;
339       }
340 
341       advanceTo(Index);
342     }
343   };
344 
begin()345   const_iterator begin() const { return const_iterator(Intervals.begin()); }
346 
end()347   const_iterator end() const { return const_iterator(); }
348 
349   /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index.
350   /// If no such set bit exists, return end(). This is like std::lower_bound.
351   /// This has worst-case logarithmic performance (roughly O(log(gaps between
352   /// contiguous ranges))).
find(IndexT Index)353   const_iterator find(IndexT Index) const {
354     auto UnderlyingIt = Intervals.find(Index);
355     if (UnderlyingIt == Intervals.end())
356       return end();
357     auto It = const_iterator(UnderlyingIt);
358     It.advanceTo(Index);
359     return It;
360   }
361 
362   /// Return a range iterator which iterates over all of the set bits in the
363   /// half-open range [Start, End).
half_open_range(IndexT Start,IndexT End)364   iterator_range<const_iterator> half_open_range(IndexT Start,
365                                                  IndexT End) const {
366     assert(Start < End && "Not a valid range");
367     auto StartIt = find(Start);
368     if (StartIt == end() || *StartIt >= End)
369       return {end(), end()};
370     auto EndIt = StartIt;
371     EndIt.advanceToLowerBound(End);
372     return {StartIt, EndIt};
373   }
374 
print(raw_ostream & OS)375   void print(raw_ostream &OS) const {
376     OS << "{";
377     for (auto It = Intervals.begin(), End = Intervals.end(); It != End;
378          ++It) {
379       OS << "[" << It.start();
380       if (It.start() != It.stop())
381         OS << ", " << It.stop();
382       OS << "]";
383     }
384     OS << "}";
385   }
386 
387 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
dump()388   LLVM_DUMP_METHOD void dump() const {
389     // LLDB swallows the first line of output after callling dump(). Add
390     // newlines before/after the braces to work around this.
391     dbgs() << "\n";
392     print(dbgs());
393     dbgs() << "\n";
394   }
395 #endif
396 
397 private:
insert(IndexT Start,IndexT End)398   void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); }
399 
400   /// Record the overlaps between \p this and \p Other in \p Overlaps. Return
401   /// true if there is any overlap.
getOverlaps(const ThisT & Other,SmallVectorImpl<IntervalT> & Overlaps)402   bool getOverlaps(const ThisT &Other,
403                    SmallVectorImpl<IntervalT> &Overlaps) const {
404     for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals);
405          I.valid(); ++I)
406       Overlaps.emplace_back(I.start(), I.stop());
407     assert(llvm::is_sorted(Overlaps,
408                            [](IntervalT LHS, IntervalT RHS) {
409                              return LHS.second < RHS.first;
410                            }) &&
411            "Overlaps must be sorted");
412     return !Overlaps.empty();
413   }
414 
415   /// Given the set of overlaps between this and some other bitvector, and an
416   /// interval [Start, Stop] from that bitvector, determine the portions of the
417   /// interval which do not overlap with this.
getNonOverlappingParts(IndexT Start,IndexT Stop,const SmallVectorImpl<IntervalT> & Overlaps,SmallVectorImpl<IntervalT> & NonOverlappingParts)418   void getNonOverlappingParts(IndexT Start, IndexT Stop,
419                               const SmallVectorImpl<IntervalT> &Overlaps,
420                               SmallVectorImpl<IntervalT> &NonOverlappingParts) {
421     IndexT NextUncoveredBit = Start;
422     for (IntervalT Overlap : Overlaps) {
423       IndexT OlapStart, OlapStop;
424       std::tie(OlapStart, OlapStop) = Overlap;
425 
426       // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop
427       // and Start <= OlapStop.
428       bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop;
429       if (!DoesOverlap)
430         continue;
431 
432       // Cover the range [NextUncoveredBit, OlapStart). This puts the start of
433       // the next uncovered range at OlapStop+1.
434       if (NextUncoveredBit < OlapStart)
435         NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1);
436       NextUncoveredBit = OlapStop + 1;
437       if (NextUncoveredBit > Stop)
438         break;
439     }
440     if (NextUncoveredBit <= Stop)
441       NonOverlappingParts.emplace_back(NextUncoveredBit, Stop);
442   }
443 
444   Allocator *Alloc;
445   MapT Intervals;
446 };
447 
448 } // namespace llvm
449 
450 #endif // LLVM_ADT_COALESCINGBITVECTOR_H
451