1 #include "libfilezilla/tls_layer.hpp"
2 #include "tls_layer_impl.hpp"
3 #include "libfilezilla/tls_info.hpp"
4 #include "tls_system_trust_store_impl.hpp"
5 
6 #include "libfilezilla/file.hpp"
7 #include "libfilezilla/iputils.hpp"
8 #include "libfilezilla/translate.hpp"
9 #include "libfilezilla/util.hpp"
10 
11 #include <gnutls/x509.h>
12 
13 #include <algorithm>
14 #include <set>
15 
16 #include <string.h>
17 
18 using namespace std::literals;
19 
20 #if DEBUG_SOCKETEVENTS
21 #include <assert.h>
22 
23 namespace fz {
24 bool FZ_PRIVATE_SYMBOL has_pending_event(event_handler * handler, socket_event_source const* const source, socket_event_flag event);
25 }
26 #endif
27 
28 static_assert(GNUTLS_VERSION_NUMBER != 0x030604, "Using TLS 1.3 with this version of GnuTLS does not work, update your version of GnuTLS");
29 
30 namespace fz {
31 
32 namespace {
33 
34 #if FZ_USE_GNUTLS_SYSTEM_CIPHERS
35 char const ciphers[] = "@SYSTEM:-ARCFOUR-128:-3DES-CBC:-MD5:-SIGN-RSA-MD5:-VERS-SSL3.0";
36 #else
37 	#if GNUTLS_VERSION_NUMBER >= 0x030600
38 		char const ciphers[] = "SECURE256:+SECURE128:-ARCFOUR-128:-3DES-CBC:-MD5:+SIGN-ALL:-SIGN-RSA-MD5:+CTYPE-X509:-VERS-SSL3.0";
39 	#else
40 		char const ciphers[] = "SECURE256:+SECURE128:-ARCFOUR-128:-3DES-CBC:-MD5:+SIGN-ALL:-SIGN-RSA-MD5:+CTYPE-X509:-CTYPE-OPENPGP:-VERS-SSL3.0";
41 	#endif
42 #endif
43 
44 #define TLSDEBUG 0
45 #if TLSDEBUG
46 // This is quite ugly
47 logger_interface* pLogging;
log_func(int level,char const * msg)48 extern "C" void log_func(int level, char const* msg)
49 {
50 	if (!msg || !pLogging) {
51 		return;
52 	}
53 	std::wstring s = to_wstring(msg);
54 	trim(s);
55 	pLogging->log(logmsg::debug_debug, L"tls: %d %s", level, s);
56 }
57 #endif
58 
remove_verification_events(event_handler * handler,tls_layer const * const source)59 void remove_verification_events(event_handler* handler, tls_layer const* const source)
60 {
61 	if (!handler) {
62 		return;
63 	}
64 
65 	auto event_filter = [&](event_loop::Events::value_type const& ev) -> bool {
66 		if (ev.first != handler) {
67 			return false;
68 		}
69 		else if (ev.second->derived_type() == certificate_verification_event::type()) {
70 			return std::get<0>(static_cast<certificate_verification_event const&>(*ev.second).v_) == source;
71 		}
72 		return false;
73 	};
74 
75 	handler->event_loop_.filter_events(event_filter);
76 }
77 
c_push_function(gnutls_transport_ptr_t ptr,const void * data,size_t len)78 extern "C" ssize_t c_push_function(gnutls_transport_ptr_t ptr, const void* data, size_t len)
79 {
80 	return ((tls_layer_impl*)ptr)->push_function(data, len);
81 }
82 
c_pull_function(gnutls_transport_ptr_t ptr,void * data,size_t len)83 extern "C" ssize_t c_pull_function(gnutls_transport_ptr_t ptr, void* data, size_t len)
84 {
85 	return ((tls_layer_impl*)ptr)->pull_function(data, len);
86 }
87 }
88 
89 class tls_layerCallbacks
90 {
91 public:
handshake_hook_func(gnutls_session_t session,unsigned int htype,unsigned int post,unsigned int incoming)92 	static int handshake_hook_func(gnutls_session_t session, unsigned int htype, unsigned int post, unsigned int incoming)
93 	{
94 		if (!session) {
95 			return 0;
96 		}
97 		auto* tls = reinterpret_cast<tls_layer_impl*>(gnutls_session_get_ptr(session));
98 		if (!tls) {
99 			return 0;
100 		}
101 
102 		char const* prefix;
103 		if (incoming) {
104 			if (post) {
105 				prefix = "Processed";
106 			}
107 			else {
108 				prefix = "Received";
109 			}
110 		}
111 		else {
112 			if (post) {
113 				prefix = "Sent";
114 			}
115 			else {
116 				prefix = "About to send";
117 			}
118 		}
119 
120 		char const* name = gnutls_handshake_description_get_name(static_cast<gnutls_handshake_description_t>(htype));
121 
122 		tls->logger_.log(logmsg::debug_debug, L"TLS handshake: %s %s", prefix, name);
123 
124 		return 0;
125 	}
126 
store_session(void * ptr,gnutls_datum_t const & key,gnutls_datum_t const & data)127 	static int store_session(void* ptr, gnutls_datum_t const& key, gnutls_datum_t const& data)
128 	{
129 		auto* tls = reinterpret_cast<tls_layer_impl*>(ptr);
130 		if (!tls) {
131 			return 0;
132 		}
133 		if (!key.size || !data.size) {
134 			return 0;
135 		}
136 		tls->session_db_key_.resize(key.size);
137 		memcpy(tls->session_db_key_.data(), key.data, key.size);
138 		tls->session_db_data_.resize(data.size);
139 		memcpy(tls->session_db_data_.data(), data.data, data.size);
140 
141 		return 0;
142 	}
143 
retrieve_session(void * ptr,gnutls_datum_t key)144 	static gnutls_datum_t retrieve_session(void *ptr, gnutls_datum_t key)
145 	{
146 		auto* tls = reinterpret_cast<tls_layer_impl*>(ptr);
147 		if (!tls) {
148 			return {};
149 		}
150 		if (!key.size) {
151 			return {};
152 		}
153 
154 		if (key.size == tls->session_db_key_.size() && !memcmp(tls->session_db_key_.data(), key.data, key.size)) {
155 			gnutls_datum_t d{};
156 			d.data = reinterpret_cast<unsigned char*>(gnutls_malloc(tls->session_db_data_.size()));
157 			if (d.data) {
158 				d.size = tls->session_db_data_.size();
159 				memcpy(d.data, tls->session_db_data_.data(), d.size);
160 			}
161 			return d;
162 		}
163 
164 		return gnutls_datum_t{};
165 	}
166 };
167 
168 namespace {
handshake_hook_func(gnutls_session_t session,unsigned int htype,unsigned int post,unsigned int incoming,gnutls_datum_t const *)169 extern "C" int handshake_hook_func(gnutls_session_t session, unsigned int htype, unsigned int post, unsigned int incoming, gnutls_datum_t const*)
170 {
171 	return tls_layerCallbacks::handshake_hook_func(session, htype, post, incoming);
172 }
173 
db_store_func(void * ptr,gnutls_datum_t key,gnutls_datum_t data)174 extern "C" int db_store_func(void *ptr, gnutls_datum_t key, gnutls_datum_t data)
175 {
176 	return tls_layerCallbacks::store_session(ptr, key, data);
177 }
178 
db_retr_func(void * ptr,gnutls_datum_t key)179 extern "C" gnutls_datum_t db_retr_func(void *ptr, gnutls_datum_t key)
180 {
181 	return tls_layerCallbacks::retrieve_session(ptr, key);
182 }
183 
to_string(gnutls_datum_t const & d)184 std::string to_string(gnutls_datum_t const& d)
185 {
186 	if (d.data && d.size) {
187 		return std::string(d.data, d.data + d.size);
188 	}
189 	return {};
190 }
191 
to_view(gnutls_datum_t const & d)192 std::string_view to_view(gnutls_datum_t const& d)
193 {
194 	if (d.data && d.size) {
195 		return std::string_view(reinterpret_cast<char const*>(d.data), d.size);
196 	}
197 	return {};
198 }
199 
200 struct datum_holder final : gnutls_datum_t
201 {
datum_holderfz::__anon85fcda0c0311::datum_holder202 	datum_holder() {
203 		data = nullptr;
204 		size = 0;
205 	}
206 
~datum_holderfz::__anon85fcda0c0311::datum_holder207 	~datum_holder() {
208 		gnutls_free(data);
209 	}
210 
clearfz::__anon85fcda0c0311::datum_holder211 	void clear()
212 	{
213 		gnutls_free(data);
214 		data = nullptr;
215 		size = 0;
216 	}
217 
218 	datum_holder(datum_holder const&) = delete;
219 	datum_holder& operator=(datum_holder const&) = delete;
220 
to_stringfz::__anon85fcda0c0311::datum_holder221 	std::string to_string() const {
222 		return data ? std::string(data, data + size) : std::string();
223 	}
224 
to_string_viewfz::__anon85fcda0c0311::datum_holder225 	std::string_view to_string_view() const {
226 		return data ? std::string_view(reinterpret_cast<char *>(data), size) : std::string_view();
227 	}
228 };
229 
clone_cert(gnutls_x509_crt_t in,gnutls_x509_crt_t & out)230 void clone_cert(gnutls_x509_crt_t in, gnutls_x509_crt_t &out)
231 {
232 	gnutls_x509_crt_deinit(out);
233 	out = nullptr;
234 
235 	if (in) {
236 		datum_holder der;
237 		if (gnutls_x509_crt_export2(in, GNUTLS_X509_FMT_DER, &der) == GNUTLS_E_SUCCESS) {
238 			gnutls_x509_crt_init(&out);
239 			if (gnutls_x509_crt_import(out, &der, GNUTLS_X509_FMT_DER) != GNUTLS_E_SUCCESS) {
240 				gnutls_x509_crt_deinit(out);
241 				out = nullptr;
242 			}
243 		}
244 	}
245 }
246 }
247 
tls_layer_impl(tls_layer & layer,tls_system_trust_store * systemTrustStore,logger_interface & logger)248 tls_layer_impl::tls_layer_impl(tls_layer& layer, tls_system_trust_store* systemTrustStore, logger_interface & logger)
249 	: tls_layer_(layer)
250 	, logger_(logger)
251 	, system_trust_store_(systemTrustStore)
252 {
253 }
254 
~tls_layer_impl()255 tls_layer_impl::~tls_layer_impl()
256 {
257 	deinit();
258 }
259 
init()260 bool tls_layer_impl::init()
261 {
262 	// This function initializes GnuTLS
263 	if (!initialized_) {
264 		initialized_ = true;
265 		int res = gnutls_global_init();
266 		if (res) {
267 			log_error(res, L"gnutls_global_init");
268 			deinit();
269 			return false;
270 		}
271 
272 #if TLSDEBUG
273 		if (!pLogging) {
274 			pLogging = &logger_;
275 			gnutls_global_set_log_function(log_func);
276 			gnutls_global_set_log_level(99);
277 		}
278 #endif
279 	}
280 
281 	if (!cert_credentials_) {
282 		int res = gnutls_certificate_allocate_credentials(&cert_credentials_);
283 		if (res < 0) {
284 			log_error(res, L"gnutls_certificate_allocate_credentials");
285 			deinit();
286 			return false;
287 		}
288 	}
289 
290 	return true;
291 }
292 
read_certificates_file(native_string const & certsfile,logger_interface * logger)293 std::string read_certificates_file(native_string const& certsfile, logger_interface * logger)
294 {
295 	file cf(certsfile, file::reading, file::existing);
296 	if (!cf.opened()) {
297 		if (logger) {
298 			logger->log(logmsg::error, fztranslate("Could not open certificate file."));
299 		}
300 		return {};
301 	}
302 	int64_t const cs = cf.size();
303 	if (cs < 0 || cs > 1024 * 1024) {
304 		if (logger) {
305 			logger->log(logmsg::error, fztranslate("Certificate file too big."));
306 		}
307 		return {};
308 	}
309 	std::string c;
310 	c.resize(cs);
311 	auto read = cf.read(c.data(), cs);
312 	if (read != cs) {
313 		if (logger) {
314 			logger->log(logmsg::error, fztranslate("Could not read certificate file."));
315 		}
316 		return {};
317 	}
318 	return c;
319 }
320 
set_certificate_file(native_string const & keyfile,native_string const & certsfile,native_string const & password,bool pem)321 bool tls_layer_impl::set_certificate_file(native_string const& keyfile, native_string const& certsfile, native_string const& password, bool pem)
322 {
323 	// Load the files ourselves instead of calling gnutls_certificate_set_x509_key_file2
324 	// as it takes narrow strings on MSW, thus being unable to open all files.
325 
326 	file kf(keyfile, file::reading, file::existing);
327 	if (!kf.opened()) {
328 		logger_.log(logmsg::error, fztranslate("Could not open key file."));
329 		return false;
330 	}
331 	int64_t const ks = kf.size();
332 	if (ks < 0 || ks > 1024 * 1024) {
333 		logger_.log(logmsg::error, fztranslate("Key file too big."));
334 		return false;
335 	}
336 	std::string k;
337 	k.resize(ks);
338 	auto read = kf.read(k.data(), ks);
339 	if (read != ks) {
340 		logger_.log(logmsg::error, fztranslate("Could not read key file."));
341 		return false;
342 	}
343 
344 	std::string c = read_certificates_file(certsfile, &logger_);
345 	if (c.empty()) {
346 		return false;
347 	}
348 
349 	return set_certificate(k, c, password, pem);
350 }
351 
set_certificate(std::string_view const & key,std::string_view const & certs,native_string const & password,bool pem)352 bool tls_layer_impl::set_certificate(std::string_view const& key, std::string_view const& certs, native_string const& password, bool pem)
353 {
354 	if (!init()) {
355 		return false;
356 	}
357 
358 	if (!cert_credentials_) {
359 		return false;
360 	}
361 
362 	gnutls_datum_t c;
363 	c.data = const_cast<unsigned char*>(reinterpret_cast<unsigned char const*>(certs.data()));
364 	c.size = certs.size();
365 
366 	gnutls_datum_t k;
367 	k.data = const_cast<unsigned char*>(reinterpret_cast<unsigned char const*>(key.data()));
368 	k.size = key.size();
369 
370 	int res = gnutls_certificate_set_x509_key_mem2(cert_credentials_, &c,
371 		&k, pem ? GNUTLS_X509_FMT_PEM : GNUTLS_X509_FMT_DER, password.empty() ? nullptr : to_utf8(password).c_str(), 0);
372 	if (res < 0) {
373 		log_error(res, L"gnutls_certificate_set_x509_key_mem2");
374 		deinit();
375 		return false;
376 	}
377 
378 	return true;
379 }
380 // Convert them all to PEM
381 
init_session(bool client)382 bool tls_layer_impl::init_session(bool client)
383 {
384 	if (!cert_credentials_) {
385 		deinit();
386 		return false;
387 	}
388 
389 	int res = gnutls_init(&session_, client ? GNUTLS_CLIENT : GNUTLS_SERVER);
390 	if (res) {
391 		log_error(res, L"gnutls_init");
392 		deinit();
393 		return false;
394 	}
395 
396 	if (!client) {
397 		if (ticket_key_.empty()) {
398 			datum_holder h;
399 			res = gnutls_session_ticket_key_generate(&h);
400 			if (res) {
401 				log_error(res, L"gnutls_session_ticket_key_generate");
402 				deinit();
403 				return false;
404 			}
405 			ticket_key_.assign(h.data, h.data + h.size);
406 		}
407 
408 		gnutls_datum_t k;
409 		k.data = ticket_key_.data();
410 		k.size = ticket_key_.size();
411 		res = gnutls_session_ticket_enable_server(session_, &k);
412 		if (res) {
413 			log_error(res, L"gnutls_session_ticket_enable_server");
414 			deinit();
415 			return false;
416 		}
417 	}
418 
419 	// For use in callbacks
420 	gnutls_session_set_ptr(session_, this);
421 	gnutls_db_set_ptr(session_, this);
422 
423 	// Even though the name gnutls_db_set_cache_expiration
424 	// implies expiration of some cache, it also governs
425 	// the actual session lifetime, independend whether the
426 	// session is cached or not.
427 	gnutls_db_set_cache_expiration(session_, 100000000);
428 
429 	if (!client) {
430 		gnutls_db_set_ptr(session_, this);
431 		gnutls_db_set_store_function(session_, &db_store_func);
432 		gnutls_db_set_retrieve_function(session_, &db_retr_func);
433 	}
434 
435 	std::string prio = ciphers;
436 	switch (min_tls_ver_) {
437 	case tls_ver::v1_3:
438 		prio += ":-VERS-TLS1.2";
439 		// Fallthrough
440 	case tls_ver::v1_2:
441 		prio += ":-VERS-TLS1.1";
442 		// Fallthrough
443 	case tls_ver::v1_1:
444 		prio += ":-VERS-TLS1.0";
445 		break;
446 	default:
447 		break;
448 	}
449 
450 	if (max_tls_ver_) {
451 		switch (*max_tls_ver_) {
452 		case tls_ver::v1_0:
453 			prio += ":-VERS-TLS1.1";
454 			// Fallthrough
455 		case tls_ver::v1_1:
456 			prio += ":-VERS-TLS1.2";
457 			// Fallthrough
458 		case tls_ver::v1_2:
459 #if GNUTLS_VERSION_NUMBER >= 0x030603
460 			prio += ":-VERS-TLS1.3";
461 #endif
462 			break;
463 		default:
464 			break;
465 		}
466 	}
467 
468 	res = gnutls_priority_set_direct(session_, prio.c_str(), nullptr);
469 	if (res) {
470 		log_error(res, L"gnutls_priority_set_direct");
471 		deinit();
472 		return false;
473 	}
474 
475 	gnutls_dh_set_prime_bits(session_, 1024);
476 
477 	gnutls_credentials_set(session_, GNUTLS_CRD_CERTIFICATE, cert_credentials_);
478 
479 	// Setup transport functions
480 	gnutls_transport_set_push_function(session_, c_push_function);
481 	gnutls_transport_set_pull_function(session_, c_pull_function);
482 	gnutls_transport_set_ptr(session_, (gnutls_transport_ptr_t)this);
483 
484 	if (!do_set_alpn()) {
485 		deinit();
486 		return false;
487 	}
488 
489 	return true;
490 }
491 
deinit()492 void tls_layer_impl::deinit()
493 {
494 	deinit_session();
495 
496 	if (cert_credentials_) {
497 		gnutls_certificate_free_credentials(cert_credentials_);
498 		cert_credentials_ = nullptr;
499 	}
500 
501 	if (initialized_) {
502 		initialized_ = false;
503 		gnutls_global_deinit();
504 	}
505 
506 	ticket_key_.clear();
507 
508 	state_ = socket_state::failed;
509 
510 #if TLSDEBUG
511 	if (pLogging == &logger_) {
512 		pLogging = nullptr;
513 	}
514 #endif
515 
516 	remove_verification_events(verification_handler_, &tls_layer_);
517 	verification_handler_ = nullptr;
518 }
519 
520 
deinit_session()521 void tls_layer_impl::deinit_session()
522 {
523 	if (session_) {
524 		gnutls_deinit(session_);
525 		session_ = nullptr;
526 	}
527 }
528 
529 
log_error(int code,std::wstring const & function,logmsg::type logLevel)530 void tls_layer_impl::log_error(int code, std::wstring const& function, logmsg::type logLevel)
531 {
532 	if (logLevel < logmsg::debug_warning && state_ >= socket_state::shut_down && shutdown_silence_read_errors_) {
533 		logLevel = logmsg::debug_warning;
534 	}
535 
536 	if (code == GNUTLS_E_WARNING_ALERT_RECEIVED || code == GNUTLS_E_FATAL_ALERT_RECEIVED) {
537 		log_alert(logLevel);
538 	}
539 	else if (code == GNUTLS_E_PULL_ERROR) {
540 		if (function.empty()) {
541 			logger_.log(logmsg::debug_warning, L"GnuTLS could not read from socket: %s", socket_error_description(socket_error_));
542 		}
543 		else {
544 			logger_.log(logmsg::debug_warning, L"GnuTLS could not read from socket in %s: %s", function, socket_error_description(socket_error_));
545 		}
546 	}
547 	else if (code == GNUTLS_E_PUSH_ERROR) {
548 		if (function.empty()) {
549 			logger_.log(logmsg::debug_warning, L"GnuTLS could not write to socket: %s", socket_error_description(socket_error_));
550 		}
551 		else {
552 			logger_.log(logmsg::debug_warning, L"GnuTLS could not write to socket in %s: %s", function, socket_error_description(socket_error_));
553 		}
554 	}
555 	else {
556 		char const* error = gnutls_strerror(code);
557 		if (error) {
558 			if (function.empty()) {
559 				logger_.log(logLevel, fztranslate("GnuTLS error %d: %s"), code, error);
560 			}
561 			else {
562 				logger_.log(logLevel, fztranslate("GnuTLS error %d in %s: %s"), code, function, error);
563 			}
564 		}
565 		else {
566 			if (function.empty()) {
567 				logger_.log(logLevel, fztranslate("GnuTLS error %d"), code);
568 			}
569 			else {
570 				logger_.log(logLevel, fztranslate("GnuTLS error %d in %s"), code, function);
571 			}
572 		}
573 	}
574 }
575 
log_alert(logmsg::type logLevel)576 void tls_layer_impl::log_alert(logmsg::type logLevel)
577 {
578 	gnutls_alert_description_t last_alert = gnutls_alert_get(session_);
579 	char const* alert = gnutls_alert_get_name(last_alert);
580 	if (alert) {
581 		logger_.log(logLevel,
582 					server_ ? fztranslate("Received TLS alert from the client: %s (%d)") : fztranslate("Received TLS alert from the server: %s (%d)"),
583 					alert, last_alert);
584 	}
585 	else {
586 		logger_.log(logLevel,
587 					server_ ? fztranslate("Received unknown TLS alert %d from the client") : fztranslate("Received unknown TLS alert %d from the server"),
588 					last_alert);
589 	}
590 }
591 
push_function(void const * data,size_t len)592 ssize_t tls_layer_impl::push_function(void const* data, size_t len)
593 {
594 #if TLSDEBUG
595 	logger_.log(logmsg::debug_debug, L"tls_layer_impl::push_function(%d)", len);
596 #endif
597 	if (!can_write_to_socket_) {
598 		gnutls_transport_set_errno(session_, EAGAIN);
599 		return -1;
600 	}
601 
602 	int error;
603 	int written = tls_layer_.next_layer_.write(data, static_cast<unsigned int>(len), error);
604 
605 	if (written < 0) {
606 		can_write_to_socket_ = false;
607 		if (error != EAGAIN) {
608 			socket_error_ = error;
609 		}
610 		gnutls_transport_set_errno(session_, error);
611 #if TLSDEBUG
612 		logger_.log(logmsg::debug_debug, L"  returning -1 due to %d", error);
613 #endif
614 		return -1;
615 	}
616 
617 #if TLSDEBUG
618 	logger_.log(logmsg::debug_debug, L"  returning %d", written);
619 #endif
620 
621 	return written;
622 }
623 
pull_function(void * data,size_t len)624 ssize_t tls_layer_impl::pull_function(void* data, size_t len)
625 {
626 #if TLSDEBUG
627 	logger_.log(logmsg::debug_debug, L"tls_layer_impl::pull_function(%d)",  (int)len);
628 #endif
629 
630 	if (!can_read_from_socket_) {
631 		gnutls_transport_set_errno(session_, EAGAIN);
632 		return -1;
633 	}
634 
635 	int error;
636 	int read = tls_layer_.next_layer_.read(data, static_cast<unsigned int>(len), error);
637 	if (read < 0) {
638 		if (error != EAGAIN) {
639 			socket_error_ = error;
640 		}
641 		else {
642 			can_read_from_socket_ = false;
643 		}
644 		gnutls_transport_set_errno(session_, error);
645 #if TLSDEBUG
646 		logger_.log(logmsg::debug_debug, L"  returning -1 due to %d", error);
647 #endif
648 		return -1;
649 	}
650 
651 	if (!read) {
652 		socket_eof_ = true;
653 	}
654 
655 #if TLSDEBUG
656 	logger_.log(logmsg::debug_debug, L"  returning %d", read);
657 #endif
658 
659 	return read;
660 }
661 
operator ()(event_base const & ev)662 void tls_layer_impl::operator()(event_base const& ev)
663 {
664 	dispatch<socket_event, hostaddress_event>(ev, this
665 		, &tls_layer_impl::on_socket_event
666 		, &tls_layer_impl::forward_hostaddress_event);
667 }
668 
forward_hostaddress_event(socket_event_source * source,std::string const & address)669 void tls_layer_impl::forward_hostaddress_event(socket_event_source* source, std::string const& address)
670 {
671 	tls_layer_.forward_hostaddress_event(source, address);
672 }
673 
on_socket_event(socket_event_source * s,socket_event_flag t,int error)674 void tls_layer_impl::on_socket_event(socket_event_source* s, socket_event_flag t, int error)
675 {
676 	if (!session_) {
677 		return;
678 	}
679 
680 	if (t == socket_event_flag::connection_next) {
681 		tls_layer_.forward_socket_event(s, t, error);
682 		return;
683 	}
684 
685 	if (error) {
686 		socket_error_ = error;
687 		deinit();
688 		tls_layer_.forward_socket_event(s, t, error);
689 		return;
690 	}
691 
692 	switch (t)
693 	{
694 	case socket_event_flag::read:
695 		on_read();
696 		break;
697 	case socket_event_flag::write:
698 		on_send();
699 		break;
700 	case socket_event_flag::connection:
701 		if (hostname_.empty()) {
702 			set_hostname(tls_layer_.next_layer_.peer_host());
703 		}
704 		on_send();
705 		break;
706 	default:
707 		break;
708 	}
709 }
710 
on_read()711 void tls_layer_impl::on_read()
712 {
713 	logger_.log(logmsg::debug_debug, L"tls_layer_impl::on_read()");
714 
715 #if DEBUG_SOCKETEVENTS
716 	assert(!can_read_from_socket_);
717 #endif
718 	can_read_from_socket_ = true;
719 
720 	if (!session_) {
721 		return;
722 	}
723 
724 	if (state_ == socket_state::connecting) {
725 		continue_handshake();
726 	}
727 	else if (state_ == socket_state::connected || state_ == socket_state::shutting_down || state_ == socket_state::shut_down) {
728 #if DEBUG_SOCKETEVENTS
729 		assert(!debug_can_read_);
730 		debug_can_read_ = true;
731 #endif
732 		if (tls_layer_.event_handler_) {
733 			tls_layer_.event_handler_->send_event<socket_event>(&tls_layer_, socket_event_flag::read, 0);
734 		}
735 	}
736 }
737 
on_send()738 void tls_layer_impl::on_send()
739 {
740 	logger_.log(logmsg::debug_debug, L"tls_layer_impl::on_send()");
741 
742 	can_write_to_socket_ = true;
743 
744 	if (!session_) {
745 		return;
746 	}
747 
748 	if (state_ == socket_state::connecting) {
749 		continue_handshake();
750 	}
751 	else if (state_ == socket_state::shutting_down) {
752 		int res = continue_write();
753 		if (res) {
754 			return;
755 		}
756 
757 		res = continue_shutdown();
758 		if (res != EAGAIN) {
759 			if (tls_layer_.event_handler_) {
760 				tls_layer_.event_handler_->send_event<socket_event>(&tls_layer_, socket_event_flag::write, res);
761 			}
762 		}
763 	}
764 	else if (state_ == socket_state::connected) {
765 		continue_write();
766 	}
767 }
768 
continue_write()769 int tls_layer_impl::continue_write()
770 {
771 	if (send_buffer_.empty()) {
772 		return 0;
773 	}
774 
775 	do {
776 		ssize_t res = GNUTLS_E_AGAIN;
777 		while ((res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN) && can_write_to_socket_) {
778 			res = gnutls_record_send(session_, send_buffer_.get(), send_buffer_.size());
779 		}
780 
781 		if (res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN) {
782 			return EAGAIN;
783 		}
784 
785 		if (res < 0) {
786 			failure(static_cast<int>(res), true);
787 			return ECONNABORTED;
788 		}
789 
790 		if (static_cast<size_t>(res) > send_buffer_.size()) {
791 			logger_.log(logmsg::error, L"gnutls_record_send has processed more data than it has buffered");
792 			failure(0, true);
793 			return ECONNABORTED;
794 		}
795 
796 		send_buffer_.consume(static_cast<size_t>(res));
797 	}
798 	while (!send_buffer_.empty());
799 
800 	if (write_blocked_by_send_buffer_) {
801 		write_blocked_by_send_buffer_ = false;
802 
803 		if (state_ == socket_state::connected) {
804 #if DEBUG_SOCKETEVENTS
805 			assert(!debug_can_write_);
806 			debug_can_write_ = true;
807 #endif
808 			if (tls_layer_.event_handler_) {
809 				tls_layer_.event_handler_->send_event<socket_event>(&tls_layer_, socket_event_flag::write, 0);
810 			}
811 		}
812 	}
813 
814 	return 0;
815 }
816 
resumed_session() const817 bool tls_layer_impl::resumed_session() const
818 {
819 	return gnutls_session_is_resumed(session_) != 0;
820 }
821 
client_handshake(std::vector<uint8_t> const & session_to_resume,native_string const & session_hostname,std::vector<uint8_t> const & required_certificate,event_handler * const verification_handler)822 bool tls_layer_impl::client_handshake(std::vector<uint8_t> const& session_to_resume, native_string const& session_hostname, std::vector<uint8_t> const& required_certificate, event_handler *const verification_handler)
823 {
824 	logger_.log(logmsg::debug_verbose, L"tls_layer_impl::client_handshake()");
825 
826 	if (state_ != socket_state::none) {
827 		logger_.log(logmsg::debug_warning, L"Called tls_layer_impl::client_handshake on a socket that isn't idle");
828 		return false;
829 	}
830 
831 	if (!init() || !init_session(true)) {
832 		return false;
833 	}
834 
835 	state_ = socket_state::connecting;
836 
837 	if (!required_certificate.empty()) {
838 		std::string_view v(reinterpret_cast<char const*>(required_certificate.data()), required_certificate.size());
839 		size_t i = v.find_first_not_of("-");
840 		size_t p = v.find("BEGIN ");
841 		if (i != std::string_view::npos && i >= 4 && i == p) {
842 			// It's PEM
843 			gnutls_datum_t in;
844 			in.data = const_cast<unsigned char*>(reinterpret_cast<unsigned char const*>(required_certificate.data()));
845 			in.size = required_certificate.size();
846 
847 			datum_holder der;
848 			gnutls_pem_base64_decode2(nullptr, &in, &der);
849 
850 			required_certificate_.assign(der.data, der.data + der.size);
851 		}
852 		else {
853 			// Must be DER
854 			required_certificate_ = required_certificate;
855 		}
856 	}
857 
858 	verification_handler_ = verification_handler;
859 
860 	if (!session_to_resume.empty()) {
861 		int res = gnutls_session_set_data(session_, session_to_resume.data(), session_to_resume.size());
862 		if (res) {
863 			logger_.log(logmsg::debug_info, L"gnutls_session_set_data failed: %d. Going to reinitialize session.", res);
864 			deinit_session();
865 			if (!init_session(true)) {
866 				return false;
867 			}
868 		}
869 		else {
870 			logger_.log(logmsg::debug_info, L"Trying to resume existing TLS session.");
871 		}
872 	}
873 
874 	if (logger_.should_log(logmsg::debug_debug)) {
875 		gnutls_handshake_set_hook_function(session_, GNUTLS_HANDSHAKE_ANY, GNUTLS_HOOK_BOTH, &handshake_hook_func);
876 	}
877 
878 	if (!session_hostname.empty()) {
879 		set_hostname(session_hostname);
880 	}
881 	else if (!hostname_.empty()) {
882 		set_hostname(hostname_);
883 	}
884 
885 	if (tls_layer_.next_layer_.get_state() == socket_state::none || tls_layer_.next_layer_.get_state() == socket_state::connecting) {
886 		// Wait until the socket gets connected
887 		return true;
888 	}
889 	else if (tls_layer_.next_layer_.get_state() != socket_state::connected) {
890 		// We're too late
891 		return false;
892 	}
893 
894 	if (hostname_.empty()) {
895 		set_hostname(tls_layer_.next_layer_.peer_host());
896 	}
897 	return continue_handshake() == EAGAIN;
898 }
899 
900 namespace {
extract_with_size(uint8_t const * & p,uint8_t const * const end,std::vector<uint8_t> & target)901 bool extract_with_size(uint8_t const* &p, uint8_t const* const end, std::vector<uint8_t>& target)
902 {
903 	size_t s;
904 	if (static_cast<size_t>(end - p) < sizeof(s)) {
905 		return false;
906 	}
907 	memcpy(&s, p, sizeof(s));
908 	p += sizeof(s);
909 	if (static_cast<size_t>(end - p) < s) {
910 		return false;
911 	}
912 	target.resize(s);
913 	if (s) {
914 		memcpy(target.data(), p, s);
915 		p += s;
916 	}
917 	return true;
918 }
919 }
920 
server_handshake(std::vector<uint8_t> const & session_to_resume,std::string_view const & preamble)921 bool tls_layer_impl::server_handshake(std::vector<uint8_t> const& session_to_resume, std::string_view const& preamble)
922 {
923 	logger_.log(logmsg::debug_verbose, L"tls_layer_impl::server_handshake()");
924 
925 	if (state_ != socket_state::none) {
926 		logger_.log(logmsg::debug_warning, L"Called tls_layer_impl::server_handshake on a socket that isn't idle");
927 		return false;
928 	}
929 
930 	server_ = true;
931 
932 	if (!session_to_resume.empty()) {
933 		auto const* p = session_to_resume.data();
934 		auto const* const end = p + session_to_resume.size();
935 		if (!extract_with_size(p, end, ticket_key_)) {
936 			return false;
937 		}
938 		if (!extract_with_size(p, end, session_db_key_)) {
939 			return false;
940 		}
941 		if (!extract_with_size(p, end, session_db_data_)) {
942 			return false;
943 		}
944 	}
945 
946 	if (!init() || !init_session(false)) {
947 		return false;
948 	}
949 
950 	state_ = socket_state::connecting;
951 
952 	if (logger_.should_log(logmsg::debug_debug)) {
953 		gnutls_handshake_set_hook_function(session_, GNUTLS_HANDSHAKE_ANY, GNUTLS_HOOK_BOTH, &handshake_hook_func);
954 	}
955 
956 	if (tls_layer_.next_layer_.get_state() == socket_state::none || tls_layer_.next_layer_.get_state() == socket_state::connecting) {
957 		// Wait until the socket gets connected
958 		return true;
959 	}
960 	else if (tls_layer_.next_layer_.get_state() != socket_state::connected) {
961 		// We're too late
962 		return false;
963 	}
964 
965 	preamble_.append(preamble);
966 
967 	return continue_handshake() == EAGAIN;
968 }
969 
continue_handshake()970 int tls_layer_impl::continue_handshake()
971 {
972 	logger_.log(logmsg::debug_verbose, L"tls_layer_impl::continue_handshake()");
973 	if (!session_ || state_ != socket_state::connecting) {
974 		return ENOTCONN;
975 	}
976 
977 	while (!preamble_.empty()) {
978 		if (!can_write_to_socket_) {
979 			return EAGAIN;
980 		}
981 
982 		int error{};
983 		int written = tls_layer_.next_layer_.write(preamble_.get(), static_cast<int>(preamble_.size()), error);
984 		if (written < 0) {
985 			can_write_to_socket_ = false;
986 			if (error != EAGAIN) {
987 				socket_error_ = error;
988 				failure(0, true);
989 			}
990 			return error;
991 		}
992 		preamble_.consume(static_cast<size_t>(written));
993 	}
994 
995 	int res = gnutls_handshake(session_);
996 	while (res == GNUTLS_E_AGAIN || res == GNUTLS_E_INTERRUPTED) {
997 		if (!(gnutls_record_get_direction(session_) ? can_write_to_socket_ : can_read_from_socket_)) {
998 			break;
999 		}
1000 		res = gnutls_handshake(session_);
1001 	}
1002 	if (!res) {
1003 		logger_.log(logmsg::debug_info, L"TLS Handshake successful");
1004 		handshake_successful_ = true;
1005 
1006 		if (resumed_session()) {
1007 			logger_.log(logmsg::debug_info, L"TLS Session resumed");
1008 		}
1009 
1010 		std::string const protocol = get_protocol();
1011 		std::string const keyExchange = get_key_exchange();
1012 		std::string const cipherName = get_cipher();
1013 		std::string const macName = get_mac();
1014 
1015 		logger_.log(logmsg::debug_info, L"Protocol: %s, Key exchange: %s, Cipher: %s, MAC: %s", protocol, keyExchange, cipherName, macName);
1016 
1017 		if (is_client()) {
1018 			return verify_certificate();
1019 		}
1020 		else {
1021 			state_ = socket_state::connected;
1022 
1023 #if DEBUG_SOCKETEVENTS
1024 			if (can_read_from_socket_) {
1025 				assert(!debug_can_read_);
1026 				debug_can_read_ = true;
1027 			}
1028 			assert(!debug_can_write_);
1029 			debug_can_write_ = true;
1030 #endif
1031 			if (tls_layer_.event_handler_) {
1032 				tls_layer_.event_handler_->send_event<socket_event>(&tls_layer_, socket_event_flag::connection, 0);
1033 				if (can_read_from_socket_) {
1034 					tls_layer_.event_handler_->send_event<socket_event>(&tls_layer_, socket_event_flag::read, 0);
1035 				}
1036 			}
1037 		}
1038 
1039 		return 0;
1040 	}
1041 	else if (res == GNUTLS_E_AGAIN || res == GNUTLS_E_INTERRUPTED) {
1042 		if (!socket_error_) {
1043 			return EAGAIN;
1044 		}
1045 
1046 		// GnuTLS has a writev() emulation that ignores trailing errors if
1047 		// at least some data got sent
1048 		res = GNUTLS_E_PUSH_ERROR;
1049 	}
1050 
1051 	failure(res, true);
1052 
1053 	return socket_error_ ? socket_error_ : ECONNABORTED;
1054 }
1055 
read(void * buffer,unsigned int len,int & error)1056 int tls_layer_impl::read(void *buffer, unsigned int len, int& error)
1057 {
1058 	if (state_ == socket_state::connecting) {
1059 		error = EAGAIN;
1060 		return -1;
1061 	}
1062 	else if (state_ != socket_state::connected && state_ != socket_state::shutting_down && state_ != socket_state::shut_down) {
1063 		error = ENOTCONN;
1064 		return -1;
1065 	}
1066 
1067 #if DEBUG_SOCKETEVENTS
1068 	assert(debug_can_read_);
1069 	assert(!has_pending_event(tls_layer_.event_handler_, &tls_layer_, socket_event_flag::read));
1070 #endif
1071 
1072 	int res = do_call_gnutls_record_recv(buffer, len);
1073 	if (res >= 0) {
1074 		error = 0;
1075 		return res;
1076 	}
1077 	else if (res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN) {
1078 #if DEBUG_SOCKETEVENTS
1079 		debug_can_read_ = false;
1080 #endif
1081 		error = EAGAIN;
1082 	}
1083 	else {
1084 		failure(res, false, L"gnutls_record_recv");
1085 		error = socket_error_ ? socket_error_ : ECONNABORTED;
1086 	}
1087 
1088 	return -1;
1089 }
1090 
write(void const * buffer,unsigned int len,int & error)1091 int tls_layer_impl::write(void const* buffer, unsigned int len, int& error)
1092 {
1093 	if (state_ == socket_state::connecting) {
1094 		error = EAGAIN;
1095 		return -1;
1096 	}
1097 	else if (state_ == socket_state::shutting_down || state_ == socket_state::shut_down) {
1098 		error = ESHUTDOWN;
1099 		return -1;
1100 	}
1101 	else if (state_ != socket_state::connected) {
1102 		error = ENOTCONN;
1103 		return -1;
1104 	}
1105 
1106 #if DEBUG_SOCKETEVENTS
1107 	assert(debug_can_write_);
1108 	assert(!has_pending_event(tls_layer_.event_handler_, &tls_layer_, socket_event_flag::write));
1109 #endif
1110 
1111 	if (!send_buffer_.empty()) {
1112 		write_blocked_by_send_buffer_ = true;
1113 #if DEBUG_SOCKETEVENTS
1114 		debug_can_write_ = false;
1115 #endif
1116 		error = EAGAIN;
1117 		return -1;
1118 	}
1119 
1120 	ssize_t res = gnutls_record_send(session_, buffer, len);
1121 
1122 	while ((res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN) && can_write_to_socket_) {
1123 		res = gnutls_record_send(session_, nullptr, 0);
1124 	}
1125 
1126 	if (res >= 0) {
1127 		error = 0;
1128 		return static_cast<int>(res);
1129 	}
1130 
1131 	if (res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN) {
1132 		if (!socket_error_) {
1133 			// Unfortunately we can't return EAGAIN here as GnuTLS has already consumed some data.
1134 			// With our semantics, EAGAIN means nothing has been handed off yet.
1135 			// Thus remember up to gnutls_record_get_max_size bytes from the input.
1136 			unsigned int max = static_cast<unsigned int>(gnutls_record_get_max_size(session_));
1137 			if (len > max) {
1138 				len = max;
1139 			}
1140 			send_buffer_.append(reinterpret_cast<unsigned char const*>(buffer), len);
1141 			return static_cast<int>(len);
1142 		}
1143 
1144 		// GnuTLS has a writev() emulation that ignores trailing errors if
1145 		// at least some data got sent
1146 		res = GNUTLS_E_PUSH_ERROR;
1147 	}
1148 
1149 	failure(static_cast<int>(res), false, L"gnutls_record_send");
1150 	error = socket_error_ ? socket_error_ : ECONNABORTED;
1151 	return -1;
1152 }
1153 
failure(int code,bool send_close,std::wstring const & function)1154 void tls_layer_impl::failure(int code, bool send_close, std::wstring const& function)
1155 {
1156 	logger_.log(logmsg::debug_debug, L"tls_layer_impl::failure(%d)", code);
1157 	if (code) {
1158 		log_error(code, function);
1159 		if (socket_eof_) {
1160 			if (code == GNUTLS_E_UNEXPECTED_PACKET_LENGTH
1161 #ifdef GNUTLS_E_PREMATURE_TERMINATION
1162 				|| code == GNUTLS_E_PREMATURE_TERMINATION
1163 #endif
1164 				)
1165 			{
1166 				if (state_ != socket_state::shut_down || !shutdown_silence_read_errors_) {
1167 					logger_.log(logmsg::status, server_ ? fztranslate("Client did not properly shut down TLS connection") : fztranslate("Server did not properly shut down TLS connection"));
1168 				}
1169 			}
1170 		}
1171 	}
1172 
1173 	auto const oldState = state_;
1174 
1175 	deinit();
1176 
1177 	if (send_close && tls_layer_.event_handler_) {
1178 		int error = socket_error_;
1179 		if (!error) {
1180 			error = ECONNABORTED;
1181 		}
1182 		if (oldState == socket_state::connecting) {
1183 			tls_layer_.event_handler_->send_event<socket_event>(&tls_layer_, socket_event_flag::connection, error);
1184 		}
1185 		else {
1186 			tls_layer_.event_handler_->send_event<socket_event>(&tls_layer_, socket_event_flag::read, error);
1187 		}
1188 	}
1189 }
1190 
shutdown()1191 int tls_layer_impl::shutdown()
1192 {
1193 	logger_.log(logmsg::debug_verbose, L"tls_layer_impl::shutdown()");
1194 
1195 	if (state_ == socket_state::shut_down) {
1196 		return 0;
1197 	}
1198 	else if (state_ == socket_state::shutting_down) {
1199 		return EAGAIN;
1200 	}
1201 	else if (state_ != socket_state::connected) {
1202 		return ENOTCONN;
1203 	}
1204 
1205 	state_ = socket_state::shutting_down;
1206 
1207 	if (!send_buffer_.empty()) {
1208 		logger_.log(logmsg::debug_verbose, L"Postponing shutdown, send_buffer_ not empty");
1209 		return EAGAIN;
1210 	}
1211 
1212 	return continue_shutdown();
1213 }
1214 
continue_shutdown()1215 int tls_layer_impl::continue_shutdown()
1216 {
1217 	logger_.log(logmsg::debug_verbose, L"tls_layer_impl::continue_shutdown()");
1218 
1219 	if (!sent_closure_alert_) {
1220 		int res = gnutls_bye(session_, GNUTLS_SHUT_WR);
1221 		while ((res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN) && can_write_to_socket_) {
1222 			res = gnutls_bye(session_, GNUTLS_SHUT_WR);
1223 		}
1224 		if (res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN) {
1225 			if (!socket_error_) {
1226 				return EAGAIN;
1227 			}
1228 
1229 			// GnuTLS has a writev() emulation that ignores trailing errors if
1230 			// at least some data got sent
1231 			res = GNUTLS_E_PUSH_ERROR;
1232 		}
1233 		if (res) {
1234 			failure(res, false, L"gnutls_bye");
1235 			return socket_error_ ? socket_error_ : ECONNABORTED;
1236 		}
1237 		sent_closure_alert_ = true;
1238 	}
1239 
1240 	int res = tls_layer_.next_layer_.shutdown();
1241 	if (res == EAGAIN) {
1242 		return EAGAIN;
1243 	}
1244 
1245 	if (!res) {
1246 		state_ = socket_state::shut_down;
1247 	}
1248 	else {
1249 		socket_error_ = res;
1250 		failure(0, false);
1251 	}
1252 	return res;
1253 }
1254 
set_verification_result(bool trusted)1255 void tls_layer_impl::set_verification_result(bool trusted)
1256 {
1257 	if (state_ != socket_state::connecting && !handshake_successful_) {
1258 		logger_.log(logmsg::debug_warning, L"TrustCurrentCert called at wrong time.");
1259 		return;
1260 	}
1261 
1262 	remove_verification_events(verification_handler_, &tls_layer_);
1263 	verification_handler_ = nullptr;
1264 
1265 	if (trusted) {
1266 		state_ = socket_state::connected;
1267 
1268 #if DEBUG_SOCKETEVENTS
1269 		if (can_read_from_socket_) {
1270 			assert(!debug_can_read_);
1271 			debug_can_read_ = true;
1272 		}
1273 		assert(!debug_can_write_);
1274 		debug_can_write_ = true;
1275 #endif
1276 		if (tls_layer_.event_handler_) {
1277 			tls_layer_.event_handler_->send_event<socket_event>(&tls_layer_, socket_event_flag::connection, 0);
1278 			if (can_read_from_socket_) {
1279 				tls_layer_.event_handler_->send_event<socket_event>(&tls_layer_, socket_event_flag::read, 0);
1280 			}
1281 		}
1282 
1283 		return;
1284 	}
1285 
1286 	logger_.log(logmsg::error, fztranslate("Remote certificate not trusted."));
1287 	failure(0, true);
1288 }
1289 
bin2hex(unsigned char const * in,size_t size)1290 static std::string bin2hex(unsigned char const* in, size_t size)
1291 {
1292 	std::string str;
1293 	str.reserve(size * 3);
1294 	for (size_t i = 0; i < size; ++i) {
1295 		if (i) {
1296 			str += ':';
1297 		}
1298 		str += int_to_hex_char<char>(in[i] >> 4);
1299 		str += int_to_hex_char<char>(in[i] & 0xf);
1300 	}
1301 
1302 	return str;
1303 }
1304 
1305 
extract_cert(gnutls_x509_crt_t const & cert,x509_certificate & out,bool last,logger_interface * logger)1306 bool tls_layer_impl::extract_cert(gnutls_x509_crt_t const& cert, x509_certificate& out, bool last, logger_interface * logger)
1307 {
1308 	datetime expiration_time(gnutls_x509_crt_get_expiration_time(cert), datetime::seconds);
1309 	datetime activation_time(gnutls_x509_crt_get_activation_time(cert), datetime::seconds);
1310 
1311 	if (!activation_time || !expiration_time || expiration_time < activation_time) {
1312 		if (logger) {
1313 			logger->log(logmsg::error, fztranslate("Could not extract validity period of certificate"));
1314 		}
1315 		return false;
1316 	}
1317 
1318 	// Get the serial number of the certificate
1319 	unsigned char buffer[40];
1320 	size_t size = sizeof(buffer);
1321 	int res = gnutls_x509_crt_get_serial(cert, buffer, &size);
1322 	if (res != 0) {
1323 		size = 0;
1324 	}
1325 
1326 	auto serial = bin2hex(buffer, size);
1327 
1328 	unsigned int pk_bits;
1329 	int pkAlgo = gnutls_x509_crt_get_pk_algorithm(cert, &pk_bits);
1330 	std::string pk_algo_name;
1331 	if (pkAlgo >= 0) {
1332 		char const* pAlgo = gnutls_pk_algorithm_get_name((gnutls_pk_algorithm_t)pkAlgo);
1333 		if (pAlgo) {
1334 			pk_algo_name = pAlgo;
1335 		}
1336 	}
1337 
1338 	int signAlgo = gnutls_x509_crt_get_signature_algorithm(cert);
1339 	std::string signAlgoName;
1340 	if (signAlgo >= 0) {
1341 		char const* pAlgo = gnutls_sign_algorithm_get_name((gnutls_sign_algorithm_t)signAlgo);
1342 		if (pAlgo) {
1343 			signAlgoName = pAlgo;
1344 		}
1345 	}
1346 
1347 	std::string subject, issuer;
1348 
1349 	datum_holder raw_subject;
1350 	res = gnutls_x509_crt_get_dn3(cert, &raw_subject, 0);
1351 	if (!res) {
1352 		subject = raw_subject.to_string_view();
1353 	}
1354 	else {
1355 		if (logger) {
1356 			logger->log(logmsg::debug_warning, "gnutls_x509_crt_get_dn3 failed with %d", res);
1357 		}
1358 	}
1359 	if (subject.empty()) {
1360 		if (logger) {
1361 			logger->log(logmsg::error, fztranslate("Could not get distinguished name of certificate subject, gnutls_x509_get_dn failed"));
1362 		}
1363 		return false;
1364 	}
1365 
1366 	std::vector<x509_certificate::subject_name> alt_subject_names = get_cert_subject_alt_names(cert);
1367 
1368 	datum_holder raw_issuer;
1369 	res = gnutls_x509_crt_get_issuer_dn3(cert, &raw_issuer, 0);
1370 	if (!res) {
1371 		issuer = raw_issuer.to_string_view();
1372 	}
1373 	else {
1374 		if (logger) {
1375 			logger->log(logmsg::debug_warning, "gnutls_x509_crt_get_issuer_dn3 failed with %d", res);
1376 		}
1377 	}
1378 	if (issuer.empty() ) {
1379 		if (logger) {
1380 			logger->log(logmsg::error, fztranslate("Could not get distinguished name of certificate issuer, gnutls_x509_get_issuer_dn failed"));
1381 		}
1382 		return false;
1383 	}
1384 
1385 	std::string fingerprint_sha256;
1386 	std::string fingerprint_sha1;
1387 
1388 	unsigned char digest[100];
1389 	size = sizeof(digest) - 1;
1390 	if (!gnutls_x509_crt_get_fingerprint(cert, GNUTLS_DIG_SHA256, digest, &size)) {
1391 		digest[size] = 0;
1392 		fingerprint_sha256 = bin2hex(digest, size);
1393 	}
1394 	size = sizeof(digest) - 1;
1395 	if (!gnutls_x509_crt_get_fingerprint(cert, GNUTLS_DIG_SHA1, digest, &size)) {
1396 		digest[size] = 0;
1397 		fingerprint_sha1 = bin2hex(digest, size);
1398 	}
1399 
1400 	datum_holder der;
1401 	if (gnutls_x509_crt_export2(cert, GNUTLS_X509_FMT_DER, &der) != GNUTLS_E_SUCCESS || !der.data || !der.size) {
1402 		if (logger) {
1403 			logger->log(logmsg::error, L"gnutls_x509_crt_get_issuer_dn");
1404 		}
1405 		return false;
1406 	}
1407 	std::vector<uint8_t> data(der.data, der.data + der.size);
1408 
1409 	out = x509_certificate(
1410 		std::move(data),
1411 		activation_time, expiration_time,
1412 		serial,
1413 		pk_algo_name, pk_bits,
1414 		signAlgoName,
1415 		fingerprint_sha256,
1416 		fingerprint_sha1,
1417 		issuer,
1418 		subject,
1419 		std::move(alt_subject_names),
1420 		last ? gnutls_x509_crt_check_issuer(cert, cert) : false);
1421 
1422 	return true;
1423 }
1424 
1425 
get_cert_subject_alt_names(gnutls_x509_crt_t cert)1426 std::vector<x509_certificate::subject_name> tls_layer_impl::get_cert_subject_alt_names(gnutls_x509_crt_t cert)
1427 {
1428 	std::vector<x509_certificate::subject_name> ret;
1429 
1430 	char san[4096];
1431 	for (unsigned int i = 0; i < 10000; ++i) { // I assume this is a sane limit
1432 		size_t san_size = sizeof(san) - 1;
1433 		int const type_or_error = gnutls_x509_crt_get_subject_alt_name(cert, i, san, &san_size, nullptr);
1434 		if (type_or_error == GNUTLS_E_SHORT_MEMORY_BUFFER) {
1435 			continue;
1436 		}
1437 		else if (type_or_error < 0) {
1438 			break;
1439 		}
1440 
1441 		if (type_or_error == GNUTLS_SAN_DNSNAME || type_or_error == GNUTLS_SAN_RFC822NAME) {
1442 			std::string dns = san;
1443 			if (!dns.empty()) {
1444 				ret.emplace_back(x509_certificate::subject_name{std::move(dns), type_or_error == GNUTLS_SAN_DNSNAME});
1445 			}
1446 		}
1447 		else if (type_or_error == GNUTLS_SAN_IPADDRESS) {
1448 			std::string ip = socket::address_to_string(san, static_cast<int>(san_size));
1449 			if (!ip.empty()) {
1450 				ret.emplace_back(x509_certificate::subject_name{std::move(ip), false});
1451 			}
1452 		}
1453 	}
1454 	return ret;
1455 }
1456 
certificate_is_blacklisted(cert_list_holder const & certs)1457 bool tls_layer_impl::certificate_is_blacklisted(cert_list_holder const& certs)
1458 {
1459 	for (size_t i = 0; i < certs.certs_size; ++i) {
1460 		if (certificate_is_blacklisted(certs.certs[i])) {
1461 			return true;
1462 		}
1463 	}
1464 	return false;
1465 }
1466 
certificate_is_blacklisted(gnutls_x509_crt_t const & cert)1467 bool tls_layer_impl::certificate_is_blacklisted(gnutls_x509_crt_t const& cert)
1468 {
1469 	static std::set<std::string, std::less<>> const bad_authority_key_ids = {
1470 		std::string("\xF4\x94\xBF\xDE\x50\xB6\xDB\x6B\x24\x3D\x9E\xF7\xBE\x3A\xAE\x36\xD7\xFB\x0E\x05", 20) // Nation-wide MITM in Kazakhstan
1471 	};
1472 
1473 	char buf[256];
1474 	unsigned int critical{};
1475 	size_t size = sizeof(buf);
1476 	int res = gnutls_x509_crt_get_authority_key_id(cert, buf, &size, &critical);
1477 	if (!res) {
1478 		auto it = bad_authority_key_ids.find(std::string_view(buf, size));
1479 		if (it != bad_authority_key_ids.cend()) {
1480 			return true;
1481 		}
1482 	}
1483 
1484 	return false;
1485 }
1486 
get_algorithm_warnings() const1487 int tls_layer_impl::get_algorithm_warnings() const
1488 {
1489 	int algorithmWarnings{};
1490 
1491 	switch (gnutls_protocol_get_version(session_))
1492 	{
1493 		case GNUTLS_SSL3:
1494 		case GNUTLS_VERSION_UNKNOWN:
1495 			algorithmWarnings |= tls_session_info::tlsver;
1496 			break;
1497 		default:
1498 			break;
1499 	}
1500 
1501 	switch (gnutls_cipher_get(session_)) {
1502 		case GNUTLS_CIPHER_UNKNOWN:
1503 		case GNUTLS_CIPHER_NULL:
1504 		case GNUTLS_CIPHER_ARCFOUR_128:
1505 		case GNUTLS_CIPHER_3DES_CBC:
1506 		case GNUTLS_CIPHER_ARCFOUR_40:
1507 		case GNUTLS_CIPHER_RC2_40_CBC:
1508 		case GNUTLS_CIPHER_DES_CBC:
1509 			algorithmWarnings |= tls_session_info::cipher;
1510 			break;
1511 		default:
1512 			break;
1513 	}
1514 
1515 	switch (gnutls_mac_get(session_)) {
1516 		case GNUTLS_MAC_UNKNOWN:
1517 		case GNUTLS_MAC_NULL:
1518 		case GNUTLS_MAC_MD5:
1519 		case GNUTLS_MAC_MD2:
1520 		case GNUTLS_MAC_UMAC_96:
1521 			algorithmWarnings |= tls_session_info::mac;
1522 			break;
1523 		default:
1524 			break;
1525 	}
1526 
1527 	switch (gnutls_kx_get(session_)) {
1528 		case GNUTLS_KX_UNKNOWN:
1529 		case GNUTLS_KX_ANON_DH:
1530 		case GNUTLS_KX_RSA_EXPORT:
1531 		case GNUTLS_KX_ANON_ECDH:
1532 			algorithmWarnings |= tls_session_info::kex;
1533 		default:
1534 			break;
1535 	}
1536 
1537 	return algorithmWarnings;
1538 }
1539 
load_certificates(std::string_view const & in,bool pem,gnutls_x509_crt_t * & certs,unsigned int & certs_size,bool & sort)1540 int tls_layer_impl::load_certificates(std::string_view const& in, bool pem, gnutls_x509_crt_t *& certs, unsigned int & certs_size, bool & sort)
1541 {
1542 	gnutls_datum_t dpem;
1543 	dpem.data = reinterpret_cast<unsigned char*>(const_cast<char *>(in.data()));
1544 	dpem.size = in.size();
1545 	unsigned int flags{};
1546 	if (sort) {
1547 		flags |= GNUTLS_X509_CRT_LIST_FAIL_IF_UNSORTED;
1548 	}
1549 
1550 	int res = gnutls_x509_crt_list_import2(&certs, &certs_size, &dpem, pem ? GNUTLS_X509_FMT_PEM : GNUTLS_X509_FMT_DER, flags);
1551 	if (res == GNUTLS_E_CERTIFICATE_LIST_UNSORTED) {
1552 		sort = false;
1553 		flags |= GNUTLS_X509_CRT_LIST_SORT;
1554 		res = gnutls_x509_crt_list_import2(&certs, &certs_size, &dpem, pem ? GNUTLS_X509_FMT_PEM : GNUTLS_X509_FMT_DER, flags);
1555 	}
1556 
1557 	if (res != GNUTLS_E_SUCCESS) {
1558 		certs = nullptr;
1559 		certs_size = 0;
1560 	}
1561 	return res;
1562 }
1563 
get_sorted_peer_certificates(gnutls_x509_crt_t * & certs,unsigned int & certs_size)1564 bool tls_layer_impl::get_sorted_peer_certificates(gnutls_x509_crt_t *& certs, unsigned int & certs_size)
1565 {
1566 	certs = nullptr;
1567 	certs_size = 0;
1568 
1569 	// First get unsorted list of peer certificates in DER
1570 	unsigned int cert_list_size;
1571 	const gnutls_datum_t* cert_list = gnutls_certificate_get_peers(session_, &cert_list_size);
1572 	if (!cert_list || !cert_list_size) {
1573 		logger_.log(logmsg::error, fztranslate("gnutls_certificate_get_peers returned no certificates"));
1574 		return false;
1575 	}
1576 
1577 	// Convert them all to PEM
1578 	// Avoid gnutls_pem_base64_encode2, excessive allocations.
1579 	auto constexpr header = "-----BEGIN CERTIFICATE-----\r\n"sv;
1580 	auto constexpr footer = "\r\n-----END CERTIFICATE-----\r\n"sv;
1581 
1582 	size_t cap = cert_list_size * header.size() + footer.size();
1583 	for (unsigned i = 0; i < cert_list_size; ++i) {
1584 		cap += ((cert_list[i].size + 2) / 3) * 4;
1585 	}
1586 
1587 	std::string pem;
1588 	pem.reserve(cap);
1589 
1590 	for (unsigned i = 0; i < cert_list_size; ++i) {
1591 		pem += header;
1592 		base64_encode_append(pem, to_view(cert_list[i]), base64_type::standard, true);
1593 		pem += footer;
1594 	}
1595 
1596 	// And now import the certificates
1597 	bool sort = true;
1598 	int res = load_certificates(pem, true, certs, certs_size, sort);
1599 	if (res == GNUTLS_E_CERTIFICATE_LIST_UNSORTED) {
1600 		logger_.log(logmsg::error, fztranslate("Could not sort peer certificates"));
1601 	}
1602 	else if (!sort) {
1603 		logger_.log(logmsg::error, fztranslate("Server sent unsorted certificate chain in violation of the TLS specifications"));
1604 	}
1605 
1606 	return res == GNUTLS_E_SUCCESS;
1607 }
1608 
log_verification_error(int status)1609 void tls_layer_impl::log_verification_error(int status)
1610 {
1611 	gnutls_datum_t buffer{};
1612 	gnutls_certificate_verification_status_print(status, GNUTLS_CRT_X509, &buffer, 0);
1613 	if (buffer.data) {
1614 		logger_.log(logmsg::debug_warning, L"Gnutls Verification status: %s", buffer.data);
1615 		gnutls_free(buffer.data);
1616 	}
1617 
1618 	if (status & GNUTLS_CERT_REVOKED) {
1619 		logger_.log(logmsg::error, fztranslate("Beware! Certificate has been revoked"));
1620 
1621 		// The remaining errors are no longer of interest
1622 		return;
1623 	}
1624 	if (status & GNUTLS_CERT_SIGNATURE_FAILURE) {
1625 		logger_.log(logmsg::error, fztranslate("Certificate signature verification failed"));
1626 		status &= ~GNUTLS_CERT_SIGNATURE_FAILURE;
1627 	}
1628 	if (status & GNUTLS_CERT_INSECURE_ALGORITHM) {
1629 		logger_.log(logmsg::error, fztranslate("A certificate in the chain was signed using an insecure algorithm"));
1630 		status &= ~GNUTLS_CERT_INSECURE_ALGORITHM;
1631 	}
1632 	if (status & GNUTLS_CERT_SIGNER_NOT_CA) {
1633 		logger_.log(logmsg::error, fztranslate("An issuer in the certificate chain is not a certificate authority"));
1634 		status &= ~GNUTLS_CERT_SIGNER_NOT_CA;
1635 	}
1636 	if (status & GNUTLS_CERT_UNEXPECTED_OWNER) {
1637 		logger_.log(logmsg::error, fztranslate("The server's hostname does not match the certificate's hostname"));
1638 		status &= ~GNUTLS_CERT_UNEXPECTED_OWNER;
1639 	}
1640 #ifdef GNUTLS_CERT_MISSING_OCSP_STATUS
1641 	if (status & GNUTLS_CERT_MISSING_OCSP_STATUS) {
1642 		logger_.log(logmsg::error, fztranslate("The certificate requires the server to include an OCSP status in its response, but the OCSP status is missing."));
1643 		status &= ~GNUTLS_CERT_MISSING_OCSP_STATUS;
1644 	}
1645 #endif
1646 	if (status) {
1647 		if (status == GNUTLS_CERT_INVALID) {
1648 			logger_.log(logmsg::error, fztranslate("Received certificate chain could not be verified."));
1649 		}
1650 		else {
1651 			logger_.log(logmsg::error, fztranslate("Received certificate chain could not be verified. Verification status is %d."), status);
1652 		}
1653 	}
1654 
1655 }
1656 
verify_certificate()1657 int tls_layer_impl::verify_certificate()
1658 {
1659 	logger_.log(logmsg::debug_verbose, L"tls_layer_impl::verify_certificate()");
1660 
1661 	if (state_ != socket_state::connecting) {
1662 		logger_.log(logmsg::debug_warning, L"verify_certificate called at wrong time");
1663 		return ENOTCONN;
1664 	}
1665 
1666 	if (gnutls_certificate_type_get(session_) != GNUTLS_CRT_X509) {
1667 		logger_.log(logmsg::error, fztranslate("Unsupported certificate type"));
1668 		failure(0, true);
1669 		return EOPNOTSUPP;
1670 	}
1671 
1672 	cert_list_holder certs;
1673 	if (!get_sorted_peer_certificates(certs.certs, certs.certs_size)) {
1674 		failure(0, true);
1675 		return EINVAL;
1676 	}
1677 
1678 	if (certificate_is_blacklisted(certs)) {
1679 		logger_.log(logmsg::error, fztranslate("Man-in-the-Middle attack detected, aborting connection."));
1680 		failure(0, true);
1681 		return EINVAL;
1682 	}
1683 
1684 	if (!required_certificate_.empty()) {
1685 		datum_holder cert_der{};
1686 		int res = gnutls_x509_crt_export2(certs.certs[0], GNUTLS_X509_FMT_DER, &cert_der);
1687 		if (res != GNUTLS_E_SUCCESS) {
1688 			failure(res, true, L"gnutls_x509_crt_export2");
1689 			return ECONNABORTED;
1690 		}
1691 
1692 		if (required_certificate_.size() != cert_der.size ||
1693 			memcmp(required_certificate_.data(), cert_der.data, cert_der.size))
1694 		{
1695 			logger_.log(logmsg::error, fztranslate("Certificate of connection does not match expected certificate."));
1696 			failure(0, true);
1697 			return EINVAL;
1698 		}
1699 
1700 		set_verification_result(true);
1701 
1702 		if (state_ != socket_state::connected && state_ != socket_state::shutting_down && state_ != socket_state::shut_down) {
1703 			return ECONNABORTED;
1704 		}
1705 		return 0;
1706 	}
1707 
1708 	bool const uses_hostname = !hostname_.empty() && get_address_type(hostname_) == address_type::unknown;
1709 
1710 	bool systemTrust = false;
1711 	bool hostnameMismatch = false;
1712 
1713 	// Our trust-model is user-guided TOFU on the host's certificate.
1714 	//
1715 	// First we verify it against the system trust store.
1716 	//
1717 	// If that fails, we validate the certificate chain sent by the server
1718 	// allowing three impairments:
1719 	// - Hostname mismatch
1720 	// - Out of validity
1721 	// - Signer not found
1722 	//
1723 	// In any case, actual trust decision is done later by the user.
1724 
1725 
1726 	// First, check system trust
1727 	if (uses_hostname && system_trust_store_) {
1728 
1729 		auto lease = system_trust_store_->impl_->lease();
1730 		auto cred = std::get<0>(lease);
1731 		if (cred) {
1732 			gnutls_credentials_set(session_, GNUTLS_CRD_CERTIFICATE, cred);
1733 			unsigned int status = 0;
1734 			int verifyResult = gnutls_certificate_verify_peers3(session_, to_utf8(hostname_).c_str(), &status);
1735 			gnutls_credentials_set(session_, GNUTLS_CRD_CERTIFICATE, cert_credentials_);
1736 			std::get<1>(lease).unlock();
1737 
1738 			if (verifyResult < 0) {
1739 				logger_.log(logmsg::debug_warning, L"gnutls_certificate_verify_peers2 returned %d with status %u", verifyResult, status);
1740 				logger_.log(logmsg::error, fztranslate("Failed to verify peer certificate"));
1741 				failure(0, true);
1742 				return EINVAL;
1743 			}
1744 
1745 			if (!status) {
1746 				systemTrust = true;
1747 			}
1748 		}
1749 		else {
1750 			std::get<1>(lease).unlock();
1751 			logger_.log(logmsg::debug_warning, L"System trust store could not be loaded");
1752 		}
1753 	}
1754 
1755 	if (!verification_handler_) {
1756 		set_verification_result(systemTrust);
1757 		return systemTrust ? 0 : ECONNABORTED;
1758 	}
1759 	else {
1760 		if (!systemTrust) {
1761 			// System trust store cannot verify this certificate. Allow three impairments:
1762 			//
1763 			// 1. For now, add the highest certificate from the chain to trust list. Otherwise
1764 			// gnutls_certificate_verify_peers2 always stops with GNUTLS_CERT_SIGNER_NOT_FOUND
1765 			// at the highest certificate in the chain.
1766 			gnutls_x509_crt_t root{};
1767 			clone_cert(certs.certs[certs.certs_size - 1], root);
1768 			if (!root) {
1769 				logger_.log(logmsg::error, fztranslate("Could not copy certificate"));
1770 				failure(0, true);
1771 				return ECONNABORTED;
1772 			}
1773 
1774 			gnutls_x509_trust_list_t tlist;
1775 			gnutls_certificate_get_trust_list(cert_credentials_, &tlist);
1776 			if (gnutls_x509_trust_list_add_cas(tlist, &root, 1, 0) != 1) {
1777 				logger_.log(logmsg::error, fztranslate("Could not add certificate to temporary trust list"));
1778 				failure(0, true);
1779 				return ECONNABORTED;
1780 			}
1781 
1782 			// 2. Also disable time checks. We allow expired/not yet valid certificates, though only
1783 			// after explicit user confirmation.
1784 			gnutls_certificate_set_verify_flags(cert_credentials_, gnutls_certificate_get_verify_flags(cert_credentials_) | GNUTLS_VERIFY_DISABLE_TIME_CHECKS | GNUTLS_VERIFY_DISABLE_TRUSTED_TIME_CHECKS);
1785 
1786 			unsigned int status = 0;
1787 			int verifyResult = gnutls_certificate_verify_peers2(session_, &status);
1788 
1789 			if (verifyResult < 0) {
1790 				logger_.log(logmsg::debug_warning, L"gnutls_certificate_verify_peers2 returned %d with status %u", verifyResult, status);
1791 				logger_.log(logmsg::error, fztranslate("Failed to verify peer certificate"));
1792 				failure(0, true);
1793 				return EINVAL;
1794 			}
1795 
1796 			if (status != 0) {
1797 				log_verification_error(status);
1798 
1799 				failure(0, true);
1800 				return EINVAL;
1801 			}
1802 
1803 			// 3. Hostname mismatch
1804 			if (uses_hostname) {
1805 				if (!gnutls_x509_crt_check_hostname(certs.certs[0], to_utf8(hostname_).c_str())) {
1806 					hostnameMismatch = true;
1807 					logger_.log(logmsg::debug_warning, L"Hostname does not match certificate SANs");
1808 				}
1809 			}
1810 		}
1811 
1812 		logger_.log(logmsg::status, fztranslate("Verifying certificate..."));
1813 
1814 		std::vector<x509_certificate> certificates;
1815 		certificates.reserve(certs.certs_size);
1816 		for (unsigned int i = 0; i < certs.certs_size; ++i) {
1817 			x509_certificate cert;
1818 			if (extract_cert(certs.certs[i], cert, i + 1 == certs.certs_size, &logger_)) {
1819 				certificates.emplace_back(std::move(cert));
1820 			}
1821 			else {
1822 				failure(0, true);
1823 				return ECONNABORTED;
1824 			}
1825 		}
1826 
1827 		// Lengthen incomplete chains to the root using the trust store.
1828 		if (!certificates.empty() && !certificates.back().self_signed() && system_trust_store_) {
1829 			auto lease = system_trust_store_->impl_->lease();
1830 			auto cred = std::get<0>(lease);
1831 			if (cred) {
1832 				gnutls_x509_crt_t cert = certs.certs[certs.certs_size - 1];
1833 				while (!certificates.back().self_signed()) {
1834 					gnutls_x509_crt_t issuer{};
1835 					if (gnutls_certificate_get_issuer(cred, cert, &issuer, 0) || !issuer) {
1836 						break;
1837 					}
1838 
1839 					// Why is this cert even in the trust store? Antivirus MITM?
1840 					if (certificate_is_blacklisted(issuer)) {
1841 						logger_.log(logmsg::error, fztranslate("Man-in-the-Middle attack detected, aborting connection."));
1842 						failure(0, true);
1843 						return EINVAL;
1844 					}
1845 
1846 					x509_certificate out;
1847 					if (!extract_cert(issuer, out, true, &logger_)) {
1848 						failure(0, true);
1849 						return ECONNABORTED;
1850 					}
1851 					certificates.push_back(out);
1852 					cert = issuer;
1853 				}
1854 			}
1855 		}
1856 
1857 		int const algorithmWarnings = get_algorithm_warnings();
1858 
1859 		int error;
1860 		auto port = tls_layer_.peer_port(error);
1861 		if (port == -1) {
1862 			socket_error_ = error;
1863 			failure(0, true);
1864 			return ECONNABORTED;
1865 		}
1866 
1867 		tls_session_info session_info(
1868 			to_utf8(to_wstring(hostname_)),
1869 			port,
1870 			get_protocol(),
1871 			get_key_exchange(),
1872 			get_cipher(),
1873 			get_mac(),
1874 			algorithmWarnings,
1875 			std::move(certificates),
1876 			systemTrust,
1877 			hostnameMismatch
1878 		);
1879 
1880 		verification_handler_->send_event<certificate_verification_event>(&tls_layer_, std::move(session_info));
1881 
1882 		return EAGAIN;
1883 	}
1884 }
1885 
get_protocol() const1886 std::string tls_layer_impl::get_protocol() const
1887 {
1888 	std::string ret;
1889 
1890 	char const* s = gnutls_protocol_get_name(gnutls_protocol_get_version(session_));
1891 	if (s && *s) {
1892 		ret = s;
1893 	}
1894 
1895 	if (ret.empty()) {
1896 		ret = to_utf8(fztranslate("unknown"));
1897 	}
1898 
1899 	return ret;
1900 }
1901 
get_key_exchange() const1902 std::string tls_layer_impl::get_key_exchange() const
1903 {
1904 	std::string ret;
1905 
1906 	char const* s{};
1907 	gnutls_kx_algorithm_t alg = gnutls_kx_get(session_);
1908 	bool const dh = (alg == GNUTLS_KX_DHE_RSA || alg == GNUTLS_KX_DHE_DSS);
1909 	bool const ecdh = (alg == GNUTLS_KX_ECDHE_RSA || alg == GNUTLS_KX_ECDHE_ECDSA);
1910 	if (dh || ecdh) {
1911 		char const* const signature_name = gnutls_sign_get_name(static_cast<gnutls_sign_algorithm_t>(gnutls_sign_algorithm_get(session_)));
1912 		ret = (ecdh ? "ECDHE" : "DHE");
1913 #if GNUTLS_VERSION_NUMBER >= 0x030600
1914 		s = gnutls_group_get_name(gnutls_group_get(session_));
1915 		if (s && *s) {
1916 			ret += "-";
1917 			ret += s;
1918 		}
1919 #endif
1920 		if (signature_name && *signature_name) {
1921 			ret += "-";
1922 			ret += signature_name;
1923 		}
1924 	}
1925 	else {
1926 		s = gnutls_kx_get_name(alg);
1927 		if (s && *s) {
1928 			ret = s;
1929 		}
1930 	}
1931 
1932 
1933 	if (ret.empty()) {
1934 		ret = to_utf8(fztranslate("unknown"));
1935 	}
1936 
1937 	return ret;
1938 }
1939 
get_cipher() const1940 std::string tls_layer_impl::get_cipher() const
1941 {
1942 	std::string ret;
1943 
1944 	char const* cipher = gnutls_cipher_get_name(gnutls_cipher_get(session_));
1945 	if (cipher && *cipher) {
1946 		ret = cipher;
1947 	}
1948 
1949 	if (ret.empty()) {
1950 		ret = to_utf8(fztranslate("unknown"));
1951 	}
1952 
1953 	return ret;
1954 }
1955 
get_mac() const1956 std::string tls_layer_impl::get_mac() const
1957 {
1958 	std::string ret;
1959 
1960 	char const* mac = gnutls_mac_get_name(gnutls_mac_get(session_));
1961 	if (mac && *mac) {
1962 		ret = mac;
1963 	}
1964 
1965 	if (ret.empty()) {
1966 		ret = to_utf8(fztranslate("unknown"));
1967 	}
1968 
1969 	return ret;
1970 }
1971 
list_tls_ciphers(std::string const & priority)1972 std::string tls_layer_impl::list_tls_ciphers(std::string const& priority)
1973 {
1974 	auto list = sprintf("Ciphers for %s:\n", priority.empty() ? ciphers : priority);
1975 
1976 	gnutls_priority_t pcache;
1977 	char const* err = nullptr;
1978 	int ret = gnutls_priority_init(&pcache, priority.empty() ? ciphers : priority.c_str(), &err);
1979 	if (ret < 0) {
1980 		list += sprintf("gnutls_priority_init failed with code %d: %s", ret, err ? err : "Unknown error");
1981 		return list;
1982 	}
1983 	else {
1984 		for (unsigned int i = 0; ; ++i) {
1985 			unsigned int idx;
1986 			ret = gnutls_priority_get_cipher_suite_index(pcache, i, &idx);
1987 			if (ret == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE) {
1988 				break;
1989 			}
1990 			if (ret == GNUTLS_E_UNKNOWN_CIPHER_SUITE) {
1991 				continue;
1992 			}
1993 
1994 			gnutls_protocol_t version;
1995 			unsigned char id[2];
1996 			char const* name = gnutls_cipher_suite_info(idx, id, nullptr, nullptr, nullptr, &version);
1997 
1998 			if (name != nullptr) {
1999 				list += sprintf(
2000 					"%-50s    0x%02x, 0x%02x    %s\n",
2001 					name,
2002 					(unsigned char)id[0],
2003 					(unsigned char)id[1],
2004 					gnutls_protocol_get_name(version));
2005 			}
2006 		}
2007 	}
2008 
2009 	return list;
2010 }
2011 
do_call_gnutls_record_recv(void * data,size_t len)2012 int tls_layer_impl::do_call_gnutls_record_recv(void* data, size_t len)
2013 {
2014 	ssize_t res = gnutls_record_recv(session_, data, len);
2015 	while ((res == GNUTLS_E_AGAIN || res == GNUTLS_E_INTERRUPTED) && can_read_from_socket_ && !gnutls_record_get_direction(session_)) {
2016 		// Spurious EAGAIN. Can happen if GnuTLS gets a partial
2017 		// record and the socket got closed.
2018 		// The unexpected close is being ignored in this case, unless
2019 		// gnutls_record_recv is being called again.
2020 		// Manually call gnutls_record_recv as in case of eof on the socket,
2021 		// we are not getting another receive event.
2022 		logger_.log(logmsg::debug_verbose, L"gnutls_record_recv returned spurious EAGAIN");
2023 		res = gnutls_record_recv(session_, data, len);
2024 	}
2025 
2026 	if ((res == GNUTLS_E_AGAIN || res == GNUTLS_E_INTERRUPTED) && socket_error_) {
2027 		res = GNUTLS_E_PULL_ERROR;
2028 	}
2029 
2030 	return static_cast<int>(res);
2031 }
2032 
get_gnutls_version()2033 std::string tls_layer_impl::get_gnutls_version()
2034 {
2035 	char const* v = gnutls_check_version(nullptr);
2036 	if (!v || !*v) {
2037 		return "unknown";
2038 	}
2039 
2040 	return v;
2041 }
2042 
set_hostname(native_string const & host)2043 void tls_layer_impl::set_hostname(native_string const& host)
2044 {
2045 	hostname_ = host;
2046 	if (session_ && !hostname_.empty() && get_address_type(hostname_) == address_type::unknown) {
2047 		auto const utf8 = to_utf8(hostname_);
2048 		if (!utf8.empty()) {
2049 			int res = gnutls_server_name_set(session_, GNUTLS_NAME_DNS, utf8.c_str(), utf8.size());
2050 			if (res) {
2051 				log_error(res, L"gnutls_server_name_set", logmsg::debug_warning);
2052 			}
2053 		}
2054 	}
2055 }
2056 
get_hostname() const2057 native_string tls_layer_impl::get_hostname() const
2058 {
2059 	if (!session_) {
2060 		return {};
2061 	}
2062 
2063 	size_t len{};
2064 	unsigned int type{};
2065 	unsigned int i{};
2066 	int ret;
2067 	do {
2068 		ret = gnutls_server_name_get(session_, nullptr, &len, &type, i++);
2069 	}
2070 	while (ret == GNUTLS_E_SHORT_MEMORY_BUFFER && type != GNUTLS_NAME_DNS);
2071 
2072 	if (ret == GNUTLS_E_SHORT_MEMORY_BUFFER) {
2073 		std::string name;
2074 		name.resize(len - 1);
2075 		ret = gnutls_server_name_get(session_, name.data(), &len, &type, --i);
2076 		if (!ret) {
2077 			return fz::to_native(name);
2078 		}
2079 	}
2080 
2081 	return {};
2082 }
2083 
connect(native_string const & host,unsigned int port,address_type family)2084 int tls_layer_impl::connect(native_string const& host, unsigned int port, address_type family)
2085 {
2086 	if (hostname_.empty()) {
2087 		set_hostname(host);
2088 	}
2089 
2090 	return tls_layer_.next_layer_.connect(host, port, family);
2091 }
2092 
2093 namespace {
append_with_size(uint8_t * & p,std::vector<uint8_t> const & d)2094 void append_with_size(uint8_t * &p, std::vector<uint8_t> const& d)
2095 {
2096 	size_t s = d.size();
2097 	memcpy(p, &s, sizeof(s));
2098 	p += sizeof(s);
2099 	if (s) {
2100 		memcpy(p, d.data(), s);
2101 		p += s;
2102 	}
2103 }
2104 }
2105 
get_session_parameters() const2106 std::vector<uint8_t> tls_layer_impl::get_session_parameters() const
2107 {
2108 	std::vector<uint8_t> ret;
2109 
2110 	if (is_client()) {
2111 		datum_holder d;
2112 		int res = gnutls_session_get_data2(session_, &d);
2113 		if (res) {
2114 			logger_.log(logmsg::debug_warning, L"gnutls_session_get_data2 failed: %d", res);
2115 		}
2116 		else {
2117 			ret.assign(d.data, d.data + d.size);
2118 		}
2119 	}
2120 	else {
2121 		ret.resize(sizeof(size_t) * 3 + ticket_key_.size() + session_db_key_.size() + session_db_data_.size());
2122 		auto* p = ret.data();
2123 		append_with_size(p, ticket_key_);
2124 		append_with_size(p, session_db_key_);
2125 		append_with_size(p, session_db_data_);
2126 	}
2127 
2128 	return ret;
2129 }
2130 
get_raw_certificate() const2131 std::vector<uint8_t> tls_layer_impl::get_raw_certificate() const
2132 {
2133 	std::vector<uint8_t> ret;
2134 
2135 	// Implicitly trust certificate of primary socket
2136 	unsigned int cert_list_size;
2137 	gnutls_datum_t const* const cert_list = gnutls_certificate_get_peers(session_, &cert_list_size);
2138 	if (cert_list && cert_list_size) {
2139 		ret.assign(cert_list[0].data, cert_list[0].data + cert_list[0].size);
2140 	}
2141 
2142 	return ret;
2143 }
2144 
generate_selfsigned_certificate(native_string const & password,std::string const & distinguished_name,std::vector<std::string> const & hostnames)2145 std::pair<std::string, std::string> tls_layer_impl::generate_selfsigned_certificate(native_string const& password, std::string const& distinguished_name, std::vector<std::string> const& hostnames)
2146 {
2147 	std::pair<std::string, std::string> ret;
2148 
2149 	gnutls_x509_privkey_t priv;
2150 	int res = gnutls_x509_privkey_init(&priv);
2151 	if (res) {
2152 		return ret;
2153 	}
2154 
2155 	auto fmt = GNUTLS_PK_ECDSA;
2156 	unsigned int bits = gnutls_sec_param_to_pk_bits(fmt, GNUTLS_SEC_PARAM_HIGH);
2157 	if (fmt == GNUTLS_PK_RSA && bits < 2048) {
2158 		bits = 2048;
2159 	}
2160 
2161 	res = gnutls_x509_privkey_generate(priv, fmt, bits, 0);
2162 	if (res) {
2163 		gnutls_x509_privkey_deinit(priv);
2164 		return ret;
2165 	}
2166 
2167 	datum_holder kh;
2168 
2169 	if (password.empty()) {
2170 		res = gnutls_x509_privkey_export2(priv, GNUTLS_X509_FMT_PEM, &kh);
2171 	}
2172 	else {
2173 		res = gnutls_x509_privkey_export2_pkcs8(priv, GNUTLS_X509_FMT_PEM, to_utf8(password).c_str(), 0, &kh);
2174 	}
2175 	if (res) {
2176 		gnutls_x509_privkey_deinit(priv);
2177 		return ret;
2178 	}
2179 
2180 	gnutls_x509_crt_t crt;
2181 	res = gnutls_x509_crt_init(&crt);
2182 	if (res) {
2183 		gnutls_x509_privkey_deinit(priv);
2184 		return ret;
2185 	}
2186 
2187 	res = gnutls_x509_crt_set_version(crt, 3);
2188 	if (res) {
2189 		gnutls_x509_privkey_deinit(priv);
2190 		gnutls_x509_crt_deinit(crt);
2191 		return ret;
2192 	}
2193 
2194 	res = gnutls_x509_crt_set_key(crt, priv);
2195 	if (res) {
2196 		gnutls_x509_privkey_deinit(priv);
2197 		gnutls_x509_crt_deinit(crt);
2198 		return ret;
2199 	}
2200 
2201 	char const* out{};
2202 	res = gnutls_x509_crt_set_dn(crt, distinguished_name.c_str(), &out);
2203 	if (res) {
2204 		gnutls_x509_privkey_deinit(priv);
2205 		gnutls_x509_crt_deinit(crt);
2206 		return ret;
2207 	}
2208 
2209 	for (auto const& hostname : hostnames) {
2210 		res = gnutls_x509_crt_set_subject_alt_name(crt, GNUTLS_SAN_DNSNAME, hostname.c_str(), hostname.size(), GNUTLS_FSAN_APPEND);
2211 		if (res) {
2212 			gnutls_x509_privkey_deinit(priv);
2213 			gnutls_x509_crt_deinit(crt);
2214 			return ret;
2215 		}
2216 	}
2217 
2218 	res = gnutls_x509_crt_set_serial(crt, random_bytes(20).data(), 20);
2219 	if (res) {
2220 		gnutls_x509_privkey_deinit(priv);
2221 		gnutls_x509_crt_deinit(crt);
2222 		return ret;
2223 	}
2224 
2225 	auto const now = datetime::now();
2226 
2227 	res = gnutls_x509_crt_set_activation_time(crt, (now - duration::from_minutes(5)).get_time_t());
2228 	if (res) {
2229 		gnutls_x509_privkey_deinit(priv);
2230 		gnutls_x509_crt_deinit(crt);
2231 		return ret;
2232 	}
2233 	res = gnutls_x509_crt_set_expiration_time(crt, (now + duration::from_days(366)).get_time_t());
2234 	if (res) {
2235 		gnutls_x509_privkey_deinit(priv);
2236 		gnutls_x509_crt_deinit(crt);
2237 		return ret;
2238 	}
2239 
2240 	res = gnutls_x509_crt_set_key_usage(crt, GNUTLS_KEY_DIGITAL_SIGNATURE | GNUTLS_KEY_KEY_ENCIPHERMENT);
2241 	if (res) {
2242 		gnutls_x509_privkey_deinit(priv);
2243 		gnutls_x509_crt_deinit(crt);
2244 		return ret;
2245 	}
2246 
2247 	res = gnutls_x509_crt_set_basic_constraints(crt, 0, -1);
2248 	if (res) {
2249 		gnutls_x509_privkey_deinit(priv);
2250 		gnutls_x509_crt_deinit(crt);
2251 		return ret;
2252 	}
2253 
2254 	res = gnutls_x509_crt_sign2(crt, crt, priv, GNUTLS_DIG_SHA256, 0);
2255 	if (res) {
2256 		gnutls_x509_privkey_deinit(priv);
2257 		gnutls_x509_crt_deinit(crt);
2258 		return ret;
2259 	}
2260 
2261 	datum_holder ch;
2262 	res = gnutls_x509_crt_export2(crt, GNUTLS_X509_FMT_PEM, &ch);
2263 	if (res) {
2264 		gnutls_x509_privkey_deinit(priv);
2265 		gnutls_x509_crt_deinit(crt);
2266 		return ret;
2267 	}
2268 
2269 	gnutls_x509_privkey_deinit(priv);
2270 	gnutls_x509_crt_deinit(crt);
2271 	ret.first = kh.to_string();
2272 	ret.second = ch.to_string();
2273 
2274 	return ret;
2275 }
2276 
generate_csr(native_string const & password,std::string const & distinguished_name,std::vector<std::string> const & hostnames,bool csr_as_pem)2277 std::pair<std::string, std::string> tls_layer_impl::generate_csr(native_string const& password, std::string const& distinguished_name, std::vector<std::string> const& hostnames, bool csr_as_pem)
2278 {
2279 	std::pair<std::string, std::string> ret;
2280 
2281 	gnutls_x509_privkey_t priv;
2282 	int res = gnutls_x509_privkey_init(&priv);
2283 	if (res) {
2284 		return ret;
2285 	}
2286 
2287 	auto fmt = GNUTLS_PK_ECDSA;
2288 	unsigned int bits = gnutls_sec_param_to_pk_bits(fmt, GNUTLS_SEC_PARAM_HIGH);
2289 	if (fmt == GNUTLS_PK_RSA && bits < 2048) {
2290 		bits = 2048;
2291 	}
2292 
2293 	res = gnutls_x509_privkey_generate(priv, fmt, bits, 0);
2294 	if (res) {
2295 		gnutls_x509_privkey_deinit(priv);
2296 		return ret;
2297 	}
2298 
2299 	datum_holder kh;
2300 
2301 	if (password.empty()) {
2302 		res = gnutls_x509_privkey_export2(priv, GNUTLS_X509_FMT_PEM, &kh);
2303 	}
2304 	else {
2305 		res = gnutls_x509_privkey_export2_pkcs8(priv, GNUTLS_X509_FMT_PEM, to_utf8(password).c_str(), 0, &kh);
2306 	}
2307 	if (res) {
2308 		gnutls_x509_privkey_deinit(priv);
2309 		return ret;
2310 	}
2311 
2312 	gnutls_x509_crq_t crq;
2313 	res = gnutls_x509_crq_init(&crq);
2314 	if (res) {
2315 		gnutls_x509_privkey_deinit(priv);
2316 		return ret;
2317 	}
2318 
2319 	res = gnutls_x509_crq_set_version(crq, 3);
2320 	if (res) {
2321 		gnutls_x509_privkey_deinit(priv);
2322 		gnutls_x509_crq_deinit(crq);
2323 		return ret;
2324 	}
2325 
2326 	res = gnutls_x509_crq_set_key(crq, priv);
2327 	if (res) {
2328 		gnutls_x509_privkey_deinit(priv);
2329 		gnutls_x509_crq_deinit(crq);
2330 		return ret;
2331 	}
2332 
2333 	char const* out{};
2334 	res = gnutls_x509_crq_set_dn(crq, distinguished_name.c_str(), &out);
2335 	if (res) {
2336 		gnutls_x509_privkey_deinit(priv);
2337 		gnutls_x509_crq_deinit(crq);
2338 		return ret;
2339 	}
2340 
2341 	for (auto const& hostname : hostnames) {
2342 		res = gnutls_x509_crq_set_subject_alt_name(crq, GNUTLS_SAN_DNSNAME, hostname.c_str(), hostname.size(), GNUTLS_FSAN_APPEND);
2343 		if (res) {
2344 			gnutls_x509_privkey_deinit(priv);
2345 			gnutls_x509_crq_deinit(crq);
2346 			return ret;
2347 		}
2348 	}
2349 
2350 	res = gnutls_x509_crq_set_key_usage(crq, GNUTLS_KEY_DIGITAL_SIGNATURE | GNUTLS_KEY_KEY_ENCIPHERMENT);
2351 	if (res) {
2352 		gnutls_x509_privkey_deinit(priv);
2353 		gnutls_x509_crq_deinit(crq);
2354 		return ret;
2355 	}
2356 
2357 	res = gnutls_x509_crq_set_basic_constraints(crq, 0, -1);
2358 	if (res) {
2359 		gnutls_x509_privkey_deinit(priv);
2360 		gnutls_x509_crq_deinit(crq);
2361 		return ret;
2362 	}
2363 
2364 	res = gnutls_x509_crq_sign2(crq, priv, GNUTLS_DIG_SHA256, 0);
2365 	if (res) {
2366 		gnutls_x509_privkey_deinit(priv);
2367 		gnutls_x509_crq_deinit(crq);
2368 		return ret;
2369 	}
2370 
2371 	datum_holder ch;
2372 	res = gnutls_x509_crq_export2(crq, csr_as_pem ? GNUTLS_X509_FMT_PEM : GNUTLS_X509_FMT_DER, &ch);
2373 	if (res) {
2374 		gnutls_x509_privkey_deinit(priv);
2375 		gnutls_x509_crq_deinit(crq);
2376 		return ret;
2377 	}
2378 
2379 	gnutls_x509_privkey_deinit(priv);
2380 	gnutls_x509_crq_deinit(crq);
2381 	ret.first = kh.to_string();
2382 	ret.second = ch.to_string();
2383 
2384 	return ret;
2385 }
2386 
shutdown_read()2387 int tls_layer_impl::shutdown_read()
2388 {
2389 	if (!can_read_from_socket_) {
2390 		return EAGAIN;
2391 	}
2392 
2393 	char c{};
2394 	int error{};
2395 	int res = tls_layer_.next_layer_.read(&c, 1, error);
2396 	if (!res) {
2397 		return tls_layer_.next_layer_.shutdown_read();
2398 	}
2399 	else if (res > 0) {
2400 		// Have to fail the connection as we have now discarded data.
2401 		return ECONNABORTED;
2402 	}
2403 
2404 	if (error == EAGAIN) {
2405 		can_read_from_socket_ = false;
2406 #if DEBUG_SOCKETEVENTS
2407 		debug_can_read_ = false;
2408 #endif
2409 	}
2410 
2411 	return error;
2412 }
2413 
set_event_handler(event_handler * pEvtHandler,fz::socket_event_flag retrigger_block)2414 void tls_layer_impl::set_event_handler(event_handler* pEvtHandler, fz::socket_event_flag retrigger_block)
2415 {
2416 	write_blocked_by_send_buffer_ = false;
2417 
2418 	fz::socket_event_flag const pending = change_socket_event_handler(tls_layer_.event_handler_, pEvtHandler, &tls_layer_, retrigger_block);
2419 	tls_layer_.event_handler_ = pEvtHandler;
2420 
2421 	if (pEvtHandler) {
2422 		if (can_write_to_socket_ && (state_ == socket_state::connected || state_ == socket_state::shutting_down) && !(pending & (socket_event_flag::write | socket_event_flag::connection)) && !(retrigger_block & socket_event_flag::write)) {
2423 			pEvtHandler->send_event<socket_event>(&tls_layer_, socket_event_flag::write, 0);
2424 #if DEBUG_SOCKETEVENTS
2425 			assert(debug_can_write_);
2426 #endif
2427 		}
2428 		if (can_read_from_socket_ && (state_ == socket_state::connected || state_ == socket_state::shutting_down || state_ == socket_state::shut_down)) {
2429 			if (!(pending & socket_event_flag::read) && !(retrigger_block & socket_event_flag::read)) {
2430 				pEvtHandler->send_event<socket_event>(&tls_layer_, socket_event_flag::read, 0);
2431 #if DEBUG_SOCKETEVENTS
2432 				assert(debug_can_read_);
2433 #endif
2434 			}
2435 		}
2436 	}
2437 
2438 }
2439 
do_set_alpn()2440 bool tls_layer_impl::do_set_alpn()
2441 {
2442 	if (alpn_.empty()) {
2443 		return true;
2444 	}
2445 
2446 	gnutls_datum_t * data = new gnutls_datum_t[alpn_.size()];
2447 	for (size_t i = 0; i < alpn_.size(); ++i) {
2448 		data[i].data = reinterpret_cast<unsigned char *>(const_cast<char*>(alpn_[i].c_str()));
2449 		data[i].size = alpn_[i].size();
2450 	}
2451 	int res = gnutls_alpn_set_protocols(session_, data, alpn_.size(), GNUTLS_ALPN_MANDATORY);
2452 	delete [] data;
2453 
2454 	if (res) {
2455 		log_error(res, L"gnutls_alpn_set_protocols");
2456 	}
2457 	return res == 0;
2458 }
2459 
get_alpn() const2460 std::string tls_layer_impl::get_alpn() const
2461 {
2462 	if (session_) {
2463 		gnutls_datum_t protocol;
2464 		if (!gnutls_alpn_get_selected_protocol(session_, &protocol)) {
2465 			return to_string(protocol);
2466 		}
2467 	}
2468 	return {};
2469 }
2470 
set_min_tls_ver(tls_ver ver)2471 void tls_layer_impl::set_min_tls_ver(tls_ver ver)
2472 {
2473 	min_tls_ver_ = ver;
2474 }
2475 
set_max_tls_ver(tls_ver ver)2476 void tls_layer_impl::set_max_tls_ver(tls_ver ver)
2477 {
2478 	max_tls_ver_ = ver;
2479 }
2480 
2481 }
2482