1// DNS packet assembly, see RFC 1035. Converting from - Unpack() - 2// and to - Pack() - wire format. 3// All the packers and unpackers take a (msg []byte, off int) 4// and return (off1 int, ok bool). If they return ok==false, they 5// also return off1==len(msg), so that the next unpacker will 6// also fail. This lets us avoid checks of ok until the end of a 7// packing sequence. 8 9package dns 10 11//go:generate go run msg_generate.go 12 13import ( 14 "crypto/rand" 15 "encoding/binary" 16 "fmt" 17 "math/big" 18 "strconv" 19 "strings" 20) 21 22const ( 23 maxCompressionOffset = 2 << 13 // We have 14 bits for the compression pointer 24 maxDomainNameWireOctets = 255 // See RFC 1035 section 2.3.4 25 26 // This is the maximum number of compression pointers that should occur in a 27 // semantically valid message. Each label in a domain name must be at least one 28 // octet and is separated by a period. The root label won't be represented by a 29 // compression pointer to a compression pointer, hence the -2 to exclude the 30 // smallest valid root label. 31 // 32 // It is possible to construct a valid message that has more compression pointers 33 // than this, and still doesn't loop, by pointing to a previous pointer. This is 34 // not something a well written implementation should ever do, so we leave them 35 // to trip the maximum compression pointer check. 36 maxCompressionPointers = (maxDomainNameWireOctets+1)/2 - 2 37 38 // This is the maximum length of a domain name in presentation format. The 39 // maximum wire length of a domain name is 255 octets (see above), with the 40 // maximum label length being 63. The wire format requires one extra byte over 41 // the presentation format, reducing the number of octets by 1. Each label in 42 // the name will be separated by a single period, with each octet in the label 43 // expanding to at most 4 bytes (\DDD). If all other labels are of the maximum 44 // length, then the final label can only be 61 octets long to not exceed the 45 // maximum allowed wire length. 46 maxDomainNamePresentationLength = 61*4 + 1 + 63*4 + 1 + 63*4 + 1 + 63*4 + 1 47) 48 49// Errors defined in this package. 50var ( 51 ErrAlg error = &Error{err: "bad algorithm"} // ErrAlg indicates an error with the (DNSSEC) algorithm. 52 ErrAuth error = &Error{err: "bad authentication"} // ErrAuth indicates an error in the TSIG authentication. 53 ErrBuf error = &Error{err: "buffer size too small"} // ErrBuf indicates that the buffer used is too small for the message. 54 ErrConnEmpty error = &Error{err: "conn has no connection"} // ErrConnEmpty indicates a connection is being used before it is initialized. 55 ErrExtendedRcode error = &Error{err: "bad extended rcode"} // ErrExtendedRcode ... 56 ErrFqdn error = &Error{err: "domain must be fully qualified"} // ErrFqdn indicates that a domain name does not have a closing dot. 57 ErrId error = &Error{err: "id mismatch"} // ErrId indicates there is a mismatch with the message's ID. 58 ErrKeyAlg error = &Error{err: "bad key algorithm"} // ErrKeyAlg indicates that the algorithm in the key is not valid. 59 ErrKey error = &Error{err: "bad key"} 60 ErrKeySize error = &Error{err: "bad key size"} 61 ErrLongDomain error = &Error{err: fmt.Sprintf("domain name exceeded %d wire-format octets", maxDomainNameWireOctets)} 62 ErrNoSig error = &Error{err: "no signature found"} 63 ErrPrivKey error = &Error{err: "bad private key"} 64 ErrRcode error = &Error{err: "bad rcode"} 65 ErrRdata error = &Error{err: "bad rdata"} 66 ErrRRset error = &Error{err: "bad rrset"} 67 ErrSecret error = &Error{err: "no secrets defined"} 68 ErrShortRead error = &Error{err: "short read"} 69 ErrSig error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated. 70 ErrSoa error = &Error{err: "no SOA"} // ErrSOA indicates that no SOA RR was seen when doing zone transfers. 71 ErrTime error = &Error{err: "bad time"} // ErrTime indicates a timing error in TSIG authentication. 72) 73 74// Id by default returns a 16-bit random number to be used as a message id. The 75// number is drawn from a cryptographically secure random number generator. 76// This being a variable the function can be reassigned to a custom function. 77// For instance, to make it return a static value for testing: 78// 79// dns.Id = func() uint16 { return 3 } 80var Id = id 81 82// id returns a 16 bits random number to be used as a 83// message id. The random provided should be good enough. 84func id() uint16 { 85 var output uint16 86 err := binary.Read(rand.Reader, binary.BigEndian, &output) 87 if err != nil { 88 panic("dns: reading random id failed: " + err.Error()) 89 } 90 return output 91} 92 93// MsgHdr is a a manually-unpacked version of (id, bits). 94type MsgHdr struct { 95 Id uint16 96 Response bool 97 Opcode int 98 Authoritative bool 99 Truncated bool 100 RecursionDesired bool 101 RecursionAvailable bool 102 Zero bool 103 AuthenticatedData bool 104 CheckingDisabled bool 105 Rcode int 106} 107 108// Msg contains the layout of a DNS message. 109type Msg struct { 110 MsgHdr 111 Compress bool `json:"-"` // If true, the message will be compressed when converted to wire format. 112 Question []Question // Holds the RR(s) of the question section. 113 Answer []RR // Holds the RR(s) of the answer section. 114 Ns []RR // Holds the RR(s) of the authority section. 115 Extra []RR // Holds the RR(s) of the additional section. 116} 117 118// ClassToString is a maps Classes to strings for each CLASS wire type. 119var ClassToString = map[uint16]string{ 120 ClassINET: "IN", 121 ClassCSNET: "CS", 122 ClassCHAOS: "CH", 123 ClassHESIOD: "HS", 124 ClassNONE: "NONE", 125 ClassANY: "ANY", 126} 127 128// OpcodeToString maps Opcodes to strings. 129var OpcodeToString = map[int]string{ 130 OpcodeQuery: "QUERY", 131 OpcodeIQuery: "IQUERY", 132 OpcodeStatus: "STATUS", 133 OpcodeNotify: "NOTIFY", 134 OpcodeUpdate: "UPDATE", 135} 136 137// RcodeToString maps Rcodes to strings. 138var RcodeToString = map[int]string{ 139 RcodeSuccess: "NOERROR", 140 RcodeFormatError: "FORMERR", 141 RcodeServerFailure: "SERVFAIL", 142 RcodeNameError: "NXDOMAIN", 143 RcodeNotImplemented: "NOTIMP", 144 RcodeRefused: "REFUSED", 145 RcodeYXDomain: "YXDOMAIN", // See RFC 2136 146 RcodeYXRrset: "YXRRSET", 147 RcodeNXRrset: "NXRRSET", 148 RcodeNotAuth: "NOTAUTH", 149 RcodeNotZone: "NOTZONE", 150 RcodeBadSig: "BADSIG", // Also known as RcodeBadVers, see RFC 6891 151 // RcodeBadVers: "BADVERS", 152 RcodeBadKey: "BADKEY", 153 RcodeBadTime: "BADTIME", 154 RcodeBadMode: "BADMODE", 155 RcodeBadName: "BADNAME", 156 RcodeBadAlg: "BADALG", 157 RcodeBadTrunc: "BADTRUNC", 158 RcodeBadCookie: "BADCOOKIE", 159} 160 161// compressionMap is used to allow a more efficient compression map 162// to be used for internal packDomainName calls without changing the 163// signature or functionality of public API. 164// 165// In particular, map[string]uint16 uses 25% less per-entry memory 166// than does map[string]int. 167type compressionMap struct { 168 ext map[string]int // external callers 169 int map[string]uint16 // internal callers 170} 171 172func (m compressionMap) valid() bool { 173 return m.int != nil || m.ext != nil 174} 175 176func (m compressionMap) insert(s string, pos int) { 177 if m.ext != nil { 178 m.ext[s] = pos 179 } else { 180 m.int[s] = uint16(pos) 181 } 182} 183 184func (m compressionMap) find(s string) (int, bool) { 185 if m.ext != nil { 186 pos, ok := m.ext[s] 187 return pos, ok 188 } 189 190 pos, ok := m.int[s] 191 return int(pos), ok 192} 193 194// Domain names are a sequence of counted strings 195// split at the dots. They end with a zero-length string. 196 197// PackDomainName packs a domain name s into msg[off:]. 198// If compression is wanted compress must be true and the compression 199// map needs to hold a mapping between domain names and offsets 200// pointing into msg. 201func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { 202 return packDomainName(s, msg, off, compressionMap{ext: compression}, compress) 203} 204 205func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) { 206 // XXX: A logical copy of this function exists in IsDomainName and 207 // should be kept in sync with this function. 208 209 ls := len(s) 210 if ls == 0 { // Ok, for instance when dealing with update RR without any rdata. 211 return off, nil 212 } 213 214 // If not fully qualified, error out. 215 if !IsFqdn(s) { 216 return len(msg), ErrFqdn 217 } 218 219 // Each dot ends a segment of the name. 220 // We trade each dot byte for a length byte. 221 // Except for escaped dots (\.), which are normal dots. 222 // There is also a trailing zero. 223 224 // Compression 225 pointer := -1 226 227 // Emit sequence of counted strings, chopping at dots. 228 var ( 229 begin int 230 compBegin int 231 compOff int 232 bs []byte 233 wasDot bool 234 ) 235loop: 236 for i := 0; i < ls; i++ { 237 var c byte 238 if bs == nil { 239 c = s[i] 240 } else { 241 c = bs[i] 242 } 243 244 switch c { 245 case '\\': 246 if off+1 > len(msg) { 247 return len(msg), ErrBuf 248 } 249 250 if bs == nil { 251 bs = []byte(s) 252 } 253 254 // check for \DDD 255 if i+3 < ls && isDigit(bs[i+1]) && isDigit(bs[i+2]) && isDigit(bs[i+3]) { 256 bs[i] = dddToByte(bs[i+1:]) 257 copy(bs[i+1:ls-3], bs[i+4:]) 258 ls -= 3 259 compOff += 3 260 } else { 261 copy(bs[i:ls-1], bs[i+1:]) 262 ls-- 263 compOff++ 264 } 265 266 wasDot = false 267 case '.': 268 if wasDot { 269 // two dots back to back is not legal 270 return len(msg), ErrRdata 271 } 272 wasDot = true 273 274 labelLen := i - begin 275 if labelLen >= 1<<6 { // top two bits of length must be clear 276 return len(msg), ErrRdata 277 } 278 279 // off can already (we're in a loop) be bigger than len(msg) 280 // this happens when a name isn't fully qualified 281 if off+1+labelLen > len(msg) { 282 return len(msg), ErrBuf 283 } 284 285 // Don't try to compress '.' 286 // We should only compress when compress is true, but we should also still pick 287 // up names that can be used for *future* compression(s). 288 if compression.valid() && !isRootLabel(s, bs, begin, ls) { 289 if p, ok := compression.find(s[compBegin:]); ok { 290 // The first hit is the longest matching dname 291 // keep the pointer offset we get back and store 292 // the offset of the current name, because that's 293 // where we need to insert the pointer later 294 295 // If compress is true, we're allowed to compress this dname 296 if compress { 297 pointer = p // Where to point to 298 break loop 299 } 300 } else if off < maxCompressionOffset { 301 // Only offsets smaller than maxCompressionOffset can be used. 302 compression.insert(s[compBegin:], off) 303 } 304 } 305 306 // The following is covered by the length check above. 307 msg[off] = byte(labelLen) 308 309 if bs == nil { 310 copy(msg[off+1:], s[begin:i]) 311 } else { 312 copy(msg[off+1:], bs[begin:i]) 313 } 314 off += 1 + labelLen 315 316 begin = i + 1 317 compBegin = begin + compOff 318 default: 319 wasDot = false 320 } 321 } 322 323 // Root label is special 324 if isRootLabel(s, bs, 0, ls) { 325 return off, nil 326 } 327 328 // If we did compression and we find something add the pointer here 329 if pointer != -1 { 330 // We have two bytes (14 bits) to put the pointer in 331 binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000)) 332 return off + 2, nil 333 } 334 335 if off < len(msg) { 336 msg[off] = 0 337 } 338 339 return off + 1, nil 340} 341 342// isRootLabel returns whether s or bs, from off to end, is the root 343// label ".". 344// 345// If bs is nil, s will be checked, otherwise bs will be checked. 346func isRootLabel(s string, bs []byte, off, end int) bool { 347 if bs == nil { 348 return s[off:end] == "." 349 } 350 351 return end-off == 1 && bs[off] == '.' 352} 353 354// Unpack a domain name. 355// In addition to the simple sequences of counted strings above, 356// domain names are allowed to refer to strings elsewhere in the 357// packet, to avoid repeating common suffixes when returning 358// many entries in a single domain. The pointers are marked 359// by a length byte with the top two bits set. Ignoring those 360// two bits, that byte and the next give a 14 bit offset from msg[0] 361// where we should pick up the trail. 362// Note that if we jump elsewhere in the packet, 363// we return off1 == the offset after the first pointer we found, 364// which is where the next record will start. 365// In theory, the pointers are only allowed to jump backward. 366// We let them jump anywhere and stop jumping after a while. 367 368// UnpackDomainName unpacks a domain name into a string. It returns 369// the name, the new offset into msg and any error that occurred. 370// 371// When an error is encountered, the unpacked name will be discarded 372// and len(msg) will be returned as the offset. 373func UnpackDomainName(msg []byte, off int) (string, int, error) { 374 s := make([]byte, 0, maxDomainNamePresentationLength) 375 off1 := 0 376 lenmsg := len(msg) 377 budget := maxDomainNameWireOctets 378 ptr := 0 // number of pointers followed 379Loop: 380 for { 381 if off >= lenmsg { 382 return "", lenmsg, ErrBuf 383 } 384 c := int(msg[off]) 385 off++ 386 switch c & 0xC0 { 387 case 0x00: 388 if c == 0x00 { 389 // end of name 390 break Loop 391 } 392 // literal string 393 if off+c > lenmsg { 394 return "", lenmsg, ErrBuf 395 } 396 budget -= c + 1 // +1 for the label separator 397 if budget <= 0 { 398 return "", lenmsg, ErrLongDomain 399 } 400 for _, b := range msg[off : off+c] { 401 switch b { 402 case '.', '(', ')', ';', ' ', '@': 403 fallthrough 404 case '"', '\\': 405 s = append(s, '\\', b) 406 default: 407 if b < ' ' || b > '~' { // unprintable, use \DDD 408 s = append(s, escapeByte(b)...) 409 } else { 410 s = append(s, b) 411 } 412 } 413 } 414 s = append(s, '.') 415 off += c 416 case 0xC0: 417 // pointer to somewhere else in msg. 418 // remember location after first ptr, 419 // since that's how many bytes we consumed. 420 // also, don't follow too many pointers -- 421 // maybe there's a loop. 422 if off >= lenmsg { 423 return "", lenmsg, ErrBuf 424 } 425 c1 := msg[off] 426 off++ 427 if ptr == 0 { 428 off1 = off 429 } 430 if ptr++; ptr > maxCompressionPointers { 431 return "", lenmsg, &Error{err: "too many compression pointers"} 432 } 433 // pointer should guarantee that it advances and points forwards at least 434 // but the condition on previous three lines guarantees that it's 435 // at least loop-free 436 off = (c^0xC0)<<8 | int(c1) 437 default: 438 // 0x80 and 0x40 are reserved 439 return "", lenmsg, ErrRdata 440 } 441 } 442 if ptr == 0 { 443 off1 = off 444 } 445 if len(s) == 0 { 446 return ".", off1, nil 447 } 448 return string(s), off1, nil 449} 450 451func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) { 452 if len(txt) == 0 { 453 if offset >= len(msg) { 454 return offset, ErrBuf 455 } 456 msg[offset] = 0 457 return offset, nil 458 } 459 var err error 460 for _, s := range txt { 461 if len(s) > len(tmp) { 462 return offset, ErrBuf 463 } 464 offset, err = packTxtString(s, msg, offset, tmp) 465 if err != nil { 466 return offset, err 467 } 468 } 469 return offset, nil 470} 471 472func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) { 473 lenByteOffset := offset 474 if offset >= len(msg) || len(s) > len(tmp) { 475 return offset, ErrBuf 476 } 477 offset++ 478 bs := tmp[:len(s)] 479 copy(bs, s) 480 for i := 0; i < len(bs); i++ { 481 if len(msg) <= offset { 482 return offset, ErrBuf 483 } 484 if bs[i] == '\\' { 485 i++ 486 if i == len(bs) { 487 break 488 } 489 // check for \DDD 490 if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) { 491 msg[offset] = dddToByte(bs[i:]) 492 i += 2 493 } else { 494 msg[offset] = bs[i] 495 } 496 } else { 497 msg[offset] = bs[i] 498 } 499 offset++ 500 } 501 l := offset - lenByteOffset - 1 502 if l > 255 { 503 return offset, &Error{err: "string exceeded 255 bytes in txt"} 504 } 505 msg[lenByteOffset] = byte(l) 506 return offset, nil 507} 508 509func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) { 510 if offset >= len(msg) || len(s) > len(tmp) { 511 return offset, ErrBuf 512 } 513 bs := tmp[:len(s)] 514 copy(bs, s) 515 for i := 0; i < len(bs); i++ { 516 if len(msg) <= offset { 517 return offset, ErrBuf 518 } 519 if bs[i] == '\\' { 520 i++ 521 if i == len(bs) { 522 break 523 } 524 // check for \DDD 525 if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) { 526 msg[offset] = dddToByte(bs[i:]) 527 i += 2 528 } else { 529 msg[offset] = bs[i] 530 } 531 } else { 532 msg[offset] = bs[i] 533 } 534 offset++ 535 } 536 return offset, nil 537} 538 539func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) { 540 off = off0 541 var s string 542 for off < len(msg) && err == nil { 543 s, off, err = unpackString(msg, off) 544 if err == nil { 545 ss = append(ss, s) 546 } 547 } 548 return 549} 550 551// Helpers for dealing with escaped bytes 552func isDigit(b byte) bool { return b >= '0' && b <= '9' } 553 554func dddToByte(s []byte) byte { 555 _ = s[2] // bounds check hint to compiler; see golang.org/issue/14808 556 return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0')) 557} 558 559func dddStringToByte(s string) byte { 560 _ = s[2] // bounds check hint to compiler; see golang.org/issue/14808 561 return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0')) 562} 563 564// Helper function for packing and unpacking 565func intToBytes(i *big.Int, length int) []byte { 566 buf := i.Bytes() 567 if len(buf) < length { 568 b := make([]byte, length) 569 copy(b[length-len(buf):], buf) 570 return b 571 } 572 return buf 573} 574 575// PackRR packs a resource record rr into msg[off:]. 576// See PackDomainName for documentation about the compression. 577func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { 578 headerEnd, off1, err := packRR(rr, msg, off, compressionMap{ext: compression}, compress) 579 if err == nil { 580 // packRR no longer sets the Rdlength field on the rr, but 581 // callers might be expecting it so we set it here. 582 rr.Header().Rdlength = uint16(off1 - headerEnd) 583 } 584 return off1, err 585} 586 587func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) { 588 if rr == nil { 589 return len(msg), len(msg), &Error{err: "nil rr"} 590 } 591 592 headerEnd, err = rr.Header().packHeader(msg, off, compression, compress) 593 if err != nil { 594 return headerEnd, len(msg), err 595 } 596 597 off1, err = rr.pack(msg, headerEnd, compression, compress) 598 if err != nil { 599 return headerEnd, len(msg), err 600 } 601 602 rdlength := off1 - headerEnd 603 if int(uint16(rdlength)) != rdlength { // overflow 604 return headerEnd, len(msg), ErrRdata 605 } 606 607 // The RDLENGTH field is the last field in the header and we set it here. 608 binary.BigEndian.PutUint16(msg[headerEnd-2:], uint16(rdlength)) 609 return headerEnd, off1, nil 610} 611 612// UnpackRR unpacks msg[off:] into an RR. 613func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) { 614 h, off, msg, err := unpackHeader(msg, off) 615 if err != nil { 616 return nil, len(msg), err 617 } 618 619 return UnpackRRWithHeader(h, msg, off) 620} 621 622// UnpackRRWithHeader unpacks the record type specific payload given an existing 623// RR_Header. 624func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) { 625 if newFn, ok := TypeToRR[h.Rrtype]; ok { 626 rr = newFn() 627 *rr.Header() = h 628 } else { 629 rr = &RFC3597{Hdr: h} 630 } 631 632 if noRdata(h) { 633 return rr, off, nil 634 } 635 636 end := off + int(h.Rdlength) 637 638 off, err = rr.unpack(msg, off) 639 if err != nil { 640 return nil, end, err 641 } 642 if off != end { 643 return &h, end, &Error{err: "bad rdlength"} 644 } 645 646 return rr, off, nil 647} 648 649// unpackRRslice unpacks msg[off:] into an []RR. 650// If we cannot unpack the whole array, then it will return nil 651func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) { 652 var r RR 653 // Don't pre-allocate, l may be under attacker control 654 var dst []RR 655 for i := 0; i < l; i++ { 656 off1 := off 657 r, off, err = UnpackRR(msg, off) 658 if err != nil { 659 off = len(msg) 660 break 661 } 662 // If offset does not increase anymore, l is a lie 663 if off1 == off { 664 l = i 665 break 666 } 667 dst = append(dst, r) 668 } 669 if err != nil && off == len(msg) { 670 dst = nil 671 } 672 return dst, off, err 673} 674 675// Convert a MsgHdr to a string, with dig-like headers: 676// 677//;; opcode: QUERY, status: NOERROR, id: 48404 678// 679//;; flags: qr aa rd ra; 680func (h *MsgHdr) String() string { 681 if h == nil { 682 return "<nil> MsgHdr" 683 } 684 685 s := ";; opcode: " + OpcodeToString[h.Opcode] 686 s += ", status: " + RcodeToString[h.Rcode] 687 s += ", id: " + strconv.Itoa(int(h.Id)) + "\n" 688 689 s += ";; flags:" 690 if h.Response { 691 s += " qr" 692 } 693 if h.Authoritative { 694 s += " aa" 695 } 696 if h.Truncated { 697 s += " tc" 698 } 699 if h.RecursionDesired { 700 s += " rd" 701 } 702 if h.RecursionAvailable { 703 s += " ra" 704 } 705 if h.Zero { // Hmm 706 s += " z" 707 } 708 if h.AuthenticatedData { 709 s += " ad" 710 } 711 if h.CheckingDisabled { 712 s += " cd" 713 } 714 715 s += ";" 716 return s 717} 718 719// Pack packs a Msg: it is converted to to wire format. 720// If the dns.Compress is true the message will be in compressed wire format. 721func (dns *Msg) Pack() (msg []byte, err error) { 722 return dns.PackBuffer(nil) 723} 724 725// PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated. 726func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) { 727 // If this message can't be compressed, avoid filling the 728 // compression map and creating garbage. 729 if dns.Compress && dns.isCompressible() { 730 compression := make(map[string]uint16) // Compression pointer mappings. 731 return dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true) 732 } 733 734 return dns.packBufferWithCompressionMap(buf, compressionMap{}, false) 735} 736 737// packBufferWithCompressionMap packs a Msg, using the given buffer buf. 738func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compressionMap, compress bool) (msg []byte, err error) { 739 if dns.Rcode < 0 || dns.Rcode > 0xFFF { 740 return nil, ErrRcode 741 } 742 743 // Set extended rcode unconditionally if we have an opt, this will allow 744 // reseting the extended rcode bits if they need to. 745 if opt := dns.IsEdns0(); opt != nil { 746 opt.SetExtendedRcode(uint16(dns.Rcode)) 747 } else if dns.Rcode > 0xF { 748 // If Rcode is an extended one and opt is nil, error out. 749 return nil, ErrExtendedRcode 750 } 751 752 // Convert convenient Msg into wire-like Header. 753 var dh Header 754 dh.Id = dns.Id 755 dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF) 756 if dns.Response { 757 dh.Bits |= _QR 758 } 759 if dns.Authoritative { 760 dh.Bits |= _AA 761 } 762 if dns.Truncated { 763 dh.Bits |= _TC 764 } 765 if dns.RecursionDesired { 766 dh.Bits |= _RD 767 } 768 if dns.RecursionAvailable { 769 dh.Bits |= _RA 770 } 771 if dns.Zero { 772 dh.Bits |= _Z 773 } 774 if dns.AuthenticatedData { 775 dh.Bits |= _AD 776 } 777 if dns.CheckingDisabled { 778 dh.Bits |= _CD 779 } 780 781 dh.Qdcount = uint16(len(dns.Question)) 782 dh.Ancount = uint16(len(dns.Answer)) 783 dh.Nscount = uint16(len(dns.Ns)) 784 dh.Arcount = uint16(len(dns.Extra)) 785 786 // We need the uncompressed length here, because we first pack it and then compress it. 787 msg = buf 788 uncompressedLen := msgLenWithCompressionMap(dns, nil) 789 if packLen := uncompressedLen + 1; len(msg) < packLen { 790 msg = make([]byte, packLen) 791 } 792 793 // Pack it in: header and then the pieces. 794 off := 0 795 off, err = dh.pack(msg, off, compression, compress) 796 if err != nil { 797 return nil, err 798 } 799 for _, r := range dns.Question { 800 off, err = r.pack(msg, off, compression, compress) 801 if err != nil { 802 return nil, err 803 } 804 } 805 for _, r := range dns.Answer { 806 _, off, err = packRR(r, msg, off, compression, compress) 807 if err != nil { 808 return nil, err 809 } 810 } 811 for _, r := range dns.Ns { 812 _, off, err = packRR(r, msg, off, compression, compress) 813 if err != nil { 814 return nil, err 815 } 816 } 817 for _, r := range dns.Extra { 818 _, off, err = packRR(r, msg, off, compression, compress) 819 if err != nil { 820 return nil, err 821 } 822 } 823 return msg[:off], nil 824} 825 826func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) { 827 // If we are at the end of the message we should return *just* the 828 // header. This can still be useful to the caller. 9.9.9.9 sends these 829 // when responding with REFUSED for instance. 830 if off == len(msg) { 831 // reset sections before returning 832 dns.Question, dns.Answer, dns.Ns, dns.Extra = nil, nil, nil, nil 833 return nil 834 } 835 836 // Qdcount, Ancount, Nscount, Arcount can't be trusted, as they are 837 // attacker controlled. This means we can't use them to pre-allocate 838 // slices. 839 dns.Question = nil 840 for i := 0; i < int(dh.Qdcount); i++ { 841 off1 := off 842 var q Question 843 q, off, err = unpackQuestion(msg, off) 844 if err != nil { 845 return err 846 } 847 if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie! 848 dh.Qdcount = uint16(i) 849 break 850 } 851 dns.Question = append(dns.Question, q) 852 } 853 854 dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off) 855 // The header counts might have been wrong so we need to update it 856 dh.Ancount = uint16(len(dns.Answer)) 857 if err == nil { 858 dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off) 859 } 860 // The header counts might have been wrong so we need to update it 861 dh.Nscount = uint16(len(dns.Ns)) 862 if err == nil { 863 dns.Extra, off, err = unpackRRslice(int(dh.Arcount), msg, off) 864 } 865 // The header counts might have been wrong so we need to update it 866 dh.Arcount = uint16(len(dns.Extra)) 867 868 // Set extended Rcode 869 if opt := dns.IsEdns0(); opt != nil { 870 dns.Rcode |= opt.ExtendedRcode() 871 } 872 873 if off != len(msg) { 874 // TODO(miek) make this an error? 875 // use PackOpt to let people tell how detailed the error reporting should be? 876 // println("dns: extra bytes in dns packet", off, "<", len(msg)) 877 } 878 return err 879 880} 881 882// Unpack unpacks a binary message to a Msg structure. 883func (dns *Msg) Unpack(msg []byte) (err error) { 884 dh, off, err := unpackMsgHdr(msg, 0) 885 if err != nil { 886 return err 887 } 888 889 dns.setHdr(dh) 890 return dns.unpack(dh, msg, off) 891} 892 893// Convert a complete message to a string with dig-like output. 894func (dns *Msg) String() string { 895 if dns == nil { 896 return "<nil> MsgHdr" 897 } 898 s := dns.MsgHdr.String() + " " 899 s += "QUERY: " + strconv.Itoa(len(dns.Question)) + ", " 900 s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", " 901 s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", " 902 s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n" 903 if len(dns.Question) > 0 { 904 s += "\n;; QUESTION SECTION:\n" 905 for _, r := range dns.Question { 906 s += r.String() + "\n" 907 } 908 } 909 if len(dns.Answer) > 0 { 910 s += "\n;; ANSWER SECTION:\n" 911 for _, r := range dns.Answer { 912 if r != nil { 913 s += r.String() + "\n" 914 } 915 } 916 } 917 if len(dns.Ns) > 0 { 918 s += "\n;; AUTHORITY SECTION:\n" 919 for _, r := range dns.Ns { 920 if r != nil { 921 s += r.String() + "\n" 922 } 923 } 924 } 925 if len(dns.Extra) > 0 { 926 s += "\n;; ADDITIONAL SECTION:\n" 927 for _, r := range dns.Extra { 928 if r != nil { 929 s += r.String() + "\n" 930 } 931 } 932 } 933 return s 934} 935 936// isCompressible returns whether the msg may be compressible. 937func (dns *Msg) isCompressible() bool { 938 // If we only have one question, there is nothing we can ever compress. 939 return len(dns.Question) > 1 || len(dns.Answer) > 0 || 940 len(dns.Ns) > 0 || len(dns.Extra) > 0 941} 942 943// Len returns the message length when in (un)compressed wire format. 944// If dns.Compress is true compression it is taken into account. Len() 945// is provided to be a faster way to get the size of the resulting packet, 946// than packing it, measuring the size and discarding the buffer. 947func (dns *Msg) Len() int { 948 // If this message can't be compressed, avoid filling the 949 // compression map and creating garbage. 950 if dns.Compress && dns.isCompressible() { 951 compression := make(map[string]struct{}) 952 return msgLenWithCompressionMap(dns, compression) 953 } 954 955 return msgLenWithCompressionMap(dns, nil) 956} 957 958func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int { 959 l := headerSize 960 961 for _, r := range dns.Question { 962 l += r.len(l, compression) 963 } 964 for _, r := range dns.Answer { 965 if r != nil { 966 l += r.len(l, compression) 967 } 968 } 969 for _, r := range dns.Ns { 970 if r != nil { 971 l += r.len(l, compression) 972 } 973 } 974 for _, r := range dns.Extra { 975 if r != nil { 976 l += r.len(l, compression) 977 } 978 } 979 980 return l 981} 982 983func domainNameLen(s string, off int, compression map[string]struct{}, compress bool) int { 984 if s == "" || s == "." { 985 return 1 986 } 987 988 escaped := strings.Contains(s, "\\") 989 990 if compression != nil && (compress || off < maxCompressionOffset) { 991 // compressionLenSearch will insert the entry into the compression 992 // map if it doesn't contain it. 993 if l, ok := compressionLenSearch(compression, s, off); ok && compress { 994 if escaped { 995 return escapedNameLen(s[:l]) + 2 996 } 997 998 return l + 2 999 } 1000 } 1001 1002 if escaped { 1003 return escapedNameLen(s) + 1 1004 } 1005 1006 return len(s) + 1 1007} 1008 1009func escapedNameLen(s string) int { 1010 nameLen := len(s) 1011 for i := 0; i < len(s); i++ { 1012 if s[i] != '\\' { 1013 continue 1014 } 1015 1016 if i+3 < len(s) && isDigit(s[i+1]) && isDigit(s[i+2]) && isDigit(s[i+3]) { 1017 nameLen -= 3 1018 i += 3 1019 } else { 1020 nameLen-- 1021 i++ 1022 } 1023 } 1024 1025 return nameLen 1026} 1027 1028func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, bool) { 1029 for off, end := 0, false; !end; off, end = NextLabel(s, off) { 1030 if _, ok := c[s[off:]]; ok { 1031 return off, true 1032 } 1033 1034 if msgOff+off < maxCompressionOffset { 1035 c[s[off:]] = struct{}{} 1036 } 1037 } 1038 1039 return 0, false 1040} 1041 1042// Copy returns a new RR which is a deep-copy of r. 1043func Copy(r RR) RR { return r.copy() } 1044 1045// Len returns the length (in octets) of the uncompressed RR in wire format. 1046func Len(r RR) int { return r.len(0, nil) } 1047 1048// Copy returns a new *Msg which is a deep-copy of dns. 1049func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) } 1050 1051// CopyTo copies the contents to the provided message using a deep-copy and returns the copy. 1052func (dns *Msg) CopyTo(r1 *Msg) *Msg { 1053 r1.MsgHdr = dns.MsgHdr 1054 r1.Compress = dns.Compress 1055 1056 if len(dns.Question) > 0 { 1057 r1.Question = make([]Question, len(dns.Question)) 1058 copy(r1.Question, dns.Question) // TODO(miek): Question is an immutable value, ok to do a shallow-copy 1059 } 1060 1061 rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra)) 1062 r1.Answer, rrArr = rrArr[:0:len(dns.Answer)], rrArr[len(dns.Answer):] 1063 r1.Ns, rrArr = rrArr[:0:len(dns.Ns)], rrArr[len(dns.Ns):] 1064 r1.Extra = rrArr[:0:len(dns.Extra)] 1065 1066 for _, r := range dns.Answer { 1067 r1.Answer = append(r1.Answer, r.copy()) 1068 } 1069 1070 for _, r := range dns.Ns { 1071 r1.Ns = append(r1.Ns, r.copy()) 1072 } 1073 1074 for _, r := range dns.Extra { 1075 r1.Extra = append(r1.Extra, r.copy()) 1076 } 1077 1078 return r1 1079} 1080 1081func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) { 1082 off, err := packDomainName(q.Name, msg, off, compression, compress) 1083 if err != nil { 1084 return off, err 1085 } 1086 off, err = packUint16(q.Qtype, msg, off) 1087 if err != nil { 1088 return off, err 1089 } 1090 off, err = packUint16(q.Qclass, msg, off) 1091 if err != nil { 1092 return off, err 1093 } 1094 return off, nil 1095} 1096 1097func unpackQuestion(msg []byte, off int) (Question, int, error) { 1098 var ( 1099 q Question 1100 err error 1101 ) 1102 q.Name, off, err = UnpackDomainName(msg, off) 1103 if err != nil { 1104 return q, off, err 1105 } 1106 if off == len(msg) { 1107 return q, off, nil 1108 } 1109 q.Qtype, off, err = unpackUint16(msg, off) 1110 if err != nil { 1111 return q, off, err 1112 } 1113 if off == len(msg) { 1114 return q, off, nil 1115 } 1116 q.Qclass, off, err = unpackUint16(msg, off) 1117 if off == len(msg) { 1118 return q, off, nil 1119 } 1120 return q, off, err 1121} 1122 1123func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) { 1124 off, err := packUint16(dh.Id, msg, off) 1125 if err != nil { 1126 return off, err 1127 } 1128 off, err = packUint16(dh.Bits, msg, off) 1129 if err != nil { 1130 return off, err 1131 } 1132 off, err = packUint16(dh.Qdcount, msg, off) 1133 if err != nil { 1134 return off, err 1135 } 1136 off, err = packUint16(dh.Ancount, msg, off) 1137 if err != nil { 1138 return off, err 1139 } 1140 off, err = packUint16(dh.Nscount, msg, off) 1141 if err != nil { 1142 return off, err 1143 } 1144 off, err = packUint16(dh.Arcount, msg, off) 1145 if err != nil { 1146 return off, err 1147 } 1148 return off, nil 1149} 1150 1151func unpackMsgHdr(msg []byte, off int) (Header, int, error) { 1152 var ( 1153 dh Header 1154 err error 1155 ) 1156 dh.Id, off, err = unpackUint16(msg, off) 1157 if err != nil { 1158 return dh, off, err 1159 } 1160 dh.Bits, off, err = unpackUint16(msg, off) 1161 if err != nil { 1162 return dh, off, err 1163 } 1164 dh.Qdcount, off, err = unpackUint16(msg, off) 1165 if err != nil { 1166 return dh, off, err 1167 } 1168 dh.Ancount, off, err = unpackUint16(msg, off) 1169 if err != nil { 1170 return dh, off, err 1171 } 1172 dh.Nscount, off, err = unpackUint16(msg, off) 1173 if err != nil { 1174 return dh, off, err 1175 } 1176 dh.Arcount, off, err = unpackUint16(msg, off) 1177 if err != nil { 1178 return dh, off, err 1179 } 1180 return dh, off, nil 1181} 1182 1183// setHdr set the header in the dns using the binary data in dh. 1184func (dns *Msg) setHdr(dh Header) { 1185 dns.Id = dh.Id 1186 dns.Response = dh.Bits&_QR != 0 1187 dns.Opcode = int(dh.Bits>>11) & 0xF 1188 dns.Authoritative = dh.Bits&_AA != 0 1189 dns.Truncated = dh.Bits&_TC != 0 1190 dns.RecursionDesired = dh.Bits&_RD != 0 1191 dns.RecursionAvailable = dh.Bits&_RA != 0 1192 dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite. 1193 dns.AuthenticatedData = dh.Bits&_AD != 0 1194 dns.CheckingDisabled = dh.Bits&_CD != 0 1195 dns.Rcode = int(dh.Bits & 0xF) 1196} 1197