1package utp 2 3import ( 4 "context" 5 "errors" 6 "io" 7 "log" 8 "math/rand" 9 "net" 10 "sync" 11 "time" 12) 13 14var ( 15 _ net.Listener = &Socket{} 16 _ net.PacketConn = &Socket{} 17) 18 19// Uniquely identifies any uTP connection on top of the underlying packet 20// stream. 21type connKey struct { 22 remoteAddr resolvedAddrStr 23 connID uint16 24} 25 26// A Socket wraps a net.PacketConn, diverting uTP packets to its child uTP 27// Conns. 28type Socket struct { 29 pc net.PacketConn 30 conns map[connKey]*Conn 31 32 backlogNotEmpty Event 33 backlog map[syn]net.Addr 34 35 closed Event 36 destroyed Event 37 38 wgReadWrite sync.WaitGroup 39 40 unusedReads chan read 41 connDeadlines 42 // If a read error occurs on the underlying net.PacketConn, it is put 43 // here. This is because reading is done in its own goroutine to dispatch 44 // to uTP Conns. 45 ReadErr error 46 47 onAttach func(remote net.Addr) 48 onDetach func(remote net.Addr) 49 onMutex sync.RWMutex 50} 51 52func (s *Socket) OnAttach(f func(remote net.Addr)) { 53 s.onMutex.Lock() 54 defer s.onMutex.Unlock() 55 s.onAttach = f 56} 57 58func (s *Socket) OnDetach(f func(remote net.Addr)) { 59 s.onMutex.Lock() 60 defer s.onMutex.Unlock() 61 s.onDetach = f 62} 63 64func listenPacket(network, addr string) (pc net.PacketConn, err error) { 65 return net.ListenPacket(network, addr) 66} 67 68// NewSocket creates a net.PacketConn with the given network and address, and 69// returns a Socket dispatching on it. 70func NewSocket(network, addr string) (s *Socket, err error) { 71 if network == "" { 72 network = "udp" 73 } 74 pc, err := listenPacket(network, addr) 75 if err != nil { 76 return 77 } 78 return NewSocketFromPacketConn(pc) 79} 80 81// Create a Socket, using the provided net.PacketConn. If you want to retain 82// use of the net.PacketConn after the Socket closes it, override the 83// net.PacketConn's Close method, or use NetSocketFromPacketConnNoClose. 84func NewSocketFromPacketConn(pc net.PacketConn) (s *Socket, err error) { 85 s = &Socket{ 86 backlog: make(map[syn]net.Addr, backlog), 87 pc: pc, 88 unusedReads: make(chan read, 100), 89 wgReadWrite: sync.WaitGroup{}, 90 } 91 mu.Lock() 92 sockets[s] = struct{}{} 93 mu.Unlock() 94 go s.reader() 95 return 96} 97 98// Create a Socket using the provided PacketConn, that doesn't close the 99// PacketConn when the Socket is closed. 100func NewSocketFromPacketConnNoClose(pc net.PacketConn) (s *Socket, err error) { 101 return NewSocketFromPacketConn(packetConnNopCloser{pc}) 102} 103 104func (s *Socket) unusedRead(read read) { 105 unusedReads.Add(1) 106 select { 107 case s.unusedReads <- read: 108 default: 109 // Drop the packet. 110 unusedReadsDropped.Add(1) 111 } 112} 113 114func (s *Socket) pushBacklog(syn syn, addr net.Addr) { 115 if _, ok := s.backlog[syn]; ok { 116 return 117 } 118 // Pop a pseudo-random syn to make room. TODO: Use missinggo/orderedmap, 119 // coz that's what is wanted here. 120 for k, v := range s.backlog { 121 if len(s.backlog) < backlog { 122 break 123 } 124 delete(s.backlog, k) 125 // A syn is sent on the remote's recv_id, so this is where we can send 126 // the reset. 127 s.reset(v, k.seq_nr, k.conn_id) 128 } 129 s.backlog[syn] = addr 130 s.backlogChanged() 131} 132 133func (s *Socket) reader() { 134 mu.Lock() 135 defer mu.Unlock() 136 defer s.destroy() 137 var b [maxRecvSize]byte 138 for { 139 s.wgReadWrite.Add(1) 140 mu.Unlock() 141 n, addr, err := s.pc.ReadFrom(b[:]) 142 s.wgReadWrite.Done() 143 mu.Lock() 144 if s.destroyed.IsSet() { 145 return 146 } 147 if err != nil { 148 log.Printf("error reading Socket PacketConn: %s", err) 149 s.ReadErr = err 150 return 151 } 152 s.handleReceivedPacket(read{ 153 append([]byte(nil), b[:n]...), 154 addr, 155 }) 156 } 157} 158 159func receivedUTPPacketSize(n int) { 160 if n > largestReceivedUTPPacket { 161 largestReceivedUTPPacket = n 162 largestReceivedUTPPacketExpvar.Set(int64(n)) 163 } 164} 165 166func (s *Socket) connForRead(h header, from net.Addr) (c *Conn, ok bool) { 167 c, ok = s.conns[connKey{ 168 resolvedAddrStr(from.String()), 169 func() uint16 { 170 if h.Type == stSyn { 171 // SYNs have a ConnID one lower than the eventual recvID, and we index 172 // the connections with that, so use it for the lookup. 173 return h.ConnID + 1 174 } else { 175 return h.ConnID 176 } 177 }(), 178 }] 179 return 180} 181 182func (s *Socket) handlePacketReceivedForEstablishedConn(h header, from net.Addr, data []byte, c *Conn) { 183 if h.Type == stSyn { 184 if h.ConnID == c.send_id-2 { 185 // This is a SYN for connection that cannot exist locally. The 186 // connection the remote wants to establish here with the proposed 187 // recv_id, already has an existing connection that was dialled 188 // *out* from this socket, which is why the send_id is 1 higher, 189 // rather than 1 lower than the recv_id. 190 log.Print("resetting conflicting syn") 191 s.reset(from, h.SeqNr, h.ConnID) 192 return 193 } else if h.ConnID != c.send_id { 194 panic("bad assumption") 195 } 196 } 197 c.receivePacket(h, data) 198} 199 200func (s *Socket) handleReceivedPacket(p read) { 201 if len(p.data) < 20 { 202 s.unusedRead(p) 203 return 204 } 205 var h header 206 hEnd, err := h.Unmarshal(p.data) 207 if err != nil || h.Type > stMax || h.Version != 1 { 208 s.unusedRead(p) 209 return 210 } 211 if c, ok := s.connForRead(h, p.from); ok { 212 receivedUTPPacketSize(len(p.data)) 213 s.handlePacketReceivedForEstablishedConn(h, p.from, p.data[hEnd:], c) 214 return 215 } 216 // Packet doesn't belong to an existing connection. 217 switch h.Type { 218 case stSyn: 219 s.pushBacklog(syn{ 220 seq_nr: h.SeqNr, 221 conn_id: h.ConnID, 222 addr: p.from.String(), 223 }, p.from) 224 return 225 case stReset: 226 // Could be a late arriving packet for a Conn we're already done with. 227 // If it was for an existing connection, we would have handled it 228 // earlier. 229 default: 230 unexpectedPacketsRead.Add(1) 231 // This is an unexpected packet. We'll send a reset, but also pass it 232 // on. I don't think you can reset on the received packets ConnID if 233 // it isn't a SYN, as the send_id will differ in this case. 234 s.reset(p.from, h.SeqNr, h.ConnID) 235 // Connection initiated by remote. 236 s.reset(p.from, h.SeqNr, h.ConnID-1) 237 // Connection initiated locally. 238 s.reset(p.from, h.SeqNr, h.ConnID+1) 239 } 240 s.unusedRead(p) 241} 242 243// Send a reset in response to a packet with the given header. 244func (s *Socket) reset(addr net.Addr, ackNr, connId uint16) { 245 b := make([]byte, 0, maxHeaderSize) 246 h := header{ 247 Type: stReset, 248 Version: 1, 249 ConnID: connId, 250 AckNr: ackNr, 251 } 252 b = b[:h.Marshal(b)] 253 go s.writeTo(b, addr) 254} 255 256// Return a recv_id that should be free. Handling the case where it isn't is 257// deferred to a more appropriate function. 258func (s *Socket) newConnID(remoteAddr resolvedAddrStr) (id uint16) { 259 // Rather than use math.Rand, which requires generating all the IDs up 260 // front and allocating a slice, we do it on the stack, generating the IDs 261 // only as required. To do this, we use the fact that the array is 262 // default-initialized. IDs that are 0, are actually their index in the 263 // array. IDs that are non-zero, are +1 from their intended ID. 264 var idsBack [0x10000]int 265 ids := idsBack[:] 266 for len(ids) != 0 { 267 // Pick the next ID from the untried ids. 268 i := rand.Intn(len(ids)) 269 id = uint16(ids[i]) 270 // If it's zero, then treat it as though the index i was the ID. 271 // Otherwise the value we get is the ID+1. 272 if id == 0 { 273 id = uint16(i) 274 } else { 275 id-- 276 } 277 // Check there's no connection using this ID for its recv_id... 278 _, ok1 := s.conns[connKey{remoteAddr, id}] 279 // and if we're connecting to our own Socket, that there isn't a Conn 280 // already receiving on what will correspond to our send_id. Note that 281 // we just assume that we could be connecting to our own Socket. This 282 // will halve the available connection IDs to each distinct remote 283 // address. Presumably that's ~0x8000, down from ~0x10000. 284 _, ok2 := s.conns[connKey{remoteAddr, id + 1}] 285 _, ok4 := s.conns[connKey{remoteAddr, id - 1}] 286 if !ok1 && !ok2 && !ok4 { 287 return 288 } 289 // The set of possible IDs is shrinking. The highest one will be lost, so 290 // it's moved to the location of the one we just tried. 291 ids[i] = len(ids) // Conveniently already +1. 292 // And shrink. 293 ids = ids[:len(ids)-1] 294 } 295 return 296} 297 298var ( 299 zeroipv4 = net.ParseIP("0.0.0.0") 300 zeroipv6 = net.ParseIP("::") 301 302 ipv4lo = mustResolveUDP("127.0.0.1") 303 ipv6lo = mustResolveUDP("::1") 304) 305 306func mustResolveUDP(addr string) net.IP { 307 u, err := net.ResolveIPAddr("ip", addr) 308 if err != nil { 309 panic(err) 310 } 311 return u.IP 312} 313 314func realRemoteAddr(addr net.Addr) net.Addr { 315 udpAddr, ok := addr.(*net.UDPAddr) 316 if ok { 317 if udpAddr.IP.Equal(zeroipv4) { 318 udpAddr.IP = ipv4lo 319 } 320 if udpAddr.IP.Equal(zeroipv6) { 321 udpAddr.IP = ipv6lo 322 } 323 } 324 return addr 325} 326 327func (s *Socket) newConn(addr net.Addr) (c *Conn) { 328 addr = realRemoteAddr(addr) 329 330 c = &Conn{ 331 socket: s, 332 remoteSocketAddr: addr, 333 created: time.Now(), 334 } 335 c.sendPendingSendSendStateTimer = StoppedFuncTimer(c.sendPendingSendStateTimerCallback) 336 c.packetReadTimeoutTimer = time.AfterFunc(packetReadTimeout, c.receivePacketTimeoutCallback) 337 return 338} 339 340func (s *Socket) Dial(addr string) (net.Conn, error) { 341 return s.DialContext(context.Background(), "", addr) 342} 343 344func (s *Socket) DialAddr(netAddr net.Addr) (net.Conn, error) { 345 return s.DialAddrContext(context.Background(), netAddr) 346} 347 348func (s *Socket) resolveAddr(network, addr string) (net.Addr, error) { 349 n := s.network() 350 if network != "" { 351 n = network 352 } 353 return net.ResolveUDPAddr(n, addr) 354} 355 356func (s *Socket) network() string { 357 return s.pc.LocalAddr().Network() 358} 359 360func (s *Socket) startOutboundConn(addr net.Addr) (c *Conn, err error) { 361 mu.Lock() 362 defer mu.Unlock() 363 c = s.newConn(addr) 364 c.recv_id = s.newConnID(resolvedAddrStr(c.RemoteAddr().String())) 365 c.send_id = c.recv_id + 1 366 if logLevel >= 1 { 367 log.Printf("dial registering addr: %s", c.RemoteAddr().String()) 368 } 369 if !s.registerConn(c.recv_id, resolvedAddrStr(c.RemoteAddr().String()), c) { 370 err = errors.New("couldn't register new connection") 371 log.Println(c.recv_id, c.RemoteAddr().String()) 372 for k, c := range s.conns { 373 log.Println(k, c, c.age()) 374 } 375 log.Printf("that's %d connections", len(s.conns)) 376 } 377 if err != nil { 378 return 379 } 380 c.seq_nr = 1 381 c.writeSyn() 382 return 383} 384 385func (s *Socket) DialContext(ctx context.Context, network, addr string) (nc net.Conn, err error) { 386 netAddr, err := s.resolveAddr(network, addr) 387 if err != nil { 388 return 389 } 390 391 return s.DialAddrContext(ctx, netAddr) 392} 393 394func (s *Socket) DialAddrContext(ctx context.Context, netAddr net.Addr) (nc net.Conn, err error) { 395 c, err := s.startOutboundConn(netAddr) 396 if err != nil { 397 return 398 } 399 400 connErr := make(chan error, 1) 401 go func() { 402 connErr <- c.recvSynAck() 403 }() 404 select { 405 case err = <-connErr: 406 case <-ctx.Done(): 407 err = ctx.Err() 408 } 409 if err != nil { 410 mu.Lock() 411 c.destroy(errors.New("dial timeout")) 412 mu.Unlock() 413 return 414 } 415 mu.Lock() 416 c.updateCanWrite() 417 mu.Unlock() 418 //nc = pproffd.WrapNetConn(c) 419 nc = c 420 return 421} 422 423func (me *Socket) writeTo(b []byte, addr net.Addr) (n int, err error) { 424 apdc := artificialPacketDropChance 425 if apdc != 0 { 426 if rand.Float64() < apdc { 427 n = len(b) 428 return 429 } 430 } 431 n, err = me.pc.WriteTo(b, addr) 432 return 433} 434 435// Returns true if the connection was newly registered, false otherwise. 436func (s *Socket) registerConn(recvID uint16, remoteAddr resolvedAddrStr, c *Conn) bool { 437 if s.conns == nil { 438 s.conns = make(map[connKey]*Conn) 439 } 440 key := connKey{remoteAddr, recvID} 441 if _, ok := s.conns[key]; ok { 442 return false 443 } 444 c.connKey = key 445 s.conns[key] = c 446 s.onMutex.RLock() 447 defer s.onMutex.RUnlock() 448 if s.onAttach != nil { 449 go s.onAttach(c.remoteSocketAddr) 450 } 451 return true 452} 453 454func (s *Socket) backlogChanged() { 455 if len(s.backlog) != 0 { 456 s.backlogNotEmpty.Set() 457 } else { 458 s.backlogNotEmpty.Clear() 459 } 460} 461 462func (s *Socket) nextSyn() (syn syn, addr net.Addr, err error) { 463 for { 464 WaitEvents(&mu, &s.closed, &s.backlogNotEmpty, &s.destroyed) 465 if s.closed.IsSet() { 466 err = errClosed 467 return 468 } 469 if s.destroyed.IsSet() { 470 err = s.ReadErr 471 return 472 } 473 for k, v := range s.backlog { 474 syn = k 475 addr = v 476 delete(s.backlog, k) 477 s.backlogChanged() 478 return 479 } 480 } 481} 482 483// ACK a SYN, and return a new Conn for it. ok is false if the SYN is bad, and 484// the Conn invalid. 485func (s *Socket) ackSyn(syn syn, addr net.Addr) (c *Conn, ok bool) { 486 c = s.newConn(addr) 487 c.send_id = syn.conn_id 488 c.recv_id = c.send_id + 1 489 c.seq_nr = uint16(rand.Int()) 490 c.lastAck = c.seq_nr - 1 491 c.ack_nr = syn.seq_nr 492 c.synAcked = true 493 c.updateCanWrite() 494 if !s.registerConn(c.recv_id, resolvedAddrStr(addr.String()), c) { 495 // SYN that triggered this accept duplicates existing connection. 496 // Ack again in case the SYN was a resend. 497 c = s.conns[connKey{resolvedAddrStr(addr.String()), c.recv_id}] 498 if c.send_id != syn.conn_id { 499 panic(":|") 500 } 501 c.sendState() 502 return 503 } 504 c.sendState() 505 ok = true 506 return 507} 508 509// Accept and return a new uTP connection. 510func (s *Socket) Accept() (net.Conn, error) { 511 mu.Lock() 512 defer mu.Unlock() 513 for { 514 syn, addr, err := s.nextSyn() 515 if err != nil { 516 return nil, err 517 } 518 c, ok := s.ackSyn(syn, addr) 519 if ok { 520 c.updateCanWrite() 521 return c, nil 522 } 523 } 524} 525 526// The address we're listening on for new uTP connections. 527func (s *Socket) Addr() net.Addr { 528 return s.pc.LocalAddr() 529} 530 531func (s *Socket) CloseNow() error { 532 mu.Lock() 533 defer mu.Unlock() 534 s.closed.Set() 535 for _, c := range s.conns { 536 c.closeNow() 537 } 538 s.destroy() 539 s.wgReadWrite.Wait() 540 return nil 541} 542 543func (s *Socket) Close() error { 544 mu.Lock() 545 defer mu.Unlock() 546 s.closed.Set() 547 s.lazyDestroy() 548 return nil 549} 550 551func (s *Socket) lazyDestroy() { 552 if len(s.conns) != 0 { 553 return 554 } 555 if !s.closed.IsSet() { 556 return 557 } 558 s.destroy() 559} 560 561func (s *Socket) destroy() { 562 delete(sockets, s) 563 s.destroyed.Set() 564 s.pc.Close() 565 for _, c := range s.conns { 566 c.destroy(errors.New("Socket destroyed")) 567 } 568} 569 570func (s *Socket) LocalAddr() net.Addr { 571 return s.pc.LocalAddr() 572} 573 574func (s *Socket) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 575 select { 576 case read, ok := <-s.unusedReads: 577 if !ok { 578 err = io.EOF 579 return 580 } 581 n = copy(p, read.data) 582 addr = read.from 583 return 584 case <-s.connDeadlines.read.passed.LockedChan(&mu): 585 err = errTimeout 586 return 587 } 588} 589 590func (s *Socket) WriteTo(b []byte, addr net.Addr) (n int, err error) { 591 mu.Lock() 592 if s.connDeadlines.write.passed.IsSet() { 593 err = errTimeout 594 } 595 s.wgReadWrite.Add(1) 596 defer s.wgReadWrite.Done() 597 mu.Unlock() 598 if err != nil { 599 return 600 } 601 return s.pc.WriteTo(b, addr) 602} 603