1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6import faiss 7import time 8import numpy as np 9 10import logging 11 12LOG = logging.getLogger(__name__) 13 14def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2): 15 """Computes the exact KNN search results for a dataset that possibly 16 does not fit in RAM but for which we have an iterator that 17 returns it block by block. 18 """ 19 LOG.info("knn_ground_truth queries size %s k=%d" % (xq.shape, k)) 20 t0 = time.time() 21 nq, d = xq.shape 22 rh = faiss.ResultHeap(nq, k) 23 24 index = faiss.IndexFlat(d, metric_type) 25 if faiss.get_num_gpus(): 26 LOG.info('running on %d GPUs' % faiss.get_num_gpus()) 27 index = faiss.index_cpu_to_all_gpus(index) 28 29 # compute ground-truth by blocks, and add to heaps 30 i0 = 0 31 for xbi in db_iterator: 32 ni = xbi.shape[0] 33 index.add(xbi) 34 D, I = index.search(xq, k) 35 I += i0 36 rh.add_result(D, I) 37 index.reset() 38 i0 += ni 39 LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0)) 40 41 rh.finalize() 42 LOG.info("GT time: %.3f s (%d vectors)" % (time.time() - t0, i0)) 43 44 return rh.D, rh.I 45 46# knn function used to be here 47knn = faiss.knn 48 49 50 51 52def range_search_gpu(xq, r2, index_gpu, index_cpu): 53 """GPU does not support range search, so we emulate it with 54 knn search + fallback to CPU index. 55 56 The index_cpu can either be a CPU index or a numpy table that will 57 be used to construct a Flat index if needed. 58 """ 59 nq, d = xq.shape 60 LOG.debug("GPU search %d queries" % nq) 61 k = min(index_gpu.ntotal, 1024) 62 D, I = index_gpu.search(xq, k) 63 if index_gpu.metric_type == faiss.METRIC_L2: 64 mask = D[:, k - 1] < r2 65 else: 66 mask = D[:, k - 1] > r2 67 if mask.sum() > 0: 68 LOG.debug("CPU search remain %d" % mask.sum()) 69 if isinstance(index_cpu, np.ndarray): 70 # then it in fact an array that we have to make flat 71 xb = index_cpu 72 index_cpu = faiss.IndexFlat(d, index_gpu.metric_type) 73 index_cpu.add(xb) 74 lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2) 75 LOG.debug("combine") 76 D_res, I_res = [], [] 77 nr = 0 78 for i in range(nq): 79 if not mask[i]: 80 if index_gpu.metric_type == faiss.METRIC_L2: 81 nv = (D[i, :] < r2).sum() 82 else: 83 nv = (D[i, :] > r2).sum() 84 D_res.append(D[i, :nv]) 85 I_res.append(I[i, :nv]) 86 else: 87 l0, l1 = lim_remain[nr], lim_remain[nr + 1] 88 D_res.append(D_remain[l0:l1]) 89 I_res.append(I_remain[l0:l1]) 90 nr += 1 91 lims = np.cumsum([0] + [len(di) for di in D_res]) 92 return lims, np.hstack(D_res), np.hstack(I_res) 93 94 95def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2, 96 shard=False, ngpu=-1): 97 """Computes the range-search search results for a dataset that possibly 98 does not fit in RAM but for which we have an iterator that 99 returns it block by block. 100 """ 101 nq, d = xq.shape 102 t0 = time.time() 103 xq = np.ascontiguousarray(xq, dtype='float32') 104 105 index = faiss.IndexFlat(d, metric_type) 106 if ngpu == -1: 107 ngpu = faiss.get_num_gpus() 108 if ngpu: 109 LOG.info('running on %d GPUs' % ngpu) 110 co = faiss.GpuMultipleClonerOptions() 111 co.shard = shard 112 index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu) 113 114 # compute ground-truth by blocks 115 i0 = 0 116 D = [[] for _i in range(nq)] 117 I = [[] for _i in range(nq)] 118 for xbi in db_iterator: 119 ni = xbi.shape[0] 120 if ngpu > 0: 121 index_gpu.add(xbi) 122 lims_i, Di, Ii = range_search_gpu(xq, threshold, index_gpu, xbi) 123 index_gpu.reset() 124 else: 125 index.add(xbi) 126 lims_i, Di, Ii = index.range_search(xq, threshold) 127 index.reset() 128 Ii += i0 129 for j in range(nq): 130 l0, l1 = lims_i[j], lims_i[j + 1] 131 if l1 > l0: 132 D[j].append(Di[l0:l1]) 133 I[j].append(Ii[l0:l1]) 134 i0 += ni 135 LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0)) 136 137 empty_I = np.zeros(0, dtype='int64') 138 empty_D = np.zeros(0, dtype='float32') 139 # import pdb; pdb.set_trace() 140 D = [(np.hstack(i) if i != [] else empty_D) for i in D] 141 I = [(np.hstack(i) if i != [] else empty_I) for i in I] 142 sizes = [len(i) for i in I] 143 assert len(sizes) == nq 144 lims = np.zeros(nq + 1, dtype="uint64") 145 lims[1:] = np.cumsum(sizes) 146 return lims, np.hstack(D), np.hstack(I) 147 148 149def threshold_radius_nres(nres, dis, ids, thresh): 150 """ select a set of results """ 151 mask = dis < thresh 152 new_nres = np.zeros_like(nres) 153 o = 0 154 for i, nr in enumerate(nres): 155 nr = int(nr) # avoid issues with int64 + uint64 156 new_nres[i] = mask[o : o + nr].sum() 157 o += nr 158 return new_nres, dis[mask], ids[mask] 159 160 161def threshold_radius(lims, dis, ids, thresh): 162 """ restrict range-search results to those below a given radius """ 163 mask = dis < thresh 164 new_lims = np.zeros_like(lims) 165 n = len(lims) - 1 166 for i in range(n): 167 l0, l1 = lims[i], lims[i + 1] 168 new_lims[i + 1] = new_lims[i] + mask[l0:l1].sum() 169 return new_lims, dis[mask], ids[mask] 170 171 172def apply_maxres(res_batches, target_nres): 173 """find radius that reduces number of results to target_nres, and 174 applies it in-place to the result batches used in range_search_max_results""" 175 alldis = np.hstack([dis for _, dis, _ in res_batches]) 176 alldis.partition(target_nres) 177 radius = alldis[target_nres] 178 179 if alldis.dtype == 'float32': 180 radius = float(radius) 181 else: 182 radius = int(radius) 183 LOG.debug(' setting radius to %s' % radius) 184 totres = 0 185 for i, (nres, dis, ids) in enumerate(res_batches): 186 nres, dis, ids = threshold_radius_nres(nres, dis, ids, radius) 187 totres += len(dis) 188 res_batches[i] = nres, dis, ids 189 LOG.debug(' updated previous results, new nb results %d' % totres) 190 return radius, totres 191 192 193def range_search_max_results(index, query_iterator, radius, 194 max_results=None, min_results=None, 195 shard=False, ngpu=0): 196 """Performs a range search with many queries (given by an iterator) 197 and adjusts the threshold on-the-fly so that the total results 198 table does not grow larger than max_results. 199 200 If ngpu != 0, the function moves the index to this many GPUs to 201 speed up search. 202 """ 203 204 if max_results is not None: 205 if min_results is None: 206 min_results = int(0.8 * max_results) 207 208 if ngpu == -1: 209 ngpu = faiss.get_num_gpus() 210 211 if ngpu: 212 LOG.info('running on %d GPUs' % ngpu) 213 co = faiss.GpuMultipleClonerOptions() 214 co.shard = shard 215 index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu) 216 217 t_start = time.time() 218 t_search = t_post_process = 0 219 qtot = totres = raw_totres = 0 220 res_batches = [] 221 222 for xqi in query_iterator: 223 t0 = time.time() 224 if ngpu > 0: 225 lims_i, Di, Ii = range_search_gpu(xqi, radius, index_gpu, index) 226 else: 227 lims_i, Di, Ii = index.range_search(xqi, radius) 228 229 nres_i = lims_i[1:] - lims_i[:-1] 230 raw_totres += len(Di) 231 qtot += len(xqi) 232 233 t1 = time.time() 234 if xqi.dtype != np.float32: 235 # for binary indexes 236 # weird Faiss quirk that returns floats for Hamming distances 237 Di = Di.astype('int16') 238 239 totres += len(Di) 240 res_batches.append((nres_i, Di, Ii)) 241 242 if max_results is not None and totres > max_results: 243 LOG.info('too many results %d > %d, scaling back radius' % 244 (totres, max_results)) 245 radius, totres = apply_maxres(res_batches, min_results) 246 t2 = time.time() 247 t_search += t1 - t0 248 t_post_process += t2 - t1 249 LOG.debug(' [%.3f s] %d queries done, %d results' % ( 250 time.time() - t_start, qtot, totres)) 251 252 LOG.info(' search done in %.3f s + %.3f s, total %d results, end threshold %g' % ( 253 t_search, t_post_process, totres, radius)) 254 255 nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches]) 256 dis = np.hstack([dis_i for nres_i, dis_i, ids_i in res_batches]) 257 ids = np.hstack([ids_i for nres_i, dis_i, ids_i in res_batches]) 258 259 lims = np.zeros(len(nres) + 1, dtype='uint64') 260 lims[1:] = np.cumsum(nres) 261 262 return radius, lims, dis, ids 263