1"""Test suite for statistics module, including helper NumericTestCase and
2approx_equal function.
3
4"""
5
6import bisect
7import collections
8import collections.abc
9import copy
10import decimal
11import doctest
12import math
13import pickle
14import random
15import sys
16import unittest
17from test import support
18from test.support import import_helper
19
20from decimal import Decimal
21from fractions import Fraction
22
23
24# Module to be tested.
25import statistics
26
27
28# === Helper functions and class ===
29
30def sign(x):
31    """Return -1.0 for negatives, including -0.0, otherwise +1.0."""
32    return math.copysign(1, x)
33
34def _nan_equal(a, b):
35    """Return True if a and b are both the same kind of NAN.
36
37    >>> _nan_equal(Decimal('NAN'), Decimal('NAN'))
38    True
39    >>> _nan_equal(Decimal('sNAN'), Decimal('sNAN'))
40    True
41    >>> _nan_equal(Decimal('NAN'), Decimal('sNAN'))
42    False
43    >>> _nan_equal(Decimal(42), Decimal('NAN'))
44    False
45
46    >>> _nan_equal(float('NAN'), float('NAN'))
47    True
48    >>> _nan_equal(float('NAN'), 0.5)
49    False
50
51    >>> _nan_equal(float('NAN'), Decimal('NAN'))
52    False
53
54    NAN payloads are not compared.
55    """
56    if type(a) is not type(b):
57        return False
58    if isinstance(a, float):
59        return math.isnan(a) and math.isnan(b)
60    aexp = a.as_tuple()[2]
61    bexp = b.as_tuple()[2]
62    return (aexp == bexp) and (aexp in ('n', 'N'))  # Both NAN or both sNAN.
63
64
65def _calc_errors(actual, expected):
66    """Return the absolute and relative errors between two numbers.
67
68    >>> _calc_errors(100, 75)
69    (25, 0.25)
70    >>> _calc_errors(100, 100)
71    (0, 0.0)
72
73    Returns the (absolute error, relative error) between the two arguments.
74    """
75    base = max(abs(actual), abs(expected))
76    abs_err = abs(actual - expected)
77    rel_err = abs_err/base if base else float('inf')
78    return (abs_err, rel_err)
79
80
81def approx_equal(x, y, tol=1e-12, rel=1e-7):
82    """approx_equal(x, y [, tol [, rel]]) => True|False
83
84    Return True if numbers x and y are approximately equal, to within some
85    margin of error, otherwise return False. Numbers which compare equal
86    will also compare approximately equal.
87
88    x is approximately equal to y if the difference between them is less than
89    an absolute error tol or a relative error rel, whichever is bigger.
90
91    If given, both tol and rel must be finite, non-negative numbers. If not
92    given, default values are tol=1e-12 and rel=1e-7.
93
94    >>> approx_equal(1.2589, 1.2587, tol=0.0003, rel=0)
95    True
96    >>> approx_equal(1.2589, 1.2587, tol=0.0001, rel=0)
97    False
98
99    Absolute error is defined as abs(x-y); if that is less than or equal to
100    tol, x and y are considered approximately equal.
101
102    Relative error is defined as abs((x-y)/x) or abs((x-y)/y), whichever is
103    smaller, provided x or y are not zero. If that figure is less than or
104    equal to rel, x and y are considered approximately equal.
105
106    Complex numbers are not directly supported. If you wish to compare to
107    complex numbers, extract their real and imaginary parts and compare them
108    individually.
109
110    NANs always compare unequal, even with themselves. Infinities compare
111    approximately equal if they have the same sign (both positive or both
112    negative). Infinities with different signs compare unequal; so do
113    comparisons of infinities with finite numbers.
114    """
115    if tol < 0 or rel < 0:
116        raise ValueError('error tolerances must be non-negative')
117    # NANs are never equal to anything, approximately or otherwise.
118    if math.isnan(x) or math.isnan(y):
119        return False
120    # Numbers which compare equal also compare approximately equal.
121    if x == y:
122        # This includes the case of two infinities with the same sign.
123        return True
124    if math.isinf(x) or math.isinf(y):
125        # This includes the case of two infinities of opposite sign, or
126        # one infinity and one finite number.
127        return False
128    # Two finite numbers.
129    actual_error = abs(x - y)
130    allowed_error = max(tol, rel*max(abs(x), abs(y)))
131    return actual_error <= allowed_error
132
133
134# This class exists only as somewhere to stick a docstring containing
135# doctests. The following docstring and tests were originally in a separate
136# module. Now that it has been merged in here, I need somewhere to hang the.
137# docstring. Ultimately, this class will die, and the information below will
138# either become redundant, or be moved into more appropriate places.
139class _DoNothing:
140    """
141    When doing numeric work, especially with floats, exact equality is often
142    not what you want. Due to round-off error, it is often a bad idea to try
143    to compare floats with equality. Instead the usual procedure is to test
144    them with some (hopefully small!) allowance for error.
145
146    The ``approx_equal`` function allows you to specify either an absolute
147    error tolerance, or a relative error, or both.
148
149    Absolute error tolerances are simple, but you need to know the magnitude
150    of the quantities being compared:
151
152    >>> approx_equal(12.345, 12.346, tol=1e-3)
153    True
154    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3)  # tol is too small.
155    False
156
157    Relative errors are more suitable when the values you are comparing can
158    vary in magnitude:
159
160    >>> approx_equal(12.345, 12.346, rel=1e-4)
161    True
162    >>> approx_equal(12.345e6, 12.346e6, rel=1e-4)
163    True
164
165    but a naive implementation of relative error testing can run into trouble
166    around zero.
167
168    If you supply both an absolute tolerance and a relative error, the
169    comparison succeeds if either individual test succeeds:
170
171    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3, rel=1e-4)
172    True
173
174    """
175    pass
176
177
178
179# We prefer this for testing numeric values that may not be exactly equal,
180# and avoid using TestCase.assertAlmostEqual, because it sucks :-)
181
182py_statistics = import_helper.import_fresh_module('statistics',
183                                                  blocked=['_statistics'])
184c_statistics = import_helper.import_fresh_module('statistics',
185                                                 fresh=['_statistics'])
186
187
188class TestModules(unittest.TestCase):
189    func_names = ['_normal_dist_inv_cdf']
190
191    def test_py_functions(self):
192        for fname in self.func_names:
193            self.assertEqual(getattr(py_statistics, fname).__module__, 'statistics')
194
195    @unittest.skipUnless(c_statistics, 'requires _statistics')
196    def test_c_functions(self):
197        for fname in self.func_names:
198            self.assertEqual(getattr(c_statistics, fname).__module__, '_statistics')
199
200
201class NumericTestCase(unittest.TestCase):
202    """Unit test class for numeric work.
203
204    This subclasses TestCase. In addition to the standard method
205    ``TestCase.assertAlmostEqual``,  ``assertApproxEqual`` is provided.
206    """
207    # By default, we expect exact equality, unless overridden.
208    tol = rel = 0
209
210    def assertApproxEqual(
211            self, first, second, tol=None, rel=None, msg=None
212            ):
213        """Test passes if ``first`` and ``second`` are approximately equal.
214
215        This test passes if ``first`` and ``second`` are equal to
216        within ``tol``, an absolute error, or ``rel``, a relative error.
217
218        If either ``tol`` or ``rel`` are None or not given, they default to
219        test attributes of the same name (by default, 0).
220
221        The objects may be either numbers, or sequences of numbers. Sequences
222        are tested element-by-element.
223
224        >>> class MyTest(NumericTestCase):
225        ...     def test_number(self):
226        ...         x = 1.0/6
227        ...         y = sum([x]*6)
228        ...         self.assertApproxEqual(y, 1.0, tol=1e-15)
229        ...     def test_sequence(self):
230        ...         a = [1.001, 1.001e-10, 1.001e10]
231        ...         b = [1.0, 1e-10, 1e10]
232        ...         self.assertApproxEqual(a, b, rel=1e-3)
233        ...
234        >>> import unittest
235        >>> from io import StringIO  # Suppress test runner output.
236        >>> suite = unittest.TestLoader().loadTestsFromTestCase(MyTest)
237        >>> unittest.TextTestRunner(stream=StringIO()).run(suite)
238        <unittest.runner.TextTestResult run=2 errors=0 failures=0>
239
240        """
241        if tol is None:
242            tol = self.tol
243        if rel is None:
244            rel = self.rel
245        if (
246                isinstance(first, collections.abc.Sequence) and
247                isinstance(second, collections.abc.Sequence)
248            ):
249            check = self._check_approx_seq
250        else:
251            check = self._check_approx_num
252        check(first, second, tol, rel, msg)
253
254    def _check_approx_seq(self, first, second, tol, rel, msg):
255        if len(first) != len(second):
256            standardMsg = (
257                "sequences differ in length: %d items != %d items"
258                % (len(first), len(second))
259                )
260            msg = self._formatMessage(msg, standardMsg)
261            raise self.failureException(msg)
262        for i, (a,e) in enumerate(zip(first, second)):
263            self._check_approx_num(a, e, tol, rel, msg, i)
264
265    def _check_approx_num(self, first, second, tol, rel, msg, idx=None):
266        if approx_equal(first, second, tol, rel):
267            # Test passes. Return early, we are done.
268            return None
269        # Otherwise we failed.
270        standardMsg = self._make_std_err_msg(first, second, tol, rel, idx)
271        msg = self._formatMessage(msg, standardMsg)
272        raise self.failureException(msg)
273
274    @staticmethod
275    def _make_std_err_msg(first, second, tol, rel, idx):
276        # Create the standard error message for approx_equal failures.
277        assert first != second
278        template = (
279            '  %r != %r\n'
280            '  values differ by more than tol=%r and rel=%r\n'
281            '  -> absolute error = %r\n'
282            '  -> relative error = %r'
283            )
284        if idx is not None:
285            header = 'numeric sequences first differ at index %d.\n' % idx
286            template = header + template
287        # Calculate actual errors:
288        abs_err, rel_err = _calc_errors(first, second)
289        return template % (first, second, tol, rel, abs_err, rel_err)
290
291
292# ========================
293# === Test the helpers ===
294# ========================
295
296class TestSign(unittest.TestCase):
297    """Test that the helper function sign() works correctly."""
298    def testZeroes(self):
299        # Test that signed zeroes report their sign correctly.
300        self.assertEqual(sign(0.0), +1)
301        self.assertEqual(sign(-0.0), -1)
302
303
304# --- Tests for approx_equal ---
305
306class ApproxEqualSymmetryTest(unittest.TestCase):
307    # Test symmetry of approx_equal.
308
309    def test_relative_symmetry(self):
310        # Check that approx_equal treats relative error symmetrically.
311        # (a-b)/a is usually not equal to (a-b)/b. Ensure that this
312        # doesn't matter.
313        #
314        #   Note: the reason for this test is that an early version
315        #   of approx_equal was not symmetric. A relative error test
316        #   would pass, or fail, depending on which value was passed
317        #   as the first argument.
318        #
319        args1 = [2456, 37.8, -12.45, Decimal('2.54'), Fraction(17, 54)]
320        args2 = [2459, 37.2, -12.41, Decimal('2.59'), Fraction(15, 54)]
321        assert len(args1) == len(args2)
322        for a, b in zip(args1, args2):
323            self.do_relative_symmetry(a, b)
324
325    def do_relative_symmetry(self, a, b):
326        a, b = min(a, b), max(a, b)
327        assert a < b
328        delta = b - a  # The absolute difference between the values.
329        rel_err1, rel_err2 = abs(delta/a), abs(delta/b)
330        # Choose an error margin halfway between the two.
331        rel = (rel_err1 + rel_err2)/2
332        # Now see that values a and b compare approx equal regardless of
333        # which is given first.
334        self.assertTrue(approx_equal(a, b, tol=0, rel=rel))
335        self.assertTrue(approx_equal(b, a, tol=0, rel=rel))
336
337    def test_symmetry(self):
338        # Test that approx_equal(a, b) == approx_equal(b, a)
339        args = [-23, -2, 5, 107, 93568]
340        delta = 2
341        for a in args:
342            for type_ in (int, float, Decimal, Fraction):
343                x = type_(a)*100
344                y = x + delta
345                r = abs(delta/max(x, y))
346                # There are five cases to check:
347                # 1) actual error <= tol, <= rel
348                self.do_symmetry_test(x, y, tol=delta, rel=r)
349                self.do_symmetry_test(x, y, tol=delta+1, rel=2*r)
350                # 2) actual error > tol, > rel
351                self.do_symmetry_test(x, y, tol=delta-1, rel=r/2)
352                # 3) actual error <= tol, > rel
353                self.do_symmetry_test(x, y, tol=delta, rel=r/2)
354                # 4) actual error > tol, <= rel
355                self.do_symmetry_test(x, y, tol=delta-1, rel=r)
356                self.do_symmetry_test(x, y, tol=delta-1, rel=2*r)
357                # 5) exact equality test
358                self.do_symmetry_test(x, x, tol=0, rel=0)
359                self.do_symmetry_test(x, y, tol=0, rel=0)
360
361    def do_symmetry_test(self, a, b, tol, rel):
362        template = "approx_equal comparisons don't match for %r"
363        flag1 = approx_equal(a, b, tol, rel)
364        flag2 = approx_equal(b, a, tol, rel)
365        self.assertEqual(flag1, flag2, template.format((a, b, tol, rel)))
366
367
368class ApproxEqualExactTest(unittest.TestCase):
369    # Test the approx_equal function with exactly equal values.
370    # Equal values should compare as approximately equal.
371    # Test cases for exactly equal values, which should compare approx
372    # equal regardless of the error tolerances given.
373
374    def do_exactly_equal_test(self, x, tol, rel):
375        result = approx_equal(x, x, tol=tol, rel=rel)
376        self.assertTrue(result, 'equality failure for x=%r' % x)
377        result = approx_equal(-x, -x, tol=tol, rel=rel)
378        self.assertTrue(result, 'equality failure for x=%r' % -x)
379
380    def test_exactly_equal_ints(self):
381        # Test that equal int values are exactly equal.
382        for n in [42, 19740, 14974, 230, 1795, 700245, 36587]:
383            self.do_exactly_equal_test(n, 0, 0)
384
385    def test_exactly_equal_floats(self):
386        # Test that equal float values are exactly equal.
387        for x in [0.42, 1.9740, 1497.4, 23.0, 179.5, 70.0245, 36.587]:
388            self.do_exactly_equal_test(x, 0, 0)
389
390    def test_exactly_equal_fractions(self):
391        # Test that equal Fraction values are exactly equal.
392        F = Fraction
393        for f in [F(1, 2), F(0), F(5, 3), F(9, 7), F(35, 36), F(3, 7)]:
394            self.do_exactly_equal_test(f, 0, 0)
395
396    def test_exactly_equal_decimals(self):
397        # Test that equal Decimal values are exactly equal.
398        D = Decimal
399        for d in map(D, "8.2 31.274 912.04 16.745 1.2047".split()):
400            self.do_exactly_equal_test(d, 0, 0)
401
402    def test_exactly_equal_absolute(self):
403        # Test that equal values are exactly equal with an absolute error.
404        for n in [16, 1013, 1372, 1198, 971, 4]:
405            # Test as ints.
406            self.do_exactly_equal_test(n, 0.01, 0)
407            # Test as floats.
408            self.do_exactly_equal_test(n/10, 0.01, 0)
409            # Test as Fractions.
410            f = Fraction(n, 1234)
411            self.do_exactly_equal_test(f, 0.01, 0)
412
413    def test_exactly_equal_absolute_decimals(self):
414        # Test equal Decimal values are exactly equal with an absolute error.
415        self.do_exactly_equal_test(Decimal("3.571"), Decimal("0.01"), 0)
416        self.do_exactly_equal_test(-Decimal("81.3971"), Decimal("0.01"), 0)
417
418    def test_exactly_equal_relative(self):
419        # Test that equal values are exactly equal with a relative error.
420        for x in [8347, 101.3, -7910.28, Fraction(5, 21)]:
421            self.do_exactly_equal_test(x, 0, 0.01)
422        self.do_exactly_equal_test(Decimal("11.68"), 0, Decimal("0.01"))
423
424    def test_exactly_equal_both(self):
425        # Test that equal values are equal when both tol and rel are given.
426        for x in [41017, 16.742, -813.02, Fraction(3, 8)]:
427            self.do_exactly_equal_test(x, 0.1, 0.01)
428        D = Decimal
429        self.do_exactly_equal_test(D("7.2"), D("0.1"), D("0.01"))
430
431
432class ApproxEqualUnequalTest(unittest.TestCase):
433    # Unequal values should compare unequal with zero error tolerances.
434    # Test cases for unequal values, with exact equality test.
435
436    def do_exactly_unequal_test(self, x):
437        for a in (x, -x):
438            result = approx_equal(a, a+1, tol=0, rel=0)
439            self.assertFalse(result, 'inequality failure for x=%r' % a)
440
441    def test_exactly_unequal_ints(self):
442        # Test unequal int values are unequal with zero error tolerance.
443        for n in [951, 572305, 478, 917, 17240]:
444            self.do_exactly_unequal_test(n)
445
446    def test_exactly_unequal_floats(self):
447        # Test unequal float values are unequal with zero error tolerance.
448        for x in [9.51, 5723.05, 47.8, 9.17, 17.24]:
449            self.do_exactly_unequal_test(x)
450
451    def test_exactly_unequal_fractions(self):
452        # Test that unequal Fractions are unequal with zero error tolerance.
453        F = Fraction
454        for f in [F(1, 5), F(7, 9), F(12, 11), F(101, 99023)]:
455            self.do_exactly_unequal_test(f)
456
457    def test_exactly_unequal_decimals(self):
458        # Test that unequal Decimals are unequal with zero error tolerance.
459        for d in map(Decimal, "3.1415 298.12 3.47 18.996 0.00245".split()):
460            self.do_exactly_unequal_test(d)
461
462
463class ApproxEqualInexactTest(unittest.TestCase):
464    # Inexact test cases for approx_error.
465    # Test cases when comparing two values that are not exactly equal.
466
467    # === Absolute error tests ===
468
469    def do_approx_equal_abs_test(self, x, delta):
470        template = "Test failure for x={!r}, y={!r}"
471        for y in (x + delta, x - delta):
472            msg = template.format(x, y)
473            self.assertTrue(approx_equal(x, y, tol=2*delta, rel=0), msg)
474            self.assertFalse(approx_equal(x, y, tol=delta/2, rel=0), msg)
475
476    def test_approx_equal_absolute_ints(self):
477        # Test approximate equality of ints with an absolute error.
478        for n in [-10737, -1975, -7, -2, 0, 1, 9, 37, 423, 9874, 23789110]:
479            self.do_approx_equal_abs_test(n, 10)
480            self.do_approx_equal_abs_test(n, 2)
481
482    def test_approx_equal_absolute_floats(self):
483        # Test approximate equality of floats with an absolute error.
484        for x in [-284.126, -97.1, -3.4, -2.15, 0.5, 1.0, 7.8, 4.23, 3817.4]:
485            self.do_approx_equal_abs_test(x, 1.5)
486            self.do_approx_equal_abs_test(x, 0.01)
487            self.do_approx_equal_abs_test(x, 0.0001)
488
489    def test_approx_equal_absolute_fractions(self):
490        # Test approximate equality of Fractions with an absolute error.
491        delta = Fraction(1, 29)
492        numerators = [-84, -15, -2, -1, 0, 1, 5, 17, 23, 34, 71]
493        for f in (Fraction(n, 29) for n in numerators):
494            self.do_approx_equal_abs_test(f, delta)
495            self.do_approx_equal_abs_test(f, float(delta))
496
497    def test_approx_equal_absolute_decimals(self):
498        # Test approximate equality of Decimals with an absolute error.
499        delta = Decimal("0.01")
500        for d in map(Decimal, "1.0 3.5 36.08 61.79 7912.3648".split()):
501            self.do_approx_equal_abs_test(d, delta)
502            self.do_approx_equal_abs_test(-d, delta)
503
504    def test_cross_zero(self):
505        # Test for the case of the two values having opposite signs.
506        self.assertTrue(approx_equal(1e-5, -1e-5, tol=1e-4, rel=0))
507
508    # === Relative error tests ===
509
510    def do_approx_equal_rel_test(self, x, delta):
511        template = "Test failure for x={!r}, y={!r}"
512        for y in (x*(1+delta), x*(1-delta)):
513            msg = template.format(x, y)
514            self.assertTrue(approx_equal(x, y, tol=0, rel=2*delta), msg)
515            self.assertFalse(approx_equal(x, y, tol=0, rel=delta/2), msg)
516
517    def test_approx_equal_relative_ints(self):
518        # Test approximate equality of ints with a relative error.
519        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.36))
520        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.37))
521        # ---
522        self.assertTrue(approx_equal(449, 512, tol=0, rel=0.125))
523        self.assertTrue(approx_equal(448, 512, tol=0, rel=0.125))
524        self.assertFalse(approx_equal(447, 512, tol=0, rel=0.125))
525
526    def test_approx_equal_relative_floats(self):
527        # Test approximate equality of floats with a relative error.
528        for x in [-178.34, -0.1, 0.1, 1.0, 36.97, 2847.136, 9145.074]:
529            self.do_approx_equal_rel_test(x, 0.02)
530            self.do_approx_equal_rel_test(x, 0.0001)
531
532    def test_approx_equal_relative_fractions(self):
533        # Test approximate equality of Fractions with a relative error.
534        F = Fraction
535        delta = Fraction(3, 8)
536        for f in [F(3, 84), F(17, 30), F(49, 50), F(92, 85)]:
537            for d in (delta, float(delta)):
538                self.do_approx_equal_rel_test(f, d)
539                self.do_approx_equal_rel_test(-f, d)
540
541    def test_approx_equal_relative_decimals(self):
542        # Test approximate equality of Decimals with a relative error.
543        for d in map(Decimal, "0.02 1.0 5.7 13.67 94.138 91027.9321".split()):
544            self.do_approx_equal_rel_test(d, Decimal("0.001"))
545            self.do_approx_equal_rel_test(-d, Decimal("0.05"))
546
547    # === Both absolute and relative error tests ===
548
549    # There are four cases to consider:
550    #   1) actual error <= both absolute and relative error
551    #   2) actual error <= absolute error but > relative error
552    #   3) actual error <= relative error but > absolute error
553    #   4) actual error > both absolute and relative error
554
555    def do_check_both(self, a, b, tol, rel, tol_flag, rel_flag):
556        check = self.assertTrue if tol_flag else self.assertFalse
557        check(approx_equal(a, b, tol=tol, rel=0))
558        check = self.assertTrue if rel_flag else self.assertFalse
559        check(approx_equal(a, b, tol=0, rel=rel))
560        check = self.assertTrue if (tol_flag or rel_flag) else self.assertFalse
561        check(approx_equal(a, b, tol=tol, rel=rel))
562
563    def test_approx_equal_both1(self):
564        # Test actual error <= both absolute and relative error.
565        self.do_check_both(7.955, 7.952, 0.004, 3.8e-4, True, True)
566        self.do_check_both(-7.387, -7.386, 0.002, 0.0002, True, True)
567
568    def test_approx_equal_both2(self):
569        # Test actual error <= absolute error but > relative error.
570        self.do_check_both(7.955, 7.952, 0.004, 3.7e-4, True, False)
571
572    def test_approx_equal_both3(self):
573        # Test actual error <= relative error but > absolute error.
574        self.do_check_both(7.955, 7.952, 0.001, 3.8e-4, False, True)
575
576    def test_approx_equal_both4(self):
577        # Test actual error > both absolute and relative error.
578        self.do_check_both(2.78, 2.75, 0.01, 0.001, False, False)
579        self.do_check_both(971.44, 971.47, 0.02, 3e-5, False, False)
580
581
582class ApproxEqualSpecialsTest(unittest.TestCase):
583    # Test approx_equal with NANs and INFs and zeroes.
584
585    def test_inf(self):
586        for type_ in (float, Decimal):
587            inf = type_('inf')
588            self.assertTrue(approx_equal(inf, inf))
589            self.assertTrue(approx_equal(inf, inf, 0, 0))
590            self.assertTrue(approx_equal(inf, inf, 1, 0.01))
591            self.assertTrue(approx_equal(-inf, -inf))
592            self.assertFalse(approx_equal(inf, -inf))
593            self.assertFalse(approx_equal(inf, 1000))
594
595    def test_nan(self):
596        for type_ in (float, Decimal):
597            nan = type_('nan')
598            for other in (nan, type_('inf'), 1000):
599                self.assertFalse(approx_equal(nan, other))
600
601    def test_float_zeroes(self):
602        nzero = math.copysign(0.0, -1)
603        self.assertTrue(approx_equal(nzero, 0.0, tol=0.1, rel=0.1))
604
605    def test_decimal_zeroes(self):
606        nzero = Decimal("-0.0")
607        self.assertTrue(approx_equal(nzero, Decimal(0), tol=0.1, rel=0.1))
608
609
610class TestApproxEqualErrors(unittest.TestCase):
611    # Test error conditions of approx_equal.
612
613    def test_bad_tol(self):
614        # Test negative tol raises.
615        self.assertRaises(ValueError, approx_equal, 100, 100, -1, 0.1)
616
617    def test_bad_rel(self):
618        # Test negative rel raises.
619        self.assertRaises(ValueError, approx_equal, 100, 100, 1, -0.1)
620
621
622# --- Tests for NumericTestCase ---
623
624# The formatting routine that generates the error messages is complex enough
625# that it too needs testing.
626
627class TestNumericTestCase(unittest.TestCase):
628    # The exact wording of NumericTestCase error messages is *not* guaranteed,
629    # but we need to give them some sort of test to ensure that they are
630    # generated correctly. As a compromise, we look for specific substrings
631    # that are expected to be found even if the overall error message changes.
632
633    def do_test(self, args):
634        actual_msg = NumericTestCase._make_std_err_msg(*args)
635        expected = self.generate_substrings(*args)
636        for substring in expected:
637            self.assertIn(substring, actual_msg)
638
639    def test_numerictestcase_is_testcase(self):
640        # Ensure that NumericTestCase actually is a TestCase.
641        self.assertTrue(issubclass(NumericTestCase, unittest.TestCase))
642
643    def test_error_msg_numeric(self):
644        # Test the error message generated for numeric comparisons.
645        args = (2.5, 4.0, 0.5, 0.25, None)
646        self.do_test(args)
647
648    def test_error_msg_sequence(self):
649        # Test the error message generated for sequence comparisons.
650        args = (3.75, 8.25, 1.25, 0.5, 7)
651        self.do_test(args)
652
653    def generate_substrings(self, first, second, tol, rel, idx):
654        """Return substrings we expect to see in error messages."""
655        abs_err, rel_err = _calc_errors(first, second)
656        substrings = [
657                'tol=%r' % tol,
658                'rel=%r' % rel,
659                'absolute error = %r' % abs_err,
660                'relative error = %r' % rel_err,
661                ]
662        if idx is not None:
663            substrings.append('differ at index %d' % idx)
664        return substrings
665
666
667# =======================================
668# === Tests for the statistics module ===
669# =======================================
670
671
672class GlobalsTest(unittest.TestCase):
673    module = statistics
674    expected_metadata = ["__doc__", "__all__"]
675
676    def test_meta(self):
677        # Test for the existence of metadata.
678        for meta in self.expected_metadata:
679            self.assertTrue(hasattr(self.module, meta),
680                            "%s not present" % meta)
681
682    def test_check_all(self):
683        # Check everything in __all__ exists and is public.
684        module = self.module
685        for name in module.__all__:
686            # No private names in __all__:
687            self.assertFalse(name.startswith("_"),
688                             'private name "%s" in __all__' % name)
689            # And anything in __all__ must exist:
690            self.assertTrue(hasattr(module, name),
691                            'missing name "%s" in __all__' % name)
692
693
694class DocTests(unittest.TestCase):
695    @unittest.skipIf(sys.flags.optimize >= 2,
696                     "Docstrings are omitted with -OO and above")
697    def test_doc_tests(self):
698        failed, tried = doctest.testmod(statistics, optionflags=doctest.ELLIPSIS)
699        self.assertGreater(tried, 0)
700        self.assertEqual(failed, 0)
701
702class StatisticsErrorTest(unittest.TestCase):
703    def test_has_exception(self):
704        errmsg = (
705                "Expected StatisticsError to be a ValueError, but got a"
706                " subclass of %r instead."
707                )
708        self.assertTrue(hasattr(statistics, 'StatisticsError'))
709        self.assertTrue(
710                issubclass(statistics.StatisticsError, ValueError),
711                errmsg % statistics.StatisticsError.__base__
712                )
713
714
715# === Tests for private utility functions ===
716
717class ExactRatioTest(unittest.TestCase):
718    # Test _exact_ratio utility.
719
720    def test_int(self):
721        for i in (-20, -3, 0, 5, 99, 10**20):
722            self.assertEqual(statistics._exact_ratio(i), (i, 1))
723
724    def test_fraction(self):
725        numerators = (-5, 1, 12, 38)
726        for n in numerators:
727            f = Fraction(n, 37)
728            self.assertEqual(statistics._exact_ratio(f), (n, 37))
729
730    def test_float(self):
731        self.assertEqual(statistics._exact_ratio(0.125), (1, 8))
732        self.assertEqual(statistics._exact_ratio(1.125), (9, 8))
733        data = [random.uniform(-100, 100) for _ in range(100)]
734        for x in data:
735            num, den = statistics._exact_ratio(x)
736            self.assertEqual(x, num/den)
737
738    def test_decimal(self):
739        D = Decimal
740        _exact_ratio = statistics._exact_ratio
741        self.assertEqual(_exact_ratio(D("0.125")), (1, 8))
742        self.assertEqual(_exact_ratio(D("12.345")), (2469, 200))
743        self.assertEqual(_exact_ratio(D("-1.98")), (-99, 50))
744
745    def test_inf(self):
746        INF = float("INF")
747        class MyFloat(float):
748            pass
749        class MyDecimal(Decimal):
750            pass
751        for inf in (INF, -INF):
752            for type_ in (float, MyFloat, Decimal, MyDecimal):
753                x = type_(inf)
754                ratio = statistics._exact_ratio(x)
755                self.assertEqual(ratio, (x, None))
756                self.assertEqual(type(ratio[0]), type_)
757                self.assertTrue(math.isinf(ratio[0]))
758
759    def test_float_nan(self):
760        NAN = float("NAN")
761        class MyFloat(float):
762            pass
763        for nan in (NAN, MyFloat(NAN)):
764            ratio = statistics._exact_ratio(nan)
765            self.assertTrue(math.isnan(ratio[0]))
766            self.assertIs(ratio[1], None)
767            self.assertEqual(type(ratio[0]), type(nan))
768
769    def test_decimal_nan(self):
770        NAN = Decimal("NAN")
771        sNAN = Decimal("sNAN")
772        class MyDecimal(Decimal):
773            pass
774        for nan in (NAN, MyDecimal(NAN), sNAN, MyDecimal(sNAN)):
775            ratio = statistics._exact_ratio(nan)
776            self.assertTrue(_nan_equal(ratio[0], nan))
777            self.assertIs(ratio[1], None)
778            self.assertEqual(type(ratio[0]), type(nan))
779
780
781class DecimalToRatioTest(unittest.TestCase):
782    # Test _exact_ratio private function.
783
784    def test_infinity(self):
785        # Test that INFs are handled correctly.
786        inf = Decimal('INF')
787        self.assertEqual(statistics._exact_ratio(inf), (inf, None))
788        self.assertEqual(statistics._exact_ratio(-inf), (-inf, None))
789
790    def test_nan(self):
791        # Test that NANs are handled correctly.
792        for nan in (Decimal('NAN'), Decimal('sNAN')):
793            num, den = statistics._exact_ratio(nan)
794            # Because NANs always compare non-equal, we cannot use assertEqual.
795            # Nor can we use an identity test, as we don't guarantee anything
796            # about the object identity.
797            self.assertTrue(_nan_equal(num, nan))
798            self.assertIs(den, None)
799
800    def test_sign(self):
801        # Test sign is calculated correctly.
802        numbers = [Decimal("9.8765e12"), Decimal("9.8765e-12")]
803        for d in numbers:
804            # First test positive decimals.
805            assert d > 0
806            num, den = statistics._exact_ratio(d)
807            self.assertGreaterEqual(num, 0)
808            self.assertGreater(den, 0)
809            # Then test negative decimals.
810            num, den = statistics._exact_ratio(-d)
811            self.assertLessEqual(num, 0)
812            self.assertGreater(den, 0)
813
814    def test_negative_exponent(self):
815        # Test result when the exponent is negative.
816        t = statistics._exact_ratio(Decimal("0.1234"))
817        self.assertEqual(t, (617, 5000))
818
819    def test_positive_exponent(self):
820        # Test results when the exponent is positive.
821        t = statistics._exact_ratio(Decimal("1.234e7"))
822        self.assertEqual(t, (12340000, 1))
823
824    def test_regression_20536(self):
825        # Regression test for issue 20536.
826        # See http://bugs.python.org/issue20536
827        t = statistics._exact_ratio(Decimal("1e2"))
828        self.assertEqual(t, (100, 1))
829        t = statistics._exact_ratio(Decimal("1.47e5"))
830        self.assertEqual(t, (147000, 1))
831
832
833class IsFiniteTest(unittest.TestCase):
834    # Test _isfinite private function.
835
836    def test_finite(self):
837        # Test that finite numbers are recognised as finite.
838        for x in (5, Fraction(1, 3), 2.5, Decimal("5.5")):
839            self.assertTrue(statistics._isfinite(x))
840
841    def test_infinity(self):
842        # Test that INFs are not recognised as finite.
843        for x in (float("inf"), Decimal("inf")):
844            self.assertFalse(statistics._isfinite(x))
845
846    def test_nan(self):
847        # Test that NANs are not recognised as finite.
848        for x in (float("nan"), Decimal("NAN"), Decimal("sNAN")):
849            self.assertFalse(statistics._isfinite(x))
850
851
852class CoerceTest(unittest.TestCase):
853    # Test that private function _coerce correctly deals with types.
854
855    # The coercion rules are currently an implementation detail, although at
856    # some point that should change. The tests and comments here define the
857    # correct implementation.
858
859    # Pre-conditions of _coerce:
860    #
861    #   - The first time _sum calls _coerce, the
862    #   - coerce(T, S) will never be called with bool as the first argument;
863    #     this is a pre-condition, guarded with an assertion.
864
865    #
866    #   - coerce(T, T) will always return T; we assume T is a valid numeric
867    #     type. Violate this assumption at your own risk.
868    #
869    #   - Apart from as above, bool is treated as if it were actually int.
870    #
871    #   - coerce(int, X) and coerce(X, int) return X.
872    #   -
873    def test_bool(self):
874        # bool is somewhat special, due to the pre-condition that it is
875        # never given as the first argument to _coerce, and that it cannot
876        # be subclassed. So we test it specially.
877        for T in (int, float, Fraction, Decimal):
878            self.assertIs(statistics._coerce(T, bool), T)
879            class MyClass(T): pass
880            self.assertIs(statistics._coerce(MyClass, bool), MyClass)
881
882    def assertCoerceTo(self, A, B):
883        """Assert that type A coerces to B."""
884        self.assertIs(statistics._coerce(A, B), B)
885        self.assertIs(statistics._coerce(B, A), B)
886
887    def check_coerce_to(self, A, B):
888        """Checks that type A coerces to B, including subclasses."""
889        # Assert that type A is coerced to B.
890        self.assertCoerceTo(A, B)
891        # Subclasses of A are also coerced to B.
892        class SubclassOfA(A): pass
893        self.assertCoerceTo(SubclassOfA, B)
894        # A, and subclasses of A, are coerced to subclasses of B.
895        class SubclassOfB(B): pass
896        self.assertCoerceTo(A, SubclassOfB)
897        self.assertCoerceTo(SubclassOfA, SubclassOfB)
898
899    def assertCoerceRaises(self, A, B):
900        """Assert that coercing A to B, or vice versa, raises TypeError."""
901        self.assertRaises(TypeError, statistics._coerce, (A, B))
902        self.assertRaises(TypeError, statistics._coerce, (B, A))
903
904    def check_type_coercions(self, T):
905        """Check that type T coerces correctly with subclasses of itself."""
906        assert T is not bool
907        # Coercing a type with itself returns the same type.
908        self.assertIs(statistics._coerce(T, T), T)
909        # Coercing a type with a subclass of itself returns the subclass.
910        class U(T): pass
911        class V(T): pass
912        class W(U): pass
913        for typ in (U, V, W):
914            self.assertCoerceTo(T, typ)
915        self.assertCoerceTo(U, W)
916        # Coercing two subclasses that aren't parent/child is an error.
917        self.assertCoerceRaises(U, V)
918        self.assertCoerceRaises(V, W)
919
920    def test_int(self):
921        # Check that int coerces correctly.
922        self.check_type_coercions(int)
923        for typ in (float, Fraction, Decimal):
924            self.check_coerce_to(int, typ)
925
926    def test_fraction(self):
927        # Check that Fraction coerces correctly.
928        self.check_type_coercions(Fraction)
929        self.check_coerce_to(Fraction, float)
930
931    def test_decimal(self):
932        # Check that Decimal coerces correctly.
933        self.check_type_coercions(Decimal)
934
935    def test_float(self):
936        # Check that float coerces correctly.
937        self.check_type_coercions(float)
938
939    def test_non_numeric_types(self):
940        for bad_type in (str, list, type(None), tuple, dict):
941            for good_type in (int, float, Fraction, Decimal):
942                self.assertCoerceRaises(good_type, bad_type)
943
944    def test_incompatible_types(self):
945        # Test that incompatible types raise.
946        for T in (float, Fraction):
947            class MySubclass(T): pass
948            self.assertCoerceRaises(T, Decimal)
949            self.assertCoerceRaises(MySubclass, Decimal)
950
951
952class ConvertTest(unittest.TestCase):
953    # Test private _convert function.
954
955    def check_exact_equal(self, x, y):
956        """Check that x equals y, and has the same type as well."""
957        self.assertEqual(x, y)
958        self.assertIs(type(x), type(y))
959
960    def test_int(self):
961        # Test conversions to int.
962        x = statistics._convert(Fraction(71), int)
963        self.check_exact_equal(x, 71)
964        class MyInt(int): pass
965        x = statistics._convert(Fraction(17), MyInt)
966        self.check_exact_equal(x, MyInt(17))
967
968    def test_fraction(self):
969        # Test conversions to Fraction.
970        x = statistics._convert(Fraction(95, 99), Fraction)
971        self.check_exact_equal(x, Fraction(95, 99))
972        class MyFraction(Fraction):
973            def __truediv__(self, other):
974                return self.__class__(super().__truediv__(other))
975        x = statistics._convert(Fraction(71, 13), MyFraction)
976        self.check_exact_equal(x, MyFraction(71, 13))
977
978    def test_float(self):
979        # Test conversions to float.
980        x = statistics._convert(Fraction(-1, 2), float)
981        self.check_exact_equal(x, -0.5)
982        class MyFloat(float):
983            def __truediv__(self, other):
984                return self.__class__(super().__truediv__(other))
985        x = statistics._convert(Fraction(9, 8), MyFloat)
986        self.check_exact_equal(x, MyFloat(1.125))
987
988    def test_decimal(self):
989        # Test conversions to Decimal.
990        x = statistics._convert(Fraction(1, 40), Decimal)
991        self.check_exact_equal(x, Decimal("0.025"))
992        class MyDecimal(Decimal):
993            def __truediv__(self, other):
994                return self.__class__(super().__truediv__(other))
995        x = statistics._convert(Fraction(-15, 16), MyDecimal)
996        self.check_exact_equal(x, MyDecimal("-0.9375"))
997
998    def test_inf(self):
999        for INF in (float('inf'), Decimal('inf')):
1000            for inf in (INF, -INF):
1001                x = statistics._convert(inf, type(inf))
1002                self.check_exact_equal(x, inf)
1003
1004    def test_nan(self):
1005        for nan in (float('nan'), Decimal('NAN'), Decimal('sNAN')):
1006            x = statistics._convert(nan, type(nan))
1007            self.assertTrue(_nan_equal(x, nan))
1008
1009    def test_invalid_input_type(self):
1010        with self.assertRaises(TypeError):
1011            statistics._convert(None, float)
1012
1013
1014class FailNegTest(unittest.TestCase):
1015    """Test _fail_neg private function."""
1016
1017    def test_pass_through(self):
1018        # Test that values are passed through unchanged.
1019        values = [1, 2.0, Fraction(3), Decimal(4)]
1020        new = list(statistics._fail_neg(values))
1021        self.assertEqual(values, new)
1022
1023    def test_negatives_raise(self):
1024        # Test that negatives raise an exception.
1025        for x in [1, 2.0, Fraction(3), Decimal(4)]:
1026            seq = [-x]
1027            it = statistics._fail_neg(seq)
1028            self.assertRaises(statistics.StatisticsError, next, it)
1029
1030    def test_error_msg(self):
1031        # Test that a given error message is used.
1032        msg = "badness #%d" % random.randint(10000, 99999)
1033        try:
1034            next(statistics._fail_neg([-1], msg))
1035        except statistics.StatisticsError as e:
1036            errmsg = e.args[0]
1037        else:
1038            self.fail("expected exception, but it didn't happen")
1039        self.assertEqual(errmsg, msg)
1040
1041
1042class FindLteqTest(unittest.TestCase):
1043    # Test _find_lteq private function.
1044
1045    def test_invalid_input_values(self):
1046        for a, x in [
1047            ([], 1),
1048            ([1, 2], 3),
1049            ([1, 3], 2)
1050        ]:
1051            with self.subTest(a=a, x=x):
1052                with self.assertRaises(ValueError):
1053                    statistics._find_lteq(a, x)
1054
1055    def test_locate_successfully(self):
1056        for a, x, expected_i in [
1057            ([1, 1, 1, 2, 3], 1, 0),
1058            ([0, 1, 1, 1, 2, 3], 1, 1),
1059            ([1, 2, 3, 3, 3], 3, 2)
1060        ]:
1061            with self.subTest(a=a, x=x):
1062                self.assertEqual(expected_i, statistics._find_lteq(a, x))
1063
1064
1065class FindRteqTest(unittest.TestCase):
1066    # Test _find_rteq private function.
1067
1068    def test_invalid_input_values(self):
1069        for a, l, x in [
1070            ([1], 2, 1),
1071            ([1, 3], 0, 2)
1072        ]:
1073            with self.assertRaises(ValueError):
1074                statistics._find_rteq(a, l, x)
1075
1076    def test_locate_successfully(self):
1077        for a, l, x, expected_i in [
1078            ([1, 1, 1, 2, 3], 0, 1, 2),
1079            ([0, 1, 1, 1, 2, 3], 0, 1, 3),
1080            ([1, 2, 3, 3, 3], 0, 3, 4)
1081        ]:
1082            with self.subTest(a=a, l=l, x=x):
1083                self.assertEqual(expected_i, statistics._find_rteq(a, l, x))
1084
1085
1086# === Tests for public functions ===
1087
1088class UnivariateCommonMixin:
1089    # Common tests for most univariate functions that take a data argument.
1090
1091    def test_no_args(self):
1092        # Fail if given no arguments.
1093        self.assertRaises(TypeError, self.func)
1094
1095    def test_empty_data(self):
1096        # Fail when the data argument (first argument) is empty.
1097        for empty in ([], (), iter([])):
1098            self.assertRaises(statistics.StatisticsError, self.func, empty)
1099
1100    def prepare_data(self):
1101        """Return int data for various tests."""
1102        data = list(range(10))
1103        while data == sorted(data):
1104            random.shuffle(data)
1105        return data
1106
1107    def test_no_inplace_modifications(self):
1108        # Test that the function does not modify its input data.
1109        data = self.prepare_data()
1110        assert len(data) != 1  # Necessary to avoid infinite loop.
1111        assert data != sorted(data)
1112        saved = data[:]
1113        assert data is not saved
1114        _ = self.func(data)
1115        self.assertListEqual(data, saved, "data has been modified")
1116
1117    def test_order_doesnt_matter(self):
1118        # Test that the order of data points doesn't change the result.
1119
1120        # CAUTION: due to floating point rounding errors, the result actually
1121        # may depend on the order. Consider this test representing an ideal.
1122        # To avoid this test failing, only test with exact values such as ints
1123        # or Fractions.
1124        data = [1, 2, 3, 3, 3, 4, 5, 6]*100
1125        expected = self.func(data)
1126        random.shuffle(data)
1127        actual = self.func(data)
1128        self.assertEqual(expected, actual)
1129
1130    def test_type_of_data_collection(self):
1131        # Test that the type of iterable data doesn't effect the result.
1132        class MyList(list):
1133            pass
1134        class MyTuple(tuple):
1135            pass
1136        def generator(data):
1137            return (obj for obj in data)
1138        data = self.prepare_data()
1139        expected = self.func(data)
1140        for kind in (list, tuple, iter, MyList, MyTuple, generator):
1141            result = self.func(kind(data))
1142            self.assertEqual(result, expected)
1143
1144    def test_range_data(self):
1145        # Test that functions work with range objects.
1146        data = range(20, 50, 3)
1147        expected = self.func(list(data))
1148        self.assertEqual(self.func(data), expected)
1149
1150    def test_bad_arg_types(self):
1151        # Test that function raises when given data of the wrong type.
1152
1153        # Don't roll the following into a loop like this:
1154        #   for bad in list_of_bad:
1155        #       self.check_for_type_error(bad)
1156        #
1157        # Since assertRaises doesn't show the arguments that caused the test
1158        # failure, it is very difficult to debug these test failures when the
1159        # following are in a loop.
1160        self.check_for_type_error(None)
1161        self.check_for_type_error(23)
1162        self.check_for_type_error(42.0)
1163        self.check_for_type_error(object())
1164
1165    def check_for_type_error(self, *args):
1166        self.assertRaises(TypeError, self.func, *args)
1167
1168    def test_type_of_data_element(self):
1169        # Check the type of data elements doesn't affect the numeric result.
1170        # This is a weaker test than UnivariateTypeMixin.testTypesConserved,
1171        # because it checks the numeric result by equality, but not by type.
1172        class MyFloat(float):
1173            def __truediv__(self, other):
1174                return type(self)(super().__truediv__(other))
1175            def __add__(self, other):
1176                return type(self)(super().__add__(other))
1177            __radd__ = __add__
1178
1179        raw = self.prepare_data()
1180        expected = self.func(raw)
1181        for kind in (float, MyFloat, Decimal, Fraction):
1182            data = [kind(x) for x in raw]
1183            result = type(expected)(self.func(data))
1184            self.assertEqual(result, expected)
1185
1186
1187class UnivariateTypeMixin:
1188    """Mixin class for type-conserving functions.
1189
1190    This mixin class holds test(s) for functions which conserve the type of
1191    individual data points. E.g. the mean of a list of Fractions should itself
1192    be a Fraction.
1193
1194    Not all tests to do with types need go in this class. Only those that
1195    rely on the function returning the same type as its input data.
1196    """
1197    def prepare_types_for_conservation_test(self):
1198        """Return the types which are expected to be conserved."""
1199        class MyFloat(float):
1200            def __truediv__(self, other):
1201                return type(self)(super().__truediv__(other))
1202            def __rtruediv__(self, other):
1203                return type(self)(super().__rtruediv__(other))
1204            def __sub__(self, other):
1205                return type(self)(super().__sub__(other))
1206            def __rsub__(self, other):
1207                return type(self)(super().__rsub__(other))
1208            def __pow__(self, other):
1209                return type(self)(super().__pow__(other))
1210            def __add__(self, other):
1211                return type(self)(super().__add__(other))
1212            __radd__ = __add__
1213        return (float, Decimal, Fraction, MyFloat)
1214
1215    def test_types_conserved(self):
1216        # Test that functions keeps the same type as their data points.
1217        # (Excludes mixed data types.) This only tests the type of the return
1218        # result, not the value.
1219        data = self.prepare_data()
1220        for kind in self.prepare_types_for_conservation_test():
1221            d = [kind(x) for x in data]
1222            result = self.func(d)
1223            self.assertIs(type(result), kind)
1224
1225
1226class TestSumCommon(UnivariateCommonMixin, UnivariateTypeMixin):
1227    # Common test cases for statistics._sum() function.
1228
1229    # This test suite looks only at the numeric value returned by _sum,
1230    # after conversion to the appropriate type.
1231    def setUp(self):
1232        def simplified_sum(*args):
1233            T, value, n = statistics._sum(*args)
1234            return statistics._coerce(value, T)
1235        self.func = simplified_sum
1236
1237
1238class TestSum(NumericTestCase):
1239    # Test cases for statistics._sum() function.
1240
1241    # These tests look at the entire three value tuple returned by _sum.
1242
1243    def setUp(self):
1244        self.func = statistics._sum
1245
1246    def test_empty_data(self):
1247        # Override test for empty data.
1248        for data in ([], (), iter([])):
1249            self.assertEqual(self.func(data), (int, Fraction(0), 0))
1250
1251    def test_ints(self):
1252        self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]),
1253                         (int, Fraction(60), 8))
1254
1255    def test_floats(self):
1256        self.assertEqual(self.func([0.25]*20),
1257                         (float, Fraction(5.0), 20))
1258
1259    def test_fractions(self):
1260        self.assertEqual(self.func([Fraction(1, 1000)]*500),
1261                         (Fraction, Fraction(1, 2), 500))
1262
1263    def test_decimals(self):
1264        D = Decimal
1265        data = [D("0.001"), D("5.246"), D("1.702"), D("-0.025"),
1266                D("3.974"), D("2.328"), D("4.617"), D("2.843"),
1267                ]
1268        self.assertEqual(self.func(data),
1269                         (Decimal, Decimal("20.686"), 8))
1270
1271    def test_compare_with_math_fsum(self):
1272        # Compare with the math.fsum function.
1273        # Ideally we ought to get the exact same result, but sometimes
1274        # we differ by a very slight amount :-(
1275        data = [random.uniform(-100, 1000) for _ in range(1000)]
1276        self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16)
1277
1278    def test_strings_fail(self):
1279        # Sum of strings should fail.
1280        self.assertRaises(TypeError, self.func, [1, 2, 3], '999')
1281        self.assertRaises(TypeError, self.func, [1, 2, 3, '999'])
1282
1283    def test_bytes_fail(self):
1284        # Sum of bytes should fail.
1285        self.assertRaises(TypeError, self.func, [1, 2, 3], b'999')
1286        self.assertRaises(TypeError, self.func, [1, 2, 3, b'999'])
1287
1288    def test_mixed_sum(self):
1289        # Mixed input types are not (currently) allowed.
1290        # Check that mixed data types fail.
1291        self.assertRaises(TypeError, self.func, [1, 2.0, Decimal(1)])
1292        # And so does mixed start argument.
1293        self.assertRaises(TypeError, self.func, [1, 2.0], Decimal(1))
1294
1295
1296class SumTortureTest(NumericTestCase):
1297    def test_torture(self):
1298        # Tim Peters' torture test for sum, and variants of same.
1299        self.assertEqual(statistics._sum([1, 1e100, 1, -1e100]*10000),
1300                         (float, Fraction(20000.0), 40000))
1301        self.assertEqual(statistics._sum([1e100, 1, 1, -1e100]*10000),
1302                         (float, Fraction(20000.0), 40000))
1303        T, num, count = statistics._sum([1e-100, 1, 1e-100, -1]*10000)
1304        self.assertIs(T, float)
1305        self.assertEqual(count, 40000)
1306        self.assertApproxEqual(float(num), 2.0e-96, rel=5e-16)
1307
1308
1309class SumSpecialValues(NumericTestCase):
1310    # Test that sum works correctly with IEEE-754 special values.
1311
1312    def test_nan(self):
1313        for type_ in (float, Decimal):
1314            nan = type_('nan')
1315            result = statistics._sum([1, nan, 2])[1]
1316            self.assertIs(type(result), type_)
1317            self.assertTrue(math.isnan(result))
1318
1319    def check_infinity(self, x, inf):
1320        """Check x is an infinity of the same type and sign as inf."""
1321        self.assertTrue(math.isinf(x))
1322        self.assertIs(type(x), type(inf))
1323        self.assertEqual(x > 0, inf > 0)
1324        assert x == inf
1325
1326    def do_test_inf(self, inf):
1327        # Adding a single infinity gives infinity.
1328        result = statistics._sum([1, 2, inf, 3])[1]
1329        self.check_infinity(result, inf)
1330        # Adding two infinities of the same sign also gives infinity.
1331        result = statistics._sum([1, 2, inf, 3, inf, 4])[1]
1332        self.check_infinity(result, inf)
1333
1334    def test_float_inf(self):
1335        inf = float('inf')
1336        for sign in (+1, -1):
1337            self.do_test_inf(sign*inf)
1338
1339    def test_decimal_inf(self):
1340        inf = Decimal('inf')
1341        for sign in (+1, -1):
1342            self.do_test_inf(sign*inf)
1343
1344    def test_float_mismatched_infs(self):
1345        # Test that adding two infinities of opposite sign gives a NAN.
1346        inf = float('inf')
1347        result = statistics._sum([1, 2, inf, 3, -inf, 4])[1]
1348        self.assertTrue(math.isnan(result))
1349
1350    def test_decimal_extendedcontext_mismatched_infs_to_nan(self):
1351        # Test adding Decimal INFs with opposite sign returns NAN.
1352        inf = Decimal('inf')
1353        data = [1, 2, inf, 3, -inf, 4]
1354        with decimal.localcontext(decimal.ExtendedContext):
1355            self.assertTrue(math.isnan(statistics._sum(data)[1]))
1356
1357    def test_decimal_basiccontext_mismatched_infs_to_nan(self):
1358        # Test adding Decimal INFs with opposite sign raises InvalidOperation.
1359        inf = Decimal('inf')
1360        data = [1, 2, inf, 3, -inf, 4]
1361        with decimal.localcontext(decimal.BasicContext):
1362            self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1363
1364    def test_decimal_snan_raises(self):
1365        # Adding sNAN should raise InvalidOperation.
1366        sNAN = Decimal('sNAN')
1367        data = [1, sNAN, 2]
1368        self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1369
1370
1371# === Tests for averages ===
1372
1373class AverageMixin(UnivariateCommonMixin):
1374    # Mixin class holding common tests for averages.
1375
1376    def test_single_value(self):
1377        # Average of a single value is the value itself.
1378        for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')):
1379            self.assertEqual(self.func([x]), x)
1380
1381    def prepare_values_for_repeated_single_test(self):
1382        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712'))
1383
1384    def test_repeated_single_value(self):
1385        # The average of a single repeated value is the value itself.
1386        for x in self.prepare_values_for_repeated_single_test():
1387            for count in (2, 5, 10, 20):
1388                with self.subTest(x=x, count=count):
1389                    data = [x]*count
1390                    self.assertEqual(self.func(data), x)
1391
1392
1393class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1394    def setUp(self):
1395        self.func = statistics.mean
1396
1397    def test_torture_pep(self):
1398        # "Torture Test" from PEP-450.
1399        self.assertEqual(self.func([1e100, 1, 3, -1e100]), 1)
1400
1401    def test_ints(self):
1402        # Test mean with ints.
1403        data = [0, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 7, 7, 7, 8, 9]
1404        random.shuffle(data)
1405        self.assertEqual(self.func(data), 4.8125)
1406
1407    def test_floats(self):
1408        # Test mean with floats.
1409        data = [17.25, 19.75, 20.0, 21.5, 21.75, 23.25, 25.125, 27.5]
1410        random.shuffle(data)
1411        self.assertEqual(self.func(data), 22.015625)
1412
1413    def test_decimals(self):
1414        # Test mean with Decimals.
1415        D = Decimal
1416        data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")]
1417        random.shuffle(data)
1418        self.assertEqual(self.func(data), D("3.5896"))
1419
1420    def test_fractions(self):
1421        # Test mean with Fractions.
1422        F = Fraction
1423        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1424        random.shuffle(data)
1425        self.assertEqual(self.func(data), F(1479, 1960))
1426
1427    def test_inf(self):
1428        # Test mean with infinities.
1429        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1430        for kind in (float, Decimal):
1431            for sign in (1, -1):
1432                inf = kind("inf")*sign
1433                data = raw + [inf]
1434                result = self.func(data)
1435                self.assertTrue(math.isinf(result))
1436                self.assertEqual(result, inf)
1437
1438    def test_mismatched_infs(self):
1439        # Test mean with infinities of opposite sign.
1440        data = [2, 4, 6, float('inf'), 1, 3, 5, float('-inf')]
1441        result = self.func(data)
1442        self.assertTrue(math.isnan(result))
1443
1444    def test_nan(self):
1445        # Test mean with NANs.
1446        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1447        for kind in (float, Decimal):
1448            inf = kind("nan")
1449            data = raw + [inf]
1450            result = self.func(data)
1451            self.assertTrue(math.isnan(result))
1452
1453    def test_big_data(self):
1454        # Test adding a large constant to every data point.
1455        c = 1e9
1456        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1457        expected = self.func(data) + c
1458        assert expected != c
1459        result = self.func([x+c for x in data])
1460        self.assertEqual(result, expected)
1461
1462    def test_doubled_data(self):
1463        # Mean of [a,b,c...z] should be same as for [a,a,b,b,c,c...z,z].
1464        data = [random.uniform(-3, 5) for _ in range(1000)]
1465        expected = self.func(data)
1466        actual = self.func(data*2)
1467        self.assertApproxEqual(actual, expected)
1468
1469    def test_regression_20561(self):
1470        # Regression test for issue 20561.
1471        # See http://bugs.python.org/issue20561
1472        d = Decimal('1e4')
1473        self.assertEqual(statistics.mean([d]), d)
1474
1475    def test_regression_25177(self):
1476        # Regression test for issue 25177.
1477        # Ensure very big and very small floats don't overflow.
1478        # See http://bugs.python.org/issue25177.
1479        self.assertEqual(statistics.mean(
1480            [8.988465674311579e+307, 8.98846567431158e+307]),
1481            8.98846567431158e+307)
1482        big = 8.98846567431158e+307
1483        tiny = 5e-324
1484        for n in (2, 3, 5, 200):
1485            self.assertEqual(statistics.mean([big]*n), big)
1486            self.assertEqual(statistics.mean([tiny]*n), tiny)
1487
1488
1489class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1490    def setUp(self):
1491        self.func = statistics.harmonic_mean
1492
1493    def prepare_data(self):
1494        # Override mixin method.
1495        values = super().prepare_data()
1496        values.remove(0)
1497        return values
1498
1499    def prepare_values_for_repeated_single_test(self):
1500        # Override mixin method.
1501        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.125'))
1502
1503    def test_zero(self):
1504        # Test that harmonic mean returns zero when given zero.
1505        values = [1, 0, 2]
1506        self.assertEqual(self.func(values), 0)
1507
1508    def test_negative_error(self):
1509        # Test that harmonic mean raises when given a negative value.
1510        exc = statistics.StatisticsError
1511        for values in ([-1], [1, -2, 3]):
1512            with self.subTest(values=values):
1513                self.assertRaises(exc, self.func, values)
1514
1515    def test_invalid_type_error(self):
1516        # Test error is raised when input contains invalid type(s)
1517        for data in [
1518            ['3.14'],               # single string
1519            ['1', '2', '3'],        # multiple strings
1520            [1, '2', 3, '4', 5],    # mixed strings and valid integers
1521            [2.3, 3.4, 4.5, '5.6']  # only one string and valid floats
1522        ]:
1523            with self.subTest(data=data):
1524                with self.assertRaises(TypeError):
1525                    self.func(data)
1526
1527    def test_ints(self):
1528        # Test harmonic mean with ints.
1529        data = [2, 4, 4, 8, 16, 16]
1530        random.shuffle(data)
1531        self.assertEqual(self.func(data), 6*4/5)
1532
1533    def test_floats_exact(self):
1534        # Test harmonic mean with some carefully chosen floats.
1535        data = [1/8, 1/4, 1/4, 1/2, 1/2]
1536        random.shuffle(data)
1537        self.assertEqual(self.func(data), 1/4)
1538        self.assertEqual(self.func([0.25, 0.5, 1.0, 1.0]), 0.5)
1539
1540    def test_singleton_lists(self):
1541        # Test that harmonic mean([x]) returns (approximately) x.
1542        for x in range(1, 101):
1543            self.assertEqual(self.func([x]), x)
1544
1545    def test_decimals_exact(self):
1546        # Test harmonic mean with some carefully chosen Decimals.
1547        D = Decimal
1548        self.assertEqual(self.func([D(15), D(30), D(60), D(60)]), D(30))
1549        data = [D("0.05"), D("0.10"), D("0.20"), D("0.20")]
1550        random.shuffle(data)
1551        self.assertEqual(self.func(data), D("0.10"))
1552        data = [D("1.68"), D("0.32"), D("5.94"), D("2.75")]
1553        random.shuffle(data)
1554        self.assertEqual(self.func(data), D(66528)/70723)
1555
1556    def test_fractions(self):
1557        # Test harmonic mean with Fractions.
1558        F = Fraction
1559        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1560        random.shuffle(data)
1561        self.assertEqual(self.func(data), F(7*420, 4029))
1562
1563    def test_inf(self):
1564        # Test harmonic mean with infinity.
1565        values = [2.0, float('inf'), 1.0]
1566        self.assertEqual(self.func(values), 2.0)
1567
1568    def test_nan(self):
1569        # Test harmonic mean with NANs.
1570        values = [2.0, float('nan'), 1.0]
1571        self.assertTrue(math.isnan(self.func(values)))
1572
1573    def test_multiply_data_points(self):
1574        # Test multiplying every data point by a constant.
1575        c = 111
1576        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1577        expected = self.func(data)*c
1578        result = self.func([x*c for x in data])
1579        self.assertEqual(result, expected)
1580
1581    def test_doubled_data(self):
1582        # Harmonic mean of [a,b...z] should be same as for [a,a,b,b...z,z].
1583        data = [random.uniform(1, 5) for _ in range(1000)]
1584        expected = self.func(data)
1585        actual = self.func(data*2)
1586        self.assertApproxEqual(actual, expected)
1587
1588    def test_with_weights(self):
1589        self.assertEqual(self.func([40, 60], [5, 30]), 56.0)  # common case
1590        self.assertEqual(self.func([40, 60],
1591                                   weights=[5, 30]), 56.0)    # keyword argument
1592        self.assertEqual(self.func(iter([40, 60]),
1593                                   iter([5, 30])), 56.0)      # iterator inputs
1594        self.assertEqual(
1595            self.func([Fraction(10, 3), Fraction(23, 5), Fraction(7, 2)], [5, 2, 10]),
1596            self.func([Fraction(10, 3)] * 5 +
1597                      [Fraction(23, 5)] * 2 +
1598                      [Fraction(7, 2)] * 10))
1599        self.assertEqual(self.func([10], [7]), 10)            # n=1 fast path
1600        with self.assertRaises(TypeError):
1601            self.func([1, 2, 3], [1, (), 3])                  # non-numeric weight
1602        with self.assertRaises(statistics.StatisticsError):
1603            self.func([1, 2, 3], [1, 2])                      # wrong number of weights
1604        with self.assertRaises(statistics.StatisticsError):
1605            self.func([10], [0])                              # no non-zero weights
1606        with self.assertRaises(statistics.StatisticsError):
1607            self.func([10, 20], [0, 0])                       # no non-zero weights
1608
1609
1610class TestMedian(NumericTestCase, AverageMixin):
1611    # Common tests for median and all median.* functions.
1612    def setUp(self):
1613        self.func = statistics.median
1614
1615    def prepare_data(self):
1616        """Overload method from UnivariateCommonMixin."""
1617        data = super().prepare_data()
1618        if len(data)%2 != 1:
1619            data.append(2)
1620        return data
1621
1622    def test_even_ints(self):
1623        # Test median with an even number of int data points.
1624        data = [1, 2, 3, 4, 5, 6]
1625        assert len(data)%2 == 0
1626        self.assertEqual(self.func(data), 3.5)
1627
1628    def test_odd_ints(self):
1629        # Test median with an odd number of int data points.
1630        data = [1, 2, 3, 4, 5, 6, 9]
1631        assert len(data)%2 == 1
1632        self.assertEqual(self.func(data), 4)
1633
1634    def test_odd_fractions(self):
1635        # Test median works with an odd number of Fractions.
1636        F = Fraction
1637        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7)]
1638        assert len(data)%2 == 1
1639        random.shuffle(data)
1640        self.assertEqual(self.func(data), F(3, 7))
1641
1642    def test_even_fractions(self):
1643        # Test median works with an even number of Fractions.
1644        F = Fraction
1645        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1646        assert len(data)%2 == 0
1647        random.shuffle(data)
1648        self.assertEqual(self.func(data), F(1, 2))
1649
1650    def test_odd_decimals(self):
1651        # Test median works with an odd number of Decimals.
1652        D = Decimal
1653        data = [D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1654        assert len(data)%2 == 1
1655        random.shuffle(data)
1656        self.assertEqual(self.func(data), D('4.2'))
1657
1658    def test_even_decimals(self):
1659        # Test median works with an even number of Decimals.
1660        D = Decimal
1661        data = [D('1.2'), D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1662        assert len(data)%2 == 0
1663        random.shuffle(data)
1664        self.assertEqual(self.func(data), D('3.65'))
1665
1666
1667class TestMedianDataType(NumericTestCase, UnivariateTypeMixin):
1668    # Test conservation of data element type for median.
1669    def setUp(self):
1670        self.func = statistics.median
1671
1672    def prepare_data(self):
1673        data = list(range(15))
1674        assert len(data)%2 == 1
1675        while data == sorted(data):
1676            random.shuffle(data)
1677        return data
1678
1679
1680class TestMedianLow(TestMedian, UnivariateTypeMixin):
1681    def setUp(self):
1682        self.func = statistics.median_low
1683
1684    def test_even_ints(self):
1685        # Test median_low with an even number of ints.
1686        data = [1, 2, 3, 4, 5, 6]
1687        assert len(data)%2 == 0
1688        self.assertEqual(self.func(data), 3)
1689
1690    def test_even_fractions(self):
1691        # Test median_low works with an even number of Fractions.
1692        F = Fraction
1693        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1694        assert len(data)%2 == 0
1695        random.shuffle(data)
1696        self.assertEqual(self.func(data), F(3, 7))
1697
1698    def test_even_decimals(self):
1699        # Test median_low works with an even number of Decimals.
1700        D = Decimal
1701        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1702        assert len(data)%2 == 0
1703        random.shuffle(data)
1704        self.assertEqual(self.func(data), D('3.3'))
1705
1706
1707class TestMedianHigh(TestMedian, UnivariateTypeMixin):
1708    def setUp(self):
1709        self.func = statistics.median_high
1710
1711    def test_even_ints(self):
1712        # Test median_high with an even number of ints.
1713        data = [1, 2, 3, 4, 5, 6]
1714        assert len(data)%2 == 0
1715        self.assertEqual(self.func(data), 4)
1716
1717    def test_even_fractions(self):
1718        # Test median_high works with an even number of Fractions.
1719        F = Fraction
1720        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1721        assert len(data)%2 == 0
1722        random.shuffle(data)
1723        self.assertEqual(self.func(data), F(4, 7))
1724
1725    def test_even_decimals(self):
1726        # Test median_high works with an even number of Decimals.
1727        D = Decimal
1728        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1729        assert len(data)%2 == 0
1730        random.shuffle(data)
1731        self.assertEqual(self.func(data), D('4.4'))
1732
1733
1734class TestMedianGrouped(TestMedian):
1735    # Test median_grouped.
1736    # Doesn't conserve data element types, so don't use TestMedianType.
1737    def setUp(self):
1738        self.func = statistics.median_grouped
1739
1740    def test_odd_number_repeated(self):
1741        # Test median.grouped with repeated median values.
1742        data = [12, 13, 14, 14, 14, 15, 15]
1743        assert len(data)%2 == 1
1744        self.assertEqual(self.func(data), 14)
1745        #---
1746        data = [12, 13, 14, 14, 14, 14, 15]
1747        assert len(data)%2 == 1
1748        self.assertEqual(self.func(data), 13.875)
1749        #---
1750        data = [5, 10, 10, 15, 20, 20, 20, 20, 25, 25, 30]
1751        assert len(data)%2 == 1
1752        self.assertEqual(self.func(data, 5), 19.375)
1753        #---
1754        data = [16, 18, 18, 18, 18, 20, 20, 20, 22, 22, 22, 24, 24, 26, 28]
1755        assert len(data)%2 == 1
1756        self.assertApproxEqual(self.func(data, 2), 20.66666667, tol=1e-8)
1757
1758    def test_even_number_repeated(self):
1759        # Test median.grouped with repeated median values.
1760        data = [5, 10, 10, 15, 20, 20, 20, 25, 25, 30]
1761        assert len(data)%2 == 0
1762        self.assertApproxEqual(self.func(data, 5), 19.16666667, tol=1e-8)
1763        #---
1764        data = [2, 3, 4, 4, 4, 5]
1765        assert len(data)%2 == 0
1766        self.assertApproxEqual(self.func(data), 3.83333333, tol=1e-8)
1767        #---
1768        data = [2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1769        assert len(data)%2 == 0
1770        self.assertEqual(self.func(data), 4.5)
1771        #---
1772        data = [3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1773        assert len(data)%2 == 0
1774        self.assertEqual(self.func(data), 4.75)
1775
1776    def test_repeated_single_value(self):
1777        # Override method from AverageMixin.
1778        # Yet again, failure of median_grouped to conserve the data type
1779        # causes me headaches :-(
1780        for x in (5.3, 68, 4.3e17, Fraction(29, 101), Decimal('32.9714')):
1781            for count in (2, 5, 10, 20):
1782                data = [x]*count
1783                self.assertEqual(self.func(data), float(x))
1784
1785    def test_odd_fractions(self):
1786        # Test median_grouped works with an odd number of Fractions.
1787        F = Fraction
1788        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4)]
1789        assert len(data)%2 == 1
1790        random.shuffle(data)
1791        self.assertEqual(self.func(data), 3.0)
1792
1793    def test_even_fractions(self):
1794        # Test median_grouped works with an even number of Fractions.
1795        F = Fraction
1796        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4), F(17, 4)]
1797        assert len(data)%2 == 0
1798        random.shuffle(data)
1799        self.assertEqual(self.func(data), 3.25)
1800
1801    def test_odd_decimals(self):
1802        # Test median_grouped works with an odd number of Decimals.
1803        D = Decimal
1804        data = [D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1805        assert len(data)%2 == 1
1806        random.shuffle(data)
1807        self.assertEqual(self.func(data), 6.75)
1808
1809    def test_even_decimals(self):
1810        # Test median_grouped works with an even number of Decimals.
1811        D = Decimal
1812        data = [D('5.5'), D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1813        assert len(data)%2 == 0
1814        random.shuffle(data)
1815        self.assertEqual(self.func(data), 6.5)
1816        #---
1817        data = [D('5.5'), D('5.5'), D('6.5'), D('7.5'), D('7.5'), D('8.5')]
1818        assert len(data)%2 == 0
1819        random.shuffle(data)
1820        self.assertEqual(self.func(data), 7.0)
1821
1822    def test_interval(self):
1823        # Test median_grouped with interval argument.
1824        data = [2.25, 2.5, 2.5, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1825        self.assertEqual(self.func(data, 0.25), 2.875)
1826        data = [2.25, 2.5, 2.5, 2.75, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1827        self.assertApproxEqual(self.func(data, 0.25), 2.83333333, tol=1e-8)
1828        data = [220, 220, 240, 260, 260, 260, 260, 280, 280, 300, 320, 340]
1829        self.assertEqual(self.func(data, 20), 265.0)
1830
1831    def test_data_type_error(self):
1832        # Test median_grouped with str, bytes data types for data and interval
1833        data = ["", "", ""]
1834        self.assertRaises(TypeError, self.func, data)
1835        #---
1836        data = [b"", b"", b""]
1837        self.assertRaises(TypeError, self.func, data)
1838        #---
1839        data = [1, 2, 3]
1840        interval = ""
1841        self.assertRaises(TypeError, self.func, data, interval)
1842        #---
1843        data = [1, 2, 3]
1844        interval = b""
1845        self.assertRaises(TypeError, self.func, data, interval)
1846
1847
1848class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1849    # Test cases for the discrete version of mode.
1850    def setUp(self):
1851        self.func = statistics.mode
1852
1853    def prepare_data(self):
1854        """Overload method from UnivariateCommonMixin."""
1855        # Make sure test data has exactly one mode.
1856        return [1, 1, 1, 1, 3, 4, 7, 9, 0, 8, 2]
1857
1858    def test_range_data(self):
1859        # Override test from UnivariateCommonMixin.
1860        data = range(20, 50, 3)
1861        self.assertEqual(self.func(data), 20)
1862
1863    def test_nominal_data(self):
1864        # Test mode with nominal data.
1865        data = 'abcbdb'
1866        self.assertEqual(self.func(data), 'b')
1867        data = 'fe fi fo fum fi fi'.split()
1868        self.assertEqual(self.func(data), 'fi')
1869
1870    def test_discrete_data(self):
1871        # Test mode with discrete numeric data.
1872        data = list(range(10))
1873        for i in range(10):
1874            d = data + [i]
1875            random.shuffle(d)
1876            self.assertEqual(self.func(d), i)
1877
1878    def test_bimodal_data(self):
1879        # Test mode with bimodal data.
1880        data = [1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, 9, 9]
1881        assert data.count(2) == data.count(6) == 4
1882        # mode() should return 2, the first encountered mode
1883        self.assertEqual(self.func(data), 2)
1884
1885    def test_unique_data(self):
1886        # Test mode when data points are all unique.
1887        data = list(range(10))
1888        # mode() should return 0, the first encountered mode
1889        self.assertEqual(self.func(data), 0)
1890
1891    def test_none_data(self):
1892        # Test that mode raises TypeError if given None as data.
1893
1894        # This test is necessary because the implementation of mode uses
1895        # collections.Counter, which accepts None and returns an empty dict.
1896        self.assertRaises(TypeError, self.func, None)
1897
1898    def test_counter_data(self):
1899        # Test that a Counter is treated like any other iterable.
1900        # We're making sure mode() first calls iter() on its input.
1901        # The concern is that a Counter of a Counter returns the original
1902        # unchanged rather than counting its keys.
1903        c = collections.Counter(a=1, b=2)
1904        # If iter() is called, mode(c) loops over the keys, ['a', 'b'],
1905        # all the counts will be 1, and the first encountered mode is 'a'.
1906        self.assertEqual(self.func(c), 'a')
1907
1908
1909class TestMultiMode(unittest.TestCase):
1910
1911    def test_basics(self):
1912        multimode = statistics.multimode
1913        self.assertEqual(multimode('aabbbbbbbbcc'), ['b'])
1914        self.assertEqual(multimode('aabbbbccddddeeffffgg'), ['b', 'd', 'f'])
1915        self.assertEqual(multimode(''), [])
1916
1917
1918class TestFMean(unittest.TestCase):
1919
1920    def test_basics(self):
1921        fmean = statistics.fmean
1922        D = Decimal
1923        F = Fraction
1924        for data, expected_mean, kind in [
1925            ([3.5, 4.0, 5.25], 4.25, 'floats'),
1926            ([D('3.5'), D('4.0'), D('5.25')], 4.25, 'decimals'),
1927            ([F(7, 2), F(4, 1), F(21, 4)], 4.25, 'fractions'),
1928            ([True, False, True, True, False], 0.60, 'booleans'),
1929            ([3.5, 4, F(21, 4)], 4.25, 'mixed types'),
1930            ((3.5, 4.0, 5.25), 4.25, 'tuple'),
1931            (iter([3.5, 4.0, 5.25]), 4.25, 'iterator'),
1932                ]:
1933            actual_mean = fmean(data)
1934            self.assertIs(type(actual_mean), float, kind)
1935            self.assertEqual(actual_mean, expected_mean, kind)
1936
1937    def test_error_cases(self):
1938        fmean = statistics.fmean
1939        StatisticsError = statistics.StatisticsError
1940        with self.assertRaises(StatisticsError):
1941            fmean([])                               # empty input
1942        with self.assertRaises(StatisticsError):
1943            fmean(iter([]))                         # empty iterator
1944        with self.assertRaises(TypeError):
1945            fmean(None)                             # non-iterable input
1946        with self.assertRaises(TypeError):
1947            fmean([10, None, 20])                   # non-numeric input
1948        with self.assertRaises(TypeError):
1949            fmean()                                 # missing data argument
1950        with self.assertRaises(TypeError):
1951            fmean([10, 20, 60], 70)                 # too many arguments
1952
1953    def test_special_values(self):
1954        # Rules for special values are inherited from math.fsum()
1955        fmean = statistics.fmean
1956        NaN = float('Nan')
1957        Inf = float('Inf')
1958        self.assertTrue(math.isnan(fmean([10, NaN])), 'nan')
1959        self.assertTrue(math.isnan(fmean([NaN, Inf])), 'nan and infinity')
1960        self.assertTrue(math.isinf(fmean([10, Inf])), 'infinity')
1961        with self.assertRaises(ValueError):
1962            fmean([Inf, -Inf])
1963
1964
1965# === Tests for variances and standard deviations ===
1966
1967class VarianceStdevMixin(UnivariateCommonMixin):
1968    # Mixin class holding common tests for variance and std dev.
1969
1970    # Subclasses should inherit from this before NumericTestClass, in order
1971    # to see the rel attribute below. See testShiftData for an explanation.
1972
1973    rel = 1e-12
1974
1975    def test_single_value(self):
1976        # Deviation of a single value is zero.
1977        for x in (11, 19.8, 4.6e14, Fraction(21, 34), Decimal('8.392')):
1978            self.assertEqual(self.func([x]), 0)
1979
1980    def test_repeated_single_value(self):
1981        # The deviation of a single repeated value is zero.
1982        for x in (7.2, 49, 8.1e15, Fraction(3, 7), Decimal('62.4802')):
1983            for count in (2, 3, 5, 15):
1984                data = [x]*count
1985                self.assertEqual(self.func(data), 0)
1986
1987    def test_domain_error_regression(self):
1988        # Regression test for a domain error exception.
1989        # (Thanks to Geremy Condra.)
1990        data = [0.123456789012345]*10000
1991        # All the items are identical, so variance should be exactly zero.
1992        # We allow some small round-off error, but not much.
1993        result = self.func(data)
1994        self.assertApproxEqual(result, 0.0, tol=5e-17)
1995        self.assertGreaterEqual(result, 0)  # A negative result must fail.
1996
1997    def test_shift_data(self):
1998        # Test that shifting the data by a constant amount does not affect
1999        # the variance or stdev. Or at least not much.
2000
2001        # Due to rounding, this test should be considered an ideal. We allow
2002        # some tolerance away from "no change at all" by setting tol and/or rel
2003        # attributes. Subclasses may set tighter or looser error tolerances.
2004        raw = [1.03, 1.27, 1.94, 2.04, 2.58, 3.14, 4.75, 4.98, 5.42, 6.78]
2005        expected = self.func(raw)
2006        # Don't set shift too high, the bigger it is, the more rounding error.
2007        shift = 1e5
2008        data = [x + shift for x in raw]
2009        self.assertApproxEqual(self.func(data), expected)
2010
2011    def test_shift_data_exact(self):
2012        # Like test_shift_data, but result is always exact.
2013        raw = [1, 3, 3, 4, 5, 7, 9, 10, 11, 16]
2014        assert all(x==int(x) for x in raw)
2015        expected = self.func(raw)
2016        shift = 10**9
2017        data = [x + shift for x in raw]
2018        self.assertEqual(self.func(data), expected)
2019
2020    def test_iter_list_same(self):
2021        # Test that iter data and list data give the same result.
2022
2023        # This is an explicit test that iterators and lists are treated the
2024        # same; justification for this test over and above the similar test
2025        # in UnivariateCommonMixin is that an earlier design had variance and
2026        # friends swap between one- and two-pass algorithms, which would
2027        # sometimes give different results.
2028        data = [random.uniform(-3, 8) for _ in range(1000)]
2029        expected = self.func(data)
2030        self.assertEqual(self.func(iter(data)), expected)
2031
2032
2033class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
2034    # Tests for population variance.
2035    def setUp(self):
2036        self.func = statistics.pvariance
2037
2038    def test_exact_uniform(self):
2039        # Test the variance against an exact result for uniform data.
2040        data = list(range(10000))
2041        random.shuffle(data)
2042        expected = (10000**2 - 1)/12  # Exact value.
2043        self.assertEqual(self.func(data), expected)
2044
2045    def test_ints(self):
2046        # Test population variance with int data.
2047        data = [4, 7, 13, 16]
2048        exact = 22.5
2049        self.assertEqual(self.func(data), exact)
2050
2051    def test_fractions(self):
2052        # Test population variance with Fraction data.
2053        F = Fraction
2054        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
2055        exact = F(3, 8)
2056        result = self.func(data)
2057        self.assertEqual(result, exact)
2058        self.assertIsInstance(result, Fraction)
2059
2060    def test_decimals(self):
2061        # Test population variance with Decimal data.
2062        D = Decimal
2063        data = [D("12.1"), D("12.2"), D("12.5"), D("12.9")]
2064        exact = D('0.096875')
2065        result = self.func(data)
2066        self.assertEqual(result, exact)
2067        self.assertIsInstance(result, Decimal)
2068
2069    def test_accuracy_bug_20499(self):
2070        data = [0, 0, 1]
2071        exact = 2 / 9
2072        result = self.func(data)
2073        self.assertEqual(result, exact)
2074        self.assertIsInstance(result, float)
2075
2076
2077class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
2078    # Tests for sample variance.
2079    def setUp(self):
2080        self.func = statistics.variance
2081
2082    def test_single_value(self):
2083        # Override method from VarianceStdevMixin.
2084        for x in (35, 24.7, 8.2e15, Fraction(19, 30), Decimal('4.2084')):
2085            self.assertRaises(statistics.StatisticsError, self.func, [x])
2086
2087    def test_ints(self):
2088        # Test sample variance with int data.
2089        data = [4, 7, 13, 16]
2090        exact = 30
2091        self.assertEqual(self.func(data), exact)
2092
2093    def test_fractions(self):
2094        # Test sample variance with Fraction data.
2095        F = Fraction
2096        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
2097        exact = F(1, 2)
2098        result = self.func(data)
2099        self.assertEqual(result, exact)
2100        self.assertIsInstance(result, Fraction)
2101
2102    def test_decimals(self):
2103        # Test sample variance with Decimal data.
2104        D = Decimal
2105        data = [D(2), D(2), D(7), D(9)]
2106        exact = 4*D('9.5')/D(3)
2107        result = self.func(data)
2108        self.assertEqual(result, exact)
2109        self.assertIsInstance(result, Decimal)
2110
2111    def test_center_not_at_mean(self):
2112        data = (1.0, 2.0)
2113        self.assertEqual(self.func(data), 0.5)
2114        self.assertEqual(self.func(data, xbar=2.0), 1.0)
2115
2116    def test_accuracy_bug_20499(self):
2117        data = [0, 0, 2]
2118        exact = 4 / 3
2119        result = self.func(data)
2120        self.assertEqual(result, exact)
2121        self.assertIsInstance(result, float)
2122
2123class TestPStdev(VarianceStdevMixin, NumericTestCase):
2124    # Tests for population standard deviation.
2125    def setUp(self):
2126        self.func = statistics.pstdev
2127
2128    def test_compare_to_variance(self):
2129        # Test that stdev is, in fact, the square root of variance.
2130        data = [random.uniform(-17, 24) for _ in range(1000)]
2131        expected = math.sqrt(statistics.pvariance(data))
2132        self.assertEqual(self.func(data), expected)
2133
2134    def test_center_not_at_mean(self):
2135        # See issue: 40855
2136        data = (3, 6, 7, 10)
2137        self.assertEqual(self.func(data), 2.5)
2138        self.assertEqual(self.func(data, mu=0.5), 6.5)
2139
2140class TestStdev(VarianceStdevMixin, NumericTestCase):
2141    # Tests for sample standard deviation.
2142    def setUp(self):
2143        self.func = statistics.stdev
2144
2145    def test_single_value(self):
2146        # Override method from VarianceStdevMixin.
2147        for x in (81, 203.74, 3.9e14, Fraction(5, 21), Decimal('35.719')):
2148            self.assertRaises(statistics.StatisticsError, self.func, [x])
2149
2150    def test_compare_to_variance(self):
2151        # Test that stdev is, in fact, the square root of variance.
2152        data = [random.uniform(-2, 9) for _ in range(1000)]
2153        expected = math.sqrt(statistics.variance(data))
2154        self.assertEqual(self.func(data), expected)
2155
2156    def test_center_not_at_mean(self):
2157        data = (1.0, 2.0)
2158        self.assertEqual(self.func(data, xbar=2.0), 1.0)
2159
2160class TestGeometricMean(unittest.TestCase):
2161
2162    def test_basics(self):
2163        geometric_mean = statistics.geometric_mean
2164        self.assertAlmostEqual(geometric_mean([54, 24, 36]), 36.0)
2165        self.assertAlmostEqual(geometric_mean([4.0, 9.0]), 6.0)
2166        self.assertAlmostEqual(geometric_mean([17.625]), 17.625)
2167
2168        random.seed(86753095551212)
2169        for rng in [
2170                range(1, 100),
2171                range(1, 1_000),
2172                range(1, 10_000),
2173                range(500, 10_000, 3),
2174                range(10_000, 500, -3),
2175                [12, 17, 13, 5, 120, 7],
2176                [random.expovariate(50.0) for i in range(1_000)],
2177                [random.lognormvariate(20.0, 3.0) for i in range(2_000)],
2178                [random.triangular(2000, 3000, 2200) for i in range(3_000)],
2179            ]:
2180            gm_decimal = math.prod(map(Decimal, rng)) ** (Decimal(1) / len(rng))
2181            gm_float = geometric_mean(rng)
2182            self.assertTrue(math.isclose(gm_float, float(gm_decimal)))
2183
2184    def test_various_input_types(self):
2185        geometric_mean = statistics.geometric_mean
2186        D = Decimal
2187        F = Fraction
2188        # https://www.wolframalpha.com/input/?i=geometric+mean+3.5,+4.0,+5.25
2189        expected_mean = 4.18886
2190        for data, kind in [
2191            ([3.5, 4.0, 5.25], 'floats'),
2192            ([D('3.5'), D('4.0'), D('5.25')], 'decimals'),
2193            ([F(7, 2), F(4, 1), F(21, 4)], 'fractions'),
2194            ([3.5, 4, F(21, 4)], 'mixed types'),
2195            ((3.5, 4.0, 5.25), 'tuple'),
2196            (iter([3.5, 4.0, 5.25]), 'iterator'),
2197                ]:
2198            actual_mean = geometric_mean(data)
2199            self.assertIs(type(actual_mean), float, kind)
2200            self.assertAlmostEqual(actual_mean, expected_mean, places=5)
2201
2202    def test_big_and_small(self):
2203        geometric_mean = statistics.geometric_mean
2204
2205        # Avoid overflow to infinity
2206        large = 2.0 ** 1000
2207        big_gm = geometric_mean([54.0 * large, 24.0 * large, 36.0 * large])
2208        self.assertTrue(math.isclose(big_gm, 36.0 * large))
2209        self.assertFalse(math.isinf(big_gm))
2210
2211        # Avoid underflow to zero
2212        small = 2.0 ** -1000
2213        small_gm = geometric_mean([54.0 * small, 24.0 * small, 36.0 * small])
2214        self.assertTrue(math.isclose(small_gm, 36.0 * small))
2215        self.assertNotEqual(small_gm, 0.0)
2216
2217    def test_error_cases(self):
2218        geometric_mean = statistics.geometric_mean
2219        StatisticsError = statistics.StatisticsError
2220        with self.assertRaises(StatisticsError):
2221            geometric_mean([])                      # empty input
2222        with self.assertRaises(StatisticsError):
2223            geometric_mean([3.5, 0.0, 5.25])        # zero input
2224        with self.assertRaises(StatisticsError):
2225            geometric_mean([3.5, -4.0, 5.25])       # negative input
2226        with self.assertRaises(StatisticsError):
2227            geometric_mean(iter([]))                # empty iterator
2228        with self.assertRaises(TypeError):
2229            geometric_mean(None)                    # non-iterable input
2230        with self.assertRaises(TypeError):
2231            geometric_mean([10, None, 20])          # non-numeric input
2232        with self.assertRaises(TypeError):
2233            geometric_mean()                        # missing data argument
2234        with self.assertRaises(TypeError):
2235            geometric_mean([10, 20, 60], 70)        # too many arguments
2236
2237    def test_special_values(self):
2238        # Rules for special values are inherited from math.fsum()
2239        geometric_mean = statistics.geometric_mean
2240        NaN = float('Nan')
2241        Inf = float('Inf')
2242        self.assertTrue(math.isnan(geometric_mean([10, NaN])), 'nan')
2243        self.assertTrue(math.isnan(geometric_mean([NaN, Inf])), 'nan and infinity')
2244        self.assertTrue(math.isinf(geometric_mean([10, Inf])), 'infinity')
2245        with self.assertRaises(ValueError):
2246            geometric_mean([Inf, -Inf])
2247
2248
2249class TestQuantiles(unittest.TestCase):
2250
2251    def test_specific_cases(self):
2252        # Match results computed by hand and cross-checked
2253        # against the PERCENTILE.EXC function in MS Excel.
2254        quantiles = statistics.quantiles
2255        data = [120, 200, 250, 320, 350]
2256        random.shuffle(data)
2257        for n, expected in [
2258            (1, []),
2259            (2, [250.0]),
2260            (3, [200.0, 320.0]),
2261            (4, [160.0, 250.0, 335.0]),
2262            (5, [136.0, 220.0, 292.0, 344.0]),
2263            (6, [120.0, 200.0, 250.0, 320.0, 350.0]),
2264            (8, [100.0, 160.0, 212.5, 250.0, 302.5, 335.0, 357.5]),
2265            (10, [88.0, 136.0, 184.0, 220.0, 250.0, 292.0, 326.0, 344.0, 362.0]),
2266            (12, [80.0, 120.0, 160.0, 200.0, 225.0, 250.0, 285.0, 320.0, 335.0,
2267                  350.0, 365.0]),
2268            (15, [72.0, 104.0, 136.0, 168.0, 200.0, 220.0, 240.0, 264.0, 292.0,
2269                  320.0, 332.0, 344.0, 356.0, 368.0]),
2270                ]:
2271            self.assertEqual(expected, quantiles(data, n=n))
2272            self.assertEqual(len(quantiles(data, n=n)), n - 1)
2273            # Preserve datatype when possible
2274            for datatype in (float, Decimal, Fraction):
2275                result = quantiles(map(datatype, data), n=n)
2276                self.assertTrue(all(type(x) == datatype) for x in result)
2277                self.assertEqual(result, list(map(datatype, expected)))
2278            # Quantiles should be idempotent
2279            if len(expected) >= 2:
2280                self.assertEqual(quantiles(expected, n=n), expected)
2281            # Cross-check against method='inclusive' which should give
2282            # the same result after adding in minimum and maximum values
2283            # extrapolated from the two lowest and two highest points.
2284            sdata = sorted(data)
2285            lo = 2 * sdata[0] - sdata[1]
2286            hi = 2 * sdata[-1] - sdata[-2]
2287            padded_data = data + [lo, hi]
2288            self.assertEqual(
2289                quantiles(data, n=n),
2290                quantiles(padded_data, n=n, method='inclusive'),
2291                (n, data),
2292            )
2293            # Invariant under translation and scaling
2294            def f(x):
2295                return 3.5 * x - 1234.675
2296            exp = list(map(f, expected))
2297            act = quantiles(map(f, data), n=n)
2298            self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act)))
2299        # Q2 agrees with median()
2300        for k in range(2, 60):
2301            data = random.choices(range(100), k=k)
2302            q1, q2, q3 = quantiles(data)
2303            self.assertEqual(q2, statistics.median(data))
2304
2305    def test_specific_cases_inclusive(self):
2306        # Match results computed by hand and cross-checked
2307        # against the PERCENTILE.INC function in MS Excel
2308        # and against the quantile() function in SciPy.
2309        quantiles = statistics.quantiles
2310        data = [100, 200, 400, 800]
2311        random.shuffle(data)
2312        for n, expected in [
2313            (1, []),
2314            (2, [300.0]),
2315            (3, [200.0, 400.0]),
2316            (4, [175.0, 300.0, 500.0]),
2317            (5, [160.0, 240.0, 360.0, 560.0]),
2318            (6, [150.0, 200.0, 300.0, 400.0, 600.0]),
2319            (8, [137.5, 175, 225.0, 300.0, 375.0, 500.0,650.0]),
2320            (10, [130.0, 160.0, 190.0, 240.0, 300.0, 360.0, 440.0, 560.0, 680.0]),
2321            (12, [125.0, 150.0, 175.0, 200.0, 250.0, 300.0, 350.0, 400.0,
2322                  500.0, 600.0, 700.0]),
2323            (15, [120.0, 140.0, 160.0, 180.0, 200.0, 240.0, 280.0, 320.0, 360.0,
2324                  400.0, 480.0, 560.0, 640.0, 720.0]),
2325                ]:
2326            self.assertEqual(expected, quantiles(data, n=n, method="inclusive"))
2327            self.assertEqual(len(quantiles(data, n=n, method="inclusive")), n - 1)
2328            # Preserve datatype when possible
2329            for datatype in (float, Decimal, Fraction):
2330                result = quantiles(map(datatype, data), n=n, method="inclusive")
2331                self.assertTrue(all(type(x) == datatype) for x in result)
2332                self.assertEqual(result, list(map(datatype, expected)))
2333            # Invariant under translation and scaling
2334            def f(x):
2335                return 3.5 * x - 1234.675
2336            exp = list(map(f, expected))
2337            act = quantiles(map(f, data), n=n, method="inclusive")
2338            self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act)))
2339        # Natural deciles
2340        self.assertEqual(quantiles([0, 100], n=10, method='inclusive'),
2341                         [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0])
2342        self.assertEqual(quantiles(range(0, 101), n=10, method='inclusive'),
2343                         [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0])
2344        # Whenever n is smaller than the number of data points, running
2345        # method='inclusive' should give the same result as method='exclusive'
2346        # after the two included extreme points are removed.
2347        data = [random.randrange(10_000) for i in range(501)]
2348        actual = quantiles(data, n=32, method='inclusive')
2349        data.remove(min(data))
2350        data.remove(max(data))
2351        expected = quantiles(data, n=32)
2352        self.assertEqual(expected, actual)
2353        # Q2 agrees with median()
2354        for k in range(2, 60):
2355            data = random.choices(range(100), k=k)
2356            q1, q2, q3 = quantiles(data, method='inclusive')
2357            self.assertEqual(q2, statistics.median(data))
2358
2359    def test_equal_inputs(self):
2360        quantiles = statistics.quantiles
2361        for n in range(2, 10):
2362            data = [10.0] * n
2363            self.assertEqual(quantiles(data), [10.0, 10.0, 10.0])
2364            self.assertEqual(quantiles(data, method='inclusive'),
2365                            [10.0, 10.0, 10.0])
2366
2367    def test_equal_sized_groups(self):
2368        quantiles = statistics.quantiles
2369        total = 10_000
2370        data = [random.expovariate(0.2) for i in range(total)]
2371        while len(set(data)) != total:
2372            data.append(random.expovariate(0.2))
2373        data.sort()
2374
2375        # Cases where the group size exactly divides the total
2376        for n in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000):
2377            group_size = total // n
2378            self.assertEqual(
2379                [bisect.bisect(data, q) for q in quantiles(data, n=n)],
2380                list(range(group_size, total, group_size)))
2381
2382        # When the group sizes can't be exactly equal, they should
2383        # differ by no more than one
2384        for n in (13, 19, 59, 109, 211, 571, 1019, 1907, 5261, 9769):
2385            group_sizes = {total // n, total // n + 1}
2386            pos = [bisect.bisect(data, q) for q in quantiles(data, n=n)]
2387            sizes = {q - p for p, q in zip(pos, pos[1:])}
2388            self.assertTrue(sizes <= group_sizes)
2389
2390    def test_error_cases(self):
2391        quantiles = statistics.quantiles
2392        StatisticsError = statistics.StatisticsError
2393        with self.assertRaises(TypeError):
2394            quantiles()                         # Missing arguments
2395        with self.assertRaises(TypeError):
2396            quantiles([10, 20, 30], 13, n=4)    # Too many arguments
2397        with self.assertRaises(TypeError):
2398            quantiles([10, 20, 30], 4)          # n is a positional argument
2399        with self.assertRaises(StatisticsError):
2400            quantiles([10, 20, 30], n=0)        # n is zero
2401        with self.assertRaises(StatisticsError):
2402            quantiles([10, 20, 30], n=-1)       # n is negative
2403        with self.assertRaises(TypeError):
2404            quantiles([10, 20, 30], n=1.5)      # n is not an integer
2405        with self.assertRaises(ValueError):
2406            quantiles([10, 20, 30], method='X') # method is unknown
2407        with self.assertRaises(StatisticsError):
2408            quantiles([10], n=4)                # not enough data points
2409        with self.assertRaises(TypeError):
2410            quantiles([10, None, 30], n=4)      # data is non-numeric
2411
2412
2413class TestBivariateStatistics(unittest.TestCase):
2414
2415    def test_unequal_size_error(self):
2416        for x, y in [
2417            ([1, 2, 3], [1, 2]),
2418            ([1, 2], [1, 2, 3]),
2419        ]:
2420            with self.assertRaises(statistics.StatisticsError):
2421                statistics.covariance(x, y)
2422            with self.assertRaises(statistics.StatisticsError):
2423                statistics.correlation(x, y)
2424            with self.assertRaises(statistics.StatisticsError):
2425                statistics.linear_regression(x, y)
2426
2427    def test_small_sample_error(self):
2428        for x, y in [
2429            ([], []),
2430            ([], [1, 2,]),
2431            ([1, 2,], []),
2432            ([1,], [1,]),
2433            ([1,], [1, 2,]),
2434            ([1, 2,], [1,]),
2435        ]:
2436            with self.assertRaises(statistics.StatisticsError):
2437                statistics.covariance(x, y)
2438            with self.assertRaises(statistics.StatisticsError):
2439                statistics.correlation(x, y)
2440            with self.assertRaises(statistics.StatisticsError):
2441                statistics.linear_regression(x, y)
2442
2443
2444class TestCorrelationAndCovariance(unittest.TestCase):
2445
2446    def test_results(self):
2447        for x, y, result in [
2448            ([1, 2, 3], [1, 2, 3], 1),
2449            ([1, 2, 3], [-1, -2, -3], -1),
2450            ([1, 2, 3], [3, 2, 1], -1),
2451            ([1, 2, 3], [1, 2, 1], 0),
2452            ([1, 2, 3], [1, 3, 2], 0.5),
2453        ]:
2454            self.assertAlmostEqual(statistics.correlation(x, y), result)
2455            self.assertAlmostEqual(statistics.covariance(x, y), result)
2456
2457    def test_different_scales(self):
2458        x = [1, 2, 3]
2459        y = [10, 30, 20]
2460        self.assertAlmostEqual(statistics.correlation(x, y), 0.5)
2461        self.assertAlmostEqual(statistics.covariance(x, y), 5)
2462
2463        y = [.1, .2, .3]
2464        self.assertAlmostEqual(statistics.correlation(x, y), 1)
2465        self.assertAlmostEqual(statistics.covariance(x, y), 0.1)
2466
2467
2468class TestLinearRegression(unittest.TestCase):
2469
2470    def test_constant_input_error(self):
2471        x = [1, 1, 1,]
2472        y = [1, 2, 3,]
2473        with self.assertRaises(statistics.StatisticsError):
2474            statistics.linear_regression(x, y)
2475
2476    def test_results(self):
2477        for x, y, true_intercept, true_slope in [
2478            ([1, 2, 3], [0, 0, 0], 0, 0),
2479            ([1, 2, 3], [1, 2, 3], 0, 1),
2480            ([1, 2, 3], [100, 100, 100], 100, 0),
2481            ([1, 2, 3], [12, 14, 16], 10, 2),
2482            ([1, 2, 3], [-1, -2, -3], 0, -1),
2483            ([1, 2, 3], [21, 22, 23], 20, 1),
2484            ([1, 2, 3], [5.1, 5.2, 5.3], 5, 0.1),
2485        ]:
2486            slope, intercept = statistics.linear_regression(x, y)
2487            self.assertAlmostEqual(intercept, true_intercept)
2488            self.assertAlmostEqual(slope, true_slope)
2489
2490
2491class TestNormalDist:
2492
2493    # General note on precision: The pdf(), cdf(), and overlap() methods
2494    # depend on functions in the math libraries that do not make
2495    # explicit accuracy guarantees.  Accordingly, some of the accuracy
2496    # tests below may fail if the underlying math functions are
2497    # inaccurate.  There isn't much we can do about this short of
2498    # implementing our own implementations from scratch.
2499
2500    def test_slots(self):
2501        nd = self.module.NormalDist(300, 23)
2502        with self.assertRaises(TypeError):
2503            vars(nd)
2504        self.assertEqual(tuple(nd.__slots__), ('_mu', '_sigma'))
2505
2506    def test_instantiation_and_attributes(self):
2507        nd = self.module.NormalDist(500, 17)
2508        self.assertEqual(nd.mean, 500)
2509        self.assertEqual(nd.stdev, 17)
2510        self.assertEqual(nd.variance, 17**2)
2511
2512        # default arguments
2513        nd = self.module.NormalDist()
2514        self.assertEqual(nd.mean, 0)
2515        self.assertEqual(nd.stdev, 1)
2516        self.assertEqual(nd.variance, 1**2)
2517
2518        # error case: negative sigma
2519        with self.assertRaises(self.module.StatisticsError):
2520            self.module.NormalDist(500, -10)
2521
2522        # verify that subclass type is honored
2523        class NewNormalDist(self.module.NormalDist):
2524            pass
2525        nnd = NewNormalDist(200, 5)
2526        self.assertEqual(type(nnd), NewNormalDist)
2527
2528    def test_alternative_constructor(self):
2529        NormalDist = self.module.NormalDist
2530        data = [96, 107, 90, 92, 110]
2531        # list input
2532        self.assertEqual(NormalDist.from_samples(data), NormalDist(99, 9))
2533        # tuple input
2534        self.assertEqual(NormalDist.from_samples(tuple(data)), NormalDist(99, 9))
2535        # iterator input
2536        self.assertEqual(NormalDist.from_samples(iter(data)), NormalDist(99, 9))
2537        # error cases
2538        with self.assertRaises(self.module.StatisticsError):
2539            NormalDist.from_samples([])                      # empty input
2540        with self.assertRaises(self.module.StatisticsError):
2541            NormalDist.from_samples([10])                    # only one input
2542
2543        # verify that subclass type is honored
2544        class NewNormalDist(NormalDist):
2545            pass
2546        nnd = NewNormalDist.from_samples(data)
2547        self.assertEqual(type(nnd), NewNormalDist)
2548
2549    def test_sample_generation(self):
2550        NormalDist = self.module.NormalDist
2551        mu, sigma = 10_000, 3.0
2552        X = NormalDist(mu, sigma)
2553        n = 1_000
2554        data = X.samples(n)
2555        self.assertEqual(len(data), n)
2556        self.assertEqual(set(map(type, data)), {float})
2557        # mean(data) expected to fall within 8 standard deviations
2558        xbar = self.module.mean(data)
2559        self.assertTrue(mu - sigma*8 <= xbar <= mu + sigma*8)
2560
2561        # verify that seeding makes reproducible sequences
2562        n = 100
2563        data1 = X.samples(n, seed='happiness and joy')
2564        data2 = X.samples(n, seed='trouble and despair')
2565        data3 = X.samples(n, seed='happiness and joy')
2566        data4 = X.samples(n, seed='trouble and despair')
2567        self.assertEqual(data1, data3)
2568        self.assertEqual(data2, data4)
2569        self.assertNotEqual(data1, data2)
2570
2571    def test_pdf(self):
2572        NormalDist = self.module.NormalDist
2573        X = NormalDist(100, 15)
2574        # Verify peak around center
2575        self.assertLess(X.pdf(99), X.pdf(100))
2576        self.assertLess(X.pdf(101), X.pdf(100))
2577        # Test symmetry
2578        for i in range(50):
2579            self.assertAlmostEqual(X.pdf(100 - i), X.pdf(100 + i))
2580        # Test vs CDF
2581        dx = 2.0 ** -10
2582        for x in range(90, 111):
2583            est_pdf = (X.cdf(x + dx) - X.cdf(x)) / dx
2584            self.assertAlmostEqual(X.pdf(x), est_pdf, places=4)
2585        # Test vs table of known values -- CRC 26th Edition
2586        Z = NormalDist()
2587        for x, px in enumerate([
2588            0.3989, 0.3989, 0.3989, 0.3988, 0.3986,
2589            0.3984, 0.3982, 0.3980, 0.3977, 0.3973,
2590            0.3970, 0.3965, 0.3961, 0.3956, 0.3951,
2591            0.3945, 0.3939, 0.3932, 0.3925, 0.3918,
2592            0.3910, 0.3902, 0.3894, 0.3885, 0.3876,
2593            0.3867, 0.3857, 0.3847, 0.3836, 0.3825,
2594            0.3814, 0.3802, 0.3790, 0.3778, 0.3765,
2595            0.3752, 0.3739, 0.3725, 0.3712, 0.3697,
2596            0.3683, 0.3668, 0.3653, 0.3637, 0.3621,
2597            0.3605, 0.3589, 0.3572, 0.3555, 0.3538,
2598        ]):
2599            self.assertAlmostEqual(Z.pdf(x / 100.0), px, places=4)
2600            self.assertAlmostEqual(Z.pdf(-x / 100.0), px, places=4)
2601        # Error case: variance is zero
2602        Y = NormalDist(100, 0)
2603        with self.assertRaises(self.module.StatisticsError):
2604            Y.pdf(90)
2605        # Special values
2606        self.assertEqual(X.pdf(float('-Inf')), 0.0)
2607        self.assertEqual(X.pdf(float('Inf')), 0.0)
2608        self.assertTrue(math.isnan(X.pdf(float('NaN'))))
2609
2610    def test_cdf(self):
2611        NormalDist = self.module.NormalDist
2612        X = NormalDist(100, 15)
2613        cdfs = [X.cdf(x) for x in range(1, 200)]
2614        self.assertEqual(set(map(type, cdfs)), {float})
2615        # Verify montonic
2616        self.assertEqual(cdfs, sorted(cdfs))
2617        # Verify center (should be exact)
2618        self.assertEqual(X.cdf(100), 0.50)
2619        # Check against a table of known values
2620        # https://en.wikipedia.org/wiki/Standard_normal_table#Cumulative
2621        Z = NormalDist()
2622        for z, cum_prob in [
2623            (0.00, 0.50000), (0.01, 0.50399), (0.02, 0.50798),
2624            (0.14, 0.55567), (0.29, 0.61409), (0.33, 0.62930),
2625            (0.54, 0.70540), (0.60, 0.72575), (1.17, 0.87900),
2626            (1.60, 0.94520), (2.05, 0.97982), (2.89, 0.99807),
2627            (3.52, 0.99978), (3.98, 0.99997), (4.07, 0.99998),
2628            ]:
2629            self.assertAlmostEqual(Z.cdf(z), cum_prob, places=5)
2630            self.assertAlmostEqual(Z.cdf(-z), 1.0 - cum_prob, places=5)
2631        # Error case: variance is zero
2632        Y = NormalDist(100, 0)
2633        with self.assertRaises(self.module.StatisticsError):
2634            Y.cdf(90)
2635        # Special values
2636        self.assertEqual(X.cdf(float('-Inf')), 0.0)
2637        self.assertEqual(X.cdf(float('Inf')), 1.0)
2638        self.assertTrue(math.isnan(X.cdf(float('NaN'))))
2639
2640    @support.skip_if_pgo_task
2641    def test_inv_cdf(self):
2642        NormalDist = self.module.NormalDist
2643
2644        # Center case should be exact.
2645        iq = NormalDist(100, 15)
2646        self.assertEqual(iq.inv_cdf(0.50), iq.mean)
2647
2648        # Test versus a published table of known percentage points.
2649        # See the second table at the bottom of the page here:
2650        # http://people.bath.ac.uk/masss/tables/normaltable.pdf
2651        Z = NormalDist()
2652        pp = {5.0: (0.000, 1.645, 2.576, 3.291, 3.891,
2653                    4.417, 4.892, 5.327, 5.731, 6.109),
2654              2.5: (0.674, 1.960, 2.807, 3.481, 4.056,
2655                    4.565, 5.026, 5.451, 5.847, 6.219),
2656              1.0: (1.282, 2.326, 3.090, 3.719, 4.265,
2657                    4.753, 5.199, 5.612, 5.998, 6.361)}
2658        for base, row in pp.items():
2659            for exp, x in enumerate(row, start=1):
2660                p = base * 10.0 ** (-exp)
2661                self.assertAlmostEqual(-Z.inv_cdf(p), x, places=3)
2662                p = 1.0 - p
2663                self.assertAlmostEqual(Z.inv_cdf(p), x, places=3)
2664
2665        # Match published example for MS Excel
2666        # https://support.office.com/en-us/article/norm-inv-function-54b30935-fee7-493c-bedb-2278a9db7e13
2667        self.assertAlmostEqual(NormalDist(40, 1.5).inv_cdf(0.908789), 42.000002)
2668
2669        # One million equally spaced probabilities
2670        n = 2**20
2671        for p in range(1, n):
2672            p /= n
2673            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2674
2675        # One hundred ever smaller probabilities to test tails out to
2676        # extreme probabilities: 1 / 2**50 and (2**50-1) / 2 ** 50
2677        for e in range(1, 51):
2678            p = 2.0 ** (-e)
2679            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2680            p = 1.0 - p
2681            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2682
2683        # Now apply cdf() first.  Near the tails, the round-trip loses
2684        # precision and is ill-conditioned (small changes in the inputs
2685        # give large changes in the output), so only check to 5 places.
2686        for x in range(200):
2687            self.assertAlmostEqual(iq.inv_cdf(iq.cdf(x)), x, places=5)
2688
2689        # Error cases:
2690        with self.assertRaises(self.module.StatisticsError):
2691            iq.inv_cdf(0.0)                         # p is zero
2692        with self.assertRaises(self.module.StatisticsError):
2693            iq.inv_cdf(-0.1)                        # p under zero
2694        with self.assertRaises(self.module.StatisticsError):
2695            iq.inv_cdf(1.0)                         # p is one
2696        with self.assertRaises(self.module.StatisticsError):
2697            iq.inv_cdf(1.1)                         # p over one
2698        with self.assertRaises(self.module.StatisticsError):
2699            iq = NormalDist(100, 0)                 # sigma is zero
2700            iq.inv_cdf(0.5)
2701
2702        # Special values
2703        self.assertTrue(math.isnan(Z.inv_cdf(float('NaN'))))
2704
2705    def test_quantiles(self):
2706        # Quartiles of a standard normal distribution
2707        Z = self.module.NormalDist()
2708        for n, expected in [
2709            (1, []),
2710            (2, [0.0]),
2711            (3, [-0.4307, 0.4307]),
2712            (4 ,[-0.6745, 0.0, 0.6745]),
2713                ]:
2714            actual = Z.quantiles(n=n)
2715            self.assertTrue(all(math.isclose(e, a, abs_tol=0.0001)
2716                            for e, a in zip(expected, actual)))
2717
2718    def test_overlap(self):
2719        NormalDist = self.module.NormalDist
2720
2721        # Match examples from Imman and Bradley
2722        for X1, X2, published_result in [
2723                (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0), 0.80258),
2724                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0), 0.60993),
2725            ]:
2726            self.assertAlmostEqual(X1.overlap(X2), published_result, places=4)
2727            self.assertAlmostEqual(X2.overlap(X1), published_result, places=4)
2728
2729        # Check against integration of the PDF
2730        def overlap_numeric(X, Y, *, steps=8_192, z=5):
2731            'Numerical integration cross-check for overlap() '
2732            fsum = math.fsum
2733            center = (X.mean + Y.mean) / 2.0
2734            width = z * max(X.stdev, Y.stdev)
2735            start = center - width
2736            dx = 2.0 * width / steps
2737            x_arr = [start + i*dx for i in range(steps)]
2738            xp = list(map(X.pdf, x_arr))
2739            yp = list(map(Y.pdf, x_arr))
2740            total = max(fsum(xp), fsum(yp))
2741            return fsum(map(min, xp, yp)) / total
2742
2743        for X1, X2 in [
2744                # Examples from Imman and Bradley
2745                (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0)),
2746                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)),
2747                # Example from https://www.rasch.org/rmt/rmt101r.htm
2748                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)),
2749                # Gender heights from http://www.usablestats.com/lessons/normal
2750                (NormalDist(70, 4), NormalDist(65, 3.5)),
2751                # Misc cases with equal standard deviations
2752                (NormalDist(100, 15), NormalDist(110, 15)),
2753                (NormalDist(-100, 15), NormalDist(110, 15)),
2754                (NormalDist(-100, 15), NormalDist(-110, 15)),
2755                # Misc cases with unequal standard deviations
2756                (NormalDist(100, 12), NormalDist(100, 15)),
2757                (NormalDist(100, 12), NormalDist(110, 15)),
2758                (NormalDist(100, 12), NormalDist(150, 15)),
2759                (NormalDist(100, 12), NormalDist(150, 35)),
2760                # Misc cases with small values
2761                (NormalDist(1.000, 0.002), NormalDist(1.001, 0.003)),
2762                (NormalDist(1.000, 0.002), NormalDist(1.006, 0.0003)),
2763                (NormalDist(1.000, 0.002), NormalDist(1.001, 0.099)),
2764            ]:
2765            self.assertAlmostEqual(X1.overlap(X2), overlap_numeric(X1, X2), places=5)
2766            self.assertAlmostEqual(X2.overlap(X1), overlap_numeric(X1, X2), places=5)
2767
2768        # Error cases
2769        X = NormalDist()
2770        with self.assertRaises(TypeError):
2771            X.overlap()                             # too few arguments
2772        with self.assertRaises(TypeError):
2773            X.overlap(X, X)                         # too may arguments
2774        with self.assertRaises(TypeError):
2775            X.overlap(None)                         # right operand not a NormalDist
2776        with self.assertRaises(self.module.StatisticsError):
2777            X.overlap(NormalDist(1, 0))             # right operand sigma is zero
2778        with self.assertRaises(self.module.StatisticsError):
2779            NormalDist(1, 0).overlap(X)             # left operand sigma is zero
2780
2781    def test_zscore(self):
2782        NormalDist = self.module.NormalDist
2783        X = NormalDist(100, 15)
2784        self.assertEqual(X.zscore(142), 2.8)
2785        self.assertEqual(X.zscore(58), -2.8)
2786        self.assertEqual(X.zscore(100), 0.0)
2787        with self.assertRaises(TypeError):
2788            X.zscore()                              # too few arguments
2789        with self.assertRaises(TypeError):
2790            X.zscore(1, 1)                          # too may arguments
2791        with self.assertRaises(TypeError):
2792            X.zscore(None)                          # non-numeric type
2793        with self.assertRaises(self.module.StatisticsError):
2794            NormalDist(1, 0).zscore(100)            # sigma is zero
2795
2796    def test_properties(self):
2797        X = self.module.NormalDist(100, 15)
2798        self.assertEqual(X.mean, 100)
2799        self.assertEqual(X.median, 100)
2800        self.assertEqual(X.mode, 100)
2801        self.assertEqual(X.stdev, 15)
2802        self.assertEqual(X.variance, 225)
2803
2804    def test_same_type_addition_and_subtraction(self):
2805        NormalDist = self.module.NormalDist
2806        X = NormalDist(100, 12)
2807        Y = NormalDist(40, 5)
2808        self.assertEqual(X + Y, NormalDist(140, 13))        # __add__
2809        self.assertEqual(X - Y, NormalDist(60, 13))         # __sub__
2810
2811    def test_translation_and_scaling(self):
2812        NormalDist = self.module.NormalDist
2813        X = NormalDist(100, 15)
2814        y = 10
2815        self.assertEqual(+X, NormalDist(100, 15))           # __pos__
2816        self.assertEqual(-X, NormalDist(-100, 15))          # __neg__
2817        self.assertEqual(X + y, NormalDist(110, 15))        # __add__
2818        self.assertEqual(y + X, NormalDist(110, 15))        # __radd__
2819        self.assertEqual(X - y, NormalDist(90, 15))         # __sub__
2820        self.assertEqual(y - X, NormalDist(-90, 15))        # __rsub__
2821        self.assertEqual(X * y, NormalDist(1000, 150))      # __mul__
2822        self.assertEqual(y * X, NormalDist(1000, 150))      # __rmul__
2823        self.assertEqual(X / y, NormalDist(10, 1.5))        # __truediv__
2824        with self.assertRaises(TypeError):                  # __rtruediv__
2825            y / X
2826
2827    def test_unary_operations(self):
2828        NormalDist = self.module.NormalDist
2829        X = NormalDist(100, 12)
2830        Y = +X
2831        self.assertIsNot(X, Y)
2832        self.assertEqual(X.mean, Y.mean)
2833        self.assertEqual(X.stdev, Y.stdev)
2834        Y = -X
2835        self.assertIsNot(X, Y)
2836        self.assertEqual(X.mean, -Y.mean)
2837        self.assertEqual(X.stdev, Y.stdev)
2838
2839    def test_equality(self):
2840        NormalDist = self.module.NormalDist
2841        nd1 = NormalDist()
2842        nd2 = NormalDist(2, 4)
2843        nd3 = NormalDist()
2844        nd4 = NormalDist(2, 4)
2845        nd5 = NormalDist(2, 8)
2846        nd6 = NormalDist(8, 4)
2847        self.assertNotEqual(nd1, nd2)
2848        self.assertEqual(nd1, nd3)
2849        self.assertEqual(nd2, nd4)
2850        self.assertNotEqual(nd2, nd5)
2851        self.assertNotEqual(nd2, nd6)
2852
2853        # Test NotImplemented when types are different
2854        class A:
2855            def __eq__(self, other):
2856                return 10
2857        a = A()
2858        self.assertEqual(nd1.__eq__(a), NotImplemented)
2859        self.assertEqual(nd1 == a, 10)
2860        self.assertEqual(a == nd1, 10)
2861
2862        # All subclasses to compare equal giving the same behavior
2863        # as list, tuple, int, float, complex, str, dict, set, etc.
2864        class SizedNormalDist(NormalDist):
2865            def __init__(self, mu, sigma, n):
2866                super().__init__(mu, sigma)
2867                self.n = n
2868        s = SizedNormalDist(100, 15, 57)
2869        nd4 = NormalDist(100, 15)
2870        self.assertEqual(s, nd4)
2871
2872        # Don't allow duck type equality because we wouldn't
2873        # want a lognormal distribution to compare equal
2874        # to a normal distribution with the same parameters
2875        class LognormalDist:
2876            def __init__(self, mu, sigma):
2877                self.mu = mu
2878                self.sigma = sigma
2879        lnd = LognormalDist(100, 15)
2880        nd = NormalDist(100, 15)
2881        self.assertNotEqual(nd, lnd)
2882
2883    def test_pickle_and_copy(self):
2884        nd = self.module.NormalDist(37.5, 5.625)
2885        nd1 = copy.copy(nd)
2886        self.assertEqual(nd, nd1)
2887        nd2 = copy.deepcopy(nd)
2888        self.assertEqual(nd, nd2)
2889        nd3 = pickle.loads(pickle.dumps(nd))
2890        self.assertEqual(nd, nd3)
2891
2892    def test_hashability(self):
2893        ND = self.module.NormalDist
2894        s = {ND(100, 15), ND(100.0, 15.0), ND(100, 10), ND(95, 15), ND(100, 15)}
2895        self.assertEqual(len(s), 3)
2896
2897    def test_repr(self):
2898        nd = self.module.NormalDist(37.5, 5.625)
2899        self.assertEqual(repr(nd), 'NormalDist(mu=37.5, sigma=5.625)')
2900
2901# Swapping the sys.modules['statistics'] is to solving the
2902# _pickle.PicklingError:
2903# Can't pickle <class 'statistics.NormalDist'>:
2904# it's not the same object as statistics.NormalDist
2905class TestNormalDistPython(unittest.TestCase, TestNormalDist):
2906    module = py_statistics
2907    def setUp(self):
2908        sys.modules['statistics'] = self.module
2909
2910    def tearDown(self):
2911        sys.modules['statistics'] = statistics
2912
2913
2914@unittest.skipUnless(c_statistics, 'requires _statistics')
2915class TestNormalDistC(unittest.TestCase, TestNormalDist):
2916    module = c_statistics
2917    def setUp(self):
2918        sys.modules['statistics'] = self.module
2919
2920    def tearDown(self):
2921        sys.modules['statistics'] = statistics
2922
2923
2924# === Run tests ===
2925
2926def load_tests(loader, tests, ignore):
2927    """Used for doctest/unittest integration."""
2928    tests.addTests(doctest.DocTestSuite())
2929    return tests
2930
2931
2932if __name__ == "__main__":
2933    unittest.main()
2934