1package yamux
2
3import (
4	"bufio"
5	"fmt"
6	"io"
7	"io/ioutil"
8	"log"
9	"math"
10	"net"
11	"strings"
12	"sync"
13	"sync/atomic"
14	"time"
15)
16
17// Session is used to wrap a reliable ordered connection and to
18// multiplex it into multiple streams.
19type Session struct {
20	// remoteGoAway indicates the remote side does
21	// not want futher connections. Must be first for alignment.
22	remoteGoAway int32
23
24	// localGoAway indicates that we should stop
25	// accepting futher connections. Must be first for alignment.
26	localGoAway int32
27
28	// nextStreamID is the next stream we should
29	// send. This depends if we are a client/server.
30	nextStreamID uint32
31
32	// config holds our configuration
33	config *Config
34
35	// logger is used for our logs
36	logger *log.Logger
37
38	// conn is the underlying connection
39	conn io.ReadWriteCloser
40
41	// bufRead is a buffered reader
42	bufRead *bufio.Reader
43
44	// pings is used to track inflight pings
45	pings    map[uint32]chan struct{}
46	pingID   uint32
47	pingLock sync.Mutex
48
49	// streams maps a stream id to a stream, and inflight has an entry
50	// for any outgoing stream that has not yet been established. Both are
51	// protected by streamLock.
52	streams    map[uint32]*Stream
53	inflight   map[uint32]struct{}
54	streamLock sync.Mutex
55
56	// synCh acts like a semaphore. It is sized to the AcceptBacklog which
57	// is assumed to be symmetric between the client and server. This allows
58	// the client to avoid exceeding the backlog and instead blocks the open.
59	synCh chan struct{}
60
61	// acceptCh is used to pass ready streams to the client
62	acceptCh chan *Stream
63
64	// sendCh is used to mark a stream as ready to send,
65	// or to send a header out directly.
66	sendCh chan sendReady
67
68	// recvDoneCh is closed when recv() exits to avoid a race
69	// between stream registration and stream shutdown
70	recvDoneCh chan struct{}
71
72	// shutdown is used to safely close a session
73	shutdown     bool
74	shutdownErr  error
75	shutdownCh   chan struct{}
76	shutdownLock sync.Mutex
77}
78
79// sendReady is used to either mark a stream as ready
80// or to directly send a header
81type sendReady struct {
82	Hdr  []byte
83	Body []byte
84	Err  chan error
85}
86
87// newSession is used to construct a new session
88func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
89	logger := config.Logger
90	if logger == nil {
91		logger = log.New(config.LogOutput, "", log.LstdFlags)
92	}
93
94	s := &Session{
95		config:     config,
96		logger:     logger,
97		conn:       conn,
98		bufRead:    bufio.NewReader(conn),
99		pings:      make(map[uint32]chan struct{}),
100		streams:    make(map[uint32]*Stream),
101		inflight:   make(map[uint32]struct{}),
102		synCh:      make(chan struct{}, config.AcceptBacklog),
103		acceptCh:   make(chan *Stream, config.AcceptBacklog),
104		sendCh:     make(chan sendReady, 64),
105		recvDoneCh: make(chan struct{}),
106		shutdownCh: make(chan struct{}),
107	}
108	if client {
109		s.nextStreamID = 1
110	} else {
111		s.nextStreamID = 2
112	}
113	go s.recv()
114	go s.send()
115	if config.EnableKeepAlive {
116		go s.keepalive()
117	}
118	return s
119}
120
121// IsClosed does a safe check to see if we have shutdown
122func (s *Session) IsClosed() bool {
123	select {
124	case <-s.shutdownCh:
125		return true
126	default:
127		return false
128	}
129}
130
131// CloseChan returns a read-only channel which is closed as
132// soon as the session is closed.
133func (s *Session) CloseChan() <-chan struct{} {
134	return s.shutdownCh
135}
136
137// NumStreams returns the number of currently open streams
138func (s *Session) NumStreams() int {
139	s.streamLock.Lock()
140	num := len(s.streams)
141	s.streamLock.Unlock()
142	return num
143}
144
145// Open is used to create a new stream as a net.Conn
146func (s *Session) Open() (net.Conn, error) {
147	conn, err := s.OpenStream()
148	if err != nil {
149		return nil, err
150	}
151	return conn, nil
152}
153
154// OpenStream is used to create a new stream
155func (s *Session) OpenStream() (*Stream, error) {
156	if s.IsClosed() {
157		return nil, ErrSessionShutdown
158	}
159	if atomic.LoadInt32(&s.remoteGoAway) == 1 {
160		return nil, ErrRemoteGoAway
161	}
162
163	// Block if we have too many inflight SYNs
164	select {
165	case s.synCh <- struct{}{}:
166	case <-s.shutdownCh:
167		return nil, ErrSessionShutdown
168	}
169
170GET_ID:
171	// Get an ID, and check for stream exhaustion
172	id := atomic.LoadUint32(&s.nextStreamID)
173	if id >= math.MaxUint32-1 {
174		return nil, ErrStreamsExhausted
175	}
176	if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
177		goto GET_ID
178	}
179
180	// Register the stream
181	stream := newStream(s, id, streamInit)
182	s.streamLock.Lock()
183	s.streams[id] = stream
184	s.inflight[id] = struct{}{}
185	s.streamLock.Unlock()
186
187	if s.config.StreamOpenTimeout > 0 {
188		go s.setOpenTimeout(stream)
189	}
190
191	// Send the window update to create
192	if err := stream.sendWindowUpdate(); err != nil {
193		select {
194		case <-s.synCh:
195		default:
196			s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
197		}
198		return nil, err
199	}
200	return stream, nil
201}
202
203// setOpenTimeout implements a timeout for streams that are opened but not established.
204// If the StreamOpenTimeout is exceeded we assume the peer is unable to ACK,
205// and close the session.
206// The number of running timers is bounded by the capacity of the synCh.
207func (s *Session) setOpenTimeout(stream *Stream) {
208	timer := time.NewTimer(s.config.StreamOpenTimeout)
209	defer timer.Stop()
210
211	select {
212	case <-stream.establishCh:
213		return
214	case <-s.shutdownCh:
215		return
216	case <-timer.C:
217		// Timeout reached while waiting for ACK.
218		// Close the session to force connection re-establishment.
219		s.logger.Printf("[ERR] yamux: aborted stream open (destination=%s): %v", s.RemoteAddr().String(), ErrTimeout.err)
220		s.Close()
221	}
222}
223
224// Accept is used to block until the next available stream
225// is ready to be accepted.
226func (s *Session) Accept() (net.Conn, error) {
227	conn, err := s.AcceptStream()
228	if err != nil {
229		return nil, err
230	}
231	return conn, err
232}
233
234// AcceptStream is used to block until the next available stream
235// is ready to be accepted.
236func (s *Session) AcceptStream() (*Stream, error) {
237	select {
238	case stream := <-s.acceptCh:
239		if err := stream.sendWindowUpdate(); err != nil {
240			return nil, err
241		}
242		return stream, nil
243	case <-s.shutdownCh:
244		return nil, s.shutdownErr
245	}
246}
247
248// Close is used to close the session and all streams.
249// Attempts to send a GoAway before closing the connection.
250func (s *Session) Close() error {
251	s.shutdownLock.Lock()
252	defer s.shutdownLock.Unlock()
253
254	if s.shutdown {
255		return nil
256	}
257	s.shutdown = true
258	if s.shutdownErr == nil {
259		s.shutdownErr = ErrSessionShutdown
260	}
261	close(s.shutdownCh)
262	s.conn.Close()
263	<-s.recvDoneCh
264
265	s.streamLock.Lock()
266	defer s.streamLock.Unlock()
267	for _, stream := range s.streams {
268		stream.forceClose()
269	}
270	return nil
271}
272
273// exitErr is used to handle an error that is causing the
274// session to terminate.
275func (s *Session) exitErr(err error) {
276	s.shutdownLock.Lock()
277	if s.shutdownErr == nil {
278		s.shutdownErr = err
279	}
280	s.shutdownLock.Unlock()
281	s.Close()
282}
283
284// GoAway can be used to prevent accepting further
285// connections. It does not close the underlying conn.
286func (s *Session) GoAway() error {
287	return s.waitForSend(s.goAway(goAwayNormal), nil)
288}
289
290// goAway is used to send a goAway message
291func (s *Session) goAway(reason uint32) header {
292	atomic.SwapInt32(&s.localGoAway, 1)
293	hdr := header(make([]byte, headerSize))
294	hdr.encode(typeGoAway, 0, 0, reason)
295	return hdr
296}
297
298// Ping is used to measure the RTT response time
299func (s *Session) Ping() (time.Duration, error) {
300	// Get a channel for the ping
301	ch := make(chan struct{})
302
303	// Get a new ping id, mark as pending
304	s.pingLock.Lock()
305	id := s.pingID
306	s.pingID++
307	s.pings[id] = ch
308	s.pingLock.Unlock()
309
310	// Send the ping request
311	hdr := header(make([]byte, headerSize))
312	hdr.encode(typePing, flagSYN, 0, id)
313	if err := s.waitForSend(hdr, nil); err != nil {
314		return 0, err
315	}
316
317	// Wait for a response
318	start := time.Now()
319	select {
320	case <-ch:
321	case <-time.After(s.config.ConnectionWriteTimeout):
322		s.pingLock.Lock()
323		delete(s.pings, id) // Ignore it if a response comes later.
324		s.pingLock.Unlock()
325		return 0, ErrTimeout
326	case <-s.shutdownCh:
327		return 0, ErrSessionShutdown
328	}
329
330	// Compute the RTT
331	return time.Now().Sub(start), nil
332}
333
334// keepalive is a long running goroutine that periodically does
335// a ping to keep the connection alive.
336func (s *Session) keepalive() {
337	for {
338		select {
339		case <-time.After(s.config.KeepAliveInterval):
340			_, err := s.Ping()
341			if err != nil {
342				if err != ErrSessionShutdown {
343					s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
344					s.exitErr(ErrKeepAliveTimeout)
345				}
346				return
347			}
348		case <-s.shutdownCh:
349			return
350		}
351	}
352}
353
354// waitForSendErr waits to send a header, checking for a potential shutdown
355func (s *Session) waitForSend(hdr header, body []byte) error {
356	errCh := make(chan error, 1)
357	return s.waitForSendErr(hdr, body, errCh)
358}
359
360// waitForSendErr waits to send a header with optional data, checking for a
361// potential shutdown. Since there's the expectation that sends can happen
362// in a timely manner, we enforce the connection write timeout here.
363func (s *Session) waitForSendErr(hdr header, body []byte, errCh chan error) error {
364	t := timerPool.Get()
365	timer := t.(*time.Timer)
366	timer.Reset(s.config.ConnectionWriteTimeout)
367	defer func() {
368		timer.Stop()
369		select {
370		case <-timer.C:
371		default:
372		}
373		timerPool.Put(t)
374	}()
375
376	ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
377	select {
378	case s.sendCh <- ready:
379	case <-s.shutdownCh:
380		return ErrSessionShutdown
381	case <-timer.C:
382		return ErrConnectionWriteTimeout
383	}
384
385	select {
386	case err := <-errCh:
387		return err
388	case <-s.shutdownCh:
389		return ErrSessionShutdown
390	case <-timer.C:
391		return ErrConnectionWriteTimeout
392	}
393}
394
395// sendNoWait does a send without waiting. Since there's the expectation that
396// the send happens right here, we enforce the connection write timeout if we
397// can't queue the header to be sent.
398func (s *Session) sendNoWait(hdr header) error {
399	t := timerPool.Get()
400	timer := t.(*time.Timer)
401	timer.Reset(s.config.ConnectionWriteTimeout)
402	defer func() {
403		timer.Stop()
404		select {
405		case <-timer.C:
406		default:
407		}
408		timerPool.Put(t)
409	}()
410
411	select {
412	case s.sendCh <- sendReady{Hdr: hdr}:
413		return nil
414	case <-s.shutdownCh:
415		return ErrSessionShutdown
416	case <-timer.C:
417		return ErrConnectionWriteTimeout
418	}
419}
420
421// send is a long running goroutine that sends data
422func (s *Session) send() {
423	for {
424		select {
425		case ready := <-s.sendCh:
426			// Send a header if ready
427			if ready.Hdr != nil {
428				sent := 0
429				for sent < len(ready.Hdr) {
430					n, err := s.conn.Write(ready.Hdr[sent:])
431					if err != nil {
432						s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
433						asyncSendErr(ready.Err, err)
434						s.exitErr(err)
435						return
436					}
437					sent += n
438				}
439			}
440
441			// Send data from a body if given
442			if ready.Body != nil {
443				_, err := s.conn.Write(ready.Body)
444				if err != nil {
445					s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
446					asyncSendErr(ready.Err, err)
447					s.exitErr(err)
448					return
449				}
450			}
451
452			// No error, successful send
453			asyncSendErr(ready.Err, nil)
454		case <-s.shutdownCh:
455			return
456		}
457	}
458}
459
460// recv is a long running goroutine that accepts new data
461func (s *Session) recv() {
462	if err := s.recvLoop(); err != nil {
463		s.exitErr(err)
464	}
465}
466
467// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type
468var (
469	handlers = []func(*Session, header) error{
470		typeData:         (*Session).handleStreamMessage,
471		typeWindowUpdate: (*Session).handleStreamMessage,
472		typePing:         (*Session).handlePing,
473		typeGoAway:       (*Session).handleGoAway,
474	}
475)
476
477// recvLoop continues to receive data until a fatal error is encountered
478func (s *Session) recvLoop() error {
479	defer close(s.recvDoneCh)
480	hdr := header(make([]byte, headerSize))
481	for {
482		// Read the header
483		if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
484			if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
485				s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
486			}
487			return err
488		}
489
490		// Verify the version
491		if hdr.Version() != protoVersion {
492			s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
493			return ErrInvalidVersion
494		}
495
496		mt := hdr.MsgType()
497		if mt < typeData || mt > typeGoAway {
498			return ErrInvalidMsgType
499		}
500
501		if err := handlers[mt](s, hdr); err != nil {
502			return err
503		}
504	}
505}
506
507// handleStreamMessage handles either a data or window update frame
508func (s *Session) handleStreamMessage(hdr header) error {
509	// Check for a new stream creation
510	id := hdr.StreamID()
511	flags := hdr.Flags()
512	if flags&flagSYN == flagSYN {
513		if err := s.incomingStream(id); err != nil {
514			return err
515		}
516	}
517
518	// Get the stream
519	s.streamLock.Lock()
520	stream := s.streams[id]
521	s.streamLock.Unlock()
522
523	// If we do not have a stream, likely we sent a RST
524	if stream == nil {
525		// Drain any data on the wire
526		if hdr.MsgType() == typeData && hdr.Length() > 0 {
527			s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
528			if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
529				s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
530				return nil
531			}
532		} else {
533			s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
534		}
535		return nil
536	}
537
538	// Check if this is a window update
539	if hdr.MsgType() == typeWindowUpdate {
540		if err := stream.incrSendWindow(hdr, flags); err != nil {
541			if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
542				s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
543			}
544			return err
545		}
546		return nil
547	}
548
549	// Read the new data
550	if err := stream.readData(hdr, flags, s.bufRead); err != nil {
551		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
552			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
553		}
554		return err
555	}
556	return nil
557}
558
559// handlePing is invokde for a typePing frame
560func (s *Session) handlePing(hdr header) error {
561	flags := hdr.Flags()
562	pingID := hdr.Length()
563
564	// Check if this is a query, respond back in a separate context so we
565	// don't interfere with the receiving thread blocking for the write.
566	if flags&flagSYN == flagSYN {
567		go func() {
568			hdr := header(make([]byte, headerSize))
569			hdr.encode(typePing, flagACK, 0, pingID)
570			if err := s.sendNoWait(hdr); err != nil {
571				s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
572			}
573		}()
574		return nil
575	}
576
577	// Handle a response
578	s.pingLock.Lock()
579	ch := s.pings[pingID]
580	if ch != nil {
581		delete(s.pings, pingID)
582		close(ch)
583	}
584	s.pingLock.Unlock()
585	return nil
586}
587
588// handleGoAway is invokde for a typeGoAway frame
589func (s *Session) handleGoAway(hdr header) error {
590	code := hdr.Length()
591	switch code {
592	case goAwayNormal:
593		atomic.SwapInt32(&s.remoteGoAway, 1)
594	case goAwayProtoErr:
595		s.logger.Printf("[ERR] yamux: received protocol error go away")
596		return fmt.Errorf("yamux protocol error")
597	case goAwayInternalErr:
598		s.logger.Printf("[ERR] yamux: received internal error go away")
599		return fmt.Errorf("remote yamux internal error")
600	default:
601		s.logger.Printf("[ERR] yamux: received unexpected go away")
602		return fmt.Errorf("unexpected go away received")
603	}
604	return nil
605}
606
607// incomingStream is used to create a new incoming stream
608func (s *Session) incomingStream(id uint32) error {
609	// Reject immediately if we are doing a go away
610	if atomic.LoadInt32(&s.localGoAway) == 1 {
611		hdr := header(make([]byte, headerSize))
612		hdr.encode(typeWindowUpdate, flagRST, id, 0)
613		return s.sendNoWait(hdr)
614	}
615
616	// Allocate a new stream
617	stream := newStream(s, id, streamSYNReceived)
618
619	s.streamLock.Lock()
620	defer s.streamLock.Unlock()
621
622	// Check if stream already exists
623	if _, ok := s.streams[id]; ok {
624		s.logger.Printf("[ERR] yamux: duplicate stream declared")
625		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
626			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
627		}
628		return ErrDuplicateStream
629	}
630
631	// Register the stream
632	s.streams[id] = stream
633
634	// Check if we've exceeded the backlog
635	select {
636	case s.acceptCh <- stream:
637		return nil
638	default:
639		// Backlog exceeded! RST the stream
640		s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
641		delete(s.streams, id)
642		stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
643		return s.sendNoWait(stream.sendHdr)
644	}
645}
646
647// closeStream is used to close a stream once both sides have
648// issued a close. If there was an in-flight SYN and the stream
649// was not yet established, then this will give the credit back.
650func (s *Session) closeStream(id uint32) {
651	s.streamLock.Lock()
652	if _, ok := s.inflight[id]; ok {
653		select {
654		case <-s.synCh:
655		default:
656			s.logger.Printf("[ERR] yamux: SYN tracking out of sync")
657		}
658	}
659	delete(s.streams, id)
660	s.streamLock.Unlock()
661}
662
663// establishStream is used to mark a stream that was in the
664// SYN Sent state as established.
665func (s *Session) establishStream(id uint32) {
666	s.streamLock.Lock()
667	if _, ok := s.inflight[id]; ok {
668		delete(s.inflight, id)
669	} else {
670		s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)")
671	}
672	select {
673	case <-s.synCh:
674	default:
675		s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)")
676	}
677	s.streamLock.Unlock()
678}
679