1 ////////////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright 2006 - 2015, Göteborg Bit Factory.
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining a copy
6 // of this software and associated documentation files (the "Software"), to deal
7 // in the Software without restriction, including without limitation the rights
8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // copies of the Software, and to permit persons to whom the Software is
10 // furnished to do so, subject to the following conditions:
11 //
12 // The above copyright notice and this permission notice shall be included
13 // in all copies or substantial portions of the Software.
14 //
15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
16 // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 // SOFTWARE.
22 //
23 // http://www.opensource.org/licenses/mit-license.php
24 //
25 ////////////////////////////////////////////////////////////////////////////////
26 
27 #include <cmake.h>
28 
29 #ifdef HAVE_LIBGNUTLS
30 
31 #include <iostream>
32 #include <unistd.h>
33 #include <stdio.h>
34 #include <stdlib.h>
35 #include <stdint.h>
36 #include <string.h>
37 #include <sys/socket.h>
38 #include <arpa/inet.h>
39 #if (defined OPENBSD || defined SOLARIS || defined NETBSD)
40 #include <errno.h>
41 #else
42 #include <sys/errno.h>
43 #endif
44 #include <sys/types.h>
45 #include <sys/socket.h>
46 #include <netdb.h>
47 #include <TLSClient.h>
48 #include <gnutls/x509.h>
49 #include <text.h>
50 
51 #define MAX_BUF 16384
52 
53 static int verify_certificate_callback (gnutls_session_t);
54 
55 ////////////////////////////////////////////////////////////////////////////////
gnutls_log_function(int level,const char * message)56 static void gnutls_log_function (int level, const char* message)
57 {
58   std::cout << "c: " << level << " " << message;
59 }
60 
61 ////////////////////////////////////////////////////////////////////////////////
verify_certificate_callback(gnutls_session_t session)62 static int verify_certificate_callback (gnutls_session_t session)
63 {
64   const TLSClient* client = (TLSClient*) gnutls_session_get_ptr (session);
65   return client->verify_certificate ();
66 }
67 
68 ////////////////////////////////////////////////////////////////////////////////
TLSClient()69 TLSClient::TLSClient ()
70 : _ca ("")
71 , _cert ("")
72 , _key ("")
73 , _host ("")
74 , _port ("")
75 , _session(0)
76 , _socket (0)
77 , _limit (0)
78 , _debug (false)
79 , _trust(strict)
80 {
81 }
82 
83 ////////////////////////////////////////////////////////////////////////////////
~TLSClient()84 TLSClient::~TLSClient ()
85 {
86   gnutls_deinit (_session);
87   gnutls_certificate_free_credentials (_credentials);
88   gnutls_global_deinit ();
89 
90   if (_socket)
91   {
92     shutdown (_socket, SHUT_RDWR);
93     close (_socket);
94   }
95 }
96 
97 ////////////////////////////////////////////////////////////////////////////////
limit(int max)98 void TLSClient::limit (int max)
99 {
100   _limit = max;
101 }
102 
103 ////////////////////////////////////////////////////////////////////////////////
104 // Calling this method results in all subsequent socket traffic being sent to
105 // std::cout, labelled with 'c: ...'.
debug(int level)106 void TLSClient::debug (int level)
107 {
108   if (level)
109     _debug = true;
110 
111   gnutls_global_set_log_function (gnutls_log_function);
112   gnutls_global_set_log_level (level);
113 }
114 
115 ////////////////////////////////////////////////////////////////////////////////
trust(const enum trust_level value)116 void TLSClient::trust (const enum trust_level value)
117 {
118   _trust = value;
119   if (_debug)
120   {
121     if (_trust == allow_all)
122       std::cout << "c: INFO Server certificate will be trusted automatically.\n";
123     else if (_trust == ignore_hostname)
124       std::cout << "c: INFO Server certificate will be verified but hostname ignored.\n";
125     else
126       std::cout << "c: INFO Server certificate will be verified.\n";
127   }
128 }
129 
130 ////////////////////////////////////////////////////////////////////////////////
ciphers(const std::string & cipher_list)131 void TLSClient::ciphers (const std::string& cipher_list)
132 {
133   _ciphers = cipher_list;
134 }
135 
136 ////////////////////////////////////////////////////////////////////////////////
init(const std::string & ca,const std::string & cert,const std::string & key)137 void TLSClient::init (
138   const std::string& ca,
139   const std::string& cert,
140   const std::string& key)
141 {
142   _ca   = ca;
143   _cert = cert;
144   _key  = key;
145 
146   int ret = gnutls_global_init ();
147   if (ret < 0)
148     throw format ("TLS init error. {1}", gnutls_strerror (ret));
149 
150   ret = gnutls_certificate_allocate_credentials (&_credentials);
151   if (ret < 0)
152     throw format ("TLS allocation error. {1}", gnutls_strerror (ret));
153 
154   if (_ca != "" &&
155       (ret = gnutls_certificate_set_x509_trust_file (_credentials, _ca.c_str (), GNUTLS_X509_FMT_PEM)) < 0)
156     throw format ("Bad CA file. {1}", gnutls_strerror (ret));
157 
158   if (_cert != "" &&
159       _key != "" &&
160       (ret = gnutls_certificate_set_x509_key_file (_credentials, _cert.c_str (), _key.c_str (), GNUTLS_X509_FMT_PEM)) < 0)
161     throw format ("Bad CERT file. {1}", gnutls_strerror (ret));
162 
163 #if GNUTLS_VERSION_NUMBER >= 0x02090a
164   // The automatic verification for the server certificate with
165   // gnutls_certificate_set_verify_function only works with gnutls
166   // >=2.9.10. So with older versions we should call the verify function
167   // manually after the gnutls handshake.
168   gnutls_certificate_set_verify_function (_credentials, verify_certificate_callback);
169 #endif
170   ret = gnutls_init (&_session, GNUTLS_CLIENT);
171   if (ret < 0)
172     throw format ("TLS client init error. {1}", gnutls_strerror (ret));
173 
174   // Use default priorities unless overridden.
175   if (_ciphers == "")
176     _ciphers = "NORMAL";
177 
178   const char *err;
179   ret = gnutls_priority_set_direct (_session, _ciphers.c_str (), &err);
180   if (ret < 0)
181   {
182     if (_debug && ret == GNUTLS_E_INVALID_REQUEST)
183       std::cout << "c: ERROR Priority error at: " << err << "\n";
184 
185     throw format ("Error initializing TLS. {1}", gnutls_strerror (ret));
186   }
187 
188   // Apply the x509 credentials to the current session.
189   ret = gnutls_credentials_set (_session, GNUTLS_CRD_CERTIFICATE, _credentials);
190   if (ret < 0)
191     throw format ("TLS credentials error. {1}", gnutls_strerror (ret));
192 }
193 
194 ////////////////////////////////////////////////////////////////////////////////
connect(const std::string & host,const std::string & port)195 void TLSClient::connect (const std::string& host, const std::string& port)
196 {
197   _host = host;
198   _port = port;
199 
200   // Store the TLSClient instance, so that the verification callback can access
201   // it during the handshake below and call the verifcation method.
202   gnutls_session_set_ptr (_session, (void*) this);
203 
204   // use IPv4 or IPv6, does not matter.
205   struct addrinfo hints = {0};
206   hints.ai_family   = AF_UNSPEC;
207   hints.ai_socktype = SOCK_STREAM;
208   hints.ai_flags    = AI_PASSIVE; // use my IP
209 
210   struct addrinfo* res;
211   int ret = ::getaddrinfo (host.c_str (), port.c_str (), &hints, &res);
212   if (ret != 0)
213     throw std::string (::gai_strerror (ret));
214 
215   // Try them all, stop on success.
216   struct addrinfo* p;
217   for (p = res; p != NULL; p = p->ai_next)
218   {
219     if ((_socket = ::socket (p->ai_family, p->ai_socktype, p->ai_protocol)) == -1)
220       continue;
221 
222     // When a socket is closed, it remains unavailable for a while (netstat -an).
223     // Setting SO_REUSEADDR allows this program to assume control of a closed,
224     // but unavailable socket.
225     int on = 1;
226     if (::setsockopt (_socket,
227                       SOL_SOCKET,
228                       SO_REUSEADDR,
229                       (const void*) &on,
230                       sizeof (on)) == -1)
231       throw std::string (::strerror (errno));
232 
233     if (::connect (_socket, p->ai_addr, p->ai_addrlen) == -1)
234       continue;
235 
236     break;
237   }
238 
239   free (res);
240 
241   if (p == NULL)
242     throw format ("Could not connect to {1} {2}", host, port);
243 
244 #if GNUTLS_VERSION_NUMBER >= 0x030109
245   gnutls_transport_set_int (_session, _socket);
246 #else
247   gnutls_transport_set_ptr (_session, (gnutls_transport_ptr_t) (intptr_t) _socket);
248 #endif
249 
250   // Perform the TLS handshake
251   do
252   {
253     ret = gnutls_handshake (_session);
254   }
255   while (ret < 0 && gnutls_error_is_fatal (ret) == 0);
256   if (ret < 0)
257     throw format ("Handshake failed. {1}", gnutls_strerror (ret));
258 
259 #if GNUTLS_VERSION_NUMBER < 0x02090a
260   // The automatic verification for the server certificate with
261   // gnutls_certificate_set_verify_function does only work with gnutls
262   // >=2.9.10. So with older versions we should call the verify function
263   // manually after the gnutls handshake.
264   ret = verify_certificate ();
265   if (ret < 0)
266   {
267     if (_debug)
268       std::cout << "c: ERROR Certificate verification failed.\n";
269     throw format ("Error Initializing TLS. {1}", gnutls_strerror (ret));
270   }
271 #endif
272 
273   if (_debug)
274   {
275 #if GNUTLS_VERSION_NUMBER >= 0x03010a
276     char* desc = gnutls_session_get_desc (_session);
277     std::cout << "c: INFO Handshake was completed: " << desc << "\n";
278     gnutls_free (desc);
279 #else
280     std::cout << "c: INFO Handshake was completed.\n";
281 #endif
282   }
283 }
284 
285 ////////////////////////////////////////////////////////////////////////////////
bye()286 void TLSClient::bye ()
287 {
288   gnutls_bye (_session, GNUTLS_SHUT_RDWR);
289 }
290 
291 ////////////////////////////////////////////////////////////////////////////////
verify_certificate() const292 int TLSClient::verify_certificate () const
293 {
294   if (_trust == TLSClient::allow_all)
295     return 0;
296 
297   if (_debug)
298     std::cout << "c: INFO Verifying certificate.\n";
299 
300   // This verification function uses the trusted CAs in the credentials
301   // structure. So you must have installed one or more CA certificates.
302   unsigned int status = 0;
303   const char* hostname = _host.c_str();
304 #if GNUTLS_VERSION_NUMBER >= 0x030104
305   if (_trust == TLSClient::ignore_hostname)
306     hostname = NULL;
307 
308   int ret = gnutls_certificate_verify_peers3 (_session, hostname, &status);
309   if (ret < 0)
310   {
311     if (_debug)
312       std::cout << "c: ERROR Certificate verification peers3 failed. " << gnutls_strerror (ret) << "\n";
313     return GNUTLS_E_CERTIFICATE_ERROR;
314   }
315 
316   // status 16450 == 0100000001000010
317   //   GNUTLS_CERT_INVALID             1<<1
318   //   GNUTLS_CERT_SIGNER_NOT_FOUND    1<<6
319   //   GNUTLS_CERT_UNEXPECTED_OWNER    1<<14  Hostname does not match
320 
321   if (_debug && status)
322     std::cout << "c: ERROR Certificate status=" << status << "\n";
323 #else
324   int ret = gnutls_certificate_verify_peers2 (_session, &status);
325   if (ret < 0)
326   {
327     if (_debug)
328       std::cout << "c: ERROR Certificate verification peers2 failed. " << gnutls_strerror (ret) << "\n";
329     return GNUTLS_E_CERTIFICATE_ERROR;
330   }
331 
332   if (_debug && status)
333     std::cout << "c: ERROR Certificate status=" << status << "\n";
334 
335   if ((status == 0) && (_trust != TLSClient::ignore_hostname))
336   {
337     if (gnutls_certificate_type_get (_session) == GNUTLS_CRT_X509)
338     {
339       const gnutls_datum* cert_list;
340       unsigned int cert_list_size;
341       gnutls_x509_crt cert;
342 
343       cert_list = gnutls_certificate_get_peers (_session, &cert_list_size);
344       if (cert_list_size == 0)
345       {
346         if (_debug)
347           std::cout << "c: ERROR Certificate get peers failed. " << gnutls_strerror (ret) << "\n";
348         return GNUTLS_E_CERTIFICATE_ERROR;
349       }
350 
351       ret = gnutls_x509_crt_init (&cert);
352       if (ret < 0)
353       {
354         if (_debug)
355           std::cout << "c: ERROR x509 init failed. " << gnutls_strerror (ret) << "\n";
356         return GNUTLS_E_CERTIFICATE_ERROR;
357       }
358 
359       ret = gnutls_x509_crt_import (cert, &cert_list[0], GNUTLS_X509_FMT_DER);
360       if (ret < 0)
361       {
362         if (_debug)
363           std::cout << "c: ERROR x509 cert import. " << gnutls_strerror (ret) << "\n";
364         gnutls_x509_crt_deinit(cert);
365         return GNUTLS_E_CERTIFICATE_ERROR;
366       }
367 
368       if (gnutls_x509_crt_check_hostname (cert, hostname) == 0)
369       {
370         if (_debug)
371           std::cout << "c: ERROR x509 cert check hostname. " << gnutls_strerror (ret) << "\n";
372         gnutls_x509_crt_deinit(cert);
373         return GNUTLS_E_CERTIFICATE_ERROR;
374       }
375     }
376     else
377       return GNUTLS_E_CERTIFICATE_ERROR;
378   }
379 #endif
380 
381 #if GNUTLS_VERSION_NUMBER >= 0x030104
382   gnutls_certificate_type_t type = gnutls_certificate_type_get (_session);
383   gnutls_datum_t out;
384   ret = gnutls_certificate_verification_status_print (status, type, &out, 0);
385   if (ret < 0)
386   {
387     if (_debug)
388       std::cout << "c: ERROR certificate verification status. " << gnutls_strerror (ret) << "\n";
389     return GNUTLS_E_CERTIFICATE_ERROR;
390   }
391 
392   if (_debug)
393     std::cout << "c: INFO " << out.data << "\n";
394   gnutls_free (out.data);
395 #endif
396 
397   if (status != 0)
398     return GNUTLS_E_CERTIFICATE_ERROR;
399 
400   // Continue handshake.
401   return 0;
402 }
403 
404 ////////////////////////////////////////////////////////////////////////////////
send(const std::string & data)405 void TLSClient::send (const std::string& data)
406 {
407   std::string packet = "XXXX" + data;
408 
409   // Encode the length.
410   unsigned long l = packet.length ();
411   packet[0] = l >>24;
412   packet[1] = l >>16;
413   packet[2] = l >>8;
414   packet[3] = l;
415 
416   unsigned int total = 0;
417   unsigned int remaining = packet.length ();
418 
419   while (total < packet.length ())
420   {
421     int status;
422     do
423     {
424       status = gnutls_record_send (_session, packet.c_str () + total, remaining);
425     }
426     while (errno == GNUTLS_E_INTERRUPTED ||
427            errno == GNUTLS_E_AGAIN);
428 
429     if (status == -1)
430       break;
431 
432     total     += (unsigned int) status;
433     remaining -= (unsigned int) status;
434   }
435 
436   if (_debug)
437     std::cout << "c: INFO Sending 'XXXX"
438               << data.c_str ()
439               << "' (" << total << " bytes)"
440               << std::endl;
441 }
442 
443 ////////////////////////////////////////////////////////////////////////////////
recv(std::string & data)444 void TLSClient::recv (std::string& data)
445 {
446   data = "";          // No appending of data.
447   int received = 0;
448 
449   // Get the encoded length.
450   unsigned char header[4] = {0};
451   do
452   {
453     received = gnutls_record_recv (_session, header, 4);
454   }
455   while (received > 0 &&
456          (errno == GNUTLS_E_INTERRUPTED ||
457           errno == GNUTLS_E_AGAIN));
458 
459   int total = received;
460 
461   // Decode the length.
462   unsigned long expected = (header[0]<<24) |
463                            (header[1]<<16) |
464                            (header[2]<<8) |
465                             header[3];
466   if (_debug)
467     std::cout << "c: INFO expecting " << expected << " bytes.\n";
468 
469   // TODO This would be a good place to assert 'expected < _limit'.
470 
471   // Arbitrary buffer size.
472   char buffer[MAX_BUF];
473 
474   // Keep reading until no more data.  Concatenate chunks of data if a) the
475   // read was interrupted by a signal, and b) if there is more data than
476   // fits in the buffer.
477   do
478   {
479     do
480     {
481       received = gnutls_record_recv (_session, buffer, MAX_BUF - 1);
482     }
483     while (received > 0 &&
484            (errno == GNUTLS_E_INTERRUPTED ||
485             errno == GNUTLS_E_AGAIN));
486 
487     // Other end closed the connection.
488     if (received == 0)
489     {
490       if (_debug)
491         std::cout << "c: INFO Peer has closed the TLS connection\n";
492       break;
493     }
494 
495     // Something happened.
496     if (received < 0 && gnutls_error_is_fatal (received) == 0)
497     {
498       if (_debug)
499       std::cout << "c: WARNING " << gnutls_strerror (received) << "\n";
500     }
501     else if (received < 0)
502       throw std::string (gnutls_strerror (received));
503 
504     buffer [received] = '\0';
505     data += buffer;
506     total += received;
507 
508     // Stop at defined limit.
509     if (_limit && total > _limit)
510       break;
511   }
512   while (received > 0 && total < (int) expected);
513 
514   if (_debug)
515     std::cout << "c: INFO Receiving 'XXXX"
516               << data.c_str ()
517               << "' (" << total << " bytes)"
518               << std::endl;
519 }
520 
521 ////////////////////////////////////////////////////////////////////////////////
522 #endif
523