1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #include <cstdio>
9 #include <cstdlib>
10 #include <random>
11
12 #include <gtest/gtest.h>
13
14 #include <faiss/IndexFlat.h>
15 #include <faiss/IndexIVFPQ.h>
16 #include <faiss/index_io.h>
17
TEST(IVFPQ,accuracy)18 TEST(IVFPQ, accuracy) {
19 // dimension of the vectors to index
20 int d = 64;
21
22 // size of the database we plan to index
23 size_t nb = 1000;
24
25 // make a set of nt training vectors in the unit cube
26 // (could be the database)
27 size_t nt = 1500;
28
29 // make the index object and train it
30 faiss::IndexFlatL2 coarse_quantizer(d);
31
32 // a reasonable number of cetroids to index nb vectors
33 int ncentroids = 25;
34
35 faiss::IndexIVFPQ index(&coarse_quantizer, d, ncentroids, 16, 8);
36
37 // index that gives the ground-truth
38 faiss::IndexFlatL2 index_gt(d);
39
40 std::mt19937 rng;
41 std::uniform_real_distribution<> distrib;
42
43 { // training
44
45 std::vector<float> trainvecs(nt * d);
46 for (size_t i = 0; i < nt * d; i++) {
47 trainvecs[i] = distrib(rng);
48 }
49 index.verbose = true;
50 index.train(nt, trainvecs.data());
51 }
52
53 { // populating the database
54
55 std::vector<float> database(nb * d);
56 for (size_t i = 0; i < nb * d; i++) {
57 database[i] = distrib(rng);
58 }
59
60 index.add(nb, database.data());
61 index_gt.add(nb, database.data());
62 }
63
64 int nq = 200;
65 int n_ok;
66
67 { // searching the database
68
69 std::vector<float> queries(nq * d);
70 for (size_t i = 0; i < nq * d; i++) {
71 queries[i] = distrib(rng);
72 }
73
74 std::vector<faiss::Index::idx_t> gt_nns(nq);
75 std::vector<float> gt_dis(nq);
76
77 index_gt.search(nq, queries.data(), 1, gt_dis.data(), gt_nns.data());
78
79 index.nprobe = 5;
80 int k = 5;
81 std::vector<faiss::Index::idx_t> nns(k * nq);
82 std::vector<float> dis(k * nq);
83
84 index.search(nq, queries.data(), k, dis.data(), nns.data());
85
86 n_ok = 0;
87 for (int q = 0; q < nq; q++) {
88 for (int i = 0; i < k; i++)
89 if (nns[q * k + i] == gt_nns[q])
90 n_ok++;
91 }
92 EXPECT_GT(n_ok, nq * 0.4);
93 }
94 }
95