1package dns
2
3import (
4	"context"
5	"io"
6	"sync"
7	"time"
8
9	"github.com/xtls/xray-core/transport/internet/stat"
10
11	"golang.org/x/net/dns/dnsmessage"
12
13	"github.com/xtls/xray-core/common"
14	"github.com/xtls/xray-core/common/buf"
15	"github.com/xtls/xray-core/common/net"
16	dns_proto "github.com/xtls/xray-core/common/protocol/dns"
17	"github.com/xtls/xray-core/common/session"
18	"github.com/xtls/xray-core/common/signal"
19	"github.com/xtls/xray-core/common/task"
20	"github.com/xtls/xray-core/core"
21	"github.com/xtls/xray-core/features/dns"
22	"github.com/xtls/xray-core/features/policy"
23	"github.com/xtls/xray-core/transport"
24	"github.com/xtls/xray-core/transport/internet"
25)
26
27func init() {
28	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
29		h := new(Handler)
30		if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error {
31			return h.Init(config.(*Config), dnsClient, policyManager)
32		}); err != nil {
33			return nil, err
34		}
35		return h, nil
36	}))
37}
38
39type ownLinkVerifier interface {
40	IsOwnLink(ctx context.Context) bool
41}
42
43type Handler struct {
44	client          dns.Client
45	ownLinkVerifier ownLinkVerifier
46	server          net.Destination
47	timeout         time.Duration
48}
49
50func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error {
51	h.client = dnsClient
52	h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle
53
54	if v, ok := dnsClient.(ownLinkVerifier); ok {
55		h.ownLinkVerifier = v
56	}
57
58	if config.Server != nil {
59		h.server = config.Server.AsDestination()
60	}
61	return nil
62}
63
64func (h *Handler) isOwnLink(ctx context.Context) bool {
65	return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx)
66}
67
68func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) {
69	var parser dnsmessage.Parser
70	header, err := parser.Start(b)
71	if err != nil {
72		newError("parser start").Base(err).WriteToLog()
73		return
74	}
75
76	id = header.ID
77	q, err := parser.Question()
78	if err != nil {
79		newError("question").Base(err).WriteToLog()
80		return
81	}
82	qType = q.Type
83	if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA {
84		return
85	}
86
87	domain = q.Name.String()
88	r = true
89	return
90}
91
92// Process implements proxy.Outbound.
93func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error {
94	outbound := session.OutboundFromContext(ctx)
95	if outbound == nil || !outbound.Target.IsValid() {
96		return newError("invalid outbound")
97	}
98
99	srcNetwork := outbound.Target.Network
100
101	dest := outbound.Target
102	if h.server.Network != net.Network_Unknown {
103		dest.Network = h.server.Network
104	}
105	if h.server.Address != nil {
106		dest.Address = h.server.Address
107	}
108	if h.server.Port != 0 {
109		dest.Port = h.server.Port
110	}
111
112	newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx))
113
114	conn := &outboundConn{
115		dialer: func() (stat.Connection, error) {
116			return d.Dial(ctx, dest)
117		},
118		connReady: make(chan struct{}, 1),
119	}
120
121	var reader dns_proto.MessageReader
122	var writer dns_proto.MessageWriter
123	if srcNetwork == net.Network_TCP {
124		reader = dns_proto.NewTCPReader(link.Reader)
125		writer = &dns_proto.TCPWriter{
126			Writer: link.Writer,
127		}
128	} else {
129		reader = &dns_proto.UDPReader{
130			Reader: link.Reader,
131		}
132		writer = &dns_proto.UDPWriter{
133			Writer: link.Writer,
134		}
135	}
136
137	var connReader dns_proto.MessageReader
138	var connWriter dns_proto.MessageWriter
139	if dest.Network == net.Network_TCP {
140		connReader = dns_proto.NewTCPReader(buf.NewReader(conn))
141		connWriter = &dns_proto.TCPWriter{
142			Writer: buf.NewWriter(conn),
143		}
144	} else {
145		connReader = &dns_proto.UDPReader{
146			Reader: buf.NewPacketReader(conn),
147		}
148		connWriter = &dns_proto.UDPWriter{
149			Writer: buf.NewWriter(conn),
150		}
151	}
152
153	ctx, cancel := context.WithCancel(ctx)
154	timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout)
155
156	request := func() error {
157		defer conn.Close()
158
159		for {
160			b, err := reader.ReadMessage()
161			if err == io.EOF {
162				return nil
163			}
164
165			if err != nil {
166				return err
167			}
168
169			timer.Update()
170
171			if !h.isOwnLink(ctx) {
172				isIPQuery, domain, id, qType := parseIPQuery(b.Bytes())
173				if isIPQuery {
174					go h.handleIPQuery(id, qType, domain, writer)
175					continue
176				}
177			}
178
179			if err := connWriter.WriteMessage(b); err != nil {
180				return err
181			}
182		}
183	}
184
185	response := func() error {
186		for {
187			b, err := connReader.ReadMessage()
188			if err == io.EOF {
189				return nil
190			}
191
192			if err != nil {
193				return err
194			}
195
196			timer.Update()
197
198			if err := writer.WriteMessage(b); err != nil {
199				return err
200			}
201		}
202	}
203
204	if err := task.Run(ctx, request, response); err != nil {
205		return newError("connection ends").Base(err)
206	}
207
208	return nil
209}
210
211func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
212	var ips []net.IP
213	var err error
214
215	var ttl uint32 = 600
216
217	switch qType {
218	case dnsmessage.TypeA:
219		ips, err = h.client.LookupIP(domain, dns.IPOption{
220			IPv4Enable: true,
221			IPv6Enable: false,
222			FakeEnable: true,
223		})
224	case dnsmessage.TypeAAAA:
225		ips, err = h.client.LookupIP(domain, dns.IPOption{
226			IPv4Enable: false,
227			IPv6Enable: true,
228			FakeEnable: true,
229		})
230	}
231
232	rcode := dns.RCodeFromError(err)
233	if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse {
234		newError("ip query").Base(err).WriteToLog()
235		return
236	}
237
238	switch qType {
239	case dnsmessage.TypeA:
240		for i, ip := range ips {
241			ips[i] = ip.To4()
242		}
243	case dnsmessage.TypeAAAA:
244		for i, ip := range ips {
245			ips[i] = ip.To16()
246		}
247	}
248
249	b := buf.New()
250	rawBytes := b.Extend(buf.Size)
251	builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
252		ID:                 id,
253		RCode:              dnsmessage.RCode(rcode),
254		RecursionAvailable: true,
255		RecursionDesired:   true,
256		Response:           true,
257		Authoritative:      true,
258	})
259	builder.EnableCompression()
260	common.Must(builder.StartQuestions())
261	common.Must(builder.Question(dnsmessage.Question{
262		Name:  dnsmessage.MustNewName(domain),
263		Class: dnsmessage.ClassINET,
264		Type:  qType,
265	}))
266	common.Must(builder.StartAnswers())
267
268	rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl}
269	for _, ip := range ips {
270		if len(ip) == net.IPv4len {
271			var r dnsmessage.AResource
272			copy(r.A[:], ip)
273			common.Must(builder.AResource(rHeader, r))
274		} else {
275			var r dnsmessage.AAAAResource
276			copy(r.AAAA[:], ip)
277			common.Must(builder.AAAAResource(rHeader, r))
278		}
279	}
280	msgBytes, err := builder.Finish()
281	if err != nil {
282		newError("pack message").Base(err).WriteToLog()
283		b.Release()
284		return
285	}
286	b.Resize(0, int32(len(msgBytes)))
287
288	if err := writer.WriteMessage(b); err != nil {
289		newError("write IP answer").Base(err).WriteToLog()
290	}
291}
292
293type outboundConn struct {
294	access sync.Mutex
295	dialer func() (stat.Connection, error)
296
297	conn      net.Conn
298	connReady chan struct{}
299}
300
301func (c *outboundConn) dial() error {
302	conn, err := c.dialer()
303	if err != nil {
304		return err
305	}
306	c.conn = conn
307	c.connReady <- struct{}{}
308	return nil
309}
310
311func (c *outboundConn) Write(b []byte) (int, error) {
312	c.access.Lock()
313
314	if c.conn == nil {
315		if err := c.dial(); err != nil {
316			c.access.Unlock()
317			newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog()
318			return len(b), nil
319		}
320	}
321
322	c.access.Unlock()
323
324	return c.conn.Write(b)
325}
326
327func (c *outboundConn) Read(b []byte) (int, error) {
328	var conn net.Conn
329	c.access.Lock()
330	conn = c.conn
331	c.access.Unlock()
332
333	if conn == nil {
334		_, open := <-c.connReady
335		if !open {
336			return 0, io.EOF
337		}
338		conn = c.conn
339	}
340
341	return conn.Read(b)
342}
343
344func (c *outboundConn) Close() error {
345	c.access.Lock()
346	close(c.connReady)
347	if c.conn != nil {
348		c.conn.Close()
349	}
350	c.access.Unlock()
351	return nil
352}
353