1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6import faiss 7import unittest 8import numpy as np 9 10from faiss.contrib import datasets 11from faiss.contrib.exhaustive_search import knn_ground_truth, range_ground_truth 12from faiss.contrib import evaluation 13 14 15from common_faiss_tests import get_dataset_2 16 17 18class TestComputeGT(unittest.TestCase): 19 20 def test_compute_GT(self): 21 d = 64 22 xt, xb, xq = get_dataset_2(d, 0, 10000, 100) 23 24 index = faiss.IndexFlatL2(d) 25 index.add(xb) 26 Dref, Iref = index.search(xq, 10) 27 28 # iterator function on the matrix 29 30 def matrix_iterator(xb, bs): 31 for i0 in range(0, xb.shape[0], bs): 32 yield xb[i0:i0 + bs] 33 34 Dnew, Inew = knn_ground_truth(xq, matrix_iterator(xb, 1000), 10) 35 36 np.testing.assert_array_equal(Iref, Inew) 37 np.testing.assert_almost_equal(Dref, Dnew, decimal=4) 38 39 def do_test_range(self, metric): 40 ds = datasets.SyntheticDataset(32, 0, 1000, 10) 41 xq = ds.get_queries() 42 xb = ds.get_database() 43 D, I = faiss.knn(xq, xb, 10, metric=metric) 44 threshold = float(D[:, -1].mean()) 45 46 index = faiss.IndexFlat(32, metric) 47 index.add(xb) 48 ref_lims, ref_D, ref_I = index.range_search(xq, threshold) 49 50 new_lims, new_D, new_I = range_ground_truth( 51 xq, ds.database_iterator(bs=100), threshold, 52 metric_type=metric) 53 54 evaluation.test_ref_range_results( 55 ref_lims, ref_D, ref_I, 56 new_lims, new_D, new_I 57 ) 58 59 def test_range_L2(self): 60 self.do_test_range(faiss.METRIC_L2) 61 62 def test_range_IP(self): 63 self.do_test_range(faiss.METRIC_INNER_PRODUCT) 64