1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6#@nolint
7
8# not linting this file because it imports * from swigfaiss, which
9# causes a ton of useless warnings.
10
11import numpy as np
12import sys
13import inspect
14import array
15import warnings
16
17# We import * so that the symbol foo can be accessed as faiss.foo.
18from .loader import *
19
20
21__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR,
22                            FAISS_VERSION_MINOR,
23                            FAISS_VERSION_PATCH)
24
25##################################################################
26# The functions below add or replace some methods for classes
27# this is to be able to pass in numpy arrays directly
28# The C++ version of the classnames will be suffixed with _c
29##################################################################
30
31
32def replace_method(the_class, name, replacement, ignore_missing=False):
33    """ Replaces a method in a class with another version. The old method
34    is renamed to method_name_c (because presumably it was implemented in C) """
35    try:
36        orig_method = getattr(the_class, name)
37    except AttributeError:
38        if ignore_missing:
39            return
40        raise
41    if orig_method.__name__ == 'replacement_' + name:
42        # replacement was done in parent class
43        return
44    setattr(the_class, name + '_c', orig_method)
45    setattr(the_class, name, replacement)
46
47def handle_Clustering():
48
49    def replacement_train(self, x, index, weights=None):
50        """Perform clustering on a set of vectors. The index is used for assignment.
51
52        Parameters
53        ----------
54        x : array_like
55            Training vectors, shape (n, self.d). `dtype` must be float32.
56        index : faiss.Index
57            Index used for assignment. The dimension of the index should be `self.d`.
58        weights : array_like, optional
59            Per training sample weight (size n) used when computing the weighted
60            average to obtain the centroid (default is 1 for all training vectors).
61        """
62        n, d = x.shape
63        assert d == self.d
64        if weights is not None:
65            assert weights.shape == (n, )
66            self.train_c(n, swig_ptr(x), index, swig_ptr(weights))
67        else:
68            self.train_c(n, swig_ptr(x), index)
69
70    def replacement_train_encoded(self, x, codec, index, weights=None):
71        """ Perform clustering on a set of compressed vectors. The index is used for assignment.
72        The decompression is performed on-the-fly.
73
74        Parameters
75        ----------
76        x : array_like
77            Training vectors, shape (n, codec.code_size()). `dtype` must be `uint8`.
78        codec : faiss.Index
79            Index used to decode the vectors. Should have dimension `self.d`.
80        index : faiss.Index
81            Index used for assignment. The dimension of the index should be `self.d`.
82        weigths : array_like, optional
83            Per training sample weight (size n) used when computing the weighted
84            average to obtain the centroid (default is 1 for all training vectors).
85        """
86        n, d = x.shape
87        assert d == codec.sa_code_size()
88        assert codec.d == index.d
89        if weights is not None:
90            assert weights.shape == (n, )
91            self.train_encoded_c(n, swig_ptr(x), codec, index, swig_ptr(weights))
92        else:
93            self.train_encoded_c(n, swig_ptr(x), codec, index)
94    replace_method(Clustering, 'train', replacement_train)
95    replace_method(Clustering, 'train_encoded', replacement_train_encoded)
96
97
98handle_Clustering()
99
100
101def handle_Quantizer(the_class):
102
103    def replacement_train(self, x):
104        """ Train the quantizer on a set of training vectors.
105
106        Parameters
107        ----------
108        x : array_like
109            Training vectors, shape (n, self.d). `dtype` must be float32.
110        """
111        n, d = x.shape
112        assert d == self.d
113        self.train_c(n, swig_ptr(x))
114
115    def replacement_compute_codes(self, x):
116        """ Compute the codes corresponding to a set of vectors.
117
118        Parameters
119        ----------
120        x : array_like
121            Vectors to encode, shape (n, self.d). `dtype` must be float32.
122
123        Returns
124        -------
125        codes : array_like
126            Corresponding code for each vector, shape (n, self.code_size)
127            and `dtype` uint8.
128        """
129        n, d = x.shape
130        assert d == self.d
131        codes = np.empty((n, self.code_size), dtype='uint8')
132        self.compute_codes_c(swig_ptr(x), swig_ptr(codes), n)
133        return codes
134
135    def replacement_decode(self, codes):
136        """Reconstruct an approximation of vectors given their codes.
137
138        Parameters
139        ----------
140        codes : array_like
141            Codes to decode, shape (n, self.code_size). `dtype` must be uint8.
142
143        Returns
144        -------
145            Reconstructed vectors for each code, shape `(n, d)` and `dtype` float32.
146        """
147        n, cs = codes.shape
148        assert cs == self.code_size
149        x = np.empty((n, self.d), dtype='float32')
150        self.decode_c(swig_ptr(codes), swig_ptr(x), n)
151        return x
152
153    replace_method(the_class, 'train', replacement_train)
154    replace_method(the_class, 'compute_codes', replacement_compute_codes)
155    replace_method(the_class, 'decode', replacement_decode)
156
157
158handle_Quantizer(ProductQuantizer)
159handle_Quantizer(ScalarQuantizer)
160handle_Quantizer(ResidualQuantizer)
161handle_Quantizer(LocalSearchQuantizer)
162
163
164def handle_NSG(the_class):
165
166    def replacement_build(self, x, graph):
167        n, d = x.shape
168        assert d == self.d
169        assert graph.ndim == 2
170        assert graph.shape[0] == n
171        K = graph.shape[1]
172        self.build_c(n, swig_ptr(x), swig_ptr(graph), K)
173
174    replace_method(the_class, 'build', replacement_build)
175
176
177def handle_Index(the_class):
178
179    def replacement_add(self, x):
180        """Adds vectors to the index.
181        The index must be trained before vectors can be added to it.
182        The vectors are implicitly numbered in sequence. When `n` vectors are
183        added to the index, they are given ids `ntotal`, `ntotal + 1`, ..., `ntotal + n - 1`.
184
185        Parameters
186        ----------
187        x : array_like
188            Query vectors, shape (n, d) where d is appropriate for the index.
189            `dtype` must be float32.
190        """
191
192        n, d = x.shape
193        assert d == self.d
194        self.add_c(n, swig_ptr(x))
195
196    def replacement_add_with_ids(self, x, ids):
197        """Adds vectors with arbitrary ids to the index (not all indexes support this).
198        The index must be trained before vectors can be added to it.
199        Vector `i` is stored in `x[i]` and has id `ids[i]`.
200
201        Parameters
202        ----------
203        x : array_like
204            Query vectors, shape (n, d) where d is appropriate for the index.
205            `dtype` must be float32.
206        ids : array_like
207            Array if ids of size n. The ids must be of type `int64`. Note that `-1` is reserved
208            in result lists to mean "not found" so it's better to not use it as an id.
209        """
210        n, d = x.shape
211        assert d == self.d
212
213        assert ids.shape == (n, ), 'not same nb of vectors as ids'
214        self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids))
215
216    def replacement_assign(self, x, k, labels=None):
217        """Find the k nearest neighbors of the set of vectors x in the index.
218        This is the same as the `search` method, but discards the distances.
219
220        Parameters
221        ----------
222        x : array_like
223            Query vectors, shape (n, d) where d is appropriate for the index.
224            `dtype` must be float32.
225        k : int
226            Number of nearest neighbors.
227        labels : array_like, optional
228            Labels array to store the results.
229
230        Returns
231        -------
232        labels: array_like
233            Labels of the nearest neighbors, shape (n, k).
234            When not enough results are found, the label is set to -1
235        """
236        n, d = x.shape
237        assert d == self.d
238
239        if labels is None:
240            labels = np.empty((n, k), dtype=np.int64)
241        else:
242            assert labels.shape == (n, k)
243
244        self.assign_c(n, swig_ptr(x), swig_ptr(labels), k)
245        return labels
246
247    def replacement_train(self, x):
248        """Trains the index on a representative set of vectors.
249        The index must be trained before vectors can be added to it.
250
251        Parameters
252        ----------
253        x : array_like
254            Query vectors, shape (n, d) where d is appropriate for the index.
255            `dtype` must be float32.
256        """
257        n, d = x.shape
258        assert d == self.d
259        self.train_c(n, swig_ptr(x))
260
261    def replacement_search(self, x, k, D=None, I=None):
262        """Find the k nearest neighbors of the set of vectors x in the index.
263
264        Parameters
265        ----------
266        x : array_like
267            Query vectors, shape (n, d) where d is appropriate for the index.
268            `dtype` must be float32.
269        k : int
270            Number of nearest neighbors.
271        D : array_like, optional
272            Distance array to store the result.
273        I : array_like, optional
274            Labels array to store the results.
275
276        Returns
277        -------
278        D : array_like
279            Distances of the nearest neighbors, shape (n, k). When not enough results are found
280            the label is set to +Inf or -Inf.
281        I : array_like
282            Labels of the nearest neighbors, shape (n, k).
283            When not enough results are found, the label is set to -1
284        """
285
286        n, d = x.shape
287        assert d == self.d
288
289        assert k > 0
290
291        if D is None:
292            D = np.empty((n, k), dtype=np.float32)
293        else:
294            assert D.shape == (n, k)
295
296        if I is None:
297            I = np.empty((n, k), dtype=np.int64)
298        else:
299            assert I.shape == (n, k)
300
301        self.search_c(n, swig_ptr(x), k, swig_ptr(D), swig_ptr(I))
302        return D, I
303
304    def replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None):
305        """Find the k nearest neighbors of the set of vectors x in the index,
306        and return an approximation of these vectors.
307
308        Parameters
309        ----------
310        x : array_like
311            Query vectors, shape (n, d) where d is appropriate for the index.
312            `dtype` must be float32.
313        k : int
314            Number of nearest neighbors.
315        D : array_like, optional
316            Distance array to store the result.
317        I : array_like, optional
318            Labels array to store the result.
319        R : array_like, optional
320            reconstruction array to store
321
322        Returns
323        -------
324        D : array_like
325            Distances of the nearest neighbors, shape (n, k). When not enough results are found
326            the label is set to +Inf or -Inf.
327        I : array_like
328            Labels of the nearest neighbors, shape (n, k). When not enough results are found,
329            the label is set to -1
330        R : array_like
331            Approximate (reconstructed) nearest neighbor vectors, shape (n, k, d).
332        """
333        n, d = x.shape
334        assert d == self.d
335
336        assert k > 0
337
338        if D is None:
339            D = np.empty((n, k), dtype=np.float32)
340        else:
341            assert D.shape == (n, k)
342
343        if I is None:
344            I = np.empty((n, k), dtype=np.int64)
345        else:
346            assert I.shape == (n, k)
347
348        if R is None:
349            R = np.empty((n, k, d), dtype=np.float32)
350        else:
351            assert R.shape == (n, k, d)
352
353        self.search_and_reconstruct_c(n, swig_ptr(x),
354                                      k, swig_ptr(D),
355                                      swig_ptr(I),
356                                      swig_ptr(R))
357        return D, I, R
358
359    def replacement_remove_ids(self, x):
360        """Remove some ids from the index.
361        This is a O(ntotal) operation by default, so could be expensive.
362
363        Parameters
364        ----------
365        x : array_like or faiss.IDSelector
366            Either an IDSelector that returns True for vectors to remove, or a
367            list of ids to reomove (1D array of int64). When `x` is a list,
368            it is wrapped into an IDSelector.
369
370        Returns
371        -------
372        n_remove: int
373            number of vectors that were removed
374        """
375        if isinstance(x, IDSelector):
376            sel = x
377        else:
378            assert x.ndim == 1
379            index_ivf = try_extract_index_ivf (self)
380            if index_ivf and index_ivf.direct_map.type == DirectMap.Hashtable:
381                sel = IDSelectorArray(x.size, swig_ptr(x))
382            else:
383                sel = IDSelectorBatch(x.size, swig_ptr(x))
384        return self.remove_ids_c(sel)
385
386    def replacement_reconstruct(self, key, x=None):
387        """Approximate reconstruction of one vector from the index.
388
389        Parameters
390        ----------
391        key : int
392            Id of the vector to reconstruct
393        x : array_like, optional
394            pre-allocated array to store the results
395
396        Returns
397        -------
398        x : array_like
399            Reconstructed vector, size `self.d`, `dtype`=float32
400        """
401        if x is None:
402            x = np.empty(self.d, dtype=np.float32)
403        else:
404            assert x.shape == (self.d, )
405
406        self.reconstruct_c(key, swig_ptr(x))
407        return x
408
409    def replacement_reconstruct_n(self, n0, ni, x=None):
410        """Approximate reconstruction of vectors `n0` ... `n0 + ni - 1` from the index.
411        Missing vectors trigger an exception.
412
413        Parameters
414        ----------
415        n0 : int
416            Id of the first vector to reconstruct
417        ni : int
418            Number of vectors to reconstruct
419        x : array_like, optional
420            pre-allocated array to store the results
421
422        Returns
423        -------
424        x : array_like
425            Reconstructed vectors, size (`ni`, `self.d`), `dtype`=float32
426        """
427        if x is None:
428            x = np.empty((ni, self.d), dtype=np.float32)
429        else:
430            assert x.shape == (ni, self.d)
431
432        self.reconstruct_n_c(n0, ni, swig_ptr(x))
433        return x
434
435    def replacement_update_vectors(self, keys, x):
436        n = keys.size
437        assert keys.shape == (n, )
438        assert x.shape == (n, self.d)
439
440        self.update_vectors_c(n, swig_ptr(keys), swig_ptr(x))
441
442    # The CPU does not support passed-in output buffers
443    def replacement_range_search(self, x, thresh):
444        """Search vectors that are within a distance of the query vectors.
445
446        Parameters
447        ----------
448        x : array_like
449            Query vectors, shape (n, d) where d is appropriate for the index.
450            `dtype` must be float32.
451        thresh : float
452            Threshold to select neighbors. All elements within this radius are returned,
453            except for maximum inner product indexes, where the elements above the
454            threshold are returned
455
456        Returns
457        -------
458        lims: array_like
459            Startring index of the results for each query vector, size n+1.
460        D : array_like
461            Distances of the nearest neighbors, shape `lims[n]`. The distances for
462            query i are in `D[lims[i]:lims[i+1]]`.
463        I : array_like
464            Labels of nearest neighbors, shape `lims[n]`. The labels for query i
465            are in `I[lims[i]:lims[i+1]]`.
466
467        """
468        n, d = x.shape
469        assert d == self.d
470
471        res = RangeSearchResult(n)
472        self.range_search_c(n, swig_ptr(x), thresh, res)
473        # get pointers and copy them
474        lims = rev_swig_ptr(res.lims, n + 1).copy()
475        nd = int(lims[-1])
476        D = rev_swig_ptr(res.distances, nd).copy()
477        I = rev_swig_ptr(res.labels, nd).copy()
478        return lims, D, I
479
480    def replacement_sa_encode(self, x, codes=None):
481
482
483        n, d = x.shape
484        assert d == self.d
485
486        if codes is None:
487            codes = np.empty((n, self.sa_code_size()), dtype=np.uint8)
488        else:
489            assert codes.shape == (n, self.sa_code_size())
490
491        self.sa_encode_c(n, swig_ptr(x), swig_ptr(codes))
492        return codes
493
494    def replacement_sa_decode(self, codes, x=None):
495        n, cs = codes.shape
496        assert cs == self.sa_code_size()
497
498        if x is None:
499            x = np.empty((n, self.d), dtype=np.float32)
500        else:
501            assert x.shape == (n, self.d)
502
503        self.sa_decode_c(n, swig_ptr(codes), swig_ptr(x))
504        return x
505
506    replace_method(the_class, 'add', replacement_add)
507    replace_method(the_class, 'add_with_ids', replacement_add_with_ids)
508    replace_method(the_class, 'assign', replacement_assign)
509    replace_method(the_class, 'train', replacement_train)
510    replace_method(the_class, 'search', replacement_search)
511    replace_method(the_class, 'remove_ids', replacement_remove_ids)
512    replace_method(the_class, 'reconstruct', replacement_reconstruct)
513    replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n)
514    replace_method(the_class, 'range_search', replacement_range_search)
515    replace_method(the_class, 'update_vectors', replacement_update_vectors,
516                   ignore_missing=True)
517    replace_method(the_class, 'search_and_reconstruct',
518                   replacement_search_and_reconstruct, ignore_missing=True)
519    replace_method(the_class, 'sa_encode', replacement_sa_encode)
520    replace_method(the_class, 'sa_decode', replacement_sa_decode)
521
522    # get/set state for pickle
523    # the data is serialized to std::vector -> numpy array -> python bytes
524    # so not very efficient for now.
525
526    def index_getstate(self):
527        return {"this": serialize_index(self).tobytes()}
528
529    def index_setstate(self, st):
530        index2 = deserialize_index(np.frombuffer(st["this"], dtype="uint8"))
531        self.this = index2.this
532
533    the_class.__getstate__ = index_getstate
534    the_class.__setstate__ = index_setstate
535
536
537
538def handle_IndexBinary(the_class):
539
540    def replacement_add(self, x):
541        n, d = x.shape
542        assert d * 8 == self.d
543        self.add_c(n, swig_ptr(x))
544
545    def replacement_add_with_ids(self, x, ids):
546        n, d = x.shape
547        assert d * 8 == self.d
548        assert ids.shape == (n, ), 'not same nb of vectors as ids'
549        self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids))
550
551    def replacement_train(self, x):
552        n, d = x.shape
553        assert d * 8 == self.d
554        self.train_c(n, swig_ptr(x))
555
556    def replacement_reconstruct(self, key):
557        x = np.empty(self.d // 8, dtype=np.uint8)
558        self.reconstruct_c(key, swig_ptr(x))
559        return x
560
561    def replacement_search(self, x, k):
562        n, d = x.shape
563        assert d * 8 == self.d
564        assert k > 0
565        distances = np.empty((n, k), dtype=np.int32)
566        labels = np.empty((n, k), dtype=np.int64)
567        self.search_c(n, swig_ptr(x),
568                      k, swig_ptr(distances),
569                      swig_ptr(labels))
570        return distances, labels
571
572    def replacement_range_search(self, x, thresh):
573        n, d = x.shape
574        assert d * 8 == self.d
575        res = RangeSearchResult(n)
576        self.range_search_c(n, swig_ptr(x), thresh, res)
577        # get pointers and copy them
578        lims = rev_swig_ptr(res.lims, n + 1).copy()
579        nd = int(lims[-1])
580        D = rev_swig_ptr(res.distances, nd).copy()
581        I = rev_swig_ptr(res.labels, nd).copy()
582        return lims, D, I
583
584    def replacement_remove_ids(self, x):
585        if isinstance(x, IDSelector):
586            sel = x
587        else:
588            assert x.ndim == 1
589            sel = IDSelectorBatch(x.size, swig_ptr(x))
590        return self.remove_ids_c(sel)
591
592    replace_method(the_class, 'add', replacement_add)
593    replace_method(the_class, 'add_with_ids', replacement_add_with_ids)
594    replace_method(the_class, 'train', replacement_train)
595    replace_method(the_class, 'search', replacement_search)
596    replace_method(the_class, 'range_search', replacement_range_search)
597    replace_method(the_class, 'reconstruct', replacement_reconstruct)
598    replace_method(the_class, 'remove_ids', replacement_remove_ids)
599
600
601def handle_VectorTransform(the_class):
602
603    def apply_method(self, x):
604        n, d = x.shape
605        assert d == self.d_in
606        y = np.empty((n, self.d_out), dtype=np.float32)
607        self.apply_noalloc(n, swig_ptr(x), swig_ptr(y))
608        return y
609
610    def replacement_reverse_transform(self, x):
611        n, d = x.shape
612        assert d == self.d_out
613        y = np.empty((n, self.d_in), dtype=np.float32)
614        self.reverse_transform_c(n, swig_ptr(x), swig_ptr(y))
615        return y
616
617    def replacement_vt_train(self, x):
618        n, d = x.shape
619        assert d == self.d_in
620        self.train_c(n, swig_ptr(x))
621
622    replace_method(the_class, 'train', replacement_vt_train)
623    # apply is reserved in Pyton...
624    the_class.apply_py = apply_method
625    the_class.apply = apply_method
626    replace_method(the_class, 'reverse_transform',
627                   replacement_reverse_transform)
628
629
630def handle_AutoTuneCriterion(the_class):
631    def replacement_set_groundtruth(self, D, I):
632        if D:
633            assert I.shape == D.shape
634        self.nq, self.gt_nnn = I.shape
635        self.set_groundtruth_c(
636            self.gt_nnn, swig_ptr(D) if D else None, swig_ptr(I))
637
638    def replacement_evaluate(self, D, I):
639        assert I.shape == D.shape
640        assert I.shape == (self.nq, self.nnn)
641        return self.evaluate_c(swig_ptr(D), swig_ptr(I))
642
643    replace_method(the_class, 'set_groundtruth', replacement_set_groundtruth)
644    replace_method(the_class, 'evaluate', replacement_evaluate)
645
646
647def handle_ParameterSpace(the_class):
648    def replacement_explore(self, index, xq, crit):
649        assert xq.shape == (crit.nq, index.d)
650        ops = OperatingPoints()
651        self.explore_c(index, crit.nq, swig_ptr(xq),
652                       crit, ops)
653        return ops
654    replace_method(the_class, 'explore', replacement_explore)
655
656
657def handle_MatrixStats(the_class):
658    original_init = the_class.__init__
659
660    def replacement_init(self, m):
661        assert len(m.shape) == 2
662        original_init(self, m.shape[0], m.shape[1], swig_ptr(m))
663
664    the_class.__init__ = replacement_init
665
666handle_MatrixStats(MatrixStats)
667
668def handle_IOWriter(the_class):
669
670    def write_bytes(self, b):
671        return self(swig_ptr(b), 1, len(b))
672
673    the_class.write_bytes = write_bytes
674
675handle_IOWriter(IOWriter)
676
677def handle_IOReader(the_class):
678
679    def read_bytes(self, totsz):
680        buf = bytearray(totsz)
681        was_read = self(swig_ptr(buf), 1, len(buf))
682        return bytes(buf[:was_read])
683
684    the_class.read_bytes = read_bytes
685
686handle_IOReader(IOReader)
687
688this_module = sys.modules[__name__]
689
690
691for symbol in dir(this_module):
692    obj = getattr(this_module, symbol)
693    # print symbol, isinstance(obj, (type, types.ClassType))
694    if inspect.isclass(obj):
695        the_class = obj
696        if issubclass(the_class, Index):
697            handle_Index(the_class)
698
699        if issubclass(the_class, IndexBinary):
700            handle_IndexBinary(the_class)
701
702        if issubclass(the_class, VectorTransform):
703            handle_VectorTransform(the_class)
704
705        if issubclass(the_class, AutoTuneCriterion):
706            handle_AutoTuneCriterion(the_class)
707
708        if issubclass(the_class, ParameterSpace):
709            handle_ParameterSpace(the_class)
710
711        if issubclass(the_class, IndexNSG):
712            handle_NSG(the_class)
713
714###########################################
715# Utility to add a deprecation warning to
716# classes from the SWIG interface
717###########################################
718
719def _make_deprecated_swig_class(deprecated_name, base_name):
720    """
721    Dynamically construct deprecated classes as wrappers around renamed ones
722
723    The deprecation warning added in their __new__-method will trigger upon
724    construction of an instance of the class, but only once per session.
725
726    We do this here (in __init__.py) because the base classes are defined in
727    the SWIG interface, making it cumbersome to add the deprecation there.
728
729    Parameters
730    ----------
731    deprecated_name : string
732        Name of the class to be deprecated; _not_ present in SWIG interface.
733    base_name : string
734        Name of the class that is replacing deprecated_name; must already be
735        imported into the current namespace.
736
737    Returns
738    -------
739    None
740        However, the deprecated class gets added to the faiss namespace
741    """
742    base_class = globals()[base_name]
743    def new_meth(cls, *args, **kwargs):
744        msg = f"The class faiss.{deprecated_name} is deprecated in favour of faiss.{base_name}!"
745        warnings.warn(msg, DeprecationWarning, stacklevel=2)
746        instance = super(base_class, cls).__new__(cls, *args, **kwargs)
747        return instance
748
749    # three-argument version of "type" uses (name, tuple-of-bases, dict-of-attributes)
750    klazz = type(deprecated_name, (base_class,), {"__new__": new_meth})
751
752    # this ends up adding the class to the "faiss" namespace, in a way that it
753    # is available both through "import faiss" and "from faiss import *"
754    globals()[deprecated_name] = klazz
755
756###########################################
757# Add Python references to objects
758# we do this at the Python class wrapper level.
759###########################################
760
761def add_ref_in_constructor(the_class, parameter_no):
762    # adds a reference to parameter parameter_no in self
763    # so that that parameter does not get deallocated before self
764    original_init = the_class.__init__
765
766    def replacement_init(self, *args):
767        original_init(self, *args)
768        self.referenced_objects = [args[parameter_no]]
769
770    def replacement_init_multiple(self, *args):
771        original_init(self, *args)
772        pset = parameter_no[len(args)]
773        self.referenced_objects = [args[no] for no in pset]
774
775    if type(parameter_no) == dict:
776        # a list of parameters to keep, depending on the number of arguments
777        the_class.__init__ = replacement_init_multiple
778    else:
779        the_class.__init__ = replacement_init
780
781
782def add_ref_in_method(the_class, method_name, parameter_no):
783    original_method = getattr(the_class, method_name)
784    def replacement_method(self, *args):
785        ref = args[parameter_no]
786        if not hasattr(self, 'referenced_objects'):
787            self.referenced_objects = [ref]
788        else:
789            self.referenced_objects.append(ref)
790        return original_method(self, *args)
791    setattr(the_class, method_name, replacement_method)
792
793def add_ref_in_function(function_name, parameter_no):
794    # assumes the function returns an object
795    original_function = getattr(this_module, function_name)
796    def replacement_function(*args):
797        result = original_function(*args)
798        ref = args[parameter_no]
799        result.referenced_objects = [ref]
800        return result
801    setattr(this_module, function_name, replacement_function)
802
803add_ref_in_constructor(IndexIVFFlat, 0)
804add_ref_in_constructor(IndexIVFFlatDedup, 0)
805add_ref_in_constructor(IndexPreTransform, {2: [0, 1], 1: [0]})
806add_ref_in_method(IndexPreTransform, 'prepend_transform', 0)
807add_ref_in_constructor(IndexIVFPQ, 0)
808add_ref_in_constructor(IndexIVFPQR, 0)
809add_ref_in_constructor(IndexIVFPQFastScan, 0)
810add_ref_in_constructor(Index2Layer, 0)
811add_ref_in_constructor(Level1Quantizer, 0)
812add_ref_in_constructor(IndexIVFScalarQuantizer, 0)
813add_ref_in_constructor(IndexIDMap, 0)
814add_ref_in_constructor(IndexIDMap2, 0)
815add_ref_in_constructor(IndexHNSW, 0)
816add_ref_in_method(IndexShards, 'add_shard', 0)
817add_ref_in_method(IndexBinaryShards, 'add_shard', 0)
818add_ref_in_constructor(IndexRefineFlat, {2:[0], 1:[0]})
819add_ref_in_constructor(IndexRefine, {2:[0, 1]})
820
821add_ref_in_constructor(IndexBinaryIVF, 0)
822add_ref_in_constructor(IndexBinaryFromFloat, 0)
823add_ref_in_constructor(IndexBinaryIDMap, 0)
824add_ref_in_constructor(IndexBinaryIDMap2, 0)
825
826add_ref_in_method(IndexReplicas, 'addIndex', 0)
827add_ref_in_method(IndexBinaryReplicas, 'addIndex', 0)
828
829add_ref_in_constructor(BufferedIOWriter, 0)
830add_ref_in_constructor(BufferedIOReader, 0)
831
832# seems really marginal...
833# remove_ref_from_method(IndexReplicas, 'removeIndex', 0)
834
835###########################################
836# GPU functions
837###########################################
838
839
840def index_cpu_to_gpu_multiple_py(resources, index, co=None, gpus=None):
841    """ builds the C++ vectors for the GPU indices and the
842    resources. Handles the case where the resources are assigned to
843    the list of GPUs """
844    if gpus is None:
845        gpus = range(len(resources))
846    vres = GpuResourcesVector()
847    vdev = Int32Vector()
848    for i, res in zip(gpus, resources):
849        vdev.push_back(i)
850        vres.push_back(res)
851    index = index_cpu_to_gpu_multiple(vres, vdev, index, co)
852    return index
853
854
855def index_cpu_to_all_gpus(index, co=None, ngpu=-1):
856    index_gpu = index_cpu_to_gpus_list(index, co=co, gpus=None, ngpu=ngpu)
857    return index_gpu
858
859
860def index_cpu_to_gpus_list(index, co=None, gpus=None, ngpu=-1):
861    """ Here we can pass list of GPU ids as a parameter or ngpu to
862    use first n GPU's. gpus mut be a list or None"""
863    if (gpus is None) and (ngpu == -1):  # All blank
864        gpus = range(get_num_gpus())
865    elif (gpus is None) and (ngpu != -1):  # Get number of GPU's only
866        gpus = range(ngpu)
867    res = [StandardGpuResources() for _ in gpus]
868    index_gpu = index_cpu_to_gpu_multiple_py(res, index, co, gpus)
869    return index_gpu
870
871# allows numpy ndarray usage with bfKnn
872def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2):
873    """
874    Compute the k nearest neighbors of a vector on one GPU without constructing an index
875
876    Parameters
877    ----------
878    res : StandardGpuResources
879        GPU resources to use during computation
880    xq : array_like
881        Query vectors, shape (nq, d) where d is appropriate for the index.
882        `dtype` must be float32.
883    xb : array_like
884        Database vectors, shape (nb, d) where d is appropriate for the index.
885        `dtype` must be float32.
886    k : int
887        Number of nearest neighbors.
888    D : array_like, optional
889        Output array for distances of the nearest neighbors, shape (nq, k)
890    I : array_like, optional
891        Output array for the nearest neighbors, shape (nq, k)
892    distance_type : MetricType, optional
893        distance measure to use (either METRIC_L2 or METRIC_INNER_PRODUCT)
894
895    Returns
896    -------
897    D : array_like
898        Distances of the nearest neighbors, shape (nq, k)
899    I : array_like
900        Labels of the nearest neighbors, shape (nq, k)
901    """
902    nq, d = xq.shape
903    if xq.flags.c_contiguous:
904        xq_row_major = True
905    elif xq.flags.f_contiguous:
906        xq = xq.T
907        xq_row_major = False
908    else:
909        raise TypeError('xq matrix should be row (C) or column-major (Fortran)')
910
911    xq_ptr = swig_ptr(xq)
912
913    if xq.dtype == np.float32:
914        xq_type = DistanceDataType_F32
915    elif xq.dtype == np.float16:
916        xq_type = DistanceDataType_F16
917    else:
918        raise TypeError('xq must be f32 or f16')
919
920    nb, d2 = xb.shape
921    assert d2 == d
922    if xb.flags.c_contiguous:
923        xb_row_major = True
924    elif xb.flags.f_contiguous:
925        xb = xb.T
926        xb_row_major = False
927    else:
928        raise TypeError('xb matrix should be row (C) or column-major (Fortran)')
929
930    xb_ptr = swig_ptr(xb)
931
932    if xb.dtype == np.float32:
933        xb_type = DistanceDataType_F32
934    elif xb.dtype == np.float16:
935        xb_type = DistanceDataType_F16
936    else:
937        raise TypeError('xb must be float32 or float16')
938
939    if D is None:
940        D = np.empty((nq, k), dtype=np.float32)
941    else:
942        assert D.shape == (nq, k)
943        # interface takes void*, we need to check this
944        assert D.dtype == np.float32
945
946    D_ptr = swig_ptr(D)
947
948    if I is None:
949        I = np.empty((nq, k), dtype=np.int64)
950    else:
951        assert I.shape == (nq, k)
952
953    I_ptr = swig_ptr(I)
954
955    if I.dtype == np.int64:
956        I_type = IndicesDataType_I64
957    elif I.dtype == I.dtype == np.int32:
958        I_type = IndicesDataType_I32
959    else:
960        raise TypeError('I must be i64 or i32')
961
962    args = GpuDistanceParams()
963    args.metric = metric
964    args.k = k
965    args.dims = d
966    args.vectors = xb_ptr
967    args.vectorsRowMajor = xb_row_major
968    args.vectorType = xb_type
969    args.numVectors = nb
970    args.queries = xq_ptr
971    args.queriesRowMajor = xq_row_major
972    args.queryType = xq_type
973    args.numQueries = nq
974    args.outDistances = D_ptr
975    args.outIndices = I_ptr
976    args.outIndicesType = I_type
977
978    # no stream synchronization needed, inputs and outputs are guaranteed to
979    # be on the CPU (numpy arrays)
980    bfKnn(res, args)
981
982    return D, I
983
984# allows numpy ndarray usage with bfKnn for all pairwise distances
985def pairwise_distance_gpu(res, xq, xb, D=None, metric=METRIC_L2):
986    """
987    Compute all pairwise distances between xq and xb on one GPU without constructing an index
988
989    Parameters
990    ----------
991    res : StandardGpuResources
992        GPU resources to use during computation
993    xq : array_like
994        Query vectors, shape (nq, d) where d is appropriate for the index.
995        `dtype` must be float32.
996    xb : array_like
997        Database vectors, shape (nb, d) where d is appropriate for the index.
998        `dtype` must be float32.
999    D : array_like, optional
1000        Output array for all pairwise distances, shape (nq, nb)
1001    distance_type : MetricType, optional
1002        distance measure to use (either METRIC_L2 or METRIC_INNER_PRODUCT)
1003
1004    Returns
1005    -------
1006    D : array_like
1007        All pairwise distances, shape (nq, nb)
1008    """
1009    nq, d = xq.shape
1010    if xq.flags.c_contiguous:
1011        xq_row_major = True
1012    elif xq.flags.f_contiguous:
1013        xq = xq.T
1014        xq_row_major = False
1015    else:
1016        raise TypeError('xq matrix should be row (C) or column-major (Fortran)')
1017
1018    xq_ptr = swig_ptr(xq)
1019
1020    if xq.dtype == np.float32:
1021        xq_type = DistanceDataType_F32
1022    elif xq.dtype == np.float16:
1023        xq_type = DistanceDataType_F16
1024    else:
1025        raise TypeError('xq must be float32 or float16')
1026
1027    nb, d2 = xb.shape
1028    assert d2 == d
1029    if xb.flags.c_contiguous:
1030        xb_row_major = True
1031    elif xb.flags.f_contiguous:
1032        xb = xb.T
1033        xb_row_major = False
1034    else:
1035        raise TypeError('xb matrix should be row (C) or column-major (Fortran)')
1036
1037    xb_ptr = swig_ptr(xb)
1038
1039    if xb.dtype == np.float32:
1040        xb_type = DistanceDataType_F32
1041    elif xb.dtype == np.float16:
1042        xb_type = DistanceDataType_F16
1043    else:
1044        raise TypeError('xb must be float32 or float16')
1045
1046    if D is None:
1047        D = np.empty((nq, nb), dtype=np.float32)
1048    else:
1049        assert D.shape == (nq, nb)
1050        # interface takes void*, we need to check this
1051        assert D.dtype == np.float32
1052
1053    D_ptr = swig_ptr(D)
1054
1055    args = GpuDistanceParams()
1056    args.metric = metric
1057    args.k = -1 # selects all pairwise distances
1058    args.dims = d
1059    args.vectors = xb_ptr
1060    args.vectorsRowMajor = xb_row_major
1061    args.vectorType = xb_type
1062    args.numVectors = nb
1063    args.queries = xq_ptr
1064    args.queriesRowMajor = xq_row_major
1065    args.queryType = xq_type
1066    args.numQueries = nq
1067    args.outDistances = D_ptr
1068
1069    # no stream synchronization needed, inputs and outputs are guaranteed to
1070    # be on the CPU (numpy arrays)
1071    bfKnn(res, args)
1072
1073    return D
1074
1075
1076###########################################
1077# numpy array / std::vector conversions
1078###########################################
1079
1080sizeof_long = array.array('l').itemsize
1081deprecated_name_map = {
1082    # deprecated: replacement
1083    'Float': 'Float32',
1084    'Double': 'Float64',
1085    'Char': 'Int8',
1086    'Int': 'Int32',
1087    'Long': 'Int32' if sizeof_long == 4 else 'Int64',
1088    'LongLong': 'Int64',
1089    'Byte': 'UInt8',
1090    # previously misspelled variant
1091    'Uint64': 'UInt64',
1092}
1093
1094for depr_prefix, base_prefix in deprecated_name_map.items():
1095    _make_deprecated_swig_class(depr_prefix + "Vector", base_prefix + "Vector")
1096
1097    # same for the three legacy *VectorVector classes
1098    if depr_prefix in ['Float', 'Long', 'Byte']:
1099        _make_deprecated_swig_class(depr_prefix + "VectorVector",
1100                                    base_prefix + "VectorVector")
1101
1102# mapping from vector names in swigfaiss.swig and the numpy dtype names
1103# TODO: once deprecated classes are removed, remove the dict and just use .lower() below
1104vector_name_map = {
1105    'Float32': 'float32',
1106    'Float64': 'float64',
1107    'Int8': 'int8',
1108    'Int16': 'int16',
1109    'Int32': 'int32',
1110    'Int64': 'int64',
1111    'UInt8': 'uint8',
1112    'UInt16': 'uint16',
1113    'UInt32': 'uint32',
1114    'UInt64': 'uint64',
1115    **{k: v.lower() for k, v in deprecated_name_map.items()}
1116}
1117
1118
1119def vector_to_array(v):
1120    """ convert a C++ vector to a numpy array """
1121    classname = v.__class__.__name__
1122    assert classname.endswith('Vector')
1123    dtype = np.dtype(vector_name_map[classname[:-6]])
1124    a = np.empty(v.size(), dtype=dtype)
1125    if v.size() > 0:
1126        memcpy(swig_ptr(a), v.data(), a.nbytes)
1127    return a
1128
1129
1130def vector_float_to_array(v):
1131    return vector_to_array(v)
1132
1133
1134def copy_array_to_vector(a, v):
1135    """ copy a numpy array to a vector """
1136    n, = a.shape
1137    classname = v.__class__.__name__
1138    assert classname.endswith('Vector')
1139    dtype = np.dtype(vector_name_map[classname[:-6]])
1140    assert dtype == a.dtype, (
1141        'cannot copy a %s array to a %s (should be %s)' % (
1142            a.dtype, classname, dtype))
1143    v.resize(n)
1144    if n > 0:
1145        memcpy(v.data(), swig_ptr(a), a.nbytes)
1146
1147# same for AlignedTable
1148
1149def copy_array_to_AlignedTable(a, v):
1150    n, = a.shape
1151    # TODO check class name
1152    assert v.itemsize() == a.itemsize
1153    v.resize(n)
1154    if n > 0:
1155        memcpy(v.get(), swig_ptr(a), a.nbytes)
1156
1157def array_to_AlignedTable(a):
1158    if a.dtype == 'uint16':
1159        v = AlignedTableUint16(a.size)
1160    elif a.dtype == 'uint8':
1161        v = AlignedTableUint8(a.size)
1162    else:
1163        assert False
1164    copy_array_to_AlignedTable(a, v)
1165    return v
1166
1167def AlignedTable_to_array(v):
1168    """ convert an AlignedTable to a numpy array """
1169    classname = v.__class__.__name__
1170    assert classname.startswith('AlignedTable')
1171    dtype = classname[12:].lower()
1172    a = np.empty(v.size(), dtype=dtype)
1173    if a.size > 0:
1174        memcpy(swig_ptr(a), v.data(), a.nbytes)
1175    return a
1176
1177###########################################
1178# Wrapper for a few functions
1179###########################################
1180
1181def kmin(array, k):
1182    """return k smallest values (and their indices) of the lines of a
1183    float32 array"""
1184    m, n = array.shape
1185    I = np.zeros((m, k), dtype='int64')
1186    D = np.zeros((m, k), dtype='float32')
1187    ha = float_maxheap_array_t()
1188    ha.ids = swig_ptr(I)
1189    ha.val = swig_ptr(D)
1190    ha.nh = m
1191    ha.k = k
1192    ha.heapify()
1193    ha.addn(n, swig_ptr(array))
1194    ha.reorder()
1195    return D, I
1196
1197
1198def kmax(array, k):
1199    """return k largest values (and their indices) of the lines of a
1200    float32 array"""
1201    m, n = array.shape
1202    I = np.zeros((m, k), dtype='int64')
1203    D = np.zeros((m, k), dtype='float32')
1204    ha = float_minheap_array_t()
1205    ha.ids = swig_ptr(I)
1206    ha.val = swig_ptr(D)
1207    ha.nh = m
1208    ha.k = k
1209    ha.heapify()
1210    ha.addn(n, swig_ptr(array))
1211    ha.reorder()
1212    return D, I
1213
1214
1215def pairwise_distances(xq, xb, mt=METRIC_L2, metric_arg=0):
1216    """compute the whole pairwise distance matrix between two sets of
1217    vectors"""
1218    nq, d = xq.shape
1219    nb, d2 = xb.shape
1220    assert d == d2
1221    dis = np.empty((nq, nb), dtype='float32')
1222    if mt == METRIC_L2:
1223        pairwise_L2sqr(
1224            d, nq, swig_ptr(xq),
1225            nb, swig_ptr(xb),
1226            swig_ptr(dis))
1227    else:
1228        pairwise_extra_distances(
1229            d, nq, swig_ptr(xq),
1230            nb, swig_ptr(xb),
1231            mt, metric_arg,
1232            swig_ptr(dis))
1233    return dis
1234
1235
1236
1237
1238def rand(n, seed=12345):
1239    res = np.empty(n, dtype='float32')
1240    float_rand(swig_ptr(res), res.size, seed)
1241    return res
1242
1243
1244def randint(n, seed=12345, vmax=None):
1245    res = np.empty(n, dtype='int64')
1246    if vmax is None:
1247        int64_rand(swig_ptr(res), res.size, seed)
1248    else:
1249        int64_rand_max(swig_ptr(res), res.size, vmax, seed)
1250    return res
1251
1252lrand = randint
1253
1254def randn(n, seed=12345):
1255    res = np.empty(n, dtype='float32')
1256    float_randn(swig_ptr(res), res.size, seed)
1257    return res
1258
1259
1260def eval_intersection(I1, I2):
1261    """ size of intersection between each line of two result tables"""
1262    n = I1.shape[0]
1263    assert I2.shape[0] == n
1264    k1, k2 = I1.shape[1], I2.shape[1]
1265    ninter = 0
1266    for i in range(n):
1267        ninter += ranklist_intersection_size(
1268            k1, swig_ptr(I1[i]), k2, swig_ptr(I2[i]))
1269    return ninter
1270
1271
1272def normalize_L2(x):
1273    fvec_renorm_L2(x.shape[1], x.shape[0], swig_ptr(x))
1274
1275######################################################
1276# MapLong2Long interface
1277######################################################
1278
1279def replacement_map_add(self, keys, vals):
1280    n, = keys.shape
1281    assert (n,) == keys.shape
1282    self.add_c(n, swig_ptr(keys), swig_ptr(vals))
1283
1284def replacement_map_search_multiple(self, keys):
1285    n, = keys.shape
1286    vals = np.empty(n, dtype='int64')
1287    self.search_multiple_c(n, swig_ptr(keys), swig_ptr(vals))
1288    return vals
1289
1290replace_method(MapLong2Long, 'add', replacement_map_add)
1291replace_method(MapLong2Long, 'search_multiple', replacement_map_search_multiple)
1292
1293######################################################
1294# search_with_parameters interface
1295######################################################
1296
1297search_with_parameters_c = search_with_parameters
1298
1299def search_with_parameters(index, x, k, params=None, output_stats=False):
1300    n, d = x.shape
1301    assert d == index.d
1302    if not params:
1303        # if not provided use the ones set in the IVF object
1304        params = IVFSearchParameters()
1305        index_ivf = extract_index_ivf(index)
1306        params.nprobe = index_ivf.nprobe
1307        params.max_codes = index_ivf.max_codes
1308    nb_dis = np.empty(1, 'uint64')
1309    ms_per_stage = np.empty(3, 'float64')
1310    distances = np.empty((n, k), dtype=np.float32)
1311    labels = np.empty((n, k), dtype=np.int64)
1312    search_with_parameters_c(
1313        index, n, swig_ptr(x),
1314        k, swig_ptr(distances),
1315        swig_ptr(labels),
1316        params, swig_ptr(nb_dis), swig_ptr(ms_per_stage)
1317    )
1318    if not output_stats:
1319        return distances, labels
1320    else:
1321        stats = {
1322            'ndis': nb_dis[0],
1323            'pre_transform_ms': ms_per_stage[0],
1324            'coarse_quantizer_ms': ms_per_stage[1],
1325            'invlist_scan_ms': ms_per_stage[2],
1326        }
1327        return distances, labels, stats
1328
1329range_search_with_parameters_c = range_search_with_parameters
1330
1331def range_search_with_parameters(index, x, radius, params=None, output_stats=False):
1332    n, d = x.shape
1333    assert d == index.d
1334    if not params:
1335        # if not provided use the ones set in the IVF object
1336        params = IVFSearchParameters()
1337        index_ivf = extract_index_ivf(index)
1338        params.nprobe = index_ivf.nprobe
1339        params.max_codes = index_ivf.max_codes
1340    nb_dis = np.empty(1, 'uint64')
1341    ms_per_stage = np.empty(3, 'float64')
1342    res = RangeSearchResult(n)
1343    range_search_with_parameters_c(
1344        index, n, swig_ptr(x),
1345        radius, res,
1346        params, swig_ptr(nb_dis), swig_ptr(ms_per_stage)
1347    )
1348    lims = rev_swig_ptr(res.lims, n + 1).copy()
1349    nd = int(lims[-1])
1350    Dout = rev_swig_ptr(res.distances, nd).copy()
1351    Iout = rev_swig_ptr(res.labels, nd).copy()
1352    if not output_stats:
1353        return lims, Dout, Iout
1354    else:
1355        stats = {
1356            'ndis': nb_dis[0],
1357            'pre_transform_ms': ms_per_stage[0],
1358            'coarse_quantizer_ms': ms_per_stage[1],
1359            'invlist_scan_ms': ms_per_stage[2],
1360        }
1361        return lims, Dout, Iout, stats
1362
1363
1364######################################################
1365# KNN function
1366######################################################
1367
1368def knn(xq, xb, k, metric=METRIC_L2):
1369    """
1370    Compute the k nearest neighbors of a vector without constructing an index
1371
1372
1373    Parameters
1374    ----------
1375    xq : array_like
1376        Query vectors, shape (nq, d) where d is appropriate for the index.
1377        `dtype` must be float32.
1378    xb : array_like
1379        Database vectors, shape (nb, d) where d is appropriate for the index.
1380        `dtype` must be float32.
1381    k : int
1382        Number of nearest neighbors.
1383    distance_type : MetricType, optional
1384        distance measure to use (either METRIC_L2 or METRIC_INNER_PRODUCT)
1385
1386    Returns
1387    -------
1388    D : array_like
1389        Distances of the nearest neighbors, shape (nq, k)
1390    I : array_like
1391        Labels of the nearest neighbors, shape (nq, k)
1392    """
1393    nq, d = xq.shape
1394    nb, d2 = xb.shape
1395    assert d == d2
1396
1397    I = np.empty((nq, k), dtype='int64')
1398    D = np.empty((nq, k), dtype='float32')
1399
1400    if metric == METRIC_L2:
1401        heaps = float_maxheap_array_t()
1402        heaps.k = k
1403        heaps.nh = nq
1404        heaps.val = swig_ptr(D)
1405        heaps.ids = swig_ptr(I)
1406        knn_L2sqr(
1407            swig_ptr(xq), swig_ptr(xb),
1408            d, nq, nb, heaps
1409        )
1410    elif metric == METRIC_INNER_PRODUCT:
1411        heaps = float_minheap_array_t()
1412        heaps.k = k
1413        heaps.nh = nq
1414        heaps.val = swig_ptr(D)
1415        heaps.ids = swig_ptr(I)
1416        knn_inner_product(
1417            swig_ptr(xq), swig_ptr(xb),
1418            d, nq, nb, heaps
1419        )
1420    else:
1421        raise NotImplementedError("only L2 and INNER_PRODUCT are supported")
1422    return D, I
1423
1424
1425###########################################
1426# Kmeans object
1427###########################################
1428
1429
1430class Kmeans:
1431    """Object that performs k-means clustering and manages the centroids.
1432    The `Kmeans` class is essentially a wrapper around the C++ `Clustering` object.
1433
1434    Parameters
1435    ----------
1436    d : int
1437       dimension of the vectors to cluster
1438    k : int
1439       number of clusters
1440    gpu: bool or int, optional
1441       False: don't use GPU
1442       True: use all GPUs
1443       number: use this many GPUs
1444    progressive_dim_steps:
1445        use a progressive dimension clustering (with that number of steps)
1446
1447    Subsequent parameters are fields of the Clustring object. The most important are:
1448
1449    niter: int, optional
1450       clustering iterations
1451    nredo: int, optional
1452       redo clustering this many times and keep best
1453    verbose: bool, optional
1454    spherical: bool, optional
1455       do we want normalized centroids?
1456    int_centroids: bool, optional
1457       round centroids coordinates to integer
1458    seed: int, optional
1459       seed for the random number generator
1460
1461    """
1462
1463
1464    def __init__(self, d, k, **kwargs):
1465        """d: input dimension, k: nb of centroids. Additional
1466         parameters are passed on the ClusteringParameters object,
1467         including niter=25, verbose=False, spherical = False
1468        """
1469        self.d = d
1470        self.k = k
1471        self.gpu = False
1472        if "progressive_dim_steps" in kwargs:
1473            self.cp = ProgressiveDimClusteringParameters()
1474        else:
1475            self.cp = ClusteringParameters()
1476        for k, v in kwargs.items():
1477            if k == 'gpu':
1478                if v == True or v == -1:
1479                    v = get_num_gpus()
1480                self.gpu = v
1481            else:
1482                # if this raises an exception, it means that it is a non-existent field
1483                getattr(self.cp, k)
1484                setattr(self.cp, k, v)
1485        self.centroids = None
1486
1487    def train(self, x, weights=None, init_centroids=None):
1488        """ Perform k-means clustering.
1489        On output of the function call:
1490
1491        - the centroids are in the centroids field of size (`k`, `d`).
1492
1493        - the objective value at each iteration is in the array obj (size `niter`)
1494
1495        - detailed optimization statistics are in the array iteration_stats.
1496
1497        Parameters
1498        ----------
1499        x : array_like
1500            Training vectors, shape (n, d), `dtype` must be float32 and n should
1501            be larger than the number of clusters `k`.
1502        weights : array_like
1503            weight associated to each vector, shape `n`
1504        init_centroids : array_like
1505            initial set of centroids, shape (n, d)
1506
1507        Returns
1508        -------
1509        final_obj: float
1510            final optimization objective
1511
1512        """
1513        n, d = x.shape
1514        assert d == self.d
1515
1516        if self.cp.__class__ == ClusteringParameters:
1517            # regular clustering
1518            clus = Clustering(d, self.k, self.cp)
1519            if init_centroids is not None:
1520                nc, d2 = init_centroids.shape
1521                assert d2 == d
1522                copy_array_to_vector(init_centroids.ravel(), clus.centroids)
1523            if self.cp.spherical:
1524                self.index = IndexFlatIP(d)
1525            else:
1526                self.index = IndexFlatL2(d)
1527            if self.gpu:
1528                self.index = index_cpu_to_all_gpus(self.index, ngpu=self.gpu)
1529            clus.train(x, self.index, weights)
1530        else:
1531            # not supported for progressive dim
1532            assert weights is None
1533            assert init_centroids is None
1534            assert not self.cp.spherical
1535            clus = ProgressiveDimClustering(d, self.k, self.cp)
1536            if self.gpu:
1537                fac = GpuProgressiveDimIndexFactory(ngpu=self.gpu)
1538            else:
1539                fac = ProgressiveDimIndexFactory()
1540            clus.train(n, swig_ptr(x), fac)
1541
1542        centroids = vector_float_to_array(clus.centroids)
1543
1544        self.centroids = centroids.reshape(self.k, d)
1545        stats = clus.iteration_stats
1546        stats = [stats.at(i) for i in range(stats.size())]
1547        self.obj = np.array([st.obj for st in stats])
1548        # copy all the iteration_stats objects to a python array
1549        stat_fields = 'obj time time_search imbalance_factor nsplit'.split()
1550        self.iteration_stats = [
1551            {field: getattr(st, field) for field in stat_fields}
1552            for st in stats
1553        ]
1554        return self.obj[-1] if self.obj.size > 0 else 0.0
1555
1556    def assign(self, x):
1557        assert self.centroids is not None, "should train before assigning"
1558        self.index.reset()
1559        self.index.add(self.centroids)
1560        D, I = self.index.search(x, 1)
1561        return D.ravel(), I.ravel()
1562
1563# IndexProxy was renamed to IndexReplicas, remap the old name for any old code
1564# people may have
1565IndexProxy = IndexReplicas
1566ConcatenatedInvertedLists = HStackInvertedLists
1567
1568###########################################
1569# serialization of indexes to byte arrays
1570###########################################
1571
1572def serialize_index(index):
1573    """ convert an index to a numpy uint8 array  """
1574    writer = VectorIOWriter()
1575    write_index(index, writer)
1576    return vector_to_array(writer.data)
1577
1578def deserialize_index(data):
1579    reader = VectorIOReader()
1580    copy_array_to_vector(data, reader.data)
1581    return read_index(reader)
1582
1583def serialize_index_binary(index):
1584    """ convert an index to a numpy uint8 array  """
1585    writer = VectorIOWriter()
1586    write_index_binary(index, writer)
1587    return vector_to_array(writer.data)
1588
1589def deserialize_index_binary(data):
1590    reader = VectorIOReader()
1591    copy_array_to_vector(data, reader.data)
1592    return read_index_binary(reader)
1593
1594
1595###########################################
1596# ResultHeap
1597###########################################
1598
1599class ResultHeap:
1600    """Accumulate query results from a sliced dataset. The final result will
1601    be in self.D, self.I."""
1602
1603    def __init__(self, nq, k):
1604        " nq: number of query vectors, k: number of results per query "
1605        self.I = np.zeros((nq, k), dtype='int64')
1606        self.D = np.zeros((nq, k), dtype='float32')
1607        self.nq, self.k = nq, k
1608        heaps = float_maxheap_array_t()
1609        heaps.k = k
1610        heaps.nh = nq
1611        heaps.val = swig_ptr(self.D)
1612        heaps.ids = swig_ptr(self.I)
1613        heaps.heapify()
1614        self.heaps = heaps
1615
1616    def add_result(self, D, I):
1617        """D, I do not need to be in a particular order (heap or sorted)"""
1618        assert D.shape == (self.nq, self.k)
1619        assert I.shape == (self.nq, self.k)
1620        self.heaps.addn_with_ids(
1621            self.k, swig_ptr(D),
1622            swig_ptr(I), self.k)
1623
1624    def finalize(self):
1625        self.heaps.reorder()
1626