1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6#@nolint 7 8# not linting this file because it imports * from swigfaiss, which 9# causes a ton of useless warnings. 10 11import numpy as np 12import sys 13import inspect 14import array 15import warnings 16 17# We import * so that the symbol foo can be accessed as faiss.foo. 18from .loader import * 19 20 21__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR, 22 FAISS_VERSION_MINOR, 23 FAISS_VERSION_PATCH) 24 25################################################################## 26# The functions below add or replace some methods for classes 27# this is to be able to pass in numpy arrays directly 28# The C++ version of the classnames will be suffixed with _c 29################################################################## 30 31 32def replace_method(the_class, name, replacement, ignore_missing=False): 33 """ Replaces a method in a class with another version. The old method 34 is renamed to method_name_c (because presumably it was implemented in C) """ 35 try: 36 orig_method = getattr(the_class, name) 37 except AttributeError: 38 if ignore_missing: 39 return 40 raise 41 if orig_method.__name__ == 'replacement_' + name: 42 # replacement was done in parent class 43 return 44 setattr(the_class, name + '_c', orig_method) 45 setattr(the_class, name, replacement) 46 47def handle_Clustering(): 48 49 def replacement_train(self, x, index, weights=None): 50 """Perform clustering on a set of vectors. The index is used for assignment. 51 52 Parameters 53 ---------- 54 x : array_like 55 Training vectors, shape (n, self.d). `dtype` must be float32. 56 index : faiss.Index 57 Index used for assignment. The dimension of the index should be `self.d`. 58 weights : array_like, optional 59 Per training sample weight (size n) used when computing the weighted 60 average to obtain the centroid (default is 1 for all training vectors). 61 """ 62 n, d = x.shape 63 assert d == self.d 64 if weights is not None: 65 assert weights.shape == (n, ) 66 self.train_c(n, swig_ptr(x), index, swig_ptr(weights)) 67 else: 68 self.train_c(n, swig_ptr(x), index) 69 70 def replacement_train_encoded(self, x, codec, index, weights=None): 71 """ Perform clustering on a set of compressed vectors. The index is used for assignment. 72 The decompression is performed on-the-fly. 73 74 Parameters 75 ---------- 76 x : array_like 77 Training vectors, shape (n, codec.code_size()). `dtype` must be `uint8`. 78 codec : faiss.Index 79 Index used to decode the vectors. Should have dimension `self.d`. 80 index : faiss.Index 81 Index used for assignment. The dimension of the index should be `self.d`. 82 weigths : array_like, optional 83 Per training sample weight (size n) used when computing the weighted 84 average to obtain the centroid (default is 1 for all training vectors). 85 """ 86 n, d = x.shape 87 assert d == codec.sa_code_size() 88 assert codec.d == index.d 89 if weights is not None: 90 assert weights.shape == (n, ) 91 self.train_encoded_c(n, swig_ptr(x), codec, index, swig_ptr(weights)) 92 else: 93 self.train_encoded_c(n, swig_ptr(x), codec, index) 94 replace_method(Clustering, 'train', replacement_train) 95 replace_method(Clustering, 'train_encoded', replacement_train_encoded) 96 97 98handle_Clustering() 99 100 101def handle_Quantizer(the_class): 102 103 def replacement_train(self, x): 104 """ Train the quantizer on a set of training vectors. 105 106 Parameters 107 ---------- 108 x : array_like 109 Training vectors, shape (n, self.d). `dtype` must be float32. 110 """ 111 n, d = x.shape 112 assert d == self.d 113 self.train_c(n, swig_ptr(x)) 114 115 def replacement_compute_codes(self, x): 116 """ Compute the codes corresponding to a set of vectors. 117 118 Parameters 119 ---------- 120 x : array_like 121 Vectors to encode, shape (n, self.d). `dtype` must be float32. 122 123 Returns 124 ------- 125 codes : array_like 126 Corresponding code for each vector, shape (n, self.code_size) 127 and `dtype` uint8. 128 """ 129 n, d = x.shape 130 assert d == self.d 131 codes = np.empty((n, self.code_size), dtype='uint8') 132 self.compute_codes_c(swig_ptr(x), swig_ptr(codes), n) 133 return codes 134 135 def replacement_decode(self, codes): 136 """Reconstruct an approximation of vectors given their codes. 137 138 Parameters 139 ---------- 140 codes : array_like 141 Codes to decode, shape (n, self.code_size). `dtype` must be uint8. 142 143 Returns 144 ------- 145 Reconstructed vectors for each code, shape `(n, d)` and `dtype` float32. 146 """ 147 n, cs = codes.shape 148 assert cs == self.code_size 149 x = np.empty((n, self.d), dtype='float32') 150 self.decode_c(swig_ptr(codes), swig_ptr(x), n) 151 return x 152 153 replace_method(the_class, 'train', replacement_train) 154 replace_method(the_class, 'compute_codes', replacement_compute_codes) 155 replace_method(the_class, 'decode', replacement_decode) 156 157 158handle_Quantizer(ProductQuantizer) 159handle_Quantizer(ScalarQuantizer) 160handle_Quantizer(ResidualQuantizer) 161handle_Quantizer(LocalSearchQuantizer) 162 163 164def handle_NSG(the_class): 165 166 def replacement_build(self, x, graph): 167 n, d = x.shape 168 assert d == self.d 169 assert graph.ndim == 2 170 assert graph.shape[0] == n 171 K = graph.shape[1] 172 self.build_c(n, swig_ptr(x), swig_ptr(graph), K) 173 174 replace_method(the_class, 'build', replacement_build) 175 176 177def handle_Index(the_class): 178 179 def replacement_add(self, x): 180 """Adds vectors to the index. 181 The index must be trained before vectors can be added to it. 182 The vectors are implicitly numbered in sequence. When `n` vectors are 183 added to the index, they are given ids `ntotal`, `ntotal + 1`, ..., `ntotal + n - 1`. 184 185 Parameters 186 ---------- 187 x : array_like 188 Query vectors, shape (n, d) where d is appropriate for the index. 189 `dtype` must be float32. 190 """ 191 192 n, d = x.shape 193 assert d == self.d 194 self.add_c(n, swig_ptr(x)) 195 196 def replacement_add_with_ids(self, x, ids): 197 """Adds vectors with arbitrary ids to the index (not all indexes support this). 198 The index must be trained before vectors can be added to it. 199 Vector `i` is stored in `x[i]` and has id `ids[i]`. 200 201 Parameters 202 ---------- 203 x : array_like 204 Query vectors, shape (n, d) where d is appropriate for the index. 205 `dtype` must be float32. 206 ids : array_like 207 Array if ids of size n. The ids must be of type `int64`. Note that `-1` is reserved 208 in result lists to mean "not found" so it's better to not use it as an id. 209 """ 210 n, d = x.shape 211 assert d == self.d 212 213 assert ids.shape == (n, ), 'not same nb of vectors as ids' 214 self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) 215 216 def replacement_assign(self, x, k, labels=None): 217 """Find the k nearest neighbors of the set of vectors x in the index. 218 This is the same as the `search` method, but discards the distances. 219 220 Parameters 221 ---------- 222 x : array_like 223 Query vectors, shape (n, d) where d is appropriate for the index. 224 `dtype` must be float32. 225 k : int 226 Number of nearest neighbors. 227 labels : array_like, optional 228 Labels array to store the results. 229 230 Returns 231 ------- 232 labels: array_like 233 Labels of the nearest neighbors, shape (n, k). 234 When not enough results are found, the label is set to -1 235 """ 236 n, d = x.shape 237 assert d == self.d 238 239 if labels is None: 240 labels = np.empty((n, k), dtype=np.int64) 241 else: 242 assert labels.shape == (n, k) 243 244 self.assign_c(n, swig_ptr(x), swig_ptr(labels), k) 245 return labels 246 247 def replacement_train(self, x): 248 """Trains the index on a representative set of vectors. 249 The index must be trained before vectors can be added to it. 250 251 Parameters 252 ---------- 253 x : array_like 254 Query vectors, shape (n, d) where d is appropriate for the index. 255 `dtype` must be float32. 256 """ 257 n, d = x.shape 258 assert d == self.d 259 self.train_c(n, swig_ptr(x)) 260 261 def replacement_search(self, x, k, D=None, I=None): 262 """Find the k nearest neighbors of the set of vectors x in the index. 263 264 Parameters 265 ---------- 266 x : array_like 267 Query vectors, shape (n, d) where d is appropriate for the index. 268 `dtype` must be float32. 269 k : int 270 Number of nearest neighbors. 271 D : array_like, optional 272 Distance array to store the result. 273 I : array_like, optional 274 Labels array to store the results. 275 276 Returns 277 ------- 278 D : array_like 279 Distances of the nearest neighbors, shape (n, k). When not enough results are found 280 the label is set to +Inf or -Inf. 281 I : array_like 282 Labels of the nearest neighbors, shape (n, k). 283 When not enough results are found, the label is set to -1 284 """ 285 286 n, d = x.shape 287 assert d == self.d 288 289 assert k > 0 290 291 if D is None: 292 D = np.empty((n, k), dtype=np.float32) 293 else: 294 assert D.shape == (n, k) 295 296 if I is None: 297 I = np.empty((n, k), dtype=np.int64) 298 else: 299 assert I.shape == (n, k) 300 301 self.search_c(n, swig_ptr(x), k, swig_ptr(D), swig_ptr(I)) 302 return D, I 303 304 def replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None): 305 """Find the k nearest neighbors of the set of vectors x in the index, 306 and return an approximation of these vectors. 307 308 Parameters 309 ---------- 310 x : array_like 311 Query vectors, shape (n, d) where d is appropriate for the index. 312 `dtype` must be float32. 313 k : int 314 Number of nearest neighbors. 315 D : array_like, optional 316 Distance array to store the result. 317 I : array_like, optional 318 Labels array to store the result. 319 R : array_like, optional 320 reconstruction array to store 321 322 Returns 323 ------- 324 D : array_like 325 Distances of the nearest neighbors, shape (n, k). When not enough results are found 326 the label is set to +Inf or -Inf. 327 I : array_like 328 Labels of the nearest neighbors, shape (n, k). When not enough results are found, 329 the label is set to -1 330 R : array_like 331 Approximate (reconstructed) nearest neighbor vectors, shape (n, k, d). 332 """ 333 n, d = x.shape 334 assert d == self.d 335 336 assert k > 0 337 338 if D is None: 339 D = np.empty((n, k), dtype=np.float32) 340 else: 341 assert D.shape == (n, k) 342 343 if I is None: 344 I = np.empty((n, k), dtype=np.int64) 345 else: 346 assert I.shape == (n, k) 347 348 if R is None: 349 R = np.empty((n, k, d), dtype=np.float32) 350 else: 351 assert R.shape == (n, k, d) 352 353 self.search_and_reconstruct_c(n, swig_ptr(x), 354 k, swig_ptr(D), 355 swig_ptr(I), 356 swig_ptr(R)) 357 return D, I, R 358 359 def replacement_remove_ids(self, x): 360 """Remove some ids from the index. 361 This is a O(ntotal) operation by default, so could be expensive. 362 363 Parameters 364 ---------- 365 x : array_like or faiss.IDSelector 366 Either an IDSelector that returns True for vectors to remove, or a 367 list of ids to reomove (1D array of int64). When `x` is a list, 368 it is wrapped into an IDSelector. 369 370 Returns 371 ------- 372 n_remove: int 373 number of vectors that were removed 374 """ 375 if isinstance(x, IDSelector): 376 sel = x 377 else: 378 assert x.ndim == 1 379 index_ivf = try_extract_index_ivf (self) 380 if index_ivf and index_ivf.direct_map.type == DirectMap.Hashtable: 381 sel = IDSelectorArray(x.size, swig_ptr(x)) 382 else: 383 sel = IDSelectorBatch(x.size, swig_ptr(x)) 384 return self.remove_ids_c(sel) 385 386 def replacement_reconstruct(self, key, x=None): 387 """Approximate reconstruction of one vector from the index. 388 389 Parameters 390 ---------- 391 key : int 392 Id of the vector to reconstruct 393 x : array_like, optional 394 pre-allocated array to store the results 395 396 Returns 397 ------- 398 x : array_like 399 Reconstructed vector, size `self.d`, `dtype`=float32 400 """ 401 if x is None: 402 x = np.empty(self.d, dtype=np.float32) 403 else: 404 assert x.shape == (self.d, ) 405 406 self.reconstruct_c(key, swig_ptr(x)) 407 return x 408 409 def replacement_reconstruct_n(self, n0, ni, x=None): 410 """Approximate reconstruction of vectors `n0` ... `n0 + ni - 1` from the index. 411 Missing vectors trigger an exception. 412 413 Parameters 414 ---------- 415 n0 : int 416 Id of the first vector to reconstruct 417 ni : int 418 Number of vectors to reconstruct 419 x : array_like, optional 420 pre-allocated array to store the results 421 422 Returns 423 ------- 424 x : array_like 425 Reconstructed vectors, size (`ni`, `self.d`), `dtype`=float32 426 """ 427 if x is None: 428 x = np.empty((ni, self.d), dtype=np.float32) 429 else: 430 assert x.shape == (ni, self.d) 431 432 self.reconstruct_n_c(n0, ni, swig_ptr(x)) 433 return x 434 435 def replacement_update_vectors(self, keys, x): 436 n = keys.size 437 assert keys.shape == (n, ) 438 assert x.shape == (n, self.d) 439 440 self.update_vectors_c(n, swig_ptr(keys), swig_ptr(x)) 441 442 # The CPU does not support passed-in output buffers 443 def replacement_range_search(self, x, thresh): 444 """Search vectors that are within a distance of the query vectors. 445 446 Parameters 447 ---------- 448 x : array_like 449 Query vectors, shape (n, d) where d is appropriate for the index. 450 `dtype` must be float32. 451 thresh : float 452 Threshold to select neighbors. All elements within this radius are returned, 453 except for maximum inner product indexes, where the elements above the 454 threshold are returned 455 456 Returns 457 ------- 458 lims: array_like 459 Startring index of the results for each query vector, size n+1. 460 D : array_like 461 Distances of the nearest neighbors, shape `lims[n]`. The distances for 462 query i are in `D[lims[i]:lims[i+1]]`. 463 I : array_like 464 Labels of nearest neighbors, shape `lims[n]`. The labels for query i 465 are in `I[lims[i]:lims[i+1]]`. 466 467 """ 468 n, d = x.shape 469 assert d == self.d 470 471 res = RangeSearchResult(n) 472 self.range_search_c(n, swig_ptr(x), thresh, res) 473 # get pointers and copy them 474 lims = rev_swig_ptr(res.lims, n + 1).copy() 475 nd = int(lims[-1]) 476 D = rev_swig_ptr(res.distances, nd).copy() 477 I = rev_swig_ptr(res.labels, nd).copy() 478 return lims, D, I 479 480 def replacement_sa_encode(self, x, codes=None): 481 482 483 n, d = x.shape 484 assert d == self.d 485 486 if codes is None: 487 codes = np.empty((n, self.sa_code_size()), dtype=np.uint8) 488 else: 489 assert codes.shape == (n, self.sa_code_size()) 490 491 self.sa_encode_c(n, swig_ptr(x), swig_ptr(codes)) 492 return codes 493 494 def replacement_sa_decode(self, codes, x=None): 495 n, cs = codes.shape 496 assert cs == self.sa_code_size() 497 498 if x is None: 499 x = np.empty((n, self.d), dtype=np.float32) 500 else: 501 assert x.shape == (n, self.d) 502 503 self.sa_decode_c(n, swig_ptr(codes), swig_ptr(x)) 504 return x 505 506 replace_method(the_class, 'add', replacement_add) 507 replace_method(the_class, 'add_with_ids', replacement_add_with_ids) 508 replace_method(the_class, 'assign', replacement_assign) 509 replace_method(the_class, 'train', replacement_train) 510 replace_method(the_class, 'search', replacement_search) 511 replace_method(the_class, 'remove_ids', replacement_remove_ids) 512 replace_method(the_class, 'reconstruct', replacement_reconstruct) 513 replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n) 514 replace_method(the_class, 'range_search', replacement_range_search) 515 replace_method(the_class, 'update_vectors', replacement_update_vectors, 516 ignore_missing=True) 517 replace_method(the_class, 'search_and_reconstruct', 518 replacement_search_and_reconstruct, ignore_missing=True) 519 replace_method(the_class, 'sa_encode', replacement_sa_encode) 520 replace_method(the_class, 'sa_decode', replacement_sa_decode) 521 522 # get/set state for pickle 523 # the data is serialized to std::vector -> numpy array -> python bytes 524 # so not very efficient for now. 525 526 def index_getstate(self): 527 return {"this": serialize_index(self).tobytes()} 528 529 def index_setstate(self, st): 530 index2 = deserialize_index(np.frombuffer(st["this"], dtype="uint8")) 531 self.this = index2.this 532 533 the_class.__getstate__ = index_getstate 534 the_class.__setstate__ = index_setstate 535 536 537 538def handle_IndexBinary(the_class): 539 540 def replacement_add(self, x): 541 n, d = x.shape 542 assert d * 8 == self.d 543 self.add_c(n, swig_ptr(x)) 544 545 def replacement_add_with_ids(self, x, ids): 546 n, d = x.shape 547 assert d * 8 == self.d 548 assert ids.shape == (n, ), 'not same nb of vectors as ids' 549 self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) 550 551 def replacement_train(self, x): 552 n, d = x.shape 553 assert d * 8 == self.d 554 self.train_c(n, swig_ptr(x)) 555 556 def replacement_reconstruct(self, key): 557 x = np.empty(self.d // 8, dtype=np.uint8) 558 self.reconstruct_c(key, swig_ptr(x)) 559 return x 560 561 def replacement_search(self, x, k): 562 n, d = x.shape 563 assert d * 8 == self.d 564 assert k > 0 565 distances = np.empty((n, k), dtype=np.int32) 566 labels = np.empty((n, k), dtype=np.int64) 567 self.search_c(n, swig_ptr(x), 568 k, swig_ptr(distances), 569 swig_ptr(labels)) 570 return distances, labels 571 572 def replacement_range_search(self, x, thresh): 573 n, d = x.shape 574 assert d * 8 == self.d 575 res = RangeSearchResult(n) 576 self.range_search_c(n, swig_ptr(x), thresh, res) 577 # get pointers and copy them 578 lims = rev_swig_ptr(res.lims, n + 1).copy() 579 nd = int(lims[-1]) 580 D = rev_swig_ptr(res.distances, nd).copy() 581 I = rev_swig_ptr(res.labels, nd).copy() 582 return lims, D, I 583 584 def replacement_remove_ids(self, x): 585 if isinstance(x, IDSelector): 586 sel = x 587 else: 588 assert x.ndim == 1 589 sel = IDSelectorBatch(x.size, swig_ptr(x)) 590 return self.remove_ids_c(sel) 591 592 replace_method(the_class, 'add', replacement_add) 593 replace_method(the_class, 'add_with_ids', replacement_add_with_ids) 594 replace_method(the_class, 'train', replacement_train) 595 replace_method(the_class, 'search', replacement_search) 596 replace_method(the_class, 'range_search', replacement_range_search) 597 replace_method(the_class, 'reconstruct', replacement_reconstruct) 598 replace_method(the_class, 'remove_ids', replacement_remove_ids) 599 600 601def handle_VectorTransform(the_class): 602 603 def apply_method(self, x): 604 n, d = x.shape 605 assert d == self.d_in 606 y = np.empty((n, self.d_out), dtype=np.float32) 607 self.apply_noalloc(n, swig_ptr(x), swig_ptr(y)) 608 return y 609 610 def replacement_reverse_transform(self, x): 611 n, d = x.shape 612 assert d == self.d_out 613 y = np.empty((n, self.d_in), dtype=np.float32) 614 self.reverse_transform_c(n, swig_ptr(x), swig_ptr(y)) 615 return y 616 617 def replacement_vt_train(self, x): 618 n, d = x.shape 619 assert d == self.d_in 620 self.train_c(n, swig_ptr(x)) 621 622 replace_method(the_class, 'train', replacement_vt_train) 623 # apply is reserved in Pyton... 624 the_class.apply_py = apply_method 625 the_class.apply = apply_method 626 replace_method(the_class, 'reverse_transform', 627 replacement_reverse_transform) 628 629 630def handle_AutoTuneCriterion(the_class): 631 def replacement_set_groundtruth(self, D, I): 632 if D: 633 assert I.shape == D.shape 634 self.nq, self.gt_nnn = I.shape 635 self.set_groundtruth_c( 636 self.gt_nnn, swig_ptr(D) if D else None, swig_ptr(I)) 637 638 def replacement_evaluate(self, D, I): 639 assert I.shape == D.shape 640 assert I.shape == (self.nq, self.nnn) 641 return self.evaluate_c(swig_ptr(D), swig_ptr(I)) 642 643 replace_method(the_class, 'set_groundtruth', replacement_set_groundtruth) 644 replace_method(the_class, 'evaluate', replacement_evaluate) 645 646 647def handle_ParameterSpace(the_class): 648 def replacement_explore(self, index, xq, crit): 649 assert xq.shape == (crit.nq, index.d) 650 ops = OperatingPoints() 651 self.explore_c(index, crit.nq, swig_ptr(xq), 652 crit, ops) 653 return ops 654 replace_method(the_class, 'explore', replacement_explore) 655 656 657def handle_MatrixStats(the_class): 658 original_init = the_class.__init__ 659 660 def replacement_init(self, m): 661 assert len(m.shape) == 2 662 original_init(self, m.shape[0], m.shape[1], swig_ptr(m)) 663 664 the_class.__init__ = replacement_init 665 666handle_MatrixStats(MatrixStats) 667 668def handle_IOWriter(the_class): 669 670 def write_bytes(self, b): 671 return self(swig_ptr(b), 1, len(b)) 672 673 the_class.write_bytes = write_bytes 674 675handle_IOWriter(IOWriter) 676 677def handle_IOReader(the_class): 678 679 def read_bytes(self, totsz): 680 buf = bytearray(totsz) 681 was_read = self(swig_ptr(buf), 1, len(buf)) 682 return bytes(buf[:was_read]) 683 684 the_class.read_bytes = read_bytes 685 686handle_IOReader(IOReader) 687 688this_module = sys.modules[__name__] 689 690 691for symbol in dir(this_module): 692 obj = getattr(this_module, symbol) 693 # print symbol, isinstance(obj, (type, types.ClassType)) 694 if inspect.isclass(obj): 695 the_class = obj 696 if issubclass(the_class, Index): 697 handle_Index(the_class) 698 699 if issubclass(the_class, IndexBinary): 700 handle_IndexBinary(the_class) 701 702 if issubclass(the_class, VectorTransform): 703 handle_VectorTransform(the_class) 704 705 if issubclass(the_class, AutoTuneCriterion): 706 handle_AutoTuneCriterion(the_class) 707 708 if issubclass(the_class, ParameterSpace): 709 handle_ParameterSpace(the_class) 710 711 if issubclass(the_class, IndexNSG): 712 handle_NSG(the_class) 713 714########################################### 715# Utility to add a deprecation warning to 716# classes from the SWIG interface 717########################################### 718 719def _make_deprecated_swig_class(deprecated_name, base_name): 720 """ 721 Dynamically construct deprecated classes as wrappers around renamed ones 722 723 The deprecation warning added in their __new__-method will trigger upon 724 construction of an instance of the class, but only once per session. 725 726 We do this here (in __init__.py) because the base classes are defined in 727 the SWIG interface, making it cumbersome to add the deprecation there. 728 729 Parameters 730 ---------- 731 deprecated_name : string 732 Name of the class to be deprecated; _not_ present in SWIG interface. 733 base_name : string 734 Name of the class that is replacing deprecated_name; must already be 735 imported into the current namespace. 736 737 Returns 738 ------- 739 None 740 However, the deprecated class gets added to the faiss namespace 741 """ 742 base_class = globals()[base_name] 743 def new_meth(cls, *args, **kwargs): 744 msg = f"The class faiss.{deprecated_name} is deprecated in favour of faiss.{base_name}!" 745 warnings.warn(msg, DeprecationWarning, stacklevel=2) 746 instance = super(base_class, cls).__new__(cls, *args, **kwargs) 747 return instance 748 749 # three-argument version of "type" uses (name, tuple-of-bases, dict-of-attributes) 750 klazz = type(deprecated_name, (base_class,), {"__new__": new_meth}) 751 752 # this ends up adding the class to the "faiss" namespace, in a way that it 753 # is available both through "import faiss" and "from faiss import *" 754 globals()[deprecated_name] = klazz 755 756########################################### 757# Add Python references to objects 758# we do this at the Python class wrapper level. 759########################################### 760 761def add_ref_in_constructor(the_class, parameter_no): 762 # adds a reference to parameter parameter_no in self 763 # so that that parameter does not get deallocated before self 764 original_init = the_class.__init__ 765 766 def replacement_init(self, *args): 767 original_init(self, *args) 768 self.referenced_objects = [args[parameter_no]] 769 770 def replacement_init_multiple(self, *args): 771 original_init(self, *args) 772 pset = parameter_no[len(args)] 773 self.referenced_objects = [args[no] for no in pset] 774 775 if type(parameter_no) == dict: 776 # a list of parameters to keep, depending on the number of arguments 777 the_class.__init__ = replacement_init_multiple 778 else: 779 the_class.__init__ = replacement_init 780 781 782def add_ref_in_method(the_class, method_name, parameter_no): 783 original_method = getattr(the_class, method_name) 784 def replacement_method(self, *args): 785 ref = args[parameter_no] 786 if not hasattr(self, 'referenced_objects'): 787 self.referenced_objects = [ref] 788 else: 789 self.referenced_objects.append(ref) 790 return original_method(self, *args) 791 setattr(the_class, method_name, replacement_method) 792 793def add_ref_in_function(function_name, parameter_no): 794 # assumes the function returns an object 795 original_function = getattr(this_module, function_name) 796 def replacement_function(*args): 797 result = original_function(*args) 798 ref = args[parameter_no] 799 result.referenced_objects = [ref] 800 return result 801 setattr(this_module, function_name, replacement_function) 802 803add_ref_in_constructor(IndexIVFFlat, 0) 804add_ref_in_constructor(IndexIVFFlatDedup, 0) 805add_ref_in_constructor(IndexPreTransform, {2: [0, 1], 1: [0]}) 806add_ref_in_method(IndexPreTransform, 'prepend_transform', 0) 807add_ref_in_constructor(IndexIVFPQ, 0) 808add_ref_in_constructor(IndexIVFPQR, 0) 809add_ref_in_constructor(IndexIVFPQFastScan, 0) 810add_ref_in_constructor(Index2Layer, 0) 811add_ref_in_constructor(Level1Quantizer, 0) 812add_ref_in_constructor(IndexIVFScalarQuantizer, 0) 813add_ref_in_constructor(IndexIDMap, 0) 814add_ref_in_constructor(IndexIDMap2, 0) 815add_ref_in_constructor(IndexHNSW, 0) 816add_ref_in_method(IndexShards, 'add_shard', 0) 817add_ref_in_method(IndexBinaryShards, 'add_shard', 0) 818add_ref_in_constructor(IndexRefineFlat, {2:[0], 1:[0]}) 819add_ref_in_constructor(IndexRefine, {2:[0, 1]}) 820 821add_ref_in_constructor(IndexBinaryIVF, 0) 822add_ref_in_constructor(IndexBinaryFromFloat, 0) 823add_ref_in_constructor(IndexBinaryIDMap, 0) 824add_ref_in_constructor(IndexBinaryIDMap2, 0) 825 826add_ref_in_method(IndexReplicas, 'addIndex', 0) 827add_ref_in_method(IndexBinaryReplicas, 'addIndex', 0) 828 829add_ref_in_constructor(BufferedIOWriter, 0) 830add_ref_in_constructor(BufferedIOReader, 0) 831 832# seems really marginal... 833# remove_ref_from_method(IndexReplicas, 'removeIndex', 0) 834 835########################################### 836# GPU functions 837########################################### 838 839 840def index_cpu_to_gpu_multiple_py(resources, index, co=None, gpus=None): 841 """ builds the C++ vectors for the GPU indices and the 842 resources. Handles the case where the resources are assigned to 843 the list of GPUs """ 844 if gpus is None: 845 gpus = range(len(resources)) 846 vres = GpuResourcesVector() 847 vdev = Int32Vector() 848 for i, res in zip(gpus, resources): 849 vdev.push_back(i) 850 vres.push_back(res) 851 index = index_cpu_to_gpu_multiple(vres, vdev, index, co) 852 return index 853 854 855def index_cpu_to_all_gpus(index, co=None, ngpu=-1): 856 index_gpu = index_cpu_to_gpus_list(index, co=co, gpus=None, ngpu=ngpu) 857 return index_gpu 858 859 860def index_cpu_to_gpus_list(index, co=None, gpus=None, ngpu=-1): 861 """ Here we can pass list of GPU ids as a parameter or ngpu to 862 use first n GPU's. gpus mut be a list or None""" 863 if (gpus is None) and (ngpu == -1): # All blank 864 gpus = range(get_num_gpus()) 865 elif (gpus is None) and (ngpu != -1): # Get number of GPU's only 866 gpus = range(ngpu) 867 res = [StandardGpuResources() for _ in gpus] 868 index_gpu = index_cpu_to_gpu_multiple_py(res, index, co, gpus) 869 return index_gpu 870 871# allows numpy ndarray usage with bfKnn 872def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2): 873 """ 874 Compute the k nearest neighbors of a vector on one GPU without constructing an index 875 876 Parameters 877 ---------- 878 res : StandardGpuResources 879 GPU resources to use during computation 880 xq : array_like 881 Query vectors, shape (nq, d) where d is appropriate for the index. 882 `dtype` must be float32. 883 xb : array_like 884 Database vectors, shape (nb, d) where d is appropriate for the index. 885 `dtype` must be float32. 886 k : int 887 Number of nearest neighbors. 888 D : array_like, optional 889 Output array for distances of the nearest neighbors, shape (nq, k) 890 I : array_like, optional 891 Output array for the nearest neighbors, shape (nq, k) 892 distance_type : MetricType, optional 893 distance measure to use (either METRIC_L2 or METRIC_INNER_PRODUCT) 894 895 Returns 896 ------- 897 D : array_like 898 Distances of the nearest neighbors, shape (nq, k) 899 I : array_like 900 Labels of the nearest neighbors, shape (nq, k) 901 """ 902 nq, d = xq.shape 903 if xq.flags.c_contiguous: 904 xq_row_major = True 905 elif xq.flags.f_contiguous: 906 xq = xq.T 907 xq_row_major = False 908 else: 909 raise TypeError('xq matrix should be row (C) or column-major (Fortran)') 910 911 xq_ptr = swig_ptr(xq) 912 913 if xq.dtype == np.float32: 914 xq_type = DistanceDataType_F32 915 elif xq.dtype == np.float16: 916 xq_type = DistanceDataType_F16 917 else: 918 raise TypeError('xq must be f32 or f16') 919 920 nb, d2 = xb.shape 921 assert d2 == d 922 if xb.flags.c_contiguous: 923 xb_row_major = True 924 elif xb.flags.f_contiguous: 925 xb = xb.T 926 xb_row_major = False 927 else: 928 raise TypeError('xb matrix should be row (C) or column-major (Fortran)') 929 930 xb_ptr = swig_ptr(xb) 931 932 if xb.dtype == np.float32: 933 xb_type = DistanceDataType_F32 934 elif xb.dtype == np.float16: 935 xb_type = DistanceDataType_F16 936 else: 937 raise TypeError('xb must be float32 or float16') 938 939 if D is None: 940 D = np.empty((nq, k), dtype=np.float32) 941 else: 942 assert D.shape == (nq, k) 943 # interface takes void*, we need to check this 944 assert D.dtype == np.float32 945 946 D_ptr = swig_ptr(D) 947 948 if I is None: 949 I = np.empty((nq, k), dtype=np.int64) 950 else: 951 assert I.shape == (nq, k) 952 953 I_ptr = swig_ptr(I) 954 955 if I.dtype == np.int64: 956 I_type = IndicesDataType_I64 957 elif I.dtype == I.dtype == np.int32: 958 I_type = IndicesDataType_I32 959 else: 960 raise TypeError('I must be i64 or i32') 961 962 args = GpuDistanceParams() 963 args.metric = metric 964 args.k = k 965 args.dims = d 966 args.vectors = xb_ptr 967 args.vectorsRowMajor = xb_row_major 968 args.vectorType = xb_type 969 args.numVectors = nb 970 args.queries = xq_ptr 971 args.queriesRowMajor = xq_row_major 972 args.queryType = xq_type 973 args.numQueries = nq 974 args.outDistances = D_ptr 975 args.outIndices = I_ptr 976 args.outIndicesType = I_type 977 978 # no stream synchronization needed, inputs and outputs are guaranteed to 979 # be on the CPU (numpy arrays) 980 bfKnn(res, args) 981 982 return D, I 983 984# allows numpy ndarray usage with bfKnn for all pairwise distances 985def pairwise_distance_gpu(res, xq, xb, D=None, metric=METRIC_L2): 986 """ 987 Compute all pairwise distances between xq and xb on one GPU without constructing an index 988 989 Parameters 990 ---------- 991 res : StandardGpuResources 992 GPU resources to use during computation 993 xq : array_like 994 Query vectors, shape (nq, d) where d is appropriate for the index. 995 `dtype` must be float32. 996 xb : array_like 997 Database vectors, shape (nb, d) where d is appropriate for the index. 998 `dtype` must be float32. 999 D : array_like, optional 1000 Output array for all pairwise distances, shape (nq, nb) 1001 distance_type : MetricType, optional 1002 distance measure to use (either METRIC_L2 or METRIC_INNER_PRODUCT) 1003 1004 Returns 1005 ------- 1006 D : array_like 1007 All pairwise distances, shape (nq, nb) 1008 """ 1009 nq, d = xq.shape 1010 if xq.flags.c_contiguous: 1011 xq_row_major = True 1012 elif xq.flags.f_contiguous: 1013 xq = xq.T 1014 xq_row_major = False 1015 else: 1016 raise TypeError('xq matrix should be row (C) or column-major (Fortran)') 1017 1018 xq_ptr = swig_ptr(xq) 1019 1020 if xq.dtype == np.float32: 1021 xq_type = DistanceDataType_F32 1022 elif xq.dtype == np.float16: 1023 xq_type = DistanceDataType_F16 1024 else: 1025 raise TypeError('xq must be float32 or float16') 1026 1027 nb, d2 = xb.shape 1028 assert d2 == d 1029 if xb.flags.c_contiguous: 1030 xb_row_major = True 1031 elif xb.flags.f_contiguous: 1032 xb = xb.T 1033 xb_row_major = False 1034 else: 1035 raise TypeError('xb matrix should be row (C) or column-major (Fortran)') 1036 1037 xb_ptr = swig_ptr(xb) 1038 1039 if xb.dtype == np.float32: 1040 xb_type = DistanceDataType_F32 1041 elif xb.dtype == np.float16: 1042 xb_type = DistanceDataType_F16 1043 else: 1044 raise TypeError('xb must be float32 or float16') 1045 1046 if D is None: 1047 D = np.empty((nq, nb), dtype=np.float32) 1048 else: 1049 assert D.shape == (nq, nb) 1050 # interface takes void*, we need to check this 1051 assert D.dtype == np.float32 1052 1053 D_ptr = swig_ptr(D) 1054 1055 args = GpuDistanceParams() 1056 args.metric = metric 1057 args.k = -1 # selects all pairwise distances 1058 args.dims = d 1059 args.vectors = xb_ptr 1060 args.vectorsRowMajor = xb_row_major 1061 args.vectorType = xb_type 1062 args.numVectors = nb 1063 args.queries = xq_ptr 1064 args.queriesRowMajor = xq_row_major 1065 args.queryType = xq_type 1066 args.numQueries = nq 1067 args.outDistances = D_ptr 1068 1069 # no stream synchronization needed, inputs and outputs are guaranteed to 1070 # be on the CPU (numpy arrays) 1071 bfKnn(res, args) 1072 1073 return D 1074 1075 1076########################################### 1077# numpy array / std::vector conversions 1078########################################### 1079 1080sizeof_long = array.array('l').itemsize 1081deprecated_name_map = { 1082 # deprecated: replacement 1083 'Float': 'Float32', 1084 'Double': 'Float64', 1085 'Char': 'Int8', 1086 'Int': 'Int32', 1087 'Long': 'Int32' if sizeof_long == 4 else 'Int64', 1088 'LongLong': 'Int64', 1089 'Byte': 'UInt8', 1090 # previously misspelled variant 1091 'Uint64': 'UInt64', 1092} 1093 1094for depr_prefix, base_prefix in deprecated_name_map.items(): 1095 _make_deprecated_swig_class(depr_prefix + "Vector", base_prefix + "Vector") 1096 1097 # same for the three legacy *VectorVector classes 1098 if depr_prefix in ['Float', 'Long', 'Byte']: 1099 _make_deprecated_swig_class(depr_prefix + "VectorVector", 1100 base_prefix + "VectorVector") 1101 1102# mapping from vector names in swigfaiss.swig and the numpy dtype names 1103# TODO: once deprecated classes are removed, remove the dict and just use .lower() below 1104vector_name_map = { 1105 'Float32': 'float32', 1106 'Float64': 'float64', 1107 'Int8': 'int8', 1108 'Int16': 'int16', 1109 'Int32': 'int32', 1110 'Int64': 'int64', 1111 'UInt8': 'uint8', 1112 'UInt16': 'uint16', 1113 'UInt32': 'uint32', 1114 'UInt64': 'uint64', 1115 **{k: v.lower() for k, v in deprecated_name_map.items()} 1116} 1117 1118 1119def vector_to_array(v): 1120 """ convert a C++ vector to a numpy array """ 1121 classname = v.__class__.__name__ 1122 assert classname.endswith('Vector') 1123 dtype = np.dtype(vector_name_map[classname[:-6]]) 1124 a = np.empty(v.size(), dtype=dtype) 1125 if v.size() > 0: 1126 memcpy(swig_ptr(a), v.data(), a.nbytes) 1127 return a 1128 1129 1130def vector_float_to_array(v): 1131 return vector_to_array(v) 1132 1133 1134def copy_array_to_vector(a, v): 1135 """ copy a numpy array to a vector """ 1136 n, = a.shape 1137 classname = v.__class__.__name__ 1138 assert classname.endswith('Vector') 1139 dtype = np.dtype(vector_name_map[classname[:-6]]) 1140 assert dtype == a.dtype, ( 1141 'cannot copy a %s array to a %s (should be %s)' % ( 1142 a.dtype, classname, dtype)) 1143 v.resize(n) 1144 if n > 0: 1145 memcpy(v.data(), swig_ptr(a), a.nbytes) 1146 1147# same for AlignedTable 1148 1149def copy_array_to_AlignedTable(a, v): 1150 n, = a.shape 1151 # TODO check class name 1152 assert v.itemsize() == a.itemsize 1153 v.resize(n) 1154 if n > 0: 1155 memcpy(v.get(), swig_ptr(a), a.nbytes) 1156 1157def array_to_AlignedTable(a): 1158 if a.dtype == 'uint16': 1159 v = AlignedTableUint16(a.size) 1160 elif a.dtype == 'uint8': 1161 v = AlignedTableUint8(a.size) 1162 else: 1163 assert False 1164 copy_array_to_AlignedTable(a, v) 1165 return v 1166 1167def AlignedTable_to_array(v): 1168 """ convert an AlignedTable to a numpy array """ 1169 classname = v.__class__.__name__ 1170 assert classname.startswith('AlignedTable') 1171 dtype = classname[12:].lower() 1172 a = np.empty(v.size(), dtype=dtype) 1173 if a.size > 0: 1174 memcpy(swig_ptr(a), v.data(), a.nbytes) 1175 return a 1176 1177########################################### 1178# Wrapper for a few functions 1179########################################### 1180 1181def kmin(array, k): 1182 """return k smallest values (and their indices) of the lines of a 1183 float32 array""" 1184 m, n = array.shape 1185 I = np.zeros((m, k), dtype='int64') 1186 D = np.zeros((m, k), dtype='float32') 1187 ha = float_maxheap_array_t() 1188 ha.ids = swig_ptr(I) 1189 ha.val = swig_ptr(D) 1190 ha.nh = m 1191 ha.k = k 1192 ha.heapify() 1193 ha.addn(n, swig_ptr(array)) 1194 ha.reorder() 1195 return D, I 1196 1197 1198def kmax(array, k): 1199 """return k largest values (and their indices) of the lines of a 1200 float32 array""" 1201 m, n = array.shape 1202 I = np.zeros((m, k), dtype='int64') 1203 D = np.zeros((m, k), dtype='float32') 1204 ha = float_minheap_array_t() 1205 ha.ids = swig_ptr(I) 1206 ha.val = swig_ptr(D) 1207 ha.nh = m 1208 ha.k = k 1209 ha.heapify() 1210 ha.addn(n, swig_ptr(array)) 1211 ha.reorder() 1212 return D, I 1213 1214 1215def pairwise_distances(xq, xb, mt=METRIC_L2, metric_arg=0): 1216 """compute the whole pairwise distance matrix between two sets of 1217 vectors""" 1218 nq, d = xq.shape 1219 nb, d2 = xb.shape 1220 assert d == d2 1221 dis = np.empty((nq, nb), dtype='float32') 1222 if mt == METRIC_L2: 1223 pairwise_L2sqr( 1224 d, nq, swig_ptr(xq), 1225 nb, swig_ptr(xb), 1226 swig_ptr(dis)) 1227 else: 1228 pairwise_extra_distances( 1229 d, nq, swig_ptr(xq), 1230 nb, swig_ptr(xb), 1231 mt, metric_arg, 1232 swig_ptr(dis)) 1233 return dis 1234 1235 1236 1237 1238def rand(n, seed=12345): 1239 res = np.empty(n, dtype='float32') 1240 float_rand(swig_ptr(res), res.size, seed) 1241 return res 1242 1243 1244def randint(n, seed=12345, vmax=None): 1245 res = np.empty(n, dtype='int64') 1246 if vmax is None: 1247 int64_rand(swig_ptr(res), res.size, seed) 1248 else: 1249 int64_rand_max(swig_ptr(res), res.size, vmax, seed) 1250 return res 1251 1252lrand = randint 1253 1254def randn(n, seed=12345): 1255 res = np.empty(n, dtype='float32') 1256 float_randn(swig_ptr(res), res.size, seed) 1257 return res 1258 1259 1260def eval_intersection(I1, I2): 1261 """ size of intersection between each line of two result tables""" 1262 n = I1.shape[0] 1263 assert I2.shape[0] == n 1264 k1, k2 = I1.shape[1], I2.shape[1] 1265 ninter = 0 1266 for i in range(n): 1267 ninter += ranklist_intersection_size( 1268 k1, swig_ptr(I1[i]), k2, swig_ptr(I2[i])) 1269 return ninter 1270 1271 1272def normalize_L2(x): 1273 fvec_renorm_L2(x.shape[1], x.shape[0], swig_ptr(x)) 1274 1275###################################################### 1276# MapLong2Long interface 1277###################################################### 1278 1279def replacement_map_add(self, keys, vals): 1280 n, = keys.shape 1281 assert (n,) == keys.shape 1282 self.add_c(n, swig_ptr(keys), swig_ptr(vals)) 1283 1284def replacement_map_search_multiple(self, keys): 1285 n, = keys.shape 1286 vals = np.empty(n, dtype='int64') 1287 self.search_multiple_c(n, swig_ptr(keys), swig_ptr(vals)) 1288 return vals 1289 1290replace_method(MapLong2Long, 'add', replacement_map_add) 1291replace_method(MapLong2Long, 'search_multiple', replacement_map_search_multiple) 1292 1293###################################################### 1294# search_with_parameters interface 1295###################################################### 1296 1297search_with_parameters_c = search_with_parameters 1298 1299def search_with_parameters(index, x, k, params=None, output_stats=False): 1300 n, d = x.shape 1301 assert d == index.d 1302 if not params: 1303 # if not provided use the ones set in the IVF object 1304 params = IVFSearchParameters() 1305 index_ivf = extract_index_ivf(index) 1306 params.nprobe = index_ivf.nprobe 1307 params.max_codes = index_ivf.max_codes 1308 nb_dis = np.empty(1, 'uint64') 1309 ms_per_stage = np.empty(3, 'float64') 1310 distances = np.empty((n, k), dtype=np.float32) 1311 labels = np.empty((n, k), dtype=np.int64) 1312 search_with_parameters_c( 1313 index, n, swig_ptr(x), 1314 k, swig_ptr(distances), 1315 swig_ptr(labels), 1316 params, swig_ptr(nb_dis), swig_ptr(ms_per_stage) 1317 ) 1318 if not output_stats: 1319 return distances, labels 1320 else: 1321 stats = { 1322 'ndis': nb_dis[0], 1323 'pre_transform_ms': ms_per_stage[0], 1324 'coarse_quantizer_ms': ms_per_stage[1], 1325 'invlist_scan_ms': ms_per_stage[2], 1326 } 1327 return distances, labels, stats 1328 1329range_search_with_parameters_c = range_search_with_parameters 1330 1331def range_search_with_parameters(index, x, radius, params=None, output_stats=False): 1332 n, d = x.shape 1333 assert d == index.d 1334 if not params: 1335 # if not provided use the ones set in the IVF object 1336 params = IVFSearchParameters() 1337 index_ivf = extract_index_ivf(index) 1338 params.nprobe = index_ivf.nprobe 1339 params.max_codes = index_ivf.max_codes 1340 nb_dis = np.empty(1, 'uint64') 1341 ms_per_stage = np.empty(3, 'float64') 1342 res = RangeSearchResult(n) 1343 range_search_with_parameters_c( 1344 index, n, swig_ptr(x), 1345 radius, res, 1346 params, swig_ptr(nb_dis), swig_ptr(ms_per_stage) 1347 ) 1348 lims = rev_swig_ptr(res.lims, n + 1).copy() 1349 nd = int(lims[-1]) 1350 Dout = rev_swig_ptr(res.distances, nd).copy() 1351 Iout = rev_swig_ptr(res.labels, nd).copy() 1352 if not output_stats: 1353 return lims, Dout, Iout 1354 else: 1355 stats = { 1356 'ndis': nb_dis[0], 1357 'pre_transform_ms': ms_per_stage[0], 1358 'coarse_quantizer_ms': ms_per_stage[1], 1359 'invlist_scan_ms': ms_per_stage[2], 1360 } 1361 return lims, Dout, Iout, stats 1362 1363 1364###################################################### 1365# KNN function 1366###################################################### 1367 1368def knn(xq, xb, k, metric=METRIC_L2): 1369 """ 1370 Compute the k nearest neighbors of a vector without constructing an index 1371 1372 1373 Parameters 1374 ---------- 1375 xq : array_like 1376 Query vectors, shape (nq, d) where d is appropriate for the index. 1377 `dtype` must be float32. 1378 xb : array_like 1379 Database vectors, shape (nb, d) where d is appropriate for the index. 1380 `dtype` must be float32. 1381 k : int 1382 Number of nearest neighbors. 1383 distance_type : MetricType, optional 1384 distance measure to use (either METRIC_L2 or METRIC_INNER_PRODUCT) 1385 1386 Returns 1387 ------- 1388 D : array_like 1389 Distances of the nearest neighbors, shape (nq, k) 1390 I : array_like 1391 Labels of the nearest neighbors, shape (nq, k) 1392 """ 1393 nq, d = xq.shape 1394 nb, d2 = xb.shape 1395 assert d == d2 1396 1397 I = np.empty((nq, k), dtype='int64') 1398 D = np.empty((nq, k), dtype='float32') 1399 1400 if metric == METRIC_L2: 1401 heaps = float_maxheap_array_t() 1402 heaps.k = k 1403 heaps.nh = nq 1404 heaps.val = swig_ptr(D) 1405 heaps.ids = swig_ptr(I) 1406 knn_L2sqr( 1407 swig_ptr(xq), swig_ptr(xb), 1408 d, nq, nb, heaps 1409 ) 1410 elif metric == METRIC_INNER_PRODUCT: 1411 heaps = float_minheap_array_t() 1412 heaps.k = k 1413 heaps.nh = nq 1414 heaps.val = swig_ptr(D) 1415 heaps.ids = swig_ptr(I) 1416 knn_inner_product( 1417 swig_ptr(xq), swig_ptr(xb), 1418 d, nq, nb, heaps 1419 ) 1420 else: 1421 raise NotImplementedError("only L2 and INNER_PRODUCT are supported") 1422 return D, I 1423 1424 1425########################################### 1426# Kmeans object 1427########################################### 1428 1429 1430class Kmeans: 1431 """Object that performs k-means clustering and manages the centroids. 1432 The `Kmeans` class is essentially a wrapper around the C++ `Clustering` object. 1433 1434 Parameters 1435 ---------- 1436 d : int 1437 dimension of the vectors to cluster 1438 k : int 1439 number of clusters 1440 gpu: bool or int, optional 1441 False: don't use GPU 1442 True: use all GPUs 1443 number: use this many GPUs 1444 progressive_dim_steps: 1445 use a progressive dimension clustering (with that number of steps) 1446 1447 Subsequent parameters are fields of the Clustring object. The most important are: 1448 1449 niter: int, optional 1450 clustering iterations 1451 nredo: int, optional 1452 redo clustering this many times and keep best 1453 verbose: bool, optional 1454 spherical: bool, optional 1455 do we want normalized centroids? 1456 int_centroids: bool, optional 1457 round centroids coordinates to integer 1458 seed: int, optional 1459 seed for the random number generator 1460 1461 """ 1462 1463 1464 def __init__(self, d, k, **kwargs): 1465 """d: input dimension, k: nb of centroids. Additional 1466 parameters are passed on the ClusteringParameters object, 1467 including niter=25, verbose=False, spherical = False 1468 """ 1469 self.d = d 1470 self.k = k 1471 self.gpu = False 1472 if "progressive_dim_steps" in kwargs: 1473 self.cp = ProgressiveDimClusteringParameters() 1474 else: 1475 self.cp = ClusteringParameters() 1476 for k, v in kwargs.items(): 1477 if k == 'gpu': 1478 if v == True or v == -1: 1479 v = get_num_gpus() 1480 self.gpu = v 1481 else: 1482 # if this raises an exception, it means that it is a non-existent field 1483 getattr(self.cp, k) 1484 setattr(self.cp, k, v) 1485 self.centroids = None 1486 1487 def train(self, x, weights=None, init_centroids=None): 1488 """ Perform k-means clustering. 1489 On output of the function call: 1490 1491 - the centroids are in the centroids field of size (`k`, `d`). 1492 1493 - the objective value at each iteration is in the array obj (size `niter`) 1494 1495 - detailed optimization statistics are in the array iteration_stats. 1496 1497 Parameters 1498 ---------- 1499 x : array_like 1500 Training vectors, shape (n, d), `dtype` must be float32 and n should 1501 be larger than the number of clusters `k`. 1502 weights : array_like 1503 weight associated to each vector, shape `n` 1504 init_centroids : array_like 1505 initial set of centroids, shape (n, d) 1506 1507 Returns 1508 ------- 1509 final_obj: float 1510 final optimization objective 1511 1512 """ 1513 n, d = x.shape 1514 assert d == self.d 1515 1516 if self.cp.__class__ == ClusteringParameters: 1517 # regular clustering 1518 clus = Clustering(d, self.k, self.cp) 1519 if init_centroids is not None: 1520 nc, d2 = init_centroids.shape 1521 assert d2 == d 1522 copy_array_to_vector(init_centroids.ravel(), clus.centroids) 1523 if self.cp.spherical: 1524 self.index = IndexFlatIP(d) 1525 else: 1526 self.index = IndexFlatL2(d) 1527 if self.gpu: 1528 self.index = index_cpu_to_all_gpus(self.index, ngpu=self.gpu) 1529 clus.train(x, self.index, weights) 1530 else: 1531 # not supported for progressive dim 1532 assert weights is None 1533 assert init_centroids is None 1534 assert not self.cp.spherical 1535 clus = ProgressiveDimClustering(d, self.k, self.cp) 1536 if self.gpu: 1537 fac = GpuProgressiveDimIndexFactory(ngpu=self.gpu) 1538 else: 1539 fac = ProgressiveDimIndexFactory() 1540 clus.train(n, swig_ptr(x), fac) 1541 1542 centroids = vector_float_to_array(clus.centroids) 1543 1544 self.centroids = centroids.reshape(self.k, d) 1545 stats = clus.iteration_stats 1546 stats = [stats.at(i) for i in range(stats.size())] 1547 self.obj = np.array([st.obj for st in stats]) 1548 # copy all the iteration_stats objects to a python array 1549 stat_fields = 'obj time time_search imbalance_factor nsplit'.split() 1550 self.iteration_stats = [ 1551 {field: getattr(st, field) for field in stat_fields} 1552 for st in stats 1553 ] 1554 return self.obj[-1] if self.obj.size > 0 else 0.0 1555 1556 def assign(self, x): 1557 assert self.centroids is not None, "should train before assigning" 1558 self.index.reset() 1559 self.index.add(self.centroids) 1560 D, I = self.index.search(x, 1) 1561 return D.ravel(), I.ravel() 1562 1563# IndexProxy was renamed to IndexReplicas, remap the old name for any old code 1564# people may have 1565IndexProxy = IndexReplicas 1566ConcatenatedInvertedLists = HStackInvertedLists 1567 1568########################################### 1569# serialization of indexes to byte arrays 1570########################################### 1571 1572def serialize_index(index): 1573 """ convert an index to a numpy uint8 array """ 1574 writer = VectorIOWriter() 1575 write_index(index, writer) 1576 return vector_to_array(writer.data) 1577 1578def deserialize_index(data): 1579 reader = VectorIOReader() 1580 copy_array_to_vector(data, reader.data) 1581 return read_index(reader) 1582 1583def serialize_index_binary(index): 1584 """ convert an index to a numpy uint8 array """ 1585 writer = VectorIOWriter() 1586 write_index_binary(index, writer) 1587 return vector_to_array(writer.data) 1588 1589def deserialize_index_binary(data): 1590 reader = VectorIOReader() 1591 copy_array_to_vector(data, reader.data) 1592 return read_index_binary(reader) 1593 1594 1595########################################### 1596# ResultHeap 1597########################################### 1598 1599class ResultHeap: 1600 """Accumulate query results from a sliced dataset. The final result will 1601 be in self.D, self.I.""" 1602 1603 def __init__(self, nq, k): 1604 " nq: number of query vectors, k: number of results per query " 1605 self.I = np.zeros((nq, k), dtype='int64') 1606 self.D = np.zeros((nq, k), dtype='float32') 1607 self.nq, self.k = nq, k 1608 heaps = float_maxheap_array_t() 1609 heaps.k = k 1610 heaps.nh = nq 1611 heaps.val = swig_ptr(self.D) 1612 heaps.ids = swig_ptr(self.I) 1613 heaps.heapify() 1614 self.heaps = heaps 1615 1616 def add_result(self, D, I): 1617 """D, I do not need to be in a particular order (heap or sorted)""" 1618 assert D.shape == (self.nq, self.k) 1619 assert I.shape == (self.nq, self.k) 1620 self.heaps.addn_with_ids( 1621 self.k, swig_ptr(D), 1622 swig_ptr(I), self.k) 1623 1624 def finalize(self): 1625 self.heaps.reorder() 1626