1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6import os 7import numpy as np 8import faiss 9 10from .vecs_io import fvecs_read, ivecs_read, bvecs_mmap, fvecs_mmap 11from .exhaustive_search import knn 12 13class Dataset: 14 """ Generic abstract class for a test dataset """ 15 16 def __init__(self): 17 """ the constructor should set the following fields: """ 18 self.d = -1 19 self.metric = 'L2' # or IP 20 self.nq = -1 21 self.nb = -1 22 self.nt = -1 23 24 def get_queries(self): 25 """ return the queries as a (nq, d) array """ 26 raise NotImplementedError() 27 28 def get_train(self, maxtrain=None): 29 """ return the queries as a (nt, d) array """ 30 raise NotImplementedError() 31 32 def get_database(self): 33 """ return the queries as a (nb, d) array """ 34 raise NotImplementedError() 35 36 def database_iterator(self, bs=128, split=(1, 0)): 37 """returns an iterator on database vectors. 38 bs is the number of vectors per batch 39 split = (nsplit, rank) means the dataset is split in nsplit 40 shards and we want shard number rank 41 The default implementation just iterates over the full matrix 42 returned by get_dataset. 43 """ 44 xb = self.get_database() 45 nsplit, rank = split 46 i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit 47 for j0 in range(i0, i1, bs): 48 yield xb[j0: min(j0 + bs, i1)] 49 50 def get_groundtruth(self, k=None): 51 """ return the ground truth for k-nearest neighbor search """ 52 raise NotImplementedError() 53 54 def get_groundtruth_range(self, thresh=None): 55 """ return the ground truth for range search """ 56 raise NotImplementedError() 57 58 def __str__(self): 59 return (f"dataset in dimension {self.d}, with metric {self.metric}, " 60 f"size: Q {self.nq} B {self.nb} T {self.nt}") 61 62 def check_sizes(self): 63 """ runs the previous and checks the sizes of the matrices """ 64 assert self.get_queries().shape == (self.nq, self.d) 65 if self.nt > 0: 66 xt = self.get_train(maxtrain=123) 67 assert xt.shape == (123, self.d), "shape=%s" % (xt.shape, ) 68 assert self.get_database().shape == (self.nb, self.d) 69 assert self.get_groundtruth(k=13).shape == (self.nq, 13) 70 71 72class SyntheticDataset(Dataset): 73 """A dataset that is not completely random but still challenging to 74 index 75 """ 76 77 def __init__(self, d, nt, nb, nq, metric='L2'): 78 Dataset.__init__(self) 79 self.d, self.nt, self.nb, self.nq = d, nt, nb, nq 80 d1 = 10 # intrinsic dimension (more or less) 81 n = nb + nt + nq 82 rs = np.random.RandomState(1338) 83 x = rs.normal(size=(n, d1)) 84 x = np.dot(x, rs.rand(d1, d)) 85 # now we have a d1-dim ellipsoid in d-dimensional space 86 # higher factor (>4) -> higher frequency -> less linear 87 x = x * (rs.rand(d) * 4 + 0.1) 88 x = np.sin(x) 89 x = x.astype('float32') 90 self.metric = metric 91 self.xt = x[:nt] 92 self.xb = x[nt:nt + nb] 93 self.xq = x[nt + nb:] 94 95 def get_queries(self): 96 return self.xq 97 98 def get_train(self, maxtrain=None): 99 maxtrain = maxtrain if maxtrain is not None else self.nt 100 return self.xt[:maxtrain] 101 102 def get_database(self): 103 return self.xb 104 105 def get_groundtruth(self, k=100): 106 return knn( 107 self.xq, self.xb, k, 108 faiss.METRIC_L2 if self.metric == 'L2' else faiss.METRIC_INNER_PRODUCT 109 )[1] 110 111 112############################################################################ 113# The following datasets are a few standard open-source datasets 114# they should be stored in a directory, and we start by guessing where 115# that directory is 116############################################################################ 117 118 119for dataset_basedir in ( 120 '/datasets01/simsearch/041218/', 121 '/mnt/vol/gfsai-flash3-east/ai-group/datasets/simsearch/'): 122 if os.path.exists(dataset_basedir): 123 break 124else: 125 # users can link their data directory to `./data` 126 dataset_basedir = 'data/' 127 128 129class DatasetSIFT1M(Dataset): 130 """ 131 The original dataset is available at: http://corpus-texmex.irisa.fr/ 132 (ANN_SIFT1M) 133 """ 134 135 def __init__(self): 136 Dataset.__init__(self) 137 self.d, self.nt, self.nb, self.nq = 128, 100000, 1000000, 10000 138 self.basedir = dataset_basedir + 'sift1M/' 139 140 def get_queries(self): 141 return fvecs_read(self.basedir + "sift_query.fvecs") 142 143 def get_train(self, maxtrain=None): 144 maxtrain = maxtrain if maxtrain is not None else self.nt 145 return fvecs_read(self.basedir + "sift_learn.fvecs")[:maxtrain] 146 147 def get_database(self): 148 return fvecs_read(self.basedir + "sift_base.fvecs") 149 150 def get_groundtruth(self, k=None): 151 gt = ivecs_read(self.basedir + "sift_groundtruth.ivecs") 152 if k is not None: 153 assert k <= 100 154 gt = gt[:, :k] 155 return gt 156 157 158def sanitize(x): 159 return np.ascontiguousarray(x, dtype='float32') 160 161 162class DatasetBigANN(Dataset): 163 """ 164 The original dataset is available at: http://corpus-texmex.irisa.fr/ 165 (ANN_SIFT1B) 166 """ 167 168 def __init__(self, nb_M=1000): 169 Dataset.__init__(self) 170 assert nb_M in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000) 171 self.nb_M = nb_M 172 nb = nb_M * 10**6 173 self.d, self.nt, self.nb, self.nq = 128, 10**8, nb, 10000 174 self.basedir = dataset_basedir + 'bigann/' 175 176 def get_queries(self): 177 return sanitize(bvecs_mmap(self.basedir + 'bigann_query.bvecs')[:]) 178 179 def get_train(self, maxtrain=None): 180 maxtrain = maxtrain if maxtrain is not None else self.nt 181 return sanitize(bvecs_mmap(self.basedir + 'bigann_learn.bvecs')[:maxtrain]) 182 183 def get_groundtruth(self, k=None): 184 gt = ivecs_read(self.basedir + 'gnd/idx_%dM.ivecs' % self.nb_M) 185 if k is not None: 186 assert k <= 100 187 gt = gt[:, :k] 188 return gt 189 190 def get_database(self): 191 assert self.nb_M < 100, "dataset too large, use iterator" 192 return sanitize(bvecs_mmap(self.basedir + 'bigann_base.bvecs')[:self.nb]) 193 194 def database_iterator(self, bs=128, split=(1, 0)): 195 xb = bvecs_mmap(self.basedir + 'bigann_base.bvecs') 196 nsplit, rank = split 197 i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit 198 for j0 in range(i0, i1, bs): 199 yield sanitize(xb[j0: min(j0 + bs, i1)]) 200 201 202class DatasetDeep1B(Dataset): 203 """ 204 See 205 https://github.com/facebookresearch/faiss/tree/master/benchs#getting-deep1b 206 on how to get the data 207 """ 208 209 def __init__(self, nb=10**9): 210 Dataset.__init__(self) 211 nb_to_name = { 212 10**5: '100k', 213 10**6: '1M', 214 10**7: '10M', 215 10**8: '100M', 216 10**9: '1B' 217 } 218 assert nb in nb_to_name 219 self.d, self.nt, self.nb, self.nq = 96, 358480000, nb, 10000 220 self.basedir = dataset_basedir + 'deep1b/' 221 self.gt_fname = "%sdeep%s_groundtruth.ivecs" % ( 222 self.basedir, nb_to_name[self.nb]) 223 224 def get_queries(self): 225 return sanitize(fvecs_read(self.basedir + "deep1B_queries.fvecs")) 226 227 def get_train(self, maxtrain=None): 228 maxtrain = maxtrain if maxtrain is not None else self.nt 229 return sanitize(fvecs_mmap(self.basedir + "learn.fvecs")[:maxtrain]) 230 231 def get_groundtruth(self, k=None): 232 gt = ivecs_read(self.gt_fname) 233 if k is not None: 234 assert k <= 100 235 gt = gt[:, :k] 236 return gt 237 238 def get_database(self): 239 assert self.nb <= 10**8, "dataset too large, use iterator" 240 return sanitize(fvecs_mmap(self.basedir + "base.fvecs")[:self.nb]) 241 242 def database_iterator(self, bs=128, split=(1, 0)): 243 xb = fvecs_mmap(self.basedir + "base.fvecs") 244 nsplit, rank = split 245 i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit 246 for j0 in range(i0, i1, bs): 247 yield sanitize(xb[j0: min(j0 + bs, i1)]) 248 249 250class DatasetGlove(Dataset): 251 """ 252 Data from http://ann-benchmarks.com/glove-100-angular.hdf5 253 """ 254 255 def __init__(self, loc=None, download=False): 256 import h5py 257 assert not download, "not implemented" 258 if not loc: 259 loc = dataset_basedir + 'glove/glove-100-angular.hdf5' 260 self.glove_h5py = h5py.File(loc, 'r') 261 # IP and L2 are equivalent in this case, but it is traditionally seen as an IP dataset 262 self.metric = 'IP' 263 self.d, self.nt = 100, 0 264 self.nb = self.glove_h5py['train'].shape[0] 265 self.nq = self.glove_h5py['test'].shape[0] 266 267 def get_queries(self): 268 xq = np.array(self.glove_h5py['test']) 269 faiss.normalize_L2(xq) 270 return xq 271 272 def get_database(self): 273 xb = np.array(self.glove_h5py['train']) 274 faiss.normalize_L2(xb) 275 return xb 276 277 def get_groundtruth(self, k=None): 278 gt = self.glove_h5py['neighbors'] 279 if k is not None: 280 assert k <= 100 281 gt = gt[:, :k] 282 return gt 283 284 285class DatasetMusic100(Dataset): 286 """ 287 get dataset from 288 https://github.com/stanis-morozov/ip-nsw#dataset 289 """ 290 291 def __init__(self): 292 Dataset.__init__(self) 293 self.d, self.nt, self.nb, self.nq = 100, 0, 10**6, 10000 294 self.metric = 'IP' 295 self.basedir = dataset_basedir + 'music-100/' 296 297 def get_queries(self): 298 xq = np.fromfile(self.basedir + 'query_music100.bin', dtype='float32') 299 xq = xq.reshape(-1, 100) 300 return xq 301 302 def get_database(self): 303 xb = np.fromfile(self.basedir + 'database_music100.bin', dtype='float32') 304 xb = xb.reshape(-1, 100) 305 return xb 306 307 def get_groundtruth(self, k=None): 308 gt = np.load(self.basedir + 'gt.npy') 309 if k is not None: 310 assert k <= 100 311 gt = gt[:, :k] 312 return gt 313