1#cython: language_level=3
2from libc cimport math
3cimport cython
4cimport numpy as np
5
6
7ctypedef fused dtype:
8    np.uint8_t
9    np.uint16_t
10    np.uint32_t
11    np.uint64_t
12    np.int8_t
13    np.int16_t
14    np.int32_t
15    np.int64_t
16    np.float32_t
17    np.float64_t
18    np.longdouble_t
19
20
21@cython.boundscheck(False)
22@cython.nonecheck(False)
23@cython.wraparound(False)
24@cython.cdivision(True)
25cpdef double ks_2samp(dtype[:] data1, dtype[:] data2):
26    cdef:
27        size_t i = 0, j = 0, n1 = data1.shape[0], n2 = data2.shape[0]
28        dtype d1i, d2j
29        double d = 0, mind = 0, maxd = 0, inv_n1 = 1. / n1, inv_n2 = 1. / n2
30    while i < n1 and j < n2:
31        d1i = data1[i]
32        d2j = data2[j]
33        if d1i <= d2j:
34            while i < n1 and data1[i] == d1i:
35                d += inv_n1
36                i += 1
37        if d1i >= d2j:
38            while j < n2 and data2[j] == d2j:
39                d -= inv_n2
40                j += 1
41        mind = min(mind, d)
42        maxd = max(maxd, d)
43    return maxd - mind
44