1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# 'License'); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import functools
19import os
20import random
21import re
22import warnings
23
24import numpy as np
25import pytest
26from mxnet import ndarray as nd
27from mxnet.test_utils import *
28
29import gluonnlp as nlp
30from gluonnlp.base import get_home_dir
31
32
33@pytest.fixture
34def counter():
35    return nlp.data.utils.Counter( ['a', 'b', 'b', 'c', 'c', 'c',
36                                    'some_word$'])
37
38def _get_test_str_of_tokens(token_delim, seq_delim):
39    seq1 = token_delim + token_delim.join(['Life', 'is', 'great', '!']) + token_delim + seq_delim
40    seq2 = token_delim + token_delim.join(['life', 'is', 'good', '.']) + token_delim + seq_delim
41    seq3 = token_delim + token_delim.join(['life', "isn't", 'bad', '.']) + token_delim + seq_delim
42    seqs = seq1 + seq2 + seq3
43    return seqs
44
45def simple_tokenize(source_str, token_delim=' ', seq_delim='\n'):
46    return filter(None, re.split(token_delim + '|' + seq_delim, source_str))
47
48def _test_count_tokens(token_delim, seq_delim):
49    source_str = _get_test_str_of_tokens(token_delim, seq_delim)
50
51    tokens = list(simple_tokenize(source_str, token_delim, seq_delim))
52    cnt1 = nlp.data.count_tokens(tokens, to_lower=False)
53    assert cnt1 == nlp.data.utils.Counter(
54        {'is': 2, 'life': 2, '.': 2, 'Life': 1, 'great': 1, '!': 1, 'good': 1, "isn't": 1,
55         'bad': 1})
56
57    cnt2 = nlp.data.count_tokens(tokens, to_lower=True)
58    assert cnt2 == nlp.data.utils.Counter(
59        {'life': 3, 'is': 2, '.': 2, 'great': 1, '!': 1, 'good': 1, "isn't": 1, 'bad': 1}), cnt2
60
61    counter_to_update = nlp.data.utils.Counter({'life': 2})
62
63    cnt3 = nlp.data.utils.count_tokens(tokens, to_lower=False,
64                                   counter=counter_to_update.copy())
65    assert cnt3 == nlp.data.utils.Counter(
66        {'is': 2, 'life': 4, '.': 2, 'Life': 1, 'great': 1, '!': 1, 'good': 1, "isn't": 1,
67         'bad': 1})
68
69    cnt4 = nlp.data.count_tokens(tokens, to_lower=True,
70                             counter=counter_to_update.copy())
71    assert cnt4 == nlp.data.utils.Counter(
72        {'life': 5, 'is': 2, '.': 2, 'great': 1, '!': 1, 'good': 1, "isn't": 1, 'bad': 1})
73
74
75def test_count_tokens():
76    _test_count_tokens(' ', '\n')
77    _test_count_tokens('IS', 'LIFE')
78
79
80def test_vocabulary_getitem(counter):
81    vocab = nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token='<unk>',
82                      bos_token=None, eos_token=None, reserved_tokens=None)
83
84    i1 = vocab['c']
85    assert i1 == 2
86    assert vocab.to_indices('c') == 2
87
88    i2 = vocab[['c']]
89    assert i2 == [2]
90    assert vocab.to_indices(['c']) == [2]
91
92    i3 = vocab[['<unk>', 'non-exist']]
93    assert i3 == [0, 0]
94    assert vocab.to_indices(['<unk>', 'non-exist']) == [0, 0]
95
96    i4 = vocab[['a', 'non-exist', 'a', 'b']]
97    assert i4 == [4, 0, 4, 3]
98    assert vocab.to_indices(['a', 'non-exist', 'a', 'b']) == [4, 0, 4, 3]
99
100    no_unk_vocab = nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token=None,
101                             bos_token=None, eos_token=None, reserved_tokens=None)
102    assert no_unk_vocab['c'] == 1
103    assert no_unk_vocab.to_indices('c') == 1
104
105    assert no_unk_vocab[['c']] == [1]
106    assert no_unk_vocab.to_indices(['c']) == [1]
107
108    for words in [['<unk>', 'non-exist'], ['a', 'non-exist', 'a', 'b']]:
109        with pytest.raises(KeyError):
110            no_unk_vocab.to_indices(words)
111
112
113def test_vocabulary_to_tokens(counter):
114    vocab = nlp.Vocab(counter, max_size=None, min_freq=1,unknown_token='<unknown>',
115                      bos_token=None, eos_token=None, reserved_tokens=None)
116    i1 = vocab.to_tokens(2)
117    assert i1 == 'c'
118
119    i2 = vocab.to_tokens([2])
120    assert i2 == ['c']
121
122    i3 = vocab.to_tokens([0, 0])
123    assert i3 == ['<unknown>', '<unknown>']
124
125    i4 = vocab.to_tokens([4, 0, 4, 3])
126    assert i4 == ['a', '<unknown>', 'a', 'b']
127
128    for indices in [6, [6,7]]:
129        with pytest.raises(ValueError):
130            vocab.to_tokens(indices)
131
132
133def test_vocabulary(counter):
134    v1 = nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token='<unk>',
135                   padding_token=None, bos_token=None, eos_token=None, reserved_tokens=None)
136    assert len(v1) == 5
137    assert v1.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3, 'some_word$': 4}
138    assert v1.idx_to_token[1] == 'c'
139    assert v1.unknown_token == '<unk>'
140    assert v1.reserved_tokens is None
141    assert v1.embedding is None
142    assert 'a' in v1
143    assert v1.unknown_token in v1
144
145    v2 = nlp.Vocab(counter, max_size=None, min_freq=2, unknown_token='<unk>',
146                   padding_token=None, bos_token=None, eos_token=None, reserved_tokens=None)
147    assert len(v2) == 3
148    assert v2.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
149    assert v2.idx_to_token[1] == 'c'
150    assert v2.unknown_token == '<unk>'
151    assert v2.reserved_tokens is None
152    assert v2.embedding is None
153    assert 'a' not in v2
154    assert v2.unknown_token in v2
155
156    v3 = nlp.Vocab(counter, max_size=None, min_freq=100, unknown_token='<unk>',
157                   padding_token=None, bos_token=None, eos_token=None, reserved_tokens=None)
158    assert len(v3) == 1
159    assert v3.token_to_idx == {'<unk>': 0}
160    assert v3.idx_to_token[0] == '<unk>'
161    assert v3.unknown_token == '<unk>'
162    assert v3.reserved_tokens is None
163    assert v3.embedding is None
164    assert 'a' not in v3
165
166    v4 = nlp.Vocab(counter, max_size=2, min_freq=1, unknown_token='<unk>',
167                   padding_token=None, bos_token=None, eos_token=None, reserved_tokens=None)
168    assert len(v4) == 3
169    assert v4.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
170    assert v4.idx_to_token[1] == 'c'
171    assert v4.unknown_token == '<unk>'
172    assert v4.reserved_tokens is None
173    assert v4.embedding is None
174    assert 'a' not in v4
175
176    v5 = nlp.Vocab(counter, max_size=3, min_freq=1, unknown_token='<unk>',
177                   padding_token=None, bos_token=None, eos_token=None, reserved_tokens=None)
178    assert len(v5) == 4
179    assert v5.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3}
180    assert v5.idx_to_token[1] == 'c'
181    assert v5.unknown_token == '<unk>'
182    assert v5.reserved_tokens is None
183    assert v5.embedding is None
184    assert 'a' in v5
185
186    v6 = nlp.Vocab(counter, max_size=100, min_freq=1, unknown_token='<unk>',
187                   padding_token=None, bos_token=None, eos_token=None, reserved_tokens=None)
188    assert len(v6) == 5
189    assert v6.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3,
190                               'some_word$': 4}
191    assert v6.idx_to_token[1] == 'c'
192    assert v6.unknown_token == '<unk>'
193    assert v6.reserved_tokens is None
194    assert v6.embedding is None
195    assert 'a' in v6
196
197    v7 = nlp.Vocab(counter, max_size=1, min_freq=2, unknown_token='<unk>',
198                   padding_token=None, bos_token=None, eos_token=None, reserved_tokens=None)
199    assert len(v7) == 2
200    assert v7.token_to_idx == {'<unk>': 0, 'c': 1}
201    assert v7.idx_to_token[1] == 'c'
202    assert v7.unknown_token == '<unk>'
203    assert v7.reserved_tokens is None
204    assert v7.embedding is None
205    assert 'a' not in v7
206
207    with pytest.raises(AssertionError):
208        nlp.Vocab(counter, max_size=None, min_freq=0, unknown_token='<unknown>',
209              reserved_tokens=['b'])
210    with pytest.raises(AssertionError):
211        nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token='<unknown>',
212              reserved_tokens=['b', 'b'])
213    with pytest.raises(AssertionError):
214        nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token='<unknown>',
215              reserved_tokens=['b', '<unknown>'])
216
217    v8 = nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token='<unknown>',
218                   padding_token=None, bos_token=None, eos_token=None, reserved_tokens=['b'])
219    assert len(v8) == 5
220    assert v8.token_to_idx == {'<unknown>': 0, 'b': 1, 'c': 2, 'a': 3, 'some_word$': 4}
221    assert v8.idx_to_token[1] == 'b'
222    assert v8.unknown_token == '<unknown>'
223    assert v8.reserved_tokens == ['b']
224    assert v8.embedding is None
225    assert 'a' in v8
226
227    v9 = nlp.Vocab(counter, max_size=None, min_freq=2, unknown_token='<unk>',
228                   padding_token=None, bos_token=None, eos_token=None, reserved_tokens=['b', 'a'])
229    assert len(v9) == 4
230    assert v9.token_to_idx == {'<unk>': 0, 'b': 1, 'a': 2, 'c': 3}
231    assert v9.idx_to_token[1] == 'b'
232    assert v9.unknown_token == '<unk>'
233    assert v9.reserved_tokens == ['b', 'a']
234    assert v9.embedding is None
235    assert 'a' in v9
236
237    v10 = nlp.Vocab(counter, max_size=None, min_freq=100, unknown_token='<unk>',
238                    padding_token=None, bos_token=None, eos_token=None, reserved_tokens=['b', 'c'])
239    assert len(v10) == 3
240    assert v10.token_to_idx == {'<unk>': 0, 'b': 1, 'c': 2}
241    assert v10.idx_to_token[1] == 'b'
242    assert v10.unknown_token == '<unk>'
243    assert v10.reserved_tokens == ['b', 'c']
244    assert v10.embedding is None
245    assert 'a' not in v10
246
247    v11 = nlp.Vocab(counter, max_size=1, min_freq=2, unknown_token='<unk>',
248                    padding_token=None, bos_token=None, eos_token=None,
249                    reserved_tokens=['<pad>', 'b'])
250    assert len(v11) == 4
251    assert v11.token_to_idx == {'<unk>': 0, '<pad>': 1, 'b': 2, 'c': 3}
252    assert v11.idx_to_token[1] == '<pad>'
253    assert v11.unknown_token == '<unk>'
254    assert v11.reserved_tokens == ['<pad>', 'b']
255    assert v11.embedding is None
256    assert 'a' not in v11
257
258    v12 = nlp.Vocab(counter, max_size=None, min_freq=2, unknown_token='b',
259                    padding_token=None, bos_token=None, eos_token=None, reserved_tokens=['<pad>'])
260    assert len(v12) == 3
261    assert v12.token_to_idx == {'b': 0, '<pad>': 1, 'c': 2}
262    assert v12.idx_to_token[1] == '<pad>'
263    assert v12.unknown_token == 'b'
264    assert v12.reserved_tokens == ['<pad>']
265    assert v12.embedding is None
266    assert 'a' not in v12
267
268    v13 = nlp.Vocab(counter, max_size=None, min_freq=2, unknown_token='a',
269                    padding_token=None, bos_token=None, eos_token=None, reserved_tokens=['<pad>'])
270    assert len(v13) == 4
271    assert v13.token_to_idx == {'a': 0, '<pad>': 1, 'c': 2, 'b': 3}
272    assert v13.idx_to_token[1] == '<pad>'
273    assert v13.unknown_token == 'a'
274    assert v13.reserved_tokens == ['<pad>']
275    assert v13.embedding is None
276    assert 'a' in v13
277
278    counter_tuple = nlp.data.utils.Counter([('a', 'a'), ('b', 'b'), ('b', 'b'), ('c', 'c'),
279                                            ('c', 'c'), ('c', 'c'), ('some_word$', 'some_word$')])
280
281    v14 = nlp.Vocab(counter_tuple, max_size=None, min_freq=1, unknown_token=('<unk>', '<unk>'),
282                    padding_token=None, bos_token=None, eos_token=None, reserved_tokens=None)
283    assert len(v14) == 5
284    assert v14.token_to_idx == {('<unk>', '<unk>'): 0, ('c', 'c'): 1, ('b', 'b'): 2, ('a', 'a'): 3,
285                                ('some_word$', 'some_word$'): 4}
286    assert v14.idx_to_token[1] == ('c', 'c')
287    assert v14.unknown_token == ('<unk>', '<unk>')
288    assert v14.reserved_tokens is None
289    assert v14.embedding is None
290    assert ('a', 'a') in v14
291    assert ('<unk>', '<unk>') in v14
292
293
294def _mk_my_pretrain_file(path, token_delim, pretrain_file):
295    path = os.path.expanduser(path)
296    if not os.path.exists(path):
297        os.makedirs(path)
298    seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
299    seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
300    seqs = seq1 + seq2
301    with open(os.path.join(path, pretrain_file), 'w') as fout:
302        fout.write(seqs)
303
304
305def _mk_my_pretrain_file2(path, token_delim, pretrain_file):
306    path = os.path.expanduser(path)
307    if not os.path.exists(path):
308        os.makedirs(path)
309    seq1 = token_delim.join(['a', '0.01', '0.02', '0.03', '0.04', '0.05']) + '\n'
310    seq2 = token_delim.join(['c', '0.06', '0.07', '0.08', '0.09', '0.1']) + '\n'
311    seqs = seq1 + seq2
312    with open(os.path.join(path, pretrain_file), 'w') as fout:
313        fout.write(seqs)
314
315
316def _mk_my_pretrain_file3(path, token_delim, pretrain_file):
317    path = os.path.expanduser(path)
318    if not os.path.exists(path):
319        os.makedirs(path)
320    seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
321    seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
322    seq3 = token_delim.join(['<unk1>', '1.1', '1.2', '1.3', '1.4',
323                             '1.5']) + '\n'
324    seqs = seq1 + seq2 + seq3
325    with open(os.path.join(path, pretrain_file), 'w') as fout:
326        fout.write(seqs)
327
328
329def _mk_my_pretrain_file4(path, token_delim, pretrain_file):
330    path = os.path.expanduser(path)
331    if not os.path.exists(path):
332        os.makedirs(path)
333    seq1 = token_delim.join(['a', '0.01', '0.02', '0.03', '0.04', '0.05']) + '\n'
334    seq2 = token_delim.join(['c', '0.06', '0.07', '0.08', '0.09', '0.1']) + '\n'
335    seq3 = token_delim.join(['<unk2>', '0.11', '0.12', '0.13', '0.14', '0.15']) + '\n'
336    seqs = seq1 + seq2 + seq3
337    with open(os.path.join(path, pretrain_file), 'w') as fout:
338        fout.write(seqs)
339
340
341def _mk_my_invalid_pretrain_file(path, token_delim, pretrain_file):
342    path = os.path.expanduser(path)
343    if not os.path.exists(path):
344        os.makedirs(path)
345    seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
346    seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
347    seq3 = token_delim.join(['c']) + '\n'
348    seqs = seq1 + seq2 + seq3
349    with open(os.path.join(path, pretrain_file), 'w') as fout:
350        fout.write(seqs)
351
352
353def _mk_my_invalid_pretrain_file2(path, token_delim, pretrain_file):
354    path = os.path.expanduser(path)
355    if not os.path.exists(path):
356        os.makedirs(path)
357    seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
358    seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
359    seq3 = token_delim.join(['c', '0.6', '0.7', '0.8']) + '\n'
360    seqs = seq1 + seq2 + seq3
361    with open(os.path.join(path, pretrain_file), 'w') as fout:
362        fout.write(seqs)
363
364
365@pytest.mark.parametrize('allow_extend', [True, False])
366@pytest.mark.serial
367def test_token_embedding_from_file(tmpdir, allow_extend):
368    embed_root = str(tmpdir)
369    embed_name = 'my_embed'
370    elem_delim = '\t'
371    pretrain_file = 'my_pretrain_file.txt'
372
373    from_file = functools.partial(nlp.embedding.TokenEmbedding.from_file, allow_extend=allow_extend)
374
375    _mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim, pretrain_file)
376
377    pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file)
378
379    my_embed = from_file(pretrain_file_path, elem_delim)
380
381    assert 'a' in my_embed
382    assert my_embed.unknown_token == '<unk>'
383    assert my_embed.unknown_token in my_embed
384
385    first_vec = my_embed.idx_to_vec[0]
386    assert_almost_equal(first_vec.asnumpy(), np.array([0, 0, 0, 0, 0]))
387
388    # Test properties
389    assert my_embed.token_to_idx == {'<unk>': 0, 'a': 1, 'b': 2}
390    assert my_embed.idx_to_token == ['<unk>', 'a', 'b']
391
392    assert_almost_equal(my_embed.idx_to_vec.asnumpy(),
393                       np.array([[0,  0,  0,  0,  0],
394                                 [0.1, 0.2, 0.3, 0.4, 0.5],
395                                 [0.6, 0.7, 0.8, 0.9, 1]])
396                       )
397
398    # Test __getitem__.
399    unk_vec = my_embed['A']
400    assert_almost_equal(unk_vec.asnumpy(), np.array([0, 0, 0, 0, 0]))
401
402    a_vec = my_embed['a']
403    assert_almost_equal(a_vec.asnumpy(), np.array([0.1, 0.2, 0.3, 0.4, 0.5]))
404
405    my_embed = from_file(pretrain_file_path, elem_delim)
406    # Test __setitem__.
407    my_embed['a'] = nd.array([1, 2, 3, 4, 5])
408    assert_almost_equal(my_embed['a'].asnumpy(), np.array([1, 2, 3, 4, 5]))
409    if allow_extend:
410        with pytest.warns(UserWarning):  # Should add multiple new tokens at a time
411            my_embed['unknown$$$'] = nd.array([0, 0, 0, 0, 0])
412        assert_almost_equal(my_embed['unknown$$$'].asnumpy(), np.array([0, 0, 0, 0, 0]))
413    else:
414        with pytest.raises(KeyError):
415            my_embed['unknown$$$'] = nd.array([0, 0, 0, 0, 0])
416    with pytest.raises(AssertionError):
417        my_embed['<unk>'] = nd.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
418    with pytest.raises(AssertionError):
419        my_embed['<unk>'] = nd.array([0])
420
421    unk_vecs = my_embed['<unk$unk@unk>', '<unk$unk@unk>']
422    assert_almost_equal(unk_vecs.asnumpy(), np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]))
423
424    # Test loaded unknown vectors.
425    pretrain_file2 = 'my_pretrain_file2.txt'
426    _mk_my_pretrain_file3(os.path.join(embed_root, embed_name), elem_delim, pretrain_file2)
427    pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file2)
428    my_embed2 = from_file(pretrain_file_path, elem_delim, init_unknown_vec=nd.ones, unknown_token='<unk>')
429    unk_vec2 = my_embed2['<unk>']
430    assert_almost_equal(unk_vec2.asnumpy(), np.array([1, 1, 1, 1, 1]))
431    unk_vec2 = my_embed2['<unk$unk@unk>']
432    assert_almost_equal(unk_vec2.asnumpy(), np.array([1, 1, 1, 1, 1]))
433
434    my_embed3 = from_file(pretrain_file_path, elem_delim, init_unknown_vec=nd.ones, unknown_token='<unk1>')
435    unk_vec3 = my_embed3['<unk1>']
436    assert_almost_equal(unk_vec3.asnumpy(), np.array([1.1, 1.2, 1.3, 1.4, 1.5]))
437    unk_vec3 = my_embed3['<unk$unk@unk>']
438    assert_almost_equal(unk_vec3.asnumpy(), np.array([1.1, 1.2, 1.3, 1.4, 1.5]))
439
440    # Test error handling.
441    invalid_pretrain_file = 'invalid_pretrain_file.txt'
442    _mk_my_invalid_pretrain_file(os.path.join(embed_root, embed_name), elem_delim,
443                                 invalid_pretrain_file)
444    pretrain_file_path = os.path.join(embed_root, embed_name, invalid_pretrain_file)
445    with pytest.raises(AssertionError):
446        from_file(pretrain_file_path, elem_delim)
447
448    invalid_pretrain_file2 = 'invalid_pretrain_file2.txt'
449    _mk_my_invalid_pretrain_file2(os.path.join(embed_root, embed_name), elem_delim,
450                                  invalid_pretrain_file2)
451    pretrain_file_path = os.path.join(embed_root, embed_name, invalid_pretrain_file2)
452    with pytest.raises(AssertionError):
453        from_file(pretrain_file_path, elem_delim)
454
455
456def test_embedding_get_and_pretrain_file_names():
457    assert len(nlp.embedding.list_sources(embedding_name='fasttext')) == 486
458    assert len(nlp.embedding.list_sources(embedding_name='glove')) == 10
459    assert len(nlp.embedding.list_sources(embedding_name='word2vec')) == 3
460
461    reg = nlp.embedding.list_sources(embedding_name=None)
462
463    assert len(reg['glove']) == 10
464    assert len(reg['fasttext']) == 486
465    assert len(reg['word2vec']) == 3
466
467    with pytest.raises(KeyError):
468        nlp.embedding.list_sources('unknown$$')
469
470
471@pytest.mark.parametrize('allow_extend', [True, False])
472def test_vocab_set_embedding_with_one_custom_embedding(tmpdir, allow_extend, counter):
473    embed_root = str(tmpdir)
474    embed_name = 'my_embed'
475    elem_delim = '\t'
476    pretrain_file = 'my_pretrain_file1.txt'
477
478    from_file = functools.partial(nlp.embedding.TokenEmbedding.from_file, allow_extend=allow_extend)
479
480    _mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim, pretrain_file)
481
482    pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file)
483
484    v1 = nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token='<unk>',
485                   padding_token=None, bos_token=None, eos_token=None, reserved_tokens=['<pad>'])
486    v1_no_unk = nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token=None,
487                          padding_token=None, bos_token=None, eos_token=None,
488                          reserved_tokens=['<pad>'])
489
490    e1 = from_file(pretrain_file_path, elem_delim, init_unknown_vec=nd.ones)
491
492    assert v1.embedding is None
493    assert v1_no_unk.embedding is None
494    v1.set_embedding(e1)
495    v1_no_unk.set_embedding(e1)
496    assert v1.embedding is not None
497    assert v1_no_unk.embedding is not None
498
499    # Test properties
500    assert v1.embedding.token_to_idx == {'<unk>': 0, '<pad>': 1, 'c': 2, 'b': 3, 'a': 4, 'some_word$': 5}
501    assert v1.embedding.idx_to_token == ['<unk>', '<pad>', 'c', 'b', 'a', 'some_word$']
502
503    assert v1_no_unk.embedding.token_to_idx == {'<pad>': 0, 'c': 1, 'b': 2, 'a': 3, 'some_word$': 4}
504    assert v1_no_unk.embedding.idx_to_token == ['<pad>', 'c', 'b', 'a', 'some_word$']
505
506    assert_almost_equal(v1.embedding.idx_to_vec.asnumpy(),
507                        np.array([[1, 1, 1, 1, 1],
508                                  [1, 1, 1, 1, 1],
509                                  [1, 1, 1, 1, 1],
510                                  [0.6, 0.7, 0.8, 0.9, 1],
511                                  [0.1, 0.2, 0.3, 0.4, 0.5],
512                                  [1, 1, 1, 1, 1]])
513                        )
514    assert_almost_equal(v1_no_unk.embedding.idx_to_vec.asnumpy(),
515                        np.array([[1, 1, 1, 1, 1],
516                                  [1, 1, 1, 1, 1],
517                                  [0.6, 0.7, 0.8, 0.9, 1],
518                                  [0.1, 0.2, 0.3, 0.4, 0.5],
519                                  [1, 1, 1, 1, 1]])
520                        )
521
522    assert_almost_equal(v1.embedding['c'].asnumpy(),
523                        np.array([1, 1, 1, 1, 1])
524                        )
525    assert_almost_equal(v1_no_unk.embedding['c'].asnumpy(),
526                        np.array([1, 1, 1, 1, 1])
527                        )
528
529    assert_almost_equal(v1.embedding[['c']].asnumpy(),
530                        np.array([[1, 1, 1, 1, 1]])
531                        )
532    assert_almost_equal(v1_no_unk.embedding[['c']].asnumpy(),
533                        np.array([[1, 1, 1, 1, 1]])
534                        )
535
536    assert_almost_equal(v1.embedding[['a', 'not_exist']].asnumpy(),
537                        np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
538                                  [1, 1, 1, 1, 1]])
539                        )
540    with pytest.raises(KeyError):
541        v1_no_unk.embedding['a', 'not_exist']
542
543    assert_almost_equal(v1.embedding[['a', 'b']].asnumpy(),
544                        np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
545                                  [0.6, 0.7, 0.8, 0.9, 1]])
546                        )
547    assert_almost_equal(v1_no_unk.embedding[['a', 'b']].asnumpy(),
548                        np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
549                                  [0.6, 0.7, 0.8, 0.9, 1]])
550                        )
551
552    assert_almost_equal(v1.embedding[['A', 'b']].asnumpy(),
553                        np.array([[1, 1, 1, 1, 1],
554                                  [0.6, 0.7, 0.8, 0.9, 1]])
555                        )
556    with pytest.raises(KeyError):
557        v1_no_unk.embedding['A', 'b']
558
559    v1.embedding['a'] = nd.array([2, 2, 2, 2, 2])
560    v1.embedding['b'] = nd.array([3, 3, 3, 3, 3])
561    v1_no_unk.embedding['a'] = nd.array([2, 2, 2, 2, 2])
562    v1_no_unk.embedding['b'] = nd.array([3, 3, 3, 3, 3])
563
564    assert_almost_equal(v1.embedding.idx_to_vec.asnumpy(),
565                        np.array([[1, 1, 1, 1, 1],
566                                  [1, 1, 1, 1, 1],
567                                  [1, 1, 1, 1, 1],
568                                  [3, 3, 3, 3, 3],
569                                  [2, 2, 2, 2, 2],
570                                  [1, 1, 1, 1, 1]])
571                        )
572
573    assert_almost_equal(v1_no_unk.embedding.idx_to_vec.asnumpy(),
574                        np.array([[1, 1, 1, 1, 1],
575                                  [1, 1, 1, 1, 1],
576                                  [3, 3, 3, 3, 3],
577                                  [2, 2, 2, 2, 2],
578                                  [1, 1, 1, 1, 1]])
579                        )
580
581    v1.embedding['<unk>'] = nd.array([0, 0, 0, 0, 0])
582    assert_almost_equal(v1.embedding.idx_to_vec.asnumpy(),
583                        np.array([[0, 0, 0, 0, 0],
584                                  [1, 1, 1, 1, 1],
585                                  [1, 1, 1, 1, 1],
586                                  [3, 3, 3, 3, 3],
587                                  [2, 2, 2, 2, 2],
588                                  [1, 1, 1, 1, 1]])
589                        )
590    with pytest.raises(KeyError):
591        # The TokenEmbedding assigned to a vocab is never extendable
592        v1_no_unk.embedding['<unk>'] = nd.array([0, 0, 0, 0, 0])
593    v1.embedding['<unk>'] = nd.array([10, 10, 10, 10, 10])
594    assert_almost_equal(v1.embedding.idx_to_vec.asnumpy(),
595                        np.array([[10, 10, 10, 10, 10],
596                                  [1, 1, 1, 1, 1],
597                                  [1, 1, 1, 1, 1],
598                                  [3, 3, 3, 3, 3],
599                                  [2, 2, 2, 2, 2],
600                                  [1, 1, 1, 1, 1]])
601                        )
602
603    v1.set_embedding(None)
604    assert v1.embedding is None
605    v1_no_unk.set_embedding(None)
606    assert v1_no_unk.embedding is None
607
608
609@pytest.mark.parametrize('allow_extend', [True, False])
610def test_vocab_set_embedding_with_two_custom_embeddings(tmpdir, allow_extend, counter):
611    embed_root = str(tmpdir)
612    embed_name = 'my_embed'
613    elem_delim = '\t'
614    pretrain_file1 = 'my_pretrain_file1.txt'
615    pretrain_file2 = 'my_pretrain_file2.txt'
616
617    from_file = functools.partial(nlp.embedding.TokenEmbedding.from_file, allow_extend=allow_extend)
618
619    _mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim, pretrain_file1)
620    _mk_my_pretrain_file2(os.path.join(embed_root, embed_name), elem_delim, pretrain_file2)
621
622    pretrain_file_path1 = os.path.join(embed_root, embed_name, pretrain_file1)
623    pretrain_file_path2 = os.path.join(embed_root, embed_name, pretrain_file2)
624
625    my_embed1 = from_file(pretrain_file_path1, elem_delim, init_unknown_vec=nd.ones)
626    my_embed2 = from_file(pretrain_file_path2, elem_delim)
627
628    v1 = nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token='<unk>',
629                   padding_token=None, bos_token=None, eos_token=None, reserved_tokens=None)
630    v1.set_embedding(my_embed1, my_embed2)
631    assert v1.embedding is not None
632    assert v1.embedding.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3, 'some_word$': 4}
633    assert v1.embedding.idx_to_token == ['<unk>', 'c', 'b', 'a', 'some_word$']
634
635    with pytest.raises(AssertionError):
636        v1.set_embedding(my_embed1, None, my_embed2)
637    assert_almost_equal(v1.embedding.idx_to_vec.asnumpy(),
638                        np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
639                                  [1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1],
640                                  [0.6, 0.7, 0.8, 0.9, 1, 0, 0, 0, 0, 0],
641                                  [0.1, 0.2, 0.3, 0.4, 0.5,
642                                   0.01, 0.02, 0.03, 0.04, 0.05],
643                                  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
644                        )
645
646    assert_almost_equal(v1.embedding['c'].asnumpy(),
647                        np.array([1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1])
648                        )
649
650    assert_almost_equal(v1.embedding[['b', 'not_exist']].asnumpy(),
651                        np.array([[0.6, 0.7, 0.8, 0.9, 1, 0, 0, 0, 0, 0],
652                                  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
653                        )
654
655    v1.embedding['a'] = nd.array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
656    v1.embedding['b'] = nd.array([3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
657
658    assert_almost_equal(v1.embedding.idx_to_vec.asnumpy(),
659                        np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
660                                  [1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1],
661                                  [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
662                                  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
663                                  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
664                        )
665
666    # Test loaded unknown tokens
667    pretrain_file3 = 'my_pretrain_file3.txt'
668    pretrain_file4 = 'my_pretrain_file4.txt'
669
670    _mk_my_pretrain_file3(os.path.join(embed_root, embed_name), elem_delim, pretrain_file3)
671    _mk_my_pretrain_file4(os.path.join(embed_root, embed_name), elem_delim, pretrain_file4)
672
673    pretrain_file_path3 = os.path.join(embed_root, embed_name, pretrain_file3)
674    pretrain_file_path4 = os.path.join(embed_root, embed_name, pretrain_file4)
675
676    my_embed3 = from_file(pretrain_file_path3, elem_delim, init_unknown_vec=nd.ones,
677                          unknown_token='<unk1>')
678    my_embed4 = from_file(pretrain_file_path4, elem_delim, unknown_token='<unk2>')
679
680    v2 = nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token='<unk>', padding_token=None,
681                   bos_token=None, eos_token=None, reserved_tokens=None)
682    v2.set_embedding(my_embed3, my_embed4)
683    assert v2.embedding.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3, 'some_word$': 4}
684    assert v2.embedding.idx_to_token == ['<unk>', 'c', 'b', 'a', 'some_word$']
685    assert_almost_equal(v2.embedding.idx_to_vec.asnumpy(),
686                        np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
687                                   0.11, 0.12, 0.13, 0.14, 0.15],
688                                  [1.1, 1.2, 1.3, 1.4, 1.5,
689                                   0.06, 0.07, 0.08, 0.09, 0.1],
690                                  [0.6, 0.7, 0.8, 0.9, 1,
691                                   0.11, 0.12, 0.13, 0.14, 0.15],
692                                  [0.1, 0.2, 0.3, 0.4, 0.5,
693                                   0.01, 0.02, 0.03, 0.04, 0.05],
694                                  [1.1, 1.2, 1.3, 1.4, 1.5,
695                                   0.11, 0.12, 0.13, 0.14, 0.15]])
696                        )
697
698    v3 = nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token='<unk1>', padding_token=None,
699                   bos_token=None, eos_token=None, reserved_tokens=None)
700    v3.set_embedding(my_embed3, my_embed4)
701    assert v3.embedding.token_to_idx == {'<unk1>': 0, 'c': 1, 'b': 2, 'a': 3, 'some_word$': 4}
702    assert v3.embedding.idx_to_token == ['<unk1>', 'c', 'b', 'a', 'some_word$']
703    assert_almost_equal(v3.embedding.idx_to_vec.asnumpy(),
704                        np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
705                                   0.11, 0.12, 0.13, 0.14, 0.15],
706                                  [1.1, 1.2, 1.3, 1.4, 1.5,
707                                   0.06, 0.07, 0.08, 0.09, 0.1],
708                                  [0.6, 0.7, 0.8, 0.9, 1,
709                                   0.11, 0.12, 0.13, 0.14, 0.15],
710                                  [0.1, 0.2, 0.3, 0.4, 0.5,
711                                   0.01, 0.02, 0.03, 0.04, 0.05],
712                                  [1.1, 1.2, 1.3, 1.4, 1.5,
713                                   0.11, 0.12, 0.13, 0.14, 0.15]])
714                        )
715
716    v4 = nlp.Vocab(counter, max_size=None, min_freq=1, unknown_token='<unk2>', padding_token=None,
717                   bos_token=None, eos_token=None, reserved_tokens=None)
718    v4.set_embedding(my_embed3, my_embed4)
719    assert v4.embedding.token_to_idx == {'<unk2>': 0, 'c': 1, 'b': 2, 'a': 3, 'some_word$': 4}
720    assert v4.embedding.idx_to_token == ['<unk2>', 'c', 'b', 'a', 'some_word$']
721    assert_almost_equal(v4.embedding.idx_to_vec.asnumpy(),
722                        np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
723                                   0.11, 0.12, 0.13, 0.14, 0.15],
724                                  [1.1, 1.2, 1.3, 1.4, 1.5,
725                                   0.06, 0.07, 0.08, 0.09, 0.1],
726                                  [0.6, 0.7, 0.8, 0.9, 1,
727                                   0.11, 0.12, 0.13, 0.14, 0.15],
728                                  [0.1, 0.2, 0.3, 0.4, 0.5,
729                                   0.01, 0.02, 0.03, 0.04, 0.05],
730                                  [1.1, 1.2, 1.3, 1.4, 1.5,
731                                   0.11, 0.12, 0.13, 0.14, 0.15]])
732                        )
733
734    counter2 = nlp.data.utils.Counter(['b', 'b', 'c', 'c', 'c', 'some_word$'])
735
736    v5 = nlp.Vocab(counter2, max_size=None, min_freq=1, unknown_token='a', padding_token=None,
737                   bos_token=None, eos_token=None, reserved_tokens=None)
738    v5.set_embedding(my_embed3, my_embed4)
739    assert v5.embedding.token_to_idx == {'a': 0, 'c': 1, 'b': 2, 'some_word$': 3}
740    assert v5.embedding.idx_to_token == ['a', 'c', 'b', 'some_word$']
741    assert_almost_equal(v5.embedding.idx_to_vec.asnumpy(),
742                        np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
743                                   0.11, 0.12, 0.13, 0.14, 0.15],
744                                  [1.1, 1.2, 1.3, 1.4, 1.5,
745                                   0.06, 0.07, 0.08, 0.09, 0.1],
746                                  [0.6, 0.7, 0.8, 0.9, 1,
747                                   0.11, 0.12, 0.13, 0.14, 0.15],
748                                  [1.1, 1.2, 1.3, 1.4, 1.5,
749                                   0.11, 0.12, 0.13, 0.14, 0.15]])
750                        )
751
752
753@pytest.mark.parametrize('allow_extend', [True, False])
754@pytest.mark.parametrize('unknown_token', [True, False])
755@pytest.mark.parametrize('vocab_unknown_token', [True, False])
756@pytest.mark.parametrize('initialize', [True, False])
757def test_vocab_set_embedding_with_subword_lookup_only_token_embedding(
758        allow_extend, unknown_token, vocab_unknown_token, initialize):
759    embsize = 5
760
761    class NaiveLookup:
762        def __contains__(self, token):
763            return True
764
765        def __getitem__(self, tokens):
766            if isinstance(tokens, str):
767                return nd.ones(embsize)
768            else:
769                return nd.ones((len(tokens), embsize))
770
771    c = nlp.data.utils.Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
772    v = nlp.Vocab(c, max_size=None, min_freq=1,
773                  unknown_token='<unk>' if vocab_unknown_token else None,
774                  padding_token='<pad>')
775
776    assert v.embedding is None
777
778    e = nlp.embedding.TokenEmbedding(
779        unknown_lookup=NaiveLookup(), allow_extend=allow_extend,
780        unknown_token='<unk>' if unknown_token else None)
781
782    if initialize and unknown_token:
783        e[e.unknown_token] = nd.zeros(embsize)
784    elif initialize and allow_extend:
785        with pytest.warns(UserWarning):  # encouraged to batch their updates
786            e["hello"] = e.unknown_lookup["hello"]
787    else:  # Cannot initialize, even if initialize is True
788        with pytest.raises(AssertionError):
789            v.set_embedding(e)
790        return  # cannot test more
791
792    v.set_embedding(e)
793    assert v.embedding is not None
794    assert v.embedding.idx_to_vec is not None
795    assert v.embedding.idx_to_vec.shape == (len(v), embsize)
796
797    for t in c.keys():
798        assert np.all(np.isclose(1, v.embedding[t].asnumpy()))
799
800
801@pytest.mark.serial
802@pytest.mark.remote_required
803def test_download_embed():
804    @nlp.embedding.register
805    class Test(nlp.embedding.TokenEmbedding):
806        # 33 bytes.
807        source_file_hash = \
808                {'embedding_test': ('embedding_test.vec',
809                                    '29b9a6511cf4b5aae293c44a9ec1365b74f2a2f8')}
810        namespace = 'test'
811
812        def __init__(self, embedding_root=os.path.join(get_home_dir(), 'embedding'),
813                     init_unknown_vec=nd.zeros, **kwargs):
814            source = 'embedding_test'
815            Test._check_source(self.source_file_hash, source)
816
817            file_path = Test._get_file_path(self.source_file_hash,
818                                            embedding_root, source)
819            unknown_token = kwargs.pop('unknown_token', '<unk>')
820            idx_to_token, idx_to_vec, unknown_token = self._load_embedding(
821                file_path,
822                elem_delim=' ',
823                unknown_token=unknown_token,
824                init_unknown_vec=init_unknown_vec)
825
826            return super(Test, self).__init__(unknown_token=unknown_token,
827                                              init_unknown_vec=None,
828                                              idx_to_token=idx_to_token,
829                                              idx_to_vec=idx_to_vec,
830                                              **kwargs)
831
832    test_embed = nlp.embedding.create('test')
833    assert_almost_equal(test_embed['hello'].asnumpy(), (nd.arange(5) + 1).asnumpy())
834    assert_almost_equal(test_embed['world'].asnumpy(), (nd.arange(5) + 6).asnumpy())
835    assert_almost_equal(test_embed['<unk>'].asnumpy(), nd.zeros((5,)).asnumpy())
836
837
838def test_vocab_serialization():
839    # Preserving unknown_token behaviour
840    vocab = nlp.Vocab(unknown_token=None)
841    with pytest.raises(KeyError):
842        vocab['hello']
843    loaded_vocab = nlp.Vocab.from_json(vocab.to_json())
844    with pytest.raises(KeyError):
845        loaded_vocab['hello']
846
847    vocab = nlp.Vocab(unknown_token='abc')
848    vocab['hello']
849    loaded_vocab = nlp.Vocab.from_json(vocab.to_json())
850    loaded_vocab['hello']
851
852
853def test_token_embedding_from_serialized_file(tmpdir):
854    embed_root = str(tmpdir)
855    embed_name = 'my_embed'
856    elem_delim = '\t'
857    pretrain_file = 'my_pretrain_file.txt'
858    serialize_file = 'my_pretrain_file.npz'
859
860    _mk_my_pretrain_file(
861        os.path.join(embed_root, embed_name), elem_delim, pretrain_file)
862
863    pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file)
864    serialize_file_path = os.path.join(embed_root, embed_name, serialize_file)
865
866    # Serialize the embedding in format suitable for storage on S3 and test if
867    # loading the serialized file always results in the same as loading the
868    # text file would
869    my_embed_for_serialization = nlp.embedding.TokenEmbedding.from_file(
870        pretrain_file_path, elem_delim=elem_delim, unknown_token=None)
871    my_embed_for_serialization.serialize(serialize_file_path)
872
873    # Test w/wo unknown token
874    known_unknown_token = my_embed_for_serialization.idx_to_token[-1]
875    for unknown_token in [None, '<some_unknown_token>', known_unknown_token]:
876        my_embed_text = nlp.embedding.TokenEmbedding.from_file(
877            pretrain_file_path, elem_delim=elem_delim,
878            unknown_token=unknown_token)
879        my_embed_serialize = nlp.embedding.TokenEmbedding.from_file(
880            serialize_file_path, unknown_token=unknown_token)
881        assert my_embed_serialize == my_embed_text
882
883
884@pytest.mark.parametrize('unknown_token',
885                         ['<strangetoken>', None, nlp._constants.UNK_TOKEN])
886@pytest.mark.serial
887@pytest.mark.remote_required
888def test_token_embedding_from_file_S3_with_custom_unknown_token(unknown_token):
889    nlp.embedding.create('glove', source='glove.6B.50d',
890                         unknown_token=unknown_token)
891
892
893@pytest.mark.parametrize('load_ngrams', [True, False])
894@pytest.mark.serial
895@pytest.mark.remote_required
896def test_token_embedding_from_S3_fasttext_with_ngrams(load_ngrams):
897    embed = nlp.embedding.create('fasttext', source='wiki.simple',
898                                 load_ngrams=load_ngrams, unknown_token=None)
899    if load_ngrams:
900        embed['$$$unknownword$$$']
901    else:
902        with pytest.raises(KeyError):
903            embed['$$$unknownword$$$']
904
905
906@pytest.mark.parametrize('setinconstructor', [True, False])
907@pytest.mark.parametrize('lookup', ['naive', 'incapable'])
908@pytest.mark.parametrize('initializetokenembedding', [True, False])
909@pytest.mark.parametrize('unknown_token', [True, False])
910@pytest.mark.parametrize('allow_extend', [True, False])
911def test_token_embedding_unknown_lookup(setinconstructor, lookup,
912                                        initializetokenembedding,
913                                        unknown_token, allow_extend, tmpdir):
914    class NaiveLookup:
915        dim = 5  # Must match _mk_my_pretrain_file
916
917        def __contains__(self, token):
918            return True
919
920        def __getitem__(self, tokens):
921            if isinstance(tokens, str):
922                return nd.ones(self.dim)
923            else:
924                return nd.ones((len(tokens), self.dim))
925
926    class IncapableLookup:
927        def __contains__(self, token):
928            return False
929
930        def __getitem__(self, tokens):
931            raise KeyError
932
933    if initializetokenembedding:
934        # Load a TokenEmbedding with idx_to_vec already initialized
935        embed_root = str(tmpdir)
936        embed_name = 'my_embed'
937        elem_delim = '\t'
938        pretrain_file = 'my_pretrain_file.txt'
939        _mk_my_pretrain_file(
940            os.path.join(embed_root, embed_name), elem_delim, pretrain_file)
941        pretrain_file_path = os.path.join(embed_root, embed_name,
942                                          pretrain_file)
943        TokenEmbedding = functools.partial(
944            nlp.embedding.TokenEmbedding.from_file, pretrain_file_path,
945            elem_delim)
946    else:
947        TokenEmbedding = nlp.embedding.token_embedding.TokenEmbedding
948
949    Lookup = NaiveLookup if lookup == "naive" else IncapableLookup
950
951    if setinconstructor:
952        TokEmb = functools.partial(
953            TokenEmbedding, unknown_lookup=Lookup(), allow_extend=allow_extend,
954            unknown_token='<unk>' if unknown_token else None)
955    else:
956
957        def TokEmb(*args, **kwargs):
958            token_embedding = TokenEmbedding(
959                allow_extend=allow_extend,
960                unknown_token='<unk>' if unknown_token else None, *args,
961                **kwargs)
962            token_embedding.unknown_lookup = Lookup()
963            return token_embedding
964
965    token_embedding = TokEmb()
966    if lookup == 'incapable' and not initializetokenembedding:
967        with pytest.raises(KeyError):
968            token_embedding['hello']
969    elif lookup == 'incapable' and initializetokenembedding and not unknown_token:
970        with pytest.raises(KeyError):
971            token_embedding['hello']
972    elif lookup == 'incapable' and initializetokenembedding and unknown_token:
973        assert 'hello' not in token_embedding.token_to_idx
974        assert np.all(np.isclose(0, token_embedding['hello'].asnumpy()))
975        assert 'hello' not in token_embedding.token_to_idx
976    elif lookup != 'naive':
977        raise RuntimeError('Invalid test parameterization.')
978    else:
979        assert 'hello' not in token_embedding.token_to_idx
980        assert np.all(np.isclose(1, token_embedding['hello'].asnumpy()))
981        assert 'hello' not in token_embedding.token_to_idx
982
983        if allow_extend:
984            with pytest.warns(UserWarning):  # encouraged to batch their updates
985                token_embedding['hello'] = token_embedding.unknown_lookup['hello']
986            assert 'hello' in token_embedding.token_to_idx
987            assert np.all(np.isclose(1, token_embedding['hello'].asnumpy()))
988
989            token_embedding[['hello2', 'world']] = \
990                token_embedding.unknown_lookup[['hello2', 'world']]
991            assert 'hello2' in token_embedding.token_to_idx
992            assert 'world' in token_embedding.token_to_idx
993            assert np.all(np.isclose(1, token_embedding['hello2'].asnumpy()))
994
995
996@pytest.mark.parametrize('initializeidxtovecbyextending', [True, False])
997def test_token_embedding_manual_extension(initializeidxtovecbyextending,
998                                          tmpdir):
999    if not initializeidxtovecbyextending:
1000        # Load a TokenEmbedding with idx_to_vec already initialized
1001        embed_root = str(tmpdir)
1002        embed_name = 'my_embed'
1003        elem_delim = '\t'
1004        pretrain_file = 'my_pretrain_file.txt'
1005        _mk_my_pretrain_file(
1006            os.path.join(embed_root, embed_name), elem_delim, pretrain_file)
1007        pretrain_file_path = os.path.join(embed_root, embed_name,
1008                                          pretrain_file)
1009        TokEmb = functools.partial(nlp.embedding.TokenEmbedding.from_file,
1010                                   pretrain_file_path, elem_delim,
1011                                   allow_extend=True)
1012    else:
1013        TokEmb = functools.partial(
1014            nlp.embedding.token_embedding.TokenEmbedding, allow_extend=True)
1015
1016    # Uninitialized token_embedding._idx_to_vec based
1017    token_embedding = TokEmb()
1018    with pytest.warns(UserWarning):  # encouraged to batch their updates
1019        token_embedding['hello'] = nd.zeros(shape=(1, 5))
1020    assert np.all(np.isclose(0, token_embedding['hello'].asnumpy()))
1021
1022    token_embedding = TokEmb()
1023    with pytest.warns(UserWarning):  # encouraged to batch their updates
1024        token_embedding['hello'] = nd.zeros(shape=(5, ))
1025    assert np.all(np.isclose(0, token_embedding['hello'].asnumpy()))
1026
1027    token_embedding = TokEmb()
1028    token_embedding[['hello', 'world']] = nd.zeros(shape=(2, 5))
1029    assert np.all(np.isclose(0, token_embedding['hello'].asnumpy()))
1030    assert np.all(np.isclose(0, token_embedding['world'].asnumpy()))
1031
1032    with pytest.raises(AssertionError):
1033        token_embedding = TokEmb()
1034        token_embedding[['hello', 'world']] = nd.zeros(shape=(1, 5))
1035
1036    with pytest.raises(AssertionError):
1037        token_embedding = TokEmb()
1038        token_embedding[['hello', 'world']] = nd.zeros(shape=(5, ))
1039
1040@pytest.mark.serial
1041@pytest.mark.remote_required
1042def test_token_embedding_serialization():
1043    with warnings.catch_warnings():
1044        warnings.simplefilter("ignore")
1045        # UserWarning: New token embedding test_vocab_embed.Test registered
1046        # with name test isoverriding existing token embedding
1047        # test_vocab_embed.Test
1048
1049        @nlp.embedding.register
1050        class Test(nlp.embedding.TokenEmbedding):
1051            # 33 bytes.
1052            source_file_hash = \
1053                    {'embedding_test': ('embedding_test.vec',
1054                                        '29b9a6511cf4b5aae293c44a9ec1365b74f2a2f8')}
1055            namespace = 'test'
1056
1057            def __init__(self, embedding_root=os.path.join(get_home_dir(), 'embedding'), **kwargs):
1058                source = 'embedding_test'
1059                Test._check_source(self.source_file_hash, source)
1060
1061                file_path = Test._get_file_path(self.source_file_hash, embedding_root, source)
1062
1063                unknown_token = kwargs.pop('unknown_token', '<unk>')
1064                init_unknown_vec = kwargs.pop('init_unknown_vec', nd.zeros)
1065                idx_to_token, idx_to_vec, unknown_token = self._load_embedding(
1066                    file_path, elem_delim=' ', unknown_token=unknown_token,
1067                    init_unknown_vec=init_unknown_vec)
1068
1069                super(Test,
1070                      self).__init__(unknown_token=unknown_token, init_unknown_vec=None,
1071                                     idx_to_token=idx_to_token, idx_to_vec=idx_to_vec, **kwargs)
1072
1073
1074    emb = nlp.embedding.create('test')
1075
1076    # Test uncompressed serialization
1077    file_path = os.path.join('tests', 'data', 'embedding', 'embeddings.npz')
1078    emb.serialize(file_path, compress=False)
1079    loaded_emb = Test.deserialize(file_path)
1080    assert loaded_emb == emb
1081
1082    # Test compressed serialization
1083    file_path_compressed = os.path.join('tests', 'data', 'embedding', 'embeddings_compressed.npz')
1084    emb.serialize(file_path_compressed, compress=True)
1085    loaded_emb = Test.deserialize(file_path)
1086    assert loaded_emb == emb
1087
1088
1089def test_word_embedding_evaluation_registry():
1090    with pytest.raises(RuntimeError):
1091
1092        @nlp.embedding.evaluation.register
1093        class InvalidEvaluationFunction:
1094            pass
1095
1096    with pytest.raises(KeyError):
1097        nlp.embedding.evaluation.create('invalid', 'InvalidEvaluationFunction')
1098
1099    nlp.embedding.evaluation.list_evaluation_functions()
1100    nlp.embedding.evaluation.list_evaluation_functions(kind='similarity')
1101    nlp.embedding.evaluation.list_evaluation_functions(kind='analogy')
1102    with pytest.raises(KeyError):
1103        nlp.embedding.evaluation.list_evaluation_functions('invalid')
1104
1105
1106@pytest.mark.parametrize(
1107    'similarity_function',
1108    nlp.embedding.evaluation.list_evaluation_functions('similarity'))
1109@pytest.mark.serial
1110@pytest.mark.remote_required
1111def test_word_embedding_similarity_evaluation_models(similarity_function):
1112    try:
1113        from scipy import stats
1114    except ImportError:
1115        raise ImportError('This testcase requires scipy.')
1116
1117    dataset = nlp.data.WordSim353()
1118
1119    counter = nlp.data.utils.Counter(w for wpair in dataset for w in wpair[:2])
1120    vocab = nlp.vocab.Vocab(counter)
1121    vocab.set_embedding(nlp.embedding.create('fasttext', source='wiki.simple'))
1122
1123    data = [[vocab[d[0]], vocab[d[1]], d[2]] for d in dataset]
1124    words1, words2, scores = zip(*data)
1125
1126    evaluator = nlp.embedding.evaluation.WordEmbeddingSimilarity(
1127        vocab.embedding.idx_to_vec,
1128        similarity_function=similarity_function)
1129    evaluator.initialize()
1130
1131    words1, words2 = nd.array(words1), nd.array(words2)
1132    pred_similarity = evaluator(words1, words2)
1133
1134    sr = stats.spearmanr(pred_similarity.asnumpy(), np.array(scores))
1135    assert np.isclose(0.6076485693769645, sr.correlation)
1136
1137
1138@pytest.mark.parametrize(
1139    'analogy_function',
1140    nlp.embedding.evaluation.list_evaluation_functions('analogy'))
1141@pytest.mark.serial
1142@pytest.mark.remote_required
1143def test_word_embedding_analogy_evaluation_models(analogy_function):
1144    dataset = nlp.data.GoogleAnalogyTestSet()
1145    dataset = [d for i, d in enumerate(dataset) if i < 10]
1146
1147    embedding = nlp.embedding.create('fasttext', source='wiki.simple')
1148    counter = nlp.data.utils.Counter(embedding.idx_to_token)
1149    vocab = nlp.vocab.Vocab(counter)
1150    vocab.set_embedding(embedding)
1151
1152    dataset_coded = [[vocab[d[0]], vocab[d[1]], vocab[d[2]], vocab[d[3]]]
1153                     for d in dataset]
1154    dataset_coded_nd = nd.array(dataset_coded, dtype=np.int64)
1155
1156    for k in [1, 3]:
1157        for exclude_question_words in [True, False]:
1158            evaluator = nlp.embedding.evaluation.WordEmbeddingAnalogy(
1159                idx_to_vec=vocab.embedding.idx_to_vec,
1160                analogy_function=analogy_function, k=k,
1161                exclude_question_words=exclude_question_words)
1162            evaluator.initialize()
1163
1164            words1 = dataset_coded_nd[:, 0]
1165            words2 = dataset_coded_nd[:, 1]
1166            words3 = dataset_coded_nd[:, 2]
1167            pred_idxs = evaluator(words1, words2, words3).astype(np.int64)
1168
1169            # If we don't exclude inputs most predictions should be wrong
1170            words4 = dataset_coded_nd[:, 3]
1171            accuracy = (pred_idxs[:, 0] == words4).astype(np.float64).mean()
1172            accuracy = accuracy.asscalar()
1173            if not exclude_question_words:
1174                assert accuracy <= 0.1
1175
1176                # Instead the model would predict W3 most of the time
1177                accuracy_w3 = (pred_idxs[:, 0] == words3).astype(np.float64).mean()
1178                assert accuracy_w3.asscalar() >= 0.89
1179
1180            else:
1181                # The wiki.simple vectors don't perform too good
1182                assert accuracy >= 0.29
1183
1184            # Assert output shape
1185            assert pred_idxs.shape[1] == k
1186
1187
1188def test_subword_function_bytes():
1189    sf = nlp.vocab.create_subword_function('ByteSubwords')
1190
1191    assert [[116, 101, 115, 116]] == sf([u'test'])
1192    assert [[207, 132, 206, 181, 207, 131, 207, 132]] == sf([u'τεστ'])
1193
1194
1195def test_subword_function_ngramhashes():
1196    num_subwords = 1000
1197    sf = nlp.vocab.create_subword_function('NGramHashes', ngrams=[3, 4, 5, 6],
1198                                           num_subwords=num_subwords)
1199
1200    assert set([8, 195, 271, 500, 201, 445, 379, 831, 617, 851]) == set(sf(['test'])[0])
1201    assert set([8, 195, 271, 500, 201, 445, 379, 831, 617, 851]) == set(sf([u'test'])[0])
1202    assert set([429, 793, 101, 334, 295, 474, 145, 524, 388, 790]) == set(sf([u'τεστ'])[0])
1203    assert 1669484008 == sf.fasttext_hash_asbytes('<te')
1204    assert 1669484008 == sf.fasttext_hash_asbytes(u'<te')
1205    assert 2688791429 == sf.fasttext_hash_asbytes(u'<τε')
1206    assert 1669484008 % num_subwords == next(iter(sf.subwords_to_indices(['<te'])))
1207    assert 1669484008 % num_subwords == next(iter(sf.subwords_to_indices([u'<te'])))
1208    assert 2688791429 % num_subwords == next(iter(sf.subwords_to_indices([u'<τε'])))
1209
1210
1211@pytest.mark.parametrize('unknown_token', ['<unk>', None])
1212@pytest.mark.parametrize('padding_token', ['<pad>', '<eos>', None])  # padding_token == eos_token
1213@pytest.mark.parametrize('eos_token', ['<eos>', None])
1214@pytest.mark.parametrize('reserved_tokens', [['<tok>'], []])
1215def test_vocab_duplicate_special_tokens(unknown_token, padding_token,
1216                                        eos_token, reserved_tokens):
1217    """Different special tokens are allowed to map to the same representations.
1218
1219    Special tokens are a subset of the reserved tokens. In general reserved
1220    tokens must not contain duplicates; however, it is allowed that multiple
1221    special tokens use the same reserved token.
1222
1223    """
1224    counter = nlp.data.utils.Counter(
1225        ['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
1226
1227    Vocab = functools.partial(nlp.Vocab,
1228                              counter,
1229                              max_size=None,
1230                              min_freq=1,
1231                              unknown_token=unknown_token,
1232                              padding_token=padding_token,
1233                              bos_token=None,
1234                              eos_token=eos_token)
1235
1236    v = Vocab(reserved_tokens=reserved_tokens)
1237
1238    # Duplicate special tokens must not corrupt the index
1239    # (Broken before GluonNLP 0.7)
1240    if eos_token is not None and padding_token == eos_token:
1241        # padding_token == eos_token; there should only be a single index for
1242        # <eos>
1243        # Before GluonNLP 0.7, idx_to_token looked like
1244        # ['<unk>', '<eos>', '<eos>', 'c', 'b', 'a']
1245        # But it should look like
1246        # ['<unk>', '<eos>', 'c', 'b', 'a']
1247        assert len(v.idx_to_token) == len(v.token_to_idx)
1248        assert len(v.idx_to_token) == len(set(v.idx_to_token))
1249
1250    # Specifying a special tokens as reserved tokens is counted as duplicate
1251    if eos_token is not None:
1252        with pytest.raises(AssertionError):
1253            Vocab(reserved_tokens=reserved_tokens + [eos_token])
1254
1255
1256@pytest.mark.parametrize('unknown_token', ['<unk>', None])
1257@pytest.mark.parametrize('padding_token', ['<pad>', None])
1258def test_vocab_identifiers_to_tokens_sanity_checks(unknown_token,
1259                                                   padding_token, counter):
1260    Vocab = functools.partial(nlp.Vocab,
1261                              counter,
1262                              max_size=None,
1263                              min_freq=1,
1264                              unknown_token=unknown_token,
1265                              bos_token=None,
1266                              eos_token=None)
1267    # Special tokens are automatically added
1268    v = Vocab(my_token='<does_not_exist_yet>')
1269    assert v.my_token == '<does_not_exist_yet>'
1270
1271    # Special token names must end in _token
1272    with pytest.raises(ValueError):
1273        Vocab(special_tok='<token>')
1274
1275    # Cannot set internals
1276    with pytest.raises(ValueError):
1277        Vocab(_private_token='<token>')
1278
1279    # Enforces uniqueness requirement of reserved_tokens argument
1280    with pytest.raises(AssertionError):
1281        Vocab(reserved_tokens=['<token>'], special_token='<token>')
1282
1283    # Many-to-one mapping is allowed
1284    v = Vocab(first_name_of_token='<token>', second_name_of_token='<token>')
1285    assert v.first_name_of_token == '<token>'
1286    assert v.second_name_of_token == '<token>'
1287    if unknown_token:
1288        v = Vocab(unk_token=unknown_token)
1289        assert v.unk_token == unknown_token
1290        assert v.unk_token == v.unknown_token
1291    if padding_token:
1292        v = Vocab(pad_token=padding_token)
1293        assert v.pad_token == padding_token
1294        assert v.pad_token == v.padding_token
1295
1296
1297@pytest.mark.parametrize('unknown_token', ['<unk>', None])
1298@pytest.mark.parametrize('padding_token', ['<pad>', None])
1299@pytest.mark.parametrize('identifiers_to_tokens', [{
1300    'important_token': '<imp>'
1301}, {}])
1302@pytest.mark.parametrize('test_serialization', [True, False])
1303def test_vocab_identifiers_to_tokens(unknown_token, padding_token,
1304                                     identifiers_to_tokens, test_serialization,
1305                                     counter):
1306    vocab = nlp.Vocab(counter,
1307                      max_size=None,
1308                      min_freq=1,
1309                      unknown_token=unknown_token,
1310                      padding_token=padding_token,
1311                      bos_token=None,
1312                      eos_token=None,
1313                      **identifiers_to_tokens)
1314
1315    if test_serialization:
1316        vocab = nlp.Vocab.from_json(vocab.to_json())
1317
1318    if identifiers_to_tokens:
1319        for identifier, token in identifiers_to_tokens.items():
1320            assert hasattr(vocab, identifier)
1321            assert getattr(vocab, identifier) == token
1322            assert token in vocab.reserved_tokens
1323
1324    assert getattr(vocab, 'unknown_token') == unknown_token
1325    assert getattr(vocab, 'padding_token') == padding_token
1326
1327
1328@pytest.mark.parametrize('unknown_token', ['<unk>', None])
1329@pytest.mark.parametrize('padding_token', ['<pad>', None])
1330def test_vocab_token_to_idx(unknown_token, padding_token, counter):
1331    reserved_tokens = ['<tok>']
1332    Vocab = functools.partial(nlp.Vocab,
1333                              counter,
1334                              max_size=None,
1335                              min_freq=1,
1336                              unknown_token=unknown_token,
1337                              padding_token=padding_token,
1338                              bos_token=None,
1339                              eos_token=None,
1340                              reserved_tokens=reserved_tokens)
1341    tokens = set(counter)
1342    if unknown_token is not None:
1343        tokens.add(unknown_token)
1344    if padding_token is not None:
1345        tokens.add(padding_token)
1346    if isinstance(reserved_tokens, dict):
1347        tokens.update(reserved_tokens.values())
1348    elif isinstance(reserved_tokens, list):
1349        tokens.update(reserved_tokens)
1350
1351    # Test sanity-checks
1352    valid_token = next(counter.elements())
1353    invalid_token = 'token_that_does_not_occur_in_vocab'
1354    assert invalid_token not in counter
1355    with pytest.raises(ValueError):
1356        Vocab(token_to_idx={invalid_token: 0})
1357    with pytest.raises(ValueError):
1358        Vocab(token_to_idx={valid_token: -1})
1359    with pytest.raises(ValueError):
1360        Vocab(token_to_idx={valid_token: len(tokens)})
1361
1362    def token_idx_check(token, idx):
1363        assert v[token] == idx
1364        assert v.token_to_idx[token] == idx
1365        assert v.idx_to_token[idx] == token
1366
1367    def consistency_check(v):
1368        assert set(v.idx_to_token) == set(v.token_to_idx.keys())
1369        assert set(v.token_to_idx.keys()) == set(tokens)
1370        assert set(v.token_to_idx.values()) == set(range(len(tokens)))
1371
1372    # Manual checks with special tokens
1373    if unknown_token:
1374        v = Vocab(token_to_idx={unknown_token: len(tokens) - 1})
1375        consistency_check(v)
1376        token_idx_check(unknown_token, len(tokens) -1)
1377    if padding_token:
1378        v = Vocab(token_to_idx={padding_token: len(tokens) - 1})
1379        consistency_check(v)
1380        token_idx_check(padding_token, len(tokens) -1)
1381
1382    # Test 10 random user-specified indices for a subset of tokens
1383    for i in range(10):
1384        k = random.randint(0, len(tokens) - 1)
1385        token_to_idx = {
1386            k: v
1387            for k, v in zip(random.sample(tokens, k), random.sample(
1388                range(k), k))
1389        }
1390        v = Vocab(token_to_idx=token_to_idx)
1391        consistency_check(v)
1392        for token, idx in token_to_idx.items():
1393            token_idx_check(token, idx)
1394
1395
1396@pytest.mark.parametrize('unknown_token', ['<unk>', None])
1397@pytest.mark.parametrize('padding_token', ['<pad>', '<eos>', None])
1398@pytest.mark.parametrize('eos_token', ['<eos>', None])
1399@pytest.mark.parametrize('reserved_tokens', [['<tok>'], []])
1400def test_vocab_duplicate_special_tokens(unknown_token, padding_token,
1401                                        eos_token, reserved_tokens, counter):
1402    """Different special tokens are allowed to map to the same representations.
1403
1404    Special tokens are a subset of the reserved tokens. In general reserved
1405    tokens must not contain duplicates; however, it is allowed that multiple
1406    special tokens use the same reserved token.
1407
1408    """
1409    Vocab = functools.partial(nlp.Vocab,counter,
1410                  max_size=None,
1411                  min_freq=1,
1412                  unknown_token=unknown_token,
1413                  padding_token=padding_token,
1414                  bos_token=None,
1415                  eos_token=eos_token
1416                  )
1417
1418    v = Vocab(reserved_tokens=reserved_tokens)
1419
1420    # Specifying a special tokens as reserved tokens is counted as duplicate
1421    if eos_token is not None:
1422        with pytest.raises(AssertionError):
1423            Vocab(reserved_tokens=reserved_tokens + [eos_token])
1424
1425
1426def test_vocab_backwards_compatibility_prior_v0_7_corrupted_index_bug():
1427    with open('tests/data/vocab/backward_compat_0_7_corrupted_index', 'r') as f:
1428        with pytest.warns(UserWarning):  # Detected a corrupted index in the deserialize vocabulary
1429            v = nlp.Vocab.from_json(f.read())
1430
1431    assert len(set(v.idx_to_token)) == len(v.token_to_idx)
1432    assert v['<unk>'] == 0
1433    assert v['<bos>'] == 2
1434    assert v['<eos>'] == 3
1435    assert v['token'] == 4
1436
1437    assert v.idx_to_token[0] == '<unk>'
1438    assert v.idx_to_token[1] == '<eos>'  # corruption preserved for backward
1439    # compatibility
1440    assert v.idx_to_token[2] == '<bos>'
1441    assert v.idx_to_token[3] == '<eos>'
1442    assert v.idx_to_token[4] == 'token'
1443
1444
1445@pytest.mark.parametrize('unknown_token', ['<unk>', '<UNK>'])
1446@pytest.mark.parametrize('padding_token', ['<pad>', '<eos>', None])
1447@pytest.mark.parametrize('eos_token', ['<eos>', None])
1448@pytest.mark.parametrize('reserved_tokens', [['<tok>'], []])
1449def test_vocab_remapped_unknown_token_idx(unknown_token, padding_token, eos_token, reserved_tokens,
1450                                          counter):
1451    Vocab = functools.partial(nlp.Vocab, counter, max_size=None, min_freq=1,
1452                              unknown_token=unknown_token, padding_token=padding_token,
1453                              bos_token=None, eos_token=eos_token)
1454
1455    v = Vocab()
1456    assert v['UNKNOWNWORD'] == 0
1457
1458    v = Vocab(token_to_idx={unknown_token: 1})
1459    assert v['UNKNOWNWORD'] == 1
1460
1461def test_vocab_consistency():
1462    v0 = nlp.Vocab({'a': 1}, mask_token='[MASK]', sep_token='[SEP]',
1463                   cls_token='[CLS]')
1464    v1 = nlp.Vocab({'a': 1}, mask_token='[MASK]', sep_token='[SEP]',
1465                   cls_token='[CLS]')
1466    assert v0[v0.mask_token] == v1[v1.mask_token]
1467    assert v0[v0.sep_token] == v1[v1.sep_token]
1468    assert v0[v0.cls_token] == v1[v1.cls_token]
1469