1// Copyright ©2016 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package stat
6
7import (
8	"math"
9	"testing"
10
11	"gonum.org/v1/gonum/floats"
12)
13
14func TestROC(t *testing.T) {
15	const tol = 1e-14
16
17	cases := []struct {
18		y          []float64
19		c          []bool
20		w          []float64
21		cutoffs    []float64
22		wantTPR    []float64
23		wantFPR    []float64
24		wantThresh []float64
25	}{
26		// Test cases were informed by using sklearn metrics.roc_curve when
27		// cutoffs is nil, but all test cases (including when cutoffs is not
28		// nil) were calculated manually.
29		// Some differences exist between unweighted ROCs from our function
30		// and metrics.roc_curve which appears to use integer cutoffs in that
31		// case. sklearn also appears to do some magic that trims leading zeros
32		// sometimes.
33		{ // 0
34			y:          []float64{0, 3, 5, 6, 7.5, 8},
35			c:          []bool{false, true, false, true, true, true},
36			wantTPR:    []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
37			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1},
38			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
39		},
40		{ // 1
41			y:          []float64{0, 3, 5, 6, 7.5, 8},
42			c:          []bool{false, true, false, true, true, true},
43			w:          []float64{4, 1, 6, 3, 2, 2},
44			wantTPR:    []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1},
45			wantFPR:    []float64{0, 0, 0, 0, 0.6, 0.6, 1},
46			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
47		},
48		{ // 2
49			y:          []float64{0, 3, 5, 6, 7.5, 8},
50			c:          []bool{false, true, false, true, true, true},
51			cutoffs:    []float64{-1, 2, 4, 6, 8},
52			wantTPR:    []float64{0.25, 0.75, 0.75, 1, 1},
53			wantFPR:    []float64{0, 0, 0.5, 0.5, 1},
54			wantThresh: []float64{8, 6, 4, 2, -1},
55		},
56		{ // 3
57			y:          []float64{0, 3, 5, 6, 7.5, 8},
58			c:          []bool{false, true, false, true, true, true},
59			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
60			wantTPR:    []float64{0.25, 0.5, 0.75, 0.75, 0.75, 1, 1, 1, 1},
61			wantFPR:    []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
62			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
63		},
64		{ // 4
65			y:          []float64{0, 3, 5, 6, 7.5, 8},
66			c:          []bool{false, true, false, true, true, true},
67			w:          []float64{4, 1, 6, 3, 2, 2},
68			cutoffs:    []float64{-1, 2, 4, 6, 8},
69			wantTPR:    []float64{0.25, 0.875, 0.875, 1, 1},
70			wantFPR:    []float64{0, 0, 0.6, 0.6, 1},
71			wantThresh: []float64{8, 6, 4, 2, -1},
72		},
73		{ // 5
74			y:          []float64{0, 3, 5, 6, 7.5, 8},
75			c:          []bool{false, true, false, true, true, true},
76			w:          []float64{4, 1, 6, 3, 2, 2},
77			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
78			wantTPR:    []float64{0.25, 0.5, 0.875, 0.875, 0.875, 1, 1, 1, 1},
79			wantFPR:    []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
80			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
81		},
82		{ // 6
83			y:          []float64{0, 3, 6, 6, 6, 8},
84			c:          []bool{false, true, false, true, true, true},
85			wantTPR:    []float64{0, 0.25, 0.75, 1, 1},
86			wantFPR:    []float64{0, 0, 0.5, 0.5, 1},
87			wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
88		},
89		{ // 7
90			y:          []float64{0, 3, 6, 6, 6, 8},
91			c:          []bool{false, true, false, true, true, true},
92			w:          []float64{4, 1, 6, 3, 2, 2},
93			wantTPR:    []float64{0, 0.25, 0.875, 1, 1},
94			wantFPR:    []float64{0, 0, 0.6, 0.6, 1},
95			wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
96		},
97		{ // 8
98			y:          []float64{0, 3, 6, 6, 6, 8},
99			c:          []bool{false, true, false, true, true, true},
100			cutoffs:    []float64{-1, 2, 4, 6, 8},
101			wantTPR:    []float64{0.25, 0.75, 0.75, 1, 1},
102			wantFPR:    []float64{0, 0.5, 0.5, 0.5, 1},
103			wantThresh: []float64{8, 6, 4, 2, -1},
104		},
105		{ // 9
106			y:          []float64{0, 3, 6, 6, 6, 8},
107			c:          []bool{false, true, false, true, true, true},
108			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
109			wantTPR:    []float64{0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1, 1},
110			wantFPR:    []float64{0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
111			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
112		},
113		{ // 10
114			y:          []float64{0, 3, 6, 6, 6, 8},
115			c:          []bool{false, true, false, true, true, true},
116			w:          []float64{4, 1, 6, 3, 2, 2},
117			cutoffs:    []float64{-1, 2, 4, 6, 8},
118			wantTPR:    []float64{0.25, 0.875, 0.875, 1, 1},
119			wantFPR:    []float64{0, 0.6, 0.6, 0.6, 1},
120			wantThresh: []float64{8, 6, 4, 2, -1},
121		},
122		{ // 11
123			y:          []float64{0, 3, 6, 6, 6, 8},
124			c:          []bool{false, true, false, true, true, true},
125			w:          []float64{4, 1, 6, 3, 2, 2},
126			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
127			wantTPR:    []float64{0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1, 1},
128			wantFPR:    []float64{0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
129			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
130		},
131		{ // 12
132			y:          []float64{0.1, 0.35, 0.4, 0.8},
133			c:          []bool{true, false, true, false},
134			wantTPR:    []float64{0, 0, 0.5, 0.5, 1},
135			wantFPR:    []float64{0, 0.5, 0.5, 1, 1},
136			wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
137		},
138		{ // 13
139			y:          []float64{0.1, 0.35, 0.4, 0.8},
140			c:          []bool{false, false, true, true},
141			wantTPR:    []float64{0, 0.5, 1, 1, 1},
142			wantFPR:    []float64{0, 0, 0, 0.5, 1},
143			wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
144		},
145		{ // 14
146			y:          []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
147			c:          []bool{false, true, false, false, true, true, false},
148			cutoffs:    []float64{-1, 2.5, 5, 7.5, 10},
149			wantTPR:    []float64{0, 0, 0, 0, 1},
150			wantFPR:    []float64{0.25, 0.25, 0.25, 0.25, 1},
151			wantThresh: []float64{10, 7.5, 5, 2.5, -1},
152		},
153		{ // 15
154			y:          []float64{1, 2},
155			c:          []bool{false, false},
156			wantTPR:    []float64{math.NaN(), math.NaN(), math.NaN()},
157			wantFPR:    []float64{0, 0.5, 1},
158			wantThresh: []float64{math.Inf(1), 2, 1},
159		},
160		{ // 16
161			y:          []float64{1, 2},
162			c:          []bool{false, false},
163			cutoffs:    []float64{-1, 2},
164			wantTPR:    []float64{math.NaN(), math.NaN()},
165			wantFPR:    []float64{0.5, 1},
166			wantThresh: []float64{2, -1},
167		},
168		{ // 17
169			y:          []float64{1, 2},
170			c:          []bool{false, false},
171			cutoffs:    []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
172			wantTPR:    []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
173			wantFPR:    []float64{0.5, 0.5, 0.5, 0.5, 0.5, 1},
174			wantThresh: []float64{2, 1.8, 1.6, 1.4, 1.2, 0},
175		},
176		{ // 18
177			y:          []float64{1},
178			c:          []bool{false},
179			wantTPR:    []float64{math.NaN(), math.NaN()},
180			wantFPR:    []float64{0, 1},
181			wantThresh: []float64{math.Inf(1), 1},
182		},
183		{ // 19
184			y:          []float64{1},
185			c:          []bool{false},
186			cutoffs:    []float64{-1, 1},
187			wantTPR:    []float64{math.NaN(), math.NaN()},
188			wantFPR:    []float64{1, 1},
189			wantThresh: []float64{1, -1},
190		},
191		{ // 20
192			y:          []float64{1},
193			c:          []bool{true},
194			wantTPR:    []float64{0, 1},
195			wantFPR:    []float64{math.NaN(), math.NaN()},
196			wantThresh: []float64{math.Inf(1), 1},
197		},
198		{ // 21
199			y:          []float64{},
200			c:          []bool{},
201			wantTPR:    nil,
202			wantFPR:    nil,
203			wantThresh: nil,
204		},
205		{ // 22
206			y:          []float64{},
207			c:          []bool{},
208			cutoffs:    []float64{-1, 2.5, 5, 7.5, 10},
209			wantTPR:    nil,
210			wantFPR:    nil,
211			wantThresh: nil,
212		},
213		{ // 23
214			y:          []float64{0.1, 0.35, 0.4, 0.8},
215			c:          []bool{true, false, true, false},
216			cutoffs:    []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1},
217			wantTPR:    []float64{0, 0, 0, 0.5, 0.5, 1, 1},
218			wantFPR:    []float64{0, 0, 0.5, 0.5, 1, 1, 1},
219			wantThresh: []float64{1, 0.9, 0.8, 0.4, 0.35, 0.1, -1},
220		},
221		{ // 24
222			y:          []float64{0.1, 0.35, 0.4, 0.8},
223			c:          []bool{true, false, true, false},
224			cutoffs:    []float64{math.Inf(-1), 0.1, 0.36, 0.8},
225			wantTPR:    []float64{0, 0.5, 1, 1},
226			wantFPR:    []float64{0.5, 0.5, 1, 1},
227			wantThresh: []float64{0.8, 0.36, 0.1, math.Inf(-1)},
228		},
229		{ // 25
230			y:          []float64{0, 3, 5, 6, 7.5, 8},
231			c:          []bool{false, true, false, true, true, true},
232			cutoffs:    make([]float64, 0, 10),
233			wantTPR:    []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
234			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1},
235			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
236		},
237		{ // 26
238			y:          []float64{0.1, 0.35, 0.4, 0.8},
239			c:          []bool{true, false, true, false},
240			cutoffs:    []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1, 1.1, 1.2},
241			wantTPR:    []float64{0, 0, 0, 0, 0, 0.5, 0.5, 1, 1},
242			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1, 1, 1},
243			wantThresh: []float64{1.2, 1.1, 1, 0.9, 0.8, 0.4, 0.35, 0.1, -1},
244		},
245	}
246	for i, test := range cases {
247		gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w)
248		if !floats.Same(gotTPR, test.wantTPR) && !floats.EqualApprox(gotTPR, test.wantTPR, tol) {
249			t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR)
250		}
251		if !floats.Same(gotFPR, test.wantFPR) && !floats.EqualApprox(gotFPR, test.wantFPR, tol) {
252			t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR)
253		}
254		if !floats.Same(gotThresh, test.wantThresh) {
255			t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh)
256		}
257	}
258}
259