1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6from __future__ import absolute_import, division, print_function, unicode_literals 7 8# translation of test_meta_index.lua 9 10import sys 11import numpy as np 12import faiss 13import unittest 14 15from common_faiss_tests import Randu10k 16 17ru = Randu10k() 18 19xb = ru.xb 20xt = ru.xt 21xq = ru.xq 22nb, d = xb.shape 23nq, d = xq.shape 24 25 26class IDRemap(unittest.TestCase): 27 28 def test_id_remap_idmap(self): 29 # reference: index without remapping 30 31 index = faiss.IndexPQ(d, 8, 8) 32 k = 10 33 index.train(xt) 34 index.add(xb) 35 _Dref, Iref = index.search(xq, k) 36 37 # try a remapping 38 ids = np.arange(nb)[::-1].copy().astype('int64') 39 40 sub_index = faiss.IndexPQ(d, 8, 8) 41 index2 = faiss.IndexIDMap(sub_index) 42 43 index2.train(xt) 44 index2.add_with_ids(xb, ids) 45 46 _D, I = index2.search(xq, k) 47 48 assert np.all(I == nb - 1 - Iref) 49 50 def test_id_remap_ivf(self): 51 # coarse quantizer in common 52 coarse_quantizer = faiss.IndexFlatIP(d) 53 ncentroids = 25 54 55 # reference: index without remapping 56 57 index = faiss.IndexIVFPQ(coarse_quantizer, d, 58 ncentroids, 8, 8) 59 index.nprobe = 5 60 k = 10 61 index.train(xt) 62 index.add(xb) 63 _Dref, Iref = index.search(xq, k) 64 65 # try a remapping 66 ids = np.arange(nb)[::-1].copy().astype('int64') 67 68 index2 = faiss.IndexIVFPQ(coarse_quantizer, d, 69 ncentroids, 8, 8) 70 index2.nprobe = 5 71 72 index2.train(xt) 73 index2.add_with_ids(xb, ids) 74 75 _D, I = index2.search(xq, k) 76 assert np.all(I == nb - 1 - Iref) 77 78 79class Shards(unittest.TestCase): 80 81 def test_shards(self): 82 k = 32 83 ref_index = faiss.IndexFlatL2(d) 84 85 print('ref search') 86 ref_index.add(xb) 87 _Dref, Iref = ref_index.search(xq, k) 88 print(Iref[:5, :6]) 89 90 # there is a OpenMP bug in this configuration, so disable threading 91 if sys.platform == "darwin" and "Clang 12" in sys.version: 92 nthreads = faiss.omp_get_max_threads() 93 faiss.omp_set_num_threads(1) 94 else: 95 nthreads = None 96 97 shard_index = faiss.IndexShards(d) 98 shard_index_2 = faiss.IndexShards(d, True, False) 99 100 ni = 3 101 for i in range(ni): 102 i0 = int(i * nb / ni) 103 i1 = int((i + 1) * nb / ni) 104 index = faiss.IndexFlatL2(d) 105 index.add(xb[i0:i1]) 106 shard_index.add_shard(index) 107 108 index_2 = faiss.IndexFlatL2(d) 109 irm = faiss.IndexIDMap(index_2) 110 shard_index_2.add_shard(irm) 111 112 # test parallel add 113 shard_index_2.verbose = True 114 shard_index_2.add(xb) 115 116 for test_no in range(3): 117 with_threads = test_no == 1 118 119 print('shard search test_no = %d' % test_no) 120 if with_threads: 121 remember_nt = faiss.omp_get_max_threads() 122 faiss.omp_set_num_threads(1) 123 shard_index.threaded = True 124 else: 125 shard_index.threaded = False 126 127 if test_no != 2: 128 _D, I = shard_index.search(xq, k) 129 else: 130 _D, I = shard_index_2.search(xq, k) 131 132 print(I[:5, :6]) 133 134 if with_threads: 135 faiss.omp_set_num_threads(remember_nt) 136 137 ndiff = (I != Iref).sum() 138 139 print('%d / %d differences' % (ndiff, nq * k)) 140 assert(ndiff < nq * k / 1000.) 141 142 if nthreads is not None: 143 faiss.omp_set_num_threads(nthreads) 144 145class Merge(unittest.TestCase): 146 147 def make_index_for_merge(self, quant, index_type, master_index): 148 ncent = 40 149 if index_type == 1: 150 index = faiss.IndexIVFFlat(quant, d, ncent, faiss.METRIC_L2) 151 if master_index: 152 index.is_trained = True 153 elif index_type == 2: 154 index = faiss.IndexIVFPQ(quant, d, ncent, 4, 8) 155 if master_index: 156 index.pq = master_index.pq 157 index.is_trained = True 158 elif index_type == 3: 159 index = faiss.IndexIVFPQR(quant, d, ncent, 4, 8, 8, 8) 160 if master_index: 161 index.pq = master_index.pq 162 index.refine_pq = master_index.refine_pq 163 index.is_trained = True 164 elif index_type == 4: 165 # quant used as the actual index 166 index = faiss.IndexIDMap(quant) 167 return index 168 169 def do_test_merge(self, index_type): 170 k = 16 171 quant = faiss.IndexFlatL2(d) 172 ref_index = self.make_index_for_merge(quant, index_type, False) 173 174 # trains the quantizer 175 ref_index.train(xt) 176 177 print('ref search') 178 ref_index.add(xb) 179 _Dref, Iref = ref_index.search(xq, k) 180 print(Iref[:5, :6]) 181 182 indexes = [] 183 ni = 3 184 for i in range(ni): 185 i0 = int(i * nb / ni) 186 i1 = int((i + 1) * nb / ni) 187 index = self.make_index_for_merge(quant, index_type, ref_index) 188 index.is_trained = True 189 index.add(xb[i0:i1]) 190 indexes.append(index) 191 192 index = indexes[0] 193 194 for i in range(1, ni): 195 print('merge ntotal=%d other.ntotal=%d ' % ( 196 index.ntotal, indexes[i].ntotal)) 197 index.merge_from(indexes[i], index.ntotal) 198 199 _D, I = index.search(xq, k) 200 print(I[:5, :6]) 201 202 ndiff = (I != Iref).sum() 203 print('%d / %d differences' % (ndiff, nq * k)) 204 assert(ndiff < nq * k / 1000.) 205 206 def test_merge(self): 207 self.do_test_merge(1) 208 self.do_test_merge(2) 209 self.do_test_merge(3) 210 211 def do_test_remove(self, index_type): 212 k = 16 213 quant = faiss.IndexFlatL2(d) 214 index = self.make_index_for_merge(quant, index_type, None) 215 216 # trains the quantizer 217 index.train(xt) 218 219 if index_type < 4: 220 index.add(xb) 221 else: 222 gen = np.random.RandomState(1234) 223 id_list = gen.permutation(nb * 7)[:nb].astype('int64') 224 index.add_with_ids(xb, id_list) 225 226 print('ref search ntotal=%d' % index.ntotal) 227 Dref, Iref = index.search(xq, k) 228 229 toremove = np.zeros(nq * k, dtype='int64') 230 nr = 0 231 for i in range(nq): 232 for j in range(k): 233 # remove all even results (it's ok if there are duplicates 234 # in the list of ids) 235 if Iref[i, j] % 2 == 0: 236 nr = nr + 1 237 toremove[nr] = Iref[i, j] 238 239 print('nr=', nr) 240 241 idsel = faiss.IDSelectorBatch( 242 nr, faiss.swig_ptr(toremove)) 243 244 for i in range(nr): 245 assert(idsel.is_member(int(toremove[i]))) 246 247 nremoved = index.remove_ids(idsel) 248 249 print('nremoved=%d ntotal=%d' % (nremoved, index.ntotal)) 250 251 D, I = index.search(xq, k) 252 253 # make sure results are in the same order with even ones removed 254 ndiff = 0 255 for i in range(nq): 256 j2 = 0 257 for j in range(k): 258 if Iref[i, j] % 2 != 0: 259 if I[i, j2] != Iref[i, j]: 260 ndiff += 1 261 assert abs(D[i, j2] - Dref[i, j]) < 1e-5 262 j2 += 1 263 # draws are ordered arbitrarily 264 assert ndiff < 5 265 266 def test_remove(self): 267 self.do_test_remove(1) 268 self.do_test_remove(2) 269 self.do_test_remove(4) 270 271 272 273 274 275 276if __name__ == '__main__': 277 unittest.main() 278