1import warnings
2
3import numpy as np
4
5cimport numpy as cnp
6from numpy cimport (
7    float32_t,
8    float64_t,
9    int8_t,
10    int16_t,
11    int32_t,
12    int64_t,
13    intp_t,
14    ndarray,
15    uint8_t,
16    uint16_t,
17    uint32_t,
18    uint64_t,
19)
20
21cnp.import_array()
22
23
24from pandas._libs cimport util
25from pandas._libs.hashtable cimport HashTable
26from pandas._libs.tslibs.nattype cimport c_NaT as NaT
27from pandas._libs.tslibs.period cimport is_period_object
28from pandas._libs.tslibs.timedeltas cimport _Timedelta
29from pandas._libs.tslibs.timestamps cimport _Timestamp
30
31from pandas._libs import algos, hashtable as _hash
32from pandas._libs.missing import checknull
33
34
35cdef inline bint is_definitely_invalid_key(object val):
36    try:
37        hash(val)
38    except TypeError:
39        return True
40    return False
41
42
43# Don't populate hash tables in monotonic indexes larger than this
44_SIZE_CUTOFF = 1_000_000
45
46
47cdef class IndexEngine:
48
49    cdef readonly:
50        object vgetter
51        HashTable mapping
52        bint over_size_threshold
53
54    cdef:
55        bint unique, monotonic_inc, monotonic_dec
56        bint need_monotonic_check, need_unique_check
57
58    def __init__(self, vgetter, n):
59        self.vgetter = vgetter
60
61        self.over_size_threshold = n >= _SIZE_CUTOFF
62        self.clear_mapping()
63
64    def __contains__(self, val: object) -> bool:
65        # We assume before we get here:
66        #  - val is hashable
67        self._ensure_mapping_populated()
68        return val in self.mapping
69
70    cpdef get_loc(self, object val):
71        cdef:
72            Py_ssize_t loc
73
74        if is_definitely_invalid_key(val):
75            raise TypeError(f"'{val}' is an invalid key")
76
77        if self.over_size_threshold and self.is_monotonic_increasing:
78            if not self.is_unique:
79                return self._get_loc_duplicates(val)
80            values = self._get_index_values()
81
82            self._check_type(val)
83            try:
84                loc = _bin_search(values, val)  # .searchsorted(val, side='left')
85            except TypeError:
86                # GH#35788 e.g. val=None with float64 values
87                raise KeyError(val)
88            if loc >= len(values):
89                raise KeyError(val)
90            if values[loc] != val:
91                raise KeyError(val)
92            return loc
93
94        self._ensure_mapping_populated()
95        if not self.unique:
96            return self._get_loc_duplicates(val)
97
98        self._check_type(val)
99
100        try:
101            return self.mapping.get_item(val)
102        except (TypeError, ValueError):
103            raise KeyError(val)
104
105    cdef inline _get_loc_duplicates(self, object val):
106        cdef:
107            Py_ssize_t diff
108
109        if self.is_monotonic_increasing:
110            values = self._get_index_values()
111            try:
112                left = values.searchsorted(val, side='left')
113                right = values.searchsorted(val, side='right')
114            except TypeError:
115                # e.g. GH#29189 get_loc(None) with a Float64Index
116                raise KeyError(val)
117
118            diff = right - left
119            if diff == 0:
120                raise KeyError(val)
121            elif diff == 1:
122                return left
123            else:
124                return slice(left, right)
125
126        return self._maybe_get_bool_indexer(val)
127
128    cdef _maybe_get_bool_indexer(self, object val):
129        cdef:
130            ndarray[uint8_t, ndim=1, cast=True] indexer
131
132        indexer = self._get_index_values() == val
133        return self._unpack_bool_indexer(indexer, val)
134
135    cdef _unpack_bool_indexer(self,
136                              ndarray[uint8_t, ndim=1, cast=True] indexer,
137                              object val):
138        cdef:
139            ndarray[intp_t, ndim=1] found
140            int count
141
142        found = np.where(indexer)[0]
143        count = len(found)
144
145        if count > 1:
146            return indexer
147        if count == 1:
148            return int(found[0])
149
150        raise KeyError(val)
151
152    def sizeof(self, deep: bool = False) -> int:
153        """ return the sizeof our mapping """
154        if not self.is_mapping_populated:
155            return 0
156        return self.mapping.sizeof(deep=deep)
157
158    def __sizeof__(self) -> int:
159        return self.sizeof()
160
161    @property
162    def is_unique(self) -> bool:
163        if self.need_unique_check:
164            self._do_unique_check()
165
166        return self.unique == 1
167
168    cdef inline _do_unique_check(self):
169
170        # this de-facto the same
171        self._ensure_mapping_populated()
172
173    @property
174    def is_monotonic_increasing(self) -> bool:
175        if self.need_monotonic_check:
176            self._do_monotonic_check()
177
178        return self.monotonic_inc == 1
179
180    @property
181    def is_monotonic_decreasing(self) -> bool:
182        if self.need_monotonic_check:
183            self._do_monotonic_check()
184
185        return self.monotonic_dec == 1
186
187    cdef inline _do_monotonic_check(self):
188        cdef:
189            bint is_unique
190        try:
191            values = self._get_index_values()
192            self.monotonic_inc, self.monotonic_dec, is_unique = \
193                self._call_monotonic(values)
194        except TypeError:
195            self.monotonic_inc = 0
196            self.monotonic_dec = 0
197            is_unique = 0
198
199        self.need_monotonic_check = 0
200
201        # we can only be sure of uniqueness if is_unique=1
202        if is_unique:
203            self.unique = 1
204            self.need_unique_check = 0
205
206    cdef _get_index_values(self):
207        return self.vgetter()
208
209    cdef _call_monotonic(self, values):
210        return algos.is_monotonic(values, timelike=False)
211
212    def get_backfill_indexer(self, other: np.ndarray, limit=None) -> np.ndarray:
213        return algos.backfill(self._get_index_values(), other, limit=limit)
214
215    def get_pad_indexer(self, other: np.ndarray, limit=None) -> np.ndarray:
216        return algos.pad(self._get_index_values(), other, limit=limit)
217
218    cdef _make_hash_table(self, Py_ssize_t n):
219        raise NotImplementedError
220
221    cdef _check_type(self, object val):
222        hash(val)
223
224    @property
225    def is_mapping_populated(self) -> bool:
226        return self.mapping is not None
227
228    cdef inline _ensure_mapping_populated(self):
229        # this populates the mapping
230        # if its not already populated
231        # also satisfies the need_unique_check
232
233        if not self.is_mapping_populated:
234
235            values = self._get_index_values()
236            self.mapping = self._make_hash_table(len(values))
237            self._call_map_locations(values)
238
239            if len(self.mapping) == len(values):
240                self.unique = 1
241
242        self.need_unique_check = 0
243
244    cdef void _call_map_locations(self, values):
245        self.mapping.map_locations(values)
246
247    def clear_mapping(self):
248        self.mapping = None
249        self.need_monotonic_check = 1
250        self.need_unique_check = 1
251
252        self.unique = 0
253        self.monotonic_inc = 0
254        self.monotonic_dec = 0
255
256    def get_indexer(self, values):
257        self._ensure_mapping_populated()
258        return self.mapping.lookup(values)
259
260    def get_indexer_non_unique(self, targets):
261        """
262        Return an indexer suitable for taking from a non unique index
263        return the labels in the same order as the target
264        and a missing indexer into the targets (which correspond
265        to the -1 indices in the results
266        """
267        cdef:
268            ndarray values, x
269            ndarray[intp_t] result, missing
270            set stargets, remaining_stargets
271            dict d = {}
272            object val
273            int count = 0, count_missing = 0
274            Py_ssize_t i, j, n, n_t, n_alloc
275
276        self._ensure_mapping_populated()
277        values = np.array(self._get_index_values(), copy=False)
278        stargets = set(targets)
279        n = len(values)
280        n_t = len(targets)
281        if n > 10_000:
282            n_alloc = 10_000
283        else:
284            n_alloc = n
285
286        result = np.empty(n_alloc, dtype=np.intp)
287        missing = np.empty(n_t, dtype=np.intp)
288
289        # map each starget to its position in the index
290        if stargets and len(stargets) < 5 and self.is_monotonic_increasing:
291            # if there are few enough stargets and the index is monotonically
292            # increasing, then use binary search for each starget
293            remaining_stargets = set()
294            for starget in stargets:
295                try:
296                    start = values.searchsorted(starget, side='left')
297                    end = values.searchsorted(starget, side='right')
298                except TypeError:  # e.g. if we tried to search for string in int array
299                    remaining_stargets.add(starget)
300                else:
301                    if start != end:
302                        d[starget] = list(range(start, end))
303
304            stargets = remaining_stargets
305
306        if stargets:
307            # otherwise, map by iterating through all items in the index
308            for i in range(n):
309                val = values[i]
310                if val in stargets:
311                    if val not in d:
312                        d[val] = []
313                    d[val].append(i)
314
315        for i in range(n_t):
316            val = targets[i]
317
318            # found
319            if val in d:
320                for j in d[val]:
321
322                    # realloc if needed
323                    if count >= n_alloc:
324                        n_alloc += 10_000
325                        result = np.resize(result, n_alloc)
326
327                    result[count] = j
328                    count += 1
329
330            # value not found
331            else:
332
333                if count >= n_alloc:
334                    n_alloc += 10_000
335                    result = np.resize(result, n_alloc)
336                result[count] = -1
337                count += 1
338                missing[count_missing] = i
339                count_missing += 1
340
341        return result[0:count], missing[0:count_missing]
342
343
344cdef Py_ssize_t _bin_search(ndarray values, object val) except -1:
345    cdef:
346        Py_ssize_t mid = 0, lo = 0, hi = len(values) - 1
347        object pval
348
349    if hi == 0 or (hi > 0 and val > values[hi]):
350        return len(values)
351
352    while lo < hi:
353        mid = (lo + hi) // 2
354        pval = values[mid]
355        if val < pval:
356            hi = mid
357        elif val > pval:
358            lo = mid + 1
359        else:
360            while mid > 0 and val == values[mid - 1]:
361                mid -= 1
362            return mid
363
364    if val <= values[mid]:
365        return mid
366    else:
367        return mid + 1
368
369
370cdef class ObjectEngine(IndexEngine):
371    """
372    Index Engine for use with object-dtype Index, namely the base class Index.
373    """
374    cdef _make_hash_table(self, Py_ssize_t n):
375        return _hash.PyObjectHashTable(n)
376
377
378cdef class DatetimeEngine(Int64Engine):
379
380    cdef str _get_box_dtype(self):
381        return 'M8[ns]'
382
383    cdef int64_t _unbox_scalar(self, scalar) except? -1:
384        # NB: caller is responsible for ensuring tzawareness compat
385        #  before we get here
386        if not (isinstance(scalar, _Timestamp) or scalar is NaT):
387            raise TypeError(scalar)
388        return scalar.value
389
390    def __contains__(self, val: object) -> bool:
391        # We assume before we get here:
392        #  - val is hashable
393        cdef:
394            int64_t loc, conv
395
396        conv = self._unbox_scalar(val)
397        if self.over_size_threshold and self.is_monotonic_increasing:
398            if not self.is_unique:
399                return self._get_loc_duplicates(conv)
400            values = self._get_index_values()
401            loc = values.searchsorted(conv, side='left')
402            return values[loc] == conv
403
404        self._ensure_mapping_populated()
405        return conv in self.mapping
406
407    cdef _get_index_values(self):
408        return self.vgetter().view('i8')
409
410    cdef _call_monotonic(self, values):
411        return algos.is_monotonic(values, timelike=True)
412
413    cpdef get_loc(self, object val):
414        # NB: the caller is responsible for ensuring that we are called
415        #  with either a Timestamp or NaT (Timedelta or NaT for TimedeltaEngine)
416
417        cdef:
418            int64_t loc
419        if is_definitely_invalid_key(val):
420            raise TypeError(f"'{val}' is an invalid key")
421
422        try:
423            conv = self._unbox_scalar(val)
424        except TypeError:
425            raise KeyError(val)
426
427        # Welcome to the spaghetti factory
428        if self.over_size_threshold and self.is_monotonic_increasing:
429            if not self.is_unique:
430                return self._get_loc_duplicates(conv)
431            values = self._get_index_values()
432
433            loc = values.searchsorted(conv, side='left')
434
435            if loc == len(values) or values[loc] != conv:
436                raise KeyError(val)
437            return loc
438
439        self._ensure_mapping_populated()
440        if not self.unique:
441            return self._get_loc_duplicates(conv)
442
443        try:
444            return self.mapping.get_item(conv)
445        except KeyError:
446            raise KeyError(val)
447
448    def get_indexer_non_unique(self, targets):
449        # we may get datetime64[ns] or timedelta64[ns], cast these to int64
450        return super().get_indexer_non_unique(targets.view("i8"))
451
452    def get_indexer(self, values):
453        self._ensure_mapping_populated()
454        if values.dtype != self._get_box_dtype():
455            return np.repeat(-1, len(values)).astype('i4')
456        values = np.asarray(values).view('i8')
457        return self.mapping.lookup(values)
458
459    def get_pad_indexer(self, other: np.ndarray, limit=None) -> np.ndarray:
460        if other.dtype != self._get_box_dtype():
461            return np.repeat(-1, len(other)).astype('i4')
462        other = np.asarray(other).view('i8')
463        return algos.pad(self._get_index_values(), other, limit=limit)
464
465    def get_backfill_indexer(self, other: np.ndarray, limit=None) -> np.ndarray:
466        if other.dtype != self._get_box_dtype():
467            return np.repeat(-1, len(other)).astype('i4')
468        other = np.asarray(other).view('i8')
469        return algos.backfill(self._get_index_values(), other, limit=limit)
470
471
472cdef class TimedeltaEngine(DatetimeEngine):
473
474    cdef str _get_box_dtype(self):
475        return 'm8[ns]'
476
477    cdef int64_t _unbox_scalar(self, scalar) except? -1:
478        if not (isinstance(scalar, _Timedelta) or scalar is NaT):
479            raise TypeError(scalar)
480        return scalar.value
481
482
483cdef class PeriodEngine(Int64Engine):
484
485    cdef int64_t _unbox_scalar(self, scalar) except? -1:
486        if scalar is NaT:
487            return scalar.value
488        if is_period_object(scalar):
489            # NB: we assume that we have the correct freq here.
490            return scalar.ordinal
491        raise TypeError(scalar)
492
493    cpdef get_loc(self, object val):
494        # NB: the caller is responsible for ensuring that we are called
495        #  with either a Period or NaT
496        cdef:
497            int64_t conv
498
499        try:
500            conv = self._unbox_scalar(val)
501        except TypeError:
502            raise KeyError(val)
503
504        return Int64Engine.get_loc(self, conv)
505
506    cdef _get_index_values(self):
507        return super(PeriodEngine, self).vgetter().view("i8")
508
509    cdef _call_monotonic(self, values):
510        return algos.is_monotonic(values, timelike=True)
511
512
513cdef class BaseMultiIndexCodesEngine:
514    """
515    Base class for MultiIndexUIntEngine and MultiIndexPyIntEngine, which
516    represent each label in a MultiIndex as an integer, by juxtaposing the bits
517    encoding each level, with appropriate offsets.
518
519    For instance: if 3 levels have respectively 3, 6 and 1 possible values,
520    then their labels can be represented using respectively 2, 3 and 1 bits,
521    as follows:
522     _ _ _ _____ _ __ __ __
523    |0|0|0| ... |0| 0|a1|a0| -> offset 0 (first level)
524     — — — ————— — —— —— ——
525    |0|0|0| ... |0|b2|b1|b0| -> offset 2 (bits required for first level)
526     — — — ————— — —— —— ——
527    |0|0|0| ... |0| 0| 0|c0| -> offset 5 (bits required for first two levels)
528     ‾ ‾ ‾ ‾‾‾‾‾ ‾ ‾‾ ‾‾ ‾‾
529    and the resulting unsigned integer representation will be:
530     _ _ _ _____ _ __ __ __ __ __ __
531    |0|0|0| ... |0|c0|b2|b1|b0|a1|a0|
532     ‾ ‾ ‾ ‾‾‾‾‾ ‾ ‾‾ ‾‾ ‾‾ ‾‾ ‾‾ ‾‾
533
534    Offsets are calculated at initialization, labels are transformed by method
535    _codes_to_ints.
536
537    Keys are located by first locating each component against the respective
538    level, then locating (the integer representation of) codes.
539    """
540    def __init__(self, object levels, object labels,
541                 ndarray[uint64_t, ndim=1] offsets):
542        """
543        Parameters
544        ----------
545        levels : list-like of numpy arrays
546            Levels of the MultiIndex.
547        labels : list-like of numpy arrays of integer dtype
548            Labels of the MultiIndex.
549        offsets : numpy array of uint64 dtype
550            Pre-calculated offsets, one for each level of the index.
551        """
552        self.levels = levels
553        self.offsets = offsets
554
555        # Transform labels in a single array, and add 1 so that we are working
556        # with positive integers (-1 for NaN becomes 0):
557        codes = (np.array(labels, dtype='int64').T + 1).astype('uint64',
558                                                               copy=False)
559
560        # Map each codes combination in the index to an integer unambiguously
561        # (no collisions possible), based on the "offsets", which describe the
562        # number of bits to switch labels for each level:
563        lab_ints = self._codes_to_ints(codes)
564
565        # Initialize underlying index (e.g. libindex.UInt64Engine) with
566        # integers representing labels: we will use its get_loc and get_indexer
567        self._base.__init__(self, lambda: lab_ints, len(lab_ints))
568
569    def _codes_to_ints(self, codes):
570        raise NotImplementedError("Implemented by subclass")
571
572    def _extract_level_codes(self, object target):
573        """
574        Map the requested list of (tuple) keys to their integer representations
575        for searching in the underlying integer index.
576
577        Parameters
578        ----------
579        target : list-like of keys
580            Each key is a tuple, with a label for each level of the index.
581
582        Returns
583        ------
584        int_keys : 1-dimensional array of dtype uint64 or object
585            Integers representing one combination each
586        """
587        level_codes = [lev.get_indexer(codes) + 1 for lev, codes
588                       in zip(self.levels, zip(*target))]
589        return self._codes_to_ints(np.array(level_codes, dtype='uint64').T)
590
591    def get_indexer_no_fill(self, object target) -> np.ndarray:
592        """
593        Returns an array giving the positions of each value of `target` in
594        `self.values`, where -1 represents a value in `target` which does not
595        appear in `self.values`
596
597        Parameters
598        ----------
599        target : list-like of keys
600            Each key is a tuple, with a label for each level of the index
601
602        Returns
603        -------
604        np.ndarray[int64_t, ndim=1] of the indexer of `target` into
605        `self.values`
606        """
607        lab_ints = self._extract_level_codes(target)
608        return self._base.get_indexer(self, lab_ints)
609
610    def get_indexer(self, object target, object values = None,
611                    object method = None, object limit = None) -> np.ndarray:
612        """
613        Returns an array giving the positions of each value of `target` in
614        `values`, where -1 represents a value in `target` which does not
615        appear in `values`
616
617        If `method` is "backfill" then the position for a value in `target`
618        which does not appear in `values` is that of the next greater value
619        in `values` (if one exists), and -1 if there is no such value.
620
621        Similarly, if the method is "pad" then the position for a value in
622        `target` which does not appear in `values` is that of the next smaller
623        value in `values` (if one exists), and -1 if there is no such value.
624
625        Parameters
626        ----------
627        target: list-like of tuples
628            need not be sorted, but all must have the same length, which must be
629            the same as the length of all tuples in `values`
630        values : list-like of tuples
631            must be sorted and all have the same length.  Should be the set of
632            the MultiIndex's values.  Needed only if `method` is not None
633        method: string
634            "backfill" or "pad"
635        limit: int, optional
636            if provided, limit the number of fills to this value
637
638        Returns
639        -------
640        np.ndarray[int64_t, ndim=1] of the indexer of `target` into `values`,
641        filled with the `method` (and optionally `limit`) specified
642        """
643        if method is None:
644            return self.get_indexer_no_fill(target)
645
646        assert method in ("backfill", "pad")
647        cdef:
648            int64_t i, j, next_code
649            int64_t num_values, num_target_values
650            ndarray[int64_t, ndim=1] target_order
651            ndarray[object, ndim=1] target_values
652            ndarray[int64_t, ndim=1] new_codes, new_target_codes
653            ndarray[int64_t, ndim=1] sorted_indexer
654
655        target_order = np.argsort(target.values).astype('int64')
656        target_values = target.values[target_order]
657        num_values, num_target_values = len(values), len(target_values)
658        new_codes, new_target_codes = (
659            np.empty((num_values,)).astype('int64'),
660            np.empty((num_target_values,)).astype('int64'),
661        )
662
663        # `values` and `target_values` are both sorted, so we walk through them
664        # and memoize the (ordered) set of indices in the (implicit) merged-and
665        # sorted list of the two which belong to each of them
666        # the effect of this is to create a factorization for the (sorted)
667        # merger of the index values, where `new_codes` and `new_target_codes`
668        # are the subset of the factors which appear in `values` and `target`,
669        # respectively
670        i, j, next_code = 0, 0, 0
671        while i < num_values and j < num_target_values:
672            val, target_val = values[i], target_values[j]
673            if val <= target_val:
674                new_codes[i] = next_code
675                i += 1
676            if target_val <= val:
677                new_target_codes[j] = next_code
678                j += 1
679            next_code += 1
680
681        # at this point, at least one should have reached the end
682        # the remaining values of the other should be added to the end
683        assert i == num_values or j == num_target_values
684        while i < num_values:
685            new_codes[i] = next_code
686            i += 1
687            next_code += 1
688        while j < num_target_values:
689            new_target_codes[j] = next_code
690            j += 1
691            next_code += 1
692
693        # get the indexer, and undo the sorting of `target.values`
694        sorted_indexer = (
695            algos.backfill if method == "backfill" else algos.pad
696        )(new_codes, new_target_codes, limit=limit).astype('int64')
697        return sorted_indexer[np.argsort(target_order)]
698
699    def get_loc(self, object key):
700        if is_definitely_invalid_key(key):
701            raise TypeError(f"'{key}' is an invalid key")
702        if not isinstance(key, tuple):
703            raise KeyError(key)
704        try:
705            indices = [0 if checknull(v) else lev.get_loc(v) + 1
706                       for lev, v in zip(self.levels, key)]
707        except KeyError:
708            raise KeyError(key)
709
710        # Transform indices into single integer:
711        lab_int = self._codes_to_ints(np.array(indices, dtype='uint64'))
712
713        return self._base.get_loc(self, lab_int)
714
715    def get_indexer_non_unique(self, object target):
716        # This needs to be overridden just because the default one works on
717        # target._values, and target can be itself a MultiIndex.
718
719        lab_ints = self._extract_level_codes(target)
720        indexer = self._base.get_indexer_non_unique(self, lab_ints)
721
722        return indexer
723
724    def __contains__(self, val: object) -> bool:
725        # We assume before we get here:
726        #  - val is hashable
727        # Default __contains__ looks in the underlying mapping, which in this
728        # case only contains integer representations.
729        try:
730            self.get_loc(val)
731            return True
732        except (KeyError, TypeError, ValueError):
733            return False
734
735
736# Generated from template.
737include "index_class_helper.pxi"
738