1package yamux
2
3import (
4	"bufio"
5	"context"
6	"fmt"
7	"io"
8	"io/ioutil"
9	"log"
10	"math"
11	"net"
12	"os"
13	"strings"
14	"sync"
15	"sync/atomic"
16	"time"
17
18	"github.com/libp2p/go-buffer-pool"
19)
20
21// Session is used to wrap a reliable ordered connection and to
22// multiplex it into multiple streams.
23type Session struct {
24	// remoteGoAway indicates the remote side does
25	// not want futher connections. Must be first for alignment.
26	remoteGoAway int32
27
28	// localGoAway indicates that we should stop
29	// accepting futher connections. Must be first for alignment.
30	localGoAway int32
31
32	// nextStreamID is the next stream we should
33	// send. This depends if we are a client/server.
34	nextStreamID uint32
35
36	// config holds our configuration
37	config *Config
38
39	// logger is used for our logs
40	logger *log.Logger
41
42	// conn is the underlying connection
43	conn net.Conn
44
45	// reader is a buffered reader
46	reader io.Reader
47
48	// pings is used to track inflight pings
49	pingLock   sync.Mutex
50	pingID     uint32
51	activePing *ping
52
53	// streams maps a stream id to a stream, and inflight has an entry
54	// for any outgoing stream that has not yet been established. Both are
55	// protected by streamLock.
56	streams    map[uint32]*Stream
57	inflight   map[uint32]struct{}
58	streamLock sync.Mutex
59
60	// synCh acts like a semaphore. It is sized to the AcceptBacklog which
61	// is assumed to be symmetric between the client and server. This allows
62	// the client to avoid exceeding the backlog and instead blocks the open.
63	synCh chan struct{}
64
65	// acceptCh is used to pass ready streams to the client
66	acceptCh chan *Stream
67
68	// sendCh is used to send messages
69	sendCh chan []byte
70
71	// pingCh and pingCh are used to send pings and pongs
72	pongCh, pingCh chan uint32
73
74	// recvDoneCh is closed when recv() exits to avoid a race
75	// between stream registration and stream shutdown
76	recvDoneCh chan struct{}
77
78	// sendDoneCh is closed when send() exits to avoid a race
79	// between returning from a Stream.Write and exiting from the send loop
80	// (which may be reading a buffer on-load-from Stream.Write).
81	sendDoneCh chan struct{}
82
83	// client is true if we're the client and our stream IDs should be odd.
84	client bool
85
86	// shutdown is used to safely close a session
87	shutdown     bool
88	shutdownErr  error
89	shutdownCh   chan struct{}
90	shutdownLock sync.Mutex
91
92	// keepaliveTimer is a periodic timer for keepalive messages. It's nil
93	// when keepalives are disabled.
94	keepaliveLock   sync.Mutex
95	keepaliveTimer  *time.Timer
96	keepaliveActive bool
97}
98
99// newSession is used to construct a new session
100func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Session {
101	var reader io.Reader = conn
102	if readBuf > 0 {
103		reader = bufio.NewReaderSize(reader, readBuf)
104	}
105	s := &Session{
106		config:     config,
107		client:     client,
108		logger:     log.New(config.LogOutput, "", log.LstdFlags),
109		conn:       conn,
110		reader:     reader,
111		streams:    make(map[uint32]*Stream),
112		inflight:   make(map[uint32]struct{}),
113		synCh:      make(chan struct{}, config.AcceptBacklog),
114		acceptCh:   make(chan *Stream, config.AcceptBacklog),
115		sendCh:     make(chan []byte, 64),
116		pongCh:     make(chan uint32, config.PingBacklog),
117		pingCh:     make(chan uint32),
118		recvDoneCh: make(chan struct{}),
119		sendDoneCh: make(chan struct{}),
120		shutdownCh: make(chan struct{}),
121	}
122	if client {
123		s.nextStreamID = 1
124	} else {
125		s.nextStreamID = 2
126	}
127	if config.EnableKeepAlive {
128		s.startKeepalive()
129	}
130	go s.recv()
131	go s.send()
132	return s
133}
134
135// IsClosed does a safe check to see if we have shutdown
136func (s *Session) IsClosed() bool {
137	select {
138	case <-s.shutdownCh:
139		return true
140	default:
141		return false
142	}
143}
144
145// CloseChan returns a read-only channel which is closed as
146// soon as the session is closed.
147func (s *Session) CloseChan() <-chan struct{} {
148	return s.shutdownCh
149}
150
151// NumStreams returns the number of currently open streams
152func (s *Session) NumStreams() int {
153	s.streamLock.Lock()
154	num := len(s.streams)
155	s.streamLock.Unlock()
156	return num
157}
158
159// Open is used to create a new stream as a net.Conn
160func (s *Session) Open(ctx context.Context) (net.Conn, error) {
161	conn, err := s.OpenStream(ctx)
162	if err != nil {
163		return nil, err
164	}
165	return conn, nil
166}
167
168// OpenStream is used to create a new stream
169func (s *Session) OpenStream(ctx context.Context) (*Stream, error) {
170	if s.IsClosed() {
171		return nil, s.shutdownErr
172	}
173	if atomic.LoadInt32(&s.remoteGoAway) == 1 {
174		return nil, ErrRemoteGoAway
175	}
176
177	// Block if we have too many inflight SYNs
178	select {
179	case s.synCh <- struct{}{}:
180	case <-ctx.Done():
181		return nil, ctx.Err()
182	case <-s.shutdownCh:
183		return nil, s.shutdownErr
184	}
185
186GET_ID:
187	// Get an ID, and check for stream exhaustion
188	id := atomic.LoadUint32(&s.nextStreamID)
189	if id >= math.MaxUint32-1 {
190		return nil, ErrStreamsExhausted
191	}
192	if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
193		goto GET_ID
194	}
195
196	// Register the stream
197	stream := newStream(s, id, streamInit)
198	s.streamLock.Lock()
199	s.streams[id] = stream
200	s.inflight[id] = struct{}{}
201	s.streamLock.Unlock()
202
203	// Send the window update to create
204	if err := stream.sendWindowUpdate(); err != nil {
205		select {
206		case <-s.synCh:
207		default:
208			s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
209		}
210		return nil, err
211	}
212	return stream, nil
213}
214
215// Accept is used to block until the next available stream
216// is ready to be accepted.
217func (s *Session) Accept() (net.Conn, error) {
218	conn, err := s.AcceptStream()
219	if err != nil {
220		return nil, err
221	}
222	return conn, err
223}
224
225// AcceptStream is used to block until the next available stream
226// is ready to be accepted.
227func (s *Session) AcceptStream() (*Stream, error) {
228	for {
229		select {
230		case stream := <-s.acceptCh:
231			if err := stream.sendWindowUpdate(); err != nil {
232				// don't return accept errors.
233				s.logger.Printf("[WARN] error sending window update before accepting: %s", err)
234				continue
235			}
236			return stream, nil
237		case <-s.shutdownCh:
238			return nil, s.shutdownErr
239		}
240	}
241}
242
243// Close is used to close the session and all streams.
244// Attempts to send a GoAway before closing the connection.
245func (s *Session) Close() error {
246	s.shutdownLock.Lock()
247	defer s.shutdownLock.Unlock()
248
249	if s.shutdown {
250		return nil
251	}
252	s.shutdown = true
253	if s.shutdownErr == nil {
254		s.shutdownErr = ErrSessionShutdown
255	}
256	close(s.shutdownCh)
257	s.conn.Close()
258	s.stopKeepalive()
259	<-s.recvDoneCh
260	<-s.sendDoneCh
261
262	s.streamLock.Lock()
263	defer s.streamLock.Unlock()
264	for _, stream := range s.streams {
265		stream.forceClose()
266	}
267	return nil
268}
269
270// exitErr is used to handle an error that is causing the
271// session to terminate.
272func (s *Session) exitErr(err error) {
273	s.shutdownLock.Lock()
274	if s.shutdownErr == nil {
275		s.shutdownErr = err
276	}
277	s.shutdownLock.Unlock()
278	s.Close()
279}
280
281// GoAway can be used to prevent accepting further
282// connections. It does not close the underlying conn.
283func (s *Session) GoAway() error {
284	return s.sendMsg(s.goAway(goAwayNormal), nil, nil)
285}
286
287// goAway is used to send a goAway message
288func (s *Session) goAway(reason uint32) header {
289	atomic.SwapInt32(&s.localGoAway, 1)
290	hdr := encode(typeGoAway, 0, 0, reason)
291	return hdr
292}
293
294// Ping is used to measure the RTT response time
295func (s *Session) Ping() (dur time.Duration, err error) {
296	// Prepare a ping.
297	s.pingLock.Lock()
298	// If there's an active ping, jump on the bandwagon.
299	if activePing := s.activePing; activePing != nil {
300		s.pingLock.Unlock()
301		return activePing.wait()
302	}
303
304	// Ok, our job to send the ping.
305	activePing := newPing(s.pingID)
306	s.pingID++
307	s.activePing = activePing
308	s.pingLock.Unlock()
309
310	defer func() {
311		// complete ping promise
312		activePing.finish(dur, err)
313
314		// Unset it.
315		s.pingLock.Lock()
316		s.activePing = nil
317		s.pingLock.Unlock()
318	}()
319
320	// Send the ping request, waiting at most one connection write timeout
321	// to flush it.
322	timer := time.NewTimer(s.config.ConnectionWriteTimeout)
323	defer timer.Stop()
324	select {
325	case s.pingCh <- activePing.id:
326	case <-timer.C:
327		return 0, ErrTimeout
328	case <-s.shutdownCh:
329		return 0, s.shutdownErr
330	}
331
332	// The "time" starts once we've actually sent the ping. Otherwise, we'll
333	// measure the time it takes to flush the queue as well.
334	start := time.Now()
335
336	// Wait for a response, again waiting at most one write timeout.
337	if !timer.Stop() {
338		<-timer.C
339	}
340	timer.Reset(s.config.ConnectionWriteTimeout)
341	select {
342	case <-activePing.pingResponse:
343	case <-timer.C:
344		return 0, ErrTimeout
345	case <-s.shutdownCh:
346		return 0, s.shutdownErr
347	}
348
349	// Compute the RTT
350	return time.Since(start), nil
351}
352
353// startKeepalive starts the keepalive process.
354func (s *Session) startKeepalive() {
355	s.keepaliveLock.Lock()
356	defer s.keepaliveLock.Unlock()
357	s.keepaliveTimer = time.AfterFunc(s.config.KeepAliveInterval, func() {
358		s.keepaliveLock.Lock()
359		if s.keepaliveTimer == nil || s.keepaliveActive {
360			// keepalives have been stopped or a keepalive is active.
361			s.keepaliveLock.Unlock()
362			return
363		}
364		s.keepaliveActive = true
365		s.keepaliveLock.Unlock()
366
367		_, err := s.Ping()
368
369		s.keepaliveLock.Lock()
370		s.keepaliveActive = false
371		if s.keepaliveTimer != nil {
372			s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
373		}
374		s.keepaliveLock.Unlock()
375
376		if err != nil {
377			s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
378			s.exitErr(ErrKeepAliveTimeout)
379		}
380	})
381}
382
383// stopKeepalive stops the keepalive process.
384func (s *Session) stopKeepalive() {
385	s.keepaliveLock.Lock()
386	defer s.keepaliveLock.Unlock()
387	if s.keepaliveTimer != nil {
388		s.keepaliveTimer.Stop()
389		s.keepaliveTimer = nil
390	}
391}
392
393func (s *Session) extendKeepalive() {
394	s.keepaliveLock.Lock()
395	if s.keepaliveTimer != nil && !s.keepaliveActive {
396		// Don't stop the timer and drain the channel. This is an
397		// AfterFunc, not a normal timer, and any attempts to drain the
398		// channel will block forever.
399		//
400		// Go will stop the timer for us internally anyways. The docs
401		// say one must stop the timer before calling reset but that's
402		// to ensure that the timer doesn't end up firing immediately
403		// after calling Reset.
404		s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
405	}
406	s.keepaliveLock.Unlock()
407}
408
409// send sends the header and body.
410func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) error {
411	select {
412	case <-s.shutdownCh:
413		return s.shutdownErr
414	default:
415	}
416
417	// duplicate as we're sending this async.
418	buf := pool.Get(headerSize + len(body))
419	copy(buf[:headerSize], hdr[:])
420	copy(buf[headerSize:], body)
421
422	select {
423	case <-s.shutdownCh:
424		pool.Put(buf)
425		return s.shutdownErr
426	case s.sendCh <- buf:
427		return nil
428	case <-deadline:
429		pool.Put(buf)
430		return ErrTimeout
431	}
432}
433
434// send is a long running goroutine that sends data
435func (s *Session) send() {
436	if err := s.sendLoop(); err != nil {
437		s.exitErr(err)
438	}
439}
440
441func (s *Session) sendLoop() error {
442	defer close(s.sendDoneCh)
443
444	// Extend the write deadline if we've passed the halfway point. This can
445	// be expensive so this ensures we only have to do this once every
446	// ConnectionWriteTimeout/2 (usually 5s).
447	var lastWriteDeadline time.Time
448	extendWriteDeadline := func() error {
449		now := time.Now()
450		// If over half of the deadline has elapsed, extend it.
451		if now.Add(s.config.ConnectionWriteTimeout / 2).After(lastWriteDeadline) {
452			lastWriteDeadline = now.Add(s.config.ConnectionWriteTimeout)
453			return s.conn.SetWriteDeadline(lastWriteDeadline)
454		}
455		return nil
456	}
457
458	writer := s.conn
459
460	// FIXME: https://github.com/libp2p/go-libp2p/issues/644
461	// Write coalescing is disabled for now.
462
463	//writer := pool.Writer{W: s.conn}
464
465	//var writeTimeout *time.Timer
466	//var writeTimeoutCh <-chan time.Time
467	//if s.config.WriteCoalesceDelay > 0 {
468	//	writeTimeout = time.NewTimer(s.config.WriteCoalesceDelay)
469	//	defer writeTimeout.Stop()
470
471	//	writeTimeoutCh = writeTimeout.C
472	//} else {
473	//	ch := make(chan time.Time)
474	//	close(ch)
475	//	writeTimeoutCh = ch
476	//}
477
478	for {
479		// yield after processing the last message, if we've shutdown.
480		// s.sendCh is a buffered channel and Go doesn't guarantee select order.
481		select {
482		case <-s.shutdownCh:
483			return nil
484		default:
485		}
486
487		// Flushes at least once every 100 microseconds unless we're
488		// constantly writing.
489		var buf []byte
490		select {
491		case buf = <-s.sendCh:
492		case pingID := <-s.pingCh:
493			buf = pool.Get(headerSize)
494			hdr := encode(typePing, flagSYN, 0, pingID)
495			copy(buf, hdr[:])
496		case pingID := <-s.pongCh:
497			buf = pool.Get(headerSize)
498			hdr := encode(typePing, flagACK, 0, pingID)
499			copy(buf, hdr[:])
500		case <-s.shutdownCh:
501			return nil
502			//default:
503			//	select {
504			//	case buf = <-s.sendCh:
505			//	case <-s.shutdownCh:
506			//		return nil
507			//	case <-writeTimeoutCh:
508			//		if err := writer.Flush(); err != nil {
509			//			if os.IsTimeout(err) {
510			//				err = ErrConnectionWriteTimeout
511			//			}
512			//			return err
513			//		}
514
515			//		select {
516			//		case buf = <-s.sendCh:
517			//		case <-s.shutdownCh:
518			//			return nil
519			//		}
520
521			//		if writeTimeout != nil {
522			//			writeTimeout.Reset(s.config.WriteCoalesceDelay)
523			//		}
524			//	}
525		}
526
527		if err := extendWriteDeadline(); err != nil {
528			pool.Put(buf)
529			return err
530		}
531
532		_, err := writer.Write(buf)
533		pool.Put(buf)
534
535		if err != nil {
536			if os.IsTimeout(err) {
537				err = ErrConnectionWriteTimeout
538			}
539			return err
540		}
541	}
542}
543
544// recv is a long running goroutine that accepts new data
545func (s *Session) recv() {
546	if err := s.recvLoop(); err != nil {
547		s.exitErr(err)
548	}
549}
550
551// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type
552var (
553	handlers = []func(*Session, header) error{
554		typeData:         (*Session).handleStreamMessage,
555		typeWindowUpdate: (*Session).handleStreamMessage,
556		typePing:         (*Session).handlePing,
557		typeGoAway:       (*Session).handleGoAway,
558	}
559)
560
561// recvLoop continues to receive data until a fatal error is encountered
562func (s *Session) recvLoop() error {
563	defer close(s.recvDoneCh)
564	var hdr header
565	for {
566		// Read the header
567		if _, err := io.ReadFull(s.reader, hdr[:]); err != nil {
568			if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
569				s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
570			}
571			return err
572		}
573
574		// Reset the keepalive timer every time we receive data.
575		// There's no reason to keepalive if we're active. Worse, if the
576		// peer is busy sending us stuff, the pong might get stuck
577		// behind a bunch of data.
578		s.extendKeepalive()
579
580		// Verify the version
581		if hdr.Version() != protoVersion {
582			s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
583			return ErrInvalidVersion
584		}
585
586		mt := hdr.MsgType()
587		if mt < typeData || mt > typeGoAway {
588			return ErrInvalidMsgType
589		}
590
591		if err := handlers[mt](s, hdr); err != nil {
592			return err
593		}
594	}
595}
596
597// handleStreamMessage handles either a data or window update frame
598func (s *Session) handleStreamMessage(hdr header) error {
599	// Check for a new stream creation
600	id := hdr.StreamID()
601	flags := hdr.Flags()
602	if flags&flagSYN == flagSYN {
603		if err := s.incomingStream(id); err != nil {
604			return err
605		}
606	}
607
608	// Get the stream
609	s.streamLock.Lock()
610	stream := s.streams[id]
611	s.streamLock.Unlock()
612
613	// If we do not have a stream, likely we sent a RST
614	if stream == nil {
615		// Drain any data on the wire
616		if hdr.MsgType() == typeData && hdr.Length() > 0 {
617			s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
618			if _, err := io.CopyN(ioutil.Discard, s.reader, int64(hdr.Length())); err != nil {
619				s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
620				return nil
621			}
622		} else {
623			s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
624		}
625		return nil
626	}
627
628	// Check if this is a window update
629	if hdr.MsgType() == typeWindowUpdate {
630		if err := stream.incrSendWindow(hdr, flags); err != nil {
631			if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil {
632				s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
633			}
634			return err
635		}
636		return nil
637	}
638
639	// Read the new data
640	if err := stream.readData(hdr, flags, s.reader); err != nil {
641		if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil {
642			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
643		}
644		return err
645	}
646	return nil
647}
648
649// handlePing is invoked for a typePing frame
650func (s *Session) handlePing(hdr header) error {
651	flags := hdr.Flags()
652	pingID := hdr.Length()
653
654	// Check if this is a query, respond back in a separate context so we
655	// don't interfere with the receiving thread blocking for the write.
656	if flags&flagSYN == flagSYN {
657		select {
658		case s.pongCh <- pingID:
659		default:
660			s.logger.Printf("[WARN] yamux: dropped ping reply")
661		}
662		return nil
663	}
664
665	// Handle a response
666	s.pingLock.Lock()
667	// If we have an active ping, and this is a response to that active
668	// ping, complete the ping.
669	if s.activePing != nil && s.activePing.id == pingID {
670		// Don't assume that the peer won't send multiple responses for
671		// the same ping.
672		select {
673		case s.activePing.pingResponse <- struct{}{}:
674		default:
675		}
676	}
677	s.pingLock.Unlock()
678	return nil
679}
680
681// handleGoAway is invokde for a typeGoAway frame
682func (s *Session) handleGoAway(hdr header) error {
683	code := hdr.Length()
684	switch code {
685	case goAwayNormal:
686		atomic.SwapInt32(&s.remoteGoAway, 1)
687	case goAwayProtoErr:
688		s.logger.Printf("[ERR] yamux: received protocol error go away")
689		return fmt.Errorf("yamux protocol error")
690	case goAwayInternalErr:
691		s.logger.Printf("[ERR] yamux: received internal error go away")
692		return fmt.Errorf("remote yamux internal error")
693	default:
694		s.logger.Printf("[ERR] yamux: received unexpected go away")
695		return fmt.Errorf("unexpected go away received")
696	}
697	return nil
698}
699
700// incomingStream is used to create a new incoming stream
701func (s *Session) incomingStream(id uint32) error {
702	if s.client != (id%2 == 0) {
703		s.logger.Printf("[ERR] yamux: both endpoints are clients")
704		return fmt.Errorf("both yamux endpoints are clients")
705	}
706	// Reject immediately if we are doing a go away
707	if atomic.LoadInt32(&s.localGoAway) == 1 {
708		hdr := encode(typeWindowUpdate, flagRST, id, 0)
709		return s.sendMsg(hdr, nil, nil)
710	}
711
712	// Allocate a new stream
713	stream := newStream(s, id, streamSYNReceived)
714
715	s.streamLock.Lock()
716	defer s.streamLock.Unlock()
717
718	// Check if stream already exists
719	if _, ok := s.streams[id]; ok {
720		s.logger.Printf("[ERR] yamux: duplicate stream declared")
721		if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil {
722			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
723		}
724		return ErrDuplicateStream
725	}
726
727	// Register the stream
728	s.streams[id] = stream
729
730	// Check if we've exceeded the backlog
731	select {
732	case s.acceptCh <- stream:
733		return nil
734	default:
735		// Backlog exceeded! RST the stream
736		s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
737		delete(s.streams, id)
738		hdr := encode(typeWindowUpdate, flagRST, id, 0)
739		return s.sendMsg(hdr, nil, nil)
740	}
741}
742
743// closeStream is used to close a stream once both sides have
744// issued a close. If there was an in-flight SYN and the stream
745// was not yet established, then this will give the credit back.
746func (s *Session) closeStream(id uint32) {
747	s.streamLock.Lock()
748	if _, ok := s.inflight[id]; ok {
749		select {
750		case <-s.synCh:
751		default:
752			s.logger.Printf("[ERR] yamux: SYN tracking out of sync")
753		}
754		delete(s.inflight, id)
755	}
756	delete(s.streams, id)
757	s.streamLock.Unlock()
758}
759
760// establishStream is used to mark a stream that was in the
761// SYN Sent state as established.
762func (s *Session) establishStream(id uint32) {
763	s.streamLock.Lock()
764	if _, ok := s.inflight[id]; ok {
765		delete(s.inflight, id)
766	} else {
767		s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)")
768	}
769	select {
770	case <-s.synCh:
771	default:
772		s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)")
773	}
774	s.streamLock.Unlock()
775}
776