1package distribution
2
3import (
4	"net/http"
5	"net/http/httptest"
6	"net/url"
7	"runtime"
8	"strings"
9	"testing"
10
11	"github.com/docker/distribution/reference"
12	"github.com/docker/docker/api/types"
13	registrytypes "github.com/docker/docker/api/types/registry"
14	"github.com/docker/docker/registry"
15	"github.com/sirupsen/logrus"
16	"golang.org/x/net/context"
17)
18
19const secretRegistryToken = "mysecrettoken"
20
21type tokenPassThruHandler struct {
22	reached       bool
23	gotToken      bool
24	shouldSend401 func(url string) bool
25}
26
27func (h *tokenPassThruHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
28	h.reached = true
29	if strings.Contains(r.Header.Get("Authorization"), secretRegistryToken) {
30		logrus.Debug("Detected registry token in auth header")
31		h.gotToken = true
32	}
33	if h.shouldSend401 == nil || h.shouldSend401(r.RequestURI) {
34		w.Header().Set("WWW-Authenticate", `Bearer realm="foorealm"`)
35		w.WriteHeader(401)
36	}
37}
38
39func testTokenPassThru(t *testing.T, ts *httptest.Server) {
40	uri, err := url.Parse(ts.URL)
41	if err != nil {
42		t.Fatalf("could not parse url from test server: %v", err)
43	}
44
45	endpoint := registry.APIEndpoint{
46		Mirror:       false,
47		URL:          uri,
48		Version:      2,
49		Official:     false,
50		TrimHostname: false,
51		TLSConfig:    nil,
52	}
53	n, _ := reference.ParseNormalizedNamed("testremotename")
54	repoInfo := &registry.RepositoryInfo{
55		Name: n,
56		Index: &registrytypes.IndexInfo{
57			Name:     "testrepo",
58			Mirrors:  nil,
59			Secure:   false,
60			Official: false,
61		},
62		Official: false,
63	}
64	imagePullConfig := &ImagePullConfig{
65		Config: Config{
66			MetaHeaders: http.Header{},
67			AuthConfig: &types.AuthConfig{
68				RegistryToken: secretRegistryToken,
69			},
70		},
71		Schema2Types: ImageTypes,
72	}
73	puller, err := newPuller(endpoint, repoInfo, imagePullConfig)
74	if err != nil {
75		t.Fatal(err)
76	}
77	p := puller.(*v2Puller)
78	ctx := context.Background()
79	p.repo, _, err = NewV2Repository(ctx, p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull")
80	if err != nil {
81		t.Fatal(err)
82	}
83
84	logrus.Debug("About to pull")
85	// We expect it to fail, since we haven't mock'd the full registry exchange in our handler above
86	tag, _ := reference.WithTag(n, "tag_goes_here")
87	_ = p.pullV2Repository(ctx, tag, runtime.GOOS)
88}
89
90func TestTokenPassThru(t *testing.T) {
91	handler := &tokenPassThruHandler{shouldSend401: func(url string) bool { return url == "/v2/" }}
92	ts := httptest.NewServer(handler)
93	defer ts.Close()
94
95	testTokenPassThru(t, ts)
96
97	if !handler.reached {
98		t.Fatal("Handler not reached")
99	}
100	if !handler.gotToken {
101		t.Fatal("Failed to receive registry token")
102	}
103}
104
105func TestTokenPassThruDifferentHost(t *testing.T) {
106	handler := new(tokenPassThruHandler)
107	ts := httptest.NewServer(handler)
108	defer ts.Close()
109
110	tsredirect := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
111		if r.RequestURI == "/v2/" {
112			w.Header().Set("WWW-Authenticate", `Bearer realm="foorealm"`)
113			w.WriteHeader(401)
114			return
115		}
116		http.Redirect(w, r, ts.URL+r.URL.Path, http.StatusMovedPermanently)
117	}))
118	defer tsredirect.Close()
119
120	testTokenPassThru(t, tsredirect)
121
122	if !handler.reached {
123		t.Fatal("Handler not reached")
124	}
125	if handler.gotToken {
126		t.Fatal("Redirect should not forward Authorization header to another host")
127	}
128}
129