1"""Test suite for statistics module, including helper NumericTestCase and
2approx_equal function.
3
4"""
5
6import bisect
7import collections
8import collections.abc
9import copy
10import decimal
11import doctest
12import itertools
13import math
14import pickle
15import random
16import sys
17import unittest
18from test import support
19from test.support import import_helper, requires_IEEE_754
20
21from decimal import Decimal
22from fractions import Fraction
23
24
25# Module to be tested.
26import statistics
27
28
29# === Helper functions and class ===
30
31def sign(x):
32    """Return -1.0 for negatives, including -0.0, otherwise +1.0."""
33    return math.copysign(1, x)
34
35def _nan_equal(a, b):
36    """Return True if a and b are both the same kind of NAN.
37
38    >>> _nan_equal(Decimal('NAN'), Decimal('NAN'))
39    True
40    >>> _nan_equal(Decimal('sNAN'), Decimal('sNAN'))
41    True
42    >>> _nan_equal(Decimal('NAN'), Decimal('sNAN'))
43    False
44    >>> _nan_equal(Decimal(42), Decimal('NAN'))
45    False
46
47    >>> _nan_equal(float('NAN'), float('NAN'))
48    True
49    >>> _nan_equal(float('NAN'), 0.5)
50    False
51
52    >>> _nan_equal(float('NAN'), Decimal('NAN'))
53    False
54
55    NAN payloads are not compared.
56    """
57    if type(a) is not type(b):
58        return False
59    if isinstance(a, float):
60        return math.isnan(a) and math.isnan(b)
61    aexp = a.as_tuple()[2]
62    bexp = b.as_tuple()[2]
63    return (aexp == bexp) and (aexp in ('n', 'N'))  # Both NAN or both sNAN.
64
65
66def _calc_errors(actual, expected):
67    """Return the absolute and relative errors between two numbers.
68
69    >>> _calc_errors(100, 75)
70    (25, 0.25)
71    >>> _calc_errors(100, 100)
72    (0, 0.0)
73
74    Returns the (absolute error, relative error) between the two arguments.
75    """
76    base = max(abs(actual), abs(expected))
77    abs_err = abs(actual - expected)
78    rel_err = abs_err/base if base else float('inf')
79    return (abs_err, rel_err)
80
81
82def approx_equal(x, y, tol=1e-12, rel=1e-7):
83    """approx_equal(x, y [, tol [, rel]]) => True|False
84
85    Return True if numbers x and y are approximately equal, to within some
86    margin of error, otherwise return False. Numbers which compare equal
87    will also compare approximately equal.
88
89    x is approximately equal to y if the difference between them is less than
90    an absolute error tol or a relative error rel, whichever is bigger.
91
92    If given, both tol and rel must be finite, non-negative numbers. If not
93    given, default values are tol=1e-12 and rel=1e-7.
94
95    >>> approx_equal(1.2589, 1.2587, tol=0.0003, rel=0)
96    True
97    >>> approx_equal(1.2589, 1.2587, tol=0.0001, rel=0)
98    False
99
100    Absolute error is defined as abs(x-y); if that is less than or equal to
101    tol, x and y are considered approximately equal.
102
103    Relative error is defined as abs((x-y)/x) or abs((x-y)/y), whichever is
104    smaller, provided x or y are not zero. If that figure is less than or
105    equal to rel, x and y are considered approximately equal.
106
107    Complex numbers are not directly supported. If you wish to compare to
108    complex numbers, extract their real and imaginary parts and compare them
109    individually.
110
111    NANs always compare unequal, even with themselves. Infinities compare
112    approximately equal if they have the same sign (both positive or both
113    negative). Infinities with different signs compare unequal; so do
114    comparisons of infinities with finite numbers.
115    """
116    if tol < 0 or rel < 0:
117        raise ValueError('error tolerances must be non-negative')
118    # NANs are never equal to anything, approximately or otherwise.
119    if math.isnan(x) or math.isnan(y):
120        return False
121    # Numbers which compare equal also compare approximately equal.
122    if x == y:
123        # This includes the case of two infinities with the same sign.
124        return True
125    if math.isinf(x) or math.isinf(y):
126        # This includes the case of two infinities of opposite sign, or
127        # one infinity and one finite number.
128        return False
129    # Two finite numbers.
130    actual_error = abs(x - y)
131    allowed_error = max(tol, rel*max(abs(x), abs(y)))
132    return actual_error <= allowed_error
133
134
135# This class exists only as somewhere to stick a docstring containing
136# doctests. The following docstring and tests were originally in a separate
137# module. Now that it has been merged in here, I need somewhere to hang the.
138# docstring. Ultimately, this class will die, and the information below will
139# either become redundant, or be moved into more appropriate places.
140class _DoNothing:
141    """
142    When doing numeric work, especially with floats, exact equality is often
143    not what you want. Due to round-off error, it is often a bad idea to try
144    to compare floats with equality. Instead the usual procedure is to test
145    them with some (hopefully small!) allowance for error.
146
147    The ``approx_equal`` function allows you to specify either an absolute
148    error tolerance, or a relative error, or both.
149
150    Absolute error tolerances are simple, but you need to know the magnitude
151    of the quantities being compared:
152
153    >>> approx_equal(12.345, 12.346, tol=1e-3)
154    True
155    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3)  # tol is too small.
156    False
157
158    Relative errors are more suitable when the values you are comparing can
159    vary in magnitude:
160
161    >>> approx_equal(12.345, 12.346, rel=1e-4)
162    True
163    >>> approx_equal(12.345e6, 12.346e6, rel=1e-4)
164    True
165
166    but a naive implementation of relative error testing can run into trouble
167    around zero.
168
169    If you supply both an absolute tolerance and a relative error, the
170    comparison succeeds if either individual test succeeds:
171
172    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3, rel=1e-4)
173    True
174
175    """
176    pass
177
178
179
180# We prefer this for testing numeric values that may not be exactly equal,
181# and avoid using TestCase.assertAlmostEqual, because it sucks :-)
182
183py_statistics = import_helper.import_fresh_module('statistics',
184                                                  blocked=['_statistics'])
185c_statistics = import_helper.import_fresh_module('statistics',
186                                                 fresh=['_statistics'])
187
188
189class TestModules(unittest.TestCase):
190    func_names = ['_normal_dist_inv_cdf']
191
192    def test_py_functions(self):
193        for fname in self.func_names:
194            self.assertEqual(getattr(py_statistics, fname).__module__, 'statistics')
195
196    @unittest.skipUnless(c_statistics, 'requires _statistics')
197    def test_c_functions(self):
198        for fname in self.func_names:
199            self.assertEqual(getattr(c_statistics, fname).__module__, '_statistics')
200
201
202class NumericTestCase(unittest.TestCase):
203    """Unit test class for numeric work.
204
205    This subclasses TestCase. In addition to the standard method
206    ``TestCase.assertAlmostEqual``,  ``assertApproxEqual`` is provided.
207    """
208    # By default, we expect exact equality, unless overridden.
209    tol = rel = 0
210
211    def assertApproxEqual(
212            self, first, second, tol=None, rel=None, msg=None
213            ):
214        """Test passes if ``first`` and ``second`` are approximately equal.
215
216        This test passes if ``first`` and ``second`` are equal to
217        within ``tol``, an absolute error, or ``rel``, a relative error.
218
219        If either ``tol`` or ``rel`` are None or not given, they default to
220        test attributes of the same name (by default, 0).
221
222        The objects may be either numbers, or sequences of numbers. Sequences
223        are tested element-by-element.
224
225        >>> class MyTest(NumericTestCase):
226        ...     def test_number(self):
227        ...         x = 1.0/6
228        ...         y = sum([x]*6)
229        ...         self.assertApproxEqual(y, 1.0, tol=1e-15)
230        ...     def test_sequence(self):
231        ...         a = [1.001, 1.001e-10, 1.001e10]
232        ...         b = [1.0, 1e-10, 1e10]
233        ...         self.assertApproxEqual(a, b, rel=1e-3)
234        ...
235        >>> import unittest
236        >>> from io import StringIO  # Suppress test runner output.
237        >>> suite = unittest.TestLoader().loadTestsFromTestCase(MyTest)
238        >>> unittest.TextTestRunner(stream=StringIO()).run(suite)
239        <unittest.runner.TextTestResult run=2 errors=0 failures=0>
240
241        """
242        if tol is None:
243            tol = self.tol
244        if rel is None:
245            rel = self.rel
246        if (
247                isinstance(first, collections.abc.Sequence) and
248                isinstance(second, collections.abc.Sequence)
249            ):
250            check = self._check_approx_seq
251        else:
252            check = self._check_approx_num
253        check(first, second, tol, rel, msg)
254
255    def _check_approx_seq(self, first, second, tol, rel, msg):
256        if len(first) != len(second):
257            standardMsg = (
258                "sequences differ in length: %d items != %d items"
259                % (len(first), len(second))
260                )
261            msg = self._formatMessage(msg, standardMsg)
262            raise self.failureException(msg)
263        for i, (a,e) in enumerate(zip(first, second)):
264            self._check_approx_num(a, e, tol, rel, msg, i)
265
266    def _check_approx_num(self, first, second, tol, rel, msg, idx=None):
267        if approx_equal(first, second, tol, rel):
268            # Test passes. Return early, we are done.
269            return None
270        # Otherwise we failed.
271        standardMsg = self._make_std_err_msg(first, second, tol, rel, idx)
272        msg = self._formatMessage(msg, standardMsg)
273        raise self.failureException(msg)
274
275    @staticmethod
276    def _make_std_err_msg(first, second, tol, rel, idx):
277        # Create the standard error message for approx_equal failures.
278        assert first != second
279        template = (
280            '  %r != %r\n'
281            '  values differ by more than tol=%r and rel=%r\n'
282            '  -> absolute error = %r\n'
283            '  -> relative error = %r'
284            )
285        if idx is not None:
286            header = 'numeric sequences first differ at index %d.\n' % idx
287            template = header + template
288        # Calculate actual errors:
289        abs_err, rel_err = _calc_errors(first, second)
290        return template % (first, second, tol, rel, abs_err, rel_err)
291
292
293# ========================
294# === Test the helpers ===
295# ========================
296
297class TestSign(unittest.TestCase):
298    """Test that the helper function sign() works correctly."""
299    def testZeroes(self):
300        # Test that signed zeroes report their sign correctly.
301        self.assertEqual(sign(0.0), +1)
302        self.assertEqual(sign(-0.0), -1)
303
304
305# --- Tests for approx_equal ---
306
307class ApproxEqualSymmetryTest(unittest.TestCase):
308    # Test symmetry of approx_equal.
309
310    def test_relative_symmetry(self):
311        # Check that approx_equal treats relative error symmetrically.
312        # (a-b)/a is usually not equal to (a-b)/b. Ensure that this
313        # doesn't matter.
314        #
315        #   Note: the reason for this test is that an early version
316        #   of approx_equal was not symmetric. A relative error test
317        #   would pass, or fail, depending on which value was passed
318        #   as the first argument.
319        #
320        args1 = [2456, 37.8, -12.45, Decimal('2.54'), Fraction(17, 54)]
321        args2 = [2459, 37.2, -12.41, Decimal('2.59'), Fraction(15, 54)]
322        assert len(args1) == len(args2)
323        for a, b in zip(args1, args2):
324            self.do_relative_symmetry(a, b)
325
326    def do_relative_symmetry(self, a, b):
327        a, b = min(a, b), max(a, b)
328        assert a < b
329        delta = b - a  # The absolute difference between the values.
330        rel_err1, rel_err2 = abs(delta/a), abs(delta/b)
331        # Choose an error margin halfway between the two.
332        rel = (rel_err1 + rel_err2)/2
333        # Now see that values a and b compare approx equal regardless of
334        # which is given first.
335        self.assertTrue(approx_equal(a, b, tol=0, rel=rel))
336        self.assertTrue(approx_equal(b, a, tol=0, rel=rel))
337
338    def test_symmetry(self):
339        # Test that approx_equal(a, b) == approx_equal(b, a)
340        args = [-23, -2, 5, 107, 93568]
341        delta = 2
342        for a in args:
343            for type_ in (int, float, Decimal, Fraction):
344                x = type_(a)*100
345                y = x + delta
346                r = abs(delta/max(x, y))
347                # There are five cases to check:
348                # 1) actual error <= tol, <= rel
349                self.do_symmetry_test(x, y, tol=delta, rel=r)
350                self.do_symmetry_test(x, y, tol=delta+1, rel=2*r)
351                # 2) actual error > tol, > rel
352                self.do_symmetry_test(x, y, tol=delta-1, rel=r/2)
353                # 3) actual error <= tol, > rel
354                self.do_symmetry_test(x, y, tol=delta, rel=r/2)
355                # 4) actual error > tol, <= rel
356                self.do_symmetry_test(x, y, tol=delta-1, rel=r)
357                self.do_symmetry_test(x, y, tol=delta-1, rel=2*r)
358                # 5) exact equality test
359                self.do_symmetry_test(x, x, tol=0, rel=0)
360                self.do_symmetry_test(x, y, tol=0, rel=0)
361
362    def do_symmetry_test(self, a, b, tol, rel):
363        template = "approx_equal comparisons don't match for %r"
364        flag1 = approx_equal(a, b, tol, rel)
365        flag2 = approx_equal(b, a, tol, rel)
366        self.assertEqual(flag1, flag2, template.format((a, b, tol, rel)))
367
368
369class ApproxEqualExactTest(unittest.TestCase):
370    # Test the approx_equal function with exactly equal values.
371    # Equal values should compare as approximately equal.
372    # Test cases for exactly equal values, which should compare approx
373    # equal regardless of the error tolerances given.
374
375    def do_exactly_equal_test(self, x, tol, rel):
376        result = approx_equal(x, x, tol=tol, rel=rel)
377        self.assertTrue(result, 'equality failure for x=%r' % x)
378        result = approx_equal(-x, -x, tol=tol, rel=rel)
379        self.assertTrue(result, 'equality failure for x=%r' % -x)
380
381    def test_exactly_equal_ints(self):
382        # Test that equal int values are exactly equal.
383        for n in [42, 19740, 14974, 230, 1795, 700245, 36587]:
384            self.do_exactly_equal_test(n, 0, 0)
385
386    def test_exactly_equal_floats(self):
387        # Test that equal float values are exactly equal.
388        for x in [0.42, 1.9740, 1497.4, 23.0, 179.5, 70.0245, 36.587]:
389            self.do_exactly_equal_test(x, 0, 0)
390
391    def test_exactly_equal_fractions(self):
392        # Test that equal Fraction values are exactly equal.
393        F = Fraction
394        for f in [F(1, 2), F(0), F(5, 3), F(9, 7), F(35, 36), F(3, 7)]:
395            self.do_exactly_equal_test(f, 0, 0)
396
397    def test_exactly_equal_decimals(self):
398        # Test that equal Decimal values are exactly equal.
399        D = Decimal
400        for d in map(D, "8.2 31.274 912.04 16.745 1.2047".split()):
401            self.do_exactly_equal_test(d, 0, 0)
402
403    def test_exactly_equal_absolute(self):
404        # Test that equal values are exactly equal with an absolute error.
405        for n in [16, 1013, 1372, 1198, 971, 4]:
406            # Test as ints.
407            self.do_exactly_equal_test(n, 0.01, 0)
408            # Test as floats.
409            self.do_exactly_equal_test(n/10, 0.01, 0)
410            # Test as Fractions.
411            f = Fraction(n, 1234)
412            self.do_exactly_equal_test(f, 0.01, 0)
413
414    def test_exactly_equal_absolute_decimals(self):
415        # Test equal Decimal values are exactly equal with an absolute error.
416        self.do_exactly_equal_test(Decimal("3.571"), Decimal("0.01"), 0)
417        self.do_exactly_equal_test(-Decimal("81.3971"), Decimal("0.01"), 0)
418
419    def test_exactly_equal_relative(self):
420        # Test that equal values are exactly equal with a relative error.
421        for x in [8347, 101.3, -7910.28, Fraction(5, 21)]:
422            self.do_exactly_equal_test(x, 0, 0.01)
423        self.do_exactly_equal_test(Decimal("11.68"), 0, Decimal("0.01"))
424
425    def test_exactly_equal_both(self):
426        # Test that equal values are equal when both tol and rel are given.
427        for x in [41017, 16.742, -813.02, Fraction(3, 8)]:
428            self.do_exactly_equal_test(x, 0.1, 0.01)
429        D = Decimal
430        self.do_exactly_equal_test(D("7.2"), D("0.1"), D("0.01"))
431
432
433class ApproxEqualUnequalTest(unittest.TestCase):
434    # Unequal values should compare unequal with zero error tolerances.
435    # Test cases for unequal values, with exact equality test.
436
437    def do_exactly_unequal_test(self, x):
438        for a in (x, -x):
439            result = approx_equal(a, a+1, tol=0, rel=0)
440            self.assertFalse(result, 'inequality failure for x=%r' % a)
441
442    def test_exactly_unequal_ints(self):
443        # Test unequal int values are unequal with zero error tolerance.
444        for n in [951, 572305, 478, 917, 17240]:
445            self.do_exactly_unequal_test(n)
446
447    def test_exactly_unequal_floats(self):
448        # Test unequal float values are unequal with zero error tolerance.
449        for x in [9.51, 5723.05, 47.8, 9.17, 17.24]:
450            self.do_exactly_unequal_test(x)
451
452    def test_exactly_unequal_fractions(self):
453        # Test that unequal Fractions are unequal with zero error tolerance.
454        F = Fraction
455        for f in [F(1, 5), F(7, 9), F(12, 11), F(101, 99023)]:
456            self.do_exactly_unequal_test(f)
457
458    def test_exactly_unequal_decimals(self):
459        # Test that unequal Decimals are unequal with zero error tolerance.
460        for d in map(Decimal, "3.1415 298.12 3.47 18.996 0.00245".split()):
461            self.do_exactly_unequal_test(d)
462
463
464class ApproxEqualInexactTest(unittest.TestCase):
465    # Inexact test cases for approx_error.
466    # Test cases when comparing two values that are not exactly equal.
467
468    # === Absolute error tests ===
469
470    def do_approx_equal_abs_test(self, x, delta):
471        template = "Test failure for x={!r}, y={!r}"
472        for y in (x + delta, x - delta):
473            msg = template.format(x, y)
474            self.assertTrue(approx_equal(x, y, tol=2*delta, rel=0), msg)
475            self.assertFalse(approx_equal(x, y, tol=delta/2, rel=0), msg)
476
477    def test_approx_equal_absolute_ints(self):
478        # Test approximate equality of ints with an absolute error.
479        for n in [-10737, -1975, -7, -2, 0, 1, 9, 37, 423, 9874, 23789110]:
480            self.do_approx_equal_abs_test(n, 10)
481            self.do_approx_equal_abs_test(n, 2)
482
483    def test_approx_equal_absolute_floats(self):
484        # Test approximate equality of floats with an absolute error.
485        for x in [-284.126, -97.1, -3.4, -2.15, 0.5, 1.0, 7.8, 4.23, 3817.4]:
486            self.do_approx_equal_abs_test(x, 1.5)
487            self.do_approx_equal_abs_test(x, 0.01)
488            self.do_approx_equal_abs_test(x, 0.0001)
489
490    def test_approx_equal_absolute_fractions(self):
491        # Test approximate equality of Fractions with an absolute error.
492        delta = Fraction(1, 29)
493        numerators = [-84, -15, -2, -1, 0, 1, 5, 17, 23, 34, 71]
494        for f in (Fraction(n, 29) for n in numerators):
495            self.do_approx_equal_abs_test(f, delta)
496            self.do_approx_equal_abs_test(f, float(delta))
497
498    def test_approx_equal_absolute_decimals(self):
499        # Test approximate equality of Decimals with an absolute error.
500        delta = Decimal("0.01")
501        for d in map(Decimal, "1.0 3.5 36.08 61.79 7912.3648".split()):
502            self.do_approx_equal_abs_test(d, delta)
503            self.do_approx_equal_abs_test(-d, delta)
504
505    def test_cross_zero(self):
506        # Test for the case of the two values having opposite signs.
507        self.assertTrue(approx_equal(1e-5, -1e-5, tol=1e-4, rel=0))
508
509    # === Relative error tests ===
510
511    def do_approx_equal_rel_test(self, x, delta):
512        template = "Test failure for x={!r}, y={!r}"
513        for y in (x*(1+delta), x*(1-delta)):
514            msg = template.format(x, y)
515            self.assertTrue(approx_equal(x, y, tol=0, rel=2*delta), msg)
516            self.assertFalse(approx_equal(x, y, tol=0, rel=delta/2), msg)
517
518    def test_approx_equal_relative_ints(self):
519        # Test approximate equality of ints with a relative error.
520        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.36))
521        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.37))
522        # ---
523        self.assertTrue(approx_equal(449, 512, tol=0, rel=0.125))
524        self.assertTrue(approx_equal(448, 512, tol=0, rel=0.125))
525        self.assertFalse(approx_equal(447, 512, tol=0, rel=0.125))
526
527    def test_approx_equal_relative_floats(self):
528        # Test approximate equality of floats with a relative error.
529        for x in [-178.34, -0.1, 0.1, 1.0, 36.97, 2847.136, 9145.074]:
530            self.do_approx_equal_rel_test(x, 0.02)
531            self.do_approx_equal_rel_test(x, 0.0001)
532
533    def test_approx_equal_relative_fractions(self):
534        # Test approximate equality of Fractions with a relative error.
535        F = Fraction
536        delta = Fraction(3, 8)
537        for f in [F(3, 84), F(17, 30), F(49, 50), F(92, 85)]:
538            for d in (delta, float(delta)):
539                self.do_approx_equal_rel_test(f, d)
540                self.do_approx_equal_rel_test(-f, d)
541
542    def test_approx_equal_relative_decimals(self):
543        # Test approximate equality of Decimals with a relative error.
544        for d in map(Decimal, "0.02 1.0 5.7 13.67 94.138 91027.9321".split()):
545            self.do_approx_equal_rel_test(d, Decimal("0.001"))
546            self.do_approx_equal_rel_test(-d, Decimal("0.05"))
547
548    # === Both absolute and relative error tests ===
549
550    # There are four cases to consider:
551    #   1) actual error <= both absolute and relative error
552    #   2) actual error <= absolute error but > relative error
553    #   3) actual error <= relative error but > absolute error
554    #   4) actual error > both absolute and relative error
555
556    def do_check_both(self, a, b, tol, rel, tol_flag, rel_flag):
557        check = self.assertTrue if tol_flag else self.assertFalse
558        check(approx_equal(a, b, tol=tol, rel=0))
559        check = self.assertTrue if rel_flag else self.assertFalse
560        check(approx_equal(a, b, tol=0, rel=rel))
561        check = self.assertTrue if (tol_flag or rel_flag) else self.assertFalse
562        check(approx_equal(a, b, tol=tol, rel=rel))
563
564    def test_approx_equal_both1(self):
565        # Test actual error <= both absolute and relative error.
566        self.do_check_both(7.955, 7.952, 0.004, 3.8e-4, True, True)
567        self.do_check_both(-7.387, -7.386, 0.002, 0.0002, True, True)
568
569    def test_approx_equal_both2(self):
570        # Test actual error <= absolute error but > relative error.
571        self.do_check_both(7.955, 7.952, 0.004, 3.7e-4, True, False)
572
573    def test_approx_equal_both3(self):
574        # Test actual error <= relative error but > absolute error.
575        self.do_check_both(7.955, 7.952, 0.001, 3.8e-4, False, True)
576
577    def test_approx_equal_both4(self):
578        # Test actual error > both absolute and relative error.
579        self.do_check_both(2.78, 2.75, 0.01, 0.001, False, False)
580        self.do_check_both(971.44, 971.47, 0.02, 3e-5, False, False)
581
582
583class ApproxEqualSpecialsTest(unittest.TestCase):
584    # Test approx_equal with NANs and INFs and zeroes.
585
586    def test_inf(self):
587        for type_ in (float, Decimal):
588            inf = type_('inf')
589            self.assertTrue(approx_equal(inf, inf))
590            self.assertTrue(approx_equal(inf, inf, 0, 0))
591            self.assertTrue(approx_equal(inf, inf, 1, 0.01))
592            self.assertTrue(approx_equal(-inf, -inf))
593            self.assertFalse(approx_equal(inf, -inf))
594            self.assertFalse(approx_equal(inf, 1000))
595
596    def test_nan(self):
597        for type_ in (float, Decimal):
598            nan = type_('nan')
599            for other in (nan, type_('inf'), 1000):
600                self.assertFalse(approx_equal(nan, other))
601
602    def test_float_zeroes(self):
603        nzero = math.copysign(0.0, -1)
604        self.assertTrue(approx_equal(nzero, 0.0, tol=0.1, rel=0.1))
605
606    def test_decimal_zeroes(self):
607        nzero = Decimal("-0.0")
608        self.assertTrue(approx_equal(nzero, Decimal(0), tol=0.1, rel=0.1))
609
610
611class TestApproxEqualErrors(unittest.TestCase):
612    # Test error conditions of approx_equal.
613
614    def test_bad_tol(self):
615        # Test negative tol raises.
616        self.assertRaises(ValueError, approx_equal, 100, 100, -1, 0.1)
617
618    def test_bad_rel(self):
619        # Test negative rel raises.
620        self.assertRaises(ValueError, approx_equal, 100, 100, 1, -0.1)
621
622
623# --- Tests for NumericTestCase ---
624
625# The formatting routine that generates the error messages is complex enough
626# that it too needs testing.
627
628class TestNumericTestCase(unittest.TestCase):
629    # The exact wording of NumericTestCase error messages is *not* guaranteed,
630    # but we need to give them some sort of test to ensure that they are
631    # generated correctly. As a compromise, we look for specific substrings
632    # that are expected to be found even if the overall error message changes.
633
634    def do_test(self, args):
635        actual_msg = NumericTestCase._make_std_err_msg(*args)
636        expected = self.generate_substrings(*args)
637        for substring in expected:
638            self.assertIn(substring, actual_msg)
639
640    def test_numerictestcase_is_testcase(self):
641        # Ensure that NumericTestCase actually is a TestCase.
642        self.assertTrue(issubclass(NumericTestCase, unittest.TestCase))
643
644    def test_error_msg_numeric(self):
645        # Test the error message generated for numeric comparisons.
646        args = (2.5, 4.0, 0.5, 0.25, None)
647        self.do_test(args)
648
649    def test_error_msg_sequence(self):
650        # Test the error message generated for sequence comparisons.
651        args = (3.75, 8.25, 1.25, 0.5, 7)
652        self.do_test(args)
653
654    def generate_substrings(self, first, second, tol, rel, idx):
655        """Return substrings we expect to see in error messages."""
656        abs_err, rel_err = _calc_errors(first, second)
657        substrings = [
658                'tol=%r' % tol,
659                'rel=%r' % rel,
660                'absolute error = %r' % abs_err,
661                'relative error = %r' % rel_err,
662                ]
663        if idx is not None:
664            substrings.append('differ at index %d' % idx)
665        return substrings
666
667
668# =======================================
669# === Tests for the statistics module ===
670# =======================================
671
672
673class GlobalsTest(unittest.TestCase):
674    module = statistics
675    expected_metadata = ["__doc__", "__all__"]
676
677    def test_meta(self):
678        # Test for the existence of metadata.
679        for meta in self.expected_metadata:
680            self.assertTrue(hasattr(self.module, meta),
681                            "%s not present" % meta)
682
683    def test_check_all(self):
684        # Check everything in __all__ exists and is public.
685        module = self.module
686        for name in module.__all__:
687            # No private names in __all__:
688            self.assertFalse(name.startswith("_"),
689                             'private name "%s" in __all__' % name)
690            # And anything in __all__ must exist:
691            self.assertTrue(hasattr(module, name),
692                            'missing name "%s" in __all__' % name)
693
694
695class DocTests(unittest.TestCase):
696    @unittest.skipIf(sys.flags.optimize >= 2,
697                     "Docstrings are omitted with -OO and above")
698    def test_doc_tests(self):
699        failed, tried = doctest.testmod(statistics, optionflags=doctest.ELLIPSIS)
700        self.assertGreater(tried, 0)
701        self.assertEqual(failed, 0)
702
703class StatisticsErrorTest(unittest.TestCase):
704    def test_has_exception(self):
705        errmsg = (
706                "Expected StatisticsError to be a ValueError, but got a"
707                " subclass of %r instead."
708                )
709        self.assertTrue(hasattr(statistics, 'StatisticsError'))
710        self.assertTrue(
711                issubclass(statistics.StatisticsError, ValueError),
712                errmsg % statistics.StatisticsError.__base__
713                )
714
715
716# === Tests for private utility functions ===
717
718class ExactRatioTest(unittest.TestCase):
719    # Test _exact_ratio utility.
720
721    def test_int(self):
722        for i in (-20, -3, 0, 5, 99, 10**20):
723            self.assertEqual(statistics._exact_ratio(i), (i, 1))
724
725    def test_fraction(self):
726        numerators = (-5, 1, 12, 38)
727        for n in numerators:
728            f = Fraction(n, 37)
729            self.assertEqual(statistics._exact_ratio(f), (n, 37))
730
731    def test_float(self):
732        self.assertEqual(statistics._exact_ratio(0.125), (1, 8))
733        self.assertEqual(statistics._exact_ratio(1.125), (9, 8))
734        data = [random.uniform(-100, 100) for _ in range(100)]
735        for x in data:
736            num, den = statistics._exact_ratio(x)
737            self.assertEqual(x, num/den)
738
739    def test_decimal(self):
740        D = Decimal
741        _exact_ratio = statistics._exact_ratio
742        self.assertEqual(_exact_ratio(D("0.125")), (1, 8))
743        self.assertEqual(_exact_ratio(D("12.345")), (2469, 200))
744        self.assertEqual(_exact_ratio(D("-1.98")), (-99, 50))
745
746    def test_inf(self):
747        INF = float("INF")
748        class MyFloat(float):
749            pass
750        class MyDecimal(Decimal):
751            pass
752        for inf in (INF, -INF):
753            for type_ in (float, MyFloat, Decimal, MyDecimal):
754                x = type_(inf)
755                ratio = statistics._exact_ratio(x)
756                self.assertEqual(ratio, (x, None))
757                self.assertEqual(type(ratio[0]), type_)
758                self.assertTrue(math.isinf(ratio[0]))
759
760    def test_float_nan(self):
761        NAN = float("NAN")
762        class MyFloat(float):
763            pass
764        for nan in (NAN, MyFloat(NAN)):
765            ratio = statistics._exact_ratio(nan)
766            self.assertTrue(math.isnan(ratio[0]))
767            self.assertIs(ratio[1], None)
768            self.assertEqual(type(ratio[0]), type(nan))
769
770    def test_decimal_nan(self):
771        NAN = Decimal("NAN")
772        sNAN = Decimal("sNAN")
773        class MyDecimal(Decimal):
774            pass
775        for nan in (NAN, MyDecimal(NAN), sNAN, MyDecimal(sNAN)):
776            ratio = statistics._exact_ratio(nan)
777            self.assertTrue(_nan_equal(ratio[0], nan))
778            self.assertIs(ratio[1], None)
779            self.assertEqual(type(ratio[0]), type(nan))
780
781
782class DecimalToRatioTest(unittest.TestCase):
783    # Test _exact_ratio private function.
784
785    def test_infinity(self):
786        # Test that INFs are handled correctly.
787        inf = Decimal('INF')
788        self.assertEqual(statistics._exact_ratio(inf), (inf, None))
789        self.assertEqual(statistics._exact_ratio(-inf), (-inf, None))
790
791    def test_nan(self):
792        # Test that NANs are handled correctly.
793        for nan in (Decimal('NAN'), Decimal('sNAN')):
794            num, den = statistics._exact_ratio(nan)
795            # Because NANs always compare non-equal, we cannot use assertEqual.
796            # Nor can we use an identity test, as we don't guarantee anything
797            # about the object identity.
798            self.assertTrue(_nan_equal(num, nan))
799            self.assertIs(den, None)
800
801    def test_sign(self):
802        # Test sign is calculated correctly.
803        numbers = [Decimal("9.8765e12"), Decimal("9.8765e-12")]
804        for d in numbers:
805            # First test positive decimals.
806            assert d > 0
807            num, den = statistics._exact_ratio(d)
808            self.assertGreaterEqual(num, 0)
809            self.assertGreater(den, 0)
810            # Then test negative decimals.
811            num, den = statistics._exact_ratio(-d)
812            self.assertLessEqual(num, 0)
813            self.assertGreater(den, 0)
814
815    def test_negative_exponent(self):
816        # Test result when the exponent is negative.
817        t = statistics._exact_ratio(Decimal("0.1234"))
818        self.assertEqual(t, (617, 5000))
819
820    def test_positive_exponent(self):
821        # Test results when the exponent is positive.
822        t = statistics._exact_ratio(Decimal("1.234e7"))
823        self.assertEqual(t, (12340000, 1))
824
825    def test_regression_20536(self):
826        # Regression test for issue 20536.
827        # See http://bugs.python.org/issue20536
828        t = statistics._exact_ratio(Decimal("1e2"))
829        self.assertEqual(t, (100, 1))
830        t = statistics._exact_ratio(Decimal("1.47e5"))
831        self.assertEqual(t, (147000, 1))
832
833
834class IsFiniteTest(unittest.TestCase):
835    # Test _isfinite private function.
836
837    def test_finite(self):
838        # Test that finite numbers are recognised as finite.
839        for x in (5, Fraction(1, 3), 2.5, Decimal("5.5")):
840            self.assertTrue(statistics._isfinite(x))
841
842    def test_infinity(self):
843        # Test that INFs are not recognised as finite.
844        for x in (float("inf"), Decimal("inf")):
845            self.assertFalse(statistics._isfinite(x))
846
847    def test_nan(self):
848        # Test that NANs are not recognised as finite.
849        for x in (float("nan"), Decimal("NAN"), Decimal("sNAN")):
850            self.assertFalse(statistics._isfinite(x))
851
852
853class CoerceTest(unittest.TestCase):
854    # Test that private function _coerce correctly deals with types.
855
856    # The coercion rules are currently an implementation detail, although at
857    # some point that should change. The tests and comments here define the
858    # correct implementation.
859
860    # Pre-conditions of _coerce:
861    #
862    #   - The first time _sum calls _coerce, the
863    #   - coerce(T, S) will never be called with bool as the first argument;
864    #     this is a pre-condition, guarded with an assertion.
865
866    #
867    #   - coerce(T, T) will always return T; we assume T is a valid numeric
868    #     type. Violate this assumption at your own risk.
869    #
870    #   - Apart from as above, bool is treated as if it were actually int.
871    #
872    #   - coerce(int, X) and coerce(X, int) return X.
873    #   -
874    def test_bool(self):
875        # bool is somewhat special, due to the pre-condition that it is
876        # never given as the first argument to _coerce, and that it cannot
877        # be subclassed. So we test it specially.
878        for T in (int, float, Fraction, Decimal):
879            self.assertIs(statistics._coerce(T, bool), T)
880            class MyClass(T): pass
881            self.assertIs(statistics._coerce(MyClass, bool), MyClass)
882
883    def assertCoerceTo(self, A, B):
884        """Assert that type A coerces to B."""
885        self.assertIs(statistics._coerce(A, B), B)
886        self.assertIs(statistics._coerce(B, A), B)
887
888    def check_coerce_to(self, A, B):
889        """Checks that type A coerces to B, including subclasses."""
890        # Assert that type A is coerced to B.
891        self.assertCoerceTo(A, B)
892        # Subclasses of A are also coerced to B.
893        class SubclassOfA(A): pass
894        self.assertCoerceTo(SubclassOfA, B)
895        # A, and subclasses of A, are coerced to subclasses of B.
896        class SubclassOfB(B): pass
897        self.assertCoerceTo(A, SubclassOfB)
898        self.assertCoerceTo(SubclassOfA, SubclassOfB)
899
900    def assertCoerceRaises(self, A, B):
901        """Assert that coercing A to B, or vice versa, raises TypeError."""
902        self.assertRaises(TypeError, statistics._coerce, (A, B))
903        self.assertRaises(TypeError, statistics._coerce, (B, A))
904
905    def check_type_coercions(self, T):
906        """Check that type T coerces correctly with subclasses of itself."""
907        assert T is not bool
908        # Coercing a type with itself returns the same type.
909        self.assertIs(statistics._coerce(T, T), T)
910        # Coercing a type with a subclass of itself returns the subclass.
911        class U(T): pass
912        class V(T): pass
913        class W(U): pass
914        for typ in (U, V, W):
915            self.assertCoerceTo(T, typ)
916        self.assertCoerceTo(U, W)
917        # Coercing two subclasses that aren't parent/child is an error.
918        self.assertCoerceRaises(U, V)
919        self.assertCoerceRaises(V, W)
920
921    def test_int(self):
922        # Check that int coerces correctly.
923        self.check_type_coercions(int)
924        for typ in (float, Fraction, Decimal):
925            self.check_coerce_to(int, typ)
926
927    def test_fraction(self):
928        # Check that Fraction coerces correctly.
929        self.check_type_coercions(Fraction)
930        self.check_coerce_to(Fraction, float)
931
932    def test_decimal(self):
933        # Check that Decimal coerces correctly.
934        self.check_type_coercions(Decimal)
935
936    def test_float(self):
937        # Check that float coerces correctly.
938        self.check_type_coercions(float)
939
940    def test_non_numeric_types(self):
941        for bad_type in (str, list, type(None), tuple, dict):
942            for good_type in (int, float, Fraction, Decimal):
943                self.assertCoerceRaises(good_type, bad_type)
944
945    def test_incompatible_types(self):
946        # Test that incompatible types raise.
947        for T in (float, Fraction):
948            class MySubclass(T): pass
949            self.assertCoerceRaises(T, Decimal)
950            self.assertCoerceRaises(MySubclass, Decimal)
951
952
953class ConvertTest(unittest.TestCase):
954    # Test private _convert function.
955
956    def check_exact_equal(self, x, y):
957        """Check that x equals y, and has the same type as well."""
958        self.assertEqual(x, y)
959        self.assertIs(type(x), type(y))
960
961    def test_int(self):
962        # Test conversions to int.
963        x = statistics._convert(Fraction(71), int)
964        self.check_exact_equal(x, 71)
965        class MyInt(int): pass
966        x = statistics._convert(Fraction(17), MyInt)
967        self.check_exact_equal(x, MyInt(17))
968
969    def test_fraction(self):
970        # Test conversions to Fraction.
971        x = statistics._convert(Fraction(95, 99), Fraction)
972        self.check_exact_equal(x, Fraction(95, 99))
973        class MyFraction(Fraction):
974            def __truediv__(self, other):
975                return self.__class__(super().__truediv__(other))
976        x = statistics._convert(Fraction(71, 13), MyFraction)
977        self.check_exact_equal(x, MyFraction(71, 13))
978
979    def test_float(self):
980        # Test conversions to float.
981        x = statistics._convert(Fraction(-1, 2), float)
982        self.check_exact_equal(x, -0.5)
983        class MyFloat(float):
984            def __truediv__(self, other):
985                return self.__class__(super().__truediv__(other))
986        x = statistics._convert(Fraction(9, 8), MyFloat)
987        self.check_exact_equal(x, MyFloat(1.125))
988
989    def test_decimal(self):
990        # Test conversions to Decimal.
991        x = statistics._convert(Fraction(1, 40), Decimal)
992        self.check_exact_equal(x, Decimal("0.025"))
993        class MyDecimal(Decimal):
994            def __truediv__(self, other):
995                return self.__class__(super().__truediv__(other))
996        x = statistics._convert(Fraction(-15, 16), MyDecimal)
997        self.check_exact_equal(x, MyDecimal("-0.9375"))
998
999    def test_inf(self):
1000        for INF in (float('inf'), Decimal('inf')):
1001            for inf in (INF, -INF):
1002                x = statistics._convert(inf, type(inf))
1003                self.check_exact_equal(x, inf)
1004
1005    def test_nan(self):
1006        for nan in (float('nan'), Decimal('NAN'), Decimal('sNAN')):
1007            x = statistics._convert(nan, type(nan))
1008            self.assertTrue(_nan_equal(x, nan))
1009
1010    def test_invalid_input_type(self):
1011        with self.assertRaises(TypeError):
1012            statistics._convert(None, float)
1013
1014
1015class FailNegTest(unittest.TestCase):
1016    """Test _fail_neg private function."""
1017
1018    def test_pass_through(self):
1019        # Test that values are passed through unchanged.
1020        values = [1, 2.0, Fraction(3), Decimal(4)]
1021        new = list(statistics._fail_neg(values))
1022        self.assertEqual(values, new)
1023
1024    def test_negatives_raise(self):
1025        # Test that negatives raise an exception.
1026        for x in [1, 2.0, Fraction(3), Decimal(4)]:
1027            seq = [-x]
1028            it = statistics._fail_neg(seq)
1029            self.assertRaises(statistics.StatisticsError, next, it)
1030
1031    def test_error_msg(self):
1032        # Test that a given error message is used.
1033        msg = "badness #%d" % random.randint(10000, 99999)
1034        try:
1035            next(statistics._fail_neg([-1], msg))
1036        except statistics.StatisticsError as e:
1037            errmsg = e.args[0]
1038        else:
1039            self.fail("expected exception, but it didn't happen")
1040        self.assertEqual(errmsg, msg)
1041
1042
1043class FindLteqTest(unittest.TestCase):
1044    # Test _find_lteq private function.
1045
1046    def test_invalid_input_values(self):
1047        for a, x in [
1048            ([], 1),
1049            ([1, 2], 3),
1050            ([1, 3], 2)
1051        ]:
1052            with self.subTest(a=a, x=x):
1053                with self.assertRaises(ValueError):
1054                    statistics._find_lteq(a, x)
1055
1056    def test_locate_successfully(self):
1057        for a, x, expected_i in [
1058            ([1, 1, 1, 2, 3], 1, 0),
1059            ([0, 1, 1, 1, 2, 3], 1, 1),
1060            ([1, 2, 3, 3, 3], 3, 2)
1061        ]:
1062            with self.subTest(a=a, x=x):
1063                self.assertEqual(expected_i, statistics._find_lteq(a, x))
1064
1065
1066class FindRteqTest(unittest.TestCase):
1067    # Test _find_rteq private function.
1068
1069    def test_invalid_input_values(self):
1070        for a, l, x in [
1071            ([1], 2, 1),
1072            ([1, 3], 0, 2)
1073        ]:
1074            with self.assertRaises(ValueError):
1075                statistics._find_rteq(a, l, x)
1076
1077    def test_locate_successfully(self):
1078        for a, l, x, expected_i in [
1079            ([1, 1, 1, 2, 3], 0, 1, 2),
1080            ([0, 1, 1, 1, 2, 3], 0, 1, 3),
1081            ([1, 2, 3, 3, 3], 0, 3, 4)
1082        ]:
1083            with self.subTest(a=a, l=l, x=x):
1084                self.assertEqual(expected_i, statistics._find_rteq(a, l, x))
1085
1086
1087# === Tests for public functions ===
1088
1089class UnivariateCommonMixin:
1090    # Common tests for most univariate functions that take a data argument.
1091
1092    def test_no_args(self):
1093        # Fail if given no arguments.
1094        self.assertRaises(TypeError, self.func)
1095
1096    def test_empty_data(self):
1097        # Fail when the data argument (first argument) is empty.
1098        for empty in ([], (), iter([])):
1099            self.assertRaises(statistics.StatisticsError, self.func, empty)
1100
1101    def prepare_data(self):
1102        """Return int data for various tests."""
1103        data = list(range(10))
1104        while data == sorted(data):
1105            random.shuffle(data)
1106        return data
1107
1108    def test_no_inplace_modifications(self):
1109        # Test that the function does not modify its input data.
1110        data = self.prepare_data()
1111        assert len(data) != 1  # Necessary to avoid infinite loop.
1112        assert data != sorted(data)
1113        saved = data[:]
1114        assert data is not saved
1115        _ = self.func(data)
1116        self.assertListEqual(data, saved, "data has been modified")
1117
1118    def test_order_doesnt_matter(self):
1119        # Test that the order of data points doesn't change the result.
1120
1121        # CAUTION: due to floating point rounding errors, the result actually
1122        # may depend on the order. Consider this test representing an ideal.
1123        # To avoid this test failing, only test with exact values such as ints
1124        # or Fractions.
1125        data = [1, 2, 3, 3, 3, 4, 5, 6]*100
1126        expected = self.func(data)
1127        random.shuffle(data)
1128        actual = self.func(data)
1129        self.assertEqual(expected, actual)
1130
1131    def test_type_of_data_collection(self):
1132        # Test that the type of iterable data doesn't effect the result.
1133        class MyList(list):
1134            pass
1135        class MyTuple(tuple):
1136            pass
1137        def generator(data):
1138            return (obj for obj in data)
1139        data = self.prepare_data()
1140        expected = self.func(data)
1141        for kind in (list, tuple, iter, MyList, MyTuple, generator):
1142            result = self.func(kind(data))
1143            self.assertEqual(result, expected)
1144
1145    def test_range_data(self):
1146        # Test that functions work with range objects.
1147        data = range(20, 50, 3)
1148        expected = self.func(list(data))
1149        self.assertEqual(self.func(data), expected)
1150
1151    def test_bad_arg_types(self):
1152        # Test that function raises when given data of the wrong type.
1153
1154        # Don't roll the following into a loop like this:
1155        #   for bad in list_of_bad:
1156        #       self.check_for_type_error(bad)
1157        #
1158        # Since assertRaises doesn't show the arguments that caused the test
1159        # failure, it is very difficult to debug these test failures when the
1160        # following are in a loop.
1161        self.check_for_type_error(None)
1162        self.check_for_type_error(23)
1163        self.check_for_type_error(42.0)
1164        self.check_for_type_error(object())
1165
1166    def check_for_type_error(self, *args):
1167        self.assertRaises(TypeError, self.func, *args)
1168
1169    def test_type_of_data_element(self):
1170        # Check the type of data elements doesn't affect the numeric result.
1171        # This is a weaker test than UnivariateTypeMixin.testTypesConserved,
1172        # because it checks the numeric result by equality, but not by type.
1173        class MyFloat(float):
1174            def __truediv__(self, other):
1175                return type(self)(super().__truediv__(other))
1176            def __add__(self, other):
1177                return type(self)(super().__add__(other))
1178            __radd__ = __add__
1179
1180        raw = self.prepare_data()
1181        expected = self.func(raw)
1182        for kind in (float, MyFloat, Decimal, Fraction):
1183            data = [kind(x) for x in raw]
1184            result = type(expected)(self.func(data))
1185            self.assertEqual(result, expected)
1186
1187
1188class UnivariateTypeMixin:
1189    """Mixin class for type-conserving functions.
1190
1191    This mixin class holds test(s) for functions which conserve the type of
1192    individual data points. E.g. the mean of a list of Fractions should itself
1193    be a Fraction.
1194
1195    Not all tests to do with types need go in this class. Only those that
1196    rely on the function returning the same type as its input data.
1197    """
1198    def prepare_types_for_conservation_test(self):
1199        """Return the types which are expected to be conserved."""
1200        class MyFloat(float):
1201            def __truediv__(self, other):
1202                return type(self)(super().__truediv__(other))
1203            def __rtruediv__(self, other):
1204                return type(self)(super().__rtruediv__(other))
1205            def __sub__(self, other):
1206                return type(self)(super().__sub__(other))
1207            def __rsub__(self, other):
1208                return type(self)(super().__rsub__(other))
1209            def __pow__(self, other):
1210                return type(self)(super().__pow__(other))
1211            def __add__(self, other):
1212                return type(self)(super().__add__(other))
1213            __radd__ = __add__
1214            def __mul__(self, other):
1215                return type(self)(super().__mul__(other))
1216            __rmul__ = __mul__
1217        return (float, Decimal, Fraction, MyFloat)
1218
1219    def test_types_conserved(self):
1220        # Test that functions keeps the same type as their data points.
1221        # (Excludes mixed data types.) This only tests the type of the return
1222        # result, not the value.
1223        data = self.prepare_data()
1224        for kind in self.prepare_types_for_conservation_test():
1225            d = [kind(x) for x in data]
1226            result = self.func(d)
1227            self.assertIs(type(result), kind)
1228
1229
1230class TestSumCommon(UnivariateCommonMixin, UnivariateTypeMixin):
1231    # Common test cases for statistics._sum() function.
1232
1233    # This test suite looks only at the numeric value returned by _sum,
1234    # after conversion to the appropriate type.
1235    def setUp(self):
1236        def simplified_sum(*args):
1237            T, value, n = statistics._sum(*args)
1238            return statistics._coerce(value, T)
1239        self.func = simplified_sum
1240
1241
1242class TestSum(NumericTestCase):
1243    # Test cases for statistics._sum() function.
1244
1245    # These tests look at the entire three value tuple returned by _sum.
1246
1247    def setUp(self):
1248        self.func = statistics._sum
1249
1250    def test_empty_data(self):
1251        # Override test for empty data.
1252        for data in ([], (), iter([])):
1253            self.assertEqual(self.func(data), (int, Fraction(0), 0))
1254
1255    def test_ints(self):
1256        self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]),
1257                         (int, Fraction(60), 8))
1258
1259    def test_floats(self):
1260        self.assertEqual(self.func([0.25]*20),
1261                         (float, Fraction(5.0), 20))
1262
1263    def test_fractions(self):
1264        self.assertEqual(self.func([Fraction(1, 1000)]*500),
1265                         (Fraction, Fraction(1, 2), 500))
1266
1267    def test_decimals(self):
1268        D = Decimal
1269        data = [D("0.001"), D("5.246"), D("1.702"), D("-0.025"),
1270                D("3.974"), D("2.328"), D("4.617"), D("2.843"),
1271                ]
1272        self.assertEqual(self.func(data),
1273                         (Decimal, Decimal("20.686"), 8))
1274
1275    def test_compare_with_math_fsum(self):
1276        # Compare with the math.fsum function.
1277        # Ideally we ought to get the exact same result, but sometimes
1278        # we differ by a very slight amount :-(
1279        data = [random.uniform(-100, 1000) for _ in range(1000)]
1280        self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16)
1281
1282    def test_strings_fail(self):
1283        # Sum of strings should fail.
1284        self.assertRaises(TypeError, self.func, [1, 2, 3], '999')
1285        self.assertRaises(TypeError, self.func, [1, 2, 3, '999'])
1286
1287    def test_bytes_fail(self):
1288        # Sum of bytes should fail.
1289        self.assertRaises(TypeError, self.func, [1, 2, 3], b'999')
1290        self.assertRaises(TypeError, self.func, [1, 2, 3, b'999'])
1291
1292    def test_mixed_sum(self):
1293        # Mixed input types are not (currently) allowed.
1294        # Check that mixed data types fail.
1295        self.assertRaises(TypeError, self.func, [1, 2.0, Decimal(1)])
1296        # And so does mixed start argument.
1297        self.assertRaises(TypeError, self.func, [1, 2.0], Decimal(1))
1298
1299
1300class SumTortureTest(NumericTestCase):
1301    def test_torture(self):
1302        # Tim Peters' torture test for sum, and variants of same.
1303        self.assertEqual(statistics._sum([1, 1e100, 1, -1e100]*10000),
1304                         (float, Fraction(20000.0), 40000))
1305        self.assertEqual(statistics._sum([1e100, 1, 1, -1e100]*10000),
1306                         (float, Fraction(20000.0), 40000))
1307        T, num, count = statistics._sum([1e-100, 1, 1e-100, -1]*10000)
1308        self.assertIs(T, float)
1309        self.assertEqual(count, 40000)
1310        self.assertApproxEqual(float(num), 2.0e-96, rel=5e-16)
1311
1312
1313class SumSpecialValues(NumericTestCase):
1314    # Test that sum works correctly with IEEE-754 special values.
1315
1316    def test_nan(self):
1317        for type_ in (float, Decimal):
1318            nan = type_('nan')
1319            result = statistics._sum([1, nan, 2])[1]
1320            self.assertIs(type(result), type_)
1321            self.assertTrue(math.isnan(result))
1322
1323    def check_infinity(self, x, inf):
1324        """Check x is an infinity of the same type and sign as inf."""
1325        self.assertTrue(math.isinf(x))
1326        self.assertIs(type(x), type(inf))
1327        self.assertEqual(x > 0, inf > 0)
1328        assert x == inf
1329
1330    def do_test_inf(self, inf):
1331        # Adding a single infinity gives infinity.
1332        result = statistics._sum([1, 2, inf, 3])[1]
1333        self.check_infinity(result, inf)
1334        # Adding two infinities of the same sign also gives infinity.
1335        result = statistics._sum([1, 2, inf, 3, inf, 4])[1]
1336        self.check_infinity(result, inf)
1337
1338    def test_float_inf(self):
1339        inf = float('inf')
1340        for sign in (+1, -1):
1341            self.do_test_inf(sign*inf)
1342
1343    def test_decimal_inf(self):
1344        inf = Decimal('inf')
1345        for sign in (+1, -1):
1346            self.do_test_inf(sign*inf)
1347
1348    def test_float_mismatched_infs(self):
1349        # Test that adding two infinities of opposite sign gives a NAN.
1350        inf = float('inf')
1351        result = statistics._sum([1, 2, inf, 3, -inf, 4])[1]
1352        self.assertTrue(math.isnan(result))
1353
1354    def test_decimal_extendedcontext_mismatched_infs_to_nan(self):
1355        # Test adding Decimal INFs with opposite sign returns NAN.
1356        inf = Decimal('inf')
1357        data = [1, 2, inf, 3, -inf, 4]
1358        with decimal.localcontext(decimal.ExtendedContext):
1359            self.assertTrue(math.isnan(statistics._sum(data)[1]))
1360
1361    def test_decimal_basiccontext_mismatched_infs_to_nan(self):
1362        # Test adding Decimal INFs with opposite sign raises InvalidOperation.
1363        inf = Decimal('inf')
1364        data = [1, 2, inf, 3, -inf, 4]
1365        with decimal.localcontext(decimal.BasicContext):
1366            self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1367
1368    def test_decimal_snan_raises(self):
1369        # Adding sNAN should raise InvalidOperation.
1370        sNAN = Decimal('sNAN')
1371        data = [1, sNAN, 2]
1372        self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1373
1374
1375# === Tests for averages ===
1376
1377class AverageMixin(UnivariateCommonMixin):
1378    # Mixin class holding common tests for averages.
1379
1380    def test_single_value(self):
1381        # Average of a single value is the value itself.
1382        for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')):
1383            self.assertEqual(self.func([x]), x)
1384
1385    def prepare_values_for_repeated_single_test(self):
1386        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712'))
1387
1388    def test_repeated_single_value(self):
1389        # The average of a single repeated value is the value itself.
1390        for x in self.prepare_values_for_repeated_single_test():
1391            for count in (2, 5, 10, 20):
1392                with self.subTest(x=x, count=count):
1393                    data = [x]*count
1394                    self.assertEqual(self.func(data), x)
1395
1396
1397class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1398    def setUp(self):
1399        self.func = statistics.mean
1400
1401    def test_torture_pep(self):
1402        # "Torture Test" from PEP-450.
1403        self.assertEqual(self.func([1e100, 1, 3, -1e100]), 1)
1404
1405    def test_ints(self):
1406        # Test mean with ints.
1407        data = [0, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 7, 7, 7, 8, 9]
1408        random.shuffle(data)
1409        self.assertEqual(self.func(data), 4.8125)
1410
1411    def test_floats(self):
1412        # Test mean with floats.
1413        data = [17.25, 19.75, 20.0, 21.5, 21.75, 23.25, 25.125, 27.5]
1414        random.shuffle(data)
1415        self.assertEqual(self.func(data), 22.015625)
1416
1417    def test_decimals(self):
1418        # Test mean with Decimals.
1419        D = Decimal
1420        data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")]
1421        random.shuffle(data)
1422        self.assertEqual(self.func(data), D("3.5896"))
1423
1424    def test_fractions(self):
1425        # Test mean with Fractions.
1426        F = Fraction
1427        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1428        random.shuffle(data)
1429        self.assertEqual(self.func(data), F(1479, 1960))
1430
1431    def test_inf(self):
1432        # Test mean with infinities.
1433        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1434        for kind in (float, Decimal):
1435            for sign in (1, -1):
1436                inf = kind("inf")*sign
1437                data = raw + [inf]
1438                result = self.func(data)
1439                self.assertTrue(math.isinf(result))
1440                self.assertEqual(result, inf)
1441
1442    def test_mismatched_infs(self):
1443        # Test mean with infinities of opposite sign.
1444        data = [2, 4, 6, float('inf'), 1, 3, 5, float('-inf')]
1445        result = self.func(data)
1446        self.assertTrue(math.isnan(result))
1447
1448    def test_nan(self):
1449        # Test mean with NANs.
1450        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1451        for kind in (float, Decimal):
1452            inf = kind("nan")
1453            data = raw + [inf]
1454            result = self.func(data)
1455            self.assertTrue(math.isnan(result))
1456
1457    def test_big_data(self):
1458        # Test adding a large constant to every data point.
1459        c = 1e9
1460        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1461        expected = self.func(data) + c
1462        assert expected != c
1463        result = self.func([x+c for x in data])
1464        self.assertEqual(result, expected)
1465
1466    def test_doubled_data(self):
1467        # Mean of [a,b,c...z] should be same as for [a,a,b,b,c,c...z,z].
1468        data = [random.uniform(-3, 5) for _ in range(1000)]
1469        expected = self.func(data)
1470        actual = self.func(data*2)
1471        self.assertApproxEqual(actual, expected)
1472
1473    def test_regression_20561(self):
1474        # Regression test for issue 20561.
1475        # See http://bugs.python.org/issue20561
1476        d = Decimal('1e4')
1477        self.assertEqual(statistics.mean([d]), d)
1478
1479    def test_regression_25177(self):
1480        # Regression test for issue 25177.
1481        # Ensure very big and very small floats don't overflow.
1482        # See http://bugs.python.org/issue25177.
1483        self.assertEqual(statistics.mean(
1484            [8.988465674311579e+307, 8.98846567431158e+307]),
1485            8.98846567431158e+307)
1486        big = 8.98846567431158e+307
1487        tiny = 5e-324
1488        for n in (2, 3, 5, 200):
1489            self.assertEqual(statistics.mean([big]*n), big)
1490            self.assertEqual(statistics.mean([tiny]*n), tiny)
1491
1492
1493class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1494    def setUp(self):
1495        self.func = statistics.harmonic_mean
1496
1497    def prepare_data(self):
1498        # Override mixin method.
1499        values = super().prepare_data()
1500        values.remove(0)
1501        return values
1502
1503    def prepare_values_for_repeated_single_test(self):
1504        # Override mixin method.
1505        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.125'))
1506
1507    def test_zero(self):
1508        # Test that harmonic mean returns zero when given zero.
1509        values = [1, 0, 2]
1510        self.assertEqual(self.func(values), 0)
1511
1512    def test_negative_error(self):
1513        # Test that harmonic mean raises when given a negative value.
1514        exc = statistics.StatisticsError
1515        for values in ([-1], [1, -2, 3]):
1516            with self.subTest(values=values):
1517                self.assertRaises(exc, self.func, values)
1518
1519    def test_invalid_type_error(self):
1520        # Test error is raised when input contains invalid type(s)
1521        for data in [
1522            ['3.14'],               # single string
1523            ['1', '2', '3'],        # multiple strings
1524            [1, '2', 3, '4', 5],    # mixed strings and valid integers
1525            [2.3, 3.4, 4.5, '5.6']  # only one string and valid floats
1526        ]:
1527            with self.subTest(data=data):
1528                with self.assertRaises(TypeError):
1529                    self.func(data)
1530
1531    def test_ints(self):
1532        # Test harmonic mean with ints.
1533        data = [2, 4, 4, 8, 16, 16]
1534        random.shuffle(data)
1535        self.assertEqual(self.func(data), 6*4/5)
1536
1537    def test_floats_exact(self):
1538        # Test harmonic mean with some carefully chosen floats.
1539        data = [1/8, 1/4, 1/4, 1/2, 1/2]
1540        random.shuffle(data)
1541        self.assertEqual(self.func(data), 1/4)
1542        self.assertEqual(self.func([0.25, 0.5, 1.0, 1.0]), 0.5)
1543
1544    def test_singleton_lists(self):
1545        # Test that harmonic mean([x]) returns (approximately) x.
1546        for x in range(1, 101):
1547            self.assertEqual(self.func([x]), x)
1548
1549    def test_decimals_exact(self):
1550        # Test harmonic mean with some carefully chosen Decimals.
1551        D = Decimal
1552        self.assertEqual(self.func([D(15), D(30), D(60), D(60)]), D(30))
1553        data = [D("0.05"), D("0.10"), D("0.20"), D("0.20")]
1554        random.shuffle(data)
1555        self.assertEqual(self.func(data), D("0.10"))
1556        data = [D("1.68"), D("0.32"), D("5.94"), D("2.75")]
1557        random.shuffle(data)
1558        self.assertEqual(self.func(data), D(66528)/70723)
1559
1560    def test_fractions(self):
1561        # Test harmonic mean with Fractions.
1562        F = Fraction
1563        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1564        random.shuffle(data)
1565        self.assertEqual(self.func(data), F(7*420, 4029))
1566
1567    def test_inf(self):
1568        # Test harmonic mean with infinity.
1569        values = [2.0, float('inf'), 1.0]
1570        self.assertEqual(self.func(values), 2.0)
1571
1572    def test_nan(self):
1573        # Test harmonic mean with NANs.
1574        values = [2.0, float('nan'), 1.0]
1575        self.assertTrue(math.isnan(self.func(values)))
1576
1577    def test_multiply_data_points(self):
1578        # Test multiplying every data point by a constant.
1579        c = 111
1580        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1581        expected = self.func(data)*c
1582        result = self.func([x*c for x in data])
1583        self.assertEqual(result, expected)
1584
1585    def test_doubled_data(self):
1586        # Harmonic mean of [a,b...z] should be same as for [a,a,b,b...z,z].
1587        data = [random.uniform(1, 5) for _ in range(1000)]
1588        expected = self.func(data)
1589        actual = self.func(data*2)
1590        self.assertApproxEqual(actual, expected)
1591
1592    def test_with_weights(self):
1593        self.assertEqual(self.func([40, 60], [5, 30]), 56.0)  # common case
1594        self.assertEqual(self.func([40, 60],
1595                                   weights=[5, 30]), 56.0)    # keyword argument
1596        self.assertEqual(self.func(iter([40, 60]),
1597                                   iter([5, 30])), 56.0)      # iterator inputs
1598        self.assertEqual(
1599            self.func([Fraction(10, 3), Fraction(23, 5), Fraction(7, 2)], [5, 2, 10]),
1600            self.func([Fraction(10, 3)] * 5 +
1601                      [Fraction(23, 5)] * 2 +
1602                      [Fraction(7, 2)] * 10))
1603        self.assertEqual(self.func([10], [7]), 10)            # n=1 fast path
1604        with self.assertRaises(TypeError):
1605            self.func([1, 2, 3], [1, (), 3])                  # non-numeric weight
1606        with self.assertRaises(statistics.StatisticsError):
1607            self.func([1, 2, 3], [1, 2])                      # wrong number of weights
1608        with self.assertRaises(statistics.StatisticsError):
1609            self.func([10], [0])                              # no non-zero weights
1610        with self.assertRaises(statistics.StatisticsError):
1611            self.func([10, 20], [0, 0])                       # no non-zero weights
1612
1613
1614class TestMedian(NumericTestCase, AverageMixin):
1615    # Common tests for median and all median.* functions.
1616    def setUp(self):
1617        self.func = statistics.median
1618
1619    def prepare_data(self):
1620        """Overload method from UnivariateCommonMixin."""
1621        data = super().prepare_data()
1622        if len(data)%2 != 1:
1623            data.append(2)
1624        return data
1625
1626    def test_even_ints(self):
1627        # Test median with an even number of int data points.
1628        data = [1, 2, 3, 4, 5, 6]
1629        assert len(data)%2 == 0
1630        self.assertEqual(self.func(data), 3.5)
1631
1632    def test_odd_ints(self):
1633        # Test median with an odd number of int data points.
1634        data = [1, 2, 3, 4, 5, 6, 9]
1635        assert len(data)%2 == 1
1636        self.assertEqual(self.func(data), 4)
1637
1638    def test_odd_fractions(self):
1639        # Test median works with an odd number of Fractions.
1640        F = Fraction
1641        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7)]
1642        assert len(data)%2 == 1
1643        random.shuffle(data)
1644        self.assertEqual(self.func(data), F(3, 7))
1645
1646    def test_even_fractions(self):
1647        # Test median works with an even number of Fractions.
1648        F = Fraction
1649        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1650        assert len(data)%2 == 0
1651        random.shuffle(data)
1652        self.assertEqual(self.func(data), F(1, 2))
1653
1654    def test_odd_decimals(self):
1655        # Test median works with an odd number of Decimals.
1656        D = Decimal
1657        data = [D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1658        assert len(data)%2 == 1
1659        random.shuffle(data)
1660        self.assertEqual(self.func(data), D('4.2'))
1661
1662    def test_even_decimals(self):
1663        # Test median works with an even number of Decimals.
1664        D = Decimal
1665        data = [D('1.2'), D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1666        assert len(data)%2 == 0
1667        random.shuffle(data)
1668        self.assertEqual(self.func(data), D('3.65'))
1669
1670
1671class TestMedianDataType(NumericTestCase, UnivariateTypeMixin):
1672    # Test conservation of data element type for median.
1673    def setUp(self):
1674        self.func = statistics.median
1675
1676    def prepare_data(self):
1677        data = list(range(15))
1678        assert len(data)%2 == 1
1679        while data == sorted(data):
1680            random.shuffle(data)
1681        return data
1682
1683
1684class TestMedianLow(TestMedian, UnivariateTypeMixin):
1685    def setUp(self):
1686        self.func = statistics.median_low
1687
1688    def test_even_ints(self):
1689        # Test median_low with an even number of ints.
1690        data = [1, 2, 3, 4, 5, 6]
1691        assert len(data)%2 == 0
1692        self.assertEqual(self.func(data), 3)
1693
1694    def test_even_fractions(self):
1695        # Test median_low works with an even number of Fractions.
1696        F = Fraction
1697        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1698        assert len(data)%2 == 0
1699        random.shuffle(data)
1700        self.assertEqual(self.func(data), F(3, 7))
1701
1702    def test_even_decimals(self):
1703        # Test median_low works with an even number of Decimals.
1704        D = Decimal
1705        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1706        assert len(data)%2 == 0
1707        random.shuffle(data)
1708        self.assertEqual(self.func(data), D('3.3'))
1709
1710
1711class TestMedianHigh(TestMedian, UnivariateTypeMixin):
1712    def setUp(self):
1713        self.func = statistics.median_high
1714
1715    def test_even_ints(self):
1716        # Test median_high with an even number of ints.
1717        data = [1, 2, 3, 4, 5, 6]
1718        assert len(data)%2 == 0
1719        self.assertEqual(self.func(data), 4)
1720
1721    def test_even_fractions(self):
1722        # Test median_high works with an even number of Fractions.
1723        F = Fraction
1724        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1725        assert len(data)%2 == 0
1726        random.shuffle(data)
1727        self.assertEqual(self.func(data), F(4, 7))
1728
1729    def test_even_decimals(self):
1730        # Test median_high works with an even number of Decimals.
1731        D = Decimal
1732        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1733        assert len(data)%2 == 0
1734        random.shuffle(data)
1735        self.assertEqual(self.func(data), D('4.4'))
1736
1737
1738class TestMedianGrouped(TestMedian):
1739    # Test median_grouped.
1740    # Doesn't conserve data element types, so don't use TestMedianType.
1741    def setUp(self):
1742        self.func = statistics.median_grouped
1743
1744    def test_odd_number_repeated(self):
1745        # Test median.grouped with repeated median values.
1746        data = [12, 13, 14, 14, 14, 15, 15]
1747        assert len(data)%2 == 1
1748        self.assertEqual(self.func(data), 14)
1749        #---
1750        data = [12, 13, 14, 14, 14, 14, 15]
1751        assert len(data)%2 == 1
1752        self.assertEqual(self.func(data), 13.875)
1753        #---
1754        data = [5, 10, 10, 15, 20, 20, 20, 20, 25, 25, 30]
1755        assert len(data)%2 == 1
1756        self.assertEqual(self.func(data, 5), 19.375)
1757        #---
1758        data = [16, 18, 18, 18, 18, 20, 20, 20, 22, 22, 22, 24, 24, 26, 28]
1759        assert len(data)%2 == 1
1760        self.assertApproxEqual(self.func(data, 2), 20.66666667, tol=1e-8)
1761
1762    def test_even_number_repeated(self):
1763        # Test median.grouped with repeated median values.
1764        data = [5, 10, 10, 15, 20, 20, 20, 25, 25, 30]
1765        assert len(data)%2 == 0
1766        self.assertApproxEqual(self.func(data, 5), 19.16666667, tol=1e-8)
1767        #---
1768        data = [2, 3, 4, 4, 4, 5]
1769        assert len(data)%2 == 0
1770        self.assertApproxEqual(self.func(data), 3.83333333, tol=1e-8)
1771        #---
1772        data = [2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1773        assert len(data)%2 == 0
1774        self.assertEqual(self.func(data), 4.5)
1775        #---
1776        data = [3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1777        assert len(data)%2 == 0
1778        self.assertEqual(self.func(data), 4.75)
1779
1780    def test_repeated_single_value(self):
1781        # Override method from AverageMixin.
1782        # Yet again, failure of median_grouped to conserve the data type
1783        # causes me headaches :-(
1784        for x in (5.3, 68, 4.3e17, Fraction(29, 101), Decimal('32.9714')):
1785            for count in (2, 5, 10, 20):
1786                data = [x]*count
1787                self.assertEqual(self.func(data), float(x))
1788
1789    def test_odd_fractions(self):
1790        # Test median_grouped works with an odd number of Fractions.
1791        F = Fraction
1792        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4)]
1793        assert len(data)%2 == 1
1794        random.shuffle(data)
1795        self.assertEqual(self.func(data), 3.0)
1796
1797    def test_even_fractions(self):
1798        # Test median_grouped works with an even number of Fractions.
1799        F = Fraction
1800        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4), F(17, 4)]
1801        assert len(data)%2 == 0
1802        random.shuffle(data)
1803        self.assertEqual(self.func(data), 3.25)
1804
1805    def test_odd_decimals(self):
1806        # Test median_grouped works with an odd number of Decimals.
1807        D = Decimal
1808        data = [D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1809        assert len(data)%2 == 1
1810        random.shuffle(data)
1811        self.assertEqual(self.func(data), 6.75)
1812
1813    def test_even_decimals(self):
1814        # Test median_grouped works with an even number of Decimals.
1815        D = Decimal
1816        data = [D('5.5'), D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1817        assert len(data)%2 == 0
1818        random.shuffle(data)
1819        self.assertEqual(self.func(data), 6.5)
1820        #---
1821        data = [D('5.5'), D('5.5'), D('6.5'), D('7.5'), D('7.5'), D('8.5')]
1822        assert len(data)%2 == 0
1823        random.shuffle(data)
1824        self.assertEqual(self.func(data), 7.0)
1825
1826    def test_interval(self):
1827        # Test median_grouped with interval argument.
1828        data = [2.25, 2.5, 2.5, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1829        self.assertEqual(self.func(data, 0.25), 2.875)
1830        data = [2.25, 2.5, 2.5, 2.75, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1831        self.assertApproxEqual(self.func(data, 0.25), 2.83333333, tol=1e-8)
1832        data = [220, 220, 240, 260, 260, 260, 260, 280, 280, 300, 320, 340]
1833        self.assertEqual(self.func(data, 20), 265.0)
1834
1835    def test_data_type_error(self):
1836        # Test median_grouped with str, bytes data types for data and interval
1837        data = ["", "", ""]
1838        self.assertRaises(TypeError, self.func, data)
1839        #---
1840        data = [b"", b"", b""]
1841        self.assertRaises(TypeError, self.func, data)
1842        #---
1843        data = [1, 2, 3]
1844        interval = ""
1845        self.assertRaises(TypeError, self.func, data, interval)
1846        #---
1847        data = [1, 2, 3]
1848        interval = b""
1849        self.assertRaises(TypeError, self.func, data, interval)
1850
1851
1852class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1853    # Test cases for the discrete version of mode.
1854    def setUp(self):
1855        self.func = statistics.mode
1856
1857    def prepare_data(self):
1858        """Overload method from UnivariateCommonMixin."""
1859        # Make sure test data has exactly one mode.
1860        return [1, 1, 1, 1, 3, 4, 7, 9, 0, 8, 2]
1861
1862    def test_range_data(self):
1863        # Override test from UnivariateCommonMixin.
1864        data = range(20, 50, 3)
1865        self.assertEqual(self.func(data), 20)
1866
1867    def test_nominal_data(self):
1868        # Test mode with nominal data.
1869        data = 'abcbdb'
1870        self.assertEqual(self.func(data), 'b')
1871        data = 'fe fi fo fum fi fi'.split()
1872        self.assertEqual(self.func(data), 'fi')
1873
1874    def test_discrete_data(self):
1875        # Test mode with discrete numeric data.
1876        data = list(range(10))
1877        for i in range(10):
1878            d = data + [i]
1879            random.shuffle(d)
1880            self.assertEqual(self.func(d), i)
1881
1882    def test_bimodal_data(self):
1883        # Test mode with bimodal data.
1884        data = [1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, 9, 9]
1885        assert data.count(2) == data.count(6) == 4
1886        # mode() should return 2, the first encountered mode
1887        self.assertEqual(self.func(data), 2)
1888
1889    def test_unique_data(self):
1890        # Test mode when data points are all unique.
1891        data = list(range(10))
1892        # mode() should return 0, the first encountered mode
1893        self.assertEqual(self.func(data), 0)
1894
1895    def test_none_data(self):
1896        # Test that mode raises TypeError if given None as data.
1897
1898        # This test is necessary because the implementation of mode uses
1899        # collections.Counter, which accepts None and returns an empty dict.
1900        self.assertRaises(TypeError, self.func, None)
1901
1902    def test_counter_data(self):
1903        # Test that a Counter is treated like any other iterable.
1904        # We're making sure mode() first calls iter() on its input.
1905        # The concern is that a Counter of a Counter returns the original
1906        # unchanged rather than counting its keys.
1907        c = collections.Counter(a=1, b=2)
1908        # If iter() is called, mode(c) loops over the keys, ['a', 'b'],
1909        # all the counts will be 1, and the first encountered mode is 'a'.
1910        self.assertEqual(self.func(c), 'a')
1911
1912
1913class TestMultiMode(unittest.TestCase):
1914
1915    def test_basics(self):
1916        multimode = statistics.multimode
1917        self.assertEqual(multimode('aabbbbbbbbcc'), ['b'])
1918        self.assertEqual(multimode('aabbbbccddddeeffffgg'), ['b', 'd', 'f'])
1919        self.assertEqual(multimode(''), [])
1920
1921
1922class TestFMean(unittest.TestCase):
1923
1924    def test_basics(self):
1925        fmean = statistics.fmean
1926        D = Decimal
1927        F = Fraction
1928        for data, expected_mean, kind in [
1929            ([3.5, 4.0, 5.25], 4.25, 'floats'),
1930            ([D('3.5'), D('4.0'), D('5.25')], 4.25, 'decimals'),
1931            ([F(7, 2), F(4, 1), F(21, 4)], 4.25, 'fractions'),
1932            ([True, False, True, True, False], 0.60, 'booleans'),
1933            ([3.5, 4, F(21, 4)], 4.25, 'mixed types'),
1934            ((3.5, 4.0, 5.25), 4.25, 'tuple'),
1935            (iter([3.5, 4.0, 5.25]), 4.25, 'iterator'),
1936                ]:
1937            actual_mean = fmean(data)
1938            self.assertIs(type(actual_mean), float, kind)
1939            self.assertEqual(actual_mean, expected_mean, kind)
1940
1941    def test_error_cases(self):
1942        fmean = statistics.fmean
1943        StatisticsError = statistics.StatisticsError
1944        with self.assertRaises(StatisticsError):
1945            fmean([])                               # empty input
1946        with self.assertRaises(StatisticsError):
1947            fmean(iter([]))                         # empty iterator
1948        with self.assertRaises(TypeError):
1949            fmean(None)                             # non-iterable input
1950        with self.assertRaises(TypeError):
1951            fmean([10, None, 20])                   # non-numeric input
1952        with self.assertRaises(TypeError):
1953            fmean()                                 # missing data argument
1954        with self.assertRaises(TypeError):
1955            fmean([10, 20, 60], 70)                 # too many arguments
1956
1957    def test_special_values(self):
1958        # Rules for special values are inherited from math.fsum()
1959        fmean = statistics.fmean
1960        NaN = float('Nan')
1961        Inf = float('Inf')
1962        self.assertTrue(math.isnan(fmean([10, NaN])), 'nan')
1963        self.assertTrue(math.isnan(fmean([NaN, Inf])), 'nan and infinity')
1964        self.assertTrue(math.isinf(fmean([10, Inf])), 'infinity')
1965        with self.assertRaises(ValueError):
1966            fmean([Inf, -Inf])
1967
1968    def test_weights(self):
1969        fmean = statistics.fmean
1970        StatisticsError = statistics.StatisticsError
1971        self.assertEqual(
1972            fmean([10, 10, 10, 50], [0.25] * 4),
1973            fmean([10, 10, 10, 50]))
1974        self.assertEqual(
1975            fmean([10, 10, 20], [0.25, 0.25, 0.50]),
1976            fmean([10, 10, 20, 20]))
1977        self.assertEqual(                           # inputs are iterators
1978            fmean(iter([10, 10, 20]), iter([0.25, 0.25, 0.50])),
1979            fmean([10, 10, 20, 20]))
1980        with self.assertRaises(StatisticsError):
1981            fmean([10, 20, 30], [1, 2])             # unequal lengths
1982        with self.assertRaises(StatisticsError):
1983            fmean(iter([10, 20, 30]), iter([1, 2])) # unequal lengths
1984        with self.assertRaises(StatisticsError):
1985            fmean([10, 20], [-1, 1])                # sum of weights is zero
1986        with self.assertRaises(StatisticsError):
1987            fmean(iter([10, 20]), iter([-1, 1]))    # sum of weights is zero
1988
1989
1990# === Tests for variances and standard deviations ===
1991
1992class VarianceStdevMixin(UnivariateCommonMixin):
1993    # Mixin class holding common tests for variance and std dev.
1994
1995    # Subclasses should inherit from this before NumericTestClass, in order
1996    # to see the rel attribute below. See testShiftData for an explanation.
1997
1998    rel = 1e-12
1999
2000    def test_single_value(self):
2001        # Deviation of a single value is zero.
2002        for x in (11, 19.8, 4.6e14, Fraction(21, 34), Decimal('8.392')):
2003            self.assertEqual(self.func([x]), 0)
2004
2005    def test_repeated_single_value(self):
2006        # The deviation of a single repeated value is zero.
2007        for x in (7.2, 49, 8.1e15, Fraction(3, 7), Decimal('62.4802')):
2008            for count in (2, 3, 5, 15):
2009                data = [x]*count
2010                self.assertEqual(self.func(data), 0)
2011
2012    def test_domain_error_regression(self):
2013        # Regression test for a domain error exception.
2014        # (Thanks to Geremy Condra.)
2015        data = [0.123456789012345]*10000
2016        # All the items are identical, so variance should be exactly zero.
2017        # We allow some small round-off error, but not much.
2018        result = self.func(data)
2019        self.assertApproxEqual(result, 0.0, tol=5e-17)
2020        self.assertGreaterEqual(result, 0)  # A negative result must fail.
2021
2022    def test_shift_data(self):
2023        # Test that shifting the data by a constant amount does not affect
2024        # the variance or stdev. Or at least not much.
2025
2026        # Due to rounding, this test should be considered an ideal. We allow
2027        # some tolerance away from "no change at all" by setting tol and/or rel
2028        # attributes. Subclasses may set tighter or looser error tolerances.
2029        raw = [1.03, 1.27, 1.94, 2.04, 2.58, 3.14, 4.75, 4.98, 5.42, 6.78]
2030        expected = self.func(raw)
2031        # Don't set shift too high, the bigger it is, the more rounding error.
2032        shift = 1e5
2033        data = [x + shift for x in raw]
2034        self.assertApproxEqual(self.func(data), expected)
2035
2036    def test_shift_data_exact(self):
2037        # Like test_shift_data, but result is always exact.
2038        raw = [1, 3, 3, 4, 5, 7, 9, 10, 11, 16]
2039        assert all(x==int(x) for x in raw)
2040        expected = self.func(raw)
2041        shift = 10**9
2042        data = [x + shift for x in raw]
2043        self.assertEqual(self.func(data), expected)
2044
2045    def test_iter_list_same(self):
2046        # Test that iter data and list data give the same result.
2047
2048        # This is an explicit test that iterators and lists are treated the
2049        # same; justification for this test over and above the similar test
2050        # in UnivariateCommonMixin is that an earlier design had variance and
2051        # friends swap between one- and two-pass algorithms, which would
2052        # sometimes give different results.
2053        data = [random.uniform(-3, 8) for _ in range(1000)]
2054        expected = self.func(data)
2055        self.assertEqual(self.func(iter(data)), expected)
2056
2057
2058class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
2059    # Tests for population variance.
2060    def setUp(self):
2061        self.func = statistics.pvariance
2062
2063    def test_exact_uniform(self):
2064        # Test the variance against an exact result for uniform data.
2065        data = list(range(10000))
2066        random.shuffle(data)
2067        expected = (10000**2 - 1)/12  # Exact value.
2068        self.assertEqual(self.func(data), expected)
2069
2070    def test_ints(self):
2071        # Test population variance with int data.
2072        data = [4, 7, 13, 16]
2073        exact = 22.5
2074        self.assertEqual(self.func(data), exact)
2075
2076    def test_fractions(self):
2077        # Test population variance with Fraction data.
2078        F = Fraction
2079        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
2080        exact = F(3, 8)
2081        result = self.func(data)
2082        self.assertEqual(result, exact)
2083        self.assertIsInstance(result, Fraction)
2084
2085    def test_decimals(self):
2086        # Test population variance with Decimal data.
2087        D = Decimal
2088        data = [D("12.1"), D("12.2"), D("12.5"), D("12.9")]
2089        exact = D('0.096875')
2090        result = self.func(data)
2091        self.assertEqual(result, exact)
2092        self.assertIsInstance(result, Decimal)
2093
2094    def test_accuracy_bug_20499(self):
2095        data = [0, 0, 1]
2096        exact = 2 / 9
2097        result = self.func(data)
2098        self.assertEqual(result, exact)
2099        self.assertIsInstance(result, float)
2100
2101
2102class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
2103    # Tests for sample variance.
2104    def setUp(self):
2105        self.func = statistics.variance
2106
2107    def test_single_value(self):
2108        # Override method from VarianceStdevMixin.
2109        for x in (35, 24.7, 8.2e15, Fraction(19, 30), Decimal('4.2084')):
2110            self.assertRaises(statistics.StatisticsError, self.func, [x])
2111
2112    def test_ints(self):
2113        # Test sample variance with int data.
2114        data = [4, 7, 13, 16]
2115        exact = 30
2116        self.assertEqual(self.func(data), exact)
2117
2118    def test_fractions(self):
2119        # Test sample variance with Fraction data.
2120        F = Fraction
2121        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
2122        exact = F(1, 2)
2123        result = self.func(data)
2124        self.assertEqual(result, exact)
2125        self.assertIsInstance(result, Fraction)
2126
2127    def test_decimals(self):
2128        # Test sample variance with Decimal data.
2129        D = Decimal
2130        data = [D(2), D(2), D(7), D(9)]
2131        exact = 4*D('9.5')/D(3)
2132        result = self.func(data)
2133        self.assertEqual(result, exact)
2134        self.assertIsInstance(result, Decimal)
2135
2136    def test_center_not_at_mean(self):
2137        data = (1.0, 2.0)
2138        self.assertEqual(self.func(data), 0.5)
2139        self.assertEqual(self.func(data, xbar=2.0), 1.0)
2140
2141    def test_accuracy_bug_20499(self):
2142        data = [0, 0, 2]
2143        exact = 4 / 3
2144        result = self.func(data)
2145        self.assertEqual(result, exact)
2146        self.assertIsInstance(result, float)
2147
2148class TestPStdev(VarianceStdevMixin, NumericTestCase):
2149    # Tests for population standard deviation.
2150    def setUp(self):
2151        self.func = statistics.pstdev
2152
2153    def test_compare_to_variance(self):
2154        # Test that stdev is, in fact, the square root of variance.
2155        data = [random.uniform(-17, 24) for _ in range(1000)]
2156        expected = math.sqrt(statistics.pvariance(data))
2157        self.assertEqual(self.func(data), expected)
2158
2159    def test_center_not_at_mean(self):
2160        # See issue: 40855
2161        data = (3, 6, 7, 10)
2162        self.assertEqual(self.func(data), 2.5)
2163        self.assertEqual(self.func(data, mu=0.5), 6.5)
2164
2165class TestSqrtHelpers(unittest.TestCase):
2166
2167    def test_integer_sqrt_of_frac_rto(self):
2168        for n, m in itertools.product(range(100), range(1, 1000)):
2169            r = statistics._integer_sqrt_of_frac_rto(n, m)
2170            self.assertIsInstance(r, int)
2171            if r*r*m == n:
2172                # Root is exact
2173                continue
2174            # Inexact, so the root should be odd
2175            self.assertEqual(r&1, 1)
2176            # Verify correct rounding
2177            self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)
2178
2179    @requires_IEEE_754
2180    def test_float_sqrt_of_frac(self):
2181
2182        def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
2183            if not x:
2184                return root == 0.0
2185
2186            # Extract adjacent representable floats
2187            r_up: float = math.nextafter(root, math.inf)
2188            r_down: float = math.nextafter(root, -math.inf)
2189            assert r_down < root < r_up
2190
2191            # Convert to fractions for exact arithmetic
2192            frac_root: Fraction = Fraction(root)
2193            half_way_up: Fraction = (frac_root + Fraction(r_up)) / 2
2194            half_way_down: Fraction = (frac_root + Fraction(r_down)) / 2
2195
2196            # Check a closed interval.
2197            # Does not test for a midpoint rounding rule.
2198            return half_way_down ** 2 <= x <= half_way_up ** 2
2199
2200        randrange = random.randrange
2201
2202        for i in range(60_000):
2203            numerator: int = randrange(10 ** randrange(50))
2204            denonimator: int = randrange(10 ** randrange(50)) + 1
2205            with self.subTest(numerator=numerator, denonimator=denonimator):
2206                x: Fraction = Fraction(numerator, denonimator)
2207                root: float = statistics._float_sqrt_of_frac(numerator, denonimator)
2208                self.assertTrue(is_root_correctly_rounded(x, root))
2209
2210        # Verify that corner cases and error handling match math.sqrt()
2211        self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0)
2212        with self.assertRaises(ValueError):
2213            statistics._float_sqrt_of_frac(-1, 1)
2214        with self.assertRaises(ValueError):
2215            statistics._float_sqrt_of_frac(1, -1)
2216
2217        # Error handling for zero denominator matches that for Fraction(1, 0)
2218        with self.assertRaises(ZeroDivisionError):
2219            statistics._float_sqrt_of_frac(1, 0)
2220
2221        # The result is well defined if both inputs are negative
2222        self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1))
2223
2224    def test_decimal_sqrt_of_frac(self):
2225        root: Decimal
2226        numerator: int
2227        denominator: int
2228
2229        for root, numerator, denominator in [
2230            (Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000),  # No adj
2231            (Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000),  # Adj up
2232            (Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000),  # Adj down
2233        ]:
2234            with decimal.localcontext(decimal.DefaultContext):
2235                self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root)
2236
2237            # Confirm expected root with a quad precision decimal computation
2238            with decimal.localcontext(decimal.DefaultContext) as ctx:
2239                ctx.prec *= 4
2240                high_prec_ratio = Decimal(numerator) / Decimal(denominator)
2241                ctx.rounding = decimal.ROUND_05UP
2242                high_prec_root = high_prec_ratio.sqrt()
2243            with decimal.localcontext(decimal.DefaultContext):
2244                target_root = +high_prec_root
2245            self.assertEqual(root, target_root)
2246
2247        # Verify that corner cases and error handling match Decimal.sqrt()
2248        self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0)
2249        with self.assertRaises(decimal.InvalidOperation):
2250            statistics._decimal_sqrt_of_frac(-1, 1)
2251        with self.assertRaises(decimal.InvalidOperation):
2252            statistics._decimal_sqrt_of_frac(1, -1)
2253
2254        # Error handling for zero denominator matches that for Fraction(1, 0)
2255        with self.assertRaises(ZeroDivisionError):
2256            statistics._decimal_sqrt_of_frac(1, 0)
2257
2258        # The result is well defined if both inputs are negative
2259        self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1))
2260
2261
2262class TestStdev(VarianceStdevMixin, NumericTestCase):
2263    # Tests for sample standard deviation.
2264    def setUp(self):
2265        self.func = statistics.stdev
2266
2267    def test_single_value(self):
2268        # Override method from VarianceStdevMixin.
2269        for x in (81, 203.74, 3.9e14, Fraction(5, 21), Decimal('35.719')):
2270            self.assertRaises(statistics.StatisticsError, self.func, [x])
2271
2272    def test_compare_to_variance(self):
2273        # Test that stdev is, in fact, the square root of variance.
2274        data = [random.uniform(-2, 9) for _ in range(1000)]
2275        expected = math.sqrt(statistics.variance(data))
2276        self.assertAlmostEqual(self.func(data), expected)
2277
2278    def test_center_not_at_mean(self):
2279        data = (1.0, 2.0)
2280        self.assertEqual(self.func(data, xbar=2.0), 1.0)
2281
2282class TestGeometricMean(unittest.TestCase):
2283
2284    def test_basics(self):
2285        geometric_mean = statistics.geometric_mean
2286        self.assertAlmostEqual(geometric_mean([54, 24, 36]), 36.0)
2287        self.assertAlmostEqual(geometric_mean([4.0, 9.0]), 6.0)
2288        self.assertAlmostEqual(geometric_mean([17.625]), 17.625)
2289
2290        random.seed(86753095551212)
2291        for rng in [
2292                range(1, 100),
2293                range(1, 1_000),
2294                range(1, 10_000),
2295                range(500, 10_000, 3),
2296                range(10_000, 500, -3),
2297                [12, 17, 13, 5, 120, 7],
2298                [random.expovariate(50.0) for i in range(1_000)],
2299                [random.lognormvariate(20.0, 3.0) for i in range(2_000)],
2300                [random.triangular(2000, 3000, 2200) for i in range(3_000)],
2301            ]:
2302            gm_decimal = math.prod(map(Decimal, rng)) ** (Decimal(1) / len(rng))
2303            gm_float = geometric_mean(rng)
2304            self.assertTrue(math.isclose(gm_float, float(gm_decimal)))
2305
2306    def test_various_input_types(self):
2307        geometric_mean = statistics.geometric_mean
2308        D = Decimal
2309        F = Fraction
2310        # https://www.wolframalpha.com/input/?i=geometric+mean+3.5,+4.0,+5.25
2311        expected_mean = 4.18886
2312        for data, kind in [
2313            ([3.5, 4.0, 5.25], 'floats'),
2314            ([D('3.5'), D('4.0'), D('5.25')], 'decimals'),
2315            ([F(7, 2), F(4, 1), F(21, 4)], 'fractions'),
2316            ([3.5, 4, F(21, 4)], 'mixed types'),
2317            ((3.5, 4.0, 5.25), 'tuple'),
2318            (iter([3.5, 4.0, 5.25]), 'iterator'),
2319                ]:
2320            actual_mean = geometric_mean(data)
2321            self.assertIs(type(actual_mean), float, kind)
2322            self.assertAlmostEqual(actual_mean, expected_mean, places=5)
2323
2324    def test_big_and_small(self):
2325        geometric_mean = statistics.geometric_mean
2326
2327        # Avoid overflow to infinity
2328        large = 2.0 ** 1000
2329        big_gm = geometric_mean([54.0 * large, 24.0 * large, 36.0 * large])
2330        self.assertTrue(math.isclose(big_gm, 36.0 * large))
2331        self.assertFalse(math.isinf(big_gm))
2332
2333        # Avoid underflow to zero
2334        small = 2.0 ** -1000
2335        small_gm = geometric_mean([54.0 * small, 24.0 * small, 36.0 * small])
2336        self.assertTrue(math.isclose(small_gm, 36.0 * small))
2337        self.assertNotEqual(small_gm, 0.0)
2338
2339    def test_error_cases(self):
2340        geometric_mean = statistics.geometric_mean
2341        StatisticsError = statistics.StatisticsError
2342        with self.assertRaises(StatisticsError):
2343            geometric_mean([])                      # empty input
2344        with self.assertRaises(StatisticsError):
2345            geometric_mean([3.5, 0.0, 5.25])        # zero input
2346        with self.assertRaises(StatisticsError):
2347            geometric_mean([3.5, -4.0, 5.25])       # negative input
2348        with self.assertRaises(StatisticsError):
2349            geometric_mean(iter([]))                # empty iterator
2350        with self.assertRaises(TypeError):
2351            geometric_mean(None)                    # non-iterable input
2352        with self.assertRaises(TypeError):
2353            geometric_mean([10, None, 20])          # non-numeric input
2354        with self.assertRaises(TypeError):
2355            geometric_mean()                        # missing data argument
2356        with self.assertRaises(TypeError):
2357            geometric_mean([10, 20, 60], 70)        # too many arguments
2358
2359    def test_special_values(self):
2360        # Rules for special values are inherited from math.fsum()
2361        geometric_mean = statistics.geometric_mean
2362        NaN = float('Nan')
2363        Inf = float('Inf')
2364        self.assertTrue(math.isnan(geometric_mean([10, NaN])), 'nan')
2365        self.assertTrue(math.isnan(geometric_mean([NaN, Inf])), 'nan and infinity')
2366        self.assertTrue(math.isinf(geometric_mean([10, Inf])), 'infinity')
2367        with self.assertRaises(ValueError):
2368            geometric_mean([Inf, -Inf])
2369
2370    def test_mixed_int_and_float(self):
2371        # Regression test for b.p.o. issue #28327
2372        geometric_mean = statistics.geometric_mean
2373        expected_mean = 3.80675409583932
2374        values = [
2375            [2, 3, 5, 7],
2376            [2, 3, 5, 7.0],
2377            [2, 3, 5.0, 7.0],
2378            [2, 3.0, 5.0, 7.0],
2379            [2.0, 3.0, 5.0, 7.0],
2380        ]
2381        for v in values:
2382            with self.subTest(v=v):
2383                actual_mean = geometric_mean(v)
2384                self.assertAlmostEqual(actual_mean, expected_mean, places=5)
2385
2386
2387class TestQuantiles(unittest.TestCase):
2388
2389    def test_specific_cases(self):
2390        # Match results computed by hand and cross-checked
2391        # against the PERCENTILE.EXC function in MS Excel.
2392        quantiles = statistics.quantiles
2393        data = [120, 200, 250, 320, 350]
2394        random.shuffle(data)
2395        for n, expected in [
2396            (1, []),
2397            (2, [250.0]),
2398            (3, [200.0, 320.0]),
2399            (4, [160.0, 250.0, 335.0]),
2400            (5, [136.0, 220.0, 292.0, 344.0]),
2401            (6, [120.0, 200.0, 250.0, 320.0, 350.0]),
2402            (8, [100.0, 160.0, 212.5, 250.0, 302.5, 335.0, 357.5]),
2403            (10, [88.0, 136.0, 184.0, 220.0, 250.0, 292.0, 326.0, 344.0, 362.0]),
2404            (12, [80.0, 120.0, 160.0, 200.0, 225.0, 250.0, 285.0, 320.0, 335.0,
2405                  350.0, 365.0]),
2406            (15, [72.0, 104.0, 136.0, 168.0, 200.0, 220.0, 240.0, 264.0, 292.0,
2407                  320.0, 332.0, 344.0, 356.0, 368.0]),
2408                ]:
2409            self.assertEqual(expected, quantiles(data, n=n))
2410            self.assertEqual(len(quantiles(data, n=n)), n - 1)
2411            # Preserve datatype when possible
2412            for datatype in (float, Decimal, Fraction):
2413                result = quantiles(map(datatype, data), n=n)
2414                self.assertTrue(all(type(x) == datatype) for x in result)
2415                self.assertEqual(result, list(map(datatype, expected)))
2416            # Quantiles should be idempotent
2417            if len(expected) >= 2:
2418                self.assertEqual(quantiles(expected, n=n), expected)
2419            # Cross-check against method='inclusive' which should give
2420            # the same result after adding in minimum and maximum values
2421            # extrapolated from the two lowest and two highest points.
2422            sdata = sorted(data)
2423            lo = 2 * sdata[0] - sdata[1]
2424            hi = 2 * sdata[-1] - sdata[-2]
2425            padded_data = data + [lo, hi]
2426            self.assertEqual(
2427                quantiles(data, n=n),
2428                quantiles(padded_data, n=n, method='inclusive'),
2429                (n, data),
2430            )
2431            # Invariant under translation and scaling
2432            def f(x):
2433                return 3.5 * x - 1234.675
2434            exp = list(map(f, expected))
2435            act = quantiles(map(f, data), n=n)
2436            self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act)))
2437        # Q2 agrees with median()
2438        for k in range(2, 60):
2439            data = random.choices(range(100), k=k)
2440            q1, q2, q3 = quantiles(data)
2441            self.assertEqual(q2, statistics.median(data))
2442
2443    def test_specific_cases_inclusive(self):
2444        # Match results computed by hand and cross-checked
2445        # against the PERCENTILE.INC function in MS Excel
2446        # and against the quantile() function in SciPy.
2447        quantiles = statistics.quantiles
2448        data = [100, 200, 400, 800]
2449        random.shuffle(data)
2450        for n, expected in [
2451            (1, []),
2452            (2, [300.0]),
2453            (3, [200.0, 400.0]),
2454            (4, [175.0, 300.0, 500.0]),
2455            (5, [160.0, 240.0, 360.0, 560.0]),
2456            (6, [150.0, 200.0, 300.0, 400.0, 600.0]),
2457            (8, [137.5, 175, 225.0, 300.0, 375.0, 500.0,650.0]),
2458            (10, [130.0, 160.0, 190.0, 240.0, 300.0, 360.0, 440.0, 560.0, 680.0]),
2459            (12, [125.0, 150.0, 175.0, 200.0, 250.0, 300.0, 350.0, 400.0,
2460                  500.0, 600.0, 700.0]),
2461            (15, [120.0, 140.0, 160.0, 180.0, 200.0, 240.0, 280.0, 320.0, 360.0,
2462                  400.0, 480.0, 560.0, 640.0, 720.0]),
2463                ]:
2464            self.assertEqual(expected, quantiles(data, n=n, method="inclusive"))
2465            self.assertEqual(len(quantiles(data, n=n, method="inclusive")), n - 1)
2466            # Preserve datatype when possible
2467            for datatype in (float, Decimal, Fraction):
2468                result = quantiles(map(datatype, data), n=n, method="inclusive")
2469                self.assertTrue(all(type(x) == datatype) for x in result)
2470                self.assertEqual(result, list(map(datatype, expected)))
2471            # Invariant under translation and scaling
2472            def f(x):
2473                return 3.5 * x - 1234.675
2474            exp = list(map(f, expected))
2475            act = quantiles(map(f, data), n=n, method="inclusive")
2476            self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act)))
2477        # Natural deciles
2478        self.assertEqual(quantiles([0, 100], n=10, method='inclusive'),
2479                         [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0])
2480        self.assertEqual(quantiles(range(0, 101), n=10, method='inclusive'),
2481                         [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0])
2482        # Whenever n is smaller than the number of data points, running
2483        # method='inclusive' should give the same result as method='exclusive'
2484        # after the two included extreme points are removed.
2485        data = [random.randrange(10_000) for i in range(501)]
2486        actual = quantiles(data, n=32, method='inclusive')
2487        data.remove(min(data))
2488        data.remove(max(data))
2489        expected = quantiles(data, n=32)
2490        self.assertEqual(expected, actual)
2491        # Q2 agrees with median()
2492        for k in range(2, 60):
2493            data = random.choices(range(100), k=k)
2494            q1, q2, q3 = quantiles(data, method='inclusive')
2495            self.assertEqual(q2, statistics.median(data))
2496
2497    def test_equal_inputs(self):
2498        quantiles = statistics.quantiles
2499        for n in range(2, 10):
2500            data = [10.0] * n
2501            self.assertEqual(quantiles(data), [10.0, 10.0, 10.0])
2502            self.assertEqual(quantiles(data, method='inclusive'),
2503                            [10.0, 10.0, 10.0])
2504
2505    def test_equal_sized_groups(self):
2506        quantiles = statistics.quantiles
2507        total = 10_000
2508        data = [random.expovariate(0.2) for i in range(total)]
2509        while len(set(data)) != total:
2510            data.append(random.expovariate(0.2))
2511        data.sort()
2512
2513        # Cases where the group size exactly divides the total
2514        for n in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000):
2515            group_size = total // n
2516            self.assertEqual(
2517                [bisect.bisect(data, q) for q in quantiles(data, n=n)],
2518                list(range(group_size, total, group_size)))
2519
2520        # When the group sizes can't be exactly equal, they should
2521        # differ by no more than one
2522        for n in (13, 19, 59, 109, 211, 571, 1019, 1907, 5261, 9769):
2523            group_sizes = {total // n, total // n + 1}
2524            pos = [bisect.bisect(data, q) for q in quantiles(data, n=n)]
2525            sizes = {q - p for p, q in zip(pos, pos[1:])}
2526            self.assertTrue(sizes <= group_sizes)
2527
2528    def test_error_cases(self):
2529        quantiles = statistics.quantiles
2530        StatisticsError = statistics.StatisticsError
2531        with self.assertRaises(TypeError):
2532            quantiles()                         # Missing arguments
2533        with self.assertRaises(TypeError):
2534            quantiles([10, 20, 30], 13, n=4)    # Too many arguments
2535        with self.assertRaises(TypeError):
2536            quantiles([10, 20, 30], 4)          # n is a positional argument
2537        with self.assertRaises(StatisticsError):
2538            quantiles([10, 20, 30], n=0)        # n is zero
2539        with self.assertRaises(StatisticsError):
2540            quantiles([10, 20, 30], n=-1)       # n is negative
2541        with self.assertRaises(TypeError):
2542            quantiles([10, 20, 30], n=1.5)      # n is not an integer
2543        with self.assertRaises(ValueError):
2544            quantiles([10, 20, 30], method='X') # method is unknown
2545        with self.assertRaises(StatisticsError):
2546            quantiles([10], n=4)                # not enough data points
2547        with self.assertRaises(TypeError):
2548            quantiles([10, None, 30], n=4)      # data is non-numeric
2549
2550
2551class TestBivariateStatistics(unittest.TestCase):
2552
2553    def test_unequal_size_error(self):
2554        for x, y in [
2555            ([1, 2, 3], [1, 2]),
2556            ([1, 2], [1, 2, 3]),
2557        ]:
2558            with self.assertRaises(statistics.StatisticsError):
2559                statistics.covariance(x, y)
2560            with self.assertRaises(statistics.StatisticsError):
2561                statistics.correlation(x, y)
2562            with self.assertRaises(statistics.StatisticsError):
2563                statistics.linear_regression(x, y)
2564
2565    def test_small_sample_error(self):
2566        for x, y in [
2567            ([], []),
2568            ([], [1, 2,]),
2569            ([1, 2,], []),
2570            ([1,], [1,]),
2571            ([1,], [1, 2,]),
2572            ([1, 2,], [1,]),
2573        ]:
2574            with self.assertRaises(statistics.StatisticsError):
2575                statistics.covariance(x, y)
2576            with self.assertRaises(statistics.StatisticsError):
2577                statistics.correlation(x, y)
2578            with self.assertRaises(statistics.StatisticsError):
2579                statistics.linear_regression(x, y)
2580
2581
2582class TestCorrelationAndCovariance(unittest.TestCase):
2583
2584    def test_results(self):
2585        for x, y, result in [
2586            ([1, 2, 3], [1, 2, 3], 1),
2587            ([1, 2, 3], [-1, -2, -3], -1),
2588            ([1, 2, 3], [3, 2, 1], -1),
2589            ([1, 2, 3], [1, 2, 1], 0),
2590            ([1, 2, 3], [1, 3, 2], 0.5),
2591        ]:
2592            self.assertAlmostEqual(statistics.correlation(x, y), result)
2593            self.assertAlmostEqual(statistics.covariance(x, y), result)
2594
2595    def test_different_scales(self):
2596        x = [1, 2, 3]
2597        y = [10, 30, 20]
2598        self.assertAlmostEqual(statistics.correlation(x, y), 0.5)
2599        self.assertAlmostEqual(statistics.covariance(x, y), 5)
2600
2601        y = [.1, .2, .3]
2602        self.assertAlmostEqual(statistics.correlation(x, y), 1)
2603        self.assertAlmostEqual(statistics.covariance(x, y), 0.1)
2604
2605
2606class TestLinearRegression(unittest.TestCase):
2607
2608    def test_constant_input_error(self):
2609        x = [1, 1, 1,]
2610        y = [1, 2, 3,]
2611        with self.assertRaises(statistics.StatisticsError):
2612            statistics.linear_regression(x, y)
2613
2614    def test_results(self):
2615        for x, y, true_intercept, true_slope in [
2616            ([1, 2, 3], [0, 0, 0], 0, 0),
2617            ([1, 2, 3], [1, 2, 3], 0, 1),
2618            ([1, 2, 3], [100, 100, 100], 100, 0),
2619            ([1, 2, 3], [12, 14, 16], 10, 2),
2620            ([1, 2, 3], [-1, -2, -3], 0, -1),
2621            ([1, 2, 3], [21, 22, 23], 20, 1),
2622            ([1, 2, 3], [5.1, 5.2, 5.3], 5, 0.1),
2623        ]:
2624            slope, intercept = statistics.linear_regression(x, y)
2625            self.assertAlmostEqual(intercept, true_intercept)
2626            self.assertAlmostEqual(slope, true_slope)
2627
2628    def test_proportional(self):
2629        x = [10, 20, 30, 40]
2630        y = [180, 398, 610, 799]
2631        slope, intercept = statistics.linear_regression(x, y, proportional=True)
2632        self.assertAlmostEqual(slope, 20 + 1/150)
2633        self.assertEqual(intercept, 0.0)
2634
2635class TestNormalDist:
2636
2637    # General note on precision: The pdf(), cdf(), and overlap() methods
2638    # depend on functions in the math libraries that do not make
2639    # explicit accuracy guarantees.  Accordingly, some of the accuracy
2640    # tests below may fail if the underlying math functions are
2641    # inaccurate.  There isn't much we can do about this short of
2642    # implementing our own implementations from scratch.
2643
2644    def test_slots(self):
2645        nd = self.module.NormalDist(300, 23)
2646        with self.assertRaises(TypeError):
2647            vars(nd)
2648        self.assertEqual(tuple(nd.__slots__), ('_mu', '_sigma'))
2649
2650    def test_instantiation_and_attributes(self):
2651        nd = self.module.NormalDist(500, 17)
2652        self.assertEqual(nd.mean, 500)
2653        self.assertEqual(nd.stdev, 17)
2654        self.assertEqual(nd.variance, 17**2)
2655
2656        # default arguments
2657        nd = self.module.NormalDist()
2658        self.assertEqual(nd.mean, 0)
2659        self.assertEqual(nd.stdev, 1)
2660        self.assertEqual(nd.variance, 1**2)
2661
2662        # error case: negative sigma
2663        with self.assertRaises(self.module.StatisticsError):
2664            self.module.NormalDist(500, -10)
2665
2666        # verify that subclass type is honored
2667        class NewNormalDist(self.module.NormalDist):
2668            pass
2669        nnd = NewNormalDist(200, 5)
2670        self.assertEqual(type(nnd), NewNormalDist)
2671
2672    def test_alternative_constructor(self):
2673        NormalDist = self.module.NormalDist
2674        data = [96, 107, 90, 92, 110]
2675        # list input
2676        self.assertEqual(NormalDist.from_samples(data), NormalDist(99, 9))
2677        # tuple input
2678        self.assertEqual(NormalDist.from_samples(tuple(data)), NormalDist(99, 9))
2679        # iterator input
2680        self.assertEqual(NormalDist.from_samples(iter(data)), NormalDist(99, 9))
2681        # error cases
2682        with self.assertRaises(self.module.StatisticsError):
2683            NormalDist.from_samples([])                      # empty input
2684        with self.assertRaises(self.module.StatisticsError):
2685            NormalDist.from_samples([10])                    # only one input
2686
2687        # verify that subclass type is honored
2688        class NewNormalDist(NormalDist):
2689            pass
2690        nnd = NewNormalDist.from_samples(data)
2691        self.assertEqual(type(nnd), NewNormalDist)
2692
2693    def test_sample_generation(self):
2694        NormalDist = self.module.NormalDist
2695        mu, sigma = 10_000, 3.0
2696        X = NormalDist(mu, sigma)
2697        n = 1_000
2698        data = X.samples(n)
2699        self.assertEqual(len(data), n)
2700        self.assertEqual(set(map(type, data)), {float})
2701        # mean(data) expected to fall within 8 standard deviations
2702        xbar = self.module.mean(data)
2703        self.assertTrue(mu - sigma*8 <= xbar <= mu + sigma*8)
2704
2705        # verify that seeding makes reproducible sequences
2706        n = 100
2707        data1 = X.samples(n, seed='happiness and joy')
2708        data2 = X.samples(n, seed='trouble and despair')
2709        data3 = X.samples(n, seed='happiness and joy')
2710        data4 = X.samples(n, seed='trouble and despair')
2711        self.assertEqual(data1, data3)
2712        self.assertEqual(data2, data4)
2713        self.assertNotEqual(data1, data2)
2714
2715    def test_pdf(self):
2716        NormalDist = self.module.NormalDist
2717        X = NormalDist(100, 15)
2718        # Verify peak around center
2719        self.assertLess(X.pdf(99), X.pdf(100))
2720        self.assertLess(X.pdf(101), X.pdf(100))
2721        # Test symmetry
2722        for i in range(50):
2723            self.assertAlmostEqual(X.pdf(100 - i), X.pdf(100 + i))
2724        # Test vs CDF
2725        dx = 2.0 ** -10
2726        for x in range(90, 111):
2727            est_pdf = (X.cdf(x + dx) - X.cdf(x)) / dx
2728            self.assertAlmostEqual(X.pdf(x), est_pdf, places=4)
2729        # Test vs table of known values -- CRC 26th Edition
2730        Z = NormalDist()
2731        for x, px in enumerate([
2732            0.3989, 0.3989, 0.3989, 0.3988, 0.3986,
2733            0.3984, 0.3982, 0.3980, 0.3977, 0.3973,
2734            0.3970, 0.3965, 0.3961, 0.3956, 0.3951,
2735            0.3945, 0.3939, 0.3932, 0.3925, 0.3918,
2736            0.3910, 0.3902, 0.3894, 0.3885, 0.3876,
2737            0.3867, 0.3857, 0.3847, 0.3836, 0.3825,
2738            0.3814, 0.3802, 0.3790, 0.3778, 0.3765,
2739            0.3752, 0.3739, 0.3725, 0.3712, 0.3697,
2740            0.3683, 0.3668, 0.3653, 0.3637, 0.3621,
2741            0.3605, 0.3589, 0.3572, 0.3555, 0.3538,
2742        ]):
2743            self.assertAlmostEqual(Z.pdf(x / 100.0), px, places=4)
2744            self.assertAlmostEqual(Z.pdf(-x / 100.0), px, places=4)
2745        # Error case: variance is zero
2746        Y = NormalDist(100, 0)
2747        with self.assertRaises(self.module.StatisticsError):
2748            Y.pdf(90)
2749        # Special values
2750        self.assertEqual(X.pdf(float('-Inf')), 0.0)
2751        self.assertEqual(X.pdf(float('Inf')), 0.0)
2752        self.assertTrue(math.isnan(X.pdf(float('NaN'))))
2753
2754    def test_cdf(self):
2755        NormalDist = self.module.NormalDist
2756        X = NormalDist(100, 15)
2757        cdfs = [X.cdf(x) for x in range(1, 200)]
2758        self.assertEqual(set(map(type, cdfs)), {float})
2759        # Verify montonic
2760        self.assertEqual(cdfs, sorted(cdfs))
2761        # Verify center (should be exact)
2762        self.assertEqual(X.cdf(100), 0.50)
2763        # Check against a table of known values
2764        # https://en.wikipedia.org/wiki/Standard_normal_table#Cumulative
2765        Z = NormalDist()
2766        for z, cum_prob in [
2767            (0.00, 0.50000), (0.01, 0.50399), (0.02, 0.50798),
2768            (0.14, 0.55567), (0.29, 0.61409), (0.33, 0.62930),
2769            (0.54, 0.70540), (0.60, 0.72575), (1.17, 0.87900),
2770            (1.60, 0.94520), (2.05, 0.97982), (2.89, 0.99807),
2771            (3.52, 0.99978), (3.98, 0.99997), (4.07, 0.99998),
2772            ]:
2773            self.assertAlmostEqual(Z.cdf(z), cum_prob, places=5)
2774            self.assertAlmostEqual(Z.cdf(-z), 1.0 - cum_prob, places=5)
2775        # Error case: variance is zero
2776        Y = NormalDist(100, 0)
2777        with self.assertRaises(self.module.StatisticsError):
2778            Y.cdf(90)
2779        # Special values
2780        self.assertEqual(X.cdf(float('-Inf')), 0.0)
2781        self.assertEqual(X.cdf(float('Inf')), 1.0)
2782        self.assertTrue(math.isnan(X.cdf(float('NaN'))))
2783
2784    @support.skip_if_pgo_task
2785    def test_inv_cdf(self):
2786        NormalDist = self.module.NormalDist
2787
2788        # Center case should be exact.
2789        iq = NormalDist(100, 15)
2790        self.assertEqual(iq.inv_cdf(0.50), iq.mean)
2791
2792        # Test versus a published table of known percentage points.
2793        # See the second table at the bottom of the page here:
2794        # http://people.bath.ac.uk/masss/tables/normaltable.pdf
2795        Z = NormalDist()
2796        pp = {5.0: (0.000, 1.645, 2.576, 3.291, 3.891,
2797                    4.417, 4.892, 5.327, 5.731, 6.109),
2798              2.5: (0.674, 1.960, 2.807, 3.481, 4.056,
2799                    4.565, 5.026, 5.451, 5.847, 6.219),
2800              1.0: (1.282, 2.326, 3.090, 3.719, 4.265,
2801                    4.753, 5.199, 5.612, 5.998, 6.361)}
2802        for base, row in pp.items():
2803            for exp, x in enumerate(row, start=1):
2804                p = base * 10.0 ** (-exp)
2805                self.assertAlmostEqual(-Z.inv_cdf(p), x, places=3)
2806                p = 1.0 - p
2807                self.assertAlmostEqual(Z.inv_cdf(p), x, places=3)
2808
2809        # Match published example for MS Excel
2810        # https://support.office.com/en-us/article/norm-inv-function-54b30935-fee7-493c-bedb-2278a9db7e13
2811        self.assertAlmostEqual(NormalDist(40, 1.5).inv_cdf(0.908789), 42.000002)
2812
2813        # One million equally spaced probabilities
2814        n = 2**20
2815        for p in range(1, n):
2816            p /= n
2817            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2818
2819        # One hundred ever smaller probabilities to test tails out to
2820        # extreme probabilities: 1 / 2**50 and (2**50-1) / 2 ** 50
2821        for e in range(1, 51):
2822            p = 2.0 ** (-e)
2823            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2824            p = 1.0 - p
2825            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2826
2827        # Now apply cdf() first.  Near the tails, the round-trip loses
2828        # precision and is ill-conditioned (small changes in the inputs
2829        # give large changes in the output), so only check to 5 places.
2830        for x in range(200):
2831            self.assertAlmostEqual(iq.inv_cdf(iq.cdf(x)), x, places=5)
2832
2833        # Error cases:
2834        with self.assertRaises(self.module.StatisticsError):
2835            iq.inv_cdf(0.0)                         # p is zero
2836        with self.assertRaises(self.module.StatisticsError):
2837            iq.inv_cdf(-0.1)                        # p under zero
2838        with self.assertRaises(self.module.StatisticsError):
2839            iq.inv_cdf(1.0)                         # p is one
2840        with self.assertRaises(self.module.StatisticsError):
2841            iq.inv_cdf(1.1)                         # p over one
2842        with self.assertRaises(self.module.StatisticsError):
2843            iq = NormalDist(100, 0)                 # sigma is zero
2844            iq.inv_cdf(0.5)
2845
2846        # Special values
2847        self.assertTrue(math.isnan(Z.inv_cdf(float('NaN'))))
2848
2849    def test_quantiles(self):
2850        # Quartiles of a standard normal distribution
2851        Z = self.module.NormalDist()
2852        for n, expected in [
2853            (1, []),
2854            (2, [0.0]),
2855            (3, [-0.4307, 0.4307]),
2856            (4 ,[-0.6745, 0.0, 0.6745]),
2857                ]:
2858            actual = Z.quantiles(n=n)
2859            self.assertTrue(all(math.isclose(e, a, abs_tol=0.0001)
2860                            for e, a in zip(expected, actual)))
2861
2862    def test_overlap(self):
2863        NormalDist = self.module.NormalDist
2864
2865        # Match examples from Imman and Bradley
2866        for X1, X2, published_result in [
2867                (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0), 0.80258),
2868                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0), 0.60993),
2869            ]:
2870            self.assertAlmostEqual(X1.overlap(X2), published_result, places=4)
2871            self.assertAlmostEqual(X2.overlap(X1), published_result, places=4)
2872
2873        # Check against integration of the PDF
2874        def overlap_numeric(X, Y, *, steps=8_192, z=5):
2875            'Numerical integration cross-check for overlap() '
2876            fsum = math.fsum
2877            center = (X.mean + Y.mean) / 2.0
2878            width = z * max(X.stdev, Y.stdev)
2879            start = center - width
2880            dx = 2.0 * width / steps
2881            x_arr = [start + i*dx for i in range(steps)]
2882            xp = list(map(X.pdf, x_arr))
2883            yp = list(map(Y.pdf, x_arr))
2884            total = max(fsum(xp), fsum(yp))
2885            return fsum(map(min, xp, yp)) / total
2886
2887        for X1, X2 in [
2888                # Examples from Imman and Bradley
2889                (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0)),
2890                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)),
2891                # Example from https://www.rasch.org/rmt/rmt101r.htm
2892                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)),
2893                # Gender heights from http://www.usablestats.com/lessons/normal
2894                (NormalDist(70, 4), NormalDist(65, 3.5)),
2895                # Misc cases with equal standard deviations
2896                (NormalDist(100, 15), NormalDist(110, 15)),
2897                (NormalDist(-100, 15), NormalDist(110, 15)),
2898                (NormalDist(-100, 15), NormalDist(-110, 15)),
2899                # Misc cases with unequal standard deviations
2900                (NormalDist(100, 12), NormalDist(100, 15)),
2901                (NormalDist(100, 12), NormalDist(110, 15)),
2902                (NormalDist(100, 12), NormalDist(150, 15)),
2903                (NormalDist(100, 12), NormalDist(150, 35)),
2904                # Misc cases with small values
2905                (NormalDist(1.000, 0.002), NormalDist(1.001, 0.003)),
2906                (NormalDist(1.000, 0.002), NormalDist(1.006, 0.0003)),
2907                (NormalDist(1.000, 0.002), NormalDist(1.001, 0.099)),
2908            ]:
2909            self.assertAlmostEqual(X1.overlap(X2), overlap_numeric(X1, X2), places=5)
2910            self.assertAlmostEqual(X2.overlap(X1), overlap_numeric(X1, X2), places=5)
2911
2912        # Error cases
2913        X = NormalDist()
2914        with self.assertRaises(TypeError):
2915            X.overlap()                             # too few arguments
2916        with self.assertRaises(TypeError):
2917            X.overlap(X, X)                         # too may arguments
2918        with self.assertRaises(TypeError):
2919            X.overlap(None)                         # right operand not a NormalDist
2920        with self.assertRaises(self.module.StatisticsError):
2921            X.overlap(NormalDist(1, 0))             # right operand sigma is zero
2922        with self.assertRaises(self.module.StatisticsError):
2923            NormalDist(1, 0).overlap(X)             # left operand sigma is zero
2924
2925    def test_zscore(self):
2926        NormalDist = self.module.NormalDist
2927        X = NormalDist(100, 15)
2928        self.assertEqual(X.zscore(142), 2.8)
2929        self.assertEqual(X.zscore(58), -2.8)
2930        self.assertEqual(X.zscore(100), 0.0)
2931        with self.assertRaises(TypeError):
2932            X.zscore()                              # too few arguments
2933        with self.assertRaises(TypeError):
2934            X.zscore(1, 1)                          # too may arguments
2935        with self.assertRaises(TypeError):
2936            X.zscore(None)                          # non-numeric type
2937        with self.assertRaises(self.module.StatisticsError):
2938            NormalDist(1, 0).zscore(100)            # sigma is zero
2939
2940    def test_properties(self):
2941        X = self.module.NormalDist(100, 15)
2942        self.assertEqual(X.mean, 100)
2943        self.assertEqual(X.median, 100)
2944        self.assertEqual(X.mode, 100)
2945        self.assertEqual(X.stdev, 15)
2946        self.assertEqual(X.variance, 225)
2947
2948    def test_same_type_addition_and_subtraction(self):
2949        NormalDist = self.module.NormalDist
2950        X = NormalDist(100, 12)
2951        Y = NormalDist(40, 5)
2952        self.assertEqual(X + Y, NormalDist(140, 13))        # __add__
2953        self.assertEqual(X - Y, NormalDist(60, 13))         # __sub__
2954
2955    def test_translation_and_scaling(self):
2956        NormalDist = self.module.NormalDist
2957        X = NormalDist(100, 15)
2958        y = 10
2959        self.assertEqual(+X, NormalDist(100, 15))           # __pos__
2960        self.assertEqual(-X, NormalDist(-100, 15))          # __neg__
2961        self.assertEqual(X + y, NormalDist(110, 15))        # __add__
2962        self.assertEqual(y + X, NormalDist(110, 15))        # __radd__
2963        self.assertEqual(X - y, NormalDist(90, 15))         # __sub__
2964        self.assertEqual(y - X, NormalDist(-90, 15))        # __rsub__
2965        self.assertEqual(X * y, NormalDist(1000, 150))      # __mul__
2966        self.assertEqual(y * X, NormalDist(1000, 150))      # __rmul__
2967        self.assertEqual(X / y, NormalDist(10, 1.5))        # __truediv__
2968        with self.assertRaises(TypeError):                  # __rtruediv__
2969            y / X
2970
2971    def test_unary_operations(self):
2972        NormalDist = self.module.NormalDist
2973        X = NormalDist(100, 12)
2974        Y = +X
2975        self.assertIsNot(X, Y)
2976        self.assertEqual(X.mean, Y.mean)
2977        self.assertEqual(X.stdev, Y.stdev)
2978        Y = -X
2979        self.assertIsNot(X, Y)
2980        self.assertEqual(X.mean, -Y.mean)
2981        self.assertEqual(X.stdev, Y.stdev)
2982
2983    def test_equality(self):
2984        NormalDist = self.module.NormalDist
2985        nd1 = NormalDist()
2986        nd2 = NormalDist(2, 4)
2987        nd3 = NormalDist()
2988        nd4 = NormalDist(2, 4)
2989        nd5 = NormalDist(2, 8)
2990        nd6 = NormalDist(8, 4)
2991        self.assertNotEqual(nd1, nd2)
2992        self.assertEqual(nd1, nd3)
2993        self.assertEqual(nd2, nd4)
2994        self.assertNotEqual(nd2, nd5)
2995        self.assertNotEqual(nd2, nd6)
2996
2997        # Test NotImplemented when types are different
2998        class A:
2999            def __eq__(self, other):
3000                return 10
3001        a = A()
3002        self.assertEqual(nd1.__eq__(a), NotImplemented)
3003        self.assertEqual(nd1 == a, 10)
3004        self.assertEqual(a == nd1, 10)
3005
3006        # All subclasses to compare equal giving the same behavior
3007        # as list, tuple, int, float, complex, str, dict, set, etc.
3008        class SizedNormalDist(NormalDist):
3009            def __init__(self, mu, sigma, n):
3010                super().__init__(mu, sigma)
3011                self.n = n
3012        s = SizedNormalDist(100, 15, 57)
3013        nd4 = NormalDist(100, 15)
3014        self.assertEqual(s, nd4)
3015
3016        # Don't allow duck type equality because we wouldn't
3017        # want a lognormal distribution to compare equal
3018        # to a normal distribution with the same parameters
3019        class LognormalDist:
3020            def __init__(self, mu, sigma):
3021                self.mu = mu
3022                self.sigma = sigma
3023        lnd = LognormalDist(100, 15)
3024        nd = NormalDist(100, 15)
3025        self.assertNotEqual(nd, lnd)
3026
3027    def test_pickle_and_copy(self):
3028        nd = self.module.NormalDist(37.5, 5.625)
3029        nd1 = copy.copy(nd)
3030        self.assertEqual(nd, nd1)
3031        nd2 = copy.deepcopy(nd)
3032        self.assertEqual(nd, nd2)
3033        nd3 = pickle.loads(pickle.dumps(nd))
3034        self.assertEqual(nd, nd3)
3035
3036    def test_hashability(self):
3037        ND = self.module.NormalDist
3038        s = {ND(100, 15), ND(100.0, 15.0), ND(100, 10), ND(95, 15), ND(100, 15)}
3039        self.assertEqual(len(s), 3)
3040
3041    def test_repr(self):
3042        nd = self.module.NormalDist(37.5, 5.625)
3043        self.assertEqual(repr(nd), 'NormalDist(mu=37.5, sigma=5.625)')
3044
3045# Swapping the sys.modules['statistics'] is to solving the
3046# _pickle.PicklingError:
3047# Can't pickle <class 'statistics.NormalDist'>:
3048# it's not the same object as statistics.NormalDist
3049class TestNormalDistPython(unittest.TestCase, TestNormalDist):
3050    module = py_statistics
3051    def setUp(self):
3052        sys.modules['statistics'] = self.module
3053
3054    def tearDown(self):
3055        sys.modules['statistics'] = statistics
3056
3057
3058@unittest.skipUnless(c_statistics, 'requires _statistics')
3059class TestNormalDistC(unittest.TestCase, TestNormalDist):
3060    module = c_statistics
3061    def setUp(self):
3062        sys.modules['statistics'] = self.module
3063
3064    def tearDown(self):
3065        sys.modules['statistics'] = statistics
3066
3067
3068# === Run tests ===
3069
3070def load_tests(loader, tests, ignore):
3071    """Used for doctest/unittest integration."""
3072    tests.addTests(doctest.DocTestSuite())
3073    return tests
3074
3075
3076if __name__ == "__main__":
3077    unittest.main()
3078