1package quic
2
3import (
4	"context"
5	"sync"
6
7	"github.com/lucas-clemente/quic-go/internal/protocol"
8	"github.com/lucas-clemente/quic-go/internal/wire"
9)
10
11//go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi"
12//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni"
13type outgoingItemsMap struct {
14	mutex sync.RWMutex
15
16	streams map[protocol.StreamNum]item
17
18	openQueue      map[uint64]chan struct{}
19	lowestInQueue  uint64
20	highestInQueue uint64
21
22	nextStream  protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
23	maxStream   protocol.StreamNum // the maximum stream ID we're allowed to open
24	blockedSent bool               // was a STREAMS_BLOCKED sent for the current maxStream
25
26	newStream            func(protocol.StreamNum) item
27	queueStreamIDBlocked func(*wire.StreamsBlockedFrame)
28
29	closeErr error
30}
31
32func newOutgoingItemsMap(
33	newStream func(protocol.StreamNum) item,
34	queueControlFrame func(wire.Frame),
35) *outgoingItemsMap {
36	return &outgoingItemsMap{
37		streams:              make(map[protocol.StreamNum]item),
38		openQueue:            make(map[uint64]chan struct{}),
39		maxStream:            protocol.InvalidStreamNum,
40		nextStream:           1,
41		newStream:            newStream,
42		queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) },
43	}
44}
45
46func (m *outgoingItemsMap) OpenStream() (item, error) {
47	m.mutex.Lock()
48	defer m.mutex.Unlock()
49
50	if m.closeErr != nil {
51		return nil, m.closeErr
52	}
53
54	// if there are OpenStreamSync calls waiting, return an error here
55	if len(m.openQueue) > 0 || m.nextStream > m.maxStream {
56		m.maybeSendBlockedFrame()
57		return nil, streamOpenErr{errTooManyOpenStreams}
58	}
59	return m.openStream(), nil
60}
61
62func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) {
63	m.mutex.Lock()
64	defer m.mutex.Unlock()
65
66	if m.closeErr != nil {
67		return nil, m.closeErr
68	}
69
70	if err := ctx.Err(); err != nil {
71		return nil, err
72	}
73
74	if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
75		return m.openStream(), nil
76	}
77
78	waitChan := make(chan struct{}, 1)
79	queuePos := m.highestInQueue
80	m.highestInQueue++
81	if len(m.openQueue) == 0 {
82		m.lowestInQueue = queuePos
83	}
84	m.openQueue[queuePos] = waitChan
85	m.maybeSendBlockedFrame()
86
87	for {
88		m.mutex.Unlock()
89		select {
90		case <-ctx.Done():
91			m.mutex.Lock()
92			delete(m.openQueue, queuePos)
93			return nil, ctx.Err()
94		case <-waitChan:
95		}
96		m.mutex.Lock()
97
98		if m.closeErr != nil {
99			return nil, m.closeErr
100		}
101		if m.nextStream > m.maxStream {
102			// no stream available. Continue waiting
103			continue
104		}
105		str := m.openStream()
106		delete(m.openQueue, queuePos)
107		m.lowestInQueue = queuePos + 1
108		m.unblockOpenSync()
109		return str, nil
110	}
111}
112
113func (m *outgoingItemsMap) openStream() item {
114	s := m.newStream(m.nextStream)
115	m.streams[m.nextStream] = s
116	m.nextStream++
117	return s
118}
119
120// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset,
121// if we haven't sent one for this offset yet
122func (m *outgoingItemsMap) maybeSendBlockedFrame() {
123	if m.blockedSent {
124		return
125	}
126
127	var streamNum protocol.StreamNum
128	if m.maxStream != protocol.InvalidStreamNum {
129		streamNum = m.maxStream
130	}
131	m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
132		Type:        streamTypeGeneric,
133		StreamLimit: streamNum,
134	})
135	m.blockedSent = true
136}
137
138func (m *outgoingItemsMap) GetStream(num protocol.StreamNum) (item, error) {
139	m.mutex.RLock()
140	if num >= m.nextStream {
141		m.mutex.RUnlock()
142		return nil, streamError{
143			message: "peer attempted to open stream %d",
144			nums:    []protocol.StreamNum{num},
145		}
146	}
147	s := m.streams[num]
148	m.mutex.RUnlock()
149	return s, nil
150}
151
152func (m *outgoingItemsMap) DeleteStream(num protocol.StreamNum) error {
153	m.mutex.Lock()
154	defer m.mutex.Unlock()
155
156	if _, ok := m.streams[num]; !ok {
157		return streamError{
158			message: "tried to delete unknown outgoing stream %d",
159			nums:    []protocol.StreamNum{num},
160		}
161	}
162	delete(m.streams, num)
163	return nil
164}
165
166func (m *outgoingItemsMap) SetMaxStream(num protocol.StreamNum) {
167	m.mutex.Lock()
168	defer m.mutex.Unlock()
169
170	if num <= m.maxStream {
171		return
172	}
173	m.maxStream = num
174	m.blockedSent = false
175	if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) {
176		m.maybeSendBlockedFrame()
177	}
178	m.unblockOpenSync()
179}
180
181// UpdateSendWindow is called when the peer's transport parameters are received.
182// Only in the case of a 0-RTT handshake will we have open streams at this point.
183// We might need to update the send window, in case the server increased it.
184func (m *outgoingItemsMap) UpdateSendWindow(limit protocol.ByteCount) {
185	m.mutex.Lock()
186	for _, str := range m.streams {
187		str.updateSendWindow(limit)
188	}
189	m.mutex.Unlock()
190}
191
192// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream
193func (m *outgoingItemsMap) unblockOpenSync() {
194	if len(m.openQueue) == 0 {
195		return
196	}
197	for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
198		c, ok := m.openQueue[qp]
199		if !ok { // entry was deleted because the context was canceled
200			continue
201		}
202		// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
203		// It's sufficient to only unblock OpenStreamSync once.
204		select {
205		case c <- struct{}{}:
206		default:
207		}
208		return
209	}
210}
211
212func (m *outgoingItemsMap) CloseWithError(err error) {
213	m.mutex.Lock()
214	m.closeErr = err
215	for _, str := range m.streams {
216		str.closeForShutdown(err)
217	}
218	for _, c := range m.openQueue {
219		if c != nil {
220			close(c)
221		}
222	}
223	m.mutex.Unlock()
224}
225