1package dns 2 3// A client implementation. 4 5import ( 6 "context" 7 "crypto/tls" 8 "encoding/binary" 9 "fmt" 10 "io" 11 "net" 12 "strings" 13 "time" 14) 15 16const ( 17 dnsTimeout time.Duration = 2 * time.Second 18 tcpIdleTimeout time.Duration = 8 * time.Second 19) 20 21// A Conn represents a connection to a DNS server. 22type Conn struct { 23 net.Conn // a net.Conn holding the connection 24 UDPSize uint16 // minimum receive buffer for UDP messages 25 TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) 26 tsigRequestMAC string 27} 28 29// A Client defines parameters for a DNS client. 30type Client struct { 31 Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) 32 UDPSize uint16 // minimum receive buffer for UDP messages 33 TLSConfig *tls.Config // TLS connection configuration 34 Dialer *net.Dialer // a net.Dialer used to set local address, timeouts and more 35 // Timeout is a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout, 36 // WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and 37 // Client.Dialer) or context.Context.Deadline (see the deprecated ExchangeContext) 38 Timeout time.Duration 39 DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero 40 ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero 41 WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero 42 TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) 43 SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass 44 group singleflight 45} 46 47// Exchange performs a synchronous UDP query. It sends the message m to the address 48// contained in a and waits for a reply. Exchange does not retry a failed query, nor 49// will it fall back to TCP in case of truncation. 50// See client.Exchange for more information on setting larger buffer sizes. 51func Exchange(m *Msg, a string) (r *Msg, err error) { 52 client := Client{Net: "udp"} 53 r, _, err = client.Exchange(m, a) 54 return r, err 55} 56 57func (c *Client) dialTimeout() time.Duration { 58 if c.Timeout != 0 { 59 return c.Timeout 60 } 61 if c.DialTimeout != 0 { 62 return c.DialTimeout 63 } 64 return dnsTimeout 65} 66 67func (c *Client) readTimeout() time.Duration { 68 if c.ReadTimeout != 0 { 69 return c.ReadTimeout 70 } 71 return dnsTimeout 72} 73 74func (c *Client) writeTimeout() time.Duration { 75 if c.WriteTimeout != 0 { 76 return c.WriteTimeout 77 } 78 return dnsTimeout 79} 80 81// Dial connects to the address on the named network. 82func (c *Client) Dial(address string) (conn *Conn, err error) { 83 // create a new dialer with the appropriate timeout 84 var d net.Dialer 85 if c.Dialer == nil { 86 d = net.Dialer{Timeout: c.getTimeoutForRequest(c.dialTimeout())} 87 } else { 88 d = *c.Dialer 89 } 90 91 network := c.Net 92 if network == "" { 93 network = "udp" 94 } 95 96 useTLS := strings.HasPrefix(network, "tcp") && strings.HasSuffix(network, "-tls") 97 98 conn = new(Conn) 99 if useTLS { 100 network = strings.TrimSuffix(network, "-tls") 101 102 conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig) 103 } else { 104 conn.Conn, err = d.Dial(network, address) 105 } 106 if err != nil { 107 return nil, err 108 } 109 110 return conn, nil 111} 112 113// Exchange performs a synchronous query. It sends the message m to the address 114// contained in a and waits for a reply. Basic use pattern with a *dns.Client: 115// 116// c := new(dns.Client) 117// in, rtt, err := c.Exchange(message, "127.0.0.1:53") 118// 119// Exchange does not retry a failed query, nor will it fall back to TCP in 120// case of truncation. 121// It is up to the caller to create a message that allows for larger responses to be 122// returned. Specifically this means adding an EDNS0 OPT RR that will advertise a larger 123// buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit 124// of 512 bytes 125// To specify a local address or a timeout, the caller has to set the `Client.Dialer` 126// attribute appropriately 127func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) { 128 if !c.SingleInflight { 129 return c.exchange(m, address) 130 } 131 132 q := m.Question[0] 133 key := fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass) 134 r, rtt, err, shared := c.group.Do(key, func() (*Msg, time.Duration, error) { 135 return c.exchange(m, address) 136 }) 137 if r != nil && shared { 138 r = r.Copy() 139 } 140 141 return r, rtt, err 142} 143 144func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { 145 var co *Conn 146 147 co, err = c.Dial(a) 148 149 if err != nil { 150 return nil, 0, err 151 } 152 defer co.Close() 153 154 opt := m.IsEdns0() 155 // If EDNS0 is used use that for size. 156 if opt != nil && opt.UDPSize() >= MinMsgSize { 157 co.UDPSize = opt.UDPSize() 158 } 159 // Otherwise use the client's configured UDP size. 160 if opt == nil && c.UDPSize >= MinMsgSize { 161 co.UDPSize = c.UDPSize 162 } 163 164 co.TsigSecret = c.TsigSecret 165 t := time.Now() 166 // write with the appropriate write timeout 167 co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout()))) 168 if err = co.WriteMsg(m); err != nil { 169 return nil, 0, err 170 } 171 172 co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout()))) 173 r, err = co.ReadMsg() 174 if err == nil && r.Id != m.Id { 175 err = ErrId 176 } 177 rtt = time.Since(t) 178 return r, rtt, err 179} 180 181// ReadMsg reads a message from the connection co. 182// If the received message contains a TSIG record the transaction signature 183// is verified. This method always tries to return the message, however if an 184// error is returned there are no guarantees that the returned message is a 185// valid representation of the packet read. 186func (co *Conn) ReadMsg() (*Msg, error) { 187 p, err := co.ReadMsgHeader(nil) 188 if err != nil { 189 return nil, err 190 } 191 192 m := new(Msg) 193 if err := m.Unpack(p); err != nil { 194 // If an error was returned, we still want to allow the user to use 195 // the message, but naively they can just check err if they don't want 196 // to use an erroneous message 197 return m, err 198 } 199 if t := m.IsTsig(); t != nil { 200 if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { 201 return m, ErrSecret 202 } 203 // Need to work on the original message p, as that was used to calculate the tsig. 204 err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) 205 } 206 return m, err 207} 208 209// ReadMsgHeader reads a DNS message, parses and populates hdr (when hdr is not nil). 210// Returns message as a byte slice to be parsed with Msg.Unpack later on. 211// Note that error handling on the message body is not possible as only the header is parsed. 212func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { 213 var ( 214 p []byte 215 n int 216 err error 217 ) 218 219 if _, ok := co.Conn.(net.PacketConn); ok { 220 if co.UDPSize > MinMsgSize { 221 p = make([]byte, co.UDPSize) 222 } else { 223 p = make([]byte, MinMsgSize) 224 } 225 n, err = co.Read(p) 226 } else { 227 var length uint16 228 if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { 229 return nil, err 230 } 231 232 p = make([]byte, length) 233 n, err = io.ReadFull(co.Conn, p) 234 } 235 236 if err != nil { 237 return nil, err 238 } else if n < headerSize { 239 return nil, ErrShortRead 240 } 241 242 p = p[:n] 243 if hdr != nil { 244 dh, _, err := unpackMsgHdr(p, 0) 245 if err != nil { 246 return nil, err 247 } 248 *hdr = dh 249 } 250 return p, err 251} 252 253// Read implements the net.Conn read method. 254func (co *Conn) Read(p []byte) (n int, err error) { 255 if co.Conn == nil { 256 return 0, ErrConnEmpty 257 } 258 259 if _, ok := co.Conn.(net.PacketConn); ok { 260 // UDP connection 261 return co.Conn.Read(p) 262 } 263 264 var length uint16 265 if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { 266 return 0, err 267 } 268 if int(length) > len(p) { 269 return 0, io.ErrShortBuffer 270 } 271 272 return io.ReadFull(co.Conn, p[:length]) 273} 274 275// WriteMsg sends a message through the connection co. 276// If the message m contains a TSIG record the transaction 277// signature is calculated. 278func (co *Conn) WriteMsg(m *Msg) (err error) { 279 var out []byte 280 if t := m.IsTsig(); t != nil { 281 mac := "" 282 if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { 283 return ErrSecret 284 } 285 out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) 286 // Set for the next read, although only used in zone transfers 287 co.tsigRequestMAC = mac 288 } else { 289 out, err = m.Pack() 290 } 291 if err != nil { 292 return err 293 } 294 _, err = co.Write(out) 295 return err 296} 297 298// Write implements the net.Conn Write method. 299func (co *Conn) Write(p []byte) (int, error) { 300 if len(p) > MaxMsgSize { 301 return 0, &Error{err: "message too large"} 302 } 303 304 if _, ok := co.Conn.(net.PacketConn); ok { 305 return co.Conn.Write(p) 306 } 307 308 l := make([]byte, 2) 309 binary.BigEndian.PutUint16(l, uint16(len(p))) 310 311 n, err := (&net.Buffers{l, p}).WriteTo(co.Conn) 312 return int(n), err 313} 314 315// Return the appropriate timeout for a specific request 316func (c *Client) getTimeoutForRequest(timeout time.Duration) time.Duration { 317 var requestTimeout time.Duration 318 if c.Timeout != 0 { 319 requestTimeout = c.Timeout 320 } else { 321 requestTimeout = timeout 322 } 323 // net.Dialer.Timeout has priority if smaller than the timeouts computed so 324 // far 325 if c.Dialer != nil && c.Dialer.Timeout != 0 { 326 if c.Dialer.Timeout < requestTimeout { 327 requestTimeout = c.Dialer.Timeout 328 } 329 } 330 return requestTimeout 331} 332 333// Dial connects to the address on the named network. 334func Dial(network, address string) (conn *Conn, err error) { 335 conn = new(Conn) 336 conn.Conn, err = net.Dial(network, address) 337 if err != nil { 338 return nil, err 339 } 340 return conn, nil 341} 342 343// ExchangeContext performs a synchronous UDP query, like Exchange. It 344// additionally obeys deadlines from the passed Context. 345func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) { 346 client := Client{Net: "udp"} 347 r, _, err = client.ExchangeContext(ctx, m, a) 348 // ignorint rtt to leave the original ExchangeContext API unchanged, but 349 // this function will go away 350 return r, err 351} 352 353// ExchangeConn performs a synchronous query. It sends the message m via the connection 354// c and waits for a reply. The connection c is not closed by ExchangeConn. 355// Deprecated: This function is going away, but can easily be mimicked: 356// 357// co := &dns.Conn{Conn: c} // c is your net.Conn 358// co.WriteMsg(m) 359// in, _ := co.ReadMsg() 360// co.Close() 361// 362func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { 363 println("dns: ExchangeConn: this function is deprecated") 364 co := new(Conn) 365 co.Conn = c 366 if err = co.WriteMsg(m); err != nil { 367 return nil, err 368 } 369 r, err = co.ReadMsg() 370 if err == nil && r.Id != m.Id { 371 err = ErrId 372 } 373 return r, err 374} 375 376// DialTimeout acts like Dial but takes a timeout. 377func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) { 378 client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}} 379 return client.Dial(address) 380} 381 382// DialWithTLS connects to the address on the named network with TLS. 383func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, err error) { 384 if !strings.HasSuffix(network, "-tls") { 385 network += "-tls" 386 } 387 client := Client{Net: network, TLSConfig: tlsConfig} 388 return client.Dial(address) 389} 390 391// DialTimeoutWithTLS acts like DialWithTLS but takes a timeout. 392func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout time.Duration) (conn *Conn, err error) { 393 if !strings.HasSuffix(network, "-tls") { 394 network += "-tls" 395 } 396 client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig} 397 return client.Dial(address) 398} 399 400// ExchangeContext acts like Exchange, but honors the deadline on the provided 401// context, if present. If there is both a context deadline and a configured 402// timeout on the client, the earliest of the two takes effect. 403func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) { 404 var timeout time.Duration 405 if deadline, ok := ctx.Deadline(); !ok { 406 timeout = 0 407 } else { 408 timeout = time.Until(deadline) 409 } 410 // not passing the context to the underlying calls, as the API does not support 411 // context. For timeouts you should set up Client.Dialer and call Client.Exchange. 412 // TODO(tmthrgd,miekg): this is a race condition. 413 c.Dialer = &net.Dialer{Timeout: timeout} 414 return c.Exchange(m, a) 415} 416