1*1dcdf01fSchristos /*
2*1dcdf01fSchristos  * Copyright 1995-2018 The OpenSSL Project Authors. All Rights Reserved.
3*1dcdf01fSchristos  *
4*1dcdf01fSchristos  * Licensed under the OpenSSL license (the "License").  You may not use
5*1dcdf01fSchristos  * this file except in compliance with the License.  You can obtain a copy
6*1dcdf01fSchristos  * in the file LICENSE in the source distribution or at
7*1dcdf01fSchristos  * https://www.openssl.org/source/license.html
8*1dcdf01fSchristos  */
9*1dcdf01fSchristos 
10*1dcdf01fSchristos #if !defined(__STDC_FORMAT_MACROS)
11*1dcdf01fSchristos #define __STDC_FORMAT_MACROS
12*1dcdf01fSchristos #endif
13*1dcdf01fSchristos 
14*1dcdf01fSchristos #include "packeted_bio.h"
15*1dcdf01fSchristos #include <openssl/e_os2.h>
16*1dcdf01fSchristos 
17*1dcdf01fSchristos #if !defined(OPENSSL_SYS_WINDOWS)
18*1dcdf01fSchristos #include <arpa/inet.h>
19*1dcdf01fSchristos #include <netinet/in.h>
20*1dcdf01fSchristos #include <netinet/tcp.h>
21*1dcdf01fSchristos #include <signal.h>
22*1dcdf01fSchristos #include <sys/socket.h>
23*1dcdf01fSchristos #include <sys/time.h>
24*1dcdf01fSchristos #include <unistd.h>
25*1dcdf01fSchristos #else
26*1dcdf01fSchristos #include <io.h>
27*1dcdf01fSchristos OPENSSL_MSVC_PRAGMA(warning(push, 3))
28*1dcdf01fSchristos #include <winsock2.h>
29*1dcdf01fSchristos #include <ws2tcpip.h>
30*1dcdf01fSchristos OPENSSL_MSVC_PRAGMA(warning(pop))
31*1dcdf01fSchristos 
32*1dcdf01fSchristos OPENSSL_MSVC_PRAGMA(comment(lib, "Ws2_32.lib"))
33*1dcdf01fSchristos #endif
34*1dcdf01fSchristos 
35*1dcdf01fSchristos #include <assert.h>
36*1dcdf01fSchristos #include <inttypes.h>
37*1dcdf01fSchristos #include <string.h>
38*1dcdf01fSchristos 
39*1dcdf01fSchristos #include <openssl/bio.h>
40*1dcdf01fSchristos #include <openssl/buffer.h>
41*1dcdf01fSchristos #include <openssl/bn.h>
42*1dcdf01fSchristos #include <openssl/crypto.h>
43*1dcdf01fSchristos #include <openssl/dh.h>
44*1dcdf01fSchristos #include <openssl/err.h>
45*1dcdf01fSchristos #include <openssl/evp.h>
46*1dcdf01fSchristos #include <openssl/hmac.h>
47*1dcdf01fSchristos #include <openssl/objects.h>
48*1dcdf01fSchristos #include <openssl/rand.h>
49*1dcdf01fSchristos #include <openssl/ssl.h>
50*1dcdf01fSchristos #include <openssl/x509.h>
51*1dcdf01fSchristos 
52*1dcdf01fSchristos #include <memory>
53*1dcdf01fSchristos #include <string>
54*1dcdf01fSchristos #include <vector>
55*1dcdf01fSchristos 
56*1dcdf01fSchristos #include "async_bio.h"
57*1dcdf01fSchristos #include "test_config.h"
58*1dcdf01fSchristos 
59*1dcdf01fSchristos namespace bssl {
60*1dcdf01fSchristos 
61*1dcdf01fSchristos #if !defined(OPENSSL_SYS_WINDOWS)
closesocket(int sock)62*1dcdf01fSchristos static int closesocket(int sock) {
63*1dcdf01fSchristos   return close(sock);
64*1dcdf01fSchristos }
65*1dcdf01fSchristos 
PrintSocketError(const char * func)66*1dcdf01fSchristos static void PrintSocketError(const char *func) {
67*1dcdf01fSchristos   perror(func);
68*1dcdf01fSchristos }
69*1dcdf01fSchristos #else
70*1dcdf01fSchristos static void PrintSocketError(const char *func) {
71*1dcdf01fSchristos   fprintf(stderr, "%s: %d\n", func, WSAGetLastError());
72*1dcdf01fSchristos }
73*1dcdf01fSchristos #endif
74*1dcdf01fSchristos 
Usage(const char * program)75*1dcdf01fSchristos static int Usage(const char *program) {
76*1dcdf01fSchristos   fprintf(stderr, "Usage: %s [flags...]\n", program);
77*1dcdf01fSchristos   return 1;
78*1dcdf01fSchristos }
79*1dcdf01fSchristos 
80*1dcdf01fSchristos struct TestState {
81*1dcdf01fSchristos   // async_bio is async BIO which pauses reads and writes.
82*1dcdf01fSchristos   BIO *async_bio = nullptr;
83*1dcdf01fSchristos   // packeted_bio is the packeted BIO which simulates read timeouts.
84*1dcdf01fSchristos   BIO *packeted_bio = nullptr;
85*1dcdf01fSchristos   bool cert_ready = false;
86*1dcdf01fSchristos   bool handshake_done = false;
87*1dcdf01fSchristos   // private_key is the underlying private key used when testing custom keys.
88*1dcdf01fSchristos   bssl::UniquePtr<EVP_PKEY> private_key;
89*1dcdf01fSchristos   bool got_new_session = false;
90*1dcdf01fSchristos   bssl::UniquePtr<SSL_SESSION> new_session;
91*1dcdf01fSchristos   bool ticket_decrypt_done = false;
92*1dcdf01fSchristos   bool alpn_select_done = false;
93*1dcdf01fSchristos };
94*1dcdf01fSchristos 
TestStateExFree(void * parent,void * ptr,CRYPTO_EX_DATA * ad,int index,long argl,void * argp)95*1dcdf01fSchristos static void TestStateExFree(void *parent, void *ptr, CRYPTO_EX_DATA *ad,
96*1dcdf01fSchristos                             int index, long argl, void *argp) {
97*1dcdf01fSchristos   delete ((TestState *)ptr);
98*1dcdf01fSchristos }
99*1dcdf01fSchristos 
100*1dcdf01fSchristos static int g_config_index = 0;
101*1dcdf01fSchristos static int g_state_index = 0;
102*1dcdf01fSchristos 
SetTestConfig(SSL * ssl,const TestConfig * config)103*1dcdf01fSchristos static bool SetTestConfig(SSL *ssl, const TestConfig *config) {
104*1dcdf01fSchristos   return SSL_set_ex_data(ssl, g_config_index, (void *)config) == 1;
105*1dcdf01fSchristos }
106*1dcdf01fSchristos 
GetTestConfig(const SSL * ssl)107*1dcdf01fSchristos static const TestConfig *GetTestConfig(const SSL *ssl) {
108*1dcdf01fSchristos   return (const TestConfig *)SSL_get_ex_data(ssl, g_config_index);
109*1dcdf01fSchristos }
110*1dcdf01fSchristos 
SetTestState(SSL * ssl,std::unique_ptr<TestState> state)111*1dcdf01fSchristos static bool SetTestState(SSL *ssl, std::unique_ptr<TestState> state) {
112*1dcdf01fSchristos   // |SSL_set_ex_data| takes ownership of |state| only on success.
113*1dcdf01fSchristos   if (SSL_set_ex_data(ssl, g_state_index, state.get()) == 1) {
114*1dcdf01fSchristos     state.release();
115*1dcdf01fSchristos     return true;
116*1dcdf01fSchristos   }
117*1dcdf01fSchristos   return false;
118*1dcdf01fSchristos }
119*1dcdf01fSchristos 
GetTestState(const SSL * ssl)120*1dcdf01fSchristos static TestState *GetTestState(const SSL *ssl) {
121*1dcdf01fSchristos   return (TestState *)SSL_get_ex_data(ssl, g_state_index);
122*1dcdf01fSchristos }
123*1dcdf01fSchristos 
LoadCertificate(const std::string & file)124*1dcdf01fSchristos static bssl::UniquePtr<X509> LoadCertificate(const std::string &file) {
125*1dcdf01fSchristos   bssl::UniquePtr<BIO> bio(BIO_new(BIO_s_file()));
126*1dcdf01fSchristos   if (!bio || !BIO_read_filename(bio.get(), file.c_str())) {
127*1dcdf01fSchristos     return nullptr;
128*1dcdf01fSchristos   }
129*1dcdf01fSchristos   return bssl::UniquePtr<X509>(PEM_read_bio_X509(bio.get(), NULL, NULL, NULL));
130*1dcdf01fSchristos }
131*1dcdf01fSchristos 
LoadPrivateKey(const std::string & file)132*1dcdf01fSchristos static bssl::UniquePtr<EVP_PKEY> LoadPrivateKey(const std::string &file) {
133*1dcdf01fSchristos   bssl::UniquePtr<BIO> bio(BIO_new(BIO_s_file()));
134*1dcdf01fSchristos   if (!bio || !BIO_read_filename(bio.get(), file.c_str())) {
135*1dcdf01fSchristos     return nullptr;
136*1dcdf01fSchristos   }
137*1dcdf01fSchristos   return bssl::UniquePtr<EVP_PKEY>(
138*1dcdf01fSchristos       PEM_read_bio_PrivateKey(bio.get(), NULL, NULL, NULL));
139*1dcdf01fSchristos }
140*1dcdf01fSchristos 
141*1dcdf01fSchristos template<typename T>
142*1dcdf01fSchristos struct Free {
operator ()bssl::Free143*1dcdf01fSchristos   void operator()(T *buf) {
144*1dcdf01fSchristos     free(buf);
145*1dcdf01fSchristos   }
146*1dcdf01fSchristos };
147*1dcdf01fSchristos 
GetCertificate(SSL * ssl,bssl::UniquePtr<X509> * out_x509,bssl::UniquePtr<EVP_PKEY> * out_pkey)148*1dcdf01fSchristos static bool GetCertificate(SSL *ssl, bssl::UniquePtr<X509> *out_x509,
149*1dcdf01fSchristos                            bssl::UniquePtr<EVP_PKEY> *out_pkey) {
150*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
151*1dcdf01fSchristos 
152*1dcdf01fSchristos   if (!config->key_file.empty()) {
153*1dcdf01fSchristos     *out_pkey = LoadPrivateKey(config->key_file.c_str());
154*1dcdf01fSchristos     if (!*out_pkey) {
155*1dcdf01fSchristos       return false;
156*1dcdf01fSchristos     }
157*1dcdf01fSchristos   }
158*1dcdf01fSchristos   if (!config->cert_file.empty()) {
159*1dcdf01fSchristos     *out_x509 = LoadCertificate(config->cert_file.c_str());
160*1dcdf01fSchristos     if (!*out_x509) {
161*1dcdf01fSchristos       return false;
162*1dcdf01fSchristos     }
163*1dcdf01fSchristos   }
164*1dcdf01fSchristos   return true;
165*1dcdf01fSchristos }
166*1dcdf01fSchristos 
InstallCertificate(SSL * ssl)167*1dcdf01fSchristos static bool InstallCertificate(SSL *ssl) {
168*1dcdf01fSchristos   bssl::UniquePtr<X509> x509;
169*1dcdf01fSchristos   bssl::UniquePtr<EVP_PKEY> pkey;
170*1dcdf01fSchristos   if (!GetCertificate(ssl, &x509, &pkey)) {
171*1dcdf01fSchristos     return false;
172*1dcdf01fSchristos   }
173*1dcdf01fSchristos 
174*1dcdf01fSchristos   if (pkey && !SSL_use_PrivateKey(ssl, pkey.get())) {
175*1dcdf01fSchristos     return false;
176*1dcdf01fSchristos   }
177*1dcdf01fSchristos 
178*1dcdf01fSchristos   if (x509 && !SSL_use_certificate(ssl, x509.get())) {
179*1dcdf01fSchristos     return false;
180*1dcdf01fSchristos   }
181*1dcdf01fSchristos 
182*1dcdf01fSchristos   return true;
183*1dcdf01fSchristos }
184*1dcdf01fSchristos 
ClientCertCallback(SSL * ssl,X509 ** out_x509,EVP_PKEY ** out_pkey)185*1dcdf01fSchristos static int ClientCertCallback(SSL *ssl, X509 **out_x509, EVP_PKEY **out_pkey) {
186*1dcdf01fSchristos   if (GetTestConfig(ssl)->async && !GetTestState(ssl)->cert_ready) {
187*1dcdf01fSchristos     return -1;
188*1dcdf01fSchristos   }
189*1dcdf01fSchristos 
190*1dcdf01fSchristos   bssl::UniquePtr<X509> x509;
191*1dcdf01fSchristos   bssl::UniquePtr<EVP_PKEY> pkey;
192*1dcdf01fSchristos   if (!GetCertificate(ssl, &x509, &pkey)) {
193*1dcdf01fSchristos     return -1;
194*1dcdf01fSchristos   }
195*1dcdf01fSchristos 
196*1dcdf01fSchristos   // Return zero for no certificate.
197*1dcdf01fSchristos   if (!x509) {
198*1dcdf01fSchristos     return 0;
199*1dcdf01fSchristos   }
200*1dcdf01fSchristos 
201*1dcdf01fSchristos   // Asynchronous private keys are not supported with client_cert_cb.
202*1dcdf01fSchristos   *out_x509 = x509.release();
203*1dcdf01fSchristos   *out_pkey = pkey.release();
204*1dcdf01fSchristos   return 1;
205*1dcdf01fSchristos }
206*1dcdf01fSchristos 
VerifySucceed(X509_STORE_CTX * store_ctx,void * arg)207*1dcdf01fSchristos static int VerifySucceed(X509_STORE_CTX *store_ctx, void *arg) {
208*1dcdf01fSchristos   return 1;
209*1dcdf01fSchristos }
210*1dcdf01fSchristos 
VerifyFail(X509_STORE_CTX * store_ctx,void * arg)211*1dcdf01fSchristos static int VerifyFail(X509_STORE_CTX *store_ctx, void *arg) {
212*1dcdf01fSchristos   X509_STORE_CTX_set_error(store_ctx, X509_V_ERR_APPLICATION_VERIFICATION);
213*1dcdf01fSchristos   return 0;
214*1dcdf01fSchristos }
215*1dcdf01fSchristos 
NextProtosAdvertisedCallback(SSL * ssl,const uint8_t ** out,unsigned int * out_len,void * arg)216*1dcdf01fSchristos static int NextProtosAdvertisedCallback(SSL *ssl, const uint8_t **out,
217*1dcdf01fSchristos                                         unsigned int *out_len, void *arg) {
218*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
219*1dcdf01fSchristos   if (config->advertise_npn.empty()) {
220*1dcdf01fSchristos     return SSL_TLSEXT_ERR_NOACK;
221*1dcdf01fSchristos   }
222*1dcdf01fSchristos 
223*1dcdf01fSchristos   *out = (const uint8_t*)config->advertise_npn.data();
224*1dcdf01fSchristos   *out_len = config->advertise_npn.size();
225*1dcdf01fSchristos   return SSL_TLSEXT_ERR_OK;
226*1dcdf01fSchristos }
227*1dcdf01fSchristos 
NextProtoSelectCallback(SSL * ssl,uint8_t ** out,uint8_t * outlen,const uint8_t * in,unsigned inlen,void * arg)228*1dcdf01fSchristos static int NextProtoSelectCallback(SSL* ssl, uint8_t** out, uint8_t* outlen,
229*1dcdf01fSchristos                                    const uint8_t* in, unsigned inlen, void* arg) {
230*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
231*1dcdf01fSchristos   if (config->select_next_proto.empty()) {
232*1dcdf01fSchristos     return SSL_TLSEXT_ERR_NOACK;
233*1dcdf01fSchristos   }
234*1dcdf01fSchristos 
235*1dcdf01fSchristos   *out = (uint8_t*)config->select_next_proto.data();
236*1dcdf01fSchristos   *outlen = config->select_next_proto.size();
237*1dcdf01fSchristos   return SSL_TLSEXT_ERR_OK;
238*1dcdf01fSchristos }
239*1dcdf01fSchristos 
AlpnSelectCallback(SSL * ssl,const uint8_t ** out,uint8_t * outlen,const uint8_t * in,unsigned inlen,void * arg)240*1dcdf01fSchristos static int AlpnSelectCallback(SSL* ssl, const uint8_t** out, uint8_t* outlen,
241*1dcdf01fSchristos                               const uint8_t* in, unsigned inlen, void* arg) {
242*1dcdf01fSchristos   if (GetTestState(ssl)->alpn_select_done) {
243*1dcdf01fSchristos     fprintf(stderr, "AlpnSelectCallback called after completion.\n");
244*1dcdf01fSchristos     exit(1);
245*1dcdf01fSchristos   }
246*1dcdf01fSchristos 
247*1dcdf01fSchristos   GetTestState(ssl)->alpn_select_done = true;
248*1dcdf01fSchristos 
249*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
250*1dcdf01fSchristos   if (config->decline_alpn) {
251*1dcdf01fSchristos     return SSL_TLSEXT_ERR_NOACK;
252*1dcdf01fSchristos   }
253*1dcdf01fSchristos 
254*1dcdf01fSchristos   if (!config->expected_advertised_alpn.empty() &&
255*1dcdf01fSchristos       (config->expected_advertised_alpn.size() != inlen ||
256*1dcdf01fSchristos        memcmp(config->expected_advertised_alpn.data(),
257*1dcdf01fSchristos               in, inlen) != 0)) {
258*1dcdf01fSchristos     fprintf(stderr, "bad ALPN select callback inputs\n");
259*1dcdf01fSchristos     exit(1);
260*1dcdf01fSchristos   }
261*1dcdf01fSchristos 
262*1dcdf01fSchristos   *out = (const uint8_t*)config->select_alpn.data();
263*1dcdf01fSchristos   *outlen = config->select_alpn.size();
264*1dcdf01fSchristos   return SSL_TLSEXT_ERR_OK;
265*1dcdf01fSchristos }
266*1dcdf01fSchristos 
PskClientCallback(SSL * ssl,const char * hint,char * out_identity,unsigned max_identity_len,uint8_t * out_psk,unsigned max_psk_len)267*1dcdf01fSchristos static unsigned PskClientCallback(SSL *ssl, const char *hint,
268*1dcdf01fSchristos                                   char *out_identity,
269*1dcdf01fSchristos                                   unsigned max_identity_len,
270*1dcdf01fSchristos                                   uint8_t *out_psk, unsigned max_psk_len) {
271*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
272*1dcdf01fSchristos 
273*1dcdf01fSchristos   if (config->psk_identity.empty()) {
274*1dcdf01fSchristos     if (hint != nullptr) {
275*1dcdf01fSchristos       fprintf(stderr, "Server PSK hint was non-null.\n");
276*1dcdf01fSchristos       return 0;
277*1dcdf01fSchristos     }
278*1dcdf01fSchristos   } else if (hint == nullptr ||
279*1dcdf01fSchristos              strcmp(hint, config->psk_identity.c_str()) != 0) {
280*1dcdf01fSchristos     fprintf(stderr, "Server PSK hint did not match.\n");
281*1dcdf01fSchristos     return 0;
282*1dcdf01fSchristos   }
283*1dcdf01fSchristos 
284*1dcdf01fSchristos   // Account for the trailing '\0' for the identity.
285*1dcdf01fSchristos   if (config->psk_identity.size() >= max_identity_len ||
286*1dcdf01fSchristos       config->psk.size() > max_psk_len) {
287*1dcdf01fSchristos     fprintf(stderr, "PSK buffers too small\n");
288*1dcdf01fSchristos     return 0;
289*1dcdf01fSchristos   }
290*1dcdf01fSchristos 
291*1dcdf01fSchristos   BUF_strlcpy(out_identity, config->psk_identity.c_str(),
292*1dcdf01fSchristos               max_identity_len);
293*1dcdf01fSchristos   memcpy(out_psk, config->psk.data(), config->psk.size());
294*1dcdf01fSchristos   return config->psk.size();
295*1dcdf01fSchristos }
296*1dcdf01fSchristos 
PskServerCallback(SSL * ssl,const char * identity,uint8_t * out_psk,unsigned max_psk_len)297*1dcdf01fSchristos static unsigned PskServerCallback(SSL *ssl, const char *identity,
298*1dcdf01fSchristos                                   uint8_t *out_psk, unsigned max_psk_len) {
299*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
300*1dcdf01fSchristos 
301*1dcdf01fSchristos   if (strcmp(identity, config->psk_identity.c_str()) != 0) {
302*1dcdf01fSchristos     fprintf(stderr, "Client PSK identity did not match.\n");
303*1dcdf01fSchristos     return 0;
304*1dcdf01fSchristos   }
305*1dcdf01fSchristos 
306*1dcdf01fSchristos   if (config->psk.size() > max_psk_len) {
307*1dcdf01fSchristos     fprintf(stderr, "PSK buffers too small\n");
308*1dcdf01fSchristos     return 0;
309*1dcdf01fSchristos   }
310*1dcdf01fSchristos 
311*1dcdf01fSchristos   memcpy(out_psk, config->psk.data(), config->psk.size());
312*1dcdf01fSchristos   return config->psk.size();
313*1dcdf01fSchristos }
314*1dcdf01fSchristos 
CertCallback(SSL * ssl,void * arg)315*1dcdf01fSchristos static int CertCallback(SSL *ssl, void *arg) {
316*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
317*1dcdf01fSchristos 
318*1dcdf01fSchristos   // Check the CertificateRequest metadata is as expected.
319*1dcdf01fSchristos   //
320*1dcdf01fSchristos   // TODO(davidben): Test |SSL_get_client_CA_list|.
321*1dcdf01fSchristos   if (!SSL_is_server(ssl) &&
322*1dcdf01fSchristos       !config->expected_certificate_types.empty()) {
323*1dcdf01fSchristos     const uint8_t *certificate_types;
324*1dcdf01fSchristos     size_t certificate_types_len =
325*1dcdf01fSchristos         SSL_get0_certificate_types(ssl, &certificate_types);
326*1dcdf01fSchristos     if (certificate_types_len != config->expected_certificate_types.size() ||
327*1dcdf01fSchristos         memcmp(certificate_types,
328*1dcdf01fSchristos                config->expected_certificate_types.data(),
329*1dcdf01fSchristos                certificate_types_len) != 0) {
330*1dcdf01fSchristos       fprintf(stderr, "certificate types mismatch\n");
331*1dcdf01fSchristos       return 0;
332*1dcdf01fSchristos     }
333*1dcdf01fSchristos   }
334*1dcdf01fSchristos 
335*1dcdf01fSchristos   // The certificate will be installed via other means.
336*1dcdf01fSchristos   if (!config->async ||
337*1dcdf01fSchristos       config->use_old_client_cert_callback) {
338*1dcdf01fSchristos     return 1;
339*1dcdf01fSchristos   }
340*1dcdf01fSchristos 
341*1dcdf01fSchristos   if (!GetTestState(ssl)->cert_ready) {
342*1dcdf01fSchristos     return -1;
343*1dcdf01fSchristos   }
344*1dcdf01fSchristos   if (!InstallCertificate(ssl)) {
345*1dcdf01fSchristos     return 0;
346*1dcdf01fSchristos   }
347*1dcdf01fSchristos   return 1;
348*1dcdf01fSchristos }
349*1dcdf01fSchristos 
InfoCallback(const SSL * ssl,int type,int val)350*1dcdf01fSchristos static void InfoCallback(const SSL *ssl, int type, int val) {
351*1dcdf01fSchristos   if (type == SSL_CB_HANDSHAKE_DONE) {
352*1dcdf01fSchristos     if (GetTestConfig(ssl)->handshake_never_done) {
353*1dcdf01fSchristos       fprintf(stderr, "Handshake unexpectedly completed.\n");
354*1dcdf01fSchristos       // Abort before any expected error code is printed, to ensure the overall
355*1dcdf01fSchristos       // test fails.
356*1dcdf01fSchristos       abort();
357*1dcdf01fSchristos     }
358*1dcdf01fSchristos     GetTestState(ssl)->handshake_done = true;
359*1dcdf01fSchristos 
360*1dcdf01fSchristos     // Callbacks may be called again on a new handshake.
361*1dcdf01fSchristos     GetTestState(ssl)->ticket_decrypt_done = false;
362*1dcdf01fSchristos     GetTestState(ssl)->alpn_select_done = false;
363*1dcdf01fSchristos   }
364*1dcdf01fSchristos }
365*1dcdf01fSchristos 
NewSessionCallback(SSL * ssl,SSL_SESSION * session)366*1dcdf01fSchristos static int NewSessionCallback(SSL *ssl, SSL_SESSION *session) {
367*1dcdf01fSchristos   GetTestState(ssl)->got_new_session = true;
368*1dcdf01fSchristos   GetTestState(ssl)->new_session.reset(session);
369*1dcdf01fSchristos   return 1;
370*1dcdf01fSchristos }
371*1dcdf01fSchristos 
TicketKeyCallback(SSL * ssl,uint8_t * key_name,uint8_t * iv,EVP_CIPHER_CTX * ctx,HMAC_CTX * hmac_ctx,int encrypt)372*1dcdf01fSchristos static int TicketKeyCallback(SSL *ssl, uint8_t *key_name, uint8_t *iv,
373*1dcdf01fSchristos                              EVP_CIPHER_CTX *ctx, HMAC_CTX *hmac_ctx,
374*1dcdf01fSchristos                              int encrypt) {
375*1dcdf01fSchristos   if (!encrypt) {
376*1dcdf01fSchristos     if (GetTestState(ssl)->ticket_decrypt_done) {
377*1dcdf01fSchristos       fprintf(stderr, "TicketKeyCallback called after completion.\n");
378*1dcdf01fSchristos       return -1;
379*1dcdf01fSchristos     }
380*1dcdf01fSchristos 
381*1dcdf01fSchristos     GetTestState(ssl)->ticket_decrypt_done = true;
382*1dcdf01fSchristos   }
383*1dcdf01fSchristos 
384*1dcdf01fSchristos   // This is just test code, so use the all-zeros key.
385*1dcdf01fSchristos   static const uint8_t kZeros[16] = {0};
386*1dcdf01fSchristos 
387*1dcdf01fSchristos   if (encrypt) {
388*1dcdf01fSchristos     memcpy(key_name, kZeros, sizeof(kZeros));
389*1dcdf01fSchristos     RAND_bytes(iv, 16);
390*1dcdf01fSchristos   } else if (memcmp(key_name, kZeros, 16) != 0) {
391*1dcdf01fSchristos     return 0;
392*1dcdf01fSchristos   }
393*1dcdf01fSchristos 
394*1dcdf01fSchristos   if (!HMAC_Init_ex(hmac_ctx, kZeros, sizeof(kZeros), EVP_sha256(), NULL) ||
395*1dcdf01fSchristos       !EVP_CipherInit_ex(ctx, EVP_aes_128_cbc(), NULL, kZeros, iv, encrypt)) {
396*1dcdf01fSchristos     return -1;
397*1dcdf01fSchristos   }
398*1dcdf01fSchristos 
399*1dcdf01fSchristos   if (!encrypt) {
400*1dcdf01fSchristos     return GetTestConfig(ssl)->renew_ticket ? 2 : 1;
401*1dcdf01fSchristos   }
402*1dcdf01fSchristos   return 1;
403*1dcdf01fSchristos }
404*1dcdf01fSchristos 
405*1dcdf01fSchristos // kCustomExtensionValue is the extension value that the custom extension
406*1dcdf01fSchristos // callbacks will add.
407*1dcdf01fSchristos static const uint16_t kCustomExtensionValue = 1234;
408*1dcdf01fSchristos static void *const kCustomExtensionAddArg =
409*1dcdf01fSchristos     reinterpret_cast<void *>(kCustomExtensionValue);
410*1dcdf01fSchristos static void *const kCustomExtensionParseArg =
411*1dcdf01fSchristos     reinterpret_cast<void *>(kCustomExtensionValue + 1);
412*1dcdf01fSchristos static const char kCustomExtensionContents[] = "custom extension";
413*1dcdf01fSchristos 
CustomExtensionAddCallback(SSL * ssl,unsigned extension_value,const uint8_t ** out,size_t * out_len,int * out_alert_value,void * add_arg)414*1dcdf01fSchristos static int CustomExtensionAddCallback(SSL *ssl, unsigned extension_value,
415*1dcdf01fSchristos                                       const uint8_t **out, size_t *out_len,
416*1dcdf01fSchristos                                       int *out_alert_value, void *add_arg) {
417*1dcdf01fSchristos   if (extension_value != kCustomExtensionValue ||
418*1dcdf01fSchristos       add_arg != kCustomExtensionAddArg) {
419*1dcdf01fSchristos     abort();
420*1dcdf01fSchristos   }
421*1dcdf01fSchristos 
422*1dcdf01fSchristos   if (GetTestConfig(ssl)->custom_extension_skip) {
423*1dcdf01fSchristos     return 0;
424*1dcdf01fSchristos   }
425*1dcdf01fSchristos   if (GetTestConfig(ssl)->custom_extension_fail_add) {
426*1dcdf01fSchristos     return -1;
427*1dcdf01fSchristos   }
428*1dcdf01fSchristos 
429*1dcdf01fSchristos   *out = reinterpret_cast<const uint8_t*>(kCustomExtensionContents);
430*1dcdf01fSchristos   *out_len = sizeof(kCustomExtensionContents) - 1;
431*1dcdf01fSchristos 
432*1dcdf01fSchristos   return 1;
433*1dcdf01fSchristos }
434*1dcdf01fSchristos 
CustomExtensionFreeCallback(SSL * ssl,unsigned extension_value,const uint8_t * out,void * add_arg)435*1dcdf01fSchristos static void CustomExtensionFreeCallback(SSL *ssl, unsigned extension_value,
436*1dcdf01fSchristos                                         const uint8_t *out, void *add_arg) {
437*1dcdf01fSchristos   if (extension_value != kCustomExtensionValue ||
438*1dcdf01fSchristos       add_arg != kCustomExtensionAddArg ||
439*1dcdf01fSchristos       out != reinterpret_cast<const uint8_t *>(kCustomExtensionContents)) {
440*1dcdf01fSchristos     abort();
441*1dcdf01fSchristos   }
442*1dcdf01fSchristos }
443*1dcdf01fSchristos 
CustomExtensionParseCallback(SSL * ssl,unsigned extension_value,const uint8_t * contents,size_t contents_len,int * out_alert_value,void * parse_arg)444*1dcdf01fSchristos static int CustomExtensionParseCallback(SSL *ssl, unsigned extension_value,
445*1dcdf01fSchristos                                         const uint8_t *contents,
446*1dcdf01fSchristos                                         size_t contents_len,
447*1dcdf01fSchristos                                         int *out_alert_value, void *parse_arg) {
448*1dcdf01fSchristos   if (extension_value != kCustomExtensionValue ||
449*1dcdf01fSchristos       parse_arg != kCustomExtensionParseArg) {
450*1dcdf01fSchristos     abort();
451*1dcdf01fSchristos   }
452*1dcdf01fSchristos 
453*1dcdf01fSchristos   if (contents_len != sizeof(kCustomExtensionContents) - 1 ||
454*1dcdf01fSchristos       memcmp(contents, kCustomExtensionContents, contents_len) != 0) {
455*1dcdf01fSchristos     *out_alert_value = SSL_AD_DECODE_ERROR;
456*1dcdf01fSchristos     return 0;
457*1dcdf01fSchristos   }
458*1dcdf01fSchristos 
459*1dcdf01fSchristos   return 1;
460*1dcdf01fSchristos }
461*1dcdf01fSchristos 
ServerNameCallback(SSL * ssl,int * out_alert,void * arg)462*1dcdf01fSchristos static int ServerNameCallback(SSL *ssl, int *out_alert, void *arg) {
463*1dcdf01fSchristos   // SNI must be accessible from the SNI callback.
464*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
465*1dcdf01fSchristos   const char *server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
466*1dcdf01fSchristos   if (server_name == nullptr ||
467*1dcdf01fSchristos       std::string(server_name) != config->expected_server_name) {
468*1dcdf01fSchristos     fprintf(stderr, "servername mismatch (got %s; want %s)\n", server_name,
469*1dcdf01fSchristos             config->expected_server_name.c_str());
470*1dcdf01fSchristos     return SSL_TLSEXT_ERR_ALERT_FATAL;
471*1dcdf01fSchristos   }
472*1dcdf01fSchristos 
473*1dcdf01fSchristos   return SSL_TLSEXT_ERR_OK;
474*1dcdf01fSchristos }
475*1dcdf01fSchristos 
476*1dcdf01fSchristos // Connect returns a new socket connected to localhost on |port| or -1 on
477*1dcdf01fSchristos // error.
Connect(uint16_t port)478*1dcdf01fSchristos static int Connect(uint16_t port) {
479*1dcdf01fSchristos   int sock = socket(AF_INET, SOCK_STREAM, 0);
480*1dcdf01fSchristos   if (sock == -1) {
481*1dcdf01fSchristos     PrintSocketError("socket");
482*1dcdf01fSchristos     return -1;
483*1dcdf01fSchristos   }
484*1dcdf01fSchristos   int nodelay = 1;
485*1dcdf01fSchristos   if (setsockopt(sock, IPPROTO_TCP, TCP_NODELAY,
486*1dcdf01fSchristos           reinterpret_cast<const char*>(&nodelay), sizeof(nodelay)) != 0) {
487*1dcdf01fSchristos     PrintSocketError("setsockopt");
488*1dcdf01fSchristos     closesocket(sock);
489*1dcdf01fSchristos     return -1;
490*1dcdf01fSchristos   }
491*1dcdf01fSchristos   sockaddr_in sin;
492*1dcdf01fSchristos   memset(&sin, 0, sizeof(sin));
493*1dcdf01fSchristos   sin.sin_family = AF_INET;
494*1dcdf01fSchristos   sin.sin_port = htons(port);
495*1dcdf01fSchristos   if (!inet_pton(AF_INET, "127.0.0.1", &sin.sin_addr)) {
496*1dcdf01fSchristos     PrintSocketError("inet_pton");
497*1dcdf01fSchristos     closesocket(sock);
498*1dcdf01fSchristos     return -1;
499*1dcdf01fSchristos   }
500*1dcdf01fSchristos   if (connect(sock, reinterpret_cast<const sockaddr*>(&sin),
501*1dcdf01fSchristos               sizeof(sin)) != 0) {
502*1dcdf01fSchristos     PrintSocketError("connect");
503*1dcdf01fSchristos     closesocket(sock);
504*1dcdf01fSchristos     return -1;
505*1dcdf01fSchristos   }
506*1dcdf01fSchristos   return sock;
507*1dcdf01fSchristos }
508*1dcdf01fSchristos 
509*1dcdf01fSchristos class SocketCloser {
510*1dcdf01fSchristos  public:
SocketCloser(int sock)511*1dcdf01fSchristos   explicit SocketCloser(int sock) : sock_(sock) {}
~SocketCloser()512*1dcdf01fSchristos   ~SocketCloser() {
513*1dcdf01fSchristos     // Half-close and drain the socket before releasing it. This seems to be
514*1dcdf01fSchristos     // necessary for graceful shutdown on Windows. It will also avoid write
515*1dcdf01fSchristos     // failures in the test runner.
516*1dcdf01fSchristos #if defined(OPENSSL_SYS_WINDOWS)
517*1dcdf01fSchristos     shutdown(sock_, SD_SEND);
518*1dcdf01fSchristos #else
519*1dcdf01fSchristos     shutdown(sock_, SHUT_WR);
520*1dcdf01fSchristos #endif
521*1dcdf01fSchristos     while (true) {
522*1dcdf01fSchristos       char buf[1024];
523*1dcdf01fSchristos       if (recv(sock_, buf, sizeof(buf), 0) <= 0) {
524*1dcdf01fSchristos         break;
525*1dcdf01fSchristos       }
526*1dcdf01fSchristos     }
527*1dcdf01fSchristos     closesocket(sock_);
528*1dcdf01fSchristos   }
529*1dcdf01fSchristos 
530*1dcdf01fSchristos  private:
531*1dcdf01fSchristos   const int sock_;
532*1dcdf01fSchristos };
533*1dcdf01fSchristos 
SetupCtx(const TestConfig * config)534*1dcdf01fSchristos static bssl::UniquePtr<SSL_CTX> SetupCtx(const TestConfig *config) {
535*1dcdf01fSchristos   const char sess_id_ctx[] = "ossl_shim";
536*1dcdf01fSchristos   bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(
537*1dcdf01fSchristos       config->is_dtls ? DTLS_method() : TLS_method()));
538*1dcdf01fSchristos   if (!ssl_ctx) {
539*1dcdf01fSchristos     return nullptr;
540*1dcdf01fSchristos   }
541*1dcdf01fSchristos 
542*1dcdf01fSchristos   SSL_CTX_set_security_level(ssl_ctx.get(), 0);
543*1dcdf01fSchristos #if 0
544*1dcdf01fSchristos   /* Disabled for now until we have some TLS1.3 support */
545*1dcdf01fSchristos   // Enable TLS 1.3 for tests.
546*1dcdf01fSchristos   if (!config->is_dtls &&
547*1dcdf01fSchristos       !SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION)) {
548*1dcdf01fSchristos     return nullptr;
549*1dcdf01fSchristos   }
550*1dcdf01fSchristos #else
551*1dcdf01fSchristos   /* Ensure we don't negotiate TLSv1.3 until we can handle it */
552*1dcdf01fSchristos   if (!config->is_dtls &&
553*1dcdf01fSchristos       !SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_2_VERSION)) {
554*1dcdf01fSchristos     return nullptr;
555*1dcdf01fSchristos   }
556*1dcdf01fSchristos #endif
557*1dcdf01fSchristos 
558*1dcdf01fSchristos   std::string cipher_list = "ALL";
559*1dcdf01fSchristos   if (!config->cipher.empty()) {
560*1dcdf01fSchristos     cipher_list = config->cipher;
561*1dcdf01fSchristos     SSL_CTX_set_options(ssl_ctx.get(), SSL_OP_CIPHER_SERVER_PREFERENCE);
562*1dcdf01fSchristos   }
563*1dcdf01fSchristos   if (!SSL_CTX_set_cipher_list(ssl_ctx.get(), cipher_list.c_str())) {
564*1dcdf01fSchristos     return nullptr;
565*1dcdf01fSchristos   }
566*1dcdf01fSchristos 
567*1dcdf01fSchristos   DH *tmpdh;
568*1dcdf01fSchristos 
569*1dcdf01fSchristos   if (config->use_sparse_dh_prime) {
570*1dcdf01fSchristos     BIGNUM *p, *g;
571*1dcdf01fSchristos     p = BN_new();
572*1dcdf01fSchristos     g = BN_new();
573*1dcdf01fSchristos     tmpdh = DH_new();
574*1dcdf01fSchristos     if (p == NULL || g == NULL || tmpdh == NULL) {
575*1dcdf01fSchristos         BN_free(p);
576*1dcdf01fSchristos         BN_free(g);
577*1dcdf01fSchristos         DH_free(tmpdh);
578*1dcdf01fSchristos         return nullptr;
579*1dcdf01fSchristos     }
580*1dcdf01fSchristos     // This prime number is 2^1024 + 643 – a value just above a power of two.
581*1dcdf01fSchristos     // Because of its form, values modulo it are essentially certain to be one
582*1dcdf01fSchristos     // byte shorter. This is used to test padding of these values.
583*1dcdf01fSchristos     if (BN_hex2bn(
584*1dcdf01fSchristos             &p,
585*1dcdf01fSchristos             "1000000000000000000000000000000000000000000000000000000000000000"
586*1dcdf01fSchristos             "0000000000000000000000000000000000000000000000000000000000000000"
587*1dcdf01fSchristos             "0000000000000000000000000000000000000000000000000000000000000000"
588*1dcdf01fSchristos             "0000000000000000000000000000000000000000000000000000000000000028"
589*1dcdf01fSchristos             "3") == 0 ||
590*1dcdf01fSchristos         !BN_set_word(g, 2)) {
591*1dcdf01fSchristos       BN_free(p);
592*1dcdf01fSchristos       BN_free(g);
593*1dcdf01fSchristos       DH_free(tmpdh);
594*1dcdf01fSchristos       return nullptr;
595*1dcdf01fSchristos     }
596*1dcdf01fSchristos     DH_set0_pqg(tmpdh, p, NULL, g);
597*1dcdf01fSchristos   } else {
598*1dcdf01fSchristos       tmpdh = DH_get_2048_256();
599*1dcdf01fSchristos   }
600*1dcdf01fSchristos 
601*1dcdf01fSchristos   bssl::UniquePtr<DH> dh(tmpdh);
602*1dcdf01fSchristos 
603*1dcdf01fSchristos   if (!dh || !SSL_CTX_set_tmp_dh(ssl_ctx.get(), dh.get())) {
604*1dcdf01fSchristos     return nullptr;
605*1dcdf01fSchristos   }
606*1dcdf01fSchristos 
607*1dcdf01fSchristos   SSL_CTX_set_session_cache_mode(ssl_ctx.get(), SSL_SESS_CACHE_BOTH);
608*1dcdf01fSchristos 
609*1dcdf01fSchristos   if (config->use_old_client_cert_callback) {
610*1dcdf01fSchristos     SSL_CTX_set_client_cert_cb(ssl_ctx.get(), ClientCertCallback);
611*1dcdf01fSchristos   }
612*1dcdf01fSchristos 
613*1dcdf01fSchristos   SSL_CTX_set_npn_advertised_cb(
614*1dcdf01fSchristos       ssl_ctx.get(), NextProtosAdvertisedCallback, NULL);
615*1dcdf01fSchristos   if (!config->select_next_proto.empty()) {
616*1dcdf01fSchristos     SSL_CTX_set_next_proto_select_cb(ssl_ctx.get(), NextProtoSelectCallback,
617*1dcdf01fSchristos                                      NULL);
618*1dcdf01fSchristos   }
619*1dcdf01fSchristos 
620*1dcdf01fSchristos   if (!config->select_alpn.empty() || config->decline_alpn) {
621*1dcdf01fSchristos     SSL_CTX_set_alpn_select_cb(ssl_ctx.get(), AlpnSelectCallback, NULL);
622*1dcdf01fSchristos   }
623*1dcdf01fSchristos 
624*1dcdf01fSchristos   SSL_CTX_set_info_callback(ssl_ctx.get(), InfoCallback);
625*1dcdf01fSchristos   SSL_CTX_sess_set_new_cb(ssl_ctx.get(), NewSessionCallback);
626*1dcdf01fSchristos 
627*1dcdf01fSchristos   if (config->use_ticket_callback) {
628*1dcdf01fSchristos     SSL_CTX_set_tlsext_ticket_key_cb(ssl_ctx.get(), TicketKeyCallback);
629*1dcdf01fSchristos   }
630*1dcdf01fSchristos 
631*1dcdf01fSchristos   if (config->enable_client_custom_extension &&
632*1dcdf01fSchristos       !SSL_CTX_add_client_custom_ext(
633*1dcdf01fSchristos           ssl_ctx.get(), kCustomExtensionValue, CustomExtensionAddCallback,
634*1dcdf01fSchristos           CustomExtensionFreeCallback, kCustomExtensionAddArg,
635*1dcdf01fSchristos           CustomExtensionParseCallback, kCustomExtensionParseArg)) {
636*1dcdf01fSchristos     return nullptr;
637*1dcdf01fSchristos   }
638*1dcdf01fSchristos 
639*1dcdf01fSchristos   if (config->enable_server_custom_extension &&
640*1dcdf01fSchristos       !SSL_CTX_add_server_custom_ext(
641*1dcdf01fSchristos           ssl_ctx.get(), kCustomExtensionValue, CustomExtensionAddCallback,
642*1dcdf01fSchristos           CustomExtensionFreeCallback, kCustomExtensionAddArg,
643*1dcdf01fSchristos           CustomExtensionParseCallback, kCustomExtensionParseArg)) {
644*1dcdf01fSchristos     return nullptr;
645*1dcdf01fSchristos   }
646*1dcdf01fSchristos 
647*1dcdf01fSchristos   if (config->verify_fail) {
648*1dcdf01fSchristos     SSL_CTX_set_cert_verify_callback(ssl_ctx.get(), VerifyFail, NULL);
649*1dcdf01fSchristos   } else {
650*1dcdf01fSchristos     SSL_CTX_set_cert_verify_callback(ssl_ctx.get(), VerifySucceed, NULL);
651*1dcdf01fSchristos   }
652*1dcdf01fSchristos 
653*1dcdf01fSchristos   if (config->use_null_client_ca_list) {
654*1dcdf01fSchristos     SSL_CTX_set_client_CA_list(ssl_ctx.get(), nullptr);
655*1dcdf01fSchristos   }
656*1dcdf01fSchristos 
657*1dcdf01fSchristos   if (!SSL_CTX_set_session_id_context(ssl_ctx.get(),
658*1dcdf01fSchristos                                       (const unsigned char *)sess_id_ctx,
659*1dcdf01fSchristos                                       sizeof(sess_id_ctx) - 1))
660*1dcdf01fSchristos     return nullptr;
661*1dcdf01fSchristos 
662*1dcdf01fSchristos   if (!config->expected_server_name.empty()) {
663*1dcdf01fSchristos     SSL_CTX_set_tlsext_servername_callback(ssl_ctx.get(), ServerNameCallback);
664*1dcdf01fSchristos   }
665*1dcdf01fSchristos 
666*1dcdf01fSchristos   return ssl_ctx;
667*1dcdf01fSchristos }
668*1dcdf01fSchristos 
669*1dcdf01fSchristos // RetryAsync is called after a failed operation on |ssl| with return code
670*1dcdf01fSchristos // |ret|. If the operation should be retried, it simulates one asynchronous
671*1dcdf01fSchristos // event and returns true. Otherwise it returns false.
RetryAsync(SSL * ssl,int ret)672*1dcdf01fSchristos static bool RetryAsync(SSL *ssl, int ret) {
673*1dcdf01fSchristos   // No error; don't retry.
674*1dcdf01fSchristos   if (ret >= 0) {
675*1dcdf01fSchristos     return false;
676*1dcdf01fSchristos   }
677*1dcdf01fSchristos 
678*1dcdf01fSchristos   TestState *test_state = GetTestState(ssl);
679*1dcdf01fSchristos   assert(GetTestConfig(ssl)->async);
680*1dcdf01fSchristos 
681*1dcdf01fSchristos   if (test_state->packeted_bio != nullptr &&
682*1dcdf01fSchristos       PacketedBioAdvanceClock(test_state->packeted_bio)) {
683*1dcdf01fSchristos     // The DTLS retransmit logic silently ignores write failures. So the test
684*1dcdf01fSchristos     // may progress, allow writes through synchronously.
685*1dcdf01fSchristos     AsyncBioEnforceWriteQuota(test_state->async_bio, false);
686*1dcdf01fSchristos     int timeout_ret = DTLSv1_handle_timeout(ssl);
687*1dcdf01fSchristos     AsyncBioEnforceWriteQuota(test_state->async_bio, true);
688*1dcdf01fSchristos 
689*1dcdf01fSchristos     if (timeout_ret < 0) {
690*1dcdf01fSchristos       fprintf(stderr, "Error retransmitting.\n");
691*1dcdf01fSchristos       return false;
692*1dcdf01fSchristos     }
693*1dcdf01fSchristos     return true;
694*1dcdf01fSchristos   }
695*1dcdf01fSchristos 
696*1dcdf01fSchristos   // See if we needed to read or write more. If so, allow one byte through on
697*1dcdf01fSchristos   // the appropriate end to maximally stress the state machine.
698*1dcdf01fSchristos   switch (SSL_get_error(ssl, ret)) {
699*1dcdf01fSchristos     case SSL_ERROR_WANT_READ:
700*1dcdf01fSchristos       AsyncBioAllowRead(test_state->async_bio, 1);
701*1dcdf01fSchristos       return true;
702*1dcdf01fSchristos     case SSL_ERROR_WANT_WRITE:
703*1dcdf01fSchristos       AsyncBioAllowWrite(test_state->async_bio, 1);
704*1dcdf01fSchristos       return true;
705*1dcdf01fSchristos     case SSL_ERROR_WANT_X509_LOOKUP:
706*1dcdf01fSchristos       test_state->cert_ready = true;
707*1dcdf01fSchristos       return true;
708*1dcdf01fSchristos     default:
709*1dcdf01fSchristos       return false;
710*1dcdf01fSchristos   }
711*1dcdf01fSchristos }
712*1dcdf01fSchristos 
713*1dcdf01fSchristos // DoRead reads from |ssl|, resolving any asynchronous operations. It returns
714*1dcdf01fSchristos // the result value of the final |SSL_read| call.
DoRead(SSL * ssl,uint8_t * out,size_t max_out)715*1dcdf01fSchristos static int DoRead(SSL *ssl, uint8_t *out, size_t max_out) {
716*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
717*1dcdf01fSchristos   TestState *test_state = GetTestState(ssl);
718*1dcdf01fSchristos   int ret;
719*1dcdf01fSchristos   do {
720*1dcdf01fSchristos     if (config->async) {
721*1dcdf01fSchristos       // The DTLS retransmit logic silently ignores write failures. So the test
722*1dcdf01fSchristos       // may progress, allow writes through synchronously. |SSL_read| may
723*1dcdf01fSchristos       // trigger a retransmit, so disconnect the write quota.
724*1dcdf01fSchristos       AsyncBioEnforceWriteQuota(test_state->async_bio, false);
725*1dcdf01fSchristos     }
726*1dcdf01fSchristos     ret = config->peek_then_read ? SSL_peek(ssl, out, max_out)
727*1dcdf01fSchristos                                  : SSL_read(ssl, out, max_out);
728*1dcdf01fSchristos     if (config->async) {
729*1dcdf01fSchristos       AsyncBioEnforceWriteQuota(test_state->async_bio, true);
730*1dcdf01fSchristos     }
731*1dcdf01fSchristos   } while (config->async && RetryAsync(ssl, ret));
732*1dcdf01fSchristos 
733*1dcdf01fSchristos   if (config->peek_then_read && ret > 0) {
734*1dcdf01fSchristos     std::unique_ptr<uint8_t[]> buf(new uint8_t[static_cast<size_t>(ret)]);
735*1dcdf01fSchristos 
736*1dcdf01fSchristos     // SSL_peek should synchronously return the same data.
737*1dcdf01fSchristos     int ret2 = SSL_peek(ssl, buf.get(), ret);
738*1dcdf01fSchristos     if (ret2 != ret ||
739*1dcdf01fSchristos         memcmp(buf.get(), out, ret) != 0) {
740*1dcdf01fSchristos       fprintf(stderr, "First and second SSL_peek did not match.\n");
741*1dcdf01fSchristos       return -1;
742*1dcdf01fSchristos     }
743*1dcdf01fSchristos 
744*1dcdf01fSchristos     // SSL_read should synchronously return the same data and consume it.
745*1dcdf01fSchristos     ret2 = SSL_read(ssl, buf.get(), ret);
746*1dcdf01fSchristos     if (ret2 != ret ||
747*1dcdf01fSchristos         memcmp(buf.get(), out, ret) != 0) {
748*1dcdf01fSchristos       fprintf(stderr, "SSL_peek and SSL_read did not match.\n");
749*1dcdf01fSchristos       return -1;
750*1dcdf01fSchristos     }
751*1dcdf01fSchristos   }
752*1dcdf01fSchristos 
753*1dcdf01fSchristos   return ret;
754*1dcdf01fSchristos }
755*1dcdf01fSchristos 
756*1dcdf01fSchristos // WriteAll writes |in_len| bytes from |in| to |ssl|, resolving any asynchronous
757*1dcdf01fSchristos // operations. It returns the result of the final |SSL_write| call.
WriteAll(SSL * ssl,const uint8_t * in,size_t in_len)758*1dcdf01fSchristos static int WriteAll(SSL *ssl, const uint8_t *in, size_t in_len) {
759*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
760*1dcdf01fSchristos   int ret;
761*1dcdf01fSchristos   do {
762*1dcdf01fSchristos     ret = SSL_write(ssl, in, in_len);
763*1dcdf01fSchristos     if (ret > 0) {
764*1dcdf01fSchristos       in += ret;
765*1dcdf01fSchristos       in_len -= ret;
766*1dcdf01fSchristos     }
767*1dcdf01fSchristos   } while ((config->async && RetryAsync(ssl, ret)) || (ret > 0 && in_len > 0));
768*1dcdf01fSchristos   return ret;
769*1dcdf01fSchristos }
770*1dcdf01fSchristos 
771*1dcdf01fSchristos // DoShutdown calls |SSL_shutdown|, resolving any asynchronous operations. It
772*1dcdf01fSchristos // returns the result of the final |SSL_shutdown| call.
DoShutdown(SSL * ssl)773*1dcdf01fSchristos static int DoShutdown(SSL *ssl) {
774*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
775*1dcdf01fSchristos   int ret;
776*1dcdf01fSchristos   do {
777*1dcdf01fSchristos     ret = SSL_shutdown(ssl);
778*1dcdf01fSchristos   } while (config->async && RetryAsync(ssl, ret));
779*1dcdf01fSchristos   return ret;
780*1dcdf01fSchristos }
781*1dcdf01fSchristos 
GetProtocolVersion(const SSL * ssl)782*1dcdf01fSchristos static uint16_t GetProtocolVersion(const SSL *ssl) {
783*1dcdf01fSchristos   uint16_t version = SSL_version(ssl);
784*1dcdf01fSchristos   if (!SSL_is_dtls(ssl)) {
785*1dcdf01fSchristos     return version;
786*1dcdf01fSchristos   }
787*1dcdf01fSchristos   return 0x0201 + ~version;
788*1dcdf01fSchristos }
789*1dcdf01fSchristos 
790*1dcdf01fSchristos // CheckHandshakeProperties checks, immediately after |ssl| completes its
791*1dcdf01fSchristos // initial handshake (or False Starts), whether all the properties are
792*1dcdf01fSchristos // consistent with the test configuration and invariants.
CheckHandshakeProperties(SSL * ssl,bool is_resume)793*1dcdf01fSchristos static bool CheckHandshakeProperties(SSL *ssl, bool is_resume) {
794*1dcdf01fSchristos   const TestConfig *config = GetTestConfig(ssl);
795*1dcdf01fSchristos 
796*1dcdf01fSchristos   if (SSL_get_current_cipher(ssl) == nullptr) {
797*1dcdf01fSchristos     fprintf(stderr, "null cipher after handshake\n");
798*1dcdf01fSchristos     return false;
799*1dcdf01fSchristos   }
800*1dcdf01fSchristos 
801*1dcdf01fSchristos   if (is_resume &&
802*1dcdf01fSchristos       (!!SSL_session_reused(ssl) == config->expect_session_miss)) {
803*1dcdf01fSchristos     fprintf(stderr, "session was%s reused\n",
804*1dcdf01fSchristos             SSL_session_reused(ssl) ? "" : " not");
805*1dcdf01fSchristos     return false;
806*1dcdf01fSchristos   }
807*1dcdf01fSchristos 
808*1dcdf01fSchristos   if (!GetTestState(ssl)->handshake_done) {
809*1dcdf01fSchristos     fprintf(stderr, "handshake was not completed\n");
810*1dcdf01fSchristos     return false;
811*1dcdf01fSchristos   }
812*1dcdf01fSchristos 
813*1dcdf01fSchristos   if (!config->is_server) {
814*1dcdf01fSchristos     bool expect_new_session =
815*1dcdf01fSchristos         !config->expect_no_session &&
816*1dcdf01fSchristos         (!SSL_session_reused(ssl) || config->expect_ticket_renewal) &&
817*1dcdf01fSchristos         // Session tickets are sent post-handshake in TLS 1.3.
818*1dcdf01fSchristos         GetProtocolVersion(ssl) < TLS1_3_VERSION;
819*1dcdf01fSchristos     if (expect_new_session != GetTestState(ssl)->got_new_session) {
820*1dcdf01fSchristos       fprintf(stderr,
821*1dcdf01fSchristos               "new session was%s cached, but we expected the opposite\n",
822*1dcdf01fSchristos               GetTestState(ssl)->got_new_session ? "" : " not");
823*1dcdf01fSchristos       return false;
824*1dcdf01fSchristos     }
825*1dcdf01fSchristos   }
826*1dcdf01fSchristos 
827*1dcdf01fSchristos   if (!config->expected_server_name.empty()) {
828*1dcdf01fSchristos     const char *server_name =
829*1dcdf01fSchristos         SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
830*1dcdf01fSchristos     if (server_name == nullptr ||
831*1dcdf01fSchristos             std::string(server_name) != config->expected_server_name) {
832*1dcdf01fSchristos       fprintf(stderr, "servername mismatch (got %s; want %s)\n",
833*1dcdf01fSchristos               server_name, config->expected_server_name.c_str());
834*1dcdf01fSchristos       return false;
835*1dcdf01fSchristos     }
836*1dcdf01fSchristos   }
837*1dcdf01fSchristos 
838*1dcdf01fSchristos   if (!config->expected_next_proto.empty()) {
839*1dcdf01fSchristos     const uint8_t *next_proto;
840*1dcdf01fSchristos     unsigned next_proto_len;
841*1dcdf01fSchristos     SSL_get0_next_proto_negotiated(ssl, &next_proto, &next_proto_len);
842*1dcdf01fSchristos     if (next_proto_len != config->expected_next_proto.size() ||
843*1dcdf01fSchristos         memcmp(next_proto, config->expected_next_proto.data(),
844*1dcdf01fSchristos                next_proto_len) != 0) {
845*1dcdf01fSchristos       fprintf(stderr, "negotiated next proto mismatch\n");
846*1dcdf01fSchristos       return false;
847*1dcdf01fSchristos     }
848*1dcdf01fSchristos   }
849*1dcdf01fSchristos 
850*1dcdf01fSchristos   if (!config->expected_alpn.empty()) {
851*1dcdf01fSchristos     const uint8_t *alpn_proto;
852*1dcdf01fSchristos     unsigned alpn_proto_len;
853*1dcdf01fSchristos     SSL_get0_alpn_selected(ssl, &alpn_proto, &alpn_proto_len);
854*1dcdf01fSchristos     if (alpn_proto_len != config->expected_alpn.size() ||
855*1dcdf01fSchristos         memcmp(alpn_proto, config->expected_alpn.data(),
856*1dcdf01fSchristos                alpn_proto_len) != 0) {
857*1dcdf01fSchristos       fprintf(stderr, "negotiated alpn proto mismatch\n");
858*1dcdf01fSchristos       return false;
859*1dcdf01fSchristos     }
860*1dcdf01fSchristos   }
861*1dcdf01fSchristos 
862*1dcdf01fSchristos   if (config->expect_extended_master_secret) {
863*1dcdf01fSchristos     if (!SSL_get_extms_support(ssl)) {
864*1dcdf01fSchristos       fprintf(stderr, "No EMS for connection when expected");
865*1dcdf01fSchristos       return false;
866*1dcdf01fSchristos     }
867*1dcdf01fSchristos   }
868*1dcdf01fSchristos 
869*1dcdf01fSchristos   if (config->expect_verify_result) {
870*1dcdf01fSchristos     int expected_verify_result = config->verify_fail ?
871*1dcdf01fSchristos       X509_V_ERR_APPLICATION_VERIFICATION :
872*1dcdf01fSchristos       X509_V_OK;
873*1dcdf01fSchristos 
874*1dcdf01fSchristos     if (SSL_get_verify_result(ssl) != expected_verify_result) {
875*1dcdf01fSchristos       fprintf(stderr, "Wrong certificate verification result\n");
876*1dcdf01fSchristos       return false;
877*1dcdf01fSchristos     }
878*1dcdf01fSchristos   }
879*1dcdf01fSchristos 
880*1dcdf01fSchristos   if (!config->psk.empty()) {
881*1dcdf01fSchristos     if (SSL_get_peer_cert_chain(ssl) != nullptr) {
882*1dcdf01fSchristos       fprintf(stderr, "Received peer certificate on a PSK cipher.\n");
883*1dcdf01fSchristos       return false;
884*1dcdf01fSchristos     }
885*1dcdf01fSchristos   } else if (!config->is_server || config->require_any_client_certificate) {
886*1dcdf01fSchristos     if (SSL_get_peer_certificate(ssl) == nullptr) {
887*1dcdf01fSchristos       fprintf(stderr, "Received no peer certificate but expected one.\n");
888*1dcdf01fSchristos       return false;
889*1dcdf01fSchristos     }
890*1dcdf01fSchristos   }
891*1dcdf01fSchristos 
892*1dcdf01fSchristos   return true;
893*1dcdf01fSchristos }
894*1dcdf01fSchristos 
895*1dcdf01fSchristos // DoExchange runs a test SSL exchange against the peer. On success, it returns
896*1dcdf01fSchristos // true and sets |*out_session| to the negotiated SSL session. If the test is a
897*1dcdf01fSchristos // resumption attempt, |is_resume| is true and |session| is the session from the
898*1dcdf01fSchristos // previous exchange.
DoExchange(bssl::UniquePtr<SSL_SESSION> * out_session,SSL_CTX * ssl_ctx,const TestConfig * config,bool is_resume,SSL_SESSION * session)899*1dcdf01fSchristos static bool DoExchange(bssl::UniquePtr<SSL_SESSION> *out_session,
900*1dcdf01fSchristos                        SSL_CTX *ssl_ctx, const TestConfig *config,
901*1dcdf01fSchristos                        bool is_resume, SSL_SESSION *session) {
902*1dcdf01fSchristos   bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx));
903*1dcdf01fSchristos   if (!ssl) {
904*1dcdf01fSchristos     return false;
905*1dcdf01fSchristos   }
906*1dcdf01fSchristos 
907*1dcdf01fSchristos   if (!SetTestConfig(ssl.get(), config) ||
908*1dcdf01fSchristos       !SetTestState(ssl.get(), std::unique_ptr<TestState>(new TestState))) {
909*1dcdf01fSchristos     return false;
910*1dcdf01fSchristos   }
911*1dcdf01fSchristos 
912*1dcdf01fSchristos   if (config->fallback_scsv &&
913*1dcdf01fSchristos       !SSL_set_mode(ssl.get(), SSL_MODE_SEND_FALLBACK_SCSV)) {
914*1dcdf01fSchristos     return false;
915*1dcdf01fSchristos   }
916*1dcdf01fSchristos   // Install the certificate synchronously if nothing else will handle it.
917*1dcdf01fSchristos   if (!config->use_old_client_cert_callback &&
918*1dcdf01fSchristos       !config->async &&
919*1dcdf01fSchristos       !InstallCertificate(ssl.get())) {
920*1dcdf01fSchristos     return false;
921*1dcdf01fSchristos   }
922*1dcdf01fSchristos   SSL_set_cert_cb(ssl.get(), CertCallback, nullptr);
923*1dcdf01fSchristos   if (config->require_any_client_certificate) {
924*1dcdf01fSchristos     SSL_set_verify(ssl.get(), SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
925*1dcdf01fSchristos                    NULL);
926*1dcdf01fSchristos   }
927*1dcdf01fSchristos   if (config->verify_peer) {
928*1dcdf01fSchristos     SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, NULL);
929*1dcdf01fSchristos   }
930*1dcdf01fSchristos   if (config->partial_write) {
931*1dcdf01fSchristos     SSL_set_mode(ssl.get(), SSL_MODE_ENABLE_PARTIAL_WRITE);
932*1dcdf01fSchristos   }
933*1dcdf01fSchristos   if (config->no_tls13) {
934*1dcdf01fSchristos     SSL_set_options(ssl.get(), SSL_OP_NO_TLSv1_3);
935*1dcdf01fSchristos   }
936*1dcdf01fSchristos   if (config->no_tls12) {
937*1dcdf01fSchristos     SSL_set_options(ssl.get(), SSL_OP_NO_TLSv1_2);
938*1dcdf01fSchristos   }
939*1dcdf01fSchristos   if (config->no_tls11) {
940*1dcdf01fSchristos     SSL_set_options(ssl.get(), SSL_OP_NO_TLSv1_1);
941*1dcdf01fSchristos   }
942*1dcdf01fSchristos   if (config->no_tls1) {
943*1dcdf01fSchristos     SSL_set_options(ssl.get(), SSL_OP_NO_TLSv1);
944*1dcdf01fSchristos   }
945*1dcdf01fSchristos   if (config->no_ssl3) {
946*1dcdf01fSchristos     SSL_set_options(ssl.get(), SSL_OP_NO_SSLv3);
947*1dcdf01fSchristos   }
948*1dcdf01fSchristos   if (!config->host_name.empty() &&
949*1dcdf01fSchristos       !SSL_set_tlsext_host_name(ssl.get(), config->host_name.c_str())) {
950*1dcdf01fSchristos     return false;
951*1dcdf01fSchristos   }
952*1dcdf01fSchristos   if (!config->advertise_alpn.empty() &&
953*1dcdf01fSchristos       SSL_set_alpn_protos(ssl.get(),
954*1dcdf01fSchristos                           (const uint8_t *)config->advertise_alpn.data(),
955*1dcdf01fSchristos                           config->advertise_alpn.size()) != 0) {
956*1dcdf01fSchristos     return false;
957*1dcdf01fSchristos   }
958*1dcdf01fSchristos   if (!config->psk.empty()) {
959*1dcdf01fSchristos     SSL_set_psk_client_callback(ssl.get(), PskClientCallback);
960*1dcdf01fSchristos     SSL_set_psk_server_callback(ssl.get(), PskServerCallback);
961*1dcdf01fSchristos   }
962*1dcdf01fSchristos   if (!config->psk_identity.empty() &&
963*1dcdf01fSchristos       !SSL_use_psk_identity_hint(ssl.get(), config->psk_identity.c_str())) {
964*1dcdf01fSchristos     return false;
965*1dcdf01fSchristos   }
966*1dcdf01fSchristos   if (!config->srtp_profiles.empty() &&
967*1dcdf01fSchristos       SSL_set_tlsext_use_srtp(ssl.get(), config->srtp_profiles.c_str())) {
968*1dcdf01fSchristos     return false;
969*1dcdf01fSchristos   }
970*1dcdf01fSchristos   if (config->min_version != 0 &&
971*1dcdf01fSchristos       !SSL_set_min_proto_version(ssl.get(), (uint16_t)config->min_version)) {
972*1dcdf01fSchristos     return false;
973*1dcdf01fSchristos   }
974*1dcdf01fSchristos   if (config->max_version != 0 &&
975*1dcdf01fSchristos       !SSL_set_max_proto_version(ssl.get(), (uint16_t)config->max_version)) {
976*1dcdf01fSchristos     return false;
977*1dcdf01fSchristos   }
978*1dcdf01fSchristos   if (config->mtu != 0) {
979*1dcdf01fSchristos     SSL_set_options(ssl.get(), SSL_OP_NO_QUERY_MTU);
980*1dcdf01fSchristos     SSL_set_mtu(ssl.get(), config->mtu);
981*1dcdf01fSchristos   }
982*1dcdf01fSchristos   if (config->renegotiate_freely) {
983*1dcdf01fSchristos     // This is always on for OpenSSL.
984*1dcdf01fSchristos   }
985*1dcdf01fSchristos   if (!config->check_close_notify) {
986*1dcdf01fSchristos     SSL_set_quiet_shutdown(ssl.get(), 1);
987*1dcdf01fSchristos   }
988*1dcdf01fSchristos   if (config->p384_only) {
989*1dcdf01fSchristos     int nid = NID_secp384r1;
990*1dcdf01fSchristos     if (!SSL_set1_curves(ssl.get(), &nid, 1)) {
991*1dcdf01fSchristos       return false;
992*1dcdf01fSchristos     }
993*1dcdf01fSchristos   }
994*1dcdf01fSchristos   if (config->enable_all_curves) {
995*1dcdf01fSchristos     static const int kAllCurves[] = {
996*1dcdf01fSchristos       NID_X25519, NID_X9_62_prime256v1, NID_X448, NID_secp521r1, NID_secp384r1
997*1dcdf01fSchristos     };
998*1dcdf01fSchristos     if (!SSL_set1_curves(ssl.get(), kAllCurves,
999*1dcdf01fSchristos                          OPENSSL_ARRAY_SIZE(kAllCurves))) {
1000*1dcdf01fSchristos       return false;
1001*1dcdf01fSchristos     }
1002*1dcdf01fSchristos   }
1003*1dcdf01fSchristos   if (config->max_cert_list > 0) {
1004*1dcdf01fSchristos     SSL_set_max_cert_list(ssl.get(), config->max_cert_list);
1005*1dcdf01fSchristos   }
1006*1dcdf01fSchristos 
1007*1dcdf01fSchristos   if (!config->async) {
1008*1dcdf01fSchristos     SSL_set_mode(ssl.get(), SSL_MODE_AUTO_RETRY);
1009*1dcdf01fSchristos   }
1010*1dcdf01fSchristos 
1011*1dcdf01fSchristos   int sock = Connect(config->port);
1012*1dcdf01fSchristos   if (sock == -1) {
1013*1dcdf01fSchristos     return false;
1014*1dcdf01fSchristos   }
1015*1dcdf01fSchristos   SocketCloser closer(sock);
1016*1dcdf01fSchristos 
1017*1dcdf01fSchristos   bssl::UniquePtr<BIO> bio(BIO_new_socket(sock, BIO_NOCLOSE));
1018*1dcdf01fSchristos   if (!bio) {
1019*1dcdf01fSchristos     return false;
1020*1dcdf01fSchristos   }
1021*1dcdf01fSchristos   if (config->is_dtls) {
1022*1dcdf01fSchristos     bssl::UniquePtr<BIO> packeted = PacketedBioCreate(!config->async);
1023*1dcdf01fSchristos     if (!packeted) {
1024*1dcdf01fSchristos       return false;
1025*1dcdf01fSchristos     }
1026*1dcdf01fSchristos     GetTestState(ssl.get())->packeted_bio = packeted.get();
1027*1dcdf01fSchristos     BIO_push(packeted.get(), bio.release());
1028*1dcdf01fSchristos     bio = std::move(packeted);
1029*1dcdf01fSchristos   }
1030*1dcdf01fSchristos   if (config->async) {
1031*1dcdf01fSchristos     bssl::UniquePtr<BIO> async_scoped =
1032*1dcdf01fSchristos         config->is_dtls ? AsyncBioCreateDatagram() : AsyncBioCreate();
1033*1dcdf01fSchristos     if (!async_scoped) {
1034*1dcdf01fSchristos       return false;
1035*1dcdf01fSchristos     }
1036*1dcdf01fSchristos     BIO_push(async_scoped.get(), bio.release());
1037*1dcdf01fSchristos     GetTestState(ssl.get())->async_bio = async_scoped.get();
1038*1dcdf01fSchristos     bio = std::move(async_scoped);
1039*1dcdf01fSchristos   }
1040*1dcdf01fSchristos   SSL_set_bio(ssl.get(), bio.get(), bio.get());
1041*1dcdf01fSchristos   bio.release();  // SSL_set_bio takes ownership.
1042*1dcdf01fSchristos 
1043*1dcdf01fSchristos   if (session != NULL) {
1044*1dcdf01fSchristos     if (!config->is_server) {
1045*1dcdf01fSchristos       if (SSL_set_session(ssl.get(), session) != 1) {
1046*1dcdf01fSchristos         return false;
1047*1dcdf01fSchristos       }
1048*1dcdf01fSchristos     }
1049*1dcdf01fSchristos   }
1050*1dcdf01fSchristos 
1051*1dcdf01fSchristos #if 0
1052*1dcdf01fSchristos   // KNOWN BUG: OpenSSL's SSL_get_current_cipher behaves incorrectly when
1053*1dcdf01fSchristos   // offering resumption.
1054*1dcdf01fSchristos   if (SSL_get_current_cipher(ssl.get()) != nullptr) {
1055*1dcdf01fSchristos     fprintf(stderr, "non-null cipher before handshake\n");
1056*1dcdf01fSchristos     return false;
1057*1dcdf01fSchristos   }
1058*1dcdf01fSchristos #endif
1059*1dcdf01fSchristos 
1060*1dcdf01fSchristos   int ret;
1061*1dcdf01fSchristos   if (config->implicit_handshake) {
1062*1dcdf01fSchristos     if (config->is_server) {
1063*1dcdf01fSchristos       SSL_set_accept_state(ssl.get());
1064*1dcdf01fSchristos     } else {
1065*1dcdf01fSchristos       SSL_set_connect_state(ssl.get());
1066*1dcdf01fSchristos     }
1067*1dcdf01fSchristos   } else {
1068*1dcdf01fSchristos     do {
1069*1dcdf01fSchristos       if (config->is_server) {
1070*1dcdf01fSchristos         ret = SSL_accept(ssl.get());
1071*1dcdf01fSchristos       } else {
1072*1dcdf01fSchristos         ret = SSL_connect(ssl.get());
1073*1dcdf01fSchristos       }
1074*1dcdf01fSchristos     } while (config->async && RetryAsync(ssl.get(), ret));
1075*1dcdf01fSchristos     if (ret != 1 ||
1076*1dcdf01fSchristos         !CheckHandshakeProperties(ssl.get(), is_resume)) {
1077*1dcdf01fSchristos       return false;
1078*1dcdf01fSchristos     }
1079*1dcdf01fSchristos 
1080*1dcdf01fSchristos     // Reset the state to assert later that the callback isn't called in
1081*1dcdf01fSchristos     // renegotiations.
1082*1dcdf01fSchristos     GetTestState(ssl.get())->got_new_session = false;
1083*1dcdf01fSchristos   }
1084*1dcdf01fSchristos 
1085*1dcdf01fSchristos   if (config->export_keying_material > 0) {
1086*1dcdf01fSchristos     std::vector<uint8_t> result(
1087*1dcdf01fSchristos         static_cast<size_t>(config->export_keying_material));
1088*1dcdf01fSchristos     if (SSL_export_keying_material(
1089*1dcdf01fSchristos             ssl.get(), result.data(), result.size(),
1090*1dcdf01fSchristos             config->export_label.data(), config->export_label.size(),
1091*1dcdf01fSchristos             reinterpret_cast<const uint8_t*>(config->export_context.data()),
1092*1dcdf01fSchristos             config->export_context.size(), config->use_export_context) != 1) {
1093*1dcdf01fSchristos       fprintf(stderr, "failed to export keying material\n");
1094*1dcdf01fSchristos       return false;
1095*1dcdf01fSchristos     }
1096*1dcdf01fSchristos     if (WriteAll(ssl.get(), result.data(), result.size()) < 0) {
1097*1dcdf01fSchristos       return false;
1098*1dcdf01fSchristos     }
1099*1dcdf01fSchristos   }
1100*1dcdf01fSchristos 
1101*1dcdf01fSchristos   if (config->write_different_record_sizes) {
1102*1dcdf01fSchristos     if (config->is_dtls) {
1103*1dcdf01fSchristos       fprintf(stderr, "write_different_record_sizes not supported for DTLS\n");
1104*1dcdf01fSchristos       return false;
1105*1dcdf01fSchristos     }
1106*1dcdf01fSchristos     // This mode writes a number of different record sizes in an attempt to
1107*1dcdf01fSchristos     // trip up the CBC record splitting code.
1108*1dcdf01fSchristos     static const size_t kBufLen = 32769;
1109*1dcdf01fSchristos     std::unique_ptr<uint8_t[]> buf(new uint8_t[kBufLen]);
1110*1dcdf01fSchristos     memset(buf.get(), 0x42, kBufLen);
1111*1dcdf01fSchristos     static const size_t kRecordSizes[] = {
1112*1dcdf01fSchristos         0, 1, 255, 256, 257, 16383, 16384, 16385, 32767, 32768, 32769};
1113*1dcdf01fSchristos     for (size_t i = 0; i < OPENSSL_ARRAY_SIZE(kRecordSizes); i++) {
1114*1dcdf01fSchristos       const size_t len = kRecordSizes[i];
1115*1dcdf01fSchristos       if (len > kBufLen) {
1116*1dcdf01fSchristos         fprintf(stderr, "Bad kRecordSizes value.\n");
1117*1dcdf01fSchristos         return false;
1118*1dcdf01fSchristos       }
1119*1dcdf01fSchristos       if (WriteAll(ssl.get(), buf.get(), len) < 0) {
1120*1dcdf01fSchristos         return false;
1121*1dcdf01fSchristos       }
1122*1dcdf01fSchristos     }
1123*1dcdf01fSchristos   } else {
1124*1dcdf01fSchristos     if (config->shim_writes_first) {
1125*1dcdf01fSchristos       if (WriteAll(ssl.get(), reinterpret_cast<const uint8_t *>("hello"),
1126*1dcdf01fSchristos                    5) < 0) {
1127*1dcdf01fSchristos         return false;
1128*1dcdf01fSchristos       }
1129*1dcdf01fSchristos     }
1130*1dcdf01fSchristos     if (!config->shim_shuts_down) {
1131*1dcdf01fSchristos       for (;;) {
1132*1dcdf01fSchristos         static const size_t kBufLen = 16384;
1133*1dcdf01fSchristos         std::unique_ptr<uint8_t[]> buf(new uint8_t[kBufLen]);
1134*1dcdf01fSchristos 
1135*1dcdf01fSchristos         // Read only 512 bytes at a time in TLS to ensure records may be
1136*1dcdf01fSchristos         // returned in multiple reads.
1137*1dcdf01fSchristos         int n = DoRead(ssl.get(), buf.get(), config->is_dtls ? kBufLen : 512);
1138*1dcdf01fSchristos         int err = SSL_get_error(ssl.get(), n);
1139*1dcdf01fSchristos         if (err == SSL_ERROR_ZERO_RETURN ||
1140*1dcdf01fSchristos             (n == 0 && err == SSL_ERROR_SYSCALL)) {
1141*1dcdf01fSchristos           if (n != 0) {
1142*1dcdf01fSchristos             fprintf(stderr, "Invalid SSL_get_error output\n");
1143*1dcdf01fSchristos             return false;
1144*1dcdf01fSchristos           }
1145*1dcdf01fSchristos           // Stop on either clean or unclean shutdown.
1146*1dcdf01fSchristos           break;
1147*1dcdf01fSchristos         } else if (err != SSL_ERROR_NONE) {
1148*1dcdf01fSchristos           if (n > 0) {
1149*1dcdf01fSchristos             fprintf(stderr, "Invalid SSL_get_error output\n");
1150*1dcdf01fSchristos             return false;
1151*1dcdf01fSchristos           }
1152*1dcdf01fSchristos           return false;
1153*1dcdf01fSchristos         }
1154*1dcdf01fSchristos         // Successfully read data.
1155*1dcdf01fSchristos         if (n <= 0) {
1156*1dcdf01fSchristos           fprintf(stderr, "Invalid SSL_get_error output\n");
1157*1dcdf01fSchristos           return false;
1158*1dcdf01fSchristos         }
1159*1dcdf01fSchristos 
1160*1dcdf01fSchristos         // After a successful read, with or without False Start, the handshake
1161*1dcdf01fSchristos         // must be complete.
1162*1dcdf01fSchristos         if (!GetTestState(ssl.get())->handshake_done) {
1163*1dcdf01fSchristos           fprintf(stderr, "handshake was not completed after SSL_read\n");
1164*1dcdf01fSchristos           return false;
1165*1dcdf01fSchristos         }
1166*1dcdf01fSchristos 
1167*1dcdf01fSchristos         for (int i = 0; i < n; i++) {
1168*1dcdf01fSchristos           buf[i] ^= 0xff;
1169*1dcdf01fSchristos         }
1170*1dcdf01fSchristos         if (WriteAll(ssl.get(), buf.get(), n) < 0) {
1171*1dcdf01fSchristos           return false;
1172*1dcdf01fSchristos         }
1173*1dcdf01fSchristos       }
1174*1dcdf01fSchristos     }
1175*1dcdf01fSchristos   }
1176*1dcdf01fSchristos 
1177*1dcdf01fSchristos   if (!config->is_server &&
1178*1dcdf01fSchristos       !config->implicit_handshake &&
1179*1dcdf01fSchristos       // Session tickets are sent post-handshake in TLS 1.3.
1180*1dcdf01fSchristos       GetProtocolVersion(ssl.get()) < TLS1_3_VERSION &&
1181*1dcdf01fSchristos       GetTestState(ssl.get())->got_new_session) {
1182*1dcdf01fSchristos     fprintf(stderr, "new session was established after the handshake\n");
1183*1dcdf01fSchristos     return false;
1184*1dcdf01fSchristos   }
1185*1dcdf01fSchristos 
1186*1dcdf01fSchristos   if (GetProtocolVersion(ssl.get()) >= TLS1_3_VERSION && !config->is_server) {
1187*1dcdf01fSchristos     bool expect_new_session =
1188*1dcdf01fSchristos         !config->expect_no_session && !config->shim_shuts_down;
1189*1dcdf01fSchristos     if (expect_new_session != GetTestState(ssl.get())->got_new_session) {
1190*1dcdf01fSchristos       fprintf(stderr,
1191*1dcdf01fSchristos               "new session was%s cached, but we expected the opposite\n",
1192*1dcdf01fSchristos               GetTestState(ssl.get())->got_new_session ? "" : " not");
1193*1dcdf01fSchristos       return false;
1194*1dcdf01fSchristos     }
1195*1dcdf01fSchristos   }
1196*1dcdf01fSchristos 
1197*1dcdf01fSchristos   if (out_session) {
1198*1dcdf01fSchristos     *out_session = std::move(GetTestState(ssl.get())->new_session);
1199*1dcdf01fSchristos   }
1200*1dcdf01fSchristos 
1201*1dcdf01fSchristos   ret = DoShutdown(ssl.get());
1202*1dcdf01fSchristos 
1203*1dcdf01fSchristos   if (config->shim_shuts_down && config->check_close_notify) {
1204*1dcdf01fSchristos     // We initiate shutdown, so |SSL_shutdown| will return in two stages. First
1205*1dcdf01fSchristos     // it returns zero when our close_notify is sent, then one when the peer's
1206*1dcdf01fSchristos     // is received.
1207*1dcdf01fSchristos     if (ret != 0) {
1208*1dcdf01fSchristos       fprintf(stderr, "Unexpected SSL_shutdown result: %d != 0\n", ret);
1209*1dcdf01fSchristos       return false;
1210*1dcdf01fSchristos     }
1211*1dcdf01fSchristos     ret = DoShutdown(ssl.get());
1212*1dcdf01fSchristos   }
1213*1dcdf01fSchristos 
1214*1dcdf01fSchristos   if (ret != 1) {
1215*1dcdf01fSchristos     fprintf(stderr, "Unexpected SSL_shutdown result: %d != 1\n", ret);
1216*1dcdf01fSchristos     return false;
1217*1dcdf01fSchristos   }
1218*1dcdf01fSchristos 
1219*1dcdf01fSchristos   if (SSL_total_renegotiations(ssl.get()) !=
1220*1dcdf01fSchristos       config->expect_total_renegotiations) {
1221*1dcdf01fSchristos     fprintf(stderr, "Expected %d renegotiations, got %ld\n",
1222*1dcdf01fSchristos             config->expect_total_renegotiations,
1223*1dcdf01fSchristos             SSL_total_renegotiations(ssl.get()));
1224*1dcdf01fSchristos     return false;
1225*1dcdf01fSchristos   }
1226*1dcdf01fSchristos 
1227*1dcdf01fSchristos   return true;
1228*1dcdf01fSchristos }
1229*1dcdf01fSchristos 
1230*1dcdf01fSchristos class StderrDelimiter {
1231*1dcdf01fSchristos  public:
~StderrDelimiter()1232*1dcdf01fSchristos   ~StderrDelimiter() { fprintf(stderr, "--- DONE ---\n"); }
1233*1dcdf01fSchristos };
1234*1dcdf01fSchristos 
Main(int argc,char ** argv)1235*1dcdf01fSchristos static int Main(int argc, char **argv) {
1236*1dcdf01fSchristos   // To distinguish ASan's output from ours, add a trailing message to stderr.
1237*1dcdf01fSchristos   // Anything following this line will be considered an error.
1238*1dcdf01fSchristos   StderrDelimiter delimiter;
1239*1dcdf01fSchristos 
1240*1dcdf01fSchristos #if defined(OPENSSL_SYS_WINDOWS)
1241*1dcdf01fSchristos   /* Initialize Winsock. */
1242*1dcdf01fSchristos   WORD wsa_version = MAKEWORD(2, 2);
1243*1dcdf01fSchristos   WSADATA wsa_data;
1244*1dcdf01fSchristos   int wsa_err = WSAStartup(wsa_version, &wsa_data);
1245*1dcdf01fSchristos   if (wsa_err != 0) {
1246*1dcdf01fSchristos     fprintf(stderr, "WSAStartup failed: %d\n", wsa_err);
1247*1dcdf01fSchristos     return 1;
1248*1dcdf01fSchristos   }
1249*1dcdf01fSchristos   if (wsa_data.wVersion != wsa_version) {
1250*1dcdf01fSchristos     fprintf(stderr, "Didn't get expected version: %x\n", wsa_data.wVersion);
1251*1dcdf01fSchristos     return 1;
1252*1dcdf01fSchristos   }
1253*1dcdf01fSchristos #else
1254*1dcdf01fSchristos   signal(SIGPIPE, SIG_IGN);
1255*1dcdf01fSchristos #endif
1256*1dcdf01fSchristos 
1257*1dcdf01fSchristos   OPENSSL_init_crypto(0, NULL);
1258*1dcdf01fSchristos   OPENSSL_init_ssl(0, NULL);
1259*1dcdf01fSchristos   g_config_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
1260*1dcdf01fSchristos   g_state_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, TestStateExFree);
1261*1dcdf01fSchristos   if (g_config_index < 0 || g_state_index < 0) {
1262*1dcdf01fSchristos     return 1;
1263*1dcdf01fSchristos   }
1264*1dcdf01fSchristos 
1265*1dcdf01fSchristos   TestConfig config;
1266*1dcdf01fSchristos   if (!ParseConfig(argc - 1, argv + 1, &config)) {
1267*1dcdf01fSchristos     return Usage(argv[0]);
1268*1dcdf01fSchristos   }
1269*1dcdf01fSchristos 
1270*1dcdf01fSchristos   bssl::UniquePtr<SSL_CTX> ssl_ctx = SetupCtx(&config);
1271*1dcdf01fSchristos   if (!ssl_ctx) {
1272*1dcdf01fSchristos     ERR_print_errors_fp(stderr);
1273*1dcdf01fSchristos     return 1;
1274*1dcdf01fSchristos   }
1275*1dcdf01fSchristos 
1276*1dcdf01fSchristos   bssl::UniquePtr<SSL_SESSION> session;
1277*1dcdf01fSchristos   for (int i = 0; i < config.resume_count + 1; i++) {
1278*1dcdf01fSchristos     bool is_resume = i > 0;
1279*1dcdf01fSchristos     if (is_resume && !config.is_server && !session) {
1280*1dcdf01fSchristos       fprintf(stderr, "No session to offer.\n");
1281*1dcdf01fSchristos       return 1;
1282*1dcdf01fSchristos     }
1283*1dcdf01fSchristos 
1284*1dcdf01fSchristos     bssl::UniquePtr<SSL_SESSION> offer_session = std::move(session);
1285*1dcdf01fSchristos     if (!DoExchange(&session, ssl_ctx.get(), &config, is_resume,
1286*1dcdf01fSchristos                     offer_session.get())) {
1287*1dcdf01fSchristos       fprintf(stderr, "Connection %d failed.\n", i + 1);
1288*1dcdf01fSchristos       ERR_print_errors_fp(stderr);
1289*1dcdf01fSchristos       return 1;
1290*1dcdf01fSchristos     }
1291*1dcdf01fSchristos   }
1292*1dcdf01fSchristos 
1293*1dcdf01fSchristos   return 0;
1294*1dcdf01fSchristos }
1295*1dcdf01fSchristos 
1296*1dcdf01fSchristos }  // namespace bssl
1297*1dcdf01fSchristos 
main(int argc,char ** argv)1298*1dcdf01fSchristos int main(int argc, char **argv) {
1299*1dcdf01fSchristos   return bssl::Main(argc, argv);
1300*1dcdf01fSchristos }
1301