1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include <faiss/gpu/GpuDistance.h>
9 #include <faiss/gpu/GpuResources.h>
10 #include <faiss/gpu/utils/DeviceUtils.h>
11 #include <faiss/impl/FaissAssert.h>
12 #include <faiss/gpu/impl/Distance.cuh>
13 #include <faiss/gpu/utils/ConversionOperators.cuh>
14 #include <faiss/gpu/utils/CopyUtils.cuh>
15 #include <faiss/gpu/utils/DeviceTensor.cuh>
16 
17 namespace faiss {
18 namespace gpu {
19 
20 template <typename T>
bfKnnConvert(GpuResourcesProvider * prov,const GpuDistanceParams & args)21 void bfKnnConvert(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
22     // Validate the input data
23     FAISS_THROW_IF_NOT_MSG(
24             args.k > 0 || args.k == -1,
25             "bfKnn: k must be > 0 for top-k reduction, "
26             "or -1 for all pairwise distances");
27     FAISS_THROW_IF_NOT_MSG(args.dims > 0, "bfKnn: dims must be > 0");
28     FAISS_THROW_IF_NOT_MSG(
29             args.numVectors > 0, "bfKnn: numVectors must be > 0");
30     FAISS_THROW_IF_NOT_MSG(
31             args.vectors, "bfKnn: vectors must be provided (passed null)");
32     FAISS_THROW_IF_NOT_MSG(
33             args.numQueries > 0, "bfKnn: numQueries must be > 0");
34     FAISS_THROW_IF_NOT_MSG(
35             args.queries, "bfKnn: queries must be provided (passed null)");
36     FAISS_THROW_IF_NOT_MSG(
37             args.outDistances,
38             "bfKnn: outDistances must be provided (passed null)");
39     FAISS_THROW_IF_NOT_MSG(
40             args.outIndices || args.k == -1,
41             "bfKnn: outIndices must be provided (passed null)");
42 
43     // Don't let the resources go out of scope
44     auto resImpl = prov->getResources();
45     auto res = resImpl.get();
46     auto device = getCurrentDevice();
47     auto stream = res->getDefaultStreamCurrentDevice();
48 
49     auto tVectors = toDeviceTemporary<T, 2>(
50             res,
51             device,
52             const_cast<T*>(reinterpret_cast<const T*>(args.vectors)),
53             stream,
54             {args.vectorsRowMajor ? args.numVectors : args.dims,
55              args.vectorsRowMajor ? args.dims : args.numVectors});
56     auto tQueries = toDeviceTemporary<T, 2>(
57             res,
58             device,
59             const_cast<T*>(reinterpret_cast<const T*>(args.queries)),
60             stream,
61             {args.queriesRowMajor ? args.numQueries : args.dims,
62              args.queriesRowMajor ? args.dims : args.numQueries});
63 
64     DeviceTensor<float, 1, true> tVectorNorms;
65     if (args.vectorNorms) {
66         tVectorNorms = toDeviceTemporary<float, 1>(
67                 res,
68                 device,
69                 const_cast<float*>(args.vectorNorms),
70                 stream,
71                 {args.numVectors});
72     }
73 
74     auto tOutDistances = toDeviceTemporary<float, 2>(
75             res,
76             device,
77             args.outDistances,
78             stream,
79             {args.numQueries, args.k == -1 ? args.numVectors : args.k});
80 
81     if (args.k == -1) {
82         // Reporting all pairwise distances
83         allPairwiseDistanceOnDevice<T>(
84                 res,
85                 device,
86                 stream,
87                 tVectors,
88                 args.vectorsRowMajor,
89                 args.vectorNorms ? &tVectorNorms : nullptr,
90                 tQueries,
91                 args.queriesRowMajor,
92                 args.metric,
93                 args.metricArg,
94                 tOutDistances);
95     } else if (args.outIndicesType == IndicesDataType::I64) {
96         // The brute-force API only supports an interface for i32 indices only,
97         // so we must create an output i32 buffer then convert back
98         DeviceTensor<int, 2, true> tOutIntIndices(
99                 res,
100                 makeTempAlloc(AllocType::Other, stream),
101                 {args.numQueries, args.k});
102 
103         // Since we've guaranteed that all arguments are on device, call the
104         // implementation
105         bfKnnOnDevice<T>(
106                 res,
107                 device,
108                 stream,
109                 tVectors,
110                 args.vectorsRowMajor,
111                 args.vectorNorms ? &tVectorNorms : nullptr,
112                 tQueries,
113                 args.queriesRowMajor,
114                 args.k,
115                 args.metric,
116                 args.metricArg,
117                 tOutDistances,
118                 tOutIntIndices,
119                 args.ignoreOutDistances);
120 
121         // Convert and copy int indices out
122         auto tOutIndices = toDeviceTemporary<Index::idx_t, 2>(
123                 res,
124                 device,
125                 (Index::idx_t*)args.outIndices,
126                 stream,
127                 {args.numQueries, args.k});
128 
129         // Convert int to idx_t
130         convertTensor<int, Index::idx_t, 2>(
131                 stream, tOutIntIndices, tOutIndices);
132 
133         // Copy back if necessary
134         fromDevice<Index::idx_t, 2>(
135                 tOutIndices, (Index::idx_t*)args.outIndices, stream);
136 
137     } else if (args.outIndicesType == IndicesDataType::I32) {
138         // We can use the brute-force API directly, as it takes i32 indices
139         // FIXME: convert to int32_t everywhere?
140         static_assert(sizeof(int) == 4, "");
141 
142         auto tOutIntIndices = toDeviceTemporary<int, 2>(
143                 res,
144                 device,
145                 (int*)args.outIndices,
146                 stream,
147                 {args.numQueries, args.k});
148 
149         // Since we've guaranteed that all arguments are on device, call the
150         // implementation
151         bfKnnOnDevice<T>(
152                 res,
153                 device,
154                 stream,
155                 tVectors,
156                 args.vectorsRowMajor,
157                 args.vectorNorms ? &tVectorNorms : nullptr,
158                 tQueries,
159                 args.queriesRowMajor,
160                 args.k,
161                 args.metric,
162                 args.metricArg,
163                 tOutDistances,
164                 tOutIntIndices,
165                 args.ignoreOutDistances);
166 
167         // Copy back if necessary
168         fromDevice<int, 2>(tOutIntIndices, (int*)args.outIndices, stream);
169     } else {
170         FAISS_THROW_MSG("unknown outIndicesType");
171     }
172 
173     // Copy distances back if necessary
174     fromDevice<float, 2>(tOutDistances, args.outDistances, stream);
175 }
176 
bfKnn(GpuResourcesProvider * res,const GpuDistanceParams & args)177 void bfKnn(GpuResourcesProvider* res, const GpuDistanceParams& args) {
178     // For now, both vectors and queries must be of the same data type
179     FAISS_THROW_IF_NOT_MSG(
180             args.vectorType == args.queryType,
181             "limitation: both vectorType and queryType must currently "
182             "be the same (F32 or F16");
183 
184     if (args.vectorType == DistanceDataType::F32) {
185         bfKnnConvert<float>(res, args);
186     } else if (args.vectorType == DistanceDataType::F16) {
187         bfKnnConvert<half>(res, args);
188     } else {
189         FAISS_THROW_MSG("unknown vectorType");
190     }
191 }
192 
193 // legacy version
bruteForceKnn(GpuResourcesProvider * res,faiss::MetricType metric,const float * vectors,bool vectorsRowMajor,int numVectors,const float * queries,bool queriesRowMajor,int numQueries,int dims,int k,float * outDistances,Index::idx_t * outIndices)194 void bruteForceKnn(
195         GpuResourcesProvider* res,
196         faiss::MetricType metric,
197         // A region of memory size numVectors x dims, with dims
198         // innermost
199         const float* vectors,
200         bool vectorsRowMajor,
201         int numVectors,
202         // A region of memory size numQueries x dims, with dims
203         // innermost
204         const float* queries,
205         bool queriesRowMajor,
206         int numQueries,
207         int dims,
208         int k,
209         // A region of memory size numQueries x k, with k
210         // innermost
211         float* outDistances,
212         // A region of memory size numQueries x k, with k
213         // innermost
214         Index::idx_t* outIndices) {
215     std::cerr << "bruteForceKnn is deprecated; call bfKnn instead" << std::endl;
216 
217     GpuDistanceParams args;
218     args.metric = metric;
219     args.k = k;
220     args.dims = dims;
221     args.vectors = vectors;
222     args.vectorsRowMajor = vectorsRowMajor;
223     args.numVectors = numVectors;
224     args.queries = queries;
225     args.queriesRowMajor = queriesRowMajor;
226     args.numQueries = numQueries;
227     args.outDistances = outDistances;
228     args.outIndices = outIndices;
229 
230     bfKnn(res, args);
231 }
232 
233 } // namespace gpu
234 } // namespace faiss
235