1package quic
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"net"
8	"sync"
9
10	"github.com/lucas-clemente/quic-go/internal/flowcontrol"
11	"github.com/lucas-clemente/quic-go/internal/protocol"
12	"github.com/lucas-clemente/quic-go/internal/qerr"
13	"github.com/lucas-clemente/quic-go/internal/wire"
14)
15
16type streamError struct {
17	message string
18	nums    []protocol.StreamNum
19}
20
21func (e streamError) Error() string {
22	return e.message
23}
24
25func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
26	strError, ok := err.(streamError)
27	if !ok {
28		return err
29	}
30	ids := make([]interface{}, len(strError.nums))
31	for i, num := range strError.nums {
32		ids[i] = num.StreamID(stype, pers)
33	}
34	return fmt.Errorf(strError.Error(), ids...)
35}
36
37type streamOpenErr struct{ error }
38
39var _ net.Error = &streamOpenErr{}
40
41func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
42func (streamOpenErr) Timeout() bool     { return false }
43
44// errTooManyOpenStreams is used internally by the outgoing streams maps.
45var errTooManyOpenStreams = errors.New("too many open streams")
46
47type streamsMap struct {
48	perspective protocol.Perspective
49	version     protocol.VersionNumber
50
51	maxIncomingBidiStreams uint64
52	maxIncomingUniStreams  uint64
53
54	sender            streamSender
55	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
56
57	mutex               sync.Mutex
58	outgoingBidiStreams *outgoingBidiStreamsMap
59	outgoingUniStreams  *outgoingUniStreamsMap
60	incomingBidiStreams *incomingBidiStreamsMap
61	incomingUniStreams  *incomingUniStreamsMap
62	reset               bool
63}
64
65var _ streamManager = &streamsMap{}
66
67func newStreamsMap(
68	sender streamSender,
69	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
70	maxIncomingBidiStreams uint64,
71	maxIncomingUniStreams uint64,
72	perspective protocol.Perspective,
73	version protocol.VersionNumber,
74) streamManager {
75	m := &streamsMap{
76		perspective:            perspective,
77		newFlowController:      newFlowController,
78		maxIncomingBidiStreams: maxIncomingBidiStreams,
79		maxIncomingUniStreams:  maxIncomingUniStreams,
80		sender:                 sender,
81		version:                version,
82	}
83	m.initMaps()
84	return m
85}
86
87func (m *streamsMap) initMaps() {
88	m.outgoingBidiStreams = newOutgoingBidiStreamsMap(
89		func(num protocol.StreamNum) streamI {
90			id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
91			return newStream(id, m.sender, m.newFlowController(id), m.version)
92		},
93		m.sender.queueControlFrame,
94	)
95	m.incomingBidiStreams = newIncomingBidiStreamsMap(
96		func(num protocol.StreamNum) streamI {
97			id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
98			return newStream(id, m.sender, m.newFlowController(id), m.version)
99		},
100		m.maxIncomingBidiStreams,
101		m.sender.queueControlFrame,
102	)
103	m.outgoingUniStreams = newOutgoingUniStreamsMap(
104		func(num protocol.StreamNum) sendStreamI {
105			id := num.StreamID(protocol.StreamTypeUni, m.perspective)
106			return newSendStream(id, m.sender, m.newFlowController(id), m.version)
107		},
108		m.sender.queueControlFrame,
109	)
110	m.incomingUniStreams = newIncomingUniStreamsMap(
111		func(num protocol.StreamNum) receiveStreamI {
112			id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite())
113			return newReceiveStream(id, m.sender, m.newFlowController(id), m.version)
114		},
115		m.maxIncomingUniStreams,
116		m.sender.queueControlFrame,
117	)
118}
119
120func (m *streamsMap) OpenStream() (Stream, error) {
121	m.mutex.Lock()
122	reset := m.reset
123	mm := m.outgoingBidiStreams
124	m.mutex.Unlock()
125	if reset {
126		return nil, Err0RTTRejected
127	}
128	str, err := mm.OpenStream()
129	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
130}
131
132func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
133	m.mutex.Lock()
134	reset := m.reset
135	mm := m.outgoingBidiStreams
136	m.mutex.Unlock()
137	if reset {
138		return nil, Err0RTTRejected
139	}
140	str, err := mm.OpenStreamSync(ctx)
141	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
142}
143
144func (m *streamsMap) OpenUniStream() (SendStream, error) {
145	m.mutex.Lock()
146	reset := m.reset
147	mm := m.outgoingUniStreams
148	m.mutex.Unlock()
149	if reset {
150		return nil, Err0RTTRejected
151	}
152	str, err := mm.OpenStream()
153	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
154}
155
156func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
157	m.mutex.Lock()
158	reset := m.reset
159	mm := m.outgoingUniStreams
160	m.mutex.Unlock()
161	if reset {
162		return nil, Err0RTTRejected
163	}
164	str, err := mm.OpenStreamSync(ctx)
165	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
166}
167
168func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
169	m.mutex.Lock()
170	reset := m.reset
171	mm := m.incomingBidiStreams
172	m.mutex.Unlock()
173	if reset {
174		return nil, Err0RTTRejected
175	}
176	str, err := mm.AcceptStream(ctx)
177	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
178}
179
180func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
181	m.mutex.Lock()
182	reset := m.reset
183	mm := m.incomingUniStreams
184	m.mutex.Unlock()
185	if reset {
186		return nil, Err0RTTRejected
187	}
188	str, err := mm.AcceptStream(ctx)
189	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
190}
191
192func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
193	num := id.StreamNum()
194	switch id.Type() {
195	case protocol.StreamTypeUni:
196		if id.InitiatedBy() == m.perspective {
197			return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective)
198		}
199		return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite())
200	case protocol.StreamTypeBidi:
201		if id.InitiatedBy() == m.perspective {
202			return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective)
203		}
204		return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite())
205	}
206	panic("")
207}
208
209func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
210	str, err := m.getOrOpenReceiveStream(id)
211	if err != nil {
212		return nil, &qerr.TransportError{
213			ErrorCode:    qerr.StreamStateError,
214			ErrorMessage: err.Error(),
215		}
216	}
217	return str, nil
218}
219
220func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
221	num := id.StreamNum()
222	switch id.Type() {
223	case protocol.StreamTypeUni:
224		if id.InitiatedBy() == m.perspective {
225			// an outgoing unidirectional stream is a send stream, not a receive stream
226			return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
227		}
228		str, err := m.incomingUniStreams.GetOrOpenStream(num)
229		return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
230	case protocol.StreamTypeBidi:
231		var str receiveStreamI
232		var err error
233		if id.InitiatedBy() == m.perspective {
234			str, err = m.outgoingBidiStreams.GetStream(num)
235		} else {
236			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
237		}
238		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
239	}
240	panic("")
241}
242
243func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
244	str, err := m.getOrOpenSendStream(id)
245	if err != nil {
246		return nil, &qerr.TransportError{
247			ErrorCode:    qerr.StreamStateError,
248			ErrorMessage: err.Error(),
249		}
250	}
251	return str, nil
252}
253
254func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
255	num := id.StreamNum()
256	switch id.Type() {
257	case protocol.StreamTypeUni:
258		if id.InitiatedBy() == m.perspective {
259			str, err := m.outgoingUniStreams.GetStream(num)
260			return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
261		}
262		// an incoming unidirectional stream is a receive stream, not a send stream
263		return nil, fmt.Errorf("peer attempted to open send stream %d", id)
264	case protocol.StreamTypeBidi:
265		var str sendStreamI
266		var err error
267		if id.InitiatedBy() == m.perspective {
268			str, err = m.outgoingBidiStreams.GetStream(num)
269		} else {
270			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
271		}
272		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
273	}
274	panic("")
275}
276
277func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
278	switch f.Type {
279	case protocol.StreamTypeUni:
280		m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
281	case protocol.StreamTypeBidi:
282		m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
283	}
284}
285
286func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
287	m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
288	m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
289	m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni)
290	m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
291}
292
293func (m *streamsMap) CloseWithError(err error) {
294	m.outgoingBidiStreams.CloseWithError(err)
295	m.outgoingUniStreams.CloseWithError(err)
296	m.incomingBidiStreams.CloseWithError(err)
297	m.incomingUniStreams.CloseWithError(err)
298}
299
300// ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are
301// 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error.
302// 2. reset to their initial state, such that we can immediately process new incoming stream data.
303// Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error,
304// until UseResetMaps() has been called.
305func (m *streamsMap) ResetFor0RTT() {
306	m.mutex.Lock()
307	defer m.mutex.Unlock()
308	m.reset = true
309	m.CloseWithError(Err0RTTRejected)
310	m.initMaps()
311}
312
313func (m *streamsMap) UseResetMaps() {
314	m.mutex.Lock()
315	m.reset = false
316	m.mutex.Unlock()
317}
318