1#!/usr/bin/env python 2# -*- coding: utf-8 -*- 3# 4# Copyright (C) 2013 Radim Rehurek <me@radimrehurek.com> 5# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html 6 7""" 8This module integrates Spotify's `Annoy <https://github.com/spotify/annoy>`_ (Approximate Nearest Neighbors Oh Yeah) 9library with Gensim's :class:`~gensim.models.word2vec.Word2Vec`, :class:`~gensim.models.doc2vec.Doc2Vec`, 10:class:`~gensim.models.fasttext.FastText` and :class:`~gensim.models.keyedvectors.KeyedVectors` word embeddings. 11 12.. Important:: 13 To use this module, you must have the ``annoy`` library installed. 14 To install it, run ``pip install annoy``. 15 16""" 17 18# Avoid import collisions on py2: this module has the same name as the actual Annoy library. 19from __future__ import absolute_import 20 21import os 22 23try: 24 import cPickle as _pickle 25except ImportError: 26 import pickle as _pickle 27 28from gensim import utils 29from gensim.models.doc2vec import Doc2Vec 30from gensim.models.word2vec import Word2Vec 31from gensim.models.fasttext import FastText 32from gensim.models import KeyedVectors 33 34 35_NOANNOY = ImportError("Annoy not installed. To use the Annoy indexer, please run `pip install annoy`.") 36 37 38class AnnoyIndexer(): 39 """This class allows the use of `Annoy <https://github.com/spotify/annoy>`_ for fast (approximate) 40 vector retrieval in `most_similar()` calls of 41 :class:`~gensim.models.word2vec.Word2Vec`, :class:`~gensim.models.doc2vec.Doc2Vec`, 42 :class:`~gensim.models.fasttext.FastText` and :class:`~gensim.models.keyedvectors.Word2VecKeyedVectors` models. 43 44 """ 45 46 def __init__(self, model=None, num_trees=None): 47 """ 48 Parameters 49 ---------- 50 model : trained model, optional 51 Use vectors from this model as the source for the index. 52 num_trees : int, optional 53 Number of trees for Annoy indexer. 54 55 Examples 56 -------- 57 .. sourcecode:: pycon 58 59 >>> from gensim.similarities.annoy import AnnoyIndexer 60 >>> from gensim.models import Word2Vec 61 >>> 62 >>> sentences = [['cute', 'cat', 'say', 'meow'], ['cute', 'dog', 'say', 'woof']] 63 >>> model = Word2Vec(sentences, min_count=1, seed=1) 64 >>> 65 >>> indexer = AnnoyIndexer(model, 2) 66 >>> model.most_similar("cat", topn=2, indexer=indexer) 67 [('cat', 1.0), ('dog', 0.32011348009109497)] 68 69 """ 70 self.index = None 71 self.labels = None 72 self.model = model 73 self.num_trees = num_trees 74 75 if model and num_trees: 76 # Extract the KeyedVectors object from whatever model we were given. 77 if isinstance(self.model, Doc2Vec): 78 kv = self.model.dv 79 elif isinstance(self.model, (Word2Vec, FastText)): 80 kv = self.model.wv 81 elif isinstance(self.model, (KeyedVectors,)): 82 kv = self.model 83 else: 84 raise ValueError("Only a Word2Vec, Doc2Vec, FastText or KeyedVectors instance can be used") 85 self._build_from_model(kv.get_normed_vectors(), kv.index_to_key, kv.vector_size) 86 87 def save(self, fname, protocol=utils.PICKLE_PROTOCOL): 88 """Save AnnoyIndexer instance to disk. 89 90 Parameters 91 ---------- 92 fname : str 93 Path to output. Save will produce 2 files: 94 `fname`: Annoy index itself. 95 `fname.dict`: Index metadata. 96 protocol : int, optional 97 Protocol for pickle. 98 99 Notes 100 ----- 101 This method saves **only the index**. The trained model isn't preserved. 102 103 """ 104 self.index.save(fname) 105 d = {'f': self.model.vector_size, 'num_trees': self.num_trees, 'labels': self.labels} 106 with utils.open(fname + '.dict', 'wb') as fout: 107 _pickle.dump(d, fout, protocol=protocol) 108 109 def load(self, fname): 110 """Load an AnnoyIndexer instance from disk. 111 112 Parameters 113 ---------- 114 fname : str 115 The path as previously used by ``save()``. 116 117 Examples 118 -------- 119 .. sourcecode:: pycon 120 121 >>> from gensim.similarities.index import AnnoyIndexer 122 >>> from gensim.models import Word2Vec 123 >>> from tempfile import mkstemp 124 >>> 125 >>> sentences = [['cute', 'cat', 'say', 'meow'], ['cute', 'dog', 'say', 'woof']] 126 >>> model = Word2Vec(sentences, min_count=1, seed=1, epochs=10) 127 >>> 128 >>> indexer = AnnoyIndexer(model, 2) 129 >>> _, temp_fn = mkstemp() 130 >>> indexer.save(temp_fn) 131 >>> 132 >>> new_indexer = AnnoyIndexer() 133 >>> new_indexer.load(temp_fn) 134 >>> new_indexer.model = model 135 136 """ 137 fname_dict = fname + '.dict' 138 if not (os.path.exists(fname) and os.path.exists(fname_dict)): 139 raise IOError( 140 f"Can't find index files '{fname}' and '{fname_dict}' - unable to restore AnnoyIndexer state." 141 ) 142 try: 143 from annoy import AnnoyIndex 144 except ImportError: 145 raise _NOANNOY 146 147 with utils.open(fname_dict, 'rb') as f: 148 d = _pickle.loads(f.read()) 149 self.num_trees = d['num_trees'] 150 self.index = AnnoyIndex(d['f'], metric='angular') 151 self.index.load(fname) 152 self.labels = d['labels'] 153 154 def _build_from_model(self, vectors, labels, num_features): 155 try: 156 from annoy import AnnoyIndex 157 except ImportError: 158 raise _NOANNOY 159 160 index = AnnoyIndex(num_features, metric='angular') 161 162 for vector_num, vector in enumerate(vectors): 163 index.add_item(vector_num, vector) 164 165 index.build(self.num_trees) 166 self.index = index 167 self.labels = labels 168 169 def most_similar(self, vector, num_neighbors): 170 """Find `num_neighbors` most similar items. 171 172 Parameters 173 ---------- 174 vector : numpy.array 175 Vector for word/document. 176 num_neighbors : int 177 Number of most similar items 178 179 Returns 180 ------- 181 list of (str, float) 182 List of most similar items in format [(`item`, `cosine_distance`), ... ] 183 184 """ 185 ids, distances = self.index.get_nns_by_vector( 186 vector, num_neighbors, include_distances=True) 187 188 return [(self.labels[ids[i]], 1 - distances[i] / 2) for i in range(len(ids))] 189