1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include <cstdio>
9 #include <cstdlib>
10 #include <random>
11 
12 #include <gtest/gtest.h>
13 
14 #include <faiss/IndexFlat.h>
15 #include <faiss/IndexIVFPQ.h>
16 #include <faiss/index_io.h>
17 
TEST(IVFPQ,accuracy)18 TEST(IVFPQ, accuracy) {
19     // dimension of the vectors to index
20     int d = 64;
21 
22     // size of the database we plan to index
23     size_t nb = 1000;
24 
25     // make a set of nt training vectors in the unit cube
26     // (could be the database)
27     size_t nt = 1500;
28 
29     // make the index object and train it
30     faiss::IndexFlatL2 coarse_quantizer(d);
31 
32     // a reasonable number of cetroids to index nb vectors
33     int ncentroids = 25;
34 
35     faiss::IndexIVFPQ index(&coarse_quantizer, d, ncentroids, 16, 8);
36 
37     // index that gives the ground-truth
38     faiss::IndexFlatL2 index_gt(d);
39 
40     std::mt19937 rng;
41     std::uniform_real_distribution<> distrib;
42 
43     { // training
44 
45         std::vector<float> trainvecs(nt * d);
46         for (size_t i = 0; i < nt * d; i++) {
47             trainvecs[i] = distrib(rng);
48         }
49         index.verbose = true;
50         index.train(nt, trainvecs.data());
51     }
52 
53     { // populating the database
54 
55         std::vector<float> database(nb * d);
56         for (size_t i = 0; i < nb * d; i++) {
57             database[i] = distrib(rng);
58         }
59 
60         index.add(nb, database.data());
61         index_gt.add(nb, database.data());
62     }
63 
64     int nq = 200;
65     int n_ok;
66 
67     { // searching the database
68 
69         std::vector<float> queries(nq * d);
70         for (size_t i = 0; i < nq * d; i++) {
71             queries[i] = distrib(rng);
72         }
73 
74         std::vector<faiss::Index::idx_t> gt_nns(nq);
75         std::vector<float> gt_dis(nq);
76 
77         index_gt.search(nq, queries.data(), 1, gt_dis.data(), gt_nns.data());
78 
79         index.nprobe = 5;
80         int k = 5;
81         std::vector<faiss::Index::idx_t> nns(k * nq);
82         std::vector<float> dis(k * nq);
83 
84         index.search(nq, queries.data(), k, dis.data(), nns.data());
85 
86         n_ok = 0;
87         for (int q = 0; q < nq; q++) {
88             for (int i = 0; i < k; i++)
89                 if (nns[q * k + i] == gt_nns[q])
90                     n_ok++;
91         }
92         EXPECT_GT(n_ok, nq * 0.4);
93     }
94 }
95