1"""
2An implementation of an object that acts like a collection of on/off bits.
3"""
4
5import operator
6from array import array
7from bisect import bisect_left, bisect_right, insort
8
9from whoosh.compat import integer_types, izip, izip_longest, next, xrange
10from whoosh.util.numeric import bytes_for_bits
11
12
13# Number of '1' bits in each byte (0-255)
14_1SPERBYTE = array('B', [0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2,
152, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4,
163, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3,
173, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
182, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5,
195, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4,
203, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5,
215, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 2, 3, 3, 4, 3, 4, 4, 5,
223, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 3, 4,
234, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7,
246, 7, 7, 8])
25
26
27class DocIdSet(object):
28    """Base class for a set of positive integers, implementing a subset of the
29    built-in ``set`` type's interface with extra docid-related methods.
30
31    This is a superclass for alternative set implementations to the built-in
32    ``set`` which are more memory-efficient and specialized toward storing
33    sorted lists of positive integers, though they will inevitably be slower
34    than ``set`` for most operations since they're pure Python.
35    """
36
37    def __eq__(self, other):
38        for a, b in izip(self, other):
39            if a != b:
40                return False
41        return True
42
43    def __neq__(self, other):
44        return not self.__eq__(other)
45
46    def __len__(self):
47        raise NotImplementedError
48
49    def __iter__(self):
50        raise NotImplementedError
51
52    def __contains__(self, i):
53        raise NotImplementedError
54
55    def __or__(self, other):
56        return self.union(other)
57
58    def __and__(self, other):
59        return self.intersection(other)
60
61    def __sub__(self, other):
62        return self.difference(other)
63
64    def copy(self):
65        raise NotImplementedError
66
67    def add(self, n):
68        raise NotImplementedError
69
70    def discard(self, n):
71        raise NotImplementedError
72
73    def update(self, other):
74        add = self.add
75        for i in other:
76            add(i)
77
78    def intersection_update(self, other):
79        for n in self:
80            if n not in other:
81                self.discard(n)
82
83    def difference_update(self, other):
84        for n in other:
85            self.discard(n)
86
87    def invert_update(self, size):
88        """Updates the set in-place to contain numbers in the range
89        ``[0 - size)`` except numbers that are in this set.
90        """
91
92        for i in xrange(size):
93            if i in self:
94                self.discard(i)
95            else:
96                self.add(i)
97
98    def intersection(self, other):
99        c = self.copy()
100        c.intersection_update(other)
101        return c
102
103    def union(self, other):
104        c = self.copy()
105        c.update(other)
106        return c
107
108    def difference(self, other):
109        c = self.copy()
110        c.difference_update(other)
111        return c
112
113    def invert(self, size):
114        c = self.copy()
115        c.invert_update(size)
116        return c
117
118    def isdisjoint(self, other):
119        a = self
120        b = other
121        if len(other) < len(self):
122            a, b = other, self
123        for num in a:
124            if num in b:
125                return False
126        return True
127
128    def before(self, i):
129        """Returns the previous integer in the set before ``i``, or None.
130        """
131        raise NotImplementedError
132
133    def after(self, i):
134        """Returns the next integer in the set after ``i``, or None.
135        """
136        raise NotImplementedError
137
138    def first(self):
139        """Returns the first (lowest) integer in the set.
140        """
141        raise NotImplementedError
142
143    def last(self):
144        """Returns the last (highest) integer in the set.
145        """
146        raise NotImplementedError
147
148
149class BaseBitSet(DocIdSet):
150    # Methods to override
151
152    def byte_count(self):
153        raise NotImplementedError
154
155    def _get_byte(self, i):
156        raise NotImplementedError
157
158    def _iter_bytes(self):
159        raise NotImplementedError
160
161    # Base implementations
162
163    def __len__(self):
164        return sum(_1SPERBYTE[b] for b in self._iter_bytes())
165
166    def __iter__(self):
167        base = 0
168        for byte in self._iter_bytes():
169            for i in xrange(8):
170                if byte & (1 << i):
171                    yield base + i
172            base += 8
173
174    def __nonzero__(self):
175        return any(n for n in self._iter_bytes())
176
177    __bool__ = __nonzero__
178
179    def __contains__(self, i):
180        bucket = i // 8
181        if bucket >= self.byte_count():
182            return False
183        return bool(self._get_byte(bucket) & (1 << (i & 7)))
184
185    def first(self):
186        return self.after(-1)
187
188    def last(self):
189        return self.before(self.byte_count() * 8 + 1)
190
191    def before(self, i):
192        _get_byte = self._get_byte
193        size = self.byte_count() * 8
194
195        if i <= 0:
196            return None
197        elif i >= size:
198            i = size - 1
199        else:
200            i -= 1
201        bucket = i // 8
202
203        while i >= 0:
204            byte = _get_byte(bucket)
205            if not byte:
206                bucket -= 1
207                i = bucket * 8 + 7
208                continue
209            if byte & (1 << (i & 7)):
210                return i
211            if i % 8 == 0:
212                bucket -= 1
213            i -= 1
214
215        return None
216
217    def after(self, i):
218        _get_byte = self._get_byte
219        size = self.byte_count() * 8
220
221        if i >= size:
222            return None
223        elif i < 0:
224            i = 0
225        else:
226            i += 1
227        bucket = i // 8
228
229        while i < size:
230            byte = _get_byte(bucket)
231            if not byte:
232                bucket += 1
233                i = bucket * 8
234                continue
235            if byte & (1 << (i & 7)):
236                return i
237            i += 1
238            if i % 8 == 0:
239                bucket += 1
240
241        return None
242
243
244class OnDiskBitSet(BaseBitSet):
245    """A DocIdSet backed by an array of bits on disk.
246
247    >>> st = RamStorage()
248    >>> f = st.create_file("test.bin")
249    >>> bs = BitSet([1, 10, 15, 7, 2])
250    >>> bytecount = bs.to_disk(f)
251    >>> f.close()
252    >>> # ...
253    >>> f = st.open_file("test.bin")
254    >>> odbs = OnDiskBitSet(f, bytecount)
255    >>> list(odbs)
256    [1, 2, 7, 10, 15]
257    """
258
259    def __init__(self, dbfile, basepos, bytecount):
260        """
261        :param dbfile: a :class:`~whoosh.filedb.structfile.StructFile` object
262            to read from.
263        :param basepos: the base position of the bytes in the given file.
264        :param bytecount: the number of bytes to use for the bit array.
265        """
266
267        self._dbfile = dbfile
268        self._basepos = basepos
269        self._bytecount = bytecount
270
271    def __repr__(self):
272        return "%s(%s, %d, %d)" % (self.__class__.__name__, self.dbfile,
273                                   self._basepos, self.bytecount)
274
275    def byte_count(self):
276        return self._bytecount
277
278    def _get_byte(self, n):
279        return self._dbfile.get_byte(self._basepos + n)
280
281    def _iter_bytes(self):
282        dbfile = self._dbfile
283        dbfile.seek(self._basepos)
284        for _ in xrange(self._bytecount):
285            yield dbfile.read_byte()
286
287
288class BitSet(BaseBitSet):
289    """A DocIdSet backed by an array of bits. This can also be useful as a bit
290    array (e.g. for a Bloom filter). It is much more memory efficient than a
291    large built-in set of integers, but wastes memory for sparse sets.
292    """
293
294    def __init__(self, source=None, size=0):
295        """
296        :param maxsize: the maximum size of the bit array.
297        :param source: an iterable of positive integers to add to this set.
298        :param bits: an array of unsigned bytes ("B") to use as the underlying
299            bit array. This is used by some of the object's methods.
300        """
301
302        # If the source is a list, tuple, or set, we can guess the size
303        if not size and isinstance(source, (list, tuple, set, frozenset)):
304            size = max(source)
305        bytecount = bytes_for_bits(size)
306        self.bits = array("B", (0 for _ in xrange(bytecount)))
307
308        if source:
309            add = self.add
310            for num in source:
311                add(num)
312
313    def __repr__(self):
314        return "%s(%r)" % (self.__class__.__name__, list(self))
315
316    def byte_count(self):
317        return len(self.bits)
318
319    def _get_byte(self, n):
320        return self.bits[n]
321
322    def _iter_bytes(self):
323        return iter(self.bits)
324
325    def _trim(self):
326        bits = self.bits
327        last = len(self.bits) - 1
328        while last >= 0 and not bits[last]:
329            last -= 1
330        del self.bits[last + 1:]
331
332    def _resize(self, tosize):
333        curlength = len(self.bits)
334        newlength = bytes_for_bits(tosize)
335        if newlength > curlength:
336            self.bits.extend((0,) * (newlength - curlength))
337        elif newlength < curlength:
338            del self.bits[newlength + 1:]
339
340    def _zero_extra_bits(self, size):
341        bits = self.bits
342        spill = size - ((len(bits) - 1) * 8)
343        if spill:
344            mask = 2 ** spill - 1
345            bits[-1] = bits[-1] & mask
346
347    def _logic(self, obj, op, other):
348        objbits = obj.bits
349        for i, (byte1, byte2) in enumerate(izip_longest(objbits, other.bits,
350                                                        fillvalue=0)):
351            value = op(byte1, byte2) & 0xFF
352            if i >= len(objbits):
353                objbits.append(value)
354            else:
355                objbits[i] = value
356
357        obj._trim()
358        return obj
359
360    def to_disk(self, dbfile):
361        dbfile.write_array(self.bits)
362        return len(self.bits)
363
364    @classmethod
365    def from_bytes(cls, bs):
366        b = cls()
367        b.bits = array("B", bs)
368        return b
369
370    @classmethod
371    def from_disk(cls, dbfile, bytecount):
372        return cls.from_bytes(dbfile.read_array("B", bytecount))
373
374    def copy(self):
375        b = self.__class__()
376        b.bits = array("B", iter(self.bits))
377        return b
378
379    def clear(self):
380        for i in xrange(len(self.bits)):
381            self.bits[i] = 0
382
383    def add(self, i):
384        bucket = i >> 3
385        if bucket >= len(self.bits):
386            self._resize(i + 1)
387        self.bits[bucket] |= 1 << (i & 7)
388
389    def discard(self, i):
390        bucket = i >> 3
391        self.bits[bucket] &= ~(1 << (i & 7))
392
393    def _resize_to_other(self, other):
394        if isinstance(other, (list, tuple, set, frozenset)):
395            maxbit = max(other)
396            if maxbit // 8 > len(self.bits):
397                self._resize(maxbit)
398
399    def update(self, iterable):
400        self._resize_to_other(iterable)
401        DocIdSet.update(self, iterable)
402
403    def intersection_update(self, other):
404        if isinstance(other, BitSet):
405            return self._logic(self, operator.__and__, other)
406        discard = self.discard
407        for n in self:
408            if n not in other:
409                discard(n)
410
411    def difference_update(self, other):
412        if isinstance(other, BitSet):
413            return self._logic(self, lambda x, y: x & ~y, other)
414        discard = self.discard
415        for n in other:
416            discard(n)
417
418    def invert_update(self, size):
419        bits = self.bits
420        for i in xrange(len(bits)):
421            bits[i] = ~bits[i] & 0xFF
422        self._zero_extra_bits(size)
423
424    def union(self, other):
425        if isinstance(other, BitSet):
426            return self._logic(self.copy(), operator.__or__, other)
427        b = self.copy()
428        b.update(other)
429        return b
430
431    def intersection(self, other):
432        if isinstance(other, BitSet):
433            return self._logic(self.copy(), operator.__and__, other)
434        return BitSet(source=(n for n in self if n in other))
435
436    def difference(self, other):
437        if isinstance(other, BitSet):
438            return self._logic(self.copy(), lambda x, y: x & ~y, other)
439        return BitSet(source=(n for n in self if n not in other))
440
441
442class SortedIntSet(DocIdSet):
443    """A DocIdSet backed by a sorted array of integers.
444    """
445
446    def __init__(self, source=None, typecode="I"):
447        if source:
448            self.data = array(typecode, sorted(source))
449        else:
450            self.data = array(typecode)
451        self.typecode = typecode
452
453    def copy(self):
454        sis = SortedIntSet()
455        sis.data = array(self.typecode, self.data)
456        return sis
457
458    def size(self):
459        return len(self.data) * self.data.itemsize
460
461    def __repr__(self):
462        return "%s(%r)" % (self.__class__.__name__, self.data)
463
464    def __len__(self):
465        return len(self.data)
466
467    def __iter__(self):
468        return iter(self.data)
469
470    def __nonzero__(self):
471        return bool(self.data)
472
473    __bool__ = __nonzero__
474
475    def __contains__(self, i):
476        data = self.data
477        if not data or i < data[0] or i > data[-1]:
478            return False
479
480        pos = bisect_left(data, i)
481        if pos == len(data):
482            return False
483        return data[pos] == i
484
485    def add(self, i):
486        data = self.data
487        if not data or i > data[-1]:
488            data.append(i)
489        else:
490            mn = data[0]
491            mx = data[-1]
492            if i == mn or i == mx:
493                return
494            elif i > mx:
495                data.append(i)
496            elif i < mn:
497                data.insert(0, i)
498            else:
499                pos = bisect_left(data, i)
500                if data[pos] != i:
501                    data.insert(pos, i)
502
503    def discard(self, i):
504        data = self.data
505        pos = bisect_left(data, i)
506        if data[pos] == i:
507            data.pop(pos)
508
509    def clear(self):
510        self.data = array(self.typecode)
511
512    def intersection_update(self, other):
513        self.data = array(self.typecode, (num for num in self if num in other))
514
515    def difference_update(self, other):
516        self.data = array(self.typecode,
517                          (num for num in self if num not in other))
518
519    def intersection(self, other):
520        return SortedIntSet((num for num in self if num in other))
521
522    def difference(self, other):
523        return SortedIntSet((num for num in self if num not in other))
524
525    def first(self):
526        return self.data[0]
527
528    def last(self):
529        return self.data[-1]
530
531    def before(self, i):
532        data = self.data
533        pos = bisect_left(data, i)
534        if pos < 1:
535            return None
536        else:
537            return data[pos - 1]
538
539    def after(self, i):
540        data = self.data
541        if not data or i >= data[-1]:
542            return None
543        elif i < data[0]:
544            return data[0]
545
546        pos = bisect_right(data, i)
547        return data[pos]
548
549
550class ReverseIdSet(DocIdSet):
551    """
552    Wraps a DocIdSet object and reverses its semantics, so docs in the wrapped
553    set are not in this set, and vice-versa.
554    """
555
556    def __init__(self, idset, limit):
557        """
558        :param idset: the DocIdSet object to wrap.
559        :param limit: the highest possible ID plus one.
560        """
561
562        self.idset = idset
563        self.limit = limit
564
565    def __len__(self):
566        return self.limit - len(self.idset)
567
568    def __contains__(self, i):
569        return i not in self.idset
570
571    def __iter__(self):
572        ids = iter(self.idset)
573        try:
574            nx = next(ids)
575        except StopIteration:
576            nx = -1
577
578        for i in xrange(self.limit):
579            if i == nx:
580                try:
581                    nx = next(ids)
582                except StopIteration:
583                    nx = -1
584            else:
585                yield i
586
587    def add(self, n):
588        self.idset.discard(n)
589
590    def discard(self, n):
591        self.idset.add(n)
592
593    def first(self):
594        for i in self:
595            return i
596
597    def last(self):
598        idset = self.idset
599        maxid = self.limit - 1
600        if idset.last() < maxid - 1:
601            return maxid
602
603        for i in xrange(maxid, -1, -1):
604            if i not in idset:
605                return i
606
607ROARING_CUTOFF = 1 << 12
608
609
610class RoaringIdSet(DocIdSet):
611    """
612    Separates IDs into ranges of 2^16 bits, and stores each range in the most
613    efficient type of doc set, either a BitSet (if the range has >= 2^12 IDs)
614    or a sorted ID set of 16-bit shorts.
615    """
616
617    cutoff = 2**12
618
619    def __init__(self, source=None):
620        self.idsets = []
621        if source:
622            self.update(source)
623
624    def __len__(self):
625        if not self.idsets:
626            return 0
627
628        return sum(len(idset) for idset in self.idsets)
629
630    def __contains__(self, n):
631        bucket = n >> 16
632        if bucket >= len(self.idsets):
633            return False
634        return (n - (bucket << 16)) in self.idsets[bucket]
635
636    def __iter__(self):
637        for i, idset in self.idsets:
638            floor = i << 16
639            for n in idset:
640                yield floor + n
641
642    def _find(self, n):
643        bucket = n >> 16
644        floor = n << 16
645        if bucket >= len(self.idsets):
646            self.idsets.extend([SortedIntSet() for _
647                                in xrange(len(self.idsets), bucket + 1)])
648        idset = self.idsets[bucket]
649        return bucket, floor, idset
650
651    def add(self, n):
652        bucket, floor, idset = self._find(n)
653        oldlen = len(idset)
654        idset.add(n - floor)
655        if oldlen <= ROARING_CUTOFF < len(idset):
656            self.idsets[bucket] = BitSet(idset)
657
658    def discard(self, n):
659        bucket, floor, idset = self._find(n)
660        oldlen = len(idset)
661        idset.discard(n - floor)
662        if oldlen > ROARING_CUTOFF >= len(idset):
663            self.idsets[bucket] = SortedIntSet(idset)
664
665
666class MultiIdSet(DocIdSet):
667    """Wraps multiple SERIAL sub-DocIdSet objects and presents them as an
668    aggregated, read-only set.
669    """
670
671    def __init__(self, idsets, offsets):
672        """
673        :param idsets: a list of DocIdSet objects.
674        :param offsets: a list of offsets corresponding to the DocIdSet objects
675            in ``idsets``.
676        """
677
678        assert len(idsets) == len(offsets)
679        self.idsets = idsets
680        self.offsets = offsets
681
682    def _document_set(self, n):
683        offsets = self.offsets
684        return max(bisect_left(offsets, n), len(self.offsets) - 1)
685
686    def _set_and_docnum(self, n):
687        setnum = self._document_set(n)
688        offset = self.offsets[setnum]
689        return self.idsets[setnum], n - offset
690
691    def __len__(self):
692        return sum(len(idset) for idset in self.idsets)
693
694    def __iter__(self):
695        for idset, offset in izip(self.idsets, self.offsets):
696            for docnum in idset:
697                yield docnum + offset
698
699    def __contains__(self, item):
700        idset, n = self._set_and_docnum(item)
701        return n in idset
702
703
704