1package yamux 2 3import ( 4 "bytes" 5 "io" 6 "sync" 7 "sync/atomic" 8 "time" 9) 10 11type streamState int 12 13const ( 14 streamInit streamState = iota 15 streamSYNSent 16 streamSYNReceived 17 streamEstablished 18 streamLocalClose 19 streamRemoteClose 20 streamClosed 21 streamReset 22) 23 24// Stream is used to represent a logical stream 25// within a session. 26type Stream struct { 27 recvWindow uint32 28 sendWindow uint32 29 30 id uint32 31 session *Session 32 33 state streamState 34 stateLock sync.Mutex 35 36 recvBuf *bytes.Buffer 37 recvLock sync.Mutex 38 39 controlHdr header 40 controlErr chan error 41 controlHdrLock sync.Mutex 42 43 sendHdr header 44 sendErr chan error 45 sendLock sync.Mutex 46 47 recvNotifyCh chan struct{} 48 sendNotifyCh chan struct{} 49 50 readDeadline atomic.Value // time.Time 51 writeDeadline atomic.Value // time.Time 52} 53 54// newStream is used to construct a new stream within 55// a given session for an ID 56func newStream(session *Session, id uint32, state streamState) *Stream { 57 s := &Stream{ 58 id: id, 59 session: session, 60 state: state, 61 controlHdr: header(make([]byte, headerSize)), 62 controlErr: make(chan error, 1), 63 sendHdr: header(make([]byte, headerSize)), 64 sendErr: make(chan error, 1), 65 recvWindow: initialStreamWindow, 66 sendWindow: initialStreamWindow, 67 recvNotifyCh: make(chan struct{}, 1), 68 sendNotifyCh: make(chan struct{}, 1), 69 } 70 s.readDeadline.Store(time.Time{}) 71 s.writeDeadline.Store(time.Time{}) 72 return s 73} 74 75// Session returns the associated stream session 76func (s *Stream) Session() *Session { 77 return s.session 78} 79 80// StreamID returns the ID of this stream 81func (s *Stream) StreamID() uint32 { 82 return s.id 83} 84 85// Read is used to read from the stream 86func (s *Stream) Read(b []byte) (n int, err error) { 87 defer asyncNotify(s.recvNotifyCh) 88START: 89 s.stateLock.Lock() 90 switch s.state { 91 case streamLocalClose: 92 fallthrough 93 case streamRemoteClose: 94 fallthrough 95 case streamClosed: 96 s.recvLock.Lock() 97 if s.recvBuf == nil || s.recvBuf.Len() == 0 { 98 s.recvLock.Unlock() 99 s.stateLock.Unlock() 100 return 0, io.EOF 101 } 102 s.recvLock.Unlock() 103 case streamReset: 104 s.stateLock.Unlock() 105 return 0, ErrConnectionReset 106 } 107 s.stateLock.Unlock() 108 109 // If there is no data available, block 110 s.recvLock.Lock() 111 if s.recvBuf == nil || s.recvBuf.Len() == 0 { 112 s.recvLock.Unlock() 113 goto WAIT 114 } 115 116 // Read any bytes 117 n, _ = s.recvBuf.Read(b) 118 s.recvLock.Unlock() 119 120 // Send a window update potentially 121 err = s.sendWindowUpdate() 122 return n, err 123 124WAIT: 125 var timeout <-chan time.Time 126 var timer *time.Timer 127 readDeadline := s.readDeadline.Load().(time.Time) 128 if !readDeadline.IsZero() { 129 delay := readDeadline.Sub(time.Now()) 130 timer = time.NewTimer(delay) 131 timeout = timer.C 132 } 133 select { 134 case <-s.recvNotifyCh: 135 if timer != nil { 136 timer.Stop() 137 } 138 goto START 139 case <-timeout: 140 return 0, ErrTimeout 141 } 142} 143 144// Write is used to write to the stream 145func (s *Stream) Write(b []byte) (n int, err error) { 146 s.sendLock.Lock() 147 defer s.sendLock.Unlock() 148 total := 0 149 for total < len(b) { 150 n, err := s.write(b[total:]) 151 total += n 152 if err != nil { 153 return total, err 154 } 155 } 156 return total, nil 157} 158 159// write is used to write to the stream, may return on 160// a short write. 161func (s *Stream) write(b []byte) (n int, err error) { 162 var flags uint16 163 var max uint32 164 var body io.Reader 165START: 166 s.stateLock.Lock() 167 switch s.state { 168 case streamLocalClose: 169 fallthrough 170 case streamClosed: 171 s.stateLock.Unlock() 172 return 0, ErrStreamClosed 173 case streamReset: 174 s.stateLock.Unlock() 175 return 0, ErrConnectionReset 176 } 177 s.stateLock.Unlock() 178 179 // If there is no data available, block 180 window := atomic.LoadUint32(&s.sendWindow) 181 if window == 0 { 182 goto WAIT 183 } 184 185 // Determine the flags if any 186 flags = s.sendFlags() 187 188 // Send up to our send window 189 max = min(window, uint32(len(b))) 190 body = bytes.NewReader(b[:max]) 191 192 // Send the header 193 s.sendHdr.encode(typeData, flags, s.id, max) 194 if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil { 195 return 0, err 196 } 197 198 // Reduce our send window 199 atomic.AddUint32(&s.sendWindow, ^uint32(max-1)) 200 201 // Unlock 202 return int(max), err 203 204WAIT: 205 var timeout <-chan time.Time 206 writeDeadline := s.writeDeadline.Load().(time.Time) 207 if !writeDeadline.IsZero() { 208 delay := writeDeadline.Sub(time.Now()) 209 timeout = time.After(delay) 210 } 211 select { 212 case <-s.sendNotifyCh: 213 goto START 214 case <-timeout: 215 return 0, ErrTimeout 216 } 217 return 0, nil 218} 219 220// sendFlags determines any flags that are appropriate 221// based on the current stream state 222func (s *Stream) sendFlags() uint16 { 223 s.stateLock.Lock() 224 defer s.stateLock.Unlock() 225 var flags uint16 226 switch s.state { 227 case streamInit: 228 flags |= flagSYN 229 s.state = streamSYNSent 230 case streamSYNReceived: 231 flags |= flagACK 232 s.state = streamEstablished 233 } 234 return flags 235} 236 237// sendWindowUpdate potentially sends a window update enabling 238// further writes to take place. Must be invoked with the lock. 239func (s *Stream) sendWindowUpdate() error { 240 s.controlHdrLock.Lock() 241 defer s.controlHdrLock.Unlock() 242 243 // Determine the delta update 244 max := s.session.config.MaxStreamWindowSize 245 var bufLen uint32 246 s.recvLock.Lock() 247 if s.recvBuf != nil { 248 bufLen = uint32(s.recvBuf.Len()) 249 } 250 delta := (max - bufLen) - s.recvWindow 251 252 // Determine the flags if any 253 flags := s.sendFlags() 254 255 // Check if we can omit the update 256 if delta < (max/2) && flags == 0 { 257 s.recvLock.Unlock() 258 return nil 259 } 260 261 // Update our window 262 s.recvWindow += delta 263 s.recvLock.Unlock() 264 265 // Send the header 266 s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) 267 if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { 268 return err 269 } 270 return nil 271} 272 273// sendClose is used to send a FIN 274func (s *Stream) sendClose() error { 275 s.controlHdrLock.Lock() 276 defer s.controlHdrLock.Unlock() 277 278 flags := s.sendFlags() 279 flags |= flagFIN 280 s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0) 281 if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { 282 return err 283 } 284 return nil 285} 286 287// Close is used to close the stream 288func (s *Stream) Close() error { 289 closeStream := false 290 s.stateLock.Lock() 291 switch s.state { 292 // Opened means we need to signal a close 293 case streamSYNSent: 294 fallthrough 295 case streamSYNReceived: 296 fallthrough 297 case streamEstablished: 298 s.state = streamLocalClose 299 goto SEND_CLOSE 300 301 case streamLocalClose: 302 case streamRemoteClose: 303 s.state = streamClosed 304 closeStream = true 305 goto SEND_CLOSE 306 307 case streamClosed: 308 case streamReset: 309 default: 310 panic("unhandled state") 311 } 312 s.stateLock.Unlock() 313 return nil 314SEND_CLOSE: 315 s.stateLock.Unlock() 316 s.sendClose() 317 s.notifyWaiting() 318 if closeStream { 319 s.session.closeStream(s.id) 320 } 321 return nil 322} 323 324// forceClose is used for when the session is exiting 325func (s *Stream) forceClose() { 326 s.stateLock.Lock() 327 s.state = streamClosed 328 s.stateLock.Unlock() 329 s.notifyWaiting() 330} 331 332// processFlags is used to update the state of the stream 333// based on set flags, if any. Lock must be held 334func (s *Stream) processFlags(flags uint16) error { 335 // Close the stream without holding the state lock 336 closeStream := false 337 defer func() { 338 if closeStream { 339 s.session.closeStream(s.id) 340 } 341 }() 342 343 s.stateLock.Lock() 344 defer s.stateLock.Unlock() 345 if flags&flagACK == flagACK { 346 if s.state == streamSYNSent { 347 s.state = streamEstablished 348 } 349 s.session.establishStream(s.id) 350 } 351 if flags&flagFIN == flagFIN { 352 switch s.state { 353 case streamSYNSent: 354 fallthrough 355 case streamSYNReceived: 356 fallthrough 357 case streamEstablished: 358 s.state = streamRemoteClose 359 s.notifyWaiting() 360 case streamLocalClose: 361 s.state = streamClosed 362 closeStream = true 363 s.notifyWaiting() 364 default: 365 s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state) 366 return ErrUnexpectedFlag 367 } 368 } 369 if flags&flagRST == flagRST { 370 s.state = streamReset 371 closeStream = true 372 s.notifyWaiting() 373 } 374 return nil 375} 376 377// notifyWaiting notifies all the waiting channels 378func (s *Stream) notifyWaiting() { 379 asyncNotify(s.recvNotifyCh) 380 asyncNotify(s.sendNotifyCh) 381} 382 383// incrSendWindow updates the size of our send window 384func (s *Stream) incrSendWindow(hdr header, flags uint16) error { 385 if err := s.processFlags(flags); err != nil { 386 return err 387 } 388 389 // Increase window, unblock a sender 390 atomic.AddUint32(&s.sendWindow, hdr.Length()) 391 asyncNotify(s.sendNotifyCh) 392 return nil 393} 394 395// readData is used to handle a data frame 396func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { 397 if err := s.processFlags(flags); err != nil { 398 return err 399 } 400 401 // Check that our recv window is not exceeded 402 length := hdr.Length() 403 if length == 0 { 404 return nil 405 } 406 407 // Wrap in a limited reader 408 conn = &io.LimitedReader{R: conn, N: int64(length)} 409 410 // Copy into buffer 411 s.recvLock.Lock() 412 413 if length > s.recvWindow { 414 s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length) 415 return ErrRecvWindowExceeded 416 } 417 418 if s.recvBuf == nil { 419 // Allocate the receive buffer just-in-time to fit the full data frame. 420 // This way we can read in the whole packet without further allocations. 421 s.recvBuf = bytes.NewBuffer(make([]byte, 0, length)) 422 } 423 if _, err := io.Copy(s.recvBuf, conn); err != nil { 424 s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err) 425 s.recvLock.Unlock() 426 return err 427 } 428 429 // Decrement the receive window 430 s.recvWindow -= length 431 s.recvLock.Unlock() 432 433 // Unblock any readers 434 asyncNotify(s.recvNotifyCh) 435 return nil 436} 437 438// SetDeadline sets the read and write deadlines 439func (s *Stream) SetDeadline(t time.Time) error { 440 if err := s.SetReadDeadline(t); err != nil { 441 return err 442 } 443 if err := s.SetWriteDeadline(t); err != nil { 444 return err 445 } 446 return nil 447} 448 449// SetReadDeadline sets the deadline for future Read calls. 450func (s *Stream) SetReadDeadline(t time.Time) error { 451 s.readDeadline.Store(t) 452 return nil 453} 454 455// SetWriteDeadline sets the deadline for future Write calls 456func (s *Stream) SetWriteDeadline(t time.Time) error { 457 s.writeDeadline.Store(t) 458 return nil 459} 460 461// Shrink is used to compact the amount of buffers utilized 462// This is useful when using Yamux in a connection pool to reduce 463// the idle memory utilization. 464func (s *Stream) Shrink() { 465 s.recvLock.Lock() 466 if s.recvBuf != nil && s.recvBuf.Len() == 0 { 467 s.recvBuf = nil 468 } 469 s.recvLock.Unlock() 470} 471