1 ////////////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright 2006 - 2015, Göteborg Bit Factory.
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining a copy
6 // of this software and associated documentation files (the "Software"), to deal
7 // in the Software without restriction, including without limitation the rights
8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // copies of the Software, and to permit persons to whom the Software is
10 // furnished to do so, subject to the following conditions:
11 //
12 // The above copyright notice and this permission notice shall be included
13 // in all copies or substantial portions of the Software.
14 //
15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
16 // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 // SOFTWARE.
22 //
23 // http://www.opensource.org/licenses/mit-license.php
24 //
25 ////////////////////////////////////////////////////////////////////////////////
26
27 #include <cmake.h>
28
29 #ifdef HAVE_LIBGNUTLS
30
31 #include <iostream>
32 #include <unistd.h>
33 #include <stdio.h>
34 #include <stdlib.h>
35 #include <stdint.h>
36 #include <string.h>
37 #include <sys/socket.h>
38 #include <arpa/inet.h>
39 #if (defined OPENBSD || defined SOLARIS || defined NETBSD)
40 #include <errno.h>
41 #else
42 #include <sys/errno.h>
43 #endif
44 #include <sys/types.h>
45 #include <sys/socket.h>
46 #include <netdb.h>
47 #include <TLSClient.h>
48 #include <gnutls/x509.h>
49 #include <text.h>
50
51 #define MAX_BUF 16384
52
53 static int verify_certificate_callback (gnutls_session_t);
54
55 ////////////////////////////////////////////////////////////////////////////////
gnutls_log_function(int level,const char * message)56 static void gnutls_log_function (int level, const char* message)
57 {
58 std::cout << "c: " << level << " " << message;
59 }
60
61 ////////////////////////////////////////////////////////////////////////////////
verify_certificate_callback(gnutls_session_t session)62 static int verify_certificate_callback (gnutls_session_t session)
63 {
64 const TLSClient* client = (TLSClient*) gnutls_session_get_ptr (session);
65 return client->verify_certificate ();
66 }
67
68 ////////////////////////////////////////////////////////////////////////////////
TLSClient()69 TLSClient::TLSClient ()
70 : _ca ("")
71 , _cert ("")
72 , _key ("")
73 , _host ("")
74 , _port ("")
75 , _session(0)
76 , _socket (0)
77 , _limit (0)
78 , _debug (false)
79 , _trust(strict)
80 {
81 }
82
83 ////////////////////////////////////////////////////////////////////////////////
~TLSClient()84 TLSClient::~TLSClient ()
85 {
86 gnutls_deinit (_session);
87 gnutls_certificate_free_credentials (_credentials);
88 gnutls_global_deinit ();
89
90 if (_socket)
91 {
92 shutdown (_socket, SHUT_RDWR);
93 close (_socket);
94 }
95 }
96
97 ////////////////////////////////////////////////////////////////////////////////
limit(int max)98 void TLSClient::limit (int max)
99 {
100 _limit = max;
101 }
102
103 ////////////////////////////////////////////////////////////////////////////////
104 // Calling this method results in all subsequent socket traffic being sent to
105 // std::cout, labelled with 'c: ...'.
debug(int level)106 void TLSClient::debug (int level)
107 {
108 if (level)
109 _debug = true;
110
111 gnutls_global_set_log_function (gnutls_log_function);
112 gnutls_global_set_log_level (level);
113 }
114
115 ////////////////////////////////////////////////////////////////////////////////
trust(const enum trust_level value)116 void TLSClient::trust (const enum trust_level value)
117 {
118 _trust = value;
119 if (_debug)
120 {
121 if (_trust == allow_all)
122 std::cout << "c: INFO Server certificate will be trusted automatically.\n";
123 else if (_trust == ignore_hostname)
124 std::cout << "c: INFO Server certificate will be verified but hostname ignored.\n";
125 else
126 std::cout << "c: INFO Server certificate will be verified.\n";
127 }
128 }
129
130 ////////////////////////////////////////////////////////////////////////////////
ciphers(const std::string & cipher_list)131 void TLSClient::ciphers (const std::string& cipher_list)
132 {
133 _ciphers = cipher_list;
134 }
135
136 ////////////////////////////////////////////////////////////////////////////////
init(const std::string & ca,const std::string & cert,const std::string & key)137 void TLSClient::init (
138 const std::string& ca,
139 const std::string& cert,
140 const std::string& key)
141 {
142 _ca = ca;
143 _cert = cert;
144 _key = key;
145
146 int ret = gnutls_global_init ();
147 if (ret < 0)
148 throw format ("TLS init error. {1}", gnutls_strerror (ret));
149
150 ret = gnutls_certificate_allocate_credentials (&_credentials);
151 if (ret < 0)
152 throw format ("TLS allocation error. {1}", gnutls_strerror (ret));
153
154 if (_ca != "" &&
155 (ret = gnutls_certificate_set_x509_trust_file (_credentials, _ca.c_str (), GNUTLS_X509_FMT_PEM)) < 0)
156 throw format ("Bad CA file. {1}", gnutls_strerror (ret));
157
158 if (_cert != "" &&
159 _key != "" &&
160 (ret = gnutls_certificate_set_x509_key_file (_credentials, _cert.c_str (), _key.c_str (), GNUTLS_X509_FMT_PEM)) < 0)
161 throw format ("Bad CERT file. {1}", gnutls_strerror (ret));
162
163 #if GNUTLS_VERSION_NUMBER >= 0x02090a
164 // The automatic verification for the server certificate with
165 // gnutls_certificate_set_verify_function only works with gnutls
166 // >=2.9.10. So with older versions we should call the verify function
167 // manually after the gnutls handshake.
168 gnutls_certificate_set_verify_function (_credentials, verify_certificate_callback);
169 #endif
170 ret = gnutls_init (&_session, GNUTLS_CLIENT);
171 if (ret < 0)
172 throw format ("TLS client init error. {1}", gnutls_strerror (ret));
173
174 // Use default priorities unless overridden.
175 if (_ciphers == "")
176 _ciphers = "NORMAL";
177
178 const char *err;
179 ret = gnutls_priority_set_direct (_session, _ciphers.c_str (), &err);
180 if (ret < 0)
181 {
182 if (_debug && ret == GNUTLS_E_INVALID_REQUEST)
183 std::cout << "c: ERROR Priority error at: " << err << "\n";
184
185 throw format ("Error initializing TLS. {1}", gnutls_strerror (ret));
186 }
187
188 // Apply the x509 credentials to the current session.
189 ret = gnutls_credentials_set (_session, GNUTLS_CRD_CERTIFICATE, _credentials);
190 if (ret < 0)
191 throw format ("TLS credentials error. {1}", gnutls_strerror (ret));
192 }
193
194 ////////////////////////////////////////////////////////////////////////////////
connect(const std::string & host,const std::string & port)195 void TLSClient::connect (const std::string& host, const std::string& port)
196 {
197 _host = host;
198 _port = port;
199
200 // Store the TLSClient instance, so that the verification callback can access
201 // it during the handshake below and call the verifcation method.
202 gnutls_session_set_ptr (_session, (void*) this);
203
204 // use IPv4 or IPv6, does not matter.
205 struct addrinfo hints = {0};
206 hints.ai_family = AF_UNSPEC;
207 hints.ai_socktype = SOCK_STREAM;
208 hints.ai_flags = AI_PASSIVE; // use my IP
209
210 struct addrinfo* res;
211 int ret = ::getaddrinfo (host.c_str (), port.c_str (), &hints, &res);
212 if (ret != 0)
213 throw std::string (::gai_strerror (ret));
214
215 // Try them all, stop on success.
216 struct addrinfo* p;
217 for (p = res; p != NULL; p = p->ai_next)
218 {
219 if ((_socket = ::socket (p->ai_family, p->ai_socktype, p->ai_protocol)) == -1)
220 continue;
221
222 // When a socket is closed, it remains unavailable for a while (netstat -an).
223 // Setting SO_REUSEADDR allows this program to assume control of a closed,
224 // but unavailable socket.
225 int on = 1;
226 if (::setsockopt (_socket,
227 SOL_SOCKET,
228 SO_REUSEADDR,
229 (const void*) &on,
230 sizeof (on)) == -1)
231 throw std::string (::strerror (errno));
232
233 if (::connect (_socket, p->ai_addr, p->ai_addrlen) == -1)
234 continue;
235
236 break;
237 }
238
239 free (res);
240
241 if (p == NULL)
242 throw format ("Could not connect to {1} {2}", host, port);
243
244 #if GNUTLS_VERSION_NUMBER >= 0x030109
245 gnutls_transport_set_int (_session, _socket);
246 #else
247 gnutls_transport_set_ptr (_session, (gnutls_transport_ptr_t) (intptr_t) _socket);
248 #endif
249
250 // Perform the TLS handshake
251 do
252 {
253 ret = gnutls_handshake (_session);
254 }
255 while (ret < 0 && gnutls_error_is_fatal (ret) == 0);
256 if (ret < 0)
257 throw format ("Handshake failed. {1}", gnutls_strerror (ret));
258
259 #if GNUTLS_VERSION_NUMBER < 0x02090a
260 // The automatic verification for the server certificate with
261 // gnutls_certificate_set_verify_function does only work with gnutls
262 // >=2.9.10. So with older versions we should call the verify function
263 // manually after the gnutls handshake.
264 ret = verify_certificate ();
265 if (ret < 0)
266 {
267 if (_debug)
268 std::cout << "c: ERROR Certificate verification failed.\n";
269 throw format ("Error Initializing TLS. {1}", gnutls_strerror (ret));
270 }
271 #endif
272
273 if (_debug)
274 {
275 #if GNUTLS_VERSION_NUMBER >= 0x03010a
276 char* desc = gnutls_session_get_desc (_session);
277 std::cout << "c: INFO Handshake was completed: " << desc << "\n";
278 gnutls_free (desc);
279 #else
280 std::cout << "c: INFO Handshake was completed.\n";
281 #endif
282 }
283 }
284
285 ////////////////////////////////////////////////////////////////////////////////
bye()286 void TLSClient::bye ()
287 {
288 gnutls_bye (_session, GNUTLS_SHUT_RDWR);
289 }
290
291 ////////////////////////////////////////////////////////////////////////////////
verify_certificate() const292 int TLSClient::verify_certificate () const
293 {
294 if (_trust == TLSClient::allow_all)
295 return 0;
296
297 if (_debug)
298 std::cout << "c: INFO Verifying certificate.\n";
299
300 // This verification function uses the trusted CAs in the credentials
301 // structure. So you must have installed one or more CA certificates.
302 unsigned int status = 0;
303 const char* hostname = _host.c_str();
304 #if GNUTLS_VERSION_NUMBER >= 0x030104
305 if (_trust == TLSClient::ignore_hostname)
306 hostname = NULL;
307
308 int ret = gnutls_certificate_verify_peers3 (_session, hostname, &status);
309 if (ret < 0)
310 {
311 if (_debug)
312 std::cout << "c: ERROR Certificate verification peers3 failed. " << gnutls_strerror (ret) << "\n";
313 return GNUTLS_E_CERTIFICATE_ERROR;
314 }
315
316 // status 16450 == 0100000001000010
317 // GNUTLS_CERT_INVALID 1<<1
318 // GNUTLS_CERT_SIGNER_NOT_FOUND 1<<6
319 // GNUTLS_CERT_UNEXPECTED_OWNER 1<<14 Hostname does not match
320
321 if (_debug && status)
322 std::cout << "c: ERROR Certificate status=" << status << "\n";
323 #else
324 int ret = gnutls_certificate_verify_peers2 (_session, &status);
325 if (ret < 0)
326 {
327 if (_debug)
328 std::cout << "c: ERROR Certificate verification peers2 failed. " << gnutls_strerror (ret) << "\n";
329 return GNUTLS_E_CERTIFICATE_ERROR;
330 }
331
332 if (_debug && status)
333 std::cout << "c: ERROR Certificate status=" << status << "\n";
334
335 if ((status == 0) && (_trust != TLSClient::ignore_hostname))
336 {
337 if (gnutls_certificate_type_get (_session) == GNUTLS_CRT_X509)
338 {
339 const gnutls_datum* cert_list;
340 unsigned int cert_list_size;
341 gnutls_x509_crt cert;
342
343 cert_list = gnutls_certificate_get_peers (_session, &cert_list_size);
344 if (cert_list_size == 0)
345 {
346 if (_debug)
347 std::cout << "c: ERROR Certificate get peers failed. " << gnutls_strerror (ret) << "\n";
348 return GNUTLS_E_CERTIFICATE_ERROR;
349 }
350
351 ret = gnutls_x509_crt_init (&cert);
352 if (ret < 0)
353 {
354 if (_debug)
355 std::cout << "c: ERROR x509 init failed. " << gnutls_strerror (ret) << "\n";
356 return GNUTLS_E_CERTIFICATE_ERROR;
357 }
358
359 ret = gnutls_x509_crt_import (cert, &cert_list[0], GNUTLS_X509_FMT_DER);
360 if (ret < 0)
361 {
362 if (_debug)
363 std::cout << "c: ERROR x509 cert import. " << gnutls_strerror (ret) << "\n";
364 gnutls_x509_crt_deinit(cert);
365 return GNUTLS_E_CERTIFICATE_ERROR;
366 }
367
368 if (gnutls_x509_crt_check_hostname (cert, hostname) == 0)
369 {
370 if (_debug)
371 std::cout << "c: ERROR x509 cert check hostname. " << gnutls_strerror (ret) << "\n";
372 gnutls_x509_crt_deinit(cert);
373 return GNUTLS_E_CERTIFICATE_ERROR;
374 }
375 }
376 else
377 return GNUTLS_E_CERTIFICATE_ERROR;
378 }
379 #endif
380
381 #if GNUTLS_VERSION_NUMBER >= 0x030104
382 gnutls_certificate_type_t type = gnutls_certificate_type_get (_session);
383 gnutls_datum_t out;
384 ret = gnutls_certificate_verification_status_print (status, type, &out, 0);
385 if (ret < 0)
386 {
387 if (_debug)
388 std::cout << "c: ERROR certificate verification status. " << gnutls_strerror (ret) << "\n";
389 return GNUTLS_E_CERTIFICATE_ERROR;
390 }
391
392 if (_debug)
393 std::cout << "c: INFO " << out.data << "\n";
394 gnutls_free (out.data);
395 #endif
396
397 if (status != 0)
398 return GNUTLS_E_CERTIFICATE_ERROR;
399
400 // Continue handshake.
401 return 0;
402 }
403
404 ////////////////////////////////////////////////////////////////////////////////
send(const std::string & data)405 void TLSClient::send (const std::string& data)
406 {
407 std::string packet = "XXXX" + data;
408
409 // Encode the length.
410 unsigned long l = packet.length ();
411 packet[0] = l >>24;
412 packet[1] = l >>16;
413 packet[2] = l >>8;
414 packet[3] = l;
415
416 unsigned int total = 0;
417 unsigned int remaining = packet.length ();
418
419 while (total < packet.length ())
420 {
421 int status;
422 do
423 {
424 status = gnutls_record_send (_session, packet.c_str () + total, remaining);
425 }
426 while (errno == GNUTLS_E_INTERRUPTED ||
427 errno == GNUTLS_E_AGAIN);
428
429 if (status == -1)
430 break;
431
432 total += (unsigned int) status;
433 remaining -= (unsigned int) status;
434 }
435
436 if (_debug)
437 std::cout << "c: INFO Sending 'XXXX"
438 << data.c_str ()
439 << "' (" << total << " bytes)"
440 << std::endl;
441 }
442
443 ////////////////////////////////////////////////////////////////////////////////
recv(std::string & data)444 void TLSClient::recv (std::string& data)
445 {
446 data = ""; // No appending of data.
447 int received = 0;
448
449 // Get the encoded length.
450 unsigned char header[4] = {0};
451 do
452 {
453 received = gnutls_record_recv (_session, header, 4);
454 }
455 while (received > 0 &&
456 (errno == GNUTLS_E_INTERRUPTED ||
457 errno == GNUTLS_E_AGAIN));
458
459 int total = received;
460
461 // Decode the length.
462 unsigned long expected = (header[0]<<24) |
463 (header[1]<<16) |
464 (header[2]<<8) |
465 header[3];
466 if (_debug)
467 std::cout << "c: INFO expecting " << expected << " bytes.\n";
468
469 // TODO This would be a good place to assert 'expected < _limit'.
470
471 // Arbitrary buffer size.
472 char buffer[MAX_BUF];
473
474 // Keep reading until no more data. Concatenate chunks of data if a) the
475 // read was interrupted by a signal, and b) if there is more data than
476 // fits in the buffer.
477 do
478 {
479 do
480 {
481 received = gnutls_record_recv (_session, buffer, MAX_BUF - 1);
482 }
483 while (received > 0 &&
484 (errno == GNUTLS_E_INTERRUPTED ||
485 errno == GNUTLS_E_AGAIN));
486
487 // Other end closed the connection.
488 if (received == 0)
489 {
490 if (_debug)
491 std::cout << "c: INFO Peer has closed the TLS connection\n";
492 break;
493 }
494
495 // Something happened.
496 if (received < 0 && gnutls_error_is_fatal (received) == 0)
497 {
498 if (_debug)
499 std::cout << "c: WARNING " << gnutls_strerror (received) << "\n";
500 }
501 else if (received < 0)
502 throw std::string (gnutls_strerror (received));
503
504 buffer [received] = '\0';
505 data += buffer;
506 total += received;
507
508 // Stop at defined limit.
509 if (_limit && total > _limit)
510 break;
511 }
512 while (received > 0 && total < (int) expected);
513
514 if (_debug)
515 std::cout << "c: INFO Receiving 'XXXX"
516 << data.c_str ()
517 << "' (" << total << " bytes)"
518 << std::endl;
519 }
520
521 ////////////////////////////////////////////////////////////////////////////////
522 #endif
523