1package distribution // import "github.com/docker/docker/distribution"
2
3import (
4	"context"
5	"net/http"
6	"net/http/httptest"
7	"net/url"
8	"strings"
9	"testing"
10
11	"github.com/docker/distribution/reference"
12	"github.com/docker/docker/api/types"
13	registrytypes "github.com/docker/docker/api/types/registry"
14	"github.com/docker/docker/registry"
15	"github.com/sirupsen/logrus"
16)
17
18const secretRegistryToken = "mysecrettoken"
19
20type tokenPassThruHandler struct {
21	reached       bool
22	gotToken      bool
23	shouldSend401 func(url string) bool
24}
25
26func (h *tokenPassThruHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
27	h.reached = true
28	if strings.Contains(r.Header.Get("Authorization"), secretRegistryToken) {
29		logrus.Debug("Detected registry token in auth header")
30		h.gotToken = true
31	}
32	if h.shouldSend401 == nil || h.shouldSend401(r.RequestURI) {
33		w.Header().Set("WWW-Authenticate", `Bearer realm="foorealm"`)
34		w.WriteHeader(401)
35	}
36}
37
38func testTokenPassThru(t *testing.T, ts *httptest.Server) {
39	uri, err := url.Parse(ts.URL)
40	if err != nil {
41		t.Fatalf("could not parse url from test server: %v", err)
42	}
43
44	endpoint := registry.APIEndpoint{
45		Mirror:       false,
46		URL:          uri,
47		Version:      2,
48		Official:     false,
49		TrimHostname: false,
50		TLSConfig:    nil,
51	}
52	n, _ := reference.ParseNormalizedNamed("testremotename")
53	repoInfo := &registry.RepositoryInfo{
54		Name: n,
55		Index: &registrytypes.IndexInfo{
56			Name:     "testrepo",
57			Mirrors:  nil,
58			Secure:   false,
59			Official: false,
60		},
61		Official: false,
62	}
63	imagePullConfig := &ImagePullConfig{
64		Config: Config{
65			MetaHeaders: http.Header{},
66			AuthConfig: &types.AuthConfig{
67				RegistryToken: secretRegistryToken,
68			},
69		},
70		Schema2Types: ImageTypes,
71	}
72	puller, err := newPuller(endpoint, repoInfo, imagePullConfig)
73	if err != nil {
74		t.Fatal(err)
75	}
76	p := puller.(*v2Puller)
77	ctx := context.Background()
78	p.repo, _, err = NewV2Repository(ctx, p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull")
79	if err != nil {
80		t.Fatal(err)
81	}
82
83	logrus.Debug("About to pull")
84	// We expect it to fail, since we haven't mock'd the full registry exchange in our handler above
85	tag, _ := reference.WithTag(n, "tag_goes_here")
86	_ = p.pullV2Repository(ctx, tag, nil)
87}
88
89func TestTokenPassThru(t *testing.T) {
90	handler := &tokenPassThruHandler{shouldSend401: func(url string) bool { return url == "/v2/" }}
91	ts := httptest.NewServer(handler)
92	defer ts.Close()
93
94	testTokenPassThru(t, ts)
95
96	if !handler.reached {
97		t.Fatal("Handler not reached")
98	}
99	if !handler.gotToken {
100		t.Fatal("Failed to receive registry token")
101	}
102}
103
104func TestTokenPassThruDifferentHost(t *testing.T) {
105	handler := new(tokenPassThruHandler)
106	ts := httptest.NewServer(handler)
107	defer ts.Close()
108
109	tsredirect := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
110		if r.RequestURI == "/v2/" {
111			w.Header().Set("WWW-Authenticate", `Bearer realm="foorealm"`)
112			w.WriteHeader(401)
113			return
114		}
115		http.Redirect(w, r, ts.URL+r.URL.Path, http.StatusMovedPermanently)
116	}))
117	defer tsredirect.Close()
118
119	testTokenPassThru(t, tsredirect)
120
121	if !handler.reached {
122		t.Fatal("Handler not reached")
123	}
124	if handler.gotToken {
125		t.Fatal("Redirect should not forward Authorization header to another host")
126	}
127}
128