1 ////////////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright 2006 - 2021, Tomas Babej, Paul Beckingham, Federico Hernandez.
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining a copy
6 // of this software and associated documentation files (the "Software"), to deal
7 // in the Software without restriction, including without limitation the rights
8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // copies of the Software, and to permit persons to whom the Software is
10 // furnished to do so, subject to the following conditions:
11 //
12 // The above copyright notice and this permission notice shall be included
13 // in all copies or substantial portions of the Software.
14 //
15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
16 // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 // SOFTWARE.
22 //
23 // https://www.opensource.org/licenses/mit-license.php
24 //
25 ////////////////////////////////////////////////////////////////////////////////
26 
27 #include <cmake.h>
28 
29 #ifdef HAVE_LIBGNUTLS
30 
31 #include <TLSClient.h>
32 #include <iostream>
33 #include <unistd.h>
34 #include <stdio.h>
35 #include <stdlib.h>
36 #include <stdint.h>
37 #include <string.h>
38 #include <sys/socket.h>
39 #include <arpa/inet.h>
40 #include <errno.h>
41 #include <sys/types.h>
42 #include <netdb.h>
43 #include <gnutls/x509.h>
44 #include <shared.h>
45 #include <format.h>
46 
47 #define MAX_BUF 16384
48 
49 #if GNUTLS_VERSION_NUMBER < 0x030406
50 #if GNUTLS_VERSION_NUMBER >= 0x020a00
51 static int verify_certificate_callback (gnutls_session_t);
52 #endif
53 #endif
54 
55 ////////////////////////////////////////////////////////////////////////////////
gnutls_log_function(int level,const char * message)56 static void gnutls_log_function (int level, const char* message)
57 {
58   std::cout << "c: " << level << ' ' << message;
59 }
60 
61 ////////////////////////////////////////////////////////////////////////////////
62 #if GNUTLS_VERSION_NUMBER < 0x030406
63 #if GNUTLS_VERSION_NUMBER >= 0x020a00
verify_certificate_callback(gnutls_session_t session)64 static int verify_certificate_callback (gnutls_session_t session)
65 {
66   const TLSClient* client = (TLSClient*) gnutls_session_get_ptr (session); // All
67   return client->verify_certificate ();
68 }
69 #endif
70 #endif
71 
72 ////////////////////////////////////////////////////////////////////////////////
~TLSClient()73 TLSClient::~TLSClient ()
74 {
75   gnutls_deinit (_session); // All
76   gnutls_certificate_free_credentials (_credentials); // All
77 #if GNUTLS_VERSION_NUMBER < 0x030300
78   gnutls_global_deinit (); // All
79 #endif
80 
81   if (_socket)
82   {
83     shutdown (_socket, SHUT_RDWR);
84     close (_socket);
85   }
86 }
87 
88 ////////////////////////////////////////////////////////////////////////////////
limit(int max)89 void TLSClient::limit (int max)
90 {
91   _limit = max;
92 }
93 
94 ////////////////////////////////////////////////////////////////////////////////
95 // Calling this method results in all subsequent socket traffic being sent to
96 // std::cout, labelled with 'c: ...'.
debug(int level)97 void TLSClient::debug (int level)
98 {
99   if (level)
100     _debug = true;
101 
102   gnutls_global_set_log_function (gnutls_log_function); // All
103   gnutls_global_set_log_level (level); // All
104 }
105 
106 ////////////////////////////////////////////////////////////////////////////////
trust(const enum trust_level value)107 void TLSClient::trust (const enum trust_level value)
108 {
109   _trust = value;
110   if (_debug)
111   {
112     if (_trust == allow_all)
113       std::cout << "c: INFO Server certificate will be trusted automatically.\n";
114     else if (_trust == ignore_hostname)
115       std::cout << "c: INFO Server certificate will be verified but hostname ignored.\n";
116     else
117       std::cout << "c: INFO Server certificate will be verified.\n";
118   }
119 }
120 
121 ////////////////////////////////////////////////////////////////////////////////
ciphers(const std::string & cipher_list)122 void TLSClient::ciphers (const std::string& cipher_list)
123 {
124   _ciphers = cipher_list;
125 }
126 
127 ////////////////////////////////////////////////////////////////////////////////
init(const std::string & ca,const std::string & cert,const std::string & key)128 void TLSClient::init (
129   const std::string& ca,
130   const std::string& cert,
131   const std::string& key)
132 {
133   _ca   = ca;
134   _cert = cert;
135   _key  = key;
136 
137   int ret;
138 #if GNUTLS_VERSION_NUMBER < 0x030300
139   ret = gnutls_global_init (); // All
140   if (ret < 0)
141     throw format ("TLS init error. {1}", gnutls_strerror (ret)); // All
142 #endif
143 
144   ret = gnutls_certificate_allocate_credentials (&_credentials); // All
145   if (ret < 0)
146     throw format ("TLS allocation error. {1}", gnutls_strerror (ret)); // All
147 
148 #if GNUTLS_VERSION_NUMBER >= 0x030014
149   // Automatic loading of system installed CA certificates.
150   ret = gnutls_certificate_set_x509_system_trust (_credentials); // 3.0.20
151   if (ret < 0)
152     throw format ("Bad System Trust. {1}", gnutls_strerror (ret)); // All
153 #endif
154 
155   if (_ca != "")
156   {
157     // The gnutls_certificate_set_x509_key_file call returns number of
158     // certificates parsed on success (including 0, when no certificate was
159     // found) and negative values on error
160     ret = gnutls_certificate_set_x509_trust_file (_credentials, _ca.c_str (), GNUTLS_X509_FMT_PEM); // All
161     if (ret == 0)
162       throw format ("CA file {1} contains no certificate.", _ca);
163     else if (ret < 0)
164       throw format ("Bad CA file: {1}", gnutls_strerror (ret)); // All
165 
166   }
167 
168   // TODO This may need 0x030111 protection.
169   if (_cert != "" &&
170       _key != "" &&
171       (ret = gnutls_certificate_set_x509_key_file (_credentials, _cert.c_str (), _key.c_str (), GNUTLS_X509_FMT_PEM)) < 0) // 3.1.11
172     throw format ("Bad client CERT/KEY file. {1}", gnutls_strerror (ret)); // All
173 
174 #if GNUTLS_VERSION_NUMBER < 0x030406
175 #if GNUTLS_VERSION_NUMBER >= 0x020a00
176   // The automatic verification for the server certificate with
177   // gnutls_certificate_set_verify_function only works with gnutls
178   // >=2.9.10. So with older versions we should call the verify function
179   // manually after the gnutls handshake.
180   gnutls_certificate_set_verify_function (_credentials, verify_certificate_callback); // 2.10.0
181 #endif
182 #endif
183   ret = gnutls_init (&_session, GNUTLS_CLIENT); // All
184   if (ret < 0)
185     throw format ("TLS client init error. {1}", gnutls_strerror (ret)); // All
186 
187   // Use default priorities unless overridden.
188   if (_ciphers == "")
189     _ciphers = "NORMAL";
190 
191   const char *err;
192   ret = gnutls_priority_set_direct (_session, _ciphers.c_str (), &err); // All
193   if (ret < 0)
194   {
195     if (_debug && ret == GNUTLS_E_INVALID_REQUEST)
196       std::cout << "c: ERROR Priority error at: " << err << '\n';
197 
198     throw format ("Error initializing TLS. {1}", gnutls_strerror (ret)); // All
199   }
200 
201   // Apply the x509 credentials to the current session.
202   ret = gnutls_credentials_set (_session, GNUTLS_CRD_CERTIFICATE, _credentials); // All
203   if (ret < 0)
204     throw format ("TLS credentials error. {1}", gnutls_strerror (ret)); // All
205 }
206 
207 ////////////////////////////////////////////////////////////////////////////////
connect(const std::string & host,const std::string & port)208 void TLSClient::connect (const std::string& host, const std::string& port)
209 {
210   _host = host;
211   _port = port;
212 
213   int ret;
214 #if GNUTLS_VERSION_NUMBER >= 0x030406
215   // For _trust == TLSClient::allow_all we perform no action
216   if (_trust == TLSClient::ignore_hostname)
217     gnutls_session_set_verify_cert (_session, nullptr, 0); // 3.4.6
218   else if (_trust == TLSClient::strict)
219     gnutls_session_set_verify_cert (_session, _host.c_str (), 0); // 3.4.6
220 #endif
221 
222   // SNI.  Only permitted when _host is a DNS name, not an IPv4/6 address.
223   std::string dummyAddress;
224   int dummyPort;
225   if (! isIPv4Address (_host, dummyAddress, dummyPort) &&
226       ! isIPv6Address (_host, dummyAddress, dummyPort))
227   {
228     ret = gnutls_server_name_set (_session, GNUTLS_NAME_DNS, _host.c_str (), _host.length ()); // All
229     if (ret < 0)
230       throw format ("TLS SNI error. {1}", gnutls_strerror (ret)); // All
231   }
232 
233   // Store the TLSClient instance, so that the verification callback can access
234   // it during the handshake below and call the verification method.
235   gnutls_session_set_ptr (_session, (void*) this); // All
236 
237   // use IPv4 or IPv6, does not matter.
238   struct addrinfo hints {};
239   hints.ai_family   = AF_UNSPEC;
240   hints.ai_socktype = SOCK_STREAM;
241   hints.ai_flags    = AI_PASSIVE; // use my IP
242 
243   struct addrinfo* res;
244   ret = ::getaddrinfo (host.c_str (), port.c_str (), &hints, &res);
245   if (ret != 0)
246     throw std::string (::gai_strerror (ret));
247 
248   // Try them all, stop on success.
249   struct addrinfo* p;
250   for (p = res; p != nullptr; p = p->ai_next)
251   {
252     if ((_socket = ::socket (p->ai_family, p->ai_socktype, p->ai_protocol)) == -1)
253       continue;
254 
255     // When a socket is closed, it remains unavailable for a while (netstat -an).
256     // Setting SO_REUSEADDR allows this program to assume control of a closed,
257     // but unavailable socket.
258     int on = 1;
259     if (::setsockopt (_socket,
260                       SOL_SOCKET,
261                       SO_REUSEADDR,
262                       (const void*) &on,
263                       sizeof (on)) == -1)
264       throw std::string (::strerror (errno));
265 
266     if (::connect (_socket, p->ai_addr, p->ai_addrlen) == -1)
267       continue;
268 
269     break;
270   }
271 
272   free (res);
273 
274   if (p == nullptr)
275     throw format ("Could not connect to {1} {2}", host, port);
276 
277 #if GNUTLS_VERSION_NUMBER >= 0x030100
278   gnutls_handshake_set_timeout (_session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); // 3.1.0
279 #endif
280 
281 #if GNUTLS_VERSION_NUMBER >= 0x030109
282   gnutls_transport_set_int (_session, _socket); // 3.1.9
283 #else
284   gnutls_transport_set_ptr (_session, (gnutls_transport_ptr_t) (intptr_t) _socket); // All
285 #endif
286 
287   // Perform the TLS handshake
288   do
289   {
290     ret = gnutls_handshake (_session); // All
291   }
292   while (ret < 0 && gnutls_error_is_fatal (ret) == 0); // All
293 
294   if (ret < 0)
295   {
296 #if GNUTLS_VERSION_NUMBER >= 0x030406
297     if (ret == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR)
298     {
299       auto type = gnutls_certificate_type_get (_session); // All
300       auto status = gnutls_session_get_verify_cert_status (_session); // 3.4.6
301       gnutls_datum_t out;
302       gnutls_certificate_verification_status_print (status, type, &out, 0);  // 3.1.4
303 
304       std::string error {(const char*) out.data};
305       gnutls_free (out.data); // All
306 
307       throw format ("Handshake failed. {1}", error); // All
308     }
309 #else
310     throw format ("Handshake failed. {1}", gnutls_strerror (ret)); // All
311 #endif
312   }
313 
314 #if GNUTLS_VERSION_NUMBER < 0x020a00
315   // The automatic verification for the server certificate with
316   // gnutls_certificate_set_verify_function does only work with gnutls
317   // >=2.10.0. So with older versions we should call the verify function
318   // manually after the gnutls handshake.
319   ret = verify_certificate ();
320   if (ret < 0)
321   {
322     if (_debug)
323       std::cout << "c: ERROR Certificate verification failed.\n";
324     throw format ("Error initializing TLS. {1}", gnutls_strerror (ret)); // All
325   }
326 #endif
327 
328   if (_debug)
329   {
330 #if GNUTLS_VERSION_NUMBER >= 0x03010a
331     char* desc = gnutls_session_get_desc (_session); // 3.1.10
332     std::cout << "c: INFO Handshake was completed: " << desc << '\n';
333     gnutls_free (desc);
334 #else
335     std::cout << "c: INFO Handshake was completed.\n";
336 #endif
337   }
338 }
339 
340 ////////////////////////////////////////////////////////////////////////////////
bye()341 void TLSClient::bye ()
342 {
343   gnutls_bye (_session, GNUTLS_SHUT_RDWR); // All
344 }
345 
346 ////////////////////////////////////////////////////////////////////////////////
verify_certificate() const347 int TLSClient::verify_certificate () const
348 {
349   if (_trust == TLSClient::allow_all)
350     return 0;
351 
352   if (_debug)
353     std::cout << "c: INFO Verifying certificate.\n";
354 
355   // This verification function uses the trusted CAs in the credentials
356   // structure. So you must have installed one or more CA certificates.
357   unsigned int status = 0;
358   const char* hostname = _host.c_str();
359 #if GNUTLS_VERSION_NUMBER >= 0x030104
360   if (_trust == TLSClient::ignore_hostname)
361     hostname = nullptr;
362 
363   int ret = gnutls_certificate_verify_peers3 (_session, hostname, &status); // 3.1.4
364   if (ret < 0)
365   {
366     if (_debug)
367       std::cout << "c: ERROR Certificate verification peers3 failed. " << gnutls_strerror (ret) << '\n'; // All
368     return GNUTLS_E_CERTIFICATE_ERROR;
369   }
370 
371   // status 16450 == 0100000001000010
372   //   GNUTLS_CERT_INVALID             1<<1
373   //   GNUTLS_CERT_SIGNER_NOT_FOUND    1<<6
374   //   GNUTLS_CERT_UNEXPECTED_OWNER    1<<14  Hostname does not match
375 
376   if (_debug && status)
377     std::cout << "c: ERROR Certificate status=" << status << '\n';
378 #else
379   int ret = gnutls_certificate_verify_peers2 (_session, &status); // All
380   if (ret < 0)
381   {
382     if (_debug)
383       std::cout << "c: ERROR Certificate verification peers2 failed. " << gnutls_strerror (ret) << '\n'; // All
384     return GNUTLS_E_CERTIFICATE_ERROR;
385   }
386 
387   if (_debug && status)
388     std::cout << "c: ERROR Certificate status=" << status << '\n';
389 
390   if ((status == 0) && (_trust != TLSClient::ignore_hostname))
391   {
392     if (gnutls_certificate_type_get (_session) == GNUTLS_CRT_X509) // All
393     {
394       const gnutls_datum* cert_list;
395       unsigned int cert_list_size;
396       gnutls_x509_crt cert;
397 
398       cert_list = gnutls_certificate_get_peers (_session, &cert_list_size); // All
399       if (cert_list_size == 0)
400       {
401         if (_debug)
402           std::cout << "c: ERROR Certificate get peers failed. " << gnutls_strerror (ret) << '\n'; // All
403         return GNUTLS_E_CERTIFICATE_ERROR;
404       }
405 
406       ret = gnutls_x509_crt_init (&cert); // All
407       if (ret < 0)
408       {
409         if (_debug)
410           std::cout << "c: ERROR x509 init failed. " << gnutls_strerror (ret) << '\n'; // All
411         return GNUTLS_E_CERTIFICATE_ERROR;
412       }
413 
414       ret = gnutls_x509_crt_import (cert, &cert_list[0], GNUTLS_X509_FMT_DER); // All
415       if (ret < 0)
416       {
417         if (_debug)
418           std::cout << "c: ERROR x509 cert import. " << gnutls_strerror (ret) << '\n'; // All
419         gnutls_x509_crt_deinit(cert); // All
420         return GNUTLS_E_CERTIFICATE_ERROR;
421       }
422 
423       if (gnutls_x509_crt_check_hostname (cert, hostname) == 0) // All
424       {
425         if (_debug)
426           std::cout << "c: ERROR x509 cert check hostname. " << gnutls_strerror (ret) << '\n'; // All
427         gnutls_x509_crt_deinit(cert);
428         return GNUTLS_E_CERTIFICATE_ERROR;
429       }
430     }
431     else
432       return GNUTLS_E_CERTIFICATE_ERROR;
433   }
434 #endif
435 
436 #if GNUTLS_VERSION_NUMBER >= 0x030104
437   gnutls_certificate_type_t type = gnutls_certificate_type_get (_session); // All
438   gnutls_datum_t out;
439   ret = gnutls_certificate_verification_status_print (status, type, &out, 0); // 3.1.4
440   if (ret < 0)
441   {
442     if (_debug)
443       std::cout << "c: ERROR certificate verification status. " << gnutls_strerror (ret) << '\n'; // All
444     return GNUTLS_E_CERTIFICATE_ERROR;
445   }
446 
447   if (_debug)
448     std::cout << "c: INFO " << out.data << '\n';
449   gnutls_free (out.data);
450 #endif
451 
452   if (status != 0)
453     return GNUTLS_E_CERTIFICATE_ERROR;
454 
455   // Continue handshake.
456   return 0;
457 }
458 
459 ////////////////////////////////////////////////////////////////////////////////
send(const std::string & data)460 void TLSClient::send (const std::string& data)
461 {
462   std::string packet = "XXXX" + data;
463 
464   // Encode the length.
465   unsigned long l = packet.length ();
466   packet[0] = l >>24;
467   packet[1] = l >>16;
468   packet[2] = l >>8;
469   packet[3] = l;
470 
471   unsigned int total = 0;
472   unsigned int remaining = packet.length ();
473 
474   while (total < packet.length ())
475   {
476     int status;
477     do
478     {
479       status = gnutls_record_send (_session, packet.c_str () + total, remaining); // All
480     }
481     while (errno == GNUTLS_E_INTERRUPTED ||
482            errno == GNUTLS_E_AGAIN);
483 
484     if (status == -1)
485       break;
486 
487     total     += (unsigned int) status;
488     remaining -= (unsigned int) status;
489   }
490 
491   if (_debug)
492     std::cout << "c: INFO Sending 'XXXX"
493               << data.c_str ()
494               << "' (" << total << " bytes)"
495               << std::endl;
496 }
497 
498 ////////////////////////////////////////////////////////////////////////////////
recv(std::string & data)499 void TLSClient::recv (std::string& data)
500 {
501   data = "";          // No appending of data.
502   int received = 0;
503 
504   // Get the encoded length.
505   unsigned char header[4] {};
506   do
507   {
508     received = gnutls_record_recv (_session, header, 4); // All
509   }
510   while (received > 0 &&
511          (errno == GNUTLS_E_INTERRUPTED ||
512           errno == GNUTLS_E_AGAIN));
513 
514   int total = received;
515 
516   // Decode the length.
517   unsigned long expected = (header[0]<<24) |
518                            (header[1]<<16) |
519                            (header[2]<<8) |
520                             header[3];
521   if (_debug)
522     std::cout << "c: INFO expecting " << expected << " bytes.\n";
523 
524   // TODO This would be a good place to assert 'expected < _limit'.
525 
526   // Arbitrary buffer size.
527   char buffer[MAX_BUF];
528 
529   // Keep reading until no more data.  Concatenate chunks of data if a) the
530   // read was interrupted by a signal, and b) if there is more data than
531   // fits in the buffer.
532   do
533   {
534     do
535     {
536       received = gnutls_record_recv (_session, buffer, MAX_BUF - 1); // All
537     }
538     while (received > 0 &&
539            (errno == GNUTLS_E_INTERRUPTED ||
540             errno == GNUTLS_E_AGAIN));
541 
542     // Other end closed the connection.
543     if (received == 0)
544     {
545       if (_debug)
546         std::cout << "c: INFO Peer has closed the TLS connection\n";
547       break;
548     }
549 
550     // Something happened.
551     if (received < 0 && gnutls_error_is_fatal (received) == 0) // All
552     {
553       if (_debug)
554         std::cout << "c: WARNING " << gnutls_strerror (received) << '\n'; // All
555     }
556     else if (received < 0)
557       throw std::string (gnutls_strerror (received)); // All
558 
559     buffer [received] = '\0';
560     data += buffer;
561     total += received;
562 
563     // Stop at defined limit.
564     if (_limit && total > _limit)
565       break;
566   }
567   while (received > 0 && total < (int) expected);
568 
569   if (_debug)
570     std::cout << "c: INFO Receiving 'XXXX"
571               << data.c_str ()
572               << "' (" << total << " bytes)"
573               << std::endl;
574 }
575 
576 ////////////////////////////////////////////////////////////////////////////////
577 #endif
578