1package dns
2
3import (
4	"bytes"
5	"crypto"
6	"crypto/dsa"
7	"crypto/ecdsa"
8	"crypto/elliptic"
9	_ "crypto/md5"
10	"crypto/rand"
11	"crypto/rsa"
12	_ "crypto/sha1"
13	_ "crypto/sha256"
14	_ "crypto/sha512"
15	"encoding/asn1"
16	"encoding/hex"
17	"math/big"
18	"sort"
19	"strings"
20	"time"
21)
22
23// DNSSEC encryption algorithm codes.
24const (
25	_ uint8 = iota
26	RSAMD5
27	DH
28	DSA
29	_ // Skip 4, RFC 6725, section 2.1
30	RSASHA1
31	DSANSEC3SHA1
32	RSASHA1NSEC3SHA1
33	RSASHA256
34	_ // Skip 9, RFC 6725, section 2.1
35	RSASHA512
36	_ // Skip 11, RFC 6725, section 2.1
37	ECCGOST
38	ECDSAP256SHA256
39	ECDSAP384SHA384
40	INDIRECT   uint8 = 252
41	PRIVATEDNS uint8 = 253 // Private (experimental keys)
42	PRIVATEOID uint8 = 254
43)
44
45// Map for algorithm names.
46var AlgorithmToString = map[uint8]string{
47	RSAMD5:           "RSAMD5",
48	DH:               "DH",
49	DSA:              "DSA",
50	RSASHA1:          "RSASHA1",
51	DSANSEC3SHA1:     "DSA-NSEC3-SHA1",
52	RSASHA1NSEC3SHA1: "RSASHA1-NSEC3-SHA1",
53	RSASHA256:        "RSASHA256",
54	RSASHA512:        "RSASHA512",
55	ECCGOST:          "ECC-GOST",
56	ECDSAP256SHA256:  "ECDSAP256SHA256",
57	ECDSAP384SHA384:  "ECDSAP384SHA384",
58	INDIRECT:         "INDIRECT",
59	PRIVATEDNS:       "PRIVATEDNS",
60	PRIVATEOID:       "PRIVATEOID",
61}
62
63// Map of algorithm strings.
64var StringToAlgorithm = reverseInt8(AlgorithmToString)
65
66// Map of algorithm crypto hashes.
67var AlgorithmToHash = map[uint8]crypto.Hash{
68	RSAMD5:           crypto.MD5, // Deprecated in RFC 6725
69	RSASHA1:          crypto.SHA1,
70	RSASHA1NSEC3SHA1: crypto.SHA1,
71	RSASHA256:        crypto.SHA256,
72	ECDSAP256SHA256:  crypto.SHA256,
73	ECDSAP384SHA384:  crypto.SHA384,
74	RSASHA512:        crypto.SHA512,
75}
76
77// DNSSEC hashing algorithm codes.
78const (
79	_      uint8 = iota
80	SHA1         // RFC 4034
81	SHA256       // RFC 4509
82	GOST94       // RFC 5933
83	SHA384       // Experimental
84	SHA512       // Experimental
85)
86
87// Map for hash names.
88var HashToString = map[uint8]string{
89	SHA1:   "SHA1",
90	SHA256: "SHA256",
91	GOST94: "GOST94",
92	SHA384: "SHA384",
93	SHA512: "SHA512",
94}
95
96// Map of hash strings.
97var StringToHash = reverseInt8(HashToString)
98
99// DNSKEY flag values.
100const (
101	SEP    = 1
102	REVOKE = 1 << 7
103	ZONE   = 1 << 8
104)
105
106// The RRSIG needs to be converted to wireformat with some of
107// the rdata (the signature) missing. Use this struct to ease
108// the conversion (and re-use the pack/unpack functions).
109type rrsigWireFmt struct {
110	TypeCovered uint16
111	Algorithm   uint8
112	Labels      uint8
113	OrigTtl     uint32
114	Expiration  uint32
115	Inception   uint32
116	KeyTag      uint16
117	SignerName  string `dns:"domain-name"`
118	/* No Signature */
119}
120
121// Used for converting DNSKEY's rdata to wirefmt.
122type dnskeyWireFmt struct {
123	Flags     uint16
124	Protocol  uint8
125	Algorithm uint8
126	PublicKey string `dns:"base64"`
127	/* Nothing is left out */
128}
129
130func divRoundUp(a, b int) int {
131	return (a + b - 1) / b
132}
133
134// KeyTag calculates the keytag (or key-id) of the DNSKEY.
135func (k *DNSKEY) KeyTag() uint16 {
136	if k == nil {
137		return 0
138	}
139	var keytag int
140	switch k.Algorithm {
141	case RSAMD5:
142		// Look at the bottom two bytes of the modules, which the last
143		// item in the pubkey. We could do this faster by looking directly
144		// at the base64 values. But I'm lazy.
145		modulus, _ := fromBase64([]byte(k.PublicKey))
146		if len(modulus) > 1 {
147			x, _ := unpackUint16(modulus, len(modulus)-2)
148			keytag = int(x)
149		}
150	default:
151		keywire := new(dnskeyWireFmt)
152		keywire.Flags = k.Flags
153		keywire.Protocol = k.Protocol
154		keywire.Algorithm = k.Algorithm
155		keywire.PublicKey = k.PublicKey
156		wire := make([]byte, DefaultMsgSize)
157		n, err := PackStruct(keywire, wire, 0)
158		if err != nil {
159			return 0
160		}
161		wire = wire[:n]
162		for i, v := range wire {
163			if i&1 != 0 {
164				keytag += int(v) // must be larger than uint32
165			} else {
166				keytag += int(v) << 8
167			}
168		}
169		keytag += (keytag >> 16) & 0xFFFF
170		keytag &= 0xFFFF
171	}
172	return uint16(keytag)
173}
174
175// ToDS converts a DNSKEY record to a DS record.
176func (k *DNSKEY) ToDS(h uint8) *DS {
177	if k == nil {
178		return nil
179	}
180	ds := new(DS)
181	ds.Hdr.Name = k.Hdr.Name
182	ds.Hdr.Class = k.Hdr.Class
183	ds.Hdr.Rrtype = TypeDS
184	ds.Hdr.Ttl = k.Hdr.Ttl
185	ds.Algorithm = k.Algorithm
186	ds.DigestType = h
187	ds.KeyTag = k.KeyTag()
188
189	keywire := new(dnskeyWireFmt)
190	keywire.Flags = k.Flags
191	keywire.Protocol = k.Protocol
192	keywire.Algorithm = k.Algorithm
193	keywire.PublicKey = k.PublicKey
194	wire := make([]byte, DefaultMsgSize)
195	n, err := PackStruct(keywire, wire, 0)
196	if err != nil {
197		return nil
198	}
199	wire = wire[:n]
200
201	owner := make([]byte, 255)
202	off, err1 := PackDomainName(strings.ToLower(k.Hdr.Name), owner, 0, nil, false)
203	if err1 != nil {
204		return nil
205	}
206	owner = owner[:off]
207	// RFC4034:
208	// digest = digest_algorithm( DNSKEY owner name | DNSKEY RDATA);
209	// "|" denotes concatenation
210	// DNSKEY RDATA = Flags | Protocol | Algorithm | Public Key.
211
212	// digest buffer
213	digest := append(owner, wire...) // another copy
214
215	var hash crypto.Hash
216	switch h {
217	case SHA1:
218		hash = crypto.SHA1
219	case SHA256:
220		hash = crypto.SHA256
221	case SHA384:
222		hash = crypto.SHA384
223	case SHA512:
224		hash = crypto.SHA512
225	default:
226		return nil
227	}
228
229	s := hash.New()
230	s.Write(digest)
231	ds.Digest = hex.EncodeToString(s.Sum(nil))
232	return ds
233}
234
235// ToCDNSKEY converts a DNSKEY record to a CDNSKEY record.
236func (k *DNSKEY) ToCDNSKEY() *CDNSKEY {
237	c := &CDNSKEY{DNSKEY: *k}
238	c.Hdr = *k.Hdr.copyHeader()
239	c.Hdr.Rrtype = TypeCDNSKEY
240	return c
241}
242
243// ToCDS converts a DS record to a CDS record.
244func (d *DS) ToCDS() *CDS {
245	c := &CDS{DS: *d}
246	c.Hdr = *d.Hdr.copyHeader()
247	c.Hdr.Rrtype = TypeCDS
248	return c
249}
250
251// Sign signs an RRSet. The signature needs to be filled in with the values:
252// Inception, Expiration, KeyTag, SignerName and Algorithm.  The rest is copied
253// from the RRset. Sign returns a non-nill error when the signing went OK.
254// There is no check if RRSet is a proper (RFC 2181) RRSet.  If OrigTTL is non
255// zero, it is used as-is, otherwise the TTL of the RRset is used as the
256// OrigTTL.
257func (rr *RRSIG) Sign(k crypto.Signer, rrset []RR) error {
258	if k == nil {
259		return ErrPrivKey
260	}
261	// s.Inception and s.Expiration may be 0 (rollover etc.), the rest must be set
262	if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
263		return ErrKey
264	}
265
266	rr.Hdr.Rrtype = TypeRRSIG
267	rr.Hdr.Name = rrset[0].Header().Name
268	rr.Hdr.Class = rrset[0].Header().Class
269	if rr.OrigTtl == 0 { // If set don't override
270		rr.OrigTtl = rrset[0].Header().Ttl
271	}
272	rr.TypeCovered = rrset[0].Header().Rrtype
273	rr.Labels = uint8(CountLabel(rrset[0].Header().Name))
274
275	if strings.HasPrefix(rrset[0].Header().Name, "*") {
276		rr.Labels-- // wildcard, remove from label count
277	}
278
279	sigwire := new(rrsigWireFmt)
280	sigwire.TypeCovered = rr.TypeCovered
281	sigwire.Algorithm = rr.Algorithm
282	sigwire.Labels = rr.Labels
283	sigwire.OrigTtl = rr.OrigTtl
284	sigwire.Expiration = rr.Expiration
285	sigwire.Inception = rr.Inception
286	sigwire.KeyTag = rr.KeyTag
287	// For signing, lowercase this name
288	sigwire.SignerName = strings.ToLower(rr.SignerName)
289
290	// Create the desired binary blob
291	signdata := make([]byte, DefaultMsgSize)
292	n, err := PackStruct(sigwire, signdata, 0)
293	if err != nil {
294		return err
295	}
296	signdata = signdata[:n]
297	wire, err := rawSignatureData(rrset, rr)
298	if err != nil {
299		return err
300	}
301	signdata = append(signdata, wire...)
302
303	hash, ok := AlgorithmToHash[rr.Algorithm]
304	if !ok {
305		return ErrAlg
306	}
307
308	h := hash.New()
309	h.Write(signdata)
310
311	signature, err := sign(k, h.Sum(nil), hash, rr.Algorithm)
312	if err != nil {
313		return err
314	}
315
316	rr.Signature = toBase64(signature)
317
318	return nil
319}
320
321func sign(k crypto.Signer, hashed []byte, hash crypto.Hash, alg uint8) ([]byte, error) {
322	signature, err := k.Sign(rand.Reader, hashed, hash)
323	if err != nil {
324		return nil, err
325	}
326
327	switch alg {
328	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512:
329		return signature, nil
330
331	case ECDSAP256SHA256, ECDSAP384SHA384:
332		ecdsaSignature := &struct {
333			R, S *big.Int
334		}{}
335		if _, err := asn1.Unmarshal(signature, ecdsaSignature); err != nil {
336			return nil, err
337		}
338
339		var intlen int
340		switch alg {
341		case ECDSAP256SHA256:
342			intlen = 32
343		case ECDSAP384SHA384:
344			intlen = 48
345		}
346
347		signature := intToBytes(ecdsaSignature.R, intlen)
348		signature = append(signature, intToBytes(ecdsaSignature.S, intlen)...)
349		return signature, nil
350
351	// There is no defined interface for what a DSA backed crypto.Signer returns
352	case DSA, DSANSEC3SHA1:
353		// 	t := divRoundUp(divRoundUp(p.PublicKey.Y.BitLen(), 8)-64, 8)
354		// 	signature := []byte{byte(t)}
355		// 	signature = append(signature, intToBytes(r1, 20)...)
356		// 	signature = append(signature, intToBytes(s1, 20)...)
357		// 	rr.Signature = signature
358	}
359
360	return nil, ErrAlg
361}
362
363// Verify validates an RRSet with the signature and key. This is only the
364// cryptographic test, the signature validity period must be checked separately.
365// This function copies the rdata of some RRs (to lowercase domain names) for the validation to work.
366func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
367	// First the easy checks
368	if !IsRRset(rrset) {
369		return ErrRRset
370	}
371	if rr.KeyTag != k.KeyTag() {
372		return ErrKey
373	}
374	if rr.Hdr.Class != k.Hdr.Class {
375		return ErrKey
376	}
377	if rr.Algorithm != k.Algorithm {
378		return ErrKey
379	}
380	if strings.ToLower(rr.SignerName) != strings.ToLower(k.Hdr.Name) {
381		return ErrKey
382	}
383	if k.Protocol != 3 {
384		return ErrKey
385	}
386
387	// IsRRset checked that we have at least one RR and that the RRs in
388	// the set have consistent type, class, and name. Also check that type and
389	// class matches the RRSIG record.
390	if rrset[0].Header().Class != rr.Hdr.Class {
391		return ErrRRset
392	}
393	if rrset[0].Header().Rrtype != rr.TypeCovered {
394		return ErrRRset
395	}
396
397	// RFC 4035 5.3.2.  Reconstructing the Signed Data
398	// Copy the sig, except the rrsig data
399	sigwire := new(rrsigWireFmt)
400	sigwire.TypeCovered = rr.TypeCovered
401	sigwire.Algorithm = rr.Algorithm
402	sigwire.Labels = rr.Labels
403	sigwire.OrigTtl = rr.OrigTtl
404	sigwire.Expiration = rr.Expiration
405	sigwire.Inception = rr.Inception
406	sigwire.KeyTag = rr.KeyTag
407	sigwire.SignerName = strings.ToLower(rr.SignerName)
408	// Create the desired binary blob
409	signeddata := make([]byte, DefaultMsgSize)
410	n, err := PackStruct(sigwire, signeddata, 0)
411	if err != nil {
412		return err
413	}
414	signeddata = signeddata[:n]
415	wire, err := rawSignatureData(rrset, rr)
416	if err != nil {
417		return err
418	}
419	signeddata = append(signeddata, wire...)
420
421	sigbuf := rr.sigBuf()           // Get the binary signature data
422	if rr.Algorithm == PRIVATEDNS { // PRIVATEOID
423		// TODO(miek)
424		// remove the domain name and assume its ours?
425	}
426
427	hash, ok := AlgorithmToHash[rr.Algorithm]
428	if !ok {
429		return ErrAlg
430	}
431
432	switch rr.Algorithm {
433	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512, RSAMD5:
434		// TODO(mg): this can be done quicker, ie. cache the pubkey data somewhere??
435		pubkey := k.publicKeyRSA() // Get the key
436		if pubkey == nil {
437			return ErrKey
438		}
439
440		h := hash.New()
441		h.Write(signeddata)
442		return rsa.VerifyPKCS1v15(pubkey, hash, h.Sum(nil), sigbuf)
443
444	case ECDSAP256SHA256, ECDSAP384SHA384:
445		pubkey := k.publicKeyECDSA()
446		if pubkey == nil {
447			return ErrKey
448		}
449
450		// Split sigbuf into the r and s coordinates
451		r := new(big.Int).SetBytes(sigbuf[:len(sigbuf)/2])
452		s := new(big.Int).SetBytes(sigbuf[len(sigbuf)/2:])
453
454		h := hash.New()
455		h.Write(signeddata)
456		if ecdsa.Verify(pubkey, h.Sum(nil), r, s) {
457			return nil
458		}
459		return ErrSig
460
461	default:
462		return ErrAlg
463	}
464}
465
466// ValidityPeriod uses RFC1982 serial arithmetic to calculate
467// if a signature period is valid. If t is the zero time, the
468// current time is taken other t is. Returns true if the signature
469// is valid at the given time, otherwise returns false.
470func (rr *RRSIG) ValidityPeriod(t time.Time) bool {
471	var utc int64
472	if t.IsZero() {
473		utc = time.Now().UTC().Unix()
474	} else {
475		utc = t.UTC().Unix()
476	}
477	modi := (int64(rr.Inception) - utc) / year68
478	mode := (int64(rr.Expiration) - utc) / year68
479	ti := int64(rr.Inception) + (modi * year68)
480	te := int64(rr.Expiration) + (mode * year68)
481	return ti <= utc && utc <= te
482}
483
484// Return the signatures base64 encodedig sigdata as a byte slice.
485func (rr *RRSIG) sigBuf() []byte {
486	sigbuf, err := fromBase64([]byte(rr.Signature))
487	if err != nil {
488		return nil
489	}
490	return sigbuf
491}
492
493// publicKeyRSA returns the RSA public key from a DNSKEY record.
494func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey {
495	keybuf, err := fromBase64([]byte(k.PublicKey))
496	if err != nil {
497		return nil
498	}
499
500	// RFC 2537/3110, section 2. RSA Public KEY Resource Records
501	// Length is in the 0th byte, unless its zero, then it
502	// it in bytes 1 and 2 and its a 16 bit number
503	explen := uint16(keybuf[0])
504	keyoff := 1
505	if explen == 0 {
506		explen = uint16(keybuf[1])<<8 | uint16(keybuf[2])
507		keyoff = 3
508	}
509	pubkey := new(rsa.PublicKey)
510
511	pubkey.N = big.NewInt(0)
512	shift := uint64((explen - 1) * 8)
513	expo := uint64(0)
514	for i := int(explen - 1); i > 0; i-- {
515		expo += uint64(keybuf[keyoff+i]) << shift
516		shift -= 8
517	}
518	// Remainder
519	expo += uint64(keybuf[keyoff])
520	if expo > 2<<31 {
521		// Larger expo than supported.
522		// println("dns: F5 primes (or larger) are not supported")
523		return nil
524	}
525	pubkey.E = int(expo)
526
527	pubkey.N.SetBytes(keybuf[keyoff+int(explen):])
528	return pubkey
529}
530
531// publicKeyECDSA returns the Curve public key from the DNSKEY record.
532func (k *DNSKEY) publicKeyECDSA() *ecdsa.PublicKey {
533	keybuf, err := fromBase64([]byte(k.PublicKey))
534	if err != nil {
535		return nil
536	}
537	pubkey := new(ecdsa.PublicKey)
538	switch k.Algorithm {
539	case ECDSAP256SHA256:
540		pubkey.Curve = elliptic.P256()
541		if len(keybuf) != 64 {
542			// wrongly encoded key
543			return nil
544		}
545	case ECDSAP384SHA384:
546		pubkey.Curve = elliptic.P384()
547		if len(keybuf) != 96 {
548			// Wrongly encoded key
549			return nil
550		}
551	}
552	pubkey.X = big.NewInt(0)
553	pubkey.X.SetBytes(keybuf[:len(keybuf)/2])
554	pubkey.Y = big.NewInt(0)
555	pubkey.Y.SetBytes(keybuf[len(keybuf)/2:])
556	return pubkey
557}
558
559func (k *DNSKEY) publicKeyDSA() *dsa.PublicKey {
560	keybuf, err := fromBase64([]byte(k.PublicKey))
561	if err != nil {
562		return nil
563	}
564	if len(keybuf) < 22 {
565		return nil
566	}
567	t, keybuf := int(keybuf[0]), keybuf[1:]
568	size := 64 + t*8
569	q, keybuf := keybuf[:20], keybuf[20:]
570	if len(keybuf) != 3*size {
571		return nil
572	}
573	p, keybuf := keybuf[:size], keybuf[size:]
574	g, y := keybuf[:size], keybuf[size:]
575	pubkey := new(dsa.PublicKey)
576	pubkey.Parameters.Q = big.NewInt(0).SetBytes(q)
577	pubkey.Parameters.P = big.NewInt(0).SetBytes(p)
578	pubkey.Parameters.G = big.NewInt(0).SetBytes(g)
579	pubkey.Y = big.NewInt(0).SetBytes(y)
580	return pubkey
581}
582
583type wireSlice [][]byte
584
585func (p wireSlice) Len() int      { return len(p) }
586func (p wireSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
587func (p wireSlice) Less(i, j int) bool {
588	_, ioff, _ := UnpackDomainName(p[i], 0)
589	_, joff, _ := UnpackDomainName(p[j], 0)
590	return bytes.Compare(p[i][ioff+10:], p[j][joff+10:]) < 0
591}
592
593// Return the raw signature data.
594func rawSignatureData(rrset []RR, s *RRSIG) (buf []byte, err error) {
595	wires := make(wireSlice, len(rrset))
596	for i, r := range rrset {
597		r1 := r.copy()
598		r1.Header().Ttl = s.OrigTtl
599		labels := SplitDomainName(r1.Header().Name)
600		// 6.2. Canonical RR Form. (4) - wildcards
601		if len(labels) > int(s.Labels) {
602			// Wildcard
603			r1.Header().Name = "*." + strings.Join(labels[len(labels)-int(s.Labels):], ".") + "."
604		}
605		// RFC 4034: 6.2.  Canonical RR Form. (2) - domain name to lowercase
606		r1.Header().Name = strings.ToLower(r1.Header().Name)
607		// 6.2. Canonical RR Form. (3) - domain rdata to lowercase.
608		//   NS, MD, MF, CNAME, SOA, MB, MG, MR, PTR,
609		//   HINFO, MINFO, MX, RP, AFSDB, RT, SIG, PX, NXT, NAPTR, KX,
610		//   SRV, DNAME, A6
611		//
612		// RFC 6840 - Clarifications and Implementation Notes for DNS Security (DNSSEC):
613		//	Section 6.2 of [RFC4034] also erroneously lists HINFO as a record
614		//	that needs conversion to lowercase, and twice at that.  Since HINFO
615		//	records contain no domain names, they are not subject to case
616		//	conversion.
617		switch x := r1.(type) {
618		case *NS:
619			x.Ns = strings.ToLower(x.Ns)
620		case *CNAME:
621			x.Target = strings.ToLower(x.Target)
622		case *SOA:
623			x.Ns = strings.ToLower(x.Ns)
624			x.Mbox = strings.ToLower(x.Mbox)
625		case *MB:
626			x.Mb = strings.ToLower(x.Mb)
627		case *MG:
628			x.Mg = strings.ToLower(x.Mg)
629		case *MR:
630			x.Mr = strings.ToLower(x.Mr)
631		case *PTR:
632			x.Ptr = strings.ToLower(x.Ptr)
633		case *MINFO:
634			x.Rmail = strings.ToLower(x.Rmail)
635			x.Email = strings.ToLower(x.Email)
636		case *MX:
637			x.Mx = strings.ToLower(x.Mx)
638		case *NAPTR:
639			x.Replacement = strings.ToLower(x.Replacement)
640		case *KX:
641			x.Exchanger = strings.ToLower(x.Exchanger)
642		case *SRV:
643			x.Target = strings.ToLower(x.Target)
644		case *DNAME:
645			x.Target = strings.ToLower(x.Target)
646		}
647		// 6.2. Canonical RR Form. (5) - origTTL
648		wire := make([]byte, r1.len()+1) // +1 to be safe(r)
649		off, err1 := PackRR(r1, wire, 0, nil, false)
650		if err1 != nil {
651			return nil, err1
652		}
653		wire = wire[:off]
654		wires[i] = wire
655	}
656	sort.Sort(wires)
657	for i, wire := range wires {
658		if i > 0 && bytes.Equal(wire, wires[i-1]) {
659			continue
660		}
661		buf = append(buf, wire...)
662	}
663	return buf, nil
664}
665