1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import faiss
7import time
8import numpy as np
9
10import logging
11
12LOG = logging.getLogger(__name__)
13
14def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):
15    """Computes the exact KNN search results for a dataset that possibly
16    does not fit in RAM but for which we have an iterator that
17    returns it block by block.
18    """
19    LOG.info("knn_ground_truth queries size %s k=%d" % (xq.shape, k))
20    t0 = time.time()
21    nq, d = xq.shape
22    rh = faiss.ResultHeap(nq, k)
23
24    index = faiss.IndexFlat(d, metric_type)
25    if faiss.get_num_gpus():
26        LOG.info('running on %d GPUs' % faiss.get_num_gpus())
27        index = faiss.index_cpu_to_all_gpus(index)
28
29    # compute ground-truth by blocks, and add to heaps
30    i0 = 0
31    for xbi in db_iterator:
32        ni = xbi.shape[0]
33        index.add(xbi)
34        D, I = index.search(xq, k)
35        I += i0
36        rh.add_result(D, I)
37        index.reset()
38        i0 += ni
39        LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0))
40
41    rh.finalize()
42    LOG.info("GT time: %.3f s (%d vectors)" % (time.time() - t0, i0))
43
44    return rh.D, rh.I
45
46# knn function used to be here
47knn = faiss.knn
48
49
50
51
52def range_search_gpu(xq, r2, index_gpu, index_cpu):
53    """GPU does not support range search, so we emulate it with
54    knn search + fallback to CPU index.
55
56    The index_cpu can either be a CPU index or a numpy table that will
57    be used to construct a Flat index if needed.
58    """
59    nq, d = xq.shape
60    LOG.debug("GPU search %d queries" % nq)
61    k = min(index_gpu.ntotal, 1024)
62    D, I = index_gpu.search(xq, k)
63    if index_gpu.metric_type == faiss.METRIC_L2:
64        mask = D[:, k - 1] < r2
65    else:
66        mask = D[:, k - 1] > r2
67    if mask.sum() > 0:
68        LOG.debug("CPU search remain %d" % mask.sum())
69        if isinstance(index_cpu, np.ndarray):
70            # then it in fact an array that we have to make flat
71            xb = index_cpu
72            index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
73            index_cpu.add(xb)
74        lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
75    LOG.debug("combine")
76    D_res, I_res = [], []
77    nr = 0
78    for i in range(nq):
79        if not mask[i]:
80            if index_gpu.metric_type == faiss.METRIC_L2:
81                nv = (D[i, :] < r2).sum()
82            else:
83                nv = (D[i, :] > r2).sum()
84            D_res.append(D[i, :nv])
85            I_res.append(I[i, :nv])
86        else:
87            l0, l1 = lim_remain[nr], lim_remain[nr + 1]
88            D_res.append(D_remain[l0:l1])
89            I_res.append(I_remain[l0:l1])
90            nr += 1
91    lims = np.cumsum([0] + [len(di) for di in D_res])
92    return lims, np.hstack(D_res), np.hstack(I_res)
93
94
95def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2,
96                       shard=False, ngpu=-1):
97    """Computes the range-search search results for a dataset that possibly
98    does not fit in RAM but for which we have an iterator that
99    returns it block by block.
100    """
101    nq, d = xq.shape
102    t0 = time.time()
103    xq = np.ascontiguousarray(xq, dtype='float32')
104
105    index = faiss.IndexFlat(d, metric_type)
106    if ngpu == -1:
107        ngpu = faiss.get_num_gpus()
108    if ngpu:
109        LOG.info('running on %d GPUs' % ngpu)
110        co = faiss.GpuMultipleClonerOptions()
111        co.shard = shard
112        index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
113
114    # compute ground-truth by blocks
115    i0 = 0
116    D = [[] for _i in range(nq)]
117    I = [[] for _i in range(nq)]
118    for xbi in db_iterator:
119        ni = xbi.shape[0]
120        if ngpu > 0:
121            index_gpu.add(xbi)
122            lims_i, Di, Ii = range_search_gpu(xq, threshold, index_gpu, xbi)
123            index_gpu.reset()
124        else:
125            index.add(xbi)
126            lims_i, Di, Ii = index.range_search(xq, threshold)
127            index.reset()
128        Ii += i0
129        for j in range(nq):
130            l0, l1 = lims_i[j], lims_i[j + 1]
131            if l1 > l0:
132                D[j].append(Di[l0:l1])
133                I[j].append(Ii[l0:l1])
134        i0 += ni
135        LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0))
136
137    empty_I = np.zeros(0, dtype='int64')
138    empty_D = np.zeros(0, dtype='float32')
139    # import pdb; pdb.set_trace()
140    D = [(np.hstack(i) if i != [] else empty_D) for i in D]
141    I = [(np.hstack(i) if i != [] else empty_I) for i in I]
142    sizes = [len(i) for i in I]
143    assert len(sizes) == nq
144    lims = np.zeros(nq + 1, dtype="uint64")
145    lims[1:] = np.cumsum(sizes)
146    return lims, np.hstack(D), np.hstack(I)
147
148
149def threshold_radius_nres(nres, dis, ids, thresh):
150    """ select a set of results """
151    mask = dis < thresh
152    new_nres = np.zeros_like(nres)
153    o = 0
154    for i, nr in enumerate(nres):
155        nr = int(nr)   # avoid issues with int64 + uint64
156        new_nres[i] = mask[o : o + nr].sum()
157        o += nr
158    return new_nres, dis[mask], ids[mask]
159
160
161def threshold_radius(lims, dis, ids, thresh):
162    """ restrict range-search results to those below a given radius """
163    mask = dis < thresh
164    new_lims = np.zeros_like(lims)
165    n = len(lims) - 1
166    for i in range(n):
167        l0, l1 = lims[i], lims[i + 1]
168        new_lims[i + 1] = new_lims[i] + mask[l0:l1].sum()
169    return new_lims, dis[mask], ids[mask]
170
171
172def apply_maxres(res_batches, target_nres):
173    """find radius that reduces number of results to target_nres, and
174    applies it in-place to the result batches used in range_search_max_results"""
175    alldis = np.hstack([dis for _, dis, _ in res_batches])
176    alldis.partition(target_nres)
177    radius = alldis[target_nres]
178
179    if alldis.dtype == 'float32':
180        radius = float(radius)
181    else:
182        radius = int(radius)
183    LOG.debug('   setting radius to %s' % radius)
184    totres = 0
185    for i, (nres, dis, ids) in enumerate(res_batches):
186        nres, dis, ids = threshold_radius_nres(nres, dis, ids, radius)
187        totres += len(dis)
188        res_batches[i] = nres, dis, ids
189    LOG.debug('   updated previous results, new nb results %d' % totres)
190    return radius, totres
191
192
193def range_search_max_results(index, query_iterator, radius,
194                             max_results=None, min_results=None,
195                             shard=False, ngpu=0):
196    """Performs a range search with many queries (given by an iterator)
197    and adjusts the threshold on-the-fly so that the total results
198    table does not grow larger than max_results.
199
200    If ngpu != 0, the function moves the index to this many GPUs to
201    speed up search.
202    """
203
204    if max_results is not None:
205        if min_results is None:
206            min_results = int(0.8 * max_results)
207
208    if ngpu == -1:
209        ngpu = faiss.get_num_gpus()
210
211    if ngpu:
212        LOG.info('running on %d GPUs' % ngpu)
213        co = faiss.GpuMultipleClonerOptions()
214        co.shard = shard
215        index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
216
217    t_start = time.time()
218    t_search = t_post_process = 0
219    qtot = totres = raw_totres = 0
220    res_batches = []
221
222    for xqi in query_iterator:
223        t0 = time.time()
224        if ngpu > 0:
225            lims_i, Di, Ii = range_search_gpu(xqi, radius, index_gpu, index)
226        else:
227            lims_i, Di, Ii = index.range_search(xqi, radius)
228
229        nres_i = lims_i[1:] - lims_i[:-1]
230        raw_totres += len(Di)
231        qtot += len(xqi)
232
233        t1 = time.time()
234        if xqi.dtype != np.float32:
235            # for binary indexes
236            # weird Faiss quirk that returns floats for Hamming distances
237            Di = Di.astype('int16')
238
239        totres += len(Di)
240        res_batches.append((nres_i, Di, Ii))
241
242        if max_results is not None and totres > max_results:
243            LOG.info('too many results %d > %d, scaling back radius' %
244                     (totres, max_results))
245            radius, totres = apply_maxres(res_batches, min_results)
246        t2 = time.time()
247        t_search += t1 - t0
248        t_post_process += t2 - t1
249        LOG.debug('   [%.3f s] %d queries done, %d results' % (
250            time.time() - t_start, qtot, totres))
251
252    LOG.info('   search done in %.3f s + %.3f s, total %d results, end threshold %g' % (
253        t_search, t_post_process, totres, radius))
254
255    nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches])
256    dis = np.hstack([dis_i for nres_i, dis_i, ids_i in res_batches])
257    ids = np.hstack([ids_i for nres_i, dis_i, ids_i in res_batches])
258
259    lims = np.zeros(len(nres) + 1, dtype='uint64')
260    lims[1:] = np.cumsum(nres)
261
262    return radius, lims, dis, ids
263