1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include <faiss/impl/FaissAssert.h>
9 #include <exception>
10 #include <iostream>
11 
12 namespace faiss {
13 
14 template <typename IndexT>
ThreadedIndex(bool threaded)15 ThreadedIndex<IndexT>::ThreadedIndex(bool threaded)
16         // 0 is default dimension
17         : ThreadedIndex(0, threaded) {}
18 
19 template <typename IndexT>
ThreadedIndex(int d,bool threaded)20 ThreadedIndex<IndexT>::ThreadedIndex(int d, bool threaded)
21         : IndexT(d), own_fields(false), isThreaded_(threaded) {}
22 
23 template <typename IndexT>
~ThreadedIndex()24 ThreadedIndex<IndexT>::~ThreadedIndex() {
25     for (auto& p : indices_) {
26         if (isThreaded_) {
27             // should have worker thread
28             FAISS_ASSERT((bool)p.second);
29 
30             // This will also flush all pending work
31             p.second->stop();
32             p.second->waitForThreadExit();
33         } else {
34             // should not have worker thread
35             FAISS_ASSERT(!(bool)p.second);
36         }
37 
38         if (own_fields) {
39             delete p.first;
40         }
41     }
42 }
43 
44 template <typename IndexT>
addIndex(IndexT * index)45 void ThreadedIndex<IndexT>::addIndex(IndexT* index) {
46     // We inherit the dimension from the first index added to us if we don't
47     // have a set dimension
48     if (indices_.empty() && this->d == 0) {
49         this->d = index->d;
50     }
51 
52     // The new index must match our set dimension
53     FAISS_THROW_IF_NOT_FMT(
54             this->d == index->d,
55             "addIndex: dimension mismatch for "
56             "newly added index; expecting dim %d, "
57             "new index has dim %d",
58             this->d,
59             index->d);
60 
61     if (!indices_.empty()) {
62         auto& existing = indices_.front().first;
63 
64         FAISS_THROW_IF_NOT_MSG(
65                 index->metric_type == existing->metric_type,
66                 "addIndex: newly added index is "
67                 "of different metric type than old index");
68 
69         // Make sure this index is not duplicated
70         for (auto& p : indices_) {
71             FAISS_THROW_IF_NOT_MSG(
72                     p.first != index,
73                     "addIndex: attempting to add index "
74                     "that is already in the collection");
75         }
76     }
77 
78     indices_.emplace_back(std::make_pair(
79             index,
80             std::unique_ptr<WorkerThread>(
81                     isThreaded_ ? new WorkerThread : nullptr)));
82 
83     onAfterAddIndex(index);
84 }
85 
86 template <typename IndexT>
removeIndex(IndexT * index)87 void ThreadedIndex<IndexT>::removeIndex(IndexT* index) {
88     for (auto it = indices_.begin(); it != indices_.end(); ++it) {
89         if (it->first == index) {
90             // This is our index; stop the worker thread before removing it,
91             // to ensure that it has finished before function exit
92             if (isThreaded_) {
93                 // should have worker thread
94                 FAISS_ASSERT((bool)it->second);
95                 it->second->stop();
96                 it->second->waitForThreadExit();
97             } else {
98                 // should not have worker thread
99                 FAISS_ASSERT(!(bool)it->second);
100             }
101 
102             indices_.erase(it);
103             onAfterRemoveIndex(index);
104 
105             if (own_fields) {
106                 delete index;
107             }
108 
109             return;
110         }
111     }
112 
113     // could not find our index
114     FAISS_THROW_MSG("IndexReplicas::removeIndex: index not found");
115 }
116 
117 template <typename IndexT>
runOnIndex(std::function<void (int,IndexT *)> f)118 void ThreadedIndex<IndexT>::runOnIndex(std::function<void(int, IndexT*)> f) {
119     if (isThreaded_) {
120         std::vector<std::future<bool>> v;
121 
122         for (int i = 0; i < this->indices_.size(); ++i) {
123             auto& p = this->indices_[i];
124             auto indexPtr = p.first;
125             v.emplace_back(
126                     p.second->add([f, i, indexPtr]() { f(i, indexPtr); }));
127         }
128 
129         waitAndHandleFutures(v);
130     } else {
131         // Multiple exceptions may be thrown; gather them as we encounter them,
132         // while letting everything else run to completion
133         std::vector<std::pair<int, std::exception_ptr>> exceptions;
134 
135         for (int i = 0; i < this->indices_.size(); ++i) {
136             auto& p = this->indices_[i];
137             try {
138                 f(i, p.first);
139             } catch (...) {
140                 exceptions.emplace_back(
141                         std::make_pair(i, std::current_exception()));
142             }
143         }
144 
145         handleExceptions(exceptions);
146     }
147 }
148 
149 template <typename IndexT>
runOnIndex(std::function<void (int,const IndexT *)> f)150 void ThreadedIndex<IndexT>::runOnIndex(
151         std::function<void(int, const IndexT*)> f) const {
152     const_cast<ThreadedIndex<IndexT>*>(this)->runOnIndex(
153             [f](int i, IndexT* idx) { f(i, idx); });
154 }
155 
156 template <typename IndexT>
reset()157 void ThreadedIndex<IndexT>::reset() {
158     runOnIndex([](int, IndexT* index) { index->reset(); });
159     this->ntotal = 0;
160     this->is_trained = false;
161 }
162 
163 template <typename IndexT>
onAfterAddIndex(IndexT * index)164 void ThreadedIndex<IndexT>::onAfterAddIndex(IndexT* index) {}
165 
166 template <typename IndexT>
onAfterRemoveIndex(IndexT * index)167 void ThreadedIndex<IndexT>::onAfterRemoveIndex(IndexT* index) {}
168 
169 template <typename IndexT>
waitAndHandleFutures(std::vector<std::future<bool>> & v)170 void ThreadedIndex<IndexT>::waitAndHandleFutures(
171         std::vector<std::future<bool>>& v) {
172     // Blocking wait for completion for all of the indices, capturing any
173     // exceptions that are generated
174     std::vector<std::pair<int, std::exception_ptr>> exceptions;
175 
176     for (int i = 0; i < v.size(); ++i) {
177         auto& fut = v[i];
178 
179         try {
180             fut.get();
181         } catch (...) {
182             exceptions.emplace_back(
183                     std::make_pair(i, std::current_exception()));
184         }
185     }
186 
187     handleExceptions(exceptions);
188 }
189 
190 } // namespace faiss
191