1""" miscellaneous sorting / groupby utilities """
2from collections import defaultdict
3from typing import (
4    TYPE_CHECKING,
5    Callable,
6    DefaultDict,
7    Dict,
8    Iterable,
9    List,
10    Optional,
11    Sequence,
12    Tuple,
13    Union,
14)
15
16import numpy as np
17
18from pandas._libs import algos, hashtable, lib
19from pandas._libs.hashtable import unique_label_indices
20from pandas._typing import IndexKeyFunc
21
22from pandas.core.dtypes.common import (
23    ensure_int64,
24    ensure_platform_int,
25    is_extension_array_dtype,
26)
27from pandas.core.dtypes.generic import ABCMultiIndex
28from pandas.core.dtypes.missing import isna
29
30import pandas.core.algorithms as algorithms
31from pandas.core.construction import extract_array
32
33if TYPE_CHECKING:
34    from pandas import MultiIndex
35    from pandas.core.indexes.base import Index
36
37_INT64_MAX = np.iinfo(np.int64).max
38
39
40def get_indexer_indexer(
41    target: "Index",
42    level: Union[str, int, List[str], List[int]],
43    ascending: Union[Sequence[Union[bool, int]], Union[bool, int]],
44    kind: str,
45    na_position: str,
46    sort_remaining: bool,
47    key: IndexKeyFunc,
48) -> Optional[np.array]:
49    """
50    Helper method that return the indexer according to input parameters for
51    the sort_index method of DataFrame and Series.
52
53    Parameters
54    ----------
55    target : Index
56    level : int or level name or list of ints or list of level names
57    ascending : bool or list of bools, default True
58    kind : {'quicksort', 'mergesort', 'heapsort'}, default 'quicksort'
59    na_position : {'first', 'last'}, default 'last'
60    sort_remaining : bool, default True
61    key : callable, optional
62
63    Returns
64    -------
65    Optional[ndarray]
66        The indexer for the new index.
67    """
68
69    target = ensure_key_mapped(target, key, levels=level)
70    target = target._sort_levels_monotonic()
71
72    if level is not None:
73        _, indexer = target.sortlevel(
74            level, ascending=ascending, sort_remaining=sort_remaining
75        )
76    elif isinstance(target, ABCMultiIndex):
77        indexer = lexsort_indexer(
78            target._get_codes_for_sorting(), orders=ascending, na_position=na_position
79        )
80    else:
81        # Check monotonic-ness before sort an index (GH 11080)
82        if (ascending and target.is_monotonic_increasing) or (
83            not ascending and target.is_monotonic_decreasing
84        ):
85            return None
86
87        indexer = nargsort(
88            target, kind=kind, ascending=ascending, na_position=na_position
89        )
90    return indexer
91
92
93def get_group_index(labels, shape, sort: bool, xnull: bool):
94    """
95    For the particular label_list, gets the offsets into the hypothetical list
96    representing the totally ordered cartesian product of all possible label
97    combinations, *as long as* this space fits within int64 bounds;
98    otherwise, though group indices identify unique combinations of
99    labels, they cannot be deconstructed.
100    - If `sort`, rank of returned ids preserve lexical ranks of labels.
101      i.e. returned id's can be used to do lexical sort on labels;
102    - If `xnull` nulls (-1 labels) are passed through.
103
104    Parameters
105    ----------
106    labels : sequence of arrays
107        Integers identifying levels at each location
108    shape : sequence of ints
109        Number of unique levels at each location
110    sort : bool
111        If the ranks of returned ids should match lexical ranks of labels
112    xnull : bool
113        If true nulls are excluded. i.e. -1 values in the labels are
114        passed through.
115
116    Returns
117    -------
118    An array of type int64 where two elements are equal if their corresponding
119    labels are equal at all location.
120
121    Notes
122    -----
123    The length of `labels` and `shape` must be identical.
124    """
125
126    def _int64_cut_off(shape) -> int:
127        acc = 1
128        for i, mul in enumerate(shape):
129            acc *= int(mul)
130            if not acc < _INT64_MAX:
131                return i
132        return len(shape)
133
134    def maybe_lift(lab, size):
135        # promote nan values (assigned -1 label in lab array)
136        # so that all output values are non-negative
137        return (lab + 1, size + 1) if (lab == -1).any() else (lab, size)
138
139    labels = map(ensure_int64, labels)
140    if not xnull:
141        labels, shape = map(list, zip(*map(maybe_lift, labels, shape)))
142
143    labels = list(labels)
144    shape = list(shape)
145
146    # Iteratively process all the labels in chunks sized so less
147    # than _INT64_MAX unique int ids will be required for each chunk
148    while True:
149        # how many levels can be done without overflow:
150        nlev = _int64_cut_off(shape)
151
152        # compute flat ids for the first `nlev` levels
153        stride = np.prod(shape[1:nlev], dtype="i8")
154        out = stride * labels[0].astype("i8", subok=False, copy=False)
155
156        for i in range(1, nlev):
157            if shape[i] == 0:
158                stride = 0
159            else:
160                stride //= shape[i]
161            out += labels[i] * stride
162
163        if xnull:  # exclude nulls
164            mask = labels[0] == -1
165            for lab in labels[1:nlev]:
166                mask |= lab == -1
167            out[mask] = -1
168
169        if nlev == len(shape):  # all levels done!
170            break
171
172        # compress what has been done so far in order to avoid overflow
173        # to retain lexical ranks, obs_ids should be sorted
174        comp_ids, obs_ids = compress_group_index(out, sort=sort)
175
176        labels = [comp_ids] + labels[nlev:]
177        shape = [len(obs_ids)] + shape[nlev:]
178
179    return out
180
181
182def get_compressed_ids(labels, sizes):
183    """
184    Group_index is offsets into cartesian product of all possible labels. This
185    space can be huge, so this function compresses it, by computing offsets
186    (comp_ids) into the list of unique labels (obs_group_ids).
187
188    Parameters
189    ----------
190    labels : list of label arrays
191    sizes : list of size of the levels
192
193    Returns
194    -------
195    tuple of (comp_ids, obs_group_ids)
196    """
197    ids = get_group_index(labels, sizes, sort=True, xnull=False)
198    return compress_group_index(ids, sort=True)
199
200
201def is_int64_overflow_possible(shape) -> bool:
202    the_prod = 1
203    for x in shape:
204        the_prod *= int(x)
205
206    return the_prod >= _INT64_MAX
207
208
209def decons_group_index(comp_labels, shape):
210    # reconstruct labels
211    if is_int64_overflow_possible(shape):
212        # at some point group indices are factorized,
213        # and may not be deconstructed here! wrong path!
214        raise ValueError("cannot deconstruct factorized group indices!")
215
216    label_list = []
217    factor = 1
218    y = 0
219    x = comp_labels
220    for i in reversed(range(len(shape))):
221        labels = (x - y) % (factor * shape[i]) // factor
222        np.putmask(labels, comp_labels < 0, -1)
223        label_list.append(labels)
224        y = labels * factor
225        factor *= shape[i]
226    return label_list[::-1]
227
228
229def decons_obs_group_ids(comp_ids, obs_ids, shape, labels, xnull: bool):
230    """
231    Reconstruct labels from observed group ids.
232
233    Parameters
234    ----------
235    xnull : bool
236        If nulls are excluded; i.e. -1 labels are passed through.
237    """
238    if not xnull:
239        lift = np.fromiter(((a == -1).any() for a in labels), dtype="i8")
240        shape = np.asarray(shape, dtype="i8") + lift
241
242    if not is_int64_overflow_possible(shape):
243        # obs ids are deconstructable! take the fast route!
244        out = decons_group_index(obs_ids, shape)
245        return out if xnull or not lift.any() else [x - y for x, y in zip(out, lift)]
246
247    i = unique_label_indices(comp_ids)
248    i8copy = lambda a: a.astype("i8", subok=False, copy=True)
249    return [i8copy(lab[i]) for lab in labels]
250
251
252def indexer_from_factorized(labels, shape, compress: bool = True):
253    ids = get_group_index(labels, shape, sort=True, xnull=False)
254
255    if not compress:
256        ngroups = (ids.size and ids.max()) + 1
257    else:
258        ids, obs = compress_group_index(ids, sort=True)
259        ngroups = len(obs)
260
261    return get_group_index_sorter(ids, ngroups)
262
263
264def lexsort_indexer(
265    keys, orders=None, na_position: str = "last", key: Optional[Callable] = None
266):
267    """
268    Performs lexical sorting on a set of keys
269
270    Parameters
271    ----------
272    keys : sequence of arrays
273        Sequence of ndarrays to be sorted by the indexer
274    orders : boolean or list of booleans, optional
275        Determines the sorting order for each element in keys. If a list,
276        it must be the same length as keys. This determines whether the
277        corresponding element in keys should be sorted in ascending
278        (True) or descending (False) order. if bool, applied to all
279        elements as above. if None, defaults to True.
280    na_position : {'first', 'last'}, default 'last'
281        Determines placement of NA elements in the sorted list ("last" or "first")
282    key : Callable, optional
283        Callable key function applied to every element in keys before sorting
284
285        .. versionadded:: 1.0.0
286    """
287    from pandas.core.arrays import Categorical
288
289    labels = []
290    shape = []
291    if isinstance(orders, bool):
292        orders = [orders] * len(keys)
293    elif orders is None:
294        orders = [True] * len(keys)
295
296    keys = [ensure_key_mapped(k, key) for k in keys]
297
298    for k, order in zip(keys, orders):
299        cat = Categorical(k, ordered=True)
300
301        if na_position not in ["last", "first"]:
302            raise ValueError(f"invalid na_position: {na_position}")
303
304        n = len(cat.categories)
305        codes = cat.codes.copy()
306
307        mask = cat.codes == -1
308        if order:  # ascending
309            if na_position == "last":
310                codes = np.where(mask, n, codes)
311            elif na_position == "first":
312                codes += 1
313        else:  # not order means descending
314            if na_position == "last":
315                codes = np.where(mask, n, n - codes - 1)
316            elif na_position == "first":
317                codes = np.where(mask, 0, n - codes)
318        if mask.any():
319            n += 1
320
321        shape.append(n)
322        labels.append(codes)
323
324    return indexer_from_factorized(labels, shape)
325
326
327def nargsort(
328    items,
329    kind: str = "quicksort",
330    ascending: bool = True,
331    na_position: str = "last",
332    key: Optional[Callable] = None,
333    mask: Optional[np.ndarray] = None,
334):
335    """
336    Intended to be a drop-in replacement for np.argsort which handles NaNs.
337
338    Adds ascending, na_position, and key parameters.
339
340    (GH #6399, #5231, #27237)
341
342    Parameters
343    ----------
344    kind : str, default 'quicksort'
345    ascending : bool, default True
346    na_position : {'first', 'last'}, default 'last'
347    key : Optional[Callable], default None
348    mask : Optional[np.ndarray], default None
349        Passed when called by ExtensionArray.argsort.
350    """
351
352    if key is not None:
353        items = ensure_key_mapped(items, key)
354        return nargsort(
355            items,
356            kind=kind,
357            ascending=ascending,
358            na_position=na_position,
359            key=None,
360            mask=mask,
361        )
362
363    items = extract_array(items)
364    if mask is None:
365        mask = np.asarray(isna(items))
366
367    if is_extension_array_dtype(items):
368        return items.argsort(ascending=ascending, kind=kind, na_position=na_position)
369    else:
370        items = np.asanyarray(items)
371
372    idx = np.arange(len(items))
373    non_nans = items[~mask]
374    non_nan_idx = idx[~mask]
375
376    nan_idx = np.nonzero(mask)[0]
377    if not ascending:
378        non_nans = non_nans[::-1]
379        non_nan_idx = non_nan_idx[::-1]
380    indexer = non_nan_idx[non_nans.argsort(kind=kind)]
381    if not ascending:
382        indexer = indexer[::-1]
383    # Finally, place the NaNs at the end or the beginning according to
384    # na_position
385    if na_position == "last":
386        indexer = np.concatenate([indexer, nan_idx])
387    elif na_position == "first":
388        indexer = np.concatenate([nan_idx, indexer])
389    else:
390        raise ValueError(f"invalid na_position: {na_position}")
391    return indexer
392
393
394def nargminmax(values, method: str):
395    """
396    Implementation of np.argmin/argmax but for ExtensionArray and which
397    handles missing values.
398
399    Parameters
400    ----------
401    values : ExtensionArray
402    method : {"argmax", "argmin"}
403
404    Returns
405    -------
406    int
407    """
408    assert method in {"argmax", "argmin"}
409    func = np.argmax if method == "argmax" else np.argmin
410
411    mask = np.asarray(isna(values))
412    values = values._values_for_argsort()
413
414    idx = np.arange(len(values))
415    non_nans = values[~mask]
416    non_nan_idx = idx[~mask]
417
418    return non_nan_idx[func(non_nans)]
419
420
421def _ensure_key_mapped_multiindex(
422    index: "MultiIndex", key: Callable, level=None
423) -> "MultiIndex":
424    """
425    Returns a new MultiIndex in which key has been applied
426    to all levels specified in level (or all levels if level
427    is None). Used for key sorting for MultiIndex.
428
429    Parameters
430    ----------
431    index : MultiIndex
432        Index to which to apply the key function on the
433        specified levels.
434    key : Callable
435        Function that takes an Index and returns an Index of
436        the same shape. This key is applied to each level
437        separately. The name of the level can be used to
438        distinguish different levels for application.
439    level : list-like, int or str, default None
440        Level or list of levels to apply the key function to.
441        If None, key function is applied to all levels. Other
442        levels are left unchanged.
443
444    Returns
445    -------
446    labels : MultiIndex
447        Resulting MultiIndex with modified levels.
448    """
449
450    if level is not None:
451        if isinstance(level, (str, int)):
452            sort_levels = [level]
453        else:
454            sort_levels = level
455
456        sort_levels = [index._get_level_number(lev) for lev in sort_levels]
457    else:
458        sort_levels = list(range(index.nlevels))  # satisfies mypy
459
460    mapped = [
461        ensure_key_mapped(index._get_level_values(level), key)
462        if level in sort_levels
463        else index._get_level_values(level)
464        for level in range(index.nlevels)
465    ]
466
467    labels = type(index).from_arrays(mapped)
468
469    return labels
470
471
472def ensure_key_mapped(values, key: Optional[Callable], levels=None):
473    """
474    Applies a callable key function to the values function and checks
475    that the resulting value has the same shape. Can be called on Index
476    subclasses, Series, DataFrames, or ndarrays.
477
478    Parameters
479    ----------
480    values : Series, DataFrame, Index subclass, or ndarray
481    key : Optional[Callable], key to be called on the values array
482    levels : Optional[List], if values is a MultiIndex, list of levels to
483    apply the key to.
484    """
485    from pandas.core.indexes.api import Index
486
487    if not key:
488        return values
489
490    if isinstance(values, ABCMultiIndex):
491        return _ensure_key_mapped_multiindex(values, key, level=levels)
492
493    result = key(values.copy())
494    if len(result) != len(values):
495        raise ValueError(
496            "User-provided `key` function must not change the shape of the array."
497        )
498
499    try:
500        if isinstance(
501            values, Index
502        ):  # convert to a new Index subclass, not necessarily the same
503            result = Index(result)
504        else:
505            type_of_values = type(values)
506            result = type_of_values(result)  # try to revert to original type otherwise
507    except TypeError:
508        raise TypeError(
509            f"User-provided `key` function returned an invalid type {type(result)} \
510            which could not be converted to {type(values)}."
511        )
512
513    return result
514
515
516def get_flattened_list(
517    comp_ids: np.ndarray,
518    ngroups: int,
519    levels: Iterable["Index"],
520    labels: Iterable[np.ndarray],
521) -> List[Tuple]:
522    """Map compressed group id -> key tuple."""
523    comp_ids = comp_ids.astype(np.int64, copy=False)
524    arrays: DefaultDict[int, List[int]] = defaultdict(list)
525    for labs, level in zip(labels, levels):
526        table = hashtable.Int64HashTable(ngroups)
527        table.map(comp_ids, labs.astype(np.int64, copy=False))
528        for i in range(ngroups):
529            arrays[i].append(level[table.get_item(i)])
530    return [tuple(array) for array in arrays.values()]
531
532
533def get_indexer_dict(
534    label_list: List[np.ndarray], keys: List["Index"]
535) -> Dict[Union[str, Tuple], np.ndarray]:
536    """
537    Returns
538    -------
539    dict:
540        Labels mapped to indexers.
541    """
542    shape = [len(x) for x in keys]
543
544    group_index = get_group_index(label_list, shape, sort=True, xnull=True)
545    if np.all(group_index == -1):
546        # When all keys are nan and dropna=True, indices_fast can't handle this
547        # and the return is empty anyway
548        return {}
549    ngroups = (
550        ((group_index.size and group_index.max()) + 1)
551        if is_int64_overflow_possible(shape)
552        else np.prod(shape, dtype="i8")
553    )
554
555    sorter = get_group_index_sorter(group_index, ngroups)
556
557    sorted_labels = [lab.take(sorter) for lab in label_list]
558    group_index = group_index.take(sorter)
559
560    return lib.indices_fast(sorter, group_index, keys, sorted_labels)
561
562
563# ----------------------------------------------------------------------
564# sorting levels...cleverly?
565
566
567def get_group_index_sorter(group_index, ngroups: int):
568    """
569    algos.groupsort_indexer implements `counting sort` and it is at least
570    O(ngroups), where
571        ngroups = prod(shape)
572        shape = map(len, keys)
573    that is, linear in the number of combinations (cartesian product) of unique
574    values of groupby keys. This can be huge when doing multi-key groupby.
575    np.argsort(kind='mergesort') is O(count x log(count)) where count is the
576    length of the data-frame;
577    Both algorithms are `stable` sort and that is necessary for correctness of
578    groupby operations. e.g. consider:
579        df.groupby(key)[col].transform('first')
580    """
581    count = len(group_index)
582    alpha = 0.0  # taking complexities literally; there may be
583    beta = 1.0  # some room for fine-tuning these parameters
584    do_groupsort = count > 0 and ((alpha + beta * ngroups) < (count * np.log(count)))
585    if do_groupsort:
586        sorter, _ = algos.groupsort_indexer(ensure_int64(group_index), ngroups)
587        return ensure_platform_int(sorter)
588    else:
589        return group_index.argsort(kind="mergesort")
590
591
592def compress_group_index(group_index, sort: bool = True):
593    """
594    Group_index is offsets into cartesian product of all possible labels. This
595    space can be huge, so this function compresses it, by computing offsets
596    (comp_ids) into the list of unique labels (obs_group_ids).
597    """
598    size_hint = min(len(group_index), hashtable.SIZE_HINT_LIMIT)
599    table = hashtable.Int64HashTable(size_hint)
600
601    group_index = ensure_int64(group_index)
602
603    # note, group labels come out ascending (ie, 1,2,3 etc)
604    comp_ids, obs_group_ids = table.get_labels_groupby(group_index)
605
606    if sort and len(obs_group_ids) > 0:
607        obs_group_ids, comp_ids = _reorder_by_uniques(obs_group_ids, comp_ids)
608
609    return ensure_int64(comp_ids), ensure_int64(obs_group_ids)
610
611
612def _reorder_by_uniques(uniques, labels):
613    # sorter is index where elements ought to go
614    sorter = uniques.argsort()
615
616    # reverse_indexer is where elements came from
617    reverse_indexer = np.empty(len(sorter), dtype=np.int64)
618    reverse_indexer.put(sorter, np.arange(len(sorter)))
619
620    mask = labels < 0
621
622    # move labels to right locations (ie, unsort ascending labels)
623    labels = algorithms.take_nd(reverse_indexer, labels, allow_fill=False)
624    np.putmask(labels, mask, -1)
625
626    # sort observed ids
627    uniques = algorithms.take_nd(uniques, sorter, allow_fill=False)
628
629    return uniques, labels
630