1package yamux
2
3import (
4	"bufio"
5	"fmt"
6	"io"
7	"io/ioutil"
8	"log"
9	"math"
10	"net"
11	"strings"
12	"sync"
13	"sync/atomic"
14	"time"
15)
16
17// Session is used to wrap a reliable ordered connection and to
18// multiplex it into multiple streams.
19type Session struct {
20	// remoteGoAway indicates the remote side does
21	// not want futher connections. Must be first for alignment.
22	remoteGoAway int32
23
24	// localGoAway indicates that we should stop
25	// accepting futher connections. Must be first for alignment.
26	localGoAway int32
27
28	// nextStreamID is the next stream we should
29	// send. This depends if we are a client/server.
30	nextStreamID uint32
31
32	// config holds our configuration
33	config *Config
34
35	// logger is used for our logs
36	logger *log.Logger
37
38	// conn is the underlying connection
39	conn io.ReadWriteCloser
40
41	// bufRead is a buffered reader
42	bufRead *bufio.Reader
43
44	// pings is used to track inflight pings
45	pings    map[uint32]chan struct{}
46	pingID   uint32
47	pingLock sync.Mutex
48
49	// streams maps a stream id to a stream, and inflight has an entry
50	// for any outgoing stream that has not yet been established. Both are
51	// protected by streamLock.
52	streams    map[uint32]*Stream
53	inflight   map[uint32]struct{}
54	streamLock sync.Mutex
55
56	// synCh acts like a semaphore. It is sized to the AcceptBacklog which
57	// is assumed to be symmetric between the client and server. This allows
58	// the client to avoid exceeding the backlog and instead blocks the open.
59	synCh chan struct{}
60
61	// acceptCh is used to pass ready streams to the client
62	acceptCh chan *Stream
63
64	// sendCh is used to mark a stream as ready to send,
65	// or to send a header out directly.
66	sendCh chan sendReady
67
68	// recvDoneCh is closed when recv() exits to avoid a race
69	// between stream registration and stream shutdown
70	recvDoneCh chan struct{}
71
72	// shutdown is used to safely close a session
73	shutdown     bool
74	shutdownErr  error
75	shutdownCh   chan struct{}
76	shutdownLock sync.Mutex
77}
78
79// sendReady is used to either mark a stream as ready
80// or to directly send a header
81type sendReady struct {
82	Hdr  []byte
83	Body io.Reader
84	Err  chan error
85}
86
87// newSession is used to construct a new session
88func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
89	s := &Session{
90		config:     config,
91		logger:     log.New(config.LogOutput, "", log.LstdFlags),
92		conn:       conn,
93		bufRead:    bufio.NewReader(conn),
94		pings:      make(map[uint32]chan struct{}),
95		streams:    make(map[uint32]*Stream),
96		inflight:   make(map[uint32]struct{}),
97		synCh:      make(chan struct{}, config.AcceptBacklog),
98		acceptCh:   make(chan *Stream, config.AcceptBacklog),
99		sendCh:     make(chan sendReady, 64),
100		recvDoneCh: make(chan struct{}),
101		shutdownCh: make(chan struct{}),
102	}
103	if client {
104		s.nextStreamID = 1
105	} else {
106		s.nextStreamID = 2
107	}
108	go s.recv()
109	go s.send()
110	if config.EnableKeepAlive {
111		go s.keepalive()
112	}
113	return s
114}
115
116// IsClosed does a safe check to see if we have shutdown
117func (s *Session) IsClosed() bool {
118	select {
119	case <-s.shutdownCh:
120		return true
121	default:
122		return false
123	}
124}
125
126// NumStreams returns the number of currently open streams
127func (s *Session) NumStreams() int {
128	s.streamLock.Lock()
129	num := len(s.streams)
130	s.streamLock.Unlock()
131	return num
132}
133
134// Open is used to create a new stream as a net.Conn
135func (s *Session) Open() (net.Conn, error) {
136	conn, err := s.OpenStream()
137	if err != nil {
138		return nil, err
139	}
140	return conn, nil
141}
142
143// OpenStream is used to create a new stream
144func (s *Session) OpenStream() (*Stream, error) {
145	if s.IsClosed() {
146		return nil, ErrSessionShutdown
147	}
148	if atomic.LoadInt32(&s.remoteGoAway) == 1 {
149		return nil, ErrRemoteGoAway
150	}
151
152	// Block if we have too many inflight SYNs
153	select {
154	case s.synCh <- struct{}{}:
155	case <-s.shutdownCh:
156		return nil, ErrSessionShutdown
157	}
158
159GET_ID:
160	// Get an ID, and check for stream exhaustion
161	id := atomic.LoadUint32(&s.nextStreamID)
162	if id >= math.MaxUint32-1 {
163		return nil, ErrStreamsExhausted
164	}
165	if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
166		goto GET_ID
167	}
168
169	// Register the stream
170	stream := newStream(s, id, streamInit)
171	s.streamLock.Lock()
172	s.streams[id] = stream
173	s.inflight[id] = struct{}{}
174	s.streamLock.Unlock()
175
176	// Send the window update to create
177	if err := stream.sendWindowUpdate(); err != nil {
178		select {
179		case <-s.synCh:
180		default:
181			s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
182		}
183		return nil, err
184	}
185	return stream, nil
186}
187
188// Accept is used to block until the next available stream
189// is ready to be accepted.
190func (s *Session) Accept() (net.Conn, error) {
191	conn, err := s.AcceptStream()
192	if err != nil {
193		return nil, err
194	}
195	return conn, err
196}
197
198// AcceptStream is used to block until the next available stream
199// is ready to be accepted.
200func (s *Session) AcceptStream() (*Stream, error) {
201	select {
202	case stream := <-s.acceptCh:
203		if err := stream.sendWindowUpdate(); err != nil {
204			return nil, err
205		}
206		return stream, nil
207	case <-s.shutdownCh:
208		return nil, s.shutdownErr
209	}
210}
211
212// Close is used to close the session and all streams.
213// Attempts to send a GoAway before closing the connection.
214func (s *Session) Close() error {
215	s.shutdownLock.Lock()
216	defer s.shutdownLock.Unlock()
217
218	if s.shutdown {
219		return nil
220	}
221	s.shutdown = true
222	if s.shutdownErr == nil {
223		s.shutdownErr = ErrSessionShutdown
224	}
225	close(s.shutdownCh)
226	s.conn.Close()
227	<-s.recvDoneCh
228
229	s.streamLock.Lock()
230	defer s.streamLock.Unlock()
231	for _, stream := range s.streams {
232		stream.forceClose()
233	}
234	return nil
235}
236
237// exitErr is used to handle an error that is causing the
238// session to terminate.
239func (s *Session) exitErr(err error) {
240	s.shutdownLock.Lock()
241	if s.shutdownErr == nil {
242		s.shutdownErr = err
243	}
244	s.shutdownLock.Unlock()
245	s.Close()
246}
247
248// GoAway can be used to prevent accepting further
249// connections. It does not close the underlying conn.
250func (s *Session) GoAway() error {
251	return s.waitForSend(s.goAway(goAwayNormal), nil)
252}
253
254// goAway is used to send a goAway message
255func (s *Session) goAway(reason uint32) header {
256	atomic.SwapInt32(&s.localGoAway, 1)
257	hdr := header(make([]byte, headerSize))
258	hdr.encode(typeGoAway, 0, 0, reason)
259	return hdr
260}
261
262// Ping is used to measure the RTT response time
263func (s *Session) Ping() (time.Duration, error) {
264	// Get a channel for the ping
265	ch := make(chan struct{})
266
267	// Get a new ping id, mark as pending
268	s.pingLock.Lock()
269	id := s.pingID
270	s.pingID++
271	s.pings[id] = ch
272	s.pingLock.Unlock()
273
274	// Send the ping request
275	hdr := header(make([]byte, headerSize))
276	hdr.encode(typePing, flagSYN, 0, id)
277	if err := s.waitForSend(hdr, nil); err != nil {
278		return 0, err
279	}
280
281	// Wait for a response
282	start := time.Now()
283	select {
284	case <-ch:
285	case <-time.After(s.config.ConnectionWriteTimeout):
286		s.pingLock.Lock()
287		delete(s.pings, id) // Ignore it if a response comes later.
288		s.pingLock.Unlock()
289		return 0, ErrTimeout
290	case <-s.shutdownCh:
291		return 0, ErrSessionShutdown
292	}
293
294	// Compute the RTT
295	return time.Now().Sub(start), nil
296}
297
298// keepalive is a long running goroutine that periodically does
299// a ping to keep the connection alive.
300func (s *Session) keepalive() {
301	for {
302		select {
303		case <-time.After(s.config.KeepAliveInterval):
304			_, err := s.Ping()
305			if err != nil {
306				s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
307				s.exitErr(ErrKeepAliveTimeout)
308				return
309			}
310		case <-s.shutdownCh:
311			return
312		}
313	}
314}
315
316// waitForSendErr waits to send a header, checking for a potential shutdown
317func (s *Session) waitForSend(hdr header, body io.Reader) error {
318	errCh := make(chan error, 1)
319	return s.waitForSendErr(hdr, body, errCh)
320}
321
322// waitForSendErr waits to send a header with optional data, checking for a
323// potential shutdown. Since there's the expectation that sends can happen
324// in a timely manner, we enforce the connection write timeout here.
325func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
326	timer := time.NewTimer(s.config.ConnectionWriteTimeout)
327	defer timer.Stop()
328
329	ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
330	select {
331	case s.sendCh <- ready:
332	case <-s.shutdownCh:
333		return ErrSessionShutdown
334	case <-timer.C:
335		return ErrConnectionWriteTimeout
336	}
337
338	select {
339	case err := <-errCh:
340		return err
341	case <-s.shutdownCh:
342		return ErrSessionShutdown
343	case <-timer.C:
344		return ErrConnectionWriteTimeout
345	}
346}
347
348// sendNoWait does a send without waiting. Since there's the expectation that
349// the send happens right here, we enforce the connection write timeout if we
350// can't queue the header to be sent.
351func (s *Session) sendNoWait(hdr header) error {
352	timer := time.NewTimer(s.config.ConnectionWriteTimeout)
353	defer timer.Stop()
354
355	select {
356	case s.sendCh <- sendReady{Hdr: hdr}:
357		return nil
358	case <-s.shutdownCh:
359		return ErrSessionShutdown
360	case <-timer.C:
361		return ErrConnectionWriteTimeout
362	}
363}
364
365// send is a long running goroutine that sends data
366func (s *Session) send() {
367	for {
368		select {
369		case ready := <-s.sendCh:
370			// Send a header if ready
371			if ready.Hdr != nil {
372				sent := 0
373				for sent < len(ready.Hdr) {
374					n, err := s.conn.Write(ready.Hdr[sent:])
375					if err != nil {
376						s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
377						asyncSendErr(ready.Err, err)
378						s.exitErr(err)
379						return
380					}
381					sent += n
382				}
383			}
384
385			// Send data from a body if given
386			if ready.Body != nil {
387				_, err := io.Copy(s.conn, ready.Body)
388				if err != nil {
389					s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
390					asyncSendErr(ready.Err, err)
391					s.exitErr(err)
392					return
393				}
394			}
395
396			// No error, successful send
397			asyncSendErr(ready.Err, nil)
398		case <-s.shutdownCh:
399			return
400		}
401	}
402}
403
404// recv is a long running goroutine that accepts new data
405func (s *Session) recv() {
406	if err := s.recvLoop(); err != nil {
407		s.exitErr(err)
408	}
409}
410
411// recvLoop continues to receive data until a fatal error is encountered
412func (s *Session) recvLoop() error {
413	defer close(s.recvDoneCh)
414	hdr := header(make([]byte, headerSize))
415	var handler func(header) error
416	for {
417		// Read the header
418		if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
419			if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
420				s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
421			}
422			return err
423		}
424
425		// Verify the version
426		if hdr.Version() != protoVersion {
427			s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
428			return ErrInvalidVersion
429		}
430
431		// Switch on the type
432		switch hdr.MsgType() {
433		case typeData:
434			handler = s.handleStreamMessage
435		case typeWindowUpdate:
436			handler = s.handleStreamMessage
437		case typeGoAway:
438			handler = s.handleGoAway
439		case typePing:
440			handler = s.handlePing
441		default:
442			return ErrInvalidMsgType
443		}
444
445		// Invoke the handler
446		if err := handler(hdr); err != nil {
447			return err
448		}
449	}
450}
451
452// handleStreamMessage handles either a data or window update frame
453func (s *Session) handleStreamMessage(hdr header) error {
454	// Check for a new stream creation
455	id := hdr.StreamID()
456	flags := hdr.Flags()
457	if flags&flagSYN == flagSYN {
458		if err := s.incomingStream(id); err != nil {
459			return err
460		}
461	}
462
463	// Get the stream
464	s.streamLock.Lock()
465	stream := s.streams[id]
466	s.streamLock.Unlock()
467
468	// If we do not have a stream, likely we sent a RST
469	if stream == nil {
470		// Drain any data on the wire
471		if hdr.MsgType() == typeData && hdr.Length() > 0 {
472			s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
473			if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
474				s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
475				return nil
476			}
477		} else {
478			s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
479		}
480		return nil
481	}
482
483	// Check if this is a window update
484	if hdr.MsgType() == typeWindowUpdate {
485		if err := stream.incrSendWindow(hdr, flags); err != nil {
486			if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
487				s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
488			}
489			return err
490		}
491		return nil
492	}
493
494	// Read the new data
495	if err := stream.readData(hdr, flags, s.bufRead); err != nil {
496		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
497			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
498		}
499		return err
500	}
501	return nil
502}
503
504// handlePing is invokde for a typePing frame
505func (s *Session) handlePing(hdr header) error {
506	flags := hdr.Flags()
507	pingID := hdr.Length()
508
509	// Check if this is a query, respond back in a separate context so we
510	// don't interfere with the receiving thread blocking for the write.
511	if flags&flagSYN == flagSYN {
512		go func() {
513			hdr := header(make([]byte, headerSize))
514			hdr.encode(typePing, flagACK, 0, pingID)
515			if err := s.sendNoWait(hdr); err != nil {
516				s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
517			}
518		}()
519		return nil
520	}
521
522	// Handle a response
523	s.pingLock.Lock()
524	ch := s.pings[pingID]
525	if ch != nil {
526		delete(s.pings, pingID)
527		close(ch)
528	}
529	s.pingLock.Unlock()
530	return nil
531}
532
533// handleGoAway is invokde for a typeGoAway frame
534func (s *Session) handleGoAway(hdr header) error {
535	code := hdr.Length()
536	switch code {
537	case goAwayNormal:
538		atomic.SwapInt32(&s.remoteGoAway, 1)
539	case goAwayProtoErr:
540		s.logger.Printf("[ERR] yamux: received protocol error go away")
541		return fmt.Errorf("yamux protocol error")
542	case goAwayInternalErr:
543		s.logger.Printf("[ERR] yamux: received internal error go away")
544		return fmt.Errorf("remote yamux internal error")
545	default:
546		s.logger.Printf("[ERR] yamux: received unexpected go away")
547		return fmt.Errorf("unexpected go away received")
548	}
549	return nil
550}
551
552// incomingStream is used to create a new incoming stream
553func (s *Session) incomingStream(id uint32) error {
554	// Reject immediately if we are doing a go away
555	if atomic.LoadInt32(&s.localGoAway) == 1 {
556		hdr := header(make([]byte, headerSize))
557		hdr.encode(typeWindowUpdate, flagRST, id, 0)
558		return s.sendNoWait(hdr)
559	}
560
561	// Allocate a new stream
562	stream := newStream(s, id, streamSYNReceived)
563
564	s.streamLock.Lock()
565	defer s.streamLock.Unlock()
566
567	// Check if stream already exists
568	if _, ok := s.streams[id]; ok {
569		s.logger.Printf("[ERR] yamux: duplicate stream declared")
570		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
571			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
572		}
573		return ErrDuplicateStream
574	}
575
576	// Register the stream
577	s.streams[id] = stream
578
579	// Check if we've exceeded the backlog
580	select {
581	case s.acceptCh <- stream:
582		return nil
583	default:
584		// Backlog exceeded! RST the stream
585		s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
586		delete(s.streams, id)
587		stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
588		return s.sendNoWait(stream.sendHdr)
589	}
590}
591
592// closeStream is used to close a stream once both sides have
593// issued a close. If there was an in-flight SYN and the stream
594// was not yet established, then this will give the credit back.
595func (s *Session) closeStream(id uint32) {
596	s.streamLock.Lock()
597	if _, ok := s.inflight[id]; ok {
598		select {
599		case <-s.synCh:
600		default:
601			s.logger.Printf("[ERR] yamux: SYN tracking out of sync")
602		}
603	}
604	delete(s.streams, id)
605	s.streamLock.Unlock()
606}
607
608// establishStream is used to mark a stream that was in the
609// SYN Sent state as established.
610func (s *Session) establishStream(id uint32) {
611	s.streamLock.Lock()
612	if _, ok := s.inflight[id]; ok {
613		delete(s.inflight, id)
614	} else {
615		s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)")
616	}
617	select {
618	case <-s.synCh:
619	default:
620		s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)")
621	}
622	s.streamLock.Unlock()
623}
624