1import cython
2from cython import Py_ssize_t
3
4from libc.math cimport fabs, sqrt
5from libc.stdlib cimport free, malloc
6from libc.string cimport memmove
7
8import numpy as np
9
10cimport numpy as cnp
11from numpy cimport (
12    NPY_FLOAT32,
13    NPY_FLOAT64,
14    NPY_INT8,
15    NPY_INT16,
16    NPY_INT32,
17    NPY_INT64,
18    NPY_OBJECT,
19    NPY_UINT8,
20    NPY_UINT16,
21    NPY_UINT32,
22    NPY_UINT64,
23    float32_t,
24    float64_t,
25    int8_t,
26    int16_t,
27    int32_t,
28    int64_t,
29    ndarray,
30    uint8_t,
31    uint16_t,
32    uint32_t,
33    uint64_t,
34)
35
36cnp.import_array()
37
38
39cimport pandas._libs.util as util
40from pandas._libs.khash cimport (
41    kh_destroy_int64,
42    kh_get_int64,
43    kh_init_int64,
44    kh_int64_t,
45    kh_put_int64,
46    kh_resize_int64,
47    khiter_t,
48)
49from pandas._libs.util cimport get_nat, numeric
50
51import pandas._libs.missing as missing
52
53cdef:
54    float64_t FP_ERR = 1e-13
55    float64_t NaN = <float64_t>np.NaN
56    int64_t NPY_NAT = get_nat()
57
58tiebreakers = {
59    "average": TIEBREAK_AVERAGE,
60    "min": TIEBREAK_MIN,
61    "max": TIEBREAK_MAX,
62    "first": TIEBREAK_FIRST,
63    "dense": TIEBREAK_DENSE,
64}
65
66
67cdef inline bint are_diff(object left, object right):
68    try:
69        return fabs(left - right) > FP_ERR
70    except TypeError:
71        return left != right
72
73
74class Infinity:
75    """
76    Provide a positive Infinity comparison method for ranking.
77    """
78    __lt__ = lambda self, other: False
79    __le__ = lambda self, other: isinstance(other, Infinity)
80    __eq__ = lambda self, other: isinstance(other, Infinity)
81    __ne__ = lambda self, other: not isinstance(other, Infinity)
82    __gt__ = lambda self, other: (not isinstance(other, Infinity) and
83                                  not missing.checknull(other))
84    __ge__ = lambda self, other: not missing.checknull(other)
85
86
87class NegInfinity:
88    """
89    Provide a negative Infinity comparison method for ranking.
90    """
91    __lt__ = lambda self, other: (not isinstance(other, NegInfinity) and
92                                  not missing.checknull(other))
93    __le__ = lambda self, other: not missing.checknull(other)
94    __eq__ = lambda self, other: isinstance(other, NegInfinity)
95    __ne__ = lambda self, other: not isinstance(other, NegInfinity)
96    __gt__ = lambda self, other: False
97    __ge__ = lambda self, other: isinstance(other, NegInfinity)
98
99
100@cython.wraparound(False)
101@cython.boundscheck(False)
102cpdef ndarray[int64_t, ndim=1] unique_deltas(const int64_t[:] arr):
103    """
104    Efficiently find the unique first-differences of the given array.
105
106    Parameters
107    ----------
108    arr : ndarray[in64_t]
109
110    Returns
111    -------
112    ndarray[int64_t]
113        An ordered ndarray[int64_t]
114    """
115    cdef:
116        Py_ssize_t i, n = len(arr)
117        int64_t val
118        khiter_t k
119        kh_int64_t *table
120        int ret = 0
121        list uniques = []
122        ndarray[int64_t, ndim=1] result
123
124    table = kh_init_int64()
125    kh_resize_int64(table, 10)
126    for i in range(n - 1):
127        val = arr[i + 1] - arr[i]
128        k = kh_get_int64(table, val)
129        if k == table.n_buckets:
130            kh_put_int64(table, val, &ret)
131            uniques.append(val)
132    kh_destroy_int64(table)
133
134    result = np.array(uniques, dtype=np.int64)
135    result.sort()
136    return result
137
138
139@cython.wraparound(False)
140@cython.boundscheck(False)
141def is_lexsorted(list_of_arrays: list) -> bint:
142    cdef:
143        Py_ssize_t i
144        Py_ssize_t n, nlevels
145        int64_t k, cur, pre
146        ndarray arr
147        bint result = True
148
149    nlevels = len(list_of_arrays)
150    n = len(list_of_arrays[0])
151
152    cdef int64_t **vecs = <int64_t**>malloc(nlevels * sizeof(int64_t*))
153    for i in range(nlevels):
154        arr = list_of_arrays[i]
155        assert arr.dtype.name == 'int64'
156        vecs[i] = <int64_t*>cnp.PyArray_DATA(arr)
157
158    # Assume uniqueness??
159    with nogil:
160        for i in range(1, n):
161            for k in range(nlevels):
162                cur = vecs[k][i]
163                pre = vecs[k][i -1]
164                if cur == pre:
165                    continue
166                elif cur > pre:
167                    break
168                else:
169                    result = False
170                    break
171    free(vecs)
172    return result
173
174
175@cython.boundscheck(False)
176@cython.wraparound(False)
177def groupsort_indexer(const int64_t[:] index, Py_ssize_t ngroups):
178    """
179    Compute a 1-d indexer.
180
181    The indexer is an ordering of the passed index,
182    ordered by the groups.
183
184    Parameters
185    ----------
186    index: int64 ndarray
187        Mappings from group -> position.
188    ngroups: int64
189        Number of groups.
190
191    Returns
192    -------
193    tuple
194        1-d indexer ordered by groups, group counts.
195
196    Notes
197    -----
198    This is a reverse of the label factorization process.
199    """
200    cdef:
201        Py_ssize_t i, loc, label, n
202        ndarray[int64_t] counts, where, result
203
204    counts = np.zeros(ngroups + 1, dtype=np.int64)
205    n = len(index)
206    result = np.zeros(n, dtype=np.int64)
207    where = np.zeros(ngroups + 1, dtype=np.int64)
208
209    with nogil:
210
211        # count group sizes, location 0 for NA
212        for i in range(n):
213            counts[index[i] + 1] += 1
214
215        # mark the start of each contiguous group of like-indexed data
216        for i in range(1, ngroups + 1):
217            where[i] = where[i - 1] + counts[i - 1]
218
219        # this is our indexer
220        for i in range(n):
221            label = index[i] + 1
222            result[where[label]] = i
223            where[label] += 1
224
225    return result, counts
226
227
228@cython.boundscheck(False)
229@cython.wraparound(False)
230def kth_smallest(numeric[:] a, Py_ssize_t k) -> numeric:
231    cdef:
232        Py_ssize_t i, j, l, m, n = a.shape[0]
233        numeric x
234
235    with nogil:
236        l = 0
237        m = n - 1
238
239        while l < m:
240            x = a[k]
241            i = l
242            j = m
243
244            while 1:
245                while a[i] < x: i += 1
246                while x < a[j]: j -= 1
247                if i <= j:
248                    swap(&a[i], &a[j])
249                    i += 1; j -= 1
250
251                if i > j: break
252
253            if j < k: l = i
254            if k < i: m = j
255    return a[k]
256
257
258# ----------------------------------------------------------------------
259# Pairwise correlation/covariance
260
261
262@cython.boundscheck(False)
263@cython.wraparound(False)
264def nancorr(const float64_t[:, :] mat, bint cov=False, minp=None):
265    cdef:
266        Py_ssize_t i, j, xi, yi, N, K
267        bint minpv
268        ndarray[float64_t, ndim=2] result
269        ndarray[uint8_t, ndim=2] mask
270        int64_t nobs = 0
271        float64_t vx, vy, meanx, meany, divisor, prev_meany, prev_meanx, ssqdmx
272        float64_t ssqdmy, covxy
273
274    N, K = (<object>mat).shape
275
276    if minp is None:
277        minpv = 1
278    else:
279        minpv = <int>minp
280
281    result = np.empty((K, K), dtype=np.float64)
282    mask = np.isfinite(mat).view(np.uint8)
283
284    with nogil:
285        for xi in range(K):
286            for yi in range(xi + 1):
287                # Welford's method for the variance-calculation
288                # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
289                nobs = ssqdmx = ssqdmy = covxy = meanx = meany = 0
290                for i in range(N):
291                    if mask[i, xi] and mask[i, yi]:
292                        vx = mat[i, xi]
293                        vy = mat[i, yi]
294                        nobs += 1
295                        prev_meanx = meanx
296                        prev_meany = meany
297                        meanx = meanx + 1 / nobs * (vx - meanx)
298                        meany = meany + 1 / nobs * (vy - meany)
299                        ssqdmx = ssqdmx + (vx - meanx) * (vx - prev_meanx)
300                        ssqdmy = ssqdmy + (vy - meany) * (vy - prev_meany)
301                        covxy = covxy + (vx - meanx) * (vy - prev_meany)
302
303                if nobs < minpv:
304                    result[xi, yi] = result[yi, xi] = NaN
305                else:
306                    divisor = (nobs - 1.0) if cov else sqrt(ssqdmx * ssqdmy)
307
308                    if divisor != 0:
309                        result[xi, yi] = result[yi, xi] = covxy / divisor
310                    else:
311                        result[xi, yi] = result[yi, xi] = NaN
312
313    return result
314
315# ----------------------------------------------------------------------
316# Pairwise Spearman correlation
317
318
319@cython.boundscheck(False)
320@cython.wraparound(False)
321def nancorr_spearman(ndarray[float64_t, ndim=2] mat, Py_ssize_t minp=1) -> ndarray:
322    cdef:
323        Py_ssize_t i, j, xi, yi, N, K
324        ndarray[float64_t, ndim=2] result
325        ndarray[float64_t, ndim=2] ranked_mat
326        ndarray[float64_t, ndim=1] maskedx
327        ndarray[float64_t, ndim=1] maskedy
328        ndarray[uint8_t, ndim=2] mask
329        int64_t nobs = 0
330        float64_t vx, vy, sumx, sumxx, sumyy, mean, divisor
331
332    N, K = (<object>mat).shape
333
334    result = np.empty((K, K), dtype=np.float64)
335    mask = np.isfinite(mat).view(np.uint8)
336
337    ranked_mat = np.empty((N, K), dtype=np.float64)
338
339    for i in range(K):
340        ranked_mat[:, i] = rank_1d(mat[:, i])
341
342    for xi in range(K):
343        for yi in range(xi + 1):
344            nobs = 0
345            # Keep track of whether we need to recompute ranks
346            all_ranks = True
347            for i in range(N):
348                all_ranks &= not (mask[i, xi] ^ mask[i, yi])
349                if mask[i, xi] and mask[i, yi]:
350                    nobs += 1
351
352            if nobs < minp:
353                result[xi, yi] = result[yi, xi] = NaN
354            else:
355                maskedx = np.empty(nobs, dtype=np.float64)
356                maskedy = np.empty(nobs, dtype=np.float64)
357                j = 0
358
359                for i in range(N):
360                    if mask[i, xi] and mask[i, yi]:
361                        maskedx[j] = ranked_mat[i, xi]
362                        maskedy[j] = ranked_mat[i, yi]
363                        j += 1
364
365                if not all_ranks:
366                    maskedx = rank_1d(maskedx)
367                    maskedy = rank_1d(maskedy)
368
369                mean = (nobs + 1) / 2.
370
371                # now the cov numerator
372                sumx = sumxx = sumyy = 0
373
374                for i in range(nobs):
375                    vx = maskedx[i] - mean
376                    vy = maskedy[i] - mean
377
378                    sumx += vx * vy
379                    sumxx += vx * vx
380                    sumyy += vy * vy
381
382                divisor = sqrt(sumxx * sumyy)
383
384                if divisor != 0:
385                    result[xi, yi] = result[yi, xi] = sumx / divisor
386                else:
387                    result[xi, yi] = result[yi, xi] = NaN
388
389    return result
390
391
392# ----------------------------------------------------------------------
393
394ctypedef fused algos_t:
395    float64_t
396    float32_t
397    object
398    int64_t
399    int32_t
400    int16_t
401    int8_t
402    uint64_t
403    uint32_t
404    uint16_t
405    uint8_t
406
407
408def validate_limit(nobs: int, limit=None) -> int:
409    """
410    Check that the `limit` argument is a positive integer.
411
412    Parameters
413    ----------
414    nobs : int
415    limit : object
416
417    Returns
418    -------
419    int
420        The limit.
421    """
422    if limit is None:
423        lim = nobs
424    else:
425        if not util.is_integer_object(limit):
426            raise ValueError('Limit must be an integer')
427        if limit < 1:
428            raise ValueError('Limit must be greater than 0')
429        lim = limit
430
431    return lim
432
433
434@cython.boundscheck(False)
435@cython.wraparound(False)
436def pad(ndarray[algos_t] old, ndarray[algos_t] new, limit=None):
437    cdef:
438        Py_ssize_t i, j, nleft, nright
439        ndarray[int64_t, ndim=1] indexer
440        algos_t cur, next_val
441        int lim, fill_count = 0
442
443    nleft = len(old)
444    nright = len(new)
445    indexer = np.empty(nright, dtype=np.int64)
446    indexer[:] = -1
447
448    lim = validate_limit(nright, limit)
449
450    if nleft == 0 or nright == 0 or new[nright - 1] < old[0]:
451        return indexer
452
453    i = j = 0
454
455    cur = old[0]
456
457    while j <= nright - 1 and new[j] < cur:
458        j += 1
459
460    while True:
461        if j == nright:
462            break
463
464        if i == nleft - 1:
465            while j < nright:
466                if new[j] == cur:
467                    indexer[j] = i
468                elif new[j] > cur and fill_count < lim:
469                    indexer[j] = i
470                    fill_count += 1
471                j += 1
472            break
473
474        next_val = old[i + 1]
475
476        while j < nright and cur <= new[j] < next_val:
477            if new[j] == cur:
478                indexer[j] = i
479            elif fill_count < lim:
480                indexer[j] = i
481                fill_count += 1
482            j += 1
483
484        fill_count = 0
485        i += 1
486        cur = next_val
487
488    return indexer
489
490
491@cython.boundscheck(False)
492@cython.wraparound(False)
493def pad_inplace(algos_t[:] values, const uint8_t[:] mask, limit=None):
494    cdef:
495        Py_ssize_t i, N
496        algos_t val
497        int lim, fill_count = 0
498
499    N = len(values)
500
501    # GH#2778
502    if N == 0:
503        return
504
505    lim = validate_limit(N, limit)
506
507    val = values[0]
508    for i in range(N):
509        if mask[i]:
510            if fill_count >= lim:
511                continue
512            fill_count += 1
513            values[i] = val
514        else:
515            fill_count = 0
516            val = values[i]
517
518
519@cython.boundscheck(False)
520@cython.wraparound(False)
521def pad_2d_inplace(algos_t[:, :] values, const uint8_t[:, :] mask, limit=None):
522    cdef:
523        Py_ssize_t i, j, N, K
524        algos_t val
525        int lim, fill_count = 0
526
527    K, N = (<object>values).shape
528
529    # GH#2778
530    if N == 0:
531        return
532
533    lim = validate_limit(N, limit)
534
535    for j in range(K):
536        fill_count = 0
537        val = values[j, 0]
538        for i in range(N):
539            if mask[j, i]:
540                if fill_count >= lim:
541                    continue
542                fill_count += 1
543                values[j, i] = val
544            else:
545                fill_count = 0
546                val = values[j, i]
547
548
549"""
550Backfilling logic for generating fill vector
551
552Diagram of what's going on
553
554Old      New    Fill vector    Mask
555         .        0               1
556         .        0               1
557         .        0               1
558A        A        0               1
559         .        1               1
560         .        1               1
561         .        1               1
562         .        1               1
563         .        1               1
564B        B        1               1
565         .        2               1
566         .        2               1
567         .        2               1
568C        C        2               1
569         .                        0
570         .                        0
571D
572"""
573
574
575@cython.boundscheck(False)
576@cython.wraparound(False)
577def backfill(ndarray[algos_t] old, ndarray[algos_t] new, limit=None) -> ndarray:
578    cdef:
579        Py_ssize_t i, j, nleft, nright
580        ndarray[int64_t, ndim=1] indexer
581        algos_t cur, prev
582        int lim, fill_count = 0
583
584    nleft = len(old)
585    nright = len(new)
586    indexer = np.empty(nright, dtype=np.int64)
587    indexer[:] = -1
588
589    lim = validate_limit(nright, limit)
590
591    if nleft == 0 or nright == 0 or new[0] > old[nleft - 1]:
592        return indexer
593
594    i = nleft - 1
595    j = nright - 1
596
597    cur = old[nleft - 1]
598
599    while j >= 0 and new[j] > cur:
600        j -= 1
601
602    while True:
603        if j < 0:
604            break
605
606        if i == 0:
607            while j >= 0:
608                if new[j] == cur:
609                    indexer[j] = i
610                elif new[j] < cur and fill_count < lim:
611                    indexer[j] = i
612                    fill_count += 1
613                j -= 1
614            break
615
616        prev = old[i - 1]
617
618        while j >= 0 and prev < new[j] <= cur:
619            if new[j] == cur:
620                indexer[j] = i
621            elif new[j] < cur and fill_count < lim:
622                indexer[j] = i
623                fill_count += 1
624            j -= 1
625
626        fill_count = 0
627        i -= 1
628        cur = prev
629
630    return indexer
631
632
633@cython.boundscheck(False)
634@cython.wraparound(False)
635def backfill_inplace(algos_t[:] values, const uint8_t[:] mask, limit=None):
636    cdef:
637        Py_ssize_t i, N
638        algos_t val
639        int lim, fill_count = 0
640
641    N = len(values)
642
643    # GH#2778
644    if N == 0:
645        return
646
647    lim = validate_limit(N, limit)
648
649    val = values[N - 1]
650    for i in range(N - 1, -1, -1):
651        if mask[i]:
652            if fill_count >= lim:
653                continue
654            fill_count += 1
655            values[i] = val
656        else:
657            fill_count = 0
658            val = values[i]
659
660
661@cython.boundscheck(False)
662@cython.wraparound(False)
663def backfill_2d_inplace(algos_t[:, :] values,
664                        const uint8_t[:, :] mask,
665                        limit=None):
666    cdef:
667        Py_ssize_t i, j, N, K
668        algos_t val
669        int lim, fill_count = 0
670
671    K, N = (<object>values).shape
672
673    # GH#2778
674    if N == 0:
675        return
676
677    lim = validate_limit(N, limit)
678
679    for j in range(K):
680        fill_count = 0
681        val = values[j, N - 1]
682        for i in range(N - 1, -1, -1):
683            if mask[j, i]:
684                if fill_count >= lim:
685                    continue
686                fill_count += 1
687                values[j, i] = val
688            else:
689                fill_count = 0
690                val = values[j, i]
691
692
693@cython.boundscheck(False)
694@cython.wraparound(False)
695def is_monotonic(ndarray[algos_t, ndim=1] arr, bint timelike):
696    """
697    Returns
698    -------
699    tuple
700        is_monotonic_inc : bool
701        is_monotonic_dec : bool
702        is_unique : bool
703    """
704    cdef:
705        Py_ssize_t i, n
706        algos_t prev, cur
707        bint is_monotonic_inc = 1
708        bint is_monotonic_dec = 1
709        bint is_unique = 1
710        bint is_strict_monotonic = 1
711
712    n = len(arr)
713
714    if n == 1:
715        if arr[0] != arr[0] or (timelike and <int64_t>arr[0] == NPY_NAT):
716            # single value is NaN
717            return False, False, True
718        else:
719            return True, True, True
720    elif n < 2:
721        return True, True, True
722
723    if timelike and <int64_t>arr[0] == NPY_NAT:
724        return False, False, True
725
726    if algos_t is not object:
727        with nogil:
728            prev = arr[0]
729            for i in range(1, n):
730                cur = arr[i]
731                if timelike and <int64_t>cur == NPY_NAT:
732                    is_monotonic_inc = 0
733                    is_monotonic_dec = 0
734                    break
735                if cur < prev:
736                    is_monotonic_inc = 0
737                elif cur > prev:
738                    is_monotonic_dec = 0
739                elif cur == prev:
740                    is_unique = 0
741                else:
742                    # cur or prev is NaN
743                    is_monotonic_inc = 0
744                    is_monotonic_dec = 0
745                    break
746                if not is_monotonic_inc and not is_monotonic_dec:
747                    is_monotonic_inc = 0
748                    is_monotonic_dec = 0
749                    break
750                prev = cur
751    else:
752        # object-dtype, identical to above except we cannot use `with nogil`
753        prev = arr[0]
754        for i in range(1, n):
755            cur = arr[i]
756            if timelike and <int64_t>cur == NPY_NAT:
757                is_monotonic_inc = 0
758                is_monotonic_dec = 0
759                break
760            if cur < prev:
761                is_monotonic_inc = 0
762            elif cur > prev:
763                is_monotonic_dec = 0
764            elif cur == prev:
765                is_unique = 0
766            else:
767                # cur or prev is NaN
768                is_monotonic_inc = 0
769                is_monotonic_dec = 0
770                break
771            if not is_monotonic_inc and not is_monotonic_dec:
772                is_monotonic_inc = 0
773                is_monotonic_dec = 0
774                break
775            prev = cur
776
777    is_strict_monotonic = is_unique and (is_monotonic_inc or is_monotonic_dec)
778    return is_monotonic_inc, is_monotonic_dec, is_strict_monotonic
779
780
781# ----------------------------------------------------------------------
782# rank_1d, rank_2d
783# ----------------------------------------------------------------------
784
785ctypedef fused rank_t:
786    object
787    float64_t
788    uint64_t
789    int64_t
790
791
792@cython.wraparound(False)
793@cython.boundscheck(False)
794def rank_1d(
795    ndarray[rank_t, ndim=1] in_arr,
796    ties_method="average",
797    bint ascending=True,
798    na_option="keep",
799    bint pct=False,
800):
801    """
802    Fast NaN-friendly version of ``scipy.stats.rankdata``.
803    """
804    cdef:
805        Py_ssize_t i, j, n, dups = 0, total_tie_count = 0, non_na_idx = 0
806        ndarray[rank_t] sorted_data, values
807        ndarray[float64_t] ranks
808        ndarray[int64_t] argsorted
809        ndarray[uint8_t, cast=True] sorted_mask
810        rank_t val, nan_value
811        float64_t sum_ranks = 0
812        int tiebreak = 0
813        bint keep_na = False
814        bint isnan, condition
815        float64_t count = 0.0
816
817    tiebreak = tiebreakers[ties_method]
818
819    if rank_t is float64_t:
820        values = np.asarray(in_arr).copy()
821    elif rank_t is object:
822        values = np.array(in_arr, copy=True)
823
824        if values.dtype != np.object_:
825            values = values.astype('O')
826    else:
827        values = np.asarray(in_arr).copy()
828
829    keep_na = na_option == 'keep'
830
831    if rank_t is object:
832        mask = missing.isnaobj(values)
833    elif rank_t is float64_t:
834        mask = np.isnan(values)
835    elif rank_t is int64_t:
836        mask = values == NPY_NAT
837
838    # double sort first by mask and then by values to ensure nan values are
839    # either at the beginning or the end. mask/(~mask) controls padding at
840    # tail or the head
841    if rank_t is not uint64_t:
842        if ascending ^ (na_option == 'top'):
843            if rank_t is object:
844                nan_value = Infinity()
845            elif rank_t is float64_t:
846                nan_value = np.inf
847            elif rank_t is int64_t:
848                nan_value = np.iinfo(np.int64).max
849
850            order = (values, mask)
851        else:
852            if rank_t is object:
853                nan_value = NegInfinity()
854            elif rank_t is float64_t:
855                nan_value = -np.inf
856            elif rank_t is int64_t:
857                nan_value = np.iinfo(np.int64).min
858
859            order = (values, ~mask)
860        np.putmask(values, mask, nan_value)
861    else:
862        mask = np.zeros(shape=len(values), dtype=bool)
863        order = (values, mask)
864
865    n = len(values)
866    ranks = np.empty(n, dtype='f8')
867
868    if rank_t is object:
869        _as = np.lexsort(keys=order)
870    else:
871        if tiebreak == TIEBREAK_FIRST:
872            # need to use a stable sort here
873            _as = np.lexsort(keys=order)
874            if not ascending:
875                tiebreak = TIEBREAK_FIRST_DESCENDING
876        else:
877            _as = np.lexsort(keys=order)
878
879    if not ascending:
880        _as = _as[::-1]
881
882    sorted_data = values.take(_as)
883    sorted_mask = mask.take(_as)
884    _indices = np.diff(sorted_mask.astype(int)).nonzero()[0]
885    non_na_idx = _indices[0] if len(_indices) > 0 else -1
886    argsorted = _as.astype('i8')
887
888    if rank_t is object:
889        # TODO: de-duplicate once cython supports conditional nogil
890        for i in range(n):
891            sum_ranks += i + 1
892            dups += 1
893
894            val = sorted_data[i]
895
896            if rank_t is not uint64_t:
897                isnan = sorted_mask[i]
898                if isnan and keep_na:
899                    ranks[argsorted[i]] = NaN
900                    continue
901
902            count += 1.0
903
904            if rank_t is object:
905                condition = (
906                    i == n - 1 or
907                    are_diff(sorted_data[i + 1], val) or
908                    i == non_na_idx
909                )
910            else:
911                condition = (
912                    i == n - 1 or
913                    sorted_data[i + 1] != val or
914                    i == non_na_idx
915                )
916
917            if condition:
918
919                if tiebreak == TIEBREAK_AVERAGE:
920                    for j in range(i - dups + 1, i + 1):
921                        ranks[argsorted[j]] = sum_ranks / dups
922                elif tiebreak == TIEBREAK_MIN:
923                    for j in range(i - dups + 1, i + 1):
924                        ranks[argsorted[j]] = i - dups + 2
925                elif tiebreak == TIEBREAK_MAX:
926                    for j in range(i - dups + 1, i + 1):
927                        ranks[argsorted[j]] = i + 1
928                elif tiebreak == TIEBREAK_FIRST:
929                    if rank_t is object:
930                        raise ValueError('first not supported for non-numeric data')
931                    else:
932                        for j in range(i - dups + 1, i + 1):
933                            ranks[argsorted[j]] = j + 1
934                elif tiebreak == TIEBREAK_FIRST_DESCENDING:
935                    for j in range(i - dups + 1, i + 1):
936                        ranks[argsorted[j]] = 2 * i - j - dups + 2
937                elif tiebreak == TIEBREAK_DENSE:
938                    total_tie_count += 1
939                    for j in range(i - dups + 1, i + 1):
940                        ranks[argsorted[j]] = total_tie_count
941                sum_ranks = dups = 0
942
943    else:
944        with nogil:
945            # TODO: why does the 2d version not have a nogil block?
946            for i in range(n):
947                sum_ranks += i + 1
948                dups += 1
949
950                val = sorted_data[i]
951
952                if rank_t is not uint64_t:
953                    isnan = sorted_mask[i]
954                    if isnan and keep_na:
955                        ranks[argsorted[i]] = NaN
956                        continue
957
958                count += 1.0
959
960                if rank_t is object:
961                    condition = (
962                        i == n - 1 or
963                        are_diff(sorted_data[i + 1], val) or
964                        i == non_na_idx
965                    )
966                else:
967                    condition = (
968                        i == n - 1 or
969                        sorted_data[i + 1] != val or
970                        i == non_na_idx
971                    )
972
973                if condition:
974
975                    if tiebreak == TIEBREAK_AVERAGE:
976                        for j in range(i - dups + 1, i + 1):
977                            ranks[argsorted[j]] = sum_ranks / dups
978                    elif tiebreak == TIEBREAK_MIN:
979                        for j in range(i - dups + 1, i + 1):
980                            ranks[argsorted[j]] = i - dups + 2
981                    elif tiebreak == TIEBREAK_MAX:
982                        for j in range(i - dups + 1, i + 1):
983                            ranks[argsorted[j]] = i + 1
984                    elif tiebreak == TIEBREAK_FIRST:
985                        if rank_t is object:
986                            raise ValueError('first not supported for non-numeric data')
987                        else:
988                            for j in range(i - dups + 1, i + 1):
989                                ranks[argsorted[j]] = j + 1
990                    elif tiebreak == TIEBREAK_FIRST_DESCENDING:
991                        for j in range(i - dups + 1, i + 1):
992                            ranks[argsorted[j]] = 2 * i - j - dups + 2
993                    elif tiebreak == TIEBREAK_DENSE:
994                        total_tie_count += 1
995                        for j in range(i - dups + 1, i + 1):
996                            ranks[argsorted[j]] = total_tie_count
997                    sum_ranks = dups = 0
998
999    if pct:
1000        if tiebreak == TIEBREAK_DENSE:
1001            return ranks / total_tie_count
1002        else:
1003            return ranks / count
1004    else:
1005        return ranks
1006
1007
1008def rank_2d(
1009    ndarray[rank_t, ndim=2] in_arr,
1010    int axis=0,
1011    ties_method="average",
1012    bint ascending=True,
1013    na_option="keep",
1014    bint pct=False,
1015):
1016    """
1017    Fast NaN-friendly version of ``scipy.stats.rankdata``.
1018    """
1019    cdef:
1020        Py_ssize_t i, j, z, k, n, dups = 0, total_tie_count = 0
1021        Py_ssize_t infs
1022        ndarray[float64_t, ndim=2] ranks
1023        ndarray[rank_t, ndim=2] values
1024        ndarray[int64_t, ndim=2] argsorted
1025        rank_t val, nan_value
1026        float64_t sum_ranks = 0
1027        int tiebreak = 0
1028        bint keep_na = False
1029        float64_t count = 0.0
1030        bint condition, skip_condition
1031
1032    tiebreak = tiebreakers[ties_method]
1033
1034    keep_na = na_option == 'keep'
1035
1036    if axis == 0:
1037        values = np.asarray(in_arr).T.copy()
1038    else:
1039        values = np.asarray(in_arr).copy()
1040
1041    if rank_t is object:
1042        if values.dtype != np.object_:
1043            values = values.astype('O')
1044
1045    if rank_t is not uint64_t:
1046        if ascending ^ (na_option == 'top'):
1047            if rank_t is object:
1048                nan_value = Infinity()
1049            elif rank_t is float64_t:
1050                nan_value = np.inf
1051            elif rank_t is int64_t:
1052                nan_value = np.iinfo(np.int64).max
1053
1054        else:
1055            if rank_t is object:
1056                nan_value = NegInfinity()
1057            elif rank_t is float64_t:
1058                nan_value = -np.inf
1059            elif rank_t is int64_t:
1060                nan_value = NPY_NAT
1061
1062        if rank_t is object:
1063            mask = missing.isnaobj2d(values)
1064        elif rank_t is float64_t:
1065            mask = np.isnan(values)
1066        elif rank_t is int64_t:
1067            mask = values == NPY_NAT
1068
1069        np.putmask(values, mask, nan_value)
1070
1071    n, k = (<object>values).shape
1072    ranks = np.empty((n, k), dtype='f8')
1073
1074    if rank_t is object:
1075        try:
1076            _as = values.argsort(1)
1077        except TypeError:
1078            values = in_arr
1079            for i in range(len(values)):
1080                ranks[i] = rank_1d(in_arr[i], ties_method=ties_method,
1081                                   ascending=ascending, pct=pct)
1082            if axis == 0:
1083                return ranks.T
1084            else:
1085                return ranks
1086    else:
1087        if tiebreak == TIEBREAK_FIRST:
1088            # need to use a stable sort here
1089            _as = values.argsort(axis=1, kind='mergesort')
1090            if not ascending:
1091                tiebreak = TIEBREAK_FIRST_DESCENDING
1092        else:
1093            _as = values.argsort(1)
1094
1095    if not ascending:
1096        _as = _as[:, ::-1]
1097
1098    values = _take_2d(values, _as)
1099    argsorted = _as.astype('i8')
1100
1101    for i in range(n):
1102        if rank_t is object:
1103            dups = sum_ranks = infs = 0
1104        else:
1105            dups = sum_ranks = 0
1106
1107        total_tie_count = 0
1108        count = 0.0
1109        for j in range(k):
1110            if rank_t is not object:
1111                sum_ranks += j + 1
1112                dups += 1
1113
1114            val = values[i, j]
1115
1116            if rank_t is not uint64_t:
1117                if rank_t is object:
1118                    skip_condition = (val is nan_value) and keep_na
1119                else:
1120                    skip_condition = (val == nan_value) and keep_na
1121                if skip_condition:
1122                    ranks[i, argsorted[i, j]] = NaN
1123
1124                    if rank_t is object:
1125                        infs += 1
1126
1127                    continue
1128
1129            count += 1.0
1130
1131            if rank_t is object:
1132                sum_ranks += (j - infs) + 1
1133                dups += 1
1134
1135            if rank_t is object:
1136                condition = j == k - 1 or are_diff(values[i, j + 1], val)
1137            else:
1138                condition = j == k - 1 or values[i, j + 1] != val
1139
1140            if condition:
1141                if tiebreak == TIEBREAK_AVERAGE:
1142                    for z in range(j - dups + 1, j + 1):
1143                        ranks[i, argsorted[i, z]] = sum_ranks / dups
1144                elif tiebreak == TIEBREAK_MIN:
1145                    for z in range(j - dups + 1, j + 1):
1146                        ranks[i, argsorted[i, z]] = j - dups + 2
1147                elif tiebreak == TIEBREAK_MAX:
1148                    for z in range(j - dups + 1, j + 1):
1149                        ranks[i, argsorted[i, z]] = j + 1
1150                elif tiebreak == TIEBREAK_FIRST:
1151                    if rank_t is object:
1152                        raise ValueError('first not supported for non-numeric data')
1153                    else:
1154                        for z in range(j - dups + 1, j + 1):
1155                            ranks[i, argsorted[i, z]] = z + 1
1156                elif tiebreak == TIEBREAK_FIRST_DESCENDING:
1157                    for z in range(j - dups + 1, j + 1):
1158                        ranks[i, argsorted[i, z]] = 2 * j - z - dups + 2
1159                elif tiebreak == TIEBREAK_DENSE:
1160                    total_tie_count += 1
1161                    for z in range(j - dups + 1, j + 1):
1162                        ranks[i, argsorted[i, z]] = total_tie_count
1163                sum_ranks = dups = 0
1164        if pct:
1165            if tiebreak == TIEBREAK_DENSE:
1166                ranks[i, :] /= total_tie_count
1167            else:
1168                ranks[i, :] /= count
1169    if axis == 0:
1170        return ranks.T
1171    else:
1172        return ranks
1173
1174
1175ctypedef fused diff_t:
1176    float64_t
1177    float32_t
1178    int8_t
1179    int16_t
1180    int32_t
1181    int64_t
1182
1183ctypedef fused out_t:
1184    float32_t
1185    float64_t
1186    int64_t
1187
1188
1189@cython.boundscheck(False)
1190@cython.wraparound(False)
1191def diff_2d(
1192    ndarray[diff_t, ndim=2] arr,  # TODO(cython 3) update to "const diff_t[:, :] arr"
1193    ndarray[out_t, ndim=2] out,
1194    Py_ssize_t periods,
1195    int axis,
1196    bint datetimelike=False,
1197):
1198    cdef:
1199        Py_ssize_t i, j, sx, sy, start, stop
1200        bint f_contig = arr.flags.f_contiguous
1201        # bint f_contig = arr.is_f_contig()  # TODO(cython 3)
1202        diff_t left, right
1203
1204    # Disable for unsupported dtype combinations,
1205    #  see https://github.com/cython/cython/issues/2646
1206    if (out_t is float32_t
1207            and not (diff_t is float32_t or diff_t is int8_t or diff_t is int16_t)):
1208        raise NotImplementedError
1209    elif (out_t is float64_t
1210          and (diff_t is float32_t or diff_t is int8_t or diff_t is int16_t)):
1211        raise NotImplementedError
1212    elif out_t is int64_t and diff_t is not int64_t:
1213        # We only have out_t of int64_t if we have datetimelike
1214        raise NotImplementedError
1215    else:
1216        # We put this inside an indented else block to avoid cython build
1217        #  warnings about unreachable code
1218        sx, sy = (<object>arr).shape
1219        with nogil:
1220            if f_contig:
1221                if axis == 0:
1222                    if periods >= 0:
1223                        start, stop = periods, sx
1224                    else:
1225                        start, stop = 0, sx + periods
1226                    for j in range(sy):
1227                        for i in range(start, stop):
1228                            left = arr[i, j]
1229                            right = arr[i - periods, j]
1230                            if out_t is int64_t and datetimelike:
1231                                if left == NPY_NAT or right == NPY_NAT:
1232                                    out[i, j] = NPY_NAT
1233                                else:
1234                                    out[i, j] = left - right
1235                            else:
1236                                out[i, j] = left - right
1237                else:
1238                    if periods >= 0:
1239                        start, stop = periods, sy
1240                    else:
1241                        start, stop = 0, sy + periods
1242                    for j in range(start, stop):
1243                        for i in range(sx):
1244                            left = arr[i, j]
1245                            right = arr[i, j - periods]
1246                            if out_t is int64_t and datetimelike:
1247                                if left == NPY_NAT or right == NPY_NAT:
1248                                    out[i, j] = NPY_NAT
1249                                else:
1250                                    out[i, j] = left - right
1251                            else:
1252                                out[i, j] = left - right
1253            else:
1254                if axis == 0:
1255                    if periods >= 0:
1256                        start, stop = periods, sx
1257                    else:
1258                        start, stop = 0, sx + periods
1259                    for i in range(start, stop):
1260                        for j in range(sy):
1261                            left = arr[i, j]
1262                            right = arr[i - periods, j]
1263                            if out_t is int64_t and datetimelike:
1264                                if left == NPY_NAT or right == NPY_NAT:
1265                                    out[i, j] = NPY_NAT
1266                                else:
1267                                    out[i, j] = left - right
1268                            else:
1269                                out[i, j] = left - right
1270                else:
1271                    if periods >= 0:
1272                        start, stop = periods, sy
1273                    else:
1274                        start, stop = 0, sy + periods
1275                    for i in range(sx):
1276                        for j in range(start, stop):
1277                            left = arr[i, j]
1278                            right = arr[i, j - periods]
1279                            if out_t is int64_t and datetimelike:
1280                                if left == NPY_NAT or right == NPY_NAT:
1281                                    out[i, j] = NPY_NAT
1282                                else:
1283                                    out[i, j] = left - right
1284                            else:
1285                                out[i, j] = left - right
1286
1287
1288# generated from template
1289include "algos_common_helper.pxi"
1290include "algos_take_helper.pxi"
1291