1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/tsi/ssl_transport_security.h"
22 #include <sys/socket.h>
23 #include <limits.h>
24 #include <string.h>
25 
26 /* TODO(jboeuf): refactor inet_ntop into a portability header. */
27 /* Note: for whomever reads this and tries to refactor this, this
28    can't be in grpc, it has to be in gpr. */
29 #ifdef GPR_WINDOWS
30 #include <ws2tcpip.h>
31 #else
32 #include <arpa/inet.h>
33 #include <sys/socket.h>
34 #endif
35 
36 #include <string>
37 
38 #include <grpc/grpc_security.h>
39 #include <grpc/support/alloc.h>
40 #include <grpc/support/log.h>
41 #include <grpc/support/string_util.h>
42 #include <grpc/support/sync.h>
43 #include <grpc/support/thd_id.h>
44 
45 #include "absl/strings/match.h"
46 #include "absl/strings/string_view.h"
47 
48 #pragma clang diagnostic push
49 #pragma clang diagnostic ignored "-Wmodule-import-in-extern-c"
50 extern "C" {
51 #include <openssl/bio.h>
52 #include <openssl/crypto.h> /* For OPENSSL_free */
53 #include <openssl/engine.h>
54 #include <openssl/err.h>
55 #include <openssl/ssl.h>
56 #include <openssl/tls1.h>
57 #include <openssl/x509.h>
58 #include <openssl/x509v3.h>
59 }
60 #pragma clang diagnostic pop
61 
62 #include "src/core/lib/gpr/useful.h"
63 #include "src/core/tsi/ssl/session_cache/ssl_session_cache.h"
64 #include "src/core/tsi/ssl_types.h"
65 #include "src/core/tsi/transport_security.h"
66 
67 /* --- Constants. ---*/
68 
69 #define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND 16384
70 #define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND 1024
71 #define TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE 1024
72 
73 #if OPENSSL_VERSION_NUMBER >= 0x10002000L
74 #define TSI_OPENSSL_ALPN_SUPPORT 1
75 #else
76 #define TSI_OPENSSL_ALPN_SUPPORT 0
77 #endif
78 
79 /* TODO(jboeuf): I have not found a way to get this number dynamically from the
80    SSL structure. This is what we would ultimately want though... */
81 #define TSI_SSL_MAX_PROTECTION_OVERHEAD 100
82 
83 /* --- Structure definitions. ---*/
84 
85 struct tsi_ssl_root_certs_store {
86   X509_STORE* store;
87 };
88 
89 struct tsi_ssl_handshaker_factory {
90   const tsi_ssl_handshaker_factory_vtable* vtable;
91   gpr_refcount refcount;
92 };
93 
94 struct tsi_ssl_client_handshaker_factory {
95   tsi_ssl_handshaker_factory base;
96   SSL_CTX* ssl_context;
97   unsigned char* alpn_protocol_list;
98   size_t alpn_protocol_list_length;
99   grpc_core::RefCountedPtr<tsi::SslSessionLRUCache> session_cache;
100 };
101 
102 struct tsi_ssl_server_handshaker_factory {
103   /* Several contexts to support SNI.
104      The tsi_peer array contains the subject names of the server certificates
105      associated with the contexts at the same index.  */
106   tsi_ssl_handshaker_factory base;
107   SSL_CTX** ssl_contexts;
108   tsi_peer* ssl_context_x509_subject_names;
109   size_t ssl_context_count;
110   unsigned char* alpn_protocol_list;
111   size_t alpn_protocol_list_length;
112 };
113 
114 struct tsi_ssl_handshaker {
115   tsi_handshaker base;
116   SSL* ssl;
117   BIO* network_io;
118   tsi_result result;
119   unsigned char* outgoing_bytes_buffer;
120   size_t outgoing_bytes_buffer_size;
121   tsi_ssl_handshaker_factory* factory_ref;
122 };
123 struct tsi_ssl_handshaker_result {
124   tsi_handshaker_result base;
125   SSL* ssl;
126   BIO* network_io;
127   unsigned char* unused_bytes;
128   size_t unused_bytes_size;
129 };
130 struct tsi_ssl_frame_protector {
131   tsi_frame_protector base;
132   SSL* ssl;
133   BIO* network_io;
134   unsigned char* buffer;
135   size_t buffer_size;
136   size_t buffer_offset;
137 };
138 /* --- Library Initialization. ---*/
139 
140 static gpr_once g_init_openssl_once = GPR_ONCE_INIT;
141 static int g_ssl_ctx_ex_factory_index = -1;
142 static const unsigned char kSslSessionIdContext[] = {'g', 'r', 'p', 'c'};
143 #ifndef OPENSSL_IS_BORINGSSL
144 static const char kSslEnginePrefix[] = "engine:";
145 #endif
146 
147 #if OPENSSL_VERSION_NUMBER < 0x10100000
148 static gpr_mu* g_openssl_mutexes = nullptr;
149 static void openssl_locking_cb(int mode, int type, const char* file,
150                                int line) GRPC_UNUSED;
151 static unsigned long openssl_thread_id_cb(void) GRPC_UNUSED;
152 
openssl_locking_cb(int mode,int type,const char * file,int line)153 static void openssl_locking_cb(int mode, int type, const char* file, int line) {
154   if (mode & CRYPTO_LOCK) {
155     gpr_mu_lock(&g_openssl_mutexes[type]);
156   } else {
157     gpr_mu_unlock(&g_openssl_mutexes[type]);
158   }
159 }
160 
openssl_thread_id_cb(void)161 static unsigned long openssl_thread_id_cb(void) {
162   return static_cast<unsigned long>(gpr_thd_currentid());
163 }
164 #endif
165 
init_openssl(void)166 static void init_openssl(void) {
167 #if OPENSSL_VERSION_NUMBER >= 0x10100000
168   OPENSSL_init_ssl(0, nullptr);
169 #else
170   SSL_library_init();
171   SSL_load_error_strings();
172   OpenSSL_add_all_algorithms();
173 #endif
174 #if OPENSSL_VERSION_NUMBER < 0x10100000
175   if (!CRYPTO_get_locking_callback()) {
176     int num_locks = CRYPTO_num_locks();
177     GPR_ASSERT(num_locks > 0);
178     g_openssl_mutexes = static_cast<gpr_mu*>(
179         gpr_malloc(static_cast<size_t>(num_locks) * sizeof(gpr_mu)));
180     for (int i = 0; i < num_locks; i++) {
181       gpr_mu_init(&g_openssl_mutexes[i]);
182     }
183     CRYPTO_set_locking_callback(openssl_locking_cb);
184     CRYPTO_set_id_callback(openssl_thread_id_cb);
185   } else {
186     gpr_log(GPR_INFO, "OpenSSL callback has already been set.");
187   }
188 #endif
189   g_ssl_ctx_ex_factory_index =
190       SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
191   GPR_ASSERT(g_ssl_ctx_ex_factory_index != -1);
192 }
193 
194 /* --- Ssl utils. ---*/
195 
ssl_error_string(int error)196 static const char* ssl_error_string(int error) {
197   switch (error) {
198     case SSL_ERROR_NONE:
199       return "SSL_ERROR_NONE";
200     case SSL_ERROR_ZERO_RETURN:
201       return "SSL_ERROR_ZERO_RETURN";
202     case SSL_ERROR_WANT_READ:
203       return "SSL_ERROR_WANT_READ";
204     case SSL_ERROR_WANT_WRITE:
205       return "SSL_ERROR_WANT_WRITE";
206     case SSL_ERROR_WANT_CONNECT:
207       return "SSL_ERROR_WANT_CONNECT";
208     case SSL_ERROR_WANT_ACCEPT:
209       return "SSL_ERROR_WANT_ACCEPT";
210     case SSL_ERROR_WANT_X509_LOOKUP:
211       return "SSL_ERROR_WANT_X509_LOOKUP";
212     case SSL_ERROR_SYSCALL:
213       return "SSL_ERROR_SYSCALL";
214     case SSL_ERROR_SSL:
215       return "SSL_ERROR_SSL";
216     default:
217       return "Unknown error";
218   }
219 }
220 
221 /* TODO(jboeuf): Remove when we are past the debugging phase with this code. */
ssl_log_where_info(const SSL * ssl,int where,int flag,const char * msg)222 static void ssl_log_where_info(const SSL* ssl, int where, int flag,
223                                const char* msg) {
224   if ((where & flag) && GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) {
225     gpr_log(GPR_INFO, "%20.20s - %30.30s  - %5.10s", msg,
226             SSL_state_string_long(ssl), SSL_state_string(ssl));
227   }
228 }
229 
230 /* Used for debugging. TODO(jboeuf): Remove when code is mature enough. */
ssl_info_callback(const SSL * ssl,int where,int ret)231 static void ssl_info_callback(const SSL* ssl, int where, int ret) {
232   if (ret == 0) {
233     gpr_log(GPR_ERROR, "ssl_info_callback: error occurred.\n");
234     return;
235   }
236 
237   ssl_log_where_info(ssl, where, SSL_CB_LOOP, "LOOP");
238   ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_START, "HANDSHAKE START");
239   ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_DONE, "HANDSHAKE DONE");
240 }
241 
242 /* Returns 1 if name looks like an IP address, 0 otherwise.
243    This is a very rough heuristic, and only handles IPv6 in hexadecimal form. */
looks_like_ip_address(absl::string_view name)244 static int looks_like_ip_address(absl::string_view name) {
245   size_t dot_count = 0;
246   size_t num_size = 0;
247   for (size_t i = 0; i < name.size(); ++i) {
248     if (name[i] == ':') {
249       /* IPv6 Address in hexadecimal form, : is not allowed in DNS names. */
250       return 1;
251     }
252     if (name[i] >= '0' && name[i] <= '9') {
253       if (num_size > 3) return 0;
254       num_size++;
255     } else if (name[i] == '.') {
256       if (dot_count > 3 || num_size == 0) return 0;
257       dot_count++;
258       num_size = 0;
259     } else {
260       return 0;
261     }
262   }
263   if (dot_count < 3 || num_size == 0) return 0;
264   return 1;
265 }
266 
267 /* Gets the subject CN from an X509 cert. */
ssl_get_x509_common_name(X509 * cert,unsigned char ** utf8,size_t * utf8_size)268 static tsi_result ssl_get_x509_common_name(X509* cert, unsigned char** utf8,
269                                            size_t* utf8_size) {
270   int common_name_index = -1;
271   X509_NAME_ENTRY* common_name_entry = nullptr;
272   ASN1_STRING* common_name_asn1 = nullptr;
273   X509_NAME* subject_name = X509_get_subject_name(cert);
274   int utf8_returned_size = 0;
275   if (subject_name == nullptr) {
276     gpr_log(GPR_INFO, "Could not get subject name from certificate.");
277     return TSI_NOT_FOUND;
278   }
279   common_name_index =
280       X509_NAME_get_index_by_NID(subject_name, NID_commonName, -1);
281   if (common_name_index == -1) {
282     gpr_log(GPR_INFO, "Could not get common name of subject from certificate.");
283     return TSI_NOT_FOUND;
284   }
285   common_name_entry = X509_NAME_get_entry(subject_name, common_name_index);
286   if (common_name_entry == nullptr) {
287     gpr_log(GPR_ERROR, "Could not get common name entry from certificate.");
288     return TSI_INTERNAL_ERROR;
289   }
290   common_name_asn1 = X509_NAME_ENTRY_get_data(common_name_entry);
291   if (common_name_asn1 == nullptr) {
292     gpr_log(GPR_ERROR,
293             "Could not get common name entry asn1 from certificate.");
294     return TSI_INTERNAL_ERROR;
295   }
296   utf8_returned_size = ASN1_STRING_to_UTF8(utf8, common_name_asn1);
297   if (utf8_returned_size < 0) {
298     gpr_log(GPR_ERROR, "Could not extract utf8 from asn1 string.");
299     return TSI_OUT_OF_RESOURCES;
300   }
301   *utf8_size = static_cast<size_t>(utf8_returned_size);
302   return TSI_OK;
303 }
304 
305 /* Gets the subject CN of an X509 cert as a tsi_peer_property. */
peer_property_from_x509_common_name(X509 * cert,tsi_peer_property * property)306 static tsi_result peer_property_from_x509_common_name(
307     X509* cert, tsi_peer_property* property) {
308   unsigned char* common_name;
309   size_t common_name_size;
310   tsi_result result =
311       ssl_get_x509_common_name(cert, &common_name, &common_name_size);
312   if (result != TSI_OK) {
313     if (result == TSI_NOT_FOUND) {
314       common_name = nullptr;
315       common_name_size = 0;
316     } else {
317       return result;
318     }
319   }
320   result = tsi_construct_string_peer_property(
321       TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY,
322       common_name == nullptr ? "" : reinterpret_cast<const char*>(common_name),
323       common_name_size, property);
324   OPENSSL_free(common_name);
325   return result;
326 }
327 
328 /* Gets the X509 cert in PEM format as a tsi_peer_property. */
add_pem_certificate(X509 * cert,tsi_peer_property * property)329 static tsi_result add_pem_certificate(X509* cert, tsi_peer_property* property) {
330   BIO* bio = BIO_new(BIO_s_mem());
331   if (!PEM_write_bio_X509(bio, cert)) {
332     BIO_free(bio);
333     return TSI_INTERNAL_ERROR;
334   }
335   char* contents;
336   long len = BIO_get_mem_data(bio, &contents);
337   if (len <= 0) {
338     BIO_free(bio);
339     return TSI_INTERNAL_ERROR;
340   }
341   tsi_result result = tsi_construct_string_peer_property(
342       TSI_X509_PEM_CERT_PROPERTY, (const char*)contents,
343       static_cast<size_t>(len), property);
344   BIO_free(bio);
345   return result;
346 }
347 
348 /* Gets the subject SANs from an X509 cert as a tsi_peer_property. */
add_subject_alt_names_properties_to_peer(tsi_peer * peer,GENERAL_NAMES * subject_alt_names,size_t subject_alt_name_count,int * current_insert_index)349 static tsi_result add_subject_alt_names_properties_to_peer(
350     tsi_peer* peer, GENERAL_NAMES* subject_alt_names,
351     size_t subject_alt_name_count, int* current_insert_index) {
352   size_t i;
353   tsi_result result = TSI_OK;
354 
355   for (i = 0; i < subject_alt_name_count; i++) {
356     GENERAL_NAME* subject_alt_name =
357         sk_GENERAL_NAME_value(subject_alt_names, TSI_SIZE_AS_SIZE(i));
358     if (subject_alt_name->type == GEN_DNS ||
359         subject_alt_name->type == GEN_EMAIL ||
360         subject_alt_name->type == GEN_URI) {
361       unsigned char* name = nullptr;
362       int name_size;
363       if (subject_alt_name->type == GEN_DNS) {
364         name_size = ASN1_STRING_to_UTF8(&name, subject_alt_name->d.dNSName);
365       } else if (subject_alt_name->type == GEN_EMAIL) {
366         name_size = ASN1_STRING_to_UTF8(&name, subject_alt_name->d.rfc822Name);
367       } else {
368         name_size = ASN1_STRING_to_UTF8(
369             &name, subject_alt_name->d.uniformResourceIdentifier);
370       }
371       if (name_size < 0) {
372         gpr_log(GPR_ERROR, "Could not get utf8 from asn1 string.");
373         result = TSI_INTERNAL_ERROR;
374         break;
375       }
376       result = tsi_construct_string_peer_property(
377           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY,
378           reinterpret_cast<const char*>(name), static_cast<size_t>(name_size),
379           &peer->properties[(*current_insert_index)++]);
380       if (result != TSI_OK) {
381         OPENSSL_free(name);
382         break;
383       }
384       if (subject_alt_name->type == GEN_URI) {
385         result = tsi_construct_string_peer_property(
386             TSI_X509_URI_PEER_PROPERTY, reinterpret_cast<const char*>(name),
387             static_cast<size_t>(name_size),
388             &peer->properties[(*current_insert_index)++]);
389       }
390       OPENSSL_free(name);
391     } else if (subject_alt_name->type == GEN_IPADD) {
392       char ntop_buf[INET6_ADDRSTRLEN];
393       int af;
394 
395       if (subject_alt_name->d.iPAddress->length == 4) {
396         af = AF_INET;
397       } else if (subject_alt_name->d.iPAddress->length == 16) {
398         af = AF_INET6;
399       } else {
400         gpr_log(GPR_ERROR, "SAN IP Address contained invalid IP");
401         result = TSI_INTERNAL_ERROR;
402         break;
403       }
404       const char* name = inet_ntop(af, subject_alt_name->d.iPAddress->data,
405                                    ntop_buf, INET6_ADDRSTRLEN);
406       if (name == nullptr) {
407         gpr_log(GPR_ERROR, "Could not get IP string from asn1 octet.");
408         result = TSI_INTERNAL_ERROR;
409         break;
410       }
411 
412       result = tsi_construct_string_peer_property_from_cstring(
413           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, name,
414           &peer->properties[(*current_insert_index)++]);
415     } else {
416       result = tsi_construct_string_peer_property_from_cstring(
417           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, "other types of SAN",
418           &peer->properties[(*current_insert_index)++]);
419     }
420     if (result != TSI_OK) break;
421   }
422   return result;
423 }
424 
425 /* Gets information about the peer's X509 cert as a tsi_peer object. */
peer_from_x509(X509 * cert,int include_certificate_type,tsi_peer * peer)426 static tsi_result peer_from_x509(X509* cert, int include_certificate_type,
427                                  tsi_peer* peer) {
428   /* TODO(jboeuf): Maybe add more properties. */
429   GENERAL_NAMES* subject_alt_names = static_cast<GENERAL_NAMES*>(
430       X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
431   int subject_alt_name_count =
432       (subject_alt_names != nullptr)
433           ? static_cast<int>(sk_GENERAL_NAME_num(subject_alt_names))
434           : 0;
435   size_t property_count;
436   tsi_result result;
437   GPR_ASSERT(subject_alt_name_count >= 0);
438   property_count = (include_certificate_type ? static_cast<size_t>(1) : 0) +
439                    2 /* common name, certificate */ +
440                    static_cast<size_t>(subject_alt_name_count);
441   for (int i = 0; i < subject_alt_name_count; i++) {
442     GENERAL_NAME* subject_alt_name =
443         sk_GENERAL_NAME_value(subject_alt_names, TSI_SIZE_AS_SIZE(i));
444     if (subject_alt_name->type == GEN_URI) {
445       property_count += 1;
446     }
447   }
448   result = tsi_construct_peer(property_count, peer);
449   if (result != TSI_OK) return result;
450   int current_insert_index = 0;
451   do {
452     if (include_certificate_type) {
453       result = tsi_construct_string_peer_property_from_cstring(
454           TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE,
455           &peer->properties[current_insert_index++]);
456       if (result != TSI_OK) break;
457     }
458     result = peer_property_from_x509_common_name(
459         cert, &peer->properties[current_insert_index++]);
460     if (result != TSI_OK) break;
461 
462     result =
463         add_pem_certificate(cert, &peer->properties[current_insert_index++]);
464     if (result != TSI_OK) break;
465 
466     if (subject_alt_name_count != 0) {
467       result = add_subject_alt_names_properties_to_peer(
468           peer, subject_alt_names, static_cast<size_t>(subject_alt_name_count),
469           &current_insert_index);
470       if (result != TSI_OK) break;
471     }
472   } while (false);
473 
474   if (subject_alt_names != nullptr) {
475     sk_GENERAL_NAME_pop_free(subject_alt_names, GENERAL_NAME_free);
476   }
477   if (result != TSI_OK) tsi_peer_destruct(peer);
478 
479   GPR_ASSERT((int)peer->property_count == current_insert_index);
480   return result;
481 }
482 
483 /* Logs the SSL error stack. */
log_ssl_error_stack(void)484 static void log_ssl_error_stack(void) {
485   unsigned long err;
486   while ((err = ERR_get_error()) != 0) {
487     char details[256];
488     ERR_error_string_n(static_cast<uint32_t>(err), details, sizeof(details));
489     gpr_log(GPR_ERROR, "%s", details);
490   }
491 }
492 
493 /* Performs an SSL_read and handle errors. */
do_ssl_read(SSL * ssl,unsigned char * unprotected_bytes,size_t * unprotected_bytes_size)494 static tsi_result do_ssl_read(SSL* ssl, unsigned char* unprotected_bytes,
495                               size_t* unprotected_bytes_size) {
496   int read_from_ssl;
497   GPR_ASSERT(*unprotected_bytes_size <= INT_MAX);
498   read_from_ssl = SSL_read(ssl, unprotected_bytes,
499                            static_cast<int>(*unprotected_bytes_size));
500   if (read_from_ssl <= 0) {
501     read_from_ssl = SSL_get_error(ssl, read_from_ssl);
502     switch (read_from_ssl) {
503       case SSL_ERROR_ZERO_RETURN: /* Received a close_notify alert. */
504       case SSL_ERROR_WANT_READ:   /* We need more data to finish the frame. */
505         *unprotected_bytes_size = 0;
506         return TSI_OK;
507       case SSL_ERROR_WANT_WRITE:
508         gpr_log(
509             GPR_ERROR,
510             "Peer tried to renegotiate SSL connection. This is unsupported.");
511         return TSI_UNIMPLEMENTED;
512       case SSL_ERROR_SSL:
513         gpr_log(GPR_ERROR, "Corruption detected.");
514         log_ssl_error_stack();
515         return TSI_DATA_CORRUPTED;
516       default:
517         gpr_log(GPR_ERROR, "SSL_read failed with error %s.",
518                 ssl_error_string(read_from_ssl));
519         return TSI_PROTOCOL_FAILURE;
520     }
521   }
522   *unprotected_bytes_size = static_cast<size_t>(read_from_ssl);
523   return TSI_OK;
524 }
525 
526 /* Performs an SSL_write and handle errors. */
do_ssl_write(SSL * ssl,unsigned char * unprotected_bytes,size_t unprotected_bytes_size)527 static tsi_result do_ssl_write(SSL* ssl, unsigned char* unprotected_bytes,
528                                size_t unprotected_bytes_size) {
529   int ssl_write_result;
530   GPR_ASSERT(unprotected_bytes_size <= INT_MAX);
531   ssl_write_result = SSL_write(ssl, unprotected_bytes,
532                                static_cast<int>(unprotected_bytes_size));
533   if (ssl_write_result < 0) {
534     ssl_write_result = SSL_get_error(ssl, ssl_write_result);
535     if (ssl_write_result == SSL_ERROR_WANT_READ) {
536       gpr_log(GPR_ERROR,
537               "Peer tried to renegotiate SSL connection. This is unsupported.");
538       return TSI_UNIMPLEMENTED;
539     } else {
540       gpr_log(GPR_ERROR, "SSL_write failed with error %s.",
541               ssl_error_string(ssl_write_result));
542       return TSI_INTERNAL_ERROR;
543     }
544   }
545   return TSI_OK;
546 }
547 
548 /* Loads an in-memory PEM certificate chain into the SSL context. */
ssl_ctx_use_certificate_chain(SSL_CTX * context,const char * pem_cert_chain,size_t pem_cert_chain_size)549 static tsi_result ssl_ctx_use_certificate_chain(SSL_CTX* context,
550                                                 const char* pem_cert_chain,
551                                                 size_t pem_cert_chain_size) {
552   tsi_result result = TSI_OK;
553   X509* certificate = nullptr;
554   BIO* pem;
555   GPR_ASSERT(pem_cert_chain_size <= INT_MAX);
556   pem = BIO_new_mem_buf((void*)pem_cert_chain,
557                         static_cast<int>(pem_cert_chain_size));
558   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
559 
560   do {
561     certificate = PEM_read_bio_X509_AUX(pem, nullptr, nullptr, (void*)"");
562     if (certificate == nullptr) {
563       result = TSI_INVALID_ARGUMENT;
564       break;
565     }
566     if (!SSL_CTX_use_certificate(context, certificate)) {
567       result = TSI_INVALID_ARGUMENT;
568       break;
569     }
570     while (true) {
571       X509* certificate_authority =
572           PEM_read_bio_X509(pem, nullptr, nullptr, (void*)"");
573       if (certificate_authority == nullptr) {
574         ERR_clear_error();
575         break; /* Done reading. */
576       }
577       if (!SSL_CTX_add_extra_chain_cert(context, certificate_authority)) {
578         X509_free(certificate_authority);
579         result = TSI_INVALID_ARGUMENT;
580         break;
581       }
582       /* We don't need to free certificate_authority as its ownership has been
583          transferred to the context. That is not the case for certificate
584          though.
585        */
586     }
587   } while (false);
588 
589   if (certificate != nullptr) X509_free(certificate);
590   BIO_free(pem);
591   return result;
592 }
593 
594 #ifndef OPENSSL_IS_BORINGSSL
ssl_ctx_use_engine_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)595 static tsi_result ssl_ctx_use_engine_private_key(SSL_CTX* context,
596                                                  const char* pem_key,
597                                                  size_t pem_key_size) {
598   tsi_result result = TSI_OK;
599   EVP_PKEY* private_key = nullptr;
600   ENGINE* engine = nullptr;
601   char* engine_name = nullptr;
602   // Parse key which is in following format engine:<engine_id>:<key_id>
603   do {
604     char* engine_start = (char*)pem_key + strlen(kSslEnginePrefix);
605     char* engine_end = (char*)strchr(engine_start, ':');
606     if (engine_end == nullptr) {
607       result = TSI_INVALID_ARGUMENT;
608       break;
609     }
610     char* key_id = engine_end + 1;
611     int engine_name_length = engine_end - engine_start;
612     if (engine_name_length == 0) {
613       result = TSI_INVALID_ARGUMENT;
614       break;
615     }
616     engine_name = static_cast<char*>(gpr_zalloc(engine_name_length + 1));
617     memcpy(engine_name, engine_start, engine_name_length);
618     gpr_log(GPR_DEBUG, "ENGINE key: %s", engine_name);
619     ENGINE_load_dynamic();
620     engine = ENGINE_by_id(engine_name);
621     if (engine == nullptr) {
622       // If not available at ENGINE_DIR, use dynamic to load from
623       // current working directory.
624       engine = ENGINE_by_id("dynamic");
625       if (engine == nullptr) {
626         gpr_log(GPR_ERROR, "Cannot load dynamic engine");
627         result = TSI_INVALID_ARGUMENT;
628         break;
629       }
630       if (!ENGINE_ctrl_cmd_string(engine, "ID", engine_name, 0) ||
631           !ENGINE_ctrl_cmd_string(engine, "DIR_LOAD", "2", 0) ||
632           !ENGINE_ctrl_cmd_string(engine, "DIR_ADD", ".", 0) ||
633           !ENGINE_ctrl_cmd_string(engine, "LIST_ADD", "1", 0) ||
634           !ENGINE_ctrl_cmd_string(engine, "LOAD", NULL, 0)) {
635         gpr_log(GPR_ERROR, "Cannot find engine");
636         result = TSI_INVALID_ARGUMENT;
637         break;
638       }
639     }
640     if (!ENGINE_set_default(engine, ENGINE_METHOD_ALL)) {
641       gpr_log(GPR_ERROR, "ENGINE_set_default with ENGINE_METHOD_ALL failed");
642       result = TSI_INVALID_ARGUMENT;
643       break;
644     }
645     if (!ENGINE_init(engine)) {
646       gpr_log(GPR_ERROR, "ENGINE_init failed");
647       result = TSI_INVALID_ARGUMENT;
648       break;
649     }
650     private_key = ENGINE_load_private_key(engine, key_id, 0, 0);
651     if (private_key == nullptr) {
652       gpr_log(GPR_ERROR, "ENGINE_load_private_key failed");
653       result = TSI_INVALID_ARGUMENT;
654       break;
655     }
656     if (!SSL_CTX_use_PrivateKey(context, private_key)) {
657       gpr_log(GPR_ERROR, "SSL_CTX_use_PrivateKey failed");
658       result = TSI_INVALID_ARGUMENT;
659       break;
660     }
661   } while (0);
662   if (engine != nullptr) ENGINE_free(engine);
663   if (private_key != nullptr) EVP_PKEY_free(private_key);
664   if (engine_name != nullptr) gpr_free(engine_name);
665   return result;
666 }
667 #endif /* OPENSSL_IS_BORINGSSL */
668 
ssl_ctx_use_pem_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)669 static tsi_result ssl_ctx_use_pem_private_key(SSL_CTX* context,
670                                               const char* pem_key,
671                                               size_t pem_key_size) {
672   tsi_result result = TSI_OK;
673   EVP_PKEY* private_key = nullptr;
674   BIO* pem;
675   GPR_ASSERT(pem_key_size <= INT_MAX);
676   pem = BIO_new_mem_buf((void*)pem_key, static_cast<int>(pem_key_size));
677   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
678   do {
679     private_key = PEM_read_bio_PrivateKey(pem, nullptr, nullptr, (void*)"");
680     if (private_key == nullptr) {
681       result = TSI_INVALID_ARGUMENT;
682       break;
683     }
684     if (!SSL_CTX_use_PrivateKey(context, private_key)) {
685       result = TSI_INVALID_ARGUMENT;
686       break;
687     }
688   } while (false);
689   if (private_key != nullptr) EVP_PKEY_free(private_key);
690   BIO_free(pem);
691   return result;
692 }
693 
694 /* Loads an in-memory PEM private key into the SSL context. */
ssl_ctx_use_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)695 static tsi_result ssl_ctx_use_private_key(SSL_CTX* context, const char* pem_key,
696                                           size_t pem_key_size) {
697 // BoringSSL does not have ENGINE support
698 #ifndef OPENSSL_IS_BORINGSSL
699   if (strncmp(pem_key, kSslEnginePrefix, strlen(kSslEnginePrefix)) == 0) {
700     return ssl_ctx_use_engine_private_key(context, pem_key, pem_key_size);
701   } else
702 #endif /* OPENSSL_IS_BORINGSSL */
703   {
704     return ssl_ctx_use_pem_private_key(context, pem_key, pem_key_size);
705   }
706 }
707 
708 /* Loads in-memory PEM verification certs into the SSL context and optionally
709    returns the verification cert names (root_names can be NULL). */
x509_store_load_certs(X509_STORE * cert_store,const char * pem_roots,size_t pem_roots_size,STACK_OF (X509_NAME)** root_names)710 static tsi_result x509_store_load_certs(X509_STORE* cert_store,
711                                         const char* pem_roots,
712                                         size_t pem_roots_size,
713                                         STACK_OF(X509_NAME) * *root_names) {
714   tsi_result result = TSI_OK;
715   size_t num_roots = 0;
716   X509* root = nullptr;
717   X509_NAME* root_name = nullptr;
718   BIO* pem;
719   GPR_ASSERT(pem_roots_size <= INT_MAX);
720   pem = BIO_new_mem_buf((void*)pem_roots, static_cast<int>(pem_roots_size));
721   if (cert_store == nullptr) return TSI_INVALID_ARGUMENT;
722   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
723   if (root_names != nullptr) {
724     *root_names = sk_X509_NAME_new_null();
725     if (*root_names == nullptr) return TSI_OUT_OF_RESOURCES;
726   }
727 
728   while (true) {
729     root = PEM_read_bio_X509_AUX(pem, nullptr, nullptr, (void*)"");
730     if (root == nullptr) {
731       ERR_clear_error();
732       break; /* We're at the end of stream. */
733     }
734     if (root_names != nullptr) {
735       root_name = X509_get_subject_name(root);
736       if (root_name == nullptr) {
737         gpr_log(GPR_ERROR, "Could not get name from root certificate.");
738         result = TSI_INVALID_ARGUMENT;
739         break;
740       }
741       root_name = X509_NAME_dup(root_name);
742       if (root_name == nullptr) {
743         result = TSI_OUT_OF_RESOURCES;
744         break;
745       }
746       sk_X509_NAME_push(*root_names, root_name);
747       root_name = nullptr;
748     }
749     ERR_clear_error();
750     if (!X509_STORE_add_cert(cert_store, root)) {
751       unsigned long error = ERR_get_error();
752       if (ERR_GET_LIB(error) != ERR_LIB_X509 ||
753           ERR_GET_REASON(error) != X509_R_CERT_ALREADY_IN_HASH_TABLE) {
754         gpr_log(GPR_ERROR, "Could not add root certificate to ssl context.");
755         result = TSI_INTERNAL_ERROR;
756         break;
757       }
758     }
759     X509_free(root);
760     num_roots++;
761   }
762   if (num_roots == 0) {
763     gpr_log(GPR_ERROR, "Could not load any root certificate.");
764     result = TSI_INVALID_ARGUMENT;
765   }
766 
767   if (result != TSI_OK) {
768     if (root != nullptr) X509_free(root);
769     if (root_names != nullptr) {
770       sk_X509_NAME_pop_free(*root_names, X509_NAME_free);
771       *root_names = nullptr;
772       if (root_name != nullptr) X509_NAME_free(root_name);
773     }
774   }
775   BIO_free(pem);
776   return result;
777 }
778 
ssl_ctx_load_verification_certs(SSL_CTX * context,const char * pem_roots,size_t pem_roots_size,STACK_OF (X509_NAME)** root_name)779 static tsi_result ssl_ctx_load_verification_certs(SSL_CTX* context,
780                                                   const char* pem_roots,
781                                                   size_t pem_roots_size,
782                                                   STACK_OF(X509_NAME) *
783                                                       *root_name) {
784   X509_STORE* cert_store = SSL_CTX_get_cert_store(context);
785   X509_STORE_set_flags(cert_store,
786                        X509_V_FLAG_PARTIAL_CHAIN | X509_V_FLAG_TRUSTED_FIRST);
787   return x509_store_load_certs(cert_store, pem_roots, pem_roots_size,
788                                root_name);
789 }
790 
791 /* Populates the SSL context with a private key and a cert chain, and sets the
792    cipher list and the ephemeral ECDH key. */
populate_ssl_context(SSL_CTX * context,const tsi_ssl_pem_key_cert_pair * key_cert_pair,const char * cipher_list)793 static tsi_result populate_ssl_context(
794     SSL_CTX* context, const tsi_ssl_pem_key_cert_pair* key_cert_pair,
795     const char* cipher_list) {
796   tsi_result result = TSI_OK;
797   if (key_cert_pair != nullptr) {
798     if (key_cert_pair->cert_chain != nullptr) {
799       result = ssl_ctx_use_certificate_chain(context, key_cert_pair->cert_chain,
800                                              strlen(key_cert_pair->cert_chain));
801       if (result != TSI_OK) {
802         gpr_log(GPR_ERROR, "Invalid cert chain file.");
803         return result;
804       }
805     }
806     if (key_cert_pair->private_key != nullptr) {
807       result = ssl_ctx_use_private_key(context, key_cert_pair->private_key,
808                                        strlen(key_cert_pair->private_key));
809       if (result != TSI_OK || !SSL_CTX_check_private_key(context)) {
810         gpr_log(GPR_ERROR, "Invalid private key.");
811         return result != TSI_OK ? result : TSI_INVALID_ARGUMENT;
812       }
813     }
814   }
815   if ((cipher_list != nullptr) &&
816       !SSL_CTX_set_cipher_list(context, cipher_list)) {
817     gpr_log(GPR_ERROR, "Invalid cipher list: %s.", cipher_list);
818     return TSI_INVALID_ARGUMENT;
819   }
820   {
821     EC_KEY* ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
822     if (!SSL_CTX_set_tmp_ecdh(context, ecdh)) {
823       gpr_log(GPR_ERROR, "Could not set ephemeral ECDH key.");
824       EC_KEY_free(ecdh);
825       return TSI_INTERNAL_ERROR;
826     }
827     SSL_CTX_set_options(context, SSL_OP_SINGLE_ECDH_USE);
828     EC_KEY_free(ecdh);
829   }
830   return TSI_OK;
831 }
832 
833 /* Extracts the CN and the SANs from an X509 cert as a peer object. */
tsi_ssl_extract_x509_subject_names_from_pem_cert(const char * pem_cert,tsi_peer * peer)834 tsi_result tsi_ssl_extract_x509_subject_names_from_pem_cert(
835     const char* pem_cert, tsi_peer* peer) {
836   tsi_result result = TSI_OK;
837   X509* cert = nullptr;
838   BIO* pem;
839   pem = BIO_new_mem_buf((void*)pem_cert, static_cast<int>(strlen(pem_cert)));
840   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
841 
842   cert = PEM_read_bio_X509(pem, nullptr, nullptr, (void*)"");
843   if (cert == nullptr) {
844     gpr_log(GPR_ERROR, "Invalid certificate");
845     result = TSI_INVALID_ARGUMENT;
846   } else {
847     result = peer_from_x509(cert, 0, peer);
848   }
849   if (cert != nullptr) X509_free(cert);
850   BIO_free(pem);
851   return result;
852 }
853 
854 /* Builds the alpn protocol name list according to rfc 7301. */
build_alpn_protocol_name_list(const char ** alpn_protocols,uint16_t num_alpn_protocols,unsigned char ** protocol_name_list,size_t * protocol_name_list_length)855 static tsi_result build_alpn_protocol_name_list(
856     const char** alpn_protocols, uint16_t num_alpn_protocols,
857     unsigned char** protocol_name_list, size_t* protocol_name_list_length) {
858   uint16_t i;
859   unsigned char* current;
860   *protocol_name_list = nullptr;
861   *protocol_name_list_length = 0;
862   if (num_alpn_protocols == 0) return TSI_INVALID_ARGUMENT;
863   for (i = 0; i < num_alpn_protocols; i++) {
864     size_t length =
865         alpn_protocols[i] == nullptr ? 0 : strlen(alpn_protocols[i]);
866     if (length == 0 || length > 255) {
867       gpr_log(GPR_ERROR, "Invalid protocol name length: %d.",
868               static_cast<int>(length));
869       return TSI_INVALID_ARGUMENT;
870     }
871     *protocol_name_list_length += length + 1;
872   }
873   *protocol_name_list =
874       static_cast<unsigned char*>(gpr_malloc(*protocol_name_list_length));
875   if (*protocol_name_list == nullptr) return TSI_OUT_OF_RESOURCES;
876   current = *protocol_name_list;
877   for (i = 0; i < num_alpn_protocols; i++) {
878     size_t length = strlen(alpn_protocols[i]);
879     *(current++) = static_cast<uint8_t>(length); /* max checked above. */
880     memcpy(current, alpn_protocols[i], length);
881     current += length;
882   }
883   /* Safety check. */
884   if ((current < *protocol_name_list) ||
885       (static_cast<uintptr_t>(current - *protocol_name_list) !=
886        *protocol_name_list_length)) {
887     return TSI_INTERNAL_ERROR;
888   }
889   return TSI_OK;
890 }
891 
892 // The verification callback is used for clients that don't really care about
893 // the server's certificate, but we need to pull it anyway, in case a higher
894 // layer wants to look at it. In this case the verification may fail, but
895 // we don't really care.
NullVerifyCallback(int,X509_STORE_CTX *)896 static int NullVerifyCallback(int /*preverify_ok*/, X509_STORE_CTX* /*ctx*/) {
897   return 1;
898 }
899 
900 // Sets the min and max TLS version of |ssl_context| to |min_tls_version| and
901 // |max_tls_version|, respectively. Calling this method is a no-op when using
902 // OpenSSL versions < 1.1.
tsi_set_min_and_max_tls_versions(SSL_CTX * ssl_context,tsi_tls_version min_tls_version,tsi_tls_version max_tls_version)903 static tsi_result tsi_set_min_and_max_tls_versions(
904     SSL_CTX* ssl_context, tsi_tls_version min_tls_version,
905     tsi_tls_version max_tls_version) {
906   if (ssl_context == nullptr) {
907     gpr_log(GPR_INFO,
908             "Invalid nullptr argument to |tsi_set_min_and_max_tls_versions|.");
909     return TSI_INVALID_ARGUMENT;
910   }
911 #if OPENSSL_VERSION_NUMBER >= 0x10100000
912   // Set the min TLS version of the SSL context.
913   switch (min_tls_version) {
914     case tsi_tls_version::TSI_TLS1_2:
915       SSL_CTX_set_min_proto_version(ssl_context, TLS1_2_VERSION);
916       break;
917 #if defined(TLS1_3_VERSION)
918     case tsi_tls_version::TSI_TLS1_3:
919       SSL_CTX_set_min_proto_version(ssl_context, TLS1_3_VERSION);
920       break;
921 #endif
922     default:
923       gpr_log(GPR_INFO, "TLS version is not supported.");
924       return TSI_FAILED_PRECONDITION;
925   }
926   // Set the max TLS version of the SSL context.
927   switch (max_tls_version) {
928     case tsi_tls_version::TSI_TLS1_2:
929       SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION);
930       break;
931 #if defined(TLS1_3_VERSION)
932     case tsi_tls_version::TSI_TLS1_3:
933       SSL_CTX_set_max_proto_version(ssl_context, TLS1_3_VERSION);
934       break;
935 #endif
936     default:
937       gpr_log(GPR_INFO, "TLS version is not supported.");
938       return TSI_FAILED_PRECONDITION;
939   }
940 #endif
941   return TSI_OK;
942 }
943 
944 /* --- tsi_ssl_root_certs_store methods implementation. ---*/
945 
tsi_ssl_root_certs_store_create(const char * pem_roots)946 tsi_ssl_root_certs_store* tsi_ssl_root_certs_store_create(
947     const char* pem_roots) {
948   if (pem_roots == nullptr) {
949     gpr_log(GPR_ERROR, "The root certificates are empty.");
950     return nullptr;
951   }
952   tsi_ssl_root_certs_store* root_store = static_cast<tsi_ssl_root_certs_store*>(
953       gpr_zalloc(sizeof(tsi_ssl_root_certs_store)));
954   if (root_store == nullptr) {
955     gpr_log(GPR_ERROR, "Could not allocate buffer for ssl_root_certs_store.");
956     return nullptr;
957   }
958   root_store->store = X509_STORE_new();
959   if (root_store->store == nullptr) {
960     gpr_log(GPR_ERROR, "Could not allocate buffer for X509_STORE.");
961     gpr_free(root_store);
962     return nullptr;
963   }
964   tsi_result result = x509_store_load_certs(root_store->store, pem_roots,
965                                             strlen(pem_roots), nullptr);
966   if (result != TSI_OK) {
967     gpr_log(GPR_ERROR, "Could not load root certificates.");
968     X509_STORE_free(root_store->store);
969     gpr_free(root_store);
970     return nullptr;
971   }
972   return root_store;
973 }
974 
tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store * self)975 void tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store* self) {
976   if (self == nullptr) return;
977   X509_STORE_free(self->store);
978   gpr_free(self);
979 }
980 
981 /* --- tsi_ssl_session_cache methods implementation. ---*/
982 
tsi_ssl_session_cache_create_lru(size_t capacity)983 tsi_ssl_session_cache* tsi_ssl_session_cache_create_lru(size_t capacity) {
984   /* Pointer will be dereferenced by unref call. */
985   return reinterpret_cast<tsi_ssl_session_cache*>(
986       tsi::SslSessionLRUCache::Create(capacity).release());
987 }
988 
tsi_ssl_session_cache_ref(tsi_ssl_session_cache * cache)989 void tsi_ssl_session_cache_ref(tsi_ssl_session_cache* cache) {
990   /* Pointer will be dereferenced by unref call. */
991   reinterpret_cast<tsi::SslSessionLRUCache*>(cache)->Ref().release();
992 }
993 
tsi_ssl_session_cache_unref(tsi_ssl_session_cache * cache)994 void tsi_ssl_session_cache_unref(tsi_ssl_session_cache* cache) {
995   reinterpret_cast<tsi::SslSessionLRUCache*>(cache)->Unref();
996 }
997 
998 /* --- tsi_frame_protector methods implementation. ---*/
999 
ssl_protector_protect(tsi_frame_protector * self,const unsigned char * unprotected_bytes,size_t * unprotected_bytes_size,unsigned char * protected_output_frames,size_t * protected_output_frames_size)1000 static tsi_result ssl_protector_protect(tsi_frame_protector* self,
1001                                         const unsigned char* unprotected_bytes,
1002                                         size_t* unprotected_bytes_size,
1003                                         unsigned char* protected_output_frames,
1004                                         size_t* protected_output_frames_size) {
1005   tsi_ssl_frame_protector* impl =
1006       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1007   int read_from_ssl;
1008   size_t available;
1009   tsi_result result = TSI_OK;
1010 
1011   /* First see if we have some pending data in the SSL BIO. */
1012   int pending_in_ssl = static_cast<int>(BIO_pending(impl->network_io));
1013   if (pending_in_ssl > 0) {
1014     *unprotected_bytes_size = 0;
1015     GPR_ASSERT(*protected_output_frames_size <= INT_MAX);
1016     read_from_ssl = BIO_read(impl->network_io, protected_output_frames,
1017                              static_cast<int>(*protected_output_frames_size));
1018     if (read_from_ssl < 0) {
1019       gpr_log(GPR_ERROR,
1020               "Could not read from BIO even though some data is pending");
1021       return TSI_INTERNAL_ERROR;
1022     }
1023     *protected_output_frames_size = static_cast<size_t>(read_from_ssl);
1024     return TSI_OK;
1025   }
1026 
1027   /* Now see if we can send a complete frame. */
1028   available = impl->buffer_size - impl->buffer_offset;
1029   if (available > *unprotected_bytes_size) {
1030     /* If we cannot, just copy the data in our internal buffer. */
1031     memcpy(impl->buffer + impl->buffer_offset, unprotected_bytes,
1032            *unprotected_bytes_size);
1033     impl->buffer_offset += *unprotected_bytes_size;
1034     *protected_output_frames_size = 0;
1035     return TSI_OK;
1036   }
1037 
1038   /* If we can, prepare the buffer, send it to SSL_write and read. */
1039   memcpy(impl->buffer + impl->buffer_offset, unprotected_bytes, available);
1040   result = do_ssl_write(impl->ssl, impl->buffer, impl->buffer_size);
1041   if (result != TSI_OK) return result;
1042 
1043   GPR_ASSERT(*protected_output_frames_size <= INT_MAX);
1044   read_from_ssl = BIO_read(impl->network_io, protected_output_frames,
1045                            static_cast<int>(*protected_output_frames_size));
1046   if (read_from_ssl < 0) {
1047     gpr_log(GPR_ERROR, "Could not read from BIO after SSL_write.");
1048     return TSI_INTERNAL_ERROR;
1049   }
1050   *protected_output_frames_size = static_cast<size_t>(read_from_ssl);
1051   *unprotected_bytes_size = available;
1052   impl->buffer_offset = 0;
1053   return TSI_OK;
1054 }
1055 
ssl_protector_protect_flush(tsi_frame_protector * self,unsigned char * protected_output_frames,size_t * protected_output_frames_size,size_t * still_pending_size)1056 static tsi_result ssl_protector_protect_flush(
1057     tsi_frame_protector* self, unsigned char* protected_output_frames,
1058     size_t* protected_output_frames_size, size_t* still_pending_size) {
1059   tsi_result result = TSI_OK;
1060   tsi_ssl_frame_protector* impl =
1061       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1062   int read_from_ssl = 0;
1063   int pending;
1064 
1065   if (impl->buffer_offset != 0) {
1066     result = do_ssl_write(impl->ssl, impl->buffer, impl->buffer_offset);
1067     if (result != TSI_OK) return result;
1068     impl->buffer_offset = 0;
1069   }
1070 
1071   pending = static_cast<int>(BIO_pending(impl->network_io));
1072   GPR_ASSERT(pending >= 0);
1073   *still_pending_size = static_cast<size_t>(pending);
1074   if (*still_pending_size == 0) return TSI_OK;
1075 
1076   GPR_ASSERT(*protected_output_frames_size <= INT_MAX);
1077   read_from_ssl = BIO_read(impl->network_io, protected_output_frames,
1078                            static_cast<int>(*protected_output_frames_size));
1079   if (read_from_ssl <= 0) {
1080     gpr_log(GPR_ERROR, "Could not read from BIO after SSL_write.");
1081     return TSI_INTERNAL_ERROR;
1082   }
1083   *protected_output_frames_size = static_cast<size_t>(read_from_ssl);
1084   pending = static_cast<int>(BIO_pending(impl->network_io));
1085   GPR_ASSERT(pending >= 0);
1086   *still_pending_size = static_cast<size_t>(pending);
1087   return TSI_OK;
1088 }
1089 
ssl_protector_unprotect(tsi_frame_protector * self,const unsigned char * protected_frames_bytes,size_t * protected_frames_bytes_size,unsigned char * unprotected_bytes,size_t * unprotected_bytes_size)1090 static tsi_result ssl_protector_unprotect(
1091     tsi_frame_protector* self, const unsigned char* protected_frames_bytes,
1092     size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes,
1093     size_t* unprotected_bytes_size) {
1094   tsi_result result = TSI_OK;
1095   int written_into_ssl = 0;
1096   size_t output_bytes_size = *unprotected_bytes_size;
1097   size_t output_bytes_offset = 0;
1098   tsi_ssl_frame_protector* impl =
1099       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1100 
1101   /* First, try to read remaining data from ssl. */
1102   result = do_ssl_read(impl->ssl, unprotected_bytes, unprotected_bytes_size);
1103   if (result != TSI_OK) return result;
1104   if (*unprotected_bytes_size == output_bytes_size) {
1105     /* We have read everything we could and cannot process any more input. */
1106     *protected_frames_bytes_size = 0;
1107     return TSI_OK;
1108   }
1109   output_bytes_offset = *unprotected_bytes_size;
1110   unprotected_bytes += output_bytes_offset;
1111   *unprotected_bytes_size = output_bytes_size - output_bytes_offset;
1112 
1113   /* Then, try to write some data to ssl. */
1114   GPR_ASSERT(*protected_frames_bytes_size <= INT_MAX);
1115   written_into_ssl = BIO_write(impl->network_io, protected_frames_bytes,
1116                                static_cast<int>(*protected_frames_bytes_size));
1117   if (written_into_ssl < 0) {
1118     gpr_log(GPR_ERROR, "Sending protected frame to ssl failed with %d",
1119             written_into_ssl);
1120     return TSI_INTERNAL_ERROR;
1121   }
1122   *protected_frames_bytes_size = static_cast<size_t>(written_into_ssl);
1123 
1124   /* Now try to read some data again. */
1125   result = do_ssl_read(impl->ssl, unprotected_bytes, unprotected_bytes_size);
1126   if (result == TSI_OK) {
1127     /* Don't forget to output the total number of bytes read. */
1128     *unprotected_bytes_size += output_bytes_offset;
1129   }
1130   return result;
1131 }
1132 
ssl_protector_destroy(tsi_frame_protector * self)1133 static void ssl_protector_destroy(tsi_frame_protector* self) {
1134   tsi_ssl_frame_protector* impl =
1135       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1136   if (impl->buffer != nullptr) gpr_free(impl->buffer);
1137   if (impl->ssl != nullptr) SSL_free(impl->ssl);
1138   if (impl->network_io != nullptr) BIO_free(impl->network_io);
1139   gpr_free(self);
1140 }
1141 
1142 static const tsi_frame_protector_vtable frame_protector_vtable = {
1143     ssl_protector_protect,
1144     ssl_protector_protect_flush,
1145     ssl_protector_unprotect,
1146     ssl_protector_destroy,
1147 };
1148 
1149 /* --- tsi_server_handshaker_factory methods implementation. --- */
1150 
tsi_ssl_handshaker_factory_destroy(tsi_ssl_handshaker_factory * self)1151 static void tsi_ssl_handshaker_factory_destroy(
1152     tsi_ssl_handshaker_factory* self) {
1153   if (self == nullptr) return;
1154 
1155   if (self->vtable != nullptr && self->vtable->destroy != nullptr) {
1156     self->vtable->destroy(self);
1157   }
1158   /* Note, we don't free(self) here because this object is always directly
1159    * embedded in another object. If tsi_ssl_handshaker_factory_init allocates
1160    * any memory, it should be free'd here. */
1161 }
1162 
tsi_ssl_handshaker_factory_ref(tsi_ssl_handshaker_factory * self)1163 static tsi_ssl_handshaker_factory* tsi_ssl_handshaker_factory_ref(
1164     tsi_ssl_handshaker_factory* self) {
1165   if (self == nullptr) return nullptr;
1166   gpr_refn(&self->refcount, 1);
1167   return self;
1168 }
1169 
tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory * self)1170 static void tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory* self) {
1171   if (self == nullptr) return;
1172 
1173   if (gpr_unref(&self->refcount)) {
1174     tsi_ssl_handshaker_factory_destroy(self);
1175   }
1176 }
1177 
1178 static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = {nullptr};
1179 
1180 /* Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for
1181  * allocating memory for the factory. */
tsi_ssl_handshaker_factory_init(tsi_ssl_handshaker_factory * factory)1182 static void tsi_ssl_handshaker_factory_init(
1183     tsi_ssl_handshaker_factory* factory) {
1184   GPR_ASSERT(factory != nullptr);
1185 
1186   factory->vtable = &handshaker_factory_vtable;
1187   gpr_ref_init(&factory->refcount, 1);
1188 }
1189 
1190 /* Gets the X509 cert chain in PEM format as a tsi_peer_property. */
tsi_ssl_get_cert_chain_contents(STACK_OF (X509)* peer_chain,tsi_peer_property * property)1191 tsi_result tsi_ssl_get_cert_chain_contents(STACK_OF(X509) * peer_chain,
1192                                            tsi_peer_property* property) {
1193   BIO* bio = BIO_new(BIO_s_mem());
1194   const auto peer_chain_len = sk_X509_num(peer_chain);
1195   for (auto i = decltype(peer_chain_len){0}; i < peer_chain_len; i++) {
1196     if (!PEM_write_bio_X509(bio, sk_X509_value(peer_chain, i))) {
1197       BIO_free(bio);
1198       return TSI_INTERNAL_ERROR;
1199     }
1200   }
1201   char* contents;
1202   long len = BIO_get_mem_data(bio, &contents);
1203   if (len <= 0) {
1204     BIO_free(bio);
1205     return TSI_INTERNAL_ERROR;
1206   }
1207   tsi_result result = tsi_construct_string_peer_property(
1208       TSI_X509_PEM_CERT_CHAIN_PROPERTY, (const char*)contents,
1209       static_cast<size_t>(len), property);
1210   BIO_free(bio);
1211   return result;
1212 }
1213 
1214 /* --- tsi_handshaker_result methods implementation. ---*/
ssl_handshaker_result_extract_peer(const tsi_handshaker_result * self,tsi_peer * peer)1215 static tsi_result ssl_handshaker_result_extract_peer(
1216     const tsi_handshaker_result* self, tsi_peer* peer) {
1217   tsi_result result = TSI_OK;
1218   const unsigned char* alpn_selected = nullptr;
1219   unsigned int alpn_selected_len;
1220   const tsi_ssl_handshaker_result* impl =
1221       reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
1222   X509* peer_cert = SSL_get_peer_certificate(impl->ssl);
1223   if (peer_cert != nullptr) {
1224     result = peer_from_x509(peer_cert, 1, peer);
1225     X509_free(peer_cert);
1226     if (result != TSI_OK) return result;
1227   }
1228 #if TSI_OPENSSL_ALPN_SUPPORT
1229   SSL_get0_alpn_selected(impl->ssl, &alpn_selected, &alpn_selected_len);
1230 #endif /* TSI_OPENSSL_ALPN_SUPPORT */
1231   if (alpn_selected == nullptr) {
1232     /* Try npn. */
1233     SSL_get0_next_proto_negotiated(impl->ssl, &alpn_selected,
1234                                    &alpn_selected_len);
1235   }
1236   // When called on the client side, the stack also contains the
1237   // peer's certificate; When called on the server side,
1238   // the peer's certificate is not present in the stack
1239   STACK_OF(X509)* peer_chain = SSL_get_peer_cert_chain(impl->ssl);
1240   // 1 is for session reused property.
1241   size_t new_property_count = peer->property_count + 3;
1242   if (alpn_selected != nullptr) new_property_count++;
1243   if (peer_chain != nullptr) new_property_count++;
1244   tsi_peer_property* new_properties = static_cast<tsi_peer_property*>(
1245       gpr_zalloc(sizeof(*new_properties) * new_property_count));
1246   for (size_t i = 0; i < peer->property_count; i++) {
1247     new_properties[i] = peer->properties[i];
1248   }
1249   if (peer->properties != nullptr) gpr_free(peer->properties);
1250   peer->properties = new_properties;
1251   // Add peer chain if available
1252   if (peer_chain != nullptr) {
1253     result = tsi_ssl_get_cert_chain_contents(
1254         peer_chain, &peer->properties[peer->property_count]);
1255     if (result == TSI_OK) peer->property_count++;
1256   }
1257   if (alpn_selected != nullptr) {
1258     result = tsi_construct_string_peer_property(
1259         TSI_SSL_ALPN_SELECTED_PROTOCOL,
1260         reinterpret_cast<const char*>(alpn_selected), alpn_selected_len,
1261         &peer->properties[peer->property_count]);
1262     if (result != TSI_OK) return result;
1263     peer->property_count++;
1264   }
1265   // Add security_level peer property.
1266   result = tsi_construct_string_peer_property_from_cstring(
1267       TSI_SECURITY_LEVEL_PEER_PROPERTY,
1268       tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY),
1269       &peer->properties[peer->property_count]);
1270   if (result != TSI_OK) return result;
1271   peer->property_count++;
1272 
1273   const char* session_reused = SSL_session_reused(impl->ssl) ? "true" : "false";
1274   result = tsi_construct_string_peer_property_from_cstring(
1275       TSI_SSL_SESSION_REUSED_PEER_PROPERTY, session_reused,
1276       &peer->properties[peer->property_count]);
1277   if (result != TSI_OK) return result;
1278   peer->property_count++;
1279   return result;
1280 }
1281 
ssl_handshaker_result_create_frame_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_frame_protector ** protector)1282 static tsi_result ssl_handshaker_result_create_frame_protector(
1283     const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
1284     tsi_frame_protector** protector) {
1285   size_t actual_max_output_protected_frame_size =
1286       TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
1287   tsi_ssl_handshaker_result* impl =
1288       reinterpret_cast<tsi_ssl_handshaker_result*>(
1289           const_cast<tsi_handshaker_result*>(self));
1290   tsi_ssl_frame_protector* protector_impl =
1291       static_cast<tsi_ssl_frame_protector*>(
1292           gpr_zalloc(sizeof(*protector_impl)));
1293 
1294   if (max_output_protected_frame_size != nullptr) {
1295     if (*max_output_protected_frame_size >
1296         TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND) {
1297       *max_output_protected_frame_size =
1298           TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
1299     } else if (*max_output_protected_frame_size <
1300                TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND) {
1301       *max_output_protected_frame_size =
1302           TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND;
1303     }
1304     actual_max_output_protected_frame_size = *max_output_protected_frame_size;
1305   }
1306   protector_impl->buffer_size =
1307       actual_max_output_protected_frame_size - TSI_SSL_MAX_PROTECTION_OVERHEAD;
1308   protector_impl->buffer =
1309       static_cast<unsigned char*>(gpr_malloc(protector_impl->buffer_size));
1310   if (protector_impl->buffer == nullptr) {
1311     gpr_log(GPR_ERROR,
1312             "Could not allocated buffer for tsi_ssl_frame_protector.");
1313     gpr_free(protector_impl);
1314     return TSI_INTERNAL_ERROR;
1315   }
1316 
1317   /* Transfer ownership of ssl and network_io to the frame protector. */
1318   protector_impl->ssl = impl->ssl;
1319   impl->ssl = nullptr;
1320   protector_impl->network_io = impl->network_io;
1321   impl->network_io = nullptr;
1322   protector_impl->base.vtable = &frame_protector_vtable;
1323   *protector = &protector_impl->base;
1324   return TSI_OK;
1325 }
1326 
ssl_handshaker_result_get_unused_bytes(const tsi_handshaker_result * self,const unsigned char ** bytes,size_t * bytes_size)1327 static tsi_result ssl_handshaker_result_get_unused_bytes(
1328     const tsi_handshaker_result* self, const unsigned char** bytes,
1329     size_t* bytes_size) {
1330   const tsi_ssl_handshaker_result* impl =
1331       reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
1332   *bytes_size = impl->unused_bytes_size;
1333   *bytes = impl->unused_bytes;
1334   return TSI_OK;
1335 }
1336 
ssl_handshaker_result_destroy(tsi_handshaker_result * self)1337 static void ssl_handshaker_result_destroy(tsi_handshaker_result* self) {
1338   tsi_ssl_handshaker_result* impl =
1339       reinterpret_cast<tsi_ssl_handshaker_result*>(self);
1340   SSL_free(impl->ssl);
1341   BIO_free(impl->network_io);
1342   gpr_free(impl->unused_bytes);
1343   gpr_free(impl);
1344 }
1345 
1346 static const tsi_handshaker_result_vtable handshaker_result_vtable = {
1347     ssl_handshaker_result_extract_peer,
1348     nullptr, /* create_zero_copy_grpc_protector */
1349     ssl_handshaker_result_create_frame_protector,
1350     ssl_handshaker_result_get_unused_bytes,
1351     ssl_handshaker_result_destroy,
1352 };
1353 
ssl_handshaker_result_create(tsi_ssl_handshaker * handshaker,unsigned char * unused_bytes,size_t unused_bytes_size,tsi_handshaker_result ** handshaker_result)1354 static tsi_result ssl_handshaker_result_create(
1355     tsi_ssl_handshaker* handshaker, unsigned char* unused_bytes,
1356     size_t unused_bytes_size, tsi_handshaker_result** handshaker_result) {
1357   if (handshaker == nullptr || handshaker_result == nullptr ||
1358       (unused_bytes_size > 0 && unused_bytes == nullptr)) {
1359     return TSI_INVALID_ARGUMENT;
1360   }
1361   tsi_ssl_handshaker_result* result =
1362       static_cast<tsi_ssl_handshaker_result*>(gpr_zalloc(sizeof(*result)));
1363   result->base.vtable = &handshaker_result_vtable;
1364   /* Transfer ownership of ssl and network_io to the handshaker result. */
1365   result->ssl = handshaker->ssl;
1366   handshaker->ssl = nullptr;
1367   result->network_io = handshaker->network_io;
1368   handshaker->network_io = nullptr;
1369   /* Transfer ownership of |unused_bytes| to the handshaker result. */
1370   result->unused_bytes = unused_bytes;
1371   result->unused_bytes_size = unused_bytes_size;
1372   *handshaker_result = &result->base;
1373   return TSI_OK;
1374 }
1375 
1376 /* --- tsi_handshaker methods implementation. ---*/
1377 
ssl_handshaker_get_bytes_to_send_to_peer(tsi_ssl_handshaker * impl,unsigned char * bytes,size_t * bytes_size)1378 static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(
1379     tsi_ssl_handshaker* impl, unsigned char* bytes, size_t* bytes_size) {
1380   int bytes_read_from_ssl = 0;
1381   if (bytes == nullptr || bytes_size == nullptr || *bytes_size == 0 ||
1382       *bytes_size > INT_MAX) {
1383     return TSI_INVALID_ARGUMENT;
1384   }
1385   GPR_ASSERT(*bytes_size <= INT_MAX);
1386   bytes_read_from_ssl =
1387       BIO_read(impl->network_io, bytes, static_cast<int>(*bytes_size));
1388   if (bytes_read_from_ssl < 0) {
1389     *bytes_size = 0;
1390     if (!BIO_should_retry(impl->network_io)) {
1391       impl->result = TSI_INTERNAL_ERROR;
1392       return impl->result;
1393     } else {
1394       return TSI_OK;
1395     }
1396   }
1397   *bytes_size = static_cast<size_t>(bytes_read_from_ssl);
1398   return BIO_pending(impl->network_io) == 0 ? TSI_OK : TSI_INCOMPLETE_DATA;
1399 }
1400 
ssl_handshaker_get_result(tsi_ssl_handshaker * impl)1401 static tsi_result ssl_handshaker_get_result(tsi_ssl_handshaker* impl) {
1402   if ((impl->result == TSI_HANDSHAKE_IN_PROGRESS) &&
1403       SSL_is_init_finished(impl->ssl)) {
1404     impl->result = TSI_OK;
1405   }
1406   return impl->result;
1407 }
1408 
ssl_handshaker_process_bytes_from_peer(tsi_ssl_handshaker * impl,const unsigned char * bytes,size_t * bytes_size)1409 static tsi_result ssl_handshaker_process_bytes_from_peer(
1410     tsi_ssl_handshaker* impl, const unsigned char* bytes, size_t* bytes_size) {
1411   int bytes_written_into_ssl_size = 0;
1412   if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
1413     return TSI_INVALID_ARGUMENT;
1414   }
1415   GPR_ASSERT(*bytes_size <= INT_MAX);
1416   bytes_written_into_ssl_size =
1417       BIO_write(impl->network_io, bytes, static_cast<int>(*bytes_size));
1418   if (bytes_written_into_ssl_size < 0) {
1419     gpr_log(GPR_ERROR, "Could not write to memory BIO.");
1420     impl->result = TSI_INTERNAL_ERROR;
1421     return impl->result;
1422   }
1423   *bytes_size = static_cast<size_t>(bytes_written_into_ssl_size);
1424 
1425   if (ssl_handshaker_get_result(impl) != TSI_HANDSHAKE_IN_PROGRESS) {
1426     impl->result = TSI_OK;
1427     return impl->result;
1428   } else {
1429     /* Get ready to get some bytes from SSL. */
1430     int ssl_result = SSL_do_handshake(impl->ssl);
1431     ssl_result = SSL_get_error(impl->ssl, ssl_result);
1432     switch (ssl_result) {
1433       case SSL_ERROR_WANT_READ:
1434         if (BIO_pending(impl->network_io) == 0) {
1435           /* We need more data. */
1436           return TSI_INCOMPLETE_DATA;
1437         } else {
1438           return TSI_OK;
1439         }
1440       case SSL_ERROR_NONE:
1441         return TSI_OK;
1442       default: {
1443         char err_str[256];
1444         ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
1445         gpr_log(GPR_ERROR, "Handshake failed with fatal error %s: %s.",
1446                 ssl_error_string(ssl_result), err_str);
1447         impl->result = TSI_PROTOCOL_FAILURE;
1448         return impl->result;
1449       }
1450     }
1451   }
1452 }
1453 
ssl_handshaker_destroy(tsi_handshaker * self)1454 static void ssl_handshaker_destroy(tsi_handshaker* self) {
1455   tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1456   SSL_free(impl->ssl);
1457   BIO_free(impl->network_io);
1458   gpr_free(impl->outgoing_bytes_buffer);
1459   tsi_ssl_handshaker_factory_unref(impl->factory_ref);
1460   gpr_free(impl);
1461 }
1462 
1463 // Removes the bytes remaining in |impl->SSL|'s read BIO and writes them to
1464 // |bytes_remaining|.
ssl_bytes_remaining(tsi_ssl_handshaker * impl,unsigned char ** bytes_remaining,size_t * bytes_remaining_size)1465 static tsi_result ssl_bytes_remaining(tsi_ssl_handshaker* impl,
1466                                       unsigned char** bytes_remaining,
1467                                       size_t* bytes_remaining_size) {
1468   if (impl == nullptr || bytes_remaining == nullptr ||
1469       bytes_remaining_size == nullptr) {
1470     return TSI_INVALID_ARGUMENT;
1471   }
1472   // Atempt to read all of the bytes in SSL's read BIO. These bytes should
1473   // contain application data records that were appended to a handshake record
1474   // containing the ClientFinished or ServerFinished message.
1475   size_t bytes_in_ssl = BIO_pending(SSL_get_rbio(impl->ssl));
1476   if (bytes_in_ssl == 0) return TSI_OK;
1477   *bytes_remaining = static_cast<uint8_t*>(gpr_malloc(bytes_in_ssl));
1478   int bytes_read = BIO_read(SSL_get_rbio(impl->ssl), *bytes_remaining,
1479                             static_cast<int>(bytes_in_ssl));
1480   // If an unexpected number of bytes were read, return an error status and free
1481   // all of the bytes that were read.
1482   if (bytes_read < 0 || static_cast<size_t>(bytes_read) != bytes_in_ssl) {
1483     gpr_log(GPR_ERROR,
1484             "Failed to read the expected number of bytes from SSL object.");
1485     gpr_free(*bytes_remaining);
1486     *bytes_remaining = nullptr;
1487     return TSI_INTERNAL_ERROR;
1488   }
1489   *bytes_remaining_size = static_cast<size_t>(bytes_read);
1490   return TSI_OK;
1491 }
1492 
ssl_handshaker_next(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char ** bytes_to_send,size_t * bytes_to_send_size,tsi_handshaker_result ** handshaker_result,tsi_handshaker_on_next_done_cb,void *)1493 static tsi_result ssl_handshaker_next(
1494     tsi_handshaker* self, const unsigned char* received_bytes,
1495     size_t received_bytes_size, const unsigned char** bytes_to_send,
1496     size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result,
1497     tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/) {
1498   /* Input sanity check.  */
1499   if ((received_bytes_size > 0 && received_bytes == nullptr) ||
1500       bytes_to_send == nullptr || bytes_to_send_size == nullptr ||
1501       handshaker_result == nullptr) {
1502     return TSI_INVALID_ARGUMENT;
1503   }
1504   /* If there are received bytes, process them first.  */
1505   tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1506   tsi_result status = TSI_OK;
1507   size_t bytes_consumed = received_bytes_size;
1508   if (received_bytes_size > 0) {
1509     status = ssl_handshaker_process_bytes_from_peer(impl, received_bytes,
1510                                                     &bytes_consumed);
1511     if (status != TSI_OK) return status;
1512   }
1513   /* Get bytes to send to the peer, if available.  */
1514   size_t offset = 0;
1515   do {
1516     size_t to_send_size = impl->outgoing_bytes_buffer_size - offset;
1517     status = ssl_handshaker_get_bytes_to_send_to_peer(
1518         impl, impl->outgoing_bytes_buffer + offset, &to_send_size);
1519     offset += to_send_size;
1520     if (status == TSI_INCOMPLETE_DATA) {
1521       impl->outgoing_bytes_buffer_size *= 2;
1522       impl->outgoing_bytes_buffer = static_cast<unsigned char*>(gpr_realloc(
1523           impl->outgoing_bytes_buffer, impl->outgoing_bytes_buffer_size));
1524     }
1525   } while (status == TSI_INCOMPLETE_DATA);
1526   if (status != TSI_OK) return status;
1527   *bytes_to_send = impl->outgoing_bytes_buffer;
1528   *bytes_to_send_size = offset;
1529   /* If handshake completes, create tsi_handshaker_result.  */
1530   if (ssl_handshaker_get_result(impl) == TSI_HANDSHAKE_IN_PROGRESS) {
1531     *handshaker_result = nullptr;
1532   } else {
1533     // Any bytes that remain in |impl->ssl|'s read BIO after the handshake is
1534     // complete must be extracted and set to the unused bytes of the handshaker
1535     // result. This indicates to the gRPC stack that there are bytes from the
1536     // peer that must be processed.
1537     unsigned char* unused_bytes = nullptr;
1538     size_t unused_bytes_size = 0;
1539     status = ssl_bytes_remaining(impl, &unused_bytes, &unused_bytes_size);
1540     if (status != TSI_OK) return status;
1541     if (unused_bytes_size > received_bytes_size) {
1542       gpr_log(GPR_ERROR, "More unused bytes than received bytes.");
1543       gpr_free(unused_bytes);
1544       return TSI_INTERNAL_ERROR;
1545     }
1546     status = ssl_handshaker_result_create(impl, unused_bytes, unused_bytes_size,
1547                                           handshaker_result);
1548     if (status == TSI_OK) {
1549       /* Indicates that the handshake has completed and that a handshaker_result
1550        * has been created. */
1551       self->handshaker_result_created = true;
1552     }
1553   }
1554   return status;
1555 }
1556 
1557 static const tsi_handshaker_vtable handshaker_vtable = {
1558     nullptr, /* get_bytes_to_send_to_peer -- deprecated */
1559     nullptr, /* process_bytes_from_peer   -- deprecated */
1560     nullptr, /* get_result                -- deprecated */
1561     nullptr, /* extract_peer              -- deprecated */
1562     nullptr, /* create_frame_protector    -- deprecated */
1563     ssl_handshaker_destroy,
1564     ssl_handshaker_next,
1565     nullptr, /* shutdown */
1566 };
1567 
1568 /* --- tsi_ssl_handshaker_factory common methods. --- */
1569 
tsi_ssl_handshaker_resume_session(SSL * ssl,tsi::SslSessionLRUCache * session_cache)1570 static void tsi_ssl_handshaker_resume_session(
1571     SSL* ssl, tsi::SslSessionLRUCache* session_cache) {
1572   const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1573   if (server_name == nullptr) {
1574     return;
1575   }
1576   tsi::SslSessionPtr session = session_cache->Get(server_name);
1577   if (session != nullptr) {
1578     // SSL_set_session internally increments reference counter.
1579     SSL_set_session(ssl, session.get());
1580   }
1581 }
1582 
create_tsi_ssl_handshaker(SSL_CTX * ctx,int is_client,const char * server_name_indication,tsi_ssl_handshaker_factory * factory,tsi_handshaker ** handshaker)1583 static tsi_result create_tsi_ssl_handshaker(SSL_CTX* ctx, int is_client,
1584                                             const char* server_name_indication,
1585                                             tsi_ssl_handshaker_factory* factory,
1586                                             tsi_handshaker** handshaker) {
1587   SSL* ssl = SSL_new(ctx);
1588   BIO* network_io = nullptr;
1589   BIO* ssl_io = nullptr;
1590   tsi_ssl_handshaker* impl = nullptr;
1591   *handshaker = nullptr;
1592   if (ctx == nullptr) {
1593     gpr_log(GPR_ERROR, "SSL Context is null. Should never happen.");
1594     return TSI_INTERNAL_ERROR;
1595   }
1596   if (ssl == nullptr) {
1597     return TSI_OUT_OF_RESOURCES;
1598   }
1599   SSL_set_info_callback(ssl, ssl_info_callback);
1600 
1601   if (!BIO_new_bio_pair(&network_io, 0, &ssl_io, 0)) {
1602     gpr_log(GPR_ERROR, "BIO_new_bio_pair failed.");
1603     SSL_free(ssl);
1604     return TSI_OUT_OF_RESOURCES;
1605   }
1606   SSL_set_bio(ssl, ssl_io, ssl_io);
1607   if (is_client) {
1608     int ssl_result;
1609     SSL_set_connect_state(ssl);
1610     if (server_name_indication != nullptr) {
1611       if (!SSL_set_tlsext_host_name(ssl, server_name_indication)) {
1612         gpr_log(GPR_ERROR, "Invalid server name indication %s.",
1613                 server_name_indication);
1614         SSL_free(ssl);
1615         BIO_free(network_io);
1616         return TSI_INTERNAL_ERROR;
1617       }
1618     }
1619     tsi_ssl_client_handshaker_factory* client_factory =
1620         reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory);
1621     if (client_factory->session_cache != nullptr) {
1622       tsi_ssl_handshaker_resume_session(ssl,
1623                                         client_factory->session_cache.get());
1624     }
1625     ssl_result = SSL_do_handshake(ssl);
1626     ssl_result = SSL_get_error(ssl, ssl_result);
1627     if (ssl_result != SSL_ERROR_WANT_READ) {
1628       gpr_log(GPR_ERROR,
1629               "Unexpected error received from first SSL_do_handshake call: %s",
1630               ssl_error_string(ssl_result));
1631       SSL_free(ssl);
1632       BIO_free(network_io);
1633       return TSI_INTERNAL_ERROR;
1634     }
1635   } else {
1636     SSL_set_accept_state(ssl);
1637   }
1638 
1639   impl = static_cast<tsi_ssl_handshaker*>(gpr_zalloc(sizeof(*impl)));
1640   impl->ssl = ssl;
1641   impl->network_io = network_io;
1642   impl->result = TSI_HANDSHAKE_IN_PROGRESS;
1643   impl->outgoing_bytes_buffer_size =
1644       TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE;
1645   impl->outgoing_bytes_buffer =
1646       static_cast<unsigned char*>(gpr_zalloc(impl->outgoing_bytes_buffer_size));
1647   impl->base.vtable = &handshaker_vtable;
1648   impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory);
1649   *handshaker = &impl->base;
1650   return TSI_OK;
1651 }
1652 
select_protocol_list(const unsigned char ** out,unsigned char * outlen,const unsigned char * client_list,size_t client_list_len,const unsigned char * server_list,size_t server_list_len)1653 static int select_protocol_list(const unsigned char** out,
1654                                 unsigned char* outlen,
1655                                 const unsigned char* client_list,
1656                                 size_t client_list_len,
1657                                 const unsigned char* server_list,
1658                                 size_t server_list_len) {
1659   const unsigned char* client_current = client_list;
1660   while (static_cast<unsigned int>(client_current - client_list) <
1661          client_list_len) {
1662     unsigned char client_current_len = *(client_current++);
1663     const unsigned char* server_current = server_list;
1664     while ((server_current >= server_list) &&
1665            static_cast<uintptr_t>(server_current - server_list) <
1666                server_list_len) {
1667       unsigned char server_current_len = *(server_current++);
1668       if ((client_current_len == server_current_len) &&
1669           !memcmp(client_current, server_current, server_current_len)) {
1670         *out = server_current;
1671         *outlen = server_current_len;
1672         return SSL_TLSEXT_ERR_OK;
1673       }
1674       server_current += server_current_len;
1675     }
1676     client_current += client_current_len;
1677   }
1678   return SSL_TLSEXT_ERR_NOACK;
1679 }
1680 
1681 /* --- tsi_ssl_client_handshaker_factory methods implementation. --- */
1682 
tsi_ssl_client_handshaker_factory_create_handshaker(tsi_ssl_client_handshaker_factory * self,const char * server_name_indication,tsi_handshaker ** handshaker)1683 tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
1684     tsi_ssl_client_handshaker_factory* self, const char* server_name_indication,
1685     tsi_handshaker** handshaker) {
1686   return create_tsi_ssl_handshaker(self->ssl_context, 1, server_name_indication,
1687                                    &self->base, handshaker);
1688 }
1689 
tsi_ssl_client_handshaker_factory_unref(tsi_ssl_client_handshaker_factory * self)1690 void tsi_ssl_client_handshaker_factory_unref(
1691     tsi_ssl_client_handshaker_factory* self) {
1692   if (self == nullptr) return;
1693   tsi_ssl_handshaker_factory_unref(&self->base);
1694 }
1695 
tsi_ssl_client_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)1696 static void tsi_ssl_client_handshaker_factory_destroy(
1697     tsi_ssl_handshaker_factory* factory) {
1698   if (factory == nullptr) return;
1699   tsi_ssl_client_handshaker_factory* self =
1700       reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory);
1701   if (self->ssl_context != nullptr) SSL_CTX_free(self->ssl_context);
1702   if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
1703   self->session_cache.reset();
1704   gpr_free(self);
1705 }
1706 
client_handshaker_factory_npn_callback(SSL *,unsigned char ** out,unsigned char * outlen,const unsigned char * in,unsigned int inlen,void * arg)1707 static int client_handshaker_factory_npn_callback(
1708     SSL* /*ssl*/, unsigned char** out, unsigned char* outlen,
1709     const unsigned char* in, unsigned int inlen, void* arg) {
1710   tsi_ssl_client_handshaker_factory* factory =
1711       static_cast<tsi_ssl_client_handshaker_factory*>(arg);
1712   return select_protocol_list((const unsigned char**)out, outlen,
1713                               factory->alpn_protocol_list,
1714                               factory->alpn_protocol_list_length, in, inlen);
1715 }
1716 
1717 /* --- tsi_ssl_server_handshaker_factory methods implementation. --- */
1718 
tsi_ssl_server_handshaker_factory_create_handshaker(tsi_ssl_server_handshaker_factory * self,tsi_handshaker ** handshaker)1719 tsi_result tsi_ssl_server_handshaker_factory_create_handshaker(
1720     tsi_ssl_server_handshaker_factory* self, tsi_handshaker** handshaker) {
1721   if (self->ssl_context_count == 0) return TSI_INVALID_ARGUMENT;
1722   /* Create the handshaker with the first context. We will switch if needed
1723      because of SNI in ssl_server_handshaker_factory_servername_callback.  */
1724   return create_tsi_ssl_handshaker(self->ssl_contexts[0], 0, nullptr,
1725                                    &self->base, handshaker);
1726 }
1727 
tsi_ssl_server_handshaker_factory_unref(tsi_ssl_server_handshaker_factory * self)1728 void tsi_ssl_server_handshaker_factory_unref(
1729     tsi_ssl_server_handshaker_factory* self) {
1730   if (self == nullptr) return;
1731   tsi_ssl_handshaker_factory_unref(&self->base);
1732 }
1733 
tsi_ssl_server_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)1734 static void tsi_ssl_server_handshaker_factory_destroy(
1735     tsi_ssl_handshaker_factory* factory) {
1736   if (factory == nullptr) return;
1737   tsi_ssl_server_handshaker_factory* self =
1738       reinterpret_cast<tsi_ssl_server_handshaker_factory*>(factory);
1739   size_t i;
1740   for (i = 0; i < self->ssl_context_count; i++) {
1741     if (self->ssl_contexts[i] != nullptr) {
1742       SSL_CTX_free(self->ssl_contexts[i]);
1743       tsi_peer_destruct(&self->ssl_context_x509_subject_names[i]);
1744     }
1745   }
1746   if (self->ssl_contexts != nullptr) gpr_free(self->ssl_contexts);
1747   if (self->ssl_context_x509_subject_names != nullptr) {
1748     gpr_free(self->ssl_context_x509_subject_names);
1749   }
1750   if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
1751   gpr_free(self);
1752 }
1753 
does_entry_match_name(absl::string_view entry,absl::string_view name)1754 static int does_entry_match_name(absl::string_view entry,
1755                                  absl::string_view name) {
1756   if (entry.empty()) return 0;
1757 
1758   /* Take care of '.' terminations. */
1759   if (name.back() == '.') {
1760     name.remove_suffix(1);
1761   }
1762   if (entry.back() == '.') {
1763     entry.remove_suffix(1);
1764     if (entry.empty()) return 0;
1765   }
1766 
1767   if (absl::EqualsIgnoreCase(name, entry)) {
1768     return 1; /* Perfect match. */
1769   }
1770   if (entry.front() != '*') return 0;
1771 
1772   /* Wildchar subdomain matching. */
1773   if (entry.size() < 3 || entry[1] != '.') { /* At least *.x */
1774     gpr_log(GPR_ERROR, "Invalid wildchar entry.");
1775     return 0;
1776   }
1777   size_t name_subdomain_pos = name.find('.');
1778   if (name_subdomain_pos == absl::string_view::npos) return 0;
1779   if (name_subdomain_pos >= name.size() - 2) return 0;
1780   absl::string_view name_subdomain =
1781       name.substr(name_subdomain_pos + 1); /* Starts after the dot. */
1782   entry.remove_prefix(2);                  /* Remove *. */
1783   size_t dot = name_subdomain.find('.');
1784   if (dot == absl::string_view::npos || dot == name_subdomain.size() - 1) {
1785     gpr_log(GPR_ERROR, "Invalid toplevel subdomain: %s",
1786             std::string(name_subdomain).c_str());
1787     return 0;
1788   }
1789   if (name_subdomain.back() == '.') {
1790     name_subdomain.remove_suffix(1);
1791   }
1792   return !entry.empty() && absl::EqualsIgnoreCase(name_subdomain, entry);
1793 }
1794 
ssl_server_handshaker_factory_servername_callback(SSL * ssl,int *,void * arg)1795 static int ssl_server_handshaker_factory_servername_callback(SSL* ssl,
1796                                                              int* /*ap*/,
1797                                                              void* arg) {
1798   tsi_ssl_server_handshaker_factory* impl =
1799       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
1800   size_t i = 0;
1801   const char* servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1802   if (servername == nullptr || strlen(servername) == 0) {
1803     return SSL_TLSEXT_ERR_NOACK;
1804   }
1805 
1806   for (i = 0; i < impl->ssl_context_count; i++) {
1807     if (tsi_ssl_peer_matches_name(&impl->ssl_context_x509_subject_names[i],
1808                                   servername)) {
1809       SSL_set_SSL_CTX(ssl, impl->ssl_contexts[i]);
1810       return SSL_TLSEXT_ERR_OK;
1811     }
1812   }
1813   gpr_log(GPR_ERROR, "No match found for server name: %s.", servername);
1814   return SSL_TLSEXT_ERR_NOACK;
1815 }
1816 
1817 #if TSI_OPENSSL_ALPN_SUPPORT
server_handshaker_factory_alpn_callback(SSL *,const unsigned char ** out,unsigned char * outlen,const unsigned char * in,unsigned int inlen,void * arg)1818 static int server_handshaker_factory_alpn_callback(
1819     SSL* /*ssl*/, const unsigned char** out, unsigned char* outlen,
1820     const unsigned char* in, unsigned int inlen, void* arg) {
1821   tsi_ssl_server_handshaker_factory* factory =
1822       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
1823   return select_protocol_list(out, outlen, in, inlen,
1824                               factory->alpn_protocol_list,
1825                               factory->alpn_protocol_list_length);
1826 }
1827 #endif /* TSI_OPENSSL_ALPN_SUPPORT */
1828 
server_handshaker_factory_npn_advertised_callback(SSL *,const unsigned char ** out,unsigned int * outlen,void * arg)1829 static int server_handshaker_factory_npn_advertised_callback(
1830     SSL* /*ssl*/, const unsigned char** out, unsigned int* outlen, void* arg) {
1831   tsi_ssl_server_handshaker_factory* factory =
1832       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
1833   *out = factory->alpn_protocol_list;
1834   GPR_ASSERT(factory->alpn_protocol_list_length <= UINT_MAX);
1835   *outlen = static_cast<unsigned int>(factory->alpn_protocol_list_length);
1836   return SSL_TLSEXT_ERR_OK;
1837 }
1838 
1839 /// This callback is called when new \a session is established and ready to
1840 /// be cached. This session can be reused for new connections to similar
1841 /// servers at later point of time.
1842 /// It's intended to be used with SSL_CTX_sess_set_new_cb function.
1843 ///
1844 /// It returns 1 if callback takes ownership over \a session and 0 otherwise.
server_handshaker_factory_new_session_callback(SSL * ssl,SSL_SESSION * session)1845 static int server_handshaker_factory_new_session_callback(
1846     SSL* ssl, SSL_SESSION* session) {
1847   SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl);
1848   if (ssl_context == nullptr) {
1849     return 0;
1850   }
1851   void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index);
1852   tsi_ssl_client_handshaker_factory* factory =
1853       static_cast<tsi_ssl_client_handshaker_factory*>(arg);
1854   const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1855   if (server_name == nullptr) {
1856     return 0;
1857   }
1858   factory->session_cache->Put(server_name, tsi::SslSessionPtr(session));
1859   // Return 1 to indicate transferred ownership over the given session.
1860   return 1;
1861 }
1862 
1863 /* --- tsi_ssl_handshaker_factory constructors. --- */
1864 
1865 static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = {
1866     tsi_ssl_client_handshaker_factory_destroy};
1867 
tsi_create_ssl_client_handshaker_factory(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pair,const char * pem_root_certs,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_client_handshaker_factory ** factory)1868 tsi_result tsi_create_ssl_client_handshaker_factory(
1869     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pair,
1870     const char* pem_root_certs, const char* cipher_suites,
1871     const char** alpn_protocols, uint16_t num_alpn_protocols,
1872     tsi_ssl_client_handshaker_factory** factory) {
1873   tsi_ssl_client_handshaker_options options;
1874   options.pem_key_cert_pair = pem_key_cert_pair;
1875   options.pem_root_certs = pem_root_certs;
1876   options.cipher_suites = cipher_suites;
1877   options.alpn_protocols = alpn_protocols;
1878   options.num_alpn_protocols = num_alpn_protocols;
1879   return tsi_create_ssl_client_handshaker_factory_with_options(&options,
1880                                                                factory);
1881 }
1882 
tsi_create_ssl_client_handshaker_factory_with_options(const tsi_ssl_client_handshaker_options * options,tsi_ssl_client_handshaker_factory ** factory)1883 tsi_result tsi_create_ssl_client_handshaker_factory_with_options(
1884     const tsi_ssl_client_handshaker_options* options,
1885     tsi_ssl_client_handshaker_factory** factory) {
1886   SSL_CTX* ssl_context = nullptr;
1887   tsi_ssl_client_handshaker_factory* impl = nullptr;
1888   tsi_result result = TSI_OK;
1889 
1890   gpr_once_init(&g_init_openssl_once, init_openssl);
1891 
1892   if (factory == nullptr) return TSI_INVALID_ARGUMENT;
1893   *factory = nullptr;
1894   if (options->pem_root_certs == nullptr && options->root_store == nullptr) {
1895     return TSI_INVALID_ARGUMENT;
1896   }
1897 
1898 #if OPENSSL_VERSION_NUMBER >= 0x10100000
1899   ssl_context = SSL_CTX_new(TLS_method());
1900 #else
1901   ssl_context = SSL_CTX_new(TLSv1_2_method());
1902 #endif
1903   result = tsi_set_min_and_max_tls_versions(
1904       ssl_context, options->min_tls_version, options->max_tls_version);
1905   if (result != TSI_OK) return result;
1906   if (ssl_context == nullptr) {
1907     gpr_log(GPR_ERROR, "Could not create ssl context.");
1908     return TSI_INVALID_ARGUMENT;
1909   }
1910 
1911   impl = static_cast<tsi_ssl_client_handshaker_factory*>(
1912       gpr_zalloc(sizeof(*impl)));
1913   tsi_ssl_handshaker_factory_init(&impl->base);
1914   impl->base.vtable = &client_handshaker_factory_vtable;
1915   impl->ssl_context = ssl_context;
1916   if (options->session_cache != nullptr) {
1917     // Unref is called manually on factory destruction.
1918     impl->session_cache =
1919         reinterpret_cast<tsi::SslSessionLRUCache*>(options->session_cache)
1920             ->Ref();
1921     SSL_CTX_set_ex_data(ssl_context, g_ssl_ctx_ex_factory_index, impl);
1922     SSL_CTX_sess_set_new_cb(ssl_context,
1923                             server_handshaker_factory_new_session_callback);
1924     SSL_CTX_set_session_cache_mode(ssl_context, SSL_SESS_CACHE_CLIENT);
1925   }
1926 
1927   do {
1928     result = populate_ssl_context(ssl_context, options->pem_key_cert_pair,
1929                                   options->cipher_suites);
1930     if (result != TSI_OK) break;
1931 
1932 #if OPENSSL_VERSION_NUMBER >= 0x10100000 && !(defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2070000fL)
1933     // X509_STORE_up_ref is only available since OpenSSL 1.1.
1934     if (options->root_store != nullptr) {
1935       X509_STORE_up_ref(options->root_store->store);
1936       SSL_CTX_set_cert_store(ssl_context, options->root_store->store);
1937     }
1938 #endif
1939     if (OPENSSL_VERSION_NUMBER < 0x10100000 || options->root_store == nullptr) {
1940       result = ssl_ctx_load_verification_certs(
1941           ssl_context, options->pem_root_certs, strlen(options->pem_root_certs),
1942           nullptr);
1943       if (result != TSI_OK) {
1944         gpr_log(GPR_ERROR, "Cannot load server root certificates.");
1945         break;
1946       }
1947     }
1948 
1949     if (options->num_alpn_protocols != 0) {
1950       result = build_alpn_protocol_name_list(
1951           options->alpn_protocols, options->num_alpn_protocols,
1952           &impl->alpn_protocol_list, &impl->alpn_protocol_list_length);
1953       if (result != TSI_OK) {
1954         gpr_log(GPR_ERROR, "Building alpn list failed with error %s.",
1955                 tsi_result_to_string(result));
1956         break;
1957       }
1958 #if TSI_OPENSSL_ALPN_SUPPORT
1959       GPR_ASSERT(impl->alpn_protocol_list_length < UINT_MAX);
1960       if (SSL_CTX_set_alpn_protos(
1961               ssl_context, impl->alpn_protocol_list,
1962               static_cast<unsigned int>(impl->alpn_protocol_list_length))) {
1963         gpr_log(GPR_ERROR, "Could not set alpn protocol list to context.");
1964         result = TSI_INVALID_ARGUMENT;
1965         break;
1966       }
1967 #endif /* TSI_OPENSSL_ALPN_SUPPORT */
1968       SSL_CTX_set_next_proto_select_cb(
1969           ssl_context, client_handshaker_factory_npn_callback, impl);
1970     }
1971   } while (false);
1972   if (result != TSI_OK) {
1973     tsi_ssl_handshaker_factory_unref(&impl->base);
1974     return result;
1975   }
1976   if (options->skip_server_certificate_verification) {
1977     SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, NullVerifyCallback);
1978   } else {
1979     SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, nullptr);
1980   }
1981   /* TODO(jboeuf): Add revocation verification. */
1982 
1983   *factory = impl;
1984   return TSI_OK;
1985 }
1986 
1987 static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = {
1988     tsi_ssl_server_handshaker_factory_destroy};
1989 
tsi_create_ssl_server_handshaker_factory(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pairs,size_t num_key_cert_pairs,const char * pem_client_root_certs,int force_client_auth,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_server_handshaker_factory ** factory)1990 tsi_result tsi_create_ssl_server_handshaker_factory(
1991     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
1992     size_t num_key_cert_pairs, const char* pem_client_root_certs,
1993     int force_client_auth, const char* cipher_suites,
1994     const char** alpn_protocols, uint16_t num_alpn_protocols,
1995     tsi_ssl_server_handshaker_factory** factory) {
1996   return tsi_create_ssl_server_handshaker_factory_ex(
1997       pem_key_cert_pairs, num_key_cert_pairs, pem_client_root_certs,
1998       force_client_auth ? TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY
1999                         : TSI_DONT_REQUEST_CLIENT_CERTIFICATE,
2000       cipher_suites, alpn_protocols, num_alpn_protocols, factory);
2001 }
2002 
tsi_create_ssl_server_handshaker_factory_ex(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pairs,size_t num_key_cert_pairs,const char * pem_client_root_certs,tsi_client_certificate_request_type client_certificate_request,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_server_handshaker_factory ** factory)2003 tsi_result tsi_create_ssl_server_handshaker_factory_ex(
2004     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
2005     size_t num_key_cert_pairs, const char* pem_client_root_certs,
2006     tsi_client_certificate_request_type client_certificate_request,
2007     const char* cipher_suites, const char** alpn_protocols,
2008     uint16_t num_alpn_protocols, tsi_ssl_server_handshaker_factory** factory) {
2009   tsi_ssl_server_handshaker_options options;
2010   options.pem_key_cert_pairs = pem_key_cert_pairs;
2011   options.num_key_cert_pairs = num_key_cert_pairs;
2012   options.pem_client_root_certs = pem_client_root_certs;
2013   options.client_certificate_request = client_certificate_request;
2014   options.cipher_suites = cipher_suites;
2015   options.alpn_protocols = alpn_protocols;
2016   options.num_alpn_protocols = num_alpn_protocols;
2017   return tsi_create_ssl_server_handshaker_factory_with_options(&options,
2018                                                                factory);
2019 }
2020 
tsi_create_ssl_server_handshaker_factory_with_options(const tsi_ssl_server_handshaker_options * options,tsi_ssl_server_handshaker_factory ** factory)2021 tsi_result tsi_create_ssl_server_handshaker_factory_with_options(
2022     const tsi_ssl_server_handshaker_options* options,
2023     tsi_ssl_server_handshaker_factory** factory) {
2024   tsi_ssl_server_handshaker_factory* impl = nullptr;
2025   tsi_result result = TSI_OK;
2026   size_t i = 0;
2027 
2028   gpr_once_init(&g_init_openssl_once, init_openssl);
2029 
2030   if (factory == nullptr) return TSI_INVALID_ARGUMENT;
2031   *factory = nullptr;
2032   if (options->num_key_cert_pairs == 0 ||
2033       options->pem_key_cert_pairs == nullptr) {
2034     return TSI_INVALID_ARGUMENT;
2035   }
2036 
2037   impl = static_cast<tsi_ssl_server_handshaker_factory*>(
2038       gpr_zalloc(sizeof(*impl)));
2039   tsi_ssl_handshaker_factory_init(&impl->base);
2040   impl->base.vtable = &server_handshaker_factory_vtable;
2041 
2042   impl->ssl_contexts = static_cast<SSL_CTX**>(
2043       gpr_zalloc(options->num_key_cert_pairs * sizeof(SSL_CTX*)));
2044   impl->ssl_context_x509_subject_names = static_cast<tsi_peer*>(
2045       gpr_zalloc(options->num_key_cert_pairs * sizeof(tsi_peer)));
2046   if (impl->ssl_contexts == nullptr ||
2047       impl->ssl_context_x509_subject_names == nullptr) {
2048     tsi_ssl_handshaker_factory_unref(&impl->base);
2049     return TSI_OUT_OF_RESOURCES;
2050   }
2051   impl->ssl_context_count = options->num_key_cert_pairs;
2052 
2053   if (options->num_alpn_protocols > 0) {
2054     result = build_alpn_protocol_name_list(
2055         options->alpn_protocols, options->num_alpn_protocols,
2056         &impl->alpn_protocol_list, &impl->alpn_protocol_list_length);
2057     if (result != TSI_OK) {
2058       tsi_ssl_handshaker_factory_unref(&impl->base);
2059       return result;
2060     }
2061   }
2062 
2063   for (i = 0; i < options->num_key_cert_pairs; i++) {
2064     do {
2065 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2066       impl->ssl_contexts[i] = SSL_CTX_new(TLS_method());
2067 #else
2068       impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method());
2069 #endif
2070       result = tsi_set_min_and_max_tls_versions(impl->ssl_contexts[i],
2071                                                 options->min_tls_version,
2072                                                 options->max_tls_version);
2073       if (result != TSI_OK) return result;
2074       if (impl->ssl_contexts[i] == nullptr) {
2075         gpr_log(GPR_ERROR, "Could not create ssl context.");
2076         result = TSI_OUT_OF_RESOURCES;
2077         break;
2078       }
2079       result = populate_ssl_context(impl->ssl_contexts[i],
2080                                     &options->pem_key_cert_pairs[i],
2081                                     options->cipher_suites);
2082       if (result != TSI_OK) break;
2083 
2084       // TODO(elessar): Provide ability to disable session ticket keys.
2085 
2086       // Allow client cache sessions (it's needed for OpenSSL only).
2087       int set_sid_ctx_result = SSL_CTX_set_session_id_context(
2088           impl->ssl_contexts[i], kSslSessionIdContext,
2089           GPR_ARRAY_SIZE(kSslSessionIdContext));
2090       if (set_sid_ctx_result == 0) {
2091         gpr_log(GPR_ERROR, "Failed to set session id context.");
2092         result = TSI_INTERNAL_ERROR;
2093         break;
2094       }
2095 
2096       if (options->session_ticket_key != nullptr) {
2097         if (SSL_CTX_set_tlsext_ticket_keys(
2098                 impl->ssl_contexts[i],
2099                 const_cast<char*>(options->session_ticket_key),
2100                 options->session_ticket_key_size) == 0) {
2101           gpr_log(GPR_ERROR, "Invalid STEK size.");
2102           result = TSI_INVALID_ARGUMENT;
2103           break;
2104         }
2105       }
2106 
2107       if (options->pem_client_root_certs != nullptr) {
2108         STACK_OF(X509_NAME)* root_names = nullptr;
2109         result = ssl_ctx_load_verification_certs(
2110             impl->ssl_contexts[i], options->pem_client_root_certs,
2111             strlen(options->pem_client_root_certs), &root_names);
2112         if (result != TSI_OK) {
2113           gpr_log(GPR_ERROR, "Invalid verification certs.");
2114           break;
2115         }
2116         SSL_CTX_set_client_CA_list(impl->ssl_contexts[i], root_names);
2117       }
2118       switch (options->client_certificate_request) {
2119         case TSI_DONT_REQUEST_CLIENT_CERTIFICATE:
2120           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr);
2121           break;
2122         case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
2123           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER,
2124                              NullVerifyCallback);
2125           break;
2126         case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
2127           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr);
2128           break;
2129         case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
2130           SSL_CTX_set_verify(impl->ssl_contexts[i],
2131                              SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
2132                              NullVerifyCallback);
2133           break;
2134         case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
2135           SSL_CTX_set_verify(impl->ssl_contexts[i],
2136                              SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
2137                              nullptr);
2138           break;
2139       }
2140       /* TODO(jboeuf): Add revocation verification. */
2141 
2142       result = tsi_ssl_extract_x509_subject_names_from_pem_cert(
2143           options->pem_key_cert_pairs[i].cert_chain,
2144           &impl->ssl_context_x509_subject_names[i]);
2145       if (result != TSI_OK) break;
2146 
2147       SSL_CTX_set_tlsext_servername_callback(
2148           impl->ssl_contexts[i],
2149           ssl_server_handshaker_factory_servername_callback);
2150       SSL_CTX_set_tlsext_servername_arg(impl->ssl_contexts[i], impl);
2151 #if TSI_OPENSSL_ALPN_SUPPORT
2152       SSL_CTX_set_alpn_select_cb(impl->ssl_contexts[i],
2153                                  server_handshaker_factory_alpn_callback, impl);
2154 #endif /* TSI_OPENSSL_ALPN_SUPPORT */
2155       SSL_CTX_set_next_protos_advertised_cb(
2156           impl->ssl_contexts[i],
2157           server_handshaker_factory_npn_advertised_callback, impl);
2158     } while (false);
2159 
2160     if (result != TSI_OK) {
2161       tsi_ssl_handshaker_factory_unref(&impl->base);
2162       return result;
2163     }
2164   }
2165 
2166   *factory = impl;
2167   return TSI_OK;
2168 }
2169 
2170 /* --- tsi_ssl utils. --- */
2171 
tsi_ssl_peer_matches_name(const tsi_peer * peer,absl::string_view name)2172 int tsi_ssl_peer_matches_name(const tsi_peer* peer, absl::string_view name) {
2173   size_t i = 0;
2174   size_t san_count = 0;
2175   const tsi_peer_property* cn_property = nullptr;
2176   int like_ip = looks_like_ip_address(name);
2177 
2178   /* Check the SAN first. */
2179   for (i = 0; i < peer->property_count; i++) {
2180     const tsi_peer_property* property = &peer->properties[i];
2181     if (property->name == nullptr) continue;
2182     if (strcmp(property->name,
2183                TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) {
2184       san_count++;
2185 
2186       absl::string_view entry(property->value.data, property->value.length);
2187       if (!like_ip && does_entry_match_name(entry, name)) {
2188         return 1;
2189       } else if (like_ip && name == entry) {
2190         /* IP Addresses are exact matches only. */
2191         return 1;
2192       }
2193     } else if (strcmp(property->name,
2194                       TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY) == 0) {
2195       cn_property = property;
2196     }
2197   }
2198 
2199   /* If there's no SAN, try the CN, but only if its not like an IP Address */
2200   if (san_count == 0 && cn_property != nullptr && !like_ip) {
2201     if (does_entry_match_name(absl::string_view(cn_property->value.data,
2202                                                 cn_property->value.length),
2203                               name)) {
2204       return 1;
2205     }
2206   }
2207 
2208   return 0; /* Not found. */
2209 }
2210 
2211 /* --- Testing support. --- */
tsi_ssl_handshaker_factory_swap_vtable(tsi_ssl_handshaker_factory * factory,tsi_ssl_handshaker_factory_vtable * new_vtable)2212 const tsi_ssl_handshaker_factory_vtable* tsi_ssl_handshaker_factory_swap_vtable(
2213     tsi_ssl_handshaker_factory* factory,
2214     tsi_ssl_handshaker_factory_vtable* new_vtable) {
2215   GPR_ASSERT(factory != nullptr);
2216   GPR_ASSERT(factory->vtable != nullptr);
2217 
2218   const tsi_ssl_handshaker_factory_vtable* orig_vtable = factory->vtable;
2219   factory->vtable = new_vtable;
2220   return orig_vtable;
2221 }
2222