1"""Test suite for statistics module, including helper NumericTestCase and
2approx_equal function.
3
4"""
5
6import bisect
7import collections
8import collections.abc
9import copy
10import decimal
11import doctest
12import math
13import pickle
14import random
15import sys
16import unittest
17from test import support
18
19from decimal import Decimal
20from fractions import Fraction
21from test import support
22
23
24# Module to be tested.
25import statistics
26
27
28# === Helper functions and class ===
29
30def sign(x):
31    """Return -1.0 for negatives, including -0.0, otherwise +1.0."""
32    return math.copysign(1, x)
33
34def _nan_equal(a, b):
35    """Return True if a and b are both the same kind of NAN.
36
37    >>> _nan_equal(Decimal('NAN'), Decimal('NAN'))
38    True
39    >>> _nan_equal(Decimal('sNAN'), Decimal('sNAN'))
40    True
41    >>> _nan_equal(Decimal('NAN'), Decimal('sNAN'))
42    False
43    >>> _nan_equal(Decimal(42), Decimal('NAN'))
44    False
45
46    >>> _nan_equal(float('NAN'), float('NAN'))
47    True
48    >>> _nan_equal(float('NAN'), 0.5)
49    False
50
51    >>> _nan_equal(float('NAN'), Decimal('NAN'))
52    False
53
54    NAN payloads are not compared.
55    """
56    if type(a) is not type(b):
57        return False
58    if isinstance(a, float):
59        return math.isnan(a) and math.isnan(b)
60    aexp = a.as_tuple()[2]
61    bexp = b.as_tuple()[2]
62    return (aexp == bexp) and (aexp in ('n', 'N'))  # Both NAN or both sNAN.
63
64
65def _calc_errors(actual, expected):
66    """Return the absolute and relative errors between two numbers.
67
68    >>> _calc_errors(100, 75)
69    (25, 0.25)
70    >>> _calc_errors(100, 100)
71    (0, 0.0)
72
73    Returns the (absolute error, relative error) between the two arguments.
74    """
75    base = max(abs(actual), abs(expected))
76    abs_err = abs(actual - expected)
77    rel_err = abs_err/base if base else float('inf')
78    return (abs_err, rel_err)
79
80
81def approx_equal(x, y, tol=1e-12, rel=1e-7):
82    """approx_equal(x, y [, tol [, rel]]) => True|False
83
84    Return True if numbers x and y are approximately equal, to within some
85    margin of error, otherwise return False. Numbers which compare equal
86    will also compare approximately equal.
87
88    x is approximately equal to y if the difference between them is less than
89    an absolute error tol or a relative error rel, whichever is bigger.
90
91    If given, both tol and rel must be finite, non-negative numbers. If not
92    given, default values are tol=1e-12 and rel=1e-7.
93
94    >>> approx_equal(1.2589, 1.2587, tol=0.0003, rel=0)
95    True
96    >>> approx_equal(1.2589, 1.2587, tol=0.0001, rel=0)
97    False
98
99    Absolute error is defined as abs(x-y); if that is less than or equal to
100    tol, x and y are considered approximately equal.
101
102    Relative error is defined as abs((x-y)/x) or abs((x-y)/y), whichever is
103    smaller, provided x or y are not zero. If that figure is less than or
104    equal to rel, x and y are considered approximately equal.
105
106    Complex numbers are not directly supported. If you wish to compare to
107    complex numbers, extract their real and imaginary parts and compare them
108    individually.
109
110    NANs always compare unequal, even with themselves. Infinities compare
111    approximately equal if they have the same sign (both positive or both
112    negative). Infinities with different signs compare unequal; so do
113    comparisons of infinities with finite numbers.
114    """
115    if tol < 0 or rel < 0:
116        raise ValueError('error tolerances must be non-negative')
117    # NANs are never equal to anything, approximately or otherwise.
118    if math.isnan(x) or math.isnan(y):
119        return False
120    # Numbers which compare equal also compare approximately equal.
121    if x == y:
122        # This includes the case of two infinities with the same sign.
123        return True
124    if math.isinf(x) or math.isinf(y):
125        # This includes the case of two infinities of opposite sign, or
126        # one infinity and one finite number.
127        return False
128    # Two finite numbers.
129    actual_error = abs(x - y)
130    allowed_error = max(tol, rel*max(abs(x), abs(y)))
131    return actual_error <= allowed_error
132
133
134# This class exists only as somewhere to stick a docstring containing
135# doctests. The following docstring and tests were originally in a separate
136# module. Now that it has been merged in here, I need somewhere to hang the.
137# docstring. Ultimately, this class will die, and the information below will
138# either become redundant, or be moved into more appropriate places.
139class _DoNothing:
140    """
141    When doing numeric work, especially with floats, exact equality is often
142    not what you want. Due to round-off error, it is often a bad idea to try
143    to compare floats with equality. Instead the usual procedure is to test
144    them with some (hopefully small!) allowance for error.
145
146    The ``approx_equal`` function allows you to specify either an absolute
147    error tolerance, or a relative error, or both.
148
149    Absolute error tolerances are simple, but you need to know the magnitude
150    of the quantities being compared:
151
152    >>> approx_equal(12.345, 12.346, tol=1e-3)
153    True
154    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3)  # tol is too small.
155    False
156
157    Relative errors are more suitable when the values you are comparing can
158    vary in magnitude:
159
160    >>> approx_equal(12.345, 12.346, rel=1e-4)
161    True
162    >>> approx_equal(12.345e6, 12.346e6, rel=1e-4)
163    True
164
165    but a naive implementation of relative error testing can run into trouble
166    around zero.
167
168    If you supply both an absolute tolerance and a relative error, the
169    comparison succeeds if either individual test succeeds:
170
171    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3, rel=1e-4)
172    True
173
174    """
175    pass
176
177
178
179# We prefer this for testing numeric values that may not be exactly equal,
180# and avoid using TestCase.assertAlmostEqual, because it sucks :-)
181
182py_statistics = support.import_fresh_module('statistics', blocked=['_statistics'])
183c_statistics = support.import_fresh_module('statistics', fresh=['_statistics'])
184
185
186class TestModules(unittest.TestCase):
187    func_names = ['_normal_dist_inv_cdf']
188
189    def test_py_functions(self):
190        for fname in self.func_names:
191            self.assertEqual(getattr(py_statistics, fname).__module__, 'statistics')
192
193    @unittest.skipUnless(c_statistics, 'requires _statistics')
194    def test_c_functions(self):
195        for fname in self.func_names:
196            self.assertEqual(getattr(c_statistics, fname).__module__, '_statistics')
197
198
199class NumericTestCase(unittest.TestCase):
200    """Unit test class for numeric work.
201
202    This subclasses TestCase. In addition to the standard method
203    ``TestCase.assertAlmostEqual``,  ``assertApproxEqual`` is provided.
204    """
205    # By default, we expect exact equality, unless overridden.
206    tol = rel = 0
207
208    def assertApproxEqual(
209            self, first, second, tol=None, rel=None, msg=None
210            ):
211        """Test passes if ``first`` and ``second`` are approximately equal.
212
213        This test passes if ``first`` and ``second`` are equal to
214        within ``tol``, an absolute error, or ``rel``, a relative error.
215
216        If either ``tol`` or ``rel`` are None or not given, they default to
217        test attributes of the same name (by default, 0).
218
219        The objects may be either numbers, or sequences of numbers. Sequences
220        are tested element-by-element.
221
222        >>> class MyTest(NumericTestCase):
223        ...     def test_number(self):
224        ...         x = 1.0/6
225        ...         y = sum([x]*6)
226        ...         self.assertApproxEqual(y, 1.0, tol=1e-15)
227        ...     def test_sequence(self):
228        ...         a = [1.001, 1.001e-10, 1.001e10]
229        ...         b = [1.0, 1e-10, 1e10]
230        ...         self.assertApproxEqual(a, b, rel=1e-3)
231        ...
232        >>> import unittest
233        >>> from io import StringIO  # Suppress test runner output.
234        >>> suite = unittest.TestLoader().loadTestsFromTestCase(MyTest)
235        >>> unittest.TextTestRunner(stream=StringIO()).run(suite)
236        <unittest.runner.TextTestResult run=2 errors=0 failures=0>
237
238        """
239        if tol is None:
240            tol = self.tol
241        if rel is None:
242            rel = self.rel
243        if (
244                isinstance(first, collections.abc.Sequence) and
245                isinstance(second, collections.abc.Sequence)
246            ):
247            check = self._check_approx_seq
248        else:
249            check = self._check_approx_num
250        check(first, second, tol, rel, msg)
251
252    def _check_approx_seq(self, first, second, tol, rel, msg):
253        if len(first) != len(second):
254            standardMsg = (
255                "sequences differ in length: %d items != %d items"
256                % (len(first), len(second))
257                )
258            msg = self._formatMessage(msg, standardMsg)
259            raise self.failureException(msg)
260        for i, (a,e) in enumerate(zip(first, second)):
261            self._check_approx_num(a, e, tol, rel, msg, i)
262
263    def _check_approx_num(self, first, second, tol, rel, msg, idx=None):
264        if approx_equal(first, second, tol, rel):
265            # Test passes. Return early, we are done.
266            return None
267        # Otherwise we failed.
268        standardMsg = self._make_std_err_msg(first, second, tol, rel, idx)
269        msg = self._formatMessage(msg, standardMsg)
270        raise self.failureException(msg)
271
272    @staticmethod
273    def _make_std_err_msg(first, second, tol, rel, idx):
274        # Create the standard error message for approx_equal failures.
275        assert first != second
276        template = (
277            '  %r != %r\n'
278            '  values differ by more than tol=%r and rel=%r\n'
279            '  -> absolute error = %r\n'
280            '  -> relative error = %r'
281            )
282        if idx is not None:
283            header = 'numeric sequences first differ at index %d.\n' % idx
284            template = header + template
285        # Calculate actual errors:
286        abs_err, rel_err = _calc_errors(first, second)
287        return template % (first, second, tol, rel, abs_err, rel_err)
288
289
290# ========================
291# === Test the helpers ===
292# ========================
293
294class TestSign(unittest.TestCase):
295    """Test that the helper function sign() works correctly."""
296    def testZeroes(self):
297        # Test that signed zeroes report their sign correctly.
298        self.assertEqual(sign(0.0), +1)
299        self.assertEqual(sign(-0.0), -1)
300
301
302# --- Tests for approx_equal ---
303
304class ApproxEqualSymmetryTest(unittest.TestCase):
305    # Test symmetry of approx_equal.
306
307    def test_relative_symmetry(self):
308        # Check that approx_equal treats relative error symmetrically.
309        # (a-b)/a is usually not equal to (a-b)/b. Ensure that this
310        # doesn't matter.
311        #
312        #   Note: the reason for this test is that an early version
313        #   of approx_equal was not symmetric. A relative error test
314        #   would pass, or fail, depending on which value was passed
315        #   as the first argument.
316        #
317        args1 = [2456, 37.8, -12.45, Decimal('2.54'), Fraction(17, 54)]
318        args2 = [2459, 37.2, -12.41, Decimal('2.59'), Fraction(15, 54)]
319        assert len(args1) == len(args2)
320        for a, b in zip(args1, args2):
321            self.do_relative_symmetry(a, b)
322
323    def do_relative_symmetry(self, a, b):
324        a, b = min(a, b), max(a, b)
325        assert a < b
326        delta = b - a  # The absolute difference between the values.
327        rel_err1, rel_err2 = abs(delta/a), abs(delta/b)
328        # Choose an error margin halfway between the two.
329        rel = (rel_err1 + rel_err2)/2
330        # Now see that values a and b compare approx equal regardless of
331        # which is given first.
332        self.assertTrue(approx_equal(a, b, tol=0, rel=rel))
333        self.assertTrue(approx_equal(b, a, tol=0, rel=rel))
334
335    def test_symmetry(self):
336        # Test that approx_equal(a, b) == approx_equal(b, a)
337        args = [-23, -2, 5, 107, 93568]
338        delta = 2
339        for a in args:
340            for type_ in (int, float, Decimal, Fraction):
341                x = type_(a)*100
342                y = x + delta
343                r = abs(delta/max(x, y))
344                # There are five cases to check:
345                # 1) actual error <= tol, <= rel
346                self.do_symmetry_test(x, y, tol=delta, rel=r)
347                self.do_symmetry_test(x, y, tol=delta+1, rel=2*r)
348                # 2) actual error > tol, > rel
349                self.do_symmetry_test(x, y, tol=delta-1, rel=r/2)
350                # 3) actual error <= tol, > rel
351                self.do_symmetry_test(x, y, tol=delta, rel=r/2)
352                # 4) actual error > tol, <= rel
353                self.do_symmetry_test(x, y, tol=delta-1, rel=r)
354                self.do_symmetry_test(x, y, tol=delta-1, rel=2*r)
355                # 5) exact equality test
356                self.do_symmetry_test(x, x, tol=0, rel=0)
357                self.do_symmetry_test(x, y, tol=0, rel=0)
358
359    def do_symmetry_test(self, a, b, tol, rel):
360        template = "approx_equal comparisons don't match for %r"
361        flag1 = approx_equal(a, b, tol, rel)
362        flag2 = approx_equal(b, a, tol, rel)
363        self.assertEqual(flag1, flag2, template.format((a, b, tol, rel)))
364
365
366class ApproxEqualExactTest(unittest.TestCase):
367    # Test the approx_equal function with exactly equal values.
368    # Equal values should compare as approximately equal.
369    # Test cases for exactly equal values, which should compare approx
370    # equal regardless of the error tolerances given.
371
372    def do_exactly_equal_test(self, x, tol, rel):
373        result = approx_equal(x, x, tol=tol, rel=rel)
374        self.assertTrue(result, 'equality failure for x=%r' % x)
375        result = approx_equal(-x, -x, tol=tol, rel=rel)
376        self.assertTrue(result, 'equality failure for x=%r' % -x)
377
378    def test_exactly_equal_ints(self):
379        # Test that equal int values are exactly equal.
380        for n in [42, 19740, 14974, 230, 1795, 700245, 36587]:
381            self.do_exactly_equal_test(n, 0, 0)
382
383    def test_exactly_equal_floats(self):
384        # Test that equal float values are exactly equal.
385        for x in [0.42, 1.9740, 1497.4, 23.0, 179.5, 70.0245, 36.587]:
386            self.do_exactly_equal_test(x, 0, 0)
387
388    def test_exactly_equal_fractions(self):
389        # Test that equal Fraction values are exactly equal.
390        F = Fraction
391        for f in [F(1, 2), F(0), F(5, 3), F(9, 7), F(35, 36), F(3, 7)]:
392            self.do_exactly_equal_test(f, 0, 0)
393
394    def test_exactly_equal_decimals(self):
395        # Test that equal Decimal values are exactly equal.
396        D = Decimal
397        for d in map(D, "8.2 31.274 912.04 16.745 1.2047".split()):
398            self.do_exactly_equal_test(d, 0, 0)
399
400    def test_exactly_equal_absolute(self):
401        # Test that equal values are exactly equal with an absolute error.
402        for n in [16, 1013, 1372, 1198, 971, 4]:
403            # Test as ints.
404            self.do_exactly_equal_test(n, 0.01, 0)
405            # Test as floats.
406            self.do_exactly_equal_test(n/10, 0.01, 0)
407            # Test as Fractions.
408            f = Fraction(n, 1234)
409            self.do_exactly_equal_test(f, 0.01, 0)
410
411    def test_exactly_equal_absolute_decimals(self):
412        # Test equal Decimal values are exactly equal with an absolute error.
413        self.do_exactly_equal_test(Decimal("3.571"), Decimal("0.01"), 0)
414        self.do_exactly_equal_test(-Decimal("81.3971"), Decimal("0.01"), 0)
415
416    def test_exactly_equal_relative(self):
417        # Test that equal values are exactly equal with a relative error.
418        for x in [8347, 101.3, -7910.28, Fraction(5, 21)]:
419            self.do_exactly_equal_test(x, 0, 0.01)
420        self.do_exactly_equal_test(Decimal("11.68"), 0, Decimal("0.01"))
421
422    def test_exactly_equal_both(self):
423        # Test that equal values are equal when both tol and rel are given.
424        for x in [41017, 16.742, -813.02, Fraction(3, 8)]:
425            self.do_exactly_equal_test(x, 0.1, 0.01)
426        D = Decimal
427        self.do_exactly_equal_test(D("7.2"), D("0.1"), D("0.01"))
428
429
430class ApproxEqualUnequalTest(unittest.TestCase):
431    # Unequal values should compare unequal with zero error tolerances.
432    # Test cases for unequal values, with exact equality test.
433
434    def do_exactly_unequal_test(self, x):
435        for a in (x, -x):
436            result = approx_equal(a, a+1, tol=0, rel=0)
437            self.assertFalse(result, 'inequality failure for x=%r' % a)
438
439    def test_exactly_unequal_ints(self):
440        # Test unequal int values are unequal with zero error tolerance.
441        for n in [951, 572305, 478, 917, 17240]:
442            self.do_exactly_unequal_test(n)
443
444    def test_exactly_unequal_floats(self):
445        # Test unequal float values are unequal with zero error tolerance.
446        for x in [9.51, 5723.05, 47.8, 9.17, 17.24]:
447            self.do_exactly_unequal_test(x)
448
449    def test_exactly_unequal_fractions(self):
450        # Test that unequal Fractions are unequal with zero error tolerance.
451        F = Fraction
452        for f in [F(1, 5), F(7, 9), F(12, 11), F(101, 99023)]:
453            self.do_exactly_unequal_test(f)
454
455    def test_exactly_unequal_decimals(self):
456        # Test that unequal Decimals are unequal with zero error tolerance.
457        for d in map(Decimal, "3.1415 298.12 3.47 18.996 0.00245".split()):
458            self.do_exactly_unequal_test(d)
459
460
461class ApproxEqualInexactTest(unittest.TestCase):
462    # Inexact test cases for approx_error.
463    # Test cases when comparing two values that are not exactly equal.
464
465    # === Absolute error tests ===
466
467    def do_approx_equal_abs_test(self, x, delta):
468        template = "Test failure for x={!r}, y={!r}"
469        for y in (x + delta, x - delta):
470            msg = template.format(x, y)
471            self.assertTrue(approx_equal(x, y, tol=2*delta, rel=0), msg)
472            self.assertFalse(approx_equal(x, y, tol=delta/2, rel=0), msg)
473
474    def test_approx_equal_absolute_ints(self):
475        # Test approximate equality of ints with an absolute error.
476        for n in [-10737, -1975, -7, -2, 0, 1, 9, 37, 423, 9874, 23789110]:
477            self.do_approx_equal_abs_test(n, 10)
478            self.do_approx_equal_abs_test(n, 2)
479
480    def test_approx_equal_absolute_floats(self):
481        # Test approximate equality of floats with an absolute error.
482        for x in [-284.126, -97.1, -3.4, -2.15, 0.5, 1.0, 7.8, 4.23, 3817.4]:
483            self.do_approx_equal_abs_test(x, 1.5)
484            self.do_approx_equal_abs_test(x, 0.01)
485            self.do_approx_equal_abs_test(x, 0.0001)
486
487    def test_approx_equal_absolute_fractions(self):
488        # Test approximate equality of Fractions with an absolute error.
489        delta = Fraction(1, 29)
490        numerators = [-84, -15, -2, -1, 0, 1, 5, 17, 23, 34, 71]
491        for f in (Fraction(n, 29) for n in numerators):
492            self.do_approx_equal_abs_test(f, delta)
493            self.do_approx_equal_abs_test(f, float(delta))
494
495    def test_approx_equal_absolute_decimals(self):
496        # Test approximate equality of Decimals with an absolute error.
497        delta = Decimal("0.01")
498        for d in map(Decimal, "1.0 3.5 36.08 61.79 7912.3648".split()):
499            self.do_approx_equal_abs_test(d, delta)
500            self.do_approx_equal_abs_test(-d, delta)
501
502    def test_cross_zero(self):
503        # Test for the case of the two values having opposite signs.
504        self.assertTrue(approx_equal(1e-5, -1e-5, tol=1e-4, rel=0))
505
506    # === Relative error tests ===
507
508    def do_approx_equal_rel_test(self, x, delta):
509        template = "Test failure for x={!r}, y={!r}"
510        for y in (x*(1+delta), x*(1-delta)):
511            msg = template.format(x, y)
512            self.assertTrue(approx_equal(x, y, tol=0, rel=2*delta), msg)
513            self.assertFalse(approx_equal(x, y, tol=0, rel=delta/2), msg)
514
515    def test_approx_equal_relative_ints(self):
516        # Test approximate equality of ints with a relative error.
517        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.36))
518        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.37))
519        # ---
520        self.assertTrue(approx_equal(449, 512, tol=0, rel=0.125))
521        self.assertTrue(approx_equal(448, 512, tol=0, rel=0.125))
522        self.assertFalse(approx_equal(447, 512, tol=0, rel=0.125))
523
524    def test_approx_equal_relative_floats(self):
525        # Test approximate equality of floats with a relative error.
526        for x in [-178.34, -0.1, 0.1, 1.0, 36.97, 2847.136, 9145.074]:
527            self.do_approx_equal_rel_test(x, 0.02)
528            self.do_approx_equal_rel_test(x, 0.0001)
529
530    def test_approx_equal_relative_fractions(self):
531        # Test approximate equality of Fractions with a relative error.
532        F = Fraction
533        delta = Fraction(3, 8)
534        for f in [F(3, 84), F(17, 30), F(49, 50), F(92, 85)]:
535            for d in (delta, float(delta)):
536                self.do_approx_equal_rel_test(f, d)
537                self.do_approx_equal_rel_test(-f, d)
538
539    def test_approx_equal_relative_decimals(self):
540        # Test approximate equality of Decimals with a relative error.
541        for d in map(Decimal, "0.02 1.0 5.7 13.67 94.138 91027.9321".split()):
542            self.do_approx_equal_rel_test(d, Decimal("0.001"))
543            self.do_approx_equal_rel_test(-d, Decimal("0.05"))
544
545    # === Both absolute and relative error tests ===
546
547    # There are four cases to consider:
548    #   1) actual error <= both absolute and relative error
549    #   2) actual error <= absolute error but > relative error
550    #   3) actual error <= relative error but > absolute error
551    #   4) actual error > both absolute and relative error
552
553    def do_check_both(self, a, b, tol, rel, tol_flag, rel_flag):
554        check = self.assertTrue if tol_flag else self.assertFalse
555        check(approx_equal(a, b, tol=tol, rel=0))
556        check = self.assertTrue if rel_flag else self.assertFalse
557        check(approx_equal(a, b, tol=0, rel=rel))
558        check = self.assertTrue if (tol_flag or rel_flag) else self.assertFalse
559        check(approx_equal(a, b, tol=tol, rel=rel))
560
561    def test_approx_equal_both1(self):
562        # Test actual error <= both absolute and relative error.
563        self.do_check_both(7.955, 7.952, 0.004, 3.8e-4, True, True)
564        self.do_check_both(-7.387, -7.386, 0.002, 0.0002, True, True)
565
566    def test_approx_equal_both2(self):
567        # Test actual error <= absolute error but > relative error.
568        self.do_check_both(7.955, 7.952, 0.004, 3.7e-4, True, False)
569
570    def test_approx_equal_both3(self):
571        # Test actual error <= relative error but > absolute error.
572        self.do_check_both(7.955, 7.952, 0.001, 3.8e-4, False, True)
573
574    def test_approx_equal_both4(self):
575        # Test actual error > both absolute and relative error.
576        self.do_check_both(2.78, 2.75, 0.01, 0.001, False, False)
577        self.do_check_both(971.44, 971.47, 0.02, 3e-5, False, False)
578
579
580class ApproxEqualSpecialsTest(unittest.TestCase):
581    # Test approx_equal with NANs and INFs and zeroes.
582
583    def test_inf(self):
584        for type_ in (float, Decimal):
585            inf = type_('inf')
586            self.assertTrue(approx_equal(inf, inf))
587            self.assertTrue(approx_equal(inf, inf, 0, 0))
588            self.assertTrue(approx_equal(inf, inf, 1, 0.01))
589            self.assertTrue(approx_equal(-inf, -inf))
590            self.assertFalse(approx_equal(inf, -inf))
591            self.assertFalse(approx_equal(inf, 1000))
592
593    def test_nan(self):
594        for type_ in (float, Decimal):
595            nan = type_('nan')
596            for other in (nan, type_('inf'), 1000):
597                self.assertFalse(approx_equal(nan, other))
598
599    def test_float_zeroes(self):
600        nzero = math.copysign(0.0, -1)
601        self.assertTrue(approx_equal(nzero, 0.0, tol=0.1, rel=0.1))
602
603    def test_decimal_zeroes(self):
604        nzero = Decimal("-0.0")
605        self.assertTrue(approx_equal(nzero, Decimal(0), tol=0.1, rel=0.1))
606
607
608class TestApproxEqualErrors(unittest.TestCase):
609    # Test error conditions of approx_equal.
610
611    def test_bad_tol(self):
612        # Test negative tol raises.
613        self.assertRaises(ValueError, approx_equal, 100, 100, -1, 0.1)
614
615    def test_bad_rel(self):
616        # Test negative rel raises.
617        self.assertRaises(ValueError, approx_equal, 100, 100, 1, -0.1)
618
619
620# --- Tests for NumericTestCase ---
621
622# The formatting routine that generates the error messages is complex enough
623# that it too needs testing.
624
625class TestNumericTestCase(unittest.TestCase):
626    # The exact wording of NumericTestCase error messages is *not* guaranteed,
627    # but we need to give them some sort of test to ensure that they are
628    # generated correctly. As a compromise, we look for specific substrings
629    # that are expected to be found even if the overall error message changes.
630
631    def do_test(self, args):
632        actual_msg = NumericTestCase._make_std_err_msg(*args)
633        expected = self.generate_substrings(*args)
634        for substring in expected:
635            self.assertIn(substring, actual_msg)
636
637    def test_numerictestcase_is_testcase(self):
638        # Ensure that NumericTestCase actually is a TestCase.
639        self.assertTrue(issubclass(NumericTestCase, unittest.TestCase))
640
641    def test_error_msg_numeric(self):
642        # Test the error message generated for numeric comparisons.
643        args = (2.5, 4.0, 0.5, 0.25, None)
644        self.do_test(args)
645
646    def test_error_msg_sequence(self):
647        # Test the error message generated for sequence comparisons.
648        args = (3.75, 8.25, 1.25, 0.5, 7)
649        self.do_test(args)
650
651    def generate_substrings(self, first, second, tol, rel, idx):
652        """Return substrings we expect to see in error messages."""
653        abs_err, rel_err = _calc_errors(first, second)
654        substrings = [
655                'tol=%r' % tol,
656                'rel=%r' % rel,
657                'absolute error = %r' % abs_err,
658                'relative error = %r' % rel_err,
659                ]
660        if idx is not None:
661            substrings.append('differ at index %d' % idx)
662        return substrings
663
664
665# =======================================
666# === Tests for the statistics module ===
667# =======================================
668
669
670class GlobalsTest(unittest.TestCase):
671    module = statistics
672    expected_metadata = ["__doc__", "__all__"]
673
674    def test_meta(self):
675        # Test for the existence of metadata.
676        for meta in self.expected_metadata:
677            self.assertTrue(hasattr(self.module, meta),
678                            "%s not present" % meta)
679
680    def test_check_all(self):
681        # Check everything in __all__ exists and is public.
682        module = self.module
683        for name in module.__all__:
684            # No private names in __all__:
685            self.assertFalse(name.startswith("_"),
686                             'private name "%s" in __all__' % name)
687            # And anything in __all__ must exist:
688            self.assertTrue(hasattr(module, name),
689                            'missing name "%s" in __all__' % name)
690
691
692class DocTests(unittest.TestCase):
693    @unittest.skipIf(sys.flags.optimize >= 2,
694                     "Docstrings are omitted with -OO and above")
695    def test_doc_tests(self):
696        failed, tried = doctest.testmod(statistics, optionflags=doctest.ELLIPSIS)
697        self.assertGreater(tried, 0)
698        self.assertEqual(failed, 0)
699
700class StatisticsErrorTest(unittest.TestCase):
701    def test_has_exception(self):
702        errmsg = (
703                "Expected StatisticsError to be a ValueError, but got a"
704                " subclass of %r instead."
705                )
706        self.assertTrue(hasattr(statistics, 'StatisticsError'))
707        self.assertTrue(
708                issubclass(statistics.StatisticsError, ValueError),
709                errmsg % statistics.StatisticsError.__base__
710                )
711
712
713# === Tests for private utility functions ===
714
715class ExactRatioTest(unittest.TestCase):
716    # Test _exact_ratio utility.
717
718    def test_int(self):
719        for i in (-20, -3, 0, 5, 99, 10**20):
720            self.assertEqual(statistics._exact_ratio(i), (i, 1))
721
722    def test_fraction(self):
723        numerators = (-5, 1, 12, 38)
724        for n in numerators:
725            f = Fraction(n, 37)
726            self.assertEqual(statistics._exact_ratio(f), (n, 37))
727
728    def test_float(self):
729        self.assertEqual(statistics._exact_ratio(0.125), (1, 8))
730        self.assertEqual(statistics._exact_ratio(1.125), (9, 8))
731        data = [random.uniform(-100, 100) for _ in range(100)]
732        for x in data:
733            num, den = statistics._exact_ratio(x)
734            self.assertEqual(x, num/den)
735
736    def test_decimal(self):
737        D = Decimal
738        _exact_ratio = statistics._exact_ratio
739        self.assertEqual(_exact_ratio(D("0.125")), (1, 8))
740        self.assertEqual(_exact_ratio(D("12.345")), (2469, 200))
741        self.assertEqual(_exact_ratio(D("-1.98")), (-99, 50))
742
743    def test_inf(self):
744        INF = float("INF")
745        class MyFloat(float):
746            pass
747        class MyDecimal(Decimal):
748            pass
749        for inf in (INF, -INF):
750            for type_ in (float, MyFloat, Decimal, MyDecimal):
751                x = type_(inf)
752                ratio = statistics._exact_ratio(x)
753                self.assertEqual(ratio, (x, None))
754                self.assertEqual(type(ratio[0]), type_)
755                self.assertTrue(math.isinf(ratio[0]))
756
757    def test_float_nan(self):
758        NAN = float("NAN")
759        class MyFloat(float):
760            pass
761        for nan in (NAN, MyFloat(NAN)):
762            ratio = statistics._exact_ratio(nan)
763            self.assertTrue(math.isnan(ratio[0]))
764            self.assertIs(ratio[1], None)
765            self.assertEqual(type(ratio[0]), type(nan))
766
767    def test_decimal_nan(self):
768        NAN = Decimal("NAN")
769        sNAN = Decimal("sNAN")
770        class MyDecimal(Decimal):
771            pass
772        for nan in (NAN, MyDecimal(NAN), sNAN, MyDecimal(sNAN)):
773            ratio = statistics._exact_ratio(nan)
774            self.assertTrue(_nan_equal(ratio[0], nan))
775            self.assertIs(ratio[1], None)
776            self.assertEqual(type(ratio[0]), type(nan))
777
778
779class DecimalToRatioTest(unittest.TestCase):
780    # Test _exact_ratio private function.
781
782    def test_infinity(self):
783        # Test that INFs are handled correctly.
784        inf = Decimal('INF')
785        self.assertEqual(statistics._exact_ratio(inf), (inf, None))
786        self.assertEqual(statistics._exact_ratio(-inf), (-inf, None))
787
788    def test_nan(self):
789        # Test that NANs are handled correctly.
790        for nan in (Decimal('NAN'), Decimal('sNAN')):
791            num, den = statistics._exact_ratio(nan)
792            # Because NANs always compare non-equal, we cannot use assertEqual.
793            # Nor can we use an identity test, as we don't guarantee anything
794            # about the object identity.
795            self.assertTrue(_nan_equal(num, nan))
796            self.assertIs(den, None)
797
798    def test_sign(self):
799        # Test sign is calculated correctly.
800        numbers = [Decimal("9.8765e12"), Decimal("9.8765e-12")]
801        for d in numbers:
802            # First test positive decimals.
803            assert d > 0
804            num, den = statistics._exact_ratio(d)
805            self.assertGreaterEqual(num, 0)
806            self.assertGreater(den, 0)
807            # Then test negative decimals.
808            num, den = statistics._exact_ratio(-d)
809            self.assertLessEqual(num, 0)
810            self.assertGreater(den, 0)
811
812    def test_negative_exponent(self):
813        # Test result when the exponent is negative.
814        t = statistics._exact_ratio(Decimal("0.1234"))
815        self.assertEqual(t, (617, 5000))
816
817    def test_positive_exponent(self):
818        # Test results when the exponent is positive.
819        t = statistics._exact_ratio(Decimal("1.234e7"))
820        self.assertEqual(t, (12340000, 1))
821
822    def test_regression_20536(self):
823        # Regression test for issue 20536.
824        # See http://bugs.python.org/issue20536
825        t = statistics._exact_ratio(Decimal("1e2"))
826        self.assertEqual(t, (100, 1))
827        t = statistics._exact_ratio(Decimal("1.47e5"))
828        self.assertEqual(t, (147000, 1))
829
830
831class IsFiniteTest(unittest.TestCase):
832    # Test _isfinite private function.
833
834    def test_finite(self):
835        # Test that finite numbers are recognised as finite.
836        for x in (5, Fraction(1, 3), 2.5, Decimal("5.5")):
837            self.assertTrue(statistics._isfinite(x))
838
839    def test_infinity(self):
840        # Test that INFs are not recognised as finite.
841        for x in (float("inf"), Decimal("inf")):
842            self.assertFalse(statistics._isfinite(x))
843
844    def test_nan(self):
845        # Test that NANs are not recognised as finite.
846        for x in (float("nan"), Decimal("NAN"), Decimal("sNAN")):
847            self.assertFalse(statistics._isfinite(x))
848
849
850class CoerceTest(unittest.TestCase):
851    # Test that private function _coerce correctly deals with types.
852
853    # The coercion rules are currently an implementation detail, although at
854    # some point that should change. The tests and comments here define the
855    # correct implementation.
856
857    # Pre-conditions of _coerce:
858    #
859    #   - The first time _sum calls _coerce, the
860    #   - coerce(T, S) will never be called with bool as the first argument;
861    #     this is a pre-condition, guarded with an assertion.
862
863    #
864    #   - coerce(T, T) will always return T; we assume T is a valid numeric
865    #     type. Violate this assumption at your own risk.
866    #
867    #   - Apart from as above, bool is treated as if it were actually int.
868    #
869    #   - coerce(int, X) and coerce(X, int) return X.
870    #   -
871    def test_bool(self):
872        # bool is somewhat special, due to the pre-condition that it is
873        # never given as the first argument to _coerce, and that it cannot
874        # be subclassed. So we test it specially.
875        for T in (int, float, Fraction, Decimal):
876            self.assertIs(statistics._coerce(T, bool), T)
877            class MyClass(T): pass
878            self.assertIs(statistics._coerce(MyClass, bool), MyClass)
879
880    def assertCoerceTo(self, A, B):
881        """Assert that type A coerces to B."""
882        self.assertIs(statistics._coerce(A, B), B)
883        self.assertIs(statistics._coerce(B, A), B)
884
885    def check_coerce_to(self, A, B):
886        """Checks that type A coerces to B, including subclasses."""
887        # Assert that type A is coerced to B.
888        self.assertCoerceTo(A, B)
889        # Subclasses of A are also coerced to B.
890        class SubclassOfA(A): pass
891        self.assertCoerceTo(SubclassOfA, B)
892        # A, and subclasses of A, are coerced to subclasses of B.
893        class SubclassOfB(B): pass
894        self.assertCoerceTo(A, SubclassOfB)
895        self.assertCoerceTo(SubclassOfA, SubclassOfB)
896
897    def assertCoerceRaises(self, A, B):
898        """Assert that coercing A to B, or vice versa, raises TypeError."""
899        self.assertRaises(TypeError, statistics._coerce, (A, B))
900        self.assertRaises(TypeError, statistics._coerce, (B, A))
901
902    def check_type_coercions(self, T):
903        """Check that type T coerces correctly with subclasses of itself."""
904        assert T is not bool
905        # Coercing a type with itself returns the same type.
906        self.assertIs(statistics._coerce(T, T), T)
907        # Coercing a type with a subclass of itself returns the subclass.
908        class U(T): pass
909        class V(T): pass
910        class W(U): pass
911        for typ in (U, V, W):
912            self.assertCoerceTo(T, typ)
913        self.assertCoerceTo(U, W)
914        # Coercing two subclasses that aren't parent/child is an error.
915        self.assertCoerceRaises(U, V)
916        self.assertCoerceRaises(V, W)
917
918    def test_int(self):
919        # Check that int coerces correctly.
920        self.check_type_coercions(int)
921        for typ in (float, Fraction, Decimal):
922            self.check_coerce_to(int, typ)
923
924    def test_fraction(self):
925        # Check that Fraction coerces correctly.
926        self.check_type_coercions(Fraction)
927        self.check_coerce_to(Fraction, float)
928
929    def test_decimal(self):
930        # Check that Decimal coerces correctly.
931        self.check_type_coercions(Decimal)
932
933    def test_float(self):
934        # Check that float coerces correctly.
935        self.check_type_coercions(float)
936
937    def test_non_numeric_types(self):
938        for bad_type in (str, list, type(None), tuple, dict):
939            for good_type in (int, float, Fraction, Decimal):
940                self.assertCoerceRaises(good_type, bad_type)
941
942    def test_incompatible_types(self):
943        # Test that incompatible types raise.
944        for T in (float, Fraction):
945            class MySubclass(T): pass
946            self.assertCoerceRaises(T, Decimal)
947            self.assertCoerceRaises(MySubclass, Decimal)
948
949
950class ConvertTest(unittest.TestCase):
951    # Test private _convert function.
952
953    def check_exact_equal(self, x, y):
954        """Check that x equals y, and has the same type as well."""
955        self.assertEqual(x, y)
956        self.assertIs(type(x), type(y))
957
958    def test_int(self):
959        # Test conversions to int.
960        x = statistics._convert(Fraction(71), int)
961        self.check_exact_equal(x, 71)
962        class MyInt(int): pass
963        x = statistics._convert(Fraction(17), MyInt)
964        self.check_exact_equal(x, MyInt(17))
965
966    def test_fraction(self):
967        # Test conversions to Fraction.
968        x = statistics._convert(Fraction(95, 99), Fraction)
969        self.check_exact_equal(x, Fraction(95, 99))
970        class MyFraction(Fraction):
971            def __truediv__(self, other):
972                return self.__class__(super().__truediv__(other))
973        x = statistics._convert(Fraction(71, 13), MyFraction)
974        self.check_exact_equal(x, MyFraction(71, 13))
975
976    def test_float(self):
977        # Test conversions to float.
978        x = statistics._convert(Fraction(-1, 2), float)
979        self.check_exact_equal(x, -0.5)
980        class MyFloat(float):
981            def __truediv__(self, other):
982                return self.__class__(super().__truediv__(other))
983        x = statistics._convert(Fraction(9, 8), MyFloat)
984        self.check_exact_equal(x, MyFloat(1.125))
985
986    def test_decimal(self):
987        # Test conversions to Decimal.
988        x = statistics._convert(Fraction(1, 40), Decimal)
989        self.check_exact_equal(x, Decimal("0.025"))
990        class MyDecimal(Decimal):
991            def __truediv__(self, other):
992                return self.__class__(super().__truediv__(other))
993        x = statistics._convert(Fraction(-15, 16), MyDecimal)
994        self.check_exact_equal(x, MyDecimal("-0.9375"))
995
996    def test_inf(self):
997        for INF in (float('inf'), Decimal('inf')):
998            for inf in (INF, -INF):
999                x = statistics._convert(inf, type(inf))
1000                self.check_exact_equal(x, inf)
1001
1002    def test_nan(self):
1003        for nan in (float('nan'), Decimal('NAN'), Decimal('sNAN')):
1004            x = statistics._convert(nan, type(nan))
1005            self.assertTrue(_nan_equal(x, nan))
1006
1007
1008class FailNegTest(unittest.TestCase):
1009    """Test _fail_neg private function."""
1010
1011    def test_pass_through(self):
1012        # Test that values are passed through unchanged.
1013        values = [1, 2.0, Fraction(3), Decimal(4)]
1014        new = list(statistics._fail_neg(values))
1015        self.assertEqual(values, new)
1016
1017    def test_negatives_raise(self):
1018        # Test that negatives raise an exception.
1019        for x in [1, 2.0, Fraction(3), Decimal(4)]:
1020            seq = [-x]
1021            it = statistics._fail_neg(seq)
1022            self.assertRaises(statistics.StatisticsError, next, it)
1023
1024    def test_error_msg(self):
1025        # Test that a given error message is used.
1026        msg = "badness #%d" % random.randint(10000, 99999)
1027        try:
1028            next(statistics._fail_neg([-1], msg))
1029        except statistics.StatisticsError as e:
1030            errmsg = e.args[0]
1031        else:
1032            self.fail("expected exception, but it didn't happen")
1033        self.assertEqual(errmsg, msg)
1034
1035
1036# === Tests for public functions ===
1037
1038class UnivariateCommonMixin:
1039    # Common tests for most univariate functions that take a data argument.
1040
1041    def test_no_args(self):
1042        # Fail if given no arguments.
1043        self.assertRaises(TypeError, self.func)
1044
1045    def test_empty_data(self):
1046        # Fail when the data argument (first argument) is empty.
1047        for empty in ([], (), iter([])):
1048            self.assertRaises(statistics.StatisticsError, self.func, empty)
1049
1050    def prepare_data(self):
1051        """Return int data for various tests."""
1052        data = list(range(10))
1053        while data == sorted(data):
1054            random.shuffle(data)
1055        return data
1056
1057    def test_no_inplace_modifications(self):
1058        # Test that the function does not modify its input data.
1059        data = self.prepare_data()
1060        assert len(data) != 1  # Necessary to avoid infinite loop.
1061        assert data != sorted(data)
1062        saved = data[:]
1063        assert data is not saved
1064        _ = self.func(data)
1065        self.assertListEqual(data, saved, "data has been modified")
1066
1067    def test_order_doesnt_matter(self):
1068        # Test that the order of data points doesn't change the result.
1069
1070        # CAUTION: due to floating point rounding errors, the result actually
1071        # may depend on the order. Consider this test representing an ideal.
1072        # To avoid this test failing, only test with exact values such as ints
1073        # or Fractions.
1074        data = [1, 2, 3, 3, 3, 4, 5, 6]*100
1075        expected = self.func(data)
1076        random.shuffle(data)
1077        actual = self.func(data)
1078        self.assertEqual(expected, actual)
1079
1080    def test_type_of_data_collection(self):
1081        # Test that the type of iterable data doesn't effect the result.
1082        class MyList(list):
1083            pass
1084        class MyTuple(tuple):
1085            pass
1086        def generator(data):
1087            return (obj for obj in data)
1088        data = self.prepare_data()
1089        expected = self.func(data)
1090        for kind in (list, tuple, iter, MyList, MyTuple, generator):
1091            result = self.func(kind(data))
1092            self.assertEqual(result, expected)
1093
1094    def test_range_data(self):
1095        # Test that functions work with range objects.
1096        data = range(20, 50, 3)
1097        expected = self.func(list(data))
1098        self.assertEqual(self.func(data), expected)
1099
1100    def test_bad_arg_types(self):
1101        # Test that function raises when given data of the wrong type.
1102
1103        # Don't roll the following into a loop like this:
1104        #   for bad in list_of_bad:
1105        #       self.check_for_type_error(bad)
1106        #
1107        # Since assertRaises doesn't show the arguments that caused the test
1108        # failure, it is very difficult to debug these test failures when the
1109        # following are in a loop.
1110        self.check_for_type_error(None)
1111        self.check_for_type_error(23)
1112        self.check_for_type_error(42.0)
1113        self.check_for_type_error(object())
1114
1115    def check_for_type_error(self, *args):
1116        self.assertRaises(TypeError, self.func, *args)
1117
1118    def test_type_of_data_element(self):
1119        # Check the type of data elements doesn't affect the numeric result.
1120        # This is a weaker test than UnivariateTypeMixin.testTypesConserved,
1121        # because it checks the numeric result by equality, but not by type.
1122        class MyFloat(float):
1123            def __truediv__(self, other):
1124                return type(self)(super().__truediv__(other))
1125            def __add__(self, other):
1126                return type(self)(super().__add__(other))
1127            __radd__ = __add__
1128
1129        raw = self.prepare_data()
1130        expected = self.func(raw)
1131        for kind in (float, MyFloat, Decimal, Fraction):
1132            data = [kind(x) for x in raw]
1133            result = type(expected)(self.func(data))
1134            self.assertEqual(result, expected)
1135
1136
1137class UnivariateTypeMixin:
1138    """Mixin class for type-conserving functions.
1139
1140    This mixin class holds test(s) for functions which conserve the type of
1141    individual data points. E.g. the mean of a list of Fractions should itself
1142    be a Fraction.
1143
1144    Not all tests to do with types need go in this class. Only those that
1145    rely on the function returning the same type as its input data.
1146    """
1147    def prepare_types_for_conservation_test(self):
1148        """Return the types which are expected to be conserved."""
1149        class MyFloat(float):
1150            def __truediv__(self, other):
1151                return type(self)(super().__truediv__(other))
1152            def __rtruediv__(self, other):
1153                return type(self)(super().__rtruediv__(other))
1154            def __sub__(self, other):
1155                return type(self)(super().__sub__(other))
1156            def __rsub__(self, other):
1157                return type(self)(super().__rsub__(other))
1158            def __pow__(self, other):
1159                return type(self)(super().__pow__(other))
1160            def __add__(self, other):
1161                return type(self)(super().__add__(other))
1162            __radd__ = __add__
1163        return (float, Decimal, Fraction, MyFloat)
1164
1165    def test_types_conserved(self):
1166        # Test that functions keeps the same type as their data points.
1167        # (Excludes mixed data types.) This only tests the type of the return
1168        # result, not the value.
1169        data = self.prepare_data()
1170        for kind in self.prepare_types_for_conservation_test():
1171            d = [kind(x) for x in data]
1172            result = self.func(d)
1173            self.assertIs(type(result), kind)
1174
1175
1176class TestSumCommon(UnivariateCommonMixin, UnivariateTypeMixin):
1177    # Common test cases for statistics._sum() function.
1178
1179    # This test suite looks only at the numeric value returned by _sum,
1180    # after conversion to the appropriate type.
1181    def setUp(self):
1182        def simplified_sum(*args):
1183            T, value, n = statistics._sum(*args)
1184            return statistics._coerce(value, T)
1185        self.func = simplified_sum
1186
1187
1188class TestSum(NumericTestCase):
1189    # Test cases for statistics._sum() function.
1190
1191    # These tests look at the entire three value tuple returned by _sum.
1192
1193    def setUp(self):
1194        self.func = statistics._sum
1195
1196    def test_empty_data(self):
1197        # Override test for empty data.
1198        for data in ([], (), iter([])):
1199            self.assertEqual(self.func(data), (int, Fraction(0), 0))
1200            self.assertEqual(self.func(data, 23), (int, Fraction(23), 0))
1201            self.assertEqual(self.func(data, 2.3), (float, Fraction(2.3), 0))
1202
1203    def test_ints(self):
1204        self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]),
1205                         (int, Fraction(60), 8))
1206        self.assertEqual(self.func([4, 2, 3, -8, 7], 1000),
1207                         (int, Fraction(1008), 5))
1208
1209    def test_floats(self):
1210        self.assertEqual(self.func([0.25]*20),
1211                         (float, Fraction(5.0), 20))
1212        self.assertEqual(self.func([0.125, 0.25, 0.5, 0.75], 1.5),
1213                         (float, Fraction(3.125), 4))
1214
1215    def test_fractions(self):
1216        self.assertEqual(self.func([Fraction(1, 1000)]*500),
1217                         (Fraction, Fraction(1, 2), 500))
1218
1219    def test_decimals(self):
1220        D = Decimal
1221        data = [D("0.001"), D("5.246"), D("1.702"), D("-0.025"),
1222                D("3.974"), D("2.328"), D("4.617"), D("2.843"),
1223                ]
1224        self.assertEqual(self.func(data),
1225                         (Decimal, Decimal("20.686"), 8))
1226
1227    def test_compare_with_math_fsum(self):
1228        # Compare with the math.fsum function.
1229        # Ideally we ought to get the exact same result, but sometimes
1230        # we differ by a very slight amount :-(
1231        data = [random.uniform(-100, 1000) for _ in range(1000)]
1232        self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16)
1233
1234    def test_start_argument(self):
1235        # Test that the optional start argument works correctly.
1236        data = [random.uniform(1, 1000) for _ in range(100)]
1237        t = self.func(data)[1]
1238        self.assertEqual(t+42, self.func(data, 42)[1])
1239        self.assertEqual(t-23, self.func(data, -23)[1])
1240        self.assertEqual(t+Fraction(1e20), self.func(data, 1e20)[1])
1241
1242    def test_strings_fail(self):
1243        # Sum of strings should fail.
1244        self.assertRaises(TypeError, self.func, [1, 2, 3], '999')
1245        self.assertRaises(TypeError, self.func, [1, 2, 3, '999'])
1246
1247    def test_bytes_fail(self):
1248        # Sum of bytes should fail.
1249        self.assertRaises(TypeError, self.func, [1, 2, 3], b'999')
1250        self.assertRaises(TypeError, self.func, [1, 2, 3, b'999'])
1251
1252    def test_mixed_sum(self):
1253        # Mixed input types are not (currently) allowed.
1254        # Check that mixed data types fail.
1255        self.assertRaises(TypeError, self.func, [1, 2.0, Decimal(1)])
1256        # And so does mixed start argument.
1257        self.assertRaises(TypeError, self.func, [1, 2.0], Decimal(1))
1258
1259
1260class SumTortureTest(NumericTestCase):
1261    def test_torture(self):
1262        # Tim Peters' torture test for sum, and variants of same.
1263        self.assertEqual(statistics._sum([1, 1e100, 1, -1e100]*10000),
1264                         (float, Fraction(20000.0), 40000))
1265        self.assertEqual(statistics._sum([1e100, 1, 1, -1e100]*10000),
1266                         (float, Fraction(20000.0), 40000))
1267        T, num, count = statistics._sum([1e-100, 1, 1e-100, -1]*10000)
1268        self.assertIs(T, float)
1269        self.assertEqual(count, 40000)
1270        self.assertApproxEqual(float(num), 2.0e-96, rel=5e-16)
1271
1272
1273class SumSpecialValues(NumericTestCase):
1274    # Test that sum works correctly with IEEE-754 special values.
1275
1276    def test_nan(self):
1277        for type_ in (float, Decimal):
1278            nan = type_('nan')
1279            result = statistics._sum([1, nan, 2])[1]
1280            self.assertIs(type(result), type_)
1281            self.assertTrue(math.isnan(result))
1282
1283    def check_infinity(self, x, inf):
1284        """Check x is an infinity of the same type and sign as inf."""
1285        self.assertTrue(math.isinf(x))
1286        self.assertIs(type(x), type(inf))
1287        self.assertEqual(x > 0, inf > 0)
1288        assert x == inf
1289
1290    def do_test_inf(self, inf):
1291        # Adding a single infinity gives infinity.
1292        result = statistics._sum([1, 2, inf, 3])[1]
1293        self.check_infinity(result, inf)
1294        # Adding two infinities of the same sign also gives infinity.
1295        result = statistics._sum([1, 2, inf, 3, inf, 4])[1]
1296        self.check_infinity(result, inf)
1297
1298    def test_float_inf(self):
1299        inf = float('inf')
1300        for sign in (+1, -1):
1301            self.do_test_inf(sign*inf)
1302
1303    def test_decimal_inf(self):
1304        inf = Decimal('inf')
1305        for sign in (+1, -1):
1306            self.do_test_inf(sign*inf)
1307
1308    def test_float_mismatched_infs(self):
1309        # Test that adding two infinities of opposite sign gives a NAN.
1310        inf = float('inf')
1311        result = statistics._sum([1, 2, inf, 3, -inf, 4])[1]
1312        self.assertTrue(math.isnan(result))
1313
1314    def test_decimal_extendedcontext_mismatched_infs_to_nan(self):
1315        # Test adding Decimal INFs with opposite sign returns NAN.
1316        inf = Decimal('inf')
1317        data = [1, 2, inf, 3, -inf, 4]
1318        with decimal.localcontext(decimal.ExtendedContext):
1319            self.assertTrue(math.isnan(statistics._sum(data)[1]))
1320
1321    def test_decimal_basiccontext_mismatched_infs_to_nan(self):
1322        # Test adding Decimal INFs with opposite sign raises InvalidOperation.
1323        inf = Decimal('inf')
1324        data = [1, 2, inf, 3, -inf, 4]
1325        with decimal.localcontext(decimal.BasicContext):
1326            self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1327
1328    def test_decimal_snan_raises(self):
1329        # Adding sNAN should raise InvalidOperation.
1330        sNAN = Decimal('sNAN')
1331        data = [1, sNAN, 2]
1332        self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1333
1334
1335# === Tests for averages ===
1336
1337class AverageMixin(UnivariateCommonMixin):
1338    # Mixin class holding common tests for averages.
1339
1340    def test_single_value(self):
1341        # Average of a single value is the value itself.
1342        for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')):
1343            self.assertEqual(self.func([x]), x)
1344
1345    def prepare_values_for_repeated_single_test(self):
1346        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712'))
1347
1348    def test_repeated_single_value(self):
1349        # The average of a single repeated value is the value itself.
1350        for x in self.prepare_values_for_repeated_single_test():
1351            for count in (2, 5, 10, 20):
1352                with self.subTest(x=x, count=count):
1353                    data = [x]*count
1354                    self.assertEqual(self.func(data), x)
1355
1356
1357class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1358    def setUp(self):
1359        self.func = statistics.mean
1360
1361    def test_torture_pep(self):
1362        # "Torture Test" from PEP-450.
1363        self.assertEqual(self.func([1e100, 1, 3, -1e100]), 1)
1364
1365    def test_ints(self):
1366        # Test mean with ints.
1367        data = [0, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 7, 7, 7, 8, 9]
1368        random.shuffle(data)
1369        self.assertEqual(self.func(data), 4.8125)
1370
1371    def test_floats(self):
1372        # Test mean with floats.
1373        data = [17.25, 19.75, 20.0, 21.5, 21.75, 23.25, 25.125, 27.5]
1374        random.shuffle(data)
1375        self.assertEqual(self.func(data), 22.015625)
1376
1377    def test_decimals(self):
1378        # Test mean with Decimals.
1379        D = Decimal
1380        data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")]
1381        random.shuffle(data)
1382        self.assertEqual(self.func(data), D("3.5896"))
1383
1384    def test_fractions(self):
1385        # Test mean with Fractions.
1386        F = Fraction
1387        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1388        random.shuffle(data)
1389        self.assertEqual(self.func(data), F(1479, 1960))
1390
1391    def test_inf(self):
1392        # Test mean with infinities.
1393        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1394        for kind in (float, Decimal):
1395            for sign in (1, -1):
1396                inf = kind("inf")*sign
1397                data = raw + [inf]
1398                result = self.func(data)
1399                self.assertTrue(math.isinf(result))
1400                self.assertEqual(result, inf)
1401
1402    def test_mismatched_infs(self):
1403        # Test mean with infinities of opposite sign.
1404        data = [2, 4, 6, float('inf'), 1, 3, 5, float('-inf')]
1405        result = self.func(data)
1406        self.assertTrue(math.isnan(result))
1407
1408    def test_nan(self):
1409        # Test mean with NANs.
1410        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1411        for kind in (float, Decimal):
1412            inf = kind("nan")
1413            data = raw + [inf]
1414            result = self.func(data)
1415            self.assertTrue(math.isnan(result))
1416
1417    def test_big_data(self):
1418        # Test adding a large constant to every data point.
1419        c = 1e9
1420        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1421        expected = self.func(data) + c
1422        assert expected != c
1423        result = self.func([x+c for x in data])
1424        self.assertEqual(result, expected)
1425
1426    def test_doubled_data(self):
1427        # Mean of [a,b,c...z] should be same as for [a,a,b,b,c,c...z,z].
1428        data = [random.uniform(-3, 5) for _ in range(1000)]
1429        expected = self.func(data)
1430        actual = self.func(data*2)
1431        self.assertApproxEqual(actual, expected)
1432
1433    def test_regression_20561(self):
1434        # Regression test for issue 20561.
1435        # See http://bugs.python.org/issue20561
1436        d = Decimal('1e4')
1437        self.assertEqual(statistics.mean([d]), d)
1438
1439    def test_regression_25177(self):
1440        # Regression test for issue 25177.
1441        # Ensure very big and very small floats don't overflow.
1442        # See http://bugs.python.org/issue25177.
1443        self.assertEqual(statistics.mean(
1444            [8.988465674311579e+307, 8.98846567431158e+307]),
1445            8.98846567431158e+307)
1446        big = 8.98846567431158e+307
1447        tiny = 5e-324
1448        for n in (2, 3, 5, 200):
1449            self.assertEqual(statistics.mean([big]*n), big)
1450            self.assertEqual(statistics.mean([tiny]*n), tiny)
1451
1452
1453class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1454    def setUp(self):
1455        self.func = statistics.harmonic_mean
1456
1457    def prepare_data(self):
1458        # Override mixin method.
1459        values = super().prepare_data()
1460        values.remove(0)
1461        return values
1462
1463    def prepare_values_for_repeated_single_test(self):
1464        # Override mixin method.
1465        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.125'))
1466
1467    def test_zero(self):
1468        # Test that harmonic mean returns zero when given zero.
1469        values = [1, 0, 2]
1470        self.assertEqual(self.func(values), 0)
1471
1472    def test_negative_error(self):
1473        # Test that harmonic mean raises when given a negative value.
1474        exc = statistics.StatisticsError
1475        for values in ([-1], [1, -2, 3]):
1476            with self.subTest(values=values):
1477                self.assertRaises(exc, self.func, values)
1478
1479    def test_ints(self):
1480        # Test harmonic mean with ints.
1481        data = [2, 4, 4, 8, 16, 16]
1482        random.shuffle(data)
1483        self.assertEqual(self.func(data), 6*4/5)
1484
1485    def test_floats_exact(self):
1486        # Test harmonic mean with some carefully chosen floats.
1487        data = [1/8, 1/4, 1/4, 1/2, 1/2]
1488        random.shuffle(data)
1489        self.assertEqual(self.func(data), 1/4)
1490        self.assertEqual(self.func([0.25, 0.5, 1.0, 1.0]), 0.5)
1491
1492    def test_singleton_lists(self):
1493        # Test that harmonic mean([x]) returns (approximately) x.
1494        for x in range(1, 101):
1495            self.assertEqual(self.func([x]), x)
1496
1497    def test_decimals_exact(self):
1498        # Test harmonic mean with some carefully chosen Decimals.
1499        D = Decimal
1500        self.assertEqual(self.func([D(15), D(30), D(60), D(60)]), D(30))
1501        data = [D("0.05"), D("0.10"), D("0.20"), D("0.20")]
1502        random.shuffle(data)
1503        self.assertEqual(self.func(data), D("0.10"))
1504        data = [D("1.68"), D("0.32"), D("5.94"), D("2.75")]
1505        random.shuffle(data)
1506        self.assertEqual(self.func(data), D(66528)/70723)
1507
1508    def test_fractions(self):
1509        # Test harmonic mean with Fractions.
1510        F = Fraction
1511        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1512        random.shuffle(data)
1513        self.assertEqual(self.func(data), F(7*420, 4029))
1514
1515    def test_inf(self):
1516        # Test harmonic mean with infinity.
1517        values = [2.0, float('inf'), 1.0]
1518        self.assertEqual(self.func(values), 2.0)
1519
1520    def test_nan(self):
1521        # Test harmonic mean with NANs.
1522        values = [2.0, float('nan'), 1.0]
1523        self.assertTrue(math.isnan(self.func(values)))
1524
1525    def test_multiply_data_points(self):
1526        # Test multiplying every data point by a constant.
1527        c = 111
1528        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1529        expected = self.func(data)*c
1530        result = self.func([x*c for x in data])
1531        self.assertEqual(result, expected)
1532
1533    def test_doubled_data(self):
1534        # Harmonic mean of [a,b...z] should be same as for [a,a,b,b...z,z].
1535        data = [random.uniform(1, 5) for _ in range(1000)]
1536        expected = self.func(data)
1537        actual = self.func(data*2)
1538        self.assertApproxEqual(actual, expected)
1539
1540
1541class TestMedian(NumericTestCase, AverageMixin):
1542    # Common tests for median and all median.* functions.
1543    def setUp(self):
1544        self.func = statistics.median
1545
1546    def prepare_data(self):
1547        """Overload method from UnivariateCommonMixin."""
1548        data = super().prepare_data()
1549        if len(data)%2 != 1:
1550            data.append(2)
1551        return data
1552
1553    def test_even_ints(self):
1554        # Test median with an even number of int data points.
1555        data = [1, 2, 3, 4, 5, 6]
1556        assert len(data)%2 == 0
1557        self.assertEqual(self.func(data), 3.5)
1558
1559    def test_odd_ints(self):
1560        # Test median with an odd number of int data points.
1561        data = [1, 2, 3, 4, 5, 6, 9]
1562        assert len(data)%2 == 1
1563        self.assertEqual(self.func(data), 4)
1564
1565    def test_odd_fractions(self):
1566        # Test median works with an odd number of Fractions.
1567        F = Fraction
1568        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7)]
1569        assert len(data)%2 == 1
1570        random.shuffle(data)
1571        self.assertEqual(self.func(data), F(3, 7))
1572
1573    def test_even_fractions(self):
1574        # Test median works with an even number of Fractions.
1575        F = Fraction
1576        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1577        assert len(data)%2 == 0
1578        random.shuffle(data)
1579        self.assertEqual(self.func(data), F(1, 2))
1580
1581    def test_odd_decimals(self):
1582        # Test median works with an odd number of Decimals.
1583        D = Decimal
1584        data = [D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1585        assert len(data)%2 == 1
1586        random.shuffle(data)
1587        self.assertEqual(self.func(data), D('4.2'))
1588
1589    def test_even_decimals(self):
1590        # Test median works with an even number of Decimals.
1591        D = Decimal
1592        data = [D('1.2'), D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1593        assert len(data)%2 == 0
1594        random.shuffle(data)
1595        self.assertEqual(self.func(data), D('3.65'))
1596
1597
1598class TestMedianDataType(NumericTestCase, UnivariateTypeMixin):
1599    # Test conservation of data element type for median.
1600    def setUp(self):
1601        self.func = statistics.median
1602
1603    def prepare_data(self):
1604        data = list(range(15))
1605        assert len(data)%2 == 1
1606        while data == sorted(data):
1607            random.shuffle(data)
1608        return data
1609
1610
1611class TestMedianLow(TestMedian, UnivariateTypeMixin):
1612    def setUp(self):
1613        self.func = statistics.median_low
1614
1615    def test_even_ints(self):
1616        # Test median_low with an even number of ints.
1617        data = [1, 2, 3, 4, 5, 6]
1618        assert len(data)%2 == 0
1619        self.assertEqual(self.func(data), 3)
1620
1621    def test_even_fractions(self):
1622        # Test median_low works with an even number of Fractions.
1623        F = Fraction
1624        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1625        assert len(data)%2 == 0
1626        random.shuffle(data)
1627        self.assertEqual(self.func(data), F(3, 7))
1628
1629    def test_even_decimals(self):
1630        # Test median_low works with an even number of Decimals.
1631        D = Decimal
1632        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1633        assert len(data)%2 == 0
1634        random.shuffle(data)
1635        self.assertEqual(self.func(data), D('3.3'))
1636
1637
1638class TestMedianHigh(TestMedian, UnivariateTypeMixin):
1639    def setUp(self):
1640        self.func = statistics.median_high
1641
1642    def test_even_ints(self):
1643        # Test median_high with an even number of ints.
1644        data = [1, 2, 3, 4, 5, 6]
1645        assert len(data)%2 == 0
1646        self.assertEqual(self.func(data), 4)
1647
1648    def test_even_fractions(self):
1649        # Test median_high works with an even number of Fractions.
1650        F = Fraction
1651        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1652        assert len(data)%2 == 0
1653        random.shuffle(data)
1654        self.assertEqual(self.func(data), F(4, 7))
1655
1656    def test_even_decimals(self):
1657        # Test median_high works with an even number of Decimals.
1658        D = Decimal
1659        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1660        assert len(data)%2 == 0
1661        random.shuffle(data)
1662        self.assertEqual(self.func(data), D('4.4'))
1663
1664
1665class TestMedianGrouped(TestMedian):
1666    # Test median_grouped.
1667    # Doesn't conserve data element types, so don't use TestMedianType.
1668    def setUp(self):
1669        self.func = statistics.median_grouped
1670
1671    def test_odd_number_repeated(self):
1672        # Test median.grouped with repeated median values.
1673        data = [12, 13, 14, 14, 14, 15, 15]
1674        assert len(data)%2 == 1
1675        self.assertEqual(self.func(data), 14)
1676        #---
1677        data = [12, 13, 14, 14, 14, 14, 15]
1678        assert len(data)%2 == 1
1679        self.assertEqual(self.func(data), 13.875)
1680        #---
1681        data = [5, 10, 10, 15, 20, 20, 20, 20, 25, 25, 30]
1682        assert len(data)%2 == 1
1683        self.assertEqual(self.func(data, 5), 19.375)
1684        #---
1685        data = [16, 18, 18, 18, 18, 20, 20, 20, 22, 22, 22, 24, 24, 26, 28]
1686        assert len(data)%2 == 1
1687        self.assertApproxEqual(self.func(data, 2), 20.66666667, tol=1e-8)
1688
1689    def test_even_number_repeated(self):
1690        # Test median.grouped with repeated median values.
1691        data = [5, 10, 10, 15, 20, 20, 20, 25, 25, 30]
1692        assert len(data)%2 == 0
1693        self.assertApproxEqual(self.func(data, 5), 19.16666667, tol=1e-8)
1694        #---
1695        data = [2, 3, 4, 4, 4, 5]
1696        assert len(data)%2 == 0
1697        self.assertApproxEqual(self.func(data), 3.83333333, tol=1e-8)
1698        #---
1699        data = [2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1700        assert len(data)%2 == 0
1701        self.assertEqual(self.func(data), 4.5)
1702        #---
1703        data = [3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1704        assert len(data)%2 == 0
1705        self.assertEqual(self.func(data), 4.75)
1706
1707    def test_repeated_single_value(self):
1708        # Override method from AverageMixin.
1709        # Yet again, failure of median_grouped to conserve the data type
1710        # causes me headaches :-(
1711        for x in (5.3, 68, 4.3e17, Fraction(29, 101), Decimal('32.9714')):
1712            for count in (2, 5, 10, 20):
1713                data = [x]*count
1714                self.assertEqual(self.func(data), float(x))
1715
1716    def test_odd_fractions(self):
1717        # Test median_grouped works with an odd number of Fractions.
1718        F = Fraction
1719        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4)]
1720        assert len(data)%2 == 1
1721        random.shuffle(data)
1722        self.assertEqual(self.func(data), 3.0)
1723
1724    def test_even_fractions(self):
1725        # Test median_grouped works with an even number of Fractions.
1726        F = Fraction
1727        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4), F(17, 4)]
1728        assert len(data)%2 == 0
1729        random.shuffle(data)
1730        self.assertEqual(self.func(data), 3.25)
1731
1732    def test_odd_decimals(self):
1733        # Test median_grouped works with an odd number of Decimals.
1734        D = Decimal
1735        data = [D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1736        assert len(data)%2 == 1
1737        random.shuffle(data)
1738        self.assertEqual(self.func(data), 6.75)
1739
1740    def test_even_decimals(self):
1741        # Test median_grouped works with an even number of Decimals.
1742        D = Decimal
1743        data = [D('5.5'), D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1744        assert len(data)%2 == 0
1745        random.shuffle(data)
1746        self.assertEqual(self.func(data), 6.5)
1747        #---
1748        data = [D('5.5'), D('5.5'), D('6.5'), D('7.5'), D('7.5'), D('8.5')]
1749        assert len(data)%2 == 0
1750        random.shuffle(data)
1751        self.assertEqual(self.func(data), 7.0)
1752
1753    def test_interval(self):
1754        # Test median_grouped with interval argument.
1755        data = [2.25, 2.5, 2.5, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1756        self.assertEqual(self.func(data, 0.25), 2.875)
1757        data = [2.25, 2.5, 2.5, 2.75, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1758        self.assertApproxEqual(self.func(data, 0.25), 2.83333333, tol=1e-8)
1759        data = [220, 220, 240, 260, 260, 260, 260, 280, 280, 300, 320, 340]
1760        self.assertEqual(self.func(data, 20), 265.0)
1761
1762    def test_data_type_error(self):
1763        # Test median_grouped with str, bytes data types for data and interval
1764        data = ["", "", ""]
1765        self.assertRaises(TypeError, self.func, data)
1766        #---
1767        data = [b"", b"", b""]
1768        self.assertRaises(TypeError, self.func, data)
1769        #---
1770        data = [1, 2, 3]
1771        interval = ""
1772        self.assertRaises(TypeError, self.func, data, interval)
1773        #---
1774        data = [1, 2, 3]
1775        interval = b""
1776        self.assertRaises(TypeError, self.func, data, interval)
1777
1778
1779class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1780    # Test cases for the discrete version of mode.
1781    def setUp(self):
1782        self.func = statistics.mode
1783
1784    def prepare_data(self):
1785        """Overload method from UnivariateCommonMixin."""
1786        # Make sure test data has exactly one mode.
1787        return [1, 1, 1, 1, 3, 4, 7, 9, 0, 8, 2]
1788
1789    def test_range_data(self):
1790        # Override test from UnivariateCommonMixin.
1791        data = range(20, 50, 3)
1792        self.assertEqual(self.func(data), 20)
1793
1794    def test_nominal_data(self):
1795        # Test mode with nominal data.
1796        data = 'abcbdb'
1797        self.assertEqual(self.func(data), 'b')
1798        data = 'fe fi fo fum fi fi'.split()
1799        self.assertEqual(self.func(data), 'fi')
1800
1801    def test_discrete_data(self):
1802        # Test mode with discrete numeric data.
1803        data = list(range(10))
1804        for i in range(10):
1805            d = data + [i]
1806            random.shuffle(d)
1807            self.assertEqual(self.func(d), i)
1808
1809    def test_bimodal_data(self):
1810        # Test mode with bimodal data.
1811        data = [1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, 9, 9]
1812        assert data.count(2) == data.count(6) == 4
1813        # mode() should return 2, the first encountered mode
1814        self.assertEqual(self.func(data), 2)
1815
1816    def test_unique_data(self):
1817        # Test mode when data points are all unique.
1818        data = list(range(10))
1819        # mode() should return 0, the first encountered mode
1820        self.assertEqual(self.func(data), 0)
1821
1822    def test_none_data(self):
1823        # Test that mode raises TypeError if given None as data.
1824
1825        # This test is necessary because the implementation of mode uses
1826        # collections.Counter, which accepts None and returns an empty dict.
1827        self.assertRaises(TypeError, self.func, None)
1828
1829    def test_counter_data(self):
1830        # Test that a Counter is treated like any other iterable.
1831        data = collections.Counter([1, 1, 1, 2])
1832        # Since the keys of the counter are treated as data points, not the
1833        # counts, this should return the first mode encountered, 1
1834        self.assertEqual(self.func(data), 1)
1835
1836
1837class TestMultiMode(unittest.TestCase):
1838
1839    def test_basics(self):
1840        multimode = statistics.multimode
1841        self.assertEqual(multimode('aabbbbbbbbcc'), ['b'])
1842        self.assertEqual(multimode('aabbbbccddddeeffffgg'), ['b', 'd', 'f'])
1843        self.assertEqual(multimode(''), [])
1844
1845
1846class TestFMean(unittest.TestCase):
1847
1848    def test_basics(self):
1849        fmean = statistics.fmean
1850        D = Decimal
1851        F = Fraction
1852        for data, expected_mean, kind in [
1853            ([3.5, 4.0, 5.25], 4.25, 'floats'),
1854            ([D('3.5'), D('4.0'), D('5.25')], 4.25, 'decimals'),
1855            ([F(7, 2), F(4, 1), F(21, 4)], 4.25, 'fractions'),
1856            ([True, False, True, True, False], 0.60, 'booleans'),
1857            ([3.5, 4, F(21, 4)], 4.25, 'mixed types'),
1858            ((3.5, 4.0, 5.25), 4.25, 'tuple'),
1859            (iter([3.5, 4.0, 5.25]), 4.25, 'iterator'),
1860                ]:
1861            actual_mean = fmean(data)
1862            self.assertIs(type(actual_mean), float, kind)
1863            self.assertEqual(actual_mean, expected_mean, kind)
1864
1865    def test_error_cases(self):
1866        fmean = statistics.fmean
1867        StatisticsError = statistics.StatisticsError
1868        with self.assertRaises(StatisticsError):
1869            fmean([])                               # empty input
1870        with self.assertRaises(StatisticsError):
1871            fmean(iter([]))                         # empty iterator
1872        with self.assertRaises(TypeError):
1873            fmean(None)                             # non-iterable input
1874        with self.assertRaises(TypeError):
1875            fmean([10, None, 20])                   # non-numeric input
1876        with self.assertRaises(TypeError):
1877            fmean()                                 # missing data argument
1878        with self.assertRaises(TypeError):
1879            fmean([10, 20, 60], 70)                 # too many arguments
1880
1881    def test_special_values(self):
1882        # Rules for special values are inherited from math.fsum()
1883        fmean = statistics.fmean
1884        NaN = float('Nan')
1885        Inf = float('Inf')
1886        self.assertTrue(math.isnan(fmean([10, NaN])), 'nan')
1887        self.assertTrue(math.isnan(fmean([NaN, Inf])), 'nan and infinity')
1888        self.assertTrue(math.isinf(fmean([10, Inf])), 'infinity')
1889        with self.assertRaises(ValueError):
1890            fmean([Inf, -Inf])
1891
1892
1893# === Tests for variances and standard deviations ===
1894
1895class VarianceStdevMixin(UnivariateCommonMixin):
1896    # Mixin class holding common tests for variance and std dev.
1897
1898    # Subclasses should inherit from this before NumericTestClass, in order
1899    # to see the rel attribute below. See testShiftData for an explanation.
1900
1901    rel = 1e-12
1902
1903    def test_single_value(self):
1904        # Deviation of a single value is zero.
1905        for x in (11, 19.8, 4.6e14, Fraction(21, 34), Decimal('8.392')):
1906            self.assertEqual(self.func([x]), 0)
1907
1908    def test_repeated_single_value(self):
1909        # The deviation of a single repeated value is zero.
1910        for x in (7.2, 49, 8.1e15, Fraction(3, 7), Decimal('62.4802')):
1911            for count in (2, 3, 5, 15):
1912                data = [x]*count
1913                self.assertEqual(self.func(data), 0)
1914
1915    def test_domain_error_regression(self):
1916        # Regression test for a domain error exception.
1917        # (Thanks to Geremy Condra.)
1918        data = [0.123456789012345]*10000
1919        # All the items are identical, so variance should be exactly zero.
1920        # We allow some small round-off error, but not much.
1921        result = self.func(data)
1922        self.assertApproxEqual(result, 0.0, tol=5e-17)
1923        self.assertGreaterEqual(result, 0)  # A negative result must fail.
1924
1925    def test_shift_data(self):
1926        # Test that shifting the data by a constant amount does not affect
1927        # the variance or stdev. Or at least not much.
1928
1929        # Due to rounding, this test should be considered an ideal. We allow
1930        # some tolerance away from "no change at all" by setting tol and/or rel
1931        # attributes. Subclasses may set tighter or looser error tolerances.
1932        raw = [1.03, 1.27, 1.94, 2.04, 2.58, 3.14, 4.75, 4.98, 5.42, 6.78]
1933        expected = self.func(raw)
1934        # Don't set shift too high, the bigger it is, the more rounding error.
1935        shift = 1e5
1936        data = [x + shift for x in raw]
1937        self.assertApproxEqual(self.func(data), expected)
1938
1939    def test_shift_data_exact(self):
1940        # Like test_shift_data, but result is always exact.
1941        raw = [1, 3, 3, 4, 5, 7, 9, 10, 11, 16]
1942        assert all(x==int(x) for x in raw)
1943        expected = self.func(raw)
1944        shift = 10**9
1945        data = [x + shift for x in raw]
1946        self.assertEqual(self.func(data), expected)
1947
1948    def test_iter_list_same(self):
1949        # Test that iter data and list data give the same result.
1950
1951        # This is an explicit test that iterators and lists are treated the
1952        # same; justification for this test over and above the similar test
1953        # in UnivariateCommonMixin is that an earlier design had variance and
1954        # friends swap between one- and two-pass algorithms, which would
1955        # sometimes give different results.
1956        data = [random.uniform(-3, 8) for _ in range(1000)]
1957        expected = self.func(data)
1958        self.assertEqual(self.func(iter(data)), expected)
1959
1960
1961class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
1962    # Tests for population variance.
1963    def setUp(self):
1964        self.func = statistics.pvariance
1965
1966    def test_exact_uniform(self):
1967        # Test the variance against an exact result for uniform data.
1968        data = list(range(10000))
1969        random.shuffle(data)
1970        expected = (10000**2 - 1)/12  # Exact value.
1971        self.assertEqual(self.func(data), expected)
1972
1973    def test_ints(self):
1974        # Test population variance with int data.
1975        data = [4, 7, 13, 16]
1976        exact = 22.5
1977        self.assertEqual(self.func(data), exact)
1978
1979    def test_fractions(self):
1980        # Test population variance with Fraction data.
1981        F = Fraction
1982        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
1983        exact = F(3, 8)
1984        result = self.func(data)
1985        self.assertEqual(result, exact)
1986        self.assertIsInstance(result, Fraction)
1987
1988    def test_decimals(self):
1989        # Test population variance with Decimal data.
1990        D = Decimal
1991        data = [D("12.1"), D("12.2"), D("12.5"), D("12.9")]
1992        exact = D('0.096875')
1993        result = self.func(data)
1994        self.assertEqual(result, exact)
1995        self.assertIsInstance(result, Decimal)
1996
1997
1998class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
1999    # Tests for sample variance.
2000    def setUp(self):
2001        self.func = statistics.variance
2002
2003    def test_single_value(self):
2004        # Override method from VarianceStdevMixin.
2005        for x in (35, 24.7, 8.2e15, Fraction(19, 30), Decimal('4.2084')):
2006            self.assertRaises(statistics.StatisticsError, self.func, [x])
2007
2008    def test_ints(self):
2009        # Test sample variance with int data.
2010        data = [4, 7, 13, 16]
2011        exact = 30
2012        self.assertEqual(self.func(data), exact)
2013
2014    def test_fractions(self):
2015        # Test sample variance with Fraction data.
2016        F = Fraction
2017        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
2018        exact = F(1, 2)
2019        result = self.func(data)
2020        self.assertEqual(result, exact)
2021        self.assertIsInstance(result, Fraction)
2022
2023    def test_decimals(self):
2024        # Test sample variance with Decimal data.
2025        D = Decimal
2026        data = [D(2), D(2), D(7), D(9)]
2027        exact = 4*D('9.5')/D(3)
2028        result = self.func(data)
2029        self.assertEqual(result, exact)
2030        self.assertIsInstance(result, Decimal)
2031
2032    def test_center_not_at_mean(self):
2033        data = (1.0, 2.0)
2034        self.assertEqual(self.func(data), 0.5)
2035        self.assertEqual(self.func(data, xbar=2.0), 1.0)
2036
2037class TestPStdev(VarianceStdevMixin, NumericTestCase):
2038    # Tests for population standard deviation.
2039    def setUp(self):
2040        self.func = statistics.pstdev
2041
2042    def test_compare_to_variance(self):
2043        # Test that stdev is, in fact, the square root of variance.
2044        data = [random.uniform(-17, 24) for _ in range(1000)]
2045        expected = math.sqrt(statistics.pvariance(data))
2046        self.assertEqual(self.func(data), expected)
2047
2048    def test_center_not_at_mean(self):
2049        # See issue: 40855
2050        data = (3, 6, 7, 10)
2051        self.assertEqual(self.func(data), 2.5)
2052        self.assertEqual(self.func(data, mu=0.5), 6.5)
2053
2054class TestStdev(VarianceStdevMixin, NumericTestCase):
2055    # Tests for sample standard deviation.
2056    def setUp(self):
2057        self.func = statistics.stdev
2058
2059    def test_single_value(self):
2060        # Override method from VarianceStdevMixin.
2061        for x in (81, 203.74, 3.9e14, Fraction(5, 21), Decimal('35.719')):
2062            self.assertRaises(statistics.StatisticsError, self.func, [x])
2063
2064    def test_compare_to_variance(self):
2065        # Test that stdev is, in fact, the square root of variance.
2066        data = [random.uniform(-2, 9) for _ in range(1000)]
2067        expected = math.sqrt(statistics.variance(data))
2068        self.assertEqual(self.func(data), expected)
2069
2070    def test_center_not_at_mean(self):
2071        data = (1.0, 2.0)
2072        self.assertEqual(self.func(data, xbar=2.0), 1.0)
2073
2074class TestGeometricMean(unittest.TestCase):
2075
2076    def test_basics(self):
2077        geometric_mean = statistics.geometric_mean
2078        self.assertAlmostEqual(geometric_mean([54, 24, 36]), 36.0)
2079        self.assertAlmostEqual(geometric_mean([4.0, 9.0]), 6.0)
2080        self.assertAlmostEqual(geometric_mean([17.625]), 17.625)
2081
2082        random.seed(86753095551212)
2083        for rng in [
2084                range(1, 100),
2085                range(1, 1_000),
2086                range(1, 10_000),
2087                range(500, 10_000, 3),
2088                range(10_000, 500, -3),
2089                [12, 17, 13, 5, 120, 7],
2090                [random.expovariate(50.0) for i in range(1_000)],
2091                [random.lognormvariate(20.0, 3.0) for i in range(2_000)],
2092                [random.triangular(2000, 3000, 2200) for i in range(3_000)],
2093            ]:
2094            gm_decimal = math.prod(map(Decimal, rng)) ** (Decimal(1) / len(rng))
2095            gm_float = geometric_mean(rng)
2096            self.assertTrue(math.isclose(gm_float, float(gm_decimal)))
2097
2098    def test_various_input_types(self):
2099        geometric_mean = statistics.geometric_mean
2100        D = Decimal
2101        F = Fraction
2102        # https://www.wolframalpha.com/input/?i=geometric+mean+3.5,+4.0,+5.25
2103        expected_mean = 4.18886
2104        for data, kind in [
2105            ([3.5, 4.0, 5.25], 'floats'),
2106            ([D('3.5'), D('4.0'), D('5.25')], 'decimals'),
2107            ([F(7, 2), F(4, 1), F(21, 4)], 'fractions'),
2108            ([3.5, 4, F(21, 4)], 'mixed types'),
2109            ((3.5, 4.0, 5.25), 'tuple'),
2110            (iter([3.5, 4.0, 5.25]), 'iterator'),
2111                ]:
2112            actual_mean = geometric_mean(data)
2113            self.assertIs(type(actual_mean), float, kind)
2114            self.assertAlmostEqual(actual_mean, expected_mean, places=5)
2115
2116    def test_big_and_small(self):
2117        geometric_mean = statistics.geometric_mean
2118
2119        # Avoid overflow to infinity
2120        large = 2.0 ** 1000
2121        big_gm = geometric_mean([54.0 * large, 24.0 * large, 36.0 * large])
2122        self.assertTrue(math.isclose(big_gm, 36.0 * large))
2123        self.assertFalse(math.isinf(big_gm))
2124
2125        # Avoid underflow to zero
2126        small = 2.0 ** -1000
2127        small_gm = geometric_mean([54.0 * small, 24.0 * small, 36.0 * small])
2128        self.assertTrue(math.isclose(small_gm, 36.0 * small))
2129        self.assertNotEqual(small_gm, 0.0)
2130
2131    def test_error_cases(self):
2132        geometric_mean = statistics.geometric_mean
2133        StatisticsError = statistics.StatisticsError
2134        with self.assertRaises(StatisticsError):
2135            geometric_mean([])                      # empty input
2136        with self.assertRaises(StatisticsError):
2137            geometric_mean([3.5, 0.0, 5.25])        # zero input
2138        with self.assertRaises(StatisticsError):
2139            geometric_mean([3.5, -4.0, 5.25])       # negative input
2140        with self.assertRaises(StatisticsError):
2141            geometric_mean(iter([]))                # empty iterator
2142        with self.assertRaises(TypeError):
2143            geometric_mean(None)                    # non-iterable input
2144        with self.assertRaises(TypeError):
2145            geometric_mean([10, None, 20])          # non-numeric input
2146        with self.assertRaises(TypeError):
2147            geometric_mean()                        # missing data argument
2148        with self.assertRaises(TypeError):
2149            geometric_mean([10, 20, 60], 70)        # too many arguments
2150
2151    def test_special_values(self):
2152        # Rules for special values are inherited from math.fsum()
2153        geometric_mean = statistics.geometric_mean
2154        NaN = float('Nan')
2155        Inf = float('Inf')
2156        self.assertTrue(math.isnan(geometric_mean([10, NaN])), 'nan')
2157        self.assertTrue(math.isnan(geometric_mean([NaN, Inf])), 'nan and infinity')
2158        self.assertTrue(math.isinf(geometric_mean([10, Inf])), 'infinity')
2159        with self.assertRaises(ValueError):
2160            geometric_mean([Inf, -Inf])
2161
2162
2163class TestQuantiles(unittest.TestCase):
2164
2165    def test_specific_cases(self):
2166        # Match results computed by hand and cross-checked
2167        # against the PERCENTILE.EXC function in MS Excel.
2168        quantiles = statistics.quantiles
2169        data = [120, 200, 250, 320, 350]
2170        random.shuffle(data)
2171        for n, expected in [
2172            (1, []),
2173            (2, [250.0]),
2174            (3, [200.0, 320.0]),
2175            (4, [160.0, 250.0, 335.0]),
2176            (5, [136.0, 220.0, 292.0, 344.0]),
2177            (6, [120.0, 200.0, 250.0, 320.0, 350.0]),
2178            (8, [100.0, 160.0, 212.5, 250.0, 302.5, 335.0, 357.5]),
2179            (10, [88.0, 136.0, 184.0, 220.0, 250.0, 292.0, 326.0, 344.0, 362.0]),
2180            (12, [80.0, 120.0, 160.0, 200.0, 225.0, 250.0, 285.0, 320.0, 335.0,
2181                  350.0, 365.0]),
2182            (15, [72.0, 104.0, 136.0, 168.0, 200.0, 220.0, 240.0, 264.0, 292.0,
2183                  320.0, 332.0, 344.0, 356.0, 368.0]),
2184                ]:
2185            self.assertEqual(expected, quantiles(data, n=n))
2186            self.assertEqual(len(quantiles(data, n=n)), n - 1)
2187            # Preserve datatype when possible
2188            for datatype in (float, Decimal, Fraction):
2189                result = quantiles(map(datatype, data), n=n)
2190                self.assertTrue(all(type(x) == datatype) for x in result)
2191                self.assertEqual(result, list(map(datatype, expected)))
2192            # Quantiles should be idempotent
2193            if len(expected) >= 2:
2194                self.assertEqual(quantiles(expected, n=n), expected)
2195            # Cross-check against method='inclusive' which should give
2196            # the same result after adding in minimum and maximum values
2197            # extrapolated from the two lowest and two highest points.
2198            sdata = sorted(data)
2199            lo = 2 * sdata[0] - sdata[1]
2200            hi = 2 * sdata[-1] - sdata[-2]
2201            padded_data = data + [lo, hi]
2202            self.assertEqual(
2203                quantiles(data, n=n),
2204                quantiles(padded_data, n=n, method='inclusive'),
2205                (n, data),
2206            )
2207            # Invariant under translation and scaling
2208            def f(x):
2209                return 3.5 * x - 1234.675
2210            exp = list(map(f, expected))
2211            act = quantiles(map(f, data), n=n)
2212            self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act)))
2213        # Q2 agrees with median()
2214        for k in range(2, 60):
2215            data = random.choices(range(100), k=k)
2216            q1, q2, q3 = quantiles(data)
2217            self.assertEqual(q2, statistics.median(data))
2218
2219    def test_specific_cases_inclusive(self):
2220        # Match results computed by hand and cross-checked
2221        # against the PERCENTILE.INC function in MS Excel
2222        # and against the quantile() function in SciPy.
2223        quantiles = statistics.quantiles
2224        data = [100, 200, 400, 800]
2225        random.shuffle(data)
2226        for n, expected in [
2227            (1, []),
2228            (2, [300.0]),
2229            (3, [200.0, 400.0]),
2230            (4, [175.0, 300.0, 500.0]),
2231            (5, [160.0, 240.0, 360.0, 560.0]),
2232            (6, [150.0, 200.0, 300.0, 400.0, 600.0]),
2233            (8, [137.5, 175, 225.0, 300.0, 375.0, 500.0,650.0]),
2234            (10, [130.0, 160.0, 190.0, 240.0, 300.0, 360.0, 440.0, 560.0, 680.0]),
2235            (12, [125.0, 150.0, 175.0, 200.0, 250.0, 300.0, 350.0, 400.0,
2236                  500.0, 600.0, 700.0]),
2237            (15, [120.0, 140.0, 160.0, 180.0, 200.0, 240.0, 280.0, 320.0, 360.0,
2238                  400.0, 480.0, 560.0, 640.0, 720.0]),
2239                ]:
2240            self.assertEqual(expected, quantiles(data, n=n, method="inclusive"))
2241            self.assertEqual(len(quantiles(data, n=n, method="inclusive")), n - 1)
2242            # Preserve datatype when possible
2243            for datatype in (float, Decimal, Fraction):
2244                result = quantiles(map(datatype, data), n=n, method="inclusive")
2245                self.assertTrue(all(type(x) == datatype) for x in result)
2246                self.assertEqual(result, list(map(datatype, expected)))
2247            # Invariant under translation and scaling
2248            def f(x):
2249                return 3.5 * x - 1234.675
2250            exp = list(map(f, expected))
2251            act = quantiles(map(f, data), n=n, method="inclusive")
2252            self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act)))
2253        # Natural deciles
2254        self.assertEqual(quantiles([0, 100], n=10, method='inclusive'),
2255                         [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0])
2256        self.assertEqual(quantiles(range(0, 101), n=10, method='inclusive'),
2257                         [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0])
2258        # Whenever n is smaller than the number of data points, running
2259        # method='inclusive' should give the same result as method='exclusive'
2260        # after the two included extreme points are removed.
2261        data = [random.randrange(10_000) for i in range(501)]
2262        actual = quantiles(data, n=32, method='inclusive')
2263        data.remove(min(data))
2264        data.remove(max(data))
2265        expected = quantiles(data, n=32)
2266        self.assertEqual(expected, actual)
2267        # Q2 agrees with median()
2268        for k in range(2, 60):
2269            data = random.choices(range(100), k=k)
2270            q1, q2, q3 = quantiles(data, method='inclusive')
2271            self.assertEqual(q2, statistics.median(data))
2272
2273    def test_equal_inputs(self):
2274        quantiles = statistics.quantiles
2275        for n in range(2, 10):
2276            data = [10.0] * n
2277            self.assertEqual(quantiles(data), [10.0, 10.0, 10.0])
2278            self.assertEqual(quantiles(data, method='inclusive'),
2279                            [10.0, 10.0, 10.0])
2280
2281    def test_equal_sized_groups(self):
2282        quantiles = statistics.quantiles
2283        total = 10_000
2284        data = [random.expovariate(0.2) for i in range(total)]
2285        while len(set(data)) != total:
2286            data.append(random.expovariate(0.2))
2287        data.sort()
2288
2289        # Cases where the group size exactly divides the total
2290        for n in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000):
2291            group_size = total // n
2292            self.assertEqual(
2293                [bisect.bisect(data, q) for q in quantiles(data, n=n)],
2294                list(range(group_size, total, group_size)))
2295
2296        # When the group sizes can't be exactly equal, they should
2297        # differ by no more than one
2298        for n in (13, 19, 59, 109, 211, 571, 1019, 1907, 5261, 9769):
2299            group_sizes = {total // n, total // n + 1}
2300            pos = [bisect.bisect(data, q) for q in quantiles(data, n=n)]
2301            sizes = {q - p for p, q in zip(pos, pos[1:])}
2302            self.assertTrue(sizes <= group_sizes)
2303
2304    def test_error_cases(self):
2305        quantiles = statistics.quantiles
2306        StatisticsError = statistics.StatisticsError
2307        with self.assertRaises(TypeError):
2308            quantiles()                         # Missing arguments
2309        with self.assertRaises(TypeError):
2310            quantiles([10, 20, 30], 13, n=4)    # Too many arguments
2311        with self.assertRaises(TypeError):
2312            quantiles([10, 20, 30], 4)          # n is a positional argument
2313        with self.assertRaises(StatisticsError):
2314            quantiles([10, 20, 30], n=0)        # n is zero
2315        with self.assertRaises(StatisticsError):
2316            quantiles([10, 20, 30], n=-1)       # n is negative
2317        with self.assertRaises(TypeError):
2318            quantiles([10, 20, 30], n=1.5)      # n is not an integer
2319        with self.assertRaises(ValueError):
2320            quantiles([10, 20, 30], method='X') # method is unknown
2321        with self.assertRaises(StatisticsError):
2322            quantiles([10], n=4)                # not enough data points
2323        with self.assertRaises(TypeError):
2324            quantiles([10, None, 30], n=4)      # data is non-numeric
2325
2326
2327class TestNormalDist:
2328
2329    # General note on precision: The pdf(), cdf(), and overlap() methods
2330    # depend on functions in the math libraries that do not make
2331    # explicit accuracy guarantees.  Accordingly, some of the accuracy
2332    # tests below may fail if the underlying math functions are
2333    # inaccurate.  There isn't much we can do about this short of
2334    # implementing our own implementations from scratch.
2335
2336    def test_slots(self):
2337        nd = self.module.NormalDist(300, 23)
2338        with self.assertRaises(TypeError):
2339            vars(nd)
2340        self.assertEqual(tuple(nd.__slots__), ('_mu', '_sigma'))
2341
2342    def test_instantiation_and_attributes(self):
2343        nd = self.module.NormalDist(500, 17)
2344        self.assertEqual(nd.mean, 500)
2345        self.assertEqual(nd.stdev, 17)
2346        self.assertEqual(nd.variance, 17**2)
2347
2348        # default arguments
2349        nd = self.module.NormalDist()
2350        self.assertEqual(nd.mean, 0)
2351        self.assertEqual(nd.stdev, 1)
2352        self.assertEqual(nd.variance, 1**2)
2353
2354        # error case: negative sigma
2355        with self.assertRaises(self.module.StatisticsError):
2356            self.module.NormalDist(500, -10)
2357
2358        # verify that subclass type is honored
2359        class NewNormalDist(self.module.NormalDist):
2360            pass
2361        nnd = NewNormalDist(200, 5)
2362        self.assertEqual(type(nnd), NewNormalDist)
2363
2364    def test_alternative_constructor(self):
2365        NormalDist = self.module.NormalDist
2366        data = [96, 107, 90, 92, 110]
2367        # list input
2368        self.assertEqual(NormalDist.from_samples(data), NormalDist(99, 9))
2369        # tuple input
2370        self.assertEqual(NormalDist.from_samples(tuple(data)), NormalDist(99, 9))
2371        # iterator input
2372        self.assertEqual(NormalDist.from_samples(iter(data)), NormalDist(99, 9))
2373        # error cases
2374        with self.assertRaises(self.module.StatisticsError):
2375            NormalDist.from_samples([])                      # empty input
2376        with self.assertRaises(self.module.StatisticsError):
2377            NormalDist.from_samples([10])                    # only one input
2378
2379        # verify that subclass type is honored
2380        class NewNormalDist(NormalDist):
2381            pass
2382        nnd = NewNormalDist.from_samples(data)
2383        self.assertEqual(type(nnd), NewNormalDist)
2384
2385    def test_sample_generation(self):
2386        NormalDist = self.module.NormalDist
2387        mu, sigma = 10_000, 3.0
2388        X = NormalDist(mu, sigma)
2389        n = 1_000
2390        data = X.samples(n)
2391        self.assertEqual(len(data), n)
2392        self.assertEqual(set(map(type, data)), {float})
2393        # mean(data) expected to fall within 8 standard deviations
2394        xbar = self.module.mean(data)
2395        self.assertTrue(mu - sigma*8 <= xbar <= mu + sigma*8)
2396
2397        # verify that seeding makes reproducible sequences
2398        n = 100
2399        data1 = X.samples(n, seed='happiness and joy')
2400        data2 = X.samples(n, seed='trouble and despair')
2401        data3 = X.samples(n, seed='happiness and joy')
2402        data4 = X.samples(n, seed='trouble and despair')
2403        self.assertEqual(data1, data3)
2404        self.assertEqual(data2, data4)
2405        self.assertNotEqual(data1, data2)
2406
2407    def test_pdf(self):
2408        NormalDist = self.module.NormalDist
2409        X = NormalDist(100, 15)
2410        # Verify peak around center
2411        self.assertLess(X.pdf(99), X.pdf(100))
2412        self.assertLess(X.pdf(101), X.pdf(100))
2413        # Test symmetry
2414        for i in range(50):
2415            self.assertAlmostEqual(X.pdf(100 - i), X.pdf(100 + i))
2416        # Test vs CDF
2417        dx = 2.0 ** -10
2418        for x in range(90, 111):
2419            est_pdf = (X.cdf(x + dx) - X.cdf(x)) / dx
2420            self.assertAlmostEqual(X.pdf(x), est_pdf, places=4)
2421        # Test vs table of known values -- CRC 26th Edition
2422        Z = NormalDist()
2423        for x, px in enumerate([
2424            0.3989, 0.3989, 0.3989, 0.3988, 0.3986,
2425            0.3984, 0.3982, 0.3980, 0.3977, 0.3973,
2426            0.3970, 0.3965, 0.3961, 0.3956, 0.3951,
2427            0.3945, 0.3939, 0.3932, 0.3925, 0.3918,
2428            0.3910, 0.3902, 0.3894, 0.3885, 0.3876,
2429            0.3867, 0.3857, 0.3847, 0.3836, 0.3825,
2430            0.3814, 0.3802, 0.3790, 0.3778, 0.3765,
2431            0.3752, 0.3739, 0.3725, 0.3712, 0.3697,
2432            0.3683, 0.3668, 0.3653, 0.3637, 0.3621,
2433            0.3605, 0.3589, 0.3572, 0.3555, 0.3538,
2434        ]):
2435            self.assertAlmostEqual(Z.pdf(x / 100.0), px, places=4)
2436            self.assertAlmostEqual(Z.pdf(-x / 100.0), px, places=4)
2437        # Error case: variance is zero
2438        Y = NormalDist(100, 0)
2439        with self.assertRaises(self.module.StatisticsError):
2440            Y.pdf(90)
2441        # Special values
2442        self.assertEqual(X.pdf(float('-Inf')), 0.0)
2443        self.assertEqual(X.pdf(float('Inf')), 0.0)
2444        self.assertTrue(math.isnan(X.pdf(float('NaN'))))
2445
2446    def test_cdf(self):
2447        NormalDist = self.module.NormalDist
2448        X = NormalDist(100, 15)
2449        cdfs = [X.cdf(x) for x in range(1, 200)]
2450        self.assertEqual(set(map(type, cdfs)), {float})
2451        # Verify montonic
2452        self.assertEqual(cdfs, sorted(cdfs))
2453        # Verify center (should be exact)
2454        self.assertEqual(X.cdf(100), 0.50)
2455        # Check against a table of known values
2456        # https://en.wikipedia.org/wiki/Standard_normal_table#Cumulative
2457        Z = NormalDist()
2458        for z, cum_prob in [
2459            (0.00, 0.50000), (0.01, 0.50399), (0.02, 0.50798),
2460            (0.14, 0.55567), (0.29, 0.61409), (0.33, 0.62930),
2461            (0.54, 0.70540), (0.60, 0.72575), (1.17, 0.87900),
2462            (1.60, 0.94520), (2.05, 0.97982), (2.89, 0.99807),
2463            (3.52, 0.99978), (3.98, 0.99997), (4.07, 0.99998),
2464            ]:
2465            self.assertAlmostEqual(Z.cdf(z), cum_prob, places=5)
2466            self.assertAlmostEqual(Z.cdf(-z), 1.0 - cum_prob, places=5)
2467        # Error case: variance is zero
2468        Y = NormalDist(100, 0)
2469        with self.assertRaises(self.module.StatisticsError):
2470            Y.cdf(90)
2471        # Special values
2472        self.assertEqual(X.cdf(float('-Inf')), 0.0)
2473        self.assertEqual(X.cdf(float('Inf')), 1.0)
2474        self.assertTrue(math.isnan(X.cdf(float('NaN'))))
2475
2476    @support.skip_if_pgo_task
2477    def test_inv_cdf(self):
2478        NormalDist = self.module.NormalDist
2479
2480        # Center case should be exact.
2481        iq = NormalDist(100, 15)
2482        self.assertEqual(iq.inv_cdf(0.50), iq.mean)
2483
2484        # Test versus a published table of known percentage points.
2485        # See the second table at the bottom of the page here:
2486        # http://people.bath.ac.uk/masss/tables/normaltable.pdf
2487        Z = NormalDist()
2488        pp = {5.0: (0.000, 1.645, 2.576, 3.291, 3.891,
2489                    4.417, 4.892, 5.327, 5.731, 6.109),
2490              2.5: (0.674, 1.960, 2.807, 3.481, 4.056,
2491                    4.565, 5.026, 5.451, 5.847, 6.219),
2492              1.0: (1.282, 2.326, 3.090, 3.719, 4.265,
2493                    4.753, 5.199, 5.612, 5.998, 6.361)}
2494        for base, row in pp.items():
2495            for exp, x in enumerate(row, start=1):
2496                p = base * 10.0 ** (-exp)
2497                self.assertAlmostEqual(-Z.inv_cdf(p), x, places=3)
2498                p = 1.0 - p
2499                self.assertAlmostEqual(Z.inv_cdf(p), x, places=3)
2500
2501        # Match published example for MS Excel
2502        # https://support.office.com/en-us/article/norm-inv-function-54b30935-fee7-493c-bedb-2278a9db7e13
2503        self.assertAlmostEqual(NormalDist(40, 1.5).inv_cdf(0.908789), 42.000002)
2504
2505        # One million equally spaced probabilities
2506        n = 2**20
2507        for p in range(1, n):
2508            p /= n
2509            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2510
2511        # One hundred ever smaller probabilities to test tails out to
2512        # extreme probabilities: 1 / 2**50 and (2**50-1) / 2 ** 50
2513        for e in range(1, 51):
2514            p = 2.0 ** (-e)
2515            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2516            p = 1.0 - p
2517            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2518
2519        # Now apply cdf() first.  Near the tails, the round-trip loses
2520        # precision and is ill-conditioned (small changes in the inputs
2521        # give large changes in the output), so only check to 5 places.
2522        for x in range(200):
2523            self.assertAlmostEqual(iq.inv_cdf(iq.cdf(x)), x, places=5)
2524
2525        # Error cases:
2526        with self.assertRaises(self.module.StatisticsError):
2527            iq.inv_cdf(0.0)                         # p is zero
2528        with self.assertRaises(self.module.StatisticsError):
2529            iq.inv_cdf(-0.1)                        # p under zero
2530        with self.assertRaises(self.module.StatisticsError):
2531            iq.inv_cdf(1.0)                         # p is one
2532        with self.assertRaises(self.module.StatisticsError):
2533            iq.inv_cdf(1.1)                         # p over one
2534        with self.assertRaises(self.module.StatisticsError):
2535            iq = NormalDist(100, 0)                 # sigma is zero
2536            iq.inv_cdf(0.5)
2537
2538        # Special values
2539        self.assertTrue(math.isnan(Z.inv_cdf(float('NaN'))))
2540
2541    def test_quantiles(self):
2542        # Quartiles of a standard normal distribution
2543        Z = self.module.NormalDist()
2544        for n, expected in [
2545            (1, []),
2546            (2, [0.0]),
2547            (3, [-0.4307, 0.4307]),
2548            (4 ,[-0.6745, 0.0, 0.6745]),
2549                ]:
2550            actual = Z.quantiles(n=n)
2551            self.assertTrue(all(math.isclose(e, a, abs_tol=0.0001)
2552                            for e, a in zip(expected, actual)))
2553
2554    def test_overlap(self):
2555        NormalDist = self.module.NormalDist
2556
2557        # Match examples from Imman and Bradley
2558        for X1, X2, published_result in [
2559                (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0), 0.80258),
2560                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0), 0.60993),
2561            ]:
2562            self.assertAlmostEqual(X1.overlap(X2), published_result, places=4)
2563            self.assertAlmostEqual(X2.overlap(X1), published_result, places=4)
2564
2565        # Check against integration of the PDF
2566        def overlap_numeric(X, Y, *, steps=8_192, z=5):
2567            'Numerical integration cross-check for overlap() '
2568            fsum = math.fsum
2569            center = (X.mean + Y.mean) / 2.0
2570            width = z * max(X.stdev, Y.stdev)
2571            start = center - width
2572            dx = 2.0 * width / steps
2573            x_arr = [start + i*dx for i in range(steps)]
2574            xp = list(map(X.pdf, x_arr))
2575            yp = list(map(Y.pdf, x_arr))
2576            total = max(fsum(xp), fsum(yp))
2577            return fsum(map(min, xp, yp)) / total
2578
2579        for X1, X2 in [
2580                # Examples from Imman and Bradley
2581                (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0)),
2582                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)),
2583                # Example from https://www.rasch.org/rmt/rmt101r.htm
2584                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)),
2585                # Gender heights from http://www.usablestats.com/lessons/normal
2586                (NormalDist(70, 4), NormalDist(65, 3.5)),
2587                # Misc cases with equal standard deviations
2588                (NormalDist(100, 15), NormalDist(110, 15)),
2589                (NormalDist(-100, 15), NormalDist(110, 15)),
2590                (NormalDist(-100, 15), NormalDist(-110, 15)),
2591                # Misc cases with unequal standard deviations
2592                (NormalDist(100, 12), NormalDist(100, 15)),
2593                (NormalDist(100, 12), NormalDist(110, 15)),
2594                (NormalDist(100, 12), NormalDist(150, 15)),
2595                (NormalDist(100, 12), NormalDist(150, 35)),
2596                # Misc cases with small values
2597                (NormalDist(1.000, 0.002), NormalDist(1.001, 0.003)),
2598                (NormalDist(1.000, 0.002), NormalDist(1.006, 0.0003)),
2599                (NormalDist(1.000, 0.002), NormalDist(1.001, 0.099)),
2600            ]:
2601            self.assertAlmostEqual(X1.overlap(X2), overlap_numeric(X1, X2), places=5)
2602            self.assertAlmostEqual(X2.overlap(X1), overlap_numeric(X1, X2), places=5)
2603
2604        # Error cases
2605        X = NormalDist()
2606        with self.assertRaises(TypeError):
2607            X.overlap()                             # too few arguments
2608        with self.assertRaises(TypeError):
2609            X.overlap(X, X)                         # too may arguments
2610        with self.assertRaises(TypeError):
2611            X.overlap(None)                         # right operand not a NormalDist
2612        with self.assertRaises(self.module.StatisticsError):
2613            X.overlap(NormalDist(1, 0))             # right operand sigma is zero
2614        with self.assertRaises(self.module.StatisticsError):
2615            NormalDist(1, 0).overlap(X)             # left operand sigma is zero
2616
2617    def test_properties(self):
2618        X = self.module.NormalDist(100, 15)
2619        self.assertEqual(X.mean, 100)
2620        self.assertEqual(X.median, 100)
2621        self.assertEqual(X.mode, 100)
2622        self.assertEqual(X.stdev, 15)
2623        self.assertEqual(X.variance, 225)
2624
2625    def test_same_type_addition_and_subtraction(self):
2626        NormalDist = self.module.NormalDist
2627        X = NormalDist(100, 12)
2628        Y = NormalDist(40, 5)
2629        self.assertEqual(X + Y, NormalDist(140, 13))        # __add__
2630        self.assertEqual(X - Y, NormalDist(60, 13))         # __sub__
2631
2632    def test_translation_and_scaling(self):
2633        NormalDist = self.module.NormalDist
2634        X = NormalDist(100, 15)
2635        y = 10
2636        self.assertEqual(+X, NormalDist(100, 15))           # __pos__
2637        self.assertEqual(-X, NormalDist(-100, 15))          # __neg__
2638        self.assertEqual(X + y, NormalDist(110, 15))        # __add__
2639        self.assertEqual(y + X, NormalDist(110, 15))        # __radd__
2640        self.assertEqual(X - y, NormalDist(90, 15))         # __sub__
2641        self.assertEqual(y - X, NormalDist(-90, 15))        # __rsub__
2642        self.assertEqual(X * y, NormalDist(1000, 150))      # __mul__
2643        self.assertEqual(y * X, NormalDist(1000, 150))      # __rmul__
2644        self.assertEqual(X / y, NormalDist(10, 1.5))        # __truediv__
2645        with self.assertRaises(TypeError):                  # __rtruediv__
2646            y / X
2647
2648    def test_unary_operations(self):
2649        NormalDist = self.module.NormalDist
2650        X = NormalDist(100, 12)
2651        Y = +X
2652        self.assertIsNot(X, Y)
2653        self.assertEqual(X.mean, Y.mean)
2654        self.assertEqual(X.stdev, Y.stdev)
2655        Y = -X
2656        self.assertIsNot(X, Y)
2657        self.assertEqual(X.mean, -Y.mean)
2658        self.assertEqual(X.stdev, Y.stdev)
2659
2660    def test_equality(self):
2661        NormalDist = self.module.NormalDist
2662        nd1 = NormalDist()
2663        nd2 = NormalDist(2, 4)
2664        nd3 = NormalDist()
2665        nd4 = NormalDist(2, 4)
2666        nd5 = NormalDist(2, 8)
2667        nd6 = NormalDist(8, 4)
2668        self.assertNotEqual(nd1, nd2)
2669        self.assertEqual(nd1, nd3)
2670        self.assertEqual(nd2, nd4)
2671        self.assertNotEqual(nd2, nd5)
2672        self.assertNotEqual(nd2, nd6)
2673
2674        # Test NotImplemented when types are different
2675        class A:
2676            def __eq__(self, other):
2677                return 10
2678        a = A()
2679        self.assertEqual(nd1.__eq__(a), NotImplemented)
2680        self.assertEqual(nd1 == a, 10)
2681        self.assertEqual(a == nd1, 10)
2682
2683        # All subclasses to compare equal giving the same behavior
2684        # as list, tuple, int, float, complex, str, dict, set, etc.
2685        class SizedNormalDist(NormalDist):
2686            def __init__(self, mu, sigma, n):
2687                super().__init__(mu, sigma)
2688                self.n = n
2689        s = SizedNormalDist(100, 15, 57)
2690        nd4 = NormalDist(100, 15)
2691        self.assertEqual(s, nd4)
2692
2693        # Don't allow duck type equality because we wouldn't
2694        # want a lognormal distribution to compare equal
2695        # to a normal distribution with the same parameters
2696        class LognormalDist:
2697            def __init__(self, mu, sigma):
2698                self.mu = mu
2699                self.sigma = sigma
2700        lnd = LognormalDist(100, 15)
2701        nd = NormalDist(100, 15)
2702        self.assertNotEqual(nd, lnd)
2703
2704    def test_pickle_and_copy(self):
2705        nd = self.module.NormalDist(37.5, 5.625)
2706        nd1 = copy.copy(nd)
2707        self.assertEqual(nd, nd1)
2708        nd2 = copy.deepcopy(nd)
2709        self.assertEqual(nd, nd2)
2710        nd3 = pickle.loads(pickle.dumps(nd))
2711        self.assertEqual(nd, nd3)
2712
2713    def test_hashability(self):
2714        ND = self.module.NormalDist
2715        s = {ND(100, 15), ND(100.0, 15.0), ND(100, 10), ND(95, 15), ND(100, 15)}
2716        self.assertEqual(len(s), 3)
2717
2718    def test_repr(self):
2719        nd = self.module.NormalDist(37.5, 5.625)
2720        self.assertEqual(repr(nd), 'NormalDist(mu=37.5, sigma=5.625)')
2721
2722# Swapping the sys.modules['statistics'] is to solving the
2723# _pickle.PicklingError:
2724# Can't pickle <class 'statistics.NormalDist'>:
2725# it's not the same object as statistics.NormalDist
2726class TestNormalDistPython(unittest.TestCase, TestNormalDist):
2727    module = py_statistics
2728    def setUp(self):
2729        sys.modules['statistics'] = self.module
2730
2731    def tearDown(self):
2732        sys.modules['statistics'] = statistics
2733
2734
2735@unittest.skipUnless(c_statistics, 'requires _statistics')
2736class TestNormalDistC(unittest.TestCase, TestNormalDist):
2737    module = c_statistics
2738    def setUp(self):
2739        sys.modules['statistics'] = self.module
2740
2741    def tearDown(self):
2742        sys.modules['statistics'] = statistics
2743
2744
2745# === Run tests ===
2746
2747def load_tests(loader, tests, ignore):
2748    """Used for doctest/unittest integration."""
2749    tests.addTests(doctest.DocTestSuite())
2750    return tests
2751
2752
2753if __name__ == "__main__":
2754    unittest.main()
2755