1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package http2
6
7import (
8	"bytes"
9	"compress/gzip"
10	"compress/zlib"
11	"context"
12	"crypto/tls"
13	"errors"
14	"flag"
15	"fmt"
16	"io"
17	"io/ioutil"
18	"log"
19	"net"
20	"net/http"
21	"net/http/httptest"
22	"os"
23	"os/exec"
24	"reflect"
25	"runtime"
26	"strconv"
27	"strings"
28	"sync"
29	"sync/atomic"
30	"testing"
31	"time"
32
33	"golang.org/x/net/http2/hpack"
34)
35
36var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered")
37
38func stderrv() io.Writer {
39	if *stderrVerbose {
40		return os.Stderr
41	}
42
43	return ioutil.Discard
44}
45
46type serverTester struct {
47	cc             net.Conn // client conn
48	t              testing.TB
49	ts             *httptest.Server
50	fr             *Framer
51	serverLogBuf   bytes.Buffer // logger for httptest.Server
52	logFilter      []string     // substrings to filter out
53	scMu           sync.Mutex   // guards sc
54	sc             *serverConn
55	hpackDec       *hpack.Decoder
56	decodedHeaders [][2]string
57
58	// If http2debug!=2, then we capture Frame debug logs that will be written
59	// to t.Log after a test fails. The read and write logs use separate locks
60	// and buffers so we don't accidentally introduce synchronization between
61	// the read and write goroutines, which may hide data races.
62	frameReadLogMu   sync.Mutex
63	frameReadLogBuf  bytes.Buffer
64	frameWriteLogMu  sync.Mutex
65	frameWriteLogBuf bytes.Buffer
66
67	// writing headers:
68	headerBuf bytes.Buffer
69	hpackEnc  *hpack.Encoder
70}
71
72func init() {
73	testHookOnPanicMu = new(sync.Mutex)
74	goAwayTimeout = 25 * time.Millisecond
75}
76
77func resetHooks() {
78	testHookOnPanicMu.Lock()
79	testHookOnPanic = nil
80	testHookOnPanicMu.Unlock()
81}
82
83type serverTesterOpt string
84
85var optOnlyServer = serverTesterOpt("only_server")
86var optQuiet = serverTesterOpt("quiet_logging")
87var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames")
88
89func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
90	resetHooks()
91
92	ts := httptest.NewUnstartedServer(handler)
93
94	tlsConfig := &tls.Config{
95		InsecureSkipVerify: true,
96		NextProtos:         []string{NextProtoTLS},
97	}
98
99	var onlyServer, quiet, framerReuseFrames bool
100	h2server := new(Server)
101	for _, opt := range opts {
102		switch v := opt.(type) {
103		case func(*tls.Config):
104			v(tlsConfig)
105		case func(*httptest.Server):
106			v(ts)
107		case func(*Server):
108			v(h2server)
109		case serverTesterOpt:
110			switch v {
111			case optOnlyServer:
112				onlyServer = true
113			case optQuiet:
114				quiet = true
115			case optFramerReuseFrames:
116				framerReuseFrames = true
117			}
118		case func(net.Conn, http.ConnState):
119			ts.Config.ConnState = v
120		default:
121			t.Fatalf("unknown newServerTester option type %T", v)
122		}
123	}
124
125	ConfigureServer(ts.Config, h2server)
126
127	st := &serverTester{
128		t:  t,
129		ts: ts,
130	}
131	st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
132	st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
133
134	ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
135	if quiet {
136		ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
137	} else {
138		ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
139	}
140	ts.StartTLS()
141
142	if VerboseLogs {
143		t.Logf("Running test server at: %s", ts.URL)
144	}
145	testHookGetServerConn = func(v *serverConn) {
146		st.scMu.Lock()
147		defer st.scMu.Unlock()
148		st.sc = v
149	}
150	log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st}))
151	if !onlyServer {
152		cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
153		if err != nil {
154			t.Fatal(err)
155		}
156		st.cc = cc
157		st.fr = NewFramer(cc, cc)
158		if framerReuseFrames {
159			st.fr.SetReuseFrames()
160		}
161		if !logFrameReads && !logFrameWrites {
162			st.fr.debugReadLoggerf = func(m string, v ...interface{}) {
163				m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
164				st.frameReadLogMu.Lock()
165				fmt.Fprintf(&st.frameReadLogBuf, m, v...)
166				st.frameReadLogMu.Unlock()
167			}
168			st.fr.debugWriteLoggerf = func(m string, v ...interface{}) {
169				m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
170				st.frameWriteLogMu.Lock()
171				fmt.Fprintf(&st.frameWriteLogBuf, m, v...)
172				st.frameWriteLogMu.Unlock()
173			}
174			st.fr.logReads = true
175			st.fr.logWrites = true
176		}
177	}
178	return st
179}
180
181func (st *serverTester) closeConn() {
182	st.scMu.Lock()
183	defer st.scMu.Unlock()
184	st.sc.conn.Close()
185}
186
187func (st *serverTester) addLogFilter(phrase string) {
188	st.logFilter = append(st.logFilter, phrase)
189}
190
191func (st *serverTester) stream(id uint32) *stream {
192	ch := make(chan *stream, 1)
193	st.sc.serveMsgCh <- func(int) {
194		ch <- st.sc.streams[id]
195	}
196	return <-ch
197}
198
199func (st *serverTester) streamState(id uint32) streamState {
200	ch := make(chan streamState, 1)
201	st.sc.serveMsgCh <- func(int) {
202		state, _ := st.sc.state(id)
203		ch <- state
204	}
205	return <-ch
206}
207
208// loopNum reports how many times this conn's select loop has gone around.
209func (st *serverTester) loopNum() int {
210	lastc := make(chan int, 1)
211	st.sc.serveMsgCh <- func(loopNum int) {
212		lastc <- loopNum
213	}
214	return <-lastc
215}
216
217// awaitIdle heuristically awaits for the server conn's select loop to be idle.
218// The heuristic is that the server connection's serve loop must schedule
219// 50 times in a row without any channel sends or receives occurring.
220func (st *serverTester) awaitIdle() {
221	remain := 50
222	last := st.loopNum()
223	for remain > 0 {
224		n := st.loopNum()
225		if n == last+1 {
226			remain--
227		} else {
228			remain = 50
229		}
230		last = n
231	}
232}
233
234func (st *serverTester) Close() {
235	if st.t.Failed() {
236		st.frameReadLogMu.Lock()
237		if st.frameReadLogBuf.Len() > 0 {
238			st.t.Logf("Framer read log:\n%s", st.frameReadLogBuf.String())
239		}
240		st.frameReadLogMu.Unlock()
241
242		st.frameWriteLogMu.Lock()
243		if st.frameWriteLogBuf.Len() > 0 {
244			st.t.Logf("Framer write log:\n%s", st.frameWriteLogBuf.String())
245		}
246		st.frameWriteLogMu.Unlock()
247
248		// If we failed already (and are likely in a Fatal,
249		// unwindowing), force close the connection, so the
250		// httptest.Server doesn't wait forever for the conn
251		// to close.
252		if st.cc != nil {
253			st.cc.Close()
254		}
255	}
256	st.ts.Close()
257	if st.cc != nil {
258		st.cc.Close()
259	}
260	log.SetOutput(os.Stderr)
261}
262
263// greet initiates the client's HTTP/2 connection into a state where
264// frames may be sent.
265func (st *serverTester) greet() {
266	st.greetAndCheckSettings(func(Setting) error { return nil })
267}
268
269func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error) {
270	st.writePreface()
271	st.writeInitialSettings()
272	st.wantSettings().ForeachSetting(checkSetting)
273	st.writeSettingsAck()
274
275	// The initial WINDOW_UPDATE and SETTINGS ACK can come in any order.
276	var gotSettingsAck bool
277	var gotWindowUpdate bool
278
279	for i := 0; i < 2; i++ {
280		f, err := st.readFrame()
281		if err != nil {
282			st.t.Fatal(err)
283		}
284		switch f := f.(type) {
285		case *SettingsFrame:
286			if !f.Header().Flags.Has(FlagSettingsAck) {
287				st.t.Fatal("Settings Frame didn't have ACK set")
288			}
289			gotSettingsAck = true
290
291		case *WindowUpdateFrame:
292			if f.FrameHeader.StreamID != 0 {
293				st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
294			}
295			incr := uint32((&Server{}).initialConnRecvWindowSize() - initialWindowSize)
296			if f.Increment != incr {
297				st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr)
298			}
299			gotWindowUpdate = true
300
301		default:
302			st.t.Fatalf("Wanting a settings ACK or window update, received a %T", f)
303		}
304	}
305
306	if !gotSettingsAck {
307		st.t.Fatalf("Didn't get a settings ACK")
308	}
309	if !gotWindowUpdate {
310		st.t.Fatalf("Didn't get a window update")
311	}
312}
313
314func (st *serverTester) writePreface() {
315	n, err := st.cc.Write(clientPreface)
316	if err != nil {
317		st.t.Fatalf("Error writing client preface: %v", err)
318	}
319	if n != len(clientPreface) {
320		st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(clientPreface))
321	}
322}
323
324func (st *serverTester) writeInitialSettings() {
325	if err := st.fr.WriteSettings(); err != nil {
326		st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err)
327	}
328}
329
330func (st *serverTester) writeSettingsAck() {
331	if err := st.fr.WriteSettingsAck(); err != nil {
332		st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err)
333	}
334}
335
336func (st *serverTester) writeHeaders(p HeadersFrameParam) {
337	if err := st.fr.WriteHeaders(p); err != nil {
338		st.t.Fatalf("Error writing HEADERS: %v", err)
339	}
340}
341
342func (st *serverTester) writePriority(id uint32, p PriorityParam) {
343	if err := st.fr.WritePriority(id, p); err != nil {
344		st.t.Fatalf("Error writing PRIORITY: %v", err)
345	}
346}
347
348func (st *serverTester) encodeHeaderField(k, v string) {
349	err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
350	if err != nil {
351		st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
352	}
353}
354
355// encodeHeaderRaw is the magic-free version of encodeHeader.
356// It takes 0 or more (k, v) pairs and encodes them.
357func (st *serverTester) encodeHeaderRaw(headers ...string) []byte {
358	if len(headers)%2 == 1 {
359		panic("odd number of kv args")
360	}
361	st.headerBuf.Reset()
362	for len(headers) > 0 {
363		k, v := headers[0], headers[1]
364		st.encodeHeaderField(k, v)
365		headers = headers[2:]
366	}
367	return st.headerBuf.Bytes()
368}
369
370// encodeHeader encodes headers and returns their HPACK bytes. headers
371// must contain an even number of key/value pairs. There may be
372// multiple pairs for keys (e.g. "cookie").  The :method, :path, and
373// :scheme headers default to GET, / and https. The :authority header
374// defaults to st.ts.Listener.Addr().
375func (st *serverTester) encodeHeader(headers ...string) []byte {
376	if len(headers)%2 == 1 {
377		panic("odd number of kv args")
378	}
379
380	st.headerBuf.Reset()
381	defaultAuthority := st.ts.Listener.Addr().String()
382
383	if len(headers) == 0 {
384		// Fast path, mostly for benchmarks, so test code doesn't pollute
385		// profiles when we're looking to improve server allocations.
386		st.encodeHeaderField(":method", "GET")
387		st.encodeHeaderField(":scheme", "https")
388		st.encodeHeaderField(":authority", defaultAuthority)
389		st.encodeHeaderField(":path", "/")
390		return st.headerBuf.Bytes()
391	}
392
393	if len(headers) == 2 && headers[0] == ":method" {
394		// Another fast path for benchmarks.
395		st.encodeHeaderField(":method", headers[1])
396		st.encodeHeaderField(":scheme", "https")
397		st.encodeHeaderField(":authority", defaultAuthority)
398		st.encodeHeaderField(":path", "/")
399		return st.headerBuf.Bytes()
400	}
401
402	pseudoCount := map[string]int{}
403	keys := []string{":method", ":scheme", ":authority", ":path"}
404	vals := map[string][]string{
405		":method":    {"GET"},
406		":scheme":    {"https"},
407		":authority": {defaultAuthority},
408		":path":      {"/"},
409	}
410	for len(headers) > 0 {
411		k, v := headers[0], headers[1]
412		headers = headers[2:]
413		if _, ok := vals[k]; !ok {
414			keys = append(keys, k)
415		}
416		if strings.HasPrefix(k, ":") {
417			pseudoCount[k]++
418			if pseudoCount[k] == 1 {
419				vals[k] = []string{v}
420			} else {
421				// Allows testing of invalid headers w/ dup pseudo fields.
422				vals[k] = append(vals[k], v)
423			}
424		} else {
425			vals[k] = append(vals[k], v)
426		}
427	}
428	for _, k := range keys {
429		for _, v := range vals[k] {
430			st.encodeHeaderField(k, v)
431		}
432	}
433	return st.headerBuf.Bytes()
434}
435
436// bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set.
437func (st *serverTester) bodylessReq1(headers ...string) {
438	st.writeHeaders(HeadersFrameParam{
439		StreamID:      1, // clients send odd numbers
440		BlockFragment: st.encodeHeader(headers...),
441		EndStream:     true,
442		EndHeaders:    true,
443	})
444}
445
446func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
447	if err := st.fr.WriteData(streamID, endStream, data); err != nil {
448		st.t.Fatalf("Error writing DATA: %v", err)
449	}
450}
451
452func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
453	if err := st.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
454		st.t.Fatalf("Error writing DATA: %v", err)
455	}
456}
457
458func readFrameTimeout(fr *Framer, wait time.Duration) (Frame, error) {
459	ch := make(chan interface{}, 1)
460	go func() {
461		fr, err := fr.ReadFrame()
462		if err != nil {
463			ch <- err
464		} else {
465			ch <- fr
466		}
467	}()
468	t := time.NewTimer(wait)
469	select {
470	case v := <-ch:
471		t.Stop()
472		if fr, ok := v.(Frame); ok {
473			return fr, nil
474		}
475		return nil, v.(error)
476	case <-t.C:
477		return nil, errors.New("timeout waiting for frame")
478	}
479}
480
481func (st *serverTester) readFrame() (Frame, error) {
482	return readFrameTimeout(st.fr, 2*time.Second)
483}
484
485func (st *serverTester) wantHeaders() *HeadersFrame {
486	f, err := st.readFrame()
487	if err != nil {
488		st.t.Fatalf("Error while expecting a HEADERS frame: %v", err)
489	}
490	hf, ok := f.(*HeadersFrame)
491	if !ok {
492		st.t.Fatalf("got a %T; want *HeadersFrame", f)
493	}
494	return hf
495}
496
497func (st *serverTester) wantContinuation() *ContinuationFrame {
498	f, err := st.readFrame()
499	if err != nil {
500		st.t.Fatalf("Error while expecting a CONTINUATION frame: %v", err)
501	}
502	cf, ok := f.(*ContinuationFrame)
503	if !ok {
504		st.t.Fatalf("got a %T; want *ContinuationFrame", f)
505	}
506	return cf
507}
508
509func (st *serverTester) wantData() *DataFrame {
510	f, err := st.readFrame()
511	if err != nil {
512		st.t.Fatalf("Error while expecting a DATA frame: %v", err)
513	}
514	df, ok := f.(*DataFrame)
515	if !ok {
516		st.t.Fatalf("got a %T; want *DataFrame", f)
517	}
518	return df
519}
520
521func (st *serverTester) wantSettings() *SettingsFrame {
522	f, err := st.readFrame()
523	if err != nil {
524		st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
525	}
526	sf, ok := f.(*SettingsFrame)
527	if !ok {
528		st.t.Fatalf("got a %T; want *SettingsFrame", f)
529	}
530	return sf
531}
532
533func (st *serverTester) wantPing() *PingFrame {
534	f, err := st.readFrame()
535	if err != nil {
536		st.t.Fatalf("Error while expecting a PING frame: %v", err)
537	}
538	pf, ok := f.(*PingFrame)
539	if !ok {
540		st.t.Fatalf("got a %T; want *PingFrame", f)
541	}
542	return pf
543}
544
545func (st *serverTester) wantGoAway() *GoAwayFrame {
546	f, err := st.readFrame()
547	if err != nil {
548		st.t.Fatalf("Error while expecting a GOAWAY frame: %v", err)
549	}
550	gf, ok := f.(*GoAwayFrame)
551	if !ok {
552		st.t.Fatalf("got a %T; want *GoAwayFrame", f)
553	}
554	return gf
555}
556
557func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
558	f, err := st.readFrame()
559	if err != nil {
560		st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
561	}
562	rs, ok := f.(*RSTStreamFrame)
563	if !ok {
564		st.t.Fatalf("got a %T; want *RSTStreamFrame", f)
565	}
566	if rs.FrameHeader.StreamID != streamID {
567		st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
568	}
569	if rs.ErrCode != errCode {
570		st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode)
571	}
572}
573
574func (st *serverTester) wantWindowUpdate(streamID, incr uint32) {
575	f, err := st.readFrame()
576	if err != nil {
577		st.t.Fatalf("Error while expecting a WINDOW_UPDATE frame: %v", err)
578	}
579	wu, ok := f.(*WindowUpdateFrame)
580	if !ok {
581		st.t.Fatalf("got a %T; want *WindowUpdateFrame", f)
582	}
583	if wu.FrameHeader.StreamID != streamID {
584		st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
585	}
586	if wu.Increment != incr {
587		st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
588	}
589}
590
591func (st *serverTester) wantSettingsAck() {
592	f, err := st.readFrame()
593	if err != nil {
594		st.t.Fatal(err)
595	}
596	sf, ok := f.(*SettingsFrame)
597	if !ok {
598		st.t.Fatalf("Wanting a settings ACK, received a %T", f)
599	}
600	if !sf.Header().Flags.Has(FlagSettingsAck) {
601		st.t.Fatal("Settings Frame didn't have ACK set")
602	}
603}
604
605func (st *serverTester) wantPushPromise() *PushPromiseFrame {
606	f, err := st.readFrame()
607	if err != nil {
608		st.t.Fatal(err)
609	}
610	ppf, ok := f.(*PushPromiseFrame)
611	if !ok {
612		st.t.Fatalf("Wanted PushPromise, received %T", ppf)
613	}
614	return ppf
615}
616
617func TestServer(t *testing.T) {
618	gotReq := make(chan bool, 1)
619	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
620		w.Header().Set("Foo", "Bar")
621		gotReq <- true
622	})
623	defer st.Close()
624
625	covers("3.5", `
626		The server connection preface consists of a potentially empty
627		SETTINGS frame ([SETTINGS]) that MUST be the first frame the
628		server sends in the HTTP/2 connection.
629	`)
630
631	st.greet()
632	st.writeHeaders(HeadersFrameParam{
633		StreamID:      1, // clients send odd numbers
634		BlockFragment: st.encodeHeader(),
635		EndStream:     true, // no DATA frames
636		EndHeaders:    true,
637	})
638
639	select {
640	case <-gotReq:
641	case <-time.After(2 * time.Second):
642		t.Error("timeout waiting for request")
643	}
644}
645
646func TestServer_Request_Get(t *testing.T) {
647	testServerRequest(t, func(st *serverTester) {
648		st.writeHeaders(HeadersFrameParam{
649			StreamID:      1, // clients send odd numbers
650			BlockFragment: st.encodeHeader("foo-bar", "some-value"),
651			EndStream:     true, // no DATA frames
652			EndHeaders:    true,
653		})
654	}, func(r *http.Request) {
655		if r.Method != "GET" {
656			t.Errorf("Method = %q; want GET", r.Method)
657		}
658		if r.URL.Path != "/" {
659			t.Errorf("URL.Path = %q; want /", r.URL.Path)
660		}
661		if r.ContentLength != 0 {
662			t.Errorf("ContentLength = %v; want 0", r.ContentLength)
663		}
664		if r.Close {
665			t.Error("Close = true; want false")
666		}
667		if !strings.Contains(r.RemoteAddr, ":") {
668			t.Errorf("RemoteAddr = %q; want something with a colon", r.RemoteAddr)
669		}
670		if r.Proto != "HTTP/2.0" || r.ProtoMajor != 2 || r.ProtoMinor != 0 {
671			t.Errorf("Proto = %q Major=%v,Minor=%v; want HTTP/2.0", r.Proto, r.ProtoMajor, r.ProtoMinor)
672		}
673		wantHeader := http.Header{
674			"Foo-Bar": []string{"some-value"},
675		}
676		if !reflect.DeepEqual(r.Header, wantHeader) {
677			t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
678		}
679		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
680			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
681		}
682	})
683}
684
685func TestServer_Request_Get_PathSlashes(t *testing.T) {
686	testServerRequest(t, func(st *serverTester) {
687		st.writeHeaders(HeadersFrameParam{
688			StreamID:      1, // clients send odd numbers
689			BlockFragment: st.encodeHeader(":path", "/%2f/"),
690			EndStream:     true, // no DATA frames
691			EndHeaders:    true,
692		})
693	}, func(r *http.Request) {
694		if r.RequestURI != "/%2f/" {
695			t.Errorf("RequestURI = %q; want /%%2f/", r.RequestURI)
696		}
697		if r.URL.Path != "///" {
698			t.Errorf("URL.Path = %q; want ///", r.URL.Path)
699		}
700	})
701}
702
703// TODO: add a test with EndStream=true on the HEADERS but setting a
704// Content-Length anyway. Should we just omit it and force it to
705// zero?
706
707func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) {
708	testServerRequest(t, func(st *serverTester) {
709		st.writeHeaders(HeadersFrameParam{
710			StreamID:      1, // clients send odd numbers
711			BlockFragment: st.encodeHeader(":method", "POST"),
712			EndStream:     true,
713			EndHeaders:    true,
714		})
715	}, func(r *http.Request) {
716		if r.Method != "POST" {
717			t.Errorf("Method = %q; want POST", r.Method)
718		}
719		if r.ContentLength != 0 {
720			t.Errorf("ContentLength = %v; want 0", r.ContentLength)
721		}
722		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
723			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
724		}
725	})
726}
727
728func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) {
729	testBodyContents(t, -1, "", func(st *serverTester) {
730		st.writeHeaders(HeadersFrameParam{
731			StreamID:      1, // clients send odd numbers
732			BlockFragment: st.encodeHeader(":method", "POST"),
733			EndStream:     false, // to say DATA frames are coming
734			EndHeaders:    true,
735		})
736		st.writeData(1, true, nil) // just kidding. empty body.
737	})
738}
739
740func TestServer_Request_Post_Body_OneData(t *testing.T) {
741	const content = "Some content"
742	testBodyContents(t, -1, content, func(st *serverTester) {
743		st.writeHeaders(HeadersFrameParam{
744			StreamID:      1, // clients send odd numbers
745			BlockFragment: st.encodeHeader(":method", "POST"),
746			EndStream:     false, // to say DATA frames are coming
747			EndHeaders:    true,
748		})
749		st.writeData(1, true, []byte(content))
750	})
751}
752
753func TestServer_Request_Post_Body_TwoData(t *testing.T) {
754	const content = "Some content"
755	testBodyContents(t, -1, content, func(st *serverTester) {
756		st.writeHeaders(HeadersFrameParam{
757			StreamID:      1, // clients send odd numbers
758			BlockFragment: st.encodeHeader(":method", "POST"),
759			EndStream:     false, // to say DATA frames are coming
760			EndHeaders:    true,
761		})
762		st.writeData(1, false, []byte(content[:5]))
763		st.writeData(1, true, []byte(content[5:]))
764	})
765}
766
767func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) {
768	const content = "Some content"
769	testBodyContents(t, int64(len(content)), content, func(st *serverTester) {
770		st.writeHeaders(HeadersFrameParam{
771			StreamID: 1, // clients send odd numbers
772			BlockFragment: st.encodeHeader(
773				":method", "POST",
774				"content-length", strconv.Itoa(len(content)),
775			),
776			EndStream:  false, // to say DATA frames are coming
777			EndHeaders: true,
778		})
779		st.writeData(1, true, []byte(content))
780	})
781}
782
783func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) {
784	testBodyContentsFail(t, 3, "request declared a Content-Length of 3 but only wrote 2 bytes",
785		func(st *serverTester) {
786			st.writeHeaders(HeadersFrameParam{
787				StreamID: 1, // clients send odd numbers
788				BlockFragment: st.encodeHeader(
789					":method", "POST",
790					"content-length", "3",
791				),
792				EndStream:  false, // to say DATA frames are coming
793				EndHeaders: true,
794			})
795			st.writeData(1, true, []byte("12"))
796		})
797}
798
799func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) {
800	testBodyContentsFail(t, 4, "sender tried to send more than declared Content-Length of 4 bytes",
801		func(st *serverTester) {
802			st.writeHeaders(HeadersFrameParam{
803				StreamID: 1, // clients send odd numbers
804				BlockFragment: st.encodeHeader(
805					":method", "POST",
806					"content-length", "4",
807				),
808				EndStream:  false, // to say DATA frames are coming
809				EndHeaders: true,
810			})
811			st.writeData(1, true, []byte("12345"))
812		})
813}
814
815func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, write func(st *serverTester)) {
816	testServerRequest(t, write, func(r *http.Request) {
817		if r.Method != "POST" {
818			t.Errorf("Method = %q; want POST", r.Method)
819		}
820		if r.ContentLength != wantContentLength {
821			t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
822		}
823		all, err := ioutil.ReadAll(r.Body)
824		if err != nil {
825			t.Fatal(err)
826		}
827		if string(all) != wantBody {
828			t.Errorf("Read = %q; want %q", all, wantBody)
829		}
830		if err := r.Body.Close(); err != nil {
831			t.Fatalf("Close: %v", err)
832		}
833	})
834}
835
836func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError string, write func(st *serverTester)) {
837	testServerRequest(t, write, func(r *http.Request) {
838		if r.Method != "POST" {
839			t.Errorf("Method = %q; want POST", r.Method)
840		}
841		if r.ContentLength != wantContentLength {
842			t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
843		}
844		all, err := ioutil.ReadAll(r.Body)
845		if err == nil {
846			t.Fatalf("expected an error (%q) reading from the body. Successfully read %q instead.",
847				wantReadError, all)
848		}
849		if !strings.Contains(err.Error(), wantReadError) {
850			t.Fatalf("Body.Read = %v; want substring %q", err, wantReadError)
851		}
852		if err := r.Body.Close(); err != nil {
853			t.Fatalf("Close: %v", err)
854		}
855	})
856}
857
858// Using a Host header, instead of :authority
859func TestServer_Request_Get_Host(t *testing.T) {
860	const host = "example.com"
861	testServerRequest(t, func(st *serverTester) {
862		st.writeHeaders(HeadersFrameParam{
863			StreamID:      1, // clients send odd numbers
864			BlockFragment: st.encodeHeader(":authority", "", "host", host),
865			EndStream:     true,
866			EndHeaders:    true,
867		})
868	}, func(r *http.Request) {
869		if r.Host != host {
870			t.Errorf("Host = %q; want %q", r.Host, host)
871		}
872	})
873}
874
875// Using an :authority pseudo-header, instead of Host
876func TestServer_Request_Get_Authority(t *testing.T) {
877	const host = "example.com"
878	testServerRequest(t, func(st *serverTester) {
879		st.writeHeaders(HeadersFrameParam{
880			StreamID:      1, // clients send odd numbers
881			BlockFragment: st.encodeHeader(":authority", host),
882			EndStream:     true,
883			EndHeaders:    true,
884		})
885	}, func(r *http.Request) {
886		if r.Host != host {
887			t.Errorf("Host = %q; want %q", r.Host, host)
888		}
889	})
890}
891
892func TestServer_Request_WithContinuation(t *testing.T) {
893	wantHeader := http.Header{
894		"Foo-One":   []string{"value-one"},
895		"Foo-Two":   []string{"value-two"},
896		"Foo-Three": []string{"value-three"},
897	}
898	testServerRequest(t, func(st *serverTester) {
899		fullHeaders := st.encodeHeader(
900			"foo-one", "value-one",
901			"foo-two", "value-two",
902			"foo-three", "value-three",
903		)
904		remain := fullHeaders
905		chunks := 0
906		for len(remain) > 0 {
907			const maxChunkSize = 5
908			chunk := remain
909			if len(chunk) > maxChunkSize {
910				chunk = chunk[:maxChunkSize]
911			}
912			remain = remain[len(chunk):]
913
914			if chunks == 0 {
915				st.writeHeaders(HeadersFrameParam{
916					StreamID:      1, // clients send odd numbers
917					BlockFragment: chunk,
918					EndStream:     true,  // no DATA frames
919					EndHeaders:    false, // we'll have continuation frames
920				})
921			} else {
922				err := st.fr.WriteContinuation(1, len(remain) == 0, chunk)
923				if err != nil {
924					t.Fatal(err)
925				}
926			}
927			chunks++
928		}
929		if chunks < 2 {
930			t.Fatal("too few chunks")
931		}
932	}, func(r *http.Request) {
933		if !reflect.DeepEqual(r.Header, wantHeader) {
934			t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
935		}
936	})
937}
938
939// Concatenated cookie headers. ("8.1.2.5 Compressing the Cookie Header Field")
940func TestServer_Request_CookieConcat(t *testing.T) {
941	const host = "example.com"
942	testServerRequest(t, func(st *serverTester) {
943		st.bodylessReq1(
944			":authority", host,
945			"cookie", "a=b",
946			"cookie", "c=d",
947			"cookie", "e=f",
948		)
949	}, func(r *http.Request) {
950		const want = "a=b; c=d; e=f"
951		if got := r.Header.Get("Cookie"); got != want {
952			t.Errorf("Cookie = %q; want %q", got, want)
953		}
954	})
955}
956
957func TestServer_Request_Reject_CapitalHeader(t *testing.T) {
958	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("UPPER", "v") })
959}
960
961func TestServer_Request_Reject_HeaderFieldNameColon(t *testing.T) {
962	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has:colon", "v") })
963}
964
965func TestServer_Request_Reject_HeaderFieldNameNULL(t *testing.T) {
966	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has\x00null", "v") })
967}
968
969func TestServer_Request_Reject_HeaderFieldNameEmpty(t *testing.T) {
970	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("", "v") })
971}
972
973func TestServer_Request_Reject_HeaderFieldValueNewline(t *testing.T) {
974	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\nnewline") })
975}
976
977func TestServer_Request_Reject_HeaderFieldValueCR(t *testing.T) {
978	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\rcarriage") })
979}
980
981func TestServer_Request_Reject_HeaderFieldValueDEL(t *testing.T) {
982	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\x7fdel") })
983}
984
985func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) {
986	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "") })
987}
988
989func TestServer_Request_Reject_Pseudo_ExactlyOne(t *testing.T) {
990	// 8.1.2.3 Request Pseudo-Header Fields
991	// "All HTTP/2 requests MUST include exactly one valid value" ...
992	testRejectRequest(t, func(st *serverTester) {
993		st.addLogFilter("duplicate pseudo-header")
994		st.bodylessReq1(":method", "GET", ":method", "POST")
995	})
996}
997
998func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) {
999	// 8.1.2.3 Request Pseudo-Header Fields
1000	// "All pseudo-header fields MUST appear in the header block
1001	// before regular header fields. Any request or response that
1002	// contains a pseudo-header field that appears in a header
1003	// block after a regular header field MUST be treated as
1004	// malformed (Section 8.1.2.6)."
1005	testRejectRequest(t, func(st *serverTester) {
1006		st.addLogFilter("pseudo-header after regular header")
1007		var buf bytes.Buffer
1008		enc := hpack.NewEncoder(&buf)
1009		enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
1010		enc.WriteField(hpack.HeaderField{Name: "regular", Value: "foobar"})
1011		enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"})
1012		enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
1013		st.writeHeaders(HeadersFrameParam{
1014			StreamID:      1, // clients send odd numbers
1015			BlockFragment: buf.Bytes(),
1016			EndStream:     true,
1017			EndHeaders:    true,
1018		})
1019	})
1020}
1021
1022func TestServer_Request_Reject_Pseudo_Missing_path(t *testing.T) {
1023	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":path", "") })
1024}
1025
1026func TestServer_Request_Reject_Pseudo_Missing_scheme(t *testing.T) {
1027	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "") })
1028}
1029
1030func TestServer_Request_Reject_Pseudo_scheme_invalid(t *testing.T) {
1031	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "bogus") })
1032}
1033
1034func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
1035	testRejectRequest(t, func(st *serverTester) {
1036		st.addLogFilter(`invalid pseudo-header ":unknown_thing"`)
1037		st.bodylessReq1(":unknown_thing", "")
1038	})
1039}
1040
1041func testRejectRequest(t *testing.T, send func(*serverTester)) {
1042	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1043		t.Error("server request made it to handler; should've been rejected")
1044	})
1045	defer st.Close()
1046
1047	st.greet()
1048	send(st)
1049	st.wantRSTStream(1, ErrCodeProtocol)
1050}
1051
1052func testRejectRequestWithProtocolError(t *testing.T, send func(*serverTester)) {
1053	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1054		t.Error("server request made it to handler; should've been rejected")
1055	}, optQuiet)
1056	defer st.Close()
1057
1058	st.greet()
1059	send(st)
1060	gf := st.wantGoAway()
1061	if gf.ErrCode != ErrCodeProtocol {
1062		t.Errorf("err code = %v; want %v", gf.ErrCode, ErrCodeProtocol)
1063	}
1064}
1065
1066// Section 5.1, on idle connections: "Receiving any frame other than
1067// HEADERS or PRIORITY on a stream in this state MUST be treated as a
1068// connection error (Section 5.4.1) of type PROTOCOL_ERROR."
1069func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) {
1070	testRejectRequestWithProtocolError(t, func(st *serverTester) {
1071		st.fr.WriteWindowUpdate(123, 456)
1072	})
1073}
1074func TestRejectFrameOnIdle_Data(t *testing.T) {
1075	testRejectRequestWithProtocolError(t, func(st *serverTester) {
1076		st.fr.WriteData(123, true, nil)
1077	})
1078}
1079func TestRejectFrameOnIdle_RSTStream(t *testing.T) {
1080	testRejectRequestWithProtocolError(t, func(st *serverTester) {
1081		st.fr.WriteRSTStream(123, ErrCodeCancel)
1082	})
1083}
1084
1085func TestServer_Request_Connect(t *testing.T) {
1086	testServerRequest(t, func(st *serverTester) {
1087		st.writeHeaders(HeadersFrameParam{
1088			StreamID: 1,
1089			BlockFragment: st.encodeHeaderRaw(
1090				":method", "CONNECT",
1091				":authority", "example.com:123",
1092			),
1093			EndStream:  true,
1094			EndHeaders: true,
1095		})
1096	}, func(r *http.Request) {
1097		if g, w := r.Method, "CONNECT"; g != w {
1098			t.Errorf("Method = %q; want %q", g, w)
1099		}
1100		if g, w := r.RequestURI, "example.com:123"; g != w {
1101			t.Errorf("RequestURI = %q; want %q", g, w)
1102		}
1103		if g, w := r.URL.Host, "example.com:123"; g != w {
1104			t.Errorf("URL.Host = %q; want %q", g, w)
1105		}
1106	})
1107}
1108
1109func TestServer_Request_Connect_InvalidPath(t *testing.T) {
1110	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1111		st.writeHeaders(HeadersFrameParam{
1112			StreamID: 1,
1113			BlockFragment: st.encodeHeaderRaw(
1114				":method", "CONNECT",
1115				":authority", "example.com:123",
1116				":path", "/bogus",
1117			),
1118			EndStream:  true,
1119			EndHeaders: true,
1120		})
1121	})
1122}
1123
1124func TestServer_Request_Connect_InvalidScheme(t *testing.T) {
1125	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1126		st.writeHeaders(HeadersFrameParam{
1127			StreamID: 1,
1128			BlockFragment: st.encodeHeaderRaw(
1129				":method", "CONNECT",
1130				":authority", "example.com:123",
1131				":scheme", "https",
1132			),
1133			EndStream:  true,
1134			EndHeaders: true,
1135		})
1136	})
1137}
1138
1139func TestServer_Ping(t *testing.T) {
1140	st := newServerTester(t, nil)
1141	defer st.Close()
1142	st.greet()
1143
1144	// Server should ignore this one, since it has ACK set.
1145	ackPingData := [8]byte{1, 2, 4, 8, 16, 32, 64, 128}
1146	if err := st.fr.WritePing(true, ackPingData); err != nil {
1147		t.Fatal(err)
1148	}
1149
1150	// But the server should reply to this one, since ACK is false.
1151	pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
1152	if err := st.fr.WritePing(false, pingData); err != nil {
1153		t.Fatal(err)
1154	}
1155
1156	pf := st.wantPing()
1157	if !pf.Flags.Has(FlagPingAck) {
1158		t.Error("response ping doesn't have ACK set")
1159	}
1160	if pf.Data != pingData {
1161		t.Errorf("response ping has data %q; want %q", pf.Data, pingData)
1162	}
1163}
1164
1165func TestServer_MaxQueuedControlFrames(t *testing.T) {
1166	if testing.Short() {
1167		t.Skip("skipping in short mode")
1168	}
1169
1170	st := newServerTester(t, nil)
1171	defer st.Close()
1172	st.greet()
1173
1174	const extraPings = 500000 // enough to fill the TCP buffers
1175
1176	for i := 0; i < maxQueuedControlFrames+extraPings; i++ {
1177		pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
1178		if err := st.fr.WritePing(false, pingData); err != nil {
1179			if i == 0 {
1180				t.Fatal(err)
1181			}
1182			// We expect the connection to get closed by the server when the TCP
1183			// buffer fills up and the write queue reaches MaxQueuedControlFrames.
1184			t.Logf("sent %d PING frames", i)
1185			return
1186		}
1187	}
1188	t.Errorf("unexpected success sending all PING frames")
1189}
1190
1191func TestServer_RejectsLargeFrames(t *testing.T) {
1192	if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
1193		t.Skip("see golang.org/issue/13434, golang.org/issue/37321")
1194	}
1195
1196	st := newServerTester(t, nil)
1197	defer st.Close()
1198	st.greet()
1199
1200	// Write too large of a frame (too large by one byte)
1201	// We ignore the return value because it's expected that the server
1202	// will only read the first 9 bytes (the headre) and then disconnect.
1203	st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, defaultMaxReadFrameSize+1))
1204
1205	gf := st.wantGoAway()
1206	if gf.ErrCode != ErrCodeFrameSize {
1207		t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFrameSize)
1208	}
1209	if st.serverLogBuf.Len() != 0 {
1210		// Previously we spun here for a bit until the GOAWAY disconnect
1211		// timer fired, logging while we fired.
1212		t.Errorf("unexpected server output: %.500s\n", st.serverLogBuf.Bytes())
1213	}
1214}
1215
1216func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
1217	puppet := newHandlerPuppet()
1218	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1219		puppet.act(w, r)
1220	})
1221	defer st.Close()
1222	defer puppet.done()
1223
1224	st.greet()
1225
1226	st.writeHeaders(HeadersFrameParam{
1227		StreamID:      1, // clients send odd numbers
1228		BlockFragment: st.encodeHeader(":method", "POST"),
1229		EndStream:     false, // data coming
1230		EndHeaders:    true,
1231	})
1232	st.writeData(1, false, []byte("abcdef"))
1233	puppet.do(readBodyHandler(t, "abc"))
1234	st.wantWindowUpdate(0, 3)
1235	st.wantWindowUpdate(1, 3)
1236
1237	puppet.do(readBodyHandler(t, "def"))
1238	st.wantWindowUpdate(0, 3)
1239	st.wantWindowUpdate(1, 3)
1240
1241	st.writeData(1, true, []byte("ghijkl")) // END_STREAM here
1242	puppet.do(readBodyHandler(t, "ghi"))
1243	puppet.do(readBodyHandler(t, "jkl"))
1244	st.wantWindowUpdate(0, 3)
1245	st.wantWindowUpdate(0, 3) // no more stream-level, since END_STREAM
1246}
1247
1248// the version of the TestServer_Handler_Sends_WindowUpdate with padding.
1249// See golang.org/issue/16556
1250func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) {
1251	puppet := newHandlerPuppet()
1252	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1253		puppet.act(w, r)
1254	})
1255	defer st.Close()
1256	defer puppet.done()
1257
1258	st.greet()
1259
1260	st.writeHeaders(HeadersFrameParam{
1261		StreamID:      1,
1262		BlockFragment: st.encodeHeader(":method", "POST"),
1263		EndStream:     false,
1264		EndHeaders:    true,
1265	})
1266	st.writeDataPadded(1, false, []byte("abcdef"), []byte{0, 0, 0, 0})
1267
1268	// Expect to immediately get our 5 bytes of padding back for
1269	// both the connection and stream (4 bytes of padding + 1 byte of length)
1270	st.wantWindowUpdate(0, 5)
1271	st.wantWindowUpdate(1, 5)
1272
1273	puppet.do(readBodyHandler(t, "abc"))
1274	st.wantWindowUpdate(0, 3)
1275	st.wantWindowUpdate(1, 3)
1276
1277	puppet.do(readBodyHandler(t, "def"))
1278	st.wantWindowUpdate(0, 3)
1279	st.wantWindowUpdate(1, 3)
1280}
1281
1282func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) {
1283	st := newServerTester(t, nil)
1284	defer st.Close()
1285	st.greet()
1286	if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
1287		t.Fatal(err)
1288	}
1289	gf := st.wantGoAway()
1290	if gf.ErrCode != ErrCodeFlowControl {
1291		t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFlowControl)
1292	}
1293	if gf.LastStreamID != 0 {
1294		t.Errorf("GOAWAY last stream ID = %v; want %v", gf.LastStreamID, 0)
1295	}
1296}
1297
1298func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) {
1299	inHandler := make(chan bool)
1300	blockHandler := make(chan bool)
1301	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1302		inHandler <- true
1303		<-blockHandler
1304	})
1305	defer st.Close()
1306	defer close(blockHandler)
1307	st.greet()
1308	st.writeHeaders(HeadersFrameParam{
1309		StreamID:      1,
1310		BlockFragment: st.encodeHeader(":method", "POST"),
1311		EndStream:     false, // keep it open
1312		EndHeaders:    true,
1313	})
1314	<-inHandler
1315	// Send a bogus window update:
1316	if err := st.fr.WriteWindowUpdate(1, 1<<31-1); err != nil {
1317		t.Fatal(err)
1318	}
1319	st.wantRSTStream(1, ErrCodeFlowControl)
1320}
1321
1322// testServerPostUnblock sends a hanging POST with unsent data to handler,
1323// then runs fn once in the handler, and verifies that the error returned from
1324// handler is acceptable. It fails if takes over 5 seconds for handler to exit.
1325func testServerPostUnblock(t *testing.T,
1326	handler func(http.ResponseWriter, *http.Request) error,
1327	fn func(*serverTester),
1328	checkErr func(error),
1329	otherHeaders ...string) {
1330	inHandler := make(chan bool)
1331	errc := make(chan error, 1)
1332	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1333		inHandler <- true
1334		errc <- handler(w, r)
1335	})
1336	defer st.Close()
1337	st.greet()
1338	st.writeHeaders(HeadersFrameParam{
1339		StreamID:      1,
1340		BlockFragment: st.encodeHeader(append([]string{":method", "POST"}, otherHeaders...)...),
1341		EndStream:     false, // keep it open
1342		EndHeaders:    true,
1343	})
1344	<-inHandler
1345	fn(st)
1346	select {
1347	case err := <-errc:
1348		if checkErr != nil {
1349			checkErr(err)
1350		}
1351	case <-time.After(5 * time.Second):
1352		t.Fatal("timeout waiting for Handler to return")
1353	}
1354}
1355
1356func TestServer_RSTStream_Unblocks_Read(t *testing.T) {
1357	testServerPostUnblock(t,
1358		func(w http.ResponseWriter, r *http.Request) (err error) {
1359			_, err = r.Body.Read(make([]byte, 1))
1360			return
1361		},
1362		func(st *serverTester) {
1363			if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1364				t.Fatal(err)
1365			}
1366		},
1367		func(err error) {
1368			want := StreamError{StreamID: 0x1, Code: 0x8}
1369			if !reflect.DeepEqual(err, want) {
1370				t.Errorf("Read error = %v; want %v", err, want)
1371			}
1372		},
1373	)
1374}
1375
1376func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1377	// Run this test a bunch, because it doesn't always
1378	// deadlock. But with a bunch, it did.
1379	n := 50
1380	if testing.Short() {
1381		n = 5
1382	}
1383	for i := 0; i < n; i++ {
1384		testServer_RSTStream_Unblocks_Header_Write(t)
1385	}
1386}
1387
1388func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1389	inHandler := make(chan bool, 1)
1390	unblockHandler := make(chan bool, 1)
1391	headerWritten := make(chan bool, 1)
1392	wroteRST := make(chan bool, 1)
1393
1394	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1395		inHandler <- true
1396		<-wroteRST
1397		w.Header().Set("foo", "bar")
1398		w.WriteHeader(200)
1399		w.(http.Flusher).Flush()
1400		headerWritten <- true
1401		<-unblockHandler
1402	})
1403	defer st.Close()
1404
1405	st.greet()
1406	st.writeHeaders(HeadersFrameParam{
1407		StreamID:      1,
1408		BlockFragment: st.encodeHeader(":method", "POST"),
1409		EndStream:     false, // keep it open
1410		EndHeaders:    true,
1411	})
1412	<-inHandler
1413	if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1414		t.Fatal(err)
1415	}
1416	wroteRST <- true
1417	st.awaitIdle()
1418	select {
1419	case <-headerWritten:
1420	case <-time.After(2 * time.Second):
1421		t.Error("timeout waiting for header write")
1422	}
1423	unblockHandler <- true
1424}
1425
1426func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
1427	testServerPostUnblock(t,
1428		func(w http.ResponseWriter, r *http.Request) (err error) {
1429			_, err = r.Body.Read(make([]byte, 1))
1430			return
1431		},
1432		func(st *serverTester) { st.cc.Close() },
1433		func(err error) {
1434			if err == nil {
1435				t.Error("unexpected nil error from Request.Body.Read")
1436			}
1437		},
1438	)
1439}
1440
1441var blockUntilClosed = func(w http.ResponseWriter, r *http.Request) error {
1442	<-w.(http.CloseNotifier).CloseNotify()
1443	return nil
1444}
1445
1446func TestServer_CloseNotify_After_RSTStream(t *testing.T) {
1447	testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1448		if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1449			t.Fatal(err)
1450		}
1451	}, nil)
1452}
1453
1454func TestServer_CloseNotify_After_ConnClose(t *testing.T) {
1455	testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { st.cc.Close() }, nil)
1456}
1457
1458// that CloseNotify unblocks after a stream error due to the client's
1459// problem that's unrelated to them explicitly canceling it (which is
1460// TestServer_CloseNotify_After_RSTStream above)
1461func TestServer_CloseNotify_After_StreamError(t *testing.T) {
1462	testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1463		// data longer than declared Content-Length => stream error
1464		st.writeData(1, true, []byte("1234"))
1465	}, nil, "content-length", "3")
1466}
1467
1468func TestServer_StateTransitions(t *testing.T) {
1469	var st *serverTester
1470	inHandler := make(chan bool)
1471	writeData := make(chan bool)
1472	leaveHandler := make(chan bool)
1473	st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1474		inHandler <- true
1475		if st.stream(1) == nil {
1476			t.Errorf("nil stream 1 in handler")
1477		}
1478		if got, want := st.streamState(1), stateOpen; got != want {
1479			t.Errorf("in handler, state is %v; want %v", got, want)
1480		}
1481		writeData <- true
1482		if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
1483			t.Errorf("body read = %d, %v; want 0, EOF", n, err)
1484		}
1485		if got, want := st.streamState(1), stateHalfClosedRemote; got != want {
1486			t.Errorf("in handler, state is %v; want %v", got, want)
1487		}
1488
1489		<-leaveHandler
1490	})
1491	st.greet()
1492	if st.stream(1) != nil {
1493		t.Fatal("stream 1 should be empty")
1494	}
1495	if got := st.streamState(1); got != stateIdle {
1496		t.Fatalf("stream 1 should be idle; got %v", got)
1497	}
1498
1499	st.writeHeaders(HeadersFrameParam{
1500		StreamID:      1,
1501		BlockFragment: st.encodeHeader(":method", "POST"),
1502		EndStream:     false, // keep it open
1503		EndHeaders:    true,
1504	})
1505	<-inHandler
1506	<-writeData
1507	st.writeData(1, true, nil)
1508
1509	leaveHandler <- true
1510	hf := st.wantHeaders()
1511	if !hf.StreamEnded() {
1512		t.Fatal("expected END_STREAM flag")
1513	}
1514
1515	if got, want := st.streamState(1), stateClosed; got != want {
1516		t.Errorf("at end, state is %v; want %v", got, want)
1517	}
1518	if st.stream(1) != nil {
1519		t.Fatal("at end, stream 1 should be gone")
1520	}
1521}
1522
1523// test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
1524func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
1525	testServerRejectsConn(t, func(st *serverTester) {
1526		st.writeHeaders(HeadersFrameParam{
1527			StreamID:      1,
1528			BlockFragment: st.encodeHeader(),
1529			EndStream:     true,
1530			EndHeaders:    false,
1531		})
1532		st.writeHeaders(HeadersFrameParam{ // Not a continuation.
1533			StreamID:      3, // different stream.
1534			BlockFragment: st.encodeHeader(),
1535			EndStream:     true,
1536			EndHeaders:    true,
1537		})
1538	})
1539}
1540
1541// test HEADERS w/o EndHeaders + PING (should get rejected)
1542func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
1543	testServerRejectsConn(t, func(st *serverTester) {
1544		st.writeHeaders(HeadersFrameParam{
1545			StreamID:      1,
1546			BlockFragment: st.encodeHeader(),
1547			EndStream:     true,
1548			EndHeaders:    false,
1549		})
1550		if err := st.fr.WritePing(false, [8]byte{}); err != nil {
1551			t.Fatal(err)
1552		}
1553	})
1554}
1555
1556// test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
1557func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
1558	testServerRejectsConn(t, func(st *serverTester) {
1559		st.writeHeaders(HeadersFrameParam{
1560			StreamID:      1,
1561			BlockFragment: st.encodeHeader(),
1562			EndStream:     true,
1563			EndHeaders:    true,
1564		})
1565		st.wantHeaders()
1566		if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1567			t.Fatal(err)
1568		}
1569	})
1570}
1571
1572// test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID
1573func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
1574	testServerRejectsConn(t, func(st *serverTester) {
1575		st.writeHeaders(HeadersFrameParam{
1576			StreamID:      1,
1577			BlockFragment: st.encodeHeader(),
1578			EndStream:     true,
1579			EndHeaders:    false,
1580		})
1581		if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1582			t.Fatal(err)
1583		}
1584	})
1585}
1586
1587// No HEADERS on stream 0.
1588func TestServer_Rejects_Headers0(t *testing.T) {
1589	testServerRejectsConn(t, func(st *serverTester) {
1590		st.fr.AllowIllegalWrites = true
1591		st.writeHeaders(HeadersFrameParam{
1592			StreamID:      0,
1593			BlockFragment: st.encodeHeader(),
1594			EndStream:     true,
1595			EndHeaders:    true,
1596		})
1597	})
1598}
1599
1600// No CONTINUATION on stream 0.
1601func TestServer_Rejects_Continuation0(t *testing.T) {
1602	testServerRejectsConn(t, func(st *serverTester) {
1603		st.fr.AllowIllegalWrites = true
1604		if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil {
1605			t.Fatal(err)
1606		}
1607	})
1608}
1609
1610// No PRIORITY on stream 0.
1611func TestServer_Rejects_Priority0(t *testing.T) {
1612	testServerRejectsConn(t, func(st *serverTester) {
1613		st.fr.AllowIllegalWrites = true
1614		st.writePriority(0, PriorityParam{StreamDep: 1})
1615	})
1616}
1617
1618// No HEADERS frame with a self-dependence.
1619func TestServer_Rejects_HeadersSelfDependence(t *testing.T) {
1620	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1621		st.fr.AllowIllegalWrites = true
1622		st.writeHeaders(HeadersFrameParam{
1623			StreamID:      1,
1624			BlockFragment: st.encodeHeader(),
1625			EndStream:     true,
1626			EndHeaders:    true,
1627			Priority:      PriorityParam{StreamDep: 1},
1628		})
1629	})
1630}
1631
1632// No PRIORTY frame with a self-dependence.
1633func TestServer_Rejects_PrioritySelfDependence(t *testing.T) {
1634	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1635		st.fr.AllowIllegalWrites = true
1636		st.writePriority(1, PriorityParam{StreamDep: 1})
1637	})
1638}
1639
1640func TestServer_Rejects_PushPromise(t *testing.T) {
1641	testServerRejectsConn(t, func(st *serverTester) {
1642		pp := PushPromiseParam{
1643			StreamID:  1,
1644			PromiseID: 3,
1645		}
1646		if err := st.fr.WritePushPromise(pp); err != nil {
1647			t.Fatal(err)
1648		}
1649	})
1650}
1651
1652// testServerRejectsConn tests that the server hangs up with a GOAWAY
1653// frame and a server close after the client does something
1654// deserving a CONNECTION_ERROR.
1655func testServerRejectsConn(t *testing.T, writeReq func(*serverTester)) {
1656	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1657	st.addLogFilter("connection error: PROTOCOL_ERROR")
1658	defer st.Close()
1659	st.greet()
1660	writeReq(st)
1661
1662	st.wantGoAway()
1663	errc := make(chan error, 1)
1664	go func() {
1665		fr, err := st.fr.ReadFrame()
1666		if err == nil {
1667			err = fmt.Errorf("got frame of type %T", fr)
1668		}
1669		errc <- err
1670	}()
1671	select {
1672	case err := <-errc:
1673		if err != io.EOF {
1674			t.Errorf("ReadFrame = %v; want io.EOF", err)
1675		}
1676	case <-time.After(2 * time.Second):
1677		t.Error("timeout waiting for disconnect")
1678	}
1679}
1680
1681// testServerRejectsStream tests that the server sends a RST_STREAM with the provided
1682// error code after a client sends a bogus request.
1683func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTester)) {
1684	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1685	defer st.Close()
1686	st.greet()
1687	writeReq(st)
1688	st.wantRSTStream(1, code)
1689}
1690
1691// testServerRequest sets up an idle HTTP/2 connection and lets you
1692// write a single request with writeReq, and then verify that the
1693// *http.Request is built correctly in checkReq.
1694func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) {
1695	gotReq := make(chan bool, 1)
1696	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1697		if r.Body == nil {
1698			t.Fatal("nil Body")
1699		}
1700		checkReq(r)
1701		gotReq <- true
1702	})
1703	defer st.Close()
1704
1705	st.greet()
1706	writeReq(st)
1707
1708	select {
1709	case <-gotReq:
1710	case <-time.After(2 * time.Second):
1711		t.Error("timeout waiting for request")
1712	}
1713}
1714
1715func getSlash(st *serverTester) { st.bodylessReq1() }
1716
1717func TestServer_Response_NoData(t *testing.T) {
1718	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1719		// Nothing.
1720		return nil
1721	}, func(st *serverTester) {
1722		getSlash(st)
1723		hf := st.wantHeaders()
1724		if !hf.StreamEnded() {
1725			t.Fatal("want END_STREAM flag")
1726		}
1727		if !hf.HeadersEnded() {
1728			t.Fatal("want END_HEADERS flag")
1729		}
1730	})
1731}
1732
1733func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
1734	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1735		w.Header().Set("Foo-Bar", "some-value")
1736		return nil
1737	}, func(st *serverTester) {
1738		getSlash(st)
1739		hf := st.wantHeaders()
1740		if !hf.StreamEnded() {
1741			t.Fatal("want END_STREAM flag")
1742		}
1743		if !hf.HeadersEnded() {
1744			t.Fatal("want END_HEADERS flag")
1745		}
1746		goth := st.decodeHeader(hf.HeaderBlockFragment())
1747		wanth := [][2]string{
1748			{":status", "200"},
1749			{"foo-bar", "some-value"},
1750			{"content-length", "0"},
1751		}
1752		if !reflect.DeepEqual(goth, wanth) {
1753			t.Errorf("Got headers %v; want %v", goth, wanth)
1754		}
1755	})
1756}
1757
1758func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) {
1759	const msg = "<html>this is HTML."
1760	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1761		w.Header().Set("Content-Type", "foo/bar")
1762		io.WriteString(w, msg)
1763		return nil
1764	}, func(st *serverTester) {
1765		getSlash(st)
1766		hf := st.wantHeaders()
1767		if hf.StreamEnded() {
1768			t.Fatal("don't want END_STREAM, expecting data")
1769		}
1770		if !hf.HeadersEnded() {
1771			t.Fatal("want END_HEADERS flag")
1772		}
1773		goth := st.decodeHeader(hf.HeaderBlockFragment())
1774		wanth := [][2]string{
1775			{":status", "200"},
1776			{"content-type", "foo/bar"},
1777			{"content-length", strconv.Itoa(len(msg))},
1778		}
1779		if !reflect.DeepEqual(goth, wanth) {
1780			t.Errorf("Got headers %v; want %v", goth, wanth)
1781		}
1782		df := st.wantData()
1783		if !df.StreamEnded() {
1784			t.Error("expected DATA to have END_STREAM flag")
1785		}
1786		if got := string(df.Data()); got != msg {
1787			t.Errorf("got DATA %q; want %q", got, msg)
1788		}
1789	})
1790}
1791
1792func TestServer_Response_TransferEncoding_chunked(t *testing.T) {
1793	const msg = "hi"
1794	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1795		w.Header().Set("Transfer-Encoding", "chunked") // should be stripped
1796		io.WriteString(w, msg)
1797		return nil
1798	}, func(st *serverTester) {
1799		getSlash(st)
1800		hf := st.wantHeaders()
1801		goth := st.decodeHeader(hf.HeaderBlockFragment())
1802		wanth := [][2]string{
1803			{":status", "200"},
1804			{"content-type", "text/plain; charset=utf-8"},
1805			{"content-length", strconv.Itoa(len(msg))},
1806		}
1807		if !reflect.DeepEqual(goth, wanth) {
1808			t.Errorf("Got headers %v; want %v", goth, wanth)
1809		}
1810	})
1811}
1812
1813// Header accessed only after the initial write.
1814func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) {
1815	const msg = "<html>this is HTML."
1816	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1817		io.WriteString(w, msg)
1818		w.Header().Set("foo", "should be ignored")
1819		return nil
1820	}, func(st *serverTester) {
1821		getSlash(st)
1822		hf := st.wantHeaders()
1823		if hf.StreamEnded() {
1824			t.Fatal("unexpected END_STREAM")
1825		}
1826		if !hf.HeadersEnded() {
1827			t.Fatal("want END_HEADERS flag")
1828		}
1829		goth := st.decodeHeader(hf.HeaderBlockFragment())
1830		wanth := [][2]string{
1831			{":status", "200"},
1832			{"content-type", "text/html; charset=utf-8"},
1833			{"content-length", strconv.Itoa(len(msg))},
1834		}
1835		if !reflect.DeepEqual(goth, wanth) {
1836			t.Errorf("Got headers %v; want %v", goth, wanth)
1837		}
1838	})
1839}
1840
1841// Header accessed before the initial write and later mutated.
1842func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) {
1843	const msg = "<html>this is HTML."
1844	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1845		w.Header().Set("foo", "proper value")
1846		io.WriteString(w, msg)
1847		w.Header().Set("foo", "should be ignored")
1848		return nil
1849	}, func(st *serverTester) {
1850		getSlash(st)
1851		hf := st.wantHeaders()
1852		if hf.StreamEnded() {
1853			t.Fatal("unexpected END_STREAM")
1854		}
1855		if !hf.HeadersEnded() {
1856			t.Fatal("want END_HEADERS flag")
1857		}
1858		goth := st.decodeHeader(hf.HeaderBlockFragment())
1859		wanth := [][2]string{
1860			{":status", "200"},
1861			{"foo", "proper value"},
1862			{"content-type", "text/html; charset=utf-8"},
1863			{"content-length", strconv.Itoa(len(msg))},
1864		}
1865		if !reflect.DeepEqual(goth, wanth) {
1866			t.Errorf("Got headers %v; want %v", goth, wanth)
1867		}
1868	})
1869}
1870
1871func TestServer_Response_Data_SniffLenType(t *testing.T) {
1872	const msg = "<html>this is HTML."
1873	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1874		io.WriteString(w, msg)
1875		return nil
1876	}, func(st *serverTester) {
1877		getSlash(st)
1878		hf := st.wantHeaders()
1879		if hf.StreamEnded() {
1880			t.Fatal("don't want END_STREAM, expecting data")
1881		}
1882		if !hf.HeadersEnded() {
1883			t.Fatal("want END_HEADERS flag")
1884		}
1885		goth := st.decodeHeader(hf.HeaderBlockFragment())
1886		wanth := [][2]string{
1887			{":status", "200"},
1888			{"content-type", "text/html; charset=utf-8"},
1889			{"content-length", strconv.Itoa(len(msg))},
1890		}
1891		if !reflect.DeepEqual(goth, wanth) {
1892			t.Errorf("Got headers %v; want %v", goth, wanth)
1893		}
1894		df := st.wantData()
1895		if !df.StreamEnded() {
1896			t.Error("expected DATA to have END_STREAM flag")
1897		}
1898		if got := string(df.Data()); got != msg {
1899			t.Errorf("got DATA %q; want %q", got, msg)
1900		}
1901	})
1902}
1903
1904func TestServer_Response_Header_Flush_MidWrite(t *testing.T) {
1905	const msg = "<html>this is HTML"
1906	const msg2 = ", and this is the next chunk"
1907	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1908		io.WriteString(w, msg)
1909		w.(http.Flusher).Flush()
1910		io.WriteString(w, msg2)
1911		return nil
1912	}, func(st *serverTester) {
1913		getSlash(st)
1914		hf := st.wantHeaders()
1915		if hf.StreamEnded() {
1916			t.Fatal("unexpected END_STREAM flag")
1917		}
1918		if !hf.HeadersEnded() {
1919			t.Fatal("want END_HEADERS flag")
1920		}
1921		goth := st.decodeHeader(hf.HeaderBlockFragment())
1922		wanth := [][2]string{
1923			{":status", "200"},
1924			{"content-type", "text/html; charset=utf-8"}, // sniffed
1925			// and no content-length
1926		}
1927		if !reflect.DeepEqual(goth, wanth) {
1928			t.Errorf("Got headers %v; want %v", goth, wanth)
1929		}
1930		{
1931			df := st.wantData()
1932			if df.StreamEnded() {
1933				t.Error("unexpected END_STREAM flag")
1934			}
1935			if got := string(df.Data()); got != msg {
1936				t.Errorf("got DATA %q; want %q", got, msg)
1937			}
1938		}
1939		{
1940			df := st.wantData()
1941			if !df.StreamEnded() {
1942				t.Error("wanted END_STREAM flag on last data chunk")
1943			}
1944			if got := string(df.Data()); got != msg2 {
1945				t.Errorf("got DATA %q; want %q", got, msg2)
1946			}
1947		}
1948	})
1949}
1950
1951func TestServer_Response_LargeWrite(t *testing.T) {
1952	const size = 1 << 20
1953	const maxFrameSize = 16 << 10
1954	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1955		n, err := w.Write(bytes.Repeat([]byte("a"), size))
1956		if err != nil {
1957			return fmt.Errorf("Write error: %v", err)
1958		}
1959		if n != size {
1960			return fmt.Errorf("wrong size %d from Write", n)
1961		}
1962		return nil
1963	}, func(st *serverTester) {
1964		if err := st.fr.WriteSettings(
1965			Setting{SettingInitialWindowSize, 0},
1966			Setting{SettingMaxFrameSize, maxFrameSize},
1967		); err != nil {
1968			t.Fatal(err)
1969		}
1970		st.wantSettingsAck()
1971
1972		getSlash(st) // make the single request
1973
1974		// Give the handler quota to write:
1975		if err := st.fr.WriteWindowUpdate(1, size); err != nil {
1976			t.Fatal(err)
1977		}
1978		// Give the handler quota to write to connection-level
1979		// window as well
1980		if err := st.fr.WriteWindowUpdate(0, size); err != nil {
1981			t.Fatal(err)
1982		}
1983		hf := st.wantHeaders()
1984		if hf.StreamEnded() {
1985			t.Fatal("unexpected END_STREAM flag")
1986		}
1987		if !hf.HeadersEnded() {
1988			t.Fatal("want END_HEADERS flag")
1989		}
1990		goth := st.decodeHeader(hf.HeaderBlockFragment())
1991		wanth := [][2]string{
1992			{":status", "200"},
1993			{"content-type", "text/plain; charset=utf-8"}, // sniffed
1994			// and no content-length
1995		}
1996		if !reflect.DeepEqual(goth, wanth) {
1997			t.Errorf("Got headers %v; want %v", goth, wanth)
1998		}
1999		var bytes, frames int
2000		for {
2001			df := st.wantData()
2002			bytes += len(df.Data())
2003			frames++
2004			for _, b := range df.Data() {
2005				if b != 'a' {
2006					t.Fatal("non-'a' byte seen in DATA")
2007				}
2008			}
2009			if df.StreamEnded() {
2010				break
2011			}
2012		}
2013		if bytes != size {
2014			t.Errorf("Got %d bytes; want %d", bytes, size)
2015		}
2016		if want := int(size / maxFrameSize); frames < want || frames > want*2 {
2017			t.Errorf("Got %d frames; want %d", frames, size)
2018		}
2019	})
2020}
2021
2022// Test that the handler can't write more than the client allows
2023func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) {
2024	// Make these reads. Before each read, the client adds exactly enough
2025	// flow-control to satisfy the read. Numbers chosen arbitrarily.
2026	reads := []int{123, 1, 13, 127}
2027	size := 0
2028	for _, n := range reads {
2029		size += n
2030	}
2031
2032	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2033		w.(http.Flusher).Flush()
2034		n, err := w.Write(bytes.Repeat([]byte("a"), size))
2035		if err != nil {
2036			return fmt.Errorf("Write error: %v", err)
2037		}
2038		if n != size {
2039			return fmt.Errorf("wrong size %d from Write", n)
2040		}
2041		return nil
2042	}, func(st *serverTester) {
2043		// Set the window size to something explicit for this test.
2044		// It's also how much initial data we expect.
2045		if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, uint32(reads[0])}); err != nil {
2046			t.Fatal(err)
2047		}
2048		st.wantSettingsAck()
2049
2050		getSlash(st) // make the single request
2051
2052		hf := st.wantHeaders()
2053		if hf.StreamEnded() {
2054			t.Fatal("unexpected END_STREAM flag")
2055		}
2056		if !hf.HeadersEnded() {
2057			t.Fatal("want END_HEADERS flag")
2058		}
2059
2060		df := st.wantData()
2061		if got := len(df.Data()); got != reads[0] {
2062			t.Fatalf("Initial window size = %d but got DATA with %d bytes", reads[0], got)
2063		}
2064
2065		for _, quota := range reads[1:] {
2066			if err := st.fr.WriteWindowUpdate(1, uint32(quota)); err != nil {
2067				t.Fatal(err)
2068			}
2069			df := st.wantData()
2070			if int(quota) != len(df.Data()) {
2071				t.Fatalf("read %d bytes after giving %d quota", len(df.Data()), quota)
2072			}
2073		}
2074	})
2075}
2076
2077// Test that the handler blocked in a Write is unblocked if the server sends a RST_STREAM.
2078func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) {
2079	const size = 1 << 20
2080	const maxFrameSize = 16 << 10
2081	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2082		w.(http.Flusher).Flush()
2083		errc := make(chan error, 1)
2084		go func() {
2085			_, err := w.Write(bytes.Repeat([]byte("a"), size))
2086			errc <- err
2087		}()
2088		select {
2089		case err := <-errc:
2090			if err == nil {
2091				return errors.New("unexpected nil error from Write in handler")
2092			}
2093			return nil
2094		case <-time.After(2 * time.Second):
2095			return errors.New("timeout waiting for Write in handler")
2096		}
2097	}, func(st *serverTester) {
2098		if err := st.fr.WriteSettings(
2099			Setting{SettingInitialWindowSize, 0},
2100			Setting{SettingMaxFrameSize, maxFrameSize},
2101		); err != nil {
2102			t.Fatal(err)
2103		}
2104		st.wantSettingsAck()
2105
2106		getSlash(st) // make the single request
2107
2108		hf := st.wantHeaders()
2109		if hf.StreamEnded() {
2110			t.Fatal("unexpected END_STREAM flag")
2111		}
2112		if !hf.HeadersEnded() {
2113			t.Fatal("want END_HEADERS flag")
2114		}
2115
2116		if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
2117			t.Fatal(err)
2118		}
2119	})
2120}
2121
2122func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) {
2123	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2124		w.(http.Flusher).Flush()
2125		// Nothing; send empty DATA
2126		return nil
2127	}, func(st *serverTester) {
2128		// Handler gets no data quota:
2129		if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, 0}); err != nil {
2130			t.Fatal(err)
2131		}
2132		st.wantSettingsAck()
2133
2134		getSlash(st) // make the single request
2135
2136		hf := st.wantHeaders()
2137		if hf.StreamEnded() {
2138			t.Fatal("unexpected END_STREAM flag")
2139		}
2140		if !hf.HeadersEnded() {
2141			t.Fatal("want END_HEADERS flag")
2142		}
2143
2144		df := st.wantData()
2145		if got := len(df.Data()); got != 0 {
2146			t.Fatalf("unexpected %d DATA bytes; want 0", got)
2147		}
2148		if !df.StreamEnded() {
2149			t.Fatal("DATA didn't have END_STREAM")
2150		}
2151	})
2152}
2153
2154func TestServer_Response_Automatic100Continue(t *testing.T) {
2155	const msg = "foo"
2156	const reply = "bar"
2157	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2158		if v := r.Header.Get("Expect"); v != "" {
2159			t.Errorf("Expect header = %q; want empty", v)
2160		}
2161		buf := make([]byte, len(msg))
2162		// This read should trigger the 100-continue being sent.
2163		if n, err := io.ReadFull(r.Body, buf); err != nil || n != len(msg) || string(buf) != msg {
2164			return fmt.Errorf("ReadFull = %q, %v; want %q, nil", buf[:n], err, msg)
2165		}
2166		_, err := io.WriteString(w, reply)
2167		return err
2168	}, func(st *serverTester) {
2169		st.writeHeaders(HeadersFrameParam{
2170			StreamID:      1, // clients send odd numbers
2171			BlockFragment: st.encodeHeader(":method", "POST", "expect", "100-continue"),
2172			EndStream:     false,
2173			EndHeaders:    true,
2174		})
2175		hf := st.wantHeaders()
2176		if hf.StreamEnded() {
2177			t.Fatal("unexpected END_STREAM flag")
2178		}
2179		if !hf.HeadersEnded() {
2180			t.Fatal("want END_HEADERS flag")
2181		}
2182		goth := st.decodeHeader(hf.HeaderBlockFragment())
2183		wanth := [][2]string{
2184			{":status", "100"},
2185		}
2186		if !reflect.DeepEqual(goth, wanth) {
2187			t.Fatalf("Got headers %v; want %v", goth, wanth)
2188		}
2189
2190		// Okay, they sent status 100, so we can send our
2191		// gigantic and/or sensitive "foo" payload now.
2192		st.writeData(1, true, []byte(msg))
2193
2194		st.wantWindowUpdate(0, uint32(len(msg)))
2195
2196		hf = st.wantHeaders()
2197		if hf.StreamEnded() {
2198			t.Fatal("expected data to follow")
2199		}
2200		if !hf.HeadersEnded() {
2201			t.Fatal("want END_HEADERS flag")
2202		}
2203		goth = st.decodeHeader(hf.HeaderBlockFragment())
2204		wanth = [][2]string{
2205			{":status", "200"},
2206			{"content-type", "text/plain; charset=utf-8"},
2207			{"content-length", strconv.Itoa(len(reply))},
2208		}
2209		if !reflect.DeepEqual(goth, wanth) {
2210			t.Errorf("Got headers %v; want %v", goth, wanth)
2211		}
2212
2213		df := st.wantData()
2214		if string(df.Data()) != reply {
2215			t.Errorf("Client read %q; want %q", df.Data(), reply)
2216		}
2217		if !df.StreamEnded() {
2218			t.Errorf("expect data stream end")
2219		}
2220	})
2221}
2222
2223func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
2224	errc := make(chan error, 1)
2225	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2226		p := []byte("some data.\n")
2227		for {
2228			_, err := w.Write(p)
2229			if err != nil {
2230				errc <- err
2231				return nil
2232			}
2233		}
2234	}, func(st *serverTester) {
2235		st.writeHeaders(HeadersFrameParam{
2236			StreamID:      1,
2237			BlockFragment: st.encodeHeader(),
2238			EndStream:     false,
2239			EndHeaders:    true,
2240		})
2241		hf := st.wantHeaders()
2242		if hf.StreamEnded() {
2243			t.Fatal("unexpected END_STREAM flag")
2244		}
2245		if !hf.HeadersEnded() {
2246			t.Fatal("want END_HEADERS flag")
2247		}
2248		// Close the connection and wait for the handler to (hopefully) notice.
2249		st.cc.Close()
2250		select {
2251		case <-errc:
2252		case <-time.After(5 * time.Second):
2253			t.Error("timeout")
2254		}
2255	})
2256}
2257
2258func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
2259	const testPath = "/some/path"
2260
2261	inHandler := make(chan uint32)
2262	leaveHandler := make(chan bool)
2263	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2264		id := w.(*responseWriter).rws.stream.id
2265		inHandler <- id
2266		if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath {
2267			t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath)
2268		}
2269		<-leaveHandler
2270	})
2271	defer st.Close()
2272	st.greet()
2273	nextStreamID := uint32(1)
2274	streamID := func() uint32 {
2275		defer func() { nextStreamID += 2 }()
2276		return nextStreamID
2277	}
2278	sendReq := func(id uint32, headers ...string) {
2279		st.writeHeaders(HeadersFrameParam{
2280			StreamID:      id,
2281			BlockFragment: st.encodeHeader(headers...),
2282			EndStream:     true,
2283			EndHeaders:    true,
2284		})
2285	}
2286	for i := 0; i < defaultMaxStreams; i++ {
2287		sendReq(streamID())
2288		<-inHandler
2289	}
2290	defer func() {
2291		for i := 0; i < defaultMaxStreams; i++ {
2292			leaveHandler <- true
2293		}
2294	}()
2295
2296	// And this one should cross the limit:
2297	// (It's also sent as a CONTINUATION, to verify we still track the decoder context,
2298	// even if we're rejecting it)
2299	rejectID := streamID()
2300	headerBlock := st.encodeHeader(":path", testPath)
2301	frag1, frag2 := headerBlock[:3], headerBlock[3:]
2302	st.writeHeaders(HeadersFrameParam{
2303		StreamID:      rejectID,
2304		BlockFragment: frag1,
2305		EndStream:     true,
2306		EndHeaders:    false, // CONTINUATION coming
2307	})
2308	if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil {
2309		t.Fatal(err)
2310	}
2311	st.wantRSTStream(rejectID, ErrCodeProtocol)
2312
2313	// But let a handler finish:
2314	leaveHandler <- true
2315	st.wantHeaders()
2316
2317	// And now another stream should be able to start:
2318	goodID := streamID()
2319	sendReq(goodID, ":path", testPath)
2320	select {
2321	case got := <-inHandler:
2322		if got != goodID {
2323			t.Errorf("Got stream %d; want %d", got, goodID)
2324		}
2325	case <-time.After(3 * time.Second):
2326		t.Error("timeout waiting for handler")
2327	}
2328}
2329
2330// So many response headers that the server needs to use CONTINUATION frames:
2331func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) {
2332	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2333		h := w.Header()
2334		for i := 0; i < 5000; i++ {
2335			h.Set(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("x-value-%d", i))
2336		}
2337		return nil
2338	}, func(st *serverTester) {
2339		getSlash(st)
2340		hf := st.wantHeaders()
2341		if hf.HeadersEnded() {
2342			t.Fatal("got unwanted END_HEADERS flag")
2343		}
2344		n := 0
2345		for {
2346			n++
2347			cf := st.wantContinuation()
2348			if cf.HeadersEnded() {
2349				break
2350			}
2351		}
2352		if n < 5 {
2353			t.Errorf("Only got %d CONTINUATION frames; expected 5+ (currently 6)", n)
2354		}
2355	})
2356}
2357
2358// This previously crashed (reported by Mathieu Lonjaret as observed
2359// while using Camlistore) because we got a DATA frame from the client
2360// after the handler exited and our logic at the time was wrong,
2361// keeping a stream in the map in stateClosed, which tickled an
2362// invariant check later when we tried to remove that stream (via
2363// defer sc.closeAllStreamsOnConnClose) when the serverConn serve loop
2364// ended.
2365func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) {
2366	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2367		// nothing
2368		return nil
2369	}, func(st *serverTester) {
2370		st.writeHeaders(HeadersFrameParam{
2371			StreamID:      1,
2372			BlockFragment: st.encodeHeader(),
2373			EndStream:     false, // DATA is coming
2374			EndHeaders:    true,
2375		})
2376		hf := st.wantHeaders()
2377		if !hf.HeadersEnded() || !hf.StreamEnded() {
2378			t.Fatalf("want END_HEADERS+END_STREAM, got %v", hf)
2379		}
2380
2381		// Sent when the a Handler closes while a client has
2382		// indicated it's still sending DATA:
2383		st.wantRSTStream(1, ErrCodeNo)
2384
2385		// Now the handler has ended, so it's ended its
2386		// stream, but the client hasn't closed its side
2387		// (stateClosedLocal).  So send more data and verify
2388		// it doesn't crash with an internal invariant panic, like
2389		// it did before.
2390		st.writeData(1, true, []byte("foo"))
2391
2392		// Get our flow control bytes back, since the handler didn't get them.
2393		st.wantWindowUpdate(0, uint32(len("foo")))
2394
2395		// Sent after a peer sends data anyway (admittedly the
2396		// previous RST_STREAM might've still been in-flight),
2397		// but they'll get the more friendly 'cancel' code
2398		// first.
2399		st.wantRSTStream(1, ErrCodeStreamClosed)
2400
2401		// Set up a bunch of machinery to record the panic we saw
2402		// previously.
2403		var (
2404			panMu    sync.Mutex
2405			panicVal interface{}
2406		)
2407
2408		testHookOnPanicMu.Lock()
2409		testHookOnPanic = func(sc *serverConn, pv interface{}) bool {
2410			panMu.Lock()
2411			panicVal = pv
2412			panMu.Unlock()
2413			return true
2414		}
2415		testHookOnPanicMu.Unlock()
2416
2417		// Now force the serve loop to end, via closing the connection.
2418		st.cc.Close()
2419		select {
2420		case <-st.sc.doneServing:
2421			// Loop has exited.
2422			panMu.Lock()
2423			got := panicVal
2424			panMu.Unlock()
2425			if got != nil {
2426				t.Errorf("Got panic: %v", got)
2427			}
2428		case <-time.After(5 * time.Second):
2429			t.Error("timeout")
2430		}
2431	})
2432}
2433
2434func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) }
2435func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) }
2436
2437func testRejectTLS(t *testing.T, max uint16) {
2438	st := newServerTester(t, nil, func(c *tls.Config) {
2439		c.MaxVersion = max
2440	})
2441	defer st.Close()
2442	gf := st.wantGoAway()
2443	if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
2444		t.Errorf("Got error code %v; want %v", got, want)
2445	}
2446}
2447
2448func TestServer_Rejects_TLSBadCipher(t *testing.T) {
2449	st := newServerTester(t, nil, func(c *tls.Config) {
2450		// All TLS 1.3 ciphers are good. Test with TLS 1.2.
2451		c.MaxVersion = tls.VersionTLS12
2452		// Only list bad ones:
2453		c.CipherSuites = []uint16{
2454			tls.TLS_RSA_WITH_RC4_128_SHA,
2455			tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
2456			tls.TLS_RSA_WITH_AES_128_CBC_SHA,
2457			tls.TLS_RSA_WITH_AES_256_CBC_SHA,
2458			tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
2459			tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
2460			tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
2461			tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
2462			tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
2463			tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
2464			tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
2465			cipher_TLS_RSA_WITH_AES_128_CBC_SHA256,
2466		}
2467	})
2468	defer st.Close()
2469	gf := st.wantGoAway()
2470	if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
2471		t.Errorf("Got error code %v; want %v", got, want)
2472	}
2473}
2474
2475func TestServer_Advertises_Common_Cipher(t *testing.T) {
2476	const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
2477	st := newServerTester(t, nil, func(c *tls.Config) {
2478		// Have the client only support the one required by the spec.
2479		c.CipherSuites = []uint16{requiredSuite}
2480	}, func(ts *httptest.Server) {
2481		var srv *http.Server = ts.Config
2482		// Have the server configured with no specific cipher suites.
2483		// This tests that Go's defaults include the required one.
2484		srv.TLSConfig = nil
2485	})
2486	defer st.Close()
2487	st.greet()
2488}
2489
2490func (st *serverTester) onHeaderField(f hpack.HeaderField) {
2491	if f.Name == "date" {
2492		return
2493	}
2494	st.decodedHeaders = append(st.decodedHeaders, [2]string{f.Name, f.Value})
2495}
2496
2497func (st *serverTester) decodeHeader(headerBlock []byte) (pairs [][2]string) {
2498	st.decodedHeaders = nil
2499	if _, err := st.hpackDec.Write(headerBlock); err != nil {
2500		st.t.Fatalf("hpack decoding error: %v", err)
2501	}
2502	if err := st.hpackDec.Close(); err != nil {
2503		st.t.Fatalf("hpack decoding error: %v", err)
2504	}
2505	return st.decodedHeaders
2506}
2507
2508// testServerResponse sets up an idle HTTP/2 connection. The client function should
2509// write a single request that must be handled by the handler. This waits up to 5s
2510// for client to return, then up to an additional 2s for the handler to return.
2511func testServerResponse(t testing.TB,
2512	handler func(http.ResponseWriter, *http.Request) error,
2513	client func(*serverTester),
2514) {
2515	errc := make(chan error, 1)
2516	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2517		if r.Body == nil {
2518			t.Fatal("nil Body")
2519		}
2520		errc <- handler(w, r)
2521	})
2522	defer st.Close()
2523
2524	donec := make(chan bool)
2525	go func() {
2526		defer close(donec)
2527		st.greet()
2528		client(st)
2529	}()
2530
2531	select {
2532	case <-donec:
2533	case <-time.After(5 * time.Second):
2534		t.Fatal("timeout in client")
2535	}
2536
2537	select {
2538	case err := <-errc:
2539		if err != nil {
2540			t.Fatalf("Error in handler: %v", err)
2541		}
2542	case <-time.After(2 * time.Second):
2543		t.Fatal("timeout in handler")
2544	}
2545}
2546
2547// readBodyHandler returns an http Handler func that reads len(want)
2548// bytes from r.Body and fails t if the contents read were not
2549// the value of want.
2550func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *http.Request) {
2551	return func(w http.ResponseWriter, r *http.Request) {
2552		buf := make([]byte, len(want))
2553		_, err := io.ReadFull(r.Body, buf)
2554		if err != nil {
2555			t.Error(err)
2556			return
2557		}
2558		if string(buf) != want {
2559			t.Errorf("read %q; want %q", buf, want)
2560		}
2561	}
2562}
2563
2564// TestServerWithCurl currently fails, hence the LenientCipherSuites test. See:
2565//   https://github.com/tatsuhiro-t/nghttp2/issues/140 &
2566//   http://sourceforge.net/p/curl/bugs/1472/
2567func TestServerWithCurl(t *testing.T)                     { testServerWithCurl(t, false) }
2568func TestServerWithCurl_LenientCipherSuites(t *testing.T) { testServerWithCurl(t, true) }
2569
2570func testServerWithCurl(t *testing.T, permitProhibitedCipherSuites bool) {
2571	if runtime.GOOS != "linux" {
2572		t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway")
2573	}
2574	if testing.Short() {
2575		t.Skip("skipping curl test in short mode")
2576	}
2577	requireCurl(t)
2578	var gotConn int32
2579	testHookOnConn = func() { atomic.StoreInt32(&gotConn, 1) }
2580
2581	const msg = "Hello from curl!\n"
2582	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2583		w.Header().Set("Foo", "Bar")
2584		w.Header().Set("Client-Proto", r.Proto)
2585		io.WriteString(w, msg)
2586	}))
2587	ConfigureServer(ts.Config, &Server{
2588		PermitProhibitedCipherSuites: permitProhibitedCipherSuites,
2589	})
2590	ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
2591	ts.StartTLS()
2592	defer ts.Close()
2593
2594	t.Logf("Running test server for curl to hit at: %s", ts.URL)
2595	container := curl(t, "--silent", "--http2", "--insecure", "-v", ts.URL)
2596	defer kill(container)
2597	resc := make(chan interface{}, 1)
2598	go func() {
2599		res, err := dockerLogs(container)
2600		if err != nil {
2601			resc <- err
2602		} else {
2603			resc <- res
2604		}
2605	}()
2606	select {
2607	case res := <-resc:
2608		if err, ok := res.(error); ok {
2609			t.Fatal(err)
2610		}
2611		body := string(res.([]byte))
2612		// Search for both "key: value" and "key:value", since curl changed their format
2613		// Our Dockerfile contains the latest version (no space), but just in case people
2614		// didn't rebuild, check both.
2615		if !strings.Contains(body, "foo: Bar") && !strings.Contains(body, "foo:Bar") {
2616			t.Errorf("didn't see foo: Bar header")
2617			t.Logf("Got: %s", body)
2618		}
2619		if !strings.Contains(body, "client-proto: HTTP/2") && !strings.Contains(body, "client-proto:HTTP/2") {
2620			t.Errorf("didn't see client-proto: HTTP/2 header")
2621			t.Logf("Got: %s", res)
2622		}
2623		if !strings.Contains(string(res.([]byte)), msg) {
2624			t.Errorf("didn't see %q content", msg)
2625			t.Logf("Got: %s", res)
2626		}
2627	case <-time.After(3 * time.Second):
2628		t.Errorf("timeout waiting for curl")
2629	}
2630
2631	if atomic.LoadInt32(&gotConn) == 0 {
2632		t.Error("never saw an http2 connection")
2633	}
2634}
2635
2636var doh2load = flag.Bool("h2load", false, "Run h2load test")
2637
2638func TestServerWithH2Load(t *testing.T) {
2639	if !*doh2load {
2640		t.Skip("Skipping without --h2load flag.")
2641	}
2642	if runtime.GOOS != "linux" {
2643		t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway")
2644	}
2645	requireH2load(t)
2646
2647	msg := strings.Repeat("Hello, h2load!\n", 5000)
2648	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2649		io.WriteString(w, msg)
2650		w.(http.Flusher).Flush()
2651		io.WriteString(w, msg)
2652	}))
2653	ts.StartTLS()
2654	defer ts.Close()
2655
2656	cmd := exec.Command("docker", "run", "--net=host", "--entrypoint=/usr/local/bin/h2load", "gohttp2/curl",
2657		"-n100000", "-c100", "-m100", ts.URL)
2658	cmd.Stdout = os.Stdout
2659	cmd.Stderr = os.Stderr
2660	if err := cmd.Run(); err != nil {
2661		t.Fatal(err)
2662	}
2663}
2664
2665// Issue 12843
2666func TestServerDoS_MaxHeaderListSize(t *testing.T) {
2667	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
2668	defer st.Close()
2669
2670	// shake hands
2671	frameSize := defaultMaxReadFrameSize
2672	var advHeaderListSize *uint32
2673	st.greetAndCheckSettings(func(s Setting) error {
2674		switch s.ID {
2675		case SettingMaxFrameSize:
2676			if s.Val < minMaxFrameSize {
2677				frameSize = minMaxFrameSize
2678			} else if s.Val > maxFrameSize {
2679				frameSize = maxFrameSize
2680			} else {
2681				frameSize = int(s.Val)
2682			}
2683		case SettingMaxHeaderListSize:
2684			advHeaderListSize = &s.Val
2685		}
2686		return nil
2687	})
2688
2689	if advHeaderListSize == nil {
2690		t.Errorf("server didn't advertise a max header list size")
2691	} else if *advHeaderListSize == 0 {
2692		t.Errorf("server advertised a max header list size of 0")
2693	}
2694
2695	st.encodeHeaderField(":method", "GET")
2696	st.encodeHeaderField(":path", "/")
2697	st.encodeHeaderField(":scheme", "https")
2698	cookie := strings.Repeat("*", 4058)
2699	st.encodeHeaderField("cookie", cookie)
2700	st.writeHeaders(HeadersFrameParam{
2701		StreamID:      1,
2702		BlockFragment: st.headerBuf.Bytes(),
2703		EndStream:     true,
2704		EndHeaders:    false,
2705	})
2706
2707	// Capture the short encoding of a duplicate ~4K cookie, now
2708	// that we've already sent it once.
2709	st.headerBuf.Reset()
2710	st.encodeHeaderField("cookie", cookie)
2711
2712	// Now send 1MB of it.
2713	const size = 1 << 20
2714	b := bytes.Repeat(st.headerBuf.Bytes(), size/st.headerBuf.Len())
2715	for len(b) > 0 {
2716		chunk := b
2717		if len(chunk) > frameSize {
2718			chunk = chunk[:frameSize]
2719		}
2720		b = b[len(chunk):]
2721		st.fr.WriteContinuation(1, len(b) == 0, chunk)
2722	}
2723
2724	h := st.wantHeaders()
2725	if !h.HeadersEnded() {
2726		t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
2727	}
2728	headers := st.decodeHeader(h.HeaderBlockFragment())
2729	want := [][2]string{
2730		{":status", "431"},
2731		{"content-type", "text/html; charset=utf-8"},
2732		{"content-length", "63"},
2733	}
2734	if !reflect.DeepEqual(headers, want) {
2735		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
2736	}
2737}
2738
2739func TestServer_Response_Stream_With_Missing_Trailer(t *testing.T) {
2740	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2741		w.Header().Set("Trailer", "test-trailer")
2742		return nil
2743	}, func(st *serverTester) {
2744		getSlash(st)
2745		hf := st.wantHeaders()
2746		if !hf.HeadersEnded() {
2747			t.Fatal("want END_HEADERS flag")
2748		}
2749		df := st.wantData()
2750		if len(df.data) != 0 {
2751			t.Fatal("did not want data")
2752		}
2753		if !df.StreamEnded() {
2754			t.Fatal("want END_STREAM flag")
2755		}
2756	})
2757}
2758
2759func TestCompressionErrorOnWrite(t *testing.T) {
2760	const maxStrLen = 8 << 10
2761	var serverConfig *http.Server
2762	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2763		// No response body.
2764	}, func(ts *httptest.Server) {
2765		serverConfig = ts.Config
2766		serverConfig.MaxHeaderBytes = maxStrLen
2767	})
2768	st.addLogFilter("connection error: COMPRESSION_ERROR")
2769	defer st.Close()
2770	st.greet()
2771
2772	maxAllowed := st.sc.framer.maxHeaderStringLen()
2773
2774	// Crank this up, now that we have a conn connected with the
2775	// hpack.Decoder's max string length set has been initialized
2776	// from the earlier low ~8K value. We want this higher so don't
2777	// hit the max header list size. We only want to test hitting
2778	// the max string size.
2779	serverConfig.MaxHeaderBytes = 1 << 20
2780
2781	// First a request with a header that's exactly the max allowed size
2782	// for the hpack compression. It's still too long for the header list
2783	// size, so we'll get the 431 error, but that keeps the compression
2784	// context still valid.
2785	hbf := st.encodeHeader("foo", strings.Repeat("a", maxAllowed))
2786
2787	st.writeHeaders(HeadersFrameParam{
2788		StreamID:      1,
2789		BlockFragment: hbf,
2790		EndStream:     true,
2791		EndHeaders:    true,
2792	})
2793	h := st.wantHeaders()
2794	if !h.HeadersEnded() {
2795		t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
2796	}
2797	headers := st.decodeHeader(h.HeaderBlockFragment())
2798	want := [][2]string{
2799		{":status", "431"},
2800		{"content-type", "text/html; charset=utf-8"},
2801		{"content-length", "63"},
2802	}
2803	if !reflect.DeepEqual(headers, want) {
2804		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
2805	}
2806	df := st.wantData()
2807	if !strings.Contains(string(df.Data()), "HTTP Error 431") {
2808		t.Errorf("Unexpected data body: %q", df.Data())
2809	}
2810	if !df.StreamEnded() {
2811		t.Fatalf("expect data stream end")
2812	}
2813
2814	// And now send one that's just one byte too big.
2815	hbf = st.encodeHeader("bar", strings.Repeat("b", maxAllowed+1))
2816	st.writeHeaders(HeadersFrameParam{
2817		StreamID:      3,
2818		BlockFragment: hbf,
2819		EndStream:     true,
2820		EndHeaders:    true,
2821	})
2822	ga := st.wantGoAway()
2823	if ga.ErrCode != ErrCodeCompression {
2824		t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
2825	}
2826}
2827
2828func TestCompressionErrorOnClose(t *testing.T) {
2829	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2830		// No response body.
2831	})
2832	st.addLogFilter("connection error: COMPRESSION_ERROR")
2833	defer st.Close()
2834	st.greet()
2835
2836	hbf := st.encodeHeader("foo", "bar")
2837	hbf = hbf[:len(hbf)-1] // truncate one byte from the end, so hpack.Decoder.Close fails.
2838	st.writeHeaders(HeadersFrameParam{
2839		StreamID:      1,
2840		BlockFragment: hbf,
2841		EndStream:     true,
2842		EndHeaders:    true,
2843	})
2844	ga := st.wantGoAway()
2845	if ga.ErrCode != ErrCodeCompression {
2846		t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
2847	}
2848}
2849
2850// test that a server handler can read trailers from a client
2851func TestServerReadsTrailers(t *testing.T) {
2852	const testBody = "some test body"
2853	writeReq := func(st *serverTester) {
2854		st.writeHeaders(HeadersFrameParam{
2855			StreamID:      1, // clients send odd numbers
2856			BlockFragment: st.encodeHeader("trailer", "Foo, Bar", "trailer", "Baz"),
2857			EndStream:     false,
2858			EndHeaders:    true,
2859		})
2860		st.writeData(1, false, []byte(testBody))
2861		st.writeHeaders(HeadersFrameParam{
2862			StreamID: 1, // clients send odd numbers
2863			BlockFragment: st.encodeHeaderRaw(
2864				"foo", "foov",
2865				"bar", "barv",
2866				"baz", "bazv",
2867				"surprise", "wasn't declared; shouldn't show up",
2868			),
2869			EndStream:  true,
2870			EndHeaders: true,
2871		})
2872	}
2873	checkReq := func(r *http.Request) {
2874		wantTrailer := http.Header{
2875			"Foo": nil,
2876			"Bar": nil,
2877			"Baz": nil,
2878		}
2879		if !reflect.DeepEqual(r.Trailer, wantTrailer) {
2880			t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer)
2881		}
2882		slurp, err := ioutil.ReadAll(r.Body)
2883		if string(slurp) != testBody {
2884			t.Errorf("read body %q; want %q", slurp, testBody)
2885		}
2886		if err != nil {
2887			t.Fatalf("Body slurp: %v", err)
2888		}
2889		wantTrailerAfter := http.Header{
2890			"Foo": {"foov"},
2891			"Bar": {"barv"},
2892			"Baz": {"bazv"},
2893		}
2894		if !reflect.DeepEqual(r.Trailer, wantTrailerAfter) {
2895			t.Errorf("final Trailer = %v; want %v", r.Trailer, wantTrailerAfter)
2896		}
2897	}
2898	testServerRequest(t, writeReq, checkReq)
2899}
2900
2901// test that a server handler can send trailers
2902func TestServerWritesTrailers_WithFlush(t *testing.T)    { testServerWritesTrailers(t, true) }
2903func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) }
2904
2905func testServerWritesTrailers(t *testing.T, withFlush bool) {
2906	// See https://httpwg.github.io/specs/rfc7540.html#rfc.section.8.1.3
2907	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2908		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
2909		w.Header().Add("Trailer", "Server-Trailer-C")
2910		w.Header().Add("Trailer", "Transfer-Encoding, Content-Length, Trailer") // filtered
2911
2912		// Regular headers:
2913		w.Header().Set("Foo", "Bar")
2914		w.Header().Set("Content-Length", "5") // len("Hello")
2915
2916		io.WriteString(w, "Hello")
2917		if withFlush {
2918			w.(http.Flusher).Flush()
2919		}
2920		w.Header().Set("Server-Trailer-A", "valuea")
2921		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
2922		// After a flush, random keys like Server-Surprise shouldn't show up:
2923		w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!")
2924		// But we do permit promoting keys to trailers after a
2925		// flush if they start with the magic
2926		// otherwise-invalid "Trailer:" prefix:
2927		w.Header().Set("Trailer:Post-Header-Trailer", "hi1")
2928		w.Header().Set("Trailer:post-header-trailer2", "hi2")
2929		w.Header().Set("Trailer:Range", "invalid")
2930		w.Header().Set("Trailer:Foo\x01Bogus", "invalid")
2931		w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 7230 4.1.2")
2932		w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 7230 4.1.2")
2933		w.Header().Set("Trailer", "should not be included; Forbidden by RFC 7230 4.1.2")
2934		return nil
2935	}, func(st *serverTester) {
2936		getSlash(st)
2937		hf := st.wantHeaders()
2938		if hf.StreamEnded() {
2939			t.Fatal("response HEADERS had END_STREAM")
2940		}
2941		if !hf.HeadersEnded() {
2942			t.Fatal("response HEADERS didn't have END_HEADERS")
2943		}
2944		goth := st.decodeHeader(hf.HeaderBlockFragment())
2945		wanth := [][2]string{
2946			{":status", "200"},
2947			{"foo", "Bar"},
2948			{"trailer", "Server-Trailer-A, Server-Trailer-B"},
2949			{"trailer", "Server-Trailer-C"},
2950			{"trailer", "Transfer-Encoding, Content-Length, Trailer"},
2951			{"content-type", "text/plain; charset=utf-8"},
2952			{"content-length", "5"},
2953		}
2954		if !reflect.DeepEqual(goth, wanth) {
2955			t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2956		}
2957		df := st.wantData()
2958		if string(df.Data()) != "Hello" {
2959			t.Fatalf("Client read %q; want Hello", df.Data())
2960		}
2961		if df.StreamEnded() {
2962			t.Fatalf("data frame had STREAM_ENDED")
2963		}
2964		tf := st.wantHeaders() // for the trailers
2965		if !tf.StreamEnded() {
2966			t.Fatalf("trailers HEADERS lacked END_STREAM")
2967		}
2968		if !tf.HeadersEnded() {
2969			t.Fatalf("trailers HEADERS lacked END_HEADERS")
2970		}
2971		wanth = [][2]string{
2972			{"post-header-trailer", "hi1"},
2973			{"post-header-trailer2", "hi2"},
2974			{"server-trailer-a", "valuea"},
2975			{"server-trailer-c", "valuec"},
2976		}
2977		goth = st.decodeHeader(tf.HeaderBlockFragment())
2978		if !reflect.DeepEqual(goth, wanth) {
2979			t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2980		}
2981	})
2982}
2983
2984// validate transmitted header field names & values
2985// golang.org/issue/14048
2986func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
2987	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2988		w.Header().Add("OK1", "x")
2989		w.Header().Add("Bad:Colon", "x") // colon (non-token byte) in key
2990		w.Header().Add("Bad1\x00", "x")  // null in key
2991		w.Header().Add("Bad2", "x\x00y") // null in value
2992		return nil
2993	}, func(st *serverTester) {
2994		getSlash(st)
2995		hf := st.wantHeaders()
2996		if !hf.StreamEnded() {
2997			t.Error("response HEADERS lacked END_STREAM")
2998		}
2999		if !hf.HeadersEnded() {
3000			t.Fatal("response HEADERS didn't have END_HEADERS")
3001		}
3002		goth := st.decodeHeader(hf.HeaderBlockFragment())
3003		wanth := [][2]string{
3004			{":status", "200"},
3005			{"ok1", "x"},
3006			{"content-length", "0"},
3007		}
3008		if !reflect.DeepEqual(goth, wanth) {
3009			t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
3010		}
3011	})
3012}
3013
3014func BenchmarkServerGets(b *testing.B) {
3015	defer disableGoroutineTracking()()
3016	b.ReportAllocs()
3017
3018	const msg = "Hello, world"
3019	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3020		io.WriteString(w, msg)
3021	})
3022	defer st.Close()
3023	st.greet()
3024
3025	// Give the server quota to reply. (plus it has the 64KB)
3026	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3027		b.Fatal(err)
3028	}
3029
3030	for i := 0; i < b.N; i++ {
3031		id := 1 + uint32(i)*2
3032		st.writeHeaders(HeadersFrameParam{
3033			StreamID:      id,
3034			BlockFragment: st.encodeHeader(),
3035			EndStream:     true,
3036			EndHeaders:    true,
3037		})
3038		st.wantHeaders()
3039		df := st.wantData()
3040		if !df.StreamEnded() {
3041			b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3042		}
3043	}
3044}
3045
3046func BenchmarkServerPosts(b *testing.B) {
3047	defer disableGoroutineTracking()()
3048	b.ReportAllocs()
3049
3050	const msg = "Hello, world"
3051	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3052		// Consume the (empty) body from th peer before replying, otherwise
3053		// the server will sometimes (depending on scheduling) send the peer a
3054		// a RST_STREAM with the CANCEL error code.
3055		if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
3056			b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
3057		}
3058		io.WriteString(w, msg)
3059	})
3060	defer st.Close()
3061	st.greet()
3062
3063	// Give the server quota to reply. (plus it has the 64KB)
3064	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3065		b.Fatal(err)
3066	}
3067
3068	for i := 0; i < b.N; i++ {
3069		id := 1 + uint32(i)*2
3070		st.writeHeaders(HeadersFrameParam{
3071			StreamID:      id,
3072			BlockFragment: st.encodeHeader(":method", "POST"),
3073			EndStream:     false,
3074			EndHeaders:    true,
3075		})
3076		st.writeData(id, true, nil)
3077		st.wantHeaders()
3078		df := st.wantData()
3079		if !df.StreamEnded() {
3080			b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3081		}
3082	}
3083}
3084
3085// Send a stream of messages from server to client in separate data frames.
3086// Brings up performance issues seen in long streams.
3087// Created to show problem in go issue #18502
3088func BenchmarkServerToClientStreamDefaultOptions(b *testing.B) {
3089	benchmarkServerToClientStream(b)
3090}
3091
3092// Justification for Change-Id: Iad93420ef6c3918f54249d867098f1dadfa324d8
3093// Expect to see memory/alloc reduction by opting in to Frame reuse with the Framer.
3094func BenchmarkServerToClientStreamReuseFrames(b *testing.B) {
3095	benchmarkServerToClientStream(b, optFramerReuseFrames)
3096}
3097
3098func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
3099	defer disableGoroutineTracking()()
3100	b.ReportAllocs()
3101	const msgLen = 1
3102	// default window size
3103	const windowSize = 1<<16 - 1
3104
3105	// next message to send from the server and for the client to expect
3106	nextMsg := func(i int) []byte {
3107		msg := make([]byte, msgLen)
3108		msg[0] = byte(i)
3109		if len(msg) != msgLen {
3110			panic("invalid test setup msg length")
3111		}
3112		return msg
3113	}
3114
3115	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3116		// Consume the (empty) body from th peer before replying, otherwise
3117		// the server will sometimes (depending on scheduling) send the peer a
3118		// a RST_STREAM with the CANCEL error code.
3119		if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
3120			b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
3121		}
3122		for i := 0; i < b.N; i += 1 {
3123			w.Write(nextMsg(i))
3124			w.(http.Flusher).Flush()
3125		}
3126	}, newServerOpts...)
3127	defer st.Close()
3128	st.greet()
3129
3130	const id = uint32(1)
3131
3132	st.writeHeaders(HeadersFrameParam{
3133		StreamID:      id,
3134		BlockFragment: st.encodeHeader(":method", "POST"),
3135		EndStream:     false,
3136		EndHeaders:    true,
3137	})
3138
3139	st.writeData(id, true, nil)
3140	st.wantHeaders()
3141
3142	var pendingWindowUpdate = uint32(0)
3143
3144	for i := 0; i < b.N; i += 1 {
3145		expected := nextMsg(i)
3146		df := st.wantData()
3147		if bytes.Compare(expected, df.data) != 0 {
3148			b.Fatalf("Bad message received; want %v; got %v", expected, df.data)
3149		}
3150		// try to send infrequent but large window updates so they don't overwhelm the test
3151		pendingWindowUpdate += uint32(len(df.data))
3152		if pendingWindowUpdate >= windowSize/2 {
3153			if err := st.fr.WriteWindowUpdate(0, pendingWindowUpdate); err != nil {
3154				b.Fatal(err)
3155			}
3156			if err := st.fr.WriteWindowUpdate(id, pendingWindowUpdate); err != nil {
3157				b.Fatal(err)
3158			}
3159			pendingWindowUpdate = 0
3160		}
3161	}
3162	df := st.wantData()
3163	if !df.StreamEnded() {
3164		b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3165	}
3166}
3167
3168// go-fuzz bug, originally reported at https://github.com/bradfitz/http2/issues/53
3169// Verify we don't hang.
3170func TestIssue53(t *testing.T) {
3171	const data = "PRI * HTTP/2.0\r\n\r\nSM" +
3172		"\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad"
3173	s := &http.Server{
3174		ErrorLog: log.New(io.MultiWriter(stderrv(), twriter{t: t}), "", log.LstdFlags),
3175		Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
3176			w.Write([]byte("hello"))
3177		}),
3178	}
3179	s2 := &Server{
3180		MaxReadFrameSize:             1 << 16,
3181		PermitProhibitedCipherSuites: true,
3182	}
3183	c := &issue53Conn{[]byte(data), false, false}
3184	s2.ServeConn(c, &ServeConnOpts{BaseConfig: s})
3185	if !c.closed {
3186		t.Fatal("connection is not closed")
3187	}
3188}
3189
3190type issue53Conn struct {
3191	data    []byte
3192	closed  bool
3193	written bool
3194}
3195
3196func (c *issue53Conn) Read(b []byte) (n int, err error) {
3197	if len(c.data) == 0 {
3198		return 0, io.EOF
3199	}
3200	n = copy(b, c.data)
3201	c.data = c.data[n:]
3202	return
3203}
3204
3205func (c *issue53Conn) Write(b []byte) (n int, err error) {
3206	c.written = true
3207	return len(b), nil
3208}
3209
3210func (c *issue53Conn) Close() error {
3211	c.closed = true
3212	return nil
3213}
3214
3215func (c *issue53Conn) LocalAddr() net.Addr {
3216	return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3217}
3218func (c *issue53Conn) RemoteAddr() net.Addr {
3219	return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3220}
3221func (c *issue53Conn) SetDeadline(t time.Time) error      { return nil }
3222func (c *issue53Conn) SetReadDeadline(t time.Time) error  { return nil }
3223func (c *issue53Conn) SetWriteDeadline(t time.Time) error { return nil }
3224
3225// golang.org/issue/33839
3226func TestServeConnOptsNilReceiverBehavior(t *testing.T) {
3227	defer func() {
3228		if r := recover(); r != nil {
3229			t.Errorf("got a panic that should not happen: %v", r)
3230		}
3231	}()
3232
3233	var o *ServeConnOpts
3234	if o.context() == nil {
3235		t.Error("o.context should not return nil")
3236	}
3237	if o.baseConfig() == nil {
3238		t.Error("o.baseConfig should not return nil")
3239	}
3240	if o.handler() == nil {
3241		t.Error("o.handler should not return nil")
3242	}
3243}
3244
3245// golang.org/issue/12895
3246func TestConfigureServer(t *testing.T) {
3247	tests := []struct {
3248		name      string
3249		tlsConfig *tls.Config
3250		wantErr   string
3251	}{
3252		{
3253			name: "empty server",
3254		},
3255		{
3256			name: "just the required cipher suite",
3257			tlsConfig: &tls.Config{
3258				CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3259			},
3260		},
3261		{
3262			name: "just the alternative required cipher suite",
3263			tlsConfig: &tls.Config{
3264				CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
3265			},
3266		},
3267		{
3268			name: "missing required cipher suite",
3269			tlsConfig: &tls.Config{
3270				CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384},
3271			},
3272			wantErr: "is missing an HTTP/2-required",
3273		},
3274		{
3275			name: "required after bad",
3276			tlsConfig: &tls.Config{
3277				CipherSuites: []uint16{tls.TLS_RSA_WITH_RC4_128_SHA, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3278			},
3279			wantErr: "contains an HTTP/2-approved cipher suite (0xc02f), but it comes after",
3280		},
3281		{
3282			name: "bad after required",
3283			tlsConfig: &tls.Config{
3284				CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_RC4_128_SHA},
3285			},
3286		},
3287	}
3288	for _, tt := range tests {
3289		srv := &http.Server{TLSConfig: tt.tlsConfig}
3290		err := ConfigureServer(srv, nil)
3291		if (err != nil) != (tt.wantErr != "") {
3292			if tt.wantErr != "" {
3293				t.Errorf("%s: success, but want error", tt.name)
3294			} else {
3295				t.Errorf("%s: unexpected error: %v", tt.name, err)
3296			}
3297		}
3298		if err != nil && tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) {
3299			t.Errorf("%s: err = %v; want substring %q", tt.name, err, tt.wantErr)
3300		}
3301		if err == nil && !srv.TLSConfig.PreferServerCipherSuites {
3302			t.Errorf("%s: PreferServerCipherSuite is false; want true", tt.name)
3303		}
3304	}
3305}
3306
3307func TestServerRejectHeadWithBody(t *testing.T) {
3308	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3309		// No response body.
3310	})
3311	defer st.Close()
3312	st.greet()
3313	st.writeHeaders(HeadersFrameParam{
3314		StreamID:      1, // clients send odd numbers
3315		BlockFragment: st.encodeHeader(":method", "HEAD"),
3316		EndStream:     false, // what we're testing, a bogus HEAD request with body
3317		EndHeaders:    true,
3318	})
3319	st.wantRSTStream(1, ErrCodeProtocol)
3320}
3321
3322func TestServerNoAutoContentLengthOnHead(t *testing.T) {
3323	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3324		// No response body. (or smaller than one frame)
3325	})
3326	defer st.Close()
3327	st.greet()
3328	st.writeHeaders(HeadersFrameParam{
3329		StreamID:      1, // clients send odd numbers
3330		BlockFragment: st.encodeHeader(":method", "HEAD"),
3331		EndStream:     true,
3332		EndHeaders:    true,
3333	})
3334	h := st.wantHeaders()
3335	headers := st.decodeHeader(h.HeaderBlockFragment())
3336	want := [][2]string{
3337		{":status", "200"},
3338	}
3339	if !reflect.DeepEqual(headers, want) {
3340		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
3341	}
3342}
3343
3344// golang.org/issue/13495
3345func TestServerNoDuplicateContentType(t *testing.T) {
3346	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3347		w.Header()["Content-Type"] = []string{""}
3348		fmt.Fprintf(w, "<html><head></head><body>hi</body></html>")
3349	})
3350	defer st.Close()
3351	st.greet()
3352	st.writeHeaders(HeadersFrameParam{
3353		StreamID:      1,
3354		BlockFragment: st.encodeHeader(),
3355		EndStream:     true,
3356		EndHeaders:    true,
3357	})
3358	h := st.wantHeaders()
3359	headers := st.decodeHeader(h.HeaderBlockFragment())
3360	want := [][2]string{
3361		{":status", "200"},
3362		{"content-type", ""},
3363		{"content-length", "41"},
3364	}
3365	if !reflect.DeepEqual(headers, want) {
3366		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
3367	}
3368}
3369
3370func disableGoroutineTracking() (restore func()) {
3371	old := DebugGoroutines
3372	DebugGoroutines = false
3373	return func() { DebugGoroutines = old }
3374}
3375
3376func BenchmarkServer_GetRequest(b *testing.B) {
3377	defer disableGoroutineTracking()()
3378	b.ReportAllocs()
3379	const msg = "Hello, world."
3380	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3381		n, err := io.Copy(ioutil.Discard, r.Body)
3382		if err != nil || n > 0 {
3383			b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3384		}
3385		io.WriteString(w, msg)
3386	})
3387	defer st.Close()
3388
3389	st.greet()
3390	// Give the server quota to reply. (plus it has the 64KB)
3391	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3392		b.Fatal(err)
3393	}
3394	hbf := st.encodeHeader(":method", "GET")
3395	for i := 0; i < b.N; i++ {
3396		streamID := uint32(1 + 2*i)
3397		st.writeHeaders(HeadersFrameParam{
3398			StreamID:      streamID,
3399			BlockFragment: hbf,
3400			EndStream:     true,
3401			EndHeaders:    true,
3402		})
3403		st.wantHeaders()
3404		st.wantData()
3405	}
3406}
3407
3408func BenchmarkServer_PostRequest(b *testing.B) {
3409	defer disableGoroutineTracking()()
3410	b.ReportAllocs()
3411	const msg = "Hello, world."
3412	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3413		n, err := io.Copy(ioutil.Discard, r.Body)
3414		if err != nil || n > 0 {
3415			b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3416		}
3417		io.WriteString(w, msg)
3418	})
3419	defer st.Close()
3420	st.greet()
3421	// Give the server quota to reply. (plus it has the 64KB)
3422	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3423		b.Fatal(err)
3424	}
3425	hbf := st.encodeHeader(":method", "POST")
3426	for i := 0; i < b.N; i++ {
3427		streamID := uint32(1 + 2*i)
3428		st.writeHeaders(HeadersFrameParam{
3429			StreamID:      streamID,
3430			BlockFragment: hbf,
3431			EndStream:     false,
3432			EndHeaders:    true,
3433		})
3434		st.writeData(streamID, true, nil)
3435		st.wantHeaders()
3436		st.wantData()
3437	}
3438}
3439
3440type connStateConn struct {
3441	net.Conn
3442	cs tls.ConnectionState
3443}
3444
3445func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs }
3446
3447// golang.org/issue/12737 -- handle any net.Conn, not just
3448// *tls.Conn.
3449func TestServerHandleCustomConn(t *testing.T) {
3450	var s Server
3451	c1, c2 := net.Pipe()
3452	clientDone := make(chan struct{})
3453	handlerDone := make(chan struct{})
3454	var req *http.Request
3455	go func() {
3456		defer close(clientDone)
3457		defer c2.Close()
3458		fr := NewFramer(c2, c2)
3459		io.WriteString(c2, ClientPreface)
3460		fr.WriteSettings()
3461		fr.WriteSettingsAck()
3462		f, err := fr.ReadFrame()
3463		if err != nil {
3464			t.Error(err)
3465			return
3466		}
3467		if sf, ok := f.(*SettingsFrame); !ok || sf.IsAck() {
3468			t.Errorf("Got %v; want non-ACK SettingsFrame", summarizeFrame(f))
3469			return
3470		}
3471		f, err = fr.ReadFrame()
3472		if err != nil {
3473			t.Error(err)
3474			return
3475		}
3476		if sf, ok := f.(*SettingsFrame); !ok || !sf.IsAck() {
3477			t.Errorf("Got %v; want ACK SettingsFrame", summarizeFrame(f))
3478			return
3479		}
3480		var henc hpackEncoder
3481		fr.WriteHeaders(HeadersFrameParam{
3482			StreamID:      1,
3483			BlockFragment: henc.encodeHeaderRaw(t, ":method", "GET", ":path", "/", ":scheme", "https", ":authority", "foo.com"),
3484			EndStream:     true,
3485			EndHeaders:    true,
3486		})
3487		go io.Copy(ioutil.Discard, c2)
3488		<-handlerDone
3489	}()
3490	const testString = "my custom ConnectionState"
3491	fakeConnState := tls.ConnectionState{
3492		ServerName:  testString,
3493		Version:     tls.VersionTLS12,
3494		CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
3495	}
3496	go s.ServeConn(connStateConn{c1, fakeConnState}, &ServeConnOpts{
3497		BaseConfig: &http.Server{
3498			Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3499				defer close(handlerDone)
3500				req = r
3501			}),
3502		}})
3503	select {
3504	case <-clientDone:
3505	case <-time.After(5 * time.Second):
3506		t.Fatal("timeout waiting for handler")
3507	}
3508	if req.TLS == nil {
3509		t.Fatalf("Request.TLS is nil. Got: %#v", req)
3510	}
3511	if req.TLS.ServerName != testString {
3512		t.Fatalf("Request.TLS = %+v; want ServerName of %q", req.TLS, testString)
3513	}
3514}
3515
3516// golang.org/issue/14214
3517func TestServer_Rejects_ConnHeaders(t *testing.T) {
3518	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3519		t.Error("should not get to Handler")
3520	})
3521	defer st.Close()
3522	st.greet()
3523	st.bodylessReq1("connection", "foo")
3524	hf := st.wantHeaders()
3525	goth := st.decodeHeader(hf.HeaderBlockFragment())
3526	wanth := [][2]string{
3527		{":status", "400"},
3528		{"content-type", "text/plain; charset=utf-8"},
3529		{"x-content-type-options", "nosniff"},
3530		{"content-length", "51"},
3531	}
3532	if !reflect.DeepEqual(goth, wanth) {
3533		t.Errorf("Got headers %v; want %v", goth, wanth)
3534	}
3535}
3536
3537type hpackEncoder struct {
3538	enc *hpack.Encoder
3539	buf bytes.Buffer
3540}
3541
3542func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte {
3543	if len(headers)%2 == 1 {
3544		panic("odd number of kv args")
3545	}
3546	he.buf.Reset()
3547	if he.enc == nil {
3548		he.enc = hpack.NewEncoder(&he.buf)
3549	}
3550	for len(headers) > 0 {
3551		k, v := headers[0], headers[1]
3552		err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v})
3553		if err != nil {
3554			t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
3555		}
3556		headers = headers[2:]
3557	}
3558	return he.buf.Bytes()
3559}
3560
3561func TestCheckValidHTTP2Request(t *testing.T) {
3562	tests := []struct {
3563		h    http.Header
3564		want error
3565	}{
3566		{
3567			h:    http.Header{"Te": {"trailers"}},
3568			want: nil,
3569		},
3570		{
3571			h:    http.Header{"Te": {"trailers", "bogus"}},
3572			want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`),
3573		},
3574		{
3575			h:    http.Header{"Foo": {""}},
3576			want: nil,
3577		},
3578		{
3579			h:    http.Header{"Connection": {""}},
3580			want: errors.New(`request header "Connection" is not valid in HTTP/2`),
3581		},
3582		{
3583			h:    http.Header{"Proxy-Connection": {""}},
3584			want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`),
3585		},
3586		{
3587			h:    http.Header{"Keep-Alive": {""}},
3588			want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`),
3589		},
3590		{
3591			h:    http.Header{"Upgrade": {""}},
3592			want: errors.New(`request header "Upgrade" is not valid in HTTP/2`),
3593		},
3594	}
3595	for i, tt := range tests {
3596		got := checkValidHTTP2RequestHeaders(tt.h)
3597		if !equalError(got, tt.want) {
3598			t.Errorf("%d. checkValidHTTP2Request = %v; want %v", i, got, tt.want)
3599		}
3600	}
3601}
3602
3603// golang.org/issue/14030
3604func TestExpect100ContinueAfterHandlerWrites(t *testing.T) {
3605	const msg = "Hello"
3606	const msg2 = "World"
3607
3608	doRead := make(chan bool, 1)
3609	defer close(doRead) // fallback cleanup
3610
3611	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3612		io.WriteString(w, msg)
3613		w.(http.Flusher).Flush()
3614
3615		// Do a read, which might force a 100-continue status to be sent.
3616		<-doRead
3617		r.Body.Read(make([]byte, 10))
3618
3619		io.WriteString(w, msg2)
3620
3621	}, optOnlyServer)
3622	defer st.Close()
3623
3624	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3625	defer tr.CloseIdleConnections()
3626
3627	req, _ := http.NewRequest("POST", st.ts.URL, io.LimitReader(neverEnding('A'), 2<<20))
3628	req.Header.Set("Expect", "100-continue")
3629
3630	res, err := tr.RoundTrip(req)
3631	if err != nil {
3632		t.Fatal(err)
3633	}
3634	defer res.Body.Close()
3635
3636	buf := make([]byte, len(msg))
3637	if _, err := io.ReadFull(res.Body, buf); err != nil {
3638		t.Fatal(err)
3639	}
3640	if string(buf) != msg {
3641		t.Fatalf("msg = %q; want %q", buf, msg)
3642	}
3643
3644	doRead <- true
3645
3646	if _, err := io.ReadFull(res.Body, buf); err != nil {
3647		t.Fatal(err)
3648	}
3649	if string(buf) != msg2 {
3650		t.Fatalf("second msg = %q; want %q", buf, msg2)
3651	}
3652}
3653
3654type funcReader func([]byte) (n int, err error)
3655
3656func (f funcReader) Read(p []byte) (n int, err error) { return f(p) }
3657
3658// golang.org/issue/16481 -- return flow control when streams close with unread data.
3659// (The Server version of the bug. See also TestUnreadFlowControlReturned_Transport)
3660func TestUnreadFlowControlReturned_Server(t *testing.T) {
3661	for _, tt := range []struct {
3662		name  string
3663		reqFn func(r *http.Request)
3664	}{
3665		{
3666			"body-open",
3667			func(r *http.Request) {},
3668		},
3669		{
3670			"body-closed",
3671			func(r *http.Request) {
3672				r.Body.Close()
3673			},
3674		},
3675		{
3676			"read-1-byte-and-close",
3677			func(r *http.Request) {
3678				b := make([]byte, 1)
3679				r.Body.Read(b)
3680				r.Body.Close()
3681			},
3682		},
3683	} {
3684		t.Run(tt.name, func(t *testing.T) {
3685			unblock := make(chan bool, 1)
3686			defer close(unblock)
3687
3688			timeOut := time.NewTimer(5 * time.Second)
3689			defer timeOut.Stop()
3690			st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3691				// Don't read the 16KB request body. Wait until the client's
3692				// done sending it and then return. This should cause the Server
3693				// to then return those 16KB of flow control to the client.
3694				tt.reqFn(r)
3695				select {
3696				case <-unblock:
3697				case <-timeOut.C:
3698					t.Fatal(tt.name, "timedout")
3699				}
3700			}, optOnlyServer)
3701			defer st.Close()
3702
3703			tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3704			defer tr.CloseIdleConnections()
3705
3706			// This previously hung on the 4th iteration.
3707			iters := 100
3708			if testing.Short() {
3709				iters = 20
3710			}
3711			for i := 0; i < iters; i++ {
3712				body := io.MultiReader(
3713					io.LimitReader(neverEnding('A'), 16<<10),
3714					funcReader(func([]byte) (n int, err error) {
3715						unblock <- true
3716						return 0, io.EOF
3717					}),
3718				)
3719				req, _ := http.NewRequest("POST", st.ts.URL, body)
3720				res, err := tr.RoundTrip(req)
3721				if err != nil {
3722					t.Fatal(tt.name, err)
3723				}
3724				res.Body.Close()
3725			}
3726		})
3727	}
3728}
3729
3730func TestServerIdleTimeout(t *testing.T) {
3731	if testing.Short() {
3732		t.Skip("skipping in short mode")
3733	}
3734
3735	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3736	}, func(h2s *Server) {
3737		h2s.IdleTimeout = 500 * time.Millisecond
3738	})
3739	defer st.Close()
3740
3741	st.greet()
3742	ga := st.wantGoAway()
3743	if ga.ErrCode != ErrCodeNo {
3744		t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
3745	}
3746}
3747
3748func TestServerIdleTimeout_AfterRequest(t *testing.T) {
3749	if testing.Short() {
3750		t.Skip("skipping in short mode")
3751	}
3752	const timeout = 250 * time.Millisecond
3753
3754	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3755		time.Sleep(timeout * 2)
3756	}, func(h2s *Server) {
3757		h2s.IdleTimeout = timeout
3758	})
3759	defer st.Close()
3760
3761	st.greet()
3762
3763	// Send a request which takes twice the timeout. Verifies the
3764	// idle timeout doesn't fire while we're in a request:
3765	st.bodylessReq1()
3766	st.wantHeaders()
3767
3768	// But the idle timeout should be rearmed after the request
3769	// is done:
3770	ga := st.wantGoAway()
3771	if ga.ErrCode != ErrCodeNo {
3772		t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
3773	}
3774}
3775
3776// grpc-go closes the Request.Body currently with a Read.
3777// Verify that it doesn't race.
3778// See https://github.com/grpc/grpc-go/pull/938
3779func TestRequestBodyReadCloseRace(t *testing.T) {
3780	for i := 0; i < 100; i++ {
3781		body := &requestBody{
3782			pipe: &pipe{
3783				b: new(bytes.Buffer),
3784			},
3785		}
3786		body.pipe.CloseWithError(io.EOF)
3787
3788		done := make(chan bool, 1)
3789		buf := make([]byte, 10)
3790		go func() {
3791			time.Sleep(1 * time.Millisecond)
3792			body.Close()
3793			done <- true
3794		}()
3795		body.Read(buf)
3796		<-done
3797	}
3798}
3799
3800func TestIssue20704Race(t *testing.T) {
3801	if testing.Short() && os.Getenv("GO_BUILDER_NAME") == "" {
3802		t.Skip("skipping in short mode")
3803	}
3804	const (
3805		itemSize  = 1 << 10
3806		itemCount = 100
3807	)
3808
3809	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3810		for i := 0; i < itemCount; i++ {
3811			_, err := w.Write(make([]byte, itemSize))
3812			if err != nil {
3813				return
3814			}
3815		}
3816	}, optOnlyServer)
3817	defer st.Close()
3818
3819	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3820	defer tr.CloseIdleConnections()
3821	cl := &http.Client{Transport: tr}
3822
3823	for i := 0; i < 1000; i++ {
3824		resp, err := cl.Get(st.ts.URL)
3825		if err != nil {
3826			t.Fatal(err)
3827		}
3828		// Force a RST stream to the server by closing without
3829		// reading the body:
3830		resp.Body.Close()
3831	}
3832}
3833
3834func TestServer_Rejects_TooSmall(t *testing.T) {
3835	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3836		ioutil.ReadAll(r.Body)
3837		return nil
3838	}, func(st *serverTester) {
3839		st.writeHeaders(HeadersFrameParam{
3840			StreamID: 1, // clients send odd numbers
3841			BlockFragment: st.encodeHeader(
3842				":method", "POST",
3843				"content-length", "4",
3844			),
3845			EndStream:  false, // to say DATA frames are coming
3846			EndHeaders: true,
3847		})
3848		st.writeData(1, true, []byte("12345"))
3849
3850		st.wantRSTStream(1, ErrCodeProtocol)
3851	})
3852}
3853
3854// Tests that a handler setting "Connection: close" results in a GOAWAY being sent,
3855// and the connection still completing.
3856func TestServerHandlerConnectionClose(t *testing.T) {
3857	unblockHandler := make(chan bool, 1)
3858	defer close(unblockHandler) // backup; in case of errors
3859	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3860		w.Header().Set("Connection", "close")
3861		w.Header().Set("Foo", "bar")
3862		w.(http.Flusher).Flush()
3863		<-unblockHandler
3864		return nil
3865	}, func(st *serverTester) {
3866		st.writeHeaders(HeadersFrameParam{
3867			StreamID:      1,
3868			BlockFragment: st.encodeHeader(),
3869			EndStream:     true,
3870			EndHeaders:    true,
3871		})
3872		var sawGoAway bool
3873		var sawRes bool
3874		for {
3875			f, err := st.readFrame()
3876			if err == io.EOF {
3877				break
3878			}
3879			if err != nil {
3880				t.Fatal(err)
3881			}
3882			switch f := f.(type) {
3883			case *GoAwayFrame:
3884				sawGoAway = true
3885				unblockHandler <- true
3886				if f.LastStreamID != 1 || f.ErrCode != ErrCodeNo {
3887					t.Errorf("unexpected GOAWAY frame: %v", summarizeFrame(f))
3888				}
3889			case *HeadersFrame:
3890				goth := st.decodeHeader(f.HeaderBlockFragment())
3891				wanth := [][2]string{
3892					{":status", "200"},
3893					{"foo", "bar"},
3894				}
3895				if !reflect.DeepEqual(goth, wanth) {
3896					t.Errorf("got headers %v; want %v", goth, wanth)
3897				}
3898				sawRes = true
3899			case *DataFrame:
3900				if f.StreamID != 1 || !f.StreamEnded() || len(f.Data()) != 0 {
3901					t.Errorf("unexpected DATA frame: %v", summarizeFrame(f))
3902				}
3903			default:
3904				t.Logf("unexpected frame: %v", summarizeFrame(f))
3905			}
3906		}
3907		if !sawGoAway {
3908			t.Errorf("didn't see GOAWAY")
3909		}
3910		if !sawRes {
3911			t.Errorf("didn't see response")
3912		}
3913	})
3914}
3915
3916func TestServer_Headers_HalfCloseRemote(t *testing.T) {
3917	var st *serverTester
3918	writeData := make(chan bool)
3919	writeHeaders := make(chan bool)
3920	leaveHandler := make(chan bool)
3921	st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3922		if st.stream(1) == nil {
3923			t.Errorf("nil stream 1 in handler")
3924		}
3925		if got, want := st.streamState(1), stateOpen; got != want {
3926			t.Errorf("in handler, state is %v; want %v", got, want)
3927		}
3928		writeData <- true
3929		if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
3930			t.Errorf("body read = %d, %v; want 0, EOF", n, err)
3931		}
3932		if got, want := st.streamState(1), stateHalfClosedRemote; got != want {
3933			t.Errorf("in handler, state is %v; want %v", got, want)
3934		}
3935		writeHeaders <- true
3936
3937		<-leaveHandler
3938	})
3939	st.greet()
3940
3941	st.writeHeaders(HeadersFrameParam{
3942		StreamID:      1,
3943		BlockFragment: st.encodeHeader(),
3944		EndStream:     false, // keep it open
3945		EndHeaders:    true,
3946	})
3947	<-writeData
3948	st.writeData(1, true, nil)
3949
3950	<-writeHeaders
3951
3952	st.writeHeaders(HeadersFrameParam{
3953		StreamID:      1,
3954		BlockFragment: st.encodeHeader(),
3955		EndStream:     false, // keep it open
3956		EndHeaders:    true,
3957	})
3958
3959	defer close(leaveHandler)
3960
3961	st.wantRSTStream(1, ErrCodeStreamClosed)
3962}
3963
3964func TestServerGracefulShutdown(t *testing.T) {
3965	var st *serverTester
3966	handlerDone := make(chan struct{})
3967	st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3968		defer close(handlerDone)
3969		go st.ts.Config.Shutdown(context.Background())
3970
3971		ga := st.wantGoAway()
3972		if ga.ErrCode != ErrCodeNo {
3973			t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
3974		}
3975		if ga.LastStreamID != 1 {
3976			t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID)
3977		}
3978
3979		w.Header().Set("x-foo", "bar")
3980	})
3981	defer st.Close()
3982
3983	st.greet()
3984	st.bodylessReq1()
3985
3986	select {
3987	case <-handlerDone:
3988	case <-time.After(5 * time.Second):
3989		t.Fatalf("server did not shutdown?")
3990	}
3991	hf := st.wantHeaders()
3992	goth := st.decodeHeader(hf.HeaderBlockFragment())
3993	wanth := [][2]string{
3994		{":status", "200"},
3995		{"x-foo", "bar"},
3996		{"content-length", "0"},
3997	}
3998	if !reflect.DeepEqual(goth, wanth) {
3999		t.Errorf("Got headers %v; want %v", goth, wanth)
4000	}
4001
4002	n, err := st.cc.Read([]byte{0})
4003	if n != 0 || err == nil {
4004		t.Errorf("Read = %v, %v; want 0, non-nil", n, err)
4005	}
4006}
4007
4008// Issue 31753: don't sniff when Content-Encoding is set
4009func TestContentEncodingNoSniffing(t *testing.T) {
4010	type resp struct {
4011		name string
4012		body []byte
4013		// setting Content-Encoding as an interface instead of a string
4014		// directly, so as to differentiate between 3 states:
4015		//    unset, empty string "" and set string "foo/bar".
4016		contentEncoding interface{}
4017		wantContentType string
4018	}
4019
4020	resps := []*resp{
4021		{
4022			name:            "gzip content-encoding, gzipped", // don't sniff.
4023			contentEncoding: "application/gzip",
4024			wantContentType: "",
4025			body: func() []byte {
4026				buf := new(bytes.Buffer)
4027				gzw := gzip.NewWriter(buf)
4028				gzw.Write([]byte("doctype html><p>Hello</p>"))
4029				gzw.Close()
4030				return buf.Bytes()
4031			}(),
4032		},
4033		{
4034			name:            "zlib content-encoding, zlibbed", // don't sniff.
4035			contentEncoding: "application/zlib",
4036			wantContentType: "",
4037			body: func() []byte {
4038				buf := new(bytes.Buffer)
4039				zw := zlib.NewWriter(buf)
4040				zw.Write([]byte("doctype html><p>Hello</p>"))
4041				zw.Close()
4042				return buf.Bytes()
4043			}(),
4044		},
4045		{
4046			name:            "no content-encoding", // must sniff.
4047			wantContentType: "application/x-gzip",
4048			body: func() []byte {
4049				buf := new(bytes.Buffer)
4050				gzw := gzip.NewWriter(buf)
4051				gzw.Write([]byte("doctype html><p>Hello</p>"))
4052				gzw.Close()
4053				return buf.Bytes()
4054			}(),
4055		},
4056		{
4057			name:            "phony content-encoding", // don't sniff.
4058			contentEncoding: "foo/bar",
4059			body:            []byte("doctype html><p>Hello</p>"),
4060		},
4061		{
4062			name:            "empty but set content-encoding",
4063			contentEncoding: "",
4064			wantContentType: "audio/mpeg",
4065			body:            []byte("ID3"),
4066		},
4067	}
4068
4069	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4070	defer tr.CloseIdleConnections()
4071
4072	for _, tt := range resps {
4073		t.Run(tt.name, func(t *testing.T) {
4074			st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4075				if tt.contentEncoding != nil {
4076					w.Header().Set("Content-Encoding", tt.contentEncoding.(string))
4077				}
4078				w.Write(tt.body)
4079			}, optOnlyServer)
4080			defer st.Close()
4081
4082			req, _ := http.NewRequest("GET", st.ts.URL, nil)
4083			res, err := tr.RoundTrip(req)
4084			if err != nil {
4085				t.Fatalf("Failed to fetch URL: %v", err)
4086			}
4087			defer res.Body.Close()
4088			if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
4089				if w != nil { // The case where contentEncoding was set explicitly.
4090					t.Errorf("Content-Encoding mismatch\n\tgot:  %q\n\twant: %q", g, w)
4091				} else if g != "" { // "" should be the equivalent when the contentEncoding is unset.
4092					t.Errorf("Unexpected Content-Encoding %q", g)
4093				}
4094			}
4095			if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
4096				t.Errorf("Content-Type mismatch\n\tgot:  %q\n\twant: %q", g, w)
4097			}
4098		})
4099	}
4100}
4101