1 /*
2  * uhub - A tiny ADC p2p connection hub
3  * Copyright (C) 2007-2014, Jan Vidar Krey
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation; either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  *
18  */
19 
20 #include "uhub.h"
21 #include "network/common.h"
22 #include "network/tls.h"
23 #include "network/backend.h"
24 
25 #ifdef SSL_SUPPORT
26 #ifdef SSL_USE_OPENSSL
27 
28 void net_stats_add_tx(size_t bytes);
29 void net_stats_add_rx(size_t bytes);
30 
31 struct net_ssl_openssl
32 {
33 	SSL* ssl;
34 	BIO* bio;
35 	enum ssl_state state;
36 	int events;
37 	int ssl_read_events;
38 	int ssl_write_events;
39 	uint32_t flags;
40 	size_t bytes_rx;
41 	size_t bytes_tx;
42 };
43 
44 struct net_context_openssl
45 {
46 	SSL_CTX* ssl;
47 };
48 
get_handle(struct net_connection * con)49 static struct net_ssl_openssl* get_handle(struct net_connection* con)
50 {
51 	uhub_assert(con);
52 	return (struct net_ssl_openssl*) con->ssl;
53 }
54 
55 #ifdef DEBUG
get_state_str(enum ssl_state state)56 static const char* get_state_str(enum ssl_state state)
57 {
58 	switch (state)
59 	{
60 		case tls_st_none:			return "tls_st_none";
61 		case tls_st_error:			return "tls_st_error";
62 		case tls_st_accepting:		return "tls_st_accepting";
63 		case tls_st_connecting:		return "tls_st_connecting";
64 		case tls_st_connected:		return "tls_st_connected";
65 		case tls_st_disconnecting:	return "tls_st_disconnecting";
66 	}
67 	uhub_assert(!"This should not happen - invalid state!");
68 	return "(UNKNOWN STATE)";
69 }
70 #endif
71 
net_ssl_set_state(struct net_ssl_openssl * handle,enum ssl_state new_state)72 static void net_ssl_set_state(struct net_ssl_openssl* handle, enum ssl_state new_state)
73 {
74 	LOG_DEBUG("net_ssl_set_state(): prev_state=%s, new_state=%s", get_state_str(handle->state), get_state_str(new_state));
75 	handle->state = new_state;
76 }
77 
net_ssl_get_provider()78 const char* net_ssl_get_provider()
79 {
80 	return OPENSSL_VERSION_TEXT;
81 }
82 
net_ssl_library_init()83 int net_ssl_library_init()
84 {
85 	LOG_TRACE("Initializing OpenSSL...");
86 	SSL_library_init();
87 	SSL_load_error_strings();
88 	return 1;
89 }
90 
net_ssl_library_shutdown()91 int net_ssl_library_shutdown()
92 {
93 	ERR_clear_error();
94 #if OPENSSL_VERSION_NUMBER < 0x10100000L
95 	ERR_remove_state(0);
96 #endif
97 
98 	ENGINE_cleanup();
99 	CONF_modules_unload(1);
100 
101 	ERR_free_strings();
102 	EVP_cleanup();
103 	CRYPTO_cleanup_all_ex_data();
104 
105 	// sk_SSL_COMP_free(SSL_COMP_get_compression_methods());
106 	return 1;
107 }
108 
add_io_stats(struct net_ssl_openssl * handle)109 static void add_io_stats(struct net_ssl_openssl* handle)
110 {
111 #if OPENSSL_VERSION_NUMBER < 0x10100000L
112 	unsigned long num_read = handle->bio->num_read;
113 	unsigned long num_write = handle->bio->num_write;
114 #else
115 	unsigned long num_read = BIO_number_read(handle->bio);
116 	unsigned long num_write = BIO_number_written(handle->bio);
117 #endif
118 
119 	if (num_read > handle->bytes_rx)
120 	{
121 		net_stats_add_rx(num_read - handle->bytes_rx);
122 		handle->bytes_rx = num_read;
123 	}
124 
125 	if (num_write > handle->bytes_tx)
126 	{
127 		net_stats_add_tx(num_write - handle->bytes_tx);
128 		handle->bytes_tx = num_write;
129 	}
130 }
131 
get_ssl_method(const char * tls_version)132 static const SSL_METHOD* get_ssl_method(const char* tls_version)
133 {
134 	if (!tls_version || !*tls_version)
135 	{
136 		LOG_ERROR("tls_version is not set.");
137 		return 0;
138 	}
139 
140 #if OPENSSL_VERSION_NUMBER < 0x10100000L
141 	if (!strcmp(tls_version, "1.0"))
142 	  return TLSv1_method();
143 #if OPENSSL_VERSION_NUMBER >= 0x1000100fL
144 	if (!strcmp(tls_version, "1.1"))
145 	  return TLSv1_1_method();
146 	if (!strcmp(tls_version, "1.2"))
147 	  return TLSv1_2_method();
148 #endif
149 
150 	LOG_ERROR("Unable to recognize tls_version.");
151 	return 0;
152 #else
153 	LOG_WARN("tls_version is obsolete, and should not be used.");
154 	return TLS_method();
155 #endif
156 }
157 
158 /**
159  * Create a new SSL context.
160  */
net_ssl_context_create(const char * tls_version,const char * tls_ciphersuite)161 struct ssl_context_handle* net_ssl_context_create(const char* tls_version, const char* tls_ciphersuite)
162 {
163 	struct net_context_openssl* ctx = (struct net_context_openssl*) hub_malloc_zero(sizeof(struct net_context_openssl));
164 	const SSL_METHOD* ssl_method = get_ssl_method(tls_version);
165 
166 	if (!ssl_method)
167 	{
168 		hub_free(ctx);
169 		return 0;
170 	}
171 
172 	ctx->ssl = SSL_CTX_new(ssl_method);
173 
174 	/* Disable SSLv2 */
175 	SSL_CTX_set_options(ctx->ssl, SSL_OP_NO_SSLv2);
176 
177 // #ifdef SSL_OP_NO_SSLv3
178 	/* Disable SSLv3 */
179 	SSL_CTX_set_options(ctx->ssl, SSL_OP_NO_SSLv3);
180 // #endif
181 
182 	// FIXME: Why did we need this again?
183 	SSL_CTX_set_quiet_shutdown(ctx->ssl, 1);
184 
185 #ifdef SSL_OP_NO_COMPRESSION
186 	/* Disable compression */
187 	LOG_TRACE("Disabling SSL compression."); /* "CRIME" attack */
188 	SSL_CTX_set_options(ctx->ssl, SSL_OP_NO_COMPRESSION);
189 #endif
190 
191 	/* Set preferred cipher suite */
192 	if (SSL_CTX_set_cipher_list(ctx->ssl, tls_ciphersuite) != 1)
193 	{
194 		LOG_ERROR("Unable to set cipher suite.");
195 		SSL_CTX_free(ctx->ssl);
196 		hub_free(ctx);
197 		return 0;
198 	}
199 
200 	return (struct ssl_context_handle*) ctx;
201 }
202 
net_ssl_context_destroy(struct ssl_context_handle * ctx_)203 void net_ssl_context_destroy(struct ssl_context_handle* ctx_)
204 {
205 	struct net_context_openssl* ctx = (struct net_context_openssl*) ctx_;
206 	SSL_CTX_free(ctx->ssl);
207 	hub_free(ctx);
208 }
209 
ssl_load_certificate(struct ssl_context_handle * ctx_,const char * pem_file)210 int ssl_load_certificate(struct ssl_context_handle* ctx_, const char* pem_file)
211 {
212 	struct net_context_openssl* ctx = (struct net_context_openssl*) ctx_;
213 	if (SSL_CTX_use_certificate_chain_file(ctx->ssl, pem_file) < 0)
214 	{
215 		LOG_ERROR("SSL_CTX_use_certificate_chain_file: %s", ERR_error_string(ERR_get_error(), NULL));
216 		return 0;
217 	}
218 
219 	return 1;
220 }
221 
ssl_load_private_key(struct ssl_context_handle * ctx_,const char * pem_file)222 int ssl_load_private_key(struct ssl_context_handle* ctx_, const char* pem_file)
223 {
224 	struct net_context_openssl* ctx = (struct net_context_openssl*) ctx_;
225 	if (SSL_CTX_use_PrivateKey_file(ctx->ssl, pem_file, SSL_FILETYPE_PEM) < 0)
226 	{
227 		LOG_ERROR("SSL_CTX_use_PrivateKey_file: %s", ERR_error_string(ERR_get_error(), NULL));
228 		return 0;
229 	}
230 	return 1;
231 }
232 
ssl_check_private_key(struct ssl_context_handle * ctx_)233 int ssl_check_private_key(struct ssl_context_handle* ctx_)
234 {
235 	struct net_context_openssl* ctx = (struct net_context_openssl*) ctx_;
236 	if (SSL_CTX_check_private_key(ctx->ssl) != 1)
237 	{
238 		LOG_FATAL("SSL_CTX_check_private_key: Private key does not match the certificate public key: %s", ERR_error_string(ERR_get_error(), NULL));
239 		return 0;
240 	}
241 	return 1;
242 }
243 
handle_openssl_error(struct net_connection * con,int ret,int read)244 static int handle_openssl_error(struct net_connection* con, int ret, int read)
245 {
246 	struct net_ssl_openssl* handle = get_handle(con);
247 	int err = SSL_get_error(handle->ssl, ret);
248 	switch (err)
249 	{
250 		case SSL_ERROR_ZERO_RETURN:
251 			// Not really an error, but SSL was shut down.
252 			return -1;
253 
254 		case SSL_ERROR_WANT_READ:
255 			if (read)
256 				handle->ssl_read_events = NET_EVENT_READ;
257 			else
258 				handle->ssl_write_events = NET_EVENT_READ;
259 			return 0;
260 
261 		case SSL_ERROR_WANT_WRITE:
262 			if (read)
263 				handle->ssl_read_events = NET_EVENT_WRITE;
264 			else
265 				handle->ssl_write_events = NET_EVENT_WRITE;
266 			return 0;
267 
268 		case SSL_ERROR_SSL:
269 			net_ssl_set_state(handle, tls_st_error);
270 			return -2;
271 
272 		case SSL_ERROR_SYSCALL:
273 			net_ssl_set_state(handle, tls_st_error);
274 			return -2;
275 	}
276 
277 	return -2;
278 }
279 
net_con_ssl_accept(struct net_connection * con)280 ssize_t net_con_ssl_accept(struct net_connection* con)
281 {
282 	struct net_ssl_openssl* handle = get_handle(con);
283 	ssize_t ret;
284 	net_ssl_set_state(handle, tls_st_accepting);
285 
286 	ret = SSL_accept(handle->ssl);
287 	LOG_PROTO("SSL_accept() ret=%d", ret);
288 	if (ret > 0)
289 	{
290 		net_con_update(con, NET_EVENT_READ);
291 		net_ssl_set_state(handle, tls_st_connected);
292 		return ret;
293 	}
294 	return handle_openssl_error(con, ret, tls_st_accepting);
295 }
296 
net_con_ssl_connect(struct net_connection * con)297 ssize_t net_con_ssl_connect(struct net_connection* con)
298 {
299 	struct net_ssl_openssl* handle = get_handle(con);
300 	ssize_t ret;
301 	net_ssl_set_state(handle, tls_st_connecting);
302 
303 	ret = SSL_connect(handle->ssl);
304 	LOG_PROTO("SSL_connect() ret=%d", ret);
305 
306 	if (ret > 0)
307 	{
308 		net_con_update(con, NET_EVENT_READ);
309 		net_ssl_set_state(handle, tls_st_connected);
310 		return ret;
311 	}
312 
313 	ret = handle_openssl_error(con, ret, tls_st_connecting);
314 	LOG_ERROR("net_con_ssl_connect: ret=%d", ret);
315 	return ret;
316 }
317 
net_con_ssl_handshake(struct net_connection * con,enum net_con_ssl_mode ssl_mode,struct ssl_context_handle * ssl_ctx)318 ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode ssl_mode, struct ssl_context_handle* ssl_ctx)
319 {
320 	uhub_assert(con);
321 	uhub_assert(ssl_ctx);
322 
323 	struct net_context_openssl* ctx = (struct net_context_openssl*) ssl_ctx;
324 	struct net_ssl_openssl* handle = (struct net_ssl_openssl*) hub_malloc_zero(sizeof(struct net_ssl_openssl));
325 
326 	if (ssl_mode == net_con_ssl_mode_server)
327 	{
328 		handle->ssl = SSL_new(ctx->ssl);
329 		if (!handle->ssl)
330 		{
331 			LOG_ERROR("Unable to create new SSL stream\n");
332 			return -1;
333 		}
334 		SSL_set_fd(handle->ssl, con->sd);
335 		handle->bio = SSL_get_rbio(handle->ssl);
336 		con->ssl = (struct ssl_handle*) handle;
337 		return net_con_ssl_accept(con);
338 	}
339 	else
340 	{
341 		handle->ssl = SSL_new(ctx->ssl);
342 		SSL_set_fd(handle->ssl, con->sd);
343 		handle->bio = SSL_get_rbio(handle->ssl);
344 		con->ssl = (struct ssl_handle*) handle;
345 		return net_con_ssl_connect(con);
346 	}
347 }
348 
net_ssl_send(struct net_connection * con,const void * buf,size_t len)349 ssize_t net_ssl_send(struct net_connection* con, const void* buf, size_t len)
350 {
351 	struct net_ssl_openssl* handle = get_handle(con);
352 
353 	LOG_TRACE("net_ssl_send(), state=%d", (int) handle->state);
354 
355 	if (handle->state == tls_st_error)
356 		return -2;
357 
358 	uhub_assert(handle->state == tls_st_connected);
359 
360 
361 	ERR_clear_error();
362 	ssize_t ret = SSL_write(handle->ssl, buf, len);
363 	add_io_stats(handle);
364 	LOG_PROTO("SSL_write(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret);
365 	if (ret > 0)
366 		handle->ssl_write_events = 0;
367 	else
368 		ret = handle_openssl_error(con, ret, 0);
369 
370 	net_ssl_update(con, handle->events);  // Update backend only
371 	return ret;
372 }
373 
net_ssl_recv(struct net_connection * con,void * buf,size_t len)374 ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len)
375 {
376 	struct net_ssl_openssl* handle = get_handle(con);
377 	ssize_t ret;
378 
379 	if (handle->state == tls_st_error)
380 		return -2;
381 
382 	if (handle->state == tls_st_accepting || handle->state == tls_st_connecting)
383 		return -1;
384 
385 	uhub_assert(handle->state == tls_st_connected);
386 
387 	ERR_clear_error();
388 
389 	ret = SSL_read(handle->ssl, buf, len);
390 	add_io_stats(handle);
391 	LOG_PROTO("SSL_read(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret);
392 	if (ret > 0)
393 		handle->ssl_read_events = 0;
394 	else
395 		ret = handle_openssl_error(con, ret, 1);
396 
397 	net_ssl_update(con, handle->events);  // Update backend only
398 	return ret;
399 }
400 
net_ssl_update(struct net_connection * con,int events)401 void net_ssl_update(struct net_connection* con, int events)
402 {
403 	struct net_ssl_openssl* handle = get_handle(con);
404 	handle->events = events;
405 	net_backend_update(con, handle->events | handle->ssl_read_events | handle->ssl_write_events);
406 }
407 
net_ssl_shutdown(struct net_connection * con)408 void net_ssl_shutdown(struct net_connection* con)
409 {
410 	struct net_ssl_openssl* handle = get_handle(con);
411 	if (handle)
412 	{
413 		SSL_shutdown(handle->ssl);
414 		SSL_clear(handle->ssl);
415 	}
416 }
417 
net_ssl_destroy(struct net_connection * con)418 void net_ssl_destroy(struct net_connection* con)
419 {
420 	struct net_ssl_openssl* handle = get_handle(con);
421 	LOG_TRACE("net_ssl_destroy: %p", con);
422 	SSL_free(handle->ssl);
423 	hub_free(handle);
424 }
425 
net_ssl_callback(struct net_connection * con,int events)426 void net_ssl_callback(struct net_connection* con, int events)
427 {
428 	struct net_ssl_openssl* handle = get_handle(con);
429 	int ret;
430 
431 	switch (handle->state)
432 	{
433 		case tls_st_none:
434 			con->callback(con, events, con->ptr);
435 			break;
436 
437 		case tls_st_error:
438 			con->callback(con, NET_EVENT_ERROR, con->ptr);
439 			break;
440 
441 		case tls_st_accepting:
442 			if (net_con_ssl_accept(con) != 0)
443 				con->callback(con, NET_EVENT_READ, con->ptr);
444 			break;
445 
446 		case tls_st_connecting:
447 			ret = net_con_ssl_connect(con);
448 			if (ret == 0)
449 				return;
450 
451 			if (ret > 0)
452 			{
453 				LOG_DEBUG("%p SSL connected!", con);
454 				con->callback(con, NET_EVENT_READ, con->ptr);
455 			}
456 			else
457 			{
458 				LOG_DEBUG("%p SSL handshake failed!", con);
459 				con->callback(con, NET_EVENT_ERROR, con->ptr);
460 			}
461 			break;
462 
463 		case tls_st_connected:
464 			if (handle->ssl_read_events & events)
465 				events |= NET_EVENT_READ;
466 			if (handle->ssl_write_events & events)
467 				events |= NET_EVENT_WRITE;
468 			con->callback(con, events, con->ptr);
469 			break;
470 
471 		case tls_st_disconnecting:
472 			return;
473 	}
474 }
475 
net_ssl_get_tls_version(struct net_connection * con)476 const char* net_ssl_get_tls_version(struct net_connection* con)
477 {
478 	struct net_ssl_openssl* handle = get_handle(con);
479 	return SSL_get_version(handle->ssl);
480 }
481 
net_ssl_get_tls_cipher(struct net_connection * con)482 const char* net_ssl_get_tls_cipher(struct net_connection* con)
483 {
484 	struct net_ssl_openssl* handle = get_handle(con);
485 	const SSL_CIPHER *cipher = SSL_get_current_cipher(handle->ssl);
486 	return SSL_CIPHER_get_name(cipher);
487 }
488 
489 #endif /* SSL_USE_OPENSSL */
490 #endif /* SSL_SUPPORT */
491 
492