1package dns
2
3import (
4	"bufio"
5	"crypto"
6	"crypto/ecdsa"
7	"crypto/rsa"
8	"io"
9	"math/big"
10	"strconv"
11	"strings"
12
13	"golang.org/x/crypto/ed25519"
14)
15
16// NewPrivateKey returns a PrivateKey by parsing the string s.
17// s should be in the same form of the BIND private key files.
18func (k *DNSKEY) NewPrivateKey(s string) (crypto.PrivateKey, error) {
19	if s == "" || s[len(s)-1] != '\n' { // We need a closing newline
20		return k.ReadPrivateKey(strings.NewReader(s+"\n"), "")
21	}
22	return k.ReadPrivateKey(strings.NewReader(s), "")
23}
24
25// ReadPrivateKey reads a private key from the io.Reader q. The string file is
26// only used in error reporting.
27// The public key must be known, because some cryptographic algorithms embed
28// the public inside the privatekey.
29func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (crypto.PrivateKey, error) {
30	m, err := parseKey(q, file)
31	if m == nil {
32		return nil, err
33	}
34	if _, ok := m["private-key-format"]; !ok {
35		return nil, ErrPrivKey
36	}
37	if m["private-key-format"] != "v1.2" && m["private-key-format"] != "v1.3" {
38		return nil, ErrPrivKey
39	}
40	// TODO(mg): check if the pubkey matches the private key
41	algo, err := strconv.ParseUint(strings.SplitN(m["algorithm"], " ", 2)[0], 10, 8)
42	if err != nil {
43		return nil, ErrPrivKey
44	}
45	switch uint8(algo) {
46	case RSAMD5, DSA, DSANSEC3SHA1:
47		return nil, ErrAlg
48	case RSASHA1:
49		fallthrough
50	case RSASHA1NSEC3SHA1:
51		fallthrough
52	case RSASHA256:
53		fallthrough
54	case RSASHA512:
55		priv, err := readPrivateKeyRSA(m)
56		if err != nil {
57			return nil, err
58		}
59		pub := k.publicKeyRSA()
60		if pub == nil {
61			return nil, ErrKey
62		}
63		priv.PublicKey = *pub
64		return priv, nil
65	case ECCGOST:
66		return nil, ErrPrivKey
67	case ECDSAP256SHA256:
68		fallthrough
69	case ECDSAP384SHA384:
70		priv, err := readPrivateKeyECDSA(m)
71		if err != nil {
72			return nil, err
73		}
74		pub := k.publicKeyECDSA()
75		if pub == nil {
76			return nil, ErrKey
77		}
78		priv.PublicKey = *pub
79		return priv, nil
80	case ED25519:
81		return readPrivateKeyED25519(m)
82	default:
83		return nil, ErrPrivKey
84	}
85}
86
87// Read a private key (file) string and create a public key. Return the private key.
88func readPrivateKeyRSA(m map[string]string) (*rsa.PrivateKey, error) {
89	p := new(rsa.PrivateKey)
90	p.Primes = []*big.Int{nil, nil}
91	for k, v := range m {
92		switch k {
93		case "modulus", "publicexponent", "privateexponent", "prime1", "prime2":
94			v1, err := fromBase64([]byte(v))
95			if err != nil {
96				return nil, err
97			}
98			switch k {
99			case "modulus":
100				p.PublicKey.N = new(big.Int).SetBytes(v1)
101			case "publicexponent":
102				i := new(big.Int).SetBytes(v1)
103				p.PublicKey.E = int(i.Int64()) // int64 should be large enough
104			case "privateexponent":
105				p.D = new(big.Int).SetBytes(v1)
106			case "prime1":
107				p.Primes[0] = new(big.Int).SetBytes(v1)
108			case "prime2":
109				p.Primes[1] = new(big.Int).SetBytes(v1)
110			}
111		case "exponent1", "exponent2", "coefficient":
112			// not used in Go (yet)
113		case "created", "publish", "activate":
114			// not used in Go (yet)
115		}
116	}
117	return p, nil
118}
119
120func readPrivateKeyECDSA(m map[string]string) (*ecdsa.PrivateKey, error) {
121	p := new(ecdsa.PrivateKey)
122	p.D = new(big.Int)
123	// TODO: validate that the required flags are present
124	for k, v := range m {
125		switch k {
126		case "privatekey":
127			v1, err := fromBase64([]byte(v))
128			if err != nil {
129				return nil, err
130			}
131			p.D.SetBytes(v1)
132		case "created", "publish", "activate":
133			/* not used in Go (yet) */
134		}
135	}
136	return p, nil
137}
138
139func readPrivateKeyED25519(m map[string]string) (ed25519.PrivateKey, error) {
140	var p ed25519.PrivateKey
141	// TODO: validate that the required flags are present
142	for k, v := range m {
143		switch k {
144		case "privatekey":
145			p1, err := fromBase64([]byte(v))
146			if err != nil {
147				return nil, err
148			}
149			if len(p1) != ed25519.SeedSize {
150				return nil, ErrPrivKey
151			}
152			p = ed25519.NewKeyFromSeed(p1)
153		case "created", "publish", "activate":
154			/* not used in Go (yet) */
155		}
156	}
157	return p, nil
158}
159
160// parseKey reads a private key from r. It returns a map[string]string,
161// with the key-value pairs, or an error when the file is not correct.
162func parseKey(r io.Reader, file string) (map[string]string, error) {
163	m := make(map[string]string)
164	var k string
165
166	c := newKLexer(r)
167
168	for l, ok := c.Next(); ok; l, ok = c.Next() {
169		// It should alternate
170		switch l.value {
171		case zKey:
172			k = l.token
173		case zValue:
174			if k == "" {
175				return nil, &ParseError{file, "no private key seen", l}
176			}
177
178			m[strings.ToLower(k)] = l.token
179			k = ""
180		}
181	}
182
183	// Surface any read errors from r.
184	if err := c.Err(); err != nil {
185		return nil, &ParseError{file: file, err: err.Error()}
186	}
187
188	return m, nil
189}
190
191type klexer struct {
192	br io.ByteReader
193
194	readErr error
195
196	line   int
197	column int
198
199	key bool
200
201	eol bool // end-of-line
202}
203
204func newKLexer(r io.Reader) *klexer {
205	br, ok := r.(io.ByteReader)
206	if !ok {
207		br = bufio.NewReaderSize(r, 1024)
208	}
209
210	return &klexer{
211		br: br,
212
213		line: 1,
214
215		key: true,
216	}
217}
218
219func (kl *klexer) Err() error {
220	if kl.readErr == io.EOF {
221		return nil
222	}
223
224	return kl.readErr
225}
226
227// readByte returns the next byte from the input
228func (kl *klexer) readByte() (byte, bool) {
229	if kl.readErr != nil {
230		return 0, false
231	}
232
233	c, err := kl.br.ReadByte()
234	if err != nil {
235		kl.readErr = err
236		return 0, false
237	}
238
239	// delay the newline handling until the next token is delivered,
240	// fixes off-by-one errors when reporting a parse error.
241	if kl.eol {
242		kl.line++
243		kl.column = 0
244		kl.eol = false
245	}
246
247	if c == '\n' {
248		kl.eol = true
249	} else {
250		kl.column++
251	}
252
253	return c, true
254}
255
256func (kl *klexer) Next() (lex, bool) {
257	var (
258		l lex
259
260		str strings.Builder
261
262		commt bool
263	)
264
265	for x, ok := kl.readByte(); ok; x, ok = kl.readByte() {
266		l.line, l.column = kl.line, kl.column
267
268		switch x {
269		case ':':
270			if commt || !kl.key {
271				break
272			}
273
274			kl.key = false
275
276			// Next token is a space, eat it
277			kl.readByte()
278
279			l.value = zKey
280			l.token = str.String()
281			return l, true
282		case ';':
283			commt = true
284		case '\n':
285			if commt {
286				// Reset a comment
287				commt = false
288			}
289
290			if kl.key && str.Len() == 0 {
291				// ignore empty lines
292				break
293			}
294
295			kl.key = true
296
297			l.value = zValue
298			l.token = str.String()
299			return l, true
300		default:
301			if commt {
302				break
303			}
304
305			str.WriteByte(x)
306		}
307	}
308
309	if kl.readErr != nil && kl.readErr != io.EOF {
310		// Don't return any tokens after a read error occurs.
311		return lex{value: zEOF}, false
312	}
313
314	if str.Len() > 0 {
315		// Send remainder
316		l.value = zValue
317		l.token = str.String()
318		return l, true
319	}
320
321	return lex{value: zEOF}, false
322}
323