1package http3
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7
8	"github.com/golang/mock/gomock"
9	"github.com/lucas-clemente/quic-go"
10	mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
11
12	. "github.com/onsi/ginkgo"
13	. "github.com/onsi/gomega"
14)
15
16type bodyType uint8
17
18const (
19	bodyTypeRequest bodyType = iota
20	bodyTypeResponse
21)
22
23func (t bodyType) String() string {
24	if t == bodyTypeRequest {
25		return "request"
26	}
27	return "response"
28}
29
30var _ = Describe("Body", func() {
31	var (
32		rb            *body
33		str           *mockquic.MockStream
34		buf           *bytes.Buffer
35		reqDone       chan struct{}
36		errorCbCalled bool
37	)
38
39	errorCb := func() { errorCbCalled = true }
40
41	getDataFrame := func(data []byte) []byte {
42		b := &bytes.Buffer{}
43		(&dataFrame{Length: uint64(len(data))}).Write(b)
44		b.Write(data)
45		return b.Bytes()
46	}
47
48	BeforeEach(func() {
49		buf = &bytes.Buffer{}
50		errorCbCalled = false
51	})
52
53	for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} {
54		bodyType := bt
55
56		Context(fmt.Sprintf("using a %s body", bodyType), func() {
57			BeforeEach(func() {
58				str = mockquic.NewMockStream(mockCtrl)
59				str.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
60					return buf.Write(b)
61				}).AnyTimes()
62				str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
63					return buf.Read(b)
64				}).AnyTimes()
65
66				switch bodyType {
67				case bodyTypeRequest:
68					rb = newRequestBody(str, errorCb)
69				case bodyTypeResponse:
70					reqDone = make(chan struct{})
71					rb = newResponseBody(str, reqDone, errorCb)
72				}
73			})
74
75			It("reads DATA frames in a single run", func() {
76				buf.Write(getDataFrame([]byte("foobar")))
77				b := make([]byte, 6)
78				n, err := rb.Read(b)
79				Expect(err).ToNot(HaveOccurred())
80				Expect(n).To(Equal(6))
81				Expect(b).To(Equal([]byte("foobar")))
82			})
83
84			It("reads DATA frames in multiple runs", func() {
85				buf.Write(getDataFrame([]byte("foobar")))
86				b := make([]byte, 3)
87				n, err := rb.Read(b)
88				Expect(err).ToNot(HaveOccurred())
89				Expect(n).To(Equal(3))
90				Expect(b).To(Equal([]byte("foo")))
91				n, err = rb.Read(b)
92				Expect(err).ToNot(HaveOccurred())
93				Expect(n).To(Equal(3))
94				Expect(b).To(Equal([]byte("bar")))
95			})
96
97			It("reads DATA frames into too large buffers", func() {
98				buf.Write(getDataFrame([]byte("foobar")))
99				b := make([]byte, 10)
100				n, err := rb.Read(b)
101				Expect(err).ToNot(HaveOccurred())
102				Expect(n).To(Equal(6))
103				Expect(b[:n]).To(Equal([]byte("foobar")))
104			})
105
106			It("reads DATA frames into too large buffers, in multiple runs", func() {
107				buf.Write(getDataFrame([]byte("foobar")))
108				b := make([]byte, 4)
109				n, err := rb.Read(b)
110				Expect(err).ToNot(HaveOccurred())
111				Expect(n).To(Equal(4))
112				Expect(b).To(Equal([]byte("foob")))
113				n, err = rb.Read(b)
114				Expect(err).ToNot(HaveOccurred())
115				Expect(n).To(Equal(2))
116				Expect(b[:n]).To(Equal([]byte("ar")))
117			})
118
119			It("reads multiple DATA frames", func() {
120				buf.Write(getDataFrame([]byte("foo")))
121				buf.Write(getDataFrame([]byte("bar")))
122				b := make([]byte, 6)
123				n, err := rb.Read(b)
124				Expect(err).ToNot(HaveOccurred())
125				Expect(n).To(Equal(3))
126				Expect(b[:n]).To(Equal([]byte("foo")))
127				n, err = rb.Read(b)
128				Expect(err).ToNot(HaveOccurred())
129				Expect(n).To(Equal(3))
130				Expect(b[:n]).To(Equal([]byte("bar")))
131			})
132
133			It("skips HEADERS frames", func() {
134				buf.Write(getDataFrame([]byte("foo")))
135				(&headersFrame{Length: 10}).Write(buf)
136				buf.Write(make([]byte, 10))
137				buf.Write(getDataFrame([]byte("bar")))
138				b := make([]byte, 6)
139				n, err := io.ReadFull(rb, b)
140				Expect(err).ToNot(HaveOccurred())
141				Expect(n).To(Equal(6))
142				Expect(b).To(Equal([]byte("foobar")))
143			})
144
145			It("errors when it can't parse the frame", func() {
146				buf.Write([]byte("invalid"))
147				_, err := rb.Read([]byte{0})
148				Expect(err).To(HaveOccurred())
149			})
150
151			It("errors on unexpected frames, and calls the error callback", func() {
152				(&settingsFrame{}).Write(buf)
153				_, err := rb.Read([]byte{0})
154				Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame"))
155				Expect(errorCbCalled).To(BeTrue())
156			})
157
158			if bodyType == bodyTypeResponse {
159				It("closes the reqDone channel when Read errors", func() {
160					buf.Write([]byte("invalid"))
161					_, err := rb.Read([]byte{0})
162					Expect(err).To(HaveOccurred())
163					Expect(reqDone).To(BeClosed())
164				})
165
166				It("allows multiple calls to Read, when Read errors", func() {
167					buf.Write([]byte("invalid"))
168					_, err := rb.Read([]byte{0})
169					Expect(err).To(HaveOccurred())
170					Expect(reqDone).To(BeClosed())
171					_, err = rb.Read([]byte{0})
172					Expect(err).To(HaveOccurred())
173				})
174
175				It("closes responses", func() {
176					str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled))
177					Expect(rb.Close()).To(Succeed())
178				})
179
180				It("allows multiple calls to Close", func() {
181					str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2)
182					Expect(rb.Close()).To(Succeed())
183					Expect(reqDone).To(BeClosed())
184					Expect(rb.Close()).To(Succeed())
185				})
186			}
187		})
188	}
189})
190