1# Author: Leland McInnes <leland.mcinnes@gmail.com>
2# Enough simple sparse operations in numba to enable sparse UMAP
3#
4# License: BSD 3 clause
5from __future__ import print_function
6import locale
7import numpy as np
8import numba
9
10from pynndescent.utils import norm, tau_rand
11from pynndescent.distances import kantorovich
12
13locale.setlocale(locale.LC_NUMERIC, "C")
14
15FLOAT32_EPS = np.finfo(np.float32).eps
16FLOAT32_MAX = np.finfo(np.float32).max
17
18# Just reproduce a simpler version of numpy isclose (not numba supported yet)
19@numba.njit(cache=True)
20def isclose(a, b, rtol=1.0e-5, atol=1.0e-8):
21    diff = np.abs(a - b)
22    return diff <= (atol + rtol * np.abs(b))
23
24
25# Just reproduce a simpler version of numpy unique (not numba supported yet)
26@numba.njit(cache=True)
27def arr_unique(arr):
28    aux = np.sort(arr)
29    flag = np.concatenate((np.ones(1, dtype=np.bool_), aux[1:] != aux[:-1]))
30    return aux[flag]
31
32
33# Just reproduce a simpler version of numpy union1d (not numba supported yet)
34@numba.njit(cache=True)
35def arr_union(ar1, ar2):
36    if ar1.shape[0] == 0:
37        return ar2
38    elif ar2.shape[0] == 0:
39        return ar1
40    else:
41        return arr_unique(np.concatenate((ar1, ar2)))
42
43
44# Just reproduce a simpler version of numpy intersect1d (not numba supported
45# yet)
46@numba.njit(cache=True)
47def arr_intersect(ar1, ar2):
48    aux = np.concatenate((ar1, ar2))
49    aux.sort()
50    return aux[:-1][aux[1:] == aux[:-1]]
51
52
53@numba.njit(
54    [
55        numba.types.Tuple(
56            (
57                numba.types.Array(numba.types.int32, 1, "C"),
58                numba.types.Array(numba.types.float32, 1, "C"),
59            )
60        )(
61            numba.types.Array(numba.types.int32, 1, "C", readonly=True),
62            numba.types.Array(numba.types.float32, 1, "C", readonly=True),
63            numba.types.Array(numba.types.int32, 1, "C", readonly=True),
64            numba.types.Array(numba.types.float32, 1, "C", readonly=True),
65        )
66    ],
67    fastmath=True,
68    locals={
69        "result_ind": numba.types.int32[::1],
70        "result_data": numba.types.float32[::1],
71        "val": numba.types.float32,
72        "i1": numba.types.int32,
73        "i2": numba.types.int32,
74        "j1": numba.types.int32,
75        "j2": numba.types.int32,
76    },
77    cache=True,
78)
79def sparse_sum(ind1, data1, ind2, data2):
80    result_size = ind1.shape[0] + ind2.shape[0]
81    result_ind = np.zeros(result_size, dtype=np.int32)
82    result_data = np.zeros(result_size, dtype=np.float32)
83
84    i1 = 0
85    i2 = 0
86    nnz = 0
87
88    # pass through both index lists
89    while i1 < ind1.shape[0] and i2 < ind2.shape[0]:
90        j1 = ind1[i1]
91        j2 = ind2[i2]
92
93        if j1 == j2:
94            val = data1[i1] + data2[i2]
95            if val != 0:
96                result_ind[nnz] = j1
97                result_data[nnz] = val
98                nnz += 1
99            i1 += 1
100            i2 += 1
101        elif j1 < j2:
102            val = data1[i1]
103            if val != 0:
104                result_ind[nnz] = j1
105                result_data[nnz] = val
106                nnz += 1
107            i1 += 1
108        else:
109            val = data2[i2]
110            if val != 0:
111                result_ind[nnz] = j2
112                result_data[nnz] = val
113                nnz += 1
114            i2 += 1
115
116    # pass over the tails
117    while i1 < ind1.shape[0]:
118        j1 = ind1[i1]
119        val = data1[i1]
120        if val != 0:
121            result_ind[nnz] = j1
122            result_data[nnz] = val
123            nnz += 1
124        i1 += 1
125
126    while i2 < ind2.shape[0]:
127        j2 = ind2[i2]
128        val = data2[i2]
129        if val != 0:
130            result_ind[nnz] = j2
131            result_data[nnz] = val
132            nnz += 1
133        i2 += 1
134
135    # truncate to the correct length in case there were zeros created
136    result_ind = result_ind[:nnz]
137    result_data = result_data[:nnz]
138
139    return result_ind, result_data
140
141
142@numba.njit(cache=True)
143def sparse_diff(ind1, data1, ind2, data2):
144    return sparse_sum(ind1, data1, ind2, -data2)
145
146
147@numba.njit(
148    [
149        # "Tuple((i4[::1],f4[::1]))(i4[::1],f4[::1],i4[::1],f4[::1])",
150        numba.types.Tuple(
151            (
152                numba.types.ListType(numba.types.int32),
153                numba.types.ListType(numba.types.float32),
154            )
155        )(
156            numba.types.Array(numba.types.int32, 1, "C", readonly=True),
157            numba.types.Array(numba.types.float32, 1, "C", readonly=True),
158            numba.types.Array(numba.types.int32, 1, "C", readonly=True),
159            numba.types.Array(numba.types.float32, 1, "C", readonly=True),
160        )
161    ],
162    fastmath=True,
163    locals={
164        "val": numba.types.float32,
165        "i1": numba.types.int32,
166        "i2": numba.types.int32,
167        "j1": numba.types.int32,
168        "j2": numba.types.int32,
169    },
170    cache=True,
171)
172def sparse_mul(ind1, data1, ind2, data2):
173    result_ind = numba.typed.List.empty_list(numba.types.int32)
174    result_data = numba.typed.List.empty_list(numba.types.float32)
175
176    i1 = 0
177    i2 = 0
178
179    # pass through both index lists
180    while i1 < ind1.shape[0] and i2 < ind2.shape[0]:
181        j1 = ind1[i1]
182        j2 = ind2[i2]
183
184        if j1 == j2:
185            val = data1[i1] * data2[i2]
186            if val != 0:
187                result_ind.append(j1)
188                result_data.append(val)
189            i1 += 1
190            i2 += 1
191        elif j1 < j2:
192            i1 += 1
193        else:
194            i2 += 1
195
196    return result_ind, result_data
197
198
199@numba.njit(cache=True)
200def sparse_euclidean(ind1, data1, ind2, data2):
201    _, aux_data = sparse_diff(ind1, data1, ind2, data2)
202    result = 0.0
203    for i in range(aux_data.shape[0]):
204        result += aux_data[i] ** 2
205    return np.sqrt(result)
206
207
208@numba.njit(
209    [
210        "f4(i4[::1],f4[::1],i4[::1],f4[::1])",
211        numba.types.float32(
212            numba.types.Array(numba.types.int32, 1, "C", readonly=True),
213            numba.types.Array(numba.types.float32, 1, "C", readonly=True),
214            numba.types.Array(numba.types.int32, 1, "C", readonly=True),
215            numba.types.Array(numba.types.float32, 1, "C", readonly=True),
216        ),
217    ],
218    fastmath=True,
219    locals={
220        "aux_data": numba.types.float32[::1],
221        "result": numba.types.float32,
222        "diff": numba.types.float32,
223        "dim": numba.types.intp,
224        "i": numba.types.uint16,
225    },
226    cache=True,
227)
228def sparse_squared_euclidean(ind1, data1, ind2, data2):
229    _, aux_data = sparse_diff(ind1, data1, ind2, data2)
230    result = 0.0
231    dim = len(aux_data)
232    for i in range(dim):
233        result += aux_data[i] * aux_data[i]
234    return result
235
236
237@numba.njit(cache=True)
238def sparse_manhattan(ind1, data1, ind2, data2):
239    _, aux_data = sparse_diff(ind1, data1, ind2, data2)
240    result = 0.0
241    for i in range(aux_data.shape[0]):
242        result += np.abs(aux_data[i])
243    return result
244
245
246@numba.njit(cache=True)
247def sparse_chebyshev(ind1, data1, ind2, data2):
248    _, aux_data = sparse_diff(ind1, data1, ind2, data2)
249    result = 0.0
250    for i in range(aux_data.shape[0]):
251        result = max(result, np.abs(aux_data[i]))
252    return result
253
254
255@numba.njit(cache=True)
256def sparse_minkowski(ind1, data1, ind2, data2, p=2.0):
257    _, aux_data = sparse_diff(ind1, data1, ind2, data2)
258    result = 0.0
259    for i in range(aux_data.shape[0]):
260        result += np.abs(aux_data[i]) ** p
261    return result ** (1.0 / p)
262
263
264@numba.njit(cache=True)
265def sparse_hamming(ind1, data1, ind2, data2, n_features):
266    num_not_equal = sparse_diff(ind1, data1, ind2, data2)[0].shape[0]
267    return float(num_not_equal) / n_features
268
269
270@numba.njit(cache=True)
271def sparse_canberra(ind1, data1, ind2, data2):
272    abs_data1 = np.abs(data1)
273    abs_data2 = np.abs(data2)
274    denom_inds, denom_data = sparse_sum(ind1, abs_data1, ind2, abs_data2)
275    denom_data = (1.0 / denom_data).astype(np.float32)
276    numer_inds, numer_data = sparse_diff(ind1, data1, ind2, data2)
277    numer_data = np.abs(numer_data)
278
279    _, val_data = sparse_mul(numer_inds, numer_data, denom_inds, denom_data)
280    result = 0.0
281    for val in val_data:
282        result += val
283
284    return result
285
286
287@numba.njit(
288    [
289        "f4(i4[::1],f4[::1],i4[::1],f4[::1])",
290        numba.types.float32(
291            numba.types.Array(numba.types.int32, 1, "C", readonly=True),
292            numba.types.Array(numba.types.float32, 1, "C", readonly=True),
293            numba.types.Array(numba.types.int32, 1, "C", readonly=True),
294            numba.types.Array(numba.types.float32, 1, "C", readonly=True),
295        ),
296    ],
297    fastmath=True,
298    cache=True,
299)
300def sparse_bray_curtis(ind1, data1, ind2, data2):  # pragma: no cover
301    _, denom_data = sparse_sum(ind1, data1, ind2, data2)
302    denom_data = np.abs(denom_data)
303
304    if denom_data.shape[0] == 0:
305        return 0.0
306
307    denominator = np.sum(denom_data)
308
309    if denominator == 0.0:
310        return 0.0
311
312    _, numer_data = sparse_diff(ind1, data1, ind2, data2)
313    numer_data = np.abs(numer_data)
314
315    numerator = np.sum(numer_data)
316
317    return float(numerator) / denominator
318
319
320@numba.njit(cache=True)
321def sparse_jaccard(ind1, data1, ind2, data2):
322    num_non_zero = arr_union(ind1, ind2).shape[0]
323    num_equal = arr_intersect(ind1, ind2).shape[0]
324
325    if num_non_zero == 0:
326        return 0.0
327    else:
328        return float(num_non_zero - num_equal) / num_non_zero
329
330
331@numba.njit(
332    [
333        "f4(i4[::1],f4[::1],i4[::1],f4[::1])",
334        numba.types.float32(
335            numba.types.Array(numba.types.int32, 1, "C", readonly=True),
336            numba.types.Array(numba.types.float32, 1, "C", readonly=True),
337            numba.types.Array(numba.types.int32, 1, "C", readonly=True),
338            numba.types.Array(numba.types.float32, 1, "C", readonly=True),
339        ),
340    ],
341    fastmath=True,
342    locals={"num_non_zero": numba.types.intp, "num_equal": numba.types.intp},
343    cache=True,
344)
345def sparse_alternative_jaccard(ind1, data1, ind2, data2):
346    num_non_zero = arr_union(ind1, ind2).shape[0]
347    num_equal = arr_intersect(ind1, ind2).shape[0]
348
349    if num_non_zero == 0:
350        return 0.0
351    else:
352        return -np.log2(num_equal / num_non_zero)
353
354
355@numba.vectorize(fastmath=True)
356def correct_alternative_jaccard(v):
357    return 1.0 - pow(2.0, -v)
358
359
360@numba.njit(cache=True)
361def sparse_matching(ind1, data1, ind2, data2, n_features):
362    num_true_true = arr_intersect(ind1, ind2).shape[0]
363    num_non_zero = arr_union(ind1, ind2).shape[0]
364    num_not_equal = num_non_zero - num_true_true
365
366    return float(num_not_equal) / n_features
367
368
369@numba.njit(cache=True)
370def sparse_dice(ind1, data1, ind2, data2):
371    num_true_true = arr_intersect(ind1, ind2).shape[0]
372    num_non_zero = arr_union(ind1, ind2).shape[0]
373    num_not_equal = num_non_zero - num_true_true
374
375    if num_not_equal == 0.0:
376        return 0.0
377    else:
378        return num_not_equal / (2.0 * num_true_true + num_not_equal)
379
380
381@numba.njit(cache=True)
382def sparse_kulsinski(ind1, data1, ind2, data2, n_features):
383    num_true_true = arr_intersect(ind1, ind2).shape[0]
384    num_non_zero = arr_union(ind1, ind2).shape[0]
385    num_not_equal = num_non_zero - num_true_true
386
387    if num_not_equal == 0:
388        return 0.0
389    else:
390        return float(num_not_equal - num_true_true + n_features) / (
391            num_not_equal + n_features
392        )
393
394
395@numba.njit(cache=True)
396def sparse_rogers_tanimoto(ind1, data1, ind2, data2, n_features):
397    num_true_true = arr_intersect(ind1, ind2).shape[0]
398    num_non_zero = arr_union(ind1, ind2).shape[0]
399    num_not_equal = num_non_zero - num_true_true
400
401    return (2.0 * num_not_equal) / (n_features + num_not_equal)
402
403
404@numba.njit(cache=True)
405def sparse_russellrao(ind1, data1, ind2, data2, n_features):
406    if ind1.shape[0] == ind2.shape[0] and np.all(ind1 == ind2):
407        return 0.0
408
409    num_true_true = arr_intersect(ind1, ind2).shape[0]
410
411    if num_true_true == np.sum(data1 != 0) and num_true_true == np.sum(data2 != 0):
412        return 0.0
413    else:
414        return float(n_features - num_true_true) / (n_features)
415
416
417@numba.njit(cache=True)
418def sparse_sokal_michener(ind1, data1, ind2, data2, n_features):
419    num_true_true = arr_intersect(ind1, ind2).shape[0]
420    num_non_zero = arr_union(ind1, ind2).shape[0]
421    num_not_equal = num_non_zero - num_true_true
422
423    return (2.0 * num_not_equal) / (n_features + num_not_equal)
424
425
426@numba.njit(cache=True)
427def sparse_sokal_sneath(ind1, data1, ind2, data2):
428    num_true_true = arr_intersect(ind1, ind2).shape[0]
429    num_non_zero = arr_union(ind1, ind2).shape[0]
430    num_not_equal = num_non_zero - num_true_true
431
432    if num_not_equal == 0.0:
433        return 0.0
434    else:
435        return num_not_equal / (0.5 * num_true_true + num_not_equal)
436
437
438@numba.njit(cache=True)
439def sparse_cosine(ind1, data1, ind2, data2):
440    _, aux_data = sparse_mul(ind1, data1, ind2, data2)
441    result = 0.0
442    norm1 = norm(data1)
443    norm2 = norm(data2)
444
445    for val in aux_data:
446        result += val
447
448    if norm1 == 0.0 and norm2 == 0.0:
449        return 0.0
450    elif norm1 == 0.0 or norm2 == 0.0:
451        return 1.0
452    else:
453        return 1.0 - (result / (norm1 * norm2))
454
455
456@numba.njit(
457    #    "f4(i4[::1],f4[::1],i4[::1],f4[::1])",
458    fastmath=True,
459    locals={
460        "result": numba.types.float32,
461        "norm_x": numba.types.float32,
462        "norm_y": numba.types.float32,
463        "dim": numba.types.intp,
464        "i": numba.types.uint16,
465    },
466    cache=True,
467)
468def sparse_alternative_cosine(ind1, data1, ind2, data2):
469    _, aux_data = sparse_mul(ind1, data1, ind2, data2)
470    result = 0.0
471    norm_x = norm(data1)
472    norm_y = norm(data2)
473    dim = len(aux_data)
474    for i in range(dim):
475        result += aux_data[i]
476    if norm_x == 0.0 and norm_y == 0.0:
477        return 0.0
478    elif norm_x == 0.0 or norm_y == 0.0:
479        return FLOAT32_MAX
480    elif result <= 0.0:
481        return FLOAT32_MAX
482    else:
483        result = (norm_x * norm_y) / result
484        return np.log2(result)
485
486
487@numba.vectorize(fastmath=True, cache=True)
488def sparse_correct_alternative_cosine(d):
489    if isclose(0.0, abs(d), atol=1e-7) or d < 0.0:
490        return 0.0
491    else:
492        return 1.0 - pow(2.0, -d)
493
494
495@numba.njit(cache=True)
496def sparse_dot(ind1, data1, ind2, data2):
497    _, aux_data = sparse_mul(ind1, data1, ind2, data2)
498    result = 0.0
499
500    for val in aux_data:
501        result += val
502
503    return 1.0 - result
504
505
506@numba.njit(
507    #    "f4(i4[::1],f4[::1],i4[::1],f4[::1])",
508    fastmath=True,
509    locals={
510        "result": numba.types.float32,
511        "dim": numba.types.intp,
512        "i": numba.types.uint16,
513    },
514    cache=True,
515)
516def sparse_alternative_dot(ind1, data1, ind2, data2):
517    _, aux_data = sparse_mul(ind1, data1, ind2, data2)
518    result = 0.0
519    dim = len(aux_data)
520    for i in range(dim):
521        result += aux_data[i]
522
523    if result <= 0.0:
524        return FLOAT32_MAX
525    else:
526        return -np.log2(result)
527
528
529@numba.njit(cache=True)
530def sparse_correlation(ind1, data1, ind2, data2, n_features):
531
532    mu_x = 0.0
533    mu_y = 0.0
534    dot_product = 0.0
535
536    if ind1.shape[0] == 0 and ind2.shape[0] == 0:
537        return 0.0
538    elif ind1.shape[0] == 0 or ind2.shape[0] == 0:
539        return 1.0
540
541    for i in range(data1.shape[0]):
542        mu_x += data1[i]
543    for i in range(data2.shape[0]):
544        mu_y += data2[i]
545
546    mu_x /= n_features
547    mu_y /= n_features
548
549    shifted_data1 = np.empty(data1.shape[0], dtype=np.float32)
550    shifted_data2 = np.empty(data2.shape[0], dtype=np.float32)
551
552    for i in range(data1.shape[0]):
553        shifted_data1[i] = data1[i] - mu_x
554    for i in range(data2.shape[0]):
555        shifted_data2[i] = data2[i] - mu_y
556
557    norm1 = np.sqrt(
558        (norm(shifted_data1) ** 2) + (n_features - ind1.shape[0]) * (mu_x ** 2)
559    )
560    norm2 = np.sqrt(
561        (norm(shifted_data2) ** 2) + (n_features - ind2.shape[0]) * (mu_y ** 2)
562    )
563
564    dot_prod_inds, dot_prod_data = sparse_mul(ind1, shifted_data1, ind2, shifted_data2)
565
566    common_indices = set(dot_prod_inds)
567
568    for val in dot_prod_data:
569        dot_product += val
570
571    for i in range(ind1.shape[0]):
572        if ind1[i] not in common_indices:
573            dot_product -= shifted_data1[i] * (mu_y)
574
575    for i in range(ind2.shape[0]):
576        if ind2[i] not in common_indices:
577            dot_product -= shifted_data2[i] * (mu_x)
578
579    all_indices = arr_union(ind1, ind2)
580    dot_product += mu_x * mu_y * (n_features - all_indices.shape[0])
581
582    if norm1 == 0.0 and norm2 == 0.0:
583        return 0.0
584    elif dot_product == 0.0:
585        return 1.0
586    else:
587        return 1.0 - (dot_product / (norm1 * norm2))
588
589
590@numba.njit(cache=True)
591def sparse_hellinger(ind1, data1, ind2, data2):
592    aux_inds, aux_data = sparse_mul(ind1, data1, ind2, data2)
593    result = 0.0
594    norm1 = np.sum(data1)
595    norm2 = np.sum(data2)
596    sqrt_norm_prod = np.sqrt(norm1 * norm2)
597
598    for val in aux_data:
599        result += np.sqrt(val)
600
601    if norm1 == 0.0 and norm2 == 0.0:
602        return 0.0
603    elif norm1 == 0.0 or norm2 == 0.0:
604        return 1.0
605    elif result > sqrt_norm_prod:
606        return 0.0
607    else:
608        return np.sqrt(1.0 - (result / sqrt_norm_prod))
609
610
611@numba.njit(
612    #   "f4(i4[::1],f4[::1],i4[::1],f4[::1])",
613    fastmath=True,
614    locals={
615        "result": numba.types.float32,
616        "l1_norm_x": numba.types.float32,
617        "l1_norm_y": numba.types.float32,
618        "dim": numba.types.intp,
619        "i": numba.types.uint16,
620    },
621    cache=True,
622)
623def sparse_alternative_hellinger(ind1, data1, ind2, data2):
624    aux_inds, aux_data = sparse_mul(ind1, data1, ind2, data2)
625    result = 0.0
626    l1_norm_x = np.sum(data1)
627    l1_norm_y = np.sum(data2)
628    dim = len(aux_data)
629
630    for i in range(dim):
631        result += np.sqrt(aux_data[i])
632
633    if l1_norm_x == 0 and l1_norm_y == 0:
634        return 0.0
635    elif l1_norm_x == 0 or l1_norm_y == 0:
636        return FLOAT32_MAX
637    elif result <= 0:
638        return FLOAT32_MAX
639    else:
640        result = np.sqrt(l1_norm_x * l1_norm_y) / result
641        return np.log2(result)
642
643
644@numba.vectorize(fastmath=True, cache=True)
645def sparse_correct_alternative_hellinger(d):
646    if isclose(0.0, abs(d), atol=1e-7) or d < 0.0:
647        return 0.0
648    else:
649        return np.sqrt(1.0 - pow(2.0, -d))
650
651
652@numba.njit(cache=True)
653def dummy_ground_metric(x, y):
654    return np.float32(not x == y)
655
656
657def create_ground_metric(ground_vectors, metric):
658    """Generate a "ground_metric" suitable for passing to a ``sparse_kantorovich``
659    distance function. This should be a metric that, given indices of the data,
660    should produce the ground distance between the corresponding vectors. This
661    allows the construction of a cost_matrix or ground_distance_matrix between
662    sparse samples on the fly -- without having to compute an all pairs distance.
663    This is particularly useful for things like word-mover-distance.
664
665    For example, to create a suitable ground_metric for word-mover distance one
666    would use:
667
668    ``wmd_ground_metric = create_ground_metric(word_vectors, cosine)``
669
670    Parameters
671    ----------
672    ground_vectors: array of shape (n_features, d)
673        The set of vectors between which ground_distances are measured. That is,
674        there should be a vector for each feature of the space one wishes to compute
675        Kantorovich distance over.
676
677    metric: callable (numba jitted)
678        The underlying metric used to cpmpute distances between feature vectors.
679
680    Returns
681    -------
682    ground_metric: callable (numba jitted)
683        A ground metric suitable for passing to ``sparse_kantorovich``.
684    """
685
686    @numba.njit()
687    def ground_metric(index1, index2):
688        return metric(ground_vectors[index1], ground_vectors[index2])
689
690    return ground_metric
691
692
693@numba.njit(cache=True)
694def sparse_kantorovich(ind1, data1, ind2, data2, ground_metric=dummy_ground_metric):
695
696    cost_matrix = np.empty((ind1.shape[0], ind2.shape[0]))
697    for i in range(ind1.shape[0]):
698        for j in range(ind2.shape[0]):
699            cost_matrix[i, j] = ground_metric(ind1[i], ind2[j])
700
701    return kantorovich(data1, data2, cost_matrix)
702
703
704@numba.njit(parallel=True, cache=True)
705def diversify(
706    indices,
707    distances,
708    data_indices,
709    data_indptr,
710    data_data,
711    dist,
712    rng_state,
713    prune_probability=1.0,
714):
715
716    for i in numba.prange(indices.shape[0]):
717
718        new_indices = [indices[i, 0]]
719        new_distances = [distances[i, 0]]
720        for j in range(1, indices.shape[1]):
721            if indices[i, j] < 0:
722                break
723
724            flag = True
725            for k in range(len(new_indices)):
726                c = new_indices[k]
727
728                from_ind = data_indices[
729                    data_indptr[indices[i, j]] : data_indptr[indices[i, j] + 1]
730                ]
731                from_data = data_data[
732                    data_indptr[indices[i, j]] : data_indptr[indices[i, j] + 1]
733                ]
734
735                to_ind = data_indices[data_indptr[c] : data_indptr[c + 1]]
736                to_data = data_data[data_indptr[c] : data_indptr[c + 1]]
737
738                d = dist(from_ind, from_data, to_ind, to_data)
739                if new_distances[k] > FLOAT32_EPS and d < distances[i, j]:
740                    if tau_rand(rng_state) < prune_probability:
741                        flag = False
742                        break
743
744            if flag:
745                new_indices.append(indices[i, j])
746                new_distances.append(distances[i, j])
747
748        for j in range(indices.shape[1]):
749            if j < len(new_indices):
750                indices[i, j] = new_indices[j]
751                distances[i, j] = new_distances[j]
752            else:
753                indices[i, j] = -1
754                distances[i, j] = np.inf
755
756    return indices, distances
757
758
759@numba.njit(parallel=True, cache=True)
760def diversify_csr(
761    graph_indptr,
762    graph_indices,
763    graph_data,
764    data_indptr,
765    data_indices,
766    data_data,
767    dist,
768    rng_state,
769    prune_probability=1.0,
770):
771
772    n_nodes = graph_indptr.shape[0] - 1
773
774    for i in numba.prange(n_nodes):
775
776        current_indices = graph_indices[graph_indptr[i] : graph_indptr[i + 1]]
777        current_data = graph_data[graph_indptr[i] : graph_indptr[i + 1]]
778
779        order = np.argsort(current_data)
780        retained = np.ones(order.shape[0], dtype=np.int8)
781
782        for idx in range(1, order.shape[0]):
783
784            j = order[idx]
785
786            for k in range(idx):
787
788                l = order[k]
789
790                if retained[l] == 1:
791                    p = current_indices[j]
792                    q = current_indices[l]
793
794                    from_inds = data_indices[data_indptr[p] : data_indptr[p + 1]]
795                    from_data = data_data[data_indptr[p] : data_indptr[p + 1]]
796
797                    to_inds = data_indices[data_indptr[q] : data_indptr[q + 1]]
798                    to_data = data_data[data_indptr[q] : data_indptr[q + 1]]
799                    d = dist(from_inds, from_data, to_inds, to_data)
800
801                    if current_data[l] > FLOAT32_EPS and d < current_data[j]:
802                        if tau_rand(rng_state) < prune_probability:
803                            retained[j] = 0
804                            break
805
806        for idx in range(order.shape[0]):
807            j = order[idx]
808            if retained[j] == 0:
809                graph_data[graph_indptr[i] + j] = 0
810
811    return
812
813
814sparse_named_distances = {
815    # general minkowski distances
816    "euclidean": sparse_euclidean,
817    "l2": sparse_euclidean,
818    "sqeuclidean": sparse_squared_euclidean,
819    "manhattan": sparse_manhattan,
820    "l1": sparse_manhattan,
821    "taxicab": sparse_manhattan,
822    "chebyshev": sparse_chebyshev,
823    "linf": sparse_chebyshev,
824    "linfty": sparse_chebyshev,
825    "linfinity": sparse_chebyshev,
826    "minkowski": sparse_minkowski,
827    # Other distances
828    "canberra": sparse_canberra,
829    "kantorovich": sparse_kantorovich,
830    "wasserstein": sparse_kantorovich,
831    "braycurtis": sparse_bray_curtis,
832    # Binary distances
833    "hamming": sparse_hamming,
834    "jaccard": sparse_jaccard,
835    "dice": sparse_dice,
836    "matching": sparse_matching,
837    "kulsinski": sparse_kulsinski,
838    "rogerstanimoto": sparse_rogers_tanimoto,
839    "russellrao": sparse_russellrao,
840    "sokalmichener": sparse_sokal_michener,
841    "sokalsneath": sparse_sokal_sneath,
842    # Angular distances
843    "cosine": sparse_cosine,
844    "correlation": sparse_correlation,
845    "hellinger": sparse_hellinger,
846}
847
848sparse_need_n_features = (
849    "hamming",
850    "matching",
851    "kulsinski",
852    "rogerstanimoto",
853    "russellrao",
854    "sokalmichener",
855    "correlation",
856)
857
858
859# Some distances have a faster to compute alternative that
860# retains the same ordering of distances. We can compute with
861# this instead, and then correct the final distances when complete.
862# This provides a list of distances that have such an alternative
863# along with the alternative distance function and the correction
864# function to be applied.
865sparse_fast_distance_alternatives = {
866    "euclidean": {"dist": sparse_squared_euclidean, "correction": np.sqrt},
867    "l2": {"dist": sparse_squared_euclidean, "correction": np.sqrt},
868    "cosine": {
869        "dist": sparse_alternative_cosine,
870        "correction": sparse_correct_alternative_cosine,
871    },
872    "dot": {
873        "dist": sparse_alternative_dot,
874        "correction": sparse_correct_alternative_cosine,
875    },
876    "hellinger": {
877        "dist": sparse_alternative_hellinger,
878        "correction": sparse_correct_alternative_hellinger,
879    },
880    "jaccard": {
881        "dist": sparse_alternative_jaccard,
882        "correction": correct_alternative_jaccard,
883    },
884}
885