1package yamux 2 3import ( 4 "bufio" 5 "context" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "log" 10 "math" 11 "net" 12 "os" 13 "strings" 14 "sync" 15 "sync/atomic" 16 "time" 17 18 "github.com/libp2p/go-buffer-pool" 19) 20 21// Session is used to wrap a reliable ordered connection and to 22// multiplex it into multiple streams. 23type Session struct { 24 // remoteGoAway indicates the remote side does 25 // not want futher connections. Must be first for alignment. 26 remoteGoAway int32 27 28 // localGoAway indicates that we should stop 29 // accepting futher connections. Must be first for alignment. 30 localGoAway int32 31 32 // nextStreamID is the next stream we should 33 // send. This depends if we are a client/server. 34 nextStreamID uint32 35 36 // config holds our configuration 37 config *Config 38 39 // logger is used for our logs 40 logger *log.Logger 41 42 // conn is the underlying connection 43 conn net.Conn 44 45 // reader is a buffered reader 46 reader io.Reader 47 48 // pings is used to track inflight pings 49 pingLock sync.Mutex 50 pingID uint32 51 activePing *ping 52 53 // streams maps a stream id to a stream, and inflight has an entry 54 // for any outgoing stream that has not yet been established. Both are 55 // protected by streamLock. 56 streams map[uint32]*Stream 57 inflight map[uint32]struct{} 58 streamLock sync.Mutex 59 60 // synCh acts like a semaphore. It is sized to the AcceptBacklog which 61 // is assumed to be symmetric between the client and server. This allows 62 // the client to avoid exceeding the backlog and instead blocks the open. 63 synCh chan struct{} 64 65 // acceptCh is used to pass ready streams to the client 66 acceptCh chan *Stream 67 68 // sendCh is used to send messages 69 sendCh chan []byte 70 71 // pingCh and pingCh are used to send pings and pongs 72 pongCh, pingCh chan uint32 73 74 // recvDoneCh is closed when recv() exits to avoid a race 75 // between stream registration and stream shutdown 76 recvDoneCh chan struct{} 77 78 // sendDoneCh is closed when send() exits to avoid a race 79 // between returning from a Stream.Write and exiting from the send loop 80 // (which may be reading a buffer on-load-from Stream.Write). 81 sendDoneCh chan struct{} 82 83 // client is true if we're the client and our stream IDs should be odd. 84 client bool 85 86 // shutdown is used to safely close a session 87 shutdown bool 88 shutdownErr error 89 shutdownCh chan struct{} 90 shutdownLock sync.Mutex 91 92 // keepaliveTimer is a periodic timer for keepalive messages. It's nil 93 // when keepalives are disabled. 94 keepaliveLock sync.Mutex 95 keepaliveTimer *time.Timer 96 keepaliveActive bool 97} 98 99// newSession is used to construct a new session 100func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Session { 101 var reader io.Reader = conn 102 if readBuf > 0 { 103 reader = bufio.NewReaderSize(reader, readBuf) 104 } 105 s := &Session{ 106 config: config, 107 client: client, 108 logger: log.New(config.LogOutput, "", log.LstdFlags), 109 conn: conn, 110 reader: reader, 111 streams: make(map[uint32]*Stream), 112 inflight: make(map[uint32]struct{}), 113 synCh: make(chan struct{}, config.AcceptBacklog), 114 acceptCh: make(chan *Stream, config.AcceptBacklog), 115 sendCh: make(chan []byte, 64), 116 pongCh: make(chan uint32, config.PingBacklog), 117 pingCh: make(chan uint32), 118 recvDoneCh: make(chan struct{}), 119 sendDoneCh: make(chan struct{}), 120 shutdownCh: make(chan struct{}), 121 } 122 if client { 123 s.nextStreamID = 1 124 } else { 125 s.nextStreamID = 2 126 } 127 if config.EnableKeepAlive { 128 s.startKeepalive() 129 } 130 go s.recv() 131 go s.send() 132 return s 133} 134 135// IsClosed does a safe check to see if we have shutdown 136func (s *Session) IsClosed() bool { 137 select { 138 case <-s.shutdownCh: 139 return true 140 default: 141 return false 142 } 143} 144 145// CloseChan returns a read-only channel which is closed as 146// soon as the session is closed. 147func (s *Session) CloseChan() <-chan struct{} { 148 return s.shutdownCh 149} 150 151// NumStreams returns the number of currently open streams 152func (s *Session) NumStreams() int { 153 s.streamLock.Lock() 154 num := len(s.streams) 155 s.streamLock.Unlock() 156 return num 157} 158 159// Open is used to create a new stream as a net.Conn 160func (s *Session) Open(ctx context.Context) (net.Conn, error) { 161 conn, err := s.OpenStream(ctx) 162 if err != nil { 163 return nil, err 164 } 165 return conn, nil 166} 167 168// OpenStream is used to create a new stream 169func (s *Session) OpenStream(ctx context.Context) (*Stream, error) { 170 if s.IsClosed() { 171 return nil, s.shutdownErr 172 } 173 if atomic.LoadInt32(&s.remoteGoAway) == 1 { 174 return nil, ErrRemoteGoAway 175 } 176 177 // Block if we have too many inflight SYNs 178 select { 179 case s.synCh <- struct{}{}: 180 case <-ctx.Done(): 181 return nil, ctx.Err() 182 case <-s.shutdownCh: 183 return nil, s.shutdownErr 184 } 185 186GET_ID: 187 // Get an ID, and check for stream exhaustion 188 id := atomic.LoadUint32(&s.nextStreamID) 189 if id >= math.MaxUint32-1 { 190 return nil, ErrStreamsExhausted 191 } 192 if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) { 193 goto GET_ID 194 } 195 196 // Register the stream 197 stream := newStream(s, id, streamInit) 198 s.streamLock.Lock() 199 s.streams[id] = stream 200 s.inflight[id] = struct{}{} 201 s.streamLock.Unlock() 202 203 // Send the window update to create 204 if err := stream.sendWindowUpdate(); err != nil { 205 select { 206 case <-s.synCh: 207 default: 208 s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore") 209 } 210 return nil, err 211 } 212 return stream, nil 213} 214 215// Accept is used to block until the next available stream 216// is ready to be accepted. 217func (s *Session) Accept() (net.Conn, error) { 218 conn, err := s.AcceptStream() 219 if err != nil { 220 return nil, err 221 } 222 return conn, err 223} 224 225// AcceptStream is used to block until the next available stream 226// is ready to be accepted. 227func (s *Session) AcceptStream() (*Stream, error) { 228 for { 229 select { 230 case stream := <-s.acceptCh: 231 if err := stream.sendWindowUpdate(); err != nil { 232 // don't return accept errors. 233 s.logger.Printf("[WARN] error sending window update before accepting: %s", err) 234 continue 235 } 236 return stream, nil 237 case <-s.shutdownCh: 238 return nil, s.shutdownErr 239 } 240 } 241} 242 243// Close is used to close the session and all streams. 244// Attempts to send a GoAway before closing the connection. 245func (s *Session) Close() error { 246 s.shutdownLock.Lock() 247 defer s.shutdownLock.Unlock() 248 249 if s.shutdown { 250 return nil 251 } 252 s.shutdown = true 253 if s.shutdownErr == nil { 254 s.shutdownErr = ErrSessionShutdown 255 } 256 close(s.shutdownCh) 257 s.conn.Close() 258 s.stopKeepalive() 259 <-s.recvDoneCh 260 <-s.sendDoneCh 261 262 s.streamLock.Lock() 263 defer s.streamLock.Unlock() 264 for _, stream := range s.streams { 265 stream.forceClose() 266 } 267 return nil 268} 269 270// exitErr is used to handle an error that is causing the 271// session to terminate. 272func (s *Session) exitErr(err error) { 273 s.shutdownLock.Lock() 274 if s.shutdownErr == nil { 275 s.shutdownErr = err 276 } 277 s.shutdownLock.Unlock() 278 s.Close() 279} 280 281// GoAway can be used to prevent accepting further 282// connections. It does not close the underlying conn. 283func (s *Session) GoAway() error { 284 return s.sendMsg(s.goAway(goAwayNormal), nil, nil) 285} 286 287// goAway is used to send a goAway message 288func (s *Session) goAway(reason uint32) header { 289 atomic.SwapInt32(&s.localGoAway, 1) 290 hdr := encode(typeGoAway, 0, 0, reason) 291 return hdr 292} 293 294// Ping is used to measure the RTT response time 295func (s *Session) Ping() (dur time.Duration, err error) { 296 // Prepare a ping. 297 s.pingLock.Lock() 298 // If there's an active ping, jump on the bandwagon. 299 if activePing := s.activePing; activePing != nil { 300 s.pingLock.Unlock() 301 return activePing.wait() 302 } 303 304 // Ok, our job to send the ping. 305 activePing := newPing(s.pingID) 306 s.pingID++ 307 s.activePing = activePing 308 s.pingLock.Unlock() 309 310 defer func() { 311 // complete ping promise 312 activePing.finish(dur, err) 313 314 // Unset it. 315 s.pingLock.Lock() 316 s.activePing = nil 317 s.pingLock.Unlock() 318 }() 319 320 // Send the ping request, waiting at most one connection write timeout 321 // to flush it. 322 timer := time.NewTimer(s.config.ConnectionWriteTimeout) 323 defer timer.Stop() 324 select { 325 case s.pingCh <- activePing.id: 326 case <-timer.C: 327 return 0, ErrTimeout 328 case <-s.shutdownCh: 329 return 0, s.shutdownErr 330 } 331 332 // The "time" starts once we've actually sent the ping. Otherwise, we'll 333 // measure the time it takes to flush the queue as well. 334 start := time.Now() 335 336 // Wait for a response, again waiting at most one write timeout. 337 if !timer.Stop() { 338 <-timer.C 339 } 340 timer.Reset(s.config.ConnectionWriteTimeout) 341 select { 342 case <-activePing.pingResponse: 343 case <-timer.C: 344 return 0, ErrTimeout 345 case <-s.shutdownCh: 346 return 0, s.shutdownErr 347 } 348 349 // Compute the RTT 350 return time.Since(start), nil 351} 352 353// startKeepalive starts the keepalive process. 354func (s *Session) startKeepalive() { 355 s.keepaliveLock.Lock() 356 defer s.keepaliveLock.Unlock() 357 s.keepaliveTimer = time.AfterFunc(s.config.KeepAliveInterval, func() { 358 s.keepaliveLock.Lock() 359 if s.keepaliveTimer == nil || s.keepaliveActive { 360 // keepalives have been stopped or a keepalive is active. 361 s.keepaliveLock.Unlock() 362 return 363 } 364 s.keepaliveActive = true 365 s.keepaliveLock.Unlock() 366 367 _, err := s.Ping() 368 369 s.keepaliveLock.Lock() 370 s.keepaliveActive = false 371 if s.keepaliveTimer != nil { 372 s.keepaliveTimer.Reset(s.config.KeepAliveInterval) 373 } 374 s.keepaliveLock.Unlock() 375 376 if err != nil { 377 s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) 378 s.exitErr(ErrKeepAliveTimeout) 379 } 380 }) 381} 382 383// stopKeepalive stops the keepalive process. 384func (s *Session) stopKeepalive() { 385 s.keepaliveLock.Lock() 386 defer s.keepaliveLock.Unlock() 387 if s.keepaliveTimer != nil { 388 s.keepaliveTimer.Stop() 389 s.keepaliveTimer = nil 390 } 391} 392 393func (s *Session) extendKeepalive() { 394 s.keepaliveLock.Lock() 395 if s.keepaliveTimer != nil && !s.keepaliveActive { 396 // Don't stop the timer and drain the channel. This is an 397 // AfterFunc, not a normal timer, and any attempts to drain the 398 // channel will block forever. 399 // 400 // Go will stop the timer for us internally anyways. The docs 401 // say one must stop the timer before calling reset but that's 402 // to ensure that the timer doesn't end up firing immediately 403 // after calling Reset. 404 s.keepaliveTimer.Reset(s.config.KeepAliveInterval) 405 } 406 s.keepaliveLock.Unlock() 407} 408 409// send sends the header and body. 410func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) error { 411 select { 412 case <-s.shutdownCh: 413 return s.shutdownErr 414 default: 415 } 416 417 // duplicate as we're sending this async. 418 buf := pool.Get(headerSize + len(body)) 419 copy(buf[:headerSize], hdr[:]) 420 copy(buf[headerSize:], body) 421 422 select { 423 case <-s.shutdownCh: 424 pool.Put(buf) 425 return s.shutdownErr 426 case s.sendCh <- buf: 427 return nil 428 case <-deadline: 429 pool.Put(buf) 430 return ErrTimeout 431 } 432} 433 434// send is a long running goroutine that sends data 435func (s *Session) send() { 436 if err := s.sendLoop(); err != nil { 437 s.exitErr(err) 438 } 439} 440 441func (s *Session) sendLoop() error { 442 defer close(s.sendDoneCh) 443 444 // Extend the write deadline if we've passed the halfway point. This can 445 // be expensive so this ensures we only have to do this once every 446 // ConnectionWriteTimeout/2 (usually 5s). 447 var lastWriteDeadline time.Time 448 extendWriteDeadline := func() error { 449 now := time.Now() 450 // If over half of the deadline has elapsed, extend it. 451 if now.Add(s.config.ConnectionWriteTimeout / 2).After(lastWriteDeadline) { 452 lastWriteDeadline = now.Add(s.config.ConnectionWriteTimeout) 453 return s.conn.SetWriteDeadline(lastWriteDeadline) 454 } 455 return nil 456 } 457 458 writer := s.conn 459 460 // FIXME: https://github.com/libp2p/go-libp2p/issues/644 461 // Write coalescing is disabled for now. 462 463 //writer := pool.Writer{W: s.conn} 464 465 //var writeTimeout *time.Timer 466 //var writeTimeoutCh <-chan time.Time 467 //if s.config.WriteCoalesceDelay > 0 { 468 // writeTimeout = time.NewTimer(s.config.WriteCoalesceDelay) 469 // defer writeTimeout.Stop() 470 471 // writeTimeoutCh = writeTimeout.C 472 //} else { 473 // ch := make(chan time.Time) 474 // close(ch) 475 // writeTimeoutCh = ch 476 //} 477 478 for { 479 // yield after processing the last message, if we've shutdown. 480 // s.sendCh is a buffered channel and Go doesn't guarantee select order. 481 select { 482 case <-s.shutdownCh: 483 return nil 484 default: 485 } 486 487 // Flushes at least once every 100 microseconds unless we're 488 // constantly writing. 489 var buf []byte 490 select { 491 case buf = <-s.sendCh: 492 case pingID := <-s.pingCh: 493 buf = pool.Get(headerSize) 494 hdr := encode(typePing, flagSYN, 0, pingID) 495 copy(buf, hdr[:]) 496 case pingID := <-s.pongCh: 497 buf = pool.Get(headerSize) 498 hdr := encode(typePing, flagACK, 0, pingID) 499 copy(buf, hdr[:]) 500 case <-s.shutdownCh: 501 return nil 502 //default: 503 // select { 504 // case buf = <-s.sendCh: 505 // case <-s.shutdownCh: 506 // return nil 507 // case <-writeTimeoutCh: 508 // if err := writer.Flush(); err != nil { 509 // if os.IsTimeout(err) { 510 // err = ErrConnectionWriteTimeout 511 // } 512 // return err 513 // } 514 515 // select { 516 // case buf = <-s.sendCh: 517 // case <-s.shutdownCh: 518 // return nil 519 // } 520 521 // if writeTimeout != nil { 522 // writeTimeout.Reset(s.config.WriteCoalesceDelay) 523 // } 524 // } 525 } 526 527 if err := extendWriteDeadline(); err != nil { 528 pool.Put(buf) 529 return err 530 } 531 532 _, err := writer.Write(buf) 533 pool.Put(buf) 534 535 if err != nil { 536 if os.IsTimeout(err) { 537 err = ErrConnectionWriteTimeout 538 } 539 return err 540 } 541 } 542} 543 544// recv is a long running goroutine that accepts new data 545func (s *Session) recv() { 546 if err := s.recvLoop(); err != nil { 547 s.exitErr(err) 548 } 549} 550 551// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type 552var ( 553 handlers = []func(*Session, header) error{ 554 typeData: (*Session).handleStreamMessage, 555 typeWindowUpdate: (*Session).handleStreamMessage, 556 typePing: (*Session).handlePing, 557 typeGoAway: (*Session).handleGoAway, 558 } 559) 560 561// recvLoop continues to receive data until a fatal error is encountered 562func (s *Session) recvLoop() error { 563 defer close(s.recvDoneCh) 564 var hdr header 565 for { 566 // Read the header 567 if _, err := io.ReadFull(s.reader, hdr[:]); err != nil { 568 if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { 569 s.logger.Printf("[ERR] yamux: Failed to read header: %v", err) 570 } 571 return err 572 } 573 574 // Reset the keepalive timer every time we receive data. 575 // There's no reason to keepalive if we're active. Worse, if the 576 // peer is busy sending us stuff, the pong might get stuck 577 // behind a bunch of data. 578 s.extendKeepalive() 579 580 // Verify the version 581 if hdr.Version() != protoVersion { 582 s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version()) 583 return ErrInvalidVersion 584 } 585 586 mt := hdr.MsgType() 587 if mt < typeData || mt > typeGoAway { 588 return ErrInvalidMsgType 589 } 590 591 if err := handlers[mt](s, hdr); err != nil { 592 return err 593 } 594 } 595} 596 597// handleStreamMessage handles either a data or window update frame 598func (s *Session) handleStreamMessage(hdr header) error { 599 // Check for a new stream creation 600 id := hdr.StreamID() 601 flags := hdr.Flags() 602 if flags&flagSYN == flagSYN { 603 if err := s.incomingStream(id); err != nil { 604 return err 605 } 606 } 607 608 // Get the stream 609 s.streamLock.Lock() 610 stream := s.streams[id] 611 s.streamLock.Unlock() 612 613 // If we do not have a stream, likely we sent a RST 614 if stream == nil { 615 // Drain any data on the wire 616 if hdr.MsgType() == typeData && hdr.Length() > 0 { 617 s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id) 618 if _, err := io.CopyN(ioutil.Discard, s.reader, int64(hdr.Length())); err != nil { 619 s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err) 620 return nil 621 } 622 } else { 623 s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr) 624 } 625 return nil 626 } 627 628 // Check if this is a window update 629 if hdr.MsgType() == typeWindowUpdate { 630 if err := stream.incrSendWindow(hdr, flags); err != nil { 631 if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil { 632 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 633 } 634 return err 635 } 636 return nil 637 } 638 639 // Read the new data 640 if err := stream.readData(hdr, flags, s.reader); err != nil { 641 if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil { 642 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 643 } 644 return err 645 } 646 return nil 647} 648 649// handlePing is invoked for a typePing frame 650func (s *Session) handlePing(hdr header) error { 651 flags := hdr.Flags() 652 pingID := hdr.Length() 653 654 // Check if this is a query, respond back in a separate context so we 655 // don't interfere with the receiving thread blocking for the write. 656 if flags&flagSYN == flagSYN { 657 select { 658 case s.pongCh <- pingID: 659 default: 660 s.logger.Printf("[WARN] yamux: dropped ping reply") 661 } 662 return nil 663 } 664 665 // Handle a response 666 s.pingLock.Lock() 667 // If we have an active ping, and this is a response to that active 668 // ping, complete the ping. 669 if s.activePing != nil && s.activePing.id == pingID { 670 // Don't assume that the peer won't send multiple responses for 671 // the same ping. 672 select { 673 case s.activePing.pingResponse <- struct{}{}: 674 default: 675 } 676 } 677 s.pingLock.Unlock() 678 return nil 679} 680 681// handleGoAway is invokde for a typeGoAway frame 682func (s *Session) handleGoAway(hdr header) error { 683 code := hdr.Length() 684 switch code { 685 case goAwayNormal: 686 atomic.SwapInt32(&s.remoteGoAway, 1) 687 case goAwayProtoErr: 688 s.logger.Printf("[ERR] yamux: received protocol error go away") 689 return fmt.Errorf("yamux protocol error") 690 case goAwayInternalErr: 691 s.logger.Printf("[ERR] yamux: received internal error go away") 692 return fmt.Errorf("remote yamux internal error") 693 default: 694 s.logger.Printf("[ERR] yamux: received unexpected go away") 695 return fmt.Errorf("unexpected go away received") 696 } 697 return nil 698} 699 700// incomingStream is used to create a new incoming stream 701func (s *Session) incomingStream(id uint32) error { 702 if s.client != (id%2 == 0) { 703 s.logger.Printf("[ERR] yamux: both endpoints are clients") 704 return fmt.Errorf("both yamux endpoints are clients") 705 } 706 // Reject immediately if we are doing a go away 707 if atomic.LoadInt32(&s.localGoAway) == 1 { 708 hdr := encode(typeWindowUpdate, flagRST, id, 0) 709 return s.sendMsg(hdr, nil, nil) 710 } 711 712 // Allocate a new stream 713 stream := newStream(s, id, streamSYNReceived) 714 715 s.streamLock.Lock() 716 defer s.streamLock.Unlock() 717 718 // Check if stream already exists 719 if _, ok := s.streams[id]; ok { 720 s.logger.Printf("[ERR] yamux: duplicate stream declared") 721 if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil { 722 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 723 } 724 return ErrDuplicateStream 725 } 726 727 // Register the stream 728 s.streams[id] = stream 729 730 // Check if we've exceeded the backlog 731 select { 732 case s.acceptCh <- stream: 733 return nil 734 default: 735 // Backlog exceeded! RST the stream 736 s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset") 737 delete(s.streams, id) 738 hdr := encode(typeWindowUpdate, flagRST, id, 0) 739 return s.sendMsg(hdr, nil, nil) 740 } 741} 742 743// closeStream is used to close a stream once both sides have 744// issued a close. If there was an in-flight SYN and the stream 745// was not yet established, then this will give the credit back. 746func (s *Session) closeStream(id uint32) { 747 s.streamLock.Lock() 748 if _, ok := s.inflight[id]; ok { 749 select { 750 case <-s.synCh: 751 default: 752 s.logger.Printf("[ERR] yamux: SYN tracking out of sync") 753 } 754 delete(s.inflight, id) 755 } 756 delete(s.streams, id) 757 s.streamLock.Unlock() 758} 759 760// establishStream is used to mark a stream that was in the 761// SYN Sent state as established. 762func (s *Session) establishStream(id uint32) { 763 s.streamLock.Lock() 764 if _, ok := s.inflight[id]; ok { 765 delete(s.inflight, id) 766 } else { 767 s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)") 768 } 769 select { 770 case <-s.synCh: 771 default: 772 s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)") 773 } 774 s.streamLock.Unlock() 775} 776