1package dns
2
3import (
4	"errors"
5	"net"
6	"strconv"
7	"strings"
8)
9
10const hexDigit = "0123456789abcdef"
11
12// Everything is assumed in ClassINET.
13
14// SetReply creates a reply message from a request message.
15func (dns *Msg) SetReply(request *Msg) *Msg {
16	dns.Id = request.Id
17	dns.Response = true
18	dns.Opcode = request.Opcode
19	if dns.Opcode == OpcodeQuery {
20		dns.RecursionDesired = request.RecursionDesired // Copy rd bit
21		dns.CheckingDisabled = request.CheckingDisabled // Copy cd bit
22	}
23	dns.Rcode = RcodeSuccess
24	if len(request.Question) > 0 {
25		dns.Question = make([]Question, 1)
26		dns.Question[0] = request.Question[0]
27	}
28	return dns
29}
30
31// SetQuestion creates a question message, it sets the Question
32// section, generates an Id and sets the RecursionDesired (RD)
33// bit to true.
34func (dns *Msg) SetQuestion(z string, t uint16) *Msg {
35	dns.Id = Id()
36	dns.RecursionDesired = true
37	dns.Question = make([]Question, 1)
38	dns.Question[0] = Question{z, t, ClassINET}
39	return dns
40}
41
42// SetNotify creates a notify message, it sets the Question
43// section, generates an Id and sets the Authoritative (AA)
44// bit to true.
45func (dns *Msg) SetNotify(z string) *Msg {
46	dns.Opcode = OpcodeNotify
47	dns.Authoritative = true
48	dns.Id = Id()
49	dns.Question = make([]Question, 1)
50	dns.Question[0] = Question{z, TypeSOA, ClassINET}
51	return dns
52}
53
54// SetRcode creates an error message suitable for the request.
55func (dns *Msg) SetRcode(request *Msg, rcode int) *Msg {
56	dns.SetReply(request)
57	dns.Rcode = rcode
58	return dns
59}
60
61// SetRcodeFormatError creates a message with FormError set.
62func (dns *Msg) SetRcodeFormatError(request *Msg) *Msg {
63	dns.Rcode = RcodeFormatError
64	dns.Opcode = OpcodeQuery
65	dns.Response = true
66	dns.Authoritative = false
67	dns.Id = request.Id
68	return dns
69}
70
71// SetUpdate makes the message a dynamic update message. It
72// sets the ZONE section to: z, TypeSOA, ClassINET.
73func (dns *Msg) SetUpdate(z string) *Msg {
74	dns.Id = Id()
75	dns.Response = false
76	dns.Opcode = OpcodeUpdate
77	dns.Compress = false // BIND9 cannot handle compression
78	dns.Question = make([]Question, 1)
79	dns.Question[0] = Question{z, TypeSOA, ClassINET}
80	return dns
81}
82
83// SetIxfr creates message for requesting an IXFR.
84func (dns *Msg) SetIxfr(z string, serial uint32, ns, mbox string) *Msg {
85	dns.Id = Id()
86	dns.Question = make([]Question, 1)
87	dns.Ns = make([]RR, 1)
88	s := new(SOA)
89	s.Hdr = RR_Header{z, TypeSOA, ClassINET, defaultTtl, 0}
90	s.Serial = serial
91	s.Ns = ns
92	s.Mbox = mbox
93	dns.Question[0] = Question{z, TypeIXFR, ClassINET}
94	dns.Ns[0] = s
95	return dns
96}
97
98// SetAxfr creates message for requesting an AXFR.
99func (dns *Msg) SetAxfr(z string) *Msg {
100	dns.Id = Id()
101	dns.Question = make([]Question, 1)
102	dns.Question[0] = Question{z, TypeAXFR, ClassINET}
103	return dns
104}
105
106// SetTsig appends a TSIG RR to the message.
107// This is only a skeleton TSIG RR that is added as the last RR in the
108// additional section. The TSIG is calculated when the message is being send.
109func (dns *Msg) SetTsig(z, algo string, fudge uint16, timesigned int64) *Msg {
110	t := new(TSIG)
111	t.Hdr = RR_Header{z, TypeTSIG, ClassANY, 0, 0}
112	t.Algorithm = algo
113	t.Fudge = fudge
114	t.TimeSigned = uint64(timesigned)
115	t.OrigId = dns.Id
116	dns.Extra = append(dns.Extra, t)
117	return dns
118}
119
120// SetEdns0 appends a EDNS0 OPT RR to the message.
121// TSIG should always the last RR in a message.
122func (dns *Msg) SetEdns0(udpsize uint16, do bool) *Msg {
123	e := new(OPT)
124	e.Hdr.Name = "."
125	e.Hdr.Rrtype = TypeOPT
126	e.SetUDPSize(udpsize)
127	if do {
128		e.SetDo()
129	}
130	dns.Extra = append(dns.Extra, e)
131	return dns
132}
133
134// IsTsig checks if the message has a TSIG record as the last record
135// in the additional section. It returns the TSIG record found or nil.
136func (dns *Msg) IsTsig() *TSIG {
137	if len(dns.Extra) > 0 {
138		if dns.Extra[len(dns.Extra)-1].Header().Rrtype == TypeTSIG {
139			return dns.Extra[len(dns.Extra)-1].(*TSIG)
140		}
141	}
142	return nil
143}
144
145// IsEdns0 checks if the message has a EDNS0 (OPT) record, any EDNS0
146// record in the additional section will do. It returns the OPT record
147// found or nil.
148func (dns *Msg) IsEdns0() *OPT {
149	// RFC 6891, Section 6.1.1 allows the OPT record to appear
150	// anywhere in the additional record section, but it's usually at
151	// the end so start there.
152	for i := len(dns.Extra) - 1; i >= 0; i-- {
153		if dns.Extra[i].Header().Rrtype == TypeOPT {
154			return dns.Extra[i].(*OPT)
155		}
156	}
157	return nil
158}
159
160// popEdns0 is like IsEdns0, but it removes the record from the message.
161func (dns *Msg) popEdns0() *OPT {
162	// RFC 6891, Section 6.1.1 allows the OPT record to appear
163	// anywhere in the additional record section, but it's usually at
164	// the end so start there.
165	for i := len(dns.Extra) - 1; i >= 0; i-- {
166		if dns.Extra[i].Header().Rrtype == TypeOPT {
167			opt := dns.Extra[i].(*OPT)
168			dns.Extra = append(dns.Extra[:i], dns.Extra[i+1:]...)
169			return opt
170		}
171	}
172	return nil
173}
174
175// IsDomainName checks if s is a valid domain name, it returns the number of
176// labels and true, when a domain name is valid.  Note that non fully qualified
177// domain name is considered valid, in this case the last label is counted in
178// the number of labels.  When false is returned the number of labels is not
179// defined.  Also note that this function is extremely liberal; almost any
180// string is a valid domain name as the DNS is 8 bit protocol. It checks if each
181// label fits in 63 characters and that the entire name will fit into the 255
182// octet wire format limit.
183func IsDomainName(s string) (labels int, ok bool) {
184	// XXX: The logic in this function was copied from packDomainName and
185	// should be kept in sync with that function.
186
187	const lenmsg = 256
188
189	if len(s) == 0 { // Ok, for instance when dealing with update RR without any rdata.
190		return 0, false
191	}
192
193	s = Fqdn(s)
194
195	// Each dot ends a segment of the name. Except for escaped dots (\.), which
196	// are normal dots.
197
198	var (
199		off    int
200		begin  int
201		wasDot bool
202	)
203	for i := 0; i < len(s); i++ {
204		switch s[i] {
205		case '\\':
206			if off+1 > lenmsg {
207				return labels, false
208			}
209
210			// check for \DDD
211			if i+3 < len(s) && isDigit(s[i+1]) && isDigit(s[i+2]) && isDigit(s[i+3]) {
212				i += 3
213				begin += 3
214			} else {
215				i++
216				begin++
217			}
218
219			wasDot = false
220		case '.':
221			if wasDot {
222				// two dots back to back is not legal
223				return labels, false
224			}
225			wasDot = true
226
227			labelLen := i - begin
228			if labelLen >= 1<<6 { // top two bits of length must be clear
229				return labels, false
230			}
231
232			// off can already (we're in a loop) be bigger than lenmsg
233			// this happens when a name isn't fully qualified
234			off += 1 + labelLen
235			if off > lenmsg {
236				return labels, false
237			}
238
239			labels++
240			begin = i + 1
241		default:
242			wasDot = false
243		}
244	}
245
246	return labels, true
247}
248
249// IsSubDomain checks if child is indeed a child of the parent. If child and parent
250// are the same domain true is returned as well.
251func IsSubDomain(parent, child string) bool {
252	// Entire child is contained in parent
253	return CompareDomainName(parent, child) == CountLabel(parent)
254}
255
256// IsMsg sanity checks buf and returns an error if it isn't a valid DNS packet.
257// The checking is performed on the binary payload.
258func IsMsg(buf []byte) error {
259	// Header
260	if len(buf) < headerSize {
261		return errors.New("dns: bad message header")
262	}
263	// Header: Opcode
264	// TODO(miek): more checks here, e.g. check all header bits.
265	return nil
266}
267
268// IsFqdn checks if a domain name is fully qualified.
269func IsFqdn(s string) bool {
270	s2 := strings.TrimSuffix(s, ".")
271	if s == s2 {
272		return false
273	}
274
275	i := strings.LastIndexFunc(s2, func(r rune) bool {
276		return r != '\\'
277	})
278
279	// Test whether we have an even number of escape sequences before
280	// the dot or none.
281	return (len(s2)-i)%2 != 0
282}
283
284// IsRRset checks if a set of RRs is a valid RRset as defined by RFC 2181.
285// This means the RRs need to have the same type, name, and class. Returns true
286// if the RR set is valid, otherwise false.
287func IsRRset(rrset []RR) bool {
288	if len(rrset) == 0 {
289		return false
290	}
291	if len(rrset) == 1 {
292		return true
293	}
294	rrHeader := rrset[0].Header()
295	rrType := rrHeader.Rrtype
296	rrClass := rrHeader.Class
297	rrName := rrHeader.Name
298
299	for _, rr := range rrset[1:] {
300		curRRHeader := rr.Header()
301		if curRRHeader.Rrtype != rrType || curRRHeader.Class != rrClass || curRRHeader.Name != rrName {
302			// Mismatch between the records, so this is not a valid rrset for
303			//signing/verifying
304			return false
305		}
306	}
307
308	return true
309}
310
311// Fqdn return the fully qualified domain name from s.
312// If s is already fully qualified, it behaves as the identity function.
313func Fqdn(s string) string {
314	if IsFqdn(s) {
315		return s
316	}
317	return s + "."
318}
319
320// CanonicalName returns the domain name in canonical form. A name in canonical
321// form is lowercase and fully qualified. See Section 6.2 in RFC 4034.
322func CanonicalName(s string) string {
323	return strings.ToLower(Fqdn(s))
324}
325
326// Copied from the official Go code.
327
328// ReverseAddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP
329// address suitable for reverse DNS (PTR) record lookups or an error if it fails
330// to parse the IP address.
331func ReverseAddr(addr string) (arpa string, err error) {
332	ip := net.ParseIP(addr)
333	if ip == nil {
334		return "", &Error{err: "unrecognized address: " + addr}
335	}
336	if v4 := ip.To4(); v4 != nil {
337		buf := make([]byte, 0, net.IPv4len*4+len("in-addr.arpa."))
338		// Add it, in reverse, to the buffer
339		for i := len(v4) - 1; i >= 0; i-- {
340			buf = strconv.AppendInt(buf, int64(v4[i]), 10)
341			buf = append(buf, '.')
342		}
343		// Append "in-addr.arpa." and return (buf already has the final .)
344		buf = append(buf, "in-addr.arpa."...)
345		return string(buf), nil
346	}
347	// Must be IPv6
348	buf := make([]byte, 0, net.IPv6len*4+len("ip6.arpa."))
349	// Add it, in reverse, to the buffer
350	for i := len(ip) - 1; i >= 0; i-- {
351		v := ip[i]
352		buf = append(buf, hexDigit[v&0xF], '.', hexDigit[v>>4], '.')
353	}
354	// Append "ip6.arpa." and return (buf already has the final .)
355	buf = append(buf, "ip6.arpa."...)
356	return string(buf), nil
357}
358
359// String returns the string representation for the type t.
360func (t Type) String() string {
361	if t1, ok := TypeToString[uint16(t)]; ok {
362		return t1
363	}
364	return "TYPE" + strconv.Itoa(int(t))
365}
366
367// String returns the string representation for the class c.
368func (c Class) String() string {
369	if s, ok := ClassToString[uint16(c)]; ok {
370		// Only emit mnemonics when they are unambiguous, specially ANY is in both.
371		if _, ok := StringToType[s]; !ok {
372			return s
373		}
374	}
375	return "CLASS" + strconv.Itoa(int(c))
376}
377
378// String returns the string representation for the name n.
379func (n Name) String() string {
380	return sprintName(string(n))
381}
382