1"""Test suite for statistics module, including helper NumericTestCase and 2approx_equal function. 3 4""" 5 6import bisect 7import collections 8import collections.abc 9import copy 10import decimal 11import doctest 12import itertools 13import math 14import pickle 15import random 16import sys 17import unittest 18from test import support 19from test.support import import_helper, requires_IEEE_754 20 21from decimal import Decimal 22from fractions import Fraction 23 24 25# Module to be tested. 26import statistics 27 28 29# === Helper functions and class === 30 31def sign(x): 32 """Return -1.0 for negatives, including -0.0, otherwise +1.0.""" 33 return math.copysign(1, x) 34 35def _nan_equal(a, b): 36 """Return True if a and b are both the same kind of NAN. 37 38 >>> _nan_equal(Decimal('NAN'), Decimal('NAN')) 39 True 40 >>> _nan_equal(Decimal('sNAN'), Decimal('sNAN')) 41 True 42 >>> _nan_equal(Decimal('NAN'), Decimal('sNAN')) 43 False 44 >>> _nan_equal(Decimal(42), Decimal('NAN')) 45 False 46 47 >>> _nan_equal(float('NAN'), float('NAN')) 48 True 49 >>> _nan_equal(float('NAN'), 0.5) 50 False 51 52 >>> _nan_equal(float('NAN'), Decimal('NAN')) 53 False 54 55 NAN payloads are not compared. 56 """ 57 if type(a) is not type(b): 58 return False 59 if isinstance(a, float): 60 return math.isnan(a) and math.isnan(b) 61 aexp = a.as_tuple()[2] 62 bexp = b.as_tuple()[2] 63 return (aexp == bexp) and (aexp in ('n', 'N')) # Both NAN or both sNAN. 64 65 66def _calc_errors(actual, expected): 67 """Return the absolute and relative errors between two numbers. 68 69 >>> _calc_errors(100, 75) 70 (25, 0.25) 71 >>> _calc_errors(100, 100) 72 (0, 0.0) 73 74 Returns the (absolute error, relative error) between the two arguments. 75 """ 76 base = max(abs(actual), abs(expected)) 77 abs_err = abs(actual - expected) 78 rel_err = abs_err/base if base else float('inf') 79 return (abs_err, rel_err) 80 81 82def approx_equal(x, y, tol=1e-12, rel=1e-7): 83 """approx_equal(x, y [, tol [, rel]]) => True|False 84 85 Return True if numbers x and y are approximately equal, to within some 86 margin of error, otherwise return False. Numbers which compare equal 87 will also compare approximately equal. 88 89 x is approximately equal to y if the difference between them is less than 90 an absolute error tol or a relative error rel, whichever is bigger. 91 92 If given, both tol and rel must be finite, non-negative numbers. If not 93 given, default values are tol=1e-12 and rel=1e-7. 94 95 >>> approx_equal(1.2589, 1.2587, tol=0.0003, rel=0) 96 True 97 >>> approx_equal(1.2589, 1.2587, tol=0.0001, rel=0) 98 False 99 100 Absolute error is defined as abs(x-y); if that is less than or equal to 101 tol, x and y are considered approximately equal. 102 103 Relative error is defined as abs((x-y)/x) or abs((x-y)/y), whichever is 104 smaller, provided x or y are not zero. If that figure is less than or 105 equal to rel, x and y are considered approximately equal. 106 107 Complex numbers are not directly supported. If you wish to compare to 108 complex numbers, extract their real and imaginary parts and compare them 109 individually. 110 111 NANs always compare unequal, even with themselves. Infinities compare 112 approximately equal if they have the same sign (both positive or both 113 negative). Infinities with different signs compare unequal; so do 114 comparisons of infinities with finite numbers. 115 """ 116 if tol < 0 or rel < 0: 117 raise ValueError('error tolerances must be non-negative') 118 # NANs are never equal to anything, approximately or otherwise. 119 if math.isnan(x) or math.isnan(y): 120 return False 121 # Numbers which compare equal also compare approximately equal. 122 if x == y: 123 # This includes the case of two infinities with the same sign. 124 return True 125 if math.isinf(x) or math.isinf(y): 126 # This includes the case of two infinities of opposite sign, or 127 # one infinity and one finite number. 128 return False 129 # Two finite numbers. 130 actual_error = abs(x - y) 131 allowed_error = max(tol, rel*max(abs(x), abs(y))) 132 return actual_error <= allowed_error 133 134 135# This class exists only as somewhere to stick a docstring containing 136# doctests. The following docstring and tests were originally in a separate 137# module. Now that it has been merged in here, I need somewhere to hang the. 138# docstring. Ultimately, this class will die, and the information below will 139# either become redundant, or be moved into more appropriate places. 140class _DoNothing: 141 """ 142 When doing numeric work, especially with floats, exact equality is often 143 not what you want. Due to round-off error, it is often a bad idea to try 144 to compare floats with equality. Instead the usual procedure is to test 145 them with some (hopefully small!) allowance for error. 146 147 The ``approx_equal`` function allows you to specify either an absolute 148 error tolerance, or a relative error, or both. 149 150 Absolute error tolerances are simple, but you need to know the magnitude 151 of the quantities being compared: 152 153 >>> approx_equal(12.345, 12.346, tol=1e-3) 154 True 155 >>> approx_equal(12.345e6, 12.346e6, tol=1e-3) # tol is too small. 156 False 157 158 Relative errors are more suitable when the values you are comparing can 159 vary in magnitude: 160 161 >>> approx_equal(12.345, 12.346, rel=1e-4) 162 True 163 >>> approx_equal(12.345e6, 12.346e6, rel=1e-4) 164 True 165 166 but a naive implementation of relative error testing can run into trouble 167 around zero. 168 169 If you supply both an absolute tolerance and a relative error, the 170 comparison succeeds if either individual test succeeds: 171 172 >>> approx_equal(12.345e6, 12.346e6, tol=1e-3, rel=1e-4) 173 True 174 175 """ 176 pass 177 178 179 180# We prefer this for testing numeric values that may not be exactly equal, 181# and avoid using TestCase.assertAlmostEqual, because it sucks :-) 182 183py_statistics = import_helper.import_fresh_module('statistics', 184 blocked=['_statistics']) 185c_statistics = import_helper.import_fresh_module('statistics', 186 fresh=['_statistics']) 187 188 189class TestModules(unittest.TestCase): 190 func_names = ['_normal_dist_inv_cdf'] 191 192 def test_py_functions(self): 193 for fname in self.func_names: 194 self.assertEqual(getattr(py_statistics, fname).__module__, 'statistics') 195 196 @unittest.skipUnless(c_statistics, 'requires _statistics') 197 def test_c_functions(self): 198 for fname in self.func_names: 199 self.assertEqual(getattr(c_statistics, fname).__module__, '_statistics') 200 201 202class NumericTestCase(unittest.TestCase): 203 """Unit test class for numeric work. 204 205 This subclasses TestCase. In addition to the standard method 206 ``TestCase.assertAlmostEqual``, ``assertApproxEqual`` is provided. 207 """ 208 # By default, we expect exact equality, unless overridden. 209 tol = rel = 0 210 211 def assertApproxEqual( 212 self, first, second, tol=None, rel=None, msg=None 213 ): 214 """Test passes if ``first`` and ``second`` are approximately equal. 215 216 This test passes if ``first`` and ``second`` are equal to 217 within ``tol``, an absolute error, or ``rel``, a relative error. 218 219 If either ``tol`` or ``rel`` are None or not given, they default to 220 test attributes of the same name (by default, 0). 221 222 The objects may be either numbers, or sequences of numbers. Sequences 223 are tested element-by-element. 224 225 >>> class MyTest(NumericTestCase): 226 ... def test_number(self): 227 ... x = 1.0/6 228 ... y = sum([x]*6) 229 ... self.assertApproxEqual(y, 1.0, tol=1e-15) 230 ... def test_sequence(self): 231 ... a = [1.001, 1.001e-10, 1.001e10] 232 ... b = [1.0, 1e-10, 1e10] 233 ... self.assertApproxEqual(a, b, rel=1e-3) 234 ... 235 >>> import unittest 236 >>> from io import StringIO # Suppress test runner output. 237 >>> suite = unittest.TestLoader().loadTestsFromTestCase(MyTest) 238 >>> unittest.TextTestRunner(stream=StringIO()).run(suite) 239 <unittest.runner.TextTestResult run=2 errors=0 failures=0> 240 241 """ 242 if tol is None: 243 tol = self.tol 244 if rel is None: 245 rel = self.rel 246 if ( 247 isinstance(first, collections.abc.Sequence) and 248 isinstance(second, collections.abc.Sequence) 249 ): 250 check = self._check_approx_seq 251 else: 252 check = self._check_approx_num 253 check(first, second, tol, rel, msg) 254 255 def _check_approx_seq(self, first, second, tol, rel, msg): 256 if len(first) != len(second): 257 standardMsg = ( 258 "sequences differ in length: %d items != %d items" 259 % (len(first), len(second)) 260 ) 261 msg = self._formatMessage(msg, standardMsg) 262 raise self.failureException(msg) 263 for i, (a,e) in enumerate(zip(first, second)): 264 self._check_approx_num(a, e, tol, rel, msg, i) 265 266 def _check_approx_num(self, first, second, tol, rel, msg, idx=None): 267 if approx_equal(first, second, tol, rel): 268 # Test passes. Return early, we are done. 269 return None 270 # Otherwise we failed. 271 standardMsg = self._make_std_err_msg(first, second, tol, rel, idx) 272 msg = self._formatMessage(msg, standardMsg) 273 raise self.failureException(msg) 274 275 @staticmethod 276 def _make_std_err_msg(first, second, tol, rel, idx): 277 # Create the standard error message for approx_equal failures. 278 assert first != second 279 template = ( 280 ' %r != %r\n' 281 ' values differ by more than tol=%r and rel=%r\n' 282 ' -> absolute error = %r\n' 283 ' -> relative error = %r' 284 ) 285 if idx is not None: 286 header = 'numeric sequences first differ at index %d.\n' % idx 287 template = header + template 288 # Calculate actual errors: 289 abs_err, rel_err = _calc_errors(first, second) 290 return template % (first, second, tol, rel, abs_err, rel_err) 291 292 293# ======================== 294# === Test the helpers === 295# ======================== 296 297class TestSign(unittest.TestCase): 298 """Test that the helper function sign() works correctly.""" 299 def testZeroes(self): 300 # Test that signed zeroes report their sign correctly. 301 self.assertEqual(sign(0.0), +1) 302 self.assertEqual(sign(-0.0), -1) 303 304 305# --- Tests for approx_equal --- 306 307class ApproxEqualSymmetryTest(unittest.TestCase): 308 # Test symmetry of approx_equal. 309 310 def test_relative_symmetry(self): 311 # Check that approx_equal treats relative error symmetrically. 312 # (a-b)/a is usually not equal to (a-b)/b. Ensure that this 313 # doesn't matter. 314 # 315 # Note: the reason for this test is that an early version 316 # of approx_equal was not symmetric. A relative error test 317 # would pass, or fail, depending on which value was passed 318 # as the first argument. 319 # 320 args1 = [2456, 37.8, -12.45, Decimal('2.54'), Fraction(17, 54)] 321 args2 = [2459, 37.2, -12.41, Decimal('2.59'), Fraction(15, 54)] 322 assert len(args1) == len(args2) 323 for a, b in zip(args1, args2): 324 self.do_relative_symmetry(a, b) 325 326 def do_relative_symmetry(self, a, b): 327 a, b = min(a, b), max(a, b) 328 assert a < b 329 delta = b - a # The absolute difference between the values. 330 rel_err1, rel_err2 = abs(delta/a), abs(delta/b) 331 # Choose an error margin halfway between the two. 332 rel = (rel_err1 + rel_err2)/2 333 # Now see that values a and b compare approx equal regardless of 334 # which is given first. 335 self.assertTrue(approx_equal(a, b, tol=0, rel=rel)) 336 self.assertTrue(approx_equal(b, a, tol=0, rel=rel)) 337 338 def test_symmetry(self): 339 # Test that approx_equal(a, b) == approx_equal(b, a) 340 args = [-23, -2, 5, 107, 93568] 341 delta = 2 342 for a in args: 343 for type_ in (int, float, Decimal, Fraction): 344 x = type_(a)*100 345 y = x + delta 346 r = abs(delta/max(x, y)) 347 # There are five cases to check: 348 # 1) actual error <= tol, <= rel 349 self.do_symmetry_test(x, y, tol=delta, rel=r) 350 self.do_symmetry_test(x, y, tol=delta+1, rel=2*r) 351 # 2) actual error > tol, > rel 352 self.do_symmetry_test(x, y, tol=delta-1, rel=r/2) 353 # 3) actual error <= tol, > rel 354 self.do_symmetry_test(x, y, tol=delta, rel=r/2) 355 # 4) actual error > tol, <= rel 356 self.do_symmetry_test(x, y, tol=delta-1, rel=r) 357 self.do_symmetry_test(x, y, tol=delta-1, rel=2*r) 358 # 5) exact equality test 359 self.do_symmetry_test(x, x, tol=0, rel=0) 360 self.do_symmetry_test(x, y, tol=0, rel=0) 361 362 def do_symmetry_test(self, a, b, tol, rel): 363 template = "approx_equal comparisons don't match for %r" 364 flag1 = approx_equal(a, b, tol, rel) 365 flag2 = approx_equal(b, a, tol, rel) 366 self.assertEqual(flag1, flag2, template.format((a, b, tol, rel))) 367 368 369class ApproxEqualExactTest(unittest.TestCase): 370 # Test the approx_equal function with exactly equal values. 371 # Equal values should compare as approximately equal. 372 # Test cases for exactly equal values, which should compare approx 373 # equal regardless of the error tolerances given. 374 375 def do_exactly_equal_test(self, x, tol, rel): 376 result = approx_equal(x, x, tol=tol, rel=rel) 377 self.assertTrue(result, 'equality failure for x=%r' % x) 378 result = approx_equal(-x, -x, tol=tol, rel=rel) 379 self.assertTrue(result, 'equality failure for x=%r' % -x) 380 381 def test_exactly_equal_ints(self): 382 # Test that equal int values are exactly equal. 383 for n in [42, 19740, 14974, 230, 1795, 700245, 36587]: 384 self.do_exactly_equal_test(n, 0, 0) 385 386 def test_exactly_equal_floats(self): 387 # Test that equal float values are exactly equal. 388 for x in [0.42, 1.9740, 1497.4, 23.0, 179.5, 70.0245, 36.587]: 389 self.do_exactly_equal_test(x, 0, 0) 390 391 def test_exactly_equal_fractions(self): 392 # Test that equal Fraction values are exactly equal. 393 F = Fraction 394 for f in [F(1, 2), F(0), F(5, 3), F(9, 7), F(35, 36), F(3, 7)]: 395 self.do_exactly_equal_test(f, 0, 0) 396 397 def test_exactly_equal_decimals(self): 398 # Test that equal Decimal values are exactly equal. 399 D = Decimal 400 for d in map(D, "8.2 31.274 912.04 16.745 1.2047".split()): 401 self.do_exactly_equal_test(d, 0, 0) 402 403 def test_exactly_equal_absolute(self): 404 # Test that equal values are exactly equal with an absolute error. 405 for n in [16, 1013, 1372, 1198, 971, 4]: 406 # Test as ints. 407 self.do_exactly_equal_test(n, 0.01, 0) 408 # Test as floats. 409 self.do_exactly_equal_test(n/10, 0.01, 0) 410 # Test as Fractions. 411 f = Fraction(n, 1234) 412 self.do_exactly_equal_test(f, 0.01, 0) 413 414 def test_exactly_equal_absolute_decimals(self): 415 # Test equal Decimal values are exactly equal with an absolute error. 416 self.do_exactly_equal_test(Decimal("3.571"), Decimal("0.01"), 0) 417 self.do_exactly_equal_test(-Decimal("81.3971"), Decimal("0.01"), 0) 418 419 def test_exactly_equal_relative(self): 420 # Test that equal values are exactly equal with a relative error. 421 for x in [8347, 101.3, -7910.28, Fraction(5, 21)]: 422 self.do_exactly_equal_test(x, 0, 0.01) 423 self.do_exactly_equal_test(Decimal("11.68"), 0, Decimal("0.01")) 424 425 def test_exactly_equal_both(self): 426 # Test that equal values are equal when both tol and rel are given. 427 for x in [41017, 16.742, -813.02, Fraction(3, 8)]: 428 self.do_exactly_equal_test(x, 0.1, 0.01) 429 D = Decimal 430 self.do_exactly_equal_test(D("7.2"), D("0.1"), D("0.01")) 431 432 433class ApproxEqualUnequalTest(unittest.TestCase): 434 # Unequal values should compare unequal with zero error tolerances. 435 # Test cases for unequal values, with exact equality test. 436 437 def do_exactly_unequal_test(self, x): 438 for a in (x, -x): 439 result = approx_equal(a, a+1, tol=0, rel=0) 440 self.assertFalse(result, 'inequality failure for x=%r' % a) 441 442 def test_exactly_unequal_ints(self): 443 # Test unequal int values are unequal with zero error tolerance. 444 for n in [951, 572305, 478, 917, 17240]: 445 self.do_exactly_unequal_test(n) 446 447 def test_exactly_unequal_floats(self): 448 # Test unequal float values are unequal with zero error tolerance. 449 for x in [9.51, 5723.05, 47.8, 9.17, 17.24]: 450 self.do_exactly_unequal_test(x) 451 452 def test_exactly_unequal_fractions(self): 453 # Test that unequal Fractions are unequal with zero error tolerance. 454 F = Fraction 455 for f in [F(1, 5), F(7, 9), F(12, 11), F(101, 99023)]: 456 self.do_exactly_unequal_test(f) 457 458 def test_exactly_unequal_decimals(self): 459 # Test that unequal Decimals are unequal with zero error tolerance. 460 for d in map(Decimal, "3.1415 298.12 3.47 18.996 0.00245".split()): 461 self.do_exactly_unequal_test(d) 462 463 464class ApproxEqualInexactTest(unittest.TestCase): 465 # Inexact test cases for approx_error. 466 # Test cases when comparing two values that are not exactly equal. 467 468 # === Absolute error tests === 469 470 def do_approx_equal_abs_test(self, x, delta): 471 template = "Test failure for x={!r}, y={!r}" 472 for y in (x + delta, x - delta): 473 msg = template.format(x, y) 474 self.assertTrue(approx_equal(x, y, tol=2*delta, rel=0), msg) 475 self.assertFalse(approx_equal(x, y, tol=delta/2, rel=0), msg) 476 477 def test_approx_equal_absolute_ints(self): 478 # Test approximate equality of ints with an absolute error. 479 for n in [-10737, -1975, -7, -2, 0, 1, 9, 37, 423, 9874, 23789110]: 480 self.do_approx_equal_abs_test(n, 10) 481 self.do_approx_equal_abs_test(n, 2) 482 483 def test_approx_equal_absolute_floats(self): 484 # Test approximate equality of floats with an absolute error. 485 for x in [-284.126, -97.1, -3.4, -2.15, 0.5, 1.0, 7.8, 4.23, 3817.4]: 486 self.do_approx_equal_abs_test(x, 1.5) 487 self.do_approx_equal_abs_test(x, 0.01) 488 self.do_approx_equal_abs_test(x, 0.0001) 489 490 def test_approx_equal_absolute_fractions(self): 491 # Test approximate equality of Fractions with an absolute error. 492 delta = Fraction(1, 29) 493 numerators = [-84, -15, -2, -1, 0, 1, 5, 17, 23, 34, 71] 494 for f in (Fraction(n, 29) for n in numerators): 495 self.do_approx_equal_abs_test(f, delta) 496 self.do_approx_equal_abs_test(f, float(delta)) 497 498 def test_approx_equal_absolute_decimals(self): 499 # Test approximate equality of Decimals with an absolute error. 500 delta = Decimal("0.01") 501 for d in map(Decimal, "1.0 3.5 36.08 61.79 7912.3648".split()): 502 self.do_approx_equal_abs_test(d, delta) 503 self.do_approx_equal_abs_test(-d, delta) 504 505 def test_cross_zero(self): 506 # Test for the case of the two values having opposite signs. 507 self.assertTrue(approx_equal(1e-5, -1e-5, tol=1e-4, rel=0)) 508 509 # === Relative error tests === 510 511 def do_approx_equal_rel_test(self, x, delta): 512 template = "Test failure for x={!r}, y={!r}" 513 for y in (x*(1+delta), x*(1-delta)): 514 msg = template.format(x, y) 515 self.assertTrue(approx_equal(x, y, tol=0, rel=2*delta), msg) 516 self.assertFalse(approx_equal(x, y, tol=0, rel=delta/2), msg) 517 518 def test_approx_equal_relative_ints(self): 519 # Test approximate equality of ints with a relative error. 520 self.assertTrue(approx_equal(64, 47, tol=0, rel=0.36)) 521 self.assertTrue(approx_equal(64, 47, tol=0, rel=0.37)) 522 # --- 523 self.assertTrue(approx_equal(449, 512, tol=0, rel=0.125)) 524 self.assertTrue(approx_equal(448, 512, tol=0, rel=0.125)) 525 self.assertFalse(approx_equal(447, 512, tol=0, rel=0.125)) 526 527 def test_approx_equal_relative_floats(self): 528 # Test approximate equality of floats with a relative error. 529 for x in [-178.34, -0.1, 0.1, 1.0, 36.97, 2847.136, 9145.074]: 530 self.do_approx_equal_rel_test(x, 0.02) 531 self.do_approx_equal_rel_test(x, 0.0001) 532 533 def test_approx_equal_relative_fractions(self): 534 # Test approximate equality of Fractions with a relative error. 535 F = Fraction 536 delta = Fraction(3, 8) 537 for f in [F(3, 84), F(17, 30), F(49, 50), F(92, 85)]: 538 for d in (delta, float(delta)): 539 self.do_approx_equal_rel_test(f, d) 540 self.do_approx_equal_rel_test(-f, d) 541 542 def test_approx_equal_relative_decimals(self): 543 # Test approximate equality of Decimals with a relative error. 544 for d in map(Decimal, "0.02 1.0 5.7 13.67 94.138 91027.9321".split()): 545 self.do_approx_equal_rel_test(d, Decimal("0.001")) 546 self.do_approx_equal_rel_test(-d, Decimal("0.05")) 547 548 # === Both absolute and relative error tests === 549 550 # There are four cases to consider: 551 # 1) actual error <= both absolute and relative error 552 # 2) actual error <= absolute error but > relative error 553 # 3) actual error <= relative error but > absolute error 554 # 4) actual error > both absolute and relative error 555 556 def do_check_both(self, a, b, tol, rel, tol_flag, rel_flag): 557 check = self.assertTrue if tol_flag else self.assertFalse 558 check(approx_equal(a, b, tol=tol, rel=0)) 559 check = self.assertTrue if rel_flag else self.assertFalse 560 check(approx_equal(a, b, tol=0, rel=rel)) 561 check = self.assertTrue if (tol_flag or rel_flag) else self.assertFalse 562 check(approx_equal(a, b, tol=tol, rel=rel)) 563 564 def test_approx_equal_both1(self): 565 # Test actual error <= both absolute and relative error. 566 self.do_check_both(7.955, 7.952, 0.004, 3.8e-4, True, True) 567 self.do_check_both(-7.387, -7.386, 0.002, 0.0002, True, True) 568 569 def test_approx_equal_both2(self): 570 # Test actual error <= absolute error but > relative error. 571 self.do_check_both(7.955, 7.952, 0.004, 3.7e-4, True, False) 572 573 def test_approx_equal_both3(self): 574 # Test actual error <= relative error but > absolute error. 575 self.do_check_both(7.955, 7.952, 0.001, 3.8e-4, False, True) 576 577 def test_approx_equal_both4(self): 578 # Test actual error > both absolute and relative error. 579 self.do_check_both(2.78, 2.75, 0.01, 0.001, False, False) 580 self.do_check_both(971.44, 971.47, 0.02, 3e-5, False, False) 581 582 583class ApproxEqualSpecialsTest(unittest.TestCase): 584 # Test approx_equal with NANs and INFs and zeroes. 585 586 def test_inf(self): 587 for type_ in (float, Decimal): 588 inf = type_('inf') 589 self.assertTrue(approx_equal(inf, inf)) 590 self.assertTrue(approx_equal(inf, inf, 0, 0)) 591 self.assertTrue(approx_equal(inf, inf, 1, 0.01)) 592 self.assertTrue(approx_equal(-inf, -inf)) 593 self.assertFalse(approx_equal(inf, -inf)) 594 self.assertFalse(approx_equal(inf, 1000)) 595 596 def test_nan(self): 597 for type_ in (float, Decimal): 598 nan = type_('nan') 599 for other in (nan, type_('inf'), 1000): 600 self.assertFalse(approx_equal(nan, other)) 601 602 def test_float_zeroes(self): 603 nzero = math.copysign(0.0, -1) 604 self.assertTrue(approx_equal(nzero, 0.0, tol=0.1, rel=0.1)) 605 606 def test_decimal_zeroes(self): 607 nzero = Decimal("-0.0") 608 self.assertTrue(approx_equal(nzero, Decimal(0), tol=0.1, rel=0.1)) 609 610 611class TestApproxEqualErrors(unittest.TestCase): 612 # Test error conditions of approx_equal. 613 614 def test_bad_tol(self): 615 # Test negative tol raises. 616 self.assertRaises(ValueError, approx_equal, 100, 100, -1, 0.1) 617 618 def test_bad_rel(self): 619 # Test negative rel raises. 620 self.assertRaises(ValueError, approx_equal, 100, 100, 1, -0.1) 621 622 623# --- Tests for NumericTestCase --- 624 625# The formatting routine that generates the error messages is complex enough 626# that it too needs testing. 627 628class TestNumericTestCase(unittest.TestCase): 629 # The exact wording of NumericTestCase error messages is *not* guaranteed, 630 # but we need to give them some sort of test to ensure that they are 631 # generated correctly. As a compromise, we look for specific substrings 632 # that are expected to be found even if the overall error message changes. 633 634 def do_test(self, args): 635 actual_msg = NumericTestCase._make_std_err_msg(*args) 636 expected = self.generate_substrings(*args) 637 for substring in expected: 638 self.assertIn(substring, actual_msg) 639 640 def test_numerictestcase_is_testcase(self): 641 # Ensure that NumericTestCase actually is a TestCase. 642 self.assertTrue(issubclass(NumericTestCase, unittest.TestCase)) 643 644 def test_error_msg_numeric(self): 645 # Test the error message generated for numeric comparisons. 646 args = (2.5, 4.0, 0.5, 0.25, None) 647 self.do_test(args) 648 649 def test_error_msg_sequence(self): 650 # Test the error message generated for sequence comparisons. 651 args = (3.75, 8.25, 1.25, 0.5, 7) 652 self.do_test(args) 653 654 def generate_substrings(self, first, second, tol, rel, idx): 655 """Return substrings we expect to see in error messages.""" 656 abs_err, rel_err = _calc_errors(first, second) 657 substrings = [ 658 'tol=%r' % tol, 659 'rel=%r' % rel, 660 'absolute error = %r' % abs_err, 661 'relative error = %r' % rel_err, 662 ] 663 if idx is not None: 664 substrings.append('differ at index %d' % idx) 665 return substrings 666 667 668# ======================================= 669# === Tests for the statistics module === 670# ======================================= 671 672 673class GlobalsTest(unittest.TestCase): 674 module = statistics 675 expected_metadata = ["__doc__", "__all__"] 676 677 def test_meta(self): 678 # Test for the existence of metadata. 679 for meta in self.expected_metadata: 680 self.assertTrue(hasattr(self.module, meta), 681 "%s not present" % meta) 682 683 def test_check_all(self): 684 # Check everything in __all__ exists and is public. 685 module = self.module 686 for name in module.__all__: 687 # No private names in __all__: 688 self.assertFalse(name.startswith("_"), 689 'private name "%s" in __all__' % name) 690 # And anything in __all__ must exist: 691 self.assertTrue(hasattr(module, name), 692 'missing name "%s" in __all__' % name) 693 694 695class DocTests(unittest.TestCase): 696 @unittest.skipIf(sys.flags.optimize >= 2, 697 "Docstrings are omitted with -OO and above") 698 def test_doc_tests(self): 699 failed, tried = doctest.testmod(statistics, optionflags=doctest.ELLIPSIS) 700 self.assertGreater(tried, 0) 701 self.assertEqual(failed, 0) 702 703class StatisticsErrorTest(unittest.TestCase): 704 def test_has_exception(self): 705 errmsg = ( 706 "Expected StatisticsError to be a ValueError, but got a" 707 " subclass of %r instead." 708 ) 709 self.assertTrue(hasattr(statistics, 'StatisticsError')) 710 self.assertTrue( 711 issubclass(statistics.StatisticsError, ValueError), 712 errmsg % statistics.StatisticsError.__base__ 713 ) 714 715 716# === Tests for private utility functions === 717 718class ExactRatioTest(unittest.TestCase): 719 # Test _exact_ratio utility. 720 721 def test_int(self): 722 for i in (-20, -3, 0, 5, 99, 10**20): 723 self.assertEqual(statistics._exact_ratio(i), (i, 1)) 724 725 def test_fraction(self): 726 numerators = (-5, 1, 12, 38) 727 for n in numerators: 728 f = Fraction(n, 37) 729 self.assertEqual(statistics._exact_ratio(f), (n, 37)) 730 731 def test_float(self): 732 self.assertEqual(statistics._exact_ratio(0.125), (1, 8)) 733 self.assertEqual(statistics._exact_ratio(1.125), (9, 8)) 734 data = [random.uniform(-100, 100) for _ in range(100)] 735 for x in data: 736 num, den = statistics._exact_ratio(x) 737 self.assertEqual(x, num/den) 738 739 def test_decimal(self): 740 D = Decimal 741 _exact_ratio = statistics._exact_ratio 742 self.assertEqual(_exact_ratio(D("0.125")), (1, 8)) 743 self.assertEqual(_exact_ratio(D("12.345")), (2469, 200)) 744 self.assertEqual(_exact_ratio(D("-1.98")), (-99, 50)) 745 746 def test_inf(self): 747 INF = float("INF") 748 class MyFloat(float): 749 pass 750 class MyDecimal(Decimal): 751 pass 752 for inf in (INF, -INF): 753 for type_ in (float, MyFloat, Decimal, MyDecimal): 754 x = type_(inf) 755 ratio = statistics._exact_ratio(x) 756 self.assertEqual(ratio, (x, None)) 757 self.assertEqual(type(ratio[0]), type_) 758 self.assertTrue(math.isinf(ratio[0])) 759 760 def test_float_nan(self): 761 NAN = float("NAN") 762 class MyFloat(float): 763 pass 764 for nan in (NAN, MyFloat(NAN)): 765 ratio = statistics._exact_ratio(nan) 766 self.assertTrue(math.isnan(ratio[0])) 767 self.assertIs(ratio[1], None) 768 self.assertEqual(type(ratio[0]), type(nan)) 769 770 def test_decimal_nan(self): 771 NAN = Decimal("NAN") 772 sNAN = Decimal("sNAN") 773 class MyDecimal(Decimal): 774 pass 775 for nan in (NAN, MyDecimal(NAN), sNAN, MyDecimal(sNAN)): 776 ratio = statistics._exact_ratio(nan) 777 self.assertTrue(_nan_equal(ratio[0], nan)) 778 self.assertIs(ratio[1], None) 779 self.assertEqual(type(ratio[0]), type(nan)) 780 781 782class DecimalToRatioTest(unittest.TestCase): 783 # Test _exact_ratio private function. 784 785 def test_infinity(self): 786 # Test that INFs are handled correctly. 787 inf = Decimal('INF') 788 self.assertEqual(statistics._exact_ratio(inf), (inf, None)) 789 self.assertEqual(statistics._exact_ratio(-inf), (-inf, None)) 790 791 def test_nan(self): 792 # Test that NANs are handled correctly. 793 for nan in (Decimal('NAN'), Decimal('sNAN')): 794 num, den = statistics._exact_ratio(nan) 795 # Because NANs always compare non-equal, we cannot use assertEqual. 796 # Nor can we use an identity test, as we don't guarantee anything 797 # about the object identity. 798 self.assertTrue(_nan_equal(num, nan)) 799 self.assertIs(den, None) 800 801 def test_sign(self): 802 # Test sign is calculated correctly. 803 numbers = [Decimal("9.8765e12"), Decimal("9.8765e-12")] 804 for d in numbers: 805 # First test positive decimals. 806 assert d > 0 807 num, den = statistics._exact_ratio(d) 808 self.assertGreaterEqual(num, 0) 809 self.assertGreater(den, 0) 810 # Then test negative decimals. 811 num, den = statistics._exact_ratio(-d) 812 self.assertLessEqual(num, 0) 813 self.assertGreater(den, 0) 814 815 def test_negative_exponent(self): 816 # Test result when the exponent is negative. 817 t = statistics._exact_ratio(Decimal("0.1234")) 818 self.assertEqual(t, (617, 5000)) 819 820 def test_positive_exponent(self): 821 # Test results when the exponent is positive. 822 t = statistics._exact_ratio(Decimal("1.234e7")) 823 self.assertEqual(t, (12340000, 1)) 824 825 def test_regression_20536(self): 826 # Regression test for issue 20536. 827 # See http://bugs.python.org/issue20536 828 t = statistics._exact_ratio(Decimal("1e2")) 829 self.assertEqual(t, (100, 1)) 830 t = statistics._exact_ratio(Decimal("1.47e5")) 831 self.assertEqual(t, (147000, 1)) 832 833 834class IsFiniteTest(unittest.TestCase): 835 # Test _isfinite private function. 836 837 def test_finite(self): 838 # Test that finite numbers are recognised as finite. 839 for x in (5, Fraction(1, 3), 2.5, Decimal("5.5")): 840 self.assertTrue(statistics._isfinite(x)) 841 842 def test_infinity(self): 843 # Test that INFs are not recognised as finite. 844 for x in (float("inf"), Decimal("inf")): 845 self.assertFalse(statistics._isfinite(x)) 846 847 def test_nan(self): 848 # Test that NANs are not recognised as finite. 849 for x in (float("nan"), Decimal("NAN"), Decimal("sNAN")): 850 self.assertFalse(statistics._isfinite(x)) 851 852 853class CoerceTest(unittest.TestCase): 854 # Test that private function _coerce correctly deals with types. 855 856 # The coercion rules are currently an implementation detail, although at 857 # some point that should change. The tests and comments here define the 858 # correct implementation. 859 860 # Pre-conditions of _coerce: 861 # 862 # - The first time _sum calls _coerce, the 863 # - coerce(T, S) will never be called with bool as the first argument; 864 # this is a pre-condition, guarded with an assertion. 865 866 # 867 # - coerce(T, T) will always return T; we assume T is a valid numeric 868 # type. Violate this assumption at your own risk. 869 # 870 # - Apart from as above, bool is treated as if it were actually int. 871 # 872 # - coerce(int, X) and coerce(X, int) return X. 873 # - 874 def test_bool(self): 875 # bool is somewhat special, due to the pre-condition that it is 876 # never given as the first argument to _coerce, and that it cannot 877 # be subclassed. So we test it specially. 878 for T in (int, float, Fraction, Decimal): 879 self.assertIs(statistics._coerce(T, bool), T) 880 class MyClass(T): pass 881 self.assertIs(statistics._coerce(MyClass, bool), MyClass) 882 883 def assertCoerceTo(self, A, B): 884 """Assert that type A coerces to B.""" 885 self.assertIs(statistics._coerce(A, B), B) 886 self.assertIs(statistics._coerce(B, A), B) 887 888 def check_coerce_to(self, A, B): 889 """Checks that type A coerces to B, including subclasses.""" 890 # Assert that type A is coerced to B. 891 self.assertCoerceTo(A, B) 892 # Subclasses of A are also coerced to B. 893 class SubclassOfA(A): pass 894 self.assertCoerceTo(SubclassOfA, B) 895 # A, and subclasses of A, are coerced to subclasses of B. 896 class SubclassOfB(B): pass 897 self.assertCoerceTo(A, SubclassOfB) 898 self.assertCoerceTo(SubclassOfA, SubclassOfB) 899 900 def assertCoerceRaises(self, A, B): 901 """Assert that coercing A to B, or vice versa, raises TypeError.""" 902 self.assertRaises(TypeError, statistics._coerce, (A, B)) 903 self.assertRaises(TypeError, statistics._coerce, (B, A)) 904 905 def check_type_coercions(self, T): 906 """Check that type T coerces correctly with subclasses of itself.""" 907 assert T is not bool 908 # Coercing a type with itself returns the same type. 909 self.assertIs(statistics._coerce(T, T), T) 910 # Coercing a type with a subclass of itself returns the subclass. 911 class U(T): pass 912 class V(T): pass 913 class W(U): pass 914 for typ in (U, V, W): 915 self.assertCoerceTo(T, typ) 916 self.assertCoerceTo(U, W) 917 # Coercing two subclasses that aren't parent/child is an error. 918 self.assertCoerceRaises(U, V) 919 self.assertCoerceRaises(V, W) 920 921 def test_int(self): 922 # Check that int coerces correctly. 923 self.check_type_coercions(int) 924 for typ in (float, Fraction, Decimal): 925 self.check_coerce_to(int, typ) 926 927 def test_fraction(self): 928 # Check that Fraction coerces correctly. 929 self.check_type_coercions(Fraction) 930 self.check_coerce_to(Fraction, float) 931 932 def test_decimal(self): 933 # Check that Decimal coerces correctly. 934 self.check_type_coercions(Decimal) 935 936 def test_float(self): 937 # Check that float coerces correctly. 938 self.check_type_coercions(float) 939 940 def test_non_numeric_types(self): 941 for bad_type in (str, list, type(None), tuple, dict): 942 for good_type in (int, float, Fraction, Decimal): 943 self.assertCoerceRaises(good_type, bad_type) 944 945 def test_incompatible_types(self): 946 # Test that incompatible types raise. 947 for T in (float, Fraction): 948 class MySubclass(T): pass 949 self.assertCoerceRaises(T, Decimal) 950 self.assertCoerceRaises(MySubclass, Decimal) 951 952 953class ConvertTest(unittest.TestCase): 954 # Test private _convert function. 955 956 def check_exact_equal(self, x, y): 957 """Check that x equals y, and has the same type as well.""" 958 self.assertEqual(x, y) 959 self.assertIs(type(x), type(y)) 960 961 def test_int(self): 962 # Test conversions to int. 963 x = statistics._convert(Fraction(71), int) 964 self.check_exact_equal(x, 71) 965 class MyInt(int): pass 966 x = statistics._convert(Fraction(17), MyInt) 967 self.check_exact_equal(x, MyInt(17)) 968 969 def test_fraction(self): 970 # Test conversions to Fraction. 971 x = statistics._convert(Fraction(95, 99), Fraction) 972 self.check_exact_equal(x, Fraction(95, 99)) 973 class MyFraction(Fraction): 974 def __truediv__(self, other): 975 return self.__class__(super().__truediv__(other)) 976 x = statistics._convert(Fraction(71, 13), MyFraction) 977 self.check_exact_equal(x, MyFraction(71, 13)) 978 979 def test_float(self): 980 # Test conversions to float. 981 x = statistics._convert(Fraction(-1, 2), float) 982 self.check_exact_equal(x, -0.5) 983 class MyFloat(float): 984 def __truediv__(self, other): 985 return self.__class__(super().__truediv__(other)) 986 x = statistics._convert(Fraction(9, 8), MyFloat) 987 self.check_exact_equal(x, MyFloat(1.125)) 988 989 def test_decimal(self): 990 # Test conversions to Decimal. 991 x = statistics._convert(Fraction(1, 40), Decimal) 992 self.check_exact_equal(x, Decimal("0.025")) 993 class MyDecimal(Decimal): 994 def __truediv__(self, other): 995 return self.__class__(super().__truediv__(other)) 996 x = statistics._convert(Fraction(-15, 16), MyDecimal) 997 self.check_exact_equal(x, MyDecimal("-0.9375")) 998 999 def test_inf(self): 1000 for INF in (float('inf'), Decimal('inf')): 1001 for inf in (INF, -INF): 1002 x = statistics._convert(inf, type(inf)) 1003 self.check_exact_equal(x, inf) 1004 1005 def test_nan(self): 1006 for nan in (float('nan'), Decimal('NAN'), Decimal('sNAN')): 1007 x = statistics._convert(nan, type(nan)) 1008 self.assertTrue(_nan_equal(x, nan)) 1009 1010 def test_invalid_input_type(self): 1011 with self.assertRaises(TypeError): 1012 statistics._convert(None, float) 1013 1014 1015class FailNegTest(unittest.TestCase): 1016 """Test _fail_neg private function.""" 1017 1018 def test_pass_through(self): 1019 # Test that values are passed through unchanged. 1020 values = [1, 2.0, Fraction(3), Decimal(4)] 1021 new = list(statistics._fail_neg(values)) 1022 self.assertEqual(values, new) 1023 1024 def test_negatives_raise(self): 1025 # Test that negatives raise an exception. 1026 for x in [1, 2.0, Fraction(3), Decimal(4)]: 1027 seq = [-x] 1028 it = statistics._fail_neg(seq) 1029 self.assertRaises(statistics.StatisticsError, next, it) 1030 1031 def test_error_msg(self): 1032 # Test that a given error message is used. 1033 msg = "badness #%d" % random.randint(10000, 99999) 1034 try: 1035 next(statistics._fail_neg([-1], msg)) 1036 except statistics.StatisticsError as e: 1037 errmsg = e.args[0] 1038 else: 1039 self.fail("expected exception, but it didn't happen") 1040 self.assertEqual(errmsg, msg) 1041 1042 1043class FindLteqTest(unittest.TestCase): 1044 # Test _find_lteq private function. 1045 1046 def test_invalid_input_values(self): 1047 for a, x in [ 1048 ([], 1), 1049 ([1, 2], 3), 1050 ([1, 3], 2) 1051 ]: 1052 with self.subTest(a=a, x=x): 1053 with self.assertRaises(ValueError): 1054 statistics._find_lteq(a, x) 1055 1056 def test_locate_successfully(self): 1057 for a, x, expected_i in [ 1058 ([1, 1, 1, 2, 3], 1, 0), 1059 ([0, 1, 1, 1, 2, 3], 1, 1), 1060 ([1, 2, 3, 3, 3], 3, 2) 1061 ]: 1062 with self.subTest(a=a, x=x): 1063 self.assertEqual(expected_i, statistics._find_lteq(a, x)) 1064 1065 1066class FindRteqTest(unittest.TestCase): 1067 # Test _find_rteq private function. 1068 1069 def test_invalid_input_values(self): 1070 for a, l, x in [ 1071 ([1], 2, 1), 1072 ([1, 3], 0, 2) 1073 ]: 1074 with self.assertRaises(ValueError): 1075 statistics._find_rteq(a, l, x) 1076 1077 def test_locate_successfully(self): 1078 for a, l, x, expected_i in [ 1079 ([1, 1, 1, 2, 3], 0, 1, 2), 1080 ([0, 1, 1, 1, 2, 3], 0, 1, 3), 1081 ([1, 2, 3, 3, 3], 0, 3, 4) 1082 ]: 1083 with self.subTest(a=a, l=l, x=x): 1084 self.assertEqual(expected_i, statistics._find_rteq(a, l, x)) 1085 1086 1087# === Tests for public functions === 1088 1089class UnivariateCommonMixin: 1090 # Common tests for most univariate functions that take a data argument. 1091 1092 def test_no_args(self): 1093 # Fail if given no arguments. 1094 self.assertRaises(TypeError, self.func) 1095 1096 def test_empty_data(self): 1097 # Fail when the data argument (first argument) is empty. 1098 for empty in ([], (), iter([])): 1099 self.assertRaises(statistics.StatisticsError, self.func, empty) 1100 1101 def prepare_data(self): 1102 """Return int data for various tests.""" 1103 data = list(range(10)) 1104 while data == sorted(data): 1105 random.shuffle(data) 1106 return data 1107 1108 def test_no_inplace_modifications(self): 1109 # Test that the function does not modify its input data. 1110 data = self.prepare_data() 1111 assert len(data) != 1 # Necessary to avoid infinite loop. 1112 assert data != sorted(data) 1113 saved = data[:] 1114 assert data is not saved 1115 _ = self.func(data) 1116 self.assertListEqual(data, saved, "data has been modified") 1117 1118 def test_order_doesnt_matter(self): 1119 # Test that the order of data points doesn't change the result. 1120 1121 # CAUTION: due to floating point rounding errors, the result actually 1122 # may depend on the order. Consider this test representing an ideal. 1123 # To avoid this test failing, only test with exact values such as ints 1124 # or Fractions. 1125 data = [1, 2, 3, 3, 3, 4, 5, 6]*100 1126 expected = self.func(data) 1127 random.shuffle(data) 1128 actual = self.func(data) 1129 self.assertEqual(expected, actual) 1130 1131 def test_type_of_data_collection(self): 1132 # Test that the type of iterable data doesn't effect the result. 1133 class MyList(list): 1134 pass 1135 class MyTuple(tuple): 1136 pass 1137 def generator(data): 1138 return (obj for obj in data) 1139 data = self.prepare_data() 1140 expected = self.func(data) 1141 for kind in (list, tuple, iter, MyList, MyTuple, generator): 1142 result = self.func(kind(data)) 1143 self.assertEqual(result, expected) 1144 1145 def test_range_data(self): 1146 # Test that functions work with range objects. 1147 data = range(20, 50, 3) 1148 expected = self.func(list(data)) 1149 self.assertEqual(self.func(data), expected) 1150 1151 def test_bad_arg_types(self): 1152 # Test that function raises when given data of the wrong type. 1153 1154 # Don't roll the following into a loop like this: 1155 # for bad in list_of_bad: 1156 # self.check_for_type_error(bad) 1157 # 1158 # Since assertRaises doesn't show the arguments that caused the test 1159 # failure, it is very difficult to debug these test failures when the 1160 # following are in a loop. 1161 self.check_for_type_error(None) 1162 self.check_for_type_error(23) 1163 self.check_for_type_error(42.0) 1164 self.check_for_type_error(object()) 1165 1166 def check_for_type_error(self, *args): 1167 self.assertRaises(TypeError, self.func, *args) 1168 1169 def test_type_of_data_element(self): 1170 # Check the type of data elements doesn't affect the numeric result. 1171 # This is a weaker test than UnivariateTypeMixin.testTypesConserved, 1172 # because it checks the numeric result by equality, but not by type. 1173 class MyFloat(float): 1174 def __truediv__(self, other): 1175 return type(self)(super().__truediv__(other)) 1176 def __add__(self, other): 1177 return type(self)(super().__add__(other)) 1178 __radd__ = __add__ 1179 1180 raw = self.prepare_data() 1181 expected = self.func(raw) 1182 for kind in (float, MyFloat, Decimal, Fraction): 1183 data = [kind(x) for x in raw] 1184 result = type(expected)(self.func(data)) 1185 self.assertEqual(result, expected) 1186 1187 1188class UnivariateTypeMixin: 1189 """Mixin class for type-conserving functions. 1190 1191 This mixin class holds test(s) for functions which conserve the type of 1192 individual data points. E.g. the mean of a list of Fractions should itself 1193 be a Fraction. 1194 1195 Not all tests to do with types need go in this class. Only those that 1196 rely on the function returning the same type as its input data. 1197 """ 1198 def prepare_types_for_conservation_test(self): 1199 """Return the types which are expected to be conserved.""" 1200 class MyFloat(float): 1201 def __truediv__(self, other): 1202 return type(self)(super().__truediv__(other)) 1203 def __rtruediv__(self, other): 1204 return type(self)(super().__rtruediv__(other)) 1205 def __sub__(self, other): 1206 return type(self)(super().__sub__(other)) 1207 def __rsub__(self, other): 1208 return type(self)(super().__rsub__(other)) 1209 def __pow__(self, other): 1210 return type(self)(super().__pow__(other)) 1211 def __add__(self, other): 1212 return type(self)(super().__add__(other)) 1213 __radd__ = __add__ 1214 def __mul__(self, other): 1215 return type(self)(super().__mul__(other)) 1216 __rmul__ = __mul__ 1217 return (float, Decimal, Fraction, MyFloat) 1218 1219 def test_types_conserved(self): 1220 # Test that functions keeps the same type as their data points. 1221 # (Excludes mixed data types.) This only tests the type of the return 1222 # result, not the value. 1223 data = self.prepare_data() 1224 for kind in self.prepare_types_for_conservation_test(): 1225 d = [kind(x) for x in data] 1226 result = self.func(d) 1227 self.assertIs(type(result), kind) 1228 1229 1230class TestSumCommon(UnivariateCommonMixin, UnivariateTypeMixin): 1231 # Common test cases for statistics._sum() function. 1232 1233 # This test suite looks only at the numeric value returned by _sum, 1234 # after conversion to the appropriate type. 1235 def setUp(self): 1236 def simplified_sum(*args): 1237 T, value, n = statistics._sum(*args) 1238 return statistics._coerce(value, T) 1239 self.func = simplified_sum 1240 1241 1242class TestSum(NumericTestCase): 1243 # Test cases for statistics._sum() function. 1244 1245 # These tests look at the entire three value tuple returned by _sum. 1246 1247 def setUp(self): 1248 self.func = statistics._sum 1249 1250 def test_empty_data(self): 1251 # Override test for empty data. 1252 for data in ([], (), iter([])): 1253 self.assertEqual(self.func(data), (int, Fraction(0), 0)) 1254 1255 def test_ints(self): 1256 self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]), 1257 (int, Fraction(60), 8)) 1258 1259 def test_floats(self): 1260 self.assertEqual(self.func([0.25]*20), 1261 (float, Fraction(5.0), 20)) 1262 1263 def test_fractions(self): 1264 self.assertEqual(self.func([Fraction(1, 1000)]*500), 1265 (Fraction, Fraction(1, 2), 500)) 1266 1267 def test_decimals(self): 1268 D = Decimal 1269 data = [D("0.001"), D("5.246"), D("1.702"), D("-0.025"), 1270 D("3.974"), D("2.328"), D("4.617"), D("2.843"), 1271 ] 1272 self.assertEqual(self.func(data), 1273 (Decimal, Decimal("20.686"), 8)) 1274 1275 def test_compare_with_math_fsum(self): 1276 # Compare with the math.fsum function. 1277 # Ideally we ought to get the exact same result, but sometimes 1278 # we differ by a very slight amount :-( 1279 data = [random.uniform(-100, 1000) for _ in range(1000)] 1280 self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16) 1281 1282 def test_strings_fail(self): 1283 # Sum of strings should fail. 1284 self.assertRaises(TypeError, self.func, [1, 2, 3], '999') 1285 self.assertRaises(TypeError, self.func, [1, 2, 3, '999']) 1286 1287 def test_bytes_fail(self): 1288 # Sum of bytes should fail. 1289 self.assertRaises(TypeError, self.func, [1, 2, 3], b'999') 1290 self.assertRaises(TypeError, self.func, [1, 2, 3, b'999']) 1291 1292 def test_mixed_sum(self): 1293 # Mixed input types are not (currently) allowed. 1294 # Check that mixed data types fail. 1295 self.assertRaises(TypeError, self.func, [1, 2.0, Decimal(1)]) 1296 # And so does mixed start argument. 1297 self.assertRaises(TypeError, self.func, [1, 2.0], Decimal(1)) 1298 1299 1300class SumTortureTest(NumericTestCase): 1301 def test_torture(self): 1302 # Tim Peters' torture test for sum, and variants of same. 1303 self.assertEqual(statistics._sum([1, 1e100, 1, -1e100]*10000), 1304 (float, Fraction(20000.0), 40000)) 1305 self.assertEqual(statistics._sum([1e100, 1, 1, -1e100]*10000), 1306 (float, Fraction(20000.0), 40000)) 1307 T, num, count = statistics._sum([1e-100, 1, 1e-100, -1]*10000) 1308 self.assertIs(T, float) 1309 self.assertEqual(count, 40000) 1310 self.assertApproxEqual(float(num), 2.0e-96, rel=5e-16) 1311 1312 1313class SumSpecialValues(NumericTestCase): 1314 # Test that sum works correctly with IEEE-754 special values. 1315 1316 def test_nan(self): 1317 for type_ in (float, Decimal): 1318 nan = type_('nan') 1319 result = statistics._sum([1, nan, 2])[1] 1320 self.assertIs(type(result), type_) 1321 self.assertTrue(math.isnan(result)) 1322 1323 def check_infinity(self, x, inf): 1324 """Check x is an infinity of the same type and sign as inf.""" 1325 self.assertTrue(math.isinf(x)) 1326 self.assertIs(type(x), type(inf)) 1327 self.assertEqual(x > 0, inf > 0) 1328 assert x == inf 1329 1330 def do_test_inf(self, inf): 1331 # Adding a single infinity gives infinity. 1332 result = statistics._sum([1, 2, inf, 3])[1] 1333 self.check_infinity(result, inf) 1334 # Adding two infinities of the same sign also gives infinity. 1335 result = statistics._sum([1, 2, inf, 3, inf, 4])[1] 1336 self.check_infinity(result, inf) 1337 1338 def test_float_inf(self): 1339 inf = float('inf') 1340 for sign in (+1, -1): 1341 self.do_test_inf(sign*inf) 1342 1343 def test_decimal_inf(self): 1344 inf = Decimal('inf') 1345 for sign in (+1, -1): 1346 self.do_test_inf(sign*inf) 1347 1348 def test_float_mismatched_infs(self): 1349 # Test that adding two infinities of opposite sign gives a NAN. 1350 inf = float('inf') 1351 result = statistics._sum([1, 2, inf, 3, -inf, 4])[1] 1352 self.assertTrue(math.isnan(result)) 1353 1354 def test_decimal_extendedcontext_mismatched_infs_to_nan(self): 1355 # Test adding Decimal INFs with opposite sign returns NAN. 1356 inf = Decimal('inf') 1357 data = [1, 2, inf, 3, -inf, 4] 1358 with decimal.localcontext(decimal.ExtendedContext): 1359 self.assertTrue(math.isnan(statistics._sum(data)[1])) 1360 1361 def test_decimal_basiccontext_mismatched_infs_to_nan(self): 1362 # Test adding Decimal INFs with opposite sign raises InvalidOperation. 1363 inf = Decimal('inf') 1364 data = [1, 2, inf, 3, -inf, 4] 1365 with decimal.localcontext(decimal.BasicContext): 1366 self.assertRaises(decimal.InvalidOperation, statistics._sum, data) 1367 1368 def test_decimal_snan_raises(self): 1369 # Adding sNAN should raise InvalidOperation. 1370 sNAN = Decimal('sNAN') 1371 data = [1, sNAN, 2] 1372 self.assertRaises(decimal.InvalidOperation, statistics._sum, data) 1373 1374 1375# === Tests for averages === 1376 1377class AverageMixin(UnivariateCommonMixin): 1378 # Mixin class holding common tests for averages. 1379 1380 def test_single_value(self): 1381 # Average of a single value is the value itself. 1382 for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')): 1383 self.assertEqual(self.func([x]), x) 1384 1385 def prepare_values_for_repeated_single_test(self): 1386 return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712')) 1387 1388 def test_repeated_single_value(self): 1389 # The average of a single repeated value is the value itself. 1390 for x in self.prepare_values_for_repeated_single_test(): 1391 for count in (2, 5, 10, 20): 1392 with self.subTest(x=x, count=count): 1393 data = [x]*count 1394 self.assertEqual(self.func(data), x) 1395 1396 1397class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin): 1398 def setUp(self): 1399 self.func = statistics.mean 1400 1401 def test_torture_pep(self): 1402 # "Torture Test" from PEP-450. 1403 self.assertEqual(self.func([1e100, 1, 3, -1e100]), 1) 1404 1405 def test_ints(self): 1406 # Test mean with ints. 1407 data = [0, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 7, 7, 7, 8, 9] 1408 random.shuffle(data) 1409 self.assertEqual(self.func(data), 4.8125) 1410 1411 def test_floats(self): 1412 # Test mean with floats. 1413 data = [17.25, 19.75, 20.0, 21.5, 21.75, 23.25, 25.125, 27.5] 1414 random.shuffle(data) 1415 self.assertEqual(self.func(data), 22.015625) 1416 1417 def test_decimals(self): 1418 # Test mean with Decimals. 1419 D = Decimal 1420 data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")] 1421 random.shuffle(data) 1422 self.assertEqual(self.func(data), D("3.5896")) 1423 1424 def test_fractions(self): 1425 # Test mean with Fractions. 1426 F = Fraction 1427 data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)] 1428 random.shuffle(data) 1429 self.assertEqual(self.func(data), F(1479, 1960)) 1430 1431 def test_inf(self): 1432 # Test mean with infinities. 1433 raw = [1, 3, 5, 7, 9] # Use only ints, to avoid TypeError later. 1434 for kind in (float, Decimal): 1435 for sign in (1, -1): 1436 inf = kind("inf")*sign 1437 data = raw + [inf] 1438 result = self.func(data) 1439 self.assertTrue(math.isinf(result)) 1440 self.assertEqual(result, inf) 1441 1442 def test_mismatched_infs(self): 1443 # Test mean with infinities of opposite sign. 1444 data = [2, 4, 6, float('inf'), 1, 3, 5, float('-inf')] 1445 result = self.func(data) 1446 self.assertTrue(math.isnan(result)) 1447 1448 def test_nan(self): 1449 # Test mean with NANs. 1450 raw = [1, 3, 5, 7, 9] # Use only ints, to avoid TypeError later. 1451 for kind in (float, Decimal): 1452 inf = kind("nan") 1453 data = raw + [inf] 1454 result = self.func(data) 1455 self.assertTrue(math.isnan(result)) 1456 1457 def test_big_data(self): 1458 # Test adding a large constant to every data point. 1459 c = 1e9 1460 data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4] 1461 expected = self.func(data) + c 1462 assert expected != c 1463 result = self.func([x+c for x in data]) 1464 self.assertEqual(result, expected) 1465 1466 def test_doubled_data(self): 1467 # Mean of [a,b,c...z] should be same as for [a,a,b,b,c,c...z,z]. 1468 data = [random.uniform(-3, 5) for _ in range(1000)] 1469 expected = self.func(data) 1470 actual = self.func(data*2) 1471 self.assertApproxEqual(actual, expected) 1472 1473 def test_regression_20561(self): 1474 # Regression test for issue 20561. 1475 # See http://bugs.python.org/issue20561 1476 d = Decimal('1e4') 1477 self.assertEqual(statistics.mean([d]), d) 1478 1479 def test_regression_25177(self): 1480 # Regression test for issue 25177. 1481 # Ensure very big and very small floats don't overflow. 1482 # See http://bugs.python.org/issue25177. 1483 self.assertEqual(statistics.mean( 1484 [8.988465674311579e+307, 8.98846567431158e+307]), 1485 8.98846567431158e+307) 1486 big = 8.98846567431158e+307 1487 tiny = 5e-324 1488 for n in (2, 3, 5, 200): 1489 self.assertEqual(statistics.mean([big]*n), big) 1490 self.assertEqual(statistics.mean([tiny]*n), tiny) 1491 1492 1493class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin): 1494 def setUp(self): 1495 self.func = statistics.harmonic_mean 1496 1497 def prepare_data(self): 1498 # Override mixin method. 1499 values = super().prepare_data() 1500 values.remove(0) 1501 return values 1502 1503 def prepare_values_for_repeated_single_test(self): 1504 # Override mixin method. 1505 return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.125')) 1506 1507 def test_zero(self): 1508 # Test that harmonic mean returns zero when given zero. 1509 values = [1, 0, 2] 1510 self.assertEqual(self.func(values), 0) 1511 1512 def test_negative_error(self): 1513 # Test that harmonic mean raises when given a negative value. 1514 exc = statistics.StatisticsError 1515 for values in ([-1], [1, -2, 3]): 1516 with self.subTest(values=values): 1517 self.assertRaises(exc, self.func, values) 1518 1519 def test_invalid_type_error(self): 1520 # Test error is raised when input contains invalid type(s) 1521 for data in [ 1522 ['3.14'], # single string 1523 ['1', '2', '3'], # multiple strings 1524 [1, '2', 3, '4', 5], # mixed strings and valid integers 1525 [2.3, 3.4, 4.5, '5.6'] # only one string and valid floats 1526 ]: 1527 with self.subTest(data=data): 1528 with self.assertRaises(TypeError): 1529 self.func(data) 1530 1531 def test_ints(self): 1532 # Test harmonic mean with ints. 1533 data = [2, 4, 4, 8, 16, 16] 1534 random.shuffle(data) 1535 self.assertEqual(self.func(data), 6*4/5) 1536 1537 def test_floats_exact(self): 1538 # Test harmonic mean with some carefully chosen floats. 1539 data = [1/8, 1/4, 1/4, 1/2, 1/2] 1540 random.shuffle(data) 1541 self.assertEqual(self.func(data), 1/4) 1542 self.assertEqual(self.func([0.25, 0.5, 1.0, 1.0]), 0.5) 1543 1544 def test_singleton_lists(self): 1545 # Test that harmonic mean([x]) returns (approximately) x. 1546 for x in range(1, 101): 1547 self.assertEqual(self.func([x]), x) 1548 1549 def test_decimals_exact(self): 1550 # Test harmonic mean with some carefully chosen Decimals. 1551 D = Decimal 1552 self.assertEqual(self.func([D(15), D(30), D(60), D(60)]), D(30)) 1553 data = [D("0.05"), D("0.10"), D("0.20"), D("0.20")] 1554 random.shuffle(data) 1555 self.assertEqual(self.func(data), D("0.10")) 1556 data = [D("1.68"), D("0.32"), D("5.94"), D("2.75")] 1557 random.shuffle(data) 1558 self.assertEqual(self.func(data), D(66528)/70723) 1559 1560 def test_fractions(self): 1561 # Test harmonic mean with Fractions. 1562 F = Fraction 1563 data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)] 1564 random.shuffle(data) 1565 self.assertEqual(self.func(data), F(7*420, 4029)) 1566 1567 def test_inf(self): 1568 # Test harmonic mean with infinity. 1569 values = [2.0, float('inf'), 1.0] 1570 self.assertEqual(self.func(values), 2.0) 1571 1572 def test_nan(self): 1573 # Test harmonic mean with NANs. 1574 values = [2.0, float('nan'), 1.0] 1575 self.assertTrue(math.isnan(self.func(values))) 1576 1577 def test_multiply_data_points(self): 1578 # Test multiplying every data point by a constant. 1579 c = 111 1580 data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4] 1581 expected = self.func(data)*c 1582 result = self.func([x*c for x in data]) 1583 self.assertEqual(result, expected) 1584 1585 def test_doubled_data(self): 1586 # Harmonic mean of [a,b...z] should be same as for [a,a,b,b...z,z]. 1587 data = [random.uniform(1, 5) for _ in range(1000)] 1588 expected = self.func(data) 1589 actual = self.func(data*2) 1590 self.assertApproxEqual(actual, expected) 1591 1592 def test_with_weights(self): 1593 self.assertEqual(self.func([40, 60], [5, 30]), 56.0) # common case 1594 self.assertEqual(self.func([40, 60], 1595 weights=[5, 30]), 56.0) # keyword argument 1596 self.assertEqual(self.func(iter([40, 60]), 1597 iter([5, 30])), 56.0) # iterator inputs 1598 self.assertEqual( 1599 self.func([Fraction(10, 3), Fraction(23, 5), Fraction(7, 2)], [5, 2, 10]), 1600 self.func([Fraction(10, 3)] * 5 + 1601 [Fraction(23, 5)] * 2 + 1602 [Fraction(7, 2)] * 10)) 1603 self.assertEqual(self.func([10], [7]), 10) # n=1 fast path 1604 with self.assertRaises(TypeError): 1605 self.func([1, 2, 3], [1, (), 3]) # non-numeric weight 1606 with self.assertRaises(statistics.StatisticsError): 1607 self.func([1, 2, 3], [1, 2]) # wrong number of weights 1608 with self.assertRaises(statistics.StatisticsError): 1609 self.func([10], [0]) # no non-zero weights 1610 with self.assertRaises(statistics.StatisticsError): 1611 self.func([10, 20], [0, 0]) # no non-zero weights 1612 1613 1614class TestMedian(NumericTestCase, AverageMixin): 1615 # Common tests for median and all median.* functions. 1616 def setUp(self): 1617 self.func = statistics.median 1618 1619 def prepare_data(self): 1620 """Overload method from UnivariateCommonMixin.""" 1621 data = super().prepare_data() 1622 if len(data)%2 != 1: 1623 data.append(2) 1624 return data 1625 1626 def test_even_ints(self): 1627 # Test median with an even number of int data points. 1628 data = [1, 2, 3, 4, 5, 6] 1629 assert len(data)%2 == 0 1630 self.assertEqual(self.func(data), 3.5) 1631 1632 def test_odd_ints(self): 1633 # Test median with an odd number of int data points. 1634 data = [1, 2, 3, 4, 5, 6, 9] 1635 assert len(data)%2 == 1 1636 self.assertEqual(self.func(data), 4) 1637 1638 def test_odd_fractions(self): 1639 # Test median works with an odd number of Fractions. 1640 F = Fraction 1641 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7)] 1642 assert len(data)%2 == 1 1643 random.shuffle(data) 1644 self.assertEqual(self.func(data), F(3, 7)) 1645 1646 def test_even_fractions(self): 1647 # Test median works with an even number of Fractions. 1648 F = Fraction 1649 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] 1650 assert len(data)%2 == 0 1651 random.shuffle(data) 1652 self.assertEqual(self.func(data), F(1, 2)) 1653 1654 def test_odd_decimals(self): 1655 # Test median works with an odd number of Decimals. 1656 D = Decimal 1657 data = [D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')] 1658 assert len(data)%2 == 1 1659 random.shuffle(data) 1660 self.assertEqual(self.func(data), D('4.2')) 1661 1662 def test_even_decimals(self): 1663 # Test median works with an even number of Decimals. 1664 D = Decimal 1665 data = [D('1.2'), D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')] 1666 assert len(data)%2 == 0 1667 random.shuffle(data) 1668 self.assertEqual(self.func(data), D('3.65')) 1669 1670 1671class TestMedianDataType(NumericTestCase, UnivariateTypeMixin): 1672 # Test conservation of data element type for median. 1673 def setUp(self): 1674 self.func = statistics.median 1675 1676 def prepare_data(self): 1677 data = list(range(15)) 1678 assert len(data)%2 == 1 1679 while data == sorted(data): 1680 random.shuffle(data) 1681 return data 1682 1683 1684class TestMedianLow(TestMedian, UnivariateTypeMixin): 1685 def setUp(self): 1686 self.func = statistics.median_low 1687 1688 def test_even_ints(self): 1689 # Test median_low with an even number of ints. 1690 data = [1, 2, 3, 4, 5, 6] 1691 assert len(data)%2 == 0 1692 self.assertEqual(self.func(data), 3) 1693 1694 def test_even_fractions(self): 1695 # Test median_low works with an even number of Fractions. 1696 F = Fraction 1697 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] 1698 assert len(data)%2 == 0 1699 random.shuffle(data) 1700 self.assertEqual(self.func(data), F(3, 7)) 1701 1702 def test_even_decimals(self): 1703 # Test median_low works with an even number of Decimals. 1704 D = Decimal 1705 data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')] 1706 assert len(data)%2 == 0 1707 random.shuffle(data) 1708 self.assertEqual(self.func(data), D('3.3')) 1709 1710 1711class TestMedianHigh(TestMedian, UnivariateTypeMixin): 1712 def setUp(self): 1713 self.func = statistics.median_high 1714 1715 def test_even_ints(self): 1716 # Test median_high with an even number of ints. 1717 data = [1, 2, 3, 4, 5, 6] 1718 assert len(data)%2 == 0 1719 self.assertEqual(self.func(data), 4) 1720 1721 def test_even_fractions(self): 1722 # Test median_high works with an even number of Fractions. 1723 F = Fraction 1724 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] 1725 assert len(data)%2 == 0 1726 random.shuffle(data) 1727 self.assertEqual(self.func(data), F(4, 7)) 1728 1729 def test_even_decimals(self): 1730 # Test median_high works with an even number of Decimals. 1731 D = Decimal 1732 data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')] 1733 assert len(data)%2 == 0 1734 random.shuffle(data) 1735 self.assertEqual(self.func(data), D('4.4')) 1736 1737 1738class TestMedianGrouped(TestMedian): 1739 # Test median_grouped. 1740 # Doesn't conserve data element types, so don't use TestMedianType. 1741 def setUp(self): 1742 self.func = statistics.median_grouped 1743 1744 def test_odd_number_repeated(self): 1745 # Test median.grouped with repeated median values. 1746 data = [12, 13, 14, 14, 14, 15, 15] 1747 assert len(data)%2 == 1 1748 self.assertEqual(self.func(data), 14) 1749 #--- 1750 data = [12, 13, 14, 14, 14, 14, 15] 1751 assert len(data)%2 == 1 1752 self.assertEqual(self.func(data), 13.875) 1753 #--- 1754 data = [5, 10, 10, 15, 20, 20, 20, 20, 25, 25, 30] 1755 assert len(data)%2 == 1 1756 self.assertEqual(self.func(data, 5), 19.375) 1757 #--- 1758 data = [16, 18, 18, 18, 18, 20, 20, 20, 22, 22, 22, 24, 24, 26, 28] 1759 assert len(data)%2 == 1 1760 self.assertApproxEqual(self.func(data, 2), 20.66666667, tol=1e-8) 1761 1762 def test_even_number_repeated(self): 1763 # Test median.grouped with repeated median values. 1764 data = [5, 10, 10, 15, 20, 20, 20, 25, 25, 30] 1765 assert len(data)%2 == 0 1766 self.assertApproxEqual(self.func(data, 5), 19.16666667, tol=1e-8) 1767 #--- 1768 data = [2, 3, 4, 4, 4, 5] 1769 assert len(data)%2 == 0 1770 self.assertApproxEqual(self.func(data), 3.83333333, tol=1e-8) 1771 #--- 1772 data = [2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6] 1773 assert len(data)%2 == 0 1774 self.assertEqual(self.func(data), 4.5) 1775 #--- 1776 data = [3, 4, 4, 4, 5, 5, 5, 5, 6, 6] 1777 assert len(data)%2 == 0 1778 self.assertEqual(self.func(data), 4.75) 1779 1780 def test_repeated_single_value(self): 1781 # Override method from AverageMixin. 1782 # Yet again, failure of median_grouped to conserve the data type 1783 # causes me headaches :-( 1784 for x in (5.3, 68, 4.3e17, Fraction(29, 101), Decimal('32.9714')): 1785 for count in (2, 5, 10, 20): 1786 data = [x]*count 1787 self.assertEqual(self.func(data), float(x)) 1788 1789 def test_odd_fractions(self): 1790 # Test median_grouped works with an odd number of Fractions. 1791 F = Fraction 1792 data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4)] 1793 assert len(data)%2 == 1 1794 random.shuffle(data) 1795 self.assertEqual(self.func(data), 3.0) 1796 1797 def test_even_fractions(self): 1798 # Test median_grouped works with an even number of Fractions. 1799 F = Fraction 1800 data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4), F(17, 4)] 1801 assert len(data)%2 == 0 1802 random.shuffle(data) 1803 self.assertEqual(self.func(data), 3.25) 1804 1805 def test_odd_decimals(self): 1806 # Test median_grouped works with an odd number of Decimals. 1807 D = Decimal 1808 data = [D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')] 1809 assert len(data)%2 == 1 1810 random.shuffle(data) 1811 self.assertEqual(self.func(data), 6.75) 1812 1813 def test_even_decimals(self): 1814 # Test median_grouped works with an even number of Decimals. 1815 D = Decimal 1816 data = [D('5.5'), D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')] 1817 assert len(data)%2 == 0 1818 random.shuffle(data) 1819 self.assertEqual(self.func(data), 6.5) 1820 #--- 1821 data = [D('5.5'), D('5.5'), D('6.5'), D('7.5'), D('7.5'), D('8.5')] 1822 assert len(data)%2 == 0 1823 random.shuffle(data) 1824 self.assertEqual(self.func(data), 7.0) 1825 1826 def test_interval(self): 1827 # Test median_grouped with interval argument. 1828 data = [2.25, 2.5, 2.5, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75] 1829 self.assertEqual(self.func(data, 0.25), 2.875) 1830 data = [2.25, 2.5, 2.5, 2.75, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75] 1831 self.assertApproxEqual(self.func(data, 0.25), 2.83333333, tol=1e-8) 1832 data = [220, 220, 240, 260, 260, 260, 260, 280, 280, 300, 320, 340] 1833 self.assertEqual(self.func(data, 20), 265.0) 1834 1835 def test_data_type_error(self): 1836 # Test median_grouped with str, bytes data types for data and interval 1837 data = ["", "", ""] 1838 self.assertRaises(TypeError, self.func, data) 1839 #--- 1840 data = [b"", b"", b""] 1841 self.assertRaises(TypeError, self.func, data) 1842 #--- 1843 data = [1, 2, 3] 1844 interval = "" 1845 self.assertRaises(TypeError, self.func, data, interval) 1846 #--- 1847 data = [1, 2, 3] 1848 interval = b"" 1849 self.assertRaises(TypeError, self.func, data, interval) 1850 1851 1852class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin): 1853 # Test cases for the discrete version of mode. 1854 def setUp(self): 1855 self.func = statistics.mode 1856 1857 def prepare_data(self): 1858 """Overload method from UnivariateCommonMixin.""" 1859 # Make sure test data has exactly one mode. 1860 return [1, 1, 1, 1, 3, 4, 7, 9, 0, 8, 2] 1861 1862 def test_range_data(self): 1863 # Override test from UnivariateCommonMixin. 1864 data = range(20, 50, 3) 1865 self.assertEqual(self.func(data), 20) 1866 1867 def test_nominal_data(self): 1868 # Test mode with nominal data. 1869 data = 'abcbdb' 1870 self.assertEqual(self.func(data), 'b') 1871 data = 'fe fi fo fum fi fi'.split() 1872 self.assertEqual(self.func(data), 'fi') 1873 1874 def test_discrete_data(self): 1875 # Test mode with discrete numeric data. 1876 data = list(range(10)) 1877 for i in range(10): 1878 d = data + [i] 1879 random.shuffle(d) 1880 self.assertEqual(self.func(d), i) 1881 1882 def test_bimodal_data(self): 1883 # Test mode with bimodal data. 1884 data = [1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, 9, 9] 1885 assert data.count(2) == data.count(6) == 4 1886 # mode() should return 2, the first encountered mode 1887 self.assertEqual(self.func(data), 2) 1888 1889 def test_unique_data(self): 1890 # Test mode when data points are all unique. 1891 data = list(range(10)) 1892 # mode() should return 0, the first encountered mode 1893 self.assertEqual(self.func(data), 0) 1894 1895 def test_none_data(self): 1896 # Test that mode raises TypeError if given None as data. 1897 1898 # This test is necessary because the implementation of mode uses 1899 # collections.Counter, which accepts None and returns an empty dict. 1900 self.assertRaises(TypeError, self.func, None) 1901 1902 def test_counter_data(self): 1903 # Test that a Counter is treated like any other iterable. 1904 # We're making sure mode() first calls iter() on its input. 1905 # The concern is that a Counter of a Counter returns the original 1906 # unchanged rather than counting its keys. 1907 c = collections.Counter(a=1, b=2) 1908 # If iter() is called, mode(c) loops over the keys, ['a', 'b'], 1909 # all the counts will be 1, and the first encountered mode is 'a'. 1910 self.assertEqual(self.func(c), 'a') 1911 1912 1913class TestMultiMode(unittest.TestCase): 1914 1915 def test_basics(self): 1916 multimode = statistics.multimode 1917 self.assertEqual(multimode('aabbbbbbbbcc'), ['b']) 1918 self.assertEqual(multimode('aabbbbccddddeeffffgg'), ['b', 'd', 'f']) 1919 self.assertEqual(multimode(''), []) 1920 1921 1922class TestFMean(unittest.TestCase): 1923 1924 def test_basics(self): 1925 fmean = statistics.fmean 1926 D = Decimal 1927 F = Fraction 1928 for data, expected_mean, kind in [ 1929 ([3.5, 4.0, 5.25], 4.25, 'floats'), 1930 ([D('3.5'), D('4.0'), D('5.25')], 4.25, 'decimals'), 1931 ([F(7, 2), F(4, 1), F(21, 4)], 4.25, 'fractions'), 1932 ([True, False, True, True, False], 0.60, 'booleans'), 1933 ([3.5, 4, F(21, 4)], 4.25, 'mixed types'), 1934 ((3.5, 4.0, 5.25), 4.25, 'tuple'), 1935 (iter([3.5, 4.0, 5.25]), 4.25, 'iterator'), 1936 ]: 1937 actual_mean = fmean(data) 1938 self.assertIs(type(actual_mean), float, kind) 1939 self.assertEqual(actual_mean, expected_mean, kind) 1940 1941 def test_error_cases(self): 1942 fmean = statistics.fmean 1943 StatisticsError = statistics.StatisticsError 1944 with self.assertRaises(StatisticsError): 1945 fmean([]) # empty input 1946 with self.assertRaises(StatisticsError): 1947 fmean(iter([])) # empty iterator 1948 with self.assertRaises(TypeError): 1949 fmean(None) # non-iterable input 1950 with self.assertRaises(TypeError): 1951 fmean([10, None, 20]) # non-numeric input 1952 with self.assertRaises(TypeError): 1953 fmean() # missing data argument 1954 with self.assertRaises(TypeError): 1955 fmean([10, 20, 60], 70) # too many arguments 1956 1957 def test_special_values(self): 1958 # Rules for special values are inherited from math.fsum() 1959 fmean = statistics.fmean 1960 NaN = float('Nan') 1961 Inf = float('Inf') 1962 self.assertTrue(math.isnan(fmean([10, NaN])), 'nan') 1963 self.assertTrue(math.isnan(fmean([NaN, Inf])), 'nan and infinity') 1964 self.assertTrue(math.isinf(fmean([10, Inf])), 'infinity') 1965 with self.assertRaises(ValueError): 1966 fmean([Inf, -Inf]) 1967 1968 def test_weights(self): 1969 fmean = statistics.fmean 1970 StatisticsError = statistics.StatisticsError 1971 self.assertEqual( 1972 fmean([10, 10, 10, 50], [0.25] * 4), 1973 fmean([10, 10, 10, 50])) 1974 self.assertEqual( 1975 fmean([10, 10, 20], [0.25, 0.25, 0.50]), 1976 fmean([10, 10, 20, 20])) 1977 self.assertEqual( # inputs are iterators 1978 fmean(iter([10, 10, 20]), iter([0.25, 0.25, 0.50])), 1979 fmean([10, 10, 20, 20])) 1980 with self.assertRaises(StatisticsError): 1981 fmean([10, 20, 30], [1, 2]) # unequal lengths 1982 with self.assertRaises(StatisticsError): 1983 fmean(iter([10, 20, 30]), iter([1, 2])) # unequal lengths 1984 with self.assertRaises(StatisticsError): 1985 fmean([10, 20], [-1, 1]) # sum of weights is zero 1986 with self.assertRaises(StatisticsError): 1987 fmean(iter([10, 20]), iter([-1, 1])) # sum of weights is zero 1988 1989 1990# === Tests for variances and standard deviations === 1991 1992class VarianceStdevMixin(UnivariateCommonMixin): 1993 # Mixin class holding common tests for variance and std dev. 1994 1995 # Subclasses should inherit from this before NumericTestClass, in order 1996 # to see the rel attribute below. See testShiftData for an explanation. 1997 1998 rel = 1e-12 1999 2000 def test_single_value(self): 2001 # Deviation of a single value is zero. 2002 for x in (11, 19.8, 4.6e14, Fraction(21, 34), Decimal('8.392')): 2003 self.assertEqual(self.func([x]), 0) 2004 2005 def test_repeated_single_value(self): 2006 # The deviation of a single repeated value is zero. 2007 for x in (7.2, 49, 8.1e15, Fraction(3, 7), Decimal('62.4802')): 2008 for count in (2, 3, 5, 15): 2009 data = [x]*count 2010 self.assertEqual(self.func(data), 0) 2011 2012 def test_domain_error_regression(self): 2013 # Regression test for a domain error exception. 2014 # (Thanks to Geremy Condra.) 2015 data = [0.123456789012345]*10000 2016 # All the items are identical, so variance should be exactly zero. 2017 # We allow some small round-off error, but not much. 2018 result = self.func(data) 2019 self.assertApproxEqual(result, 0.0, tol=5e-17) 2020 self.assertGreaterEqual(result, 0) # A negative result must fail. 2021 2022 def test_shift_data(self): 2023 # Test that shifting the data by a constant amount does not affect 2024 # the variance or stdev. Or at least not much. 2025 2026 # Due to rounding, this test should be considered an ideal. We allow 2027 # some tolerance away from "no change at all" by setting tol and/or rel 2028 # attributes. Subclasses may set tighter or looser error tolerances. 2029 raw = [1.03, 1.27, 1.94, 2.04, 2.58, 3.14, 4.75, 4.98, 5.42, 6.78] 2030 expected = self.func(raw) 2031 # Don't set shift too high, the bigger it is, the more rounding error. 2032 shift = 1e5 2033 data = [x + shift for x in raw] 2034 self.assertApproxEqual(self.func(data), expected) 2035 2036 def test_shift_data_exact(self): 2037 # Like test_shift_data, but result is always exact. 2038 raw = [1, 3, 3, 4, 5, 7, 9, 10, 11, 16] 2039 assert all(x==int(x) for x in raw) 2040 expected = self.func(raw) 2041 shift = 10**9 2042 data = [x + shift for x in raw] 2043 self.assertEqual(self.func(data), expected) 2044 2045 def test_iter_list_same(self): 2046 # Test that iter data and list data give the same result. 2047 2048 # This is an explicit test that iterators and lists are treated the 2049 # same; justification for this test over and above the similar test 2050 # in UnivariateCommonMixin is that an earlier design had variance and 2051 # friends swap between one- and two-pass algorithms, which would 2052 # sometimes give different results. 2053 data = [random.uniform(-3, 8) for _ in range(1000)] 2054 expected = self.func(data) 2055 self.assertEqual(self.func(iter(data)), expected) 2056 2057 2058class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin): 2059 # Tests for population variance. 2060 def setUp(self): 2061 self.func = statistics.pvariance 2062 2063 def test_exact_uniform(self): 2064 # Test the variance against an exact result for uniform data. 2065 data = list(range(10000)) 2066 random.shuffle(data) 2067 expected = (10000**2 - 1)/12 # Exact value. 2068 self.assertEqual(self.func(data), expected) 2069 2070 def test_ints(self): 2071 # Test population variance with int data. 2072 data = [4, 7, 13, 16] 2073 exact = 22.5 2074 self.assertEqual(self.func(data), exact) 2075 2076 def test_fractions(self): 2077 # Test population variance with Fraction data. 2078 F = Fraction 2079 data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)] 2080 exact = F(3, 8) 2081 result = self.func(data) 2082 self.assertEqual(result, exact) 2083 self.assertIsInstance(result, Fraction) 2084 2085 def test_decimals(self): 2086 # Test population variance with Decimal data. 2087 D = Decimal 2088 data = [D("12.1"), D("12.2"), D("12.5"), D("12.9")] 2089 exact = D('0.096875') 2090 result = self.func(data) 2091 self.assertEqual(result, exact) 2092 self.assertIsInstance(result, Decimal) 2093 2094 def test_accuracy_bug_20499(self): 2095 data = [0, 0, 1] 2096 exact = 2 / 9 2097 result = self.func(data) 2098 self.assertEqual(result, exact) 2099 self.assertIsInstance(result, float) 2100 2101 2102class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin): 2103 # Tests for sample variance. 2104 def setUp(self): 2105 self.func = statistics.variance 2106 2107 def test_single_value(self): 2108 # Override method from VarianceStdevMixin. 2109 for x in (35, 24.7, 8.2e15, Fraction(19, 30), Decimal('4.2084')): 2110 self.assertRaises(statistics.StatisticsError, self.func, [x]) 2111 2112 def test_ints(self): 2113 # Test sample variance with int data. 2114 data = [4, 7, 13, 16] 2115 exact = 30 2116 self.assertEqual(self.func(data), exact) 2117 2118 def test_fractions(self): 2119 # Test sample variance with Fraction data. 2120 F = Fraction 2121 data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)] 2122 exact = F(1, 2) 2123 result = self.func(data) 2124 self.assertEqual(result, exact) 2125 self.assertIsInstance(result, Fraction) 2126 2127 def test_decimals(self): 2128 # Test sample variance with Decimal data. 2129 D = Decimal 2130 data = [D(2), D(2), D(7), D(9)] 2131 exact = 4*D('9.5')/D(3) 2132 result = self.func(data) 2133 self.assertEqual(result, exact) 2134 self.assertIsInstance(result, Decimal) 2135 2136 def test_center_not_at_mean(self): 2137 data = (1.0, 2.0) 2138 self.assertEqual(self.func(data), 0.5) 2139 self.assertEqual(self.func(data, xbar=2.0), 1.0) 2140 2141 def test_accuracy_bug_20499(self): 2142 data = [0, 0, 2] 2143 exact = 4 / 3 2144 result = self.func(data) 2145 self.assertEqual(result, exact) 2146 self.assertIsInstance(result, float) 2147 2148class TestPStdev(VarianceStdevMixin, NumericTestCase): 2149 # Tests for population standard deviation. 2150 def setUp(self): 2151 self.func = statistics.pstdev 2152 2153 def test_compare_to_variance(self): 2154 # Test that stdev is, in fact, the square root of variance. 2155 data = [random.uniform(-17, 24) for _ in range(1000)] 2156 expected = math.sqrt(statistics.pvariance(data)) 2157 self.assertEqual(self.func(data), expected) 2158 2159 def test_center_not_at_mean(self): 2160 # See issue: 40855 2161 data = (3, 6, 7, 10) 2162 self.assertEqual(self.func(data), 2.5) 2163 self.assertEqual(self.func(data, mu=0.5), 6.5) 2164 2165class TestSqrtHelpers(unittest.TestCase): 2166 2167 def test_integer_sqrt_of_frac_rto(self): 2168 for n, m in itertools.product(range(100), range(1, 1000)): 2169 r = statistics._integer_sqrt_of_frac_rto(n, m) 2170 self.assertIsInstance(r, int) 2171 if r*r*m == n: 2172 # Root is exact 2173 continue 2174 # Inexact, so the root should be odd 2175 self.assertEqual(r&1, 1) 2176 # Verify correct rounding 2177 self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2) 2178 2179 @requires_IEEE_754 2180 def test_float_sqrt_of_frac(self): 2181 2182 def is_root_correctly_rounded(x: Fraction, root: float) -> bool: 2183 if not x: 2184 return root == 0.0 2185 2186 # Extract adjacent representable floats 2187 r_up: float = math.nextafter(root, math.inf) 2188 r_down: float = math.nextafter(root, -math.inf) 2189 assert r_down < root < r_up 2190 2191 # Convert to fractions for exact arithmetic 2192 frac_root: Fraction = Fraction(root) 2193 half_way_up: Fraction = (frac_root + Fraction(r_up)) / 2 2194 half_way_down: Fraction = (frac_root + Fraction(r_down)) / 2 2195 2196 # Check a closed interval. 2197 # Does not test for a midpoint rounding rule. 2198 return half_way_down ** 2 <= x <= half_way_up ** 2 2199 2200 randrange = random.randrange 2201 2202 for i in range(60_000): 2203 numerator: int = randrange(10 ** randrange(50)) 2204 denonimator: int = randrange(10 ** randrange(50)) + 1 2205 with self.subTest(numerator=numerator, denonimator=denonimator): 2206 x: Fraction = Fraction(numerator, denonimator) 2207 root: float = statistics._float_sqrt_of_frac(numerator, denonimator) 2208 self.assertTrue(is_root_correctly_rounded(x, root)) 2209 2210 # Verify that corner cases and error handling match math.sqrt() 2211 self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0) 2212 with self.assertRaises(ValueError): 2213 statistics._float_sqrt_of_frac(-1, 1) 2214 with self.assertRaises(ValueError): 2215 statistics._float_sqrt_of_frac(1, -1) 2216 2217 # Error handling for zero denominator matches that for Fraction(1, 0) 2218 with self.assertRaises(ZeroDivisionError): 2219 statistics._float_sqrt_of_frac(1, 0) 2220 2221 # The result is well defined if both inputs are negative 2222 self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1)) 2223 2224 def test_decimal_sqrt_of_frac(self): 2225 root: Decimal 2226 numerator: int 2227 denominator: int 2228 2229 for root, numerator, denominator in [ 2230 (Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000), # No adj 2231 (Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000), # Adj up 2232 (Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000), # Adj down 2233 ]: 2234 with decimal.localcontext(decimal.DefaultContext): 2235 self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root) 2236 2237 # Confirm expected root with a quad precision decimal computation 2238 with decimal.localcontext(decimal.DefaultContext) as ctx: 2239 ctx.prec *= 4 2240 high_prec_ratio = Decimal(numerator) / Decimal(denominator) 2241 ctx.rounding = decimal.ROUND_05UP 2242 high_prec_root = high_prec_ratio.sqrt() 2243 with decimal.localcontext(decimal.DefaultContext): 2244 target_root = +high_prec_root 2245 self.assertEqual(root, target_root) 2246 2247 # Verify that corner cases and error handling match Decimal.sqrt() 2248 self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0) 2249 with self.assertRaises(decimal.InvalidOperation): 2250 statistics._decimal_sqrt_of_frac(-1, 1) 2251 with self.assertRaises(decimal.InvalidOperation): 2252 statistics._decimal_sqrt_of_frac(1, -1) 2253 2254 # Error handling for zero denominator matches that for Fraction(1, 0) 2255 with self.assertRaises(ZeroDivisionError): 2256 statistics._decimal_sqrt_of_frac(1, 0) 2257 2258 # The result is well defined if both inputs are negative 2259 self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1)) 2260 2261 2262class TestStdev(VarianceStdevMixin, NumericTestCase): 2263 # Tests for sample standard deviation. 2264 def setUp(self): 2265 self.func = statistics.stdev 2266 2267 def test_single_value(self): 2268 # Override method from VarianceStdevMixin. 2269 for x in (81, 203.74, 3.9e14, Fraction(5, 21), Decimal('35.719')): 2270 self.assertRaises(statistics.StatisticsError, self.func, [x]) 2271 2272 def test_compare_to_variance(self): 2273 # Test that stdev is, in fact, the square root of variance. 2274 data = [random.uniform(-2, 9) for _ in range(1000)] 2275 expected = math.sqrt(statistics.variance(data)) 2276 self.assertAlmostEqual(self.func(data), expected) 2277 2278 def test_center_not_at_mean(self): 2279 data = (1.0, 2.0) 2280 self.assertEqual(self.func(data, xbar=2.0), 1.0) 2281 2282class TestGeometricMean(unittest.TestCase): 2283 2284 def test_basics(self): 2285 geometric_mean = statistics.geometric_mean 2286 self.assertAlmostEqual(geometric_mean([54, 24, 36]), 36.0) 2287 self.assertAlmostEqual(geometric_mean([4.0, 9.0]), 6.0) 2288 self.assertAlmostEqual(geometric_mean([17.625]), 17.625) 2289 2290 random.seed(86753095551212) 2291 for rng in [ 2292 range(1, 100), 2293 range(1, 1_000), 2294 range(1, 10_000), 2295 range(500, 10_000, 3), 2296 range(10_000, 500, -3), 2297 [12, 17, 13, 5, 120, 7], 2298 [random.expovariate(50.0) for i in range(1_000)], 2299 [random.lognormvariate(20.0, 3.0) for i in range(2_000)], 2300 [random.triangular(2000, 3000, 2200) for i in range(3_000)], 2301 ]: 2302 gm_decimal = math.prod(map(Decimal, rng)) ** (Decimal(1) / len(rng)) 2303 gm_float = geometric_mean(rng) 2304 self.assertTrue(math.isclose(gm_float, float(gm_decimal))) 2305 2306 def test_various_input_types(self): 2307 geometric_mean = statistics.geometric_mean 2308 D = Decimal 2309 F = Fraction 2310 # https://www.wolframalpha.com/input/?i=geometric+mean+3.5,+4.0,+5.25 2311 expected_mean = 4.18886 2312 for data, kind in [ 2313 ([3.5, 4.0, 5.25], 'floats'), 2314 ([D('3.5'), D('4.0'), D('5.25')], 'decimals'), 2315 ([F(7, 2), F(4, 1), F(21, 4)], 'fractions'), 2316 ([3.5, 4, F(21, 4)], 'mixed types'), 2317 ((3.5, 4.0, 5.25), 'tuple'), 2318 (iter([3.5, 4.0, 5.25]), 'iterator'), 2319 ]: 2320 actual_mean = geometric_mean(data) 2321 self.assertIs(type(actual_mean), float, kind) 2322 self.assertAlmostEqual(actual_mean, expected_mean, places=5) 2323 2324 def test_big_and_small(self): 2325 geometric_mean = statistics.geometric_mean 2326 2327 # Avoid overflow to infinity 2328 large = 2.0 ** 1000 2329 big_gm = geometric_mean([54.0 * large, 24.0 * large, 36.0 * large]) 2330 self.assertTrue(math.isclose(big_gm, 36.0 * large)) 2331 self.assertFalse(math.isinf(big_gm)) 2332 2333 # Avoid underflow to zero 2334 small = 2.0 ** -1000 2335 small_gm = geometric_mean([54.0 * small, 24.0 * small, 36.0 * small]) 2336 self.assertTrue(math.isclose(small_gm, 36.0 * small)) 2337 self.assertNotEqual(small_gm, 0.0) 2338 2339 def test_error_cases(self): 2340 geometric_mean = statistics.geometric_mean 2341 StatisticsError = statistics.StatisticsError 2342 with self.assertRaises(StatisticsError): 2343 geometric_mean([]) # empty input 2344 with self.assertRaises(StatisticsError): 2345 geometric_mean([3.5, 0.0, 5.25]) # zero input 2346 with self.assertRaises(StatisticsError): 2347 geometric_mean([3.5, -4.0, 5.25]) # negative input 2348 with self.assertRaises(StatisticsError): 2349 geometric_mean(iter([])) # empty iterator 2350 with self.assertRaises(TypeError): 2351 geometric_mean(None) # non-iterable input 2352 with self.assertRaises(TypeError): 2353 geometric_mean([10, None, 20]) # non-numeric input 2354 with self.assertRaises(TypeError): 2355 geometric_mean() # missing data argument 2356 with self.assertRaises(TypeError): 2357 geometric_mean([10, 20, 60], 70) # too many arguments 2358 2359 def test_special_values(self): 2360 # Rules for special values are inherited from math.fsum() 2361 geometric_mean = statistics.geometric_mean 2362 NaN = float('Nan') 2363 Inf = float('Inf') 2364 self.assertTrue(math.isnan(geometric_mean([10, NaN])), 'nan') 2365 self.assertTrue(math.isnan(geometric_mean([NaN, Inf])), 'nan and infinity') 2366 self.assertTrue(math.isinf(geometric_mean([10, Inf])), 'infinity') 2367 with self.assertRaises(ValueError): 2368 geometric_mean([Inf, -Inf]) 2369 2370 def test_mixed_int_and_float(self): 2371 # Regression test for b.p.o. issue #28327 2372 geometric_mean = statistics.geometric_mean 2373 expected_mean = 3.80675409583932 2374 values = [ 2375 [2, 3, 5, 7], 2376 [2, 3, 5, 7.0], 2377 [2, 3, 5.0, 7.0], 2378 [2, 3.0, 5.0, 7.0], 2379 [2.0, 3.0, 5.0, 7.0], 2380 ] 2381 for v in values: 2382 with self.subTest(v=v): 2383 actual_mean = geometric_mean(v) 2384 self.assertAlmostEqual(actual_mean, expected_mean, places=5) 2385 2386 2387class TestQuantiles(unittest.TestCase): 2388 2389 def test_specific_cases(self): 2390 # Match results computed by hand and cross-checked 2391 # against the PERCENTILE.EXC function in MS Excel. 2392 quantiles = statistics.quantiles 2393 data = [120, 200, 250, 320, 350] 2394 random.shuffle(data) 2395 for n, expected in [ 2396 (1, []), 2397 (2, [250.0]), 2398 (3, [200.0, 320.0]), 2399 (4, [160.0, 250.0, 335.0]), 2400 (5, [136.0, 220.0, 292.0, 344.0]), 2401 (6, [120.0, 200.0, 250.0, 320.0, 350.0]), 2402 (8, [100.0, 160.0, 212.5, 250.0, 302.5, 335.0, 357.5]), 2403 (10, [88.0, 136.0, 184.0, 220.0, 250.0, 292.0, 326.0, 344.0, 362.0]), 2404 (12, [80.0, 120.0, 160.0, 200.0, 225.0, 250.0, 285.0, 320.0, 335.0, 2405 350.0, 365.0]), 2406 (15, [72.0, 104.0, 136.0, 168.0, 200.0, 220.0, 240.0, 264.0, 292.0, 2407 320.0, 332.0, 344.0, 356.0, 368.0]), 2408 ]: 2409 self.assertEqual(expected, quantiles(data, n=n)) 2410 self.assertEqual(len(quantiles(data, n=n)), n - 1) 2411 # Preserve datatype when possible 2412 for datatype in (float, Decimal, Fraction): 2413 result = quantiles(map(datatype, data), n=n) 2414 self.assertTrue(all(type(x) == datatype) for x in result) 2415 self.assertEqual(result, list(map(datatype, expected))) 2416 # Quantiles should be idempotent 2417 if len(expected) >= 2: 2418 self.assertEqual(quantiles(expected, n=n), expected) 2419 # Cross-check against method='inclusive' which should give 2420 # the same result after adding in minimum and maximum values 2421 # extrapolated from the two lowest and two highest points. 2422 sdata = sorted(data) 2423 lo = 2 * sdata[0] - sdata[1] 2424 hi = 2 * sdata[-1] - sdata[-2] 2425 padded_data = data + [lo, hi] 2426 self.assertEqual( 2427 quantiles(data, n=n), 2428 quantiles(padded_data, n=n, method='inclusive'), 2429 (n, data), 2430 ) 2431 # Invariant under translation and scaling 2432 def f(x): 2433 return 3.5 * x - 1234.675 2434 exp = list(map(f, expected)) 2435 act = quantiles(map(f, data), n=n) 2436 self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act))) 2437 # Q2 agrees with median() 2438 for k in range(2, 60): 2439 data = random.choices(range(100), k=k) 2440 q1, q2, q3 = quantiles(data) 2441 self.assertEqual(q2, statistics.median(data)) 2442 2443 def test_specific_cases_inclusive(self): 2444 # Match results computed by hand and cross-checked 2445 # against the PERCENTILE.INC function in MS Excel 2446 # and against the quantile() function in SciPy. 2447 quantiles = statistics.quantiles 2448 data = [100, 200, 400, 800] 2449 random.shuffle(data) 2450 for n, expected in [ 2451 (1, []), 2452 (2, [300.0]), 2453 (3, [200.0, 400.0]), 2454 (4, [175.0, 300.0, 500.0]), 2455 (5, [160.0, 240.0, 360.0, 560.0]), 2456 (6, [150.0, 200.0, 300.0, 400.0, 600.0]), 2457 (8, [137.5, 175, 225.0, 300.0, 375.0, 500.0,650.0]), 2458 (10, [130.0, 160.0, 190.0, 240.0, 300.0, 360.0, 440.0, 560.0, 680.0]), 2459 (12, [125.0, 150.0, 175.0, 200.0, 250.0, 300.0, 350.0, 400.0, 2460 500.0, 600.0, 700.0]), 2461 (15, [120.0, 140.0, 160.0, 180.0, 200.0, 240.0, 280.0, 320.0, 360.0, 2462 400.0, 480.0, 560.0, 640.0, 720.0]), 2463 ]: 2464 self.assertEqual(expected, quantiles(data, n=n, method="inclusive")) 2465 self.assertEqual(len(quantiles(data, n=n, method="inclusive")), n - 1) 2466 # Preserve datatype when possible 2467 for datatype in (float, Decimal, Fraction): 2468 result = quantiles(map(datatype, data), n=n, method="inclusive") 2469 self.assertTrue(all(type(x) == datatype) for x in result) 2470 self.assertEqual(result, list(map(datatype, expected))) 2471 # Invariant under translation and scaling 2472 def f(x): 2473 return 3.5 * x - 1234.675 2474 exp = list(map(f, expected)) 2475 act = quantiles(map(f, data), n=n, method="inclusive") 2476 self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act))) 2477 # Natural deciles 2478 self.assertEqual(quantiles([0, 100], n=10, method='inclusive'), 2479 [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]) 2480 self.assertEqual(quantiles(range(0, 101), n=10, method='inclusive'), 2481 [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]) 2482 # Whenever n is smaller than the number of data points, running 2483 # method='inclusive' should give the same result as method='exclusive' 2484 # after the two included extreme points are removed. 2485 data = [random.randrange(10_000) for i in range(501)] 2486 actual = quantiles(data, n=32, method='inclusive') 2487 data.remove(min(data)) 2488 data.remove(max(data)) 2489 expected = quantiles(data, n=32) 2490 self.assertEqual(expected, actual) 2491 # Q2 agrees with median() 2492 for k in range(2, 60): 2493 data = random.choices(range(100), k=k) 2494 q1, q2, q3 = quantiles(data, method='inclusive') 2495 self.assertEqual(q2, statistics.median(data)) 2496 2497 def test_equal_inputs(self): 2498 quantiles = statistics.quantiles 2499 for n in range(2, 10): 2500 data = [10.0] * n 2501 self.assertEqual(quantiles(data), [10.0, 10.0, 10.0]) 2502 self.assertEqual(quantiles(data, method='inclusive'), 2503 [10.0, 10.0, 10.0]) 2504 2505 def test_equal_sized_groups(self): 2506 quantiles = statistics.quantiles 2507 total = 10_000 2508 data = [random.expovariate(0.2) for i in range(total)] 2509 while len(set(data)) != total: 2510 data.append(random.expovariate(0.2)) 2511 data.sort() 2512 2513 # Cases where the group size exactly divides the total 2514 for n in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000): 2515 group_size = total // n 2516 self.assertEqual( 2517 [bisect.bisect(data, q) for q in quantiles(data, n=n)], 2518 list(range(group_size, total, group_size))) 2519 2520 # When the group sizes can't be exactly equal, they should 2521 # differ by no more than one 2522 for n in (13, 19, 59, 109, 211, 571, 1019, 1907, 5261, 9769): 2523 group_sizes = {total // n, total // n + 1} 2524 pos = [bisect.bisect(data, q) for q in quantiles(data, n=n)] 2525 sizes = {q - p for p, q in zip(pos, pos[1:])} 2526 self.assertTrue(sizes <= group_sizes) 2527 2528 def test_error_cases(self): 2529 quantiles = statistics.quantiles 2530 StatisticsError = statistics.StatisticsError 2531 with self.assertRaises(TypeError): 2532 quantiles() # Missing arguments 2533 with self.assertRaises(TypeError): 2534 quantiles([10, 20, 30], 13, n=4) # Too many arguments 2535 with self.assertRaises(TypeError): 2536 quantiles([10, 20, 30], 4) # n is a positional argument 2537 with self.assertRaises(StatisticsError): 2538 quantiles([10, 20, 30], n=0) # n is zero 2539 with self.assertRaises(StatisticsError): 2540 quantiles([10, 20, 30], n=-1) # n is negative 2541 with self.assertRaises(TypeError): 2542 quantiles([10, 20, 30], n=1.5) # n is not an integer 2543 with self.assertRaises(ValueError): 2544 quantiles([10, 20, 30], method='X') # method is unknown 2545 with self.assertRaises(StatisticsError): 2546 quantiles([10], n=4) # not enough data points 2547 with self.assertRaises(TypeError): 2548 quantiles([10, None, 30], n=4) # data is non-numeric 2549 2550 2551class TestBivariateStatistics(unittest.TestCase): 2552 2553 def test_unequal_size_error(self): 2554 for x, y in [ 2555 ([1, 2, 3], [1, 2]), 2556 ([1, 2], [1, 2, 3]), 2557 ]: 2558 with self.assertRaises(statistics.StatisticsError): 2559 statistics.covariance(x, y) 2560 with self.assertRaises(statistics.StatisticsError): 2561 statistics.correlation(x, y) 2562 with self.assertRaises(statistics.StatisticsError): 2563 statistics.linear_regression(x, y) 2564 2565 def test_small_sample_error(self): 2566 for x, y in [ 2567 ([], []), 2568 ([], [1, 2,]), 2569 ([1, 2,], []), 2570 ([1,], [1,]), 2571 ([1,], [1, 2,]), 2572 ([1, 2,], [1,]), 2573 ]: 2574 with self.assertRaises(statistics.StatisticsError): 2575 statistics.covariance(x, y) 2576 with self.assertRaises(statistics.StatisticsError): 2577 statistics.correlation(x, y) 2578 with self.assertRaises(statistics.StatisticsError): 2579 statistics.linear_regression(x, y) 2580 2581 2582class TestCorrelationAndCovariance(unittest.TestCase): 2583 2584 def test_results(self): 2585 for x, y, result in [ 2586 ([1, 2, 3], [1, 2, 3], 1), 2587 ([1, 2, 3], [-1, -2, -3], -1), 2588 ([1, 2, 3], [3, 2, 1], -1), 2589 ([1, 2, 3], [1, 2, 1], 0), 2590 ([1, 2, 3], [1, 3, 2], 0.5), 2591 ]: 2592 self.assertAlmostEqual(statistics.correlation(x, y), result) 2593 self.assertAlmostEqual(statistics.covariance(x, y), result) 2594 2595 def test_different_scales(self): 2596 x = [1, 2, 3] 2597 y = [10, 30, 20] 2598 self.assertAlmostEqual(statistics.correlation(x, y), 0.5) 2599 self.assertAlmostEqual(statistics.covariance(x, y), 5) 2600 2601 y = [.1, .2, .3] 2602 self.assertAlmostEqual(statistics.correlation(x, y), 1) 2603 self.assertAlmostEqual(statistics.covariance(x, y), 0.1) 2604 2605 2606class TestLinearRegression(unittest.TestCase): 2607 2608 def test_constant_input_error(self): 2609 x = [1, 1, 1,] 2610 y = [1, 2, 3,] 2611 with self.assertRaises(statistics.StatisticsError): 2612 statistics.linear_regression(x, y) 2613 2614 def test_results(self): 2615 for x, y, true_intercept, true_slope in [ 2616 ([1, 2, 3], [0, 0, 0], 0, 0), 2617 ([1, 2, 3], [1, 2, 3], 0, 1), 2618 ([1, 2, 3], [100, 100, 100], 100, 0), 2619 ([1, 2, 3], [12, 14, 16], 10, 2), 2620 ([1, 2, 3], [-1, -2, -3], 0, -1), 2621 ([1, 2, 3], [21, 22, 23], 20, 1), 2622 ([1, 2, 3], [5.1, 5.2, 5.3], 5, 0.1), 2623 ]: 2624 slope, intercept = statistics.linear_regression(x, y) 2625 self.assertAlmostEqual(intercept, true_intercept) 2626 self.assertAlmostEqual(slope, true_slope) 2627 2628 def test_proportional(self): 2629 x = [10, 20, 30, 40] 2630 y = [180, 398, 610, 799] 2631 slope, intercept = statistics.linear_regression(x, y, proportional=True) 2632 self.assertAlmostEqual(slope, 20 + 1/150) 2633 self.assertEqual(intercept, 0.0) 2634 2635class TestNormalDist: 2636 2637 # General note on precision: The pdf(), cdf(), and overlap() methods 2638 # depend on functions in the math libraries that do not make 2639 # explicit accuracy guarantees. Accordingly, some of the accuracy 2640 # tests below may fail if the underlying math functions are 2641 # inaccurate. There isn't much we can do about this short of 2642 # implementing our own implementations from scratch. 2643 2644 def test_slots(self): 2645 nd = self.module.NormalDist(300, 23) 2646 with self.assertRaises(TypeError): 2647 vars(nd) 2648 self.assertEqual(tuple(nd.__slots__), ('_mu', '_sigma')) 2649 2650 def test_instantiation_and_attributes(self): 2651 nd = self.module.NormalDist(500, 17) 2652 self.assertEqual(nd.mean, 500) 2653 self.assertEqual(nd.stdev, 17) 2654 self.assertEqual(nd.variance, 17**2) 2655 2656 # default arguments 2657 nd = self.module.NormalDist() 2658 self.assertEqual(nd.mean, 0) 2659 self.assertEqual(nd.stdev, 1) 2660 self.assertEqual(nd.variance, 1**2) 2661 2662 # error case: negative sigma 2663 with self.assertRaises(self.module.StatisticsError): 2664 self.module.NormalDist(500, -10) 2665 2666 # verify that subclass type is honored 2667 class NewNormalDist(self.module.NormalDist): 2668 pass 2669 nnd = NewNormalDist(200, 5) 2670 self.assertEqual(type(nnd), NewNormalDist) 2671 2672 def test_alternative_constructor(self): 2673 NormalDist = self.module.NormalDist 2674 data = [96, 107, 90, 92, 110] 2675 # list input 2676 self.assertEqual(NormalDist.from_samples(data), NormalDist(99, 9)) 2677 # tuple input 2678 self.assertEqual(NormalDist.from_samples(tuple(data)), NormalDist(99, 9)) 2679 # iterator input 2680 self.assertEqual(NormalDist.from_samples(iter(data)), NormalDist(99, 9)) 2681 # error cases 2682 with self.assertRaises(self.module.StatisticsError): 2683 NormalDist.from_samples([]) # empty input 2684 with self.assertRaises(self.module.StatisticsError): 2685 NormalDist.from_samples([10]) # only one input 2686 2687 # verify that subclass type is honored 2688 class NewNormalDist(NormalDist): 2689 pass 2690 nnd = NewNormalDist.from_samples(data) 2691 self.assertEqual(type(nnd), NewNormalDist) 2692 2693 def test_sample_generation(self): 2694 NormalDist = self.module.NormalDist 2695 mu, sigma = 10_000, 3.0 2696 X = NormalDist(mu, sigma) 2697 n = 1_000 2698 data = X.samples(n) 2699 self.assertEqual(len(data), n) 2700 self.assertEqual(set(map(type, data)), {float}) 2701 # mean(data) expected to fall within 8 standard deviations 2702 xbar = self.module.mean(data) 2703 self.assertTrue(mu - sigma*8 <= xbar <= mu + sigma*8) 2704 2705 # verify that seeding makes reproducible sequences 2706 n = 100 2707 data1 = X.samples(n, seed='happiness and joy') 2708 data2 = X.samples(n, seed='trouble and despair') 2709 data3 = X.samples(n, seed='happiness and joy') 2710 data4 = X.samples(n, seed='trouble and despair') 2711 self.assertEqual(data1, data3) 2712 self.assertEqual(data2, data4) 2713 self.assertNotEqual(data1, data2) 2714 2715 def test_pdf(self): 2716 NormalDist = self.module.NormalDist 2717 X = NormalDist(100, 15) 2718 # Verify peak around center 2719 self.assertLess(X.pdf(99), X.pdf(100)) 2720 self.assertLess(X.pdf(101), X.pdf(100)) 2721 # Test symmetry 2722 for i in range(50): 2723 self.assertAlmostEqual(X.pdf(100 - i), X.pdf(100 + i)) 2724 # Test vs CDF 2725 dx = 2.0 ** -10 2726 for x in range(90, 111): 2727 est_pdf = (X.cdf(x + dx) - X.cdf(x)) / dx 2728 self.assertAlmostEqual(X.pdf(x), est_pdf, places=4) 2729 # Test vs table of known values -- CRC 26th Edition 2730 Z = NormalDist() 2731 for x, px in enumerate([ 2732 0.3989, 0.3989, 0.3989, 0.3988, 0.3986, 2733 0.3984, 0.3982, 0.3980, 0.3977, 0.3973, 2734 0.3970, 0.3965, 0.3961, 0.3956, 0.3951, 2735 0.3945, 0.3939, 0.3932, 0.3925, 0.3918, 2736 0.3910, 0.3902, 0.3894, 0.3885, 0.3876, 2737 0.3867, 0.3857, 0.3847, 0.3836, 0.3825, 2738 0.3814, 0.3802, 0.3790, 0.3778, 0.3765, 2739 0.3752, 0.3739, 0.3725, 0.3712, 0.3697, 2740 0.3683, 0.3668, 0.3653, 0.3637, 0.3621, 2741 0.3605, 0.3589, 0.3572, 0.3555, 0.3538, 2742 ]): 2743 self.assertAlmostEqual(Z.pdf(x / 100.0), px, places=4) 2744 self.assertAlmostEqual(Z.pdf(-x / 100.0), px, places=4) 2745 # Error case: variance is zero 2746 Y = NormalDist(100, 0) 2747 with self.assertRaises(self.module.StatisticsError): 2748 Y.pdf(90) 2749 # Special values 2750 self.assertEqual(X.pdf(float('-Inf')), 0.0) 2751 self.assertEqual(X.pdf(float('Inf')), 0.0) 2752 self.assertTrue(math.isnan(X.pdf(float('NaN')))) 2753 2754 def test_cdf(self): 2755 NormalDist = self.module.NormalDist 2756 X = NormalDist(100, 15) 2757 cdfs = [X.cdf(x) for x in range(1, 200)] 2758 self.assertEqual(set(map(type, cdfs)), {float}) 2759 # Verify montonic 2760 self.assertEqual(cdfs, sorted(cdfs)) 2761 # Verify center (should be exact) 2762 self.assertEqual(X.cdf(100), 0.50) 2763 # Check against a table of known values 2764 # https://en.wikipedia.org/wiki/Standard_normal_table#Cumulative 2765 Z = NormalDist() 2766 for z, cum_prob in [ 2767 (0.00, 0.50000), (0.01, 0.50399), (0.02, 0.50798), 2768 (0.14, 0.55567), (0.29, 0.61409), (0.33, 0.62930), 2769 (0.54, 0.70540), (0.60, 0.72575), (1.17, 0.87900), 2770 (1.60, 0.94520), (2.05, 0.97982), (2.89, 0.99807), 2771 (3.52, 0.99978), (3.98, 0.99997), (4.07, 0.99998), 2772 ]: 2773 self.assertAlmostEqual(Z.cdf(z), cum_prob, places=5) 2774 self.assertAlmostEqual(Z.cdf(-z), 1.0 - cum_prob, places=5) 2775 # Error case: variance is zero 2776 Y = NormalDist(100, 0) 2777 with self.assertRaises(self.module.StatisticsError): 2778 Y.cdf(90) 2779 # Special values 2780 self.assertEqual(X.cdf(float('-Inf')), 0.0) 2781 self.assertEqual(X.cdf(float('Inf')), 1.0) 2782 self.assertTrue(math.isnan(X.cdf(float('NaN')))) 2783 2784 @support.skip_if_pgo_task 2785 def test_inv_cdf(self): 2786 NormalDist = self.module.NormalDist 2787 2788 # Center case should be exact. 2789 iq = NormalDist(100, 15) 2790 self.assertEqual(iq.inv_cdf(0.50), iq.mean) 2791 2792 # Test versus a published table of known percentage points. 2793 # See the second table at the bottom of the page here: 2794 # http://people.bath.ac.uk/masss/tables/normaltable.pdf 2795 Z = NormalDist() 2796 pp = {5.0: (0.000, 1.645, 2.576, 3.291, 3.891, 2797 4.417, 4.892, 5.327, 5.731, 6.109), 2798 2.5: (0.674, 1.960, 2.807, 3.481, 4.056, 2799 4.565, 5.026, 5.451, 5.847, 6.219), 2800 1.0: (1.282, 2.326, 3.090, 3.719, 4.265, 2801 4.753, 5.199, 5.612, 5.998, 6.361)} 2802 for base, row in pp.items(): 2803 for exp, x in enumerate(row, start=1): 2804 p = base * 10.0 ** (-exp) 2805 self.assertAlmostEqual(-Z.inv_cdf(p), x, places=3) 2806 p = 1.0 - p 2807 self.assertAlmostEqual(Z.inv_cdf(p), x, places=3) 2808 2809 # Match published example for MS Excel 2810 # https://support.office.com/en-us/article/norm-inv-function-54b30935-fee7-493c-bedb-2278a9db7e13 2811 self.assertAlmostEqual(NormalDist(40, 1.5).inv_cdf(0.908789), 42.000002) 2812 2813 # One million equally spaced probabilities 2814 n = 2**20 2815 for p in range(1, n): 2816 p /= n 2817 self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p) 2818 2819 # One hundred ever smaller probabilities to test tails out to 2820 # extreme probabilities: 1 / 2**50 and (2**50-1) / 2 ** 50 2821 for e in range(1, 51): 2822 p = 2.0 ** (-e) 2823 self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p) 2824 p = 1.0 - p 2825 self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p) 2826 2827 # Now apply cdf() first. Near the tails, the round-trip loses 2828 # precision and is ill-conditioned (small changes in the inputs 2829 # give large changes in the output), so only check to 5 places. 2830 for x in range(200): 2831 self.assertAlmostEqual(iq.inv_cdf(iq.cdf(x)), x, places=5) 2832 2833 # Error cases: 2834 with self.assertRaises(self.module.StatisticsError): 2835 iq.inv_cdf(0.0) # p is zero 2836 with self.assertRaises(self.module.StatisticsError): 2837 iq.inv_cdf(-0.1) # p under zero 2838 with self.assertRaises(self.module.StatisticsError): 2839 iq.inv_cdf(1.0) # p is one 2840 with self.assertRaises(self.module.StatisticsError): 2841 iq.inv_cdf(1.1) # p over one 2842 with self.assertRaises(self.module.StatisticsError): 2843 iq = NormalDist(100, 0) # sigma is zero 2844 iq.inv_cdf(0.5) 2845 2846 # Special values 2847 self.assertTrue(math.isnan(Z.inv_cdf(float('NaN')))) 2848 2849 def test_quantiles(self): 2850 # Quartiles of a standard normal distribution 2851 Z = self.module.NormalDist() 2852 for n, expected in [ 2853 (1, []), 2854 (2, [0.0]), 2855 (3, [-0.4307, 0.4307]), 2856 (4 ,[-0.6745, 0.0, 0.6745]), 2857 ]: 2858 actual = Z.quantiles(n=n) 2859 self.assertTrue(all(math.isclose(e, a, abs_tol=0.0001) 2860 for e, a in zip(expected, actual))) 2861 2862 def test_overlap(self): 2863 NormalDist = self.module.NormalDist 2864 2865 # Match examples from Imman and Bradley 2866 for X1, X2, published_result in [ 2867 (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0), 0.80258), 2868 (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0), 0.60993), 2869 ]: 2870 self.assertAlmostEqual(X1.overlap(X2), published_result, places=4) 2871 self.assertAlmostEqual(X2.overlap(X1), published_result, places=4) 2872 2873 # Check against integration of the PDF 2874 def overlap_numeric(X, Y, *, steps=8_192, z=5): 2875 'Numerical integration cross-check for overlap() ' 2876 fsum = math.fsum 2877 center = (X.mean + Y.mean) / 2.0 2878 width = z * max(X.stdev, Y.stdev) 2879 start = center - width 2880 dx = 2.0 * width / steps 2881 x_arr = [start + i*dx for i in range(steps)] 2882 xp = list(map(X.pdf, x_arr)) 2883 yp = list(map(Y.pdf, x_arr)) 2884 total = max(fsum(xp), fsum(yp)) 2885 return fsum(map(min, xp, yp)) / total 2886 2887 for X1, X2 in [ 2888 # Examples from Imman and Bradley 2889 (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0)), 2890 (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)), 2891 # Example from https://www.rasch.org/rmt/rmt101r.htm 2892 (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)), 2893 # Gender heights from http://www.usablestats.com/lessons/normal 2894 (NormalDist(70, 4), NormalDist(65, 3.5)), 2895 # Misc cases with equal standard deviations 2896 (NormalDist(100, 15), NormalDist(110, 15)), 2897 (NormalDist(-100, 15), NormalDist(110, 15)), 2898 (NormalDist(-100, 15), NormalDist(-110, 15)), 2899 # Misc cases with unequal standard deviations 2900 (NormalDist(100, 12), NormalDist(100, 15)), 2901 (NormalDist(100, 12), NormalDist(110, 15)), 2902 (NormalDist(100, 12), NormalDist(150, 15)), 2903 (NormalDist(100, 12), NormalDist(150, 35)), 2904 # Misc cases with small values 2905 (NormalDist(1.000, 0.002), NormalDist(1.001, 0.003)), 2906 (NormalDist(1.000, 0.002), NormalDist(1.006, 0.0003)), 2907 (NormalDist(1.000, 0.002), NormalDist(1.001, 0.099)), 2908 ]: 2909 self.assertAlmostEqual(X1.overlap(X2), overlap_numeric(X1, X2), places=5) 2910 self.assertAlmostEqual(X2.overlap(X1), overlap_numeric(X1, X2), places=5) 2911 2912 # Error cases 2913 X = NormalDist() 2914 with self.assertRaises(TypeError): 2915 X.overlap() # too few arguments 2916 with self.assertRaises(TypeError): 2917 X.overlap(X, X) # too may arguments 2918 with self.assertRaises(TypeError): 2919 X.overlap(None) # right operand not a NormalDist 2920 with self.assertRaises(self.module.StatisticsError): 2921 X.overlap(NormalDist(1, 0)) # right operand sigma is zero 2922 with self.assertRaises(self.module.StatisticsError): 2923 NormalDist(1, 0).overlap(X) # left operand sigma is zero 2924 2925 def test_zscore(self): 2926 NormalDist = self.module.NormalDist 2927 X = NormalDist(100, 15) 2928 self.assertEqual(X.zscore(142), 2.8) 2929 self.assertEqual(X.zscore(58), -2.8) 2930 self.assertEqual(X.zscore(100), 0.0) 2931 with self.assertRaises(TypeError): 2932 X.zscore() # too few arguments 2933 with self.assertRaises(TypeError): 2934 X.zscore(1, 1) # too may arguments 2935 with self.assertRaises(TypeError): 2936 X.zscore(None) # non-numeric type 2937 with self.assertRaises(self.module.StatisticsError): 2938 NormalDist(1, 0).zscore(100) # sigma is zero 2939 2940 def test_properties(self): 2941 X = self.module.NormalDist(100, 15) 2942 self.assertEqual(X.mean, 100) 2943 self.assertEqual(X.median, 100) 2944 self.assertEqual(X.mode, 100) 2945 self.assertEqual(X.stdev, 15) 2946 self.assertEqual(X.variance, 225) 2947 2948 def test_same_type_addition_and_subtraction(self): 2949 NormalDist = self.module.NormalDist 2950 X = NormalDist(100, 12) 2951 Y = NormalDist(40, 5) 2952 self.assertEqual(X + Y, NormalDist(140, 13)) # __add__ 2953 self.assertEqual(X - Y, NormalDist(60, 13)) # __sub__ 2954 2955 def test_translation_and_scaling(self): 2956 NormalDist = self.module.NormalDist 2957 X = NormalDist(100, 15) 2958 y = 10 2959 self.assertEqual(+X, NormalDist(100, 15)) # __pos__ 2960 self.assertEqual(-X, NormalDist(-100, 15)) # __neg__ 2961 self.assertEqual(X + y, NormalDist(110, 15)) # __add__ 2962 self.assertEqual(y + X, NormalDist(110, 15)) # __radd__ 2963 self.assertEqual(X - y, NormalDist(90, 15)) # __sub__ 2964 self.assertEqual(y - X, NormalDist(-90, 15)) # __rsub__ 2965 self.assertEqual(X * y, NormalDist(1000, 150)) # __mul__ 2966 self.assertEqual(y * X, NormalDist(1000, 150)) # __rmul__ 2967 self.assertEqual(X / y, NormalDist(10, 1.5)) # __truediv__ 2968 with self.assertRaises(TypeError): # __rtruediv__ 2969 y / X 2970 2971 def test_unary_operations(self): 2972 NormalDist = self.module.NormalDist 2973 X = NormalDist(100, 12) 2974 Y = +X 2975 self.assertIsNot(X, Y) 2976 self.assertEqual(X.mean, Y.mean) 2977 self.assertEqual(X.stdev, Y.stdev) 2978 Y = -X 2979 self.assertIsNot(X, Y) 2980 self.assertEqual(X.mean, -Y.mean) 2981 self.assertEqual(X.stdev, Y.stdev) 2982 2983 def test_equality(self): 2984 NormalDist = self.module.NormalDist 2985 nd1 = NormalDist() 2986 nd2 = NormalDist(2, 4) 2987 nd3 = NormalDist() 2988 nd4 = NormalDist(2, 4) 2989 nd5 = NormalDist(2, 8) 2990 nd6 = NormalDist(8, 4) 2991 self.assertNotEqual(nd1, nd2) 2992 self.assertEqual(nd1, nd3) 2993 self.assertEqual(nd2, nd4) 2994 self.assertNotEqual(nd2, nd5) 2995 self.assertNotEqual(nd2, nd6) 2996 2997 # Test NotImplemented when types are different 2998 class A: 2999 def __eq__(self, other): 3000 return 10 3001 a = A() 3002 self.assertEqual(nd1.__eq__(a), NotImplemented) 3003 self.assertEqual(nd1 == a, 10) 3004 self.assertEqual(a == nd1, 10) 3005 3006 # All subclasses to compare equal giving the same behavior 3007 # as list, tuple, int, float, complex, str, dict, set, etc. 3008 class SizedNormalDist(NormalDist): 3009 def __init__(self, mu, sigma, n): 3010 super().__init__(mu, sigma) 3011 self.n = n 3012 s = SizedNormalDist(100, 15, 57) 3013 nd4 = NormalDist(100, 15) 3014 self.assertEqual(s, nd4) 3015 3016 # Don't allow duck type equality because we wouldn't 3017 # want a lognormal distribution to compare equal 3018 # to a normal distribution with the same parameters 3019 class LognormalDist: 3020 def __init__(self, mu, sigma): 3021 self.mu = mu 3022 self.sigma = sigma 3023 lnd = LognormalDist(100, 15) 3024 nd = NormalDist(100, 15) 3025 self.assertNotEqual(nd, lnd) 3026 3027 def test_pickle_and_copy(self): 3028 nd = self.module.NormalDist(37.5, 5.625) 3029 nd1 = copy.copy(nd) 3030 self.assertEqual(nd, nd1) 3031 nd2 = copy.deepcopy(nd) 3032 self.assertEqual(nd, nd2) 3033 nd3 = pickle.loads(pickle.dumps(nd)) 3034 self.assertEqual(nd, nd3) 3035 3036 def test_hashability(self): 3037 ND = self.module.NormalDist 3038 s = {ND(100, 15), ND(100.0, 15.0), ND(100, 10), ND(95, 15), ND(100, 15)} 3039 self.assertEqual(len(s), 3) 3040 3041 def test_repr(self): 3042 nd = self.module.NormalDist(37.5, 5.625) 3043 self.assertEqual(repr(nd), 'NormalDist(mu=37.5, sigma=5.625)') 3044 3045# Swapping the sys.modules['statistics'] is to solving the 3046# _pickle.PicklingError: 3047# Can't pickle <class 'statistics.NormalDist'>: 3048# it's not the same object as statistics.NormalDist 3049class TestNormalDistPython(unittest.TestCase, TestNormalDist): 3050 module = py_statistics 3051 def setUp(self): 3052 sys.modules['statistics'] = self.module 3053 3054 def tearDown(self): 3055 sys.modules['statistics'] = statistics 3056 3057 3058@unittest.skipUnless(c_statistics, 'requires _statistics') 3059class TestNormalDistC(unittest.TestCase, TestNormalDist): 3060 module = c_statistics 3061 def setUp(self): 3062 sys.modules['statistics'] = self.module 3063 3064 def tearDown(self): 3065 sys.modules['statistics'] = statistics 3066 3067 3068# === Run tests === 3069 3070def load_tests(loader, tests, ignore): 3071 """Used for doctest/unittest integration.""" 3072 tests.addTests(doctest.DocTestSuite()) 3073 return tests 3074 3075 3076if __name__ == "__main__": 3077 unittest.main() 3078