1package tcp_test
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7	"io/ioutil"
8	"log"
9	"net"
10	"strings"
11	"sync"
12	"testing"
13	"testing/quick"
14	"time"
15
16	"github.com/influxdata/influxdb/tcp"
17)
18
19// Ensure the muxer can split a listener's connections across multiple listeners.
20func TestMux(t *testing.T) {
21	if err := quick.Check(func(n uint8, msg []byte) bool {
22		if testing.Verbose() {
23			if len(msg) == 0 {
24				log.Printf("n=%d, <no message>", n)
25			} else {
26				log.Printf("n=%d, hdr=%d, len=%d", n, msg[0], len(msg))
27			}
28		}
29
30		var wg sync.WaitGroup
31
32		// Open single listener on random port.
33		tcpListener, err := net.Listen("tcp", "127.0.0.1:0")
34		if err != nil {
35			t.Fatal(err)
36		}
37		defer tcpListener.Close()
38
39		// Setup muxer & listeners.
40		mux := tcp.NewMux()
41		mux.Timeout = 200 * time.Millisecond
42		if !testing.Verbose() {
43			mux.Logger = log.New(ioutil.Discard, "", 0)
44		}
45
46		errC := make(chan error)
47		for i := uint8(0); i < n; i++ {
48			ln := mux.Listen(byte(i))
49
50			wg.Add(1)
51			go func(i uint8, ln net.Listener) {
52				defer wg.Done()
53
54				// Wait for a connection for this listener.
55				conn, err := ln.Accept()
56				if conn != nil {
57					defer conn.Close()
58				}
59
60				// If there is no message or the header byte
61				// doesn't match then expect close.
62				if len(msg) == 0 || msg[0] != byte(i) {
63					if err == nil || err.Error() != "network connection closed" {
64						errC <- fmt.Errorf("unexpected error: %s", err)
65						return
66					}
67					return
68				}
69
70				// If the header byte matches this listener
71				// then expect a connection and read the message.
72				var buf bytes.Buffer
73				if _, err := io.CopyN(&buf, conn, int64(len(msg)-1)); err != nil {
74					errC <- err
75					return
76				} else if !bytes.Equal(msg[1:], buf.Bytes()) {
77					errC <- fmt.Errorf("message mismatch:\n\nexp=%x\n\ngot=%x\n\n", msg[1:], buf.Bytes())
78					return
79				}
80
81				// Write response.
82				if _, err := conn.Write([]byte("OK")); err != nil {
83					errC <- err
84					return
85				}
86			}(i, ln)
87		}
88
89		// Begin serving from the listener.
90		go mux.Serve(tcpListener)
91
92		// Write message to TCP listener and read OK response.
93		conn, err := net.Dial("tcp", tcpListener.Addr().String())
94		if err != nil {
95			t.Fatal(err)
96		} else if _, err = conn.Write(msg); err != nil {
97			t.Fatal(err)
98		}
99
100		// Read the response into the buffer.
101		var resp [2]byte
102		_, err = io.ReadFull(conn, resp[:])
103
104		// If the message header is less than n then expect a response.
105		// Otherwise we should get an EOF because the mux closed.
106		if len(msg) > 0 && uint8(msg[0]) < n {
107			if string(resp[:]) != `OK` {
108				t.Fatalf("unexpected response: %s", resp[:])
109			}
110		} else {
111			if err == nil || (err != io.EOF && !(strings.Contains(err.Error(), "connection reset by peer") ||
112				strings.Contains(err.Error(), "closed by the remote host"))) {
113				t.Fatalf("unexpected error: %s", err)
114			}
115		}
116
117		// Close connection.
118		if err := conn.Close(); err != nil {
119			t.Fatal(err)
120		}
121
122		// Close original TCP listener and wait for all goroutines to close.
123		tcpListener.Close()
124
125		go func() {
126			wg.Wait()
127			close(errC)
128		}()
129
130		ok := true
131		for err := range errC {
132			if err != nil {
133				ok = false
134				t.Error(err)
135			}
136		}
137
138		return ok
139	}, nil); err != nil {
140		t.Error(err)
141	}
142}
143
144// Ensure two handlers cannot be registered for the same header byte.
145func TestMux_Listen_ErrAlreadyRegistered(t *testing.T) {
146	defer func() {
147		if r := recover(); r != `listener already registered under header byte: 5` {
148			t.Fatalf("unexpected recover: %#v", r)
149		}
150	}()
151
152	// Register two listeners with the same header byte.
153	mux := tcp.NewMux()
154	mux.Listen(5)
155	mux.Listen(5)
156}
157
158// Ensure that closing a listener from mux.Listen releases an Accept call and
159// deregisters the mux.
160func TestMux_Close(t *testing.T) {
161	listener, err := net.Listen("tcp", "127.0.0.1:0")
162	if err != nil {
163		t.Fatalf("unexpected error: %s", err)
164	}
165
166	done := make(chan struct{})
167	mux := tcp.NewMux()
168	go func() {
169		mux.Serve(listener)
170		close(done)
171	}()
172	l := mux.Listen(5)
173
174	closed := make(chan struct{})
175	go func() {
176		_, err := l.Accept()
177		if err == nil || !strings.Contains(err.Error(), "connection closed") {
178			t.Errorf("unexpected error: %s", err)
179		}
180		close(closed)
181	}()
182	l.Close()
183
184	timer := time.NewTimer(100 * time.Millisecond)
185	select {
186	case <-closed:
187		timer.Stop()
188	case <-timer.C:
189		t.Errorf("timeout while waiting for the mux to close")
190	}
191
192	// We should now be able to register a new listener at the same byte
193	// without causing a panic.
194	defer func() {
195		if r := recover(); r != nil {
196			t.Fatalf("unexpected recover: %#v", r)
197		}
198	}()
199	l = mux.Listen(5)
200
201	// Verify that closing the listener does not cause a panic.
202	listener.Close()
203	timer = time.NewTimer(100 * time.Millisecond)
204	select {
205	case <-done:
206		timer.Stop()
207		// This should not panic.
208		l.Close()
209	case <-timer.C:
210		t.Errorf("timeout while waiting for the mux to close")
211	}
212}
213