1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package http2
6
7import (
8	"errors"
9	"fmt"
10	"io"
11	"io/ioutil"
12	"net/http"
13	"reflect"
14	"strconv"
15	"sync"
16	"testing"
17	"time"
18)
19
20func TestServer_Push_Success(t *testing.T) {
21	const (
22		mainBody   = "<html>index page</html>"
23		pushedBody = "<html>pushed page</html>"
24		userAgent  = "testagent"
25		cookie     = "testcookie"
26	)
27
28	var stURL string
29	checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
30		if got, want := r.Method, wantMethod; got != want {
31			return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
32		}
33		if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
34			return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
35		}
36		if got, want := "https://"+r.Host, stURL; got != want {
37			return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
38		}
39		if r.Body == nil {
40			return fmt.Errorf("nil Body")
41		}
42		if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 {
43			return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
44		}
45		return nil
46	}
47
48	errc := make(chan error, 3)
49	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
50		switch r.URL.RequestURI() {
51		case "/":
52			// Push "/pushed?get" as a GET request, using an absolute URL.
53			opt := &http.PushOptions{
54				Header: http.Header{
55					"User-Agent": {userAgent},
56				},
57			}
58			if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
59				errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
60				return
61			}
62			// Push "/pushed?head" as a HEAD request, using a path.
63			opt = &http.PushOptions{
64				Method: "HEAD",
65				Header: http.Header{
66					"User-Agent": {userAgent},
67					"Cookie":     {cookie},
68				},
69			}
70			if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
71				errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
72				return
73			}
74			w.Header().Set("Content-Type", "text/html")
75			w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
76			w.WriteHeader(200)
77			io.WriteString(w, mainBody)
78			errc <- nil
79
80		case "/pushed?get":
81			wantH := http.Header{}
82			wantH.Set("User-Agent", userAgent)
83			if err := checkPromisedReq(r, "GET", wantH); err != nil {
84				errc <- fmt.Errorf("/pushed?get: %v", err)
85				return
86			}
87			w.Header().Set("Content-Type", "text/html")
88			w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
89			w.WriteHeader(200)
90			io.WriteString(w, pushedBody)
91			errc <- nil
92
93		case "/pushed?head":
94			wantH := http.Header{}
95			wantH.Set("User-Agent", userAgent)
96			wantH.Set("Cookie", cookie)
97			if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
98				errc <- fmt.Errorf("/pushed?head: %v", err)
99				return
100			}
101			w.WriteHeader(204)
102			errc <- nil
103
104		default:
105			errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
106		}
107	})
108	stURL = st.ts.URL
109
110	// Send one request, which should push two responses.
111	st.greet()
112	getSlash(st)
113	for k := 0; k < 3; k++ {
114		select {
115		case <-time.After(2 * time.Second):
116			t.Errorf("timeout waiting for handler %d to finish", k)
117		case err := <-errc:
118			if err != nil {
119				t.Fatal(err)
120			}
121		}
122	}
123
124	checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
125		pp, ok := f.(*PushPromiseFrame)
126		if !ok {
127			return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
128		}
129		if !pp.HeadersEnded() {
130			return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
131		}
132		if got, want := pp.PromiseID, promiseID; got != want {
133			return fmt.Errorf("got PromiseID %v; want %v", got, want)
134		}
135		gotH := st.decodeHeader(pp.HeaderBlockFragment())
136		if !reflect.DeepEqual(gotH, wantH) {
137			return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
138		}
139		return nil
140	}
141	checkHeaders := func(f Frame, wantH [][2]string) error {
142		hf, ok := f.(*HeadersFrame)
143		if !ok {
144			return fmt.Errorf("got a %T; want *HeadersFrame", f)
145		}
146		gotH := st.decodeHeader(hf.HeaderBlockFragment())
147		if !reflect.DeepEqual(gotH, wantH) {
148			return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
149		}
150		return nil
151	}
152	checkData := func(f Frame, wantData string) error {
153		df, ok := f.(*DataFrame)
154		if !ok {
155			return fmt.Errorf("got a %T; want *DataFrame", f)
156		}
157		if gotData := string(df.Data()); gotData != wantData {
158			return fmt.Errorf("got response data %q; want %q", gotData, wantData)
159		}
160		return nil
161	}
162
163	// Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA
164	// Stream 2 has HEADERS + DATA
165	// Stream 4 has HEADERS
166	expected := map[uint32][]func(Frame) error{
167		1: {
168			func(f Frame) error {
169				return checkPushPromise(f, 2, [][2]string{
170					{":method", "GET"},
171					{":scheme", "https"},
172					{":authority", st.ts.Listener.Addr().String()},
173					{":path", "/pushed?get"},
174					{"user-agent", userAgent},
175				})
176			},
177			func(f Frame) error {
178				return checkPushPromise(f, 4, [][2]string{
179					{":method", "HEAD"},
180					{":scheme", "https"},
181					{":authority", st.ts.Listener.Addr().String()},
182					{":path", "/pushed?head"},
183					{"cookie", cookie},
184					{"user-agent", userAgent},
185				})
186			},
187			func(f Frame) error {
188				return checkHeaders(f, [][2]string{
189					{":status", "200"},
190					{"content-type", "text/html"},
191					{"content-length", strconv.Itoa(len(mainBody))},
192				})
193			},
194			func(f Frame) error {
195				return checkData(f, mainBody)
196			},
197		},
198		2: {
199			func(f Frame) error {
200				return checkHeaders(f, [][2]string{
201					{":status", "200"},
202					{"content-type", "text/html"},
203					{"content-length", strconv.Itoa(len(pushedBody))},
204				})
205			},
206			func(f Frame) error {
207				return checkData(f, pushedBody)
208			},
209		},
210		4: {
211			func(f Frame) error {
212				return checkHeaders(f, [][2]string{
213					{":status", "204"},
214				})
215			},
216		},
217	}
218
219	consumed := map[uint32]int{}
220	for k := 0; len(expected) > 0; k++ {
221		f, err := st.readFrame()
222		if err != nil {
223			for id, left := range expected {
224				t.Errorf("stream %d: missing %d frames", id, len(left))
225			}
226			t.Fatalf("readFrame %d: %v", k, err)
227		}
228		id := f.Header().StreamID
229		label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
230		if len(expected[id]) == 0 {
231			t.Fatalf("%s: unexpected frame %#+v", label, f)
232		}
233		check := expected[id][0]
234		expected[id] = expected[id][1:]
235		if len(expected[id]) == 0 {
236			delete(expected, id)
237		}
238		if err := check(f); err != nil {
239			t.Fatalf("%s: %v", label, err)
240		}
241		consumed[id]++
242	}
243}
244
245func TestServer_Push_SuccessNoRace(t *testing.T) {
246	// Regression test for issue #18326. Ensure the request handler can mutate
247	// pushed request headers without racing with the PUSH_PROMISE write.
248	errc := make(chan error, 2)
249	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
250		switch r.URL.RequestURI() {
251		case "/":
252			opt := &http.PushOptions{
253				Header: http.Header{"User-Agent": {"testagent"}},
254			}
255			if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
256				errc <- fmt.Errorf("error pushing: %v", err)
257				return
258			}
259			w.WriteHeader(200)
260			errc <- nil
261
262		case "/pushed":
263			// Update request header, ensure there is no race.
264			r.Header.Set("User-Agent", "newagent")
265			r.Header.Set("Cookie", "cookie")
266			w.WriteHeader(200)
267			errc <- nil
268
269		default:
270			errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
271		}
272	})
273
274	// Send one request, which should push one response.
275	st.greet()
276	getSlash(st)
277	for k := 0; k < 2; k++ {
278		select {
279		case <-time.After(2 * time.Second):
280			t.Errorf("timeout waiting for handler %d to finish", k)
281		case err := <-errc:
282			if err != nil {
283				t.Fatal(err)
284			}
285		}
286	}
287}
288
289func TestServer_Push_RejectRecursivePush(t *testing.T) {
290	// Expect two requests, but might get three if there's a bug and the second push succeeds.
291	errc := make(chan error, 3)
292	handler := func(w http.ResponseWriter, r *http.Request) error {
293		baseURL := "https://" + r.Host
294		switch r.URL.Path {
295		case "/":
296			if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
297				return fmt.Errorf("first Push()=%v, want nil", err)
298			}
299			return nil
300
301		case "/push1":
302			if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
303				return fmt.Errorf("Push()=%v, want %v", got, want)
304			}
305			return nil
306
307		default:
308			return fmt.Errorf("unexpected path: %q", r.URL.Path)
309		}
310	}
311	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
312		errc <- handler(w, r)
313	})
314	defer st.Close()
315	st.greet()
316	getSlash(st)
317	if err := <-errc; err != nil {
318		t.Errorf("First request failed: %v", err)
319	}
320	if err := <-errc; err != nil {
321		t.Errorf("Second request failed: %v", err)
322	}
323}
324
325func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
326	// Expect one request, but might get two if there's a bug and the push succeeds.
327	errc := make(chan error, 2)
328	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
329		errc <- doPush(w.(http.Pusher), r)
330	})
331	defer st.Close()
332	st.greet()
333	if err := st.fr.WriteSettings(settings...); err != nil {
334		st.t.Fatalf("WriteSettings: %v", err)
335	}
336	st.wantSettingsAck()
337	getSlash(st)
338	if err := <-errc; err != nil {
339		t.Error(err)
340	}
341	// Should not get a PUSH_PROMISE frame.
342	hf := st.wantHeaders()
343	if !hf.StreamEnded() {
344		t.Error("stream should end after headers")
345	}
346}
347
348func TestServer_Push_RejectIfDisabled(t *testing.T) {
349	testServer_Push_RejectSingleRequest(t,
350		func(p http.Pusher, r *http.Request) error {
351			if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
352				return fmt.Errorf("Push()=%v, want %v", got, want)
353			}
354			return nil
355		},
356		Setting{SettingEnablePush, 0})
357}
358
359func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
360	testServer_Push_RejectSingleRequest(t,
361		func(p http.Pusher, r *http.Request) error {
362			if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
363				return fmt.Errorf("Push()=%v, want %v", got, want)
364			}
365			return nil
366		},
367		Setting{SettingMaxConcurrentStreams, 0})
368}
369
370func TestServer_Push_RejectWrongScheme(t *testing.T) {
371	testServer_Push_RejectSingleRequest(t,
372		func(p http.Pusher, r *http.Request) error {
373			if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
374				return errors.New("Push() should have failed (push target URL is http)")
375			}
376			return nil
377		})
378}
379
380func TestServer_Push_RejectMissingHost(t *testing.T) {
381	testServer_Push_RejectSingleRequest(t,
382		func(p http.Pusher, r *http.Request) error {
383			if err := p.Push("https:pushed", nil); err == nil {
384				return errors.New("Push() should have failed (push target URL missing host)")
385			}
386			return nil
387		})
388}
389
390func TestServer_Push_RejectRelativePath(t *testing.T) {
391	testServer_Push_RejectSingleRequest(t,
392		func(p http.Pusher, r *http.Request) error {
393			if err := p.Push("../test", nil); err == nil {
394				return errors.New("Push() should have failed (push target is a relative path)")
395			}
396			return nil
397		})
398}
399
400func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
401	testServer_Push_RejectSingleRequest(t,
402		func(p http.Pusher, r *http.Request) error {
403			if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
404				return errors.New("Push() should have failed (cannot promise a POST)")
405			}
406			return nil
407		})
408}
409
410func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
411	testServer_Push_RejectSingleRequest(t,
412		func(p http.Pusher, r *http.Request) error {
413			header := http.Header{
414				"Content-Length":   {"10"},
415				"Content-Encoding": {"gzip"},
416				"Trailer":          {"Foo"},
417				"Te":               {"trailers"},
418				"Host":             {"test.com"},
419				":authority":       {"test.com"},
420			}
421			if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
422				return errors.New("Push() should have failed (forbidden headers)")
423			}
424			return nil
425		})
426}
427
428func TestServer_Push_StateTransitions(t *testing.T) {
429	const body = "foo"
430
431	gotPromise := make(chan bool)
432	finishedPush := make(chan bool)
433
434	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
435		switch r.URL.RequestURI() {
436		case "/":
437			if err := w.(http.Pusher).Push("/pushed", nil); err != nil {
438				t.Errorf("Push error: %v", err)
439			}
440			// Don't finish this request until the push finishes so we don't
441			// nondeterministically interleave output frames with the push.
442			<-finishedPush
443		case "/pushed":
444			<-gotPromise
445		}
446		w.Header().Set("Content-Type", "text/html")
447		w.Header().Set("Content-Length", strconv.Itoa(len(body)))
448		w.WriteHeader(200)
449		io.WriteString(w, body)
450	})
451	defer st.Close()
452
453	st.greet()
454	if st.stream(2) != nil {
455		t.Fatal("stream 2 should be empty")
456	}
457	if got, want := st.streamState(2), stateIdle; got != want {
458		t.Fatalf("streamState(2)=%v, want %v", got, want)
459	}
460	getSlash(st)
461	// After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote.
462	st.wantPushPromise()
463	if got, want := st.streamState(2), stateHalfClosedRemote; got != want {
464		t.Fatalf("streamState(2)=%v, want %v", got, want)
465	}
466	// We stall the HTTP handler for "/pushed" until the above check. If we don't
467	// stall the handler, then the handler might write HEADERS and DATA and finish
468	// the stream before we check st.streamState(2) -- should that happen, we'll
469	// see stateClosed and fail the above check.
470	close(gotPromise)
471	st.wantHeaders()
472	if df := st.wantData(); !df.StreamEnded() {
473		t.Fatal("expected END_STREAM flag on DATA")
474	}
475	if got, want := st.streamState(2), stateClosed; got != want {
476		t.Fatalf("streamState(2)=%v, want %v", got, want)
477	}
478	close(finishedPush)
479}
480
481func TestServer_Push_RejectAfterGoAway(t *testing.T) {
482	var readyOnce sync.Once
483	ready := make(chan struct{})
484	errc := make(chan error, 2)
485	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
486		select {
487		case <-ready:
488		case <-time.After(5 * time.Second):
489			errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed")
490		}
491		if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
492			errc <- fmt.Errorf("Push()=%v, want %v", got, want)
493		}
494		errc <- nil
495	})
496	defer st.Close()
497	st.greet()
498	getSlash(st)
499
500	// Send GOAWAY and wait for it to be processed.
501	st.fr.WriteGoAway(1, ErrCodeNo, nil)
502	go func() {
503		for {
504			select {
505			case <-ready:
506				return
507			default:
508			}
509			st.sc.serveMsgCh <- func(loopNum int) {
510				if !st.sc.pushEnabled {
511					readyOnce.Do(func() { close(ready) })
512				}
513			}
514		}
515	}()
516	if err := <-errc; err != nil {
517		t.Error(err)
518	}
519}
520