1# -*- coding: utf-8 -*- 2# Natural Language Toolkit: Language Model Unit Tests 3# 4# Copyright (C) 2001-2019 NLTK Project 5# Author: Ilia Kurenkov <ilia.kurenkov@gmail.com> 6# URL: <http://nltk.org/> 7# For license information, see LICENSE.TXT 8 9from __future__ import division 10 11import math 12import sys 13import unittest 14 15from six import add_metaclass 16 17from nltk.lm import ( 18 Vocabulary, 19 MLE, 20 Lidstone, 21 Laplace, 22 WittenBellInterpolated, 23 KneserNeyInterpolated, 24) 25from nltk.lm.preprocessing import padded_everygrams 26 27 28def _prepare_test_data(ngram_order): 29 return ( 30 Vocabulary(["a", "b", "c", "d", "z", "<s>", "</s>"], unk_cutoff=1), 31 [ 32 list(padded_everygrams(ngram_order, sent)) 33 for sent in (list("abcd"), list("egadbe")) 34 ], 35 ) 36 37 38class ParametrizeTestsMeta(type): 39 """Metaclass for generating parametrized tests.""" 40 41 def __new__(cls, name, bases, dct): 42 contexts = ( 43 ("a",), 44 ("c",), 45 (u"<s>",), 46 ("b",), 47 (u"<UNK>",), 48 ("d",), 49 ("e",), 50 ("r",), 51 ("w",), 52 ) 53 for i, c in enumerate(contexts): 54 dct["test_sumto1_{0}".format(i)] = cls.add_sum_to_1_test(c) 55 scores = dct.get("score_tests", []) 56 for i, (word, context, expected_score) in enumerate(scores): 57 dct["test_score_{0}".format(i)] = cls.add_score_test( 58 word, context, expected_score 59 ) 60 return super(ParametrizeTestsMeta, cls).__new__(cls, name, bases, dct) 61 62 @classmethod 63 def add_score_test(cls, word, context, expected_score): 64 if sys.version_info > (3, 5): 65 message = "word='{word}', context={context}" 66 else: 67 # Python 2 doesn't report the mismatched values if we pass a custom 68 # message, so we have to report them manually. 69 message = ( 70 "{score} != {expected_score} within 4 places, " 71 "word='{word}', context={context}" 72 ) 73 74 def test_method(self): 75 score = self.model.score(word, context) 76 self.assertAlmostEqual( 77 score, expected_score, msg=message.format(**locals()), places=4 78 ) 79 80 return test_method 81 82 @classmethod 83 def add_sum_to_1_test(cls, context): 84 def test(self): 85 s = sum(self.model.score(w, context) for w in self.model.vocab) 86 self.assertAlmostEqual(s, 1.0, msg="The context is {}".format(context)) 87 88 return test 89 90 91@add_metaclass(ParametrizeTestsMeta) 92class MleBigramTests(unittest.TestCase): 93 """unit tests for MLENgramModel class""" 94 95 score_tests = [ 96 ("d", ["c"], 1), 97 # Unseen ngrams should yield 0 98 ("d", ["e"], 0), 99 # Unigrams should also be 0 100 ("z", None, 0), 101 # N unigrams = 14 102 # count('a') = 2 103 ("a", None, 2.0 / 14), 104 # count('y') = 3 105 ("y", None, 3.0 / 14), 106 ] 107 108 def setUp(self): 109 vocab, training_text = _prepare_test_data(2) 110 self.model = MLE(2, vocabulary=vocab) 111 self.model.fit(training_text) 112 113 def test_logscore_zero_score(self): 114 # logscore of unseen ngrams should be -inf 115 logscore = self.model.logscore("d", ["e"]) 116 117 self.assertTrue(math.isinf(logscore)) 118 119 def test_entropy_perplexity_seen(self): 120 # ngrams seen during training 121 trained = [ 122 ("<s>", "a"), 123 ("a", "b"), 124 ("b", "<UNK>"), 125 ("<UNK>", "a"), 126 ("a", "d"), 127 ("d", "</s>"), 128 ] 129 # Ngram = Log score 130 # <s>, a = -1 131 # a, b = -1 132 # b, UNK = -1 133 # UNK, a = -1.585 134 # a, d = -1 135 # d, </s> = -1 136 # TOTAL logscores = -6.585 137 # - AVG logscores = 1.0975 138 H = 1.0975 139 perplexity = 2.1398 140 141 self.assertAlmostEqual(H, self.model.entropy(trained), places=4) 142 self.assertAlmostEqual(perplexity, self.model.perplexity(trained), places=4) 143 144 def test_entropy_perplexity_unseen(self): 145 # In MLE, even one unseen ngram should make entropy and perplexity infinite 146 untrained = [("<s>", "a"), ("a", "c"), ("c", "d"), ("d", "</s>")] 147 148 self.assertTrue(math.isinf(self.model.entropy(untrained))) 149 self.assertTrue(math.isinf(self.model.perplexity(untrained))) 150 151 def test_entropy_perplexity_unigrams(self): 152 # word = score, log score 153 # <s> = 0.1429, -2.8074 154 # a = 0.1429, -2.8074 155 # c = 0.0714, -3.8073 156 # UNK = 0.2143, -2.2224 157 # d = 0.1429, -2.8074 158 # c = 0.0714, -3.8073 159 # </s> = 0.1429, -2.8074 160 # TOTAL logscores = -21.6243 161 # - AVG logscores = 3.0095 162 H = 3.0095 163 perplexity = 8.0529 164 165 text = [("<s>",), ("a",), ("c",), ("-",), ("d",), ("c",), ("</s>",)] 166 167 self.assertAlmostEqual(H, self.model.entropy(text), places=4) 168 self.assertAlmostEqual(perplexity, self.model.perplexity(text), places=4) 169 170 171@add_metaclass(ParametrizeTestsMeta) 172class MleTrigramTests(unittest.TestCase): 173 """MLE trigram model tests""" 174 175 score_tests = [ 176 # count(d | b, c) = 1 177 # count(b, c) = 1 178 ("d", ("b", "c"), 1), 179 # count(d | c) = 1 180 # count(c) = 1 181 ("d", ["c"], 1), 182 # total number of tokens is 18, of which "a" occured 2 times 183 ("a", None, 2.0 / 18), 184 # in vocabulary but unseen 185 ("z", None, 0), 186 # out of vocabulary should use "UNK" score 187 ("y", None, 3.0 / 18), 188 ] 189 190 def setUp(self): 191 vocab, training_text = _prepare_test_data(3) 192 self.model = MLE(3, vocabulary=vocab) 193 self.model.fit(training_text) 194 195 196@add_metaclass(ParametrizeTestsMeta) 197class LidstoneBigramTests(unittest.TestCase): 198 """unit tests for Lidstone class""" 199 200 score_tests = [ 201 # count(d | c) = 1 202 # *count(d | c) = 1.1 203 # Count(w | c for w in vocab) = 1 204 # *Count(w | c for w in vocab) = 1.8 205 ("d", ["c"], 1.1 / 1.8), 206 # Total unigrams: 14 207 # Vocab size: 8 208 # Denominator: 14 + 0.8 = 14.8 209 # count("a") = 2 210 # *count("a") = 2.1 211 ("a", None, 2.1 / 14.8), 212 # in vocabulary but unseen 213 # count("z") = 0 214 # *count("z") = 0.1 215 ("z", None, 0.1 / 14.8), 216 # out of vocabulary should use "UNK" score 217 # count("<UNK>") = 3 218 # *count("<UNK>") = 3.1 219 ("y", None, 3.1 / 14.8), 220 ] 221 222 def setUp(self): 223 vocab, training_text = _prepare_test_data(2) 224 self.model = Lidstone(0.1, 2, vocabulary=vocab) 225 self.model.fit(training_text) 226 227 def test_gamma(self): 228 self.assertEqual(0.1, self.model.gamma) 229 230 def test_entropy_perplexity(self): 231 text = [ 232 ("<s>", "a"), 233 ("a", "c"), 234 ("c", "<UNK>"), 235 ("<UNK>", "d"), 236 ("d", "c"), 237 ("c", "</s>"), 238 ] 239 # Unlike MLE this should be able to handle completely novel ngrams 240 # Ngram = score, log score 241 # <s>, a = 0.3929, -1.3479 242 # a, c = 0.0357, -4.8074 243 # c, UNK = 0.0(5), -4.1699 244 # UNK, d = 0.0263, -5.2479 245 # d, c = 0.0357, -4.8074 246 # c, </s> = 0.0(5), -4.1699 247 # TOTAL logscore: −24.5504 248 # - AVG logscore: 4.0917 249 H = 4.0917 250 perplexity = 17.0504 251 self.assertAlmostEqual(H, self.model.entropy(text), places=4) 252 self.assertAlmostEqual(perplexity, self.model.perplexity(text), places=4) 253 254 255@add_metaclass(ParametrizeTestsMeta) 256class LidstoneTrigramTests(unittest.TestCase): 257 score_tests = [ 258 # Logic behind this is the same as for bigram model 259 ("d", ["c"], 1.1 / 1.8), 260 # if we choose a word that hasn't appeared after (b, c) 261 ("e", ["c"], 0.1 / 1.8), 262 # Trigram score now 263 ("d", ["b", "c"], 1.1 / 1.8), 264 ("e", ["b", "c"], 0.1 / 1.8), 265 ] 266 267 def setUp(self): 268 vocab, training_text = _prepare_test_data(3) 269 self.model = Lidstone(0.1, 3, vocabulary=vocab) 270 self.model.fit(training_text) 271 272 273@add_metaclass(ParametrizeTestsMeta) 274class LaplaceBigramTests(unittest.TestCase): 275 """unit tests for Laplace class""" 276 277 score_tests = [ 278 # basic sanity-check: 279 # count(d | c) = 1 280 # *count(d | c) = 2 281 # Count(w | c for w in vocab) = 1 282 # *Count(w | c for w in vocab) = 9 283 ("d", ["c"], 2.0 / 9), 284 # Total unigrams: 14 285 # Vocab size: 8 286 # Denominator: 14 + 8 = 22 287 # count("a") = 2 288 # *count("a") = 3 289 ("a", None, 3.0 / 22), 290 # in vocabulary but unseen 291 # count("z") = 0 292 # *count("z") = 1 293 ("z", None, 1.0 / 22), 294 # out of vocabulary should use "UNK" score 295 # count("<UNK>") = 3 296 # *count("<UNK>") = 4 297 ("y", None, 4.0 / 22), 298 ] 299 300 def setUp(self): 301 vocab, training_text = _prepare_test_data(2) 302 self.model = Laplace(2, vocabulary=vocab) 303 self.model.fit(training_text) 304 305 def test_gamma(self): 306 # Make sure the gamma is set to 1 307 self.assertEqual(1, self.model.gamma) 308 309 def test_entropy_perplexity(self): 310 text = [ 311 ("<s>", "a"), 312 ("a", "c"), 313 ("c", "<UNK>"), 314 ("<UNK>", "d"), 315 ("d", "c"), 316 ("c", "</s>"), 317 ] 318 # Unlike MLE this should be able to handle completely novel ngrams 319 # Ngram = score, log score 320 # <s>, a = 0.2, -2.3219 321 # a, c = 0.1, -3.3219 322 # c, UNK = 0.(1), -3.1699 323 # UNK, d = 0.(09), 3.4594 324 # d, c = 0.1 -3.3219 325 # c, </s> = 0.(1), -3.1699 326 # Total logscores: −18.7651 327 # - AVG logscores: 3.1275 328 H = 3.1275 329 perplexity = 8.7393 330 self.assertAlmostEqual(H, self.model.entropy(text), places=4) 331 self.assertAlmostEqual(perplexity, self.model.perplexity(text), places=4) 332 333 334@add_metaclass(ParametrizeTestsMeta) 335class WittenBellInterpolatedTrigramTests(unittest.TestCase): 336 def setUp(self): 337 vocab, training_text = _prepare_test_data(3) 338 self.model = WittenBellInterpolated(3, vocabulary=vocab) 339 self.model.fit(training_text) 340 341 score_tests = [ 342 # For unigram scores by default revert to MLE 343 # Total unigrams: 18 344 # count('c'): 1 345 ("c", None, 1.0 / 18), 346 # in vocabulary but unseen 347 # count("z") = 0 348 ("z", None, 0.0 / 18), 349 # out of vocabulary should use "UNK" score 350 # count("<UNK>") = 3 351 ("y", None, 3.0 / 18), 352 # gamma(['b']) = 0.1111 353 # mle.score('c', ['b']) = 0.5 354 # (1 - gamma) * mle + gamma * mle('c') ~= 0.45 + .3 / 18 355 ("c", ["b"], (1 - 0.1111) * 0.5 + 0.1111 * 1 / 18), 356 # building on that, let's try 'a b c' as the trigram 357 # gamma(['a', 'b']) = 0.0667 358 # mle("c", ["a", "b"]) = 1 359 ("c", ["a", "b"], (1 - 0.0667) + 0.0667 * ((1 - 0.1111) * 0.5 + 0.1111 / 18)), 360 ] 361 362 363@add_metaclass(ParametrizeTestsMeta) 364class KneserNeyInterpolatedTrigramTests(unittest.TestCase): 365 def setUp(self): 366 vocab, training_text = _prepare_test_data(3) 367 self.model = KneserNeyInterpolated(3, vocabulary=vocab) 368 self.model.fit(training_text) 369 370 score_tests = [ 371 # For unigram scores revert to uniform 372 # Vocab size: 8 373 # count('c'): 1 374 ("c", None, 1.0 / 8), 375 # in vocabulary but unseen, still uses uniform 376 ("z", None, 1 / 8), 377 # out of vocabulary should use "UNK" score, i.e. again uniform 378 ("y", None, 1.0 / 8), 379 # alpha = count('bc') - discount = 1 - 0.1 = 0.9 380 # gamma(['b']) = discount * number of unique words that follow ['b'] = 0.1 * 2 381 # normalizer = total number of bigrams with this context = 2 382 # the final should be: (alpha + gamma * unigram_score("c")) 383 ("c", ["b"], (0.9 + 0.2 * (1 / 8)) / 2), 384 # building on that, let's try 'a b c' as the trigram 385 # alpha = count('abc') - discount = 1 - 0.1 = 0.9 386 # gamma(['a', 'b']) = 0.1 * 1 387 # normalizer = total number of trigrams with prefix "ab" = 1 => we can ignore it! 388 ("c", ["a", "b"], 0.9 + 0.1 * ((0.9 + 0.2 * (1 / 8)) / 2)), 389 ] 390 391 392class NgramModelTextGenerationTests(unittest.TestCase): 393 """Using MLE estimator, generate some text.""" 394 395 def setUp(self): 396 vocab, training_text = _prepare_test_data(3) 397 self.model = MLE(3, vocabulary=vocab) 398 self.model.fit(training_text) 399 400 def test_generate_one_no_context(self): 401 self.assertEqual(self.model.generate(random_seed=3), "<UNK>") 402 403 def test_generate_one_limiting_context(self): 404 # We don't need random_seed for contexts with only one continuation 405 self.assertEqual(self.model.generate(text_seed=["c"]), "d") 406 self.assertEqual(self.model.generate(text_seed=["b", "c"]), "d") 407 self.assertEqual(self.model.generate(text_seed=["a", "c"]), "d") 408 409 def test_generate_one_varied_context(self): 410 # When context doesn't limit our options enough, seed the random choice 411 self.assertEqual( 412 self.model.generate(text_seed=("a", "<s>"), random_seed=2), "a" 413 ) 414 415 def test_generate_no_seed_unigrams(self): 416 self.assertEqual( 417 self.model.generate(5, random_seed=3), 418 ["<UNK>", "</s>", "</s>", "</s>", "</s>"], 419 ) 420 421 def test_generate_with_text_seed(self): 422 self.assertEqual( 423 self.model.generate(5, text_seed=("<s>", "e"), random_seed=3), 424 ["<UNK>", "a", "d", "b", "<UNK>"], 425 ) 426 427 def test_generate_oov_text_seed(self): 428 self.assertEqual( 429 self.model.generate(text_seed=("aliens",), random_seed=3), 430 self.model.generate(text_seed=("<UNK>",), random_seed=3), 431 ) 432 433 def test_generate_None_text_seed(self): 434 # should crash with type error when we try to look it up in vocabulary 435 with self.assertRaises(TypeError): 436 self.model.generate(text_seed=(None,)) 437 438 # This will work 439 self.assertEqual( 440 self.model.generate(text_seed=None, random_seed=3), 441 self.model.generate(random_seed=3), 442 ) 443