1 /*
2 * uhub - A tiny ADC p2p connection hub
3 * Copyright (C) 2007-2014, Jan Vidar Krey
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation; either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program. If not, see <http://www.gnu.org/licenses/>.
17 *
18 */
19
20 #include "uhub.h"
21 #include "network/common.h"
22 #include "network/tls.h"
23 #include "network/backend.h"
24
25 #ifdef SSL_SUPPORT
26 #ifdef SSL_USE_OPENSSL
27
28 void net_stats_add_tx(size_t bytes);
29 void net_stats_add_rx(size_t bytes);
30
31 struct net_ssl_openssl
32 {
33 SSL* ssl;
34 BIO* bio;
35 enum ssl_state state;
36 int events;
37 int ssl_read_events;
38 int ssl_write_events;
39 uint32_t flags;
40 size_t bytes_rx;
41 size_t bytes_tx;
42 };
43
44 struct net_context_openssl
45 {
46 SSL_CTX* ssl;
47 };
48
get_handle(struct net_connection * con)49 static struct net_ssl_openssl* get_handle(struct net_connection* con)
50 {
51 uhub_assert(con);
52 return (struct net_ssl_openssl*) con->ssl;
53 }
54
55 #ifdef DEBUG
get_state_str(enum ssl_state state)56 static const char* get_state_str(enum ssl_state state)
57 {
58 switch (state)
59 {
60 case tls_st_none: return "tls_st_none";
61 case tls_st_error: return "tls_st_error";
62 case tls_st_accepting: return "tls_st_accepting";
63 case tls_st_connecting: return "tls_st_connecting";
64 case tls_st_connected: return "tls_st_connected";
65 case tls_st_disconnecting: return "tls_st_disconnecting";
66 }
67 uhub_assert(!"This should not happen - invalid state!");
68 return "(UNKNOWN STATE)";
69 }
70 #endif
71
net_ssl_set_state(struct net_ssl_openssl * handle,enum ssl_state new_state)72 static void net_ssl_set_state(struct net_ssl_openssl* handle, enum ssl_state new_state)
73 {
74 LOG_DEBUG("net_ssl_set_state(): prev_state=%s, new_state=%s", get_state_str(handle->state), get_state_str(new_state));
75 handle->state = new_state;
76 }
77
net_ssl_get_provider()78 const char* net_ssl_get_provider()
79 {
80 return OPENSSL_VERSION_TEXT;
81 }
82
net_ssl_library_init()83 int net_ssl_library_init()
84 {
85 LOG_TRACE("Initializing OpenSSL...");
86 SSL_library_init();
87 SSL_load_error_strings();
88 return 1;
89 }
90
net_ssl_library_shutdown()91 int net_ssl_library_shutdown()
92 {
93 ERR_clear_error();
94 #if OPENSSL_VERSION_NUMBER < 0x10100000L
95 ERR_remove_state(0);
96 #endif
97
98 ENGINE_cleanup();
99 CONF_modules_unload(1);
100
101 ERR_free_strings();
102 EVP_cleanup();
103 CRYPTO_cleanup_all_ex_data();
104
105 // sk_SSL_COMP_free(SSL_COMP_get_compression_methods());
106 return 1;
107 }
108
add_io_stats(struct net_ssl_openssl * handle)109 static void add_io_stats(struct net_ssl_openssl* handle)
110 {
111 #if OPENSSL_VERSION_NUMBER < 0x10100000L
112 unsigned long num_read = handle->bio->num_read;
113 unsigned long num_write = handle->bio->num_write;
114 #else
115 unsigned long num_read = BIO_number_read(handle->bio);
116 unsigned long num_write = BIO_number_written(handle->bio);
117 #endif
118
119 if (num_read > handle->bytes_rx)
120 {
121 net_stats_add_rx(num_read - handle->bytes_rx);
122 handle->bytes_rx = num_read;
123 }
124
125 if (num_write > handle->bytes_tx)
126 {
127 net_stats_add_tx(num_write - handle->bytes_tx);
128 handle->bytes_tx = num_write;
129 }
130 }
131
get_ssl_method(const char * tls_version)132 static const SSL_METHOD* get_ssl_method(const char* tls_version)
133 {
134 if (!tls_version || !*tls_version)
135 {
136 LOG_ERROR("tls_version is not set.");
137 return 0;
138 }
139
140 #if OPENSSL_VERSION_NUMBER < 0x10100000L
141 if (!strcmp(tls_version, "1.0"))
142 return TLSv1_method();
143 #if OPENSSL_VERSION_NUMBER >= 0x1000100fL
144 if (!strcmp(tls_version, "1.1"))
145 return TLSv1_1_method();
146 if (!strcmp(tls_version, "1.2"))
147 return TLSv1_2_method();
148 #endif
149
150 LOG_ERROR("Unable to recognize tls_version.");
151 return 0;
152 #else
153 LOG_WARN("tls_version is obsolete, and should not be used.");
154 return TLS_method();
155 #endif
156 }
157
158 /**
159 * Create a new SSL context.
160 */
net_ssl_context_create(const char * tls_version,const char * tls_ciphersuite)161 struct ssl_context_handle* net_ssl_context_create(const char* tls_version, const char* tls_ciphersuite)
162 {
163 struct net_context_openssl* ctx = (struct net_context_openssl*) hub_malloc_zero(sizeof(struct net_context_openssl));
164 const SSL_METHOD* ssl_method = get_ssl_method(tls_version);
165
166 if (!ssl_method)
167 {
168 hub_free(ctx);
169 return 0;
170 }
171
172 ctx->ssl = SSL_CTX_new(ssl_method);
173
174 /* Disable SSLv2 */
175 SSL_CTX_set_options(ctx->ssl, SSL_OP_NO_SSLv2);
176
177 // #ifdef SSL_OP_NO_SSLv3
178 /* Disable SSLv3 */
179 SSL_CTX_set_options(ctx->ssl, SSL_OP_NO_SSLv3);
180 // #endif
181
182 // FIXME: Why did we need this again?
183 SSL_CTX_set_quiet_shutdown(ctx->ssl, 1);
184
185 #ifdef SSL_OP_NO_COMPRESSION
186 /* Disable compression */
187 LOG_TRACE("Disabling SSL compression."); /* "CRIME" attack */
188 SSL_CTX_set_options(ctx->ssl, SSL_OP_NO_COMPRESSION);
189 #endif
190
191 /* Set preferred cipher suite */
192 if (SSL_CTX_set_cipher_list(ctx->ssl, tls_ciphersuite) != 1)
193 {
194 LOG_ERROR("Unable to set cipher suite.");
195 SSL_CTX_free(ctx->ssl);
196 hub_free(ctx);
197 return 0;
198 }
199
200 return (struct ssl_context_handle*) ctx;
201 }
202
net_ssl_context_destroy(struct ssl_context_handle * ctx_)203 void net_ssl_context_destroy(struct ssl_context_handle* ctx_)
204 {
205 struct net_context_openssl* ctx = (struct net_context_openssl*) ctx_;
206 SSL_CTX_free(ctx->ssl);
207 hub_free(ctx);
208 }
209
ssl_load_certificate(struct ssl_context_handle * ctx_,const char * pem_file)210 int ssl_load_certificate(struct ssl_context_handle* ctx_, const char* pem_file)
211 {
212 struct net_context_openssl* ctx = (struct net_context_openssl*) ctx_;
213 if (SSL_CTX_use_certificate_chain_file(ctx->ssl, pem_file) < 0)
214 {
215 LOG_ERROR("SSL_CTX_use_certificate_chain_file: %s", ERR_error_string(ERR_get_error(), NULL));
216 return 0;
217 }
218
219 return 1;
220 }
221
ssl_load_private_key(struct ssl_context_handle * ctx_,const char * pem_file)222 int ssl_load_private_key(struct ssl_context_handle* ctx_, const char* pem_file)
223 {
224 struct net_context_openssl* ctx = (struct net_context_openssl*) ctx_;
225 if (SSL_CTX_use_PrivateKey_file(ctx->ssl, pem_file, SSL_FILETYPE_PEM) < 0)
226 {
227 LOG_ERROR("SSL_CTX_use_PrivateKey_file: %s", ERR_error_string(ERR_get_error(), NULL));
228 return 0;
229 }
230 return 1;
231 }
232
ssl_check_private_key(struct ssl_context_handle * ctx_)233 int ssl_check_private_key(struct ssl_context_handle* ctx_)
234 {
235 struct net_context_openssl* ctx = (struct net_context_openssl*) ctx_;
236 if (SSL_CTX_check_private_key(ctx->ssl) != 1)
237 {
238 LOG_FATAL("SSL_CTX_check_private_key: Private key does not match the certificate public key: %s", ERR_error_string(ERR_get_error(), NULL));
239 return 0;
240 }
241 return 1;
242 }
243
handle_openssl_error(struct net_connection * con,int ret,int read)244 static int handle_openssl_error(struct net_connection* con, int ret, int read)
245 {
246 struct net_ssl_openssl* handle = get_handle(con);
247 int err = SSL_get_error(handle->ssl, ret);
248 switch (err)
249 {
250 case SSL_ERROR_ZERO_RETURN:
251 // Not really an error, but SSL was shut down.
252 return -1;
253
254 case SSL_ERROR_WANT_READ:
255 if (read)
256 handle->ssl_read_events = NET_EVENT_READ;
257 else
258 handle->ssl_write_events = NET_EVENT_READ;
259 return 0;
260
261 case SSL_ERROR_WANT_WRITE:
262 if (read)
263 handle->ssl_read_events = NET_EVENT_WRITE;
264 else
265 handle->ssl_write_events = NET_EVENT_WRITE;
266 return 0;
267
268 case SSL_ERROR_SSL:
269 net_ssl_set_state(handle, tls_st_error);
270 return -2;
271
272 case SSL_ERROR_SYSCALL:
273 net_ssl_set_state(handle, tls_st_error);
274 return -2;
275 }
276
277 return -2;
278 }
279
net_con_ssl_accept(struct net_connection * con)280 ssize_t net_con_ssl_accept(struct net_connection* con)
281 {
282 struct net_ssl_openssl* handle = get_handle(con);
283 ssize_t ret;
284 net_ssl_set_state(handle, tls_st_accepting);
285
286 ret = SSL_accept(handle->ssl);
287 LOG_PROTO("SSL_accept() ret=%d", ret);
288 if (ret > 0)
289 {
290 net_con_update(con, NET_EVENT_READ);
291 net_ssl_set_state(handle, tls_st_connected);
292 return ret;
293 }
294 return handle_openssl_error(con, ret, tls_st_accepting);
295 }
296
net_con_ssl_connect(struct net_connection * con)297 ssize_t net_con_ssl_connect(struct net_connection* con)
298 {
299 struct net_ssl_openssl* handle = get_handle(con);
300 ssize_t ret;
301 net_ssl_set_state(handle, tls_st_connecting);
302
303 ret = SSL_connect(handle->ssl);
304 LOG_PROTO("SSL_connect() ret=%d", ret);
305
306 if (ret > 0)
307 {
308 net_con_update(con, NET_EVENT_READ);
309 net_ssl_set_state(handle, tls_st_connected);
310 return ret;
311 }
312
313 ret = handle_openssl_error(con, ret, tls_st_connecting);
314 LOG_ERROR("net_con_ssl_connect: ret=%d", ret);
315 return ret;
316 }
317
net_con_ssl_handshake(struct net_connection * con,enum net_con_ssl_mode ssl_mode,struct ssl_context_handle * ssl_ctx)318 ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode ssl_mode, struct ssl_context_handle* ssl_ctx)
319 {
320 uhub_assert(con);
321 uhub_assert(ssl_ctx);
322
323 struct net_context_openssl* ctx = (struct net_context_openssl*) ssl_ctx;
324 struct net_ssl_openssl* handle = (struct net_ssl_openssl*) hub_malloc_zero(sizeof(struct net_ssl_openssl));
325
326 if (ssl_mode == net_con_ssl_mode_server)
327 {
328 handle->ssl = SSL_new(ctx->ssl);
329 if (!handle->ssl)
330 {
331 LOG_ERROR("Unable to create new SSL stream\n");
332 return -1;
333 }
334 SSL_set_fd(handle->ssl, con->sd);
335 handle->bio = SSL_get_rbio(handle->ssl);
336 con->ssl = (struct ssl_handle*) handle;
337 return net_con_ssl_accept(con);
338 }
339 else
340 {
341 handle->ssl = SSL_new(ctx->ssl);
342 SSL_set_fd(handle->ssl, con->sd);
343 handle->bio = SSL_get_rbio(handle->ssl);
344 con->ssl = (struct ssl_handle*) handle;
345 return net_con_ssl_connect(con);
346 }
347 }
348
net_ssl_send(struct net_connection * con,const void * buf,size_t len)349 ssize_t net_ssl_send(struct net_connection* con, const void* buf, size_t len)
350 {
351 struct net_ssl_openssl* handle = get_handle(con);
352
353 LOG_TRACE("net_ssl_send(), state=%d", (int) handle->state);
354
355 if (handle->state == tls_st_error)
356 return -2;
357
358 uhub_assert(handle->state == tls_st_connected);
359
360
361 ERR_clear_error();
362 ssize_t ret = SSL_write(handle->ssl, buf, len);
363 add_io_stats(handle);
364 LOG_PROTO("SSL_write(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret);
365 if (ret > 0)
366 handle->ssl_write_events = 0;
367 else
368 ret = handle_openssl_error(con, ret, 0);
369
370 net_ssl_update(con, handle->events); // Update backend only
371 return ret;
372 }
373
net_ssl_recv(struct net_connection * con,void * buf,size_t len)374 ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len)
375 {
376 struct net_ssl_openssl* handle = get_handle(con);
377 ssize_t ret;
378
379 if (handle->state == tls_st_error)
380 return -2;
381
382 if (handle->state == tls_st_accepting || handle->state == tls_st_connecting)
383 return -1;
384
385 uhub_assert(handle->state == tls_st_connected);
386
387 ERR_clear_error();
388
389 ret = SSL_read(handle->ssl, buf, len);
390 add_io_stats(handle);
391 LOG_PROTO("SSL_read(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret);
392 if (ret > 0)
393 handle->ssl_read_events = 0;
394 else
395 ret = handle_openssl_error(con, ret, 1);
396
397 net_ssl_update(con, handle->events); // Update backend only
398 return ret;
399 }
400
net_ssl_update(struct net_connection * con,int events)401 void net_ssl_update(struct net_connection* con, int events)
402 {
403 struct net_ssl_openssl* handle = get_handle(con);
404 handle->events = events;
405 net_backend_update(con, handle->events | handle->ssl_read_events | handle->ssl_write_events);
406 }
407
net_ssl_shutdown(struct net_connection * con)408 void net_ssl_shutdown(struct net_connection* con)
409 {
410 struct net_ssl_openssl* handle = get_handle(con);
411 if (handle)
412 {
413 SSL_shutdown(handle->ssl);
414 SSL_clear(handle->ssl);
415 }
416 }
417
net_ssl_destroy(struct net_connection * con)418 void net_ssl_destroy(struct net_connection* con)
419 {
420 struct net_ssl_openssl* handle = get_handle(con);
421 LOG_TRACE("net_ssl_destroy: %p", con);
422 SSL_free(handle->ssl);
423 hub_free(handle);
424 }
425
net_ssl_callback(struct net_connection * con,int events)426 void net_ssl_callback(struct net_connection* con, int events)
427 {
428 struct net_ssl_openssl* handle = get_handle(con);
429 int ret;
430
431 switch (handle->state)
432 {
433 case tls_st_none:
434 con->callback(con, events, con->ptr);
435 break;
436
437 case tls_st_error:
438 con->callback(con, NET_EVENT_ERROR, con->ptr);
439 break;
440
441 case tls_st_accepting:
442 if (net_con_ssl_accept(con) != 0)
443 con->callback(con, NET_EVENT_READ, con->ptr);
444 break;
445
446 case tls_st_connecting:
447 ret = net_con_ssl_connect(con);
448 if (ret == 0)
449 return;
450
451 if (ret > 0)
452 {
453 LOG_DEBUG("%p SSL connected!", con);
454 con->callback(con, NET_EVENT_READ, con->ptr);
455 }
456 else
457 {
458 LOG_DEBUG("%p SSL handshake failed!", con);
459 con->callback(con, NET_EVENT_ERROR, con->ptr);
460 }
461 break;
462
463 case tls_st_connected:
464 if (handle->ssl_read_events & events)
465 events |= NET_EVENT_READ;
466 if (handle->ssl_write_events & events)
467 events |= NET_EVENT_WRITE;
468 con->callback(con, events, con->ptr);
469 break;
470
471 case tls_st_disconnecting:
472 return;
473 }
474 }
475
net_ssl_get_tls_version(struct net_connection * con)476 const char* net_ssl_get_tls_version(struct net_connection* con)
477 {
478 struct net_ssl_openssl* handle = get_handle(con);
479 return SSL_get_version(handle->ssl);
480 }
481
net_ssl_get_tls_cipher(struct net_connection * con)482 const char* net_ssl_get_tls_cipher(struct net_connection* con)
483 {
484 struct net_ssl_openssl* handle = get_handle(con);
485 const SSL_CIPHER *cipher = SSL_get_current_cipher(handle->ssl);
486 return SSL_CIPHER_get_name(cipher);
487 }
488
489 #endif /* SSL_USE_OPENSSL */
490 #endif /* SSL_SUPPORT */
491
492