1package auth
2
3import (
4	"encoding/base64"
5	"fmt"
6	"io/ioutil"
7	"net/http"
8	"os"
9	"reflect"
10	"strings"
11	"testing"
12	"time"
13
14	"github.com/fabiolb/fabio/config"
15	"github.com/fabiolb/fabio/uuid"
16)
17
18type responseWriter struct {
19	header  http.Header
20	code    int
21	written []byte
22}
23
24func (rw *responseWriter) Header() http.Header {
25	if rw.header == nil {
26		rw.header = map[string][]string{}
27	}
28	return rw.header
29}
30
31func (rw *responseWriter) Write(b []byte) (int, error) {
32	rw.written = append(rw.written, b...)
33	return len(rw.written), nil
34}
35
36func (rw *responseWriter) WriteHeader(statusCode int) {
37	rw.code = statusCode
38}
39
40func createBasicAuthFile(contents string) (string, error) {
41	dir, err := ioutil.TempDir("", "basicauth")
42
43	if err != nil {
44		return "", fmt.Errorf("could not create temp dir: %s", err)
45	}
46
47	filename := fmt.Sprintf("%s/%s", dir, uuid.NewUUID())
48
49	err = ioutil.WriteFile(filename, []byte(contents), 0666)
50
51	if err != nil {
52		return "", fmt.Errorf("could not write password file: %s", err)
53	}
54
55	return filename, nil
56}
57
58func createBasicAuth(user string, password string) (AuthScheme, error) {
59	contents := fmt.Sprintf("%s:%s", user, password)
60
61	filename, err := createBasicAuthFile(contents)
62
63	a, err := newBasicAuth(config.BasicAuth{
64		File:  filename,
65		Realm: "testrealm",
66	})
67
68	if err != nil {
69		return nil, fmt.Errorf("could not create basic auth: %s", err)
70	}
71
72	return a, nil
73}
74
75func TestNewBasicAuth(t *testing.T) {
76
77	t.Run("should create a basic auth scheme from the supplied config", func(t *testing.T) {
78		filename, err := createBasicAuthFile("foo:bar")
79
80		if err != nil {
81			t.Error(err)
82		}
83
84		_, err = newBasicAuth(config.BasicAuth{
85			File: filename,
86		})
87
88		if err != nil {
89			t.Error(err)
90		}
91	})
92
93	t.Run("should log a warning when credentials are malformed", func(t *testing.T) {
94		filename, err := createBasicAuthFile("foosdlijdgohdgdbar")
95
96		if err != nil {
97			t.Error(err)
98		}
99
100		_, err = newBasicAuth(config.BasicAuth{
101			File: filename,
102		})
103
104		if err != nil {
105			t.Error(err)
106		}
107	})
108}
109
110func TestBasic_Authorised(t *testing.T) {
111	basicAuth, err := createBasicAuth("foo", "bar")
112	creds := []byte("foo:bar")
113
114	if err != nil {
115		t.Fatal(err)
116	}
117
118	tests := []struct {
119		name string
120		req  *http.Request
121		res  http.ResponseWriter
122		out  bool
123	}{
124		{
125			"correct credentials should be authorized",
126			&http.Request{
127				Header: http.Header{
128					"Authorization": []string{fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(creds))},
129				},
130			},
131			&responseWriter{},
132			true,
133		},
134		{
135			"incorrect credentials should not be authorized",
136			&http.Request{
137				Header: http.Header{
138					"Authorization": []string{fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("baz:blarg")))},
139				},
140			},
141			&responseWriter{},
142			false,
143		},
144		{
145			"missing Authorization header should not be authorized",
146			&http.Request{
147				Header: http.Header{},
148			},
149			&responseWriter{},
150			false,
151		},
152		{
153			"malformed Authorization header should not be authorized",
154			&http.Request{
155				Header: http.Header{
156					"Authorization": []string{"malformed"},
157				},
158			},
159			&responseWriter{},
160			false,
161		},
162	}
163
164	for _, tt := range tests {
165		t.Run(tt.name, func(t *testing.T) {
166			if got, want := basicAuth.Authorized(tt.req, tt.res), tt.out; !reflect.DeepEqual(got, want) {
167				t.Errorf("got %v want %v", got, want)
168			}
169		})
170	}
171}
172
173func TestBasic_Authorised_should_fail_without_htpasswd_file(t *testing.T) {
174	filename, err := createBasicAuthFile("foo:bar")
175	if err != nil {
176		t.Error(err)
177	}
178
179	a, err := newBasicAuth(config.BasicAuth{
180		File:    filename,
181		Refresh: time.Second,
182	})
183	if err != nil {
184		t.Error(err)
185	}
186
187	creds := []byte("foo:bar")
188	r := &http.Request{
189		Header: http.Header{
190			"Authorization": []string{fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(creds))},
191		},
192	}
193
194	w := &responseWriter{}
195
196	t.Run("should authorize against supplied htpasswd file", func(t *testing.T) {
197		if got, want := a.Authorized(r, w), true; !reflect.DeepEqual(got, want) {
198			t.Errorf("got %v want %v", got, want)
199		}
200	})
201
202	if err := os.Remove(filename); err != nil {
203		t.Fatalf("removing htpasswd file: %s", err)
204	}
205
206	time.Sleep(2 * time.Second) // ensure htpasswd file refresh happend
207
208	t.Run("should not authorize after removing htpasswd file", func(t *testing.T) {
209		if got, want := a.Authorized(r, w), false; !reflect.DeepEqual(got, want) {
210			t.Errorf("got %v want %v", got, want)
211		}
212	})
213}
214
215func TestBasic_Authorized_should_set_www_realm_header(t *testing.T) {
216	basicAuth, err := createBasicAuth("foo", "bar")
217
218	if err != nil {
219		t.Fatal(err)
220	}
221
222	rw := &responseWriter{}
223
224	_ = basicAuth.Authorized(&http.Request{Header: http.Header{}}, rw)
225
226	got := rw.Header().Get("WWW-Authenticate")
227	want := `Basic realm="testrealm"`
228
229	if strings.Compare(got, want) != 0 {
230		t.Errorf("got '%s', want '%s'", got, want)
231	}
232}
233