1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.util
19
20import org.apache.spark.annotation.Since
21
22/**
23 * A class for tracking the statistics of a set of numbers (count, mean and variance) in a
24 * numerically robust way. Includes support for merging two StatCounters. Based on Welford
25 * and Chan's <a href="http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance">
26 * algorithms</a> for running variance.
27 *
28 * @constructor Initialize the StatCounter with the given values.
29 */
30class StatCounter(values: TraversableOnce[Double]) extends Serializable {
31  private var n: Long = 0     // Running count of our values
32  private var mu: Double = 0  // Running mean of our values
33  private var m2: Double = 0  // Running variance numerator (sum of (x - mean)^2)
34  private var maxValue: Double = Double.NegativeInfinity // Running max of our values
35  private var minValue: Double = Double.PositiveInfinity // Running min of our values
36
37  merge(values)
38
39  /** Initialize the StatCounter with no values. */
40  def this() = this(Nil)
41
42  /** Add a value into this StatCounter, updating the internal statistics. */
43  def merge(value: Double): StatCounter = {
44    val delta = value - mu
45    n += 1
46    mu += delta / n
47    m2 += delta * (value - mu)
48    maxValue = math.max(maxValue, value)
49    minValue = math.min(minValue, value)
50    this
51  }
52
53  /** Add multiple values into this StatCounter, updating the internal statistics. */
54  def merge(values: TraversableOnce[Double]): StatCounter = {
55    values.foreach(v => merge(v))
56    this
57  }
58
59  /** Merge another StatCounter into this one, adding up the internal statistics. */
60  def merge(other: StatCounter): StatCounter = {
61    if (other == this) {
62      merge(other.copy())  // Avoid overwriting fields in a weird order
63    } else {
64      if (n == 0) {
65        mu = other.mu
66        m2 = other.m2
67        n = other.n
68        maxValue = other.maxValue
69        minValue = other.minValue
70      } else if (other.n != 0) {
71        val delta = other.mu - mu
72        if (other.n * 10 < n) {
73          mu = mu + (delta * other.n) / (n + other.n)
74        } else if (n * 10 < other.n) {
75          mu = other.mu - (delta * n) / (n + other.n)
76        } else {
77          mu = (mu * n + other.mu * other.n) / (n + other.n)
78        }
79        m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
80        n += other.n
81        maxValue = math.max(maxValue, other.maxValue)
82        minValue = math.min(minValue, other.minValue)
83      }
84      this
85    }
86  }
87
88  /** Clone this StatCounter */
89  def copy(): StatCounter = {
90    val other = new StatCounter
91    other.n = n
92    other.mu = mu
93    other.m2 = m2
94    other.maxValue = maxValue
95    other.minValue = minValue
96    other
97  }
98
99  def count: Long = n
100
101  def mean: Double = mu
102
103  def sum: Double = n * mu
104
105  def max: Double = maxValue
106
107  def min: Double = minValue
108
109  /** Return the population variance of the values. */
110  def variance: Double = popVariance
111
112  /**
113   * Return the population variance of the values.
114   */
115  @Since("2.1.0")
116  def popVariance: Double = {
117    if (n == 0) {
118      Double.NaN
119    } else {
120      m2 / n
121    }
122  }
123
124  /**
125   * Return the sample variance, which corrects for bias in estimating the variance by dividing
126   * by N-1 instead of N.
127   */
128  def sampleVariance: Double = {
129    if (n <= 1) {
130      Double.NaN
131    } else {
132      m2 / (n - 1)
133    }
134  }
135
136  /** Return the population standard deviation of the values. */
137  def stdev: Double = popStdev
138
139  /**
140   * Return the population standard deviation of the values.
141   */
142  @Since("2.1.0")
143  def popStdev: Double = math.sqrt(popVariance)
144
145  /**
146   * Return the sample standard deviation of the values, which corrects for bias in estimating the
147   * variance by dividing by N-1 instead of N.
148   */
149  def sampleStdev: Double = math.sqrt(sampleVariance)
150
151  override def toString: String = {
152    "(count: %d, mean: %f, stdev: %f, max: %f, min: %f)".format(count, mean, stdev, max, min)
153  }
154}
155
156object StatCounter {
157  /** Build a StatCounter from a list of values. */
158  def apply(values: TraversableOnce[Double]): StatCounter = new StatCounter(values)
159
160  /** Build a StatCounter from a list of values passed as variable-length arguments. */
161  def apply(values: Double*): StatCounter = new StatCounter(values)
162}
163