1package quic
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"net"
8
9	"github.com/golang/mock/gomock"
10
11	"github.com/lucas-clemente/quic-go/internal/flowcontrol"
12	"github.com/lucas-clemente/quic-go/internal/mocks"
13	"github.com/lucas-clemente/quic-go/internal/protocol"
14	"github.com/lucas-clemente/quic-go/internal/qerr"
15	"github.com/lucas-clemente/quic-go/internal/wire"
16
17	. "github.com/onsi/ginkgo"
18	. "github.com/onsi/gomega"
19)
20
21func (e streamError) TestError() error {
22	nums := make([]interface{}, len(e.nums))
23	for i, num := range e.nums {
24		nums[i] = num
25	}
26	return fmt.Errorf(e.message, nums...)
27}
28
29type streamMapping struct {
30	firstIncomingBidiStream protocol.StreamID
31	firstIncomingUniStream  protocol.StreamID
32	firstOutgoingBidiStream protocol.StreamID
33	firstOutgoingUniStream  protocol.StreamID
34}
35
36func expectTooManyStreamsError(err error) {
37	ExpectWithOffset(1, err).To(HaveOccurred())
38	ExpectWithOffset(1, err.Error()).To(Equal(errTooManyOpenStreams.Error()))
39	nerr, ok := err.(net.Error)
40	ExpectWithOffset(1, ok).To(BeTrue())
41	ExpectWithOffset(1, nerr.Temporary()).To(BeTrue())
42	ExpectWithOffset(1, nerr.Timeout()).To(BeFalse())
43}
44
45var _ = Describe("Streams Map", func() {
46	newFlowController := func(protocol.StreamID) flowcontrol.StreamFlowController {
47		return mocks.NewMockStreamFlowController(mockCtrl)
48	}
49
50	serverStreamMapping := streamMapping{
51		firstIncomingBidiStream: 0,
52		firstOutgoingBidiStream: 1,
53		firstIncomingUniStream:  2,
54		firstOutgoingUniStream:  3,
55	}
56	clientStreamMapping := streamMapping{
57		firstIncomingBidiStream: 1,
58		firstOutgoingBidiStream: 0,
59		firstIncomingUniStream:  3,
60		firstOutgoingUniStream:  2,
61	}
62
63	for _, p := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} {
64		perspective := p
65		var ids streamMapping
66		if perspective == protocol.PerspectiveClient {
67			ids = clientStreamMapping
68		} else {
69			ids = serverStreamMapping
70		}
71
72		Context(perspective.String(), func() {
73			var (
74				m          *streamsMap
75				mockSender *MockStreamSender
76			)
77
78			const (
79				MaxBidiStreamNum = 111
80				MaxUniStreamNum  = 222
81			)
82
83			allowUnlimitedStreams := func() {
84				m.UpdateLimits(&wire.TransportParameters{
85					MaxBidiStreamNum: protocol.MaxStreamCount,
86					MaxUniStreamNum:  protocol.MaxStreamCount,
87				})
88			}
89
90			BeforeEach(func() {
91				mockSender = NewMockStreamSender(mockCtrl)
92				m = newStreamsMap(mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective, protocol.VersionWhatever).(*streamsMap)
93			})
94
95			Context("opening", func() {
96				It("opens bidirectional streams", func() {
97					allowUnlimitedStreams()
98					str, err := m.OpenStream()
99					Expect(err).ToNot(HaveOccurred())
100					Expect(str).To(BeAssignableToTypeOf(&stream{}))
101					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
102					str, err = m.OpenStream()
103					Expect(err).ToNot(HaveOccurred())
104					Expect(str).To(BeAssignableToTypeOf(&stream{}))
105					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + 4))
106				})
107
108				It("opens unidirectional streams", func() {
109					allowUnlimitedStreams()
110					str, err := m.OpenUniStream()
111					Expect(err).ToNot(HaveOccurred())
112					Expect(str).To(BeAssignableToTypeOf(&sendStream{}))
113					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
114					str, err = m.OpenUniStream()
115					Expect(err).ToNot(HaveOccurred())
116					Expect(str).To(BeAssignableToTypeOf(&sendStream{}))
117					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + 4))
118				})
119			})
120
121			Context("accepting", func() {
122				It("accepts bidirectional streams", func() {
123					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
124					Expect(err).ToNot(HaveOccurred())
125					str, err := m.AcceptStream(context.Background())
126					Expect(err).ToNot(HaveOccurred())
127					Expect(str).To(BeAssignableToTypeOf(&stream{}))
128					Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream))
129				})
130
131				It("accepts unidirectional streams", func() {
132					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
133					Expect(err).ToNot(HaveOccurred())
134					str, err := m.AcceptUniStream(context.Background())
135					Expect(err).ToNot(HaveOccurred())
136					Expect(str).To(BeAssignableToTypeOf(&receiveStream{}))
137					Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream))
138				})
139			})
140
141			Context("deleting", func() {
142				BeforeEach(func() {
143					mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
144					allowUnlimitedStreams()
145				})
146
147				It("deletes outgoing bidirectional streams", func() {
148					id := ids.firstOutgoingBidiStream
149					str, err := m.OpenStream()
150					Expect(err).ToNot(HaveOccurred())
151					Expect(str.StreamID()).To(Equal(id))
152					Expect(m.DeleteStream(id)).To(Succeed())
153					dstr, err := m.GetOrOpenSendStream(id)
154					Expect(err).ToNot(HaveOccurred())
155					Expect(dstr).To(BeNil())
156				})
157
158				It("deletes incoming bidirectional streams", func() {
159					id := ids.firstIncomingBidiStream
160					str, err := m.GetOrOpenReceiveStream(id)
161					Expect(err).ToNot(HaveOccurred())
162					Expect(str.StreamID()).To(Equal(id))
163					Expect(m.DeleteStream(id)).To(Succeed())
164					dstr, err := m.GetOrOpenReceiveStream(id)
165					Expect(err).ToNot(HaveOccurred())
166					Expect(dstr).To(BeNil())
167				})
168
169				It("accepts bidirectional streams after they have been deleted", func() {
170					id := ids.firstIncomingBidiStream
171					_, err := m.GetOrOpenReceiveStream(id)
172					Expect(err).ToNot(HaveOccurred())
173					Expect(m.DeleteStream(id)).To(Succeed())
174					str, err := m.AcceptStream(context.Background())
175					Expect(err).ToNot(HaveOccurred())
176					Expect(str).ToNot(BeNil())
177					Expect(str.StreamID()).To(Equal(id))
178				})
179
180				It("deletes outgoing unidirectional streams", func() {
181					id := ids.firstOutgoingUniStream
182					str, err := m.OpenUniStream()
183					Expect(err).ToNot(HaveOccurred())
184					Expect(str.StreamID()).To(Equal(id))
185					Expect(m.DeleteStream(id)).To(Succeed())
186					dstr, err := m.GetOrOpenSendStream(id)
187					Expect(err).ToNot(HaveOccurred())
188					Expect(dstr).To(BeNil())
189				})
190
191				It("deletes incoming unidirectional streams", func() {
192					id := ids.firstIncomingUniStream
193					str, err := m.GetOrOpenReceiveStream(id)
194					Expect(err).ToNot(HaveOccurred())
195					Expect(str.StreamID()).To(Equal(id))
196					Expect(m.DeleteStream(id)).To(Succeed())
197					dstr, err := m.GetOrOpenReceiveStream(id)
198					Expect(err).ToNot(HaveOccurred())
199					Expect(dstr).To(BeNil())
200				})
201
202				It("accepts unirectional streams after they have been deleted", func() {
203					id := ids.firstIncomingUniStream
204					_, err := m.GetOrOpenReceiveStream(id)
205					Expect(err).ToNot(HaveOccurred())
206					Expect(m.DeleteStream(id)).To(Succeed())
207					str, err := m.AcceptUniStream(context.Background())
208					Expect(err).ToNot(HaveOccurred())
209					Expect(str).ToNot(BeNil())
210					Expect(str.StreamID()).To(Equal(id))
211				})
212
213				It("errors when deleting unknown incoming unidirectional streams", func() {
214					id := ids.firstIncomingUniStream + 4
215					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id)))
216				})
217
218				It("errors when deleting unknown outgoing unidirectional streams", func() {
219					id := ids.firstOutgoingUniStream + 4
220					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id)))
221				})
222
223				It("errors when deleting unknown incoming bidirectional streams", func() {
224					id := ids.firstIncomingBidiStream + 4
225					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id)))
226				})
227
228				It("errors when deleting unknown outgoing bidirectional streams", func() {
229					id := ids.firstOutgoingBidiStream + 4
230					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id)))
231				})
232			})
233
234			Context("getting streams", func() {
235				BeforeEach(func() {
236					allowUnlimitedStreams()
237				})
238
239				Context("send streams", func() {
240					It("gets an outgoing bidirectional stream", func() {
241						// need to open the stream ourselves first
242						// the peer is not allowed to create a stream initiated by us
243						_, err := m.OpenStream()
244						Expect(err).ToNot(HaveOccurred())
245						str, err := m.GetOrOpenSendStream(ids.firstOutgoingBidiStream)
246						Expect(err).ToNot(HaveOccurred())
247						Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
248					})
249
250					It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
251						id := ids.firstOutgoingBidiStream + 5*4
252						_, err := m.GetOrOpenSendStream(id)
253						Expect(err).To(MatchError(&qerr.TransportError{
254							ErrorCode:    qerr.StreamStateError,
255							ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
256						}))
257					})
258
259					It("gets an outgoing unidirectional stream", func() {
260						// need to open the stream ourselves first
261						// the peer is not allowed to create a stream initiated by us
262						_, err := m.OpenUniStream()
263						Expect(err).ToNot(HaveOccurred())
264						str, err := m.GetOrOpenSendStream(ids.firstOutgoingUniStream)
265						Expect(err).ToNot(HaveOccurred())
266						Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
267					})
268
269					It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
270						id := ids.firstOutgoingUniStream + 5*4
271						_, err := m.GetOrOpenSendStream(id)
272						Expect(err).To(MatchError(&qerr.TransportError{
273							ErrorCode:    qerr.StreamStateError,
274							ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
275						}))
276					})
277
278					It("gets an incoming bidirectional stream", func() {
279						id := ids.firstIncomingBidiStream + 4*7
280						str, err := m.GetOrOpenSendStream(id)
281						Expect(err).ToNot(HaveOccurred())
282						Expect(str.StreamID()).To(Equal(id))
283					})
284
285					It("errors when trying to get an incoming unidirectional stream", func() {
286						id := ids.firstIncomingUniStream
287						_, err := m.GetOrOpenSendStream(id)
288						Expect(err).To(MatchError(&qerr.TransportError{
289							ErrorCode:    qerr.StreamStateError,
290							ErrorMessage: fmt.Sprintf("peer attempted to open send stream %d", id),
291						}))
292					})
293				})
294
295				Context("receive streams", func() {
296					It("gets an outgoing bidirectional stream", func() {
297						// need to open the stream ourselves first
298						// the peer is not allowed to create a stream initiated by us
299						_, err := m.OpenStream()
300						Expect(err).ToNot(HaveOccurred())
301						str, err := m.GetOrOpenReceiveStream(ids.firstOutgoingBidiStream)
302						Expect(err).ToNot(HaveOccurred())
303						Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
304					})
305
306					It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
307						id := ids.firstOutgoingBidiStream + 5*4
308						_, err := m.GetOrOpenReceiveStream(id)
309						Expect(err).To(MatchError(&qerr.TransportError{
310							ErrorCode:    qerr.StreamStateError,
311							ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
312						}))
313					})
314
315					It("gets an incoming bidirectional stream", func() {
316						id := ids.firstIncomingBidiStream + 4*7
317						str, err := m.GetOrOpenReceiveStream(id)
318						Expect(err).ToNot(HaveOccurred())
319						Expect(str.StreamID()).To(Equal(id))
320					})
321
322					It("gets an incoming unidirectional stream", func() {
323						id := ids.firstIncomingUniStream + 4*10
324						str, err := m.GetOrOpenReceiveStream(id)
325						Expect(err).ToNot(HaveOccurred())
326						Expect(str.StreamID()).To(Equal(id))
327					})
328
329					It("errors when trying to get an outgoing unidirectional stream", func() {
330						id := ids.firstOutgoingUniStream
331						_, err := m.GetOrOpenReceiveStream(id)
332						Expect(err).To(MatchError(&qerr.TransportError{
333							ErrorCode:    qerr.StreamStateError,
334							ErrorMessage: fmt.Sprintf("peer attempted to open receive stream %d", id),
335						}))
336					})
337				})
338			})
339
340			It("processes the parameter for outgoing streams", func() {
341				mockSender.EXPECT().queueControlFrame(gomock.Any())
342				_, err := m.OpenStream()
343				expectTooManyStreamsError(err)
344				m.UpdateLimits(&wire.TransportParameters{
345					MaxBidiStreamNum: 5,
346					MaxUniStreamNum:  8,
347				})
348
349				mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2)
350				// test we can only 5 bidirectional streams
351				for i := 0; i < 5; i++ {
352					str, err := m.OpenStream()
353					Expect(err).ToNot(HaveOccurred())
354					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + protocol.StreamID(4*i)))
355				}
356				_, err = m.OpenStream()
357				expectTooManyStreamsError(err)
358				// test we can only 8 unidirectional streams
359				for i := 0; i < 8; i++ {
360					str, err := m.OpenUniStream()
361					Expect(err).ToNot(HaveOccurred())
362					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + protocol.StreamID(4*i)))
363				}
364				_, err = m.OpenUniStream()
365				expectTooManyStreamsError(err)
366			})
367
368			if perspective == protocol.PerspectiveClient {
369				It("applies parameters to existing streams (needed for 0-RTT)", func() {
370					m.UpdateLimits(&wire.TransportParameters{
371						MaxBidiStreamNum: 1000,
372						MaxUniStreamNum:  1000,
373					})
374					flowControllers := make(map[protocol.StreamID]*mocks.MockStreamFlowController)
375					m.newFlowController = func(id protocol.StreamID) flowcontrol.StreamFlowController {
376						fc := mocks.NewMockStreamFlowController(mockCtrl)
377						flowControllers[id] = fc
378						return fc
379					}
380
381					str, err := m.OpenStream()
382					Expect(err).ToNot(HaveOccurred())
383					unistr, err := m.OpenUniStream()
384					Expect(err).ToNot(HaveOccurred())
385
386					Expect(flowControllers).To(HaveKey(str.StreamID()))
387					flowControllers[str.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(4321))
388					Expect(flowControllers).To(HaveKey(unistr.StreamID()))
389					flowControllers[unistr.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(1234))
390
391					m.UpdateLimits(&wire.TransportParameters{
392						MaxBidiStreamNum:               1000,
393						InitialMaxStreamDataUni:        1234,
394						MaxUniStreamNum:                1000,
395						InitialMaxStreamDataBidiRemote: 4321,
396					})
397				})
398			}
399
400			Context("handling MAX_STREAMS frames", func() {
401				BeforeEach(func() {
402					mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
403				})
404
405				It("processes IDs for outgoing bidirectional streams", func() {
406					_, err := m.OpenStream()
407					expectTooManyStreamsError(err)
408					m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
409						Type:         protocol.StreamTypeBidi,
410						MaxStreamNum: 1,
411					})
412					str, err := m.OpenStream()
413					Expect(err).ToNot(HaveOccurred())
414					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
415					_, err = m.OpenStream()
416					expectTooManyStreamsError(err)
417				})
418
419				It("processes IDs for outgoing unidirectional streams", func() {
420					_, err := m.OpenUniStream()
421					expectTooManyStreamsError(err)
422					m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
423						Type:         protocol.StreamTypeUni,
424						MaxStreamNum: 1,
425					})
426					str, err := m.OpenUniStream()
427					Expect(err).ToNot(HaveOccurred())
428					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
429					_, err = m.OpenUniStream()
430					expectTooManyStreamsError(err)
431				})
432			})
433
434			Context("sending MAX_STREAMS frames", func() {
435				It("sends a MAX_STREAMS frame for bidirectional streams", func() {
436					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
437					Expect(err).ToNot(HaveOccurred())
438					_, err = m.AcceptStream(context.Background())
439					Expect(err).ToNot(HaveOccurred())
440					mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
441						Type:         protocol.StreamTypeBidi,
442						MaxStreamNum: MaxBidiStreamNum + 1,
443					})
444					Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed())
445				})
446
447				It("sends a MAX_STREAMS frame for unidirectional streams", func() {
448					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
449					Expect(err).ToNot(HaveOccurred())
450					_, err = m.AcceptUniStream(context.Background())
451					Expect(err).ToNot(HaveOccurred())
452					mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
453						Type:         protocol.StreamTypeUni,
454						MaxStreamNum: MaxUniStreamNum + 1,
455					})
456					Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed())
457				})
458			})
459
460			It("closes", func() {
461				testErr := errors.New("test error")
462				m.CloseWithError(testErr)
463				_, err := m.OpenStream()
464				Expect(err).To(HaveOccurred())
465				Expect(err.Error()).To(Equal(testErr.Error()))
466				_, err = m.OpenUniStream()
467				Expect(err).To(HaveOccurred())
468				Expect(err.Error()).To(Equal(testErr.Error()))
469				_, err = m.AcceptStream(context.Background())
470				Expect(err).To(HaveOccurred())
471				Expect(err.Error()).To(Equal(testErr.Error()))
472				_, err = m.AcceptUniStream(context.Background())
473				Expect(err).To(HaveOccurred())
474				Expect(err.Error()).To(Equal(testErr.Error()))
475			})
476
477			if perspective == protocol.PerspectiveClient {
478				It("resets for 0-RTT", func() {
479					mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
480					m.ResetFor0RTT()
481					// make sure that calls to open / accept streams fail
482					_, err := m.OpenStream()
483					Expect(err).To(MatchError(Err0RTTRejected))
484					_, err = m.AcceptStream(context.Background())
485					Expect(err).To(MatchError(Err0RTTRejected))
486					// make sure that we can still get new streams, as the server might be sending us data
487					str, err := m.GetOrOpenReceiveStream(3)
488					Expect(err).ToNot(HaveOccurred())
489					Expect(str).ToNot(BeNil())
490
491					// now switch to using the new streams map
492					m.UseResetMaps()
493					_, err = m.OpenStream()
494					Expect(err).To(HaveOccurred())
495					Expect(err.Error()).To(ContainSubstring("too many open streams"))
496				})
497			}
498		})
499	}
500})
501