1"""
2Basic statistics module.
3
4This module provides functions for calculating statistics of data, including
5averages, variance, and standard deviation.
6
7Calculating averages
8--------------------
9
10==================  ==================================================
11Function            Description
12==================  ==================================================
13mean                Arithmetic mean (average) of data.
14fmean               Fast, floating point arithmetic mean.
15geometric_mean      Geometric mean of data.
16harmonic_mean       Harmonic mean of data.
17median              Median (middle value) of data.
18median_low          Low median of data.
19median_high         High median of data.
20median_grouped      Median, or 50th percentile, of grouped data.
21mode                Mode (most common value) of data.
22multimode           List of modes (most common values of data).
23quantiles           Divide data into intervals with equal probability.
24==================  ==================================================
25
26Calculate the arithmetic mean ("the average") of data:
27
28>>> mean([-1.0, 2.5, 3.25, 5.75])
292.625
30
31
32Calculate the standard median of discrete data:
33
34>>> median([2, 3, 4, 5])
353.5
36
37
38Calculate the median, or 50th percentile, of data grouped into class intervals
39centred on the data values provided. E.g. if your data points are rounded to
40the nearest whole number:
41
42>>> median_grouped([2, 2, 3, 3, 3, 4])  #doctest: +ELLIPSIS
432.8333333333...
44
45This should be interpreted in this way: you have two data points in the class
46interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
47the class interval 3.5-4.5. The median of these data points is 2.8333...
48
49
50Calculating variability or spread
51---------------------------------
52
53==================  =============================================
54Function            Description
55==================  =============================================
56pvariance           Population variance of data.
57variance            Sample variance of data.
58pstdev              Population standard deviation of data.
59stdev               Sample standard deviation of data.
60==================  =============================================
61
62Calculate the standard deviation of sample data:
63
64>>> stdev([2.5, 3.25, 5.5, 11.25, 11.75])  #doctest: +ELLIPSIS
654.38961843444...
66
67If you have previously calculated the mean, you can pass it as the optional
68second argument to the four "spread" functions to avoid recalculating it:
69
70>>> data = [1, 2, 2, 4, 4, 4, 5, 6]
71>>> mu = mean(data)
72>>> pvariance(data, mu)
732.5
74
75
76Statistics for relations between two inputs
77-------------------------------------------
78
79==================  ====================================================
80Function            Description
81==================  ====================================================
82covariance          Sample covariance for two variables.
83correlation         Pearson's correlation coefficient for two variables.
84linear_regression   Intercept and slope for simple linear regression.
85==================  ====================================================
86
87Calculate covariance, Pearson's correlation, and simple linear regression
88for two inputs:
89
90>>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
91>>> y = [1, 2, 3, 1, 2, 3, 1, 2, 3]
92>>> covariance(x, y)
930.75
94>>> correlation(x, y)  #doctest: +ELLIPSIS
950.31622776601...
96>>> linear_regression(x, y)  #doctest:
97LinearRegression(slope=0.1, intercept=1.5)
98
99
100Exceptions
101----------
102
103A single exception is defined: StatisticsError is a subclass of ValueError.
104
105"""
106
107__all__ = [
108    'NormalDist',
109    'StatisticsError',
110    'correlation',
111    'covariance',
112    'fmean',
113    'geometric_mean',
114    'harmonic_mean',
115    'linear_regression',
116    'mean',
117    'median',
118    'median_grouped',
119    'median_high',
120    'median_low',
121    'mode',
122    'multimode',
123    'pstdev',
124    'pvariance',
125    'quantiles',
126    'stdev',
127    'variance',
128]
129
130import math
131import numbers
132import random
133import sys
134
135from fractions import Fraction
136from decimal import Decimal
137from itertools import groupby, repeat
138from bisect import bisect_left, bisect_right
139from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
140from operator import mul
141from collections import Counter, namedtuple
142
143_SQRT2 = sqrt(2.0)
144
145# === Exceptions ===
146
147class StatisticsError(ValueError):
148    pass
149
150
151# === Private utilities ===
152
153def _sum(data):
154    """_sum(data) -> (type, sum, count)
155
156    Return a high-precision sum of the given numeric data as a fraction,
157    together with the type to be converted to and the count of items.
158
159    Examples
160    --------
161
162    >>> _sum([3, 2.25, 4.5, -0.5, 0.25])
163    (<class 'float'>, Fraction(19, 2), 5)
164
165    Some sources of round-off error will be avoided:
166
167    # Built-in sum returns zero.
168    >>> _sum([1e50, 1, -1e50] * 1000)
169    (<class 'float'>, Fraction(1000, 1), 3000)
170
171    Fractions and Decimals are also supported:
172
173    >>> from fractions import Fraction as F
174    >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
175    (<class 'fractions.Fraction'>, Fraction(63, 20), 4)
176
177    >>> from decimal import Decimal as D
178    >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
179    >>> _sum(data)
180    (<class 'decimal.Decimal'>, Fraction(6963, 10000), 4)
181
182    Mixed types are currently treated as an error, except that int is
183    allowed.
184    """
185    count = 0
186    partials = {}
187    partials_get = partials.get
188    T = int
189    for typ, values in groupby(data, type):
190        T = _coerce(T, typ)  # or raise TypeError
191        for n, d in map(_exact_ratio, values):
192            count += 1
193            partials[d] = partials_get(d, 0) + n
194    if None in partials:
195        # The sum will be a NAN or INF. We can ignore all the finite
196        # partials, and just look at this special one.
197        total = partials[None]
198        assert not _isfinite(total)
199    else:
200        # Sum all the partial sums using builtin sum.
201        total = sum(Fraction(n, d) for d, n in partials.items())
202    return (T, total, count)
203
204
205def _isfinite(x):
206    try:
207        return x.is_finite()  # Likely a Decimal.
208    except AttributeError:
209        return math.isfinite(x)  # Coerces to float first.
210
211
212def _coerce(T, S):
213    """Coerce types T and S to a common type, or raise TypeError.
214
215    Coercion rules are currently an implementation detail. See the CoerceTest
216    test class in test_statistics for details.
217    """
218    # See http://bugs.python.org/issue24068.
219    assert T is not bool, "initial type T is bool"
220    # If the types are the same, no need to coerce anything. Put this
221    # first, so that the usual case (no coercion needed) happens as soon
222    # as possible.
223    if T is S:  return T
224    # Mixed int & other coerce to the other type.
225    if S is int or S is bool:  return T
226    if T is int:  return S
227    # If one is a (strict) subclass of the other, coerce to the subclass.
228    if issubclass(S, T):  return S
229    if issubclass(T, S):  return T
230    # Ints coerce to the other type.
231    if issubclass(T, int):  return S
232    if issubclass(S, int):  return T
233    # Mixed fraction & float coerces to float (or float subclass).
234    if issubclass(T, Fraction) and issubclass(S, float):
235        return S
236    if issubclass(T, float) and issubclass(S, Fraction):
237        return T
238    # Any other combination is disallowed.
239    msg = "don't know how to coerce %s and %s"
240    raise TypeError(msg % (T.__name__, S.__name__))
241
242
243def _exact_ratio(x):
244    """Return Real number x to exact (numerator, denominator) pair.
245
246    >>> _exact_ratio(0.25)
247    (1, 4)
248
249    x is expected to be an int, Fraction, Decimal or float.
250    """
251
252    # XXX We should revisit whether using fractions to accumulate exact
253    # ratios is the right way to go.
254
255    # The integer ratios for binary floats can have numerators or
256    # denominators with over 300 decimal digits.  The problem is more
257    # acute with decimal floats where the the default decimal context
258    # supports a huge range of exponents from Emin=-999999 to
259    # Emax=999999.  When expanded with as_integer_ratio(), numbers like
260    # Decimal('3.14E+5000') and Decimal('3.14E-5000') have large
261    # numerators or denominators that will slow computation.
262
263    # When the integer ratios are accumulated as fractions, the size
264    # grows to cover the full range from the smallest magnitude to the
265    # largest.  For example, Fraction(3.14E+300) + Fraction(3.14E-300),
266    # has a 616 digit numerator.  Likewise,
267    # Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000'))
268    # has 10,003 digit numerator.
269
270    # This doesn't seem to have been problem in practice, but it is a
271    # potential pitfall.
272
273    try:
274        return x.as_integer_ratio()
275    except AttributeError:
276        pass
277    except (OverflowError, ValueError):
278        # float NAN or INF.
279        assert not _isfinite(x)
280        return (x, None)
281    try:
282        # x may be an Integral ABC.
283        return (x.numerator, x.denominator)
284    except AttributeError:
285        msg = f"can't convert type '{type(x).__name__}' to numerator/denominator"
286        raise TypeError(msg)
287
288
289def _convert(value, T):
290    """Convert value to given numeric type T."""
291    if type(value) is T:
292        # This covers the cases where T is Fraction, or where value is
293        # a NAN or INF (Decimal or float).
294        return value
295    if issubclass(T, int) and value.denominator != 1:
296        T = float
297    try:
298        # FIXME: what do we do if this overflows?
299        return T(value)
300    except TypeError:
301        if issubclass(T, Decimal):
302            return T(value.numerator) / T(value.denominator)
303        else:
304            raise
305
306
307def _find_lteq(a, x):
308    'Locate the leftmost value exactly equal to x'
309    i = bisect_left(a, x)
310    if i != len(a) and a[i] == x:
311        return i
312    raise ValueError
313
314
315def _find_rteq(a, l, x):
316    'Locate the rightmost value exactly equal to x'
317    i = bisect_right(a, x, lo=l)
318    if i != (len(a) + 1) and a[i - 1] == x:
319        return i - 1
320    raise ValueError
321
322
323def _fail_neg(values, errmsg='negative value'):
324    """Iterate over values, failing if any are less than zero."""
325    for x in values:
326        if x < 0:
327            raise StatisticsError(errmsg)
328        yield x
329
330
331def _integer_sqrt_of_frac_rto(n: int, m: int) -> int:
332    """Square root of n/m, rounded to the nearest integer using round-to-odd."""
333    # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
334    a = math.isqrt(n // m)
335    return a | (a*a*m != n)
336
337
338# For 53 bit precision floats, the bit width used in
339# _float_sqrt_of_frac() is 109.
340_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3
341
342
343def _float_sqrt_of_frac(n: int, m: int) -> float:
344    """Square root of n/m as a float, correctly rounded."""
345    # See principle and proof sketch at: https://bugs.python.org/msg407078
346    q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2
347    if q >= 0:
348        numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q
349        denominator = 1
350    else:
351        numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m)
352        denominator = 1 << -q
353    return numerator / denominator   # Convert to float
354
355
356def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal:
357    """Square root of n/m as a Decimal, correctly rounded."""
358    # Premise:  For decimal, computing (n/m).sqrt() can be off
359    #           by 1 ulp from the correctly rounded result.
360    # Method:   Check the result, moving up or down a step if needed.
361    if n <= 0:
362        if not n:
363            return Decimal('0.0')
364        n, m = -n, -m
365
366    root = (Decimal(n) / Decimal(m)).sqrt()
367    nr, dr = root.as_integer_ratio()
368
369    plus = root.next_plus()
370    np, dp = plus.as_integer_ratio()
371    # test: n / m > ((root + plus) / 2) ** 2
372    if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2:
373        return plus
374
375    minus = root.next_minus()
376    nm, dm = minus.as_integer_ratio()
377    # test: n / m < ((root + minus) / 2) ** 2
378    if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2:
379        return minus
380
381    return root
382
383
384# === Measures of central tendency (averages) ===
385
386def mean(data):
387    """Return the sample arithmetic mean of data.
388
389    >>> mean([1, 2, 3, 4, 4])
390    2.8
391
392    >>> from fractions import Fraction as F
393    >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
394    Fraction(13, 21)
395
396    >>> from decimal import Decimal as D
397    >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
398    Decimal('0.5625')
399
400    If ``data`` is empty, StatisticsError will be raised.
401    """
402    if iter(data) is data:
403        data = list(data)
404    n = len(data)
405    if n < 1:
406        raise StatisticsError('mean requires at least one data point')
407    T, total, count = _sum(data)
408    assert count == n
409    return _convert(total / n, T)
410
411
412def fmean(data, weights=None):
413    """Convert data to floats and compute the arithmetic mean.
414
415    This runs faster than the mean() function and it always returns a float.
416    If the input dataset is empty, it raises a StatisticsError.
417
418    >>> fmean([3.5, 4.0, 5.25])
419    4.25
420    """
421    try:
422        n = len(data)
423    except TypeError:
424        # Handle iterators that do not define __len__().
425        n = 0
426        def count(iterable):
427            nonlocal n
428            for n, x in enumerate(iterable, start=1):
429                yield x
430        data = count(data)
431    if weights is None:
432        total = fsum(data)
433        if not n:
434            raise StatisticsError('fmean requires at least one data point')
435        return total / n
436    try:
437        num_weights = len(weights)
438    except TypeError:
439        weights = list(weights)
440        num_weights = len(weights)
441    num = fsum(map(mul, data, weights))
442    if n != num_weights:
443        raise StatisticsError('data and weights must be the same length')
444    den = fsum(weights)
445    if not den:
446        raise StatisticsError('sum of weights must be non-zero')
447    return num / den
448
449
450def geometric_mean(data):
451    """Convert data to floats and compute the geometric mean.
452
453    Raises a StatisticsError if the input dataset is empty,
454    if it contains a zero, or if it contains a negative value.
455
456    No special efforts are made to achieve exact results.
457    (However, this may change in the future.)
458
459    >>> round(geometric_mean([54, 24, 36]), 9)
460    36.0
461    """
462    try:
463        return exp(fmean(map(log, data)))
464    except ValueError:
465        raise StatisticsError('geometric mean requires a non-empty dataset '
466                              'containing positive numbers') from None
467
468
469def harmonic_mean(data, weights=None):
470    """Return the harmonic mean of data.
471
472    The harmonic mean is the reciprocal of the arithmetic mean of the
473    reciprocals of the data.  It can be used for averaging ratios or
474    rates, for example speeds.
475
476    Suppose a car travels 40 km/hr for 5 km and then speeds-up to
477    60 km/hr for another 5 km. What is the average speed?
478
479        >>> harmonic_mean([40, 60])
480        48.0
481
482    Suppose a car travels 40 km/hr for 5 km, and when traffic clears,
483    speeds-up to 60 km/hr for the remaining 30 km of the journey. What
484    is the average speed?
485
486        >>> harmonic_mean([40, 60], weights=[5, 30])
487        56.0
488
489    If ``data`` is empty, or any element is less than zero,
490    ``harmonic_mean`` will raise ``StatisticsError``.
491    """
492    if iter(data) is data:
493        data = list(data)
494    errmsg = 'harmonic mean does not support negative values'
495    n = len(data)
496    if n < 1:
497        raise StatisticsError('harmonic_mean requires at least one data point')
498    elif n == 1 and weights is None:
499        x = data[0]
500        if isinstance(x, (numbers.Real, Decimal)):
501            if x < 0:
502                raise StatisticsError(errmsg)
503            return x
504        else:
505            raise TypeError('unsupported type')
506    if weights is None:
507        weights = repeat(1, n)
508        sum_weights = n
509    else:
510        if iter(weights) is weights:
511            weights = list(weights)
512        if len(weights) != n:
513            raise StatisticsError('Number of weights does not match data size')
514        _, sum_weights, _ = _sum(w for w in _fail_neg(weights, errmsg))
515    try:
516        data = _fail_neg(data, errmsg)
517        T, total, count = _sum(w / x if w else 0 for w, x in zip(weights, data))
518    except ZeroDivisionError:
519        return 0
520    if total <= 0:
521        raise StatisticsError('Weighted sum must be positive')
522    return _convert(sum_weights / total, T)
523
524# FIXME: investigate ways to calculate medians without sorting? Quickselect?
525def median(data):
526    """Return the median (middle value) of numeric data.
527
528    When the number of data points is odd, return the middle data point.
529    When the number of data points is even, the median is interpolated by
530    taking the average of the two middle values:
531
532    >>> median([1, 3, 5])
533    3
534    >>> median([1, 3, 5, 7])
535    4.0
536
537    """
538    data = sorted(data)
539    n = len(data)
540    if n == 0:
541        raise StatisticsError("no median for empty data")
542    if n % 2 == 1:
543        return data[n // 2]
544    else:
545        i = n // 2
546        return (data[i - 1] + data[i]) / 2
547
548
549def median_low(data):
550    """Return the low median of numeric data.
551
552    When the number of data points is odd, the middle value is returned.
553    When it is even, the smaller of the two middle values is returned.
554
555    >>> median_low([1, 3, 5])
556    3
557    >>> median_low([1, 3, 5, 7])
558    3
559
560    """
561    data = sorted(data)
562    n = len(data)
563    if n == 0:
564        raise StatisticsError("no median for empty data")
565    if n % 2 == 1:
566        return data[n // 2]
567    else:
568        return data[n // 2 - 1]
569
570
571def median_high(data):
572    """Return the high median of data.
573
574    When the number of data points is odd, the middle value is returned.
575    When it is even, the larger of the two middle values is returned.
576
577    >>> median_high([1, 3, 5])
578    3
579    >>> median_high([1, 3, 5, 7])
580    5
581
582    """
583    data = sorted(data)
584    n = len(data)
585    if n == 0:
586        raise StatisticsError("no median for empty data")
587    return data[n // 2]
588
589
590def median_grouped(data, interval=1):
591    """Return the 50th percentile (median) of grouped continuous data.
592
593    >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
594    3.7
595    >>> median_grouped([52, 52, 53, 54])
596    52.5
597
598    This calculates the median as the 50th percentile, and should be
599    used when your data is continuous and grouped. In the above example,
600    the values 1, 2, 3, etc. actually represent the midpoint of classes
601    0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
602    class 3.5-4.5, and interpolation is used to estimate it.
603
604    Optional argument ``interval`` represents the class interval, and
605    defaults to 1. Changing the class interval naturally will change the
606    interpolated 50th percentile value:
607
608    >>> median_grouped([1, 3, 3, 5, 7], interval=1)
609    3.25
610    >>> median_grouped([1, 3, 3, 5, 7], interval=2)
611    3.5
612
613    This function does not check whether the data points are at least
614    ``interval`` apart.
615    """
616    data = sorted(data)
617    n = len(data)
618    if n == 0:
619        raise StatisticsError("no median for empty data")
620    elif n == 1:
621        return data[0]
622    # Find the value at the midpoint. Remember this corresponds to the
623    # centre of the class interval.
624    x = data[n // 2]
625    for obj in (x, interval):
626        if isinstance(obj, (str, bytes)):
627            raise TypeError('expected number but got %r' % obj)
628    try:
629        L = x - interval / 2  # The lower limit of the median interval.
630    except TypeError:
631        # Mixed type. For now we just coerce to float.
632        L = float(x) - float(interval) / 2
633
634    # Uses bisection search to search for x in data with log(n) time complexity
635    # Find the position of leftmost occurrence of x in data
636    l1 = _find_lteq(data, x)
637    # Find the position of rightmost occurrence of x in data[l1...len(data)]
638    # Assuming always l1 <= l2
639    l2 = _find_rteq(data, l1, x)
640    cf = l1
641    f = l2 - l1 + 1
642    return L + interval * (n / 2 - cf) / f
643
644
645def mode(data):
646    """Return the most common data point from discrete or nominal data.
647
648    ``mode`` assumes discrete data, and returns a single value. This is the
649    standard treatment of the mode as commonly taught in schools:
650
651        >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
652        3
653
654    This also works with nominal (non-numeric) data:
655
656        >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
657        'red'
658
659    If there are multiple modes with same frequency, return the first one
660    encountered:
661
662        >>> mode(['red', 'red', 'green', 'blue', 'blue'])
663        'red'
664
665    If *data* is empty, ``mode``, raises StatisticsError.
666
667    """
668    pairs = Counter(iter(data)).most_common(1)
669    try:
670        return pairs[0][0]
671    except IndexError:
672        raise StatisticsError('no mode for empty data') from None
673
674
675def multimode(data):
676    """Return a list of the most frequently occurring values.
677
678    Will return more than one result if there are multiple modes
679    or an empty list if *data* is empty.
680
681    >>> multimode('aabbbbbbbbcc')
682    ['b']
683    >>> multimode('aabbbbccddddeeffffgg')
684    ['b', 'd', 'f']
685    >>> multimode('')
686    []
687    """
688    counts = Counter(iter(data))
689    if not counts:
690        return []
691    maxcount = max(counts.values())
692    return [value for value, count in counts.items() if count == maxcount]
693
694
695# Notes on methods for computing quantiles
696# ----------------------------------------
697#
698# There is no one perfect way to compute quantiles.  Here we offer
699# two methods that serve common needs.  Most other packages
700# surveyed offered at least one or both of these two, making them
701# "standard" in the sense of "widely-adopted and reproducible".
702# They are also easy to explain, easy to compute manually, and have
703# straight-forward interpretations that aren't surprising.
704
705# The default method is known as "R6", "PERCENTILE.EXC", or "expected
706# value of rank order statistics". The alternative method is known as
707# "R7", "PERCENTILE.INC", or "mode of rank order statistics".
708
709# For sample data where there is a positive probability for values
710# beyond the range of the data, the R6 exclusive method is a
711# reasonable choice.  Consider a random sample of nine values from a
712# population with a uniform distribution from 0.0 to 1.0.  The
713# distribution of the third ranked sample point is described by
714# betavariate(alpha=3, beta=7) which has mode=0.250, median=0.286, and
715# mean=0.300.  Only the latter (which corresponds with R6) gives the
716# desired cut point with 30% of the population falling below that
717# value, making it comparable to a result from an inv_cdf() function.
718# The R6 exclusive method is also idempotent.
719
720# For describing population data where the end points are known to
721# be included in the data, the R7 inclusive method is a reasonable
722# choice.  Instead of the mean, it uses the mode of the beta
723# distribution for the interior points.  Per Hyndman & Fan, "One nice
724# property is that the vertices of Q7(p) divide the range into n - 1
725# intervals, and exactly 100p% of the intervals lie to the left of
726# Q7(p) and 100(1 - p)% of the intervals lie to the right of Q7(p)."
727
728# If needed, other methods could be added.  However, for now, the
729# position is that fewer options make for easier choices and that
730# external packages can be used for anything more advanced.
731
732def quantiles(data, *, n=4, method='exclusive'):
733    """Divide *data* into *n* continuous intervals with equal probability.
734
735    Returns a list of (n - 1) cut points separating the intervals.
736
737    Set *n* to 4 for quartiles (the default).  Set *n* to 10 for deciles.
738    Set *n* to 100 for percentiles which gives the 99 cuts points that
739    separate *data* in to 100 equal sized groups.
740
741    The *data* can be any iterable containing sample.
742    The cut points are linearly interpolated between data points.
743
744    If *method* is set to *inclusive*, *data* is treated as population
745    data.  The minimum value is treated as the 0th percentile and the
746    maximum value is treated as the 100th percentile.
747    """
748    if n < 1:
749        raise StatisticsError('n must be at least 1')
750    data = sorted(data)
751    ld = len(data)
752    if ld < 2:
753        raise StatisticsError('must have at least two data points')
754    if method == 'inclusive':
755        m = ld - 1
756        result = []
757        for i in range(1, n):
758            j, delta = divmod(i * m, n)
759            interpolated = (data[j] * (n - delta) + data[j + 1] * delta) / n
760            result.append(interpolated)
761        return result
762    if method == 'exclusive':
763        m = ld + 1
764        result = []
765        for i in range(1, n):
766            j = i * m // n                               # rescale i to m/n
767            j = 1 if j < 1 else ld-1 if j > ld-1 else j  # clamp to 1 .. ld-1
768            delta = i*m - j*n                            # exact integer math
769            interpolated = (data[j - 1] * (n - delta) + data[j] * delta) / n
770            result.append(interpolated)
771        return result
772    raise ValueError(f'Unknown method: {method!r}')
773
774
775# === Measures of spread ===
776
777# See http://mathworld.wolfram.com/Variance.html
778#     http://mathworld.wolfram.com/SampleVariance.html
779#     http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
780#
781# Under no circumstances use the so-called "computational formula for
782# variance", as that is only suitable for hand calculations with a small
783# amount of low-precision data. It has terrible numeric properties.
784#
785# See a comparison of three computational methods here:
786# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
787
788def _ss(data, c=None):
789    """Return sum of square deviations of sequence data.
790
791    If ``c`` is None, the mean is calculated in one pass, and the deviations
792    from the mean are calculated in a second pass. Otherwise, deviations are
793    calculated from ``c`` as given. Use the second case with care, as it can
794    lead to garbage results.
795    """
796    if c is not None:
797        T, total, count = _sum((d := x - c) * d for x in data)
798        return (T, total)
799    T, total, count = _sum(data)
800    mean_n, mean_d = (total / count).as_integer_ratio()
801    partials = Counter()
802    for n, d in map(_exact_ratio, data):
803        diff_n = n * mean_d - d * mean_n
804        diff_d = d * mean_d
805        partials[diff_d * diff_d] += diff_n * diff_n
806    if None in partials:
807        # The sum will be a NAN or INF. We can ignore all the finite
808        # partials, and just look at this special one.
809        total = partials[None]
810        assert not _isfinite(total)
811    else:
812        total = sum(Fraction(n, d) for d, n in partials.items())
813    return (T, total)
814
815
816def variance(data, xbar=None):
817    """Return the sample variance of data.
818
819    data should be an iterable of Real-valued numbers, with at least two
820    values. The optional argument xbar, if given, should be the mean of
821    the data. If it is missing or None, the mean is automatically calculated.
822
823    Use this function when your data is a sample from a population. To
824    calculate the variance from the entire population, see ``pvariance``.
825
826    Examples:
827
828    >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
829    >>> variance(data)
830    1.3720238095238095
831
832    If you have already calculated the mean of your data, you can pass it as
833    the optional second argument ``xbar`` to avoid recalculating it:
834
835    >>> m = mean(data)
836    >>> variance(data, m)
837    1.3720238095238095
838
839    This function does not check that ``xbar`` is actually the mean of
840    ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
841    impossible results.
842
843    Decimals and Fractions are supported:
844
845    >>> from decimal import Decimal as D
846    >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
847    Decimal('31.01875')
848
849    >>> from fractions import Fraction as F
850    >>> variance([F(1, 6), F(1, 2), F(5, 3)])
851    Fraction(67, 108)
852
853    """
854    if iter(data) is data:
855        data = list(data)
856    n = len(data)
857    if n < 2:
858        raise StatisticsError('variance requires at least two data points')
859    T, ss = _ss(data, xbar)
860    return _convert(ss / (n - 1), T)
861
862
863def pvariance(data, mu=None):
864    """Return the population variance of ``data``.
865
866    data should be a sequence or iterable of Real-valued numbers, with at least one
867    value. The optional argument mu, if given, should be the mean of
868    the data. If it is missing or None, the mean is automatically calculated.
869
870    Use this function to calculate the variance from the entire population.
871    To estimate the variance from a sample, the ``variance`` function is
872    usually a better choice.
873
874    Examples:
875
876    >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
877    >>> pvariance(data)
878    1.25
879
880    If you have already calculated the mean of the data, you can pass it as
881    the optional second argument to avoid recalculating it:
882
883    >>> mu = mean(data)
884    >>> pvariance(data, mu)
885    1.25
886
887    Decimals and Fractions are supported:
888
889    >>> from decimal import Decimal as D
890    >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
891    Decimal('24.815')
892
893    >>> from fractions import Fraction as F
894    >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
895    Fraction(13, 72)
896
897    """
898    if iter(data) is data:
899        data = list(data)
900    n = len(data)
901    if n < 1:
902        raise StatisticsError('pvariance requires at least one data point')
903    T, ss = _ss(data, mu)
904    return _convert(ss / n, T)
905
906
907def stdev(data, xbar=None):
908    """Return the square root of the sample variance.
909
910    See ``variance`` for arguments and other details.
911
912    >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
913    1.0810874155219827
914
915    """
916    if iter(data) is data:
917        data = list(data)
918    n = len(data)
919    if n < 2:
920        raise StatisticsError('stdev requires at least two data points')
921    T, ss = _ss(data, xbar)
922    mss = ss / (n - 1)
923    if issubclass(T, Decimal):
924        return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
925    return _float_sqrt_of_frac(mss.numerator, mss.denominator)
926
927
928def pstdev(data, mu=None):
929    """Return the square root of the population variance.
930
931    See ``pvariance`` for arguments and other details.
932
933    >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
934    0.986893273527251
935
936    """
937    if iter(data) is data:
938        data = list(data)
939    n = len(data)
940    if n < 1:
941        raise StatisticsError('pstdev requires at least one data point')
942    T, ss = _ss(data, mu)
943    mss = ss / n
944    if issubclass(T, Decimal):
945        return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
946    return _float_sqrt_of_frac(mss.numerator, mss.denominator)
947
948
949# === Statistics for relations between two inputs ===
950
951# See https://en.wikipedia.org/wiki/Covariance
952#     https://en.wikipedia.org/wiki/Pearson_correlation_coefficient
953#     https://en.wikipedia.org/wiki/Simple_linear_regression
954
955
956def covariance(x, y, /):
957    """Covariance
958
959    Return the sample covariance of two inputs *x* and *y*. Covariance
960    is a measure of the joint variability of two inputs.
961
962    >>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
963    >>> y = [1, 2, 3, 1, 2, 3, 1, 2, 3]
964    >>> covariance(x, y)
965    0.75
966    >>> z = [9, 8, 7, 6, 5, 4, 3, 2, 1]
967    >>> covariance(x, z)
968    -7.5
969    >>> covariance(z, x)
970    -7.5
971
972    """
973    n = len(x)
974    if len(y) != n:
975        raise StatisticsError('covariance requires that both inputs have same number of data points')
976    if n < 2:
977        raise StatisticsError('covariance requires at least two data points')
978    xbar = fsum(x) / n
979    ybar = fsum(y) / n
980    sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
981    return sxy / (n - 1)
982
983
984def correlation(x, y, /):
985    """Pearson's correlation coefficient
986
987    Return the Pearson's correlation coefficient for two inputs. Pearson's
988    correlation coefficient *r* takes values between -1 and +1. It measures the
989    strength and direction of the linear relationship, where +1 means very
990    strong, positive linear relationship, -1 very strong, negative linear
991    relationship, and 0 no linear relationship.
992
993    >>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
994    >>> y = [9, 8, 7, 6, 5, 4, 3, 2, 1]
995    >>> correlation(x, x)
996    1.0
997    >>> correlation(x, y)
998    -1.0
999
1000    """
1001    n = len(x)
1002    if len(y) != n:
1003        raise StatisticsError('correlation requires that both inputs have same number of data points')
1004    if n < 2:
1005        raise StatisticsError('correlation requires at least two data points')
1006    xbar = fsum(x) / n
1007    ybar = fsum(y) / n
1008    sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
1009    sxx = fsum((d := xi - xbar) * d for xi in x)
1010    syy = fsum((d := yi - ybar) * d for yi in y)
1011    try:
1012        return sxy / sqrt(sxx * syy)
1013    except ZeroDivisionError:
1014        raise StatisticsError('at least one of the inputs is constant')
1015
1016
1017LinearRegression = namedtuple('LinearRegression', ('slope', 'intercept'))
1018
1019
1020def linear_regression(x, y, /, *, proportional=False):
1021    """Slope and intercept for simple linear regression.
1022
1023    Return the slope and intercept of simple linear regression
1024    parameters estimated using ordinary least squares. Simple linear
1025    regression describes relationship between an independent variable
1026    *x* and a dependent variable *y* in terms of a linear function:
1027
1028        y = slope * x + intercept + noise
1029
1030    where *slope* and *intercept* are the regression parameters that are
1031    estimated, and noise represents the variability of the data that was
1032    not explained by the linear regression (it is equal to the
1033    difference between predicted and actual values of the dependent
1034    variable).
1035
1036    The parameters are returned as a named tuple.
1037
1038    >>> x = [1, 2, 3, 4, 5]
1039    >>> noise = NormalDist().samples(5, seed=42)
1040    >>> y = [3 * x[i] + 2 + noise[i] for i in range(5)]
1041    >>> linear_regression(x, y)  #doctest: +ELLIPSIS
1042    LinearRegression(slope=3.09078914170..., intercept=1.75684970486...)
1043
1044    If *proportional* is true, the independent variable *x* and the
1045    dependent variable *y* are assumed to be directly proportional.
1046    The data is fit to a line passing through the origin.
1047
1048    Since the *intercept* will always be 0.0, the underlying linear
1049    function simplifies to:
1050
1051        y = slope * x + noise
1052
1053    >>> y = [3 * x[i] + noise[i] for i in range(5)]
1054    >>> linear_regression(x, y, proportional=True)  #doctest: +ELLIPSIS
1055    LinearRegression(slope=3.02447542484..., intercept=0.0)
1056
1057    """
1058    n = len(x)
1059    if len(y) != n:
1060        raise StatisticsError('linear regression requires that both inputs have same number of data points')
1061    if n < 2:
1062        raise StatisticsError('linear regression requires at least two data points')
1063    if proportional:
1064        sxy = fsum(xi * yi for xi, yi in zip(x, y))
1065        sxx = fsum(xi * xi for xi in x)
1066    else:
1067        xbar = fsum(x) / n
1068        ybar = fsum(y) / n
1069        sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
1070        sxx = fsum((d := xi - xbar) * d for xi in x)
1071    try:
1072        slope = sxy / sxx   # equivalent to:  covariance(x, y) / variance(x)
1073    except ZeroDivisionError:
1074        raise StatisticsError('x is constant')
1075    intercept = 0.0 if proportional else ybar - slope * xbar
1076    return LinearRegression(slope=slope, intercept=intercept)
1077
1078
1079## Normal Distribution #####################################################
1080
1081
1082def _normal_dist_inv_cdf(p, mu, sigma):
1083    # There is no closed-form solution to the inverse CDF for the normal
1084    # distribution, so we use a rational approximation instead:
1085    # Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
1086    # Normal Distribution".  Applied Statistics. Blackwell Publishing. 37
1087    # (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.
1088    q = p - 0.5
1089    if fabs(q) <= 0.425:
1090        r = 0.180625 - q * q
1091        # Hash sum: 55.88319_28806_14901_4439
1092        num = (((((((2.50908_09287_30122_6727e+3 * r +
1093                     3.34305_75583_58812_8105e+4) * r +
1094                     6.72657_70927_00870_0853e+4) * r +
1095                     4.59219_53931_54987_1457e+4) * r +
1096                     1.37316_93765_50946_1125e+4) * r +
1097                     1.97159_09503_06551_4427e+3) * r +
1098                     1.33141_66789_17843_7745e+2) * r +
1099                     3.38713_28727_96366_6080e+0) * q
1100        den = (((((((5.22649_52788_52854_5610e+3 * r +
1101                     2.87290_85735_72194_2674e+4) * r +
1102                     3.93078_95800_09271_0610e+4) * r +
1103                     2.12137_94301_58659_5867e+4) * r +
1104                     5.39419_60214_24751_1077e+3) * r +
1105                     6.87187_00749_20579_0830e+2) * r +
1106                     4.23133_30701_60091_1252e+1) * r +
1107                     1.0)
1108        x = num / den
1109        return mu + (x * sigma)
1110    r = p if q <= 0.0 else 1.0 - p
1111    r = sqrt(-log(r))
1112    if r <= 5.0:
1113        r = r - 1.6
1114        # Hash sum: 49.33206_50330_16102_89036
1115        num = (((((((7.74545_01427_83414_07640e-4 * r +
1116                     2.27238_44989_26918_45833e-2) * r +
1117                     2.41780_72517_74506_11770e-1) * r +
1118                     1.27045_82524_52368_38258e+0) * r +
1119                     3.64784_83247_63204_60504e+0) * r +
1120                     5.76949_72214_60691_40550e+0) * r +
1121                     4.63033_78461_56545_29590e+0) * r +
1122                     1.42343_71107_49683_57734e+0)
1123        den = (((((((1.05075_00716_44416_84324e-9 * r +
1124                     5.47593_80849_95344_94600e-4) * r +
1125                     1.51986_66563_61645_71966e-2) * r +
1126                     1.48103_97642_74800_74590e-1) * r +
1127                     6.89767_33498_51000_04550e-1) * r +
1128                     1.67638_48301_83803_84940e+0) * r +
1129                     2.05319_16266_37758_82187e+0) * r +
1130                     1.0)
1131    else:
1132        r = r - 5.0
1133        # Hash sum: 47.52583_31754_92896_71629
1134        num = (((((((2.01033_43992_92288_13265e-7 * r +
1135                     2.71155_55687_43487_57815e-5) * r +
1136                     1.24266_09473_88078_43860e-3) * r +
1137                     2.65321_89526_57612_30930e-2) * r +
1138                     2.96560_57182_85048_91230e-1) * r +
1139                     1.78482_65399_17291_33580e+0) * r +
1140                     5.46378_49111_64114_36990e+0) * r +
1141                     6.65790_46435_01103_77720e+0)
1142        den = (((((((2.04426_31033_89939_78564e-15 * r +
1143                     1.42151_17583_16445_88870e-7) * r +
1144                     1.84631_83175_10054_68180e-5) * r +
1145                     7.86869_13114_56132_59100e-4) * r +
1146                     1.48753_61290_85061_48525e-2) * r +
1147                     1.36929_88092_27358_05310e-1) * r +
1148                     5.99832_20655_58879_37690e-1) * r +
1149                     1.0)
1150    x = num / den
1151    if q < 0.0:
1152        x = -x
1153    return mu + (x * sigma)
1154
1155
1156# If available, use C implementation
1157try:
1158    from _statistics import _normal_dist_inv_cdf
1159except ImportError:
1160    pass
1161
1162
1163class NormalDist:
1164    "Normal distribution of a random variable"
1165    # https://en.wikipedia.org/wiki/Normal_distribution
1166    # https://en.wikipedia.org/wiki/Variance#Properties
1167
1168    __slots__ = {
1169        '_mu': 'Arithmetic mean of a normal distribution',
1170        '_sigma': 'Standard deviation of a normal distribution',
1171    }
1172
1173    def __init__(self, mu=0.0, sigma=1.0):
1174        "NormalDist where mu is the mean and sigma is the standard deviation."
1175        if sigma < 0.0:
1176            raise StatisticsError('sigma must be non-negative')
1177        self._mu = float(mu)
1178        self._sigma = float(sigma)
1179
1180    @classmethod
1181    def from_samples(cls, data):
1182        "Make a normal distribution instance from sample data."
1183        if not isinstance(data, (list, tuple)):
1184            data = list(data)
1185        xbar = fmean(data)
1186        return cls(xbar, stdev(data, xbar))
1187
1188    def samples(self, n, *, seed=None):
1189        "Generate *n* samples for a given mean and standard deviation."
1190        gauss = random.gauss if seed is None else random.Random(seed).gauss
1191        mu, sigma = self._mu, self._sigma
1192        return [gauss(mu, sigma) for i in range(n)]
1193
1194    def pdf(self, x):
1195        "Probability density function.  P(x <= X < x+dx) / dx"
1196        variance = self._sigma * self._sigma
1197        if not variance:
1198            raise StatisticsError('pdf() not defined when sigma is zero')
1199        diff = x - self._mu
1200        return exp(diff * diff / (-2.0 * variance)) / sqrt(tau * variance)
1201
1202    def cdf(self, x):
1203        "Cumulative distribution function.  P(X <= x)"
1204        if not self._sigma:
1205            raise StatisticsError('cdf() not defined when sigma is zero')
1206        return 0.5 * (1.0 + erf((x - self._mu) / (self._sigma * _SQRT2)))
1207
1208    def inv_cdf(self, p):
1209        """Inverse cumulative distribution function.  x : P(X <= x) = p
1210
1211        Finds the value of the random variable such that the probability of
1212        the variable being less than or equal to that value equals the given
1213        probability.
1214
1215        This function is also called the percent point function or quantile
1216        function.
1217        """
1218        if p <= 0.0 or p >= 1.0:
1219            raise StatisticsError('p must be in the range 0.0 < p < 1.0')
1220        if self._sigma <= 0.0:
1221            raise StatisticsError('cdf() not defined when sigma at or below zero')
1222        return _normal_dist_inv_cdf(p, self._mu, self._sigma)
1223
1224    def quantiles(self, n=4):
1225        """Divide into *n* continuous intervals with equal probability.
1226
1227        Returns a list of (n - 1) cut points separating the intervals.
1228
1229        Set *n* to 4 for quartiles (the default).  Set *n* to 10 for deciles.
1230        Set *n* to 100 for percentiles which gives the 99 cuts points that
1231        separate the normal distribution in to 100 equal sized groups.
1232        """
1233        return [self.inv_cdf(i / n) for i in range(1, n)]
1234
1235    def overlap(self, other):
1236        """Compute the overlapping coefficient (OVL) between two normal distributions.
1237
1238        Measures the agreement between two normal probability distributions.
1239        Returns a value between 0.0 and 1.0 giving the overlapping area in
1240        the two underlying probability density functions.
1241
1242            >>> N1 = NormalDist(2.4, 1.6)
1243            >>> N2 = NormalDist(3.2, 2.0)
1244            >>> N1.overlap(N2)
1245            0.8035050657330205
1246        """
1247        # See: "The overlapping coefficient as a measure of agreement between
1248        # probability distributions and point estimation of the overlap of two
1249        # normal densities" -- Henry F. Inman and Edwin L. Bradley Jr
1250        # http://dx.doi.org/10.1080/03610928908830127
1251        if not isinstance(other, NormalDist):
1252            raise TypeError('Expected another NormalDist instance')
1253        X, Y = self, other
1254        if (Y._sigma, Y._mu) < (X._sigma, X._mu):  # sort to assure commutativity
1255            X, Y = Y, X
1256        X_var, Y_var = X.variance, Y.variance
1257        if not X_var or not Y_var:
1258            raise StatisticsError('overlap() not defined when sigma is zero')
1259        dv = Y_var - X_var
1260        dm = fabs(Y._mu - X._mu)
1261        if not dv:
1262            return 1.0 - erf(dm / (2.0 * X._sigma * _SQRT2))
1263        a = X._mu * Y_var - Y._mu * X_var
1264        b = X._sigma * Y._sigma * sqrt(dm * dm + dv * log(Y_var / X_var))
1265        x1 = (a + b) / dv
1266        x2 = (a - b) / dv
1267        return 1.0 - (fabs(Y.cdf(x1) - X.cdf(x1)) + fabs(Y.cdf(x2) - X.cdf(x2)))
1268
1269    def zscore(self, x):
1270        """Compute the Standard Score.  (x - mean) / stdev
1271
1272        Describes *x* in terms of the number of standard deviations
1273        above or below the mean of the normal distribution.
1274        """
1275        # https://www.statisticshowto.com/probability-and-statistics/z-score/
1276        if not self._sigma:
1277            raise StatisticsError('zscore() not defined when sigma is zero')
1278        return (x - self._mu) / self._sigma
1279
1280    @property
1281    def mean(self):
1282        "Arithmetic mean of the normal distribution."
1283        return self._mu
1284
1285    @property
1286    def median(self):
1287        "Return the median of the normal distribution"
1288        return self._mu
1289
1290    @property
1291    def mode(self):
1292        """Return the mode of the normal distribution
1293
1294        The mode is the value x where which the probability density
1295        function (pdf) takes its maximum value.
1296        """
1297        return self._mu
1298
1299    @property
1300    def stdev(self):
1301        "Standard deviation of the normal distribution."
1302        return self._sigma
1303
1304    @property
1305    def variance(self):
1306        "Square of the standard deviation."
1307        return self._sigma * self._sigma
1308
1309    def __add__(x1, x2):
1310        """Add a constant or another NormalDist instance.
1311
1312        If *other* is a constant, translate mu by the constant,
1313        leaving sigma unchanged.
1314
1315        If *other* is a NormalDist, add both the means and the variances.
1316        Mathematically, this works only if the two distributions are
1317        independent or if they are jointly normally distributed.
1318        """
1319        if isinstance(x2, NormalDist):
1320            return NormalDist(x1._mu + x2._mu, hypot(x1._sigma, x2._sigma))
1321        return NormalDist(x1._mu + x2, x1._sigma)
1322
1323    def __sub__(x1, x2):
1324        """Subtract a constant or another NormalDist instance.
1325
1326        If *other* is a constant, translate by the constant mu,
1327        leaving sigma unchanged.
1328
1329        If *other* is a NormalDist, subtract the means and add the variances.
1330        Mathematically, this works only if the two distributions are
1331        independent or if they are jointly normally distributed.
1332        """
1333        if isinstance(x2, NormalDist):
1334            return NormalDist(x1._mu - x2._mu, hypot(x1._sigma, x2._sigma))
1335        return NormalDist(x1._mu - x2, x1._sigma)
1336
1337    def __mul__(x1, x2):
1338        """Multiply both mu and sigma by a constant.
1339
1340        Used for rescaling, perhaps to change measurement units.
1341        Sigma is scaled with the absolute value of the constant.
1342        """
1343        return NormalDist(x1._mu * x2, x1._sigma * fabs(x2))
1344
1345    def __truediv__(x1, x2):
1346        """Divide both mu and sigma by a constant.
1347
1348        Used for rescaling, perhaps to change measurement units.
1349        Sigma is scaled with the absolute value of the constant.
1350        """
1351        return NormalDist(x1._mu / x2, x1._sigma / fabs(x2))
1352
1353    def __pos__(x1):
1354        "Return a copy of the instance."
1355        return NormalDist(x1._mu, x1._sigma)
1356
1357    def __neg__(x1):
1358        "Negates mu while keeping sigma the same."
1359        return NormalDist(-x1._mu, x1._sigma)
1360
1361    __radd__ = __add__
1362
1363    def __rsub__(x1, x2):
1364        "Subtract a NormalDist from a constant or another NormalDist."
1365        return -(x1 - x2)
1366
1367    __rmul__ = __mul__
1368
1369    def __eq__(x1, x2):
1370        "Two NormalDist objects are equal if their mu and sigma are both equal."
1371        if not isinstance(x2, NormalDist):
1372            return NotImplemented
1373        return x1._mu == x2._mu and x1._sigma == x2._sigma
1374
1375    def __hash__(self):
1376        "NormalDist objects hash equal if their mu and sigma are both equal."
1377        return hash((self._mu, self._sigma))
1378
1379    def __repr__(self):
1380        return f'{type(self).__name__}(mu={self._mu!r}, sigma={self._sigma!r})'
1381