1// Copyright ©2016 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package stat 6 7import ( 8 "math" 9 "sort" 10) 11 12// ROC returns paired false positive rate (FPR) and true positive rate 13// (TPR) values corresponding to cutoff points on the receiver operator 14// characteristic (ROC) curve obtained when y is treated as a binary 15// classifier for classes with weights. The cutoff thresholds used to 16// calculate the ROC are returned in thresh such that tpr[i] and fpr[i] 17// are the true and false positive rates for y >= thresh[i]. 18// 19// The input y and cutoffs must be sorted, and values in y must correspond 20// to values in classes and weights. SortWeightedLabeled can be used to 21// sort y together with classes and weights. 22// 23// For a given cutoff value, observations corresponding to entries in y 24// greater than the cutoff value are classified as true, while those 25// less than or equal to the cutoff value are classified as false. These 26// assigned class labels are compared with the true values in the classes 27// slice and used to calculate the FPR and TPR. 28// 29// If weights is nil, all weights are treated as 1. If weights is not nil 30// it must have the same length as y and classes, otherwise ROC will panic. 31// 32// If cutoffs is nil or empty, all possible cutoffs are calculated, 33// resulting in fpr and tpr having length one greater than the number of 34// unique values in y. Otherwise fpr and tpr will be returned with the 35// same length as cutoffs. floats.Span can be used to generate equally 36// spaced cutoffs. 37// 38// More details about ROC curves are available at 39// https://en.wikipedia.org/wiki/Receiver_operating_characteristic 40func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) { 41 if len(y) != len(classes) { 42 panic("stat: slice length mismatch") 43 } 44 if weights != nil && len(y) != len(weights) { 45 panic("stat: slice length mismatch") 46 } 47 if !sort.Float64sAreSorted(y) { 48 panic("stat: input must be sorted ascending") 49 } 50 if !sort.Float64sAreSorted(cutoffs) { 51 panic("stat: cutoff values must be sorted ascending") 52 } 53 if len(y) == 0 { 54 return nil, nil, nil 55 } 56 if len(cutoffs) == 0 { 57 if cutoffs == nil || cap(cutoffs) < len(y)+1 { 58 cutoffs = make([]float64, len(y)+1) 59 } else { 60 cutoffs = cutoffs[:len(y)+1] 61 } 62 // Choose all possible cutoffs for unique values in y. 63 bin := 0 64 cutoffs[bin] = y[0] 65 for i, u := range y[1:] { 66 if u == y[i] { 67 continue 68 } 69 bin++ 70 cutoffs[bin] = u 71 } 72 cutoffs[bin+1] = math.Inf(1) 73 cutoffs = cutoffs[:bin+2] 74 } else { 75 // Don't mutate the provided cutoffs. 76 tmp := cutoffs 77 cutoffs = make([]float64, len(cutoffs)) 78 copy(cutoffs, tmp) 79 } 80 81 tpr = make([]float64, len(cutoffs)) 82 fpr = make([]float64, len(cutoffs)) 83 var bin int 84 var nPos, nNeg float64 85 for i, u := range classes { 86 // Update the bin until it matches the next y value 87 // skipping empty bins. 88 for bin < len(cutoffs)-1 && y[i] >= cutoffs[bin] { 89 bin++ 90 tpr[bin] = tpr[bin-1] 91 fpr[bin] = fpr[bin-1] 92 } 93 posWeight, negWeight := 1.0, 0.0 94 if weights != nil { 95 posWeight = weights[i] 96 } 97 if !u { 98 posWeight, negWeight = negWeight, posWeight 99 } 100 nPos += posWeight 101 nNeg += negWeight 102 // Count false negatives (in tpr) and true negatives (in fpr). 103 if y[i] < cutoffs[bin] { 104 tpr[bin] += posWeight 105 fpr[bin] += negWeight 106 } 107 } 108 109 invNeg := 1 / nNeg 110 invPos := 1 / nPos 111 // Convert negative counts to TPR and FPR. 112 // Bins beyond the maximum value in y are skipped 113 // leaving these fpr and tpr elements as zero. 114 for i := range tpr[:bin+1] { 115 // Prevent fused float operations by 116 // making explicit float64 conversions. 117 tpr[i] = 1 - float64(tpr[i]*invPos) 118 fpr[i] = 1 - float64(fpr[i]*invNeg) 119 } 120 for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 { 121 tpr[i], tpr[j] = tpr[j], tpr[i] 122 fpr[i], fpr[j] = fpr[j], fpr[i] 123 } 124 for i, j := 0, len(cutoffs)-1; i < j; i, j = i+1, j-1 { 125 cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i] 126 } 127 128 return tpr, fpr, cutoffs 129} 130