1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.util 19 20import org.apache.spark.annotation.Since 21 22/** 23 * A class for tracking the statistics of a set of numbers (count, mean and variance) in a 24 * numerically robust way. Includes support for merging two StatCounters. Based on Welford 25 * and Chan's <a href="http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance"> 26 * algorithms</a> for running variance. 27 * 28 * @constructor Initialize the StatCounter with the given values. 29 */ 30class StatCounter(values: TraversableOnce[Double]) extends Serializable { 31 private var n: Long = 0 // Running count of our values 32 private var mu: Double = 0 // Running mean of our values 33 private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2) 34 private var maxValue: Double = Double.NegativeInfinity // Running max of our values 35 private var minValue: Double = Double.PositiveInfinity // Running min of our values 36 37 merge(values) 38 39 /** Initialize the StatCounter with no values. */ 40 def this() = this(Nil) 41 42 /** Add a value into this StatCounter, updating the internal statistics. */ 43 def merge(value: Double): StatCounter = { 44 val delta = value - mu 45 n += 1 46 mu += delta / n 47 m2 += delta * (value - mu) 48 maxValue = math.max(maxValue, value) 49 minValue = math.min(minValue, value) 50 this 51 } 52 53 /** Add multiple values into this StatCounter, updating the internal statistics. */ 54 def merge(values: TraversableOnce[Double]): StatCounter = { 55 values.foreach(v => merge(v)) 56 this 57 } 58 59 /** Merge another StatCounter into this one, adding up the internal statistics. */ 60 def merge(other: StatCounter): StatCounter = { 61 if (other == this) { 62 merge(other.copy()) // Avoid overwriting fields in a weird order 63 } else { 64 if (n == 0) { 65 mu = other.mu 66 m2 = other.m2 67 n = other.n 68 maxValue = other.maxValue 69 minValue = other.minValue 70 } else if (other.n != 0) { 71 val delta = other.mu - mu 72 if (other.n * 10 < n) { 73 mu = mu + (delta * other.n) / (n + other.n) 74 } else if (n * 10 < other.n) { 75 mu = other.mu - (delta * n) / (n + other.n) 76 } else { 77 mu = (mu * n + other.mu * other.n) / (n + other.n) 78 } 79 m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) 80 n += other.n 81 maxValue = math.max(maxValue, other.maxValue) 82 minValue = math.min(minValue, other.minValue) 83 } 84 this 85 } 86 } 87 88 /** Clone this StatCounter */ 89 def copy(): StatCounter = { 90 val other = new StatCounter 91 other.n = n 92 other.mu = mu 93 other.m2 = m2 94 other.maxValue = maxValue 95 other.minValue = minValue 96 other 97 } 98 99 def count: Long = n 100 101 def mean: Double = mu 102 103 def sum: Double = n * mu 104 105 def max: Double = maxValue 106 107 def min: Double = minValue 108 109 /** Return the population variance of the values. */ 110 def variance: Double = popVariance 111 112 /** 113 * Return the population variance of the values. 114 */ 115 @Since("2.1.0") 116 def popVariance: Double = { 117 if (n == 0) { 118 Double.NaN 119 } else { 120 m2 / n 121 } 122 } 123 124 /** 125 * Return the sample variance, which corrects for bias in estimating the variance by dividing 126 * by N-1 instead of N. 127 */ 128 def sampleVariance: Double = { 129 if (n <= 1) { 130 Double.NaN 131 } else { 132 m2 / (n - 1) 133 } 134 } 135 136 /** Return the population standard deviation of the values. */ 137 def stdev: Double = popStdev 138 139 /** 140 * Return the population standard deviation of the values. 141 */ 142 @Since("2.1.0") 143 def popStdev: Double = math.sqrt(popVariance) 144 145 /** 146 * Return the sample standard deviation of the values, which corrects for bias in estimating the 147 * variance by dividing by N-1 instead of N. 148 */ 149 def sampleStdev: Double = math.sqrt(sampleVariance) 150 151 override def toString: String = { 152 "(count: %d, mean: %f, stdev: %f, max: %f, min: %f)".format(count, mean, stdev, max, min) 153 } 154} 155 156object StatCounter { 157 /** Build a StatCounter from a list of values. */ 158 def apply(values: TraversableOnce[Double]): StatCounter = new StatCounter(values) 159 160 /** Build a StatCounter from a list of values passed as variable-length arguments. */ 161 def apply(values: Double*): StatCounter = new StatCounter(values) 162} 163