1package udp
2
3import (
4	"context"
5
6	"github.com/xtls/xray-core/common/buf"
7	"github.com/xtls/xray-core/common/net"
8	"github.com/xtls/xray-core/common/protocol/udp"
9	"github.com/xtls/xray-core/transport/internet"
10)
11
12type HubOption func(h *Hub)
13
14func HubCapacity(capacity int) HubOption {
15	return func(h *Hub) {
16		h.capacity = capacity
17	}
18}
19
20func HubReceiveOriginalDestination(r bool) HubOption {
21	return func(h *Hub) {
22		h.recvOrigDest = r
23	}
24}
25
26type Hub struct {
27	conn         *net.UDPConn
28	cache        chan *udp.Packet
29	capacity     int
30	recvOrigDest bool
31}
32
33func ListenUDP(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, options ...HubOption) (*Hub, error) {
34	hub := &Hub{
35		capacity:     256,
36		recvOrigDest: false,
37	}
38	for _, opt := range options {
39		opt(hub)
40	}
41
42	var sockopt *internet.SocketConfig
43	if streamSettings != nil {
44		sockopt = streamSettings.SocketSettings
45	}
46	if sockopt != nil && sockopt.ReceiveOriginalDestAddress {
47		hub.recvOrigDest = true
48	}
49
50	udpConn, err := internet.ListenSystemPacket(ctx, &net.UDPAddr{
51		IP:   address.IP(),
52		Port: int(port),
53	}, sockopt)
54	if err != nil {
55		return nil, err
56	}
57	newError("listening UDP on ", address, ":", port).WriteToLog()
58	hub.conn = udpConn.(*net.UDPConn)
59	hub.cache = make(chan *udp.Packet, hub.capacity)
60
61	go hub.start()
62	return hub, nil
63}
64
65// Close implements net.Listener.
66func (h *Hub) Close() error {
67	h.conn.Close()
68	return nil
69}
70
71func (h *Hub) WriteTo(payload []byte, dest net.Destination) (int, error) {
72	return h.conn.WriteToUDP(payload, &net.UDPAddr{
73		IP:   dest.Address.IP(),
74		Port: int(dest.Port),
75	})
76}
77
78func (h *Hub) start() {
79	c := h.cache
80	defer close(c)
81
82	oobBytes := make([]byte, 256)
83
84	for {
85		buffer := buf.New()
86		var noob int
87		var addr *net.UDPAddr
88		rawBytes := buffer.Extend(buf.Size)
89
90		n, noob, _, addr, err := ReadUDPMsg(h.conn, rawBytes, oobBytes)
91		if err != nil {
92			newError("failed to read UDP msg").Base(err).WriteToLog()
93			buffer.Release()
94			break
95		}
96		buffer.Resize(0, int32(n))
97
98		if buffer.IsEmpty() {
99			buffer.Release()
100			continue
101		}
102
103		payload := &udp.Packet{
104			Payload: buffer,
105			Source:  net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)),
106		}
107		if h.recvOrigDest && noob > 0 {
108			payload.Target = RetrieveOriginalDest(oobBytes[:noob])
109			if payload.Target.IsValid() {
110				newError("UDP original destination: ", payload.Target).AtDebug().WriteToLog()
111			} else {
112				newError("failed to read UDP original destination").WriteToLog()
113			}
114		}
115
116		select {
117		case c <- payload:
118		default:
119			buffer.Release()
120			payload.Payload = nil
121		}
122	}
123}
124
125// Addr implements net.Listener.
126func (h *Hub) Addr() net.Addr {
127	return h.conn.LocalAddr()
128}
129
130func (h *Hub) Receive() <-chan *udp.Packet {
131	return h.cache
132}
133