1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6 7import unittest 8import platform 9 10import numpy as np 11import faiss 12 13from faiss.contrib import datasets 14from faiss.contrib.inspect_tools import get_invlist 15 16 17class TestLUTQuantization(unittest.TestCase): 18 19 def compute_dis_float(self, codes, LUT, bias): 20 nprobe, nt, M = codes.shape 21 dis = np.zeros((nprobe, nt), dtype='float32') 22 if bias is not None: 23 dis[:] = bias.reshape(-1, 1) 24 25 if LUT.ndim == 2: 26 LUTp = LUT 27 28 for p in range(nprobe): 29 if LUT.ndim == 3: 30 LUTp = LUT[p] 31 32 for i in range(nt): 33 dis[p, i] += LUTp[np.arange(M), codes[p, i]].sum() 34 35 return dis 36 37 def compute_dis_quant(self, codes, LUT, bias, a, b): 38 nprobe, nt, M = codes.shape 39 dis = np.zeros((nprobe, nt), dtype='uint16') 40 if bias is not None: 41 dis[:] = bias.reshape(-1, 1) 42 43 if LUT.ndim == 2: 44 LUTp = LUT 45 46 for p in range(nprobe): 47 if LUT.ndim == 3: 48 LUTp = LUT[p] 49 50 for i in range(nt): 51 dis[p, i] += LUTp[np.arange(M), codes[p, i]].astype('uint16').sum() 52 53 return dis / a + b 54 55 def do_test(self, LUT, bias, nprobe, alt_3d=False): 56 M, ksub = LUT.shape[-2:] 57 nt = 200 58 59 rs = np.random.RandomState(123) 60 codes = rs.randint(ksub, size=(nprobe, nt, M)).astype('uint8') 61 62 dis_ref = self.compute_dis_float(codes, LUT, bias) 63 64 LUTq = np.zeros(LUT.shape, dtype='uint8') 65 biasq = ( 66 np.zeros(bias.shape, dtype='uint16') 67 if (bias is not None) and not alt_3d else None 68 ) 69 atab = np.zeros(1, dtype='float32') 70 btab = np.zeros(1, dtype='float32') 71 72 def sp(x): 73 return faiss.swig_ptr(x) if x is not None else None 74 75 faiss.quantize_LUT_and_bias( 76 nprobe, M, ksub, LUT.ndim == 3, 77 sp(LUT), sp(bias), sp(LUTq), M, sp(biasq), 78 sp(atab), sp(btab) 79 ) 80 a = atab[0] 81 b = btab[0] 82 dis_new = self.compute_dis_quant(codes, LUTq, biasq, a, b) 83 84 # print(a, b, dis_ref.sum()) 85 avg_realtive_error = np.abs(dis_new - dis_ref).sum() / dis_ref.sum() 86 # print('a=', a, 'avg_relative_error=', avg_realtive_error) 87 self.assertLess(avg_realtive_error, 0.0005) 88 89 def test_no_residual_ip(self): 90 ksub = 16 91 M = 20 92 nprobe = 10 93 rs = np.random.RandomState(1234) 94 LUT = rs.rand(M, ksub).astype('float32') 95 bias = None 96 97 self.do_test(LUT, bias, nprobe) 98 99 def test_by_residual_ip(self): 100 ksub = 16 101 M = 20 102 nprobe = 10 103 rs = np.random.RandomState(1234) 104 LUT = rs.rand(M, ksub).astype('float32') 105 bias = rs.rand(nprobe).astype('float32') 106 bias *= 10 107 108 self.do_test(LUT, bias, nprobe) 109 110 def test_by_residual_L2(self): 111 ksub = 16 112 M = 20 113 nprobe = 10 114 rs = np.random.RandomState(1234) 115 LUT = rs.rand(nprobe, M, ksub).astype('float32') 116 bias = rs.rand(nprobe).astype('float32') 117 bias *= 10 118 119 self.do_test(LUT, bias, nprobe) 120 121 def test_by_residual_L2_v2(self): 122 ksub = 16 123 M = 20 124 nprobe = 10 125 rs = np.random.RandomState(1234) 126 LUT = rs.rand(nprobe, M, ksub).astype('float32') 127 bias = rs.rand(nprobe).astype('float32') 128 bias *= 10 129 130 self.do_test(LUT, bias, nprobe, alt_3d=True) 131 132 133 134 135########################################################## 136# Tests for various IndexPQFastScan implementations 137########################################################## 138 139def verify_with_draws(testcase, Dref, Iref, Dnew, Inew): 140 """ verify a list of results where there are draws in the distances (because 141 they are integer). """ 142 np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5) 143 # here we have to be careful because of draws 144 for i in range(len(Iref)): 145 if np.all(Iref[i] == Inew[i]): # easy case 146 continue 147 # we can deduce nothing about the latest line 148 skip_dis = Dref[i, -1] 149 for dis in np.unique(Dref): 150 if dis == skip_dis: continue 151 mask = Dref[i, :] == dis 152 testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask])) 153 154def three_metrics(Dref, Iref, Dnew, Inew): 155 nq = Iref.shape[0] 156 recall_at_1 = (Iref[:, 0] == Inew[:, 0]).sum() / nq 157 recall_at_10 = (Iref[:, :1] == Inew[:, :10]).sum() / nq 158 ninter = 0 159 for i in range(nq): 160 ninter += len(np.intersect1d(Inew[i], Iref[i])) 161 intersection_at_10 = ninter / nq 162 return recall_at_1, recall_at_10, intersection_at_10 163 164 165########################################################## 166# Tests for various IndexIVFPQFastScan implementations 167########################################################## 168 169class TestIVFImplem1(unittest.TestCase): 170 """ Verify implem 1 (search from original invlists) 171 against IndexIVFPQ """ 172 173 def do_test(self, by_residual, metric_type=faiss.METRIC_L2, 174 use_precomputed_table=0): 175 ds = datasets.SyntheticDataset(32, 2000, 5000, 1000) 176 177 index = faiss.index_factory(32, "IVF32,PQ16x4np", metric_type) 178 index.use_precomputed_table 179 index.use_precomputed_table = use_precomputed_table 180 index.train(ds.get_train()) 181 index.add(ds.get_database()) 182 index.nprobe = 4 183 index.by_residual = by_residual 184 Da, Ia = index.search(ds.get_queries(), 10) 185 186 index2 = faiss.IndexIVFPQFastScan(index) 187 index2.implem = 1 188 Db, Ib = index2.search(ds.get_queries(), 10) 189 # self.assertLess((Ia != Ib).sum(), Ia.size * 0.005) 190 np.testing.assert_array_equal(Ia, Ib) 191 np.testing.assert_almost_equal(Da, Db, decimal=5) 192 193 def test_no_residual(self): 194 self.do_test(False) 195 196 def test_by_residual(self): 197 self.do_test(True) 198 199 def test_by_residual_no_precomputed(self): 200 self.do_test(True, use_precomputed_table=-1) 201 202 def test_no_residual_ip(self): 203 self.do_test(False, faiss.METRIC_INNER_PRODUCT) 204 205 def test_by_residual_ip(self): 206 self.do_test(True, faiss.METRIC_INNER_PRODUCT) 207 208 209 210class TestIVFImplem2(unittest.TestCase): 211 """ Verify implem 2 (search with original invlists with uint8 LUTs) 212 against IndexIVFPQ. Entails some loss in accuracy. """ 213 214 def eval_quant_loss(self, by_residual, metric=faiss.METRIC_L2): 215 ds = datasets.SyntheticDataset(32, 2000, 5000, 1000) 216 217 index = faiss.index_factory(32, "IVF32,PQ16x4np", metric) 218 index.train(ds.get_train()) 219 index.add(ds.get_database()) 220 index.nprobe = 4 221 index.by_residual = by_residual 222 Da, Ia = index.search(ds.get_queries(), 10) 223 224 # loss due to int8 quantization of LUTs 225 index2 = faiss.IndexIVFPQFastScan(index) 226 index2.implem = 2 227 Db, Ib = index2.search(ds.get_queries(), 10) 228 229 m3 = three_metrics(Da, Ia, Db, Ib) 230 231 232 # print(by_residual, metric, recall_at_1, recall_at_10, intersection_at_10) 233 ref_results = { 234 (True, 1): [0.985, 1.0, 9.872], 235 (True, 0): [ 0.987, 1.0, 9.914], 236 (False, 1): [0.991, 1.0, 9.907], 237 (False, 0): [0.986, 1.0, 9.917], 238 } 239 240 ref = ref_results[(by_residual, metric)] 241 242 self.assertGreaterEqual(m3[0], ref[0] * 0.995) 243 self.assertGreaterEqual(m3[1], ref[1] * 0.995) 244 self.assertGreaterEqual(m3[2], ref[2] * 0.995) 245 246 247 def test_qloss_no_residual(self): 248 self.eval_quant_loss(False) 249 250 def test_qloss_by_residual(self): 251 self.eval_quant_loss(True) 252 253 def test_qloss_no_residual_ip(self): 254 self.eval_quant_loss(False, faiss.METRIC_INNER_PRODUCT) 255 256 def test_qloss_by_residual_ip(self): 257 self.eval_quant_loss(True, faiss.METRIC_INNER_PRODUCT) 258 259class TestEquivPQ(unittest.TestCase): 260 261 def test_equiv_pq(self): 262 ds = datasets.SyntheticDataset(32, 2000, 200, 4) 263 264 index = faiss.index_factory(32, "IVF1,PQ16x4np") 265 index.by_residual = False 266 # force coarse quantizer 267 index.quantizer.add(np.zeros((1, 32), dtype='float32')) 268 index.train(ds.get_train()) 269 index.add(ds.get_database()) 270 Dref, Iref = index.search(ds.get_queries(), 4) 271 272 index_pq = faiss.index_factory(32, "PQ16x4np") 273 index_pq.pq = index.pq 274 index_pq.is_trained = True 275 index_pq.codes = faiss. downcast_InvertedLists( 276 index.invlists).codes.at(0) 277 index_pq.ntotal = index.ntotal 278 Dnew, Inew = index_pq.search(ds.get_queries(), 4) 279 280 np.testing.assert_array_equal(Iref, Inew) 281 np.testing.assert_array_equal(Dref, Dnew) 282 283 index_pq2 = faiss.IndexPQFastScan(index_pq) 284 index_pq2.implem = 12 285 Dref, Iref = index_pq2.search(ds.get_queries(), 4) 286 287 index2 = faiss.IndexIVFPQFastScan(index) 288 index2.implem = 12 289 Dnew, Inew = index2.search(ds.get_queries(), 4) 290 np.testing.assert_array_equal(Iref, Inew) 291 np.testing.assert_array_equal(Dref, Dnew) 292 293 294class TestIVFImplem12(unittest.TestCase): 295 296 IMPLEM = 12 297 298 def do_test(self, by_residual, metric=faiss.METRIC_L2, d=32): 299 ds = datasets.SyntheticDataset(d, 2000, 5000, 200) 300 301 index = faiss.index_factory(d, f"IVF32,PQ{d//2}x4np", metric) 302 # force coarse quantizer 303 # index.quantizer.add(np.zeros((1, 32), dtype='float32')) 304 index.by_residual = by_residual 305 index.train(ds.get_train()) 306 index.add(ds.get_database()) 307 index.nprobe = 4 308 309 index2 = faiss.IndexIVFPQFastScan(index) 310 index2.implem = 2 311 Dref, Iref = index2.search(ds.get_queries(), 4) 312 index2 = faiss.IndexIVFPQFastScan(index) 313 index2.implem = self.IMPLEM 314 Dnew, Inew = index2.search(ds.get_queries(), 4) 315 316 verify_with_draws(self, Dref, Iref, Dnew, Inew) 317 318 stats = faiss.cvar.indexIVF_stats 319 stats.reset() 320 321 # also verify with single result 322 Dnew, Inew = index2.search(ds.get_queries(), 1) 323 for q in range(len(Dref)): 324 if Dref[q, 1] == Dref[q, 0]: 325 # then we cannot conclude 326 continue 327 self.assertEqual(Iref[q, 0], Inew[q, 0]) 328 np.testing.assert_almost_equal(Dref[q, 0], Dnew[q, 0], decimal=5) 329 330 self.assertGreater(stats.ndis, 0) 331 332 def test_no_residual(self): 333 self.do_test(False) 334 335 def test_by_residual(self): 336 self.do_test(True) 337 338 def test_no_residual_ip(self): 339 self.do_test(False, metric=faiss.METRIC_INNER_PRODUCT) 340 341 def test_by_residual_ip(self): 342 self.do_test(True, metric=faiss.METRIC_INNER_PRODUCT) 343 344 def test_no_residual_odd_dim(self): 345 self.do_test(False, d=30) 346 347 def test_by_residual_odd_dim(self): 348 self.do_test(True, d=30) 349 350 351class TestIVFImplem10(TestIVFImplem12): 352 IMPLEM = 10 353 354 355class TestIVFImplem11(TestIVFImplem12): 356 IMPLEM = 11 357 358class TestIVFImplem13(TestIVFImplem12): 359 IMPLEM = 13 360 361 362class TestAdd(unittest.TestCase): 363 364 def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32): 365 bbs = 32 366 ds = datasets.SyntheticDataset(d, 2000, 5000, 200) 367 368 index = faiss.index_factory(d, f"IVF32,PQ{d//2}x4np", metric) 369 index.by_residual = by_residual 370 index.train(ds.get_train()) 371 index.nprobe = 4 372 373 xb = ds.get_database() 374 index.add(xb[:1235]) 375 376 index2 = faiss.IndexIVFPQFastScan(index, bbs) 377 378 index.add(xb[1235:]) 379 index3 = faiss.IndexIVFPQFastScan(index, bbs) 380 Dref, Iref = index3.search(ds.get_queries(), 10) 381 382 index2.add(xb[1235:]) 383 Dnew, Inew = index2.search(ds.get_queries(), 10) 384 385 np.testing.assert_array_equal(Dref, Dnew) 386 np.testing.assert_array_equal(Iref, Inew) 387 388 # direct verification of code content. Not sure the test is correct 389 # if codes are shuffled. 390 for list_no in range(32): 391 ref_ids, ref_codes = get_invlist(index3.invlists, list_no) 392 new_ids, new_codes = get_invlist(index2.invlists, list_no) 393 self.assertEqual(set(ref_ids), set(new_ids)) 394 new_code_per_id = { 395 new_ids[i]: new_codes[i // bbs, :, i % bbs] 396 for i in range(new_ids.size) 397 } 398 for i, the_id in enumerate(ref_ids): 399 ref_code_i = ref_codes[i // bbs, :, i % bbs] 400 new_code_i = new_code_per_id[the_id] 401 np.testing.assert_array_equal(ref_code_i, new_code_i) 402 403 404 def test_add(self): 405 self.do_test() 406 407 def test_odd_d(self): 408 self.do_test(d=30) 409 410 def test_bbs64(self): 411 self.do_test(bbs=64) 412 413 414class TestTraining(unittest.TestCase): 415 416 def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32): 417 bbs = 32 418 ds = datasets.SyntheticDataset(d, 2000, 5000, 200) 419 420 index = faiss.index_factory(d, f"IVF32,PQ{d//2}x4np", metric) 421 index.by_residual = by_residual 422 index.train(ds.get_train()) 423 index.add(ds.get_database()) 424 index.nprobe = 4 425 Dref, Iref = index.search(ds.get_queries(), 10) 426 427 index2 = faiss.IndexIVFPQFastScan( 428 index.quantizer, d, 32, d // 2, 4, metric, bbs) 429 index2.by_residual = by_residual 430 index2.train(ds.get_train()) 431 432 index2.add(ds.get_database()) 433 index2.nprobe = 4 434 Dnew, Inew = index2.search(ds.get_queries(), 10) 435 436 m3 = three_metrics(Dref, Iref, Dnew, Inew) 437 # print((by_residual, metric, d), ":", m3) 438 ref_m3_tab = { 439 (True, 1, 32) : (0.995, 1.0, 9.91), 440 (True, 0, 32) : (0.99, 1.0, 9.91), 441 (True, 1, 30) : (0.99, 1.0, 9.885), 442 (False, 1, 32) : (0.99, 1.0, 9.875), 443 (False, 0, 32) : (0.99, 1.0, 9.92), 444 (False, 1, 30) : (1.0, 1.0, 9.895) 445 } 446 ref_m3 = ref_m3_tab[(by_residual, metric, d)] 447 self.assertGreater(m3[0], ref_m3[0] * 0.99) 448 self.assertGreater(m3[1], ref_m3[1] * 0.99) 449 self.assertGreater(m3[2], ref_m3[2] * 0.99) 450 451 # Test I/O 452 data = faiss.serialize_index(index2) 453 index3 = faiss.deserialize_index(data) 454 D3, I3 = index3.search(ds.get_queries(), 10) 455 456 np.testing.assert_array_equal(I3, Inew) 457 np.testing.assert_array_equal(D3, Dnew) 458 459 def test_no_residual(self): 460 self.do_test(by_residual=False) 461 462 def test_by_residual(self): 463 self.do_test(by_residual=True) 464 465 def test_no_residual_ip(self): 466 self.do_test(by_residual=False, metric=faiss.METRIC_INNER_PRODUCT) 467 468 def test_by_residual_ip(self): 469 self.do_test(by_residual=True, metric=faiss.METRIC_INNER_PRODUCT) 470 471 def test_no_residual_odd_dim(self): 472 self.do_test(by_residual=False, d=30) 473 474 def test_by_residual_odd_dim(self): 475 self.do_test(by_residual=True, d=30) 476