xref: /openbsd/usr.bin/ssh/kex.c (revision d415bd75)
1 /* $OpenBSD: kex.c,v 1.182 2023/10/11 04:46:29 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 
27 #include <sys/types.h>
28 #include <errno.h>
29 #include <signal.h>
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <unistd.h>
34 #include <poll.h>
35 
36 #ifdef WITH_OPENSSL
37 #include <openssl/crypto.h>
38 #endif
39 
40 #include "ssh.h"
41 #include "ssh2.h"
42 #include "atomicio.h"
43 #include "version.h"
44 #include "packet.h"
45 #include "compat.h"
46 #include "cipher.h"
47 #include "sshkey.h"
48 #include "kex.h"
49 #include "log.h"
50 #include "mac.h"
51 #include "match.h"
52 #include "misc.h"
53 #include "dispatch.h"
54 #include "monitor.h"
55 #include "myproposal.h"
56 
57 #include "ssherr.h"
58 #include "sshbuf.h"
59 #include "digest.h"
60 #include "xmalloc.h"
61 
62 /* prototype */
63 static int kex_choose_conf(struct ssh *);
64 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
65 
66 static const char * const proposal_names[PROPOSAL_MAX] = {
67 	"KEX algorithms",
68 	"host key algorithms",
69 	"ciphers ctos",
70 	"ciphers stoc",
71 	"MACs ctos",
72 	"MACs stoc",
73 	"compression ctos",
74 	"compression stoc",
75 	"languages ctos",
76 	"languages stoc",
77 };
78 
79 struct kexalg {
80 	char *name;
81 	u_int type;
82 	int ec_nid;
83 	int hash_alg;
84 };
85 static const struct kexalg kexalgs[] = {
86 #ifdef WITH_OPENSSL
87 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
88 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
89 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
90 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
91 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
92 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
93 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
94 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
95 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
96 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
97 	    SSH_DIGEST_SHA384 },
98 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
99 	    SSH_DIGEST_SHA512 },
100 #endif
101 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
102 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
103 	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
104 	    SSH_DIGEST_SHA512 },
105 	{ NULL, 0, -1, -1},
106 };
107 
108 char *
109 kex_alg_list(char sep)
110 {
111 	char *ret = NULL, *tmp;
112 	size_t nlen, rlen = 0;
113 	const struct kexalg *k;
114 
115 	for (k = kexalgs; k->name != NULL; k++) {
116 		if (ret != NULL)
117 			ret[rlen++] = sep;
118 		nlen = strlen(k->name);
119 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
120 			free(ret);
121 			return NULL;
122 		}
123 		ret = tmp;
124 		memcpy(ret + rlen, k->name, nlen + 1);
125 		rlen += nlen;
126 	}
127 	return ret;
128 }
129 
130 static const struct kexalg *
131 kex_alg_by_name(const char *name)
132 {
133 	const struct kexalg *k;
134 
135 	for (k = kexalgs; k->name != NULL; k++) {
136 		if (strcmp(k->name, name) == 0)
137 			return k;
138 	}
139 	return NULL;
140 }
141 
142 /* Validate KEX method name list */
143 int
144 kex_names_valid(const char *names)
145 {
146 	char *s, *cp, *p;
147 
148 	if (names == NULL || strcmp(names, "") == 0)
149 		return 0;
150 	if ((s = cp = strdup(names)) == NULL)
151 		return 0;
152 	for ((p = strsep(&cp, ",")); p && *p != '\0';
153 	    (p = strsep(&cp, ","))) {
154 		if (kex_alg_by_name(p) == NULL) {
155 			error("Unsupported KEX algorithm \"%.100s\"", p);
156 			free(s);
157 			return 0;
158 		}
159 	}
160 	debug3("kex names ok: [%s]", names);
161 	free(s);
162 	return 1;
163 }
164 
165 /*
166  * Concatenate algorithm names, avoiding duplicates in the process.
167  * Caller must free returned string.
168  */
169 char *
170 kex_names_cat(const char *a, const char *b)
171 {
172 	char *ret = NULL, *tmp = NULL, *cp, *p, *m;
173 	size_t len;
174 
175 	if (a == NULL || *a == '\0')
176 		return strdup(b);
177 	if (b == NULL || *b == '\0')
178 		return strdup(a);
179 	if (strlen(b) > 1024*1024)
180 		return NULL;
181 	len = strlen(a) + strlen(b) + 2;
182 	if ((tmp = cp = strdup(b)) == NULL ||
183 	    (ret = calloc(1, len)) == NULL) {
184 		free(tmp);
185 		return NULL;
186 	}
187 	strlcpy(ret, a, len);
188 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
189 		if ((m = match_list(ret, p, NULL)) != NULL) {
190 			free(m);
191 			continue; /* Algorithm already present */
192 		}
193 		if (strlcat(ret, ",", len) >= len ||
194 		    strlcat(ret, p, len) >= len) {
195 			free(tmp);
196 			free(ret);
197 			return NULL; /* Shouldn't happen */
198 		}
199 	}
200 	free(tmp);
201 	return ret;
202 }
203 
204 /*
205  * Assemble a list of algorithms from a default list and a string from a
206  * configuration file. The user-provided string may begin with '+' to
207  * indicate that it should be appended to the default, '-' that the
208  * specified names should be removed, or '^' that they should be placed
209  * at the head.
210  */
211 int
212 kex_assemble_names(char **listp, const char *def, const char *all)
213 {
214 	char *cp, *tmp, *patterns;
215 	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
216 	int r = SSH_ERR_INTERNAL_ERROR;
217 
218 	if (listp == NULL || def == NULL || all == NULL)
219 		return SSH_ERR_INVALID_ARGUMENT;
220 
221 	if (*listp == NULL || **listp == '\0') {
222 		if ((*listp = strdup(def)) == NULL)
223 			return SSH_ERR_ALLOC_FAIL;
224 		return 0;
225 	}
226 
227 	list = *listp;
228 	*listp = NULL;
229 	if (*list == '+') {
230 		/* Append names to default list */
231 		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
232 			r = SSH_ERR_ALLOC_FAIL;
233 			goto fail;
234 		}
235 		free(list);
236 		list = tmp;
237 	} else if (*list == '-') {
238 		/* Remove names from default list */
239 		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
240 			r = SSH_ERR_ALLOC_FAIL;
241 			goto fail;
242 		}
243 		free(list);
244 		/* filtering has already been done */
245 		return 0;
246 	} else if (*list == '^') {
247 		/* Place names at head of default list */
248 		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
249 			r = SSH_ERR_ALLOC_FAIL;
250 			goto fail;
251 		}
252 		free(list);
253 		list = tmp;
254 	} else {
255 		/* Explicit list, overrides default - just use "list" as is */
256 	}
257 
258 	/*
259 	 * The supplied names may be a pattern-list. For the -list case,
260 	 * the patterns are applied above. For the +list and explicit list
261 	 * cases we need to do it now.
262 	 */
263 	ret = NULL;
264 	if ((patterns = opatterns = strdup(list)) == NULL) {
265 		r = SSH_ERR_ALLOC_FAIL;
266 		goto fail;
267 	}
268 	/* Apply positive (i.e. non-negated) patterns from the list */
269 	while ((cp = strsep(&patterns, ",")) != NULL) {
270 		if (*cp == '!') {
271 			/* negated matches are not supported here */
272 			r = SSH_ERR_INVALID_ARGUMENT;
273 			goto fail;
274 		}
275 		free(matching);
276 		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
277 			r = SSH_ERR_ALLOC_FAIL;
278 			goto fail;
279 		}
280 		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
281 			r = SSH_ERR_ALLOC_FAIL;
282 			goto fail;
283 		}
284 		free(ret);
285 		ret = tmp;
286 	}
287 	if (ret == NULL || *ret == '\0') {
288 		/* An empty name-list is an error */
289 		/* XXX better error code? */
290 		r = SSH_ERR_INVALID_ARGUMENT;
291 		goto fail;
292 	}
293 
294 	/* success */
295 	*listp = ret;
296 	ret = NULL;
297 	r = 0;
298 
299  fail:
300 	free(matching);
301 	free(opatterns);
302 	free(list);
303 	free(ret);
304 	return r;
305 }
306 
307 /*
308  * Fill out a proposal array with dynamically allocated values, which may
309  * be modified as required for compatibility reasons.
310  * Any of the options may be NULL, in which case the default is used.
311  * Array contents must be freed by calling kex_proposal_free_entries.
312  */
313 void
314 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
315     const char *kexalgos, const char *ciphers, const char *macs,
316     const char *comp, const char *hkalgs)
317 {
318 	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
319 	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
320 	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
321 	u_int i;
322 
323 	if (prop == NULL)
324 		fatal_f("proposal missing");
325 
326 	for (i = 0; i < PROPOSAL_MAX; i++) {
327 		switch(i) {
328 		case PROPOSAL_KEX_ALGS:
329 			prop[i] = compat_kex_proposal(ssh,
330 			    kexalgos ? kexalgos : defprop[i]);
331 			break;
332 		case PROPOSAL_ENC_ALGS_CTOS:
333 		case PROPOSAL_ENC_ALGS_STOC:
334 			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
335 			break;
336 		case PROPOSAL_MAC_ALGS_CTOS:
337 		case PROPOSAL_MAC_ALGS_STOC:
338 			prop[i]  = xstrdup(macs ? macs : defprop[i]);
339 			break;
340 		case PROPOSAL_COMP_ALGS_CTOS:
341 		case PROPOSAL_COMP_ALGS_STOC:
342 			prop[i] = xstrdup(comp ? comp : defprop[i]);
343 			break;
344 		case PROPOSAL_SERVER_HOST_KEY_ALGS:
345 			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
346 			break;
347 		default:
348 			prop[i] = xstrdup(defprop[i]);
349 		}
350 	}
351 }
352 
353 void
354 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
355 {
356 	u_int i;
357 
358 	for (i = 0; i < PROPOSAL_MAX; i++)
359 		free(prop[i]);
360 }
361 
362 /* put algorithm proposal into buffer */
363 int
364 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
365 {
366 	u_int i;
367 	int r;
368 
369 	sshbuf_reset(b);
370 
371 	/*
372 	 * add a dummy cookie, the cookie will be overwritten by
373 	 * kex_send_kexinit(), each time a kexinit is set
374 	 */
375 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
376 		if ((r = sshbuf_put_u8(b, 0)) != 0)
377 			return r;
378 	}
379 	for (i = 0; i < PROPOSAL_MAX; i++) {
380 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
381 			return r;
382 	}
383 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
384 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
385 		return r;
386 	return 0;
387 }
388 
389 /* parse buffer and return algorithm proposal */
390 int
391 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
392 {
393 	struct sshbuf *b = NULL;
394 	u_char v;
395 	u_int i;
396 	char **proposal = NULL;
397 	int r;
398 
399 	*propp = NULL;
400 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
401 		return SSH_ERR_ALLOC_FAIL;
402 	if ((b = sshbuf_fromb(raw)) == NULL) {
403 		r = SSH_ERR_ALLOC_FAIL;
404 		goto out;
405 	}
406 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
407 		error_fr(r, "consume cookie");
408 		goto out;
409 	}
410 	/* extract kex init proposal strings */
411 	for (i = 0; i < PROPOSAL_MAX; i++) {
412 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
413 			error_fr(r, "parse proposal %u", i);
414 			goto out;
415 		}
416 		debug2("%s: %s", proposal_names[i], proposal[i]);
417 	}
418 	/* first kex follows / reserved */
419 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
420 	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
421 		error_fr(r, "parse");
422 		goto out;
423 	}
424 	if (first_kex_follows != NULL)
425 		*first_kex_follows = v;
426 	debug2("first_kex_follows %d ", v);
427 	debug2("reserved %u ", i);
428 	r = 0;
429 	*propp = proposal;
430  out:
431 	if (r != 0 && proposal != NULL)
432 		kex_prop_free(proposal);
433 	sshbuf_free(b);
434 	return r;
435 }
436 
437 void
438 kex_prop_free(char **proposal)
439 {
440 	u_int i;
441 
442 	if (proposal == NULL)
443 		return;
444 	for (i = 0; i < PROPOSAL_MAX; i++)
445 		free(proposal[i]);
446 	free(proposal);
447 }
448 
449 int
450 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
451 {
452 	int r;
453 
454 	error("kex protocol error: type %d seq %u", type, seq);
455 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
456 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
457 	    (r = sshpkt_send(ssh)) != 0)
458 		return r;
459 	return 0;
460 }
461 
462 static void
463 kex_reset_dispatch(struct ssh *ssh)
464 {
465 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
466 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
467 }
468 
469 static int
470 kex_send_ext_info(struct ssh *ssh)
471 {
472 	int r;
473 	char *algs;
474 
475 	debug("Sending SSH2_MSG_EXT_INFO");
476 	if ((algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
477 		return SSH_ERR_ALLOC_FAIL;
478 	/* XXX filter algs list by allowed pubkey/hostbased types */
479 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
480 	    (r = sshpkt_put_u32(ssh, 3)) != 0 ||
481 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
482 	    (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
483 	    (r = sshpkt_put_cstring(ssh,
484 	    "publickey-hostbound@openssh.com")) != 0 ||
485 	    (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
486 	    (r = sshpkt_put_cstring(ssh, "ping@openssh.com")) != 0 ||
487 	    (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
488 	    (r = sshpkt_send(ssh)) != 0) {
489 		error_fr(r, "compose");
490 		goto out;
491 	}
492 	/* success */
493 	r = 0;
494  out:
495 	free(algs);
496 	return r;
497 }
498 
499 int
500 kex_send_newkeys(struct ssh *ssh)
501 {
502 	int r;
503 
504 	kex_reset_dispatch(ssh);
505 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
506 	    (r = sshpkt_send(ssh)) != 0)
507 		return r;
508 	debug("SSH2_MSG_NEWKEYS sent");
509 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
510 	if (ssh->kex->ext_info_c && (ssh->kex->flags & KEX_INITIAL) != 0)
511 		if ((r = kex_send_ext_info(ssh)) != 0)
512 			return r;
513 	debug("expecting SSH2_MSG_NEWKEYS");
514 	return 0;
515 }
516 
517 /* Check whether an ext_info value contains the expected version string */
518 static int
519 kex_ext_info_check_ver(struct kex *kex, const char *name,
520     const u_char *val, size_t len, const char *want_ver, u_int flag)
521 {
522 	if (memchr(val, '\0', len) != NULL) {
523 		error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
524 		return SSH_ERR_INVALID_FORMAT;
525 	}
526 	debug_f("%s=<%s>", name, val);
527 	if (strcmp(val, want_ver) == 0)
528 		kex->flags |= flag;
529 	else
530 		debug_f("unsupported version of %s extension", name);
531 	return 0;
532 }
533 
534 int
535 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
536 {
537 	struct kex *kex = ssh->kex;
538 	u_int32_t i, ninfo;
539 	char *name;
540 	u_char *val;
541 	size_t vlen;
542 	int r;
543 
544 	debug("SSH2_MSG_EXT_INFO received");
545 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
546 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
547 		return r;
548 	if (ninfo >= 1024) {
549 		error("SSH2_MSG_EXT_INFO with too many entries, expected "
550 		    "<=1024, received %u", ninfo);
551 		return SSH_ERR_INVALID_FORMAT;
552 	}
553 	for (i = 0; i < ninfo; i++) {
554 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
555 			return r;
556 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
557 			free(name);
558 			return r;
559 		}
560 		if (strcmp(name, "server-sig-algs") == 0) {
561 			/* Ensure no \0 lurking in value */
562 			if (memchr(val, '\0', vlen) != NULL) {
563 				error_f("nul byte in %s", name);
564 				free(name);
565 				free(val);
566 				return SSH_ERR_INVALID_FORMAT;
567 			}
568 			debug_f("%s=<%s>", name, val);
569 			kex->server_sig_algs = val;
570 			val = NULL;
571 		} else if (strcmp(name,
572 		    "publickey-hostbound@openssh.com") == 0) {
573 			if ((r = kex_ext_info_check_ver(kex, name, val, vlen,
574 			    "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
575 				free(name);
576 				free(val);
577 				return r;
578 			}
579 		} else if (strcmp(name, "ping@openssh.com") == 0) {
580 			if ((r = kex_ext_info_check_ver(kex, name, val, vlen,
581 			    "0", KEX_HAS_PING)) != 0) {
582 				free(name);
583 				free(val);
584 				return r;
585 			}
586 		} else
587 			debug_f("%s (unrecognised)", name);
588 		free(name);
589 		free(val);
590 	}
591 	return sshpkt_get_end(ssh);
592 }
593 
594 static int
595 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
596 {
597 	struct kex *kex = ssh->kex;
598 	int r;
599 
600 	debug("SSH2_MSG_NEWKEYS received");
601 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
602 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
603 	if ((r = sshpkt_get_end(ssh)) != 0)
604 		return r;
605 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
606 		return r;
607 	kex->done = 1;
608 	kex->flags &= ~KEX_INITIAL;
609 	sshbuf_reset(kex->peer);
610 	/* sshbuf_reset(kex->my); */
611 	kex->flags &= ~KEX_INIT_SENT;
612 	free(kex->name);
613 	kex->name = NULL;
614 	return 0;
615 }
616 
617 int
618 kex_send_kexinit(struct ssh *ssh)
619 {
620 	u_char *cookie;
621 	struct kex *kex = ssh->kex;
622 	int r;
623 
624 	if (kex == NULL) {
625 		error_f("no kex");
626 		return SSH_ERR_INTERNAL_ERROR;
627 	}
628 	if (kex->flags & KEX_INIT_SENT)
629 		return 0;
630 	kex->done = 0;
631 
632 	/* generate a random cookie */
633 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
634 		error_f("bad kex length: %zu < %d",
635 		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
636 		return SSH_ERR_INVALID_FORMAT;
637 	}
638 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
639 		error_f("buffer error");
640 		return SSH_ERR_INTERNAL_ERROR;
641 	}
642 	arc4random_buf(cookie, KEX_COOKIE_LEN);
643 
644 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
645 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
646 	    (r = sshpkt_send(ssh)) != 0) {
647 		error_fr(r, "compose reply");
648 		return r;
649 	}
650 	debug("SSH2_MSG_KEXINIT sent");
651 	kex->flags |= KEX_INIT_SENT;
652 	return 0;
653 }
654 
655 int
656 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
657 {
658 	struct kex *kex = ssh->kex;
659 	const u_char *ptr;
660 	u_int i;
661 	size_t dlen;
662 	int r;
663 
664 	debug("SSH2_MSG_KEXINIT received");
665 	if (kex == NULL) {
666 		error_f("no kex");
667 		return SSH_ERR_INTERNAL_ERROR;
668 	}
669 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
670 	ptr = sshpkt_ptr(ssh, &dlen);
671 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
672 		return r;
673 
674 	/* discard packet */
675 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
676 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
677 			error_fr(r, "discard cookie");
678 			return r;
679 		}
680 	}
681 	for (i = 0; i < PROPOSAL_MAX; i++) {
682 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
683 			error_fr(r, "discard proposal");
684 			return r;
685 		}
686 	}
687 	/*
688 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
689 	 * KEX method has the server move first, but a server might be using
690 	 * a custom method or one that we otherwise don't support. We should
691 	 * be prepared to remember first_kex_follows here so we can eat a
692 	 * packet later.
693 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
694 	 * for cases where the server *doesn't* go first. I guess we should
695 	 * ignore it when it is set for these cases, which is what we do now.
696 	 */
697 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
698 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
699 	    (r = sshpkt_get_end(ssh)) != 0)
700 			return r;
701 
702 	if (!(kex->flags & KEX_INIT_SENT))
703 		if ((r = kex_send_kexinit(ssh)) != 0)
704 			return r;
705 	if ((r = kex_choose_conf(ssh)) != 0)
706 		return r;
707 
708 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
709 		return (kex->kex[kex->kex_type])(ssh);
710 
711 	error_f("unknown kex type %u", kex->kex_type);
712 	return SSH_ERR_INTERNAL_ERROR;
713 }
714 
715 struct kex *
716 kex_new(void)
717 {
718 	struct kex *kex;
719 
720 	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
721 	    (kex->peer = sshbuf_new()) == NULL ||
722 	    (kex->my = sshbuf_new()) == NULL ||
723 	    (kex->client_version = sshbuf_new()) == NULL ||
724 	    (kex->server_version = sshbuf_new()) == NULL ||
725 	    (kex->session_id = sshbuf_new()) == NULL) {
726 		kex_free(kex);
727 		return NULL;
728 	}
729 	return kex;
730 }
731 
732 void
733 kex_free_newkeys(struct newkeys *newkeys)
734 {
735 	if (newkeys == NULL)
736 		return;
737 	if (newkeys->enc.key) {
738 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
739 		free(newkeys->enc.key);
740 		newkeys->enc.key = NULL;
741 	}
742 	if (newkeys->enc.iv) {
743 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
744 		free(newkeys->enc.iv);
745 		newkeys->enc.iv = NULL;
746 	}
747 	free(newkeys->enc.name);
748 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
749 	free(newkeys->comp.name);
750 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
751 	mac_clear(&newkeys->mac);
752 	if (newkeys->mac.key) {
753 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
754 		free(newkeys->mac.key);
755 		newkeys->mac.key = NULL;
756 	}
757 	free(newkeys->mac.name);
758 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
759 	freezero(newkeys, sizeof(*newkeys));
760 }
761 
762 void
763 kex_free(struct kex *kex)
764 {
765 	u_int mode;
766 
767 	if (kex == NULL)
768 		return;
769 
770 #ifdef WITH_OPENSSL
771 	DH_free(kex->dh);
772 	EC_KEY_free(kex->ec_client_key);
773 #endif
774 	for (mode = 0; mode < MODE_MAX; mode++) {
775 		kex_free_newkeys(kex->newkeys[mode]);
776 		kex->newkeys[mode] = NULL;
777 	}
778 	sshbuf_free(kex->peer);
779 	sshbuf_free(kex->my);
780 	sshbuf_free(kex->client_version);
781 	sshbuf_free(kex->server_version);
782 	sshbuf_free(kex->client_pub);
783 	sshbuf_free(kex->session_id);
784 	sshbuf_free(kex->initial_sig);
785 	sshkey_free(kex->initial_hostkey);
786 	free(kex->failed_choice);
787 	free(kex->hostkey_alg);
788 	free(kex->name);
789 	free(kex);
790 }
791 
792 int
793 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
794 {
795 	int r;
796 
797 	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
798 		return r;
799 	ssh->kex->flags = KEX_INITIAL;
800 	kex_reset_dispatch(ssh);
801 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
802 	return 0;
803 }
804 
805 int
806 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
807 {
808 	int r;
809 
810 	if ((r = kex_ready(ssh, proposal)) != 0)
811 		return r;
812 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
813 		kex_free(ssh->kex);
814 		ssh->kex = NULL;
815 		return r;
816 	}
817 	return 0;
818 }
819 
820 /*
821  * Request key re-exchange, returns 0 on success or a ssherr.h error
822  * code otherwise. Must not be called if KEX is incomplete or in-progress.
823  */
824 int
825 kex_start_rekex(struct ssh *ssh)
826 {
827 	if (ssh->kex == NULL) {
828 		error_f("no kex");
829 		return SSH_ERR_INTERNAL_ERROR;
830 	}
831 	if (ssh->kex->done == 0) {
832 		error_f("requested twice");
833 		return SSH_ERR_INTERNAL_ERROR;
834 	}
835 	ssh->kex->done = 0;
836 	return kex_send_kexinit(ssh);
837 }
838 
839 static int
840 choose_enc(struct sshenc *enc, char *client, char *server)
841 {
842 	char *name = match_list(client, server, NULL);
843 
844 	if (name == NULL)
845 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
846 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
847 		error_f("unsupported cipher %s", name);
848 		free(name);
849 		return SSH_ERR_INTERNAL_ERROR;
850 	}
851 	enc->name = name;
852 	enc->enabled = 0;
853 	enc->iv = NULL;
854 	enc->iv_len = cipher_ivlen(enc->cipher);
855 	enc->key = NULL;
856 	enc->key_len = cipher_keylen(enc->cipher);
857 	enc->block_size = cipher_blocksize(enc->cipher);
858 	return 0;
859 }
860 
861 static int
862 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
863 {
864 	char *name = match_list(client, server, NULL);
865 
866 	if (name == NULL)
867 		return SSH_ERR_NO_MAC_ALG_MATCH;
868 	if (mac_setup(mac, name) < 0) {
869 		error_f("unsupported MAC %s", name);
870 		free(name);
871 		return SSH_ERR_INTERNAL_ERROR;
872 	}
873 	mac->name = name;
874 	mac->key = NULL;
875 	mac->enabled = 0;
876 	return 0;
877 }
878 
879 static int
880 choose_comp(struct sshcomp *comp, char *client, char *server)
881 {
882 	char *name = match_list(client, server, NULL);
883 
884 	if (name == NULL)
885 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
886 #ifdef WITH_ZLIB
887 	if (strcmp(name, "zlib@openssh.com") == 0) {
888 		comp->type = COMP_DELAYED;
889 	} else if (strcmp(name, "zlib") == 0) {
890 		comp->type = COMP_ZLIB;
891 	} else
892 #endif	/* WITH_ZLIB */
893 	if (strcmp(name, "none") == 0) {
894 		comp->type = COMP_NONE;
895 	} else {
896 		error_f("unsupported compression scheme %s", name);
897 		free(name);
898 		return SSH_ERR_INTERNAL_ERROR;
899 	}
900 	comp->name = name;
901 	return 0;
902 }
903 
904 static int
905 choose_kex(struct kex *k, char *client, char *server)
906 {
907 	const struct kexalg *kexalg;
908 
909 	k->name = match_list(client, server, NULL);
910 
911 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
912 	if (k->name == NULL)
913 		return SSH_ERR_NO_KEX_ALG_MATCH;
914 	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
915 		error_f("unsupported KEX method %s", k->name);
916 		return SSH_ERR_INTERNAL_ERROR;
917 	}
918 	k->kex_type = kexalg->type;
919 	k->hash_alg = kexalg->hash_alg;
920 	k->ec_nid = kexalg->ec_nid;
921 	return 0;
922 }
923 
924 static int
925 choose_hostkeyalg(struct kex *k, char *client, char *server)
926 {
927 	free(k->hostkey_alg);
928 	k->hostkey_alg = match_list(client, server, NULL);
929 
930 	debug("kex: host key algorithm: %s",
931 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
932 	if (k->hostkey_alg == NULL)
933 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
934 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
935 	if (k->hostkey_type == KEY_UNSPEC) {
936 		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
937 		return SSH_ERR_INTERNAL_ERROR;
938 	}
939 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
940 	return 0;
941 }
942 
943 static int
944 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
945 {
946 	static int check[] = {
947 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
948 	};
949 	int *idx;
950 	char *p;
951 
952 	for (idx = &check[0]; *idx != -1; idx++) {
953 		if ((p = strchr(my[*idx], ',')) != NULL)
954 			*p = '\0';
955 		if ((p = strchr(peer[*idx], ',')) != NULL)
956 			*p = '\0';
957 		if (strcmp(my[*idx], peer[*idx]) != 0) {
958 			debug2("proposal mismatch: my %s peer %s",
959 			    my[*idx], peer[*idx]);
960 			return (0);
961 		}
962 	}
963 	debug2("proposals match");
964 	return (1);
965 }
966 
967 /* returns non-zero if proposal contains any algorithm from algs */
968 static int
969 has_any_alg(const char *proposal, const char *algs)
970 {
971 	char *cp;
972 
973 	if ((cp = match_list(proposal, algs, NULL)) == NULL)
974 		return 0;
975 	free(cp);
976 	return 1;
977 }
978 
979 static int
980 kex_choose_conf(struct ssh *ssh)
981 {
982 	struct kex *kex = ssh->kex;
983 	struct newkeys *newkeys;
984 	char **my = NULL, **peer = NULL;
985 	char **cprop, **sprop;
986 	int nenc, nmac, ncomp;
987 	u_int mode, ctos, need, dh_need, authlen;
988 	int r, first_kex_follows;
989 
990 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
991 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
992 		goto out;
993 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
994 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
995 		goto out;
996 
997 	if (kex->server) {
998 		cprop=peer;
999 		sprop=my;
1000 	} else {
1001 		cprop=my;
1002 		sprop=peer;
1003 	}
1004 
1005 	/* Check whether client supports ext_info_c */
1006 	if (kex->server && (kex->flags & KEX_INITIAL)) {
1007 		char *ext;
1008 
1009 		ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
1010 		kex->ext_info_c = (ext != NULL);
1011 		free(ext);
1012 	}
1013 
1014 	/* Check whether client supports rsa-sha2 algorithms */
1015 	if (kex->server && (kex->flags & KEX_INITIAL)) {
1016 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1017 		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1018 			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1019 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1020 		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1021 			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1022 	}
1023 
1024 	/* Algorithm Negotiation */
1025 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1026 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
1027 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1028 		peer[PROPOSAL_KEX_ALGS] = NULL;
1029 		goto out;
1030 	}
1031 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1032 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1033 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1034 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1035 		goto out;
1036 	}
1037 	for (mode = 0; mode < MODE_MAX; mode++) {
1038 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1039 			r = SSH_ERR_ALLOC_FAIL;
1040 			goto out;
1041 		}
1042 		kex->newkeys[mode] = newkeys;
1043 		ctos = (!kex->server && mode == MODE_OUT) ||
1044 		    (kex->server && mode == MODE_IN);
1045 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1046 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1047 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1048 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1049 		    sprop[nenc])) != 0) {
1050 			kex->failed_choice = peer[nenc];
1051 			peer[nenc] = NULL;
1052 			goto out;
1053 		}
1054 		authlen = cipher_authlen(newkeys->enc.cipher);
1055 		/* ignore mac for authenticated encryption */
1056 		if (authlen == 0 &&
1057 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1058 		    sprop[nmac])) != 0) {
1059 			kex->failed_choice = peer[nmac];
1060 			peer[nmac] = NULL;
1061 			goto out;
1062 		}
1063 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1064 		    sprop[ncomp])) != 0) {
1065 			kex->failed_choice = peer[ncomp];
1066 			peer[ncomp] = NULL;
1067 			goto out;
1068 		}
1069 		debug("kex: %s cipher: %s MAC: %s compression: %s",
1070 		    ctos ? "client->server" : "server->client",
1071 		    newkeys->enc.name,
1072 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1073 		    newkeys->comp.name);
1074 	}
1075 	need = dh_need = 0;
1076 	for (mode = 0; mode < MODE_MAX; mode++) {
1077 		newkeys = kex->newkeys[mode];
1078 		need = MAXIMUM(need, newkeys->enc.key_len);
1079 		need = MAXIMUM(need, newkeys->enc.block_size);
1080 		need = MAXIMUM(need, newkeys->enc.iv_len);
1081 		need = MAXIMUM(need, newkeys->mac.key_len);
1082 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1083 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1084 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1085 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1086 	}
1087 	/* XXX need runden? */
1088 	kex->we_need = need;
1089 	kex->dh_need = dh_need;
1090 
1091 	/* ignore the next message if the proposals do not match */
1092 	if (first_kex_follows && !proposals_match(my, peer))
1093 		ssh->dispatch_skip_packets = 1;
1094 	r = 0;
1095  out:
1096 	kex_prop_free(my);
1097 	kex_prop_free(peer);
1098 	return r;
1099 }
1100 
1101 static int
1102 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1103     const struct sshbuf *shared_secret, u_char **keyp)
1104 {
1105 	struct kex *kex = ssh->kex;
1106 	struct ssh_digest_ctx *hashctx = NULL;
1107 	char c = id;
1108 	u_int have;
1109 	size_t mdsz;
1110 	u_char *digest;
1111 	int r;
1112 
1113 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1114 		return SSH_ERR_INVALID_ARGUMENT;
1115 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1116 		r = SSH_ERR_ALLOC_FAIL;
1117 		goto out;
1118 	}
1119 
1120 	/* K1 = HASH(K || H || "A" || session_id) */
1121 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1122 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1123 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1124 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1125 	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1126 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1127 		r = SSH_ERR_LIBCRYPTO_ERROR;
1128 		error_f("KEX hash failed");
1129 		goto out;
1130 	}
1131 	ssh_digest_free(hashctx);
1132 	hashctx = NULL;
1133 
1134 	/*
1135 	 * expand key:
1136 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1137 	 * Key = K1 || K2 || ... || Kn
1138 	 */
1139 	for (have = mdsz; need > have; have += mdsz) {
1140 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1141 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1142 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1143 		    ssh_digest_update(hashctx, digest, have) != 0 ||
1144 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1145 			error_f("KDF failed");
1146 			r = SSH_ERR_LIBCRYPTO_ERROR;
1147 			goto out;
1148 		}
1149 		ssh_digest_free(hashctx);
1150 		hashctx = NULL;
1151 	}
1152 #ifdef DEBUG_KEX
1153 	fprintf(stderr, "key '%c'== ", c);
1154 	dump_digest("key", digest, need);
1155 #endif
1156 	*keyp = digest;
1157 	digest = NULL;
1158 	r = 0;
1159  out:
1160 	free(digest);
1161 	ssh_digest_free(hashctx);
1162 	return r;
1163 }
1164 
1165 #define NKEYS	6
1166 int
1167 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1168     const struct sshbuf *shared_secret)
1169 {
1170 	struct kex *kex = ssh->kex;
1171 	u_char *keys[NKEYS];
1172 	u_int i, j, mode, ctos;
1173 	int r;
1174 
1175 	/* save initial hash as session id */
1176 	if ((kex->flags & KEX_INITIAL) != 0) {
1177 		if (sshbuf_len(kex->session_id) != 0) {
1178 			error_f("already have session ID at kex");
1179 			return SSH_ERR_INTERNAL_ERROR;
1180 		}
1181 		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1182 			return r;
1183 	} else if (sshbuf_len(kex->session_id) == 0) {
1184 		error_f("no session ID in rekex");
1185 		return SSH_ERR_INTERNAL_ERROR;
1186 	}
1187 	for (i = 0; i < NKEYS; i++) {
1188 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1189 		    shared_secret, &keys[i])) != 0) {
1190 			for (j = 0; j < i; j++)
1191 				free(keys[j]);
1192 			return r;
1193 		}
1194 	}
1195 	for (mode = 0; mode < MODE_MAX; mode++) {
1196 		ctos = (!kex->server && mode == MODE_OUT) ||
1197 		    (kex->server && mode == MODE_IN);
1198 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1199 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1200 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1201 	}
1202 	return 0;
1203 }
1204 
1205 int
1206 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1207 {
1208 	struct kex *kex = ssh->kex;
1209 
1210 	*pubp = NULL;
1211 	*prvp = NULL;
1212 	if (kex->load_host_public_key == NULL ||
1213 	    kex->load_host_private_key == NULL) {
1214 		error_f("missing hostkey loader");
1215 		return SSH_ERR_INVALID_ARGUMENT;
1216 	}
1217 	*pubp = kex->load_host_public_key(kex->hostkey_type,
1218 	    kex->hostkey_nid, ssh);
1219 	*prvp = kex->load_host_private_key(kex->hostkey_type,
1220 	    kex->hostkey_nid, ssh);
1221 	if (*pubp == NULL)
1222 		return SSH_ERR_NO_HOSTKEY_LOADED;
1223 	return 0;
1224 }
1225 
1226 int
1227 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1228 {
1229 	struct kex *kex = ssh->kex;
1230 
1231 	if (kex->verify_host_key == NULL) {
1232 		error_f("missing hostkey verifier");
1233 		return SSH_ERR_INVALID_ARGUMENT;
1234 	}
1235 	if (server_host_key->type != kex->hostkey_type ||
1236 	    (kex->hostkey_type == KEY_ECDSA &&
1237 	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1238 		return SSH_ERR_KEY_TYPE_MISMATCH;
1239 	if (kex->verify_host_key(server_host_key, ssh) == -1)
1240 		return  SSH_ERR_SIGNATURE_INVALID;
1241 	return 0;
1242 }
1243 
1244 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1245 void
1246 dump_digest(const char *msg, const u_char *digest, int len)
1247 {
1248 	fprintf(stderr, "%s\n", msg);
1249 	sshbuf_dump_data(digest, len, stderr);
1250 }
1251 #endif
1252 
1253 /*
1254  * Send a plaintext error message to the peer, suffixed by \r\n.
1255  * Only used during banner exchange, and there only for the server.
1256  */
1257 static void
1258 send_error(struct ssh *ssh, char *msg)
1259 {
1260 	char *crnl = "\r\n";
1261 
1262 	if (!ssh->kex->server)
1263 		return;
1264 
1265 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1266 	    msg, strlen(msg)) != strlen(msg) ||
1267 	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1268 	    crnl, strlen(crnl)) != strlen(crnl))
1269 		error_f("write: %.100s", strerror(errno));
1270 }
1271 
1272 /*
1273  * Sends our identification string and waits for the peer's. Will block for
1274  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1275  * Returns on 0 success or a ssherr.h code on failure.
1276  */
1277 int
1278 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1279     const char *version_addendum)
1280 {
1281 	int remote_major, remote_minor, mismatch, oerrno = 0;
1282 	size_t len, n;
1283 	int r, expect_nl;
1284 	u_char c;
1285 	struct sshbuf *our_version = ssh->kex->server ?
1286 	    ssh->kex->server_version : ssh->kex->client_version;
1287 	struct sshbuf *peer_version = ssh->kex->server ?
1288 	    ssh->kex->client_version : ssh->kex->server_version;
1289 	char *our_version_string = NULL, *peer_version_string = NULL;
1290 	char *cp, *remote_version = NULL;
1291 
1292 	/* Prepare and send our banner */
1293 	sshbuf_reset(our_version);
1294 	if (version_addendum != NULL && *version_addendum == '\0')
1295 		version_addendum = NULL;
1296 	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%s%s%s\r\n",
1297 	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1298 	    version_addendum == NULL ? "" : " ",
1299 	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1300 		oerrno = errno;
1301 		error_fr(r, "sshbuf_putf");
1302 		goto out;
1303 	}
1304 
1305 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1306 	    sshbuf_mutable_ptr(our_version),
1307 	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1308 		oerrno = errno;
1309 		debug_f("write: %.100s", strerror(errno));
1310 		r = SSH_ERR_SYSTEM_ERROR;
1311 		goto out;
1312 	}
1313 	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1314 		oerrno = errno;
1315 		error_fr(r, "sshbuf_consume_end");
1316 		goto out;
1317 	}
1318 	our_version_string = sshbuf_dup_string(our_version);
1319 	if (our_version_string == NULL) {
1320 		error_f("sshbuf_dup_string failed");
1321 		r = SSH_ERR_ALLOC_FAIL;
1322 		goto out;
1323 	}
1324 	debug("Local version string %.100s", our_version_string);
1325 
1326 	/* Read other side's version identification. */
1327 	for (n = 0; ; n++) {
1328 		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1329 			send_error(ssh, "No SSH identification string "
1330 			    "received.");
1331 			error_f("No SSH version received in first %u lines "
1332 			    "from server", SSH_MAX_PRE_BANNER_LINES);
1333 			r = SSH_ERR_INVALID_FORMAT;
1334 			goto out;
1335 		}
1336 		sshbuf_reset(peer_version);
1337 		expect_nl = 0;
1338 		for (;;) {
1339 			if (timeout_ms > 0) {
1340 				r = waitrfd(ssh_packet_get_connection_in(ssh),
1341 				    &timeout_ms, NULL);
1342 				if (r == -1 && errno == ETIMEDOUT) {
1343 					send_error(ssh, "Timed out waiting "
1344 					    "for SSH identification string.");
1345 					error("Connection timed out during "
1346 					    "banner exchange");
1347 					r = SSH_ERR_CONN_TIMEOUT;
1348 					goto out;
1349 				} else if (r == -1) {
1350 					oerrno = errno;
1351 					error_f("%s", strerror(errno));
1352 					r = SSH_ERR_SYSTEM_ERROR;
1353 					goto out;
1354 				}
1355 			}
1356 
1357 			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1358 			    &c, 1);
1359 			if (len != 1 && errno == EPIPE) {
1360 				verbose_f("Connection closed by remote host");
1361 				r = SSH_ERR_CONN_CLOSED;
1362 				goto out;
1363 			} else if (len != 1) {
1364 				oerrno = errno;
1365 				error_f("read: %.100s", strerror(errno));
1366 				r = SSH_ERR_SYSTEM_ERROR;
1367 				goto out;
1368 			}
1369 			if (c == '\r') {
1370 				expect_nl = 1;
1371 				continue;
1372 			}
1373 			if (c == '\n')
1374 				break;
1375 			if (c == '\0' || expect_nl) {
1376 				verbose_f("banner line contains invalid "
1377 				    "characters");
1378 				goto invalid;
1379 			}
1380 			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1381 				oerrno = errno;
1382 				error_fr(r, "sshbuf_put");
1383 				goto out;
1384 			}
1385 			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1386 				verbose_f("banner line too long");
1387 				goto invalid;
1388 			}
1389 		}
1390 		/* Is this an actual protocol banner? */
1391 		if (sshbuf_len(peer_version) > 4 &&
1392 		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1393 			break;
1394 		/* If not, then just log the line and continue */
1395 		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1396 			error_f("sshbuf_dup_string failed");
1397 			r = SSH_ERR_ALLOC_FAIL;
1398 			goto out;
1399 		}
1400 		/* Do not accept lines before the SSH ident from a client */
1401 		if (ssh->kex->server) {
1402 			verbose_f("client sent invalid protocol identifier "
1403 			    "\"%.256s\"", cp);
1404 			free(cp);
1405 			goto invalid;
1406 		}
1407 		debug_f("banner line %zu: %s", n, cp);
1408 		free(cp);
1409 	}
1410 	peer_version_string = sshbuf_dup_string(peer_version);
1411 	if (peer_version_string == NULL)
1412 		fatal_f("sshbuf_dup_string failed");
1413 	/* XXX must be same size for sscanf */
1414 	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1415 		error_f("calloc failed");
1416 		r = SSH_ERR_ALLOC_FAIL;
1417 		goto out;
1418 	}
1419 
1420 	/*
1421 	 * Check that the versions match.  In future this might accept
1422 	 * several versions and set appropriate flags to handle them.
1423 	 */
1424 	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1425 	    &remote_major, &remote_minor, remote_version) != 3) {
1426 		error("Bad remote protocol version identification: '%.100s'",
1427 		    peer_version_string);
1428  invalid:
1429 		send_error(ssh, "Invalid SSH identification string.");
1430 		r = SSH_ERR_INVALID_FORMAT;
1431 		goto out;
1432 	}
1433 	debug("Remote protocol version %d.%d, remote software version %.100s",
1434 	    remote_major, remote_minor, remote_version);
1435 	compat_banner(ssh, remote_version);
1436 
1437 	mismatch = 0;
1438 	switch (remote_major) {
1439 	case 2:
1440 		break;
1441 	case 1:
1442 		if (remote_minor != 99)
1443 			mismatch = 1;
1444 		break;
1445 	default:
1446 		mismatch = 1;
1447 		break;
1448 	}
1449 	if (mismatch) {
1450 		error("Protocol major versions differ: %d vs. %d",
1451 		    PROTOCOL_MAJOR_2, remote_major);
1452 		send_error(ssh, "Protocol major versions differ.");
1453 		r = SSH_ERR_NO_PROTOCOL_VERSION;
1454 		goto out;
1455 	}
1456 
1457 	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1458 		logit("probed from %s port %d with %s.  Don't panic.",
1459 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1460 		    peer_version_string);
1461 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1462 		goto out;
1463 	}
1464 	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1465 		logit("scanned from %s port %d with %s.  Don't panic.",
1466 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1467 		    peer_version_string);
1468 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1469 		goto out;
1470 	}
1471 	/* success */
1472 	r = 0;
1473  out:
1474 	free(our_version_string);
1475 	free(peer_version_string);
1476 	free(remote_version);
1477 	if (r == SSH_ERR_SYSTEM_ERROR)
1478 		errno = oerrno;
1479 	return r;
1480 }
1481 
1482