1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 #include <faiss/impl/NNDescent.h>
11 
12 #include <mutex>
13 #include <string>
14 
15 #include <faiss/impl/AuxIndexStructures.h>
16 
17 namespace faiss {
18 
19 using LockGuard = std::lock_guard<std::mutex>;
20 
21 namespace nndescent {
22 
23 void gen_random(std::mt19937& rng, int* addr, const int size, const int N);
24 
Nhood(int l,int s,std::mt19937 & rng,int N)25 Nhood::Nhood(int l, int s, std::mt19937& rng, int N) {
26     M = s;
27     nn_new.resize(s * 2);
28     gen_random(rng, nn_new.data(), (int)nn_new.size(), N);
29 }
30 
31 /// Copy operator
operator =(const Nhood & other)32 Nhood& Nhood::operator=(const Nhood& other) {
33     M = other.M;
34     std::copy(
35             other.nn_new.begin(),
36             other.nn_new.end(),
37             std::back_inserter(nn_new));
38     nn_new.reserve(other.nn_new.capacity());
39     pool.reserve(other.pool.capacity());
40     return *this;
41 }
42 
43 /// Copy constructor
Nhood(const Nhood & other)44 Nhood::Nhood(const Nhood& other) {
45     M = other.M;
46     std::copy(
47             other.nn_new.begin(),
48             other.nn_new.end(),
49             std::back_inserter(nn_new));
50     nn_new.reserve(other.nn_new.capacity());
51     pool.reserve(other.pool.capacity());
52 }
53 
54 /// Insert a point into the candidate pool
insert(int id,float dist)55 void Nhood::insert(int id, float dist) {
56     LockGuard guard(lock);
57     if (dist > pool.front().distance)
58         return;
59     for (int i = 0; i < pool.size(); i++) {
60         if (id == pool[i].id)
61             return;
62     }
63     if (pool.size() < pool.capacity()) {
64         pool.push_back(Neighbor(id, dist, true));
65         std::push_heap(pool.begin(), pool.end());
66     } else {
67         std::pop_heap(pool.begin(), pool.end());
68         pool[pool.size() - 1] = Neighbor(id, dist, true);
69         std::push_heap(pool.begin(), pool.end());
70     }
71 }
72 
73 /// In local join, two objects are compared only if at least
74 /// one of them is new.
75 template <typename C>
join(C callback) const76 void Nhood::join(C callback) const {
77     for (int const i : nn_new) {
78         for (int const j : nn_new) {
79             if (i < j) {
80                 callback(i, j);
81             }
82         }
83         for (int j : nn_old) {
84             callback(i, j);
85         }
86     }
87 }
88 
gen_random(std::mt19937 & rng,int * addr,const int size,const int N)89 void gen_random(std::mt19937& rng, int* addr, const int size, const int N) {
90     for (int i = 0; i < size; ++i) {
91         addr[i] = rng() % (N - size);
92     }
93     std::sort(addr, addr + size);
94     for (int i = 1; i < size; ++i) {
95         if (addr[i] <= addr[i - 1]) {
96             addr[i] = addr[i - 1] + 1;
97         }
98     }
99     int off = rng() % N;
100     for (int i = 0; i < size; ++i) {
101         addr[i] = (addr[i] + off) % N;
102     }
103 }
104 
105 // Insert a new point into the candidate pool in ascending order
insert_into_pool(Neighbor * addr,int size,Neighbor nn)106 int insert_into_pool(Neighbor* addr, int size, Neighbor nn) {
107     // find the location to insert
108     int left = 0, right = size - 1;
109     if (addr[left].distance > nn.distance) {
110         memmove((char*)&addr[left + 1], &addr[left], size * sizeof(Neighbor));
111         addr[left] = nn;
112         return left;
113     }
114     if (addr[right].distance < nn.distance) {
115         addr[size] = nn;
116         return size;
117     }
118     while (left < right - 1) {
119         int mid = (left + right) / 2;
120         if (addr[mid].distance > nn.distance)
121             right = mid;
122         else
123             left = mid;
124     }
125     // check equal ID
126 
127     while (left > 0) {
128         if (addr[left].distance < nn.distance)
129             break;
130         if (addr[left].id == nn.id)
131             return size + 1;
132         left--;
133     }
134     if (addr[left].id == nn.id || addr[right].id == nn.id)
135         return size + 1;
136     memmove((char*)&addr[right + 1],
137             &addr[right],
138             (size - right) * sizeof(Neighbor));
139     addr[right] = nn;
140     return right;
141 }
142 
143 } // namespace nndescent
144 
145 using namespace nndescent;
146 
147 constexpr int NUM_EVAL_POINTS = 100;
148 
NNDescent(const int d,const int K)149 NNDescent::NNDescent(const int d, const int K) : K(K), random_seed(2021), d(d) {
150     ntotal = 0;
151     has_built = false;
152     S = 10;
153     R = 100;
154     L = K + 50;
155     iter = 10;
156     search_L = 0;
157 }
158 
~NNDescent()159 NNDescent::~NNDescent() {}
160 
join(DistanceComputer & qdis)161 void NNDescent::join(DistanceComputer& qdis) {
162 #pragma omp parallel for default(shared) schedule(dynamic, 100)
163     for (int n = 0; n < ntotal; n++) {
164         graph[n].join([&](int i, int j) {
165             if (i != j) {
166                 float dist = qdis.symmetric_dis(i, j);
167                 graph[i].insert(j, dist);
168                 graph[j].insert(i, dist);
169             }
170         });
171     }
172 }
173 
174 /// Sample neighbors for each node to peform local join later
175 /// Store them in nn_new and nn_old
update()176 void NNDescent::update() {
177     // Step 1.
178     // Clear all nn_new and nn_old
179 #pragma omp parallel for
180     for (int i = 0; i < ntotal; i++) {
181         std::vector<int>().swap(graph[i].nn_new);
182         std::vector<int>().swap(graph[i].nn_old);
183     }
184 
185     // Step 2.
186     // Compute the number of neighbors which is new i.e. flag is true
187     // in the candidate pool. This must not exceed the sample number S.
188     // That means We only select S new neighbors.
189 #pragma omp parallel for
190     for (int n = 0; n < ntotal; ++n) {
191         auto& nn = graph[n];
192         std::sort(nn.pool.begin(), nn.pool.end());
193 
194         if (nn.pool.size() > L)
195             nn.pool.resize(L);
196         nn.pool.reserve(L); // keep the pool size be L
197 
198         int maxl = std::min(nn.M + S, (int)nn.pool.size());
199         int c = 0;
200         int l = 0;
201 
202         while ((l < maxl) && (c < S)) {
203             if (nn.pool[l].flag)
204                 ++c;
205             ++l;
206         }
207         nn.M = l;
208     }
209 
210     // Step 3.
211     // Find reverse links for each node
212     // Randomly choose R reverse links.
213 #pragma omp parallel
214     {
215         std::mt19937 rng(random_seed * 5081 + omp_get_thread_num());
216 #pragma omp for
217         for (int n = 0; n < ntotal; ++n) {
218             auto& node = graph[n];
219             auto& nn_new = node.nn_new;
220             auto& nn_old = node.nn_old;
221 
222             for (int l = 0; l < node.M; ++l) {
223                 auto& nn = node.pool[l];
224                 auto& other = graph[nn.id]; // the other side of the edge
225 
226                 if (nn.flag) { // the node is inserted newly
227                     // push the neighbor into nn_new
228                     nn_new.push_back(nn.id);
229                     // push itself into other.rnn_new if it is not in
230                     // the candidate pool of the other side
231                     if (nn.distance > other.pool.back().distance) {
232                         LockGuard guard(other.lock);
233                         if (other.rnn_new.size() < R) {
234                             other.rnn_new.push_back(n);
235                         } else {
236                             int pos = rng() % R;
237                             other.rnn_new[pos] = n;
238                         }
239                     }
240                     nn.flag = false;
241 
242                 } else { // the node is old
243                     // push the neighbor into nn_old
244                     nn_old.push_back(nn.id);
245                     // push itself into other.rnn_old if it is not in
246                     // the candidate pool of the other side
247                     if (nn.distance > other.pool.back().distance) {
248                         LockGuard guard(other.lock);
249                         if (other.rnn_old.size() < R) {
250                             other.rnn_old.push_back(n);
251                         } else {
252                             int pos = rng() % R;
253                             other.rnn_old[pos] = n;
254                         }
255                     }
256                 }
257             }
258             // make heap to join later (in join() function)
259             std::make_heap(node.pool.begin(), node.pool.end());
260         }
261     }
262 
263     // Step 4.
264     // Combine the forward and the reverse links
265     // R = 0 means no reverse links are used.
266 #pragma omp parallel for
267     for (int i = 0; i < ntotal; ++i) {
268         auto& nn_new = graph[i].nn_new;
269         auto& nn_old = graph[i].nn_old;
270         auto& rnn_new = graph[i].rnn_new;
271         auto& rnn_old = graph[i].rnn_old;
272 
273         nn_new.insert(nn_new.end(), rnn_new.begin(), rnn_new.end());
274         nn_old.insert(nn_old.end(), rnn_old.begin(), rnn_old.end());
275         if (nn_old.size() > R * 2) {
276             nn_old.resize(R * 2);
277             nn_old.reserve(R * 2);
278         }
279 
280         std::vector<int>().swap(graph[i].rnn_new);
281         std::vector<int>().swap(graph[i].rnn_old);
282     }
283 }
284 
nndescent(DistanceComputer & qdis,bool verbose)285 void NNDescent::nndescent(DistanceComputer& qdis, bool verbose) {
286     int num_eval_points = std::min(NUM_EVAL_POINTS, ntotal);
287     std::vector<int> eval_points(num_eval_points);
288     std::vector<std::vector<int>> acc_eval_set(num_eval_points);
289     std::mt19937 rng(random_seed * 6577 + omp_get_thread_num());
290     gen_random(rng, eval_points.data(), eval_points.size(), ntotal);
291     generate_eval_set(qdis, eval_points, acc_eval_set, ntotal);
292     for (int it = 0; it < iter; it++) {
293         join(qdis);
294         update();
295 
296         if (verbose) {
297             float recall = eval_recall(eval_points, acc_eval_set);
298             printf("Iter: %d, recall@%d: %lf\n", it, K, recall);
299         }
300     }
301 }
302 
303 /// Sample a small number of points to evaluate the quality of KNNG built
generate_eval_set(DistanceComputer & qdis,std::vector<int> & c,std::vector<std::vector<int>> & v,int N)304 void NNDescent::generate_eval_set(
305         DistanceComputer& qdis,
306         std::vector<int>& c,
307         std::vector<std::vector<int>>& v,
308         int N) {
309 #pragma omp parallel for
310     for (int i = 0; i < c.size(); i++) {
311         std::vector<Neighbor> tmp;
312         for (int j = 0; j < N; j++) {
313             if (i == j)
314                 continue; // skip itself
315             float dist = qdis.symmetric_dis(c[i], j);
316             tmp.push_back(Neighbor(j, dist, true));
317         }
318 
319         std::partial_sort(tmp.begin(), tmp.begin() + K, tmp.end());
320         for (int j = 0; j < K; j++) {
321             v[i].push_back(tmp[j].id);
322         }
323     }
324 }
325 
326 /// Evaluate the quality of KNNG built
eval_recall(std::vector<int> & eval_points,std::vector<std::vector<int>> & acc_eval_set)327 float NNDescent::eval_recall(
328         std::vector<int>& eval_points,
329         std::vector<std::vector<int>>& acc_eval_set) {
330     float mean_acc = 0.0f;
331     for (size_t i = 0; i < eval_points.size(); i++) {
332         float acc = 0;
333         std::vector<Neighbor>& g = graph[eval_points[i]].pool;
334         std::vector<int>& v = acc_eval_set[i];
335         for (size_t j = 0; j < g.size(); j++) {
336             for (size_t k = 0; k < v.size(); k++) {
337                 if (g[j].id == v[k]) {
338                     acc++;
339                     break;
340                 }
341             }
342         }
343         mean_acc += acc / v.size();
344     }
345     return mean_acc / eval_points.size();
346 }
347 
348 /// Initialize the KNN graph randomly
init_graph(DistanceComputer & qdis)349 void NNDescent::init_graph(DistanceComputer& qdis) {
350     graph.reserve(ntotal);
351     {
352         std::mt19937 rng(random_seed * 6007);
353         for (int i = 0; i < ntotal; i++) {
354             graph.push_back(Nhood(L, S, rng, (int)ntotal));
355         }
356     }
357 #pragma omp parallel
358     {
359         std::mt19937 rng(random_seed * 7741 + omp_get_thread_num());
360 #pragma omp for
361         for (int i = 0; i < ntotal; i++) {
362             std::vector<int> tmp(S);
363 
364             gen_random(rng, tmp.data(), S, ntotal);
365 
366             for (int j = 0; j < S; j++) {
367                 int id = tmp[j];
368                 if (id == i)
369                     continue;
370                 float dist = qdis.symmetric_dis(i, id);
371 
372                 graph[i].pool.push_back(Neighbor(id, dist, true));
373             }
374             std::make_heap(graph[i].pool.begin(), graph[i].pool.end());
375             graph[i].pool.reserve(L);
376         }
377     }
378 }
379 
build(DistanceComputer & qdis,const int n,bool verbose)380 void NNDescent::build(DistanceComputer& qdis, const int n, bool verbose) {
381     FAISS_THROW_IF_NOT_MSG(L >= K, "L should be >= K in NNDescent.build");
382 
383     if (verbose) {
384         printf("Parameters: K=%d, S=%d, R=%d, L=%d, iter=%d\n",
385                K,
386                S,
387                R,
388                L,
389                iter);
390     }
391 
392     ntotal = n;
393     init_graph(qdis);
394     nndescent(qdis, verbose);
395 
396     final_graph.resize(ntotal * K);
397 
398     // Store the neighbor link structure into final_graph
399     // Clear the old graph
400     for (int i = 0; i < ntotal; i++) {
401         std::sort(graph[i].pool.begin(), graph[i].pool.end());
402         for (int j = 0; j < K; j++) {
403             FAISS_ASSERT(graph[i].pool[j].id < ntotal);
404             final_graph[i * K + j] = graph[i].pool[j].id;
405         }
406     }
407     std::vector<Nhood>().swap(graph);
408     has_built = true;
409 
410     if (verbose) {
411         printf("Addes %d points into the index\n", ntotal);
412     }
413 }
414 
search(DistanceComputer & qdis,const int topk,idx_t * indices,float * dists,VisitedTable & vt) const415 void NNDescent::search(
416         DistanceComputer& qdis,
417         const int topk,
418         idx_t* indices,
419         float* dists,
420         VisitedTable& vt) const {
421     FAISS_THROW_IF_NOT_MSG(has_built, "The index is not build yet.");
422     int L = std::max(search_L, topk);
423 
424     // candidate pool, the K best items is the result.
425     std::vector<Neighbor> retset(L + 1);
426 
427     // Randomly choose L points to intialize the candidate pool
428     std::vector<int> init_ids(L);
429     std::mt19937 rng(random_seed);
430 
431     gen_random(rng, init_ids.data(), L, ntotal);
432     for (int i = 0; i < L; i++) {
433         int id = init_ids[i];
434         float dist = qdis(id);
435         retset[i] = Neighbor(id, dist, true);
436     }
437 
438     // Maintain the candidate pool in ascending order
439     std::sort(retset.begin(), retset.begin() + L);
440 
441     int k = 0;
442 
443     // Stop until the smallest position updated is >= L
444     while (k < L) {
445         int nk = L;
446 
447         if (retset[k].flag) {
448             retset[k].flag = false;
449             int n = retset[k].id;
450 
451             for (int m = 0; m < K; ++m) {
452                 int id = final_graph[n * K + m];
453                 if (vt.get(id))
454                     continue;
455 
456                 vt.set(id);
457                 float dist = qdis(id);
458                 if (dist >= retset[L - 1].distance)
459                     continue;
460 
461                 Neighbor nn(id, dist, true);
462                 int r = insert_into_pool(retset.data(), L, nn);
463 
464                 if (r < nk)
465                     nk = r;
466             }
467         }
468         if (nk <= k)
469             k = nk;
470         else
471             ++k;
472     }
473     for (size_t i = 0; i < topk; i++) {
474         indices[i] = retset[i].id;
475         dists[i] = retset[i].distance;
476     }
477 
478     vt.advance();
479 }
480 
reset()481 void NNDescent::reset() {
482     has_built = false;
483     ntotal = 0;
484     final_graph.resize(0);
485 }
486 
487 } // namespace faiss
488