1package yamux 2 3import ( 4 "bufio" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "log" 9 "math" 10 "net" 11 "strings" 12 "sync" 13 "sync/atomic" 14 "time" 15) 16 17// Session is used to wrap a reliable ordered connection and to 18// multiplex it into multiple streams. 19type Session struct { 20 // remoteGoAway indicates the remote side does 21 // not want futher connections. Must be first for alignment. 22 remoteGoAway int32 23 24 // localGoAway indicates that we should stop 25 // accepting futher connections. Must be first for alignment. 26 localGoAway int32 27 28 // nextStreamID is the next stream we should 29 // send. This depends if we are a client/server. 30 nextStreamID uint32 31 32 // config holds our configuration 33 config *Config 34 35 // logger is used for our logs 36 logger *log.Logger 37 38 // conn is the underlying connection 39 conn io.ReadWriteCloser 40 41 // bufRead is a buffered reader 42 bufRead *bufio.Reader 43 44 // pings is used to track inflight pings 45 pings map[uint32]chan struct{} 46 pingID uint32 47 pingLock sync.Mutex 48 49 // streams maps a stream id to a stream, and inflight has an entry 50 // for any outgoing stream that has not yet been established. Both are 51 // protected by streamLock. 52 streams map[uint32]*Stream 53 inflight map[uint32]struct{} 54 streamLock sync.Mutex 55 56 // synCh acts like a semaphore. It is sized to the AcceptBacklog which 57 // is assumed to be symmetric between the client and server. This allows 58 // the client to avoid exceeding the backlog and instead blocks the open. 59 synCh chan struct{} 60 61 // acceptCh is used to pass ready streams to the client 62 acceptCh chan *Stream 63 64 // sendCh is used to mark a stream as ready to send, 65 // or to send a header out directly. 66 sendCh chan sendReady 67 68 // recvDoneCh is closed when recv() exits to avoid a race 69 // between stream registration and stream shutdown 70 recvDoneCh chan struct{} 71 72 // shutdown is used to safely close a session 73 shutdown bool 74 shutdownErr error 75 shutdownCh chan struct{} 76 shutdownLock sync.Mutex 77} 78 79// sendReady is used to either mark a stream as ready 80// or to directly send a header 81type sendReady struct { 82 Hdr []byte 83 Body io.Reader 84 Err chan error 85} 86 87// newSession is used to construct a new session 88func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { 89 logger := config.Logger 90 if logger == nil { 91 logger = log.New(config.LogOutput, "", log.LstdFlags) 92 } 93 94 s := &Session{ 95 config: config, 96 logger: logger, 97 conn: conn, 98 bufRead: bufio.NewReader(conn), 99 pings: make(map[uint32]chan struct{}), 100 streams: make(map[uint32]*Stream), 101 inflight: make(map[uint32]struct{}), 102 synCh: make(chan struct{}, config.AcceptBacklog), 103 acceptCh: make(chan *Stream, config.AcceptBacklog), 104 sendCh: make(chan sendReady, 64), 105 recvDoneCh: make(chan struct{}), 106 shutdownCh: make(chan struct{}), 107 } 108 if client { 109 s.nextStreamID = 1 110 } else { 111 s.nextStreamID = 2 112 } 113 go s.recv() 114 go s.send() 115 if config.EnableKeepAlive { 116 go s.keepalive() 117 } 118 return s 119} 120 121// IsClosed does a safe check to see if we have shutdown 122func (s *Session) IsClosed() bool { 123 select { 124 case <-s.shutdownCh: 125 return true 126 default: 127 return false 128 } 129} 130 131// CloseChan returns a read-only channel which is closed as 132// soon as the session is closed. 133func (s *Session) CloseChan() <-chan struct{} { 134 return s.shutdownCh 135} 136 137// NumStreams returns the number of currently open streams 138func (s *Session) NumStreams() int { 139 s.streamLock.Lock() 140 num := len(s.streams) 141 s.streamLock.Unlock() 142 return num 143} 144 145// Open is used to create a new stream as a net.Conn 146func (s *Session) Open() (net.Conn, error) { 147 conn, err := s.OpenStream() 148 if err != nil { 149 return nil, err 150 } 151 return conn, nil 152} 153 154// OpenStream is used to create a new stream 155func (s *Session) OpenStream() (*Stream, error) { 156 if s.IsClosed() { 157 return nil, ErrSessionShutdown 158 } 159 if atomic.LoadInt32(&s.remoteGoAway) == 1 { 160 return nil, ErrRemoteGoAway 161 } 162 163 // Block if we have too many inflight SYNs 164 select { 165 case s.synCh <- struct{}{}: 166 case <-s.shutdownCh: 167 return nil, ErrSessionShutdown 168 } 169 170GET_ID: 171 // Get an ID, and check for stream exhaustion 172 id := atomic.LoadUint32(&s.nextStreamID) 173 if id >= math.MaxUint32-1 { 174 return nil, ErrStreamsExhausted 175 } 176 if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) { 177 goto GET_ID 178 } 179 180 // Register the stream 181 stream := newStream(s, id, streamInit) 182 s.streamLock.Lock() 183 s.streams[id] = stream 184 s.inflight[id] = struct{}{} 185 s.streamLock.Unlock() 186 187 // Send the window update to create 188 if err := stream.sendWindowUpdate(); err != nil { 189 select { 190 case <-s.synCh: 191 default: 192 s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore") 193 } 194 return nil, err 195 } 196 return stream, nil 197} 198 199// Accept is used to block until the next available stream 200// is ready to be accepted. 201func (s *Session) Accept() (net.Conn, error) { 202 conn, err := s.AcceptStream() 203 if err != nil { 204 return nil, err 205 } 206 return conn, err 207} 208 209// AcceptStream is used to block until the next available stream 210// is ready to be accepted. 211func (s *Session) AcceptStream() (*Stream, error) { 212 select { 213 case stream := <-s.acceptCh: 214 if err := stream.sendWindowUpdate(); err != nil { 215 return nil, err 216 } 217 return stream, nil 218 case <-s.shutdownCh: 219 return nil, s.shutdownErr 220 } 221} 222 223// Close is used to close the session and all streams. 224// Attempts to send a GoAway before closing the connection. 225func (s *Session) Close() error { 226 s.shutdownLock.Lock() 227 defer s.shutdownLock.Unlock() 228 229 if s.shutdown { 230 return nil 231 } 232 s.shutdown = true 233 if s.shutdownErr == nil { 234 s.shutdownErr = ErrSessionShutdown 235 } 236 close(s.shutdownCh) 237 s.conn.Close() 238 <-s.recvDoneCh 239 240 s.streamLock.Lock() 241 defer s.streamLock.Unlock() 242 for _, stream := range s.streams { 243 stream.forceClose() 244 } 245 return nil 246} 247 248// exitErr is used to handle an error that is causing the 249// session to terminate. 250func (s *Session) exitErr(err error) { 251 s.shutdownLock.Lock() 252 if s.shutdownErr == nil { 253 s.shutdownErr = err 254 } 255 s.shutdownLock.Unlock() 256 s.Close() 257} 258 259// GoAway can be used to prevent accepting further 260// connections. It does not close the underlying conn. 261func (s *Session) GoAway() error { 262 return s.waitForSend(s.goAway(goAwayNormal), nil) 263} 264 265// goAway is used to send a goAway message 266func (s *Session) goAway(reason uint32) header { 267 atomic.SwapInt32(&s.localGoAway, 1) 268 hdr := header(make([]byte, headerSize)) 269 hdr.encode(typeGoAway, 0, 0, reason) 270 return hdr 271} 272 273// Ping is used to measure the RTT response time 274func (s *Session) Ping() (time.Duration, error) { 275 // Get a channel for the ping 276 ch := make(chan struct{}) 277 278 // Get a new ping id, mark as pending 279 s.pingLock.Lock() 280 id := s.pingID 281 s.pingID++ 282 s.pings[id] = ch 283 s.pingLock.Unlock() 284 285 // Send the ping request 286 hdr := header(make([]byte, headerSize)) 287 hdr.encode(typePing, flagSYN, 0, id) 288 if err := s.waitForSend(hdr, nil); err != nil { 289 return 0, err 290 } 291 292 // Wait for a response 293 start := time.Now() 294 select { 295 case <-ch: 296 case <-time.After(s.config.ConnectionWriteTimeout): 297 s.pingLock.Lock() 298 delete(s.pings, id) // Ignore it if a response comes later. 299 s.pingLock.Unlock() 300 return 0, ErrTimeout 301 case <-s.shutdownCh: 302 return 0, ErrSessionShutdown 303 } 304 305 // Compute the RTT 306 return time.Now().Sub(start), nil 307} 308 309// keepalive is a long running goroutine that periodically does 310// a ping to keep the connection alive. 311func (s *Session) keepalive() { 312 for { 313 select { 314 case <-time.After(s.config.KeepAliveInterval): 315 _, err := s.Ping() 316 if err != nil { 317 if err != ErrSessionShutdown { 318 s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) 319 s.exitErr(ErrKeepAliveTimeout) 320 } 321 return 322 } 323 case <-s.shutdownCh: 324 return 325 } 326 } 327} 328 329// waitForSendErr waits to send a header, checking for a potential shutdown 330func (s *Session) waitForSend(hdr header, body io.Reader) error { 331 errCh := make(chan error, 1) 332 return s.waitForSendErr(hdr, body, errCh) 333} 334 335// waitForSendErr waits to send a header with optional data, checking for a 336// potential shutdown. Since there's the expectation that sends can happen 337// in a timely manner, we enforce the connection write timeout here. 338func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error { 339 t := timerPool.Get() 340 timer := t.(*time.Timer) 341 timer.Reset(s.config.ConnectionWriteTimeout) 342 defer func() { 343 timer.Stop() 344 select { 345 case <-timer.C: 346 default: 347 } 348 timerPool.Put(t) 349 }() 350 351 ready := sendReady{Hdr: hdr, Body: body, Err: errCh} 352 select { 353 case s.sendCh <- ready: 354 case <-s.shutdownCh: 355 return ErrSessionShutdown 356 case <-timer.C: 357 return ErrConnectionWriteTimeout 358 } 359 360 select { 361 case err := <-errCh: 362 return err 363 case <-s.shutdownCh: 364 return ErrSessionShutdown 365 case <-timer.C: 366 return ErrConnectionWriteTimeout 367 } 368} 369 370// sendNoWait does a send without waiting. Since there's the expectation that 371// the send happens right here, we enforce the connection write timeout if we 372// can't queue the header to be sent. 373func (s *Session) sendNoWait(hdr header) error { 374 t := timerPool.Get() 375 timer := t.(*time.Timer) 376 timer.Reset(s.config.ConnectionWriteTimeout) 377 defer func() { 378 timer.Stop() 379 select { 380 case <-timer.C: 381 default: 382 } 383 timerPool.Put(t) 384 }() 385 386 select { 387 case s.sendCh <- sendReady{Hdr: hdr}: 388 return nil 389 case <-s.shutdownCh: 390 return ErrSessionShutdown 391 case <-timer.C: 392 return ErrConnectionWriteTimeout 393 } 394} 395 396// send is a long running goroutine that sends data 397func (s *Session) send() { 398 for { 399 select { 400 case ready := <-s.sendCh: 401 // Send a header if ready 402 if ready.Hdr != nil { 403 sent := 0 404 for sent < len(ready.Hdr) { 405 n, err := s.conn.Write(ready.Hdr[sent:]) 406 if err != nil { 407 s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) 408 asyncSendErr(ready.Err, err) 409 s.exitErr(err) 410 return 411 } 412 sent += n 413 } 414 } 415 416 // Send data from a body if given 417 if ready.Body != nil { 418 _, err := io.Copy(s.conn, ready.Body) 419 if err != nil { 420 s.logger.Printf("[ERR] yamux: Failed to write body: %v", err) 421 asyncSendErr(ready.Err, err) 422 s.exitErr(err) 423 return 424 } 425 } 426 427 // No error, successful send 428 asyncSendErr(ready.Err, nil) 429 case <-s.shutdownCh: 430 return 431 } 432 } 433} 434 435// recv is a long running goroutine that accepts new data 436func (s *Session) recv() { 437 if err := s.recvLoop(); err != nil { 438 s.exitErr(err) 439 } 440} 441 442// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type 443var ( 444 handlers = []func(*Session, header) error{ 445 typeData: (*Session).handleStreamMessage, 446 typeWindowUpdate: (*Session).handleStreamMessage, 447 typePing: (*Session).handlePing, 448 typeGoAway: (*Session).handleGoAway, 449 } 450) 451 452// recvLoop continues to receive data until a fatal error is encountered 453func (s *Session) recvLoop() error { 454 defer close(s.recvDoneCh) 455 hdr := header(make([]byte, headerSize)) 456 for { 457 // Read the header 458 if _, err := io.ReadFull(s.bufRead, hdr); err != nil { 459 if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { 460 s.logger.Printf("[ERR] yamux: Failed to read header: %v", err) 461 } 462 return err 463 } 464 465 // Verify the version 466 if hdr.Version() != protoVersion { 467 s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version()) 468 return ErrInvalidVersion 469 } 470 471 mt := hdr.MsgType() 472 if mt < typeData || mt > typeGoAway { 473 return ErrInvalidMsgType 474 } 475 476 if err := handlers[mt](s, hdr); err != nil { 477 return err 478 } 479 } 480} 481 482// handleStreamMessage handles either a data or window update frame 483func (s *Session) handleStreamMessage(hdr header) error { 484 // Check for a new stream creation 485 id := hdr.StreamID() 486 flags := hdr.Flags() 487 if flags&flagSYN == flagSYN { 488 if err := s.incomingStream(id); err != nil { 489 return err 490 } 491 } 492 493 // Get the stream 494 s.streamLock.Lock() 495 stream := s.streams[id] 496 s.streamLock.Unlock() 497 498 // If we do not have a stream, likely we sent a RST 499 if stream == nil { 500 // Drain any data on the wire 501 if hdr.MsgType() == typeData && hdr.Length() > 0 { 502 s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id) 503 if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil { 504 s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err) 505 return nil 506 } 507 } else { 508 s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr) 509 } 510 return nil 511 } 512 513 // Check if this is a window update 514 if hdr.MsgType() == typeWindowUpdate { 515 if err := stream.incrSendWindow(hdr, flags); err != nil { 516 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 517 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 518 } 519 return err 520 } 521 return nil 522 } 523 524 // Read the new data 525 if err := stream.readData(hdr, flags, s.bufRead); err != nil { 526 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 527 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 528 } 529 return err 530 } 531 return nil 532} 533 534// handlePing is invokde for a typePing frame 535func (s *Session) handlePing(hdr header) error { 536 flags := hdr.Flags() 537 pingID := hdr.Length() 538 539 // Check if this is a query, respond back in a separate context so we 540 // don't interfere with the receiving thread blocking for the write. 541 if flags&flagSYN == flagSYN { 542 go func() { 543 hdr := header(make([]byte, headerSize)) 544 hdr.encode(typePing, flagACK, 0, pingID) 545 if err := s.sendNoWait(hdr); err != nil { 546 s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err) 547 } 548 }() 549 return nil 550 } 551 552 // Handle a response 553 s.pingLock.Lock() 554 ch := s.pings[pingID] 555 if ch != nil { 556 delete(s.pings, pingID) 557 close(ch) 558 } 559 s.pingLock.Unlock() 560 return nil 561} 562 563// handleGoAway is invokde for a typeGoAway frame 564func (s *Session) handleGoAway(hdr header) error { 565 code := hdr.Length() 566 switch code { 567 case goAwayNormal: 568 atomic.SwapInt32(&s.remoteGoAway, 1) 569 case goAwayProtoErr: 570 s.logger.Printf("[ERR] yamux: received protocol error go away") 571 return fmt.Errorf("yamux protocol error") 572 case goAwayInternalErr: 573 s.logger.Printf("[ERR] yamux: received internal error go away") 574 return fmt.Errorf("remote yamux internal error") 575 default: 576 s.logger.Printf("[ERR] yamux: received unexpected go away") 577 return fmt.Errorf("unexpected go away received") 578 } 579 return nil 580} 581 582// incomingStream is used to create a new incoming stream 583func (s *Session) incomingStream(id uint32) error { 584 // Reject immediately if we are doing a go away 585 if atomic.LoadInt32(&s.localGoAway) == 1 { 586 hdr := header(make([]byte, headerSize)) 587 hdr.encode(typeWindowUpdate, flagRST, id, 0) 588 return s.sendNoWait(hdr) 589 } 590 591 // Allocate a new stream 592 stream := newStream(s, id, streamSYNReceived) 593 594 s.streamLock.Lock() 595 defer s.streamLock.Unlock() 596 597 // Check if stream already exists 598 if _, ok := s.streams[id]; ok { 599 s.logger.Printf("[ERR] yamux: duplicate stream declared") 600 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 601 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 602 } 603 return ErrDuplicateStream 604 } 605 606 // Register the stream 607 s.streams[id] = stream 608 609 // Check if we've exceeded the backlog 610 select { 611 case s.acceptCh <- stream: 612 return nil 613 default: 614 // Backlog exceeded! RST the stream 615 s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset") 616 delete(s.streams, id) 617 stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0) 618 return s.sendNoWait(stream.sendHdr) 619 } 620} 621 622// closeStream is used to close a stream once both sides have 623// issued a close. If there was an in-flight SYN and the stream 624// was not yet established, then this will give the credit back. 625func (s *Session) closeStream(id uint32) { 626 s.streamLock.Lock() 627 if _, ok := s.inflight[id]; ok { 628 select { 629 case <-s.synCh: 630 default: 631 s.logger.Printf("[ERR] yamux: SYN tracking out of sync") 632 } 633 } 634 delete(s.streams, id) 635 s.streamLock.Unlock() 636} 637 638// establishStream is used to mark a stream that was in the 639// SYN Sent state as established. 640func (s *Session) establishStream(id uint32) { 641 s.streamLock.Lock() 642 if _, ok := s.inflight[id]; ok { 643 delete(s.inflight, id) 644 } else { 645 s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)") 646 } 647 select { 648 case <-s.synCh: 649 default: 650 s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)") 651 } 652 s.streamLock.Unlock() 653} 654