1"""
2Support for native homogeneous sets.
3"""
4
5
6import collections
7import contextlib
8import math
9import operator
10
11from llvmlite import ir
12from numba.core import types, typing, cgutils
13from numba.core.imputils import (lower_builtin, lower_cast,
14                                    iternext_impl, impl_ret_borrowed,
15                                    impl_ret_new_ref, impl_ret_untracked,
16                                    for_iter, call_len, RefType)
17from numba.core.utils import cached_property
18from numba.misc import quicksort
19from numba.cpython import slicing
20from numba.extending import intrinsic
21
22
23def get_payload_struct(context, builder, set_type, ptr):
24    """
25    Given a set value and type, get its payload structure (as a
26    reference, so that mutations are seen by all).
27    """
28    payload_type = types.SetPayload(set_type)
29    ptrty = context.get_data_type(payload_type).as_pointer()
30    payload = builder.bitcast(ptr, ptrty)
31    return context.make_data_helper(builder, payload_type, ref=payload)
32
33
34def get_entry_size(context, set_type):
35    """
36    Return the entry size for the given set type.
37    """
38    llty = context.get_data_type(types.SetEntry(set_type))
39    return context.get_abi_sizeof(llty)
40
41
42# Note these values are special:
43# - EMPTY is obtained by issuing memset(..., 0xFF)
44# - (unsigned) EMPTY > (unsigned) DELETED > any other hash value
45EMPTY = -1
46DELETED = -2
47FALLBACK = -43
48
49# Minimal size of entries table.  Must be a power of 2!
50MINSIZE = 16
51
52# Number of cache-friendly linear probes before switching to non-linear probing
53LINEAR_PROBES = 3
54
55DEBUG_ALLOCS = False
56
57
58def get_hash_value(context, builder, typ, value):
59    """
60    Compute the hash of the given value.
61    """
62    typingctx = context.typing_context
63    fnty = typingctx.resolve_value_type(hash)
64    sig = fnty.get_call_type(typingctx, (typ,), {})
65    fn = context.get_function(fnty, sig)
66    h = fn(builder, (value,))
67    # Fixup reserved values
68    is_ok = is_hash_used(context, builder, h)
69    fallback = ir.Constant(h.type, FALLBACK)
70    return builder.select(is_ok, h, fallback)
71
72
73@intrinsic
74def _get_hash_value_intrinsic(typingctx, value):
75    def impl(context, builder, typ, args):
76        return get_hash_value(context, builder, value, args[0])
77    fnty = typingctx.resolve_value_type(hash)
78    sig = fnty.get_call_type(typingctx, (value,), {})
79    return sig, impl
80
81
82def is_hash_empty(context, builder, h):
83    """
84    Whether the hash value denotes an empty entry.
85    """
86    empty = ir.Constant(h.type, EMPTY)
87    return builder.icmp_unsigned('==', h, empty)
88
89def is_hash_deleted(context, builder, h):
90    """
91    Whether the hash value denotes a deleted entry.
92    """
93    deleted = ir.Constant(h.type, DELETED)
94    return builder.icmp_unsigned('==', h, deleted)
95
96def is_hash_used(context, builder, h):
97    """
98    Whether the hash value denotes an active entry.
99    """
100    # Everything below DELETED is an used entry
101    deleted = ir.Constant(h.type, DELETED)
102    return builder.icmp_unsigned('<', h, deleted)
103
104
105SetLoop = collections.namedtuple('SetLoop', ('index', 'entry', 'do_break'))
106
107
108class _SetPayload(object):
109
110    def __init__(self, context, builder, set_type, ptr):
111        payload = get_payload_struct(context, builder, set_type, ptr)
112        self._context = context
113        self._builder = builder
114        self._ty = set_type
115        self._payload = payload
116        self._entries = payload._get_ptr_by_name('entries')
117        self._ptr = ptr
118
119    @property
120    def mask(self):
121        return self._payload.mask
122
123    @mask.setter
124    def mask(self, value):
125        # CAUTION: mask must be a power of 2 minus 1
126        self._payload.mask = value
127
128    @property
129    def used(self):
130        return self._payload.used
131
132    @used.setter
133    def used(self, value):
134        self._payload.used = value
135
136    @property
137    def fill(self):
138        return self._payload.fill
139
140    @fill.setter
141    def fill(self, value):
142        self._payload.fill = value
143
144    @property
145    def finger(self):
146        return self._payload.finger
147
148    @finger.setter
149    def finger(self, value):
150        self._payload.finger = value
151
152    @property
153    def dirty(self):
154        return self._payload.dirty
155
156    @dirty.setter
157    def dirty(self, value):
158        self._payload.dirty = value
159
160    @property
161    def entries(self):
162        """
163        A pointer to the start of the entries array.
164        """
165        return self._entries
166
167    @property
168    def ptr(self):
169        """
170        A pointer to the start of the NRT-allocated area.
171        """
172        return self._ptr
173
174    def get_entry(self, idx):
175        """
176        Get entry number *idx*.
177        """
178        entry_ptr = cgutils.gep(self._builder, self._entries, idx)
179        entry = self._context.make_data_helper(self._builder,
180                                               types.SetEntry(self._ty),
181                                               ref=entry_ptr)
182        return entry
183
184    def _lookup(self, item, h, for_insert=False):
185        """
186        Lookup the *item* with the given hash values in the entries.
187
188        Return a (found, entry index) tuple:
189        - If found is true, <entry index> points to the entry containing
190          the item.
191        - If found is false, <entry index> points to the empty entry that
192          the item can be written to (only if *for_insert* is true)
193        """
194        context = self._context
195        builder = self._builder
196
197        intp_t = h.type
198
199        mask = self.mask
200        dtype = self._ty.dtype
201        eqfn = context.get_function(operator.eq,
202                                    typing.signature(types.boolean, dtype, dtype))
203
204        one = ir.Constant(intp_t, 1)
205        five = ir.Constant(intp_t, 5)
206
207        # The perturbation value for probing
208        perturb = cgutils.alloca_once_value(builder, h)
209        # The index of the entry being considered: start with (hash & mask)
210        index = cgutils.alloca_once_value(builder,
211                                          builder.and_(h, mask))
212        if for_insert:
213            # The index of the first deleted entry in the lookup chain
214            free_index_sentinel = mask.type(-1)  # highest unsigned index
215            free_index = cgutils.alloca_once_value(builder, free_index_sentinel)
216
217        bb_body = builder.append_basic_block("lookup.body")
218        bb_found = builder.append_basic_block("lookup.found")
219        bb_not_found = builder.append_basic_block("lookup.not_found")
220        bb_end = builder.append_basic_block("lookup.end")
221
222        def check_entry(i):
223            """
224            Check entry *i* against the value being searched for.
225            """
226            entry = self.get_entry(i)
227            entry_hash = entry.hash
228
229            with builder.if_then(builder.icmp_unsigned('==', h, entry_hash)):
230                # Hashes are equal, compare values
231                # (note this also ensures the entry is used)
232                eq = eqfn(builder, (item, entry.key))
233                with builder.if_then(eq):
234                    builder.branch(bb_found)
235
236            with builder.if_then(is_hash_empty(context, builder, entry_hash)):
237                builder.branch(bb_not_found)
238
239            if for_insert:
240                # Memorize the index of the first deleted entry
241                with builder.if_then(is_hash_deleted(context, builder, entry_hash)):
242                    j = builder.load(free_index)
243                    j = builder.select(builder.icmp_unsigned('==', j, free_index_sentinel),
244                                       i, j)
245                    builder.store(j, free_index)
246
247        # First linear probing.  When the number of collisions is small,
248        # the lineary probing loop achieves better cache locality and
249        # is also slightly cheaper computationally.
250        with cgutils.for_range(builder, ir.Constant(intp_t, LINEAR_PROBES)):
251            i = builder.load(index)
252            check_entry(i)
253            i = builder.add(i, one)
254            i = builder.and_(i, mask)
255            builder.store(i, index)
256
257        # If not found after linear probing, switch to a non-linear
258        # perturbation keyed on the unmasked hash value.
259        # XXX how to tell LLVM this branch is unlikely?
260        builder.branch(bb_body)
261        with builder.goto_block(bb_body):
262            i = builder.load(index)
263            check_entry(i)
264
265            # Perturb to go to next entry:
266            #   perturb >>= 5
267            #   i = (i * 5 + 1 + perturb) & mask
268            p = builder.load(perturb)
269            p = builder.lshr(p, five)
270            i = builder.add(one, builder.mul(i, five))
271            i = builder.and_(mask, builder.add(i, p))
272            builder.store(i, index)
273            builder.store(p, perturb)
274            # Loop
275            builder.branch(bb_body)
276
277        with builder.goto_block(bb_not_found):
278            if for_insert:
279                # Not found => for insertion, return the index of the first
280                # deleted entry (if any), to avoid creating an infinite
281                # lookup chain (issue #1913).
282                i = builder.load(index)
283                j = builder.load(free_index)
284                i = builder.select(builder.icmp_unsigned('==', j, free_index_sentinel),
285                                   i, j)
286                builder.store(i, index)
287            builder.branch(bb_end)
288
289        with builder.goto_block(bb_found):
290            builder.branch(bb_end)
291
292        builder.position_at_end(bb_end)
293
294        found = builder.phi(ir.IntType(1), 'found')
295        found.add_incoming(cgutils.true_bit, bb_found)
296        found.add_incoming(cgutils.false_bit, bb_not_found)
297
298        return found, builder.load(index)
299
300    @contextlib.contextmanager
301    def _iterate(self, start=None):
302        """
303        Iterate over the payload's entries.  Yield a SetLoop.
304        """
305        context = self._context
306        builder = self._builder
307
308        intp_t = context.get_value_type(types.intp)
309        one = ir.Constant(intp_t, 1)
310        size = builder.add(self.mask, one)
311
312        with cgutils.for_range(builder, size, start=start) as range_loop:
313            entry = self.get_entry(range_loop.index)
314            is_used = is_hash_used(context, builder, entry.hash)
315            with builder.if_then(is_used):
316                loop = SetLoop(index=range_loop.index, entry=entry,
317                               do_break=range_loop.do_break)
318                yield loop
319
320    @contextlib.contextmanager
321    def _next_entry(self):
322        """
323        Yield a random entry from the payload.  Caller must ensure the
324        set isn't empty, otherwise the function won't end.
325        """
326        context = self._context
327        builder = self._builder
328
329        intp_t = context.get_value_type(types.intp)
330        zero = ir.Constant(intp_t, 0)
331        one = ir.Constant(intp_t, 1)
332        mask = self.mask
333
334        # Start walking the entries from the stored "search finger" and
335        # break as soon as we find a used entry.
336
337        bb_body = builder.append_basic_block('next_entry_body')
338        bb_end = builder.append_basic_block('next_entry_end')
339
340        index = cgutils.alloca_once_value(builder, self.finger)
341        builder.branch(bb_body)
342
343        with builder.goto_block(bb_body):
344            i = builder.load(index)
345            # ANDing with mask ensures we stay inside the table boundaries
346            i = builder.and_(mask, builder.add(i, one))
347            builder.store(i, index)
348            entry = self.get_entry(i)
349            is_used = is_hash_used(context, builder, entry.hash)
350            builder.cbranch(is_used, bb_end, bb_body)
351
352        builder.position_at_end(bb_end)
353
354        # Update the search finger with the next position.  This avoids
355        # O(n**2) behaviour when pop() is called in a loop.
356        i = builder.load(index)
357        self.finger = i
358        yield self.get_entry(i)
359
360
361class SetInstance(object):
362
363    def __init__(self, context, builder, set_type, set_val):
364        self._context = context
365        self._builder = builder
366        self._ty = set_type
367        self._entrysize = get_entry_size(context, set_type)
368        self._set = context.make_helper(builder, set_type, set_val)
369
370    @property
371    def dtype(self):
372        return self._ty.dtype
373
374    @property
375    def payload(self):
376        """
377        The _SetPayload for this set.
378        """
379        # This cannot be cached as the pointer can move around!
380        context = self._context
381        builder = self._builder
382
383        ptr = self._context.nrt.meminfo_data(builder, self.meminfo)
384        return _SetPayload(context, builder, self._ty, ptr)
385
386    @property
387    def value(self):
388        return self._set._getvalue()
389
390    @property
391    def meminfo(self):
392        return self._set.meminfo
393
394    @property
395    def parent(self):
396        return self._set.parent
397
398    @parent.setter
399    def parent(self, value):
400        self._set.parent = value
401
402    def get_size(self):
403        """
404        Return the number of elements in the size.
405        """
406        return self.payload.used
407
408    def set_dirty(self, val):
409        if self._ty.reflected:
410            self.payload.dirty = cgutils.true_bit if val else cgutils.false_bit
411
412    def _add_entry(self, payload, entry, item, h, do_resize=True):
413        context = self._context
414        builder = self._builder
415
416        old_hash = entry.hash
417        entry.hash = h
418        entry.key = item
419        # used++
420        used = payload.used
421        one = ir.Constant(used.type, 1)
422        used = payload.used = builder.add(used, one)
423        # fill++ if entry wasn't a deleted one
424        with builder.if_then(is_hash_empty(context, builder, old_hash),
425                             likely=True):
426            payload.fill = builder.add(payload.fill, one)
427        # Grow table if necessary
428        if do_resize:
429            self.upsize(used)
430        self.set_dirty(True)
431
432    def _add_key(self, payload, item, h, do_resize=True):
433        context = self._context
434        builder = self._builder
435
436        found, i = payload._lookup(item, h, for_insert=True)
437        not_found = builder.not_(found)
438
439        with builder.if_then(not_found):
440            # Not found => add it
441            entry = payload.get_entry(i)
442            old_hash = entry.hash
443            entry.hash = h
444            entry.key = item
445            # used++
446            used = payload.used
447            one = ir.Constant(used.type, 1)
448            used = payload.used = builder.add(used, one)
449            # fill++ if entry wasn't a deleted one
450            with builder.if_then(is_hash_empty(context, builder, old_hash),
451                                 likely=True):
452                payload.fill = builder.add(payload.fill, one)
453            # Grow table if necessary
454            if do_resize:
455                self.upsize(used)
456            self.set_dirty(True)
457
458    def _remove_entry(self, payload, entry, do_resize=True):
459        # Mark entry deleted
460        entry.hash = ir.Constant(entry.hash.type, DELETED)
461        # used--
462        used = payload.used
463        one = ir.Constant(used.type, 1)
464        used = payload.used = self._builder.sub(used, one)
465        # Shrink table if necessary
466        if do_resize:
467            self.downsize(used)
468        self.set_dirty(True)
469
470    def _remove_key(self, payload, item, h, do_resize=True):
471        context = self._context
472        builder = self._builder
473
474        found, i = payload._lookup(item, h)
475
476        with builder.if_then(found):
477            entry = payload.get_entry(i)
478            self._remove_entry(payload, entry, do_resize)
479
480        return found
481
482    def add(self, item, do_resize=True):
483        context = self._context
484        builder = self._builder
485
486        payload = self.payload
487        h = get_hash_value(context, builder, self._ty.dtype, item)
488        self._add_key(payload, item, h, do_resize)
489
490    def add_pyapi(self, pyapi, item, do_resize=True):
491        """A version of .add for use inside functions following Python calling
492        convention.
493        """
494        context = self._context
495        builder = self._builder
496
497        payload = self.payload
498        h = self._pyapi_get_hash_value(pyapi, context, builder, item)
499        self._add_key(payload, item, h, do_resize)
500
501    def _pyapi_get_hash_value(self, pyapi, context, builder, item):
502        """Python API compatible version of `get_hash_value()`.
503        """
504        argtypes = [self._ty.dtype]
505        resty = types.intp
506
507        def wrapper(val):
508            return _get_hash_value_intrinsic(val)
509
510        args = [item]
511        sig = typing.signature(resty, *argtypes)
512        is_error, retval = pyapi.call_jit_code(wrapper, sig, args)
513        # Handle return status
514        with builder.if_then(is_error, likely=False):
515            # Raise nopython exception as a Python exception
516            builder.ret(pyapi.get_null_object())
517        return retval
518
519    def contains(self, item):
520        context = self._context
521        builder = self._builder
522
523        payload = self.payload
524        h = get_hash_value(context, builder, self._ty.dtype, item)
525        found, i = payload._lookup(item, h)
526        return found
527
528    def discard(self, item):
529        context = self._context
530        builder = self._builder
531
532        payload = self.payload
533        h = get_hash_value(context, builder, self._ty.dtype, item)
534        found = self._remove_key(payload, item, h)
535        return found
536
537    def pop(self):
538        context = self._context
539        builder = self._builder
540
541        lty = context.get_value_type(self._ty.dtype)
542        key = cgutils.alloca_once(builder, lty)
543
544        payload = self.payload
545        with payload._next_entry() as entry:
546            builder.store(entry.key, key)
547            self._remove_entry(payload, entry)
548
549        return builder.load(key)
550
551    def clear(self):
552        context = self._context
553        builder = self._builder
554
555        intp_t = context.get_value_type(types.intp)
556        minsize = ir.Constant(intp_t, MINSIZE)
557        self._replace_payload(minsize)
558        self.set_dirty(True)
559
560    def copy(self):
561        """
562        Return a copy of this set.
563        """
564        context = self._context
565        builder = self._builder
566
567        payload = self.payload
568        used = payload.used
569        fill = payload.fill
570
571        other = type(self)(context, builder, self._ty, None)
572
573        no_deleted_entries = builder.icmp_unsigned('==', used, fill)
574        with builder.if_else(no_deleted_entries, likely=True) \
575            as (if_no_deleted, if_deleted):
576            with if_no_deleted:
577                # No deleted entries => raw copy the payload
578                ok = other._copy_payload(payload)
579                with builder.if_then(builder.not_(ok), likely=False):
580                    context.call_conv.return_user_exc(builder, MemoryError,
581                                                      ("cannot copy set",))
582
583            with if_deleted:
584                # Deleted entries => re-insert entries one by one
585                nentries = self.choose_alloc_size(context, builder, used)
586                ok = other._allocate_payload(nentries)
587                with builder.if_then(builder.not_(ok), likely=False):
588                    context.call_conv.return_user_exc(builder, MemoryError,
589                                                      ("cannot copy set",))
590
591                other_payload = other.payload
592                with payload._iterate() as loop:
593                    entry = loop.entry
594                    other._add_key(other_payload, entry.key, entry.hash,
595                                   do_resize=False)
596
597        return other
598
599    def intersect(self, other):
600        """
601        In-place intersection with *other* set.
602        """
603        context = self._context
604        builder = self._builder
605        payload = self.payload
606        other_payload = other.payload
607
608        with payload._iterate() as loop:
609            entry = loop.entry
610            found, _ = other_payload._lookup(entry.key, entry.hash)
611            with builder.if_then(builder.not_(found)):
612                self._remove_entry(payload, entry, do_resize=False)
613
614        # Final downsize
615        self.downsize(payload.used)
616
617    def difference(self, other):
618        """
619        In-place difference with *other* set.
620        """
621        context = self._context
622        builder = self._builder
623        payload = self.payload
624        other_payload = other.payload
625
626        with other_payload._iterate() as loop:
627            entry = loop.entry
628            self._remove_key(payload, entry.key, entry.hash, do_resize=False)
629
630        # Final downsize
631        self.downsize(payload.used)
632
633    def symmetric_difference(self, other):
634        """
635        In-place symmetric difference with *other* set.
636        """
637        context = self._context
638        builder = self._builder
639        other_payload = other.payload
640
641        with other_payload._iterate() as loop:
642            key = loop.entry.key
643            h = loop.entry.hash
644            # We must reload our payload as it may be resized during the loop
645            payload = self.payload
646            found, i = payload._lookup(key, h, for_insert=True)
647            entry = payload.get_entry(i)
648            with builder.if_else(found) as (if_common, if_not_common):
649                with if_common:
650                    self._remove_entry(payload, entry, do_resize=False)
651                with if_not_common:
652                    self._add_entry(payload, entry, key, h)
653
654        # Final downsize
655        self.downsize(self.payload.used)
656
657    def issubset(self, other, strict=False):
658        context = self._context
659        builder = self._builder
660        payload = self.payload
661        other_payload = other.payload
662
663        cmp_op = '<' if strict else '<='
664
665        res = cgutils.alloca_once_value(builder, cgutils.true_bit)
666        with builder.if_else(
667            builder.icmp_unsigned(cmp_op, payload.used, other_payload.used)
668            ) as (if_smaller, if_larger):
669            with if_larger:
670                # self larger than other => self cannot possibly a subset
671                builder.store(cgutils.false_bit, res)
672            with if_smaller:
673                # check whether each key of self is in other
674                with payload._iterate() as loop:
675                    entry = loop.entry
676                    found, _ = other_payload._lookup(entry.key, entry.hash)
677                    with builder.if_then(builder.not_(found)):
678                        builder.store(cgutils.false_bit, res)
679                        loop.do_break()
680
681        return builder.load(res)
682
683    def isdisjoint(self, other):
684        context = self._context
685        builder = self._builder
686        payload = self.payload
687        other_payload = other.payload
688
689        res = cgutils.alloca_once_value(builder, cgutils.true_bit)
690
691        def check(smaller, larger):
692            # Loop over the smaller of the two, and search in the larger
693            with smaller._iterate() as loop:
694                entry = loop.entry
695                found, _ = larger._lookup(entry.key, entry.hash)
696                with builder.if_then(found):
697                    builder.store(cgutils.false_bit, res)
698                    loop.do_break()
699
700        with builder.if_else(
701            builder.icmp_unsigned('>', payload.used, other_payload.used)
702            ) as (if_larger, otherwise):
703
704            with if_larger:
705                # len(self) > len(other)
706                check(other_payload, payload)
707
708            with otherwise:
709                # len(self) <= len(other)
710                check(payload, other_payload)
711
712        return builder.load(res)
713
714    def equals(self, other):
715        context = self._context
716        builder = self._builder
717        payload = self.payload
718        other_payload = other.payload
719
720        res = cgutils.alloca_once_value(builder, cgutils.true_bit)
721        with builder.if_else(
722            builder.icmp_unsigned('==', payload.used, other_payload.used)
723            ) as (if_same_size, otherwise):
724            with if_same_size:
725                # same sizes => check whether each key of self is in other
726                with payload._iterate() as loop:
727                    entry = loop.entry
728                    found, _ = other_payload._lookup(entry.key, entry.hash)
729                    with builder.if_then(builder.not_(found)):
730                        builder.store(cgutils.false_bit, res)
731                        loop.do_break()
732            with otherwise:
733                # different sizes => cannot possibly be equal
734                builder.store(cgutils.false_bit, res)
735
736        return builder.load(res)
737
738    @classmethod
739    def allocate_ex(cls, context, builder, set_type, nitems=None):
740        """
741        Allocate a SetInstance with its storage.
742        Return a (ok, instance) tuple where *ok* is a LLVM boolean and
743        *instance* is a SetInstance object (the object's contents are
744        only valid when *ok* is true).
745        """
746        intp_t = context.get_value_type(types.intp)
747
748        if nitems is None:
749            nentries = ir.Constant(intp_t, MINSIZE)
750        else:
751            if isinstance(nitems, int):
752                nitems = ir.Constant(intp_t, nitems)
753            nentries = cls.choose_alloc_size(context, builder, nitems)
754
755        self = cls(context, builder, set_type, None)
756        ok = self._allocate_payload(nentries)
757        return ok, self
758
759    @classmethod
760    def allocate(cls, context, builder, set_type, nitems=None):
761        """
762        Allocate a SetInstance with its storage.  Same as allocate_ex(),
763        but return an initialized *instance*.  If allocation failed,
764        control is transferred to the caller using the target's current
765        call convention.
766        """
767        ok, self = cls.allocate_ex(context, builder, set_type, nitems)
768        with builder.if_then(builder.not_(ok), likely=False):
769            context.call_conv.return_user_exc(builder, MemoryError,
770                                              ("cannot allocate set",))
771        return self
772
773    @classmethod
774    def from_meminfo(cls, context, builder, set_type, meminfo):
775        """
776        Allocate a new set instance pointing to an existing payload
777        (a meminfo pointer).
778        Note the parent field has to be filled by the caller.
779        """
780        self = cls(context, builder, set_type, None)
781        self._set.meminfo = meminfo
782        self._set.parent = context.get_constant_null(types.pyobject)
783        context.nrt.incref(builder, set_type, self.value)
784        # Payload is part of the meminfo, no need to touch it
785        return self
786
787    @classmethod
788    def choose_alloc_size(cls, context, builder, nitems):
789        """
790        Choose a suitable number of entries for the given number of items.
791        """
792        intp_t = nitems.type
793        one = ir.Constant(intp_t, 1)
794        minsize = ir.Constant(intp_t, MINSIZE)
795
796        # Ensure number of entries >= 2 * used
797        min_entries = builder.shl(nitems, one)
798        # Find out first suitable power of 2, starting from MINSIZE
799        size_p = cgutils.alloca_once_value(builder, minsize)
800
801        bb_body = builder.append_basic_block("calcsize.body")
802        bb_end = builder.append_basic_block("calcsize.end")
803
804        builder.branch(bb_body)
805
806        with builder.goto_block(bb_body):
807            size = builder.load(size_p)
808            is_large_enough = builder.icmp_unsigned('>=', size, min_entries)
809            with builder.if_then(is_large_enough, likely=False):
810                builder.branch(bb_end)
811            next_size = builder.shl(size, one)
812            builder.store(next_size, size_p)
813            builder.branch(bb_body)
814
815        builder.position_at_end(bb_end)
816        return builder.load(size_p)
817
818    def upsize(self, nitems):
819        """
820        When adding to the set, ensure it is properly sized for the given
821        number of used entries.
822        """
823        context = self._context
824        builder = self._builder
825        intp_t = nitems.type
826
827        one = ir.Constant(intp_t, 1)
828        two = ir.Constant(intp_t, 2)
829
830        payload = self.payload
831
832        # Ensure number of entries >= 2 * used
833        min_entries = builder.shl(nitems, one)
834        size = builder.add(payload.mask, one)
835        need_resize = builder.icmp_unsigned('>=', min_entries, size)
836
837        with builder.if_then(need_resize, likely=False):
838            # Find out next suitable size
839            new_size_p = cgutils.alloca_once_value(builder, size)
840
841            bb_body = builder.append_basic_block("calcsize.body")
842            bb_end = builder.append_basic_block("calcsize.end")
843
844            builder.branch(bb_body)
845
846            with builder.goto_block(bb_body):
847                # Multiply by 4 (ensuring size remains a power of two)
848                new_size = builder.load(new_size_p)
849                new_size = builder.shl(new_size, two)
850                builder.store(new_size, new_size_p)
851                is_too_small = builder.icmp_unsigned('>=', min_entries, new_size)
852                builder.cbranch(is_too_small, bb_body, bb_end)
853
854            builder.position_at_end(bb_end)
855
856            new_size = builder.load(new_size_p)
857            if DEBUG_ALLOCS:
858                context.printf(builder,
859                               "upsize to %zd items: current size = %zd, "
860                               "min entries = %zd, new size = %zd\n",
861                               nitems, size, min_entries, new_size)
862            self._resize(payload, new_size, "cannot grow set")
863
864    def downsize(self, nitems):
865        """
866        When removing from the set, ensure it is properly sized for the given
867        number of used entries.
868        """
869        context = self._context
870        builder = self._builder
871        intp_t = nitems.type
872
873        one = ir.Constant(intp_t, 1)
874        two = ir.Constant(intp_t, 2)
875        minsize = ir.Constant(intp_t, MINSIZE)
876
877        payload = self.payload
878
879        # Ensure entries >= max(2 * used, MINSIZE)
880        min_entries = builder.shl(nitems, one)
881        min_entries = builder.select(builder.icmp_unsigned('>=', min_entries, minsize),
882                                     min_entries, minsize)
883        # Shrink only if size >= 4 * min_entries && size > MINSIZE
884        max_size = builder.shl(min_entries, two)
885        size = builder.add(payload.mask, one)
886        need_resize = builder.and_(
887            builder.icmp_unsigned('<=', max_size, size),
888            builder.icmp_unsigned('<', minsize, size))
889
890        with builder.if_then(need_resize, likely=False):
891            # Find out next suitable size
892            new_size_p = cgutils.alloca_once_value(builder, size)
893
894            bb_body = builder.append_basic_block("calcsize.body")
895            bb_end = builder.append_basic_block("calcsize.end")
896
897            builder.branch(bb_body)
898
899            with builder.goto_block(bb_body):
900                # Divide by 2 (ensuring size remains a power of two)
901                new_size = builder.load(new_size_p)
902                new_size = builder.lshr(new_size, one)
903                # Keep current size if new size would be < min_entries
904                is_too_small = builder.icmp_unsigned('>', min_entries, new_size)
905                with builder.if_then(is_too_small):
906                    builder.branch(bb_end)
907                builder.store(new_size, new_size_p)
908                builder.branch(bb_body)
909
910            builder.position_at_end(bb_end)
911
912            # Ensure new_size >= MINSIZE
913            new_size = builder.load(new_size_p)
914            # At this point, new_size should be < size if the factors
915            # above were chosen carefully!
916
917            if DEBUG_ALLOCS:
918                context.printf(builder,
919                               "downsize to %zd items: current size = %zd, "
920                               "min entries = %zd, new size = %zd\n",
921                               nitems, size, min_entries, new_size)
922            self._resize(payload, new_size, "cannot shrink set")
923
924    def _resize(self, payload, nentries, errmsg):
925        """
926        Resize the payload to the given number of entries.
927
928        CAUTION: *nentries* must be a power of 2!
929        """
930        context = self._context
931        builder = self._builder
932
933        # Allocate new entries
934        old_payload = payload
935
936        ok = self._allocate_payload(nentries, realloc=True)
937        with builder.if_then(builder.not_(ok), likely=False):
938            context.call_conv.return_user_exc(builder, MemoryError,
939                                              (errmsg,))
940
941        # Re-insert old entries
942        payload = self.payload
943        with old_payload._iterate() as loop:
944            entry = loop.entry
945            self._add_key(payload, entry.key, entry.hash,
946                          do_resize=False)
947
948        self._free_payload(old_payload.ptr)
949
950    def _replace_payload(self, nentries):
951        """
952        Replace the payload with a new empty payload with the given number
953        of entries.
954
955        CAUTION: *nentries* must be a power of 2!
956        """
957        context = self._context
958        builder = self._builder
959
960        # Free old payload
961        self._free_payload(self.payload.ptr)
962
963        ok = self._allocate_payload(nentries, realloc=True)
964        with builder.if_then(builder.not_(ok), likely=False):
965            context.call_conv.return_user_exc(builder, MemoryError,
966                                              ("cannot reallocate set",))
967
968    def _allocate_payload(self, nentries, realloc=False):
969        """
970        Allocate and initialize payload for the given number of entries.
971        If *realloc* is True, the existing meminfo is reused.
972
973        CAUTION: *nentries* must be a power of 2!
974        """
975        context = self._context
976        builder = self._builder
977
978        ok = cgutils.alloca_once_value(builder, cgutils.true_bit)
979
980        intp_t = context.get_value_type(types.intp)
981        zero = ir.Constant(intp_t, 0)
982        one = ir.Constant(intp_t, 1)
983
984        payload_type = context.get_data_type(types.SetPayload(self._ty))
985        payload_size = context.get_abi_sizeof(payload_type)
986        entry_size = self._entrysize
987        # Account for the fact that the payload struct already contains an entry
988        payload_size -= entry_size
989
990        # Total allocation size = <payload header size> + nentries * entry_size
991        allocsize, ovf = cgutils.muladd_with_overflow(builder, nentries,
992                                                      ir.Constant(intp_t, entry_size),
993                                                      ir.Constant(intp_t, payload_size))
994        with builder.if_then(ovf, likely=False):
995            builder.store(cgutils.false_bit, ok)
996
997        with builder.if_then(builder.load(ok), likely=True):
998            if realloc:
999                meminfo = self._set.meminfo
1000                ptr = context.nrt.meminfo_varsize_alloc(builder, meminfo,
1001                                                        size=allocsize)
1002                alloc_ok = cgutils.is_null(builder, ptr)
1003            else:
1004                meminfo = context.nrt.meminfo_new_varsize(builder, size=allocsize)
1005                alloc_ok = cgutils.is_null(builder, meminfo)
1006
1007            with builder.if_else(cgutils.is_null(builder, meminfo),
1008                                 likely=False) as (if_error, if_ok):
1009                with if_error:
1010                    builder.store(cgutils.false_bit, ok)
1011                with if_ok:
1012                    if not realloc:
1013                        self._set.meminfo = meminfo
1014                        self._set.parent = context.get_constant_null(types.pyobject)
1015                    payload = self.payload
1016                    # Initialize entries to 0xff (EMPTY)
1017                    cgutils.memset(builder, payload.ptr, allocsize, 0xFF)
1018                    payload.used = zero
1019                    payload.fill = zero
1020                    payload.finger = zero
1021                    new_mask = builder.sub(nentries, one)
1022                    payload.mask = new_mask
1023
1024                    if DEBUG_ALLOCS:
1025                        context.printf(builder,
1026                                       "allocated %zd bytes for set at %p: mask = %zd\n",
1027                                       allocsize, payload.ptr, new_mask)
1028
1029        return builder.load(ok)
1030
1031    def _free_payload(self, ptr):
1032        """
1033        Free an allocated old payload at *ptr*.
1034        """
1035        self._context.nrt.meminfo_varsize_free(self._builder, self.meminfo, ptr)
1036
1037    def _copy_payload(self, src_payload):
1038        """
1039        Raw-copy the given payload into self.
1040        """
1041        context = self._context
1042        builder = self._builder
1043
1044        ok = cgutils.alloca_once_value(builder, cgutils.true_bit)
1045
1046        intp_t = context.get_value_type(types.intp)
1047        zero = ir.Constant(intp_t, 0)
1048        one = ir.Constant(intp_t, 1)
1049
1050        payload_type = context.get_data_type(types.SetPayload(self._ty))
1051        payload_size = context.get_abi_sizeof(payload_type)
1052        entry_size = self._entrysize
1053        # Account for the fact that the payload struct already contains an entry
1054        payload_size -= entry_size
1055
1056        mask = src_payload.mask
1057        nentries = builder.add(one, mask)
1058
1059        # Total allocation size = <payload header size> + nentries * entry_size
1060        # (note there can't be any overflow since we're reusing an existing
1061        #  payload's parameters)
1062        allocsize = builder.add(ir.Constant(intp_t, payload_size),
1063                                builder.mul(ir.Constant(intp_t, entry_size),
1064                                            nentries))
1065
1066        with builder.if_then(builder.load(ok), likely=True):
1067            meminfo = context.nrt.meminfo_new_varsize(builder, size=allocsize)
1068            alloc_ok = cgutils.is_null(builder, meminfo)
1069
1070            with builder.if_else(cgutils.is_null(builder, meminfo),
1071                                 likely=False) as (if_error, if_ok):
1072                with if_error:
1073                    builder.store(cgutils.false_bit, ok)
1074                with if_ok:
1075                    self._set.meminfo = meminfo
1076                    payload = self.payload
1077                    payload.used = src_payload.used
1078                    payload.fill = src_payload.fill
1079                    payload.finger = zero
1080                    payload.mask = mask
1081                    cgutils.raw_memcpy(builder, payload.entries,
1082                                       src_payload.entries, nentries,
1083                                       entry_size)
1084
1085                    if DEBUG_ALLOCS:
1086                        context.printf(builder,
1087                                       "allocated %zd bytes for set at %p: mask = %zd\n",
1088                                       allocsize, payload.ptr, mask)
1089
1090        return builder.load(ok)
1091
1092
1093class SetIterInstance(object):
1094
1095    def __init__(self, context, builder, iter_type, iter_val):
1096        self._context = context
1097        self._builder = builder
1098        self._ty = iter_type
1099        self._iter = context.make_helper(builder, iter_type, iter_val)
1100        ptr = self._context.nrt.meminfo_data(builder, self.meminfo)
1101        self._payload = _SetPayload(context, builder, self._ty.container, ptr)
1102
1103    @classmethod
1104    def from_set(cls, context, builder, iter_type, set_val):
1105        set_inst = SetInstance(context, builder, iter_type.container, set_val)
1106        self = cls(context, builder, iter_type, None)
1107        index = context.get_constant(types.intp, 0)
1108        self._iter.index = cgutils.alloca_once_value(builder, index)
1109        self._iter.meminfo = set_inst.meminfo
1110        return self
1111
1112    @property
1113    def value(self):
1114        return self._iter._getvalue()
1115
1116    @property
1117    def meminfo(self):
1118        return self._iter.meminfo
1119
1120    @property
1121    def index(self):
1122        return self._builder.load(self._iter.index)
1123
1124    @index.setter
1125    def index(self, value):
1126        self._builder.store(value, self._iter.index)
1127
1128    def iternext(self, result):
1129        index = self.index
1130        payload = self._payload
1131        one = ir.Constant(index.type, 1)
1132
1133        result.set_exhausted()
1134
1135        with payload._iterate(start=index) as loop:
1136            # An entry was found
1137            entry = loop.entry
1138            result.set_valid()
1139            result.yield_(entry.key)
1140            self.index = self._builder.add(loop.index, one)
1141            loop.do_break()
1142
1143
1144#-------------------------------------------------------------------------------
1145# Constructors
1146
1147def build_set(context, builder, set_type, items):
1148    """
1149    Build a set of the given type, containing the given items.
1150    """
1151    nitems = len(items)
1152    inst = SetInstance.allocate(context, builder, set_type, nitems)
1153
1154    # Populate set.  Inlining the insertion code for each item would be very
1155    # costly, instead we create a LLVM array and iterate over it.
1156    array = cgutils.pack_array(builder, items)
1157    array_ptr = cgutils.alloca_once_value(builder, array)
1158
1159    count = context.get_constant(types.intp, nitems)
1160    with cgutils.for_range(builder, count) as loop:
1161        item = builder.load(cgutils.gep(builder, array_ptr, 0, loop.index))
1162        inst.add(item)
1163
1164    return impl_ret_new_ref(context, builder, set_type, inst.value)
1165
1166
1167@lower_builtin(set)
1168def set_empty_constructor(context, builder, sig, args):
1169    set_type = sig.return_type
1170    inst = SetInstance.allocate(context, builder, set_type)
1171    return impl_ret_new_ref(context, builder, set_type, inst.value)
1172
1173@lower_builtin(set, types.IterableType)
1174def set_constructor(context, builder, sig, args):
1175    set_type = sig.return_type
1176    items_type, = sig.args
1177    items, = args
1178
1179    # If the argument has a len(), preallocate the set so as to
1180    # avoid resizes.
1181    n = call_len(context, builder, items_type, items)
1182    inst = SetInstance.allocate(context, builder, set_type, n)
1183    with for_iter(context, builder, items_type, items) as loop:
1184        inst.add(loop.value)
1185
1186    return impl_ret_new_ref(context, builder, set_type, inst.value)
1187
1188
1189#-------------------------------------------------------------------------------
1190# Various operations
1191
1192@lower_builtin(len, types.Set)
1193def set_len(context, builder, sig, args):
1194    inst = SetInstance(context, builder, sig.args[0], args[0])
1195    return inst.get_size()
1196
1197@lower_builtin(operator.contains, types.Set, types.Any)
1198def in_set(context, builder, sig, args):
1199    inst = SetInstance(context, builder, sig.args[0], args[0])
1200    return inst.contains(args[1])
1201
1202@lower_builtin('getiter', types.Set)
1203def getiter_set(context, builder, sig, args):
1204    inst = SetIterInstance.from_set(context, builder, sig.return_type, args[0])
1205    return impl_ret_borrowed(context, builder, sig.return_type, inst.value)
1206
1207@lower_builtin('iternext', types.SetIter)
1208@iternext_impl(RefType.BORROWED)
1209def iternext_listiter(context, builder, sig, args, result):
1210    inst = SetIterInstance(context, builder, sig.args[0], args[0])
1211    inst.iternext(result)
1212
1213
1214#-------------------------------------------------------------------------------
1215# Methods
1216
1217# One-item-at-a-time operations
1218
1219@lower_builtin("set.add", types.Set, types.Any)
1220def set_add(context, builder, sig, args):
1221    inst = SetInstance(context, builder, sig.args[0], args[0])
1222    item = args[1]
1223    inst.add(item)
1224
1225    return context.get_dummy_value()
1226
1227@lower_builtin("set.discard", types.Set, types.Any)
1228def set_discard(context, builder, sig, args):
1229    inst = SetInstance(context, builder, sig.args[0], args[0])
1230    item = args[1]
1231    inst.discard(item)
1232
1233    return context.get_dummy_value()
1234
1235@lower_builtin("set.pop", types.Set)
1236def set_pop(context, builder, sig, args):
1237    inst = SetInstance(context, builder, sig.args[0], args[0])
1238    used = inst.payload.used
1239    with builder.if_then(cgutils.is_null(builder, used), likely=False):
1240        context.call_conv.return_user_exc(builder, KeyError,
1241                                          ("set.pop(): empty set",))
1242
1243    return inst.pop()
1244
1245@lower_builtin("set.remove", types.Set, types.Any)
1246def set_remove(context, builder, sig, args):
1247    inst = SetInstance(context, builder, sig.args[0], args[0])
1248    item = args[1]
1249    found = inst.discard(item)
1250    with builder.if_then(builder.not_(found), likely=False):
1251        context.call_conv.return_user_exc(builder, KeyError,
1252                                          ("set.remove(): key not in set",))
1253
1254    return context.get_dummy_value()
1255
1256
1257# Mutating set operations
1258
1259@lower_builtin("set.clear", types.Set)
1260def set_clear(context, builder, sig, args):
1261    inst = SetInstance(context, builder, sig.args[0], args[0])
1262    inst.clear()
1263    return context.get_dummy_value()
1264
1265@lower_builtin("set.copy", types.Set)
1266def set_copy(context, builder, sig, args):
1267    inst = SetInstance(context, builder, sig.args[0], args[0])
1268    other = inst.copy()
1269    return impl_ret_new_ref(context, builder, sig.return_type, other.value)
1270
1271@lower_builtin("set.difference_update", types.Set, types.IterableType)
1272def set_difference_update(context, builder, sig, args):
1273    inst = SetInstance(context, builder, sig.args[0], args[0])
1274    other = SetInstance(context, builder, sig.args[1], args[1])
1275
1276    inst.difference(other)
1277
1278    return context.get_dummy_value()
1279
1280@lower_builtin("set.intersection_update", types.Set, types.Set)
1281def set_intersection_update(context, builder, sig, args):
1282    inst = SetInstance(context, builder, sig.args[0], args[0])
1283    other = SetInstance(context, builder, sig.args[1], args[1])
1284
1285    inst.intersect(other)
1286
1287    return context.get_dummy_value()
1288
1289@lower_builtin("set.symmetric_difference_update", types.Set, types.Set)
1290def set_symmetric_difference_update(context, builder, sig, args):
1291    inst = SetInstance(context, builder, sig.args[0], args[0])
1292    other = SetInstance(context, builder, sig.args[1], args[1])
1293
1294    inst.symmetric_difference(other)
1295
1296    return context.get_dummy_value()
1297
1298@lower_builtin("set.update", types.Set, types.IterableType)
1299def set_update(context, builder, sig, args):
1300    inst = SetInstance(context, builder, sig.args[0], args[0])
1301    items_type = sig.args[1]
1302    items = args[1]
1303
1304    # If the argument has a len(), assume there are few collisions and
1305    # presize to len(set) + len(items)
1306    n = call_len(context, builder, items_type, items)
1307    if n is not None:
1308        new_size = builder.add(inst.payload.used, n)
1309        inst.upsize(new_size)
1310
1311    with for_iter(context, builder, items_type, items) as loop:
1312        inst.add(loop.value)
1313
1314    if n is not None:
1315        # If we pre-grew the set, downsize in case there were many collisions
1316        inst.downsize(inst.payload.used)
1317
1318    return context.get_dummy_value()
1319
1320for op_, op_impl in [
1321    (operator.iand, set_intersection_update),
1322    (operator.ior, set_update),
1323    (operator.isub, set_difference_update),
1324    (operator.ixor, set_symmetric_difference_update),
1325    ]:
1326    @lower_builtin(op_, types.Set, types.Set)
1327    def set_inplace(context, builder, sig, args, op_impl=op_impl):
1328        assert sig.return_type == sig.args[0]
1329        op_impl(context, builder, sig, args)
1330        return impl_ret_borrowed(context, builder, sig.args[0], args[0])
1331
1332
1333# Set operations creating a new set
1334
1335@lower_builtin(operator.sub, types.Set, types.Set)
1336@lower_builtin("set.difference", types.Set, types.Set)
1337def set_difference(context, builder, sig, args):
1338    def difference_impl(a, b):
1339        s = a.copy()
1340        s.difference_update(b)
1341        return s
1342
1343    return context.compile_internal(builder, difference_impl, sig, args)
1344
1345@lower_builtin(operator.and_, types.Set, types.Set)
1346@lower_builtin("set.intersection", types.Set, types.Set)
1347def set_intersection(context, builder, sig, args):
1348    def intersection_impl(a, b):
1349        if len(a) < len(b):
1350            s = a.copy()
1351            s.intersection_update(b)
1352            return s
1353        else:
1354            s = b.copy()
1355            s.intersection_update(a)
1356            return s
1357
1358    return context.compile_internal(builder, intersection_impl, sig, args)
1359
1360@lower_builtin(operator.xor, types.Set, types.Set)
1361@lower_builtin("set.symmetric_difference", types.Set, types.Set)
1362def set_symmetric_difference(context, builder, sig, args):
1363    def symmetric_difference_impl(a, b):
1364        if len(a) > len(b):
1365            s = a.copy()
1366            s.symmetric_difference_update(b)
1367            return s
1368        else:
1369            s = b.copy()
1370            s.symmetric_difference_update(a)
1371            return s
1372
1373    return context.compile_internal(builder, symmetric_difference_impl,
1374                                    sig, args)
1375
1376@lower_builtin(operator.or_, types.Set, types.Set)
1377@lower_builtin("set.union", types.Set, types.Set)
1378def set_union(context, builder, sig, args):
1379    def union_impl(a, b):
1380        if len(a) > len(b):
1381            s = a.copy()
1382            s.update(b)
1383            return s
1384        else:
1385            s = b.copy()
1386            s.update(a)
1387            return s
1388
1389    return context.compile_internal(builder, union_impl, sig, args)
1390
1391
1392# Predicates
1393
1394@lower_builtin("set.isdisjoint", types.Set, types.Set)
1395def set_isdisjoint(context, builder, sig, args):
1396    inst = SetInstance(context, builder, sig.args[0], args[0])
1397    other = SetInstance(context, builder, sig.args[1], args[1])
1398
1399    return inst.isdisjoint(other)
1400
1401@lower_builtin(operator.le, types.Set, types.Set)
1402@lower_builtin("set.issubset", types.Set, types.Set)
1403def set_issubset(context, builder, sig, args):
1404    inst = SetInstance(context, builder, sig.args[0], args[0])
1405    other = SetInstance(context, builder, sig.args[1], args[1])
1406
1407    return inst.issubset(other)
1408
1409@lower_builtin(operator.ge, types.Set, types.Set)
1410@lower_builtin("set.issuperset", types.Set, types.Set)
1411def set_issuperset(context, builder, sig, args):
1412    def superset_impl(a, b):
1413        return b.issubset(a)
1414
1415    return context.compile_internal(builder, superset_impl, sig, args)
1416
1417@lower_builtin(operator.eq, types.Set, types.Set)
1418def set_isdisjoint(context, builder, sig, args):
1419    inst = SetInstance(context, builder, sig.args[0], args[0])
1420    other = SetInstance(context, builder, sig.args[1], args[1])
1421
1422    return inst.equals(other)
1423
1424@lower_builtin(operator.ne, types.Set, types.Set)
1425def set_ne(context, builder, sig, args):
1426    def ne_impl(a, b):
1427        return not a == b
1428
1429    return context.compile_internal(builder, ne_impl, sig, args)
1430
1431@lower_builtin(operator.lt, types.Set, types.Set)
1432def set_lt(context, builder, sig, args):
1433    inst = SetInstance(context, builder, sig.args[0], args[0])
1434    other = SetInstance(context, builder, sig.args[1], args[1])
1435
1436    return inst.issubset(other, strict=True)
1437
1438@lower_builtin(operator.gt, types.Set, types.Set)
1439def set_gt(context, builder, sig, args):
1440    def gt_impl(a, b):
1441        return b < a
1442
1443    return context.compile_internal(builder, gt_impl, sig, args)
1444
1445@lower_builtin(operator.is_, types.Set, types.Set)
1446def set_is(context, builder, sig, args):
1447    a = SetInstance(context, builder, sig.args[0], args[0])
1448    b = SetInstance(context, builder, sig.args[1], args[1])
1449    ma = builder.ptrtoint(a.meminfo, cgutils.intp_t)
1450    mb = builder.ptrtoint(b.meminfo, cgutils.intp_t)
1451    return builder.icmp_signed('==', ma, mb)
1452
1453
1454# -----------------------------------------------------------------------------
1455# Implicit casting
1456
1457@lower_cast(types.Set, types.Set)
1458def set_to_set(context, builder, fromty, toty, val):
1459    # Casting from non-reflected to reflected
1460    assert fromty.dtype == toty.dtype
1461    return val
1462