1 /*
2  * Copyright (C) 2016 American Civil Liberties Union (ACLU)
3  *               2016-2018 CZ.NIC, z.s.p.o
4  *
5  * Initial Author: Daniel Kahn Gillmor <dkg@fifthhorseman.net>
6  *                 Ondřej Surý <ondrej@sury.org>
7  *
8  * SPDX-License-Identifier: GPL-3.0-or-later
9  */
10 
11 #include <gnutls/abstract.h>
12 #include <gnutls/crypto.h>
13 #include <gnutls/gnutls.h>
14 #include <gnutls/x509.h>
15 #include <uv.h>
16 
17 #include <errno.h>
18 #include <stdlib.h>
19 
20 #include "contrib/ucw/lib.h"
21 #include "contrib/base64.h"
22 #include "daemon/io.h"
23 #include "daemon/tls.h"
24 #include "daemon/worker.h"
25 #include "daemon/session.h"
26 
27 #define EPHEMERAL_CERT_EXPIRATION_SECONDS_RENEW_BEFORE (60*60*24*7)
28 #define GNUTLS_PIN_MIN_VERSION  0x030400
29 
30 #define VERBOSE_MSG(cl_side, ...)\
31 	if (cl_side) \
32 		kr_log_debug(TLSCLIENT, __VA_ARGS__); \
33 	else \
34 		kr_log_debug(TLS, __VA_ARGS__);
35 
36 /** @internal Debugging facility. */
37 #ifdef DEBUG
38 #define DEBUG_MSG(...) kr_log_debug(TLS, __VA_ARGS__)
39 #else
40 #define DEBUG_MSG(...)
41 #endif
42 
43 struct async_write_ctx {
44 	uv_write_t write_req;
45 	struct tls_common_ctx *t;
46 	char buf[];
47 };
48 
49 static int client_verify_certificate(gnutls_session_t tls_session);
50 
51 /**
52  * Set mandatory security settings from
53  * https://tools.ietf.org/html/draft-ietf-dprive-dtls-and-tls-profiles-11#section-9
54  * Performance optimizations are not implemented at the moment.
55  */
kres_gnutls_set_priority(gnutls_session_t session)56 static int kres_gnutls_set_priority(gnutls_session_t session) {
57 	static const char * const priorities =
58 		"NORMAL:" /* GnuTLS defaults */
59 		"-VERS-TLS1.0:-VERS-TLS1.1:" /* TLS 1.2 and higher */
60 		 /* Some distros by default allow features that are considered
61 		  * too insecure nowadays, so let's disable them explicitly. */
62 		"-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
63 	const char *errpos = NULL;
64 	int err = gnutls_priority_set_direct(session, priorities, &errpos);
65 	if (err != GNUTLS_E_SUCCESS) {
66 		kr_log_error(TLS, "setting priority '%s' failed at character %zd (...'%s') with %s (%d)\n",
67 			     priorities, errpos - priorities, errpos, gnutls_strerror_name(err), err);
68 	}
69 	return err;
70 }
71 
kres_gnutls_pull(gnutls_transport_ptr_t h,void * buf,size_t len)72 static ssize_t kres_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
73 {
74 	struct tls_common_ctx *t = (struct tls_common_ctx *)h;
75 	if (kr_fails_assert(t)) {
76 		errno = EFAULT;
77 		return -1;
78 	}
79 
80 	ssize_t	avail = t->nread - t->consumed;
81 	DEBUG_MSG("[%s] pull wanted: %zu available: %zu\n",
82 		  t->client_side ? "tls_client" : "tls", len, avail);
83 	if (t->nread <= t->consumed) {
84 		errno = EAGAIN;
85 		return -1;
86 	}
87 
88 	ssize_t	transfer = MIN(avail, len);
89 	memcpy(buf, t->buf + t->consumed, transfer);
90 	t->consumed += transfer;
91 	return transfer;
92 }
93 
on_write_complete(uv_write_t * req,int status)94 static void on_write_complete(uv_write_t *req, int status)
95 {
96 	if (kr_fails_assert(req->data))
97 		return;
98 	struct async_write_ctx *async_ctx = (struct async_write_ctx *)req->data;
99 	struct tls_common_ctx *t = async_ctx->t;
100 	if (t->write_queue_size)
101 		t->write_queue_size -= 1;
102 	else
103 		kr_assert(false);
104 	free(req->data);
105 }
106 
stream_queue_is_empty(struct tls_common_ctx * t)107 static bool stream_queue_is_empty(struct tls_common_ctx *t)
108 {
109 	return (t->write_queue_size == 0);
110 }
111 
kres_gnutls_vec_push(gnutls_transport_ptr_t h,const giovec_t * iov,int iovcnt)112 static ssize_t kres_gnutls_vec_push(gnutls_transport_ptr_t h, const giovec_t * iov, int iovcnt)
113 {
114 	struct tls_common_ctx *t = (struct tls_common_ctx *)h;
115 	if (kr_fails_assert(t)) {
116 		errno = EFAULT;
117 		return -1;
118 	}
119 
120 	if (iovcnt == 0) {
121 		return 0;
122 	}
123 
124 	if (kr_fails_assert(t->session)) {
125 		errno = EFAULT;
126 		return -1;
127 	}
128 	uv_stream_t *handle = (uv_stream_t *)session_get_handle(t->session);
129 	if (kr_fails_assert(handle && handle->type == UV_TCP)) {
130 		errno = EFAULT;
131 		return -1;
132 	}
133 
134 	/*
135 	 * This is a little bit complicated. There are two different writes:
136 	 * 1. Immediate, these don't need to own the buffered data and return immediately
137 	 * 2. Asynchronous, these need to own the buffers until the write completes
138 	 * In order to avoid copying the buffer, an immediate write is tried first if possible.
139 	 * If it isn't possible to write the data without queueing, an asynchronous write
140 	 * is created (with copied buffered data).
141 	 */
142 
143 	size_t total_len = 0;
144 	uv_buf_t uv_buf[iovcnt];
145 	for (int i = 0; i < iovcnt; ++i) {
146 		uv_buf[i].base = iov[i].iov_base;
147 		uv_buf[i].len = iov[i].iov_len;
148 		total_len += iov[i].iov_len;
149 	}
150 
151 	/* Try to perform the immediate write first to avoid copy */
152 	int ret = 0;
153 	if (stream_queue_is_empty(t)) {
154 		ret = uv_try_write(handle, uv_buf, iovcnt);
155 		DEBUG_MSG("[%s] push %zu <%p> = %d\n",
156 		    t->client_side ? "tls_client" : "tls", total_len, h, ret);
157 		/* from libuv documentation -
158 		   uv_try_write will return either:
159 		     > 0: number of bytes written (can be less than the supplied buffer size).
160 		     < 0: negative error code (UV_EAGAIN is returned if no data can be sent immediately).
161 		*/
162 		if (ret == total_len) {
163 			/* All the data were buffered by libuv.
164 			 * Return. */
165 			return ret;
166 		}
167 
168 		if (ret < 0 && ret != UV_EAGAIN) {
169 			/* uv_try_write() has returned error code other then UV_EAGAIN.
170 			 * Return. */
171 			VERBOSE_MSG(t->client_side, "uv_try_write error: %s\n",
172 					uv_strerror(ret));
173 			ret = -1;
174 			errno = EIO;
175 			return ret;
176 		}
177 		/* Since we are here expression below is true
178 		 * (ret != total_len) && (ret >= 0 || ret == UV_EAGAIN)
179 		 * or the same
180 		 * (ret != total_len && ret >= 0) || (ret != total_len && ret == UV_EAGAIN)
181 		 * i.e. either occurs partial write or UV_EAGAIN.
182 		 * Proceed and copy data amount to owned memory and perform async write.
183 		 */
184 		if (ret == UV_EAGAIN) {
185 			/* No data were buffered, so we must buffer all the data. */
186 			ret = 0;
187 		}
188 	}
189 
190 	/* Fallback when the queue is full, and it's not possible to do an immediate write */
191 	char *p = malloc(sizeof(struct async_write_ctx) + total_len - ret);
192 	if (p != NULL) {
193 		struct async_write_ctx *async_ctx = (struct async_write_ctx *)p;
194 		/* Save pointer to session tls context */
195 		async_ctx->t = t;
196 		char *buf = async_ctx->buf;
197 		/* Skip data written in the partial write */
198 		size_t to_skip = ret;
199 		/* Copy the buffer into owned memory */
200 		size_t off = 0;
201 		for (int i = 0; i < iovcnt; ++i) {
202 			if (to_skip > 0) {
203 				/* Ignore current buffer if it's all skipped */
204 				if (to_skip >= uv_buf[i].len) {
205 					to_skip -= uv_buf[i].len;
206 					continue;
207 				}
208 				/* Skip only part of the buffer */
209 				uv_buf[i].base += to_skip;
210 				uv_buf[i].len -= to_skip;
211 				to_skip = 0;
212 			}
213 			memcpy(buf + off, uv_buf[i].base, uv_buf[i].len);
214 			off += uv_buf[i].len;
215 		}
216 		uv_buf[0].base = buf;
217 		uv_buf[0].len = off;
218 
219 		/* Create an asynchronous write request */
220 		uv_write_t *write_req = &async_ctx->write_req;
221 		memset(write_req, 0, sizeof(uv_write_t));
222 		write_req->data = p;
223 
224 		/* Perform an asynchronous write with a callback */
225 		if (uv_write(write_req, handle, uv_buf, 1, on_write_complete) == 0) {
226 			ret = total_len;
227 			t->write_queue_size += 1;
228 		} else {
229 			free(p);
230 			VERBOSE_MSG(t->client_side, "uv_write error: %s\n",
231 					uv_strerror(ret));
232 			errno = EIO;
233 			ret = -1;
234 		}
235 	} else {
236 		errno = ENOMEM;
237 		ret = -1;
238 	}
239 
240 	DEBUG_MSG("[%s] queued %zu <%p> = %d\n",
241 	    t->client_side ? "tls_client" : "tls", total_len, h, ret);
242 
243 	return ret;
244 }
245 
246 /** Perform TLS handshake and handle error codes according to the documentation.
247   * See See https://gnutls.org/manual/html_node/TLS-handshake.html#TLS-handshake
248   * The function returns kr_ok() or success or non fatal error, kr_error(EAGAIN) on blocking, or kr_error(EIO) on fatal error.
249   */
tls_handshake(struct tls_common_ctx * ctx,tls_handshake_cb handshake_cb)250 static int tls_handshake(struct tls_common_ctx *ctx, tls_handshake_cb handshake_cb) {
251 	struct session *session = ctx->session;
252 
253 	int err = gnutls_handshake(ctx->tls_session);
254 	if (err == GNUTLS_E_SUCCESS) {
255 		/* Handshake finished, return success */
256 		ctx->handshake_state = TLS_HS_DONE;
257 		struct sockaddr *peer = session_get_peer(session);
258 		VERBOSE_MSG(ctx->client_side, "TLS handshake with %s has completed\n",
259 				kr_straddr(peer));
260 		if (handshake_cb) {
261 			if (handshake_cb(session, 0) != kr_ok()) {
262 				return kr_error(EIO);
263 			}
264 		}
265 	} else if (err == GNUTLS_E_AGAIN) {
266 		return kr_error(EAGAIN);
267 	} else if (gnutls_error_is_fatal(err)) {
268 		/* Fatal errors, return error as it's not recoverable */
269 		VERBOSE_MSG(ctx->client_side, "gnutls_handshake failed: %s (%d)\n",
270 				gnutls_strerror_name(err), err);
271 		/* Notify the peer about handshake failure via an alert. */
272 		gnutls_alert_send_appropriate(ctx->tls_session, err);
273 		if (handshake_cb) {
274 			handshake_cb(session, -1);
275 		}
276 		return kr_error(EIO);
277 	} else if (err == GNUTLS_E_WARNING_ALERT_RECEIVED) {
278 		/* Handle warning when in verbose mode */
279 		const char *alert_name = gnutls_alert_get_name(gnutls_alert_get(ctx->tls_session));
280 		if (alert_name != NULL) {
281 			struct sockaddr *peer = session_get_peer(session);
282 			VERBOSE_MSG(ctx->client_side, "TLS alert from %s received: %s\n",
283 					kr_straddr(peer), alert_name);
284 		}
285 	}
286 	return kr_ok();
287 }
288 
289 
tls_new(struct worker_ctx * worker)290 struct tls_ctx *tls_new(struct worker_ctx *worker)
291 {
292 	if (kr_fails_assert(worker && worker->engine))
293 		return NULL;
294 
295 	struct network *net = &worker->engine->net;
296 	if (!net->tls_credentials) {
297 		net->tls_credentials = tls_get_ephemeral_credentials(worker->engine);
298 		if (!net->tls_credentials) {
299 			kr_log_error(TLS, "X.509 credentials are missing, and ephemeral credentials failed; no TLS\n");
300 			return NULL;
301 		}
302 		kr_log_info(TLS, "Using ephemeral TLS credentials\n");
303 		tls_credentials_log_pins(net->tls_credentials);
304 	}
305 
306 	time_t now = time(NULL);
307 	if (net->tls_credentials->valid_until != GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION) {
308 		if (net->tls_credentials->ephemeral_servicename) {
309 			/* ephemeral cert: refresh if due to expire within a week */
310 			if (now >= net->tls_credentials->valid_until - EPHEMERAL_CERT_EXPIRATION_SECONDS_RENEW_BEFORE) {
311 				struct tls_credentials *newcreds = tls_get_ephemeral_credentials(worker->engine);
312 				if (newcreds) {
313 					tls_credentials_release(net->tls_credentials);
314 					net->tls_credentials = newcreds;
315 					kr_log_info(TLS, "Renewed expiring ephemeral X.509 cert\n");
316 				} else {
317 					kr_log_error(TLS, "Failed to renew expiring ephemeral X.509 cert, using existing one\n");
318 				}
319 			}
320 		} else {
321 			/* non-ephemeral cert: warn once when certificate expires */
322 			if (now >= net->tls_credentials->valid_until) {
323 				kr_log_error(TLS, "X.509 certificate has expired!\n");
324 				net->tls_credentials->valid_until = GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION;
325 			}
326 		}
327 	}
328 
329 	struct tls_ctx *tls = calloc(1, sizeof(struct tls_ctx));
330 	if (tls == NULL) {
331 		kr_log_error(TLS, "failed to allocate TLS context\n");
332 		return NULL;
333 	}
334 
335 	int err = gnutls_init(&tls->c.tls_session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
336 	if (err != GNUTLS_E_SUCCESS) {
337 		kr_log_error(TLS, "gnutls_init(): %s (%d)\n", gnutls_strerror_name(err), err);
338 		tls_free(tls);
339 		return NULL;
340 	}
341 	tls->credentials = tls_credentials_reserve(net->tls_credentials);
342 	err = gnutls_credentials_set(tls->c.tls_session, GNUTLS_CRD_CERTIFICATE,
343 				     tls->credentials->credentials);
344 	if (err != GNUTLS_E_SUCCESS) {
345 		kr_log_error(TLS, "gnutls_credentials_set(): %s (%d)\n", gnutls_strerror_name(err), err);
346 		tls_free(tls);
347 		return NULL;
348 	}
349 	if (kres_gnutls_set_priority(tls->c.tls_session) != GNUTLS_E_SUCCESS) {
350 		tls_free(tls);
351 		return NULL;
352 	}
353 
354 	tls->c.worker = worker;
355 	tls->c.client_side = false;
356 
357 	gnutls_transport_set_pull_function(tls->c.tls_session, kres_gnutls_pull);
358 	gnutls_transport_set_vec_push_function(tls->c.tls_session, kres_gnutls_vec_push);
359 	gnutls_transport_set_ptr(tls->c.tls_session, tls);
360 
361 	if (net->tls_session_ticket_ctx) {
362 		tls_session_ticket_enable(net->tls_session_ticket_ctx,
363 					  tls->c.tls_session);
364 	}
365 
366 	return tls;
367 }
368 
tls_close(struct tls_common_ctx * ctx)369 void tls_close(struct tls_common_ctx *ctx)
370 {
371 	if (ctx == NULL || ctx->tls_session == NULL || kr_fails_assert(ctx->session))
372 		return;
373 
374 	if (ctx->handshake_state == TLS_HS_DONE) {
375 		const struct sockaddr *peer = session_get_peer(ctx->session);
376 		VERBOSE_MSG(ctx->client_side, "closing tls connection to `%s`\n",
377 			       kr_straddr(peer));
378 		ctx->handshake_state = TLS_HS_CLOSING;
379 		gnutls_bye(ctx->tls_session, GNUTLS_SHUT_RDWR);
380 	}
381 }
382 
tls_free(struct tls_ctx * tls)383 void tls_free(struct tls_ctx *tls)
384 {
385 	if (!tls) {
386 		return;
387 	}
388 
389 	if (tls->c.tls_session) {
390 		/* Don't terminate TLS connection, just tear it down */
391 		gnutls_deinit(tls->c.tls_session);
392 		tls->c.tls_session = NULL;
393 	}
394 
395 	tls_credentials_release(tls->credentials);
396 	free(tls);
397 }
398 
tls_write(uv_write_t * req,uv_handle_t * handle,knot_pkt_t * pkt,uv_write_cb cb)399 int tls_write(uv_write_t *req, uv_handle_t *handle, knot_pkt_t *pkt, uv_write_cb cb)
400 {
401 	if (!pkt || !handle || !handle->data) {
402 		return kr_error(EINVAL);
403 	}
404 
405 	struct session *s = handle->data;
406 	struct tls_common_ctx *tls_ctx = session_tls_get_common_ctx(s);
407 
408 	if (kr_fails_assert(tls_ctx && session_flags(s)->outgoing == tls_ctx->client_side))
409 		return kr_error(EINVAL);
410 
411 	const uint16_t pkt_size = htons(pkt->size);
412 	gnutls_session_t tls_session = tls_ctx->tls_session;
413 
414 	gnutls_record_cork(tls_session);
415 	ssize_t count = 0;
416 	if ((count = gnutls_record_send(tls_session, &pkt_size, sizeof(pkt_size)) < 0) ||
417 	    (count = gnutls_record_send(tls_session, pkt->wire, pkt->size) < 0)) {
418 		VERBOSE_MSG(tls_ctx->client_side, "gnutls_record_send failed: %s (%zd)\n",
419 				gnutls_strerror_name(count), count);
420 		return kr_error(EIO);
421 	}
422 
423 	const ssize_t submitted = sizeof(pkt_size) + pkt->size;
424 
425 	int ret = gnutls_record_uncork(tls_session, GNUTLS_RECORD_WAIT);
426 	if (ret < 0) {
427 		if (!gnutls_error_is_fatal(ret)) {
428 			return kr_error(EAGAIN);
429 		} else {
430 			VERBOSE_MSG(tls_ctx->client_side, "gnutls_record_uncork failed: %s (%d)\n",
431 					gnutls_strerror_name(ret), ret);
432 			return kr_error(EIO);
433 		}
434 	}
435 
436 	if (ret != submitted) {
437 		kr_log_error(TLS, "gnutls_record_uncork didn't send all data (%d of %zd)\n", ret, submitted);
438 		return kr_error(EIO);
439 	}
440 
441 	/* The data is now accepted in gnutls internal buffers, the message can be treated as sent */
442 	req->handle = (uv_stream_t *)handle;
443 	cb(req, 0);
444 
445 	return kr_ok();
446 }
447 
tls_process_input_data(struct session * s,const uint8_t * buf,ssize_t nread)448 ssize_t tls_process_input_data(struct session *s, const uint8_t *buf, ssize_t nread)
449 {
450 	struct tls_common_ctx *tls_p = session_tls_get_common_ctx(s);
451 	if (!tls_p) {
452 		return kr_error(ENOSYS);
453 	}
454 
455 	if (kr_fails_assert(tls_p->session == s))
456 		return kr_error(EINVAL);
457 	const bool ok = tls_p->recv_buf == buf && nread <= sizeof(tls_p->recv_buf);
458 	if (kr_fails_assert(ok)) /* don't risk overflowing the buffer if we have a mistake somewhere */
459 		return kr_error(EINVAL);
460 
461 	tls_p->buf = buf;
462 	tls_p->nread = nread >= 0 ? nread : 0;
463 	tls_p->consumed = 0;
464 
465 	/* Ensure TLS handshake is performed before receiving data.
466 	 * See https://www.gnutls.org/manual/html_node/TLS-handshake.html */
467 	while (tls_p->handshake_state <= TLS_HS_IN_PROGRESS) {
468 		int err = tls_handshake(tls_p, tls_p->handshake_cb);
469 		if (err == kr_error(EAGAIN)) {
470 			return 0; /* Wait for more data */
471 		} else if (err != kr_ok()) {
472 			return err;
473 		}
474 	}
475 
476 	/* See https://gnutls.org/manual/html_node/Data-transfer-and-termination.html#Data-transfer-and-termination */
477 	ssize_t submitted = 0;
478 	uint8_t *wire_buf = session_wirebuf_get_free_start(s);
479 	size_t wire_buf_size = session_wirebuf_get_free_size(s);
480 	while (true) {
481 		ssize_t count = gnutls_record_recv(tls_p->tls_session, wire_buf, wire_buf_size);
482 		if (count == GNUTLS_E_AGAIN) {
483 			if (tls_p->consumed == tls_p->nread) {
484 				/* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */
485 				break; /* No more data available in this libuv buffer */
486 			}
487 			continue;
488 		} else if (count == GNUTLS_E_INTERRUPTED) {
489 			continue;
490 		} else if (count == GNUTLS_E_REHANDSHAKE) {
491 			/* See https://www.gnutls.org/manual/html_node/Re_002dauthentication.html */
492 			struct sockaddr *peer = session_get_peer(s);
493 			VERBOSE_MSG(tls_p->client_side, "TLS rehandshake with %s has started\n",
494 					kr_straddr(peer));
495 			tls_set_hs_state(tls_p, TLS_HS_IN_PROGRESS);
496 			int err = kr_ok();
497 			while (tls_p->handshake_state <= TLS_HS_IN_PROGRESS) {
498 				err = tls_handshake(tls_p, tls_p->handshake_cb);
499 				if (err == kr_error(EAGAIN)) {
500 					break;
501 				} else if (err != kr_ok()) {
502 					return err;
503 				}
504 			}
505 			if (err == kr_error(EAGAIN)) {
506 				/* pull function is out of data */
507 				break;
508 			}
509 			/* There are can be data available, check it. */
510 			continue;
511 		} else if (count < 0) {
512 			VERBOSE_MSG(tls_p->client_side, "gnutls_record_recv failed: %s (%zd)\n",
513 					gnutls_strerror_name(count), count);
514 			return kr_error(EIO);
515 		} else if (count == 0) {
516 			break;
517 		}
518 		DEBUG_MSG("[%s] received %zd data\n", tls_p->client_side ? "tls_client" : "tls", count);
519 		wire_buf += count;
520 		wire_buf_size -= count;
521 		submitted += count;
522 		if (wire_buf_size == 0 && tls_p->consumed != tls_p->nread) {
523 			/* session buffer is full
524 			 * whereas not all the data were consumed */
525 			return kr_error(ENOSPC);
526 		}
527 	}
528 	/* Here all data must be consumed. */
529 	if (tls_p->consumed != tls_p->nread) {
530 		/* Something went wrong, better return error.
531 		 * This is most probably due to gnutls_record_recv() did not
532 		 * consume all available network data by calling kres_gnutls_pull().
533 		 * TODO assess the need for buffering of data amount.
534 		 */
535 		return kr_error(ENOSPC);
536 	}
537 	return submitted;
538 }
539 
540 #if TLS_CAN_USE_PINS
541 /*
542   DNS-over-TLS Out of band key-pinned authentication profile uses the
543   same form of pins as HPKP:
544 
545   e.g.  pin-sha256="FHkyLhvI0n70E47cJlRTamTrnYVcsYdjUGbr79CfAVI="
546 
547   DNS-over-TLS OOB key-pins: https://tools.ietf.org/html/rfc7858#appendix-A
548   HPKP pin reference:        https://tools.ietf.org/html/rfc7469#appendix-A
549 */
550 #define PINLEN  ((((32) * 8 + 4)/6) + 3 + 1)
551 
552 /* Compute pin_sha256 for the certificate.
553  * It may be in raw format - just TLS_SHA256_RAW_LEN bytes without termination,
554  * or it may be a base64 0-terminated string requiring up to
555  * TLS_SHA256_BASE64_BUFLEN bytes.
556  * \return error code */
get_oob_key_pin(gnutls_x509_crt_t crt,char * outchar,ssize_t outchar_len,bool raw)557 static int get_oob_key_pin(gnutls_x509_crt_t crt, char *outchar, ssize_t outchar_len, bool raw)
558 {
559 	if (kr_fails_assert(!raw || outchar_len >= TLS_SHA256_RAW_LEN)) {
560 		return kr_error(ENOSPC);
561 		/* With !raw we have check inside kr_base64_encode. */
562 	}
563 	gnutls_pubkey_t key;
564 	int err = gnutls_pubkey_init(&key);
565 	if (err != GNUTLS_E_SUCCESS) return err;
566 
567 	gnutls_datum_t datum = { .data = NULL, .size = 0 };
568 	err = gnutls_pubkey_import_x509(key, crt, 0);
569 	if (err != GNUTLS_E_SUCCESS) goto leave;
570 
571 	err = gnutls_pubkey_export2(key, GNUTLS_X509_FMT_DER, &datum);
572 	if (err != GNUTLS_E_SUCCESS) goto leave;
573 
574 	char raw_pin[TLS_SHA256_RAW_LEN]; /* TMP buffer if raw == false */
575 	err = gnutls_hash_fast(GNUTLS_DIG_SHA256, datum.data, datum.size,
576 				(raw ? outchar : raw_pin));
577 	if (err != GNUTLS_E_SUCCESS || raw/*success*/)
578 		goto leave;
579 	/* Convert to non-raw. */
580 	err = kr_base64_encode((uint8_t *)raw_pin, sizeof(raw_pin),
581 			    (uint8_t *)outchar, outchar_len);
582 	if (err >= 0 && err < outchar_len) {
583 		err = GNUTLS_E_SUCCESS;
584 		outchar[err] = '\0'; /* kr_base64_encode() doesn't do it */
585 	} else if (kr_fails_assert(err < 0)) {
586 		err = kr_error(ENOSPC); /* base64 fits but '\0' doesn't */
587 		outchar[outchar_len - 1] = '\0';
588 	}
589 leave:
590 	gnutls_free(datum.data);
591 	gnutls_pubkey_deinit(key);
592 	return err;
593 }
594 
tls_credentials_log_pins(struct tls_credentials * tls_credentials)595 void tls_credentials_log_pins(struct tls_credentials *tls_credentials)
596 {
597 	for (int index = 0;; index++) {
598 		gnutls_x509_crt_t *certs = NULL;
599 		unsigned int cert_count = 0;
600 		int err = gnutls_certificate_get_x509_crt(tls_credentials->credentials,
601 							index, &certs, &cert_count);
602 		if (err != GNUTLS_E_SUCCESS) {
603 			if (err != GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE) {
604 				kr_log_error(TLS, "could not get X.509 certificates (%d) %s\n",
605 						err, gnutls_strerror_name(err));
606 			}
607 			return;
608 		}
609 
610 		for (int i = 0; i < cert_count; i++) {
611 			char pin[TLS_SHA256_BASE64_BUFLEN] = { 0 };
612 			err = get_oob_key_pin(certs[i], pin, sizeof(pin), false);
613 			if (err != GNUTLS_E_SUCCESS) {
614 				kr_log_error(TLS, "could not calculate RFC 7858 OOB key-pin from cert %d (%d) %s\n",
615 						i, err, gnutls_strerror_name(err));
616 			} else {
617 				kr_log_info(TLS, "RFC 7858 OOB key-pin (%d): pin-sha256=\"%s\"\n",
618 						i, pin);
619 			}
620 			gnutls_x509_crt_deinit(certs[i]);
621 		}
622 		gnutls_free(certs);
623 	}
624 }
625 #else
tls_credentials_log_pins(struct tls_credentials * tls_credentials)626 void tls_credentials_log_pins(struct tls_credentials *tls_credentials)
627 {
628 	kr_log_debug(TLS, "could not calculate RFC 7858 OOB key-pin; GnuTLS 3.4.0+ required\n");
629 }
630 #endif
631 
str_replace(char ** where_ptr,const char * with)632 static int str_replace(char **where_ptr, const char *with)
633 {
634 	char *copy = with ? strdup(with) : NULL;
635 	if (with && !copy) {
636 		return kr_error(ENOMEM);
637 	}
638 
639 	free(*where_ptr);
640 	*where_ptr = copy;
641 	return kr_ok();
642 }
643 
_get_end_entity_expiration(gnutls_certificate_credentials_t creds)644 static time_t _get_end_entity_expiration(gnutls_certificate_credentials_t creds)
645 {
646 	gnutls_datum_t data;
647 	gnutls_x509_crt_t cert = NULL;
648 	int err;
649 	time_t ret = GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION;
650 
651 	if ((err = gnutls_certificate_get_crt_raw(creds, 0, 0, &data)) != GNUTLS_E_SUCCESS) {
652 		kr_log_error(TLS, "failed to get cert to check expiration: (%d) %s\n",
653 			     err, gnutls_strerror_name(err));
654 		goto done;
655 	}
656 	if ((err = gnutls_x509_crt_init(&cert)) != GNUTLS_E_SUCCESS) {
657 		kr_log_error(TLS, "failed to initialize cert: (%d) %s\n",
658 			     err, gnutls_strerror_name(err));
659 		goto done;
660 	}
661 	if ((err = gnutls_x509_crt_import(cert, &data, GNUTLS_X509_FMT_DER)) != GNUTLS_E_SUCCESS) {
662 		kr_log_error(TLS, "failed to construct cert while checking expiration: (%d) %s\n",
663 			     err, gnutls_strerror_name(err));
664 		goto done;
665 	}
666 
667 	ret = gnutls_x509_crt_get_expiration_time (cert);
668  done:
669 	/* do not free data; g_c_get_crt_raw() says to treat it as
670 	 * constant. */
671 	gnutls_x509_crt_deinit(cert);
672 	return ret;
673 }
674 
tls_certificate_set(struct network * net,const char * tls_cert,const char * tls_key)675 int tls_certificate_set(struct network *net, const char *tls_cert, const char *tls_key)
676 {
677 	if (!net) {
678 		return kr_error(EINVAL);
679 	}
680 
681 	struct tls_credentials *tls_credentials = calloc(1, sizeof(*tls_credentials));
682 	if (tls_credentials == NULL) {
683 		return kr_error(ENOMEM);
684 	}
685 
686 	int err = 0;
687 	if ((err = gnutls_certificate_allocate_credentials(&tls_credentials->credentials)) != GNUTLS_E_SUCCESS) {
688 		kr_log_error(TLS, "gnutls_certificate_allocate_credentials() failed: (%d) %s\n",
689 			     err, gnutls_strerror_name(err));
690 		tls_credentials_free(tls_credentials);
691 		return kr_error(ENOMEM);
692 	}
693 	if ((err = gnutls_certificate_set_x509_system_trust(tls_credentials->credentials)) < 0) {
694 		if (err != GNUTLS_E_UNIMPLEMENTED_FEATURE) {
695 			kr_log_warning(TLS, "warning: gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n",
696 				     err, gnutls_strerror_name(err));
697 			tls_credentials_free(tls_credentials);
698 			return err;
699 		}
700 	}
701 
702 	if ((str_replace(&tls_credentials->tls_cert, tls_cert) != 0) ||
703 	    (str_replace(&tls_credentials->tls_key, tls_key) != 0)) {
704 		tls_credentials_free(tls_credentials);
705 		return kr_error(ENOMEM);
706 	}
707 
708 	if ((err = gnutls_certificate_set_x509_key_file(tls_credentials->credentials,
709 							tls_cert, tls_key, GNUTLS_X509_FMT_PEM)) != GNUTLS_E_SUCCESS) {
710 		tls_credentials_free(tls_credentials);
711 		kr_log_error(TLS, "gnutls_certificate_set_x509_key_file(%s,%s) failed: %d (%s)\n",
712 			     tls_cert, tls_key, err, gnutls_strerror_name(err));
713 		return kr_error(EINVAL);
714 	}
715 	/* record the expiration date: */
716 	tls_credentials->valid_until = _get_end_entity_expiration(tls_credentials->credentials);
717 
718 	/* Exchange the x509 credentials */
719 	struct tls_credentials *old_credentials = net->tls_credentials;
720 
721 	/* Start using the new x509_credentials */
722 	net->tls_credentials = tls_credentials;
723 	tls_credentials_log_pins(net->tls_credentials);
724 
725 	if (old_credentials) {
726 		err = tls_credentials_release(old_credentials);
727 		if (err != kr_error(EBUSY)) {
728 			return err;
729 		}
730 	}
731 
732 	return kr_ok();
733 }
734 
tls_credentials_reserve(struct tls_credentials * tls_credentials)735 struct tls_credentials *tls_credentials_reserve(struct tls_credentials *tls_credentials) {
736 	if (!tls_credentials) {
737 		return NULL;
738 	}
739 	tls_credentials->count++;
740 	return tls_credentials;
741 }
742 
tls_credentials_release(struct tls_credentials * tls_credentials)743 int tls_credentials_release(struct tls_credentials *tls_credentials) {
744 	if (!tls_credentials) {
745 		return kr_error(EINVAL);
746 	}
747 	if (--tls_credentials->count < 0) {
748 		tls_credentials_free(tls_credentials);
749 	} else {
750 		return kr_error(EBUSY);
751 	}
752 	return kr_ok();
753 }
754 
tls_credentials_free(struct tls_credentials * tls_credentials)755 void tls_credentials_free(struct tls_credentials *tls_credentials) {
756 	if (!tls_credentials) {
757 		return;
758 	}
759 
760 	if (tls_credentials->credentials) {
761 		gnutls_certificate_free_credentials(tls_credentials->credentials);
762 	}
763 	if (tls_credentials->tls_cert) {
764 		free(tls_credentials->tls_cert);
765 	}
766 	if (tls_credentials->tls_key) {
767 		free(tls_credentials->tls_key);
768 	}
769 	if (tls_credentials->ephemeral_servicename) {
770 		free(tls_credentials->ephemeral_servicename);
771 	}
772 	free(tls_credentials);
773 }
774 
tls_client_param_unref(tls_client_param_t * entry)775 void tls_client_param_unref(tls_client_param_t *entry)
776 {
777 	if (!entry || kr_fails_assert(entry->refs)) return;
778 	--(entry->refs);
779 	if (entry->refs) return;
780 
781 	DEBUG_MSG("freeing TLS parameters %p\n", (void *)entry);
782 
783 	for (int i = 0; i < entry->ca_files.len; ++i) {
784 		free_const(entry->ca_files.at[i]);
785 	}
786 	array_clear(entry->ca_files);
787 
788 	free_const(entry->hostname);
789 
790 	for (int i = 0; i < entry->pins.len; ++i) {
791 		free_const(entry->pins.at[i]);
792 	}
793 	array_clear(entry->pins);
794 
795 	if (entry->credentials) {
796 		gnutls_certificate_free_credentials(entry->credentials);
797 	}
798 
799 	if (entry->session_data.data) {
800 		gnutls_free(entry->session_data.data);
801 	}
802 
803 	free(entry);
804 }
param_free(void ** param,void * null)805 static int param_free(void **param, void *null)
806 {
807 	if (kr_fails_assert(param && *param))
808 		return -1;
809 	tls_client_param_unref(*param);
810 	return 0;
811 }
tls_client_params_free(tls_client_params_t * params)812 void tls_client_params_free(tls_client_params_t *params)
813 {
814 	if (!params) return;
815 	trie_apply(params, param_free, NULL);
816 	trie_free(params);
817 }
818 
tls_client_param_new()819 tls_client_param_t * tls_client_param_new()
820 {
821 	tls_client_param_t *e = calloc(1, sizeof(*e));
822 	if (kr_fails_assert(e))
823 		return NULL;
824 	/* Note: those array_t don't need further initialization. */
825 	e->refs = 1;
826 	int ret = gnutls_certificate_allocate_credentials(&e->credentials);
827 	if (ret != GNUTLS_E_SUCCESS) {
828 		kr_log_error(TLSCLIENT, "error: gnutls_certificate_allocate_credentials() fails (%s)\n",
829 			     gnutls_strerror_name(ret));
830 		free(e);
831 		return NULL;
832 	}
833 	gnutls_certificate_set_verify_function(e->credentials, client_verify_certificate);
834 	return e;
835 }
836 
837 /**
838  * Convert an IP address and port number to binary key.
839  *
840  * \precond buffer \param key must have sufficient size
841  * \param addr[in]
842  * \param len[out] output length
843  * \param key[out] output buffer
844  */
construct_key(const union inaddr * addr,uint32_t * len,char * key)845 static bool construct_key(const union inaddr *addr, uint32_t *len, char *key)
846 {
847 	switch (addr->ip.sa_family) {
848 	case AF_INET:
849 		memcpy(key, &addr->ip4.sin_port, sizeof(addr->ip4.sin_port));
850 		memcpy(key + sizeof(addr->ip4.sin_port), &addr->ip4.sin_addr,
851 			sizeof(addr->ip4.sin_addr));
852 		*len = sizeof(addr->ip4.sin_port) + sizeof(addr->ip4.sin_addr);
853 		return true;
854 	case AF_INET6:
855 		memcpy(key, &addr->ip6.sin6_port, sizeof(addr->ip6.sin6_port));
856 		memcpy(key + sizeof(addr->ip6.sin6_port), &addr->ip6.sin6_addr,
857 			sizeof(addr->ip6.sin6_addr));
858 		*len = sizeof(addr->ip6.sin6_port) + sizeof(addr->ip6.sin6_addr);
859 		return true;
860 	default:
861 		kr_assert(!EINVAL);
862 		return false;
863 	}
864 }
tls_client_param_getptr(tls_client_params_t ** params,const struct sockaddr * addr,bool do_insert)865 tls_client_param_t ** tls_client_param_getptr(tls_client_params_t **params,
866 				const struct sockaddr *addr, bool do_insert)
867 {
868 	if (kr_fails_assert(params && addr))
869 		return NULL;
870 	/* We accept NULL for empty map; ensure the map exists if needed. */
871 	if (!*params) {
872 		if (!do_insert) return NULL;
873 		*params = trie_create(NULL);
874 		if (kr_fails_assert(*params))
875 			return NULL;
876 	}
877 	/* Construct the key. */
878 	const union inaddr *ia = (const union inaddr *)addr;
879 	char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)];
880 	uint32_t len;
881 	if (!construct_key(ia, &len, key))
882 		return NULL;
883 	/* Get the entry. */
884 	return (tls_client_param_t **)
885 		(do_insert ? trie_get_ins : trie_get_try)(*params, key, len);
886 }
887 
tls_client_param_remove(tls_client_params_t * params,const struct sockaddr * addr)888 int tls_client_param_remove(tls_client_params_t *params, const struct sockaddr *addr)
889 {
890 	const union inaddr *ia = (const union inaddr *)addr;
891 	char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)];
892 	uint32_t len;
893 	if (!construct_key(ia, &len, key))
894 		return kr_error(EINVAL);
895 	trie_val_t param_ptr;
896 	int ret = trie_del(params, key, len, &param_ptr);
897 	if (ret != KNOT_EOK)
898 		return kr_error(ret);
899 	tls_client_param_unref(param_ptr);
900 	return kr_ok();
901 }
902 
903 /**
904  * Verify that at least one certificate in the certificate chain matches
905  * at least one certificate pin in the non-empty params->pins array.
906  * \returns GNUTLS_E_SUCCESS if pin matches, any other value is an error
907  */
client_verify_pin(const unsigned int cert_list_size,const gnutls_datum_t * cert_list,tls_client_param_t * params)908 static int client_verify_pin(const unsigned int cert_list_size,
909 				const gnutls_datum_t *cert_list,
910 				tls_client_param_t *params)
911 {
912 	if (kr_fails_assert(params->pins.len > 0))
913 		return GNUTLS_E_CERTIFICATE_ERROR;
914 #if TLS_CAN_USE_PINS
915 	for (int i = 0; i < cert_list_size; i++) {
916 		gnutls_x509_crt_t cert;
917 		int ret = gnutls_x509_crt_init(&cert);
918 		if (ret != GNUTLS_E_SUCCESS) {
919 			return ret;
920 		}
921 
922 		ret = gnutls_x509_crt_import(cert, &cert_list[i], GNUTLS_X509_FMT_DER);
923 		if (ret != GNUTLS_E_SUCCESS) {
924 			gnutls_x509_crt_deinit(cert);
925 			return ret;
926 		}
927 
928 	#ifdef DEBUG
929 		if (kr_log_is_debug(TLS, NULL)) {
930 			char pin_base64[TLS_SHA256_BASE64_BUFLEN];
931 			/* DEBUG: additionally compute and print the base64 pin.
932 			 * Not very efficient, but that's OK for DEBUG. */
933 			ret = get_oob_key_pin(cert, pin_base64, sizeof(pin_base64), false);
934 			if (ret == GNUTLS_E_SUCCESS) {
935 				DEBUG_MSG("[tls_client] received pin: %s\n", pin_base64);
936 			} else {
937 				DEBUG_MSG("[tls_client] failed to convert received pin\n");
938 				/* Now we hope that `ret` below can't differ. */
939 			}
940 		}
941 	#endif
942 		char cert_pin[TLS_SHA256_RAW_LEN];
943 		/* Get raw pin and compare. */
944 		ret = get_oob_key_pin(cert, cert_pin, sizeof(cert_pin), true);
945 		gnutls_x509_crt_deinit(cert);
946 		if (ret != GNUTLS_E_SUCCESS) {
947 			return ret;
948 		}
949 		for (size_t j = 0; j < params->pins.len; ++j) {
950 			const uint8_t *pin = params->pins.at[j];
951 			if (memcmp(cert_pin, pin, TLS_SHA256_RAW_LEN) != 0)
952 				continue; /* mismatch */
953 			DEBUG_MSG("[tls_client] matched a configured pin no. %zd\n", j);
954 			return GNUTLS_E_SUCCESS;
955 		}
956 		DEBUG_MSG("[tls_client] none of %zd configured pin(s) matched\n",
957 				params->pins.len);
958 	}
959 
960 	kr_log_error(TLSCLIENT, "no pin matched: %zu pins * %d certificates\n",
961 			params->pins.len, cert_list_size);
962 	return GNUTLS_E_CERTIFICATE_ERROR;
963 
964 #else /* TLS_CAN_USE_PINS */
965 	kr_log_error(TLSCLIENT, "internal inconsistency: TLS_CAN_USE_PINS\n");
966 	kr_assert(false);
967 	return GNUTLS_E_CERTIFICATE_ERROR;
968 #endif
969 }
970 
971 /**
972  * Verify that \param tls_session contains a valid X.509 certificate chain
973  * with given hostname.
974  *
975  * \returns GNUTLS_E_SUCCESS if certificate chain is valid, any other value is an error
976  */
client_verify_certchain(gnutls_session_t tls_session,const char * hostname)977 static int client_verify_certchain(gnutls_session_t tls_session, const char *hostname)
978 {
979 	if (kr_fails_assert(hostname)) {
980 		kr_log_error(TLSCLIENT, "internal config inconsistency: no hostname set\n");
981 		return GNUTLS_E_CERTIFICATE_ERROR;
982 	}
983 
984 	unsigned int status;
985 	int ret = gnutls_certificate_verify_peers3(tls_session, hostname, &status);
986 	if ((ret == GNUTLS_E_SUCCESS) && (status == 0)) {
987 		return GNUTLS_E_SUCCESS;
988 	}
989 
990 	if (ret == GNUTLS_E_SUCCESS) {
991 		gnutls_datum_t msg;
992 		ret = gnutls_certificate_verification_status_print(
993 			status, gnutls_certificate_type_get(tls_session), &msg, 0);
994 		if (ret == GNUTLS_E_SUCCESS) {
995 			kr_log_error(TLSCLIENT, "failed to verify peer certificate: "
996 					"%s\n", msg.data);
997 			gnutls_free(msg.data);
998 		} else {
999 			kr_log_error(TLSCLIENT, "failed to verify peer certificate: "
1000 					"unable to print reason: %s (%s)\n",
1001 					gnutls_strerror(ret), gnutls_strerror_name(ret));
1002 		} /* gnutls_certificate_verification_status_print end */
1003 	} else {
1004 		kr_log_error(TLSCLIENT, "failed to verify peer certificate: "
1005 			     "gnutls_certificate_verify_peers3 error: %s (%s)\n",
1006 			     gnutls_strerror(ret), gnutls_strerror_name(ret));
1007 	} /* gnutls_certificate_verify_peers3 end */
1008 	return GNUTLS_E_CERTIFICATE_ERROR;
1009 }
1010 
1011 /**
1012  * Verify that actual TLS security parameters of \param tls_session
1013  * match requirements provided by user in tls_session->params.
1014  * \returns GNUTLS_E_SUCCESS if requirements were met, any other value is an error
1015  */
client_verify_certificate(gnutls_session_t tls_session)1016 static int client_verify_certificate(gnutls_session_t tls_session)
1017 {
1018 	struct tls_client_ctx *ctx = gnutls_session_get_ptr(tls_session);
1019 	if (kr_fails_assert(ctx->params))
1020 		return GNUTLS_E_CERTIFICATE_ERROR;
1021 
1022 	if (ctx->params->insecure) {
1023 		return GNUTLS_E_SUCCESS;
1024 	}
1025 
1026 	gnutls_certificate_type_t cert_type = gnutls_certificate_type_get(tls_session);
1027 	if (cert_type != GNUTLS_CRT_X509) {
1028 		kr_log_error(TLSCLIENT, "invalid certificate type %i has been received\n",
1029 			     cert_type);
1030 		return GNUTLS_E_CERTIFICATE_ERROR;
1031 	}
1032 	unsigned int cert_list_size = 0;
1033 	const gnutls_datum_t *cert_list =
1034 		gnutls_certificate_get_peers(tls_session, &cert_list_size);
1035 	if (cert_list == NULL || cert_list_size == 0) {
1036 		kr_log_error(TLSCLIENT, "empty certificate list\n");
1037 		return GNUTLS_E_CERTIFICATE_ERROR;
1038 	}
1039 
1040 	if (ctx->params->pins.len > 0)
1041 		/* check hash of the certificate but ignore everything else */
1042 		return client_verify_pin(cert_list_size, cert_list, ctx->params);
1043 	else
1044 		return client_verify_certchain(ctx->c.tls_session, ctx->params->hostname);
1045 }
1046 
tls_client_ctx_new(tls_client_param_t * entry,struct worker_ctx * worker)1047 struct tls_client_ctx *tls_client_ctx_new(tls_client_param_t *entry,
1048 					    struct worker_ctx *worker)
1049 {
1050 	struct tls_client_ctx *ctx = calloc(1, sizeof (struct tls_client_ctx));
1051 	if (!ctx) {
1052 		return NULL;
1053 	}
1054 	unsigned int flags = GNUTLS_CLIENT | GNUTLS_NONBLOCK
1055 #ifdef GNUTLS_ENABLE_FALSE_START
1056 			     | GNUTLS_ENABLE_FALSE_START
1057 #endif
1058 	;
1059 	int ret = gnutls_init(&ctx->c.tls_session,  flags);
1060 	if (ret != GNUTLS_E_SUCCESS) {
1061 		tls_client_ctx_free(ctx);
1062 		return NULL;
1063 	}
1064 
1065 	ret = kres_gnutls_set_priority(ctx->c.tls_session);
1066 	if (ret != GNUTLS_E_SUCCESS) {
1067 		tls_client_ctx_free(ctx);
1068 		return NULL;
1069 	}
1070 
1071 	/* Must take a reference on parameters as the credentials are owned by it
1072 	 * and must not be freed while the session is active. */
1073 	++(entry->refs);
1074 	ctx->params = entry;
1075 
1076 	ret = gnutls_credentials_set(ctx->c.tls_session, GNUTLS_CRD_CERTIFICATE,
1077 	                             entry->credentials);
1078 	if (ret == GNUTLS_E_SUCCESS && entry->hostname) {
1079 		ret = gnutls_server_name_set(ctx->c.tls_session, GNUTLS_NAME_DNS,
1080 					entry->hostname, strlen(entry->hostname));
1081 		kr_log_debug(TLSCLIENT, "set hostname, ret = %d\n", ret);
1082 	} else if (!entry->hostname) {
1083 		kr_log_debug(TLSCLIENT, "no hostname\n");
1084 	}
1085 	if (ret != GNUTLS_E_SUCCESS) {
1086 		tls_client_ctx_free(ctx);
1087 		return NULL;
1088 	}
1089 
1090 	ctx->c.worker = worker;
1091 	ctx->c.client_side = true;
1092 
1093 	gnutls_transport_set_pull_function(ctx->c.tls_session, kres_gnutls_pull);
1094 	gnutls_transport_set_vec_push_function(ctx->c.tls_session, kres_gnutls_vec_push);
1095 	gnutls_transport_set_ptr(ctx->c.tls_session, ctx);
1096 	return ctx;
1097 }
1098 
tls_client_ctx_free(struct tls_client_ctx * ctx)1099 void tls_client_ctx_free(struct tls_client_ctx *ctx)
1100 {
1101 	if (ctx == NULL) {
1102 		return;
1103 	}
1104 
1105 	if (ctx->c.tls_session != NULL) {
1106 		gnutls_deinit(ctx->c.tls_session);
1107 		ctx->c.tls_session = NULL;
1108 	}
1109 
1110 	/* Must decrease the refcount for referenced parameters */
1111 	tls_client_param_unref(ctx->params);
1112 
1113 	free (ctx);
1114 }
1115 
tls_pull_timeout_func(gnutls_transport_ptr_t h,unsigned int ms)1116 int  tls_pull_timeout_func(gnutls_transport_ptr_t h, unsigned int ms)
1117 {
1118 	struct tls_common_ctx *t = (struct tls_common_ctx *)h;
1119 	if (kr_fails_assert(t)) {
1120 		errno = EFAULT;
1121 		return -1;
1122 	}
1123 	ssize_t avail = t->nread - t->consumed;
1124 	DEBUG_MSG("[%s] timeout check: available: %zu\n",
1125 		  t->client_side ? "tls_client" : "tls", avail);
1126 	if (avail <= 0) {
1127 		errno = EAGAIN;
1128 		return -1;
1129 	}
1130 	return avail;
1131 }
1132 
tls_client_connect_start(struct tls_client_ctx * client_ctx,struct session * session,tls_handshake_cb handshake_cb)1133 int tls_client_connect_start(struct tls_client_ctx *client_ctx,
1134 			     struct session *session,
1135 			     tls_handshake_cb handshake_cb)
1136 {
1137 	if (session == NULL || client_ctx == NULL)
1138 		return kr_error(EINVAL);
1139 
1140 	if (kr_fails_assert(session_flags(session)->outgoing && session_get_handle(session)->type == UV_TCP))
1141 		return kr_error(EINVAL);
1142 
1143 	struct tls_common_ctx *ctx = &client_ctx->c;
1144 
1145 	gnutls_session_set_ptr(ctx->tls_session, client_ctx);
1146 	gnutls_handshake_set_timeout(ctx->tls_session, ctx->worker->engine->net.tcp.tls_handshake_timeout);
1147 	gnutls_transport_set_pull_timeout_function(ctx->tls_session, tls_pull_timeout_func);
1148 	session_tls_set_client_ctx(session, client_ctx);
1149 	ctx->handshake_cb = handshake_cb;
1150 	ctx->handshake_state = TLS_HS_IN_PROGRESS;
1151 	ctx->session = session;
1152 
1153 	tls_client_param_t *tls_params = client_ctx->params;
1154 	if (tls_params->session_data.data != NULL) {
1155 		gnutls_session_set_data(ctx->tls_session, tls_params->session_data.data,
1156 					tls_params->session_data.size);
1157 	}
1158 
1159 	/* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */
1160 	while (ctx->handshake_state <= TLS_HS_IN_PROGRESS) {
1161 		int ret = tls_handshake(ctx, handshake_cb);
1162 		if (ret != kr_ok()) {
1163 			return ret;
1164 		}
1165 	}
1166 	return kr_ok();
1167 }
1168 
tls_get_hs_state(const struct tls_common_ctx * ctx)1169 tls_hs_state_t tls_get_hs_state(const struct tls_common_ctx *ctx)
1170 {
1171 	return ctx->handshake_state;
1172 }
1173 
tls_set_hs_state(struct tls_common_ctx * ctx,tls_hs_state_t state)1174 int tls_set_hs_state(struct tls_common_ctx *ctx, tls_hs_state_t state)
1175 {
1176 	if (state >= TLS_HS_LAST) {
1177 		return kr_error(EINVAL);
1178 	}
1179 	ctx->handshake_state = state;
1180 	return kr_ok();
1181 }
1182 
tls_client_ctx_set_session(struct tls_client_ctx * ctx,struct session * session)1183 int tls_client_ctx_set_session(struct tls_client_ctx *ctx, struct session *session)
1184 {
1185 	if (!ctx) {
1186 		return kr_error(EINVAL);
1187 	}
1188 	ctx->c.session = session;
1189 	return kr_ok();
1190 }
1191 
1192 #undef DEBUG_MSG
1193 #undef VERBOSE_MSG
1194