1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6
7import unittest
8import platform
9
10import numpy as np
11import faiss
12
13from faiss.contrib import datasets
14from faiss.contrib.inspect_tools import get_invlist
15
16
17class TestLUTQuantization(unittest.TestCase):
18
19    def compute_dis_float(self, codes, LUT, bias):
20        nprobe, nt, M = codes.shape
21        dis = np.zeros((nprobe, nt), dtype='float32')
22        if bias is not None:
23            dis[:] = bias.reshape(-1, 1)
24
25        if LUT.ndim == 2:
26            LUTp = LUT
27
28        for p in range(nprobe):
29            if LUT.ndim == 3:
30                LUTp = LUT[p]
31
32            for i in range(nt):
33                dis[p, i] += LUTp[np.arange(M), codes[p, i]].sum()
34
35        return dis
36
37    def compute_dis_quant(self, codes, LUT, bias, a, b):
38        nprobe, nt, M = codes.shape
39        dis = np.zeros((nprobe, nt), dtype='uint16')
40        if bias is not None:
41            dis[:] = bias.reshape(-1, 1)
42
43        if LUT.ndim == 2:
44            LUTp = LUT
45
46        for p in range(nprobe):
47            if LUT.ndim == 3:
48                LUTp = LUT[p]
49
50            for i in range(nt):
51                dis[p, i] += LUTp[np.arange(M), codes[p, i]].astype('uint16').sum()
52
53        return dis / a + b
54
55    def do_test(self, LUT, bias, nprobe, alt_3d=False):
56        M, ksub = LUT.shape[-2:]
57        nt = 200
58
59        rs = np.random.RandomState(123)
60        codes = rs.randint(ksub, size=(nprobe, nt, M)).astype('uint8')
61
62        dis_ref = self.compute_dis_float(codes, LUT, bias)
63
64        LUTq = np.zeros(LUT.shape, dtype='uint8')
65        biasq = (
66            np.zeros(bias.shape, dtype='uint16')
67            if (bias is not None) and not alt_3d else None
68        )
69        atab = np.zeros(1, dtype='float32')
70        btab = np.zeros(1, dtype='float32')
71
72        def sp(x):
73            return faiss.swig_ptr(x) if x is not None else None
74
75        faiss.quantize_LUT_and_bias(
76                nprobe, M, ksub, LUT.ndim == 3,
77                sp(LUT), sp(bias), sp(LUTq), M, sp(biasq),
78                sp(atab), sp(btab)
79        )
80        a = atab[0]
81        b = btab[0]
82        dis_new = self.compute_dis_quant(codes, LUTq, biasq, a, b)
83
84        #    print(a, b, dis_ref.sum())
85        avg_realtive_error = np.abs(dis_new - dis_ref).sum() / dis_ref.sum()
86        # print('a=', a, 'avg_relative_error=', avg_realtive_error)
87        self.assertLess(avg_realtive_error, 0.0005)
88
89    def test_no_residual_ip(self):
90        ksub = 16
91        M = 20
92        nprobe = 10
93        rs = np.random.RandomState(1234)
94        LUT = rs.rand(M, ksub).astype('float32')
95        bias = None
96
97        self.do_test(LUT, bias, nprobe)
98
99    def test_by_residual_ip(self):
100        ksub = 16
101        M = 20
102        nprobe = 10
103        rs = np.random.RandomState(1234)
104        LUT = rs.rand(M, ksub).astype('float32')
105        bias = rs.rand(nprobe).astype('float32')
106        bias *= 10
107
108        self.do_test(LUT, bias, nprobe)
109
110    def test_by_residual_L2(self):
111        ksub = 16
112        M = 20
113        nprobe = 10
114        rs = np.random.RandomState(1234)
115        LUT = rs.rand(nprobe, M, ksub).astype('float32')
116        bias = rs.rand(nprobe).astype('float32')
117        bias *= 10
118
119        self.do_test(LUT, bias, nprobe)
120
121    def test_by_residual_L2_v2(self):
122        ksub = 16
123        M = 20
124        nprobe = 10
125        rs = np.random.RandomState(1234)
126        LUT = rs.rand(nprobe, M, ksub).astype('float32')
127        bias = rs.rand(nprobe).astype('float32')
128        bias *= 10
129
130        self.do_test(LUT, bias, nprobe, alt_3d=True)
131
132
133
134
135##########################################################
136# Tests for various IndexPQFastScan implementations
137##########################################################
138
139def verify_with_draws(testcase, Dref, Iref, Dnew, Inew):
140    """ verify a list of results where there are draws in the distances (because
141    they are integer). """
142    np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
143    # here we have to be careful because of draws
144    for i in range(len(Iref)):
145        if np.all(Iref[i] == Inew[i]): # easy case
146            continue
147        # we can deduce nothing about the latest line
148        skip_dis = Dref[i, -1]
149        for dis in np.unique(Dref):
150            if dis == skip_dis: continue
151            mask = Dref[i, :] == dis
152            testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
153
154def three_metrics(Dref, Iref, Dnew, Inew):
155    nq = Iref.shape[0]
156    recall_at_1 = (Iref[:, 0] == Inew[:, 0]).sum() / nq
157    recall_at_10 = (Iref[:, :1] == Inew[:, :10]).sum() / nq
158    ninter = 0
159    for i in range(nq):
160        ninter += len(np.intersect1d(Inew[i], Iref[i]))
161    intersection_at_10 = ninter / nq
162    return recall_at_1, recall_at_10, intersection_at_10
163
164
165##########################################################
166# Tests for various IndexIVFPQFastScan implementations
167##########################################################
168
169class TestIVFImplem1(unittest.TestCase):
170    """ Verify implem 1 (search from original invlists)
171    against IndexIVFPQ """
172
173    def do_test(self, by_residual, metric_type=faiss.METRIC_L2,
174                use_precomputed_table=0):
175        ds  = datasets.SyntheticDataset(32, 2000, 5000, 1000)
176
177        index = faiss.index_factory(32, "IVF32,PQ16x4np", metric_type)
178        index.use_precomputed_table
179        index.use_precomputed_table = use_precomputed_table
180        index.train(ds.get_train())
181        index.add(ds.get_database())
182        index.nprobe = 4
183        index.by_residual = by_residual
184        Da, Ia = index.search(ds.get_queries(), 10)
185
186        index2 = faiss.IndexIVFPQFastScan(index)
187        index2.implem = 1
188        Db, Ib = index2.search(ds.get_queries(), 10)
189        # self.assertLess((Ia != Ib).sum(), Ia.size * 0.005)
190        np.testing.assert_array_equal(Ia, Ib)
191        np.testing.assert_almost_equal(Da, Db, decimal=5)
192
193    def test_no_residual(self):
194        self.do_test(False)
195
196    def test_by_residual(self):
197        self.do_test(True)
198
199    def test_by_residual_no_precomputed(self):
200        self.do_test(True, use_precomputed_table=-1)
201
202    def test_no_residual_ip(self):
203        self.do_test(False, faiss.METRIC_INNER_PRODUCT)
204
205    def test_by_residual_ip(self):
206        self.do_test(True, faiss.METRIC_INNER_PRODUCT)
207
208
209
210class TestIVFImplem2(unittest.TestCase):
211    """ Verify implem 2 (search with original invlists with uint8 LUTs)
212    against IndexIVFPQ. Entails some loss in accuracy. """
213
214    def eval_quant_loss(self, by_residual, metric=faiss.METRIC_L2):
215        ds  = datasets.SyntheticDataset(32, 2000, 5000, 1000)
216
217        index = faiss.index_factory(32, "IVF32,PQ16x4np", metric)
218        index.train(ds.get_train())
219        index.add(ds.get_database())
220        index.nprobe = 4
221        index.by_residual = by_residual
222        Da, Ia = index.search(ds.get_queries(), 10)
223
224        # loss due to int8 quantization of LUTs
225        index2 = faiss.IndexIVFPQFastScan(index)
226        index2.implem = 2
227        Db, Ib = index2.search(ds.get_queries(), 10)
228
229        m3 = three_metrics(Da, Ia, Db, Ib)
230
231
232        # print(by_residual, metric, recall_at_1, recall_at_10, intersection_at_10)
233        ref_results = {
234            (True, 1): [0.985, 1.0, 9.872],
235            (True, 0): [ 0.987, 1.0, 9.914],
236            (False, 1): [0.991, 1.0, 9.907],
237            (False, 0): [0.986, 1.0, 9.917],
238        }
239
240        ref = ref_results[(by_residual, metric)]
241
242        self.assertGreaterEqual(m3[0], ref[0] * 0.995)
243        self.assertGreaterEqual(m3[1], ref[1] * 0.995)
244        self.assertGreaterEqual(m3[2], ref[2] * 0.995)
245
246
247    def test_qloss_no_residual(self):
248        self.eval_quant_loss(False)
249
250    def test_qloss_by_residual(self):
251        self.eval_quant_loss(True)
252
253    def test_qloss_no_residual_ip(self):
254        self.eval_quant_loss(False, faiss.METRIC_INNER_PRODUCT)
255
256    def test_qloss_by_residual_ip(self):
257        self.eval_quant_loss(True, faiss.METRIC_INNER_PRODUCT)
258
259class TestEquivPQ(unittest.TestCase):
260
261    def test_equiv_pq(self):
262        ds  = datasets.SyntheticDataset(32, 2000, 200, 4)
263
264        index = faiss.index_factory(32, "IVF1,PQ16x4np")
265        index.by_residual = False
266        # force coarse quantizer
267        index.quantizer.add(np.zeros((1, 32), dtype='float32'))
268        index.train(ds.get_train())
269        index.add(ds.get_database())
270        Dref, Iref = index.search(ds.get_queries(), 4)
271
272        index_pq = faiss.index_factory(32, "PQ16x4np")
273        index_pq.pq = index.pq
274        index_pq.is_trained = True
275        index_pq.codes = faiss. downcast_InvertedLists(
276            index.invlists).codes.at(0)
277        index_pq.ntotal = index.ntotal
278        Dnew, Inew = index_pq.search(ds.get_queries(), 4)
279
280        np.testing.assert_array_equal(Iref, Inew)
281        np.testing.assert_array_equal(Dref, Dnew)
282
283        index_pq2 = faiss.IndexPQFastScan(index_pq)
284        index_pq2.implem = 12
285        Dref, Iref = index_pq2.search(ds.get_queries(), 4)
286
287        index2 = faiss.IndexIVFPQFastScan(index)
288        index2.implem = 12
289        Dnew, Inew = index2.search(ds.get_queries(), 4)
290        np.testing.assert_array_equal(Iref, Inew)
291        np.testing.assert_array_equal(Dref, Dnew)
292
293
294class TestIVFImplem12(unittest.TestCase):
295
296    IMPLEM = 12
297
298    def do_test(self, by_residual, metric=faiss.METRIC_L2, d=32):
299        ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
300
301        index = faiss.index_factory(d, f"IVF32,PQ{d//2}x4np", metric)
302        # force coarse quantizer
303        # index.quantizer.add(np.zeros((1, 32), dtype='float32'))
304        index.by_residual = by_residual
305        index.train(ds.get_train())
306        index.add(ds.get_database())
307        index.nprobe = 4
308
309        index2 = faiss.IndexIVFPQFastScan(index)
310        index2.implem = 2
311        Dref, Iref = index2.search(ds.get_queries(), 4)
312        index2 = faiss.IndexIVFPQFastScan(index)
313        index2.implem = self.IMPLEM
314        Dnew, Inew = index2.search(ds.get_queries(), 4)
315
316        verify_with_draws(self, Dref, Iref, Dnew, Inew)
317
318        stats = faiss.cvar.indexIVF_stats
319        stats.reset()
320
321        # also verify with single result
322        Dnew, Inew = index2.search(ds.get_queries(), 1)
323        for q in range(len(Dref)):
324            if Dref[q, 1] == Dref[q, 0]:
325                # then we cannot conclude
326                continue
327            self.assertEqual(Iref[q, 0], Inew[q, 0])
328            np.testing.assert_almost_equal(Dref[q, 0], Dnew[q, 0], decimal=5)
329
330        self.assertGreater(stats.ndis, 0)
331
332    def test_no_residual(self):
333        self.do_test(False)
334
335    def test_by_residual(self):
336        self.do_test(True)
337
338    def test_no_residual_ip(self):
339        self.do_test(False, metric=faiss.METRIC_INNER_PRODUCT)
340
341    def test_by_residual_ip(self):
342        self.do_test(True, metric=faiss.METRIC_INNER_PRODUCT)
343
344    def test_no_residual_odd_dim(self):
345        self.do_test(False, d=30)
346
347    def test_by_residual_odd_dim(self):
348        self.do_test(True, d=30)
349
350
351class TestIVFImplem10(TestIVFImplem12):
352    IMPLEM = 10
353
354
355class TestIVFImplem11(TestIVFImplem12):
356    IMPLEM = 11
357
358class TestIVFImplem13(TestIVFImplem12):
359    IMPLEM = 13
360
361
362class TestAdd(unittest.TestCase):
363
364    def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32):
365        bbs = 32
366        ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
367
368        index = faiss.index_factory(d, f"IVF32,PQ{d//2}x4np", metric)
369        index.by_residual = by_residual
370        index.train(ds.get_train())
371        index.nprobe = 4
372
373        xb = ds.get_database()
374        index.add(xb[:1235])
375
376        index2 = faiss.IndexIVFPQFastScan(index, bbs)
377
378        index.add(xb[1235:])
379        index3 = faiss.IndexIVFPQFastScan(index, bbs)
380        Dref, Iref = index3.search(ds.get_queries(), 10)
381
382        index2.add(xb[1235:])
383        Dnew, Inew = index2.search(ds.get_queries(), 10)
384
385        np.testing.assert_array_equal(Dref, Dnew)
386        np.testing.assert_array_equal(Iref, Inew)
387
388        # direct verification of code content. Not sure the test is correct
389        # if codes are shuffled.
390        for list_no in range(32):
391            ref_ids, ref_codes = get_invlist(index3.invlists, list_no)
392            new_ids, new_codes = get_invlist(index2.invlists, list_no)
393            self.assertEqual(set(ref_ids), set(new_ids))
394            new_code_per_id = {
395                new_ids[i]: new_codes[i // bbs, :, i % bbs]
396                for i in range(new_ids.size)
397            }
398            for i, the_id in enumerate(ref_ids):
399                ref_code_i = ref_codes[i // bbs, :, i % bbs]
400                new_code_i = new_code_per_id[the_id]
401                np.testing.assert_array_equal(ref_code_i, new_code_i)
402
403
404    def test_add(self):
405        self.do_test()
406
407    def test_odd_d(self):
408        self.do_test(d=30)
409
410    def test_bbs64(self):
411        self.do_test(bbs=64)
412
413
414class TestTraining(unittest.TestCase):
415
416    def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32):
417        bbs = 32
418        ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
419
420        index = faiss.index_factory(d, f"IVF32,PQ{d//2}x4np", metric)
421        index.by_residual = by_residual
422        index.train(ds.get_train())
423        index.add(ds.get_database())
424        index.nprobe = 4
425        Dref, Iref = index.search(ds.get_queries(), 10)
426
427        index2 = faiss.IndexIVFPQFastScan(
428            index.quantizer, d, 32, d // 2, 4, metric, bbs)
429        index2.by_residual = by_residual
430        index2.train(ds.get_train())
431
432        index2.add(ds.get_database())
433        index2.nprobe = 4
434        Dnew, Inew = index2.search(ds.get_queries(), 10)
435
436        m3 = three_metrics(Dref, Iref, Dnew, Inew)
437        #   print((by_residual, metric, d), ":", m3)
438        ref_m3_tab = {
439            (True, 1, 32) : (0.995, 1.0, 9.91),
440            (True, 0, 32) : (0.99, 1.0, 9.91),
441            (True, 1, 30) : (0.99, 1.0, 9.885),
442            (False, 1, 32) : (0.99, 1.0, 9.875),
443            (False, 0, 32) : (0.99, 1.0, 9.92),
444            (False, 1, 30) : (1.0, 1.0, 9.895)
445        }
446        ref_m3 = ref_m3_tab[(by_residual, metric, d)]
447        self.assertGreater(m3[0], ref_m3[0] * 0.99)
448        self.assertGreater(m3[1], ref_m3[1] * 0.99)
449        self.assertGreater(m3[2], ref_m3[2] * 0.99)
450
451        # Test I/O
452        data = faiss.serialize_index(index2)
453        index3 = faiss.deserialize_index(data)
454        D3, I3 = index3.search(ds.get_queries(), 10)
455
456        np.testing.assert_array_equal(I3, Inew)
457        np.testing.assert_array_equal(D3, Dnew)
458
459    def test_no_residual(self):
460        self.do_test(by_residual=False)
461
462    def test_by_residual(self):
463        self.do_test(by_residual=True)
464
465    def test_no_residual_ip(self):
466        self.do_test(by_residual=False, metric=faiss.METRIC_INNER_PRODUCT)
467
468    def test_by_residual_ip(self):
469        self.do_test(by_residual=True, metric=faiss.METRIC_INNER_PRODUCT)
470
471    def test_no_residual_odd_dim(self):
472        self.do_test(by_residual=False, d=30)
473
474    def test_by_residual_odd_dim(self):
475        self.do_test(by_residual=True, d=30)
476