1package yamux
2
3import (
4	"bytes"
5	"io"
6	"sync"
7	"sync/atomic"
8	"time"
9)
10
11type streamState int
12
13const (
14	streamInit streamState = iota
15	streamSYNSent
16	streamSYNReceived
17	streamEstablished
18	streamLocalClose
19	streamRemoteClose
20	streamClosed
21	streamReset
22)
23
24// Stream is used to represent a logical stream
25// within a session.
26type Stream struct {
27	recvWindow uint32
28	sendWindow uint32
29
30	id      uint32
31	session *Session
32
33	state     streamState
34	stateLock sync.Mutex
35
36	recvBuf  *bytes.Buffer
37	recvLock sync.Mutex
38
39	controlHdr     header
40	controlErr     chan error
41	controlHdrLock sync.Mutex
42
43	sendHdr  header
44	sendErr  chan error
45	sendLock sync.Mutex
46
47	recvNotifyCh chan struct{}
48	sendNotifyCh chan struct{}
49
50	readDeadline  atomic.Value // time.Time
51	writeDeadline atomic.Value // time.Time
52}
53
54// newStream is used to construct a new stream within
55// a given session for an ID
56func newStream(session *Session, id uint32, state streamState) *Stream {
57	s := &Stream{
58		id:           id,
59		session:      session,
60		state:        state,
61		controlHdr:   header(make([]byte, headerSize)),
62		controlErr:   make(chan error, 1),
63		sendHdr:      header(make([]byte, headerSize)),
64		sendErr:      make(chan error, 1),
65		recvWindow:   initialStreamWindow,
66		sendWindow:   initialStreamWindow,
67		recvNotifyCh: make(chan struct{}, 1),
68		sendNotifyCh: make(chan struct{}, 1),
69	}
70	s.readDeadline.Store(time.Time{})
71	s.writeDeadline.Store(time.Time{})
72	return s
73}
74
75// Session returns the associated stream session
76func (s *Stream) Session() *Session {
77	return s.session
78}
79
80// StreamID returns the ID of this stream
81func (s *Stream) StreamID() uint32 {
82	return s.id
83}
84
85// Read is used to read from the stream
86func (s *Stream) Read(b []byte) (n int, err error) {
87	defer asyncNotify(s.recvNotifyCh)
88START:
89	s.stateLock.Lock()
90	switch s.state {
91	case streamLocalClose:
92		fallthrough
93	case streamRemoteClose:
94		fallthrough
95	case streamClosed:
96		s.recvLock.Lock()
97		if s.recvBuf == nil || s.recvBuf.Len() == 0 {
98			s.recvLock.Unlock()
99			s.stateLock.Unlock()
100			return 0, io.EOF
101		}
102		s.recvLock.Unlock()
103	case streamReset:
104		s.stateLock.Unlock()
105		return 0, ErrConnectionReset
106	}
107	s.stateLock.Unlock()
108
109	// If there is no data available, block
110	s.recvLock.Lock()
111	if s.recvBuf == nil || s.recvBuf.Len() == 0 {
112		s.recvLock.Unlock()
113		goto WAIT
114	}
115
116	// Read any bytes
117	n, _ = s.recvBuf.Read(b)
118	s.recvLock.Unlock()
119
120	// Send a window update potentially
121	err = s.sendWindowUpdate()
122	return n, err
123
124WAIT:
125	var timeout <-chan time.Time
126	var timer *time.Timer
127	readDeadline := s.readDeadline.Load().(time.Time)
128	if !readDeadline.IsZero() {
129		delay := readDeadline.Sub(time.Now())
130		timer = time.NewTimer(delay)
131		timeout = timer.C
132	}
133	select {
134	case <-s.recvNotifyCh:
135		if timer != nil {
136			timer.Stop()
137		}
138		goto START
139	case <-timeout:
140		return 0, ErrTimeout
141	}
142}
143
144// Write is used to write to the stream
145func (s *Stream) Write(b []byte) (n int, err error) {
146	s.sendLock.Lock()
147	defer s.sendLock.Unlock()
148	total := 0
149	for total < len(b) {
150		n, err := s.write(b[total:])
151		total += n
152		if err != nil {
153			return total, err
154		}
155	}
156	return total, nil
157}
158
159// write is used to write to the stream, may return on
160// a short write.
161func (s *Stream) write(b []byte) (n int, err error) {
162	var flags uint16
163	var max uint32
164	var body io.Reader
165START:
166	s.stateLock.Lock()
167	switch s.state {
168	case streamLocalClose:
169		fallthrough
170	case streamClosed:
171		s.stateLock.Unlock()
172		return 0, ErrStreamClosed
173	case streamReset:
174		s.stateLock.Unlock()
175		return 0, ErrConnectionReset
176	}
177	s.stateLock.Unlock()
178
179	// If there is no data available, block
180	window := atomic.LoadUint32(&s.sendWindow)
181	if window == 0 {
182		goto WAIT
183	}
184
185	// Determine the flags if any
186	flags = s.sendFlags()
187
188	// Send up to our send window
189	max = min(window, uint32(len(b)))
190	body = bytes.NewReader(b[:max])
191
192	// Send the header
193	s.sendHdr.encode(typeData, flags, s.id, max)
194	if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
195		return 0, err
196	}
197
198	// Reduce our send window
199	atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
200
201	// Unlock
202	return int(max), err
203
204WAIT:
205	var timeout <-chan time.Time
206	writeDeadline := s.writeDeadline.Load().(time.Time)
207	if !writeDeadline.IsZero() {
208		delay := writeDeadline.Sub(time.Now())
209		timeout = time.After(delay)
210	}
211	select {
212	case <-s.sendNotifyCh:
213		goto START
214	case <-timeout:
215		return 0, ErrTimeout
216	}
217	return 0, nil
218}
219
220// sendFlags determines any flags that are appropriate
221// based on the current stream state
222func (s *Stream) sendFlags() uint16 {
223	s.stateLock.Lock()
224	defer s.stateLock.Unlock()
225	var flags uint16
226	switch s.state {
227	case streamInit:
228		flags |= flagSYN
229		s.state = streamSYNSent
230	case streamSYNReceived:
231		flags |= flagACK
232		s.state = streamEstablished
233	}
234	return flags
235}
236
237// sendWindowUpdate potentially sends a window update enabling
238// further writes to take place. Must be invoked with the lock.
239func (s *Stream) sendWindowUpdate() error {
240	s.controlHdrLock.Lock()
241	defer s.controlHdrLock.Unlock()
242
243	// Determine the delta update
244	max := s.session.config.MaxStreamWindowSize
245	var bufLen uint32
246	s.recvLock.Lock()
247	if s.recvBuf != nil {
248		bufLen = uint32(s.recvBuf.Len())
249	}
250	delta := (max - bufLen) - s.recvWindow
251
252	// Determine the flags if any
253	flags := s.sendFlags()
254
255	// Check if we can omit the update
256	if delta < (max/2) && flags == 0 {
257		s.recvLock.Unlock()
258		return nil
259	}
260
261	// Update our window
262	s.recvWindow += delta
263	s.recvLock.Unlock()
264
265	// Send the header
266	s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
267	if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
268		return err
269	}
270	return nil
271}
272
273// sendClose is used to send a FIN
274func (s *Stream) sendClose() error {
275	s.controlHdrLock.Lock()
276	defer s.controlHdrLock.Unlock()
277
278	flags := s.sendFlags()
279	flags |= flagFIN
280	s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
281	if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
282		return err
283	}
284	return nil
285}
286
287// Close is used to close the stream
288func (s *Stream) Close() error {
289	closeStream := false
290	s.stateLock.Lock()
291	switch s.state {
292	// Opened means we need to signal a close
293	case streamSYNSent:
294		fallthrough
295	case streamSYNReceived:
296		fallthrough
297	case streamEstablished:
298		s.state = streamLocalClose
299		goto SEND_CLOSE
300
301	case streamLocalClose:
302	case streamRemoteClose:
303		s.state = streamClosed
304		closeStream = true
305		goto SEND_CLOSE
306
307	case streamClosed:
308	case streamReset:
309	default:
310		panic("unhandled state")
311	}
312	s.stateLock.Unlock()
313	return nil
314SEND_CLOSE:
315	s.stateLock.Unlock()
316	s.sendClose()
317	s.notifyWaiting()
318	if closeStream {
319		s.session.closeStream(s.id)
320	}
321	return nil
322}
323
324// forceClose is used for when the session is exiting
325func (s *Stream) forceClose() {
326	s.stateLock.Lock()
327	s.state = streamClosed
328	s.stateLock.Unlock()
329	s.notifyWaiting()
330}
331
332// processFlags is used to update the state of the stream
333// based on set flags, if any. Lock must be held
334func (s *Stream) processFlags(flags uint16) error {
335	// Close the stream without holding the state lock
336	closeStream := false
337	defer func() {
338		if closeStream {
339			s.session.closeStream(s.id)
340		}
341	}()
342
343	s.stateLock.Lock()
344	defer s.stateLock.Unlock()
345	if flags&flagACK == flagACK {
346		if s.state == streamSYNSent {
347			s.state = streamEstablished
348		}
349		s.session.establishStream(s.id)
350	}
351	if flags&flagFIN == flagFIN {
352		switch s.state {
353		case streamSYNSent:
354			fallthrough
355		case streamSYNReceived:
356			fallthrough
357		case streamEstablished:
358			s.state = streamRemoteClose
359			s.notifyWaiting()
360		case streamLocalClose:
361			s.state = streamClosed
362			closeStream = true
363			s.notifyWaiting()
364		default:
365			s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
366			return ErrUnexpectedFlag
367		}
368	}
369	if flags&flagRST == flagRST {
370		s.state = streamReset
371		closeStream = true
372		s.notifyWaiting()
373	}
374	return nil
375}
376
377// notifyWaiting notifies all the waiting channels
378func (s *Stream) notifyWaiting() {
379	asyncNotify(s.recvNotifyCh)
380	asyncNotify(s.sendNotifyCh)
381}
382
383// incrSendWindow updates the size of our send window
384func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
385	if err := s.processFlags(flags); err != nil {
386		return err
387	}
388
389	// Increase window, unblock a sender
390	atomic.AddUint32(&s.sendWindow, hdr.Length())
391	asyncNotify(s.sendNotifyCh)
392	return nil
393}
394
395// readData is used to handle a data frame
396func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
397	if err := s.processFlags(flags); err != nil {
398		return err
399	}
400
401	// Check that our recv window is not exceeded
402	length := hdr.Length()
403	if length == 0 {
404		return nil
405	}
406
407	// Wrap in a limited reader
408	conn = &io.LimitedReader{R: conn, N: int64(length)}
409
410	// Copy into buffer
411	s.recvLock.Lock()
412
413	if length > s.recvWindow {
414		s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
415		return ErrRecvWindowExceeded
416	}
417
418	if s.recvBuf == nil {
419		// Allocate the receive buffer just-in-time to fit the full data frame.
420		// This way we can read in the whole packet without further allocations.
421		s.recvBuf = bytes.NewBuffer(make([]byte, 0, length))
422	}
423	if _, err := io.Copy(s.recvBuf, conn); err != nil {
424		s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
425		s.recvLock.Unlock()
426		return err
427	}
428
429	// Decrement the receive window
430	s.recvWindow -= length
431	s.recvLock.Unlock()
432
433	// Unblock any readers
434	asyncNotify(s.recvNotifyCh)
435	return nil
436}
437
438// SetDeadline sets the read and write deadlines
439func (s *Stream) SetDeadline(t time.Time) error {
440	if err := s.SetReadDeadline(t); err != nil {
441		return err
442	}
443	if err := s.SetWriteDeadline(t); err != nil {
444		return err
445	}
446	return nil
447}
448
449// SetReadDeadline sets the deadline for future Read calls.
450func (s *Stream) SetReadDeadline(t time.Time) error {
451	s.readDeadline.Store(t)
452	return nil
453}
454
455// SetWriteDeadline sets the deadline for future Write calls
456func (s *Stream) SetWriteDeadline(t time.Time) error {
457	s.writeDeadline.Store(t)
458	return nil
459}
460
461// Shrink is used to compact the amount of buffers utilized
462// This is useful when using Yamux in a connection pool to reduce
463// the idle memory utilization.
464func (s *Stream) Shrink() {
465	s.recvLock.Lock()
466	if s.recvBuf != nil && s.recvBuf.Len() == 0 {
467		s.recvBuf = nil
468	}
469	s.recvLock.Unlock()
470}
471