1package yamux 2 3import ( 4 "bufio" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "log" 9 "math" 10 "net" 11 "strings" 12 "sync" 13 "sync/atomic" 14 "time" 15) 16 17// Session is used to wrap a reliable ordered connection and to 18// multiplex it into multiple streams. 19type Session struct { 20 // remoteGoAway indicates the remote side does 21 // not want futher connections. Must be first for alignment. 22 remoteGoAway int32 23 24 // localGoAway indicates that we should stop 25 // accepting futher connections. Must be first for alignment. 26 localGoAway int32 27 28 // nextStreamID is the next stream we should 29 // send. This depends if we are a client/server. 30 nextStreamID uint32 31 32 // config holds our configuration 33 config *Config 34 35 // logger is used for our logs 36 logger *log.Logger 37 38 // conn is the underlying connection 39 conn io.ReadWriteCloser 40 41 // bufRead is a buffered reader 42 bufRead *bufio.Reader 43 44 // pings is used to track inflight pings 45 pings map[uint32]chan struct{} 46 pingID uint32 47 pingLock sync.Mutex 48 49 // streams maps a stream id to a stream, and inflight has an entry 50 // for any outgoing stream that has not yet been established. Both are 51 // protected by streamLock. 52 streams map[uint32]*Stream 53 inflight map[uint32]struct{} 54 streamLock sync.Mutex 55 56 // synCh acts like a semaphore. It is sized to the AcceptBacklog which 57 // is assumed to be symmetric between the client and server. This allows 58 // the client to avoid exceeding the backlog and instead blocks the open. 59 synCh chan struct{} 60 61 // acceptCh is used to pass ready streams to the client 62 acceptCh chan *Stream 63 64 // sendCh is used to mark a stream as ready to send, 65 // or to send a header out directly. 66 sendCh chan sendReady 67 68 // recvDoneCh is closed when recv() exits to avoid a race 69 // between stream registration and stream shutdown 70 recvDoneCh chan struct{} 71 72 // shutdown is used to safely close a session 73 shutdown bool 74 shutdownErr error 75 shutdownCh chan struct{} 76 shutdownLock sync.Mutex 77} 78 79// sendReady is used to either mark a stream as ready 80// or to directly send a header 81type sendReady struct { 82 Hdr []byte 83 Body io.Reader 84 Err chan error 85} 86 87// newSession is used to construct a new session 88func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { 89 s := &Session{ 90 config: config, 91 logger: log.New(config.LogOutput, "", log.LstdFlags), 92 conn: conn, 93 bufRead: bufio.NewReader(conn), 94 pings: make(map[uint32]chan struct{}), 95 streams: make(map[uint32]*Stream), 96 inflight: make(map[uint32]struct{}), 97 synCh: make(chan struct{}, config.AcceptBacklog), 98 acceptCh: make(chan *Stream, config.AcceptBacklog), 99 sendCh: make(chan sendReady, 64), 100 recvDoneCh: make(chan struct{}), 101 shutdownCh: make(chan struct{}), 102 } 103 if client { 104 s.nextStreamID = 1 105 } else { 106 s.nextStreamID = 2 107 } 108 go s.recv() 109 go s.send() 110 if config.EnableKeepAlive { 111 go s.keepalive() 112 } 113 return s 114} 115 116// IsClosed does a safe check to see if we have shutdown 117func (s *Session) IsClosed() bool { 118 select { 119 case <-s.shutdownCh: 120 return true 121 default: 122 return false 123 } 124} 125 126// NumStreams returns the number of currently open streams 127func (s *Session) NumStreams() int { 128 s.streamLock.Lock() 129 num := len(s.streams) 130 s.streamLock.Unlock() 131 return num 132} 133 134// Open is used to create a new stream as a net.Conn 135func (s *Session) Open() (net.Conn, error) { 136 conn, err := s.OpenStream() 137 if err != nil { 138 return nil, err 139 } 140 return conn, nil 141} 142 143// OpenStream is used to create a new stream 144func (s *Session) OpenStream() (*Stream, error) { 145 if s.IsClosed() { 146 return nil, ErrSessionShutdown 147 } 148 if atomic.LoadInt32(&s.remoteGoAway) == 1 { 149 return nil, ErrRemoteGoAway 150 } 151 152 // Block if we have too many inflight SYNs 153 select { 154 case s.synCh <- struct{}{}: 155 case <-s.shutdownCh: 156 return nil, ErrSessionShutdown 157 } 158 159GET_ID: 160 // Get an ID, and check for stream exhaustion 161 id := atomic.LoadUint32(&s.nextStreamID) 162 if id >= math.MaxUint32-1 { 163 return nil, ErrStreamsExhausted 164 } 165 if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) { 166 goto GET_ID 167 } 168 169 // Register the stream 170 stream := newStream(s, id, streamInit) 171 s.streamLock.Lock() 172 s.streams[id] = stream 173 s.inflight[id] = struct{}{} 174 s.streamLock.Unlock() 175 176 // Send the window update to create 177 if err := stream.sendWindowUpdate(); err != nil { 178 select { 179 case <-s.synCh: 180 default: 181 s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore") 182 } 183 return nil, err 184 } 185 return stream, nil 186} 187 188// Accept is used to block until the next available stream 189// is ready to be accepted. 190func (s *Session) Accept() (net.Conn, error) { 191 conn, err := s.AcceptStream() 192 if err != nil { 193 return nil, err 194 } 195 return conn, err 196} 197 198// AcceptStream is used to block until the next available stream 199// is ready to be accepted. 200func (s *Session) AcceptStream() (*Stream, error) { 201 select { 202 case stream := <-s.acceptCh: 203 if err := stream.sendWindowUpdate(); err != nil { 204 return nil, err 205 } 206 return stream, nil 207 case <-s.shutdownCh: 208 return nil, s.shutdownErr 209 } 210} 211 212// Close is used to close the session and all streams. 213// Attempts to send a GoAway before closing the connection. 214func (s *Session) Close() error { 215 s.shutdownLock.Lock() 216 defer s.shutdownLock.Unlock() 217 218 if s.shutdown { 219 return nil 220 } 221 s.shutdown = true 222 if s.shutdownErr == nil { 223 s.shutdownErr = ErrSessionShutdown 224 } 225 close(s.shutdownCh) 226 s.conn.Close() 227 <-s.recvDoneCh 228 229 s.streamLock.Lock() 230 defer s.streamLock.Unlock() 231 for _, stream := range s.streams { 232 stream.forceClose() 233 } 234 return nil 235} 236 237// exitErr is used to handle an error that is causing the 238// session to terminate. 239func (s *Session) exitErr(err error) { 240 s.shutdownLock.Lock() 241 if s.shutdownErr == nil { 242 s.shutdownErr = err 243 } 244 s.shutdownLock.Unlock() 245 s.Close() 246} 247 248// GoAway can be used to prevent accepting further 249// connections. It does not close the underlying conn. 250func (s *Session) GoAway() error { 251 return s.waitForSend(s.goAway(goAwayNormal), nil) 252} 253 254// goAway is used to send a goAway message 255func (s *Session) goAway(reason uint32) header { 256 atomic.SwapInt32(&s.localGoAway, 1) 257 hdr := header(make([]byte, headerSize)) 258 hdr.encode(typeGoAway, 0, 0, reason) 259 return hdr 260} 261 262// Ping is used to measure the RTT response time 263func (s *Session) Ping() (time.Duration, error) { 264 // Get a channel for the ping 265 ch := make(chan struct{}) 266 267 // Get a new ping id, mark as pending 268 s.pingLock.Lock() 269 id := s.pingID 270 s.pingID++ 271 s.pings[id] = ch 272 s.pingLock.Unlock() 273 274 // Send the ping request 275 hdr := header(make([]byte, headerSize)) 276 hdr.encode(typePing, flagSYN, 0, id) 277 if err := s.waitForSend(hdr, nil); err != nil { 278 return 0, err 279 } 280 281 // Wait for a response 282 start := time.Now() 283 select { 284 case <-ch: 285 case <-time.After(s.config.ConnectionWriteTimeout): 286 s.pingLock.Lock() 287 delete(s.pings, id) // Ignore it if a response comes later. 288 s.pingLock.Unlock() 289 return 0, ErrTimeout 290 case <-s.shutdownCh: 291 return 0, ErrSessionShutdown 292 } 293 294 // Compute the RTT 295 return time.Now().Sub(start), nil 296} 297 298// keepalive is a long running goroutine that periodically does 299// a ping to keep the connection alive. 300func (s *Session) keepalive() { 301 for { 302 select { 303 case <-time.After(s.config.KeepAliveInterval): 304 _, err := s.Ping() 305 if err != nil { 306 s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) 307 s.exitErr(ErrKeepAliveTimeout) 308 return 309 } 310 case <-s.shutdownCh: 311 return 312 } 313 } 314} 315 316// waitForSendErr waits to send a header, checking for a potential shutdown 317func (s *Session) waitForSend(hdr header, body io.Reader) error { 318 errCh := make(chan error, 1) 319 return s.waitForSendErr(hdr, body, errCh) 320} 321 322// waitForSendErr waits to send a header with optional data, checking for a 323// potential shutdown. Since there's the expectation that sends can happen 324// in a timely manner, we enforce the connection write timeout here. 325func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error { 326 timer := time.NewTimer(s.config.ConnectionWriteTimeout) 327 defer timer.Stop() 328 329 ready := sendReady{Hdr: hdr, Body: body, Err: errCh} 330 select { 331 case s.sendCh <- ready: 332 case <-s.shutdownCh: 333 return ErrSessionShutdown 334 case <-timer.C: 335 return ErrConnectionWriteTimeout 336 } 337 338 select { 339 case err := <-errCh: 340 return err 341 case <-s.shutdownCh: 342 return ErrSessionShutdown 343 case <-timer.C: 344 return ErrConnectionWriteTimeout 345 } 346} 347 348// sendNoWait does a send without waiting. Since there's the expectation that 349// the send happens right here, we enforce the connection write timeout if we 350// can't queue the header to be sent. 351func (s *Session) sendNoWait(hdr header) error { 352 timer := time.NewTimer(s.config.ConnectionWriteTimeout) 353 defer timer.Stop() 354 355 select { 356 case s.sendCh <- sendReady{Hdr: hdr}: 357 return nil 358 case <-s.shutdownCh: 359 return ErrSessionShutdown 360 case <-timer.C: 361 return ErrConnectionWriteTimeout 362 } 363} 364 365// send is a long running goroutine that sends data 366func (s *Session) send() { 367 for { 368 select { 369 case ready := <-s.sendCh: 370 // Send a header if ready 371 if ready.Hdr != nil { 372 sent := 0 373 for sent < len(ready.Hdr) { 374 n, err := s.conn.Write(ready.Hdr[sent:]) 375 if err != nil { 376 s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) 377 asyncSendErr(ready.Err, err) 378 s.exitErr(err) 379 return 380 } 381 sent += n 382 } 383 } 384 385 // Send data from a body if given 386 if ready.Body != nil { 387 _, err := io.Copy(s.conn, ready.Body) 388 if err != nil { 389 s.logger.Printf("[ERR] yamux: Failed to write body: %v", err) 390 asyncSendErr(ready.Err, err) 391 s.exitErr(err) 392 return 393 } 394 } 395 396 // No error, successful send 397 asyncSendErr(ready.Err, nil) 398 case <-s.shutdownCh: 399 return 400 } 401 } 402} 403 404// recv is a long running goroutine that accepts new data 405func (s *Session) recv() { 406 if err := s.recvLoop(); err != nil { 407 s.exitErr(err) 408 } 409} 410 411// recvLoop continues to receive data until a fatal error is encountered 412func (s *Session) recvLoop() error { 413 defer close(s.recvDoneCh) 414 hdr := header(make([]byte, headerSize)) 415 var handler func(header) error 416 for { 417 // Read the header 418 if _, err := io.ReadFull(s.bufRead, hdr); err != nil { 419 if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { 420 s.logger.Printf("[ERR] yamux: Failed to read header: %v", err) 421 } 422 return err 423 } 424 425 // Verify the version 426 if hdr.Version() != protoVersion { 427 s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version()) 428 return ErrInvalidVersion 429 } 430 431 // Switch on the type 432 switch hdr.MsgType() { 433 case typeData: 434 handler = s.handleStreamMessage 435 case typeWindowUpdate: 436 handler = s.handleStreamMessage 437 case typeGoAway: 438 handler = s.handleGoAway 439 case typePing: 440 handler = s.handlePing 441 default: 442 return ErrInvalidMsgType 443 } 444 445 // Invoke the handler 446 if err := handler(hdr); err != nil { 447 return err 448 } 449 } 450} 451 452// handleStreamMessage handles either a data or window update frame 453func (s *Session) handleStreamMessage(hdr header) error { 454 // Check for a new stream creation 455 id := hdr.StreamID() 456 flags := hdr.Flags() 457 if flags&flagSYN == flagSYN { 458 if err := s.incomingStream(id); err != nil { 459 return err 460 } 461 } 462 463 // Get the stream 464 s.streamLock.Lock() 465 stream := s.streams[id] 466 s.streamLock.Unlock() 467 468 // If we do not have a stream, likely we sent a RST 469 if stream == nil { 470 // Drain any data on the wire 471 if hdr.MsgType() == typeData && hdr.Length() > 0 { 472 s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id) 473 if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil { 474 s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err) 475 return nil 476 } 477 } else { 478 s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr) 479 } 480 return nil 481 } 482 483 // Check if this is a window update 484 if hdr.MsgType() == typeWindowUpdate { 485 if err := stream.incrSendWindow(hdr, flags); err != nil { 486 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 487 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 488 } 489 return err 490 } 491 return nil 492 } 493 494 // Read the new data 495 if err := stream.readData(hdr, flags, s.bufRead); err != nil { 496 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 497 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 498 } 499 return err 500 } 501 return nil 502} 503 504// handlePing is invokde for a typePing frame 505func (s *Session) handlePing(hdr header) error { 506 flags := hdr.Flags() 507 pingID := hdr.Length() 508 509 // Check if this is a query, respond back in a separate context so we 510 // don't interfere with the receiving thread blocking for the write. 511 if flags&flagSYN == flagSYN { 512 go func() { 513 hdr := header(make([]byte, headerSize)) 514 hdr.encode(typePing, flagACK, 0, pingID) 515 if err := s.sendNoWait(hdr); err != nil { 516 s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err) 517 } 518 }() 519 return nil 520 } 521 522 // Handle a response 523 s.pingLock.Lock() 524 ch := s.pings[pingID] 525 if ch != nil { 526 delete(s.pings, pingID) 527 close(ch) 528 } 529 s.pingLock.Unlock() 530 return nil 531} 532 533// handleGoAway is invokde for a typeGoAway frame 534func (s *Session) handleGoAway(hdr header) error { 535 code := hdr.Length() 536 switch code { 537 case goAwayNormal: 538 atomic.SwapInt32(&s.remoteGoAway, 1) 539 case goAwayProtoErr: 540 s.logger.Printf("[ERR] yamux: received protocol error go away") 541 return fmt.Errorf("yamux protocol error") 542 case goAwayInternalErr: 543 s.logger.Printf("[ERR] yamux: received internal error go away") 544 return fmt.Errorf("remote yamux internal error") 545 default: 546 s.logger.Printf("[ERR] yamux: received unexpected go away") 547 return fmt.Errorf("unexpected go away received") 548 } 549 return nil 550} 551 552// incomingStream is used to create a new incoming stream 553func (s *Session) incomingStream(id uint32) error { 554 // Reject immediately if we are doing a go away 555 if atomic.LoadInt32(&s.localGoAway) == 1 { 556 hdr := header(make([]byte, headerSize)) 557 hdr.encode(typeWindowUpdate, flagRST, id, 0) 558 return s.sendNoWait(hdr) 559 } 560 561 // Allocate a new stream 562 stream := newStream(s, id, streamSYNReceived) 563 564 s.streamLock.Lock() 565 defer s.streamLock.Unlock() 566 567 // Check if stream already exists 568 if _, ok := s.streams[id]; ok { 569 s.logger.Printf("[ERR] yamux: duplicate stream declared") 570 if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 571 s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 572 } 573 return ErrDuplicateStream 574 } 575 576 // Register the stream 577 s.streams[id] = stream 578 579 // Check if we've exceeded the backlog 580 select { 581 case s.acceptCh <- stream: 582 return nil 583 default: 584 // Backlog exceeded! RST the stream 585 s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset") 586 delete(s.streams, id) 587 stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0) 588 return s.sendNoWait(stream.sendHdr) 589 } 590} 591 592// closeStream is used to close a stream once both sides have 593// issued a close. If there was an in-flight SYN and the stream 594// was not yet established, then this will give the credit back. 595func (s *Session) closeStream(id uint32) { 596 s.streamLock.Lock() 597 if _, ok := s.inflight[id]; ok { 598 select { 599 case <-s.synCh: 600 default: 601 s.logger.Printf("[ERR] yamux: SYN tracking out of sync") 602 } 603 } 604 delete(s.streams, id) 605 s.streamLock.Unlock() 606} 607 608// establishStream is used to mark a stream that was in the 609// SYN Sent state as established. 610func (s *Session) establishStream(id uint32) { 611 s.streamLock.Lock() 612 if _, ok := s.inflight[id]; ok { 613 delete(s.inflight, id) 614 } else { 615 s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)") 616 } 617 select { 618 case <-s.synCh: 619 default: 620 s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)") 621 } 622 s.streamLock.Unlock() 623} 624