1/*
2Copyright 2016 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package rest
18
19import (
20	"fmt"
21	"net/http"
22	"reflect"
23	"strconv"
24	"testing"
25
26	clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
27)
28
29func TestAuthPluginWrapTransport(t *testing.T) {
30	if err := RegisterAuthProviderPlugin("pluginA", pluginAProvider); err != nil {
31		t.Errorf("Unexpected error: failed to register pluginA: %v", err)
32	}
33	if err := RegisterAuthProviderPlugin("pluginB", pluginBProvider); err != nil {
34		t.Errorf("Unexpected error: failed to register pluginB: %v", err)
35	}
36	if err := RegisterAuthProviderPlugin("pluginFail", pluginFailProvider); err != nil {
37		t.Errorf("Unexpected error: failed to register pluginFail: %v", err)
38	}
39	testCases := []struct {
40		useWrapTransport bool
41		plugin           string
42		expectErr        bool
43		expectPluginA    bool
44		expectPluginB    bool
45	}{
46		{false, "", false, false, false},
47		{false, "pluginA", false, true, false},
48		{false, "pluginB", false, false, true},
49		{false, "pluginFail", true, false, false},
50		{false, "pluginUnknown", true, false, false},
51	}
52	for i, tc := range testCases {
53		c := Config{}
54		if tc.useWrapTransport {
55			// Specify an existing WrapTransport in the config to make sure that
56			// plugins play nicely.
57			c.WrapTransport = func(rt http.RoundTripper) http.RoundTripper {
58				return &wrapTransport{rt}
59			}
60		}
61		if len(tc.plugin) != 0 {
62			c.AuthProvider = &clientcmdapi.AuthProviderConfig{Name: tc.plugin}
63		}
64		tConfig, err := c.TransportConfig()
65		if err != nil {
66			// Unknown/bad plugins are expected to fail here.
67			if !tc.expectErr {
68				t.Errorf("%d. Did not expect errors loading Auth Plugin: %q. Got: %v", i, tc.plugin, err)
69			}
70			continue
71		}
72		var fullyWrappedTransport http.RoundTripper
73		fullyWrappedTransport = &emptyTransport{}
74		if tConfig.WrapTransport != nil {
75			fullyWrappedTransport = tConfig.WrapTransport(&emptyTransport{})
76		}
77		res, err := fullyWrappedTransport.RoundTrip(&http.Request{})
78		if err != nil {
79			t.Errorf("%d. Unexpected error in RoundTrip: %v", i, err)
80			continue
81		}
82		hasWrapTransport := res.Header.Get("wrapTransport") == "Y"
83		hasPluginA := res.Header.Get("pluginA") == "Y"
84		hasPluginB := res.Header.Get("pluginB") == "Y"
85		if hasWrapTransport != tc.useWrapTransport {
86			t.Errorf("%d. Expected Existing config.WrapTransport: %t; Got: %t", i, tc.useWrapTransport, hasWrapTransport)
87		}
88		if hasPluginA != tc.expectPluginA {
89			t.Errorf("%d. Expected Plugin A: %t; Got: %t", i, tc.expectPluginA, hasPluginA)
90		}
91		if hasPluginB != tc.expectPluginB {
92			t.Errorf("%d. Expected Plugin B: %t; Got: %t", i, tc.expectPluginB, hasPluginB)
93		}
94	}
95}
96
97func TestAuthPluginPersist(t *testing.T) {
98	// register pluginA by a different name so we don't collide across tests.
99	if err := RegisterAuthProviderPlugin("pluginA2", pluginAProvider); err != nil {
100		t.Errorf("Unexpected error: failed to register pluginA: %v", err)
101	}
102	if err := RegisterAuthProviderPlugin("pluginPersist", pluginPersistProvider); err != nil {
103		t.Errorf("Unexpected error: failed to register pluginPersist: %v", err)
104	}
105	fooBarConfig := map[string]string{"foo": "bar"}
106	testCases := []struct {
107		plugin                       string
108		startingConfig               map[string]string
109		expectedConfigAfterLogin     map[string]string
110		expectedConfigAfterRoundTrip map[string]string
111	}{
112		// non-persisting plugins should work fine without modifying config.
113		{"pluginA2", map[string]string{}, map[string]string{}, map[string]string{}},
114		{"pluginA2", fooBarConfig, fooBarConfig, fooBarConfig},
115		// plugins that persist config should be able to persist when they want.
116		{
117			"pluginPersist",
118			map[string]string{},
119			map[string]string{
120				"login": "Y",
121			},
122			map[string]string{
123				"login":      "Y",
124				"roundTrips": "1",
125			},
126		},
127		{
128			"pluginPersist",
129			map[string]string{
130				"login":      "Y",
131				"roundTrips": "123",
132			},
133			map[string]string{
134				"login":      "Y",
135				"roundTrips": "123",
136			},
137			map[string]string{
138				"login":      "Y",
139				"roundTrips": "124",
140			},
141		},
142	}
143	for i, tc := range testCases {
144		cfg := &clientcmdapi.AuthProviderConfig{
145			Name:   tc.plugin,
146			Config: tc.startingConfig,
147		}
148		persister := &inMemoryPersister{make(map[string]string)}
149		persister.Persist(tc.startingConfig)
150		plugin, err := GetAuthProvider("127.0.0.1", cfg, persister)
151		if err != nil {
152			t.Errorf("%d. Unexpected error: failed to get plugin %q: %v", i, tc.plugin, err)
153		}
154		if err := plugin.Login(); err != nil {
155			t.Errorf("%d. Unexpected error calling Login() w/ plugin %q: %v", i, tc.plugin, err)
156		}
157		// Make sure the plugin persisted what we expect after Login().
158		if !reflect.DeepEqual(persister.savedConfig, tc.expectedConfigAfterLogin) {
159			t.Errorf("%d. Unexpected persisted config after calling %s.Login(): \nGot:\n%v\nExpected:\n%v",
160				i, tc.plugin, persister.savedConfig, tc.expectedConfigAfterLogin)
161		}
162		if _, err := plugin.WrapTransport(&emptyTransport{}).RoundTrip(&http.Request{}); err != nil {
163			t.Errorf("%d. Unexpected error round-tripping w/ plugin %q: %v", i, tc.plugin, err)
164		}
165		// Make sure the plugin persisted what we expect after RoundTrip().
166		if !reflect.DeepEqual(persister.savedConfig, tc.expectedConfigAfterRoundTrip) {
167			t.Errorf("%d. Unexpected persisted config after calling %s.WrapTransport.RoundTrip(): \nGot:\n%v\nExpected:\n%v",
168				i, tc.plugin, persister.savedConfig, tc.expectedConfigAfterLogin)
169		}
170	}
171
172}
173
174// emptyTransport provides an empty http.Response with an initialized header
175// to allow wrapping RoundTrippers to set header values.
176type emptyTransport struct{}
177
178func (*emptyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
179	res := &http.Response{
180		Header: make(map[string][]string),
181	}
182	return res, nil
183}
184
185// wrapTransport sets "wrapTransport" = "Y" on the response.
186type wrapTransport struct {
187	rt http.RoundTripper
188}
189
190func (w *wrapTransport) RoundTrip(req *http.Request) (*http.Response, error) {
191	res, err := w.rt.RoundTrip(req)
192	if err != nil {
193		return nil, err
194	}
195	res.Header.Add("wrapTransport", "Y")
196	return res, nil
197}
198
199// wrapTransportA sets "pluginA" = "Y" on the response.
200type wrapTransportA struct {
201	rt http.RoundTripper
202}
203
204func (w *wrapTransportA) RoundTrip(req *http.Request) (*http.Response, error) {
205	res, err := w.rt.RoundTrip(req)
206	if err != nil {
207		return nil, err
208	}
209	res.Header.Add("pluginA", "Y")
210	return res, nil
211}
212
213type pluginA struct{}
214
215func (*pluginA) WrapTransport(rt http.RoundTripper) http.RoundTripper {
216	return &wrapTransportA{rt}
217}
218
219func (*pluginA) Login() error { return nil }
220
221func pluginAProvider(string, map[string]string, AuthProviderConfigPersister) (AuthProvider, error) {
222	return &pluginA{}, nil
223}
224
225// wrapTransportB sets "pluginB" = "Y" on the response.
226type wrapTransportB struct {
227	rt http.RoundTripper
228}
229
230func (w *wrapTransportB) RoundTrip(req *http.Request) (*http.Response, error) {
231	res, err := w.rt.RoundTrip(req)
232	if err != nil {
233		return nil, err
234	}
235	res.Header.Add("pluginB", "Y")
236	return res, nil
237}
238
239type pluginB struct{}
240
241func (*pluginB) WrapTransport(rt http.RoundTripper) http.RoundTripper {
242	return &wrapTransportB{rt}
243}
244
245func (*pluginB) Login() error { return nil }
246
247func pluginBProvider(string, map[string]string, AuthProviderConfigPersister) (AuthProvider, error) {
248	return &pluginB{}, nil
249}
250
251// pluginFailProvider simulates a registered AuthPlugin that fails to load.
252func pluginFailProvider(string, map[string]string, AuthProviderConfigPersister) (AuthProvider, error) {
253	return nil, fmt.Errorf("Failed to load AuthProvider")
254}
255
256type inMemoryPersister struct {
257	savedConfig map[string]string
258}
259
260func (i *inMemoryPersister) Persist(config map[string]string) error {
261	i.savedConfig = make(map[string]string)
262	for k, v := range config {
263		i.savedConfig[k] = v
264	}
265	return nil
266}
267
268// wrapTransportPersist increments the "roundTrips" entry from the config when
269// roundTrip is called.
270type wrapTransportPersist struct {
271	rt        http.RoundTripper
272	config    map[string]string
273	persister AuthProviderConfigPersister
274}
275
276func (w *wrapTransportPersist) RoundTrip(req *http.Request) (*http.Response, error) {
277	roundTrips := 0
278	if rtVal, ok := w.config["roundTrips"]; ok {
279		var err error
280		roundTrips, err = strconv.Atoi(rtVal)
281		if err != nil {
282			return nil, err
283		}
284	}
285	roundTrips++
286	w.config["roundTrips"] = fmt.Sprintf("%d", roundTrips)
287	if err := w.persister.Persist(w.config); err != nil {
288		return nil, err
289	}
290	return w.rt.RoundTrip(req)
291}
292
293type pluginPersist struct {
294	config    map[string]string
295	persister AuthProviderConfigPersister
296}
297
298func (p *pluginPersist) WrapTransport(rt http.RoundTripper) http.RoundTripper {
299	return &wrapTransportPersist{rt, p.config, p.persister}
300}
301
302// Login sets the config entry "login" to "Y".
303func (p *pluginPersist) Login() error {
304	p.config["login"] = "Y"
305	p.persister.Persist(p.config)
306	return nil
307}
308
309func pluginPersistProvider(_ string, config map[string]string, persister AuthProviderConfigPersister) (AuthProvider, error) {
310	return &pluginPersist{config, persister}, nil
311}
312