1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package http2
6
7import (
8	"bytes"
9	"fmt"
10	"io"
11	"reflect"
12	"strings"
13	"testing"
14	"unsafe"
15
16	"golang.org/x/net/http2/hpack"
17)
18
19func testFramer() (*Framer, *bytes.Buffer) {
20	buf := new(bytes.Buffer)
21	return NewFramer(buf, buf), buf
22}
23
24func TestFrameSizes(t *testing.T) {
25	// Catch people rearranging the FrameHeader fields.
26	if got, want := int(unsafe.Sizeof(FrameHeader{})), 12; got != want {
27		t.Errorf("FrameHeader size = %d; want %d", got, want)
28	}
29}
30
31func TestFrameTypeString(t *testing.T) {
32	tests := []struct {
33		ft   FrameType
34		want string
35	}{
36		{FrameData, "DATA"},
37		{FramePing, "PING"},
38		{FrameGoAway, "GOAWAY"},
39		{0xf, "UNKNOWN_FRAME_TYPE_15"},
40	}
41
42	for i, tt := range tests {
43		got := tt.ft.String()
44		if got != tt.want {
45			t.Errorf("%d. String(FrameType %d) = %q; want %q", i, int(tt.ft), got, tt.want)
46		}
47	}
48}
49
50func TestWriteRST(t *testing.T) {
51	fr, buf := testFramer()
52	var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
53	var errCode uint32 = 7<<24 + 6<<16 + 5<<8 + 4
54	fr.WriteRSTStream(streamID, ErrCode(errCode))
55	const wantEnc = "\x00\x00\x04\x03\x00\x01\x02\x03\x04\x07\x06\x05\x04"
56	if buf.String() != wantEnc {
57		t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
58	}
59	f, err := fr.ReadFrame()
60	if err != nil {
61		t.Fatal(err)
62	}
63	want := &RSTStreamFrame{
64		FrameHeader: FrameHeader{
65			valid:    true,
66			Type:     0x3,
67			Flags:    0x0,
68			Length:   0x4,
69			StreamID: 0x1020304,
70		},
71		ErrCode: 0x7060504,
72	}
73	if !reflect.DeepEqual(f, want) {
74		t.Errorf("parsed back %#v; want %#v", f, want)
75	}
76}
77
78func TestWriteData(t *testing.T) {
79	fr, buf := testFramer()
80	var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
81	data := []byte("ABC")
82	fr.WriteData(streamID, true, data)
83	const wantEnc = "\x00\x00\x03\x00\x01\x01\x02\x03\x04ABC"
84	if buf.String() != wantEnc {
85		t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
86	}
87	f, err := fr.ReadFrame()
88	if err != nil {
89		t.Fatal(err)
90	}
91	df, ok := f.(*DataFrame)
92	if !ok {
93		t.Fatalf("got %T; want *DataFrame", f)
94	}
95	if !bytes.Equal(df.Data(), data) {
96		t.Errorf("got %q; want %q", df.Data(), data)
97	}
98	if f.Header().Flags&1 == 0 {
99		t.Errorf("didn't see END_STREAM flag")
100	}
101}
102
103func TestWriteDataPadded(t *testing.T) {
104	tests := [...]struct {
105		streamID   uint32
106		endStream  bool
107		data       []byte
108		pad        []byte
109		wantHeader FrameHeader
110	}{
111		// Unpadded:
112		0: {
113			streamID:  1,
114			endStream: true,
115			data:      []byte("foo"),
116			pad:       nil,
117			wantHeader: FrameHeader{
118				Type:     FrameData,
119				Flags:    FlagDataEndStream,
120				Length:   3,
121				StreamID: 1,
122			},
123		},
124
125		// Padded bit set, but no padding:
126		1: {
127			streamID:  1,
128			endStream: true,
129			data:      []byte("foo"),
130			pad:       []byte{},
131			wantHeader: FrameHeader{
132				Type:     FrameData,
133				Flags:    FlagDataEndStream | FlagDataPadded,
134				Length:   4,
135				StreamID: 1,
136			},
137		},
138
139		// Padded bit set, with padding:
140		2: {
141			streamID:  1,
142			endStream: false,
143			data:      []byte("foo"),
144			pad:       []byte{0, 0, 0},
145			wantHeader: FrameHeader{
146				Type:     FrameData,
147				Flags:    FlagDataPadded,
148				Length:   7,
149				StreamID: 1,
150			},
151		},
152	}
153	for i, tt := range tests {
154		fr, _ := testFramer()
155		fr.WriteDataPadded(tt.streamID, tt.endStream, tt.data, tt.pad)
156		f, err := fr.ReadFrame()
157		if err != nil {
158			t.Errorf("%d. ReadFrame: %v", i, err)
159			continue
160		}
161		got := f.Header()
162		tt.wantHeader.valid = true
163		if got != tt.wantHeader {
164			t.Errorf("%d. read %+v; want %+v", i, got, tt.wantHeader)
165			continue
166		}
167		df := f.(*DataFrame)
168		if !bytes.Equal(df.Data(), tt.data) {
169			t.Errorf("%d. got %q; want %q", i, df.Data(), tt.data)
170		}
171	}
172}
173
174func TestWriteHeaders(t *testing.T) {
175	tests := []struct {
176		name      string
177		p         HeadersFrameParam
178		wantEnc   string
179		wantFrame *HeadersFrame
180	}{
181		{
182			"basic",
183			HeadersFrameParam{
184				StreamID:      42,
185				BlockFragment: []byte("abc"),
186				Priority:      PriorityParam{},
187			},
188			"\x00\x00\x03\x01\x00\x00\x00\x00*abc",
189			&HeadersFrame{
190				FrameHeader: FrameHeader{
191					valid:    true,
192					StreamID: 42,
193					Type:     FrameHeaders,
194					Length:   uint32(len("abc")),
195				},
196				Priority:      PriorityParam{},
197				headerFragBuf: []byte("abc"),
198			},
199		},
200		{
201			"basic + end flags",
202			HeadersFrameParam{
203				StreamID:      42,
204				BlockFragment: []byte("abc"),
205				EndStream:     true,
206				EndHeaders:    true,
207				Priority:      PriorityParam{},
208			},
209			"\x00\x00\x03\x01\x05\x00\x00\x00*abc",
210			&HeadersFrame{
211				FrameHeader: FrameHeader{
212					valid:    true,
213					StreamID: 42,
214					Type:     FrameHeaders,
215					Flags:    FlagHeadersEndStream | FlagHeadersEndHeaders,
216					Length:   uint32(len("abc")),
217				},
218				Priority:      PriorityParam{},
219				headerFragBuf: []byte("abc"),
220			},
221		},
222		{
223			"with padding",
224			HeadersFrameParam{
225				StreamID:      42,
226				BlockFragment: []byte("abc"),
227				EndStream:     true,
228				EndHeaders:    true,
229				PadLength:     5,
230				Priority:      PriorityParam{},
231			},
232			"\x00\x00\t\x01\r\x00\x00\x00*\x05abc\x00\x00\x00\x00\x00",
233			&HeadersFrame{
234				FrameHeader: FrameHeader{
235					valid:    true,
236					StreamID: 42,
237					Type:     FrameHeaders,
238					Flags:    FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded,
239					Length:   uint32(1 + len("abc") + 5), // pad length + contents + padding
240				},
241				Priority:      PriorityParam{},
242				headerFragBuf: []byte("abc"),
243			},
244		},
245		{
246			"with priority",
247			HeadersFrameParam{
248				StreamID:      42,
249				BlockFragment: []byte("abc"),
250				EndStream:     true,
251				EndHeaders:    true,
252				PadLength:     2,
253				Priority: PriorityParam{
254					StreamDep: 15,
255					Exclusive: true,
256					Weight:    127,
257				},
258			},
259			"\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x0f\u007fabc\x00\x00",
260			&HeadersFrame{
261				FrameHeader: FrameHeader{
262					valid:    true,
263					StreamID: 42,
264					Type:     FrameHeaders,
265					Flags:    FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority,
266					Length:   uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding
267				},
268				Priority: PriorityParam{
269					StreamDep: 15,
270					Exclusive: true,
271					Weight:    127,
272				},
273				headerFragBuf: []byte("abc"),
274			},
275		},
276		{
277			"with priority stream dep zero", // golang.org/issue/15444
278			HeadersFrameParam{
279				StreamID:      42,
280				BlockFragment: []byte("abc"),
281				EndStream:     true,
282				EndHeaders:    true,
283				PadLength:     2,
284				Priority: PriorityParam{
285					StreamDep: 0,
286					Exclusive: true,
287					Weight:    127,
288				},
289			},
290			"\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x00\u007fabc\x00\x00",
291			&HeadersFrame{
292				FrameHeader: FrameHeader{
293					valid:    true,
294					StreamID: 42,
295					Type:     FrameHeaders,
296					Flags:    FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority,
297					Length:   uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding
298				},
299				Priority: PriorityParam{
300					StreamDep: 0,
301					Exclusive: true,
302					Weight:    127,
303				},
304				headerFragBuf: []byte("abc"),
305			},
306		},
307	}
308	for _, tt := range tests {
309		fr, buf := testFramer()
310		if err := fr.WriteHeaders(tt.p); err != nil {
311			t.Errorf("test %q: %v", tt.name, err)
312			continue
313		}
314		if buf.String() != tt.wantEnc {
315			t.Errorf("test %q: encoded %q; want %q", tt.name, buf.Bytes(), tt.wantEnc)
316		}
317		f, err := fr.ReadFrame()
318		if err != nil {
319			t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
320			continue
321		}
322		if !reflect.DeepEqual(f, tt.wantFrame) {
323			t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
324		}
325	}
326}
327
328func TestWriteInvalidStreamDep(t *testing.T) {
329	fr, _ := testFramer()
330	err := fr.WriteHeaders(HeadersFrameParam{
331		StreamID: 42,
332		Priority: PriorityParam{
333			StreamDep: 1 << 31,
334		},
335	})
336	if err != errDepStreamID {
337		t.Errorf("header error = %v; want %q", err, errDepStreamID)
338	}
339
340	err = fr.WritePriority(2, PriorityParam{StreamDep: 1 << 31})
341	if err != errDepStreamID {
342		t.Errorf("priority error = %v; want %q", err, errDepStreamID)
343	}
344}
345
346func TestWriteContinuation(t *testing.T) {
347	const streamID = 42
348	tests := []struct {
349		name string
350		end  bool
351		frag []byte
352
353		wantFrame *ContinuationFrame
354	}{
355		{
356			"not end",
357			false,
358			[]byte("abc"),
359			&ContinuationFrame{
360				FrameHeader: FrameHeader{
361					valid:    true,
362					StreamID: streamID,
363					Type:     FrameContinuation,
364					Length:   uint32(len("abc")),
365				},
366				headerFragBuf: []byte("abc"),
367			},
368		},
369		{
370			"end",
371			true,
372			[]byte("def"),
373			&ContinuationFrame{
374				FrameHeader: FrameHeader{
375					valid:    true,
376					StreamID: streamID,
377					Type:     FrameContinuation,
378					Flags:    FlagContinuationEndHeaders,
379					Length:   uint32(len("def")),
380				},
381				headerFragBuf: []byte("def"),
382			},
383		},
384	}
385	for _, tt := range tests {
386		fr, _ := testFramer()
387		if err := fr.WriteContinuation(streamID, tt.end, tt.frag); err != nil {
388			t.Errorf("test %q: %v", tt.name, err)
389			continue
390		}
391		fr.AllowIllegalReads = true
392		f, err := fr.ReadFrame()
393		if err != nil {
394			t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
395			continue
396		}
397		if !reflect.DeepEqual(f, tt.wantFrame) {
398			t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
399		}
400	}
401}
402
403func TestWritePriority(t *testing.T) {
404	const streamID = 42
405	tests := []struct {
406		name      string
407		priority  PriorityParam
408		wantFrame *PriorityFrame
409	}{
410		{
411			"not exclusive",
412			PriorityParam{
413				StreamDep: 2,
414				Exclusive: false,
415				Weight:    127,
416			},
417			&PriorityFrame{
418				FrameHeader{
419					valid:    true,
420					StreamID: streamID,
421					Type:     FramePriority,
422					Length:   5,
423				},
424				PriorityParam{
425					StreamDep: 2,
426					Exclusive: false,
427					Weight:    127,
428				},
429			},
430		},
431
432		{
433			"exclusive",
434			PriorityParam{
435				StreamDep: 3,
436				Exclusive: true,
437				Weight:    77,
438			},
439			&PriorityFrame{
440				FrameHeader{
441					valid:    true,
442					StreamID: streamID,
443					Type:     FramePriority,
444					Length:   5,
445				},
446				PriorityParam{
447					StreamDep: 3,
448					Exclusive: true,
449					Weight:    77,
450				},
451			},
452		},
453	}
454	for _, tt := range tests {
455		fr, _ := testFramer()
456		if err := fr.WritePriority(streamID, tt.priority); err != nil {
457			t.Errorf("test %q: %v", tt.name, err)
458			continue
459		}
460		f, err := fr.ReadFrame()
461		if err != nil {
462			t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
463			continue
464		}
465		if !reflect.DeepEqual(f, tt.wantFrame) {
466			t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
467		}
468	}
469}
470
471func TestWriteSettings(t *testing.T) {
472	fr, buf := testFramer()
473	settings := []Setting{{1, 2}, {3, 4}}
474	fr.WriteSettings(settings...)
475	const wantEnc = "\x00\x00\f\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x03\x00\x00\x00\x04"
476	if buf.String() != wantEnc {
477		t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
478	}
479	f, err := fr.ReadFrame()
480	if err != nil {
481		t.Fatal(err)
482	}
483	sf, ok := f.(*SettingsFrame)
484	if !ok {
485		t.Fatalf("Got a %T; want a SettingsFrame", f)
486	}
487	var got []Setting
488	sf.ForeachSetting(func(s Setting) error {
489		got = append(got, s)
490		valBack, ok := sf.Value(s.ID)
491		if !ok || valBack != s.Val {
492			t.Errorf("Value(%d) = %v, %v; want %v, true", s.ID, valBack, ok, s.Val)
493		}
494		return nil
495	})
496	if !reflect.DeepEqual(settings, got) {
497		t.Errorf("Read settings %+v != written settings %+v", got, settings)
498	}
499}
500
501func TestWriteSettingsAck(t *testing.T) {
502	fr, buf := testFramer()
503	fr.WriteSettingsAck()
504	const wantEnc = "\x00\x00\x00\x04\x01\x00\x00\x00\x00"
505	if buf.String() != wantEnc {
506		t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
507	}
508}
509
510func TestWriteWindowUpdate(t *testing.T) {
511	fr, buf := testFramer()
512	const streamID = 1<<24 + 2<<16 + 3<<8 + 4
513	const incr = 7<<24 + 6<<16 + 5<<8 + 4
514	if err := fr.WriteWindowUpdate(streamID, incr); err != nil {
515		t.Fatal(err)
516	}
517	const wantEnc = "\x00\x00\x04\x08\x00\x01\x02\x03\x04\x07\x06\x05\x04"
518	if buf.String() != wantEnc {
519		t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
520	}
521	f, err := fr.ReadFrame()
522	if err != nil {
523		t.Fatal(err)
524	}
525	want := &WindowUpdateFrame{
526		FrameHeader: FrameHeader{
527			valid:    true,
528			Type:     0x8,
529			Flags:    0x0,
530			Length:   0x4,
531			StreamID: 0x1020304,
532		},
533		Increment: 0x7060504,
534	}
535	if !reflect.DeepEqual(f, want) {
536		t.Errorf("parsed back %#v; want %#v", f, want)
537	}
538}
539
540func TestWritePing(t *testing.T)    { testWritePing(t, false) }
541func TestWritePingAck(t *testing.T) { testWritePing(t, true) }
542
543func testWritePing(t *testing.T, ack bool) {
544	fr, buf := testFramer()
545	if err := fr.WritePing(ack, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
546		t.Fatal(err)
547	}
548	var wantFlags Flags
549	if ack {
550		wantFlags = FlagPingAck
551	}
552	var wantEnc = "\x00\x00\x08\x06" + string(wantFlags) + "\x00\x00\x00\x00" + "\x01\x02\x03\x04\x05\x06\x07\x08"
553	if buf.String() != wantEnc {
554		t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
555	}
556
557	f, err := fr.ReadFrame()
558	if err != nil {
559		t.Fatal(err)
560	}
561	want := &PingFrame{
562		FrameHeader: FrameHeader{
563			valid:    true,
564			Type:     0x6,
565			Flags:    wantFlags,
566			Length:   0x8,
567			StreamID: 0,
568		},
569		Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8},
570	}
571	if !reflect.DeepEqual(f, want) {
572		t.Errorf("parsed back %#v; want %#v", f, want)
573	}
574}
575
576func TestReadFrameHeader(t *testing.T) {
577	tests := []struct {
578		in   string
579		want FrameHeader
580	}{
581		{in: "\x00\x00\x00" + "\x00" + "\x00" + "\x00\x00\x00\x00", want: FrameHeader{}},
582		{in: "\x01\x02\x03" + "\x04" + "\x05" + "\x06\x07\x08\x09", want: FrameHeader{
583			Length: 66051, Type: 4, Flags: 5, StreamID: 101124105,
584		}},
585		// Ignore high bit:
586		{in: "\xff\xff\xff" + "\xff" + "\xff" + "\xff\xff\xff\xff", want: FrameHeader{
587			Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}},
588		{in: "\xff\xff\xff" + "\xff" + "\xff" + "\x7f\xff\xff\xff", want: FrameHeader{
589			Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}},
590	}
591	for i, tt := range tests {
592		got, err := readFrameHeader(make([]byte, 9), strings.NewReader(tt.in))
593		if err != nil {
594			t.Errorf("%d. readFrameHeader(%q) = %v", i, tt.in, err)
595			continue
596		}
597		tt.want.valid = true
598		if got != tt.want {
599			t.Errorf("%d. readFrameHeader(%q) = %+v; want %+v", i, tt.in, got, tt.want)
600		}
601	}
602}
603
604func TestReadWriteFrameHeader(t *testing.T) {
605	tests := []struct {
606		len      uint32
607		typ      FrameType
608		flags    Flags
609		streamID uint32
610	}{
611		{len: 0, typ: 255, flags: 1, streamID: 0},
612		{len: 0, typ: 255, flags: 1, streamID: 1},
613		{len: 0, typ: 255, flags: 1, streamID: 255},
614		{len: 0, typ: 255, flags: 1, streamID: 256},
615		{len: 0, typ: 255, flags: 1, streamID: 65535},
616		{len: 0, typ: 255, flags: 1, streamID: 65536},
617
618		{len: 0, typ: 1, flags: 255, streamID: 1},
619		{len: 255, typ: 1, flags: 255, streamID: 1},
620		{len: 256, typ: 1, flags: 255, streamID: 1},
621		{len: 65535, typ: 1, flags: 255, streamID: 1},
622		{len: 65536, typ: 1, flags: 255, streamID: 1},
623		{len: 16777215, typ: 1, flags: 255, streamID: 1},
624	}
625	for _, tt := range tests {
626		fr, buf := testFramer()
627		fr.startWrite(tt.typ, tt.flags, tt.streamID)
628		fr.writeBytes(make([]byte, tt.len))
629		fr.endWrite()
630		fh, err := ReadFrameHeader(buf)
631		if err != nil {
632			t.Errorf("ReadFrameHeader(%+v) = %v", tt, err)
633			continue
634		}
635		if fh.Type != tt.typ || fh.Flags != tt.flags || fh.Length != tt.len || fh.StreamID != tt.streamID {
636			t.Errorf("ReadFrameHeader(%+v) = %+v; mismatch", tt, fh)
637		}
638	}
639
640}
641
642func TestWriteTooLargeFrame(t *testing.T) {
643	fr, _ := testFramer()
644	fr.startWrite(0, 1, 1)
645	fr.writeBytes(make([]byte, 1<<24))
646	err := fr.endWrite()
647	if err != ErrFrameTooLarge {
648		t.Errorf("endWrite = %v; want errFrameTooLarge", err)
649	}
650}
651
652func TestWriteGoAway(t *testing.T) {
653	const debug = "foo"
654	fr, buf := testFramer()
655	if err := fr.WriteGoAway(0x01020304, 0x05060708, []byte(debug)); err != nil {
656		t.Fatal(err)
657	}
658	const wantEnc = "\x00\x00\v\a\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08" + debug
659	if buf.String() != wantEnc {
660		t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
661	}
662	f, err := fr.ReadFrame()
663	if err != nil {
664		t.Fatal(err)
665	}
666	want := &GoAwayFrame{
667		FrameHeader: FrameHeader{
668			valid:    true,
669			Type:     0x7,
670			Flags:    0,
671			Length:   uint32(4 + 4 + len(debug)),
672			StreamID: 0,
673		},
674		LastStreamID: 0x01020304,
675		ErrCode:      0x05060708,
676		debugData:    []byte(debug),
677	}
678	if !reflect.DeepEqual(f, want) {
679		t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
680	}
681	if got := string(f.(*GoAwayFrame).DebugData()); got != debug {
682		t.Errorf("debug data = %q; want %q", got, debug)
683	}
684}
685
686func TestWritePushPromise(t *testing.T) {
687	pp := PushPromiseParam{
688		StreamID:      42,
689		PromiseID:     42,
690		BlockFragment: []byte("abc"),
691	}
692	fr, buf := testFramer()
693	if err := fr.WritePushPromise(pp); err != nil {
694		t.Fatal(err)
695	}
696	const wantEnc = "\x00\x00\x07\x05\x00\x00\x00\x00*\x00\x00\x00*abc"
697	if buf.String() != wantEnc {
698		t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
699	}
700	f, err := fr.ReadFrame()
701	if err != nil {
702		t.Fatal(err)
703	}
704	_, ok := f.(*PushPromiseFrame)
705	if !ok {
706		t.Fatalf("got %T; want *PushPromiseFrame", f)
707	}
708	want := &PushPromiseFrame{
709		FrameHeader: FrameHeader{
710			valid:    true,
711			Type:     0x5,
712			Flags:    0x0,
713			Length:   0x7,
714			StreamID: 42,
715		},
716		PromiseID:     42,
717		headerFragBuf: []byte("abc"),
718	}
719	if !reflect.DeepEqual(f, want) {
720		t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
721	}
722}
723
724// test checkFrameOrder and that HEADERS and CONTINUATION frames can't be intermingled.
725func TestReadFrameOrder(t *testing.T) {
726	head := func(f *Framer, id uint32, end bool) {
727		f.WriteHeaders(HeadersFrameParam{
728			StreamID:      id,
729			BlockFragment: []byte("foo"), // unused, but non-empty
730			EndHeaders:    end,
731		})
732	}
733	cont := func(f *Framer, id uint32, end bool) {
734		f.WriteContinuation(id, end, []byte("foo"))
735	}
736
737	tests := [...]struct {
738		name    string
739		w       func(*Framer)
740		atLeast int
741		wantErr string
742	}{
743		0: {
744			w: func(f *Framer) {
745				head(f, 1, true)
746			},
747		},
748		1: {
749			w: func(f *Framer) {
750				head(f, 1, true)
751				head(f, 2, true)
752			},
753		},
754		2: {
755			wantErr: "got HEADERS for stream 2; expected CONTINUATION following HEADERS for stream 1",
756			w: func(f *Framer) {
757				head(f, 1, false)
758				head(f, 2, true)
759			},
760		},
761		3: {
762			wantErr: "got DATA for stream 1; expected CONTINUATION following HEADERS for stream 1",
763			w: func(f *Framer) {
764				head(f, 1, false)
765			},
766		},
767		4: {
768			w: func(f *Framer) {
769				head(f, 1, false)
770				cont(f, 1, true)
771				head(f, 2, true)
772			},
773		},
774		5: {
775			wantErr: "got CONTINUATION for stream 2; expected stream 1",
776			w: func(f *Framer) {
777				head(f, 1, false)
778				cont(f, 2, true)
779				head(f, 2, true)
780			},
781		},
782		6: {
783			wantErr: "unexpected CONTINUATION for stream 1",
784			w: func(f *Framer) {
785				cont(f, 1, true)
786			},
787		},
788		7: {
789			wantErr: "unexpected CONTINUATION for stream 1",
790			w: func(f *Framer) {
791				cont(f, 1, false)
792			},
793		},
794		8: {
795			wantErr: "HEADERS frame with stream ID 0",
796			w: func(f *Framer) {
797				head(f, 0, true)
798			},
799		},
800		9: {
801			wantErr: "CONTINUATION frame with stream ID 0",
802			w: func(f *Framer) {
803				cont(f, 0, true)
804			},
805		},
806		10: {
807			wantErr: "unexpected CONTINUATION for stream 1",
808			atLeast: 5,
809			w: func(f *Framer) {
810				head(f, 1, false)
811				cont(f, 1, false)
812				cont(f, 1, false)
813				cont(f, 1, false)
814				cont(f, 1, true)
815				cont(f, 1, false)
816			},
817		},
818	}
819	for i, tt := range tests {
820		buf := new(bytes.Buffer)
821		f := NewFramer(buf, buf)
822		f.AllowIllegalWrites = true
823		tt.w(f)
824		f.WriteData(1, true, nil) // to test transition away from last step
825
826		var err error
827		n := 0
828		var log bytes.Buffer
829		for {
830			var got Frame
831			got, err = f.ReadFrame()
832			fmt.Fprintf(&log, "  read %v, %v\n", got, err)
833			if err != nil {
834				break
835			}
836			n++
837		}
838		if err == io.EOF {
839			err = nil
840		}
841		ok := tt.wantErr == ""
842		if ok && err != nil {
843			t.Errorf("%d. after %d good frames, ReadFrame = %v; want success\n%s", i, n, err, log.Bytes())
844			continue
845		}
846		if !ok && err != ConnectionError(ErrCodeProtocol) {
847			t.Errorf("%d. after %d good frames, ReadFrame = %v; want ConnectionError(ErrCodeProtocol)\n%s", i, n, err, log.Bytes())
848			continue
849		}
850		if !((f.errDetail == nil && tt.wantErr == "") || (fmt.Sprint(f.errDetail) == tt.wantErr)) {
851			t.Errorf("%d. framer eror = %q; want %q\n%s", i, f.errDetail, tt.wantErr, log.Bytes())
852		}
853		if n < tt.atLeast {
854			t.Errorf("%d. framer only read %d frames; want at least %d\n%s", i, n, tt.atLeast, log.Bytes())
855		}
856	}
857}
858
859func TestMetaFrameHeader(t *testing.T) {
860	write := func(f *Framer, frags ...[]byte) {
861		for i, frag := range frags {
862			end := (i == len(frags)-1)
863			if i == 0 {
864				f.WriteHeaders(HeadersFrameParam{
865					StreamID:      1,
866					BlockFragment: frag,
867					EndHeaders:    end,
868				})
869			} else {
870				f.WriteContinuation(1, end, frag)
871			}
872		}
873	}
874
875	want := func(flags Flags, length uint32, pairs ...string) *MetaHeadersFrame {
876		mh := &MetaHeadersFrame{
877			HeadersFrame: &HeadersFrame{
878				FrameHeader: FrameHeader{
879					Type:     FrameHeaders,
880					Flags:    flags,
881					Length:   length,
882					StreamID: 1,
883				},
884			},
885			Fields: []hpack.HeaderField(nil),
886		}
887		for len(pairs) > 0 {
888			mh.Fields = append(mh.Fields, hpack.HeaderField{
889				Name:  pairs[0],
890				Value: pairs[1],
891			})
892			pairs = pairs[2:]
893		}
894		return mh
895	}
896	truncated := func(mh *MetaHeadersFrame) *MetaHeadersFrame {
897		mh.Truncated = true
898		return mh
899	}
900
901	const noFlags Flags = 0
902
903	oneKBString := strings.Repeat("a", 1<<10)
904
905	tests := [...]struct {
906		name              string
907		w                 func(*Framer)
908		want              interface{} // *MetaHeaderFrame or error
909		wantErrReason     string
910		maxHeaderListSize uint32
911	}{
912		0: {
913			name: "single_headers",
914			w: func(f *Framer) {
915				var he hpackEncoder
916				all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/")
917				write(f, all)
918			},
919			want: want(FlagHeadersEndHeaders, 2, ":method", "GET", ":path", "/"),
920		},
921		1: {
922			name: "with_continuation",
923			w: func(f *Framer) {
924				var he hpackEncoder
925				all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar")
926				write(f, all[:1], all[1:])
927			},
928			want: want(noFlags, 1, ":method", "GET", ":path", "/", "foo", "bar"),
929		},
930		2: {
931			name: "with_two_continuation",
932			w: func(f *Framer) {
933				var he hpackEncoder
934				all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar")
935				write(f, all[:2], all[2:4], all[4:])
936			},
937			want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", "bar"),
938		},
939		3: {
940			name: "big_string_okay",
941			w: func(f *Framer) {
942				var he hpackEncoder
943				all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString)
944				write(f, all[:2], all[2:])
945			},
946			want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", oneKBString),
947		},
948		4: {
949			name: "big_string_error",
950			w: func(f *Framer) {
951				var he hpackEncoder
952				all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString)
953				write(f, all[:2], all[2:])
954			},
955			maxHeaderListSize: (1 << 10) / 2,
956			want:              ConnectionError(ErrCodeCompression),
957		},
958		5: {
959			name: "max_header_list_truncated",
960			w: func(f *Framer) {
961				var he hpackEncoder
962				var pairs = []string{":method", "GET", ":path", "/"}
963				for i := 0; i < 100; i++ {
964					pairs = append(pairs, "foo", "bar")
965				}
966				all := he.encodeHeaderRaw(t, pairs...)
967				write(f, all[:2], all[2:])
968			},
969			maxHeaderListSize: (1 << 10) / 2,
970			want: truncated(want(noFlags, 2,
971				":method", "GET",
972				":path", "/",
973				"foo", "bar",
974				"foo", "bar",
975				"foo", "bar",
976				"foo", "bar",
977				"foo", "bar",
978				"foo", "bar",
979				"foo", "bar",
980				"foo", "bar",
981				"foo", "bar",
982				"foo", "bar",
983				"foo", "bar", // 11
984			)),
985		},
986		6: {
987			name: "pseudo_order",
988			w: func(f *Framer) {
989				write(f, encodeHeaderRaw(t,
990					":method", "GET",
991					"foo", "bar",
992					":path", "/", // bogus
993				))
994			},
995			want:          streamError(1, ErrCodeProtocol),
996			wantErrReason: "pseudo header field after regular",
997		},
998		7: {
999			name: "pseudo_unknown",
1000			w: func(f *Framer) {
1001				write(f, encodeHeaderRaw(t,
1002					":unknown", "foo", // bogus
1003					"foo", "bar",
1004				))
1005			},
1006			want:          streamError(1, ErrCodeProtocol),
1007			wantErrReason: "invalid pseudo-header \":unknown\"",
1008		},
1009		8: {
1010			name: "pseudo_mix_request_response",
1011			w: func(f *Framer) {
1012				write(f, encodeHeaderRaw(t,
1013					":method", "GET",
1014					":status", "100",
1015				))
1016			},
1017			want:          streamError(1, ErrCodeProtocol),
1018			wantErrReason: "mix of request and response pseudo headers",
1019		},
1020		9: {
1021			name: "pseudo_dup",
1022			w: func(f *Framer) {
1023				write(f, encodeHeaderRaw(t,
1024					":method", "GET",
1025					":method", "POST",
1026				))
1027			},
1028			want:          streamError(1, ErrCodeProtocol),
1029			wantErrReason: "duplicate pseudo-header \":method\"",
1030		},
1031		10: {
1032			name: "trailer_okay_no_pseudo",
1033			w:    func(f *Framer) { write(f, encodeHeaderRaw(t, "foo", "bar")) },
1034			want: want(FlagHeadersEndHeaders, 8, "foo", "bar"),
1035		},
1036		11: {
1037			name:          "invalid_field_name",
1038			w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) },
1039			want:          streamError(1, ErrCodeProtocol),
1040			wantErrReason: "invalid header field name \"CapitalBad\"",
1041		},
1042		12: {
1043			name:          "invalid_field_value",
1044			w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) },
1045			want:          streamError(1, ErrCodeProtocol),
1046			wantErrReason: "invalid header field value \"bad_null\\x00\"",
1047		},
1048	}
1049	for i, tt := range tests {
1050		buf := new(bytes.Buffer)
1051		f := NewFramer(buf, buf)
1052		f.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
1053		f.MaxHeaderListSize = tt.maxHeaderListSize
1054		tt.w(f)
1055
1056		name := tt.name
1057		if name == "" {
1058			name = fmt.Sprintf("test index %d", i)
1059		}
1060
1061		var got interface{}
1062		var err error
1063		got, err = f.ReadFrame()
1064		if err != nil {
1065			got = err
1066
1067			// Ignore the StreamError.Cause field, if it matches the wantErrReason.
1068			// The test table above predates the Cause field.
1069			if se, ok := err.(StreamError); ok && se.Cause != nil && se.Cause.Error() == tt.wantErrReason {
1070				se.Cause = nil
1071				got = se
1072			}
1073		}
1074		if !reflect.DeepEqual(got, tt.want) {
1075			if mhg, ok := got.(*MetaHeadersFrame); ok {
1076				if mhw, ok := tt.want.(*MetaHeadersFrame); ok {
1077					hg := mhg.HeadersFrame
1078					hw := mhw.HeadersFrame
1079					if hg != nil && hw != nil && !reflect.DeepEqual(*hg, *hw) {
1080						t.Errorf("%s: headers differ:\n got: %+v\nwant: %+v\n", name, *hg, *hw)
1081					}
1082				}
1083			}
1084			str := func(v interface{}) string {
1085				if _, ok := v.(error); ok {
1086					return fmt.Sprintf("error %v", v)
1087				} else {
1088					return fmt.Sprintf("value %#v", v)
1089				}
1090			}
1091			t.Errorf("%s:\n got: %v\nwant: %s", name, str(got), str(tt.want))
1092		}
1093		if tt.wantErrReason != "" && tt.wantErrReason != fmt.Sprint(f.errDetail) {
1094			t.Errorf("%s: got error reason %q; want %q", name, f.errDetail, tt.wantErrReason)
1095		}
1096	}
1097}
1098
1099func TestSetReuseFrames(t *testing.T) {
1100	fr, buf := testFramer()
1101	fr.SetReuseFrames()
1102
1103	// Check that DataFrames are reused. Note that
1104	// SetReuseFrames only currently implements reuse of DataFrames.
1105	firstDf := readAndVerifyDataFrame("ABC", 3, fr, buf, t)
1106
1107	for i := 0; i < 10; i++ {
1108		df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t)
1109		if df != firstDf {
1110			t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf)
1111		}
1112	}
1113
1114	for i := 0; i < 10; i++ {
1115		df := readAndVerifyDataFrame("", 0, fr, buf, t)
1116		if df != firstDf {
1117			t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf)
1118		}
1119	}
1120
1121	for i := 0; i < 10; i++ {
1122		df := readAndVerifyDataFrame("HHH", 3, fr, buf, t)
1123		if df != firstDf {
1124			t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf)
1125		}
1126	}
1127}
1128
1129func TestSetReuseFramesMoreThanOnce(t *testing.T) {
1130	fr, buf := testFramer()
1131	fr.SetReuseFrames()
1132
1133	firstDf := readAndVerifyDataFrame("ABC", 3, fr, buf, t)
1134	fr.SetReuseFrames()
1135
1136	for i := 0; i < 10; i++ {
1137		df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t)
1138		// SetReuseFrames should be idempotent
1139		fr.SetReuseFrames()
1140		if df != firstDf {
1141			t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf)
1142		}
1143	}
1144}
1145
1146func TestNoSetReuseFrames(t *testing.T) {
1147	fr, buf := testFramer()
1148	const numNewDataFrames = 10
1149	dfSoFar := make([]interface{}, numNewDataFrames)
1150
1151	// Check that DataFrames are not reused if SetReuseFrames wasn't called.
1152	// SetReuseFrames only currently implements reuse of DataFrames.
1153	for i := 0; i < numNewDataFrames; i++ {
1154		df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t)
1155		for _, item := range dfSoFar {
1156			if df == item {
1157				t.Errorf("Expected Framer to return new DataFrames since SetNoReuseFrames not set.")
1158			}
1159		}
1160		dfSoFar[i] = df
1161	}
1162}
1163
1164func readAndVerifyDataFrame(data string, length byte, fr *Framer, buf *bytes.Buffer, t *testing.T) *DataFrame {
1165	var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
1166	fr.WriteData(streamID, true, []byte(data))
1167	wantEnc := "\x00\x00" + string(length) + "\x00\x01\x01\x02\x03\x04" + data
1168	if buf.String() != wantEnc {
1169		t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
1170	}
1171	f, err := fr.ReadFrame()
1172	if err != nil {
1173		t.Fatal(err)
1174	}
1175	df, ok := f.(*DataFrame)
1176	if !ok {
1177		t.Fatalf("got %T; want *DataFrame", f)
1178	}
1179	if !bytes.Equal(df.Data(), []byte(data)) {
1180		t.Errorf("got %q; want %q", df.Data(), []byte(data))
1181	}
1182	if f.Header().Flags&1 == 0 {
1183		t.Errorf("didn't see END_STREAM flag")
1184	}
1185	return df
1186}
1187
1188func encodeHeaderRaw(t *testing.T, pairs ...string) []byte {
1189	var he hpackEncoder
1190	return he.encodeHeaderRaw(t, pairs...)
1191}
1192
1193func TestSettingsDuplicates(t *testing.T) {
1194	tests := []struct {
1195		settings []Setting
1196		want     bool
1197	}{
1198		{nil, false},
1199		{[]Setting{{ID: 1}}, false},
1200		{[]Setting{{ID: 1}, {ID: 2}}, false},
1201		{[]Setting{{ID: 1}, {ID: 2}}, false},
1202		{[]Setting{{ID: 1}, {ID: 2}, {ID: 3}}, false},
1203		{[]Setting{{ID: 1}, {ID: 2}, {ID: 3}}, false},
1204		{[]Setting{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}}, false},
1205
1206		{[]Setting{{ID: 1}, {ID: 2}, {ID: 3}, {ID: 2}}, true},
1207		{[]Setting{{ID: 4}, {ID: 2}, {ID: 3}, {ID: 4}}, true},
1208
1209		{[]Setting{
1210			{ID: 1}, {ID: 2}, {ID: 3}, {ID: 4},
1211			{ID: 5}, {ID: 6}, {ID: 7}, {ID: 8},
1212			{ID: 9}, {ID: 10}, {ID: 11}, {ID: 12},
1213		}, false},
1214
1215		{[]Setting{
1216			{ID: 1}, {ID: 2}, {ID: 3}, {ID: 4},
1217			{ID: 5}, {ID: 6}, {ID: 7}, {ID: 8},
1218			{ID: 9}, {ID: 10}, {ID: 11}, {ID: 11},
1219		}, true},
1220	}
1221	for i, tt := range tests {
1222		fr, _ := testFramer()
1223		fr.WriteSettings(tt.settings...)
1224		f, err := fr.ReadFrame()
1225		if err != nil {
1226			t.Fatalf("%d. ReadFrame: %v", i, err)
1227		}
1228		sf := f.(*SettingsFrame)
1229		got := sf.HasDuplicates()
1230		if got != tt.want {
1231			t.Errorf("%d. HasDuplicates = %v; want %v", i, got, tt.want)
1232		}
1233	}
1234
1235}
1236