1# Author: Leland McInnes <leland.mcinnes@gmail.com> 2# Enough simple sparse operations in numba to enable sparse UMAP 3# 4# License: BSD 3 clause 5from __future__ import print_function 6import locale 7import numpy as np 8import numba 9 10from pynndescent.utils import norm, tau_rand 11from pynndescent.distances import kantorovich 12 13locale.setlocale(locale.LC_NUMERIC, "C") 14 15FLOAT32_EPS = np.finfo(np.float32).eps 16FLOAT32_MAX = np.finfo(np.float32).max 17 18# Just reproduce a simpler version of numpy isclose (not numba supported yet) 19@numba.njit(cache=True) 20def isclose(a, b, rtol=1.0e-5, atol=1.0e-8): 21 diff = np.abs(a - b) 22 return diff <= (atol + rtol * np.abs(b)) 23 24 25# Just reproduce a simpler version of numpy unique (not numba supported yet) 26@numba.njit(cache=True) 27def arr_unique(arr): 28 aux = np.sort(arr) 29 flag = np.concatenate((np.ones(1, dtype=np.bool_), aux[1:] != aux[:-1])) 30 return aux[flag] 31 32 33# Just reproduce a simpler version of numpy union1d (not numba supported yet) 34@numba.njit(cache=True) 35def arr_union(ar1, ar2): 36 if ar1.shape[0] == 0: 37 return ar2 38 elif ar2.shape[0] == 0: 39 return ar1 40 else: 41 return arr_unique(np.concatenate((ar1, ar2))) 42 43 44# Just reproduce a simpler version of numpy intersect1d (not numba supported 45# yet) 46@numba.njit(cache=True) 47def arr_intersect(ar1, ar2): 48 aux = np.concatenate((ar1, ar2)) 49 aux.sort() 50 return aux[:-1][aux[1:] == aux[:-1]] 51 52 53@numba.njit( 54 [ 55 numba.types.Tuple( 56 ( 57 numba.types.Array(numba.types.int32, 1, "C"), 58 numba.types.Array(numba.types.float32, 1, "C"), 59 ) 60 )( 61 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 62 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 63 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 64 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 65 ) 66 ], 67 fastmath=True, 68 locals={ 69 "result_ind": numba.types.int32[::1], 70 "result_data": numba.types.float32[::1], 71 "val": numba.types.float32, 72 "i1": numba.types.int32, 73 "i2": numba.types.int32, 74 "j1": numba.types.int32, 75 "j2": numba.types.int32, 76 }, 77 cache=True, 78) 79def sparse_sum(ind1, data1, ind2, data2): 80 result_size = ind1.shape[0] + ind2.shape[0] 81 result_ind = np.zeros(result_size, dtype=np.int32) 82 result_data = np.zeros(result_size, dtype=np.float32) 83 84 i1 = 0 85 i2 = 0 86 nnz = 0 87 88 # pass through both index lists 89 while i1 < ind1.shape[0] and i2 < ind2.shape[0]: 90 j1 = ind1[i1] 91 j2 = ind2[i2] 92 93 if j1 == j2: 94 val = data1[i1] + data2[i2] 95 if val != 0: 96 result_ind[nnz] = j1 97 result_data[nnz] = val 98 nnz += 1 99 i1 += 1 100 i2 += 1 101 elif j1 < j2: 102 val = data1[i1] 103 if val != 0: 104 result_ind[nnz] = j1 105 result_data[nnz] = val 106 nnz += 1 107 i1 += 1 108 else: 109 val = data2[i2] 110 if val != 0: 111 result_ind[nnz] = j2 112 result_data[nnz] = val 113 nnz += 1 114 i2 += 1 115 116 # pass over the tails 117 while i1 < ind1.shape[0]: 118 j1 = ind1[i1] 119 val = data1[i1] 120 if val != 0: 121 result_ind[nnz] = j1 122 result_data[nnz] = val 123 nnz += 1 124 i1 += 1 125 126 while i2 < ind2.shape[0]: 127 j2 = ind2[i2] 128 val = data2[i2] 129 if val != 0: 130 result_ind[nnz] = j2 131 result_data[nnz] = val 132 nnz += 1 133 i2 += 1 134 135 # truncate to the correct length in case there were zeros created 136 result_ind = result_ind[:nnz] 137 result_data = result_data[:nnz] 138 139 return result_ind, result_data 140 141 142@numba.njit(cache=True) 143def sparse_diff(ind1, data1, ind2, data2): 144 return sparse_sum(ind1, data1, ind2, -data2) 145 146 147@numba.njit( 148 [ 149 # "Tuple((i4[::1],f4[::1]))(i4[::1],f4[::1],i4[::1],f4[::1])", 150 numba.types.Tuple( 151 ( 152 numba.types.ListType(numba.types.int32), 153 numba.types.ListType(numba.types.float32), 154 ) 155 )( 156 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 157 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 158 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 159 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 160 ) 161 ], 162 fastmath=True, 163 locals={ 164 "val": numba.types.float32, 165 "i1": numba.types.int32, 166 "i2": numba.types.int32, 167 "j1": numba.types.int32, 168 "j2": numba.types.int32, 169 }, 170 cache=True, 171) 172def sparse_mul(ind1, data1, ind2, data2): 173 result_ind = numba.typed.List.empty_list(numba.types.int32) 174 result_data = numba.typed.List.empty_list(numba.types.float32) 175 176 i1 = 0 177 i2 = 0 178 179 # pass through both index lists 180 while i1 < ind1.shape[0] and i2 < ind2.shape[0]: 181 j1 = ind1[i1] 182 j2 = ind2[i2] 183 184 if j1 == j2: 185 val = data1[i1] * data2[i2] 186 if val != 0: 187 result_ind.append(j1) 188 result_data.append(val) 189 i1 += 1 190 i2 += 1 191 elif j1 < j2: 192 i1 += 1 193 else: 194 i2 += 1 195 196 return result_ind, result_data 197 198 199@numba.njit(cache=True) 200def sparse_euclidean(ind1, data1, ind2, data2): 201 _, aux_data = sparse_diff(ind1, data1, ind2, data2) 202 result = 0.0 203 for i in range(aux_data.shape[0]): 204 result += aux_data[i] ** 2 205 return np.sqrt(result) 206 207 208@numba.njit( 209 [ 210 "f4(i4[::1],f4[::1],i4[::1],f4[::1])", 211 numba.types.float32( 212 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 213 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 214 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 215 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 216 ), 217 ], 218 fastmath=True, 219 locals={ 220 "aux_data": numba.types.float32[::1], 221 "result": numba.types.float32, 222 "diff": numba.types.float32, 223 "dim": numba.types.intp, 224 "i": numba.types.uint16, 225 }, 226 cache=True, 227) 228def sparse_squared_euclidean(ind1, data1, ind2, data2): 229 _, aux_data = sparse_diff(ind1, data1, ind2, data2) 230 result = 0.0 231 dim = len(aux_data) 232 for i in range(dim): 233 result += aux_data[i] * aux_data[i] 234 return result 235 236 237@numba.njit(cache=True) 238def sparse_manhattan(ind1, data1, ind2, data2): 239 _, aux_data = sparse_diff(ind1, data1, ind2, data2) 240 result = 0.0 241 for i in range(aux_data.shape[0]): 242 result += np.abs(aux_data[i]) 243 return result 244 245 246@numba.njit(cache=True) 247def sparse_chebyshev(ind1, data1, ind2, data2): 248 _, aux_data = sparse_diff(ind1, data1, ind2, data2) 249 result = 0.0 250 for i in range(aux_data.shape[0]): 251 result = max(result, np.abs(aux_data[i])) 252 return result 253 254 255@numba.njit(cache=True) 256def sparse_minkowski(ind1, data1, ind2, data2, p=2.0): 257 _, aux_data = sparse_diff(ind1, data1, ind2, data2) 258 result = 0.0 259 for i in range(aux_data.shape[0]): 260 result += np.abs(aux_data[i]) ** p 261 return result ** (1.0 / p) 262 263 264@numba.njit(cache=True) 265def sparse_hamming(ind1, data1, ind2, data2, n_features): 266 num_not_equal = sparse_diff(ind1, data1, ind2, data2)[0].shape[0] 267 return float(num_not_equal) / n_features 268 269 270@numba.njit(cache=True) 271def sparse_canberra(ind1, data1, ind2, data2): 272 abs_data1 = np.abs(data1) 273 abs_data2 = np.abs(data2) 274 denom_inds, denom_data = sparse_sum(ind1, abs_data1, ind2, abs_data2) 275 denom_data = (1.0 / denom_data).astype(np.float32) 276 numer_inds, numer_data = sparse_diff(ind1, data1, ind2, data2) 277 numer_data = np.abs(numer_data) 278 279 _, val_data = sparse_mul(numer_inds, numer_data, denom_inds, denom_data) 280 result = 0.0 281 for val in val_data: 282 result += val 283 284 return result 285 286 287@numba.njit( 288 [ 289 "f4(i4[::1],f4[::1],i4[::1],f4[::1])", 290 numba.types.float32( 291 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 292 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 293 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 294 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 295 ), 296 ], 297 fastmath=True, 298 cache=True, 299) 300def sparse_bray_curtis(ind1, data1, ind2, data2): # pragma: no cover 301 _, denom_data = sparse_sum(ind1, data1, ind2, data2) 302 denom_data = np.abs(denom_data) 303 304 if denom_data.shape[0] == 0: 305 return 0.0 306 307 denominator = np.sum(denom_data) 308 309 if denominator == 0.0: 310 return 0.0 311 312 _, numer_data = sparse_diff(ind1, data1, ind2, data2) 313 numer_data = np.abs(numer_data) 314 315 numerator = np.sum(numer_data) 316 317 return float(numerator) / denominator 318 319 320@numba.njit(cache=True) 321def sparse_jaccard(ind1, data1, ind2, data2): 322 num_non_zero = arr_union(ind1, ind2).shape[0] 323 num_equal = arr_intersect(ind1, ind2).shape[0] 324 325 if num_non_zero == 0: 326 return 0.0 327 else: 328 return float(num_non_zero - num_equal) / num_non_zero 329 330 331@numba.njit( 332 [ 333 "f4(i4[::1],f4[::1],i4[::1],f4[::1])", 334 numba.types.float32( 335 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 336 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 337 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 338 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 339 ), 340 ], 341 fastmath=True, 342 locals={"num_non_zero": numba.types.intp, "num_equal": numba.types.intp}, 343 cache=True, 344) 345def sparse_alternative_jaccard(ind1, data1, ind2, data2): 346 num_non_zero = arr_union(ind1, ind2).shape[0] 347 num_equal = arr_intersect(ind1, ind2).shape[0] 348 349 if num_non_zero == 0: 350 return 0.0 351 else: 352 return -np.log2(num_equal / num_non_zero) 353 354 355@numba.vectorize(fastmath=True) 356def correct_alternative_jaccard(v): 357 return 1.0 - pow(2.0, -v) 358 359 360@numba.njit(cache=True) 361def sparse_matching(ind1, data1, ind2, data2, n_features): 362 num_true_true = arr_intersect(ind1, ind2).shape[0] 363 num_non_zero = arr_union(ind1, ind2).shape[0] 364 num_not_equal = num_non_zero - num_true_true 365 366 return float(num_not_equal) / n_features 367 368 369@numba.njit(cache=True) 370def sparse_dice(ind1, data1, ind2, data2): 371 num_true_true = arr_intersect(ind1, ind2).shape[0] 372 num_non_zero = arr_union(ind1, ind2).shape[0] 373 num_not_equal = num_non_zero - num_true_true 374 375 if num_not_equal == 0.0: 376 return 0.0 377 else: 378 return num_not_equal / (2.0 * num_true_true + num_not_equal) 379 380 381@numba.njit(cache=True) 382def sparse_kulsinski(ind1, data1, ind2, data2, n_features): 383 num_true_true = arr_intersect(ind1, ind2).shape[0] 384 num_non_zero = arr_union(ind1, ind2).shape[0] 385 num_not_equal = num_non_zero - num_true_true 386 387 if num_not_equal == 0: 388 return 0.0 389 else: 390 return float(num_not_equal - num_true_true + n_features) / ( 391 num_not_equal + n_features 392 ) 393 394 395@numba.njit(cache=True) 396def sparse_rogers_tanimoto(ind1, data1, ind2, data2, n_features): 397 num_true_true = arr_intersect(ind1, ind2).shape[0] 398 num_non_zero = arr_union(ind1, ind2).shape[0] 399 num_not_equal = num_non_zero - num_true_true 400 401 return (2.0 * num_not_equal) / (n_features + num_not_equal) 402 403 404@numba.njit(cache=True) 405def sparse_russellrao(ind1, data1, ind2, data2, n_features): 406 if ind1.shape[0] == ind2.shape[0] and np.all(ind1 == ind2): 407 return 0.0 408 409 num_true_true = arr_intersect(ind1, ind2).shape[0] 410 411 if num_true_true == np.sum(data1 != 0) and num_true_true == np.sum(data2 != 0): 412 return 0.0 413 else: 414 return float(n_features - num_true_true) / (n_features) 415 416 417@numba.njit(cache=True) 418def sparse_sokal_michener(ind1, data1, ind2, data2, n_features): 419 num_true_true = arr_intersect(ind1, ind2).shape[0] 420 num_non_zero = arr_union(ind1, ind2).shape[0] 421 num_not_equal = num_non_zero - num_true_true 422 423 return (2.0 * num_not_equal) / (n_features + num_not_equal) 424 425 426@numba.njit(cache=True) 427def sparse_sokal_sneath(ind1, data1, ind2, data2): 428 num_true_true = arr_intersect(ind1, ind2).shape[0] 429 num_non_zero = arr_union(ind1, ind2).shape[0] 430 num_not_equal = num_non_zero - num_true_true 431 432 if num_not_equal == 0.0: 433 return 0.0 434 else: 435 return num_not_equal / (0.5 * num_true_true + num_not_equal) 436 437 438@numba.njit(cache=True) 439def sparse_cosine(ind1, data1, ind2, data2): 440 _, aux_data = sparse_mul(ind1, data1, ind2, data2) 441 result = 0.0 442 norm1 = norm(data1) 443 norm2 = norm(data2) 444 445 for val in aux_data: 446 result += val 447 448 if norm1 == 0.0 and norm2 == 0.0: 449 return 0.0 450 elif norm1 == 0.0 or norm2 == 0.0: 451 return 1.0 452 else: 453 return 1.0 - (result / (norm1 * norm2)) 454 455 456@numba.njit( 457 # "f4(i4[::1],f4[::1],i4[::1],f4[::1])", 458 fastmath=True, 459 locals={ 460 "result": numba.types.float32, 461 "norm_x": numba.types.float32, 462 "norm_y": numba.types.float32, 463 "dim": numba.types.intp, 464 "i": numba.types.uint16, 465 }, 466 cache=True, 467) 468def sparse_alternative_cosine(ind1, data1, ind2, data2): 469 _, aux_data = sparse_mul(ind1, data1, ind2, data2) 470 result = 0.0 471 norm_x = norm(data1) 472 norm_y = norm(data2) 473 dim = len(aux_data) 474 for i in range(dim): 475 result += aux_data[i] 476 if norm_x == 0.0 and norm_y == 0.0: 477 return 0.0 478 elif norm_x == 0.0 or norm_y == 0.0: 479 return FLOAT32_MAX 480 elif result <= 0.0: 481 return FLOAT32_MAX 482 else: 483 result = (norm_x * norm_y) / result 484 return np.log2(result) 485 486 487@numba.vectorize(fastmath=True, cache=True) 488def sparse_correct_alternative_cosine(d): 489 if isclose(0.0, abs(d), atol=1e-7) or d < 0.0: 490 return 0.0 491 else: 492 return 1.0 - pow(2.0, -d) 493 494 495@numba.njit(cache=True) 496def sparse_dot(ind1, data1, ind2, data2): 497 _, aux_data = sparse_mul(ind1, data1, ind2, data2) 498 result = 0.0 499 500 for val in aux_data: 501 result += val 502 503 return 1.0 - result 504 505 506@numba.njit( 507 # "f4(i4[::1],f4[::1],i4[::1],f4[::1])", 508 fastmath=True, 509 locals={ 510 "result": numba.types.float32, 511 "dim": numba.types.intp, 512 "i": numba.types.uint16, 513 }, 514 cache=True, 515) 516def sparse_alternative_dot(ind1, data1, ind2, data2): 517 _, aux_data = sparse_mul(ind1, data1, ind2, data2) 518 result = 0.0 519 dim = len(aux_data) 520 for i in range(dim): 521 result += aux_data[i] 522 523 if result <= 0.0: 524 return FLOAT32_MAX 525 else: 526 return -np.log2(result) 527 528 529@numba.njit(cache=True) 530def sparse_correlation(ind1, data1, ind2, data2, n_features): 531 532 mu_x = 0.0 533 mu_y = 0.0 534 dot_product = 0.0 535 536 if ind1.shape[0] == 0 and ind2.shape[0] == 0: 537 return 0.0 538 elif ind1.shape[0] == 0 or ind2.shape[0] == 0: 539 return 1.0 540 541 for i in range(data1.shape[0]): 542 mu_x += data1[i] 543 for i in range(data2.shape[0]): 544 mu_y += data2[i] 545 546 mu_x /= n_features 547 mu_y /= n_features 548 549 shifted_data1 = np.empty(data1.shape[0], dtype=np.float32) 550 shifted_data2 = np.empty(data2.shape[0], dtype=np.float32) 551 552 for i in range(data1.shape[0]): 553 shifted_data1[i] = data1[i] - mu_x 554 for i in range(data2.shape[0]): 555 shifted_data2[i] = data2[i] - mu_y 556 557 norm1 = np.sqrt( 558 (norm(shifted_data1) ** 2) + (n_features - ind1.shape[0]) * (mu_x ** 2) 559 ) 560 norm2 = np.sqrt( 561 (norm(shifted_data2) ** 2) + (n_features - ind2.shape[0]) * (mu_y ** 2) 562 ) 563 564 dot_prod_inds, dot_prod_data = sparse_mul(ind1, shifted_data1, ind2, shifted_data2) 565 566 common_indices = set(dot_prod_inds) 567 568 for val in dot_prod_data: 569 dot_product += val 570 571 for i in range(ind1.shape[0]): 572 if ind1[i] not in common_indices: 573 dot_product -= shifted_data1[i] * (mu_y) 574 575 for i in range(ind2.shape[0]): 576 if ind2[i] not in common_indices: 577 dot_product -= shifted_data2[i] * (mu_x) 578 579 all_indices = arr_union(ind1, ind2) 580 dot_product += mu_x * mu_y * (n_features - all_indices.shape[0]) 581 582 if norm1 == 0.0 and norm2 == 0.0: 583 return 0.0 584 elif dot_product == 0.0: 585 return 1.0 586 else: 587 return 1.0 - (dot_product / (norm1 * norm2)) 588 589 590@numba.njit(cache=True) 591def sparse_hellinger(ind1, data1, ind2, data2): 592 aux_inds, aux_data = sparse_mul(ind1, data1, ind2, data2) 593 result = 0.0 594 norm1 = np.sum(data1) 595 norm2 = np.sum(data2) 596 sqrt_norm_prod = np.sqrt(norm1 * norm2) 597 598 for val in aux_data: 599 result += np.sqrt(val) 600 601 if norm1 == 0.0 and norm2 == 0.0: 602 return 0.0 603 elif norm1 == 0.0 or norm2 == 0.0: 604 return 1.0 605 elif result > sqrt_norm_prod: 606 return 0.0 607 else: 608 return np.sqrt(1.0 - (result / sqrt_norm_prod)) 609 610 611@numba.njit( 612 # "f4(i4[::1],f4[::1],i4[::1],f4[::1])", 613 fastmath=True, 614 locals={ 615 "result": numba.types.float32, 616 "l1_norm_x": numba.types.float32, 617 "l1_norm_y": numba.types.float32, 618 "dim": numba.types.intp, 619 "i": numba.types.uint16, 620 }, 621 cache=True, 622) 623def sparse_alternative_hellinger(ind1, data1, ind2, data2): 624 aux_inds, aux_data = sparse_mul(ind1, data1, ind2, data2) 625 result = 0.0 626 l1_norm_x = np.sum(data1) 627 l1_norm_y = np.sum(data2) 628 dim = len(aux_data) 629 630 for i in range(dim): 631 result += np.sqrt(aux_data[i]) 632 633 if l1_norm_x == 0 and l1_norm_y == 0: 634 return 0.0 635 elif l1_norm_x == 0 or l1_norm_y == 0: 636 return FLOAT32_MAX 637 elif result <= 0: 638 return FLOAT32_MAX 639 else: 640 result = np.sqrt(l1_norm_x * l1_norm_y) / result 641 return np.log2(result) 642 643 644@numba.vectorize(fastmath=True, cache=True) 645def sparse_correct_alternative_hellinger(d): 646 if isclose(0.0, abs(d), atol=1e-7) or d < 0.0: 647 return 0.0 648 else: 649 return np.sqrt(1.0 - pow(2.0, -d)) 650 651 652@numba.njit(cache=True) 653def dummy_ground_metric(x, y): 654 return np.float32(not x == y) 655 656 657def create_ground_metric(ground_vectors, metric): 658 """Generate a "ground_metric" suitable for passing to a ``sparse_kantorovich`` 659 distance function. This should be a metric that, given indices of the data, 660 should produce the ground distance between the corresponding vectors. This 661 allows the construction of a cost_matrix or ground_distance_matrix between 662 sparse samples on the fly -- without having to compute an all pairs distance. 663 This is particularly useful for things like word-mover-distance. 664 665 For example, to create a suitable ground_metric for word-mover distance one 666 would use: 667 668 ``wmd_ground_metric = create_ground_metric(word_vectors, cosine)`` 669 670 Parameters 671 ---------- 672 ground_vectors: array of shape (n_features, d) 673 The set of vectors between which ground_distances are measured. That is, 674 there should be a vector for each feature of the space one wishes to compute 675 Kantorovich distance over. 676 677 metric: callable (numba jitted) 678 The underlying metric used to cpmpute distances between feature vectors. 679 680 Returns 681 ------- 682 ground_metric: callable (numba jitted) 683 A ground metric suitable for passing to ``sparse_kantorovich``. 684 """ 685 686 @numba.njit() 687 def ground_metric(index1, index2): 688 return metric(ground_vectors[index1], ground_vectors[index2]) 689 690 return ground_metric 691 692 693@numba.njit(cache=True) 694def sparse_kantorovich(ind1, data1, ind2, data2, ground_metric=dummy_ground_metric): 695 696 cost_matrix = np.empty((ind1.shape[0], ind2.shape[0])) 697 for i in range(ind1.shape[0]): 698 for j in range(ind2.shape[0]): 699 cost_matrix[i, j] = ground_metric(ind1[i], ind2[j]) 700 701 return kantorovich(data1, data2, cost_matrix) 702 703 704@numba.njit(parallel=True, cache=True) 705def diversify( 706 indices, 707 distances, 708 data_indices, 709 data_indptr, 710 data_data, 711 dist, 712 rng_state, 713 prune_probability=1.0, 714): 715 716 for i in numba.prange(indices.shape[0]): 717 718 new_indices = [indices[i, 0]] 719 new_distances = [distances[i, 0]] 720 for j in range(1, indices.shape[1]): 721 if indices[i, j] < 0: 722 break 723 724 flag = True 725 for k in range(len(new_indices)): 726 c = new_indices[k] 727 728 from_ind = data_indices[ 729 data_indptr[indices[i, j]] : data_indptr[indices[i, j] + 1] 730 ] 731 from_data = data_data[ 732 data_indptr[indices[i, j]] : data_indptr[indices[i, j] + 1] 733 ] 734 735 to_ind = data_indices[data_indptr[c] : data_indptr[c + 1]] 736 to_data = data_data[data_indptr[c] : data_indptr[c + 1]] 737 738 d = dist(from_ind, from_data, to_ind, to_data) 739 if new_distances[k] > FLOAT32_EPS and d < distances[i, j]: 740 if tau_rand(rng_state) < prune_probability: 741 flag = False 742 break 743 744 if flag: 745 new_indices.append(indices[i, j]) 746 new_distances.append(distances[i, j]) 747 748 for j in range(indices.shape[1]): 749 if j < len(new_indices): 750 indices[i, j] = new_indices[j] 751 distances[i, j] = new_distances[j] 752 else: 753 indices[i, j] = -1 754 distances[i, j] = np.inf 755 756 return indices, distances 757 758 759@numba.njit(parallel=True, cache=True) 760def diversify_csr( 761 graph_indptr, 762 graph_indices, 763 graph_data, 764 data_indptr, 765 data_indices, 766 data_data, 767 dist, 768 rng_state, 769 prune_probability=1.0, 770): 771 772 n_nodes = graph_indptr.shape[0] - 1 773 774 for i in numba.prange(n_nodes): 775 776 current_indices = graph_indices[graph_indptr[i] : graph_indptr[i + 1]] 777 current_data = graph_data[graph_indptr[i] : graph_indptr[i + 1]] 778 779 order = np.argsort(current_data) 780 retained = np.ones(order.shape[0], dtype=np.int8) 781 782 for idx in range(1, order.shape[0]): 783 784 j = order[idx] 785 786 for k in range(idx): 787 788 l = order[k] 789 790 if retained[l] == 1: 791 p = current_indices[j] 792 q = current_indices[l] 793 794 from_inds = data_indices[data_indptr[p] : data_indptr[p + 1]] 795 from_data = data_data[data_indptr[p] : data_indptr[p + 1]] 796 797 to_inds = data_indices[data_indptr[q] : data_indptr[q + 1]] 798 to_data = data_data[data_indptr[q] : data_indptr[q + 1]] 799 d = dist(from_inds, from_data, to_inds, to_data) 800 801 if current_data[l] > FLOAT32_EPS and d < current_data[j]: 802 if tau_rand(rng_state) < prune_probability: 803 retained[j] = 0 804 break 805 806 for idx in range(order.shape[0]): 807 j = order[idx] 808 if retained[j] == 0: 809 graph_data[graph_indptr[i] + j] = 0 810 811 return 812 813 814sparse_named_distances = { 815 # general minkowski distances 816 "euclidean": sparse_euclidean, 817 "l2": sparse_euclidean, 818 "sqeuclidean": sparse_squared_euclidean, 819 "manhattan": sparse_manhattan, 820 "l1": sparse_manhattan, 821 "taxicab": sparse_manhattan, 822 "chebyshev": sparse_chebyshev, 823 "linf": sparse_chebyshev, 824 "linfty": sparse_chebyshev, 825 "linfinity": sparse_chebyshev, 826 "minkowski": sparse_minkowski, 827 # Other distances 828 "canberra": sparse_canberra, 829 "kantorovich": sparse_kantorovich, 830 "wasserstein": sparse_kantorovich, 831 "braycurtis": sparse_bray_curtis, 832 # Binary distances 833 "hamming": sparse_hamming, 834 "jaccard": sparse_jaccard, 835 "dice": sparse_dice, 836 "matching": sparse_matching, 837 "kulsinski": sparse_kulsinski, 838 "rogerstanimoto": sparse_rogers_tanimoto, 839 "russellrao": sparse_russellrao, 840 "sokalmichener": sparse_sokal_michener, 841 "sokalsneath": sparse_sokal_sneath, 842 # Angular distances 843 "cosine": sparse_cosine, 844 "correlation": sparse_correlation, 845 "hellinger": sparse_hellinger, 846} 847 848sparse_need_n_features = ( 849 "hamming", 850 "matching", 851 "kulsinski", 852 "rogerstanimoto", 853 "russellrao", 854 "sokalmichener", 855 "correlation", 856) 857 858 859# Some distances have a faster to compute alternative that 860# retains the same ordering of distances. We can compute with 861# this instead, and then correct the final distances when complete. 862# This provides a list of distances that have such an alternative 863# along with the alternative distance function and the correction 864# function to be applied. 865sparse_fast_distance_alternatives = { 866 "euclidean": {"dist": sparse_squared_euclidean, "correction": np.sqrt}, 867 "l2": {"dist": sparse_squared_euclidean, "correction": np.sqrt}, 868 "cosine": { 869 "dist": sparse_alternative_cosine, 870 "correction": sparse_correct_alternative_cosine, 871 }, 872 "dot": { 873 "dist": sparse_alternative_dot, 874 "correction": sparse_correct_alternative_cosine, 875 }, 876 "hellinger": { 877 "dist": sparse_alternative_hellinger, 878 "correction": sparse_correct_alternative_hellinger, 879 }, 880 "jaccard": { 881 "dist": sparse_alternative_jaccard, 882 "correction": correct_alternative_jaccard, 883 }, 884} 885