1package basic 2 3import ( 4 "context" 5 "encoding/base64" 6 "fmt" 7 "testing" 8 9 httptransport "github.com/go-kit/kit/transport/http" 10) 11 12func TestWithBasicAuth(t *testing.T) { 13 requiredUser := "test-user" 14 requiredPassword := "test-pass" 15 realm := "test realm" 16 17 type want struct { 18 result interface{} 19 err error 20 } 21 tests := []struct { 22 name string 23 authHeader interface{} 24 want want 25 }{ 26 {"Isn't valid with nil header", nil, want{nil, AuthError{realm}}}, 27 {"Isn't valid with non-string header", 42, want{nil, AuthError{realm}}}, 28 {"Isn't valid without authHeader", "", want{nil, AuthError{realm}}}, 29 {"Isn't valid for wrong user", makeAuthString("wrong-user", requiredPassword), want{nil, AuthError{realm}}}, 30 {"Isn't valid for wrong password", makeAuthString(requiredUser, "wrong-password"), want{nil, AuthError{realm}}}, 31 {"Is valid for correct creds", makeAuthString(requiredUser, requiredPassword), want{true, nil}}, 32 } 33 for _, tt := range tests { 34 t.Run(tt.name, func(t *testing.T) { 35 ctx := context.WithValue(context.TODO(), httptransport.ContextKeyRequestAuthorization, tt.authHeader) 36 37 result, err := AuthMiddleware(requiredUser, requiredPassword, realm)(passedValidation)(ctx, nil) 38 if result != tt.want.result || err != tt.want.err { 39 t.Errorf("WithBasicAuth() = result: %v, err: %v, want result: %v, want error: %v", result, err, tt.want.result, tt.want.err) 40 } 41 }) 42 } 43} 44 45func makeAuthString(user string, password string) string { 46 data := []byte(fmt.Sprintf("%s:%s", user, password)) 47 return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(data)) 48} 49 50func passedValidation(ctx context.Context, request interface{}) (response interface{}, err error) { 51 return true, nil 52} 53