1// Copyright 2016 The CMux Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12// implied. See the License for the specific language governing
13// permissions and limitations under the License.
14
15package cmux
16
17import (
18	"errors"
19	"fmt"
20	"io"
21	"net"
22	"sync"
23	"time"
24)
25
26// Matcher matches a connection based on its content.
27type Matcher func(io.Reader) bool
28
29// MatchWriter is a match that can also write response (say to do handshake).
30type MatchWriter func(io.Writer, io.Reader) bool
31
32// ErrorHandler handles an error and returns whether
33// the mux should continue serving the listener.
34type ErrorHandler func(error) bool
35
36var _ net.Error = ErrNotMatched{}
37
38// ErrNotMatched is returned whenever a connection is not matched by any of
39// the matchers registered in the multiplexer.
40type ErrNotMatched struct {
41	c net.Conn
42}
43
44func (e ErrNotMatched) Error() string {
45	return fmt.Sprintf("mux: connection %v not matched by an matcher",
46		e.c.RemoteAddr())
47}
48
49// Temporary implements the net.Error interface.
50func (e ErrNotMatched) Temporary() bool { return true }
51
52// Timeout implements the net.Error interface.
53func (e ErrNotMatched) Timeout() bool { return false }
54
55type errListenerClosed string
56
57func (e errListenerClosed) Error() string   { return string(e) }
58func (e errListenerClosed) Temporary() bool { return false }
59func (e errListenerClosed) Timeout() bool   { return false }
60
61// ErrListenerClosed is returned from muxListener.Accept when the underlying
62// listener is closed.
63var ErrListenerClosed = errListenerClosed("mux: listener closed")
64
65// ErrServerClosed is returned from muxListener.Accept when mux server is closed.
66var ErrServerClosed = errors.New("mux: server closed")
67
68// for readability of readTimeout
69var noTimeout time.Duration
70
71// New instantiates a new connection multiplexer.
72func New(l net.Listener) CMux {
73	return &cMux{
74		root:        l,
75		bufLen:      1024,
76		errh:        func(_ error) bool { return true },
77		donec:       make(chan struct{}),
78		readTimeout: noTimeout,
79	}
80}
81
82// CMux is a multiplexer for network connections.
83type CMux interface {
84	// Match returns a net.Listener that sees (i.e., accepts) only
85	// the connections matched by at least one of the matcher.
86	//
87	// The order used to call Match determines the priority of matchers.
88	Match(...Matcher) net.Listener
89	// MatchWithWriters returns a net.Listener that accepts only the
90	// connections that matched by at least of the matcher writers.
91	//
92	// Prefer Matchers over MatchWriters, since the latter can write on the
93	// connection before the actual handler.
94	//
95	// The order used to call Match determines the priority of matchers.
96	MatchWithWriters(...MatchWriter) net.Listener
97	// Serve starts multiplexing the listener. Serve blocks and perhaps
98	// should be invoked concurrently within a go routine.
99	Serve() error
100	// Closes cmux server and stops accepting any connections on listener
101	Close()
102	// HandleError registers an error handler that handles listener errors.
103	HandleError(ErrorHandler)
104	// sets a timeout for the read of matchers
105	SetReadTimeout(time.Duration)
106}
107
108type matchersListener struct {
109	ss []MatchWriter
110	l  muxListener
111}
112
113type cMux struct {
114	root        net.Listener
115	bufLen      int
116	errh        ErrorHandler
117	sls         []matchersListener
118	readTimeout time.Duration
119	donec       chan struct{}
120	mu          sync.Mutex
121}
122
123func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
124	mws := make([]MatchWriter, 0, len(matchers))
125	for _, m := range matchers {
126		cm := m
127		mws = append(mws, func(w io.Writer, r io.Reader) bool {
128			return cm(r)
129		})
130	}
131	return mws
132}
133
134func (m *cMux) Match(matchers ...Matcher) net.Listener {
135	mws := matchersToMatchWriters(matchers)
136	return m.MatchWithWriters(mws...)
137}
138
139func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener {
140	ml := muxListener{
141		Listener: m.root,
142		connc:    make(chan net.Conn, m.bufLen),
143		donec:    make(chan struct{}),
144	}
145	m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
146	return ml
147}
148
149func (m *cMux) SetReadTimeout(t time.Duration) {
150	m.readTimeout = t
151}
152
153func (m *cMux) Serve() error {
154	var wg sync.WaitGroup
155
156	defer func() {
157		m.closeDoneChans()
158		wg.Wait()
159
160		for _, sl := range m.sls {
161			close(sl.l.connc)
162			// Drain the connections enqueued for the listener.
163			for c := range sl.l.connc {
164				_ = c.Close()
165			}
166		}
167	}()
168
169	for {
170		c, err := m.root.Accept()
171		if err != nil {
172			if !m.handleErr(err) {
173				return err
174			}
175			continue
176		}
177
178		wg.Add(1)
179		go m.serve(c, m.donec, &wg)
180	}
181}
182
183func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
184	defer wg.Done()
185
186	muc := newMuxConn(c)
187	if m.readTimeout > noTimeout {
188		_ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
189	}
190	for _, sl := range m.sls {
191		for _, s := range sl.ss {
192			matched := s(muc.Conn, muc.startSniffing())
193			if matched {
194				muc.doneSniffing()
195				if m.readTimeout > noTimeout {
196					_ = c.SetReadDeadline(time.Time{})
197				}
198				select {
199				case sl.l.connc <- muc:
200				case <-donec:
201					_ = c.Close()
202				}
203				return
204			}
205		}
206	}
207
208	_ = c.Close()
209	err := ErrNotMatched{c: c}
210	if !m.handleErr(err) {
211		_ = m.root.Close()
212	}
213}
214
215func (m *cMux) Close() {
216	m.closeDoneChans()
217}
218
219func (m *cMux) closeDoneChans() {
220	m.mu.Lock()
221	defer m.mu.Unlock()
222
223	select {
224	case <-m.donec:
225		// Already closed. Don't close again
226	default:
227		close(m.donec)
228	}
229	for _, sl := range m.sls {
230		select {
231		case <-sl.l.donec:
232			// Already closed. Don't close again
233		default:
234			close(sl.l.donec)
235		}
236	}
237}
238
239func (m *cMux) HandleError(h ErrorHandler) {
240	m.errh = h
241}
242
243func (m *cMux) handleErr(err error) bool {
244	if !m.errh(err) {
245		return false
246	}
247
248	if ne, ok := err.(net.Error); ok {
249		return ne.Temporary()
250	}
251
252	return false
253}
254
255type muxListener struct {
256	net.Listener
257	connc chan net.Conn
258	donec chan struct{}
259}
260
261func (l muxListener) Accept() (net.Conn, error) {
262	select {
263	case c, ok := <-l.connc:
264		if !ok {
265			return nil, ErrListenerClosed
266		}
267		return c, nil
268	case <-l.donec:
269		return nil, ErrServerClosed
270	}
271}
272
273// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
274type MuxConn struct {
275	net.Conn
276	buf bufferedReader
277}
278
279func newMuxConn(c net.Conn) *MuxConn {
280	return &MuxConn{
281		Conn: c,
282		buf:  bufferedReader{source: c},
283	}
284}
285
286// From the io.Reader documentation:
287//
288// When Read encounters an error or end-of-file condition after
289// successfully reading n > 0 bytes, it returns the number of
290// bytes read.  It may return the (non-nil) error from the same call
291// or return the error (and n == 0) from a subsequent call.
292// An instance of this general case is that a Reader returning
293// a non-zero number of bytes at the end of the input stream may
294// return either err == EOF or err == nil.  The next Read should
295// return 0, EOF.
296func (m *MuxConn) Read(p []byte) (int, error) {
297	return m.buf.Read(p)
298}
299
300func (m *MuxConn) startSniffing() io.Reader {
301	m.buf.reset(true)
302	return &m.buf
303}
304
305func (m *MuxConn) doneSniffing() {
306	m.buf.reset(false)
307}
308