1// Package tcp provides a simple multiplexer over TCP.
2package tcp // import "github.com/influxdata/influxdb/tcp"
3
4import (
5	"errors"
6	"fmt"
7	"io"
8	"log"
9	"net"
10	"os"
11	"sync"
12	"time"
13)
14
15const (
16	// DefaultTimeout is the default length of time to wait for first byte.
17	DefaultTimeout = 30 * time.Second
18)
19
20// Mux multiplexes a network connection.
21type Mux struct {
22	mu sync.RWMutex
23	ln net.Listener
24	m  map[byte]*listener
25
26	defaultListener *listener
27
28	wg sync.WaitGroup
29
30	// The amount of time to wait for the first header byte.
31	Timeout time.Duration
32
33	// Out-of-band error logger
34	Logger *log.Logger
35}
36
37type replayConn struct {
38	net.Conn
39	firstByte     byte
40	readFirstbyte bool
41}
42
43func (rc *replayConn) Read(b []byte) (int, error) {
44	if rc.readFirstbyte {
45		return rc.Conn.Read(b)
46	}
47
48	if len(b) == 0 {
49		return 0, nil
50	}
51
52	b[0] = rc.firstByte
53	rc.readFirstbyte = true
54	return 1, nil
55}
56
57// NewMux returns a new instance of Mux.
58func NewMux() *Mux {
59	return &Mux{
60		m:       make(map[byte]*listener),
61		Timeout: DefaultTimeout,
62		Logger:  log.New(os.Stderr, "[tcp] ", log.LstdFlags),
63	}
64}
65
66// Serve handles connections from ln and multiplexes then across registered listeners.
67func (mux *Mux) Serve(ln net.Listener) error {
68	mux.mu.Lock()
69	mux.ln = ln
70	mux.mu.Unlock()
71	for {
72		// Wait for the next connection.
73		// If it returns a temporary error then simply retry.
74		// If it returns any other error then exit immediately.
75		conn, err := ln.Accept()
76		if err, ok := err.(interface {
77			Temporary() bool
78		}); ok && err.Temporary() {
79			continue
80		}
81		if err != nil {
82			// Wait for all connections to be demux
83			mux.wg.Wait()
84
85			// Concurrently close all registered listeners.
86			// Because mux.m is keyed by byte, in the worst case we would spawn 256 goroutines here.
87			var wg sync.WaitGroup
88			mux.mu.RLock()
89			for _, ln := range mux.m {
90				wg.Add(1)
91				go func(ln *listener) {
92					defer wg.Done()
93					ln.Close()
94				}(ln)
95			}
96			mux.mu.RUnlock()
97			wg.Wait()
98
99			mux.mu.RLock()
100			dl := mux.defaultListener
101			mux.mu.RUnlock()
102			if dl != nil {
103				dl.Close()
104			}
105
106			return err
107		}
108
109		// Demux in a goroutine to
110		mux.wg.Add(1)
111		go mux.handleConn(conn)
112	}
113}
114
115func (mux *Mux) handleConn(conn net.Conn) {
116	defer mux.wg.Done()
117	// Set a read deadline so connections with no data don't timeout.
118	if err := conn.SetReadDeadline(time.Now().Add(mux.Timeout)); err != nil {
119		conn.Close()
120		mux.Logger.Printf("tcp.Mux: cannot set read deadline: %s", err)
121		return
122	}
123
124	// Read first byte from connection to determine handler.
125	var typ [1]byte
126	if _, err := io.ReadFull(conn, typ[:]); err != nil {
127		conn.Close()
128		mux.Logger.Printf("tcp.Mux: cannot read header byte: %s", err)
129		return
130	}
131
132	// Reset read deadline and let the listener handle that.
133	if err := conn.SetReadDeadline(time.Time{}); err != nil {
134		conn.Close()
135		mux.Logger.Printf("tcp.Mux: cannot reset set read deadline: %s", err)
136		return
137	}
138
139	// Retrieve handler based on first byte.
140	mux.mu.RLock()
141	handler := mux.m[typ[0]]
142	mux.mu.RUnlock()
143
144	if handler == nil {
145		if mux.defaultListener == nil {
146			conn.Close()
147			mux.Logger.Printf("tcp.Mux: handler not registered: %d. Connection from %s closed", typ[0], conn.RemoteAddr())
148			return
149		}
150
151		conn = &replayConn{
152			Conn:      conn,
153			firstByte: typ[0],
154		}
155		handler = mux.defaultListener
156	}
157
158	handler.HandleConn(conn, typ[0])
159}
160
161// Listen returns a listener identified by header.
162// Any connection accepted by mux is multiplexed based on the initial header byte.
163func (mux *Mux) Listen(header byte) net.Listener {
164	mux.mu.Lock()
165	defer mux.mu.Unlock()
166
167	// Ensure two listeners are not created for the same header byte.
168	if _, ok := mux.m[header]; ok {
169		panic(fmt.Sprintf("listener already registered under header byte: %d", header))
170	}
171
172	// Create a new listener and assign it.
173	ln := &listener{
174		c:    make(chan net.Conn),
175		done: make(chan struct{}),
176		mux:  mux,
177	}
178	mux.m[header] = ln
179
180	return ln
181}
182
183// release removes the listener from the mux.
184func (mux *Mux) release(ln *listener) bool {
185	mux.mu.Lock()
186	defer mux.mu.Unlock()
187
188	for b, l := range mux.m {
189		if l == ln {
190			delete(mux.m, b)
191			return true
192		}
193	}
194	return false
195}
196
197// DefaultListener will return a net.Listener that will pass-through any
198// connections with non-registered values for the first byte of the connection.
199// The connections returned from this listener's Accept() method will replay the
200// first byte of the connection as a short first Read().
201//
202// This can be used to pass to an HTTP server, so long as there are no conflicts
203// with registered listener bytes and the first character of the HTTP request:
204// 71 ('G') for GET, etc.
205func (mux *Mux) DefaultListener() net.Listener {
206	mux.mu.Lock()
207	defer mux.mu.Unlock()
208	if mux.defaultListener == nil {
209		mux.defaultListener = &listener{
210			c:    make(chan net.Conn),
211			done: make(chan struct{}),
212			mux:  mux,
213		}
214	}
215
216	return mux.defaultListener
217}
218
219// listener is a receiver for connections received by Mux.
220type listener struct {
221	mux *Mux
222
223	// The done channel is closed before taking a lock on mu to close c.
224	// That way, anyone holding an RLock can release the lock by receiving from done.
225	done chan struct{}
226
227	mu sync.RWMutex
228	c  chan net.Conn
229}
230
231// Accept waits for and returns the next connection to the listener.
232func (ln *listener) Accept() (net.Conn, error) {
233	ln.mu.RLock()
234	defer ln.mu.RUnlock()
235
236	select {
237	case <-ln.done:
238		return nil, errors.New("network connection closed")
239	case conn := <-ln.c:
240		return conn, nil
241	}
242}
243
244// Close removes this listener from the parent mux and closes the channel.
245func (ln *listener) Close() error {
246	if ok := ln.mux.release(ln); ok {
247		// Close done to signal to any RLock holders to release their lock.
248		close(ln.done)
249
250		// Hold a lock while reassigning ln.c to nil
251		// so that attempted sends or receives will block forever.
252		ln.mu.Lock()
253		ln.c = nil
254		ln.mu.Unlock()
255	}
256	return nil
257}
258
259// HandleConn handles the connection, if the listener has not been closed.
260func (ln *listener) HandleConn(conn net.Conn, handlerID byte) {
261	ln.mu.RLock()
262	defer ln.mu.RUnlock()
263
264	// Send connection to handler.  The handler is responsible for closing the connection.
265	timer := time.NewTimer(ln.mux.Timeout)
266	defer timer.Stop()
267
268	select {
269	case <-ln.done:
270		// Receive will return immediately if ln.Close has been called.
271		conn.Close()
272	case ln.c <- conn:
273		// Send will block forever if ln.Close has been called.
274	case <-timer.C:
275		conn.Close()
276		ln.mux.Logger.Printf("tcp.Mux: handler not ready: %d. Connection from %s closed", handlerID, conn.RemoteAddr())
277		return
278	}
279}
280
281// Addr returns the Addr of the listener
282func (ln *listener) Addr() net.Addr {
283	if ln.mux == nil {
284		return nil
285	}
286
287	ln.mux.mu.RLock()
288	defer ln.mux.mu.RUnlock()
289
290	if ln.mux.ln == nil {
291		return nil
292	}
293
294	return ln.mux.ln.Addr()
295}
296
297// Dial connects to a remote mux listener with a given header byte.
298func Dial(network, address string, header byte) (net.Conn, error) {
299	conn, err := net.Dial(network, address)
300	if err != nil {
301		return nil, err
302	}
303
304	if _, err := conn.Write([]byte{header}); err != nil {
305		return nil, fmt.Errorf("write mux header: %s", err)
306	}
307
308	return conn, nil
309}
310