1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #include <faiss/gpu/GpuDistance.h>
9 #include <faiss/gpu/GpuResources.h>
10 #include <faiss/gpu/utils/DeviceUtils.h>
11 #include <faiss/impl/FaissAssert.h>
12 #include <faiss/gpu/impl/Distance.cuh>
13 #include <faiss/gpu/utils/ConversionOperators.cuh>
14 #include <faiss/gpu/utils/CopyUtils.cuh>
15 #include <faiss/gpu/utils/DeviceTensor.cuh>
16
17 namespace faiss {
18 namespace gpu {
19
20 template <typename T>
bfKnnConvert(GpuResourcesProvider * prov,const GpuDistanceParams & args)21 void bfKnnConvert(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
22 // Validate the input data
23 FAISS_THROW_IF_NOT_MSG(
24 args.k > 0 || args.k == -1,
25 "bfKnn: k must be > 0 for top-k reduction, "
26 "or -1 for all pairwise distances");
27 FAISS_THROW_IF_NOT_MSG(args.dims > 0, "bfKnn: dims must be > 0");
28 FAISS_THROW_IF_NOT_MSG(
29 args.numVectors > 0, "bfKnn: numVectors must be > 0");
30 FAISS_THROW_IF_NOT_MSG(
31 args.vectors, "bfKnn: vectors must be provided (passed null)");
32 FAISS_THROW_IF_NOT_MSG(
33 args.numQueries > 0, "bfKnn: numQueries must be > 0");
34 FAISS_THROW_IF_NOT_MSG(
35 args.queries, "bfKnn: queries must be provided (passed null)");
36 FAISS_THROW_IF_NOT_MSG(
37 args.outDistances,
38 "bfKnn: outDistances must be provided (passed null)");
39 FAISS_THROW_IF_NOT_MSG(
40 args.outIndices || args.k == -1,
41 "bfKnn: outIndices must be provided (passed null)");
42
43 // Don't let the resources go out of scope
44 auto resImpl = prov->getResources();
45 auto res = resImpl.get();
46 auto device = getCurrentDevice();
47 auto stream = res->getDefaultStreamCurrentDevice();
48
49 auto tVectors = toDeviceTemporary<T, 2>(
50 res,
51 device,
52 const_cast<T*>(reinterpret_cast<const T*>(args.vectors)),
53 stream,
54 {args.vectorsRowMajor ? args.numVectors : args.dims,
55 args.vectorsRowMajor ? args.dims : args.numVectors});
56 auto tQueries = toDeviceTemporary<T, 2>(
57 res,
58 device,
59 const_cast<T*>(reinterpret_cast<const T*>(args.queries)),
60 stream,
61 {args.queriesRowMajor ? args.numQueries : args.dims,
62 args.queriesRowMajor ? args.dims : args.numQueries});
63
64 DeviceTensor<float, 1, true> tVectorNorms;
65 if (args.vectorNorms) {
66 tVectorNorms = toDeviceTemporary<float, 1>(
67 res,
68 device,
69 const_cast<float*>(args.vectorNorms),
70 stream,
71 {args.numVectors});
72 }
73
74 auto tOutDistances = toDeviceTemporary<float, 2>(
75 res,
76 device,
77 args.outDistances,
78 stream,
79 {args.numQueries, args.k == -1 ? args.numVectors : args.k});
80
81 if (args.k == -1) {
82 // Reporting all pairwise distances
83 allPairwiseDistanceOnDevice<T>(
84 res,
85 device,
86 stream,
87 tVectors,
88 args.vectorsRowMajor,
89 args.vectorNorms ? &tVectorNorms : nullptr,
90 tQueries,
91 args.queriesRowMajor,
92 args.metric,
93 args.metricArg,
94 tOutDistances);
95 } else if (args.outIndicesType == IndicesDataType::I64) {
96 // The brute-force API only supports an interface for i32 indices only,
97 // so we must create an output i32 buffer then convert back
98 DeviceTensor<int, 2, true> tOutIntIndices(
99 res,
100 makeTempAlloc(AllocType::Other, stream),
101 {args.numQueries, args.k});
102
103 // Since we've guaranteed that all arguments are on device, call the
104 // implementation
105 bfKnnOnDevice<T>(
106 res,
107 device,
108 stream,
109 tVectors,
110 args.vectorsRowMajor,
111 args.vectorNorms ? &tVectorNorms : nullptr,
112 tQueries,
113 args.queriesRowMajor,
114 args.k,
115 args.metric,
116 args.metricArg,
117 tOutDistances,
118 tOutIntIndices,
119 args.ignoreOutDistances);
120
121 // Convert and copy int indices out
122 auto tOutIndices = toDeviceTemporary<Index::idx_t, 2>(
123 res,
124 device,
125 (Index::idx_t*)args.outIndices,
126 stream,
127 {args.numQueries, args.k});
128
129 // Convert int to idx_t
130 convertTensor<int, Index::idx_t, 2>(
131 stream, tOutIntIndices, tOutIndices);
132
133 // Copy back if necessary
134 fromDevice<Index::idx_t, 2>(
135 tOutIndices, (Index::idx_t*)args.outIndices, stream);
136
137 } else if (args.outIndicesType == IndicesDataType::I32) {
138 // We can use the brute-force API directly, as it takes i32 indices
139 // FIXME: convert to int32_t everywhere?
140 static_assert(sizeof(int) == 4, "");
141
142 auto tOutIntIndices = toDeviceTemporary<int, 2>(
143 res,
144 device,
145 (int*)args.outIndices,
146 stream,
147 {args.numQueries, args.k});
148
149 // Since we've guaranteed that all arguments are on device, call the
150 // implementation
151 bfKnnOnDevice<T>(
152 res,
153 device,
154 stream,
155 tVectors,
156 args.vectorsRowMajor,
157 args.vectorNorms ? &tVectorNorms : nullptr,
158 tQueries,
159 args.queriesRowMajor,
160 args.k,
161 args.metric,
162 args.metricArg,
163 tOutDistances,
164 tOutIntIndices,
165 args.ignoreOutDistances);
166
167 // Copy back if necessary
168 fromDevice<int, 2>(tOutIntIndices, (int*)args.outIndices, stream);
169 } else {
170 FAISS_THROW_MSG("unknown outIndicesType");
171 }
172
173 // Copy distances back if necessary
174 fromDevice<float, 2>(tOutDistances, args.outDistances, stream);
175 }
176
bfKnn(GpuResourcesProvider * res,const GpuDistanceParams & args)177 void bfKnn(GpuResourcesProvider* res, const GpuDistanceParams& args) {
178 // For now, both vectors and queries must be of the same data type
179 FAISS_THROW_IF_NOT_MSG(
180 args.vectorType == args.queryType,
181 "limitation: both vectorType and queryType must currently "
182 "be the same (F32 or F16");
183
184 if (args.vectorType == DistanceDataType::F32) {
185 bfKnnConvert<float>(res, args);
186 } else if (args.vectorType == DistanceDataType::F16) {
187 bfKnnConvert<half>(res, args);
188 } else {
189 FAISS_THROW_MSG("unknown vectorType");
190 }
191 }
192
193 // legacy version
bruteForceKnn(GpuResourcesProvider * res,faiss::MetricType metric,const float * vectors,bool vectorsRowMajor,int numVectors,const float * queries,bool queriesRowMajor,int numQueries,int dims,int k,float * outDistances,Index::idx_t * outIndices)194 void bruteForceKnn(
195 GpuResourcesProvider* res,
196 faiss::MetricType metric,
197 // A region of memory size numVectors x dims, with dims
198 // innermost
199 const float* vectors,
200 bool vectorsRowMajor,
201 int numVectors,
202 // A region of memory size numQueries x dims, with dims
203 // innermost
204 const float* queries,
205 bool queriesRowMajor,
206 int numQueries,
207 int dims,
208 int k,
209 // A region of memory size numQueries x k, with k
210 // innermost
211 float* outDistances,
212 // A region of memory size numQueries x k, with k
213 // innermost
214 Index::idx_t* outIndices) {
215 std::cerr << "bruteForceKnn is deprecated; call bfKnn instead" << std::endl;
216
217 GpuDistanceParams args;
218 args.metric = metric;
219 args.k = k;
220 args.dims = dims;
221 args.vectors = vectors;
222 args.vectorsRowMajor = vectorsRowMajor;
223 args.numVectors = numVectors;
224 args.queries = queries;
225 args.queriesRowMajor = queriesRowMajor;
226 args.numQueries = numQueries;
227 args.outDistances = outDistances;
228 args.outIndices = outIndices;
229
230 bfKnn(res, args);
231 }
232
233 } // namespace gpu
234 } // namespace faiss
235