1package dns 2 3import ( 4 "context" 5 "io" 6 "sync" 7 "time" 8 9 "github.com/xtls/xray-core/transport/internet/stat" 10 11 "golang.org/x/net/dns/dnsmessage" 12 13 "github.com/xtls/xray-core/common" 14 "github.com/xtls/xray-core/common/buf" 15 "github.com/xtls/xray-core/common/net" 16 dns_proto "github.com/xtls/xray-core/common/protocol/dns" 17 "github.com/xtls/xray-core/common/session" 18 "github.com/xtls/xray-core/common/signal" 19 "github.com/xtls/xray-core/common/task" 20 "github.com/xtls/xray-core/core" 21 "github.com/xtls/xray-core/features/dns" 22 "github.com/xtls/xray-core/features/policy" 23 "github.com/xtls/xray-core/transport" 24 "github.com/xtls/xray-core/transport/internet" 25) 26 27func init() { 28 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 29 h := new(Handler) 30 if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { 31 return h.Init(config.(*Config), dnsClient, policyManager) 32 }); err != nil { 33 return nil, err 34 } 35 return h, nil 36 })) 37} 38 39type ownLinkVerifier interface { 40 IsOwnLink(ctx context.Context) bool 41} 42 43type Handler struct { 44 client dns.Client 45 ownLinkVerifier ownLinkVerifier 46 server net.Destination 47 timeout time.Duration 48} 49 50func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error { 51 h.client = dnsClient 52 h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle 53 54 if v, ok := dnsClient.(ownLinkVerifier); ok { 55 h.ownLinkVerifier = v 56 } 57 58 if config.Server != nil { 59 h.server = config.Server.AsDestination() 60 } 61 return nil 62} 63 64func (h *Handler) isOwnLink(ctx context.Context) bool { 65 return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) 66} 67 68func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { 69 var parser dnsmessage.Parser 70 header, err := parser.Start(b) 71 if err != nil { 72 newError("parser start").Base(err).WriteToLog() 73 return 74 } 75 76 id = header.ID 77 q, err := parser.Question() 78 if err != nil { 79 newError("question").Base(err).WriteToLog() 80 return 81 } 82 qType = q.Type 83 if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { 84 return 85 } 86 87 domain = q.Name.String() 88 r = true 89 return 90} 91 92// Process implements proxy.Outbound. 93func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { 94 outbound := session.OutboundFromContext(ctx) 95 if outbound == nil || !outbound.Target.IsValid() { 96 return newError("invalid outbound") 97 } 98 99 srcNetwork := outbound.Target.Network 100 101 dest := outbound.Target 102 if h.server.Network != net.Network_Unknown { 103 dest.Network = h.server.Network 104 } 105 if h.server.Address != nil { 106 dest.Address = h.server.Address 107 } 108 if h.server.Port != 0 { 109 dest.Port = h.server.Port 110 } 111 112 newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx)) 113 114 conn := &outboundConn{ 115 dialer: func() (stat.Connection, error) { 116 return d.Dial(ctx, dest) 117 }, 118 connReady: make(chan struct{}, 1), 119 } 120 121 var reader dns_proto.MessageReader 122 var writer dns_proto.MessageWriter 123 if srcNetwork == net.Network_TCP { 124 reader = dns_proto.NewTCPReader(link.Reader) 125 writer = &dns_proto.TCPWriter{ 126 Writer: link.Writer, 127 } 128 } else { 129 reader = &dns_proto.UDPReader{ 130 Reader: link.Reader, 131 } 132 writer = &dns_proto.UDPWriter{ 133 Writer: link.Writer, 134 } 135 } 136 137 var connReader dns_proto.MessageReader 138 var connWriter dns_proto.MessageWriter 139 if dest.Network == net.Network_TCP { 140 connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) 141 connWriter = &dns_proto.TCPWriter{ 142 Writer: buf.NewWriter(conn), 143 } 144 } else { 145 connReader = &dns_proto.UDPReader{ 146 Reader: buf.NewPacketReader(conn), 147 } 148 connWriter = &dns_proto.UDPWriter{ 149 Writer: buf.NewWriter(conn), 150 } 151 } 152 153 ctx, cancel := context.WithCancel(ctx) 154 timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) 155 156 request := func() error { 157 defer conn.Close() 158 159 for { 160 b, err := reader.ReadMessage() 161 if err == io.EOF { 162 return nil 163 } 164 165 if err != nil { 166 return err 167 } 168 169 timer.Update() 170 171 if !h.isOwnLink(ctx) { 172 isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) 173 if isIPQuery { 174 go h.handleIPQuery(id, qType, domain, writer) 175 continue 176 } 177 } 178 179 if err := connWriter.WriteMessage(b); err != nil { 180 return err 181 } 182 } 183 } 184 185 response := func() error { 186 for { 187 b, err := connReader.ReadMessage() 188 if err == io.EOF { 189 return nil 190 } 191 192 if err != nil { 193 return err 194 } 195 196 timer.Update() 197 198 if err := writer.WriteMessage(b); err != nil { 199 return err 200 } 201 } 202 } 203 204 if err := task.Run(ctx, request, response); err != nil { 205 return newError("connection ends").Base(err) 206 } 207 208 return nil 209} 210 211func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { 212 var ips []net.IP 213 var err error 214 215 var ttl uint32 = 600 216 217 switch qType { 218 case dnsmessage.TypeA: 219 ips, err = h.client.LookupIP(domain, dns.IPOption{ 220 IPv4Enable: true, 221 IPv6Enable: false, 222 FakeEnable: true, 223 }) 224 case dnsmessage.TypeAAAA: 225 ips, err = h.client.LookupIP(domain, dns.IPOption{ 226 IPv4Enable: false, 227 IPv6Enable: true, 228 FakeEnable: true, 229 }) 230 } 231 232 rcode := dns.RCodeFromError(err) 233 if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse { 234 newError("ip query").Base(err).WriteToLog() 235 return 236 } 237 238 switch qType { 239 case dnsmessage.TypeA: 240 for i, ip := range ips { 241 ips[i] = ip.To4() 242 } 243 case dnsmessage.TypeAAAA: 244 for i, ip := range ips { 245 ips[i] = ip.To16() 246 } 247 } 248 249 b := buf.New() 250 rawBytes := b.Extend(buf.Size) 251 builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ 252 ID: id, 253 RCode: dnsmessage.RCode(rcode), 254 RecursionAvailable: true, 255 RecursionDesired: true, 256 Response: true, 257 Authoritative: true, 258 }) 259 builder.EnableCompression() 260 common.Must(builder.StartQuestions()) 261 common.Must(builder.Question(dnsmessage.Question{ 262 Name: dnsmessage.MustNewName(domain), 263 Class: dnsmessage.ClassINET, 264 Type: qType, 265 })) 266 common.Must(builder.StartAnswers()) 267 268 rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl} 269 for _, ip := range ips { 270 if len(ip) == net.IPv4len { 271 var r dnsmessage.AResource 272 copy(r.A[:], ip) 273 common.Must(builder.AResource(rHeader, r)) 274 } else { 275 var r dnsmessage.AAAAResource 276 copy(r.AAAA[:], ip) 277 common.Must(builder.AAAAResource(rHeader, r)) 278 } 279 } 280 msgBytes, err := builder.Finish() 281 if err != nil { 282 newError("pack message").Base(err).WriteToLog() 283 b.Release() 284 return 285 } 286 b.Resize(0, int32(len(msgBytes))) 287 288 if err := writer.WriteMessage(b); err != nil { 289 newError("write IP answer").Base(err).WriteToLog() 290 } 291} 292 293type outboundConn struct { 294 access sync.Mutex 295 dialer func() (stat.Connection, error) 296 297 conn net.Conn 298 connReady chan struct{} 299} 300 301func (c *outboundConn) dial() error { 302 conn, err := c.dialer() 303 if err != nil { 304 return err 305 } 306 c.conn = conn 307 c.connReady <- struct{}{} 308 return nil 309} 310 311func (c *outboundConn) Write(b []byte) (int, error) { 312 c.access.Lock() 313 314 if c.conn == nil { 315 if err := c.dial(); err != nil { 316 c.access.Unlock() 317 newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog() 318 return len(b), nil 319 } 320 } 321 322 c.access.Unlock() 323 324 return c.conn.Write(b) 325} 326 327func (c *outboundConn) Read(b []byte) (int, error) { 328 var conn net.Conn 329 c.access.Lock() 330 conn = c.conn 331 c.access.Unlock() 332 333 if conn == nil { 334 _, open := <-c.connReady 335 if !open { 336 return 0, io.EOF 337 } 338 conn = c.conn 339 } 340 341 return conn.Read(b) 342} 343 344func (c *outboundConn) Close() error { 345 c.access.Lock() 346 close(c.connReady) 347 if c.conn != nil { 348 c.conn.Close() 349 } 350 c.access.Unlock() 351 return nil 352} 353