1// Copyright ©2016 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package stat
6
7import (
8	"math"
9	"sort"
10)
11
12// ROC returns paired false positive rate (FPR) and true positive rate
13// (TPR) values corresponding to cutoff points on the receiver operator
14// characteristic (ROC) curve obtained when y is treated as a binary
15// classifier for classes with weights. The cutoff thresholds used to
16// calculate the ROC are returned in thresh such that tpr[i] and fpr[i]
17// are the true and false positive rates for y >= thresh[i].
18//
19// The input y and cutoffs must be sorted, and values in y must correspond
20// to values in classes and weights. SortWeightedLabeled can be used to
21// sort y together with classes and weights.
22//
23// For a given cutoff value, observations corresponding to entries in y
24// greater than the cutoff value are classified as true, while those
25// less than or equal to the cutoff value are classified as false. These
26// assigned class labels are compared with the true values in the classes
27// slice and used to calculate the FPR and TPR.
28//
29// If weights is nil, all weights are treated as 1. If weights is not nil
30// it must have the same length as y and classes, otherwise ROC will panic.
31//
32// If cutoffs is nil or empty, all possible cutoffs are calculated,
33// resulting in fpr and tpr having length one greater than the number of
34// unique values in y. Otherwise fpr and tpr will be returned with the
35// same length as cutoffs. floats.Span can be used to generate equally
36// spaced cutoffs.
37//
38// More details about ROC curves are available at
39// https://en.wikipedia.org/wiki/Receiver_operating_characteristic
40func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) {
41	if len(y) != len(classes) {
42		panic("stat: slice length mismatch")
43	}
44	if weights != nil && len(y) != len(weights) {
45		panic("stat: slice length mismatch")
46	}
47	if !sort.Float64sAreSorted(y) {
48		panic("stat: input must be sorted ascending")
49	}
50	if !sort.Float64sAreSorted(cutoffs) {
51		panic("stat: cutoff values must be sorted ascending")
52	}
53	if len(y) == 0 {
54		return nil, nil, nil
55	}
56	if len(cutoffs) == 0 {
57		if cutoffs == nil || cap(cutoffs) < len(y)+1 {
58			cutoffs = make([]float64, len(y)+1)
59		} else {
60			cutoffs = cutoffs[:len(y)+1]
61		}
62		// Choose all possible cutoffs for unique values in y.
63		bin := 0
64		cutoffs[bin] = y[0]
65		for i, u := range y[1:] {
66			if u == y[i] {
67				continue
68			}
69			bin++
70			cutoffs[bin] = u
71		}
72		cutoffs[bin+1] = math.Inf(1)
73		cutoffs = cutoffs[:bin+2]
74	} else {
75		// Don't mutate the provided cutoffs.
76		tmp := cutoffs
77		cutoffs = make([]float64, len(cutoffs))
78		copy(cutoffs, tmp)
79	}
80
81	tpr = make([]float64, len(cutoffs))
82	fpr = make([]float64, len(cutoffs))
83	var bin int
84	var nPos, nNeg float64
85	for i, u := range classes {
86		// Update the bin until it matches the next y value
87		// skipping empty bins.
88		for bin < len(cutoffs)-1 && y[i] >= cutoffs[bin] {
89			bin++
90			tpr[bin] = tpr[bin-1]
91			fpr[bin] = fpr[bin-1]
92		}
93		posWeight, negWeight := 1.0, 0.0
94		if weights != nil {
95			posWeight = weights[i]
96		}
97		if !u {
98			posWeight, negWeight = negWeight, posWeight
99		}
100		nPos += posWeight
101		nNeg += negWeight
102		// Count false negatives (in tpr) and true negatives (in fpr).
103		if y[i] < cutoffs[bin] {
104			tpr[bin] += posWeight
105			fpr[bin] += negWeight
106		}
107	}
108
109	invNeg := 1 / nNeg
110	invPos := 1 / nPos
111	// Convert negative counts to TPR and FPR.
112	// Bins beyond the maximum value in y are skipped
113	// leaving these fpr and tpr elements as zero.
114	for i := range tpr[:bin+1] {
115		// Prevent fused float operations by
116		// making explicit float64 conversions.
117		tpr[i] = 1 - float64(tpr[i]*invPos)
118		fpr[i] = 1 - float64(fpr[i]*invNeg)
119	}
120	for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 {
121		tpr[i], tpr[j] = tpr[j], tpr[i]
122		fpr[i], fpr[j] = fpr[j], fpr[i]
123	}
124	for i, j := 0, len(cutoffs)-1; i < j; i, j = i+1, j-1 {
125		cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i]
126	}
127
128	return tpr, fpr, cutoffs
129}
130