1package remotedialer
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"io"
8	"net"
9	"os"
10	"strconv"
11	"strings"
12	"sync"
13	"sync/atomic"
14	"time"
15
16	"github.com/gorilla/websocket"
17	"github.com/sirupsen/logrus"
18)
19
20type Session struct {
21	sync.Mutex
22
23	nextConnID       int64
24	clientKey        string
25	sessionKey       int64
26	conn             *wsConn
27	conns            map[int64]*connection
28	remoteClientKeys map[string]map[int]bool
29	auth             ConnectAuthorizer
30	pingCancel       context.CancelFunc
31	pingWait         sync.WaitGroup
32	dialer           Dialer
33	client           bool
34}
35
36// PrintTunnelData No tunnel logging by default
37var PrintTunnelData bool
38
39func init() {
40	if os.Getenv("CATTLE_TUNNEL_DATA_DEBUG") == "true" {
41		PrintTunnelData = true
42	}
43}
44
45func NewClientSession(auth ConnectAuthorizer, conn *websocket.Conn) *Session {
46	return &Session{
47		clientKey: "client",
48		conn:      newWSConn(conn),
49		conns:     map[int64]*connection{},
50		auth:      auth,
51		client:    true,
52	}
53}
54
55func newSession(sessionKey int64, clientKey string, conn *websocket.Conn) *Session {
56	return &Session{
57		nextConnID:       1,
58		clientKey:        clientKey,
59		sessionKey:       sessionKey,
60		conn:             newWSConn(conn),
61		conns:            map[int64]*connection{},
62		remoteClientKeys: map[string]map[int]bool{},
63	}
64}
65
66func (s *Session) startPings(rootCtx context.Context) {
67	ctx, cancel := context.WithCancel(rootCtx)
68	s.pingCancel = cancel
69	s.pingWait.Add(1)
70
71	go func() {
72		defer s.pingWait.Done()
73
74		t := time.NewTicker(PingWriteInterval)
75		defer t.Stop()
76
77		for {
78			select {
79			case <-ctx.Done():
80				return
81			case <-t.C:
82				s.conn.Lock()
83				if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(PingWaitDuration)); err != nil {
84					logrus.WithError(err).Error("Error writing ping")
85				}
86				logrus.Debug("Wrote ping")
87				s.conn.Unlock()
88			}
89		}
90	}()
91}
92
93func (s *Session) stopPings() {
94	if s.pingCancel == nil {
95		return
96	}
97
98	s.pingCancel()
99	s.pingWait.Wait()
100}
101
102func (s *Session) Serve(ctx context.Context) (int, error) {
103	if s.client {
104		s.startPings(ctx)
105	}
106
107	for {
108		msType, reader, err := s.conn.NextReader()
109		if err != nil {
110			return 400, err
111		}
112
113		if msType != websocket.BinaryMessage {
114			return 400, errWrongMessageType
115		}
116
117		if err := s.serveMessage(ctx, reader); err != nil {
118			return 500, err
119		}
120	}
121}
122
123func (s *Session) serveMessage(ctx context.Context, reader io.Reader) error {
124	message, err := newServerMessage(reader)
125	if err != nil {
126		return err
127	}
128
129	if PrintTunnelData {
130		logrus.Debug("REQUEST ", message)
131	}
132
133	if message.messageType == Connect {
134		if s.auth == nil || !s.auth(message.proto, message.address) {
135			return errors.New("connect not allowed")
136		}
137		s.clientConnect(ctx, message)
138		return nil
139	}
140
141	s.Lock()
142	if message.messageType == AddClient && s.remoteClientKeys != nil {
143		err := s.addRemoteClient(message.address)
144		s.Unlock()
145		return err
146	} else if message.messageType == RemoveClient {
147		err := s.removeRemoteClient(message.address)
148		s.Unlock()
149		return err
150	}
151	conn := s.conns[message.connID]
152	s.Unlock()
153
154	if conn == nil {
155		if message.messageType == Data {
156			err := fmt.Errorf("connection not found %s/%d/%d", s.clientKey, s.sessionKey, message.connID)
157			newErrorMessage(message.connID, err).WriteTo(defaultDeadline(), s.conn)
158		}
159		return nil
160	}
161
162	switch message.messageType {
163	case Data:
164		if err := conn.OnData(message); err != nil {
165			s.closeConnection(message.connID, err)
166		}
167	case Error:
168		s.closeConnection(message.connID, message.Err())
169	}
170
171	return nil
172}
173
174func defaultDeadline() time.Time {
175	return time.Now().Add(time.Minute)
176}
177
178func parseAddress(address string) (string, int, error) {
179	parts := strings.SplitN(address, "/", 2)
180	if len(parts) != 2 {
181		return "", 0, errors.New("not / separated")
182	}
183	v, err := strconv.Atoi(parts[1])
184	return parts[0], v, err
185}
186
187func (s *Session) addRemoteClient(address string) error {
188	clientKey, sessionKey, err := parseAddress(address)
189	if err != nil {
190		return fmt.Errorf("invalid remote Session %s: %v", address, err)
191	}
192
193	keys := s.remoteClientKeys[clientKey]
194	if keys == nil {
195		keys = map[int]bool{}
196		s.remoteClientKeys[clientKey] = keys
197	}
198	keys[sessionKey] = true
199
200	if PrintTunnelData {
201		logrus.Debugf("ADD REMOTE CLIENT %s, SESSION %d", address, s.sessionKey)
202	}
203
204	return nil
205}
206
207func (s *Session) removeRemoteClient(address string) error {
208	clientKey, sessionKey, err := parseAddress(address)
209	if err != nil {
210		return fmt.Errorf("invalid remote Session %s: %v", address, err)
211	}
212
213	keys := s.remoteClientKeys[clientKey]
214	delete(keys, int(sessionKey))
215	if len(keys) == 0 {
216		delete(s.remoteClientKeys, clientKey)
217	}
218
219	if PrintTunnelData {
220		logrus.Debugf("REMOVE REMOTE CLIENT %s, SESSION %d", address, s.sessionKey)
221	}
222
223	return nil
224}
225
226func (s *Session) closeConnection(connID int64, err error) {
227	s.Lock()
228	conn := s.conns[connID]
229	delete(s.conns, connID)
230	if PrintTunnelData {
231		logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
232	}
233	s.Unlock()
234
235	if conn != nil {
236		conn.tunnelClose(err)
237	}
238}
239
240func (s *Session) clientConnect(ctx context.Context, message *message) {
241	conn := newConnection(message.connID, s, message.proto, message.address)
242
243	s.Lock()
244	s.conns[message.connID] = conn
245	if PrintTunnelData {
246		logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
247	}
248	s.Unlock()
249
250	go clientDial(ctx, s.dialer, conn, message)
251}
252
253type connResult struct {
254	conn net.Conn
255	err  error
256}
257
258func (s *Session) Dial(ctx context.Context, proto, address string) (net.Conn, error) {
259	return s.serverConnectContext(ctx, proto, address)
260}
261
262func (s *Session) serverConnectContext(ctx context.Context, proto, address string) (net.Conn, error) {
263	deadline, ok := ctx.Deadline()
264	if ok {
265		return s.serverConnect(deadline, proto, address)
266	}
267
268	result := make(chan connResult, 1)
269	go func() {
270		c, err := s.serverConnect(defaultDeadline(), proto, address)
271		result <- connResult{conn: c, err: err}
272	}()
273
274	select {
275	case <-ctx.Done():
276		// We don't want to orphan an open connection so we wait for the result and immediately close it
277		go func() {
278			r := <-result
279			if r.err == nil {
280				r.conn.Close()
281			}
282		}()
283		return nil, ctx.Err()
284	case r := <-result:
285		return r.conn, r.err
286	}
287}
288
289func (s *Session) serverConnect(deadline time.Time, proto, address string) (net.Conn, error) {
290	connID := atomic.AddInt64(&s.nextConnID, 1)
291	conn := newConnection(connID, s, proto, address)
292
293	s.Lock()
294	s.conns[connID] = conn
295	if PrintTunnelData {
296		logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
297	}
298	s.Unlock()
299
300	_, err := s.writeMessage(deadline, newConnect(connID, proto, address))
301	if err != nil {
302		s.closeConnection(connID, err)
303		return nil, err
304	}
305
306	return conn, err
307}
308
309func (s *Session) writeMessage(deadline time.Time, message *message) (int, error) {
310	if PrintTunnelData {
311		logrus.Debug("WRITE ", message)
312	}
313	return message.WriteTo(deadline, s.conn)
314}
315
316func (s *Session) Close() {
317	s.Lock()
318	defer s.Unlock()
319
320	s.stopPings()
321
322	for _, connection := range s.conns {
323		connection.tunnelClose(errors.New("tunnel disconnect"))
324	}
325
326	s.conns = map[int64]*connection{}
327}
328
329func (s *Session) sessionAdded(clientKey string, sessionKey int64) {
330	client := fmt.Sprintf("%s/%d", clientKey, sessionKey)
331	_, err := s.writeMessage(time.Time{}, newAddClient(client))
332	if err != nil {
333		s.conn.conn.Close()
334	}
335}
336
337func (s *Session) sessionRemoved(clientKey string, sessionKey int64) {
338	client := fmt.Sprintf("%s/%d", clientKey, sessionKey)
339	_, err := s.writeMessage(time.Time{}, newRemoveClient(client))
340	if err != nil {
341		s.conn.conn.Close()
342	}
343}
344