1"""
2Basic statistics module.
3
4This module provides functions for calculating statistics of data, including
5averages, variance, and standard deviation.
6
7Calculating averages
8--------------------
9
10==================  ==================================================
11Function            Description
12==================  ==================================================
13mean                Arithmetic mean (average) of data.
14fmean               Fast, floating point arithmetic mean.
15geometric_mean      Geometric mean of data.
16harmonic_mean       Harmonic mean of data.
17median              Median (middle value) of data.
18median_low          Low median of data.
19median_high         High median of data.
20median_grouped      Median, or 50th percentile, of grouped data.
21mode                Mode (most common value) of data.
22multimode           List of modes (most common values of data).
23quantiles           Divide data into intervals with equal probability.
24==================  ==================================================
25
26Calculate the arithmetic mean ("the average") of data:
27
28>>> mean([-1.0, 2.5, 3.25, 5.75])
292.625
30
31
32Calculate the standard median of discrete data:
33
34>>> median([2, 3, 4, 5])
353.5
36
37
38Calculate the median, or 50th percentile, of data grouped into class intervals
39centred on the data values provided. E.g. if your data points are rounded to
40the nearest whole number:
41
42>>> median_grouped([2, 2, 3, 3, 3, 4])  #doctest: +ELLIPSIS
432.8333333333...
44
45This should be interpreted in this way: you have two data points in the class
46interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
47the class interval 3.5-4.5. The median of these data points is 2.8333...
48
49
50Calculating variability or spread
51---------------------------------
52
53==================  =============================================
54Function            Description
55==================  =============================================
56pvariance           Population variance of data.
57variance            Sample variance of data.
58pstdev              Population standard deviation of data.
59stdev               Sample standard deviation of data.
60==================  =============================================
61
62Calculate the standard deviation of sample data:
63
64>>> stdev([2.5, 3.25, 5.5, 11.25, 11.75])  #doctest: +ELLIPSIS
654.38961843444...
66
67If you have previously calculated the mean, you can pass it as the optional
68second argument to the four "spread" functions to avoid recalculating it:
69
70>>> data = [1, 2, 2, 4, 4, 4, 5, 6]
71>>> mu = mean(data)
72>>> pvariance(data, mu)
732.5
74
75
76Exceptions
77----------
78
79A single exception is defined: StatisticsError is a subclass of ValueError.
80
81"""
82
83__all__ = [
84    'NormalDist',
85    'StatisticsError',
86    'fmean',
87    'geometric_mean',
88    'harmonic_mean',
89    'mean',
90    'median',
91    'median_grouped',
92    'median_high',
93    'median_low',
94    'mode',
95    'multimode',
96    'pstdev',
97    'pvariance',
98    'quantiles',
99    'stdev',
100    'variance',
101]
102
103import math
104import numbers
105import random
106
107from fractions import Fraction
108from decimal import Decimal
109from itertools import groupby
110from bisect import bisect_left, bisect_right
111from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
112from operator import itemgetter
113from collections import Counter
114
115# === Exceptions ===
116
117class StatisticsError(ValueError):
118    pass
119
120
121# === Private utilities ===
122
123def _sum(data, start=0):
124    """_sum(data [, start]) -> (type, sum, count)
125
126    Return a high-precision sum of the given numeric data as a fraction,
127    together with the type to be converted to and the count of items.
128
129    If optional argument ``start`` is given, it is added to the total.
130    If ``data`` is empty, ``start`` (defaulting to 0) is returned.
131
132
133    Examples
134    --------
135
136    >>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
137    (<class 'float'>, Fraction(11, 1), 5)
138
139    Some sources of round-off error will be avoided:
140
141    # Built-in sum returns zero.
142    >>> _sum([1e50, 1, -1e50] * 1000)
143    (<class 'float'>, Fraction(1000, 1), 3000)
144
145    Fractions and Decimals are also supported:
146
147    >>> from fractions import Fraction as F
148    >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
149    (<class 'fractions.Fraction'>, Fraction(63, 20), 4)
150
151    >>> from decimal import Decimal as D
152    >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
153    >>> _sum(data)
154    (<class 'decimal.Decimal'>, Fraction(6963, 10000), 4)
155
156    Mixed types are currently treated as an error, except that int is
157    allowed.
158    """
159    count = 0
160    n, d = _exact_ratio(start)
161    partials = {d: n}
162    partials_get = partials.get
163    T = _coerce(int, type(start))
164    for typ, values in groupby(data, type):
165        T = _coerce(T, typ)  # or raise TypeError
166        for n,d in map(_exact_ratio, values):
167            count += 1
168            partials[d] = partials_get(d, 0) + n
169    if None in partials:
170        # The sum will be a NAN or INF. We can ignore all the finite
171        # partials, and just look at this special one.
172        total = partials[None]
173        assert not _isfinite(total)
174    else:
175        # Sum all the partial sums using builtin sum.
176        # FIXME is this faster if we sum them in order of the denominator?
177        total = sum(Fraction(n, d) for d, n in sorted(partials.items()))
178    return (T, total, count)
179
180
181def _isfinite(x):
182    try:
183        return x.is_finite()  # Likely a Decimal.
184    except AttributeError:
185        return math.isfinite(x)  # Coerces to float first.
186
187
188def _coerce(T, S):
189    """Coerce types T and S to a common type, or raise TypeError.
190
191    Coercion rules are currently an implementation detail. See the CoerceTest
192    test class in test_statistics for details.
193    """
194    # See http://bugs.python.org/issue24068.
195    assert T is not bool, "initial type T is bool"
196    # If the types are the same, no need to coerce anything. Put this
197    # first, so that the usual case (no coercion needed) happens as soon
198    # as possible.
199    if T is S:  return T
200    # Mixed int & other coerce to the other type.
201    if S is int or S is bool:  return T
202    if T is int:  return S
203    # If one is a (strict) subclass of the other, coerce to the subclass.
204    if issubclass(S, T):  return S
205    if issubclass(T, S):  return T
206    # Ints coerce to the other type.
207    if issubclass(T, int):  return S
208    if issubclass(S, int):  return T
209    # Mixed fraction & float coerces to float (or float subclass).
210    if issubclass(T, Fraction) and issubclass(S, float):
211        return S
212    if issubclass(T, float) and issubclass(S, Fraction):
213        return T
214    # Any other combination is disallowed.
215    msg = "don't know how to coerce %s and %s"
216    raise TypeError(msg % (T.__name__, S.__name__))
217
218
219def _exact_ratio(x):
220    """Return Real number x to exact (numerator, denominator) pair.
221
222    >>> _exact_ratio(0.25)
223    (1, 4)
224
225    x is expected to be an int, Fraction, Decimal or float.
226    """
227    try:
228        # Optimise the common case of floats. We expect that the most often
229        # used numeric type will be builtin floats, so try to make this as
230        # fast as possible.
231        if type(x) is float or type(x) is Decimal:
232            return x.as_integer_ratio()
233        try:
234            # x may be an int, Fraction, or Integral ABC.
235            return (x.numerator, x.denominator)
236        except AttributeError:
237            try:
238                # x may be a float or Decimal subclass.
239                return x.as_integer_ratio()
240            except AttributeError:
241                # Just give up?
242                pass
243    except (OverflowError, ValueError):
244        # float NAN or INF.
245        assert not _isfinite(x)
246        return (x, None)
247    msg = "can't convert type '{}' to numerator/denominator"
248    raise TypeError(msg.format(type(x).__name__))
249
250
251def _convert(value, T):
252    """Convert value to given numeric type T."""
253    if type(value) is T:
254        # This covers the cases where T is Fraction, or where value is
255        # a NAN or INF (Decimal or float).
256        return value
257    if issubclass(T, int) and value.denominator != 1:
258        T = float
259    try:
260        # FIXME: what do we do if this overflows?
261        return T(value)
262    except TypeError:
263        if issubclass(T, Decimal):
264            return T(value.numerator)/T(value.denominator)
265        else:
266            raise
267
268
269def _find_lteq(a, x):
270    'Locate the leftmost value exactly equal to x'
271    i = bisect_left(a, x)
272    if i != len(a) and a[i] == x:
273        return i
274    raise ValueError
275
276
277def _find_rteq(a, l, x):
278    'Locate the rightmost value exactly equal to x'
279    i = bisect_right(a, x, lo=l)
280    if i != (len(a)+1) and a[i-1] == x:
281        return i-1
282    raise ValueError
283
284
285def _fail_neg(values, errmsg='negative value'):
286    """Iterate over values, failing if any are less than zero."""
287    for x in values:
288        if x < 0:
289            raise StatisticsError(errmsg)
290        yield x
291
292
293# === Measures of central tendency (averages) ===
294
295def mean(data):
296    """Return the sample arithmetic mean of data.
297
298    >>> mean([1, 2, 3, 4, 4])
299    2.8
300
301    >>> from fractions import Fraction as F
302    >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
303    Fraction(13, 21)
304
305    >>> from decimal import Decimal as D
306    >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
307    Decimal('0.5625')
308
309    If ``data`` is empty, StatisticsError will be raised.
310    """
311    if iter(data) is data:
312        data = list(data)
313    n = len(data)
314    if n < 1:
315        raise StatisticsError('mean requires at least one data point')
316    T, total, count = _sum(data)
317    assert count == n
318    return _convert(total/n, T)
319
320
321def fmean(data):
322    """Convert data to floats and compute the arithmetic mean.
323
324    This runs faster than the mean() function and it always returns a float.
325    If the input dataset is empty, it raises a StatisticsError.
326
327    >>> fmean([3.5, 4.0, 5.25])
328    4.25
329    """
330    try:
331        n = len(data)
332    except TypeError:
333        # Handle iterators that do not define __len__().
334        n = 0
335        def count(iterable):
336            nonlocal n
337            for n, x in enumerate(iterable, start=1):
338                yield x
339        total = fsum(count(data))
340    else:
341        total = fsum(data)
342    try:
343        return total / n
344    except ZeroDivisionError:
345        raise StatisticsError('fmean requires at least one data point') from None
346
347
348def geometric_mean(data):
349    """Convert data to floats and compute the geometric mean.
350
351    Raises a StatisticsError if the input dataset is empty,
352    if it contains a zero, or if it contains a negative value.
353
354    No special efforts are made to achieve exact results.
355    (However, this may change in the future.)
356
357    >>> round(geometric_mean([54, 24, 36]), 9)
358    36.0
359    """
360    try:
361        return exp(fmean(map(log, data)))
362    except ValueError:
363        raise StatisticsError('geometric mean requires a non-empty dataset '
364                              ' containing positive numbers') from None
365
366
367def harmonic_mean(data):
368    """Return the harmonic mean of data.
369
370    The harmonic mean, sometimes called the subcontrary mean, is the
371    reciprocal of the arithmetic mean of the reciprocals of the data,
372    and is often appropriate when averaging quantities which are rates
373    or ratios, for example speeds. Example:
374
375    Suppose an investor purchases an equal value of shares in each of
376    three companies, with P/E (price/earning) ratios of 2.5, 3 and 10.
377    What is the average P/E ratio for the investor's portfolio?
378
379    >>> harmonic_mean([2.5, 3, 10])  # For an equal investment portfolio.
380    3.6
381
382    Using the arithmetic mean would give an average of about 5.167, which
383    is too high.
384
385    If ``data`` is empty, or any element is less than zero,
386    ``harmonic_mean`` will raise ``StatisticsError``.
387    """
388    # For a justification for using harmonic mean for P/E ratios, see
389    # http://fixthepitch.pellucid.com/comps-analysis-the-missing-harmony-of-summary-statistics/
390    # http://papers.ssrn.com/sol3/papers.cfm?abstract_id=2621087
391    if iter(data) is data:
392        data = list(data)
393    errmsg = 'harmonic mean does not support negative values'
394    n = len(data)
395    if n < 1:
396        raise StatisticsError('harmonic_mean requires at least one data point')
397    elif n == 1:
398        x = data[0]
399        if isinstance(x, (numbers.Real, Decimal)):
400            if x < 0:
401                raise StatisticsError(errmsg)
402            return x
403        else:
404            raise TypeError('unsupported type')
405    try:
406        T, total, count = _sum(1/x for x in _fail_neg(data, errmsg))
407    except ZeroDivisionError:
408        return 0
409    assert count == n
410    return _convert(n/total, T)
411
412
413# FIXME: investigate ways to calculate medians without sorting? Quickselect?
414def median(data):
415    """Return the median (middle value) of numeric data.
416
417    When the number of data points is odd, return the middle data point.
418    When the number of data points is even, the median is interpolated by
419    taking the average of the two middle values:
420
421    >>> median([1, 3, 5])
422    3
423    >>> median([1, 3, 5, 7])
424    4.0
425
426    """
427    data = sorted(data)
428    n = len(data)
429    if n == 0:
430        raise StatisticsError("no median for empty data")
431    if n%2 == 1:
432        return data[n//2]
433    else:
434        i = n//2
435        return (data[i - 1] + data[i])/2
436
437
438def median_low(data):
439    """Return the low median of numeric data.
440
441    When the number of data points is odd, the middle value is returned.
442    When it is even, the smaller of the two middle values is returned.
443
444    >>> median_low([1, 3, 5])
445    3
446    >>> median_low([1, 3, 5, 7])
447    3
448
449    """
450    data = sorted(data)
451    n = len(data)
452    if n == 0:
453        raise StatisticsError("no median for empty data")
454    if n%2 == 1:
455        return data[n//2]
456    else:
457        return data[n//2 - 1]
458
459
460def median_high(data):
461    """Return the high median of data.
462
463    When the number of data points is odd, the middle value is returned.
464    When it is even, the larger of the two middle values is returned.
465
466    >>> median_high([1, 3, 5])
467    3
468    >>> median_high([1, 3, 5, 7])
469    5
470
471    """
472    data = sorted(data)
473    n = len(data)
474    if n == 0:
475        raise StatisticsError("no median for empty data")
476    return data[n//2]
477
478
479def median_grouped(data, interval=1):
480    """Return the 50th percentile (median) of grouped continuous data.
481
482    >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
483    3.7
484    >>> median_grouped([52, 52, 53, 54])
485    52.5
486
487    This calculates the median as the 50th percentile, and should be
488    used when your data is continuous and grouped. In the above example,
489    the values 1, 2, 3, etc. actually represent the midpoint of classes
490    0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
491    class 3.5-4.5, and interpolation is used to estimate it.
492
493    Optional argument ``interval`` represents the class interval, and
494    defaults to 1. Changing the class interval naturally will change the
495    interpolated 50th percentile value:
496
497    >>> median_grouped([1, 3, 3, 5, 7], interval=1)
498    3.25
499    >>> median_grouped([1, 3, 3, 5, 7], interval=2)
500    3.5
501
502    This function does not check whether the data points are at least
503    ``interval`` apart.
504    """
505    data = sorted(data)
506    n = len(data)
507    if n == 0:
508        raise StatisticsError("no median for empty data")
509    elif n == 1:
510        return data[0]
511    # Find the value at the midpoint. Remember this corresponds to the
512    # centre of the class interval.
513    x = data[n//2]
514    for obj in (x, interval):
515        if isinstance(obj, (str, bytes)):
516            raise TypeError('expected number but got %r' % obj)
517    try:
518        L = x - interval/2  # The lower limit of the median interval.
519    except TypeError:
520        # Mixed type. For now we just coerce to float.
521        L = float(x) - float(interval)/2
522
523    # Uses bisection search to search for x in data with log(n) time complexity
524    # Find the position of leftmost occurrence of x in data
525    l1 = _find_lteq(data, x)
526    # Find the position of rightmost occurrence of x in data[l1...len(data)]
527    # Assuming always l1 <= l2
528    l2 = _find_rteq(data, l1, x)
529    cf = l1
530    f = l2 - l1 + 1
531    return L + interval*(n/2 - cf)/f
532
533
534def mode(data):
535    """Return the most common data point from discrete or nominal data.
536
537    ``mode`` assumes discrete data, and returns a single value. This is the
538    standard treatment of the mode as commonly taught in schools:
539
540        >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
541        3
542
543    This also works with nominal (non-numeric) data:
544
545        >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
546        'red'
547
548    If there are multiple modes with same frequency, return the first one
549    encountered:
550
551        >>> mode(['red', 'red', 'green', 'blue', 'blue'])
552        'red'
553
554    If *data* is empty, ``mode``, raises StatisticsError.
555
556    """
557    data = iter(data)
558    pairs = Counter(data).most_common(1)
559    try:
560        return pairs[0][0]
561    except IndexError:
562        raise StatisticsError('no mode for empty data') from None
563
564
565def multimode(data):
566    """Return a list of the most frequently occurring values.
567
568    Will return more than one result if there are multiple modes
569    or an empty list if *data* is empty.
570
571    >>> multimode('aabbbbbbbbcc')
572    ['b']
573    >>> multimode('aabbbbccddddeeffffgg')
574    ['b', 'd', 'f']
575    >>> multimode('')
576    []
577    """
578    counts = Counter(iter(data)).most_common()
579    maxcount, mode_items = next(groupby(counts, key=itemgetter(1)), (0, []))
580    return list(map(itemgetter(0), mode_items))
581
582
583# Notes on methods for computing quantiles
584# ----------------------------------------
585#
586# There is no one perfect way to compute quantiles.  Here we offer
587# two methods that serve common needs.  Most other packages
588# surveyed offered at least one or both of these two, making them
589# "standard" in the sense of "widely-adopted and reproducible".
590# They are also easy to explain, easy to compute manually, and have
591# straight-forward interpretations that aren't surprising.
592
593# The default method is known as "R6", "PERCENTILE.EXC", or "expected
594# value of rank order statistics". The alternative method is known as
595# "R7", "PERCENTILE.INC", or "mode of rank order statistics".
596
597# For sample data where there is a positive probability for values
598# beyond the range of the data, the R6 exclusive method is a
599# reasonable choice.  Consider a random sample of nine values from a
600# population with a uniform distribution from 0.0 to 100.0.  The
601# distribution of the third ranked sample point is described by
602# betavariate(alpha=3, beta=7) which has mode=0.250, median=0.286, and
603# mean=0.300.  Only the latter (which corresponds with R6) gives the
604# desired cut point with 30% of the population falling below that
605# value, making it comparable to a result from an inv_cdf() function.
606# The R6 exclusive method is also idempotent.
607
608# For describing population data where the end points are known to
609# be included in the data, the R7 inclusive method is a reasonable
610# choice.  Instead of the mean, it uses the mode of the beta
611# distribution for the interior points.  Per Hyndman & Fan, "One nice
612# property is that the vertices of Q7(p) divide the range into n - 1
613# intervals, and exactly 100p% of the intervals lie to the left of
614# Q7(p) and 100(1 - p)% of the intervals lie to the right of Q7(p)."
615
616# If needed, other methods could be added.  However, for now, the
617# position is that fewer options make for easier choices and that
618# external packages can be used for anything more advanced.
619
620def quantiles(data, *, n=4, method='exclusive'):
621    """Divide *data* into *n* continuous intervals with equal probability.
622
623    Returns a list of (n - 1) cut points separating the intervals.
624
625    Set *n* to 4 for quartiles (the default).  Set *n* to 10 for deciles.
626    Set *n* to 100 for percentiles which gives the 99 cuts points that
627    separate *data* in to 100 equal sized groups.
628
629    The *data* can be any iterable containing sample.
630    The cut points are linearly interpolated between data points.
631
632    If *method* is set to *inclusive*, *data* is treated as population
633    data.  The minimum value is treated as the 0th percentile and the
634    maximum value is treated as the 100th percentile.
635    """
636    if n < 1:
637        raise StatisticsError('n must be at least 1')
638    data = sorted(data)
639    ld = len(data)
640    if ld < 2:
641        raise StatisticsError('must have at least two data points')
642    if method == 'inclusive':
643        m = ld - 1
644        result = []
645        for i in range(1, n):
646            j = i * m // n
647            delta = i*m - j*n
648            interpolated = (data[j] * (n - delta) + data[j+1] * delta) / n
649            result.append(interpolated)
650        return result
651    if method == 'exclusive':
652        m = ld + 1
653        result = []
654        for i in range(1, n):
655            j = i * m // n                               # rescale i to m/n
656            j = 1 if j < 1 else ld-1 if j > ld-1 else j  # clamp to 1 .. ld-1
657            delta = i*m - j*n                            # exact integer math
658            interpolated = (data[j-1] * (n - delta) + data[j] * delta) / n
659            result.append(interpolated)
660        return result
661    raise ValueError(f'Unknown method: {method!r}')
662
663
664# === Measures of spread ===
665
666# See http://mathworld.wolfram.com/Variance.html
667#     http://mathworld.wolfram.com/SampleVariance.html
668#     http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
669#
670# Under no circumstances use the so-called "computational formula for
671# variance", as that is only suitable for hand calculations with a small
672# amount of low-precision data. It has terrible numeric properties.
673#
674# See a comparison of three computational methods here:
675# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
676
677def _ss(data, c=None):
678    """Return sum of square deviations of sequence data.
679
680    If ``c`` is None, the mean is calculated in one pass, and the deviations
681    from the mean are calculated in a second pass. Otherwise, deviations are
682    calculated from ``c`` as given. Use the second case with care, as it can
683    lead to garbage results.
684    """
685    if c is not None:
686        T, total, count = _sum((x-c)**2 for x in data)
687        return (T, total)
688    c = mean(data)
689    T, total, count = _sum((x-c)**2 for x in data)
690    # The following sum should mathematically equal zero, but due to rounding
691    # error may not.
692    U, total2, count2 = _sum((x-c) for x in data)
693    assert T == U and count == count2
694    total -=  total2**2/len(data)
695    assert not total < 0, 'negative sum of square deviations: %f' % total
696    return (T, total)
697
698
699def variance(data, xbar=None):
700    """Return the sample variance of data.
701
702    data should be an iterable of Real-valued numbers, with at least two
703    values. The optional argument xbar, if given, should be the mean of
704    the data. If it is missing or None, the mean is automatically calculated.
705
706    Use this function when your data is a sample from a population. To
707    calculate the variance from the entire population, see ``pvariance``.
708
709    Examples:
710
711    >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
712    >>> variance(data)
713    1.3720238095238095
714
715    If you have already calculated the mean of your data, you can pass it as
716    the optional second argument ``xbar`` to avoid recalculating it:
717
718    >>> m = mean(data)
719    >>> variance(data, m)
720    1.3720238095238095
721
722    This function does not check that ``xbar`` is actually the mean of
723    ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
724    impossible results.
725
726    Decimals and Fractions are supported:
727
728    >>> from decimal import Decimal as D
729    >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
730    Decimal('31.01875')
731
732    >>> from fractions import Fraction as F
733    >>> variance([F(1, 6), F(1, 2), F(5, 3)])
734    Fraction(67, 108)
735
736    """
737    if iter(data) is data:
738        data = list(data)
739    n = len(data)
740    if n < 2:
741        raise StatisticsError('variance requires at least two data points')
742    T, ss = _ss(data, xbar)
743    return _convert(ss/(n-1), T)
744
745
746def pvariance(data, mu=None):
747    """Return the population variance of ``data``.
748
749    data should be a sequence or iterable of Real-valued numbers, with at least one
750    value. The optional argument mu, if given, should be the mean of
751    the data. If it is missing or None, the mean is automatically calculated.
752
753    Use this function to calculate the variance from the entire population.
754    To estimate the variance from a sample, the ``variance`` function is
755    usually a better choice.
756
757    Examples:
758
759    >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
760    >>> pvariance(data)
761    1.25
762
763    If you have already calculated the mean of the data, you can pass it as
764    the optional second argument to avoid recalculating it:
765
766    >>> mu = mean(data)
767    >>> pvariance(data, mu)
768    1.25
769
770    Decimals and Fractions are supported:
771
772    >>> from decimal import Decimal as D
773    >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
774    Decimal('24.815')
775
776    >>> from fractions import Fraction as F
777    >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
778    Fraction(13, 72)
779
780    """
781    if iter(data) is data:
782        data = list(data)
783    n = len(data)
784    if n < 1:
785        raise StatisticsError('pvariance requires at least one data point')
786    T, ss = _ss(data, mu)
787    return _convert(ss/n, T)
788
789
790def stdev(data, xbar=None):
791    """Return the square root of the sample variance.
792
793    See ``variance`` for arguments and other details.
794
795    >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
796    1.0810874155219827
797
798    """
799    var = variance(data, xbar)
800    try:
801        return var.sqrt()
802    except AttributeError:
803        return math.sqrt(var)
804
805
806def pstdev(data, mu=None):
807    """Return the square root of the population variance.
808
809    See ``pvariance`` for arguments and other details.
810
811    >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
812    0.986893273527251
813
814    """
815    var = pvariance(data, mu)
816    try:
817        return var.sqrt()
818    except AttributeError:
819        return math.sqrt(var)
820
821
822## Normal Distribution #####################################################
823
824
825def _normal_dist_inv_cdf(p, mu, sigma):
826    # There is no closed-form solution to the inverse CDF for the normal
827    # distribution, so we use a rational approximation instead:
828    # Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
829    # Normal Distribution".  Applied Statistics. Blackwell Publishing. 37
830    # (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.
831    q = p - 0.5
832    if fabs(q) <= 0.425:
833        r = 0.180625 - q * q
834        # Hash sum: 55.88319_28806_14901_4439
835        num = (((((((2.50908_09287_30122_6727e+3 * r +
836                     3.34305_75583_58812_8105e+4) * r +
837                     6.72657_70927_00870_0853e+4) * r +
838                     4.59219_53931_54987_1457e+4) * r +
839                     1.37316_93765_50946_1125e+4) * r +
840                     1.97159_09503_06551_4427e+3) * r +
841                     1.33141_66789_17843_7745e+2) * r +
842                     3.38713_28727_96366_6080e+0) * q
843        den = (((((((5.22649_52788_52854_5610e+3 * r +
844                     2.87290_85735_72194_2674e+4) * r +
845                     3.93078_95800_09271_0610e+4) * r +
846                     2.12137_94301_58659_5867e+4) * r +
847                     5.39419_60214_24751_1077e+3) * r +
848                     6.87187_00749_20579_0830e+2) * r +
849                     4.23133_30701_60091_1252e+1) * r +
850                     1.0)
851        x = num / den
852        return mu + (x * sigma)
853    r = p if q <= 0.0 else 1.0 - p
854    r = sqrt(-log(r))
855    if r <= 5.0:
856        r = r - 1.6
857        # Hash sum: 49.33206_50330_16102_89036
858        num = (((((((7.74545_01427_83414_07640e-4 * r +
859                     2.27238_44989_26918_45833e-2) * r +
860                     2.41780_72517_74506_11770e-1) * r +
861                     1.27045_82524_52368_38258e+0) * r +
862                     3.64784_83247_63204_60504e+0) * r +
863                     5.76949_72214_60691_40550e+0) * r +
864                     4.63033_78461_56545_29590e+0) * r +
865                     1.42343_71107_49683_57734e+0)
866        den = (((((((1.05075_00716_44416_84324e-9 * r +
867                     5.47593_80849_95344_94600e-4) * r +
868                     1.51986_66563_61645_71966e-2) * r +
869                     1.48103_97642_74800_74590e-1) * r +
870                     6.89767_33498_51000_04550e-1) * r +
871                     1.67638_48301_83803_84940e+0) * r +
872                     2.05319_16266_37758_82187e+0) * r +
873                     1.0)
874    else:
875        r = r - 5.0
876        # Hash sum: 47.52583_31754_92896_71629
877        num = (((((((2.01033_43992_92288_13265e-7 * r +
878                     2.71155_55687_43487_57815e-5) * r +
879                     1.24266_09473_88078_43860e-3) * r +
880                     2.65321_89526_57612_30930e-2) * r +
881                     2.96560_57182_85048_91230e-1) * r +
882                     1.78482_65399_17291_33580e+0) * r +
883                     5.46378_49111_64114_36990e+0) * r +
884                     6.65790_46435_01103_77720e+0)
885        den = (((((((2.04426_31033_89939_78564e-15 * r +
886                     1.42151_17583_16445_88870e-7) * r +
887                     1.84631_83175_10054_68180e-5) * r +
888                     7.86869_13114_56132_59100e-4) * r +
889                     1.48753_61290_85061_48525e-2) * r +
890                     1.36929_88092_27358_05310e-1) * r +
891                     5.99832_20655_58879_37690e-1) * r +
892                     1.0)
893    x = num / den
894    if q < 0.0:
895        x = -x
896    return mu + (x * sigma)
897
898
899class NormalDist:
900    "Normal distribution of a random variable"
901    # https://en.wikipedia.org/wiki/Normal_distribution
902    # https://en.wikipedia.org/wiki/Variance#Properties
903
904    __slots__ = {
905        '_mu': 'Arithmetic mean of a normal distribution',
906        '_sigma': 'Standard deviation of a normal distribution',
907    }
908
909    def __init__(self, mu=0.0, sigma=1.0):
910        "NormalDist where mu is the mean and sigma is the standard deviation."
911        if sigma < 0.0:
912            raise StatisticsError('sigma must be non-negative')
913        self._mu = float(mu)
914        self._sigma = float(sigma)
915
916    @classmethod
917    def from_samples(cls, data):
918        "Make a normal distribution instance from sample data."
919        if not isinstance(data, (list, tuple)):
920            data = list(data)
921        xbar = fmean(data)
922        return cls(xbar, stdev(data, xbar))
923
924    def samples(self, n, *, seed=None):
925        "Generate *n* samples for a given mean and standard deviation."
926        gauss = random.gauss if seed is None else random.Random(seed).gauss
927        mu, sigma = self._mu, self._sigma
928        return [gauss(mu, sigma) for i in range(n)]
929
930    def pdf(self, x):
931        "Probability density function.  P(x <= X < x+dx) / dx"
932        variance = self._sigma ** 2.0
933        if not variance:
934            raise StatisticsError('pdf() not defined when sigma is zero')
935        return exp((x - self._mu)**2.0 / (-2.0*variance)) / sqrt(tau*variance)
936
937    def cdf(self, x):
938        "Cumulative distribution function.  P(X <= x)"
939        if not self._sigma:
940            raise StatisticsError('cdf() not defined when sigma is zero')
941        return 0.5 * (1.0 + erf((x - self._mu) / (self._sigma * sqrt(2.0))))
942
943    def inv_cdf(self, p):
944        """Inverse cumulative distribution function.  x : P(X <= x) = p
945
946        Finds the value of the random variable such that the probability of
947        the variable being less than or equal to that value equals the given
948        probability.
949
950        This function is also called the percent point function or quantile
951        function.
952        """
953        if p <= 0.0 or p >= 1.0:
954            raise StatisticsError('p must be in the range 0.0 < p < 1.0')
955        if self._sigma <= 0.0:
956            raise StatisticsError('cdf() not defined when sigma at or below zero')
957        return _normal_dist_inv_cdf(p, self._mu, self._sigma)
958
959    def quantiles(self, n=4):
960        """Divide into *n* continuous intervals with equal probability.
961
962        Returns a list of (n - 1) cut points separating the intervals.
963
964        Set *n* to 4 for quartiles (the default).  Set *n* to 10 for deciles.
965        Set *n* to 100 for percentiles which gives the 99 cuts points that
966        separate the normal distribution in to 100 equal sized groups.
967        """
968        return [self.inv_cdf(i / n) for i in range(1, n)]
969
970    def overlap(self, other):
971        """Compute the overlapping coefficient (OVL) between two normal distributions.
972
973        Measures the agreement between two normal probability distributions.
974        Returns a value between 0.0 and 1.0 giving the overlapping area in
975        the two underlying probability density functions.
976
977            >>> N1 = NormalDist(2.4, 1.6)
978            >>> N2 = NormalDist(3.2, 2.0)
979            >>> N1.overlap(N2)
980            0.8035050657330205
981        """
982        # See: "The overlapping coefficient as a measure of agreement between
983        # probability distributions and point estimation of the overlap of two
984        # normal densities" -- Henry F. Inman and Edwin L. Bradley Jr
985        # http://dx.doi.org/10.1080/03610928908830127
986        if not isinstance(other, NormalDist):
987            raise TypeError('Expected another NormalDist instance')
988        X, Y = self, other
989        if (Y._sigma, Y._mu) < (X._sigma, X._mu):   # sort to assure commutativity
990            X, Y = Y, X
991        X_var, Y_var = X.variance, Y.variance
992        if not X_var or not Y_var:
993            raise StatisticsError('overlap() not defined when sigma is zero')
994        dv = Y_var - X_var
995        dm = fabs(Y._mu - X._mu)
996        if not dv:
997            return 1.0 - erf(dm / (2.0 * X._sigma * sqrt(2.0)))
998        a = X._mu * Y_var - Y._mu * X_var
999        b = X._sigma * Y._sigma * sqrt(dm**2.0 + dv * log(Y_var / X_var))
1000        x1 = (a + b) / dv
1001        x2 = (a - b) / dv
1002        return 1.0 - (fabs(Y.cdf(x1) - X.cdf(x1)) + fabs(Y.cdf(x2) - X.cdf(x2)))
1003
1004    @property
1005    def mean(self):
1006        "Arithmetic mean of the normal distribution."
1007        return self._mu
1008
1009    @property
1010    def median(self):
1011        "Return the median of the normal distribution"
1012        return self._mu
1013
1014    @property
1015    def mode(self):
1016        """Return the mode of the normal distribution
1017
1018        The mode is the value x where which the probability density
1019        function (pdf) takes its maximum value.
1020        """
1021        return self._mu
1022
1023    @property
1024    def stdev(self):
1025        "Standard deviation of the normal distribution."
1026        return self._sigma
1027
1028    @property
1029    def variance(self):
1030        "Square of the standard deviation."
1031        return self._sigma ** 2.0
1032
1033    def __add__(x1, x2):
1034        """Add a constant or another NormalDist instance.
1035
1036        If *other* is a constant, translate mu by the constant,
1037        leaving sigma unchanged.
1038
1039        If *other* is a NormalDist, add both the means and the variances.
1040        Mathematically, this works only if the two distributions are
1041        independent or if they are jointly normally distributed.
1042        """
1043        if isinstance(x2, NormalDist):
1044            return NormalDist(x1._mu + x2._mu, hypot(x1._sigma, x2._sigma))
1045        return NormalDist(x1._mu + x2, x1._sigma)
1046
1047    def __sub__(x1, x2):
1048        """Subtract a constant or another NormalDist instance.
1049
1050        If *other* is a constant, translate by the constant mu,
1051        leaving sigma unchanged.
1052
1053        If *other* is a NormalDist, subtract the means and add the variances.
1054        Mathematically, this works only if the two distributions are
1055        independent or if they are jointly normally distributed.
1056        """
1057        if isinstance(x2, NormalDist):
1058            return NormalDist(x1._mu - x2._mu, hypot(x1._sigma, x2._sigma))
1059        return NormalDist(x1._mu - x2, x1._sigma)
1060
1061    def __mul__(x1, x2):
1062        """Multiply both mu and sigma by a constant.
1063
1064        Used for rescaling, perhaps to change measurement units.
1065        Sigma is scaled with the absolute value of the constant.
1066        """
1067        return NormalDist(x1._mu * x2, x1._sigma * fabs(x2))
1068
1069    def __truediv__(x1, x2):
1070        """Divide both mu and sigma by a constant.
1071
1072        Used for rescaling, perhaps to change measurement units.
1073        Sigma is scaled with the absolute value of the constant.
1074        """
1075        return NormalDist(x1._mu / x2, x1._sigma / fabs(x2))
1076
1077    def __pos__(x1):
1078        "Return a copy of the instance."
1079        return NormalDist(x1._mu, x1._sigma)
1080
1081    def __neg__(x1):
1082        "Negates mu while keeping sigma the same."
1083        return NormalDist(-x1._mu, x1._sigma)
1084
1085    __radd__ = __add__
1086
1087    def __rsub__(x1, x2):
1088        "Subtract a NormalDist from a constant or another NormalDist."
1089        return -(x1 - x2)
1090
1091    __rmul__ = __mul__
1092
1093    def __eq__(x1, x2):
1094        "Two NormalDist objects are equal if their mu and sigma are both equal."
1095        if not isinstance(x2, NormalDist):
1096            return NotImplemented
1097        return x1._mu == x2._mu and x1._sigma == x2._sigma
1098
1099    def __hash__(self):
1100        "NormalDist objects hash equal if their mu and sigma are both equal."
1101        return hash((self._mu, self._sigma))
1102
1103    def __repr__(self):
1104        return f'{type(self).__name__}(mu={self._mu!r}, sigma={self._sigma!r})'
1105
1106# If available, use C implementation
1107try:
1108    from _statistics import _normal_dist_inv_cdf
1109except ImportError:
1110    pass
1111
1112
1113if __name__ == '__main__':
1114
1115    # Show math operations computed analytically in comparsion
1116    # to a monte carlo simulation of the same operations
1117
1118    from math import isclose
1119    from operator import add, sub, mul, truediv
1120    from itertools import repeat
1121    import doctest
1122
1123    g1 = NormalDist(10, 20)
1124    g2 = NormalDist(-5, 25)
1125
1126    # Test scaling by a constant
1127    assert (g1 * 5 / 5).mean == g1.mean
1128    assert (g1 * 5 / 5).stdev == g1.stdev
1129
1130    n = 100_000
1131    G1 = g1.samples(n)
1132    G2 = g2.samples(n)
1133
1134    for func in (add, sub):
1135        print(f'\nTest {func.__name__} with another NormalDist:')
1136        print(func(g1, g2))
1137        print(NormalDist.from_samples(map(func, G1, G2)))
1138
1139    const = 11
1140    for func in (add, sub, mul, truediv):
1141        print(f'\nTest {func.__name__} with a constant:')
1142        print(func(g1, const))
1143        print(NormalDist.from_samples(map(func, G1, repeat(const))))
1144
1145    const = 19
1146    for func in (add, sub, mul):
1147        print(f'\nTest constant with {func.__name__}:')
1148        print(func(const, g1))
1149        print(NormalDist.from_samples(map(func, repeat(const), G1)))
1150
1151    def assert_close(G1, G2):
1152        assert isclose(G1.mean, G1.mean, rel_tol=0.01), (G1, G2)
1153        assert isclose(G1.stdev, G2.stdev, rel_tol=0.01), (G1, G2)
1154
1155    X = NormalDist(-105, 73)
1156    Y = NormalDist(31, 47)
1157    s = 32.75
1158    n = 100_000
1159
1160    S = NormalDist.from_samples([x + s for x in X.samples(n)])
1161    assert_close(X + s, S)
1162
1163    S = NormalDist.from_samples([x - s for x in X.samples(n)])
1164    assert_close(X - s, S)
1165
1166    S = NormalDist.from_samples([x * s for x in X.samples(n)])
1167    assert_close(X * s, S)
1168
1169    S = NormalDist.from_samples([x / s for x in X.samples(n)])
1170    assert_close(X / s, S)
1171
1172    S = NormalDist.from_samples([x + y for x, y in zip(X.samples(n),
1173                                                       Y.samples(n))])
1174    assert_close(X + Y, S)
1175
1176    S = NormalDist.from_samples([x - y for x, y in zip(X.samples(n),
1177                                                       Y.samples(n))])
1178    assert_close(X - Y, S)
1179
1180    print(doctest.testmod())
1181