1package dns
2
3import (
4	"crypto"
5	"crypto/dsa"
6	"crypto/ecdsa"
7	"crypto/rsa"
8	"math/big"
9	"strings"
10	"time"
11)
12
13// Sign signs a dns.Msg. It fills the signature with the appropriate data.
14// The SIG record should have the SignerName, KeyTag, Algorithm, Inception
15// and Expiration set.
16func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) {
17	if k == nil {
18		return nil, ErrPrivKey
19	}
20	if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
21		return nil, ErrKey
22	}
23	rr.Header().Rrtype = TypeSIG
24	rr.Header().Class = ClassANY
25	rr.Header().Ttl = 0
26	rr.Header().Name = "."
27	rr.OrigTtl = 0
28	rr.TypeCovered = 0
29	rr.Labels = 0
30
31	buf := make([]byte, m.Len()+rr.len())
32	mbuf, err := m.PackBuffer(buf)
33	if err != nil {
34		return nil, err
35	}
36	if &buf[0] != &mbuf[0] {
37		return nil, ErrBuf
38	}
39	off, err := PackRR(rr, buf, len(mbuf), nil, false)
40	if err != nil {
41		return nil, err
42	}
43	buf = buf[:off:cap(buf)]
44
45	hash, ok := AlgorithmToHash[rr.Algorithm]
46	if !ok {
47		return nil, ErrAlg
48	}
49
50	hasher := hash.New()
51	// Write SIG rdata
52	hasher.Write(buf[len(mbuf)+1+2+2+4+2:])
53	// Write message
54	hasher.Write(buf[:len(mbuf)])
55
56	signature, err := sign(k, hasher.Sum(nil), hash, rr.Algorithm)
57	if err != nil {
58		return nil, err
59	}
60
61	rr.Signature = toBase64(signature)
62	sig := string(signature)
63
64	buf = append(buf, sig...)
65	if len(buf) > int(^uint16(0)) {
66		return nil, ErrBuf
67	}
68	// Adjust sig data length
69	rdoff := len(mbuf) + 1 + 2 + 2 + 4
70	rdlen, _ := unpackUint16(buf, rdoff)
71	rdlen += uint16(len(sig))
72	buf[rdoff], buf[rdoff+1] = packUint16(rdlen)
73	// Adjust additional count
74	adc, _ := unpackUint16(buf, 10)
75	adc++
76	buf[10], buf[11] = packUint16(adc)
77	return buf, nil
78}
79
80// Verify validates the message buf using the key k.
81// It's assumed that buf is a valid message from which rr was unpacked.
82func (rr *SIG) Verify(k *KEY, buf []byte) error {
83	if k == nil {
84		return ErrKey
85	}
86	if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
87		return ErrKey
88	}
89
90	var hash crypto.Hash
91	switch rr.Algorithm {
92	case DSA, RSASHA1:
93		hash = crypto.SHA1
94	case RSASHA256, ECDSAP256SHA256:
95		hash = crypto.SHA256
96	case ECDSAP384SHA384:
97		hash = crypto.SHA384
98	case RSASHA512:
99		hash = crypto.SHA512
100	default:
101		return ErrAlg
102	}
103	hasher := hash.New()
104
105	buflen := len(buf)
106	qdc, _ := unpackUint16(buf, 4)
107	anc, _ := unpackUint16(buf, 6)
108	auc, _ := unpackUint16(buf, 8)
109	adc, offset := unpackUint16(buf, 10)
110	var err error
111	for i := uint16(0); i < qdc && offset < buflen; i++ {
112		_, offset, err = UnpackDomainName(buf, offset)
113		if err != nil {
114			return err
115		}
116		// Skip past Type and Class
117		offset += 2 + 2
118	}
119	for i := uint16(1); i < anc+auc+adc && offset < buflen; i++ {
120		_, offset, err = UnpackDomainName(buf, offset)
121		if err != nil {
122			return err
123		}
124		// Skip past Type, Class and TTL
125		offset += 2 + 2 + 4
126		if offset+1 >= buflen {
127			continue
128		}
129		var rdlen uint16
130		rdlen, offset = unpackUint16(buf, offset)
131		offset += int(rdlen)
132	}
133	if offset >= buflen {
134		return &Error{err: "overflowing unpacking signed message"}
135	}
136
137	// offset should be just prior to SIG
138	bodyend := offset
139	// owner name SHOULD be root
140	_, offset, err = UnpackDomainName(buf, offset)
141	if err != nil {
142		return err
143	}
144	// Skip Type, Class, TTL, RDLen
145	offset += 2 + 2 + 4 + 2
146	sigstart := offset
147	// Skip Type Covered, Algorithm, Labels, Original TTL
148	offset += 2 + 1 + 1 + 4
149	if offset+4+4 >= buflen {
150		return &Error{err: "overflow unpacking signed message"}
151	}
152	expire := uint32(buf[offset])<<24 | uint32(buf[offset+1])<<16 | uint32(buf[offset+2])<<8 | uint32(buf[offset+3])
153	offset += 4
154	incept := uint32(buf[offset])<<24 | uint32(buf[offset+1])<<16 | uint32(buf[offset+2])<<8 | uint32(buf[offset+3])
155	offset += 4
156	now := uint32(time.Now().Unix())
157	if now < incept || now > expire {
158		return ErrTime
159	}
160	// Skip key tag
161	offset += 2
162	var signername string
163	signername, offset, err = UnpackDomainName(buf, offset)
164	if err != nil {
165		return err
166	}
167	// If key has come from the DNS name compression might
168	// have mangled the case of the name
169	if strings.ToLower(signername) != strings.ToLower(k.Header().Name) {
170		return &Error{err: "signer name doesn't match key name"}
171	}
172	sigend := offset
173	hasher.Write(buf[sigstart:sigend])
174	hasher.Write(buf[:10])
175	hasher.Write([]byte{
176		byte((adc - 1) << 8),
177		byte(adc - 1),
178	})
179	hasher.Write(buf[12:bodyend])
180
181	hashed := hasher.Sum(nil)
182	sig := buf[sigend:]
183	switch k.Algorithm {
184	case DSA:
185		pk := k.publicKeyDSA()
186		sig = sig[1:]
187		r := big.NewInt(0)
188		r.SetBytes(sig[:len(sig)/2])
189		s := big.NewInt(0)
190		s.SetBytes(sig[len(sig)/2:])
191		if pk != nil {
192			if dsa.Verify(pk, hashed, r, s) {
193				return nil
194			}
195			return ErrSig
196		}
197	case RSASHA1, RSASHA256, RSASHA512:
198		pk := k.publicKeyRSA()
199		if pk != nil {
200			return rsa.VerifyPKCS1v15(pk, hash, hashed, sig)
201		}
202	case ECDSAP256SHA256, ECDSAP384SHA384:
203		pk := k.publicKeyECDSA()
204		r := big.NewInt(0)
205		r.SetBytes(sig[:len(sig)/2])
206		s := big.NewInt(0)
207		s.SetBytes(sig[len(sig)/2:])
208		if pk != nil {
209			if ecdsa.Verify(pk, hashed, r, s) {
210				return nil
211			}
212			return ErrSig
213		}
214	}
215	return ErrKeyAlg
216}
217