1// Copyright 2015 go-swagger maintainers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package security
16
17import (
18	"context"
19	"net/http"
20	"testing"
21
22	"github.com/go-openapi/errors"
23	"github.com/stretchr/testify/assert"
24)
25
26var basicAuthHandler = UserPassAuthentication(func(user, pass string) (interface{}, error) {
27	if user == "admin" && pass == "123456" {
28		return "admin", nil
29	}
30	return "", errors.Unauthenticated("basic")
31})
32
33func TestValidBasicAuth(t *testing.T) {
34	ba := BasicAuth(basicAuthHandler)
35
36	req, _ := http.NewRequest("GET", "/blah", nil)
37	req.SetBasicAuth("admin", "123456")
38	ok, usr, err := ba.Authenticate(req)
39
40	assert.NoError(t, err)
41	assert.True(t, ok)
42	assert.Equal(t, "admin", usr)
43}
44
45func TestInvalidBasicAuth(t *testing.T) {
46	ba := BasicAuth(basicAuthHandler)
47
48	req, _ := http.NewRequest("GET", "/blah", nil)
49	req.SetBasicAuth("admin", "admin")
50	ok, usr, err := ba.Authenticate(req)
51
52	assert.Error(t, err)
53	assert.True(t, ok)
54	assert.Equal(t, "", usr)
55
56	assert.NotEmpty(t, FailedBasicAuth(req))
57	assert.Equal(t, DefaultRealmName, FailedBasicAuth(req))
58}
59
60func TestMissingbasicAuth(t *testing.T) {
61	ba := BasicAuth(basicAuthHandler)
62
63	req, _ := http.NewRequest("GET", "/blah", nil)
64
65	ok, usr, err := ba.Authenticate(req)
66	assert.NoError(t, err)
67	assert.False(t, ok)
68	assert.Equal(t, nil, usr)
69
70	assert.NotEmpty(t, FailedBasicAuth(req))
71	assert.Equal(t, DefaultRealmName, FailedBasicAuth(req))
72}
73
74func TestNoRequestBasicAuth(t *testing.T) {
75	ba := BasicAuth(basicAuthHandler)
76
77	ok, usr, err := ba.Authenticate("token")
78
79	assert.NoError(t, err)
80	assert.False(t, ok)
81	assert.Nil(t, usr)
82}
83
84type secTestKey uint8
85
86const (
87	original secTestKey = iota
88	extra
89	reason
90)
91
92const (
93	wisdom      = "The man who is swimming against the stream knows the strength of it."
94	extraWisdom = "Our greatest glory is not in never falling, but in rising every time we fall."
95	expReason   = "I like the dreams of the future better than the history of the past."
96)
97
98var basicAuthHandlerCtx = UserPassAuthenticationCtx(func(ctx context.Context, user, pass string) (context.Context, interface{}, error) {
99	if user == "admin" && pass == "123456" {
100		return context.WithValue(ctx, extra, extraWisdom), "admin", nil
101	}
102	return context.WithValue(ctx, reason, expReason), "", errors.Unauthenticated("basic")
103})
104
105func TestValidBasicAuthCtx(t *testing.T) {
106	ba := BasicAuthCtx(basicAuthHandlerCtx)
107
108	req, _ := http.NewRequest("GET", "/blah", nil)
109	req = req.WithContext(context.WithValue(req.Context(), original, wisdom))
110	req.SetBasicAuth("admin", "123456")
111	ok, usr, err := ba.Authenticate(req)
112
113	assert.NoError(t, err)
114	assert.True(t, ok)
115	assert.Equal(t, "admin", usr)
116	assert.Equal(t, wisdom, req.Context().Value(original))
117	assert.Equal(t, extraWisdom, req.Context().Value(extra))
118	assert.Nil(t, req.Context().Value(reason))
119}
120
121func TestInvalidBasicAuthCtx(t *testing.T) {
122	ba := BasicAuthCtx(basicAuthHandlerCtx)
123
124	req, _ := http.NewRequest("GET", "/blah", nil)
125	req = req.WithContext(context.WithValue(req.Context(), original, wisdom))
126	req.SetBasicAuth("admin", "admin")
127	ok, usr, err := ba.Authenticate(req)
128
129	assert.Error(t, err)
130	assert.True(t, ok)
131	assert.Equal(t, "", usr)
132	assert.Equal(t, wisdom, req.Context().Value(original))
133	assert.Nil(t, req.Context().Value(extra))
134	assert.Equal(t, expReason, req.Context().Value(reason))
135}
136
137func TestMissingbasicAuthCtx(t *testing.T) {
138	ba := BasicAuthCtx(basicAuthHandlerCtx)
139
140	req, _ := http.NewRequest("GET", "/blah", nil)
141	req = req.WithContext(context.WithValue(req.Context(), original, wisdom))
142	ok, usr, err := ba.Authenticate(req)
143	assert.NoError(t, err)
144	assert.False(t, ok)
145	assert.Equal(t, nil, usr)
146
147	assert.Equal(t, wisdom, req.Context().Value(original))
148	assert.Nil(t, req.Context().Value(extra))
149	assert.Nil(t, req.Context().Value(reason))
150}
151
152func TestNoRequestBasicAuthCtx(t *testing.T) {
153	ba := BasicAuthCtx(basicAuthHandlerCtx)
154
155	ok, usr, err := ba.Authenticate("token")
156
157	assert.NoError(t, err)
158	assert.False(t, ok)
159	assert.Nil(t, usr)
160}
161