1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <folly/io/async/test/AsyncSSLSocketTest.h>
18
19 #include <fcntl.h>
20 #include <signal.h>
21 #include <sys/types.h>
22
23 #include <fstream>
24 #include <iostream>
25 #include <list>
26 #include <set>
27 #include <thread>
28
29 #include <folly/SocketAddress.h>
30 #include <folly/String.h>
31 #include <folly/io/Cursor.h>
32 #include <folly/io/async/AsyncPipe.h>
33 #include <folly/io/async/AsyncSSLSocket.h>
34 #include <folly/io/async/EventBase.h>
35 #include <folly/io/async/EventBaseThread.h>
36 #include <folly/io/async/ScopedEventBaseThread.h>
37 #include <folly/io/async/ssl/OpenSSLTransportCertificate.h>
38 #include <folly/io/async/test/BlockingSocket.h>
39 #include <folly/io/async/test/MockAsyncTransportObserver.h>
40 #include <folly/io/async/test/TFOTest.h>
41 #include <folly/net/NetOps.h>
42 #include <folly/net/NetworkSocket.h>
43 #include <folly/net/test/MockNetOpsDispatcher.h>
44 #include <folly/portability/GMock.h>
45 #include <folly/portability/GTest.h>
46 #include <folly/portability/OpenSSL.h>
47 #include <folly/portability/Unistd.h>
48 #include <folly/ssl/Init.h>
49
50 #ifdef __linux__
51 #include <dlfcn.h>
52 #endif
53
54 #if FOLLY_OPENSSL_IS_110
55 #include <openssl/async.h>
56 #endif
57
58 using std::cerr;
59 using std::endl;
60 using std::string;
61
62 using namespace testing;
63
64 #if defined __linux__
65 namespace {
66
67 // to store libc's original setsockopt()
68 typedef int (*setsockopt_ptr)(int, int, int, const void*, socklen_t);
69 setsockopt_ptr real_setsockopt_ = nullptr;
70
71 // global struct to initialize before main runs. we can init within a test,
72 // or in main, but this method seems to be least intrsive and universal
73 struct GlobalStatic {
GlobalStatic__anon28101f9d0111::GlobalStatic74 GlobalStatic() {
75 real_setsockopt_ = (setsockopt_ptr)dlsym(RTLD_NEXT, "setsockopt");
76 }
reset__anon28101f9d0111::GlobalStatic77 void reset() noexcept { ttlsDisabledSet.clear(); }
78 // for each fd, tracks whether TTLS is disabled or not
79 std::unordered_set<folly::NetworkSocket /* fd */> ttlsDisabledSet;
80 };
81
82 // the constructor will be called before main() which is all we care about
83 GlobalStatic globalStatic;
84
85 } // namespace
86
87 // we intercept setsoctopt to test setting NO_TRANSPARENT_TLS opt
88 // this name has to be global
setsockopt(int sockfd,int level,int optname,const void * optval,socklen_t optlen)89 int setsockopt(
90 int sockfd, int level, int optname, const void* optval, socklen_t optlen) {
91 if (optname == SO_NO_TRANSPARENT_TLS) {
92 globalStatic.ttlsDisabledSet.insert(folly::NetworkSocket::fromFd(sockfd));
93 return 0;
94 }
95 return real_setsockopt_(sockfd, level, optname, optval, optlen);
96 }
97 #endif
98
99 namespace folly {
100 constexpr size_t SSLClient::kMaxReadBufferSz;
101 constexpr size_t SSLClient::kMaxReadsPerEvent;
102
getfds(NetworkSocket fds[2])103 void getfds(NetworkSocket fds[2]) {
104 if (netops::socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
105 FAIL() << "failed to create socketpair: " << errnoStr(errno);
106 }
107 for (int idx = 0; idx < 2; ++idx) {
108 if (netops::set_socket_non_blocking(fds[idx]) != 0) {
109 FAIL() << "failed to put socket " << idx
110 << " in non-blocking mode: " << errnoStr(errno);
111 }
112 }
113 }
114
getctx(std::shared_ptr<folly::SSLContext> clientCtx,std::shared_ptr<folly::SSLContext> serverCtx)115 void getctx(
116 std::shared_ptr<folly::SSLContext> clientCtx,
117 std::shared_ptr<folly::SSLContext> serverCtx) {
118 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
119
120 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
121 serverCtx->loadCertificate(kTestCert);
122 serverCtx->loadPrivateKey(kTestKey);
123 }
124
sslsocketpair(EventBase * eventBase,AsyncSSLSocket::UniquePtr * clientSock,AsyncSSLSocket::UniquePtr * serverSock)125 void sslsocketpair(
126 EventBase* eventBase,
127 AsyncSSLSocket::UniquePtr* clientSock,
128 AsyncSSLSocket::UniquePtr* serverSock) {
129 auto clientCtx = std::make_shared<folly::SSLContext>();
130 auto serverCtx = std::make_shared<folly::SSLContext>();
131 NetworkSocket fds[2];
132 getfds(fds);
133 getctx(clientCtx, serverCtx);
134 clientSock->reset(new AsyncSSLSocket(clientCtx, eventBase, fds[0], false));
135 serverSock->reset(new AsyncSSLSocket(serverCtx, eventBase, fds[1], true));
136
137 // (*clientSock)->setSendTimeout(100);
138 // (*serverSock)->setSendTimeout(100);
139 }
140
141 // client protocol filters
clientProtoFilterPickPony(unsigned char ** client,unsigned int * client_len,const unsigned char *,unsigned int)142 bool clientProtoFilterPickPony(
143 unsigned char** client,
144 unsigned int* client_len,
145 const unsigned char*,
146 unsigned int) {
147 // the protocol string in length prefixed byte string. the
148 // length byte is not included in the length
149 static unsigned char p[7] = {6, 'p', 'o', 'n', 'i', 'e', 's'};
150 *client = p;
151 *client_len = 7;
152 return true;
153 }
154
clientProtoFilterPickNone(unsigned char **,unsigned int *,const unsigned char *,unsigned int)155 bool clientProtoFilterPickNone(
156 unsigned char**, unsigned int*, const unsigned char*, unsigned int) {
157 return false;
158 }
159
getFileAsBuf(const char * fileName)160 std::string getFileAsBuf(const char* fileName) {
161 std::string buffer;
162 folly::readFile(fileName, buffer);
163 return buffer;
164 }
165
166 /**
167 * Test connecting to, writing to, reading from, and closing the
168 * connection to the SSL server.
169 */
TEST(AsyncSSLSocketTest,ConnectWriteReadClose)170 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
171 // Start listening on a local port
172 WriteCallbackBase writeCallback;
173 ReadCallback readCallback(&writeCallback);
174 HandshakeCallback handshakeCallback(&readCallback);
175 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
176 TestSSLServer server(&acceptCallback);
177
178 // Set up SSL context.
179 std::shared_ptr<SSLContext> sslContext(new SSLContext());
180 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
181 // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
182 // sslContext->authenticate(true, false);
183
184 // connect
185 auto socket =
186 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
187 socket->open(std::chrono::milliseconds(10000));
188
189 // write()
190 uint8_t buf[128];
191 memset(buf, 'a', sizeof(buf));
192 socket->write(buf, sizeof(buf));
193
194 // read()
195 uint8_t readbuf[128];
196 uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
197 EXPECT_EQ(bytesRead, 128);
198 EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
199
200 // close()
201 socket->close();
202
203 cerr << "ConnectWriteReadClose test completed" << endl;
204 EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
205 }
206
TEST(AsyncSSLSocketTest,ConnectWriteReadCloseReadable)207 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseReadable) {
208 // Same as above, but test AsyncSSLSocket::readable along the way
209
210 // Start listening on a local port
211 WriteCallbackBase writeCallback;
212 ReadCallback readCallback(&writeCallback);
213 HandshakeCallback handshakeCallback(&readCallback);
214 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
215 TestSSLServer server(&acceptCallback);
216
217 // Set up SSL context.
218 std::shared_ptr<SSLContext> sslContext(new SSLContext());
219 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
220 // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
221 // sslContext->authenticate(true, false);
222
223 // connect
224 auto socket =
225 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
226 socket->open(std::chrono::milliseconds(10000));
227
228 // write()
229 uint8_t buf[128];
230 memset(buf, 'a', sizeof(buf));
231 socket->write(buf, sizeof(buf));
232
233 // read()
234 uint8_t readbuf[128];
235 // The TLS record includes the full 128 bytes. Even though we only read 1
236 // byte out of the socket, the rest of the full record decrypted and buffered
237 // in the underlying SSL session.
238 uint32_t bytesRead = socket->read(readbuf, 1);
239 EXPECT_EQ(bytesRead, 1);
240 // The socket has no data pending in the kernel
241 EXPECT_FALSE(socket->getSocket()->AsyncSocket::readable());
242 // But the socket is readable
243 EXPECT_TRUE(socket->getSocket()->readable());
244 bytesRead += socket->readAll(readbuf + 1, sizeof(readbuf) - 1);
245 EXPECT_EQ(bytesRead, 128);
246 EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
247
248 // close()
249 socket->close();
250
251 cerr << "ConnectWriteReadClose test completed" << endl;
252 EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
253 }
254
255 /**
256 * Check that zero copy options are a noop under AsyncSSLSocket since they
257 * aren't supported.
258 */
TEST(AsyncSSLSocketTest,ZeroCopy)259 TEST(AsyncSSLSocketTest, ZeroCopy) {
260 // Set up SSL context.
261 std::shared_ptr<SSLContext> sslContext(new SSLContext());
262 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
263
264 auto socket = AsyncSSLSocket::newSocket(sslContext, /*evb=*/nullptr);
265 EXPECT_FALSE(socket->setZeroCopy(true));
266 EXPECT_FALSE(socket->getZeroCopy());
267 }
268
269 /**
270 * Same as above simple test, but with a large read len to test
271 * clamping behavior.
272 */
TEST(AsyncSSLSocketTest,ConnectWriteReadLargeClose)273 TEST(AsyncSSLSocketTest, ConnectWriteReadLargeClose) {
274 // Start listening on a local port
275 WriteCallbackBase writeCallback;
276 ReadCallback readCallback(&writeCallback);
277 HandshakeCallback handshakeCallback(&readCallback);
278 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
279 TestSSLServer server(&acceptCallback);
280
281 // Set up SSL context.
282 std::shared_ptr<SSLContext> sslContext(new SSLContext());
283 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
284 // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
285 // sslContext->authenticate(true, false);
286
287 // connect
288 auto socket =
289 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
290 socket->open(std::chrono::milliseconds(10000));
291
292 // write()
293 uint8_t buf[128];
294 memset(buf, 'a', sizeof(buf));
295 socket->write(buf, sizeof(buf));
296
297 // read()
298 uint8_t readbuf[128];
299 // we will fake the read len but that should be fine
300 size_t readLen = 1L << 33;
301 uint32_t bytesRead = socket->read(readbuf, readLen);
302 EXPECT_EQ(bytesRead, 128);
303 EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
304
305 // close()
306 socket->close();
307
308 cerr << "ConnectWriteReadClose test completed" << endl;
309 EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
310 }
311
312 /**
313 * Test reading after server close.
314 */
TEST(AsyncSSLSocketTest,ReadAfterClose)315 TEST(AsyncSSLSocketTest, ReadAfterClose) {
316 // Start listening on a local port
317 WriteCallbackBase writeCallback;
318 ReadEOFCallback readCallback(&writeCallback);
319 HandshakeCallback handshakeCallback(&readCallback);
320 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
321 auto server = std::make_unique<TestSSLServer>(&acceptCallback);
322
323 // Set up SSL context.
324 auto sslContext = std::make_shared<SSLContext>();
325 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
326
327 auto socket =
328 std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
329 socket->open();
330
331 // This should trigger an EOF on the client.
332 auto evb = handshakeCallback.getSocket()->getEventBase();
333 evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
334 std::array<uint8_t, 128> readbuf;
335 auto bytesRead = socket->read(readbuf.data(), readbuf.size());
336 EXPECT_EQ(0, bytesRead);
337 }
338
339 /**
340 * Test bad renegotiation
341 */
342 #if !defined(OPENSSL_IS_BORINGSSL)
TEST(AsyncSSLSocketTest,Renegotiate)343 TEST(AsyncSSLSocketTest, Renegotiate) {
344 EventBase eventBase;
345 auto clientCtx = std::make_shared<SSLContext>();
346 auto dfServerCtx = std::make_shared<SSLContext>();
347 std::array<NetworkSocket, 2> fds;
348 getfds(fds.data());
349 getctx(clientCtx, dfServerCtx);
350
351 AsyncSSLSocket::UniquePtr clientSock(
352 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
353 AsyncSSLSocket::UniquePtr serverSock(
354 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
355 SSLHandshakeClient client(std::move(clientSock), true, true);
356 RenegotiatingServer server(std::move(serverSock));
357
358 while (!client.handshakeSuccess_ && !client.handshakeError_) {
359 eventBase.loopOnce();
360 }
361
362 ASSERT_TRUE(client.handshakeSuccess_);
363
364 auto sslSock = std::move(client).moveSocket();
365 sslSock->detachEventBase();
366 // This is nasty, however we don't want to add support for
367 // renegotiation in AsyncSSLSocket.
368 SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
369
370 auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
371
372 std::thread t([&]() { eventBase.loopForever(); });
373
374 // Trigger the renegotiation.
375 std::array<uint8_t, 128> buf;
376 memset(buf.data(), 'a', buf.size());
377 try {
378 socket->write(buf.data(), buf.size());
379 } catch (AsyncSocketException& e) {
380 LOG(INFO) << "client got error " << e.what();
381 }
382 eventBase.terminateLoopSoon();
383 t.join();
384
385 eventBase.loop();
386 ASSERT_TRUE(server.renegotiationError_);
387 }
388 #endif
389
390 /**
391 * Negative test for handshakeError().
392 */
TEST(AsyncSSLSocketTest,HandshakeError)393 TEST(AsyncSSLSocketTest, HandshakeError) {
394 // Start listening on a local port
395 WriteCallbackBase writeCallback;
396 WriteErrorCallback readCallback(&writeCallback);
397 HandshakeCallback handshakeCallback(&readCallback);
398 HandshakeErrorCallback acceptCallback(&handshakeCallback);
399 TestSSLServer server(&acceptCallback);
400
401 // Set up SSL context.
402 std::shared_ptr<SSLContext> sslContext(new SSLContext());
403 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
404
405 // connect
406 auto socket =
407 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
408 // read()
409 bool ex = false;
410 try {
411 socket->open();
412
413 uint8_t readbuf[128];
414 uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
415 LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
416 } catch (AsyncSocketException&) {
417 ex = true;
418 }
419 EXPECT_TRUE(ex);
420
421 // close()
422 socket->close();
423 cerr << "HandshakeError test completed" << endl;
424 }
425
426 /**
427 * Negative test for readError().
428 */
TEST(AsyncSSLSocketTest,ReadError)429 TEST(AsyncSSLSocketTest, ReadError) {
430 // Start listening on a local port
431 WriteCallbackBase writeCallback;
432 ReadErrorCallback readCallback(&writeCallback);
433 HandshakeCallback handshakeCallback(&readCallback);
434 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
435 TestSSLServer server(&acceptCallback);
436
437 // Set up SSL context.
438 std::shared_ptr<SSLContext> sslContext(new SSLContext());
439 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
440
441 // connect
442 auto socket =
443 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
444 socket->open();
445
446 // write something to trigger ssl handshake
447 uint8_t buf[128];
448 memset(buf, 'a', sizeof(buf));
449 socket->write(buf, sizeof(buf));
450
451 socket->close();
452 cerr << "ReadError test completed" << endl;
453 }
454
455 /**
456 * Negative test for writeError().
457 */
TEST(AsyncSSLSocketTest,WriteError)458 TEST(AsyncSSLSocketTest, WriteError) {
459 // Start listening on a local port
460 WriteCallbackBase writeCallback;
461 WriteErrorCallback readCallback(&writeCallback);
462 HandshakeCallback handshakeCallback(&readCallback);
463 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
464 TestSSLServer server(&acceptCallback);
465
466 // Set up SSL context.
467 std::shared_ptr<SSLContext> sslContext(new SSLContext());
468 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
469
470 // connect
471 auto socket =
472 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
473 socket->open();
474
475 // write something to trigger ssl handshake
476 uint8_t buf[128];
477 memset(buf, 'a', sizeof(buf));
478 socket->write(buf, sizeof(buf));
479
480 socket->close();
481 cerr << "WriteError test completed" << endl;
482 }
483
484 /**
485 * Test a socket with TCP_NODELAY unset.
486 */
TEST(AsyncSSLSocketTest,SocketWithDelay)487 TEST(AsyncSSLSocketTest, SocketWithDelay) {
488 // Start listening on a local port
489 WriteCallbackBase writeCallback;
490 ReadCallback readCallback(&writeCallback);
491 HandshakeCallback handshakeCallback(&readCallback);
492 SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
493 TestSSLServer server(&acceptCallback);
494
495 // Set up SSL context.
496 std::shared_ptr<SSLContext> sslContext(new SSLContext());
497 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
498
499 // connect
500 auto socket =
501 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
502 socket->open();
503
504 // write()
505 uint8_t buf[128];
506 memset(buf, 'a', sizeof(buf));
507 socket->write(buf, sizeof(buf));
508
509 // read()
510 uint8_t readbuf[128];
511 uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
512 EXPECT_EQ(bytesRead, 128);
513 EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
514
515 // close()
516 socket->close();
517
518 cerr << "SocketWithDelay test completed" << endl;
519 }
520
521 #if FOLLY_OPENSSL_HAS_ALPN
522 class NextProtocolTest : public Test {
523 // For matching protos
524 public:
SetUp()525 void SetUp() override { getctx(clientCtx, serverCtx); }
526
connect(bool unset=false)527 void connect(bool unset = false) {
528 getfds(fds);
529
530 if (unset) {
531 // unsetting NPN for any of [client, server] is enough to make NPN not
532 // work
533 clientCtx->unsetNextProtocols();
534 }
535
536 AsyncSSLSocket::UniquePtr clientSock(
537 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
538 AsyncSSLSocket::UniquePtr serverSock(
539 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
540 client = std::make_unique<AlpnClient>(std::move(clientSock));
541 server = std::make_unique<AlpnServer>(std::move(serverSock));
542
543 eventBase.loop();
544 }
545
expectProtocol(const std::string & proto)546 void expectProtocol(const std::string& proto) {
547 expectHandshakeSuccess();
548 EXPECT_NE(client->nextProtoLength, 0);
549 EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
550 EXPECT_EQ(
551 memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
552 0);
553 string selected((const char*)client->nextProto, client->nextProtoLength);
554 EXPECT_EQ(proto, selected);
555 }
556
expectNoProtocol()557 void expectNoProtocol() {
558 expectHandshakeSuccess();
559 EXPECT_EQ(client->nextProtoLength, 0);
560 EXPECT_EQ(server->nextProtoLength, 0);
561 EXPECT_EQ(client->nextProto, nullptr);
562 EXPECT_EQ(server->nextProto, nullptr);
563 }
564
expectHandshakeSuccess()565 void expectHandshakeSuccess() {
566 EXPECT_FALSE(client->except.has_value())
567 << "client handshake error: " << client->except->what();
568 EXPECT_FALSE(server->except.has_value())
569 << "server handshake error: " << server->except->what();
570 }
571
expectHandshakeError()572 void expectHandshakeError() {
573 EXPECT_TRUE(client->except.has_value())
574 << "Expected client handshake error!";
575 EXPECT_TRUE(server->except.has_value())
576 << "Expected server handshake error!";
577 }
578
579 EventBase eventBase;
580 std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
581 std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
582 NetworkSocket fds[2];
583 std::unique_ptr<AlpnClient> client;
584 std::unique_ptr<AlpnServer> server;
585 };
586
TEST_F(NextProtocolTest,AlpnTestOverlap)587 TEST_F(NextProtocolTest, AlpnTestOverlap) {
588 clientCtx->setAdvertisedNextProtocols({"blub", "baz"});
589 serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
590
591 connect();
592
593 expectProtocol("baz");
594 }
595
TEST_F(NextProtocolTest,AlpnTestUnset)596 TEST_F(NextProtocolTest, AlpnTestUnset) {
597 // Identical to above test, except that we want unset NPN before
598 // looping.
599 clientCtx->setAdvertisedNextProtocols({"blub", "baz"});
600 serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
601
602 connect(true /* unset */);
603
604 expectNoProtocol();
605 }
606
TEST_F(NextProtocolTest,AlpnTestNoOverlap)607 TEST_F(NextProtocolTest, AlpnTestNoOverlap) {
608 clientCtx->setAdvertisedNextProtocols({"blub"});
609 serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
610 connect();
611
612 expectNoProtocol();
613 }
614
TEST_F(NextProtocolTest,RandomizedAlpnTest)615 TEST_F(NextProtocolTest, RandomizedAlpnTest) {
616 // Probability that this test will fail is 2^-64, which could be considered
617 // as negligible.
618 const int kTries = 64;
619
620 clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
621 serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}});
622
623 std::set<string> selectedProtocols;
624 for (int i = 0; i < kTries; ++i) {
625 connect();
626
627 EXPECT_NE(client->nextProtoLength, 0);
628 EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
629 EXPECT_EQ(
630 memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
631 0);
632 string selected((const char*)client->nextProto, client->nextProtoLength);
633 selectedProtocols.insert(selected);
634 expectHandshakeSuccess();
635 }
636 EXPECT_EQ(selectedProtocols.size(), 2);
637 }
638
TEST_F(NextProtocolTest,AlpnNotAllowMismatchNoClientProtocol)639 TEST_F(NextProtocolTest, AlpnNotAllowMismatchNoClientProtocol) {
640 clientCtx->setAdvertisedNextProtocols({});
641 serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
642 serverCtx->setAlpnAllowMismatch(false);
643
644 connect();
645
646 expectHandshakeSuccess();
647 expectNoProtocol();
648 EXPECT_EQ(server->getClientAlpns(), std::vector<std::string>({}));
649 }
650
TEST_F(NextProtocolTest,AlpnNotAllowMismatchWithOverlap)651 TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithOverlap) {
652 clientCtx->setAdvertisedNextProtocols({"blub", "baz"});
653 serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
654 serverCtx->setAlpnAllowMismatch(false);
655
656 connect();
657
658 expectProtocol("baz");
659 EXPECT_EQ(
660 server->getClientAlpns(), std::vector<std::string>({"blub", "baz"}));
661 }
662
TEST_F(NextProtocolTest,AlpnNotAllowMismatchWithoutOverlap)663 TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithoutOverlap) {
664 clientCtx->setAdvertisedNextProtocols({"blub"});
665 serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
666 serverCtx->setAlpnAllowMismatch(false);
667
668 connect();
669
670 expectHandshakeError();
671 EXPECT_EQ(server->getClientAlpns(), std::vector<std::string>({"blub"}));
672 }
673
674 #endif
675
676 #ifndef OPENSSL_NO_TLSEXT
677 /**
678 * 1. Client sends TLSEXT_HOSTNAME in client hello.
679 * 2. Server found a match SSL_CTX and use this SSL_CTX to
680 * continue the SSL handshake.
681 * 3. Server sends back TLSEXT_HOSTNAME in server hello.
682 */
TEST(AsyncSSLSocketTest,SNITestMatch)683 TEST(AsyncSSLSocketTest, SNITestMatch) {
684 EventBase eventBase;
685 std::shared_ptr<SSLContext> clientCtx(new SSLContext);
686 std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
687 // Use the same SSLContext to continue the handshake after
688 // tlsext_hostname match.
689 std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
690 const std::string serverName("xyz.newdev.facebook.com");
691 NetworkSocket fds[2];
692 getfds(fds);
693 getctx(clientCtx, dfServerCtx);
694
695 AsyncSSLSocket::UniquePtr clientSock(
696 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
697 AsyncSSLSocket::UniquePtr serverSock(
698 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
699 SNIClient client(std::move(clientSock));
700 SNIServer server(
701 std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
702
703 eventBase.loop();
704
705 EXPECT_TRUE(client.serverNameMatch);
706 EXPECT_TRUE(server.serverNameMatch);
707 }
708
709 /**
710 * 1. Client sends TLSEXT_HOSTNAME in client hello.
711 * 2. Server cannot find a matching SSL_CTX and continue to use
712 * the current SSL_CTX to do the handshake.
713 * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
714 */
TEST(AsyncSSLSocketTest,SNITestNotMatch)715 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
716 EventBase eventBase;
717 std::shared_ptr<SSLContext> clientCtx(new SSLContext);
718 std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
719 // Use the same SSLContext to continue the handshake after
720 // tlsext_hostname match.
721 std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
722 const std::string clientRequestingServerName("foo.com");
723 const std::string serverExpectedServerName("xyz.newdev.facebook.com");
724
725 NetworkSocket fds[2];
726 getfds(fds);
727 getctx(clientCtx, dfServerCtx);
728
729 AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
730 clientCtx, &eventBase, fds[0], clientRequestingServerName));
731 AsyncSSLSocket::UniquePtr serverSock(
732 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
733 SNIClient client(std::move(clientSock));
734 SNIServer server(
735 std::move(serverSock),
736 dfServerCtx,
737 hskServerCtx,
738 serverExpectedServerName);
739
740 eventBase.loop();
741
742 EXPECT_TRUE(!client.serverNameMatch);
743 EXPECT_TRUE(!server.serverNameMatch);
744 }
745 /**
746 * 1. Client sends TLSEXT_HOSTNAME in client hello.
747 * 2. We then change the serverName.
748 * 3. We expect that we get 'false' as the result for serNameMatch.
749 */
750
TEST(AsyncSSLSocketTest,SNITestChangeServerName)751 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
752 EventBase eventBase;
753 std::shared_ptr<SSLContext> clientCtx(new SSLContext);
754 std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
755 // Use the same SSLContext to continue the handshake after
756 // tlsext_hostname match.
757 std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
758 const std::string serverName("xyz.newdev.facebook.com");
759 NetworkSocket fds[2];
760 getfds(fds);
761 getctx(clientCtx, dfServerCtx);
762
763 AsyncSSLSocket::UniquePtr clientSock(
764 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
765 // Change the server name
766 std::string newName("new.com");
767 clientSock->setServerName(newName);
768 AsyncSSLSocket::UniquePtr serverSock(
769 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
770 SNIClient client(std::move(clientSock));
771 SNIServer server(
772 std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
773
774 eventBase.loop();
775
776 EXPECT_TRUE(!client.serverNameMatch);
777 }
778
779 /**
780 * 1. Client does not send TLSEXT_HOSTNAME in client hello.
781 * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
782 */
TEST(AsyncSSLSocketTest,SNITestClientHelloNoHostname)783 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
784 EventBase eventBase;
785 std::shared_ptr<SSLContext> clientCtx(new SSLContext);
786 std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
787 // Use the same SSLContext to continue the handshake after
788 // tlsext_hostname match.
789 std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
790 const std::string serverExpectedServerName("xyz.newdev.facebook.com");
791
792 NetworkSocket fds[2];
793 getfds(fds);
794 getctx(clientCtx, dfServerCtx);
795
796 AsyncSSLSocket::UniquePtr clientSock(
797 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
798 AsyncSSLSocket::UniquePtr serverSock(
799 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
800 SNIClient client(std::move(clientSock));
801 SNIServer server(
802 std::move(serverSock),
803 dfServerCtx,
804 hskServerCtx,
805 serverExpectedServerName);
806
807 eventBase.loop();
808
809 EXPECT_TRUE(!client.serverNameMatch);
810 EXPECT_TRUE(!server.serverNameMatch);
811 }
812
813 /**
814 * 1. Create an SSLContext that does not have an ALPN
815 * 2. Use AsyncSSLSocket::setSupportedApplicationProtocols on the client and
816 * server, and assert that a common ALPN was negotiated.
817 */
TEST(AsyncSSLSocketTest,SetSupportedApplicationProtocols)818 TEST(AsyncSSLSocketTest, SetSupportedApplicationProtocols) {
819 EventBase eventBase;
820 auto clientCtx = std::make_shared<SSLContext>();
821 auto dfServerCtx = std::make_shared<SSLContext>();
822 // Use the same SSLContext to continue the handshake after
823 // tlsext_hostname match.
824 auto hskServerCtx = std::make_shared<SSLContext>();
825 const std::string serverExpectedServerName("xyz.newdev.facebook.com");
826
827 NetworkSocket fds[2];
828 getfds(fds);
829 getctx(clientCtx, dfServerCtx);
830
831 AsyncSSLSocket::UniquePtr clientSock(
832 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
833 AsyncSSLSocket::UniquePtr serverSock(
834 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
835
836 std::vector<std::string> protocols;
837 protocols.push_back("rs");
838
839 clientSock->setSupportedApplicationProtocols(protocols);
840 serverSock->setSupportedApplicationProtocols(protocols);
841
842 SNIClient client(std::move(clientSock));
843 SNIServer server(
844 std::move(serverSock),
845 dfServerCtx,
846 hskServerCtx,
847 serverExpectedServerName);
848
849 eventBase.loop();
850
851 EXPECT_TRUE(
852 client.getApplicationProtocol().compare(
853 server.getApplicationProtocol()) == 0);
854 }
855
856 #endif
857 /**
858 * Test SSL client socket
859 */
TEST(AsyncSSLSocketTest,SSLClientTest)860 TEST(AsyncSSLSocketTest, SSLClientTest) {
861 // Start listening on a local port
862 WriteCallbackBase writeCallback;
863 ReadCallback readCallback(&writeCallback);
864 HandshakeCallback handshakeCallback(&readCallback);
865 SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
866 TestSSLServer server(&acceptCallback);
867
868 // Set up SSL client
869 EventBase eventBase;
870 auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
871 client->setSSLOptions(SSL_OP_NO_TICKET);
872
873 client->connect();
874 EventBaseAborter eba(&eventBase, 3000);
875 eventBase.loop();
876
877 EXPECT_EQ(client->getMiss(), 1);
878 EXPECT_EQ(client->getHit(), 0);
879
880 cerr << "SSLClientTest test completed" << endl;
881 }
882
883 /**
884 * Test SSL client socket session re-use
885 */
TEST(AsyncSSLSocketTest,SSLClientTestReuse)886 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
887 // Start listening on a local port
888 WriteCallbackBase writeCallback;
889 ReadCallback readCallback(&writeCallback);
890 HandshakeCallback handshakeCallback(&readCallback);
891 SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
892 TestSSLServer server(&acceptCallback);
893
894 // Set up SSL client
895 EventBase eventBase;
896 auto client =
897 std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
898 client->setSSLOptions(SSL_OP_NO_TICKET);
899
900 client->connect();
901 EventBaseAborter eba(&eventBase, 3000);
902 eventBase.loop();
903
904 EXPECT_EQ(client->getMiss(), 1);
905 EXPECT_EQ(client->getHit(), 9);
906
907 cerr << "SSLClientTestReuse test completed" << endl;
908 }
909
910 /**
911 * Test SSL client socket timeout
912 */
TEST(AsyncSSLSocketTest,SSLClientTimeoutTest)913 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
914 // Start listening on a local port
915 EmptyReadCallback readCallback;
916 HandshakeCallback handshakeCallback(
917 &readCallback, HandshakeCallback::EXPECT_ERROR);
918 HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
919 TestSSLServer server(&acceptCallback);
920
921 // Set up SSL client
922 EventBase eventBase;
923 auto client =
924 std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
925 client->setSSLOptions(SSL_OP_NO_TICKET);
926
927 client->connect(true /* write before connect completes */);
928 EventBaseAborter eba(&eventBase, 3000);
929 eventBase.loop();
930
931 usleep(100000);
932 // This is checking that the connectError callback precedes any queued
933 // writeError callbacks. This matches AsyncSocket's behavior
934 EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
935 EXPECT_EQ(client->getErrors(), 1);
936 EXPECT_EQ(client->getMiss(), 0);
937 EXPECT_EQ(client->getHit(), 0);
938
939 cerr << "SSLClientTimeoutTest test completed" << endl;
940 }
941
942 class PerLoopReadCallback : public AsyncTransport::ReadCallback {
943 public:
getReadBuffer(void ** bufReturn,size_t * lenReturn)944 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
945 *bufReturn = buf_.data();
946 *lenReturn = buf_.size();
947 }
948
readDataAvailable(size_t len)949 void readDataAvailable(size_t len) noexcept override {
950 VLOG(3) << "Read of size: " << len;
951 s_->setReadCB(nullptr);
952 s_->getEventBase()->runInLoop([this]() { s_->setReadCB(this); });
953 }
954
readErr(const AsyncSocketException &)955 void readErr(const AsyncSocketException&) noexcept override {}
956
readEOF()957 void readEOF() noexcept override {}
958
setSocket(AsyncSocket * s)959 void setSocket(AsyncSocket* s) { s_ = s; }
960
961 private:
962 AsyncSocket* s_;
963 std::array<uint8_t, 1000> buf_;
964 };
965
966 class CloseNotifyConnector : public AsyncSocket::ConnectCallback {
967 public:
CloseNotifyConnector(EventBase * evb,const SocketAddress & addr)968 CloseNotifyConnector(EventBase* evb, const SocketAddress& addr) {
969 evb_ = evb;
970 ssl_ = AsyncSSLSocket::newSocket(std::make_shared<SSLContext>(), evb_);
971 ssl_->connect(this, addr);
972 }
973
connectSuccess()974 void connectSuccess() noexcept override {
975 ssl_->writeChain(nullptr, IOBuf::copyBuffer("hi"));
976 auto ssl = const_cast<SSL*>(ssl_->getSSL());
977 SSL_shutdown(ssl);
978 auto fd = ssl_->detachNetworkSocket();
979 tcp_.reset(new AsyncSocket(evb_, fd), AsyncSocket::Destructor());
980 evb_->runAfterDelay(
981 [this]() {
982 perLoopReads_.setSocket(tcp_.get());
983 tcp_->setReadCB(&perLoopReads_);
984 evb_->runAfterDelay([this]() { tcp_->closeNow(); }, 10);
985 },
986 100);
987 }
988
connectErr(const AsyncSocketException & ex)989 void connectErr(const AsyncSocketException& ex) noexcept override {
990 FAIL() << ex.what();
991 }
992
993 private:
994 EventBase* evb_;
995 std::shared_ptr<AsyncSSLSocket> ssl_;
996 std::shared_ptr<AsyncSocket> tcp_;
997 PerLoopReadCallback perLoopReads_;
998 };
999
1000 class ErrorCheckingWriteCallback : public AsyncSocket::WriteCallback {
1001 public:
writeSuccess()1002 void writeSuccess() noexcept override {}
1003
writeErr(size_t,const AsyncSocketException & ex)1004 void writeErr(size_t, const AsyncSocketException& ex) noexcept override {
1005 LOG(ERROR) << "write error: " << ex.what();
1006 EXPECT_NE(
1007 ex.getType(),
1008 AsyncSocketException::AsyncSocketExceptionType::SSL_ERROR);
1009 }
1010 };
1011
1012 class WriteOnEofReadCallback : public ReadCallback {
1013 public:
1014 using ReadCallback::ReadCallback;
1015
readEOF()1016 void readEOF() noexcept override {
1017 LOG(INFO) << "Got EOF";
1018 auto chain = IOBuf::create(0);
1019 for (size_t i = 0; i < 1000 * 1000; i++) {
1020 auto buf = IOBuf::create(10);
1021 buf->append(10);
1022 memset(buf->writableData(), 'x', 10);
1023 chain->prependChain(std::move(buf));
1024 }
1025 socket_->writeChain(&writeCallback_, std::move(chain));
1026 }
1027
readErr(const AsyncSocketException & ex)1028 void readErr(const AsyncSocketException& ex) noexcept override {
1029 LOG(ERROR) << ex.what();
1030 }
1031
1032 private:
1033 ErrorCheckingWriteCallback writeCallback_;
1034 };
1035
TEST(AsyncSSLSocketTest,EarlyCloseNotify)1036 TEST(AsyncSSLSocketTest, EarlyCloseNotify) {
1037 WriteOnEofReadCallback readCallback(nullptr);
1038 HandshakeCallback handshakeCallback(&readCallback);
1039 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1040 TestSSLServer server(&acceptCallback);
1041
1042 EventBase eventBase;
1043 CloseNotifyConnector cnc(&eventBase, server.getAddress());
1044
1045 eventBase.loop();
1046 }
1047
1048 /**
1049 * Verify Client Ciphers obtained using SSL MSG Callback.
1050 */
TEST(AsyncSSLSocketTest,SSLParseClientHelloSuccess)1051 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1052 EventBase eventBase;
1053 auto clientCtx = std::make_shared<SSLContext>();
1054 auto serverCtx = std::make_shared<SSLContext>();
1055 serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1056 serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1057 serverCtx->loadPrivateKey(kTestKey);
1058 serverCtx->loadCertificate(kTestCert);
1059 serverCtx->loadTrustedCertificates(kTestCA);
1060 serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
1061
1062 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1063 clientCtx->ciphers("AES256-SHA:AES128-SHA");
1064 clientCtx->loadPrivateKey(kTestKey);
1065 clientCtx->loadCertificate(kTestCert);
1066 clientCtx->loadTrustedCertificates(kTestCA);
1067
1068 NetworkSocket fds[2];
1069 getfds(fds);
1070
1071 AsyncSSLSocket::UniquePtr clientSock(
1072 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1073 AsyncSSLSocket::UniquePtr serverSock(
1074 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1075
1076 SSLHandshakeClient client(std::move(clientSock), true, true);
1077 SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1078
1079 eventBase.loop();
1080
1081 #if defined(OPENSSL_IS_BORINGSSL)
1082 EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1083 #else
1084 EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1085 #endif
1086 EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1087 EXPECT_TRUE(client.handshakeVerify_);
1088 EXPECT_TRUE(client.handshakeSuccess_);
1089 EXPECT_TRUE(!client.handshakeError_);
1090 EXPECT_TRUE(server.handshakeVerify_);
1091 EXPECT_TRUE(server.handshakeSuccess_);
1092 EXPECT_TRUE(!server.handshakeError_);
1093 }
1094
1095 /**
1096 * Verify that server is able to get client cert by getPeerCert() API.
1097 */
TEST(AsyncSSLSocketTest,GetClientCertificate)1098 TEST(AsyncSSLSocketTest, GetClientCertificate) {
1099 EventBase eventBase;
1100 auto clientCtx = std::make_shared<SSLContext>();
1101 auto serverCtx = std::make_shared<SSLContext>();
1102 serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1103 serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1104 serverCtx->loadPrivateKey(kTestKey);
1105 serverCtx->loadCertificate(kTestCert);
1106 serverCtx->loadTrustedCertificates(kClientTestCA);
1107 serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kClientTestCA);
1108
1109 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1110 clientCtx->ciphers("AES256-SHA:AES128-SHA");
1111 clientCtx->loadPrivateKey(kClientTestKey);
1112 clientCtx->loadCertificate(kClientTestCert);
1113 clientCtx->loadTrustedCertificates(kTestCA);
1114
1115 std::array<NetworkSocket, 2> fds;
1116 getfds(fds.data());
1117
1118 AsyncSSLSocket::UniquePtr clientSock(
1119 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1120 AsyncSSLSocket::UniquePtr serverSock(
1121 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1122
1123 SSLHandshakeClient client(std::move(clientSock), true, true);
1124 SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1125
1126 eventBase.loop();
1127
1128 // Handshake should succeed.
1129 EXPECT_TRUE(client.handshakeSuccess_);
1130 EXPECT_TRUE(server.handshakeSuccess_);
1131
1132 // Reclaim the sockets from SSLHandshakeBase.
1133 auto cliSocket = std::move(client).moveSocket();
1134 auto srvSocket = std::move(server).moveSocket();
1135
1136 // Client cert retrieved from server side.
1137 auto serverPeerCert = srvSocket->getPeerCertificate();
1138 CHECK(serverPeerCert);
1139
1140 // Client cert retrieved from client side.
1141 auto clientSelfCert = cliSocket->getSelfCertificate();
1142 CHECK(clientSelfCert);
1143
1144 auto serverX509 =
1145 folly::OpenSSLTransportCertificate::tryExtractX509(serverPeerCert);
1146 CHECK(serverX509);
1147
1148 auto clientX509 =
1149 folly::OpenSSLTransportCertificate::tryExtractX509(clientSelfCert);
1150 CHECK(clientX509);
1151
1152 // The two certs should be the same.
1153 EXPECT_EQ(0, X509_cmp(clientX509.get(), serverX509.get()));
1154 }
1155
TEST(AsyncSSLSocketTest,SSLParseClientHelloOnePacket)1156 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1157 EventBase eventBase;
1158 auto ctx = std::make_shared<SSLContext>();
1159
1160 NetworkSocket fds[2];
1161 getfds(fds);
1162
1163 int bufLen = 42;
1164 uint8_t majorVersion = 18;
1165 uint8_t minorVersion = 25;
1166
1167 // Create callback buf
1168 auto buf = IOBuf::create(bufLen);
1169 buf->append(bufLen);
1170 folly::io::RWPrivateCursor cursor(buf.get());
1171 cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1172 cursor.write<uint16_t>(0);
1173 cursor.write<uint8_t>(38);
1174 cursor.write<uint8_t>(majorVersion);
1175 cursor.write<uint8_t>(minorVersion);
1176 cursor.skip(32);
1177 cursor.write<uint32_t>(0);
1178
1179 SSL* ssl = ctx->createSSL();
1180 SCOPE_EXIT { SSL_free(ssl); };
1181 AsyncSSLSocket::UniquePtr sock(
1182 new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1183 sock->enableClientHelloParsing();
1184
1185 // Test client hello parsing in one packet
1186 AsyncSSLSocket::clientHelloParsingCallback(
1187 0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1188 buf.reset();
1189
1190 auto parsedClientHello = sock->getClientHelloInfo();
1191 EXPECT_TRUE(parsedClientHello != nullptr);
1192 EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1193 EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1194 }
1195
TEST(AsyncSSLSocketTest,SSLParseClientHelloTwoPackets)1196 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1197 EventBase eventBase;
1198 auto ctx = std::make_shared<SSLContext>();
1199
1200 NetworkSocket fds[2];
1201 getfds(fds);
1202
1203 int bufLen = 42;
1204 uint8_t majorVersion = 18;
1205 uint8_t minorVersion = 25;
1206
1207 // Create callback buf
1208 auto buf = IOBuf::create(bufLen);
1209 buf->append(bufLen);
1210 folly::io::RWPrivateCursor cursor(buf.get());
1211 cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1212 cursor.write<uint16_t>(0);
1213 cursor.write<uint8_t>(38);
1214 cursor.write<uint8_t>(majorVersion);
1215 cursor.write<uint8_t>(minorVersion);
1216 cursor.skip(32);
1217 cursor.write<uint32_t>(0);
1218
1219 SSL* ssl = ctx->createSSL();
1220 SCOPE_EXIT { SSL_free(ssl); };
1221 AsyncSSLSocket::UniquePtr sock(
1222 new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1223 sock->enableClientHelloParsing();
1224
1225 // Test parsing with two packets with first packet size < 3
1226 auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1227 AsyncSSLSocket::clientHelloParsingCallback(
1228 0,
1229 0,
1230 SSL3_RT_HANDSHAKE,
1231 bufCopy->data(),
1232 bufCopy->length(),
1233 ssl,
1234 sock.get());
1235 bufCopy.reset();
1236 bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1237 AsyncSSLSocket::clientHelloParsingCallback(
1238 0,
1239 0,
1240 SSL3_RT_HANDSHAKE,
1241 bufCopy->data(),
1242 bufCopy->length(),
1243 ssl,
1244 sock.get());
1245 bufCopy.reset();
1246
1247 auto parsedClientHello = sock->getClientHelloInfo();
1248 EXPECT_TRUE(parsedClientHello != nullptr);
1249 EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1250 EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1251 }
1252
TEST(AsyncSSLSocketTest,SSLParseClientHelloMultiplePackets)1253 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1254 EventBase eventBase;
1255 auto ctx = std::make_shared<SSLContext>();
1256
1257 NetworkSocket fds[2];
1258 getfds(fds);
1259
1260 int bufLen = 42;
1261 uint8_t majorVersion = 18;
1262 uint8_t minorVersion = 25;
1263
1264 // Create callback buf
1265 auto buf = IOBuf::create(bufLen);
1266 buf->append(bufLen);
1267 folly::io::RWPrivateCursor cursor(buf.get());
1268 cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1269 cursor.write<uint16_t>(0);
1270 cursor.write<uint8_t>(38);
1271 cursor.write<uint8_t>(majorVersion);
1272 cursor.write<uint8_t>(minorVersion);
1273 cursor.skip(32);
1274 cursor.write<uint32_t>(0);
1275
1276 SSL* ssl = ctx->createSSL();
1277 SCOPE_EXIT { SSL_free(ssl); };
1278 AsyncSSLSocket::UniquePtr sock(
1279 new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1280 sock->enableClientHelloParsing();
1281
1282 // Test parsing with multiple small packets
1283 for (std::size_t i = 0; i < buf->length(); i += 3) {
1284 auto bufCopy = folly::IOBuf::copyBuffer(
1285 buf->data() + i, std::min((std::size_t)3, buf->length() - i));
1286 AsyncSSLSocket::clientHelloParsingCallback(
1287 0,
1288 0,
1289 SSL3_RT_HANDSHAKE,
1290 bufCopy->data(),
1291 bufCopy->length(),
1292 ssl,
1293 sock.get());
1294 bufCopy.reset();
1295 }
1296
1297 auto parsedClientHello = sock->getClientHelloInfo();
1298 EXPECT_TRUE(parsedClientHello != nullptr);
1299 EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1300 EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1301 }
1302
1303 /**
1304 * Verify sucessful behavior of SSL certificate validation.
1305 */
TEST(AsyncSSLSocketTest,SSLHandshakeValidationSuccess)1306 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1307 EventBase eventBase;
1308 auto clientCtx = std::make_shared<SSLContext>();
1309 auto dfServerCtx = std::make_shared<SSLContext>();
1310
1311 NetworkSocket fds[2];
1312 getfds(fds);
1313 getctx(clientCtx, dfServerCtx);
1314
1315 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1316 dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1317
1318 AsyncSSLSocket::UniquePtr clientSock(
1319 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1320 AsyncSSLSocket::UniquePtr serverSock(
1321 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1322
1323 SSLHandshakeClient client(std::move(clientSock), true, true);
1324 clientCtx->loadTrustedCertificates(kTestCA);
1325
1326 SSLHandshakeServer server(std::move(serverSock), true, true);
1327
1328 eventBase.loop();
1329
1330 EXPECT_TRUE(client.handshakeVerify_);
1331 EXPECT_TRUE(client.handshakeSuccess_);
1332 EXPECT_TRUE(!client.handshakeError_);
1333 EXPECT_LE(0, client.handshakeTime.count());
1334 EXPECT_TRUE(!server.handshakeVerify_);
1335 EXPECT_TRUE(server.handshakeSuccess_);
1336 EXPECT_TRUE(!server.handshakeError_);
1337 EXPECT_LE(0, server.handshakeTime.count());
1338 }
1339
1340 /**
1341 * Verify that the client's verification callback is able to fail SSL
1342 * connection establishment.
1343 */
TEST(AsyncSSLSocketTest,SSLHandshakeValidationFailure)1344 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1345 EventBase eventBase;
1346 auto clientCtx = std::make_shared<SSLContext>();
1347 auto dfServerCtx = std::make_shared<SSLContext>();
1348
1349 NetworkSocket fds[2];
1350 getfds(fds);
1351 getctx(clientCtx, dfServerCtx);
1352
1353 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1354 dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1355
1356 AsyncSSLSocket::UniquePtr clientSock(
1357 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1358 AsyncSSLSocket::UniquePtr serverSock(
1359 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1360
1361 SSLHandshakeClient client(std::move(clientSock), true, false);
1362 clientCtx->loadTrustedCertificates(kTestCA);
1363
1364 SSLHandshakeServer server(std::move(serverSock), true, true);
1365
1366 eventBase.loop();
1367
1368 EXPECT_TRUE(client.handshakeVerify_);
1369 EXPECT_TRUE(!client.handshakeSuccess_);
1370 EXPECT_TRUE(client.handshakeError_);
1371 EXPECT_LE(0, client.handshakeTime.count());
1372 EXPECT_TRUE(!server.handshakeVerify_);
1373 EXPECT_TRUE(!server.handshakeSuccess_);
1374 EXPECT_TRUE(server.handshakeError_);
1375 EXPECT_LE(0, server.handshakeTime.count());
1376 }
1377
1378 /**
1379 * Verify that the client successfully handshakes when
1380 * CertificateIdentityVerifier is set and returns with no exception.
1381 */
TEST(AsyncSSLSocketTest,SSLCertificateIdentityVerifierReturns)1382 TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierReturns) {
1383 EventBase eventBase;
1384 auto clientCtx = std::make_shared<folly::SSLContext>();
1385 auto serverCtx = std::make_shared<folly::SSLContext>();
1386 getctx(clientCtx, serverCtx);
1387 // the client socket will default to USE_CTX, so set VERIFY here
1388 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1389 // load root certificate
1390 clientCtx->loadTrustedCertificates(kTestCA);
1391
1392 // prepare a basic server (callbacks have a few EXPECTS to fullfil)
1393 ReadCallback readCallback(nullptr);
1394 // expects successful handshake
1395 HandshakeCallback handshakeCallback(&readCallback);
1396 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1397 TestSSLServer server(&acceptCallback, serverCtx);
1398
1399 // return success in the Try with folly::unit
1400 Try<Unit> verifyResult{unit};
1401
1402 std::shared_ptr<MockCertificateIdentityVerifier> verifier =
1403 std::make_shared<MockCertificateIdentityVerifier>();
1404
1405 // expecting to only verify once, with the leaf certificate
1406 // (kTestCert)
1407 EXPECT_CALL(
1408 *verifier,
1409 verifyLeafImpl(Property(
1410 &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
1411 .WillOnce(Return(ByMove(verifyResult)));
1412
1413 AsyncSSLSocket::Options opts;
1414 opts.verifier = std::move(verifier);
1415
1416 // connect to server and handshake
1417 AsyncSSLSocket::UniquePtr socket(
1418 new AsyncSSLSocket(clientCtx, &eventBase, std::move(opts)));
1419 socket->connect(nullptr, server.getAddress(), 0);
1420
1421 // write to satisfy server ReadCallback EXPECTs
1422 std::array<uint8_t, 128> buf;
1423 memset(buf.data(), 'a', buf.size());
1424 socket->write(nullptr, buf.data(), buf.size());
1425
1426 eventBase.loop();
1427
1428 socket->close();
1429 }
1430
1431 class TestCertificateIdentityVerifierException
1432 : public CertificateIdentityVerifierException {
1433 public:
TestCertificateIdentityVerifierException(const char * content)1434 explicit TestCertificateIdentityVerifierException(const char* content)
1435 : CertificateIdentityVerifierException(content) {}
1436 };
1437
1438 /**
1439 * Verify that the client fails to connect during handshake because
1440 * CertificateIdentityVerifier returns a failure while verifying the server's
1441 * leaf certificate.
1442 */
TEST(AsyncSSLSocketTest,SSLCertificateIdentityVerifierFailsToConnect)1443 TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierFailsToConnect) {
1444 EventBase eventBase;
1445 auto clientCtx = std::make_shared<folly::SSLContext>();
1446 auto serverCtx = std::make_shared<folly::SSLContext>();
1447 getctx(clientCtx, serverCtx);
1448 // the client socket will default to USE_CTX, so set VERIFY here
1449 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1450 // load root certificate
1451 clientCtx->loadTrustedCertificates(kTestCA);
1452
1453 // prepare a basic server (callbacks have a few EXPECTS to fullfil)
1454 ReadCallback readCallback(nullptr);
1455 // expects a failed handshake
1456 HandshakeCallback handshakeCallback(
1457 &readCallback, HandshakeCallback::ExpectType::EXPECT_ERROR);
1458 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1459 TestSSLServer server(&acceptCallback, serverCtx);
1460
1461 std::shared_ptr<MockCertificateIdentityVerifier> verifier =
1462 std::make_shared<MockCertificateIdentityVerifier>();
1463
1464 // return a failed result Try
1465 TestCertificateIdentityVerifierException failed{"a failed test reason"};
1466 Try<Unit> result{failed};
1467
1468 // expecting to only verify once, with the leaf certificate (kTestCert)
1469 EXPECT_CALL(
1470 *verifier,
1471 verifyLeafImpl(Property(
1472 &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
1473 .WillOnce(Return(ByMove(result)));
1474
1475 AsyncSSLSocket::Options opts;
1476 opts.verifier = std::move(verifier);
1477
1478 // connect to server and handshake
1479 AsyncSSLSocket::UniquePtr socket(
1480 new AsyncSSLSocket(clientCtx, &eventBase, std::move(opts)));
1481 socket->connect(nullptr, server.getAddress(), 0);
1482
1483 eventBase.loop();
1484
1485 socket->close();
1486 }
1487
1488 /**
1489 * Verify that the client's CertificateIdentityVerifier is not invoked if
1490 * OpenSSL's verification fails. (With no HandshakeCB.)
1491 */
TEST(AsyncSSLSocketTest,SSLCertificateIdentityVerifierNotInvokedX509Failure)1492 TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierNotInvokedX509Failure) {
1493 EventBase eventBase;
1494 auto clientCtx = std::make_shared<folly::SSLContext>();
1495 auto serverCtx = std::make_shared<folly::SSLContext>();
1496 getctx(clientCtx, serverCtx);
1497 // the client socket will default to USE_CTX, so set VERIFY here
1498 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1499 // DO NOT load root certificate, so that server certificate is rejected
1500
1501 // prepare a basic server (callbacks have a few EXPECTS to fullfil)
1502 ReadCallback readCallback(nullptr);
1503 // expects successful handshake
1504 HandshakeCallback handshakeCallback(
1505 &readCallback, HandshakeCallback::ExpectType::EXPECT_ERROR);
1506 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1507 TestSSLServer server(&acceptCallback, serverCtx);
1508
1509 // should not get called
1510 std::shared_ptr<StrictMock<MockCertificateIdentityVerifier>> verifier =
1511 std::make_shared<StrictMock<MockCertificateIdentityVerifier>>();
1512
1513 AsyncSSLSocket::Options opts;
1514 opts.verifier = std::move(verifier);
1515
1516 // connect to server and handshake
1517 AsyncSSLSocket::UniquePtr socket(
1518 new AsyncSSLSocket(clientCtx, &eventBase, std::move(opts)));
1519 socket->connect(nullptr, server.getAddress(), 0);
1520
1521 eventBase.loop();
1522
1523 socket->close();
1524 }
1525
1526 /**
1527 * Verify that the client CertificateIdentityVerifier is not invoked if
1528 * HandshakeCB::handshakeVer verification fails.
1529 */
TEST(AsyncSSLSocketTest,SSLCertificateIdentityVerifierNotInvokedHandshakeCBFailure)1530 TEST(
1531 AsyncSSLSocketTest,
1532 SSLCertificateIdentityVerifierNotInvokedHandshakeCBFailure) {
1533 EventBase eventBase;
1534 auto clientCtx = std::make_shared<folly::SSLContext>();
1535 auto serverCtx = std::make_shared<folly::SSLContext>();
1536 getctx(clientCtx, serverCtx);
1537 // the client socket will default to USE_CTX, so set VERIFY here
1538 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1539 // load root certificate
1540 clientCtx->loadTrustedCertificates(kTestCA);
1541
1542 NetworkSocket fds[2];
1543 getfds(fds);
1544
1545 AsyncSocket::UniquePtr rawClient(new AsyncSocket(&eventBase, fds[0]));
1546 AsyncSocket::UniquePtr rawServer(new AsyncSocket(&eventBase, fds[1]));
1547
1548 // should not be invoked
1549 std::shared_ptr<StrictMock<MockCertificateIdentityVerifier>> verifier =
1550 std::make_shared<StrictMock<MockCertificateIdentityVerifier>>();
1551
1552 AsyncSSLSocket::Options clientOpts;
1553 clientOpts.verifier = verifier;
1554
1555 AsyncSSLSocket::Options serverOpts;
1556 serverOpts.isServer = true;
1557
1558 AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
1559 clientCtx, std::move(rawClient), std::move(clientOpts)));
1560 AsyncSSLSocket::UniquePtr serverSock(new AsyncSSLSocket(
1561 serverCtx, std::move(rawServer), std::move(serverOpts)));
1562
1563 serverSock->sslAccept(nullptr, std::chrono::milliseconds::zero());
1564
1565 StrictMock<MockHandshakeCB> clientHandshakeCB;
1566
1567 // Force the end entity certificate, which normally is successfully verified,
1568 // to be considered as unsuccessful
1569 EXPECT_CALL(clientHandshakeCB, handshakeVerImpl(clientSock.get(), true, _))
1570 .Times(AtLeast(1))
1571 .WillRepeatedly(Invoke([&](auto&&, bool preverifyOk, auto&& ctx) {
1572 auto currentDepth = X509_STORE_CTX_get_error_depth(ctx);
1573 if (currentDepth == 0) {
1574 EXPECT_TRUE(preverifyOk);
1575 return false;
1576 }
1577 return preverifyOk;
1578 }));
1579
1580 // failure callback to verify handshake failed
1581 EXPECT_CALL(clientHandshakeCB, handshakeErrImpl(clientSock.get(), _));
1582
1583 clientSock->sslConn(&clientHandshakeCB);
1584
1585 eventBase.loop();
1586
1587 clientSock->close();
1588 serverSock->close();
1589 }
1590
1591 /**
1592 * Verify that the client CertificateIdentityVerifier is invoked on a server
1593 * socket when peer verification is requested.
1594 */
TEST(AsyncSSLSocketTest,SSLCertificateIdentityVerifierSucceedsOnServer)1595 TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierSucceedsOnServer) {
1596 EventBase eventBase;
1597 auto clientCtx = std::make_shared<folly::SSLContext>();
1598 auto serverCtx = std::make_shared<folly::SSLContext>();
1599 getctx(clientCtx, serverCtx);
1600 // the client socket will default to USE_CTX, so set VERIFY here
1601 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1602 // load root certificate
1603 clientCtx->loadTrustedCertificates(kTestCA);
1604 // load identity and key on client, it's the same identity as server just for
1605 // convenience
1606 clientCtx->loadCertificate(kTestCert);
1607 clientCtx->loadPrivateKey(kTestKey);
1608 // instruct server to verify client
1609 serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1610 serverCtx->loadTrustedCertificates(kTestCA);
1611
1612 NetworkSocket fds[2];
1613 getfds(fds);
1614
1615 AsyncSocket::UniquePtr rawClient(new AsyncSocket(&eventBase, fds[0]));
1616 AsyncSocket::UniquePtr rawServer(new AsyncSocket(&eventBase, fds[1]));
1617
1618 // client and server verifiers should verify only once each
1619 std::shared_ptr<MockCertificateIdentityVerifier> clientVerifier =
1620 std::make_shared<MockCertificateIdentityVerifier>();
1621 EXPECT_CALL(
1622 *clientVerifier,
1623 verifyLeafImpl(Property(
1624 &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
1625 .WillOnce(Return(Try<Unit>{unit}));
1626 std::shared_ptr<StrictMock<MockCertificateIdentityVerifier>> serverVerifier =
1627 std::make_shared<StrictMock<MockCertificateIdentityVerifier>>();
1628 EXPECT_CALL(
1629 *serverVerifier,
1630 verifyLeafImpl(Property(
1631 &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
1632 .WillOnce(Return(Try<Unit>{unit}));
1633
1634 AsyncSSLSocket::Options clientOpts;
1635 clientOpts.verifier = clientVerifier;
1636
1637 AsyncSSLSocket::Options serverOpts;
1638 serverOpts.isServer = true;
1639 serverOpts.verifier = serverVerifier;
1640
1641 AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
1642 clientCtx, std::move(rawClient), std::move(clientOpts)));
1643 AsyncSSLSocket::UniquePtr serverSock(new AsyncSSLSocket(
1644 serverCtx, std::move(rawServer), std::move(serverOpts)));
1645
1646 // no HandshakeCBs anywhere
1647 serverSock->sslAccept(nullptr, std::chrono::milliseconds::zero());
1648 clientSock->sslConn(nullptr);
1649
1650 eventBase.loop();
1651
1652 clientSock->close();
1653 serverSock->close();
1654 }
1655
1656 /**
1657 * Verify that the options in SSLContext can be overridden in
1658 * sslConnect/Accept.i.e specifying that no validation should be performed
1659 * allows an otherwise-invalid certificate to be accepted and doesn't fire
1660 * the validation callback.
1661 */
TEST(AsyncSSLSocketTest,OverrideSSLCtxDisableVerify)1662 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1663 EventBase eventBase;
1664 auto clientCtx = std::make_shared<SSLContext>();
1665 auto dfServerCtx = std::make_shared<SSLContext>();
1666
1667 NetworkSocket fds[2];
1668 getfds(fds);
1669 getctx(clientCtx, dfServerCtx);
1670
1671 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1672 dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1673
1674 AsyncSSLSocket::UniquePtr clientSock(
1675 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1676 AsyncSSLSocket::UniquePtr serverSock(
1677 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1678
1679 SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1680 clientCtx->loadTrustedCertificates(kTestCA);
1681
1682 SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1683
1684 eventBase.loop();
1685
1686 EXPECT_TRUE(!client.handshakeVerify_);
1687 EXPECT_TRUE(client.handshakeSuccess_);
1688 EXPECT_TRUE(!client.handshakeError_);
1689 EXPECT_LE(0, client.handshakeTime.count());
1690 EXPECT_TRUE(!server.handshakeVerify_);
1691 EXPECT_TRUE(server.handshakeSuccess_);
1692 EXPECT_TRUE(!server.handshakeError_);
1693 EXPECT_LE(0, server.handshakeTime.count());
1694 }
1695
1696 /**
1697 * Verify that the options in SSLContext can be overridden in
1698 * sslConnect/Accept. Enable verification even if context says otherwise.
1699 * Test requireClientCert with client cert
1700 */
TEST(AsyncSSLSocketTest,OverrideSSLCtxEnableVerify)1701 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1702 EventBase eventBase;
1703 auto clientCtx = std::make_shared<SSLContext>();
1704 auto serverCtx = std::make_shared<SSLContext>();
1705 serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1706 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1707 serverCtx->loadPrivateKey(kTestKey);
1708 serverCtx->loadCertificate(kTestCert);
1709 serverCtx->loadTrustedCertificates(kTestCA);
1710 serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
1711
1712 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1713 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1714 clientCtx->loadPrivateKey(kTestKey);
1715 clientCtx->loadCertificate(kTestCert);
1716 clientCtx->loadTrustedCertificates(kTestCA);
1717
1718 NetworkSocket fds[2];
1719 getfds(fds);
1720
1721 AsyncSSLSocket::UniquePtr clientSock(
1722 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1723 AsyncSSLSocket::UniquePtr serverSock(
1724 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1725
1726 SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1727 SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1728
1729 eventBase.loop();
1730
1731 EXPECT_TRUE(client.handshakeVerify_);
1732 EXPECT_TRUE(client.handshakeSuccess_);
1733 EXPECT_FALSE(client.handshakeError_);
1734 EXPECT_LE(0, client.handshakeTime.count());
1735 EXPECT_TRUE(server.handshakeVerify_);
1736 EXPECT_TRUE(server.handshakeSuccess_);
1737 EXPECT_FALSE(server.handshakeError_);
1738 EXPECT_LE(0, server.handshakeTime.count());
1739 }
1740
1741 /**
1742 * Verify that the client's verification callback is able to override
1743 * the preverification failure and allow a successful connection.
1744 */
TEST(AsyncSSLSocketTest,SSLHandshakeValidationOverride)1745 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1746 EventBase eventBase;
1747 auto clientCtx = std::make_shared<SSLContext>();
1748 auto dfServerCtx = std::make_shared<SSLContext>();
1749
1750 NetworkSocket fds[2];
1751 getfds(fds);
1752 getctx(clientCtx, dfServerCtx);
1753
1754 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1755 dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1756
1757 AsyncSSLSocket::UniquePtr clientSock(
1758 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1759 AsyncSSLSocket::UniquePtr serverSock(
1760 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1761
1762 SSLHandshakeClient client(std::move(clientSock), false, true);
1763 SSLHandshakeServer server(std::move(serverSock), true, true);
1764
1765 eventBase.loop();
1766
1767 EXPECT_TRUE(client.handshakeVerify_);
1768 EXPECT_TRUE(client.handshakeSuccess_);
1769 EXPECT_TRUE(!client.handshakeError_);
1770 EXPECT_LE(0, client.handshakeTime.count());
1771 EXPECT_TRUE(!server.handshakeVerify_);
1772 EXPECT_TRUE(server.handshakeSuccess_);
1773 EXPECT_TRUE(!server.handshakeError_);
1774 EXPECT_LE(0, server.handshakeTime.count());
1775 }
1776
1777 /**
1778 * Verify that specifying that no validation should be performed allows an
1779 * otherwise-invalid certificate to be accepted and doesn't fire the validation
1780 * callback.
1781 */
TEST(AsyncSSLSocketTest,SSLHandshakeValidationSkip)1782 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1783 EventBase eventBase;
1784 auto clientCtx = std::make_shared<SSLContext>();
1785 auto dfServerCtx = std::make_shared<SSLContext>();
1786
1787 NetworkSocket fds[2];
1788 getfds(fds);
1789 getctx(clientCtx, dfServerCtx);
1790
1791 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1792 dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1793
1794 AsyncSSLSocket::UniquePtr clientSock(
1795 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1796 AsyncSSLSocket::UniquePtr serverSock(
1797 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1798
1799 SSLHandshakeClient client(std::move(clientSock), false, false);
1800 SSLHandshakeServer server(std::move(serverSock), false, false);
1801
1802 eventBase.loop();
1803
1804 EXPECT_TRUE(!client.handshakeVerify_);
1805 EXPECT_TRUE(client.handshakeSuccess_);
1806 EXPECT_TRUE(!client.handshakeError_);
1807 EXPECT_LE(0, client.handshakeTime.count());
1808 EXPECT_TRUE(!server.handshakeVerify_);
1809 EXPECT_TRUE(server.handshakeSuccess_);
1810 EXPECT_TRUE(!server.handshakeError_);
1811 EXPECT_LE(0, server.handshakeTime.count());
1812 }
1813
1814 /**
1815 * Test requireClientCert with client cert
1816 */
TEST(AsyncSSLSocketTest,ClientCertHandshakeSuccess)1817 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1818 EventBase eventBase;
1819 auto clientCtx = std::make_shared<SSLContext>();
1820 auto serverCtx = std::make_shared<SSLContext>();
1821 serverCtx->setVerificationOption(
1822 SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1823 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1824 serverCtx->loadPrivateKey(kTestKey);
1825 serverCtx->loadCertificate(kTestCert);
1826 serverCtx->loadTrustedCertificates(kTestCA);
1827 serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
1828
1829 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1830 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1831 clientCtx->loadPrivateKey(kTestKey);
1832 clientCtx->loadCertificate(kTestCert);
1833 clientCtx->loadTrustedCertificates(kTestCA);
1834
1835 NetworkSocket fds[2];
1836 getfds(fds);
1837
1838 AsyncSSLSocket::UniquePtr clientSock(
1839 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1840 AsyncSSLSocket::UniquePtr serverSock(
1841 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1842
1843 SSLHandshakeClient client(std::move(clientSock), true, true);
1844 SSLHandshakeServer server(std::move(serverSock), true, true);
1845
1846 eventBase.loop();
1847
1848 EXPECT_TRUE(client.handshakeVerify_);
1849 EXPECT_TRUE(client.handshakeSuccess_);
1850 EXPECT_FALSE(client.handshakeError_);
1851 EXPECT_LE(0, client.handshakeTime.count());
1852 EXPECT_TRUE(server.handshakeVerify_);
1853 EXPECT_TRUE(server.handshakeSuccess_);
1854 EXPECT_FALSE(server.handshakeError_);
1855 EXPECT_LE(0, server.handshakeTime.count());
1856
1857 // check certificates
1858 auto clientSsl = std::move(client).moveSocket();
1859 auto serverSsl = std::move(server).moveSocket();
1860
1861 auto clientPeer = clientSsl->getPeerCertificate();
1862 auto clientSelf = clientSsl->getSelfCertificate();
1863 auto serverPeer = serverSsl->getPeerCertificate();
1864 auto serverSelf = serverSsl->getSelfCertificate();
1865
1866 EXPECT_NE(clientPeer, nullptr);
1867 EXPECT_NE(clientSelf, nullptr);
1868 EXPECT_NE(serverPeer, nullptr);
1869 EXPECT_NE(serverSelf, nullptr);
1870
1871 EXPECT_EQ(clientPeer->getIdentity(), serverSelf->getIdentity());
1872 EXPECT_EQ(clientSelf->getIdentity(), serverPeer->getIdentity());
1873 }
1874
1875 /**
1876 * Test requireClientCert with no client cert
1877 */
TEST(AsyncSSLSocketTest,NoClientCertHandshakeError)1878 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1879 EventBase eventBase;
1880 auto clientCtx = std::make_shared<SSLContext>();
1881 auto serverCtx = std::make_shared<SSLContext>();
1882 serverCtx->setVerificationOption(
1883 SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1884 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1885 serverCtx->loadPrivateKey(kTestKey);
1886 serverCtx->loadCertificate(kTestCert);
1887 serverCtx->loadTrustedCertificates(kTestCA);
1888 serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
1889 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1890 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1891
1892 NetworkSocket fds[2];
1893 getfds(fds);
1894
1895 AsyncSSLSocket::UniquePtr clientSock(
1896 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1897 AsyncSSLSocket::UniquePtr serverSock(
1898 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1899
1900 SSLHandshakeClient client(std::move(clientSock), false, false);
1901 SSLHandshakeServer server(std::move(serverSock), false, false);
1902
1903 eventBase.loop();
1904
1905 EXPECT_FALSE(server.handshakeVerify_);
1906 EXPECT_FALSE(server.handshakeSuccess_);
1907 EXPECT_TRUE(server.handshakeError_);
1908 EXPECT_LE(0, client.handshakeTime.count());
1909 EXPECT_LE(0, server.handshakeTime.count());
1910 }
1911
1912 /**
1913 * Test OpenSSL 1.1.0's async functionality
1914 */
1915 #if FOLLY_OPENSSL_IS_110
1916
makeNonBlockingPipe(int pipefds[2])1917 static void makeNonBlockingPipe(int pipefds[2]) {
1918 if (pipe(pipefds) != 0) {
1919 throw std::runtime_error("Cannot create pipe");
1920 }
1921 if (::fcntl(pipefds[0], F_SETFL, O_NONBLOCK) != 0) {
1922 throw std::runtime_error("Cannot set pipe to nonblocking");
1923 }
1924 if (::fcntl(pipefds[1], F_SETFL, O_NONBLOCK) != 0) {
1925 throw std::runtime_error("Cannot set pipe to nonblocking");
1926 }
1927 }
1928
1929 // Custom RSA private key encryption method
1930 static int kRSAExIndex = -1;
1931 static int kRSAEvbExIndex = -1;
1932 static int kRSASocketExIndex = -1;
1933 static constexpr StringPiece kEngineId = "AsyncSSLSocketTest";
1934
customRsaPrivEnc(int flen,const unsigned char * from,unsigned char * to,RSA * rsa,int padding)1935 static int customRsaPrivEnc(
1936 int flen,
1937 const unsigned char* from,
1938 unsigned char* to,
1939 RSA* rsa,
1940 int padding) {
1941 LOG(INFO) << "rsa_priv_enc";
1942 EventBase* asyncJobEvb =
1943 reinterpret_cast<EventBase*>(RSA_get_ex_data(rsa, kRSAEvbExIndex));
1944 CHECK(asyncJobEvb);
1945
1946 RSA* actualRSA = reinterpret_cast<RSA*>(RSA_get_ex_data(rsa, kRSAExIndex));
1947 CHECK(actualRSA);
1948
1949 AsyncSSLSocket* socket = reinterpret_cast<AsyncSSLSocket*>(
1950 RSA_get_ex_data(rsa, kRSASocketExIndex));
1951
1952 ASYNC_JOB* job = ASYNC_get_current_job();
1953 if (job == nullptr) {
1954 throw std::runtime_error("Expected call in job context");
1955 }
1956 ASYNC_WAIT_CTX* waitctx = ASYNC_get_wait_ctx(job);
1957 OSSL_ASYNC_FD pipefds[2] = {0, 0};
1958 makeNonBlockingPipe(pipefds);
1959 if (!ASYNC_WAIT_CTX_set_wait_fd(
1960 waitctx, kEngineId.data(), pipefds[0], nullptr, nullptr)) {
1961 throw std::runtime_error("Cannot set wait fd");
1962 }
1963 int ret = 0;
1964 int* retptr = &ret;
1965
1966 auto hand = folly::NetworkSocket::native_handle_type(pipefds[1]);
1967 auto asyncPipeWriter = folly::AsyncPipeWriter::newWriter(
1968 asyncJobEvb, folly::NetworkSocket(hand));
1969
1970 if (socket) {
1971 LOG(INFO) << "Got a socket passed in, closing it...";
1972 socket->closeNow();
1973 }
1974 asyncJobEvb->runInEventBaseThread([retptr = retptr,
1975 flen = flen,
1976 from = from,
1977 to = to,
1978 padding = padding,
1979 actualRSA = actualRSA,
1980 writer = std::move(asyncPipeWriter)]() {
1981 LOG(INFO) << "Running job";
1982 *retptr = RSA_meth_get_priv_enc(RSA_PKCS1_OpenSSL())(
1983 flen, from, to, actualRSA, padding);
1984 LOG(INFO) << "Finished job, writing to pipe";
1985 uint8_t byte = *retptr > 0 ? 1 : 0;
1986 writer->write(nullptr, &byte, 1);
1987 });
1988
1989 LOG(INFO) << "About to pause job";
1990
1991 ASYNC_pause_job();
1992 LOG(INFO) << "Resumed job with ret: " << ret;
1993 return ret;
1994 }
1995
rsaFree(void *,void * ptr,CRYPTO_EX_DATA *,int,long,void *)1996 void rsaFree(void*, void* ptr, CRYPTO_EX_DATA*, int, long, void*) {
1997 LOG(INFO) << "RSA_free is called with ptr " << std::hex << ptr;
1998 if (ptr == nullptr) {
1999 LOG(INFO) << "Returning early from rsaFree because ptr is null";
2000 return;
2001 }
2002 RSA* rsa = (RSA*)ptr;
2003 auto meth = RSA_get_method(rsa);
2004 if (meth != RSA_get_default_method()) {
2005 auto nonconst = const_cast<RSA_METHOD*>(meth);
2006 RSA_meth_free(nonconst);
2007 RSA_set_method(rsa, RSA_get_default_method());
2008 }
2009 RSA_free(rsa);
2010 }
2011
2012 struct RSAPointers {
2013 RSA* actualrsa{nullptr};
2014 RSA* dummyrsa{nullptr};
2015 RSA_METHOD* meth{nullptr};
2016 };
2017
RSAPointersFree(RSAPointers * p)2018 inline void RSAPointersFree(RSAPointers* p) {
2019 if (p->meth && p->dummyrsa && RSA_get_method(p->dummyrsa) == p->meth) {
2020 RSA_set_method(p->dummyrsa, RSA_get_default_method());
2021 }
2022
2023 if (p->meth) {
2024 LOG(INFO) << "Freeing meth";
2025 RSA_meth_free(p->meth);
2026 }
2027
2028 if (p->actualrsa) {
2029 LOG(INFO) << "Freeing actualrsa";
2030 RSA_free(p->actualrsa);
2031 }
2032
2033 if (p->dummyrsa) {
2034 LOG(INFO) << "Freeing dummyrsa";
2035 RSA_free(p->dummyrsa);
2036 }
2037
2038 delete p;
2039 }
2040
2041 using RSAPointersDeleter =
2042 folly::static_function_deleter<RSAPointers, RSAPointersFree>;
2043
setupCustomRSA(const char * certPath,const char * keyPath,EventBase * jobEvb)2044 std::unique_ptr<RSAPointers, RSAPointersDeleter> setupCustomRSA(
2045 const char* certPath, const char* keyPath, EventBase* jobEvb) {
2046 auto certPEM = getFileAsBuf(certPath);
2047 auto keyPEM = getFileAsBuf(keyPath);
2048
2049 ssl::BioUniquePtr certBio(
2050 BIO_new_mem_buf((void*)certPEM.data(), certPEM.size()));
2051 ssl::BioUniquePtr keyBio(
2052 BIO_new_mem_buf((void*)keyPEM.data(), keyPEM.size()));
2053
2054 ssl::X509UniquePtr cert(
2055 PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
2056 ssl::EvpPkeyUniquePtr evpPkey(
2057 PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
2058 ssl::EvpPkeyUniquePtr publicEvpPkey(X509_get_pubkey(cert.get()));
2059
2060 std::unique_ptr<RSAPointers, RSAPointersDeleter> ret(new RSAPointers());
2061
2062 RSA* actualrsa = EVP_PKEY_get1_RSA(evpPkey.get());
2063 LOG(INFO) << "actualrsa ptr " << std::hex << (void*)actualrsa;
2064 RSA* dummyrsa = EVP_PKEY_get1_RSA(publicEvpPkey.get());
2065 if (dummyrsa == nullptr) {
2066 throw std::runtime_error("Couldn't get RSA cert public factors");
2067 }
2068 RSA_METHOD* meth = RSA_meth_dup(RSA_get_default_method());
2069 if (meth == nullptr || RSA_meth_set1_name(meth, "Async RSA method") == 0 ||
2070 RSA_meth_set_priv_enc(meth, customRsaPrivEnc) == 0 ||
2071 RSA_meth_set_flags(meth, RSA_METHOD_FLAG_NO_CHECK) == 0) {
2072 throw std::runtime_error("Cannot create async RSA_METHOD");
2073 }
2074 RSA_set_method(dummyrsa, meth);
2075 RSA_set_flags(dummyrsa, RSA_FLAG_EXT_PKEY);
2076
2077 kRSAExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
2078 kRSAEvbExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
2079 kRSASocketExIndex =
2080 RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
2081 CHECK_NE(kRSAExIndex, -1);
2082 CHECK_NE(kRSAEvbExIndex, -1);
2083 CHECK_NE(kRSASocketExIndex, -1);
2084 RSA_set_ex_data(dummyrsa, kRSAExIndex, actualrsa);
2085 RSA_set_ex_data(dummyrsa, kRSAEvbExIndex, jobEvb);
2086
2087 ret->actualrsa = actualrsa;
2088 ret->dummyrsa = dummyrsa;
2089 ret->meth = meth;
2090
2091 return ret;
2092 }
2093
2094 // TODO: disabled with ASAN doesn't play nice with ASYNC for some reason
2095 #ifndef FOLLY_SANITIZE_ADDRESS
TEST(AsyncSSLSocketTest,OpenSSL110AsyncTest)2096 TEST(AsyncSSLSocketTest, OpenSSL110AsyncTest) {
2097 ASYNC_init_thread(1, 1);
2098 EventBase eventBase;
2099 ScopedEventBaseThread jobEvbThread;
2100 auto clientCtx = std::make_shared<SSLContext>();
2101 auto serverCtx = std::make_shared<SSLContext>();
2102 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2103 serverCtx->loadCertificate(kTestCert);
2104 serverCtx->loadTrustedCertificates(kTestCA);
2105 serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
2106
2107 auto rsaPointers =
2108 setupCustomRSA(kTestCert, kTestKey, jobEvbThread.getEventBase());
2109 CHECK(rsaPointers->dummyrsa);
2110 // up-refs dummyrsa
2111 SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
2112 SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
2113
2114 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
2115 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2116
2117 NetworkSocket fds[2];
2118 getfds(fds);
2119
2120 AsyncSSLSocket::UniquePtr clientSock(
2121 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2122 AsyncSSLSocket::UniquePtr serverSock(
2123 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2124
2125 SSLHandshakeClient client(std::move(clientSock), false, false);
2126 SSLHandshakeServer server(std::move(serverSock), false, false);
2127
2128 eventBase.loop();
2129
2130 EXPECT_TRUE(server.handshakeSuccess_);
2131 EXPECT_TRUE(client.handshakeSuccess_);
2132 ASYNC_cleanup_thread();
2133 }
2134
TEST(AsyncSSLSocketTest,OpenSSL110AsyncTestFailure)2135 TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestFailure) {
2136 ASYNC_init_thread(1, 1);
2137 EventBase eventBase;
2138 ScopedEventBaseThread jobEvbThread;
2139 auto clientCtx = std::make_shared<SSLContext>();
2140 auto serverCtx = std::make_shared<SSLContext>();
2141 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2142 serverCtx->loadCertificate(kTestCert);
2143 serverCtx->loadTrustedCertificates(kTestCA);
2144 serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
2145 // Set the wrong key for the cert
2146 auto rsaPointers =
2147 setupCustomRSA(kTestCert, kClientTestKey, jobEvbThread.getEventBase());
2148 CHECK(rsaPointers->dummyrsa);
2149 SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
2150 SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
2151
2152 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
2153 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2154
2155 NetworkSocket fds[2];
2156 getfds(fds);
2157
2158 AsyncSSLSocket::UniquePtr clientSock(
2159 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2160 AsyncSSLSocket::UniquePtr serverSock(
2161 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2162
2163 SSLHandshakeClient client(std::move(clientSock), false, false);
2164 SSLHandshakeServer server(std::move(serverSock), false, false);
2165
2166 eventBase.loop();
2167
2168 EXPECT_TRUE(server.handshakeError_);
2169 EXPECT_TRUE(client.handshakeError_);
2170 ASYNC_cleanup_thread();
2171 }
2172
TEST(AsyncSSLSocketTest,OpenSSL110AsyncTestClosedWithCallbackPending)2173 TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestClosedWithCallbackPending) {
2174 ASYNC_init_thread(1, 1);
2175 EventBase eventBase;
2176 std::optional<EventBaseThread> jobEvbThread;
2177 jobEvbThread.emplace();
2178 auto clientCtx = std::make_shared<SSLContext>();
2179 auto serverCtx = std::make_shared<SSLContext>();
2180 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2181 serverCtx->loadCertificate(kTestCert);
2182 serverCtx->loadTrustedCertificates(kTestCA);
2183 serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(kTestCA);
2184
2185 auto rsaPointers =
2186 setupCustomRSA(kTestCert, kTestKey, jobEvbThread->getEventBase());
2187 CHECK(rsaPointers->dummyrsa);
2188 // up-refs dummyrsa
2189 SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
2190 SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
2191
2192 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
2193 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2194
2195 NetworkSocket fds[2];
2196 getfds(fds);
2197
2198 AsyncSSLSocket::UniquePtr clientSock(
2199 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2200 AsyncSSLSocket::UniquePtr serverSock(
2201 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2202
2203 RSA_set_ex_data(rsaPointers->dummyrsa, kRSASocketExIndex, serverSock.get());
2204
2205 SSLHandshakeClient client(std::move(clientSock), false, false);
2206 SSLHandshakeServer server(std::move(serverSock), false, false);
2207
2208 eventBase.loop();
2209
2210 EXPECT_TRUE(server.handshakeError_);
2211 EXPECT_TRUE(client.handshakeError_);
2212 ASYNC_cleanup_thread();
2213 jobEvbThread.reset();
2214 }
2215 #endif // FOLLY_SANITIZE_ADDRESS
2216
2217 #endif // FOLLY_OPENSSL_IS_110
2218
TEST(AsyncSSLSocketTest,LoadCertFromMemory)2219 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
2220 using folly::ssl::OpenSSLUtils;
2221 auto cert = getFileAsBuf(kTestCert);
2222 auto key = getFileAsBuf(kTestKey);
2223
2224 ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
2225 BIO_write(certBio.get(), cert.data(), cert.size());
2226 ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
2227 BIO_write(keyBio.get(), key.data(), key.size());
2228
2229 // Create SSL structs from buffers to get properties
2230 ssl::X509UniquePtr certStruct(
2231 PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
2232 ssl::EvpPkeyUniquePtr keyStruct(
2233 PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
2234 certBio = nullptr;
2235 keyBio = nullptr;
2236
2237 auto origCommonName = OpenSSLUtils::getCommonName(certStruct.get());
2238 auto origKeySize = EVP_PKEY_bits(keyStruct.get());
2239 certStruct = nullptr;
2240 keyStruct = nullptr;
2241
2242 auto ctx = std::make_shared<SSLContext>();
2243 ctx->loadPrivateKeyFromBufferPEM(key);
2244 ctx->loadCertificateFromBufferPEM(cert);
2245 ctx->loadTrustedCertificates(kTestCA);
2246
2247 ssl::SSLUniquePtr ssl(ctx->createSSL());
2248
2249 auto newCert = SSL_get_certificate(ssl.get());
2250 auto newKey = SSL_get_privatekey(ssl.get());
2251
2252 // Get properties from SSL struct
2253 auto newCommonName = OpenSSLUtils::getCommonName(newCert);
2254 auto newKeySize = EVP_PKEY_bits(newKey);
2255
2256 // Check that the key and cert have the expected properties
2257 EXPECT_EQ(origCommonName, newCommonName);
2258 EXPECT_EQ(origKeySize, newKeySize);
2259 }
2260
TEST(AsyncSSLSocketTest,MinWriteSizeTest)2261 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
2262 EventBase eb;
2263
2264 // Set up SSL context.
2265 auto sslContext = std::make_shared<SSLContext>();
2266 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2267
2268 // create SSL socket
2269 AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
2270
2271 EXPECT_EQ(1500, socket->getMinWriteSize());
2272
2273 socket->setMinWriteSize(0);
2274 EXPECT_EQ(0, socket->getMinWriteSize());
2275 socket->setMinWriteSize(50000);
2276 EXPECT_EQ(50000, socket->getMinWriteSize());
2277 }
2278
2279 class ReadCallbackTerminator : public ReadCallback {
2280 public:
ReadCallbackTerminator(EventBase * base,WriteCallbackBase * wcb)2281 ReadCallbackTerminator(EventBase* base, WriteCallbackBase* wcb)
2282 : ReadCallback(wcb), base_(base) {}
2283
2284 // Do not write data back, terminate the loop.
readDataAvailable(size_t len)2285 void readDataAvailable(size_t len) noexcept override {
2286 std::cerr << "readDataAvailable, len " << len << std::endl;
2287
2288 currentBuffer.length = len;
2289
2290 buffers.push_back(currentBuffer);
2291 currentBuffer.reset();
2292 state = STATE_SUCCEEDED;
2293
2294 socket_->setReadCB(nullptr);
2295 base_->terminateLoopSoon();
2296 }
2297
2298 private:
2299 EventBase* base_;
2300 };
2301
2302 /**
2303 * Test a full unencrypted codepath
2304 */
TEST(AsyncSSLSocketTest,UnencryptedTest)2305 TEST(AsyncSSLSocketTest, UnencryptedTest) {
2306 EventBase base;
2307
2308 auto clientCtx = std::make_shared<folly::SSLContext>();
2309 auto serverCtx = std::make_shared<folly::SSLContext>();
2310 NetworkSocket fds[2];
2311 getfds(fds);
2312 getctx(clientCtx, serverCtx);
2313 auto client =
2314 AsyncSSLSocket::newSocket(clientCtx, &base, fds[0], false, true);
2315 std::shared_ptr<AsyncSSLSocket> server =
2316 AsyncSSLSocket::newSocket(serverCtx, &base, fds[1], true, true);
2317
2318 ReadCallbackTerminator readCallback(&base, nullptr);
2319 server->setReadCB(&readCallback);
2320 readCallback.setSocket(server);
2321
2322 uint8_t buf[128];
2323 memset(buf, 'a', sizeof(buf));
2324 client->write(nullptr, buf, sizeof(buf));
2325
2326 // Check that bytes are unencrypted
2327 char c;
2328 EXPECT_EQ(1, netops::recv(fds[1], &c, 1, MSG_PEEK));
2329 EXPECT_EQ('a', c);
2330
2331 EventBaseAborter eba(&base, 3000);
2332 base.loop();
2333
2334 EXPECT_EQ(1, readCallback.buffers.size());
2335 EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
2336
2337 server->setReadCB(&readCallback);
2338
2339 // Unencrypted
2340 server->sslAccept(nullptr);
2341 client->sslConn(nullptr);
2342
2343 // Do NOT wait for handshake, writing should be queued and happen after
2344
2345 client->write(nullptr, buf, sizeof(buf));
2346
2347 // Check that bytes are *not* unencrypted
2348 char c2;
2349 EXPECT_EQ(1, netops::recv(fds[1], &c2, 1, MSG_PEEK));
2350 EXPECT_NE('a', c2);
2351
2352 base.loop();
2353
2354 EXPECT_EQ(2, readCallback.buffers.size());
2355 EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
2356 }
2357
TEST(AsyncSSLSocketTest,ConnectUnencryptedTest)2358 TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
2359 auto clientCtx = std::make_shared<folly::SSLContext>();
2360 auto serverCtx = std::make_shared<folly::SSLContext>();
2361 getctx(clientCtx, serverCtx);
2362
2363 WriteCallbackBase writeCallback;
2364 ReadCallback readCallback(&writeCallback);
2365 HandshakeCallback handshakeCallback(&readCallback);
2366 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2367 TestSSLServer server(&acceptCallback);
2368
2369 EventBase evb;
2370 std::shared_ptr<AsyncSSLSocket> socket =
2371 AsyncSSLSocket::newSocket(clientCtx, &evb, true);
2372 socket->connect(nullptr, server.getAddress(), 0);
2373
2374 evb.loop();
2375
2376 EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, socket->getSSLState());
2377 socket->sslConn(nullptr);
2378 evb.loop();
2379 EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, socket->getSSLState());
2380
2381 // write()
2382 std::array<uint8_t, 128> buf;
2383 memset(buf.data(), 'a', buf.size());
2384 socket->write(nullptr, buf.data(), buf.size());
2385
2386 socket->close();
2387 }
2388
2389 /**
2390 * Test acceptrunner in various situations
2391 */
TEST(AsyncSSLSocketTest,SSLAcceptRunnerBasic)2392 TEST(AsyncSSLSocketTest, SSLAcceptRunnerBasic) {
2393 EventBase eventBase;
2394 auto clientCtx = std::make_shared<SSLContext>();
2395 auto serverCtx = std::make_shared<SSLContext>();
2396 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2397 serverCtx->loadPrivateKey(kTestKey);
2398 serverCtx->loadCertificate(kTestCert);
2399
2400 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2401 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2402 clientCtx->loadTrustedCertificates(kTestCA);
2403
2404 NetworkSocket fds[2];
2405 getfds(fds);
2406
2407 AsyncSSLSocket::UniquePtr clientSock(
2408 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2409 AsyncSSLSocket::UniquePtr serverSock(
2410 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2411
2412 serverCtx->sslAcceptRunner(std::make_unique<SSLAcceptEvbRunner>(&eventBase));
2413
2414 SSLHandshakeClient client(std::move(clientSock), true, true);
2415 SSLHandshakeServer server(std::move(serverSock), true, true);
2416
2417 eventBase.loop();
2418
2419 EXPECT_TRUE(client.handshakeSuccess_);
2420 EXPECT_FALSE(client.handshakeError_);
2421 EXPECT_LE(0, client.handshakeTime.count());
2422 EXPECT_TRUE(server.handshakeSuccess_);
2423 EXPECT_FALSE(server.handshakeError_);
2424 EXPECT_LE(0, server.handshakeTime.count());
2425 }
2426
TEST(AsyncSSLSocketTest,SSLAcceptRunnerAcceptError)2427 TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptError) {
2428 EventBase eventBase;
2429 auto clientCtx = std::make_shared<SSLContext>();
2430 auto serverCtx = std::make_shared<SSLContext>();
2431 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2432 serverCtx->loadPrivateKey(kTestKey);
2433 serverCtx->loadCertificate(kTestCert);
2434
2435 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2436 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2437 clientCtx->loadTrustedCertificates(kTestCA);
2438
2439 NetworkSocket fds[2];
2440 getfds(fds);
2441
2442 AsyncSSLSocket::UniquePtr clientSock(
2443 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2444 AsyncSSLSocket::UniquePtr serverSock(
2445 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2446
2447 serverCtx->sslAcceptRunner(
2448 std::make_unique<SSLAcceptErrorRunner>(&eventBase));
2449
2450 SSLHandshakeClient client(std::move(clientSock), true, true);
2451 SSLHandshakeServer server(std::move(serverSock), true, true);
2452
2453 eventBase.loop();
2454
2455 EXPECT_FALSE(client.handshakeSuccess_);
2456 EXPECT_TRUE(client.handshakeError_);
2457 EXPECT_FALSE(server.handshakeSuccess_);
2458 EXPECT_TRUE(server.handshakeError_);
2459 }
2460
TEST(AsyncSSLSocketTest,SSLAcceptRunnerAcceptClose)2461 TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptClose) {
2462 EventBase eventBase;
2463 auto clientCtx = std::make_shared<SSLContext>();
2464 auto serverCtx = std::make_shared<SSLContext>();
2465 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2466 serverCtx->loadPrivateKey(kTestKey);
2467 serverCtx->loadCertificate(kTestCert);
2468
2469 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2470 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2471 clientCtx->loadTrustedCertificates(kTestCA);
2472
2473 NetworkSocket fds[2];
2474 getfds(fds);
2475
2476 AsyncSSLSocket::UniquePtr clientSock(
2477 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2478 AsyncSSLSocket::UniquePtr serverSock(
2479 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2480
2481 serverCtx->sslAcceptRunner(
2482 std::make_unique<SSLAcceptCloseRunner>(&eventBase, serverSock.get()));
2483
2484 SSLHandshakeClient client(std::move(clientSock), true, true);
2485 SSLHandshakeServer server(std::move(serverSock), true, true);
2486
2487 eventBase.loop();
2488
2489 EXPECT_FALSE(client.handshakeSuccess_);
2490 EXPECT_TRUE(client.handshakeError_);
2491 EXPECT_FALSE(server.handshakeSuccess_);
2492 EXPECT_TRUE(server.handshakeError_);
2493 }
2494
TEST(AsyncSSLSocketTest,SSLAcceptRunnerAcceptDestroy)2495 TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptDestroy) {
2496 EventBase eventBase;
2497 auto clientCtx = std::make_shared<SSLContext>();
2498 auto serverCtx = std::make_shared<SSLContext>();
2499 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2500 serverCtx->loadPrivateKey(kTestKey);
2501 serverCtx->loadCertificate(kTestCert);
2502
2503 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2504 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2505 clientCtx->loadTrustedCertificates(kTestCA);
2506
2507 NetworkSocket fds[2];
2508 getfds(fds);
2509
2510 AsyncSSLSocket::UniquePtr clientSock(
2511 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2512 AsyncSSLSocket::UniquePtr serverSock(
2513 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2514
2515 SSLHandshakeClient client(std::move(clientSock), true, true);
2516 SSLHandshakeServer server(std::move(serverSock), true, true);
2517
2518 serverCtx->sslAcceptRunner(
2519 std::make_unique<SSLAcceptDestroyRunner>(&eventBase, &server));
2520
2521 eventBase.loop();
2522
2523 EXPECT_FALSE(client.handshakeSuccess_);
2524 EXPECT_TRUE(client.handshakeError_);
2525 EXPECT_FALSE(server.handshakeSuccess_);
2526 EXPECT_TRUE(server.handshakeError_);
2527 }
2528
TEST(AsyncSSLSocketTest,SSLAcceptRunnerFiber)2529 TEST(AsyncSSLSocketTest, SSLAcceptRunnerFiber) {
2530 EventBase eventBase;
2531 auto clientCtx = std::make_shared<SSLContext>();
2532 auto serverCtx = std::make_shared<SSLContext>();
2533 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2534 serverCtx->loadPrivateKey(kTestKey);
2535 serverCtx->loadCertificate(kTestCert);
2536
2537 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2538 clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2539 clientCtx->loadTrustedCertificates(kTestCA);
2540
2541 NetworkSocket fds[2];
2542 getfds(fds);
2543
2544 AsyncSSLSocket::UniquePtr clientSock(
2545 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2546 AsyncSSLSocket::UniquePtr serverSock(
2547 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2548
2549 SSLHandshakeClient client(std::move(clientSock), true, true);
2550 SSLHandshakeServer server(std::move(serverSock), true, true);
2551
2552 serverCtx->sslAcceptRunner(
2553 std::make_unique<SSLAcceptFiberRunner>(&eventBase));
2554
2555 eventBase.loop();
2556
2557 EXPECT_TRUE(client.handshakeSuccess_);
2558 EXPECT_FALSE(client.handshakeError_);
2559 EXPECT_TRUE(server.handshakeSuccess_);
2560 EXPECT_FALSE(server.handshakeError_);
2561 }
2562
newCloseCb(SSL * ssl,SSL_SESSION *)2563 static int newCloseCb(SSL* ssl, SSL_SESSION*) {
2564 AsyncSSLSocket::getFromSSL(ssl)->closeNow();
2565 return 1;
2566 }
2567
2568 #if FOLLY_OPENSSL_IS_110
getCloseCb(SSL * ssl,const unsigned char *,int,int *)2569 static SSL_SESSION* getCloseCb(SSL* ssl, const unsigned char*, int, int*) {
2570 #else
2571 static SSL_SESSION* getCloseCb(SSL* ssl, unsigned char*, int, int*) {
2572 #endif
2573 AsyncSSLSocket::getFromSSL(ssl)->closeNow();
2574 return nullptr;
2575 } // namespace folly
2576
2577 TEST(AsyncSSLSocketTest, SSLAcceptRunnerFiberCloseSessionCb) {
2578 EventBase eventBase;
2579 auto clientCtx = std::make_shared<SSLContext>();
2580 auto serverCtx = std::make_shared<SSLContext>();
2581 serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2582 serverCtx->loadPrivateKey(kTestKey);
2583 serverCtx->loadCertificate(kTestCert);
2584 SSL_CTX_set_session_cache_mode(
2585 serverCtx->getSSLCtx(),
2586 SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
2587 SSL_CTX_sess_set_new_cb(serverCtx->getSSLCtx(), &newCloseCb);
2588 SSL_CTX_sess_set_get_cb(serverCtx->getSSLCtx(), &getCloseCb);
2589 serverCtx->sslAcceptRunner(
2590 std::make_unique<SSLAcceptFiberRunner>(&eventBase));
2591
2592 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
2593 clientCtx->ciphers("AES128-SHA256");
2594 clientCtx->loadTrustedCertificates(kTestCA);
2595 clientCtx->setOptions(SSL_OP_NO_TICKET);
2596
2597 NetworkSocket fds[2];
2598 getfds(fds);
2599
2600 AsyncSSLSocket::UniquePtr clientSock(
2601 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
2602 AsyncSSLSocket::UniquePtr serverSock(
2603 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
2604
2605 SSLHandshakeClient client(std::move(clientSock), true, true);
2606 SSLHandshakeServer server(std::move(serverSock), true, true);
2607
2608 eventBase.loop();
2609
2610 // As close() is called during session callbacks, client sees it as a
2611 // successful connection
2612 EXPECT_TRUE(client.handshakeSuccess_);
2613 EXPECT_FALSE(client.handshakeError_);
2614 EXPECT_FALSE(server.handshakeSuccess_);
2615 EXPECT_TRUE(server.handshakeError_);
2616 }
2617
2618 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
2619 // Start listening on a local port
2620 WriteCallbackBase writeCallback;
2621 WriteErrorCallback readCallback(&writeCallback);
2622 HandshakeCallback handshakeCallback(
2623 &readCallback, HandshakeCallback::EXPECT_ERROR);
2624 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2625 TestSSLServer server(&acceptCallback);
2626
2627 auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
2628 socket->open();
2629 uint8_t buf[3] = {0x16, 0x03, 0x01};
2630 socket->write(buf, sizeof(buf));
2631 socket->closeWithReset();
2632
2633 handshakeCallback.waitForHandshake();
2634 EXPECT_NE(
2635 handshakeCallback.errorString_.find("Network error"), std::string::npos);
2636 EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
2637 }
2638
2639 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
2640 // Start listening on a local port
2641 WriteCallbackBase writeCallback;
2642 WriteErrorCallback readCallback(&writeCallback);
2643 HandshakeCallback handshakeCallback(
2644 &readCallback, HandshakeCallback::EXPECT_ERROR);
2645 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2646 TestSSLServer server(&acceptCallback);
2647
2648 auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
2649 socket->open();
2650 uint8_t buf[3] = {0x16, 0x03, 0x01};
2651 socket->write(buf, sizeof(buf));
2652 socket->close();
2653
2654 handshakeCallback.waitForHandshake();
2655 #if FOLLY_OPENSSL_IS_110
2656 EXPECT_NE(
2657 handshakeCallback.errorString_.find("Network error"), std::string::npos);
2658 #else
2659 EXPECT_NE(
2660 handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
2661 #endif
2662 }
2663
2664 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
2665 // Start listening on a local port
2666 WriteCallbackBase writeCallback;
2667 WriteErrorCallback readCallback(&writeCallback);
2668 HandshakeCallback handshakeCallback(
2669 &readCallback, HandshakeCallback::EXPECT_ERROR);
2670 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2671 TestSSLServer server(&acceptCallback);
2672
2673 auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
2674 socket->open();
2675 uint8_t buf[256] = {0x16, 0x03};
2676 memset(buf + 2, 'a', sizeof(buf) - 2);
2677 socket->write(buf, sizeof(buf));
2678 socket->close();
2679
2680 handshakeCallback.waitForHandshake();
2681 EXPECT_NE(
2682 handshakeCallback.errorString_.find("SSL routines"), std::string::npos);
2683 #if defined(OPENSSL_IS_BORINGSSL)
2684 EXPECT_NE(
2685 handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
2686 std::string::npos);
2687 #elif FOLLY_OPENSSL_IS_110
2688 EXPECT_NE(
2689 handshakeCallback.errorString_.find("packet length too long"),
2690 std::string::npos);
2691 #else
2692 EXPECT_NE(
2693 handshakeCallback.errorString_.find("unknown protocol"),
2694 std::string::npos);
2695 #endif
2696 }
2697
2698 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
2699 using folly::ssl::OpenSSLUtils;
2700 EXPECT_EQ(
2701 OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
2702 // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
2703 EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
2704 // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
2705 EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
2706 }
2707
2708 #if defined __linux__
2709 /**
2710 * Ensure TransparentTLS flag is disabled with AsyncSSLSocket
2711 */
2712 TEST(AsyncSSLSocketTest, TTLSDisabled) {
2713 // clear all setsockopt tracking history
2714 globalStatic.reset();
2715
2716 // Start listening on a local port
2717 WriteCallbackBase writeCallback;
2718 ReadCallback readCallback(&writeCallback);
2719 HandshakeCallback handshakeCallback(&readCallback);
2720 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2721 TestSSLServer server(&acceptCallback, false);
2722
2723 // Set up SSL context.
2724 auto sslContext = std::make_shared<SSLContext>();
2725
2726 // connect
2727 auto socket =
2728 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2729 socket->open();
2730
2731 EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getNetworkSocket()));
2732
2733 // write()
2734 std::array<uint8_t, 128> buf;
2735 memset(buf.data(), 'a', buf.size());
2736 socket->write(buf.data(), buf.size());
2737
2738 // close()
2739 socket->close();
2740 }
2741 #endif
2742
2743 #if FOLLY_ALLOW_TFO
2744
2745 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
2746 public:
2747 using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
2748
2749 explicit MockAsyncTFOSSLSocket(
2750 std::shared_ptr<folly::SSLContext> sslCtx, EventBase* evb)
2751 : AsyncSSLSocket(sslCtx, evb) {}
2752
2753 MOCK_METHOD3(
2754 tfoSendMsg, ssize_t(NetworkSocket fd, struct msghdr* msg, int msg_flags));
2755 };
2756
2757 #if defined __linux__
2758 /**
2759 * Ensure TransparentTLS flag is disabled with AsyncSSLSocket + TFO
2760 */
2761 TEST(AsyncSSLSocketTest, TTLSDisabledWithTFO) {
2762 // clear all setsockopt tracking history
2763 globalStatic.reset();
2764
2765 // Start listening on a local port
2766 WriteCallbackBase writeCallback;
2767 ReadCallback readCallback(&writeCallback);
2768 HandshakeCallback handshakeCallback(&readCallback);
2769 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2770 TestSSLServer server(&acceptCallback, true);
2771
2772 // Set up SSL context.
2773 auto sslContext = std::make_shared<SSLContext>();
2774
2775 // connect
2776 auto socket =
2777 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2778 socket->enableTFO();
2779 socket->open();
2780
2781 EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getNetworkSocket()));
2782
2783 // write()
2784 std::array<uint8_t, 128> buf;
2785 memset(buf.data(), 'a', buf.size());
2786 socket->write(buf.data(), buf.size());
2787
2788 // close()
2789 socket->close();
2790 }
2791 #endif
2792
2793 /**
2794 * Test connecting to, writing to, reading from, and closing the
2795 * connection to the SSL server with TFO.
2796 */
2797 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
2798 // Start listening on a local port
2799 WriteCallbackBase writeCallback;
2800 ReadCallback readCallback(&writeCallback);
2801 HandshakeCallback handshakeCallback(&readCallback);
2802 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2803 TestSSLServer server(&acceptCallback, true);
2804
2805 // Set up SSL context.
2806 auto sslContext = std::make_shared<SSLContext>();
2807
2808 // connect
2809 auto socket =
2810 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2811 socket->enableTFO();
2812 socket->open();
2813
2814 // write()
2815 std::array<uint8_t, 128> buf;
2816 memset(buf.data(), 'a', buf.size());
2817 socket->write(buf.data(), buf.size());
2818
2819 // read()
2820 std::array<uint8_t, 128> readbuf;
2821 uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2822 EXPECT_EQ(bytesRead, 128);
2823 EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
2824
2825 // close()
2826 socket->close();
2827 }
2828
2829 /**
2830 * Test connecting to, writing to, reading from, and closing the
2831 * connection to the SSL server with TFO.
2832 */
2833 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
2834 // Start listening on a local port
2835 WriteCallbackBase writeCallback;
2836 ReadCallback readCallback(&writeCallback);
2837 HandshakeCallback handshakeCallback(&readCallback);
2838 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2839 TestSSLServer server(&acceptCallback, false);
2840
2841 // Set up SSL context.
2842 auto sslContext = std::make_shared<SSLContext>();
2843
2844 // connect
2845 auto socket =
2846 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2847 socket->enableTFO();
2848 socket->open();
2849
2850 // write()
2851 std::array<uint8_t, 128> buf;
2852 memset(buf.data(), 'a', buf.size());
2853 socket->write(buf.data(), buf.size());
2854
2855 // read()
2856 std::array<uint8_t, 128> readbuf;
2857 uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2858 EXPECT_EQ(bytesRead, 128);
2859 EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
2860
2861 // close()
2862 socket->close();
2863 }
2864
2865 class ConnCallback : public AsyncSocket::ConnectCallback {
2866 public:
2867 void connectSuccess() noexcept override { state = State::SUCCESS; }
2868
2869 void connectErr(const AsyncSocketException& ex) noexcept override {
2870 state = State::ERROR;
2871 error = ex.what();
2872 }
2873
2874 enum class State { WAITING, SUCCESS, ERROR };
2875
2876 State state{State::WAITING};
2877 std::string error;
2878 };
2879
2880 template <class Cardinality>
2881 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
2882 EventBase* evb, const SocketAddress& address, Cardinality cardinality) {
2883 // Set up SSL context.
2884 auto sslContext = std::make_shared<SSLContext>();
2885
2886 // connect
2887 auto socket = MockAsyncTFOSSLSocket::UniquePtr(
2888 new MockAsyncTFOSSLSocket(sslContext, evb));
2889 socket->enableTFO();
2890
2891 EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2892 .Times(cardinality)
2893 .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
2894 sockaddr_storage addr;
2895 auto len = address.getAddress(&addr);
2896 return netops::connect(fd, (const struct sockaddr*)&addr, len);
2897 }));
2898 return socket;
2899 }
2900
2901 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
2902 if (!folly::test::isTFOAvailable()) {
2903 GTEST_SKIP() << "TFO not supported.";
2904 }
2905
2906 // Start listening on a local port
2907 WriteCallbackBase writeCallback;
2908 ReadCallback readCallback(&writeCallback);
2909 HandshakeCallback handshakeCallback(&readCallback);
2910 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2911 TestSSLServer server(&acceptCallback, true);
2912
2913 EventBase evb;
2914
2915 auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
2916 ConnCallback ccb;
2917 socket->connect(&ccb, server.getAddress(), 30);
2918
2919 evb.loop();
2920 EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
2921
2922 evb.runInEventBaseThread([&] { socket->detachEventBase(); });
2923 evb.loop();
2924
2925 BlockingSocket sock(std::move(socket));
2926 // write()
2927 std::array<uint8_t, 128> buf;
2928 memset(buf.data(), 'a', buf.size());
2929 sock.write(buf.data(), buf.size());
2930
2931 // read()
2932 std::array<uint8_t, 128> readbuf;
2933 uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
2934 EXPECT_EQ(bytesRead, 128);
2935 EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
2936
2937 // close()
2938 sock.close();
2939 }
2940
2941 #if !defined(OPENSSL_IS_BORINGSSL)
2942 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
2943 // Start listening on a local port
2944 ConnectTimeoutCallback acceptCallback;
2945 TestSSLServer server(&acceptCallback, true);
2946
2947 // Set up SSL context.
2948 auto sslContext = std::make_shared<SSLContext>();
2949
2950 // connect
2951 auto socket =
2952 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2953 socket->enableTFO();
2954 EXPECT_THROW(
2955 socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
2956 }
2957 #endif
2958
2959 #if !defined(OPENSSL_IS_BORINGSSL)
2960 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
2961 // Start listening on a local port
2962 ConnectTimeoutCallback acceptCallback;
2963 TestSSLServer server(&acceptCallback, true);
2964
2965 EventBase evb;
2966
2967 auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
2968 ConnCallback ccb;
2969 // Set a short timeout
2970 socket->connect(&ccb, server.getAddress(), 1);
2971
2972 evb.loop();
2973 EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2974 }
2975 #endif
2976
2977 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
2978 // Start listening on a local port
2979 EmptyReadCallback readCallback;
2980 HandshakeCallback handshakeCallback(
2981 &readCallback, HandshakeCallback::EXPECT_ERROR);
2982 HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
2983 TestSSLServer server(&acceptCallback, true);
2984
2985 EventBase evb;
2986
2987 auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
2988 ConnCallback ccb;
2989 socket->connect(&ccb, server.getAddress(), 100);
2990
2991 evb.loop();
2992 EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2993 EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
2994 }
2995
2996 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
2997 // Start listening on a local port
2998 EventBase evb;
2999
3000 // Hopefully nothing is listening on this address
3001 SocketAddress addr("127.0.0.1", 65535);
3002 auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
3003 ConnCallback ccb;
3004 socket->connect(&ccb, addr, 100);
3005
3006 evb.loop();
3007 EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
3008 EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
3009 }
3010
3011 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
3012 EventBase eventBase;
3013 auto clientCtx = std::make_shared<SSLContext>();
3014 auto dfServerCtx = std::make_shared<SSLContext>();
3015 std::array<NetworkSocket, 2> fds;
3016 getfds(fds.data());
3017 getctx(clientCtx, dfServerCtx);
3018
3019 AsyncSSLSocket::UniquePtr clientSockPtr(
3020 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
3021 AsyncSSLSocket::UniquePtr serverSockPtr(
3022 new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
3023 auto clientSock = clientSockPtr.get();
3024 auto serverSock = serverSockPtr.get();
3025 SSLHandshakeClient client(std::move(clientSockPtr), true, true);
3026
3027 // Steal some data from the server.
3028 std::array<uint8_t, 10> buf;
3029 auto bytesReceived = netops::recv(fds[1], buf.data(), buf.size(), 0);
3030 checkUnixError(bytesReceived, "recv failed");
3031
3032 serverSock->setPreReceivedData(
3033 IOBuf::wrapBuffer(ByteRange(buf.data(), bytesReceived)));
3034 SSLHandshakeServer server(std::move(serverSockPtr), true, true);
3035 while (!client.handshakeSuccess_ && !client.handshakeError_) {
3036 eventBase.loopOnce();
3037 }
3038
3039 EXPECT_TRUE(client.handshakeSuccess_);
3040 EXPECT_TRUE(server.handshakeSuccess_);
3041 EXPECT_EQ(
3042 serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
3043 }
3044
3045 TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
3046 EventBase eventBase;
3047 auto clientCtx = std::make_shared<SSLContext>();
3048 auto dfServerCtx = std::make_shared<SSLContext>();
3049 std::array<NetworkSocket, 2> fds;
3050 getfds(fds.data());
3051 getctx(clientCtx, dfServerCtx);
3052
3053 AsyncSSLSocket::UniquePtr clientSockPtr(
3054 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
3055 AsyncSocket::UniquePtr serverSockPtr(new AsyncSocket(&eventBase, fds[1]));
3056 auto clientSock = clientSockPtr.get();
3057 auto serverSock = serverSockPtr.get();
3058 SSLHandshakeClient client(std::move(clientSockPtr), true, true);
3059
3060 // Steal some data from the server.
3061 std::array<uint8_t, 10> buf;
3062 auto bytesReceived = netops::recv(fds[1], buf.data(), buf.size(), 0);
3063 checkUnixError(bytesReceived, "recv failed");
3064
3065 serverSock->setPreReceivedData(
3066 IOBuf::wrapBuffer(ByteRange(buf.data(), bytesReceived)));
3067 AsyncSSLSocket::UniquePtr serverSSLSockPtr(
3068 new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
3069 auto serverSSLSock = serverSSLSockPtr.get();
3070 SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
3071 while (!client.handshakeSuccess_ && !client.handshakeError_) {
3072 eventBase.loopOnce();
3073 }
3074
3075 EXPECT_TRUE(client.handshakeSuccess_);
3076 EXPECT_TRUE(server.handshakeSuccess_);
3077 EXPECT_EQ(
3078 serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
3079 }
3080
3081 /**
3082 * Test overriding the flags passed to "sendmsg()" system call,
3083 * and verifying that write requests fail properly.
3084 */
3085 TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
3086 // Start listening on a local port
3087 SendMsgFlagsCallback msgCallback;
3088 ExpectWriteErrorCallback writeCallback(&msgCallback);
3089 ReadCallback readCallback(&writeCallback);
3090 HandshakeCallback handshakeCallback(&readCallback);
3091 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
3092 TestSSLServer server(&acceptCallback);
3093
3094 // Set up SSL context.
3095 auto sslContext = std::make_shared<SSLContext>();
3096 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
3097
3098 // connect
3099 auto socket =
3100 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
3101 socket->open();
3102
3103 // Setting flags to "-1" to trigger "Invalid argument" error
3104 // on attempt to use this flags in sendmsg() system call.
3105 msgCallback.resetFlags(-1);
3106
3107 // write()
3108 std::vector<uint8_t> buf(128, 'a');
3109 ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
3110
3111 // close()
3112 socket->close();
3113
3114 cerr << "SendMsgParamsCallback test completed" << endl;
3115 }
3116
3117 #if FOLLY_HAVE_SO_TIMESTAMPING
3118
3119 class AsyncSSLSocketByteEventTest : public ::testing::Test {
3120 protected:
3121 using MockDispatcher = ::testing::NiceMock<netops::test::MockDispatcher>;
3122 using TestObserver = test::MockAsyncTransportObserverForByteEvents;
3123 using ByteEventType = AsyncTransport::ByteEvent::Type;
3124
3125 /**
3126 * Components of a client connection to TestServer.
3127 *
3128 * Includes EventBase, client's AsyncSocket.
3129 */
3130 class ClientConn {
3131 public:
3132 explicit ClientConn(
3133 std::shared_ptr<TestSSLServer> server,
3134 std::shared_ptr<AsyncSSLSocket> socket = nullptr)
3135 : server_(std::move(server)), socket_(std::move(socket)) {
3136 if (!socket_) {
3137 socket_ = AsyncSSLSocket::newSocket(getSslContext(), &getEventBase());
3138 }
3139 socket_->setOverrideNetOpsDispatcher(netOpsDispatcher_);
3140 netOpsDispatcher_->forwardToDefaultImpl();
3141 }
3142
3143 ~ClientConn() {
3144 if (socket_) {
3145 socket_->close();
3146 }
3147 }
3148
3149 void connect() {
3150 CHECK_NOTNULL(socket_.get());
3151 CHECK_NOTNULL(socket_->getEventBase());
3152 socket_->connect(&connCb_, server_->getAddress(), 30);
3153 socket_->getEventBase()->loop();
3154 ASSERT_EQ(connCb_.state, ConnCallback::State::SUCCESS);
3155 setReadCb();
3156 }
3157
3158 void setReadCb() {
3159 // Due to how libevent works, we currently need to be subscribed to
3160 // EV_READ events in order to get error messages.
3161 //
3162 // TODO(bschlinker): Resolve this with libevent modification.
3163 // See https://github.com/libevent/libevent/issues/1038 for details.
3164 socket_->setReadCB(&readCb_);
3165 readCb_.setSocket(socket_);
3166 }
3167
3168 std::shared_ptr<NiceMock<TestObserver>> attachObserver(
3169 bool enableByteEvents) {
3170 auto observer = AsyncSSLSocketByteEventTest::attachObserver(
3171 socket_.get(), enableByteEvents);
3172 observers.push_back(observer);
3173 return observer;
3174 }
3175
3176 /**
3177 * Write to client socket, echo at server, and wait for echo at client.
3178 *
3179 * Waiting for echo at client ensures that we have given opportunity for
3180 * timestamps to be generated by the kernel.
3181 */
3182 void writeAndReflect(
3183 const std::vector<uint8_t>& wbuf, const WriteFlags writeFlags) {
3184 CHECK_NOTNULL(socket_.get());
3185 CHECK_NOTNULL(socket_->getEventBase());
3186
3187 // write to the client socket
3188 WriteCallbackBase wcb;
3189 socket_->write(&wcb, wbuf.data(), wbuf.size(), writeFlags);
3190 while (wcb.state == STATE_WAITING) {
3191 socket_->getEventBase()->loopOnce();
3192 }
3193 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3194
3195 // TestSSLServer reads and reflects for us
3196
3197 // read reflection at client
3198 while (wbuf.size() != readCb_.dataRead()) {
3199 socket_->getEventBase()->loopOnce();
3200 }
3201 readCb_.verifyData(wbuf.data(), wbuf.size());
3202 readCb_.clearData();
3203 }
3204
3205 std::shared_ptr<AsyncSSLSocket> getRawSocket() { return socket_; }
3206
3207 std::shared_ptr<SSLContext> getSslContext() {
3208 static std::shared_ptr<SSLContext> sslContext = initSslContext();
3209 return sslContext;
3210 }
3211
3212 EventBase& getEventBase() {
3213 static EventBase evb; // use same EventBase for all client sockets
3214 return evb;
3215 }
3216
3217 void netOpsExpectTimestampingSetSockOpt() {
3218 // must whitelist other calls
3219 EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
3220 .Times(AnyNumber());
3221 EXPECT_CALL(
3222 *netOpsDispatcher_, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
3223 .Times(1);
3224 }
3225
3226 void netOpsExpectNoTimestampingSetSockOpt() {
3227 // must whitelist other calls
3228 EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
3229 .Times(AnyNumber());
3230 EXPECT_CALL(
3231 *netOpsDispatcher_, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
3232 .Times(0);
3233 }
3234
3235 void netOpsExpectWriteWithFlags(WriteFlags writeFlags) {
3236 EXPECT_CALL(*netOpsDispatcher_, sendmsg(_, _, _))
3237 .WillOnce(Invoke(
3238 [this, writeFlags](
3239 NetworkSocket socket, const msghdr* message, int flags) {
3240 EXPECT_EQ(writeFlags, getMsgWriteFlags(*message));
3241 return netOpsDispatcher_->netops::Dispatcher::sendmsg(
3242 socket, message, flags);
3243 }));
3244 }
3245
3246 void netOpsVerifyAndClearExpectations() {
3247 Mock::VerifyAndClearExpectations(netOpsDispatcher_.get());
3248 }
3249
3250 /**
3251 * Static utilities.
3252 */
3253 static std::shared_ptr<SSLContext> initSslContext() {
3254 auto sslContext = std::make_shared<SSLContext>();
3255 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
3256 return sslContext;
3257 }
3258
3259 private:
3260 // server
3261 std::shared_ptr<TestSSLServer> server_;
3262
3263 // managed observers
3264 std::vector<std::shared_ptr<TestObserver>> observers;
3265
3266 // socket components
3267 ConnCallback connCb_;
3268 ReadCallback readCb_;
3269 std::shared_ptr<MockDispatcher> netOpsDispatcher_{
3270 std::make_shared<MockDispatcher>()};
3271 std::shared_ptr<AsyncSSLSocket> socket_;
3272 };
3273
3274 ClientConn getClientConn() { return ClientConn(server_); }
3275
3276 void SetUp() override {
3277 serverWriteCb_ = std::make_unique<WriteCallbackBase>();
3278 serverReadCb_ = std::make_unique<ReadCallback>(serverWriteCb_.get());
3279 serverHandshakeCb_ =
3280 std::make_unique<HandshakeCallback>(serverReadCb_.get());
3281 serverAcceptCb_ =
3282 std::make_unique<SSLServerAcceptCallback>(serverHandshakeCb_.get());
3283 server_ = std::make_shared<TestSSLServer>(serverAcceptCb_.get());
3284 }
3285
3286 /**
3287 * Static utility functions.
3288 */
3289
3290 static std::shared_ptr<NiceMock<TestObserver>> attachObserver(
3291 AsyncSocket* socket, bool enableByteEvents) {
3292 AsyncTransport::LifecycleObserver::Config config = {};
3293 config.byteEvents = enableByteEvents;
3294 return std::make_shared<NiceMock<TestObserver>>(socket, config);
3295 }
3296
3297 static WriteFlags getMsgWriteFlags(const struct msghdr& msg) {
3298 const struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
3299 if (!cmsg || cmsg->cmsg_level != SOL_SOCKET ||
3300 cmsg->cmsg_type != SO_TIMESTAMPING ||
3301 cmsg->cmsg_len != CMSG_LEN(sizeof(uint32_t))) {
3302 return WriteFlags::NONE;
3303 }
3304
3305 const uint32_t* sofFlags =
3306 (reinterpret_cast<const uint32_t*>(CMSG_DATA(cmsg)));
3307 WriteFlags flags = WriteFlags::NONE;
3308 if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SCHED) {
3309 flags = flags | WriteFlags::TIMESTAMP_SCHED;
3310 }
3311 if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE) {
3312 flags = flags | WriteFlags::TIMESTAMP_TX;
3313 }
3314 if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_ACK) {
3315 flags = flags | WriteFlags::TIMESTAMP_ACK;
3316 }
3317
3318 return flags;
3319 }
3320
3321 static WriteFlags dropWriteFromFlags(WriteFlags writeFlags) {
3322 return writeFlags & ~WriteFlags::TIMESTAMP_WRITE;
3323 }
3324
3325 // server components
3326 std::unique_ptr<WriteCallbackBase> serverWriteCb_;
3327 std::unique_ptr<ReadCallback> serverReadCb_;
3328 std::unique_ptr<HandshakeCallback> serverHandshakeCb_;
3329 std::unique_ptr<SSLServerAcceptCallback> serverAcceptCb_;
3330 std::shared_ptr<TestSSLServer> server_;
3331 };
3332
3333 TEST_F(AsyncSSLSocketByteEventTest, ObserverAttachedBeforeConnect) {
3334 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3335 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3336 const std::vector<uint8_t> wbuf(1, 'a');
3337
3338 auto clientConn = getClientConn();
3339 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3340 clientConn.netOpsExpectTimestampingSetSockOpt();
3341 clientConn.connect();
3342 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3343 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3344 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3345 clientConn.netOpsVerifyAndClearExpectations();
3346
3347 {
3348 clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3349 clientConn.writeAndReflect(wbuf, flags);
3350 clientConn.netOpsVerifyAndClearExpectations();
3351
3352 // may have more than four new ByteEvents if write split further by kernel
3353 EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
3354
3355 // due to SSL overhead, offset will not be 0
3356 auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3357 EXPECT_EQ(
3358 offsetExpected,
3359 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3360 EXPECT_EQ(
3361 offsetExpected,
3362 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3363 EXPECT_EQ(
3364 offsetExpected,
3365 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3366 EXPECT_EQ(
3367 offsetExpected,
3368 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3369 }
3370
3371 // write again to check offsets
3372 {
3373 const auto startNumByteEvents = observer->byteEvents.size();
3374 clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3375 clientConn.writeAndReflect(wbuf, flags);
3376 clientConn.netOpsVerifyAndClearExpectations();
3377
3378 // may have more than four new ByteEvents if write split further by kernel
3379 EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));
3380
3381 // due to SSL overhead, offset will not be 1
3382 auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3383 EXPECT_EQ(
3384 offsetExpected,
3385 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3386 EXPECT_EQ(
3387 offsetExpected,
3388 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3389 EXPECT_EQ(
3390 offsetExpected,
3391 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3392 EXPECT_EQ(
3393 offsetExpected,
3394 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3395 }
3396 }
3397
3398 TEST_F(AsyncSSLSocketByteEventTest, ObserverAttachedAfterConnect) {
3399 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3400 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3401 const std::vector<uint8_t> wbuf(1, 'a');
3402
3403 auto clientConn = getClientConn();
3404 clientConn.netOpsExpectNoTimestampingSetSockOpt();
3405 clientConn.connect();
3406 clientConn.netOpsVerifyAndClearExpectations();
3407
3408 clientConn.netOpsExpectTimestampingSetSockOpt();
3409 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3410 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3411 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3412 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3413 clientConn.netOpsVerifyAndClearExpectations();
3414
3415 {
3416 clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3417 clientConn.writeAndReflect(wbuf, flags);
3418 clientConn.netOpsVerifyAndClearExpectations();
3419
3420 // may have more than four new ByteEvents if write split further by kernel
3421 EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
3422
3423 // due to SSL overhead, offset will not be 0
3424 auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3425 EXPECT_EQ(
3426 offsetExpected,
3427 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3428 EXPECT_EQ(
3429 offsetExpected,
3430 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3431 EXPECT_EQ(
3432 offsetExpected,
3433 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3434 EXPECT_EQ(
3435 offsetExpected,
3436 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3437 }
3438
3439 // write again to check offsets
3440 {
3441 const auto startNumByteEvents = observer->byteEvents.size();
3442 clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3443 clientConn.writeAndReflect(wbuf, flags);
3444 clientConn.netOpsVerifyAndClearExpectations();
3445
3446 // may have more than four new ByteEvents if write split further by kernel
3447 EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));
3448
3449 // due to SSL overhead, offset will not be 1
3450 auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3451 EXPECT_EQ(
3452 offsetExpected,
3453 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3454 EXPECT_EQ(
3455 offsetExpected,
3456 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3457 EXPECT_EQ(
3458 offsetExpected,
3459 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3460 EXPECT_EQ(
3461 offsetExpected,
3462 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3463 }
3464 }
3465
3466 TEST_F(AsyncSSLSocketByteEventTest, MultiByteWrites) {
3467 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3468 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3469
3470 auto clientConn = getClientConn();
3471 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3472 clientConn.netOpsExpectTimestampingSetSockOpt();
3473 clientConn.connect();
3474 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3475 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3476 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3477 clientConn.netOpsVerifyAndClearExpectations();
3478
3479 // write 20 bytes
3480 {
3481 std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes
3482
3483 clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3484 clientConn.writeAndReflect(wbuf, flags);
3485 clientConn.netOpsVerifyAndClearExpectations();
3486
3487 // may have more than four new ByteEvents if write split further by kernel
3488 EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
3489
3490 // due to SSL overhead, offset will not be 0
3491 auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3492 EXPECT_EQ(
3493 offsetExpected,
3494 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3495 EXPECT_EQ(
3496 offsetExpected,
3497 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3498 EXPECT_EQ(
3499 offsetExpected,
3500 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3501 EXPECT_EQ(
3502 offsetExpected,
3503 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3504 }
3505
3506 // write 40 bytes
3507 {
3508 std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes
3509
3510 const auto startNumByteEvents = observer->byteEvents.size();
3511 clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3512 clientConn.writeAndReflect(wbuf, flags);
3513 clientConn.netOpsVerifyAndClearExpectations();
3514
3515 // may have more than four new ByteEvents if write split further by kernel
3516 EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));
3517
3518 // due to SSL overhead, offset will not be 1
3519 auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3520 EXPECT_EQ(
3521 offsetExpected,
3522 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3523 EXPECT_EQ(
3524 offsetExpected,
3525 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3526 EXPECT_EQ(
3527 offsetExpected,
3528 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3529 EXPECT_EQ(
3530 offsetExpected,
3531 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3532 }
3533 }
3534
3535 TEST_F(AsyncSSLSocketByteEventTest, MultiByteWritesEnableSecondWrite) {
3536 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3537 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3538
3539 auto clientConn = getClientConn();
3540 clientConn.netOpsExpectNoTimestampingSetSockOpt();
3541 clientConn.connect();
3542 clientConn.netOpsVerifyAndClearExpectations();
3543
3544 // write 20 bytes with no ByteEvents / observer
3545 {
3546 std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes
3547 clientConn.netOpsExpectWriteWithFlags(WriteFlags::NONE);
3548 clientConn.writeAndReflect(wbuf, flags);
3549 clientConn.netOpsVerifyAndClearExpectations();
3550 }
3551
3552 // enable observer
3553 clientConn.netOpsExpectTimestampingSetSockOpt();
3554 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3555 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3556 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3557 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3558 clientConn.netOpsVerifyAndClearExpectations();
3559
3560 // write 40 bytes
3561 {
3562 std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes
3563
3564 const auto startNumByteEvents = observer->byteEvents.size();
3565 clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
3566 clientConn.writeAndReflect(wbuf, flags);
3567 clientConn.netOpsVerifyAndClearExpectations();
3568
3569 // may have more than four new ByteEvents if write split further by kernel
3570 EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));
3571
3572 // due to SSL overhead, offset will not be 1
3573 auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
3574 EXPECT_EQ(
3575 offsetExpected,
3576 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3577 EXPECT_EQ(
3578 offsetExpected,
3579 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3580 EXPECT_EQ(
3581 offsetExpected,
3582 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3583 EXPECT_EQ(
3584 offsetExpected,
3585 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3586 }
3587 }
3588
3589 #endif // FOLLY_HAVE_SO_TIMESTAMPING
3590
3591 #endif // __linux__
3592
3593 TEST(AsyncSSLSocketTest, TestSNIClientHelloBehavior) {
3594 EventBase eventBase;
3595 auto serverCtx = std::make_shared<SSLContext>();
3596 auto clientCtx = std::make_shared<SSLContext>();
3597 serverCtx->loadPrivateKey(kTestKey);
3598 serverCtx->loadCertificate(kTestCert);
3599
3600 clientCtx->setSessionCacheContext("test context");
3601 clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
3602 std::shared_ptr<folly::ssl::SSLSession> resumptionSession = nullptr;
3603
3604 {
3605 std::array<NetworkSocket, 2> fds;
3606 getfds(fds.data());
3607
3608 AsyncSSLSocket::UniquePtr clientSock(
3609 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
3610 AsyncSSLSocket::UniquePtr serverSock(
3611 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
3612
3613 // Client sends SNI that doesn't match anything the server cert advertises
3614 clientSock->setServerName("Foobar");
3615
3616 SSLHandshakeServerParseClientHello server(
3617 std::move(serverSock), true, true);
3618 SSLHandshakeClient client(std::move(clientSock), true, true);
3619 eventBase.loop();
3620
3621 serverSock = std::move(server).moveSocket();
3622 auto chi = serverSock->getClientHelloInfo();
3623 ASSERT_NE(chi, nullptr);
3624 EXPECT_EQ(
3625 std::string("Foobar"), std::string(serverSock->getSSLServerName()));
3626
3627 // create another client, resuming with the prior session, but under a
3628 // different common name.
3629 clientSock = std::move(client).moveSocket();
3630 resumptionSession = clientSock->getSSLSession();
3631 }
3632
3633 {
3634 std::array<NetworkSocket, 2> fds;
3635 getfds(fds.data());
3636
3637 AsyncSSLSocket::UniquePtr clientSock(
3638 new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
3639 AsyncSSLSocket::UniquePtr serverSock(
3640 new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
3641
3642 clientSock->setSSLSession(resumptionSession);
3643 clientSock->setServerName("Baz");
3644 SSLHandshakeServerParseClientHello server(
3645 std::move(serverSock), true, true);
3646 SSLHandshakeClient client(std::move(clientSock), true, true);
3647 eventBase.loop();
3648
3649 serverSock = std::move(server).moveSocket();
3650 clientSock = std::move(client).moveSocket();
3651 EXPECT_TRUE(clientSock->getSSLSessionReused());
3652
3653 // OpenSSL 1.1.1 changes the semantics of SSL_get_servername
3654 // in
3655 // https://github.com/openssl/openssl/commit/1c4aa31d79821dee9be98e915159d52cc30d8403
3656 //
3657 // Previously, the SNI would be taken from the ClientHello.
3658 // Now, the SNI will be taken from the established session.
3659 //
3660 // But the session that was established with the client (prior handshake)
3661 // would not have set the server name field because the SNI that the client
3662 // requested ("Foobar") did not match any of the SANs that the server was
3663 // presenting ("127.0.0.1")
3664 //
3665 // To preserve this 1.1.0 behavior, getSSLServerName() should return the
3666 // parsed ClientHello servername. This test asserts this behavior.
3667 auto sni = serverSock->getSSLServerName();
3668 ASSERT_NE(sni, nullptr);
3669
3670 std::string sniStr(sni);
3671 EXPECT_EQ(sniStr, std::string("Baz"));
3672 }
3673 }
3674
3675 TEST(AsyncSSLSocketTest, BytesWrittenWithMove) {
3676 WriteCallbackBase writeCallback;
3677 ReadCallback readCallback(&writeCallback);
3678 HandshakeCallback handshakeCallback(&readCallback);
3679 SSLServerAcceptCallback acceptCallback(&handshakeCallback);
3680 TestSSLServer server(&acceptCallback);
3681
3682 auto sslContext = std::make_shared<SSLContext>();
3683 sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
3684 auto socket1 =
3685 std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
3686 socket1->open(std::chrono::milliseconds(10000));
3687
3688 // write
3689 std::vector<uint8_t> wbuf(128, 'a');
3690 socket1->write(wbuf.data(), wbuf.size());
3691 const auto socket1AppBytes = socket1->getSocket()->getAppBytesWritten();
3692 const auto socket1RawBytes = socket1->getSocket()->getRawBytesWritten();
3693 EXPECT_EQ(128, socket1AppBytes);
3694 EXPECT_LT(128, socket1RawBytes);
3695
3696 // read reflection
3697 std::vector<uint8_t> readbuf(wbuf.size());
3698 uint32_t bytesRead = socket1->readAll(readbuf.data(), readbuf.size());
3699 EXPECT_EQ(bytesRead, wbuf.size());
3700
3701 // additional sanity checks on virtuals
3702 EXPECT_EQ(
3703 socket1->getSSLSocket()->getRawBytesWritten(),
3704 socket1->getSocket()->getRawBytesWritten());
3705 EXPECT_EQ(128, socket1->getSocket()->getAppBytesWritten());
3706 EXPECT_EQ(128, socket1->getSSLSocket()->getAppBytesWritten());
3707
3708 // move to another AsyncSSLSocket
3709 AsyncSSLSocket::UniquePtr socket2(
3710 new AsyncSSLSocket(sslContext, socket1->getSocket()));
3711 EXPECT_EQ(socket1AppBytes, socket2->getAppBytesWritten());
3712 EXPECT_EQ(socket1RawBytes, socket2->getRawBytesWritten());
3713
3714 // move to an AsyncSocket
3715 AsyncSocket::UniquePtr socket3(new AsyncSocket(std::move(socket2)));
3716 EXPECT_EQ(socket1AppBytes, socket3->getAppBytesWritten());
3717 EXPECT_EQ(socket1RawBytes, socket3->getRawBytesWritten());
3718 }
3719
3720 } // namespace folly
3721
3722 #ifdef SIGPIPE
3723 ///////////////////////////////////////////////////////////////////////////
3724 // init_unit_test_suite
3725 ///////////////////////////////////////////////////////////////////////////
3726 namespace {
3727 struct Initializer {
Initializer__anon28101f9d0c11::Initializer3728 Initializer() { signal(SIGPIPE, SIG_IGN); }
3729 };
3730 Initializer initializer;
3731 } // namespace
3732 #endif
3733