1/*
2Copyright 2016 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package rest
18
19import (
20	"context"
21	"errors"
22	"fmt"
23	"io"
24	"net"
25	"net/http"
26	"path/filepath"
27	"reflect"
28	"strings"
29	"testing"
30	"time"
31
32	v1 "k8s.io/api/core/v1"
33	"k8s.io/apimachinery/pkg/runtime"
34	"k8s.io/apimachinery/pkg/runtime/schema"
35	"k8s.io/apimachinery/pkg/util/diff"
36	"k8s.io/client-go/kubernetes/scheme"
37	clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
38	"k8s.io/client-go/transport"
39	"k8s.io/client-go/util/flowcontrol"
40
41	fuzz "github.com/google/gofuzz"
42	"github.com/stretchr/testify/assert"
43)
44
45func TestIsConfigTransportTLS(t *testing.T) {
46	testCases := []struct {
47		Config       *Config
48		TransportTLS bool
49	}{
50		{
51			Config:       &Config{},
52			TransportTLS: false,
53		},
54		{
55			Config: &Config{
56				Host: "https://localhost",
57			},
58			TransportTLS: true,
59		},
60		{
61			Config: &Config{
62				Host: "localhost",
63				TLSClientConfig: TLSClientConfig{
64					CertFile: "foo",
65				},
66			},
67			TransportTLS: true,
68		},
69		{
70			Config: &Config{
71				Host: "///:://localhost",
72				TLSClientConfig: TLSClientConfig{
73					CertFile: "foo",
74				},
75			},
76			TransportTLS: false,
77		},
78		{
79			Config: &Config{
80				Host: "1.2.3.4:567",
81				TLSClientConfig: TLSClientConfig{
82					Insecure: true,
83				},
84			},
85			TransportTLS: true,
86		},
87	}
88	for _, testCase := range testCases {
89		if err := SetKubernetesDefaults(testCase.Config); err != nil {
90			t.Errorf("setting defaults failed for %#v: %v", testCase.Config, err)
91			continue
92		}
93		useTLS := IsConfigTransportTLS(*testCase.Config)
94		if testCase.TransportTLS != useTLS {
95			t.Errorf("expected %v for %#v", testCase.TransportTLS, testCase.Config)
96		}
97	}
98}
99
100func TestSetKubernetesDefaultsUserAgent(t *testing.T) {
101	config := &Config{}
102	if err := SetKubernetesDefaults(config); err != nil {
103		t.Errorf("unexpected error: %v", err)
104	}
105	if !strings.Contains(config.UserAgent, "kubernetes/") {
106		t.Errorf("no user agent set: %#v", config)
107	}
108}
109
110func TestAdjustVersion(t *testing.T) {
111	assert := assert.New(t)
112	assert.Equal("1.2.3", adjustVersion("1.2.3-alpha4"))
113	assert.Equal("1.2.3", adjustVersion("1.2.3-alpha"))
114	assert.Equal("1.2.3", adjustVersion("1.2.3"))
115	assert.Equal("unknown", adjustVersion(""))
116}
117
118func TestAdjustCommit(t *testing.T) {
119	assert := assert.New(t)
120	assert.Equal("1234567", adjustCommit("1234567890"))
121	assert.Equal("123456", adjustCommit("123456"))
122	assert.Equal("unknown", adjustCommit(""))
123}
124
125func TestAdjustCommand(t *testing.T) {
126	assert := assert.New(t)
127	assert.Equal("beans", adjustCommand(filepath.Join("home", "bob", "Downloads", "beans")))
128	assert.Equal("beans", adjustCommand(filepath.Join(".", "beans")))
129	assert.Equal("beans", adjustCommand("beans"))
130	assert.Equal("unknown", adjustCommand(""))
131}
132
133func TestBuildUserAgent(t *testing.T) {
134	assert.New(t).Equal(
135		"lynx/nicest (beos/itanium) kubernetes/baaaaaaaaad",
136		buildUserAgent(
137			"lynx", "nicest",
138			"beos", "itanium", "baaaaaaaaad"))
139}
140
141// This function untestable since it doesn't accept arguments.
142func TestDefaultKubernetesUserAgent(t *testing.T) {
143	assert.New(t).Contains(DefaultKubernetesUserAgent(), "kubernetes")
144}
145
146func TestRESTClientRequires(t *testing.T) {
147	if _, err := RESTClientFor(&Config{Host: "127.0.0.1", ContentConfig: ContentConfig{NegotiatedSerializer: scheme.Codecs}}); err == nil {
148		t.Errorf("unexpected non-error")
149	}
150	if _, err := RESTClientFor(&Config{Host: "127.0.0.1", ContentConfig: ContentConfig{GroupVersion: &v1.SchemeGroupVersion}}); err == nil {
151		t.Errorf("unexpected non-error")
152	}
153	if _, err := RESTClientFor(&Config{Host: "127.0.0.1", ContentConfig: ContentConfig{GroupVersion: &v1.SchemeGroupVersion, NegotiatedSerializer: scheme.Codecs}}); err != nil {
154		t.Errorf("unexpected error: %v", err)
155	}
156}
157
158type fakeLimiter struct {
159	FakeSaturation float64
160	FakeQPS        float32
161}
162
163func (t *fakeLimiter) TryAccept() bool {
164	return true
165}
166
167func (t *fakeLimiter) Saturation() float64 {
168	return t.FakeSaturation
169}
170
171func (t *fakeLimiter) QPS() float32 {
172	return t.FakeQPS
173}
174
175func (t *fakeLimiter) Stop() {}
176
177func (t *fakeLimiter) Accept() {}
178
179type fakeCodec struct{}
180
181func (c *fakeCodec) Decode([]byte, *schema.GroupVersionKind, runtime.Object) (runtime.Object, *schema.GroupVersionKind, error) {
182	return nil, nil, nil
183}
184
185func (c *fakeCodec) Encode(obj runtime.Object, stream io.Writer) error {
186	return nil
187}
188
189type fakeRoundTripper struct{}
190
191func (r *fakeRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
192	return nil, nil
193}
194
195var fakeWrapperFunc = func(http.RoundTripper) http.RoundTripper {
196	return &fakeRoundTripper{}
197}
198
199type fakeNegotiatedSerializer struct{}
200
201func (n *fakeNegotiatedSerializer) SupportedMediaTypes() []runtime.SerializerInfo {
202	return nil
203}
204
205func (n *fakeNegotiatedSerializer) EncoderForVersion(serializer runtime.Encoder, gv runtime.GroupVersioner) runtime.Encoder {
206	return &fakeCodec{}
207}
208
209func (n *fakeNegotiatedSerializer) DecoderToVersion(serializer runtime.Decoder, gv runtime.GroupVersioner) runtime.Decoder {
210	return &fakeCodec{}
211}
212
213var fakeDialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
214	return nil, fakeDialerError
215}
216var fakeDialerError = errors.New("fakedialer")
217
218type fakeAuthProviderConfigPersister struct{}
219
220func (fakeAuthProviderConfigPersister) Persist(map[string]string) error {
221	return fakeAuthProviderConfigPersisterError
222}
223
224var fakeAuthProviderConfigPersisterError = errors.New("fakeAuthProviderConfigPersisterError")
225
226func TestAnonymousConfig(t *testing.T) {
227	f := fuzz.New().NilChance(0.0).NumElements(1, 1)
228	f.Funcs(
229		func(r *runtime.Codec, f fuzz.Continue) {
230			codec := &fakeCodec{}
231			f.Fuzz(codec)
232			*r = codec
233		},
234		func(r *http.RoundTripper, f fuzz.Continue) {
235			roundTripper := &fakeRoundTripper{}
236			f.Fuzz(roundTripper)
237			*r = roundTripper
238		},
239		func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) {
240			*fn = fakeWrapperFunc
241		},
242		func(fn *transport.WrapperFunc, f fuzz.Continue) {
243			*fn = fakeWrapperFunc
244		},
245		func(r *runtime.NegotiatedSerializer, f fuzz.Continue) {
246			serializer := &fakeNegotiatedSerializer{}
247			f.Fuzz(serializer)
248			*r = serializer
249		},
250		func(r *flowcontrol.RateLimiter, f fuzz.Continue) {
251			limiter := &fakeLimiter{}
252			f.Fuzz(limiter)
253			*r = limiter
254		},
255		// Authentication does not require fuzzer
256		func(r *AuthProviderConfigPersister, f fuzz.Continue) {},
257		func(r *clientcmdapi.AuthProviderConfig, f fuzz.Continue) {
258			r.Config = map[string]string{}
259		},
260		// Dial does not require fuzzer
261		func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) {},
262	)
263	for i := 0; i < 20; i++ {
264		original := &Config{}
265		f.Fuzz(original)
266		actual := AnonymousClientConfig(original)
267		expected := *original
268
269		// this is the list of known security related fields, add to this list if a new field
270		// is added to Config, update AnonymousClientConfig to preserve the field otherwise.
271		expected.Impersonate = ImpersonationConfig{}
272		expected.BearerToken = ""
273		expected.BearerTokenFile = ""
274		expected.Username = ""
275		expected.Password = ""
276		expected.AuthProvider = nil
277		expected.AuthConfigPersister = nil
278		expected.ExecProvider = nil
279		expected.TLSClientConfig.CertData = nil
280		expected.TLSClientConfig.CertFile = ""
281		expected.TLSClientConfig.KeyData = nil
282		expected.TLSClientConfig.KeyFile = ""
283		expected.Transport = nil
284		expected.WrapTransport = nil
285
286		if actual.Dial != nil {
287			_, actualError := actual.Dial(context.Background(), "", "")
288			_, expectedError := expected.Dial(context.Background(), "", "")
289			if !reflect.DeepEqual(expectedError, actualError) {
290				t.Fatalf("AnonymousClientConfig dropped the Dial field")
291			}
292		} else {
293			actual.Dial = nil
294			expected.Dial = nil
295		}
296
297		if !reflect.DeepEqual(*actual, expected) {
298			t.Fatalf("AnonymousClientConfig dropped unexpected fields, identify whether they are security related or not: %s", diff.ObjectGoPrintDiff(expected, actual))
299		}
300	}
301}
302
303func TestCopyConfig(t *testing.T) {
304	f := fuzz.New().NilChance(0.0).NumElements(1, 1)
305	f.Funcs(
306		func(r *runtime.Codec, f fuzz.Continue) {
307			codec := &fakeCodec{}
308			f.Fuzz(codec)
309			*r = codec
310		},
311		func(r *http.RoundTripper, f fuzz.Continue) {
312			roundTripper := &fakeRoundTripper{}
313			f.Fuzz(roundTripper)
314			*r = roundTripper
315		},
316		func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) {
317			*fn = fakeWrapperFunc
318		},
319		func(fn *transport.WrapperFunc, f fuzz.Continue) {
320			*fn = fakeWrapperFunc
321		},
322		func(r *runtime.NegotiatedSerializer, f fuzz.Continue) {
323			serializer := &fakeNegotiatedSerializer{}
324			f.Fuzz(serializer)
325			*r = serializer
326		},
327		func(r *flowcontrol.RateLimiter, f fuzz.Continue) {
328			limiter := &fakeLimiter{}
329			f.Fuzz(limiter)
330			*r = limiter
331		},
332		func(r *AuthProviderConfigPersister, f fuzz.Continue) {
333			*r = fakeAuthProviderConfigPersister{}
334		},
335		func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) {
336			*r = fakeDialFunc
337		},
338	)
339	for i := 0; i < 20; i++ {
340		original := &Config{}
341		f.Fuzz(original)
342		actual := CopyConfig(original)
343		expected := *original
344
345		// this is the list of known risky fields, add to this list if a new field
346		// is added to Config, update CopyConfig to preserve the field otherwise.
347
348		// The DeepEqual cannot handle the func comparison, so we just verify if the
349		// function return the expected object.
350		if actual.WrapTransport == nil || !reflect.DeepEqual(expected.WrapTransport(nil), &fakeRoundTripper{}) {
351			t.Fatalf("CopyConfig dropped the WrapTransport field")
352		} else {
353			actual.WrapTransport = nil
354			expected.WrapTransport = nil
355		}
356		if actual.Dial != nil {
357			_, actualError := actual.Dial(context.Background(), "", "")
358			_, expectedError := expected.Dial(context.Background(), "", "")
359			if !reflect.DeepEqual(expectedError, actualError) {
360				t.Fatalf("CopyConfig  dropped the Dial field")
361			}
362		}
363		actual.Dial = nil
364		expected.Dial = nil
365		if actual.AuthConfigPersister != nil {
366			actualError := actual.AuthConfigPersister.Persist(nil)
367			expectedError := expected.AuthConfigPersister.Persist(nil)
368			if !reflect.DeepEqual(expectedError, actualError) {
369				t.Fatalf("CopyConfig  dropped the Dial field")
370			}
371		}
372		actual.AuthConfigPersister = nil
373		expected.AuthConfigPersister = nil
374
375		if !reflect.DeepEqual(*actual, expected) {
376			t.Fatalf("CopyConfig  dropped unexpected fields, identify whether they are security related or not: %s", diff.ObjectReflectDiff(expected, *actual))
377		}
378	}
379}
380
381func TestConfigStringer(t *testing.T) {
382	formatBytes := func(b []byte) string {
383		// %#v for []byte always pre-pends "[]byte{".
384		// %#v for struct with []byte field always pre-pends "[]uint8{".
385		return strings.Replace(fmt.Sprintf("%#v", b), "byte", "uint8", 1)
386	}
387	tests := []struct {
388		desc            string
389		c               *Config
390		expectContent   []string
391		prohibitContent []string
392	}{
393		{
394			desc:          "nil config",
395			c:             nil,
396			expectContent: []string{"<nil>"},
397		},
398		{
399			desc: "non-sensitive config",
400			c: &Config{
401				Host:      "localhost:8080",
402				APIPath:   "v1",
403				UserAgent: "gobot",
404			},
405			expectContent: []string{"localhost:8080", "v1", "gobot"},
406		},
407		{
408			desc: "sensitive config",
409			c: &Config{
410				Host:        "localhost:8080",
411				Username:    "gopher",
412				Password:    "g0ph3r",
413				BearerToken: "1234567890",
414				TLSClientConfig: TLSClientConfig{
415					CertFile: "a.crt",
416					KeyFile:  "a.key",
417					CertData: []byte("fake cert"),
418					KeyData:  []byte("fake key"),
419				},
420				AuthProvider: &clientcmdapi.AuthProviderConfig{
421					Config: map[string]string{"secret": "s3cr3t"},
422				},
423				ExecProvider: &clientcmdapi.ExecConfig{
424					Args: []string{"secret"},
425					Env:  []clientcmdapi.ExecEnvVar{{Name: "secret", Value: "s3cr3t"}},
426				},
427			},
428			expectContent: []string{
429				"localhost:8080",
430				"gopher",
431				"a.crt",
432				"a.key",
433				"--- REDACTED ---",
434				formatBytes([]byte("--- REDACTED ---")),
435				formatBytes([]byte("--- TRUNCATED ---")),
436			},
437			prohibitContent: []string{
438				"g0ph3r",
439				"1234567890",
440				formatBytes([]byte("fake cert")),
441				formatBytes([]byte("fake key")),
442				"secret",
443				"s3cr3t",
444			},
445		},
446	}
447
448	for _, tt := range tests {
449		t.Run(tt.desc, func(t *testing.T) {
450			got := tt.c.String()
451			t.Logf("formatted config: %q", got)
452
453			for _, expect := range tt.expectContent {
454				if !strings.Contains(got, expect) {
455					t.Errorf("missing expected string %q", expect)
456				}
457			}
458			for _, prohibit := range tt.prohibitContent {
459				if strings.Contains(got, prohibit) {
460					t.Errorf("found prohibited string %q", prohibit)
461				}
462			}
463		})
464	}
465}
466
467func TestConfigSprint(t *testing.T) {
468	c := &Config{
469		Host:    "localhost:8080",
470		APIPath: "v1",
471		ContentConfig: ContentConfig{
472			AcceptContentTypes: "application/json",
473			ContentType:        "application/json",
474		},
475		Username:    "gopher",
476		Password:    "g0ph3r",
477		BearerToken: "1234567890",
478		Impersonate: ImpersonationConfig{
479			UserName: "gopher2",
480		},
481		AuthProvider: &clientcmdapi.AuthProviderConfig{
482			Name:   "gopher",
483			Config: map[string]string{"secret": "s3cr3t"},
484		},
485		AuthConfigPersister: fakeAuthProviderConfigPersister{},
486		ExecProvider: &clientcmdapi.ExecConfig{
487			Command: "sudo",
488			Args:    []string{"secret"},
489			Env:     []clientcmdapi.ExecEnvVar{{Name: "secret", Value: "s3cr3t"}},
490		},
491		TLSClientConfig: TLSClientConfig{
492			CertFile: "a.crt",
493			KeyFile:  "a.key",
494			CertData: []byte("fake cert"),
495			KeyData:  []byte("fake key"),
496		},
497		UserAgent:     "gobot",
498		Transport:     &fakeRoundTripper{},
499		WrapTransport: fakeWrapperFunc,
500		QPS:           1,
501		Burst:         2,
502		RateLimiter:   &fakeLimiter{},
503		Timeout:       3 * time.Second,
504		Dial:          fakeDialFunc,
505	}
506	want := fmt.Sprintf(
507		`&rest.Config{Host:"localhost:8080", APIPath:"v1", ContentConfig:rest.ContentConfig{AcceptContentTypes:"application/json", ContentType:"application/json", GroupVersion:(*schema.GroupVersion)(nil), NegotiatedSerializer:runtime.NegotiatedSerializer(nil)}, Username:"gopher", Password:"--- REDACTED ---", BearerToken:"--- REDACTED ---", BearerTokenFile:"", Impersonate:rest.ImpersonationConfig{UserName:"gopher2", Groups:[]string(nil), Extra:map[string][]string(nil)}, AuthProvider:api.AuthProviderConfig{Name: "gopher", Config: map[string]string{--- REDACTED ---}}, AuthConfigPersister:rest.AuthProviderConfigPersister(--- REDACTED ---), ExecProvider:api.AuthProviderConfig{Command: "sudo", Args: []string{"--- REDACTED ---"}, Env: []ExecEnvVar{--- REDACTED ---}, APIVersion: ""}, TLSClientConfig:rest.sanitizedTLSClientConfig{Insecure:false, ServerName:"", CertFile:"a.crt", KeyFile:"a.key", CAFile:"", CertData:[]uint8{0x2d, 0x2d, 0x2d, 0x20, 0x54, 0x52, 0x55, 0x4e, 0x43, 0x41, 0x54, 0x45, 0x44, 0x20, 0x2d, 0x2d, 0x2d}, KeyData:[]uint8{0x2d, 0x2d, 0x2d, 0x20, 0x52, 0x45, 0x44, 0x41, 0x43, 0x54, 0x45, 0x44, 0x20, 0x2d, 0x2d, 0x2d}, CAData:[]uint8(nil)}, UserAgent:"gobot", Transport:(*rest.fakeRoundTripper)(%p), WrapTransport:(transport.WrapperFunc)(%p), QPS:1, Burst:2, RateLimiter:(*rest.fakeLimiter)(%p), Timeout:3000000000, Dial:(func(context.Context, string, string) (net.Conn, error))(%p)}`,
508		c.Transport, fakeWrapperFunc, c.RateLimiter, fakeDialFunc,
509	)
510
511	for _, f := range []string{"%s", "%v", "%+v", "%#v"} {
512		if got := fmt.Sprintf(f, c); want != got {
513			t.Errorf("fmt.Sprintf(%q, c)\ngot:  %q\nwant: %q", f, got, want)
514		}
515	}
516}
517