1 /* $OpenBSD: sshkey-xmss.c,v 1.8 2019/11/13 07:53:10 markus Exp $ */ 2 /* 3 * Copyright (c) 2017 Markus Friedl. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions 7 * are met: 8 * 1. Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * 2. Redistributions in binary form must reproduce the above copyright 11 * notice, this list of conditions and the following disclaimer in the 12 * documentation and/or other materials provided with the distribution. 13 * 14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 15 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 16 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 17 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 18 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 19 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 20 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 21 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 23 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 */ 25 26 #include <sys/types.h> 27 #include <sys/uio.h> 28 29 #include <stdio.h> 30 #include <string.h> 31 #include <unistd.h> 32 #include <fcntl.h> 33 #include <errno.h> 34 35 #include "ssh2.h" 36 #include "ssherr.h" 37 #include "sshbuf.h" 38 #include "cipher.h" 39 #include "sshkey.h" 40 #include "sshkey-xmss.h" 41 #include "atomicio.h" 42 43 #include "xmss_fast.h" 44 45 /* opaque internal XMSS state */ 46 #define XMSS_MAGIC "xmss-state-v1" 47 #define XMSS_CIPHERNAME "aes256-gcm@openssh.com" 48 struct ssh_xmss_state { 49 xmss_params params; 50 u_int32_t n, w, h, k; 51 52 bds_state bds; 53 u_char *stack; 54 u_int32_t stackoffset; 55 u_char *stacklevels; 56 u_char *auth; 57 u_char *keep; 58 u_char *th_nodes; 59 u_char *retain; 60 treehash_inst *treehash; 61 62 u_int32_t idx; /* state read from file */ 63 u_int32_t maxidx; /* restricted # of signatures */ 64 int have_state; /* .state file exists */ 65 int lockfd; /* locked in sshkey_xmss_get_state() */ 66 u_char allow_update; /* allow sshkey_xmss_update_state() */ 67 char *enc_ciphername;/* encrypt state with cipher */ 68 u_char *enc_keyiv; /* encrypt state with key */ 69 u_int32_t enc_keyiv_len; /* length of enc_keyiv */ 70 }; 71 72 int sshkey_xmss_init_bds_state(struct sshkey *); 73 int sshkey_xmss_init_enc_key(struct sshkey *, const char *); 74 void sshkey_xmss_free_bds(struct sshkey *); 75 int sshkey_xmss_get_state_from_file(struct sshkey *, const char *, 76 int *, sshkey_printfn *); 77 int sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *, 78 struct sshbuf **); 79 int sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *, 80 struct sshbuf **); 81 int sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *); 82 int sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *); 83 84 #define PRINT(s...) do { if (pr) pr(s); } while (0) 85 86 int 87 sshkey_xmss_init(struct sshkey *key, const char *name) 88 { 89 struct ssh_xmss_state *state; 90 91 if (key->xmss_state != NULL) 92 return SSH_ERR_INVALID_FORMAT; 93 if (name == NULL) 94 return SSH_ERR_INVALID_FORMAT; 95 state = calloc(sizeof(struct ssh_xmss_state), 1); 96 if (state == NULL) 97 return SSH_ERR_ALLOC_FAIL; 98 if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) { 99 state->n = 32; 100 state->w = 16; 101 state->h = 10; 102 } else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) { 103 state->n = 32; 104 state->w = 16; 105 state->h = 16; 106 } else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) { 107 state->n = 32; 108 state->w = 16; 109 state->h = 20; 110 } else { 111 free(state); 112 return SSH_ERR_KEY_TYPE_UNKNOWN; 113 } 114 if ((key->xmss_name = strdup(name)) == NULL) { 115 free(state); 116 return SSH_ERR_ALLOC_FAIL; 117 } 118 state->k = 2; /* XXX hardcoded */ 119 state->lockfd = -1; 120 if (xmss_set_params(&state->params, state->n, state->h, state->w, 121 state->k) != 0) { 122 free(state); 123 return SSH_ERR_INVALID_FORMAT; 124 } 125 key->xmss_state = state; 126 return 0; 127 } 128 129 void 130 sshkey_xmss_free_state(struct sshkey *key) 131 { 132 struct ssh_xmss_state *state = key->xmss_state; 133 134 sshkey_xmss_free_bds(key); 135 if (state) { 136 if (state->enc_keyiv) { 137 explicit_bzero(state->enc_keyiv, state->enc_keyiv_len); 138 free(state->enc_keyiv); 139 } 140 free(state->enc_ciphername); 141 free(state); 142 } 143 key->xmss_state = NULL; 144 } 145 146 #define SSH_XMSS_K2_MAGIC "k=2" 147 #define num_stack(x) ((x->h+1)*(x->n)) 148 #define num_stacklevels(x) (x->h+1) 149 #define num_auth(x) ((x->h)*(x->n)) 150 #define num_keep(x) ((x->h >> 1)*(x->n)) 151 #define num_th_nodes(x) ((x->h - x->k)*(x->n)) 152 #define num_retain(x) (((1ULL << x->k) - x->k - 1) * (x->n)) 153 #define num_treehash(x) ((x->h) - (x->k)) 154 155 int 156 sshkey_xmss_init_bds_state(struct sshkey *key) 157 { 158 struct ssh_xmss_state *state = key->xmss_state; 159 u_int32_t i; 160 161 state->stackoffset = 0; 162 if ((state->stack = calloc(num_stack(state), 1)) == NULL || 163 (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL || 164 (state->auth = calloc(num_auth(state), 1)) == NULL || 165 (state->keep = calloc(num_keep(state), 1)) == NULL || 166 (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL || 167 (state->retain = calloc(num_retain(state), 1)) == NULL || 168 (state->treehash = calloc(num_treehash(state), 169 sizeof(treehash_inst))) == NULL) { 170 sshkey_xmss_free_bds(key); 171 return SSH_ERR_ALLOC_FAIL; 172 } 173 for (i = 0; i < state->h - state->k; i++) 174 state->treehash[i].node = &state->th_nodes[state->n*i]; 175 xmss_set_bds_state(&state->bds, state->stack, state->stackoffset, 176 state->stacklevels, state->auth, state->keep, state->treehash, 177 state->retain, 0); 178 return 0; 179 } 180 181 void 182 sshkey_xmss_free_bds(struct sshkey *key) 183 { 184 struct ssh_xmss_state *state = key->xmss_state; 185 186 if (state == NULL) 187 return; 188 free(state->stack); 189 free(state->stacklevels); 190 free(state->auth); 191 free(state->keep); 192 free(state->th_nodes); 193 free(state->retain); 194 free(state->treehash); 195 state->stack = NULL; 196 state->stacklevels = NULL; 197 state->auth = NULL; 198 state->keep = NULL; 199 state->th_nodes = NULL; 200 state->retain = NULL; 201 state->treehash = NULL; 202 } 203 204 void * 205 sshkey_xmss_params(const struct sshkey *key) 206 { 207 struct ssh_xmss_state *state = key->xmss_state; 208 209 if (state == NULL) 210 return NULL; 211 return &state->params; 212 } 213 214 void * 215 sshkey_xmss_bds_state(const struct sshkey *key) 216 { 217 struct ssh_xmss_state *state = key->xmss_state; 218 219 if (state == NULL) 220 return NULL; 221 return &state->bds; 222 } 223 224 int 225 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp) 226 { 227 struct ssh_xmss_state *state = key->xmss_state; 228 229 if (lenp == NULL) 230 return SSH_ERR_INVALID_ARGUMENT; 231 if (state == NULL) 232 return SSH_ERR_INVALID_FORMAT; 233 *lenp = 4 + state->n + 234 state->params.wots_par.keysize + 235 state->h * state->n; 236 return 0; 237 } 238 239 size_t 240 sshkey_xmss_pklen(const struct sshkey *key) 241 { 242 struct ssh_xmss_state *state = key->xmss_state; 243 244 if (state == NULL) 245 return 0; 246 return state->n * 2; 247 } 248 249 size_t 250 sshkey_xmss_sklen(const struct sshkey *key) 251 { 252 struct ssh_xmss_state *state = key->xmss_state; 253 254 if (state == NULL) 255 return 0; 256 return state->n * 4 + 4; 257 } 258 259 int 260 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername) 261 { 262 struct ssh_xmss_state *state = k->xmss_state; 263 const struct sshcipher *cipher; 264 size_t keylen = 0, ivlen = 0; 265 266 if (state == NULL) 267 return SSH_ERR_INVALID_ARGUMENT; 268 if ((cipher = cipher_by_name(ciphername)) == NULL) 269 return SSH_ERR_INTERNAL_ERROR; 270 if ((state->enc_ciphername = strdup(ciphername)) == NULL) 271 return SSH_ERR_ALLOC_FAIL; 272 keylen = cipher_keylen(cipher); 273 ivlen = cipher_ivlen(cipher); 274 state->enc_keyiv_len = keylen + ivlen; 275 if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) { 276 free(state->enc_ciphername); 277 state->enc_ciphername = NULL; 278 return SSH_ERR_ALLOC_FAIL; 279 } 280 arc4random_buf(state->enc_keyiv, state->enc_keyiv_len); 281 return 0; 282 } 283 284 int 285 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b) 286 { 287 struct ssh_xmss_state *state = k->xmss_state; 288 int r; 289 290 if (state == NULL || state->enc_keyiv == NULL || 291 state->enc_ciphername == NULL) 292 return SSH_ERR_INVALID_ARGUMENT; 293 if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 || 294 (r = sshbuf_put_string(b, state->enc_keyiv, 295 state->enc_keyiv_len)) != 0) 296 return r; 297 return 0; 298 } 299 300 int 301 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b) 302 { 303 struct ssh_xmss_state *state = k->xmss_state; 304 size_t len; 305 int r; 306 307 if (state == NULL) 308 return SSH_ERR_INVALID_ARGUMENT; 309 if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 || 310 (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0) 311 return r; 312 state->enc_keyiv_len = len; 313 return 0; 314 } 315 316 int 317 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b, 318 enum sshkey_serialize_rep opts) 319 { 320 struct ssh_xmss_state *state = k->xmss_state; 321 u_char have_info = 1; 322 u_int32_t idx; 323 int r; 324 325 if (state == NULL) 326 return SSH_ERR_INVALID_ARGUMENT; 327 if (opts != SSHKEY_SERIALIZE_INFO) 328 return 0; 329 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx; 330 if ((r = sshbuf_put_u8(b, have_info)) != 0 || 331 (r = sshbuf_put_u32(b, idx)) != 0 || 332 (r = sshbuf_put_u32(b, state->maxidx)) != 0) 333 return r; 334 return 0; 335 } 336 337 int 338 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b) 339 { 340 struct ssh_xmss_state *state = k->xmss_state; 341 u_char have_info; 342 int r; 343 344 if (state == NULL) 345 return SSH_ERR_INVALID_ARGUMENT; 346 /* optional */ 347 if (sshbuf_len(b) == 0) 348 return 0; 349 if ((r = sshbuf_get_u8(b, &have_info)) != 0) 350 return r; 351 if (have_info != 1) 352 return SSH_ERR_INVALID_ARGUMENT; 353 if ((r = sshbuf_get_u32(b, &state->idx)) != 0 || 354 (r = sshbuf_get_u32(b, &state->maxidx)) != 0) 355 return r; 356 return 0; 357 } 358 359 int 360 sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits) 361 { 362 int r; 363 const char *name; 364 365 if (bits == 10) { 366 name = XMSS_SHA2_256_W16_H10_NAME; 367 } else if (bits == 16) { 368 name = XMSS_SHA2_256_W16_H16_NAME; 369 } else if (bits == 20) { 370 name = XMSS_SHA2_256_W16_H20_NAME; 371 } else { 372 name = XMSS_DEFAULT_NAME; 373 } 374 if ((r = sshkey_xmss_init(k, name)) != 0 || 375 (r = sshkey_xmss_init_bds_state(k)) != 0 || 376 (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0) 377 return r; 378 if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL || 379 (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) { 380 return SSH_ERR_ALLOC_FAIL; 381 } 382 xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k), 383 sshkey_xmss_params(k)); 384 return 0; 385 } 386 387 int 388 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename, 389 int *have_file, sshkey_printfn *pr) 390 { 391 struct sshbuf *b = NULL, *enc = NULL; 392 int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1; 393 u_int32_t len; 394 unsigned char buf[4], *data = NULL; 395 396 *have_file = 0; 397 if ((fd = open(filename, O_RDONLY)) >= 0) { 398 *have_file = 1; 399 if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) { 400 PRINT("%s: corrupt state file: %s", __func__, filename); 401 goto done; 402 } 403 len = PEEK_U32(buf); 404 if ((data = calloc(len, 1)) == NULL) { 405 ret = SSH_ERR_ALLOC_FAIL; 406 goto done; 407 } 408 if (atomicio(read, fd, data, len) != len) { 409 PRINT("%s: cannot read blob: %s", __func__, filename); 410 goto done; 411 } 412 if ((enc = sshbuf_from(data, len)) == NULL) { 413 ret = SSH_ERR_ALLOC_FAIL; 414 goto done; 415 } 416 sshkey_xmss_free_bds(k); 417 if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) { 418 ret = r; 419 goto done; 420 } 421 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) { 422 ret = r; 423 goto done; 424 } 425 ret = 0; 426 } 427 done: 428 if (fd != -1) 429 close(fd); 430 free(data); 431 sshbuf_free(enc); 432 sshbuf_free(b); 433 return ret; 434 } 435 436 int 437 sshkey_xmss_get_state(const struct sshkey *k, sshkey_printfn *pr) 438 { 439 struct ssh_xmss_state *state = k->xmss_state; 440 u_int32_t idx = 0; 441 char *filename = NULL; 442 char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL; 443 int lockfd = -1, have_state = 0, have_ostate, tries = 0; 444 int ret = SSH_ERR_INVALID_ARGUMENT, r; 445 446 if (state == NULL) 447 goto done; 448 /* 449 * If maxidx is set, then we are allowed a limited number 450 * of signatures, but don't need to access the disk. 451 * Otherwise we need to deal with the on-disk state. 452 */ 453 if (state->maxidx) { 454 /* xmss_sk always contains the current state */ 455 idx = PEEK_U32(k->xmss_sk); 456 if (idx < state->maxidx) { 457 state->allow_update = 1; 458 return 0; 459 } 460 return SSH_ERR_INVALID_ARGUMENT; 461 } 462 if ((filename = k->xmss_filename) == NULL) 463 goto done; 464 if (asprintf(&lockfile, "%s.lock", filename) == -1 || 465 asprintf(&statefile, "%s.state", filename) == -1 || 466 asprintf(&ostatefile, "%s.ostate", filename) == -1) { 467 ret = SSH_ERR_ALLOC_FAIL; 468 goto done; 469 } 470 if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) { 471 ret = SSH_ERR_SYSTEM_ERROR; 472 PRINT("%s: cannot open/create: %s", __func__, lockfile); 473 goto done; 474 } 475 while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) { 476 if (errno != EWOULDBLOCK) { 477 ret = SSH_ERR_SYSTEM_ERROR; 478 PRINT("%s: cannot lock: %s", __func__, lockfile); 479 goto done; 480 } 481 if (++tries > 10) { 482 ret = SSH_ERR_SYSTEM_ERROR; 483 PRINT("%s: giving up on: %s", __func__, lockfile); 484 goto done; 485 } 486 usleep(1000*100*tries); 487 } 488 /* XXX no longer const */ 489 if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k, 490 statefile, &have_state, pr)) != 0) { 491 if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k, 492 ostatefile, &have_ostate, pr)) == 0) { 493 state->allow_update = 1; 494 r = sshkey_xmss_forward_state(k, 1); 495 state->idx = PEEK_U32(k->xmss_sk); 496 state->allow_update = 0; 497 } 498 } 499 if (!have_state && !have_ostate) { 500 /* check that bds state is initialized */ 501 if (state->bds.auth == NULL) 502 goto done; 503 PRINT("%s: start from scratch idx 0: %u", __func__, state->idx); 504 } else if (r != 0) { 505 ret = r; 506 goto done; 507 } 508 if (state->idx + 1 < state->idx) { 509 PRINT("%s: state wrap: %u", __func__, state->idx); 510 goto done; 511 } 512 state->have_state = have_state; 513 state->lockfd = lockfd; 514 state->allow_update = 1; 515 lockfd = -1; 516 ret = 0; 517 done: 518 if (lockfd != -1) 519 close(lockfd); 520 free(lockfile); 521 free(statefile); 522 free(ostatefile); 523 return ret; 524 } 525 526 int 527 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve) 528 { 529 struct ssh_xmss_state *state = k->xmss_state; 530 u_char *sig = NULL; 531 size_t required_siglen; 532 unsigned long long smlen; 533 u_char data; 534 int ret, r; 535 536 if (state == NULL || !state->allow_update) 537 return SSH_ERR_INVALID_ARGUMENT; 538 if (reserve == 0) 539 return SSH_ERR_INVALID_ARGUMENT; 540 if (state->idx + reserve <= state->idx) 541 return SSH_ERR_INVALID_ARGUMENT; 542 if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0) 543 return r; 544 if ((sig = malloc(required_siglen)) == NULL) 545 return SSH_ERR_ALLOC_FAIL; 546 while (reserve-- > 0) { 547 state->idx = PEEK_U32(k->xmss_sk); 548 smlen = required_siglen; 549 if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k), 550 sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) { 551 r = SSH_ERR_INVALID_ARGUMENT; 552 break; 553 } 554 } 555 free(sig); 556 return r; 557 } 558 559 int 560 sshkey_xmss_update_state(const struct sshkey *k, sshkey_printfn *pr) 561 { 562 struct ssh_xmss_state *state = k->xmss_state; 563 struct sshbuf *b = NULL, *enc = NULL; 564 u_int32_t idx = 0; 565 unsigned char buf[4]; 566 char *filename = NULL; 567 char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL; 568 int fd = -1; 569 int ret = SSH_ERR_INVALID_ARGUMENT; 570 571 if (state == NULL || !state->allow_update) 572 return ret; 573 if (state->maxidx) { 574 /* no update since the number of signatures is limited */ 575 ret = 0; 576 goto done; 577 } 578 idx = PEEK_U32(k->xmss_sk); 579 if (idx == state->idx) { 580 /* no signature happened, no need to update */ 581 ret = 0; 582 goto done; 583 } else if (idx != state->idx + 1) { 584 PRINT("%s: more than one signature happened: idx %u state %u", 585 __func__, idx, state->idx); 586 goto done; 587 } 588 state->idx = idx; 589 if ((filename = k->xmss_filename) == NULL) 590 goto done; 591 if (asprintf(&statefile, "%s.state", filename) == -1 || 592 asprintf(&ostatefile, "%s.ostate", filename) == -1 || 593 asprintf(&nstatefile, "%s.nstate", filename) == -1) { 594 ret = SSH_ERR_ALLOC_FAIL; 595 goto done; 596 } 597 unlink(nstatefile); 598 if ((b = sshbuf_new()) == NULL) { 599 ret = SSH_ERR_ALLOC_FAIL; 600 goto done; 601 } 602 if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) { 603 PRINT("%s: SERLIALIZE FAILED: %d", __func__, ret); 604 goto done; 605 } 606 if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) { 607 PRINT("%s: ENCRYPT FAILED: %d", __func__, ret); 608 goto done; 609 } 610 if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) { 611 ret = SSH_ERR_SYSTEM_ERROR; 612 PRINT("%s: open new state file: %s", __func__, nstatefile); 613 goto done; 614 } 615 POKE_U32(buf, sshbuf_len(enc)); 616 if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) { 617 ret = SSH_ERR_SYSTEM_ERROR; 618 PRINT("%s: write new state file hdr: %s", __func__, nstatefile); 619 close(fd); 620 goto done; 621 } 622 if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) != 623 sshbuf_len(enc)) { 624 ret = SSH_ERR_SYSTEM_ERROR; 625 PRINT("%s: write new state file data: %s", __func__, nstatefile); 626 close(fd); 627 goto done; 628 } 629 if (fsync(fd) == -1) { 630 ret = SSH_ERR_SYSTEM_ERROR; 631 PRINT("%s: sync new state file: %s", __func__, nstatefile); 632 close(fd); 633 goto done; 634 } 635 if (close(fd) == -1) { 636 ret = SSH_ERR_SYSTEM_ERROR; 637 PRINT("%s: close new state file: %s", __func__, nstatefile); 638 goto done; 639 } 640 if (state->have_state) { 641 unlink(ostatefile); 642 if (link(statefile, ostatefile)) { 643 ret = SSH_ERR_SYSTEM_ERROR; 644 PRINT("%s: backup state %s to %s", __func__, statefile, 645 ostatefile); 646 goto done; 647 } 648 } 649 if (rename(nstatefile, statefile) == -1) { 650 ret = SSH_ERR_SYSTEM_ERROR; 651 PRINT("%s: rename %s to %s", __func__, nstatefile, statefile); 652 goto done; 653 } 654 ret = 0; 655 done: 656 if (state->lockfd != -1) { 657 close(state->lockfd); 658 state->lockfd = -1; 659 } 660 if (nstatefile) 661 unlink(nstatefile); 662 free(statefile); 663 free(ostatefile); 664 free(nstatefile); 665 sshbuf_free(b); 666 sshbuf_free(enc); 667 return ret; 668 } 669 670 int 671 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b) 672 { 673 struct ssh_xmss_state *state = k->xmss_state; 674 treehash_inst *th; 675 u_int32_t i, node; 676 int r; 677 678 if (state == NULL) 679 return SSH_ERR_INVALID_ARGUMENT; 680 if (state->stack == NULL) 681 return SSH_ERR_INVALID_ARGUMENT; 682 state->stackoffset = state->bds.stackoffset; /* copy back */ 683 if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 || 684 (r = sshbuf_put_u32(b, state->idx)) != 0 || 685 (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 || 686 (r = sshbuf_put_u32(b, state->stackoffset)) != 0 || 687 (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 || 688 (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 || 689 (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 || 690 (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 || 691 (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 || 692 (r = sshbuf_put_u32(b, num_treehash(state))) != 0) 693 return r; 694 for (i = 0; i < num_treehash(state); i++) { 695 th = &state->treehash[i]; 696 node = th->node - state->th_nodes; 697 if ((r = sshbuf_put_u32(b, th->h)) != 0 || 698 (r = sshbuf_put_u32(b, th->next_idx)) != 0 || 699 (r = sshbuf_put_u32(b, th->stackusage)) != 0 || 700 (r = sshbuf_put_u8(b, th->completed)) != 0 || 701 (r = sshbuf_put_u32(b, node)) != 0) 702 return r; 703 } 704 return 0; 705 } 706 707 int 708 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b, 709 enum sshkey_serialize_rep opts) 710 { 711 struct ssh_xmss_state *state = k->xmss_state; 712 int r = SSH_ERR_INVALID_ARGUMENT; 713 u_char have_stack, have_filename, have_enc; 714 715 if (state == NULL) 716 return SSH_ERR_INVALID_ARGUMENT; 717 if ((r = sshbuf_put_u8(b, opts)) != 0) 718 return r; 719 switch (opts) { 720 case SSHKEY_SERIALIZE_STATE: 721 r = sshkey_xmss_serialize_state(k, b); 722 break; 723 case SSHKEY_SERIALIZE_FULL: 724 if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0) 725 return r; 726 r = sshkey_xmss_serialize_state(k, b); 727 break; 728 case SSHKEY_SERIALIZE_SHIELD: 729 /* all of stack/filename/enc are optional */ 730 have_stack = state->stack != NULL; 731 if ((r = sshbuf_put_u8(b, have_stack)) != 0) 732 return r; 733 if (have_stack) { 734 state->idx = PEEK_U32(k->xmss_sk); /* update */ 735 if ((r = sshkey_xmss_serialize_state(k, b)) != 0) 736 return r; 737 } 738 have_filename = k->xmss_filename != NULL; 739 if ((r = sshbuf_put_u8(b, have_filename)) != 0) 740 return r; 741 if (have_filename && 742 (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0) 743 return r; 744 have_enc = state->enc_keyiv != NULL; 745 if ((r = sshbuf_put_u8(b, have_enc)) != 0) 746 return r; 747 if (have_enc && 748 (r = sshkey_xmss_serialize_enc_key(k, b)) != 0) 749 return r; 750 if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 || 751 (r = sshbuf_put_u8(b, state->allow_update)) != 0) 752 return r; 753 break; 754 case SSHKEY_SERIALIZE_DEFAULT: 755 r = 0; 756 break; 757 default: 758 r = SSH_ERR_INVALID_ARGUMENT; 759 break; 760 } 761 return r; 762 } 763 764 int 765 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b) 766 { 767 struct ssh_xmss_state *state = k->xmss_state; 768 treehash_inst *th; 769 u_int32_t i, lh, node; 770 size_t ls, lsl, la, lk, ln, lr; 771 char *magic; 772 int r = SSH_ERR_INTERNAL_ERROR; 773 774 if (state == NULL) 775 return SSH_ERR_INVALID_ARGUMENT; 776 if (k->xmss_sk == NULL) 777 return SSH_ERR_INVALID_ARGUMENT; 778 if ((state->treehash = calloc(num_treehash(state), 779 sizeof(treehash_inst))) == NULL) 780 return SSH_ERR_ALLOC_FAIL; 781 if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 || 782 (r = sshbuf_get_u32(b, &state->idx)) != 0 || 783 (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 || 784 (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 || 785 (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 || 786 (r = sshbuf_get_string(b, &state->auth, &la)) != 0 || 787 (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 || 788 (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 || 789 (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 || 790 (r = sshbuf_get_u32(b, &lh)) != 0) 791 goto out; 792 if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) { 793 r = SSH_ERR_INVALID_ARGUMENT; 794 goto out; 795 } 796 /* XXX check stackoffset */ 797 if (ls != num_stack(state) || 798 lsl != num_stacklevels(state) || 799 la != num_auth(state) || 800 lk != num_keep(state) || 801 ln != num_th_nodes(state) || 802 lr != num_retain(state) || 803 lh != num_treehash(state)) { 804 r = SSH_ERR_INVALID_ARGUMENT; 805 goto out; 806 } 807 for (i = 0; i < num_treehash(state); i++) { 808 th = &state->treehash[i]; 809 if ((r = sshbuf_get_u32(b, &th->h)) != 0 || 810 (r = sshbuf_get_u32(b, &th->next_idx)) != 0 || 811 (r = sshbuf_get_u32(b, &th->stackusage)) != 0 || 812 (r = sshbuf_get_u8(b, &th->completed)) != 0 || 813 (r = sshbuf_get_u32(b, &node)) != 0) 814 goto out; 815 if (node < num_th_nodes(state)) 816 th->node = &state->th_nodes[node]; 817 } 818 POKE_U32(k->xmss_sk, state->idx); 819 xmss_set_bds_state(&state->bds, state->stack, state->stackoffset, 820 state->stacklevels, state->auth, state->keep, state->treehash, 821 state->retain, 0); 822 /* success */ 823 r = 0; 824 out: 825 free(magic); 826 return r; 827 } 828 829 int 830 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b) 831 { 832 struct ssh_xmss_state *state = k->xmss_state; 833 enum sshkey_serialize_rep opts; 834 u_char have_state, have_stack, have_filename, have_enc; 835 int r; 836 837 if ((r = sshbuf_get_u8(b, &have_state)) != 0) 838 return r; 839 840 opts = have_state; 841 switch (opts) { 842 case SSHKEY_SERIALIZE_DEFAULT: 843 r = 0; 844 break; 845 case SSHKEY_SERIALIZE_SHIELD: 846 if ((r = sshbuf_get_u8(b, &have_stack)) != 0) 847 return r; 848 if (have_stack && 849 (r = sshkey_xmss_deserialize_state(k, b)) != 0) 850 return r; 851 if ((r = sshbuf_get_u8(b, &have_filename)) != 0) 852 return r; 853 if (have_filename && 854 (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0) 855 return r; 856 if ((r = sshbuf_get_u8(b, &have_enc)) != 0) 857 return r; 858 if (have_enc && 859 (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0) 860 return r; 861 if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 || 862 (r = sshbuf_get_u8(b, &state->allow_update)) != 0) 863 return r; 864 break; 865 case SSHKEY_SERIALIZE_STATE: 866 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) 867 return r; 868 break; 869 case SSHKEY_SERIALIZE_FULL: 870 if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 || 871 (r = sshkey_xmss_deserialize_state(k, b)) != 0) 872 return r; 873 break; 874 default: 875 r = SSH_ERR_INVALID_FORMAT; 876 break; 877 } 878 return r; 879 } 880 881 int 882 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b, 883 struct sshbuf **retp) 884 { 885 struct ssh_xmss_state *state = k->xmss_state; 886 struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL; 887 struct sshcipher_ctx *ciphercontext = NULL; 888 const struct sshcipher *cipher; 889 u_char *cp, *key, *iv = NULL; 890 size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen; 891 int r = SSH_ERR_INTERNAL_ERROR; 892 893 if (retp != NULL) 894 *retp = NULL; 895 if (state == NULL || 896 state->enc_keyiv == NULL || 897 state->enc_ciphername == NULL) 898 return SSH_ERR_INTERNAL_ERROR; 899 if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) { 900 r = SSH_ERR_INTERNAL_ERROR; 901 goto out; 902 } 903 blocksize = cipher_blocksize(cipher); 904 keylen = cipher_keylen(cipher); 905 ivlen = cipher_ivlen(cipher); 906 authlen = cipher_authlen(cipher); 907 if (state->enc_keyiv_len != keylen + ivlen) { 908 r = SSH_ERR_INVALID_FORMAT; 909 goto out; 910 } 911 key = state->enc_keyiv; 912 if ((encrypted = sshbuf_new()) == NULL || 913 (encoded = sshbuf_new()) == NULL || 914 (padded = sshbuf_new()) == NULL || 915 (iv = malloc(ivlen)) == NULL) { 916 r = SSH_ERR_ALLOC_FAIL; 917 goto out; 918 } 919 920 /* replace first 4 bytes of IV with index to ensure uniqueness */ 921 memcpy(iv, key + keylen, ivlen); 922 POKE_U32(iv, state->idx); 923 924 if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 || 925 (r = sshbuf_put_u32(encoded, state->idx)) != 0) 926 goto out; 927 928 /* padded state will be encrypted */ 929 if ((r = sshbuf_putb(padded, b)) != 0) 930 goto out; 931 i = 0; 932 while (sshbuf_len(padded) % blocksize) { 933 if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0) 934 goto out; 935 } 936 encrypted_len = sshbuf_len(padded); 937 938 /* header including the length of state is used as AAD */ 939 if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0) 940 goto out; 941 aadlen = sshbuf_len(encoded); 942 943 /* concat header and state */ 944 if ((r = sshbuf_putb(encoded, padded)) != 0) 945 goto out; 946 947 /* reserve space for encryption of encoded data plus auth tag */ 948 /* encrypt at offset addlen */ 949 if ((r = sshbuf_reserve(encrypted, 950 encrypted_len + aadlen + authlen, &cp)) != 0 || 951 (r = cipher_init(&ciphercontext, cipher, key, keylen, 952 iv, ivlen, 1)) != 0 || 953 (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded), 954 encrypted_len, aadlen, authlen)) != 0) 955 goto out; 956 957 /* success */ 958 r = 0; 959 out: 960 if (retp != NULL) { 961 *retp = encrypted; 962 encrypted = NULL; 963 } 964 sshbuf_free(padded); 965 sshbuf_free(encoded); 966 sshbuf_free(encrypted); 967 cipher_free(ciphercontext); 968 free(iv); 969 return r; 970 } 971 972 int 973 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded, 974 struct sshbuf **retp) 975 { 976 struct ssh_xmss_state *state = k->xmss_state; 977 struct sshbuf *copy = NULL, *decrypted = NULL; 978 struct sshcipher_ctx *ciphercontext = NULL; 979 const struct sshcipher *cipher = NULL; 980 u_char *key, *iv = NULL, *dp; 981 size_t keylen, ivlen, authlen, aadlen; 982 u_int blocksize, encrypted_len, index; 983 int r = SSH_ERR_INTERNAL_ERROR; 984 985 if (retp != NULL) 986 *retp = NULL; 987 if (state == NULL || 988 state->enc_keyiv == NULL || 989 state->enc_ciphername == NULL) 990 return SSH_ERR_INTERNAL_ERROR; 991 if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) { 992 r = SSH_ERR_INVALID_FORMAT; 993 goto out; 994 } 995 blocksize = cipher_blocksize(cipher); 996 keylen = cipher_keylen(cipher); 997 ivlen = cipher_ivlen(cipher); 998 authlen = cipher_authlen(cipher); 999 if (state->enc_keyiv_len != keylen + ivlen) { 1000 r = SSH_ERR_INTERNAL_ERROR; 1001 goto out; 1002 } 1003 key = state->enc_keyiv; 1004 1005 if ((copy = sshbuf_fromb(encoded)) == NULL || 1006 (decrypted = sshbuf_new()) == NULL || 1007 (iv = malloc(ivlen)) == NULL) { 1008 r = SSH_ERR_ALLOC_FAIL; 1009 goto out; 1010 } 1011 1012 /* check magic */ 1013 if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) || 1014 memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) { 1015 r = SSH_ERR_INVALID_FORMAT; 1016 goto out; 1017 } 1018 /* parse public portion */ 1019 if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 || 1020 (r = sshbuf_get_u32(encoded, &index)) != 0 || 1021 (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0) 1022 goto out; 1023 1024 /* check size of encrypted key blob */ 1025 if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) { 1026 r = SSH_ERR_INVALID_FORMAT; 1027 goto out; 1028 } 1029 /* check that an appropriate amount of auth data is present */ 1030 if (sshbuf_len(encoded) < authlen || 1031 sshbuf_len(encoded) - authlen < encrypted_len) { 1032 r = SSH_ERR_INVALID_FORMAT; 1033 goto out; 1034 } 1035 1036 aadlen = sshbuf_len(copy) - sshbuf_len(encoded); 1037 1038 /* replace first 4 bytes of IV with index to ensure uniqueness */ 1039 memcpy(iv, key + keylen, ivlen); 1040 POKE_U32(iv, index); 1041 1042 /* decrypt private state of key */ 1043 if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 || 1044 (r = cipher_init(&ciphercontext, cipher, key, keylen, 1045 iv, ivlen, 0)) != 0 || 1046 (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy), 1047 encrypted_len, aadlen, authlen)) != 0) 1048 goto out; 1049 1050 /* there should be no trailing data */ 1051 if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0) 1052 goto out; 1053 if (sshbuf_len(encoded) != 0) { 1054 r = SSH_ERR_INVALID_FORMAT; 1055 goto out; 1056 } 1057 1058 /* remove AAD */ 1059 if ((r = sshbuf_consume(decrypted, aadlen)) != 0) 1060 goto out; 1061 /* XXX encrypted includes unchecked padding */ 1062 1063 /* success */ 1064 r = 0; 1065 if (retp != NULL) { 1066 *retp = decrypted; 1067 decrypted = NULL; 1068 } 1069 out: 1070 cipher_free(ciphercontext); 1071 sshbuf_free(copy); 1072 sshbuf_free(decrypted); 1073 free(iv); 1074 return r; 1075 } 1076 1077 u_int32_t 1078 sshkey_xmss_signatures_left(const struct sshkey *k) 1079 { 1080 struct ssh_xmss_state *state = k->xmss_state; 1081 u_int32_t idx; 1082 1083 if (sshkey_type_plain(k->type) == KEY_XMSS && state && 1084 state->maxidx) { 1085 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx; 1086 if (idx < state->maxidx) 1087 return state->maxidx - idx; 1088 } 1089 return 0; 1090 } 1091 1092 int 1093 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign) 1094 { 1095 struct ssh_xmss_state *state = k->xmss_state; 1096 1097 if (sshkey_type_plain(k->type) != KEY_XMSS) 1098 return SSH_ERR_INVALID_ARGUMENT; 1099 if (maxsign == 0) 1100 return 0; 1101 if (state->idx + maxsign < state->idx) 1102 return SSH_ERR_INVALID_ARGUMENT; 1103 state->maxidx = state->idx + maxsign; 1104 return 0; 1105 } 1106