1 //===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file A bitvector that uses an IntervalMap to coalesce adjacent elements 10 /// into intervals. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_ADT_COALESCINGBITVECTOR_H 15 #define LLVM_ADT_COALESCINGBITVECTOR_H 16 17 #include "llvm/ADT/IntervalMap.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/iterator_range.h" 20 #include "llvm/Support/Debug.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 #include <algorithm> 24 #include <initializer_list> 25 26 namespace llvm { 27 28 /// A bitvector that, under the hood, relies on an IntervalMap to coalesce 29 /// elements into intervals. Good for representing sets which predominantly 30 /// contain contiguous ranges. Bad for representing sets with lots of gaps 31 /// between elements. 32 /// 33 /// Compared to SparseBitVector, CoalescingBitVector offers more predictable 34 /// performance for non-sequential find() operations. 35 /// 36 /// \tparam IndexT - The type of the index into the bitvector. 37 template <typename IndexT> class CoalescingBitVector { 38 static_assert(std::is_unsigned<IndexT>::value, 39 "Index must be an unsigned integer."); 40 41 using ThisT = CoalescingBitVector<IndexT>; 42 43 /// An interval map for closed integer ranges. The mapped values are unused. 44 using MapT = IntervalMap<IndexT, char>; 45 46 using UnderlyingIterator = typename MapT::const_iterator; 47 48 using IntervalT = std::pair<IndexT, IndexT>; 49 50 public: 51 using Allocator = typename MapT::Allocator; 52 53 /// Construct by passing in a CoalescingBitVector<IndexT>::Allocator 54 /// reference. CoalescingBitVector(Allocator & Alloc)55 CoalescingBitVector(Allocator &Alloc) 56 : Alloc(&Alloc), Intervals(Alloc) {} 57 58 /// \name Copy/move constructors and assignment operators. 59 /// @{ 60 CoalescingBitVector(const ThisT & Other)61 CoalescingBitVector(const ThisT &Other) 62 : Alloc(Other.Alloc), Intervals(*Other.Alloc) { 63 set(Other); 64 } 65 66 ThisT &operator=(const ThisT &Other) { 67 clear(); 68 set(Other); 69 return *this; 70 } 71 72 CoalescingBitVector(ThisT &&Other) = delete; 73 ThisT &operator=(ThisT &&Other) = delete; 74 75 /// @} 76 77 /// Clear all the bits. clear()78 void clear() { Intervals.clear(); } 79 80 /// Check whether no bits are set. empty()81 bool empty() const { return Intervals.empty(); } 82 83 /// Count the number of set bits. count()84 unsigned count() const { 85 unsigned Bits = 0; 86 for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It) 87 Bits += 1 + It.stop() - It.start(); 88 return Bits; 89 } 90 91 /// Set the bit at \p Index. 92 /// 93 /// This method does /not/ support setting a bit that has already been set, 94 /// for efficiency reasons. If possible, restructure your code to not set the 95 /// same bit multiple times, or use \ref test_and_set. set(IndexT Index)96 void set(IndexT Index) { 97 assert(!test(Index) && "Setting already-set bits not supported/efficient, " 98 "IntervalMap will assert"); 99 insert(Index, Index); 100 } 101 102 /// Set the bits set in \p Other. 103 /// 104 /// This method does /not/ support setting already-set bits, see \ref set 105 /// for the rationale. For a safe set union operation, use \ref operator|=. set(const ThisT & Other)106 void set(const ThisT &Other) { 107 for (auto It = Other.Intervals.begin(), End = Other.Intervals.end(); 108 It != End; ++It) 109 insert(It.start(), It.stop()); 110 } 111 112 /// Set the bits at \p Indices. Used for testing, primarily. set(std::initializer_list<IndexT> Indices)113 void set(std::initializer_list<IndexT> Indices) { 114 for (IndexT Index : Indices) 115 set(Index); 116 } 117 118 /// Check whether the bit at \p Index is set. test(IndexT Index)119 bool test(IndexT Index) const { 120 const auto It = Intervals.find(Index); 121 if (It == Intervals.end()) 122 return false; 123 assert(It.stop() >= Index && "Interval must end after Index"); 124 return It.start() <= Index; 125 } 126 127 /// Set the bit at \p Index. Supports setting an already-set bit. test_and_set(IndexT Index)128 void test_and_set(IndexT Index) { 129 if (!test(Index)) 130 set(Index); 131 } 132 133 /// Reset the bit at \p Index. Supports resetting an already-unset bit. reset(IndexT Index)134 void reset(IndexT Index) { 135 auto It = Intervals.find(Index); 136 if (It == Intervals.end()) 137 return; 138 139 // Split the interval containing Index into up to two parts: one from 140 // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to 141 // either Start or Stop, we create one new interval. If Index is equal to 142 // both Start and Stop, we simply erase the existing interval. 143 IndexT Start = It.start(); 144 if (Index < Start) 145 // The index was not set. 146 return; 147 IndexT Stop = It.stop(); 148 assert(Index <= Stop && "Wrong interval for index"); 149 It.erase(); 150 if (Start < Index) 151 insert(Start, Index - 1); 152 if (Index < Stop) 153 insert(Index + 1, Stop); 154 } 155 156 /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may 157 /// be a faster alternative. 158 void operator|=(const ThisT &RHS) { 159 // Get the overlaps between the two interval maps. 160 SmallVector<IntervalT, 8> Overlaps; 161 getOverlaps(RHS, Overlaps); 162 163 // Insert the non-overlapping parts of all the intervals from RHS. 164 for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end(); 165 It != End; ++It) { 166 IndexT Start = It.start(); 167 IndexT Stop = It.stop(); 168 SmallVector<IntervalT, 8> NonOverlappingParts; 169 getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts); 170 for (IntervalT AdditivePortion : NonOverlappingParts) 171 insert(AdditivePortion.first, AdditivePortion.second); 172 } 173 } 174 175 /// Set intersection. 176 void operator&=(const ThisT &RHS) { 177 // Get the overlaps between the two interval maps (i.e. the intersection). 178 SmallVector<IntervalT, 8> Overlaps; 179 getOverlaps(RHS, Overlaps); 180 // Rebuild the interval map, including only the overlaps. 181 clear(); 182 for (IntervalT Overlap : Overlaps) 183 insert(Overlap.first, Overlap.second); 184 } 185 186 /// Reset all bits present in \p Other. intersectWithComplement(const ThisT & Other)187 void intersectWithComplement(const ThisT &Other) { 188 SmallVector<IntervalT, 8> Overlaps; 189 if (!getOverlaps(Other, Overlaps)) { 190 // If there is no overlap with Other, the intersection is empty. 191 return; 192 } 193 194 // Delete the overlapping intervals. Split up intervals that only partially 195 // intersect an overlap. 196 for (IntervalT Overlap : Overlaps) { 197 IndexT OlapStart, OlapStop; 198 std::tie(OlapStart, OlapStop) = Overlap; 199 200 auto It = Intervals.find(OlapStart); 201 IndexT CurrStart = It.start(); 202 IndexT CurrStop = It.stop(); 203 assert(CurrStart <= OlapStart && OlapStop <= CurrStop && 204 "Expected some intersection!"); 205 206 // Split the overlap interval into up to two parts: one from [CurrStart, 207 // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is 208 // equal to CurrStart, the first split interval is unnecessary. Ditto for 209 // when OlapStop is equal to CurrStop, we omit the second split interval. 210 It.erase(); 211 if (CurrStart < OlapStart) 212 insert(CurrStart, OlapStart - 1); 213 if (OlapStop < CurrStop) 214 insert(OlapStop + 1, CurrStop); 215 } 216 } 217 218 bool operator==(const ThisT &RHS) const { 219 // We cannot just use std::equal because it checks the dereferenced values 220 // of an iterator pair for equality, not the iterators themselves. In our 221 // case that results in comparison of the (unused) IntervalMap values. 222 auto ItL = Intervals.begin(); 223 auto ItR = RHS.Intervals.begin(); 224 while (ItL != Intervals.end() && ItR != RHS.Intervals.end() && 225 ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) { 226 ++ItL; 227 ++ItR; 228 } 229 return ItL == Intervals.end() && ItR == RHS.Intervals.end(); 230 } 231 232 bool operator!=(const ThisT &RHS) const { return !operator==(RHS); } 233 234 class const_iterator { 235 friend class CoalescingBitVector; 236 237 public: 238 using iterator_category = std::forward_iterator_tag; 239 using value_type = IndexT; 240 using difference_type = std::ptrdiff_t; 241 using pointer = value_type *; 242 using reference = value_type &; 243 244 private: 245 // For performance reasons, make the offset at the end different than the 246 // one used in \ref begin, to optimize the common `It == end()` pattern. 247 static constexpr unsigned kIteratorAtTheEndOffset = ~0u; 248 249 UnderlyingIterator MapIterator; 250 unsigned OffsetIntoMapIterator = 0; 251 252 // Querying the start/stop of an IntervalMap iterator can be very expensive. 253 // Cache these values for performance reasons. 254 IndexT CachedStart = IndexT(); 255 IndexT CachedStop = IndexT(); 256 setToEnd()257 void setToEnd() { 258 OffsetIntoMapIterator = kIteratorAtTheEndOffset; 259 CachedStart = IndexT(); 260 CachedStop = IndexT(); 261 } 262 263 /// MapIterator has just changed, reset the cached state to point to the 264 /// start of the new underlying iterator. resetCache()265 void resetCache() { 266 if (MapIterator.valid()) { 267 OffsetIntoMapIterator = 0; 268 CachedStart = MapIterator.start(); 269 CachedStop = MapIterator.stop(); 270 } else { 271 setToEnd(); 272 } 273 } 274 275 /// Advance the iterator to \p Index, if it is contained within the current 276 /// interval. The public-facing method which supports advancing past the 277 /// current interval is \ref advanceToLowerBound. advanceTo(IndexT Index)278 void advanceTo(IndexT Index) { 279 assert(Index <= CachedStop && "Cannot advance to OOB index"); 280 if (Index < CachedStart) 281 // We're already past this index. 282 return; 283 OffsetIntoMapIterator = Index - CachedStart; 284 } 285 const_iterator(UnderlyingIterator MapIt)286 const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) { 287 resetCache(); 288 } 289 290 public: const_iterator()291 const_iterator() { setToEnd(); } 292 293 bool operator==(const const_iterator &RHS) const { 294 // Do /not/ compare MapIterator for equality, as this is very expensive. 295 // The cached start/stop values make that check unnecessary. 296 return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) == 297 std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart, 298 RHS.CachedStop); 299 } 300 301 bool operator!=(const const_iterator &RHS) const { 302 return !operator==(RHS); 303 } 304 305 IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; } 306 307 const_iterator &operator++() { // Pre-increment (++It). 308 if (CachedStart + OffsetIntoMapIterator < CachedStop) { 309 // Keep going within the current interval. 310 ++OffsetIntoMapIterator; 311 } else { 312 // We reached the end of the current interval: advance. 313 ++MapIterator; 314 resetCache(); 315 } 316 return *this; 317 } 318 319 const_iterator operator++(int) { // Post-increment (It++). 320 const_iterator tmp = *this; 321 operator++(); 322 return tmp; 323 } 324 325 /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If 326 /// no such set bit exists, advance to end(). This is like std::lower_bound. 327 /// This is useful if \p Index is close to the current iterator position. 328 /// However, unlike \ref find(), this has worst-case O(n) performance. advanceToLowerBound(IndexT Index)329 void advanceToLowerBound(IndexT Index) { 330 if (OffsetIntoMapIterator == kIteratorAtTheEndOffset) 331 return; 332 333 // Advance to the first interval containing (or past) Index, or to end(). 334 while (Index > CachedStop) { 335 ++MapIterator; 336 resetCache(); 337 if (OffsetIntoMapIterator == kIteratorAtTheEndOffset) 338 return; 339 } 340 341 advanceTo(Index); 342 } 343 }; 344 begin()345 const_iterator begin() const { return const_iterator(Intervals.begin()); } 346 end()347 const_iterator end() const { return const_iterator(); } 348 349 /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index. 350 /// If no such set bit exists, return end(). This is like std::lower_bound. 351 /// This has worst-case logarithmic performance (roughly O(log(gaps between 352 /// contiguous ranges))). find(IndexT Index)353 const_iterator find(IndexT Index) const { 354 auto UnderlyingIt = Intervals.find(Index); 355 if (UnderlyingIt == Intervals.end()) 356 return end(); 357 auto It = const_iterator(UnderlyingIt); 358 It.advanceTo(Index); 359 return It; 360 } 361 362 /// Return a range iterator which iterates over all of the set bits in the 363 /// half-open range [Start, End). half_open_range(IndexT Start,IndexT End)364 iterator_range<const_iterator> half_open_range(IndexT Start, 365 IndexT End) const { 366 assert(Start < End && "Not a valid range"); 367 auto StartIt = find(Start); 368 if (StartIt == end() || *StartIt >= End) 369 return {end(), end()}; 370 auto EndIt = StartIt; 371 EndIt.advanceToLowerBound(End); 372 return {StartIt, EndIt}; 373 } 374 print(raw_ostream & OS)375 void print(raw_ostream &OS) const { 376 OS << "{"; 377 for (auto It = Intervals.begin(), End = Intervals.end(); It != End; 378 ++It) { 379 OS << "[" << It.start(); 380 if (It.start() != It.stop()) 381 OS << ", " << It.stop(); 382 OS << "]"; 383 } 384 OS << "}"; 385 } 386 387 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) dump()388 LLVM_DUMP_METHOD void dump() const { 389 // LLDB swallows the first line of output after callling dump(). Add 390 // newlines before/after the braces to work around this. 391 dbgs() << "\n"; 392 print(dbgs()); 393 dbgs() << "\n"; 394 } 395 #endif 396 397 private: insert(IndexT Start,IndexT End)398 void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); } 399 400 /// Record the overlaps between \p this and \p Other in \p Overlaps. Return 401 /// true if there is any overlap. getOverlaps(const ThisT & Other,SmallVectorImpl<IntervalT> & Overlaps)402 bool getOverlaps(const ThisT &Other, 403 SmallVectorImpl<IntervalT> &Overlaps) const { 404 for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals); 405 I.valid(); ++I) 406 Overlaps.emplace_back(I.start(), I.stop()); 407 assert(llvm::is_sorted(Overlaps, 408 [](IntervalT LHS, IntervalT RHS) { 409 return LHS.second < RHS.first; 410 }) && 411 "Overlaps must be sorted"); 412 return !Overlaps.empty(); 413 } 414 415 /// Given the set of overlaps between this and some other bitvector, and an 416 /// interval [Start, Stop] from that bitvector, determine the portions of the 417 /// interval which do not overlap with this. getNonOverlappingParts(IndexT Start,IndexT Stop,const SmallVectorImpl<IntervalT> & Overlaps,SmallVectorImpl<IntervalT> & NonOverlappingParts)418 void getNonOverlappingParts(IndexT Start, IndexT Stop, 419 const SmallVectorImpl<IntervalT> &Overlaps, 420 SmallVectorImpl<IntervalT> &NonOverlappingParts) { 421 IndexT NextUncoveredBit = Start; 422 for (IntervalT Overlap : Overlaps) { 423 IndexT OlapStart, OlapStop; 424 std::tie(OlapStart, OlapStop) = Overlap; 425 426 // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop 427 // and Start <= OlapStop. 428 bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop; 429 if (!DoesOverlap) 430 continue; 431 432 // Cover the range [NextUncoveredBit, OlapStart). This puts the start of 433 // the next uncovered range at OlapStop+1. 434 if (NextUncoveredBit < OlapStart) 435 NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1); 436 NextUncoveredBit = OlapStop + 1; 437 if (NextUncoveredBit > Stop) 438 break; 439 } 440 if (NextUncoveredBit <= Stop) 441 NonOverlappingParts.emplace_back(NextUncoveredBit, Stop); 442 } 443 444 Allocator *Alloc; 445 MapT Intervals; 446 }; 447 448 } // namespace llvm 449 450 #endif // LLVM_ADT_COALESCINGBITVECTOR_H 451