1# -*- coding: utf-8 -*-
2# Natural Language Toolkit: Language Model Unit Tests
3#
4# Copyright (C) 2001-2019 NLTK Project
5# Author: Ilia Kurenkov <ilia.kurenkov@gmail.com>
6# URL: <http://nltk.org/>
7# For license information, see LICENSE.TXT
8
9from __future__ import division
10
11import math
12import sys
13import unittest
14
15from six import add_metaclass
16
17from nltk.lm import (
18    Vocabulary,
19    MLE,
20    Lidstone,
21    Laplace,
22    WittenBellInterpolated,
23    KneserNeyInterpolated,
24)
25from nltk.lm.preprocessing import padded_everygrams
26
27
28def _prepare_test_data(ngram_order):
29    return (
30        Vocabulary(["a", "b", "c", "d", "z", "<s>", "</s>"], unk_cutoff=1),
31        [
32            list(padded_everygrams(ngram_order, sent))
33            for sent in (list("abcd"), list("egadbe"))
34        ],
35    )
36
37
38class ParametrizeTestsMeta(type):
39    """Metaclass for generating parametrized tests."""
40
41    def __new__(cls, name, bases, dct):
42        contexts = (
43            ("a",),
44            ("c",),
45            (u"<s>",),
46            ("b",),
47            (u"<UNK>",),
48            ("d",),
49            ("e",),
50            ("r",),
51            ("w",),
52        )
53        for i, c in enumerate(contexts):
54            dct["test_sumto1_{0}".format(i)] = cls.add_sum_to_1_test(c)
55        scores = dct.get("score_tests", [])
56        for i, (word, context, expected_score) in enumerate(scores):
57            dct["test_score_{0}".format(i)] = cls.add_score_test(
58                word, context, expected_score
59            )
60        return super(ParametrizeTestsMeta, cls).__new__(cls, name, bases, dct)
61
62    @classmethod
63    def add_score_test(cls, word, context, expected_score):
64        if sys.version_info > (3, 5):
65            message = "word='{word}', context={context}"
66        else:
67            # Python 2 doesn't report the mismatched values if we pass a custom
68            # message, so we have to report them manually.
69            message = (
70                "{score} != {expected_score} within 4 places, "
71                "word='{word}', context={context}"
72            )
73
74        def test_method(self):
75            score = self.model.score(word, context)
76            self.assertAlmostEqual(
77                score, expected_score, msg=message.format(**locals()), places=4
78            )
79
80        return test_method
81
82    @classmethod
83    def add_sum_to_1_test(cls, context):
84        def test(self):
85            s = sum(self.model.score(w, context) for w in self.model.vocab)
86            self.assertAlmostEqual(s, 1.0, msg="The context is {}".format(context))
87
88        return test
89
90
91@add_metaclass(ParametrizeTestsMeta)
92class MleBigramTests(unittest.TestCase):
93    """unit tests for MLENgramModel class"""
94
95    score_tests = [
96        ("d", ["c"], 1),
97        # Unseen ngrams should yield 0
98        ("d", ["e"], 0),
99        # Unigrams should also be 0
100        ("z", None, 0),
101        # N unigrams = 14
102        # count('a') = 2
103        ("a", None, 2.0 / 14),
104        # count('y') = 3
105        ("y", None, 3.0 / 14),
106    ]
107
108    def setUp(self):
109        vocab, training_text = _prepare_test_data(2)
110        self.model = MLE(2, vocabulary=vocab)
111        self.model.fit(training_text)
112
113    def test_logscore_zero_score(self):
114        # logscore of unseen ngrams should be -inf
115        logscore = self.model.logscore("d", ["e"])
116
117        self.assertTrue(math.isinf(logscore))
118
119    def test_entropy_perplexity_seen(self):
120        # ngrams seen during training
121        trained = [
122            ("<s>", "a"),
123            ("a", "b"),
124            ("b", "<UNK>"),
125            ("<UNK>", "a"),
126            ("a", "d"),
127            ("d", "</s>"),
128        ]
129        # Ngram = Log score
130        # <s>, a    = -1
131        # a, b      = -1
132        # b, UNK    = -1
133        # UNK, a    = -1.585
134        # a, d      = -1
135        # d, </s>   = -1
136        # TOTAL logscores   = -6.585
137        # - AVG logscores   = 1.0975
138        H = 1.0975
139        perplexity = 2.1398
140
141        self.assertAlmostEqual(H, self.model.entropy(trained), places=4)
142        self.assertAlmostEqual(perplexity, self.model.perplexity(trained), places=4)
143
144    def test_entropy_perplexity_unseen(self):
145        # In MLE, even one unseen ngram should make entropy and perplexity infinite
146        untrained = [("<s>", "a"), ("a", "c"), ("c", "d"), ("d", "</s>")]
147
148        self.assertTrue(math.isinf(self.model.entropy(untrained)))
149        self.assertTrue(math.isinf(self.model.perplexity(untrained)))
150
151    def test_entropy_perplexity_unigrams(self):
152        # word = score, log score
153        # <s>   = 0.1429, -2.8074
154        # a     = 0.1429, -2.8074
155        # c     = 0.0714, -3.8073
156        # UNK   = 0.2143, -2.2224
157        # d     = 0.1429, -2.8074
158        # c     = 0.0714, -3.8073
159        # </s>  = 0.1429, -2.8074
160        # TOTAL logscores = -21.6243
161        # - AVG logscores = 3.0095
162        H = 3.0095
163        perplexity = 8.0529
164
165        text = [("<s>",), ("a",), ("c",), ("-",), ("d",), ("c",), ("</s>",)]
166
167        self.assertAlmostEqual(H, self.model.entropy(text), places=4)
168        self.assertAlmostEqual(perplexity, self.model.perplexity(text), places=4)
169
170
171@add_metaclass(ParametrizeTestsMeta)
172class MleTrigramTests(unittest.TestCase):
173    """MLE trigram model tests"""
174
175    score_tests = [
176        # count(d | b, c) = 1
177        # count(b, c) = 1
178        ("d", ("b", "c"), 1),
179        # count(d | c) = 1
180        # count(c) = 1
181        ("d", ["c"], 1),
182        # total number of tokens is 18, of which "a" occured 2 times
183        ("a", None, 2.0 / 18),
184        # in vocabulary but unseen
185        ("z", None, 0),
186        # out of vocabulary should use "UNK" score
187        ("y", None, 3.0 / 18),
188    ]
189
190    def setUp(self):
191        vocab, training_text = _prepare_test_data(3)
192        self.model = MLE(3, vocabulary=vocab)
193        self.model.fit(training_text)
194
195
196@add_metaclass(ParametrizeTestsMeta)
197class LidstoneBigramTests(unittest.TestCase):
198    """unit tests for Lidstone class"""
199
200    score_tests = [
201        # count(d | c) = 1
202        # *count(d | c) = 1.1
203        # Count(w | c for w in vocab) = 1
204        # *Count(w | c for w in vocab) = 1.8
205        ("d", ["c"], 1.1 / 1.8),
206        # Total unigrams: 14
207        # Vocab size: 8
208        # Denominator: 14 + 0.8 = 14.8
209        # count("a") = 2
210        # *count("a") = 2.1
211        ("a", None, 2.1 / 14.8),
212        # in vocabulary but unseen
213        # count("z") = 0
214        # *count("z") = 0.1
215        ("z", None, 0.1 / 14.8),
216        # out of vocabulary should use "UNK" score
217        # count("<UNK>") = 3
218        # *count("<UNK>") = 3.1
219        ("y", None, 3.1 / 14.8),
220    ]
221
222    def setUp(self):
223        vocab, training_text = _prepare_test_data(2)
224        self.model = Lidstone(0.1, 2, vocabulary=vocab)
225        self.model.fit(training_text)
226
227    def test_gamma(self):
228        self.assertEqual(0.1, self.model.gamma)
229
230    def test_entropy_perplexity(self):
231        text = [
232            ("<s>", "a"),
233            ("a", "c"),
234            ("c", "<UNK>"),
235            ("<UNK>", "d"),
236            ("d", "c"),
237            ("c", "</s>"),
238        ]
239        # Unlike MLE this should be able to handle completely novel ngrams
240        # Ngram = score, log score
241        # <s>, a    = 0.3929, -1.3479
242        # a, c      = 0.0357, -4.8074
243        # c, UNK    = 0.0(5), -4.1699
244        # UNK, d    = 0.0263,  -5.2479
245        # d, c      = 0.0357, -4.8074
246        # c, </s>   = 0.0(5), -4.1699
247        # TOTAL logscore: −24.5504
248        # - AVG logscore: 4.0917
249        H = 4.0917
250        perplexity = 17.0504
251        self.assertAlmostEqual(H, self.model.entropy(text), places=4)
252        self.assertAlmostEqual(perplexity, self.model.perplexity(text), places=4)
253
254
255@add_metaclass(ParametrizeTestsMeta)
256class LidstoneTrigramTests(unittest.TestCase):
257    score_tests = [
258        # Logic behind this is the same as for bigram model
259        ("d", ["c"], 1.1 / 1.8),
260        # if we choose a word that hasn't appeared after (b, c)
261        ("e", ["c"], 0.1 / 1.8),
262        # Trigram score now
263        ("d", ["b", "c"], 1.1 / 1.8),
264        ("e", ["b", "c"], 0.1 / 1.8),
265    ]
266
267    def setUp(self):
268        vocab, training_text = _prepare_test_data(3)
269        self.model = Lidstone(0.1, 3, vocabulary=vocab)
270        self.model.fit(training_text)
271
272
273@add_metaclass(ParametrizeTestsMeta)
274class LaplaceBigramTests(unittest.TestCase):
275    """unit tests for Laplace class"""
276
277    score_tests = [
278        # basic sanity-check:
279        # count(d | c) = 1
280        # *count(d | c) = 2
281        # Count(w | c for w in vocab) = 1
282        # *Count(w | c for w in vocab) = 9
283        ("d", ["c"], 2.0 / 9),
284        # Total unigrams: 14
285        # Vocab size: 8
286        # Denominator: 14 + 8 = 22
287        # count("a") = 2
288        # *count("a") = 3
289        ("a", None, 3.0 / 22),
290        # in vocabulary but unseen
291        # count("z") = 0
292        # *count("z") = 1
293        ("z", None, 1.0 / 22),
294        # out of vocabulary should use "UNK" score
295        # count("<UNK>") = 3
296        # *count("<UNK>") = 4
297        ("y", None, 4.0 / 22),
298    ]
299
300    def setUp(self):
301        vocab, training_text = _prepare_test_data(2)
302        self.model = Laplace(2, vocabulary=vocab)
303        self.model.fit(training_text)
304
305    def test_gamma(self):
306        # Make sure the gamma is set to 1
307        self.assertEqual(1, self.model.gamma)
308
309    def test_entropy_perplexity(self):
310        text = [
311            ("<s>", "a"),
312            ("a", "c"),
313            ("c", "<UNK>"),
314            ("<UNK>", "d"),
315            ("d", "c"),
316            ("c", "</s>"),
317        ]
318        # Unlike MLE this should be able to handle completely novel ngrams
319        # Ngram = score, log score
320        # <s>, a    = 0.2, -2.3219
321        # a, c      = 0.1, -3.3219
322        # c, UNK    = 0.(1), -3.1699
323        # UNK, d    = 0.(09), 3.4594
324        # d, c      = 0.1 -3.3219
325        # c, </s>   = 0.(1), -3.1699
326        # Total logscores: −18.7651
327        # - AVG logscores: 3.1275
328        H = 3.1275
329        perplexity = 8.7393
330        self.assertAlmostEqual(H, self.model.entropy(text), places=4)
331        self.assertAlmostEqual(perplexity, self.model.perplexity(text), places=4)
332
333
334@add_metaclass(ParametrizeTestsMeta)
335class WittenBellInterpolatedTrigramTests(unittest.TestCase):
336    def setUp(self):
337        vocab, training_text = _prepare_test_data(3)
338        self.model = WittenBellInterpolated(3, vocabulary=vocab)
339        self.model.fit(training_text)
340
341    score_tests = [
342        # For unigram scores by default revert to MLE
343        # Total unigrams: 18
344        # count('c'): 1
345        ("c", None, 1.0 / 18),
346        # in vocabulary but unseen
347        # count("z") = 0
348        ("z", None, 0.0 / 18),
349        # out of vocabulary should use "UNK" score
350        # count("<UNK>") = 3
351        ("y", None, 3.0 / 18),
352        # gamma(['b']) = 0.1111
353        # mle.score('c', ['b']) = 0.5
354        # (1 - gamma) * mle + gamma * mle('c') ~= 0.45 + .3 / 18
355        ("c", ["b"], (1 - 0.1111) * 0.5 + 0.1111 * 1 / 18),
356        # building on that, let's try 'a b c' as the trigram
357        # gamma(['a', 'b']) = 0.0667
358        # mle("c", ["a", "b"]) = 1
359        ("c", ["a", "b"], (1 - 0.0667) + 0.0667 * ((1 - 0.1111) * 0.5 + 0.1111 / 18)),
360    ]
361
362
363@add_metaclass(ParametrizeTestsMeta)
364class KneserNeyInterpolatedTrigramTests(unittest.TestCase):
365    def setUp(self):
366        vocab, training_text = _prepare_test_data(3)
367        self.model = KneserNeyInterpolated(3, vocabulary=vocab)
368        self.model.fit(training_text)
369
370    score_tests = [
371        # For unigram scores revert to uniform
372        # Vocab size: 8
373        # count('c'): 1
374        ("c", None, 1.0 / 8),
375        # in vocabulary but unseen, still uses uniform
376        ("z", None, 1 / 8),
377        # out of vocabulary should use "UNK" score, i.e. again uniform
378        ("y", None, 1.0 / 8),
379        # alpha = count('bc') - discount = 1 - 0.1 = 0.9
380        # gamma(['b']) = discount * number of unique words that follow ['b'] = 0.1 * 2
381        # normalizer = total number of bigrams with this context = 2
382        # the final should be: (alpha + gamma * unigram_score("c"))
383        ("c", ["b"], (0.9 + 0.2 * (1 / 8)) / 2),
384        # building on that, let's try 'a b c' as the trigram
385        # alpha = count('abc') - discount = 1 - 0.1 = 0.9
386        # gamma(['a', 'b']) = 0.1 * 1
387        # normalizer = total number of trigrams with prefix "ab" = 1 => we can ignore it!
388        ("c", ["a", "b"], 0.9 + 0.1 * ((0.9 + 0.2 * (1 / 8)) / 2)),
389    ]
390
391
392class NgramModelTextGenerationTests(unittest.TestCase):
393    """Using MLE estimator, generate some text."""
394
395    def setUp(self):
396        vocab, training_text = _prepare_test_data(3)
397        self.model = MLE(3, vocabulary=vocab)
398        self.model.fit(training_text)
399
400    def test_generate_one_no_context(self):
401        self.assertEqual(self.model.generate(random_seed=3), "<UNK>")
402
403    def test_generate_one_limiting_context(self):
404        # We don't need random_seed for contexts with only one continuation
405        self.assertEqual(self.model.generate(text_seed=["c"]), "d")
406        self.assertEqual(self.model.generate(text_seed=["b", "c"]), "d")
407        self.assertEqual(self.model.generate(text_seed=["a", "c"]), "d")
408
409    def test_generate_one_varied_context(self):
410        # When context doesn't limit our options enough, seed the random choice
411        self.assertEqual(
412            self.model.generate(text_seed=("a", "<s>"), random_seed=2), "a"
413        )
414
415    def test_generate_no_seed_unigrams(self):
416        self.assertEqual(
417            self.model.generate(5, random_seed=3),
418            ["<UNK>", "</s>", "</s>", "</s>", "</s>"],
419        )
420
421    def test_generate_with_text_seed(self):
422        self.assertEqual(
423            self.model.generate(5, text_seed=("<s>", "e"), random_seed=3),
424            ["<UNK>", "a", "d", "b", "<UNK>"],
425        )
426
427    def test_generate_oov_text_seed(self):
428        self.assertEqual(
429            self.model.generate(text_seed=("aliens",), random_seed=3),
430            self.model.generate(text_seed=("<UNK>",), random_seed=3),
431        )
432
433    def test_generate_None_text_seed(self):
434        # should crash with type error when we try to look it up in vocabulary
435        with self.assertRaises(TypeError):
436            self.model.generate(text_seed=(None,))
437
438        # This will work
439        self.assertEqual(
440            self.model.generate(text_seed=None, random_seed=3),
441            self.model.generate(random_seed=3),
442        )
443