1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import faiss
7import unittest
8import numpy as np
9
10from faiss.contrib import datasets
11from faiss.contrib.exhaustive_search import knn_ground_truth, range_ground_truth
12from faiss.contrib import evaluation
13
14
15from common_faiss_tests import get_dataset_2
16
17
18class TestComputeGT(unittest.TestCase):
19
20    def test_compute_GT(self):
21        d = 64
22        xt, xb, xq = get_dataset_2(d, 0, 10000, 100)
23
24        index = faiss.IndexFlatL2(d)
25        index.add(xb)
26        Dref, Iref = index.search(xq, 10)
27
28        # iterator function on the matrix
29
30        def matrix_iterator(xb, bs):
31            for i0 in range(0, xb.shape[0], bs):
32                yield xb[i0:i0 + bs]
33
34        Dnew, Inew = knn_ground_truth(xq, matrix_iterator(xb, 1000), 10)
35
36        np.testing.assert_array_equal(Iref, Inew)
37        np.testing.assert_almost_equal(Dref, Dnew, decimal=4)
38
39    def do_test_range(self, metric):
40        ds = datasets.SyntheticDataset(32, 0, 1000, 10)
41        xq = ds.get_queries()
42        xb = ds.get_database()
43        D, I = faiss.knn(xq, xb, 10, metric=metric)
44        threshold = float(D[:, -1].mean())
45
46        index = faiss.IndexFlat(32, metric)
47        index.add(xb)
48        ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
49
50        new_lims, new_D, new_I = range_ground_truth(
51            xq, ds.database_iterator(bs=100), threshold,
52            metric_type=metric)
53
54        evaluation.test_ref_range_results(
55            ref_lims, ref_D, ref_I,
56            new_lims, new_D, new_I
57        )
58
59    def test_range_L2(self):
60        self.do_test_range(faiss.METRIC_L2)
61
62    def test_range_IP(self):
63        self.do_test_range(faiss.METRIC_INNER_PRODUCT)
64