1""" 2An implementation of an object that acts like a collection of on/off bits. 3""" 4 5import operator 6from array import array 7from bisect import bisect_left, bisect_right, insort 8 9from whoosh.compat import integer_types, izip, izip_longest, next, xrange 10from whoosh.util.numeric import bytes_for_bits 11 12 13# Number of '1' bits in each byte (0-255) 14_1SPERBYTE = array('B', [0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 152, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 163, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 173, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 182, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 195, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 203, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 215, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 2, 3, 3, 4, 3, 4, 4, 5, 223, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 234, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 246, 7, 7, 8]) 25 26 27class DocIdSet(object): 28 """Base class for a set of positive integers, implementing a subset of the 29 built-in ``set`` type's interface with extra docid-related methods. 30 31 This is a superclass for alternative set implementations to the built-in 32 ``set`` which are more memory-efficient and specialized toward storing 33 sorted lists of positive integers, though they will inevitably be slower 34 than ``set`` for most operations since they're pure Python. 35 """ 36 37 def __eq__(self, other): 38 for a, b in izip(self, other): 39 if a != b: 40 return False 41 return True 42 43 def __neq__(self, other): 44 return not self.__eq__(other) 45 46 def __len__(self): 47 raise NotImplementedError 48 49 def __iter__(self): 50 raise NotImplementedError 51 52 def __contains__(self, i): 53 raise NotImplementedError 54 55 def __or__(self, other): 56 return self.union(other) 57 58 def __and__(self, other): 59 return self.intersection(other) 60 61 def __sub__(self, other): 62 return self.difference(other) 63 64 def copy(self): 65 raise NotImplementedError 66 67 def add(self, n): 68 raise NotImplementedError 69 70 def discard(self, n): 71 raise NotImplementedError 72 73 def update(self, other): 74 add = self.add 75 for i in other: 76 add(i) 77 78 def intersection_update(self, other): 79 for n in self: 80 if n not in other: 81 self.discard(n) 82 83 def difference_update(self, other): 84 for n in other: 85 self.discard(n) 86 87 def invert_update(self, size): 88 """Updates the set in-place to contain numbers in the range 89 ``[0 - size)`` except numbers that are in this set. 90 """ 91 92 for i in xrange(size): 93 if i in self: 94 self.discard(i) 95 else: 96 self.add(i) 97 98 def intersection(self, other): 99 c = self.copy() 100 c.intersection_update(other) 101 return c 102 103 def union(self, other): 104 c = self.copy() 105 c.update(other) 106 return c 107 108 def difference(self, other): 109 c = self.copy() 110 c.difference_update(other) 111 return c 112 113 def invert(self, size): 114 c = self.copy() 115 c.invert_update(size) 116 return c 117 118 def isdisjoint(self, other): 119 a = self 120 b = other 121 if len(other) < len(self): 122 a, b = other, self 123 for num in a: 124 if num in b: 125 return False 126 return True 127 128 def before(self, i): 129 """Returns the previous integer in the set before ``i``, or None. 130 """ 131 raise NotImplementedError 132 133 def after(self, i): 134 """Returns the next integer in the set after ``i``, or None. 135 """ 136 raise NotImplementedError 137 138 def first(self): 139 """Returns the first (lowest) integer in the set. 140 """ 141 raise NotImplementedError 142 143 def last(self): 144 """Returns the last (highest) integer in the set. 145 """ 146 raise NotImplementedError 147 148 149class BaseBitSet(DocIdSet): 150 # Methods to override 151 152 def byte_count(self): 153 raise NotImplementedError 154 155 def _get_byte(self, i): 156 raise NotImplementedError 157 158 def _iter_bytes(self): 159 raise NotImplementedError 160 161 # Base implementations 162 163 def __len__(self): 164 return sum(_1SPERBYTE[b] for b in self._iter_bytes()) 165 166 def __iter__(self): 167 base = 0 168 for byte in self._iter_bytes(): 169 for i in xrange(8): 170 if byte & (1 << i): 171 yield base + i 172 base += 8 173 174 def __nonzero__(self): 175 return any(n for n in self._iter_bytes()) 176 177 __bool__ = __nonzero__ 178 179 def __contains__(self, i): 180 bucket = i // 8 181 if bucket >= self.byte_count(): 182 return False 183 return bool(self._get_byte(bucket) & (1 << (i & 7))) 184 185 def first(self): 186 return self.after(-1) 187 188 def last(self): 189 return self.before(self.byte_count() * 8 + 1) 190 191 def before(self, i): 192 _get_byte = self._get_byte 193 size = self.byte_count() * 8 194 195 if i <= 0: 196 return None 197 elif i >= size: 198 i = size - 1 199 else: 200 i -= 1 201 bucket = i // 8 202 203 while i >= 0: 204 byte = _get_byte(bucket) 205 if not byte: 206 bucket -= 1 207 i = bucket * 8 + 7 208 continue 209 if byte & (1 << (i & 7)): 210 return i 211 if i % 8 == 0: 212 bucket -= 1 213 i -= 1 214 215 return None 216 217 def after(self, i): 218 _get_byte = self._get_byte 219 size = self.byte_count() * 8 220 221 if i >= size: 222 return None 223 elif i < 0: 224 i = 0 225 else: 226 i += 1 227 bucket = i // 8 228 229 while i < size: 230 byte = _get_byte(bucket) 231 if not byte: 232 bucket += 1 233 i = bucket * 8 234 continue 235 if byte & (1 << (i & 7)): 236 return i 237 i += 1 238 if i % 8 == 0: 239 bucket += 1 240 241 return None 242 243 244class OnDiskBitSet(BaseBitSet): 245 """A DocIdSet backed by an array of bits on disk. 246 247 >>> st = RamStorage() 248 >>> f = st.create_file("test.bin") 249 >>> bs = BitSet([1, 10, 15, 7, 2]) 250 >>> bytecount = bs.to_disk(f) 251 >>> f.close() 252 >>> # ... 253 >>> f = st.open_file("test.bin") 254 >>> odbs = OnDiskBitSet(f, bytecount) 255 >>> list(odbs) 256 [1, 2, 7, 10, 15] 257 """ 258 259 def __init__(self, dbfile, basepos, bytecount): 260 """ 261 :param dbfile: a :class:`~whoosh.filedb.structfile.StructFile` object 262 to read from. 263 :param basepos: the base position of the bytes in the given file. 264 :param bytecount: the number of bytes to use for the bit array. 265 """ 266 267 self._dbfile = dbfile 268 self._basepos = basepos 269 self._bytecount = bytecount 270 271 def __repr__(self): 272 return "%s(%s, %d, %d)" % (self.__class__.__name__, self.dbfile, 273 self._basepos, self.bytecount) 274 275 def byte_count(self): 276 return self._bytecount 277 278 def _get_byte(self, n): 279 return self._dbfile.get_byte(self._basepos + n) 280 281 def _iter_bytes(self): 282 dbfile = self._dbfile 283 dbfile.seek(self._basepos) 284 for _ in xrange(self._bytecount): 285 yield dbfile.read_byte() 286 287 288class BitSet(BaseBitSet): 289 """A DocIdSet backed by an array of bits. This can also be useful as a bit 290 array (e.g. for a Bloom filter). It is much more memory efficient than a 291 large built-in set of integers, but wastes memory for sparse sets. 292 """ 293 294 def __init__(self, source=None, size=0): 295 """ 296 :param maxsize: the maximum size of the bit array. 297 :param source: an iterable of positive integers to add to this set. 298 :param bits: an array of unsigned bytes ("B") to use as the underlying 299 bit array. This is used by some of the object's methods. 300 """ 301 302 # If the source is a list, tuple, or set, we can guess the size 303 if not size and isinstance(source, (list, tuple, set, frozenset)): 304 size = max(source) 305 bytecount = bytes_for_bits(size) 306 self.bits = array("B", (0 for _ in xrange(bytecount))) 307 308 if source: 309 add = self.add 310 for num in source: 311 add(num) 312 313 def __repr__(self): 314 return "%s(%r)" % (self.__class__.__name__, list(self)) 315 316 def byte_count(self): 317 return len(self.bits) 318 319 def _get_byte(self, n): 320 return self.bits[n] 321 322 def _iter_bytes(self): 323 return iter(self.bits) 324 325 def _trim(self): 326 bits = self.bits 327 last = len(self.bits) - 1 328 while last >= 0 and not bits[last]: 329 last -= 1 330 del self.bits[last + 1:] 331 332 def _resize(self, tosize): 333 curlength = len(self.bits) 334 newlength = bytes_for_bits(tosize) 335 if newlength > curlength: 336 self.bits.extend((0,) * (newlength - curlength)) 337 elif newlength < curlength: 338 del self.bits[newlength + 1:] 339 340 def _zero_extra_bits(self, size): 341 bits = self.bits 342 spill = size - ((len(bits) - 1) * 8) 343 if spill: 344 mask = 2 ** spill - 1 345 bits[-1] = bits[-1] & mask 346 347 def _logic(self, obj, op, other): 348 objbits = obj.bits 349 for i, (byte1, byte2) in enumerate(izip_longest(objbits, other.bits, 350 fillvalue=0)): 351 value = op(byte1, byte2) & 0xFF 352 if i >= len(objbits): 353 objbits.append(value) 354 else: 355 objbits[i] = value 356 357 obj._trim() 358 return obj 359 360 def to_disk(self, dbfile): 361 dbfile.write_array(self.bits) 362 return len(self.bits) 363 364 @classmethod 365 def from_bytes(cls, bs): 366 b = cls() 367 b.bits = array("B", bs) 368 return b 369 370 @classmethod 371 def from_disk(cls, dbfile, bytecount): 372 return cls.from_bytes(dbfile.read_array("B", bytecount)) 373 374 def copy(self): 375 b = self.__class__() 376 b.bits = array("B", iter(self.bits)) 377 return b 378 379 def clear(self): 380 for i in xrange(len(self.bits)): 381 self.bits[i] = 0 382 383 def add(self, i): 384 bucket = i >> 3 385 if bucket >= len(self.bits): 386 self._resize(i + 1) 387 self.bits[bucket] |= 1 << (i & 7) 388 389 def discard(self, i): 390 bucket = i >> 3 391 self.bits[bucket] &= ~(1 << (i & 7)) 392 393 def _resize_to_other(self, other): 394 if isinstance(other, (list, tuple, set, frozenset)): 395 maxbit = max(other) 396 if maxbit // 8 > len(self.bits): 397 self._resize(maxbit) 398 399 def update(self, iterable): 400 self._resize_to_other(iterable) 401 DocIdSet.update(self, iterable) 402 403 def intersection_update(self, other): 404 if isinstance(other, BitSet): 405 return self._logic(self, operator.__and__, other) 406 discard = self.discard 407 for n in self: 408 if n not in other: 409 discard(n) 410 411 def difference_update(self, other): 412 if isinstance(other, BitSet): 413 return self._logic(self, lambda x, y: x & ~y, other) 414 discard = self.discard 415 for n in other: 416 discard(n) 417 418 def invert_update(self, size): 419 bits = self.bits 420 for i in xrange(len(bits)): 421 bits[i] = ~bits[i] & 0xFF 422 self._zero_extra_bits(size) 423 424 def union(self, other): 425 if isinstance(other, BitSet): 426 return self._logic(self.copy(), operator.__or__, other) 427 b = self.copy() 428 b.update(other) 429 return b 430 431 def intersection(self, other): 432 if isinstance(other, BitSet): 433 return self._logic(self.copy(), operator.__and__, other) 434 return BitSet(source=(n for n in self if n in other)) 435 436 def difference(self, other): 437 if isinstance(other, BitSet): 438 return self._logic(self.copy(), lambda x, y: x & ~y, other) 439 return BitSet(source=(n for n in self if n not in other)) 440 441 442class SortedIntSet(DocIdSet): 443 """A DocIdSet backed by a sorted array of integers. 444 """ 445 446 def __init__(self, source=None, typecode="I"): 447 if source: 448 self.data = array(typecode, sorted(source)) 449 else: 450 self.data = array(typecode) 451 self.typecode = typecode 452 453 def copy(self): 454 sis = SortedIntSet() 455 sis.data = array(self.typecode, self.data) 456 return sis 457 458 def size(self): 459 return len(self.data) * self.data.itemsize 460 461 def __repr__(self): 462 return "%s(%r)" % (self.__class__.__name__, self.data) 463 464 def __len__(self): 465 return len(self.data) 466 467 def __iter__(self): 468 return iter(self.data) 469 470 def __nonzero__(self): 471 return bool(self.data) 472 473 __bool__ = __nonzero__ 474 475 def __contains__(self, i): 476 data = self.data 477 if not data or i < data[0] or i > data[-1]: 478 return False 479 480 pos = bisect_left(data, i) 481 if pos == len(data): 482 return False 483 return data[pos] == i 484 485 def add(self, i): 486 data = self.data 487 if not data or i > data[-1]: 488 data.append(i) 489 else: 490 mn = data[0] 491 mx = data[-1] 492 if i == mn or i == mx: 493 return 494 elif i > mx: 495 data.append(i) 496 elif i < mn: 497 data.insert(0, i) 498 else: 499 pos = bisect_left(data, i) 500 if data[pos] != i: 501 data.insert(pos, i) 502 503 def discard(self, i): 504 data = self.data 505 pos = bisect_left(data, i) 506 if data[pos] == i: 507 data.pop(pos) 508 509 def clear(self): 510 self.data = array(self.typecode) 511 512 def intersection_update(self, other): 513 self.data = array(self.typecode, (num for num in self if num in other)) 514 515 def difference_update(self, other): 516 self.data = array(self.typecode, 517 (num for num in self if num not in other)) 518 519 def intersection(self, other): 520 return SortedIntSet((num for num in self if num in other)) 521 522 def difference(self, other): 523 return SortedIntSet((num for num in self if num not in other)) 524 525 def first(self): 526 return self.data[0] 527 528 def last(self): 529 return self.data[-1] 530 531 def before(self, i): 532 data = self.data 533 pos = bisect_left(data, i) 534 if pos < 1: 535 return None 536 else: 537 return data[pos - 1] 538 539 def after(self, i): 540 data = self.data 541 if not data or i >= data[-1]: 542 return None 543 elif i < data[0]: 544 return data[0] 545 546 pos = bisect_right(data, i) 547 return data[pos] 548 549 550class ReverseIdSet(DocIdSet): 551 """ 552 Wraps a DocIdSet object and reverses its semantics, so docs in the wrapped 553 set are not in this set, and vice-versa. 554 """ 555 556 def __init__(self, idset, limit): 557 """ 558 :param idset: the DocIdSet object to wrap. 559 :param limit: the highest possible ID plus one. 560 """ 561 562 self.idset = idset 563 self.limit = limit 564 565 def __len__(self): 566 return self.limit - len(self.idset) 567 568 def __contains__(self, i): 569 return i not in self.idset 570 571 def __iter__(self): 572 ids = iter(self.idset) 573 try: 574 nx = next(ids) 575 except StopIteration: 576 nx = -1 577 578 for i in xrange(self.limit): 579 if i == nx: 580 try: 581 nx = next(ids) 582 except StopIteration: 583 nx = -1 584 else: 585 yield i 586 587 def add(self, n): 588 self.idset.discard(n) 589 590 def discard(self, n): 591 self.idset.add(n) 592 593 def first(self): 594 for i in self: 595 return i 596 597 def last(self): 598 idset = self.idset 599 maxid = self.limit - 1 600 if idset.last() < maxid - 1: 601 return maxid 602 603 for i in xrange(maxid, -1, -1): 604 if i not in idset: 605 return i 606 607ROARING_CUTOFF = 1 << 12 608 609 610class RoaringIdSet(DocIdSet): 611 """ 612 Separates IDs into ranges of 2^16 bits, and stores each range in the most 613 efficient type of doc set, either a BitSet (if the range has >= 2^12 IDs) 614 or a sorted ID set of 16-bit shorts. 615 """ 616 617 cutoff = 2**12 618 619 def __init__(self, source=None): 620 self.idsets = [] 621 if source: 622 self.update(source) 623 624 def __len__(self): 625 if not self.idsets: 626 return 0 627 628 return sum(len(idset) for idset in self.idsets) 629 630 def __contains__(self, n): 631 bucket = n >> 16 632 if bucket >= len(self.idsets): 633 return False 634 return (n - (bucket << 16)) in self.idsets[bucket] 635 636 def __iter__(self): 637 for i, idset in self.idsets: 638 floor = i << 16 639 for n in idset: 640 yield floor + n 641 642 def _find(self, n): 643 bucket = n >> 16 644 floor = n << 16 645 if bucket >= len(self.idsets): 646 self.idsets.extend([SortedIntSet() for _ 647 in xrange(len(self.idsets), bucket + 1)]) 648 idset = self.idsets[bucket] 649 return bucket, floor, idset 650 651 def add(self, n): 652 bucket, floor, idset = self._find(n) 653 oldlen = len(idset) 654 idset.add(n - floor) 655 if oldlen <= ROARING_CUTOFF < len(idset): 656 self.idsets[bucket] = BitSet(idset) 657 658 def discard(self, n): 659 bucket, floor, idset = self._find(n) 660 oldlen = len(idset) 661 idset.discard(n - floor) 662 if oldlen > ROARING_CUTOFF >= len(idset): 663 self.idsets[bucket] = SortedIntSet(idset) 664 665 666class MultiIdSet(DocIdSet): 667 """Wraps multiple SERIAL sub-DocIdSet objects and presents them as an 668 aggregated, read-only set. 669 """ 670 671 def __init__(self, idsets, offsets): 672 """ 673 :param idsets: a list of DocIdSet objects. 674 :param offsets: a list of offsets corresponding to the DocIdSet objects 675 in ``idsets``. 676 """ 677 678 assert len(idsets) == len(offsets) 679 self.idsets = idsets 680 self.offsets = offsets 681 682 def _document_set(self, n): 683 offsets = self.offsets 684 return max(bisect_left(offsets, n), len(self.offsets) - 1) 685 686 def _set_and_docnum(self, n): 687 setnum = self._document_set(n) 688 offset = self.offsets[setnum] 689 return self.idsets[setnum], n - offset 690 691 def __len__(self): 692 return sum(len(idset) for idset in self.idsets) 693 694 def __iter__(self): 695 for idset, offset in izip(self.idsets, self.offsets): 696 for docnum in idset: 697 yield docnum + offset 698 699 def __contains__(self, item): 700 idset, n = self._set_and_docnum(item) 701 return n in idset 702 703 704