1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6"""this is a basic test script for simple indices work"""
7from __future__ import absolute_import, division, print_function
8# no unicode_literals because it messes up in py2
9
10import numpy as np
11import unittest
12import faiss
13import tempfile
14import os
15import re
16import warnings
17
18from common_faiss_tests import get_dataset, get_dataset_2
19
20class TestModuleInterface(unittest.TestCase):
21
22    def test_version_attribute(self):
23        assert hasattr(faiss, '__version__')
24        assert re.match('^\\d+\\.\\d+\\.\\d+$', faiss.__version__)
25
26class TestIndexFlat(unittest.TestCase):
27
28    def do_test(self, nq, metric_type=faiss.METRIC_L2, k=10):
29        d = 32
30        nb = 1000
31        nt = 0
32
33        (xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
34        index = faiss.IndexFlat(d, metric_type)
35
36        ### k-NN search
37
38        index.add(xb)
39        D1, I1 = index.search(xq, k)
40
41        if metric_type == faiss.METRIC_L2:
42            all_dis = ((xq.reshape(nq, 1, d) - xb.reshape(1, nb, d)) ** 2).sum(2)
43            Iref = all_dis.argsort(axis=1)[:, :k]
44        else:
45            all_dis = np.dot(xq, xb.T)
46            Iref = all_dis.argsort(axis=1)[:, ::-1][:, :k]
47
48        Dref = all_dis[np.arange(nq)[:, None], Iref]
49        self.assertLessEqual((Iref != I1).sum(), Iref.size * 0.0001)
50        #  np.testing.assert_equal(Iref, I1)
51        np.testing.assert_almost_equal(Dref, D1, decimal=5)
52
53        ### Range search
54
55        radius = float(np.median(Dref[:, -1]))
56
57        lims, D2, I2 = index.range_search(xq, radius)
58
59        for i in range(nq):
60            l0, l1 = lims[i:i + 2]
61            _, Il = D2[l0:l1], I2[l0:l1]
62            if metric_type == faiss.METRIC_L2:
63                Ilref, = np.where(all_dis[i] < radius)
64            else:
65                Ilref, = np.where(all_dis[i] > radius)
66            Il.sort()
67            Ilref.sort()
68            np.testing.assert_equal(Il, Ilref)
69            np.testing.assert_almost_equal(
70                all_dis[i, Ilref], D2[l0:l1],
71                decimal=5
72            )
73
74    def set_blas_blocks(self, small):
75        if small:
76            faiss.cvar.distance_compute_blas_query_bs = 16
77            faiss.cvar.distance_compute_blas_database_bs = 12
78        else:
79            faiss.cvar.distance_compute_blas_query_bs = 4096
80            faiss.cvar.distance_compute_blas_database_bs = 1024
81
82    def test_with_blas(self):
83        self.set_blas_blocks(small=True)
84        self.do_test(200)
85        self.set_blas_blocks(small=False)
86
87    def test_noblas(self):
88        self.do_test(10)
89
90    def test_with_blas_ip(self):
91        self.set_blas_blocks(small=True)
92        self.do_test(200, faiss.METRIC_INNER_PRODUCT)
93        self.set_blas_blocks(small=False)
94
95    def test_noblas_ip(self):
96        self.do_test(10, faiss.METRIC_INNER_PRODUCT)
97
98    def test_noblas_reservoir(self):
99        self.do_test(10, k=150)
100
101    def test_with_blas_reservoir(self):
102        self.do_test(200, k=150)
103
104    def test_noblas_reservoir_ip(self):
105        self.do_test(10, faiss.METRIC_INNER_PRODUCT, k=150)
106
107    def test_with_blas_reservoir_ip(self):
108        self.do_test(200, faiss.METRIC_INNER_PRODUCT, k=150)
109
110
111
112
113
114class EvalIVFPQAccuracy(unittest.TestCase):
115
116    def test_IndexIVFPQ(self):
117        d = 32
118        nb = 1000
119        nt = 1500
120        nq = 200
121
122        (xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
123
124        gt_index = faiss.IndexFlatL2(d)
125        gt_index.add(xb)
126        D, gt_nns = gt_index.search(xq, 1)
127
128        coarse_quantizer = faiss.IndexFlatL2(d)
129        index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8)
130        index.cp.min_points_per_centroid = 5    # quiet warning
131        index.train(xt)
132        index.add(xb)
133        index.nprobe = 4
134        D, nns = index.search(xq, 10)
135        n_ok = (nns == gt_nns).sum()
136        nq = xq.shape[0]
137
138        self.assertGreater(n_ok, nq * 0.66)
139
140        # check that and Index2Layer gives the same reconstruction
141        # this is a bit fragile: it assumes 2 runs of training give
142        # the exact same result.
143        index2 = faiss.Index2Layer(coarse_quantizer, 32, 8)
144        if True:
145            index2.train(xt)
146        else:
147            index2.pq = index.pq
148            index2.is_trained = True
149        index2.add(xb)
150        ref_recons = index.reconstruct_n(0, nb)
151        new_recons = index2.reconstruct_n(0, nb)
152        self.assertTrue(np.all(ref_recons == new_recons))
153
154
155    def test_IMI(self):
156        d = 32
157        nb = 1000
158        nt = 1500
159        nq = 200
160
161        (xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
162        d = xt.shape[1]
163
164        gt_index = faiss.IndexFlatL2(d)
165        gt_index.add(xb)
166        D, gt_nns = gt_index.search(xq, 1)
167
168        nbits = 5
169        coarse_quantizer = faiss.MultiIndexQuantizer(d, 2, nbits)
170        index = faiss.IndexIVFPQ(coarse_quantizer, d, (1 << nbits) ** 2, 8, 8)
171        index.quantizer_trains_alone = 1
172        index.train(xt)
173        index.add(xb)
174        index.nprobe = 100
175        D, nns = index.search(xq, 10)
176        n_ok = (nns == gt_nns).sum()
177
178        # Should return 166 on mac, and 170 on linux.
179        self.assertGreater(n_ok, 165)
180
181        ############# replace with explicit assignment indexes
182        nbits = 5
183        pq = coarse_quantizer.pq
184        centroids = faiss.vector_to_array(pq.centroids)
185        centroids = centroids.reshape(pq.M, pq.ksub, pq.dsub)
186        ai0 = faiss.IndexFlatL2(pq.dsub)
187        ai0.add(centroids[0])
188        ai1 = faiss.IndexFlatL2(pq.dsub)
189        ai1.add(centroids[1])
190
191        coarse_quantizer_2 = faiss.MultiIndexQuantizer2(d, nbits, ai0, ai1)
192        coarse_quantizer_2.pq = pq
193        coarse_quantizer_2.is_trained = True
194
195        index.quantizer = coarse_quantizer_2
196
197        index.reset()
198        index.add(xb)
199
200        D, nns = index.search(xq, 10)
201        n_ok = (nns == gt_nns).sum()
202
203        # should return the same result
204        self.assertGreater(n_ok, 165)
205
206
207    def test_IMI_2(self):
208        d = 32
209        nb = 1000
210        nt = 1500
211        nq = 200
212
213        (xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
214        d = xt.shape[1]
215
216        gt_index = faiss.IndexFlatL2(d)
217        gt_index.add(xb)
218        D, gt_nns = gt_index.search(xq, 1)
219
220        ############# redo including training
221        nbits = 5
222        ai0 = faiss.IndexFlatL2(int(d / 2))
223        ai1 = faiss.IndexFlatL2(int(d / 2))
224
225        coarse_quantizer = faiss.MultiIndexQuantizer2(d, nbits, ai0, ai1)
226        index = faiss.IndexIVFPQ(coarse_quantizer, d, (1 << nbits) ** 2, 8, 8)
227        index.quantizer_trains_alone = 1
228        index.train(xt)
229        index.add(xb)
230        index.nprobe = 100
231        D, nns = index.search(xq, 10)
232        n_ok = (nns == gt_nns).sum()
233
234        # should return the same result
235        self.assertGreater(n_ok, 165)
236
237
238
239
240
241class TestMultiIndexQuantizer(unittest.TestCase):
242
243    def test_search_k1(self):
244
245        # verify codepath for k = 1 and k > 1
246
247        d = 64
248        nb = 0
249        nt = 1500
250        nq = 200
251
252        (xt, xb, xq) = get_dataset(d, nb, nt, nq)
253
254        miq = faiss.MultiIndexQuantizer(d, 2, 6)
255
256        miq.train(xt)
257
258        D1, I1 = miq.search(xq, 1)
259
260        D5, I5 = miq.search(xq, 5)
261
262        self.assertEqual(np.abs(I1[:, :1] - I5[:, :1]).max(), 0)
263        self.assertEqual(np.abs(D1[:, :1] - D5[:, :1]).max(), 0)
264
265
266class TestScalarQuantizer(unittest.TestCase):
267
268    def test_4variants_ivf(self):
269        d = 32
270        nt = 2500
271        nq = 400
272        nb = 5000
273
274        (xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
275
276        # common quantizer
277        quantizer = faiss.IndexFlatL2(d)
278
279        ncent = 64
280
281        index_gt = faiss.IndexFlatL2(d)
282        index_gt.add(xb)
283        D, I_ref = index_gt.search(xq, 10)
284
285        nok = {}
286
287        index = faiss.IndexIVFFlat(quantizer, d, ncent,
288                                   faiss.METRIC_L2)
289        index.cp.min_points_per_centroid = 5    # quiet warning
290        index.nprobe = 4
291        index.train(xt)
292        index.add(xb)
293        D, I = index.search(xq, 10)
294        nok['flat'] = (I[:, 0] == I_ref[:, 0]).sum()
295
296        for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split():
297            qtype = getattr(faiss.ScalarQuantizer, qname)
298            index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
299                                                  qtype, faiss.METRIC_L2)
300
301            index.nprobe = 4
302            index.train(xt)
303            index.add(xb)
304            D, I = index.search(xq, 10)
305
306            nok[qname] = (I[:, 0] == I_ref[:, 0]).sum()
307        print(nok, nq)
308
309        self.assertGreaterEqual(nok['flat'], nq * 0.6)
310        # The tests below are a bit fragile, it happens that the
311        # ordering between uniform and non-uniform are reverted,
312        # probably because the dataset is small, which introduces
313        # jitter
314        self.assertGreaterEqual(nok['flat'], nok['QT_8bit'])
315        self.assertGreaterEqual(nok['QT_8bit'], nok['QT_4bit'])
316        self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform'])
317        self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform'])
318        self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit'])
319
320    def test_4variants(self):
321        d = 32
322        nt = 2500
323        nq = 400
324        nb = 5000
325
326        (xt, xb, xq) = get_dataset(d, nb, nt, nq)
327
328        index_gt = faiss.IndexFlatL2(d)
329        index_gt.add(xb)
330        D_ref, I_ref = index_gt.search(xq, 10)
331
332        nok = {}
333
334        for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split():
335            qtype = getattr(faiss.ScalarQuantizer, qname)
336            index = faiss.IndexScalarQuantizer(d, qtype, faiss.METRIC_L2)
337            index.train(xt)
338            index.add(xb)
339            D, I = index.search(xq, 10)
340            nok[qname] = (I[:, 0] == I_ref[:, 0]).sum()
341
342        print(nok, nq)
343
344        self.assertGreaterEqual(nok['QT_8bit'], nq * 0.9)
345        self.assertGreaterEqual(nok['QT_8bit'], nok['QT_4bit'])
346        self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform'])
347        self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform'])
348        self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit'])
349
350
351class TestRangeSearch(unittest.TestCase):
352
353    def test_range_search(self):
354        d = 4
355        nt = 100
356        nq = 10
357        nb = 50
358
359        (xt, xb, xq) = get_dataset(d, nb, nt, nq)
360
361        index = faiss.IndexFlatL2(d)
362        index.add(xb)
363
364        Dref, Iref = index.search(xq, 5)
365
366        thresh = 0.1   # *squared* distance
367        lims, D, I = index.range_search(xq, thresh)
368
369        for i in range(nq):
370            Iline = I[lims[i]:lims[i + 1]]
371            Dline = D[lims[i]:lims[i + 1]]
372            for j, dis in zip(Iref[i], Dref[i]):
373                if dis < thresh:
374                    li, = np.where(Iline == j)
375                    self.assertTrue(li.size == 1)
376                    idx = li[0]
377                    self.assertGreaterEqual(1e-4, abs(Dline[idx] - dis))
378
379
380class TestSearchAndReconstruct(unittest.TestCase):
381
382    def run_search_and_reconstruct(self, index, xb, xq, k=10, eps=None):
383        n, d = xb.shape
384        assert xq.shape[1] == d
385        assert index.d == d
386
387        D_ref, I_ref = index.search(xq, k)
388        R_ref = index.reconstruct_n(0, n)
389        D, I, R = index.search_and_reconstruct(xq, k)
390
391        np.testing.assert_almost_equal(D, D_ref, decimal=5)
392        self.assertTrue((I == I_ref).all())
393        self.assertEqual(R.shape[:2], I.shape)
394        self.assertEqual(R.shape[2], d)
395
396        # (n, k, ..) -> (n * k, ..)
397        I_flat = I.reshape(-1)
398        R_flat = R.reshape(-1, d)
399        # Filter out -1s when not enough results
400        R_flat = R_flat[I_flat >= 0]
401        I_flat = I_flat[I_flat >= 0]
402
403        recons_ref_err = np.mean(np.linalg.norm(R_flat - R_ref[I_flat]))
404        self.assertLessEqual(recons_ref_err, 1e-6)
405
406        def norm1(x):
407            return np.sqrt((x ** 2).sum(axis=1))
408
409        recons_err = np.mean(norm1(R_flat - xb[I_flat]))
410
411        print('Reconstruction error = %.3f' % recons_err)
412        if eps is not None:
413            self.assertLessEqual(recons_err, eps)
414
415        return D, I, R
416
417    def test_IndexFlat(self):
418        d = 32
419        nb = 1000
420        nt = 1500
421        nq = 200
422
423        (xt, xb, xq) = get_dataset(d, nb, nt, nq)
424
425        index = faiss.IndexFlatL2(d)
426        index.add(xb)
427
428        self.run_search_and_reconstruct(index, xb, xq, eps=0.0)
429
430    def test_IndexIVFFlat(self):
431        d = 32
432        nb = 1000
433        nt = 1500
434        nq = 200
435
436        (xt, xb, xq) = get_dataset(d, nb, nt, nq)
437
438        quantizer = faiss.IndexFlatL2(d)
439        index = faiss.IndexIVFFlat(quantizer, d, 32, faiss.METRIC_L2)
440        index.cp.min_points_per_centroid = 5    # quiet warning
441        index.nprobe = 4
442        index.train(xt)
443        index.add(xb)
444
445        self.run_search_and_reconstruct(index, xb, xq, eps=0.0)
446
447    def test_IndexIVFPQ(self):
448        d = 32
449        nb = 1000
450        nt = 1500
451        nq = 200
452
453        (xt, xb, xq) = get_dataset(d, nb, nt, nq)
454
455        quantizer = faiss.IndexFlatL2(d)
456        index = faiss.IndexIVFPQ(quantizer, d, 32, 8, 8)
457        index.cp.min_points_per_centroid = 5    # quiet warning
458        index.nprobe = 4
459        index.train(xt)
460        index.add(xb)
461
462        self.run_search_and_reconstruct(index, xb, xq, eps=1.0)
463
464    def test_MultiIndex(self):
465        d = 32
466        nb = 1000
467        nt = 1500
468        nq = 200
469
470        (xt, xb, xq) = get_dataset(d, nb, nt, nq)
471
472        index = faiss.index_factory(d, "IMI2x5,PQ8np")
473        faiss.ParameterSpace().set_index_parameter(index, "nprobe", 4)
474        index.train(xt)
475        index.add(xb)
476
477        self.run_search_and_reconstruct(index, xb, xq, eps=1.0)
478
479    def test_IndexTransform(self):
480        d = 32
481        nb = 1000
482        nt = 1500
483        nq = 200
484
485        (xt, xb, xq) = get_dataset(d, nb, nt, nq)
486
487        index = faiss.index_factory(d, "L2norm,PCA8,IVF32,PQ8np")
488        faiss.ParameterSpace().set_index_parameter(index, "nprobe", 4)
489        index.train(xt)
490        index.add(xb)
491
492        self.run_search_and_reconstruct(index, xb, xq)
493
494
495class TestHNSW(unittest.TestCase):
496
497    def __init__(self, *args, **kwargs):
498        unittest.TestCase.__init__(self, *args, **kwargs)
499        d = 32
500        nt = 0
501        nb = 1500
502        nq = 500
503
504        (_, self.xb, self.xq) = get_dataset_2(d, nt, nb, nq)
505        index = faiss.IndexFlatL2(d)
506        index.add(self.xb)
507        Dref, Iref = index.search(self.xq, 1)
508        self.Iref = Iref
509
510    def test_hnsw(self):
511        d = self.xq.shape[1]
512
513        index = faiss.IndexHNSWFlat(d, 16)
514        index.add(self.xb)
515        Dhnsw, Ihnsw = index.search(self.xq, 1)
516
517        self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 460)
518
519        self.io_and_retest(index, Dhnsw, Ihnsw)
520
521    def test_hnsw_unbounded_queue(self):
522        d = self.xq.shape[1]
523
524        index = faiss.IndexHNSWFlat(d, 16)
525        index.add(self.xb)
526        index.search_bounded_queue = False
527        Dhnsw, Ihnsw = index.search(self.xq, 1)
528
529        self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 460)
530
531        self.io_and_retest(index, Dhnsw, Ihnsw)
532
533    def io_and_retest(self, index, Dhnsw, Ihnsw):
534        fd, tmpfile = tempfile.mkstemp()
535        os.close(fd)
536        try:
537            faiss.write_index(index, tmpfile)
538            index2 = faiss.read_index(tmpfile)
539        finally:
540            if os.path.exists(tmpfile):
541                os.unlink(tmpfile)
542
543        Dhnsw2, Ihnsw2 = index2.search(self.xq, 1)
544
545        self.assertTrue(np.all(Dhnsw2 == Dhnsw))
546        self.assertTrue(np.all(Ihnsw2 == Ihnsw))
547
548        # also test clone
549        index3 = faiss.clone_index(index)
550        Dhnsw3, Ihnsw3 = index3.search(self.xq, 1)
551
552        self.assertTrue(np.all(Dhnsw3 == Dhnsw))
553        self.assertTrue(np.all(Ihnsw3 == Ihnsw))
554
555
556    def test_hnsw_2level(self):
557        d = self.xq.shape[1]
558
559        quant = faiss.IndexFlatL2(d)
560
561        index = faiss.IndexHNSW2Level(quant, 256, 8, 8)
562        index.train(self.xb)
563        index.add(self.xb)
564        Dhnsw, Ihnsw = index.search(self.xq, 1)
565
566        self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 310)
567
568        self.io_and_retest(index, Dhnsw, Ihnsw)
569
570    def test_add_0_vecs(self):
571        index = faiss.IndexHNSWFlat(10, 16)
572        zero_vecs = np.zeros((0, 10), dtype='float32')
573        # infinite loop
574        index.add(zero_vecs)
575
576    def test_hnsw_IP(self):
577        d = self.xq.shape[1]
578
579        index_IP = faiss.IndexFlatIP(d)
580        index_IP.add(self.xb)
581        Dref, Iref = index_IP.search(self.xq, 1)
582
583        index = faiss.IndexHNSWFlat(d, 16, faiss.METRIC_INNER_PRODUCT)
584        index.add(self.xb)
585        Dhnsw, Ihnsw = index.search(self.xq, 1)
586
587        print('nb equal: ', (Iref == Ihnsw).sum())
588
589        self.assertGreaterEqual((Iref == Ihnsw).sum(), 480)
590
591        mask = Iref[:, 0] == Ihnsw[:, 0]
592        assert np.allclose(Dref[mask, 0], Dhnsw[mask, 0])
593
594
595class TestNSG(unittest.TestCase):
596
597    def __init__(self, *args, **kwargs):
598        unittest.TestCase.__init__(self, *args, **kwargs)
599        d = 32
600        nt = 0
601        nb = 1500
602        nq = 500
603        self.GK = 32
604
605        _, self.xb, self.xq = get_dataset_2(d, nt, nb, nq)
606
607    def make_knn_graph(self, metric):
608        n = self.xb.shape[0]
609        d = self.xb.shape[1]
610        index = faiss.IndexFlat(d, metric)
611        index.add(self.xb)
612        _, I = index.search(self.xb, self.GK + 1)
613        knn_graph = np.zeros((n, self.GK), dtype=np.int64)
614
615        # For the inner product distance, the distance between a vector and itself
616        # may not be the smallest, so it is not guaranteed that I[:, 0] is the query itself.
617        for i in range(n):
618            cnt = 0
619            for j in range(self.GK + 1):
620                if I[i, j] != i:
621                    knn_graph[i, cnt] = I[i, j]
622                    cnt += 1
623                if cnt == self.GK:
624                    break
625        return knn_graph
626
627    def subtest_io_and_clone(self, index, Dnsg, Insg):
628        fd, tmpfile = tempfile.mkstemp()
629        os.close(fd)
630        try:
631            faiss.write_index(index, tmpfile)
632            index2 = faiss.read_index(tmpfile)
633        finally:
634            if os.path.exists(tmpfile):
635                os.unlink(tmpfile)
636
637        Dnsg2, Insg2 = index2.search(self.xq, 1)
638
639        self.assertTrue(np.all(Dnsg2 == Dnsg))
640        self.assertTrue(np.all(Insg2 == Insg))
641
642        # also test clone
643        index3 = faiss.clone_index(index)
644        Dnsg3, Insg3 = index3.search(self.xq, 1)
645
646        self.assertTrue(np.all(Dnsg3 == Dnsg))
647        self.assertTrue(np.all(Insg3 == Insg))
648
649    def subtest_connectivity(self, index, nb):
650        vt = faiss.VisitedTable(nb)
651        count = index.nsg.dfs(vt, index.nsg.enterpoint, 0)
652        self.assertEqual(count, nb)
653
654    def subtest_add(self, build_type, thresh, metric=faiss.METRIC_L2):
655        d = self.xq.shape[1]
656        metrics = {faiss.METRIC_L2: 'L2',
657                   faiss.METRIC_INNER_PRODUCT: 'IP'}
658
659        flat_index = faiss.IndexFlat(d, metric)
660        flat_index.add(self.xb)
661        Dref, Iref = flat_index.search(self.xq, 1)
662
663        index = faiss.IndexNSGFlat(d, 16, metric)
664        index.verbose = True
665        index.build_type = build_type
666        index.GK = self.GK
667        index.add(self.xb)
668        Dnsg, Insg = index.search(self.xq, 1)
669
670        recalls = (Iref == Insg).sum()
671        print('metric: {}, nb equal: {}'.format(metrics[metric], recalls))
672        self.assertGreaterEqual(recalls, thresh)
673        self.subtest_connectivity(index, self.xb.shape[0])
674        self.subtest_io_and_clone(index, Dnsg, Insg)
675
676    def subtest_build(self, knn_graph, thresh, metric=faiss.METRIC_L2):
677        d = self.xq.shape[1]
678        metrics = {faiss.METRIC_L2: 'L2',
679                   faiss.METRIC_INNER_PRODUCT: 'IP'}
680
681        flat_index = faiss.IndexFlat(d, metric)
682        flat_index.add(self.xb)
683        Dref, Iref = flat_index.search(self.xq, 1)
684
685        index = faiss.IndexNSGFlat(d, 16, metric)
686        index.verbose = True
687
688        index.build(self.xb, knn_graph)
689        Dnsg, Insg = index.search(self.xq, 1)
690
691        recalls = (Iref == Insg).sum()
692        print('metric: {}, nb equal: {}'.format(metrics[metric], recalls))
693        self.assertGreaterEqual(recalls, thresh)
694        self.subtest_connectivity(index, self.xb.shape[0])
695
696    def test_add_bruteforce_L2(self):
697        self.subtest_add(0, 475, faiss.METRIC_L2)
698
699    def test_add_nndescent_L2(self):
700        self.subtest_add(1, 475, faiss.METRIC_L2)
701
702    def test_add_bruteforce_IP(self):
703        self.subtest_add(0, 480, faiss.METRIC_INNER_PRODUCT)
704
705    def test_add_nndescent_IP(self):
706        self.subtest_add(1, 480, faiss.METRIC_INNER_PRODUCT)
707
708    def test_build_L2(self):
709        knn_graph = self.make_knn_graph(faiss.METRIC_L2)
710        self.subtest_build(knn_graph, 475, faiss.METRIC_L2)
711
712    def test_build_IP(self):
713        knn_graph = self.make_knn_graph(faiss.METRIC_INNER_PRODUCT)
714        self.subtest_build(knn_graph, 480, faiss.METRIC_INNER_PRODUCT)
715
716    def test_build_invalid_knng(self):
717        """Make some invalid entries in the input knn graph.
718
719        It would cause a warning but IndexNSG should be able
720        to handel this.
721        """
722        knn_graph = self.make_knn_graph(faiss.METRIC_L2)
723        knn_graph[:100, 5] = -111
724        self.subtest_build(knn_graph, 475, faiss.METRIC_L2)
725
726        knn_graph = self.make_knn_graph(faiss.METRIC_INNER_PRODUCT)
727        knn_graph[:100, 5] = -111
728        self.subtest_build(knn_graph, 480, faiss.METRIC_INNER_PRODUCT)
729
730    def test_reset(self):
731        """test IndexNSG.reset()"""
732        d = self.xq.shape[1]
733        metrics = {faiss.METRIC_L2: 'L2',
734                   faiss.METRIC_INNER_PRODUCT: 'IP'}
735
736        metric = faiss.METRIC_L2
737        flat_index = faiss.IndexFlat(d, metric)
738        flat_index.add(self.xb)
739        Dref, Iref = flat_index.search(self.xq, 1)
740
741        index = faiss.IndexNSGFlat(d, 16)
742        index.verbose = True
743        index.GK = 32
744
745        index.add(self.xb)
746        Dnsg, Insg = index.search(self.xq, 1)
747        recalls = (Iref == Insg).sum()
748        print('metric: {}, nb equal: {}'.format(metrics[metric], recalls))
749        self.assertGreaterEqual(recalls, 475)
750        self.subtest_connectivity(index, self.xb.shape[0])
751
752        index.reset()
753        index.add(self.xb)
754        Dnsg, Insg = index.search(self.xq, 1)
755        recalls = (Iref == Insg).sum()
756        print('metric: {}, nb equal: {}'.format(metrics[metric], recalls))
757        self.assertGreaterEqual(recalls, 475)
758        self.subtest_connectivity(index, self.xb.shape[0])
759
760
761class TestDistancesPositive(unittest.TestCase):
762
763    def test_l2_pos(self):
764        """
765        roundoff errors occur only with the L2 decomposition used
766        with BLAS, ie. in IndexFlatL2 and with
767        n > distance_compute_blas_threshold = 20
768        """
769
770        d = 128
771        n = 100
772
773        rs = np.random.RandomState(1234)
774        x = rs.rand(n, d).astype('float32')
775
776        index = faiss.IndexFlatL2(d)
777        index.add(x)
778
779        D, I = index.search(x, 10)
780
781        assert np.all(D >= 0)
782
783
784class TestShardReplicas(unittest.TestCase):
785    def test_shard_flag_propagation(self):
786        d = 64                           # dimension
787        nb = 1000
788        rs = np.random.RandomState(1234)
789        xb = rs.rand(nb, d).astype('float32')
790        nlist = 10
791        quantizer1 = faiss.IndexFlatL2(d)
792        quantizer2 = faiss.IndexFlatL2(d)
793        index1 = faiss.IndexIVFFlat(quantizer1, d, nlist)
794        index2 = faiss.IndexIVFFlat(quantizer2, d, nlist)
795
796        index = faiss.IndexShards(d, True)
797        index.add_shard(index1)
798        index.add_shard(index2)
799
800        self.assertFalse(index.is_trained)
801        index.train(xb)
802        self.assertTrue(index.is_trained)
803
804        self.assertEqual(index.ntotal, 0)
805        index.add(xb)
806        self.assertEqual(index.ntotal, nb)
807
808        index.remove_shard(index2)
809        self.assertEqual(index.ntotal, nb / 2)
810        index.remove_shard(index1)
811        self.assertEqual(index.ntotal, 0)
812
813    def test_replica_flag_propagation(self):
814        d = 64                           # dimension
815        nb = 1000
816        rs = np.random.RandomState(1234)
817        xb = rs.rand(nb, d).astype('float32')
818        nlist = 10
819        quantizer1 = faiss.IndexFlatL2(d)
820        quantizer2 = faiss.IndexFlatL2(d)
821        index1 = faiss.IndexIVFFlat(quantizer1, d, nlist)
822        index2 = faiss.IndexIVFFlat(quantizer2, d, nlist)
823
824        index = faiss.IndexReplicas(d, True)
825        index.add_replica(index1)
826        index.add_replica(index2)
827
828        self.assertFalse(index.is_trained)
829        index.train(xb)
830        self.assertTrue(index.is_trained)
831
832        self.assertEqual(index.ntotal, 0)
833        index.add(xb)
834        self.assertEqual(index.ntotal, nb)
835
836        index.remove_replica(index2)
837        self.assertEqual(index.ntotal, nb)
838        index.remove_replica(index1)
839        self.assertEqual(index.ntotal, 0)
840
841class TestReconsException(unittest.TestCase):
842
843    def test_recons_exception(self):
844
845        d = 64                           # dimension
846        nb = 1000
847        rs = np.random.RandomState(1234)
848        xb = rs.rand(nb, d).astype('float32')
849        nlist = 10
850        quantizer = faiss.IndexFlatL2(d)  # the other index
851        index = faiss.IndexIVFFlat(quantizer, d, nlist)
852        index.train(xb)
853        index.add(xb)
854        index.make_direct_map()
855
856        index.reconstruct(9)
857
858        self.assertRaises(
859            RuntimeError,
860            index.reconstruct, 100001
861        )
862
863    def test_reconstuct_after_add(self):
864        index = faiss.index_factory(10, 'IVF5,SQfp16')
865        index.train(faiss.randn((100, 10), 123))
866        index.add(faiss.randn((100, 10), 345))
867        index.make_direct_map()
868        index.add(faiss.randn((100, 10), 678))
869
870        # should not raise an exception
871        index.reconstruct(5)
872        print(index.ntotal)
873        index.reconstruct(150)
874
875
876class TestReconsHash(unittest.TestCase):
877
878    def do_test(self, index_key):
879        d = 32
880        index = faiss.index_factory(d, index_key)
881        index.train(faiss.randn((100, d), 123))
882
883        # reference reconstruction
884        index.add(faiss.randn((100, d), 345))
885        index.add(faiss.randn((100, d), 678))
886        ref_recons = index.reconstruct_n(0, 200)
887
888        # with lookup
889        index.reset()
890        rs = np.random.RandomState(123)
891        ids = rs.choice(10000, size=200, replace=False).astype(np.int64)
892        index.add_with_ids(faiss.randn((100, d), 345), ids[:100])
893        index.set_direct_map_type(faiss.DirectMap.Hashtable)
894        index.add_with_ids(faiss.randn((100, d), 678), ids[100:])
895
896        # compare
897        for i in range(0, 200, 13):
898            recons = index.reconstruct(int(ids[i]))
899            self.assertTrue(np.all(recons == ref_recons[i]))
900
901        # test I/O
902        buf = faiss.serialize_index(index)
903        index2 = faiss.deserialize_index(buf)
904
905        # compare
906        for i in range(0, 200, 13):
907            recons = index2.reconstruct(int(ids[i]))
908            self.assertTrue(np.all(recons == ref_recons[i]))
909
910        # remove
911        toremove = np.ascontiguousarray(ids[0:200:3])
912
913        sel = faiss.IDSelectorArray(50, faiss.swig_ptr(toremove[:50]))
914
915        # test both ways of removing elements
916        nremove = index2.remove_ids(sel)
917        nremove += index2.remove_ids(toremove[50:])
918
919        self.assertEqual(nremove, len(toremove))
920
921        for i in range(0, 200, 13):
922            if i % 3 == 0:
923                self.assertRaises(
924                    RuntimeError,
925                    index2.reconstruct, int(ids[i])
926                )
927            else:
928                recons = index2.reconstruct(int(ids[i]))
929                self.assertTrue(np.all(recons == ref_recons[i]))
930
931        # index error should raise
932        self.assertRaises(
933            RuntimeError,
934            index.reconstruct, 20000
935        )
936
937    def test_IVFFlat(self):
938        self.do_test("IVF5,Flat")
939
940    def test_IVFSQ(self):
941        self.do_test("IVF5,SQfp16")
942
943    def test_IVFPQ(self):
944        self.do_test("IVF5,PQ4x4np")
945
946if __name__ == '__main__':
947    unittest.main()
948
949
950class TestValidIndexParams(unittest.TestCase):
951
952    def test_IndexIVFPQ(self):
953        d = 32
954        nb = 1000
955        nt = 1500
956        nq = 200
957
958        (xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
959
960        coarse_quantizer = faiss.IndexFlatL2(d)
961        index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8)
962        index.cp.min_points_per_centroid = 5    # quiet warning
963        index.train(xt)
964        index.add(xb)
965
966        # invalid nprobe
967        index.nprobe = 0
968        k = 10
969        self.assertRaises(RuntimeError, index.search, xq, k)
970
971        # invalid k
972        index.nprobe = 4
973        k = -10
974        self.assertRaises(AssertionError, index.search, xq, k)
975
976        # valid params
977        index.nprobe = 4
978        k = 10
979        D, nns = index.search(xq, k)
980
981        self.assertEqual(D.shape[0], nq)
982        self.assertEqual(D.shape[1], k)
983
984    def test_IndexFlat(self):
985        d = 32
986        nb = 1000
987        nt = 0
988        nq = 200
989
990        (xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
991        index = faiss.IndexFlat(d, faiss.METRIC_L2)
992
993        index.add(xb)
994
995        # invalid k
996        k = -5
997        self.assertRaises(AssertionError, index.search, xq, k)
998
999        # valid k
1000        k = 5
1001        D, I = index.search(xq, k)
1002
1003        self.assertEqual(D.shape[0], nq)
1004        self.assertEqual(D.shape[1], k)
1005
1006
1007class TestLargeRangeSearch(unittest.TestCase):
1008
1009    def test_range_search(self):
1010        # test for https://github.com/facebookresearch/faiss/issues/1889
1011        d = 256
1012        nq = 16
1013        nb = 1000000
1014
1015        # faiss.cvar.distance_compute_blas_threshold = 10
1016        faiss.omp_set_num_threads(1)
1017
1018        index = faiss.IndexFlatL2(d)
1019        xb = np.zeros((nb, d), dtype="float32")
1020        index.add(xb)
1021
1022        xq = np.zeros((nq, d), dtype="float32")
1023        lims, D, I = index.range_search(xq, 1.0)
1024
1025        assert len(D) == len(xb) * len(xq)
1026