1// Copyright ©2014 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package stat
6
7import (
8	"math"
9	"sort"
10
11	"gonum.org/v1/gonum/floats"
12)
13
14// CumulantKind specifies the behavior for calculating the empirical CDF or Quantile
15type CumulantKind int
16
17// List of supported CumulantKind values for the Quantile function.
18// Constant values should match the R nomenclature. See
19// https://en.wikipedia.org/wiki/Quantile#Estimating_the_quantiles_of_a_population
20const (
21	// Empirical treats the distribution as the actual empirical distribution.
22	Empirical CumulantKind = 1
23	// LinInterp linearly interpolates the empirical distribution between sample values, with a flat extrapolation.
24	LinInterp CumulantKind = 4
25)
26
27// bhattacharyyaCoeff computes the Bhattacharyya Coefficient for probability distributions given by:
28//  \sum_i \sqrt{p_i q_i}
29//
30// It is assumed that p and q have equal length.
31func bhattacharyyaCoeff(p, q []float64) float64 {
32	var bc float64
33	for i, a := range p {
34		bc += math.Sqrt(a * q[i])
35	}
36	return bc
37}
38
39// Bhattacharyya computes the distance between the probability distributions p and q given by:
40//  -\ln ( \sum_i \sqrt{p_i q_i} )
41//
42// The lengths of p and q must be equal. It is assumed that p and q sum to 1.
43func Bhattacharyya(p, q []float64) float64 {
44	if len(p) != len(q) {
45		panic("stat: slice length mismatch")
46	}
47	bc := bhattacharyyaCoeff(p, q)
48	return -math.Log(bc)
49}
50
51// CDF returns the empirical cumulative distribution function value of x, that is
52// the fraction of the samples less than or equal to q. The
53// exact behavior is determined by the CumulantKind. CDF is theoretically
54// the inverse of the Quantile function, though it may not be the actual inverse
55// for all values q and CumulantKinds.
56//
57// The x data must be sorted in increasing order. If weights is nil then all
58// of the weights are 1. If weights is not nil, then len(x) must equal len(weights).
59//
60// CumulantKind behaviors:
61//  - Empirical: Returns the lowest fraction for which q is greater than or equal
62//  to that fraction of samples
63func CDF(q float64, c CumulantKind, x, weights []float64) float64 {
64	if weights != nil && len(x) != len(weights) {
65		panic("stat: slice length mismatch")
66	}
67	if floats.HasNaN(x) {
68		return math.NaN()
69	}
70	if !sort.Float64sAreSorted(x) {
71		panic("x data are not sorted")
72	}
73
74	if q < x[0] {
75		return 0
76	}
77	if q >= x[len(x)-1] {
78		return 1
79	}
80
81	var sumWeights float64
82	if weights == nil {
83		sumWeights = float64(len(x))
84	} else {
85		sumWeights = floats.Sum(weights)
86	}
87
88	// Calculate the index
89	switch c {
90	case Empirical:
91		// Find the smallest value that is greater than that percent of the samples
92		var w float64
93		for i, v := range x {
94			if v > q {
95				return w / sumWeights
96			}
97			if weights == nil {
98				w++
99			} else {
100				w += weights[i]
101			}
102		}
103		panic("impossible")
104	default:
105		panic("stat: bad cumulant kind")
106	}
107}
108
109// ChiSquare computes the chi-square distance between the observed frequencies 'obs' and
110// expected frequencies 'exp' given by:
111//  \sum_i (obs_i-exp_i)^2 / exp_i
112//
113// The lengths of obs and exp must be equal.
114func ChiSquare(obs, exp []float64) float64 {
115	if len(obs) != len(exp) {
116		panic("stat: slice length mismatch")
117	}
118	var result float64
119	for i, a := range obs {
120		b := exp[i]
121		if a == 0 && b == 0 {
122			continue
123		}
124		result += (a - b) * (a - b) / b
125	}
126	return result
127}
128
129// CircularMean returns the circular mean of the dataset.
130//	atan2(\sum_i w_i * sin(alpha_i), \sum_i w_i * cos(alpha_i))
131// If weights is nil then all of the weights are 1. If weights is not nil, then
132// len(x) must equal len(weights).
133func CircularMean(x, weights []float64) float64 {
134	if weights != nil && len(x) != len(weights) {
135		panic("stat: slice length mismatch")
136	}
137
138	var aX, aY float64
139	if weights != nil {
140		for i, v := range x {
141			aX += weights[i] * math.Cos(v)
142			aY += weights[i] * math.Sin(v)
143		}
144	} else {
145		for _, v := range x {
146			aX += math.Cos(v)
147			aY += math.Sin(v)
148		}
149	}
150
151	return math.Atan2(aY, aX)
152}
153
154// Correlation returns the weighted correlation between the samples of x and y
155// with the given means.
156//  sum_i {w_i (x_i - meanX) * (y_i - meanY)} / (stdX * stdY)
157// The lengths of x and y must be equal. If weights is nil then all of the
158// weights are 1. If weights is not nil, then len(x) must equal len(weights).
159func Correlation(x, y, weights []float64) float64 {
160	// This is a two-pass corrected implementation. It is an adaptation of the
161	// algorithm used in the MeanVariance function, which applies a correction
162	// to the typical two pass approach.
163
164	if len(x) != len(y) {
165		panic("stat: slice length mismatch")
166	}
167	xu := Mean(x, weights)
168	yu := Mean(y, weights)
169	var (
170		sxx           float64
171		syy           float64
172		sxy           float64
173		xcompensation float64
174		ycompensation float64
175	)
176	if weights == nil {
177		for i, xv := range x {
178			yv := y[i]
179			xd := xv - xu
180			yd := yv - yu
181			sxx += xd * xd
182			syy += yd * yd
183			sxy += xd * yd
184			xcompensation += xd
185			ycompensation += yd
186		}
187		// xcompensation and ycompensation are from Chan, et. al.
188		// referenced in the MeanVariance function. They are analogous
189		// to the second term in (1.7) in that paper.
190		sxx -= xcompensation * xcompensation / float64(len(x))
191		syy -= ycompensation * ycompensation / float64(len(x))
192
193		return (sxy - xcompensation*ycompensation/float64(len(x))) / math.Sqrt(sxx*syy)
194
195	}
196
197	var sumWeights float64
198	for i, xv := range x {
199		w := weights[i]
200		yv := y[i]
201		xd := xv - xu
202		wxd := w * xd
203		yd := yv - yu
204		wyd := w * yd
205		sxx += wxd * xd
206		syy += wyd * yd
207		sxy += wxd * yd
208		xcompensation += wxd
209		ycompensation += wyd
210		sumWeights += w
211	}
212	// xcompensation and ycompensation are from Chan, et. al.
213	// referenced in the MeanVariance function. They are analogous
214	// to the second term in (1.7) in that paper, except they use
215	// the sumWeights instead of the sample count.
216	sxx -= xcompensation * xcompensation / sumWeights
217	syy -= ycompensation * ycompensation / sumWeights
218
219	return (sxy - xcompensation*ycompensation/sumWeights) / math.Sqrt(sxx*syy)
220}
221
222// Kendall returns the weighted Tau-a Kendall correlation between the
223// samples of x and y. The Kendall correlation measures the quantity of
224// concordant and discordant pairs of numbers. If weights are specified then
225// each pair is weighted by weights[i] * weights[j] and the final sum is
226// normalized to stay between -1 and 1.
227// The lengths of x and y must be equal. If weights is nil then all of the
228// weights are 1. If weights is not nil, then len(x) must equal len(weights).
229func Kendall(x, y, weights []float64) float64 {
230	if len(x) != len(y) {
231		panic("stat: slice length mismatch")
232	}
233
234	var (
235		cc float64 // number of concordant pairs
236		dc float64 // number of discordant pairs
237		n  = len(x)
238	)
239
240	if weights == nil {
241		for i := 0; i < n; i++ {
242			for j := i; j < n; j++ {
243				if i == j {
244					continue
245				}
246				if math.Signbit(x[j]-x[i]) == math.Signbit(y[j]-y[i]) {
247					cc++
248				} else {
249					dc++
250				}
251			}
252		}
253		return (cc - dc) / float64(n*(n-1)/2)
254	}
255
256	var sumWeights float64
257
258	for i := 0; i < n; i++ {
259		for j := i; j < n; j++ {
260			if i == j {
261				continue
262			}
263			weight := weights[i] * weights[j]
264			if math.Signbit(x[j]-x[i]) == math.Signbit(y[j]-y[i]) {
265				cc += weight
266			} else {
267				dc += weight
268			}
269			sumWeights += weight
270		}
271	}
272	return float64(cc-dc) / sumWeights
273}
274
275// Covariance returns the weighted covariance between the samples of x and y.
276//  sum_i {w_i (x_i - meanX) * (y_i - meanY)} / (sum_j {w_j} - 1)
277// The lengths of x and y must be equal. If weights is nil then all of the
278// weights are 1. If weights is not nil, then len(x) must equal len(weights).
279func Covariance(x, y, weights []float64) float64 {
280	// This is a two-pass corrected implementation. It is an adaptation of the
281	// algorithm used in the MeanVariance function, which applies a correction
282	// to the typical two pass approach.
283
284	if len(x) != len(y) {
285		panic("stat: slice length mismatch")
286	}
287	xu := Mean(x, weights)
288	yu := Mean(y, weights)
289	return covarianceMeans(x, y, weights, xu, yu)
290}
291
292// covarianceMeans returns the weighted covariance between x and y with the mean
293// of x and y already specified. See the documentation of Covariance for more
294// information.
295func covarianceMeans(x, y, weights []float64, xu, yu float64) float64 {
296	var (
297		ss            float64
298		xcompensation float64
299		ycompensation float64
300	)
301	if weights == nil {
302		for i, xv := range x {
303			yv := y[i]
304			xd := xv - xu
305			yd := yv - yu
306			ss += xd * yd
307			xcompensation += xd
308			ycompensation += yd
309		}
310		// xcompensation and ycompensation are from Chan, et. al.
311		// referenced in the MeanVariance function. They are analogous
312		// to the second term in (1.7) in that paper.
313		return (ss - xcompensation*ycompensation/float64(len(x))) / float64(len(x)-1)
314	}
315
316	var sumWeights float64
317
318	for i, xv := range x {
319		w := weights[i]
320		yv := y[i]
321		wxd := w * (xv - xu)
322		yd := (yv - yu)
323		ss += wxd * yd
324		xcompensation += wxd
325		ycompensation += w * yd
326		sumWeights += w
327	}
328	// xcompensation and ycompensation are from Chan, et. al.
329	// referenced in the MeanVariance function. They are analogous
330	// to the second term in (1.7) in that paper, except they use
331	// the sumWeights instead of the sample count.
332	return (ss - xcompensation*ycompensation/sumWeights) / (sumWeights - 1)
333}
334
335// CrossEntropy computes the cross-entropy between the two distributions specified
336// in p and q.
337func CrossEntropy(p, q []float64) float64 {
338	if len(p) != len(q) {
339		panic("stat: slice length mismatch")
340	}
341	var ce float64
342	for i, v := range p {
343		if v != 0 {
344			ce -= v * math.Log(q[i])
345		}
346	}
347	return ce
348}
349
350// Entropy computes the Shannon entropy of a distribution or the distance between
351// two distributions. The natural logarithm is used.
352//  - sum_i (p_i * log_e(p_i))
353func Entropy(p []float64) float64 {
354	var e float64
355	for _, v := range p {
356		if v != 0 { // Entropy needs 0 * log(0) == 0.
357			e -= v * math.Log(v)
358		}
359	}
360	return e
361}
362
363// ExKurtosis returns the population excess kurtosis of the sample.
364// The kurtosis is defined by the 4th moment of the mean divided by the squared
365// variance. The excess kurtosis subtracts 3.0 so that the excess kurtosis of
366// the normal distribution is zero.
367// If weights is nil then all of the weights are 1. If weights is not nil, then
368// len(x) must equal len(weights).
369func ExKurtosis(x, weights []float64) float64 {
370	mean, std := MeanStdDev(x, weights)
371	if weights == nil {
372		var e float64
373		for _, v := range x {
374			z := (v - mean) / std
375			e += z * z * z * z
376		}
377		mul, offset := kurtosisCorrection(float64(len(x)))
378		return e*mul - offset
379	}
380
381	var (
382		e          float64
383		sumWeights float64
384	)
385	for i, v := range x {
386		z := (v - mean) / std
387		e += weights[i] * z * z * z * z
388		sumWeights += weights[i]
389	}
390	mul, offset := kurtosisCorrection(sumWeights)
391	return e*mul - offset
392}
393
394// n is the number of samples
395// see https://en.wikipedia.org/wiki/Kurtosis
396func kurtosisCorrection(n float64) (mul, offset float64) {
397	return ((n + 1) / (n - 1)) * (n / (n - 2)) * (1 / (n - 3)), 3 * ((n - 1) / (n - 2)) * ((n - 1) / (n - 3))
398}
399
400// GeometricMean returns the weighted geometric mean of the dataset
401//  \prod_i {x_i ^ w_i}
402// This only applies with positive x and positive weights. If weights is nil
403// then all of the weights are 1. If weights is not nil, then len(x) must equal
404// len(weights).
405func GeometricMean(x, weights []float64) float64 {
406	if weights == nil {
407		var s float64
408		for _, v := range x {
409			s += math.Log(v)
410		}
411		s /= float64(len(x))
412		return math.Exp(s)
413	}
414	if len(x) != len(weights) {
415		panic("stat: slice length mismatch")
416	}
417	var (
418		s          float64
419		sumWeights float64
420	)
421	for i, v := range x {
422		s += weights[i] * math.Log(v)
423		sumWeights += weights[i]
424	}
425	s /= sumWeights
426	return math.Exp(s)
427}
428
429// HarmonicMean returns the weighted harmonic mean of the dataset
430//  \sum_i {w_i} / ( sum_i {w_i / x_i} )
431// This only applies with positive x and positive weights.
432// If weights is nil then all of the weights are 1. If weights is not nil, then
433// len(x) must equal len(weights).
434func HarmonicMean(x, weights []float64) float64 {
435	if weights != nil && len(x) != len(weights) {
436		panic("stat: slice length mismatch")
437	}
438	// TODO(btracey): Fix this to make it more efficient and avoid allocation.
439
440	// This can be numerically unstable (for example if x is very small).
441	// W = \sum_i {w_i}
442	// hm = exp(log(W) - log(\sum_i w_i / x_i))
443
444	logs := make([]float64, len(x))
445	var W float64
446	for i := range x {
447		if weights == nil {
448			logs[i] = -math.Log(x[i])
449			W++
450			continue
451		}
452		logs[i] = math.Log(weights[i]) - math.Log(x[i])
453		W += weights[i]
454	}
455
456	// Sum all of the logs
457	v := floats.LogSumExp(logs) // This computes log(\sum_i { w_i / x_i}).
458	return math.Exp(math.Log(W) - v)
459}
460
461// Hellinger computes the distance between the probability distributions p and q given by:
462//  \sqrt{ 1 - \sum_i \sqrt{p_i q_i} }
463//
464// The lengths of p and q must be equal. It is assumed that p and q sum to 1.
465func Hellinger(p, q []float64) float64 {
466	if len(p) != len(q) {
467		panic("stat: slice length mismatch")
468	}
469	bc := bhattacharyyaCoeff(p, q)
470	return math.Sqrt(1 - bc)
471}
472
473// Histogram sums up the weighted number of data points in each bin.
474// The weight of data point x[i] will be placed into count[j] if
475// dividers[j] <= x < dividers[j+1]. The "span" function in the floats package can assist
476// with bin creation.
477//
478// The following conditions on the inputs apply:
479//  - The count variable must either be nil or have length of one less than dividers.
480//  - The values in dividers must be sorted (use the sort package).
481//  - The x values must be sorted.
482//  - If weights is nil then all of the weights are 1.
483//  - If weights is not nil, then len(x) must equal len(weights).
484func Histogram(count, dividers, x, weights []float64) []float64 {
485	if weights != nil && len(x) != len(weights) {
486		panic("stat: slice length mismatch")
487	}
488	if count == nil {
489		count = make([]float64, len(dividers)-1)
490	}
491	if len(dividers) < 2 {
492		panic("histogram: fewer than two dividers")
493	}
494	if len(count) != len(dividers)-1 {
495		panic("histogram: bin count mismatch")
496	}
497	if !sort.Float64sAreSorted(dividers) {
498		panic("histogram: dividers are not sorted")
499	}
500	if !sort.Float64sAreSorted(x) {
501		panic("histogram: x data are not sorted")
502	}
503	for i := range count {
504		count[i] = 0
505	}
506	if len(x) == 0 {
507		return count
508	}
509	if x[0] < dividers[0] {
510		panic("histogram: minimum x value is less than lowest divider")
511	}
512	if dividers[len(dividers)-1] <= x[len(x)-1] {
513		panic("histogram: maximum x value is greater than or equal to highest divider")
514	}
515
516	idx := 0
517	comp := dividers[idx+1]
518	if weights == nil {
519		for _, v := range x {
520			if v < comp {
521				// Still in the current bucket.
522				count[idx]++
523				continue
524			}
525			// Find the next divider where v is less than the divider.
526			for j := idx + 1; j < len(dividers); j++ {
527				if v < dividers[j+1] {
528					idx = j
529					comp = dividers[j+1]
530					break
531				}
532			}
533			count[idx]++
534		}
535		return count
536	}
537
538	for i, v := range x {
539		if v < comp {
540			// Still in the current bucket.
541			count[idx] += weights[i]
542			continue
543		}
544		// Need to find the next divider where v is less than the divider.
545		for j := idx + 1; j < len(count); j++ {
546			if v < dividers[j+1] {
547				idx = j
548				comp = dividers[j+1]
549				break
550			}
551		}
552		count[idx] += weights[i]
553	}
554	return count
555}
556
557// JensenShannon computes the JensenShannon divergence between the distributions
558// p and q. The Jensen-Shannon divergence is defined as
559//  m = 0.5 * (p + q)
560//  JS(p, q) = 0.5 ( KL(p, m) + KL(q, m) )
561// Unlike Kullback-Liebler, the Jensen-Shannon distance is symmetric. The value
562// is between 0 and ln(2).
563func JensenShannon(p, q []float64) float64 {
564	if len(p) != len(q) {
565		panic("stat: slice length mismatch")
566	}
567	var js float64
568	for i, v := range p {
569		qi := q[i]
570		m := 0.5 * (v + qi)
571		if v != 0 {
572			// add kl from p to m
573			js += 0.5 * v * (math.Log(v) - math.Log(m))
574		}
575		if qi != 0 {
576			// add kl from q to m
577			js += 0.5 * qi * (math.Log(qi) - math.Log(m))
578		}
579	}
580	return js
581}
582
583// KolmogorovSmirnov computes the largest distance between two empirical CDFs.
584// Each dataset x and y consists of sample locations and counts, xWeights and
585// yWeights, respectively.
586//
587// x and y may have different lengths, though len(x) must equal len(xWeights), and
588// len(y) must equal len(yWeights). Both x and y must be sorted.
589//
590// Special cases are:
591//  = 0 if len(x) == len(y) == 0
592//  = 1 if len(x) == 0, len(y) != 0 or len(x) != 0 and len(y) == 0
593func KolmogorovSmirnov(x, xWeights, y, yWeights []float64) float64 {
594	if xWeights != nil && len(x) != len(xWeights) {
595		panic("stat: slice length mismatch")
596	}
597	if yWeights != nil && len(y) != len(yWeights) {
598		panic("stat: slice length mismatch")
599	}
600	if len(x) == 0 || len(y) == 0 {
601		if len(x) == 0 && len(y) == 0 {
602			return 0
603		}
604		return 1
605	}
606
607	if floats.HasNaN(x) {
608		return math.NaN()
609	}
610	if floats.HasNaN(y) {
611		return math.NaN()
612	}
613
614	if !sort.Float64sAreSorted(x) {
615		panic("x data are not sorted")
616	}
617	if !sort.Float64sAreSorted(y) {
618		panic("y data are not sorted")
619	}
620
621	xWeightsNil := xWeights == nil
622	yWeightsNil := yWeights == nil
623
624	var (
625		maxDist    float64
626		xSum, ySum float64
627		xCdf, yCdf float64
628		xIdx, yIdx int
629	)
630
631	if xWeightsNil {
632		xSum = float64(len(x))
633	} else {
634		xSum = floats.Sum(xWeights)
635	}
636
637	if yWeightsNil {
638		ySum = float64(len(y))
639	} else {
640		ySum = floats.Sum(yWeights)
641	}
642
643	xVal := x[0]
644	yVal := y[0]
645
646	// Algorithm description:
647	// The goal is to find the maximum difference in the empirical CDFs for the
648	// two datasets. The CDFs are piecewise-constant, and thus the distance
649	// between the CDFs will only change at the values themselves.
650	//
651	// To find the maximum distance, step through the data in ascending order
652	// of value between the two datasets. At each step, compute the empirical CDF
653	// and compare the local distance with the maximum distance.
654	// Due to some corner cases, equal data entries must be tallied simultaneously.
655	for {
656		switch {
657		case xVal < yVal:
658			xVal, xCdf, xIdx = updateKS(xIdx, xCdf, xSum, x, xWeights, xWeightsNil)
659		case yVal < xVal:
660			yVal, yCdf, yIdx = updateKS(yIdx, yCdf, ySum, y, yWeights, yWeightsNil)
661		case xVal == yVal:
662			newX := x[xIdx]
663			newY := y[yIdx]
664			if newX < newY {
665				xVal, xCdf, xIdx = updateKS(xIdx, xCdf, xSum, x, xWeights, xWeightsNil)
666			} else if newY < newX {
667				yVal, yCdf, yIdx = updateKS(yIdx, yCdf, ySum, y, yWeights, yWeightsNil)
668			} else {
669				// Update them both, they'll be equal next time and the right
670				// thing will happen.
671				xVal, xCdf, xIdx = updateKS(xIdx, xCdf, xSum, x, xWeights, xWeightsNil)
672				yVal, yCdf, yIdx = updateKS(yIdx, yCdf, ySum, y, yWeights, yWeightsNil)
673			}
674		default:
675			panic("unreachable")
676		}
677
678		dist := math.Abs(xCdf - yCdf)
679		if dist > maxDist {
680			maxDist = dist
681		}
682
683		// Both xCdf and yCdf will equal 1 at the end, so if we have reached the
684		// end of either sample list, the distance is as large as it can be.
685		if xIdx == len(x) || yIdx == len(y) {
686			return maxDist
687		}
688	}
689}
690
691// updateKS gets the next data point from one of the set. In doing so, it combines
692// the weight of all the data points of equal value. Upon return, val is the new
693// value of the data set, newCdf is the total combined CDF up until this point,
694// and newIdx is the index of the next location in that sample to examine.
695func updateKS(idx int, cdf, sum float64, values, weights []float64, isNil bool) (val, newCdf float64, newIdx int) {
696	// Sum up all the weights of consecutive values that are equal.
697	if isNil {
698		newCdf = cdf + 1/sum
699	} else {
700		newCdf = cdf + weights[idx]/sum
701	}
702	newIdx = idx + 1
703	for {
704		if newIdx == len(values) {
705			return values[newIdx-1], newCdf, newIdx
706		}
707		if values[newIdx-1] != values[newIdx] {
708			return values[newIdx], newCdf, newIdx
709		}
710		if isNil {
711			newCdf += 1 / sum
712		} else {
713			newCdf += weights[newIdx] / sum
714		}
715		newIdx++
716	}
717}
718
719// KullbackLeibler computes the Kullback-Leibler distance between the
720// distributions p and q. The natural logarithm is used.
721//  sum_i(p_i * log(p_i / q_i))
722// Note that the Kullback-Leibler distance is not symmetric;
723// KullbackLeibler(p,q) != KullbackLeibler(q,p)
724func KullbackLeibler(p, q []float64) float64 {
725	if len(p) != len(q) {
726		panic("stat: slice length mismatch")
727	}
728	var kl float64
729	for i, v := range p {
730		if v != 0 { // Entropy needs 0 * log(0) == 0.
731			kl += v * (math.Log(v) - math.Log(q[i]))
732		}
733	}
734	return kl
735}
736
737// LinearRegression computes the best-fit line
738//  y = alpha + beta*x
739// to the data in x and y with the given weights. If origin is true, the
740// regression is forced to pass through the origin.
741//
742// Specifically, LinearRegression computes the values of alpha and
743// beta such that the total residual
744//  \sum_i w[i]*(y[i] - alpha - beta*x[i])^2
745// is minimized. If origin is true, then alpha is forced to be zero.
746//
747// The lengths of x and y must be equal. If weights is nil then all of the
748// weights are 1. If weights is not nil, then len(x) must equal len(weights).
749func LinearRegression(x, y, weights []float64, origin bool) (alpha, beta float64) {
750	if len(x) != len(y) {
751		panic("stat: slice length mismatch")
752	}
753	if weights != nil && len(weights) != len(x) {
754		panic("stat: slice length mismatch")
755	}
756
757	w := 1.0
758	if origin {
759		var x2Sum, xySum float64
760		for i, xi := range x {
761			if weights != nil {
762				w = weights[i]
763			}
764			yi := y[i]
765			xySum += w * xi * yi
766			x2Sum += w * xi * xi
767		}
768		beta = xySum / x2Sum
769
770		return 0, beta
771	}
772
773	xu, xv := MeanVariance(x, weights)
774	yu := Mean(y, weights)
775	cov := covarianceMeans(x, y, weights, xu, yu)
776	beta = cov / xv
777	alpha = yu - beta*xu
778	return alpha, beta
779}
780
781// RSquared returns the coefficient of determination defined as
782//  R^2 = 1 - \sum_i w[i]*(y[i] - alpha - beta*x[i])^2 / \sum_i w[i]*(y[i] - mean(y))^2
783// for the line
784//  y = alpha + beta*x
785// and the data in x and y with the given weights.
786//
787// The lengths of x and y must be equal. If weights is nil then all of the
788// weights are 1. If weights is not nil, then len(x) must equal len(weights).
789func RSquared(x, y, weights []float64, alpha, beta float64) float64 {
790	if len(x) != len(y) {
791		panic("stat: slice length mismatch")
792	}
793	if weights != nil && len(weights) != len(x) {
794		panic("stat: slice length mismatch")
795	}
796
797	w := 1.0
798	yMean := Mean(y, weights)
799	var res, tot, d float64
800	for i, xi := range x {
801		if weights != nil {
802			w = weights[i]
803		}
804		yi := y[i]
805		fi := alpha + beta*xi
806		d = yi - fi
807		res += w * d * d
808		d = yi - yMean
809		tot += w * d * d
810	}
811	return 1 - res/tot
812}
813
814// RSquaredFrom returns the coefficient of determination defined as
815//  R^2 = 1 - \sum_i w[i]*(estimate[i] - value[i])^2 / \sum_i w[i]*(value[i] - mean(values))^2
816// and the data in estimates and values with the given weights.
817//
818// The lengths of estimates and values must be equal. If weights is nil then
819// all of the weights are 1. If weights is not nil, then len(values) must
820// equal len(weights).
821func RSquaredFrom(estimates, values, weights []float64) float64 {
822	if len(estimates) != len(values) {
823		panic("stat: slice length mismatch")
824	}
825	if weights != nil && len(weights) != len(values) {
826		panic("stat: slice length mismatch")
827	}
828
829	w := 1.0
830	mean := Mean(values, weights)
831	var res, tot, d float64
832	for i, val := range values {
833		if weights != nil {
834			w = weights[i]
835		}
836		d = val - estimates[i]
837		res += w * d * d
838		d = val - mean
839		tot += w * d * d
840	}
841	return 1 - res/tot
842}
843
844// RNoughtSquared returns the coefficient of determination defined as
845//  R₀^2 = \sum_i w[i]*(beta*x[i])^2 / \sum_i w[i]*y[i]^2
846// for the line
847//  y = beta*x
848// and the data in x and y with the given weights. RNoughtSquared should
849// only be used for best-fit lines regressed through the origin.
850//
851// The lengths of x and y must be equal. If weights is nil then all of the
852// weights are 1. If weights is not nil, then len(x) must equal len(weights).
853func RNoughtSquared(x, y, weights []float64, beta float64) float64 {
854	if len(x) != len(y) {
855		panic("stat: slice length mismatch")
856	}
857	if weights != nil && len(weights) != len(x) {
858		panic("stat: slice length mismatch")
859	}
860
861	w := 1.0
862	var ssr, tot float64
863	for i, xi := range x {
864		if weights != nil {
865			w = weights[i]
866		}
867		fi := beta * xi
868		ssr += w * fi * fi
869		yi := y[i]
870		tot += w * yi * yi
871	}
872	return ssr / tot
873}
874
875// Mean computes the weighted mean of the data set.
876//  sum_i {w_i * x_i} / sum_i {w_i}
877// If weights is nil then all of the weights are 1. If weights is not nil, then
878// len(x) must equal len(weights).
879func Mean(x, weights []float64) float64 {
880	if weights == nil {
881		return floats.Sum(x) / float64(len(x))
882	}
883	if len(x) != len(weights) {
884		panic("stat: slice length mismatch")
885	}
886	var (
887		sumValues  float64
888		sumWeights float64
889	)
890	for i, w := range weights {
891		sumValues += w * x[i]
892		sumWeights += w
893	}
894	return sumValues / sumWeights
895}
896
897// Mode returns the most common value in the dataset specified by x and the
898// given weights. Strict float64 equality is used when comparing values, so users
899// should take caution. If several values are the mode, any of them may be returned.
900func Mode(x, weights []float64) (val float64, count float64) {
901	if weights != nil && len(x) != len(weights) {
902		panic("stat: slice length mismatch")
903	}
904	if len(x) == 0 {
905		return 0, 0
906	}
907	m := make(map[float64]float64)
908	if weights == nil {
909		for _, v := range x {
910			m[v]++
911		}
912	} else {
913		for i, v := range x {
914			m[v] += weights[i]
915		}
916	}
917	var (
918		maxCount float64
919		max      float64
920	)
921	for val, count := range m {
922		if count > maxCount {
923			maxCount = count
924			max = val
925		}
926	}
927	return max, maxCount
928}
929
930// BivariateMoment computes the weighted mixed moment between the samples x and y.
931//  E[(x - μ_x)^r*(y - μ_y)^s]
932// No degrees of freedom correction is done.
933// The lengths of x and y must be equal. If weights is nil then all of the
934// weights are 1. If weights is not nil, then len(x) must equal len(weights).
935func BivariateMoment(r, s float64, x, y, weights []float64) float64 {
936	meanX := Mean(x, weights)
937	meanY := Mean(y, weights)
938	if len(x) != len(y) {
939		panic("stat: slice length mismatch")
940	}
941	if weights == nil {
942		var m float64
943		for i, vx := range x {
944			vy := y[i]
945			m += math.Pow(vx-meanX, r) * math.Pow(vy-meanY, s)
946		}
947		return m / float64(len(x))
948	}
949	if len(weights) != len(x) {
950		panic("stat: slice length mismatch")
951	}
952	var (
953		m          float64
954		sumWeights float64
955	)
956	for i, vx := range x {
957		vy := y[i]
958		w := weights[i]
959		m += w * math.Pow(vx-meanX, r) * math.Pow(vy-meanY, s)
960		sumWeights += w
961	}
962	return m / sumWeights
963}
964
965// Moment computes the weighted n^th moment of the samples,
966//  E[(x - μ)^N]
967// No degrees of freedom correction is done.
968// If weights is nil then all of the weights are 1. If weights is not nil, then
969// len(x) must equal len(weights).
970func Moment(moment float64, x, weights []float64) float64 {
971	// This also checks that x and weights have the same length.
972	mean := Mean(x, weights)
973	if weights == nil {
974		var m float64
975		for _, v := range x {
976			m += math.Pow(v-mean, moment)
977		}
978		return m / float64(len(x))
979	}
980	var (
981		m          float64
982		sumWeights float64
983	)
984	for i, v := range x {
985		w := weights[i]
986		m += w * math.Pow(v-mean, moment)
987		sumWeights += w
988	}
989	return m / sumWeights
990}
991
992// MomentAbout computes the weighted n^th weighted moment of the samples about
993// the given mean \mu,
994//  E[(x - μ)^N]
995// No degrees of freedom correction is done.
996// If weights is nil then all of the weights are 1. If weights is not nil, then
997// len(x) must equal len(weights).
998func MomentAbout(moment float64, x []float64, mean float64, weights []float64) float64 {
999	if weights == nil {
1000		var m float64
1001		for _, v := range x {
1002			m += math.Pow(v-mean, moment)
1003		}
1004		m /= float64(len(x))
1005		return m
1006	}
1007	if len(weights) != len(x) {
1008		panic("stat: slice length mismatch")
1009	}
1010	var (
1011		m          float64
1012		sumWeights float64
1013	)
1014	for i, v := range x {
1015		m += weights[i] * math.Pow(v-mean, moment)
1016		sumWeights += weights[i]
1017	}
1018	return m / sumWeights
1019}
1020
1021// Quantile returns the sample of x such that x is greater than or
1022// equal to the fraction p of samples. The exact behavior is determined by the
1023// CumulantKind, and p should be a number between 0 and 1. Quantile is theoretically
1024// the inverse of the CDF function, though it may not be the actual inverse
1025// for all values p and CumulantKinds.
1026//
1027// The x data must be sorted in increasing order. If weights is nil then all
1028// of the weights are 1. If weights is not nil, then len(x) must equal len(weights).
1029//
1030// CumulantKind behaviors:
1031//  - Empirical: Returns the lowest value q for which q is greater than or equal
1032//  to the fraction p of samples
1033//  - LinInterp: Returns the linearly interpolated value
1034func Quantile(p float64, c CumulantKind, x, weights []float64) float64 {
1035	if !(p >= 0 && p <= 1) {
1036		panic("stat: percentile out of bounds")
1037	}
1038
1039	if weights != nil && len(x) != len(weights) {
1040		panic("stat: slice length mismatch")
1041	}
1042	if floats.HasNaN(x) {
1043		return math.NaN() // This is needed because the algorithm breaks otherwise.
1044	}
1045	if !sort.Float64sAreSorted(x) {
1046		panic("x data are not sorted")
1047	}
1048
1049	var sumWeights float64
1050	if weights == nil {
1051		sumWeights = float64(len(x))
1052	} else {
1053		sumWeights = floats.Sum(weights)
1054	}
1055	switch c {
1056	case Empirical:
1057		return empiricalQuantile(p, x, weights, sumWeights)
1058	case LinInterp:
1059		return linInterpQuantile(p, x, weights, sumWeights)
1060	default:
1061		panic("stat: bad cumulant kind")
1062	}
1063}
1064
1065func empiricalQuantile(p float64, x, weights []float64, sumWeights float64) float64 {
1066	var cumsum float64
1067	fidx := p * sumWeights
1068	for i := range x {
1069		if weights == nil {
1070			cumsum++
1071		} else {
1072			cumsum += weights[i]
1073		}
1074		if cumsum >= fidx {
1075			return x[i]
1076		}
1077	}
1078	panic("impossible")
1079}
1080
1081func linInterpQuantile(p float64, x, weights []float64, sumWeights float64) float64 {
1082	var cumsum float64
1083	fidx := p * sumWeights
1084	for i := range x {
1085		if weights == nil {
1086			cumsum++
1087		} else {
1088			cumsum += weights[i]
1089		}
1090		if cumsum >= fidx {
1091			if i == 0 {
1092				return x[0]
1093			}
1094			t := cumsum - fidx
1095			if weights != nil {
1096				t /= weights[i]
1097			}
1098			return t*x[i-1] + (1-t)*x[i]
1099		}
1100	}
1101	panic("impossible")
1102}
1103
1104// Skew computes the skewness of the sample data.
1105// If weights is nil then all of the weights are 1. If weights is not nil, then
1106// len(x) must equal len(weights).
1107// When weights sum to 1 or less, a biased variance estimator should be used.
1108func Skew(x, weights []float64) float64 {
1109
1110	mean, std := MeanStdDev(x, weights)
1111	if weights == nil {
1112		var s float64
1113		for _, v := range x {
1114			z := (v - mean) / std
1115			s += z * z * z
1116		}
1117		return s * skewCorrection(float64(len(x)))
1118	}
1119	var (
1120		s          float64
1121		sumWeights float64
1122	)
1123	for i, v := range x {
1124		z := (v - mean) / std
1125		s += weights[i] * z * z * z
1126		sumWeights += weights[i]
1127	}
1128	return s * skewCorrection(sumWeights)
1129}
1130
1131// From: http://www.amstat.org/publications/jse/v19n2/doane.pdf page 7
1132func skewCorrection(n float64) float64 {
1133	return (n / (n - 1)) * (1 / (n - 2))
1134}
1135
1136// SortWeighted rearranges the data in x along with their corresponding
1137// weights so that the x data are sorted. The data is sorted in place.
1138// Weights may be nil, but if weights is non-nil then it must have the same
1139// length as x.
1140func SortWeighted(x, weights []float64) {
1141	if weights == nil {
1142		sort.Float64s(x)
1143		return
1144	}
1145	if len(x) != len(weights) {
1146		panic("stat: slice length mismatch")
1147	}
1148	sort.Sort(weightSorter{
1149		x: x,
1150		w: weights,
1151	})
1152}
1153
1154type weightSorter struct {
1155	x []float64
1156	w []float64
1157}
1158
1159func (w weightSorter) Len() int           { return len(w.x) }
1160func (w weightSorter) Less(i, j int) bool { return w.x[i] < w.x[j] }
1161func (w weightSorter) Swap(i, j int) {
1162	w.x[i], w.x[j] = w.x[j], w.x[i]
1163	w.w[i], w.w[j] = w.w[j], w.w[i]
1164}
1165
1166// SortWeightedLabeled rearranges the data in x along with their
1167// corresponding weights and boolean labels so that the x data are sorted.
1168// The data is sorted in place. Weights and labels may be nil, if either
1169// is non-nil it must have the same length as x.
1170func SortWeightedLabeled(x []float64, labels []bool, weights []float64) {
1171	if labels == nil {
1172		SortWeighted(x, weights)
1173		return
1174	}
1175	if weights == nil {
1176		if len(x) != len(labels) {
1177			panic("stat: slice length mismatch")
1178		}
1179		sort.Sort(labelSorter{
1180			x: x,
1181			l: labels,
1182		})
1183		return
1184	}
1185	if len(x) != len(labels) || len(x) != len(weights) {
1186		panic("stat: slice length mismatch")
1187	}
1188	sort.Sort(weightLabelSorter{
1189		x: x,
1190		l: labels,
1191		w: weights,
1192	})
1193}
1194
1195type labelSorter struct {
1196	x []float64
1197	l []bool
1198}
1199
1200func (a labelSorter) Len() int           { return len(a.x) }
1201func (a labelSorter) Less(i, j int) bool { return a.x[i] < a.x[j] }
1202func (a labelSorter) Swap(i, j int) {
1203	a.x[i], a.x[j] = a.x[j], a.x[i]
1204	a.l[i], a.l[j] = a.l[j], a.l[i]
1205}
1206
1207type weightLabelSorter struct {
1208	x []float64
1209	l []bool
1210	w []float64
1211}
1212
1213func (a weightLabelSorter) Len() int           { return len(a.x) }
1214func (a weightLabelSorter) Less(i, j int) bool { return a.x[i] < a.x[j] }
1215func (a weightLabelSorter) Swap(i, j int) {
1216	a.x[i], a.x[j] = a.x[j], a.x[i]
1217	a.l[i], a.l[j] = a.l[j], a.l[i]
1218	a.w[i], a.w[j] = a.w[j], a.w[i]
1219}
1220
1221// StdDev returns the sample standard deviation.
1222func StdDev(x, weights []float64) float64 {
1223	_, std := MeanStdDev(x, weights)
1224	return std
1225}
1226
1227// MeanStdDev returns the sample mean and unbiased standard deviation
1228// When weights sum to 1 or less, a biased variance estimator should be used.
1229func MeanStdDev(x, weights []float64) (mean, std float64) {
1230	mean, variance := MeanVariance(x, weights)
1231	return mean, math.Sqrt(variance)
1232}
1233
1234// StdErr returns the standard error in the mean with the given values.
1235func StdErr(std, sampleSize float64) float64 {
1236	return std / math.Sqrt(sampleSize)
1237}
1238
1239// StdScore returns the standard score (a.k.a. z-score, z-value) for the value x
1240// with the given mean and standard deviation, i.e.
1241//  (x - mean) / std
1242func StdScore(x, mean, std float64) float64 {
1243	return (x - mean) / std
1244}
1245
1246// Variance computes the unbiased weighted sample variance:
1247//  \sum_i w_i (x_i - mean)^2 / (sum_i w_i - 1)
1248// If weights is nil then all of the weights are 1. If weights is not nil, then
1249// len(x) must equal len(weights).
1250// When weights sum to 1 or less, a biased variance estimator should be used.
1251func Variance(x, weights []float64) float64 {
1252	_, variance := MeanVariance(x, weights)
1253	return variance
1254}
1255
1256// MeanVariance computes the sample mean and unbiased variance, where the mean and variance are
1257//  \sum_i w_i * x_i / (sum_i w_i)
1258//  \sum_i w_i (x_i - mean)^2 / (sum_i w_i - 1)
1259// respectively.
1260// If weights is nil then all of the weights are 1. If weights is not nil, then
1261// len(x) must equal len(weights).
1262// When weights sum to 1 or less, a biased variance estimator should be used.
1263func MeanVariance(x, weights []float64) (mean, variance float64) {
1264	var (
1265		unnormalisedVariance float64
1266		sumWeights           float64
1267	)
1268	mean, unnormalisedVariance, sumWeights = meanUnnormalisedVarianceSumWeights(x, weights)
1269	return mean, unnormalisedVariance / (sumWeights - 1)
1270}
1271
1272// PopMeanVariance computes the sample mean and biased variance (also known as
1273// "population variance"), where the mean and variance are
1274//  \sum_i w_i * x_i / (sum_i w_i)
1275//  \sum_i w_i (x_i - mean)^2 / (sum_i w_i)
1276// respectively.
1277// If weights is nil then all of the weights are 1. If weights is not nil, then
1278// len(x) must equal len(weights).
1279func PopMeanVariance(x, weights []float64) (mean, variance float64) {
1280	var (
1281		unnormalisedVariance float64
1282		sumWeights           float64
1283	)
1284	mean, unnormalisedVariance, sumWeights = meanUnnormalisedVarianceSumWeights(x, weights)
1285	return mean, unnormalisedVariance / sumWeights
1286}
1287
1288// PopMeanStdDev returns the sample mean and biased standard deviation
1289// (also known as "population standard deviation").
1290func PopMeanStdDev(x, weights []float64) (mean, std float64) {
1291	mean, variance := PopMeanVariance(x, weights)
1292	return mean, math.Sqrt(variance)
1293}
1294
1295// PopStdDev returns the population standard deviation, i.e., a square root
1296// of the biased variance estimate.
1297func PopStdDev(x, weights []float64) float64 {
1298	_, stDev := PopMeanStdDev(x, weights)
1299	return stDev
1300}
1301
1302// PopVariance computes the unbiased weighted sample variance:
1303//  \sum_i w_i (x_i - mean)^2 / (sum_i w_i)
1304// If weights is nil then all of the weights are 1. If weights is not nil, then
1305// len(x) must equal len(weights).
1306func PopVariance(x, weights []float64) float64 {
1307	_, variance := PopMeanVariance(x, weights)
1308	return variance
1309}
1310
1311func meanUnnormalisedVarianceSumWeights(x, weights []float64) (mean, unnormalisedVariance, sumWeights float64) {
1312	// This uses the corrected two-pass algorithm (1.7), from "Algorithms for computing
1313	// the sample variance: Analysis and recommendations" by Chan, Tony F., Gene H. Golub,
1314	// and Randall J. LeVeque.
1315
1316	// Note that this will panic if the slice lengths do not match.
1317	mean = Mean(x, weights)
1318	var (
1319		ss           float64
1320		compensation float64
1321	)
1322	if weights == nil {
1323		for _, v := range x {
1324			d := v - mean
1325			ss += d * d
1326			compensation += d
1327		}
1328		unnormalisedVariance = (ss - compensation*compensation/float64(len(x)))
1329		return mean, unnormalisedVariance, float64(len(x))
1330	}
1331
1332	for i, v := range x {
1333		w := weights[i]
1334		d := v - mean
1335		wd := w * d
1336		ss += wd * d
1337		compensation += wd
1338		sumWeights += w
1339	}
1340	unnormalisedVariance = (ss - compensation*compensation/sumWeights)
1341	return mean, unnormalisedVariance, sumWeights
1342}
1343