1// Package dns is an implementation of core.DNS feature.
2package dns
3
4//go:generate go run github.com/xtls/xray-core/common/errors/errorgen
5
6import (
7	"context"
8	"fmt"
9	"strings"
10	"sync"
11
12	"github.com/xtls/xray-core/app/router"
13	"github.com/xtls/xray-core/common"
14	"github.com/xtls/xray-core/common/errors"
15	"github.com/xtls/xray-core/common/net"
16	"github.com/xtls/xray-core/common/session"
17	"github.com/xtls/xray-core/common/strmatcher"
18	"github.com/xtls/xray-core/features"
19	"github.com/xtls/xray-core/features/dns"
20)
21
22// DNS is a DNS rely server.
23type DNS struct {
24	sync.Mutex
25	tag                    string
26	disableCache           bool
27	disableFallback        bool
28	disableFallbackIfMatch bool
29	ipOption               *dns.IPOption
30	hosts                  *StaticHosts
31	clients                []*Client
32	ctx                    context.Context
33	domainMatcher          strmatcher.IndexMatcher
34	matcherInfos           []*DomainMatcherInfo
35}
36
37// DomainMatcherInfo contains information attached to index returned by Server.domainMatcher
38type DomainMatcherInfo struct {
39	clientIdx     uint16
40	domainRuleIdx uint16
41}
42
43// New creates a new DNS server with given configuration.
44func New(ctx context.Context, config *Config) (*DNS, error) {
45	var tag string
46	if len(config.Tag) > 0 {
47		tag = config.Tag
48	} else {
49		tag = generateRandomTag()
50	}
51
52	var clientIP net.IP
53	switch len(config.ClientIp) {
54	case 0, net.IPv4len, net.IPv6len:
55		clientIP = net.IP(config.ClientIp)
56	default:
57		return nil, newError("unexpected client IP length ", len(config.ClientIp))
58	}
59
60	var ipOption *dns.IPOption
61	switch config.QueryStrategy {
62	case QueryStrategy_USE_IP:
63		ipOption = &dns.IPOption{
64			IPv4Enable: true,
65			IPv6Enable: true,
66			FakeEnable: false,
67		}
68	case QueryStrategy_USE_IP4:
69		ipOption = &dns.IPOption{
70			IPv4Enable: true,
71			IPv6Enable: false,
72			FakeEnable: false,
73		}
74	case QueryStrategy_USE_IP6:
75		ipOption = &dns.IPOption{
76			IPv4Enable: false,
77			IPv6Enable: true,
78			FakeEnable: false,
79		}
80	}
81
82	hosts, err := NewStaticHosts(config.StaticHosts, config.Hosts)
83	if err != nil {
84		return nil, newError("failed to create hosts").Base(err)
85	}
86
87	clients := []*Client{}
88	domainRuleCount := 0
89	for _, ns := range config.NameServer {
90		domainRuleCount += len(ns.PrioritizedDomain)
91	}
92
93	// MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1
94	matcherInfos := make([]*DomainMatcherInfo, domainRuleCount+1)
95	domainMatcher := &strmatcher.MatcherGroup{}
96	geoipContainer := router.GeoIPMatcherContainer{}
97
98	for _, endpoint := range config.NameServers {
99		features.PrintDeprecatedFeatureWarning("simple DNS server")
100		client, err := NewSimpleClient(ctx, endpoint, clientIP)
101		if err != nil {
102			return nil, newError("failed to create client").Base(err)
103		}
104		clients = append(clients, client)
105	}
106
107	for _, ns := range config.NameServer {
108		clientIdx := len(clients)
109		updateDomain := func(domainRule strmatcher.Matcher, originalRuleIdx int, matcherInfos []*DomainMatcherInfo) error {
110			midx := domainMatcher.Add(domainRule)
111			matcherInfos[midx] = &DomainMatcherInfo{
112				clientIdx:     uint16(clientIdx),
113				domainRuleIdx: uint16(originalRuleIdx),
114			}
115			return nil
116		}
117
118		myClientIP := clientIP
119		switch len(ns.ClientIp) {
120		case net.IPv4len, net.IPv6len:
121			myClientIP = net.IP(ns.ClientIp)
122		}
123		client, err := NewClient(ctx, ns, myClientIP, geoipContainer, &matcherInfos, updateDomain)
124		if err != nil {
125			return nil, newError("failed to create client").Base(err)
126		}
127		clients = append(clients, client)
128	}
129
130	// If there is no DNS client in config, add a `localhost` DNS client
131	if len(clients) == 0 {
132		clients = append(clients, NewLocalDNSClient())
133	}
134
135	return &DNS{
136		tag:                    tag,
137		hosts:                  hosts,
138		ipOption:               ipOption,
139		clients:                clients,
140		ctx:                    ctx,
141		domainMatcher:          domainMatcher,
142		matcherInfos:           matcherInfos,
143		disableCache:           config.DisableCache,
144		disableFallback:        config.DisableFallback,
145		disableFallbackIfMatch: config.DisableFallbackIfMatch,
146	}, nil
147}
148
149// Type implements common.HasType.
150func (*DNS) Type() interface{} {
151	return dns.ClientType()
152}
153
154// Start implements common.Runnable.
155func (s *DNS) Start() error {
156	return nil
157}
158
159// Close implements common.Closable.
160func (s *DNS) Close() error {
161	return nil
162}
163
164// IsOwnLink implements proxy.dns.ownLinkVerifier
165func (s *DNS) IsOwnLink(ctx context.Context) bool {
166	inbound := session.InboundFromContext(ctx)
167	return inbound != nil && inbound.Tag == s.tag
168}
169
170// LookupIP implements dns.Client.
171func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, error) {
172	if domain == "" {
173		return nil, newError("empty domain name")
174	}
175
176	option.IPv4Enable = option.IPv4Enable && s.ipOption.IPv4Enable
177	option.IPv6Enable = option.IPv6Enable && s.ipOption.IPv6Enable
178
179	if !option.IPv4Enable && !option.IPv6Enable {
180		return nil, dns.ErrEmptyResponse
181	}
182
183	// Normalize the FQDN form query
184	if strings.HasSuffix(domain, ".") {
185		domain = domain[:len(domain)-1]
186	}
187
188	// Static host lookup
189	switch addrs := s.hosts.Lookup(domain, option); {
190	case addrs == nil: // Domain not recorded in static host
191		break
192	case len(addrs) == 0: // Domain recorded, but no valid IP returned (e.g. IPv4 address with only IPv6 enabled)
193		return nil, dns.ErrEmptyResponse
194	case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Domain replacement
195		newError("domain replaced: ", domain, " -> ", addrs[0].Domain()).WriteToLog()
196		domain = addrs[0].Domain()
197	default: // Successfully found ip records in static host
198		newError("returning ", len(addrs), " IP(s) for domain ", domain, " -> ", addrs).WriteToLog()
199		return toNetIP(addrs)
200	}
201
202	// Name servers lookup
203	errs := []error{}
204	ctx := session.ContextWithInbound(s.ctx, &session.Inbound{Tag: s.tag})
205	for _, client := range s.sortClients(domain) {
206		if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
207			newError("skip DNS resolution for domain ", domain, " at server ", client.Name()).AtDebug().WriteToLog()
208			continue
209		}
210		ips, err := client.QueryIP(ctx, domain, option, s.disableCache)
211		if len(ips) > 0 {
212			return ips, nil
213		}
214		if err != nil {
215			newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
216			errs = append(errs, err)
217		}
218		if err != context.Canceled && err != context.DeadlineExceeded && err != errExpectedIPNonMatch {
219			return nil, err
220		}
221	}
222
223	return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...))
224}
225
226// GetIPOption implements ClientWithIPOption.
227func (s *DNS) GetIPOption() *dns.IPOption {
228	return s.ipOption
229}
230
231// SetQueryOption implements ClientWithIPOption.
232func (s *DNS) SetQueryOption(isIPv4Enable, isIPv6Enable bool) {
233	s.ipOption.IPv4Enable = isIPv4Enable
234	s.ipOption.IPv6Enable = isIPv6Enable
235}
236
237// SetFakeDNSOption implements ClientWithIPOption.
238func (s *DNS) SetFakeDNSOption(isFakeEnable bool) {
239	s.ipOption.FakeEnable = isFakeEnable
240}
241
242func (s *DNS) sortClients(domain string) []*Client {
243	clients := make([]*Client, 0, len(s.clients))
244	clientUsed := make([]bool, len(s.clients))
245	clientNames := make([]string, 0, len(s.clients))
246	domainRules := []string{}
247
248	// Priority domain matching
249	hasMatch := false
250	for _, match := range s.domainMatcher.Match(domain) {
251		info := s.matcherInfos[match]
252		client := s.clients[info.clientIdx]
253		domainRule := client.domains[info.domainRuleIdx]
254		domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", domainRule, info.clientIdx))
255		if clientUsed[info.clientIdx] {
256			continue
257		}
258		clientUsed[info.clientIdx] = true
259		clients = append(clients, client)
260		clientNames = append(clientNames, client.Name())
261		hasMatch = true
262	}
263
264	if !(s.disableFallback || s.disableFallbackIfMatch && hasMatch) {
265		// Default round-robin query
266		for idx, client := range s.clients {
267			if clientUsed[idx] || client.skipFallback {
268				continue
269			}
270			clientUsed[idx] = true
271			clients = append(clients, client)
272			clientNames = append(clientNames, client.Name())
273		}
274	}
275
276	if len(domainRules) > 0 {
277		newError("domain ", domain, " matches following rules: ", domainRules).AtDebug().WriteToLog()
278	}
279	if len(clientNames) > 0 {
280		newError("domain ", domain, " will use DNS in order: ", clientNames).AtDebug().WriteToLog()
281	}
282
283	if len(clients) == 0 {
284		clients = append(clients, s.clients[0])
285		clientNames = append(clientNames, s.clients[0].Name())
286		newError("domain ", domain, " will use the first DNS: ", clientNames).AtDebug().WriteToLog()
287	}
288
289	return clients
290}
291
292func init() {
293	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
294		return New(ctx, config.(*Config))
295	}))
296}
297