1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gensupport
6
7import (
8	"context"
9	"fmt"
10	"io"
11	"io/ioutil"
12	"net/http"
13	"reflect"
14	"strings"
15	"testing"
16	"time"
17)
18
19type unexpectedReader struct{}
20
21func (unexpectedReader) Read([]byte) (int, error) {
22	return 0, fmt.Errorf("unexpected read in test")
23}
24
25// event is an expected request/response pair
26type event struct {
27	// the byte range header that should be present in a request.
28	byteRange string
29	// the http status code to send in response.
30	responseStatus int
31}
32
33// interruptibleTransport is configured with a canned set of requests/responses.
34// It records the incoming data, unless the corresponding event is configured to return
35// http.StatusServiceUnavailable.
36type interruptibleTransport struct {
37	events []event
38	buf    []byte
39	bodies bodyTracker
40}
41
42// bodyTracker keeps track of response bodies that have not been closed.
43type bodyTracker map[io.ReadCloser]struct{}
44
45func (bt bodyTracker) Add(body io.ReadCloser) {
46	bt[body] = struct{}{}
47}
48
49func (bt bodyTracker) Close(body io.ReadCloser) {
50	delete(bt, body)
51}
52
53type trackingCloser struct {
54	io.Reader
55	tracker bodyTracker
56}
57
58func (tc *trackingCloser) Close() error {
59	tc.tracker.Close(tc)
60	return nil
61}
62
63func (tc *trackingCloser) Open() {
64	tc.tracker.Add(tc)
65}
66
67func (t *interruptibleTransport) RoundTrip(req *http.Request) (*http.Response, error) {
68	if len(t.events) == 0 {
69		panic("ran out of events, but got a request")
70	}
71	ev := t.events[0]
72	t.events = t.events[1:]
73	if got, want := req.Header.Get("Content-Range"), ev.byteRange; got != want {
74		return nil, fmt.Errorf("byte range: got %s; want %s", got, want)
75	}
76
77	if ev.responseStatus != http.StatusServiceUnavailable {
78		buf, err := ioutil.ReadAll(req.Body)
79		if err != nil {
80			return nil, fmt.Errorf("error reading from request data: %v", err)
81		}
82		t.buf = append(t.buf, buf...)
83	}
84
85	tc := &trackingCloser{unexpectedReader{}, t.bodies}
86	tc.Open()
87	h := http.Header{}
88	status := ev.responseStatus
89
90	// Support "X-GUploader-No-308" like Google:
91	if status == 308 && req.Header.Get("X-GUploader-No-308") == "yes" {
92		status = 200
93		h.Set("X-Http-Status-Code-Override", "308")
94	}
95
96	res := &http.Response{
97		StatusCode: status,
98		Header:     h,
99		Body:       tc,
100	}
101	return res, nil
102}
103
104// progressRecorder records updates, and calls f for every invocation of ProgressUpdate.
105type progressRecorder struct {
106	updates []int64
107	f       func()
108}
109
110func (pr *progressRecorder) ProgressUpdate(current int64) {
111	pr.updates = append(pr.updates, current)
112	if pr.f != nil {
113		pr.f()
114	}
115}
116
117func TestInterruptedTransferChunks(t *testing.T) {
118	type testCase struct {
119		name         string
120		data         string
121		chunkSize    int
122		events       []event
123		wantProgress []int64
124	}
125
126	for _, tc := range []testCase{
127		{
128			name:      "large",
129			data:      strings.Repeat("a", 300),
130			chunkSize: 90,
131			events: []event{
132				{"bytes 0-89/*", http.StatusServiceUnavailable},
133				{"bytes 0-89/*", 308},
134				{"bytes 90-179/*", 308},
135				{"bytes 180-269/*", http.StatusServiceUnavailable},
136				{"bytes 180-269/*", 308},
137				{"bytes 270-299/300", 200},
138			},
139			wantProgress: []int64{90, 180, 270, 300},
140		},
141		{
142			name:      "small",
143			data:      strings.Repeat("a", 20),
144			chunkSize: 10,
145			events: []event{
146				{"bytes 0-9/*", http.StatusServiceUnavailable},
147				{"bytes 0-9/*", 308},
148				{"bytes 10-19/*", http.StatusServiceUnavailable},
149				{"bytes 10-19/*", 308},
150				// 0 byte final request demands a byte range with leading asterix.
151				{"bytes */20", http.StatusServiceUnavailable},
152				{"bytes */20", 200},
153			},
154			wantProgress: []int64{10, 20},
155		},
156	} {
157		t.Run(tc.name, func(t *testing.T) {
158			media := strings.NewReader(tc.data)
159
160			tr := &interruptibleTransport{
161				buf:    make([]byte, 0, len(tc.data)),
162				events: tc.events,
163				bodies: bodyTracker{},
164			}
165
166			pr := progressRecorder{}
167			rx := &ResumableUpload{
168				Client:    &http.Client{Transport: tr},
169				Media:     NewMediaBuffer(media, tc.chunkSize),
170				MediaType: "text/plain",
171				Callback:  pr.ProgressUpdate,
172			}
173
174			oldBackoff := backoff
175			backoff = func() Backoff { return new(NoPauseBackoff) }
176			defer func() { backoff = oldBackoff }()
177
178			res, err := rx.Upload(context.Background())
179			if err == nil {
180				res.Body.Close()
181			}
182			if err != nil || res == nil || res.StatusCode != http.StatusOK {
183				if res == nil {
184					t.Fatalf("Upload not successful, res=nil: %v", err)
185				} else {
186					t.Fatalf("Upload not successful, statusCode=%v, err=%v", res.StatusCode, err)
187				}
188			}
189			if !reflect.DeepEqual(tr.buf, []byte(tc.data)) {
190				t.Fatalf("transferred contents:\ngot %s\nwant %s", tr.buf, tc.data)
191			}
192
193			if !reflect.DeepEqual(pr.updates, tc.wantProgress) {
194				t.Fatalf("progress updates: got %v, want %v", pr.updates, tc.wantProgress)
195			}
196
197			if len(tr.events) > 0 {
198				t.Fatalf("did not observe all expected events.  leftover events: %v", tr.events)
199			}
200			if len(tr.bodies) > 0 {
201				t.Errorf("unclosed request bodies: %v", tr.bodies)
202			}
203		})
204	}
205}
206
207func TestCancelUploadFast(t *testing.T) {
208	const (
209		chunkSize = 90
210		mediaSize = 300
211	)
212	media := strings.NewReader(strings.Repeat("a", mediaSize))
213
214	tr := &interruptibleTransport{
215		buf: make([]byte, 0, mediaSize),
216		// Shouldn't really need an event, but sometimes the test loses the
217		// race. So, this is just a filler event.
218		events: []event{{"bytes 0-9/*", http.StatusServiceUnavailable}},
219	}
220
221	pr := progressRecorder{}
222	rx := &ResumableUpload{
223		Client:    &http.Client{Transport: tr},
224		Media:     NewMediaBuffer(media, chunkSize),
225		MediaType: "text/plain",
226		Callback:  pr.ProgressUpdate,
227	}
228
229	oldBackoff := backoff
230	backoff = func() Backoff { return new(NoPauseBackoff) }
231	defer func() { backoff = oldBackoff }()
232
233	ctx, cancelFunc := context.WithCancel(context.Background())
234	cancelFunc() // stop the upload that hasn't started yet
235	res, err := rx.Upload(ctx)
236	if err != context.Canceled {
237		t.Fatalf("Upload err: got: %v; want: context cancelled", err)
238	}
239	if res != nil {
240		t.Fatalf("Upload result: got: %v; want: nil", res)
241	}
242	if pr.updates != nil {
243		t.Errorf("progress updates: got %v; want: nil", pr.updates)
244	}
245}
246
247func TestCancelUploadBasic(t *testing.T) {
248	const (
249		chunkSize = 90
250		mediaSize = 300
251	)
252	media := strings.NewReader(strings.Repeat("a", mediaSize))
253
254	tr := &interruptibleTransport{
255		buf: make([]byte, 0, mediaSize),
256		events: []event{
257			{"bytes 0-89/*", http.StatusServiceUnavailable},
258			{"bytes 0-89/*", 308},
259			{"bytes 90-179/*", 308},
260			{"bytes 180-269/*", 308}, // Upload should be cancelled before this event.
261		},
262		bodies: bodyTracker{},
263	}
264
265	ctx, cancelFunc := context.WithCancel(context.Background())
266	numUpdates := 0
267
268	pr := progressRecorder{f: func() {
269		numUpdates++
270		if numUpdates >= 2 {
271			cancelFunc()
272		}
273	}}
274
275	rx := &ResumableUpload{
276		Client:    &http.Client{Transport: tr},
277		Media:     NewMediaBuffer(media, chunkSize),
278		MediaType: "text/plain",
279		Callback:  pr.ProgressUpdate,
280	}
281
282	oldBackoff := backoff
283	backoff = func() Backoff { return new(PauseOneSecond) }
284	defer func() { backoff = oldBackoff }()
285
286	res, err := rx.Upload(ctx)
287	if err != context.Canceled {
288		t.Fatalf("Upload err: got: %v; want: context cancelled", err)
289	}
290	if res != nil {
291		t.Fatalf("Upload result: got: %v; want: nil", res)
292	}
293	if got, want := tr.buf, []byte(strings.Repeat("a", chunkSize*2)); !reflect.DeepEqual(got, want) {
294		t.Fatalf("transferred contents:\ngot %s\nwant %s", got, want)
295	}
296	if got, want := pr.updates, []int64{chunkSize, chunkSize * 2}; !reflect.DeepEqual(got, want) {
297		t.Fatalf("progress updates: got %v; want: %v", got, want)
298	}
299	if len(tr.bodies) > 0 {
300		t.Errorf("unclosed request bodies: %v", tr.bodies)
301	}
302}
303
304func TestRetry_Bounded(t *testing.T) {
305	const (
306		chunkSize = 90
307		mediaSize = 300
308	)
309	media := strings.NewReader(strings.Repeat("a", mediaSize))
310
311	tr := &interruptibleTransport{
312		buf: make([]byte, 0, mediaSize),
313		events: []event{
314			{"bytes 0-89/*", http.StatusServiceUnavailable},
315			{"bytes 0-89/*", http.StatusServiceUnavailable},
316		},
317		bodies: bodyTracker{},
318	}
319
320	rx := &ResumableUpload{
321		Client:    &http.Client{Transport: tr},
322		Media:     NewMediaBuffer(media, chunkSize),
323		MediaType: "text/plain",
324		Callback:  func(int64) {},
325	}
326
327	oldRetryDeadline := retryDeadline
328	retryDeadline = time.Second
329	defer func() { retryDeadline = oldRetryDeadline }()
330
331	oldBackoff := backoff
332	backoff = func() Backoff { return new(PauseForeverBackoff) }
333	defer func() { backoff = oldBackoff }()
334
335	resCode := make(chan int, 1)
336	go func() {
337		resp, err := rx.Upload(context.Background())
338		if err != nil {
339			t.Error(err)
340			return
341		}
342		resCode <- resp.StatusCode
343	}()
344
345	select {
346	case <-time.After(5 * time.Second):
347		t.Fatal("timed out waiting for Upload to complete")
348	case got := <-resCode:
349		if want, got := http.StatusServiceUnavailable, got; got != want {
350			t.Fatalf("want %d, got %d", want, got)
351		}
352	}
353}
354
355func TestRetry_EachChunkHasItsOwnRetryDeadline(t *testing.T) {
356	const (
357		chunkSize = 90
358		mediaSize = 300
359	)
360	media := strings.NewReader(strings.Repeat("a", mediaSize))
361
362	tr := &interruptibleTransport{
363		buf: make([]byte, 0, mediaSize),
364		events: []event{
365			{"bytes 0-89/*", http.StatusServiceUnavailable},
366			// cum: 1s sleep
367			{"bytes 0-89/*", http.StatusServiceUnavailable},
368			// cum: 2s sleep
369			{"bytes 0-89/*", http.StatusServiceUnavailable},
370			// cum: 3s sleep
371			{"bytes 0-89/*", http.StatusServiceUnavailable},
372			// cum: 4s sleep
373			{"bytes 0-89/*", 308},
374			// cum: 1s sleep <-- resets because it's a new chunk
375			{"bytes 90-179/*", 308},
376			// cum: 1s sleep <-- resets because it's a new chunk
377			{"bytes 180-269/*", 308},
378			// cum: 1s sleep <-- resets because it's a new chunk
379			{"bytes 270-299/300", 200},
380		},
381		bodies: bodyTracker{},
382	}
383
384	rx := &ResumableUpload{
385		Client:    &http.Client{Transport: tr},
386		Media:     NewMediaBuffer(media, chunkSize),
387		MediaType: "text/plain",
388		Callback:  func(int64) {},
389	}
390
391	oldRetryDeadline := retryDeadline
392	retryDeadline = 5 * time.Second
393	defer func() { retryDeadline = oldRetryDeadline }()
394
395	oldBackoff := backoff
396	backoff = func() Backoff { return new(PauseOneSecond) }
397	defer func() { backoff = oldBackoff }()
398
399	resCode := make(chan int, 1)
400	go func() {
401		resp, err := rx.Upload(context.Background())
402		if err != nil {
403			t.Error(err)
404			return
405		}
406		resCode <- resp.StatusCode
407	}()
408
409	select {
410	case <-time.After(15 * time.Second):
411		t.Fatal("timed out waiting for Upload to complete")
412	case got := <-resCode:
413		if want := http.StatusOK; got != want {
414			t.Fatalf("want %d, got %d", want, got)
415		}
416	}
417}
418