1package basic
2
3import (
4	"context"
5	"encoding/base64"
6	"fmt"
7	"testing"
8
9	httptransport "github.com/go-kit/kit/transport/http"
10)
11
12func TestWithBasicAuth(t *testing.T) {
13	requiredUser := "test-user"
14	requiredPassword := "test-pass"
15	realm := "test realm"
16
17	type want struct {
18		result interface{}
19		err    error
20	}
21	tests := []struct {
22		name       string
23		authHeader interface{}
24		want       want
25	}{
26		{"Isn't valid with nil header", nil, want{nil, AuthError{realm}}},
27		{"Isn't valid with non-string header", 42, want{nil, AuthError{realm}}},
28		{"Isn't valid without authHeader", "", want{nil, AuthError{realm}}},
29		{"Isn't valid for wrong user", makeAuthString("wrong-user", requiredPassword), want{nil, AuthError{realm}}},
30		{"Isn't valid for wrong password", makeAuthString(requiredUser, "wrong-password"), want{nil, AuthError{realm}}},
31		{"Is valid for correct creds", makeAuthString(requiredUser, requiredPassword), want{true, nil}},
32	}
33	for _, tt := range tests {
34		t.Run(tt.name, func(t *testing.T) {
35			ctx := context.WithValue(context.TODO(), httptransport.ContextKeyRequestAuthorization, tt.authHeader)
36
37			result, err := AuthMiddleware(requiredUser, requiredPassword, realm)(passedValidation)(ctx, nil)
38			if result != tt.want.result || err != tt.want.err {
39				t.Errorf("WithBasicAuth() = result: %v, err: %v, want result: %v, want error: %v", result, err, tt.want.result, tt.want.err)
40			}
41		})
42	}
43}
44
45func makeAuthString(user string, password string) string {
46	data := []byte(fmt.Sprintf("%s:%s", user, password))
47	return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(data))
48}
49
50func passedValidation(ctx context.Context, request interface{}) (response interface{}, err error) {
51	return true, nil
52}
53