1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #include <faiss/impl/FaissAssert.h>
9 #include <exception>
10 #include <iostream>
11
12 namespace faiss {
13
14 template <typename IndexT>
ThreadedIndex(bool threaded)15 ThreadedIndex<IndexT>::ThreadedIndex(bool threaded)
16 // 0 is default dimension
17 : ThreadedIndex(0, threaded) {}
18
19 template <typename IndexT>
ThreadedIndex(int d,bool threaded)20 ThreadedIndex<IndexT>::ThreadedIndex(int d, bool threaded)
21 : IndexT(d), own_fields(false), isThreaded_(threaded) {}
22
23 template <typename IndexT>
~ThreadedIndex()24 ThreadedIndex<IndexT>::~ThreadedIndex() {
25 for (auto& p : indices_) {
26 if (isThreaded_) {
27 // should have worker thread
28 FAISS_ASSERT((bool)p.second);
29
30 // This will also flush all pending work
31 p.second->stop();
32 p.second->waitForThreadExit();
33 } else {
34 // should not have worker thread
35 FAISS_ASSERT(!(bool)p.second);
36 }
37
38 if (own_fields) {
39 delete p.first;
40 }
41 }
42 }
43
44 template <typename IndexT>
addIndex(IndexT * index)45 void ThreadedIndex<IndexT>::addIndex(IndexT* index) {
46 // We inherit the dimension from the first index added to us if we don't
47 // have a set dimension
48 if (indices_.empty() && this->d == 0) {
49 this->d = index->d;
50 }
51
52 // The new index must match our set dimension
53 FAISS_THROW_IF_NOT_FMT(
54 this->d == index->d,
55 "addIndex: dimension mismatch for "
56 "newly added index; expecting dim %d, "
57 "new index has dim %d",
58 this->d,
59 index->d);
60
61 if (!indices_.empty()) {
62 auto& existing = indices_.front().first;
63
64 FAISS_THROW_IF_NOT_MSG(
65 index->metric_type == existing->metric_type,
66 "addIndex: newly added index is "
67 "of different metric type than old index");
68
69 // Make sure this index is not duplicated
70 for (auto& p : indices_) {
71 FAISS_THROW_IF_NOT_MSG(
72 p.first != index,
73 "addIndex: attempting to add index "
74 "that is already in the collection");
75 }
76 }
77
78 indices_.emplace_back(std::make_pair(
79 index,
80 std::unique_ptr<WorkerThread>(
81 isThreaded_ ? new WorkerThread : nullptr)));
82
83 onAfterAddIndex(index);
84 }
85
86 template <typename IndexT>
removeIndex(IndexT * index)87 void ThreadedIndex<IndexT>::removeIndex(IndexT* index) {
88 for (auto it = indices_.begin(); it != indices_.end(); ++it) {
89 if (it->first == index) {
90 // This is our index; stop the worker thread before removing it,
91 // to ensure that it has finished before function exit
92 if (isThreaded_) {
93 // should have worker thread
94 FAISS_ASSERT((bool)it->second);
95 it->second->stop();
96 it->second->waitForThreadExit();
97 } else {
98 // should not have worker thread
99 FAISS_ASSERT(!(bool)it->second);
100 }
101
102 indices_.erase(it);
103 onAfterRemoveIndex(index);
104
105 if (own_fields) {
106 delete index;
107 }
108
109 return;
110 }
111 }
112
113 // could not find our index
114 FAISS_THROW_MSG("IndexReplicas::removeIndex: index not found");
115 }
116
117 template <typename IndexT>
runOnIndex(std::function<void (int,IndexT *)> f)118 void ThreadedIndex<IndexT>::runOnIndex(std::function<void(int, IndexT*)> f) {
119 if (isThreaded_) {
120 std::vector<std::future<bool>> v;
121
122 for (int i = 0; i < this->indices_.size(); ++i) {
123 auto& p = this->indices_[i];
124 auto indexPtr = p.first;
125 v.emplace_back(
126 p.second->add([f, i, indexPtr]() { f(i, indexPtr); }));
127 }
128
129 waitAndHandleFutures(v);
130 } else {
131 // Multiple exceptions may be thrown; gather them as we encounter them,
132 // while letting everything else run to completion
133 std::vector<std::pair<int, std::exception_ptr>> exceptions;
134
135 for (int i = 0; i < this->indices_.size(); ++i) {
136 auto& p = this->indices_[i];
137 try {
138 f(i, p.first);
139 } catch (...) {
140 exceptions.emplace_back(
141 std::make_pair(i, std::current_exception()));
142 }
143 }
144
145 handleExceptions(exceptions);
146 }
147 }
148
149 template <typename IndexT>
runOnIndex(std::function<void (int,const IndexT *)> f)150 void ThreadedIndex<IndexT>::runOnIndex(
151 std::function<void(int, const IndexT*)> f) const {
152 const_cast<ThreadedIndex<IndexT>*>(this)->runOnIndex(
153 [f](int i, IndexT* idx) { f(i, idx); });
154 }
155
156 template <typename IndexT>
reset()157 void ThreadedIndex<IndexT>::reset() {
158 runOnIndex([](int, IndexT* index) { index->reset(); });
159 this->ntotal = 0;
160 this->is_trained = false;
161 }
162
163 template <typename IndexT>
onAfterAddIndex(IndexT * index)164 void ThreadedIndex<IndexT>::onAfterAddIndex(IndexT* index) {}
165
166 template <typename IndexT>
onAfterRemoveIndex(IndexT * index)167 void ThreadedIndex<IndexT>::onAfterRemoveIndex(IndexT* index) {}
168
169 template <typename IndexT>
waitAndHandleFutures(std::vector<std::future<bool>> & v)170 void ThreadedIndex<IndexT>::waitAndHandleFutures(
171 std::vector<std::future<bool>>& v) {
172 // Blocking wait for completion for all of the indices, capturing any
173 // exceptions that are generated
174 std::vector<std::pair<int, std::exception_ptr>> exceptions;
175
176 for (int i = 0; i < v.size(); ++i) {
177 auto& fut = v[i];
178
179 try {
180 fut.get();
181 } catch (...) {
182 exceptions.emplace_back(
183 std::make_pair(i, std::current_exception()));
184 }
185 }
186
187 handleExceptions(exceptions);
188 }
189
190 } // namespace faiss
191