1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <folly/io/async/test/AsyncSSLSocketTest.h>
18 
19 #include <fcntl.h>
20 #include <signal.h>
21 #include <sys/types.h>
22 
23 #include <fstream>
24 #include <iostream>
25 #include <list>
26 #include <set>
27 #include <thread>
28 
29 #include <folly/SocketAddress.h>
30 #include <folly/String.h>
31 #include <folly/io/Cursor.h>
32 #include <folly/io/async/AsyncPipe.h>
33 #include <folly/io/async/AsyncSSLSocket.h>
34 #include <folly/io/async/EventBase.h>
35 #include <folly/io/async/EventBaseThread.h>
36 #include <folly/io/async/ScopedEventBaseThread.h>
37 #include <folly/io/async/ssl/OpenSSLTransportCertificate.h>
38 #include <folly/io/async/test/BlockingSocket.h>
39 #include <folly/io/async/test/MockAsyncTransportObserver.h>
40 #include <folly/io/async/test/TFOTest.h>
41 #include <folly/net/NetOps.h>
42 #include <folly/net/NetworkSocket.h>
43 #include <folly/net/test/MockNetOpsDispatcher.h>
44 #include <folly/portability/GMock.h>
45 #include <folly/portability/GTest.h>
46 #include <folly/portability/OpenSSL.h>
47 #include <folly/portability/Unistd.h>
48 #include <folly/ssl/Init.h>
49 
50 #ifdef __linux__
51 #include <dlfcn.h>
52 #endif
53 
54 #if FOLLY_OPENSSL_IS_110
55 #include <openssl/async.h>
56 #endif
57 
58 using std::cerr;
59 using std::endl;
60 using std::string;
61 
62 using namespace testing;
63 
64 #if defined __linux__
65 namespace {
66 
67 // to store libc's original setsockopt()
68 typedef int (*setsockopt_ptr)(int, int, int, const void*, socklen_t);
69 setsockopt_ptr real_setsockopt_ = nullptr;
70 
71 // global struct to initialize before main runs. we can init within a test,
72 // or in main, but this method seems to be least intrsive and universal
73 struct GlobalStatic {
GlobalStatic__anon28101f9d0111::GlobalStatic74   GlobalStatic() {
75     real_setsockopt_ = (setsockopt_ptr)dlsym(RTLD_NEXT, "setsockopt");
76   }
reset__anon28101f9d0111::GlobalStatic77   void reset() noexcept { ttlsDisabledSet.clear(); }
78   // for each fd, tracks whether TTLS is disabled or not
79   std::unordered_set<folly::NetworkSocket /* fd */> ttlsDisabledSet;
80 };
81 
82 // the constructor will be called before main() which is all we care about
83 GlobalStatic globalStatic;
84 
85 } // namespace
86 
87 // we intercept setsoctopt to test setting NO_TRANSPARENT_TLS opt
88 // this name has to be global
setsockopt(int sockfd,int level,int optname,const void * optval,socklen_t optlen)89 int setsockopt(
90     int sockfd, int level, int optname, const void* optval, socklen_t optlen) {
91   if (optname == SO_NO_TRANSPARENT_TLS) {
92     globalStatic.ttlsDisabledSet.insert(folly::NetworkSocket::fromFd(sockfd));
93     return 0;
94   }
95   return real_setsockopt_(sockfd, level, optname, optval, optlen);
96 }
97 #endif
98 
99 namespace folly {
100 constexpr size_t SSLClient::kMaxReadBufferSz;
101 constexpr size_t SSLClient::kMaxReadsPerEvent;
102 
getfds(NetworkSocket fds[2])103 void getfds(NetworkSocket fds[2]) {
104   if (netops::socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
105     FAIL() << "failed to create socketpair: " << errnoStr(errno);
106   }
107   for (int idx = 0; idx < 2; ++idx) {
108     if (netops::set_socket_non_blocking(fds[idx]) != 0) {
109       FAIL() << "failed to put socket " << idx
110              << " in non-blocking mode: " << errnoStr(errno);
111     }
112   }
113 }
114 
getctx(std::shared_ptr<folly::SSLContext> clientCtx,std::shared_ptr<folly::SSLContext> serverCtx)115 void getctx(
116     std::shared_ptr<folly::SSLContext> clientCtx,
117     std::shared_ptr<folly::SSLContext> serverCtx) {
118   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
119 
120   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
121   serverCtx->loadCertificate(kTestCert);
122   serverCtx->loadPrivateKey(kTestKey);
123 }
124 
sslsocketpair(EventBase * eventBase,AsyncSSLSocket::UniquePtr * clientSock,AsyncSSLSocket::UniquePtr * serverSock)125 void sslsocketpair(
126     EventBase* eventBase,
127     AsyncSSLSocket::UniquePtr* clientSock,
128     AsyncSSLSocket::UniquePtr* serverSock) {
129   auto clientCtx = std::make_shared<folly::SSLContext>();
130   auto serverCtx = std::make_shared<folly::SSLContext>();
131   NetworkSocket fds[2];
132   getfds(fds);
133   getctx(clientCtx, serverCtx);
134   clientSock->reset(new AsyncSSLSocket(clientCtx, eventBase, fds[0], false));
135   serverSock->reset(new AsyncSSLSocket(serverCtx, eventBase, fds[1], true));
136 
137   // (*clientSock)->setSendTimeout(100);
138   // (*serverSock)->setSendTimeout(100);
139 }
140 
141 // client protocol filters
clientProtoFilterPickPony(unsigned char ** client,unsigned int * client_len,const unsigned char *,unsigned int)142 bool clientProtoFilterPickPony(
143     unsigned char** client,
144     unsigned int* client_len,
145     const unsigned char*,
146     unsigned int) {
147   // the protocol string in length prefixed byte string. the
148   // length byte is not included in the length
149   static unsigned char p[7] = {6, 'p', 'o', 'n', 'i', 'e', 's'};
150   *client = p;
151   *client_len = 7;
152   return true;
153 }
154 
clientProtoFilterPickNone(unsigned char **,unsigned int *,const unsigned char *,unsigned int)155 bool clientProtoFilterPickNone(
156     unsigned char**, unsigned int*, const unsigned char*, unsigned int) {
157   return false;
158 }
159 
getFileAsBuf(const char * fileName)160 std::string getFileAsBuf(const char* fileName) {
161   std::string buffer;
162   folly::readFile(fileName, buffer);
163   return buffer;
164 }
165 
166 /**
167  * Test connecting to, writing to, reading from, and closing the
168  * connection to the SSL server.
169  */
TEST(AsyncSSLSocketTest,ConnectWriteReadClose)170 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
171   // Start listening on a local port
172   WriteCallbackBase writeCallback;
173   ReadCallback readCallback(&writeCallback);
174   HandshakeCallback handshakeCallback(&readCallback);
175   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
176   TestSSLServer server(&acceptCallback);
177 
178   // Set up SSL context.
179   std::shared_ptr<SSLContext> sslContext(new SSLContext());
180   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
181   // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
182   // sslContext->authenticate(true, false);
183 
184   // connect
185   auto socket =
186       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
187   socket->open(std::chrono::milliseconds(10000));
188 
189   // write()
190   uint8_t buf[128];
191   memset(buf, 'a', sizeof(buf));
192   socket->write(buf, sizeof(buf));
193 
194   // read()
195   uint8_t readbuf[128];
196   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
197   EXPECT_EQ(bytesRead, 128);
198   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
199 
200   // close()
201   socket->close();
202 
203   cerr << "ConnectWriteReadClose test completed" << endl;
204   EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
205 }
206 
TEST(AsyncSSLSocketTest,ConnectWriteReadCloseReadable)207 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseReadable) {
208   // Same as above, but test AsyncSSLSocket::readable along the way
209 
210   // Start listening on a local port
211   WriteCallbackBase writeCallback;
212   ReadCallback readCallback(&writeCallback);
213   HandshakeCallback handshakeCallback(&readCallback);
214   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
215   TestSSLServer server(&acceptCallback);
216 
217   // Set up SSL context.
218   std::shared_ptr<SSLContext> sslContext(new SSLContext());
219   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
220   // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
221   // sslContext->authenticate(true, false);
222 
223   // connect
224   auto socket =
225       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
226   socket->open(std::chrono::milliseconds(10000));
227 
228   // write()
229   uint8_t buf[128];
230   memset(buf, 'a', sizeof(buf));
231   socket->write(buf, sizeof(buf));
232 
233   // read()
234   uint8_t readbuf[128];
235   // The TLS record includes the full 128 bytes.  Even though we only read 1
236   // byte out of the socket, the rest of the full record decrypted and buffered
237   // in the underlying SSL session.
238   uint32_t bytesRead = socket->read(readbuf, 1);
239   EXPECT_EQ(bytesRead, 1);
240   // The socket has no data pending in the kernel
241   EXPECT_FALSE(socket->getSocket()->AsyncSocket::readable());
242   // But the socket is readable
243   EXPECT_TRUE(socket->getSocket()->readable());
244   bytesRead += socket->readAll(readbuf + 1, sizeof(readbuf) - 1);
245   EXPECT_EQ(bytesRead, 128);
246   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
247 
248   // close()
249   socket->close();
250 
251   cerr << "ConnectWriteReadClose test completed" << endl;
252   EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
253 }
254 
255 /**
256  * Check that zero copy options are a noop under AsyncSSLSocket since they
257  * aren't supported.
258  */
TEST(AsyncSSLSocketTest,ZeroCopy)259 TEST(AsyncSSLSocketTest, ZeroCopy) {
260   // Set up SSL context.
261   std::shared_ptr<SSLContext> sslContext(new SSLContext());
262   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
263 
264   auto socket = AsyncSSLSocket::newSocket(sslContext, /*evb=*/nullptr);
265   EXPECT_FALSE(socket->setZeroCopy(true));
266   EXPECT_FALSE(socket->getZeroCopy());
267 }
268 
269 /**
270  * Same as above simple test, but with a large read len to test
271  * clamping behavior.
272  */
TEST(AsyncSSLSocketTest,ConnectWriteReadLargeClose)273 TEST(AsyncSSLSocketTest, ConnectWriteReadLargeClose) {
274   // Start listening on a local port
275   WriteCallbackBase writeCallback;
276   ReadCallback readCallback(&writeCallback);
277   HandshakeCallback handshakeCallback(&readCallback);
278   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
279   TestSSLServer server(&acceptCallback);
280 
281   // Set up SSL context.
282   std::shared_ptr<SSLContext> sslContext(new SSLContext());
283   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
284   // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
285   // sslContext->authenticate(true, false);
286 
287   // connect
288   auto socket =
289       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
290   socket->open(std::chrono::milliseconds(10000));
291 
292   // write()
293   uint8_t buf[128];
294   memset(buf, 'a', sizeof(buf));
295   socket->write(buf, sizeof(buf));
296 
297   // read()
298   uint8_t readbuf[128];
299   // we will fake the read len but that should be fine
300   size_t readLen = 1L << 33;
301   uint32_t bytesRead = socket->read(readbuf, readLen);
302   EXPECT_EQ(bytesRead, 128);
303   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
304 
305   // close()
306   socket->close();
307 
308   cerr << "ConnectWriteReadClose test completed" << endl;
309   EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
310 }
311 
312 /**
313  * Test reading after server close.
314  */
TEST(AsyncSSLSocketTest,ReadAfterClose)315 TEST(AsyncSSLSocketTest, ReadAfterClose) {
316   // Start listening on a local port
317   WriteCallbackBase writeCallback;
318   ReadEOFCallback readCallback(&writeCallback);
319   HandshakeCallback handshakeCallback(&readCallback);
320   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
321   auto server = std::make_unique<TestSSLServer>(&acceptCallback);
322 
323   // Set up SSL context.
324   auto sslContext = std::make_shared<SSLContext>();
325   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
326 
327   auto socket =
328       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
329   socket->open();
330 
331   // This should trigger an EOF on the client.
332   auto evb = handshakeCallback.getSocket()->getEventBase();
333   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
334   std::array<uint8_t, 128> readbuf;
335   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
336   EXPECT_EQ(0, bytesRead);
337 }
338 
339 /**
340  * Test bad renegotiation
341  */
342 #if !defined(OPENSSL_IS_BORINGSSL)
TEST(AsyncSSLSocketTest,Renegotiate)343 TEST(AsyncSSLSocketTest, Renegotiate) {
344   EventBase eventBase;
345   auto clientCtx = std::make_shared<SSLContext>();
346   auto dfServerCtx = std::make_shared<SSLContext>();
347   std::array<NetworkSocket, 2> fds;
348   getfds(fds.data());
349   getctx(clientCtx, dfServerCtx);
350 
351   AsyncSSLSocket::UniquePtr clientSock(
352       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
353   AsyncSSLSocket::UniquePtr serverSock(
354       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
355   SSLHandshakeClient client(std::move(clientSock), true, true);
356   RenegotiatingServer server(std::move(serverSock));
357 
358   while (!client.handshakeSuccess_ && !client.handshakeError_) {
359     eventBase.loopOnce();
360   }
361 
362   ASSERT_TRUE(client.handshakeSuccess_);
363 
364   auto sslSock = std::move(client).moveSocket();
365   sslSock->detachEventBase();
366   // This is nasty, however we don't want to add support for
367   // renegotiation in AsyncSSLSocket.
368   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
369 
370   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
371 
372   std::thread t([&]() { eventBase.loopForever(); });
373 
374   // Trigger the renegotiation.
375   std::array<uint8_t, 128> buf;
376   memset(buf.data(), 'a', buf.size());
377   try {
378     socket->write(buf.data(), buf.size());
379   } catch (AsyncSocketException& e) {
380     LOG(INFO) << "client got error " << e.what();
381   }
382   eventBase.terminateLoopSoon();
383   t.join();
384 
385   eventBase.loop();
386   ASSERT_TRUE(server.renegotiationError_);
387 }
388 #endif
389 
390 /**
391  * Negative test for handshakeError().
392  */
TEST(AsyncSSLSocketTest,HandshakeError)393 TEST(AsyncSSLSocketTest, HandshakeError) {
394   // Start listening on a local port
395   WriteCallbackBase writeCallback;
396   WriteErrorCallback readCallback(&writeCallback);
397   HandshakeCallback handshakeCallback(&readCallback);
398   HandshakeErrorCallback acceptCallback(&handshakeCallback);
399   TestSSLServer server(&acceptCallback);
400 
401   // Set up SSL context.
402   std::shared_ptr<SSLContext> sslContext(new SSLContext());
403   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
404 
405   // connect
406   auto socket =
407       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
408   // read()
409   bool ex = false;
410   try {
411     socket->open();
412 
413     uint8_t readbuf[128];
414     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
415     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
416   } catch (AsyncSocketException&) {
417     ex = true;
418   }
419   EXPECT_TRUE(ex);
420 
421   // close()
422   socket->close();
423   cerr << "HandshakeError test completed" << endl;
424 }
425 
426 /**
427  * Negative test for readError().
428  */
TEST(AsyncSSLSocketTest,ReadError)429 TEST(AsyncSSLSocketTest, ReadError) {
430   // Start listening on a local port
431   WriteCallbackBase writeCallback;
432   ReadErrorCallback readCallback(&writeCallback);
433   HandshakeCallback handshakeCallback(&readCallback);
434   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
435   TestSSLServer server(&acceptCallback);
436 
437   // Set up SSL context.
438   std::shared_ptr<SSLContext> sslContext(new SSLContext());
439   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
440 
441   // connect
442   auto socket =
443       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
444   socket->open();
445 
446   // write something to trigger ssl handshake
447   uint8_t buf[128];
448   memset(buf, 'a', sizeof(buf));
449   socket->write(buf, sizeof(buf));
450 
451   socket->close();
452   cerr << "ReadError test completed" << endl;
453 }
454 
455 /**
456  * Negative test for writeError().
457  */
TEST(AsyncSSLSocketTest,WriteError)458 TEST(AsyncSSLSocketTest, WriteError) {
459   // Start listening on a local port
460   WriteCallbackBase writeCallback;
461   WriteErrorCallback readCallback(&writeCallback);
462   HandshakeCallback handshakeCallback(&readCallback);
463   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
464   TestSSLServer server(&acceptCallback);
465 
466   // Set up SSL context.
467   std::shared_ptr<SSLContext> sslContext(new SSLContext());
468   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
469 
470   // connect
471   auto socket =
472       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
473   socket->open();
474 
475   // write something to trigger ssl handshake
476   uint8_t buf[128];
477   memset(buf, 'a', sizeof(buf));
478   socket->write(buf, sizeof(buf));
479 
480   socket->close();
481   cerr << "WriteError test completed" << endl;
482 }
483 
484 /**
485  * Test a socket with TCP_NODELAY unset.
486  */
TEST(AsyncSSLSocketTest,SocketWithDelay)487 TEST(AsyncSSLSocketTest, SocketWithDelay) {
488   // Start listening on a local port
489   WriteCallbackBase writeCallback;
490   ReadCallback readCallback(&writeCallback);
491   HandshakeCallback handshakeCallback(&readCallback);
492   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
493   TestSSLServer server(&acceptCallback);
494 
495   // Set up SSL context.
496   std::shared_ptr<SSLContext> sslContext(new SSLContext());
497   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
498 
499   // connect
500   auto socket =
501       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
502   socket->open();
503 
504   // write()
505   uint8_t buf[128];
506   memset(buf, 'a', sizeof(buf));
507   socket->write(buf, sizeof(buf));
508 
509   // read()
510   uint8_t readbuf[128];
511   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
512   EXPECT_EQ(bytesRead, 128);
513   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
514 
515   // close()
516   socket->close();
517 
518   cerr << "SocketWithDelay test completed" << endl;
519 }
520 
521 #if FOLLY_OPENSSL_HAS_ALPN
522 class NextProtocolTest : public Test {
523   // For matching protos
524  public:
SetUp()525   void SetUp() override { getctx(clientCtx, serverCtx); }
526 
connect(bool unset=false)527   void connect(bool unset = false) {
528     getfds(fds);
529 
530     if (unset) {
531       // unsetting NPN for any of [client, server] is enough to make NPN not
532       // work
533       clientCtx->unsetNextProtocols();
534     }
535 
536     AsyncSSLSocket::UniquePtr clientSock(
537         new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
538     AsyncSSLSocket::UniquePtr serverSock(
539         new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
540     client = std::make_unique<AlpnClient>(std::move(clientSock));
541     server = std::make_unique<AlpnServer>(std::move(serverSock));
542 
543     eventBase.loop();
544   }
545 
expectProtocol(const std::string & proto)546   void expectProtocol(const std::string& proto) {
547     expectHandshakeSuccess();
548     EXPECT_NE(client->nextProtoLength, 0);
549     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
550     EXPECT_EQ(
551         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
552         0);
553     string selected((const char*)client->nextProto, client->nextProtoLength);
554     EXPECT_EQ(proto, selected);
555   }
556 
expectNoProtocol()557   void expectNoProtocol() {
558     expectHandshakeSuccess();
559     EXPECT_EQ(client->nextProtoLength, 0);
560     EXPECT_EQ(server->nextProtoLength, 0);
561     EXPECT_EQ(client->nextProto, nullptr);
562     EXPECT_EQ(server->nextProto, nullptr);
563   }
564 
expectHandshakeSuccess()565   void expectHandshakeSuccess() {
566     EXPECT_FALSE(client->except.has_value())
567         << "client handshake error: " << client->except->what();
568     EXPECT_FALSE(server->except.has_value())
569         << "server handshake error: " << server->except->what();
570   }
571 
expectHandshakeError()572   void expectHandshakeError() {
573     EXPECT_TRUE(client->except.has_value())
574         << "Expected client handshake error!";
575     EXPECT_TRUE(server->except.has_value())
576         << "Expected server handshake error!";
577   }
578 
579   EventBase eventBase;
580   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
581   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
582   NetworkSocket fds[2];
583   std::unique_ptr<AlpnClient> client;
584   std::unique_ptr<AlpnServer> server;
585 };
586 
TEST_F(NextProtocolTest,AlpnTestOverlap)587 TEST_F(NextProtocolTest, AlpnTestOverlap) {
588   clientCtx->setAdvertisedNextProtocols({"blub", "baz"});
589   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
590 
591   connect();
592 
593   expectProtocol("baz");
594 }
595 
TEST_F(NextProtocolTest,AlpnTestUnset)596 TEST_F(NextProtocolTest, AlpnTestUnset) {
597   // Identical to above test, except that we want unset NPN before
598   // looping.
599   clientCtx->setAdvertisedNextProtocols({"blub", "baz"});
600   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
601 
602   connect(true /* unset */);
603 
604   expectNoProtocol();
605 }
606 
TEST_F(NextProtocolTest,AlpnTestNoOverlap)607 TEST_F(NextProtocolTest, AlpnTestNoOverlap) {
608   clientCtx->setAdvertisedNextProtocols({"blub"});
609   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
610   connect();
611 
612   expectNoProtocol();
613 }
614 
TEST_F(NextProtocolTest,RandomizedAlpnTest)615 TEST_F(NextProtocolTest, RandomizedAlpnTest) {
616   // Probability that this test will fail is 2^-64, which could be considered
617   // as negligible.
618   const int kTries = 64;
619 
620   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
621   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}});
622 
623   std::set<string> selectedProtocols;
624   for (int i = 0; i < kTries; ++i) {
625     connect();
626 
627     EXPECT_NE(client->nextProtoLength, 0);
628     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
629     EXPECT_EQ(
630         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
631         0);
632     string selected((const char*)client->nextProto, client->nextProtoLength);
633     selectedProtocols.insert(selected);
634     expectHandshakeSuccess();
635   }
636   EXPECT_EQ(selectedProtocols.size(), 2);
637 }
638 
TEST_F(NextProtocolTest,AlpnNotAllowMismatchNoClientProtocol)639 TEST_F(NextProtocolTest, AlpnNotAllowMismatchNoClientProtocol) {
640   clientCtx->setAdvertisedNextProtocols({});
641   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
642   serverCtx->setAlpnAllowMismatch(false);
643 
644   connect();
645 
646   expectHandshakeSuccess();
647   expectNoProtocol();
648   EXPECT_EQ(server->getClientAlpns(), std::vector<std::string>({}));
649 }
650 
TEST_F(NextProtocolTest,AlpnNotAllowMismatchWithOverlap)651 TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithOverlap) {
652   clientCtx->setAdvertisedNextProtocols({"blub", "baz"});
653   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
654   serverCtx->setAlpnAllowMismatch(false);
655 
656   connect();
657 
658   expectProtocol("baz");
659   EXPECT_EQ(
660       server->getClientAlpns(), std::vector<std::string>({"blub", "baz"}));
661 }
662 
TEST_F(NextProtocolTest,AlpnNotAllowMismatchWithoutOverlap)663 TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithoutOverlap) {
664   clientCtx->setAdvertisedNextProtocols({"blub"});
665   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
666   serverCtx->setAlpnAllowMismatch(false);
667 
668   connect();
669 
670   expectHandshakeError();
671   EXPECT_EQ(server->getClientAlpns(), std::vector<std::string>({"blub"}));
672 }
673 
674 #endif
675 
676 #ifndef OPENSSL_NO_TLSEXT
677 /**
678  * 1. Client sends TLSEXT_HOSTNAME in client hello.
679  * 2. Server found a match SSL_CTX and use this SSL_CTX to
680  *    continue the SSL handshake.
681  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
682  */
TEST(AsyncSSLSocketTest,SNITestMatch)683 TEST(AsyncSSLSocketTest, SNITestMatch) {
684   EventBase eventBase;
685   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
686   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
687   // Use the same SSLContext to continue the handshake after
688   // tlsext_hostname match.
689   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
690   const std::string serverName("xyz.newdev.facebook.com");
691   NetworkSocket fds[2];
692   getfds(fds);
693   getctx(clientCtx, dfServerCtx);
694 
695   AsyncSSLSocket::UniquePtr clientSock(
696       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
697   AsyncSSLSocket::UniquePtr serverSock(
698       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
699   SNIClient client(std::move(clientSock));
700   SNIServer server(
701       std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
702 
703   eventBase.loop();
704 
705   EXPECT_TRUE(client.serverNameMatch);
706   EXPECT_TRUE(server.serverNameMatch);
707 }
708 
709 /**
710  * 1. Client sends TLSEXT_HOSTNAME in client hello.
711  * 2. Server cannot find a matching SSL_CTX and continue to use
712  *    the current SSL_CTX to do the handshake.
713  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
714  */
TEST(AsyncSSLSocketTest,SNITestNotMatch)715 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
716   EventBase eventBase;
717   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
718   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
719   // Use the same SSLContext to continue the handshake after
720   // tlsext_hostname match.
721   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
722   const std::string clientRequestingServerName("foo.com");
723   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
724 
725   NetworkSocket fds[2];
726   getfds(fds);
727   getctx(clientCtx, dfServerCtx);
728 
729   AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
730       clientCtx, &eventBase, fds[0], clientRequestingServerName));
731   AsyncSSLSocket::UniquePtr serverSock(
732       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
733   SNIClient client(std::move(clientSock));
734   SNIServer server(
735       std::move(serverSock),
736       dfServerCtx,
737       hskServerCtx,
738       serverExpectedServerName);
739 
740   eventBase.loop();
741 
742   EXPECT_TRUE(!client.serverNameMatch);
743   EXPECT_TRUE(!server.serverNameMatch);
744 }
745 /**
746  * 1. Client sends TLSEXT_HOSTNAME in client hello.
747  * 2. We then change the serverName.
748  * 3. We expect that we get 'false' as the result for serNameMatch.
749  */
750 
TEST(AsyncSSLSocketTest,SNITestChangeServerName)751 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
752   EventBase eventBase;
753   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
754   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
755   // Use the same SSLContext to continue the handshake after
756   // tlsext_hostname match.
757   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
758   const std::string serverName("xyz.newdev.facebook.com");
759   NetworkSocket fds[2];
760   getfds(fds);
761   getctx(clientCtx, dfServerCtx);
762 
763   AsyncSSLSocket::UniquePtr clientSock(
764       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
765   // Change the server name
766   std::string newName("new.com");
767   clientSock->setServerName(newName);
768   AsyncSSLSocket::UniquePtr serverSock(
769       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
770   SNIClient client(std::move(clientSock));
771   SNIServer server(
772       std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
773 
774   eventBase.loop();
775 
776   EXPECT_TRUE(!client.serverNameMatch);
777 }
778 
779 /**
780  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
781  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
782  */
TEST(AsyncSSLSocketTest,SNITestClientHelloNoHostname)783 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
784   EventBase eventBase;
785   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
786   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
787   // Use the same SSLContext to continue the handshake after
788   // tlsext_hostname match.
789   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
790   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
791 
792   NetworkSocket fds[2];
793   getfds(fds);
794   getctx(clientCtx, dfServerCtx);
795 
796   AsyncSSLSocket::UniquePtr clientSock(
797       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
798   AsyncSSLSocket::UniquePtr serverSock(
799       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
800   SNIClient client(std::move(clientSock));
801   SNIServer server(
802       std::move(serverSock),
803       dfServerCtx,
804       hskServerCtx,
805       serverExpectedServerName);
806 
807   eventBase.loop();
808 
809   EXPECT_TRUE(!client.serverNameMatch);
810   EXPECT_TRUE(!server.serverNameMatch);
811 }
812 
813 /**
814  * 1. Create an SSLContext that does not have an ALPN
815  * 2. Use AsyncSSLSocket::setSupportedApplicationProtocols on the client and
816  * server, and assert that a common ALPN was negotiated.
817  */
TEST(AsyncSSLSocketTest,SetSupportedApplicationProtocols)818 TEST(AsyncSSLSocketTest, SetSupportedApplicationProtocols) {
819   EventBase eventBase;
820   auto clientCtx = std::make_shared<SSLContext>();
821   auto dfServerCtx = std::make_shared<SSLContext>();
822   // Use the same SSLContext to continue the handshake after
823   // tlsext_hostname match.
824   auto hskServerCtx = std::make_shared<SSLContext>();
825   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
826 
827   NetworkSocket fds[2];
828   getfds(fds);
829   getctx(clientCtx, dfServerCtx);
830 
831   AsyncSSLSocket::UniquePtr clientSock(
832       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
833   AsyncSSLSocket::UniquePtr serverSock(
834       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
835 
836   std::vector<std::string> protocols;
837   protocols.push_back("rs");
838 
839   clientSock->setSupportedApplicationProtocols(protocols);
840   serverSock->setSupportedApplicationProtocols(protocols);
841 
842   SNIClient client(std::move(clientSock));
843   SNIServer server(
844       std::move(serverSock),
845       dfServerCtx,
846       hskServerCtx,
847       serverExpectedServerName);
848 
849   eventBase.loop();
850 
851   EXPECT_TRUE(
852       client.getApplicationProtocol().compare(
853           server.getApplicationProtocol()) == 0);
854 }
855 
856 #endif
857 /**
858  * Test SSL client socket
859  */
TEST(AsyncSSLSocketTest,SSLClientTest)860 TEST(AsyncSSLSocketTest, SSLClientTest) {
861   // Start listening on a local port
862   WriteCallbackBase writeCallback;
863   ReadCallback readCallback(&writeCallback);
864   HandshakeCallback handshakeCallback(&readCallback);
865   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
866   TestSSLServer server(&acceptCallback);
867 
868   // Set up SSL client
869   EventBase eventBase;
870   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
871   client->setSSLOptions(SSL_OP_NO_TICKET);
872 
873   client->connect();
874   EventBaseAborter eba(&eventBase, 3000);
875   eventBase.loop();
876 
877   EXPECT_EQ(client->getMiss(), 1);
878   EXPECT_EQ(client->getHit(), 0);
879 
880   cerr << "SSLClientTest test completed" << endl;
881 }
882 
883 /**
884  * Test SSL client socket session re-use
885  */
TEST(AsyncSSLSocketTest,SSLClientTestReuse)886 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
887   // Start listening on a local port
888   WriteCallbackBase writeCallback;
889   ReadCallback readCallback(&writeCallback);
890   HandshakeCallback handshakeCallback(&readCallback);
891   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
892   TestSSLServer server(&acceptCallback);
893 
894   // Set up SSL client
895   EventBase eventBase;
896   auto client =
897       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
898   client->setSSLOptions(SSL_OP_NO_TICKET);
899 
900   client->connect();
901   EventBaseAborter eba(&eventBase, 3000);
902   eventBase.loop();
903 
904   EXPECT_EQ(client->getMiss(), 1);
905   EXPECT_EQ(client->getHit(), 9);
906 
907   cerr << "SSLClientTestReuse test completed" << endl;
908 }
909 
910 /**
911  * Test SSL client socket timeout
912  */
TEST(AsyncSSLSocketTest,SSLClientTimeoutTest)913 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
914   // Start listening on a local port
915   EmptyReadCallback readCallback;
916   HandshakeCallback handshakeCallback(
917       &readCallback, HandshakeCallback::EXPECT_ERROR);
918   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
919   TestSSLServer server(&acceptCallback);
920 
921   // Set up SSL client
922   EventBase eventBase;
923   auto client =
924       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
925   client->setSSLOptions(SSL_OP_NO_TICKET);
926 
927   client->connect(true /* write before connect completes */);
928   EventBaseAborter eba(&eventBase, 3000);
929   eventBase.loop();
930 
931   usleep(100000);
932   // This is checking that the connectError callback precedes any queued
933   // writeError callbacks.  This matches AsyncSocket's behavior
934   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
935   EXPECT_EQ(client->getErrors(), 1);
936   EXPECT_EQ(client->getMiss(), 0);
937   EXPECT_EQ(client->getHit(), 0);
938 
939   cerr << "SSLClientTimeoutTest test completed" << endl;
940 }
941 
942 class PerLoopReadCallback : public AsyncTransport::ReadCallback {
943  public:
getReadBuffer(void ** bufReturn,size_t * lenReturn)944   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
945     *bufReturn = buf_.data();
946     *lenReturn = buf_.size();
947   }
948 
readDataAvailable(size_t len)949   void readDataAvailable(size_t len) noexcept override {
950     VLOG(3) << "Read of size: " << len;
951     s_->setReadCB(nullptr);
952     s_->getEventBase()->runInLoop([this]() { s_->setReadCB(this); });
953   }
954 
readErr(const AsyncSocketException &)955   void readErr(const AsyncSocketException&) noexcept override {}
956 
readEOF()957   void readEOF() noexcept override {}
958 
setSocket(AsyncSocket * s)959   void setSocket(AsyncSocket* s) { s_ = s; }
960 
961  private:
962   AsyncSocket* s_;
963   std::array<uint8_t, 1000> buf_;
964 };
965 
966 class CloseNotifyConnector : public AsyncSocket::ConnectCallback {
967  public:
CloseNotifyConnector(EventBase * evb,const SocketAddress & addr)968   CloseNotifyConnector(EventBase* evb, const SocketAddress& addr) {
969     evb_ = evb;
970     ssl_ = AsyncSSLSocket::newSocket(std::make_shared<SSLContext>(), evb_);
971     ssl_->connect(this, addr);
972   }
973 
connectSuccess()974   void connectSuccess() noexcept override {
975     ssl_->writeChain(nullptr, IOBuf::copyBuffer("hi"));
976     auto ssl = const_cast<SSL*>(ssl_->getSSL());
977     SSL_shutdown(ssl);
978     auto fd = ssl_->detachNetworkSocket();
979     tcp_.reset(new AsyncSocket(evb_, fd), AsyncSocket::Destructor());
980     evb_->runAfterDelay(
981         [this]() {
982           perLoopReads_.setSocket(tcp_.get());
983           tcp_->setReadCB(&perLoopReads_);
984           evb_->runAfterDelay([this]() { tcp_->closeNow(); }, 10);
985         },
986         100);
987   }
988 
connectErr(const AsyncSocketException & ex)989   void connectErr(const AsyncSocketException& ex) noexcept override {
990     FAIL() << ex.what();
991   }
992 
993  private:
994   EventBase* evb_;
995   std::shared_ptr<AsyncSSLSocket> ssl_;
996   std::shared_ptr<AsyncSocket> tcp_;
997   PerLoopReadCallback perLoopReads_;
998 };
999 
1000 class ErrorCheckingWriteCallback : public AsyncSocket::WriteCallback {
1001  public:
writeSuccess()1002   void writeSuccess() noexcept override {}
1003 
writeErr(size_t,const AsyncSocketException & ex)1004   void writeErr(size_t, const AsyncSocketException& ex) noexcept override {
1005     LOG(ERROR) << "write error: " << ex.what();
1006     EXPECT_NE(
1007         ex.getType(),
1008         AsyncSocketException::AsyncSocketExceptionType::SSL_ERROR);
1009   }
1010 };
1011 
1012 class WriteOnEofReadCallback : public ReadCallback {
1013  public:
1014   using ReadCallback::ReadCallback;
1015 
readEOF()1016   void readEOF() noexcept override {
1017     LOG(INFO) << "Got EOF";
1018     auto chain = IOBuf::create(0);
1019     for (size_t i = 0; i < 1000 * 1000; i++) {
1020       auto buf = IOBuf::create(10);
1021       buf->append(10);
1022       memset(buf->writableData(), 'x', 10);
1023       chain->prependChain(std::move(buf));
1024     }
1025     socket_->writeChain(&writeCallback_, std::move(chain));
1026   }
1027 
readErr(const AsyncSocketException & ex)1028   void readErr(const AsyncSocketException& ex) noexcept override {
1029     LOG(ERROR) << ex.what();
1030   }
1031 
1032  private:
1033   ErrorCheckingWriteCallback writeCallback_;
1034 };
1035 
TEST(AsyncSSLSocketTest,EarlyCloseNotify)1036 TEST(AsyncSSLSocketTest, EarlyCloseNotify) {
1037   WriteOnEofReadCallback readCallback(nullptr);
1038   HandshakeCallback handshakeCallback(&readCallback);
1039   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1040   TestSSLServer server(&acceptCallback);
1041 
1042   EventBase eventBase;
1043   CloseNotifyConnector cnc(&eventBase, server.getAddress());
1044 
1045   eventBase.loop();
1046 }
1047 
1048 /**
1049  * Verify Client Ciphers obtained using SSL MSG Callback.
1050  */
TEST(AsyncSSLSocketTest,SSLParseClientHelloSuccess)1051 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1052   EventBase eventBase;
1053   auto clientCtx = std::make_shared<SSLContext>();
1054   auto serverCtx = std::make_shared<SSLContext>();
1055   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1056   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1057   serverCtx->loadPrivateKey(kTestKey);
1058   serverCtx->loadCertificate(kTestCert);
1059   serverCtx->loadTrustedCertificates(kTestCA);
1060   serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
1061 
1062   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1063   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1064   clientCtx->loadPrivateKey(kTestKey);
1065   clientCtx->loadCertificate(kTestCert);
1066   clientCtx->loadTrustedCertificates(kTestCA);
1067 
1068   NetworkSocket fds[2];
1069   getfds(fds);
1070 
1071   AsyncSSLSocket::UniquePtr clientSock(
1072       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1073   AsyncSSLSocket::UniquePtr serverSock(
1074       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1075 
1076   SSLHandshakeClient client(std::move(clientSock), true, true);
1077   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1078 
1079   eventBase.loop();
1080 
1081 #if defined(OPENSSL_IS_BORINGSSL)
1082   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1083 #else
1084   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1085 #endif
1086   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1087   EXPECT_TRUE(client.handshakeVerify_);
1088   EXPECT_TRUE(client.handshakeSuccess_);
1089   EXPECT_TRUE(!client.handshakeError_);
1090   EXPECT_TRUE(server.handshakeVerify_);
1091   EXPECT_TRUE(server.handshakeSuccess_);
1092   EXPECT_TRUE(!server.handshakeError_);
1093 }
1094 
1095 /**
1096  * Verify that server is able to get client cert by getPeerCert() API.
1097  */
TEST(AsyncSSLSocketTest,GetClientCertificate)1098 TEST(AsyncSSLSocketTest, GetClientCertificate) {
1099   EventBase eventBase;
1100   auto clientCtx = std::make_shared<SSLContext>();
1101   auto serverCtx = std::make_shared<SSLContext>();
1102   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1103   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1104   serverCtx->loadPrivateKey(kTestKey);
1105   serverCtx->loadCertificate(kTestCert);
1106   serverCtx->loadTrustedCertificates(kClientTestCA);
1107   serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kClientTestCA);
1108 
1109   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1110   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1111   clientCtx->loadPrivateKey(kClientTestKey);
1112   clientCtx->loadCertificate(kClientTestCert);
1113   clientCtx->loadTrustedCertificates(kTestCA);
1114 
1115   std::array<NetworkSocket, 2> fds;
1116   getfds(fds.data());
1117 
1118   AsyncSSLSocket::UniquePtr clientSock(
1119       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1120   AsyncSSLSocket::UniquePtr serverSock(
1121       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1122 
1123   SSLHandshakeClient client(std::move(clientSock), true, true);
1124   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1125 
1126   eventBase.loop();
1127 
1128   // Handshake should succeed.
1129   EXPECT_TRUE(client.handshakeSuccess_);
1130   EXPECT_TRUE(server.handshakeSuccess_);
1131 
1132   // Reclaim the sockets from SSLHandshakeBase.
1133   auto cliSocket = std::move(client).moveSocket();
1134   auto srvSocket = std::move(server).moveSocket();
1135 
1136   // Client cert retrieved from server side.
1137   auto serverPeerCert = srvSocket->getPeerCertificate();
1138   CHECK(serverPeerCert);
1139 
1140   // Client cert retrieved from client side.
1141   auto clientSelfCert = cliSocket->getSelfCertificate();
1142   CHECK(clientSelfCert);
1143 
1144   auto serverX509 =
1145       folly::OpenSSLTransportCertificate::tryExtractX509(serverPeerCert);
1146   CHECK(serverX509);
1147 
1148   auto clientX509 =
1149       folly::OpenSSLTransportCertificate::tryExtractX509(clientSelfCert);
1150   CHECK(clientX509);
1151 
1152   // The two certs should be the same.
1153   EXPECT_EQ(0, X509_cmp(clientX509.get(), serverX509.get()));
1154 }
1155 
TEST(AsyncSSLSocketTest,SSLParseClientHelloOnePacket)1156 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1157   EventBase eventBase;
1158   auto ctx = std::make_shared<SSLContext>();
1159 
1160   NetworkSocket fds[2];
1161   getfds(fds);
1162 
1163   int bufLen = 42;
1164   uint8_t majorVersion = 18;
1165   uint8_t minorVersion = 25;
1166 
1167   // Create callback buf
1168   auto buf = IOBuf::create(bufLen);
1169   buf->append(bufLen);
1170   folly::io::RWPrivateCursor cursor(buf.get());
1171   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1172   cursor.write<uint16_t>(0);
1173   cursor.write<uint8_t>(38);
1174   cursor.write<uint8_t>(majorVersion);
1175   cursor.write<uint8_t>(minorVersion);
1176   cursor.skip(32);
1177   cursor.write<uint32_t>(0);
1178 
1179   SSL* ssl = ctx->createSSL();
1180   SCOPE_EXIT { SSL_free(ssl); };
1181   AsyncSSLSocket::UniquePtr sock(
1182       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1183   sock->enableClientHelloParsing();
1184 
1185   // Test client hello parsing in one packet
1186   AsyncSSLSocket::clientHelloParsingCallback(
1187       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1188   buf.reset();
1189 
1190   auto parsedClientHello = sock->getClientHelloInfo();
1191   EXPECT_TRUE(parsedClientHello != nullptr);
1192   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1193   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1194 }
1195 
TEST(AsyncSSLSocketTest,SSLParseClientHelloTwoPackets)1196 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1197   EventBase eventBase;
1198   auto ctx = std::make_shared<SSLContext>();
1199 
1200   NetworkSocket fds[2];
1201   getfds(fds);
1202 
1203   int bufLen = 42;
1204   uint8_t majorVersion = 18;
1205   uint8_t minorVersion = 25;
1206 
1207   // Create callback buf
1208   auto buf = IOBuf::create(bufLen);
1209   buf->append(bufLen);
1210   folly::io::RWPrivateCursor cursor(buf.get());
1211   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1212   cursor.write<uint16_t>(0);
1213   cursor.write<uint8_t>(38);
1214   cursor.write<uint8_t>(majorVersion);
1215   cursor.write<uint8_t>(minorVersion);
1216   cursor.skip(32);
1217   cursor.write<uint32_t>(0);
1218 
1219   SSL* ssl = ctx->createSSL();
1220   SCOPE_EXIT { SSL_free(ssl); };
1221   AsyncSSLSocket::UniquePtr sock(
1222       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1223   sock->enableClientHelloParsing();
1224 
1225   // Test parsing with two packets with first packet size < 3
1226   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1227   AsyncSSLSocket::clientHelloParsingCallback(
1228       0,
1229       0,
1230       SSL3_RT_HANDSHAKE,
1231       bufCopy->data(),
1232       bufCopy->length(),
1233       ssl,
1234       sock.get());
1235   bufCopy.reset();
1236   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1237   AsyncSSLSocket::clientHelloParsingCallback(
1238       0,
1239       0,
1240       SSL3_RT_HANDSHAKE,
1241       bufCopy->data(),
1242       bufCopy->length(),
1243       ssl,
1244       sock.get());
1245   bufCopy.reset();
1246 
1247   auto parsedClientHello = sock->getClientHelloInfo();
1248   EXPECT_TRUE(parsedClientHello != nullptr);
1249   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1250   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1251 }
1252 
TEST(AsyncSSLSocketTest,SSLParseClientHelloMultiplePackets)1253 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1254   EventBase eventBase;
1255   auto ctx = std::make_shared<SSLContext>();
1256 
1257   NetworkSocket fds[2];
1258   getfds(fds);
1259 
1260   int bufLen = 42;
1261   uint8_t majorVersion = 18;
1262   uint8_t minorVersion = 25;
1263 
1264   // Create callback buf
1265   auto buf = IOBuf::create(bufLen);
1266   buf->append(bufLen);
1267   folly::io::RWPrivateCursor cursor(buf.get());
1268   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1269   cursor.write<uint16_t>(0);
1270   cursor.write<uint8_t>(38);
1271   cursor.write<uint8_t>(majorVersion);
1272   cursor.write<uint8_t>(minorVersion);
1273   cursor.skip(32);
1274   cursor.write<uint32_t>(0);
1275 
1276   SSL* ssl = ctx->createSSL();
1277   SCOPE_EXIT { SSL_free(ssl); };
1278   AsyncSSLSocket::UniquePtr sock(
1279       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1280   sock->enableClientHelloParsing();
1281 
1282   // Test parsing with multiple small packets
1283   for (std::size_t i = 0; i < buf->length(); i += 3) {
1284     auto bufCopy = folly::IOBuf::copyBuffer(
1285         buf->data() + i, std::min((std::size_t)3, buf->length() - i));
1286     AsyncSSLSocket::clientHelloParsingCallback(
1287         0,
1288         0,
1289         SSL3_RT_HANDSHAKE,
1290         bufCopy->data(),
1291         bufCopy->length(),
1292         ssl,
1293         sock.get());
1294     bufCopy.reset();
1295   }
1296 
1297   auto parsedClientHello = sock->getClientHelloInfo();
1298   EXPECT_TRUE(parsedClientHello != nullptr);
1299   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1300   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1301 }
1302 
1303 /**
1304  * Verify sucessful behavior of SSL certificate validation.
1305  */
TEST(AsyncSSLSocketTest,SSLHandshakeValidationSuccess)1306 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1307   EventBase eventBase;
1308   auto clientCtx = std::make_shared<SSLContext>();
1309   auto dfServerCtx = std::make_shared<SSLContext>();
1310 
1311   NetworkSocket fds[2];
1312   getfds(fds);
1313   getctx(clientCtx, dfServerCtx);
1314 
1315   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1316   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1317 
1318   AsyncSSLSocket::UniquePtr clientSock(
1319       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1320   AsyncSSLSocket::UniquePtr serverSock(
1321       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1322 
1323   SSLHandshakeClient client(std::move(clientSock), true, true);
1324   clientCtx->loadTrustedCertificates(kTestCA);
1325 
1326   SSLHandshakeServer server(std::move(serverSock), true, true);
1327 
1328   eventBase.loop();
1329 
1330   EXPECT_TRUE(client.handshakeVerify_);
1331   EXPECT_TRUE(client.handshakeSuccess_);
1332   EXPECT_TRUE(!client.handshakeError_);
1333   EXPECT_LE(0, client.handshakeTime.count());
1334   EXPECT_TRUE(!server.handshakeVerify_);
1335   EXPECT_TRUE(server.handshakeSuccess_);
1336   EXPECT_TRUE(!server.handshakeError_);
1337   EXPECT_LE(0, server.handshakeTime.count());
1338 }
1339 
1340 /**
1341  * Verify that the client's verification callback is able to fail SSL
1342  * connection establishment.
1343  */
TEST(AsyncSSLSocketTest,SSLHandshakeValidationFailure)1344 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1345   EventBase eventBase;
1346   auto clientCtx = std::make_shared<SSLContext>();
1347   auto dfServerCtx = std::make_shared<SSLContext>();
1348 
1349   NetworkSocket fds[2];
1350   getfds(fds);
1351   getctx(clientCtx, dfServerCtx);
1352 
1353   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1354   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1355 
1356   AsyncSSLSocket::UniquePtr clientSock(
1357       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1358   AsyncSSLSocket::UniquePtr serverSock(
1359       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1360 
1361   SSLHandshakeClient client(std::move(clientSock), true, false);
1362   clientCtx->loadTrustedCertificates(kTestCA);
1363 
1364   SSLHandshakeServer server(std::move(serverSock), true, true);
1365 
1366   eventBase.loop();
1367 
1368   EXPECT_TRUE(client.handshakeVerify_);
1369   EXPECT_TRUE(!client.handshakeSuccess_);
1370   EXPECT_TRUE(client.handshakeError_);
1371   EXPECT_LE(0, client.handshakeTime.count());
1372   EXPECT_TRUE(!server.handshakeVerify_);
1373   EXPECT_TRUE(!server.handshakeSuccess_);
1374   EXPECT_TRUE(server.handshakeError_);
1375   EXPECT_LE(0, server.handshakeTime.count());
1376 }
1377 
1378 /**
1379  * Verify that the client successfully handshakes when
1380  * CertificateIdentityVerifier is set and returns with no exception.
1381  */
TEST(AsyncSSLSocketTest,SSLCertificateIdentityVerifierReturns)1382 TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierReturns) {
1383   EventBase eventBase;
1384   auto clientCtx = std::make_shared<folly::SSLContext>();
1385   auto serverCtx = std::make_shared<folly::SSLContext>();
1386   getctx(clientCtx, serverCtx);
1387   // the client socket will default to USE_CTX, so set VERIFY here
1388   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1389   // load root certificate
1390   clientCtx->loadTrustedCertificates(kTestCA);
1391 
1392   // prepare a basic server (callbacks have a few EXPECTS to fullfil)
1393   ReadCallback readCallback(nullptr);
1394   // expects successful handshake
1395   HandshakeCallback handshakeCallback(&readCallback);
1396   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1397   TestSSLServer server(&acceptCallback, serverCtx);
1398 
1399   // return success in the Try with folly::unit
1400   Try<Unit> verifyResult{unit};
1401 
1402   std::shared_ptr<MockCertificateIdentityVerifier> verifier =
1403       std::make_shared<MockCertificateIdentityVerifier>();
1404 
1405   // expecting to only verify once, with the leaf certificate
1406   // (kTestCert)
1407   EXPECT_CALL(
1408       *verifier,
1409       verifyLeafImpl(Property(
1410           &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
1411       .WillOnce(Return(ByMove(verifyResult)));
1412 
1413   AsyncSSLSocket::Options opts;
1414   opts.verifier = std::move(verifier);
1415 
1416   // connect to server and handshake
1417   AsyncSSLSocket::UniquePtr socket(
1418       new AsyncSSLSocket(clientCtx, &eventBase, std::move(opts)));
1419   socket->connect(nullptr, server.getAddress(), 0);
1420 
1421   // write to satisfy server ReadCallback EXPECTs
1422   std::array<uint8_t, 128> buf;
1423   memset(buf.data(), 'a', buf.size());
1424   socket->write(nullptr, buf.data(), buf.size());
1425 
1426   eventBase.loop();
1427 
1428   socket->close();
1429 }
1430 
1431 class TestCertificateIdentityVerifierException
1432     : public CertificateIdentityVerifierException {
1433  public:
TestCertificateIdentityVerifierException(const char * content)1434   explicit TestCertificateIdentityVerifierException(const char* content)
1435       : CertificateIdentityVerifierException(content) {}
1436 };
1437 
1438 /**
1439  * Verify that the client fails to connect during handshake because
1440  * CertificateIdentityVerifier returns a failure while verifying the server's
1441  * leaf certificate.
1442  */
TEST(AsyncSSLSocketTest,SSLCertificateIdentityVerifierFailsToConnect)1443 TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierFailsToConnect) {
1444   EventBase eventBase;
1445   auto clientCtx = std::make_shared<folly::SSLContext>();
1446   auto serverCtx = std::make_shared<folly::SSLContext>();
1447   getctx(clientCtx, serverCtx);
1448   // the client socket will default to USE_CTX, so set VERIFY here
1449   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1450   // load root certificate
1451   clientCtx->loadTrustedCertificates(kTestCA);
1452 
1453   // prepare a basic server (callbacks have a few EXPECTS to fullfil)
1454   ReadCallback readCallback(nullptr);
1455   // expects a failed handshake
1456   HandshakeCallback handshakeCallback(
1457       &readCallback, HandshakeCallback::ExpectType::EXPECT_ERROR);
1458   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1459   TestSSLServer server(&acceptCallback, serverCtx);
1460 
1461   std::shared_ptr<MockCertificateIdentityVerifier> verifier =
1462       std::make_shared<MockCertificateIdentityVerifier>();
1463 
1464   // return a failed result Try
1465   TestCertificateIdentityVerifierException failed{"a failed test reason"};
1466   Try<Unit> result{failed};
1467 
1468   // expecting to only verify once, with the leaf certificate (kTestCert)
1469   EXPECT_CALL(
1470       *verifier,
1471       verifyLeafImpl(Property(
1472           &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
1473       .WillOnce(Return(ByMove(result)));
1474 
1475   AsyncSSLSocket::Options opts;
1476   opts.verifier = std::move(verifier);
1477 
1478   // connect to server and handshake
1479   AsyncSSLSocket::UniquePtr socket(
1480       new AsyncSSLSocket(clientCtx, &eventBase, std::move(opts)));
1481   socket->connect(nullptr, server.getAddress(), 0);
1482 
1483   eventBase.loop();
1484 
1485   socket->close();
1486 }
1487 
1488 /**
1489  * Verify that the client's CertificateIdentityVerifier is not invoked if
1490  * OpenSSL's verification fails. (With no HandshakeCB.)
1491  */
TEST(AsyncSSLSocketTest,SSLCertificateIdentityVerifierNotInvokedX509Failure)1492 TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierNotInvokedX509Failure) {
1493   EventBase eventBase;
1494   auto clientCtx = std::make_shared<folly::SSLContext>();
1495   auto serverCtx = std::make_shared<folly::SSLContext>();
1496   getctx(clientCtx, serverCtx);
1497   // the client socket will default to USE_CTX, so set VERIFY here
1498   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1499   // DO NOT load root certificate, so that server certificate is rejected
1500 
1501   // prepare a basic server (callbacks have a few EXPECTS to fullfil)
1502   ReadCallback readCallback(nullptr);
1503   // expects successful handshake
1504   HandshakeCallback handshakeCallback(
1505       &readCallback, HandshakeCallback::ExpectType::EXPECT_ERROR);
1506   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1507   TestSSLServer server(&acceptCallback, serverCtx);
1508 
1509   // should not get called
1510   std::shared_ptr<StrictMock<MockCertificateIdentityVerifier>> verifier =
1511       std::make_shared<StrictMock<MockCertificateIdentityVerifier>>();
1512 
1513   AsyncSSLSocket::Options opts;
1514   opts.verifier = std::move(verifier);
1515 
1516   // connect to server and handshake
1517   AsyncSSLSocket::UniquePtr socket(
1518       new AsyncSSLSocket(clientCtx, &eventBase, std::move(opts)));
1519   socket->connect(nullptr, server.getAddress(), 0);
1520 
1521   eventBase.loop();
1522 
1523   socket->close();
1524 }
1525 
1526 /**
1527  * Verify that the client CertificateIdentityVerifier is not invoked if
1528  * HandshakeCB::handshakeVer verification fails.
1529  */
TEST(AsyncSSLSocketTest,SSLCertificateIdentityVerifierNotInvokedHandshakeCBFailure)1530 TEST(
1531     AsyncSSLSocketTest,
1532     SSLCertificateIdentityVerifierNotInvokedHandshakeCBFailure) {
1533   EventBase eventBase;
1534   auto clientCtx = std::make_shared<folly::SSLContext>();
1535   auto serverCtx = std::make_shared<folly::SSLContext>();
1536   getctx(clientCtx, serverCtx);
1537   // the client socket will default to USE_CTX, so set VERIFY here
1538   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1539   // load root certificate
1540   clientCtx->loadTrustedCertificates(kTestCA);
1541 
1542   NetworkSocket fds[2];
1543   getfds(fds);
1544 
1545   AsyncSocket::UniquePtr rawClient(new AsyncSocket(&eventBase, fds[0]));
1546   AsyncSocket::UniquePtr rawServer(new AsyncSocket(&eventBase, fds[1]));
1547 
1548   // should not be invoked
1549   std::shared_ptr<StrictMock<MockCertificateIdentityVerifier>> verifier =
1550       std::make_shared<StrictMock<MockCertificateIdentityVerifier>>();
1551 
1552   AsyncSSLSocket::Options clientOpts;
1553   clientOpts.verifier = verifier;
1554 
1555   AsyncSSLSocket::Options serverOpts;
1556   serverOpts.isServer = true;
1557 
1558   AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
1559       clientCtx, std::move(rawClient), std::move(clientOpts)));
1560   AsyncSSLSocket::UniquePtr serverSock(new AsyncSSLSocket(
1561       serverCtx, std::move(rawServer), std::move(serverOpts)));
1562 
1563   serverSock->sslAccept(nullptr, std::chrono::milliseconds::zero());
1564 
1565   StrictMock<MockHandshakeCB> clientHandshakeCB;
1566 
1567   // Force the end entity certificate, which normally is successfully verified,
1568   // to be considered as unsuccessful
1569   EXPECT_CALL(clientHandshakeCB, handshakeVerImpl(clientSock.get(), true, _))
1570       .Times(AtLeast(1))
1571       .WillRepeatedly(Invoke([&](auto&&, bool preverifyOk, auto&& ctx) {
1572         auto currentDepth = X509_STORE_CTX_get_error_depth(ctx);
1573         if (currentDepth == 0) {
1574           EXPECT_TRUE(preverifyOk);
1575           return false;
1576         }
1577         return preverifyOk;
1578       }));
1579 
1580   // failure callback to verify handshake failed
1581   EXPECT_CALL(clientHandshakeCB, handshakeErrImpl(clientSock.get(), _));
1582 
1583   clientSock->sslConn(&clientHandshakeCB);
1584 
1585   eventBase.loop();
1586 
1587   clientSock->close();
1588   serverSock->close();
1589 }
1590 
1591 /**
1592  * Verify that the client CertificateIdentityVerifier is invoked on a server
1593  * socket when peer verification is requested.
1594  */
TEST(AsyncSSLSocketTest,SSLCertificateIdentityVerifierSucceedsOnServer)1595 TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierSucceedsOnServer) {
1596   EventBase eventBase;
1597   auto clientCtx = std::make_shared<folly::SSLContext>();
1598   auto serverCtx = std::make_shared<folly::SSLContext>();
1599   getctx(clientCtx, serverCtx);
1600   // the client socket will default to USE_CTX, so set VERIFY here
1601   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1602   // load root certificate
1603   clientCtx->loadTrustedCertificates(kTestCA);
1604   // load identity and key on client, it's the same identity as server just for
1605   // convenience
1606   clientCtx->loadCertificate(kTestCert);
1607   clientCtx->loadPrivateKey(kTestKey);
1608   // instruct server to verify client
1609   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1610   serverCtx->loadTrustedCertificates(kTestCA);
1611 
1612   NetworkSocket fds[2];
1613   getfds(fds);
1614 
1615   AsyncSocket::UniquePtr rawClient(new AsyncSocket(&eventBase, fds[0]));
1616   AsyncSocket::UniquePtr rawServer(new AsyncSocket(&eventBase, fds[1]));
1617 
1618   // client and server verifiers should verify only once each
1619   std::shared_ptr<MockCertificateIdentityVerifier> clientVerifier =
1620       std::make_shared<MockCertificateIdentityVerifier>();
1621   EXPECT_CALL(
1622       *clientVerifier,
1623       verifyLeafImpl(Property(
1624           &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
1625       .WillOnce(Return(Try<Unit>{unit}));
1626   std::shared_ptr<StrictMock<MockCertificateIdentityVerifier>> serverVerifier =
1627       std::make_shared<StrictMock<MockCertificateIdentityVerifier>>();
1628   EXPECT_CALL(
1629       *serverVerifier,
1630       verifyLeafImpl(Property(
1631           &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
1632       .WillOnce(Return(Try<Unit>{unit}));
1633 
1634   AsyncSSLSocket::Options clientOpts;
1635   clientOpts.verifier = clientVerifier;
1636 
1637   AsyncSSLSocket::Options serverOpts;
1638   serverOpts.isServer = true;
1639   serverOpts.verifier = serverVerifier;
1640 
1641   AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
1642       clientCtx, std::move(rawClient), std::move(clientOpts)));
1643   AsyncSSLSocket::UniquePtr serverSock(new AsyncSSLSocket(
1644       serverCtx, std::move(rawServer), std::move(serverOpts)));
1645 
1646   // no HandshakeCBs anywhere
1647   serverSock->sslAccept(nullptr, std::chrono::milliseconds::zero());
1648   clientSock->sslConn(nullptr);
1649 
1650   eventBase.loop();
1651 
1652   clientSock->close();
1653   serverSock->close();
1654 }
1655 
1656 /**
1657  * Verify that the options in SSLContext can be overridden in
1658  * sslConnect/Accept.i.e specifying that no validation should be performed
1659  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1660  * the validation callback.
1661  */
TEST(AsyncSSLSocketTest,OverrideSSLCtxDisableVerify)1662 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1663   EventBase eventBase;
1664   auto clientCtx = std::make_shared<SSLContext>();
1665   auto dfServerCtx = std::make_shared<SSLContext>();
1666 
1667   NetworkSocket fds[2];
1668   getfds(fds);
1669   getctx(clientCtx, dfServerCtx);
1670 
1671   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1672   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1673 
1674   AsyncSSLSocket::UniquePtr clientSock(
1675       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1676   AsyncSSLSocket::UniquePtr serverSock(
1677       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1678 
1679   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1680   clientCtx->loadTrustedCertificates(kTestCA);
1681 
1682   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1683 
1684   eventBase.loop();
1685 
1686   EXPECT_TRUE(!client.handshakeVerify_);
1687   EXPECT_TRUE(client.handshakeSuccess_);
1688   EXPECT_TRUE(!client.handshakeError_);
1689   EXPECT_LE(0, client.handshakeTime.count());
1690   EXPECT_TRUE(!server.handshakeVerify_);
1691   EXPECT_TRUE(server.handshakeSuccess_);
1692   EXPECT_TRUE(!server.handshakeError_);
1693   EXPECT_LE(0, server.handshakeTime.count());
1694 }
1695 
1696 /**
1697  * Verify that the options in SSLContext can be overridden in
1698  * sslConnect/Accept. Enable verification even if context says otherwise.
1699  * Test requireClientCert with client cert
1700  */
TEST(AsyncSSLSocketTest,OverrideSSLCtxEnableVerify)1701 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1702   EventBase eventBase;
1703   auto clientCtx = std::make_shared<SSLContext>();
1704   auto serverCtx = std::make_shared<SSLContext>();
1705   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1706   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1707   serverCtx->loadPrivateKey(kTestKey);
1708   serverCtx->loadCertificate(kTestCert);
1709   serverCtx->loadTrustedCertificates(kTestCA);
1710   serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
1711 
1712   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1713   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1714   clientCtx->loadPrivateKey(kTestKey);
1715   clientCtx->loadCertificate(kTestCert);
1716   clientCtx->loadTrustedCertificates(kTestCA);
1717 
1718   NetworkSocket fds[2];
1719   getfds(fds);
1720 
1721   AsyncSSLSocket::UniquePtr clientSock(
1722       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1723   AsyncSSLSocket::UniquePtr serverSock(
1724       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1725 
1726   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1727   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1728 
1729   eventBase.loop();
1730 
1731   EXPECT_TRUE(client.handshakeVerify_);
1732   EXPECT_TRUE(client.handshakeSuccess_);
1733   EXPECT_FALSE(client.handshakeError_);
1734   EXPECT_LE(0, client.handshakeTime.count());
1735   EXPECT_TRUE(server.handshakeVerify_);
1736   EXPECT_TRUE(server.handshakeSuccess_);
1737   EXPECT_FALSE(server.handshakeError_);
1738   EXPECT_LE(0, server.handshakeTime.count());
1739 }
1740 
1741 /**
1742  * Verify that the client's verification callback is able to override
1743  * the preverification failure and allow a successful connection.
1744  */
TEST(AsyncSSLSocketTest,SSLHandshakeValidationOverride)1745 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1746   EventBase eventBase;
1747   auto clientCtx = std::make_shared<SSLContext>();
1748   auto dfServerCtx = std::make_shared<SSLContext>();
1749 
1750   NetworkSocket fds[2];
1751   getfds(fds);
1752   getctx(clientCtx, dfServerCtx);
1753 
1754   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1755   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1756 
1757   AsyncSSLSocket::UniquePtr clientSock(
1758       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1759   AsyncSSLSocket::UniquePtr serverSock(
1760       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1761 
1762   SSLHandshakeClient client(std::move(clientSock), false, true);
1763   SSLHandshakeServer server(std::move(serverSock), true, true);
1764 
1765   eventBase.loop();
1766 
1767   EXPECT_TRUE(client.handshakeVerify_);
1768   EXPECT_TRUE(client.handshakeSuccess_);
1769   EXPECT_TRUE(!client.handshakeError_);
1770   EXPECT_LE(0, client.handshakeTime.count());
1771   EXPECT_TRUE(!server.handshakeVerify_);
1772   EXPECT_TRUE(server.handshakeSuccess_);
1773   EXPECT_TRUE(!server.handshakeError_);
1774   EXPECT_LE(0, server.handshakeTime.count());
1775 }
1776 
1777 /**
1778  * Verify that specifying that no validation should be performed allows an
1779  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1780  * callback.
1781  */
TEST(AsyncSSLSocketTest,SSLHandshakeValidationSkip)1782 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1783   EventBase eventBase;
1784   auto clientCtx = std::make_shared<SSLContext>();
1785   auto dfServerCtx = std::make_shared<SSLContext>();
1786 
1787   NetworkSocket fds[2];
1788   getfds(fds);
1789   getctx(clientCtx, dfServerCtx);
1790 
1791   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1792   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1793 
1794   AsyncSSLSocket::UniquePtr clientSock(
1795       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1796   AsyncSSLSocket::UniquePtr serverSock(
1797       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1798 
1799   SSLHandshakeClient client(std::move(clientSock), false, false);
1800   SSLHandshakeServer server(std::move(serverSock), false, false);
1801 
1802   eventBase.loop();
1803 
1804   EXPECT_TRUE(!client.handshakeVerify_);
1805   EXPECT_TRUE(client.handshakeSuccess_);
1806   EXPECT_TRUE(!client.handshakeError_);
1807   EXPECT_LE(0, client.handshakeTime.count());
1808   EXPECT_TRUE(!server.handshakeVerify_);
1809   EXPECT_TRUE(server.handshakeSuccess_);
1810   EXPECT_TRUE(!server.handshakeError_);
1811   EXPECT_LE(0, server.handshakeTime.count());
1812 }
1813 
1814 /**
1815  * Test requireClientCert with client cert
1816  */
TEST(AsyncSSLSocketTest,ClientCertHandshakeSuccess)1817 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1818   EventBase eventBase;
1819   auto clientCtx = std::make_shared<SSLContext>();
1820   auto serverCtx = std::make_shared<SSLContext>();
1821   serverCtx->setVerificationOption(
1822       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1823   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1824   serverCtx->loadPrivateKey(kTestKey);
1825   serverCtx->loadCertificate(kTestCert);
1826   serverCtx->loadTrustedCertificates(kTestCA);
1827   serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
1828 
1829   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1830   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1831   clientCtx->loadPrivateKey(kTestKey);
1832   clientCtx->loadCertificate(kTestCert);
1833   clientCtx->loadTrustedCertificates(kTestCA);
1834 
1835   NetworkSocket fds[2];
1836   getfds(fds);
1837 
1838   AsyncSSLSocket::UniquePtr clientSock(
1839       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1840   AsyncSSLSocket::UniquePtr serverSock(
1841       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1842 
1843   SSLHandshakeClient client(std::move(clientSock), true, true);
1844   SSLHandshakeServer server(std::move(serverSock), true, true);
1845 
1846   eventBase.loop();
1847 
1848   EXPECT_TRUE(client.handshakeVerify_);
1849   EXPECT_TRUE(client.handshakeSuccess_);
1850   EXPECT_FALSE(client.handshakeError_);
1851   EXPECT_LE(0, client.handshakeTime.count());
1852   EXPECT_TRUE(server.handshakeVerify_);
1853   EXPECT_TRUE(server.handshakeSuccess_);
1854   EXPECT_FALSE(server.handshakeError_);
1855   EXPECT_LE(0, server.handshakeTime.count());
1856 
1857   // check certificates
1858   auto clientSsl = std::move(client).moveSocket();
1859   auto serverSsl = std::move(server).moveSocket();
1860 
1861   auto clientPeer = clientSsl->getPeerCertificate();
1862   auto clientSelf = clientSsl->getSelfCertificate();
1863   auto serverPeer = serverSsl->getPeerCertificate();
1864   auto serverSelf = serverSsl->getSelfCertificate();
1865 
1866   EXPECT_NE(clientPeer, nullptr);
1867   EXPECT_NE(clientSelf, nullptr);
1868   EXPECT_NE(serverPeer, nullptr);
1869   EXPECT_NE(serverSelf, nullptr);
1870 
1871   EXPECT_EQ(clientPeer->getIdentity(), serverSelf->getIdentity());
1872   EXPECT_EQ(clientSelf->getIdentity(), serverPeer->getIdentity());
1873 }
1874 
1875 /**
1876  * Test requireClientCert with no client cert
1877  */
TEST(AsyncSSLSocketTest,NoClientCertHandshakeError)1878 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1879   EventBase eventBase;
1880   auto clientCtx = std::make_shared<SSLContext>();
1881   auto serverCtx = std::make_shared<SSLContext>();
1882   serverCtx->setVerificationOption(
1883       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1884   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1885   serverCtx->loadPrivateKey(kTestKey);
1886   serverCtx->loadCertificate(kTestCert);
1887   serverCtx->loadTrustedCertificates(kTestCA);
1888   serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
1889   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1890   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1891 
1892   NetworkSocket fds[2];
1893   getfds(fds);
1894 
1895   AsyncSSLSocket::UniquePtr clientSock(
1896       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1897   AsyncSSLSocket::UniquePtr serverSock(
1898       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1899 
1900   SSLHandshakeClient client(std::move(clientSock), false, false);
1901   SSLHandshakeServer server(std::move(serverSock), false, false);
1902 
1903   eventBase.loop();
1904 
1905   EXPECT_FALSE(server.handshakeVerify_);
1906   EXPECT_FALSE(server.handshakeSuccess_);
1907   EXPECT_TRUE(server.handshakeError_);
1908   EXPECT_LE(0, client.handshakeTime.count());
1909   EXPECT_LE(0, server.handshakeTime.count());
1910 }
1911 
1912 /**
1913  * Test OpenSSL 1.1.0's async functionality
1914  */
1915 #if FOLLY_OPENSSL_IS_110
1916 
makeNonBlockingPipe(int pipefds[2])1917 static void makeNonBlockingPipe(int pipefds[2]) {
1918   if (pipe(pipefds) != 0) {
1919     throw std::runtime_error("Cannot create pipe");
1920   }
1921   if (::fcntl(pipefds[0], F_SETFL, O_NONBLOCK) != 0) {
1922     throw std::runtime_error("Cannot set pipe to nonblocking");
1923   }
1924   if (::fcntl(pipefds[1], F_SETFL, O_NONBLOCK) != 0) {
1925     throw std::runtime_error("Cannot set pipe to nonblocking");
1926   }
1927 }
1928 
1929 // Custom RSA private key encryption method
1930 static int kRSAExIndex = -1;
1931 static int kRSAEvbExIndex = -1;
1932 static int kRSASocketExIndex = -1;
1933 static constexpr StringPiece kEngineId = "AsyncSSLSocketTest";
1934 
customRsaPrivEnc(int flen,const unsigned char * from,unsigned char * to,RSA * rsa,int padding)1935 static int customRsaPrivEnc(
1936     int flen,
1937     const unsigned char* from,
1938     unsigned char* to,
1939     RSA* rsa,
1940     int padding) {
1941   LOG(INFO) << "rsa_priv_enc";
1942   EventBase* asyncJobEvb =
1943       reinterpret_cast<EventBase*>(RSA_get_ex_data(rsa, kRSAEvbExIndex));
1944   CHECK(asyncJobEvb);
1945 
1946   RSA* actualRSA = reinterpret_cast<RSA*>(RSA_get_ex_data(rsa, kRSAExIndex));
1947   CHECK(actualRSA);
1948 
1949   AsyncSSLSocket* socket = reinterpret_cast<AsyncSSLSocket*>(
1950       RSA_get_ex_data(rsa, kRSASocketExIndex));
1951 
1952   ASYNC_JOB* job = ASYNC_get_current_job();
1953   if (job == nullptr) {
1954     throw std::runtime_error("Expected call in job context");
1955   }
1956   ASYNC_WAIT_CTX* waitctx = ASYNC_get_wait_ctx(job);
1957   OSSL_ASYNC_FD pipefds[2] = {0, 0};
1958   makeNonBlockingPipe(pipefds);
1959   if (!ASYNC_WAIT_CTX_set_wait_fd(
1960           waitctx, kEngineId.data(), pipefds[0], nullptr, nullptr)) {
1961     throw std::runtime_error("Cannot set wait fd");
1962   }
1963   int ret = 0;
1964   int* retptr = &ret;
1965 
1966   auto hand = folly::NetworkSocket::native_handle_type(pipefds[1]);
1967   auto asyncPipeWriter = folly::AsyncPipeWriter::newWriter(
1968       asyncJobEvb, folly::NetworkSocket(hand));
1969 
1970   if (socket) {
1971     LOG(INFO) << "Got a socket passed in, closing it...";
1972     socket->closeNow();
1973   }
1974   asyncJobEvb->runInEventBaseThread([retptr = retptr,
1975                                      flen = flen,
1976                                      from = from,
1977                                      to = to,
1978                                      padding = padding,
1979                                      actualRSA = actualRSA,
1980                                      writer = std::move(asyncPipeWriter)]() {
1981     LOG(INFO) << "Running job";
1982     *retptr = RSA_meth_get_priv_enc(RSA_PKCS1_OpenSSL())(
1983         flen, from, to, actualRSA, padding);
1984     LOG(INFO) << "Finished job, writing to pipe";
1985     uint8_t byte = *retptr > 0 ? 1 : 0;
1986     writer->write(nullptr, &byte, 1);
1987   });
1988 
1989   LOG(INFO) << "About to pause job";
1990 
1991   ASYNC_pause_job();
1992   LOG(INFO) << "Resumed job with ret: " << ret;
1993   return ret;
1994 }
1995 
rsaFree(void *,void * ptr,CRYPTO_EX_DATA *,int,long,void *)1996 void rsaFree(void*, void* ptr, CRYPTO_EX_DATA*, int, long, void*) {
1997   LOG(INFO) << "RSA_free is called with ptr " << std::hex << ptr;
1998   if (ptr == nullptr) {
1999     LOG(INFO) << "Returning early from rsaFree because ptr is null";
2000     return;
2001   }
2002   RSA* rsa = (RSA*)ptr;
2003   auto meth = RSA_get_method(rsa);
2004   if (meth != RSA_get_default_method()) {
2005     auto nonconst = const_cast<RSA_METHOD*>(meth);
2006     RSA_meth_free(nonconst);
2007     RSA_set_method(rsa, RSA_get_default_method());
2008   }
2009   RSA_free(rsa);
2010 }
2011 
2012 struct RSAPointers {
2013   RSA* actualrsa{nullptr};
2014   RSA* dummyrsa{nullptr};
2015   RSA_METHOD* meth{nullptr};
2016 };
2017 
RSAPointersFree(RSAPointers * p)2018 inline void RSAPointersFree(RSAPointers* p) {
2019   if (p->meth && p->dummyrsa && RSA_get_method(p->dummyrsa) == p->meth) {
2020     RSA_set_method(p->dummyrsa, RSA_get_default_method());
2021   }
2022 
2023   if (p->meth) {
2024     LOG(INFO) << "Freeing meth";
2025     RSA_meth_free(p->meth);
2026   }
2027 
2028   if (p->actualrsa) {
2029     LOG(INFO) << "Freeing actualrsa";
2030     RSA_free(p->actualrsa);
2031   }
2032 
2033   if (p->dummyrsa) {
2034     LOG(INFO) << "Freeing dummyrsa";
2035     RSA_free(p->dummyrsa);
2036   }
2037 
2038   delete p;
2039 }
2040 
2041 using RSAPointersDeleter =
2042     folly::static_function_deleter<RSAPointers, RSAPointersFree>;
2043 
setupCustomRSA(const char * certPath,const char * keyPath,EventBase * jobEvb)2044 std::unique_ptr<RSAPointers, RSAPointersDeleter> setupCustomRSA(
2045     const char* certPath, const char* keyPath, EventBase* jobEvb) {
2046   auto certPEM = getFileAsBuf(certPath);
2047   auto keyPEM = getFileAsBuf(keyPath);
2048 
2049   ssl::BioUniquePtr certBio(
2050       BIO_new_mem_buf((void*)certPEM.data(), certPEM.size()));
2051   ssl::BioUniquePtr keyBio(
2052       BIO_new_mem_buf((void*)keyPEM.data(), keyPEM.size()));
2053 
2054   ssl::X509UniquePtr cert(
2055       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
2056   ssl::EvpPkeyUniquePtr evpPkey(
2057       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
2058   ssl::EvpPkeyUniquePtr publicEvpPkey(X509_get_pubkey(cert.get()));
2059 
2060   std::unique_ptr<RSAPointers, RSAPointersDeleter> ret(new RSAPointers());
2061 
2062   RSA* actualrsa = EVP_PKEY_get1_RSA(evpPkey.get());
2063   LOG(INFO) << "actualrsa ptr " << std::hex << (void*)actualrsa;
2064   RSA* dummyrsa = EVP_PKEY_get1_RSA(publicEvpPkey.get());
2065   if (dummyrsa == nullptr) {
2066     throw std::runtime_error("Couldn't get RSA cert public factors");
2067   }
2068   RSA_METHOD* meth = RSA_meth_dup(RSA_get_default_method());
2069   if (meth == nullptr || RSA_meth_set1_name(meth, "Async RSA method") == 0 ||
2070       RSA_meth_set_priv_enc(meth, customRsaPrivEnc) == 0 ||
2071       RSA_meth_set_flags(meth, RSA_METHOD_FLAG_NO_CHECK) == 0) {
2072     throw std::runtime_error("Cannot create async RSA_METHOD");
2073   }
2074   RSA_set_method(dummyrsa, meth);
2075   RSA_set_flags(dummyrsa, RSA_FLAG_EXT_PKEY);
2076 
2077   kRSAExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
2078   kRSAEvbExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
2079   kRSASocketExIndex =
2080       RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
2081   CHECK_NE(kRSAExIndex, -1);
2082   CHECK_NE(kRSAEvbExIndex, -1);
2083   CHECK_NE(kRSASocketExIndex, -1);
2084   RSA_set_ex_data(dummyrsa, kRSAExIndex, actualrsa);
2085   RSA_set_ex_data(dummyrsa, kRSAEvbExIndex, jobEvb);
2086 
2087   ret->actualrsa = actualrsa;
2088   ret->dummyrsa = dummyrsa;
2089   ret->meth = meth;
2090 
2091   return ret;
2092 }
2093 
2094 // TODO: disabled with ASAN doesn't play nice with ASYNC for some reason
2095 #ifndef FOLLY_SANITIZE_ADDRESS
TEST(AsyncSSLSocketTest,OpenSSL110AsyncTest)2096 TEST(AsyncSSLSocketTest, OpenSSL110AsyncTest) {
2097   ASYNC_init_thread(1, 1);
2098   EventBase eventBase;
2099   ScopedEventBaseThread jobEvbThread;
2100   auto clientCtx = std::make_shared<SSLContext>();
2101   auto serverCtx = std::make_shared<SSLContext>();
2102   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2103   serverCtx->loadCertificate(kTestCert);
2104   serverCtx->loadTrustedCertificates(kTestCA);
2105   serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
2106 
2107   auto rsaPointers =
2108       setupCustomRSA(kTestCert, kTestKey, jobEvbThread.getEventBase());
2109   CHECK(rsaPointers->dummyrsa);
2110   // up-refs dummyrsa
2111   SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
2112   SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
2113 
2114   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
2115   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2116 
2117   NetworkSocket fds[2];
2118   getfds(fds);
2119 
2120   AsyncSSLSocket::UniquePtr clientSock(
2121       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2122   AsyncSSLSocket::UniquePtr serverSock(
2123       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2124 
2125   SSLHandshakeClient client(std::move(clientSock), false, false);
2126   SSLHandshakeServer server(std::move(serverSock), false, false);
2127 
2128   eventBase.loop();
2129 
2130   EXPECT_TRUE(server.handshakeSuccess_);
2131   EXPECT_TRUE(client.handshakeSuccess_);
2132   ASYNC_cleanup_thread();
2133 }
2134 
TEST(AsyncSSLSocketTest,OpenSSL110AsyncTestFailure)2135 TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestFailure) {
2136   ASYNC_init_thread(1, 1);
2137   EventBase eventBase;
2138   ScopedEventBaseThread jobEvbThread;
2139   auto clientCtx = std::make_shared<SSLContext>();
2140   auto serverCtx = std::make_shared<SSLContext>();
2141   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2142   serverCtx->loadCertificate(kTestCert);
2143   serverCtx->loadTrustedCertificates(kTestCA);
2144   serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
2145   // Set the wrong key for the cert
2146   auto rsaPointers =
2147       setupCustomRSA(kTestCert, kClientTestKey, jobEvbThread.getEventBase());
2148   CHECK(rsaPointers->dummyrsa);
2149   SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
2150   SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
2151 
2152   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
2153   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2154 
2155   NetworkSocket fds[2];
2156   getfds(fds);
2157 
2158   AsyncSSLSocket::UniquePtr clientSock(
2159       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2160   AsyncSSLSocket::UniquePtr serverSock(
2161       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2162 
2163   SSLHandshakeClient client(std::move(clientSock), false, false);
2164   SSLHandshakeServer server(std::move(serverSock), false, false);
2165 
2166   eventBase.loop();
2167 
2168   EXPECT_TRUE(server.handshakeError_);
2169   EXPECT_TRUE(client.handshakeError_);
2170   ASYNC_cleanup_thread();
2171 }
2172 
TEST(AsyncSSLSocketTest,OpenSSL110AsyncTestClosedWithCallbackPending)2173 TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestClosedWithCallbackPending) {
2174   ASYNC_init_thread(1, 1);
2175   EventBase eventBase;
2176   std::optional<EventBaseThread> jobEvbThread;
2177   jobEvbThread.emplace();
2178   auto clientCtx = std::make_shared<SSLContext>();
2179   auto serverCtx = std::make_shared<SSLContext>();
2180   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2181   serverCtx->loadCertificate(kTestCert);
2182   serverCtx->loadTrustedCertificates(kTestCA);
2183   serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
2184 
2185   auto rsaPointers =
2186       setupCustomRSA(kTestCert, kTestKey, jobEvbThread->getEventBase());
2187   CHECK(rsaPointers->dummyrsa);
2188   // up-refs dummyrsa
2189   SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
2190   SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
2191 
2192   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
2193   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2194 
2195   NetworkSocket fds[2];
2196   getfds(fds);
2197 
2198   AsyncSSLSocket::UniquePtr clientSock(
2199       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2200   AsyncSSLSocket::UniquePtr serverSock(
2201       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2202 
2203   RSA_set_ex_data(rsaPointers->dummyrsa, kRSASocketExIndex, serverSock.get());
2204 
2205   SSLHandshakeClient client(std::move(clientSock), false, false);
2206   SSLHandshakeServer server(std::move(serverSock), false, false);
2207 
2208   eventBase.loop();
2209 
2210   EXPECT_TRUE(server.handshakeError_);
2211   EXPECT_TRUE(client.handshakeError_);
2212   ASYNC_cleanup_thread();
2213   jobEvbThread.reset();
2214 }
2215 #endif // FOLLY_SANITIZE_ADDRESS
2216 
2217 #endif // FOLLY_OPENSSL_IS_110
2218 
TEST(AsyncSSLSocketTest,LoadCertFromMemory)2219 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
2220   using folly::ssl::OpenSSLUtils;
2221   auto cert = getFileAsBuf(kTestCert);
2222   auto key = getFileAsBuf(kTestKey);
2223 
2224   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
2225   BIO_write(certBio.get(), cert.data(), cert.size());
2226   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
2227   BIO_write(keyBio.get(), key.data(), key.size());
2228 
2229   // Create SSL structs from buffers to get properties
2230   ssl::X509UniquePtr certStruct(
2231       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
2232   ssl::EvpPkeyUniquePtr keyStruct(
2233       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
2234   certBio = nullptr;
2235   keyBio = nullptr;
2236 
2237   auto origCommonName = OpenSSLUtils::getCommonName(certStruct.get());
2238   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
2239   certStruct = nullptr;
2240   keyStruct = nullptr;
2241 
2242   auto ctx = std::make_shared<SSLContext>();
2243   ctx->loadPrivateKeyFromBufferPEM(key);
2244   ctx->loadCertificateFromBufferPEM(cert);
2245   ctx->loadTrustedCertificates(kTestCA);
2246 
2247   ssl::SSLUniquePtr ssl(ctx->createSSL());
2248 
2249   auto newCert = SSL_get_certificate(ssl.get());
2250   auto newKey = SSL_get_privatekey(ssl.get());
2251 
2252   // Get properties from SSL struct
2253   auto newCommonName = OpenSSLUtils::getCommonName(newCert);
2254   auto newKeySize = EVP_PKEY_bits(newKey);
2255 
2256   // Check that the key and cert have the expected properties
2257   EXPECT_EQ(origCommonName, newCommonName);
2258   EXPECT_EQ(origKeySize, newKeySize);
2259 }
2260 
TEST(AsyncSSLSocketTest,MinWriteSizeTest)2261 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
2262   EventBase eb;
2263 
2264   // Set up SSL context.
2265   auto sslContext = std::make_shared<SSLContext>();
2266   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2267 
2268   // create SSL socket
2269   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
2270 
2271   EXPECT_EQ(1500, socket->getMinWriteSize());
2272 
2273   socket->setMinWriteSize(0);
2274   EXPECT_EQ(0, socket->getMinWriteSize());
2275   socket->setMinWriteSize(50000);
2276   EXPECT_EQ(50000, socket->getMinWriteSize());
2277 }
2278 
2279 class ReadCallbackTerminator : public ReadCallback {
2280  public:
ReadCallbackTerminator(EventBase * base,WriteCallbackBase * wcb)2281   ReadCallbackTerminator(EventBase* base, WriteCallbackBase* wcb)
2282       : ReadCallback(wcb), base_(base) {}
2283 
2284   // Do not write data back, terminate the loop.
readDataAvailable(size_t len)2285   void readDataAvailable(size_t len) noexcept override {
2286     std::cerr << "readDataAvailable, len " << len << std::endl;
2287 
2288     currentBuffer.length = len;
2289 
2290     buffers.push_back(currentBuffer);
2291     currentBuffer.reset();
2292     state = STATE_SUCCEEDED;
2293 
2294     socket_->setReadCB(nullptr);
2295     base_->terminateLoopSoon();
2296   }
2297 
2298  private:
2299   EventBase* base_;
2300 };
2301 
2302 /**
2303  * Test a full unencrypted codepath
2304  */
TEST(AsyncSSLSocketTest,UnencryptedTest)2305 TEST(AsyncSSLSocketTest, UnencryptedTest) {
2306   EventBase base;
2307 
2308   auto clientCtx = std::make_shared<folly::SSLContext>();
2309   auto serverCtx = std::make_shared<folly::SSLContext>();
2310   NetworkSocket fds[2];
2311   getfds(fds);
2312   getctx(clientCtx, serverCtx);
2313   auto client =
2314       AsyncSSLSocket::newSocket(clientCtx, &base, fds[0], false, true);
2315   std::shared_ptr<AsyncSSLSocket> server =
2316       AsyncSSLSocket::newSocket(serverCtx, &base, fds[1], true, true);
2317 
2318   ReadCallbackTerminator readCallback(&base, nullptr);
2319   server->setReadCB(&readCallback);
2320   readCallback.setSocket(server);
2321 
2322   uint8_t buf[128];
2323   memset(buf, 'a', sizeof(buf));
2324   client->write(nullptr, buf, sizeof(buf));
2325 
2326   // Check that bytes are unencrypted
2327   char c;
2328   EXPECT_EQ(1, netops::recv(fds[1], &c, 1, MSG_PEEK));
2329   EXPECT_EQ('a', c);
2330 
2331   EventBaseAborter eba(&base, 3000);
2332   base.loop();
2333 
2334   EXPECT_EQ(1, readCallback.buffers.size());
2335   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
2336 
2337   server->setReadCB(&readCallback);
2338 
2339   // Unencrypted
2340   server->sslAccept(nullptr);
2341   client->sslConn(nullptr);
2342 
2343   // Do NOT wait for handshake, writing should be queued and happen after
2344 
2345   client->write(nullptr, buf, sizeof(buf));
2346 
2347   // Check that bytes are *not* unencrypted
2348   char c2;
2349   EXPECT_EQ(1, netops::recv(fds[1], &c2, 1, MSG_PEEK));
2350   EXPECT_NE('a', c2);
2351 
2352   base.loop();
2353 
2354   EXPECT_EQ(2, readCallback.buffers.size());
2355   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
2356 }
2357 
TEST(AsyncSSLSocketTest,ConnectUnencryptedTest)2358 TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
2359   auto clientCtx = std::make_shared<folly::SSLContext>();
2360   auto serverCtx = std::make_shared<folly::SSLContext>();
2361   getctx(clientCtx, serverCtx);
2362 
2363   WriteCallbackBase writeCallback;
2364   ReadCallback readCallback(&writeCallback);
2365   HandshakeCallback handshakeCallback(&readCallback);
2366   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2367   TestSSLServer server(&acceptCallback);
2368 
2369   EventBase evb;
2370   std::shared_ptr<AsyncSSLSocket> socket =
2371       AsyncSSLSocket::newSocket(clientCtx, &evb, true);
2372   socket->connect(nullptr, server.getAddress(), 0);
2373 
2374   evb.loop();
2375 
2376   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, socket->getSSLState());
2377   socket->sslConn(nullptr);
2378   evb.loop();
2379   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, socket->getSSLState());
2380 
2381   // write()
2382   std::array<uint8_t, 128> buf;
2383   memset(buf.data(), 'a', buf.size());
2384   socket->write(nullptr, buf.data(), buf.size());
2385 
2386   socket->close();
2387 }
2388 
2389 /**
2390  * Test acceptrunner in various situations
2391  */
TEST(AsyncSSLSocketTest,SSLAcceptRunnerBasic)2392 TEST(AsyncSSLSocketTest, SSLAcceptRunnerBasic) {
2393   EventBase eventBase;
2394   auto clientCtx = std::make_shared<SSLContext>();
2395   auto serverCtx = std::make_shared<SSLContext>();
2396   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2397   serverCtx->loadPrivateKey(kTestKey);
2398   serverCtx->loadCertificate(kTestCert);
2399 
2400   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2401   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2402   clientCtx->loadTrustedCertificates(kTestCA);
2403 
2404   NetworkSocket fds[2];
2405   getfds(fds);
2406 
2407   AsyncSSLSocket::UniquePtr clientSock(
2408       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2409   AsyncSSLSocket::UniquePtr serverSock(
2410       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2411 
2412   serverCtx->sslAcceptRunner(std::make_unique<SSLAcceptEvbRunner>(&eventBase));
2413 
2414   SSLHandshakeClient client(std::move(clientSock), true, true);
2415   SSLHandshakeServer server(std::move(serverSock), true, true);
2416 
2417   eventBase.loop();
2418 
2419   EXPECT_TRUE(client.handshakeSuccess_);
2420   EXPECT_FALSE(client.handshakeError_);
2421   EXPECT_LE(0, client.handshakeTime.count());
2422   EXPECT_TRUE(server.handshakeSuccess_);
2423   EXPECT_FALSE(server.handshakeError_);
2424   EXPECT_LE(0, server.handshakeTime.count());
2425 }
2426 
TEST(AsyncSSLSocketTest,SSLAcceptRunnerAcceptError)2427 TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptError) {
2428   EventBase eventBase;
2429   auto clientCtx = std::make_shared<SSLContext>();
2430   auto serverCtx = std::make_shared<SSLContext>();
2431   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2432   serverCtx->loadPrivateKey(kTestKey);
2433   serverCtx->loadCertificate(kTestCert);
2434 
2435   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2436   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2437   clientCtx->loadTrustedCertificates(kTestCA);
2438 
2439   NetworkSocket fds[2];
2440   getfds(fds);
2441 
2442   AsyncSSLSocket::UniquePtr clientSock(
2443       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2444   AsyncSSLSocket::UniquePtr serverSock(
2445       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2446 
2447   serverCtx->sslAcceptRunner(
2448       std::make_unique<SSLAcceptErrorRunner>(&eventBase));
2449 
2450   SSLHandshakeClient client(std::move(clientSock), true, true);
2451   SSLHandshakeServer server(std::move(serverSock), true, true);
2452 
2453   eventBase.loop();
2454 
2455   EXPECT_FALSE(client.handshakeSuccess_);
2456   EXPECT_TRUE(client.handshakeError_);
2457   EXPECT_FALSE(server.handshakeSuccess_);
2458   EXPECT_TRUE(server.handshakeError_);
2459 }
2460 
TEST(AsyncSSLSocketTest,SSLAcceptRunnerAcceptClose)2461 TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptClose) {
2462   EventBase eventBase;
2463   auto clientCtx = std::make_shared<SSLContext>();
2464   auto serverCtx = std::make_shared<SSLContext>();
2465   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2466   serverCtx->loadPrivateKey(kTestKey);
2467   serverCtx->loadCertificate(kTestCert);
2468 
2469   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2470   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2471   clientCtx->loadTrustedCertificates(kTestCA);
2472 
2473   NetworkSocket fds[2];
2474   getfds(fds);
2475 
2476   AsyncSSLSocket::UniquePtr clientSock(
2477       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2478   AsyncSSLSocket::UniquePtr serverSock(
2479       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2480 
2481   serverCtx->sslAcceptRunner(
2482       std::make_unique<SSLAcceptCloseRunner>(&eventBase, serverSock.get()));
2483 
2484   SSLHandshakeClient client(std::move(clientSock), true, true);
2485   SSLHandshakeServer server(std::move(serverSock), true, true);
2486 
2487   eventBase.loop();
2488 
2489   EXPECT_FALSE(client.handshakeSuccess_);
2490   EXPECT_TRUE(client.handshakeError_);
2491   EXPECT_FALSE(server.handshakeSuccess_);
2492   EXPECT_TRUE(server.handshakeError_);
2493 }
2494 
TEST(AsyncSSLSocketTest,SSLAcceptRunnerAcceptDestroy)2495 TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptDestroy) {
2496   EventBase eventBase;
2497   auto clientCtx = std::make_shared<SSLContext>();
2498   auto serverCtx = std::make_shared<SSLContext>();
2499   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2500   serverCtx->loadPrivateKey(kTestKey);
2501   serverCtx->loadCertificate(kTestCert);
2502 
2503   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2504   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2505   clientCtx->loadTrustedCertificates(kTestCA);
2506 
2507   NetworkSocket fds[2];
2508   getfds(fds);
2509 
2510   AsyncSSLSocket::UniquePtr clientSock(
2511       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2512   AsyncSSLSocket::UniquePtr serverSock(
2513       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2514 
2515   SSLHandshakeClient client(std::move(clientSock), true, true);
2516   SSLHandshakeServer server(std::move(serverSock), true, true);
2517 
2518   serverCtx->sslAcceptRunner(
2519       std::make_unique<SSLAcceptDestroyRunner>(&eventBase, &server));
2520 
2521   eventBase.loop();
2522 
2523   EXPECT_FALSE(client.handshakeSuccess_);
2524   EXPECT_TRUE(client.handshakeError_);
2525   EXPECT_FALSE(server.handshakeSuccess_);
2526   EXPECT_TRUE(server.handshakeError_);
2527 }
2528 
TEST(AsyncSSLSocketTest,SSLAcceptRunnerFiber)2529 TEST(AsyncSSLSocketTest, SSLAcceptRunnerFiber) {
2530   EventBase eventBase;
2531   auto clientCtx = std::make_shared<SSLContext>();
2532   auto serverCtx = std::make_shared<SSLContext>();
2533   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2534   serverCtx->loadPrivateKey(kTestKey);
2535   serverCtx->loadCertificate(kTestCert);
2536 
2537   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2538   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2539   clientCtx->loadTrustedCertificates(kTestCA);
2540 
2541   NetworkSocket fds[2];
2542   getfds(fds);
2543 
2544   AsyncSSLSocket::UniquePtr clientSock(
2545       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2546   AsyncSSLSocket::UniquePtr serverSock(
2547       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2548 
2549   SSLHandshakeClient client(std::move(clientSock), true, true);
2550   SSLHandshakeServer server(std::move(serverSock), true, true);
2551 
2552   serverCtx->sslAcceptRunner(
2553       std::make_unique<SSLAcceptFiberRunner>(&eventBase));
2554 
2555   eventBase.loop();
2556 
2557   EXPECT_TRUE(client.handshakeSuccess_);
2558   EXPECT_FALSE(client.handshakeError_);
2559   EXPECT_TRUE(server.handshakeSuccess_);
2560   EXPECT_FALSE(server.handshakeError_);
2561 }
2562 
newCloseCb(SSL * ssl,SSL_SESSION *)2563 static int newCloseCb(SSL* ssl, SSL_SESSION*) {
2564   AsyncSSLSocket::getFromSSL(ssl)->closeNow();
2565   return 1;
2566 }
2567 
2568 #if FOLLY_OPENSSL_IS_110
getCloseCb(SSL * ssl,const unsigned char *,int,int *)2569 static SSL_SESSION* getCloseCb(SSL* ssl, const unsigned char*, int, int*) {
2570 #else
2571 static SSL_SESSION* getCloseCb(SSL* ssl, unsigned char*, int, int*) {
2572 #endif
2573   AsyncSSLSocket::getFromSSL(ssl)->closeNow();
2574   return nullptr;
2575 } // namespace folly
2576 
2577 TEST(AsyncSSLSocketTest, SSLAcceptRunnerFiberCloseSessionCb) {
2578   EventBase eventBase;
2579   auto clientCtx = std::make_shared<SSLContext>();
2580   auto serverCtx = std::make_shared<SSLContext>();
2581   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2582   serverCtx->loadPrivateKey(kTestKey);
2583   serverCtx->loadCertificate(kTestCert);
2584   SSL_CTX_set_session_cache_mode(
2585       serverCtx->getSSLCtx(),
2586       SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
2587   SSL_CTX_sess_set_new_cb(serverCtx->getSSLCtx(), &newCloseCb);
2588   SSL_CTX_sess_set_get_cb(serverCtx->getSSLCtx(), &getCloseCb);
2589   serverCtx->sslAcceptRunner(
2590       std::make_unique<SSLAcceptFiberRunner>(&eventBase));
2591 
2592   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2593   clientCtx->ciphers("AES128-SHA256");
2594   clientCtx->loadTrustedCertificates(kTestCA);
2595   clientCtx->setOptions(SSL_OP_NO_TICKET);
2596 
2597   NetworkSocket fds[2];
2598   getfds(fds);
2599 
2600   AsyncSSLSocket::UniquePtr clientSock(
2601       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2602   AsyncSSLSocket::UniquePtr serverSock(
2603       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2604 
2605   SSLHandshakeClient client(std::move(clientSock), true, true);
2606   SSLHandshakeServer server(std::move(serverSock), true, true);
2607 
2608   eventBase.loop();
2609 
2610   // As close() is called during session callbacks, client sees it as a
2611   // successful connection
2612   EXPECT_TRUE(client.handshakeSuccess_);
2613   EXPECT_FALSE(client.handshakeError_);
2614   EXPECT_FALSE(server.handshakeSuccess_);
2615   EXPECT_TRUE(server.handshakeError_);
2616 }
2617 
2618 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
2619   // Start listening on a local port
2620   WriteCallbackBase writeCallback;
2621   WriteErrorCallback readCallback(&writeCallback);
2622   HandshakeCallback handshakeCallback(
2623       &readCallback, HandshakeCallback::EXPECT_ERROR);
2624   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2625   TestSSLServer server(&acceptCallback);
2626 
2627   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
2628   socket->open();
2629   uint8_t buf[3] = {0x16, 0x03, 0x01};
2630   socket->write(buf, sizeof(buf));
2631   socket->closeWithReset();
2632 
2633   handshakeCallback.waitForHandshake();
2634   EXPECT_NE(
2635       handshakeCallback.errorString_.find("Network error"), std::string::npos);
2636   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
2637 }
2638 
2639 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
2640   // Start listening on a local port
2641   WriteCallbackBase writeCallback;
2642   WriteErrorCallback readCallback(&writeCallback);
2643   HandshakeCallback handshakeCallback(
2644       &readCallback, HandshakeCallback::EXPECT_ERROR);
2645   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2646   TestSSLServer server(&acceptCallback);
2647 
2648   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
2649   socket->open();
2650   uint8_t buf[3] = {0x16, 0x03, 0x01};
2651   socket->write(buf, sizeof(buf));
2652   socket->close();
2653 
2654   handshakeCallback.waitForHandshake();
2655 #if FOLLY_OPENSSL_IS_110
2656   EXPECT_NE(
2657       handshakeCallback.errorString_.find("Network error"), std::string::npos);
2658 #else
2659   EXPECT_NE(
2660       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
2661 #endif
2662 }
2663 
2664 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
2665   // Start listening on a local port
2666   WriteCallbackBase writeCallback;
2667   WriteErrorCallback readCallback(&writeCallback);
2668   HandshakeCallback handshakeCallback(
2669       &readCallback, HandshakeCallback::EXPECT_ERROR);
2670   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2671   TestSSLServer server(&acceptCallback);
2672 
2673   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
2674   socket->open();
2675   uint8_t buf[256] = {0x16, 0x03};
2676   memset(buf + 2, 'a', sizeof(buf) - 2);
2677   socket->write(buf, sizeof(buf));
2678   socket->close();
2679 
2680   handshakeCallback.waitForHandshake();
2681   EXPECT_NE(
2682       handshakeCallback.errorString_.find("SSL routines"), std::string::npos);
2683 #if defined(OPENSSL_IS_BORINGSSL)
2684   EXPECT_NE(
2685       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
2686       std::string::npos);
2687 #elif FOLLY_OPENSSL_IS_110
2688   EXPECT_NE(
2689       handshakeCallback.errorString_.find("packet length too long"),
2690       std::string::npos);
2691 #else
2692   EXPECT_NE(
2693       handshakeCallback.errorString_.find("unknown protocol"),
2694       std::string::npos);
2695 #endif
2696 }
2697 
2698 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
2699   using folly::ssl::OpenSSLUtils;
2700   EXPECT_EQ(
2701       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
2702   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
2703   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
2704   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
2705   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
2706 }
2707 
2708 #if defined __linux__
2709 /**
2710  * Ensure TransparentTLS flag is disabled with AsyncSSLSocket
2711  */
2712 TEST(AsyncSSLSocketTest, TTLSDisabled) {
2713   // clear all setsockopt tracking history
2714   globalStatic.reset();
2715 
2716   // Start listening on a local port
2717   WriteCallbackBase writeCallback;
2718   ReadCallback readCallback(&writeCallback);
2719   HandshakeCallback handshakeCallback(&readCallback);
2720   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2721   TestSSLServer server(&acceptCallback, false);
2722 
2723   // Set up SSL context.
2724   auto sslContext = std::make_shared<SSLContext>();
2725 
2726   // connect
2727   auto socket =
2728       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2729   socket->open();
2730 
2731   EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getNetworkSocket()));
2732 
2733   // write()
2734   std::array<uint8_t, 128> buf;
2735   memset(buf.data(), 'a', buf.size());
2736   socket->write(buf.data(), buf.size());
2737 
2738   // close()
2739   socket->close();
2740 }
2741 #endif
2742 
2743 #if FOLLY_ALLOW_TFO
2744 
2745 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
2746  public:
2747   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
2748 
2749   explicit MockAsyncTFOSSLSocket(
2750       std::shared_ptr<folly::SSLContext> sslCtx, EventBase* evb)
2751       : AsyncSSLSocket(sslCtx, evb) {}
2752 
2753   MOCK_METHOD3(
2754       tfoSendMsg, ssize_t(NetworkSocket fd, struct msghdr* msg, int msg_flags));
2755 };
2756 
2757 #if defined __linux__
2758 /**
2759  * Ensure TransparentTLS flag is disabled with AsyncSSLSocket + TFO
2760  */
2761 TEST(AsyncSSLSocketTest, TTLSDisabledWithTFO) {
2762   // clear all setsockopt tracking history
2763   globalStatic.reset();
2764 
2765   // Start listening on a local port
2766   WriteCallbackBase writeCallback;
2767   ReadCallback readCallback(&writeCallback);
2768   HandshakeCallback handshakeCallback(&readCallback);
2769   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2770   TestSSLServer server(&acceptCallback, true);
2771 
2772   // Set up SSL context.
2773   auto sslContext = std::make_shared<SSLContext>();
2774 
2775   // connect
2776   auto socket =
2777       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2778   socket->enableTFO();
2779   socket->open();
2780 
2781   EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getNetworkSocket()));
2782 
2783   // write()
2784   std::array<uint8_t, 128> buf;
2785   memset(buf.data(), 'a', buf.size());
2786   socket->write(buf.data(), buf.size());
2787 
2788   // close()
2789   socket->close();
2790 }
2791 #endif
2792 
2793 /**
2794  * Test connecting to, writing to, reading from, and closing the
2795  * connection to the SSL server with TFO.
2796  */
2797 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
2798   // Start listening on a local port
2799   WriteCallbackBase writeCallback;
2800   ReadCallback readCallback(&writeCallback);
2801   HandshakeCallback handshakeCallback(&readCallback);
2802   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2803   TestSSLServer server(&acceptCallback, true);
2804 
2805   // Set up SSL context.
2806   auto sslContext = std::make_shared<SSLContext>();
2807 
2808   // connect
2809   auto socket =
2810       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2811   socket->enableTFO();
2812   socket->open();
2813 
2814   // write()
2815   std::array<uint8_t, 128> buf;
2816   memset(buf.data(), 'a', buf.size());
2817   socket->write(buf.data(), buf.size());
2818 
2819   // read()
2820   std::array<uint8_t, 128> readbuf;
2821   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2822   EXPECT_EQ(bytesRead, 128);
2823   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
2824 
2825   // close()
2826   socket->close();
2827 }
2828 
2829 /**
2830  * Test connecting to, writing to, reading from, and closing the
2831  * connection to the SSL server with TFO.
2832  */
2833 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
2834   // Start listening on a local port
2835   WriteCallbackBase writeCallback;
2836   ReadCallback readCallback(&writeCallback);
2837   HandshakeCallback handshakeCallback(&readCallback);
2838   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2839   TestSSLServer server(&acceptCallback, false);
2840 
2841   // Set up SSL context.
2842   auto sslContext = std::make_shared<SSLContext>();
2843 
2844   // connect
2845   auto socket =
2846       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2847   socket->enableTFO();
2848   socket->open();
2849 
2850   // write()
2851   std::array<uint8_t, 128> buf;
2852   memset(buf.data(), 'a', buf.size());
2853   socket->write(buf.data(), buf.size());
2854 
2855   // read()
2856   std::array<uint8_t, 128> readbuf;
2857   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2858   EXPECT_EQ(bytesRead, 128);
2859   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
2860 
2861   // close()
2862   socket->close();
2863 }
2864 
2865 class ConnCallback : public AsyncSocket::ConnectCallback {
2866  public:
2867   void connectSuccess() noexcept override { state = State::SUCCESS; }
2868 
2869   void connectErr(const AsyncSocketException& ex) noexcept override {
2870     state = State::ERROR;
2871     error = ex.what();
2872   }
2873 
2874   enum class State { WAITING, SUCCESS, ERROR };
2875 
2876   State state{State::WAITING};
2877   std::string error;
2878 };
2879 
2880 template <class Cardinality>
2881 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
2882     EventBase* evb, const SocketAddress& address, Cardinality cardinality) {
2883   // Set up SSL context.
2884   auto sslContext = std::make_shared<SSLContext>();
2885 
2886   // connect
2887   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
2888       new MockAsyncTFOSSLSocket(sslContext, evb));
2889   socket->enableTFO();
2890 
2891   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2892       .Times(cardinality)
2893       .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
2894         sockaddr_storage addr;
2895         auto len = address.getAddress(&addr);
2896         return netops::connect(fd, (const struct sockaddr*)&addr, len);
2897       }));
2898   return socket;
2899 }
2900 
2901 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
2902   if (!folly::test::isTFOAvailable()) {
2903     GTEST_SKIP() << "TFO not supported.";
2904   }
2905 
2906   // Start listening on a local port
2907   WriteCallbackBase writeCallback;
2908   ReadCallback readCallback(&writeCallback);
2909   HandshakeCallback handshakeCallback(&readCallback);
2910   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2911   TestSSLServer server(&acceptCallback, true);
2912 
2913   EventBase evb;
2914 
2915   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
2916   ConnCallback ccb;
2917   socket->connect(&ccb, server.getAddress(), 30);
2918 
2919   evb.loop();
2920   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
2921 
2922   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
2923   evb.loop();
2924 
2925   BlockingSocket sock(std::move(socket));
2926   // write()
2927   std::array<uint8_t, 128> buf;
2928   memset(buf.data(), 'a', buf.size());
2929   sock.write(buf.data(), buf.size());
2930 
2931   // read()
2932   std::array<uint8_t, 128> readbuf;
2933   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
2934   EXPECT_EQ(bytesRead, 128);
2935   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
2936 
2937   // close()
2938   sock.close();
2939 }
2940 
2941 #if !defined(OPENSSL_IS_BORINGSSL)
2942 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
2943   // Start listening on a local port
2944   ConnectTimeoutCallback acceptCallback;
2945   TestSSLServer server(&acceptCallback, true);
2946 
2947   // Set up SSL context.
2948   auto sslContext = std::make_shared<SSLContext>();
2949 
2950   // connect
2951   auto socket =
2952       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2953   socket->enableTFO();
2954   EXPECT_THROW(
2955       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
2956 }
2957 #endif
2958 
2959 #if !defined(OPENSSL_IS_BORINGSSL)
2960 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
2961   // Start listening on a local port
2962   ConnectTimeoutCallback acceptCallback;
2963   TestSSLServer server(&acceptCallback, true);
2964 
2965   EventBase evb;
2966 
2967   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
2968   ConnCallback ccb;
2969   // Set a short timeout
2970   socket->connect(&ccb, server.getAddress(), 1);
2971 
2972   evb.loop();
2973   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2974 }
2975 #endif
2976 
2977 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
2978   // Start listening on a local port
2979   EmptyReadCallback readCallback;
2980   HandshakeCallback handshakeCallback(
2981       &readCallback, HandshakeCallback::EXPECT_ERROR);
2982   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
2983   TestSSLServer server(&acceptCallback, true);
2984 
2985   EventBase evb;
2986 
2987   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
2988   ConnCallback ccb;
2989   socket->connect(&ccb, server.getAddress(), 100);
2990 
2991   evb.loop();
2992   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2993   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
2994 }
2995 
2996 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
2997   // Start listening on a local port
2998   EventBase evb;
2999 
3000   // Hopefully nothing is listening on this address
3001   SocketAddress addr("127.0.0.1", 65535);
3002   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
3003   ConnCallback ccb;
3004   socket->connect(&ccb, addr, 100);
3005 
3006   evb.loop();
3007   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
3008   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
3009 }
3010 
3011 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
3012   EventBase eventBase;
3013   auto clientCtx = std::make_shared<SSLContext>();
3014   auto dfServerCtx = std::make_shared<SSLContext>();
3015   std::array<NetworkSocket, 2> fds;
3016   getfds(fds.data());
3017   getctx(clientCtx, dfServerCtx);
3018 
3019   AsyncSSLSocket::UniquePtr clientSockPtr(
3020       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
3021   AsyncSSLSocket::UniquePtr serverSockPtr(
3022       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
3023   auto clientSock = clientSockPtr.get();
3024   auto serverSock = serverSockPtr.get();
3025   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
3026 
3027   // Steal some data from the server.
3028   std::array<uint8_t, 10> buf;
3029   auto bytesReceived = netops::recv(fds[1], buf.data(), buf.size(), 0);
3030   checkUnixError(bytesReceived, "recv failed");
3031 
3032   serverSock->setPreReceivedData(
3033       IOBuf::wrapBuffer(ByteRange(buf.data(), bytesReceived)));
3034   SSLHandshakeServer server(std::move(serverSockPtr), true, true);
3035   while (!client.handshakeSuccess_ && !client.handshakeError_) {
3036     eventBase.loopOnce();
3037   }
3038 
3039   EXPECT_TRUE(client.handshakeSuccess_);
3040   EXPECT_TRUE(server.handshakeSuccess_);
3041   EXPECT_EQ(
3042       serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
3043 }
3044 
3045 TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
3046   EventBase eventBase;
3047   auto clientCtx = std::make_shared<SSLContext>();
3048   auto dfServerCtx = std::make_shared<SSLContext>();
3049   std::array<NetworkSocket, 2> fds;
3050   getfds(fds.data());
3051   getctx(clientCtx, dfServerCtx);
3052 
3053   AsyncSSLSocket::UniquePtr clientSockPtr(
3054       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
3055   AsyncSocket::UniquePtr serverSockPtr(new AsyncSocket(&eventBase, fds[1]));
3056   auto clientSock = clientSockPtr.get();
3057   auto serverSock = serverSockPtr.get();
3058   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
3059 
3060   // Steal some data from the server.
3061   std::array<uint8_t, 10> buf;
3062   auto bytesReceived = netops::recv(fds[1], buf.data(), buf.size(), 0);
3063   checkUnixError(bytesReceived, "recv failed");
3064 
3065   serverSock->setPreReceivedData(
3066       IOBuf::wrapBuffer(ByteRange(buf.data(), bytesReceived)));
3067   AsyncSSLSocket::UniquePtr serverSSLSockPtr(
3068       new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
3069   auto serverSSLSock = serverSSLSockPtr.get();
3070   SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
3071   while (!client.handshakeSuccess_ && !client.handshakeError_) {
3072     eventBase.loopOnce();
3073   }
3074 
3075   EXPECT_TRUE(client.handshakeSuccess_);
3076   EXPECT_TRUE(server.handshakeSuccess_);
3077   EXPECT_EQ(
3078       serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
3079 }
3080 
3081 /**
3082  * Test overriding the flags passed to "sendmsg()" system call,
3083  * and verifying that write requests fail properly.
3084  */
3085 TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
3086   // Start listening on a local port
3087   SendMsgFlagsCallback msgCallback;
3088   ExpectWriteErrorCallback writeCallback(&msgCallback);
3089   ReadCallback readCallback(&writeCallback);
3090   HandshakeCallback handshakeCallback(&readCallback);
3091   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
3092   TestSSLServer server(&acceptCallback);
3093 
3094   // Set up SSL context.
3095   auto sslContext = std::make_shared<SSLContext>();
3096   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
3097 
3098   // connect
3099   auto socket =
3100       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
3101   socket->open();
3102 
3103   // Setting flags to "-1" to trigger "Invalid argument" error
3104   // on attempt to use this flags in sendmsg() system call.
3105   msgCallback.resetFlags(-1);
3106 
3107   // write()
3108   std::vector<uint8_t> buf(128, 'a');
3109   ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
3110 
3111   // close()
3112   socket->close();
3113 
3114   cerr << "SendMsgParamsCallback test completed" << endl;
3115 }
3116 
3117 #if FOLLY_HAVE_SO_TIMESTAMPING
3118 
3119 class AsyncSSLSocketByteEventTest : public ::testing::Test {
3120  protected:
3121   using MockDispatcher = ::testing::NiceMock<netops::test::MockDispatcher>;
3122   using TestObserver = test::MockAsyncTransportObserverForByteEvents;
3123   using ByteEventType = AsyncTransport::ByteEvent::Type;
3124 
3125   /**
3126    * Components of a client connection to TestServer.
3127    *
3128    * Includes EventBase, client's AsyncSocket.
3129    */
3130   class ClientConn {
3131    public:
3132     explicit ClientConn(
3133         std::shared_ptr<TestSSLServer> server,
3134         std::shared_ptr<AsyncSSLSocket> socket = nullptr)
3135         : server_(std::move(server)), socket_(std::move(socket)) {
3136       if (!socket_) {
3137         socket_ = AsyncSSLSocket::newSocket(getSslContext(), &getEventBase());
3138       }
3139       socket_->setOverrideNetOpsDispatcher(netOpsDispatcher_);
3140       netOpsDispatcher_->forwardToDefaultImpl();
3141     }
3142 
3143     ~ClientConn() {
3144       if (socket_) {
3145         socket_->close();
3146       }
3147     }
3148 
3149     void connect() {
3150       CHECK_NOTNULL(socket_.get());
3151       CHECK_NOTNULL(socket_->getEventBase());
3152       socket_->connect(&connCb_, server_->getAddress(), 30);
3153       socket_->getEventBase()->loop();
3154       ASSERT_EQ(connCb_.state, ConnCallback::State::SUCCESS);
3155       setReadCb();
3156     }
3157 
3158     void setReadCb() {
3159       // Due to how libevent works, we currently need to be subscribed to
3160       // EV_READ events in order to get error messages.
3161       //
3162       // TODO(bschlinker): Resolve this with libevent modification.
3163       // See https://github.com/libevent/libevent/issues/1038 for details.
3164       socket_->setReadCB(&readCb_);
3165       readCb_.setSocket(socket_);
3166     }
3167 
3168     std::shared_ptr<NiceMock<TestObserver>> attachObserver(
3169         bool enableByteEvents) {
3170       auto observer = AsyncSSLSocketByteEventTest::attachObserver(
3171           socket_.get(), enableByteEvents);
3172       observers.push_back(observer);
3173       return observer;
3174     }
3175 
3176     /**
3177      * Write to client socket, echo at server, and wait for echo at client.
3178      *
3179      * Waiting for echo at client ensures that we have given opportunity for
3180      * timestamps to be generated by the kernel.
3181      */
3182     void writeAndReflect(
3183         const std::vector<uint8_t>& wbuf, const WriteFlags writeFlags) {
3184       CHECK_NOTNULL(socket_.get());
3185       CHECK_NOTNULL(socket_->getEventBase());
3186 
3187       // write to the client socket
3188       WriteCallbackBase wcb;
3189       socket_->write(&wcb, wbuf.data(), wbuf.size(), writeFlags);
3190       while (wcb.state == STATE_WAITING) {
3191         socket_->getEventBase()->loopOnce();
3192       }
3193       ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3194 
3195       // TestSSLServer reads and reflects for us
3196 
3197       // read reflection at client
3198       while (wbuf.size() != readCb_.dataRead()) {
3199         socket_->getEventBase()->loopOnce();
3200       }
3201       readCb_.verifyData(wbuf.data(), wbuf.size());
3202       readCb_.clearData();
3203     }
3204 
3205     std::shared_ptr<AsyncSSLSocket> getRawSocket() { return socket_; }
3206 
3207     std::shared_ptr<SSLContext> getSslContext() {
3208       static std::shared_ptr<SSLContext> sslContext = initSslContext();
3209       return sslContext;
3210     }
3211 
3212     EventBase& getEventBase() {
3213       static EventBase evb; // use same EventBase for all client sockets
3214       return evb;
3215     }
3216 
3217     void netOpsExpectTimestampingSetSockOpt() {
3218       // must whitelist other calls
3219       EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
3220           .Times(AnyNumber());
3221       EXPECT_CALL(
3222           *netOpsDispatcher_, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
3223           .Times(1);
3224     }
3225 
3226     void netOpsExpectNoTimestampingSetSockOpt() {
3227       // must whitelist other calls
3228       EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
3229           .Times(AnyNumber());
3230       EXPECT_CALL(
3231           *netOpsDispatcher_, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
3232           .Times(0);
3233     }
3234 
3235     void netOpsExpectWriteWithFlags(WriteFlags writeFlags) {
3236       EXPECT_CALL(*netOpsDispatcher_, sendmsg(_, _, _))
3237           .WillOnce(Invoke(
3238               [this, writeFlags](
3239                   NetworkSocket socket, const msghdr* message, int flags) {
3240                 EXPECT_EQ(writeFlags, getMsgWriteFlags(*message));
3241                 return netOpsDispatcher_->netops::Dispatcher::sendmsg(
3242                     socket, message, flags);
3243               }));
3244     }
3245 
3246     void netOpsVerifyAndClearExpectations() {
3247       Mock::VerifyAndClearExpectations(netOpsDispatcher_.get());
3248     }
3249 
3250     /**
3251      * Static utilities.
3252      */
3253     static std::shared_ptr<SSLContext> initSslContext() {
3254       auto sslContext = std::make_shared<SSLContext>();
3255       sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
3256       return sslContext;
3257     }
3258 
3259    private:
3260     // server
3261     std::shared_ptr<TestSSLServer> server_;
3262 
3263     // managed observers
3264     std::vector<std::shared_ptr<TestObserver>> observers;
3265 
3266     // socket components
3267     ConnCallback connCb_;
3268     ReadCallback readCb_;
3269     std::shared_ptr<MockDispatcher> netOpsDispatcher_{
3270         std::make_shared<MockDispatcher>()};
3271     std::shared_ptr<AsyncSSLSocket> socket_;
3272   };
3273 
3274   ClientConn getClientConn() { return ClientConn(server_); }
3275 
3276   void SetUp() override {
3277     serverWriteCb_ = std::make_unique<WriteCallbackBase>();
3278     serverReadCb_ = std::make_unique<ReadCallback>(serverWriteCb_.get());
3279     serverHandshakeCb_ =
3280         std::make_unique<HandshakeCallback>(serverReadCb_.get());
3281     serverAcceptCb_ =
3282         std::make_unique<SSLServerAcceptCallback>(serverHandshakeCb_.get());
3283     server_ = std::make_shared<TestSSLServer>(serverAcceptCb_.get());
3284   }
3285 
3286   /**
3287    * Static utility functions.
3288    */
3289 
3290   static std::shared_ptr<NiceMock<TestObserver>> attachObserver(
3291       AsyncSocket* socket, bool enableByteEvents) {
3292     AsyncTransport::LifecycleObserver::Config config = {};
3293     config.byteEvents = enableByteEvents;
3294     return std::make_shared<NiceMock<TestObserver>>(socket, config);
3295   }
3296 
3297   static WriteFlags getMsgWriteFlags(const struct msghdr& msg) {
3298     const struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
3299     if (!cmsg || cmsg->cmsg_level != SOL_SOCKET ||
3300         cmsg->cmsg_type != SO_TIMESTAMPING ||
3301         cmsg->cmsg_len != CMSG_LEN(sizeof(uint32_t))) {
3302       return WriteFlags::NONE;
3303     }
3304 
3305     const uint32_t* sofFlags =
3306         (reinterpret_cast<const uint32_t*>(CMSG_DATA(cmsg)));
3307     WriteFlags flags = WriteFlags::NONE;
3308     if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SCHED) {
3309       flags = flags | WriteFlags::TIMESTAMP_SCHED;
3310     }
3311     if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE) {
3312       flags = flags | WriteFlags::TIMESTAMP_TX;
3313     }
3314     if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_ACK) {
3315       flags = flags | WriteFlags::TIMESTAMP_ACK;
3316     }
3317 
3318     return flags;
3319   }
3320 
3321   static WriteFlags dropWriteFromFlags(WriteFlags writeFlags) {
3322     return writeFlags & ~WriteFlags::TIMESTAMP_WRITE;
3323   }
3324 
3325   // server components
3326   std::unique_ptr<WriteCallbackBase> serverWriteCb_;
3327   std::unique_ptr<ReadCallback> serverReadCb_;
3328   std::unique_ptr<HandshakeCallback> serverHandshakeCb_;
3329   std::unique_ptr<SSLServerAcceptCallback> serverAcceptCb_;
3330   std::shared_ptr<TestSSLServer> server_;
3331 };
3332 
3333 TEST_F(AsyncSSLSocketByteEventTest, ObserverAttachedBeforeConnect) {
3334   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3335       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3336   const std::vector<uint8_t> wbuf(1, 'a');
3337 
3338   auto clientConn = getClientConn();
3339   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3340   clientConn.netOpsExpectTimestampingSetSockOpt();
3341   clientConn.connect();
3342   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3343   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3344   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3345   clientConn.netOpsVerifyAndClearExpectations();
3346 
3347   {
3348     clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3349     clientConn.writeAndReflect(wbuf, flags);
3350     clientConn.netOpsVerifyAndClearExpectations();
3351 
3352     // may have more than four new ByteEvents if write split further by kernel
3353     EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
3354 
3355     // due to SSL overhead, offset will not be 0
3356     auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3357     EXPECT_EQ(
3358         offsetExpected,
3359         observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3360     EXPECT_EQ(
3361         offsetExpected,
3362         observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3363     EXPECT_EQ(
3364         offsetExpected,
3365         observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3366     EXPECT_EQ(
3367         offsetExpected,
3368         observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3369   }
3370 
3371   // write again to check offsets
3372   {
3373     const auto startNumByteEvents = observer->byteEvents.size();
3374     clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3375     clientConn.writeAndReflect(wbuf, flags);
3376     clientConn.netOpsVerifyAndClearExpectations();
3377 
3378     // may have more than four new ByteEvents if write split further by kernel
3379     EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));
3380 
3381     // due to SSL overhead, offset will not be 1
3382     auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3383     EXPECT_EQ(
3384         offsetExpected,
3385         observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3386     EXPECT_EQ(
3387         offsetExpected,
3388         observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3389     EXPECT_EQ(
3390         offsetExpected,
3391         observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3392     EXPECT_EQ(
3393         offsetExpected,
3394         observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3395   }
3396 }
3397 
3398 TEST_F(AsyncSSLSocketByteEventTest, ObserverAttachedAfterConnect) {
3399   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3400       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3401   const std::vector<uint8_t> wbuf(1, 'a');
3402 
3403   auto clientConn = getClientConn();
3404   clientConn.netOpsExpectNoTimestampingSetSockOpt();
3405   clientConn.connect();
3406   clientConn.netOpsVerifyAndClearExpectations();
3407 
3408   clientConn.netOpsExpectTimestampingSetSockOpt();
3409   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3410   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3411   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3412   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3413   clientConn.netOpsVerifyAndClearExpectations();
3414 
3415   {
3416     clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3417     clientConn.writeAndReflect(wbuf, flags);
3418     clientConn.netOpsVerifyAndClearExpectations();
3419 
3420     // may have more than four new ByteEvents if write split further by kernel
3421     EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
3422 
3423     // due to SSL overhead, offset will not be 0
3424     auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3425     EXPECT_EQ(
3426         offsetExpected,
3427         observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3428     EXPECT_EQ(
3429         offsetExpected,
3430         observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3431     EXPECT_EQ(
3432         offsetExpected,
3433         observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3434     EXPECT_EQ(
3435         offsetExpected,
3436         observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3437   }
3438 
3439   // write again to check offsets
3440   {
3441     const auto startNumByteEvents = observer->byteEvents.size();
3442     clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3443     clientConn.writeAndReflect(wbuf, flags);
3444     clientConn.netOpsVerifyAndClearExpectations();
3445 
3446     // may have more than four new ByteEvents if write split further by kernel
3447     EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));
3448 
3449     // due to SSL overhead, offset will not be 1
3450     auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3451     EXPECT_EQ(
3452         offsetExpected,
3453         observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3454     EXPECT_EQ(
3455         offsetExpected,
3456         observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3457     EXPECT_EQ(
3458         offsetExpected,
3459         observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3460     EXPECT_EQ(
3461         offsetExpected,
3462         observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3463   }
3464 }
3465 
3466 TEST_F(AsyncSSLSocketByteEventTest, MultiByteWrites) {
3467   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3468       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3469 
3470   auto clientConn = getClientConn();
3471   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3472   clientConn.netOpsExpectTimestampingSetSockOpt();
3473   clientConn.connect();
3474   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3475   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3476   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3477   clientConn.netOpsVerifyAndClearExpectations();
3478 
3479   // write 20 bytes
3480   {
3481     std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes
3482 
3483     clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3484     clientConn.writeAndReflect(wbuf, flags);
3485     clientConn.netOpsVerifyAndClearExpectations();
3486 
3487     // may have more than four new ByteEvents if write split further by kernel
3488     EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
3489 
3490     // due to SSL overhead, offset will not be 0
3491     auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3492     EXPECT_EQ(
3493         offsetExpected,
3494         observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3495     EXPECT_EQ(
3496         offsetExpected,
3497         observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3498     EXPECT_EQ(
3499         offsetExpected,
3500         observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3501     EXPECT_EQ(
3502         offsetExpected,
3503         observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3504   }
3505 
3506   // write 40 bytes
3507   {
3508     std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes
3509 
3510     const auto startNumByteEvents = observer->byteEvents.size();
3511     clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3512     clientConn.writeAndReflect(wbuf, flags);
3513     clientConn.netOpsVerifyAndClearExpectations();
3514 
3515     // may have more than four new ByteEvents if write split further by kernel
3516     EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));
3517 
3518     // due to SSL overhead, offset will not be 1
3519     auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3520     EXPECT_EQ(
3521         offsetExpected,
3522         observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3523     EXPECT_EQ(
3524         offsetExpected,
3525         observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3526     EXPECT_EQ(
3527         offsetExpected,
3528         observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3529     EXPECT_EQ(
3530         offsetExpected,
3531         observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3532   }
3533 }
3534 
3535 TEST_F(AsyncSSLSocketByteEventTest, MultiByteWritesEnableSecondWrite) {
3536   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3537       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3538 
3539   auto clientConn = getClientConn();
3540   clientConn.netOpsExpectNoTimestampingSetSockOpt();
3541   clientConn.connect();
3542   clientConn.netOpsVerifyAndClearExpectations();
3543 
3544   // write 20 bytes with no ByteEvents / observer
3545   {
3546     std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes
3547     clientConn.netOpsExpectWriteWithFlags(WriteFlags::NONE);
3548     clientConn.writeAndReflect(wbuf, flags);
3549     clientConn.netOpsVerifyAndClearExpectations();
3550   }
3551 
3552   // enable observer
3553   clientConn.netOpsExpectTimestampingSetSockOpt();
3554   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3555   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3556   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3557   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3558   clientConn.netOpsVerifyAndClearExpectations();
3559 
3560   // write 40 bytes
3561   {
3562     std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes
3563 
3564     const auto startNumByteEvents = observer->byteEvents.size();
3565     clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3566     clientConn.writeAndReflect(wbuf, flags);
3567     clientConn.netOpsVerifyAndClearExpectations();
3568 
3569     // may have more than four new ByteEvents if write split further by kernel
3570     EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));
3571 
3572     // due to SSL overhead, offset will not be 1
3573     auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3574     EXPECT_EQ(
3575         offsetExpected,
3576         observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3577     EXPECT_EQ(
3578         offsetExpected,
3579         observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3580     EXPECT_EQ(
3581         offsetExpected,
3582         observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3583     EXPECT_EQ(
3584         offsetExpected,
3585         observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3586   }
3587 }
3588 
3589 #endif // FOLLY_HAVE_SO_TIMESTAMPING
3590 
3591 #endif // __linux__
3592 
3593 TEST(AsyncSSLSocketTest, TestSNIClientHelloBehavior) {
3594   EventBase eventBase;
3595   auto serverCtx = std::make_shared<SSLContext>();
3596   auto clientCtx = std::make_shared<SSLContext>();
3597   serverCtx->loadPrivateKey(kTestKey);
3598   serverCtx->loadCertificate(kTestCert);
3599 
3600   clientCtx->setSessionCacheContext("test context");
3601   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
3602   std::shared_ptr<folly::ssl::SSLSession> resumptionSession = nullptr;
3603 
3604   {
3605     std::array<NetworkSocket, 2> fds;
3606     getfds(fds.data());
3607 
3608     AsyncSSLSocket::UniquePtr clientSock(
3609         new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
3610     AsyncSSLSocket::UniquePtr serverSock(
3611         new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
3612 
3613     // Client sends SNI that doesn't match anything the server cert advertises
3614     clientSock->setServerName("Foobar");
3615 
3616     SSLHandshakeServerParseClientHello server(
3617         std::move(serverSock), true, true);
3618     SSLHandshakeClient client(std::move(clientSock), true, true);
3619     eventBase.loop();
3620 
3621     serverSock = std::move(server).moveSocket();
3622     auto chi = serverSock->getClientHelloInfo();
3623     ASSERT_NE(chi, nullptr);
3624     EXPECT_EQ(
3625         std::string("Foobar"), std::string(serverSock->getSSLServerName()));
3626 
3627     // create another client, resuming with the prior session, but under a
3628     // different common name.
3629     clientSock = std::move(client).moveSocket();
3630     resumptionSession = clientSock->getSSLSession();
3631   }
3632 
3633   {
3634     std::array<NetworkSocket, 2> fds;
3635     getfds(fds.data());
3636 
3637     AsyncSSLSocket::UniquePtr clientSock(
3638         new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
3639     AsyncSSLSocket::UniquePtr serverSock(
3640         new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
3641 
3642     clientSock->setSSLSession(resumptionSession);
3643     clientSock->setServerName("Baz");
3644     SSLHandshakeServerParseClientHello server(
3645         std::move(serverSock), true, true);
3646     SSLHandshakeClient client(std::move(clientSock), true, true);
3647     eventBase.loop();
3648 
3649     serverSock = std::move(server).moveSocket();
3650     clientSock = std::move(client).moveSocket();
3651     EXPECT_TRUE(clientSock->getSSLSessionReused());
3652 
3653     // OpenSSL 1.1.1 changes the semantics of SSL_get_servername
3654     // in
3655     // https://github.com/openssl/openssl/commit/1c4aa31d79821dee9be98e915159d52cc30d8403
3656     //
3657     // Previously, the SNI would be taken from the ClientHello.
3658     // Now, the SNI will be taken from the established session.
3659     //
3660     // But the session that was established with the client (prior handshake)
3661     // would not have set the server name field because the SNI that the client
3662     // requested ("Foobar") did not match any of the SANs that the server was
3663     // presenting ("127.0.0.1")
3664     //
3665     // To preserve this 1.1.0 behavior, getSSLServerName() should return the
3666     // parsed ClientHello servername. This test asserts this behavior.
3667     auto sni = serverSock->getSSLServerName();
3668     ASSERT_NE(sni, nullptr);
3669 
3670     std::string sniStr(sni);
3671     EXPECT_EQ(sniStr, std::string("Baz"));
3672   }
3673 }
3674 
3675 TEST(AsyncSSLSocketTest, BytesWrittenWithMove) {
3676   WriteCallbackBase writeCallback;
3677   ReadCallback readCallback(&writeCallback);
3678   HandshakeCallback handshakeCallback(&readCallback);
3679   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
3680   TestSSLServer server(&acceptCallback);
3681 
3682   auto sslContext = std::make_shared<SSLContext>();
3683   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
3684   auto socket1 =
3685       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
3686   socket1->open(std::chrono::milliseconds(10000));
3687 
3688   // write
3689   std::vector<uint8_t> wbuf(128, 'a');
3690   socket1->write(wbuf.data(), wbuf.size());
3691   const auto socket1AppBytes = socket1->getSocket()->getAppBytesWritten();
3692   const auto socket1RawBytes = socket1->getSocket()->getRawBytesWritten();
3693   EXPECT_EQ(128, socket1AppBytes);
3694   EXPECT_LT(128, socket1RawBytes);
3695 
3696   // read reflection
3697   std::vector<uint8_t> readbuf(wbuf.size());
3698   uint32_t bytesRead = socket1->readAll(readbuf.data(), readbuf.size());
3699   EXPECT_EQ(bytesRead, wbuf.size());
3700 
3701   // additional sanity checks on virtuals
3702   EXPECT_EQ(
3703       socket1->getSSLSocket()->getRawBytesWritten(),
3704       socket1->getSocket()->getRawBytesWritten());
3705   EXPECT_EQ(128, socket1->getSocket()->getAppBytesWritten());
3706   EXPECT_EQ(128, socket1->getSSLSocket()->getAppBytesWritten());
3707 
3708   // move to another AsyncSSLSocket
3709   AsyncSSLSocket::UniquePtr socket2(
3710       new AsyncSSLSocket(sslContext, socket1->getSocket()));
3711   EXPECT_EQ(socket1AppBytes, socket2->getAppBytesWritten());
3712   EXPECT_EQ(socket1RawBytes, socket2->getRawBytesWritten());
3713 
3714   // move to an AsyncSocket
3715   AsyncSocket::UniquePtr socket3(new AsyncSocket(std::move(socket2)));
3716   EXPECT_EQ(socket1AppBytes, socket3->getAppBytesWritten());
3717   EXPECT_EQ(socket1RawBytes, socket3->getRawBytesWritten());
3718 }
3719 
3720 } // namespace folly
3721 
3722 #ifdef SIGPIPE
3723 ///////////////////////////////////////////////////////////////////////////
3724 // init_unit_test_suite
3725 ///////////////////////////////////////////////////////////////////////////
3726 namespace {
3727 struct Initializer {
Initializer__anon28101f9d0c11::Initializer3728   Initializer() { signal(SIGPIPE, SIG_IGN); }
3729 };
3730 Initializer initializer;
3731 } // namespace
3732 #endif
3733