1/*
2   Copyright The containerd Authors.
3
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15*/
16
17package docker
18
19import (
20	"context"
21	"crypto/tls"
22	"crypto/x509"
23	"encoding/json"
24	"fmt"
25	"io"
26	"io/ioutil"
27	"net/http"
28	"net/http/httptest"
29	"strconv"
30	"strings"
31	"testing"
32	"time"
33
34	"github.com/containerd/containerd/remotes"
35	digest "github.com/opencontainers/go-digest"
36	specs "github.com/opencontainers/image-spec/specs-go"
37	ocispec "github.com/opencontainers/image-spec/specs-go/v1"
38	"github.com/pkg/errors"
39)
40
41func TestHTTPResolver(t *testing.T) {
42	s := func(h http.Handler) (string, ResolverOptions, func()) {
43		s := httptest.NewServer(h)
44
45		options := ResolverOptions{}
46		base := s.URL[7:] // strip "http://"
47		return base, options, s.Close
48	}
49
50	runBasicTest(t, "testname", s)
51}
52
53func TestHTTPSResolver(t *testing.T) {
54	runBasicTest(t, "testname", tlsServer)
55}
56
57func TestBasicResolver(t *testing.T) {
58	basicAuth := func(h http.Handler) (string, ResolverOptions, func()) {
59		// Wrap with basic auth
60		wrapped := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
61			username, password, ok := r.BasicAuth()
62			if !ok || username != "user1" || password != "password1" {
63				rw.Header().Set("WWW-Authenticate", "Basic realm=localhost")
64				rw.WriteHeader(http.StatusUnauthorized)
65				return
66			}
67			h.ServeHTTP(rw, r)
68		})
69
70		base, options, close := tlsServer(wrapped)
71		options.Hosts = ConfigureDefaultRegistries(
72			WithClient(options.Client),
73			WithAuthorizer(NewAuthorizer(options.Client, func(string) (string, string, error) {
74				return "user1", "password1", nil
75			})),
76		)
77		return base, options, close
78	}
79	runBasicTest(t, "testname", basicAuth)
80}
81
82func TestAnonymousTokenResolver(t *testing.T) {
83	th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
84		if r.Method != http.MethodGet {
85			rw.WriteHeader(http.StatusMethodNotAllowed)
86			return
87		}
88		rw.Header().Set("Content-Type", "application/json")
89		rw.WriteHeader(http.StatusOK)
90		rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`))
91	})
92
93	runBasicTest(t, "testname", withTokenServer(th, nil))
94}
95
96func TestBasicAuthTokenResolver(t *testing.T) {
97	th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
98		if r.Method != http.MethodGet {
99			rw.WriteHeader(http.StatusMethodNotAllowed)
100			return
101		}
102		rw.Header().Set("Content-Type", "application/json")
103		rw.WriteHeader(http.StatusOK)
104		username, password, ok := r.BasicAuth()
105		if !ok || username != "user1" || password != "password1" {
106			rw.Write([]byte(`{"access_token":"insufficientscope"}`))
107		} else {
108			rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`))
109		}
110	})
111	creds := func(string) (string, string, error) {
112		return "user1", "password1", nil
113	}
114
115	runBasicTest(t, "testname", withTokenServer(th, creds))
116}
117
118func TestRefreshTokenResolver(t *testing.T) {
119	th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
120		if r.Method != http.MethodPost {
121			rw.WriteHeader(http.StatusMethodNotAllowed)
122			return
123		}
124		rw.Header().Set("Content-Type", "application/json")
125		rw.WriteHeader(http.StatusOK)
126
127		r.ParseForm()
128		if r.PostForm.Get("grant_type") != "refresh_token" || r.PostForm.Get("refresh_token") != "somerefreshtoken" {
129			rw.Write([]byte(`{"access_token":"insufficientscope"}`))
130		} else {
131			rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`))
132		}
133	})
134	creds := func(string) (string, string, error) {
135		return "", "somerefreshtoken", nil
136	}
137
138	runBasicTest(t, "testname", withTokenServer(th, creds))
139}
140
141func TestPostBasicAuthTokenResolver(t *testing.T) {
142	th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
143		if r.Method != http.MethodPost {
144			rw.WriteHeader(http.StatusMethodNotAllowed)
145			return
146		}
147		rw.Header().Set("Content-Type", "application/json")
148		rw.WriteHeader(http.StatusOK)
149
150		r.ParseForm()
151		if r.PostForm.Get("grant_type") != "password" || r.PostForm.Get("username") != "user1" || r.PostForm.Get("password") != "password1" {
152			rw.Write([]byte(`{"access_token":"insufficientscope"}`))
153		} else {
154			rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`))
155		}
156	})
157	creds := func(string) (string, string, error) {
158		return "user1", "password1", nil
159	}
160
161	runBasicTest(t, "testname", withTokenServer(th, creds))
162}
163
164func TestBadTokenResolver(t *testing.T) {
165	th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
166		if r.Method != http.MethodPost {
167			rw.WriteHeader(http.StatusMethodNotAllowed)
168			return
169		}
170		rw.Header().Set("Content-Type", "application/json")
171		rw.WriteHeader(http.StatusOK)
172		rw.Write([]byte(`{"access_token":"insufficientscope"}`))
173	})
174	creds := func(string) (string, string, error) {
175		return "", "somerefreshtoken", nil
176	}
177
178	ctx := context.Background()
179	h := newContent(ocispec.MediaTypeImageManifest, []byte("not anything parse-able"))
180
181	base, ro, close := withTokenServer(th, creds)(logHandler{t, h})
182	defer close()
183
184	resolver := NewResolver(ro)
185	image := fmt.Sprintf("%s/doesntmatter:sometatg", base)
186
187	_, _, err := resolver.Resolve(ctx, image)
188	if err == nil {
189		t.Fatal("Expected error getting token with inssufficient scope")
190	}
191	if !errors.Is(err, ErrInvalidAuthorization) {
192		t.Fatal(err)
193	}
194}
195
196func TestHostFailureFallbackResolver(t *testing.T) {
197	sf := func(h http.Handler) (string, ResolverOptions, func()) {
198		s := httptest.NewServer(h)
199		base := s.URL[7:] // strip "http://"
200
201		options := ResolverOptions{}
202		createHost := func(host string) RegistryHost {
203			return RegistryHost{
204				Client: &http.Client{
205					// Set the timeout so we timeout waiting for the non-responsive HTTP server
206					Timeout: 500 * time.Millisecond,
207				},
208				Host:         host,
209				Scheme:       "http",
210				Path:         "/v2",
211				Capabilities: HostCapabilityPull | HostCapabilityResolve | HostCapabilityPush,
212			}
213		}
214
215		// Create an unstarted HTTP server. We use this to generate a random port.
216		notRunning := httptest.NewUnstartedServer(nil)
217		notRunningBase := notRunning.Listener.Addr().String()
218
219		// Override hosts with two hosts
220		options.Hosts = func(host string) ([]RegistryHost, error) {
221			return []RegistryHost{
222				createHost(notRunningBase), // This host IS running, but with a non-responsive HTTP server
223				createHost(base),           // This host IS running
224			}, nil
225		}
226
227		return base, options, s.Close
228	}
229
230	runBasicTest(t, "testname", sf)
231}
232
233func TestHostTLSFailureFallbackResolver(t *testing.T) {
234	sf := func(h http.Handler) (string, ResolverOptions, func()) {
235		// Start up two servers
236		server := httptest.NewServer(h)
237		httpBase := server.URL[7:] // strip "http://"
238
239		tlsServer := httptest.NewUnstartedServer(h)
240		tlsServer.StartTLS()
241		httpsBase := tlsServer.URL[8:] // strip "https://"
242
243		capool := x509.NewCertPool()
244		cert, _ := x509.ParseCertificate(tlsServer.TLS.Certificates[0].Certificate[0])
245		capool.AddCert(cert)
246
247		client := &http.Client{
248			Transport: &http.Transport{
249				TLSClientConfig: &tls.Config{
250					RootCAs: capool,
251				},
252			},
253		}
254
255		options := ResolverOptions{}
256		createHost := func(host string) RegistryHost {
257			return RegistryHost{
258				Client:       client,
259				Host:         host,
260				Scheme:       "https",
261				Path:         "/v2",
262				Capabilities: HostCapabilityPull | HostCapabilityResolve | HostCapabilityPush,
263			}
264		}
265
266		// Override hosts with two hosts
267		options.Hosts = func(host string) ([]RegistryHost, error) {
268			return []RegistryHost{
269				createHost(httpBase),  // This host is serving plain HTTP
270				createHost(httpsBase), // This host is serving TLS
271			}, nil
272		}
273
274		return httpBase, options, func() {
275			server.Close()
276			tlsServer.Close()
277		}
278	}
279
280	runBasicTest(t, "testname", sf)
281}
282
283func TestResolveProxy(t *testing.T) {
284	var (
285		ctx  = context.Background()
286		tag  = "latest"
287		r    = http.NewServeMux()
288		name = "testname"
289		ns   = "upstream.example.com"
290	)
291
292	m := newManifest(
293		newContent(ocispec.MediaTypeImageConfig, []byte("1")),
294		newContent(ocispec.MediaTypeImageLayerGzip, []byte("2")),
295	)
296	mc := newContent(ocispec.MediaTypeImageManifest, m.OCIManifest())
297	m.RegisterHandler(r, name)
298	r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, tag), mc)
299	r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, mc.Digest()), mc)
300
301	nr := namespaceRouter{
302		"upstream.example.com": r,
303	}
304
305	base, ro, close := tlsServer(logHandler{t, nr})
306	defer close()
307
308	ro.Hosts = func(host string) ([]RegistryHost, error) {
309		return []RegistryHost{{
310			Client:       ro.Client,
311			Host:         base,
312			Scheme:       "https",
313			Path:         "/v2",
314			Capabilities: HostCapabilityPull | HostCapabilityResolve,
315		}}, nil
316	}
317
318	resolver := NewResolver(ro)
319	image := fmt.Sprintf("%s/%s:%s", ns, name, tag)
320
321	_, d, err := resolver.Resolve(ctx, image)
322	if err != nil {
323		t.Fatal(err)
324	}
325	f, err := resolver.Fetcher(ctx, image)
326	if err != nil {
327		t.Fatal(err)
328	}
329
330	refs, err := testocimanifest(ctx, f, d)
331	if err != nil {
332		t.Fatal(err)
333	}
334
335	if len(refs) != 2 {
336		t.Fatalf("Unexpected number of references: %d, expected 2", len(refs))
337	}
338
339	for _, ref := range refs {
340		if err := testFetch(ctx, f, ref); err != nil {
341			t.Fatal(err)
342		}
343	}
344}
345
346func TestResolveProxyFallback(t *testing.T) {
347	var (
348		ctx  = context.Background()
349		tag  = "latest"
350		r    = http.NewServeMux()
351		name = "testname"
352	)
353
354	m := newManifest(
355		newContent(ocispec.MediaTypeImageConfig, []byte("1")),
356		newContent(ocispec.MediaTypeImageLayerGzip, []byte("2")),
357	)
358	mc := newContent(ocispec.MediaTypeImageManifest, m.OCIManifest())
359	m.RegisterHandler(r, name)
360	r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, tag), mc)
361	r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, mc.Digest()), mc)
362
363	nr := namespaceRouter{
364		"": r,
365	}
366	s := httptest.NewServer(logHandler{t, nr})
367	defer s.Close()
368
369	base := s.URL[7:] // strip "http://"
370
371	ro := ResolverOptions{
372		Hosts: func(host string) ([]RegistryHost, error) {
373			return []RegistryHost{
374				{
375					Host:         flipLocalhost(host),
376					Scheme:       "http",
377					Path:         "/v2",
378					Capabilities: HostCapabilityPull | HostCapabilityResolve,
379				},
380				{
381					Host:         host,
382					Scheme:       "http",
383					Path:         "/v2",
384					Capabilities: HostCapabilityPull | HostCapabilityResolve | HostCapabilityPush,
385				},
386			}, nil
387		},
388	}
389
390	resolver := NewResolver(ro)
391	image := fmt.Sprintf("%s/%s:%s", base, name, tag)
392
393	_, d, err := resolver.Resolve(ctx, image)
394	if err != nil {
395		t.Fatal(err)
396	}
397	f, err := resolver.Fetcher(ctx, image)
398	if err != nil {
399		t.Fatal(err)
400	}
401
402	refs, err := testocimanifest(ctx, f, d)
403	if err != nil {
404		t.Fatal(err)
405	}
406
407	if len(refs) != 2 {
408		t.Fatalf("Unexpected number of references: %d, expected 2", len(refs))
409	}
410
411	for _, ref := range refs {
412		if err := testFetch(ctx, f, ref); err != nil {
413			t.Fatal(err)
414		}
415	}
416}
417
418func flipLocalhost(host string) string {
419	if strings.HasPrefix(host, "127.0.0.1") {
420		return "localhost" + host[9:]
421
422	} else if strings.HasPrefix(host, "localhost") {
423		return "127.0.0.1" + host[9:]
424	}
425	return host
426}
427
428func withTokenServer(th http.Handler, creds func(string) (string, string, error)) func(h http.Handler) (string, ResolverOptions, func()) {
429	return func(h http.Handler) (string, ResolverOptions, func()) {
430		s := httptest.NewUnstartedServer(th)
431		s.StartTLS()
432
433		cert, _ := x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
434		tokenBase := s.URL + "/token"
435
436		// Wrap with token auth
437		wrapped := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
438			auth := strings.ToLower(r.Header.Get("Authorization"))
439			if auth != "bearer perfectlyvalidopaquetoken" {
440				authHeader := fmt.Sprintf("Bearer realm=%q,service=registry,scope=\"repository:testname:pull,pull\"", tokenBase)
441				if strings.HasPrefix(auth, "bearer ") {
442					authHeader = authHeader + ",error=" + auth[7:]
443				}
444				rw.Header().Set("WWW-Authenticate", authHeader)
445				rw.WriteHeader(http.StatusUnauthorized)
446				return
447			}
448			h.ServeHTTP(rw, r)
449		})
450
451		base, options, close := tlsServer(wrapped)
452		options.Hosts = ConfigureDefaultRegistries(
453			WithClient(options.Client),
454			WithAuthorizer(NewDockerAuthorizer(
455				WithAuthClient(options.Client),
456				WithAuthCreds(creds),
457			)),
458		)
459		options.Client.Transport.(*http.Transport).TLSClientConfig.RootCAs.AddCert(cert)
460		return base, options, func() {
461			s.Close()
462			close()
463		}
464	}
465}
466
467func tlsServer(h http.Handler) (string, ResolverOptions, func()) {
468	s := httptest.NewUnstartedServer(h)
469	s.StartTLS()
470
471	capool := x509.NewCertPool()
472	cert, _ := x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
473	capool.AddCert(cert)
474
475	client := &http.Client{
476		Transport: &http.Transport{
477			TLSClientConfig: &tls.Config{
478				RootCAs: capool,
479			},
480		},
481	}
482	options := ResolverOptions{
483		Hosts: ConfigureDefaultRegistries(WithClient(client)),
484		// Set deprecated field for tests to use for configuration
485		Client: client,
486	}
487	base := s.URL[8:] // strip "https://"
488	return base, options, s.Close
489}
490
491type logHandler struct {
492	t       *testing.T
493	handler http.Handler
494}
495
496func (h logHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
497	h.handler.ServeHTTP(rw, r)
498}
499
500type namespaceRouter map[string]http.Handler
501
502func (nr namespaceRouter) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
503	h, ok := nr[r.URL.Query().Get("ns")]
504	if !ok {
505		rw.WriteHeader(http.StatusNotFound)
506		return
507	}
508	h.ServeHTTP(rw, r)
509}
510
511func runBasicTest(t *testing.T, name string, sf func(h http.Handler) (string, ResolverOptions, func())) {
512	var (
513		ctx = context.Background()
514		tag = "latest"
515		r   = http.NewServeMux()
516	)
517
518	m := newManifest(
519		newContent(ocispec.MediaTypeImageConfig, []byte("1")),
520		newContent(ocispec.MediaTypeImageLayerGzip, []byte("2")),
521	)
522	mc := newContent(ocispec.MediaTypeImageManifest, m.OCIManifest())
523	m.RegisterHandler(r, name)
524	r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, tag), mc)
525	r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, mc.Digest()), mc)
526
527	base, ro, close := sf(logHandler{t, r})
528	defer close()
529
530	resolver := NewResolver(ro)
531	image := fmt.Sprintf("%s/%s:%s", base, name, tag)
532
533	_, d, err := resolver.Resolve(ctx, image)
534	if err != nil {
535		t.Fatal(err)
536	}
537	f, err := resolver.Fetcher(ctx, image)
538	if err != nil {
539		t.Fatal(err)
540	}
541
542	refs, err := testocimanifest(ctx, f, d)
543	if err != nil {
544		t.Fatal(err)
545	}
546
547	if len(refs) != 2 {
548		t.Fatalf("Unexpected number of references: %d, expected 2", len(refs))
549	}
550
551	for _, ref := range refs {
552		if err := testFetch(ctx, f, ref); err != nil {
553			t.Fatal(err)
554		}
555	}
556}
557
558func testFetch(ctx context.Context, f remotes.Fetcher, desc ocispec.Descriptor) error {
559	r, err := f.Fetch(ctx, desc)
560	if err != nil {
561		return err
562	}
563	dgstr := desc.Digest.Algorithm().Digester()
564	io.Copy(dgstr.Hash(), r)
565	if dgstr.Digest() != desc.Digest {
566		return errors.Errorf("content mismatch: %s != %s", dgstr.Digest(), desc.Digest)
567	}
568
569	return nil
570}
571
572func testocimanifest(ctx context.Context, f remotes.Fetcher, desc ocispec.Descriptor) ([]ocispec.Descriptor, error) {
573	r, err := f.Fetch(ctx, desc)
574	if err != nil {
575		return nil, errors.Wrapf(err, "failed to fetch %s", desc.Digest)
576	}
577	p, err := ioutil.ReadAll(r)
578	if err != nil {
579		return nil, err
580	}
581	if dgst := desc.Digest.Algorithm().FromBytes(p); dgst != desc.Digest {
582		return nil, errors.Errorf("digest mismatch: %s != %s", dgst, desc.Digest)
583	}
584
585	var manifest ocispec.Manifest
586	if err := json.Unmarshal(p, &manifest); err != nil {
587		return nil, err
588	}
589
590	var descs []ocispec.Descriptor
591
592	descs = append(descs, manifest.Config)
593	descs = append(descs, manifest.Layers...)
594
595	return descs, nil
596}
597
598type testContent struct {
599	mediaType string
600	content   []byte
601}
602
603func newContent(mediaType string, b []byte) testContent {
604	return testContent{
605		mediaType: mediaType,
606		content:   b,
607	}
608}
609
610func (tc testContent) Descriptor() ocispec.Descriptor {
611	return ocispec.Descriptor{
612		MediaType: tc.mediaType,
613		Digest:    digest.FromBytes(tc.content),
614		Size:      int64(len(tc.content)),
615	}
616}
617
618func (tc testContent) Digest() digest.Digest {
619	return digest.FromBytes(tc.content)
620}
621
622func (tc testContent) ServeHTTP(w http.ResponseWriter, r *http.Request) {
623	w.Header().Add("Content-Type", tc.mediaType)
624	w.Header().Add("Content-Length", strconv.Itoa(len(tc.content)))
625	w.Header().Add("Docker-Content-Digest", tc.Digest().String())
626	w.WriteHeader(http.StatusOK)
627	w.Write(tc.content)
628}
629
630type testManifest struct {
631	config     testContent
632	references []testContent
633}
634
635func newManifest(config testContent, refs ...testContent) testManifest {
636	return testManifest{
637		config:     config,
638		references: refs,
639	}
640}
641
642func (m testManifest) OCIManifest() []byte {
643	manifest := ocispec.Manifest{
644		Versioned: specs.Versioned{
645			SchemaVersion: 1,
646		},
647		Config: m.config.Descriptor(),
648		Layers: make([]ocispec.Descriptor, len(m.references)),
649	}
650	for i, c := range m.references {
651		manifest.Layers[i] = c.Descriptor()
652	}
653	b, _ := json.Marshal(manifest)
654	return b
655}
656
657func (m testManifest) RegisterHandler(r *http.ServeMux, name string) {
658	for _, c := range append(m.references, m.config) {
659		r.Handle(fmt.Sprintf("/v2/%s/blobs/%s", name, c.Digest()), c)
660	}
661}
662