xref: /openbsd/usr.bin/ssh/sshkey-xmss.c (revision b6025feb)
1 /* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */
2 /*
3  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include <sys/types.h>
27 #include <sys/uio.h>
28 
29 #include <stdio.h>
30 #include <string.h>
31 #include <unistd.h>
32 #include <fcntl.h>
33 #include <errno.h>
34 
35 #include "ssh2.h"
36 #include "ssherr.h"
37 #include "sshbuf.h"
38 #include "cipher.h"
39 #include "sshkey.h"
40 #include "sshkey-xmss.h"
41 #include "atomicio.h"
42 #include "log.h"
43 
44 #include "xmss_fast.h"
45 
46 /* opaque internal XMSS state */
47 #define XMSS_MAGIC		"xmss-state-v1"
48 #define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
49 struct ssh_xmss_state {
50 	xmss_params	params;
51 	u_int32_t	n, w, h, k;
52 
53 	bds_state	bds;
54 	u_char		*stack;
55 	u_int32_t	stackoffset;
56 	u_char		*stacklevels;
57 	u_char		*auth;
58 	u_char		*keep;
59 	u_char		*th_nodes;
60 	u_char		*retain;
61 	treehash_inst	*treehash;
62 
63 	u_int32_t	idx;		/* state read from file */
64 	u_int32_t	maxidx;		/* restricted # of signatures */
65 	int		have_state;	/* .state file exists */
66 	int		lockfd;		/* locked in sshkey_xmss_get_state() */
67 	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
68 	char		*enc_ciphername;/* encrypt state with cipher */
69 	u_char		*enc_keyiv;	/* encrypt state with key */
70 	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
71 };
72 
73 int	 sshkey_xmss_init_bds_state(struct sshkey *);
74 int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
75 void	 sshkey_xmss_free_bds(struct sshkey *);
76 int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
77 	    int *, int);
78 int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
79 	    struct sshbuf **);
80 int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
81 	    struct sshbuf **);
82 int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
83 int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
84 
85 #define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
86     0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
87 
88 int
sshkey_xmss_init(struct sshkey * key,const char * name)89 sshkey_xmss_init(struct sshkey *key, const char *name)
90 {
91 	struct ssh_xmss_state *state;
92 
93 	if (key->xmss_state != NULL)
94 		return SSH_ERR_INVALID_FORMAT;
95 	if (name == NULL)
96 		return SSH_ERR_INVALID_FORMAT;
97 	state = calloc(sizeof(struct ssh_xmss_state), 1);
98 	if (state == NULL)
99 		return SSH_ERR_ALLOC_FAIL;
100 	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
101 		state->n = 32;
102 		state->w = 16;
103 		state->h = 10;
104 	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
105 		state->n = 32;
106 		state->w = 16;
107 		state->h = 16;
108 	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
109 		state->n = 32;
110 		state->w = 16;
111 		state->h = 20;
112 	} else {
113 		free(state);
114 		return SSH_ERR_KEY_TYPE_UNKNOWN;
115 	}
116 	if ((key->xmss_name = strdup(name)) == NULL) {
117 		free(state);
118 		return SSH_ERR_ALLOC_FAIL;
119 	}
120 	state->k = 2;	/* XXX hardcoded */
121 	state->lockfd = -1;
122 	if (xmss_set_params(&state->params, state->n, state->h, state->w,
123 	    state->k) != 0) {
124 		free(state);
125 		return SSH_ERR_INVALID_FORMAT;
126 	}
127 	key->xmss_state = state;
128 	return 0;
129 }
130 
131 void
sshkey_xmss_free_state(struct sshkey * key)132 sshkey_xmss_free_state(struct sshkey *key)
133 {
134 	struct ssh_xmss_state *state = key->xmss_state;
135 
136 	sshkey_xmss_free_bds(key);
137 	if (state) {
138 		if (state->enc_keyiv) {
139 			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
140 			free(state->enc_keyiv);
141 		}
142 		free(state->enc_ciphername);
143 		free(state);
144 	}
145 	key->xmss_state = NULL;
146 }
147 
148 #define SSH_XMSS_K2_MAGIC	"k=2"
149 #define num_stack(x)		((x->h+1)*(x->n))
150 #define num_stacklevels(x)	(x->h+1)
151 #define num_auth(x)		((x->h)*(x->n))
152 #define num_keep(x)		((x->h >> 1)*(x->n))
153 #define num_th_nodes(x)		((x->h - x->k)*(x->n))
154 #define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
155 #define num_treehash(x)		((x->h) - (x->k))
156 
157 int
sshkey_xmss_init_bds_state(struct sshkey * key)158 sshkey_xmss_init_bds_state(struct sshkey *key)
159 {
160 	struct ssh_xmss_state *state = key->xmss_state;
161 	u_int32_t i;
162 
163 	state->stackoffset = 0;
164 	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
165 	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
166 	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
167 	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
168 	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
169 	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
170 	    (state->treehash = calloc(num_treehash(state),
171 	    sizeof(treehash_inst))) == NULL) {
172 		sshkey_xmss_free_bds(key);
173 		return SSH_ERR_ALLOC_FAIL;
174 	}
175 	for (i = 0; i < state->h - state->k; i++)
176 		state->treehash[i].node = &state->th_nodes[state->n*i];
177 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
178 	    state->stacklevels, state->auth, state->keep, state->treehash,
179 	    state->retain, 0);
180 	return 0;
181 }
182 
183 void
sshkey_xmss_free_bds(struct sshkey * key)184 sshkey_xmss_free_bds(struct sshkey *key)
185 {
186 	struct ssh_xmss_state *state = key->xmss_state;
187 
188 	if (state == NULL)
189 		return;
190 	free(state->stack);
191 	free(state->stacklevels);
192 	free(state->auth);
193 	free(state->keep);
194 	free(state->th_nodes);
195 	free(state->retain);
196 	free(state->treehash);
197 	state->stack = NULL;
198 	state->stacklevels = NULL;
199 	state->auth = NULL;
200 	state->keep = NULL;
201 	state->th_nodes = NULL;
202 	state->retain = NULL;
203 	state->treehash = NULL;
204 }
205 
206 void *
sshkey_xmss_params(const struct sshkey * key)207 sshkey_xmss_params(const struct sshkey *key)
208 {
209 	struct ssh_xmss_state *state = key->xmss_state;
210 
211 	if (state == NULL)
212 		return NULL;
213 	return &state->params;
214 }
215 
216 void *
sshkey_xmss_bds_state(const struct sshkey * key)217 sshkey_xmss_bds_state(const struct sshkey *key)
218 {
219 	struct ssh_xmss_state *state = key->xmss_state;
220 
221 	if (state == NULL)
222 		return NULL;
223 	return &state->bds;
224 }
225 
226 int
sshkey_xmss_siglen(const struct sshkey * key,size_t * lenp)227 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
228 {
229 	struct ssh_xmss_state *state = key->xmss_state;
230 
231 	if (lenp == NULL)
232 		return SSH_ERR_INVALID_ARGUMENT;
233 	if (state == NULL)
234 		return SSH_ERR_INVALID_FORMAT;
235 	*lenp = 4 + state->n +
236 	    state->params.wots_par.keysize +
237 	    state->h * state->n;
238 	return 0;
239 }
240 
241 size_t
sshkey_xmss_pklen(const struct sshkey * key)242 sshkey_xmss_pklen(const struct sshkey *key)
243 {
244 	struct ssh_xmss_state *state = key->xmss_state;
245 
246 	if (state == NULL)
247 		return 0;
248 	return state->n * 2;
249 }
250 
251 size_t
sshkey_xmss_sklen(const struct sshkey * key)252 sshkey_xmss_sklen(const struct sshkey *key)
253 {
254 	struct ssh_xmss_state *state = key->xmss_state;
255 
256 	if (state == NULL)
257 		return 0;
258 	return state->n * 4 + 4;
259 }
260 
261 int
sshkey_xmss_init_enc_key(struct sshkey * k,const char * ciphername)262 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
263 {
264 	struct ssh_xmss_state *state = k->xmss_state;
265 	const struct sshcipher *cipher;
266 	size_t keylen = 0, ivlen = 0;
267 
268 	if (state == NULL)
269 		return SSH_ERR_INVALID_ARGUMENT;
270 	if ((cipher = cipher_by_name(ciphername)) == NULL)
271 		return SSH_ERR_INTERNAL_ERROR;
272 	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
273 		return SSH_ERR_ALLOC_FAIL;
274 	keylen = cipher_keylen(cipher);
275 	ivlen = cipher_ivlen(cipher);
276 	state->enc_keyiv_len = keylen + ivlen;
277 	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
278 		free(state->enc_ciphername);
279 		state->enc_ciphername = NULL;
280 		return SSH_ERR_ALLOC_FAIL;
281 	}
282 	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
283 	return 0;
284 }
285 
286 int
sshkey_xmss_serialize_enc_key(const struct sshkey * k,struct sshbuf * b)287 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
288 {
289 	struct ssh_xmss_state *state = k->xmss_state;
290 	int r;
291 
292 	if (state == NULL || state->enc_keyiv == NULL ||
293 	    state->enc_ciphername == NULL)
294 		return SSH_ERR_INVALID_ARGUMENT;
295 	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
296 	    (r = sshbuf_put_string(b, state->enc_keyiv,
297 	    state->enc_keyiv_len)) != 0)
298 		return r;
299 	return 0;
300 }
301 
302 int
sshkey_xmss_deserialize_enc_key(struct sshkey * k,struct sshbuf * b)303 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
304 {
305 	struct ssh_xmss_state *state = k->xmss_state;
306 	size_t len;
307 	int r;
308 
309 	if (state == NULL)
310 		return SSH_ERR_INVALID_ARGUMENT;
311 	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
312 	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
313 		return r;
314 	state->enc_keyiv_len = len;
315 	return 0;
316 }
317 
318 int
sshkey_xmss_serialize_pk_info(const struct sshkey * k,struct sshbuf * b,enum sshkey_serialize_rep opts)319 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
320     enum sshkey_serialize_rep opts)
321 {
322 	struct ssh_xmss_state *state = k->xmss_state;
323 	u_char have_info = 1;
324 	u_int32_t idx;
325 	int r;
326 
327 	if (state == NULL)
328 		return SSH_ERR_INVALID_ARGUMENT;
329 	if (opts != SSHKEY_SERIALIZE_INFO)
330 		return 0;
331 	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
332 	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
333 	    (r = sshbuf_put_u32(b, idx)) != 0 ||
334 	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
335 		return r;
336 	return 0;
337 }
338 
339 int
sshkey_xmss_deserialize_pk_info(struct sshkey * k,struct sshbuf * b)340 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
341 {
342 	struct ssh_xmss_state *state = k->xmss_state;
343 	u_char have_info;
344 	int r;
345 
346 	if (state == NULL)
347 		return SSH_ERR_INVALID_ARGUMENT;
348 	/* optional */
349 	if (sshbuf_len(b) == 0)
350 		return 0;
351 	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
352 		return r;
353 	if (have_info != 1)
354 		return SSH_ERR_INVALID_ARGUMENT;
355 	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
356 	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
357 		return r;
358 	return 0;
359 }
360 
361 int
sshkey_xmss_generate_private_key(struct sshkey * k,int bits)362 sshkey_xmss_generate_private_key(struct sshkey *k, int bits)
363 {
364 	int r;
365 	const char *name;
366 
367 	if (bits == 10) {
368 		name = XMSS_SHA2_256_W16_H10_NAME;
369 	} else if (bits == 16) {
370 		name = XMSS_SHA2_256_W16_H16_NAME;
371 	} else if (bits == 20) {
372 		name = XMSS_SHA2_256_W16_H20_NAME;
373 	} else {
374 		name = XMSS_DEFAULT_NAME;
375 	}
376 	if ((r = sshkey_xmss_init(k, name)) != 0 ||
377 	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
378 	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
379 		return r;
380 	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
381 	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
382 		return SSH_ERR_ALLOC_FAIL;
383 	}
384 	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
385 	    sshkey_xmss_params(k));
386 	return 0;
387 }
388 
389 int
sshkey_xmss_get_state_from_file(struct sshkey * k,const char * filename,int * have_file,int printerror)390 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
391     int *have_file, int printerror)
392 {
393 	struct sshbuf *b = NULL, *enc = NULL;
394 	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
395 	u_int32_t len;
396 	unsigned char buf[4], *data = NULL;
397 
398 	*have_file = 0;
399 	if ((fd = open(filename, O_RDONLY)) >= 0) {
400 		*have_file = 1;
401 		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
402 			PRINT("corrupt state file: %s", filename);
403 			goto done;
404 		}
405 		len = PEEK_U32(buf);
406 		if ((data = calloc(len, 1)) == NULL) {
407 			ret = SSH_ERR_ALLOC_FAIL;
408 			goto done;
409 		}
410 		if (atomicio(read, fd, data, len) != len) {
411 			PRINT("cannot read blob: %s", filename);
412 			goto done;
413 		}
414 		if ((enc = sshbuf_from(data, len)) == NULL) {
415 			ret = SSH_ERR_ALLOC_FAIL;
416 			goto done;
417 		}
418 		sshkey_xmss_free_bds(k);
419 		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
420 			ret = r;
421 			goto done;
422 		}
423 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
424 			ret = r;
425 			goto done;
426 		}
427 		ret = 0;
428 	}
429 done:
430 	if (fd != -1)
431 		close(fd);
432 	free(data);
433 	sshbuf_free(enc);
434 	sshbuf_free(b);
435 	return ret;
436 }
437 
438 int
sshkey_xmss_get_state(const struct sshkey * k,int printerror)439 sshkey_xmss_get_state(const struct sshkey *k, int printerror)
440 {
441 	struct ssh_xmss_state *state = k->xmss_state;
442 	u_int32_t idx = 0;
443 	char *filename = NULL;
444 	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
445 	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
446 	int ret = SSH_ERR_INVALID_ARGUMENT, r;
447 
448 	if (state == NULL)
449 		goto done;
450 	/*
451 	 * If maxidx is set, then we are allowed a limited number
452 	 * of signatures, but don't need to access the disk.
453 	 * Otherwise we need to deal with the on-disk state.
454 	 */
455 	if (state->maxidx) {
456 		/* xmss_sk always contains the current state */
457 		idx = PEEK_U32(k->xmss_sk);
458 		if (idx < state->maxidx) {
459 			state->allow_update = 1;
460 			return 0;
461 		}
462 		return SSH_ERR_INVALID_ARGUMENT;
463 	}
464 	if ((filename = k->xmss_filename) == NULL)
465 		goto done;
466 	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
467 	    asprintf(&statefile, "%s.state", filename) == -1 ||
468 	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
469 		ret = SSH_ERR_ALLOC_FAIL;
470 		goto done;
471 	}
472 	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
473 		ret = SSH_ERR_SYSTEM_ERROR;
474 		PRINT("cannot open/create: %s", lockfile);
475 		goto done;
476 	}
477 	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
478 		if (errno != EWOULDBLOCK) {
479 			ret = SSH_ERR_SYSTEM_ERROR;
480 			PRINT("cannot lock: %s", lockfile);
481 			goto done;
482 		}
483 		if (++tries > 10) {
484 			ret = SSH_ERR_SYSTEM_ERROR;
485 			PRINT("giving up on: %s", lockfile);
486 			goto done;
487 		}
488 		usleep(1000*100*tries);
489 	}
490 	/* XXX no longer const */
491 	if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
492 	    statefile, &have_state, printerror)) != 0) {
493 		if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
494 		    ostatefile, &have_ostate, printerror)) == 0) {
495 			state->allow_update = 1;
496 			r = sshkey_xmss_forward_state(k, 1);
497 			state->idx = PEEK_U32(k->xmss_sk);
498 			state->allow_update = 0;
499 		}
500 	}
501 	if (!have_state && !have_ostate) {
502 		/* check that bds state is initialized */
503 		if (state->bds.auth == NULL)
504 			goto done;
505 		PRINT("start from scratch idx 0: %u", state->idx);
506 	} else if (r != 0) {
507 		ret = r;
508 		goto done;
509 	}
510 	if (state->idx + 1 < state->idx) {
511 		PRINT("state wrap: %u", state->idx);
512 		goto done;
513 	}
514 	state->have_state = have_state;
515 	state->lockfd = lockfd;
516 	state->allow_update = 1;
517 	lockfd = -1;
518 	ret = 0;
519 done:
520 	if (lockfd != -1)
521 		close(lockfd);
522 	free(lockfile);
523 	free(statefile);
524 	free(ostatefile);
525 	return ret;
526 }
527 
528 int
sshkey_xmss_forward_state(const struct sshkey * k,u_int32_t reserve)529 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
530 {
531 	struct ssh_xmss_state *state = k->xmss_state;
532 	u_char *sig = NULL;
533 	size_t required_siglen;
534 	unsigned long long smlen;
535 	u_char data;
536 	int ret, r;
537 
538 	if (state == NULL || !state->allow_update)
539 		return SSH_ERR_INVALID_ARGUMENT;
540 	if (reserve == 0)
541 		return SSH_ERR_INVALID_ARGUMENT;
542 	if (state->idx + reserve <= state->idx)
543 		return SSH_ERR_INVALID_ARGUMENT;
544 	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
545 		return r;
546 	if ((sig = malloc(required_siglen)) == NULL)
547 		return SSH_ERR_ALLOC_FAIL;
548 	while (reserve-- > 0) {
549 		state->idx = PEEK_U32(k->xmss_sk);
550 		smlen = required_siglen;
551 		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
552 		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
553 			r = SSH_ERR_INVALID_ARGUMENT;
554 			break;
555 		}
556 	}
557 	free(sig);
558 	return r;
559 }
560 
561 int
sshkey_xmss_update_state(const struct sshkey * k,int printerror)562 sshkey_xmss_update_state(const struct sshkey *k, int printerror)
563 {
564 	struct ssh_xmss_state *state = k->xmss_state;
565 	struct sshbuf *b = NULL, *enc = NULL;
566 	u_int32_t idx = 0;
567 	unsigned char buf[4];
568 	char *filename = NULL;
569 	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
570 	int fd = -1;
571 	int ret = SSH_ERR_INVALID_ARGUMENT;
572 
573 	if (state == NULL || !state->allow_update)
574 		return ret;
575 	if (state->maxidx) {
576 		/* no update since the number of signatures is limited */
577 		ret = 0;
578 		goto done;
579 	}
580 	idx = PEEK_U32(k->xmss_sk);
581 	if (idx == state->idx) {
582 		/* no signature happened, no need to update */
583 		ret = 0;
584 		goto done;
585 	} else if (idx != state->idx + 1) {
586 		PRINT("more than one signature happened: idx %u state %u",
587 		    idx, state->idx);
588 		goto done;
589 	}
590 	state->idx = idx;
591 	if ((filename = k->xmss_filename) == NULL)
592 		goto done;
593 	if (asprintf(&statefile, "%s.state", filename) == -1 ||
594 	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
595 	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
596 		ret = SSH_ERR_ALLOC_FAIL;
597 		goto done;
598 	}
599 	unlink(nstatefile);
600 	if ((b = sshbuf_new()) == NULL) {
601 		ret = SSH_ERR_ALLOC_FAIL;
602 		goto done;
603 	}
604 	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
605 		PRINT("SERLIALIZE FAILED: %d", ret);
606 		goto done;
607 	}
608 	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
609 		PRINT("ENCRYPT FAILED: %d", ret);
610 		goto done;
611 	}
612 	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
613 		ret = SSH_ERR_SYSTEM_ERROR;
614 		PRINT("open new state file: %s", nstatefile);
615 		goto done;
616 	}
617 	POKE_U32(buf, sshbuf_len(enc));
618 	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
619 		ret = SSH_ERR_SYSTEM_ERROR;
620 		PRINT("write new state file hdr: %s", nstatefile);
621 		close(fd);
622 		goto done;
623 	}
624 	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
625 	    sshbuf_len(enc)) {
626 		ret = SSH_ERR_SYSTEM_ERROR;
627 		PRINT("write new state file data: %s", nstatefile);
628 		close(fd);
629 		goto done;
630 	}
631 	if (fsync(fd) == -1) {
632 		ret = SSH_ERR_SYSTEM_ERROR;
633 		PRINT("sync new state file: %s", nstatefile);
634 		close(fd);
635 		goto done;
636 	}
637 	if (close(fd) == -1) {
638 		ret = SSH_ERR_SYSTEM_ERROR;
639 		PRINT("close new state file: %s", nstatefile);
640 		goto done;
641 	}
642 	if (state->have_state) {
643 		unlink(ostatefile);
644 		if (link(statefile, ostatefile)) {
645 			ret = SSH_ERR_SYSTEM_ERROR;
646 			PRINT("backup state %s to %s", statefile, ostatefile);
647 			goto done;
648 		}
649 	}
650 	if (rename(nstatefile, statefile) == -1) {
651 		ret = SSH_ERR_SYSTEM_ERROR;
652 		PRINT("rename %s to %s", nstatefile, statefile);
653 		goto done;
654 	}
655 	ret = 0;
656 done:
657 	if (state->lockfd != -1) {
658 		close(state->lockfd);
659 		state->lockfd = -1;
660 	}
661 	if (nstatefile)
662 		unlink(nstatefile);
663 	free(statefile);
664 	free(ostatefile);
665 	free(nstatefile);
666 	sshbuf_free(b);
667 	sshbuf_free(enc);
668 	return ret;
669 }
670 
671 int
sshkey_xmss_serialize_state(const struct sshkey * k,struct sshbuf * b)672 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
673 {
674 	struct ssh_xmss_state *state = k->xmss_state;
675 	treehash_inst *th;
676 	u_int32_t i, node;
677 	int r;
678 
679 	if (state == NULL)
680 		return SSH_ERR_INVALID_ARGUMENT;
681 	if (state->stack == NULL)
682 		return SSH_ERR_INVALID_ARGUMENT;
683 	state->stackoffset = state->bds.stackoffset;	/* copy back */
684 	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
685 	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
686 	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
687 	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
688 	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
689 	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
690 	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
691 	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
692 	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
693 	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
694 		return r;
695 	for (i = 0; i < num_treehash(state); i++) {
696 		th = &state->treehash[i];
697 		node = th->node - state->th_nodes;
698 		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
699 		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
700 		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
701 		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
702 		    (r = sshbuf_put_u32(b, node)) != 0)
703 			return r;
704 	}
705 	return 0;
706 }
707 
708 int
sshkey_xmss_serialize_state_opt(const struct sshkey * k,struct sshbuf * b,enum sshkey_serialize_rep opts)709 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
710     enum sshkey_serialize_rep opts)
711 {
712 	struct ssh_xmss_state *state = k->xmss_state;
713 	int r = SSH_ERR_INVALID_ARGUMENT;
714 	u_char have_stack, have_filename, have_enc;
715 
716 	if (state == NULL)
717 		return SSH_ERR_INVALID_ARGUMENT;
718 	if ((r = sshbuf_put_u8(b, opts)) != 0)
719 		return r;
720 	switch (opts) {
721 	case SSHKEY_SERIALIZE_STATE:
722 		r = sshkey_xmss_serialize_state(k, b);
723 		break;
724 	case SSHKEY_SERIALIZE_FULL:
725 		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
726 			return r;
727 		r = sshkey_xmss_serialize_state(k, b);
728 		break;
729 	case SSHKEY_SERIALIZE_SHIELD:
730 		/* all of stack/filename/enc are optional */
731 		have_stack = state->stack != NULL;
732 		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
733 			return r;
734 		if (have_stack) {
735 			state->idx = PEEK_U32(k->xmss_sk);	/* update */
736 			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
737 				return r;
738 		}
739 		have_filename = k->xmss_filename != NULL;
740 		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
741 			return r;
742 		if (have_filename &&
743 		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
744 			return r;
745 		have_enc = state->enc_keyiv != NULL;
746 		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
747 			return r;
748 		if (have_enc &&
749 		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
750 			return r;
751 		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
752 		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
753 			return r;
754 		break;
755 	case SSHKEY_SERIALIZE_DEFAULT:
756 		r = 0;
757 		break;
758 	default:
759 		r = SSH_ERR_INVALID_ARGUMENT;
760 		break;
761 	}
762 	return r;
763 }
764 
765 int
sshkey_xmss_deserialize_state(struct sshkey * k,struct sshbuf * b)766 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
767 {
768 	struct ssh_xmss_state *state = k->xmss_state;
769 	treehash_inst *th;
770 	u_int32_t i, lh, node;
771 	size_t ls, lsl, la, lk, ln, lr;
772 	char *magic;
773 	int r = SSH_ERR_INTERNAL_ERROR;
774 
775 	if (state == NULL)
776 		return SSH_ERR_INVALID_ARGUMENT;
777 	if (k->xmss_sk == NULL)
778 		return SSH_ERR_INVALID_ARGUMENT;
779 	if ((state->treehash = calloc(num_treehash(state),
780 	    sizeof(treehash_inst))) == NULL)
781 		return SSH_ERR_ALLOC_FAIL;
782 	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
783 	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
784 	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
785 	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
786 	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
787 	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
788 	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
789 	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
790 	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
791 	    (r = sshbuf_get_u32(b, &lh)) != 0)
792 		goto out;
793 	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
794 		r = SSH_ERR_INVALID_ARGUMENT;
795 		goto out;
796 	}
797 	/* XXX check stackoffset */
798 	if (ls != num_stack(state) ||
799 	    lsl != num_stacklevels(state) ||
800 	    la != num_auth(state) ||
801 	    lk != num_keep(state) ||
802 	    ln != num_th_nodes(state) ||
803 	    lr != num_retain(state) ||
804 	    lh != num_treehash(state)) {
805 		r = SSH_ERR_INVALID_ARGUMENT;
806 		goto out;
807 	}
808 	for (i = 0; i < num_treehash(state); i++) {
809 		th = &state->treehash[i];
810 		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
811 		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
812 		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
813 		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
814 		    (r = sshbuf_get_u32(b, &node)) != 0)
815 			goto out;
816 		if (node < num_th_nodes(state))
817 			th->node = &state->th_nodes[node];
818 	}
819 	POKE_U32(k->xmss_sk, state->idx);
820 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
821 	    state->stacklevels, state->auth, state->keep, state->treehash,
822 	    state->retain, 0);
823 	/* success */
824 	r = 0;
825  out:
826 	free(magic);
827 	return r;
828 }
829 
830 int
sshkey_xmss_deserialize_state_opt(struct sshkey * k,struct sshbuf * b)831 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
832 {
833 	struct ssh_xmss_state *state = k->xmss_state;
834 	enum sshkey_serialize_rep opts;
835 	u_char have_state, have_stack, have_filename, have_enc;
836 	int r;
837 
838 	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
839 		return r;
840 
841 	opts = have_state;
842 	switch (opts) {
843 	case SSHKEY_SERIALIZE_DEFAULT:
844 		r = 0;
845 		break;
846 	case SSHKEY_SERIALIZE_SHIELD:
847 		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
848 			return r;
849 		if (have_stack &&
850 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
851 			return r;
852 		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
853 			return r;
854 		if (have_filename &&
855 		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
856 			return r;
857 		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
858 			return r;
859 		if (have_enc &&
860 		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
861 			return r;
862 		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
863 		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
864 			return r;
865 		break;
866 	case SSHKEY_SERIALIZE_STATE:
867 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
868 			return r;
869 		break;
870 	case SSHKEY_SERIALIZE_FULL:
871 		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
872 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
873 			return r;
874 		break;
875 	default:
876 		r = SSH_ERR_INVALID_FORMAT;
877 		break;
878 	}
879 	return r;
880 }
881 
882 int
sshkey_xmss_encrypt_state(const struct sshkey * k,struct sshbuf * b,struct sshbuf ** retp)883 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
884    struct sshbuf **retp)
885 {
886 	struct ssh_xmss_state *state = k->xmss_state;
887 	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
888 	struct sshcipher_ctx *ciphercontext = NULL;
889 	const struct sshcipher *cipher;
890 	u_char *cp, *key, *iv = NULL;
891 	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
892 	int r = SSH_ERR_INTERNAL_ERROR;
893 
894 	if (retp != NULL)
895 		*retp = NULL;
896 	if (state == NULL ||
897 	    state->enc_keyiv == NULL ||
898 	    state->enc_ciphername == NULL)
899 		return SSH_ERR_INTERNAL_ERROR;
900 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
901 		r = SSH_ERR_INTERNAL_ERROR;
902 		goto out;
903 	}
904 	blocksize = cipher_blocksize(cipher);
905 	keylen = cipher_keylen(cipher);
906 	ivlen = cipher_ivlen(cipher);
907 	authlen = cipher_authlen(cipher);
908 	if (state->enc_keyiv_len != keylen + ivlen) {
909 		r = SSH_ERR_INVALID_FORMAT;
910 		goto out;
911 	}
912 	key = state->enc_keyiv;
913 	if ((encrypted = sshbuf_new()) == NULL ||
914 	    (encoded = sshbuf_new()) == NULL ||
915 	    (padded = sshbuf_new()) == NULL ||
916 	    (iv = malloc(ivlen)) == NULL) {
917 		r = SSH_ERR_ALLOC_FAIL;
918 		goto out;
919 	}
920 
921 	/* replace first 4 bytes of IV with index to ensure uniqueness */
922 	memcpy(iv, key + keylen, ivlen);
923 	POKE_U32(iv, state->idx);
924 
925 	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
926 	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
927 		goto out;
928 
929 	/* padded state will be encrypted */
930 	if ((r = sshbuf_putb(padded, b)) != 0)
931 		goto out;
932 	i = 0;
933 	while (sshbuf_len(padded) % blocksize) {
934 		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
935 			goto out;
936 	}
937 	encrypted_len = sshbuf_len(padded);
938 
939 	/* header including the length of state is used as AAD */
940 	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
941 		goto out;
942 	aadlen = sshbuf_len(encoded);
943 
944 	/* concat header and state */
945 	if ((r = sshbuf_putb(encoded, padded)) != 0)
946 		goto out;
947 
948 	/* reserve space for encryption of encoded data plus auth tag */
949 	/* encrypt at offset addlen */
950 	if ((r = sshbuf_reserve(encrypted,
951 	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
952 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
953 	    iv, ivlen, 1)) != 0 ||
954 	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
955 	    encrypted_len, aadlen, authlen)) != 0)
956 		goto out;
957 
958 	/* success */
959 	r = 0;
960  out:
961 	if (retp != NULL) {
962 		*retp = encrypted;
963 		encrypted = NULL;
964 	}
965 	sshbuf_free(padded);
966 	sshbuf_free(encoded);
967 	sshbuf_free(encrypted);
968 	cipher_free(ciphercontext);
969 	free(iv);
970 	return r;
971 }
972 
973 int
sshkey_xmss_decrypt_state(const struct sshkey * k,struct sshbuf * encoded,struct sshbuf ** retp)974 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
975    struct sshbuf **retp)
976 {
977 	struct ssh_xmss_state *state = k->xmss_state;
978 	struct sshbuf *copy = NULL, *decrypted = NULL;
979 	struct sshcipher_ctx *ciphercontext = NULL;
980 	const struct sshcipher *cipher = NULL;
981 	u_char *key, *iv = NULL, *dp;
982 	size_t keylen, ivlen, authlen, aadlen;
983 	u_int blocksize, encrypted_len, index;
984 	int r = SSH_ERR_INTERNAL_ERROR;
985 
986 	if (retp != NULL)
987 		*retp = NULL;
988 	if (state == NULL ||
989 	    state->enc_keyiv == NULL ||
990 	    state->enc_ciphername == NULL)
991 		return SSH_ERR_INTERNAL_ERROR;
992 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
993 		r = SSH_ERR_INVALID_FORMAT;
994 		goto out;
995 	}
996 	blocksize = cipher_blocksize(cipher);
997 	keylen = cipher_keylen(cipher);
998 	ivlen = cipher_ivlen(cipher);
999 	authlen = cipher_authlen(cipher);
1000 	if (state->enc_keyiv_len != keylen + ivlen) {
1001 		r = SSH_ERR_INTERNAL_ERROR;
1002 		goto out;
1003 	}
1004 	key = state->enc_keyiv;
1005 
1006 	if ((copy = sshbuf_fromb(encoded)) == NULL ||
1007 	    (decrypted = sshbuf_new()) == NULL ||
1008 	    (iv = malloc(ivlen)) == NULL) {
1009 		r = SSH_ERR_ALLOC_FAIL;
1010 		goto out;
1011 	}
1012 
1013 	/* check magic */
1014 	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1015 	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1016 		r = SSH_ERR_INVALID_FORMAT;
1017 		goto out;
1018 	}
1019 	/* parse public portion */
1020 	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1021 	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1022 	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1023 		goto out;
1024 
1025 	/* check size of encrypted key blob */
1026 	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1027 		r = SSH_ERR_INVALID_FORMAT;
1028 		goto out;
1029 	}
1030 	/* check that an appropriate amount of auth data is present */
1031 	if (sshbuf_len(encoded) < authlen ||
1032 	    sshbuf_len(encoded) - authlen < encrypted_len) {
1033 		r = SSH_ERR_INVALID_FORMAT;
1034 		goto out;
1035 	}
1036 
1037 	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1038 
1039 	/* replace first 4 bytes of IV with index to ensure uniqueness */
1040 	memcpy(iv, key + keylen, ivlen);
1041 	POKE_U32(iv, index);
1042 
1043 	/* decrypt private state of key */
1044 	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1045 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
1046 	    iv, ivlen, 0)) != 0 ||
1047 	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1048 	    encrypted_len, aadlen, authlen)) != 0)
1049 		goto out;
1050 
1051 	/* there should be no trailing data */
1052 	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1053 		goto out;
1054 	if (sshbuf_len(encoded) != 0) {
1055 		r = SSH_ERR_INVALID_FORMAT;
1056 		goto out;
1057 	}
1058 
1059 	/* remove AAD */
1060 	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1061 		goto out;
1062 	/* XXX encrypted includes unchecked padding */
1063 
1064 	/* success */
1065 	r = 0;
1066 	if (retp != NULL) {
1067 		*retp = decrypted;
1068 		decrypted = NULL;
1069 	}
1070  out:
1071 	cipher_free(ciphercontext);
1072 	sshbuf_free(copy);
1073 	sshbuf_free(decrypted);
1074 	free(iv);
1075 	return r;
1076 }
1077 
1078 u_int32_t
sshkey_xmss_signatures_left(const struct sshkey * k)1079 sshkey_xmss_signatures_left(const struct sshkey *k)
1080 {
1081 	struct ssh_xmss_state *state = k->xmss_state;
1082 	u_int32_t idx;
1083 
1084 	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1085 	    state->maxidx) {
1086 		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1087 		if (idx < state->maxidx)
1088 			return state->maxidx - idx;
1089 	}
1090 	return 0;
1091 }
1092 
1093 int
sshkey_xmss_enable_maxsign(struct sshkey * k,u_int32_t maxsign)1094 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1095 {
1096 	struct ssh_xmss_state *state = k->xmss_state;
1097 
1098 	if (sshkey_type_plain(k->type) != KEY_XMSS)
1099 		return SSH_ERR_INVALID_ARGUMENT;
1100 	if (maxsign == 0)
1101 		return 0;
1102 	if (state->idx + maxsign < state->idx)
1103 		return SSH_ERR_INVALID_ARGUMENT;
1104 	state->maxidx = state->idx + maxsign;
1105 	return 0;
1106 }
1107