1 /* $NetBSD: sshkey-xmss.c,v 1.9 2023/07/26 17:58:16 christos Exp $ */
2 /* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */
3 /*
4 * Copyright (c) 2017 Markus Friedl. All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 * 1. Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 * 2. Redistributions in binary form must reproduce the above copyright
12 * notice, this list of conditions and the following disclaimer in the
13 * documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
17 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
18 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
19 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
20 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
21 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
22 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
24 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26 #include "includes.h"
27 __RCSID("$NetBSD: sshkey-xmss.c,v 1.9 2023/07/26 17:58:16 christos Exp $");
28
29 #include <sys/types.h>
30 #include <sys/uio.h>
31
32 #include <stdio.h>
33 #include <string.h>
34 #include <unistd.h>
35 #include <fcntl.h>
36 #include <errno.h>
37
38 #include "ssh2.h"
39 #include "ssherr.h"
40 #include "sshbuf.h"
41 #include "cipher.h"
42 #include "sshkey.h"
43 #include "sshkey-xmss.h"
44 #include "atomicio.h"
45 #include "log.h"
46
47 #include "xmss_fast.h"
48
49 /* opaque internal XMSS state */
50 #define XMSS_MAGIC "xmss-state-v1"
51 #define XMSS_CIPHERNAME "aes256-gcm@openssh.com"
52 struct ssh_xmss_state {
53 xmss_params params;
54 u_int32_t n, w, h, k;
55
56 bds_state bds;
57 u_char *stack;
58 u_int32_t stackoffset;
59 u_char *stacklevels;
60 u_char *auth;
61 u_char *keep;
62 u_char *th_nodes;
63 u_char *retain;
64 treehash_inst *treehash;
65
66 u_int32_t idx; /* state read from file */
67 u_int32_t maxidx; /* restricted # of signatures */
68 int have_state; /* .state file exists */
69 int lockfd; /* locked in sshkey_xmss_get_state() */
70 u_char allow_update; /* allow sshkey_xmss_update_state() */
71 char *enc_ciphername;/* encrypt state with cipher */
72 u_char *enc_keyiv; /* encrypt state with key */
73 u_int32_t enc_keyiv_len; /* length of enc_keyiv */
74 };
75
76 int sshkey_xmss_init_bds_state(struct sshkey *);
77 int sshkey_xmss_init_enc_key(struct sshkey *, const char *);
78 void sshkey_xmss_free_bds(struct sshkey *);
79 int sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
80 int *, int);
81 int sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
82 struct sshbuf **);
83 int sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
84 struct sshbuf **);
85 int sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
86 int sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
87
88 #define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
89 0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
90
91 int
sshkey_xmss_init(struct sshkey * key,const char * name)92 sshkey_xmss_init(struct sshkey *key, const char *name)
93 {
94 struct ssh_xmss_state *state;
95
96 if (key->xmss_state != NULL)
97 return SSH_ERR_INVALID_FORMAT;
98 if (name == NULL)
99 return SSH_ERR_INVALID_FORMAT;
100 state = calloc(sizeof(struct ssh_xmss_state), 1);
101 if (state == NULL)
102 return SSH_ERR_ALLOC_FAIL;
103 if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
104 state->n = 32;
105 state->w = 16;
106 state->h = 10;
107 } else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
108 state->n = 32;
109 state->w = 16;
110 state->h = 16;
111 } else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
112 state->n = 32;
113 state->w = 16;
114 state->h = 20;
115 } else {
116 free(state);
117 return SSH_ERR_KEY_TYPE_UNKNOWN;
118 }
119 if ((key->xmss_name = strdup(name)) == NULL) {
120 free(state);
121 return SSH_ERR_ALLOC_FAIL;
122 }
123 state->k = 2; /* XXX hardcoded */
124 state->lockfd = -1;
125 if (xmss_set_params(&state->params, state->n, state->h, state->w,
126 state->k) != 0) {
127 free(state);
128 return SSH_ERR_INVALID_FORMAT;
129 }
130 key->xmss_state = state;
131 return 0;
132 }
133
134 void
sshkey_xmss_free_state(struct sshkey * key)135 sshkey_xmss_free_state(struct sshkey *key)
136 {
137 struct ssh_xmss_state *state = key->xmss_state;
138
139 sshkey_xmss_free_bds(key);
140 if (state) {
141 if (state->enc_keyiv) {
142 explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
143 free(state->enc_keyiv);
144 }
145 free(state->enc_ciphername);
146 free(state);
147 }
148 key->xmss_state = NULL;
149 }
150
151 #define SSH_XMSS_K2_MAGIC "k=2"
152 #define num_stack(x) ((x->h+1)*(x->n))
153 #define num_stacklevels(x) (x->h+1)
154 #define num_auth(x) ((x->h)*(x->n))
155 #define num_keep(x) ((x->h >> 1)*(x->n))
156 #define num_th_nodes(x) ((x->h - x->k)*(x->n))
157 #define num_retain(x) (((1ULL << x->k) - x->k - 1) * (x->n))
158 #define num_treehash(x) ((x->h) - (x->k))
159
160 int
sshkey_xmss_init_bds_state(struct sshkey * key)161 sshkey_xmss_init_bds_state(struct sshkey *key)
162 {
163 struct ssh_xmss_state *state = key->xmss_state;
164 u_int32_t i;
165
166 state->stackoffset = 0;
167 if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
168 (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
169 (state->auth = calloc(num_auth(state), 1)) == NULL ||
170 (state->keep = calloc(num_keep(state), 1)) == NULL ||
171 (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
172 (state->retain = calloc(num_retain(state), 1)) == NULL ||
173 (state->treehash = calloc(num_treehash(state),
174 sizeof(treehash_inst))) == NULL) {
175 sshkey_xmss_free_bds(key);
176 return SSH_ERR_ALLOC_FAIL;
177 }
178 for (i = 0; i < state->h - state->k; i++)
179 state->treehash[i].node = &state->th_nodes[state->n*i];
180 xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
181 state->stacklevels, state->auth, state->keep, state->treehash,
182 state->retain, 0);
183 return 0;
184 }
185
186 void
sshkey_xmss_free_bds(struct sshkey * key)187 sshkey_xmss_free_bds(struct sshkey *key)
188 {
189 struct ssh_xmss_state *state = key->xmss_state;
190
191 if (state == NULL)
192 return;
193 free(state->stack);
194 free(state->stacklevels);
195 free(state->auth);
196 free(state->keep);
197 free(state->th_nodes);
198 free(state->retain);
199 free(state->treehash);
200 state->stack = NULL;
201 state->stacklevels = NULL;
202 state->auth = NULL;
203 state->keep = NULL;
204 state->th_nodes = NULL;
205 state->retain = NULL;
206 state->treehash = NULL;
207 }
208
209 void *
sshkey_xmss_params(const struct sshkey * key)210 sshkey_xmss_params(const struct sshkey *key)
211 {
212 struct ssh_xmss_state *state = key->xmss_state;
213
214 if (state == NULL)
215 return NULL;
216 return &state->params;
217 }
218
219 void *
sshkey_xmss_bds_state(const struct sshkey * key)220 sshkey_xmss_bds_state(const struct sshkey *key)
221 {
222 struct ssh_xmss_state *state = key->xmss_state;
223
224 if (state == NULL)
225 return NULL;
226 return &state->bds;
227 }
228
229 int
sshkey_xmss_siglen(const struct sshkey * key,size_t * lenp)230 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
231 {
232 struct ssh_xmss_state *state = key->xmss_state;
233
234 if (lenp == NULL)
235 return SSH_ERR_INVALID_ARGUMENT;
236 if (state == NULL)
237 return SSH_ERR_INVALID_FORMAT;
238 *lenp = 4 + state->n +
239 state->params.wots_par.keysize +
240 state->h * state->n;
241 return 0;
242 }
243
244 size_t
sshkey_xmss_pklen(const struct sshkey * key)245 sshkey_xmss_pklen(const struct sshkey *key)
246 {
247 struct ssh_xmss_state *state = key->xmss_state;
248
249 if (state == NULL)
250 return 0;
251 return state->n * 2;
252 }
253
254 size_t
sshkey_xmss_sklen(const struct sshkey * key)255 sshkey_xmss_sklen(const struct sshkey *key)
256 {
257 struct ssh_xmss_state *state = key->xmss_state;
258
259 if (state == NULL)
260 return 0;
261 return state->n * 4 + 4;
262 }
263
264 int
sshkey_xmss_init_enc_key(struct sshkey * k,const char * ciphername)265 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
266 {
267 struct ssh_xmss_state *state = k->xmss_state;
268 const struct sshcipher *cipher;
269 size_t keylen = 0, ivlen = 0;
270
271 if (state == NULL)
272 return SSH_ERR_INVALID_ARGUMENT;
273 if ((cipher = cipher_by_name(ciphername)) == NULL)
274 return SSH_ERR_INTERNAL_ERROR;
275 if ((state->enc_ciphername = strdup(ciphername)) == NULL)
276 return SSH_ERR_ALLOC_FAIL;
277 keylen = cipher_keylen(cipher);
278 ivlen = cipher_ivlen(cipher);
279 state->enc_keyiv_len = keylen + ivlen;
280 if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
281 free(state->enc_ciphername);
282 state->enc_ciphername = NULL;
283 return SSH_ERR_ALLOC_FAIL;
284 }
285 arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
286 return 0;
287 }
288
289 int
sshkey_xmss_serialize_enc_key(const struct sshkey * k,struct sshbuf * b)290 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
291 {
292 struct ssh_xmss_state *state = k->xmss_state;
293 int r;
294
295 if (state == NULL || state->enc_keyiv == NULL ||
296 state->enc_ciphername == NULL)
297 return SSH_ERR_INVALID_ARGUMENT;
298 if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
299 (r = sshbuf_put_string(b, state->enc_keyiv,
300 state->enc_keyiv_len)) != 0)
301 return r;
302 return 0;
303 }
304
305 int
sshkey_xmss_deserialize_enc_key(struct sshkey * k,struct sshbuf * b)306 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
307 {
308 struct ssh_xmss_state *state = k->xmss_state;
309 size_t len;
310 int r;
311
312 if (state == NULL)
313 return SSH_ERR_INVALID_ARGUMENT;
314 if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
315 (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
316 return r;
317 state->enc_keyiv_len = len;
318 return 0;
319 }
320
321 int
sshkey_xmss_serialize_pk_info(const struct sshkey * k,struct sshbuf * b,enum sshkey_serialize_rep opts)322 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
323 enum sshkey_serialize_rep opts)
324 {
325 struct ssh_xmss_state *state = k->xmss_state;
326 u_char have_info = 1;
327 u_int32_t idx;
328 int r;
329
330 if (state == NULL)
331 return SSH_ERR_INVALID_ARGUMENT;
332 if (opts != SSHKEY_SERIALIZE_INFO)
333 return 0;
334 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
335 if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
336 (r = sshbuf_put_u32(b, idx)) != 0 ||
337 (r = sshbuf_put_u32(b, state->maxidx)) != 0)
338 return r;
339 return 0;
340 }
341
342 int
sshkey_xmss_deserialize_pk_info(struct sshkey * k,struct sshbuf * b)343 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
344 {
345 struct ssh_xmss_state *state = k->xmss_state;
346 u_char have_info;
347 int r;
348
349 if (state == NULL)
350 return SSH_ERR_INVALID_ARGUMENT;
351 /* optional */
352 if (sshbuf_len(b) == 0)
353 return 0;
354 if ((r = sshbuf_get_u8(b, &have_info)) != 0)
355 return r;
356 if (have_info != 1)
357 return SSH_ERR_INVALID_ARGUMENT;
358 if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
359 (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
360 return r;
361 return 0;
362 }
363
364 int
sshkey_xmss_generate_private_key(struct sshkey * k,int bits)365 sshkey_xmss_generate_private_key(struct sshkey *k, int bits)
366 {
367 int r;
368 const char *name;
369
370 if (bits == 10) {
371 name = XMSS_SHA2_256_W16_H10_NAME;
372 } else if (bits == 16) {
373 name = XMSS_SHA2_256_W16_H16_NAME;
374 } else if (bits == 20) {
375 name = XMSS_SHA2_256_W16_H20_NAME;
376 } else {
377 name = XMSS_DEFAULT_NAME;
378 }
379 if ((r = sshkey_xmss_init(k, name)) != 0 ||
380 (r = sshkey_xmss_init_bds_state(k)) != 0 ||
381 (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
382 return r;
383 if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
384 (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
385 return SSH_ERR_ALLOC_FAIL;
386 }
387 xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
388 sshkey_xmss_params(k));
389 return 0;
390 }
391
392 int
sshkey_xmss_get_state_from_file(struct sshkey * k,const char * filename,int * have_file,int printerror)393 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
394 int *have_file, int printerror)
395 {
396 struct sshbuf *b = NULL, *enc = NULL;
397 int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
398 u_int32_t len;
399 unsigned char buf[4], *data = NULL;
400
401 *have_file = 0;
402 if ((fd = open(filename, O_RDONLY)) >= 0) {
403 *have_file = 1;
404 if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
405 PRINT("corrupt state file: %s", filename);
406 goto done;
407 }
408 len = PEEK_U32(buf);
409 if ((data = calloc(len, 1)) == NULL) {
410 ret = SSH_ERR_ALLOC_FAIL;
411 goto done;
412 }
413 if (atomicio(read, fd, data, len) != len) {
414 PRINT("cannot read blob: %s", filename);
415 goto done;
416 }
417 if ((enc = sshbuf_from(data, len)) == NULL) {
418 ret = SSH_ERR_ALLOC_FAIL;
419 goto done;
420 }
421 sshkey_xmss_free_bds(k);
422 if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
423 ret = r;
424 goto done;
425 }
426 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
427 ret = r;
428 goto done;
429 }
430 ret = 0;
431 }
432 done:
433 if (fd != -1)
434 close(fd);
435 free(data);
436 sshbuf_free(enc);
437 sshbuf_free(b);
438 return ret;
439 }
440
441 int
sshkey_xmss_get_state(const struct sshkey * k,int printerror)442 sshkey_xmss_get_state(const struct sshkey *k, int printerror)
443 {
444 struct ssh_xmss_state *state = k->xmss_state;
445 u_int32_t idx = 0;
446 char *filename = NULL;
447 char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
448 int lockfd = -1, have_state = 0, have_ostate, tries = 0;
449 int ret = SSH_ERR_INVALID_ARGUMENT, r;
450
451 if (state == NULL)
452 goto done;
453 /*
454 * If maxidx is set, then we are allowed a limited number
455 * of signatures, but don't need to access the disk.
456 * Otherwise we need to deal with the on-disk state.
457 */
458 if (state->maxidx) {
459 /* xmss_sk always contains the current state */
460 idx = PEEK_U32(k->xmss_sk);
461 if (idx < state->maxidx) {
462 state->allow_update = 1;
463 return 0;
464 }
465 return SSH_ERR_INVALID_ARGUMENT;
466 }
467 if ((filename = k->xmss_filename) == NULL)
468 goto done;
469 if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
470 asprintf(&statefile, "%s.state", filename) == -1 ||
471 asprintf(&ostatefile, "%s.ostate", filename) == -1) {
472 ret = SSH_ERR_ALLOC_FAIL;
473 goto done;
474 }
475 if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
476 ret = SSH_ERR_SYSTEM_ERROR;
477 PRINT("cannot open/create: %s", lockfile);
478 goto done;
479 }
480 while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
481 if (errno != EWOULDBLOCK) {
482 ret = SSH_ERR_SYSTEM_ERROR;
483 PRINT("cannot lock: %s", lockfile);
484 goto done;
485 }
486 if (++tries > 10) {
487 ret = SSH_ERR_SYSTEM_ERROR;
488 PRINT("giving up on: %s", lockfile);
489 goto done;
490 }
491 usleep(1000*100*tries);
492 }
493 /* XXX no longer const */
494 if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
495 statefile, &have_state, printerror)) != 0) {
496 if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
497 ostatefile, &have_ostate, printerror)) == 0) {
498 state->allow_update = 1;
499 r = sshkey_xmss_forward_state(k, 1);
500 state->idx = PEEK_U32(k->xmss_sk);
501 state->allow_update = 0;
502 }
503 }
504 if (!have_state && !have_ostate) {
505 /* check that bds state is initialized */
506 if (state->bds.auth == NULL)
507 goto done;
508 PRINT("start from scratch idx 0: %u", state->idx);
509 } else if (r != 0) {
510 ret = r;
511 goto done;
512 }
513 if (state->idx + 1 < state->idx) {
514 PRINT("state wrap: %u", state->idx);
515 goto done;
516 }
517 state->have_state = have_state;
518 state->lockfd = lockfd;
519 state->allow_update = 1;
520 lockfd = -1;
521 ret = 0;
522 done:
523 if (lockfd != -1)
524 close(lockfd);
525 free(lockfile);
526 free(statefile);
527 free(ostatefile);
528 return ret;
529 }
530
531 int
sshkey_xmss_forward_state(const struct sshkey * k,u_int32_t reserve)532 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
533 {
534 struct ssh_xmss_state *state = k->xmss_state;
535 u_char *sig = NULL;
536 size_t required_siglen;
537 unsigned long long smlen;
538 u_char data;
539 int ret, r;
540
541 if (state == NULL || !state->allow_update)
542 return SSH_ERR_INVALID_ARGUMENT;
543 if (reserve == 0)
544 return SSH_ERR_INVALID_ARGUMENT;
545 if (state->idx + reserve <= state->idx)
546 return SSH_ERR_INVALID_ARGUMENT;
547 if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
548 return r;
549 if ((sig = malloc(required_siglen)) == NULL)
550 return SSH_ERR_ALLOC_FAIL;
551 while (reserve-- > 0) {
552 state->idx = PEEK_U32(k->xmss_sk);
553 smlen = required_siglen;
554 if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
555 sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
556 r = SSH_ERR_INVALID_ARGUMENT;
557 break;
558 }
559 }
560 free(sig);
561 return r;
562 }
563
564 int
sshkey_xmss_update_state(const struct sshkey * k,int printerror)565 sshkey_xmss_update_state(const struct sshkey *k, int printerror)
566 {
567 struct ssh_xmss_state *state = k->xmss_state;
568 struct sshbuf *b = NULL, *enc = NULL;
569 u_int32_t idx = 0;
570 unsigned char buf[4];
571 char *filename = NULL;
572 char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
573 int fd = -1;
574 int ret = SSH_ERR_INVALID_ARGUMENT;
575
576 if (state == NULL || !state->allow_update)
577 return ret;
578 if (state->maxidx) {
579 /* no update since the number of signatures is limited */
580 ret = 0;
581 goto done;
582 }
583 idx = PEEK_U32(k->xmss_sk);
584 if (idx == state->idx) {
585 /* no signature happened, no need to update */
586 ret = 0;
587 goto done;
588 } else if (idx != state->idx + 1) {
589 PRINT("more than one signature happened: idx %u state %u",
590 idx, state->idx);
591 goto done;
592 }
593 state->idx = idx;
594 if ((filename = k->xmss_filename) == NULL)
595 goto done;
596 if (asprintf(&statefile, "%s.state", filename) == -1 ||
597 asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
598 asprintf(&nstatefile, "%s.nstate", filename) == -1) {
599 ret = SSH_ERR_ALLOC_FAIL;
600 goto done;
601 }
602 unlink(nstatefile);
603 if ((b = sshbuf_new()) == NULL) {
604 ret = SSH_ERR_ALLOC_FAIL;
605 goto done;
606 }
607 if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
608 PRINT("SERLIALIZE FAILED: %d", ret);
609 goto done;
610 }
611 if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
612 PRINT("ENCRYPT FAILED: %d", ret);
613 goto done;
614 }
615 if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
616 ret = SSH_ERR_SYSTEM_ERROR;
617 PRINT("open new state file: %s", nstatefile);
618 goto done;
619 }
620 POKE_U32(buf, sshbuf_len(enc));
621 if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
622 ret = SSH_ERR_SYSTEM_ERROR;
623 PRINT("write new state file hdr: %s", nstatefile);
624 close(fd);
625 goto done;
626 }
627 if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
628 sshbuf_len(enc)) {
629 ret = SSH_ERR_SYSTEM_ERROR;
630 PRINT("write new state file data: %s", nstatefile);
631 close(fd);
632 goto done;
633 }
634 if (fsync(fd) == -1) {
635 ret = SSH_ERR_SYSTEM_ERROR;
636 PRINT("sync new state file: %s", nstatefile);
637 close(fd);
638 goto done;
639 }
640 if (close(fd) == -1) {
641 ret = SSH_ERR_SYSTEM_ERROR;
642 PRINT("close new state file: %s", nstatefile);
643 goto done;
644 }
645 if (state->have_state) {
646 unlink(ostatefile);
647 if (link(statefile, ostatefile)) {
648 ret = SSH_ERR_SYSTEM_ERROR;
649 PRINT("backup state %s to %s", statefile, ostatefile);
650 goto done;
651 }
652 }
653 if (rename(nstatefile, statefile) == -1) {
654 ret = SSH_ERR_SYSTEM_ERROR;
655 PRINT("rename %s to %s", nstatefile, statefile);
656 goto done;
657 }
658 ret = 0;
659 done:
660 if (state->lockfd != -1) {
661 close(state->lockfd);
662 state->lockfd = -1;
663 }
664 if (nstatefile)
665 unlink(nstatefile);
666 free(statefile);
667 free(ostatefile);
668 free(nstatefile);
669 sshbuf_free(b);
670 sshbuf_free(enc);
671 return ret;
672 }
673
674 int
sshkey_xmss_serialize_state(const struct sshkey * k,struct sshbuf * b)675 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
676 {
677 struct ssh_xmss_state *state = k->xmss_state;
678 treehash_inst *th;
679 u_int32_t i, node;
680 int r;
681
682 if (state == NULL)
683 return SSH_ERR_INVALID_ARGUMENT;
684 if (state->stack == NULL)
685 return SSH_ERR_INVALID_ARGUMENT;
686 state->stackoffset = state->bds.stackoffset; /* copy back */
687 if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
688 (r = sshbuf_put_u32(b, state->idx)) != 0 ||
689 (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
690 (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
691 (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
692 (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
693 (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
694 (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
695 (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
696 (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
697 return r;
698 for (i = 0; i < num_treehash(state); i++) {
699 th = &state->treehash[i];
700 node = th->node - state->th_nodes;
701 if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
702 (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
703 (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
704 (r = sshbuf_put_u8(b, th->completed)) != 0 ||
705 (r = sshbuf_put_u32(b, node)) != 0)
706 return r;
707 }
708 return 0;
709 }
710
711 int
sshkey_xmss_serialize_state_opt(const struct sshkey * k,struct sshbuf * b,enum sshkey_serialize_rep opts)712 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
713 enum sshkey_serialize_rep opts)
714 {
715 struct ssh_xmss_state *state = k->xmss_state;
716 int r = SSH_ERR_INVALID_ARGUMENT;
717 u_char have_stack, have_filename, have_enc;
718
719 if (state == NULL)
720 return SSH_ERR_INVALID_ARGUMENT;
721 if ((r = sshbuf_put_u8(b, opts)) != 0)
722 return r;
723 switch (opts) {
724 case SSHKEY_SERIALIZE_STATE:
725 r = sshkey_xmss_serialize_state(k, b);
726 break;
727 case SSHKEY_SERIALIZE_FULL:
728 if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
729 return r;
730 r = sshkey_xmss_serialize_state(k, b);
731 break;
732 case SSHKEY_SERIALIZE_SHIELD:
733 /* all of stack/filename/enc are optional */
734 have_stack = state->stack != NULL;
735 if ((r = sshbuf_put_u8(b, have_stack)) != 0)
736 return r;
737 if (have_stack) {
738 state->idx = PEEK_U32(k->xmss_sk); /* update */
739 if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
740 return r;
741 }
742 have_filename = k->xmss_filename != NULL;
743 if ((r = sshbuf_put_u8(b, have_filename)) != 0)
744 return r;
745 if (have_filename &&
746 (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
747 return r;
748 have_enc = state->enc_keyiv != NULL;
749 if ((r = sshbuf_put_u8(b, have_enc)) != 0)
750 return r;
751 if (have_enc &&
752 (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
753 return r;
754 if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
755 (r = sshbuf_put_u8(b, state->allow_update)) != 0)
756 return r;
757 break;
758 case SSHKEY_SERIALIZE_DEFAULT:
759 r = 0;
760 break;
761 default:
762 r = SSH_ERR_INVALID_ARGUMENT;
763 break;
764 }
765 return r;
766 }
767
768 int
sshkey_xmss_deserialize_state(struct sshkey * k,struct sshbuf * b)769 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
770 {
771 struct ssh_xmss_state *state = k->xmss_state;
772 treehash_inst *th;
773 u_int32_t i, lh, node;
774 size_t ls, lsl, la, lk, ln, lr;
775 char *magic;
776 int r = SSH_ERR_INTERNAL_ERROR;
777
778 if (state == NULL)
779 return SSH_ERR_INVALID_ARGUMENT;
780 if (k->xmss_sk == NULL)
781 return SSH_ERR_INVALID_ARGUMENT;
782 if ((state->treehash = calloc(num_treehash(state),
783 sizeof(treehash_inst))) == NULL)
784 return SSH_ERR_ALLOC_FAIL;
785 if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
786 (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
787 (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
788 (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
789 (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
790 (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
791 (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
792 (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
793 (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
794 (r = sshbuf_get_u32(b, &lh)) != 0)
795 goto out;
796 if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
797 r = SSH_ERR_INVALID_ARGUMENT;
798 goto out;
799 }
800 /* XXX check stackoffset */
801 if (ls != num_stack(state) ||
802 lsl != num_stacklevels(state) ||
803 la != num_auth(state) ||
804 lk != num_keep(state) ||
805 ln != num_th_nodes(state) ||
806 lr != num_retain(state) ||
807 lh != num_treehash(state)) {
808 r = SSH_ERR_INVALID_ARGUMENT;
809 goto out;
810 }
811 for (i = 0; i < num_treehash(state); i++) {
812 th = &state->treehash[i];
813 if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
814 (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
815 (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
816 (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
817 (r = sshbuf_get_u32(b, &node)) != 0)
818 goto out;
819 if (node < num_th_nodes(state))
820 th->node = &state->th_nodes[node];
821 }
822 POKE_U32(k->xmss_sk, state->idx);
823 xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
824 state->stacklevels, state->auth, state->keep, state->treehash,
825 state->retain, 0);
826 /* success */
827 r = 0;
828 out:
829 free(magic);
830 return r;
831 }
832
833 int
sshkey_xmss_deserialize_state_opt(struct sshkey * k,struct sshbuf * b)834 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
835 {
836 struct ssh_xmss_state *state = k->xmss_state;
837 enum sshkey_serialize_rep opts;
838 u_char have_state, have_stack, have_filename, have_enc;
839 int r;
840
841 if ((r = sshbuf_get_u8(b, &have_state)) != 0)
842 return r;
843
844 opts = have_state;
845 switch (opts) {
846 case SSHKEY_SERIALIZE_DEFAULT:
847 r = 0;
848 break;
849 case SSHKEY_SERIALIZE_SHIELD:
850 if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
851 return r;
852 if (have_stack &&
853 (r = sshkey_xmss_deserialize_state(k, b)) != 0)
854 return r;
855 if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
856 return r;
857 if (have_filename &&
858 (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
859 return r;
860 if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
861 return r;
862 if (have_enc &&
863 (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
864 return r;
865 if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
866 (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
867 return r;
868 break;
869 case SSHKEY_SERIALIZE_STATE:
870 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
871 return r;
872 break;
873 case SSHKEY_SERIALIZE_FULL:
874 if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
875 (r = sshkey_xmss_deserialize_state(k, b)) != 0)
876 return r;
877 break;
878 default:
879 r = SSH_ERR_INVALID_FORMAT;
880 break;
881 }
882 return r;
883 }
884
885 int
sshkey_xmss_encrypt_state(const struct sshkey * k,struct sshbuf * b,struct sshbuf ** retp)886 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
887 struct sshbuf **retp)
888 {
889 struct ssh_xmss_state *state = k->xmss_state;
890 struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
891 struct sshcipher_ctx *ciphercontext = NULL;
892 const struct sshcipher *cipher;
893 u_char *cp, *key, *iv = NULL;
894 size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
895 int r = SSH_ERR_INTERNAL_ERROR;
896
897 if (retp != NULL)
898 *retp = NULL;
899 if (state == NULL ||
900 state->enc_keyiv == NULL ||
901 state->enc_ciphername == NULL)
902 return SSH_ERR_INTERNAL_ERROR;
903 if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
904 r = SSH_ERR_INTERNAL_ERROR;
905 goto out;
906 }
907 blocksize = cipher_blocksize(cipher);
908 keylen = cipher_keylen(cipher);
909 ivlen = cipher_ivlen(cipher);
910 authlen = cipher_authlen(cipher);
911 if (state->enc_keyiv_len != keylen + ivlen) {
912 r = SSH_ERR_INVALID_FORMAT;
913 goto out;
914 }
915 key = state->enc_keyiv;
916 if ((encrypted = sshbuf_new()) == NULL ||
917 (encoded = sshbuf_new()) == NULL ||
918 (padded = sshbuf_new()) == NULL ||
919 (iv = malloc(ivlen)) == NULL) {
920 r = SSH_ERR_ALLOC_FAIL;
921 goto out;
922 }
923
924 /* replace first 4 bytes of IV with index to ensure uniqueness */
925 memcpy(iv, key + keylen, ivlen);
926 POKE_U32(iv, state->idx);
927
928 if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
929 (r = sshbuf_put_u32(encoded, state->idx)) != 0)
930 goto out;
931
932 /* padded state will be encrypted */
933 if ((r = sshbuf_putb(padded, b)) != 0)
934 goto out;
935 i = 0;
936 while (sshbuf_len(padded) % blocksize) {
937 if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
938 goto out;
939 }
940 encrypted_len = sshbuf_len(padded);
941
942 /* header including the length of state is used as AAD */
943 if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
944 goto out;
945 aadlen = sshbuf_len(encoded);
946
947 /* concat header and state */
948 if ((r = sshbuf_putb(encoded, padded)) != 0)
949 goto out;
950
951 /* reserve space for encryption of encoded data plus auth tag */
952 /* encrypt at offset addlen */
953 if ((r = sshbuf_reserve(encrypted,
954 encrypted_len + aadlen + authlen, &cp)) != 0 ||
955 (r = cipher_init(&ciphercontext, cipher, key, keylen,
956 iv, ivlen, 1)) != 0 ||
957 (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
958 encrypted_len, aadlen, authlen)) != 0)
959 goto out;
960
961 /* success */
962 r = 0;
963 out:
964 if (retp != NULL) {
965 *retp = encrypted;
966 encrypted = NULL;
967 }
968 sshbuf_free(padded);
969 sshbuf_free(encoded);
970 sshbuf_free(encrypted);
971 cipher_free(ciphercontext);
972 free(iv);
973 return r;
974 }
975
976 int
sshkey_xmss_decrypt_state(const struct sshkey * k,struct sshbuf * encoded,struct sshbuf ** retp)977 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
978 struct sshbuf **retp)
979 {
980 struct ssh_xmss_state *state = k->xmss_state;
981 struct sshbuf *copy = NULL, *decrypted = NULL;
982 struct sshcipher_ctx *ciphercontext = NULL;
983 const struct sshcipher *cipher = NULL;
984 u_char *key, *iv = NULL, *dp;
985 size_t keylen, ivlen, authlen, aadlen;
986 u_int blocksize, encrypted_len, index;
987 int r = SSH_ERR_INTERNAL_ERROR;
988
989 if (retp != NULL)
990 *retp = NULL;
991 if (state == NULL ||
992 state->enc_keyiv == NULL ||
993 state->enc_ciphername == NULL)
994 return SSH_ERR_INTERNAL_ERROR;
995 if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
996 r = SSH_ERR_INVALID_FORMAT;
997 goto out;
998 }
999 blocksize = cipher_blocksize(cipher);
1000 keylen = cipher_keylen(cipher);
1001 ivlen = cipher_ivlen(cipher);
1002 authlen = cipher_authlen(cipher);
1003 if (state->enc_keyiv_len != keylen + ivlen) {
1004 r = SSH_ERR_INTERNAL_ERROR;
1005 goto out;
1006 }
1007 key = state->enc_keyiv;
1008
1009 if ((copy = sshbuf_fromb(encoded)) == NULL ||
1010 (decrypted = sshbuf_new()) == NULL ||
1011 (iv = malloc(ivlen)) == NULL) {
1012 r = SSH_ERR_ALLOC_FAIL;
1013 goto out;
1014 }
1015
1016 /* check magic */
1017 if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1018 memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1019 r = SSH_ERR_INVALID_FORMAT;
1020 goto out;
1021 }
1022 /* parse public portion */
1023 if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1024 (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1025 (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1026 goto out;
1027
1028 /* check size of encrypted key blob */
1029 if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1030 r = SSH_ERR_INVALID_FORMAT;
1031 goto out;
1032 }
1033 /* check that an appropriate amount of auth data is present */
1034 if (sshbuf_len(encoded) < authlen ||
1035 sshbuf_len(encoded) - authlen < encrypted_len) {
1036 r = SSH_ERR_INVALID_FORMAT;
1037 goto out;
1038 }
1039
1040 aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1041
1042 /* replace first 4 bytes of IV with index to ensure uniqueness */
1043 memcpy(iv, key + keylen, ivlen);
1044 POKE_U32(iv, index);
1045
1046 /* decrypt private state of key */
1047 if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1048 (r = cipher_init(&ciphercontext, cipher, key, keylen,
1049 iv, ivlen, 0)) != 0 ||
1050 (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1051 encrypted_len, aadlen, authlen)) != 0)
1052 goto out;
1053
1054 /* there should be no trailing data */
1055 if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1056 goto out;
1057 if (sshbuf_len(encoded) != 0) {
1058 r = SSH_ERR_INVALID_FORMAT;
1059 goto out;
1060 }
1061
1062 /* remove AAD */
1063 if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1064 goto out;
1065 /* XXX encrypted includes unchecked padding */
1066
1067 /* success */
1068 r = 0;
1069 if (retp != NULL) {
1070 *retp = decrypted;
1071 decrypted = NULL;
1072 }
1073 out:
1074 cipher_free(ciphercontext);
1075 sshbuf_free(copy);
1076 sshbuf_free(decrypted);
1077 free(iv);
1078 return r;
1079 }
1080
1081 u_int32_t
sshkey_xmss_signatures_left(const struct sshkey * k)1082 sshkey_xmss_signatures_left(const struct sshkey *k)
1083 {
1084 struct ssh_xmss_state *state = k->xmss_state;
1085 u_int32_t idx;
1086
1087 if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1088 state->maxidx) {
1089 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1090 if (idx < state->maxidx)
1091 return state->maxidx - idx;
1092 }
1093 return 0;
1094 }
1095
1096 int
sshkey_xmss_enable_maxsign(struct sshkey * k,u_int32_t maxsign)1097 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1098 {
1099 struct ssh_xmss_state *state = k->xmss_state;
1100
1101 if (sshkey_type_plain(k->type) != KEY_XMSS)
1102 return SSH_ERR_INVALID_ARGUMENT;
1103 if (maxsign == 0)
1104 return 0;
1105 if (state->idx + maxsign < state->idx)
1106 return SSH_ERR_INVALID_ARGUMENT;
1107 state->maxidx = state->idx + maxsign;
1108 return 0;
1109 }
1110