1package proxy
2
3import (
4	"context"
5	"io/ioutil"
6	"math/rand"
7	"net/http"
8	"net/http/httptest"
9	"sync"
10	"testing"
11	"time"
12
13	"github.com/docker/distribution"
14	"github.com/docker/distribution/reference"
15	"github.com/docker/distribution/registry/proxy/scheduler"
16	"github.com/docker/distribution/registry/storage"
17	"github.com/docker/distribution/registry/storage/cache/memory"
18	"github.com/docker/distribution/registry/storage/driver/filesystem"
19	"github.com/docker/distribution/registry/storage/driver/inmemory"
20	"github.com/opencontainers/go-digest"
21)
22
23var sbsMu sync.Mutex
24
25type statsBlobStore struct {
26	stats map[string]int
27	blobs distribution.BlobStore
28}
29
30func (sbs statsBlobStore) Put(ctx context.Context, mediaType string, p []byte) (distribution.Descriptor, error) {
31	sbsMu.Lock()
32	sbs.stats["put"]++
33	sbsMu.Unlock()
34
35	return sbs.blobs.Put(ctx, mediaType, p)
36}
37
38func (sbs statsBlobStore) Get(ctx context.Context, dgst digest.Digest) ([]byte, error) {
39	sbsMu.Lock()
40	sbs.stats["get"]++
41	sbsMu.Unlock()
42
43	return sbs.blobs.Get(ctx, dgst)
44}
45
46func (sbs statsBlobStore) Create(ctx context.Context, options ...distribution.BlobCreateOption) (distribution.BlobWriter, error) {
47	sbsMu.Lock()
48	sbs.stats["create"]++
49	sbsMu.Unlock()
50
51	return sbs.blobs.Create(ctx, options...)
52}
53
54func (sbs statsBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) {
55	sbsMu.Lock()
56	sbs.stats["resume"]++
57	sbsMu.Unlock()
58
59	return sbs.blobs.Resume(ctx, id)
60}
61
62func (sbs statsBlobStore) Open(ctx context.Context, dgst digest.Digest) (distribution.ReadSeekCloser, error) {
63	sbsMu.Lock()
64	sbs.stats["open"]++
65	sbsMu.Unlock()
66
67	return sbs.blobs.Open(ctx, dgst)
68}
69
70func (sbs statsBlobStore) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error {
71	sbsMu.Lock()
72	sbs.stats["serveblob"]++
73	sbsMu.Unlock()
74
75	return sbs.blobs.ServeBlob(ctx, w, r, dgst)
76}
77
78func (sbs statsBlobStore) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
79
80	sbsMu.Lock()
81	sbs.stats["stat"]++
82	sbsMu.Unlock()
83
84	return sbs.blobs.Stat(ctx, dgst)
85}
86
87func (sbs statsBlobStore) Delete(ctx context.Context, dgst digest.Digest) error {
88	sbsMu.Lock()
89	sbs.stats["delete"]++
90	sbsMu.Unlock()
91
92	return sbs.blobs.Delete(ctx, dgst)
93}
94
95type testEnv struct {
96	numUnique int
97	inRemote  []distribution.Descriptor
98	store     proxyBlobStore
99	ctx       context.Context
100}
101
102func (te *testEnv) LocalStats() *map[string]int {
103	sbsMu.Lock()
104	ls := te.store.localStore.(statsBlobStore).stats
105	sbsMu.Unlock()
106	return &ls
107}
108
109func (te *testEnv) RemoteStats() *map[string]int {
110	sbsMu.Lock()
111	rs := te.store.remoteStore.(statsBlobStore).stats
112	sbsMu.Unlock()
113	return &rs
114}
115
116// Populate remote store and record the digests
117func makeTestEnv(t *testing.T, name string) *testEnv {
118	nameRef, err := reference.WithName(name)
119	if err != nil {
120		t.Fatalf("unable to parse reference: %s", err)
121	}
122
123	ctx := context.Background()
124
125	truthDir, err := ioutil.TempDir("", "truth")
126	if err != nil {
127		t.Fatalf("unable to create tempdir: %s", err)
128	}
129
130	cacheDir, err := ioutil.TempDir("", "cache")
131	if err != nil {
132		t.Fatalf("unable to create tempdir: %s", err)
133	}
134
135	localDriver, err := filesystem.FromParameters(map[string]interface{}{
136		"rootdirectory": truthDir,
137	})
138	if err != nil {
139		t.Fatalf("unable to create filesystem driver: %s", err)
140	}
141
142	// todo: create a tempfile area here
143	localRegistry, err := storage.NewRegistry(ctx, localDriver, storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption)
144	if err != nil {
145		t.Fatalf("error creating registry: %v", err)
146	}
147	localRepo, err := localRegistry.Repository(ctx, nameRef)
148	if err != nil {
149		t.Fatalf("unexpected error getting repo: %v", err)
150	}
151
152	cacheDriver, err := filesystem.FromParameters(map[string]interface{}{
153		"rootdirectory": cacheDir,
154	})
155	if err != nil {
156		t.Fatalf("unable to create filesystem driver: %s", err)
157	}
158
159	truthRegistry, err := storage.NewRegistry(ctx, cacheDriver, storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()))
160	if err != nil {
161		t.Fatalf("error creating registry: %v", err)
162	}
163	truthRepo, err := truthRegistry.Repository(ctx, nameRef)
164	if err != nil {
165		t.Fatalf("unexpected error getting repo: %v", err)
166	}
167
168	truthBlobs := statsBlobStore{
169		stats: make(map[string]int),
170		blobs: truthRepo.Blobs(ctx),
171	}
172
173	localBlobs := statsBlobStore{
174		stats: make(map[string]int),
175		blobs: localRepo.Blobs(ctx),
176	}
177
178	s := scheduler.New(ctx, inmemory.New(), "/scheduler-state.json")
179
180	proxyBlobStore := proxyBlobStore{
181		repositoryName: nameRef,
182		remoteStore:    truthBlobs,
183		localStore:     localBlobs,
184		scheduler:      s,
185		authChallenger: &mockChallenger{},
186	}
187
188	te := &testEnv{
189		store: proxyBlobStore,
190		ctx:   ctx,
191	}
192	return te
193}
194
195func makeBlob(size int) []byte {
196	blob := make([]byte, size, size)
197	for i := 0; i < size; i++ {
198		blob[i] = byte('A' + rand.Int()%48)
199	}
200	return blob
201}
202
203func init() {
204	rand.Seed(42)
205}
206
207func perm(m []distribution.Descriptor) []distribution.Descriptor {
208	for i := 0; i < len(m); i++ {
209		j := rand.Intn(i + 1)
210		tmp := m[i]
211		m[i] = m[j]
212		m[j] = tmp
213	}
214	return m
215}
216
217func populate(t *testing.T, te *testEnv, blobCount, size, numUnique int) {
218	var inRemote []distribution.Descriptor
219
220	for i := 0; i < numUnique; i++ {
221		bytes := makeBlob(size)
222		for j := 0; j < blobCount/numUnique; j++ {
223			desc, err := te.store.remoteStore.Put(te.ctx, "", bytes)
224			if err != nil {
225				t.Fatalf("Put in store")
226			}
227
228			inRemote = append(inRemote, desc)
229		}
230	}
231
232	te.inRemote = inRemote
233	te.numUnique = numUnique
234}
235func TestProxyStoreGet(t *testing.T) {
236	te := makeTestEnv(t, "foo/bar")
237
238	localStats := te.LocalStats()
239	remoteStats := te.RemoteStats()
240
241	populate(t, te, 1, 10, 1)
242	_, err := te.store.Get(te.ctx, te.inRemote[0].Digest)
243	if err != nil {
244		t.Fatal(err)
245	}
246
247	if (*localStats)["get"] != 1 && (*localStats)["put"] != 1 {
248		t.Errorf("Unexpected local counts")
249	}
250
251	if (*remoteStats)["get"] != 1 {
252		t.Errorf("Unexpected remote get count")
253	}
254
255	_, err = te.store.Get(te.ctx, te.inRemote[0].Digest)
256	if err != nil {
257		t.Fatal(err)
258	}
259
260	if (*localStats)["get"] != 2 && (*localStats)["put"] != 1 {
261		t.Errorf("Unexpected local counts")
262	}
263
264	if (*remoteStats)["get"] != 1 {
265		t.Errorf("Unexpected remote get count")
266	}
267
268}
269
270func TestProxyStoreStat(t *testing.T) {
271	te := makeTestEnv(t, "foo/bar")
272
273	remoteBlobCount := 1
274	populate(t, te, remoteBlobCount, 10, 1)
275
276	localStats := te.LocalStats()
277	remoteStats := te.RemoteStats()
278
279	// Stat - touches both stores
280	for _, d := range te.inRemote {
281		_, err := te.store.Stat(te.ctx, d.Digest)
282		if err != nil {
283			t.Fatalf("Error stating proxy store")
284		}
285	}
286
287	if (*localStats)["stat"] != remoteBlobCount {
288		t.Errorf("Unexpected local stat count")
289	}
290
291	if (*remoteStats)["stat"] != remoteBlobCount {
292		t.Errorf("Unexpected remote stat count")
293	}
294
295	if te.store.authChallenger.(*mockChallenger).count != len(te.inRemote) {
296		t.Fatalf("Unexpected auth challenge count, got %#v", te.store.authChallenger)
297	}
298
299}
300
301func TestProxyStoreServeHighConcurrency(t *testing.T) {
302	te := makeTestEnv(t, "foo/bar")
303	blobSize := 200
304	blobCount := 10
305	numUnique := 1
306	populate(t, te, blobCount, blobSize, numUnique)
307
308	numClients := 16
309	testProxyStoreServe(t, te, numClients)
310}
311
312func TestProxyStoreServeMany(t *testing.T) {
313	te := makeTestEnv(t, "foo/bar")
314	blobSize := 200
315	blobCount := 10
316	numUnique := 4
317	populate(t, te, blobCount, blobSize, numUnique)
318
319	numClients := 4
320	testProxyStoreServe(t, te, numClients)
321}
322
323// todo(richardscothern): blobCount must be smaller than num clients
324func TestProxyStoreServeBig(t *testing.T) {
325	te := makeTestEnv(t, "foo/bar")
326
327	blobSize := 2 << 20
328	blobCount := 4
329	numUnique := 2
330	populate(t, te, blobCount, blobSize, numUnique)
331
332	numClients := 4
333	testProxyStoreServe(t, te, numClients)
334}
335
336// testProxyStoreServe will create clients to consume all blobs
337// populated in the truth store
338func testProxyStoreServe(t *testing.T, te *testEnv, numClients int) {
339	localStats := te.LocalStats()
340	remoteStats := te.RemoteStats()
341
342	var wg sync.WaitGroup
343
344	for i := 0; i < numClients; i++ {
345		// Serveblob - pulls through blobs
346		wg.Add(1)
347		go func() {
348			defer wg.Done()
349			for _, remoteBlob := range te.inRemote {
350				w := httptest.NewRecorder()
351				r, err := http.NewRequest("GET", "", nil)
352				if err != nil {
353					t.Error(err)
354					return
355				}
356
357				err = te.store.ServeBlob(te.ctx, w, r, remoteBlob.Digest)
358				if err != nil {
359					t.Errorf(err.Error())
360					return
361				}
362
363				bodyBytes := w.Body.Bytes()
364				localDigest := digest.FromBytes(bodyBytes)
365				if localDigest != remoteBlob.Digest {
366					t.Errorf("Mismatching blob fetch from proxy")
367					return
368				}
369			}
370		}()
371	}
372
373	wg.Wait()
374	if t.Failed() {
375		t.FailNow()
376	}
377
378	remoteBlobCount := len(te.inRemote)
379	sbsMu.Lock()
380	if (*localStats)["stat"] != remoteBlobCount*numClients && (*localStats)["create"] != te.numUnique {
381		sbsMu.Unlock()
382		t.Fatal("Expected: stat:", remoteBlobCount*numClients, "create:", remoteBlobCount)
383	}
384	sbsMu.Unlock()
385
386	// Wait for any async storage goroutines to finish
387	time.Sleep(3 * time.Second)
388
389	sbsMu.Lock()
390	remoteStatCount := (*remoteStats)["stat"]
391	remoteOpenCount := (*remoteStats)["open"]
392	sbsMu.Unlock()
393
394	// Serveblob - blobs come from local
395	for _, dr := range te.inRemote {
396		w := httptest.NewRecorder()
397		r, err := http.NewRequest("GET", "", nil)
398		if err != nil {
399			t.Fatal(err)
400		}
401
402		err = te.store.ServeBlob(te.ctx, w, r, dr.Digest)
403		if err != nil {
404			t.Fatalf(err.Error())
405		}
406
407		dl := digest.FromBytes(w.Body.Bytes())
408		if dl != dr.Digest {
409			t.Errorf("Mismatching blob fetch from proxy")
410		}
411	}
412
413	remoteStats = te.RemoteStats()
414
415	// Ensure remote unchanged
416	sbsMu.Lock()
417	defer sbsMu.Unlock()
418	if (*remoteStats)["stat"] != remoteStatCount && (*remoteStats)["open"] != remoteOpenCount {
419		t.Fatalf("unexpected remote stats: %#v", remoteStats)
420	}
421}
422