1package tcp
2
3import (
4	"bufio"
5	"bytes"
6	"crypto/tls"
7	"errors"
8	"io"
9	"net"
10	"net/http"
11	"strings"
12	"time"
13
14	"github.com/traefik/traefik/v2/pkg/log"
15	"github.com/traefik/traefik/v2/pkg/types"
16)
17
18const defaultBufSize = 4096
19
20// Router is a TCP router.
21type Router struct {
22	routingTable      map[string]Handler
23	httpForwarder     Handler
24	httpsForwarder    Handler
25	httpHandler       http.Handler
26	httpsHandler      http.Handler
27	httpsTLSConfig    *tls.Config // default TLS config
28	catchAllNoTLS     Handler
29	hostHTTPTLSConfig map[string]*tls.Config // TLS configs keyed by SNI
30}
31
32// GetTLSGetClientInfo is called after a ClientHello is received from a client.
33func (r *Router) GetTLSGetClientInfo() func(info *tls.ClientHelloInfo) (*tls.Config, error) {
34	return func(info *tls.ClientHelloInfo) (*tls.Config, error) {
35		if tlsConfig, ok := r.hostHTTPTLSConfig[info.ServerName]; ok {
36			return tlsConfig, nil
37		}
38		return r.httpsTLSConfig, nil
39	}
40}
41
42// ServeTCP forwards the connection to the right TCP/HTTP handler.
43func (r *Router) ServeTCP(conn WriteCloser) {
44	// FIXME -- Check if ProxyProtocol changes the first bytes of the request
45
46	if r.catchAllNoTLS != nil && len(r.routingTable) == 0 {
47		r.catchAllNoTLS.ServeTCP(conn)
48		return
49	}
50
51	br := bufio.NewReader(conn)
52	serverName, tls, peeked, err := clientHelloServerName(br)
53	if err != nil {
54		conn.Close()
55		return
56	}
57
58	// Remove read/write deadline and delegate this to underlying tcp server (for now only handled by HTTP Server)
59	err = conn.SetReadDeadline(time.Time{})
60	if err != nil {
61		log.WithoutContext().Errorf("Error while setting read deadline: %v", err)
62	}
63
64	err = conn.SetWriteDeadline(time.Time{})
65	if err != nil {
66		log.WithoutContext().Errorf("Error while setting write deadline: %v", err)
67	}
68
69	if !tls {
70		switch {
71		case r.catchAllNoTLS != nil:
72			r.catchAllNoTLS.ServeTCP(r.GetConn(conn, peeked))
73		case r.httpForwarder != nil:
74			r.httpForwarder.ServeTCP(r.GetConn(conn, peeked))
75		default:
76			conn.Close()
77		}
78		return
79	}
80
81	// FIXME Optimize and test the routing table before helloServerName
82	serverName = types.CanonicalDomain(serverName)
83	if r.routingTable != nil && serverName != "" {
84		if target, ok := r.routingTable[serverName]; ok {
85			target.ServeTCP(r.GetConn(conn, peeked))
86			return
87		}
88	}
89
90	// FIXME Needs tests
91	if target, ok := r.routingTable["*"]; ok {
92		target.ServeTCP(r.GetConn(conn, peeked))
93		return
94	}
95
96	if r.httpsForwarder != nil {
97		r.httpsForwarder.ServeTCP(r.GetConn(conn, peeked))
98	} else {
99		conn.Close()
100	}
101}
102
103// AddRoute defines a handler for a given sniHost (* is the only valid option).
104func (r *Router) AddRoute(sniHost string, target Handler) {
105	if r.routingTable == nil {
106		r.routingTable = map[string]Handler{}
107	}
108	r.routingTable[strings.ToLower(sniHost)] = target
109}
110
111// AddRouteTLS defines a handler for a given sniHost and sets the matching tlsConfig.
112func (r *Router) AddRouteTLS(sniHost string, target Handler, config *tls.Config) {
113	r.AddRoute(sniHost, &TLSHandler{
114		Next:   target,
115		Config: config,
116	})
117}
118
119// AddRouteHTTPTLS defines a handler for a given sniHost and sets the matching tlsConfig.
120func (r *Router) AddRouteHTTPTLS(sniHost string, config *tls.Config) {
121	if r.hostHTTPTLSConfig == nil {
122		r.hostHTTPTLSConfig = map[string]*tls.Config{}
123	}
124	r.hostHTTPTLSConfig[sniHost] = config
125}
126
127// AddCatchAllNoTLS defines the fallback tcp handler.
128func (r *Router) AddCatchAllNoTLS(handler Handler) {
129	r.catchAllNoTLS = handler
130}
131
132// GetConn creates a connection proxy with a peeked string.
133func (r *Router) GetConn(conn WriteCloser, peeked string) WriteCloser {
134	// FIXME should it really be on Router ?
135	conn = &Conn{
136		Peeked:      []byte(peeked),
137		WriteCloser: conn,
138	}
139	return conn
140}
141
142// GetHTTPHandler gets the attached http handler.
143func (r *Router) GetHTTPHandler() http.Handler {
144	return r.httpHandler
145}
146
147// GetHTTPSHandler gets the attached https handler.
148func (r *Router) GetHTTPSHandler() http.Handler {
149	return r.httpsHandler
150}
151
152// HTTPForwarder sets the tcp handler that will forward the connections to an http handler.
153func (r *Router) HTTPForwarder(handler Handler) {
154	r.httpForwarder = handler
155}
156
157// HTTPSForwarder sets the tcp handler that will forward the TLS connections to an http handler.
158func (r *Router) HTTPSForwarder(handler Handler) {
159	for sniHost, tlsConf := range r.hostHTTPTLSConfig {
160		r.AddRouteTLS(sniHost, handler, tlsConf)
161	}
162
163	r.httpsForwarder = &TLSHandler{
164		Next:   handler,
165		Config: r.httpsTLSConfig,
166	}
167}
168
169// HTTPHandler attaches http handlers on the router.
170func (r *Router) HTTPHandler(handler http.Handler) {
171	r.httpHandler = handler
172}
173
174// HTTPSHandler attaches https handlers on the router.
175func (r *Router) HTTPSHandler(handler http.Handler, config *tls.Config) {
176	r.httpsHandler = handler
177	r.httpsTLSConfig = config
178}
179
180// Conn is a connection proxy that handles Peeked bytes.
181type Conn struct {
182	// Peeked are the bytes that have been read from Conn for the
183	// purposes of route matching, but have not yet been consumed
184	// by Read calls. It set to nil by Read when fully consumed.
185	Peeked []byte
186
187	// Conn is the underlying connection.
188	// It can be type asserted against *net.TCPConn or other types
189	// as needed. It should not be read from directly unless
190	// Peeked is nil.
191	WriteCloser
192}
193
194// Read reads bytes from the connection (using the buffer prior to actually reading).
195func (c *Conn) Read(p []byte) (n int, err error) {
196	if len(c.Peeked) > 0 {
197		n = copy(p, c.Peeked)
198		c.Peeked = c.Peeked[n:]
199		if len(c.Peeked) == 0 {
200			c.Peeked = nil
201		}
202		return n, nil
203	}
204	return c.WriteCloser.Read(p)
205}
206
207// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
208// without consuming any bytes from br.
209// On any error, the empty string is returned.
210func clientHelloServerName(br *bufio.Reader) (string, bool, string, error) {
211	hdr, err := br.Peek(1)
212	if err != nil {
213		var opErr *net.OpError
214		if !errors.Is(err, io.EOF) && (!errors.As(err, &opErr) || opErr.Timeout()) {
215			log.WithoutContext().Debugf("Error while Peeking first byte: %s", err)
216		}
217
218		return "", false, "", err
219	}
220
221	// No valid TLS record has a type of 0x80, however SSLv2 handshakes
222	// start with a uint16 length where the MSB is set and the first record
223	// is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
224	// an SSLv2 client.
225	const recordTypeSSLv2 = 0x80
226	const recordTypeHandshake = 0x16
227	if hdr[0] != recordTypeHandshake {
228		if hdr[0] == recordTypeSSLv2 {
229			// we consider SSLv2 as TLS and it will be refuse by real TLS handshake.
230			return "", true, getPeeked(br), nil
231		}
232		return "", false, getPeeked(br), nil // Not TLS.
233	}
234
235	const recordHeaderLen = 5
236	hdr, err = br.Peek(recordHeaderLen)
237	if err != nil {
238		log.Errorf("Error while Peeking hello: %s", err)
239		return "", false, getPeeked(br), nil
240	}
241
242	recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3]
243
244	if recordHeaderLen+recLen > defaultBufSize {
245		br = bufio.NewReaderSize(br, recordHeaderLen+recLen)
246	}
247
248	helloBytes, err := br.Peek(recordHeaderLen + recLen)
249	if err != nil {
250		log.Errorf("Error while Hello: %s", err)
251		return "", true, getPeeked(br), nil
252	}
253
254	sni := ""
255	server := tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{
256		GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
257			sni = hello.ServerName
258			return nil, nil
259		},
260	})
261	_ = server.Handshake()
262
263	return sni, true, getPeeked(br), nil
264}
265
266func getPeeked(br *bufio.Reader) string {
267	peeked, err := br.Peek(br.Buffered())
268	if err != nil {
269		log.Errorf("Could not get anything: %s", err)
270		return ""
271	}
272	return string(peeked)
273}
274
275// sniSniffConn is a net.Conn that reads from r, fails on Writes,
276// and crashes otherwise.
277type sniSniffConn struct {
278	r        io.Reader
279	net.Conn // nil; crash on any unexpected use
280}
281
282// Read reads from the underlying reader.
283func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) }
284
285// Write crashes all the time.
286func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF }
287