1package yamux 2 3import ( 4 "bufio" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "log" 9 "math" 10 "net" 11 "strings" 12 "sync" 13 "sync/atomic" 14 "time" 15) 16 17// Session is used to wrap a reliable ordered connection and to 18// multiplex it into multiple streams. 19type Session struct { 20 // remoteGoAway indicates the remote side does 21 // not want futher connections. Must be first for alignment. 22 remoteGoAway int32 23 24 // localGoAway indicates that we should stop 25 // accepting futher connections. Must be first for alignment. 26 localGoAway int32 27 28 // nextStreamID is the next stream we should 29 // send. This depends if we are a client/server. 30 nextStreamID uint32 31 32 // config holds our configuration 33 config *Config 34 35 // logger is used for our logs 36 logger *log.Logger 37 38 // conn is the underlying connection 39 conn io.ReadWriteCloser 40 41 // bufRead is a buffered reader 42 bufRead *bufio.Reader 43 44 // pings is used to track inflight pings 45 pings map[uint32]chan struct{} 46 pingID uint32 47 pingLock sync.Mutex 48 49 // streams maps a stream id to a stream, and inflight has an entry 50 // for any outgoing stream that has not yet been established. Both are 51 // protected by streamLock. 52 streams map[uint32]*Stream 53 inflight map[uint32]struct{} 54 streamLock sync.Mutex 55 56 // synCh acts like a semaphore. It is sized to the AcceptBacklog which 57 // is assumed to be symmetric between the client and server. This allows 58 // the client to avoid exceeding the backlog and instead blocks the open. 59 synCh chan struct{} 60 61 // acceptCh is used to pass ready streams to the client 62 acceptCh chan *Stream 63 64 // sendCh is used to mark a stream as ready to send, 65 // or to send a header out directly. 66 sendCh chan sendReady 67 68 // recvDoneCh is closed when recv() exits to avoid a race 69 // between stream registration and stream shutdown 70 recvDoneCh chan struct{} 71 72 // shutdown is used to safely close a session 73 shutdown bool 74 shutdownErr error 75 shutdownCh chan struct{} 76 shutdownLock sync.Mutex 77} 78 79// sendReady is used to either mark a stream as ready 80// or to directly send a header 81type sendReady struct { 82 Hdr []byte 83 Body io.Reader 84 Err chan error 85} 86 87// newSession is used to construct a new session 88func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { 89 logger := config.Logger 90 if logger == nil { 91 logger = log.New(config.LogOutput, "", log.LstdFlags) 92 } 93 94 s := &Session{ 95 config: config, 96 logger: logger, 97 conn: conn, 98 bufRead: bufio.NewReader(conn), 99 pings: make(map[uint32]chan struct{}), 100 streams: make(map[uint32]*Stream), 101 inflight: make(map[uint32]struct{}), 102 synCh: make(chan struct{}, config.AcceptBacklog), 103 acceptCh: make(chan *Stream, config.AcceptBacklog), 104 sendCh: make(chan sendReady, 64), 105 recvDoneCh: make(chan struct{}), 106 shutdownCh: make(chan struct{}), 107 } 108 if client { 109 s.nextStreamID = 1 110 } else { 111 s.nextStreamID = 2 112 } 113 go s.recv() 114 go s.send() 115 if config.EnableKeepAlive { 116 go s.keepalive() 117 } 118 return s 119} 120 121// IsClosed does a safe check to see if we have shutdown 122func (s *Session) IsClosed() bool { 123 select { 124 case <-s.shutdownCh: 125 return true 126 default: 127 return false 128 } 129} 130 131// CloseChan returns a read-only channel which is closed as 132// soon as the session is closed. 133func (s *Session) CloseChan() <-chan struct{} { 134 return s.shutdownCh 135} 136 137// NumStreams returns the number of currently open streams 138func (s *Session) NumStreams() int { 139 s.streamLock.Lock() 140 num := len(s.streams) 141 s.streamLock.Unlock() 142 return num 143} 144 145// Open is used to create a new stream as a net.Conn 146func (s *Session) Open() (net.Conn, error) { 147 conn, err := s.OpenStream() 148 if err != nil { 149 return nil, err 150 } 151 return conn, nil 152} 153 154// OpenStream is used to create a new stream 155func (s *Session) OpenStream() (*Stream, error) { 156 if s.IsClosed() { 157 return nil, ErrSessionShutdown 158 } 159 if atomic.LoadInt32(&s.remoteGoAway) == 1 { 160 return nil, ErrRemoteGoAway 161 } 162 163 // Block if we have too many inflight SYNs 164 select { 165 case s.synCh <- struct{}{}: 166 case <-s.shutdownCh: 167 return nil, ErrSessionShutdown 168 } 169 170GET_ID: 171 // Get an ID, and check for stream exhaustion 172 id := atomic.LoadUint32(&s.nextStreamID) 173 if id >= math.MaxUint32-1 { 174 return nil, ErrStreamsExhausted 175 } 176 if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) { 177 goto GET_ID 178 } 179 180 // Register the stream 181 stream := newStream(s, id, streamInit) 182 s.streamLock.Lock() 183 s.streams[id] = stream 184 s.inflight[id] = struct{}{} 185 s.streamLock.Unlock() 186 187 if s.config.StreamOpenTimeout > 0 { 188 go s.setOpenTimeout(stream) 189 } 190 191 // Send the window update to create 192 if err := stream.sendWindowUpdate(); err != nil { 193 select { 194 case <-s.synCh: 195 default: 196 s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore") 197 } 198 return nil, err 199 } 200 return stream, nil 201} 202 203// setOpenTimeout implements a timeout for streams that are opened but not established. 204// If the StreamOpenTimeout is exceeded we assume the peer is unable to ACK, 205// and close the session. 206// The number of running timers is bounded by the capacity of the synCh. 207func (s *Session) setOpenTimeout(stream *Stream) { 208 timer := time.NewTimer(s.config.StreamOpenTimeout) 209 defer timer.Stop() 210 211 select { 212 case <-stream.establishCh: 213 return 214 case <-s.shutdownCh: 215 return 216 case <-timer.C: 217 // Timeout reached while waiting for ACK. 218 // Close the session to force connection re-establishment. 219 s.logger.Printf("[ERR] yamux: aborted stream open (destination=%s): %v", s.RemoteAddr().String(), ErrTimeout.err) 220 s.Close() 221 } 222} 223 224// Accept is used to block until the next available stream 225// is ready to be accepted. 226func (s *Session) Accept() (net.Conn, error) { 227 conn, err := s.AcceptStream() 228 if err != nil { 229 return nil, err 230 } 231 return conn, err 232} 233 234// AcceptStream is used to block until the next available stream 235// is ready to be accepted. 236func (s *Session) AcceptStream() (*Stream, error) { 237 select { 238 case stream := <-s.acceptCh: 239 if err := stream.sendWindowUpdate(); err != nil { 240 return nil, err 241 } 242 return stream, nil 243 case <-s.shutdownCh: 244 return nil, s.shutdownErr 245 } 246} 247 248// Close is used to close the session and all streams. 249// Attempts to send a GoAway before closing the connection. 250func (s *Session) Close() error { 251 s.shutdownLock.Lock() 252 defer s.shutdownLock.Unlock() 253 254 if s.shutdown { 255 return nil 256 } 257 s.shutdown = true 258 if s.shutdownErr == nil { 259 s.shutdownErr = ErrSessionShutdown 260 } 261 close(s.shutdownCh) 262 s.conn.Close() 263 <-s.recvDoneCh 264 265 s.streamLock.Lock() 266 defer s.streamLock.Unlock() 267 for _, stream := range s.streams { 268 stream.forceClose() 269 } 270 return nil 271} 272 273// exitErr is used to handle an error that is causing the 274// session to terminate. 275func (s *Session) exitErr(err error) { 276 s.shutdownLock.Lock() 277 if s.shutdownErr == nil { 278 s.shutdownErr = err 279 } 280 s.shutdownLock.Unlock() 281 s.Close() 282} 283 284// GoAway can be used to prevent accepting further 285// connections. It does not close the underlying conn. 286func (s *Session) GoAway() error { 287 return s.waitForSend(s.goAway(goAwayNormal), nil) 288} 289 290// goAway is used to send a goAway message 291func (s *Session) goAway(reason uint32) header { 292 atomic.SwapInt32(&s.localGoAway, 1) 293 hdr := header(make([]byte, headerSize)) 294 hdr.encode(typeGoAway, 0, 0, reason) 295 return hdr 296} 297 298// Ping is used to measure the RTT response time 299func (s *Session) Ping() (time.Duration, error) { 300 // Get a channel for the ping 301 ch := make(chan struct{}) 302 303 // Get a new ping id, mark as pending 304 s.pingLock.Lock() 305 id := s.pingID 306 s.pingID++ 307 s.pings[id] = ch 308 s.pingLock.Unlock() 309 310 // Send the ping request 311 hdr := header(make([]byte, headerSize)) 312 hdr.encode(typePing, flagSYN, 0, id) 313 if err := s.waitForSend(hdr, nil); err != nil { 314 return 0, err 315 } 316 317 // Wait for a response 318 start := time.Now() 319 select { 320 case <-ch: 321 case <-time.After(s.config.ConnectionWriteTimeout): 322 s.pingLock.Lock() 323 delete(s.pings, id) // Ignore it if a response comes later. 324 s.pingLock.Unlock() 325 return 0, ErrTimeout 326 case <-s.shutdownCh: 327 return 0, ErrSessionShutdown 328 } 329 330 // Compute the RTT 331 return time.Now().Sub(start), nil 332} 333 334// keepalive is a long running goroutine that periodically does 335// a ping to keep the connection alive. 336func (s *Session) keepalive() { 337 for { 338 select { 339 case <-time.After(s.config.KeepAliveInterval): 340 _, err := s.Ping() 341 if err != nil { 342 if err != ErrSessionShutdown { 343 s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) 344 s.exitErr(ErrKeepAliveTimeout) 345 } 346 return 347 } 348 case <-s.shutdownCh: 349 return 350 } 351 } 352} 353 354// waitForSendErr waits to send a header, checking for a potential shutdown 355func (s *Session) waitForSend(hdr header, body io.Reader) error { 356 errCh := make(chan error, 1) 357 return s.waitForSendErr(hdr, body, errCh) 358} 359 360// waitForSendErr waits to send a header with optional data, checking for a 361// potential shutdown. Since there's the expectation that sends can happen 362// in a timely manner, we enforce the connection write timeout here. 363func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error { 364 t := timerPool.Get() 365 timer := t.(*time.Timer) 366 timer.Reset(s.config.ConnectionWriteTimeout) 367 defer func() { 368 timer.Stop() 369 select { 370 case <-timer.C: 371 default: 372 } 373 timerPool.Put(t) 374 }() 375 376 ready := sendReady{Hdr: hdr, Body: body, Err: errCh} 377 select { 378 case s.sendCh <- ready: 379 case <-s.shutdownCh: 380 return ErrSessionShutdown 381 case <-timer.C: 382 return ErrConnectionWriteTimeout 383 } 384 385 select { 386 case err := <-errCh: 387 return err 388 case <-s.shutdownCh: 389 return ErrSessionShutdown 390 case <-timer.C: 391 return ErrConnectionWriteTimeout 392 } 393} 394 395// sendNoWait does a send without waiting. Since there's the expectation that 396// the send happens right here, we enforce the connection write timeout if we 397// can't queue the header to be sent. 398func (s *Session) sendNoWait(hdr header) error { 399 t := timerPool.Get() 400 timer := t.(*time.Timer) 401 timer.Reset(s.config.ConnectionWriteTimeout) 402 defer func() { 403 timer.Stop() 404 select { 405 case <-timer.C: 406 default: 407 } 408 timerPool.Put(t) 409 }() 410 411 select { 412 case s.sendCh <- sendReady{Hdr: hdr}: 413 return nil 414 case <-s.shutdownCh: 415 return ErrSessionShutdown 416 case <-timer.C: 417 return ErrConnectionWriteTimeout 418 } 419} 420 421// send is a long running goroutine that sends data 422func (s *Session) send() { 423 for { 424 select { 425 case ready := <-s.sendCh: 426 // Send a header if ready 427 if ready.Hdr != nil { 428 sent := 0 429 for sent < len(ready.Hdr) { 430 n, err := s.conn.Write(ready.Hdr[sent:]) 431 if err != nil { 432 s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) 433 asyncSendErr(ready.Err, err) 434 s.exitErr(err) 435 return 436 } 437 sent += n 438 } 439 } 440 441 // Send data from a body if given 442 if ready.Body != nil { 443 _, err := io.Copy(s.conn, ready.Body) 444 if err != nil { 445 s.logger.Printf("[ERR] yamux: Failed to write body: %v", err) 446 asyncSendErr(ready.Err, err) 447 s.exitErr(err) 448 return 449 } 450 } 451 452 // No error, successful send 453 asyncSendErr(ready.Err, nil) 454 case <-s.shutdownCh: 455 return 456 } 457 } 458} 459 460// recv is a long running goroutine that accepts new data 461func (s *Session) recv() { 462 if err := s.recvLoop(); err != nil { 463 s.exitErr(err) 464 } 465} 466 467// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type 468var ( 469 handlers = []func(*Session, header) error{ 470 typeData: (*Session).handleStreamMessage, 471 typeWindowUpdate: (*Session).handleStreamMessage, 472 typePing: (*Session).handlePing, 473 typeGoAway: (*Session).handleGoAway, 474 } 475) 476 477// recvLoop continues to receive data until a fatal error is encountered 478func (s *Session) recvLoop() error { 479 defer close(s.recvDoneCh) 480 hdr := header(make([]byte, headerSize)) 481 for { 482 // Read the header 483 if _, err := io.ReadFull(s.bufRead, hdr); err != nil { 484 if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { 485 s.logger.Printf("[ERR] yamux: Failed to read header: %v", err) 486 } 487 return err 488 } 489 490 // Verify the version 491 if hdr.Version() != protoVersion { 492 s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version()) 493 return ErrInvalidVersion 494 } 495 496 mt := hdr.MsgType() 497 if mt < typeData || mt > typeGoAway { 498 return ErrInvalidMsgType 499 } 500 501 if err := handlers[mt](s, hdr); err != nil { 502 return err 503 } 504 } 505} 506 507// handleStreamMessage handles either a data or window update frame 508func (s *Session) handleStreamMessage(hdr header) error { 509 // Check for a new stream creation 510 id := hdr.StreamID() 511 flags := hdr.Flags() 512 if flags&flagSYN == flagSYN { 513 if err := s.incomingStream(id); err != nil { 514 return err 515 } 516 } 517 518 // Get the stream 519 s.streamLock.Lock() 520 stream := s.streams[id] 521 s.streamLock.Unlock() 522 523 // If we do not have a stream, likely we sent a RST 524 if stream == nil { 525 // Drain any data on the wire 526 if hdr.MsgType() == typeData && hdr.Length() > 0 { 527 s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id) 528 if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil { 529 s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err) 530 return nil 531 } 532 } else { 533 s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr) 534 } 535 return nil 536 } 537 538 // Check if this is a window update 539 if hdr.MsgType() == typeWindowUpdate { 540 if err := stream.incrSendWindow(hdr, flags); err != nil { 541 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 542 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 543 } 544 return err 545 } 546 return nil 547 } 548 549 // Read the new data 550 if err := stream.readData(hdr, flags, s.bufRead); err != nil { 551 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 552 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 553 } 554 return err 555 } 556 return nil 557} 558 559// handlePing is invokde for a typePing frame 560func (s *Session) handlePing(hdr header) error { 561 flags := hdr.Flags() 562 pingID := hdr.Length() 563 564 // Check if this is a query, respond back in a separate context so we 565 // don't interfere with the receiving thread blocking for the write. 566 if flags&flagSYN == flagSYN { 567 go func() { 568 hdr := header(make([]byte, headerSize)) 569 hdr.encode(typePing, flagACK, 0, pingID) 570 if err := s.sendNoWait(hdr); err != nil { 571 s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err) 572 } 573 }() 574 return nil 575 } 576 577 // Handle a response 578 s.pingLock.Lock() 579 ch := s.pings[pingID] 580 if ch != nil { 581 delete(s.pings, pingID) 582 close(ch) 583 } 584 s.pingLock.Unlock() 585 return nil 586} 587 588// handleGoAway is invokde for a typeGoAway frame 589func (s *Session) handleGoAway(hdr header) error { 590 code := hdr.Length() 591 switch code { 592 case goAwayNormal: 593 atomic.SwapInt32(&s.remoteGoAway, 1) 594 case goAwayProtoErr: 595 s.logger.Printf("[ERR] yamux: received protocol error go away") 596 return fmt.Errorf("yamux protocol error") 597 case goAwayInternalErr: 598 s.logger.Printf("[ERR] yamux: received internal error go away") 599 return fmt.Errorf("remote yamux internal error") 600 default: 601 s.logger.Printf("[ERR] yamux: received unexpected go away") 602 return fmt.Errorf("unexpected go away received") 603 } 604 return nil 605} 606 607// incomingStream is used to create a new incoming stream 608func (s *Session) incomingStream(id uint32) error { 609 // Reject immediately if we are doing a go away 610 if atomic.LoadInt32(&s.localGoAway) == 1 { 611 hdr := header(make([]byte, headerSize)) 612 hdr.encode(typeWindowUpdate, flagRST, id, 0) 613 return s.sendNoWait(hdr) 614 } 615 616 // Allocate a new stream 617 stream := newStream(s, id, streamSYNReceived) 618 619 s.streamLock.Lock() 620 defer s.streamLock.Unlock() 621 622 // Check if stream already exists 623 if _, ok := s.streams[id]; ok { 624 s.logger.Printf("[ERR] yamux: duplicate stream declared") 625 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 626 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 627 } 628 return ErrDuplicateStream 629 } 630 631 // Register the stream 632 s.streams[id] = stream 633 634 // Check if we've exceeded the backlog 635 select { 636 case s.acceptCh <- stream: 637 return nil 638 default: 639 // Backlog exceeded! RST the stream 640 s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset") 641 delete(s.streams, id) 642 stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0) 643 return s.sendNoWait(stream.sendHdr) 644 } 645} 646 647// closeStream is used to close a stream once both sides have 648// issued a close. If there was an in-flight SYN and the stream 649// was not yet established, then this will give the credit back. 650func (s *Session) closeStream(id uint32) { 651 s.streamLock.Lock() 652 if _, ok := s.inflight[id]; ok { 653 select { 654 case <-s.synCh: 655 default: 656 s.logger.Printf("[ERR] yamux: SYN tracking out of sync") 657 } 658 } 659 delete(s.streams, id) 660 s.streamLock.Unlock() 661} 662 663// establishStream is used to mark a stream that was in the 664// SYN Sent state as established. 665func (s *Session) establishStream(id uint32) { 666 s.streamLock.Lock() 667 if _, ok := s.inflight[id]; ok { 668 delete(s.inflight, id) 669 } else { 670 s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)") 671 } 672 select { 673 case <-s.synCh: 674 default: 675 s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)") 676 } 677 s.streamLock.Unlock() 678} 679