xref: /openbsd/usr.bin/ssh/sshkey-xmss.c (revision 09467b48)
1 /* $OpenBSD: sshkey-xmss.c,v 1.8 2019/11/13 07:53:10 markus Exp $ */
2 /*
3  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include <sys/types.h>
27 #include <sys/uio.h>
28 
29 #include <stdio.h>
30 #include <string.h>
31 #include <unistd.h>
32 #include <fcntl.h>
33 #include <errno.h>
34 
35 #include "ssh2.h"
36 #include "ssherr.h"
37 #include "sshbuf.h"
38 #include "cipher.h"
39 #include "sshkey.h"
40 #include "sshkey-xmss.h"
41 #include "atomicio.h"
42 
43 #include "xmss_fast.h"
44 
45 /* opaque internal XMSS state */
46 #define XMSS_MAGIC		"xmss-state-v1"
47 #define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
48 struct ssh_xmss_state {
49 	xmss_params	params;
50 	u_int32_t	n, w, h, k;
51 
52 	bds_state	bds;
53 	u_char		*stack;
54 	u_int32_t	stackoffset;
55 	u_char		*stacklevels;
56 	u_char		*auth;
57 	u_char		*keep;
58 	u_char		*th_nodes;
59 	u_char		*retain;
60 	treehash_inst	*treehash;
61 
62 	u_int32_t	idx;		/* state read from file */
63 	u_int32_t	maxidx;		/* restricted # of signatures */
64 	int		have_state;	/* .state file exists */
65 	int		lockfd;		/* locked in sshkey_xmss_get_state() */
66 	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
67 	char		*enc_ciphername;/* encrypt state with cipher */
68 	u_char		*enc_keyiv;	/* encrypt state with key */
69 	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
70 };
71 
72 int	 sshkey_xmss_init_bds_state(struct sshkey *);
73 int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
74 void	 sshkey_xmss_free_bds(struct sshkey *);
75 int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
76 	    int *, sshkey_printfn *);
77 int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
78 	    struct sshbuf **);
79 int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
80 	    struct sshbuf **);
81 int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
82 int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
83 
84 #define PRINT(s...) do { if (pr) pr(s); } while (0)
85 
86 int
87 sshkey_xmss_init(struct sshkey *key, const char *name)
88 {
89 	struct ssh_xmss_state *state;
90 
91 	if (key->xmss_state != NULL)
92 		return SSH_ERR_INVALID_FORMAT;
93 	if (name == NULL)
94 		return SSH_ERR_INVALID_FORMAT;
95 	state = calloc(sizeof(struct ssh_xmss_state), 1);
96 	if (state == NULL)
97 		return SSH_ERR_ALLOC_FAIL;
98 	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
99 		state->n = 32;
100 		state->w = 16;
101 		state->h = 10;
102 	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
103 		state->n = 32;
104 		state->w = 16;
105 		state->h = 16;
106 	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
107 		state->n = 32;
108 		state->w = 16;
109 		state->h = 20;
110 	} else {
111 		free(state);
112 		return SSH_ERR_KEY_TYPE_UNKNOWN;
113 	}
114 	if ((key->xmss_name = strdup(name)) == NULL) {
115 		free(state);
116 		return SSH_ERR_ALLOC_FAIL;
117 	}
118 	state->k = 2;	/* XXX hardcoded */
119 	state->lockfd = -1;
120 	if (xmss_set_params(&state->params, state->n, state->h, state->w,
121 	    state->k) != 0) {
122 		free(state);
123 		return SSH_ERR_INVALID_FORMAT;
124 	}
125 	key->xmss_state = state;
126 	return 0;
127 }
128 
129 void
130 sshkey_xmss_free_state(struct sshkey *key)
131 {
132 	struct ssh_xmss_state *state = key->xmss_state;
133 
134 	sshkey_xmss_free_bds(key);
135 	if (state) {
136 		if (state->enc_keyiv) {
137 			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
138 			free(state->enc_keyiv);
139 		}
140 		free(state->enc_ciphername);
141 		free(state);
142 	}
143 	key->xmss_state = NULL;
144 }
145 
146 #define SSH_XMSS_K2_MAGIC	"k=2"
147 #define num_stack(x)		((x->h+1)*(x->n))
148 #define num_stacklevels(x)	(x->h+1)
149 #define num_auth(x)		((x->h)*(x->n))
150 #define num_keep(x)		((x->h >> 1)*(x->n))
151 #define num_th_nodes(x)		((x->h - x->k)*(x->n))
152 #define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
153 #define num_treehash(x)		((x->h) - (x->k))
154 
155 int
156 sshkey_xmss_init_bds_state(struct sshkey *key)
157 {
158 	struct ssh_xmss_state *state = key->xmss_state;
159 	u_int32_t i;
160 
161 	state->stackoffset = 0;
162 	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
163 	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
164 	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
165 	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
166 	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
167 	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
168 	    (state->treehash = calloc(num_treehash(state),
169 	    sizeof(treehash_inst))) == NULL) {
170 		sshkey_xmss_free_bds(key);
171 		return SSH_ERR_ALLOC_FAIL;
172 	}
173 	for (i = 0; i < state->h - state->k; i++)
174 		state->treehash[i].node = &state->th_nodes[state->n*i];
175 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
176 	    state->stacklevels, state->auth, state->keep, state->treehash,
177 	    state->retain, 0);
178 	return 0;
179 }
180 
181 void
182 sshkey_xmss_free_bds(struct sshkey *key)
183 {
184 	struct ssh_xmss_state *state = key->xmss_state;
185 
186 	if (state == NULL)
187 		return;
188 	free(state->stack);
189 	free(state->stacklevels);
190 	free(state->auth);
191 	free(state->keep);
192 	free(state->th_nodes);
193 	free(state->retain);
194 	free(state->treehash);
195 	state->stack = NULL;
196 	state->stacklevels = NULL;
197 	state->auth = NULL;
198 	state->keep = NULL;
199 	state->th_nodes = NULL;
200 	state->retain = NULL;
201 	state->treehash = NULL;
202 }
203 
204 void *
205 sshkey_xmss_params(const struct sshkey *key)
206 {
207 	struct ssh_xmss_state *state = key->xmss_state;
208 
209 	if (state == NULL)
210 		return NULL;
211 	return &state->params;
212 }
213 
214 void *
215 sshkey_xmss_bds_state(const struct sshkey *key)
216 {
217 	struct ssh_xmss_state *state = key->xmss_state;
218 
219 	if (state == NULL)
220 		return NULL;
221 	return &state->bds;
222 }
223 
224 int
225 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
226 {
227 	struct ssh_xmss_state *state = key->xmss_state;
228 
229 	if (lenp == NULL)
230 		return SSH_ERR_INVALID_ARGUMENT;
231 	if (state == NULL)
232 		return SSH_ERR_INVALID_FORMAT;
233 	*lenp = 4 + state->n +
234 	    state->params.wots_par.keysize +
235 	    state->h * state->n;
236 	return 0;
237 }
238 
239 size_t
240 sshkey_xmss_pklen(const struct sshkey *key)
241 {
242 	struct ssh_xmss_state *state = key->xmss_state;
243 
244 	if (state == NULL)
245 		return 0;
246 	return state->n * 2;
247 }
248 
249 size_t
250 sshkey_xmss_sklen(const struct sshkey *key)
251 {
252 	struct ssh_xmss_state *state = key->xmss_state;
253 
254 	if (state == NULL)
255 		return 0;
256 	return state->n * 4 + 4;
257 }
258 
259 int
260 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
261 {
262 	struct ssh_xmss_state *state = k->xmss_state;
263 	const struct sshcipher *cipher;
264 	size_t keylen = 0, ivlen = 0;
265 
266 	if (state == NULL)
267 		return SSH_ERR_INVALID_ARGUMENT;
268 	if ((cipher = cipher_by_name(ciphername)) == NULL)
269 		return SSH_ERR_INTERNAL_ERROR;
270 	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
271 		return SSH_ERR_ALLOC_FAIL;
272 	keylen = cipher_keylen(cipher);
273 	ivlen = cipher_ivlen(cipher);
274 	state->enc_keyiv_len = keylen + ivlen;
275 	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
276 		free(state->enc_ciphername);
277 		state->enc_ciphername = NULL;
278 		return SSH_ERR_ALLOC_FAIL;
279 	}
280 	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
281 	return 0;
282 }
283 
284 int
285 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
286 {
287 	struct ssh_xmss_state *state = k->xmss_state;
288 	int r;
289 
290 	if (state == NULL || state->enc_keyiv == NULL ||
291 	    state->enc_ciphername == NULL)
292 		return SSH_ERR_INVALID_ARGUMENT;
293 	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
294 	    (r = sshbuf_put_string(b, state->enc_keyiv,
295 	    state->enc_keyiv_len)) != 0)
296 		return r;
297 	return 0;
298 }
299 
300 int
301 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
302 {
303 	struct ssh_xmss_state *state = k->xmss_state;
304 	size_t len;
305 	int r;
306 
307 	if (state == NULL)
308 		return SSH_ERR_INVALID_ARGUMENT;
309 	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
310 	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
311 		return r;
312 	state->enc_keyiv_len = len;
313 	return 0;
314 }
315 
316 int
317 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
318     enum sshkey_serialize_rep opts)
319 {
320 	struct ssh_xmss_state *state = k->xmss_state;
321 	u_char have_info = 1;
322 	u_int32_t idx;
323 	int r;
324 
325 	if (state == NULL)
326 		return SSH_ERR_INVALID_ARGUMENT;
327 	if (opts != SSHKEY_SERIALIZE_INFO)
328 		return 0;
329 	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
330 	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
331 	    (r = sshbuf_put_u32(b, idx)) != 0 ||
332 	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
333 		return r;
334 	return 0;
335 }
336 
337 int
338 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
339 {
340 	struct ssh_xmss_state *state = k->xmss_state;
341 	u_char have_info;
342 	int r;
343 
344 	if (state == NULL)
345 		return SSH_ERR_INVALID_ARGUMENT;
346 	/* optional */
347 	if (sshbuf_len(b) == 0)
348 		return 0;
349 	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
350 		return r;
351 	if (have_info != 1)
352 		return SSH_ERR_INVALID_ARGUMENT;
353 	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
354 	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
355 		return r;
356 	return 0;
357 }
358 
359 int
360 sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
361 {
362 	int r;
363 	const char *name;
364 
365 	if (bits == 10) {
366 		name = XMSS_SHA2_256_W16_H10_NAME;
367 	} else if (bits == 16) {
368 		name = XMSS_SHA2_256_W16_H16_NAME;
369 	} else if (bits == 20) {
370 		name = XMSS_SHA2_256_W16_H20_NAME;
371 	} else {
372 		name = XMSS_DEFAULT_NAME;
373 	}
374 	if ((r = sshkey_xmss_init(k, name)) != 0 ||
375 	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
376 	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
377 		return r;
378 	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
379 	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
380 		return SSH_ERR_ALLOC_FAIL;
381 	}
382 	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
383 	    sshkey_xmss_params(k));
384 	return 0;
385 }
386 
387 int
388 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
389     int *have_file, sshkey_printfn *pr)
390 {
391 	struct sshbuf *b = NULL, *enc = NULL;
392 	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
393 	u_int32_t len;
394 	unsigned char buf[4], *data = NULL;
395 
396 	*have_file = 0;
397 	if ((fd = open(filename, O_RDONLY)) >= 0) {
398 		*have_file = 1;
399 		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
400 			PRINT("%s: corrupt state file: %s", __func__, filename);
401 			goto done;
402 		}
403 		len = PEEK_U32(buf);
404 		if ((data = calloc(len, 1)) == NULL) {
405 			ret = SSH_ERR_ALLOC_FAIL;
406 			goto done;
407 		}
408 		if (atomicio(read, fd, data, len) != len) {
409 			PRINT("%s: cannot read blob: %s", __func__, filename);
410 			goto done;
411 		}
412 		if ((enc = sshbuf_from(data, len)) == NULL) {
413 			ret = SSH_ERR_ALLOC_FAIL;
414 			goto done;
415 		}
416 		sshkey_xmss_free_bds(k);
417 		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
418 			ret = r;
419 			goto done;
420 		}
421 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
422 			ret = r;
423 			goto done;
424 		}
425 		ret = 0;
426 	}
427 done:
428 	if (fd != -1)
429 		close(fd);
430 	free(data);
431 	sshbuf_free(enc);
432 	sshbuf_free(b);
433 	return ret;
434 }
435 
436 int
437 sshkey_xmss_get_state(const struct sshkey *k, sshkey_printfn *pr)
438 {
439 	struct ssh_xmss_state *state = k->xmss_state;
440 	u_int32_t idx = 0;
441 	char *filename = NULL;
442 	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
443 	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
444 	int ret = SSH_ERR_INVALID_ARGUMENT, r;
445 
446 	if (state == NULL)
447 		goto done;
448 	/*
449 	 * If maxidx is set, then we are allowed a limited number
450 	 * of signatures, but don't need to access the disk.
451 	 * Otherwise we need to deal with the on-disk state.
452 	 */
453 	if (state->maxidx) {
454 		/* xmss_sk always contains the current state */
455 		idx = PEEK_U32(k->xmss_sk);
456 		if (idx < state->maxidx) {
457 			state->allow_update = 1;
458 			return 0;
459 		}
460 		return SSH_ERR_INVALID_ARGUMENT;
461 	}
462 	if ((filename = k->xmss_filename) == NULL)
463 		goto done;
464 	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
465 	    asprintf(&statefile, "%s.state", filename) == -1 ||
466 	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
467 		ret = SSH_ERR_ALLOC_FAIL;
468 		goto done;
469 	}
470 	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
471 		ret = SSH_ERR_SYSTEM_ERROR;
472 		PRINT("%s: cannot open/create: %s", __func__, lockfile);
473 		goto done;
474 	}
475 	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
476 		if (errno != EWOULDBLOCK) {
477 			ret = SSH_ERR_SYSTEM_ERROR;
478 			PRINT("%s: cannot lock: %s", __func__, lockfile);
479 			goto done;
480 		}
481 		if (++tries > 10) {
482 			ret = SSH_ERR_SYSTEM_ERROR;
483 			PRINT("%s: giving up on: %s", __func__, lockfile);
484 			goto done;
485 		}
486 		usleep(1000*100*tries);
487 	}
488 	/* XXX no longer const */
489 	if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
490 	    statefile, &have_state, pr)) != 0) {
491 		if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
492 		    ostatefile, &have_ostate, pr)) == 0) {
493 			state->allow_update = 1;
494 			r = sshkey_xmss_forward_state(k, 1);
495 			state->idx = PEEK_U32(k->xmss_sk);
496 			state->allow_update = 0;
497 		}
498 	}
499 	if (!have_state && !have_ostate) {
500 		/* check that bds state is initialized */
501 		if (state->bds.auth == NULL)
502 			goto done;
503 		PRINT("%s: start from scratch idx 0: %u", __func__, state->idx);
504 	} else if (r != 0) {
505 		ret = r;
506 		goto done;
507 	}
508 	if (state->idx + 1 < state->idx) {
509 		PRINT("%s: state wrap: %u", __func__, state->idx);
510 		goto done;
511 	}
512 	state->have_state = have_state;
513 	state->lockfd = lockfd;
514 	state->allow_update = 1;
515 	lockfd = -1;
516 	ret = 0;
517 done:
518 	if (lockfd != -1)
519 		close(lockfd);
520 	free(lockfile);
521 	free(statefile);
522 	free(ostatefile);
523 	return ret;
524 }
525 
526 int
527 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
528 {
529 	struct ssh_xmss_state *state = k->xmss_state;
530 	u_char *sig = NULL;
531 	size_t required_siglen;
532 	unsigned long long smlen;
533 	u_char data;
534 	int ret, r;
535 
536 	if (state == NULL || !state->allow_update)
537 		return SSH_ERR_INVALID_ARGUMENT;
538 	if (reserve == 0)
539 		return SSH_ERR_INVALID_ARGUMENT;
540 	if (state->idx + reserve <= state->idx)
541 		return SSH_ERR_INVALID_ARGUMENT;
542 	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
543 		return r;
544 	if ((sig = malloc(required_siglen)) == NULL)
545 		return SSH_ERR_ALLOC_FAIL;
546 	while (reserve-- > 0) {
547 		state->idx = PEEK_U32(k->xmss_sk);
548 		smlen = required_siglen;
549 		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
550 		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
551 			r = SSH_ERR_INVALID_ARGUMENT;
552 			break;
553 		}
554 	}
555 	free(sig);
556 	return r;
557 }
558 
559 int
560 sshkey_xmss_update_state(const struct sshkey *k, sshkey_printfn *pr)
561 {
562 	struct ssh_xmss_state *state = k->xmss_state;
563 	struct sshbuf *b = NULL, *enc = NULL;
564 	u_int32_t idx = 0;
565 	unsigned char buf[4];
566 	char *filename = NULL;
567 	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
568 	int fd = -1;
569 	int ret = SSH_ERR_INVALID_ARGUMENT;
570 
571 	if (state == NULL || !state->allow_update)
572 		return ret;
573 	if (state->maxidx) {
574 		/* no update since the number of signatures is limited */
575 		ret = 0;
576 		goto done;
577 	}
578 	idx = PEEK_U32(k->xmss_sk);
579 	if (idx == state->idx) {
580 		/* no signature happened, no need to update */
581 		ret = 0;
582 		goto done;
583 	} else if (idx != state->idx + 1) {
584 		PRINT("%s: more than one signature happened: idx %u state %u",
585 		     __func__, idx, state->idx);
586 		goto done;
587 	}
588 	state->idx = idx;
589 	if ((filename = k->xmss_filename) == NULL)
590 		goto done;
591 	if (asprintf(&statefile, "%s.state", filename) == -1 ||
592 	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
593 	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
594 		ret = SSH_ERR_ALLOC_FAIL;
595 		goto done;
596 	}
597 	unlink(nstatefile);
598 	if ((b = sshbuf_new()) == NULL) {
599 		ret = SSH_ERR_ALLOC_FAIL;
600 		goto done;
601 	}
602 	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
603 		PRINT("%s: SERLIALIZE FAILED: %d", __func__, ret);
604 		goto done;
605 	}
606 	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
607 		PRINT("%s: ENCRYPT FAILED: %d", __func__, ret);
608 		goto done;
609 	}
610 	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
611 		ret = SSH_ERR_SYSTEM_ERROR;
612 		PRINT("%s: open new state file: %s", __func__, nstatefile);
613 		goto done;
614 	}
615 	POKE_U32(buf, sshbuf_len(enc));
616 	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
617 		ret = SSH_ERR_SYSTEM_ERROR;
618 		PRINT("%s: write new state file hdr: %s", __func__, nstatefile);
619 		close(fd);
620 		goto done;
621 	}
622 	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
623 	    sshbuf_len(enc)) {
624 		ret = SSH_ERR_SYSTEM_ERROR;
625 		PRINT("%s: write new state file data: %s", __func__, nstatefile);
626 		close(fd);
627 		goto done;
628 	}
629 	if (fsync(fd) == -1) {
630 		ret = SSH_ERR_SYSTEM_ERROR;
631 		PRINT("%s: sync new state file: %s", __func__, nstatefile);
632 		close(fd);
633 		goto done;
634 	}
635 	if (close(fd) == -1) {
636 		ret = SSH_ERR_SYSTEM_ERROR;
637 		PRINT("%s: close new state file: %s", __func__, nstatefile);
638 		goto done;
639 	}
640 	if (state->have_state) {
641 		unlink(ostatefile);
642 		if (link(statefile, ostatefile)) {
643 			ret = SSH_ERR_SYSTEM_ERROR;
644 			PRINT("%s: backup state %s to %s", __func__, statefile,
645 			    ostatefile);
646 			goto done;
647 		}
648 	}
649 	if (rename(nstatefile, statefile) == -1) {
650 		ret = SSH_ERR_SYSTEM_ERROR;
651 		PRINT("%s: rename %s to %s", __func__, nstatefile, statefile);
652 		goto done;
653 	}
654 	ret = 0;
655 done:
656 	if (state->lockfd != -1) {
657 		close(state->lockfd);
658 		state->lockfd = -1;
659 	}
660 	if (nstatefile)
661 		unlink(nstatefile);
662 	free(statefile);
663 	free(ostatefile);
664 	free(nstatefile);
665 	sshbuf_free(b);
666 	sshbuf_free(enc);
667 	return ret;
668 }
669 
670 int
671 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
672 {
673 	struct ssh_xmss_state *state = k->xmss_state;
674 	treehash_inst *th;
675 	u_int32_t i, node;
676 	int r;
677 
678 	if (state == NULL)
679 		return SSH_ERR_INVALID_ARGUMENT;
680 	if (state->stack == NULL)
681 		return SSH_ERR_INVALID_ARGUMENT;
682 	state->stackoffset = state->bds.stackoffset;	/* copy back */
683 	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
684 	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
685 	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
686 	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
687 	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
688 	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
689 	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
690 	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
691 	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
692 	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
693 		return r;
694 	for (i = 0; i < num_treehash(state); i++) {
695 		th = &state->treehash[i];
696 		node = th->node - state->th_nodes;
697 		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
698 		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
699 		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
700 		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
701 		    (r = sshbuf_put_u32(b, node)) != 0)
702 			return r;
703 	}
704 	return 0;
705 }
706 
707 int
708 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
709     enum sshkey_serialize_rep opts)
710 {
711 	struct ssh_xmss_state *state = k->xmss_state;
712 	int r = SSH_ERR_INVALID_ARGUMENT;
713 	u_char have_stack, have_filename, have_enc;
714 
715 	if (state == NULL)
716 		return SSH_ERR_INVALID_ARGUMENT;
717 	if ((r = sshbuf_put_u8(b, opts)) != 0)
718 		return r;
719 	switch (opts) {
720 	case SSHKEY_SERIALIZE_STATE:
721 		r = sshkey_xmss_serialize_state(k, b);
722 		break;
723 	case SSHKEY_SERIALIZE_FULL:
724 		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
725 			return r;
726 		r = sshkey_xmss_serialize_state(k, b);
727 		break;
728 	case SSHKEY_SERIALIZE_SHIELD:
729 		/* all of stack/filename/enc are optional */
730 		have_stack = state->stack != NULL;
731 		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
732 			return r;
733 		if (have_stack) {
734 			state->idx = PEEK_U32(k->xmss_sk);	/* update */
735 			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
736 				return r;
737 		}
738 		have_filename = k->xmss_filename != NULL;
739 		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
740 			return r;
741 		if (have_filename &&
742 		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
743 			return r;
744 		have_enc = state->enc_keyiv != NULL;
745 		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
746 			return r;
747 		if (have_enc &&
748 		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
749 			return r;
750 		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
751 		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
752 			return r;
753 		break;
754 	case SSHKEY_SERIALIZE_DEFAULT:
755 		r = 0;
756 		break;
757 	default:
758 		r = SSH_ERR_INVALID_ARGUMENT;
759 		break;
760 	}
761 	return r;
762 }
763 
764 int
765 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
766 {
767 	struct ssh_xmss_state *state = k->xmss_state;
768 	treehash_inst *th;
769 	u_int32_t i, lh, node;
770 	size_t ls, lsl, la, lk, ln, lr;
771 	char *magic;
772 	int r = SSH_ERR_INTERNAL_ERROR;
773 
774 	if (state == NULL)
775 		return SSH_ERR_INVALID_ARGUMENT;
776 	if (k->xmss_sk == NULL)
777 		return SSH_ERR_INVALID_ARGUMENT;
778 	if ((state->treehash = calloc(num_treehash(state),
779 	    sizeof(treehash_inst))) == NULL)
780 		return SSH_ERR_ALLOC_FAIL;
781 	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
782 	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
783 	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
784 	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
785 	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
786 	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
787 	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
788 	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
789 	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
790 	    (r = sshbuf_get_u32(b, &lh)) != 0)
791 		goto out;
792 	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
793 		r = SSH_ERR_INVALID_ARGUMENT;
794 		goto out;
795 	}
796 	/* XXX check stackoffset */
797 	if (ls != num_stack(state) ||
798 	    lsl != num_stacklevels(state) ||
799 	    la != num_auth(state) ||
800 	    lk != num_keep(state) ||
801 	    ln != num_th_nodes(state) ||
802 	    lr != num_retain(state) ||
803 	    lh != num_treehash(state)) {
804 		r = SSH_ERR_INVALID_ARGUMENT;
805 		goto out;
806 	}
807 	for (i = 0; i < num_treehash(state); i++) {
808 		th = &state->treehash[i];
809 		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
810 		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
811 		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
812 		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
813 		    (r = sshbuf_get_u32(b, &node)) != 0)
814 			goto out;
815 		if (node < num_th_nodes(state))
816 			th->node = &state->th_nodes[node];
817 	}
818 	POKE_U32(k->xmss_sk, state->idx);
819 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
820 	    state->stacklevels, state->auth, state->keep, state->treehash,
821 	    state->retain, 0);
822 	/* success */
823 	r = 0;
824  out:
825 	free(magic);
826 	return r;
827 }
828 
829 int
830 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
831 {
832 	struct ssh_xmss_state *state = k->xmss_state;
833 	enum sshkey_serialize_rep opts;
834 	u_char have_state, have_stack, have_filename, have_enc;
835 	int r;
836 
837 	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
838 		return r;
839 
840 	opts = have_state;
841 	switch (opts) {
842 	case SSHKEY_SERIALIZE_DEFAULT:
843 		r = 0;
844 		break;
845 	case SSHKEY_SERIALIZE_SHIELD:
846 		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
847 			return r;
848 		if (have_stack &&
849 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
850 			return r;
851 		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
852 			return r;
853 		if (have_filename &&
854 		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
855 			return r;
856 		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
857 			return r;
858 		if (have_enc &&
859 		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
860 			return r;
861 		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
862 		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
863 			return r;
864 		break;
865 	case SSHKEY_SERIALIZE_STATE:
866 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
867 			return r;
868 		break;
869 	case SSHKEY_SERIALIZE_FULL:
870 		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
871 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
872 			return r;
873 		break;
874 	default:
875 		r = SSH_ERR_INVALID_FORMAT;
876 		break;
877 	}
878 	return r;
879 }
880 
881 int
882 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
883    struct sshbuf **retp)
884 {
885 	struct ssh_xmss_state *state = k->xmss_state;
886 	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
887 	struct sshcipher_ctx *ciphercontext = NULL;
888 	const struct sshcipher *cipher;
889 	u_char *cp, *key, *iv = NULL;
890 	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
891 	int r = SSH_ERR_INTERNAL_ERROR;
892 
893 	if (retp != NULL)
894 		*retp = NULL;
895 	if (state == NULL ||
896 	    state->enc_keyiv == NULL ||
897 	    state->enc_ciphername == NULL)
898 		return SSH_ERR_INTERNAL_ERROR;
899 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
900 		r = SSH_ERR_INTERNAL_ERROR;
901 		goto out;
902 	}
903 	blocksize = cipher_blocksize(cipher);
904 	keylen = cipher_keylen(cipher);
905 	ivlen = cipher_ivlen(cipher);
906 	authlen = cipher_authlen(cipher);
907 	if (state->enc_keyiv_len != keylen + ivlen) {
908 		r = SSH_ERR_INVALID_FORMAT;
909 		goto out;
910 	}
911 	key = state->enc_keyiv;
912 	if ((encrypted = sshbuf_new()) == NULL ||
913 	    (encoded = sshbuf_new()) == NULL ||
914 	    (padded = sshbuf_new()) == NULL ||
915 	    (iv = malloc(ivlen)) == NULL) {
916 		r = SSH_ERR_ALLOC_FAIL;
917 		goto out;
918 	}
919 
920 	/* replace first 4 bytes of IV with index to ensure uniqueness */
921 	memcpy(iv, key + keylen, ivlen);
922 	POKE_U32(iv, state->idx);
923 
924 	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
925 	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
926 		goto out;
927 
928 	/* padded state will be encrypted */
929 	if ((r = sshbuf_putb(padded, b)) != 0)
930 		goto out;
931 	i = 0;
932 	while (sshbuf_len(padded) % blocksize) {
933 		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
934 			goto out;
935 	}
936 	encrypted_len = sshbuf_len(padded);
937 
938 	/* header including the length of state is used as AAD */
939 	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
940 		goto out;
941 	aadlen = sshbuf_len(encoded);
942 
943 	/* concat header and state */
944 	if ((r = sshbuf_putb(encoded, padded)) != 0)
945 		goto out;
946 
947 	/* reserve space for encryption of encoded data plus auth tag */
948 	/* encrypt at offset addlen */
949 	if ((r = sshbuf_reserve(encrypted,
950 	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
951 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
952 	    iv, ivlen, 1)) != 0 ||
953 	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
954 	    encrypted_len, aadlen, authlen)) != 0)
955 		goto out;
956 
957 	/* success */
958 	r = 0;
959  out:
960 	if (retp != NULL) {
961 		*retp = encrypted;
962 		encrypted = NULL;
963 	}
964 	sshbuf_free(padded);
965 	sshbuf_free(encoded);
966 	sshbuf_free(encrypted);
967 	cipher_free(ciphercontext);
968 	free(iv);
969 	return r;
970 }
971 
972 int
973 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
974    struct sshbuf **retp)
975 {
976 	struct ssh_xmss_state *state = k->xmss_state;
977 	struct sshbuf *copy = NULL, *decrypted = NULL;
978 	struct sshcipher_ctx *ciphercontext = NULL;
979 	const struct sshcipher *cipher = NULL;
980 	u_char *key, *iv = NULL, *dp;
981 	size_t keylen, ivlen, authlen, aadlen;
982 	u_int blocksize, encrypted_len, index;
983 	int r = SSH_ERR_INTERNAL_ERROR;
984 
985 	if (retp != NULL)
986 		*retp = NULL;
987 	if (state == NULL ||
988 	    state->enc_keyiv == NULL ||
989 	    state->enc_ciphername == NULL)
990 		return SSH_ERR_INTERNAL_ERROR;
991 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
992 		r = SSH_ERR_INVALID_FORMAT;
993 		goto out;
994 	}
995 	blocksize = cipher_blocksize(cipher);
996 	keylen = cipher_keylen(cipher);
997 	ivlen = cipher_ivlen(cipher);
998 	authlen = cipher_authlen(cipher);
999 	if (state->enc_keyiv_len != keylen + ivlen) {
1000 		r = SSH_ERR_INTERNAL_ERROR;
1001 		goto out;
1002 	}
1003 	key = state->enc_keyiv;
1004 
1005 	if ((copy = sshbuf_fromb(encoded)) == NULL ||
1006 	    (decrypted = sshbuf_new()) == NULL ||
1007 	    (iv = malloc(ivlen)) == NULL) {
1008 		r = SSH_ERR_ALLOC_FAIL;
1009 		goto out;
1010 	}
1011 
1012 	/* check magic */
1013 	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1014 	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1015 		r = SSH_ERR_INVALID_FORMAT;
1016 		goto out;
1017 	}
1018 	/* parse public portion */
1019 	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1020 	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1021 	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1022 		goto out;
1023 
1024 	/* check size of encrypted key blob */
1025 	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1026 		r = SSH_ERR_INVALID_FORMAT;
1027 		goto out;
1028 	}
1029 	/* check that an appropriate amount of auth data is present */
1030 	if (sshbuf_len(encoded) < authlen ||
1031 	    sshbuf_len(encoded) - authlen < encrypted_len) {
1032 		r = SSH_ERR_INVALID_FORMAT;
1033 		goto out;
1034 	}
1035 
1036 	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1037 
1038 	/* replace first 4 bytes of IV with index to ensure uniqueness */
1039 	memcpy(iv, key + keylen, ivlen);
1040 	POKE_U32(iv, index);
1041 
1042 	/* decrypt private state of key */
1043 	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1044 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
1045 	    iv, ivlen, 0)) != 0 ||
1046 	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1047 	    encrypted_len, aadlen, authlen)) != 0)
1048 		goto out;
1049 
1050 	/* there should be no trailing data */
1051 	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1052 		goto out;
1053 	if (sshbuf_len(encoded) != 0) {
1054 		r = SSH_ERR_INVALID_FORMAT;
1055 		goto out;
1056 	}
1057 
1058 	/* remove AAD */
1059 	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1060 		goto out;
1061 	/* XXX encrypted includes unchecked padding */
1062 
1063 	/* success */
1064 	r = 0;
1065 	if (retp != NULL) {
1066 		*retp = decrypted;
1067 		decrypted = NULL;
1068 	}
1069  out:
1070 	cipher_free(ciphercontext);
1071 	sshbuf_free(copy);
1072 	sshbuf_free(decrypted);
1073 	free(iv);
1074 	return r;
1075 }
1076 
1077 u_int32_t
1078 sshkey_xmss_signatures_left(const struct sshkey *k)
1079 {
1080 	struct ssh_xmss_state *state = k->xmss_state;
1081 	u_int32_t idx;
1082 
1083 	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1084 	    state->maxidx) {
1085 		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1086 		if (idx < state->maxidx)
1087 			return state->maxidx - idx;
1088 	}
1089 	return 0;
1090 }
1091 
1092 int
1093 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1094 {
1095 	struct ssh_xmss_state *state = k->xmss_state;
1096 
1097 	if (sshkey_type_plain(k->type) != KEY_XMSS)
1098 		return SSH_ERR_INVALID_ARGUMENT;
1099 	if (maxsign == 0)
1100 		return 0;
1101 	if (state->idx + maxsign < state->idx)
1102 		return SSH_ERR_INVALID_ARGUMENT;
1103 	state->maxidx = state->idx + maxsign;
1104 	return 0;
1105 }
1106