1package dns 2 3// A client implementation. 4 5import ( 6 "bytes" 7 "context" 8 "crypto/tls" 9 "encoding/binary" 10 "fmt" 11 "io" 12 "io/ioutil" 13 "net" 14 "net/http" 15 "strings" 16 "time" 17) 18 19const ( 20 dnsTimeout time.Duration = 2 * time.Second 21 tcpIdleTimeout time.Duration = 8 * time.Second 22 23 dohMimeType = "application/dns-message" 24) 25 26// A Conn represents a connection to a DNS server. 27type Conn struct { 28 net.Conn // a net.Conn holding the connection 29 UDPSize uint16 // minimum receive buffer for UDP messages 30 TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) 31 tsigRequestMAC string 32} 33 34// A Client defines parameters for a DNS client. 35type Client struct { 36 Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) 37 UDPSize uint16 // minimum receive buffer for UDP messages 38 TLSConfig *tls.Config // TLS connection configuration 39 Dialer *net.Dialer // a net.Dialer used to set local address, timeouts and more 40 // Timeout is a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout, 41 // WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and 42 // Client.Dialer) or context.Context.Deadline (see the deprecated ExchangeContext) 43 Timeout time.Duration 44 DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero 45 ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero 46 WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero 47 HTTPClient *http.Client // The http.Client to use for DNS-over-HTTPS 48 TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) 49 SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass 50 group singleflight 51} 52 53// Exchange performs a synchronous UDP query. It sends the message m to the address 54// contained in a and waits for a reply. Exchange does not retry a failed query, nor 55// will it fall back to TCP in case of truncation. 56// See client.Exchange for more information on setting larger buffer sizes. 57func Exchange(m *Msg, a string) (r *Msg, err error) { 58 client := Client{Net: "udp"} 59 r, _, err = client.Exchange(m, a) 60 return r, err 61} 62 63func (c *Client) dialTimeout() time.Duration { 64 if c.Timeout != 0 { 65 return c.Timeout 66 } 67 if c.DialTimeout != 0 { 68 return c.DialTimeout 69 } 70 return dnsTimeout 71} 72 73func (c *Client) readTimeout() time.Duration { 74 if c.ReadTimeout != 0 { 75 return c.ReadTimeout 76 } 77 return dnsTimeout 78} 79 80func (c *Client) writeTimeout() time.Duration { 81 if c.WriteTimeout != 0 { 82 return c.WriteTimeout 83 } 84 return dnsTimeout 85} 86 87// Dial connects to the address on the named network. 88func (c *Client) Dial(address string) (conn *Conn, err error) { 89 // create a new dialer with the appropriate timeout 90 var d net.Dialer 91 if c.Dialer == nil { 92 d = net.Dialer{Timeout: c.getTimeoutForRequest(c.dialTimeout())} 93 } else { 94 d = *c.Dialer 95 } 96 97 network := c.Net 98 if network == "" { 99 network = "udp" 100 } 101 102 useTLS := strings.HasPrefix(network, "tcp") && strings.HasSuffix(network, "-tls") 103 104 conn = new(Conn) 105 if useTLS { 106 network = strings.TrimSuffix(network, "-tls") 107 108 conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig) 109 } else { 110 conn.Conn, err = d.Dial(network, address) 111 } 112 if err != nil { 113 return nil, err 114 } 115 116 return conn, nil 117} 118 119// Exchange performs a synchronous query. It sends the message m to the address 120// contained in a and waits for a reply. Basic use pattern with a *dns.Client: 121// 122// c := new(dns.Client) 123// in, rtt, err := c.Exchange(message, "127.0.0.1:53") 124// 125// Exchange does not retry a failed query, nor will it fall back to TCP in 126// case of truncation. 127// It is up to the caller to create a message that allows for larger responses to be 128// returned. Specifically this means adding an EDNS0 OPT RR that will advertise a larger 129// buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit 130// of 512 bytes 131// To specify a local address or a timeout, the caller has to set the `Client.Dialer` 132// attribute appropriately 133func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) { 134 if !c.SingleInflight { 135 if c.Net == "https" { 136 // TODO(tmthrgd): pipe timeouts into exchangeDOH 137 return c.exchangeDOH(context.TODO(), m, address) 138 } 139 140 return c.exchange(m, address) 141 } 142 143 t := "nop" 144 if t1, ok := TypeToString[m.Question[0].Qtype]; ok { 145 t = t1 146 } 147 cl := "nop" 148 if cl1, ok := ClassToString[m.Question[0].Qclass]; ok { 149 cl = cl1 150 } 151 r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { 152 if c.Net == "https" { 153 // TODO(tmthrgd): pipe timeouts into exchangeDOH 154 return c.exchangeDOH(context.TODO(), m, address) 155 } 156 157 return c.exchange(m, address) 158 }) 159 if r != nil && shared { 160 r = r.Copy() 161 } 162 return r, rtt, err 163} 164 165func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { 166 var co *Conn 167 168 co, err = c.Dial(a) 169 170 if err != nil { 171 return nil, 0, err 172 } 173 defer co.Close() 174 175 opt := m.IsEdns0() 176 // If EDNS0 is used use that for size. 177 if opt != nil && opt.UDPSize() >= MinMsgSize { 178 co.UDPSize = opt.UDPSize() 179 } 180 // Otherwise use the client's configured UDP size. 181 if opt == nil && c.UDPSize >= MinMsgSize { 182 co.UDPSize = c.UDPSize 183 } 184 185 co.TsigSecret = c.TsigSecret 186 t := time.Now() 187 // write with the appropriate write timeout 188 co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout()))) 189 if err = co.WriteMsg(m); err != nil { 190 return nil, 0, err 191 } 192 193 co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout()))) 194 r, err = co.ReadMsg() 195 if err == nil && r.Id != m.Id { 196 err = ErrId 197 } 198 rtt = time.Since(t) 199 return r, rtt, err 200} 201 202func (c *Client) exchangeDOH(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) { 203 p, err := m.Pack() 204 if err != nil { 205 return nil, 0, err 206 } 207 208 req, err := http.NewRequest(http.MethodPost, a, bytes.NewReader(p)) 209 if err != nil { 210 return nil, 0, err 211 } 212 213 req.Header.Set("Content-Type", dohMimeType) 214 req.Header.Set("Accept", dohMimeType) 215 216 hc := http.DefaultClient 217 if c.HTTPClient != nil { 218 hc = c.HTTPClient 219 } 220 221 if ctx != context.Background() && ctx != context.TODO() { 222 req = req.WithContext(ctx) 223 } 224 225 t := time.Now() 226 227 resp, err := hc.Do(req) 228 if err != nil { 229 return nil, 0, err 230 } 231 defer closeHTTPBody(resp.Body) 232 233 if resp.StatusCode != http.StatusOK { 234 return nil, 0, fmt.Errorf("dns: server returned HTTP %d error: %q", resp.StatusCode, resp.Status) 235 } 236 237 if ct := resp.Header.Get("Content-Type"); ct != dohMimeType { 238 return nil, 0, fmt.Errorf("dns: unexpected Content-Type %q; expected %q", ct, dohMimeType) 239 } 240 241 p, err = ioutil.ReadAll(resp.Body) 242 if err != nil { 243 return nil, 0, err 244 } 245 246 rtt = time.Since(t) 247 248 r = new(Msg) 249 if err := r.Unpack(p); err != nil { 250 return r, 0, err 251 } 252 253 // TODO: TSIG? Is it even supported over DoH? 254 255 return r, rtt, nil 256} 257 258func closeHTTPBody(r io.ReadCloser) error { 259 io.Copy(ioutil.Discard, io.LimitReader(r, 8<<20)) 260 return r.Close() 261} 262 263// ReadMsg reads a message from the connection co. 264// If the received message contains a TSIG record the transaction signature 265// is verified. This method always tries to return the message, however if an 266// error is returned there are no guarantees that the returned message is a 267// valid representation of the packet read. 268func (co *Conn) ReadMsg() (*Msg, error) { 269 p, err := co.ReadMsgHeader(nil) 270 if err != nil { 271 return nil, err 272 } 273 274 m := new(Msg) 275 if err := m.Unpack(p); err != nil { 276 // If an error was returned, we still want to allow the user to use 277 // the message, but naively they can just check err if they don't want 278 // to use an erroneous message 279 return m, err 280 } 281 if t := m.IsTsig(); t != nil { 282 if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { 283 return m, ErrSecret 284 } 285 // Need to work on the original message p, as that was used to calculate the tsig. 286 err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) 287 } 288 return m, err 289} 290 291// ReadMsgHeader reads a DNS message, parses and populates hdr (when hdr is not nil). 292// Returns message as a byte slice to be parsed with Msg.Unpack later on. 293// Note that error handling on the message body is not possible as only the header is parsed. 294func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { 295 var ( 296 p []byte 297 n int 298 err error 299 ) 300 301 switch t := co.Conn.(type) { 302 case *net.TCPConn, *tls.Conn: 303 r := t.(io.Reader) 304 305 // First two bytes specify the length of the entire message. 306 l, err := tcpMsgLen(r) 307 if err != nil { 308 return nil, err 309 } 310 p = make([]byte, l) 311 n, err = tcpRead(r, p) 312 default: 313 if co.UDPSize > MinMsgSize { 314 p = make([]byte, co.UDPSize) 315 } else { 316 p = make([]byte, MinMsgSize) 317 } 318 n, err = co.Read(p) 319 } 320 321 if err != nil { 322 return nil, err 323 } else if n < headerSize { 324 return nil, ErrShortRead 325 } 326 327 p = p[:n] 328 if hdr != nil { 329 dh, _, err := unpackMsgHdr(p, 0) 330 if err != nil { 331 return nil, err 332 } 333 *hdr = dh 334 } 335 return p, err 336} 337 338// tcpMsgLen is a helper func to read first two bytes of stream as uint16 packet length. 339func tcpMsgLen(t io.Reader) (int, error) { 340 p := []byte{0, 0} 341 n, err := t.Read(p) 342 if err != nil { 343 return 0, err 344 } 345 346 // As seen with my local router/switch, returns 1 byte on the above read, 347 // resulting a a ShortRead. Just write it out (instead of loop) and read the 348 // other byte. 349 if n == 1 { 350 n1, err := t.Read(p[1:]) 351 if err != nil { 352 return 0, err 353 } 354 n += n1 355 } 356 357 if n != 2 { 358 return 0, ErrShortRead 359 } 360 l := binary.BigEndian.Uint16(p) 361 if l == 0 { 362 return 0, ErrShortRead 363 } 364 return int(l), nil 365} 366 367// tcpRead calls TCPConn.Read enough times to fill allocated buffer. 368func tcpRead(t io.Reader, p []byte) (int, error) { 369 n, err := t.Read(p) 370 if err != nil { 371 return n, err 372 } 373 for n < len(p) { 374 j, err := t.Read(p[n:]) 375 if err != nil { 376 return n, err 377 } 378 n += j 379 } 380 return n, err 381} 382 383// Read implements the net.Conn read method. 384func (co *Conn) Read(p []byte) (n int, err error) { 385 if co.Conn == nil { 386 return 0, ErrConnEmpty 387 } 388 if len(p) < 2 { 389 return 0, io.ErrShortBuffer 390 } 391 switch t := co.Conn.(type) { 392 case *net.TCPConn, *tls.Conn: 393 r := t.(io.Reader) 394 395 l, err := tcpMsgLen(r) 396 if err != nil { 397 return 0, err 398 } 399 if l > len(p) { 400 return int(l), io.ErrShortBuffer 401 } 402 return tcpRead(r, p[:l]) 403 } 404 // UDP connection 405 n, err = co.Conn.Read(p) 406 if err != nil { 407 return n, err 408 } 409 return n, err 410} 411 412// WriteMsg sends a message through the connection co. 413// If the message m contains a TSIG record the transaction 414// signature is calculated. 415func (co *Conn) WriteMsg(m *Msg) (err error) { 416 var out []byte 417 if t := m.IsTsig(); t != nil { 418 mac := "" 419 if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { 420 return ErrSecret 421 } 422 out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) 423 // Set for the next read, although only used in zone transfers 424 co.tsigRequestMAC = mac 425 } else { 426 out, err = m.Pack() 427 } 428 if err != nil { 429 return err 430 } 431 if _, err = co.Write(out); err != nil { 432 return err 433 } 434 return nil 435} 436 437// Write implements the net.Conn Write method. 438func (co *Conn) Write(p []byte) (n int, err error) { 439 switch t := co.Conn.(type) { 440 case *net.TCPConn, *tls.Conn: 441 w := t.(io.Writer) 442 443 lp := len(p) 444 if lp < 2 { 445 return 0, io.ErrShortBuffer 446 } 447 if lp > MaxMsgSize { 448 return 0, &Error{err: "message too large"} 449 } 450 l := make([]byte, 2, lp+2) 451 binary.BigEndian.PutUint16(l, uint16(lp)) 452 p = append(l, p...) 453 n, err := io.Copy(w, bytes.NewReader(p)) 454 return int(n), err 455 } 456 n, err = co.Conn.Write(p) 457 return n, err 458} 459 460// Return the appropriate timeout for a specific request 461func (c *Client) getTimeoutForRequest(timeout time.Duration) time.Duration { 462 var requestTimeout time.Duration 463 if c.Timeout != 0 { 464 requestTimeout = c.Timeout 465 } else { 466 requestTimeout = timeout 467 } 468 // net.Dialer.Timeout has priority if smaller than the timeouts computed so 469 // far 470 if c.Dialer != nil && c.Dialer.Timeout != 0 { 471 if c.Dialer.Timeout < requestTimeout { 472 requestTimeout = c.Dialer.Timeout 473 } 474 } 475 return requestTimeout 476} 477 478// Dial connects to the address on the named network. 479func Dial(network, address string) (conn *Conn, err error) { 480 conn = new(Conn) 481 conn.Conn, err = net.Dial(network, address) 482 if err != nil { 483 return nil, err 484 } 485 return conn, nil 486} 487 488// ExchangeContext performs a synchronous UDP query, like Exchange. It 489// additionally obeys deadlines from the passed Context. 490func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) { 491 client := Client{Net: "udp"} 492 r, _, err = client.ExchangeContext(ctx, m, a) 493 // ignorint rtt to leave the original ExchangeContext API unchanged, but 494 // this function will go away 495 return r, err 496} 497 498// ExchangeConn performs a synchronous query. It sends the message m via the connection 499// c and waits for a reply. The connection c is not closed by ExchangeConn. 500// This function is going away, but can easily be mimicked: 501// 502// co := &dns.Conn{Conn: c} // c is your net.Conn 503// co.WriteMsg(m) 504// in, _ := co.ReadMsg() 505// co.Close() 506// 507func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { 508 println("dns: ExchangeConn: this function is deprecated") 509 co := new(Conn) 510 co.Conn = c 511 if err = co.WriteMsg(m); err != nil { 512 return nil, err 513 } 514 r, err = co.ReadMsg() 515 if err == nil && r.Id != m.Id { 516 err = ErrId 517 } 518 return r, err 519} 520 521// DialTimeout acts like Dial but takes a timeout. 522func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) { 523 client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}} 524 conn, err = client.Dial(address) 525 if err != nil { 526 return nil, err 527 } 528 return conn, nil 529} 530 531// DialWithTLS connects to the address on the named network with TLS. 532func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, err error) { 533 if !strings.HasSuffix(network, "-tls") { 534 network += "-tls" 535 } 536 client := Client{Net: network, TLSConfig: tlsConfig} 537 conn, err = client.Dial(address) 538 539 if err != nil { 540 return nil, err 541 } 542 return conn, nil 543} 544 545// DialTimeoutWithTLS acts like DialWithTLS but takes a timeout. 546func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout time.Duration) (conn *Conn, err error) { 547 if !strings.HasSuffix(network, "-tls") { 548 network += "-tls" 549 } 550 client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig} 551 conn, err = client.Dial(address) 552 if err != nil { 553 return nil, err 554 } 555 return conn, nil 556} 557 558// ExchangeContext acts like Exchange, but honors the deadline on the provided 559// context, if present. If there is both a context deadline and a configured 560// timeout on the client, the earliest of the two takes effect. 561func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) { 562 if !c.SingleInflight && c.Net == "https" { 563 return c.exchangeDOH(ctx, m, a) 564 } 565 566 var timeout time.Duration 567 if deadline, ok := ctx.Deadline(); !ok { 568 timeout = 0 569 } else { 570 timeout = time.Until(deadline) 571 } 572 // not passing the context to the underlying calls, as the API does not support 573 // context. For timeouts you should set up Client.Dialer and call Client.Exchange. 574 // TODO(tmthrgd): this is a race condition 575 c.Dialer = &net.Dialer{Timeout: timeout} 576 return c.Exchange(m, a) 577} 578