1package quic
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"net"
8
9	"github.com/lucas-clemente/quic-go/internal/flowcontrol"
10	"github.com/lucas-clemente/quic-go/internal/protocol"
11	"github.com/lucas-clemente/quic-go/internal/qerr"
12	"github.com/lucas-clemente/quic-go/internal/wire"
13)
14
15type streamError struct {
16	message string
17	nums    []protocol.StreamNum
18}
19
20func (e streamError) Error() string {
21	return e.message
22}
23
24func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
25	strError, ok := err.(streamError)
26	if !ok {
27		return err
28	}
29	ids := make([]interface{}, len(strError.nums))
30	for i, num := range strError.nums {
31		ids[i] = num.StreamID(stype, pers)
32	}
33	return fmt.Errorf(strError.Error(), ids...)
34}
35
36type streamOpenErr struct{ error }
37
38var _ net.Error = &streamOpenErr{}
39
40func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
41func (streamOpenErr) Timeout() bool     { return false }
42
43// errTooManyOpenStreams is used internally by the outgoing streams maps.
44var errTooManyOpenStreams = errors.New("too many open streams")
45
46type streamsMap struct {
47	perspective protocol.Perspective
48
49	sender            streamSender
50	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
51
52	outgoingBidiStreams *outgoingBidiStreamsMap
53	outgoingUniStreams  *outgoingUniStreamsMap
54	incomingBidiStreams *incomingBidiStreamsMap
55	incomingUniStreams  *incomingUniStreamsMap
56}
57
58var _ streamManager = &streamsMap{}
59
60func newStreamsMap(
61	sender streamSender,
62	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
63	maxIncomingBidiStreams uint64,
64	maxIncomingUniStreams uint64,
65	perspective protocol.Perspective,
66	version protocol.VersionNumber,
67) streamManager {
68	m := &streamsMap{
69		perspective:       perspective,
70		newFlowController: newFlowController,
71		sender:            sender,
72	}
73	m.outgoingBidiStreams = newOutgoingBidiStreamsMap(
74		func(num protocol.StreamNum) streamI {
75			id := num.StreamID(protocol.StreamTypeBidi, perspective)
76			return newStream(id, m.sender, m.newFlowController(id), version)
77		},
78		sender.queueControlFrame,
79	)
80	m.incomingBidiStreams = newIncomingBidiStreamsMap(
81		func(num protocol.StreamNum) streamI {
82			id := num.StreamID(protocol.StreamTypeBidi, perspective.Opposite())
83			return newStream(id, m.sender, m.newFlowController(id), version)
84		},
85		maxIncomingBidiStreams,
86		sender.queueControlFrame,
87	)
88	m.outgoingUniStreams = newOutgoingUniStreamsMap(
89		func(num protocol.StreamNum) sendStreamI {
90			id := num.StreamID(protocol.StreamTypeUni, perspective)
91			return newSendStream(id, m.sender, m.newFlowController(id), version)
92		},
93		sender.queueControlFrame,
94	)
95	m.incomingUniStreams = newIncomingUniStreamsMap(
96		func(num protocol.StreamNum) receiveStreamI {
97			id := num.StreamID(protocol.StreamTypeUni, perspective.Opposite())
98			return newReceiveStream(id, m.sender, m.newFlowController(id), version)
99		},
100		maxIncomingUniStreams,
101		sender.queueControlFrame,
102	)
103	return m
104}
105
106func (m *streamsMap) OpenStream() (Stream, error) {
107	str, err := m.outgoingBidiStreams.OpenStream()
108	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
109}
110
111func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
112	str, err := m.outgoingBidiStreams.OpenStreamSync(ctx)
113	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
114}
115
116func (m *streamsMap) OpenUniStream() (SendStream, error) {
117	str, err := m.outgoingUniStreams.OpenStream()
118	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
119}
120
121func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
122	str, err := m.outgoingUniStreams.OpenStreamSync(ctx)
123	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
124}
125
126func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
127	str, err := m.incomingBidiStreams.AcceptStream(ctx)
128	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
129}
130
131func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
132	str, err := m.incomingUniStreams.AcceptStream(ctx)
133	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
134}
135
136func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
137	num := id.StreamNum()
138	switch id.Type() {
139	case protocol.StreamTypeUni:
140		if id.InitiatedBy() == m.perspective {
141			return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective)
142		}
143		return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite())
144	case protocol.StreamTypeBidi:
145		if id.InitiatedBy() == m.perspective {
146			return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective)
147		}
148		return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite())
149	}
150	panic("")
151}
152
153func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
154	str, err := m.getOrOpenReceiveStream(id)
155	if err != nil {
156		return nil, qerr.NewError(qerr.StreamStateError, err.Error())
157	}
158	return str, nil
159}
160
161func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
162	num := id.StreamNum()
163	switch id.Type() {
164	case protocol.StreamTypeUni:
165		if id.InitiatedBy() == m.perspective {
166			// an outgoing unidirectional stream is a send stream, not a receive stream
167			return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
168		}
169		str, err := m.incomingUniStreams.GetOrOpenStream(num)
170		return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
171	case protocol.StreamTypeBidi:
172		var str receiveStreamI
173		var err error
174		if id.InitiatedBy() == m.perspective {
175			str, err = m.outgoingBidiStreams.GetStream(num)
176		} else {
177			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
178		}
179		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
180	}
181	panic("")
182}
183
184func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
185	str, err := m.getOrOpenSendStream(id)
186	if err != nil {
187		return nil, qerr.NewError(qerr.StreamStateError, err.Error())
188	}
189	return str, nil
190}
191
192func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
193	num := id.StreamNum()
194	switch id.Type() {
195	case protocol.StreamTypeUni:
196		if id.InitiatedBy() == m.perspective {
197			str, err := m.outgoingUniStreams.GetStream(num)
198			return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
199		}
200		// an incoming unidirectional stream is a receive stream, not a send stream
201		return nil, fmt.Errorf("peer attempted to open send stream %d", id)
202	case protocol.StreamTypeBidi:
203		var str sendStreamI
204		var err error
205		if id.InitiatedBy() == m.perspective {
206			str, err = m.outgoingBidiStreams.GetStream(num)
207		} else {
208			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
209		}
210		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
211	}
212	panic("")
213}
214
215func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error {
216	switch f.Type {
217	case protocol.StreamTypeUni:
218		m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
219	case protocol.StreamTypeBidi:
220		m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
221	}
222	return nil
223}
224
225func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) error {
226	if p.MaxBidiStreamNum > protocol.MaxStreamCount ||
227		p.MaxUniStreamNum > protocol.MaxStreamCount {
228		return qerr.StreamLimitError
229	}
230	// Max{Uni,Bidi}StreamID returns the highest stream ID that the peer is allowed to open.
231	m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
232	m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
233	return nil
234}
235
236func (m *streamsMap) CloseWithError(err error) {
237	m.outgoingBidiStreams.CloseWithError(err)
238	m.outgoingUniStreams.CloseWithError(err)
239	m.incomingBidiStreams.CloseWithError(err)
240	m.incomingUniStreams.CloseWithError(err)
241}
242