1package quic
2
3import (
4	"context"
5	"fmt"
6	"sync"
7	"time"
8
9	"github.com/lucas-clemente/quic-go/internal/ackhandler"
10	"github.com/lucas-clemente/quic-go/internal/flowcontrol"
11	"github.com/lucas-clemente/quic-go/internal/protocol"
12	"github.com/lucas-clemente/quic-go/internal/qerr"
13	"github.com/lucas-clemente/quic-go/internal/utils"
14	"github.com/lucas-clemente/quic-go/internal/wire"
15)
16
17type sendStreamI interface {
18	SendStream
19	handleStopSendingFrame(*wire.StopSendingFrame)
20	hasData() bool
21	popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool)
22	closeForShutdown(error)
23	updateSendWindow(protocol.ByteCount)
24}
25
26type sendStream struct {
27	mutex sync.Mutex
28
29	numOutstandingFrames int64
30	retransmissionQueue  []*wire.StreamFrame
31
32	ctx       context.Context
33	ctxCancel context.CancelFunc
34
35	streamID protocol.StreamID
36	sender   streamSender
37
38	writeOffset protocol.ByteCount
39
40	cancelWriteErr      error
41	closeForShutdownErr error
42
43	closedForShutdown bool // set when CloseForShutdown() is called
44	finishedWriting   bool // set once Close() is called
45	canceledWrite     bool // set when CancelWrite() is called, or a STOP_SENDING frame is received
46	finSent           bool // set when a STREAM_FRAME with FIN bit has been sent
47	completed         bool // set when this stream has been reported to the streamSender as completed
48
49	dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
50	nextFrame      *wire.StreamFrame
51
52	writeChan chan struct{}
53	deadline  time.Time
54
55	flowController flowcontrol.StreamFlowController
56
57	version protocol.VersionNumber
58}
59
60var (
61	_ SendStream  = &sendStream{}
62	_ sendStreamI = &sendStream{}
63)
64
65func newSendStream(
66	streamID protocol.StreamID,
67	sender streamSender,
68	flowController flowcontrol.StreamFlowController,
69	version protocol.VersionNumber,
70) *sendStream {
71	s := &sendStream{
72		streamID:       streamID,
73		sender:         sender,
74		flowController: flowController,
75		writeChan:      make(chan struct{}, 1),
76		version:        version,
77	}
78	s.ctx, s.ctxCancel = context.WithCancel(context.Background())
79	return s
80}
81
82func (s *sendStream) StreamID() protocol.StreamID {
83	return s.streamID // same for receiveStream and sendStream
84}
85
86func (s *sendStream) Write(p []byte) (int, error) {
87	s.mutex.Lock()
88	defer s.mutex.Unlock()
89
90	if s.finishedWriting {
91		return 0, fmt.Errorf("write on closed stream %d", s.streamID)
92	}
93	if s.canceledWrite {
94		return 0, s.cancelWriteErr
95	}
96	if s.closeForShutdownErr != nil {
97		return 0, s.closeForShutdownErr
98	}
99	if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
100		return 0, errDeadline
101	}
102	if len(p) == 0 {
103		return 0, nil
104	}
105
106	s.dataForWriting = p
107
108	var (
109		deadlineTimer  *utils.Timer
110		bytesWritten   int
111		notifiedSender bool
112	)
113	for {
114		var copied bool
115		var deadline time.Time
116		// As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame),
117		// which can the be popped the next time we assemble a packet.
118		// This allows us to return Write() when all data but x bytes have been sent out.
119		// When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame,
120		// allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN).
121		if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 {
122			if s.nextFrame == nil {
123				f := wire.GetStreamFrame()
124				f.Offset = s.writeOffset
125				f.StreamID = s.streamID
126				f.DataLenPresent = true
127				f.Data = f.Data[:len(s.dataForWriting)]
128				copy(f.Data, s.dataForWriting)
129				s.nextFrame = f
130			} else {
131				l := len(s.nextFrame.Data)
132				s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)]
133				copy(s.nextFrame.Data[l:], s.dataForWriting)
134			}
135			s.dataForWriting = nil
136			bytesWritten = len(p)
137			copied = true
138		} else {
139			bytesWritten = len(p) - len(s.dataForWriting)
140			deadline = s.deadline
141			if !deadline.IsZero() {
142				if !time.Now().Before(deadline) {
143					s.dataForWriting = nil
144					return bytesWritten, errDeadline
145				}
146				if deadlineTimer == nil {
147					deadlineTimer = utils.NewTimer()
148					defer deadlineTimer.Stop()
149				}
150				deadlineTimer.Reset(deadline)
151			}
152			if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown {
153				break
154			}
155		}
156
157		s.mutex.Unlock()
158		if !notifiedSender {
159			s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
160			notifiedSender = true
161		}
162		if copied {
163			s.mutex.Lock()
164			break
165		}
166		if deadline.IsZero() {
167			<-s.writeChan
168		} else {
169			select {
170			case <-s.writeChan:
171			case <-deadlineTimer.Chan():
172				deadlineTimer.SetRead()
173			}
174		}
175		s.mutex.Lock()
176	}
177
178	if bytesWritten == len(p) {
179		return bytesWritten, nil
180	}
181	if s.closeForShutdownErr != nil {
182		return bytesWritten, s.closeForShutdownErr
183	} else if s.cancelWriteErr != nil {
184		return bytesWritten, s.cancelWriteErr
185	}
186	return bytesWritten, nil
187}
188
189func (s *sendStream) canBufferStreamFrame() bool {
190	var l protocol.ByteCount
191	if s.nextFrame != nil {
192		l = s.nextFrame.DataLen()
193	}
194	return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize
195}
196
197// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
198// maxBytes is the maximum length this frame (including frame header) will have.
199func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool /* has more data to send */) {
200	s.mutex.Lock()
201	f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes)
202	if f != nil {
203		s.numOutstandingFrames++
204	}
205	s.mutex.Unlock()
206
207	if f == nil {
208		return nil, hasMoreData
209	}
210	return &ackhandler.Frame{Frame: f, OnLost: s.queueRetransmission, OnAcked: s.frameAcked}, hasMoreData
211}
212
213func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) {
214	if s.canceledWrite || s.closeForShutdownErr != nil {
215		return nil, false
216	}
217
218	if len(s.retransmissionQueue) > 0 {
219		f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes)
220		if f != nil || hasMoreRetransmissions {
221			if f == nil {
222				return nil, true
223			}
224			// We always claim that we have more data to send.
225			// This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
226			return f, true
227		}
228	}
229
230	if len(s.dataForWriting) == 0 && s.nextFrame == nil {
231		if s.finishedWriting && !s.finSent {
232			s.finSent = true
233			return &wire.StreamFrame{
234				StreamID:       s.streamID,
235				Offset:         s.writeOffset,
236				DataLenPresent: true,
237				Fin:            true,
238			}, false
239		}
240		return nil, false
241	}
242
243	sendWindow := s.flowController.SendWindowSize()
244	if sendWindow == 0 {
245		if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
246			s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
247				StreamID:          s.streamID,
248				MaximumStreamData: offset,
249			})
250			return nil, false
251		}
252		return nil, true
253	}
254
255	f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow)
256	if dataLen := f.DataLen(); dataLen > 0 {
257		s.writeOffset += f.DataLen()
258		s.flowController.AddBytesSent(f.DataLen())
259	}
260	f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
261	if f.Fin {
262		s.finSent = true
263	}
264	return f, hasMoreData
265}
266
267func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) (*wire.StreamFrame, bool) {
268	if s.nextFrame != nil {
269		nextFrame := s.nextFrame
270		s.nextFrame = nil
271
272		maxDataLen := utils.MinByteCount(sendWindow, nextFrame.MaxDataLen(maxBytes, s.version))
273		if nextFrame.DataLen() > maxDataLen {
274			s.nextFrame = wire.GetStreamFrame()
275			s.nextFrame.StreamID = s.streamID
276			s.nextFrame.Offset = s.writeOffset + maxDataLen
277			s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen]
278			s.nextFrame.DataLenPresent = true
279			copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:])
280			nextFrame.Data = nextFrame.Data[:maxDataLen]
281		} else {
282			s.signalWrite()
283		}
284		return nextFrame, s.nextFrame != nil || s.dataForWriting != nil
285	}
286
287	f := wire.GetStreamFrame()
288	f.Fin = false
289	f.StreamID = s.streamID
290	f.Offset = s.writeOffset
291	f.DataLenPresent = true
292	f.Data = f.Data[:0]
293
294	hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow)
295	if len(f.Data) == 0 && !f.Fin {
296		f.PutBack()
297		return nil, hasMoreData
298	}
299	return f, hasMoreData
300}
301
302func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount) bool {
303	maxDataLen := f.MaxDataLen(maxBytes, s.version)
304	if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
305		return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
306	}
307	s.getDataForWriting(f, utils.MinByteCount(maxDataLen, sendWindow))
308
309	return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
310}
311
312func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more retransmissions */) {
313	f := s.retransmissionQueue[0]
314	newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, s.version)
315	if needsSplit {
316		return newFrame, true
317	}
318	s.retransmissionQueue = s.retransmissionQueue[1:]
319	return f, len(s.retransmissionQueue) > 0
320}
321
322func (s *sendStream) hasData() bool {
323	s.mutex.Lock()
324	hasData := len(s.dataForWriting) > 0
325	s.mutex.Unlock()
326	return hasData
327}
328
329func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
330	if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
331		f.Data = f.Data[:len(s.dataForWriting)]
332		copy(f.Data, s.dataForWriting)
333		s.dataForWriting = nil
334		s.signalWrite()
335		return
336	}
337	f.Data = f.Data[:maxBytes]
338	copy(f.Data, s.dataForWriting)
339	s.dataForWriting = s.dataForWriting[maxBytes:]
340	if s.canBufferStreamFrame() {
341		s.signalWrite()
342	}
343}
344
345func (s *sendStream) frameAcked(f wire.Frame) {
346	f.(*wire.StreamFrame).PutBack()
347
348	s.mutex.Lock()
349	if s.canceledWrite {
350		s.mutex.Unlock()
351		return
352	}
353	s.numOutstandingFrames--
354	if s.numOutstandingFrames < 0 {
355		panic("numOutStandingFrames negative")
356	}
357	newlyCompleted := s.isNewlyCompleted()
358	s.mutex.Unlock()
359
360	if newlyCompleted {
361		s.sender.onStreamCompleted(s.streamID)
362	}
363}
364
365func (s *sendStream) isNewlyCompleted() bool {
366	completed := (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
367	if completed && !s.completed {
368		s.completed = true
369		return true
370	}
371	return false
372}
373
374func (s *sendStream) queueRetransmission(f wire.Frame) {
375	sf := f.(*wire.StreamFrame)
376	sf.DataLenPresent = true
377	s.mutex.Lock()
378	if s.canceledWrite {
379		s.mutex.Unlock()
380		return
381	}
382	s.retransmissionQueue = append(s.retransmissionQueue, sf)
383	s.numOutstandingFrames--
384	if s.numOutstandingFrames < 0 {
385		panic("numOutStandingFrames negative")
386	}
387	s.mutex.Unlock()
388
389	s.sender.onHasStreamData(s.streamID)
390}
391
392func (s *sendStream) Close() error {
393	s.mutex.Lock()
394	if s.closedForShutdown {
395		s.mutex.Unlock()
396		return nil
397	}
398	if s.canceledWrite {
399		s.mutex.Unlock()
400		return fmt.Errorf("close called for canceled stream %d", s.streamID)
401	}
402	s.ctxCancel()
403	s.finishedWriting = true
404	s.mutex.Unlock()
405
406	s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
407	return nil
408}
409
410func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
411	s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode))
412}
413
414// must be called after locking the mutex
415func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) {
416	s.mutex.Lock()
417	if s.canceledWrite {
418		s.mutex.Unlock()
419		return
420	}
421	s.ctxCancel()
422	s.canceledWrite = true
423	s.cancelWriteErr = writeErr
424	s.numOutstandingFrames = 0
425	s.retransmissionQueue = nil
426	newlyCompleted := s.isNewlyCompleted()
427	s.mutex.Unlock()
428
429	s.signalWrite()
430	s.sender.queueControlFrame(&wire.ResetStreamFrame{
431		StreamID:  s.streamID,
432		FinalSize: s.writeOffset,
433		ErrorCode: errorCode,
434	})
435	if newlyCompleted {
436		s.sender.onStreamCompleted(s.streamID)
437	}
438}
439
440func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
441	s.mutex.Lock()
442	hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
443	s.mutex.Unlock()
444
445	s.flowController.UpdateSendWindow(limit)
446	if hasStreamData {
447		s.sender.onHasStreamData(s.streamID)
448	}
449}
450
451func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
452	s.cancelWriteImpl(frame.ErrorCode, &StreamError{
453		StreamID:  s.streamID,
454		ErrorCode: frame.ErrorCode,
455	})
456}
457
458func (s *sendStream) Context() context.Context {
459	return s.ctx
460}
461
462func (s *sendStream) SetWriteDeadline(t time.Time) error {
463	s.mutex.Lock()
464	s.deadline = t
465	s.mutex.Unlock()
466	s.signalWrite()
467	return nil
468}
469
470// CloseForShutdown closes a stream abruptly.
471// It makes Write unblock (and return the error) immediately.
472// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
473func (s *sendStream) closeForShutdown(err error) {
474	s.mutex.Lock()
475	s.ctxCancel()
476	s.closedForShutdown = true
477	s.closeForShutdownErr = err
478	s.mutex.Unlock()
479	s.signalWrite()
480}
481
482// signalWrite performs a non-blocking send on the writeChan
483func (s *sendStream) signalWrite() {
484	select {
485	case s.writeChan <- struct{}{}:
486	default:
487	}
488}
489