1package dns 2 3// A client implementation. 4 5import ( 6 "bytes" 7 "crypto/tls" 8 "encoding/binary" 9 "io" 10 "net" 11 "time" 12) 13 14const dnsTimeout time.Duration = 2 * time.Second 15const tcpIdleTimeout time.Duration = 8 * time.Second 16 17// A Conn represents a connection to a DNS server. 18type Conn struct { 19 net.Conn // a net.Conn holding the connection 20 UDPSize uint16 // minimum receive buffer for UDP messages 21 TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified 22 rtt time.Duration 23 t time.Time 24 tsigRequestMAC string 25} 26 27// A Client defines parameters for a DNS client. 28type Client struct { 29 Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) 30 UDPSize uint16 // minimum receive buffer for UDP messages 31 TLSConfig *tls.Config // TLS connection configuration 32 Timeout time.Duration // a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout and WriteTimeout when non-zero 33 DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds - overridden by Timeout when that value is non-zero 34 ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero 35 WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero 36 TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified 37 SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass 38 group singleflight 39} 40 41// Exchange performs a synchronous UDP query. It sends the message m to the address 42// contained in a and waits for a reply. Exchange does not retry a failed query, nor 43// will it fall back to TCP in case of truncation. 44// See client.Exchange for more information on setting larger buffer sizes. 45func Exchange(m *Msg, a string) (r *Msg, err error) { 46 var co *Conn 47 co, err = DialTimeout("udp", a, dnsTimeout) 48 if err != nil { 49 return nil, err 50 } 51 52 defer co.Close() 53 54 opt := m.IsEdns0() 55 // If EDNS0 is used use that for size. 56 if opt != nil && opt.UDPSize() >= MinMsgSize { 57 co.UDPSize = opt.UDPSize() 58 } 59 60 co.SetWriteDeadline(time.Now().Add(dnsTimeout)) 61 if err = co.WriteMsg(m); err != nil { 62 return nil, err 63 } 64 65 co.SetReadDeadline(time.Now().Add(dnsTimeout)) 66 r, err = co.ReadMsg() 67 if err == nil && r.Id != m.Id { 68 err = ErrId 69 } 70 return r, err 71} 72 73// ExchangeConn performs a synchronous query. It sends the message m via the connection 74// c and waits for a reply. The connection c is not closed by ExchangeConn. 75// This function is going away, but can easily be mimicked: 76// 77// co := &dns.Conn{Conn: c} // c is your net.Conn 78// co.WriteMsg(m) 79// in, _ := co.ReadMsg() 80// co.Close() 81// 82func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { 83 println("dns: this function is deprecated") 84 co := new(Conn) 85 co.Conn = c 86 if err = co.WriteMsg(m); err != nil { 87 return nil, err 88 } 89 r, err = co.ReadMsg() 90 if err == nil && r.Id != m.Id { 91 err = ErrId 92 } 93 return r, err 94} 95 96// Exchange performs a synchronous query. It sends the message m to the address 97// contained in a and waits for a reply. Basic use pattern with a *dns.Client: 98// 99// c := new(dns.Client) 100// in, rtt, err := c.Exchange(message, "127.0.0.1:53") 101// 102// Exchange does not retry a failed query, nor will it fall back to TCP in 103// case of truncation. 104// It is up to the caller to create a message that allows for larger responses to be 105// returned. Specifically this means adding an EDNS0 OPT RR that will advertise a larger 106// buffer, see SetEdns0. Messsages without an OPT RR will fallback to the historic limit 107// of 512 bytes. 108func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { 109 if !c.SingleInflight { 110 return c.exchange(m, a) 111 } 112 // This adds a bunch of garbage, TODO(miek). 113 t := "nop" 114 if t1, ok := TypeToString[m.Question[0].Qtype]; ok { 115 t = t1 116 } 117 cl := "nop" 118 if cl1, ok := ClassToString[m.Question[0].Qclass]; ok { 119 cl = cl1 120 } 121 r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { 122 return c.exchange(m, a) 123 }) 124 if err != nil { 125 return r, rtt, err 126 } 127 if shared { 128 return r.Copy(), rtt, nil 129 } 130 return r, rtt, nil 131} 132 133func (c *Client) dialTimeout() time.Duration { 134 if c.Timeout != 0 { 135 return c.Timeout 136 } 137 if c.DialTimeout != 0 { 138 return c.DialTimeout 139 } 140 return dnsTimeout 141} 142 143func (c *Client) readTimeout() time.Duration { 144 if c.ReadTimeout != 0 { 145 return c.ReadTimeout 146 } 147 return dnsTimeout 148} 149 150func (c *Client) writeTimeout() time.Duration { 151 if c.WriteTimeout != 0 { 152 return c.WriteTimeout 153 } 154 return dnsTimeout 155} 156 157func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { 158 var co *Conn 159 network := "udp" 160 tls := false 161 162 switch c.Net { 163 case "tcp-tls": 164 network = "tcp" 165 tls = true 166 case "tcp4-tls": 167 network = "tcp4" 168 tls = true 169 case "tcp6-tls": 170 network = "tcp6" 171 tls = true 172 default: 173 if c.Net != "" { 174 network = c.Net 175 } 176 } 177 178 var deadline time.Time 179 if c.Timeout != 0 { 180 deadline = time.Now().Add(c.Timeout) 181 } 182 183 if tls { 184 co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout()) 185 } else { 186 co, err = DialTimeout(network, a, c.dialTimeout()) 187 } 188 189 if err != nil { 190 return nil, 0, err 191 } 192 defer co.Close() 193 194 opt := m.IsEdns0() 195 // If EDNS0 is used use that for size. 196 if opt != nil && opt.UDPSize() >= MinMsgSize { 197 co.UDPSize = opt.UDPSize() 198 } 199 // Otherwise use the client's configured UDP size. 200 if opt == nil && c.UDPSize >= MinMsgSize { 201 co.UDPSize = c.UDPSize 202 } 203 204 co.TsigSecret = c.TsigSecret 205 co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout())) 206 if err = co.WriteMsg(m); err != nil { 207 return nil, 0, err 208 } 209 210 co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout())) 211 r, err = co.ReadMsg() 212 if err == nil && r.Id != m.Id { 213 err = ErrId 214 } 215 return r, co.rtt, err 216} 217 218// ReadMsg reads a message from the connection co. 219// If the received message contains a TSIG record the transaction 220// signature is verified. 221func (co *Conn) ReadMsg() (*Msg, error) { 222 p, err := co.ReadMsgHeader(nil) 223 if err != nil { 224 return nil, err 225 } 226 227 m := new(Msg) 228 if err := m.Unpack(p); err != nil { 229 // If ErrTruncated was returned, we still want to allow the user to use 230 // the message, but naively they can just check err if they don't want 231 // to use a truncated message 232 if err == ErrTruncated { 233 return m, err 234 } 235 return nil, err 236 } 237 if t := m.IsTsig(); t != nil { 238 if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { 239 return m, ErrSecret 240 } 241 // Need to work on the original message p, as that was used to calculate the tsig. 242 err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) 243 } 244 return m, err 245} 246 247// ReadMsgHeader reads a DNS message, parses and populates hdr (when hdr is not nil). 248// Returns message as a byte slice to be parsed with Msg.Unpack later on. 249// Note that error handling on the message body is not possible as only the header is parsed. 250func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { 251 var ( 252 p []byte 253 n int 254 err error 255 ) 256 257 switch t := co.Conn.(type) { 258 case *net.TCPConn, *tls.Conn: 259 r := t.(io.Reader) 260 261 // First two bytes specify the length of the entire message. 262 l, err := tcpMsgLen(r) 263 if err != nil { 264 return nil, err 265 } 266 p = make([]byte, l) 267 n, err = tcpRead(r, p) 268 co.rtt = time.Since(co.t) 269 default: 270 if co.UDPSize > MinMsgSize { 271 p = make([]byte, co.UDPSize) 272 } else { 273 p = make([]byte, MinMsgSize) 274 } 275 n, err = co.Read(p) 276 co.rtt = time.Since(co.t) 277 } 278 279 if err != nil { 280 return nil, err 281 } else if n < headerSize { 282 return nil, ErrShortRead 283 } 284 285 p = p[:n] 286 if hdr != nil { 287 dh, _, err := unpackMsgHdr(p, 0) 288 if err != nil { 289 return nil, err 290 } 291 *hdr = dh 292 } 293 return p, err 294} 295 296// tcpMsgLen is a helper func to read first two bytes of stream as uint16 packet length. 297func tcpMsgLen(t io.Reader) (int, error) { 298 p := []byte{0, 0} 299 n, err := t.Read(p) 300 if err != nil { 301 return 0, err 302 } 303 if n != 2 { 304 return 0, ErrShortRead 305 } 306 l := binary.BigEndian.Uint16(p) 307 if l == 0 { 308 return 0, ErrShortRead 309 } 310 return int(l), nil 311} 312 313// tcpRead calls TCPConn.Read enough times to fill allocated buffer. 314func tcpRead(t io.Reader, p []byte) (int, error) { 315 n, err := t.Read(p) 316 if err != nil { 317 return n, err 318 } 319 for n < len(p) { 320 j, err := t.Read(p[n:]) 321 if err != nil { 322 return n, err 323 } 324 n += j 325 } 326 return n, err 327} 328 329// Read implements the net.Conn read method. 330func (co *Conn) Read(p []byte) (n int, err error) { 331 if co.Conn == nil { 332 return 0, ErrConnEmpty 333 } 334 if len(p) < 2 { 335 return 0, io.ErrShortBuffer 336 } 337 switch t := co.Conn.(type) { 338 case *net.TCPConn, *tls.Conn: 339 r := t.(io.Reader) 340 341 l, err := tcpMsgLen(r) 342 if err != nil { 343 return 0, err 344 } 345 if l > len(p) { 346 return int(l), io.ErrShortBuffer 347 } 348 return tcpRead(r, p[:l]) 349 } 350 // UDP connection 351 n, err = co.Conn.Read(p) 352 if err != nil { 353 return n, err 354 } 355 return n, err 356} 357 358// WriteMsg sends a message through the connection co. 359// If the message m contains a TSIG record the transaction 360// signature is calculated. 361func (co *Conn) WriteMsg(m *Msg) (err error) { 362 var out []byte 363 if t := m.IsTsig(); t != nil { 364 mac := "" 365 if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { 366 return ErrSecret 367 } 368 out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) 369 // Set for the next read, although only used in zone transfers 370 co.tsigRequestMAC = mac 371 } else { 372 out, err = m.Pack() 373 } 374 if err != nil { 375 return err 376 } 377 co.t = time.Now() 378 if _, err = co.Write(out); err != nil { 379 return err 380 } 381 return nil 382} 383 384// Write implements the net.Conn Write method. 385func (co *Conn) Write(p []byte) (n int, err error) { 386 switch t := co.Conn.(type) { 387 case *net.TCPConn, *tls.Conn: 388 w := t.(io.Writer) 389 390 lp := len(p) 391 if lp < 2 { 392 return 0, io.ErrShortBuffer 393 } 394 if lp > MaxMsgSize { 395 return 0, &Error{err: "message too large"} 396 } 397 l := make([]byte, 2, lp+2) 398 binary.BigEndian.PutUint16(l, uint16(lp)) 399 p = append(l, p...) 400 n, err := io.Copy(w, bytes.NewReader(p)) 401 return int(n), err 402 } 403 n, err = co.Conn.(*net.UDPConn).Write(p) 404 return n, err 405} 406 407// Dial connects to the address on the named network. 408func Dial(network, address string) (conn *Conn, err error) { 409 conn = new(Conn) 410 conn.Conn, err = net.Dial(network, address) 411 if err != nil { 412 return nil, err 413 } 414 return conn, nil 415} 416 417// DialTimeout acts like Dial but takes a timeout. 418func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) { 419 conn = new(Conn) 420 conn.Conn, err = net.DialTimeout(network, address, timeout) 421 if err != nil { 422 return nil, err 423 } 424 return conn, nil 425} 426 427// DialWithTLS connects to the address on the named network with TLS. 428func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, err error) { 429 conn = new(Conn) 430 conn.Conn, err = tls.Dial(network, address, tlsConfig) 431 if err != nil { 432 return nil, err 433 } 434 return conn, nil 435} 436 437// DialTimeoutWithTLS acts like DialWithTLS but takes a timeout. 438func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout time.Duration) (conn *Conn, err error) { 439 var dialer net.Dialer 440 dialer.Timeout = timeout 441 442 conn = new(Conn) 443 conn.Conn, err = tls.DialWithDialer(&dialer, network, address, tlsConfig) 444 if err != nil { 445 return nil, err 446 } 447 return conn, nil 448} 449 450func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time { 451 if deadline.IsZero() { 452 return time.Now().Add(timeout) 453 } 454 return deadline 455} 456