1 /*	$NetBSD: sshkey-xmss.c,v 1.9 2023/07/26 17:58:16 christos Exp $	*/
2 /* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */
3 /*
4  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
17  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
18  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
19  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
20  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
21  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
22  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
24  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  */
26 #include "includes.h"
27 __RCSID("$NetBSD: sshkey-xmss.c,v 1.9 2023/07/26 17:58:16 christos Exp $");
28 
29 #include <sys/types.h>
30 #include <sys/uio.h>
31 
32 #include <stdio.h>
33 #include <string.h>
34 #include <unistd.h>
35 #include <fcntl.h>
36 #include <errno.h>
37 
38 #include "ssh2.h"
39 #include "ssherr.h"
40 #include "sshbuf.h"
41 #include "cipher.h"
42 #include "sshkey.h"
43 #include "sshkey-xmss.h"
44 #include "atomicio.h"
45 #include "log.h"
46 
47 #include "xmss_fast.h"
48 
49 /* opaque internal XMSS state */
50 #define XMSS_MAGIC		"xmss-state-v1"
51 #define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
52 struct ssh_xmss_state {
53 	xmss_params	params;
54 	u_int32_t	n, w, h, k;
55 
56 	bds_state	bds;
57 	u_char		*stack;
58 	u_int32_t	stackoffset;
59 	u_char		*stacklevels;
60 	u_char		*auth;
61 	u_char		*keep;
62 	u_char		*th_nodes;
63 	u_char		*retain;
64 	treehash_inst	*treehash;
65 
66 	u_int32_t	idx;		/* state read from file */
67 	u_int32_t	maxidx;		/* restricted # of signatures */
68 	int		have_state;	/* .state file exists */
69 	int		lockfd;		/* locked in sshkey_xmss_get_state() */
70 	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
71 	char		*enc_ciphername;/* encrypt state with cipher */
72 	u_char		*enc_keyiv;	/* encrypt state with key */
73 	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
74 };
75 
76 int	 sshkey_xmss_init_bds_state(struct sshkey *);
77 int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
78 void	 sshkey_xmss_free_bds(struct sshkey *);
79 int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
80 	    int *, int);
81 int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
82 	    struct sshbuf **);
83 int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
84 	    struct sshbuf **);
85 int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
86 int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
87 
88 #define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
89     0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
90 
91 int
sshkey_xmss_init(struct sshkey * key,const char * name)92 sshkey_xmss_init(struct sshkey *key, const char *name)
93 {
94 	struct ssh_xmss_state *state;
95 
96 	if (key->xmss_state != NULL)
97 		return SSH_ERR_INVALID_FORMAT;
98 	if (name == NULL)
99 		return SSH_ERR_INVALID_FORMAT;
100 	state = calloc(sizeof(struct ssh_xmss_state), 1);
101 	if (state == NULL)
102 		return SSH_ERR_ALLOC_FAIL;
103 	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
104 		state->n = 32;
105 		state->w = 16;
106 		state->h = 10;
107 	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
108 		state->n = 32;
109 		state->w = 16;
110 		state->h = 16;
111 	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
112 		state->n = 32;
113 		state->w = 16;
114 		state->h = 20;
115 	} else {
116 		free(state);
117 		return SSH_ERR_KEY_TYPE_UNKNOWN;
118 	}
119 	if ((key->xmss_name = strdup(name)) == NULL) {
120 		free(state);
121 		return SSH_ERR_ALLOC_FAIL;
122 	}
123 	state->k = 2;	/* XXX hardcoded */
124 	state->lockfd = -1;
125 	if (xmss_set_params(&state->params, state->n, state->h, state->w,
126 	    state->k) != 0) {
127 		free(state);
128 		return SSH_ERR_INVALID_FORMAT;
129 	}
130 	key->xmss_state = state;
131 	return 0;
132 }
133 
134 void
sshkey_xmss_free_state(struct sshkey * key)135 sshkey_xmss_free_state(struct sshkey *key)
136 {
137 	struct ssh_xmss_state *state = key->xmss_state;
138 
139 	sshkey_xmss_free_bds(key);
140 	if (state) {
141 		if (state->enc_keyiv) {
142 			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
143 			free(state->enc_keyiv);
144 		}
145 		free(state->enc_ciphername);
146 		free(state);
147 	}
148 	key->xmss_state = NULL;
149 }
150 
151 #define SSH_XMSS_K2_MAGIC	"k=2"
152 #define num_stack(x)		((x->h+1)*(x->n))
153 #define num_stacklevels(x)	(x->h+1)
154 #define num_auth(x)		((x->h)*(x->n))
155 #define num_keep(x)		((x->h >> 1)*(x->n))
156 #define num_th_nodes(x)		((x->h - x->k)*(x->n))
157 #define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
158 #define num_treehash(x)		((x->h) - (x->k))
159 
160 int
sshkey_xmss_init_bds_state(struct sshkey * key)161 sshkey_xmss_init_bds_state(struct sshkey *key)
162 {
163 	struct ssh_xmss_state *state = key->xmss_state;
164 	u_int32_t i;
165 
166 	state->stackoffset = 0;
167 	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
168 	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
169 	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
170 	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
171 	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
172 	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
173 	    (state->treehash = calloc(num_treehash(state),
174 	    sizeof(treehash_inst))) == NULL) {
175 		sshkey_xmss_free_bds(key);
176 		return SSH_ERR_ALLOC_FAIL;
177 	}
178 	for (i = 0; i < state->h - state->k; i++)
179 		state->treehash[i].node = &state->th_nodes[state->n*i];
180 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
181 	    state->stacklevels, state->auth, state->keep, state->treehash,
182 	    state->retain, 0);
183 	return 0;
184 }
185 
186 void
sshkey_xmss_free_bds(struct sshkey * key)187 sshkey_xmss_free_bds(struct sshkey *key)
188 {
189 	struct ssh_xmss_state *state = key->xmss_state;
190 
191 	if (state == NULL)
192 		return;
193 	free(state->stack);
194 	free(state->stacklevels);
195 	free(state->auth);
196 	free(state->keep);
197 	free(state->th_nodes);
198 	free(state->retain);
199 	free(state->treehash);
200 	state->stack = NULL;
201 	state->stacklevels = NULL;
202 	state->auth = NULL;
203 	state->keep = NULL;
204 	state->th_nodes = NULL;
205 	state->retain = NULL;
206 	state->treehash = NULL;
207 }
208 
209 void *
sshkey_xmss_params(const struct sshkey * key)210 sshkey_xmss_params(const struct sshkey *key)
211 {
212 	struct ssh_xmss_state *state = key->xmss_state;
213 
214 	if (state == NULL)
215 		return NULL;
216 	return &state->params;
217 }
218 
219 void *
sshkey_xmss_bds_state(const struct sshkey * key)220 sshkey_xmss_bds_state(const struct sshkey *key)
221 {
222 	struct ssh_xmss_state *state = key->xmss_state;
223 
224 	if (state == NULL)
225 		return NULL;
226 	return &state->bds;
227 }
228 
229 int
sshkey_xmss_siglen(const struct sshkey * key,size_t * lenp)230 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
231 {
232 	struct ssh_xmss_state *state = key->xmss_state;
233 
234 	if (lenp == NULL)
235 		return SSH_ERR_INVALID_ARGUMENT;
236 	if (state == NULL)
237 		return SSH_ERR_INVALID_FORMAT;
238 	*lenp = 4 + state->n +
239 	    state->params.wots_par.keysize +
240 	    state->h * state->n;
241 	return 0;
242 }
243 
244 size_t
sshkey_xmss_pklen(const struct sshkey * key)245 sshkey_xmss_pklen(const struct sshkey *key)
246 {
247 	struct ssh_xmss_state *state = key->xmss_state;
248 
249 	if (state == NULL)
250 		return 0;
251 	return state->n * 2;
252 }
253 
254 size_t
sshkey_xmss_sklen(const struct sshkey * key)255 sshkey_xmss_sklen(const struct sshkey *key)
256 {
257 	struct ssh_xmss_state *state = key->xmss_state;
258 
259 	if (state == NULL)
260 		return 0;
261 	return state->n * 4 + 4;
262 }
263 
264 int
sshkey_xmss_init_enc_key(struct sshkey * k,const char * ciphername)265 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
266 {
267 	struct ssh_xmss_state *state = k->xmss_state;
268 	const struct sshcipher *cipher;
269 	size_t keylen = 0, ivlen = 0;
270 
271 	if (state == NULL)
272 		return SSH_ERR_INVALID_ARGUMENT;
273 	if ((cipher = cipher_by_name(ciphername)) == NULL)
274 		return SSH_ERR_INTERNAL_ERROR;
275 	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
276 		return SSH_ERR_ALLOC_FAIL;
277 	keylen = cipher_keylen(cipher);
278 	ivlen = cipher_ivlen(cipher);
279 	state->enc_keyiv_len = keylen + ivlen;
280 	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
281 		free(state->enc_ciphername);
282 		state->enc_ciphername = NULL;
283 		return SSH_ERR_ALLOC_FAIL;
284 	}
285 	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
286 	return 0;
287 }
288 
289 int
sshkey_xmss_serialize_enc_key(const struct sshkey * k,struct sshbuf * b)290 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
291 {
292 	struct ssh_xmss_state *state = k->xmss_state;
293 	int r;
294 
295 	if (state == NULL || state->enc_keyiv == NULL ||
296 	    state->enc_ciphername == NULL)
297 		return SSH_ERR_INVALID_ARGUMENT;
298 	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
299 	    (r = sshbuf_put_string(b, state->enc_keyiv,
300 	    state->enc_keyiv_len)) != 0)
301 		return r;
302 	return 0;
303 }
304 
305 int
sshkey_xmss_deserialize_enc_key(struct sshkey * k,struct sshbuf * b)306 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
307 {
308 	struct ssh_xmss_state *state = k->xmss_state;
309 	size_t len;
310 	int r;
311 
312 	if (state == NULL)
313 		return SSH_ERR_INVALID_ARGUMENT;
314 	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
315 	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
316 		return r;
317 	state->enc_keyiv_len = len;
318 	return 0;
319 }
320 
321 int
sshkey_xmss_serialize_pk_info(const struct sshkey * k,struct sshbuf * b,enum sshkey_serialize_rep opts)322 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
323     enum sshkey_serialize_rep opts)
324 {
325 	struct ssh_xmss_state *state = k->xmss_state;
326 	u_char have_info = 1;
327 	u_int32_t idx;
328 	int r;
329 
330 	if (state == NULL)
331 		return SSH_ERR_INVALID_ARGUMENT;
332 	if (opts != SSHKEY_SERIALIZE_INFO)
333 		return 0;
334 	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
335 	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
336 	    (r = sshbuf_put_u32(b, idx)) != 0 ||
337 	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
338 		return r;
339 	return 0;
340 }
341 
342 int
sshkey_xmss_deserialize_pk_info(struct sshkey * k,struct sshbuf * b)343 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
344 {
345 	struct ssh_xmss_state *state = k->xmss_state;
346 	u_char have_info;
347 	int r;
348 
349 	if (state == NULL)
350 		return SSH_ERR_INVALID_ARGUMENT;
351 	/* optional */
352 	if (sshbuf_len(b) == 0)
353 		return 0;
354 	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
355 		return r;
356 	if (have_info != 1)
357 		return SSH_ERR_INVALID_ARGUMENT;
358 	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
359 	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
360 		return r;
361 	return 0;
362 }
363 
364 int
sshkey_xmss_generate_private_key(struct sshkey * k,int bits)365 sshkey_xmss_generate_private_key(struct sshkey *k, int bits)
366 {
367 	int r;
368 	const char *name;
369 
370 	if (bits == 10) {
371 		name = XMSS_SHA2_256_W16_H10_NAME;
372 	} else if (bits == 16) {
373 		name = XMSS_SHA2_256_W16_H16_NAME;
374 	} else if (bits == 20) {
375 		name = XMSS_SHA2_256_W16_H20_NAME;
376 	} else {
377 		name = XMSS_DEFAULT_NAME;
378 	}
379 	if ((r = sshkey_xmss_init(k, name)) != 0 ||
380 	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
381 	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
382 		return r;
383 	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
384 	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
385 		return SSH_ERR_ALLOC_FAIL;
386 	}
387 	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
388 	    sshkey_xmss_params(k));
389 	return 0;
390 }
391 
392 int
sshkey_xmss_get_state_from_file(struct sshkey * k,const char * filename,int * have_file,int printerror)393 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
394     int *have_file, int printerror)
395 {
396 	struct sshbuf *b = NULL, *enc = NULL;
397 	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
398 	u_int32_t len;
399 	unsigned char buf[4], *data = NULL;
400 
401 	*have_file = 0;
402 	if ((fd = open(filename, O_RDONLY)) >= 0) {
403 		*have_file = 1;
404 		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
405 			PRINT("corrupt state file: %s", filename);
406 			goto done;
407 		}
408 		len = PEEK_U32(buf);
409 		if ((data = calloc(len, 1)) == NULL) {
410 			ret = SSH_ERR_ALLOC_FAIL;
411 			goto done;
412 		}
413 		if (atomicio(read, fd, data, len) != len) {
414 			PRINT("cannot read blob: %s", filename);
415 			goto done;
416 		}
417 		if ((enc = sshbuf_from(data, len)) == NULL) {
418 			ret = SSH_ERR_ALLOC_FAIL;
419 			goto done;
420 		}
421 		sshkey_xmss_free_bds(k);
422 		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
423 			ret = r;
424 			goto done;
425 		}
426 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
427 			ret = r;
428 			goto done;
429 		}
430 		ret = 0;
431 	}
432 done:
433 	if (fd != -1)
434 		close(fd);
435 	free(data);
436 	sshbuf_free(enc);
437 	sshbuf_free(b);
438 	return ret;
439 }
440 
441 int
sshkey_xmss_get_state(const struct sshkey * k,int printerror)442 sshkey_xmss_get_state(const struct sshkey *k, int printerror)
443 {
444 	struct ssh_xmss_state *state = k->xmss_state;
445 	u_int32_t idx = 0;
446 	char *filename = NULL;
447 	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
448 	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
449 	int ret = SSH_ERR_INVALID_ARGUMENT, r;
450 
451 	if (state == NULL)
452 		goto done;
453 	/*
454 	 * If maxidx is set, then we are allowed a limited number
455 	 * of signatures, but don't need to access the disk.
456 	 * Otherwise we need to deal with the on-disk state.
457 	 */
458 	if (state->maxidx) {
459 		/* xmss_sk always contains the current state */
460 		idx = PEEK_U32(k->xmss_sk);
461 		if (idx < state->maxidx) {
462 			state->allow_update = 1;
463 			return 0;
464 		}
465 		return SSH_ERR_INVALID_ARGUMENT;
466 	}
467 	if ((filename = k->xmss_filename) == NULL)
468 		goto done;
469 	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
470 	    asprintf(&statefile, "%s.state", filename) == -1 ||
471 	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
472 		ret = SSH_ERR_ALLOC_FAIL;
473 		goto done;
474 	}
475 	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
476 		ret = SSH_ERR_SYSTEM_ERROR;
477 		PRINT("cannot open/create: %s", lockfile);
478 		goto done;
479 	}
480 	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
481 		if (errno != EWOULDBLOCK) {
482 			ret = SSH_ERR_SYSTEM_ERROR;
483 			PRINT("cannot lock: %s", lockfile);
484 			goto done;
485 		}
486 		if (++tries > 10) {
487 			ret = SSH_ERR_SYSTEM_ERROR;
488 			PRINT("giving up on: %s", lockfile);
489 			goto done;
490 		}
491 		usleep(1000*100*tries);
492 	}
493 	/* XXX no longer const */
494 	if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
495 	    statefile, &have_state, printerror)) != 0) {
496 		if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
497 		    ostatefile, &have_ostate, printerror)) == 0) {
498 			state->allow_update = 1;
499 			r = sshkey_xmss_forward_state(k, 1);
500 			state->idx = PEEK_U32(k->xmss_sk);
501 			state->allow_update = 0;
502 		}
503 	}
504 	if (!have_state && !have_ostate) {
505 		/* check that bds state is initialized */
506 		if (state->bds.auth == NULL)
507 			goto done;
508 		PRINT("start from scratch idx 0: %u", state->idx);
509 	} else if (r != 0) {
510 		ret = r;
511 		goto done;
512 	}
513 	if (state->idx + 1 < state->idx) {
514 		PRINT("state wrap: %u", state->idx);
515 		goto done;
516 	}
517 	state->have_state = have_state;
518 	state->lockfd = lockfd;
519 	state->allow_update = 1;
520 	lockfd = -1;
521 	ret = 0;
522 done:
523 	if (lockfd != -1)
524 		close(lockfd);
525 	free(lockfile);
526 	free(statefile);
527 	free(ostatefile);
528 	return ret;
529 }
530 
531 int
sshkey_xmss_forward_state(const struct sshkey * k,u_int32_t reserve)532 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
533 {
534 	struct ssh_xmss_state *state = k->xmss_state;
535 	u_char *sig = NULL;
536 	size_t required_siglen;
537 	unsigned long long smlen;
538 	u_char data;
539 	int ret, r;
540 
541 	if (state == NULL || !state->allow_update)
542 		return SSH_ERR_INVALID_ARGUMENT;
543 	if (reserve == 0)
544 		return SSH_ERR_INVALID_ARGUMENT;
545 	if (state->idx + reserve <= state->idx)
546 		return SSH_ERR_INVALID_ARGUMENT;
547 	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
548 		return r;
549 	if ((sig = malloc(required_siglen)) == NULL)
550 		return SSH_ERR_ALLOC_FAIL;
551 	while (reserve-- > 0) {
552 		state->idx = PEEK_U32(k->xmss_sk);
553 		smlen = required_siglen;
554 		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
555 		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
556 			r = SSH_ERR_INVALID_ARGUMENT;
557 			break;
558 		}
559 	}
560 	free(sig);
561 	return r;
562 }
563 
564 int
sshkey_xmss_update_state(const struct sshkey * k,int printerror)565 sshkey_xmss_update_state(const struct sshkey *k, int printerror)
566 {
567 	struct ssh_xmss_state *state = k->xmss_state;
568 	struct sshbuf *b = NULL, *enc = NULL;
569 	u_int32_t idx = 0;
570 	unsigned char buf[4];
571 	char *filename = NULL;
572 	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
573 	int fd = -1;
574 	int ret = SSH_ERR_INVALID_ARGUMENT;
575 
576 	if (state == NULL || !state->allow_update)
577 		return ret;
578 	if (state->maxidx) {
579 		/* no update since the number of signatures is limited */
580 		ret = 0;
581 		goto done;
582 	}
583 	idx = PEEK_U32(k->xmss_sk);
584 	if (idx == state->idx) {
585 		/* no signature happened, no need to update */
586 		ret = 0;
587 		goto done;
588 	} else if (idx != state->idx + 1) {
589 		PRINT("more than one signature happened: idx %u state %u",
590 		    idx, state->idx);
591 		goto done;
592 	}
593 	state->idx = idx;
594 	if ((filename = k->xmss_filename) == NULL)
595 		goto done;
596 	if (asprintf(&statefile, "%s.state", filename) == -1 ||
597 	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
598 	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
599 		ret = SSH_ERR_ALLOC_FAIL;
600 		goto done;
601 	}
602 	unlink(nstatefile);
603 	if ((b = sshbuf_new()) == NULL) {
604 		ret = SSH_ERR_ALLOC_FAIL;
605 		goto done;
606 	}
607 	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
608 		PRINT("SERLIALIZE FAILED: %d", ret);
609 		goto done;
610 	}
611 	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
612 		PRINT("ENCRYPT FAILED: %d", ret);
613 		goto done;
614 	}
615 	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
616 		ret = SSH_ERR_SYSTEM_ERROR;
617 		PRINT("open new state file: %s", nstatefile);
618 		goto done;
619 	}
620 	POKE_U32(buf, sshbuf_len(enc));
621 	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
622 		ret = SSH_ERR_SYSTEM_ERROR;
623 		PRINT("write new state file hdr: %s", nstatefile);
624 		close(fd);
625 		goto done;
626 	}
627 	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
628 	    sshbuf_len(enc)) {
629 		ret = SSH_ERR_SYSTEM_ERROR;
630 		PRINT("write new state file data: %s", nstatefile);
631 		close(fd);
632 		goto done;
633 	}
634 	if (fsync(fd) == -1) {
635 		ret = SSH_ERR_SYSTEM_ERROR;
636 		PRINT("sync new state file: %s", nstatefile);
637 		close(fd);
638 		goto done;
639 	}
640 	if (close(fd) == -1) {
641 		ret = SSH_ERR_SYSTEM_ERROR;
642 		PRINT("close new state file: %s", nstatefile);
643 		goto done;
644 	}
645 	if (state->have_state) {
646 		unlink(ostatefile);
647 		if (link(statefile, ostatefile)) {
648 			ret = SSH_ERR_SYSTEM_ERROR;
649 			PRINT("backup state %s to %s", statefile, ostatefile);
650 			goto done;
651 		}
652 	}
653 	if (rename(nstatefile, statefile) == -1) {
654 		ret = SSH_ERR_SYSTEM_ERROR;
655 		PRINT("rename %s to %s", nstatefile, statefile);
656 		goto done;
657 	}
658 	ret = 0;
659 done:
660 	if (state->lockfd != -1) {
661 		close(state->lockfd);
662 		state->lockfd = -1;
663 	}
664 	if (nstatefile)
665 		unlink(nstatefile);
666 	free(statefile);
667 	free(ostatefile);
668 	free(nstatefile);
669 	sshbuf_free(b);
670 	sshbuf_free(enc);
671 	return ret;
672 }
673 
674 int
sshkey_xmss_serialize_state(const struct sshkey * k,struct sshbuf * b)675 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
676 {
677 	struct ssh_xmss_state *state = k->xmss_state;
678 	treehash_inst *th;
679 	u_int32_t i, node;
680 	int r;
681 
682 	if (state == NULL)
683 		return SSH_ERR_INVALID_ARGUMENT;
684 	if (state->stack == NULL)
685 		return SSH_ERR_INVALID_ARGUMENT;
686 	state->stackoffset = state->bds.stackoffset;	/* copy back */
687 	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
688 	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
689 	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
690 	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
691 	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
692 	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
693 	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
694 	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
695 	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
696 	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
697 		return r;
698 	for (i = 0; i < num_treehash(state); i++) {
699 		th = &state->treehash[i];
700 		node = th->node - state->th_nodes;
701 		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
702 		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
703 		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
704 		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
705 		    (r = sshbuf_put_u32(b, node)) != 0)
706 			return r;
707 	}
708 	return 0;
709 }
710 
711 int
sshkey_xmss_serialize_state_opt(const struct sshkey * k,struct sshbuf * b,enum sshkey_serialize_rep opts)712 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
713     enum sshkey_serialize_rep opts)
714 {
715 	struct ssh_xmss_state *state = k->xmss_state;
716 	int r = SSH_ERR_INVALID_ARGUMENT;
717 	u_char have_stack, have_filename, have_enc;
718 
719 	if (state == NULL)
720 		return SSH_ERR_INVALID_ARGUMENT;
721 	if ((r = sshbuf_put_u8(b, opts)) != 0)
722 		return r;
723 	switch (opts) {
724 	case SSHKEY_SERIALIZE_STATE:
725 		r = sshkey_xmss_serialize_state(k, b);
726 		break;
727 	case SSHKEY_SERIALIZE_FULL:
728 		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
729 			return r;
730 		r = sshkey_xmss_serialize_state(k, b);
731 		break;
732 	case SSHKEY_SERIALIZE_SHIELD:
733 		/* all of stack/filename/enc are optional */
734 		have_stack = state->stack != NULL;
735 		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
736 			return r;
737 		if (have_stack) {
738 			state->idx = PEEK_U32(k->xmss_sk);	/* update */
739 			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
740 				return r;
741 		}
742 		have_filename = k->xmss_filename != NULL;
743 		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
744 			return r;
745 		if (have_filename &&
746 		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
747 			return r;
748 		have_enc = state->enc_keyiv != NULL;
749 		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
750 			return r;
751 		if (have_enc &&
752 		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
753 			return r;
754 		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
755 		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
756 			return r;
757 		break;
758 	case SSHKEY_SERIALIZE_DEFAULT:
759 		r = 0;
760 		break;
761 	default:
762 		r = SSH_ERR_INVALID_ARGUMENT;
763 		break;
764 	}
765 	return r;
766 }
767 
768 int
sshkey_xmss_deserialize_state(struct sshkey * k,struct sshbuf * b)769 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
770 {
771 	struct ssh_xmss_state *state = k->xmss_state;
772 	treehash_inst *th;
773 	u_int32_t i, lh, node;
774 	size_t ls, lsl, la, lk, ln, lr;
775 	char *magic;
776 	int r = SSH_ERR_INTERNAL_ERROR;
777 
778 	if (state == NULL)
779 		return SSH_ERR_INVALID_ARGUMENT;
780 	if (k->xmss_sk == NULL)
781 		return SSH_ERR_INVALID_ARGUMENT;
782 	if ((state->treehash = calloc(num_treehash(state),
783 	    sizeof(treehash_inst))) == NULL)
784 		return SSH_ERR_ALLOC_FAIL;
785 	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
786 	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
787 	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
788 	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
789 	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
790 	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
791 	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
792 	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
793 	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
794 	    (r = sshbuf_get_u32(b, &lh)) != 0)
795 		goto out;
796 	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
797 		r = SSH_ERR_INVALID_ARGUMENT;
798 		goto out;
799 	}
800 	/* XXX check stackoffset */
801 	if (ls != num_stack(state) ||
802 	    lsl != num_stacklevels(state) ||
803 	    la != num_auth(state) ||
804 	    lk != num_keep(state) ||
805 	    ln != num_th_nodes(state) ||
806 	    lr != num_retain(state) ||
807 	    lh != num_treehash(state)) {
808 		r = SSH_ERR_INVALID_ARGUMENT;
809 		goto out;
810 	}
811 	for (i = 0; i < num_treehash(state); i++) {
812 		th = &state->treehash[i];
813 		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
814 		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
815 		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
816 		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
817 		    (r = sshbuf_get_u32(b, &node)) != 0)
818 			goto out;
819 		if (node < num_th_nodes(state))
820 			th->node = &state->th_nodes[node];
821 	}
822 	POKE_U32(k->xmss_sk, state->idx);
823 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
824 	    state->stacklevels, state->auth, state->keep, state->treehash,
825 	    state->retain, 0);
826 	/* success */
827 	r = 0;
828  out:
829 	free(magic);
830 	return r;
831 }
832 
833 int
sshkey_xmss_deserialize_state_opt(struct sshkey * k,struct sshbuf * b)834 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
835 {
836 	struct ssh_xmss_state *state = k->xmss_state;
837 	enum sshkey_serialize_rep opts;
838 	u_char have_state, have_stack, have_filename, have_enc;
839 	int r;
840 
841 	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
842 		return r;
843 
844 	opts = have_state;
845 	switch (opts) {
846 	case SSHKEY_SERIALIZE_DEFAULT:
847 		r = 0;
848 		break;
849 	case SSHKEY_SERIALIZE_SHIELD:
850 		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
851 			return r;
852 		if (have_stack &&
853 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
854 			return r;
855 		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
856 			return r;
857 		if (have_filename &&
858 		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
859 			return r;
860 		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
861 			return r;
862 		if (have_enc &&
863 		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
864 			return r;
865 		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
866 		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
867 			return r;
868 		break;
869 	case SSHKEY_SERIALIZE_STATE:
870 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
871 			return r;
872 		break;
873 	case SSHKEY_SERIALIZE_FULL:
874 		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
875 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
876 			return r;
877 		break;
878 	default:
879 		r = SSH_ERR_INVALID_FORMAT;
880 		break;
881 	}
882 	return r;
883 }
884 
885 int
sshkey_xmss_encrypt_state(const struct sshkey * k,struct sshbuf * b,struct sshbuf ** retp)886 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
887    struct sshbuf **retp)
888 {
889 	struct ssh_xmss_state *state = k->xmss_state;
890 	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
891 	struct sshcipher_ctx *ciphercontext = NULL;
892 	const struct sshcipher *cipher;
893 	u_char *cp, *key, *iv = NULL;
894 	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
895 	int r = SSH_ERR_INTERNAL_ERROR;
896 
897 	if (retp != NULL)
898 		*retp = NULL;
899 	if (state == NULL ||
900 	    state->enc_keyiv == NULL ||
901 	    state->enc_ciphername == NULL)
902 		return SSH_ERR_INTERNAL_ERROR;
903 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
904 		r = SSH_ERR_INTERNAL_ERROR;
905 		goto out;
906 	}
907 	blocksize = cipher_blocksize(cipher);
908 	keylen = cipher_keylen(cipher);
909 	ivlen = cipher_ivlen(cipher);
910 	authlen = cipher_authlen(cipher);
911 	if (state->enc_keyiv_len != keylen + ivlen) {
912 		r = SSH_ERR_INVALID_FORMAT;
913 		goto out;
914 	}
915 	key = state->enc_keyiv;
916 	if ((encrypted = sshbuf_new()) == NULL ||
917 	    (encoded = sshbuf_new()) == NULL ||
918 	    (padded = sshbuf_new()) == NULL ||
919 	    (iv = malloc(ivlen)) == NULL) {
920 		r = SSH_ERR_ALLOC_FAIL;
921 		goto out;
922 	}
923 
924 	/* replace first 4 bytes of IV with index to ensure uniqueness */
925 	memcpy(iv, key + keylen, ivlen);
926 	POKE_U32(iv, state->idx);
927 
928 	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
929 	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
930 		goto out;
931 
932 	/* padded state will be encrypted */
933 	if ((r = sshbuf_putb(padded, b)) != 0)
934 		goto out;
935 	i = 0;
936 	while (sshbuf_len(padded) % blocksize) {
937 		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
938 			goto out;
939 	}
940 	encrypted_len = sshbuf_len(padded);
941 
942 	/* header including the length of state is used as AAD */
943 	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
944 		goto out;
945 	aadlen = sshbuf_len(encoded);
946 
947 	/* concat header and state */
948 	if ((r = sshbuf_putb(encoded, padded)) != 0)
949 		goto out;
950 
951 	/* reserve space for encryption of encoded data plus auth tag */
952 	/* encrypt at offset addlen */
953 	if ((r = sshbuf_reserve(encrypted,
954 	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
955 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
956 	    iv, ivlen, 1)) != 0 ||
957 	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
958 	    encrypted_len, aadlen, authlen)) != 0)
959 		goto out;
960 
961 	/* success */
962 	r = 0;
963  out:
964 	if (retp != NULL) {
965 		*retp = encrypted;
966 		encrypted = NULL;
967 	}
968 	sshbuf_free(padded);
969 	sshbuf_free(encoded);
970 	sshbuf_free(encrypted);
971 	cipher_free(ciphercontext);
972 	free(iv);
973 	return r;
974 }
975 
976 int
sshkey_xmss_decrypt_state(const struct sshkey * k,struct sshbuf * encoded,struct sshbuf ** retp)977 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
978    struct sshbuf **retp)
979 {
980 	struct ssh_xmss_state *state = k->xmss_state;
981 	struct sshbuf *copy = NULL, *decrypted = NULL;
982 	struct sshcipher_ctx *ciphercontext = NULL;
983 	const struct sshcipher *cipher = NULL;
984 	u_char *key, *iv = NULL, *dp;
985 	size_t keylen, ivlen, authlen, aadlen;
986 	u_int blocksize, encrypted_len, index;
987 	int r = SSH_ERR_INTERNAL_ERROR;
988 
989 	if (retp != NULL)
990 		*retp = NULL;
991 	if (state == NULL ||
992 	    state->enc_keyiv == NULL ||
993 	    state->enc_ciphername == NULL)
994 		return SSH_ERR_INTERNAL_ERROR;
995 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
996 		r = SSH_ERR_INVALID_FORMAT;
997 		goto out;
998 	}
999 	blocksize = cipher_blocksize(cipher);
1000 	keylen = cipher_keylen(cipher);
1001 	ivlen = cipher_ivlen(cipher);
1002 	authlen = cipher_authlen(cipher);
1003 	if (state->enc_keyiv_len != keylen + ivlen) {
1004 		r = SSH_ERR_INTERNAL_ERROR;
1005 		goto out;
1006 	}
1007 	key = state->enc_keyiv;
1008 
1009 	if ((copy = sshbuf_fromb(encoded)) == NULL ||
1010 	    (decrypted = sshbuf_new()) == NULL ||
1011 	    (iv = malloc(ivlen)) == NULL) {
1012 		r = SSH_ERR_ALLOC_FAIL;
1013 		goto out;
1014 	}
1015 
1016 	/* check magic */
1017 	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1018 	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1019 		r = SSH_ERR_INVALID_FORMAT;
1020 		goto out;
1021 	}
1022 	/* parse public portion */
1023 	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1024 	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1025 	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1026 		goto out;
1027 
1028 	/* check size of encrypted key blob */
1029 	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1030 		r = SSH_ERR_INVALID_FORMAT;
1031 		goto out;
1032 	}
1033 	/* check that an appropriate amount of auth data is present */
1034 	if (sshbuf_len(encoded) < authlen ||
1035 	    sshbuf_len(encoded) - authlen < encrypted_len) {
1036 		r = SSH_ERR_INVALID_FORMAT;
1037 		goto out;
1038 	}
1039 
1040 	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1041 
1042 	/* replace first 4 bytes of IV with index to ensure uniqueness */
1043 	memcpy(iv, key + keylen, ivlen);
1044 	POKE_U32(iv, index);
1045 
1046 	/* decrypt private state of key */
1047 	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1048 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
1049 	    iv, ivlen, 0)) != 0 ||
1050 	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1051 	    encrypted_len, aadlen, authlen)) != 0)
1052 		goto out;
1053 
1054 	/* there should be no trailing data */
1055 	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1056 		goto out;
1057 	if (sshbuf_len(encoded) != 0) {
1058 		r = SSH_ERR_INVALID_FORMAT;
1059 		goto out;
1060 	}
1061 
1062 	/* remove AAD */
1063 	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1064 		goto out;
1065 	/* XXX encrypted includes unchecked padding */
1066 
1067 	/* success */
1068 	r = 0;
1069 	if (retp != NULL) {
1070 		*retp = decrypted;
1071 		decrypted = NULL;
1072 	}
1073  out:
1074 	cipher_free(ciphercontext);
1075 	sshbuf_free(copy);
1076 	sshbuf_free(decrypted);
1077 	free(iv);
1078 	return r;
1079 }
1080 
1081 u_int32_t
sshkey_xmss_signatures_left(const struct sshkey * k)1082 sshkey_xmss_signatures_left(const struct sshkey *k)
1083 {
1084 	struct ssh_xmss_state *state = k->xmss_state;
1085 	u_int32_t idx;
1086 
1087 	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1088 	    state->maxidx) {
1089 		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1090 		if (idx < state->maxidx)
1091 			return state->maxidx - idx;
1092 	}
1093 	return 0;
1094 }
1095 
1096 int
sshkey_xmss_enable_maxsign(struct sshkey * k,u_int32_t maxsign)1097 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1098 {
1099 	struct ssh_xmss_state *state = k->xmss_state;
1100 
1101 	if (sshkey_type_plain(k->type) != KEY_XMSS)
1102 		return SSH_ERR_INVALID_ARGUMENT;
1103 	if (maxsign == 0)
1104 		return 0;
1105 	if (state->idx + maxsign < state->idx)
1106 		return SSH_ERR_INVALID_ARGUMENT;
1107 	state->maxidx = state->idx + maxsign;
1108 	return 0;
1109 }
1110