1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package http2
6
7import (
8	"bytes"
9	"crypto/tls"
10	"errors"
11	"flag"
12	"fmt"
13	"io"
14	"io/ioutil"
15	"log"
16	"net"
17	"net/http"
18	"net/http/httptest"
19	"os"
20	"os/exec"
21	"reflect"
22	"runtime"
23	"strconv"
24	"strings"
25	"sync"
26	"sync/atomic"
27	"testing"
28	"time"
29
30	"golang.org/x/net/http2/hpack"
31)
32
33var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered")
34
35func stderrv() io.Writer {
36	if *stderrVerbose {
37		return os.Stderr
38	}
39
40	return ioutil.Discard
41}
42
43type serverTester struct {
44	cc             net.Conn // client conn
45	t              testing.TB
46	ts             *httptest.Server
47	fr             *Framer
48	serverLogBuf   bytes.Buffer // logger for httptest.Server
49	logFilter      []string     // substrings to filter out
50	scMu           sync.Mutex   // guards sc
51	sc             *serverConn
52	hpackDec       *hpack.Decoder
53	decodedHeaders [][2]string
54
55	// If http2debug!=2, then we capture Frame debug logs that will be written
56	// to t.Log after a test fails. The read and write logs use separate locks
57	// and buffers so we don't accidentally introduce synchronization between
58	// the read and write goroutines, which may hide data races.
59	frameReadLogMu   sync.Mutex
60	frameReadLogBuf  bytes.Buffer
61	frameWriteLogMu  sync.Mutex
62	frameWriteLogBuf bytes.Buffer
63
64	// writing headers:
65	headerBuf bytes.Buffer
66	hpackEnc  *hpack.Encoder
67}
68
69func init() {
70	testHookOnPanicMu = new(sync.Mutex)
71	goAwayTimeout = 25 * time.Millisecond
72}
73
74func resetHooks() {
75	testHookOnPanicMu.Lock()
76	testHookOnPanic = nil
77	testHookOnPanicMu.Unlock()
78}
79
80type serverTesterOpt string
81
82var optOnlyServer = serverTesterOpt("only_server")
83var optQuiet = serverTesterOpt("quiet_logging")
84var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames")
85
86func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
87	resetHooks()
88
89	ts := httptest.NewUnstartedServer(handler)
90
91	tlsConfig := &tls.Config{
92		InsecureSkipVerify: true,
93		NextProtos:         []string{NextProtoTLS},
94	}
95
96	var onlyServer, quiet, framerReuseFrames bool
97	h2server := new(Server)
98	for _, opt := range opts {
99		switch v := opt.(type) {
100		case func(*tls.Config):
101			v(tlsConfig)
102		case func(*httptest.Server):
103			v(ts)
104		case func(*Server):
105			v(h2server)
106		case serverTesterOpt:
107			switch v {
108			case optOnlyServer:
109				onlyServer = true
110			case optQuiet:
111				quiet = true
112			case optFramerReuseFrames:
113				framerReuseFrames = true
114			}
115		case func(net.Conn, http.ConnState):
116			ts.Config.ConnState = v
117		default:
118			t.Fatalf("unknown newServerTester option type %T", v)
119		}
120	}
121
122	ConfigureServer(ts.Config, h2server)
123
124	st := &serverTester{
125		t:  t,
126		ts: ts,
127	}
128	st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
129	st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
130
131	ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
132	if quiet {
133		ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
134	} else {
135		ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
136	}
137	ts.StartTLS()
138
139	if VerboseLogs {
140		t.Logf("Running test server at: %s", ts.URL)
141	}
142	testHookGetServerConn = func(v *serverConn) {
143		st.scMu.Lock()
144		defer st.scMu.Unlock()
145		st.sc = v
146	}
147	log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st}))
148	if !onlyServer {
149		cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
150		if err != nil {
151			t.Fatal(err)
152		}
153		st.cc = cc
154		st.fr = NewFramer(cc, cc)
155		if framerReuseFrames {
156			st.fr.SetReuseFrames()
157		}
158		if !logFrameReads && !logFrameWrites {
159			st.fr.debugReadLoggerf = func(m string, v ...interface{}) {
160				m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
161				st.frameReadLogMu.Lock()
162				fmt.Fprintf(&st.frameReadLogBuf, m, v...)
163				st.frameReadLogMu.Unlock()
164			}
165			st.fr.debugWriteLoggerf = func(m string, v ...interface{}) {
166				m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
167				st.frameWriteLogMu.Lock()
168				fmt.Fprintf(&st.frameWriteLogBuf, m, v...)
169				st.frameWriteLogMu.Unlock()
170			}
171			st.fr.logReads = true
172			st.fr.logWrites = true
173		}
174	}
175	return st
176}
177
178func (st *serverTester) closeConn() {
179	st.scMu.Lock()
180	defer st.scMu.Unlock()
181	st.sc.conn.Close()
182}
183
184func (st *serverTester) addLogFilter(phrase string) {
185	st.logFilter = append(st.logFilter, phrase)
186}
187
188func (st *serverTester) stream(id uint32) *stream {
189	ch := make(chan *stream, 1)
190	st.sc.serveMsgCh <- func(int) {
191		ch <- st.sc.streams[id]
192	}
193	return <-ch
194}
195
196func (st *serverTester) streamState(id uint32) streamState {
197	ch := make(chan streamState, 1)
198	st.sc.serveMsgCh <- func(int) {
199		state, _ := st.sc.state(id)
200		ch <- state
201	}
202	return <-ch
203}
204
205// loopNum reports how many times this conn's select loop has gone around.
206func (st *serverTester) loopNum() int {
207	lastc := make(chan int, 1)
208	st.sc.serveMsgCh <- func(loopNum int) {
209		lastc <- loopNum
210	}
211	return <-lastc
212}
213
214// awaitIdle heuristically awaits for the server conn's select loop to be idle.
215// The heuristic is that the server connection's serve loop must schedule
216// 50 times in a row without any channel sends or receives occurring.
217func (st *serverTester) awaitIdle() {
218	remain := 50
219	last := st.loopNum()
220	for remain > 0 {
221		n := st.loopNum()
222		if n == last+1 {
223			remain--
224		} else {
225			remain = 50
226		}
227		last = n
228	}
229}
230
231func (st *serverTester) Close() {
232	if st.t.Failed() {
233		st.frameReadLogMu.Lock()
234		if st.frameReadLogBuf.Len() > 0 {
235			st.t.Logf("Framer read log:\n%s", st.frameReadLogBuf.String())
236		}
237		st.frameReadLogMu.Unlock()
238
239		st.frameWriteLogMu.Lock()
240		if st.frameWriteLogBuf.Len() > 0 {
241			st.t.Logf("Framer write log:\n%s", st.frameWriteLogBuf.String())
242		}
243		st.frameWriteLogMu.Unlock()
244
245		// If we failed already (and are likely in a Fatal,
246		// unwindowing), force close the connection, so the
247		// httptest.Server doesn't wait forever for the conn
248		// to close.
249		if st.cc != nil {
250			st.cc.Close()
251		}
252	}
253	st.ts.Close()
254	if st.cc != nil {
255		st.cc.Close()
256	}
257	log.SetOutput(os.Stderr)
258}
259
260// greet initiates the client's HTTP/2 connection into a state where
261// frames may be sent.
262func (st *serverTester) greet() {
263	st.greetAndCheckSettings(func(Setting) error { return nil })
264}
265
266func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error) {
267	st.writePreface()
268	st.writeInitialSettings()
269	st.wantSettings().ForeachSetting(checkSetting)
270	st.writeSettingsAck()
271
272	// The initial WINDOW_UPDATE and SETTINGS ACK can come in any order.
273	var gotSettingsAck bool
274	var gotWindowUpdate bool
275
276	for i := 0; i < 2; i++ {
277		f, err := st.readFrame()
278		if err != nil {
279			st.t.Fatal(err)
280		}
281		switch f := f.(type) {
282		case *SettingsFrame:
283			if !f.Header().Flags.Has(FlagSettingsAck) {
284				st.t.Fatal("Settings Frame didn't have ACK set")
285			}
286			gotSettingsAck = true
287
288		case *WindowUpdateFrame:
289			if f.FrameHeader.StreamID != 0 {
290				st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
291			}
292			incr := uint32((&Server{}).initialConnRecvWindowSize() - initialWindowSize)
293			if f.Increment != incr {
294				st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr)
295			}
296			gotWindowUpdate = true
297
298		default:
299			st.t.Fatalf("Wanting a settings ACK or window update, received a %T", f)
300		}
301	}
302
303	if !gotSettingsAck {
304		st.t.Fatalf("Didn't get a settings ACK")
305	}
306	if !gotWindowUpdate {
307		st.t.Fatalf("Didn't get a window update")
308	}
309}
310
311func (st *serverTester) writePreface() {
312	n, err := st.cc.Write(clientPreface)
313	if err != nil {
314		st.t.Fatalf("Error writing client preface: %v", err)
315	}
316	if n != len(clientPreface) {
317		st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(clientPreface))
318	}
319}
320
321func (st *serverTester) writeInitialSettings() {
322	if err := st.fr.WriteSettings(); err != nil {
323		st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err)
324	}
325}
326
327func (st *serverTester) writeSettingsAck() {
328	if err := st.fr.WriteSettingsAck(); err != nil {
329		st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err)
330	}
331}
332
333func (st *serverTester) writeHeaders(p HeadersFrameParam) {
334	if err := st.fr.WriteHeaders(p); err != nil {
335		st.t.Fatalf("Error writing HEADERS: %v", err)
336	}
337}
338
339func (st *serverTester) writePriority(id uint32, p PriorityParam) {
340	if err := st.fr.WritePriority(id, p); err != nil {
341		st.t.Fatalf("Error writing PRIORITY: %v", err)
342	}
343}
344
345func (st *serverTester) encodeHeaderField(k, v string) {
346	err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
347	if err != nil {
348		st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
349	}
350}
351
352// encodeHeaderRaw is the magic-free version of encodeHeader.
353// It takes 0 or more (k, v) pairs and encodes them.
354func (st *serverTester) encodeHeaderRaw(headers ...string) []byte {
355	if len(headers)%2 == 1 {
356		panic("odd number of kv args")
357	}
358	st.headerBuf.Reset()
359	for len(headers) > 0 {
360		k, v := headers[0], headers[1]
361		st.encodeHeaderField(k, v)
362		headers = headers[2:]
363	}
364	return st.headerBuf.Bytes()
365}
366
367// encodeHeader encodes headers and returns their HPACK bytes. headers
368// must contain an even number of key/value pairs. There may be
369// multiple pairs for keys (e.g. "cookie").  The :method, :path, and
370// :scheme headers default to GET, / and https. The :authority header
371// defaults to st.ts.Listener.Addr().
372func (st *serverTester) encodeHeader(headers ...string) []byte {
373	if len(headers)%2 == 1 {
374		panic("odd number of kv args")
375	}
376
377	st.headerBuf.Reset()
378	defaultAuthority := st.ts.Listener.Addr().String()
379
380	if len(headers) == 0 {
381		// Fast path, mostly for benchmarks, so test code doesn't pollute
382		// profiles when we're looking to improve server allocations.
383		st.encodeHeaderField(":method", "GET")
384		st.encodeHeaderField(":scheme", "https")
385		st.encodeHeaderField(":authority", defaultAuthority)
386		st.encodeHeaderField(":path", "/")
387		return st.headerBuf.Bytes()
388	}
389
390	if len(headers) == 2 && headers[0] == ":method" {
391		// Another fast path for benchmarks.
392		st.encodeHeaderField(":method", headers[1])
393		st.encodeHeaderField(":scheme", "https")
394		st.encodeHeaderField(":authority", defaultAuthority)
395		st.encodeHeaderField(":path", "/")
396		return st.headerBuf.Bytes()
397	}
398
399	pseudoCount := map[string]int{}
400	keys := []string{":method", ":scheme", ":authority", ":path"}
401	vals := map[string][]string{
402		":method":    {"GET"},
403		":scheme":    {"https"},
404		":authority": {defaultAuthority},
405		":path":      {"/"},
406	}
407	for len(headers) > 0 {
408		k, v := headers[0], headers[1]
409		headers = headers[2:]
410		if _, ok := vals[k]; !ok {
411			keys = append(keys, k)
412		}
413		if strings.HasPrefix(k, ":") {
414			pseudoCount[k]++
415			if pseudoCount[k] == 1 {
416				vals[k] = []string{v}
417			} else {
418				// Allows testing of invalid headers w/ dup pseudo fields.
419				vals[k] = append(vals[k], v)
420			}
421		} else {
422			vals[k] = append(vals[k], v)
423		}
424	}
425	for _, k := range keys {
426		for _, v := range vals[k] {
427			st.encodeHeaderField(k, v)
428		}
429	}
430	return st.headerBuf.Bytes()
431}
432
433// bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set.
434func (st *serverTester) bodylessReq1(headers ...string) {
435	st.writeHeaders(HeadersFrameParam{
436		StreamID:      1, // clients send odd numbers
437		BlockFragment: st.encodeHeader(headers...),
438		EndStream:     true,
439		EndHeaders:    true,
440	})
441}
442
443func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
444	if err := st.fr.WriteData(streamID, endStream, data); err != nil {
445		st.t.Fatalf("Error writing DATA: %v", err)
446	}
447}
448
449func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
450	if err := st.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
451		st.t.Fatalf("Error writing DATA: %v", err)
452	}
453}
454
455func readFrameTimeout(fr *Framer, wait time.Duration) (Frame, error) {
456	ch := make(chan interface{}, 1)
457	go func() {
458		fr, err := fr.ReadFrame()
459		if err != nil {
460			ch <- err
461		} else {
462			ch <- fr
463		}
464	}()
465	t := time.NewTimer(wait)
466	select {
467	case v := <-ch:
468		t.Stop()
469		if fr, ok := v.(Frame); ok {
470			return fr, nil
471		}
472		return nil, v.(error)
473	case <-t.C:
474		return nil, errors.New("timeout waiting for frame")
475	}
476}
477
478func (st *serverTester) readFrame() (Frame, error) {
479	return readFrameTimeout(st.fr, 2*time.Second)
480}
481
482func (st *serverTester) wantHeaders() *HeadersFrame {
483	f, err := st.readFrame()
484	if err != nil {
485		st.t.Fatalf("Error while expecting a HEADERS frame: %v", err)
486	}
487	hf, ok := f.(*HeadersFrame)
488	if !ok {
489		st.t.Fatalf("got a %T; want *HeadersFrame", f)
490	}
491	return hf
492}
493
494func (st *serverTester) wantContinuation() *ContinuationFrame {
495	f, err := st.readFrame()
496	if err != nil {
497		st.t.Fatalf("Error while expecting a CONTINUATION frame: %v", err)
498	}
499	cf, ok := f.(*ContinuationFrame)
500	if !ok {
501		st.t.Fatalf("got a %T; want *ContinuationFrame", f)
502	}
503	return cf
504}
505
506func (st *serverTester) wantData() *DataFrame {
507	f, err := st.readFrame()
508	if err != nil {
509		st.t.Fatalf("Error while expecting a DATA frame: %v", err)
510	}
511	df, ok := f.(*DataFrame)
512	if !ok {
513		st.t.Fatalf("got a %T; want *DataFrame", f)
514	}
515	return df
516}
517
518func (st *serverTester) wantSettings() *SettingsFrame {
519	f, err := st.readFrame()
520	if err != nil {
521		st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
522	}
523	sf, ok := f.(*SettingsFrame)
524	if !ok {
525		st.t.Fatalf("got a %T; want *SettingsFrame", f)
526	}
527	return sf
528}
529
530func (st *serverTester) wantPing() *PingFrame {
531	f, err := st.readFrame()
532	if err != nil {
533		st.t.Fatalf("Error while expecting a PING frame: %v", err)
534	}
535	pf, ok := f.(*PingFrame)
536	if !ok {
537		st.t.Fatalf("got a %T; want *PingFrame", f)
538	}
539	return pf
540}
541
542func (st *serverTester) wantGoAway() *GoAwayFrame {
543	f, err := st.readFrame()
544	if err != nil {
545		st.t.Fatalf("Error while expecting a GOAWAY frame: %v", err)
546	}
547	gf, ok := f.(*GoAwayFrame)
548	if !ok {
549		st.t.Fatalf("got a %T; want *GoAwayFrame", f)
550	}
551	return gf
552}
553
554func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
555	f, err := st.readFrame()
556	if err != nil {
557		st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
558	}
559	rs, ok := f.(*RSTStreamFrame)
560	if !ok {
561		st.t.Fatalf("got a %T; want *RSTStreamFrame", f)
562	}
563	if rs.FrameHeader.StreamID != streamID {
564		st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
565	}
566	if rs.ErrCode != errCode {
567		st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode)
568	}
569}
570
571func (st *serverTester) wantWindowUpdate(streamID, incr uint32) {
572	f, err := st.readFrame()
573	if err != nil {
574		st.t.Fatalf("Error while expecting a WINDOW_UPDATE frame: %v", err)
575	}
576	wu, ok := f.(*WindowUpdateFrame)
577	if !ok {
578		st.t.Fatalf("got a %T; want *WindowUpdateFrame", f)
579	}
580	if wu.FrameHeader.StreamID != streamID {
581		st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
582	}
583	if wu.Increment != incr {
584		st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
585	}
586}
587
588func (st *serverTester) wantSettingsAck() {
589	f, err := st.readFrame()
590	if err != nil {
591		st.t.Fatal(err)
592	}
593	sf, ok := f.(*SettingsFrame)
594	if !ok {
595		st.t.Fatalf("Wanting a settings ACK, received a %T", f)
596	}
597	if !sf.Header().Flags.Has(FlagSettingsAck) {
598		st.t.Fatal("Settings Frame didn't have ACK set")
599	}
600}
601
602func (st *serverTester) wantPushPromise() *PushPromiseFrame {
603	f, err := st.readFrame()
604	if err != nil {
605		st.t.Fatal(err)
606	}
607	ppf, ok := f.(*PushPromiseFrame)
608	if !ok {
609		st.t.Fatalf("Wanted PushPromise, received %T", ppf)
610	}
611	return ppf
612}
613
614func TestServer(t *testing.T) {
615	gotReq := make(chan bool, 1)
616	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
617		w.Header().Set("Foo", "Bar")
618		gotReq <- true
619	})
620	defer st.Close()
621
622	covers("3.5", `
623		The server connection preface consists of a potentially empty
624		SETTINGS frame ([SETTINGS]) that MUST be the first frame the
625		server sends in the HTTP/2 connection.
626	`)
627
628	st.greet()
629	st.writeHeaders(HeadersFrameParam{
630		StreamID:      1, // clients send odd numbers
631		BlockFragment: st.encodeHeader(),
632		EndStream:     true, // no DATA frames
633		EndHeaders:    true,
634	})
635
636	select {
637	case <-gotReq:
638	case <-time.After(2 * time.Second):
639		t.Error("timeout waiting for request")
640	}
641}
642
643func TestServer_Request_Get(t *testing.T) {
644	testServerRequest(t, func(st *serverTester) {
645		st.writeHeaders(HeadersFrameParam{
646			StreamID:      1, // clients send odd numbers
647			BlockFragment: st.encodeHeader("foo-bar", "some-value"),
648			EndStream:     true, // no DATA frames
649			EndHeaders:    true,
650		})
651	}, func(r *http.Request) {
652		if r.Method != "GET" {
653			t.Errorf("Method = %q; want GET", r.Method)
654		}
655		if r.URL.Path != "/" {
656			t.Errorf("URL.Path = %q; want /", r.URL.Path)
657		}
658		if r.ContentLength != 0 {
659			t.Errorf("ContentLength = %v; want 0", r.ContentLength)
660		}
661		if r.Close {
662			t.Error("Close = true; want false")
663		}
664		if !strings.Contains(r.RemoteAddr, ":") {
665			t.Errorf("RemoteAddr = %q; want something with a colon", r.RemoteAddr)
666		}
667		if r.Proto != "HTTP/2.0" || r.ProtoMajor != 2 || r.ProtoMinor != 0 {
668			t.Errorf("Proto = %q Major=%v,Minor=%v; want HTTP/2.0", r.Proto, r.ProtoMajor, r.ProtoMinor)
669		}
670		wantHeader := http.Header{
671			"Foo-Bar": []string{"some-value"},
672		}
673		if !reflect.DeepEqual(r.Header, wantHeader) {
674			t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
675		}
676		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
677			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
678		}
679	})
680}
681
682func TestServer_Request_Get_PathSlashes(t *testing.T) {
683	testServerRequest(t, func(st *serverTester) {
684		st.writeHeaders(HeadersFrameParam{
685			StreamID:      1, // clients send odd numbers
686			BlockFragment: st.encodeHeader(":path", "/%2f/"),
687			EndStream:     true, // no DATA frames
688			EndHeaders:    true,
689		})
690	}, func(r *http.Request) {
691		if r.RequestURI != "/%2f/" {
692			t.Errorf("RequestURI = %q; want /%%2f/", r.RequestURI)
693		}
694		if r.URL.Path != "///" {
695			t.Errorf("URL.Path = %q; want ///", r.URL.Path)
696		}
697	})
698}
699
700// TODO: add a test with EndStream=true on the HEADERS but setting a
701// Content-Length anyway. Should we just omit it and force it to
702// zero?
703
704func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) {
705	testServerRequest(t, func(st *serverTester) {
706		st.writeHeaders(HeadersFrameParam{
707			StreamID:      1, // clients send odd numbers
708			BlockFragment: st.encodeHeader(":method", "POST"),
709			EndStream:     true,
710			EndHeaders:    true,
711		})
712	}, func(r *http.Request) {
713		if r.Method != "POST" {
714			t.Errorf("Method = %q; want POST", r.Method)
715		}
716		if r.ContentLength != 0 {
717			t.Errorf("ContentLength = %v; want 0", r.ContentLength)
718		}
719		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
720			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
721		}
722	})
723}
724
725func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) {
726	testBodyContents(t, -1, "", func(st *serverTester) {
727		st.writeHeaders(HeadersFrameParam{
728			StreamID:      1, // clients send odd numbers
729			BlockFragment: st.encodeHeader(":method", "POST"),
730			EndStream:     false, // to say DATA frames are coming
731			EndHeaders:    true,
732		})
733		st.writeData(1, true, nil) // just kidding. empty body.
734	})
735}
736
737func TestServer_Request_Post_Body_OneData(t *testing.T) {
738	const content = "Some content"
739	testBodyContents(t, -1, content, func(st *serverTester) {
740		st.writeHeaders(HeadersFrameParam{
741			StreamID:      1, // clients send odd numbers
742			BlockFragment: st.encodeHeader(":method", "POST"),
743			EndStream:     false, // to say DATA frames are coming
744			EndHeaders:    true,
745		})
746		st.writeData(1, true, []byte(content))
747	})
748}
749
750func TestServer_Request_Post_Body_TwoData(t *testing.T) {
751	const content = "Some content"
752	testBodyContents(t, -1, content, func(st *serverTester) {
753		st.writeHeaders(HeadersFrameParam{
754			StreamID:      1, // clients send odd numbers
755			BlockFragment: st.encodeHeader(":method", "POST"),
756			EndStream:     false, // to say DATA frames are coming
757			EndHeaders:    true,
758		})
759		st.writeData(1, false, []byte(content[:5]))
760		st.writeData(1, true, []byte(content[5:]))
761	})
762}
763
764func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) {
765	const content = "Some content"
766	testBodyContents(t, int64(len(content)), content, func(st *serverTester) {
767		st.writeHeaders(HeadersFrameParam{
768			StreamID: 1, // clients send odd numbers
769			BlockFragment: st.encodeHeader(
770				":method", "POST",
771				"content-length", strconv.Itoa(len(content)),
772			),
773			EndStream:  false, // to say DATA frames are coming
774			EndHeaders: true,
775		})
776		st.writeData(1, true, []byte(content))
777	})
778}
779
780func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) {
781	testBodyContentsFail(t, 3, "request declared a Content-Length of 3 but only wrote 2 bytes",
782		func(st *serverTester) {
783			st.writeHeaders(HeadersFrameParam{
784				StreamID: 1, // clients send odd numbers
785				BlockFragment: st.encodeHeader(
786					":method", "POST",
787					"content-length", "3",
788				),
789				EndStream:  false, // to say DATA frames are coming
790				EndHeaders: true,
791			})
792			st.writeData(1, true, []byte("12"))
793		})
794}
795
796func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) {
797	testBodyContentsFail(t, 4, "sender tried to send more than declared Content-Length of 4 bytes",
798		func(st *serverTester) {
799			st.writeHeaders(HeadersFrameParam{
800				StreamID: 1, // clients send odd numbers
801				BlockFragment: st.encodeHeader(
802					":method", "POST",
803					"content-length", "4",
804				),
805				EndStream:  false, // to say DATA frames are coming
806				EndHeaders: true,
807			})
808			st.writeData(1, true, []byte("12345"))
809		})
810}
811
812func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, write func(st *serverTester)) {
813	testServerRequest(t, write, func(r *http.Request) {
814		if r.Method != "POST" {
815			t.Errorf("Method = %q; want POST", r.Method)
816		}
817		if r.ContentLength != wantContentLength {
818			t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
819		}
820		all, err := ioutil.ReadAll(r.Body)
821		if err != nil {
822			t.Fatal(err)
823		}
824		if string(all) != wantBody {
825			t.Errorf("Read = %q; want %q", all, wantBody)
826		}
827		if err := r.Body.Close(); err != nil {
828			t.Fatalf("Close: %v", err)
829		}
830	})
831}
832
833func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError string, write func(st *serverTester)) {
834	testServerRequest(t, write, func(r *http.Request) {
835		if r.Method != "POST" {
836			t.Errorf("Method = %q; want POST", r.Method)
837		}
838		if r.ContentLength != wantContentLength {
839			t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
840		}
841		all, err := ioutil.ReadAll(r.Body)
842		if err == nil {
843			t.Fatalf("expected an error (%q) reading from the body. Successfully read %q instead.",
844				wantReadError, all)
845		}
846		if !strings.Contains(err.Error(), wantReadError) {
847			t.Fatalf("Body.Read = %v; want substring %q", err, wantReadError)
848		}
849		if err := r.Body.Close(); err != nil {
850			t.Fatalf("Close: %v", err)
851		}
852	})
853}
854
855// Using a Host header, instead of :authority
856func TestServer_Request_Get_Host(t *testing.T) {
857	const host = "example.com"
858	testServerRequest(t, func(st *serverTester) {
859		st.writeHeaders(HeadersFrameParam{
860			StreamID:      1, // clients send odd numbers
861			BlockFragment: st.encodeHeader(":authority", "", "host", host),
862			EndStream:     true,
863			EndHeaders:    true,
864		})
865	}, func(r *http.Request) {
866		if r.Host != host {
867			t.Errorf("Host = %q; want %q", r.Host, host)
868		}
869	})
870}
871
872// Using an :authority pseudo-header, instead of Host
873func TestServer_Request_Get_Authority(t *testing.T) {
874	const host = "example.com"
875	testServerRequest(t, func(st *serverTester) {
876		st.writeHeaders(HeadersFrameParam{
877			StreamID:      1, // clients send odd numbers
878			BlockFragment: st.encodeHeader(":authority", host),
879			EndStream:     true,
880			EndHeaders:    true,
881		})
882	}, func(r *http.Request) {
883		if r.Host != host {
884			t.Errorf("Host = %q; want %q", r.Host, host)
885		}
886	})
887}
888
889func TestServer_Request_WithContinuation(t *testing.T) {
890	wantHeader := http.Header{
891		"Foo-One":   []string{"value-one"},
892		"Foo-Two":   []string{"value-two"},
893		"Foo-Three": []string{"value-three"},
894	}
895	testServerRequest(t, func(st *serverTester) {
896		fullHeaders := st.encodeHeader(
897			"foo-one", "value-one",
898			"foo-two", "value-two",
899			"foo-three", "value-three",
900		)
901		remain := fullHeaders
902		chunks := 0
903		for len(remain) > 0 {
904			const maxChunkSize = 5
905			chunk := remain
906			if len(chunk) > maxChunkSize {
907				chunk = chunk[:maxChunkSize]
908			}
909			remain = remain[len(chunk):]
910
911			if chunks == 0 {
912				st.writeHeaders(HeadersFrameParam{
913					StreamID:      1, // clients send odd numbers
914					BlockFragment: chunk,
915					EndStream:     true,  // no DATA frames
916					EndHeaders:    false, // we'll have continuation frames
917				})
918			} else {
919				err := st.fr.WriteContinuation(1, len(remain) == 0, chunk)
920				if err != nil {
921					t.Fatal(err)
922				}
923			}
924			chunks++
925		}
926		if chunks < 2 {
927			t.Fatal("too few chunks")
928		}
929	}, func(r *http.Request) {
930		if !reflect.DeepEqual(r.Header, wantHeader) {
931			t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
932		}
933	})
934}
935
936// Concatenated cookie headers. ("8.1.2.5 Compressing the Cookie Header Field")
937func TestServer_Request_CookieConcat(t *testing.T) {
938	const host = "example.com"
939	testServerRequest(t, func(st *serverTester) {
940		st.bodylessReq1(
941			":authority", host,
942			"cookie", "a=b",
943			"cookie", "c=d",
944			"cookie", "e=f",
945		)
946	}, func(r *http.Request) {
947		const want = "a=b; c=d; e=f"
948		if got := r.Header.Get("Cookie"); got != want {
949			t.Errorf("Cookie = %q; want %q", got, want)
950		}
951	})
952}
953
954func TestServer_Request_Reject_CapitalHeader(t *testing.T) {
955	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("UPPER", "v") })
956}
957
958func TestServer_Request_Reject_HeaderFieldNameColon(t *testing.T) {
959	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has:colon", "v") })
960}
961
962func TestServer_Request_Reject_HeaderFieldNameNULL(t *testing.T) {
963	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has\x00null", "v") })
964}
965
966func TestServer_Request_Reject_HeaderFieldNameEmpty(t *testing.T) {
967	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("", "v") })
968}
969
970func TestServer_Request_Reject_HeaderFieldValueNewline(t *testing.T) {
971	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\nnewline") })
972}
973
974func TestServer_Request_Reject_HeaderFieldValueCR(t *testing.T) {
975	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\rcarriage") })
976}
977
978func TestServer_Request_Reject_HeaderFieldValueDEL(t *testing.T) {
979	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\x7fdel") })
980}
981
982func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) {
983	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "") })
984}
985
986func TestServer_Request_Reject_Pseudo_ExactlyOne(t *testing.T) {
987	// 8.1.2.3 Request Pseudo-Header Fields
988	// "All HTTP/2 requests MUST include exactly one valid value" ...
989	testRejectRequest(t, func(st *serverTester) {
990		st.addLogFilter("duplicate pseudo-header")
991		st.bodylessReq1(":method", "GET", ":method", "POST")
992	})
993}
994
995func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) {
996	// 8.1.2.3 Request Pseudo-Header Fields
997	// "All pseudo-header fields MUST appear in the header block
998	// before regular header fields. Any request or response that
999	// contains a pseudo-header field that appears in a header
1000	// block after a regular header field MUST be treated as
1001	// malformed (Section 8.1.2.6)."
1002	testRejectRequest(t, func(st *serverTester) {
1003		st.addLogFilter("pseudo-header after regular header")
1004		var buf bytes.Buffer
1005		enc := hpack.NewEncoder(&buf)
1006		enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
1007		enc.WriteField(hpack.HeaderField{Name: "regular", Value: "foobar"})
1008		enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"})
1009		enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
1010		st.writeHeaders(HeadersFrameParam{
1011			StreamID:      1, // clients send odd numbers
1012			BlockFragment: buf.Bytes(),
1013			EndStream:     true,
1014			EndHeaders:    true,
1015		})
1016	})
1017}
1018
1019func TestServer_Request_Reject_Pseudo_Missing_path(t *testing.T) {
1020	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":path", "") })
1021}
1022
1023func TestServer_Request_Reject_Pseudo_Missing_scheme(t *testing.T) {
1024	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "") })
1025}
1026
1027func TestServer_Request_Reject_Pseudo_scheme_invalid(t *testing.T) {
1028	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "bogus") })
1029}
1030
1031func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
1032	testRejectRequest(t, func(st *serverTester) {
1033		st.addLogFilter(`invalid pseudo-header ":unknown_thing"`)
1034		st.bodylessReq1(":unknown_thing", "")
1035	})
1036}
1037
1038func testRejectRequest(t *testing.T, send func(*serverTester)) {
1039	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1040		t.Error("server request made it to handler; should've been rejected")
1041	})
1042	defer st.Close()
1043
1044	st.greet()
1045	send(st)
1046	st.wantRSTStream(1, ErrCodeProtocol)
1047}
1048
1049func testRejectRequestWithProtocolError(t *testing.T, send func(*serverTester)) {
1050	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1051		t.Error("server request made it to handler; should've been rejected")
1052	}, optQuiet)
1053	defer st.Close()
1054
1055	st.greet()
1056	send(st)
1057	gf := st.wantGoAway()
1058	if gf.ErrCode != ErrCodeProtocol {
1059		t.Errorf("err code = %v; want %v", gf.ErrCode, ErrCodeProtocol)
1060	}
1061}
1062
1063// Section 5.1, on idle connections: "Receiving any frame other than
1064// HEADERS or PRIORITY on a stream in this state MUST be treated as a
1065// connection error (Section 5.4.1) of type PROTOCOL_ERROR."
1066func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) {
1067	testRejectRequestWithProtocolError(t, func(st *serverTester) {
1068		st.fr.WriteWindowUpdate(123, 456)
1069	})
1070}
1071func TestRejectFrameOnIdle_Data(t *testing.T) {
1072	testRejectRequestWithProtocolError(t, func(st *serverTester) {
1073		st.fr.WriteData(123, true, nil)
1074	})
1075}
1076func TestRejectFrameOnIdle_RSTStream(t *testing.T) {
1077	testRejectRequestWithProtocolError(t, func(st *serverTester) {
1078		st.fr.WriteRSTStream(123, ErrCodeCancel)
1079	})
1080}
1081
1082func TestServer_Request_Connect(t *testing.T) {
1083	testServerRequest(t, func(st *serverTester) {
1084		st.writeHeaders(HeadersFrameParam{
1085			StreamID: 1,
1086			BlockFragment: st.encodeHeaderRaw(
1087				":method", "CONNECT",
1088				":authority", "example.com:123",
1089			),
1090			EndStream:  true,
1091			EndHeaders: true,
1092		})
1093	}, func(r *http.Request) {
1094		if g, w := r.Method, "CONNECT"; g != w {
1095			t.Errorf("Method = %q; want %q", g, w)
1096		}
1097		if g, w := r.RequestURI, "example.com:123"; g != w {
1098			t.Errorf("RequestURI = %q; want %q", g, w)
1099		}
1100		if g, w := r.URL.Host, "example.com:123"; g != w {
1101			t.Errorf("URL.Host = %q; want %q", g, w)
1102		}
1103	})
1104}
1105
1106func TestServer_Request_Connect_InvalidPath(t *testing.T) {
1107	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1108		st.writeHeaders(HeadersFrameParam{
1109			StreamID: 1,
1110			BlockFragment: st.encodeHeaderRaw(
1111				":method", "CONNECT",
1112				":authority", "example.com:123",
1113				":path", "/bogus",
1114			),
1115			EndStream:  true,
1116			EndHeaders: true,
1117		})
1118	})
1119}
1120
1121func TestServer_Request_Connect_InvalidScheme(t *testing.T) {
1122	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1123		st.writeHeaders(HeadersFrameParam{
1124			StreamID: 1,
1125			BlockFragment: st.encodeHeaderRaw(
1126				":method", "CONNECT",
1127				":authority", "example.com:123",
1128				":scheme", "https",
1129			),
1130			EndStream:  true,
1131			EndHeaders: true,
1132		})
1133	})
1134}
1135
1136func TestServer_Ping(t *testing.T) {
1137	st := newServerTester(t, nil)
1138	defer st.Close()
1139	st.greet()
1140
1141	// Server should ignore this one, since it has ACK set.
1142	ackPingData := [8]byte{1, 2, 4, 8, 16, 32, 64, 128}
1143	if err := st.fr.WritePing(true, ackPingData); err != nil {
1144		t.Fatal(err)
1145	}
1146
1147	// But the server should reply to this one, since ACK is false.
1148	pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
1149	if err := st.fr.WritePing(false, pingData); err != nil {
1150		t.Fatal(err)
1151	}
1152
1153	pf := st.wantPing()
1154	if !pf.Flags.Has(FlagPingAck) {
1155		t.Error("response ping doesn't have ACK set")
1156	}
1157	if pf.Data != pingData {
1158		t.Errorf("response ping has data %q; want %q", pf.Data, pingData)
1159	}
1160}
1161
1162func TestServer_RejectsLargeFrames(t *testing.T) {
1163	if runtime.GOOS == "windows" {
1164		t.Skip("see golang.org/issue/13434")
1165	}
1166
1167	st := newServerTester(t, nil)
1168	defer st.Close()
1169	st.greet()
1170
1171	// Write too large of a frame (too large by one byte)
1172	// We ignore the return value because it's expected that the server
1173	// will only read the first 9 bytes (the headre) and then disconnect.
1174	st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, defaultMaxReadFrameSize+1))
1175
1176	gf := st.wantGoAway()
1177	if gf.ErrCode != ErrCodeFrameSize {
1178		t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFrameSize)
1179	}
1180	if st.serverLogBuf.Len() != 0 {
1181		// Previously we spun here for a bit until the GOAWAY disconnect
1182		// timer fired, logging while we fired.
1183		t.Errorf("unexpected server output: %.500s\n", st.serverLogBuf.Bytes())
1184	}
1185}
1186
1187func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
1188	puppet := newHandlerPuppet()
1189	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1190		puppet.act(w, r)
1191	})
1192	defer st.Close()
1193	defer puppet.done()
1194
1195	st.greet()
1196
1197	st.writeHeaders(HeadersFrameParam{
1198		StreamID:      1, // clients send odd numbers
1199		BlockFragment: st.encodeHeader(":method", "POST"),
1200		EndStream:     false, // data coming
1201		EndHeaders:    true,
1202	})
1203	st.writeData(1, false, []byte("abcdef"))
1204	puppet.do(readBodyHandler(t, "abc"))
1205	st.wantWindowUpdate(0, 3)
1206	st.wantWindowUpdate(1, 3)
1207
1208	puppet.do(readBodyHandler(t, "def"))
1209	st.wantWindowUpdate(0, 3)
1210	st.wantWindowUpdate(1, 3)
1211
1212	st.writeData(1, true, []byte("ghijkl")) // END_STREAM here
1213	puppet.do(readBodyHandler(t, "ghi"))
1214	puppet.do(readBodyHandler(t, "jkl"))
1215	st.wantWindowUpdate(0, 3)
1216	st.wantWindowUpdate(0, 3) // no more stream-level, since END_STREAM
1217}
1218
1219// the version of the TestServer_Handler_Sends_WindowUpdate with padding.
1220// See golang.org/issue/16556
1221func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) {
1222	puppet := newHandlerPuppet()
1223	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1224		puppet.act(w, r)
1225	})
1226	defer st.Close()
1227	defer puppet.done()
1228
1229	st.greet()
1230
1231	st.writeHeaders(HeadersFrameParam{
1232		StreamID:      1,
1233		BlockFragment: st.encodeHeader(":method", "POST"),
1234		EndStream:     false,
1235		EndHeaders:    true,
1236	})
1237	st.writeDataPadded(1, false, []byte("abcdef"), []byte{0, 0, 0, 0})
1238
1239	// Expect to immediately get our 5 bytes of padding back for
1240	// both the connection and stream (4 bytes of padding + 1 byte of length)
1241	st.wantWindowUpdate(0, 5)
1242	st.wantWindowUpdate(1, 5)
1243
1244	puppet.do(readBodyHandler(t, "abc"))
1245	st.wantWindowUpdate(0, 3)
1246	st.wantWindowUpdate(1, 3)
1247
1248	puppet.do(readBodyHandler(t, "def"))
1249	st.wantWindowUpdate(0, 3)
1250	st.wantWindowUpdate(1, 3)
1251}
1252
1253func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) {
1254	st := newServerTester(t, nil)
1255	defer st.Close()
1256	st.greet()
1257	if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
1258		t.Fatal(err)
1259	}
1260	gf := st.wantGoAway()
1261	if gf.ErrCode != ErrCodeFlowControl {
1262		t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFlowControl)
1263	}
1264	if gf.LastStreamID != 0 {
1265		t.Errorf("GOAWAY last stream ID = %v; want %v", gf.LastStreamID, 0)
1266	}
1267}
1268
1269func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) {
1270	inHandler := make(chan bool)
1271	blockHandler := make(chan bool)
1272	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1273		inHandler <- true
1274		<-blockHandler
1275	})
1276	defer st.Close()
1277	defer close(blockHandler)
1278	st.greet()
1279	st.writeHeaders(HeadersFrameParam{
1280		StreamID:      1,
1281		BlockFragment: st.encodeHeader(":method", "POST"),
1282		EndStream:     false, // keep it open
1283		EndHeaders:    true,
1284	})
1285	<-inHandler
1286	// Send a bogus window update:
1287	if err := st.fr.WriteWindowUpdate(1, 1<<31-1); err != nil {
1288		t.Fatal(err)
1289	}
1290	st.wantRSTStream(1, ErrCodeFlowControl)
1291}
1292
1293// testServerPostUnblock sends a hanging POST with unsent data to handler,
1294// then runs fn once in the handler, and verifies that the error returned from
1295// handler is acceptable. It fails if takes over 5 seconds for handler to exit.
1296func testServerPostUnblock(t *testing.T,
1297	handler func(http.ResponseWriter, *http.Request) error,
1298	fn func(*serverTester),
1299	checkErr func(error),
1300	otherHeaders ...string) {
1301	inHandler := make(chan bool)
1302	errc := make(chan error, 1)
1303	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1304		inHandler <- true
1305		errc <- handler(w, r)
1306	})
1307	defer st.Close()
1308	st.greet()
1309	st.writeHeaders(HeadersFrameParam{
1310		StreamID:      1,
1311		BlockFragment: st.encodeHeader(append([]string{":method", "POST"}, otherHeaders...)...),
1312		EndStream:     false, // keep it open
1313		EndHeaders:    true,
1314	})
1315	<-inHandler
1316	fn(st)
1317	select {
1318	case err := <-errc:
1319		if checkErr != nil {
1320			checkErr(err)
1321		}
1322	case <-time.After(5 * time.Second):
1323		t.Fatal("timeout waiting for Handler to return")
1324	}
1325}
1326
1327func TestServer_RSTStream_Unblocks_Read(t *testing.T) {
1328	testServerPostUnblock(t,
1329		func(w http.ResponseWriter, r *http.Request) (err error) {
1330			_, err = r.Body.Read(make([]byte, 1))
1331			return
1332		},
1333		func(st *serverTester) {
1334			if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1335				t.Fatal(err)
1336			}
1337		},
1338		func(err error) {
1339			want := StreamError{StreamID: 0x1, Code: 0x8}
1340			if !reflect.DeepEqual(err, want) {
1341				t.Errorf("Read error = %v; want %v", err, want)
1342			}
1343		},
1344	)
1345}
1346
1347func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1348	// Run this test a bunch, because it doesn't always
1349	// deadlock. But with a bunch, it did.
1350	n := 50
1351	if testing.Short() {
1352		n = 5
1353	}
1354	for i := 0; i < n; i++ {
1355		testServer_RSTStream_Unblocks_Header_Write(t)
1356	}
1357}
1358
1359func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1360	inHandler := make(chan bool, 1)
1361	unblockHandler := make(chan bool, 1)
1362	headerWritten := make(chan bool, 1)
1363	wroteRST := make(chan bool, 1)
1364
1365	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1366		inHandler <- true
1367		<-wroteRST
1368		w.Header().Set("foo", "bar")
1369		w.WriteHeader(200)
1370		w.(http.Flusher).Flush()
1371		headerWritten <- true
1372		<-unblockHandler
1373	})
1374	defer st.Close()
1375
1376	st.greet()
1377	st.writeHeaders(HeadersFrameParam{
1378		StreamID:      1,
1379		BlockFragment: st.encodeHeader(":method", "POST"),
1380		EndStream:     false, // keep it open
1381		EndHeaders:    true,
1382	})
1383	<-inHandler
1384	if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1385		t.Fatal(err)
1386	}
1387	wroteRST <- true
1388	st.awaitIdle()
1389	select {
1390	case <-headerWritten:
1391	case <-time.After(2 * time.Second):
1392		t.Error("timeout waiting for header write")
1393	}
1394	unblockHandler <- true
1395}
1396
1397func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
1398	testServerPostUnblock(t,
1399		func(w http.ResponseWriter, r *http.Request) (err error) {
1400			_, err = r.Body.Read(make([]byte, 1))
1401			return
1402		},
1403		func(st *serverTester) { st.cc.Close() },
1404		func(err error) {
1405			if err == nil {
1406				t.Error("unexpected nil error from Request.Body.Read")
1407			}
1408		},
1409	)
1410}
1411
1412var blockUntilClosed = func(w http.ResponseWriter, r *http.Request) error {
1413	<-w.(http.CloseNotifier).CloseNotify()
1414	return nil
1415}
1416
1417func TestServer_CloseNotify_After_RSTStream(t *testing.T) {
1418	testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1419		if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1420			t.Fatal(err)
1421		}
1422	}, nil)
1423}
1424
1425func TestServer_CloseNotify_After_ConnClose(t *testing.T) {
1426	testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { st.cc.Close() }, nil)
1427}
1428
1429// that CloseNotify unblocks after a stream error due to the client's
1430// problem that's unrelated to them explicitly canceling it (which is
1431// TestServer_CloseNotify_After_RSTStream above)
1432func TestServer_CloseNotify_After_StreamError(t *testing.T) {
1433	testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1434		// data longer than declared Content-Length => stream error
1435		st.writeData(1, true, []byte("1234"))
1436	}, nil, "content-length", "3")
1437}
1438
1439func TestServer_StateTransitions(t *testing.T) {
1440	var st *serverTester
1441	inHandler := make(chan bool)
1442	writeData := make(chan bool)
1443	leaveHandler := make(chan bool)
1444	st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1445		inHandler <- true
1446		if st.stream(1) == nil {
1447			t.Errorf("nil stream 1 in handler")
1448		}
1449		if got, want := st.streamState(1), stateOpen; got != want {
1450			t.Errorf("in handler, state is %v; want %v", got, want)
1451		}
1452		writeData <- true
1453		if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
1454			t.Errorf("body read = %d, %v; want 0, EOF", n, err)
1455		}
1456		if got, want := st.streamState(1), stateHalfClosedRemote; got != want {
1457			t.Errorf("in handler, state is %v; want %v", got, want)
1458		}
1459
1460		<-leaveHandler
1461	})
1462	st.greet()
1463	if st.stream(1) != nil {
1464		t.Fatal("stream 1 should be empty")
1465	}
1466	if got := st.streamState(1); got != stateIdle {
1467		t.Fatalf("stream 1 should be idle; got %v", got)
1468	}
1469
1470	st.writeHeaders(HeadersFrameParam{
1471		StreamID:      1,
1472		BlockFragment: st.encodeHeader(":method", "POST"),
1473		EndStream:     false, // keep it open
1474		EndHeaders:    true,
1475	})
1476	<-inHandler
1477	<-writeData
1478	st.writeData(1, true, nil)
1479
1480	leaveHandler <- true
1481	hf := st.wantHeaders()
1482	if !hf.StreamEnded() {
1483		t.Fatal("expected END_STREAM flag")
1484	}
1485
1486	if got, want := st.streamState(1), stateClosed; got != want {
1487		t.Errorf("at end, state is %v; want %v", got, want)
1488	}
1489	if st.stream(1) != nil {
1490		t.Fatal("at end, stream 1 should be gone")
1491	}
1492}
1493
1494// test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
1495func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
1496	testServerRejectsConn(t, func(st *serverTester) {
1497		st.writeHeaders(HeadersFrameParam{
1498			StreamID:      1,
1499			BlockFragment: st.encodeHeader(),
1500			EndStream:     true,
1501			EndHeaders:    false,
1502		})
1503		st.writeHeaders(HeadersFrameParam{ // Not a continuation.
1504			StreamID:      3, // different stream.
1505			BlockFragment: st.encodeHeader(),
1506			EndStream:     true,
1507			EndHeaders:    true,
1508		})
1509	})
1510}
1511
1512// test HEADERS w/o EndHeaders + PING (should get rejected)
1513func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
1514	testServerRejectsConn(t, func(st *serverTester) {
1515		st.writeHeaders(HeadersFrameParam{
1516			StreamID:      1,
1517			BlockFragment: st.encodeHeader(),
1518			EndStream:     true,
1519			EndHeaders:    false,
1520		})
1521		if err := st.fr.WritePing(false, [8]byte{}); err != nil {
1522			t.Fatal(err)
1523		}
1524	})
1525}
1526
1527// test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
1528func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
1529	testServerRejectsConn(t, func(st *serverTester) {
1530		st.writeHeaders(HeadersFrameParam{
1531			StreamID:      1,
1532			BlockFragment: st.encodeHeader(),
1533			EndStream:     true,
1534			EndHeaders:    true,
1535		})
1536		st.wantHeaders()
1537		if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1538			t.Fatal(err)
1539		}
1540	})
1541}
1542
1543// test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID
1544func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
1545	testServerRejectsConn(t, func(st *serverTester) {
1546		st.writeHeaders(HeadersFrameParam{
1547			StreamID:      1,
1548			BlockFragment: st.encodeHeader(),
1549			EndStream:     true,
1550			EndHeaders:    false,
1551		})
1552		if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1553			t.Fatal(err)
1554		}
1555	})
1556}
1557
1558// No HEADERS on stream 0.
1559func TestServer_Rejects_Headers0(t *testing.T) {
1560	testServerRejectsConn(t, func(st *serverTester) {
1561		st.fr.AllowIllegalWrites = true
1562		st.writeHeaders(HeadersFrameParam{
1563			StreamID:      0,
1564			BlockFragment: st.encodeHeader(),
1565			EndStream:     true,
1566			EndHeaders:    true,
1567		})
1568	})
1569}
1570
1571// No CONTINUATION on stream 0.
1572func TestServer_Rejects_Continuation0(t *testing.T) {
1573	testServerRejectsConn(t, func(st *serverTester) {
1574		st.fr.AllowIllegalWrites = true
1575		if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil {
1576			t.Fatal(err)
1577		}
1578	})
1579}
1580
1581// No PRIORITY on stream 0.
1582func TestServer_Rejects_Priority0(t *testing.T) {
1583	testServerRejectsConn(t, func(st *serverTester) {
1584		st.fr.AllowIllegalWrites = true
1585		st.writePriority(0, PriorityParam{StreamDep: 1})
1586	})
1587}
1588
1589// No HEADERS frame with a self-dependence.
1590func TestServer_Rejects_HeadersSelfDependence(t *testing.T) {
1591	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1592		st.fr.AllowIllegalWrites = true
1593		st.writeHeaders(HeadersFrameParam{
1594			StreamID:      1,
1595			BlockFragment: st.encodeHeader(),
1596			EndStream:     true,
1597			EndHeaders:    true,
1598			Priority:      PriorityParam{StreamDep: 1},
1599		})
1600	})
1601}
1602
1603// No PRIORTY frame with a self-dependence.
1604func TestServer_Rejects_PrioritySelfDependence(t *testing.T) {
1605	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1606		st.fr.AllowIllegalWrites = true
1607		st.writePriority(1, PriorityParam{StreamDep: 1})
1608	})
1609}
1610
1611func TestServer_Rejects_PushPromise(t *testing.T) {
1612	testServerRejectsConn(t, func(st *serverTester) {
1613		pp := PushPromiseParam{
1614			StreamID:  1,
1615			PromiseID: 3,
1616		}
1617		if err := st.fr.WritePushPromise(pp); err != nil {
1618			t.Fatal(err)
1619		}
1620	})
1621}
1622
1623// testServerRejectsConn tests that the server hangs up with a GOAWAY
1624// frame and a server close after the client does something
1625// deserving a CONNECTION_ERROR.
1626func testServerRejectsConn(t *testing.T, writeReq func(*serverTester)) {
1627	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1628	st.addLogFilter("connection error: PROTOCOL_ERROR")
1629	defer st.Close()
1630	st.greet()
1631	writeReq(st)
1632
1633	st.wantGoAway()
1634	errc := make(chan error, 1)
1635	go func() {
1636		fr, err := st.fr.ReadFrame()
1637		if err == nil {
1638			err = fmt.Errorf("got frame of type %T", fr)
1639		}
1640		errc <- err
1641	}()
1642	select {
1643	case err := <-errc:
1644		if err != io.EOF {
1645			t.Errorf("ReadFrame = %v; want io.EOF", err)
1646		}
1647	case <-time.After(2 * time.Second):
1648		t.Error("timeout waiting for disconnect")
1649	}
1650}
1651
1652// testServerRejectsStream tests that the server sends a RST_STREAM with the provided
1653// error code after a client sends a bogus request.
1654func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTester)) {
1655	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1656	defer st.Close()
1657	st.greet()
1658	writeReq(st)
1659	st.wantRSTStream(1, code)
1660}
1661
1662// testServerRequest sets up an idle HTTP/2 connection and lets you
1663// write a single request with writeReq, and then verify that the
1664// *http.Request is built correctly in checkReq.
1665func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) {
1666	gotReq := make(chan bool, 1)
1667	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1668		if r.Body == nil {
1669			t.Fatal("nil Body")
1670		}
1671		checkReq(r)
1672		gotReq <- true
1673	})
1674	defer st.Close()
1675
1676	st.greet()
1677	writeReq(st)
1678
1679	select {
1680	case <-gotReq:
1681	case <-time.After(2 * time.Second):
1682		t.Error("timeout waiting for request")
1683	}
1684}
1685
1686func getSlash(st *serverTester) { st.bodylessReq1() }
1687
1688func TestServer_Response_NoData(t *testing.T) {
1689	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1690		// Nothing.
1691		return nil
1692	}, func(st *serverTester) {
1693		getSlash(st)
1694		hf := st.wantHeaders()
1695		if !hf.StreamEnded() {
1696			t.Fatal("want END_STREAM flag")
1697		}
1698		if !hf.HeadersEnded() {
1699			t.Fatal("want END_HEADERS flag")
1700		}
1701	})
1702}
1703
1704func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
1705	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1706		w.Header().Set("Foo-Bar", "some-value")
1707		return nil
1708	}, func(st *serverTester) {
1709		getSlash(st)
1710		hf := st.wantHeaders()
1711		if !hf.StreamEnded() {
1712			t.Fatal("want END_STREAM flag")
1713		}
1714		if !hf.HeadersEnded() {
1715			t.Fatal("want END_HEADERS flag")
1716		}
1717		goth := st.decodeHeader(hf.HeaderBlockFragment())
1718		wanth := [][2]string{
1719			{":status", "200"},
1720			{"foo-bar", "some-value"},
1721			{"content-length", "0"},
1722		}
1723		if !reflect.DeepEqual(goth, wanth) {
1724			t.Errorf("Got headers %v; want %v", goth, wanth)
1725		}
1726	})
1727}
1728
1729func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) {
1730	const msg = "<html>this is HTML."
1731	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1732		w.Header().Set("Content-Type", "foo/bar")
1733		io.WriteString(w, msg)
1734		return nil
1735	}, func(st *serverTester) {
1736		getSlash(st)
1737		hf := st.wantHeaders()
1738		if hf.StreamEnded() {
1739			t.Fatal("don't want END_STREAM, expecting data")
1740		}
1741		if !hf.HeadersEnded() {
1742			t.Fatal("want END_HEADERS flag")
1743		}
1744		goth := st.decodeHeader(hf.HeaderBlockFragment())
1745		wanth := [][2]string{
1746			{":status", "200"},
1747			{"content-type", "foo/bar"},
1748			{"content-length", strconv.Itoa(len(msg))},
1749		}
1750		if !reflect.DeepEqual(goth, wanth) {
1751			t.Errorf("Got headers %v; want %v", goth, wanth)
1752		}
1753		df := st.wantData()
1754		if !df.StreamEnded() {
1755			t.Error("expected DATA to have END_STREAM flag")
1756		}
1757		if got := string(df.Data()); got != msg {
1758			t.Errorf("got DATA %q; want %q", got, msg)
1759		}
1760	})
1761}
1762
1763func TestServer_Response_TransferEncoding_chunked(t *testing.T) {
1764	const msg = "hi"
1765	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1766		w.Header().Set("Transfer-Encoding", "chunked") // should be stripped
1767		io.WriteString(w, msg)
1768		return nil
1769	}, func(st *serverTester) {
1770		getSlash(st)
1771		hf := st.wantHeaders()
1772		goth := st.decodeHeader(hf.HeaderBlockFragment())
1773		wanth := [][2]string{
1774			{":status", "200"},
1775			{"content-type", "text/plain; charset=utf-8"},
1776			{"content-length", strconv.Itoa(len(msg))},
1777		}
1778		if !reflect.DeepEqual(goth, wanth) {
1779			t.Errorf("Got headers %v; want %v", goth, wanth)
1780		}
1781	})
1782}
1783
1784// Header accessed only after the initial write.
1785func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) {
1786	const msg = "<html>this is HTML."
1787	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1788		io.WriteString(w, msg)
1789		w.Header().Set("foo", "should be ignored")
1790		return nil
1791	}, func(st *serverTester) {
1792		getSlash(st)
1793		hf := st.wantHeaders()
1794		if hf.StreamEnded() {
1795			t.Fatal("unexpected END_STREAM")
1796		}
1797		if !hf.HeadersEnded() {
1798			t.Fatal("want END_HEADERS flag")
1799		}
1800		goth := st.decodeHeader(hf.HeaderBlockFragment())
1801		wanth := [][2]string{
1802			{":status", "200"},
1803			{"content-type", "text/html; charset=utf-8"},
1804			{"content-length", strconv.Itoa(len(msg))},
1805		}
1806		if !reflect.DeepEqual(goth, wanth) {
1807			t.Errorf("Got headers %v; want %v", goth, wanth)
1808		}
1809	})
1810}
1811
1812// Header accessed before the initial write and later mutated.
1813func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) {
1814	const msg = "<html>this is HTML."
1815	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1816		w.Header().Set("foo", "proper value")
1817		io.WriteString(w, msg)
1818		w.Header().Set("foo", "should be ignored")
1819		return nil
1820	}, func(st *serverTester) {
1821		getSlash(st)
1822		hf := st.wantHeaders()
1823		if hf.StreamEnded() {
1824			t.Fatal("unexpected END_STREAM")
1825		}
1826		if !hf.HeadersEnded() {
1827			t.Fatal("want END_HEADERS flag")
1828		}
1829		goth := st.decodeHeader(hf.HeaderBlockFragment())
1830		wanth := [][2]string{
1831			{":status", "200"},
1832			{"foo", "proper value"},
1833			{"content-type", "text/html; charset=utf-8"},
1834			{"content-length", strconv.Itoa(len(msg))},
1835		}
1836		if !reflect.DeepEqual(goth, wanth) {
1837			t.Errorf("Got headers %v; want %v", goth, wanth)
1838		}
1839	})
1840}
1841
1842func TestServer_Response_Data_SniffLenType(t *testing.T) {
1843	const msg = "<html>this is HTML."
1844	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1845		io.WriteString(w, msg)
1846		return nil
1847	}, func(st *serverTester) {
1848		getSlash(st)
1849		hf := st.wantHeaders()
1850		if hf.StreamEnded() {
1851			t.Fatal("don't want END_STREAM, expecting data")
1852		}
1853		if !hf.HeadersEnded() {
1854			t.Fatal("want END_HEADERS flag")
1855		}
1856		goth := st.decodeHeader(hf.HeaderBlockFragment())
1857		wanth := [][2]string{
1858			{":status", "200"},
1859			{"content-type", "text/html; charset=utf-8"},
1860			{"content-length", strconv.Itoa(len(msg))},
1861		}
1862		if !reflect.DeepEqual(goth, wanth) {
1863			t.Errorf("Got headers %v; want %v", goth, wanth)
1864		}
1865		df := st.wantData()
1866		if !df.StreamEnded() {
1867			t.Error("expected DATA to have END_STREAM flag")
1868		}
1869		if got := string(df.Data()); got != msg {
1870			t.Errorf("got DATA %q; want %q", got, msg)
1871		}
1872	})
1873}
1874
1875func TestServer_Response_Header_Flush_MidWrite(t *testing.T) {
1876	const msg = "<html>this is HTML"
1877	const msg2 = ", and this is the next chunk"
1878	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1879		io.WriteString(w, msg)
1880		w.(http.Flusher).Flush()
1881		io.WriteString(w, msg2)
1882		return nil
1883	}, func(st *serverTester) {
1884		getSlash(st)
1885		hf := st.wantHeaders()
1886		if hf.StreamEnded() {
1887			t.Fatal("unexpected END_STREAM flag")
1888		}
1889		if !hf.HeadersEnded() {
1890			t.Fatal("want END_HEADERS flag")
1891		}
1892		goth := st.decodeHeader(hf.HeaderBlockFragment())
1893		wanth := [][2]string{
1894			{":status", "200"},
1895			{"content-type", "text/html; charset=utf-8"}, // sniffed
1896			// and no content-length
1897		}
1898		if !reflect.DeepEqual(goth, wanth) {
1899			t.Errorf("Got headers %v; want %v", goth, wanth)
1900		}
1901		{
1902			df := st.wantData()
1903			if df.StreamEnded() {
1904				t.Error("unexpected END_STREAM flag")
1905			}
1906			if got := string(df.Data()); got != msg {
1907				t.Errorf("got DATA %q; want %q", got, msg)
1908			}
1909		}
1910		{
1911			df := st.wantData()
1912			if !df.StreamEnded() {
1913				t.Error("wanted END_STREAM flag on last data chunk")
1914			}
1915			if got := string(df.Data()); got != msg2 {
1916				t.Errorf("got DATA %q; want %q", got, msg2)
1917			}
1918		}
1919	})
1920}
1921
1922func TestServer_Response_LargeWrite(t *testing.T) {
1923	const size = 1 << 20
1924	const maxFrameSize = 16 << 10
1925	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1926		n, err := w.Write(bytes.Repeat([]byte("a"), size))
1927		if err != nil {
1928			return fmt.Errorf("Write error: %v", err)
1929		}
1930		if n != size {
1931			return fmt.Errorf("wrong size %d from Write", n)
1932		}
1933		return nil
1934	}, func(st *serverTester) {
1935		if err := st.fr.WriteSettings(
1936			Setting{SettingInitialWindowSize, 0},
1937			Setting{SettingMaxFrameSize, maxFrameSize},
1938		); err != nil {
1939			t.Fatal(err)
1940		}
1941		st.wantSettingsAck()
1942
1943		getSlash(st) // make the single request
1944
1945		// Give the handler quota to write:
1946		if err := st.fr.WriteWindowUpdate(1, size); err != nil {
1947			t.Fatal(err)
1948		}
1949		// Give the handler quota to write to connection-level
1950		// window as well
1951		if err := st.fr.WriteWindowUpdate(0, size); err != nil {
1952			t.Fatal(err)
1953		}
1954		hf := st.wantHeaders()
1955		if hf.StreamEnded() {
1956			t.Fatal("unexpected END_STREAM flag")
1957		}
1958		if !hf.HeadersEnded() {
1959			t.Fatal("want END_HEADERS flag")
1960		}
1961		goth := st.decodeHeader(hf.HeaderBlockFragment())
1962		wanth := [][2]string{
1963			{":status", "200"},
1964			{"content-type", "text/plain; charset=utf-8"}, // sniffed
1965			// and no content-length
1966		}
1967		if !reflect.DeepEqual(goth, wanth) {
1968			t.Errorf("Got headers %v; want %v", goth, wanth)
1969		}
1970		var bytes, frames int
1971		for {
1972			df := st.wantData()
1973			bytes += len(df.Data())
1974			frames++
1975			for _, b := range df.Data() {
1976				if b != 'a' {
1977					t.Fatal("non-'a' byte seen in DATA")
1978				}
1979			}
1980			if df.StreamEnded() {
1981				break
1982			}
1983		}
1984		if bytes != size {
1985			t.Errorf("Got %d bytes; want %d", bytes, size)
1986		}
1987		if want := int(size / maxFrameSize); frames < want || frames > want*2 {
1988			t.Errorf("Got %d frames; want %d", frames, size)
1989		}
1990	})
1991}
1992
1993// Test that the handler can't write more than the client allows
1994func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) {
1995	// Make these reads. Before each read, the client adds exactly enough
1996	// flow-control to satisfy the read. Numbers chosen arbitrarily.
1997	reads := []int{123, 1, 13, 127}
1998	size := 0
1999	for _, n := range reads {
2000		size += n
2001	}
2002
2003	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2004		w.(http.Flusher).Flush()
2005		n, err := w.Write(bytes.Repeat([]byte("a"), size))
2006		if err != nil {
2007			return fmt.Errorf("Write error: %v", err)
2008		}
2009		if n != size {
2010			return fmt.Errorf("wrong size %d from Write", n)
2011		}
2012		return nil
2013	}, func(st *serverTester) {
2014		// Set the window size to something explicit for this test.
2015		// It's also how much initial data we expect.
2016		if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, uint32(reads[0])}); err != nil {
2017			t.Fatal(err)
2018		}
2019		st.wantSettingsAck()
2020
2021		getSlash(st) // make the single request
2022
2023		hf := st.wantHeaders()
2024		if hf.StreamEnded() {
2025			t.Fatal("unexpected END_STREAM flag")
2026		}
2027		if !hf.HeadersEnded() {
2028			t.Fatal("want END_HEADERS flag")
2029		}
2030
2031		df := st.wantData()
2032		if got := len(df.Data()); got != reads[0] {
2033			t.Fatalf("Initial window size = %d but got DATA with %d bytes", reads[0], got)
2034		}
2035
2036		for _, quota := range reads[1:] {
2037			if err := st.fr.WriteWindowUpdate(1, uint32(quota)); err != nil {
2038				t.Fatal(err)
2039			}
2040			df := st.wantData()
2041			if int(quota) != len(df.Data()) {
2042				t.Fatalf("read %d bytes after giving %d quota", len(df.Data()), quota)
2043			}
2044		}
2045	})
2046}
2047
2048// Test that the handler blocked in a Write is unblocked if the server sends a RST_STREAM.
2049func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) {
2050	const size = 1 << 20
2051	const maxFrameSize = 16 << 10
2052	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2053		w.(http.Flusher).Flush()
2054		errc := make(chan error, 1)
2055		go func() {
2056			_, err := w.Write(bytes.Repeat([]byte("a"), size))
2057			errc <- err
2058		}()
2059		select {
2060		case err := <-errc:
2061			if err == nil {
2062				return errors.New("unexpected nil error from Write in handler")
2063			}
2064			return nil
2065		case <-time.After(2 * time.Second):
2066			return errors.New("timeout waiting for Write in handler")
2067		}
2068	}, func(st *serverTester) {
2069		if err := st.fr.WriteSettings(
2070			Setting{SettingInitialWindowSize, 0},
2071			Setting{SettingMaxFrameSize, maxFrameSize},
2072		); err != nil {
2073			t.Fatal(err)
2074		}
2075		st.wantSettingsAck()
2076
2077		getSlash(st) // make the single request
2078
2079		hf := st.wantHeaders()
2080		if hf.StreamEnded() {
2081			t.Fatal("unexpected END_STREAM flag")
2082		}
2083		if !hf.HeadersEnded() {
2084			t.Fatal("want END_HEADERS flag")
2085		}
2086
2087		if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
2088			t.Fatal(err)
2089		}
2090	})
2091}
2092
2093func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) {
2094	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2095		w.(http.Flusher).Flush()
2096		// Nothing; send empty DATA
2097		return nil
2098	}, func(st *serverTester) {
2099		// Handler gets no data quota:
2100		if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, 0}); err != nil {
2101			t.Fatal(err)
2102		}
2103		st.wantSettingsAck()
2104
2105		getSlash(st) // make the single request
2106
2107		hf := st.wantHeaders()
2108		if hf.StreamEnded() {
2109			t.Fatal("unexpected END_STREAM flag")
2110		}
2111		if !hf.HeadersEnded() {
2112			t.Fatal("want END_HEADERS flag")
2113		}
2114
2115		df := st.wantData()
2116		if got := len(df.Data()); got != 0 {
2117			t.Fatalf("unexpected %d DATA bytes; want 0", got)
2118		}
2119		if !df.StreamEnded() {
2120			t.Fatal("DATA didn't have END_STREAM")
2121		}
2122	})
2123}
2124
2125func TestServer_Response_Automatic100Continue(t *testing.T) {
2126	const msg = "foo"
2127	const reply = "bar"
2128	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2129		if v := r.Header.Get("Expect"); v != "" {
2130			t.Errorf("Expect header = %q; want empty", v)
2131		}
2132		buf := make([]byte, len(msg))
2133		// This read should trigger the 100-continue being sent.
2134		if n, err := io.ReadFull(r.Body, buf); err != nil || n != len(msg) || string(buf) != msg {
2135			return fmt.Errorf("ReadFull = %q, %v; want %q, nil", buf[:n], err, msg)
2136		}
2137		_, err := io.WriteString(w, reply)
2138		return err
2139	}, func(st *serverTester) {
2140		st.writeHeaders(HeadersFrameParam{
2141			StreamID:      1, // clients send odd numbers
2142			BlockFragment: st.encodeHeader(":method", "POST", "expect", "100-continue"),
2143			EndStream:     false,
2144			EndHeaders:    true,
2145		})
2146		hf := st.wantHeaders()
2147		if hf.StreamEnded() {
2148			t.Fatal("unexpected END_STREAM flag")
2149		}
2150		if !hf.HeadersEnded() {
2151			t.Fatal("want END_HEADERS flag")
2152		}
2153		goth := st.decodeHeader(hf.HeaderBlockFragment())
2154		wanth := [][2]string{
2155			{":status", "100"},
2156		}
2157		if !reflect.DeepEqual(goth, wanth) {
2158			t.Fatalf("Got headers %v; want %v", goth, wanth)
2159		}
2160
2161		// Okay, they sent status 100, so we can send our
2162		// gigantic and/or sensitive "foo" payload now.
2163		st.writeData(1, true, []byte(msg))
2164
2165		st.wantWindowUpdate(0, uint32(len(msg)))
2166
2167		hf = st.wantHeaders()
2168		if hf.StreamEnded() {
2169			t.Fatal("expected data to follow")
2170		}
2171		if !hf.HeadersEnded() {
2172			t.Fatal("want END_HEADERS flag")
2173		}
2174		goth = st.decodeHeader(hf.HeaderBlockFragment())
2175		wanth = [][2]string{
2176			{":status", "200"},
2177			{"content-type", "text/plain; charset=utf-8"},
2178			{"content-length", strconv.Itoa(len(reply))},
2179		}
2180		if !reflect.DeepEqual(goth, wanth) {
2181			t.Errorf("Got headers %v; want %v", goth, wanth)
2182		}
2183
2184		df := st.wantData()
2185		if string(df.Data()) != reply {
2186			t.Errorf("Client read %q; want %q", df.Data(), reply)
2187		}
2188		if !df.StreamEnded() {
2189			t.Errorf("expect data stream end")
2190		}
2191	})
2192}
2193
2194func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
2195	errc := make(chan error, 1)
2196	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2197		p := []byte("some data.\n")
2198		for {
2199			_, err := w.Write(p)
2200			if err != nil {
2201				errc <- err
2202				return nil
2203			}
2204		}
2205	}, func(st *serverTester) {
2206		st.writeHeaders(HeadersFrameParam{
2207			StreamID:      1,
2208			BlockFragment: st.encodeHeader(),
2209			EndStream:     false,
2210			EndHeaders:    true,
2211		})
2212		hf := st.wantHeaders()
2213		if hf.StreamEnded() {
2214			t.Fatal("unexpected END_STREAM flag")
2215		}
2216		if !hf.HeadersEnded() {
2217			t.Fatal("want END_HEADERS flag")
2218		}
2219		// Close the connection and wait for the handler to (hopefully) notice.
2220		st.cc.Close()
2221		select {
2222		case <-errc:
2223		case <-time.After(5 * time.Second):
2224			t.Error("timeout")
2225		}
2226	})
2227}
2228
2229func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
2230	const testPath = "/some/path"
2231
2232	inHandler := make(chan uint32)
2233	leaveHandler := make(chan bool)
2234	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2235		id := w.(*responseWriter).rws.stream.id
2236		inHandler <- id
2237		if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath {
2238			t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath)
2239		}
2240		<-leaveHandler
2241	})
2242	defer st.Close()
2243	st.greet()
2244	nextStreamID := uint32(1)
2245	streamID := func() uint32 {
2246		defer func() { nextStreamID += 2 }()
2247		return nextStreamID
2248	}
2249	sendReq := func(id uint32, headers ...string) {
2250		st.writeHeaders(HeadersFrameParam{
2251			StreamID:      id,
2252			BlockFragment: st.encodeHeader(headers...),
2253			EndStream:     true,
2254			EndHeaders:    true,
2255		})
2256	}
2257	for i := 0; i < defaultMaxStreams; i++ {
2258		sendReq(streamID())
2259		<-inHandler
2260	}
2261	defer func() {
2262		for i := 0; i < defaultMaxStreams; i++ {
2263			leaveHandler <- true
2264		}
2265	}()
2266
2267	// And this one should cross the limit:
2268	// (It's also sent as a CONTINUATION, to verify we still track the decoder context,
2269	// even if we're rejecting it)
2270	rejectID := streamID()
2271	headerBlock := st.encodeHeader(":path", testPath)
2272	frag1, frag2 := headerBlock[:3], headerBlock[3:]
2273	st.writeHeaders(HeadersFrameParam{
2274		StreamID:      rejectID,
2275		BlockFragment: frag1,
2276		EndStream:     true,
2277		EndHeaders:    false, // CONTINUATION coming
2278	})
2279	if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil {
2280		t.Fatal(err)
2281	}
2282	st.wantRSTStream(rejectID, ErrCodeProtocol)
2283
2284	// But let a handler finish:
2285	leaveHandler <- true
2286	st.wantHeaders()
2287
2288	// And now another stream should be able to start:
2289	goodID := streamID()
2290	sendReq(goodID, ":path", testPath)
2291	select {
2292	case got := <-inHandler:
2293		if got != goodID {
2294			t.Errorf("Got stream %d; want %d", got, goodID)
2295		}
2296	case <-time.After(3 * time.Second):
2297		t.Error("timeout waiting for handler")
2298	}
2299}
2300
2301// So many response headers that the server needs to use CONTINUATION frames:
2302func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) {
2303	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2304		h := w.Header()
2305		for i := 0; i < 5000; i++ {
2306			h.Set(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("x-value-%d", i))
2307		}
2308		return nil
2309	}, func(st *serverTester) {
2310		getSlash(st)
2311		hf := st.wantHeaders()
2312		if hf.HeadersEnded() {
2313			t.Fatal("got unwanted END_HEADERS flag")
2314		}
2315		n := 0
2316		for {
2317			n++
2318			cf := st.wantContinuation()
2319			if cf.HeadersEnded() {
2320				break
2321			}
2322		}
2323		if n < 5 {
2324			t.Errorf("Only got %d CONTINUATION frames; expected 5+ (currently 6)", n)
2325		}
2326	})
2327}
2328
2329// This previously crashed (reported by Mathieu Lonjaret as observed
2330// while using Camlistore) because we got a DATA frame from the client
2331// after the handler exited and our logic at the time was wrong,
2332// keeping a stream in the map in stateClosed, which tickled an
2333// invariant check later when we tried to remove that stream (via
2334// defer sc.closeAllStreamsOnConnClose) when the serverConn serve loop
2335// ended.
2336func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) {
2337	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2338		// nothing
2339		return nil
2340	}, func(st *serverTester) {
2341		st.writeHeaders(HeadersFrameParam{
2342			StreamID:      1,
2343			BlockFragment: st.encodeHeader(),
2344			EndStream:     false, // DATA is coming
2345			EndHeaders:    true,
2346		})
2347		hf := st.wantHeaders()
2348		if !hf.HeadersEnded() || !hf.StreamEnded() {
2349			t.Fatalf("want END_HEADERS+END_STREAM, got %v", hf)
2350		}
2351
2352		// Sent when the a Handler closes while a client has
2353		// indicated it's still sending DATA:
2354		st.wantRSTStream(1, ErrCodeNo)
2355
2356		// Now the handler has ended, so it's ended its
2357		// stream, but the client hasn't closed its side
2358		// (stateClosedLocal).  So send more data and verify
2359		// it doesn't crash with an internal invariant panic, like
2360		// it did before.
2361		st.writeData(1, true, []byte("foo"))
2362
2363		// Sent after a peer sends data anyway (admittedly the
2364		// previous RST_STREAM might've still been in-flight),
2365		// but they'll get the more friendly 'cancel' code
2366		// first.
2367		st.wantRSTStream(1, ErrCodeStreamClosed)
2368
2369		// Set up a bunch of machinery to record the panic we saw
2370		// previously.
2371		var (
2372			panMu    sync.Mutex
2373			panicVal interface{}
2374		)
2375
2376		testHookOnPanicMu.Lock()
2377		testHookOnPanic = func(sc *serverConn, pv interface{}) bool {
2378			panMu.Lock()
2379			panicVal = pv
2380			panMu.Unlock()
2381			return true
2382		}
2383		testHookOnPanicMu.Unlock()
2384
2385		// Now force the serve loop to end, via closing the connection.
2386		st.cc.Close()
2387		select {
2388		case <-st.sc.doneServing:
2389			// Loop has exited.
2390			panMu.Lock()
2391			got := panicVal
2392			panMu.Unlock()
2393			if got != nil {
2394				t.Errorf("Got panic: %v", got)
2395			}
2396		case <-time.After(5 * time.Second):
2397			t.Error("timeout")
2398		}
2399	})
2400}
2401
2402func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) }
2403func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) }
2404
2405func testRejectTLS(t *testing.T, max uint16) {
2406	st := newServerTester(t, nil, func(c *tls.Config) {
2407		c.MaxVersion = max
2408	})
2409	defer st.Close()
2410	gf := st.wantGoAway()
2411	if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
2412		t.Errorf("Got error code %v; want %v", got, want)
2413	}
2414}
2415
2416func TestServer_Rejects_TLSBadCipher(t *testing.T) {
2417	st := newServerTester(t, nil, func(c *tls.Config) {
2418		// Only list bad ones:
2419		c.CipherSuites = []uint16{
2420			tls.TLS_RSA_WITH_RC4_128_SHA,
2421			tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
2422			tls.TLS_RSA_WITH_AES_128_CBC_SHA,
2423			tls.TLS_RSA_WITH_AES_256_CBC_SHA,
2424			tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
2425			tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
2426			tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
2427			tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
2428			tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
2429			tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
2430			tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
2431			cipher_TLS_RSA_WITH_AES_128_CBC_SHA256,
2432		}
2433	})
2434	defer st.Close()
2435	gf := st.wantGoAway()
2436	if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
2437		t.Errorf("Got error code %v; want %v", got, want)
2438	}
2439}
2440
2441func TestServer_Advertises_Common_Cipher(t *testing.T) {
2442	const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
2443	st := newServerTester(t, nil, func(c *tls.Config) {
2444		// Have the client only support the one required by the spec.
2445		c.CipherSuites = []uint16{requiredSuite}
2446	}, func(ts *httptest.Server) {
2447		var srv *http.Server = ts.Config
2448		// Have the server configured with no specific cipher suites.
2449		// This tests that Go's defaults include the required one.
2450		srv.TLSConfig = nil
2451	})
2452	defer st.Close()
2453	st.greet()
2454}
2455
2456func (st *serverTester) onHeaderField(f hpack.HeaderField) {
2457	if f.Name == "date" {
2458		return
2459	}
2460	st.decodedHeaders = append(st.decodedHeaders, [2]string{f.Name, f.Value})
2461}
2462
2463func (st *serverTester) decodeHeader(headerBlock []byte) (pairs [][2]string) {
2464	st.decodedHeaders = nil
2465	if _, err := st.hpackDec.Write(headerBlock); err != nil {
2466		st.t.Fatalf("hpack decoding error: %v", err)
2467	}
2468	if err := st.hpackDec.Close(); err != nil {
2469		st.t.Fatalf("hpack decoding error: %v", err)
2470	}
2471	return st.decodedHeaders
2472}
2473
2474// testServerResponse sets up an idle HTTP/2 connection. The client function should
2475// write a single request that must be handled by the handler. This waits up to 5s
2476// for client to return, then up to an additional 2s for the handler to return.
2477func testServerResponse(t testing.TB,
2478	handler func(http.ResponseWriter, *http.Request) error,
2479	client func(*serverTester),
2480) {
2481	errc := make(chan error, 1)
2482	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2483		if r.Body == nil {
2484			t.Fatal("nil Body")
2485		}
2486		errc <- handler(w, r)
2487	})
2488	defer st.Close()
2489
2490	donec := make(chan bool)
2491	go func() {
2492		defer close(donec)
2493		st.greet()
2494		client(st)
2495	}()
2496
2497	select {
2498	case <-donec:
2499	case <-time.After(5 * time.Second):
2500		t.Fatal("timeout in client")
2501	}
2502
2503	select {
2504	case err := <-errc:
2505		if err != nil {
2506			t.Fatalf("Error in handler: %v", err)
2507		}
2508	case <-time.After(2 * time.Second):
2509		t.Fatal("timeout in handler")
2510	}
2511}
2512
2513// readBodyHandler returns an http Handler func that reads len(want)
2514// bytes from r.Body and fails t if the contents read were not
2515// the value of want.
2516func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *http.Request) {
2517	return func(w http.ResponseWriter, r *http.Request) {
2518		buf := make([]byte, len(want))
2519		_, err := io.ReadFull(r.Body, buf)
2520		if err != nil {
2521			t.Error(err)
2522			return
2523		}
2524		if string(buf) != want {
2525			t.Errorf("read %q; want %q", buf, want)
2526		}
2527	}
2528}
2529
2530// TestServerWithCurl currently fails, hence the LenientCipherSuites test. See:
2531//   https://github.com/tatsuhiro-t/nghttp2/issues/140 &
2532//   http://sourceforge.net/p/curl/bugs/1472/
2533func TestServerWithCurl(t *testing.T)                     { testServerWithCurl(t, false) }
2534func TestServerWithCurl_LenientCipherSuites(t *testing.T) { testServerWithCurl(t, true) }
2535
2536func testServerWithCurl(t *testing.T, permitProhibitedCipherSuites bool) {
2537	if runtime.GOOS != "linux" {
2538		t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway")
2539	}
2540	if testing.Short() {
2541		t.Skip("skipping curl test in short mode")
2542	}
2543	requireCurl(t)
2544	var gotConn int32
2545	testHookOnConn = func() { atomic.StoreInt32(&gotConn, 1) }
2546
2547	const msg = "Hello from curl!\n"
2548	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2549		w.Header().Set("Foo", "Bar")
2550		w.Header().Set("Client-Proto", r.Proto)
2551		io.WriteString(w, msg)
2552	}))
2553	ConfigureServer(ts.Config, &Server{
2554		PermitProhibitedCipherSuites: permitProhibitedCipherSuites,
2555	})
2556	ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
2557	ts.StartTLS()
2558	defer ts.Close()
2559
2560	t.Logf("Running test server for curl to hit at: %s", ts.URL)
2561	container := curl(t, "--silent", "--http2", "--insecure", "-v", ts.URL)
2562	defer kill(container)
2563	resc := make(chan interface{}, 1)
2564	go func() {
2565		res, err := dockerLogs(container)
2566		if err != nil {
2567			resc <- err
2568		} else {
2569			resc <- res
2570		}
2571	}()
2572	select {
2573	case res := <-resc:
2574		if err, ok := res.(error); ok {
2575			t.Fatal(err)
2576		}
2577		body := string(res.([]byte))
2578		// Search for both "key: value" and "key:value", since curl changed their format
2579		// Our Dockerfile contains the latest version (no space), but just in case people
2580		// didn't rebuild, check both.
2581		if !strings.Contains(body, "foo: Bar") && !strings.Contains(body, "foo:Bar") {
2582			t.Errorf("didn't see foo: Bar header")
2583			t.Logf("Got: %s", body)
2584		}
2585		if !strings.Contains(body, "client-proto: HTTP/2") && !strings.Contains(body, "client-proto:HTTP/2") {
2586			t.Errorf("didn't see client-proto: HTTP/2 header")
2587			t.Logf("Got: %s", res)
2588		}
2589		if !strings.Contains(string(res.([]byte)), msg) {
2590			t.Errorf("didn't see %q content", msg)
2591			t.Logf("Got: %s", res)
2592		}
2593	case <-time.After(3 * time.Second):
2594		t.Errorf("timeout waiting for curl")
2595	}
2596
2597	if atomic.LoadInt32(&gotConn) == 0 {
2598		t.Error("never saw an http2 connection")
2599	}
2600}
2601
2602var doh2load = flag.Bool("h2load", false, "Run h2load test")
2603
2604func TestServerWithH2Load(t *testing.T) {
2605	if !*doh2load {
2606		t.Skip("Skipping without --h2load flag.")
2607	}
2608	if runtime.GOOS != "linux" {
2609		t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway")
2610	}
2611	requireH2load(t)
2612
2613	msg := strings.Repeat("Hello, h2load!\n", 5000)
2614	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2615		io.WriteString(w, msg)
2616		w.(http.Flusher).Flush()
2617		io.WriteString(w, msg)
2618	}))
2619	ts.StartTLS()
2620	defer ts.Close()
2621
2622	cmd := exec.Command("docker", "run", "--net=host", "--entrypoint=/usr/local/bin/h2load", "gohttp2/curl",
2623		"-n100000", "-c100", "-m100", ts.URL)
2624	cmd.Stdout = os.Stdout
2625	cmd.Stderr = os.Stderr
2626	if err := cmd.Run(); err != nil {
2627		t.Fatal(err)
2628	}
2629}
2630
2631// Issue 12843
2632func TestServerDoS_MaxHeaderListSize(t *testing.T) {
2633	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
2634	defer st.Close()
2635
2636	// shake hands
2637	frameSize := defaultMaxReadFrameSize
2638	var advHeaderListSize *uint32
2639	st.greetAndCheckSettings(func(s Setting) error {
2640		switch s.ID {
2641		case SettingMaxFrameSize:
2642			if s.Val < minMaxFrameSize {
2643				frameSize = minMaxFrameSize
2644			} else if s.Val > maxFrameSize {
2645				frameSize = maxFrameSize
2646			} else {
2647				frameSize = int(s.Val)
2648			}
2649		case SettingMaxHeaderListSize:
2650			advHeaderListSize = &s.Val
2651		}
2652		return nil
2653	})
2654
2655	if advHeaderListSize == nil {
2656		t.Errorf("server didn't advertise a max header list size")
2657	} else if *advHeaderListSize == 0 {
2658		t.Errorf("server advertised a max header list size of 0")
2659	}
2660
2661	st.encodeHeaderField(":method", "GET")
2662	st.encodeHeaderField(":path", "/")
2663	st.encodeHeaderField(":scheme", "https")
2664	cookie := strings.Repeat("*", 4058)
2665	st.encodeHeaderField("cookie", cookie)
2666	st.writeHeaders(HeadersFrameParam{
2667		StreamID:      1,
2668		BlockFragment: st.headerBuf.Bytes(),
2669		EndStream:     true,
2670		EndHeaders:    false,
2671	})
2672
2673	// Capture the short encoding of a duplicate ~4K cookie, now
2674	// that we've already sent it once.
2675	st.headerBuf.Reset()
2676	st.encodeHeaderField("cookie", cookie)
2677
2678	// Now send 1MB of it.
2679	const size = 1 << 20
2680	b := bytes.Repeat(st.headerBuf.Bytes(), size/st.headerBuf.Len())
2681	for len(b) > 0 {
2682		chunk := b
2683		if len(chunk) > frameSize {
2684			chunk = chunk[:frameSize]
2685		}
2686		b = b[len(chunk):]
2687		st.fr.WriteContinuation(1, len(b) == 0, chunk)
2688	}
2689
2690	h := st.wantHeaders()
2691	if !h.HeadersEnded() {
2692		t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
2693	}
2694	headers := st.decodeHeader(h.HeaderBlockFragment())
2695	want := [][2]string{
2696		{":status", "431"},
2697		{"content-type", "text/html; charset=utf-8"},
2698		{"content-length", "63"},
2699	}
2700	if !reflect.DeepEqual(headers, want) {
2701		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
2702	}
2703}
2704
2705func TestCompressionErrorOnWrite(t *testing.T) {
2706	const maxStrLen = 8 << 10
2707	var serverConfig *http.Server
2708	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2709		// No response body.
2710	}, func(ts *httptest.Server) {
2711		serverConfig = ts.Config
2712		serverConfig.MaxHeaderBytes = maxStrLen
2713	})
2714	st.addLogFilter("connection error: COMPRESSION_ERROR")
2715	defer st.Close()
2716	st.greet()
2717
2718	maxAllowed := st.sc.framer.maxHeaderStringLen()
2719
2720	// Crank this up, now that we have a conn connected with the
2721	// hpack.Decoder's max string length set has been initialized
2722	// from the earlier low ~8K value. We want this higher so don't
2723	// hit the max header list size. We only want to test hitting
2724	// the max string size.
2725	serverConfig.MaxHeaderBytes = 1 << 20
2726
2727	// First a request with a header that's exactly the max allowed size
2728	// for the hpack compression. It's still too long for the header list
2729	// size, so we'll get the 431 error, but that keeps the compression
2730	// context still valid.
2731	hbf := st.encodeHeader("foo", strings.Repeat("a", maxAllowed))
2732
2733	st.writeHeaders(HeadersFrameParam{
2734		StreamID:      1,
2735		BlockFragment: hbf,
2736		EndStream:     true,
2737		EndHeaders:    true,
2738	})
2739	h := st.wantHeaders()
2740	if !h.HeadersEnded() {
2741		t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
2742	}
2743	headers := st.decodeHeader(h.HeaderBlockFragment())
2744	want := [][2]string{
2745		{":status", "431"},
2746		{"content-type", "text/html; charset=utf-8"},
2747		{"content-length", "63"},
2748	}
2749	if !reflect.DeepEqual(headers, want) {
2750		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
2751	}
2752	df := st.wantData()
2753	if !strings.Contains(string(df.Data()), "HTTP Error 431") {
2754		t.Errorf("Unexpected data body: %q", df.Data())
2755	}
2756	if !df.StreamEnded() {
2757		t.Fatalf("expect data stream end")
2758	}
2759
2760	// And now send one that's just one byte too big.
2761	hbf = st.encodeHeader("bar", strings.Repeat("b", maxAllowed+1))
2762	st.writeHeaders(HeadersFrameParam{
2763		StreamID:      3,
2764		BlockFragment: hbf,
2765		EndStream:     true,
2766		EndHeaders:    true,
2767	})
2768	ga := st.wantGoAway()
2769	if ga.ErrCode != ErrCodeCompression {
2770		t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
2771	}
2772}
2773
2774func TestCompressionErrorOnClose(t *testing.T) {
2775	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2776		// No response body.
2777	})
2778	st.addLogFilter("connection error: COMPRESSION_ERROR")
2779	defer st.Close()
2780	st.greet()
2781
2782	hbf := st.encodeHeader("foo", "bar")
2783	hbf = hbf[:len(hbf)-1] // truncate one byte from the end, so hpack.Decoder.Close fails.
2784	st.writeHeaders(HeadersFrameParam{
2785		StreamID:      1,
2786		BlockFragment: hbf,
2787		EndStream:     true,
2788		EndHeaders:    true,
2789	})
2790	ga := st.wantGoAway()
2791	if ga.ErrCode != ErrCodeCompression {
2792		t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
2793	}
2794}
2795
2796// test that a server handler can read trailers from a client
2797func TestServerReadsTrailers(t *testing.T) {
2798	const testBody = "some test body"
2799	writeReq := func(st *serverTester) {
2800		st.writeHeaders(HeadersFrameParam{
2801			StreamID:      1, // clients send odd numbers
2802			BlockFragment: st.encodeHeader("trailer", "Foo, Bar", "trailer", "Baz"),
2803			EndStream:     false,
2804			EndHeaders:    true,
2805		})
2806		st.writeData(1, false, []byte(testBody))
2807		st.writeHeaders(HeadersFrameParam{
2808			StreamID: 1, // clients send odd numbers
2809			BlockFragment: st.encodeHeaderRaw(
2810				"foo", "foov",
2811				"bar", "barv",
2812				"baz", "bazv",
2813				"surprise", "wasn't declared; shouldn't show up",
2814			),
2815			EndStream:  true,
2816			EndHeaders: true,
2817		})
2818	}
2819	checkReq := func(r *http.Request) {
2820		wantTrailer := http.Header{
2821			"Foo": nil,
2822			"Bar": nil,
2823			"Baz": nil,
2824		}
2825		if !reflect.DeepEqual(r.Trailer, wantTrailer) {
2826			t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer)
2827		}
2828		slurp, err := ioutil.ReadAll(r.Body)
2829		if string(slurp) != testBody {
2830			t.Errorf("read body %q; want %q", slurp, testBody)
2831		}
2832		if err != nil {
2833			t.Fatalf("Body slurp: %v", err)
2834		}
2835		wantTrailerAfter := http.Header{
2836			"Foo": {"foov"},
2837			"Bar": {"barv"},
2838			"Baz": {"bazv"},
2839		}
2840		if !reflect.DeepEqual(r.Trailer, wantTrailerAfter) {
2841			t.Errorf("final Trailer = %v; want %v", r.Trailer, wantTrailerAfter)
2842		}
2843	}
2844	testServerRequest(t, writeReq, checkReq)
2845}
2846
2847// test that a server handler can send trailers
2848func TestServerWritesTrailers_WithFlush(t *testing.T)    { testServerWritesTrailers(t, true) }
2849func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) }
2850
2851func testServerWritesTrailers(t *testing.T, withFlush bool) {
2852	// See https://httpwg.github.io/specs/rfc7540.html#rfc.section.8.1.3
2853	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2854		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
2855		w.Header().Add("Trailer", "Server-Trailer-C")
2856		w.Header().Add("Trailer", "Transfer-Encoding, Content-Length, Trailer") // filtered
2857
2858		// Regular headers:
2859		w.Header().Set("Foo", "Bar")
2860		w.Header().Set("Content-Length", "5") // len("Hello")
2861
2862		io.WriteString(w, "Hello")
2863		if withFlush {
2864			w.(http.Flusher).Flush()
2865		}
2866		w.Header().Set("Server-Trailer-A", "valuea")
2867		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
2868		// After a flush, random keys like Server-Surprise shouldn't show up:
2869		w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!")
2870		// But we do permit promoting keys to trailers after a
2871		// flush if they start with the magic
2872		// otherwise-invalid "Trailer:" prefix:
2873		w.Header().Set("Trailer:Post-Header-Trailer", "hi1")
2874		w.Header().Set("Trailer:post-header-trailer2", "hi2")
2875		w.Header().Set("Trailer:Range", "invalid")
2876		w.Header().Set("Trailer:Foo\x01Bogus", "invalid")
2877		w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 7230 4.1.2")
2878		w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 7230 4.1.2")
2879		w.Header().Set("Trailer", "should not be included; Forbidden by RFC 7230 4.1.2")
2880		return nil
2881	}, func(st *serverTester) {
2882		getSlash(st)
2883		hf := st.wantHeaders()
2884		if hf.StreamEnded() {
2885			t.Fatal("response HEADERS had END_STREAM")
2886		}
2887		if !hf.HeadersEnded() {
2888			t.Fatal("response HEADERS didn't have END_HEADERS")
2889		}
2890		goth := st.decodeHeader(hf.HeaderBlockFragment())
2891		wanth := [][2]string{
2892			{":status", "200"},
2893			{"foo", "Bar"},
2894			{"trailer", "Server-Trailer-A, Server-Trailer-B"},
2895			{"trailer", "Server-Trailer-C"},
2896			{"trailer", "Transfer-Encoding, Content-Length, Trailer"},
2897			{"content-type", "text/plain; charset=utf-8"},
2898			{"content-length", "5"},
2899		}
2900		if !reflect.DeepEqual(goth, wanth) {
2901			t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2902		}
2903		df := st.wantData()
2904		if string(df.Data()) != "Hello" {
2905			t.Fatalf("Client read %q; want Hello", df.Data())
2906		}
2907		if df.StreamEnded() {
2908			t.Fatalf("data frame had STREAM_ENDED")
2909		}
2910		tf := st.wantHeaders() // for the trailers
2911		if !tf.StreamEnded() {
2912			t.Fatalf("trailers HEADERS lacked END_STREAM")
2913		}
2914		if !tf.HeadersEnded() {
2915			t.Fatalf("trailers HEADERS lacked END_HEADERS")
2916		}
2917		wanth = [][2]string{
2918			{"post-header-trailer", "hi1"},
2919			{"post-header-trailer2", "hi2"},
2920			{"server-trailer-a", "valuea"},
2921			{"server-trailer-c", "valuec"},
2922		}
2923		goth = st.decodeHeader(tf.HeaderBlockFragment())
2924		if !reflect.DeepEqual(goth, wanth) {
2925			t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2926		}
2927	})
2928}
2929
2930// validate transmitted header field names & values
2931// golang.org/issue/14048
2932func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
2933	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2934		w.Header().Add("OK1", "x")
2935		w.Header().Add("Bad:Colon", "x") // colon (non-token byte) in key
2936		w.Header().Add("Bad1\x00", "x")  // null in key
2937		w.Header().Add("Bad2", "x\x00y") // null in value
2938		return nil
2939	}, func(st *serverTester) {
2940		getSlash(st)
2941		hf := st.wantHeaders()
2942		if !hf.StreamEnded() {
2943			t.Error("response HEADERS lacked END_STREAM")
2944		}
2945		if !hf.HeadersEnded() {
2946			t.Fatal("response HEADERS didn't have END_HEADERS")
2947		}
2948		goth := st.decodeHeader(hf.HeaderBlockFragment())
2949		wanth := [][2]string{
2950			{":status", "200"},
2951			{"ok1", "x"},
2952			{"content-length", "0"},
2953		}
2954		if !reflect.DeepEqual(goth, wanth) {
2955			t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2956		}
2957	})
2958}
2959
2960func BenchmarkServerGets(b *testing.B) {
2961	defer disableGoroutineTracking()()
2962	b.ReportAllocs()
2963
2964	const msg = "Hello, world"
2965	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
2966		io.WriteString(w, msg)
2967	})
2968	defer st.Close()
2969	st.greet()
2970
2971	// Give the server quota to reply. (plus it has the 64KB)
2972	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
2973		b.Fatal(err)
2974	}
2975
2976	for i := 0; i < b.N; i++ {
2977		id := 1 + uint32(i)*2
2978		st.writeHeaders(HeadersFrameParam{
2979			StreamID:      id,
2980			BlockFragment: st.encodeHeader(),
2981			EndStream:     true,
2982			EndHeaders:    true,
2983		})
2984		st.wantHeaders()
2985		df := st.wantData()
2986		if !df.StreamEnded() {
2987			b.Fatalf("DATA didn't have END_STREAM; got %v", df)
2988		}
2989	}
2990}
2991
2992func BenchmarkServerPosts(b *testing.B) {
2993	defer disableGoroutineTracking()()
2994	b.ReportAllocs()
2995
2996	const msg = "Hello, world"
2997	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
2998		// Consume the (empty) body from th peer before replying, otherwise
2999		// the server will sometimes (depending on scheduling) send the peer a
3000		// a RST_STREAM with the CANCEL error code.
3001		if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
3002			b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
3003		}
3004		io.WriteString(w, msg)
3005	})
3006	defer st.Close()
3007	st.greet()
3008
3009	// Give the server quota to reply. (plus it has the 64KB)
3010	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3011		b.Fatal(err)
3012	}
3013
3014	for i := 0; i < b.N; i++ {
3015		id := 1 + uint32(i)*2
3016		st.writeHeaders(HeadersFrameParam{
3017			StreamID:      id,
3018			BlockFragment: st.encodeHeader(":method", "POST"),
3019			EndStream:     false,
3020			EndHeaders:    true,
3021		})
3022		st.writeData(id, true, nil)
3023		st.wantHeaders()
3024		df := st.wantData()
3025		if !df.StreamEnded() {
3026			b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3027		}
3028	}
3029}
3030
3031// Send a stream of messages from server to client in separate data frames.
3032// Brings up performance issues seen in long streams.
3033// Created to show problem in go issue #18502
3034func BenchmarkServerToClientStreamDefaultOptions(b *testing.B) {
3035	benchmarkServerToClientStream(b)
3036}
3037
3038// Justification for Change-Id: Iad93420ef6c3918f54249d867098f1dadfa324d8
3039// Expect to see memory/alloc reduction by opting in to Frame reuse with the Framer.
3040func BenchmarkServerToClientStreamReuseFrames(b *testing.B) {
3041	benchmarkServerToClientStream(b, optFramerReuseFrames)
3042}
3043
3044func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
3045	defer disableGoroutineTracking()()
3046	b.ReportAllocs()
3047	const msgLen = 1
3048	// default window size
3049	const windowSize = 1<<16 - 1
3050
3051	// next message to send from the server and for the client to expect
3052	nextMsg := func(i int) []byte {
3053		msg := make([]byte, msgLen)
3054		msg[0] = byte(i)
3055		if len(msg) != msgLen {
3056			panic("invalid test setup msg length")
3057		}
3058		return msg
3059	}
3060
3061	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3062		// Consume the (empty) body from th peer before replying, otherwise
3063		// the server will sometimes (depending on scheduling) send the peer a
3064		// a RST_STREAM with the CANCEL error code.
3065		if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
3066			b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
3067		}
3068		for i := 0; i < b.N; i += 1 {
3069			w.Write(nextMsg(i))
3070			w.(http.Flusher).Flush()
3071		}
3072	}, newServerOpts...)
3073	defer st.Close()
3074	st.greet()
3075
3076	const id = uint32(1)
3077
3078	st.writeHeaders(HeadersFrameParam{
3079		StreamID:      id,
3080		BlockFragment: st.encodeHeader(":method", "POST"),
3081		EndStream:     false,
3082		EndHeaders:    true,
3083	})
3084
3085	st.writeData(id, true, nil)
3086	st.wantHeaders()
3087
3088	var pendingWindowUpdate = uint32(0)
3089
3090	for i := 0; i < b.N; i += 1 {
3091		expected := nextMsg(i)
3092		df := st.wantData()
3093		if bytes.Compare(expected, df.data) != 0 {
3094			b.Fatalf("Bad message received; want %v; got %v", expected, df.data)
3095		}
3096		// try to send infrequent but large window updates so they don't overwhelm the test
3097		pendingWindowUpdate += uint32(len(df.data))
3098		if pendingWindowUpdate >= windowSize/2 {
3099			if err := st.fr.WriteWindowUpdate(0, pendingWindowUpdate); err != nil {
3100				b.Fatal(err)
3101			}
3102			if err := st.fr.WriteWindowUpdate(id, pendingWindowUpdate); err != nil {
3103				b.Fatal(err)
3104			}
3105			pendingWindowUpdate = 0
3106		}
3107	}
3108	df := st.wantData()
3109	if !df.StreamEnded() {
3110		b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3111	}
3112}
3113
3114// go-fuzz bug, originally reported at https://github.com/bradfitz/http2/issues/53
3115// Verify we don't hang.
3116func TestIssue53(t *testing.T) {
3117	const data = "PRI * HTTP/2.0\r\n\r\nSM" +
3118		"\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad"
3119	s := &http.Server{
3120		ErrorLog: log.New(io.MultiWriter(stderrv(), twriter{t: t}), "", log.LstdFlags),
3121		Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
3122			w.Write([]byte("hello"))
3123		}),
3124	}
3125	s2 := &Server{
3126		MaxReadFrameSize:             1 << 16,
3127		PermitProhibitedCipherSuites: true,
3128	}
3129	c := &issue53Conn{[]byte(data), false, false}
3130	s2.ServeConn(c, &ServeConnOpts{BaseConfig: s})
3131	if !c.closed {
3132		t.Fatal("connection is not closed")
3133	}
3134}
3135
3136type issue53Conn struct {
3137	data    []byte
3138	closed  bool
3139	written bool
3140}
3141
3142func (c *issue53Conn) Read(b []byte) (n int, err error) {
3143	if len(c.data) == 0 {
3144		return 0, io.EOF
3145	}
3146	n = copy(b, c.data)
3147	c.data = c.data[n:]
3148	return
3149}
3150
3151func (c *issue53Conn) Write(b []byte) (n int, err error) {
3152	c.written = true
3153	return len(b), nil
3154}
3155
3156func (c *issue53Conn) Close() error {
3157	c.closed = true
3158	return nil
3159}
3160
3161func (c *issue53Conn) LocalAddr() net.Addr {
3162	return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3163}
3164func (c *issue53Conn) RemoteAddr() net.Addr {
3165	return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3166}
3167func (c *issue53Conn) SetDeadline(t time.Time) error      { return nil }
3168func (c *issue53Conn) SetReadDeadline(t time.Time) error  { return nil }
3169func (c *issue53Conn) SetWriteDeadline(t time.Time) error { return nil }
3170
3171// golang.org/issue/12895
3172func TestConfigureServer(t *testing.T) {
3173	tests := []struct {
3174		name      string
3175		tlsConfig *tls.Config
3176		wantErr   string
3177	}{
3178		{
3179			name: "empty server",
3180		},
3181		{
3182			name: "just the required cipher suite",
3183			tlsConfig: &tls.Config{
3184				CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3185			},
3186		},
3187		{
3188			name: "just the alternative required cipher suite",
3189			tlsConfig: &tls.Config{
3190				CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
3191			},
3192		},
3193		{
3194			name: "missing required cipher suite",
3195			tlsConfig: &tls.Config{
3196				CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384},
3197			},
3198			wantErr: "is missing an HTTP/2-required AES_128_GCM_SHA256 cipher.",
3199		},
3200		{
3201			name: "required after bad",
3202			tlsConfig: &tls.Config{
3203				CipherSuites: []uint16{tls.TLS_RSA_WITH_RC4_128_SHA, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3204			},
3205			wantErr: "contains an HTTP/2-approved cipher suite (0xc02f), but it comes after",
3206		},
3207		{
3208			name: "bad after required",
3209			tlsConfig: &tls.Config{
3210				CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_RC4_128_SHA},
3211			},
3212		},
3213	}
3214	for _, tt := range tests {
3215		srv := &http.Server{TLSConfig: tt.tlsConfig}
3216		err := ConfigureServer(srv, nil)
3217		if (err != nil) != (tt.wantErr != "") {
3218			if tt.wantErr != "" {
3219				t.Errorf("%s: success, but want error", tt.name)
3220			} else {
3221				t.Errorf("%s: unexpected error: %v", tt.name, err)
3222			}
3223		}
3224		if err != nil && tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) {
3225			t.Errorf("%s: err = %v; want substring %q", tt.name, err, tt.wantErr)
3226		}
3227		if err == nil && !srv.TLSConfig.PreferServerCipherSuites {
3228			t.Errorf("%s: PreferServerCipherSuite is false; want true", tt.name)
3229		}
3230	}
3231}
3232
3233func TestServerRejectHeadWithBody(t *testing.T) {
3234	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3235		// No response body.
3236	})
3237	defer st.Close()
3238	st.greet()
3239	st.writeHeaders(HeadersFrameParam{
3240		StreamID:      1, // clients send odd numbers
3241		BlockFragment: st.encodeHeader(":method", "HEAD"),
3242		EndStream:     false, // what we're testing, a bogus HEAD request with body
3243		EndHeaders:    true,
3244	})
3245	st.wantRSTStream(1, ErrCodeProtocol)
3246}
3247
3248func TestServerNoAutoContentLengthOnHead(t *testing.T) {
3249	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3250		// No response body. (or smaller than one frame)
3251	})
3252	defer st.Close()
3253	st.greet()
3254	st.writeHeaders(HeadersFrameParam{
3255		StreamID:      1, // clients send odd numbers
3256		BlockFragment: st.encodeHeader(":method", "HEAD"),
3257		EndStream:     true,
3258		EndHeaders:    true,
3259	})
3260	h := st.wantHeaders()
3261	headers := st.decodeHeader(h.HeaderBlockFragment())
3262	want := [][2]string{
3263		{":status", "200"},
3264	}
3265	if !reflect.DeepEqual(headers, want) {
3266		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
3267	}
3268}
3269
3270// golang.org/issue/13495
3271func TestServerNoDuplicateContentType(t *testing.T) {
3272	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3273		w.Header()["Content-Type"] = []string{""}
3274		fmt.Fprintf(w, "<html><head></head><body>hi</body></html>")
3275	})
3276	defer st.Close()
3277	st.greet()
3278	st.writeHeaders(HeadersFrameParam{
3279		StreamID:      1,
3280		BlockFragment: st.encodeHeader(),
3281		EndStream:     true,
3282		EndHeaders:    true,
3283	})
3284	h := st.wantHeaders()
3285	headers := st.decodeHeader(h.HeaderBlockFragment())
3286	want := [][2]string{
3287		{":status", "200"},
3288		{"content-type", ""},
3289		{"content-length", "41"},
3290	}
3291	if !reflect.DeepEqual(headers, want) {
3292		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
3293	}
3294}
3295
3296func disableGoroutineTracking() (restore func()) {
3297	old := DebugGoroutines
3298	DebugGoroutines = false
3299	return func() { DebugGoroutines = old }
3300}
3301
3302func BenchmarkServer_GetRequest(b *testing.B) {
3303	defer disableGoroutineTracking()()
3304	b.ReportAllocs()
3305	const msg = "Hello, world."
3306	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3307		n, err := io.Copy(ioutil.Discard, r.Body)
3308		if err != nil || n > 0 {
3309			b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3310		}
3311		io.WriteString(w, msg)
3312	})
3313	defer st.Close()
3314
3315	st.greet()
3316	// Give the server quota to reply. (plus it has the 64KB)
3317	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3318		b.Fatal(err)
3319	}
3320	hbf := st.encodeHeader(":method", "GET")
3321	for i := 0; i < b.N; i++ {
3322		streamID := uint32(1 + 2*i)
3323		st.writeHeaders(HeadersFrameParam{
3324			StreamID:      streamID,
3325			BlockFragment: hbf,
3326			EndStream:     true,
3327			EndHeaders:    true,
3328		})
3329		st.wantHeaders()
3330		st.wantData()
3331	}
3332}
3333
3334func BenchmarkServer_PostRequest(b *testing.B) {
3335	defer disableGoroutineTracking()()
3336	b.ReportAllocs()
3337	const msg = "Hello, world."
3338	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3339		n, err := io.Copy(ioutil.Discard, r.Body)
3340		if err != nil || n > 0 {
3341			b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3342		}
3343		io.WriteString(w, msg)
3344	})
3345	defer st.Close()
3346	st.greet()
3347	// Give the server quota to reply. (plus it has the 64KB)
3348	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3349		b.Fatal(err)
3350	}
3351	hbf := st.encodeHeader(":method", "POST")
3352	for i := 0; i < b.N; i++ {
3353		streamID := uint32(1 + 2*i)
3354		st.writeHeaders(HeadersFrameParam{
3355			StreamID:      streamID,
3356			BlockFragment: hbf,
3357			EndStream:     false,
3358			EndHeaders:    true,
3359		})
3360		st.writeData(streamID, true, nil)
3361		st.wantHeaders()
3362		st.wantData()
3363	}
3364}
3365
3366type connStateConn struct {
3367	net.Conn
3368	cs tls.ConnectionState
3369}
3370
3371func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs }
3372
3373// golang.org/issue/12737 -- handle any net.Conn, not just
3374// *tls.Conn.
3375func TestServerHandleCustomConn(t *testing.T) {
3376	var s Server
3377	c1, c2 := net.Pipe()
3378	clientDone := make(chan struct{})
3379	handlerDone := make(chan struct{})
3380	var req *http.Request
3381	go func() {
3382		defer close(clientDone)
3383		defer c2.Close()
3384		fr := NewFramer(c2, c2)
3385		io.WriteString(c2, ClientPreface)
3386		fr.WriteSettings()
3387		fr.WriteSettingsAck()
3388		f, err := fr.ReadFrame()
3389		if err != nil {
3390			t.Error(err)
3391			return
3392		}
3393		if sf, ok := f.(*SettingsFrame); !ok || sf.IsAck() {
3394			t.Errorf("Got %v; want non-ACK SettingsFrame", summarizeFrame(f))
3395			return
3396		}
3397		f, err = fr.ReadFrame()
3398		if err != nil {
3399			t.Error(err)
3400			return
3401		}
3402		if sf, ok := f.(*SettingsFrame); !ok || !sf.IsAck() {
3403			t.Errorf("Got %v; want ACK SettingsFrame", summarizeFrame(f))
3404			return
3405		}
3406		var henc hpackEncoder
3407		fr.WriteHeaders(HeadersFrameParam{
3408			StreamID:      1,
3409			BlockFragment: henc.encodeHeaderRaw(t, ":method", "GET", ":path", "/", ":scheme", "https", ":authority", "foo.com"),
3410			EndStream:     true,
3411			EndHeaders:    true,
3412		})
3413		go io.Copy(ioutil.Discard, c2)
3414		<-handlerDone
3415	}()
3416	const testString = "my custom ConnectionState"
3417	fakeConnState := tls.ConnectionState{
3418		ServerName:  testString,
3419		Version:     tls.VersionTLS12,
3420		CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
3421	}
3422	go s.ServeConn(connStateConn{c1, fakeConnState}, &ServeConnOpts{
3423		BaseConfig: &http.Server{
3424			Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3425				defer close(handlerDone)
3426				req = r
3427			}),
3428		}})
3429	select {
3430	case <-clientDone:
3431	case <-time.After(5 * time.Second):
3432		t.Fatal("timeout waiting for handler")
3433	}
3434	if req.TLS == nil {
3435		t.Fatalf("Request.TLS is nil. Got: %#v", req)
3436	}
3437	if req.TLS.ServerName != testString {
3438		t.Fatalf("Request.TLS = %+v; want ServerName of %q", req.TLS, testString)
3439	}
3440}
3441
3442// golang.org/issue/14214
3443func TestServer_Rejects_ConnHeaders(t *testing.T) {
3444	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3445		t.Error("should not get to Handler")
3446	})
3447	defer st.Close()
3448	st.greet()
3449	st.bodylessReq1("connection", "foo")
3450	hf := st.wantHeaders()
3451	goth := st.decodeHeader(hf.HeaderBlockFragment())
3452	wanth := [][2]string{
3453		{":status", "400"},
3454		{"content-type", "text/plain; charset=utf-8"},
3455		{"x-content-type-options", "nosniff"},
3456		{"content-length", "51"},
3457	}
3458	if !reflect.DeepEqual(goth, wanth) {
3459		t.Errorf("Got headers %v; want %v", goth, wanth)
3460	}
3461}
3462
3463type hpackEncoder struct {
3464	enc *hpack.Encoder
3465	buf bytes.Buffer
3466}
3467
3468func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte {
3469	if len(headers)%2 == 1 {
3470		panic("odd number of kv args")
3471	}
3472	he.buf.Reset()
3473	if he.enc == nil {
3474		he.enc = hpack.NewEncoder(&he.buf)
3475	}
3476	for len(headers) > 0 {
3477		k, v := headers[0], headers[1]
3478		err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v})
3479		if err != nil {
3480			t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
3481		}
3482		headers = headers[2:]
3483	}
3484	return he.buf.Bytes()
3485}
3486
3487func TestCheckValidHTTP2Request(t *testing.T) {
3488	tests := []struct {
3489		h    http.Header
3490		want error
3491	}{
3492		{
3493			h:    http.Header{"Te": {"trailers"}},
3494			want: nil,
3495		},
3496		{
3497			h:    http.Header{"Te": {"trailers", "bogus"}},
3498			want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`),
3499		},
3500		{
3501			h:    http.Header{"Foo": {""}},
3502			want: nil,
3503		},
3504		{
3505			h:    http.Header{"Connection": {""}},
3506			want: errors.New(`request header "Connection" is not valid in HTTP/2`),
3507		},
3508		{
3509			h:    http.Header{"Proxy-Connection": {""}},
3510			want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`),
3511		},
3512		{
3513			h:    http.Header{"Keep-Alive": {""}},
3514			want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`),
3515		},
3516		{
3517			h:    http.Header{"Upgrade": {""}},
3518			want: errors.New(`request header "Upgrade" is not valid in HTTP/2`),
3519		},
3520	}
3521	for i, tt := range tests {
3522		got := checkValidHTTP2RequestHeaders(tt.h)
3523		if !reflect.DeepEqual(got, tt.want) {
3524			t.Errorf("%d. checkValidHTTP2Request = %v; want %v", i, got, tt.want)
3525		}
3526	}
3527}
3528
3529// golang.org/issue/14030
3530func TestExpect100ContinueAfterHandlerWrites(t *testing.T) {
3531	const msg = "Hello"
3532	const msg2 = "World"
3533
3534	doRead := make(chan bool, 1)
3535	defer close(doRead) // fallback cleanup
3536
3537	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3538		io.WriteString(w, msg)
3539		w.(http.Flusher).Flush()
3540
3541		// Do a read, which might force a 100-continue status to be sent.
3542		<-doRead
3543		r.Body.Read(make([]byte, 10))
3544
3545		io.WriteString(w, msg2)
3546
3547	}, optOnlyServer)
3548	defer st.Close()
3549
3550	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3551	defer tr.CloseIdleConnections()
3552
3553	req, _ := http.NewRequest("POST", st.ts.URL, io.LimitReader(neverEnding('A'), 2<<20))
3554	req.Header.Set("Expect", "100-continue")
3555
3556	res, err := tr.RoundTrip(req)
3557	if err != nil {
3558		t.Fatal(err)
3559	}
3560	defer res.Body.Close()
3561
3562	buf := make([]byte, len(msg))
3563	if _, err := io.ReadFull(res.Body, buf); err != nil {
3564		t.Fatal(err)
3565	}
3566	if string(buf) != msg {
3567		t.Fatalf("msg = %q; want %q", buf, msg)
3568	}
3569
3570	doRead <- true
3571
3572	if _, err := io.ReadFull(res.Body, buf); err != nil {
3573		t.Fatal(err)
3574	}
3575	if string(buf) != msg2 {
3576		t.Fatalf("second msg = %q; want %q", buf, msg2)
3577	}
3578}
3579
3580type funcReader func([]byte) (n int, err error)
3581
3582func (f funcReader) Read(p []byte) (n int, err error) { return f(p) }
3583
3584// golang.org/issue/16481 -- return flow control when streams close with unread data.
3585// (The Server version of the bug. See also TestUnreadFlowControlReturned_Transport)
3586func TestUnreadFlowControlReturned_Server(t *testing.T) {
3587	unblock := make(chan bool, 1)
3588	defer close(unblock)
3589
3590	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3591		// Don't read the 16KB request body. Wait until the client's
3592		// done sending it and then return. This should cause the Server
3593		// to then return those 16KB of flow control to the client.
3594		<-unblock
3595	}, optOnlyServer)
3596	defer st.Close()
3597
3598	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3599	defer tr.CloseIdleConnections()
3600
3601	// This previously hung on the 4th iteration.
3602	for i := 0; i < 6; i++ {
3603		body := io.MultiReader(
3604			io.LimitReader(neverEnding('A'), 16<<10),
3605			funcReader(func([]byte) (n int, err error) {
3606				unblock <- true
3607				return 0, io.EOF
3608			}),
3609		)
3610		req, _ := http.NewRequest("POST", st.ts.URL, body)
3611		res, err := tr.RoundTrip(req)
3612		if err != nil {
3613			t.Fatal(err)
3614		}
3615		res.Body.Close()
3616	}
3617
3618}
3619
3620func TestServerIdleTimeout(t *testing.T) {
3621	if testing.Short() {
3622		t.Skip("skipping in short mode")
3623	}
3624
3625	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3626	}, func(h2s *Server) {
3627		h2s.IdleTimeout = 500 * time.Millisecond
3628	})
3629	defer st.Close()
3630
3631	st.greet()
3632	ga := st.wantGoAway()
3633	if ga.ErrCode != ErrCodeNo {
3634		t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
3635	}
3636}
3637
3638func TestServerIdleTimeout_AfterRequest(t *testing.T) {
3639	if testing.Short() {
3640		t.Skip("skipping in short mode")
3641	}
3642	const timeout = 250 * time.Millisecond
3643
3644	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3645		time.Sleep(timeout * 2)
3646	}, func(h2s *Server) {
3647		h2s.IdleTimeout = timeout
3648	})
3649	defer st.Close()
3650
3651	st.greet()
3652
3653	// Send a request which takes twice the timeout. Verifies the
3654	// idle timeout doesn't fire while we're in a request:
3655	st.bodylessReq1()
3656	st.wantHeaders()
3657
3658	// But the idle timeout should be rearmed after the request
3659	// is done:
3660	ga := st.wantGoAway()
3661	if ga.ErrCode != ErrCodeNo {
3662		t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
3663	}
3664}
3665
3666// grpc-go closes the Request.Body currently with a Read.
3667// Verify that it doesn't race.
3668// See https://github.com/grpc/grpc-go/pull/938
3669func TestRequestBodyReadCloseRace(t *testing.T) {
3670	for i := 0; i < 100; i++ {
3671		body := &requestBody{
3672			pipe: &pipe{
3673				b: new(bytes.Buffer),
3674			},
3675		}
3676		body.pipe.CloseWithError(io.EOF)
3677
3678		done := make(chan bool, 1)
3679		buf := make([]byte, 10)
3680		go func() {
3681			time.Sleep(1 * time.Millisecond)
3682			body.Close()
3683			done <- true
3684		}()
3685		body.Read(buf)
3686		<-done
3687	}
3688}
3689
3690func TestIssue20704Race(t *testing.T) {
3691	if testing.Short() && os.Getenv("GO_BUILDER_NAME") == "" {
3692		t.Skip("skipping in short mode")
3693	}
3694	const (
3695		itemSize  = 1 << 10
3696		itemCount = 100
3697	)
3698
3699	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3700		for i := 0; i < itemCount; i++ {
3701			_, err := w.Write(make([]byte, itemSize))
3702			if err != nil {
3703				return
3704			}
3705		}
3706	}, optOnlyServer)
3707	defer st.Close()
3708
3709	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3710	defer tr.CloseIdleConnections()
3711	cl := &http.Client{Transport: tr}
3712
3713	for i := 0; i < 1000; i++ {
3714		resp, err := cl.Get(st.ts.URL)
3715		if err != nil {
3716			t.Fatal(err)
3717		}
3718		// Force a RST stream to the server by closing without
3719		// reading the body:
3720		resp.Body.Close()
3721	}
3722}
3723
3724func TestServer_Rejects_TooSmall(t *testing.T) {
3725	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3726		ioutil.ReadAll(r.Body)
3727		return nil
3728	}, func(st *serverTester) {
3729		st.writeHeaders(HeadersFrameParam{
3730			StreamID: 1, // clients send odd numbers
3731			BlockFragment: st.encodeHeader(
3732				":method", "POST",
3733				"content-length", "4",
3734			),
3735			EndStream:  false, // to say DATA frames are coming
3736			EndHeaders: true,
3737		})
3738		st.writeData(1, true, []byte("12345"))
3739
3740		st.wantRSTStream(1, ErrCodeProtocol)
3741	})
3742}
3743
3744// Tests that a handler setting "Connection: close" results in a GOAWAY being sent,
3745// and the connection still completing.
3746func TestServerHandlerConnectionClose(t *testing.T) {
3747	unblockHandler := make(chan bool, 1)
3748	defer close(unblockHandler) // backup; in case of errors
3749	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3750		w.Header().Set("Connection", "close")
3751		w.Header().Set("Foo", "bar")
3752		w.(http.Flusher).Flush()
3753		<-unblockHandler
3754		return nil
3755	}, func(st *serverTester) {
3756		st.writeHeaders(HeadersFrameParam{
3757			StreamID:      1,
3758			BlockFragment: st.encodeHeader(),
3759			EndStream:     true,
3760			EndHeaders:    true,
3761		})
3762		var sawGoAway bool
3763		var sawRes bool
3764		for {
3765			f, err := st.readFrame()
3766			if err == io.EOF {
3767				break
3768			}
3769			if err != nil {
3770				t.Fatal(err)
3771			}
3772			switch f := f.(type) {
3773			case *GoAwayFrame:
3774				sawGoAway = true
3775				unblockHandler <- true
3776				if f.LastStreamID != 1 || f.ErrCode != ErrCodeNo {
3777					t.Errorf("unexpected GOAWAY frame: %v", summarizeFrame(f))
3778				}
3779			case *HeadersFrame:
3780				goth := st.decodeHeader(f.HeaderBlockFragment())
3781				wanth := [][2]string{
3782					{":status", "200"},
3783					{"foo", "bar"},
3784				}
3785				if !reflect.DeepEqual(goth, wanth) {
3786					t.Errorf("got headers %v; want %v", goth, wanth)
3787				}
3788				sawRes = true
3789			case *DataFrame:
3790				if f.StreamID != 1 || !f.StreamEnded() || len(f.Data()) != 0 {
3791					t.Errorf("unexpected DATA frame: %v", summarizeFrame(f))
3792				}
3793			default:
3794				t.Logf("unexpected frame: %v", summarizeFrame(f))
3795			}
3796		}
3797		if !sawGoAway {
3798			t.Errorf("didn't see GOAWAY")
3799		}
3800		if !sawRes {
3801			t.Errorf("didn't see response")
3802		}
3803	})
3804}
3805
3806func TestServer_Headers_HalfCloseRemote(t *testing.T) {
3807	var st *serverTester
3808	writeData := make(chan bool)
3809	writeHeaders := make(chan bool)
3810	leaveHandler := make(chan bool)
3811	st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3812		if st.stream(1) == nil {
3813			t.Errorf("nil stream 1 in handler")
3814		}
3815		if got, want := st.streamState(1), stateOpen; got != want {
3816			t.Errorf("in handler, state is %v; want %v", got, want)
3817		}
3818		writeData <- true
3819		if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
3820			t.Errorf("body read = %d, %v; want 0, EOF", n, err)
3821		}
3822		if got, want := st.streamState(1), stateHalfClosedRemote; got != want {
3823			t.Errorf("in handler, state is %v; want %v", got, want)
3824		}
3825		writeHeaders <- true
3826
3827		<-leaveHandler
3828	})
3829	st.greet()
3830
3831	st.writeHeaders(HeadersFrameParam{
3832		StreamID:      1,
3833		BlockFragment: st.encodeHeader(),
3834		EndStream:     false, // keep it open
3835		EndHeaders:    true,
3836	})
3837	<-writeData
3838	st.writeData(1, true, nil)
3839
3840	<-writeHeaders
3841
3842	st.writeHeaders(HeadersFrameParam{
3843		StreamID:      1,
3844		BlockFragment: st.encodeHeader(),
3845		EndStream:     false, // keep it open
3846		EndHeaders:    true,
3847	})
3848
3849	defer close(leaveHandler)
3850
3851	st.wantRSTStream(1, ErrCodeStreamClosed)
3852}
3853