1
2 #include "config.h"
3 #include "dolog.hh"
4 #include "iputils.hh"
5 #include "lock.hh"
6 #include "tcpiohandler.hh"
7
8 #ifdef HAVE_LIBSODIUM
9 #include <sodium.h>
10 #endif /* HAVE_LIBSODIUM */
11
12 #ifdef HAVE_DNS_OVER_TLS
13 #ifdef HAVE_LIBSSL
14
15 #include <openssl/conf.h>
16 #include <openssl/err.h>
17 #include <openssl/rand.h>
18 #include <openssl/ssl.h>
19 #include <openssl/x509v3.h>
20
21 #include "libssl.hh"
22
23 class OpenSSLFrontendContext
24 {
25 public:
OpenSSLFrontendContext(const ComboAddress & addr,const TLSConfig & tlsConfig)26 OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys)
27 {
28 registerOpenSSLUser();
29
30 d_tlsCtx = libssl_init_server_context(tlsConfig, d_ocspResponses);
31 if (!d_tlsCtx) {
32 ERR_print_errors_fp(stderr);
33 throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort());
34 }
35 }
36
cleanup()37 void cleanup()
38 {
39 d_tlsCtx.reset();
40
41 unregisterOpenSSLUser();
42 }
43
44 OpenSSLTLSTicketKeysRing d_ticketKeys;
45 std::map<int, std::string> d_ocspResponses;
46 std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
47 std::unique_ptr<FILE, int(*)(FILE*)> d_keyLogFile{nullptr, fclose};
48 };
49
50 class OpenSSLTLSConnection: public TLSConnection
51 {
52 public:
53 /* server side connection */
OpenSSLTLSConnection(int socket,unsigned int timeout,std::shared_ptr<OpenSSLFrontendContext> feContext)54 OpenSSLTLSConnection(int socket, unsigned int timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(feContext), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
55 {
56 d_socket = socket;
57
58 if (!s_initTLSConnIndex.test_and_set()) {
59 /* not initialized yet */
60 s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
61 if (s_tlsConnIndex == -1) {
62 throw std::runtime_error("Error getting an index for TLS connection data");
63 }
64 }
65
66 if (!d_conn) {
67 vinfolog("Error creating TLS object");
68 if (g_verbose) {
69 ERR_print_errors_fp(stderr);
70 }
71 throw std::runtime_error("Error creating TLS object");
72 }
73
74 if (!SSL_set_fd(d_conn.get(), d_socket)) {
75 throw std::runtime_error("Error assigning socket");
76 }
77
78 SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this);
79 }
80
81 /* client-side connection */
OpenSSLTLSConnection(const std::string & hostname,int socket,unsigned int timeout,SSL_CTX * tlsCtx)82 OpenSSLTLSConnection(const std::string& hostname, int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_hostname(hostname), d_timeout(timeout)
83 {
84 d_socket = socket;
85
86 if (!d_conn) {
87 vinfolog("Error creating TLS object");
88 if (g_verbose) {
89 ERR_print_errors_fp(stderr);
90 }
91 throw std::runtime_error("Error creating TLS object");
92 }
93
94 if (!SSL_set_fd(d_conn.get(), d_socket)) {
95 throw std::runtime_error("Error assigning socket");
96 }
97
98 #if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && HAVE_SSL_SET_HOSTFLAGS // grrr libressl
99 SSL_set_hostflags(d_conn.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
100 if (SSL_set1_host(d_conn.get(), d_hostname.c_str()) != 1) {
101 throw std::runtime_error("Error setting TLS hostname for certificate validation");
102 }
103 #elif (OPENSSL_VERSION_NUMBER >= 0x10002000L)
104 X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
105 /* Enable automatic hostname checks */
106 X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
107 if (X509_VERIFY_PARAM_set1_host(param, d_hostname.c_str(), d_hostname.size()) != 1) {
108 throw std::runtime_error("Error setting TLS hostname for certificate validation");
109 }
110 #else
111 /* no hostname validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
112 #endif
113 }
114
convertIORequestToIOState(int res) const115 IOState convertIORequestToIOState(int res) const
116 {
117 int error = SSL_get_error(d_conn.get(), res);
118 if (error == SSL_ERROR_WANT_READ) {
119 return IOState::NeedRead;
120 }
121 else if (error == SSL_ERROR_WANT_WRITE) {
122 return IOState::NeedWrite;
123 }
124 else if (error == SSL_ERROR_SYSCALL) {
125 throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno)));
126 }
127 else if (error == SSL_ERROR_ZERO_RETURN) {
128 throw std::runtime_error("TLS connection closed by remote end");
129 }
130 else {
131 if (g_verbose) {
132 throw std::runtime_error("Error while processing TLS connection: (" + std::to_string(error) + ") " + libssl_get_error_string());
133 } else {
134 throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error));
135 }
136 }
137 }
138
handleIORequest(int res,unsigned int timeout)139 void handleIORequest(int res, unsigned int timeout)
140 {
141 auto state = convertIORequestToIOState(res);
142 if (state == IOState::NeedRead) {
143 res = waitForData(d_socket, timeout);
144 if (res == 0) {
145 throw std::runtime_error("Timeout while reading from TLS connection");
146 }
147 else if (res < 0) {
148 throw std::runtime_error("Error waiting to read from TLS connection");
149 }
150 }
151 else if (state == IOState::NeedWrite) {
152 res = waitForRWData(d_socket, false, timeout, 0);
153 if (res == 0) {
154 throw std::runtime_error("Timeout while writing to TLS connection");
155 }
156 else if (res < 0) {
157 throw std::runtime_error("Error waiting to write to TLS connection");
158 }
159 }
160 }
161
tryConnect(bool fastOpen,const ComboAddress & remote)162 IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
163 {
164 /* sorry */
165 (void) fastOpen;
166 (void) remote;
167
168 int res = SSL_connect(d_conn.get());
169 if (res == 1) {
170 return IOState::Done;
171 }
172 else if (res < 0) {
173 return convertIORequestToIOState(res);
174 }
175
176 throw std::runtime_error("Error establishing a TLS connection");
177 }
178
connect(bool fastOpen,const ComboAddress & remote,unsigned int timeout)179 void connect(bool fastOpen, const ComboAddress& remote, unsigned int timeout) override
180 {
181 /* sorry */
182 (void) fastOpen;
183 (void) remote;
184
185 time_t start = 0;
186 unsigned int remainingTime = timeout;
187 if (timeout) {
188 start = time(nullptr);
189 }
190
191 int res = 0;
192 do {
193 res = SSL_connect(d_conn.get());
194 if (res < 0) {
195 handleIORequest(res, remainingTime);
196 }
197
198 if (timeout) {
199 time_t now = time(nullptr);
200 unsigned int elapsed = now - start;
201 if (now < start || elapsed >= remainingTime) {
202 throw runtime_error("Timeout while establishing TLS connection");
203 }
204 start = now;
205 remainingTime -= elapsed;
206 }
207 }
208 while (res != 1);
209 }
210
tryHandshake()211 IOState tryHandshake() override
212 {
213 int res = SSL_accept(d_conn.get());
214 if (res == 1) {
215 return IOState::Done;
216 }
217 else if (res < 0) {
218 return convertIORequestToIOState(res);
219 }
220
221 throw std::runtime_error("Error accepting TLS connection");
222 }
223
doHandshake()224 void doHandshake() override
225 {
226 int res = 0;
227 do {
228 res = SSL_accept(d_conn.get());
229 if (res < 0) {
230 handleIORequest(res, d_timeout);
231 }
232 }
233 while (res < 0);
234
235 if (res != 1) {
236 throw std::runtime_error("Error accepting TLS connection");
237 }
238 }
239
tryWrite(const PacketBuffer & buffer,size_t & pos,size_t toWrite)240 IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
241 {
242 do {
243 int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
244 if (res <= 0) {
245 return convertIORequestToIOState(res);
246 }
247 else {
248 pos += static_cast<size_t>(res);
249 }
250 }
251 while (pos < toWrite);
252 return IOState::Done;
253 }
254
tryRead(PacketBuffer & buffer,size_t & pos,size_t toRead)255 IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
256 {
257 do {
258 int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
259 if (res <= 0) {
260 return convertIORequestToIOState(res);
261 }
262 else {
263 pos += static_cast<size_t>(res);
264 }
265 }
266 while (pos < toRead);
267 return IOState::Done;
268 }
269
read(void * buffer,size_t bufferSize,unsigned int readTimeout,unsigned int totalTimeout)270 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
271 {
272 size_t got = 0;
273 time_t start = 0;
274 unsigned int remainingTime = totalTimeout;
275 if (totalTimeout) {
276 start = time(nullptr);
277 }
278
279 do {
280 int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
281 if (res <= 0) {
282 handleIORequest(res, readTimeout);
283 }
284 else {
285 got += static_cast<size_t>(res);
286 }
287
288 if (totalTimeout) {
289 time_t now = time(nullptr);
290 unsigned int elapsed = now - start;
291 if (now < start || elapsed >= remainingTime) {
292 throw runtime_error("Timeout while reading data");
293 }
294 start = now;
295 remainingTime -= elapsed;
296 }
297 }
298 while (got < bufferSize);
299
300 return got;
301 }
302
write(const void * buffer,size_t bufferSize,unsigned int writeTimeout)303 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
304 {
305 size_t got = 0;
306 do {
307 int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
308 if (res <= 0) {
309 handleIORequest(res, writeTimeout);
310 }
311 else {
312 got += static_cast<size_t>(res);
313 }
314 }
315 while (got < bufferSize);
316
317 return got;
318 }
319
hasBufferedData() const320 bool hasBufferedData() const override
321 {
322 if (d_conn) {
323 return SSL_pending(d_conn.get()) > 0;
324 }
325
326 return false;
327 }
328
close()329 void close() override
330 {
331 if (d_conn) {
332 SSL_shutdown(d_conn.get());
333 }
334 }
335
getServerNameIndication() const336 std::string getServerNameIndication() const override
337 {
338 if (d_conn) {
339 const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
340 if (value) {
341 return std::string(value);
342 }
343 }
344 return std::string();
345 }
346
getTLSVersion() const347 LibsslTLSVersion getTLSVersion() const override
348 {
349 auto proto = SSL_version(d_conn.get());
350 switch (proto) {
351 case TLS1_VERSION:
352 return LibsslTLSVersion::TLS10;
353 case TLS1_1_VERSION:
354 return LibsslTLSVersion::TLS11;
355 case TLS1_2_VERSION:
356 return LibsslTLSVersion::TLS12;
357 #ifdef TLS1_3_VERSION
358 case TLS1_3_VERSION:
359 return LibsslTLSVersion::TLS13;
360 #endif /* TLS1_3_VERSION */
361 default:
362 return LibsslTLSVersion::Unknown;
363 }
364 }
365
hasSessionBeenResumed() const366 bool hasSessionBeenResumed() const override
367 {
368 if (d_conn) {
369 return SSL_session_reused(d_conn.get()) != 0;
370 }
371 return false;
372 }
373
374 static int s_tlsConnIndex;
375
376 private:
377 static std::atomic_flag s_initTLSConnIndex;
378
379 std::shared_ptr<OpenSSLFrontendContext> d_feContext;
380 std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
381 std::string d_hostname;
382 unsigned int d_timeout;
383 };
384
385 std::atomic_flag OpenSSLTLSConnection::s_initTLSConnIndex = ATOMIC_FLAG_INIT;
386 int OpenSSLTLSConnection::s_tlsConnIndex = -1;
387
388 class OpenSSLTLSIOCtx: public TLSCtx
389 {
390 public:
391 /* server side context */
OpenSSLTLSIOCtx(TLSFrontend & fe)392 OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig)), d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
393 {
394 d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
395
396 if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) {
397 /* use our own ticket keys handler so we can rotate them */
398 SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
399 libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get());
400 }
401
402 if (!d_feContext->d_ocspResponses.empty()) {
403 SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
404 SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses);
405 }
406
407 libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters);
408
409 if (!fe.d_tlsConfig.d_keyLogFile.empty()) {
410 d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile);
411 }
412
413 try {
414 if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
415 handleTicketsKeyRotation(time(nullptr));
416 }
417 else {
418 OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
419 }
420 }
421 catch (const std::exception& e) {
422 throw;
423 }
424 }
425
426 /* client side context */
OpenSSLTLSIOCtx(const TLSContextParameters & params)427 OpenSSLTLSIOCtx(const TLSContextParameters& params): d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
428 {
429 int sslOptions =
430 SSL_OP_NO_SSLv2 |
431 SSL_OP_NO_SSLv3 |
432 SSL_OP_NO_COMPRESSION |
433 SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
434 SSL_OP_SINGLE_DH_USE |
435 SSL_OP_SINGLE_ECDH_USE |
436 SSL_OP_CIPHER_SERVER_PREFERENCE;
437
438 registerOpenSSLUser();
439
440 #ifdef HAVE_TLS_CLIENT_METHOD
441 d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
442 #else
443 d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
444 #endif
445 if (!d_tlsCtx) {
446 ERR_print_errors_fp(stderr);
447 throw std::runtime_error("Error creating TLS context");
448 }
449
450 SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
451 #if defined(SSL_CTX_set_ecdh_auto)
452 SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
453 #endif
454
455 if (!params.d_ciphers.empty()) {
456 if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), params.d_ciphers.c_str()) != 1) {
457 ERR_print_errors_fp(stderr);
458 throw std::runtime_error("Error setting the cipher list to '" + params.d_ciphers + "' for the TLS context");
459 }
460 }
461 #ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
462 if (!params.d_ciphers13.empty()) {
463 if (SSL_CTX_set_ciphersuites(d_tlsCtx.get(), params.d_ciphers13.c_str()) != 1) {
464 ERR_print_errors_fp(stderr);
465 throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + params.d_ciphers13 + "' for the TLS context");
466 }
467 }
468 #endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
469
470 if (params.d_validateCertificates) {
471 if (params.d_caStore.empty()) {
472 if (SSL_CTX_set_default_verify_paths(d_tlsCtx.get()) != 1) {
473 throw std::runtime_error("Error adding the system's default trusted CAs");
474 }
475 } else {
476 if (SSL_CTX_load_verify_locations(d_tlsCtx.get(), params.d_caStore.c_str(), nullptr) != 1) {
477 throw std::runtime_error("Error adding the trusted CAs file " + params.d_caStore);
478 }
479 }
480
481 SSL_CTX_set_verify(d_tlsCtx.get(), SSL_VERIFY_PEER, nullptr);
482 #if (OPENSSL_VERSION_NUMBER < 0x10002000L)
483 warnlog("TLS hostname validation requested but not supported for OpenSSL < 1.0.2");
484 #endif
485 }
486 }
487
~OpenSSLTLSIOCtx()488 ~OpenSSLTLSIOCtx() override
489 {
490 d_tlsCtx.reset();
491 unregisterOpenSSLUser();
492 }
493
ticketKeyCb(SSL * s,unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE],unsigned char * iv,EVP_CIPHER_CTX * ectx,HMAC_CTX * hctx,int enc)494 static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
495 {
496 OpenSSLFrontendContext* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s));
497 if (ctx == nullptr) {
498 return -1;
499 }
500
501 int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc);
502 if (enc == 0) {
503 if (ret == 0 || ret == 2) {
504 OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(s, OpenSSLTLSConnection::s_tlsConnIndex));
505 if (conn) {
506 if (ret == 0) {
507 conn->setUnknownTicketKey();
508 }
509 else if (ret == 2) {
510 conn->setResumedFromInactiveTicketKey();
511 }
512 }
513 }
514 }
515
516 return ret;
517 }
518
ocspStaplingCb(SSL * ssl,void * arg)519 static int ocspStaplingCb(SSL* ssl, void* arg)
520 {
521 if (ssl == nullptr || arg == nullptr) {
522 return SSL_TLSEXT_ERR_NOACK;
523 }
524 const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg);
525 return libssl_ocsp_stapling_callback(ssl, *ocspMap);
526 }
527
getConnection(int socket,unsigned int timeout,time_t now)528 std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
529 {
530 handleTicketsKeyRotation(now);
531
532 return std::make_unique<OpenSSLTLSConnection>(socket, timeout, d_feContext);
533 }
534
getClientConnection(const std::string & host,int socket,unsigned int timeout)535 std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, unsigned int timeout) override
536 {
537 return std::make_unique<OpenSSLTLSConnection>(host, socket, timeout, d_tlsCtx.get());
538 }
539
rotateTicketsKey(time_t now)540 void rotateTicketsKey(time_t now) override
541 {
542 d_feContext->d_ticketKeys.rotateTicketsKey(now);
543
544 if (d_ticketsKeyRotationDelay > 0) {
545 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
546 }
547 }
548
loadTicketsKeys(const std::string & keyFile)549 void loadTicketsKeys(const std::string& keyFile) override final
550 {
551 d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
552
553 if (d_ticketsKeyRotationDelay > 0) {
554 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
555 }
556 }
557
getTicketsKeysCount()558 size_t getTicketsKeysCount() override
559 {
560 return d_feContext->d_ticketKeys.getKeysCount();
561 }
562
563 private:
564 std::shared_ptr<OpenSSLFrontendContext> d_feContext;
565 std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx; // client context
566 };
567
568 #endif /* HAVE_LIBSSL */
569
570 #ifdef HAVE_GNUTLS
571 #include <gnutls/gnutls.h>
572 #include <gnutls/x509.h>
573
safe_memory_lock(void * data,size_t size)574 static void safe_memory_lock(void* data, size_t size)
575 {
576 #ifdef HAVE_LIBSODIUM
577 sodium_mlock(data, size);
578 #endif
579 }
580
safe_memory_release(void * data,size_t size)581 static void safe_memory_release(void* data, size_t size)
582 {
583 #ifdef HAVE_LIBSODIUM
584 sodium_munlock(data, size);
585 #elif defined(HAVE_EXPLICIT_BZERO)
586 explicit_bzero(data, size);
587 #elif defined(HAVE_EXPLICIT_MEMSET)
588 explicit_memset(data, 0, size);
589 #elif defined(HAVE_GNUTLS_MEMSET)
590 gnutls_memset(data, 0, size);
591 #else
592 /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
593 volatile unsigned int volatile_zero_idx = 0;
594 volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
595
596 if (size == 0)
597 return;
598
599 do {
600 memset(data, 0, size);
601 } while (p[volatile_zero_idx] != 0);
602 #endif
603 }
604
605 class GnuTLSTicketsKey
606 {
607 public:
GnuTLSTicketsKey()608 GnuTLSTicketsKey()
609 {
610 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
611 throw std::runtime_error("Error generating tickets key for TLS context");
612 }
613
614 safe_memory_lock(d_key.data, d_key.size);
615 }
616
GnuTLSTicketsKey(const std::string & keyFile)617 GnuTLSTicketsKey(const std::string& keyFile)
618 {
619 /* to be sure we are loading the correct amount of data, which
620 may change between versions, let's generate a correct key first */
621 if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
622 throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
623 }
624
625 safe_memory_lock(d_key.data, d_key.size);
626
627 try {
628 ifstream file(keyFile);
629 file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
630
631 if (file.fail()) {
632 file.close();
633 throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
634 }
635
636 file.close();
637 }
638 catch (const std::exception& e) {
639 safe_memory_release(d_key.data, d_key.size);
640 gnutls_free(d_key.data);
641 d_key.data = nullptr;
642 throw;
643 }
644 }
645
~GnuTLSTicketsKey()646 ~GnuTLSTicketsKey()
647 {
648 if (d_key.data != nullptr && d_key.size > 0) {
649 safe_memory_release(d_key.data, d_key.size);
650 }
651 gnutls_free(d_key.data);
652 d_key.data = nullptr;
653 }
getKey() const654 const gnutls_datum_t& getKey() const
655 {
656 return d_key;
657 }
658
659 private:
660 gnutls_datum_t d_key{nullptr, 0};
661 };
662
663 class GnuTLSConnection: public TLSConnection
664 {
665 public:
666 /* server side connection */
GnuTLSConnection(int socket,unsigned int timeout,const gnutls_certificate_credentials_t creds,const gnutls_priority_t priorityCache,std::shared_ptr<GnuTLSTicketsKey> & ticketsKey,bool enableTickets)667 GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
668 {
669 unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
670 #ifdef GNUTLS_NO_SIGNAL
671 sslOptions |= GNUTLS_NO_SIGNAL;
672 #endif
673
674 d_socket = socket;
675
676 gnutls_session_t conn;
677 if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
678 throw std::runtime_error("Error creating TLS connection");
679 }
680
681 d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
682 conn = nullptr;
683
684 if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
685 throw std::runtime_error("Error setting certificate and key to TLS connection");
686 }
687
688 if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
689 throw std::runtime_error("Error setting ciphers to TLS connection");
690 }
691
692 if (enableTickets && d_ticketsKey) {
693 const gnutls_datum_t& key = d_ticketsKey->getKey();
694 if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
695 throw std::runtime_error("Error setting the tickets key to TLS connection");
696 }
697 }
698
699 gnutls_transport_set_int(d_conn.get(), d_socket);
700
701 /* timeouts are in milliseconds */
702 gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
703 gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
704 }
705
706 /* client-side connection */
GnuTLSConnection(const std::string & host,int socket,unsigned int timeout,const gnutls_certificate_credentials_t creds,const gnutls_priority_t priorityCache,bool validateCerts)707 GnuTLSConnection(const std::string& host, int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, bool validateCerts): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host)
708 {
709 unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
710 #ifdef GNUTLS_NO_SIGNAL
711 sslOptions |= GNUTLS_NO_SIGNAL;
712 #endif
713
714 d_socket = socket;
715
716 gnutls_session_t conn;
717 if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
718 throw std::runtime_error("Error creating TLS connection");
719 }
720
721 d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
722 conn = nullptr;
723
724 int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds);
725 if (rc != GNUTLS_E_SUCCESS) {
726 throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc)));
727 }
728
729 rc = gnutls_priority_set(d_conn.get(), priorityCache);
730 if (rc != GNUTLS_E_SUCCESS) {
731 throw std::runtime_error("Error setting ciphers to TLS connection: " + std::string(gnutls_strerror(rc)));
732 }
733
734 gnutls_transport_set_int(d_conn.get(), d_socket);
735
736 /* timeouts are in milliseconds */
737 gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
738 gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
739
740 #if HAVE_GNUTLS_SESSION_SET_VERIFY_CERT
741 if (validateCerts && !d_host.empty()) {
742 gnutls_session_set_verify_cert(d_conn.get(), d_host.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN);
743 rc = gnutls_server_name_set(d_conn.get(), GNUTLS_NAME_DNS, d_host.c_str(), d_host.size());
744 if (rc != GNUTLS_E_SUCCESS) {
745 throw std::runtime_error("Error setting the SNI value to '" + d_host + "' on TLS connection: " + std::string(gnutls_strerror(rc)));
746 }
747 }
748 #else
749 /* no hostname validation for you */
750 #endif
751 }
752
tryConnect(bool fastOpen,const ComboAddress & remote)753 IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
754 {
755 int ret = 0;
756
757 if (fastOpen) {
758 #ifdef HAVE_GNUTLS_TRANSPORT_SET_FASTOPEN
759 gnutls_transport_set_fastopen(d_conn.get(), d_socket, const_cast<struct sockaddr*>(reinterpret_cast<const struct sockaddr*>(&remote)), remote.getSocklen(), 0);
760 #endif
761 }
762
763 do {
764 ret = gnutls_handshake(d_conn.get());
765 if (ret == GNUTLS_E_SUCCESS) {
766 return IOState::Done;
767 }
768 else if (ret == GNUTLS_E_AGAIN) {
769 int direction = gnutls_record_get_direction(d_conn.get());
770 return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
771 }
772 else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
773 throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
774 }
775 } while (ret == GNUTLS_E_INTERRUPTED);
776
777 throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
778 }
779
connect(bool fastOpen,const ComboAddress & remote,unsigned int timeout)780 void connect(bool fastOpen, const ComboAddress& remote, unsigned int timeout) override
781 {
782 time_t start = 0;
783 unsigned int remainingTime = timeout;
784 if (timeout) {
785 start = time(nullptr);
786 }
787
788 IOState state;
789 do {
790 state = tryConnect(fastOpen, remote);
791 if (state == IOState::Done) {
792 return;
793 }
794 else if (state == IOState::NeedRead) {
795 int result = waitForData(d_socket, remainingTime);
796 if (result <= 0) {
797 throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
798 }
799 }
800 else if (state == IOState::NeedWrite) {
801 int result = waitForRWData(d_socket, false, remainingTime, 0);
802 if (result <= 0) {
803 throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
804 }
805 }
806
807 if (timeout) {
808 time_t now = time(nullptr);
809 unsigned int elapsed = now - start;
810 if (now < start || elapsed >= remainingTime) {
811 throw runtime_error("Timeout while establishing TLS connection");
812 }
813 start = now;
814 remainingTime -= elapsed;
815 }
816 }
817 while (state != IOState::Done);
818 }
819
doHandshake()820 void doHandshake() override
821 {
822 int ret = 0;
823 do {
824 ret = gnutls_handshake(d_conn.get());
825 if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
826 throw std::runtime_error("Error accepting a new connection");
827 }
828 }
829 while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
830 }
831
tryHandshake()832 IOState tryHandshake() override
833 {
834 int ret = 0;
835
836 do {
837 ret = gnutls_handshake(d_conn.get());
838 if (ret == GNUTLS_E_SUCCESS) {
839 return IOState::Done;
840 }
841 else if (ret == GNUTLS_E_AGAIN) {
842 return IOState::NeedRead;
843 }
844 else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
845 throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
846 }
847 } while (ret == GNUTLS_E_INTERRUPTED);
848
849 throw std::runtime_error("Error accepting a new connection");
850 }
851
tryWrite(const PacketBuffer & buffer,size_t & pos,size_t toWrite)852 IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
853 {
854 do {
855 ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
856 if (res == 0) {
857 throw std::runtime_error("Error writing to TLS connection");
858 }
859 else if (res > 0) {
860 pos += static_cast<size_t>(res);
861 }
862 else if (res < 0) {
863 if (gnutls_error_is_fatal(res)) {
864 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
865 }
866 else if (res == GNUTLS_E_AGAIN) {
867 return IOState::NeedWrite;
868 }
869 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
870 }
871 }
872 while (pos < toWrite);
873 return IOState::Done;
874 }
875
tryRead(PacketBuffer & buffer,size_t & pos,size_t toRead)876 IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
877 {
878 do {
879 ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
880 if (res == 0) {
881 throw std::runtime_error("EOF while reading from TLS connection");
882 }
883 else if (res > 0) {
884 pos += static_cast<size_t>(res);
885 }
886 else if (res < 0) {
887 if (gnutls_error_is_fatal(res)) {
888 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
889 }
890 else if (res == GNUTLS_E_AGAIN) {
891 return IOState::NeedRead;
892 }
893 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
894 }
895 }
896 while (pos < toRead);
897 return IOState::Done;
898 }
899
read(void * buffer,size_t bufferSize,unsigned int readTimeout,unsigned int totalTimeout)900 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
901 {
902 size_t got = 0;
903 time_t start = 0;
904 unsigned int remainingTime = totalTimeout;
905 if (totalTimeout) {
906 start = time(nullptr);
907 }
908
909 do {
910 ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
911 if (res == 0) {
912 throw std::runtime_error("EOF while reading from TLS connection");
913 }
914 else if (res > 0) {
915 got += static_cast<size_t>(res);
916 }
917 else if (res < 0) {
918 if (gnutls_error_is_fatal(res)) {
919 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
920 }
921 else if (res == GNUTLS_E_AGAIN) {
922 int result = waitForData(d_socket, readTimeout);
923 if (result <= 0) {
924 throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result));
925 }
926 }
927 else {
928 vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
929 }
930 }
931
932 if (totalTimeout) {
933 time_t now = time(nullptr);
934 unsigned int elapsed = now - start;
935 if (now < start || elapsed >= remainingTime) {
936 throw runtime_error("Timeout while reading data");
937 }
938 start = now;
939 remainingTime -= elapsed;
940 }
941 }
942 while (got < bufferSize);
943
944 return got;
945 }
946
write(const void * buffer,size_t bufferSize,unsigned int writeTimeout)947 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
948 {
949 size_t got = 0;
950
951 do {
952 ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
953 if (res == 0) {
954 throw std::runtime_error("Error writing to TLS connection");
955 }
956 else if (res > 0) {
957 got += static_cast<size_t>(res);
958 }
959 else if (res < 0) {
960 if (gnutls_error_is_fatal(res)) {
961 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
962 }
963 else if (res == GNUTLS_E_AGAIN) {
964 int result = waitForRWData(d_socket, false, writeTimeout, 0);
965 if (result <= 0) {
966 throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
967 }
968 }
969 else {
970 vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
971 }
972 }
973 }
974 while (got < bufferSize);
975
976 return got;
977 }
978
hasBufferedData() const979 bool hasBufferedData() const override
980 {
981 if (d_conn) {
982 return gnutls_record_check_pending(d_conn.get()) > 0;
983 }
984
985 return false;
986 }
987
getServerNameIndication() const988 std::string getServerNameIndication() const override
989 {
990 if (d_conn) {
991 unsigned int type;
992 size_t name_len = 256;
993 std::string sni;
994 sni.resize(name_len);
995
996 int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
997 if (res == GNUTLS_E_SUCCESS) {
998 sni.resize(name_len);
999 return sni;
1000 }
1001 }
1002 return std::string();
1003 }
1004
getTLSVersion() const1005 LibsslTLSVersion getTLSVersion() const override
1006 {
1007 auto proto = gnutls_protocol_get_version(d_conn.get());
1008 switch (proto) {
1009 case GNUTLS_TLS1_0:
1010 return LibsslTLSVersion::TLS10;
1011 case GNUTLS_TLS1_1:
1012 return LibsslTLSVersion::TLS11;
1013 case GNUTLS_TLS1_2:
1014 return LibsslTLSVersion::TLS12;
1015 #if GNUTLS_VERSION_NUMBER >= 0x030603
1016 case GNUTLS_TLS1_3:
1017 return LibsslTLSVersion::TLS13;
1018 #endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
1019 default:
1020 return LibsslTLSVersion::Unknown;
1021 }
1022 }
1023
hasSessionBeenResumed() const1024 bool hasSessionBeenResumed() const override
1025 {
1026 if (d_conn) {
1027 return gnutls_session_is_resumed(d_conn.get()) != 0;
1028 }
1029 return false;
1030 }
1031
close()1032 void close() override
1033 {
1034 if (d_conn) {
1035 gnutls_bye(d_conn.get(), GNUTLS_SHUT_RDWR);
1036 }
1037 }
1038
1039 private:
1040 std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
1041 std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
1042 std::string d_host;
1043 };
1044
1045 class GnuTLSIOCtx: public TLSCtx
1046 {
1047 public:
1048 /* server side context */
GnuTLSIOCtx(TLSFrontend & fe)1049 GnuTLSIOCtx(TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_tlsConfig.d_enableTickets)
1050 {
1051 int rc = 0;
1052 d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
1053
1054 gnutls_certificate_credentials_t creds;
1055 rc = gnutls_certificate_allocate_credentials(&creds);
1056 if (rc != GNUTLS_E_SUCCESS) {
1057 throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1058 }
1059
1060 d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
1061 creds = nullptr;
1062
1063 for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) {
1064 rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
1065 if (rc != GNUTLS_E_SUCCESS) {
1066 throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1067 }
1068 }
1069
1070 size_t count = 0;
1071 for (const auto& file : fe.d_tlsConfig.d_ocspFiles) {
1072 rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count);
1073 if (rc != GNUTLS_E_SUCCESS) {
1074 throw std::runtime_error("Error loading OCSP response from file '" + file + "' for certificate ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).first + "') and key ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1075 }
1076 ++count;
1077 }
1078
1079 #if GNUTLS_VERSION_NUMBER >= 0x030600
1080 rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
1081 if (rc != GNUTLS_E_SUCCESS) {
1082 throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1083 }
1084 #endif
1085
1086 rc = gnutls_priority_init(&d_priorityCache, fe.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : fe.d_tlsConfig.d_ciphers.c_str(), nullptr);
1087 if (rc != GNUTLS_E_SUCCESS) {
1088 throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort());
1089 }
1090
1091 try {
1092 if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
1093 handleTicketsKeyRotation(time(nullptr));
1094 }
1095 else {
1096 GnuTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
1097 }
1098 }
1099 catch(const std::runtime_error& e) {
1100 throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
1101 }
1102 }
1103
1104 /* client side context */
GnuTLSIOCtx(const TLSContextParameters & params)1105 GnuTLSIOCtx(const TLSContextParameters& params): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(true), d_validateCerts(params.d_validateCertificates)
1106 {
1107 int rc = 0;
1108
1109 gnutls_certificate_credentials_t creds;
1110 rc = gnutls_certificate_allocate_credentials(&creds);
1111 if (rc != GNUTLS_E_SUCCESS) {
1112 throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
1113 }
1114
1115 d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
1116 creds = nullptr;
1117
1118 if (params.d_validateCertificates) {
1119 if (params.d_caStore.empty()) {
1120 rc = gnutls_certificate_set_x509_system_trust(d_creds.get());
1121 if (rc < 0) {
1122 throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
1123 }
1124 }
1125 else {
1126 rc = gnutls_certificate_set_x509_trust_file(d_creds.get(), params.d_caStore.c_str(), GNUTLS_X509_FMT_PEM);
1127 if (rc < 0) {
1128 throw std::runtime_error("Error adding '" + params.d_caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
1129 }
1130 }
1131 }
1132
1133 rc = gnutls_priority_init(&d_priorityCache, params.d_ciphers.empty() ? "NORMAL" : params.d_ciphers.c_str(), nullptr);
1134 if (rc != GNUTLS_E_SUCCESS) {
1135 throw std::runtime_error("Error setting up TLS cipher preferences to 'NORMAL' (" + std::string(gnutls_strerror(rc)) + ")");
1136 }
1137 }
1138
~GnuTLSIOCtx()1139 virtual ~GnuTLSIOCtx() override
1140 {
1141 d_creds.reset();
1142
1143 if (d_priorityCache) {
1144 gnutls_priority_deinit(d_priorityCache);
1145 }
1146 }
1147
getConnection(int socket,unsigned int timeout,time_t now)1148 std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
1149 {
1150 handleTicketsKeyRotation(now);
1151
1152 std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
1153 {
1154 ReadLock rl(&d_lock);
1155 ticketsKey = d_ticketsKey;
1156 }
1157
1158 return std::make_unique<GnuTLSConnection>(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets);
1159 }
1160
getClientConnection(const std::string & host,int socket,unsigned int timeout)1161 std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, unsigned int timeout) override
1162 {
1163 return std::make_unique<GnuTLSConnection>(host, socket, timeout, d_creds.get(), d_priorityCache, d_validateCerts);
1164 }
1165
rotateTicketsKey(time_t now)1166 void rotateTicketsKey(time_t now) override
1167 {
1168 if (!d_enableTickets) {
1169 return;
1170 }
1171
1172 auto newKey = std::make_shared<GnuTLSTicketsKey>();
1173
1174 {
1175 WriteLock wl(&d_lock);
1176 d_ticketsKey = newKey;
1177 }
1178
1179 if (d_ticketsKeyRotationDelay > 0) {
1180 d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
1181 }
1182 }
1183
loadTicketsKeys(const std::string & file)1184 void loadTicketsKeys(const std::string& file) override final
1185 {
1186 if (!d_enableTickets) {
1187 return;
1188 }
1189
1190 auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
1191 {
1192 WriteLock wl(&d_lock);
1193 d_ticketsKey = newKey;
1194 }
1195
1196 if (d_ticketsKeyRotationDelay > 0) {
1197 d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
1198 }
1199 }
1200
getTicketsKeysCount()1201 size_t getTicketsKeysCount() override
1202 {
1203 ReadLock rl(&d_lock);
1204 return d_ticketsKey != nullptr ? 1 : 0;
1205 }
1206
1207 private:
1208 std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
1209 gnutls_priority_t d_priorityCache{nullptr};
1210 std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
1211 ReadWriteLock d_lock;
1212 bool d_enableTickets{true};
1213 bool d_validateCerts{true};
1214 };
1215
1216 #endif /* HAVE_GNUTLS */
1217
1218 #endif /* HAVE_DNS_OVER_TLS */
1219
setupTLS()1220 bool TLSFrontend::setupTLS()
1221 {
1222 #ifdef HAVE_DNS_OVER_TLS
1223 std::shared_ptr<TLSCtx> newCtx{nullptr};
1224 /* get the "best" available provider */
1225 if (!d_provider.empty()) {
1226 #ifdef HAVE_GNUTLS
1227 if (d_provider == "gnutls") {
1228 newCtx = std::make_shared<GnuTLSIOCtx>(*this);
1229 std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
1230 return true;
1231 }
1232 #endif /* HAVE_GNUTLS */
1233 #ifdef HAVE_LIBSSL
1234 if (d_provider == "openssl") {
1235 newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1236 std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
1237 return true;
1238 }
1239 #endif /* HAVE_LIBSSL */
1240 }
1241 #ifdef HAVE_LIBSSL
1242 newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1243 #else /* HAVE_LIBSSL */
1244 #ifdef HAVE_GNUTLS
1245 newCtx = std::make_shared<GnuTLSIOCtx>(*this);
1246 #endif /* HAVE_GNUTLS */
1247 #endif /* HAVE_LIBSSL */
1248
1249 std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
1250 #endif /* HAVE_DNS_OVER_TLS */
1251 return true;
1252 }
1253
getTLSContext(const TLSContextParameters & params)1254 std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params)
1255 {
1256 #ifdef HAVE_DNS_OVER_TLS
1257 /* get the "best" available provider */
1258 if (!params.d_provider.empty()) {
1259 #ifdef HAVE_GNUTLS
1260 if (params.d_provider == "gnutls") {
1261 return std::make_shared<GnuTLSIOCtx>(params);
1262 }
1263 #endif /* HAVE_GNUTLS */
1264 #ifdef HAVE_LIBSSL
1265 if (params.d_provider == "openssl") {
1266 return std::make_shared<OpenSSLTLSIOCtx>(params);
1267 }
1268 #endif /* HAVE_LIBSSL */
1269 }
1270 #ifdef HAVE_GNUTLS
1271 return std::make_shared<GnuTLSIOCtx>(params);
1272 #else /* HAVE_GNUTLS */
1273 #ifdef HAVE_LIBSSL
1274 return std::make_shared<OpenSSLTLSIOCtx>(params);
1275 #endif /* HAVE_LIBSSL */
1276 #endif /* HAVE_GNUTLS */
1277
1278 #endif /* HAVE_DNS_OVER_TLS */
1279 return nullptr;
1280 }
1281