1 
2 #include "config.h"
3 #include "dolog.hh"
4 #include "iputils.hh"
5 #include "lock.hh"
6 #include "tcpiohandler.hh"
7 
8 #ifdef HAVE_LIBSODIUM
9 #include <sodium.h>
10 #endif /* HAVE_LIBSODIUM */
11 
12 #ifdef HAVE_DNS_OVER_TLS
13 #ifdef HAVE_LIBSSL
14 
15 #include <openssl/conf.h>
16 #include <openssl/err.h>
17 #include <openssl/rand.h>
18 #include <openssl/ssl.h>
19 #include <openssl/x509v3.h>
20 
21 #include "libssl.hh"
22 
23 class OpenSSLFrontendContext
24 {
25 public:
OpenSSLFrontendContext(const ComboAddress & addr,const TLSConfig & tlsConfig)26   OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys)
27   {
28     registerOpenSSLUser();
29 
30     d_tlsCtx = libssl_init_server_context(tlsConfig, d_ocspResponses);
31     if (!d_tlsCtx) {
32       ERR_print_errors_fp(stderr);
33       throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort());
34     }
35   }
36 
cleanup()37   void cleanup()
38   {
39     d_tlsCtx.reset();
40 
41     unregisterOpenSSLUser();
42   }
43 
44   OpenSSLTLSTicketKeysRing d_ticketKeys;
45   std::map<int, std::string> d_ocspResponses;
46   std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
47   std::unique_ptr<FILE, int(*)(FILE*)> d_keyLogFile{nullptr, fclose};
48 };
49 
50 class OpenSSLTLSConnection: public TLSConnection
51 {
52 public:
53   /* server side connection */
OpenSSLTLSConnection(int socket,unsigned int timeout,std::shared_ptr<OpenSSLFrontendContext> feContext)54   OpenSSLTLSConnection(int socket, unsigned int timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(feContext), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
55   {
56     d_socket = socket;
57 
58     if (!s_initTLSConnIndex.test_and_set()) {
59       /* not initialized yet */
60       s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
61       if (s_tlsConnIndex == -1) {
62         throw std::runtime_error("Error getting an index for TLS connection data");
63       }
64     }
65 
66     if (!d_conn) {
67       vinfolog("Error creating TLS object");
68       if (g_verbose) {
69         ERR_print_errors_fp(stderr);
70       }
71       throw std::runtime_error("Error creating TLS object");
72     }
73 
74     if (!SSL_set_fd(d_conn.get(), d_socket)) {
75       throw std::runtime_error("Error assigning socket");
76     }
77 
78     SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this);
79   }
80 
81   /* client-side connection */
OpenSSLTLSConnection(const std::string & hostname,int socket,unsigned int timeout,SSL_CTX * tlsCtx)82   OpenSSLTLSConnection(const std::string& hostname, int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_hostname(hostname), d_timeout(timeout)
83   {
84     d_socket = socket;
85 
86     if (!d_conn) {
87       vinfolog("Error creating TLS object");
88       if (g_verbose) {
89         ERR_print_errors_fp(stderr);
90       }
91       throw std::runtime_error("Error creating TLS object");
92     }
93 
94     if (!SSL_set_fd(d_conn.get(), d_socket)) {
95       throw std::runtime_error("Error assigning socket");
96     }
97 
98 #if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && HAVE_SSL_SET_HOSTFLAGS // grrr libressl
99     SSL_set_hostflags(d_conn.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
100     if (SSL_set1_host(d_conn.get(), d_hostname.c_str()) != 1) {
101       throw std::runtime_error("Error setting TLS hostname for certificate validation");
102     }
103 #elif (OPENSSL_VERSION_NUMBER >= 0x10002000L)
104     X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
105     /* Enable automatic hostname checks */
106     X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
107     if (X509_VERIFY_PARAM_set1_host(param, d_hostname.c_str(), d_hostname.size()) != 1) {
108       throw std::runtime_error("Error setting TLS hostname for certificate validation");
109     }
110 #else
111     /* no hostname validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
112 #endif
113   }
114 
convertIORequestToIOState(int res) const115   IOState convertIORequestToIOState(int res) const
116   {
117     int error = SSL_get_error(d_conn.get(), res);
118     if (error == SSL_ERROR_WANT_READ) {
119       return IOState::NeedRead;
120     }
121     else if (error == SSL_ERROR_WANT_WRITE) {
122       return IOState::NeedWrite;
123     }
124     else if (error == SSL_ERROR_SYSCALL) {
125       throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno)));
126     }
127     else if (error == SSL_ERROR_ZERO_RETURN) {
128       throw std::runtime_error("TLS connection closed by remote end");
129     }
130     else {
131       if (g_verbose) {
132         throw std::runtime_error("Error while processing TLS connection: (" + std::to_string(error) + ") " + libssl_get_error_string());
133       } else {
134         throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error));
135       }
136     }
137   }
138 
handleIORequest(int res,unsigned int timeout)139   void handleIORequest(int res, unsigned int timeout)
140   {
141     auto state = convertIORequestToIOState(res);
142     if (state == IOState::NeedRead) {
143       res = waitForData(d_socket, timeout);
144       if (res == 0) {
145         throw std::runtime_error("Timeout while reading from TLS connection");
146       }
147       else if (res < 0) {
148         throw std::runtime_error("Error waiting to read from TLS connection");
149       }
150     }
151     else if (state == IOState::NeedWrite) {
152       res = waitForRWData(d_socket, false, timeout, 0);
153       if (res == 0) {
154         throw std::runtime_error("Timeout while writing to TLS connection");
155       }
156       else if (res < 0) {
157         throw std::runtime_error("Error waiting to write to TLS connection");
158       }
159     }
160   }
161 
tryConnect(bool fastOpen,const ComboAddress & remote)162   IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
163   {
164     /* sorry */
165     (void) fastOpen;
166     (void) remote;
167 
168     int res = SSL_connect(d_conn.get());
169     if (res == 1) {
170       return IOState::Done;
171     }
172     else if (res < 0) {
173       return convertIORequestToIOState(res);
174     }
175 
176     throw std::runtime_error("Error establishing a TLS connection");
177   }
178 
connect(bool fastOpen,const ComboAddress & remote,unsigned int timeout)179   void connect(bool fastOpen, const ComboAddress& remote, unsigned int timeout) override
180   {
181     /* sorry */
182     (void) fastOpen;
183     (void) remote;
184 
185     time_t start = 0;
186     unsigned int remainingTime = timeout;
187     if (timeout) {
188       start = time(nullptr);
189     }
190 
191     int res = 0;
192     do {
193       res = SSL_connect(d_conn.get());
194       if (res < 0) {
195         handleIORequest(res, remainingTime);
196       }
197 
198       if (timeout) {
199         time_t now = time(nullptr);
200         unsigned int elapsed = now - start;
201         if (now < start || elapsed >= remainingTime) {
202           throw runtime_error("Timeout while establishing TLS connection");
203         }
204         start = now;
205         remainingTime -= elapsed;
206       }
207     }
208     while (res != 1);
209   }
210 
tryHandshake()211   IOState tryHandshake() override
212   {
213     int res = SSL_accept(d_conn.get());
214     if (res == 1) {
215       return IOState::Done;
216     }
217     else if (res < 0) {
218       return convertIORequestToIOState(res);
219     }
220 
221     throw std::runtime_error("Error accepting TLS connection");
222   }
223 
doHandshake()224   void doHandshake() override
225   {
226     int res = 0;
227     do {
228       res = SSL_accept(d_conn.get());
229       if (res < 0) {
230         handleIORequest(res, d_timeout);
231       }
232     }
233     while (res < 0);
234 
235     if (res != 1) {
236       throw std::runtime_error("Error accepting TLS connection");
237     }
238   }
239 
tryWrite(const PacketBuffer & buffer,size_t & pos,size_t toWrite)240   IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
241   {
242     do {
243       int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
244       if (res <= 0) {
245         return convertIORequestToIOState(res);
246       }
247       else {
248         pos += static_cast<size_t>(res);
249       }
250     }
251     while (pos < toWrite);
252     return IOState::Done;
253   }
254 
tryRead(PacketBuffer & buffer,size_t & pos,size_t toRead)255   IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
256   {
257     do {
258       int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
259       if (res <= 0) {
260         return convertIORequestToIOState(res);
261       }
262       else {
263         pos += static_cast<size_t>(res);
264       }
265     }
266     while (pos < toRead);
267     return IOState::Done;
268   }
269 
read(void * buffer,size_t bufferSize,unsigned int readTimeout,unsigned int totalTimeout)270   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
271   {
272     size_t got = 0;
273     time_t start = 0;
274     unsigned int remainingTime = totalTimeout;
275     if (totalTimeout) {
276       start = time(nullptr);
277     }
278 
279     do {
280       int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
281       if (res <= 0) {
282         handleIORequest(res, readTimeout);
283       }
284       else {
285         got += static_cast<size_t>(res);
286       }
287 
288       if (totalTimeout) {
289         time_t now = time(nullptr);
290         unsigned int elapsed = now - start;
291         if (now < start || elapsed >= remainingTime) {
292           throw runtime_error("Timeout while reading data");
293         }
294         start = now;
295         remainingTime -= elapsed;
296       }
297     }
298     while (got < bufferSize);
299 
300     return got;
301   }
302 
write(const void * buffer,size_t bufferSize,unsigned int writeTimeout)303   size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
304   {
305     size_t got = 0;
306     do {
307       int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
308       if (res <= 0) {
309         handleIORequest(res, writeTimeout);
310       }
311       else {
312         got += static_cast<size_t>(res);
313       }
314     }
315     while (got < bufferSize);
316 
317     return got;
318   }
319 
hasBufferedData() const320   bool hasBufferedData() const override
321   {
322     if (d_conn) {
323       return SSL_pending(d_conn.get()) > 0;
324     }
325 
326     return false;
327   }
328 
close()329   void close() override
330   {
331     if (d_conn) {
332       SSL_shutdown(d_conn.get());
333     }
334   }
335 
getServerNameIndication() const336   std::string getServerNameIndication() const override
337   {
338     if (d_conn) {
339       const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
340       if (value) {
341         return std::string(value);
342       }
343     }
344     return std::string();
345   }
346 
getTLSVersion() const347   LibsslTLSVersion getTLSVersion() const override
348   {
349     auto proto = SSL_version(d_conn.get());
350     switch (proto) {
351     case TLS1_VERSION:
352       return LibsslTLSVersion::TLS10;
353     case TLS1_1_VERSION:
354       return LibsslTLSVersion::TLS11;
355     case TLS1_2_VERSION:
356       return LibsslTLSVersion::TLS12;
357 #ifdef TLS1_3_VERSION
358     case TLS1_3_VERSION:
359       return LibsslTLSVersion::TLS13;
360 #endif /* TLS1_3_VERSION */
361     default:
362       return LibsslTLSVersion::Unknown;
363     }
364   }
365 
hasSessionBeenResumed() const366   bool hasSessionBeenResumed() const override
367   {
368     if (d_conn) {
369       return SSL_session_reused(d_conn.get()) != 0;
370     }
371     return false;
372   }
373 
374   static int s_tlsConnIndex;
375 
376 private:
377   static std::atomic_flag s_initTLSConnIndex;
378 
379   std::shared_ptr<OpenSSLFrontendContext> d_feContext;
380   std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
381   std::string d_hostname;
382   unsigned int d_timeout;
383 };
384 
385 std::atomic_flag OpenSSLTLSConnection::s_initTLSConnIndex = ATOMIC_FLAG_INIT;
386 int OpenSSLTLSConnection::s_tlsConnIndex = -1;
387 
388 class OpenSSLTLSIOCtx: public TLSCtx
389 {
390 public:
391   /* server side context */
OpenSSLTLSIOCtx(TLSFrontend & fe)392   OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig)), d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
393   {
394     d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
395 
396     if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) {
397       /* use our own ticket keys handler so we can rotate them */
398       SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
399       libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get());
400     }
401 
402     if (!d_feContext->d_ocspResponses.empty()) {
403       SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
404       SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses);
405     }
406 
407     libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters);
408 
409     if (!fe.d_tlsConfig.d_keyLogFile.empty()) {
410       d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile);
411     }
412 
413     try {
414       if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
415         handleTicketsKeyRotation(time(nullptr));
416       }
417       else {
418         OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
419       }
420     }
421     catch (const std::exception& e) {
422       throw;
423     }
424   }
425 
426   /* client side context */
OpenSSLTLSIOCtx(const TLSContextParameters & params)427   OpenSSLTLSIOCtx(const TLSContextParameters& params): d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
428   {
429     int sslOptions =
430       SSL_OP_NO_SSLv2 |
431       SSL_OP_NO_SSLv3 |
432       SSL_OP_NO_COMPRESSION |
433       SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
434       SSL_OP_SINGLE_DH_USE |
435       SSL_OP_SINGLE_ECDH_USE |
436       SSL_OP_CIPHER_SERVER_PREFERENCE;
437 
438     registerOpenSSLUser();
439 
440 #ifdef HAVE_TLS_CLIENT_METHOD
441     d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
442 #else
443     d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
444 #endif
445     if (!d_tlsCtx) {
446       ERR_print_errors_fp(stderr);
447       throw std::runtime_error("Error creating TLS context");
448     }
449 
450     SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
451 #if defined(SSL_CTX_set_ecdh_auto)
452     SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
453 #endif
454 
455     if (!params.d_ciphers.empty()) {
456       if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), params.d_ciphers.c_str()) != 1) {
457         ERR_print_errors_fp(stderr);
458         throw std::runtime_error("Error setting the cipher list to '" + params.d_ciphers + "' for the TLS context");
459       }
460     }
461 #ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
462     if (!params.d_ciphers13.empty()) {
463       if (SSL_CTX_set_ciphersuites(d_tlsCtx.get(), params.d_ciphers13.c_str()) != 1) {
464         ERR_print_errors_fp(stderr);
465         throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + params.d_ciphers13 + "' for the TLS context");
466       }
467     }
468 #endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
469 
470     if (params.d_validateCertificates) {
471       if (params.d_caStore.empty())  {
472         if (SSL_CTX_set_default_verify_paths(d_tlsCtx.get()) != 1) {
473           throw std::runtime_error("Error adding the system's default trusted CAs");
474         }
475       } else {
476         if (SSL_CTX_load_verify_locations(d_tlsCtx.get(), params.d_caStore.c_str(), nullptr) != 1) {
477           throw std::runtime_error("Error adding the trusted CAs file " + params.d_caStore);
478         }
479       }
480 
481       SSL_CTX_set_verify(d_tlsCtx.get(), SSL_VERIFY_PEER, nullptr);
482 #if (OPENSSL_VERSION_NUMBER < 0x10002000L)
483       warnlog("TLS hostname validation requested but not supported for OpenSSL < 1.0.2");
484 #endif
485     }
486   }
487 
~OpenSSLTLSIOCtx()488   ~OpenSSLTLSIOCtx() override
489   {
490     d_tlsCtx.reset();
491     unregisterOpenSSLUser();
492   }
493 
ticketKeyCb(SSL * s,unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE],unsigned char * iv,EVP_CIPHER_CTX * ectx,HMAC_CTX * hctx,int enc)494   static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
495   {
496     OpenSSLFrontendContext* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s));
497     if (ctx == nullptr) {
498       return -1;
499     }
500 
501     int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc);
502     if (enc == 0) {
503       if (ret == 0 || ret == 2) {
504         OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(s, OpenSSLTLSConnection::s_tlsConnIndex));
505         if (conn) {
506           if (ret == 0) {
507             conn->setUnknownTicketKey();
508           }
509           else if (ret == 2) {
510             conn->setResumedFromInactiveTicketKey();
511           }
512         }
513       }
514     }
515 
516     return ret;
517   }
518 
ocspStaplingCb(SSL * ssl,void * arg)519   static int ocspStaplingCb(SSL* ssl, void* arg)
520   {
521     if (ssl == nullptr || arg == nullptr) {
522       return SSL_TLSEXT_ERR_NOACK;
523     }
524     const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg);
525     return libssl_ocsp_stapling_callback(ssl, *ocspMap);
526   }
527 
getConnection(int socket,unsigned int timeout,time_t now)528   std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
529   {
530     handleTicketsKeyRotation(now);
531 
532     return std::make_unique<OpenSSLTLSConnection>(socket, timeout, d_feContext);
533   }
534 
getClientConnection(const std::string & host,int socket,unsigned int timeout)535   std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, unsigned int timeout) override
536   {
537     return std::make_unique<OpenSSLTLSConnection>(host, socket, timeout, d_tlsCtx.get());
538   }
539 
rotateTicketsKey(time_t now)540   void rotateTicketsKey(time_t now) override
541   {
542     d_feContext->d_ticketKeys.rotateTicketsKey(now);
543 
544     if (d_ticketsKeyRotationDelay > 0) {
545       d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
546     }
547   }
548 
loadTicketsKeys(const std::string & keyFile)549   void loadTicketsKeys(const std::string& keyFile) override final
550   {
551     d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
552 
553     if (d_ticketsKeyRotationDelay > 0) {
554       d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
555     }
556   }
557 
getTicketsKeysCount()558   size_t getTicketsKeysCount() override
559   {
560     return d_feContext->d_ticketKeys.getKeysCount();
561   }
562 
563 private:
564   std::shared_ptr<OpenSSLFrontendContext> d_feContext;
565   std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx; // client context
566 };
567 
568 #endif /* HAVE_LIBSSL */
569 
570 #ifdef HAVE_GNUTLS
571 #include <gnutls/gnutls.h>
572 #include <gnutls/x509.h>
573 
safe_memory_lock(void * data,size_t size)574 static void safe_memory_lock(void* data, size_t size)
575 {
576 #ifdef HAVE_LIBSODIUM
577   sodium_mlock(data, size);
578 #endif
579 }
580 
safe_memory_release(void * data,size_t size)581 static void safe_memory_release(void* data, size_t size)
582 {
583 #ifdef HAVE_LIBSODIUM
584   sodium_munlock(data, size);
585 #elif defined(HAVE_EXPLICIT_BZERO)
586   explicit_bzero(data, size);
587 #elif defined(HAVE_EXPLICIT_MEMSET)
588   explicit_memset(data, 0, size);
589 #elif defined(HAVE_GNUTLS_MEMSET)
590   gnutls_memset(data, 0, size);
591 #else
592   /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
593   volatile unsigned int volatile_zero_idx = 0;
594   volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
595 
596   if (size == 0)
597     return;
598 
599   do {
600     memset(data, 0, size);
601   } while (p[volatile_zero_idx] != 0);
602 #endif
603 }
604 
605 class GnuTLSTicketsKey
606 {
607 public:
GnuTLSTicketsKey()608   GnuTLSTicketsKey()
609   {
610     if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
611       throw std::runtime_error("Error generating tickets key for TLS context");
612     }
613 
614     safe_memory_lock(d_key.data, d_key.size);
615   }
616 
GnuTLSTicketsKey(const std::string & keyFile)617   GnuTLSTicketsKey(const std::string& keyFile)
618   {
619     /* to be sure we are loading the correct amount of data, which
620        may change between versions, let's generate a correct key first */
621     if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
622       throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
623     }
624 
625     safe_memory_lock(d_key.data, d_key.size);
626 
627     try {
628       ifstream file(keyFile);
629       file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
630 
631       if (file.fail()) {
632         file.close();
633         throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
634       }
635 
636       file.close();
637     }
638     catch (const std::exception& e) {
639       safe_memory_release(d_key.data, d_key.size);
640       gnutls_free(d_key.data);
641       d_key.data = nullptr;
642       throw;
643     }
644   }
645 
~GnuTLSTicketsKey()646   ~GnuTLSTicketsKey()
647   {
648     if (d_key.data != nullptr && d_key.size > 0) {
649       safe_memory_release(d_key.data, d_key.size);
650     }
651     gnutls_free(d_key.data);
652     d_key.data = nullptr;
653   }
getKey() const654   const gnutls_datum_t& getKey() const
655   {
656     return d_key;
657   }
658 
659 private:
660   gnutls_datum_t d_key{nullptr, 0};
661 };
662 
663 class GnuTLSConnection: public TLSConnection
664 {
665 public:
666   /* server side connection */
GnuTLSConnection(int socket,unsigned int timeout,const gnutls_certificate_credentials_t creds,const gnutls_priority_t priorityCache,std::shared_ptr<GnuTLSTicketsKey> & ticketsKey,bool enableTickets)667   GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
668   {
669     unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
670 #ifdef GNUTLS_NO_SIGNAL
671     sslOptions |= GNUTLS_NO_SIGNAL;
672 #endif
673 
674     d_socket = socket;
675 
676     gnutls_session_t conn;
677     if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
678       throw std::runtime_error("Error creating TLS connection");
679     }
680 
681     d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
682     conn = nullptr;
683 
684     if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
685       throw std::runtime_error("Error setting certificate and key to TLS connection");
686     }
687 
688     if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
689       throw std::runtime_error("Error setting ciphers to TLS connection");
690     }
691 
692     if (enableTickets && d_ticketsKey) {
693       const gnutls_datum_t& key = d_ticketsKey->getKey();
694       if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
695         throw std::runtime_error("Error setting the tickets key to TLS connection");
696       }
697     }
698 
699     gnutls_transport_set_int(d_conn.get(), d_socket);
700 
701     /* timeouts are in milliseconds */
702     gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
703     gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
704   }
705 
706   /* client-side connection */
GnuTLSConnection(const std::string & host,int socket,unsigned int timeout,const gnutls_certificate_credentials_t creds,const gnutls_priority_t priorityCache,bool validateCerts)707   GnuTLSConnection(const std::string& host, int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, bool validateCerts): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host)
708   {
709     unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
710 #ifdef GNUTLS_NO_SIGNAL
711     sslOptions |= GNUTLS_NO_SIGNAL;
712 #endif
713 
714     d_socket = socket;
715 
716     gnutls_session_t conn;
717     if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
718       throw std::runtime_error("Error creating TLS connection");
719     }
720 
721     d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
722     conn = nullptr;
723 
724     int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds);
725     if (rc != GNUTLS_E_SUCCESS) {
726       throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc)));
727     }
728 
729     rc = gnutls_priority_set(d_conn.get(), priorityCache);
730     if (rc != GNUTLS_E_SUCCESS) {
731       throw std::runtime_error("Error setting ciphers to TLS connection: " + std::string(gnutls_strerror(rc)));
732     }
733 
734     gnutls_transport_set_int(d_conn.get(), d_socket);
735 
736     /* timeouts are in milliseconds */
737     gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
738     gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
739 
740 #if HAVE_GNUTLS_SESSION_SET_VERIFY_CERT
741     if (validateCerts && !d_host.empty()) {
742       gnutls_session_set_verify_cert(d_conn.get(), d_host.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN);
743       rc = gnutls_server_name_set(d_conn.get(), GNUTLS_NAME_DNS, d_host.c_str(), d_host.size());
744       if (rc != GNUTLS_E_SUCCESS) {
745         throw std::runtime_error("Error setting the SNI value to '" + d_host + "' on TLS connection: " + std::string(gnutls_strerror(rc)));
746       }
747     }
748 #else
749     /* no hostname validation for you */
750 #endif
751   }
752 
tryConnect(bool fastOpen,const ComboAddress & remote)753   IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
754   {
755     int ret = 0;
756 
757     if (fastOpen) {
758 #ifdef HAVE_GNUTLS_TRANSPORT_SET_FASTOPEN
759       gnutls_transport_set_fastopen(d_conn.get(), d_socket, const_cast<struct sockaddr*>(reinterpret_cast<const struct sockaddr*>(&remote)), remote.getSocklen(), 0);
760 #endif
761     }
762 
763     do {
764       ret = gnutls_handshake(d_conn.get());
765       if (ret == GNUTLS_E_SUCCESS) {
766         return IOState::Done;
767       }
768       else if (ret == GNUTLS_E_AGAIN) {
769         int direction = gnutls_record_get_direction(d_conn.get());
770         return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
771       }
772       else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
773         throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
774       }
775     } while (ret == GNUTLS_E_INTERRUPTED);
776 
777     throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
778   }
779 
connect(bool fastOpen,const ComboAddress & remote,unsigned int timeout)780   void connect(bool fastOpen, const ComboAddress& remote, unsigned int timeout) override
781   {
782     time_t start = 0;
783     unsigned int remainingTime = timeout;
784     if (timeout) {
785       start = time(nullptr);
786     }
787 
788     IOState state;
789     do {
790       state = tryConnect(fastOpen, remote);
791       if (state == IOState::Done) {
792         return;
793       }
794       else if (state == IOState::NeedRead) {
795         int result = waitForData(d_socket, remainingTime);
796         if (result <= 0) {
797           throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
798         }
799       }
800       else if (state == IOState::NeedWrite) {
801         int result = waitForRWData(d_socket, false, remainingTime, 0);
802         if (result <= 0) {
803           throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
804         }
805       }
806 
807       if (timeout) {
808         time_t now = time(nullptr);
809         unsigned int elapsed = now - start;
810         if (now < start || elapsed >= remainingTime) {
811           throw runtime_error("Timeout while establishing TLS connection");
812         }
813         start = now;
814         remainingTime -= elapsed;
815       }
816     }
817     while (state != IOState::Done);
818   }
819 
doHandshake()820   void doHandshake() override
821   {
822     int ret = 0;
823     do {
824       ret = gnutls_handshake(d_conn.get());
825       if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
826         throw std::runtime_error("Error accepting a new connection");
827       }
828     }
829     while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
830   }
831 
tryHandshake()832   IOState tryHandshake() override
833   {
834     int ret = 0;
835 
836     do {
837       ret = gnutls_handshake(d_conn.get());
838       if (ret == GNUTLS_E_SUCCESS) {
839         return IOState::Done;
840       }
841       else if (ret == GNUTLS_E_AGAIN) {
842         return IOState::NeedRead;
843       }
844       else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
845         throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
846       }
847     } while (ret == GNUTLS_E_INTERRUPTED);
848 
849     throw std::runtime_error("Error accepting a new connection");
850   }
851 
tryWrite(const PacketBuffer & buffer,size_t & pos,size_t toWrite)852   IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
853   {
854     do {
855       ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
856       if (res == 0) {
857         throw std::runtime_error("Error writing to TLS connection");
858       }
859       else if (res > 0) {
860         pos += static_cast<size_t>(res);
861       }
862       else if (res < 0) {
863         if (gnutls_error_is_fatal(res)) {
864           throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
865         }
866         else if (res == GNUTLS_E_AGAIN) {
867           return IOState::NeedWrite;
868         }
869         warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
870       }
871     }
872     while (pos < toWrite);
873     return IOState::Done;
874   }
875 
tryRead(PacketBuffer & buffer,size_t & pos,size_t toRead)876   IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
877   {
878     do {
879       ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
880       if (res == 0) {
881         throw std::runtime_error("EOF while reading from TLS connection");
882       }
883       else if (res > 0) {
884         pos += static_cast<size_t>(res);
885       }
886       else if (res < 0) {
887         if (gnutls_error_is_fatal(res)) {
888           throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
889         }
890         else if (res == GNUTLS_E_AGAIN) {
891           return IOState::NeedRead;
892         }
893         warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
894       }
895     }
896     while (pos < toRead);
897     return IOState::Done;
898   }
899 
read(void * buffer,size_t bufferSize,unsigned int readTimeout,unsigned int totalTimeout)900   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
901   {
902     size_t got = 0;
903     time_t start = 0;
904     unsigned int remainingTime = totalTimeout;
905     if (totalTimeout) {
906       start = time(nullptr);
907     }
908 
909     do {
910       ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
911       if (res == 0) {
912         throw std::runtime_error("EOF while reading from TLS connection");
913       }
914       else if (res > 0) {
915         got += static_cast<size_t>(res);
916       }
917       else if (res < 0) {
918         if (gnutls_error_is_fatal(res)) {
919           throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
920         }
921         else if (res == GNUTLS_E_AGAIN) {
922           int result = waitForData(d_socket, readTimeout);
923           if (result <= 0) {
924             throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result));
925           }
926         }
927         else {
928           vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
929         }
930       }
931 
932       if (totalTimeout) {
933         time_t now = time(nullptr);
934         unsigned int elapsed = now - start;
935         if (now < start || elapsed >= remainingTime) {
936           throw runtime_error("Timeout while reading data");
937         }
938         start = now;
939         remainingTime -= elapsed;
940       }
941     }
942     while (got < bufferSize);
943 
944     return got;
945   }
946 
write(const void * buffer,size_t bufferSize,unsigned int writeTimeout)947   size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
948   {
949     size_t got = 0;
950 
951     do {
952       ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
953       if (res == 0) {
954         throw std::runtime_error("Error writing to TLS connection");
955       }
956       else if (res > 0) {
957         got += static_cast<size_t>(res);
958       }
959       else if (res < 0) {
960         if (gnutls_error_is_fatal(res)) {
961           throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
962         }
963         else if (res == GNUTLS_E_AGAIN) {
964           int result = waitForRWData(d_socket, false, writeTimeout, 0);
965           if (result <= 0) {
966             throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
967           }
968         }
969         else {
970           vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
971         }
972       }
973     }
974     while (got < bufferSize);
975 
976     return got;
977   }
978 
hasBufferedData() const979   bool hasBufferedData() const override
980   {
981     if (d_conn) {
982       return gnutls_record_check_pending(d_conn.get()) > 0;
983     }
984 
985     return false;
986   }
987 
getServerNameIndication() const988   std::string getServerNameIndication() const override
989   {
990     if (d_conn) {
991       unsigned int type;
992       size_t name_len = 256;
993       std::string sni;
994       sni.resize(name_len);
995 
996       int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
997       if (res == GNUTLS_E_SUCCESS) {
998         sni.resize(name_len);
999         return sni;
1000       }
1001     }
1002     return std::string();
1003   }
1004 
getTLSVersion() const1005   LibsslTLSVersion getTLSVersion() const override
1006   {
1007     auto proto = gnutls_protocol_get_version(d_conn.get());
1008     switch (proto) {
1009     case GNUTLS_TLS1_0:
1010       return LibsslTLSVersion::TLS10;
1011     case GNUTLS_TLS1_1:
1012       return LibsslTLSVersion::TLS11;
1013     case GNUTLS_TLS1_2:
1014       return LibsslTLSVersion::TLS12;
1015 #if GNUTLS_VERSION_NUMBER >= 0x030603
1016     case GNUTLS_TLS1_3:
1017       return LibsslTLSVersion::TLS13;
1018 #endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
1019     default:
1020       return LibsslTLSVersion::Unknown;
1021     }
1022   }
1023 
hasSessionBeenResumed() const1024   bool hasSessionBeenResumed() const override
1025   {
1026     if (d_conn) {
1027       return gnutls_session_is_resumed(d_conn.get()) != 0;
1028     }
1029     return false;
1030   }
1031 
close()1032   void close() override
1033   {
1034     if (d_conn) {
1035       gnutls_bye(d_conn.get(), GNUTLS_SHUT_RDWR);
1036     }
1037   }
1038 
1039 private:
1040   std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
1041   std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
1042   std::string d_host;
1043 };
1044 
1045 class GnuTLSIOCtx: public TLSCtx
1046 {
1047 public:
1048   /* server side context */
GnuTLSIOCtx(TLSFrontend & fe)1049   GnuTLSIOCtx(TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_tlsConfig.d_enableTickets)
1050   {
1051     int rc = 0;
1052     d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
1053 
1054     gnutls_certificate_credentials_t creds;
1055     rc = gnutls_certificate_allocate_credentials(&creds);
1056     if (rc != GNUTLS_E_SUCCESS) {
1057       throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1058     }
1059 
1060     d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
1061     creds = nullptr;
1062 
1063     for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) {
1064       rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
1065       if (rc != GNUTLS_E_SUCCESS) {
1066         throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1067       }
1068     }
1069 
1070     size_t count = 0;
1071     for (const auto& file : fe.d_tlsConfig.d_ocspFiles) {
1072       rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count);
1073       if (rc != GNUTLS_E_SUCCESS) {
1074         throw std::runtime_error("Error loading OCSP response from file '" + file + "' for certificate ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).first + "') and key ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1075       }
1076       ++count;
1077     }
1078 
1079 #if GNUTLS_VERSION_NUMBER >= 0x030600
1080     rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
1081     if (rc != GNUTLS_E_SUCCESS) {
1082       throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1083     }
1084 #endif
1085 
1086     rc = gnutls_priority_init(&d_priorityCache, fe.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : fe.d_tlsConfig.d_ciphers.c_str(), nullptr);
1087     if (rc != GNUTLS_E_SUCCESS) {
1088       throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort());
1089     }
1090 
1091     try {
1092       if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
1093         handleTicketsKeyRotation(time(nullptr));
1094       }
1095       else {
1096         GnuTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
1097       }
1098     }
1099     catch(const std::runtime_error& e) {
1100       throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
1101     }
1102   }
1103 
1104   /* client side context */
GnuTLSIOCtx(const TLSContextParameters & params)1105   GnuTLSIOCtx(const TLSContextParameters& params): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(true), d_validateCerts(params.d_validateCertificates)
1106   {
1107     int rc = 0;
1108 
1109     gnutls_certificate_credentials_t creds;
1110     rc = gnutls_certificate_allocate_credentials(&creds);
1111     if (rc != GNUTLS_E_SUCCESS) {
1112       throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
1113     }
1114 
1115     d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
1116     creds = nullptr;
1117 
1118     if (params.d_validateCertificates) {
1119       if (params.d_caStore.empty()) {
1120         rc = gnutls_certificate_set_x509_system_trust(d_creds.get());
1121         if (rc < 0) {
1122           throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
1123         }
1124       }
1125       else {
1126         rc = gnutls_certificate_set_x509_trust_file(d_creds.get(), params.d_caStore.c_str(), GNUTLS_X509_FMT_PEM);
1127         if (rc < 0) {
1128           throw std::runtime_error("Error adding '" + params.d_caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
1129         }
1130       }
1131     }
1132 
1133     rc = gnutls_priority_init(&d_priorityCache, params.d_ciphers.empty() ? "NORMAL" : params.d_ciphers.c_str(), nullptr);
1134     if (rc != GNUTLS_E_SUCCESS) {
1135       throw std::runtime_error("Error setting up TLS cipher preferences to 'NORMAL' (" + std::string(gnutls_strerror(rc)) + ")");
1136     }
1137   }
1138 
~GnuTLSIOCtx()1139   virtual ~GnuTLSIOCtx() override
1140   {
1141     d_creds.reset();
1142 
1143     if (d_priorityCache) {
1144       gnutls_priority_deinit(d_priorityCache);
1145     }
1146   }
1147 
getConnection(int socket,unsigned int timeout,time_t now)1148   std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
1149   {
1150     handleTicketsKeyRotation(now);
1151 
1152     std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
1153     {
1154       ReadLock rl(&d_lock);
1155       ticketsKey = d_ticketsKey;
1156     }
1157 
1158     return std::make_unique<GnuTLSConnection>(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets);
1159   }
1160 
getClientConnection(const std::string & host,int socket,unsigned int timeout)1161   std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, unsigned int timeout) override
1162   {
1163     return std::make_unique<GnuTLSConnection>(host, socket, timeout, d_creds.get(), d_priorityCache, d_validateCerts);
1164   }
1165 
rotateTicketsKey(time_t now)1166   void rotateTicketsKey(time_t now) override
1167   {
1168     if (!d_enableTickets) {
1169       return;
1170     }
1171 
1172     auto newKey = std::make_shared<GnuTLSTicketsKey>();
1173 
1174     {
1175       WriteLock wl(&d_lock);
1176       d_ticketsKey = newKey;
1177     }
1178 
1179     if (d_ticketsKeyRotationDelay > 0) {
1180       d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
1181     }
1182   }
1183 
loadTicketsKeys(const std::string & file)1184   void loadTicketsKeys(const std::string& file) override final
1185   {
1186     if (!d_enableTickets) {
1187       return;
1188     }
1189 
1190     auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
1191     {
1192       WriteLock wl(&d_lock);
1193       d_ticketsKey = newKey;
1194     }
1195 
1196     if (d_ticketsKeyRotationDelay > 0) {
1197       d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
1198     }
1199   }
1200 
getTicketsKeysCount()1201   size_t getTicketsKeysCount() override
1202   {
1203     ReadLock rl(&d_lock);
1204     return d_ticketsKey != nullptr ? 1 : 0;
1205   }
1206 
1207 private:
1208   std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
1209   gnutls_priority_t d_priorityCache{nullptr};
1210   std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
1211   ReadWriteLock d_lock;
1212   bool d_enableTickets{true};
1213   bool d_validateCerts{true};
1214 };
1215 
1216 #endif /* HAVE_GNUTLS */
1217 
1218 #endif /* HAVE_DNS_OVER_TLS */
1219 
setupTLS()1220 bool TLSFrontend::setupTLS()
1221 {
1222 #ifdef HAVE_DNS_OVER_TLS
1223   std::shared_ptr<TLSCtx> newCtx{nullptr};
1224   /* get the "best" available provider */
1225   if (!d_provider.empty()) {
1226 #ifdef HAVE_GNUTLS
1227     if (d_provider == "gnutls") {
1228       newCtx = std::make_shared<GnuTLSIOCtx>(*this);
1229       std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
1230       return true;
1231     }
1232 #endif /* HAVE_GNUTLS */
1233 #ifdef HAVE_LIBSSL
1234     if (d_provider == "openssl") {
1235       newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1236       std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
1237       return true;
1238     }
1239 #endif /* HAVE_LIBSSL */
1240   }
1241 #ifdef HAVE_LIBSSL
1242   newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
1243 #else /* HAVE_LIBSSL */
1244 #ifdef HAVE_GNUTLS
1245   newCtx = std::make_shared<GnuTLSIOCtx>(*this);
1246 #endif /* HAVE_GNUTLS */
1247 #endif /* HAVE_LIBSSL */
1248 
1249   std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
1250 #endif /* HAVE_DNS_OVER_TLS */
1251   return true;
1252 }
1253 
getTLSContext(const TLSContextParameters & params)1254 std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params)
1255 {
1256 #ifdef HAVE_DNS_OVER_TLS
1257   /* get the "best" available provider */
1258   if (!params.d_provider.empty()) {
1259 #ifdef HAVE_GNUTLS
1260     if (params.d_provider == "gnutls") {
1261       return std::make_shared<GnuTLSIOCtx>(params);
1262     }
1263 #endif /* HAVE_GNUTLS */
1264 #ifdef HAVE_LIBSSL
1265     if (params.d_provider == "openssl") {
1266       return std::make_shared<OpenSSLTLSIOCtx>(params);
1267     }
1268 #endif /* HAVE_LIBSSL */
1269   }
1270 #ifdef HAVE_GNUTLS
1271   return std::make_shared<GnuTLSIOCtx>(params);
1272 #else /* HAVE_GNUTLS */
1273 #ifdef HAVE_LIBSSL
1274   return std::make_shared<OpenSSLTLSIOCtx>(params);
1275 #endif /* HAVE_LIBSSL */
1276 #endif /* HAVE_GNUTLS */
1277 
1278 #endif /* HAVE_DNS_OVER_TLS */
1279   return nullptr;
1280 }
1281