1#!/usr/bin/env python
2# encoding: utf-8
3from collections import namedtuple
4import unittest
5import logging
6
7import numpy as np
8
9from scipy.spatial.distance import cosine
10from gensim.models.doc2vec import Doc2Vec
11from gensim import utils
12from gensim.models import translation_matrix
13from gensim.models import KeyedVectors
14from gensim.test.utils import datapath, get_tmpfile
15
16
17class TestTranslationMatrix(unittest.TestCase):
18    def setUp(self):
19        self.source_word_vec_file = datapath("EN.1-10.cbow1_wind5_hs0_neg10_size300_smpl1e-05.txt")
20        self.target_word_vec_file = datapath("IT.1-10.cbow1_wind5_hs0_neg10_size300_smpl1e-05.txt")
21
22        self.word_pairs = [
23            ("one", "uno"), ("two", "due"), ("three", "tre"),
24            ("four", "quattro"), ("five", "cinque"), ("seven", "sette"), ("eight", "otto"),
25            ("dog", "cane"), ("pig", "maiale"), ("fish", "cavallo"), ("birds", "uccelli"),
26            ("apple", "mela"), ("orange", "arancione"), ("grape", "acino"), ("banana", "banana"),
27        ]
28
29        self.test_word_pairs = [("ten", "dieci"), ("cat", "gatto")]
30
31        self.source_word_vec = KeyedVectors.load_word2vec_format(self.source_word_vec_file, binary=False)
32        self.target_word_vec = KeyedVectors.load_word2vec_format(self.target_word_vec_file, binary=False)
33
34    def test_translation_matrix(self):
35        model = translation_matrix.TranslationMatrix(self.source_word_vec, self.target_word_vec, self.word_pairs)
36        model.train(self.word_pairs)
37        self.assertEqual(model.translation_matrix.shape, (300, 300))
38
39    def test_persistence(self):
40        """Test storing/loading the entire model."""
41        tmpf = get_tmpfile('transmat-en-it.pkl')
42
43        model = translation_matrix.TranslationMatrix(self.source_word_vec, self.target_word_vec, self.word_pairs)
44        model.train(self.word_pairs)
45        model.save(tmpf)
46
47        loaded_model = translation_matrix.TranslationMatrix.load(tmpf)
48        self.assertTrue(np.allclose(model.translation_matrix, loaded_model.translation_matrix))
49
50    def test_translate_nn(self):
51        # Test the nearest neighbor retrieval method
52        model = translation_matrix.TranslationMatrix(self.source_word_vec, self.target_word_vec, self.word_pairs)
53        model.train(self.word_pairs)
54
55        test_source_word, test_target_word = zip(*self.test_word_pairs)
56        translated_words = model.translate(
57            test_source_word, topn=5, source_lang_vec=self.source_word_vec, target_lang_vec=self.target_word_vec,
58        )
59
60        for idx, item in enumerate(self.test_word_pairs):
61            self.assertTrue(item[1] in translated_words[item[0]])
62
63    def test_translate_gc(self):
64        # Test globally corrected neighbour retrieval method
65        model = translation_matrix.TranslationMatrix(self.source_word_vec, self.target_word_vec, self.word_pairs)
66        model.train(self.word_pairs)
67
68        test_source_word, test_target_word = zip(*self.test_word_pairs)
69        translated_words = model.translate(
70            test_source_word, topn=5, gc=1, sample_num=3,
71            source_lang_vec=self.source_word_vec, target_lang_vec=self.target_word_vec
72        )
73
74        for idx, item in enumerate(self.test_word_pairs):
75            self.assertTrue(item[1] in translated_words[item[0]])
76
77
78def read_sentiment_docs(filename):
79    sentiment_document = namedtuple('SentimentDocument', 'words tags')
80    alldocs = []  # will hold all docs in original order
81    with utils.open(filename, mode='rb', encoding='utf-8') as alldata:
82        for line_no, line in enumerate(alldata):
83            tokens = utils.to_unicode(line).split()
84            words = tokens
85            tags = str(line_no)
86            alldocs.append(sentiment_document(words, tags))
87    return alldocs
88
89
90class TestBackMappingTranslationMatrix(unittest.TestCase):
91    def setUp(self):
92        filename = datapath("alldata-id-10.txt")
93        train_docs = read_sentiment_docs(filename)
94        self.train_docs = train_docs
95        self.source_doc_vec = Doc2Vec(documents=train_docs[:5], vector_size=8, epochs=50, seed=1)
96        self.target_doc_vec = Doc2Vec(documents=train_docs, vector_size=8, epochs=50, seed=2)
97
98    def test_translation_matrix(self):
99        model = translation_matrix.BackMappingTranslationMatrix(
100            self.source_doc_vec, self.target_doc_vec, self.train_docs[:5],
101        )
102        transmat = model.train(self.train_docs[:5])
103        self.assertEqual(transmat.shape, (8, 8))
104
105    @unittest.skip(
106        "flaky test likely to be discarded when <https://github.com/RaRe-Technologies/gensim/issues/2977> "
107        "is addressed"
108    )
109    def test_infer_vector(self):
110        """Test that translation gives similar results to traditional inference.
111
112        This may not be completely sensible/salient with such tiny data, but
113        replaces what seemed to me to be an ever-more-nonsensical test.
114
115        See <https://github.com/RaRe-Technologies/gensim/issues/2977> for discussion
116        of whether the class this supposedly tested even survives when the
117        TranslationMatrix functionality is better documented.
118        """
119        model = translation_matrix.BackMappingTranslationMatrix(
120            self.source_doc_vec, self.target_doc_vec, self.train_docs[:5],
121        )
122        model.train(self.train_docs[:5])
123        backmapped_vec = model.infer_vector(self.target_doc_vec.dv[self.train_docs[5].tags[0]])
124        self.assertEqual(backmapped_vec.shape, (8, ))
125
126        d2v_inferred_vector = self.source_doc_vec.infer_vector(self.train_docs[5].words)
127
128        distance = cosine(backmapped_vec, d2v_inferred_vector)
129        self.assertLessEqual(distance, 0.1)
130
131
132if __name__ == '__main__':
133    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
134    unittest.main()
135