1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6import numpy as np 7import unittest 8 9from multiprocessing.dummy import Pool as ThreadPool 10 11############################################################### 12# Simple functions to evaluate knn results 13 14def knn_intersection_measure(I1, I2): 15 """ computes the intersection measure of two result tables 16 """ 17 nq, rank = I1.shape 18 assert I2.shape == (nq, rank) 19 ninter = sum( 20 np.intersect1d(I1[i], I2[i]).size 21 for i in range(nq) 22 ) 23 return ninter / I1.size 24 25############################################################### 26# Range search results can be compared with Precision-Recall 27 28def filter_range_results(lims, D, I, thresh): 29 """ select a set of results """ 30 nq = lims.size - 1 31 mask = D < thresh 32 new_lims = np.zeros_like(lims) 33 for i in range(nq): 34 new_lims[i + 1] = new_lims[i] + mask[lims[i] : lims[i + 1]].sum() 35 return new_lims, D[mask], I[mask] 36 37 38def range_PR(lims_ref, Iref, lims_new, Inew, mode="overall"): 39 """compute the precision and recall of range search results. The 40 function does not take the distances into account. """ 41 42 def ref_result_for(i): 43 return Iref[lims_ref[i]:lims_ref[i + 1]] 44 45 def new_result_for(i): 46 return Inew[lims_new[i]:lims_new[i + 1]] 47 48 nq = lims_ref.size - 1 49 assert lims_new.size - 1 == nq 50 51 ninter = np.zeros(nq, dtype="int64") 52 53 def compute_PR_for(q): 54 55 # ground truth results for this query 56 gt_ids = ref_result_for(q) 57 58 # results for this query 59 new_ids = new_result_for(q) 60 61 # there are no set functions in numpy so let's do this 62 inter = np.intersect1d(gt_ids, new_ids) 63 64 ninter[q] = len(inter) 65 66 # run in a thread pool, which helps in spite of the GIL 67 pool = ThreadPool(20) 68 pool.map(compute_PR_for, range(nq)) 69 70 return counts_to_PR( 71 lims_ref[1:] - lims_ref[:-1], 72 lims_new[1:] - lims_new[:-1], 73 ninter, 74 mode=mode 75 ) 76 77 78def counts_to_PR(ngt, nres, ninter, mode="overall"): 79 """ computes a precision-recall for a ser of queries. 80 ngt = nb of GT results per query 81 nres = nb of found results per query 82 ninter = nb of correct results per query (smaller than nres of course) 83 """ 84 85 if mode == "overall": 86 ngt, nres, ninter = ngt.sum(), nres.sum(), ninter.sum() 87 88 if nres > 0: 89 precision = ninter / nres 90 else: 91 precision = 1.0 92 93 if ngt > 0: 94 recall = ninter / ngt 95 elif nres == 0: 96 recall = 1.0 97 else: 98 recall = 0.0 99 100 return precision, recall 101 102 elif mode == "average": 103 # average precision and recall over queries 104 105 mask = ngt == 0 106 ngt[mask] = 1 107 108 recalls = ninter / ngt 109 recalls[mask] = (nres[mask] == 0).astype(float) 110 111 # avoid division by 0 112 mask = nres == 0 113 assert np.all(ninter[mask] == 0) 114 ninter[mask] = 1 115 nres[mask] = 1 116 117 precisions = ninter / nres 118 119 return precisions.mean(), recalls.mean() 120 121 else: 122 raise AssertionError() 123 124def sort_range_res_2(lims, D, I): 125 """ sort 2 arrays using the first as key """ 126 I2 = np.empty_like(I) 127 D2 = np.empty_like(D) 128 nq = len(lims) - 1 129 for i in range(nq): 130 l0, l1 = lims[i], lims[i + 1] 131 ii = I[l0:l1] 132 di = D[l0:l1] 133 o = di.argsort() 134 I2[l0:l1] = ii[o] 135 D2[l0:l1] = di[o] 136 return I2, D2 137 138 139def sort_range_res_1(lims, I): 140 I2 = np.empty_like(I) 141 nq = len(lims) - 1 142 for i in range(nq): 143 l0, l1 = lims[i], lims[i + 1] 144 I2[l0:l1] = I[l0:l1] 145 I2[l0:l1].sort() 146 return I2 147 148 149def range_PR_multiple_thresholds( 150 lims_ref, Iref, 151 lims_new, Dnew, Inew, 152 thresholds, 153 mode="overall", do_sort="ref,new" 154 ): 155 """ compute precision-recall values for range search results 156 for several thresholds on the "new" results. 157 This is to plot PR curves 158 """ 159 # ref should be sorted by ids 160 if "ref" in do_sort: 161 Iref = sort_range_res_1(lims_ref, Iref) 162 163 # new should be sorted by distances 164 if "new" in do_sort: 165 Inew, Dnew = sort_range_res_2(lims_new, Dnew, Inew) 166 167 def ref_result_for(i): 168 return Iref[lims_ref[i]:lims_ref[i + 1]] 169 170 def new_result_for(i): 171 l0, l1 = lims_new[i], lims_new[i + 1] 172 return Inew[l0:l1], Dnew[l0:l1] 173 174 nq = lims_ref.size - 1 175 assert lims_new.size - 1 == nq 176 177 nt = len(thresholds) 178 counts = np.zeros((nq, nt, 3), dtype="int64") 179 180 def compute_PR_for(q): 181 gt_ids = ref_result_for(q) 182 res_ids, res_dis = new_result_for(q) 183 184 counts[q, :, 0] = len(gt_ids) 185 186 if res_dis.size == 0: 187 # the rest remains at 0 188 return 189 190 # which offsets we are interested in 191 nres= np.searchsorted(res_dis, thresholds) 192 counts[q, :, 1] = nres 193 194 if gt_ids.size == 0: 195 return 196 197 # find number of TPs at each stage in the result list 198 ii = np.searchsorted(gt_ids, res_ids) 199 ii[ii == len(gt_ids)] = -1 200 n_ok = np.cumsum(gt_ids[ii] == res_ids) 201 202 # focus on threshold points 203 n_ok = np.hstack(([0], n_ok)) 204 counts[q, :, 2] = n_ok[nres] 205 206 pool = ThreadPool(20) 207 pool.map(compute_PR_for, range(nq)) 208 # print(counts.transpose(2, 1, 0)) 209 210 precisions = np.zeros(nt) 211 recalls = np.zeros(nt) 212 for t in range(nt): 213 p, r = counts_to_PR( 214 counts[:, t, 0], counts[:, t, 1], counts[:, t, 2], 215 mode=mode 216 ) 217 precisions[t] = p 218 recalls[t] = r 219 220 return precisions, recalls 221 222 223 224 225############################################################### 226# Functions that compare search results with a reference result. 227# They are intended for use in tests 228 229def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew): 230 """ test that knn search results are identical, raise if not """ 231 np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5) 232 # here we have to be careful because of draws 233 testcase = unittest.TestCase() # because it makes nice error messages 234 for i in range(len(Iref)): 235 if np.all(Iref[i] == Inew[i]): # easy case 236 continue 237 # we can deduce nothing about the latest line 238 skip_dis = Dref[i, -1] 239 for dis in np.unique(Dref): 240 if dis == skip_dis: 241 continue 242 mask = Dref[i, :] == dis 243 testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask])) 244 245 246def test_ref_range_results(lims_ref, Dref, Iref, 247 lims_new, Dnew, Inew): 248 """ compare range search results wrt. a reference result, 249 throw if it fails """ 250 np.testing.assert_array_equal(lims_ref, lims_new) 251 nq = len(lims_ref) - 1 252 for i in range(nq): 253 l0, l1 = lims_ref[i], lims_ref[i + 1] 254 Ii_ref = Iref[l0:l1] 255 Ii_new = Inew[l0:l1] 256 Di_ref = Dref[l0:l1] 257 Di_new = Dnew[l0:l1] 258 if np.all(Ii_ref == Ii_new): # easy 259 pass 260 else: 261 def sort_by_ids(I, D): 262 o = I.argsort() 263 return I[o], D[o] 264 # sort both 265 (Ii_ref, Di_ref) = sort_by_ids(Ii_ref, Di_ref) 266 (Ii_new, Di_new) = sort_by_ids(Ii_new, Di_new) 267 np.testing.assert_array_equal(Ii_ref, Ii_new) 268 np.testing.assert_array_almost_equal(Di_ref, Di_new, decimal=5) 269