1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6"""this is a basic test script for simple indices work""" 7from __future__ import absolute_import, division, print_function 8# no unicode_literals because it messes up in py2 9 10import numpy as np 11import unittest 12import faiss 13import tempfile 14import os 15import re 16import warnings 17 18from common_faiss_tests import get_dataset, get_dataset_2 19 20class TestModuleInterface(unittest.TestCase): 21 22 def test_version_attribute(self): 23 assert hasattr(faiss, '__version__') 24 assert re.match('^\\d+\\.\\d+\\.\\d+$', faiss.__version__) 25 26class TestIndexFlat(unittest.TestCase): 27 28 def do_test(self, nq, metric_type=faiss.METRIC_L2, k=10): 29 d = 32 30 nb = 1000 31 nt = 0 32 33 (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) 34 index = faiss.IndexFlat(d, metric_type) 35 36 ### k-NN search 37 38 index.add(xb) 39 D1, I1 = index.search(xq, k) 40 41 if metric_type == faiss.METRIC_L2: 42 all_dis = ((xq.reshape(nq, 1, d) - xb.reshape(1, nb, d)) ** 2).sum(2) 43 Iref = all_dis.argsort(axis=1)[:, :k] 44 else: 45 all_dis = np.dot(xq, xb.T) 46 Iref = all_dis.argsort(axis=1)[:, ::-1][:, :k] 47 48 Dref = all_dis[np.arange(nq)[:, None], Iref] 49 self.assertLessEqual((Iref != I1).sum(), Iref.size * 0.0001) 50 # np.testing.assert_equal(Iref, I1) 51 np.testing.assert_almost_equal(Dref, D1, decimal=5) 52 53 ### Range search 54 55 radius = float(np.median(Dref[:, -1])) 56 57 lims, D2, I2 = index.range_search(xq, radius) 58 59 for i in range(nq): 60 l0, l1 = lims[i:i + 2] 61 _, Il = D2[l0:l1], I2[l0:l1] 62 if metric_type == faiss.METRIC_L2: 63 Ilref, = np.where(all_dis[i] < radius) 64 else: 65 Ilref, = np.where(all_dis[i] > radius) 66 Il.sort() 67 Ilref.sort() 68 np.testing.assert_equal(Il, Ilref) 69 np.testing.assert_almost_equal( 70 all_dis[i, Ilref], D2[l0:l1], 71 decimal=5 72 ) 73 74 def set_blas_blocks(self, small): 75 if small: 76 faiss.cvar.distance_compute_blas_query_bs = 16 77 faiss.cvar.distance_compute_blas_database_bs = 12 78 else: 79 faiss.cvar.distance_compute_blas_query_bs = 4096 80 faiss.cvar.distance_compute_blas_database_bs = 1024 81 82 def test_with_blas(self): 83 self.set_blas_blocks(small=True) 84 self.do_test(200) 85 self.set_blas_blocks(small=False) 86 87 def test_noblas(self): 88 self.do_test(10) 89 90 def test_with_blas_ip(self): 91 self.set_blas_blocks(small=True) 92 self.do_test(200, faiss.METRIC_INNER_PRODUCT) 93 self.set_blas_blocks(small=False) 94 95 def test_noblas_ip(self): 96 self.do_test(10, faiss.METRIC_INNER_PRODUCT) 97 98 def test_noblas_reservoir(self): 99 self.do_test(10, k=150) 100 101 def test_with_blas_reservoir(self): 102 self.do_test(200, k=150) 103 104 def test_noblas_reservoir_ip(self): 105 self.do_test(10, faiss.METRIC_INNER_PRODUCT, k=150) 106 107 def test_with_blas_reservoir_ip(self): 108 self.do_test(200, faiss.METRIC_INNER_PRODUCT, k=150) 109 110 111 112 113 114class EvalIVFPQAccuracy(unittest.TestCase): 115 116 def test_IndexIVFPQ(self): 117 d = 32 118 nb = 1000 119 nt = 1500 120 nq = 200 121 122 (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) 123 124 gt_index = faiss.IndexFlatL2(d) 125 gt_index.add(xb) 126 D, gt_nns = gt_index.search(xq, 1) 127 128 coarse_quantizer = faiss.IndexFlatL2(d) 129 index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8) 130 index.cp.min_points_per_centroid = 5 # quiet warning 131 index.train(xt) 132 index.add(xb) 133 index.nprobe = 4 134 D, nns = index.search(xq, 10) 135 n_ok = (nns == gt_nns).sum() 136 nq = xq.shape[0] 137 138 self.assertGreater(n_ok, nq * 0.66) 139 140 # check that and Index2Layer gives the same reconstruction 141 # this is a bit fragile: it assumes 2 runs of training give 142 # the exact same result. 143 index2 = faiss.Index2Layer(coarse_quantizer, 32, 8) 144 if True: 145 index2.train(xt) 146 else: 147 index2.pq = index.pq 148 index2.is_trained = True 149 index2.add(xb) 150 ref_recons = index.reconstruct_n(0, nb) 151 new_recons = index2.reconstruct_n(0, nb) 152 self.assertTrue(np.all(ref_recons == new_recons)) 153 154 155 def test_IMI(self): 156 d = 32 157 nb = 1000 158 nt = 1500 159 nq = 200 160 161 (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) 162 d = xt.shape[1] 163 164 gt_index = faiss.IndexFlatL2(d) 165 gt_index.add(xb) 166 D, gt_nns = gt_index.search(xq, 1) 167 168 nbits = 5 169 coarse_quantizer = faiss.MultiIndexQuantizer(d, 2, nbits) 170 index = faiss.IndexIVFPQ(coarse_quantizer, d, (1 << nbits) ** 2, 8, 8) 171 index.quantizer_trains_alone = 1 172 index.train(xt) 173 index.add(xb) 174 index.nprobe = 100 175 D, nns = index.search(xq, 10) 176 n_ok = (nns == gt_nns).sum() 177 178 # Should return 166 on mac, and 170 on linux. 179 self.assertGreater(n_ok, 165) 180 181 ############# replace with explicit assignment indexes 182 nbits = 5 183 pq = coarse_quantizer.pq 184 centroids = faiss.vector_to_array(pq.centroids) 185 centroids = centroids.reshape(pq.M, pq.ksub, pq.dsub) 186 ai0 = faiss.IndexFlatL2(pq.dsub) 187 ai0.add(centroids[0]) 188 ai1 = faiss.IndexFlatL2(pq.dsub) 189 ai1.add(centroids[1]) 190 191 coarse_quantizer_2 = faiss.MultiIndexQuantizer2(d, nbits, ai0, ai1) 192 coarse_quantizer_2.pq = pq 193 coarse_quantizer_2.is_trained = True 194 195 index.quantizer = coarse_quantizer_2 196 197 index.reset() 198 index.add(xb) 199 200 D, nns = index.search(xq, 10) 201 n_ok = (nns == gt_nns).sum() 202 203 # should return the same result 204 self.assertGreater(n_ok, 165) 205 206 207 def test_IMI_2(self): 208 d = 32 209 nb = 1000 210 nt = 1500 211 nq = 200 212 213 (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) 214 d = xt.shape[1] 215 216 gt_index = faiss.IndexFlatL2(d) 217 gt_index.add(xb) 218 D, gt_nns = gt_index.search(xq, 1) 219 220 ############# redo including training 221 nbits = 5 222 ai0 = faiss.IndexFlatL2(int(d / 2)) 223 ai1 = faiss.IndexFlatL2(int(d / 2)) 224 225 coarse_quantizer = faiss.MultiIndexQuantizer2(d, nbits, ai0, ai1) 226 index = faiss.IndexIVFPQ(coarse_quantizer, d, (1 << nbits) ** 2, 8, 8) 227 index.quantizer_trains_alone = 1 228 index.train(xt) 229 index.add(xb) 230 index.nprobe = 100 231 D, nns = index.search(xq, 10) 232 n_ok = (nns == gt_nns).sum() 233 234 # should return the same result 235 self.assertGreater(n_ok, 165) 236 237 238 239 240 241class TestMultiIndexQuantizer(unittest.TestCase): 242 243 def test_search_k1(self): 244 245 # verify codepath for k = 1 and k > 1 246 247 d = 64 248 nb = 0 249 nt = 1500 250 nq = 200 251 252 (xt, xb, xq) = get_dataset(d, nb, nt, nq) 253 254 miq = faiss.MultiIndexQuantizer(d, 2, 6) 255 256 miq.train(xt) 257 258 D1, I1 = miq.search(xq, 1) 259 260 D5, I5 = miq.search(xq, 5) 261 262 self.assertEqual(np.abs(I1[:, :1] - I5[:, :1]).max(), 0) 263 self.assertEqual(np.abs(D1[:, :1] - D5[:, :1]).max(), 0) 264 265 266class TestScalarQuantizer(unittest.TestCase): 267 268 def test_4variants_ivf(self): 269 d = 32 270 nt = 2500 271 nq = 400 272 nb = 5000 273 274 (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) 275 276 # common quantizer 277 quantizer = faiss.IndexFlatL2(d) 278 279 ncent = 64 280 281 index_gt = faiss.IndexFlatL2(d) 282 index_gt.add(xb) 283 D, I_ref = index_gt.search(xq, 10) 284 285 nok = {} 286 287 index = faiss.IndexIVFFlat(quantizer, d, ncent, 288 faiss.METRIC_L2) 289 index.cp.min_points_per_centroid = 5 # quiet warning 290 index.nprobe = 4 291 index.train(xt) 292 index.add(xb) 293 D, I = index.search(xq, 10) 294 nok['flat'] = (I[:, 0] == I_ref[:, 0]).sum() 295 296 for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split(): 297 qtype = getattr(faiss.ScalarQuantizer, qname) 298 index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent, 299 qtype, faiss.METRIC_L2) 300 301 index.nprobe = 4 302 index.train(xt) 303 index.add(xb) 304 D, I = index.search(xq, 10) 305 306 nok[qname] = (I[:, 0] == I_ref[:, 0]).sum() 307 print(nok, nq) 308 309 self.assertGreaterEqual(nok['flat'], nq * 0.6) 310 # The tests below are a bit fragile, it happens that the 311 # ordering between uniform and non-uniform are reverted, 312 # probably because the dataset is small, which introduces 313 # jitter 314 self.assertGreaterEqual(nok['flat'], nok['QT_8bit']) 315 self.assertGreaterEqual(nok['QT_8bit'], nok['QT_4bit']) 316 self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform']) 317 self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform']) 318 self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit']) 319 320 def test_4variants(self): 321 d = 32 322 nt = 2500 323 nq = 400 324 nb = 5000 325 326 (xt, xb, xq) = get_dataset(d, nb, nt, nq) 327 328 index_gt = faiss.IndexFlatL2(d) 329 index_gt.add(xb) 330 D_ref, I_ref = index_gt.search(xq, 10) 331 332 nok = {} 333 334 for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split(): 335 qtype = getattr(faiss.ScalarQuantizer, qname) 336 index = faiss.IndexScalarQuantizer(d, qtype, faiss.METRIC_L2) 337 index.train(xt) 338 index.add(xb) 339 D, I = index.search(xq, 10) 340 nok[qname] = (I[:, 0] == I_ref[:, 0]).sum() 341 342 print(nok, nq) 343 344 self.assertGreaterEqual(nok['QT_8bit'], nq * 0.9) 345 self.assertGreaterEqual(nok['QT_8bit'], nok['QT_4bit']) 346 self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform']) 347 self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform']) 348 self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit']) 349 350 351class TestRangeSearch(unittest.TestCase): 352 353 def test_range_search(self): 354 d = 4 355 nt = 100 356 nq = 10 357 nb = 50 358 359 (xt, xb, xq) = get_dataset(d, nb, nt, nq) 360 361 index = faiss.IndexFlatL2(d) 362 index.add(xb) 363 364 Dref, Iref = index.search(xq, 5) 365 366 thresh = 0.1 # *squared* distance 367 lims, D, I = index.range_search(xq, thresh) 368 369 for i in range(nq): 370 Iline = I[lims[i]:lims[i + 1]] 371 Dline = D[lims[i]:lims[i + 1]] 372 for j, dis in zip(Iref[i], Dref[i]): 373 if dis < thresh: 374 li, = np.where(Iline == j) 375 self.assertTrue(li.size == 1) 376 idx = li[0] 377 self.assertGreaterEqual(1e-4, abs(Dline[idx] - dis)) 378 379 380class TestSearchAndReconstruct(unittest.TestCase): 381 382 def run_search_and_reconstruct(self, index, xb, xq, k=10, eps=None): 383 n, d = xb.shape 384 assert xq.shape[1] == d 385 assert index.d == d 386 387 D_ref, I_ref = index.search(xq, k) 388 R_ref = index.reconstruct_n(0, n) 389 D, I, R = index.search_and_reconstruct(xq, k) 390 391 np.testing.assert_almost_equal(D, D_ref, decimal=5) 392 self.assertTrue((I == I_ref).all()) 393 self.assertEqual(R.shape[:2], I.shape) 394 self.assertEqual(R.shape[2], d) 395 396 # (n, k, ..) -> (n * k, ..) 397 I_flat = I.reshape(-1) 398 R_flat = R.reshape(-1, d) 399 # Filter out -1s when not enough results 400 R_flat = R_flat[I_flat >= 0] 401 I_flat = I_flat[I_flat >= 0] 402 403 recons_ref_err = np.mean(np.linalg.norm(R_flat - R_ref[I_flat])) 404 self.assertLessEqual(recons_ref_err, 1e-6) 405 406 def norm1(x): 407 return np.sqrt((x ** 2).sum(axis=1)) 408 409 recons_err = np.mean(norm1(R_flat - xb[I_flat])) 410 411 print('Reconstruction error = %.3f' % recons_err) 412 if eps is not None: 413 self.assertLessEqual(recons_err, eps) 414 415 return D, I, R 416 417 def test_IndexFlat(self): 418 d = 32 419 nb = 1000 420 nt = 1500 421 nq = 200 422 423 (xt, xb, xq) = get_dataset(d, nb, nt, nq) 424 425 index = faiss.IndexFlatL2(d) 426 index.add(xb) 427 428 self.run_search_and_reconstruct(index, xb, xq, eps=0.0) 429 430 def test_IndexIVFFlat(self): 431 d = 32 432 nb = 1000 433 nt = 1500 434 nq = 200 435 436 (xt, xb, xq) = get_dataset(d, nb, nt, nq) 437 438 quantizer = faiss.IndexFlatL2(d) 439 index = faiss.IndexIVFFlat(quantizer, d, 32, faiss.METRIC_L2) 440 index.cp.min_points_per_centroid = 5 # quiet warning 441 index.nprobe = 4 442 index.train(xt) 443 index.add(xb) 444 445 self.run_search_and_reconstruct(index, xb, xq, eps=0.0) 446 447 def test_IndexIVFPQ(self): 448 d = 32 449 nb = 1000 450 nt = 1500 451 nq = 200 452 453 (xt, xb, xq) = get_dataset(d, nb, nt, nq) 454 455 quantizer = faiss.IndexFlatL2(d) 456 index = faiss.IndexIVFPQ(quantizer, d, 32, 8, 8) 457 index.cp.min_points_per_centroid = 5 # quiet warning 458 index.nprobe = 4 459 index.train(xt) 460 index.add(xb) 461 462 self.run_search_and_reconstruct(index, xb, xq, eps=1.0) 463 464 def test_MultiIndex(self): 465 d = 32 466 nb = 1000 467 nt = 1500 468 nq = 200 469 470 (xt, xb, xq) = get_dataset(d, nb, nt, nq) 471 472 index = faiss.index_factory(d, "IMI2x5,PQ8np") 473 faiss.ParameterSpace().set_index_parameter(index, "nprobe", 4) 474 index.train(xt) 475 index.add(xb) 476 477 self.run_search_and_reconstruct(index, xb, xq, eps=1.0) 478 479 def test_IndexTransform(self): 480 d = 32 481 nb = 1000 482 nt = 1500 483 nq = 200 484 485 (xt, xb, xq) = get_dataset(d, nb, nt, nq) 486 487 index = faiss.index_factory(d, "L2norm,PCA8,IVF32,PQ8np") 488 faiss.ParameterSpace().set_index_parameter(index, "nprobe", 4) 489 index.train(xt) 490 index.add(xb) 491 492 self.run_search_and_reconstruct(index, xb, xq) 493 494 495class TestHNSW(unittest.TestCase): 496 497 def __init__(self, *args, **kwargs): 498 unittest.TestCase.__init__(self, *args, **kwargs) 499 d = 32 500 nt = 0 501 nb = 1500 502 nq = 500 503 504 (_, self.xb, self.xq) = get_dataset_2(d, nt, nb, nq) 505 index = faiss.IndexFlatL2(d) 506 index.add(self.xb) 507 Dref, Iref = index.search(self.xq, 1) 508 self.Iref = Iref 509 510 def test_hnsw(self): 511 d = self.xq.shape[1] 512 513 index = faiss.IndexHNSWFlat(d, 16) 514 index.add(self.xb) 515 Dhnsw, Ihnsw = index.search(self.xq, 1) 516 517 self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 460) 518 519 self.io_and_retest(index, Dhnsw, Ihnsw) 520 521 def test_hnsw_unbounded_queue(self): 522 d = self.xq.shape[1] 523 524 index = faiss.IndexHNSWFlat(d, 16) 525 index.add(self.xb) 526 index.search_bounded_queue = False 527 Dhnsw, Ihnsw = index.search(self.xq, 1) 528 529 self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 460) 530 531 self.io_and_retest(index, Dhnsw, Ihnsw) 532 533 def io_and_retest(self, index, Dhnsw, Ihnsw): 534 fd, tmpfile = tempfile.mkstemp() 535 os.close(fd) 536 try: 537 faiss.write_index(index, tmpfile) 538 index2 = faiss.read_index(tmpfile) 539 finally: 540 if os.path.exists(tmpfile): 541 os.unlink(tmpfile) 542 543 Dhnsw2, Ihnsw2 = index2.search(self.xq, 1) 544 545 self.assertTrue(np.all(Dhnsw2 == Dhnsw)) 546 self.assertTrue(np.all(Ihnsw2 == Ihnsw)) 547 548 # also test clone 549 index3 = faiss.clone_index(index) 550 Dhnsw3, Ihnsw3 = index3.search(self.xq, 1) 551 552 self.assertTrue(np.all(Dhnsw3 == Dhnsw)) 553 self.assertTrue(np.all(Ihnsw3 == Ihnsw)) 554 555 556 def test_hnsw_2level(self): 557 d = self.xq.shape[1] 558 559 quant = faiss.IndexFlatL2(d) 560 561 index = faiss.IndexHNSW2Level(quant, 256, 8, 8) 562 index.train(self.xb) 563 index.add(self.xb) 564 Dhnsw, Ihnsw = index.search(self.xq, 1) 565 566 self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 310) 567 568 self.io_and_retest(index, Dhnsw, Ihnsw) 569 570 def test_add_0_vecs(self): 571 index = faiss.IndexHNSWFlat(10, 16) 572 zero_vecs = np.zeros((0, 10), dtype='float32') 573 # infinite loop 574 index.add(zero_vecs) 575 576 def test_hnsw_IP(self): 577 d = self.xq.shape[1] 578 579 index_IP = faiss.IndexFlatIP(d) 580 index_IP.add(self.xb) 581 Dref, Iref = index_IP.search(self.xq, 1) 582 583 index = faiss.IndexHNSWFlat(d, 16, faiss.METRIC_INNER_PRODUCT) 584 index.add(self.xb) 585 Dhnsw, Ihnsw = index.search(self.xq, 1) 586 587 print('nb equal: ', (Iref == Ihnsw).sum()) 588 589 self.assertGreaterEqual((Iref == Ihnsw).sum(), 480) 590 591 mask = Iref[:, 0] == Ihnsw[:, 0] 592 assert np.allclose(Dref[mask, 0], Dhnsw[mask, 0]) 593 594 595class TestNSG(unittest.TestCase): 596 597 def __init__(self, *args, **kwargs): 598 unittest.TestCase.__init__(self, *args, **kwargs) 599 d = 32 600 nt = 0 601 nb = 1500 602 nq = 500 603 self.GK = 32 604 605 _, self.xb, self.xq = get_dataset_2(d, nt, nb, nq) 606 607 def make_knn_graph(self, metric): 608 n = self.xb.shape[0] 609 d = self.xb.shape[1] 610 index = faiss.IndexFlat(d, metric) 611 index.add(self.xb) 612 _, I = index.search(self.xb, self.GK + 1) 613 knn_graph = np.zeros((n, self.GK), dtype=np.int64) 614 615 # For the inner product distance, the distance between a vector and itself 616 # may not be the smallest, so it is not guaranteed that I[:, 0] is the query itself. 617 for i in range(n): 618 cnt = 0 619 for j in range(self.GK + 1): 620 if I[i, j] != i: 621 knn_graph[i, cnt] = I[i, j] 622 cnt += 1 623 if cnt == self.GK: 624 break 625 return knn_graph 626 627 def subtest_io_and_clone(self, index, Dnsg, Insg): 628 fd, tmpfile = tempfile.mkstemp() 629 os.close(fd) 630 try: 631 faiss.write_index(index, tmpfile) 632 index2 = faiss.read_index(tmpfile) 633 finally: 634 if os.path.exists(tmpfile): 635 os.unlink(tmpfile) 636 637 Dnsg2, Insg2 = index2.search(self.xq, 1) 638 639 self.assertTrue(np.all(Dnsg2 == Dnsg)) 640 self.assertTrue(np.all(Insg2 == Insg)) 641 642 # also test clone 643 index3 = faiss.clone_index(index) 644 Dnsg3, Insg3 = index3.search(self.xq, 1) 645 646 self.assertTrue(np.all(Dnsg3 == Dnsg)) 647 self.assertTrue(np.all(Insg3 == Insg)) 648 649 def subtest_connectivity(self, index, nb): 650 vt = faiss.VisitedTable(nb) 651 count = index.nsg.dfs(vt, index.nsg.enterpoint, 0) 652 self.assertEqual(count, nb) 653 654 def subtest_add(self, build_type, thresh, metric=faiss.METRIC_L2): 655 d = self.xq.shape[1] 656 metrics = {faiss.METRIC_L2: 'L2', 657 faiss.METRIC_INNER_PRODUCT: 'IP'} 658 659 flat_index = faiss.IndexFlat(d, metric) 660 flat_index.add(self.xb) 661 Dref, Iref = flat_index.search(self.xq, 1) 662 663 index = faiss.IndexNSGFlat(d, 16, metric) 664 index.verbose = True 665 index.build_type = build_type 666 index.GK = self.GK 667 index.add(self.xb) 668 Dnsg, Insg = index.search(self.xq, 1) 669 670 recalls = (Iref == Insg).sum() 671 print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) 672 self.assertGreaterEqual(recalls, thresh) 673 self.subtest_connectivity(index, self.xb.shape[0]) 674 self.subtest_io_and_clone(index, Dnsg, Insg) 675 676 def subtest_build(self, knn_graph, thresh, metric=faiss.METRIC_L2): 677 d = self.xq.shape[1] 678 metrics = {faiss.METRIC_L2: 'L2', 679 faiss.METRIC_INNER_PRODUCT: 'IP'} 680 681 flat_index = faiss.IndexFlat(d, metric) 682 flat_index.add(self.xb) 683 Dref, Iref = flat_index.search(self.xq, 1) 684 685 index = faiss.IndexNSGFlat(d, 16, metric) 686 index.verbose = True 687 688 index.build(self.xb, knn_graph) 689 Dnsg, Insg = index.search(self.xq, 1) 690 691 recalls = (Iref == Insg).sum() 692 print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) 693 self.assertGreaterEqual(recalls, thresh) 694 self.subtest_connectivity(index, self.xb.shape[0]) 695 696 def test_add_bruteforce_L2(self): 697 self.subtest_add(0, 475, faiss.METRIC_L2) 698 699 def test_add_nndescent_L2(self): 700 self.subtest_add(1, 475, faiss.METRIC_L2) 701 702 def test_add_bruteforce_IP(self): 703 self.subtest_add(0, 480, faiss.METRIC_INNER_PRODUCT) 704 705 def test_add_nndescent_IP(self): 706 self.subtest_add(1, 480, faiss.METRIC_INNER_PRODUCT) 707 708 def test_build_L2(self): 709 knn_graph = self.make_knn_graph(faiss.METRIC_L2) 710 self.subtest_build(knn_graph, 475, faiss.METRIC_L2) 711 712 def test_build_IP(self): 713 knn_graph = self.make_knn_graph(faiss.METRIC_INNER_PRODUCT) 714 self.subtest_build(knn_graph, 480, faiss.METRIC_INNER_PRODUCT) 715 716 def test_build_invalid_knng(self): 717 """Make some invalid entries in the input knn graph. 718 719 It would cause a warning but IndexNSG should be able 720 to handel this. 721 """ 722 knn_graph = self.make_knn_graph(faiss.METRIC_L2) 723 knn_graph[:100, 5] = -111 724 self.subtest_build(knn_graph, 475, faiss.METRIC_L2) 725 726 knn_graph = self.make_knn_graph(faiss.METRIC_INNER_PRODUCT) 727 knn_graph[:100, 5] = -111 728 self.subtest_build(knn_graph, 480, faiss.METRIC_INNER_PRODUCT) 729 730 def test_reset(self): 731 """test IndexNSG.reset()""" 732 d = self.xq.shape[1] 733 metrics = {faiss.METRIC_L2: 'L2', 734 faiss.METRIC_INNER_PRODUCT: 'IP'} 735 736 metric = faiss.METRIC_L2 737 flat_index = faiss.IndexFlat(d, metric) 738 flat_index.add(self.xb) 739 Dref, Iref = flat_index.search(self.xq, 1) 740 741 index = faiss.IndexNSGFlat(d, 16) 742 index.verbose = True 743 index.GK = 32 744 745 index.add(self.xb) 746 Dnsg, Insg = index.search(self.xq, 1) 747 recalls = (Iref == Insg).sum() 748 print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) 749 self.assertGreaterEqual(recalls, 475) 750 self.subtest_connectivity(index, self.xb.shape[0]) 751 752 index.reset() 753 index.add(self.xb) 754 Dnsg, Insg = index.search(self.xq, 1) 755 recalls = (Iref == Insg).sum() 756 print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) 757 self.assertGreaterEqual(recalls, 475) 758 self.subtest_connectivity(index, self.xb.shape[0]) 759 760 761class TestDistancesPositive(unittest.TestCase): 762 763 def test_l2_pos(self): 764 """ 765 roundoff errors occur only with the L2 decomposition used 766 with BLAS, ie. in IndexFlatL2 and with 767 n > distance_compute_blas_threshold = 20 768 """ 769 770 d = 128 771 n = 100 772 773 rs = np.random.RandomState(1234) 774 x = rs.rand(n, d).astype('float32') 775 776 index = faiss.IndexFlatL2(d) 777 index.add(x) 778 779 D, I = index.search(x, 10) 780 781 assert np.all(D >= 0) 782 783 784class TestShardReplicas(unittest.TestCase): 785 def test_shard_flag_propagation(self): 786 d = 64 # dimension 787 nb = 1000 788 rs = np.random.RandomState(1234) 789 xb = rs.rand(nb, d).astype('float32') 790 nlist = 10 791 quantizer1 = faiss.IndexFlatL2(d) 792 quantizer2 = faiss.IndexFlatL2(d) 793 index1 = faiss.IndexIVFFlat(quantizer1, d, nlist) 794 index2 = faiss.IndexIVFFlat(quantizer2, d, nlist) 795 796 index = faiss.IndexShards(d, True) 797 index.add_shard(index1) 798 index.add_shard(index2) 799 800 self.assertFalse(index.is_trained) 801 index.train(xb) 802 self.assertTrue(index.is_trained) 803 804 self.assertEqual(index.ntotal, 0) 805 index.add(xb) 806 self.assertEqual(index.ntotal, nb) 807 808 index.remove_shard(index2) 809 self.assertEqual(index.ntotal, nb / 2) 810 index.remove_shard(index1) 811 self.assertEqual(index.ntotal, 0) 812 813 def test_replica_flag_propagation(self): 814 d = 64 # dimension 815 nb = 1000 816 rs = np.random.RandomState(1234) 817 xb = rs.rand(nb, d).astype('float32') 818 nlist = 10 819 quantizer1 = faiss.IndexFlatL2(d) 820 quantizer2 = faiss.IndexFlatL2(d) 821 index1 = faiss.IndexIVFFlat(quantizer1, d, nlist) 822 index2 = faiss.IndexIVFFlat(quantizer2, d, nlist) 823 824 index = faiss.IndexReplicas(d, True) 825 index.add_replica(index1) 826 index.add_replica(index2) 827 828 self.assertFalse(index.is_trained) 829 index.train(xb) 830 self.assertTrue(index.is_trained) 831 832 self.assertEqual(index.ntotal, 0) 833 index.add(xb) 834 self.assertEqual(index.ntotal, nb) 835 836 index.remove_replica(index2) 837 self.assertEqual(index.ntotal, nb) 838 index.remove_replica(index1) 839 self.assertEqual(index.ntotal, 0) 840 841class TestReconsException(unittest.TestCase): 842 843 def test_recons_exception(self): 844 845 d = 64 # dimension 846 nb = 1000 847 rs = np.random.RandomState(1234) 848 xb = rs.rand(nb, d).astype('float32') 849 nlist = 10 850 quantizer = faiss.IndexFlatL2(d) # the other index 851 index = faiss.IndexIVFFlat(quantizer, d, nlist) 852 index.train(xb) 853 index.add(xb) 854 index.make_direct_map() 855 856 index.reconstruct(9) 857 858 self.assertRaises( 859 RuntimeError, 860 index.reconstruct, 100001 861 ) 862 863 def test_reconstuct_after_add(self): 864 index = faiss.index_factory(10, 'IVF5,SQfp16') 865 index.train(faiss.randn((100, 10), 123)) 866 index.add(faiss.randn((100, 10), 345)) 867 index.make_direct_map() 868 index.add(faiss.randn((100, 10), 678)) 869 870 # should not raise an exception 871 index.reconstruct(5) 872 print(index.ntotal) 873 index.reconstruct(150) 874 875 876class TestReconsHash(unittest.TestCase): 877 878 def do_test(self, index_key): 879 d = 32 880 index = faiss.index_factory(d, index_key) 881 index.train(faiss.randn((100, d), 123)) 882 883 # reference reconstruction 884 index.add(faiss.randn((100, d), 345)) 885 index.add(faiss.randn((100, d), 678)) 886 ref_recons = index.reconstruct_n(0, 200) 887 888 # with lookup 889 index.reset() 890 rs = np.random.RandomState(123) 891 ids = rs.choice(10000, size=200, replace=False).astype(np.int64) 892 index.add_with_ids(faiss.randn((100, d), 345), ids[:100]) 893 index.set_direct_map_type(faiss.DirectMap.Hashtable) 894 index.add_with_ids(faiss.randn((100, d), 678), ids[100:]) 895 896 # compare 897 for i in range(0, 200, 13): 898 recons = index.reconstruct(int(ids[i])) 899 self.assertTrue(np.all(recons == ref_recons[i])) 900 901 # test I/O 902 buf = faiss.serialize_index(index) 903 index2 = faiss.deserialize_index(buf) 904 905 # compare 906 for i in range(0, 200, 13): 907 recons = index2.reconstruct(int(ids[i])) 908 self.assertTrue(np.all(recons == ref_recons[i])) 909 910 # remove 911 toremove = np.ascontiguousarray(ids[0:200:3]) 912 913 sel = faiss.IDSelectorArray(50, faiss.swig_ptr(toremove[:50])) 914 915 # test both ways of removing elements 916 nremove = index2.remove_ids(sel) 917 nremove += index2.remove_ids(toremove[50:]) 918 919 self.assertEqual(nremove, len(toremove)) 920 921 for i in range(0, 200, 13): 922 if i % 3 == 0: 923 self.assertRaises( 924 RuntimeError, 925 index2.reconstruct, int(ids[i]) 926 ) 927 else: 928 recons = index2.reconstruct(int(ids[i])) 929 self.assertTrue(np.all(recons == ref_recons[i])) 930 931 # index error should raise 932 self.assertRaises( 933 RuntimeError, 934 index.reconstruct, 20000 935 ) 936 937 def test_IVFFlat(self): 938 self.do_test("IVF5,Flat") 939 940 def test_IVFSQ(self): 941 self.do_test("IVF5,SQfp16") 942 943 def test_IVFPQ(self): 944 self.do_test("IVF5,PQ4x4np") 945 946if __name__ == '__main__': 947 unittest.main() 948 949 950class TestValidIndexParams(unittest.TestCase): 951 952 def test_IndexIVFPQ(self): 953 d = 32 954 nb = 1000 955 nt = 1500 956 nq = 200 957 958 (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) 959 960 coarse_quantizer = faiss.IndexFlatL2(d) 961 index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8) 962 index.cp.min_points_per_centroid = 5 # quiet warning 963 index.train(xt) 964 index.add(xb) 965 966 # invalid nprobe 967 index.nprobe = 0 968 k = 10 969 self.assertRaises(RuntimeError, index.search, xq, k) 970 971 # invalid k 972 index.nprobe = 4 973 k = -10 974 self.assertRaises(AssertionError, index.search, xq, k) 975 976 # valid params 977 index.nprobe = 4 978 k = 10 979 D, nns = index.search(xq, k) 980 981 self.assertEqual(D.shape[0], nq) 982 self.assertEqual(D.shape[1], k) 983 984 def test_IndexFlat(self): 985 d = 32 986 nb = 1000 987 nt = 0 988 nq = 200 989 990 (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) 991 index = faiss.IndexFlat(d, faiss.METRIC_L2) 992 993 index.add(xb) 994 995 # invalid k 996 k = -5 997 self.assertRaises(AssertionError, index.search, xq, k) 998 999 # valid k 1000 k = 5 1001 D, I = index.search(xq, k) 1002 1003 self.assertEqual(D.shape[0], nq) 1004 self.assertEqual(D.shape[1], k) 1005 1006 1007class TestLargeRangeSearch(unittest.TestCase): 1008 1009 def test_range_search(self): 1010 # test for https://github.com/facebookresearch/faiss/issues/1889 1011 d = 256 1012 nq = 16 1013 nb = 1000000 1014 1015 # faiss.cvar.distance_compute_blas_threshold = 10 1016 faiss.omp_set_num_threads(1) 1017 1018 index = faiss.IndexFlatL2(d) 1019 xb = np.zeros((nb, d), dtype="float32") 1020 index.add(xb) 1021 1022 xq = np.zeros((nq, d), dtype="float32") 1023 lims, D, I = index.range_search(xq, 1.0) 1024 1025 assert len(D) == len(xb) * len(xq) 1026