1package utp
2
3import (
4	"context"
5	"errors"
6	"io"
7	"log"
8	"math/rand"
9	"net"
10	"sync"
11	"time"
12)
13
14var (
15	_ net.Listener   = &Socket{}
16	_ net.PacketConn = &Socket{}
17)
18
19// Uniquely identifies any uTP connection on top of the underlying packet
20// stream.
21type connKey struct {
22	remoteAddr resolvedAddrStr
23	connID     uint16
24}
25
26// A Socket wraps a net.PacketConn, diverting uTP packets to its child uTP
27// Conns.
28type Socket struct {
29	pc    net.PacketConn
30	conns map[connKey]*Conn
31
32	backlogNotEmpty Event
33	backlog         map[syn]net.Addr
34
35	closed    Event
36	destroyed Event
37
38	wgReadWrite sync.WaitGroup
39
40	unusedReads chan read
41	connDeadlines
42	// If a read error occurs on the underlying net.PacketConn, it is put
43	// here. This is because reading is done in its own goroutine to dispatch
44	// to uTP Conns.
45	ReadErr error
46
47	onAttach func(remote net.Addr)
48	onDetach func(remote net.Addr)
49	onMutex  sync.RWMutex
50}
51
52func (s *Socket) OnAttach(f func(remote net.Addr)) {
53	s.onMutex.Lock()
54	defer s.onMutex.Unlock()
55	s.onAttach = f
56}
57
58func (s *Socket) OnDetach(f func(remote net.Addr)) {
59	s.onMutex.Lock()
60	defer s.onMutex.Unlock()
61	s.onDetach = f
62}
63
64func listenPacket(network, addr string) (pc net.PacketConn, err error) {
65	return net.ListenPacket(network, addr)
66}
67
68// NewSocket creates a net.PacketConn with the given network and address, and
69// returns a Socket dispatching on it.
70func NewSocket(network, addr string) (s *Socket, err error) {
71	if network == "" {
72		network = "udp"
73	}
74	pc, err := listenPacket(network, addr)
75	if err != nil {
76		return
77	}
78	return NewSocketFromPacketConn(pc)
79}
80
81// Create a Socket, using the provided net.PacketConn. If you want to retain
82// use of the net.PacketConn after the Socket closes it, override the
83// net.PacketConn's Close method, or use NetSocketFromPacketConnNoClose.
84func NewSocketFromPacketConn(pc net.PacketConn) (s *Socket, err error) {
85	s = &Socket{
86		backlog:     make(map[syn]net.Addr, backlog),
87		pc:          pc,
88		unusedReads: make(chan read, 100),
89		wgReadWrite: sync.WaitGroup{},
90	}
91	mu.Lock()
92	sockets[s] = struct{}{}
93	mu.Unlock()
94	go s.reader()
95	return
96}
97
98// Create a Socket using the provided PacketConn, that doesn't close the
99// PacketConn when the Socket is closed.
100func NewSocketFromPacketConnNoClose(pc net.PacketConn) (s *Socket, err error) {
101	return NewSocketFromPacketConn(packetConnNopCloser{pc})
102}
103
104func (s *Socket) unusedRead(read read) {
105	unusedReads.Add(1)
106	select {
107	case s.unusedReads <- read:
108	default:
109		// Drop the packet.
110		unusedReadsDropped.Add(1)
111	}
112}
113
114func (s *Socket) pushBacklog(syn syn, addr net.Addr) {
115	if _, ok := s.backlog[syn]; ok {
116		return
117	}
118	// Pop a pseudo-random syn to make room. TODO: Use missinggo/orderedmap,
119	// coz that's what is wanted here.
120	for k, v := range s.backlog {
121		if len(s.backlog) < backlog {
122			break
123		}
124		delete(s.backlog, k)
125		// A syn is sent on the remote's recv_id, so this is where we can send
126		// the reset.
127		s.reset(v, k.seq_nr, k.conn_id)
128	}
129	s.backlog[syn] = addr
130	s.backlogChanged()
131}
132
133func (s *Socket) reader() {
134	mu.Lock()
135	defer mu.Unlock()
136	defer s.destroy()
137	var b [maxRecvSize]byte
138	for {
139		s.wgReadWrite.Add(1)
140		mu.Unlock()
141		n, addr, err := s.pc.ReadFrom(b[:])
142		s.wgReadWrite.Done()
143		mu.Lock()
144		if s.destroyed.IsSet() {
145			return
146		}
147		if err != nil {
148			log.Printf("error reading Socket PacketConn: %s", err)
149			s.ReadErr = err
150			return
151		}
152		s.handleReceivedPacket(read{
153			append([]byte(nil), b[:n]...),
154			addr,
155		})
156	}
157}
158
159func receivedUTPPacketSize(n int) {
160	if n > largestReceivedUTPPacket {
161		largestReceivedUTPPacket = n
162		largestReceivedUTPPacketExpvar.Set(int64(n))
163	}
164}
165
166func (s *Socket) connForRead(h header, from net.Addr) (c *Conn, ok bool) {
167	c, ok = s.conns[connKey{
168		resolvedAddrStr(from.String()),
169		func() uint16 {
170			if h.Type == stSyn {
171				// SYNs have a ConnID one lower than the eventual recvID, and we index
172				// the connections with that, so use it for the lookup.
173				return h.ConnID + 1
174			} else {
175				return h.ConnID
176			}
177		}(),
178	}]
179	return
180}
181
182func (s *Socket) handlePacketReceivedForEstablishedConn(h header, from net.Addr, data []byte, c *Conn) {
183	if h.Type == stSyn {
184		if h.ConnID == c.send_id-2 {
185			// This is a SYN for connection that cannot exist locally. The
186			// connection the remote wants to establish here with the proposed
187			// recv_id, already has an existing connection that was dialled
188			// *out* from this socket, which is why the send_id is 1 higher,
189			// rather than 1 lower than the recv_id.
190			log.Print("resetting conflicting syn")
191			s.reset(from, h.SeqNr, h.ConnID)
192			return
193		} else if h.ConnID != c.send_id {
194			panic("bad assumption")
195		}
196	}
197	c.receivePacket(h, data)
198}
199
200func (s *Socket) handleReceivedPacket(p read) {
201	if len(p.data) < 20 {
202		s.unusedRead(p)
203		return
204	}
205	var h header
206	hEnd, err := h.Unmarshal(p.data)
207	if err != nil || h.Type > stMax || h.Version != 1 {
208		s.unusedRead(p)
209		return
210	}
211	if c, ok := s.connForRead(h, p.from); ok {
212		receivedUTPPacketSize(len(p.data))
213		s.handlePacketReceivedForEstablishedConn(h, p.from, p.data[hEnd:], c)
214		return
215	}
216	// Packet doesn't belong to an existing connection.
217	switch h.Type {
218	case stSyn:
219		s.pushBacklog(syn{
220			seq_nr:  h.SeqNr,
221			conn_id: h.ConnID,
222			addr:    p.from.String(),
223		}, p.from)
224		return
225	case stReset:
226		// Could be a late arriving packet for a Conn we're already done with.
227		// If it was for an existing connection, we would have handled it
228		// earlier.
229	default:
230		unexpectedPacketsRead.Add(1)
231		// This is an unexpected packet. We'll send a reset, but also pass it
232		// on. I don't think you can reset on the received packets ConnID if
233		// it isn't a SYN, as the send_id will differ in this case.
234		s.reset(p.from, h.SeqNr, h.ConnID)
235		// Connection initiated by remote.
236		s.reset(p.from, h.SeqNr, h.ConnID-1)
237		// Connection initiated locally.
238		s.reset(p.from, h.SeqNr, h.ConnID+1)
239	}
240	s.unusedRead(p)
241}
242
243// Send a reset in response to a packet with the given header.
244func (s *Socket) reset(addr net.Addr, ackNr, connId uint16) {
245	b := make([]byte, 0, maxHeaderSize)
246	h := header{
247		Type:    stReset,
248		Version: 1,
249		ConnID:  connId,
250		AckNr:   ackNr,
251	}
252	b = b[:h.Marshal(b)]
253	go s.writeTo(b, addr)
254}
255
256// Return a recv_id that should be free. Handling the case where it isn't is
257// deferred to a more appropriate function.
258func (s *Socket) newConnID(remoteAddr resolvedAddrStr) (id uint16) {
259	// Rather than use math.Rand, which requires generating all the IDs up
260	// front and allocating a slice, we do it on the stack, generating the IDs
261	// only as required. To do this, we use the fact that the array is
262	// default-initialized. IDs that are 0, are actually their index in the
263	// array. IDs that are non-zero, are +1 from their intended ID.
264	var idsBack [0x10000]int
265	ids := idsBack[:]
266	for len(ids) != 0 {
267		// Pick the next ID from the untried ids.
268		i := rand.Intn(len(ids))
269		id = uint16(ids[i])
270		// If it's zero, then treat it as though the index i was the ID.
271		// Otherwise the value we get is the ID+1.
272		if id == 0 {
273			id = uint16(i)
274		} else {
275			id--
276		}
277		// Check there's no connection using this ID for its recv_id...
278		_, ok1 := s.conns[connKey{remoteAddr, id}]
279		// and if we're connecting to our own Socket, that there isn't a Conn
280		// already receiving on what will correspond to our send_id. Note that
281		// we just assume that we could be connecting to our own Socket. This
282		// will halve the available connection IDs to each distinct remote
283		// address. Presumably that's ~0x8000, down from ~0x10000.
284		_, ok2 := s.conns[connKey{remoteAddr, id + 1}]
285		_, ok4 := s.conns[connKey{remoteAddr, id - 1}]
286		if !ok1 && !ok2 && !ok4 {
287			return
288		}
289		// The set of possible IDs is shrinking. The highest one will be lost, so
290		// it's moved to the location of the one we just tried.
291		ids[i] = len(ids) // Conveniently already +1.
292		// And shrink.
293		ids = ids[:len(ids)-1]
294	}
295	return
296}
297
298var (
299	zeroipv4 = net.ParseIP("0.0.0.0")
300	zeroipv6 = net.ParseIP("::")
301
302	ipv4lo = mustResolveUDP("127.0.0.1")
303	ipv6lo = mustResolveUDP("::1")
304)
305
306func mustResolveUDP(addr string) net.IP {
307	u, err := net.ResolveIPAddr("ip", addr)
308	if err != nil {
309		panic(err)
310	}
311	return u.IP
312}
313
314func realRemoteAddr(addr net.Addr) net.Addr {
315	udpAddr, ok := addr.(*net.UDPAddr)
316	if ok {
317		if udpAddr.IP.Equal(zeroipv4) {
318			udpAddr.IP = ipv4lo
319		}
320		if udpAddr.IP.Equal(zeroipv6) {
321			udpAddr.IP = ipv6lo
322		}
323	}
324	return addr
325}
326
327func (s *Socket) newConn(addr net.Addr) (c *Conn) {
328	addr = realRemoteAddr(addr)
329
330	c = &Conn{
331		socket:           s,
332		remoteSocketAddr: addr,
333		created:          time.Now(),
334	}
335	c.sendPendingSendSendStateTimer = StoppedFuncTimer(c.sendPendingSendStateTimerCallback)
336	c.packetReadTimeoutTimer = time.AfterFunc(packetReadTimeout, c.receivePacketTimeoutCallback)
337	return
338}
339
340func (s *Socket) Dial(addr string) (net.Conn, error) {
341	return s.DialContext(context.Background(), "", addr)
342}
343
344func (s *Socket) DialAddr(netAddr net.Addr) (net.Conn, error) {
345	return s.DialAddrContext(context.Background(), netAddr)
346}
347
348func (s *Socket) resolveAddr(network, addr string) (net.Addr, error) {
349	n := s.network()
350	if network != "" {
351		n = network
352	}
353	return net.ResolveUDPAddr(n, addr)
354}
355
356func (s *Socket) network() string {
357	return s.pc.LocalAddr().Network()
358}
359
360func (s *Socket) startOutboundConn(addr net.Addr) (c *Conn, err error) {
361	mu.Lock()
362	defer mu.Unlock()
363	c = s.newConn(addr)
364	c.recv_id = s.newConnID(resolvedAddrStr(c.RemoteAddr().String()))
365	c.send_id = c.recv_id + 1
366	if logLevel >= 1 {
367		log.Printf("dial registering addr: %s", c.RemoteAddr().String())
368	}
369	if !s.registerConn(c.recv_id, resolvedAddrStr(c.RemoteAddr().String()), c) {
370		err = errors.New("couldn't register new connection")
371		log.Println(c.recv_id, c.RemoteAddr().String())
372		for k, c := range s.conns {
373			log.Println(k, c, c.age())
374		}
375		log.Printf("that's %d connections", len(s.conns))
376	}
377	if err != nil {
378		return
379	}
380	c.seq_nr = 1
381	c.writeSyn()
382	return
383}
384
385func (s *Socket) DialContext(ctx context.Context, network, addr string) (nc net.Conn, err error) {
386	netAddr, err := s.resolveAddr(network, addr)
387	if err != nil {
388		return
389	}
390
391	return s.DialAddrContext(ctx, netAddr)
392}
393
394func (s *Socket) DialAddrContext(ctx context.Context, netAddr net.Addr) (nc net.Conn, err error) {
395	c, err := s.startOutboundConn(netAddr)
396	if err != nil {
397		return
398	}
399
400	connErr := make(chan error, 1)
401	go func() {
402		connErr <- c.recvSynAck()
403	}()
404	select {
405	case err = <-connErr:
406	case <-ctx.Done():
407		err = ctx.Err()
408	}
409	if err != nil {
410		mu.Lock()
411		c.destroy(errors.New("dial timeout"))
412		mu.Unlock()
413		return
414	}
415	mu.Lock()
416	c.updateCanWrite()
417	mu.Unlock()
418	//nc = pproffd.WrapNetConn(c)
419	nc = c
420	return
421}
422
423func (me *Socket) writeTo(b []byte, addr net.Addr) (n int, err error) {
424	apdc := artificialPacketDropChance
425	if apdc != 0 {
426		if rand.Float64() < apdc {
427			n = len(b)
428			return
429		}
430	}
431	n, err = me.pc.WriteTo(b, addr)
432	return
433}
434
435// Returns true if the connection was newly registered, false otherwise.
436func (s *Socket) registerConn(recvID uint16, remoteAddr resolvedAddrStr, c *Conn) bool {
437	if s.conns == nil {
438		s.conns = make(map[connKey]*Conn)
439	}
440	key := connKey{remoteAddr, recvID}
441	if _, ok := s.conns[key]; ok {
442		return false
443	}
444	c.connKey = key
445	s.conns[key] = c
446	s.onMutex.RLock()
447	defer s.onMutex.RUnlock()
448	if s.onAttach != nil {
449		go s.onAttach(c.remoteSocketAddr)
450	}
451	return true
452}
453
454func (s *Socket) backlogChanged() {
455	if len(s.backlog) != 0 {
456		s.backlogNotEmpty.Set()
457	} else {
458		s.backlogNotEmpty.Clear()
459	}
460}
461
462func (s *Socket) nextSyn() (syn syn, addr net.Addr, err error) {
463	for {
464		WaitEvents(&mu, &s.closed, &s.backlogNotEmpty, &s.destroyed)
465		if s.closed.IsSet() {
466			err = errClosed
467			return
468		}
469		if s.destroyed.IsSet() {
470			err = s.ReadErr
471			return
472		}
473		for k, v := range s.backlog {
474			syn = k
475			addr = v
476			delete(s.backlog, k)
477			s.backlogChanged()
478			return
479		}
480	}
481}
482
483// ACK a SYN, and return a new Conn for it. ok is false if the SYN is bad, and
484// the Conn invalid.
485func (s *Socket) ackSyn(syn syn, addr net.Addr) (c *Conn, ok bool) {
486	c = s.newConn(addr)
487	c.send_id = syn.conn_id
488	c.recv_id = c.send_id + 1
489	c.seq_nr = uint16(rand.Int())
490	c.lastAck = c.seq_nr - 1
491	c.ack_nr = syn.seq_nr
492	c.synAcked = true
493	c.updateCanWrite()
494	if !s.registerConn(c.recv_id, resolvedAddrStr(addr.String()), c) {
495		// SYN that triggered this accept duplicates existing connection.
496		// Ack again in case the SYN was a resend.
497		c = s.conns[connKey{resolvedAddrStr(addr.String()), c.recv_id}]
498		if c.send_id != syn.conn_id {
499			panic(":|")
500		}
501		c.sendState()
502		return
503	}
504	c.sendState()
505	ok = true
506	return
507}
508
509// Accept and return a new uTP connection.
510func (s *Socket) Accept() (net.Conn, error) {
511	mu.Lock()
512	defer mu.Unlock()
513	for {
514		syn, addr, err := s.nextSyn()
515		if err != nil {
516			return nil, err
517		}
518		c, ok := s.ackSyn(syn, addr)
519		if ok {
520			c.updateCanWrite()
521			return c, nil
522		}
523	}
524}
525
526// The address we're listening on for new uTP connections.
527func (s *Socket) Addr() net.Addr {
528	return s.pc.LocalAddr()
529}
530
531func (s *Socket) CloseNow() error {
532	mu.Lock()
533	defer mu.Unlock()
534	s.closed.Set()
535	for _, c := range s.conns {
536		c.closeNow()
537	}
538	s.destroy()
539	s.wgReadWrite.Wait()
540	return nil
541}
542
543func (s *Socket) Close() error {
544	mu.Lock()
545	defer mu.Unlock()
546	s.closed.Set()
547	s.lazyDestroy()
548	return nil
549}
550
551func (s *Socket) lazyDestroy() {
552	if len(s.conns) != 0 {
553		return
554	}
555	if !s.closed.IsSet() {
556		return
557	}
558	s.destroy()
559}
560
561func (s *Socket) destroy() {
562	delete(sockets, s)
563	s.destroyed.Set()
564	s.pc.Close()
565	for _, c := range s.conns {
566		c.destroy(errors.New("Socket destroyed"))
567	}
568}
569
570func (s *Socket) LocalAddr() net.Addr {
571	return s.pc.LocalAddr()
572}
573
574func (s *Socket) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
575	select {
576	case read, ok := <-s.unusedReads:
577		if !ok {
578			err = io.EOF
579			return
580		}
581		n = copy(p, read.data)
582		addr = read.from
583		return
584	case <-s.connDeadlines.read.passed.LockedChan(&mu):
585		err = errTimeout
586		return
587	}
588}
589
590func (s *Socket) WriteTo(b []byte, addr net.Addr) (n int, err error) {
591	mu.Lock()
592	if s.connDeadlines.write.passed.IsSet() {
593		err = errTimeout
594	}
595	s.wgReadWrite.Add(1)
596	defer s.wgReadWrite.Done()
597	mu.Unlock()
598	if err != nil {
599		return
600	}
601	return s.pc.WriteTo(b, addr)
602}
603