1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6from __future__ import absolute_import, division, print_function, unicode_literals
7
8# translation of test_meta_index.lua
9
10import sys
11import numpy as np
12import faiss
13import unittest
14
15from common_faiss_tests import Randu10k
16
17ru = Randu10k()
18
19xb = ru.xb
20xt = ru.xt
21xq = ru.xq
22nb, d = xb.shape
23nq, d = xq.shape
24
25
26class IDRemap(unittest.TestCase):
27
28    def test_id_remap_idmap(self):
29        # reference: index without remapping
30
31        index = faiss.IndexPQ(d, 8, 8)
32        k = 10
33        index.train(xt)
34        index.add(xb)
35        _Dref, Iref = index.search(xq, k)
36
37        # try a remapping
38        ids = np.arange(nb)[::-1].copy().astype('int64')
39
40        sub_index = faiss.IndexPQ(d, 8, 8)
41        index2 = faiss.IndexIDMap(sub_index)
42
43        index2.train(xt)
44        index2.add_with_ids(xb, ids)
45
46        _D, I = index2.search(xq, k)
47
48        assert np.all(I == nb - 1 - Iref)
49
50    def test_id_remap_ivf(self):
51        # coarse quantizer in common
52        coarse_quantizer = faiss.IndexFlatIP(d)
53        ncentroids = 25
54
55        # reference: index without remapping
56
57        index = faiss.IndexIVFPQ(coarse_quantizer, d,
58                                        ncentroids, 8, 8)
59        index.nprobe = 5
60        k = 10
61        index.train(xt)
62        index.add(xb)
63        _Dref, Iref = index.search(xq, k)
64
65        # try a remapping
66        ids = np.arange(nb)[::-1].copy().astype('int64')
67
68        index2 = faiss.IndexIVFPQ(coarse_quantizer, d,
69                                        ncentroids, 8, 8)
70        index2.nprobe = 5
71
72        index2.train(xt)
73        index2.add_with_ids(xb, ids)
74
75        _D, I = index2.search(xq, k)
76        assert np.all(I == nb - 1 - Iref)
77
78
79class Shards(unittest.TestCase):
80
81    def test_shards(self):
82        k = 32
83        ref_index = faiss.IndexFlatL2(d)
84
85        print('ref search')
86        ref_index.add(xb)
87        _Dref, Iref = ref_index.search(xq, k)
88        print(Iref[:5, :6])
89
90        # there is a OpenMP bug in this configuration, so disable threading
91        if sys.platform == "darwin" and "Clang 12" in sys.version:
92            nthreads = faiss.omp_get_max_threads()
93            faiss.omp_set_num_threads(1)
94        else:
95            nthreads = None
96
97        shard_index = faiss.IndexShards(d)
98        shard_index_2 = faiss.IndexShards(d, True, False)
99
100        ni = 3
101        for i in range(ni):
102            i0 = int(i * nb / ni)
103            i1 = int((i + 1) * nb / ni)
104            index = faiss.IndexFlatL2(d)
105            index.add(xb[i0:i1])
106            shard_index.add_shard(index)
107
108            index_2 = faiss.IndexFlatL2(d)
109            irm = faiss.IndexIDMap(index_2)
110            shard_index_2.add_shard(irm)
111
112        # test parallel add
113        shard_index_2.verbose = True
114        shard_index_2.add(xb)
115
116        for test_no in range(3):
117            with_threads = test_no == 1
118
119            print('shard search test_no = %d' % test_no)
120            if with_threads:
121                remember_nt = faiss.omp_get_max_threads()
122                faiss.omp_set_num_threads(1)
123                shard_index.threaded = True
124            else:
125                shard_index.threaded = False
126
127            if test_no != 2:
128                _D, I = shard_index.search(xq, k)
129            else:
130                _D, I = shard_index_2.search(xq, k)
131
132            print(I[:5, :6])
133
134            if with_threads:
135                faiss.omp_set_num_threads(remember_nt)
136
137            ndiff = (I != Iref).sum()
138
139            print('%d / %d differences' % (ndiff, nq * k))
140            assert(ndiff < nq * k / 1000.)
141
142        if nthreads is not None:
143            faiss.omp_set_num_threads(nthreads)
144
145class Merge(unittest.TestCase):
146
147    def make_index_for_merge(self, quant, index_type, master_index):
148        ncent = 40
149        if index_type == 1:
150            index = faiss.IndexIVFFlat(quant, d, ncent, faiss.METRIC_L2)
151            if master_index:
152                index.is_trained = True
153        elif index_type == 2:
154            index = faiss.IndexIVFPQ(quant, d, ncent, 4, 8)
155            if master_index:
156                index.pq = master_index.pq
157                index.is_trained = True
158        elif index_type == 3:
159            index = faiss.IndexIVFPQR(quant, d, ncent, 4, 8, 8, 8)
160            if master_index:
161                index.pq = master_index.pq
162                index.refine_pq = master_index.refine_pq
163                index.is_trained = True
164        elif index_type == 4:
165            # quant used as the actual index
166            index = faiss.IndexIDMap(quant)
167        return index
168
169    def do_test_merge(self, index_type):
170        k = 16
171        quant = faiss.IndexFlatL2(d)
172        ref_index = self.make_index_for_merge(quant, index_type, False)
173
174        # trains the quantizer
175        ref_index.train(xt)
176
177        print('ref search')
178        ref_index.add(xb)
179        _Dref, Iref = ref_index.search(xq, k)
180        print(Iref[:5, :6])
181
182        indexes = []
183        ni = 3
184        for i in range(ni):
185            i0 = int(i * nb / ni)
186            i1 = int((i + 1) * nb / ni)
187            index = self.make_index_for_merge(quant, index_type, ref_index)
188            index.is_trained = True
189            index.add(xb[i0:i1])
190            indexes.append(index)
191
192        index = indexes[0]
193
194        for i in range(1, ni):
195            print('merge ntotal=%d other.ntotal=%d ' % (
196                index.ntotal, indexes[i].ntotal))
197            index.merge_from(indexes[i], index.ntotal)
198
199        _D, I = index.search(xq, k)
200        print(I[:5, :6])
201
202        ndiff = (I != Iref).sum()
203        print('%d / %d differences' % (ndiff, nq * k))
204        assert(ndiff < nq * k / 1000.)
205
206    def test_merge(self):
207        self.do_test_merge(1)
208        self.do_test_merge(2)
209        self.do_test_merge(3)
210
211    def do_test_remove(self, index_type):
212        k = 16
213        quant = faiss.IndexFlatL2(d)
214        index = self.make_index_for_merge(quant, index_type, None)
215
216        # trains the quantizer
217        index.train(xt)
218
219        if index_type < 4:
220            index.add(xb)
221        else:
222            gen = np.random.RandomState(1234)
223            id_list = gen.permutation(nb * 7)[:nb].astype('int64')
224            index.add_with_ids(xb, id_list)
225
226        print('ref search ntotal=%d' % index.ntotal)
227        Dref, Iref = index.search(xq, k)
228
229        toremove = np.zeros(nq * k, dtype='int64')
230        nr = 0
231        for i in range(nq):
232            for j in range(k):
233                # remove all even results (it's ok if there are duplicates
234                # in the list of ids)
235                if Iref[i, j] % 2 == 0:
236                    nr = nr + 1
237                    toremove[nr] = Iref[i, j]
238
239        print('nr=', nr)
240
241        idsel = faiss.IDSelectorBatch(
242            nr, faiss.swig_ptr(toremove))
243
244        for i in range(nr):
245            assert(idsel.is_member(int(toremove[i])))
246
247        nremoved = index.remove_ids(idsel)
248
249        print('nremoved=%d ntotal=%d' % (nremoved, index.ntotal))
250
251        D, I = index.search(xq, k)
252
253        # make sure results are in the same order with even ones removed
254        ndiff = 0
255        for i in range(nq):
256            j2 = 0
257            for j in range(k):
258                if Iref[i, j] % 2 != 0:
259                    if I[i, j2] != Iref[i, j]:
260                        ndiff += 1
261                    assert abs(D[i, j2] - Dref[i, j]) < 1e-5
262                    j2 += 1
263        # draws are ordered arbitrarily
264        assert ndiff < 5
265
266    def test_remove(self):
267        self.do_test_remove(1)
268        self.do_test_remove(2)
269        self.do_test_remove(4)
270
271
272
273
274
275
276if __name__ == '__main__':
277    unittest.main()
278