1package route
2
3import (
4	"net/http"
5	"reflect"
6	"testing"
7
8	"github.com/fabiolb/fabio/auth"
9)
10
11type testAuth struct {
12	ok bool
13}
14
15func (t *testAuth) Authorized(r *http.Request, w http.ResponseWriter) bool {
16	return t.ok
17}
18
19type responseWriter struct {
20	header  http.Header
21	code    int
22	written []byte
23}
24
25func (rw *responseWriter) Header() http.Header {
26	return rw.header
27}
28
29func (rw *responseWriter) Write(b []byte) (int, error) {
30	rw.written = append(rw.written, b...)
31	return len(rw.written), nil
32}
33
34func (rw *responseWriter) WriteHeader(statusCode int) {
35	rw.code = statusCode
36}
37
38func TestTarget_Authorized(t *testing.T) {
39	tests := []struct {
40		name        string
41		authScheme  string
42		authSchemes map[string]auth.AuthScheme
43		out         bool
44	}{
45		{
46			name:       "matches correct auth scheme",
47			authScheme: "mybasic",
48			authSchemes: map[string]auth.AuthScheme{
49				"mybasic": &testAuth{ok: true},
50			},
51			out: true,
52		},
53		{
54			name:       "returns true when scheme is empty",
55			authScheme: "",
56			authSchemes: map[string]auth.AuthScheme{
57				"mybasic": &testAuth{ok: false},
58			},
59			out: true,
60		},
61		{
62			name:       "returns false when scheme is unknown",
63			authScheme: "foobar",
64			authSchemes: map[string]auth.AuthScheme{
65				"mybasic": &testAuth{ok: true},
66			},
67			out: false,
68		},
69	}
70
71	for _, tt := range tests {
72		t.Run(tt.name, func(t *testing.T) {
73			target := &Target{
74				AuthScheme: tt.authScheme,
75			}
76
77			if got, want := target.Authorized(&http.Request{}, &responseWriter{}, tt.authSchemes), tt.out; !reflect.DeepEqual(got, want) {
78				t.Errorf("got %v want %v", got, want)
79			}
80		})
81	}
82}
83