1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import numpy as np
7import unittest
8
9from multiprocessing.dummy import Pool as ThreadPool
10
11###############################################################
12# Simple functions to evaluate knn results
13
14def knn_intersection_measure(I1, I2):
15    """ computes the intersection measure of two result tables
16    """
17    nq, rank = I1.shape
18    assert I2.shape == (nq, rank)
19    ninter = sum(
20        np.intersect1d(I1[i], I2[i]).size
21        for i in range(nq)
22    )
23    return ninter / I1.size
24
25###############################################################
26# Range search results can be compared with Precision-Recall
27
28def filter_range_results(lims, D, I, thresh):
29    """ select a set of results """
30    nq = lims.size - 1
31    mask = D < thresh
32    new_lims = np.zeros_like(lims)
33    for i in range(nq):
34        new_lims[i + 1] = new_lims[i] + mask[lims[i] : lims[i + 1]].sum()
35    return new_lims, D[mask], I[mask]
36
37
38def range_PR(lims_ref, Iref, lims_new, Inew, mode="overall"):
39    """compute the precision and recall of range search results. The
40    function does not take the distances into account. """
41
42    def ref_result_for(i):
43        return Iref[lims_ref[i]:lims_ref[i + 1]]
44
45    def new_result_for(i):
46        return Inew[lims_new[i]:lims_new[i + 1]]
47
48    nq = lims_ref.size - 1
49    assert lims_new.size - 1 == nq
50
51    ninter = np.zeros(nq, dtype="int64")
52
53    def compute_PR_for(q):
54
55        # ground truth results for this query
56        gt_ids = ref_result_for(q)
57
58        # results for this query
59        new_ids = new_result_for(q)
60
61        # there are no set functions in numpy so let's do this
62        inter = np.intersect1d(gt_ids, new_ids)
63
64        ninter[q] = len(inter)
65
66    # run in a thread pool, which helps in spite of the GIL
67    pool = ThreadPool(20)
68    pool.map(compute_PR_for, range(nq))
69
70    return counts_to_PR(
71        lims_ref[1:] - lims_ref[:-1],
72        lims_new[1:] - lims_new[:-1],
73        ninter,
74        mode=mode
75    )
76
77
78def counts_to_PR(ngt, nres, ninter, mode="overall"):
79    """ computes a  precision-recall for a ser of queries.
80    ngt = nb of GT results per query
81    nres = nb of found results per query
82    ninter = nb of correct results per query (smaller than nres of course)
83    """
84
85    if mode == "overall":
86        ngt, nres, ninter = ngt.sum(), nres.sum(), ninter.sum()
87
88        if nres > 0:
89            precision = ninter / nres
90        else:
91            precision = 1.0
92
93        if ngt > 0:
94            recall = ninter / ngt
95        elif nres == 0:
96            recall = 1.0
97        else:
98            recall = 0.0
99
100        return precision, recall
101
102    elif mode == "average":
103        # average precision and recall over queries
104
105        mask = ngt == 0
106        ngt[mask] = 1
107
108        recalls = ninter / ngt
109        recalls[mask] = (nres[mask] == 0).astype(float)
110
111        # avoid division by 0
112        mask = nres == 0
113        assert np.all(ninter[mask] == 0)
114        ninter[mask] = 1
115        nres[mask] = 1
116
117        precisions = ninter / nres
118
119        return precisions.mean(), recalls.mean()
120
121    else:
122        raise AssertionError()
123
124def sort_range_res_2(lims, D, I):
125    """ sort 2 arrays using the first as key """
126    I2 = np.empty_like(I)
127    D2 = np.empty_like(D)
128    nq = len(lims) - 1
129    for i in range(nq):
130        l0, l1 = lims[i], lims[i + 1]
131        ii = I[l0:l1]
132        di = D[l0:l1]
133        o = di.argsort()
134        I2[l0:l1] = ii[o]
135        D2[l0:l1] = di[o]
136    return I2, D2
137
138
139def sort_range_res_1(lims, I):
140    I2 = np.empty_like(I)
141    nq = len(lims) - 1
142    for i in range(nq):
143        l0, l1 = lims[i], lims[i + 1]
144        I2[l0:l1] = I[l0:l1]
145        I2[l0:l1].sort()
146    return I2
147
148
149def range_PR_multiple_thresholds(
150            lims_ref, Iref,
151            lims_new, Dnew, Inew,
152            thresholds,
153            mode="overall", do_sort="ref,new"
154    ):
155    """ compute precision-recall values for range search results
156    for several thresholds on the "new" results.
157    This is to plot PR curves
158    """
159    # ref should be sorted by ids
160    if "ref" in do_sort:
161        Iref = sort_range_res_1(lims_ref, Iref)
162
163    # new should be sorted by distances
164    if "new" in do_sort:
165        Inew, Dnew = sort_range_res_2(lims_new, Dnew, Inew)
166
167    def ref_result_for(i):
168        return Iref[lims_ref[i]:lims_ref[i + 1]]
169
170    def new_result_for(i):
171        l0, l1 = lims_new[i], lims_new[i + 1]
172        return Inew[l0:l1], Dnew[l0:l1]
173
174    nq = lims_ref.size - 1
175    assert lims_new.size - 1 == nq
176
177    nt = len(thresholds)
178    counts = np.zeros((nq, nt, 3), dtype="int64")
179
180    def compute_PR_for(q):
181        gt_ids = ref_result_for(q)
182        res_ids, res_dis = new_result_for(q)
183
184        counts[q, :, 0] = len(gt_ids)
185
186        if res_dis.size == 0:
187            # the rest remains at 0
188            return
189
190        # which offsets we are interested in
191        nres= np.searchsorted(res_dis, thresholds)
192        counts[q, :, 1] = nres
193
194        if gt_ids.size == 0:
195            return
196
197        # find number of TPs at each stage in the result list
198        ii = np.searchsorted(gt_ids, res_ids)
199        ii[ii == len(gt_ids)] = -1
200        n_ok = np.cumsum(gt_ids[ii] == res_ids)
201
202        # focus on threshold points
203        n_ok = np.hstack(([0], n_ok))
204        counts[q, :, 2] = n_ok[nres]
205
206    pool = ThreadPool(20)
207    pool.map(compute_PR_for, range(nq))
208    # print(counts.transpose(2, 1, 0))
209
210    precisions = np.zeros(nt)
211    recalls = np.zeros(nt)
212    for t in range(nt):
213        p, r = counts_to_PR(
214                counts[:, t, 0], counts[:, t, 1], counts[:, t, 2],
215                mode=mode
216        )
217        precisions[t] = p
218        recalls[t] = r
219
220    return precisions, recalls
221
222
223
224
225###############################################################
226# Functions that compare search results with a reference result.
227# They are intended for use in tests
228
229def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
230    """ test that knn search results are identical, raise if not """
231    np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
232    # here we have to be careful because of draws
233    testcase = unittest.TestCase()   # because it makes nice error messages
234    for i in range(len(Iref)):
235        if np.all(Iref[i] == Inew[i]): # easy case
236            continue
237        # we can deduce nothing about the latest line
238        skip_dis = Dref[i, -1]
239        for dis in np.unique(Dref):
240            if dis == skip_dis:
241                continue
242            mask = Dref[i, :] == dis
243            testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
244
245
246def test_ref_range_results(lims_ref, Dref, Iref,
247                           lims_new, Dnew, Inew):
248    """ compare range search results wrt. a reference result,
249    throw if it fails """
250    np.testing.assert_array_equal(lims_ref, lims_new)
251    nq = len(lims_ref) - 1
252    for i in range(nq):
253        l0, l1 = lims_ref[i], lims_ref[i + 1]
254        Ii_ref = Iref[l0:l1]
255        Ii_new = Inew[l0:l1]
256        Di_ref = Dref[l0:l1]
257        Di_new = Dnew[l0:l1]
258        if np.all(Ii_ref == Ii_new): # easy
259            pass
260        else:
261            def sort_by_ids(I, D):
262                o = I.argsort()
263                return I[o], D[o]
264            # sort both
265            (Ii_ref, Di_ref) = sort_by_ids(Ii_ref, Di_ref)
266            (Ii_new, Di_new) = sort_by_ids(Ii_new, Di_new)
267            np.testing.assert_array_equal(Ii_ref, Ii_new)
268        np.testing.assert_array_almost_equal(Di_ref, Di_new, decimal=5)
269