1package remotedialer 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "os" 10 "strconv" 11 "strings" 12 "sync" 13 "sync/atomic" 14 "time" 15 16 "github.com/gorilla/websocket" 17 "github.com/sirupsen/logrus" 18) 19 20type Session struct { 21 sync.Mutex 22 23 nextConnID int64 24 clientKey string 25 sessionKey int64 26 conn *wsConn 27 conns map[int64]*connection 28 remoteClientKeys map[string]map[int]bool 29 auth ConnectAuthorizer 30 pingCancel context.CancelFunc 31 pingWait sync.WaitGroup 32 dialer Dialer 33 client bool 34} 35 36// PrintTunnelData No tunnel logging by default 37var PrintTunnelData bool 38 39func init() { 40 if os.Getenv("CATTLE_TUNNEL_DATA_DEBUG") == "true" { 41 PrintTunnelData = true 42 } 43} 44 45func NewClientSession(auth ConnectAuthorizer, conn *websocket.Conn) *Session { 46 return &Session{ 47 clientKey: "client", 48 conn: newWSConn(conn), 49 conns: map[int64]*connection{}, 50 auth: auth, 51 client: true, 52 } 53} 54 55func newSession(sessionKey int64, clientKey string, conn *websocket.Conn) *Session { 56 return &Session{ 57 nextConnID: 1, 58 clientKey: clientKey, 59 sessionKey: sessionKey, 60 conn: newWSConn(conn), 61 conns: map[int64]*connection{}, 62 remoteClientKeys: map[string]map[int]bool{}, 63 } 64} 65 66func (s *Session) startPings(rootCtx context.Context) { 67 ctx, cancel := context.WithCancel(rootCtx) 68 s.pingCancel = cancel 69 s.pingWait.Add(1) 70 71 go func() { 72 defer s.pingWait.Done() 73 74 t := time.NewTicker(PingWriteInterval) 75 defer t.Stop() 76 77 for { 78 select { 79 case <-ctx.Done(): 80 return 81 case <-t.C: 82 s.conn.Lock() 83 if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(PingWaitDuration)); err != nil { 84 logrus.WithError(err).Error("Error writing ping") 85 } 86 logrus.Debug("Wrote ping") 87 s.conn.Unlock() 88 } 89 } 90 }() 91} 92 93func (s *Session) stopPings() { 94 if s.pingCancel == nil { 95 return 96 } 97 98 s.pingCancel() 99 s.pingWait.Wait() 100} 101 102func (s *Session) Serve(ctx context.Context) (int, error) { 103 if s.client { 104 s.startPings(ctx) 105 } 106 107 for { 108 msType, reader, err := s.conn.NextReader() 109 if err != nil { 110 return 400, err 111 } 112 113 if msType != websocket.BinaryMessage { 114 return 400, errWrongMessageType 115 } 116 117 if err := s.serveMessage(ctx, reader); err != nil { 118 return 500, err 119 } 120 } 121} 122 123func (s *Session) serveMessage(ctx context.Context, reader io.Reader) error { 124 message, err := newServerMessage(reader) 125 if err != nil { 126 return err 127 } 128 129 if PrintTunnelData { 130 logrus.Debug("REQUEST ", message) 131 } 132 133 if message.messageType == Connect { 134 if s.auth == nil || !s.auth(message.proto, message.address) { 135 return errors.New("connect not allowed") 136 } 137 s.clientConnect(ctx, message) 138 return nil 139 } 140 141 s.Lock() 142 if message.messageType == AddClient && s.remoteClientKeys != nil { 143 err := s.addRemoteClient(message.address) 144 s.Unlock() 145 return err 146 } else if message.messageType == RemoveClient { 147 err := s.removeRemoteClient(message.address) 148 s.Unlock() 149 return err 150 } 151 conn := s.conns[message.connID] 152 s.Unlock() 153 154 if conn == nil { 155 if message.messageType == Data { 156 err := fmt.Errorf("connection not found %s/%d/%d", s.clientKey, s.sessionKey, message.connID) 157 newErrorMessage(message.connID, err).WriteTo(defaultDeadline(), s.conn) 158 } 159 return nil 160 } 161 162 switch message.messageType { 163 case Data: 164 if err := conn.OnData(message); err != nil { 165 s.closeConnection(message.connID, err) 166 } 167 case Error: 168 s.closeConnection(message.connID, message.Err()) 169 } 170 171 return nil 172} 173 174func defaultDeadline() time.Time { 175 return time.Now().Add(time.Minute) 176} 177 178func parseAddress(address string) (string, int, error) { 179 parts := strings.SplitN(address, "/", 2) 180 if len(parts) != 2 { 181 return "", 0, errors.New("not / separated") 182 } 183 v, err := strconv.Atoi(parts[1]) 184 return parts[0], v, err 185} 186 187func (s *Session) addRemoteClient(address string) error { 188 clientKey, sessionKey, err := parseAddress(address) 189 if err != nil { 190 return fmt.Errorf("invalid remote Session %s: %v", address, err) 191 } 192 193 keys := s.remoteClientKeys[clientKey] 194 if keys == nil { 195 keys = map[int]bool{} 196 s.remoteClientKeys[clientKey] = keys 197 } 198 keys[sessionKey] = true 199 200 if PrintTunnelData { 201 logrus.Debugf("ADD REMOTE CLIENT %s, SESSION %d", address, s.sessionKey) 202 } 203 204 return nil 205} 206 207func (s *Session) removeRemoteClient(address string) error { 208 clientKey, sessionKey, err := parseAddress(address) 209 if err != nil { 210 return fmt.Errorf("invalid remote Session %s: %v", address, err) 211 } 212 213 keys := s.remoteClientKeys[clientKey] 214 delete(keys, int(sessionKey)) 215 if len(keys) == 0 { 216 delete(s.remoteClientKeys, clientKey) 217 } 218 219 if PrintTunnelData { 220 logrus.Debugf("REMOVE REMOTE CLIENT %s, SESSION %d", address, s.sessionKey) 221 } 222 223 return nil 224} 225 226func (s *Session) closeConnection(connID int64, err error) { 227 s.Lock() 228 conn := s.conns[connID] 229 delete(s.conns, connID) 230 if PrintTunnelData { 231 logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) 232 } 233 s.Unlock() 234 235 if conn != nil { 236 conn.tunnelClose(err) 237 } 238} 239 240func (s *Session) clientConnect(ctx context.Context, message *message) { 241 conn := newConnection(message.connID, s, message.proto, message.address) 242 243 s.Lock() 244 s.conns[message.connID] = conn 245 if PrintTunnelData { 246 logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) 247 } 248 s.Unlock() 249 250 go clientDial(ctx, s.dialer, conn, message) 251} 252 253type connResult struct { 254 conn net.Conn 255 err error 256} 257 258func (s *Session) Dial(ctx context.Context, proto, address string) (net.Conn, error) { 259 return s.serverConnectContext(ctx, proto, address) 260} 261 262func (s *Session) serverConnectContext(ctx context.Context, proto, address string) (net.Conn, error) { 263 deadline, ok := ctx.Deadline() 264 if ok { 265 return s.serverConnect(deadline, proto, address) 266 } 267 268 result := make(chan connResult, 1) 269 go func() { 270 c, err := s.serverConnect(defaultDeadline(), proto, address) 271 result <- connResult{conn: c, err: err} 272 }() 273 274 select { 275 case <-ctx.Done(): 276 // We don't want to orphan an open connection so we wait for the result and immediately close it 277 go func() { 278 r := <-result 279 if r.err == nil { 280 r.conn.Close() 281 } 282 }() 283 return nil, ctx.Err() 284 case r := <-result: 285 return r.conn, r.err 286 } 287} 288 289func (s *Session) serverConnect(deadline time.Time, proto, address string) (net.Conn, error) { 290 connID := atomic.AddInt64(&s.nextConnID, 1) 291 conn := newConnection(connID, s, proto, address) 292 293 s.Lock() 294 s.conns[connID] = conn 295 if PrintTunnelData { 296 logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) 297 } 298 s.Unlock() 299 300 _, err := s.writeMessage(deadline, newConnect(connID, proto, address)) 301 if err != nil { 302 s.closeConnection(connID, err) 303 return nil, err 304 } 305 306 return conn, err 307} 308 309func (s *Session) writeMessage(deadline time.Time, message *message) (int, error) { 310 if PrintTunnelData { 311 logrus.Debug("WRITE ", message) 312 } 313 return message.WriteTo(deadline, s.conn) 314} 315 316func (s *Session) Close() { 317 s.Lock() 318 defer s.Unlock() 319 320 s.stopPings() 321 322 for _, connection := range s.conns { 323 connection.tunnelClose(errors.New("tunnel disconnect")) 324 } 325 326 s.conns = map[int64]*connection{} 327} 328 329func (s *Session) sessionAdded(clientKey string, sessionKey int64) { 330 client := fmt.Sprintf("%s/%d", clientKey, sessionKey) 331 _, err := s.writeMessage(time.Time{}, newAddClient(client)) 332 if err != nil { 333 s.conn.conn.Close() 334 } 335} 336 337func (s *Session) sessionRemoved(clientKey string, sessionKey int64) { 338 client := fmt.Sprintf("%s/%d", clientKey, sessionKey) 339 _, err := s.writeMessage(time.Time{}, newRemoveClient(client)) 340 if err != nil { 341 s.conn.conn.Close() 342 } 343} 344