1#!/usr/bin/env python 2# encoding: utf-8 3from collections import namedtuple 4import unittest 5import logging 6 7import numpy as np 8 9from scipy.spatial.distance import cosine 10from gensim.models.doc2vec import Doc2Vec 11from gensim import utils 12from gensim.models import translation_matrix 13from gensim.models import KeyedVectors 14from gensim.test.utils import datapath, get_tmpfile 15 16 17class TestTranslationMatrix(unittest.TestCase): 18 def setUp(self): 19 self.source_word_vec_file = datapath("EN.1-10.cbow1_wind5_hs0_neg10_size300_smpl1e-05.txt") 20 self.target_word_vec_file = datapath("IT.1-10.cbow1_wind5_hs0_neg10_size300_smpl1e-05.txt") 21 22 self.word_pairs = [ 23 ("one", "uno"), ("two", "due"), ("three", "tre"), 24 ("four", "quattro"), ("five", "cinque"), ("seven", "sette"), ("eight", "otto"), 25 ("dog", "cane"), ("pig", "maiale"), ("fish", "cavallo"), ("birds", "uccelli"), 26 ("apple", "mela"), ("orange", "arancione"), ("grape", "acino"), ("banana", "banana"), 27 ] 28 29 self.test_word_pairs = [("ten", "dieci"), ("cat", "gatto")] 30 31 self.source_word_vec = KeyedVectors.load_word2vec_format(self.source_word_vec_file, binary=False) 32 self.target_word_vec = KeyedVectors.load_word2vec_format(self.target_word_vec_file, binary=False) 33 34 def test_translation_matrix(self): 35 model = translation_matrix.TranslationMatrix(self.source_word_vec, self.target_word_vec, self.word_pairs) 36 model.train(self.word_pairs) 37 self.assertEqual(model.translation_matrix.shape, (300, 300)) 38 39 def test_persistence(self): 40 """Test storing/loading the entire model.""" 41 tmpf = get_tmpfile('transmat-en-it.pkl') 42 43 model = translation_matrix.TranslationMatrix(self.source_word_vec, self.target_word_vec, self.word_pairs) 44 model.train(self.word_pairs) 45 model.save(tmpf) 46 47 loaded_model = translation_matrix.TranslationMatrix.load(tmpf) 48 self.assertTrue(np.allclose(model.translation_matrix, loaded_model.translation_matrix)) 49 50 def test_translate_nn(self): 51 # Test the nearest neighbor retrieval method 52 model = translation_matrix.TranslationMatrix(self.source_word_vec, self.target_word_vec, self.word_pairs) 53 model.train(self.word_pairs) 54 55 test_source_word, test_target_word = zip(*self.test_word_pairs) 56 translated_words = model.translate( 57 test_source_word, topn=5, source_lang_vec=self.source_word_vec, target_lang_vec=self.target_word_vec, 58 ) 59 60 for idx, item in enumerate(self.test_word_pairs): 61 self.assertTrue(item[1] in translated_words[item[0]]) 62 63 def test_translate_gc(self): 64 # Test globally corrected neighbour retrieval method 65 model = translation_matrix.TranslationMatrix(self.source_word_vec, self.target_word_vec, self.word_pairs) 66 model.train(self.word_pairs) 67 68 test_source_word, test_target_word = zip(*self.test_word_pairs) 69 translated_words = model.translate( 70 test_source_word, topn=5, gc=1, sample_num=3, 71 source_lang_vec=self.source_word_vec, target_lang_vec=self.target_word_vec 72 ) 73 74 for idx, item in enumerate(self.test_word_pairs): 75 self.assertTrue(item[1] in translated_words[item[0]]) 76 77 78def read_sentiment_docs(filename): 79 sentiment_document = namedtuple('SentimentDocument', 'words tags') 80 alldocs = [] # will hold all docs in original order 81 with utils.open(filename, mode='rb', encoding='utf-8') as alldata: 82 for line_no, line in enumerate(alldata): 83 tokens = utils.to_unicode(line).split() 84 words = tokens 85 tags = str(line_no) 86 alldocs.append(sentiment_document(words, tags)) 87 return alldocs 88 89 90class TestBackMappingTranslationMatrix(unittest.TestCase): 91 def setUp(self): 92 filename = datapath("alldata-id-10.txt") 93 train_docs = read_sentiment_docs(filename) 94 self.train_docs = train_docs 95 self.source_doc_vec = Doc2Vec(documents=train_docs[:5], vector_size=8, epochs=50, seed=1) 96 self.target_doc_vec = Doc2Vec(documents=train_docs, vector_size=8, epochs=50, seed=2) 97 98 def test_translation_matrix(self): 99 model = translation_matrix.BackMappingTranslationMatrix( 100 self.source_doc_vec, self.target_doc_vec, self.train_docs[:5], 101 ) 102 transmat = model.train(self.train_docs[:5]) 103 self.assertEqual(transmat.shape, (8, 8)) 104 105 @unittest.skip( 106 "flaky test likely to be discarded when <https://github.com/RaRe-Technologies/gensim/issues/2977> " 107 "is addressed" 108 ) 109 def test_infer_vector(self): 110 """Test that translation gives similar results to traditional inference. 111 112 This may not be completely sensible/salient with such tiny data, but 113 replaces what seemed to me to be an ever-more-nonsensical test. 114 115 See <https://github.com/RaRe-Technologies/gensim/issues/2977> for discussion 116 of whether the class this supposedly tested even survives when the 117 TranslationMatrix functionality is better documented. 118 """ 119 model = translation_matrix.BackMappingTranslationMatrix( 120 self.source_doc_vec, self.target_doc_vec, self.train_docs[:5], 121 ) 122 model.train(self.train_docs[:5]) 123 backmapped_vec = model.infer_vector(self.target_doc_vec.dv[self.train_docs[5].tags[0]]) 124 self.assertEqual(backmapped_vec.shape, (8, )) 125 126 d2v_inferred_vector = self.source_doc_vec.infer_vector(self.train_docs[5].words) 127 128 distance = cosine(backmapped_vec, d2v_inferred_vector) 129 self.assertLessEqual(distance, 0.1) 130 131 132if __name__ == '__main__': 133 logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) 134 unittest.main() 135