1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2013 Radim Rehurek <me@radimrehurek.com>
5# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
6
7"""
8This module integrates Spotify's `Annoy <https://github.com/spotify/annoy>`_ (Approximate Nearest Neighbors Oh Yeah)
9library with Gensim's :class:`~gensim.models.word2vec.Word2Vec`, :class:`~gensim.models.doc2vec.Doc2Vec`,
10:class:`~gensim.models.fasttext.FastText` and :class:`~gensim.models.keyedvectors.KeyedVectors` word embeddings.
11
12.. Important::
13    To use this module, you must have the ``annoy`` library installed.
14    To install it, run ``pip install annoy``.
15
16"""
17
18# Avoid import collisions on py2: this module has the same name as the actual Annoy library.
19from __future__ import absolute_import
20
21import os
22
23try:
24    import cPickle as _pickle
25except ImportError:
26    import pickle as _pickle
27
28from gensim import utils
29from gensim.models.doc2vec import Doc2Vec
30from gensim.models.word2vec import Word2Vec
31from gensim.models.fasttext import FastText
32from gensim.models import KeyedVectors
33
34
35_NOANNOY = ImportError("Annoy not installed. To use the Annoy indexer, please run `pip install annoy`.")
36
37
38class AnnoyIndexer():
39    """This class allows the use of `Annoy <https://github.com/spotify/annoy>`_ for fast (approximate)
40    vector retrieval in `most_similar()` calls of
41    :class:`~gensim.models.word2vec.Word2Vec`, :class:`~gensim.models.doc2vec.Doc2Vec`,
42    :class:`~gensim.models.fasttext.FastText` and :class:`~gensim.models.keyedvectors.Word2VecKeyedVectors` models.
43
44    """
45
46    def __init__(self, model=None, num_trees=None):
47        """
48        Parameters
49        ----------
50        model : trained model, optional
51            Use vectors from this model as the source for the index.
52        num_trees : int, optional
53            Number of trees for Annoy indexer.
54
55        Examples
56        --------
57        .. sourcecode:: pycon
58
59            >>> from gensim.similarities.annoy import AnnoyIndexer
60            >>> from gensim.models import Word2Vec
61            >>>
62            >>> sentences = [['cute', 'cat', 'say', 'meow'], ['cute', 'dog', 'say', 'woof']]
63            >>> model = Word2Vec(sentences, min_count=1, seed=1)
64            >>>
65            >>> indexer = AnnoyIndexer(model, 2)
66            >>> model.most_similar("cat", topn=2, indexer=indexer)
67            [('cat', 1.0), ('dog', 0.32011348009109497)]
68
69        """
70        self.index = None
71        self.labels = None
72        self.model = model
73        self.num_trees = num_trees
74
75        if model and num_trees:
76            # Extract the KeyedVectors object from whatever model we were given.
77            if isinstance(self.model, Doc2Vec):
78                kv = self.model.dv
79            elif isinstance(self.model, (Word2Vec, FastText)):
80                kv = self.model.wv
81            elif isinstance(self.model, (KeyedVectors,)):
82                kv = self.model
83            else:
84                raise ValueError("Only a Word2Vec, Doc2Vec, FastText or KeyedVectors instance can be used")
85            self._build_from_model(kv.get_normed_vectors(), kv.index_to_key, kv.vector_size)
86
87    def save(self, fname, protocol=utils.PICKLE_PROTOCOL):
88        """Save AnnoyIndexer instance to disk.
89
90        Parameters
91        ----------
92        fname : str
93            Path to output. Save will produce 2 files:
94            `fname`: Annoy index itself.
95            `fname.dict`: Index metadata.
96        protocol : int, optional
97            Protocol for pickle.
98
99        Notes
100        -----
101        This method saves **only the index**. The trained model isn't preserved.
102
103        """
104        self.index.save(fname)
105        d = {'f': self.model.vector_size, 'num_trees': self.num_trees, 'labels': self.labels}
106        with utils.open(fname + '.dict', 'wb') as fout:
107            _pickle.dump(d, fout, protocol=protocol)
108
109    def load(self, fname):
110        """Load an AnnoyIndexer instance from disk.
111
112        Parameters
113        ----------
114        fname : str
115            The path as previously used by ``save()``.
116
117        Examples
118        --------
119        .. sourcecode:: pycon
120
121            >>> from gensim.similarities.index import AnnoyIndexer
122            >>> from gensim.models import Word2Vec
123            >>> from tempfile import mkstemp
124            >>>
125            >>> sentences = [['cute', 'cat', 'say', 'meow'], ['cute', 'dog', 'say', 'woof']]
126            >>> model = Word2Vec(sentences, min_count=1, seed=1, epochs=10)
127            >>>
128            >>> indexer = AnnoyIndexer(model, 2)
129            >>> _, temp_fn = mkstemp()
130            >>> indexer.save(temp_fn)
131            >>>
132            >>> new_indexer = AnnoyIndexer()
133            >>> new_indexer.load(temp_fn)
134            >>> new_indexer.model = model
135
136        """
137        fname_dict = fname + '.dict'
138        if not (os.path.exists(fname) and os.path.exists(fname_dict)):
139            raise IOError(
140                f"Can't find index files '{fname}' and '{fname_dict}' - unable to restore AnnoyIndexer state."
141            )
142        try:
143            from annoy import AnnoyIndex
144        except ImportError:
145            raise _NOANNOY
146
147        with utils.open(fname_dict, 'rb') as f:
148            d = _pickle.loads(f.read())
149        self.num_trees = d['num_trees']
150        self.index = AnnoyIndex(d['f'], metric='angular')
151        self.index.load(fname)
152        self.labels = d['labels']
153
154    def _build_from_model(self, vectors, labels, num_features):
155        try:
156            from annoy import AnnoyIndex
157        except ImportError:
158            raise _NOANNOY
159
160        index = AnnoyIndex(num_features, metric='angular')
161
162        for vector_num, vector in enumerate(vectors):
163            index.add_item(vector_num, vector)
164
165        index.build(self.num_trees)
166        self.index = index
167        self.labels = labels
168
169    def most_similar(self, vector, num_neighbors):
170        """Find `num_neighbors` most similar items.
171
172        Parameters
173        ----------
174        vector : numpy.array
175            Vector for word/document.
176        num_neighbors : int
177            Number of most similar items
178
179        Returns
180        -------
181        list of (str, float)
182            List of most similar items in format [(`item`, `cosine_distance`), ... ]
183
184        """
185        ids, distances = self.index.get_nns_by_vector(
186            vector, num_neighbors, include_distances=True)
187
188        return [(self.labels[ids[i]], 1 - distances[i] / 2) for i in range(len(ids))]
189