1 /* $OpenBSD: sshkey-xmss.c,v 1.11 2021/04/03 06:18:41 djm Exp $ */ 2 /* 3 * Copyright (c) 2017 Markus Friedl. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions 7 * are met: 8 * 1. Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * 2. Redistributions in binary form must reproduce the above copyright 11 * notice, this list of conditions and the following disclaimer in the 12 * documentation and/or other materials provided with the distribution. 13 * 14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 15 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 16 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 17 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 18 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 19 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 20 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 21 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 23 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 */ 25 26 #include <sys/types.h> 27 #include <sys/uio.h> 28 29 #include <stdio.h> 30 #include <string.h> 31 #include <unistd.h> 32 #include <fcntl.h> 33 #include <errno.h> 34 35 #include "ssh2.h" 36 #include "ssherr.h" 37 #include "sshbuf.h" 38 #include "cipher.h" 39 #include "sshkey.h" 40 #include "sshkey-xmss.h" 41 #include "atomicio.h" 42 #include "log.h" 43 44 #include "xmss_fast.h" 45 46 /* opaque internal XMSS state */ 47 #define XMSS_MAGIC "xmss-state-v1" 48 #define XMSS_CIPHERNAME "aes256-gcm@openssh.com" 49 struct ssh_xmss_state { 50 xmss_params params; 51 u_int32_t n, w, h, k; 52 53 bds_state bds; 54 u_char *stack; 55 u_int32_t stackoffset; 56 u_char *stacklevels; 57 u_char *auth; 58 u_char *keep; 59 u_char *th_nodes; 60 u_char *retain; 61 treehash_inst *treehash; 62 63 u_int32_t idx; /* state read from file */ 64 u_int32_t maxidx; /* restricted # of signatures */ 65 int have_state; /* .state file exists */ 66 int lockfd; /* locked in sshkey_xmss_get_state() */ 67 u_char allow_update; /* allow sshkey_xmss_update_state() */ 68 char *enc_ciphername;/* encrypt state with cipher */ 69 u_char *enc_keyiv; /* encrypt state with key */ 70 u_int32_t enc_keyiv_len; /* length of enc_keyiv */ 71 }; 72 73 int sshkey_xmss_init_bds_state(struct sshkey *); 74 int sshkey_xmss_init_enc_key(struct sshkey *, const char *); 75 void sshkey_xmss_free_bds(struct sshkey *); 76 int sshkey_xmss_get_state_from_file(struct sshkey *, const char *, 77 int *, int); 78 int sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *, 79 struct sshbuf **); 80 int sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *, 81 struct sshbuf **); 82 int sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *); 83 int sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *); 84 85 #define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \ 86 0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0) 87 88 int 89 sshkey_xmss_init(struct sshkey *key, const char *name) 90 { 91 struct ssh_xmss_state *state; 92 93 if (key->xmss_state != NULL) 94 return SSH_ERR_INVALID_FORMAT; 95 if (name == NULL) 96 return SSH_ERR_INVALID_FORMAT; 97 state = calloc(sizeof(struct ssh_xmss_state), 1); 98 if (state == NULL) 99 return SSH_ERR_ALLOC_FAIL; 100 if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) { 101 state->n = 32; 102 state->w = 16; 103 state->h = 10; 104 } else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) { 105 state->n = 32; 106 state->w = 16; 107 state->h = 16; 108 } else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) { 109 state->n = 32; 110 state->w = 16; 111 state->h = 20; 112 } else { 113 free(state); 114 return SSH_ERR_KEY_TYPE_UNKNOWN; 115 } 116 if ((key->xmss_name = strdup(name)) == NULL) { 117 free(state); 118 return SSH_ERR_ALLOC_FAIL; 119 } 120 state->k = 2; /* XXX hardcoded */ 121 state->lockfd = -1; 122 if (xmss_set_params(&state->params, state->n, state->h, state->w, 123 state->k) != 0) { 124 free(state); 125 return SSH_ERR_INVALID_FORMAT; 126 } 127 key->xmss_state = state; 128 return 0; 129 } 130 131 void 132 sshkey_xmss_free_state(struct sshkey *key) 133 { 134 struct ssh_xmss_state *state = key->xmss_state; 135 136 sshkey_xmss_free_bds(key); 137 if (state) { 138 if (state->enc_keyiv) { 139 explicit_bzero(state->enc_keyiv, state->enc_keyiv_len); 140 free(state->enc_keyiv); 141 } 142 free(state->enc_ciphername); 143 free(state); 144 } 145 key->xmss_state = NULL; 146 } 147 148 #define SSH_XMSS_K2_MAGIC "k=2" 149 #define num_stack(x) ((x->h+1)*(x->n)) 150 #define num_stacklevels(x) (x->h+1) 151 #define num_auth(x) ((x->h)*(x->n)) 152 #define num_keep(x) ((x->h >> 1)*(x->n)) 153 #define num_th_nodes(x) ((x->h - x->k)*(x->n)) 154 #define num_retain(x) (((1ULL << x->k) - x->k - 1) * (x->n)) 155 #define num_treehash(x) ((x->h) - (x->k)) 156 157 int 158 sshkey_xmss_init_bds_state(struct sshkey *key) 159 { 160 struct ssh_xmss_state *state = key->xmss_state; 161 u_int32_t i; 162 163 state->stackoffset = 0; 164 if ((state->stack = calloc(num_stack(state), 1)) == NULL || 165 (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL || 166 (state->auth = calloc(num_auth(state), 1)) == NULL || 167 (state->keep = calloc(num_keep(state), 1)) == NULL || 168 (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL || 169 (state->retain = calloc(num_retain(state), 1)) == NULL || 170 (state->treehash = calloc(num_treehash(state), 171 sizeof(treehash_inst))) == NULL) { 172 sshkey_xmss_free_bds(key); 173 return SSH_ERR_ALLOC_FAIL; 174 } 175 for (i = 0; i < state->h - state->k; i++) 176 state->treehash[i].node = &state->th_nodes[state->n*i]; 177 xmss_set_bds_state(&state->bds, state->stack, state->stackoffset, 178 state->stacklevels, state->auth, state->keep, state->treehash, 179 state->retain, 0); 180 return 0; 181 } 182 183 void 184 sshkey_xmss_free_bds(struct sshkey *key) 185 { 186 struct ssh_xmss_state *state = key->xmss_state; 187 188 if (state == NULL) 189 return; 190 free(state->stack); 191 free(state->stacklevels); 192 free(state->auth); 193 free(state->keep); 194 free(state->th_nodes); 195 free(state->retain); 196 free(state->treehash); 197 state->stack = NULL; 198 state->stacklevels = NULL; 199 state->auth = NULL; 200 state->keep = NULL; 201 state->th_nodes = NULL; 202 state->retain = NULL; 203 state->treehash = NULL; 204 } 205 206 void * 207 sshkey_xmss_params(const struct sshkey *key) 208 { 209 struct ssh_xmss_state *state = key->xmss_state; 210 211 if (state == NULL) 212 return NULL; 213 return &state->params; 214 } 215 216 void * 217 sshkey_xmss_bds_state(const struct sshkey *key) 218 { 219 struct ssh_xmss_state *state = key->xmss_state; 220 221 if (state == NULL) 222 return NULL; 223 return &state->bds; 224 } 225 226 int 227 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp) 228 { 229 struct ssh_xmss_state *state = key->xmss_state; 230 231 if (lenp == NULL) 232 return SSH_ERR_INVALID_ARGUMENT; 233 if (state == NULL) 234 return SSH_ERR_INVALID_FORMAT; 235 *lenp = 4 + state->n + 236 state->params.wots_par.keysize + 237 state->h * state->n; 238 return 0; 239 } 240 241 size_t 242 sshkey_xmss_pklen(const struct sshkey *key) 243 { 244 struct ssh_xmss_state *state = key->xmss_state; 245 246 if (state == NULL) 247 return 0; 248 return state->n * 2; 249 } 250 251 size_t 252 sshkey_xmss_sklen(const struct sshkey *key) 253 { 254 struct ssh_xmss_state *state = key->xmss_state; 255 256 if (state == NULL) 257 return 0; 258 return state->n * 4 + 4; 259 } 260 261 int 262 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername) 263 { 264 struct ssh_xmss_state *state = k->xmss_state; 265 const struct sshcipher *cipher; 266 size_t keylen = 0, ivlen = 0; 267 268 if (state == NULL) 269 return SSH_ERR_INVALID_ARGUMENT; 270 if ((cipher = cipher_by_name(ciphername)) == NULL) 271 return SSH_ERR_INTERNAL_ERROR; 272 if ((state->enc_ciphername = strdup(ciphername)) == NULL) 273 return SSH_ERR_ALLOC_FAIL; 274 keylen = cipher_keylen(cipher); 275 ivlen = cipher_ivlen(cipher); 276 state->enc_keyiv_len = keylen + ivlen; 277 if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) { 278 free(state->enc_ciphername); 279 state->enc_ciphername = NULL; 280 return SSH_ERR_ALLOC_FAIL; 281 } 282 arc4random_buf(state->enc_keyiv, state->enc_keyiv_len); 283 return 0; 284 } 285 286 int 287 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b) 288 { 289 struct ssh_xmss_state *state = k->xmss_state; 290 int r; 291 292 if (state == NULL || state->enc_keyiv == NULL || 293 state->enc_ciphername == NULL) 294 return SSH_ERR_INVALID_ARGUMENT; 295 if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 || 296 (r = sshbuf_put_string(b, state->enc_keyiv, 297 state->enc_keyiv_len)) != 0) 298 return r; 299 return 0; 300 } 301 302 int 303 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b) 304 { 305 struct ssh_xmss_state *state = k->xmss_state; 306 size_t len; 307 int r; 308 309 if (state == NULL) 310 return SSH_ERR_INVALID_ARGUMENT; 311 if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 || 312 (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0) 313 return r; 314 state->enc_keyiv_len = len; 315 return 0; 316 } 317 318 int 319 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b, 320 enum sshkey_serialize_rep opts) 321 { 322 struct ssh_xmss_state *state = k->xmss_state; 323 u_char have_info = 1; 324 u_int32_t idx; 325 int r; 326 327 if (state == NULL) 328 return SSH_ERR_INVALID_ARGUMENT; 329 if (opts != SSHKEY_SERIALIZE_INFO) 330 return 0; 331 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx; 332 if ((r = sshbuf_put_u8(b, have_info)) != 0 || 333 (r = sshbuf_put_u32(b, idx)) != 0 || 334 (r = sshbuf_put_u32(b, state->maxidx)) != 0) 335 return r; 336 return 0; 337 } 338 339 int 340 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b) 341 { 342 struct ssh_xmss_state *state = k->xmss_state; 343 u_char have_info; 344 int r; 345 346 if (state == NULL) 347 return SSH_ERR_INVALID_ARGUMENT; 348 /* optional */ 349 if (sshbuf_len(b) == 0) 350 return 0; 351 if ((r = sshbuf_get_u8(b, &have_info)) != 0) 352 return r; 353 if (have_info != 1) 354 return SSH_ERR_INVALID_ARGUMENT; 355 if ((r = sshbuf_get_u32(b, &state->idx)) != 0 || 356 (r = sshbuf_get_u32(b, &state->maxidx)) != 0) 357 return r; 358 return 0; 359 } 360 361 int 362 sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits) 363 { 364 int r; 365 const char *name; 366 367 if (bits == 10) { 368 name = XMSS_SHA2_256_W16_H10_NAME; 369 } else if (bits == 16) { 370 name = XMSS_SHA2_256_W16_H16_NAME; 371 } else if (bits == 20) { 372 name = XMSS_SHA2_256_W16_H20_NAME; 373 } else { 374 name = XMSS_DEFAULT_NAME; 375 } 376 if ((r = sshkey_xmss_init(k, name)) != 0 || 377 (r = sshkey_xmss_init_bds_state(k)) != 0 || 378 (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0) 379 return r; 380 if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL || 381 (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) { 382 return SSH_ERR_ALLOC_FAIL; 383 } 384 xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k), 385 sshkey_xmss_params(k)); 386 return 0; 387 } 388 389 int 390 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename, 391 int *have_file, int printerror) 392 { 393 struct sshbuf *b = NULL, *enc = NULL; 394 int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1; 395 u_int32_t len; 396 unsigned char buf[4], *data = NULL; 397 398 *have_file = 0; 399 if ((fd = open(filename, O_RDONLY)) >= 0) { 400 *have_file = 1; 401 if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) { 402 PRINT("corrupt state file: %s", filename); 403 goto done; 404 } 405 len = PEEK_U32(buf); 406 if ((data = calloc(len, 1)) == NULL) { 407 ret = SSH_ERR_ALLOC_FAIL; 408 goto done; 409 } 410 if (atomicio(read, fd, data, len) != len) { 411 PRINT("cannot read blob: %s", filename); 412 goto done; 413 } 414 if ((enc = sshbuf_from(data, len)) == NULL) { 415 ret = SSH_ERR_ALLOC_FAIL; 416 goto done; 417 } 418 sshkey_xmss_free_bds(k); 419 if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) { 420 ret = r; 421 goto done; 422 } 423 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) { 424 ret = r; 425 goto done; 426 } 427 ret = 0; 428 } 429 done: 430 if (fd != -1) 431 close(fd); 432 free(data); 433 sshbuf_free(enc); 434 sshbuf_free(b); 435 return ret; 436 } 437 438 int 439 sshkey_xmss_get_state(const struct sshkey *k, int printerror) 440 { 441 struct ssh_xmss_state *state = k->xmss_state; 442 u_int32_t idx = 0; 443 char *filename = NULL; 444 char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL; 445 int lockfd = -1, have_state = 0, have_ostate, tries = 0; 446 int ret = SSH_ERR_INVALID_ARGUMENT, r; 447 448 if (state == NULL) 449 goto done; 450 /* 451 * If maxidx is set, then we are allowed a limited number 452 * of signatures, but don't need to access the disk. 453 * Otherwise we need to deal with the on-disk state. 454 */ 455 if (state->maxidx) { 456 /* xmss_sk always contains the current state */ 457 idx = PEEK_U32(k->xmss_sk); 458 if (idx < state->maxidx) { 459 state->allow_update = 1; 460 return 0; 461 } 462 return SSH_ERR_INVALID_ARGUMENT; 463 } 464 if ((filename = k->xmss_filename) == NULL) 465 goto done; 466 if (asprintf(&lockfile, "%s.lock", filename) == -1 || 467 asprintf(&statefile, "%s.state", filename) == -1 || 468 asprintf(&ostatefile, "%s.ostate", filename) == -1) { 469 ret = SSH_ERR_ALLOC_FAIL; 470 goto done; 471 } 472 if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) { 473 ret = SSH_ERR_SYSTEM_ERROR; 474 PRINT("cannot open/create: %s", lockfile); 475 goto done; 476 } 477 while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) { 478 if (errno != EWOULDBLOCK) { 479 ret = SSH_ERR_SYSTEM_ERROR; 480 PRINT("cannot lock: %s", lockfile); 481 goto done; 482 } 483 if (++tries > 10) { 484 ret = SSH_ERR_SYSTEM_ERROR; 485 PRINT("giving up on: %s", lockfile); 486 goto done; 487 } 488 usleep(1000*100*tries); 489 } 490 /* XXX no longer const */ 491 if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k, 492 statefile, &have_state, printerror)) != 0) { 493 if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k, 494 ostatefile, &have_ostate, printerror)) == 0) { 495 state->allow_update = 1; 496 r = sshkey_xmss_forward_state(k, 1); 497 state->idx = PEEK_U32(k->xmss_sk); 498 state->allow_update = 0; 499 } 500 } 501 if (!have_state && !have_ostate) { 502 /* check that bds state is initialized */ 503 if (state->bds.auth == NULL) 504 goto done; 505 PRINT("start from scratch idx 0: %u", state->idx); 506 } else if (r != 0) { 507 ret = r; 508 goto done; 509 } 510 if (state->idx + 1 < state->idx) { 511 PRINT("state wrap: %u", state->idx); 512 goto done; 513 } 514 state->have_state = have_state; 515 state->lockfd = lockfd; 516 state->allow_update = 1; 517 lockfd = -1; 518 ret = 0; 519 done: 520 if (lockfd != -1) 521 close(lockfd); 522 free(lockfile); 523 free(statefile); 524 free(ostatefile); 525 return ret; 526 } 527 528 int 529 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve) 530 { 531 struct ssh_xmss_state *state = k->xmss_state; 532 u_char *sig = NULL; 533 size_t required_siglen; 534 unsigned long long smlen; 535 u_char data; 536 int ret, r; 537 538 if (state == NULL || !state->allow_update) 539 return SSH_ERR_INVALID_ARGUMENT; 540 if (reserve == 0) 541 return SSH_ERR_INVALID_ARGUMENT; 542 if (state->idx + reserve <= state->idx) 543 return SSH_ERR_INVALID_ARGUMENT; 544 if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0) 545 return r; 546 if ((sig = malloc(required_siglen)) == NULL) 547 return SSH_ERR_ALLOC_FAIL; 548 while (reserve-- > 0) { 549 state->idx = PEEK_U32(k->xmss_sk); 550 smlen = required_siglen; 551 if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k), 552 sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) { 553 r = SSH_ERR_INVALID_ARGUMENT; 554 break; 555 } 556 } 557 free(sig); 558 return r; 559 } 560 561 int 562 sshkey_xmss_update_state(const struct sshkey *k, int printerror) 563 { 564 struct ssh_xmss_state *state = k->xmss_state; 565 struct sshbuf *b = NULL, *enc = NULL; 566 u_int32_t idx = 0; 567 unsigned char buf[4]; 568 char *filename = NULL; 569 char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL; 570 int fd = -1; 571 int ret = SSH_ERR_INVALID_ARGUMENT; 572 573 if (state == NULL || !state->allow_update) 574 return ret; 575 if (state->maxidx) { 576 /* no update since the number of signatures is limited */ 577 ret = 0; 578 goto done; 579 } 580 idx = PEEK_U32(k->xmss_sk); 581 if (idx == state->idx) { 582 /* no signature happened, no need to update */ 583 ret = 0; 584 goto done; 585 } else if (idx != state->idx + 1) { 586 PRINT("more than one signature happened: idx %u state %u", 587 idx, state->idx); 588 goto done; 589 } 590 state->idx = idx; 591 if ((filename = k->xmss_filename) == NULL) 592 goto done; 593 if (asprintf(&statefile, "%s.state", filename) == -1 || 594 asprintf(&ostatefile, "%s.ostate", filename) == -1 || 595 asprintf(&nstatefile, "%s.nstate", filename) == -1) { 596 ret = SSH_ERR_ALLOC_FAIL; 597 goto done; 598 } 599 unlink(nstatefile); 600 if ((b = sshbuf_new()) == NULL) { 601 ret = SSH_ERR_ALLOC_FAIL; 602 goto done; 603 } 604 if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) { 605 PRINT("SERLIALIZE FAILED: %d", ret); 606 goto done; 607 } 608 if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) { 609 PRINT("ENCRYPT FAILED: %d", ret); 610 goto done; 611 } 612 if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) { 613 ret = SSH_ERR_SYSTEM_ERROR; 614 PRINT("open new state file: %s", nstatefile); 615 goto done; 616 } 617 POKE_U32(buf, sshbuf_len(enc)); 618 if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) { 619 ret = SSH_ERR_SYSTEM_ERROR; 620 PRINT("write new state file hdr: %s", nstatefile); 621 close(fd); 622 goto done; 623 } 624 if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) != 625 sshbuf_len(enc)) { 626 ret = SSH_ERR_SYSTEM_ERROR; 627 PRINT("write new state file data: %s", nstatefile); 628 close(fd); 629 goto done; 630 } 631 if (fsync(fd) == -1) { 632 ret = SSH_ERR_SYSTEM_ERROR; 633 PRINT("sync new state file: %s", nstatefile); 634 close(fd); 635 goto done; 636 } 637 if (close(fd) == -1) { 638 ret = SSH_ERR_SYSTEM_ERROR; 639 PRINT("close new state file: %s", nstatefile); 640 goto done; 641 } 642 if (state->have_state) { 643 unlink(ostatefile); 644 if (link(statefile, ostatefile)) { 645 ret = SSH_ERR_SYSTEM_ERROR; 646 PRINT("backup state %s to %s", statefile, ostatefile); 647 goto done; 648 } 649 } 650 if (rename(nstatefile, statefile) == -1) { 651 ret = SSH_ERR_SYSTEM_ERROR; 652 PRINT("rename %s to %s", nstatefile, statefile); 653 goto done; 654 } 655 ret = 0; 656 done: 657 if (state->lockfd != -1) { 658 close(state->lockfd); 659 state->lockfd = -1; 660 } 661 if (nstatefile) 662 unlink(nstatefile); 663 free(statefile); 664 free(ostatefile); 665 free(nstatefile); 666 sshbuf_free(b); 667 sshbuf_free(enc); 668 return ret; 669 } 670 671 int 672 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b) 673 { 674 struct ssh_xmss_state *state = k->xmss_state; 675 treehash_inst *th; 676 u_int32_t i, node; 677 int r; 678 679 if (state == NULL) 680 return SSH_ERR_INVALID_ARGUMENT; 681 if (state->stack == NULL) 682 return SSH_ERR_INVALID_ARGUMENT; 683 state->stackoffset = state->bds.stackoffset; /* copy back */ 684 if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 || 685 (r = sshbuf_put_u32(b, state->idx)) != 0 || 686 (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 || 687 (r = sshbuf_put_u32(b, state->stackoffset)) != 0 || 688 (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 || 689 (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 || 690 (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 || 691 (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 || 692 (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 || 693 (r = sshbuf_put_u32(b, num_treehash(state))) != 0) 694 return r; 695 for (i = 0; i < num_treehash(state); i++) { 696 th = &state->treehash[i]; 697 node = th->node - state->th_nodes; 698 if ((r = sshbuf_put_u32(b, th->h)) != 0 || 699 (r = sshbuf_put_u32(b, th->next_idx)) != 0 || 700 (r = sshbuf_put_u32(b, th->stackusage)) != 0 || 701 (r = sshbuf_put_u8(b, th->completed)) != 0 || 702 (r = sshbuf_put_u32(b, node)) != 0) 703 return r; 704 } 705 return 0; 706 } 707 708 int 709 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b, 710 enum sshkey_serialize_rep opts) 711 { 712 struct ssh_xmss_state *state = k->xmss_state; 713 int r = SSH_ERR_INVALID_ARGUMENT; 714 u_char have_stack, have_filename, have_enc; 715 716 if (state == NULL) 717 return SSH_ERR_INVALID_ARGUMENT; 718 if ((r = sshbuf_put_u8(b, opts)) != 0) 719 return r; 720 switch (opts) { 721 case SSHKEY_SERIALIZE_STATE: 722 r = sshkey_xmss_serialize_state(k, b); 723 break; 724 case SSHKEY_SERIALIZE_FULL: 725 if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0) 726 return r; 727 r = sshkey_xmss_serialize_state(k, b); 728 break; 729 case SSHKEY_SERIALIZE_SHIELD: 730 /* all of stack/filename/enc are optional */ 731 have_stack = state->stack != NULL; 732 if ((r = sshbuf_put_u8(b, have_stack)) != 0) 733 return r; 734 if (have_stack) { 735 state->idx = PEEK_U32(k->xmss_sk); /* update */ 736 if ((r = sshkey_xmss_serialize_state(k, b)) != 0) 737 return r; 738 } 739 have_filename = k->xmss_filename != NULL; 740 if ((r = sshbuf_put_u8(b, have_filename)) != 0) 741 return r; 742 if (have_filename && 743 (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0) 744 return r; 745 have_enc = state->enc_keyiv != NULL; 746 if ((r = sshbuf_put_u8(b, have_enc)) != 0) 747 return r; 748 if (have_enc && 749 (r = sshkey_xmss_serialize_enc_key(k, b)) != 0) 750 return r; 751 if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 || 752 (r = sshbuf_put_u8(b, state->allow_update)) != 0) 753 return r; 754 break; 755 case SSHKEY_SERIALIZE_DEFAULT: 756 r = 0; 757 break; 758 default: 759 r = SSH_ERR_INVALID_ARGUMENT; 760 break; 761 } 762 return r; 763 } 764 765 int 766 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b) 767 { 768 struct ssh_xmss_state *state = k->xmss_state; 769 treehash_inst *th; 770 u_int32_t i, lh, node; 771 size_t ls, lsl, la, lk, ln, lr; 772 char *magic; 773 int r = SSH_ERR_INTERNAL_ERROR; 774 775 if (state == NULL) 776 return SSH_ERR_INVALID_ARGUMENT; 777 if (k->xmss_sk == NULL) 778 return SSH_ERR_INVALID_ARGUMENT; 779 if ((state->treehash = calloc(num_treehash(state), 780 sizeof(treehash_inst))) == NULL) 781 return SSH_ERR_ALLOC_FAIL; 782 if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 || 783 (r = sshbuf_get_u32(b, &state->idx)) != 0 || 784 (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 || 785 (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 || 786 (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 || 787 (r = sshbuf_get_string(b, &state->auth, &la)) != 0 || 788 (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 || 789 (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 || 790 (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 || 791 (r = sshbuf_get_u32(b, &lh)) != 0) 792 goto out; 793 if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) { 794 r = SSH_ERR_INVALID_ARGUMENT; 795 goto out; 796 } 797 /* XXX check stackoffset */ 798 if (ls != num_stack(state) || 799 lsl != num_stacklevels(state) || 800 la != num_auth(state) || 801 lk != num_keep(state) || 802 ln != num_th_nodes(state) || 803 lr != num_retain(state) || 804 lh != num_treehash(state)) { 805 r = SSH_ERR_INVALID_ARGUMENT; 806 goto out; 807 } 808 for (i = 0; i < num_treehash(state); i++) { 809 th = &state->treehash[i]; 810 if ((r = sshbuf_get_u32(b, &th->h)) != 0 || 811 (r = sshbuf_get_u32(b, &th->next_idx)) != 0 || 812 (r = sshbuf_get_u32(b, &th->stackusage)) != 0 || 813 (r = sshbuf_get_u8(b, &th->completed)) != 0 || 814 (r = sshbuf_get_u32(b, &node)) != 0) 815 goto out; 816 if (node < num_th_nodes(state)) 817 th->node = &state->th_nodes[node]; 818 } 819 POKE_U32(k->xmss_sk, state->idx); 820 xmss_set_bds_state(&state->bds, state->stack, state->stackoffset, 821 state->stacklevels, state->auth, state->keep, state->treehash, 822 state->retain, 0); 823 /* success */ 824 r = 0; 825 out: 826 free(magic); 827 return r; 828 } 829 830 int 831 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b) 832 { 833 struct ssh_xmss_state *state = k->xmss_state; 834 enum sshkey_serialize_rep opts; 835 u_char have_state, have_stack, have_filename, have_enc; 836 int r; 837 838 if ((r = sshbuf_get_u8(b, &have_state)) != 0) 839 return r; 840 841 opts = have_state; 842 switch (opts) { 843 case SSHKEY_SERIALIZE_DEFAULT: 844 r = 0; 845 break; 846 case SSHKEY_SERIALIZE_SHIELD: 847 if ((r = sshbuf_get_u8(b, &have_stack)) != 0) 848 return r; 849 if (have_stack && 850 (r = sshkey_xmss_deserialize_state(k, b)) != 0) 851 return r; 852 if ((r = sshbuf_get_u8(b, &have_filename)) != 0) 853 return r; 854 if (have_filename && 855 (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0) 856 return r; 857 if ((r = sshbuf_get_u8(b, &have_enc)) != 0) 858 return r; 859 if (have_enc && 860 (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0) 861 return r; 862 if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 || 863 (r = sshbuf_get_u8(b, &state->allow_update)) != 0) 864 return r; 865 break; 866 case SSHKEY_SERIALIZE_STATE: 867 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) 868 return r; 869 break; 870 case SSHKEY_SERIALIZE_FULL: 871 if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 || 872 (r = sshkey_xmss_deserialize_state(k, b)) != 0) 873 return r; 874 break; 875 default: 876 r = SSH_ERR_INVALID_FORMAT; 877 break; 878 } 879 return r; 880 } 881 882 int 883 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b, 884 struct sshbuf **retp) 885 { 886 struct ssh_xmss_state *state = k->xmss_state; 887 struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL; 888 struct sshcipher_ctx *ciphercontext = NULL; 889 const struct sshcipher *cipher; 890 u_char *cp, *key, *iv = NULL; 891 size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen; 892 int r = SSH_ERR_INTERNAL_ERROR; 893 894 if (retp != NULL) 895 *retp = NULL; 896 if (state == NULL || 897 state->enc_keyiv == NULL || 898 state->enc_ciphername == NULL) 899 return SSH_ERR_INTERNAL_ERROR; 900 if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) { 901 r = SSH_ERR_INTERNAL_ERROR; 902 goto out; 903 } 904 blocksize = cipher_blocksize(cipher); 905 keylen = cipher_keylen(cipher); 906 ivlen = cipher_ivlen(cipher); 907 authlen = cipher_authlen(cipher); 908 if (state->enc_keyiv_len != keylen + ivlen) { 909 r = SSH_ERR_INVALID_FORMAT; 910 goto out; 911 } 912 key = state->enc_keyiv; 913 if ((encrypted = sshbuf_new()) == NULL || 914 (encoded = sshbuf_new()) == NULL || 915 (padded = sshbuf_new()) == NULL || 916 (iv = malloc(ivlen)) == NULL) { 917 r = SSH_ERR_ALLOC_FAIL; 918 goto out; 919 } 920 921 /* replace first 4 bytes of IV with index to ensure uniqueness */ 922 memcpy(iv, key + keylen, ivlen); 923 POKE_U32(iv, state->idx); 924 925 if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 || 926 (r = sshbuf_put_u32(encoded, state->idx)) != 0) 927 goto out; 928 929 /* padded state will be encrypted */ 930 if ((r = sshbuf_putb(padded, b)) != 0) 931 goto out; 932 i = 0; 933 while (sshbuf_len(padded) % blocksize) { 934 if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0) 935 goto out; 936 } 937 encrypted_len = sshbuf_len(padded); 938 939 /* header including the length of state is used as AAD */ 940 if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0) 941 goto out; 942 aadlen = sshbuf_len(encoded); 943 944 /* concat header and state */ 945 if ((r = sshbuf_putb(encoded, padded)) != 0) 946 goto out; 947 948 /* reserve space for encryption of encoded data plus auth tag */ 949 /* encrypt at offset addlen */ 950 if ((r = sshbuf_reserve(encrypted, 951 encrypted_len + aadlen + authlen, &cp)) != 0 || 952 (r = cipher_init(&ciphercontext, cipher, key, keylen, 953 iv, ivlen, 1)) != 0 || 954 (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded), 955 encrypted_len, aadlen, authlen)) != 0) 956 goto out; 957 958 /* success */ 959 r = 0; 960 out: 961 if (retp != NULL) { 962 *retp = encrypted; 963 encrypted = NULL; 964 } 965 sshbuf_free(padded); 966 sshbuf_free(encoded); 967 sshbuf_free(encrypted); 968 cipher_free(ciphercontext); 969 free(iv); 970 return r; 971 } 972 973 int 974 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded, 975 struct sshbuf **retp) 976 { 977 struct ssh_xmss_state *state = k->xmss_state; 978 struct sshbuf *copy = NULL, *decrypted = NULL; 979 struct sshcipher_ctx *ciphercontext = NULL; 980 const struct sshcipher *cipher = NULL; 981 u_char *key, *iv = NULL, *dp; 982 size_t keylen, ivlen, authlen, aadlen; 983 u_int blocksize, encrypted_len, index; 984 int r = SSH_ERR_INTERNAL_ERROR; 985 986 if (retp != NULL) 987 *retp = NULL; 988 if (state == NULL || 989 state->enc_keyiv == NULL || 990 state->enc_ciphername == NULL) 991 return SSH_ERR_INTERNAL_ERROR; 992 if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) { 993 r = SSH_ERR_INVALID_FORMAT; 994 goto out; 995 } 996 blocksize = cipher_blocksize(cipher); 997 keylen = cipher_keylen(cipher); 998 ivlen = cipher_ivlen(cipher); 999 authlen = cipher_authlen(cipher); 1000 if (state->enc_keyiv_len != keylen + ivlen) { 1001 r = SSH_ERR_INTERNAL_ERROR; 1002 goto out; 1003 } 1004 key = state->enc_keyiv; 1005 1006 if ((copy = sshbuf_fromb(encoded)) == NULL || 1007 (decrypted = sshbuf_new()) == NULL || 1008 (iv = malloc(ivlen)) == NULL) { 1009 r = SSH_ERR_ALLOC_FAIL; 1010 goto out; 1011 } 1012 1013 /* check magic */ 1014 if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) || 1015 memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) { 1016 r = SSH_ERR_INVALID_FORMAT; 1017 goto out; 1018 } 1019 /* parse public portion */ 1020 if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 || 1021 (r = sshbuf_get_u32(encoded, &index)) != 0 || 1022 (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0) 1023 goto out; 1024 1025 /* check size of encrypted key blob */ 1026 if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) { 1027 r = SSH_ERR_INVALID_FORMAT; 1028 goto out; 1029 } 1030 /* check that an appropriate amount of auth data is present */ 1031 if (sshbuf_len(encoded) < authlen || 1032 sshbuf_len(encoded) - authlen < encrypted_len) { 1033 r = SSH_ERR_INVALID_FORMAT; 1034 goto out; 1035 } 1036 1037 aadlen = sshbuf_len(copy) - sshbuf_len(encoded); 1038 1039 /* replace first 4 bytes of IV with index to ensure uniqueness */ 1040 memcpy(iv, key + keylen, ivlen); 1041 POKE_U32(iv, index); 1042 1043 /* decrypt private state of key */ 1044 if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 || 1045 (r = cipher_init(&ciphercontext, cipher, key, keylen, 1046 iv, ivlen, 0)) != 0 || 1047 (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy), 1048 encrypted_len, aadlen, authlen)) != 0) 1049 goto out; 1050 1051 /* there should be no trailing data */ 1052 if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0) 1053 goto out; 1054 if (sshbuf_len(encoded) != 0) { 1055 r = SSH_ERR_INVALID_FORMAT; 1056 goto out; 1057 } 1058 1059 /* remove AAD */ 1060 if ((r = sshbuf_consume(decrypted, aadlen)) != 0) 1061 goto out; 1062 /* XXX encrypted includes unchecked padding */ 1063 1064 /* success */ 1065 r = 0; 1066 if (retp != NULL) { 1067 *retp = decrypted; 1068 decrypted = NULL; 1069 } 1070 out: 1071 cipher_free(ciphercontext); 1072 sshbuf_free(copy); 1073 sshbuf_free(decrypted); 1074 free(iv); 1075 return r; 1076 } 1077 1078 u_int32_t 1079 sshkey_xmss_signatures_left(const struct sshkey *k) 1080 { 1081 struct ssh_xmss_state *state = k->xmss_state; 1082 u_int32_t idx; 1083 1084 if (sshkey_type_plain(k->type) == KEY_XMSS && state && 1085 state->maxidx) { 1086 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx; 1087 if (idx < state->maxidx) 1088 return state->maxidx - idx; 1089 } 1090 return 0; 1091 } 1092 1093 int 1094 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign) 1095 { 1096 struct ssh_xmss_state *state = k->xmss_state; 1097 1098 if (sshkey_type_plain(k->type) != KEY_XMSS) 1099 return SSH_ERR_INVALID_ARGUMENT; 1100 if (maxsign == 0) 1101 return 0; 1102 if (state->idx + maxsign < state->idx) 1103 return SSH_ERR_INVALID_ARGUMENT; 1104 state->maxidx = state->idx + maxsign; 1105 return 0; 1106 } 1107