1// Copyright 2016 The CMux Authors. All rights reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12// implied. See the License for the specific language governing 13// permissions and limitations under the License. 14 15package cmux 16 17import ( 18 "errors" 19 "fmt" 20 "io" 21 "net" 22 "sync" 23 "time" 24) 25 26// Matcher matches a connection based on its content. 27type Matcher func(io.Reader) bool 28 29// MatchWriter is a match that can also write response (say to do handshake). 30type MatchWriter func(io.Writer, io.Reader) bool 31 32// ErrorHandler handles an error and returns whether 33// the mux should continue serving the listener. 34type ErrorHandler func(error) bool 35 36var _ net.Error = ErrNotMatched{} 37 38// ErrNotMatched is returned whenever a connection is not matched by any of 39// the matchers registered in the multiplexer. 40type ErrNotMatched struct { 41 c net.Conn 42} 43 44func (e ErrNotMatched) Error() string { 45 return fmt.Sprintf("mux: connection %v not matched by an matcher", 46 e.c.RemoteAddr()) 47} 48 49// Temporary implements the net.Error interface. 50func (e ErrNotMatched) Temporary() bool { return true } 51 52// Timeout implements the net.Error interface. 53func (e ErrNotMatched) Timeout() bool { return false } 54 55type errListenerClosed string 56 57func (e errListenerClosed) Error() string { return string(e) } 58func (e errListenerClosed) Temporary() bool { return false } 59func (e errListenerClosed) Timeout() bool { return false } 60 61// ErrListenerClosed is returned from muxListener.Accept when the underlying 62// listener is closed. 63var ErrListenerClosed = errListenerClosed("mux: listener closed") 64 65// ErrServerClosed is returned from muxListener.Accept when mux server is closed. 66var ErrServerClosed = errors.New("mux: server closed") 67 68// for readability of readTimeout 69var noTimeout time.Duration 70 71// New instantiates a new connection multiplexer. 72func New(l net.Listener) CMux { 73 return &cMux{ 74 root: l, 75 bufLen: 1024, 76 errh: func(_ error) bool { return true }, 77 donec: make(chan struct{}), 78 readTimeout: noTimeout, 79 } 80} 81 82// CMux is a multiplexer for network connections. 83type CMux interface { 84 // Match returns a net.Listener that sees (i.e., accepts) only 85 // the connections matched by at least one of the matcher. 86 // 87 // The order used to call Match determines the priority of matchers. 88 Match(...Matcher) net.Listener 89 // MatchWithWriters returns a net.Listener that accepts only the 90 // connections that matched by at least of the matcher writers. 91 // 92 // Prefer Matchers over MatchWriters, since the latter can write on the 93 // connection before the actual handler. 94 // 95 // The order used to call Match determines the priority of matchers. 96 MatchWithWriters(...MatchWriter) net.Listener 97 // Serve starts multiplexing the listener. Serve blocks and perhaps 98 // should be invoked concurrently within a go routine. 99 Serve() error 100 // Closes cmux server and stops accepting any connections on listener 101 Close() 102 // HandleError registers an error handler that handles listener errors. 103 HandleError(ErrorHandler) 104 // sets a timeout for the read of matchers 105 SetReadTimeout(time.Duration) 106} 107 108type matchersListener struct { 109 ss []MatchWriter 110 l muxListener 111} 112 113type cMux struct { 114 root net.Listener 115 bufLen int 116 errh ErrorHandler 117 sls []matchersListener 118 readTimeout time.Duration 119 donec chan struct{} 120 mu sync.Mutex 121} 122 123func matchersToMatchWriters(matchers []Matcher) []MatchWriter { 124 mws := make([]MatchWriter, 0, len(matchers)) 125 for _, m := range matchers { 126 cm := m 127 mws = append(mws, func(w io.Writer, r io.Reader) bool { 128 return cm(r) 129 }) 130 } 131 return mws 132} 133 134func (m *cMux) Match(matchers ...Matcher) net.Listener { 135 mws := matchersToMatchWriters(matchers) 136 return m.MatchWithWriters(mws...) 137} 138 139func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener { 140 ml := muxListener{ 141 Listener: m.root, 142 connc: make(chan net.Conn, m.bufLen), 143 donec: make(chan struct{}), 144 } 145 m.sls = append(m.sls, matchersListener{ss: matchers, l: ml}) 146 return ml 147} 148 149func (m *cMux) SetReadTimeout(t time.Duration) { 150 m.readTimeout = t 151} 152 153func (m *cMux) Serve() error { 154 var wg sync.WaitGroup 155 156 defer func() { 157 m.closeDoneChans() 158 wg.Wait() 159 160 for _, sl := range m.sls { 161 close(sl.l.connc) 162 // Drain the connections enqueued for the listener. 163 for c := range sl.l.connc { 164 _ = c.Close() 165 } 166 } 167 }() 168 169 for { 170 c, err := m.root.Accept() 171 if err != nil { 172 if !m.handleErr(err) { 173 return err 174 } 175 continue 176 } 177 178 wg.Add(1) 179 go m.serve(c, m.donec, &wg) 180 } 181} 182 183func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { 184 defer wg.Done() 185 186 muc := newMuxConn(c) 187 if m.readTimeout > noTimeout { 188 _ = c.SetReadDeadline(time.Now().Add(m.readTimeout)) 189 } 190 for _, sl := range m.sls { 191 for _, s := range sl.ss { 192 matched := s(muc.Conn, muc.startSniffing()) 193 if matched { 194 muc.doneSniffing() 195 if m.readTimeout > noTimeout { 196 _ = c.SetReadDeadline(time.Time{}) 197 } 198 select { 199 case sl.l.connc <- muc: 200 case <-donec: 201 _ = c.Close() 202 } 203 return 204 } 205 } 206 } 207 208 _ = c.Close() 209 err := ErrNotMatched{c: c} 210 if !m.handleErr(err) { 211 _ = m.root.Close() 212 } 213} 214 215func (m *cMux) Close() { 216 m.closeDoneChans() 217} 218 219func (m *cMux) closeDoneChans() { 220 m.mu.Lock() 221 defer m.mu.Unlock() 222 223 select { 224 case <-m.donec: 225 // Already closed. Don't close again 226 default: 227 close(m.donec) 228 } 229 for _, sl := range m.sls { 230 select { 231 case <-sl.l.donec: 232 // Already closed. Don't close again 233 default: 234 close(sl.l.donec) 235 } 236 } 237} 238 239func (m *cMux) HandleError(h ErrorHandler) { 240 m.errh = h 241} 242 243func (m *cMux) handleErr(err error) bool { 244 if !m.errh(err) { 245 return false 246 } 247 248 if ne, ok := err.(net.Error); ok { 249 return ne.Temporary() 250 } 251 252 return false 253} 254 255type muxListener struct { 256 net.Listener 257 connc chan net.Conn 258 donec chan struct{} 259} 260 261func (l muxListener) Accept() (net.Conn, error) { 262 select { 263 case c, ok := <-l.connc: 264 if !ok { 265 return nil, ErrListenerClosed 266 } 267 return c, nil 268 case <-l.donec: 269 return nil, ErrServerClosed 270 } 271} 272 273// MuxConn wraps a net.Conn and provides transparent sniffing of connection data. 274type MuxConn struct { 275 net.Conn 276 buf bufferedReader 277} 278 279func newMuxConn(c net.Conn) *MuxConn { 280 return &MuxConn{ 281 Conn: c, 282 buf: bufferedReader{source: c}, 283 } 284} 285 286// From the io.Reader documentation: 287// 288// When Read encounters an error or end-of-file condition after 289// successfully reading n > 0 bytes, it returns the number of 290// bytes read. It may return the (non-nil) error from the same call 291// or return the error (and n == 0) from a subsequent call. 292// An instance of this general case is that a Reader returning 293// a non-zero number of bytes at the end of the input stream may 294// return either err == EOF or err == nil. The next Read should 295// return 0, EOF. 296func (m *MuxConn) Read(p []byte) (int, error) { 297 return m.buf.Read(p) 298} 299 300func (m *MuxConn) startSniffing() io.Reader { 301 m.buf.reset(true) 302 return &m.buf 303} 304 305func (m *MuxConn) doneSniffing() { 306 m.buf.reset(false) 307} 308