1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package http2
6
7import (
8	"bytes"
9	"crypto/tls"
10	"errors"
11	"flag"
12	"fmt"
13	"io"
14	"io/ioutil"
15	"log"
16	"net"
17	"net/http"
18	"net/http/httptest"
19	"os"
20	"os/exec"
21	"reflect"
22	"runtime"
23	"strconv"
24	"strings"
25	"sync"
26	"sync/atomic"
27	"testing"
28	"time"
29
30	"golang.org/x/net/http2/hpack"
31)
32
33var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered")
34
35func stderrv() io.Writer {
36	if *stderrVerbose {
37		return os.Stderr
38	}
39
40	return ioutil.Discard
41}
42
43type serverTester struct {
44	cc             net.Conn // client conn
45	t              testing.TB
46	ts             *httptest.Server
47	fr             *Framer
48	serverLogBuf   bytes.Buffer // logger for httptest.Server
49	logFilter      []string     // substrings to filter out
50	scMu           sync.Mutex   // guards sc
51	sc             *serverConn
52	hpackDec       *hpack.Decoder
53	decodedHeaders [][2]string
54
55	// If http2debug!=2, then we capture Frame debug logs that will be written
56	// to t.Log after a test fails. The read and write logs use separate locks
57	// and buffers so we don't accidentally introduce synchronization between
58	// the read and write goroutines, which may hide data races.
59	frameReadLogMu   sync.Mutex
60	frameReadLogBuf  bytes.Buffer
61	frameWriteLogMu  sync.Mutex
62	frameWriteLogBuf bytes.Buffer
63
64	// writing headers:
65	headerBuf bytes.Buffer
66	hpackEnc  *hpack.Encoder
67}
68
69func init() {
70	testHookOnPanicMu = new(sync.Mutex)
71	goAwayTimeout = 25 * time.Millisecond
72}
73
74func resetHooks() {
75	testHookOnPanicMu.Lock()
76	testHookOnPanic = nil
77	testHookOnPanicMu.Unlock()
78}
79
80type serverTesterOpt string
81
82var optOnlyServer = serverTesterOpt("only_server")
83var optQuiet = serverTesterOpt("quiet_logging")
84var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames")
85
86func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
87	resetHooks()
88
89	ts := httptest.NewUnstartedServer(handler)
90
91	tlsConfig := &tls.Config{
92		InsecureSkipVerify: true,
93		NextProtos:         []string{NextProtoTLS},
94	}
95
96	var onlyServer, quiet, framerReuseFrames bool
97	h2server := new(Server)
98	for _, opt := range opts {
99		switch v := opt.(type) {
100		case func(*tls.Config):
101			v(tlsConfig)
102		case func(*httptest.Server):
103			v(ts)
104		case func(*Server):
105			v(h2server)
106		case serverTesterOpt:
107			switch v {
108			case optOnlyServer:
109				onlyServer = true
110			case optQuiet:
111				quiet = true
112			case optFramerReuseFrames:
113				framerReuseFrames = true
114			}
115		case func(net.Conn, http.ConnState):
116			ts.Config.ConnState = v
117		default:
118			t.Fatalf("unknown newServerTester option type %T", v)
119		}
120	}
121
122	ConfigureServer(ts.Config, h2server)
123
124	st := &serverTester{
125		t:  t,
126		ts: ts,
127	}
128	st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
129	st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
130
131	ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
132	if quiet {
133		ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
134	} else {
135		ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
136	}
137	ts.StartTLS()
138
139	if VerboseLogs {
140		t.Logf("Running test server at: %s", ts.URL)
141	}
142	testHookGetServerConn = func(v *serverConn) {
143		st.scMu.Lock()
144		defer st.scMu.Unlock()
145		st.sc = v
146	}
147	log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st}))
148	if !onlyServer {
149		cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
150		if err != nil {
151			t.Fatal(err)
152		}
153		st.cc = cc
154		st.fr = NewFramer(cc, cc)
155		if framerReuseFrames {
156			st.fr.SetReuseFrames()
157		}
158		if !logFrameReads && !logFrameWrites {
159			st.fr.debugReadLoggerf = func(m string, v ...interface{}) {
160				m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
161				st.frameReadLogMu.Lock()
162				fmt.Fprintf(&st.frameReadLogBuf, m, v...)
163				st.frameReadLogMu.Unlock()
164			}
165			st.fr.debugWriteLoggerf = func(m string, v ...interface{}) {
166				m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
167				st.frameWriteLogMu.Lock()
168				fmt.Fprintf(&st.frameWriteLogBuf, m, v...)
169				st.frameWriteLogMu.Unlock()
170			}
171			st.fr.logReads = true
172			st.fr.logWrites = true
173		}
174	}
175	return st
176}
177
178func (st *serverTester) closeConn() {
179	st.scMu.Lock()
180	defer st.scMu.Unlock()
181	st.sc.conn.Close()
182}
183
184func (st *serverTester) addLogFilter(phrase string) {
185	st.logFilter = append(st.logFilter, phrase)
186}
187
188func (st *serverTester) stream(id uint32) *stream {
189	ch := make(chan *stream, 1)
190	st.sc.serveMsgCh <- func(int) {
191		ch <- st.sc.streams[id]
192	}
193	return <-ch
194}
195
196func (st *serverTester) streamState(id uint32) streamState {
197	ch := make(chan streamState, 1)
198	st.sc.serveMsgCh <- func(int) {
199		state, _ := st.sc.state(id)
200		ch <- state
201	}
202	return <-ch
203}
204
205// loopNum reports how many times this conn's select loop has gone around.
206func (st *serverTester) loopNum() int {
207	lastc := make(chan int, 1)
208	st.sc.serveMsgCh <- func(loopNum int) {
209		lastc <- loopNum
210	}
211	return <-lastc
212}
213
214// awaitIdle heuristically awaits for the server conn's select loop to be idle.
215// The heuristic is that the server connection's serve loop must schedule
216// 50 times in a row without any channel sends or receives occurring.
217func (st *serverTester) awaitIdle() {
218	remain := 50
219	last := st.loopNum()
220	for remain > 0 {
221		n := st.loopNum()
222		if n == last+1 {
223			remain--
224		} else {
225			remain = 50
226		}
227		last = n
228	}
229}
230
231func (st *serverTester) Close() {
232	if st.t.Failed() {
233		st.frameReadLogMu.Lock()
234		if st.frameReadLogBuf.Len() > 0 {
235			st.t.Logf("Framer read log:\n%s", st.frameReadLogBuf.String())
236		}
237		st.frameReadLogMu.Unlock()
238
239		st.frameWriteLogMu.Lock()
240		if st.frameWriteLogBuf.Len() > 0 {
241			st.t.Logf("Framer write log:\n%s", st.frameWriteLogBuf.String())
242		}
243		st.frameWriteLogMu.Unlock()
244
245		// If we failed already (and are likely in a Fatal,
246		// unwindowing), force close the connection, so the
247		// httptest.Server doesn't wait forever for the conn
248		// to close.
249		if st.cc != nil {
250			st.cc.Close()
251		}
252	}
253	st.ts.Close()
254	if st.cc != nil {
255		st.cc.Close()
256	}
257	log.SetOutput(os.Stderr)
258}
259
260// greet initiates the client's HTTP/2 connection into a state where
261// frames may be sent.
262func (st *serverTester) greet() {
263	st.greetAndCheckSettings(func(Setting) error { return nil })
264}
265
266func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error) {
267	st.writePreface()
268	st.writeInitialSettings()
269	st.wantSettings().ForeachSetting(checkSetting)
270	st.writeSettingsAck()
271
272	// The initial WINDOW_UPDATE and SETTINGS ACK can come in any order.
273	var gotSettingsAck bool
274	var gotWindowUpdate bool
275
276	for i := 0; i < 2; i++ {
277		f, err := st.readFrame()
278		if err != nil {
279			st.t.Fatal(err)
280		}
281		switch f := f.(type) {
282		case *SettingsFrame:
283			if !f.Header().Flags.Has(FlagSettingsAck) {
284				st.t.Fatal("Settings Frame didn't have ACK set")
285			}
286			gotSettingsAck = true
287
288		case *WindowUpdateFrame:
289			if f.FrameHeader.StreamID != 0 {
290				st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
291			}
292			incr := uint32((&Server{}).initialConnRecvWindowSize() - initialWindowSize)
293			if f.Increment != incr {
294				st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr)
295			}
296			gotWindowUpdate = true
297
298		default:
299			st.t.Fatalf("Wanting a settings ACK or window update, received a %T", f)
300		}
301	}
302
303	if !gotSettingsAck {
304		st.t.Fatalf("Didn't get a settings ACK")
305	}
306	if !gotWindowUpdate {
307		st.t.Fatalf("Didn't get a window update")
308	}
309}
310
311func (st *serverTester) writePreface() {
312	n, err := st.cc.Write(clientPreface)
313	if err != nil {
314		st.t.Fatalf("Error writing client preface: %v", err)
315	}
316	if n != len(clientPreface) {
317		st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(clientPreface))
318	}
319}
320
321func (st *serverTester) writeInitialSettings() {
322	if err := st.fr.WriteSettings(); err != nil {
323		st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err)
324	}
325}
326
327func (st *serverTester) writeSettingsAck() {
328	if err := st.fr.WriteSettingsAck(); err != nil {
329		st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err)
330	}
331}
332
333func (st *serverTester) writeHeaders(p HeadersFrameParam) {
334	if err := st.fr.WriteHeaders(p); err != nil {
335		st.t.Fatalf("Error writing HEADERS: %v", err)
336	}
337}
338
339func (st *serverTester) writePriority(id uint32, p PriorityParam) {
340	if err := st.fr.WritePriority(id, p); err != nil {
341		st.t.Fatalf("Error writing PRIORITY: %v", err)
342	}
343}
344
345func (st *serverTester) encodeHeaderField(k, v string) {
346	err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
347	if err != nil {
348		st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
349	}
350}
351
352// encodeHeaderRaw is the magic-free version of encodeHeader.
353// It takes 0 or more (k, v) pairs and encodes them.
354func (st *serverTester) encodeHeaderRaw(headers ...string) []byte {
355	if len(headers)%2 == 1 {
356		panic("odd number of kv args")
357	}
358	st.headerBuf.Reset()
359	for len(headers) > 0 {
360		k, v := headers[0], headers[1]
361		st.encodeHeaderField(k, v)
362		headers = headers[2:]
363	}
364	return st.headerBuf.Bytes()
365}
366
367// encodeHeader encodes headers and returns their HPACK bytes. headers
368// must contain an even number of key/value pairs. There may be
369// multiple pairs for keys (e.g. "cookie").  The :method, :path, and
370// :scheme headers default to GET, / and https. The :authority header
371// defaults to st.ts.Listener.Addr().
372func (st *serverTester) encodeHeader(headers ...string) []byte {
373	if len(headers)%2 == 1 {
374		panic("odd number of kv args")
375	}
376
377	st.headerBuf.Reset()
378	defaultAuthority := st.ts.Listener.Addr().String()
379
380	if len(headers) == 0 {
381		// Fast path, mostly for benchmarks, so test code doesn't pollute
382		// profiles when we're looking to improve server allocations.
383		st.encodeHeaderField(":method", "GET")
384		st.encodeHeaderField(":scheme", "https")
385		st.encodeHeaderField(":authority", defaultAuthority)
386		st.encodeHeaderField(":path", "/")
387		return st.headerBuf.Bytes()
388	}
389
390	if len(headers) == 2 && headers[0] == ":method" {
391		// Another fast path for benchmarks.
392		st.encodeHeaderField(":method", headers[1])
393		st.encodeHeaderField(":scheme", "https")
394		st.encodeHeaderField(":authority", defaultAuthority)
395		st.encodeHeaderField(":path", "/")
396		return st.headerBuf.Bytes()
397	}
398
399	pseudoCount := map[string]int{}
400	keys := []string{":method", ":scheme", ":authority", ":path"}
401	vals := map[string][]string{
402		":method":    {"GET"},
403		":scheme":    {"https"},
404		":authority": {defaultAuthority},
405		":path":      {"/"},
406	}
407	for len(headers) > 0 {
408		k, v := headers[0], headers[1]
409		headers = headers[2:]
410		if _, ok := vals[k]; !ok {
411			keys = append(keys, k)
412		}
413		if strings.HasPrefix(k, ":") {
414			pseudoCount[k]++
415			if pseudoCount[k] == 1 {
416				vals[k] = []string{v}
417			} else {
418				// Allows testing of invalid headers w/ dup pseudo fields.
419				vals[k] = append(vals[k], v)
420			}
421		} else {
422			vals[k] = append(vals[k], v)
423		}
424	}
425	for _, k := range keys {
426		for _, v := range vals[k] {
427			st.encodeHeaderField(k, v)
428		}
429	}
430	return st.headerBuf.Bytes()
431}
432
433// bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set.
434func (st *serverTester) bodylessReq1(headers ...string) {
435	st.writeHeaders(HeadersFrameParam{
436		StreamID:      1, // clients send odd numbers
437		BlockFragment: st.encodeHeader(headers...),
438		EndStream:     true,
439		EndHeaders:    true,
440	})
441}
442
443func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
444	if err := st.fr.WriteData(streamID, endStream, data); err != nil {
445		st.t.Fatalf("Error writing DATA: %v", err)
446	}
447}
448
449func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
450	if err := st.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
451		st.t.Fatalf("Error writing DATA: %v", err)
452	}
453}
454
455func readFrameTimeout(fr *Framer, wait time.Duration) (Frame, error) {
456	ch := make(chan interface{}, 1)
457	go func() {
458		fr, err := fr.ReadFrame()
459		if err != nil {
460			ch <- err
461		} else {
462			ch <- fr
463		}
464	}()
465	t := time.NewTimer(wait)
466	select {
467	case v := <-ch:
468		t.Stop()
469		if fr, ok := v.(Frame); ok {
470			return fr, nil
471		}
472		return nil, v.(error)
473	case <-t.C:
474		return nil, errors.New("timeout waiting for frame")
475	}
476}
477
478func (st *serverTester) readFrame() (Frame, error) {
479	return readFrameTimeout(st.fr, 2*time.Second)
480}
481
482func (st *serverTester) wantHeaders() *HeadersFrame {
483	f, err := st.readFrame()
484	if err != nil {
485		st.t.Fatalf("Error while expecting a HEADERS frame: %v", err)
486	}
487	hf, ok := f.(*HeadersFrame)
488	if !ok {
489		st.t.Fatalf("got a %T; want *HeadersFrame", f)
490	}
491	return hf
492}
493
494func (st *serverTester) wantContinuation() *ContinuationFrame {
495	f, err := st.readFrame()
496	if err != nil {
497		st.t.Fatalf("Error while expecting a CONTINUATION frame: %v", err)
498	}
499	cf, ok := f.(*ContinuationFrame)
500	if !ok {
501		st.t.Fatalf("got a %T; want *ContinuationFrame", f)
502	}
503	return cf
504}
505
506func (st *serverTester) wantData() *DataFrame {
507	f, err := st.readFrame()
508	if err != nil {
509		st.t.Fatalf("Error while expecting a DATA frame: %v", err)
510	}
511	df, ok := f.(*DataFrame)
512	if !ok {
513		st.t.Fatalf("got a %T; want *DataFrame", f)
514	}
515	return df
516}
517
518func (st *serverTester) wantSettings() *SettingsFrame {
519	f, err := st.readFrame()
520	if err != nil {
521		st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
522	}
523	sf, ok := f.(*SettingsFrame)
524	if !ok {
525		st.t.Fatalf("got a %T; want *SettingsFrame", f)
526	}
527	return sf
528}
529
530func (st *serverTester) wantPing() *PingFrame {
531	f, err := st.readFrame()
532	if err != nil {
533		st.t.Fatalf("Error while expecting a PING frame: %v", err)
534	}
535	pf, ok := f.(*PingFrame)
536	if !ok {
537		st.t.Fatalf("got a %T; want *PingFrame", f)
538	}
539	return pf
540}
541
542func (st *serverTester) wantGoAway() *GoAwayFrame {
543	f, err := st.readFrame()
544	if err != nil {
545		st.t.Fatalf("Error while expecting a GOAWAY frame: %v", err)
546	}
547	gf, ok := f.(*GoAwayFrame)
548	if !ok {
549		st.t.Fatalf("got a %T; want *GoAwayFrame", f)
550	}
551	return gf
552}
553
554func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
555	f, err := st.readFrame()
556	if err != nil {
557		st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
558	}
559	rs, ok := f.(*RSTStreamFrame)
560	if !ok {
561		st.t.Fatalf("got a %T; want *RSTStreamFrame", f)
562	}
563	if rs.FrameHeader.StreamID != streamID {
564		st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
565	}
566	if rs.ErrCode != errCode {
567		st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode)
568	}
569}
570
571func (st *serverTester) wantWindowUpdate(streamID, incr uint32) {
572	f, err := st.readFrame()
573	if err != nil {
574		st.t.Fatalf("Error while expecting a WINDOW_UPDATE frame: %v", err)
575	}
576	wu, ok := f.(*WindowUpdateFrame)
577	if !ok {
578		st.t.Fatalf("got a %T; want *WindowUpdateFrame", f)
579	}
580	if wu.FrameHeader.StreamID != streamID {
581		st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
582	}
583	if wu.Increment != incr {
584		st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
585	}
586}
587
588func (st *serverTester) wantSettingsAck() {
589	f, err := st.readFrame()
590	if err != nil {
591		st.t.Fatal(err)
592	}
593	sf, ok := f.(*SettingsFrame)
594	if !ok {
595		st.t.Fatalf("Wanting a settings ACK, received a %T", f)
596	}
597	if !sf.Header().Flags.Has(FlagSettingsAck) {
598		st.t.Fatal("Settings Frame didn't have ACK set")
599	}
600}
601
602func (st *serverTester) wantPushPromise() *PushPromiseFrame {
603	f, err := st.readFrame()
604	if err != nil {
605		st.t.Fatal(err)
606	}
607	ppf, ok := f.(*PushPromiseFrame)
608	if !ok {
609		st.t.Fatalf("Wanted PushPromise, received %T", ppf)
610	}
611	return ppf
612}
613
614func TestServer(t *testing.T) {
615	gotReq := make(chan bool, 1)
616	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
617		w.Header().Set("Foo", "Bar")
618		gotReq <- true
619	})
620	defer st.Close()
621
622	covers("3.5", `
623		The server connection preface consists of a potentially empty
624		SETTINGS frame ([SETTINGS]) that MUST be the first frame the
625		server sends in the HTTP/2 connection.
626	`)
627
628	st.greet()
629	st.writeHeaders(HeadersFrameParam{
630		StreamID:      1, // clients send odd numbers
631		BlockFragment: st.encodeHeader(),
632		EndStream:     true, // no DATA frames
633		EndHeaders:    true,
634	})
635
636	select {
637	case <-gotReq:
638	case <-time.After(2 * time.Second):
639		t.Error("timeout waiting for request")
640	}
641}
642
643func TestServer_Request_Get(t *testing.T) {
644	testServerRequest(t, func(st *serverTester) {
645		st.writeHeaders(HeadersFrameParam{
646			StreamID:      1, // clients send odd numbers
647			BlockFragment: st.encodeHeader("foo-bar", "some-value"),
648			EndStream:     true, // no DATA frames
649			EndHeaders:    true,
650		})
651	}, func(r *http.Request) {
652		if r.Method != "GET" {
653			t.Errorf("Method = %q; want GET", r.Method)
654		}
655		if r.URL.Path != "/" {
656			t.Errorf("URL.Path = %q; want /", r.URL.Path)
657		}
658		if r.ContentLength != 0 {
659			t.Errorf("ContentLength = %v; want 0", r.ContentLength)
660		}
661		if r.Close {
662			t.Error("Close = true; want false")
663		}
664		if !strings.Contains(r.RemoteAddr, ":") {
665			t.Errorf("RemoteAddr = %q; want something with a colon", r.RemoteAddr)
666		}
667		if r.Proto != "HTTP/2.0" || r.ProtoMajor != 2 || r.ProtoMinor != 0 {
668			t.Errorf("Proto = %q Major=%v,Minor=%v; want HTTP/2.0", r.Proto, r.ProtoMajor, r.ProtoMinor)
669		}
670		wantHeader := http.Header{
671			"Foo-Bar": []string{"some-value"},
672		}
673		if !reflect.DeepEqual(r.Header, wantHeader) {
674			t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
675		}
676		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
677			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
678		}
679	})
680}
681
682func TestServer_Request_Get_PathSlashes(t *testing.T) {
683	testServerRequest(t, func(st *serverTester) {
684		st.writeHeaders(HeadersFrameParam{
685			StreamID:      1, // clients send odd numbers
686			BlockFragment: st.encodeHeader(":path", "/%2f/"),
687			EndStream:     true, // no DATA frames
688			EndHeaders:    true,
689		})
690	}, func(r *http.Request) {
691		if r.RequestURI != "/%2f/" {
692			t.Errorf("RequestURI = %q; want /%%2f/", r.RequestURI)
693		}
694		if r.URL.Path != "///" {
695			t.Errorf("URL.Path = %q; want ///", r.URL.Path)
696		}
697	})
698}
699
700// TODO: add a test with EndStream=true on the HEADERS but setting a
701// Content-Length anyway. Should we just omit it and force it to
702// zero?
703
704func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) {
705	testServerRequest(t, func(st *serverTester) {
706		st.writeHeaders(HeadersFrameParam{
707			StreamID:      1, // clients send odd numbers
708			BlockFragment: st.encodeHeader(":method", "POST"),
709			EndStream:     true,
710			EndHeaders:    true,
711		})
712	}, func(r *http.Request) {
713		if r.Method != "POST" {
714			t.Errorf("Method = %q; want POST", r.Method)
715		}
716		if r.ContentLength != 0 {
717			t.Errorf("ContentLength = %v; want 0", r.ContentLength)
718		}
719		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
720			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
721		}
722	})
723}
724
725func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) {
726	testBodyContents(t, -1, "", func(st *serverTester) {
727		st.writeHeaders(HeadersFrameParam{
728			StreamID:      1, // clients send odd numbers
729			BlockFragment: st.encodeHeader(":method", "POST"),
730			EndStream:     false, // to say DATA frames are coming
731			EndHeaders:    true,
732		})
733		st.writeData(1, true, nil) // just kidding. empty body.
734	})
735}
736
737func TestServer_Request_Post_Body_OneData(t *testing.T) {
738	const content = "Some content"
739	testBodyContents(t, -1, content, func(st *serverTester) {
740		st.writeHeaders(HeadersFrameParam{
741			StreamID:      1, // clients send odd numbers
742			BlockFragment: st.encodeHeader(":method", "POST"),
743			EndStream:     false, // to say DATA frames are coming
744			EndHeaders:    true,
745		})
746		st.writeData(1, true, []byte(content))
747	})
748}
749
750func TestServer_Request_Post_Body_TwoData(t *testing.T) {
751	const content = "Some content"
752	testBodyContents(t, -1, content, func(st *serverTester) {
753		st.writeHeaders(HeadersFrameParam{
754			StreamID:      1, // clients send odd numbers
755			BlockFragment: st.encodeHeader(":method", "POST"),
756			EndStream:     false, // to say DATA frames are coming
757			EndHeaders:    true,
758		})
759		st.writeData(1, false, []byte(content[:5]))
760		st.writeData(1, true, []byte(content[5:]))
761	})
762}
763
764func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) {
765	const content = "Some content"
766	testBodyContents(t, int64(len(content)), content, func(st *serverTester) {
767		st.writeHeaders(HeadersFrameParam{
768			StreamID: 1, // clients send odd numbers
769			BlockFragment: st.encodeHeader(
770				":method", "POST",
771				"content-length", strconv.Itoa(len(content)),
772			),
773			EndStream:  false, // to say DATA frames are coming
774			EndHeaders: true,
775		})
776		st.writeData(1, true, []byte(content))
777	})
778}
779
780func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) {
781	testBodyContentsFail(t, 3, "request declared a Content-Length of 3 but only wrote 2 bytes",
782		func(st *serverTester) {
783			st.writeHeaders(HeadersFrameParam{
784				StreamID: 1, // clients send odd numbers
785				BlockFragment: st.encodeHeader(
786					":method", "POST",
787					"content-length", "3",
788				),
789				EndStream:  false, // to say DATA frames are coming
790				EndHeaders: true,
791			})
792			st.writeData(1, true, []byte("12"))
793		})
794}
795
796func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) {
797	testBodyContentsFail(t, 4, "sender tried to send more than declared Content-Length of 4 bytes",
798		func(st *serverTester) {
799			st.writeHeaders(HeadersFrameParam{
800				StreamID: 1, // clients send odd numbers
801				BlockFragment: st.encodeHeader(
802					":method", "POST",
803					"content-length", "4",
804				),
805				EndStream:  false, // to say DATA frames are coming
806				EndHeaders: true,
807			})
808			st.writeData(1, true, []byte("12345"))
809		})
810}
811
812func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, write func(st *serverTester)) {
813	testServerRequest(t, write, func(r *http.Request) {
814		if r.Method != "POST" {
815			t.Errorf("Method = %q; want POST", r.Method)
816		}
817		if r.ContentLength != wantContentLength {
818			t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
819		}
820		all, err := ioutil.ReadAll(r.Body)
821		if err != nil {
822			t.Fatal(err)
823		}
824		if string(all) != wantBody {
825			t.Errorf("Read = %q; want %q", all, wantBody)
826		}
827		if err := r.Body.Close(); err != nil {
828			t.Fatalf("Close: %v", err)
829		}
830	})
831}
832
833func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError string, write func(st *serverTester)) {
834	testServerRequest(t, write, func(r *http.Request) {
835		if r.Method != "POST" {
836			t.Errorf("Method = %q; want POST", r.Method)
837		}
838		if r.ContentLength != wantContentLength {
839			t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
840		}
841		all, err := ioutil.ReadAll(r.Body)
842		if err == nil {
843			t.Fatalf("expected an error (%q) reading from the body. Successfully read %q instead.",
844				wantReadError, all)
845		}
846		if !strings.Contains(err.Error(), wantReadError) {
847			t.Fatalf("Body.Read = %v; want substring %q", err, wantReadError)
848		}
849		if err := r.Body.Close(); err != nil {
850			t.Fatalf("Close: %v", err)
851		}
852	})
853}
854
855// Using a Host header, instead of :authority
856func TestServer_Request_Get_Host(t *testing.T) {
857	const host = "example.com"
858	testServerRequest(t, func(st *serverTester) {
859		st.writeHeaders(HeadersFrameParam{
860			StreamID:      1, // clients send odd numbers
861			BlockFragment: st.encodeHeader(":authority", "", "host", host),
862			EndStream:     true,
863			EndHeaders:    true,
864		})
865	}, func(r *http.Request) {
866		if r.Host != host {
867			t.Errorf("Host = %q; want %q", r.Host, host)
868		}
869	})
870}
871
872// Using an :authority pseudo-header, instead of Host
873func TestServer_Request_Get_Authority(t *testing.T) {
874	const host = "example.com"
875	testServerRequest(t, func(st *serverTester) {
876		st.writeHeaders(HeadersFrameParam{
877			StreamID:      1, // clients send odd numbers
878			BlockFragment: st.encodeHeader(":authority", host),
879			EndStream:     true,
880			EndHeaders:    true,
881		})
882	}, func(r *http.Request) {
883		if r.Host != host {
884			t.Errorf("Host = %q; want %q", r.Host, host)
885		}
886	})
887}
888
889func TestServer_Request_WithContinuation(t *testing.T) {
890	wantHeader := http.Header{
891		"Foo-One":   []string{"value-one"},
892		"Foo-Two":   []string{"value-two"},
893		"Foo-Three": []string{"value-three"},
894	}
895	testServerRequest(t, func(st *serverTester) {
896		fullHeaders := st.encodeHeader(
897			"foo-one", "value-one",
898			"foo-two", "value-two",
899			"foo-three", "value-three",
900		)
901		remain := fullHeaders
902		chunks := 0
903		for len(remain) > 0 {
904			const maxChunkSize = 5
905			chunk := remain
906			if len(chunk) > maxChunkSize {
907				chunk = chunk[:maxChunkSize]
908			}
909			remain = remain[len(chunk):]
910
911			if chunks == 0 {
912				st.writeHeaders(HeadersFrameParam{
913					StreamID:      1, // clients send odd numbers
914					BlockFragment: chunk,
915					EndStream:     true,  // no DATA frames
916					EndHeaders:    false, // we'll have continuation frames
917				})
918			} else {
919				err := st.fr.WriteContinuation(1, len(remain) == 0, chunk)
920				if err != nil {
921					t.Fatal(err)
922				}
923			}
924			chunks++
925		}
926		if chunks < 2 {
927			t.Fatal("too few chunks")
928		}
929	}, func(r *http.Request) {
930		if !reflect.DeepEqual(r.Header, wantHeader) {
931			t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
932		}
933	})
934}
935
936// Concatenated cookie headers. ("8.1.2.5 Compressing the Cookie Header Field")
937func TestServer_Request_CookieConcat(t *testing.T) {
938	const host = "example.com"
939	testServerRequest(t, func(st *serverTester) {
940		st.bodylessReq1(
941			":authority", host,
942			"cookie", "a=b",
943			"cookie", "c=d",
944			"cookie", "e=f",
945		)
946	}, func(r *http.Request) {
947		const want = "a=b; c=d; e=f"
948		if got := r.Header.Get("Cookie"); got != want {
949			t.Errorf("Cookie = %q; want %q", got, want)
950		}
951	})
952}
953
954func TestServer_Request_Reject_CapitalHeader(t *testing.T) {
955	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("UPPER", "v") })
956}
957
958func TestServer_Request_Reject_HeaderFieldNameColon(t *testing.T) {
959	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has:colon", "v") })
960}
961
962func TestServer_Request_Reject_HeaderFieldNameNULL(t *testing.T) {
963	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has\x00null", "v") })
964}
965
966func TestServer_Request_Reject_HeaderFieldNameEmpty(t *testing.T) {
967	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("", "v") })
968}
969
970func TestServer_Request_Reject_HeaderFieldValueNewline(t *testing.T) {
971	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\nnewline") })
972}
973
974func TestServer_Request_Reject_HeaderFieldValueCR(t *testing.T) {
975	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\rcarriage") })
976}
977
978func TestServer_Request_Reject_HeaderFieldValueDEL(t *testing.T) {
979	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\x7fdel") })
980}
981
982func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) {
983	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "") })
984}
985
986func TestServer_Request_Reject_Pseudo_ExactlyOne(t *testing.T) {
987	// 8.1.2.3 Request Pseudo-Header Fields
988	// "All HTTP/2 requests MUST include exactly one valid value" ...
989	testRejectRequest(t, func(st *serverTester) {
990		st.addLogFilter("duplicate pseudo-header")
991		st.bodylessReq1(":method", "GET", ":method", "POST")
992	})
993}
994
995func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) {
996	// 8.1.2.3 Request Pseudo-Header Fields
997	// "All pseudo-header fields MUST appear in the header block
998	// before regular header fields. Any request or response that
999	// contains a pseudo-header field that appears in a header
1000	// block after a regular header field MUST be treated as
1001	// malformed (Section 8.1.2.6)."
1002	testRejectRequest(t, func(st *serverTester) {
1003		st.addLogFilter("pseudo-header after regular header")
1004		var buf bytes.Buffer
1005		enc := hpack.NewEncoder(&buf)
1006		enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
1007		enc.WriteField(hpack.HeaderField{Name: "regular", Value: "foobar"})
1008		enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"})
1009		enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
1010		st.writeHeaders(HeadersFrameParam{
1011			StreamID:      1, // clients send odd numbers
1012			BlockFragment: buf.Bytes(),
1013			EndStream:     true,
1014			EndHeaders:    true,
1015		})
1016	})
1017}
1018
1019func TestServer_Request_Reject_Pseudo_Missing_path(t *testing.T) {
1020	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":path", "") })
1021}
1022
1023func TestServer_Request_Reject_Pseudo_Missing_scheme(t *testing.T) {
1024	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "") })
1025}
1026
1027func TestServer_Request_Reject_Pseudo_scheme_invalid(t *testing.T) {
1028	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "bogus") })
1029}
1030
1031func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
1032	testRejectRequest(t, func(st *serverTester) {
1033		st.addLogFilter(`invalid pseudo-header ":unknown_thing"`)
1034		st.bodylessReq1(":unknown_thing", "")
1035	})
1036}
1037
1038func testRejectRequest(t *testing.T, send func(*serverTester)) {
1039	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1040		t.Error("server request made it to handler; should've been rejected")
1041	})
1042	defer st.Close()
1043
1044	st.greet()
1045	send(st)
1046	st.wantRSTStream(1, ErrCodeProtocol)
1047}
1048
1049func testRejectRequestWithProtocolError(t *testing.T, send func(*serverTester)) {
1050	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1051		t.Error("server request made it to handler; should've been rejected")
1052	}, optQuiet)
1053	defer st.Close()
1054
1055	st.greet()
1056	send(st)
1057	gf := st.wantGoAway()
1058	if gf.ErrCode != ErrCodeProtocol {
1059		t.Errorf("err code = %v; want %v", gf.ErrCode, ErrCodeProtocol)
1060	}
1061}
1062
1063// Section 5.1, on idle connections: "Receiving any frame other than
1064// HEADERS or PRIORITY on a stream in this state MUST be treated as a
1065// connection error (Section 5.4.1) of type PROTOCOL_ERROR."
1066func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) {
1067	testRejectRequestWithProtocolError(t, func(st *serverTester) {
1068		st.fr.WriteWindowUpdate(123, 456)
1069	})
1070}
1071func TestRejectFrameOnIdle_Data(t *testing.T) {
1072	testRejectRequestWithProtocolError(t, func(st *serverTester) {
1073		st.fr.WriteData(123, true, nil)
1074	})
1075}
1076func TestRejectFrameOnIdle_RSTStream(t *testing.T) {
1077	testRejectRequestWithProtocolError(t, func(st *serverTester) {
1078		st.fr.WriteRSTStream(123, ErrCodeCancel)
1079	})
1080}
1081
1082func TestServer_Request_Connect(t *testing.T) {
1083	testServerRequest(t, func(st *serverTester) {
1084		st.writeHeaders(HeadersFrameParam{
1085			StreamID: 1,
1086			BlockFragment: st.encodeHeaderRaw(
1087				":method", "CONNECT",
1088				":authority", "example.com:123",
1089			),
1090			EndStream:  true,
1091			EndHeaders: true,
1092		})
1093	}, func(r *http.Request) {
1094		if g, w := r.Method, "CONNECT"; g != w {
1095			t.Errorf("Method = %q; want %q", g, w)
1096		}
1097		if g, w := r.RequestURI, "example.com:123"; g != w {
1098			t.Errorf("RequestURI = %q; want %q", g, w)
1099		}
1100		if g, w := r.URL.Host, "example.com:123"; g != w {
1101			t.Errorf("URL.Host = %q; want %q", g, w)
1102		}
1103	})
1104}
1105
1106func TestServer_Request_Connect_InvalidPath(t *testing.T) {
1107	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1108		st.writeHeaders(HeadersFrameParam{
1109			StreamID: 1,
1110			BlockFragment: st.encodeHeaderRaw(
1111				":method", "CONNECT",
1112				":authority", "example.com:123",
1113				":path", "/bogus",
1114			),
1115			EndStream:  true,
1116			EndHeaders: true,
1117		})
1118	})
1119}
1120
1121func TestServer_Request_Connect_InvalidScheme(t *testing.T) {
1122	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1123		st.writeHeaders(HeadersFrameParam{
1124			StreamID: 1,
1125			BlockFragment: st.encodeHeaderRaw(
1126				":method", "CONNECT",
1127				":authority", "example.com:123",
1128				":scheme", "https",
1129			),
1130			EndStream:  true,
1131			EndHeaders: true,
1132		})
1133	})
1134}
1135
1136func TestServer_Ping(t *testing.T) {
1137	st := newServerTester(t, nil)
1138	defer st.Close()
1139	st.greet()
1140
1141	// Server should ignore this one, since it has ACK set.
1142	ackPingData := [8]byte{1, 2, 4, 8, 16, 32, 64, 128}
1143	if err := st.fr.WritePing(true, ackPingData); err != nil {
1144		t.Fatal(err)
1145	}
1146
1147	// But the server should reply to this one, since ACK is false.
1148	pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
1149	if err := st.fr.WritePing(false, pingData); err != nil {
1150		t.Fatal(err)
1151	}
1152
1153	pf := st.wantPing()
1154	if !pf.Flags.Has(FlagPingAck) {
1155		t.Error("response ping doesn't have ACK set")
1156	}
1157	if pf.Data != pingData {
1158		t.Errorf("response ping has data %q; want %q", pf.Data, pingData)
1159	}
1160}
1161
1162func TestServer_RejectsLargeFrames(t *testing.T) {
1163	if runtime.GOOS == "windows" {
1164		t.Skip("see golang.org/issue/13434")
1165	}
1166
1167	st := newServerTester(t, nil)
1168	defer st.Close()
1169	st.greet()
1170
1171	// Write too large of a frame (too large by one byte)
1172	// We ignore the return value because it's expected that the server
1173	// will only read the first 9 bytes (the headre) and then disconnect.
1174	st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, defaultMaxReadFrameSize+1))
1175
1176	gf := st.wantGoAway()
1177	if gf.ErrCode != ErrCodeFrameSize {
1178		t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFrameSize)
1179	}
1180	if st.serverLogBuf.Len() != 0 {
1181		// Previously we spun here for a bit until the GOAWAY disconnect
1182		// timer fired, logging while we fired.
1183		t.Errorf("unexpected server output: %.500s\n", st.serverLogBuf.Bytes())
1184	}
1185}
1186
1187func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
1188	puppet := newHandlerPuppet()
1189	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1190		puppet.act(w, r)
1191	})
1192	defer st.Close()
1193	defer puppet.done()
1194
1195	st.greet()
1196
1197	st.writeHeaders(HeadersFrameParam{
1198		StreamID:      1, // clients send odd numbers
1199		BlockFragment: st.encodeHeader(":method", "POST"),
1200		EndStream:     false, // data coming
1201		EndHeaders:    true,
1202	})
1203	st.writeData(1, false, []byte("abcdef"))
1204	puppet.do(readBodyHandler(t, "abc"))
1205	st.wantWindowUpdate(0, 3)
1206	st.wantWindowUpdate(1, 3)
1207
1208	puppet.do(readBodyHandler(t, "def"))
1209	st.wantWindowUpdate(0, 3)
1210	st.wantWindowUpdate(1, 3)
1211
1212	st.writeData(1, true, []byte("ghijkl")) // END_STREAM here
1213	puppet.do(readBodyHandler(t, "ghi"))
1214	puppet.do(readBodyHandler(t, "jkl"))
1215	st.wantWindowUpdate(0, 3)
1216	st.wantWindowUpdate(0, 3) // no more stream-level, since END_STREAM
1217}
1218
1219// the version of the TestServer_Handler_Sends_WindowUpdate with padding.
1220// See golang.org/issue/16556
1221func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) {
1222	puppet := newHandlerPuppet()
1223	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1224		puppet.act(w, r)
1225	})
1226	defer st.Close()
1227	defer puppet.done()
1228
1229	st.greet()
1230
1231	st.writeHeaders(HeadersFrameParam{
1232		StreamID:      1,
1233		BlockFragment: st.encodeHeader(":method", "POST"),
1234		EndStream:     false,
1235		EndHeaders:    true,
1236	})
1237	st.writeDataPadded(1, false, []byte("abcdef"), []byte{0, 0, 0, 0})
1238
1239	// Expect to immediately get our 5 bytes of padding back for
1240	// both the connection and stream (4 bytes of padding + 1 byte of length)
1241	st.wantWindowUpdate(0, 5)
1242	st.wantWindowUpdate(1, 5)
1243
1244	puppet.do(readBodyHandler(t, "abc"))
1245	st.wantWindowUpdate(0, 3)
1246	st.wantWindowUpdate(1, 3)
1247
1248	puppet.do(readBodyHandler(t, "def"))
1249	st.wantWindowUpdate(0, 3)
1250	st.wantWindowUpdate(1, 3)
1251}
1252
1253func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) {
1254	st := newServerTester(t, nil)
1255	defer st.Close()
1256	st.greet()
1257	if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
1258		t.Fatal(err)
1259	}
1260	gf := st.wantGoAway()
1261	if gf.ErrCode != ErrCodeFlowControl {
1262		t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFlowControl)
1263	}
1264	if gf.LastStreamID != 0 {
1265		t.Errorf("GOAWAY last stream ID = %v; want %v", gf.LastStreamID, 0)
1266	}
1267}
1268
1269func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) {
1270	inHandler := make(chan bool)
1271	blockHandler := make(chan bool)
1272	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1273		inHandler <- true
1274		<-blockHandler
1275	})
1276	defer st.Close()
1277	defer close(blockHandler)
1278	st.greet()
1279	st.writeHeaders(HeadersFrameParam{
1280		StreamID:      1,
1281		BlockFragment: st.encodeHeader(":method", "POST"),
1282		EndStream:     false, // keep it open
1283		EndHeaders:    true,
1284	})
1285	<-inHandler
1286	// Send a bogus window update:
1287	if err := st.fr.WriteWindowUpdate(1, 1<<31-1); err != nil {
1288		t.Fatal(err)
1289	}
1290	st.wantRSTStream(1, ErrCodeFlowControl)
1291}
1292
1293// testServerPostUnblock sends a hanging POST with unsent data to handler,
1294// then runs fn once in the handler, and verifies that the error returned from
1295// handler is acceptable. It fails if takes over 5 seconds for handler to exit.
1296func testServerPostUnblock(t *testing.T,
1297	handler func(http.ResponseWriter, *http.Request) error,
1298	fn func(*serverTester),
1299	checkErr func(error),
1300	otherHeaders ...string) {
1301	inHandler := make(chan bool)
1302	errc := make(chan error, 1)
1303	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1304		inHandler <- true
1305		errc <- handler(w, r)
1306	})
1307	defer st.Close()
1308	st.greet()
1309	st.writeHeaders(HeadersFrameParam{
1310		StreamID:      1,
1311		BlockFragment: st.encodeHeader(append([]string{":method", "POST"}, otherHeaders...)...),
1312		EndStream:     false, // keep it open
1313		EndHeaders:    true,
1314	})
1315	<-inHandler
1316	fn(st)
1317	select {
1318	case err := <-errc:
1319		if checkErr != nil {
1320			checkErr(err)
1321		}
1322	case <-time.After(5 * time.Second):
1323		t.Fatal("timeout waiting for Handler to return")
1324	}
1325}
1326
1327func TestServer_RSTStream_Unblocks_Read(t *testing.T) {
1328	testServerPostUnblock(t,
1329		func(w http.ResponseWriter, r *http.Request) (err error) {
1330			_, err = r.Body.Read(make([]byte, 1))
1331			return
1332		},
1333		func(st *serverTester) {
1334			if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1335				t.Fatal(err)
1336			}
1337		},
1338		func(err error) {
1339			want := StreamError{StreamID: 0x1, Code: 0x8}
1340			if !reflect.DeepEqual(err, want) {
1341				t.Errorf("Read error = %v; want %v", err, want)
1342			}
1343		},
1344	)
1345}
1346
1347func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1348	// Run this test a bunch, because it doesn't always
1349	// deadlock. But with a bunch, it did.
1350	n := 50
1351	if testing.Short() {
1352		n = 5
1353	}
1354	for i := 0; i < n; i++ {
1355		testServer_RSTStream_Unblocks_Header_Write(t)
1356	}
1357}
1358
1359func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1360	inHandler := make(chan bool, 1)
1361	unblockHandler := make(chan bool, 1)
1362	headerWritten := make(chan bool, 1)
1363	wroteRST := make(chan bool, 1)
1364
1365	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1366		inHandler <- true
1367		<-wroteRST
1368		w.Header().Set("foo", "bar")
1369		w.WriteHeader(200)
1370		w.(http.Flusher).Flush()
1371		headerWritten <- true
1372		<-unblockHandler
1373	})
1374	defer st.Close()
1375
1376	st.greet()
1377	st.writeHeaders(HeadersFrameParam{
1378		StreamID:      1,
1379		BlockFragment: st.encodeHeader(":method", "POST"),
1380		EndStream:     false, // keep it open
1381		EndHeaders:    true,
1382	})
1383	<-inHandler
1384	if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1385		t.Fatal(err)
1386	}
1387	wroteRST <- true
1388	st.awaitIdle()
1389	select {
1390	case <-headerWritten:
1391	case <-time.After(2 * time.Second):
1392		t.Error("timeout waiting for header write")
1393	}
1394	unblockHandler <- true
1395}
1396
1397func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
1398	testServerPostUnblock(t,
1399		func(w http.ResponseWriter, r *http.Request) (err error) {
1400			_, err = r.Body.Read(make([]byte, 1))
1401			return
1402		},
1403		func(st *serverTester) { st.cc.Close() },
1404		func(err error) {
1405			if err == nil {
1406				t.Error("unexpected nil error from Request.Body.Read")
1407			}
1408		},
1409	)
1410}
1411
1412var blockUntilClosed = func(w http.ResponseWriter, r *http.Request) error {
1413	<-w.(http.CloseNotifier).CloseNotify()
1414	return nil
1415}
1416
1417func TestServer_CloseNotify_After_RSTStream(t *testing.T) {
1418	testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1419		if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1420			t.Fatal(err)
1421		}
1422	}, nil)
1423}
1424
1425func TestServer_CloseNotify_After_ConnClose(t *testing.T) {
1426	testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { st.cc.Close() }, nil)
1427}
1428
1429// that CloseNotify unblocks after a stream error due to the client's
1430// problem that's unrelated to them explicitly canceling it (which is
1431// TestServer_CloseNotify_After_RSTStream above)
1432func TestServer_CloseNotify_After_StreamError(t *testing.T) {
1433	testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1434		// data longer than declared Content-Length => stream error
1435		st.writeData(1, true, []byte("1234"))
1436	}, nil, "content-length", "3")
1437}
1438
1439func TestServer_StateTransitions(t *testing.T) {
1440	var st *serverTester
1441	inHandler := make(chan bool)
1442	writeData := make(chan bool)
1443	leaveHandler := make(chan bool)
1444	st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1445		inHandler <- true
1446		if st.stream(1) == nil {
1447			t.Errorf("nil stream 1 in handler")
1448		}
1449		if got, want := st.streamState(1), stateOpen; got != want {
1450			t.Errorf("in handler, state is %v; want %v", got, want)
1451		}
1452		writeData <- true
1453		if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
1454			t.Errorf("body read = %d, %v; want 0, EOF", n, err)
1455		}
1456		if got, want := st.streamState(1), stateHalfClosedRemote; got != want {
1457			t.Errorf("in handler, state is %v; want %v", got, want)
1458		}
1459
1460		<-leaveHandler
1461	})
1462	st.greet()
1463	if st.stream(1) != nil {
1464		t.Fatal("stream 1 should be empty")
1465	}
1466	if got := st.streamState(1); got != stateIdle {
1467		t.Fatalf("stream 1 should be idle; got %v", got)
1468	}
1469
1470	st.writeHeaders(HeadersFrameParam{
1471		StreamID:      1,
1472		BlockFragment: st.encodeHeader(":method", "POST"),
1473		EndStream:     false, // keep it open
1474		EndHeaders:    true,
1475	})
1476	<-inHandler
1477	<-writeData
1478	st.writeData(1, true, nil)
1479
1480	leaveHandler <- true
1481	hf := st.wantHeaders()
1482	if !hf.StreamEnded() {
1483		t.Fatal("expected END_STREAM flag")
1484	}
1485
1486	if got, want := st.streamState(1), stateClosed; got != want {
1487		t.Errorf("at end, state is %v; want %v", got, want)
1488	}
1489	if st.stream(1) != nil {
1490		t.Fatal("at end, stream 1 should be gone")
1491	}
1492}
1493
1494// test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
1495func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
1496	testServerRejectsConn(t, func(st *serverTester) {
1497		st.writeHeaders(HeadersFrameParam{
1498			StreamID:      1,
1499			BlockFragment: st.encodeHeader(),
1500			EndStream:     true,
1501			EndHeaders:    false,
1502		})
1503		st.writeHeaders(HeadersFrameParam{ // Not a continuation.
1504			StreamID:      3, // different stream.
1505			BlockFragment: st.encodeHeader(),
1506			EndStream:     true,
1507			EndHeaders:    true,
1508		})
1509	})
1510}
1511
1512// test HEADERS w/o EndHeaders + PING (should get rejected)
1513func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
1514	testServerRejectsConn(t, func(st *serverTester) {
1515		st.writeHeaders(HeadersFrameParam{
1516			StreamID:      1,
1517			BlockFragment: st.encodeHeader(),
1518			EndStream:     true,
1519			EndHeaders:    false,
1520		})
1521		if err := st.fr.WritePing(false, [8]byte{}); err != nil {
1522			t.Fatal(err)
1523		}
1524	})
1525}
1526
1527// test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
1528func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
1529	testServerRejectsConn(t, func(st *serverTester) {
1530		st.writeHeaders(HeadersFrameParam{
1531			StreamID:      1,
1532			BlockFragment: st.encodeHeader(),
1533			EndStream:     true,
1534			EndHeaders:    true,
1535		})
1536		st.wantHeaders()
1537		if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1538			t.Fatal(err)
1539		}
1540	})
1541}
1542
1543// test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID
1544func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
1545	testServerRejectsConn(t, func(st *serverTester) {
1546		st.writeHeaders(HeadersFrameParam{
1547			StreamID:      1,
1548			BlockFragment: st.encodeHeader(),
1549			EndStream:     true,
1550			EndHeaders:    false,
1551		})
1552		if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1553			t.Fatal(err)
1554		}
1555	})
1556}
1557
1558// No HEADERS on stream 0.
1559func TestServer_Rejects_Headers0(t *testing.T) {
1560	testServerRejectsConn(t, func(st *serverTester) {
1561		st.fr.AllowIllegalWrites = true
1562		st.writeHeaders(HeadersFrameParam{
1563			StreamID:      0,
1564			BlockFragment: st.encodeHeader(),
1565			EndStream:     true,
1566			EndHeaders:    true,
1567		})
1568	})
1569}
1570
1571// No CONTINUATION on stream 0.
1572func TestServer_Rejects_Continuation0(t *testing.T) {
1573	testServerRejectsConn(t, func(st *serverTester) {
1574		st.fr.AllowIllegalWrites = true
1575		if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil {
1576			t.Fatal(err)
1577		}
1578	})
1579}
1580
1581// No PRIORITY on stream 0.
1582func TestServer_Rejects_Priority0(t *testing.T) {
1583	testServerRejectsConn(t, func(st *serverTester) {
1584		st.fr.AllowIllegalWrites = true
1585		st.writePriority(0, PriorityParam{StreamDep: 1})
1586	})
1587}
1588
1589// No HEADERS frame with a self-dependence.
1590func TestServer_Rejects_HeadersSelfDependence(t *testing.T) {
1591	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1592		st.fr.AllowIllegalWrites = true
1593		st.writeHeaders(HeadersFrameParam{
1594			StreamID:      1,
1595			BlockFragment: st.encodeHeader(),
1596			EndStream:     true,
1597			EndHeaders:    true,
1598			Priority:      PriorityParam{StreamDep: 1},
1599		})
1600	})
1601}
1602
1603// No PRIORTY frame with a self-dependence.
1604func TestServer_Rejects_PrioritySelfDependence(t *testing.T) {
1605	testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1606		st.fr.AllowIllegalWrites = true
1607		st.writePriority(1, PriorityParam{StreamDep: 1})
1608	})
1609}
1610
1611func TestServer_Rejects_PushPromise(t *testing.T) {
1612	testServerRejectsConn(t, func(st *serverTester) {
1613		pp := PushPromiseParam{
1614			StreamID:  1,
1615			PromiseID: 3,
1616		}
1617		if err := st.fr.WritePushPromise(pp); err != nil {
1618			t.Fatal(err)
1619		}
1620	})
1621}
1622
1623// testServerRejectsConn tests that the server hangs up with a GOAWAY
1624// frame and a server close after the client does something
1625// deserving a CONNECTION_ERROR.
1626func testServerRejectsConn(t *testing.T, writeReq func(*serverTester)) {
1627	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1628	st.addLogFilter("connection error: PROTOCOL_ERROR")
1629	defer st.Close()
1630	st.greet()
1631	writeReq(st)
1632
1633	st.wantGoAway()
1634	errc := make(chan error, 1)
1635	go func() {
1636		fr, err := st.fr.ReadFrame()
1637		if err == nil {
1638			err = fmt.Errorf("got frame of type %T", fr)
1639		}
1640		errc <- err
1641	}()
1642	select {
1643	case err := <-errc:
1644		if err != io.EOF {
1645			t.Errorf("ReadFrame = %v; want io.EOF", err)
1646		}
1647	case <-time.After(2 * time.Second):
1648		t.Error("timeout waiting for disconnect")
1649	}
1650}
1651
1652// testServerRejectsStream tests that the server sends a RST_STREAM with the provided
1653// error code after a client sends a bogus request.
1654func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTester)) {
1655	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1656	defer st.Close()
1657	st.greet()
1658	writeReq(st)
1659	st.wantRSTStream(1, code)
1660}
1661
1662// testServerRequest sets up an idle HTTP/2 connection and lets you
1663// write a single request with writeReq, and then verify that the
1664// *http.Request is built correctly in checkReq.
1665func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) {
1666	gotReq := make(chan bool, 1)
1667	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1668		if r.Body == nil {
1669			t.Fatal("nil Body")
1670		}
1671		checkReq(r)
1672		gotReq <- true
1673	})
1674	defer st.Close()
1675
1676	st.greet()
1677	writeReq(st)
1678
1679	select {
1680	case <-gotReq:
1681	case <-time.After(2 * time.Second):
1682		t.Error("timeout waiting for request")
1683	}
1684}
1685
1686func getSlash(st *serverTester) { st.bodylessReq1() }
1687
1688func TestServer_Response_NoData(t *testing.T) {
1689	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1690		// Nothing.
1691		return nil
1692	}, func(st *serverTester) {
1693		getSlash(st)
1694		hf := st.wantHeaders()
1695		if !hf.StreamEnded() {
1696			t.Fatal("want END_STREAM flag")
1697		}
1698		if !hf.HeadersEnded() {
1699			t.Fatal("want END_HEADERS flag")
1700		}
1701	})
1702}
1703
1704func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
1705	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1706		w.Header().Set("Foo-Bar", "some-value")
1707		return nil
1708	}, func(st *serverTester) {
1709		getSlash(st)
1710		hf := st.wantHeaders()
1711		if !hf.StreamEnded() {
1712			t.Fatal("want END_STREAM flag")
1713		}
1714		if !hf.HeadersEnded() {
1715			t.Fatal("want END_HEADERS flag")
1716		}
1717		goth := st.decodeHeader(hf.HeaderBlockFragment())
1718		wanth := [][2]string{
1719			{":status", "200"},
1720			{"foo-bar", "some-value"},
1721			{"content-length", "0"},
1722		}
1723		if !reflect.DeepEqual(goth, wanth) {
1724			t.Errorf("Got headers %v; want %v", goth, wanth)
1725		}
1726	})
1727}
1728
1729func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) {
1730	const msg = "<html>this is HTML."
1731	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1732		w.Header().Set("Content-Type", "foo/bar")
1733		io.WriteString(w, msg)
1734		return nil
1735	}, func(st *serverTester) {
1736		getSlash(st)
1737		hf := st.wantHeaders()
1738		if hf.StreamEnded() {
1739			t.Fatal("don't want END_STREAM, expecting data")
1740		}
1741		if !hf.HeadersEnded() {
1742			t.Fatal("want END_HEADERS flag")
1743		}
1744		goth := st.decodeHeader(hf.HeaderBlockFragment())
1745		wanth := [][2]string{
1746			{":status", "200"},
1747			{"content-type", "foo/bar"},
1748			{"content-length", strconv.Itoa(len(msg))},
1749		}
1750		if !reflect.DeepEqual(goth, wanth) {
1751			t.Errorf("Got headers %v; want %v", goth, wanth)
1752		}
1753		df := st.wantData()
1754		if !df.StreamEnded() {
1755			t.Error("expected DATA to have END_STREAM flag")
1756		}
1757		if got := string(df.Data()); got != msg {
1758			t.Errorf("got DATA %q; want %q", got, msg)
1759		}
1760	})
1761}
1762
1763func TestServer_Response_TransferEncoding_chunked(t *testing.T) {
1764	const msg = "hi"
1765	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1766		w.Header().Set("Transfer-Encoding", "chunked") // should be stripped
1767		io.WriteString(w, msg)
1768		return nil
1769	}, func(st *serverTester) {
1770		getSlash(st)
1771		hf := st.wantHeaders()
1772		goth := st.decodeHeader(hf.HeaderBlockFragment())
1773		wanth := [][2]string{
1774			{":status", "200"},
1775			{"content-type", "text/plain; charset=utf-8"},
1776			{"content-length", strconv.Itoa(len(msg))},
1777		}
1778		if !reflect.DeepEqual(goth, wanth) {
1779			t.Errorf("Got headers %v; want %v", goth, wanth)
1780		}
1781	})
1782}
1783
1784// Header accessed only after the initial write.
1785func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) {
1786	const msg = "<html>this is HTML."
1787	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1788		io.WriteString(w, msg)
1789		w.Header().Set("foo", "should be ignored")
1790		return nil
1791	}, func(st *serverTester) {
1792		getSlash(st)
1793		hf := st.wantHeaders()
1794		if hf.StreamEnded() {
1795			t.Fatal("unexpected END_STREAM")
1796		}
1797		if !hf.HeadersEnded() {
1798			t.Fatal("want END_HEADERS flag")
1799		}
1800		goth := st.decodeHeader(hf.HeaderBlockFragment())
1801		wanth := [][2]string{
1802			{":status", "200"},
1803			{"content-type", "text/html; charset=utf-8"},
1804			{"content-length", strconv.Itoa(len(msg))},
1805		}
1806		if !reflect.DeepEqual(goth, wanth) {
1807			t.Errorf("Got headers %v; want %v", goth, wanth)
1808		}
1809	})
1810}
1811
1812// Header accessed before the initial write and later mutated.
1813func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) {
1814	const msg = "<html>this is HTML."
1815	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1816		w.Header().Set("foo", "proper value")
1817		io.WriteString(w, msg)
1818		w.Header().Set("foo", "should be ignored")
1819		return nil
1820	}, func(st *serverTester) {
1821		getSlash(st)
1822		hf := st.wantHeaders()
1823		if hf.StreamEnded() {
1824			t.Fatal("unexpected END_STREAM")
1825		}
1826		if !hf.HeadersEnded() {
1827			t.Fatal("want END_HEADERS flag")
1828		}
1829		goth := st.decodeHeader(hf.HeaderBlockFragment())
1830		wanth := [][2]string{
1831			{":status", "200"},
1832			{"foo", "proper value"},
1833			{"content-type", "text/html; charset=utf-8"},
1834			{"content-length", strconv.Itoa(len(msg))},
1835		}
1836		if !reflect.DeepEqual(goth, wanth) {
1837			t.Errorf("Got headers %v; want %v", goth, wanth)
1838		}
1839	})
1840}
1841
1842func TestServer_Response_Data_SniffLenType(t *testing.T) {
1843	const msg = "<html>this is HTML."
1844	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1845		io.WriteString(w, msg)
1846		return nil
1847	}, func(st *serverTester) {
1848		getSlash(st)
1849		hf := st.wantHeaders()
1850		if hf.StreamEnded() {
1851			t.Fatal("don't want END_STREAM, expecting data")
1852		}
1853		if !hf.HeadersEnded() {
1854			t.Fatal("want END_HEADERS flag")
1855		}
1856		goth := st.decodeHeader(hf.HeaderBlockFragment())
1857		wanth := [][2]string{
1858			{":status", "200"},
1859			{"content-type", "text/html; charset=utf-8"},
1860			{"content-length", strconv.Itoa(len(msg))},
1861		}
1862		if !reflect.DeepEqual(goth, wanth) {
1863			t.Errorf("Got headers %v; want %v", goth, wanth)
1864		}
1865		df := st.wantData()
1866		if !df.StreamEnded() {
1867			t.Error("expected DATA to have END_STREAM flag")
1868		}
1869		if got := string(df.Data()); got != msg {
1870			t.Errorf("got DATA %q; want %q", got, msg)
1871		}
1872	})
1873}
1874
1875func TestServer_Response_Header_Flush_MidWrite(t *testing.T) {
1876	const msg = "<html>this is HTML"
1877	const msg2 = ", and this is the next chunk"
1878	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1879		io.WriteString(w, msg)
1880		w.(http.Flusher).Flush()
1881		io.WriteString(w, msg2)
1882		return nil
1883	}, func(st *serverTester) {
1884		getSlash(st)
1885		hf := st.wantHeaders()
1886		if hf.StreamEnded() {
1887			t.Fatal("unexpected END_STREAM flag")
1888		}
1889		if !hf.HeadersEnded() {
1890			t.Fatal("want END_HEADERS flag")
1891		}
1892		goth := st.decodeHeader(hf.HeaderBlockFragment())
1893		wanth := [][2]string{
1894			{":status", "200"},
1895			{"content-type", "text/html; charset=utf-8"}, // sniffed
1896			// and no content-length
1897		}
1898		if !reflect.DeepEqual(goth, wanth) {
1899			t.Errorf("Got headers %v; want %v", goth, wanth)
1900		}
1901		{
1902			df := st.wantData()
1903			if df.StreamEnded() {
1904				t.Error("unexpected END_STREAM flag")
1905			}
1906			if got := string(df.Data()); got != msg {
1907				t.Errorf("got DATA %q; want %q", got, msg)
1908			}
1909		}
1910		{
1911			df := st.wantData()
1912			if !df.StreamEnded() {
1913				t.Error("wanted END_STREAM flag on last data chunk")
1914			}
1915			if got := string(df.Data()); got != msg2 {
1916				t.Errorf("got DATA %q; want %q", got, msg2)
1917			}
1918		}
1919	})
1920}
1921
1922func TestServer_Response_LargeWrite(t *testing.T) {
1923	const size = 1 << 20
1924	const maxFrameSize = 16 << 10
1925	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1926		n, err := w.Write(bytes.Repeat([]byte("a"), size))
1927		if err != nil {
1928			return fmt.Errorf("Write error: %v", err)
1929		}
1930		if n != size {
1931			return fmt.Errorf("wrong size %d from Write", n)
1932		}
1933		return nil
1934	}, func(st *serverTester) {
1935		if err := st.fr.WriteSettings(
1936			Setting{SettingInitialWindowSize, 0},
1937			Setting{SettingMaxFrameSize, maxFrameSize},
1938		); err != nil {
1939			t.Fatal(err)
1940		}
1941		st.wantSettingsAck()
1942
1943		getSlash(st) // make the single request
1944
1945		// Give the handler quota to write:
1946		if err := st.fr.WriteWindowUpdate(1, size); err != nil {
1947			t.Fatal(err)
1948		}
1949		// Give the handler quota to write to connection-level
1950		// window as well
1951		if err := st.fr.WriteWindowUpdate(0, size); err != nil {
1952			t.Fatal(err)
1953		}
1954		hf := st.wantHeaders()
1955		if hf.StreamEnded() {
1956			t.Fatal("unexpected END_STREAM flag")
1957		}
1958		if !hf.HeadersEnded() {
1959			t.Fatal("want END_HEADERS flag")
1960		}
1961		goth := st.decodeHeader(hf.HeaderBlockFragment())
1962		wanth := [][2]string{
1963			{":status", "200"},
1964			{"content-type", "text/plain; charset=utf-8"}, // sniffed
1965			// and no content-length
1966		}
1967		if !reflect.DeepEqual(goth, wanth) {
1968			t.Errorf("Got headers %v; want %v", goth, wanth)
1969		}
1970		var bytes, frames int
1971		for {
1972			df := st.wantData()
1973			bytes += len(df.Data())
1974			frames++
1975			for _, b := range df.Data() {
1976				if b != 'a' {
1977					t.Fatal("non-'a' byte seen in DATA")
1978				}
1979			}
1980			if df.StreamEnded() {
1981				break
1982			}
1983		}
1984		if bytes != size {
1985			t.Errorf("Got %d bytes; want %d", bytes, size)
1986		}
1987		if want := int(size / maxFrameSize); frames < want || frames > want*2 {
1988			t.Errorf("Got %d frames; want %d", frames, size)
1989		}
1990	})
1991}
1992
1993// Test that the handler can't write more than the client allows
1994func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) {
1995	// Make these reads. Before each read, the client adds exactly enough
1996	// flow-control to satisfy the read. Numbers chosen arbitrarily.
1997	reads := []int{123, 1, 13, 127}
1998	size := 0
1999	for _, n := range reads {
2000		size += n
2001	}
2002
2003	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2004		w.(http.Flusher).Flush()
2005		n, err := w.Write(bytes.Repeat([]byte("a"), size))
2006		if err != nil {
2007			return fmt.Errorf("Write error: %v", err)
2008		}
2009		if n != size {
2010			return fmt.Errorf("wrong size %d from Write", n)
2011		}
2012		return nil
2013	}, func(st *serverTester) {
2014		// Set the window size to something explicit for this test.
2015		// It's also how much initial data we expect.
2016		if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, uint32(reads[0])}); err != nil {
2017			t.Fatal(err)
2018		}
2019		st.wantSettingsAck()
2020
2021		getSlash(st) // make the single request
2022
2023		hf := st.wantHeaders()
2024		if hf.StreamEnded() {
2025			t.Fatal("unexpected END_STREAM flag")
2026		}
2027		if !hf.HeadersEnded() {
2028			t.Fatal("want END_HEADERS flag")
2029		}
2030
2031		df := st.wantData()
2032		if got := len(df.Data()); got != reads[0] {
2033			t.Fatalf("Initial window size = %d but got DATA with %d bytes", reads[0], got)
2034		}
2035
2036		for _, quota := range reads[1:] {
2037			if err := st.fr.WriteWindowUpdate(1, uint32(quota)); err != nil {
2038				t.Fatal(err)
2039			}
2040			df := st.wantData()
2041			if int(quota) != len(df.Data()) {
2042				t.Fatalf("read %d bytes after giving %d quota", len(df.Data()), quota)
2043			}
2044		}
2045	})
2046}
2047
2048// Test that the handler blocked in a Write is unblocked if the server sends a RST_STREAM.
2049func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) {
2050	const size = 1 << 20
2051	const maxFrameSize = 16 << 10
2052	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2053		w.(http.Flusher).Flush()
2054		errc := make(chan error, 1)
2055		go func() {
2056			_, err := w.Write(bytes.Repeat([]byte("a"), size))
2057			errc <- err
2058		}()
2059		select {
2060		case err := <-errc:
2061			if err == nil {
2062				return errors.New("unexpected nil error from Write in handler")
2063			}
2064			return nil
2065		case <-time.After(2 * time.Second):
2066			return errors.New("timeout waiting for Write in handler")
2067		}
2068	}, func(st *serverTester) {
2069		if err := st.fr.WriteSettings(
2070			Setting{SettingInitialWindowSize, 0},
2071			Setting{SettingMaxFrameSize, maxFrameSize},
2072		); err != nil {
2073			t.Fatal(err)
2074		}
2075		st.wantSettingsAck()
2076
2077		getSlash(st) // make the single request
2078
2079		hf := st.wantHeaders()
2080		if hf.StreamEnded() {
2081			t.Fatal("unexpected END_STREAM flag")
2082		}
2083		if !hf.HeadersEnded() {
2084			t.Fatal("want END_HEADERS flag")
2085		}
2086
2087		if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
2088			t.Fatal(err)
2089		}
2090	})
2091}
2092
2093func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) {
2094	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2095		w.(http.Flusher).Flush()
2096		// Nothing; send empty DATA
2097		return nil
2098	}, func(st *serverTester) {
2099		// Handler gets no data quota:
2100		if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, 0}); err != nil {
2101			t.Fatal(err)
2102		}
2103		st.wantSettingsAck()
2104
2105		getSlash(st) // make the single request
2106
2107		hf := st.wantHeaders()
2108		if hf.StreamEnded() {
2109			t.Fatal("unexpected END_STREAM flag")
2110		}
2111		if !hf.HeadersEnded() {
2112			t.Fatal("want END_HEADERS flag")
2113		}
2114
2115		df := st.wantData()
2116		if got := len(df.Data()); got != 0 {
2117			t.Fatalf("unexpected %d DATA bytes; want 0", got)
2118		}
2119		if !df.StreamEnded() {
2120			t.Fatal("DATA didn't have END_STREAM")
2121		}
2122	})
2123}
2124
2125func TestServer_Response_Automatic100Continue(t *testing.T) {
2126	const msg = "foo"
2127	const reply = "bar"
2128	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2129		if v := r.Header.Get("Expect"); v != "" {
2130			t.Errorf("Expect header = %q; want empty", v)
2131		}
2132		buf := make([]byte, len(msg))
2133		// This read should trigger the 100-continue being sent.
2134		if n, err := io.ReadFull(r.Body, buf); err != nil || n != len(msg) || string(buf) != msg {
2135			return fmt.Errorf("ReadFull = %q, %v; want %q, nil", buf[:n], err, msg)
2136		}
2137		_, err := io.WriteString(w, reply)
2138		return err
2139	}, func(st *serverTester) {
2140		st.writeHeaders(HeadersFrameParam{
2141			StreamID:      1, // clients send odd numbers
2142			BlockFragment: st.encodeHeader(":method", "POST", "expect", "100-continue"),
2143			EndStream:     false,
2144			EndHeaders:    true,
2145		})
2146		hf := st.wantHeaders()
2147		if hf.StreamEnded() {
2148			t.Fatal("unexpected END_STREAM flag")
2149		}
2150		if !hf.HeadersEnded() {
2151			t.Fatal("want END_HEADERS flag")
2152		}
2153		goth := st.decodeHeader(hf.HeaderBlockFragment())
2154		wanth := [][2]string{
2155			{":status", "100"},
2156		}
2157		if !reflect.DeepEqual(goth, wanth) {
2158			t.Fatalf("Got headers %v; want %v", goth, wanth)
2159		}
2160
2161		// Okay, they sent status 100, so we can send our
2162		// gigantic and/or sensitive "foo" payload now.
2163		st.writeData(1, true, []byte(msg))
2164
2165		st.wantWindowUpdate(0, uint32(len(msg)))
2166
2167		hf = st.wantHeaders()
2168		if hf.StreamEnded() {
2169			t.Fatal("expected data to follow")
2170		}
2171		if !hf.HeadersEnded() {
2172			t.Fatal("want END_HEADERS flag")
2173		}
2174		goth = st.decodeHeader(hf.HeaderBlockFragment())
2175		wanth = [][2]string{
2176			{":status", "200"},
2177			{"content-type", "text/plain; charset=utf-8"},
2178			{"content-length", strconv.Itoa(len(reply))},
2179		}
2180		if !reflect.DeepEqual(goth, wanth) {
2181			t.Errorf("Got headers %v; want %v", goth, wanth)
2182		}
2183
2184		df := st.wantData()
2185		if string(df.Data()) != reply {
2186			t.Errorf("Client read %q; want %q", df.Data(), reply)
2187		}
2188		if !df.StreamEnded() {
2189			t.Errorf("expect data stream end")
2190		}
2191	})
2192}
2193
2194func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
2195	errc := make(chan error, 1)
2196	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2197		p := []byte("some data.\n")
2198		for {
2199			_, err := w.Write(p)
2200			if err != nil {
2201				errc <- err
2202				return nil
2203			}
2204		}
2205	}, func(st *serverTester) {
2206		st.writeHeaders(HeadersFrameParam{
2207			StreamID:      1,
2208			BlockFragment: st.encodeHeader(),
2209			EndStream:     false,
2210			EndHeaders:    true,
2211		})
2212		hf := st.wantHeaders()
2213		if hf.StreamEnded() {
2214			t.Fatal("unexpected END_STREAM flag")
2215		}
2216		if !hf.HeadersEnded() {
2217			t.Fatal("want END_HEADERS flag")
2218		}
2219		// Close the connection and wait for the handler to (hopefully) notice.
2220		st.cc.Close()
2221		select {
2222		case <-errc:
2223		case <-time.After(5 * time.Second):
2224			t.Error("timeout")
2225		}
2226	})
2227}
2228
2229func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
2230	const testPath = "/some/path"
2231
2232	inHandler := make(chan uint32)
2233	leaveHandler := make(chan bool)
2234	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2235		id := w.(*responseWriter).rws.stream.id
2236		inHandler <- id
2237		if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath {
2238			t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath)
2239		}
2240		<-leaveHandler
2241	})
2242	defer st.Close()
2243	st.greet()
2244	nextStreamID := uint32(1)
2245	streamID := func() uint32 {
2246		defer func() { nextStreamID += 2 }()
2247		return nextStreamID
2248	}
2249	sendReq := func(id uint32, headers ...string) {
2250		st.writeHeaders(HeadersFrameParam{
2251			StreamID:      id,
2252			BlockFragment: st.encodeHeader(headers...),
2253			EndStream:     true,
2254			EndHeaders:    true,
2255		})
2256	}
2257	for i := 0; i < defaultMaxStreams; i++ {
2258		sendReq(streamID())
2259		<-inHandler
2260	}
2261	defer func() {
2262		for i := 0; i < defaultMaxStreams; i++ {
2263			leaveHandler <- true
2264		}
2265	}()
2266
2267	// And this one should cross the limit:
2268	// (It's also sent as a CONTINUATION, to verify we still track the decoder context,
2269	// even if we're rejecting it)
2270	rejectID := streamID()
2271	headerBlock := st.encodeHeader(":path", testPath)
2272	frag1, frag2 := headerBlock[:3], headerBlock[3:]
2273	st.writeHeaders(HeadersFrameParam{
2274		StreamID:      rejectID,
2275		BlockFragment: frag1,
2276		EndStream:     true,
2277		EndHeaders:    false, // CONTINUATION coming
2278	})
2279	if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil {
2280		t.Fatal(err)
2281	}
2282	st.wantRSTStream(rejectID, ErrCodeProtocol)
2283
2284	// But let a handler finish:
2285	leaveHandler <- true
2286	st.wantHeaders()
2287
2288	// And now another stream should be able to start:
2289	goodID := streamID()
2290	sendReq(goodID, ":path", testPath)
2291	select {
2292	case got := <-inHandler:
2293		if got != goodID {
2294			t.Errorf("Got stream %d; want %d", got, goodID)
2295		}
2296	case <-time.After(3 * time.Second):
2297		t.Error("timeout waiting for handler")
2298	}
2299}
2300
2301// So many response headers that the server needs to use CONTINUATION frames:
2302func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) {
2303	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2304		h := w.Header()
2305		for i := 0; i < 5000; i++ {
2306			h.Set(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("x-value-%d", i))
2307		}
2308		return nil
2309	}, func(st *serverTester) {
2310		getSlash(st)
2311		hf := st.wantHeaders()
2312		if hf.HeadersEnded() {
2313			t.Fatal("got unwanted END_HEADERS flag")
2314		}
2315		n := 0
2316		for {
2317			n++
2318			cf := st.wantContinuation()
2319			if cf.HeadersEnded() {
2320				break
2321			}
2322		}
2323		if n < 5 {
2324			t.Errorf("Only got %d CONTINUATION frames; expected 5+ (currently 6)", n)
2325		}
2326	})
2327}
2328
2329// This previously crashed (reported by Mathieu Lonjaret as observed
2330// while using Camlistore) because we got a DATA frame from the client
2331// after the handler exited and our logic at the time was wrong,
2332// keeping a stream in the map in stateClosed, which tickled an
2333// invariant check later when we tried to remove that stream (via
2334// defer sc.closeAllStreamsOnConnClose) when the serverConn serve loop
2335// ended.
2336func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) {
2337	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2338		// nothing
2339		return nil
2340	}, func(st *serverTester) {
2341		st.writeHeaders(HeadersFrameParam{
2342			StreamID:      1,
2343			BlockFragment: st.encodeHeader(),
2344			EndStream:     false, // DATA is coming
2345			EndHeaders:    true,
2346		})
2347		hf := st.wantHeaders()
2348		if !hf.HeadersEnded() || !hf.StreamEnded() {
2349			t.Fatalf("want END_HEADERS+END_STREAM, got %v", hf)
2350		}
2351
2352		// Sent when the a Handler closes while a client has
2353		// indicated it's still sending DATA:
2354		st.wantRSTStream(1, ErrCodeNo)
2355
2356		// Now the handler has ended, so it's ended its
2357		// stream, but the client hasn't closed its side
2358		// (stateClosedLocal).  So send more data and verify
2359		// it doesn't crash with an internal invariant panic, like
2360		// it did before.
2361		st.writeData(1, true, []byte("foo"))
2362
2363		// Get our flow control bytes back, since the handler didn't get them.
2364		st.wantWindowUpdate(0, uint32(len("foo")))
2365
2366		// Sent after a peer sends data anyway (admittedly the
2367		// previous RST_STREAM might've still been in-flight),
2368		// but they'll get the more friendly 'cancel' code
2369		// first.
2370		st.wantRSTStream(1, ErrCodeStreamClosed)
2371
2372		// Set up a bunch of machinery to record the panic we saw
2373		// previously.
2374		var (
2375			panMu    sync.Mutex
2376			panicVal interface{}
2377		)
2378
2379		testHookOnPanicMu.Lock()
2380		testHookOnPanic = func(sc *serverConn, pv interface{}) bool {
2381			panMu.Lock()
2382			panicVal = pv
2383			panMu.Unlock()
2384			return true
2385		}
2386		testHookOnPanicMu.Unlock()
2387
2388		// Now force the serve loop to end, via closing the connection.
2389		st.cc.Close()
2390		select {
2391		case <-st.sc.doneServing:
2392			// Loop has exited.
2393			panMu.Lock()
2394			got := panicVal
2395			panMu.Unlock()
2396			if got != nil {
2397				t.Errorf("Got panic: %v", got)
2398			}
2399		case <-time.After(5 * time.Second):
2400			t.Error("timeout")
2401		}
2402	})
2403}
2404
2405func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) }
2406func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) }
2407
2408func testRejectTLS(t *testing.T, max uint16) {
2409	st := newServerTester(t, nil, func(c *tls.Config) {
2410		c.MaxVersion = max
2411	})
2412	defer st.Close()
2413	gf := st.wantGoAway()
2414	if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
2415		t.Errorf("Got error code %v; want %v", got, want)
2416	}
2417}
2418
2419func TestServer_Rejects_TLSBadCipher(t *testing.T) {
2420	st := newServerTester(t, nil, func(c *tls.Config) {
2421		// Only list bad ones:
2422		c.CipherSuites = []uint16{
2423			tls.TLS_RSA_WITH_RC4_128_SHA,
2424			tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
2425			tls.TLS_RSA_WITH_AES_128_CBC_SHA,
2426			tls.TLS_RSA_WITH_AES_256_CBC_SHA,
2427			tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
2428			tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
2429			tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
2430			tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
2431			tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
2432			tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
2433			tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
2434			cipher_TLS_RSA_WITH_AES_128_CBC_SHA256,
2435		}
2436	})
2437	defer st.Close()
2438	gf := st.wantGoAway()
2439	if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
2440		t.Errorf("Got error code %v; want %v", got, want)
2441	}
2442}
2443
2444func TestServer_Advertises_Common_Cipher(t *testing.T) {
2445	const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
2446	st := newServerTester(t, nil, func(c *tls.Config) {
2447		// Have the client only support the one required by the spec.
2448		c.CipherSuites = []uint16{requiredSuite}
2449	}, func(ts *httptest.Server) {
2450		var srv *http.Server = ts.Config
2451		// Have the server configured with no specific cipher suites.
2452		// This tests that Go's defaults include the required one.
2453		srv.TLSConfig = nil
2454	})
2455	defer st.Close()
2456	st.greet()
2457}
2458
2459func (st *serverTester) onHeaderField(f hpack.HeaderField) {
2460	if f.Name == "date" {
2461		return
2462	}
2463	st.decodedHeaders = append(st.decodedHeaders, [2]string{f.Name, f.Value})
2464}
2465
2466func (st *serverTester) decodeHeader(headerBlock []byte) (pairs [][2]string) {
2467	st.decodedHeaders = nil
2468	if _, err := st.hpackDec.Write(headerBlock); err != nil {
2469		st.t.Fatalf("hpack decoding error: %v", err)
2470	}
2471	if err := st.hpackDec.Close(); err != nil {
2472		st.t.Fatalf("hpack decoding error: %v", err)
2473	}
2474	return st.decodedHeaders
2475}
2476
2477// testServerResponse sets up an idle HTTP/2 connection. The client function should
2478// write a single request that must be handled by the handler. This waits up to 5s
2479// for client to return, then up to an additional 2s for the handler to return.
2480func testServerResponse(t testing.TB,
2481	handler func(http.ResponseWriter, *http.Request) error,
2482	client func(*serverTester),
2483) {
2484	errc := make(chan error, 1)
2485	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2486		if r.Body == nil {
2487			t.Fatal("nil Body")
2488		}
2489		errc <- handler(w, r)
2490	})
2491	defer st.Close()
2492
2493	donec := make(chan bool)
2494	go func() {
2495		defer close(donec)
2496		st.greet()
2497		client(st)
2498	}()
2499
2500	select {
2501	case <-donec:
2502	case <-time.After(5 * time.Second):
2503		t.Fatal("timeout in client")
2504	}
2505
2506	select {
2507	case err := <-errc:
2508		if err != nil {
2509			t.Fatalf("Error in handler: %v", err)
2510		}
2511	case <-time.After(2 * time.Second):
2512		t.Fatal("timeout in handler")
2513	}
2514}
2515
2516// readBodyHandler returns an http Handler func that reads len(want)
2517// bytes from r.Body and fails t if the contents read were not
2518// the value of want.
2519func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *http.Request) {
2520	return func(w http.ResponseWriter, r *http.Request) {
2521		buf := make([]byte, len(want))
2522		_, err := io.ReadFull(r.Body, buf)
2523		if err != nil {
2524			t.Error(err)
2525			return
2526		}
2527		if string(buf) != want {
2528			t.Errorf("read %q; want %q", buf, want)
2529		}
2530	}
2531}
2532
2533// TestServerWithCurl currently fails, hence the LenientCipherSuites test. See:
2534//   https://github.com/tatsuhiro-t/nghttp2/issues/140 &
2535//   http://sourceforge.net/p/curl/bugs/1472/
2536func TestServerWithCurl(t *testing.T)                     { testServerWithCurl(t, false) }
2537func TestServerWithCurl_LenientCipherSuites(t *testing.T) { testServerWithCurl(t, true) }
2538
2539func testServerWithCurl(t *testing.T, permitProhibitedCipherSuites bool) {
2540	if runtime.GOOS != "linux" {
2541		t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway")
2542	}
2543	if testing.Short() {
2544		t.Skip("skipping curl test in short mode")
2545	}
2546	requireCurl(t)
2547	var gotConn int32
2548	testHookOnConn = func() { atomic.StoreInt32(&gotConn, 1) }
2549
2550	const msg = "Hello from curl!\n"
2551	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2552		w.Header().Set("Foo", "Bar")
2553		w.Header().Set("Client-Proto", r.Proto)
2554		io.WriteString(w, msg)
2555	}))
2556	ConfigureServer(ts.Config, &Server{
2557		PermitProhibitedCipherSuites: permitProhibitedCipherSuites,
2558	})
2559	ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
2560	ts.StartTLS()
2561	defer ts.Close()
2562
2563	t.Logf("Running test server for curl to hit at: %s", ts.URL)
2564	container := curl(t, "--silent", "--http2", "--insecure", "-v", ts.URL)
2565	defer kill(container)
2566	resc := make(chan interface{}, 1)
2567	go func() {
2568		res, err := dockerLogs(container)
2569		if err != nil {
2570			resc <- err
2571		} else {
2572			resc <- res
2573		}
2574	}()
2575	select {
2576	case res := <-resc:
2577		if err, ok := res.(error); ok {
2578			t.Fatal(err)
2579		}
2580		body := string(res.([]byte))
2581		// Search for both "key: value" and "key:value", since curl changed their format
2582		// Our Dockerfile contains the latest version (no space), but just in case people
2583		// didn't rebuild, check both.
2584		if !strings.Contains(body, "foo: Bar") && !strings.Contains(body, "foo:Bar") {
2585			t.Errorf("didn't see foo: Bar header")
2586			t.Logf("Got: %s", body)
2587		}
2588		if !strings.Contains(body, "client-proto: HTTP/2") && !strings.Contains(body, "client-proto:HTTP/2") {
2589			t.Errorf("didn't see client-proto: HTTP/2 header")
2590			t.Logf("Got: %s", res)
2591		}
2592		if !strings.Contains(string(res.([]byte)), msg) {
2593			t.Errorf("didn't see %q content", msg)
2594			t.Logf("Got: %s", res)
2595		}
2596	case <-time.After(3 * time.Second):
2597		t.Errorf("timeout waiting for curl")
2598	}
2599
2600	if atomic.LoadInt32(&gotConn) == 0 {
2601		t.Error("never saw an http2 connection")
2602	}
2603}
2604
2605var doh2load = flag.Bool("h2load", false, "Run h2load test")
2606
2607func TestServerWithH2Load(t *testing.T) {
2608	if !*doh2load {
2609		t.Skip("Skipping without --h2load flag.")
2610	}
2611	if runtime.GOOS != "linux" {
2612		t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway")
2613	}
2614	requireH2load(t)
2615
2616	msg := strings.Repeat("Hello, h2load!\n", 5000)
2617	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2618		io.WriteString(w, msg)
2619		w.(http.Flusher).Flush()
2620		io.WriteString(w, msg)
2621	}))
2622	ts.StartTLS()
2623	defer ts.Close()
2624
2625	cmd := exec.Command("docker", "run", "--net=host", "--entrypoint=/usr/local/bin/h2load", "gohttp2/curl",
2626		"-n100000", "-c100", "-m100", ts.URL)
2627	cmd.Stdout = os.Stdout
2628	cmd.Stderr = os.Stderr
2629	if err := cmd.Run(); err != nil {
2630		t.Fatal(err)
2631	}
2632}
2633
2634// Issue 12843
2635func TestServerDoS_MaxHeaderListSize(t *testing.T) {
2636	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
2637	defer st.Close()
2638
2639	// shake hands
2640	frameSize := defaultMaxReadFrameSize
2641	var advHeaderListSize *uint32
2642	st.greetAndCheckSettings(func(s Setting) error {
2643		switch s.ID {
2644		case SettingMaxFrameSize:
2645			if s.Val < minMaxFrameSize {
2646				frameSize = minMaxFrameSize
2647			} else if s.Val > maxFrameSize {
2648				frameSize = maxFrameSize
2649			} else {
2650				frameSize = int(s.Val)
2651			}
2652		case SettingMaxHeaderListSize:
2653			advHeaderListSize = &s.Val
2654		}
2655		return nil
2656	})
2657
2658	if advHeaderListSize == nil {
2659		t.Errorf("server didn't advertise a max header list size")
2660	} else if *advHeaderListSize == 0 {
2661		t.Errorf("server advertised a max header list size of 0")
2662	}
2663
2664	st.encodeHeaderField(":method", "GET")
2665	st.encodeHeaderField(":path", "/")
2666	st.encodeHeaderField(":scheme", "https")
2667	cookie := strings.Repeat("*", 4058)
2668	st.encodeHeaderField("cookie", cookie)
2669	st.writeHeaders(HeadersFrameParam{
2670		StreamID:      1,
2671		BlockFragment: st.headerBuf.Bytes(),
2672		EndStream:     true,
2673		EndHeaders:    false,
2674	})
2675
2676	// Capture the short encoding of a duplicate ~4K cookie, now
2677	// that we've already sent it once.
2678	st.headerBuf.Reset()
2679	st.encodeHeaderField("cookie", cookie)
2680
2681	// Now send 1MB of it.
2682	const size = 1 << 20
2683	b := bytes.Repeat(st.headerBuf.Bytes(), size/st.headerBuf.Len())
2684	for len(b) > 0 {
2685		chunk := b
2686		if len(chunk) > frameSize {
2687			chunk = chunk[:frameSize]
2688		}
2689		b = b[len(chunk):]
2690		st.fr.WriteContinuation(1, len(b) == 0, chunk)
2691	}
2692
2693	h := st.wantHeaders()
2694	if !h.HeadersEnded() {
2695		t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
2696	}
2697	headers := st.decodeHeader(h.HeaderBlockFragment())
2698	want := [][2]string{
2699		{":status", "431"},
2700		{"content-type", "text/html; charset=utf-8"},
2701		{"content-length", "63"},
2702	}
2703	if !reflect.DeepEqual(headers, want) {
2704		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
2705	}
2706}
2707
2708func TestCompressionErrorOnWrite(t *testing.T) {
2709	const maxStrLen = 8 << 10
2710	var serverConfig *http.Server
2711	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2712		// No response body.
2713	}, func(ts *httptest.Server) {
2714		serverConfig = ts.Config
2715		serverConfig.MaxHeaderBytes = maxStrLen
2716	})
2717	st.addLogFilter("connection error: COMPRESSION_ERROR")
2718	defer st.Close()
2719	st.greet()
2720
2721	maxAllowed := st.sc.framer.maxHeaderStringLen()
2722
2723	// Crank this up, now that we have a conn connected with the
2724	// hpack.Decoder's max string length set has been initialized
2725	// from the earlier low ~8K value. We want this higher so don't
2726	// hit the max header list size. We only want to test hitting
2727	// the max string size.
2728	serverConfig.MaxHeaderBytes = 1 << 20
2729
2730	// First a request with a header that's exactly the max allowed size
2731	// for the hpack compression. It's still too long for the header list
2732	// size, so we'll get the 431 error, but that keeps the compression
2733	// context still valid.
2734	hbf := st.encodeHeader("foo", strings.Repeat("a", maxAllowed))
2735
2736	st.writeHeaders(HeadersFrameParam{
2737		StreamID:      1,
2738		BlockFragment: hbf,
2739		EndStream:     true,
2740		EndHeaders:    true,
2741	})
2742	h := st.wantHeaders()
2743	if !h.HeadersEnded() {
2744		t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
2745	}
2746	headers := st.decodeHeader(h.HeaderBlockFragment())
2747	want := [][2]string{
2748		{":status", "431"},
2749		{"content-type", "text/html; charset=utf-8"},
2750		{"content-length", "63"},
2751	}
2752	if !reflect.DeepEqual(headers, want) {
2753		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
2754	}
2755	df := st.wantData()
2756	if !strings.Contains(string(df.Data()), "HTTP Error 431") {
2757		t.Errorf("Unexpected data body: %q", df.Data())
2758	}
2759	if !df.StreamEnded() {
2760		t.Fatalf("expect data stream end")
2761	}
2762
2763	// And now send one that's just one byte too big.
2764	hbf = st.encodeHeader("bar", strings.Repeat("b", maxAllowed+1))
2765	st.writeHeaders(HeadersFrameParam{
2766		StreamID:      3,
2767		BlockFragment: hbf,
2768		EndStream:     true,
2769		EndHeaders:    true,
2770	})
2771	ga := st.wantGoAway()
2772	if ga.ErrCode != ErrCodeCompression {
2773		t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
2774	}
2775}
2776
2777func TestCompressionErrorOnClose(t *testing.T) {
2778	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2779		// No response body.
2780	})
2781	st.addLogFilter("connection error: COMPRESSION_ERROR")
2782	defer st.Close()
2783	st.greet()
2784
2785	hbf := st.encodeHeader("foo", "bar")
2786	hbf = hbf[:len(hbf)-1] // truncate one byte from the end, so hpack.Decoder.Close fails.
2787	st.writeHeaders(HeadersFrameParam{
2788		StreamID:      1,
2789		BlockFragment: hbf,
2790		EndStream:     true,
2791		EndHeaders:    true,
2792	})
2793	ga := st.wantGoAway()
2794	if ga.ErrCode != ErrCodeCompression {
2795		t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
2796	}
2797}
2798
2799// test that a server handler can read trailers from a client
2800func TestServerReadsTrailers(t *testing.T) {
2801	const testBody = "some test body"
2802	writeReq := func(st *serverTester) {
2803		st.writeHeaders(HeadersFrameParam{
2804			StreamID:      1, // clients send odd numbers
2805			BlockFragment: st.encodeHeader("trailer", "Foo, Bar", "trailer", "Baz"),
2806			EndStream:     false,
2807			EndHeaders:    true,
2808		})
2809		st.writeData(1, false, []byte(testBody))
2810		st.writeHeaders(HeadersFrameParam{
2811			StreamID: 1, // clients send odd numbers
2812			BlockFragment: st.encodeHeaderRaw(
2813				"foo", "foov",
2814				"bar", "barv",
2815				"baz", "bazv",
2816				"surprise", "wasn't declared; shouldn't show up",
2817			),
2818			EndStream:  true,
2819			EndHeaders: true,
2820		})
2821	}
2822	checkReq := func(r *http.Request) {
2823		wantTrailer := http.Header{
2824			"Foo": nil,
2825			"Bar": nil,
2826			"Baz": nil,
2827		}
2828		if !reflect.DeepEqual(r.Trailer, wantTrailer) {
2829			t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer)
2830		}
2831		slurp, err := ioutil.ReadAll(r.Body)
2832		if string(slurp) != testBody {
2833			t.Errorf("read body %q; want %q", slurp, testBody)
2834		}
2835		if err != nil {
2836			t.Fatalf("Body slurp: %v", err)
2837		}
2838		wantTrailerAfter := http.Header{
2839			"Foo": {"foov"},
2840			"Bar": {"barv"},
2841			"Baz": {"bazv"},
2842		}
2843		if !reflect.DeepEqual(r.Trailer, wantTrailerAfter) {
2844			t.Errorf("final Trailer = %v; want %v", r.Trailer, wantTrailerAfter)
2845		}
2846	}
2847	testServerRequest(t, writeReq, checkReq)
2848}
2849
2850// test that a server handler can send trailers
2851func TestServerWritesTrailers_WithFlush(t *testing.T)    { testServerWritesTrailers(t, true) }
2852func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) }
2853
2854func testServerWritesTrailers(t *testing.T, withFlush bool) {
2855	// See https://httpwg.github.io/specs/rfc7540.html#rfc.section.8.1.3
2856	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2857		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
2858		w.Header().Add("Trailer", "Server-Trailer-C")
2859		w.Header().Add("Trailer", "Transfer-Encoding, Content-Length, Trailer") // filtered
2860
2861		// Regular headers:
2862		w.Header().Set("Foo", "Bar")
2863		w.Header().Set("Content-Length", "5") // len("Hello")
2864
2865		io.WriteString(w, "Hello")
2866		if withFlush {
2867			w.(http.Flusher).Flush()
2868		}
2869		w.Header().Set("Server-Trailer-A", "valuea")
2870		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
2871		// After a flush, random keys like Server-Surprise shouldn't show up:
2872		w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!")
2873		// But we do permit promoting keys to trailers after a
2874		// flush if they start with the magic
2875		// otherwise-invalid "Trailer:" prefix:
2876		w.Header().Set("Trailer:Post-Header-Trailer", "hi1")
2877		w.Header().Set("Trailer:post-header-trailer2", "hi2")
2878		w.Header().Set("Trailer:Range", "invalid")
2879		w.Header().Set("Trailer:Foo\x01Bogus", "invalid")
2880		w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 2616 14.40")
2881		w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 2616 14.40")
2882		w.Header().Set("Trailer", "should not be included; Forbidden by RFC 2616 14.40")
2883		return nil
2884	}, func(st *serverTester) {
2885		getSlash(st)
2886		hf := st.wantHeaders()
2887		if hf.StreamEnded() {
2888			t.Fatal("response HEADERS had END_STREAM")
2889		}
2890		if !hf.HeadersEnded() {
2891			t.Fatal("response HEADERS didn't have END_HEADERS")
2892		}
2893		goth := st.decodeHeader(hf.HeaderBlockFragment())
2894		wanth := [][2]string{
2895			{":status", "200"},
2896			{"foo", "Bar"},
2897			{"trailer", "Server-Trailer-A, Server-Trailer-B"},
2898			{"trailer", "Server-Trailer-C"},
2899			{"trailer", "Transfer-Encoding, Content-Length, Trailer"},
2900			{"content-type", "text/plain; charset=utf-8"},
2901			{"content-length", "5"},
2902		}
2903		if !reflect.DeepEqual(goth, wanth) {
2904			t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2905		}
2906		df := st.wantData()
2907		if string(df.Data()) != "Hello" {
2908			t.Fatalf("Client read %q; want Hello", df.Data())
2909		}
2910		if df.StreamEnded() {
2911			t.Fatalf("data frame had STREAM_ENDED")
2912		}
2913		tf := st.wantHeaders() // for the trailers
2914		if !tf.StreamEnded() {
2915			t.Fatalf("trailers HEADERS lacked END_STREAM")
2916		}
2917		if !tf.HeadersEnded() {
2918			t.Fatalf("trailers HEADERS lacked END_HEADERS")
2919		}
2920		wanth = [][2]string{
2921			{"post-header-trailer", "hi1"},
2922			{"post-header-trailer2", "hi2"},
2923			{"server-trailer-a", "valuea"},
2924			{"server-trailer-c", "valuec"},
2925		}
2926		goth = st.decodeHeader(tf.HeaderBlockFragment())
2927		if !reflect.DeepEqual(goth, wanth) {
2928			t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2929		}
2930	})
2931}
2932
2933// validate transmitted header field names & values
2934// golang.org/issue/14048
2935func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
2936	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2937		w.Header().Add("OK1", "x")
2938		w.Header().Add("Bad:Colon", "x") // colon (non-token byte) in key
2939		w.Header().Add("Bad1\x00", "x")  // null in key
2940		w.Header().Add("Bad2", "x\x00y") // null in value
2941		return nil
2942	}, func(st *serverTester) {
2943		getSlash(st)
2944		hf := st.wantHeaders()
2945		if !hf.StreamEnded() {
2946			t.Error("response HEADERS lacked END_STREAM")
2947		}
2948		if !hf.HeadersEnded() {
2949			t.Fatal("response HEADERS didn't have END_HEADERS")
2950		}
2951		goth := st.decodeHeader(hf.HeaderBlockFragment())
2952		wanth := [][2]string{
2953			{":status", "200"},
2954			{"ok1", "x"},
2955			{"content-length", "0"},
2956		}
2957		if !reflect.DeepEqual(goth, wanth) {
2958			t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2959		}
2960	})
2961}
2962
2963func BenchmarkServerGets(b *testing.B) {
2964	defer disableGoroutineTracking()()
2965	b.ReportAllocs()
2966
2967	const msg = "Hello, world"
2968	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
2969		io.WriteString(w, msg)
2970	})
2971	defer st.Close()
2972	st.greet()
2973
2974	// Give the server quota to reply. (plus it has the the 64KB)
2975	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
2976		b.Fatal(err)
2977	}
2978
2979	for i := 0; i < b.N; i++ {
2980		id := 1 + uint32(i)*2
2981		st.writeHeaders(HeadersFrameParam{
2982			StreamID:      id,
2983			BlockFragment: st.encodeHeader(),
2984			EndStream:     true,
2985			EndHeaders:    true,
2986		})
2987		st.wantHeaders()
2988		df := st.wantData()
2989		if !df.StreamEnded() {
2990			b.Fatalf("DATA didn't have END_STREAM; got %v", df)
2991		}
2992	}
2993}
2994
2995func BenchmarkServerPosts(b *testing.B) {
2996	defer disableGoroutineTracking()()
2997	b.ReportAllocs()
2998
2999	const msg = "Hello, world"
3000	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3001		// Consume the (empty) body from th peer before replying, otherwise
3002		// the server will sometimes (depending on scheduling) send the peer a
3003		// a RST_STREAM with the CANCEL error code.
3004		if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
3005			b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
3006		}
3007		io.WriteString(w, msg)
3008	})
3009	defer st.Close()
3010	st.greet()
3011
3012	// Give the server quota to reply. (plus it has the the 64KB)
3013	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3014		b.Fatal(err)
3015	}
3016
3017	for i := 0; i < b.N; i++ {
3018		id := 1 + uint32(i)*2
3019		st.writeHeaders(HeadersFrameParam{
3020			StreamID:      id,
3021			BlockFragment: st.encodeHeader(":method", "POST"),
3022			EndStream:     false,
3023			EndHeaders:    true,
3024		})
3025		st.writeData(id, true, nil)
3026		st.wantHeaders()
3027		df := st.wantData()
3028		if !df.StreamEnded() {
3029			b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3030		}
3031	}
3032}
3033
3034// Send a stream of messages from server to client in separate data frames.
3035// Brings up performance issues seen in long streams.
3036// Created to show problem in go issue #18502
3037func BenchmarkServerToClientStreamDefaultOptions(b *testing.B) {
3038	benchmarkServerToClientStream(b)
3039}
3040
3041// Justification for Change-Id: Iad93420ef6c3918f54249d867098f1dadfa324d8
3042// Expect to see memory/alloc reduction by opting in to Frame reuse with the Framer.
3043func BenchmarkServerToClientStreamReuseFrames(b *testing.B) {
3044	benchmarkServerToClientStream(b, optFramerReuseFrames)
3045}
3046
3047func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
3048	defer disableGoroutineTracking()()
3049	b.ReportAllocs()
3050	const msgLen = 1
3051	// default window size
3052	const windowSize = 1<<16 - 1
3053
3054	// next message to send from the server and for the client to expect
3055	nextMsg := func(i int) []byte {
3056		msg := make([]byte, msgLen)
3057		msg[0] = byte(i)
3058		if len(msg) != msgLen {
3059			panic("invalid test setup msg length")
3060		}
3061		return msg
3062	}
3063
3064	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3065		// Consume the (empty) body from th peer before replying, otherwise
3066		// the server will sometimes (depending on scheduling) send the peer a
3067		// a RST_STREAM with the CANCEL error code.
3068		if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
3069			b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
3070		}
3071		for i := 0; i < b.N; i += 1 {
3072			w.Write(nextMsg(i))
3073			w.(http.Flusher).Flush()
3074		}
3075	}, newServerOpts...)
3076	defer st.Close()
3077	st.greet()
3078
3079	const id = uint32(1)
3080
3081	st.writeHeaders(HeadersFrameParam{
3082		StreamID:      id,
3083		BlockFragment: st.encodeHeader(":method", "POST"),
3084		EndStream:     false,
3085		EndHeaders:    true,
3086	})
3087
3088	st.writeData(id, true, nil)
3089	st.wantHeaders()
3090
3091	var pendingWindowUpdate = uint32(0)
3092
3093	for i := 0; i < b.N; i += 1 {
3094		expected := nextMsg(i)
3095		df := st.wantData()
3096		if bytes.Compare(expected, df.data) != 0 {
3097			b.Fatalf("Bad message received; want %v; got %v", expected, df.data)
3098		}
3099		// try to send infrequent but large window updates so they don't overwhelm the test
3100		pendingWindowUpdate += uint32(len(df.data))
3101		if pendingWindowUpdate >= windowSize/2 {
3102			if err := st.fr.WriteWindowUpdate(0, pendingWindowUpdate); err != nil {
3103				b.Fatal(err)
3104			}
3105			if err := st.fr.WriteWindowUpdate(id, pendingWindowUpdate); err != nil {
3106				b.Fatal(err)
3107			}
3108			pendingWindowUpdate = 0
3109		}
3110	}
3111	df := st.wantData()
3112	if !df.StreamEnded() {
3113		b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3114	}
3115}
3116
3117// go-fuzz bug, originally reported at https://github.com/bradfitz/http2/issues/53
3118// Verify we don't hang.
3119func TestIssue53(t *testing.T) {
3120	const data = "PRI * HTTP/2.0\r\n\r\nSM" +
3121		"\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad"
3122	s := &http.Server{
3123		ErrorLog: log.New(io.MultiWriter(stderrv(), twriter{t: t}), "", log.LstdFlags),
3124		Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
3125			w.Write([]byte("hello"))
3126		}),
3127	}
3128	s2 := &Server{
3129		MaxReadFrameSize:             1 << 16,
3130		PermitProhibitedCipherSuites: true,
3131	}
3132	c := &issue53Conn{[]byte(data), false, false}
3133	s2.ServeConn(c, &ServeConnOpts{BaseConfig: s})
3134	if !c.closed {
3135		t.Fatal("connection is not closed")
3136	}
3137}
3138
3139type issue53Conn struct {
3140	data    []byte
3141	closed  bool
3142	written bool
3143}
3144
3145func (c *issue53Conn) Read(b []byte) (n int, err error) {
3146	if len(c.data) == 0 {
3147		return 0, io.EOF
3148	}
3149	n = copy(b, c.data)
3150	c.data = c.data[n:]
3151	return
3152}
3153
3154func (c *issue53Conn) Write(b []byte) (n int, err error) {
3155	c.written = true
3156	return len(b), nil
3157}
3158
3159func (c *issue53Conn) Close() error {
3160	c.closed = true
3161	return nil
3162}
3163
3164func (c *issue53Conn) LocalAddr() net.Addr {
3165	return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3166}
3167func (c *issue53Conn) RemoteAddr() net.Addr {
3168	return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3169}
3170func (c *issue53Conn) SetDeadline(t time.Time) error      { return nil }
3171func (c *issue53Conn) SetReadDeadline(t time.Time) error  { return nil }
3172func (c *issue53Conn) SetWriteDeadline(t time.Time) error { return nil }
3173
3174// golang.org/issue/12895
3175func TestConfigureServer(t *testing.T) {
3176	tests := []struct {
3177		name      string
3178		tlsConfig *tls.Config
3179		wantErr   string
3180	}{
3181		{
3182			name: "empty server",
3183		},
3184		{
3185			name: "just the required cipher suite",
3186			tlsConfig: &tls.Config{
3187				CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3188			},
3189		},
3190		{
3191			name: "just the alternative required cipher suite",
3192			tlsConfig: &tls.Config{
3193				CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
3194			},
3195		},
3196		{
3197			name: "missing required cipher suite",
3198			tlsConfig: &tls.Config{
3199				CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384},
3200			},
3201			wantErr: "is missing an HTTP/2-required AES_128_GCM_SHA256 cipher.",
3202		},
3203		{
3204			name: "required after bad",
3205			tlsConfig: &tls.Config{
3206				CipherSuites: []uint16{tls.TLS_RSA_WITH_RC4_128_SHA, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3207			},
3208			wantErr: "contains an HTTP/2-approved cipher suite (0xc02f), but it comes after",
3209		},
3210		{
3211			name: "bad after required",
3212			tlsConfig: &tls.Config{
3213				CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_RC4_128_SHA},
3214			},
3215		},
3216	}
3217	for _, tt := range tests {
3218		srv := &http.Server{TLSConfig: tt.tlsConfig}
3219		err := ConfigureServer(srv, nil)
3220		if (err != nil) != (tt.wantErr != "") {
3221			if tt.wantErr != "" {
3222				t.Errorf("%s: success, but want error", tt.name)
3223			} else {
3224				t.Errorf("%s: unexpected error: %v", tt.name, err)
3225			}
3226		}
3227		if err != nil && tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) {
3228			t.Errorf("%s: err = %v; want substring %q", tt.name, err, tt.wantErr)
3229		}
3230		if err == nil && !srv.TLSConfig.PreferServerCipherSuites {
3231			t.Errorf("%s: PreferServerCipherSuite is false; want true", tt.name)
3232		}
3233	}
3234}
3235
3236func TestServerRejectHeadWithBody(t *testing.T) {
3237	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3238		// No response body.
3239	})
3240	defer st.Close()
3241	st.greet()
3242	st.writeHeaders(HeadersFrameParam{
3243		StreamID:      1, // clients send odd numbers
3244		BlockFragment: st.encodeHeader(":method", "HEAD"),
3245		EndStream:     false, // what we're testing, a bogus HEAD request with body
3246		EndHeaders:    true,
3247	})
3248	st.wantRSTStream(1, ErrCodeProtocol)
3249}
3250
3251func TestServerNoAutoContentLengthOnHead(t *testing.T) {
3252	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3253		// No response body. (or smaller than one frame)
3254	})
3255	defer st.Close()
3256	st.greet()
3257	st.writeHeaders(HeadersFrameParam{
3258		StreamID:      1, // clients send odd numbers
3259		BlockFragment: st.encodeHeader(":method", "HEAD"),
3260		EndStream:     true,
3261		EndHeaders:    true,
3262	})
3263	h := st.wantHeaders()
3264	headers := st.decodeHeader(h.HeaderBlockFragment())
3265	want := [][2]string{
3266		{":status", "200"},
3267	}
3268	if !reflect.DeepEqual(headers, want) {
3269		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
3270	}
3271}
3272
3273// golang.org/issue/13495
3274func TestServerNoDuplicateContentType(t *testing.T) {
3275	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3276		w.Header()["Content-Type"] = []string{""}
3277		fmt.Fprintf(w, "<html><head></head><body>hi</body></html>")
3278	})
3279	defer st.Close()
3280	st.greet()
3281	st.writeHeaders(HeadersFrameParam{
3282		StreamID:      1,
3283		BlockFragment: st.encodeHeader(),
3284		EndStream:     true,
3285		EndHeaders:    true,
3286	})
3287	h := st.wantHeaders()
3288	headers := st.decodeHeader(h.HeaderBlockFragment())
3289	want := [][2]string{
3290		{":status", "200"},
3291		{"content-type", ""},
3292		{"content-length", "41"},
3293	}
3294	if !reflect.DeepEqual(headers, want) {
3295		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
3296	}
3297}
3298
3299func disableGoroutineTracking() (restore func()) {
3300	old := DebugGoroutines
3301	DebugGoroutines = false
3302	return func() { DebugGoroutines = old }
3303}
3304
3305func BenchmarkServer_GetRequest(b *testing.B) {
3306	defer disableGoroutineTracking()()
3307	b.ReportAllocs()
3308	const msg = "Hello, world."
3309	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3310		n, err := io.Copy(ioutil.Discard, r.Body)
3311		if err != nil || n > 0 {
3312			b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3313		}
3314		io.WriteString(w, msg)
3315	})
3316	defer st.Close()
3317
3318	st.greet()
3319	// Give the server quota to reply. (plus it has the the 64KB)
3320	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3321		b.Fatal(err)
3322	}
3323	hbf := st.encodeHeader(":method", "GET")
3324	for i := 0; i < b.N; i++ {
3325		streamID := uint32(1 + 2*i)
3326		st.writeHeaders(HeadersFrameParam{
3327			StreamID:      streamID,
3328			BlockFragment: hbf,
3329			EndStream:     true,
3330			EndHeaders:    true,
3331		})
3332		st.wantHeaders()
3333		st.wantData()
3334	}
3335}
3336
3337func BenchmarkServer_PostRequest(b *testing.B) {
3338	defer disableGoroutineTracking()()
3339	b.ReportAllocs()
3340	const msg = "Hello, world."
3341	st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3342		n, err := io.Copy(ioutil.Discard, r.Body)
3343		if err != nil || n > 0 {
3344			b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3345		}
3346		io.WriteString(w, msg)
3347	})
3348	defer st.Close()
3349	st.greet()
3350	// Give the server quota to reply. (plus it has the the 64KB)
3351	if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3352		b.Fatal(err)
3353	}
3354	hbf := st.encodeHeader(":method", "POST")
3355	for i := 0; i < b.N; i++ {
3356		streamID := uint32(1 + 2*i)
3357		st.writeHeaders(HeadersFrameParam{
3358			StreamID:      streamID,
3359			BlockFragment: hbf,
3360			EndStream:     false,
3361			EndHeaders:    true,
3362		})
3363		st.writeData(streamID, true, nil)
3364		st.wantHeaders()
3365		st.wantData()
3366	}
3367}
3368
3369type connStateConn struct {
3370	net.Conn
3371	cs tls.ConnectionState
3372}
3373
3374func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs }
3375
3376// golang.org/issue/12737 -- handle any net.Conn, not just
3377// *tls.Conn.
3378func TestServerHandleCustomConn(t *testing.T) {
3379	var s Server
3380	c1, c2 := net.Pipe()
3381	clientDone := make(chan struct{})
3382	handlerDone := make(chan struct{})
3383	var req *http.Request
3384	go func() {
3385		defer close(clientDone)
3386		defer c2.Close()
3387		fr := NewFramer(c2, c2)
3388		io.WriteString(c2, ClientPreface)
3389		fr.WriteSettings()
3390		fr.WriteSettingsAck()
3391		f, err := fr.ReadFrame()
3392		if err != nil {
3393			t.Error(err)
3394			return
3395		}
3396		if sf, ok := f.(*SettingsFrame); !ok || sf.IsAck() {
3397			t.Errorf("Got %v; want non-ACK SettingsFrame", summarizeFrame(f))
3398			return
3399		}
3400		f, err = fr.ReadFrame()
3401		if err != nil {
3402			t.Error(err)
3403			return
3404		}
3405		if sf, ok := f.(*SettingsFrame); !ok || !sf.IsAck() {
3406			t.Errorf("Got %v; want ACK SettingsFrame", summarizeFrame(f))
3407			return
3408		}
3409		var henc hpackEncoder
3410		fr.WriteHeaders(HeadersFrameParam{
3411			StreamID:      1,
3412			BlockFragment: henc.encodeHeaderRaw(t, ":method", "GET", ":path", "/", ":scheme", "https", ":authority", "foo.com"),
3413			EndStream:     true,
3414			EndHeaders:    true,
3415		})
3416		go io.Copy(ioutil.Discard, c2)
3417		<-handlerDone
3418	}()
3419	const testString = "my custom ConnectionState"
3420	fakeConnState := tls.ConnectionState{
3421		ServerName:  testString,
3422		Version:     tls.VersionTLS12,
3423		CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
3424	}
3425	go s.ServeConn(connStateConn{c1, fakeConnState}, &ServeConnOpts{
3426		BaseConfig: &http.Server{
3427			Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3428				defer close(handlerDone)
3429				req = r
3430			}),
3431		}})
3432	select {
3433	case <-clientDone:
3434	case <-time.After(5 * time.Second):
3435		t.Fatal("timeout waiting for handler")
3436	}
3437	if req.TLS == nil {
3438		t.Fatalf("Request.TLS is nil. Got: %#v", req)
3439	}
3440	if req.TLS.ServerName != testString {
3441		t.Fatalf("Request.TLS = %+v; want ServerName of %q", req.TLS, testString)
3442	}
3443}
3444
3445// golang.org/issue/14214
3446func TestServer_Rejects_ConnHeaders(t *testing.T) {
3447	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3448		t.Error("should not get to Handler")
3449	})
3450	defer st.Close()
3451	st.greet()
3452	st.bodylessReq1("connection", "foo")
3453	hf := st.wantHeaders()
3454	goth := st.decodeHeader(hf.HeaderBlockFragment())
3455	wanth := [][2]string{
3456		{":status", "400"},
3457		{"content-type", "text/plain; charset=utf-8"},
3458		{"x-content-type-options", "nosniff"},
3459		{"content-length", "51"},
3460	}
3461	if !reflect.DeepEqual(goth, wanth) {
3462		t.Errorf("Got headers %v; want %v", goth, wanth)
3463	}
3464}
3465
3466type hpackEncoder struct {
3467	enc *hpack.Encoder
3468	buf bytes.Buffer
3469}
3470
3471func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte {
3472	if len(headers)%2 == 1 {
3473		panic("odd number of kv args")
3474	}
3475	he.buf.Reset()
3476	if he.enc == nil {
3477		he.enc = hpack.NewEncoder(&he.buf)
3478	}
3479	for len(headers) > 0 {
3480		k, v := headers[0], headers[1]
3481		err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v})
3482		if err != nil {
3483			t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
3484		}
3485		headers = headers[2:]
3486	}
3487	return he.buf.Bytes()
3488}
3489
3490func TestCheckValidHTTP2Request(t *testing.T) {
3491	tests := []struct {
3492		h    http.Header
3493		want error
3494	}{
3495		{
3496			h:    http.Header{"Te": {"trailers"}},
3497			want: nil,
3498		},
3499		{
3500			h:    http.Header{"Te": {"trailers", "bogus"}},
3501			want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`),
3502		},
3503		{
3504			h:    http.Header{"Foo": {""}},
3505			want: nil,
3506		},
3507		{
3508			h:    http.Header{"Connection": {""}},
3509			want: errors.New(`request header "Connection" is not valid in HTTP/2`),
3510		},
3511		{
3512			h:    http.Header{"Proxy-Connection": {""}},
3513			want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`),
3514		},
3515		{
3516			h:    http.Header{"Keep-Alive": {""}},
3517			want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`),
3518		},
3519		{
3520			h:    http.Header{"Upgrade": {""}},
3521			want: errors.New(`request header "Upgrade" is not valid in HTTP/2`),
3522		},
3523	}
3524	for i, tt := range tests {
3525		got := checkValidHTTP2RequestHeaders(tt.h)
3526		if !reflect.DeepEqual(got, tt.want) {
3527			t.Errorf("%d. checkValidHTTP2Request = %v; want %v", i, got, tt.want)
3528		}
3529	}
3530}
3531
3532// golang.org/issue/14030
3533func TestExpect100ContinueAfterHandlerWrites(t *testing.T) {
3534	const msg = "Hello"
3535	const msg2 = "World"
3536
3537	doRead := make(chan bool, 1)
3538	defer close(doRead) // fallback cleanup
3539
3540	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3541		io.WriteString(w, msg)
3542		w.(http.Flusher).Flush()
3543
3544		// Do a read, which might force a 100-continue status to be sent.
3545		<-doRead
3546		r.Body.Read(make([]byte, 10))
3547
3548		io.WriteString(w, msg2)
3549
3550	}, optOnlyServer)
3551	defer st.Close()
3552
3553	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3554	defer tr.CloseIdleConnections()
3555
3556	req, _ := http.NewRequest("POST", st.ts.URL, io.LimitReader(neverEnding('A'), 2<<20))
3557	req.Header.Set("Expect", "100-continue")
3558
3559	res, err := tr.RoundTrip(req)
3560	if err != nil {
3561		t.Fatal(err)
3562	}
3563	defer res.Body.Close()
3564
3565	buf := make([]byte, len(msg))
3566	if _, err := io.ReadFull(res.Body, buf); err != nil {
3567		t.Fatal(err)
3568	}
3569	if string(buf) != msg {
3570		t.Fatalf("msg = %q; want %q", buf, msg)
3571	}
3572
3573	doRead <- true
3574
3575	if _, err := io.ReadFull(res.Body, buf); err != nil {
3576		t.Fatal(err)
3577	}
3578	if string(buf) != msg2 {
3579		t.Fatalf("second msg = %q; want %q", buf, msg2)
3580	}
3581}
3582
3583type funcReader func([]byte) (n int, err error)
3584
3585func (f funcReader) Read(p []byte) (n int, err error) { return f(p) }
3586
3587// golang.org/issue/16481 -- return flow control when streams close with unread data.
3588// (The Server version of the bug. See also TestUnreadFlowControlReturned_Transport)
3589func TestUnreadFlowControlReturned_Server(t *testing.T) {
3590	unblock := make(chan bool, 1)
3591	defer close(unblock)
3592
3593	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3594		// Don't read the 16KB request body. Wait until the client's
3595		// done sending it and then return. This should cause the Server
3596		// to then return those 16KB of flow control to the client.
3597		<-unblock
3598	}, optOnlyServer)
3599	defer st.Close()
3600
3601	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3602	defer tr.CloseIdleConnections()
3603
3604	// This previously hung on the 4th iteration.
3605	for i := 0; i < 6; i++ {
3606		body := io.MultiReader(
3607			io.LimitReader(neverEnding('A'), 16<<10),
3608			funcReader(func([]byte) (n int, err error) {
3609				unblock <- true
3610				return 0, io.EOF
3611			}),
3612		)
3613		req, _ := http.NewRequest("POST", st.ts.URL, body)
3614		res, err := tr.RoundTrip(req)
3615		if err != nil {
3616			t.Fatal(err)
3617		}
3618		res.Body.Close()
3619	}
3620
3621}
3622
3623func TestServerIdleTimeout(t *testing.T) {
3624	if testing.Short() {
3625		t.Skip("skipping in short mode")
3626	}
3627
3628	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3629	}, func(h2s *Server) {
3630		h2s.IdleTimeout = 500 * time.Millisecond
3631	})
3632	defer st.Close()
3633
3634	st.greet()
3635	ga := st.wantGoAway()
3636	if ga.ErrCode != ErrCodeNo {
3637		t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
3638	}
3639}
3640
3641func TestServerIdleTimeout_AfterRequest(t *testing.T) {
3642	if testing.Short() {
3643		t.Skip("skipping in short mode")
3644	}
3645	const timeout = 250 * time.Millisecond
3646
3647	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3648		time.Sleep(timeout * 2)
3649	}, func(h2s *Server) {
3650		h2s.IdleTimeout = timeout
3651	})
3652	defer st.Close()
3653
3654	st.greet()
3655
3656	// Send a request which takes twice the timeout. Verifies the
3657	// idle timeout doesn't fire while we're in a request:
3658	st.bodylessReq1()
3659	st.wantHeaders()
3660
3661	// But the idle timeout should be rearmed after the request
3662	// is done:
3663	ga := st.wantGoAway()
3664	if ga.ErrCode != ErrCodeNo {
3665		t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
3666	}
3667}
3668
3669// grpc-go closes the Request.Body currently with a Read.
3670// Verify that it doesn't race.
3671// See https://github.com/grpc/grpc-go/pull/938
3672func TestRequestBodyReadCloseRace(t *testing.T) {
3673	for i := 0; i < 100; i++ {
3674		body := &requestBody{
3675			pipe: &pipe{
3676				b: new(bytes.Buffer),
3677			},
3678		}
3679		body.pipe.CloseWithError(io.EOF)
3680
3681		done := make(chan bool, 1)
3682		buf := make([]byte, 10)
3683		go func() {
3684			time.Sleep(1 * time.Millisecond)
3685			body.Close()
3686			done <- true
3687		}()
3688		body.Read(buf)
3689		<-done
3690	}
3691}
3692
3693func TestIssue20704Race(t *testing.T) {
3694	if testing.Short() && os.Getenv("GO_BUILDER_NAME") == "" {
3695		t.Skip("skipping in short mode")
3696	}
3697	const (
3698		itemSize  = 1 << 10
3699		itemCount = 100
3700	)
3701
3702	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3703		for i := 0; i < itemCount; i++ {
3704			_, err := w.Write(make([]byte, itemSize))
3705			if err != nil {
3706				return
3707			}
3708		}
3709	}, optOnlyServer)
3710	defer st.Close()
3711
3712	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3713	defer tr.CloseIdleConnections()
3714	cl := &http.Client{Transport: tr}
3715
3716	for i := 0; i < 1000; i++ {
3717		resp, err := cl.Get(st.ts.URL)
3718		if err != nil {
3719			t.Fatal(err)
3720		}
3721		// Force a RST stream to the server by closing without
3722		// reading the body:
3723		resp.Body.Close()
3724	}
3725}
3726