1package udp
2
3import (
4	"context"
5	"io"
6	"sync"
7	"time"
8
9	"github.com/xtls/xray-core/common/signal/done"
10
11	"github.com/xtls/xray-core/common"
12	"github.com/xtls/xray-core/common/buf"
13	"github.com/xtls/xray-core/common/net"
14	"github.com/xtls/xray-core/common/protocol/udp"
15	"github.com/xtls/xray-core/common/session"
16	"github.com/xtls/xray-core/common/signal"
17	"github.com/xtls/xray-core/features/routing"
18	"github.com/xtls/xray-core/transport"
19)
20
21type ResponseCallback func(ctx context.Context, packet *udp.Packet)
22
23type connEntry struct {
24	link   *transport.Link
25	timer  signal.ActivityUpdater
26	cancel context.CancelFunc
27}
28
29type Dispatcher struct {
30	sync.RWMutex
31	conns      map[net.Destination]*connEntry
32	dispatcher routing.Dispatcher
33	callback   ResponseCallback
34}
35
36func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
37	return &Dispatcher{
38		conns:      make(map[net.Destination]*connEntry),
39		dispatcher: dispatcher,
40		callback:   callback,
41	}
42}
43
44func (v *Dispatcher) RemoveRay(dest net.Destination) {
45	v.Lock()
46	defer v.Unlock()
47	if conn, found := v.conns[dest]; found {
48		common.Close(conn.link.Reader)
49		common.Close(conn.link.Writer)
50		delete(v.conns, dest)
51	}
52}
53
54func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *connEntry {
55	v.Lock()
56	defer v.Unlock()
57
58	if entry, found := v.conns[dest]; found {
59		return entry
60	}
61
62	newError("establishing new connection for ", dest).WriteToLog()
63
64	ctx, cancel := context.WithCancel(ctx)
65	removeRay := func() {
66		cancel()
67		v.RemoveRay(dest)
68	}
69	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
70	link, _ := v.dispatcher.Dispatch(ctx, dest)
71	entry := &connEntry{
72		link:   link,
73		timer:  timer,
74		cancel: removeRay,
75	}
76	v.conns[dest] = entry
77	go handleInput(ctx, entry, dest, v.callback)
78	return entry
79}
80
81func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) {
82	// TODO: Add user to destString
83	newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx))
84
85	conn := v.getInboundRay(ctx, destination)
86	outputStream := conn.link.Writer
87	if outputStream != nil {
88		if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
89			newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
90			conn.cancel()
91			return
92		}
93	}
94}
95
96func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) {
97	defer conn.cancel()
98
99	input := conn.link.Reader
100	timer := conn.timer
101
102	for {
103		select {
104		case <-ctx.Done():
105			return
106		default:
107		}
108
109		mb, err := input.ReadMultiBuffer()
110		if err != nil {
111			newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx))
112			return
113		}
114		timer.Update()
115		for _, b := range mb {
116			callback(ctx, &udp.Packet{
117				Payload: b,
118				Source:  dest,
119			})
120		}
121	}
122}
123
124type dispatcherConn struct {
125	dispatcher *Dispatcher
126	cache      chan *udp.Packet
127	done       *done.Instance
128}
129
130func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) {
131	c := &dispatcherConn{
132		cache: make(chan *udp.Packet, 16),
133		done:  done.New(),
134	}
135
136	d := NewDispatcher(dispatcher, c.callback)
137	c.dispatcher = d
138	return c, nil
139}
140
141func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) {
142	select {
143	case <-c.done.Wait():
144		packet.Payload.Release()
145		return
146	case c.cache <- packet:
147	default:
148		packet.Payload.Release()
149		return
150	}
151}
152
153func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) {
154	select {
155	case <-c.done.Wait():
156		return 0, nil, io.EOF
157	case packet := <-c.cache:
158		n := copy(p, packet.Payload.Bytes())
159		return n, &net.UDPAddr{
160			IP:   packet.Source.Address.IP(),
161			Port: int(packet.Source.Port),
162		}, nil
163	}
164}
165
166func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {
167	buffer := buf.New()
168	raw := buffer.Extend(buf.Size)
169	n := copy(raw, p)
170	buffer.Resize(0, int32(n))
171
172	ctx := context.Background()
173	c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer)
174	return n, nil
175}
176
177func (c *dispatcherConn) Close() error {
178	return c.done.Close()
179}
180
181func (c *dispatcherConn) LocalAddr() net.Addr {
182	return &net.UDPAddr{
183		IP:   []byte{0, 0, 0, 0},
184		Port: 0,
185	}
186}
187
188func (c *dispatcherConn) SetDeadline(t time.Time) error {
189	return nil
190}
191
192func (c *dispatcherConn) SetReadDeadline(t time.Time) error {
193	return nil
194}
195
196func (c *dispatcherConn) SetWriteDeadline(t time.Time) error {
197	return nil
198}
199