1##  Module statistics.py
2##
3##  Copyright (c) 2013 Steven D'Aprano <steve+python@pearwood.info>.
4##
5##  Licensed under the Apache License, Version 2.0 (the "License");
6##  you may not use this file except in compliance with the License.
7##  You may obtain a copy of the License at
8##
9##  http://www.apache.org/licenses/LICENSE-2.0
10##
11##  Unless required by applicable law or agreed to in writing, software
12##  distributed under the License is distributed on an "AS IS" BASIS,
13##  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14##  See the License for the specific language governing permissions and
15##  limitations under the License.
16
17
18"""
19Basic statistics module.
20
21This module provides functions for calculating statistics of data, including
22averages, variance, and standard deviation.
23
24Calculating averages
25--------------------
26
27==================  =============================================
28Function            Description
29==================  =============================================
30mean                Arithmetic mean (average) of data.
31median              Median (middle value) of data.
32median_low          Low median of data.
33median_high         High median of data.
34median_grouped      Median, or 50th percentile, of grouped data.
35mode                Mode (most common value) of data.
36==================  =============================================
37
38Calculate the arithmetic mean ("the average") of data:
39
40>>> mean([-1.0, 2.5, 3.25, 5.75])
412.625
42
43
44Calculate the standard median of discrete data:
45
46>>> median([2, 3, 4, 5])
473.5
48
49
50Calculate the median, or 50th percentile, of data grouped into class intervals
51centred on the data values provided. E.g. if your data points are rounded to
52the nearest whole number:
53
54>>> median_grouped([2, 2, 3, 3, 3, 4])  #doctest: +ELLIPSIS
552.8333333333...
56
57This should be interpreted in this way: you have two data points in the class
58interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
59the class interval 3.5-4.5. The median of these data points is 2.8333...
60
61
62Calculating variability or spread
63---------------------------------
64
65==================  =============================================
66Function            Description
67==================  =============================================
68pvariance           Population variance of data.
69variance            Sample variance of data.
70pstdev              Population standard deviation of data.
71stdev               Sample standard deviation of data.
72==================  =============================================
73
74Calculate the standard deviation of sample data:
75
76>>> stdev([2.5, 3.25, 5.5, 11.25, 11.75])  #doctest: +ELLIPSIS
774.38961843444...
78
79If you have previously calculated the mean, you can pass it as the optional
80second argument to the four "spread" functions to avoid recalculating it:
81
82>>> data = [1, 2, 2, 4, 4, 4, 5, 6]
83>>> mu = mean(data)
84>>> pvariance(data, mu)
852.5
86
87
88Exceptions
89----------
90
91A single exception is defined: StatisticsError is a subclass of ValueError.
92
93"""
94
95__all__ = [ 'StatisticsError',
96            'pstdev', 'pvariance', 'stdev', 'variance',
97            'median',  'median_low', 'median_high', 'median_grouped',
98            'mean', 'mode',
99          ]
100
101
102import collections
103import math
104
105from fractions import Fraction
106from decimal import Decimal
107
108
109# === Exceptions ===
110
111class StatisticsError(ValueError):
112    pass
113
114
115# === Private utilities ===
116
117def _sum(data, start=0):
118    """_sum(data [, start]) -> value
119
120    Return a high-precision sum of the given numeric data. If optional
121    argument ``start`` is given, it is added to the total. If ``data`` is
122    empty, ``start`` (defaulting to 0) is returned.
123
124
125    Examples
126    --------
127
128    >>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
129    11.0
130
131    Some sources of round-off error will be avoided:
132
133    >>> _sum([1e50, 1, -1e50] * 1000)  # Built-in sum returns zero.
134    1000.0
135
136    Fractions and Decimals are also supported:
137
138    >>> from fractions import Fraction as F
139    >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
140    Fraction(63, 20)
141
142    >>> from decimal import Decimal as D
143    >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
144    >>> _sum(data)
145    Decimal('0.6963')
146
147    Mixed types are currently treated as an error, except that int is
148    allowed.
149    """
150    # We fail as soon as we reach a value that is not an int or the type of
151    # the first value which is not an int. E.g. _sum([int, int, float, int])
152    # is okay, but sum([int, int, float, Fraction]) is not.
153    allowed_types = set([int, type(start)])
154    n, d = _exact_ratio(start)
155    partials = {d: n}  # map {denominator: sum of numerators}
156    # Micro-optimizations.
157    exact_ratio = _exact_ratio
158    partials_get = partials.get
159    # Add numerators for each denominator.
160    for x in data:
161        _check_type(type(x), allowed_types)
162        n, d = exact_ratio(x)
163        partials[d] = partials_get(d, 0) + n
164    # Find the expected result type. If allowed_types has only one item, it
165    # will be int; if it has two, use the one which isn't int.
166    assert len(allowed_types) in (1, 2)
167    if len(allowed_types) == 1:
168        assert allowed_types.pop() is int
169        T = int
170    else:
171        T = (allowed_types - set([int])).pop()
172    if None in partials:
173        assert issubclass(T, (float, Decimal))
174        assert not math.isfinite(partials[None])
175        return T(partials[None])
176    total = Fraction()
177    for d, n in sorted(partials.items()):
178        total += Fraction(n, d)
179    if issubclass(T, int):
180        assert total.denominator == 1
181        return T(total.numerator)
182    if issubclass(T, Decimal):
183        return T(total.numerator)/total.denominator
184    return T(total)
185
186
187def _check_type(T, allowed):
188    if T not in allowed:
189        if len(allowed) == 1:
190            allowed.add(T)
191        else:
192            types = ', '.join([t.__name__ for t in allowed] + [T.__name__])
193            raise TypeError("unsupported mixed types: %s" % types)
194
195
196def _exact_ratio(x):
197    """Convert Real number x exactly to (numerator, denominator) pair.
198
199    >>> _exact_ratio(0.25)
200    (1, 4)
201
202    x is expected to be an int, Fraction, Decimal or float.
203    """
204    try:
205        try:
206            # int, Fraction
207            return (x.numerator, x.denominator)
208        except AttributeError:
209            # float
210            try:
211                return x.as_integer_ratio()
212            except AttributeError:
213                # Decimal
214                try:
215                    return _decimal_to_ratio(x)
216                except AttributeError:
217                    msg = "can't convert type '{}' to numerator/denominator"
218                    raise TypeError(msg.format(type(x).__name__)) from None
219    except (OverflowError, ValueError):
220        # INF or NAN
221        if __debug__:
222            # Decimal signalling NANs cannot be converted to float :-(
223            if isinstance(x, Decimal):
224                assert not x.is_finite()
225            else:
226                assert not math.isfinite(x)
227        return (x, None)
228
229
230# FIXME This is faster than Fraction.from_decimal, but still too slow.
231def _decimal_to_ratio(d):
232    """Convert Decimal d to exact integer ratio (numerator, denominator).
233
234    >>> from decimal import Decimal
235    >>> _decimal_to_ratio(Decimal("2.6"))
236    (26, 10)
237
238    """
239    sign, digits, exp = d.as_tuple()
240    if exp in ('F', 'n', 'N'):  # INF, NAN, sNAN
241        assert not d.is_finite()
242        raise ValueError
243    num = 0
244    for digit in digits:
245        num = num*10 + digit
246    if exp < 0:
247        den = 10**-exp
248    else:
249        num *= 10**exp
250        den = 1
251    if sign:
252        num = -num
253    return (num, den)
254
255
256def _counts(data):
257    # Generate a table of sorted (value, frequency) pairs.
258    table = collections.Counter(iter(data)).most_common()
259    if not table:
260        return table
261    # Extract the values with the highest frequency.
262    maxfreq = table[0][1]
263    for i in range(1, len(table)):
264        if table[i][1] != maxfreq:
265            table = table[:i]
266            break
267    return table
268
269
270# === Measures of central tendency (averages) ===
271
272def mean(data):
273    """Return the sample arithmetic mean of data.
274
275    >>> mean([1, 2, 3, 4, 4])
276    2.8
277
278    >>> from fractions import Fraction as F
279    >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
280    Fraction(13, 21)
281
282    >>> from decimal import Decimal as D
283    >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
284    Decimal('0.5625')
285
286    If ``data`` is empty, StatisticsError will be raised.
287    """
288    if iter(data) is data:
289        data = list(data)
290    n = len(data)
291    if n < 1:
292        raise StatisticsError('mean requires at least one data point')
293    return _sum(data)/n
294
295
296# FIXME: investigate ways to calculate medians without sorting? Quickselect?
297def median(data):
298    """Return the median (middle value) of numeric data.
299
300    When the number of data points is odd, return the middle data point.
301    When the number of data points is even, the median is interpolated by
302    taking the average of the two middle values:
303
304    >>> median([1, 3, 5])
305    3
306    >>> median([1, 3, 5, 7])
307    4.0
308
309    """
310    data = sorted(data)
311    n = len(data)
312    if n == 0:
313        raise StatisticsError("no median for empty data")
314    if n%2 == 1:
315        return data[n//2]
316    else:
317        i = n//2
318        return (data[i - 1] + data[i])/2
319
320
321def median_low(data):
322    """Return the low median of numeric data.
323
324    When the number of data points is odd, the middle value is returned.
325    When it is even, the smaller of the two middle values is returned.
326
327    >>> median_low([1, 3, 5])
328    3
329    >>> median_low([1, 3, 5, 7])
330    3
331
332    """
333    data = sorted(data)
334    n = len(data)
335    if n == 0:
336        raise StatisticsError("no median for empty data")
337    if n%2 == 1:
338        return data[n//2]
339    else:
340        return data[n//2 - 1]
341
342
343def median_high(data):
344    """Return the high median of data.
345
346    When the number of data points is odd, the middle value is returned.
347    When it is even, the larger of the two middle values is returned.
348
349    >>> median_high([1, 3, 5])
350    3
351    >>> median_high([1, 3, 5, 7])
352    5
353
354    """
355    data = sorted(data)
356    n = len(data)
357    if n == 0:
358        raise StatisticsError("no median for empty data")
359    return data[n//2]
360
361
362def median_grouped(data, interval=1):
363    """"Return the 50th percentile (median) of grouped continuous data.
364
365    >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
366    3.7
367    >>> median_grouped([52, 52, 53, 54])
368    52.5
369
370    This calculates the median as the 50th percentile, and should be
371    used when your data is continuous and grouped. In the above example,
372    the values 1, 2, 3, etc. actually represent the midpoint of classes
373    0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
374    class 3.5-4.5, and interpolation is used to estimate it.
375
376    Optional argument ``interval`` represents the class interval, and
377    defaults to 1. Changing the class interval naturally will change the
378    interpolated 50th percentile value:
379
380    >>> median_grouped([1, 3, 3, 5, 7], interval=1)
381    3.25
382    >>> median_grouped([1, 3, 3, 5, 7], interval=2)
383    3.5
384
385    This function does not check whether the data points are at least
386    ``interval`` apart.
387    """
388    data = sorted(data)
389    n = len(data)
390    if n == 0:
391        raise StatisticsError("no median for empty data")
392    elif n == 1:
393        return data[0]
394    # Find the value at the midpoint. Remember this corresponds to the
395    # centre of the class interval.
396    x = data[n//2]
397    for obj in (x, interval):
398        if isinstance(obj, (str, bytes)):
399            raise TypeError('expected number but got %r' % obj)
400    try:
401        L = x - interval/2  # The lower limit of the median interval.
402    except TypeError:
403        # Mixed type. For now we just coerce to float.
404        L = float(x) - float(interval)/2
405    cf = data.index(x)  # Number of values below the median interval.
406    # FIXME The following line could be more efficient for big lists.
407    f = data.count(x)  # Number of data points in the median interval.
408    return L + interval*(n/2 - cf)/f
409
410
411def mode(data):
412    """Return the most common data point from discrete or nominal data.
413
414    ``mode`` assumes discrete data, and returns a single value. This is the
415    standard treatment of the mode as commonly taught in schools:
416
417    >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
418    3
419
420    This also works with nominal (non-numeric) data:
421
422    >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
423    'red'
424
425    If there is not exactly one most common value, ``mode`` will raise
426    StatisticsError.
427    """
428    # Generate a table of sorted (value, frequency) pairs.
429    table = _counts(data)
430    if len(table) == 1:
431        return table[0][0]
432    elif table:
433        raise StatisticsError(
434                'no unique mode; found %d equally common values' % len(table)
435                )
436    else:
437        raise StatisticsError('no mode for empty data')
438
439
440# === Measures of spread ===
441
442# See http://mathworld.wolfram.com/Variance.html
443#     http://mathworld.wolfram.com/SampleVariance.html
444#     http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
445#
446# Under no circumstances use the so-called "computational formula for
447# variance", as that is only suitable for hand calculations with a small
448# amount of low-precision data. It has terrible numeric properties.
449#
450# See a comparison of three computational methods here:
451# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
452
453def _ss(data, c=None):
454    """Return sum of square deviations of sequence data.
455
456    If ``c`` is None, the mean is calculated in one pass, and the deviations
457    from the mean are calculated in a second pass. Otherwise, deviations are
458    calculated from ``c`` as given. Use the second case with care, as it can
459    lead to garbage results.
460    """
461    if c is None:
462        c = mean(data)
463    ss = _sum((x-c)**2 for x in data)
464    # The following sum should mathematically equal zero, but due to rounding
465    # error may not.
466    ss -= _sum((x-c) for x in data)**2/len(data)
467    assert not ss < 0, 'negative sum of square deviations: %f' % ss
468    return ss
469
470
471def variance(data, xbar=None):
472    """Return the sample variance of data.
473
474    data should be an iterable of Real-valued numbers, with at least two
475    values. The optional argument xbar, if given, should be the mean of
476    the data. If it is missing or None, the mean is automatically calculated.
477
478    Use this function when your data is a sample from a population. To
479    calculate the variance from the entire population, see ``pvariance``.
480
481    Examples:
482
483    >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
484    >>> variance(data)
485    1.3720238095238095
486
487    If you have already calculated the mean of your data, you can pass it as
488    the optional second argument ``xbar`` to avoid recalculating it:
489
490    >>> m = mean(data)
491    >>> variance(data, m)
492    1.3720238095238095
493
494    This function does not check that ``xbar`` is actually the mean of
495    ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
496    impossible results.
497
498    Decimals and Fractions are supported:
499
500    >>> from decimal import Decimal as D
501    >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
502    Decimal('31.01875')
503
504    >>> from fractions import Fraction as F
505    >>> variance([F(1, 6), F(1, 2), F(5, 3)])
506    Fraction(67, 108)
507
508    """
509    if iter(data) is data:
510        data = list(data)
511    n = len(data)
512    if n < 2:
513        raise StatisticsError('variance requires at least two data points')
514    ss = _ss(data, xbar)
515    return ss/(n-1)
516
517
518def pvariance(data, mu=None):
519    """Return the population variance of ``data``.
520
521    data should be an iterable of Real-valued numbers, with at least one
522    value. The optional argument mu, if given, should be the mean of
523    the data. If it is missing or None, the mean is automatically calculated.
524
525    Use this function to calculate the variance from the entire population.
526    To estimate the variance from a sample, the ``variance`` function is
527    usually a better choice.
528
529    Examples:
530
531    >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
532    >>> pvariance(data)
533    1.25
534
535    If you have already calculated the mean of the data, you can pass it as
536    the optional second argument to avoid recalculating it:
537
538    >>> mu = mean(data)
539    >>> pvariance(data, mu)
540    1.25
541
542    This function does not check that ``mu`` is actually the mean of ``data``.
543    Giving arbitrary values for ``mu`` may lead to invalid or impossible
544    results.
545
546    Decimals and Fractions are supported:
547
548    >>> from decimal import Decimal as D
549    >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
550    Decimal('24.815')
551
552    >>> from fractions import Fraction as F
553    >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
554    Fraction(13, 72)
555
556    """
557    if iter(data) is data:
558        data = list(data)
559    n = len(data)
560    if n < 1:
561        raise StatisticsError('pvariance requires at least one data point')
562    ss = _ss(data, mu)
563    return ss/n
564
565
566def stdev(data, xbar=None):
567    """Return the square root of the sample variance.
568
569    See ``variance`` for arguments and other details.
570
571    >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
572    1.0810874155219827
573
574    """
575    var = variance(data, xbar)
576    try:
577        return var.sqrt()
578    except AttributeError:
579        return math.sqrt(var)
580
581
582def pstdev(data, mu=None):
583    """Return the square root of the population variance.
584
585    See ``pvariance`` for arguments and other details.
586
587    >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
588    0.986893273527251
589
590    """
591    var = pvariance(data, mu)
592    try:
593        return var.sqrt()
594    except AttributeError:
595        return math.sqrt(var)
596