1package dns
2
3import (
4	"bufio"
5	"crypto"
6	"crypto/ecdsa"
7	"crypto/ed25519"
8	"crypto/rsa"
9	"io"
10	"math/big"
11	"strconv"
12	"strings"
13)
14
15// NewPrivateKey returns a PrivateKey by parsing the string s.
16// s should be in the same form of the BIND private key files.
17func (k *DNSKEY) NewPrivateKey(s string) (crypto.PrivateKey, error) {
18	if s == "" || s[len(s)-1] != '\n' { // We need a closing newline
19		return k.ReadPrivateKey(strings.NewReader(s+"\n"), "")
20	}
21	return k.ReadPrivateKey(strings.NewReader(s), "")
22}
23
24// ReadPrivateKey reads a private key from the io.Reader q. The string file is
25// only used in error reporting.
26// The public key must be known, because some cryptographic algorithms embed
27// the public inside the privatekey.
28func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (crypto.PrivateKey, error) {
29	m, err := parseKey(q, file)
30	if m == nil {
31		return nil, err
32	}
33	if _, ok := m["private-key-format"]; !ok {
34		return nil, ErrPrivKey
35	}
36	if m["private-key-format"] != "v1.2" && m["private-key-format"] != "v1.3" {
37		return nil, ErrPrivKey
38	}
39	// TODO(mg): check if the pubkey matches the private key
40	algo, err := strconv.ParseUint(strings.SplitN(m["algorithm"], " ", 2)[0], 10, 8)
41	if err != nil {
42		return nil, ErrPrivKey
43	}
44	switch uint8(algo) {
45	case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512:
46		priv, err := readPrivateKeyRSA(m)
47		if err != nil {
48			return nil, err
49		}
50		pub := k.publicKeyRSA()
51		if pub == nil {
52			return nil, ErrKey
53		}
54		priv.PublicKey = *pub
55		return priv, nil
56	case ECDSAP256SHA256, ECDSAP384SHA384:
57		priv, err := readPrivateKeyECDSA(m)
58		if err != nil {
59			return nil, err
60		}
61		pub := k.publicKeyECDSA()
62		if pub == nil {
63			return nil, ErrKey
64		}
65		priv.PublicKey = *pub
66		return priv, nil
67	case ED25519:
68		return readPrivateKeyED25519(m)
69	default:
70		return nil, ErrAlg
71	}
72}
73
74// Read a private key (file) string and create a public key. Return the private key.
75func readPrivateKeyRSA(m map[string]string) (*rsa.PrivateKey, error) {
76	p := new(rsa.PrivateKey)
77	p.Primes = []*big.Int{nil, nil}
78	for k, v := range m {
79		switch k {
80		case "modulus", "publicexponent", "privateexponent", "prime1", "prime2":
81			v1, err := fromBase64([]byte(v))
82			if err != nil {
83				return nil, err
84			}
85			switch k {
86			case "modulus":
87				p.PublicKey.N = new(big.Int).SetBytes(v1)
88			case "publicexponent":
89				i := new(big.Int).SetBytes(v1)
90				p.PublicKey.E = int(i.Int64()) // int64 should be large enough
91			case "privateexponent":
92				p.D = new(big.Int).SetBytes(v1)
93			case "prime1":
94				p.Primes[0] = new(big.Int).SetBytes(v1)
95			case "prime2":
96				p.Primes[1] = new(big.Int).SetBytes(v1)
97			}
98		case "exponent1", "exponent2", "coefficient":
99			// not used in Go (yet)
100		case "created", "publish", "activate":
101			// not used in Go (yet)
102		}
103	}
104	return p, nil
105}
106
107func readPrivateKeyECDSA(m map[string]string) (*ecdsa.PrivateKey, error) {
108	p := new(ecdsa.PrivateKey)
109	p.D = new(big.Int)
110	// TODO: validate that the required flags are present
111	for k, v := range m {
112		switch k {
113		case "privatekey":
114			v1, err := fromBase64([]byte(v))
115			if err != nil {
116				return nil, err
117			}
118			p.D.SetBytes(v1)
119		case "created", "publish", "activate":
120			/* not used in Go (yet) */
121		}
122	}
123	return p, nil
124}
125
126func readPrivateKeyED25519(m map[string]string) (ed25519.PrivateKey, error) {
127	var p ed25519.PrivateKey
128	// TODO: validate that the required flags are present
129	for k, v := range m {
130		switch k {
131		case "privatekey":
132			p1, err := fromBase64([]byte(v))
133			if err != nil {
134				return nil, err
135			}
136			if len(p1) != ed25519.SeedSize {
137				return nil, ErrPrivKey
138			}
139			p = ed25519.NewKeyFromSeed(p1)
140		case "created", "publish", "activate":
141			/* not used in Go (yet) */
142		}
143	}
144	return p, nil
145}
146
147// parseKey reads a private key from r. It returns a map[string]string,
148// with the key-value pairs, or an error when the file is not correct.
149func parseKey(r io.Reader, file string) (map[string]string, error) {
150	m := make(map[string]string)
151	var k string
152
153	c := newKLexer(r)
154
155	for l, ok := c.Next(); ok; l, ok = c.Next() {
156		// It should alternate
157		switch l.value {
158		case zKey:
159			k = l.token
160		case zValue:
161			if k == "" {
162				return nil, &ParseError{file, "no private key seen", l}
163			}
164
165			m[strings.ToLower(k)] = l.token
166			k = ""
167		}
168	}
169
170	// Surface any read errors from r.
171	if err := c.Err(); err != nil {
172		return nil, &ParseError{file: file, err: err.Error()}
173	}
174
175	return m, nil
176}
177
178type klexer struct {
179	br io.ByteReader
180
181	readErr error
182
183	line   int
184	column int
185
186	key bool
187
188	eol bool // end-of-line
189}
190
191func newKLexer(r io.Reader) *klexer {
192	br, ok := r.(io.ByteReader)
193	if !ok {
194		br = bufio.NewReaderSize(r, 1024)
195	}
196
197	return &klexer{
198		br: br,
199
200		line: 1,
201
202		key: true,
203	}
204}
205
206func (kl *klexer) Err() error {
207	if kl.readErr == io.EOF {
208		return nil
209	}
210
211	return kl.readErr
212}
213
214// readByte returns the next byte from the input
215func (kl *klexer) readByte() (byte, bool) {
216	if kl.readErr != nil {
217		return 0, false
218	}
219
220	c, err := kl.br.ReadByte()
221	if err != nil {
222		kl.readErr = err
223		return 0, false
224	}
225
226	// delay the newline handling until the next token is delivered,
227	// fixes off-by-one errors when reporting a parse error.
228	if kl.eol {
229		kl.line++
230		kl.column = 0
231		kl.eol = false
232	}
233
234	if c == '\n' {
235		kl.eol = true
236	} else {
237		kl.column++
238	}
239
240	return c, true
241}
242
243func (kl *klexer) Next() (lex, bool) {
244	var (
245		l lex
246
247		str strings.Builder
248
249		commt bool
250	)
251
252	for x, ok := kl.readByte(); ok; x, ok = kl.readByte() {
253		l.line, l.column = kl.line, kl.column
254
255		switch x {
256		case ':':
257			if commt || !kl.key {
258				break
259			}
260
261			kl.key = false
262
263			// Next token is a space, eat it
264			kl.readByte()
265
266			l.value = zKey
267			l.token = str.String()
268			return l, true
269		case ';':
270			commt = true
271		case '\n':
272			if commt {
273				// Reset a comment
274				commt = false
275			}
276
277			if kl.key && str.Len() == 0 {
278				// ignore empty lines
279				break
280			}
281
282			kl.key = true
283
284			l.value = zValue
285			l.token = str.String()
286			return l, true
287		default:
288			if commt {
289				break
290			}
291
292			str.WriteByte(x)
293		}
294	}
295
296	if kl.readErr != nil && kl.readErr != io.EOF {
297		// Don't return any tokens after a read error occurs.
298		return lex{value: zEOF}, false
299	}
300
301	if str.Len() > 0 {
302		// Send remainder
303		l.value = zValue
304		l.token = str.String()
305		return l, true
306	}
307
308	return lex{value: zEOF}, false
309}
310