1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import os
7import numpy as np
8import faiss
9
10from .vecs_io import fvecs_read, ivecs_read, bvecs_mmap, fvecs_mmap
11from .exhaustive_search import knn
12
13class Dataset:
14    """ Generic abstract class for a test dataset """
15
16    def __init__(self):
17        """ the constructor should set the following fields: """
18        self.d = -1
19        self.metric = 'L2'   # or IP
20        self.nq = -1
21        self.nb = -1
22        self.nt = -1
23
24    def get_queries(self):
25        """ return the queries as a (nq, d) array """
26        raise NotImplementedError()
27
28    def get_train(self, maxtrain=None):
29        """ return the queries as a (nt, d) array """
30        raise NotImplementedError()
31
32    def get_database(self):
33        """ return the queries as a (nb, d) array """
34        raise NotImplementedError()
35
36    def database_iterator(self, bs=128, split=(1, 0)):
37        """returns an iterator on database vectors.
38        bs is the number of vectors per batch
39        split = (nsplit, rank) means the dataset is split in nsplit
40        shards and we want shard number rank
41        The default implementation just iterates over the full matrix
42        returned by get_dataset.
43        """
44        xb = self.get_database()
45        nsplit, rank = split
46        i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
47        for j0 in range(i0, i1, bs):
48            yield xb[j0: min(j0 + bs, i1)]
49
50    def get_groundtruth(self, k=None):
51        """ return the ground truth for k-nearest neighbor search """
52        raise NotImplementedError()
53
54    def get_groundtruth_range(self, thresh=None):
55        """ return the ground truth for range search """
56        raise NotImplementedError()
57
58    def __str__(self):
59        return (f"dataset in dimension {self.d}, with metric {self.metric}, "
60                f"size: Q {self.nq} B {self.nb} T {self.nt}")
61
62    def check_sizes(self):
63        """ runs the previous and checks the sizes of the matrices """
64        assert self.get_queries().shape == (self.nq, self.d)
65        if self.nt > 0:
66            xt = self.get_train(maxtrain=123)
67            assert xt.shape == (123, self.d), "shape=%s" % (xt.shape, )
68        assert self.get_database().shape == (self.nb, self.d)
69        assert self.get_groundtruth(k=13).shape == (self.nq, 13)
70
71
72class SyntheticDataset(Dataset):
73    """A dataset that is not completely random but still challenging to
74    index
75    """
76
77    def __init__(self, d, nt, nb, nq, metric='L2'):
78        Dataset.__init__(self)
79        self.d, self.nt, self.nb, self.nq = d, nt, nb, nq
80        d1 = 10     # intrinsic dimension (more or less)
81        n = nb + nt + nq
82        rs = np.random.RandomState(1338)
83        x = rs.normal(size=(n, d1))
84        x = np.dot(x, rs.rand(d1, d))
85        # now we have a d1-dim ellipsoid in d-dimensional space
86        # higher factor (>4) -> higher frequency -> less linear
87        x = x * (rs.rand(d) * 4 + 0.1)
88        x = np.sin(x)
89        x = x.astype('float32')
90        self.metric = metric
91        self.xt = x[:nt]
92        self.xb = x[nt:nt + nb]
93        self.xq = x[nt + nb:]
94
95    def get_queries(self):
96        return self.xq
97
98    def get_train(self, maxtrain=None):
99        maxtrain = maxtrain if maxtrain is not None else self.nt
100        return self.xt[:maxtrain]
101
102    def get_database(self):
103        return self.xb
104
105    def get_groundtruth(self, k=100):
106        return knn(
107            self.xq, self.xb, k,
108            faiss.METRIC_L2 if self.metric == 'L2' else faiss.METRIC_INNER_PRODUCT
109        )[1]
110
111
112############################################################################
113# The following datasets are a few standard open-source datasets
114# they should be stored in a directory, and we start by guessing where
115# that directory is
116############################################################################
117
118
119for dataset_basedir in (
120        '/datasets01/simsearch/041218/',
121        '/mnt/vol/gfsai-flash3-east/ai-group/datasets/simsearch/'):
122    if os.path.exists(dataset_basedir):
123        break
124else:
125    # users can link their data directory to `./data`
126    dataset_basedir = 'data/'
127
128
129class DatasetSIFT1M(Dataset):
130    """
131    The original dataset is available at: http://corpus-texmex.irisa.fr/
132    (ANN_SIFT1M)
133    """
134
135    def __init__(self):
136        Dataset.__init__(self)
137        self.d, self.nt, self.nb, self.nq = 128, 100000, 1000000, 10000
138        self.basedir = dataset_basedir + 'sift1M/'
139
140    def get_queries(self):
141        return fvecs_read(self.basedir + "sift_query.fvecs")
142
143    def get_train(self, maxtrain=None):
144        maxtrain = maxtrain if maxtrain is not None else self.nt
145        return fvecs_read(self.basedir + "sift_learn.fvecs")[:maxtrain]
146
147    def get_database(self):
148        return fvecs_read(self.basedir + "sift_base.fvecs")
149
150    def get_groundtruth(self, k=None):
151        gt = ivecs_read(self.basedir + "sift_groundtruth.ivecs")
152        if k is not None:
153            assert k <= 100
154            gt = gt[:, :k]
155        return gt
156
157
158def sanitize(x):
159    return np.ascontiguousarray(x, dtype='float32')
160
161
162class DatasetBigANN(Dataset):
163    """
164    The original dataset is available at: http://corpus-texmex.irisa.fr/
165    (ANN_SIFT1B)
166    """
167
168    def __init__(self, nb_M=1000):
169        Dataset.__init__(self)
170        assert nb_M in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000)
171        self.nb_M = nb_M
172        nb = nb_M * 10**6
173        self.d, self.nt, self.nb, self.nq = 128, 10**8, nb, 10000
174        self.basedir = dataset_basedir + 'bigann/'
175
176    def get_queries(self):
177        return sanitize(bvecs_mmap(self.basedir + 'bigann_query.bvecs')[:])
178
179    def get_train(self, maxtrain=None):
180        maxtrain = maxtrain if maxtrain is not None else self.nt
181        return sanitize(bvecs_mmap(self.basedir + 'bigann_learn.bvecs')[:maxtrain])
182
183    def get_groundtruth(self, k=None):
184        gt = ivecs_read(self.basedir + 'gnd/idx_%dM.ivecs' % self.nb_M)
185        if k is not None:
186            assert k <= 100
187            gt = gt[:, :k]
188        return gt
189
190    def get_database(self):
191        assert self.nb_M < 100, "dataset too large, use iterator"
192        return sanitize(bvecs_mmap(self.basedir + 'bigann_base.bvecs')[:self.nb])
193
194    def database_iterator(self, bs=128, split=(1, 0)):
195        xb = bvecs_mmap(self.basedir + 'bigann_base.bvecs')
196        nsplit, rank = split
197        i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
198        for j0 in range(i0, i1, bs):
199            yield sanitize(xb[j0: min(j0 + bs, i1)])
200
201
202class DatasetDeep1B(Dataset):
203    """
204    See
205    https://github.com/facebookresearch/faiss/tree/master/benchs#getting-deep1b
206    on how to get the data
207    """
208
209    def __init__(self, nb=10**9):
210        Dataset.__init__(self)
211        nb_to_name = {
212            10**5: '100k',
213            10**6: '1M',
214            10**7: '10M',
215            10**8: '100M',
216            10**9: '1B'
217        }
218        assert nb in nb_to_name
219        self.d, self.nt, self.nb, self.nq = 96, 358480000, nb, 10000
220        self.basedir = dataset_basedir + 'deep1b/'
221        self.gt_fname = "%sdeep%s_groundtruth.ivecs" % (
222            self.basedir, nb_to_name[self.nb])
223
224    def get_queries(self):
225        return sanitize(fvecs_read(self.basedir + "deep1B_queries.fvecs"))
226
227    def get_train(self, maxtrain=None):
228        maxtrain = maxtrain if maxtrain is not None else self.nt
229        return sanitize(fvecs_mmap(self.basedir + "learn.fvecs")[:maxtrain])
230
231    def get_groundtruth(self, k=None):
232        gt = ivecs_read(self.gt_fname)
233        if k is not None:
234            assert k <= 100
235            gt = gt[:, :k]
236        return gt
237
238    def get_database(self):
239        assert self.nb <= 10**8, "dataset too large, use iterator"
240        return sanitize(fvecs_mmap(self.basedir + "base.fvecs")[:self.nb])
241
242    def database_iterator(self, bs=128, split=(1, 0)):
243        xb = fvecs_mmap(self.basedir + "base.fvecs")
244        nsplit, rank = split
245        i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
246        for j0 in range(i0, i1, bs):
247            yield sanitize(xb[j0: min(j0 + bs, i1)])
248
249
250class DatasetGlove(Dataset):
251    """
252    Data from http://ann-benchmarks.com/glove-100-angular.hdf5
253    """
254
255    def __init__(self, loc=None, download=False):
256        import h5py
257        assert not download, "not implemented"
258        if not loc:
259            loc = dataset_basedir + 'glove/glove-100-angular.hdf5'
260        self.glove_h5py = h5py.File(loc, 'r')
261        # IP and L2 are equivalent in this case, but it is traditionally seen as an IP dataset
262        self.metric = 'IP'
263        self.d, self.nt = 100, 0
264        self.nb = self.glove_h5py['train'].shape[0]
265        self.nq = self.glove_h5py['test'].shape[0]
266
267    def get_queries(self):
268        xq = np.array(self.glove_h5py['test'])
269        faiss.normalize_L2(xq)
270        return xq
271
272    def get_database(self):
273        xb = np.array(self.glove_h5py['train'])
274        faiss.normalize_L2(xb)
275        return xb
276
277    def get_groundtruth(self, k=None):
278        gt = self.glove_h5py['neighbors']
279        if k is not None:
280            assert k <= 100
281            gt = gt[:, :k]
282        return gt
283
284
285class DatasetMusic100(Dataset):
286    """
287    get dataset from
288    https://github.com/stanis-morozov/ip-nsw#dataset
289    """
290
291    def __init__(self):
292        Dataset.__init__(self)
293        self.d, self.nt, self.nb, self.nq = 100, 0, 10**6, 10000
294        self.metric = 'IP'
295        self.basedir = dataset_basedir + 'music-100/'
296
297    def get_queries(self):
298        xq = np.fromfile(self.basedir + 'query_music100.bin', dtype='float32')
299        xq = xq.reshape(-1, 100)
300        return xq
301
302    def get_database(self):
303        xb = np.fromfile(self.basedir + 'database_music100.bin', dtype='float32')
304        xb = xb.reshape(-1, 100)
305        return xb
306
307    def get_groundtruth(self, k=None):
308        gt = np.load(self.basedir + 'gt.npy')
309        if k is not None:
310            assert k <= 100
311            gt = gt[:, :k]
312        return gt
313