1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <folly/io/async/test/AsyncSocketTest2.h>
18
19 #include <fcntl.h>
20 #include <sys/types.h>
21
22 #include <time.h>
23 #include <iostream>
24 #include <memory>
25 #include <thread>
26
27 #include <folly/ExceptionWrapper.h>
28 #include <folly/Random.h>
29 #include <folly/SocketAddress.h>
30 #include <folly/experimental/TestUtil.h>
31 #include <folly/io/IOBuf.h>
32 #include <folly/io/SocketOptionMap.h>
33 #include <folly/io/async/AsyncTimeout.h>
34 #include <folly/io/async/EventBase.h>
35 #include <folly/io/async/ScopedEventBaseThread.h>
36 #include <folly/io/async/test/AsyncSocketTest.h>
37 #include <folly/io/async/test/MockAsyncSocketObserver.h>
38 #include <folly/io/async/test/MockAsyncTransportObserver.h>
39 #include <folly/io/async/test/TFOTest.h>
40 #include <folly/io/async/test/Util.h>
41 #include <folly/net/test/MockNetOpsDispatcher.h>
42 #include <folly/portability/GMock.h>
43 #include <folly/portability/GTest.h>
44 #include <folly/portability/Sockets.h>
45 #include <folly/portability/Unistd.h>
46 #include <folly/synchronization/Baton.h>
47 #include <folly/test/SocketAddressTestHelper.h>
48
49 using std::min;
50 using std::string;
51 using std::unique_ptr;
52 using std::vector;
53 using std::chrono::milliseconds;
54 using testing::MatchesRegex;
55
56 using namespace folly;
57 using namespace folly::test;
58 using namespace testing;
59
60 namespace {
61 // string and corresponding vector with 100 characters
62 const std::string kOneHundredCharacterString(
63 "ThisIsAVeryLongStringThatHas100Characters"
64 "AndIsUniqueEnoughToBeInterestingForTestUsageNowEndOfMessage");
65 const std::vector<uint8_t> kOneHundredCharacterVec(
66 kOneHundredCharacterString.begin(), kOneHundredCharacterString.end());
67
msgFlagsToWriteFlags(const int msg_flags)68 WriteFlags msgFlagsToWriteFlags(const int msg_flags) {
69 WriteFlags flags = WriteFlags::NONE;
70 #ifdef MSG_MORE
71 if (msg_flags & MSG_MORE) {
72 flags = flags | WriteFlags::CORK;
73 }
74 #endif // MSG_MORE
75
76 #ifdef MSG_EOR
77 if (msg_flags & MSG_EOR) {
78 flags = flags | WriteFlags::EOR;
79 }
80 #endif
81
82 #ifdef MSG_ZEROCOPY
83 if (msg_flags & MSG_ZEROCOPY) {
84 flags = flags | WriteFlags::WRITE_MSG_ZEROCOPY;
85 }
86 #endif
87 return flags;
88 }
89
getMsgAncillaryTsFlags(const struct msghdr & msg)90 WriteFlags getMsgAncillaryTsFlags(const struct msghdr& msg) {
91 const struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
92 if (!cmsg || cmsg->cmsg_level != SOL_SOCKET ||
93 cmsg->cmsg_type != SO_TIMESTAMPING ||
94 cmsg->cmsg_len != CMSG_LEN(sizeof(uint32_t))) {
95 return WriteFlags::NONE;
96 }
97
98 const uint32_t* sofFlags =
99 (reinterpret_cast<const uint32_t*>(CMSG_DATA(cmsg)));
100 WriteFlags flags = WriteFlags::NONE;
101 if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SCHED) {
102 flags = flags | WriteFlags::TIMESTAMP_SCHED;
103 }
104 if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE) {
105 flags = flags | WriteFlags::TIMESTAMP_TX;
106 }
107 if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_ACK) {
108 flags = flags | WriteFlags::TIMESTAMP_ACK;
109 }
110
111 return flags;
112 }
113
getMsgAncillaryTsFlags(const struct msghdr * msg)114 WriteFlags getMsgAncillaryTsFlags(const struct msghdr* msg) {
115 return getMsgAncillaryTsFlags(*msg);
116 }
117
118 MATCHER_P(SendmsgMsghdrHasTotalIovLen, len, "") {
119 size_t iovLen = 0;
120 for (size_t i = 0; i < arg.msg_iovlen; i++) {
121 iovLen += arg.msg_iov[i].iov_len;
122 }
123 return len == iovLen;
124 }
125
126 MATCHER_P(SendmsgInvocHasTotalIovLen, len, "") {
127 size_t iovLen = 0;
128 for (const auto& iov : arg.iovs) {
129 iovLen += iov.iov_len;
130 }
131 return len == iovLen;
132 }
133
134 MATCHER_P(SendmsgInvocHasIovFirstByte, firstBytePtr, "") {
135 if (arg.iovs.empty()) {
136 return false;
137 }
138
139 const auto& firstIov = arg.iovs.front();
140 auto iovFirstBytePtr = const_cast<void*>(
141 static_cast<const void*>(reinterpret_cast<uint8_t*>(firstIov.iov_base)));
142 return firstBytePtr == iovFirstBytePtr;
143 }
144
145 MATCHER_P(SendmsgInvocHasIovLastByte, lastBytePtr, "") {
146 if (arg.iovs.empty()) {
147 return false;
148 }
149
150 const auto& lastIov = arg.iovs.back();
151 auto iovLastBytePtr = const_cast<void*>(static_cast<const void*>(
152 reinterpret_cast<uint8_t*>(lastIov.iov_base) + lastIov.iov_len - 1));
153 return lastBytePtr == iovLastBytePtr;
154 }
155
156 MATCHER_P(SendmsgInvocMsgFlagsEq, writeFlags, "") {
157 return writeFlags == arg.writeFlagsInMsgFlags;
158 }
159
160 MATCHER_P(SendmsgInvocAncillaryFlagsEq, writeFlags, "") {
161 return writeFlags == arg.writeFlagsInAncillary;
162 }
163
164 MATCHER_P2(ByteEventMatching, type, offset, "") {
165 if (type != arg.type || (size_t)offset != arg.offset) {
166 return false;
167 }
168 return true;
169 }
170 } // namespace
171
172 class DelayedWrite : public AsyncTimeout {
173 public:
DelayedWrite(const std::shared_ptr<AsyncSocket> & socket,unique_ptr<IOBuf> && bufs,AsyncTransportWrapper::WriteCallback * wcb,bool cork,bool lastWrite=false)174 DelayedWrite(
175 const std::shared_ptr<AsyncSocket>& socket,
176 unique_ptr<IOBuf>&& bufs,
177 AsyncTransportWrapper::WriteCallback* wcb,
178 bool cork,
179 bool lastWrite = false)
180 : AsyncTimeout(socket->getEventBase()),
181 socket_(socket),
182 bufs_(std::move(bufs)),
183 wcb_(wcb),
184 cork_(cork),
185 lastWrite_(lastWrite) {}
186
187 private:
timeoutExpired()188 void timeoutExpired() noexcept override {
189 WriteFlags flags = cork_ ? WriteFlags::CORK : WriteFlags::NONE;
190 socket_->writeChain(wcb_, std::move(bufs_), flags);
191 if (lastWrite_) {
192 socket_->shutdownWrite();
193 }
194 }
195
196 std::shared_ptr<AsyncSocket> socket_;
197 unique_ptr<IOBuf> bufs_;
198 AsyncTransportWrapper::WriteCallback* wcb_;
199 bool cork_;
200 bool lastWrite_;
201 };
202
203 ///////////////////////////////////////////////////////////////////////////
204 // connect() tests
205 ///////////////////////////////////////////////////////////////////////////
206
207 /**
208 * Test connecting to a server
209 */
TEST(AsyncSocketTest,Connect)210 TEST(AsyncSocketTest, Connect) {
211 // Start listening on a local port
212 TestServer server;
213
214 // Connect using a AsyncSocket
215 EventBase evb;
216 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
217 ConnCallback cb;
218 const auto startedAt = std::chrono::steady_clock::now();
219 socket->connect(&cb, server.getAddress(), 30);
220
221 evb.loop();
222 const auto finishedAt = std::chrono::steady_clock::now();
223
224 ASSERT_EQ(cb.state, STATE_SUCCEEDED);
225 EXPECT_LE(0, socket->getConnectTime().count());
226 EXPECT_GE(socket->getConnectStartTime(), startedAt);
227 EXPECT_LE(socket->getConnectStartTime(), socket->getConnectEndTime());
228 EXPECT_LE(socket->getConnectEndTime(), finishedAt);
229 EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
230 }
231
232 enum class TFOState {
233 DISABLED,
234 ENABLED,
235 };
236
237 class AsyncSocketConnectTest : public ::testing::TestWithParam<TFOState> {};
238
getTestingValues()239 std::vector<TFOState> getTestingValues() {
240 std::vector<TFOState> vals;
241 vals.emplace_back(TFOState::DISABLED);
242
243 #if FOLLY_ALLOW_TFO
244 vals.emplace_back(TFOState::ENABLED);
245 #endif
246 return vals;
247 }
248
249 INSTANTIATE_TEST_SUITE_P(
250 ConnectTests,
251 AsyncSocketConnectTest,
252 ::testing::ValuesIn(getTestingValues()));
253
254 /**
255 * Test connecting to a server that isn't listening
256 */
TEST(AsyncSocketTest,ConnectRefused)257 TEST(AsyncSocketTest, ConnectRefused) {
258 EventBase evb;
259
260 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
261
262 // Hopefully nothing is actually listening on this address
263 folly::SocketAddress addr("127.0.0.1", 65535);
264 ConnCallback cb;
265 socket->connect(&cb, addr, 30);
266
267 evb.loop();
268
269 EXPECT_EQ(STATE_FAILED, cb.state);
270 EXPECT_EQ(AsyncSocketException::NOT_OPEN, cb.exception.getType());
271 EXPECT_LE(0, socket->getConnectTime().count());
272 EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
273 }
274
275 /**
276 * Test connection timeout
277 */
TEST(AsyncSocketTest,ConnectTimeout)278 TEST(AsyncSocketTest, ConnectTimeout) {
279 EventBase evb;
280
281 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
282
283 // Try connecting to server that won't respond.
284 //
285 // This depends somewhat on the network where this test is run.
286 // Hopefully this IP will be routable but unresponsive.
287 // (Alternatively, we could try listening on a local raw socket, but that
288 // normally requires root privileges.)
289 auto host = SocketAddressTestHelper::isIPv6Enabled()
290 ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
291 : SocketAddressTestHelper::isIPv4Enabled()
292 ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
293 : nullptr;
294 SocketAddress addr(host, 65535);
295 ConnCallback cb;
296 socket->connect(&cb, addr, 1); // also set a ridiculously small timeout
297
298 evb.loop();
299
300 ASSERT_EQ(cb.state, STATE_FAILED);
301 if (cb.exception.getType() == AsyncSocketException::NOT_OPEN) {
302 // This can happen if we could not route to the IP address picked above.
303 // In this case the connect will fail immediately rather than timing out.
304 // Just skip the test in this case.
305 SKIP() << "do not have a routable but unreachable IP address";
306 }
307 ASSERT_EQ(cb.exception.getType(), AsyncSocketException::TIMED_OUT);
308
309 // Verify that we can still get the peer address after a timeout.
310 // Use case is if the client was created from a client pool, and we want
311 // to log which peer failed.
312 folly::SocketAddress peer;
313 socket->getPeerAddress(&peer);
314 ASSERT_EQ(peer, addr);
315 EXPECT_LE(0, socket->getConnectTime().count());
316 EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(1));
317 }
318
319 /**
320 * Test writing immediately after connecting, without waiting for connect
321 * to finish.
322 */
TEST_P(AsyncSocketConnectTest,ConnectAndWrite)323 TEST_P(AsyncSocketConnectTest, ConnectAndWrite) {
324 TestServer server;
325
326 // connect()
327 EventBase evb;
328 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
329
330 if (GetParam() == TFOState::ENABLED) {
331 socket->enableTFO();
332 }
333
334 ConnCallback ccb;
335 socket->connect(&ccb, server.getAddress(), 30);
336
337 // write()
338 char buf[128];
339 memset(buf, 'a', sizeof(buf));
340 WriteCallback wcb(true /*enableReleaseIOBufCallback*/);
341 // use writeChain so we can pass an IOBuf
342 socket->writeChain(&wcb, IOBuf::copyBuffer(buf, sizeof(buf)));
343
344 // Loop. We don't bother accepting on the server socket yet.
345 // The kernel should be able to buffer the write request so it can succeed.
346 evb.loop();
347
348 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
349 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
350 ASSERT_EQ(wcb.numIoBufCount, 1);
351 ASSERT_EQ(wcb.numIoBufBytes, sizeof(buf));
352
353 // Make sure the server got a connection and received the data
354 socket->close();
355 server.verifyConnection(buf, sizeof(buf));
356
357 ASSERT_TRUE(socket->isClosedBySelf());
358 ASSERT_FALSE(socket->isClosedByPeer());
359 EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
360 }
361
362 /**
363 * Test connecting using a nullptr connect callback.
364 */
TEST_P(AsyncSocketConnectTest,ConnectNullCallback)365 TEST_P(AsyncSocketConnectTest, ConnectNullCallback) {
366 TestServer server;
367
368 // connect()
369 EventBase evb;
370 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
371 if (GetParam() == TFOState::ENABLED) {
372 socket->enableTFO();
373 }
374
375 socket->connect(nullptr, server.getAddress(), 30);
376
377 // write some data, just so we have some way of verifing
378 // that the socket works correctly after connecting
379 char buf[128];
380 memset(buf, 'a', sizeof(buf));
381 WriteCallback wcb;
382 socket->write(&wcb, buf, sizeof(buf));
383
384 evb.loop();
385
386 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
387
388 // Make sure the server got a connection and received the data
389 socket->close();
390 server.verifyConnection(buf, sizeof(buf));
391
392 ASSERT_TRUE(socket->isClosedBySelf());
393 ASSERT_FALSE(socket->isClosedByPeer());
394 }
395
396 /**
397 * Test calling both write() and close() immediately after connecting, without
398 * waiting for connect to finish.
399 *
400 * This exercises the STATE_CONNECTING_CLOSING code.
401 */
TEST_P(AsyncSocketConnectTest,ConnectWriteAndClose)402 TEST_P(AsyncSocketConnectTest, ConnectWriteAndClose) {
403 TestServer server;
404
405 // connect()
406 EventBase evb;
407 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
408 if (GetParam() == TFOState::ENABLED) {
409 socket->enableTFO();
410 }
411 ConnCallback ccb;
412 socket->connect(&ccb, server.getAddress(), 30);
413
414 // write()
415 char buf[128];
416 memset(buf, 'a', sizeof(buf));
417 WriteCallback wcb;
418 socket->write(&wcb, buf, sizeof(buf));
419
420 // close()
421 socket->close();
422
423 // Loop. We don't bother accepting on the server socket yet.
424 // The kernel should be able to buffer the write request so it can succeed.
425 evb.loop();
426
427 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
428 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
429
430 // Make sure the server got a connection and received the data
431 server.verifyConnection(buf, sizeof(buf));
432
433 ASSERT_TRUE(socket->isClosedBySelf());
434 ASSERT_FALSE(socket->isClosedByPeer());
435 }
436
437 /**
438 * Test calling close() immediately after connect()
439 */
TEST(AsyncSocketTest,ConnectAndClose)440 TEST(AsyncSocketTest, ConnectAndClose) {
441 TestServer server;
442
443 // Connect using a AsyncSocket
444 EventBase evb;
445 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
446 ConnCallback ccb;
447 socket->connect(&ccb, server.getAddress(), 30);
448
449 // Hopefully the connect didn't succeed immediately.
450 // If it did, we can't exercise the close-while-connecting code path.
451 if (ccb.state == STATE_SUCCEEDED) {
452 LOG(INFO) << "connect() succeeded immediately; aborting test "
453 "of close-during-connect behavior";
454 return;
455 }
456
457 socket->close();
458
459 // Loop, although there shouldn't be anything to do.
460 evb.loop();
461
462 // Make sure the connection was aborted
463 ASSERT_EQ(ccb.state, STATE_FAILED);
464
465 ASSERT_TRUE(socket->isClosedBySelf());
466 ASSERT_FALSE(socket->isClosedByPeer());
467 }
468
469 /**
470 * Test calling closeNow() immediately after connect()
471 *
472 * This should be identical to the normal close behavior.
473 */
TEST(AsyncSocketTest,ConnectAndCloseNow)474 TEST(AsyncSocketTest, ConnectAndCloseNow) {
475 TestServer server;
476
477 // Connect using a AsyncSocket
478 EventBase evb;
479 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
480 ConnCallback ccb;
481 socket->connect(&ccb, server.getAddress(), 30);
482
483 // Hopefully the connect didn't succeed immediately.
484 // If it did, we can't exercise the close-while-connecting code path.
485 if (ccb.state == STATE_SUCCEEDED) {
486 LOG(INFO) << "connect() succeeded immediately; aborting test "
487 "of closeNow()-during-connect behavior";
488 return;
489 }
490
491 socket->closeNow();
492
493 // Loop, although there shouldn't be anything to do.
494 evb.loop();
495
496 // Make sure the connection was aborted
497 ASSERT_EQ(ccb.state, STATE_FAILED);
498
499 ASSERT_TRUE(socket->isClosedBySelf());
500 ASSERT_FALSE(socket->isClosedByPeer());
501 }
502
503 /**
504 * Test calling both write() and closeNow() immediately after connecting,
505 * without waiting for connect to finish.
506 *
507 * This should abort the pending write.
508 */
TEST(AsyncSocketTest,ConnectWriteAndCloseNow)509 TEST(AsyncSocketTest, ConnectWriteAndCloseNow) {
510 TestServer server;
511
512 // connect()
513 EventBase evb;
514 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
515 ConnCallback ccb;
516 socket->connect(&ccb, server.getAddress(), 30);
517
518 // Hopefully the connect didn't succeed immediately.
519 // If it did, we can't exercise the close-while-connecting code path.
520 if (ccb.state == STATE_SUCCEEDED) {
521 LOG(INFO) << "connect() succeeded immediately; aborting test "
522 "of write-during-connect behavior";
523 return;
524 }
525
526 // write()
527 char buf[128];
528 memset(buf, 'a', sizeof(buf));
529 WriteCallback wcb;
530 socket->write(&wcb, buf, sizeof(buf));
531
532 // close()
533 socket->closeNow();
534
535 // Loop, although there shouldn't be anything to do.
536 evb.loop();
537
538 ASSERT_EQ(ccb.state, STATE_FAILED);
539 ASSERT_EQ(wcb.state, STATE_FAILED);
540
541 ASSERT_TRUE(socket->isClosedBySelf());
542 ASSERT_FALSE(socket->isClosedByPeer());
543 }
544
545 /**
546 * Test installing a read callback immediately, before connect() finishes.
547 */
TEST_P(AsyncSocketConnectTest,ConnectAndRead)548 TEST_P(AsyncSocketConnectTest, ConnectAndRead) {
549 TestServer server;
550
551 // connect()
552 EventBase evb;
553 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
554 if (GetParam() == TFOState::ENABLED) {
555 socket->enableTFO();
556 }
557
558 ConnCallback ccb;
559 socket->connect(&ccb, server.getAddress(), 30);
560
561 ReadCallback rcb;
562 socket->setReadCB(&rcb);
563
564 if (GetParam() == TFOState::ENABLED) {
565 // Trigger a connection
566 socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
567 }
568
569 // Even though we haven't looped yet, we should be able to accept
570 // the connection and send data to it.
571 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
572 uint8_t buf[128];
573 memset(buf, 'a', sizeof(buf));
574 acceptedSocket->write(buf, sizeof(buf));
575 acceptedSocket->flush();
576 acceptedSocket->close();
577
578 // Loop, although there shouldn't be anything to do.
579 evb.loop();
580
581 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
582 ASSERT_EQ(rcb.buffers.size(), 1);
583 ASSERT_EQ(rcb.buffers[0].length, sizeof(buf));
584 ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);
585
586 ASSERT_FALSE(socket->isClosedBySelf());
587 ASSERT_FALSE(socket->isClosedByPeer());
588 }
589
TEST_P(AsyncSocketConnectTest,ConnectAndReadv)590 TEST_P(AsyncSocketConnectTest, ConnectAndReadv) {
591 TestServer server;
592
593 // connect()
594 EventBase evb;
595 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
596 if (GetParam() == TFOState::ENABLED) {
597 socket->enableTFO();
598 }
599
600 ConnCallback ccb;
601 socket->connect(&ccb, server.getAddress(), 30);
602
603 static constexpr size_t kBuffSize = 10;
604 static constexpr size_t kLen = 40;
605 static constexpr size_t kDataSize = 128;
606
607 ReadvCallback rcb(kBuffSize, kLen);
608 socket->setReadCB(&rcb);
609
610 if (GetParam() == TFOState::ENABLED) {
611 // Trigger a connection
612 socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
613 }
614
615 // Even though we haven't looped yet, we should be able to accept
616 // the connection and send data to it.
617 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
618 std::string data(kDataSize, 'A');
619 acceptedSocket->write(
620 reinterpret_cast<unsigned char*>(data.data()), data.size());
621 acceptedSocket->flush();
622 acceptedSocket->close();
623
624 // Loop, although there shouldn't be anything to do.
625 evb.loop();
626
627 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
628 rcb.verifyData(data);
629
630 ASSERT_FALSE(socket->isClosedBySelf());
631 ASSERT_FALSE(socket->isClosedByPeer());
632 }
633
634 /**
635 * Test installing a read callback and then closing immediately before the
636 * connect attempt finishes.
637 */
TEST(AsyncSocketTest,ConnectReadAndClose)638 TEST(AsyncSocketTest, ConnectReadAndClose) {
639 TestServer server;
640
641 // connect()
642 EventBase evb;
643 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
644 ConnCallback ccb;
645 socket->connect(&ccb, server.getAddress(), 30);
646
647 // Hopefully the connect didn't succeed immediately.
648 // If it did, we can't exercise the close-while-connecting code path.
649 if (ccb.state == STATE_SUCCEEDED) {
650 LOG(INFO) << "connect() succeeded immediately; aborting test "
651 "of read-during-connect behavior";
652 return;
653 }
654
655 ReadCallback rcb;
656 socket->setReadCB(&rcb);
657
658 // close()
659 socket->close();
660
661 // Loop, although there shouldn't be anything to do.
662 evb.loop();
663
664 ASSERT_EQ(ccb.state, STATE_FAILED); // we aborted the close attempt
665 ASSERT_EQ(rcb.buffers.size(), 0);
666 ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
667
668 ASSERT_TRUE(socket->isClosedBySelf());
669 ASSERT_FALSE(socket->isClosedByPeer());
670 }
671
672 /**
673 * Test both writing and installing a read callback immediately,
674 * before connect() finishes.
675 */
TEST_P(AsyncSocketConnectTest,ConnectWriteAndRead)676 TEST_P(AsyncSocketConnectTest, ConnectWriteAndRead) {
677 TestServer server;
678
679 // connect()
680 EventBase evb;
681 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
682 if (GetParam() == TFOState::ENABLED) {
683 socket->enableTFO();
684 }
685 ConnCallback ccb;
686 socket->connect(&ccb, server.getAddress(), 30);
687
688 // write()
689 char buf1[128];
690 memset(buf1, 'a', sizeof(buf1));
691 WriteCallback wcb;
692 socket->write(&wcb, buf1, sizeof(buf1));
693
694 // set a read callback
695 ReadCallback rcb;
696 socket->setReadCB(&rcb);
697
698 // Even though we haven't looped yet, we should be able to accept
699 // the connection and send data to it.
700 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
701 uint8_t buf2[128];
702 memset(buf2, 'b', sizeof(buf2));
703 acceptedSocket->write(buf2, sizeof(buf2));
704 acceptedSocket->flush();
705
706 // shut down the write half of acceptedSocket, so that the AsyncSocket
707 // will stop reading and we can break out of the event loop.
708 netops::shutdown(acceptedSocket->getNetworkSocket(), SHUT_WR);
709
710 // Loop
711 evb.loop();
712
713 // Make sure the connect succeeded
714 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
715
716 // Make sure the AsyncSocket read the data written by the accepted socket
717 ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
718 ASSERT_EQ(rcb.buffers.size(), 1);
719 ASSERT_EQ(rcb.buffers[0].length, sizeof(buf2));
720 ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf2, sizeof(buf2)), 0);
721
722 // Close the AsyncSocket so we'll see EOF on acceptedSocket
723 socket->close();
724
725 // Make sure the accepted socket saw the data written by the AsyncSocket
726 uint8_t readbuf[sizeof(buf1)];
727 acceptedSocket->readAll(readbuf, sizeof(readbuf));
728 ASSERT_EQ(memcmp(buf1, readbuf, sizeof(buf1)), 0);
729 uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
730 ASSERT_EQ(bytesRead, 0);
731
732 ASSERT_FALSE(socket->isClosedBySelf());
733 ASSERT_TRUE(socket->isClosedByPeer());
734 }
735
736 /**
737 * Test writing to the socket then shutting down writes before the connect
738 * attempt finishes.
739 */
TEST(AsyncSocketTest,ConnectWriteAndShutdownWrite)740 TEST(AsyncSocketTest, ConnectWriteAndShutdownWrite) {
741 TestServer server;
742
743 // connect()
744 EventBase evb;
745 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
746 ConnCallback ccb;
747 socket->connect(&ccb, server.getAddress(), 30);
748
749 // Hopefully the connect didn't succeed immediately.
750 // If it did, we can't exercise the write-while-connecting code path.
751 if (ccb.state == STATE_SUCCEEDED) {
752 LOG(INFO) << "connect() succeeded immediately; skipping test";
753 return;
754 }
755
756 // Ask to write some data
757 char wbuf[128];
758 memset(wbuf, 'a', sizeof(wbuf));
759 WriteCallback wcb;
760 socket->write(&wcb, wbuf, sizeof(wbuf));
761 socket->shutdownWrite();
762
763 // Shutdown writes
764 socket->shutdownWrite();
765
766 // Even though we haven't looped yet, we should be able to accept
767 // the connection.
768 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
769
770 // Since the connection is still in progress, there should be no data to
771 // read yet. Verify that the accepted socket is not readable.
772 netops::PollDescriptor fds[1];
773 fds[0].fd = acceptedSocket->getNetworkSocket();
774 fds[0].events = POLLIN;
775 fds[0].revents = 0;
776 int rc = netops::poll(fds, 1, 0);
777 ASSERT_EQ(rc, 0);
778
779 // Write data to the accepted socket
780 uint8_t acceptedWbuf[192];
781 memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
782 acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
783 acceptedSocket->flush();
784
785 // Loop
786 evb.loop();
787
788 // The loop should have completed the connection, written the queued data,
789 // and shutdown writes on the socket.
790 //
791 // Check that the connection was completed successfully and that the write
792 // callback succeeded.
793 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
794 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
795
796 // Check that we can read the data that was written to the socket, and that
797 // we see an EOF, since its socket was half-shutdown.
798 uint8_t readbuf[sizeof(wbuf)];
799 acceptedSocket->readAll(readbuf, sizeof(readbuf));
800 ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
801 uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
802 ASSERT_EQ(bytesRead, 0);
803
804 // Close the accepted socket. This will cause it to see EOF
805 // and uninstall the read callback when we loop next.
806 acceptedSocket->close();
807
808 // Install a read callback, then loop again.
809 ReadCallback rcb;
810 socket->setReadCB(&rcb);
811 evb.loop();
812
813 // This loop should have read the data and seen the EOF
814 ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
815 ASSERT_EQ(rcb.buffers.size(), 1);
816 ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
817 ASSERT_EQ(
818 memcmp(rcb.buffers[0].buffer, acceptedWbuf, sizeof(acceptedWbuf)), 0);
819
820 ASSERT_FALSE(socket->isClosedBySelf());
821 ASSERT_FALSE(socket->isClosedByPeer());
822 }
823
824 /**
825 * Test reading, writing, and shutting down writes before the connect attempt
826 * finishes.
827 */
TEST(AsyncSocketTest,ConnectReadWriteAndShutdownWrite)828 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWrite) {
829 TestServer server;
830
831 // connect()
832 EventBase evb;
833 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
834 ConnCallback ccb;
835 socket->connect(&ccb, server.getAddress(), 30);
836
837 // Hopefully the connect didn't succeed immediately.
838 // If it did, we can't exercise the write-while-connecting code path.
839 if (ccb.state == STATE_SUCCEEDED) {
840 LOG(INFO) << "connect() succeeded immediately; skipping test";
841 return;
842 }
843
844 // Install a read callback
845 ReadCallback rcb;
846 socket->setReadCB(&rcb);
847
848 // Ask to write some data
849 char wbuf[128];
850 memset(wbuf, 'a', sizeof(wbuf));
851 WriteCallback wcb;
852 socket->write(&wcb, wbuf, sizeof(wbuf));
853
854 // Shutdown writes
855 socket->shutdownWrite();
856
857 // Even though we haven't looped yet, we should be able to accept
858 // the connection.
859 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
860
861 // Since the connection is still in progress, there should be no data to
862 // read yet. Verify that the accepted socket is not readable.
863 netops::PollDescriptor fds[1];
864 fds[0].fd = acceptedSocket->getNetworkSocket();
865 fds[0].events = POLLIN;
866 fds[0].revents = 0;
867 int rc = netops::poll(fds, 1, 0);
868 ASSERT_EQ(rc, 0);
869
870 // Write data to the accepted socket
871 uint8_t acceptedWbuf[192];
872 memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
873 acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
874 acceptedSocket->flush();
875 // Shutdown writes to the accepted socket. This will cause it to see EOF
876 // and uninstall the read callback.
877 netops::shutdown(acceptedSocket->getNetworkSocket(), SHUT_WR);
878
879 // Loop
880 evb.loop();
881
882 // The loop should have completed the connection, written the queued data,
883 // shutdown writes on the socket, read the data we wrote to it, and see the
884 // EOF.
885 //
886 // Check that the connection was completed successfully and that the read
887 // and write callbacks were invoked as expected.
888 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
889 ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
890 ASSERT_EQ(rcb.buffers.size(), 1);
891 ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
892 ASSERT_EQ(
893 memcmp(rcb.buffers[0].buffer, acceptedWbuf, sizeof(acceptedWbuf)), 0);
894 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
895
896 // Check that we can read the data that was written to the socket, and that
897 // we see an EOF, since its socket was half-shutdown.
898 uint8_t readbuf[sizeof(wbuf)];
899 acceptedSocket->readAll(readbuf, sizeof(readbuf));
900 ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
901 uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
902 ASSERT_EQ(bytesRead, 0);
903
904 // Fully close both sockets
905 acceptedSocket->close();
906 socket->close();
907
908 ASSERT_FALSE(socket->isClosedBySelf());
909 ASSERT_TRUE(socket->isClosedByPeer());
910 }
911
912 /**
913 * Test reading, writing, and calling shutdownWriteNow() before the
914 * connect attempt finishes.
915 */
TEST(AsyncSocketTest,ConnectReadWriteAndShutdownWriteNow)916 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWriteNow) {
917 TestServer server;
918
919 // connect()
920 EventBase evb;
921 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
922 ConnCallback ccb;
923 socket->connect(&ccb, server.getAddress(), 30);
924
925 // Hopefully the connect didn't succeed immediately.
926 // If it did, we can't exercise the write-while-connecting code path.
927 if (ccb.state == STATE_SUCCEEDED) {
928 LOG(INFO) << "connect() succeeded immediately; skipping test";
929 return;
930 }
931
932 // Install a read callback
933 ReadCallback rcb;
934 socket->setReadCB(&rcb);
935
936 // Ask to write some data
937 char wbuf[128];
938 memset(wbuf, 'a', sizeof(wbuf));
939 WriteCallback wcb;
940 socket->write(&wcb, wbuf, sizeof(wbuf));
941
942 // Shutdown writes immediately.
943 // This should immediately discard the data that we just tried to write.
944 socket->shutdownWriteNow();
945
946 // Verify that writeError() was invoked on the write callback.
947 ASSERT_EQ(wcb.state, STATE_FAILED);
948 ASSERT_EQ(wcb.bytesWritten, 0);
949
950 // Even though we haven't looped yet, we should be able to accept
951 // the connection.
952 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
953
954 // Since the connection is still in progress, there should be no data to
955 // read yet. Verify that the accepted socket is not readable.
956 netops::PollDescriptor fds[1];
957 fds[0].fd = acceptedSocket->getNetworkSocket();
958 fds[0].events = POLLIN;
959 fds[0].revents = 0;
960 int rc = netops::poll(fds, 1, 0);
961 ASSERT_EQ(rc, 0);
962
963 // Write data to the accepted socket
964 uint8_t acceptedWbuf[192];
965 memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
966 acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
967 acceptedSocket->flush();
968 // Shutdown writes to the accepted socket. This will cause it to see EOF
969 // and uninstall the read callback.
970 netops::shutdown(acceptedSocket->getNetworkSocket(), SHUT_WR);
971
972 // Loop
973 evb.loop();
974
975 // The loop should have completed the connection, written the queued data,
976 // shutdown writes on the socket, read the data we wrote to it, and see the
977 // EOF.
978 //
979 // Check that the connection was completed successfully and that the read
980 // callback was invoked as expected.
981 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
982 ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
983 ASSERT_EQ(rcb.buffers.size(), 1);
984 ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
985 ASSERT_EQ(
986 memcmp(rcb.buffers[0].buffer, acceptedWbuf, sizeof(acceptedWbuf)), 0);
987
988 // Since we used shutdownWriteNow(), it should have discarded all pending
989 // write data. Verify we see an immediate EOF when reading from the accepted
990 // socket.
991 uint8_t readbuf[sizeof(wbuf)];
992 uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
993 ASSERT_EQ(bytesRead, 0);
994
995 // Fully close both sockets
996 acceptedSocket->close();
997 socket->close();
998
999 ASSERT_FALSE(socket->isClosedBySelf());
1000 ASSERT_TRUE(socket->isClosedByPeer());
1001 }
1002
1003 // Helper function for use in testConnectOptWrite()
1004 // Temporarily disable the read callback
tmpDisableReads(AsyncSocket * socket,ReadCallback * rcb)1005 void tmpDisableReads(AsyncSocket* socket, ReadCallback* rcb) {
1006 // Uninstall the read callback
1007 socket->setReadCB(nullptr);
1008 // Schedule the read callback to be reinstalled after 1ms
1009 socket->getEventBase()->runInLoop(
1010 std::bind(&AsyncSocket::setReadCB, socket, rcb));
1011 }
1012
1013 /**
1014 * Test connect+write, then have the connect callback perform another write.
1015 *
1016 * This tests interaction of the optimistic writing after connect with
1017 * additional write attempts that occur in the connect callback.
1018 */
testConnectOptWrite(size_t size1,size_t size2,bool close=false)1019 void testConnectOptWrite(size_t size1, size_t size2, bool close = false) {
1020 TestServer server;
1021 EventBase evb;
1022 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1023
1024 // connect()
1025 ConnCallback ccb;
1026 socket->connect(&ccb, server.getAddress(), 30);
1027
1028 // Hopefully the connect didn't succeed immediately.
1029 // If it did, we can't exercise the optimistic write code path.
1030 if (ccb.state == STATE_SUCCEEDED) {
1031 LOG(INFO) << "connect() succeeded immediately; aborting test "
1032 "of optimistic write behavior";
1033 return;
1034 }
1035
1036 // Tell the connect callback to perform a write when the connect succeeds
1037 WriteCallback wcb2;
1038 std::unique_ptr<char[]> buf2(new char[size2]);
1039 memset(buf2.get(), 'b', size2);
1040 if (size2 > 0) {
1041 ccb.successCallback = [&] { socket->write(&wcb2, buf2.get(), size2); };
1042 // Tell the second write callback to close the connection when it is done
1043 wcb2.successCallback = [&] { socket->closeNow(); };
1044 }
1045
1046 // Schedule one write() immediately, before the connect finishes
1047 std::unique_ptr<char[]> buf1(new char[size1]);
1048 memset(buf1.get(), 'a', size1);
1049 WriteCallback wcb1;
1050 if (size1 > 0) {
1051 socket->write(&wcb1, buf1.get(), size1);
1052 }
1053
1054 if (close) {
1055 // immediately perform a close, before connect() completes
1056 socket->close();
1057 }
1058
1059 // Start reading from the other endpoint after 10ms.
1060 // If we're using large buffers, we have to read so that the writes don't
1061 // block forever.
1062 std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1063 ReadCallback rcb;
1064 rcb.dataAvailableCallback =
1065 std::bind(tmpDisableReads, acceptedSocket.get(), &rcb);
1066 socket->getEventBase()->tryRunAfterDelay(
1067 std::bind(&AsyncSocket::setReadCB, acceptedSocket.get(), &rcb), 10);
1068
1069 // Loop. We don't bother accepting on the server socket yet.
1070 // The kernel should be able to buffer the write request so it can succeed.
1071 evb.loop();
1072
1073 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1074 if (size1 > 0) {
1075 ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1076 }
1077 if (size2 > 0) {
1078 ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1079 }
1080
1081 socket->close();
1082
1083 // Make sure the read callback received all of the data
1084 size_t bytesRead = 0;
1085 for (const auto& buffer : rcb.buffers) {
1086 size_t start = bytesRead;
1087 bytesRead += buffer.length;
1088 size_t end = bytesRead;
1089 if (start < size1) {
1090 size_t cmpLen = min(size1, end) - start;
1091 ASSERT_EQ(memcmp(buffer.buffer, buf1.get() + start, cmpLen), 0);
1092 }
1093 if (end > size1 && end <= size1 + size2) {
1094 size_t itOffset;
1095 size_t buf2Offset;
1096 size_t cmpLen;
1097 if (start >= size1) {
1098 itOffset = 0;
1099 buf2Offset = start - size1;
1100 cmpLen = end - start;
1101 } else {
1102 itOffset = size1 - start;
1103 buf2Offset = 0;
1104 cmpLen = end - size1;
1105 }
1106 ASSERT_EQ(
1107 memcmp(buffer.buffer + itOffset, buf2.get() + buf2Offset, cmpLen), 0);
1108 }
1109 }
1110 ASSERT_EQ(bytesRead, size1 + size2);
1111 }
1112
TEST(AsyncSocketTest,ConnectCallbackWrite)1113 TEST(AsyncSocketTest, ConnectCallbackWrite) {
1114 // Test using small writes that should both succeed immediately
1115 testConnectOptWrite(100, 200);
1116
1117 // Test using a large buffer in the connect callback, that should block
1118 const size_t largeSize = 32 * 1024 * 1024;
1119 testConnectOptWrite(100, largeSize);
1120
1121 // Test using a large initial write
1122 testConnectOptWrite(largeSize, 100);
1123
1124 // Test using two large buffers
1125 testConnectOptWrite(largeSize, largeSize);
1126
1127 // Test a small write in the connect callback,
1128 // but no immediate write before connect completes
1129 testConnectOptWrite(0, 64);
1130
1131 // Test a large write in the connect callback,
1132 // but no immediate write before connect completes
1133 testConnectOptWrite(0, largeSize);
1134
1135 // Test connect, a small write, then immediately call close() before connect
1136 // completes
1137 testConnectOptWrite(211, 0, true);
1138
1139 // Test connect, a large immediate write (that will block), then immediately
1140 // call close() before connect completes
1141 testConnectOptWrite(largeSize, 0, true);
1142 }
1143
1144 ///////////////////////////////////////////////////////////////////////////
1145 // write() related tests
1146 ///////////////////////////////////////////////////////////////////////////
1147
1148 /**
1149 * Test writing using a nullptr callback
1150 */
TEST(AsyncSocketTest,WriteNullCallback)1151 TEST(AsyncSocketTest, WriteNullCallback) {
1152 TestServer server;
1153
1154 // connect()
1155 EventBase evb;
1156 std::shared_ptr<AsyncSocket> socket =
1157 AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1158 evb.loop(); // loop until the socket is connected
1159
1160 // write() with a nullptr callback
1161 char buf[128];
1162 memset(buf, 'a', sizeof(buf));
1163 socket->write(nullptr, buf, sizeof(buf));
1164
1165 evb.loop(); // loop until the data is sent
1166
1167 // Make sure the server got a connection and received the data
1168 socket->close();
1169 server.verifyConnection(buf, sizeof(buf));
1170
1171 ASSERT_TRUE(socket->isClosedBySelf());
1172 ASSERT_FALSE(socket->isClosedByPeer());
1173 }
1174
1175 /**
1176 * Test writing with a send timeout
1177 */
TEST(AsyncSocketTest,WriteTimeout)1178 TEST(AsyncSocketTest, WriteTimeout) {
1179 TestServer server;
1180
1181 // connect()
1182 EventBase evb;
1183 std::shared_ptr<AsyncSocket> socket =
1184 AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1185 evb.loop(); // loop until the socket is connected
1186
1187 // write() a large chunk of data, with no-one on the other end reading.
1188 // Tricky: the kernel caches the connection metrics for recently-used
1189 // routes (see tcp_no_metrics_save) so a freshly opened connection can
1190 // have a send buffer size bigger than wmem_default. This makes the test
1191 // flaky on contbuild if writeLength is < wmem_max (20M on our systems).
1192 size_t writeLength = 32 * 1024 * 1024;
1193 uint32_t timeout = 200;
1194 socket->setSendTimeout(timeout);
1195 std::unique_ptr<char[]> buf(new char[writeLength]);
1196 memset(buf.get(), 'a', writeLength);
1197 WriteCallback wcb;
1198 socket->write(&wcb, buf.get(), writeLength);
1199
1200 TimePoint start;
1201 evb.loop();
1202 TimePoint end;
1203
1204 // Make sure the write attempt timed out as requested
1205 ASSERT_EQ(wcb.state, STATE_FAILED);
1206 ASSERT_EQ(wcb.exception.getType(), AsyncSocketException::TIMED_OUT);
1207
1208 // Check that the write timed out within a reasonable period of time.
1209 // We don't check for exactly the specified timeout, since AsyncSocket only
1210 // times out when it hasn't made progress for that period of time.
1211 //
1212 // On linux, the first write sends a few hundred kb of data, then blocks for
1213 // writability, and then unblocks again after 40ms and is able to write
1214 // another smaller of data before blocking permanently. Therefore it doesn't
1215 // time out until 40ms + timeout.
1216 //
1217 // I haven't fully verified the cause of this, but I believe it probably
1218 // occurs because the receiving end delays sending an ack for up to 40ms.
1219 // (This is the default value for TCP_DELACK_MIN.) Once the sender receives
1220 // the ack, it can send some more data. However, after that point the
1221 // receiver's kernel buffer is full. This 40ms delay happens even with
1222 // TCP_NODELAY and TCP_QUICKACK enabled on both endpoints. However, the
1223 // kernel may be automatically disabling TCP_QUICKACK after receiving some
1224 // data.
1225 //
1226 // For now, we simply check that the timeout occurred within 160ms of
1227 // the requested value.
1228 T_CHECK_TIMEOUT(start, end, milliseconds(timeout), milliseconds(160));
1229 }
1230
1231 /**
1232 * Test writing to a socket that the remote endpoint has closed
1233 */
TEST(AsyncSocketTest,WritePipeError)1234 TEST(AsyncSocketTest, WritePipeError) {
1235 TestServer server;
1236
1237 // connect()
1238 EventBase evb;
1239 std::shared_ptr<AsyncSocket> socket =
1240 AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1241 socket->setSendTimeout(1000);
1242 evb.loop(); // loop until the socket is connected
1243
1244 // accept and immediately close the socket
1245 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1246 acceptedSocket->close();
1247
1248 // write() a large chunk of data
1249 size_t writeLength = 32 * 1024 * 1024;
1250 std::unique_ptr<char[]> buf(new char[writeLength]);
1251 memset(buf.get(), 'a', writeLength);
1252 WriteCallback wcb;
1253 socket->write(&wcb, buf.get(), writeLength);
1254
1255 evb.loop();
1256
1257 // Make sure the write failed.
1258 // It would be nice if AsyncSocketException could convey the errno value,
1259 // so that we could check for EPIPE
1260 ASSERT_EQ(wcb.state, STATE_FAILED);
1261 ASSERT_EQ(wcb.exception.getType(), AsyncSocketException::INTERNAL_ERROR);
1262 ASSERT_THAT(
1263 wcb.exception.what(),
1264 MatchesRegex(
1265 kIsMobile
1266 ? "AsyncSocketException: writev\\(\\) failed \\(peer=.+\\), type = Internal error, errno = .+ \\(Broken pipe\\)"
1267 : "AsyncSocketException: writev\\(\\) failed \\(peer=.+, local=.+\\), type = Internal error, errno = .+ \\(Broken pipe\\)"));
1268 ASSERT_FALSE(socket->isClosedBySelf());
1269 ASSERT_FALSE(socket->isClosedByPeer());
1270 }
1271
1272 /**
1273 * Test writing to a socket that has its read side closed
1274 */
TEST(AsyncSocketTest,WriteAfterReadEOF)1275 TEST(AsyncSocketTest, WriteAfterReadEOF) {
1276 TestServer server;
1277
1278 // connect()
1279 EventBase evb;
1280 std::shared_ptr<AsyncSocket> socket =
1281 AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1282 evb.loop(); // loop until the socket is connected
1283
1284 // Accept the connection
1285 std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1286 ReadCallback rcb;
1287 acceptedSocket->setReadCB(&rcb);
1288
1289 // Shutdown the write side of client socket (read side of server socket)
1290 socket->shutdownWrite();
1291 evb.loop();
1292
1293 // Check that accepted socket is still writable
1294 ASSERT_FALSE(acceptedSocket->good());
1295 ASSERT_TRUE(acceptedSocket->writable());
1296
1297 // Write data to accepted socket
1298 constexpr size_t simpleBufLength = 5;
1299 char simpleBuf[simpleBufLength];
1300 memset(simpleBuf, 'a', simpleBufLength);
1301 WriteCallback wcb;
1302 acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
1303 evb.loop();
1304
1305 // Make sure we were able to write even after getting a read EOF
1306 ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
1307 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1308 }
1309
1310 /**
1311 * Test that bytes written is correctly computed in case of write failure
1312 */
TEST(AsyncSocketTest,WriteErrorCallbackBytesWritten)1313 TEST(AsyncSocketTest, WriteErrorCallbackBytesWritten) {
1314 // Send and receive buffer sizes for the sockets.
1315 // Note that Linux will double this value to allow space for bookkeeping
1316 // overhead.
1317 constexpr size_t kSockBufSize = 8 * 1024;
1318 constexpr size_t kEffectiveSockBufSize = 2 * kSockBufSize;
1319
1320 TestServer server(false, kSockBufSize);
1321
1322 SocketOptionMap options{
1323 {{SOL_SOCKET, SO_SNDBUF}, int(kSockBufSize)},
1324 {{SOL_SOCKET, SO_RCVBUF}, int(kSockBufSize)},
1325 {{IPPROTO_TCP, TCP_NODELAY}, 1},
1326 };
1327
1328 // The current thread will be used by the receiver - use a separate thread
1329 // for the sender.
1330 EventBase senderEvb;
1331 std::thread senderThread([&]() { senderEvb.loopForever(); });
1332
1333 ConnCallback ccb;
1334 WriteCallback wcb;
1335 std::shared_ptr<AsyncSocket> socket;
1336
1337 senderEvb.runInEventBaseThreadAndWait([&]() {
1338 socket = AsyncSocket::newSocket(&senderEvb);
1339 socket->connect(&ccb, server.getAddress(), 30, options);
1340 });
1341
1342 // accept the socket on the server side
1343 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1344
1345 // Send a big (100KB) write so that it is partially written.
1346 constexpr size_t kSendSize = 100 * 1024;
1347 auto const sendBuf = std::vector<char>(kSendSize, 'a');
1348
1349 senderEvb.runInEventBaseThreadAndWait(
1350 [&]() { socket->write(&wcb, sendBuf.data(), kSendSize); });
1351
1352 // Read 20KB of data from the socket to allow the sender to send a bit more
1353 // data after it initially blocks.
1354 constexpr size_t kRecvSize = 20 * 1024;
1355 uint8_t recvBuf[kRecvSize];
1356 auto bytesRead = acceptedSocket->readAll(recvBuf, sizeof(recvBuf));
1357 ASSERT_EQ(kRecvSize, bytesRead);
1358 EXPECT_EQ(0, memcmp(recvBuf, sendBuf.data(), bytesRead));
1359
1360 // We should be able to send at least the amount of data received plus the
1361 // send buffer size. In practice we should probably be able to send
1362 constexpr size_t kMinExpectedBytesWritten = kRecvSize + kSockBufSize;
1363
1364 // We shouldn't be able to send more than the amount of data received plus
1365 // the send buffer size of the sending socket (kEffectiveSockBufSize) plus
1366 // the receive buffer size on the receiving socket (kEffectiveSockBufSize)
1367 constexpr size_t kMaxExpectedBytesWritten =
1368 kRecvSize + kEffectiveSockBufSize + kEffectiveSockBufSize;
1369 static_assert(
1370 kMaxExpectedBytesWritten < kSendSize, "kSendSize set too small");
1371
1372 // Need to delay after receiving 20KB and before closing the receive side so
1373 // that the send side has a chance to fill the send buffer past.
1374 using clock = std::chrono::steady_clock;
1375 auto const deadline = clock::now() + std::chrono::seconds(2);
1376 while (wcb.bytesWritten < kMinExpectedBytesWritten &&
1377 clock::now() < deadline) {
1378 std::this_thread::yield();
1379 }
1380 acceptedSocket->closeWithReset();
1381
1382 senderEvb.terminateLoopSoon();
1383 senderThread.join();
1384 socket.reset();
1385
1386 ASSERT_EQ(STATE_FAILED, wcb.state);
1387 ASSERT_LE(kMinExpectedBytesWritten, wcb.bytesWritten);
1388 ASSERT_GE(kMaxExpectedBytesWritten, wcb.bytesWritten);
1389 }
1390
1391 /**
1392 * Test writing a mix of simple buffers and IOBufs
1393 */
TEST(AsyncSocketTest,WriteIOBuf)1394 TEST(AsyncSocketTest, WriteIOBuf) {
1395 TestServer server;
1396
1397 // connect()
1398 EventBase evb;
1399 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1400 ConnCallback ccb;
1401 socket->connect(&ccb, server.getAddress(), 30);
1402
1403 // Accept the connection
1404 std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1405 ReadCallback rcb;
1406 acceptedSocket->setReadCB(&rcb);
1407
1408 // Check if EOR tracking flag can be set and reset.
1409 EXPECT_FALSE(socket->isEorTrackingEnabled());
1410 socket->setEorTracking(true);
1411 EXPECT_TRUE(socket->isEorTrackingEnabled());
1412 socket->setEorTracking(false);
1413 EXPECT_FALSE(socket->isEorTrackingEnabled());
1414
1415 // Write a simple buffer to the socket
1416 constexpr size_t simpleBufLength = 5;
1417 char simpleBuf[simpleBufLength];
1418 memset(simpleBuf, 'a', simpleBufLength);
1419 WriteCallback wcb;
1420 socket->write(&wcb, simpleBuf, simpleBufLength);
1421
1422 // Write a single-element IOBuf chain
1423 size_t buf1Length = 7;
1424 unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1425 memset(buf1->writableData(), 'b', buf1Length);
1426 buf1->append(buf1Length);
1427 unique_ptr<IOBuf> buf1Copy(buf1->clone());
1428 WriteCallback wcb2;
1429 socket->writeChain(&wcb2, std::move(buf1));
1430
1431 // Write a multiple-element IOBuf chain
1432 size_t buf2Length = 11;
1433 unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1434 memset(buf2->writableData(), 'c', buf2Length);
1435 buf2->append(buf2Length);
1436 size_t buf3Length = 13;
1437 unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1438 memset(buf3->writableData(), 'd', buf3Length);
1439 buf3->append(buf3Length);
1440 buf2->appendToChain(std::move(buf3));
1441 unique_ptr<IOBuf> buf2Copy(buf2->clone());
1442 buf2Copy->coalesce();
1443 WriteCallback wcb3;
1444 socket->writeChain(&wcb3, std::move(buf2));
1445 socket->shutdownWrite();
1446
1447 // Let the reads and writes run to completion
1448 evb.loop();
1449
1450 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1451 ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1452 ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1453
1454 // Make sure the reader got the right data in the right order
1455 ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1456 ASSERT_EQ(rcb.buffers.size(), 1);
1457 ASSERT_EQ(
1458 rcb.buffers[0].length,
1459 simpleBufLength + buf1Length + buf2Length + buf3Length);
1460 ASSERT_EQ(memcmp(rcb.buffers[0].buffer, simpleBuf, simpleBufLength), 0);
1461 ASSERT_EQ(
1462 memcmp(
1463 rcb.buffers[0].buffer + simpleBufLength,
1464 buf1Copy->data(),
1465 buf1Copy->length()),
1466 0);
1467 ASSERT_EQ(
1468 memcmp(
1469 rcb.buffers[0].buffer + simpleBufLength + buf1Length,
1470 buf2Copy->data(),
1471 buf2Copy->length()),
1472 0);
1473
1474 acceptedSocket->close();
1475 socket->close();
1476
1477 ASSERT_TRUE(socket->isClosedBySelf());
1478 ASSERT_FALSE(socket->isClosedByPeer());
1479 }
1480
TEST(AsyncSocketTest,WriteIOBufCorked)1481 TEST(AsyncSocketTest, WriteIOBufCorked) {
1482 TestServer server;
1483
1484 // connect()
1485 EventBase evb;
1486 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1487 ConnCallback ccb;
1488 socket->connect(&ccb, server.getAddress(), 30);
1489
1490 // Accept the connection
1491 std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1492 ReadCallback rcb;
1493 acceptedSocket->setReadCB(&rcb);
1494
1495 // Do three writes, 100ms apart, with the "cork" flag set
1496 // on the second write. The reader should see the first write
1497 // arrive by itself, followed by the second and third writes
1498 // arriving together.
1499 size_t buf1Length = 5;
1500 unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1501 memset(buf1->writableData(), 'a', buf1Length);
1502 buf1->append(buf1Length);
1503 size_t buf2Length = 7;
1504 unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1505 memset(buf2->writableData(), 'b', buf2Length);
1506 buf2->append(buf2Length);
1507 size_t buf3Length = 11;
1508 unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1509 memset(buf3->writableData(), 'c', buf3Length);
1510 buf3->append(buf3Length);
1511 WriteCallback wcb1;
1512 socket->writeChain(&wcb1, std::move(buf1));
1513 WriteCallback wcb2;
1514 DelayedWrite write2(socket, std::move(buf2), &wcb2, true);
1515 write2.scheduleTimeout(100);
1516 WriteCallback wcb3;
1517 DelayedWrite write3(socket, std::move(buf3), &wcb3, false, true);
1518 write3.scheduleTimeout(140);
1519
1520 evb.loop();
1521 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1522 ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1523 ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1524 if (wcb3.state != STATE_SUCCEEDED) {
1525 throw(wcb3.exception);
1526 }
1527 ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1528
1529 // Make sure the reader got the data with the right grouping
1530 ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1531 ASSERT_EQ(rcb.buffers.size(), 2);
1532 ASSERT_EQ(rcb.buffers[0].length, buf1Length);
1533 ASSERT_EQ(rcb.buffers[1].length, buf2Length + buf3Length);
1534
1535 acceptedSocket->close();
1536 socket->close();
1537
1538 ASSERT_TRUE(socket->isClosedBySelf());
1539 ASSERT_FALSE(socket->isClosedByPeer());
1540 }
1541
1542 /**
1543 * Test performing a zero-length write
1544 */
TEST(AsyncSocketTest,ZeroLengthWrite)1545 TEST(AsyncSocketTest, ZeroLengthWrite) {
1546 TestServer server;
1547
1548 // connect()
1549 EventBase evb;
1550 std::shared_ptr<AsyncSocket> socket =
1551 AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1552 evb.loop(); // loop until the socket is connected
1553
1554 auto acceptedSocket = server.acceptAsync(&evb);
1555 ReadCallback rcb;
1556 acceptedSocket->setReadCB(&rcb);
1557
1558 size_t len1 = 1024 * 1024;
1559 size_t len2 = 1024 * 1024;
1560 std::unique_ptr<char[]> buf(new char[len1 + len2]);
1561 memset(buf.get(), 'a', len1);
1562 memset(buf.get() + len1, 'b', len2);
1563
1564 WriteCallback wcb1;
1565 WriteCallback wcb2;
1566 WriteCallback wcb3;
1567 WriteCallback wcb4;
1568 socket->write(&wcb1, buf.get(), 0);
1569 socket->write(&wcb2, buf.get(), len1);
1570 socket->write(&wcb3, buf.get() + len1, 0);
1571 socket->write(&wcb4, buf.get() + len1, len2);
1572 socket->close();
1573
1574 evb.loop(); // loop until the data is sent
1575
1576 ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1577 ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1578 ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1579 ASSERT_EQ(wcb4.state, STATE_SUCCEEDED);
1580 rcb.verifyData(buf.get(), len1 + len2);
1581
1582 ASSERT_TRUE(socket->isClosedBySelf());
1583 ASSERT_FALSE(socket->isClosedByPeer());
1584 }
1585
TEST(AsyncSocketTest,ZeroLengthWritev)1586 TEST(AsyncSocketTest, ZeroLengthWritev) {
1587 TestServer server;
1588
1589 // connect()
1590 EventBase evb;
1591 std::shared_ptr<AsyncSocket> socket =
1592 AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1593 evb.loop(); // loop until the socket is connected
1594
1595 auto acceptedSocket = server.acceptAsync(&evb);
1596 ReadCallback rcb;
1597 acceptedSocket->setReadCB(&rcb);
1598
1599 size_t len1 = 1024 * 1024;
1600 size_t len2 = 1024 * 1024;
1601 std::unique_ptr<char[]> buf(new char[len1 + len2]);
1602 memset(buf.get(), 'a', len1);
1603 memset(buf.get(), 'b', len2);
1604
1605 WriteCallback wcb;
1606 constexpr size_t iovCount = 4;
1607 struct iovec iov[iovCount];
1608 iov[0].iov_base = buf.get();
1609 iov[0].iov_len = len1;
1610 iov[1].iov_base = buf.get() + len1;
1611 iov[1].iov_len = 0;
1612 iov[2].iov_base = buf.get() + len1;
1613 iov[2].iov_len = len2;
1614 iov[3].iov_base = buf.get() + len1 + len2;
1615 iov[3].iov_len = 0;
1616
1617 socket->writev(&wcb, iov, iovCount);
1618 socket->close();
1619 evb.loop(); // loop until the data is sent
1620
1621 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1622 rcb.verifyData(buf.get(), len1 + len2);
1623
1624 ASSERT_TRUE(socket->isClosedBySelf());
1625 ASSERT_FALSE(socket->isClosedByPeer());
1626 }
1627
1628 ///////////////////////////////////////////////////////////////////////////
1629 // close() related tests
1630 ///////////////////////////////////////////////////////////////////////////
1631
1632 /**
1633 * Test calling close() with pending writes when the socket is already closing.
1634 */
TEST(AsyncSocketTest,ClosePendingWritesWhileClosing)1635 TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
1636 TestServer server;
1637
1638 // connect()
1639 EventBase evb;
1640 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1641 ConnCallback ccb;
1642 socket->connect(&ccb, server.getAddress(), 30);
1643
1644 // accept the socket on the server side
1645 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1646
1647 // Loop to ensure the connect has completed
1648 evb.loop();
1649
1650 // Make sure we are connected
1651 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1652
1653 // Schedule pending writes, until several write attempts have blocked
1654 char buf[128];
1655 memset(buf, 'a', sizeof(buf));
1656 typedef vector<std::shared_ptr<WriteCallback>> WriteCallbackVector;
1657 WriteCallbackVector writeCallbacks;
1658
1659 writeCallbacks.reserve(5);
1660 while (writeCallbacks.size() < 5) {
1661 std::shared_ptr<WriteCallback> wcb(new WriteCallback);
1662
1663 socket->write(wcb.get(), buf, sizeof(buf));
1664 if (wcb->state == STATE_SUCCEEDED) {
1665 // Succeeded immediately. Keep performing more writes
1666 continue;
1667 }
1668
1669 // This write is blocked.
1670 // Have the write callback call close() when writeError() is invoked
1671 wcb->errorCallback = std::bind(&AsyncSocket::close, socket.get());
1672 writeCallbacks.push_back(wcb);
1673 }
1674
1675 // Call closeNow() to immediately fail the pending writes
1676 socket->closeNow();
1677
1678 // Make sure writeError() was invoked on all of the pending write callbacks
1679 for (const auto& writeCallback : writeCallbacks) {
1680 ASSERT_EQ((writeCallback)->state, STATE_FAILED);
1681 }
1682
1683 ASSERT_TRUE(socket->isClosedBySelf());
1684 ASSERT_FALSE(socket->isClosedByPeer());
1685 }
1686
1687 ///////////////////////////////////////////////////////////////////////////
1688 // ImmediateRead related tests
1689 ///////////////////////////////////////////////////////////////////////////
1690
1691 /* AsyncSocket use to verify immediate read works */
1692 class AsyncSocketImmediateRead : public folly::AsyncSocket {
1693 public:
1694 bool immediateReadCalled = false;
AsyncSocketImmediateRead(folly::EventBase * evb)1695 explicit AsyncSocketImmediateRead(folly::EventBase* evb) : AsyncSocket(evb) {}
1696
1697 protected:
checkForImmediateRead()1698 void checkForImmediateRead() noexcept override {
1699 immediateReadCalled = true;
1700 AsyncSocket::handleRead();
1701 }
1702 };
1703
TEST(AsyncSocket,ConnectReadImmediateRead)1704 TEST(AsyncSocket, ConnectReadImmediateRead) {
1705 TestServer server;
1706
1707 const size_t maxBufferSz = 100;
1708 const size_t maxReadsPerEvent = 1;
1709 const size_t expectedDataSz = maxBufferSz * 3;
1710 char expectedData[expectedDataSz];
1711 memset(expectedData, 'j', expectedDataSz);
1712
1713 EventBase evb;
1714 ReadCallback rcb(maxBufferSz);
1715 AsyncSocketImmediateRead socket(&evb);
1716 socket.connect(nullptr, server.getAddress(), 30);
1717
1718 evb.loop(); // loop until the socket is connected
1719
1720 socket.setReadCB(&rcb);
1721 socket.setMaxReadsPerEvent(maxReadsPerEvent);
1722 socket.immediateReadCalled = false;
1723
1724 auto acceptedSocket = server.acceptAsync(&evb);
1725
1726 ReadCallback rcbServer;
1727 WriteCallback wcbServer;
1728 rcbServer.dataAvailableCallback = [&]() {
1729 if (rcbServer.dataRead() == expectedDataSz) {
1730 // write back all data read
1731 rcbServer.verifyData(expectedData, expectedDataSz);
1732 acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1733 acceptedSocket->close();
1734 }
1735 };
1736 acceptedSocket->setReadCB(&rcbServer);
1737
1738 // write data
1739 WriteCallback wcb1;
1740 socket.write(&wcb1, expectedData, expectedDataSz);
1741 evb.loop();
1742 ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1743 rcb.verifyData(expectedData, expectedDataSz);
1744 ASSERT_EQ(socket.immediateReadCalled, true);
1745
1746 ASSERT_FALSE(socket.isClosedBySelf());
1747 ASSERT_FALSE(socket.isClosedByPeer());
1748 }
1749
TEST(AsyncSocket,ConnectReadUninstallRead)1750 TEST(AsyncSocket, ConnectReadUninstallRead) {
1751 TestServer server;
1752
1753 const size_t maxBufferSz = 100;
1754 const size_t maxReadsPerEvent = 1;
1755 const size_t expectedDataSz = maxBufferSz * 3;
1756 char expectedData[expectedDataSz];
1757 memset(expectedData, 'k', expectedDataSz);
1758
1759 EventBase evb;
1760 ReadCallback rcb(maxBufferSz);
1761 AsyncSocketImmediateRead socket(&evb);
1762 socket.connect(nullptr, server.getAddress(), 30);
1763
1764 evb.loop(); // loop until the socket is connected
1765
1766 socket.setReadCB(&rcb);
1767 socket.setMaxReadsPerEvent(maxReadsPerEvent);
1768 socket.immediateReadCalled = false;
1769
1770 auto acceptedSocket = server.acceptAsync(&evb);
1771
1772 ReadCallback rcbServer;
1773 WriteCallback wcbServer;
1774 rcbServer.dataAvailableCallback = [&]() {
1775 if (rcbServer.dataRead() == expectedDataSz) {
1776 // write back all data read
1777 rcbServer.verifyData(expectedData, expectedDataSz);
1778 acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1779 acceptedSocket->close();
1780 }
1781 };
1782 acceptedSocket->setReadCB(&rcbServer);
1783
1784 rcb.dataAvailableCallback = [&]() {
1785 // we read data and reset readCB
1786 socket.setReadCB(nullptr);
1787 };
1788
1789 // write data
1790 WriteCallback wcb;
1791 socket.write(&wcb, expectedData, expectedDataSz);
1792 evb.loop();
1793 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1794
1795 /* we shoud've only read maxBufferSz data since readCallback_
1796 * was reset in dataAvailableCallback */
1797 ASSERT_EQ(rcb.dataRead(), maxBufferSz);
1798 ASSERT_EQ(socket.immediateReadCalled, false);
1799
1800 ASSERT_FALSE(socket.isClosedBySelf());
1801 ASSERT_FALSE(socket.isClosedByPeer());
1802 }
1803
1804 // TODO:
1805 // - Test connect() and have the connect callback set the read callback
1806 // - Test connect() and have the connect callback unset the read callback
1807 // - Test reading/writing/closing/destroying the socket in the connect callback
1808 // - Test reading/writing/closing/destroying the socket in the read callback
1809 // - Test reading/writing/closing/destroying the socket in the write callback
1810 // - Test one-way shutdown behavior
1811 // - Test changing the EventBase
1812 //
1813 // - TODO: test multiple threads sharing a AsyncSocket, and detaching from it
1814 // in connectSuccess(), readDataAvailable(), writeSuccess()
1815
1816 ///////////////////////////////////////////////////////////////////////////
1817 // AsyncServerSocket tests
1818 ///////////////////////////////////////////////////////////////////////////
1819
1820 /**
1821 * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
1822 */
TEST(AsyncSocketTest,ServerAcceptOptions)1823 TEST(AsyncSocketTest, ServerAcceptOptions) {
1824 EventBase eventBase;
1825
1826 // Create a server socket
1827 std::shared_ptr<AsyncServerSocket> serverSocket(
1828 AsyncServerSocket::newSocket(&eventBase));
1829 serverSocket->bind(0);
1830 serverSocket->listen(16);
1831 folly::SocketAddress serverAddress;
1832 serverSocket->getAddress(&serverAddress);
1833
1834 // Add a callback to accept one connection then stop the loop
1835 TestAcceptCallback acceptCallback;
1836 acceptCallback.setConnectionAcceptedFn(
1837 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1838 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1839 });
1840 acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
1841 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1842 });
1843 serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
1844 serverSocket->startAccepting();
1845
1846 // Connect to the server socket
1847 std::shared_ptr<AsyncSocket> socket(
1848 AsyncSocket::newSocket(&eventBase, serverAddress));
1849
1850 eventBase.loop();
1851
1852 // Verify that the server accepted a connection
1853 ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
1854 ASSERT_EQ(
1855 acceptCallback.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
1856 ASSERT_EQ(
1857 acceptCallback.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
1858 ASSERT_EQ(
1859 acceptCallback.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
1860 auto fd = acceptCallback.getEvents()->at(1).fd;
1861
1862 #ifndef _WIN32
1863 // It is not possible to check if a socket is already in non-blocking mode on
1864 // Windows. Yes really. The accepted connection should already be in
1865 // non-blocking mode
1866 int flags = fcntl(fd.toFd(), F_GETFL, 0);
1867 ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
1868 #endif
1869
1870 #ifndef TCP_NOPUSH
1871 // The accepted connection should already have TCP_NODELAY set
1872 int value;
1873 socklen_t valueLength = sizeof(value);
1874 int rc =
1875 netops::getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
1876 ASSERT_EQ(rc, 0);
1877 ASSERT_EQ(value, 1);
1878 #endif
1879 }
1880
1881 /**
1882 * Test AsyncServerSocket::removeAcceptCallback()
1883 */
TEST(AsyncSocketTest,RemoveAcceptCallback)1884 TEST(AsyncSocketTest, RemoveAcceptCallback) {
1885 // Create a new AsyncServerSocket
1886 EventBase eventBase;
1887 std::shared_ptr<AsyncServerSocket> serverSocket(
1888 AsyncServerSocket::newSocket(&eventBase));
1889 serverSocket->bind(0);
1890 serverSocket->listen(16);
1891 folly::SocketAddress serverAddress;
1892 serverSocket->getAddress(&serverAddress);
1893
1894 // Add several accept callbacks
1895 TestAcceptCallback cb1;
1896 TestAcceptCallback cb2;
1897 TestAcceptCallback cb3;
1898 TestAcceptCallback cb4;
1899 TestAcceptCallback cb5;
1900 TestAcceptCallback cb6;
1901 TestAcceptCallback cb7;
1902
1903 // Test having callbacks remove other callbacks before them on the list,
1904 // after them on the list, or removing themselves.
1905 //
1906 // Have callback 2 remove callback 3 and callback 5 the first time it is
1907 // called.
1908 int cb2Count = 0;
1909 cb1.setConnectionAcceptedFn(
1910 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1911 std::shared_ptr<AsyncSocket> sock2(AsyncSocket::newSocket(
1912 &eventBase, serverAddress)); // cb2: -cb3 -cb5
1913 });
1914 cb3.setConnectionAcceptedFn(
1915 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {});
1916 cb4.setConnectionAcceptedFn(
1917 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1918 std::shared_ptr<AsyncSocket> sock3(
1919 AsyncSocket::newSocket(&eventBase, serverAddress)); // cb4
1920 });
1921 cb5.setConnectionAcceptedFn(
1922 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1923 std::shared_ptr<AsyncSocket> sock5(
1924 AsyncSocket::newSocket(&eventBase, serverAddress)); // cb7: -cb7
1925 });
1926 cb2.setConnectionAcceptedFn(
1927 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1928 if (cb2Count == 0) {
1929 serverSocket->removeAcceptCallback(&cb3, nullptr);
1930 serverSocket->removeAcceptCallback(&cb5, nullptr);
1931 }
1932 ++cb2Count;
1933 });
1934 // Have callback 6 remove callback 4 the first time it is called,
1935 // and destroy the server socket the second time it is called
1936 int cb6Count = 0;
1937 cb6.setConnectionAcceptedFn(
1938 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1939 if (cb6Count == 0) {
1940 serverSocket->removeAcceptCallback(&cb4, nullptr);
1941 std::shared_ptr<AsyncSocket> sock6(
1942 AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1943 std::shared_ptr<AsyncSocket> sock7(
1944 AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2
1945 std::shared_ptr<AsyncSocket> sock8(
1946 AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: stop
1947
1948 } else {
1949 serverSocket.reset();
1950 }
1951 ++cb6Count;
1952 });
1953 // Have callback 7 remove itself
1954 cb7.setConnectionAcceptedFn(
1955 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1956 serverSocket->removeAcceptCallback(&cb7, nullptr);
1957 });
1958
1959 serverSocket->addAcceptCallback(&cb1, &eventBase);
1960 serverSocket->addAcceptCallback(&cb2, &eventBase);
1961 serverSocket->addAcceptCallback(&cb3, &eventBase);
1962 serverSocket->addAcceptCallback(&cb4, &eventBase);
1963 serverSocket->addAcceptCallback(&cb5, &eventBase);
1964 serverSocket->addAcceptCallback(&cb6, &eventBase);
1965 serverSocket->addAcceptCallback(&cb7, &eventBase);
1966 serverSocket->startAccepting();
1967
1968 // Make several connections to the socket
1969 std::shared_ptr<AsyncSocket> sock1(
1970 AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1971 std::shared_ptr<AsyncSocket> sock4(
1972 AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: -cb4
1973
1974 // Loop until we are stopped
1975 eventBase.loop();
1976
1977 // Check to make sure that the expected callbacks were invoked.
1978 //
1979 // NOTE: This code depends on the AsyncServerSocket operating calling all of
1980 // the AcceptCallbacks in round-robin fashion, in the order that they were
1981 // added. The code is implemented this way right now, but the API doesn't
1982 // explicitly require it be done this way. If we change the code not to be
1983 // exactly round robin in the future, we can simplify the test checks here.
1984 // (We'll also need to update the termination code, since we expect cb6 to
1985 // get called twice to terminate the loop.)
1986 ASSERT_EQ(cb1.getEvents()->size(), 4);
1987 ASSERT_EQ(cb1.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
1988 ASSERT_EQ(cb1.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
1989 ASSERT_EQ(cb1.getEvents()->at(2).type, TestAcceptCallback::TYPE_ACCEPT);
1990 ASSERT_EQ(cb1.getEvents()->at(3).type, TestAcceptCallback::TYPE_STOP);
1991
1992 ASSERT_EQ(cb2.getEvents()->size(), 4);
1993 ASSERT_EQ(cb2.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
1994 ASSERT_EQ(cb2.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
1995 ASSERT_EQ(cb2.getEvents()->at(2).type, TestAcceptCallback::TYPE_ACCEPT);
1996 ASSERT_EQ(cb2.getEvents()->at(3).type, TestAcceptCallback::TYPE_STOP);
1997
1998 ASSERT_EQ(cb3.getEvents()->size(), 2);
1999 ASSERT_EQ(cb3.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2000 ASSERT_EQ(cb3.getEvents()->at(1).type, TestAcceptCallback::TYPE_STOP);
2001
2002 ASSERT_EQ(cb4.getEvents()->size(), 3);
2003 ASSERT_EQ(cb4.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2004 ASSERT_EQ(cb4.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2005 ASSERT_EQ(cb4.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
2006
2007 ASSERT_EQ(cb5.getEvents()->size(), 2);
2008 ASSERT_EQ(cb5.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2009 ASSERT_EQ(cb5.getEvents()->at(1).type, TestAcceptCallback::TYPE_STOP);
2010
2011 ASSERT_EQ(cb6.getEvents()->size(), 4);
2012 ASSERT_EQ(cb6.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2013 ASSERT_EQ(cb6.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2014 ASSERT_EQ(cb6.getEvents()->at(2).type, TestAcceptCallback::TYPE_ACCEPT);
2015 ASSERT_EQ(cb6.getEvents()->at(3).type, TestAcceptCallback::TYPE_STOP);
2016
2017 ASSERT_EQ(cb7.getEvents()->size(), 3);
2018 ASSERT_EQ(cb7.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2019 ASSERT_EQ(cb7.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2020 ASSERT_EQ(cb7.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
2021 }
2022
2023 /**
2024 * Test AsyncServerSocket::removeAcceptCallback()
2025 */
TEST(AsyncSocketTest,OtherThreadAcceptCallback)2026 TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
2027 // Create a new AsyncServerSocket
2028 EventBase eventBase;
2029 std::shared_ptr<AsyncServerSocket> serverSocket(
2030 AsyncServerSocket::newSocket(&eventBase));
2031 serverSocket->bind(0);
2032 serverSocket->listen(16);
2033 folly::SocketAddress serverAddress;
2034 serverSocket->getAddress(&serverAddress);
2035
2036 // Add several accept callbacks
2037 TestAcceptCallback cb1;
2038 auto thread_id = std::this_thread::get_id();
2039 cb1.setAcceptStartedFn([&]() {
2040 CHECK_NE(thread_id, std::this_thread::get_id());
2041 thread_id = std::this_thread::get_id();
2042 });
2043 cb1.setConnectionAcceptedFn(
2044 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2045 ASSERT_EQ(thread_id, std::this_thread::get_id());
2046 serverSocket->removeAcceptCallback(&cb1, &eventBase);
2047 });
2048 cb1.setAcceptStoppedFn(
2049 [&]() { ASSERT_EQ(thread_id, std::this_thread::get_id()); });
2050
2051 // Test having callbacks remove other callbacks before them on the list,
2052 serverSocket->addAcceptCallback(&cb1, &eventBase);
2053 serverSocket->startAccepting();
2054
2055 // Make several connections to the socket
2056 std::shared_ptr<AsyncSocket> sock1(
2057 AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
2058
2059 // Loop in another thread
2060 auto other = std::thread([&]() { eventBase.loop(); });
2061 other.join();
2062
2063 // Check to make sure that the expected callbacks were invoked.
2064 //
2065 // NOTE: This code depends on the AsyncServerSocket operating calling all of
2066 // the AcceptCallbacks in round-robin fashion, in the order that they were
2067 // added. The code is implemented this way right now, but the API doesn't
2068 // explicitly require it be done this way. If we change the code not to be
2069 // exactly round robin in the future, we can simplify the test checks here.
2070 // (We'll also need to update the termination code, since we expect cb6 to
2071 // get called twice to terminate the loop.)
2072 ASSERT_EQ(cb1.getEvents()->size(), 3);
2073 ASSERT_EQ(cb1.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2074 ASSERT_EQ(cb1.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2075 ASSERT_EQ(cb1.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
2076 }
2077
serverSocketSanityTest(AsyncServerSocket * serverSocket)2078 void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
2079 EventBase* eventBase = serverSocket->getEventBase();
2080 CHECK(eventBase);
2081
2082 // Add a callback to accept one connection then stop accepting
2083 TestAcceptCallback acceptCallback;
2084 acceptCallback.setConnectionAcceptedFn(
2085 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2086 serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
2087 });
2088 acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2089 serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
2090 });
2091 serverSocket->addAcceptCallback(&acceptCallback, eventBase);
2092 serverSocket->startAccepting();
2093
2094 // Connect to the server socket
2095 folly::SocketAddress serverAddress;
2096 serverSocket->getAddress(&serverAddress);
2097 AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
2098
2099 // Loop to process all events
2100 eventBase->loop();
2101
2102 // Verify that the server accepted a connection
2103 ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
2104 ASSERT_EQ(
2105 acceptCallback.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2106 ASSERT_EQ(
2107 acceptCallback.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2108 ASSERT_EQ(
2109 acceptCallback.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
2110 }
2111
2112 /* Verify that we don't leak sockets if we are destroyed()
2113 * and there are still writes pending
2114 *
2115 * If destroy() only calls close() instead of closeNow(),
2116 * it would shutdown(writes) on the socket, but it would
2117 * never be close()'d, and the socket would leak
2118 */
TEST(AsyncSocketTest,DestroyCloseTest)2119 TEST(AsyncSocketTest, DestroyCloseTest) {
2120 TestServer server;
2121
2122 // connect()
2123 EventBase clientEB;
2124 EventBase serverEB;
2125 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&clientEB);
2126 ConnCallback ccb;
2127 socket->connect(&ccb, server.getAddress(), 30);
2128
2129 // Accept the connection
2130 std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&serverEB);
2131 ReadCallback rcb;
2132 acceptedSocket->setReadCB(&rcb);
2133
2134 // Write a large buffer to the socket that is larger than kernel buffer
2135 size_t simpleBufLength = 5000000;
2136 char* simpleBuf = new char[simpleBufLength];
2137 memset(simpleBuf, 'a', simpleBufLength);
2138 WriteCallback wcb;
2139
2140 // Let the reads and writes run to completion
2141 int fd = acceptedSocket->getNetworkSocket().toFd();
2142
2143 acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
2144 socket.reset();
2145 acceptedSocket.reset();
2146
2147 // Test that server socket was closed
2148 folly::test::msvcSuppressAbortOnInvalidParams([&] {
2149 ssize_t sz = read(fd, simpleBuf, simpleBufLength);
2150 ASSERT_EQ(sz, -1);
2151 ASSERT_EQ(errno, EBADF);
2152 });
2153 delete[] simpleBuf;
2154 }
2155
2156 /**
2157 * Test AsyncServerSocket::useExistingSocket()
2158 */
TEST(AsyncSocketTest,ServerExistingSocket)2159 TEST(AsyncSocketTest, ServerExistingSocket) {
2160 EventBase eventBase;
2161
2162 // Test creating a socket, and letting AsyncServerSocket bind and listen
2163 {
2164 // Manually create a socket
2165 auto fd = netops::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2166 ASSERT_NE(fd, NetworkSocket());
2167
2168 // Create a server socket
2169 AsyncServerSocket::UniquePtr serverSocket(
2170 new AsyncServerSocket(&eventBase));
2171 serverSocket->useExistingSocket(fd);
2172 folly::SocketAddress address;
2173 serverSocket->getAddress(&address);
2174 address.setPort(0);
2175 serverSocket->bind(address);
2176 serverSocket->listen(16);
2177
2178 // Make sure the socket works
2179 serverSocketSanityTest(serverSocket.get());
2180 }
2181
2182 // Test creating a socket and binding manually,
2183 // then letting AsyncServerSocket listen
2184 {
2185 // Manually create a socket
2186 auto fd = netops::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2187 ASSERT_NE(fd, NetworkSocket());
2188 // bind
2189 struct sockaddr_in addr;
2190 addr.sin_family = AF_INET;
2191 addr.sin_port = 0;
2192 addr.sin_addr.s_addr = INADDR_ANY;
2193 ASSERT_EQ(
2194 netops::bind(
2195 fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
2196 0);
2197 // Look up the address that we bound to
2198 folly::SocketAddress boundAddress;
2199 boundAddress.setFromLocalAddress(fd);
2200
2201 // Create a server socket
2202 AsyncServerSocket::UniquePtr serverSocket(
2203 new AsyncServerSocket(&eventBase));
2204 serverSocket->useExistingSocket(fd);
2205 serverSocket->listen(16);
2206
2207 // Make sure AsyncServerSocket reports the same address that we bound to
2208 folly::SocketAddress serverSocketAddress;
2209 serverSocket->getAddress(&serverSocketAddress);
2210 ASSERT_EQ(boundAddress, serverSocketAddress);
2211
2212 // Make sure the socket works
2213 serverSocketSanityTest(serverSocket.get());
2214 }
2215
2216 // Test creating a socket, binding and listening manually,
2217 // then giving it to AsyncServerSocket
2218 {
2219 // Manually create a socket
2220 auto fd = netops::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2221 ASSERT_NE(fd, NetworkSocket());
2222 // bind
2223 struct sockaddr_in addr;
2224 addr.sin_family = AF_INET;
2225 addr.sin_port = 0;
2226 addr.sin_addr.s_addr = INADDR_ANY;
2227 ASSERT_EQ(
2228 netops::bind(
2229 fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
2230 0);
2231 // Look up the address that we bound to
2232 folly::SocketAddress boundAddress;
2233 boundAddress.setFromLocalAddress(fd);
2234 // listen
2235 ASSERT_EQ(netops::listen(fd, 16), 0);
2236
2237 // Create a server socket
2238 AsyncServerSocket::UniquePtr serverSocket(
2239 new AsyncServerSocket(&eventBase));
2240 serverSocket->useExistingSocket(fd);
2241
2242 // Make sure AsyncServerSocket reports the same address that we bound to
2243 folly::SocketAddress serverSocketAddress;
2244 serverSocket->getAddress(&serverSocketAddress);
2245 ASSERT_EQ(boundAddress, serverSocketAddress);
2246
2247 // Make sure the socket works
2248 serverSocketSanityTest(serverSocket.get());
2249 }
2250 }
2251
TEST(AsyncSocketTest,UnixDomainSocketTest)2252 TEST(AsyncSocketTest, UnixDomainSocketTest) {
2253 EventBase eventBase;
2254
2255 // Create a server socket
2256 std::shared_ptr<AsyncServerSocket> serverSocket(
2257 AsyncServerSocket::newSocket(&eventBase));
2258 string path(1, 0);
2259 path.append(folly::to<string>("/anonymous", folly::Random::rand64()));
2260 folly::SocketAddress serverAddress;
2261 serverAddress.setFromPath(path);
2262 serverSocket->bind(serverAddress);
2263 serverSocket->listen(16);
2264
2265 // Add a callback to accept one connection then stop the loop
2266 TestAcceptCallback acceptCallback;
2267 acceptCallback.setConnectionAcceptedFn(
2268 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2269 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2270 });
2271 acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2272 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2273 });
2274 serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2275 serverSocket->startAccepting();
2276
2277 // Connect to the server socket
2278 std::shared_ptr<AsyncSocket> socket(
2279 AsyncSocket::newSocket(&eventBase, serverAddress));
2280
2281 eventBase.loop();
2282
2283 // Verify that the server accepted a connection
2284 ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
2285 ASSERT_EQ(
2286 acceptCallback.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2287 ASSERT_EQ(
2288 acceptCallback.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2289 ASSERT_EQ(
2290 acceptCallback.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
2291 auto fd = acceptCallback.getEvents()->at(1).fd;
2292
2293 #ifndef _WIN32
2294 // It is not possible to check if a socket is already in non-blocking mode on
2295 // Windows. Yes really. The accepted connection should already be in
2296 // non-blocking mode
2297 int flags = fcntl(fd.toFd(), F_GETFL, 0);
2298 ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
2299 #endif
2300 }
2301
TEST(AsyncSocketTest,ConnectionEventCallbackDefault)2302 TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
2303 EventBase eventBase;
2304 TestConnectionEventCallback connectionEventCallback;
2305
2306 // Create a server socket
2307 std::shared_ptr<AsyncServerSocket> serverSocket(
2308 AsyncServerSocket::newSocket(&eventBase));
2309 serverSocket->setConnectionEventCallback(&connectionEventCallback);
2310 serverSocket->bind(0);
2311 serverSocket->listen(16);
2312 folly::SocketAddress serverAddress;
2313 serverSocket->getAddress(&serverAddress);
2314
2315 // Add a callback to accept one connection then stop the loop
2316 TestAcceptCallback acceptCallback;
2317 acceptCallback.setConnectionAcceptedFn(
2318 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2319 serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2320 });
2321 acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2322 serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2323 });
2324 serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2325 serverSocket->startAccepting();
2326
2327 // Connect to the server socket
2328 std::shared_ptr<AsyncSocket> socket(
2329 AsyncSocket::newSocket(&eventBase, serverAddress));
2330
2331 eventBase.loop();
2332
2333 // Validate the connection event counters
2334 ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2335 ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2336 ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2337 ASSERT_EQ(
2338 connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
2339 ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
2340 ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2341 ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2342 ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2343 }
2344
TEST(AsyncSocketTest,CallbackInPrimaryEventBase)2345 TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
2346 EventBase eventBase;
2347 TestConnectionEventCallback connectionEventCallback;
2348
2349 // Create a server socket
2350 std::shared_ptr<AsyncServerSocket> serverSocket(
2351 AsyncServerSocket::newSocket(&eventBase));
2352 serverSocket->setConnectionEventCallback(&connectionEventCallback);
2353 serverSocket->bind(0);
2354 serverSocket->listen(16);
2355 folly::SocketAddress serverAddress;
2356 serverSocket->getAddress(&serverAddress);
2357
2358 // Add a callback to accept one connection then stop the loop
2359 TestAcceptCallback acceptCallback;
2360 acceptCallback.setConnectionAcceptedFn(
2361 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2362 serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2363 });
2364 acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2365 serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2366 });
2367 bool acceptStartedFlag{false};
2368 acceptCallback.setAcceptStartedFn(
2369 [&acceptStartedFlag]() { acceptStartedFlag = true; });
2370 bool acceptStoppedFlag{false};
2371 acceptCallback.setAcceptStoppedFn(
2372 [&acceptStoppedFlag]() { acceptStoppedFlag = true; });
2373 serverSocket->addAcceptCallback(&acceptCallback, nullptr);
2374 serverSocket->startAccepting();
2375
2376 // Connect to the server socket
2377 std::shared_ptr<AsyncSocket> socket(
2378 AsyncSocket::newSocket(&eventBase, serverAddress));
2379
2380 eventBase.loop();
2381
2382 ASSERT_TRUE(acceptStartedFlag);
2383 ASSERT_TRUE(acceptStoppedFlag);
2384 // Validate the connection event counters
2385 ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2386 ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2387 ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2388 ASSERT_EQ(
2389 connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
2390 ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
2391 ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2392 ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2393 ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2394 }
2395
TEST(AsyncSocketTest,CallbackInSecondaryEventBase)2396 TEST(AsyncSocketTest, CallbackInSecondaryEventBase) {
2397 EventBase eventBase;
2398 TestConnectionEventCallback connectionEventCallback;
2399
2400 // Create a server socket
2401 std::shared_ptr<AsyncServerSocket> serverSocket(
2402 AsyncServerSocket::newSocket(&eventBase));
2403 serverSocket->setConnectionEventCallback(&connectionEventCallback);
2404 serverSocket->bind(0);
2405 serverSocket->listen(16);
2406 SocketAddress serverAddress;
2407 serverSocket->getAddress(&serverAddress);
2408
2409 // Add a callback to accept one connection then stop the loop
2410 TestAcceptCallback acceptCallback;
2411 ScopedEventBaseThread cobThread("ioworker_test");
2412 acceptCallback.setConnectionAcceptedFn(
2413 [&](NetworkSocket /* fd */, const SocketAddress& /* addr */) {
2414 eventBase.runInEventBaseThread([&] {
2415 serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2416 });
2417 });
2418 acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2419 eventBase.runInEventBaseThread(
2420 [&] { serverSocket->removeAcceptCallback(&acceptCallback, nullptr); });
2421 });
2422 std::atomic<bool> acceptStartedFlag{false};
2423 acceptCallback.setAcceptStartedFn([&]() { acceptStartedFlag = true; });
2424 Baton<> acceptStoppedFlag;
2425 acceptCallback.setAcceptStoppedFn([&]() { acceptStoppedFlag.post(); });
2426 serverSocket->addAcceptCallback(&acceptCallback, cobThread.getEventBase());
2427 serverSocket->startAccepting();
2428
2429 // Connect to the server socket
2430 std::shared_ptr<AsyncSocket> socket(
2431 AsyncSocket::newSocket(&eventBase, serverAddress));
2432
2433 eventBase.loop();
2434
2435 ASSERT_TRUE(acceptStoppedFlag.try_wait_for(std::chrono::seconds(1)));
2436 ASSERT_TRUE(acceptStartedFlag);
2437 // Validate the connection event counters
2438 ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2439 ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2440 ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2441 ASSERT_EQ(
2442 connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
2443 ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
2444 ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2445 ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2446 ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2447 }
2448
2449 /**
2450 * Test AsyncServerSocket::getNumPendingMessagesInQueue()
2451 */
TEST(AsyncSocketTest,NumPendingMessagesInQueue)2452 TEST(AsyncSocketTest, NumPendingMessagesInQueue) {
2453 EventBase eventBase;
2454
2455 // Counter of how many connections have been accepted
2456 int count = 0;
2457
2458 // Create a server socket
2459 auto serverSocket(AsyncServerSocket::newSocket(&eventBase));
2460 serverSocket->bind(0);
2461 serverSocket->listen(16);
2462 folly::SocketAddress serverAddress;
2463 serverSocket->getAddress(&serverAddress);
2464
2465 // Add a callback to accept connections
2466 TestAcceptCallback acceptCallback;
2467 folly::ScopedEventBaseThread cobThread("ioworker_test");
2468 acceptCallback.setConnectionAcceptedFn(
2469 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2470 count++;
2471 eventBase.runInEventBaseThreadAndWait([&] {
2472 ASSERT_EQ(4 - count, serverSocket->getNumPendingMessagesInQueue());
2473 });
2474 if (count == 4) {
2475 eventBase.runInEventBaseThread([&] {
2476 serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2477 });
2478 }
2479 });
2480 acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2481 eventBase.runInEventBaseThread(
2482 [&] { serverSocket->removeAcceptCallback(&acceptCallback, nullptr); });
2483 });
2484 serverSocket->addAcceptCallback(&acceptCallback, cobThread.getEventBase());
2485 serverSocket->startAccepting();
2486
2487 // Connect to the server socket, 4 clients, there are 4 connections
2488 auto socket1(AsyncSocket::newSocket(&eventBase, serverAddress));
2489 auto socket2(AsyncSocket::newSocket(&eventBase, serverAddress));
2490 auto socket3(AsyncSocket::newSocket(&eventBase, serverAddress));
2491 auto socket4(AsyncSocket::newSocket(&eventBase, serverAddress));
2492
2493 eventBase.loop();
2494 ASSERT_EQ(4, count);
2495 }
2496
2497 /**
2498 * Test AsyncTransport::BufferCallback
2499 */
TEST(AsyncSocketTest,BufferTest)2500 TEST(AsyncSocketTest, BufferTest) {
2501 TestServer server;
2502
2503 EventBase evb;
2504 SocketOptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2505 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2506 ConnCallback ccb;
2507 socket->connect(&ccb, server.getAddress(), 30, option);
2508
2509 char buf[100 * 1024];
2510 memset(buf, 'c', sizeof(buf));
2511 WriteCallback wcb;
2512 BufferCallback bcb(socket.get(), sizeof(buf));
2513 socket->setBufferCallback(&bcb);
2514 socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2515
2516 evb.loop();
2517 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2518 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
2519
2520 ASSERT_TRUE(bcb.hasBuffered());
2521 ASSERT_TRUE(bcb.hasBufferCleared());
2522
2523 socket->close();
2524 server.verifyConnection(buf, sizeof(buf));
2525
2526 ASSERT_TRUE(socket->isClosedBySelf());
2527 ASSERT_FALSE(socket->isClosedByPeer());
2528 }
2529
TEST(AsyncSocketTest,BufferTestChain)2530 TEST(AsyncSocketTest, BufferTestChain) {
2531 TestServer server;
2532
2533 EventBase evb;
2534 SocketOptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2535 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2536 ConnCallback ccb;
2537 socket->connect(&ccb, server.getAddress(), 30, option);
2538
2539 char buf1[100 * 1024];
2540 memset(buf1, 'c', sizeof(buf1));
2541 char buf2[100 * 1024];
2542 memset(buf2, 'f', sizeof(buf2));
2543
2544 auto buf = folly::IOBuf::copyBuffer(buf1, sizeof(buf1));
2545 buf->appendToChain(folly::IOBuf::copyBuffer(buf2, sizeof(buf2)));
2546 ASSERT_EQ(sizeof(buf1) + sizeof(buf2), buf->computeChainDataLength());
2547
2548 BufferCallback bcb(socket.get(), buf->computeChainDataLength());
2549 socket->setBufferCallback(&bcb);
2550
2551 WriteCallback wcb;
2552 socket->writeChain(&wcb, buf->clone(), WriteFlags::NONE);
2553
2554 evb.loop();
2555 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2556 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
2557
2558 ASSERT_TRUE(bcb.hasBuffered());
2559 ASSERT_TRUE(bcb.hasBufferCleared());
2560
2561 socket->close();
2562 buf->coalesce();
2563 server.verifyConnection(
2564 reinterpret_cast<const char*>(buf->data()), buf->length());
2565
2566 ASSERT_TRUE(socket->isClosedBySelf());
2567 ASSERT_FALSE(socket->isClosedByPeer());
2568 }
2569
TEST(AsyncSocketTest,BufferCallbackKill)2570 TEST(AsyncSocketTest, BufferCallbackKill) {
2571 TestServer server;
2572 EventBase evb;
2573 SocketOptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2574 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2575 ConnCallback ccb;
2576 socket->connect(&ccb, server.getAddress(), 30, option);
2577 evb.loopOnce();
2578
2579 char buf[100 * 1024];
2580 memset(buf, 'c', sizeof(buf));
2581 BufferCallback bcb(socket.get(), sizeof(buf));
2582 socket->setBufferCallback(&bcb);
2583 WriteCallback wcb;
2584 wcb.successCallback = [&] {
2585 ASSERT_TRUE(socket.unique());
2586 socket.reset();
2587 };
2588
2589 // This will trigger AsyncSocket::handleWrite,
2590 // which calls WriteCallback::writeSuccess,
2591 // which calls wcb.successCallback above,
2592 // which tries to delete socket
2593 // Then, the socket will also try to use this BufferCallback
2594 // And that should crash us, if there is no DestructorGuard on the stack
2595 socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2596
2597 evb.loop();
2598 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2599 }
2600
2601 #if FOLLY_ALLOW_TFO
TEST(AsyncSocketTest,ConnectTFO)2602 TEST(AsyncSocketTest, ConnectTFO) {
2603 if (!folly::test::isTFOAvailable()) {
2604 GTEST_SKIP() << "TFO not supported.";
2605 }
2606
2607 // Start listening on a local port
2608 TestServer server(true);
2609
2610 // Connect using a AsyncSocket
2611 EventBase evb;
2612 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2613 socket->enableTFO();
2614 ConnCallback cb;
2615 socket->connect(&cb, server.getAddress(), 30);
2616
2617 std::array<uint8_t, 128> buf;
2618 memset(buf.data(), 'a', buf.size());
2619
2620 std::array<uint8_t, 3> readBuf;
2621 auto sendBuf = IOBuf::copyBuffer("hey");
2622
2623 std::thread t([&] {
2624 auto acceptedSocket = server.accept();
2625 acceptedSocket->write(buf.data(), buf.size());
2626 acceptedSocket->flush();
2627 acceptedSocket->readAll(readBuf.data(), readBuf.size());
2628 acceptedSocket->close();
2629 });
2630
2631 evb.loop();
2632
2633 ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2634 EXPECT_LE(0, socket->getConnectTime().count());
2635 EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2636 EXPECT_TRUE(socket->getTFOAttempted());
2637
2638 // Should trigger the connect
2639 WriteCallback write;
2640 ReadCallback rcb;
2641 socket->writeChain(&write, sendBuf->clone());
2642 socket->setReadCB(&rcb);
2643 evb.loop();
2644
2645 t.join();
2646
2647 EXPECT_EQ(STATE_SUCCEEDED, write.state);
2648 EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2649 EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2650 ASSERT_EQ(1, rcb.buffers.size());
2651 ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2652 EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2653 EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2654 }
2655
TEST(AsyncSocketTest,ConnectTFOSupplyEarlyReadCB)2656 TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
2657 if (!folly::test::isTFOAvailable()) {
2658 GTEST_SKIP() << "TFO not supported.";
2659 }
2660
2661 // Start listening on a local port
2662 TestServer server(true);
2663
2664 // Connect using a AsyncSocket
2665 EventBase evb;
2666 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2667 socket->enableTFO();
2668 ConnCallback cb;
2669 socket->connect(&cb, server.getAddress(), 30);
2670 ReadCallback rcb;
2671 socket->setReadCB(&rcb);
2672
2673 std::array<uint8_t, 128> buf;
2674 memset(buf.data(), 'a', buf.size());
2675
2676 std::array<uint8_t, 3> readBuf;
2677 auto sendBuf = IOBuf::copyBuffer("hey");
2678
2679 std::thread t([&] {
2680 auto acceptedSocket = server.accept();
2681 acceptedSocket->write(buf.data(), buf.size());
2682 acceptedSocket->flush();
2683 acceptedSocket->readAll(readBuf.data(), readBuf.size());
2684 acceptedSocket->close();
2685 });
2686
2687 evb.loop();
2688
2689 ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2690 EXPECT_LE(0, socket->getConnectTime().count());
2691 EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2692 EXPECT_TRUE(socket->getTFOAttempted());
2693
2694 // Should trigger the connect
2695 WriteCallback write;
2696 socket->writeChain(&write, sendBuf->clone());
2697 evb.loop();
2698
2699 t.join();
2700
2701 EXPECT_EQ(STATE_SUCCEEDED, write.state);
2702 EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2703 EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2704 ASSERT_EQ(1, rcb.buffers.size());
2705 ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2706 EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2707 EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2708 }
2709
2710 /**
2711 * Test connecting to a server that isn't listening
2712 */
TEST(AsyncSocketTest,ConnectRefusedImmediatelyTFO)2713 TEST(AsyncSocketTest, ConnectRefusedImmediatelyTFO) {
2714 EventBase evb;
2715
2716 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2717
2718 socket->enableTFO();
2719
2720 // Hopefully nothing is actually listening on this address
2721 folly::SocketAddress addr("::1", 65535);
2722 ConnCallback cb;
2723 socket->connect(&cb, addr, 30);
2724
2725 evb.loop();
2726
2727 WriteCallback write1;
2728 // Trigger the connect if TFO attempt is supported.
2729 socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2730 WriteCallback write2;
2731 socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2732 evb.loop();
2733
2734 if (!socket->getTFOFinished()) {
2735 EXPECT_EQ(STATE_FAILED, write1.state);
2736 } else {
2737 EXPECT_EQ(STATE_SUCCEEDED, write1.state);
2738 EXPECT_FALSE(socket->getTFOSucceded());
2739 }
2740
2741 EXPECT_EQ(STATE_FAILED, write2.state);
2742
2743 EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2744 EXPECT_LE(0, socket->getConnectTime().count());
2745 EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2746 EXPECT_TRUE(socket->getTFOAttempted());
2747 }
2748
2749 /**
2750 * Test calling closeNow() immediately after connecting.
2751 */
TEST(AsyncSocketTest,ConnectWriteAndCloseNowTFO)2752 TEST(AsyncSocketTest, ConnectWriteAndCloseNowTFO) {
2753 TestServer server(true);
2754
2755 // connect()
2756 EventBase evb;
2757 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2758 socket->enableTFO();
2759
2760 ConnCallback ccb;
2761 socket->connect(&ccb, server.getAddress(), 30);
2762
2763 // write()
2764 std::array<char, 128> buf;
2765 memset(buf.data(), 'a', buf.size());
2766
2767 // close()
2768 socket->closeNow();
2769
2770 // Loop, although there shouldn't be anything to do.
2771 evb.loop();
2772
2773 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2774
2775 ASSERT_TRUE(socket->isClosedBySelf());
2776 ASSERT_FALSE(socket->isClosedByPeer());
2777 }
2778
2779 /**
2780 * Test calling close() immediately after connect()
2781 */
TEST(AsyncSocketTest,ConnectAndCloseTFO)2782 TEST(AsyncSocketTest, ConnectAndCloseTFO) {
2783 TestServer server(true);
2784
2785 // Connect using a AsyncSocket
2786 EventBase evb;
2787 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2788 socket->enableTFO();
2789
2790 ConnCallback ccb;
2791 socket->connect(&ccb, server.getAddress(), 30);
2792
2793 socket->close();
2794
2795 // Loop, although there shouldn't be anything to do.
2796 evb.loop();
2797
2798 // Make sure the connection was aborted
2799 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2800
2801 ASSERT_TRUE(socket->isClosedBySelf());
2802 ASSERT_FALSE(socket->isClosedByPeer());
2803 }
2804
2805 class MockAsyncTFOSocket : public AsyncSocket {
2806 public:
2807 using UniquePtr = std::unique_ptr<MockAsyncTFOSocket, Destructor>;
2808
MockAsyncTFOSocket(EventBase * evb)2809 explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}
2810
2811 MOCK_METHOD3(
2812 tfoSendMsg, ssize_t(NetworkSocket fd, struct msghdr* msg, int msg_flags));
2813 };
2814
TEST(AsyncSocketTest,TestTFOUnsupported)2815 TEST(AsyncSocketTest, TestTFOUnsupported) {
2816 TestServer server(true);
2817
2818 // Connect using a AsyncSocket
2819 EventBase evb;
2820 auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2821 socket->enableTFO();
2822
2823 ConnCallback ccb;
2824 socket->connect(&ccb, server.getAddress(), 30);
2825 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2826
2827 ReadCallback rcb;
2828 socket->setReadCB(&rcb);
2829
2830 EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2831 .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2832 WriteCallback write;
2833 auto sendBuf = IOBuf::copyBuffer("hey");
2834 socket->writeChain(&write, sendBuf->clone());
2835 EXPECT_EQ(STATE_WAITING, write.state);
2836
2837 std::array<uint8_t, 128> buf;
2838 memset(buf.data(), 'a', buf.size());
2839
2840 std::array<uint8_t, 3> readBuf;
2841
2842 std::thread t([&] {
2843 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2844 acceptedSocket->write(buf.data(), buf.size());
2845 acceptedSocket->flush();
2846 acceptedSocket->readAll(readBuf.data(), readBuf.size());
2847 acceptedSocket->close();
2848 });
2849
2850 evb.loop();
2851
2852 t.join();
2853 EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2854 EXPECT_EQ(STATE_SUCCEEDED, write.state);
2855
2856 EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2857 EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2858 ASSERT_EQ(1, rcb.buffers.size());
2859 ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2860 EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2861 EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2862 }
2863
TEST(AsyncSocketTest,ConnectRefusedDelayedTFO)2864 TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
2865 EventBase evb;
2866
2867 auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2868 socket->enableTFO();
2869
2870 // Hopefully this fails
2871 folly::SocketAddress fakeAddr("127.0.0.1", 65535);
2872 EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2873 .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
2874 sockaddr_storage addr;
2875 auto len = fakeAddr.getAddress(&addr);
2876 auto ret = netops::connect(fd, (const struct sockaddr*)&addr, len);
2877 LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
2878 << errno;
2879 return ret;
2880 }));
2881
2882 // Hopefully nothing is actually listening on this address
2883 ConnCallback cb;
2884 socket->connect(&cb, fakeAddr, 30);
2885
2886 WriteCallback write1;
2887 // Trigger the connect if TFO attempt is supported.
2888 socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2889
2890 if (socket->getTFOFinished()) {
2891 // This test is useless now.
2892 return;
2893 }
2894 WriteCallback write2;
2895 // Trigger the connect if TFO attempt is supported.
2896 socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2897 evb.loop();
2898
2899 EXPECT_EQ(STATE_FAILED, write1.state);
2900 EXPECT_EQ(STATE_FAILED, write2.state);
2901 EXPECT_FALSE(socket->getTFOSucceded());
2902
2903 EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2904 EXPECT_LE(0, socket->getConnectTime().count());
2905 EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2906 EXPECT_TRUE(socket->getTFOAttempted());
2907 }
2908
TEST(AsyncSocketTest,TestTFOUnsupportedTimeout)2909 TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
2910 // Try connecting to server that won't respond.
2911 //
2912 // This depends somewhat on the network where this test is run.
2913 // Hopefully this IP will be routable but unresponsive.
2914 // (Alternatively, we could try listening on a local raw socket, but that
2915 // normally requires root privileges.)
2916 auto host = SocketAddressTestHelper::isIPv6Enabled()
2917 ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
2918 : SocketAddressTestHelper::isIPv4Enabled()
2919 ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
2920 : nullptr;
2921 SocketAddress addr(host, 65535);
2922
2923 // Connect using a AsyncSocket
2924 EventBase evb;
2925 auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2926 socket->enableTFO();
2927
2928 ConnCallback ccb;
2929 // Set a very small timeout
2930 socket->connect(&ccb, addr, 1);
2931 EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2932
2933 ReadCallback rcb;
2934 socket->setReadCB(&rcb);
2935
2936 EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2937 .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2938 WriteCallback write;
2939 socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2940
2941 evb.loop();
2942
2943 EXPECT_EQ(STATE_FAILED, write.state);
2944 }
2945
TEST(AsyncSocketTest,TestTFOFallbackToConnect)2946 TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
2947 TestServer server(true);
2948
2949 // Connect using a AsyncSocket
2950 EventBase evb;
2951 auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2952 socket->enableTFO();
2953
2954 ConnCallback ccb;
2955 socket->connect(&ccb, server.getAddress(), 30);
2956 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2957
2958 ReadCallback rcb;
2959 socket->setReadCB(&rcb);
2960
2961 EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2962 .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
2963 sockaddr_storage addr;
2964 auto len = server.getAddress().getAddress(&addr);
2965 return netops::connect(fd, (const struct sockaddr*)&addr, len);
2966 }));
2967 WriteCallback write;
2968 auto sendBuf = IOBuf::copyBuffer("hey");
2969 socket->writeChain(&write, sendBuf->clone());
2970 EXPECT_EQ(STATE_WAITING, write.state);
2971
2972 std::array<uint8_t, 128> buf;
2973 memset(buf.data(), 'a', buf.size());
2974
2975 std::array<uint8_t, 3> readBuf;
2976
2977 std::thread t([&] {
2978 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2979 acceptedSocket->write(buf.data(), buf.size());
2980 acceptedSocket->flush();
2981 acceptedSocket->readAll(readBuf.data(), readBuf.size());
2982 acceptedSocket->close();
2983 });
2984
2985 evb.loop();
2986
2987 t.join();
2988 EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2989
2990 EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2991 EXPECT_EQ(STATE_SUCCEEDED, write.state);
2992
2993 EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2994 ASSERT_EQ(1, rcb.buffers.size());
2995 ASSERT_EQ(buf.size(), rcb.buffers[0].length);
2996 EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2997 }
2998
TEST(AsyncSocketTest,TestTFOFallbackTimeout)2999 TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
3000 // Try connecting to server that won't respond.
3001 //
3002 // This depends somewhat on the network where this test is run.
3003 // Hopefully this IP will be routable but unresponsive.
3004 // (Alternatively, we could try listening on a local raw socket, but that
3005 // normally requires root privileges.)
3006 auto host = SocketAddressTestHelper::isIPv6Enabled()
3007 ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
3008 : SocketAddressTestHelper::isIPv4Enabled()
3009 ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
3010 : nullptr;
3011 SocketAddress addr(host, 65535);
3012
3013 // Connect using a AsyncSocket
3014 EventBase evb;
3015 auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
3016 socket->enableTFO();
3017
3018 ConnCallback ccb;
3019 // Set a very small timeout
3020 socket->connect(&ccb, addr, 1);
3021 EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
3022
3023 ReadCallback rcb;
3024 socket->setReadCB(&rcb);
3025
3026 EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
3027 .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
3028 sockaddr_storage addr2;
3029 auto len = addr.getAddress(&addr2);
3030 return netops::connect(fd, (const struct sockaddr*)&addr2, len);
3031 }));
3032 WriteCallback write;
3033 socket->writeChain(&write, IOBuf::copyBuffer("hey"));
3034
3035 evb.loop();
3036
3037 EXPECT_EQ(STATE_FAILED, write.state);
3038 }
3039
TEST(AsyncSocketTest,TestTFOEagain)3040 TEST(AsyncSocketTest, TestTFOEagain) {
3041 TestServer server(true);
3042
3043 // Connect using a AsyncSocket
3044 EventBase evb;
3045 auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
3046 socket->enableTFO();
3047
3048 ConnCallback ccb;
3049 socket->connect(&ccb, server.getAddress(), 30);
3050
3051 EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
3052 .WillOnce(SetErrnoAndReturn(EAGAIN, -1));
3053 WriteCallback write;
3054 socket->writeChain(&write, IOBuf::copyBuffer("hey"));
3055
3056 evb.loop();
3057
3058 EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
3059 EXPECT_EQ(STATE_FAILED, write.state);
3060 }
3061
3062 // Sending a large amount of data in the first write which will
3063 // definitely not fit into MSS.
TEST(AsyncSocketTest,ConnectTFOWithBigData)3064 TEST(AsyncSocketTest, ConnectTFOWithBigData) {
3065 if (!folly::test::isTFOAvailable()) {
3066 GTEST_SKIP() << "TFO not supported.";
3067 }
3068
3069 // Start listening on a local port
3070 TestServer server(true);
3071
3072 // Connect using a AsyncSocket
3073 EventBase evb;
3074 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3075 socket->enableTFO();
3076 ConnCallback cb;
3077 socket->connect(&cb, server.getAddress(), 30);
3078
3079 std::array<uint8_t, 128> buf;
3080 memset(buf.data(), 'a', buf.size());
3081
3082 constexpr size_t len = 10 * 1024;
3083 auto sendBuf = IOBuf::create(len);
3084 sendBuf->append(len);
3085 std::array<uint8_t, len> readBuf;
3086
3087 std::thread t([&] {
3088 auto acceptedSocket = server.accept();
3089 acceptedSocket->write(buf.data(), buf.size());
3090 acceptedSocket->flush();
3091 acceptedSocket->readAll(readBuf.data(), readBuf.size());
3092 acceptedSocket->close();
3093 });
3094
3095 evb.loop();
3096
3097 ASSERT_EQ(cb.state, STATE_SUCCEEDED);
3098 EXPECT_LE(0, socket->getConnectTime().count());
3099 EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
3100 EXPECT_TRUE(socket->getTFOAttempted());
3101
3102 // Should trigger the connect
3103 WriteCallback write;
3104 ReadCallback rcb;
3105 socket->writeChain(&write, sendBuf->clone());
3106 socket->setReadCB(&rcb);
3107 evb.loop();
3108
3109 t.join();
3110
3111 EXPECT_EQ(STATE_SUCCEEDED, write.state);
3112 EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
3113 EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
3114 ASSERT_EQ(1, rcb.buffers.size());
3115 ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
3116 EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
3117 EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
3118 }
3119
3120 #endif // FOLLY_ALLOW_TFO
3121
3122 class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
3123 public:
3124 MOCK_METHOD1(evbAttached, void(AsyncSocket*));
3125 MOCK_METHOD1(evbDetached, void(AsyncSocket*));
3126 };
3127
TEST(AsyncSocketTest,EvbCallbacks)3128 TEST(AsyncSocketTest, EvbCallbacks) {
3129 auto cb = std::make_unique<MockEvbChangeCallback>();
3130 EventBase evb;
3131 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3132
3133 InSequence seq;
3134 EXPECT_CALL(*cb, evbDetached(socket.get())).Times(1);
3135 EXPECT_CALL(*cb, evbAttached(socket.get())).Times(1);
3136
3137 socket->setEvbChangedCallback(std::move(cb));
3138 socket->detachEventBase();
3139 socket->attachEventBase(&evb);
3140 }
3141
TEST(AsyncSocketTest,TestEvbDetachWtRegisteredIOHandlers)3142 TEST(AsyncSocketTest, TestEvbDetachWtRegisteredIOHandlers) {
3143 // Start listening on a local port
3144 TestServer server;
3145
3146 // Connect using a AsyncSocket
3147 EventBase evb;
3148 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3149 ConnCallback cb;
3150 socket->connect(&cb, server.getAddress(), 30);
3151
3152 evb.loop();
3153
3154 ASSERT_EQ(cb.state, STATE_SUCCEEDED);
3155 EXPECT_LE(0, socket->getConnectTime().count());
3156 EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
3157
3158 // After the ioHandlers are registered, still should be able to detach/attach
3159 ReadCallback rcb;
3160 socket->setReadCB(&rcb);
3161
3162 auto cbEvbChg = std::make_unique<MockEvbChangeCallback>();
3163 InSequence seq;
3164 EXPECT_CALL(*cbEvbChg, evbDetached(socket.get())).Times(1);
3165 EXPECT_CALL(*cbEvbChg, evbAttached(socket.get())).Times(1);
3166
3167 socket->setEvbChangedCallback(std::move(cbEvbChg));
3168 EXPECT_TRUE(socket->isDetachable());
3169 socket->detachEventBase();
3170 socket->attachEventBase(&evb);
3171
3172 socket->close();
3173 }
3174
TEST(AsyncSocketTest,TestEvbDetachThenClose)3175 TEST(AsyncSocketTest, TestEvbDetachThenClose) {
3176 // Start listening on a local port
3177 TestServer server;
3178
3179 // Connect an AsyncSocket to the server
3180 EventBase evb;
3181 auto socket = AsyncSocket::newSocket(&evb);
3182 ConnCallback cb;
3183 socket->connect(&cb, server.getAddress(), 30);
3184
3185 evb.loop();
3186
3187 ASSERT_EQ(cb.state, STATE_SUCCEEDED);
3188 EXPECT_LE(0, socket->getConnectTime().count());
3189 EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
3190
3191 // After the ioHandlers are registered, still should be able to detach/attach
3192 ReadCallback rcb;
3193 socket->setReadCB(&rcb);
3194
3195 auto cbEvbChg = std::make_unique<MockEvbChangeCallback>();
3196 InSequence seq;
3197 EXPECT_CALL(*cbEvbChg, evbDetached(socket.get())).Times(1);
3198
3199 socket->setEvbChangedCallback(std::move(cbEvbChg));
3200
3201 // Should be possible to destroy/call closeNow() without an attached EventBase
3202 EXPECT_TRUE(socket->isDetachable());
3203 socket->detachEventBase();
3204 socket.reset();
3205 }
3206
TEST(AsyncSocket,BytesWrittenWithMove)3207 TEST(AsyncSocket, BytesWrittenWithMove) {
3208 TestServer server;
3209
3210 EventBase evb;
3211 auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
3212 ConnCallback ccb;
3213 socket1->connect(&ccb, server.getAddress(), 30);
3214 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
3215
3216 EXPECT_EQ(0, socket1->getRawBytesWritten());
3217 std::vector<uint8_t> wbuf(128, 'a');
3218 WriteCallback wcb;
3219 socket1->write(&wcb, wbuf.data(), wbuf.size());
3220 evb.loopOnce();
3221 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3222 EXPECT_EQ(wbuf.size(), socket1->getRawBytesWritten());
3223 EXPECT_EQ(wbuf.size(), socket1->getAppBytesWritten());
3224
3225 auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
3226 EXPECT_EQ(wbuf.size(), socket2->getRawBytesWritten());
3227 EXPECT_EQ(wbuf.size(), socket2->getAppBytesWritten());
3228 }
3229
3230 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
3231 struct AsyncSocketErrMessageCallbackTestParams {
3232 folly::Optional<int> resetCallbackAfter;
3233 folly::Optional<int> closeSocketAfter;
3234 int gotTimestampExpected{0};
3235 int gotByteSeqExpected{0};
3236 };
3237
3238 class AsyncSocketErrMessageCallbackTest
3239 : public ::testing::TestWithParam<AsyncSocketErrMessageCallbackTestParams> {
3240 public:
3241 static std::vector<AsyncSocketErrMessageCallbackTestParams>
getTestingValues()3242 getTestingValues() {
3243 std::vector<AsyncSocketErrMessageCallbackTestParams> vals;
3244 // each socket err message triggers two socket callbacks:
3245 // (1) timestamp callback
3246 // (2) byteseq callback
3247
3248 // reset callback cases
3249 // resetting the callback should prevent any further callbacks
3250 {
3251 AsyncSocketErrMessageCallbackTestParams params;
3252 params.resetCallbackAfter = 1;
3253 params.gotTimestampExpected = 1;
3254 params.gotByteSeqExpected = 0;
3255 vals.push_back(params);
3256 }
3257 {
3258 AsyncSocketErrMessageCallbackTestParams params;
3259 params.resetCallbackAfter = 2;
3260 params.gotTimestampExpected = 1;
3261 params.gotByteSeqExpected = 1;
3262 vals.push_back(params);
3263 }
3264 {
3265 AsyncSocketErrMessageCallbackTestParams params;
3266 params.resetCallbackAfter = 3;
3267 params.gotTimestampExpected = 2;
3268 params.gotByteSeqExpected = 1;
3269 vals.push_back(params);
3270 }
3271 {
3272 AsyncSocketErrMessageCallbackTestParams params;
3273 params.resetCallbackAfter = 4;
3274 params.gotTimestampExpected = 2;
3275 params.gotByteSeqExpected = 2;
3276 vals.push_back(params);
3277 }
3278
3279 // close socket cases
3280 // closing the socket will prevent callbacks after the current err message
3281 // callbacks (both timestamp and byteseq) are completed
3282 {
3283 AsyncSocketErrMessageCallbackTestParams params;
3284 params.closeSocketAfter = 1;
3285 params.gotTimestampExpected = 1;
3286 params.gotByteSeqExpected = 1;
3287 vals.push_back(params);
3288 }
3289 {
3290 AsyncSocketErrMessageCallbackTestParams params;
3291 params.closeSocketAfter = 2;
3292 params.gotTimestampExpected = 1;
3293 params.gotByteSeqExpected = 1;
3294 vals.push_back(params);
3295 }
3296 {
3297 AsyncSocketErrMessageCallbackTestParams params;
3298 params.closeSocketAfter = 3;
3299 params.gotTimestampExpected = 2;
3300 params.gotByteSeqExpected = 2;
3301 vals.push_back(params);
3302 }
3303 {
3304 AsyncSocketErrMessageCallbackTestParams params;
3305 params.closeSocketAfter = 4;
3306 params.gotTimestampExpected = 2;
3307 params.gotByteSeqExpected = 2;
3308 vals.push_back(params);
3309 }
3310 return vals;
3311 }
3312 };
3313
3314 INSTANTIATE_TEST_SUITE_P(
3315 ErrMessageTests,
3316 AsyncSocketErrMessageCallbackTest,
3317 ::testing::ValuesIn(AsyncSocketErrMessageCallbackTest::getTestingValues()));
3318
3319 class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
3320 public:
TestErrMessageCallback()3321 TestErrMessageCallback()
3322 : exception_(folly::AsyncSocketException::UNKNOWN, "none") {}
3323
errMessage(const cmsghdr & cmsg)3324 void errMessage(const cmsghdr& cmsg) noexcept override {
3325 if (cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_TIMESTAMPING) {
3326 gotTimestamp_++;
3327 checkResetCallback();
3328 checkCloseSocket();
3329 } else if (
3330 (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
3331 (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
3332 gotByteSeq_++;
3333 checkResetCallback();
3334 checkCloseSocket();
3335 }
3336 }
3337
errMessageError(const folly::AsyncSocketException & ex)3338 void errMessageError(
3339 const folly::AsyncSocketException& ex) noexcept override {
3340 exception_ = ex;
3341 }
3342
checkResetCallback()3343 void checkResetCallback() noexcept {
3344 if (socket_ != nullptr && resetCallbackAfter_ != -1 &&
3345 gotTimestamp_ + gotByteSeq_ == resetCallbackAfter_) {
3346 socket_->setErrMessageCB(nullptr);
3347 }
3348 }
3349
checkCloseSocket()3350 void checkCloseSocket() noexcept {
3351 if (socket_ != nullptr && closeSocketAfter_ != -1 &&
3352 gotTimestamp_ + gotByteSeq_ == closeSocketAfter_) {
3353 socket_->close();
3354 }
3355 }
3356
3357 folly::AsyncSocket* socket_{nullptr};
3358 folly::AsyncSocketException exception_;
3359 int gotTimestamp_{0};
3360 int gotByteSeq_{0};
3361 int resetCallbackAfter_{-1};
3362 int closeSocketAfter_{-1};
3363 };
3364
TEST_P(AsyncSocketErrMessageCallbackTest,ErrMessageCallback)3365 TEST_P(AsyncSocketErrMessageCallbackTest, ErrMessageCallback) {
3366 TestServer server;
3367
3368 // connect()
3369 EventBase evb;
3370 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3371
3372 ConnCallback ccb;
3373 socket->connect(&ccb, server.getAddress(), 30);
3374 LOG(INFO) << "Client socket fd=" << socket->getNetworkSocket();
3375
3376 // Let the socket
3377 evb.loop();
3378
3379 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
3380
3381 // Set read callback to keep the socket subscribed for event
3382 // notifications. Though we're no planning to read anything from
3383 // this side of the connection.
3384 ReadCallback rcb(1);
3385 socket->setReadCB(&rcb);
3386
3387 // Set up timestamp callbacks
3388 TestErrMessageCallback errMsgCB;
3389 socket->setErrMessageCB(&errMsgCB);
3390 ASSERT_EQ(
3391 socket->getErrMessageCallback(),
3392 static_cast<folly::AsyncSocket::ErrMessageCallback*>(&errMsgCB));
3393
3394 // set the number of error messages before socket is closed or callback reset
3395 const auto testParams = GetParam();
3396 errMsgCB.socket_ = socket.get();
3397 if (testParams.resetCallbackAfter.has_value()) {
3398 errMsgCB.resetCallbackAfter_ = testParams.resetCallbackAfter.value();
3399 }
3400 if (testParams.closeSocketAfter.has_value()) {
3401 errMsgCB.closeSocketAfter_ = testParams.closeSocketAfter.value();
3402 }
3403
3404 // Enable timestamp notifications
3405 ASSERT_NE(socket->getNetworkSocket(), NetworkSocket());
3406 int flags = folly::netops::SOF_TIMESTAMPING_OPT_ID |
3407 folly::netops::SOF_TIMESTAMPING_OPT_TSONLY |
3408 folly::netops::SOF_TIMESTAMPING_SOFTWARE |
3409 folly::netops::SOF_TIMESTAMPING_OPT_CMSG |
3410 folly::netops::SOF_TIMESTAMPING_TX_SCHED;
3411 SocketOptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
3412 EXPECT_EQ(tstampingOpt.apply(socket->getNetworkSocket(), flags), 0);
3413
3414 // write()
3415 std::vector<uint8_t> wbuf(128, 'a');
3416 WriteCallback wcb;
3417 // Send two packets to get two EOM notifications
3418 socket->write(&wcb, wbuf.data(), wbuf.size() / 2);
3419 socket->write(&wcb, wbuf.data() + wbuf.size() / 2, wbuf.size() / 2);
3420
3421 // Accept the connection.
3422 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
3423 LOG(INFO) << "Server socket fd=" << acceptedSocket->getNetworkSocket();
3424
3425 // Loop
3426 evb.loopOnce();
3427 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3428
3429 // Check that we can read the data that was written to the socket
3430 std::vector<uint8_t> rbuf(wbuf.size(), 0);
3431 uint32_t bytesRead = acceptedSocket->readAll(rbuf.data(), rbuf.size());
3432 ASSERT_EQ(bytesRead, wbuf.size());
3433 ASSERT_TRUE(std::equal(wbuf.begin(), wbuf.end(), rbuf.begin()));
3434
3435 // Close both sockets
3436 acceptedSocket->close();
3437 socket->close();
3438
3439 ASSERT_TRUE(socket->isClosedBySelf());
3440 ASSERT_FALSE(socket->isClosedByPeer());
3441
3442 // Check for the timestamp notifications.
3443 ASSERT_EQ(
3444 errMsgCB.exception_.getType(), folly::AsyncSocketException::UNKNOWN);
3445 ASSERT_EQ(errMsgCB.gotByteSeq_, testParams.gotByteSeqExpected);
3446 ASSERT_EQ(errMsgCB.gotTimestamp_, testParams.gotTimestampExpected);
3447 }
3448
3449 #endif // FOLLY_HAVE_MSG_ERRQUEUE
3450
3451 #if FOLLY_HAVE_SO_TIMESTAMPING
3452
3453 class AsyncSocketByteEventTest : public ::testing::Test {
3454 protected:
3455 using MockDispatcher = ::testing::NiceMock<netops::test::MockDispatcher>;
3456 using TestObserver = MockAsyncTransportObserverForByteEvents;
3457 using ByteEventType = AsyncTransport::ByteEvent::Type;
3458
3459 /**
3460 * Components of a client connection to TestServer.
3461 *
3462 * Includes EventBase, client's AsyncSocket, and corresponding server socket.
3463 */
3464 class ClientConn {
3465 public:
3466 /**
3467 * Call to sendmsg intercepted and recorded by netops::Dispatcher.
3468 */
3469 struct SendmsgInvocation {
3470 // the iovecs in the msghdr
3471 std::vector<iovec> iovs;
3472
3473 // WriteFlags encoded in msg_flags
3474 WriteFlags writeFlagsInMsgFlags{WriteFlags::NONE};
3475
3476 // WriteFlags encoded in the msghdr's ancillary data
3477 WriteFlags writeFlagsInAncillary{WriteFlags::NONE};
3478 };
3479
ClientConn(std::shared_ptr<TestServer> server,std::shared_ptr<AsyncSocket> socket=nullptr,std::shared_ptr<BlockingSocket> acceptedSocket=nullptr)3480 explicit ClientConn(
3481 std::shared_ptr<TestServer> server,
3482 std::shared_ptr<AsyncSocket> socket = nullptr,
3483 std::shared_ptr<BlockingSocket> acceptedSocket = nullptr)
3484 : server_(std::move(server)),
3485 socket_(std::move(socket)),
3486 acceptedSocket_(std::move(acceptedSocket)) {
3487 if (!socket_) {
3488 socket_ = AsyncSocket::newSocket(&getEventBase());
3489 } else {
3490 setReadCb();
3491 }
3492 socket_->setOverrideNetOpsDispatcher(netOpsDispatcher_);
3493 netOpsDispatcher_->forwardToDefaultImpl();
3494 }
3495
connect()3496 void connect() {
3497 CHECK_NOTNULL(socket_.get());
3498 CHECK_NOTNULL(socket_->getEventBase());
3499 socket_->connect(&connCb_, server_->getAddress(), 30);
3500 socket_->getEventBase()->loop();
3501 ASSERT_EQ(connCb_.state, STATE_SUCCEEDED);
3502 setReadCb();
3503
3504 // accept the socket at the server
3505 acceptedSocket_ = server_->accept();
3506 }
3507
setReadCb()3508 void setReadCb() {
3509 // Due to how libevent works, we currently need to be subscribed to
3510 // EV_READ events in order to get error messages.
3511 //
3512 // TODO(bschlinker): Resolve this with libevent modification.
3513 // See https://github.com/libevent/libevent/issues/1038 for details.
3514 socket_->setReadCB(&readCb_);
3515 }
3516
attachObserver(bool enableByteEvents,bool enablePrewrite=false)3517 std::shared_ptr<NiceMock<TestObserver>> attachObserver(
3518 bool enableByteEvents, bool enablePrewrite = false) {
3519 auto observer = AsyncSocketByteEventTest::attachObserver(
3520 socket_.get(), enableByteEvents, enablePrewrite);
3521 observers_.push_back(observer);
3522 return observer;
3523 }
3524
3525 /**
3526 * Write to client socket and read at server.
3527 */
write(const iovec * iov,const size_t count,const WriteFlags writeFlags)3528 void write(
3529 const iovec* iov, const size_t count, const WriteFlags writeFlags) {
3530 CHECK_NOTNULL(socket_.get());
3531 CHECK_NOTNULL(socket_->getEventBase());
3532
3533 // read buffer for server
3534 std::vector<uint8_t> rbuf(iovsToNumBytes(iov, count), 0);
3535 uint64_t rbufReadBytes = 0;
3536
3537 // write to the client socket, incrementally read at the server
3538 WriteCallback wcb;
3539 socket_->writev(&wcb, iov, count, writeFlags);
3540 while (wcb.state == STATE_WAITING) {
3541 socket_->getEventBase()->loopOnce();
3542 rbufReadBytes += acceptedSocket_->readNoBlock(
3543 rbuf.data() + rbufReadBytes, rbuf.size() - rbufReadBytes);
3544 }
3545 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3546
3547 // finish reading, then compare
3548 rbufReadBytes += acceptedSocket_->readAll(
3549 rbuf.data() + rbufReadBytes, rbuf.size() - rbufReadBytes);
3550 const auto cBuf = iovsToVector(iov, count);
3551 ASSERT_EQ(rbufReadBytes, cBuf.size());
3552 ASSERT_TRUE(std::equal(cBuf.begin(), cBuf.end(), rbuf.begin()));
3553 }
3554
3555 /**
3556 * Write to client socket and read at server.
3557 */
write(const std::vector<uint8_t> & wbuf,const WriteFlags writeFlags)3558 void write(const std::vector<uint8_t>& wbuf, const WriteFlags writeFlags) {
3559 iovec op;
3560 op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
3561 op.iov_len = wbuf.size();
3562 write(&op, 1, writeFlags);
3563 }
3564
3565 /**
3566 * Write to client socket, echo at server, and wait for echo at client.
3567 *
3568 * Waiting for echo at client ensures that we have given opportunity for
3569 * timestamps to be generated by the kernel.
3570 */
writeAndReflect(const iovec * iov,const size_t count,const WriteFlags writeFlags)3571 void writeAndReflect(
3572 const iovec* iov, const size_t count, const WriteFlags writeFlags) {
3573 write(iov, count, writeFlags);
3574
3575 // reflect
3576 const auto wbuf = iovsToVector(iov, count);
3577 acceptedSocket_->write(wbuf.data(), wbuf.size());
3578 while (wbuf.size() != readCb_.dataRead()) {
3579 socket_->getEventBase()->loopOnce();
3580 }
3581 readCb_.verifyData(wbuf.data(), wbuf.size());
3582 readCb_.clearData();
3583 }
3584
3585 /**
3586 * Write to client socket, echo at server, and wait for echo at client.
3587 *
3588 * Waiting for echo at client ensures that we have given opportunity for
3589 * timestamps to be generated by the kernel.
3590 */
writeAndReflect(const std::vector<uint8_t> & wbuf,const WriteFlags writeFlags)3591 void writeAndReflect(
3592 const std::vector<uint8_t>& wbuf, const WriteFlags writeFlags) {
3593 iovec op = {};
3594 op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
3595 op.iov_len = wbuf.size();
3596 writeAndReflect(&op, 1, writeFlags);
3597 }
3598
getRawSocket()3599 std::shared_ptr<AsyncSocket> getRawSocket() { return socket_; }
3600
getAcceptedSocket()3601 std::shared_ptr<BlockingSocket> getAcceptedSocket() {
3602 return acceptedSocket_;
3603 }
3604
getEventBase()3605 EventBase& getEventBase() {
3606 static EventBase evb; // use same EventBase for all client sockets
3607 return evb;
3608 }
3609
getNetOpsDispatcher() const3610 std::shared_ptr<MockDispatcher> getNetOpsDispatcher() const {
3611 return netOpsDispatcher_;
3612 }
3613
3614 /**
3615 * Get recorded SendmsgInvocations.
3616 */
getSendmsgInvocations()3617 const std::vector<SendmsgInvocation>& getSendmsgInvocations() {
3618 return sendmsgInvocations_;
3619 }
3620
3621 /**
3622 * Expect a call to setsockopt with optname SO_TIMESTAMPING.
3623 */
netOpsExpectTimestampingSetSockOpt()3624 void netOpsExpectTimestampingSetSockOpt() {
3625 // must whitelist other calls
3626 EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
3627 .Times(AnyNumber());
3628 EXPECT_CALL(
3629 *netOpsDispatcher_, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
3630 .Times(1);
3631 }
3632
3633 /**
3634 * Expect NO calls to setsockopt with optname SO_TIMESTAMPING.
3635 */
netOpsExpectNoTimestampingSetSockOpt()3636 void netOpsExpectNoTimestampingSetSockOpt() {
3637 // must whitelist other calls
3638 EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
3639 .Times(AnyNumber());
3640 EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, SO_TIMESTAMPING, _, _))
3641 .Times(0);
3642 }
3643
3644 /**
3645 * Expect sendmsg to be called with the passed WriteFlags in ancillary data.
3646 */
netOpsExpectSendmsgWithAncillaryTsFlags(WriteFlags writeFlags)3647 void netOpsExpectSendmsgWithAncillaryTsFlags(WriteFlags writeFlags) {
3648 auto getMsgAncillaryTsFlags = std::bind(
3649 (WriteFlags(*)(const struct msghdr* msg)) & ::getMsgAncillaryTsFlags,
3650 std::placeholders::_1);
3651 EXPECT_CALL(
3652 *netOpsDispatcher_,
3653 sendmsg(_, ResultOf(getMsgAncillaryTsFlags, Eq(writeFlags)), _))
3654 .WillOnce(DoDefault());
3655 }
3656
3657 /**
3658 * When sendmsg is called, record details and then forward to real sendmsg.
3659 *
3660 * This creates a default action.
3661 */
netOpsOnSendmsgRecordIovecsAndFlagsAndFwd()3662 void netOpsOnSendmsgRecordIovecsAndFlagsAndFwd() {
3663 ON_CALL(*netOpsDispatcher_, sendmsg(_, _, _))
3664 .WillByDefault(::testing::Invoke(
3665 [this](NetworkSocket s, const msghdr* message, int flags) {
3666 recordSendmsgInvocation(s, message, flags);
3667 return netops::Dispatcher::getDefaultInstance()->sendmsg(
3668 s, message, flags);
3669 }));
3670 }
3671
netOpsVerifyAndClearExpectations()3672 void netOpsVerifyAndClearExpectations() {
3673 Mock::VerifyAndClearExpectations(netOpsDispatcher_.get());
3674 }
3675
3676 private:
recordSendmsgInvocation(NetworkSocket,const msghdr * message,int flags)3677 void recordSendmsgInvocation(
3678 NetworkSocket /* s */, const msghdr* message, int flags) {
3679 SendmsgInvocation invoc = {};
3680 invoc.iovs = getMsgIovecs(message);
3681 invoc.writeFlagsInMsgFlags = msgFlagsToWriteFlags(flags);
3682 invoc.writeFlagsInAncillary = getMsgAncillaryTsFlags(message);
3683 sendmsgInvocations_.emplace_back(std::move(invoc));
3684 }
3685
3686 // server
3687 std::shared_ptr<TestServer> server_;
3688
3689 // managed observers
3690 std::vector<std::shared_ptr<TestObserver>> observers_;
3691
3692 // socket components
3693 ConnCallback connCb_;
3694 ReadCallback readCb_;
3695 std::shared_ptr<MockDispatcher> netOpsDispatcher_{
3696 std::make_shared<MockDispatcher>()};
3697 std::shared_ptr<AsyncSocket> socket_;
3698
3699 // accepted socket at server
3700 std::shared_ptr<BlockingSocket> acceptedSocket_;
3701
3702 // sendmsg invocations observed
3703 std::vector<SendmsgInvocation> sendmsgInvocations_;
3704 };
3705
getClientConn()3706 ClientConn getClientConn() { return ClientConn(server_); }
3707
3708 /**
3709 * Static utility functions.
3710 */
3711
attachObserver(AsyncSocket * socket,bool enableByteEvents,bool enablePrewrite=false)3712 static std::shared_ptr<NiceMock<TestObserver>> attachObserver(
3713 AsyncSocket* socket, bool enableByteEvents, bool enablePrewrite = false) {
3714 AsyncTransport::LifecycleObserver::Config config = {};
3715 config.byteEvents = enableByteEvents;
3716 config.prewrite = enablePrewrite;
3717 return std::make_shared<NiceMock<TestObserver>>(socket, config);
3718 }
3719
getHundredBytesOfData()3720 static std::vector<uint8_t> getHundredBytesOfData() {
3721 return std::vector<uint8_t>(
3722 kOneHundredCharacterString.begin(), kOneHundredCharacterString.end());
3723 }
3724
get10KBOfData()3725 static std::vector<uint8_t> get10KBOfData() {
3726 std::vector<uint8_t> vec;
3727 vec.reserve(kOneHundredCharacterString.size() * 100);
3728 for (auto i = 0; i < 100; i++) {
3729 vec.insert(
3730 vec.end(),
3731 kOneHundredCharacterString.begin(),
3732 kOneHundredCharacterString.end());
3733 }
3734 CHECK_EQ(10000, vec.size());
3735 return vec;
3736 }
3737
get1000KBOfData()3738 static std::vector<uint8_t> get1000KBOfData() {
3739 std::vector<uint8_t> vec;
3740 vec.reserve(kOneHundredCharacterString.size() * 10000);
3741 for (auto i = 0; i < 10000; i++) {
3742 vec.insert(
3743 vec.end(),
3744 kOneHundredCharacterString.begin(),
3745 kOneHundredCharacterString.end());
3746 }
3747 CHECK_EQ(1000000, vec.size());
3748 return vec;
3749 }
3750
dropWriteFromFlags(WriteFlags writeFlags)3751 static WriteFlags dropWriteFromFlags(WriteFlags writeFlags) {
3752 return writeFlags & ~WriteFlags::TIMESTAMP_WRITE;
3753 }
3754
getMsgIovecs(const struct msghdr & msg)3755 static std::vector<iovec> getMsgIovecs(const struct msghdr& msg) {
3756 std::vector<iovec> iovecs;
3757 for (size_t i = 0; i < msg.msg_iovlen; i++) {
3758 iovecs.emplace_back(msg.msg_iov[i]);
3759 }
3760 return iovecs;
3761 }
3762
getMsgIovecs(const struct msghdr * msg)3763 static std::vector<iovec> getMsgIovecs(const struct msghdr* msg) {
3764 return getMsgIovecs(*msg);
3765 }
3766
iovsToVector(const iovec * iov,const size_t count)3767 static std::vector<uint8_t> iovsToVector(
3768 const iovec* iov, const size_t count) {
3769 std::vector<uint8_t> vec;
3770 for (size_t i = 0; i < count; i++) {
3771 if (iov[i].iov_len == 0) {
3772 continue;
3773 }
3774 const auto ptr = reinterpret_cast<uint8_t*>(iov[i].iov_base);
3775 vec.insert(vec.end(), ptr, ptr + iov[i].iov_len);
3776 }
3777 return vec;
3778 }
3779
iovsToNumBytes(const iovec * iov,const size_t count)3780 static size_t iovsToNumBytes(const iovec* iov, const size_t count) {
3781 size_t bytes = 0;
3782 for (size_t i = 0; i < count; i++) {
3783 bytes += iov[i].iov_len;
3784 }
3785 return bytes;
3786 }
3787
filterToWriteEvents(const std::vector<AsyncTransport::ByteEvent> & input)3788 std::vector<AsyncTransport::ByteEvent> filterToWriteEvents(
3789 const std::vector<AsyncTransport::ByteEvent>& input) {
3790 std::vector<AsyncTransport::ByteEvent> result;
3791 std::copy_if(
3792 input.begin(),
3793 input.end(),
3794 std::back_inserter(result),
3795 [](auto& event) {
3796 return event.type == AsyncTransport::ByteEvent::WRITE;
3797 });
3798 return result;
3799 }
3800
3801 // server
3802 std::shared_ptr<TestServer> server_{std::make_shared<TestServer>()};
3803 };
3804
TEST_F(AsyncSocketByteEventTest,MsgFlagsToWriteFlags)3805 TEST_F(AsyncSocketByteEventTest, MsgFlagsToWriteFlags) {
3806 #ifdef MSG_MORE
3807 EXPECT_EQ(WriteFlags::CORK, msgFlagsToWriteFlags(MSG_MORE));
3808 #endif // MSG_MORE
3809
3810 #ifdef MSG_EOR
3811 EXPECT_EQ(WriteFlags::EOR, msgFlagsToWriteFlags(MSG_EOR));
3812 #endif
3813
3814 #ifdef MSG_ZEROCOPY
3815 EXPECT_EQ(WriteFlags::WRITE_MSG_ZEROCOPY, msgFlagsToWriteFlags(MSG_ZEROCOPY));
3816 #endif
3817
3818 #if defined(MSG_MORE) && defined(MSG_EOR)
3819 EXPECT_EQ(
3820 WriteFlags::CORK | WriteFlags::EOR,
3821 msgFlagsToWriteFlags(MSG_MORE | MSG_EOR));
3822 #endif
3823 }
3824
TEST_F(AsyncSocketByteEventTest,GetMsgAncillaryTsFlags)3825 TEST_F(AsyncSocketByteEventTest, GetMsgAncillaryTsFlags) {
3826 auto ancillaryDataSize = CMSG_LEN(sizeof(uint32_t));
3827 auto ancillaryData = reinterpret_cast<char*>(alloca(ancillaryDataSize));
3828
3829 auto getMsg = [&ancillaryDataSize, &ancillaryData](uint32_t sofFlags) {
3830 struct msghdr msg = {};
3831 msg.msg_name = nullptr;
3832 msg.msg_namelen = 0;
3833 msg.msg_iov = nullptr;
3834 msg.msg_iovlen = 0;
3835 msg.msg_flags = 0;
3836 msg.msg_controllen = 0;
3837 msg.msg_control = nullptr;
3838 if (sofFlags) {
3839 msg.msg_controllen = ancillaryDataSize;
3840 msg.msg_control = ancillaryData;
3841 struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
3842 CHECK_NOTNULL(cmsg);
3843 cmsg->cmsg_level = SOL_SOCKET;
3844 cmsg->cmsg_type = SO_TIMESTAMPING;
3845 cmsg->cmsg_len = CMSG_LEN(sizeof(uint32_t));
3846 memcpy(CMSG_DATA(cmsg), &sofFlags, sizeof(sofFlags));
3847 }
3848 return msg;
3849 };
3850
3851 // SCHED
3852 {
3853 auto msg = getMsg(folly::netops::SOF_TIMESTAMPING_TX_SCHED);
3854 EXPECT_EQ(WriteFlags::TIMESTAMP_SCHED, getMsgAncillaryTsFlags(msg));
3855 }
3856
3857 // TX
3858 {
3859 auto msg = getMsg(folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE);
3860 EXPECT_EQ(WriteFlags::TIMESTAMP_TX, getMsgAncillaryTsFlags(msg));
3861 }
3862
3863 // ACK
3864 {
3865 auto msg = getMsg(folly::netops::SOF_TIMESTAMPING_TX_ACK);
3866 EXPECT_EQ(WriteFlags::TIMESTAMP_ACK, getMsgAncillaryTsFlags(msg));
3867 }
3868
3869 // SCHED + TX + ACK
3870 {
3871 auto msg = getMsg(
3872 folly::netops::SOF_TIMESTAMPING_TX_SCHED |
3873 folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE |
3874 folly::netops::SOF_TIMESTAMPING_TX_ACK);
3875 EXPECT_EQ(
3876 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX |
3877 WriteFlags::TIMESTAMP_ACK,
3878 getMsgAncillaryTsFlags(msg));
3879 }
3880 }
3881
TEST_F(AsyncSocketByteEventTest,ObserverAttachedBeforeConnect)3882 TEST_F(AsyncSocketByteEventTest, ObserverAttachedBeforeConnect) {
3883 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3884 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3885 const std::vector<uint8_t> wbuf(1, 'a');
3886
3887 auto clientConn = getClientConn();
3888 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3889 clientConn.netOpsExpectTimestampingSetSockOpt();
3890 clientConn.connect();
3891 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3892 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3893 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3894 clientConn.netOpsVerifyAndClearExpectations();
3895
3896 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
3897 clientConn.writeAndReflect(wbuf, flags);
3898 clientConn.netOpsVerifyAndClearExpectations();
3899 EXPECT_THAT(observer->byteEvents, SizeIs(4));
3900 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3901 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3902 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3903 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3904
3905 // write again to check offsets
3906 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
3907 clientConn.writeAndReflect(wbuf, flags);
3908 clientConn.netOpsVerifyAndClearExpectations();
3909 EXPECT_THAT(observer->byteEvents, SizeIs(8));
3910 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3911 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3912 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3913 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3914 }
3915
TEST_F(AsyncSocketByteEventTest,ObserverAttachedAfterConnect)3916 TEST_F(AsyncSocketByteEventTest, ObserverAttachedAfterConnect) {
3917 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3918 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3919 const std::vector<uint8_t> wbuf(1, 'a');
3920
3921 auto clientConn = getClientConn();
3922 clientConn.netOpsExpectNoTimestampingSetSockOpt();
3923 clientConn.connect();
3924 clientConn.netOpsVerifyAndClearExpectations();
3925
3926 clientConn.netOpsExpectTimestampingSetSockOpt();
3927 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3928 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3929 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3930 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3931 clientConn.netOpsVerifyAndClearExpectations();
3932
3933 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
3934 clientConn.writeAndReflect(wbuf, flags);
3935 clientConn.netOpsVerifyAndClearExpectations();
3936 EXPECT_THAT(observer->byteEvents, SizeIs(4));
3937 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3938 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3939 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3940 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3941
3942 // write again to check offsets
3943 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
3944 clientConn.writeAndReflect(wbuf, flags);
3945 clientConn.netOpsVerifyAndClearExpectations();
3946 EXPECT_THAT(observer->byteEvents, SizeIs(8));
3947 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3948 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3949 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3950 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3951 }
3952
TEST_F(AsyncSocketByteEventTest,ObserverAttachedBeforeConnectByteEventsDisabled)3953 TEST_F(
3954 AsyncSocketByteEventTest, ObserverAttachedBeforeConnectByteEventsDisabled) {
3955 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3956 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3957 const std::vector<uint8_t> wbuf(1, 'a');
3958
3959 auto clientConn = getClientConn();
3960 auto observer = clientConn.attachObserver(false /* enableByteEvents */);
3961 clientConn.netOpsExpectNoTimestampingSetSockOpt();
3962
3963 clientConn.connect(); // connect after observer attached
3964 EXPECT_EQ(0, observer->byteEventsEnabledCalled);
3965 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3966 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3967 clientConn.netOpsVerifyAndClearExpectations();
3968
3969 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
3970 WriteFlags::NONE); // events disabled
3971 clientConn.writeAndReflect(wbuf, flags);
3972 EXPECT_THAT(observer->byteEvents, IsEmpty());
3973 clientConn.netOpsVerifyAndClearExpectations();
3974
3975 // now enable ByteEvents with another observer, then write again
3976 clientConn.netOpsExpectTimestampingSetSockOpt();
3977 auto observer2 = clientConn.attachObserver(true /* enableByteEvents */);
3978 EXPECT_EQ(0, observer->byteEventsEnabledCalled); // observer 1 doesn't want
3979 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3980 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3981 EXPECT_EQ(1, observer2->byteEventsEnabledCalled); // should be set
3982 EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
3983 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3984 EXPECT_NE(WriteFlags::NONE, flags);
3985 EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
3986 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
3987 clientConn.writeAndReflect(wbuf, flags);
3988 clientConn.netOpsVerifyAndClearExpectations();
3989
3990 // expect no ByteEvents for first observer, four for the second
3991 EXPECT_THAT(observer->byteEvents, IsEmpty());
3992 EXPECT_THAT(observer2->byteEvents, SizeIs(4));
3993 EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3994 EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3995 EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
3996 EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
3997 }
3998
TEST_F(AsyncSocketByteEventTest,ObserverAttachedAfterConnectByteEventsDisabled)3999 TEST_F(
4000 AsyncSocketByteEventTest, ObserverAttachedAfterConnectByteEventsDisabled) {
4001 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4002 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4003 const std::vector<uint8_t> wbuf(1, 'a');
4004
4005 auto clientConn = getClientConn();
4006 clientConn.netOpsExpectNoTimestampingSetSockOpt();
4007
4008 clientConn.connect(); // connect before observer attached
4009
4010 auto observer = clientConn.attachObserver(false /* enableByteEvents */);
4011 EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4012 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4013 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4014 clientConn.netOpsVerifyAndClearExpectations();
4015
4016 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
4017 WriteFlags::NONE); // events disabled
4018 clientConn.writeAndReflect(wbuf, flags);
4019 EXPECT_THAT(observer->byteEvents, IsEmpty());
4020 clientConn.netOpsVerifyAndClearExpectations();
4021
4022 // now enable ByteEvents with another observer, then write again
4023 clientConn.netOpsExpectTimestampingSetSockOpt();
4024 auto observer2 = clientConn.attachObserver(true /* enableByteEvents */);
4025 EXPECT_EQ(0, observer->byteEventsEnabledCalled); // observer 1 doesn't want
4026 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4027 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4028 EXPECT_EQ(1, observer2->byteEventsEnabledCalled); // should be set
4029 EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4030 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4031 EXPECT_NE(WriteFlags::NONE, flags);
4032 EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
4033 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4034 clientConn.writeAndReflect(wbuf, flags);
4035 clientConn.netOpsVerifyAndClearExpectations();
4036
4037 // expect no ByteEvents for first observer, four for the second
4038 EXPECT_THAT(observer->byteEvents, IsEmpty());
4039 EXPECT_THAT(observer2->byteEvents, SizeIs(4));
4040 EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4041 EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4042 EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4043 EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4044 }
4045
TEST_F(AsyncSocketByteEventTest,ObserverAttachedAfterWrite)4046 TEST_F(AsyncSocketByteEventTest, ObserverAttachedAfterWrite) {
4047 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4048 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4049 const std::vector<uint8_t> wbuf(1, 'a');
4050
4051 auto clientConn = getClientConn();
4052 clientConn.netOpsExpectNoTimestampingSetSockOpt();
4053 clientConn.connect(); // connect before observer attached
4054 clientConn.netOpsVerifyAndClearExpectations();
4055
4056 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
4057 WriteFlags::NONE); // events disabled
4058 clientConn.writeAndReflect(wbuf, flags);
4059 clientConn.netOpsVerifyAndClearExpectations();
4060
4061 clientConn.netOpsExpectTimestampingSetSockOpt();
4062 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4063 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4064 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4065 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4066 clientConn.netOpsVerifyAndClearExpectations();
4067
4068 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4069 clientConn.writeAndReflect(wbuf, flags);
4070 clientConn.netOpsVerifyAndClearExpectations();
4071
4072 EXPECT_THAT(observer->byteEvents, SizeIs(4));
4073 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4074 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4075 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4076 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4077 }
4078
TEST_F(AsyncSocketByteEventTest,ObserverAttachedAfterClose)4079 TEST_F(AsyncSocketByteEventTest, ObserverAttachedAfterClose) {
4080 auto clientConn = getClientConn();
4081 clientConn.connect();
4082 clientConn.getRawSocket()->close();
4083 EXPECT_TRUE(clientConn.getRawSocket()->isClosedBySelf());
4084
4085 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4086 EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4087 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4088 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4089 }
4090
TEST_F(AsyncSocketByteEventTest,MultipleObserverAttached)4091 TEST_F(AsyncSocketByteEventTest, MultipleObserverAttached) {
4092 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4093 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4094 const std::vector<uint8_t> wbuf(50, 'a');
4095
4096 // attach observer 1 before connect
4097 auto clientConn = getClientConn();
4098 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4099 clientConn.netOpsExpectTimestampingSetSockOpt();
4100 clientConn.connect();
4101 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4102 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4103 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4104 clientConn.netOpsVerifyAndClearExpectations();
4105
4106 // attach observer 2 after connect
4107 auto observer2 = clientConn.attachObserver(true /* enableByteEvents */);
4108 EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
4109 EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4110 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4111
4112 // write
4113 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4114 clientConn.writeAndReflect(wbuf, flags);
4115 clientConn.netOpsVerifyAndClearExpectations();
4116
4117 // check observer1
4118 EXPECT_THAT(observer->byteEvents, SizeIs(4));
4119 EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4120 EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4121 EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4122 EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4123
4124 // check observer2
4125 EXPECT_THAT(observer2->byteEvents, SizeIs(4));
4126 EXPECT_EQ(
4127 49U, observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4128 EXPECT_EQ(
4129 49U, observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4130 EXPECT_EQ(49U, observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4131 EXPECT_EQ(49U, observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4132 }
4133
4134 /**
4135 * Test when kernel offset (uint32_t) wraps around.
4136 */
TEST_F(AsyncSocketByteEventTest,KernelOffsetWrap)4137 TEST_F(AsyncSocketByteEventTest, KernelOffsetWrap) {
4138 auto clientConn = getClientConn();
4139 clientConn.connect();
4140 clientConn.netOpsExpectTimestampingSetSockOpt();
4141 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4142 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4143 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4144 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4145 clientConn.netOpsVerifyAndClearExpectations();
4146
4147 const uint64_t wbufSize = 3000000;
4148 const std::vector<uint8_t> wbuf(wbufSize, 'a');
4149
4150 // part 1: write close to the wrap point with no ByteEvents to speed things up
4151 const uint64_t bytesToWritePt1 =
4152 static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) -
4153 (wbufSize * 5);
4154 while (clientConn.getRawSocket()->getRawBytesWritten() < bytesToWritePt1) {
4155 clientConn.write(wbuf, WriteFlags::NONE); // no reflect needed
4156 }
4157
4158 // part 2: write over the wrap point with ByteEvents
4159 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4160 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4161 const uint64_t bytesToWritePt2 =
4162 static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) +
4163 (wbufSize * 5);
4164 while (clientConn.getRawSocket()->getRawBytesWritten() < bytesToWritePt2) {
4165 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
4166 dropWriteFromFlags(flags));
4167 clientConn.writeAndReflect(wbuf, flags);
4168 clientConn.netOpsVerifyAndClearExpectations();
4169 const uint64_t expectedOffset =
4170 clientConn.getRawSocket()->getRawBytesWritten() - 1;
4171 EXPECT_EQ(
4172 expectedOffset,
4173 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4174 EXPECT_EQ(
4175 expectedOffset,
4176 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4177 EXPECT_EQ(
4178 expectedOffset,
4179 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4180 EXPECT_EQ(
4181 expectedOffset,
4182 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4183 }
4184
4185 // part 3: one more write outside of a loop with extra checks
4186 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4187 clientConn.writeAndReflect(wbuf, flags);
4188 clientConn.netOpsVerifyAndClearExpectations();
4189 const auto expectedOffset =
4190 clientConn.getRawSocket()->getRawBytesWritten() - 1;
4191 EXPECT_LT(std::numeric_limits<uint32_t>::max(), expectedOffset);
4192 EXPECT_EQ(
4193 expectedOffset,
4194 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4195 EXPECT_EQ(
4196 expectedOffset,
4197 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4198 EXPECT_EQ(
4199 expectedOffset,
4200 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4201 EXPECT_EQ(
4202 expectedOffset,
4203 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4204 }
4205
4206 /**
4207 * Ensure that ErrMessageCallback still works when ByteEvents enabled.
4208 */
TEST_F(AsyncSocketByteEventTest,ErrMessageCallbackStillTriggered)4209 TEST_F(AsyncSocketByteEventTest, ErrMessageCallbackStillTriggered) {
4210 auto clientConn = getClientConn();
4211 clientConn.connect();
4212 clientConn.netOpsExpectTimestampingSetSockOpt();
4213 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4214 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4215 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4216 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4217 clientConn.netOpsVerifyAndClearExpectations();
4218
4219 TestErrMessageCallback errMsgCB;
4220 clientConn.getRawSocket()->setErrMessageCB(&errMsgCB);
4221
4222 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4223 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4224
4225 std::vector<uint8_t> wbuf(1, 'a');
4226 EXPECT_NE(WriteFlags::NONE, flags);
4227 EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
4228 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4229 clientConn.writeAndReflect(wbuf, flags);
4230 clientConn.netOpsVerifyAndClearExpectations();
4231
4232 // observer should get events
4233 EXPECT_THAT(observer->byteEvents, SizeIs(4));
4234 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4235 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4236 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4237 EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4238
4239 // err message callbach should get events, too
4240 EXPECT_EQ(3, errMsgCB.gotByteSeq_);
4241 EXPECT_EQ(3, errMsgCB.gotTimestamp_);
4242
4243 // write again, more events for both
4244 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4245 clientConn.writeAndReflect(wbuf, flags);
4246 clientConn.netOpsVerifyAndClearExpectations();
4247 EXPECT_THAT(observer->byteEvents, SizeIs(8));
4248 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4249 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4250 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4251 EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4252 EXPECT_EQ(6, errMsgCB.gotByteSeq_);
4253 EXPECT_EQ(6, errMsgCB.gotTimestamp_);
4254 }
4255
4256 /**
4257 * Ensure that ByteEvents disabled for unix sockets (not supported).
4258 */
TEST_F(AsyncSocketByteEventTest,FailUnixSocket)4259 TEST_F(AsyncSocketByteEventTest, FailUnixSocket) {
4260 std::shared_ptr<NiceMock<TestObserver>> observer;
4261 auto netOpsDispatcher = std::make_shared<MockDispatcher>();
4262
4263 NetworkSocket fd[2];
4264 EXPECT_EQ(netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fd), 0);
4265 ASSERT_NE(fd[0], NetworkSocket());
4266 ASSERT_NE(fd[1], NetworkSocket());
4267 SCOPE_EXIT { netops::close(fd[1]); };
4268
4269 EXPECT_EQ(netops::set_socket_non_blocking(fd[0]), 0);
4270 EXPECT_EQ(netops::set_socket_non_blocking(fd[1]), 0);
4271
4272 auto clientSocketRaw = AsyncSocket::newSocket(nullptr, fd[0]);
4273 auto clientBlockingSocket = BlockingSocket(std::move(clientSocketRaw));
4274 clientBlockingSocket.getSocket()->setOverrideNetOpsDispatcher(
4275 netOpsDispatcher);
4276
4277 // make sure no SO_TIMESTAMPING setsockopt on observer attach
4278 EXPECT_CALL(*netOpsDispatcher, setsockopt(_, _, _, _, _)).Times(AnyNumber());
4279 EXPECT_CALL(
4280 *netOpsDispatcher, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
4281 .Times(0); // no calls
4282 observer = attachObserver(
4283 clientBlockingSocket.getSocket(), true /* enableByteEvents */);
4284 EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4285 EXPECT_EQ(1, observer->byteEventsUnavailableCalled);
4286 EXPECT_TRUE(observer->byteEventsUnavailableCalledEx.has_value());
4287 Mock::VerifyAndClearExpectations(netOpsDispatcher.get());
4288
4289 // do a write, we should see it has no timestamp flags
4290 const std::vector<uint8_t> wbuf(1, 'a');
4291 EXPECT_CALL(*netOpsDispatcher, sendmsg(_, _, _))
4292 .WillOnce(WithArgs<1>(Invoke([](const msghdr* message) {
4293 EXPECT_EQ(WriteFlags::NONE, getMsgAncillaryTsFlags(*message));
4294 return 1;
4295 })));
4296 clientBlockingSocket.write(
4297 wbuf.data(),
4298 wbuf.size(),
4299 WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4300 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK);
4301 Mock::VerifyAndClearExpectations(netOpsDispatcher.get());
4302 }
4303
4304 /**
4305 * If socket timestamps already enabled, do not enable ByteEvents.
4306 */
TEST_F(AsyncSocketByteEventTest,FailTimestampsAlreadyEnabled)4307 TEST_F(AsyncSocketByteEventTest, FailTimestampsAlreadyEnabled) {
4308 auto clientConn = getClientConn();
4309 clientConn.connect();
4310
4311 // enable timestamps via setsockopt
4312 const uint32_t flags = folly::netops::SOF_TIMESTAMPING_OPT_ID |
4313 folly::netops::SOF_TIMESTAMPING_OPT_TSONLY |
4314 folly::netops::SOF_TIMESTAMPING_SOFTWARE |
4315 folly::netops::SOF_TIMESTAMPING_RAW_HARDWARE |
4316 folly::netops::SOF_TIMESTAMPING_OPT_TX_SWHW;
4317 const auto ret = clientConn.getRawSocket()->setSockOpt(
4318 SOL_SOCKET, SO_TIMESTAMPING, &flags);
4319 EXPECT_EQ(0, ret);
4320
4321 clientConn.netOpsExpectNoTimestampingSetSockOpt();
4322 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4323 EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4324 EXPECT_EQ(1, observer->byteEventsUnavailableCalled); // fail
4325 EXPECT_TRUE(observer->byteEventsUnavailableCalledEx.has_value());
4326 clientConn.netOpsVerifyAndClearExpectations();
4327
4328 std::vector<uint8_t> wbuf(1, 'a');
4329 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(WriteFlags::NONE);
4330 clientConn.writeAndReflect(
4331 wbuf,
4332 WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4333 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK);
4334 clientConn.netOpsVerifyAndClearExpectations();
4335 EXPECT_THAT(observer->byteEvents, IsEmpty());
4336 }
4337
4338 /**
4339 * Verify that ByteEvent information is properly copied during socket moves.
4340 */
4341
TEST_F(AsyncSocketByteEventTest,MoveByteEventsEnabled)4342 TEST_F(AsyncSocketByteEventTest, MoveByteEventsEnabled) {
4343 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4344 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4345 const std::vector<uint8_t> wbuf(50, 'a');
4346
4347 auto clientConn = getClientConn();
4348 clientConn.connect();
4349
4350 // observer with ByteEvents enabled
4351 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4352 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4353 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4354 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4355
4356 // move the socket immediately and add an observer with ByteEvents enabled
4357 auto clientConn2 = ClientConn(
4358 server_,
4359 AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
4360 clientConn.getAcceptedSocket());
4361 auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
4362 EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
4363 EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4364 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4365
4366 // write following move, make sure the offsets are correct
4367 clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4368 dropWriteFromFlags(flags));
4369 clientConn2.writeAndReflect(wbuf, flags);
4370 clientConn2.netOpsVerifyAndClearExpectations();
4371 EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
4372 {
4373 const auto expectedOffset = 49U;
4374 EXPECT_EQ(
4375 expectedOffset,
4376 observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4377 EXPECT_EQ(
4378 expectedOffset,
4379 observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4380 EXPECT_EQ(
4381 expectedOffset,
4382 observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4383 EXPECT_EQ(
4384 expectedOffset,
4385 observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4386 }
4387
4388 // write again
4389 clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4390 dropWriteFromFlags(flags));
4391 clientConn2.writeAndReflect(wbuf, flags);
4392 clientConn2.netOpsVerifyAndClearExpectations();
4393 EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
4394 {
4395 const auto expectedOffset = 99U;
4396 EXPECT_EQ(
4397 expectedOffset,
4398 observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4399 EXPECT_EQ(
4400 expectedOffset,
4401 observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4402 EXPECT_EQ(
4403 expectedOffset,
4404 observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4405 EXPECT_EQ(
4406 expectedOffset,
4407 observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4408 }
4409 }
4410
TEST_F(AsyncSocketByteEventTest,WriteThenMoveByteEventsEnabled)4411 TEST_F(AsyncSocketByteEventTest, WriteThenMoveByteEventsEnabled) {
4412 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4413 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4414 const std::vector<uint8_t> wbuf(50, 'a');
4415
4416 auto clientConn = getClientConn();
4417 clientConn.connect();
4418
4419 // observer with ByteEvents enabled
4420 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4421 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4422 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4423 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4424
4425 // write
4426 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4427 clientConn.writeAndReflect(wbuf, flags);
4428 clientConn.netOpsVerifyAndClearExpectations();
4429 EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
4430 {
4431 const auto expectedOffset = 49U;
4432 EXPECT_EQ(
4433 expectedOffset,
4434 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4435 EXPECT_EQ(
4436 expectedOffset,
4437 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4438 EXPECT_EQ(
4439 expectedOffset,
4440 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4441 EXPECT_EQ(
4442 expectedOffset,
4443 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4444 }
4445
4446 // now move the socket and add an observer with ByteEvents enabled
4447 auto clientConn2 = ClientConn(
4448 server_,
4449 AsyncSocket::UniquePtr(
4450 new AsyncSocket(std::move(clientConn.getRawSocket().get()))),
4451 clientConn.getAcceptedSocket());
4452 auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
4453 EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
4454 EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4455 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4456
4457 // write following move, make sure the offsets are correct
4458 clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4459 dropWriteFromFlags(flags));
4460 clientConn2.writeAndReflect(wbuf, flags);
4461 clientConn2.netOpsVerifyAndClearExpectations();
4462 EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
4463 {
4464 const auto expectedOffset = 99U;
4465 EXPECT_EQ(
4466 expectedOffset,
4467 observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4468 EXPECT_EQ(
4469 expectedOffset,
4470 observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4471 EXPECT_EQ(
4472 expectedOffset,
4473 observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4474 EXPECT_EQ(
4475 expectedOffset,
4476 observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4477 }
4478
4479 // write again
4480 clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4481 dropWriteFromFlags(flags));
4482 clientConn2.writeAndReflect(wbuf, flags);
4483 clientConn2.netOpsVerifyAndClearExpectations();
4484 EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
4485 {
4486 const auto expectedOffset = 149U;
4487 EXPECT_EQ(
4488 expectedOffset,
4489 observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4490 EXPECT_EQ(
4491 expectedOffset,
4492 observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4493 EXPECT_EQ(
4494 expectedOffset,
4495 observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4496 EXPECT_EQ(
4497 expectedOffset,
4498 observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4499 }
4500 }
4501
TEST_F(AsyncSocketByteEventTest,MoveThenEnableByteEvents)4502 TEST_F(AsyncSocketByteEventTest, MoveThenEnableByteEvents) {
4503 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4504 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4505 const std::vector<uint8_t> wbuf(50, 'a');
4506
4507 auto clientConn = getClientConn();
4508 clientConn.connect();
4509
4510 // observer with ByteEvents disabled
4511 auto observer = clientConn.attachObserver(false /* enableByteEvents */);
4512 EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4513 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4514 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4515
4516 // move the socket immediately and add an observer with ByteEvents enabled
4517 auto clientConn2 = ClientConn(
4518 server_,
4519 AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
4520 clientConn.getAcceptedSocket());
4521 auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
4522 EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
4523 EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4524 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4525
4526 // write following move, make sure the offsets are correct
4527 clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4528 dropWriteFromFlags(flags));
4529 clientConn2.writeAndReflect(wbuf, flags);
4530 clientConn2.netOpsVerifyAndClearExpectations();
4531 EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
4532 {
4533 const auto expectedOffset = 49U;
4534 EXPECT_EQ(
4535 expectedOffset,
4536 observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4537 EXPECT_EQ(
4538 expectedOffset,
4539 observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4540 EXPECT_EQ(
4541 expectedOffset,
4542 observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4543 EXPECT_EQ(
4544 expectedOffset,
4545 observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4546 }
4547
4548 // write again
4549 clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4550 dropWriteFromFlags(flags));
4551 clientConn2.writeAndReflect(wbuf, flags);
4552 clientConn2.netOpsVerifyAndClearExpectations();
4553 EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
4554 {
4555 const auto expectedOffset = 99U;
4556 EXPECT_EQ(
4557 expectedOffset,
4558 observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4559 EXPECT_EQ(
4560 expectedOffset,
4561 observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4562 EXPECT_EQ(
4563 expectedOffset,
4564 observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4565 EXPECT_EQ(
4566 expectedOffset,
4567 observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4568 }
4569 }
4570
TEST_F(AsyncSocketByteEventTest,WriteThenMoveThenEnableByteEvents)4571 TEST_F(AsyncSocketByteEventTest, WriteThenMoveThenEnableByteEvents) {
4572 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4573 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4574 const std::vector<uint8_t> wbuf(50, 'a');
4575
4576 auto clientConn = getClientConn();
4577 clientConn.connect();
4578
4579 // observer with ByteEvents disabled
4580 auto observer = clientConn.attachObserver(false /* enableByteEvents */);
4581 EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4582 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4583 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4584
4585 // write, ByteEvents disabled
4586 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
4587 WriteFlags::NONE); // events diabled
4588 clientConn.writeAndReflect(wbuf, flags);
4589 clientConn.netOpsVerifyAndClearExpectations();
4590
4591 // now move the socket and add an observer with ByteEvents enabled
4592 auto clientConn2 = ClientConn(
4593 server_,
4594 AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
4595 clientConn.getAcceptedSocket());
4596 auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
4597 EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
4598 EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4599 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4600
4601 // write following move, make sure the offsets are correct
4602 clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4603 dropWriteFromFlags(flags));
4604 clientConn2.writeAndReflect(wbuf, flags);
4605 clientConn2.netOpsVerifyAndClearExpectations();
4606 EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
4607 {
4608 const auto expectedOffset = 99U;
4609 EXPECT_EQ(
4610 expectedOffset,
4611 observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4612 EXPECT_EQ(
4613 expectedOffset,
4614 observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4615 EXPECT_EQ(
4616 expectedOffset,
4617 observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4618 EXPECT_EQ(
4619 expectedOffset,
4620 observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4621 }
4622
4623 // write again
4624 clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4625 dropWriteFromFlags(flags));
4626 clientConn2.writeAndReflect(wbuf, flags);
4627 clientConn2.netOpsVerifyAndClearExpectations();
4628 EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
4629 {
4630 const auto expectedOffset = 149U;
4631 EXPECT_EQ(
4632 expectedOffset,
4633 observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4634 EXPECT_EQ(
4635 expectedOffset,
4636 observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4637 EXPECT_EQ(
4638 expectedOffset,
4639 observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4640 EXPECT_EQ(
4641 expectedOffset,
4642 observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4643 }
4644 }
4645
TEST_F(AsyncSocketByteEventTest,NoObserverMoveThenEnableByteEvents)4646 TEST_F(AsyncSocketByteEventTest, NoObserverMoveThenEnableByteEvents) {
4647 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4648 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4649 const std::vector<uint8_t> wbuf(50, 'a');
4650
4651 auto clientConn = getClientConn();
4652 clientConn.connect();
4653
4654 // no observer
4655
4656 // move the socket immediately and add an observer with ByteEvents enabled
4657 auto clientConn2 = ClientConn(
4658 server_,
4659 AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
4660 clientConn.getAcceptedSocket());
4661 auto observer = clientConn2.attachObserver(true /* enableByteEvents */);
4662 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4663 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4664 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4665
4666 // write following move, make sure the offsets are correct
4667 clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4668 dropWriteFromFlags(flags));
4669 clientConn2.writeAndReflect(wbuf, flags);
4670 clientConn2.netOpsVerifyAndClearExpectations();
4671 EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
4672 {
4673 const auto expectedOffset = 49U;
4674 EXPECT_EQ(
4675 expectedOffset,
4676 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4677 EXPECT_EQ(
4678 expectedOffset,
4679 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4680 EXPECT_EQ(
4681 expectedOffset,
4682 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4683 EXPECT_EQ(
4684 expectedOffset,
4685 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4686 }
4687
4688 // write again
4689 clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4690 dropWriteFromFlags(flags));
4691 clientConn2.writeAndReflect(wbuf, flags);
4692 clientConn2.netOpsVerifyAndClearExpectations();
4693 EXPECT_THAT(observer->byteEvents, SizeIs(Ge(8)));
4694 {
4695 const auto expectedOffset = 99U;
4696 EXPECT_EQ(
4697 expectedOffset,
4698 observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4699 EXPECT_EQ(
4700 expectedOffset,
4701 observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4702 EXPECT_EQ(
4703 expectedOffset,
4704 observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4705 EXPECT_EQ(
4706 expectedOffset,
4707 observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4708 }
4709 }
4710
4711 /**
4712 * Inspect ByteEvent fields, including xTimestampRequested in WRITE events.
4713 *
4714 * See CheckByteEventDetailsRawBytesWrittenAndTriedToWrite and
4715 * AsyncSocketByteEventDetailsTest::CheckByteEventDetails as well.
4716 */
TEST_F(AsyncSocketByteEventTest,CheckByteEventDetails)4717 TEST_F(AsyncSocketByteEventTest, CheckByteEventDetails) {
4718 const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4719 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4720 const std::vector<uint8_t> wbuf(1, 'a');
4721
4722 auto clientConn = getClientConn();
4723 clientConn.connect();
4724 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4725 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4726 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4727 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4728
4729 EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
4730 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4731 clientConn.writeAndReflect(wbuf, flags);
4732 clientConn.netOpsVerifyAndClearExpectations();
4733 EXPECT_THAT(observer->byteEvents, SizeIs(Eq(4)));
4734 const auto expectedOffset = wbuf.size() - 1;
4735
4736 // check WRITE
4737 {
4738 auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
4739 expectedOffset, ByteEventType::WRITE);
4740 ASSERT_TRUE(maybeByteEvent.has_value());
4741 auto& byteEvent = maybeByteEvent.value();
4742
4743 EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
4744 EXPECT_EQ(expectedOffset, byteEvent.offset);
4745 EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
4746 EXPECT_LT(
4747 std::chrono::steady_clock::now() - std::chrono::seconds(60),
4748 byteEvent.ts);
4749
4750 EXPECT_EQ(flags, byteEvent.maybeWriteFlags);
4751 EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
4752 EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
4753 EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());
4754
4755 EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
4756 EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
4757
4758 // maybeRawBytesWritten and maybeRawBytesTriedToWrite are tested in
4759 // CheckByteEventDetailsRawBytesWrittenAndTriedToWrite
4760 }
4761
4762 // check SCHED, TX, ACK
4763 for (const auto& byteEventType :
4764 {ByteEventType::SCHED, ByteEventType::TX, ByteEventType::ACK}) {
4765 auto maybeByteEvent =
4766 observer->getByteEventReceivedWithOffset(expectedOffset, byteEventType);
4767 ASSERT_TRUE(maybeByteEvent.has_value());
4768 auto& byteEvent = maybeByteEvent.value();
4769
4770 EXPECT_EQ(byteEventType, byteEvent.type);
4771 EXPECT_EQ(expectedOffset, byteEvent.offset);
4772 EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
4773 EXPECT_LT(
4774 std::chrono::steady_clock::now() - std::chrono::seconds(60),
4775 byteEvent.ts);
4776
4777 EXPECT_FALSE(byteEvent.maybeWriteFlags.has_value());
4778 EXPECT_DEATH(byteEvent.schedTimestampRequestedOnWrite(), ".*");
4779 EXPECT_DEATH(byteEvent.txTimestampRequestedOnWrite(), ".*");
4780 EXPECT_DEATH(byteEvent.ackTimestampRequestedOnWrite(), ".*");
4781
4782 EXPECT_TRUE(byteEvent.maybeSoftwareTs.has_value());
4783 EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
4784 }
4785 }
4786
4787 /**
4788 * Inspect ByteEvent fields maybeRawBytesWritten and maybeRawBytesTriedToWrite.
4789 */
TEST_F(AsyncSocketByteEventTest,CheckByteEventDetailsRawBytesWrittenAndTriedToWrite)4790 TEST_F(
4791 AsyncSocketByteEventTest,
4792 CheckByteEventDetailsRawBytesWrittenAndTriedToWrite) {
4793 auto clientConn = getClientConn();
4794 clientConn.connect();
4795 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4796 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4797 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4798 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4799
4800 struct ExpectedSendmsgInvocation {
4801 size_t expectedTotalIovLen{0};
4802 ssize_t returnVal{0}; // number of bytes written or error val
4803 folly::Optional<size_t> maybeWriteEventExpectedOffset{};
4804 folly::Optional<WriteFlags> maybeWriteEventExpectedFlags{};
4805 };
4806
4807 const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
4808 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
4809
4810 // first write
4811 //
4812 // no splits triggered by observer
4813 //
4814 // sendmsg will incrementally accept the bytes so we can test the values of
4815 // maybeRawBytesWritten and maybeRawBytesTriedToWrite
4816 {
4817 // bytes written per sendmsg call: 20, 10, 50, -1 (EAGAIN), 11, 99
4818 const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
4819 // {
4820 // expectedTotalIovLen, returnVal,
4821 // maybeWriteEventExpectedOffset, maybeWriteEventExpectedFlags
4822 // },
4823 {100, 20, 19, flags},
4824 {80, 10, 29, flags},
4825 {70, 50, 79, flags},
4826 {20, -1, folly::none, flags},
4827 {20, 11, 90, flags},
4828 {9, 9, 99, flags}};
4829
4830 // sendmsg will be called, we return # of bytes written
4831 {
4832 InSequence s;
4833 for (const auto& expectedInvocation : expectedSendmsgInvocations) {
4834 EXPECT_CALL(
4835 *(clientConn.getNetOpsDispatcher()),
4836 sendmsg(
4837 _,
4838 Pointee(SendmsgMsghdrHasTotalIovLen(
4839 expectedInvocation.expectedTotalIovLen)),
4840 _))
4841 .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
4842 if (expectedInvocation.returnVal < 0) {
4843 errno = EAGAIN; // returning error, set EAGAIN
4844 }
4845 return expectedInvocation.returnVal;
4846 }));
4847 }
4848 }
4849
4850 // write
4851 // writes will be intercepted, so we don't need to read at other end
4852 WriteCallback wcb;
4853 clientConn.getRawSocket()->write(
4854 &wcb,
4855 kOneHundredCharacterVec.data(),
4856 kOneHundredCharacterVec.size(),
4857 flags);
4858 while (STATE_WAITING == wcb.state) {
4859 clientConn.getRawSocket()->getEventBase()->loopOnce();
4860 }
4861 ASSERT_EQ(STATE_SUCCEEDED, wcb.state);
4862
4863 // check write events
4864 for (const auto& expectedInvocation : expectedSendmsgInvocations) {
4865 if (expectedInvocation.returnVal < 0) {
4866 // should be no WriteEvent since the return value was an error
4867 continue;
4868 }
4869
4870 ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
4871 const auto& expectedOffset =
4872 *expectedInvocation.maybeWriteEventExpectedOffset;
4873
4874 auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
4875 expectedOffset, ByteEventType::WRITE);
4876 ASSERT_TRUE(maybeByteEvent.has_value());
4877 auto& byteEvent = maybeByteEvent.value();
4878
4879 EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
4880 EXPECT_EQ(expectedOffset, byteEvent.offset);
4881 EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
4882 EXPECT_LT(
4883 std::chrono::steady_clock::now() - std::chrono::seconds(60),
4884 byteEvent.ts);
4885
4886 EXPECT_EQ(
4887 expectedInvocation.maybeWriteEventExpectedFlags,
4888 byteEvent.maybeWriteFlags);
4889 EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
4890 EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
4891 EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());
4892
4893 EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
4894 EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
4895
4896 // what we really want to test
4897 EXPECT_EQ(
4898 folly::to_unsigned(expectedInvocation.returnVal),
4899 byteEvent.maybeRawBytesWritten);
4900 EXPECT_EQ(
4901 expectedInvocation.expectedTotalIovLen,
4902 byteEvent.maybeRawBytesTriedToWrite);
4903 }
4904 }
4905
4906 // everything should have occurred by now
4907 clientConn.netOpsVerifyAndClearExpectations();
4908
4909 // second write
4910 //
4911 // sendmsg will incrementally accept the bytes so we can test the values of
4912 // maybeRawBytesWritten and maybeRawBytesTriedToWrite
4913 {
4914 // bytes written per sendmsg call: 20, 30, 50
4915 const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
4916 {100, 20, 119, flags}, {80, 30, 149, flags}, {50, 50, 199, flags}};
4917
4918 // sendmsg will be called, we return # of bytes written
4919 {
4920 InSequence s;
4921 for (const auto& expectedInvocation : expectedSendmsgInvocations) {
4922 EXPECT_CALL(
4923 *(clientConn.getNetOpsDispatcher()),
4924 sendmsg(
4925 _,
4926 Pointee(SendmsgMsghdrHasTotalIovLen(
4927 expectedInvocation.expectedTotalIovLen)),
4928 _))
4929 .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
4930 return expectedInvocation.returnVal;
4931 }));
4932 }
4933 }
4934
4935 // write
4936 // writes will be intercepted, so we don't need to read at other end
4937 WriteCallback wcb;
4938 clientConn.getRawSocket()->write(
4939 &wcb,
4940 kOneHundredCharacterVec.data(),
4941 kOneHundredCharacterVec.size(),
4942 flags);
4943 while (STATE_WAITING == wcb.state) {
4944 clientConn.getRawSocket()->getEventBase()->loopOnce();
4945 }
4946 ASSERT_EQ(STATE_SUCCEEDED, wcb.state);
4947
4948 // check write events
4949 for (const auto& expectedInvocation : expectedSendmsgInvocations) {
4950 ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
4951 const auto& expectedOffset =
4952 *expectedInvocation.maybeWriteEventExpectedOffset;
4953
4954 auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
4955 expectedOffset, ByteEventType::WRITE);
4956 ASSERT_TRUE(maybeByteEvent.has_value());
4957 auto& byteEvent = maybeByteEvent.value();
4958
4959 EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
4960 EXPECT_EQ(expectedOffset, byteEvent.offset);
4961 EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
4962 EXPECT_LT(
4963 std::chrono::steady_clock::now() - std::chrono::seconds(60),
4964 byteEvent.ts);
4965
4966 EXPECT_EQ(
4967 expectedInvocation.maybeWriteEventExpectedFlags,
4968 byteEvent.maybeWriteFlags);
4969 EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
4970 EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
4971 EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());
4972
4973 EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
4974 EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
4975
4976 // what we really want to test
4977 EXPECT_EQ(
4978 folly::to_unsigned(expectedInvocation.returnVal),
4979 byteEvent.maybeRawBytesWritten);
4980 EXPECT_EQ(
4981 expectedInvocation.expectedTotalIovLen,
4982 byteEvent.maybeRawBytesTriedToWrite);
4983 }
4984 }
4985 }
4986
TEST_F(AsyncSocketByteEventTest,SplitIoVecArraySingleIoVec)4987 TEST_F(AsyncSocketByteEventTest, SplitIoVecArraySingleIoVec) {
4988 // get srciov from lambda to enable us to keep it const during test
4989 const char* buf = kOneHundredCharacterString.c_str();
4990 auto getSrcIov = [&buf]() {
4991 std::vector<struct iovec> srcIov(2);
4992 srcIov[0].iov_base = const_cast<void*>(static_cast<const void*>(buf));
4993 srcIov[0].iov_len = kOneHundredCharacterString.size();
4994 return srcIov;
4995 };
4996
4997 std::vector<struct iovec> srcIov = getSrcIov();
4998 const auto data = srcIov.data();
4999
5000 // split 0 -> 0 (first byte)
5001 {
5002 std::vector<struct iovec> dstIov(4);
5003 size_t dstIovCount = dstIov.size();
5004 AsyncSocket::splitIovecArray(
5005 0, 0, data, srcIov.size(), dstIov.data(), dstIovCount);
5006
5007 ASSERT_EQ(1, dstIovCount);
5008 EXPECT_EQ(1, dstIov[0].iov_len);
5009 EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5010 EXPECT_EQ(buf, dstIov[0].iov_base);
5011 }
5012
5013 // split 0 -> 49 (50th byte)
5014 {
5015 std::vector<struct iovec> dstIov(4);
5016 size_t dstIovCount = dstIov.size();
5017 AsyncSocket::splitIovecArray(
5018 0, 49, data, srcIov.size(), dstIov.data(), dstIovCount);
5019
5020 ASSERT_EQ(1, dstIovCount);
5021 EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5022 EXPECT_EQ(50, dstIov[0].iov_len);
5023 }
5024
5025 // split 0 -> 98 (penultimate byte)
5026 {
5027 std::vector<struct iovec> dstIov(4);
5028 size_t dstIovCount = dstIov.size();
5029 AsyncSocket::splitIovecArray(
5030 0, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5031
5032 ASSERT_EQ(1, dstIovCount);
5033 EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5034 EXPECT_EQ(99, dstIov[0].iov_len);
5035 }
5036
5037 // split 0 -> 99 (pointless split)
5038 {
5039 std::vector<struct iovec> dstIov(4);
5040 size_t dstIovCount = dstIov.size();
5041 AsyncSocket::splitIovecArray(
5042 0, 99, data, srcIov.size(), dstIov.data(), dstIovCount);
5043
5044 ASSERT_EQ(1, dstIovCount);
5045 EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5046 EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5047 }
5048 }
5049
TEST_F(AsyncSocketByteEventTest,SplitIoVecArrayMultiIoVecInvalid)5050 TEST_F(AsyncSocketByteEventTest, SplitIoVecArrayMultiIoVecInvalid) {
5051 // get srciov from lambda to enable us to keep it const during test
5052 const char* buf = kOneHundredCharacterString.c_str();
5053 auto getSrcIov = [&buf]() {
5054 std::vector<struct iovec> srcIov(4);
5055 srcIov[0].iov_base = const_cast<void*>(static_cast<const void*>(buf));
5056 srcIov[0].iov_len = 50;
5057 srcIov[1].iov_base = const_cast<void*>(static_cast<const void*>(buf + 50));
5058 srcIov[1].iov_len = 50;
5059 return srcIov;
5060 };
5061
5062 std::vector<struct iovec> srcIov = getSrcIov();
5063 const auto data = srcIov.data();
5064
5065 // dstIov.size() < srcIov.size(); this is not allowed
5066 std::vector<struct iovec> dstIov(1);
5067 size_t dstIovCount = dstIov.size();
5068 EXPECT_LT(dstIovCount, srcIov.size());
5069 EXPECT_DEATH(
5070 AsyncSocket::splitIovecArray(
5071 0, 0, data, srcIov.size(), dstIov.data(), dstIovCount),
5072 ".*");
5073 }
5074
TEST_F(AsyncSocketByteEventTest,SplitIoVecArrayMultiIoVec)5075 TEST_F(AsyncSocketByteEventTest, SplitIoVecArrayMultiIoVec) {
5076 // get srciov from lambda to enable us to keep it const during test
5077 const char* buf = kOneHundredCharacterString.c_str();
5078 auto getSrcIov = [&buf]() {
5079 std::vector<struct iovec> srcIov(4);
5080 srcIov[0].iov_base = const_cast<void*>(static_cast<const void*>(buf));
5081 srcIov[0].iov_len = 25;
5082 srcIov[1].iov_base = const_cast<void*>(static_cast<const void*>(buf + 25));
5083 srcIov[1].iov_len = 25;
5084 srcIov[2].iov_base = const_cast<void*>(static_cast<const void*>(buf + 50));
5085 srcIov[2].iov_len = 25;
5086 srcIov[3].iov_base = const_cast<void*>(static_cast<const void*>(buf + 75));
5087 srcIov[3].iov_len = 25;
5088 return srcIov;
5089 };
5090
5091 std::vector<struct iovec> srcIov = getSrcIov();
5092 const auto data = srcIov.data();
5093
5094 // split 0 -> 0 (first byte)
5095 {
5096 std::vector<struct iovec> dstIov(4);
5097 size_t dstIovCount = dstIov.size();
5098 AsyncSocket::splitIovecArray(
5099 0, 0, data, srcIov.size(), dstIov.data(), dstIovCount);
5100
5101 ASSERT_EQ(1, dstIovCount);
5102 EXPECT_EQ(1, dstIov[0].iov_len);
5103 EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5104 EXPECT_EQ(buf, dstIov[0].iov_base);
5105 }
5106
5107 // split 0 -> 98 (penultimate byte)
5108 {
5109 std::vector<struct iovec> dstIov(4);
5110 size_t dstIovCount = dstIov.size();
5111 AsyncSocket::splitIovecArray(
5112 0, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5113
5114 ASSERT_EQ(4, dstIovCount);
5115 EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5116 EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5117 EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
5118 EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
5119 EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
5120 EXPECT_EQ(srcIov[2].iov_len, dstIov[2].iov_len);
5121
5122 // last iovec is different
5123 EXPECT_EQ(24, dstIov[3].iov_len);
5124 EXPECT_EQ(srcIov[3].iov_base, dstIov[3].iov_base);
5125 }
5126
5127 // split 0 -> 99 (pointless split)
5128 {
5129 std::vector<struct iovec> dstIov(4);
5130 size_t dstIovCount = dstIov.size();
5131 AsyncSocket::splitIovecArray(
5132 0, 99, data, srcIov.size(), dstIov.data(), dstIovCount);
5133
5134 ASSERT_EQ(4, dstIovCount);
5135 EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5136 EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5137 EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
5138 EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
5139 EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
5140 EXPECT_EQ(srcIov[2].iov_len, dstIov[2].iov_len);
5141 EXPECT_EQ(srcIov[3].iov_base, dstIov[3].iov_base);
5142 EXPECT_EQ(srcIov[3].iov_len, dstIov[3].iov_len);
5143 }
5144
5145 //
5146 // test when endOffset is near a iovec boundary
5147 //
5148
5149 // split 0 -> 49 (50th byte)
5150 {
5151 std::vector<struct iovec> dstIov(4);
5152 size_t dstIovCount = dstIov.size();
5153 AsyncSocket::splitIovecArray(
5154 0, 49, data, srcIov.size(), dstIov.data(), dstIovCount);
5155
5156 ASSERT_EQ(2, dstIovCount);
5157 EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5158 EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5159 EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
5160 EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
5161 }
5162
5163 // split 0 -> 50 (51st byte)
5164 {
5165 std::vector<struct iovec> dstIov(4);
5166 size_t dstIovCount = dstIov.size();
5167 AsyncSocket::splitIovecArray(
5168 0, 50, data, srcIov.size(), dstIov.data(), dstIovCount);
5169
5170 ASSERT_EQ(3, dstIovCount);
5171 EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5172 EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5173 EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
5174 EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
5175
5176 // last iovec is one byte
5177 EXPECT_EQ(1, dstIov[2].iov_len);
5178 EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
5179 }
5180
5181 // split 0 -> 51 (52nd byte)
5182 {
5183 std::vector<struct iovec> dstIov(4);
5184 size_t dstIovCount = dstIov.size();
5185 AsyncSocket::splitIovecArray(
5186 0, 51, data, srcIov.size(), dstIov.data(), dstIovCount);
5187
5188 ASSERT_EQ(3, dstIovCount);
5189 EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5190 EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5191 EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
5192 EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
5193
5194 // last iovec is two bytes
5195 EXPECT_EQ(2, dstIov[2].iov_len);
5196 EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
5197 }
5198
5199 //
5200 // test when startOffset is near a iovec boundary
5201 //
5202
5203 // split 49 -> 99
5204 {
5205 std::vector<struct iovec> dstIov(4);
5206 size_t dstIovCount = dstIov.size();
5207 AsyncSocket::splitIovecArray(
5208 49, 99, data, srcIov.size(), dstIov.data(), dstIovCount);
5209
5210 ASSERT_EQ(3, dstIovCount);
5211
5212 // first dst iovec is one byte, starts 24 bytes in to the second src iovec
5213 EXPECT_EQ(1, dstIov[0].iov_len);
5214 EXPECT_EQ(
5215 dstIov[0].iov_base,
5216 const_cast<void*>(static_cast<const void*>(
5217 reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));
5218
5219 // second dst iovec is third src iovec
5220 // third dst iovec is fourth src iovec
5221 EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
5222 EXPECT_EQ(dstIov[1].iov_len, srcIov[2].iov_len);
5223 EXPECT_EQ(dstIov[2].iov_base, srcIov[3].iov_base);
5224 EXPECT_EQ(dstIov[2].iov_len, srcIov[3].iov_len);
5225 }
5226
5227 // split 50 -> 99
5228 {
5229 std::vector<struct iovec> dstIov(4);
5230 size_t dstIovCount = dstIov.size();
5231 AsyncSocket::splitIovecArray(
5232 50, 99, data, srcIov.size(), dstIov.data(), dstIovCount);
5233
5234 ASSERT_EQ(2, dstIovCount);
5235
5236 // first dst iovec is third src iovec
5237 // second dst iovec is fourth src iovec
5238 EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
5239 EXPECT_EQ(dstIov[0].iov_len, srcIov[2].iov_len);
5240 EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
5241 EXPECT_EQ(dstIov[1].iov_len, srcIov[3].iov_len);
5242 }
5243
5244 // split 51 -> 99
5245 {
5246 std::vector<struct iovec> dstIov(4);
5247 size_t dstIovCount = dstIov.size();
5248 AsyncSocket::splitIovecArray(
5249 51, 99, data, srcIov.size(), dstIov.data(), dstIovCount);
5250
5251 ASSERT_EQ(2, dstIovCount);
5252
5253 // first dst iovec is 24 bytes, starts 1 byte in to the third src iovec
5254 EXPECT_EQ(24, dstIov[0].iov_len);
5255 EXPECT_EQ(
5256 dstIov[0].iov_base,
5257 const_cast<void*>(static_cast<const void*>(
5258 reinterpret_cast<uint8_t*>(srcIov[2].iov_base) + 1)));
5259
5260 // second dst iovec is fourth src iovec
5261 EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
5262 EXPECT_EQ(dstIov[1].iov_len, srcIov[3].iov_len);
5263 }
5264
5265 //
5266 // test when startOffset and endOffset are near iovec boundaries
5267 //
5268
5269 // split 49 -> 49
5270 {
5271 std::vector<struct iovec> dstIov(4);
5272 size_t dstIovCount = dstIov.size();
5273 AsyncSocket::splitIovecArray(
5274 49, 49, data, srcIov.size(), dstIov.data(), dstIovCount);
5275
5276 ASSERT_EQ(1, dstIovCount);
5277
5278 // first dst iovec is one byte, starts 24 bytes in to the second src iovec
5279 EXPECT_EQ(1, dstIov[0].iov_len);
5280 EXPECT_EQ(
5281 dstIov[0].iov_base,
5282 const_cast<void*>(static_cast<const void*>(
5283 reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));
5284 }
5285
5286 // split 49 -> 50
5287 {
5288 std::vector<struct iovec> dstIov(4);
5289 size_t dstIovCount = dstIov.size();
5290 AsyncSocket::splitIovecArray(
5291 49, 50, data, srcIov.size(), dstIov.data(), dstIovCount);
5292
5293 ASSERT_EQ(2, dstIovCount);
5294
5295 // first dst iovec is one byte, starts 24 bytes in to the second src iovec
5296 EXPECT_EQ(1, dstIov[0].iov_len);
5297 EXPECT_EQ(
5298 dstIov[0].iov_base,
5299 const_cast<void*>(static_cast<const void*>(
5300 reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));
5301
5302 // second iovec is one byte, starts at the third src iovec
5303 EXPECT_EQ(1, dstIov[1].iov_len);
5304 EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
5305 }
5306
5307 // split 49 -> 51
5308 {
5309 std::vector<struct iovec> dstIov(4);
5310 size_t dstIovCount = dstIov.size();
5311 AsyncSocket::splitIovecArray(
5312 49, 51, data, srcIov.size(), dstIov.data(), dstIovCount);
5313
5314 ASSERT_EQ(2, dstIovCount);
5315
5316 // first dst iovec is one byte, starts 24 bytes in to the second src iovec
5317 EXPECT_EQ(1, dstIov[0].iov_len);
5318 EXPECT_EQ(
5319 dstIov[0].iov_base,
5320 const_cast<void*>(static_cast<const void*>(
5321 reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));
5322
5323 // second iovec is two bytes, starts at the third src iovec
5324 EXPECT_EQ(2, dstIov[1].iov_len);
5325 EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
5326 }
5327
5328 // split 50 -> 50
5329 {
5330 std::vector<struct iovec> dstIov(4);
5331 size_t dstIovCount = dstIov.size();
5332 AsyncSocket::splitIovecArray(
5333 50, 50, data, srcIov.size(), dstIov.data(), dstIovCount);
5334
5335 ASSERT_EQ(1, dstIovCount);
5336
5337 // first dst iovec is one byte, starts at the third src iovec
5338 EXPECT_EQ(1, dstIov[0].iov_len);
5339 EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
5340 }
5341
5342 // split 50 -> 51
5343 {
5344 std::vector<struct iovec> dstIov(4);
5345 size_t dstIovCount = dstIov.size();
5346 AsyncSocket::splitIovecArray(
5347 50, 51, data, srcIov.size(), dstIov.data(), dstIovCount);
5348
5349 ASSERT_EQ(1, dstIovCount);
5350
5351 // first dst iovec is two bytes, starts at the third src iovec
5352 EXPECT_EQ(2, dstIov[0].iov_len);
5353 EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
5354 }
5355
5356 // split 51 -> 51
5357 {
5358 std::vector<struct iovec> dstIov(4);
5359 size_t dstIovCount = dstIov.size();
5360 AsyncSocket::splitIovecArray(
5361 51, 51, data, srcIov.size(), dstIov.data(), dstIovCount);
5362
5363 ASSERT_EQ(1, dstIovCount);
5364
5365 // first dst iovec is one byte, starts 1 byte into the third src iovec
5366 EXPECT_EQ(1, dstIov[0].iov_len);
5367 EXPECT_EQ(
5368 dstIov[0].iov_base,
5369 const_cast<void*>(static_cast<const void*>(
5370 reinterpret_cast<uint8_t*>(srcIov[2].iov_base) + 1)));
5371 }
5372
5373 // split 48 -> 98
5374 {
5375 std::vector<struct iovec> dstIov(4);
5376 size_t dstIovCount = dstIov.size();
5377 AsyncSocket::splitIovecArray(
5378 48, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5379
5380 ASSERT_EQ(3, dstIovCount);
5381
5382 // first dst iovec is two bytes, starts 23 bytes in to the second src iovec
5383 EXPECT_EQ(2, dstIov[0].iov_len);
5384 EXPECT_EQ(
5385 dstIov[0].iov_base,
5386 const_cast<void*>(static_cast<const void*>(
5387 reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 23)));
5388
5389 // second dst iovec is third src iovec
5390 EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
5391 EXPECT_EQ(dstIov[1].iov_len, srcIov[2].iov_len);
5392
5393 // third dst iovec is 24 bytes, starts at the fourth src iovec
5394 EXPECT_EQ(24, dstIov[2].iov_len);
5395 EXPECT_EQ(dstIov[2].iov_base, srcIov[3].iov_base);
5396 }
5397
5398 // split 49 -> 98
5399 {
5400 std::vector<struct iovec> dstIov(4);
5401 size_t dstIovCount = dstIov.size();
5402 AsyncSocket::splitIovecArray(
5403 49, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5404
5405 ASSERT_EQ(3, dstIovCount);
5406
5407 // first dst iovec is one byte, starts 24 bytes in to the second src iovec
5408 EXPECT_EQ(1, dstIov[0].iov_len);
5409 EXPECT_EQ(
5410 dstIov[0].iov_base,
5411 const_cast<void*>(static_cast<const void*>(
5412 reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));
5413
5414 // second dst iovec is third src iovec
5415 EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
5416 EXPECT_EQ(dstIov[1].iov_len, srcIov[2].iov_len);
5417
5418 // third dst iovec is 24 bytes, starts at the fourth src iovec
5419 EXPECT_EQ(24, dstIov[2].iov_len);
5420 EXPECT_EQ(dstIov[2].iov_base, srcIov[3].iov_base);
5421 }
5422
5423 // split 50 -> 98
5424 {
5425 std::vector<struct iovec> dstIov(4);
5426 size_t dstIovCount = dstIov.size();
5427 AsyncSocket::splitIovecArray(
5428 50, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5429
5430 ASSERT_EQ(2, dstIovCount);
5431
5432 // first dst iovec is third src iovec
5433 EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
5434 EXPECT_EQ(dstIov[0].iov_len, srcIov[2].iov_len);
5435
5436 // second dst iovec is 24 bytes, starts at the fourth src iovec
5437 EXPECT_EQ(24, dstIov[1].iov_len);
5438 EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
5439 }
5440
5441 // split 51 -> 98
5442 {
5443 std::vector<struct iovec> dstIov(4);
5444 size_t dstIovCount = dstIov.size();
5445 AsyncSocket::splitIovecArray(
5446 51, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5447
5448 ASSERT_EQ(2, dstIovCount);
5449
5450 // first dst iovec is 24 bytes, starts 1 byte in to the third src iovec
5451 EXPECT_EQ(24, dstIov[0].iov_len);
5452 EXPECT_EQ(
5453 dstIov[0].iov_base,
5454 const_cast<void*>(static_cast<const void*>(
5455 reinterpret_cast<uint8_t*>(srcIov[2].iov_base) + 1)));
5456
5457 // second dst iovec is 24 bytes, starts at the fourth src iovec
5458 EXPECT_EQ(24, dstIov[1].iov_len);
5459 EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
5460 }
5461 }
5462
TEST_F(AsyncSocketByteEventTest,SendmsgMatchers)5463 TEST_F(AsyncSocketByteEventTest, SendmsgMatchers) {
5464 // empty
5465 {
5466 const ClientConn::SendmsgInvocation sendmsgInvoc = {};
5467 // length
5468 EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(0)));
5469
5470 // iov first byte
5471 EXPECT_THAT(
5472 sendmsgInvoc,
5473 Not(SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data())));
5474 EXPECT_THAT(
5475 sendmsgInvoc,
5476 Not(SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data() + 5)));
5477
5478 // iov last byte
5479 EXPECT_THAT(
5480 sendmsgInvoc,
5481 Not(SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data())));
5482 EXPECT_THAT(
5483 sendmsgInvoc,
5484 Not(SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 5)));
5485 }
5486
5487 // single iov, last byte = end of kOneHundredCharacterVec
5488 {
5489 struct iovec iov = {};
5490 iov.iov_base = const_cast<void*>(
5491 static_cast<const void*>((kOneHundredCharacterVec.data())));
5492 iov.iov_len = kOneHundredCharacterVec.size();
5493 const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov}};
5494
5495 struct msghdr msg = {};
5496 msg.msg_name = nullptr;
5497 msg.msg_namelen = 0;
5498 msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5499 msg.msg_iovlen = sendmsgInvoc.iovs.size();
5500
5501 // length
5502 EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(100)));
5503 EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(100)));
5504
5505 // iov first byte
5506 EXPECT_THAT(
5507 sendmsgInvoc,
5508 SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
5509 EXPECT_THAT(
5510 sendmsgInvoc,
5511 Not(SendmsgInvocHasIovFirstByte(
5512 kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5513 1)));
5514
5515 // iov last byte
5516 EXPECT_THAT(
5517 sendmsgInvoc,
5518 Not(SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data())));
5519 EXPECT_THAT(
5520 sendmsgInvoc,
5521 SendmsgInvocHasIovLastByte(
5522 kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5523 1));
5524 }
5525
5526 // single iov, first and last byte = start of kOneHundredCharacterVec
5527 {
5528 struct iovec iov = {};
5529 iov.iov_base = const_cast<void*>(
5530 static_cast<const void*>((kOneHundredCharacterVec.data())));
5531 iov.iov_len = 1;
5532 const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov}};
5533
5534 struct msghdr msg = {};
5535 msg.msg_name = nullptr;
5536 msg.msg_namelen = 0;
5537 msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5538 msg.msg_iovlen = sendmsgInvoc.iovs.size();
5539
5540 // length
5541 EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(1)));
5542 EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(1)));
5543
5544 // iov first byte
5545 EXPECT_THAT(
5546 sendmsgInvoc,
5547 SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
5548 EXPECT_THAT(
5549 sendmsgInvoc,
5550 Not(SendmsgInvocHasIovFirstByte(
5551 kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5552 1)));
5553
5554 // iov last byte
5555 EXPECT_THAT(
5556 sendmsgInvoc,
5557 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()));
5558 EXPECT_THAT(
5559 sendmsgInvoc,
5560 Not(SendmsgInvocHasIovLastByte(
5561 kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5562 1)));
5563 }
5564
5565 // single iov, first and last byte = end of kOneHundredCharacterVec
5566 {
5567 struct iovec iov = {};
5568 iov.iov_base = const_cast<void*>(static_cast<const void*>(
5569 (kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size())));
5570 iov.iov_len = 1;
5571 const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov}};
5572
5573 struct msghdr msg = {};
5574 msg.msg_name = nullptr;
5575 msg.msg_namelen = 0;
5576 msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5577 msg.msg_iovlen = sendmsgInvoc.iovs.size();
5578
5579 // length
5580 EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(1)));
5581 EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(1)));
5582
5583 // iov first byte
5584 EXPECT_THAT(
5585 sendmsgInvoc,
5586 SendmsgInvocHasIovFirstByte(
5587 kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size()));
5588
5589 // iov last byte
5590 EXPECT_THAT(
5591 sendmsgInvoc,
5592 SendmsgInvocHasIovLastByte(
5593 kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size()));
5594 }
5595
5596 // two iov, (0 -> 0, 1 - > 99), last byte = end of kOneHundredCharacterVec
5597 {
5598 struct iovec iov1 = {};
5599 iov1.iov_base = const_cast<void*>(
5600 static_cast<const void*>((kOneHundredCharacterVec.data())));
5601 iov1.iov_len = 1;
5602
5603 struct iovec iov2 = {};
5604 iov2.iov_base = const_cast<void*>(
5605 static_cast<const void*>((kOneHundredCharacterVec.data() + 1)));
5606 iov2.iov_len = 99;
5607
5608 const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov1, iov2}};
5609
5610 struct msghdr msg = {};
5611 msg.msg_name = nullptr;
5612 msg.msg_namelen = 0;
5613 msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5614 msg.msg_iovlen = sendmsgInvoc.iovs.size();
5615
5616 // length
5617 EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(100)));
5618 EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(100)));
5619
5620 // iov first byte
5621 EXPECT_THAT(
5622 sendmsgInvoc,
5623 SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
5624
5625 // iov last byte
5626 EXPECT_THAT(
5627 sendmsgInvoc,
5628 SendmsgInvocHasIovLastByte(
5629 kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5630 1));
5631 }
5632
5633 // two iov, (0 -> 49, 50 - > 99), last byte = end of kOneHundredCharacterVec
5634 {
5635 struct iovec iov1 = {};
5636 iov1.iov_base = const_cast<void*>(
5637 static_cast<const void*>((kOneHundredCharacterVec.data())));
5638 iov1.iov_len = 50;
5639
5640 struct iovec iov2 = {};
5641 iov2.iov_base = const_cast<void*>(
5642 static_cast<const void*>((kOneHundredCharacterVec.data() + 50)));
5643 iov2.iov_len = 50;
5644
5645 const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov1, iov2}};
5646
5647 struct msghdr msg = {};
5648 msg.msg_name = nullptr;
5649 msg.msg_namelen = 0;
5650 msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5651 msg.msg_iovlen = sendmsgInvoc.iovs.size();
5652
5653 // length
5654 EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(100)));
5655 EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(100)));
5656
5657 // iov first byte
5658 EXPECT_THAT(
5659 sendmsgInvoc,
5660 SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
5661
5662 // iov last byte
5663 EXPECT_THAT(
5664 sendmsgInvoc,
5665 SendmsgInvocHasIovLastByte(
5666 kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5667 1));
5668 }
5669
5670 // two iov, (0 -> 49, 50 - > 98), last byte = penultimate byte
5671 {
5672 struct iovec iov1 = {};
5673 iov1.iov_base = const_cast<void*>(
5674 static_cast<const void*>((kOneHundredCharacterVec.data())));
5675 iov1.iov_len = 50;
5676
5677 struct iovec iov2 = {};
5678 iov2.iov_base = const_cast<void*>(
5679 static_cast<const void*>((kOneHundredCharacterVec.data() + 50)));
5680 iov2.iov_len = 49;
5681
5682 const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov1, iov2}};
5683
5684 struct msghdr msg = {};
5685 msg.msg_name = nullptr;
5686 msg.msg_namelen = 0;
5687 msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5688 msg.msg_iovlen = sendmsgInvoc.iovs.size();
5689
5690 // length
5691 EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(99)));
5692 EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(99)));
5693
5694 // iov first byte
5695 EXPECT_THAT(
5696 sendmsgInvoc,
5697 SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
5698
5699 // iov last byte
5700 EXPECT_THAT(
5701 sendmsgInvoc,
5702 SendmsgInvocHasIovLastByte(
5703 kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5704 2));
5705 }
5706 }
5707
TEST_F(AsyncSocketByteEventTest,SendmsgInvocMsgFlagsEq)5708 TEST_F(AsyncSocketByteEventTest, SendmsgInvocMsgFlagsEq) {
5709 // empty
5710 {
5711 const ClientConn::SendmsgInvocation sendmsgInvoc;
5712 EXPECT_THAT(sendmsgInvoc, SendmsgInvocMsgFlagsEq(WriteFlags::NONE));
5713 EXPECT_THAT(sendmsgInvoc, Not(SendmsgInvocMsgFlagsEq(WriteFlags::CORK)));
5714 }
5715
5716 // flag set
5717 {
5718 ClientConn::SendmsgInvocation sendmsgInvoc = {};
5719 sendmsgInvoc.writeFlagsInMsgFlags = WriteFlags::CORK;
5720 EXPECT_THAT(sendmsgInvoc, Not(SendmsgInvocMsgFlagsEq(WriteFlags::NONE)));
5721 EXPECT_THAT(
5722 sendmsgInvoc,
5723 Not(SendmsgInvocMsgFlagsEq(
5724 WriteFlags::EOR | WriteFlags::CORK))); // should be exact match
5725 EXPECT_THAT(sendmsgInvoc, SendmsgInvocMsgFlagsEq(WriteFlags::CORK));
5726 }
5727 }
5728
TEST_F(AsyncSocketByteEventTest,SendmsgInvocAncillaryFlagsEq)5729 TEST_F(AsyncSocketByteEventTest, SendmsgInvocAncillaryFlagsEq) {
5730 // empty
5731 {
5732 const ClientConn::SendmsgInvocation sendmsgInvoc;
5733 EXPECT_THAT(sendmsgInvoc, SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE));
5734 EXPECT_THAT(
5735 sendmsgInvoc,
5736 Not(SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX)));
5737 }
5738
5739 // flag set
5740 {
5741 ClientConn::SendmsgInvocation sendmsgInvoc = {};
5742 sendmsgInvoc.writeFlagsInAncillary = WriteFlags::TIMESTAMP_TX;
5743 EXPECT_THAT(
5744 sendmsgInvoc, Not(SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE)));
5745 EXPECT_THAT(
5746 sendmsgInvoc,
5747 Not(SendmsgInvocAncillaryFlagsEq(
5748 WriteFlags::TIMESTAMP_TX |
5749 WriteFlags::TIMESTAMP_ACK))); // should be exact match
5750 EXPECT_THAT(
5751 sendmsgInvoc, SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX));
5752 }
5753 }
5754
TEST_F(AsyncSocketByteEventTest,ByteEventMatching)5755 TEST_F(AsyncSocketByteEventTest, ByteEventMatching) {
5756 // offset = 0, type = WRITE
5757 {
5758 AsyncTransport::ByteEvent event = {};
5759 event.type = ByteEventType::WRITE;
5760 event.offset = 0;
5761 EXPECT_THAT(event, ByteEventMatching(ByteEventType::WRITE, 0));
5762
5763 // not matching
5764 EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::WRITE, 10)));
5765 EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::TX, 0)));
5766 EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::ACK, 0)));
5767 EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::SCHED, 0)));
5768 }
5769
5770 // offset = 10, type = TX
5771 {
5772 AsyncTransport::ByteEvent event = {};
5773 event.type = ByteEventType::TX;
5774 event.offset = 10;
5775 EXPECT_THAT(event, ByteEventMatching(ByteEventType::TX, 10));
5776
5777 // not matching
5778 EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::TX, 0)));
5779 EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::WRITE, 10)));
5780 EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::ACK, 10)));
5781 EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::SCHED, 10)));
5782 }
5783 }
5784
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserver)5785 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserver) {
5786 auto clientConn = getClientConn();
5787 clientConn.connect();
5788 auto observer = clientConn.attachObserver(
5789 true /* enableByteEvents */, true /* enablePrewrite */);
5790 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
5791 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
5792 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
5793
5794 clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
5795
5796 const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
5797 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
5798 ON_CALL(*observer, prewriteMock(_, _))
5799 .WillByDefault(testing::Invoke(
5800 [](AsyncTransport*,
5801 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
5802 AsyncTransport::LifecycleObserver::PrewriteRequest request;
5803 if (state.startOffset == 0) {
5804 request.maybeOffsetToSplitWrite = 0;
5805 } else if (state.startOffset <= 50) {
5806 request.maybeOffsetToSplitWrite = 50;
5807 } else if (state.startOffset <= 98) {
5808 request.maybeOffsetToSplitWrite = 98;
5809 }
5810
5811 request.writeFlagsToAddAtOffset = flags;
5812 return request;
5813 }));
5814 clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
5815
5816 EXPECT_THAT(
5817 clientConn.getSendmsgInvocations(),
5818 ElementsAre(
5819 AllOf(
5820 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
5821 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
5822 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
5823 AllOf(
5824 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 50),
5825 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
5826 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
5827 AllOf(
5828 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 98),
5829 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
5830 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
5831 AllOf(
5832 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
5833 SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
5834 SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));
5835
5836 // verify WRITE events exist at the appropriate locations
5837 // we verify timestamp events are generated elsewhere
5838 //
5839 // should _not_ contain events for 99 as no prewrite for that
5840 EXPECT_THAT(
5841 filterToWriteEvents(observer->byteEvents),
5842 ElementsAre(
5843 ByteEventMatching(ByteEventType::WRITE, 0),
5844 ByteEventMatching(ByteEventType::WRITE, 50),
5845 ByteEventMatching(ByteEventType::WRITE, 98)));
5846 }
5847
5848 /**
5849 * Test explicitly that CORK (MSG_MORE) is set if write is split in middle.
5850 */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverCorkIfSplitMiddle)5851 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverCorkIfSplitMiddle) {
5852 auto clientConn = getClientConn();
5853 clientConn.connect();
5854 auto observer = clientConn.attachObserver(
5855 true /* enableByteEvents */, true /* enablePrewrite */);
5856 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
5857 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
5858 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
5859
5860 clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
5861
5862 const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
5863 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
5864 ON_CALL(*observer, prewriteMock(_, _))
5865 .WillByDefault(testing::Invoke(
5866 [](AsyncTransport*,
5867 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
5868 AsyncTransport::LifecycleObserver::PrewriteRequest request;
5869 if (state.startOffset <= 50) {
5870 request.maybeOffsetToSplitWrite = 50;
5871 }
5872 request.writeFlagsToAddAtOffset = flags;
5873 return request;
5874 }));
5875 clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
5876
5877 EXPECT_THAT(
5878 clientConn.getSendmsgInvocations(),
5879 ElementsAre(
5880 AllOf(
5881 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 50),
5882 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
5883 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
5884 AllOf(
5885 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
5886 SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
5887 SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));
5888
5889 // verify WRITE events exist at the appropriate locations
5890 // we verify timestamp events are generated elsewhere
5891 EXPECT_THAT(
5892 filterToWriteEvents(observer->byteEvents),
5893 ElementsAre(ByteEventMatching(ByteEventType::WRITE, 50)));
5894 }
5895
5896 /**
5897 * Test explicitly that CORK (MSG_MORE) is set if write is split in middle.
5898 */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverNoCorkIfSplitAtEnd)5899 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverNoCorkIfSplitAtEnd) {
5900 auto clientConn = getClientConn();
5901 clientConn.connect();
5902 auto observer = clientConn.attachObserver(
5903 true /* enableByteEvents */, true /* enablePrewrite */);
5904 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
5905 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
5906 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
5907
5908 clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
5909
5910 const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
5911 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
5912 ON_CALL(*observer, prewriteMock(_, _))
5913 .WillByDefault(testing::Invoke(
5914 [](AsyncTransport*,
5915 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
5916 AsyncTransport::LifecycleObserver::PrewriteRequest request;
5917 if (state.startOffset <= 99) {
5918 request.maybeOffsetToSplitWrite = 99;
5919 }
5920 request.writeFlagsToAddAtOffset = flags;
5921 return request;
5922 }));
5923 clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
5924
5925 EXPECT_THAT(
5926 clientConn.getSendmsgInvocations(),
5927 ElementsAre(AllOf(
5928 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
5929 SendmsgInvocMsgFlagsEq(WriteFlags::NONE), // no cork!
5930 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))));
5931
5932 // verify WRITE events exist at the appropriate locations
5933 // we verify timestamp events are generated elsewhere
5934 EXPECT_THAT(
5935 filterToWriteEvents(observer->byteEvents),
5936 ElementsAre(ByteEventMatching(ByteEventType::WRITE, 99)));
5937 }
5938
5939 /**
5940 * Test explicitly that split flags are NOT added if no split.
5941 */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverNoSplitFlagsIfNoSplit)5942 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverNoSplitFlagsIfNoSplit) {
5943 auto clientConn = getClientConn();
5944 clientConn.connect();
5945 auto observer = clientConn.attachObserver(
5946 true /* enableByteEvents */, true /* enablePrewrite */);
5947 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
5948 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
5949 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
5950
5951 clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
5952
5953 const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
5954 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
5955 ON_CALL(*observer, prewriteMock(_, _))
5956 .WillByDefault(
5957 testing::Invoke([](AsyncTransport*,
5958 const AsyncTransport::LifecycleObserver::
5959 PrewriteState& /* state */) {
5960 AsyncTransport::LifecycleObserver::PrewriteRequest request;
5961 request.writeFlagsToAddAtOffset = flags;
5962 return request;
5963 }));
5964 clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
5965
5966 EXPECT_THAT(
5967 clientConn.getSendmsgInvocations(),
5968 ElementsAre(AllOf(
5969 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
5970 SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
5971 SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));
5972 }
5973
5974 /**
5975 * Test more combinations of prewrite flags, including writeFlagsToAdd.
5976 */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverFlagsOnAll)5977 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverFlagsOnAll) {
5978 auto clientConn = getClientConn();
5979 clientConn.connect();
5980 auto observer = clientConn.attachObserver(
5981 true /* enableByteEvents */, true /* enablePrewrite */);
5982 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
5983 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
5984 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
5985
5986 clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
5987 ON_CALL(*observer, prewriteMock(_, _))
5988 .WillByDefault(testing::Invoke(
5989 [](AsyncTransport*,
5990 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
5991 AsyncTransport::LifecycleObserver::PrewriteRequest request;
5992 if (state.startOffset == 0) {
5993 request.maybeOffsetToSplitWrite = 0;
5994 request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_WRITE;
5995 } else if (state.startOffset <= 10) {
5996 request.maybeOffsetToSplitWrite = 10;
5997 request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_SCHED;
5998 } else if (state.startOffset <= 20) {
5999 request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_TX;
6000 request.maybeOffsetToSplitWrite = 20;
6001 } else if (state.startOffset <= 30) {
6002 request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_ACK;
6003 request.maybeOffsetToSplitWrite = 30;
6004 } else if (state.startOffset <= 40) {
6005 request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_TX;
6006 request.writeFlagsToAdd |= WriteFlags::TIMESTAMP_WRITE;
6007 request.maybeOffsetToSplitWrite = 40;
6008 } else {
6009 request.writeFlagsToAdd |= WriteFlags::TIMESTAMP_WRITE;
6010 }
6011
6012 return request;
6013 }));
6014 clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
6015
6016 EXPECT_THAT(
6017 clientConn.getSendmsgInvocations(),
6018 ElementsAre(
6019 AllOf(
6020 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
6021 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6022 SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE)),
6023 AllOf(
6024 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 10),
6025 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6026 SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_SCHED)),
6027 AllOf(
6028 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 20),
6029 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6030 SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX)),
6031 AllOf(
6032 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 30),
6033 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6034 SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_ACK)),
6035 AllOf(
6036 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 40),
6037 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6038 SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX)),
6039 AllOf(
6040 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
6041 SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
6042 SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));
6043
6044 // verify WRITE events exist at the appropriate locations
6045 // we verify timestamp events are generated elsewhere
6046 EXPECT_THAT(
6047 filterToWriteEvents(observer->byteEvents),
6048 ElementsAre(
6049 ByteEventMatching(ByteEventType::WRITE, 0),
6050 ByteEventMatching(ByteEventType::WRITE, 40),
6051 ByteEventMatching(ByteEventType::WRITE, 99)));
6052 }
6053
6054 /**
6055 * Test merging of write flags with those passed to AsyncSocket::write().
6056 */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverFlagsOnWrite)6057 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverFlagsOnWrite) {
6058 auto clientConn = getClientConn();
6059 clientConn.connect();
6060 auto observer = clientConn.attachObserver(
6061 true /* enableByteEvents */, true /* enablePrewrite */);
6062 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6063 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6064 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6065
6066 clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6067
6068 // first byte, observer adds TX and WRITE, onwards, it just adds WRITE
6069 ON_CALL(*observer, prewriteMock(_, _))
6070 .WillByDefault(testing::Invoke(
6071 [](AsyncTransport*,
6072 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6073 AsyncTransport::LifecycleObserver::PrewriteRequest request;
6074 if (state.startOffset == 0) {
6075 request.maybeOffsetToSplitWrite = 0;
6076 request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_TX;
6077 }
6078 request.writeFlagsToAdd |= WriteFlags::TIMESTAMP_WRITE;
6079
6080 return request;
6081 }));
6082
6083 // application does a write with ACK and CORK set
6084 clientConn.writeAndReflect(
6085 kOneHundredCharacterVec, WriteFlags::CORK | WriteFlags::TIMESTAMP_ACK);
6086
6087 // make sure we have the merge
6088 // first write, TX is added
6089 // second write, CORK is passed through
6090 EXPECT_THAT(
6091 clientConn.getSendmsgInvocations(),
6092 ElementsAre(
6093 AllOf(
6094 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
6095 SendmsgInvocMsgFlagsEq(WriteFlags::CORK), // set by split
6096 SendmsgInvocAncillaryFlagsEq(
6097 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK)),
6098 AllOf(
6099 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
6100 SendmsgInvocMsgFlagsEq(WriteFlags::CORK), // still set
6101 SendmsgInvocAncillaryFlagsEq(
6102 dropWriteFromFlags(WriteFlags::TIMESTAMP_ACK)))));
6103
6104 // verify WRITE events exist at the appropriate locations
6105 // we verify timestamp events are generated elsewhere
6106 EXPECT_THAT(
6107 filterToWriteEvents(observer->byteEvents),
6108 ElementsAre(
6109 ByteEventMatching(ByteEventType::WRITE, 0),
6110 ByteEventMatching(ByteEventType::WRITE, 99)));
6111 }
6112
6113 /**
6114 * Test invalid offset for prewrite, ensure death via CHECK.
6115 */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverInvalidOffset)6116 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverInvalidOffset) {
6117 auto clientConn = getClientConn();
6118 clientConn.connect();
6119 auto observer = clientConn.attachObserver(
6120 true /* enableByteEvents */, true /* enablePrewrite */);
6121 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6122 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6123 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6124
6125 clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6126
6127 ON_CALL(*observer, prewriteMock(_, _))
6128 .WillByDefault(testing::Invoke(
6129 [](AsyncTransport*,
6130 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6131 AsyncTransport::LifecycleObserver::PrewriteRequest request;
6132 EXPECT_GT(200, state.endOffset);
6133 request.maybeOffsetToSplitWrite = 200; // invalid
6134 return request;
6135 }));
6136
6137 // check will fail due to invalid offset
6138 EXPECT_DEATH(
6139 clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE),
6140 ".*");
6141 }
6142
6143 /**
6144 * Test prewrite with multiple iovec.
6145 */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverTwoIovec)6146 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverTwoIovec) {
6147 // two iovec, each with half of the kOneHundredCharacterVec
6148 std::vector<iovec> iovs;
6149 {
6150 iovec iov = {};
6151 iov.iov_base = const_cast<void*>(
6152 static_cast<const void*>((kOneHundredCharacterVec.data())));
6153 iov.iov_len = 50;
6154 iovs.push_back(iov);
6155 }
6156 {
6157 iovec iov = {};
6158 iov.iov_base = const_cast<void*>(
6159 static_cast<const void*>((kOneHundredCharacterVec.data() + 50)));
6160 iov.iov_len = 50;
6161 iovs.push_back(iov);
6162 }
6163
6164 auto clientConn = getClientConn();
6165 clientConn.connect();
6166 auto observer = clientConn.attachObserver(
6167 true /* enableByteEvents */, true /* enablePrewrite */);
6168 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6169 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6170 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6171
6172 clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6173
6174 const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
6175 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
6176 ON_CALL(*observer, prewriteMock(_, _))
6177 .WillByDefault(testing::Invoke(
6178 [](AsyncTransport*,
6179 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6180 AsyncTransport::LifecycleObserver::PrewriteRequest request;
6181 if (state.startOffset == 0) {
6182 request.maybeOffsetToSplitWrite = 0;
6183 } else if (state.startOffset <= 49) {
6184 request.maybeOffsetToSplitWrite = 49;
6185 } else if (state.startOffset <= 99) {
6186 request.maybeOffsetToSplitWrite = 99;
6187 }
6188
6189 request.writeFlagsToAddAtOffset = flags;
6190 return request;
6191 }));
6192
6193 clientConn.writeAndReflect(iovs.data(), iovs.size(), WriteFlags::NONE);
6194
6195 EXPECT_THAT(
6196 clientConn.getSendmsgInvocations(),
6197 ElementsAre(
6198 AllOf(
6199 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
6200 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6201 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
6202 AllOf(
6203 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 49),
6204 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6205 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
6206 AllOf(
6207 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
6208 SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
6209 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))));
6210
6211 // verify WRITE events exist at the appropriate locations
6212 // we verify timestamp events are generated elsewhere
6213 EXPECT_THAT(
6214 filterToWriteEvents(observer->byteEvents),
6215 ElementsAre(
6216 ByteEventMatching(ByteEventType::WRITE, 0),
6217 ByteEventMatching(ByteEventType::WRITE, 49),
6218 ByteEventMatching(ByteEventType::WRITE, 99)));
6219 }
6220
6221 /**
6222 * Test prewrite with large number of iovec to trigger malloc codepath.
6223 */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverManyIovec)6224 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverManyIovec) {
6225 // make a long vector, 10000 bytes long
6226 auto tenThousandByteVec = get10KBOfData();
6227 ASSERT_THAT(tenThousandByteVec, SizeIs(10000));
6228
6229 // put each byte in the vector into its own iovec
6230 std::vector<iovec> tenThousandIovec;
6231 for (size_t i = 0; i < tenThousandByteVec.size(); i++) {
6232 iovec iov = {};
6233 iov.iov_base = tenThousandByteVec.data() + i;
6234 iov.iov_len = 1;
6235 tenThousandIovec.push_back(iov);
6236 }
6237
6238 auto clientConn = getClientConn();
6239 clientConn.connect();
6240 auto observer = clientConn.attachObserver(
6241 true /* enableByteEvents */, true /* enablePrewrite */);
6242 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6243 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6244 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6245
6246 clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6247
6248 const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
6249 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
6250 ON_CALL(*observer, prewriteMock(_, _))
6251 .WillByDefault(testing::Invoke(
6252 [](AsyncTransport*,
6253 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6254 AsyncTransport::LifecycleObserver::PrewriteRequest request;
6255 if (state.startOffset == 0) {
6256 request.maybeOffsetToSplitWrite = 0;
6257 } else if (state.startOffset <= 1000) {
6258 request.maybeOffsetToSplitWrite = 1000;
6259 } else if (state.startOffset <= 5000) {
6260 request.maybeOffsetToSplitWrite = 5000;
6261 }
6262
6263 request.writeFlagsToAddAtOffset = flags;
6264 return request;
6265 }));
6266
6267 clientConn.writeAndReflect(
6268 tenThousandIovec.data(), tenThousandIovec.size(), WriteFlags::NONE);
6269
6270 EXPECT_THAT(
6271 clientConn.getSendmsgInvocations(),
6272 AllOf(
6273 Contains(AllOf(
6274 SendmsgInvocHasIovLastByte(tenThousandByteVec.data()),
6275 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6276 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
6277 Contains(AllOf(
6278 SendmsgInvocHasIovLastByte(tenThousandByteVec.data() + 1000),
6279 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6280 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
6281 Contains(AllOf(
6282 SendmsgInvocHasIovLastByte(tenThousandByteVec.data() + 5000),
6283 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6284 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
6285 Contains(AllOf(
6286 SendmsgInvocHasIovLastByte(tenThousandByteVec.data() + 9999),
6287 SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
6288 SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE)))));
6289
6290 // verify WRITE events exist at the appropriate locations
6291 // we verify timestamp events are generated elsewhere
6292 //
6293 // should _not_ contain events for 99 as no prewrite for that
6294 EXPECT_THAT(
6295 filterToWriteEvents(observer->byteEvents),
6296 AllOf(
6297 Contains(ByteEventMatching(ByteEventType::WRITE, 0)),
6298 Contains(ByteEventMatching(ByteEventType::WRITE, 1000)),
6299 Contains(ByteEventMatching(ByteEventType::WRITE, 5000))));
6300 }
6301
TEST_F(AsyncSocketByteEventTest,PrewriteMultipleObservers)6302 TEST_F(AsyncSocketByteEventTest, PrewriteMultipleObservers) {
6303 auto clientConn = getClientConn();
6304 clientConn.connect();
6305
6306 // five observers
6307 // observer1 - 4 have byte events and prewrite enabled
6308 // observer5 has byte events enabled
6309 // observer6 has neither byte events or prewrite
6310 auto observer1 = clientConn.attachObserver(
6311 true /* enableByteEvents */, true /* enablePrewrite */);
6312 auto observer2 = clientConn.attachObserver(
6313 true /* enableByteEvents */, true /* enablePrewrite */);
6314 auto observer3 = clientConn.attachObserver(
6315 true /* enableByteEvents */, true /* enablePrewrite */);
6316 auto observer4 = clientConn.attachObserver(
6317 true /* enableByteEvents */, true /* enablePrewrite */);
6318 auto observer5 = clientConn.attachObserver(
6319 true /* enableByteEvents */, false /* enablePrewrite */);
6320 auto observer6 = clientConn.attachObserver(
6321 false /* enableByteEvents */, false /* enablePrewrite */);
6322
6323 clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6324
6325 // observer 1 wants TX timestamps at 25, 50, 75
6326 ON_CALL(*observer1, prewriteMock(_, _))
6327 .WillByDefault(testing::Invoke(
6328 [](AsyncTransport*,
6329 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6330 AsyncTransport::LifecycleObserver::PrewriteRequest request;
6331 if (state.startOffset <= 25) {
6332 request.maybeOffsetToSplitWrite = 25;
6333 } else if (state.startOffset <= 50) {
6334 request.maybeOffsetToSplitWrite = 50;
6335 } else if (state.startOffset <= 75) {
6336 request.maybeOffsetToSplitWrite = 75;
6337 }
6338 request.writeFlagsToAddAtOffset = WriteFlags::TIMESTAMP_TX;
6339 return request;
6340 }));
6341
6342 // observer 2 wants ACK timestamps at 35, 65, 75
6343 ON_CALL(*observer2, prewriteMock(_, _))
6344 .WillByDefault(testing::Invoke(
6345 [](AsyncTransport*,
6346 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6347 AsyncTransport::LifecycleObserver::PrewriteRequest request;
6348 if (state.startOffset <= 35) {
6349 request.maybeOffsetToSplitWrite = 35;
6350 } else if (state.startOffset <= 65) {
6351 request.maybeOffsetToSplitWrite = 65;
6352 } else if (state.startOffset <= 75) {
6353 request.maybeOffsetToSplitWrite = 75;
6354 }
6355 request.writeFlagsToAddAtOffset = WriteFlags::TIMESTAMP_ACK;
6356 return request;
6357 }));
6358
6359 // observer 3 wants WRITE and SCHED flag on every write that occurs
6360 ON_CALL(*observer3, prewriteMock(_, _))
6361 .WillByDefault(
6362 testing::Invoke([](AsyncTransport*,
6363 const AsyncTransport::LifecycleObserver::
6364 PrewriteState& /* state */) {
6365 AsyncTransport::LifecycleObserver::PrewriteRequest request;
6366 request.writeFlagsToAdd =
6367 WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED;
6368 return request;
6369 }));
6370
6371 // observer 4 has prewrite but makes no requests
6372 ON_CALL(*observer4, prewriteMock(_, _))
6373 .WillByDefault(
6374 testing::Invoke([](AsyncTransport*,
6375 const AsyncTransport::LifecycleObserver::
6376 PrewriteState& /* state */) {
6377 AsyncTransport::LifecycleObserver::PrewriteRequest request;
6378 return request; // empty
6379 }));
6380
6381 // no calls for observer 5 or observer 6
6382 EXPECT_CALL(*observer5, prewriteMock(_, _)).Times(0);
6383 EXPECT_CALL(*observer6, prewriteMock(_, _)).Times(0);
6384
6385 // write
6386 clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
6387
6388 EXPECT_THAT(
6389 clientConn.getSendmsgInvocations(),
6390 ElementsAre(
6391 AllOf(
6392 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 25),
6393 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6394 SendmsgInvocAncillaryFlagsEq(
6395 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX)),
6396 AllOf(
6397 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 35),
6398 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6399 SendmsgInvocAncillaryFlagsEq(
6400 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_ACK)),
6401 AllOf(
6402 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 50),
6403 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6404 SendmsgInvocAncillaryFlagsEq(
6405 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX)),
6406 AllOf(
6407 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 65),
6408 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6409 SendmsgInvocAncillaryFlagsEq(
6410 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_ACK)),
6411 AllOf(
6412 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 75),
6413 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6414 SendmsgInvocAncillaryFlagsEq(
6415 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX |
6416 WriteFlags::TIMESTAMP_ACK)),
6417 AllOf(
6418 SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
6419 SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
6420 SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_SCHED))));
6421
6422 // verify WRITE events exist at the appropriate locations
6423 // we verify timestamp events are generated elsewhere
6424 for (const auto& observer : {observer1, observer2, observer3}) {
6425 EXPECT_THAT(
6426 filterToWriteEvents(observer->byteEvents),
6427 ElementsAre(
6428 ByteEventMatching(ByteEventType::WRITE, 25),
6429 ByteEventMatching(ByteEventType::WRITE, 35),
6430 ByteEventMatching(ByteEventType::WRITE, 50),
6431 ByteEventMatching(ByteEventType::WRITE, 65),
6432 ByteEventMatching(ByteEventType::WRITE, 75),
6433 ByteEventMatching(ByteEventType::WRITE, 99)));
6434 }
6435 }
6436
6437 /**
6438 * Test prewrite with large write that enables testing of timestamps.
6439 *
6440 * We need to use a long vector to ensure that the kernel will not coalesce
6441 * the writes into a single SKB due to MSG_MORE.
6442 */
TEST_F(AsyncSocketByteEventTest,PrewriteTimestampedByteEvents)6443 TEST_F(AsyncSocketByteEventTest, PrewriteTimestampedByteEvents) {
6444 // need a large block of data to ensure that MSG_MORE doesn't limit us
6445 const auto hundredKBVec = get1000KBOfData();
6446 ASSERT_THAT(hundredKBVec, SizeIs(1000000));
6447
6448 auto clientConn = getClientConn();
6449 clientConn.connect();
6450 auto observer = clientConn.attachObserver(
6451 true /* enableByteEvents */, true /* enablePrewrite */);
6452 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6453 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6454 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6455
6456 clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6457
6458 const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
6459 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
6460 ON_CALL(*observer, prewriteMock(_, _))
6461 .WillByDefault(testing::Invoke(
6462 [](AsyncTransport*,
6463 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6464 AsyncTransport::LifecycleObserver::PrewriteRequest request;
6465 if (state.startOffset == 0) {
6466 request.maybeOffsetToSplitWrite = 0;
6467 } else if (state.startOffset <= 500000) {
6468 request.maybeOffsetToSplitWrite = 500000;
6469 } else {
6470 request.maybeOffsetToSplitWrite = 999999;
6471 }
6472
6473 request.writeFlagsToAdd = flags;
6474 return request;
6475 }));
6476
6477 clientConn.writeAndReflect(hundredKBVec, WriteFlags::NONE);
6478
6479 EXPECT_THAT(
6480 clientConn.getSendmsgInvocations(),
6481 AllOf(
6482 Contains(AllOf(
6483 SendmsgInvocHasIovLastByte(hundredKBVec.data()),
6484 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6485 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
6486 Contains(AllOf(
6487 SendmsgInvocHasIovLastByte(hundredKBVec.data() + 500000),
6488 SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6489 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
6490 Contains(AllOf(
6491 SendmsgInvocHasIovLastByte(hundredKBVec.data() + 999999),
6492 SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
6493 SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))))));
6494
6495 // verify WRITE events exist at the appropriate locations
6496 EXPECT_THAT(
6497 filterToWriteEvents(observer->byteEvents),
6498 AllOf(
6499 Contains(ByteEventMatching(ByteEventType::WRITE, 0)),
6500 Contains(ByteEventMatching(ByteEventType::WRITE, 500000)),
6501 Contains(ByteEventMatching(ByteEventType::WRITE, 999999))));
6502
6503 // verify SCHED, TX, and ACK events available at specified locations
6504 EXPECT_THAT(
6505 observer->byteEvents,
6506 AllOf(
6507 Contains(ByteEventMatching(ByteEventType::SCHED, 0)),
6508 Contains(ByteEventMatching(ByteEventType::TX, 0)),
6509 Contains(ByteEventMatching(ByteEventType::ACK, 0)),
6510 Contains(ByteEventMatching(ByteEventType::SCHED, 500000)),
6511 Contains(ByteEventMatching(ByteEventType::TX, 500000)),
6512 Contains(ByteEventMatching(ByteEventType::ACK, 500000)),
6513 Contains(ByteEventMatching(ByteEventType::SCHED, 999999)),
6514 Contains(ByteEventMatching(ByteEventType::TX, 999999)),
6515 Contains(ByteEventMatching(ByteEventType::ACK, 999999))));
6516 }
6517
6518 /**
6519 * Test raw bytes written and bytes tried to write with prewrite.
6520 */
TEST_F(AsyncSocketByteEventTest,PrewriteRawBytesWrittenAndTriedToWrite)6521 TEST_F(AsyncSocketByteEventTest, PrewriteRawBytesWrittenAndTriedToWrite) {
6522 auto clientConn = getClientConn();
6523 clientConn.connect();
6524 auto observer = clientConn.attachObserver(
6525 true /* enableByteEvents */, true /* enablePrewrite */);
6526 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6527 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6528 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6529
6530 struct ExpectedSendmsgInvocation {
6531 size_t expectedTotalIovLen{0};
6532 ssize_t returnVal{0}; // number of bytes written or error val
6533 folly::Optional<size_t> maybeWriteEventExpectedOffset{};
6534 folly::Optional<WriteFlags> maybeWriteEventExpectedFlags{};
6535 };
6536
6537 const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
6538 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
6539
6540 // first write
6541 //
6542 // no splits triggered by observer
6543 //
6544 // sendmsg will incrementally accept the bytes so we can test the values of
6545 // maybeRawBytesWritten and maybeRawBytesTriedToWrite
6546 {
6547 // bytes written per sendmsg call: 20, 10, 50, -1 (EAGAIN), 11, 99
6548 const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
6549 // {
6550 // expectedTotalIovLen, returnVal,
6551 // maybeWriteEventExpectedOffset, maybeWriteEventExpectedFlags
6552 // },
6553 {100, 20, 19, flags},
6554 {80, 10, 29, flags},
6555 {70, 50, 79, flags},
6556 {20, -1, folly::none, flags},
6557 {20, 11, 90, flags},
6558 {9, 9, 99, flags}};
6559
6560 // prewrite will be called, we request all events
6561 EXPECT_CALL(*observer, prewriteMock(_, _))
6562 .Times(expectedSendmsgInvocations.size())
6563 .WillRepeatedly(testing::InvokeWithoutArgs([]() {
6564 AsyncTransport::LifecycleObserver::PrewriteRequest request = {};
6565 request.writeFlagsToAdd = flags;
6566 return request;
6567 }));
6568
6569 // sendmsg will be called, we return # of bytes written
6570 {
6571 InSequence s;
6572 for (const auto& expectedInvocation : expectedSendmsgInvocations) {
6573 EXPECT_CALL(
6574 *(clientConn.getNetOpsDispatcher()),
6575 sendmsg(
6576 _,
6577 Pointee(SendmsgMsghdrHasTotalIovLen(
6578 expectedInvocation.expectedTotalIovLen)),
6579 _))
6580 .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
6581 if (expectedInvocation.returnVal < 0) {
6582 errno = EAGAIN; // returning error, set EAGAIN
6583 }
6584 return expectedInvocation.returnVal;
6585 }));
6586 }
6587 }
6588
6589 // write
6590 // writes will be intercepted, so we don't need to read at other end
6591 WriteCallback wcb;
6592 clientConn.getRawSocket()->write(
6593 &wcb,
6594 kOneHundredCharacterVec.data(),
6595 kOneHundredCharacterVec.size(),
6596 WriteFlags::NONE);
6597 while (STATE_WAITING == wcb.state) {
6598 clientConn.getRawSocket()->getEventBase()->loopOnce();
6599 }
6600 ASSERT_EQ(STATE_SUCCEEDED, wcb.state);
6601
6602 // check write events
6603 for (const auto& expectedInvocation : expectedSendmsgInvocations) {
6604 if (expectedInvocation.returnVal < 0) {
6605 // should be no WriteEvent since the return value was an error
6606 continue;
6607 }
6608
6609 ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
6610 const auto& expectedOffset =
6611 *expectedInvocation.maybeWriteEventExpectedOffset;
6612
6613 auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
6614 expectedOffset, ByteEventType::WRITE);
6615 ASSERT_TRUE(maybeByteEvent.has_value());
6616 auto& byteEvent = maybeByteEvent.value();
6617
6618 EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
6619 EXPECT_EQ(expectedOffset, byteEvent.offset);
6620 EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
6621 EXPECT_LT(
6622 std::chrono::steady_clock::now() - std::chrono::seconds(60),
6623 byteEvent.ts);
6624
6625 EXPECT_EQ(
6626 expectedInvocation.maybeWriteEventExpectedFlags,
6627 byteEvent.maybeWriteFlags);
6628 EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
6629 EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
6630 EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());
6631
6632 EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
6633 EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
6634
6635 // what we really want to test
6636 EXPECT_EQ(
6637 folly::to_unsigned(expectedInvocation.returnVal),
6638 byteEvent.maybeRawBytesWritten);
6639 EXPECT_EQ(
6640 expectedInvocation.expectedTotalIovLen,
6641 byteEvent.maybeRawBytesTriedToWrite);
6642 }
6643 }
6644
6645 // everything should have occurred by now
6646 clientConn.netOpsVerifyAndClearExpectations();
6647
6648 // second write
6649 //
6650 // start offset is 100
6651 //
6652 // split at 150th byte triggered by observer
6653 //
6654 // sendmsg will incrementally accept the bytes so we can test the values of
6655 // maybeRawBytesWritten and maybeRawBytesTriedToWrite
6656 {
6657 // due to the split at the 150th byte, we expect sendmsg invocation to
6658 // only be called with bytes 100 -> 150 until after the 150th byte has been
6659 // written; in addition, the socket only accepts 20 of the 50 bytes the
6660 // first write.
6661 //
6662 // bytes written per sendmsg call: 20, 30, 50
6663 const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
6664 {50, 20, 119, flags | WriteFlags::CORK},
6665 {30, 30, 149, flags | WriteFlags::CORK},
6666 {50, 50, 199, flags}};
6667
6668 // prewrite will be called, split at 50th byte (offset = 49)
6669 EXPECT_CALL(*observer, prewriteMock(_, _))
6670 .Times(expectedSendmsgInvocations.size())
6671 .WillRepeatedly(testing::Invoke(
6672 [](AsyncTransport*,
6673 const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6674 AsyncTransport::LifecycleObserver::PrewriteRequest request;
6675 if (state.startOffset <= 149) {
6676 request.maybeOffsetToSplitWrite = 149; // start offset = 100
6677 }
6678 request.writeFlagsToAdd = flags;
6679 return request;
6680 }));
6681
6682 // sendmsg will be called, we return # of bytes written
6683 {
6684 InSequence s;
6685 for (const auto& expectedInvocation : expectedSendmsgInvocations) {
6686 EXPECT_CALL(
6687 *(clientConn.getNetOpsDispatcher()),
6688 sendmsg(
6689 _,
6690 Pointee(SendmsgMsghdrHasTotalIovLen(
6691 expectedInvocation.expectedTotalIovLen)),
6692 _))
6693 .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
6694 return expectedInvocation.returnVal;
6695 }));
6696 }
6697 }
6698
6699 // write
6700 // writes will be intercepted, so we don't need to read at other end
6701 WriteCallback wcb;
6702 clientConn.getRawSocket()->write(
6703 &wcb,
6704 kOneHundredCharacterVec.data(),
6705 kOneHundredCharacterVec.size(),
6706 WriteFlags::NONE);
6707 while (STATE_WAITING == wcb.state) {
6708 clientConn.getRawSocket()->getEventBase()->loopOnce();
6709 }
6710 ASSERT_EQ(STATE_SUCCEEDED, wcb.state);
6711
6712 // check write events
6713 for (const auto& expectedInvocation : expectedSendmsgInvocations) {
6714 ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
6715 const auto& expectedOffset =
6716 *expectedInvocation.maybeWriteEventExpectedOffset;
6717
6718 auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
6719 expectedOffset, ByteEventType::WRITE);
6720 ASSERT_TRUE(maybeByteEvent.has_value());
6721 auto& byteEvent = maybeByteEvent.value();
6722
6723 EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
6724 EXPECT_EQ(expectedOffset, byteEvent.offset);
6725 EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
6726 EXPECT_LT(
6727 std::chrono::steady_clock::now() - std::chrono::seconds(60),
6728 byteEvent.ts);
6729
6730 EXPECT_EQ(
6731 expectedInvocation.maybeWriteEventExpectedFlags,
6732 byteEvent.maybeWriteFlags);
6733 EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
6734 EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
6735 EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());
6736
6737 EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
6738 EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
6739
6740 // what we really want to test
6741 EXPECT_EQ(
6742 folly::to_unsigned(expectedInvocation.returnVal),
6743 byteEvent.maybeRawBytesWritten);
6744 EXPECT_EQ(
6745 expectedInvocation.expectedTotalIovLen,
6746 byteEvent.maybeRawBytesTriedToWrite);
6747 }
6748 }
6749 }
6750
6751 struct AsyncSocketByteEventDetailsTestParams {
6752 struct WriteParams {
WriteParamsAsyncSocketByteEventDetailsTestParams::WriteParams6753 WriteParams(uint64_t bufferSize, WriteFlags writeFlags)
6754 : bufferSize(bufferSize), writeFlags(writeFlags) {}
6755 uint64_t bufferSize{0};
6756 WriteFlags writeFlags{WriteFlags::NONE};
6757 };
6758
6759 std::vector<WriteParams> writesWithParams;
6760 };
6761
6762 class AsyncSocketByteEventDetailsTest
6763 : public AsyncSocketByteEventTest,
6764 public testing::WithParamInterface<
6765 AsyncSocketByteEventDetailsTestParams> {
6766 public:
getTestingValues()6767 static std::vector<AsyncSocketByteEventDetailsTestParams> getTestingValues() {
6768 const std::array<WriteFlags, 9> writeFlagCombinations{
6769 // SCHED
6770 WriteFlags::TIMESTAMP_SCHED,
6771 // TX
6772 WriteFlags::TIMESTAMP_TX,
6773 // ACK
6774 WriteFlags::TIMESTAMP_ACK,
6775 // SCHED + TX + ACK
6776 WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX |
6777 WriteFlags::TIMESTAMP_ACK,
6778 // WRITE
6779 WriteFlags::TIMESTAMP_WRITE,
6780 // WRITE + SCHED
6781 WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED,
6782 // WRITE + TX
6783 WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_TX,
6784 // WRITE + ACK
6785 WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_ACK,
6786 // WRITE + SCHED + TX + ACK
6787 WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
6788 WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK,
6789 };
6790
6791 std::vector<AsyncSocketByteEventDetailsTestParams> vals;
6792 for (const auto& writeFlags : writeFlagCombinations) {
6793 // write 1 byte
6794 {
6795 AsyncSocketByteEventDetailsTestParams params;
6796 params.writesWithParams.emplace_back(1, writeFlags);
6797 vals.push_back(params);
6798 }
6799
6800 // write 1 byte twice
6801 {
6802 AsyncSocketByteEventDetailsTestParams params;
6803 params.writesWithParams.emplace_back(1, writeFlags);
6804 params.writesWithParams.emplace_back(1, writeFlags);
6805 vals.push_back(params);
6806 }
6807
6808 // write 10 bytes
6809 {
6810 AsyncSocketByteEventDetailsTestParams params;
6811 params.writesWithParams.emplace_back(10, writeFlags);
6812 vals.push_back(params);
6813 }
6814
6815 // write 10 bytes twice
6816 {
6817 AsyncSocketByteEventDetailsTestParams params;
6818 params.writesWithParams.emplace_back(10, writeFlags);
6819 params.writesWithParams.emplace_back(10, writeFlags);
6820 vals.push_back(params);
6821 }
6822 }
6823
6824 return vals;
6825 }
6826 };
6827
6828 INSTANTIATE_TEST_SUITE_P(
6829 ByteEventDetailsTest,
6830 AsyncSocketByteEventDetailsTest,
6831 ::testing::ValuesIn(AsyncSocketByteEventDetailsTest::getTestingValues()));
6832
6833 /**
6834 * Inspect ByteEvent fields, including xTimestampRequested in WRITE events.
6835 */
TEST_P(AsyncSocketByteEventDetailsTest,CheckByteEventDetails)6836 TEST_P(AsyncSocketByteEventDetailsTest, CheckByteEventDetails) {
6837 auto params = GetParam();
6838
6839 auto clientConn = getClientConn();
6840 clientConn.connect();
6841 auto observer = clientConn.attachObserver(true /* enableByteEvents */);
6842 EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6843 EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6844 EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6845
6846 uint64_t expectedNumByteEvents = 0;
6847 for (const auto& writeParams : params.writesWithParams) {
6848 const std::vector<uint8_t> wbuf(writeParams.bufferSize, 'a');
6849 const auto flags = writeParams.writeFlags;
6850 clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
6851 dropWriteFromFlags(flags));
6852 clientConn.writeAndReflect(wbuf, flags);
6853 clientConn.netOpsVerifyAndClearExpectations();
6854 const auto expectedOffset =
6855 clientConn.getRawSocket()->getRawBytesWritten() - 1;
6856
6857 // check WRITE
6858 if ((flags & WriteFlags::TIMESTAMP_WRITE) != WriteFlags::NONE) {
6859 expectedNumByteEvents++;
6860
6861 auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
6862 expectedOffset, ByteEventType::WRITE);
6863 ASSERT_TRUE(maybeByteEvent.has_value());
6864 auto& byteEvent = maybeByteEvent.value();
6865
6866 EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
6867 EXPECT_EQ(expectedOffset, byteEvent.offset);
6868 EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
6869 EXPECT_LT(
6870 std::chrono::steady_clock::now() - std::chrono::seconds(60),
6871 byteEvent.ts);
6872
6873 EXPECT_EQ(flags, byteEvent.maybeWriteFlags);
6874 EXPECT_EQ(
6875 isSet(flags, WriteFlags::TIMESTAMP_SCHED),
6876 byteEvent.schedTimestampRequestedOnWrite());
6877 EXPECT_EQ(
6878 isSet(flags, WriteFlags::TIMESTAMP_TX),
6879 byteEvent.txTimestampRequestedOnWrite());
6880 EXPECT_EQ(
6881 isSet(flags, WriteFlags::TIMESTAMP_ACK),
6882 byteEvent.ackTimestampRequestedOnWrite());
6883
6884 EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
6885 EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
6886 }
6887
6888 // check SCHED, TX, ACK
6889 for (const auto& byteEventType :
6890 {ByteEventType::SCHED, ByteEventType::TX, ByteEventType::ACK}) {
6891 auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
6892 expectedOffset, byteEventType);
6893 switch (byteEventType) {
6894 case ByteEventType::WRITE:
6895 FAIL();
6896 case ByteEventType::SCHED:
6897 if ((flags & WriteFlags::TIMESTAMP_SCHED) == WriteFlags::NONE) {
6898 EXPECT_FALSE(maybeByteEvent.has_value());
6899 continue;
6900 }
6901 break;
6902 case ByteEventType::TX:
6903 if ((flags & WriteFlags::TIMESTAMP_TX) == WriteFlags::NONE) {
6904 EXPECT_FALSE(maybeByteEvent.has_value());
6905 continue;
6906 }
6907 break;
6908 case ByteEventType::ACK:
6909 if ((flags & WriteFlags::TIMESTAMP_ACK) == WriteFlags::NONE) {
6910 EXPECT_FALSE(maybeByteEvent.has_value());
6911 continue;
6912 }
6913 break;
6914 }
6915
6916 expectedNumByteEvents++;
6917 ASSERT_TRUE(maybeByteEvent.has_value());
6918 auto& byteEvent = maybeByteEvent.value();
6919
6920 EXPECT_EQ(byteEventType, byteEvent.type);
6921 EXPECT_EQ(expectedOffset, byteEvent.offset);
6922 EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
6923 EXPECT_LT(
6924 std::chrono::steady_clock::now() - std::chrono::seconds(60),
6925 byteEvent.ts);
6926 EXPECT_FALSE(byteEvent.maybeWriteFlags.has_value());
6927 // don't check *TimestampRequestedOnWrite* fields to avoid CHECK_DEATH,
6928 // already checked in CheckByteEventDetailsApplicationSetsFlags
6929
6930 EXPECT_TRUE(byteEvent.maybeSoftwareTs.has_value());
6931 EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
6932 }
6933 }
6934
6935 // should have at least expectedNumByteEvents
6936 // may be more if writes were split up by kernel
6937 EXPECT_THAT(observer->byteEvents, SizeIs(Ge(expectedNumByteEvents)));
6938 }
6939
6940 class AsyncSocketByteEventHelperTest : public ::testing::Test {
6941 protected:
6942 using ByteEventType = AsyncTransport::ByteEvent::Type;
6943
6944 /**
6945 * Wrapper around a vector containing cmsg header + data.
6946 */
6947 class WrappedCMsg {
6948 public:
WrappedCMsg(std::vector<char> && data)6949 explicit WrappedCMsg(std::vector<char>&& data) : data_(std::move(data)) {}
6950
operator const struct cmsghdr&()6951 operator const struct cmsghdr &() {
6952 return *reinterpret_cast<struct cmsghdr*>(data_.data());
6953 }
6954
6955 protected:
6956 std::vector<char> data_;
6957 };
6958
6959 /**
6960 * Wrapper around a vector containing cmsg header + data.
6961 */
6962 class WrappedSockExtendedErrTsCMsg : public WrappedCMsg {
6963 public:
6964 using WrappedCMsg::WrappedCMsg;
6965
6966 // ts[0] -> software timestamp
6967 // ts[1] -> hardware timestamp transformed to userspace time (deprecated)
6968 // ts[2] -> hardware timestamp
6969
setSoftwareTimestamp(const std::chrono::seconds seconds,const std::chrono::nanoseconds nanoseconds)6970 void setSoftwareTimestamp(
6971 const std::chrono::seconds seconds,
6972 const std::chrono::nanoseconds nanoseconds) {
6973 struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data_.data())};
6974 struct scm_timestamping* tss{
6975 reinterpret_cast<struct scm_timestamping*>(CMSG_DATA(cmsg))};
6976 tss->ts[0].tv_sec = seconds.count();
6977 tss->ts[0].tv_nsec = nanoseconds.count();
6978 }
6979
setHardwareTimestamp(const std::chrono::seconds seconds,const std::chrono::nanoseconds nanoseconds)6980 void setHardwareTimestamp(
6981 const std::chrono::seconds seconds,
6982 const std::chrono::nanoseconds nanoseconds) {
6983 struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data_.data())};
6984 struct scm_timestamping* tss{
6985 reinterpret_cast<struct scm_timestamping*>(CMSG_DATA(cmsg))};
6986 tss->ts[2].tv_sec = seconds.count();
6987 tss->ts[2].tv_nsec = nanoseconds.count();
6988 }
6989 };
6990
cmsgData(int level,int type,size_t len)6991 static std::vector<char> cmsgData(int level, int type, size_t len) {
6992 std::vector<char> data(CMSG_LEN(len), 0);
6993 struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data.data())};
6994 cmsg->cmsg_level = level;
6995 cmsg->cmsg_type = type;
6996 cmsg->cmsg_len = CMSG_LEN(len);
6997 return data;
6998 }
6999
cmsgForSockExtendedErrTimestamping()7000 static WrappedSockExtendedErrTsCMsg cmsgForSockExtendedErrTimestamping() {
7001 return WrappedSockExtendedErrTsCMsg(
7002 cmsgData(SOL_SOCKET, SO_TIMESTAMPING, sizeof(struct scm_timestamping)));
7003 }
7004
cmsgForScmTimestamping(const uint32_t type,const uint32_t kernelByteOffset)7005 static WrappedCMsg cmsgForScmTimestamping(
7006 const uint32_t type, const uint32_t kernelByteOffset) {
7007 auto data = cmsgData(SOL_IP, IP_RECVERR, sizeof(struct sock_extended_err));
7008 struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data.data())};
7009 struct sock_extended_err* serr{
7010 reinterpret_cast<struct sock_extended_err*>(CMSG_DATA(cmsg))};
7011 serr->ee_errno = ENOMSG;
7012 serr->ee_origin = SO_EE_ORIGIN_TIMESTAMPING;
7013 serr->ee_info = type;
7014 serr->ee_data = kernelByteOffset;
7015 return WrappedCMsg(std::move(data));
7016 }
7017 };
7018
TEST_F(AsyncSocketByteEventHelperTest,ByteOffsetThenTs)7019 TEST_F(AsyncSocketByteEventHelperTest, ByteOffsetThenTs) {
7020 auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
7021 const auto softwareTsSec = std::chrono::seconds(59);
7022 const auto softwareTsNs = std::chrono::nanoseconds(11);
7023 auto serrTs = cmsgForSockExtendedErrTimestamping();
7024 serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7025
7026 AsyncSocket::ByteEventHelper helper = {};
7027 helper.byteEventsEnabled = true;
7028 helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7029
7030 EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7031 EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7032 }
7033
TEST_F(AsyncSocketByteEventHelperTest,TsThenByteOffset)7034 TEST_F(AsyncSocketByteEventHelperTest, TsThenByteOffset) {
7035 auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
7036 const auto softwareTsSec = std::chrono::seconds(59);
7037 const auto softwareTsNs = std::chrono::nanoseconds(11);
7038 auto serrTs = cmsgForSockExtendedErrTimestamping();
7039 serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7040
7041 AsyncSocket::ByteEventHelper helper = {};
7042 helper.byteEventsEnabled = true;
7043 helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7044
7045 EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7046 EXPECT_TRUE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7047 }
7048
TEST_F(AsyncSocketByteEventHelperTest,ByteEventsDisabled)7049 TEST_F(AsyncSocketByteEventHelperTest, ByteEventsDisabled) {
7050 auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
7051 const auto softwareTsSec = std::chrono::seconds(59);
7052 const auto softwareTsNs = std::chrono::nanoseconds(11);
7053 auto serrTs = cmsgForSockExtendedErrTimestamping();
7054 serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7055
7056 AsyncSocket::ByteEventHelper helper = {};
7057 helper.byteEventsEnabled = false;
7058 helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7059
7060 // fails because disabled
7061 EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7062 EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7063
7064 // enable, try again to prove this works
7065 helper.byteEventsEnabled = true;
7066 EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7067 EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7068 }
7069
TEST_F(AsyncSocketByteEventHelperTest,IgnoreUnsupportedEvent)7070 TEST_F(AsyncSocketByteEventHelperTest, IgnoreUnsupportedEvent) {
7071 auto scmType =
7072 folly::netops::SCM_TSTAMP_ACK + 10; // imaginary new type of SCM event
7073 auto scmTs = cmsgForScmTimestamping(scmType, 0);
7074 const auto softwareTsSec = std::chrono::seconds(59);
7075 const auto softwareTsNs = std::chrono::nanoseconds(11);
7076 auto serrTs = cmsgForSockExtendedErrTimestamping();
7077 serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7078
7079 AsyncSocket::ByteEventHelper helper = {};
7080 helper.byteEventsEnabled = true;
7081 helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7082
7083 // unsupported event is eaten
7084 EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7085 EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7086
7087 // change type, try again to prove this works
7088 scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_ACK, 0);
7089 EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7090 EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7091 }
7092
TEST_F(AsyncSocketByteEventHelperTest,ErrorDoubleScmCmsg)7093 TEST_F(AsyncSocketByteEventHelperTest, ErrorDoubleScmCmsg) {
7094 auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
7095
7096 AsyncSocket::ByteEventHelper helper = {};
7097 helper.byteEventsEnabled = true;
7098 helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7099 EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7100 EXPECT_THROW(
7101 helper.processCmsg(scmTs, 1 /* rawBytesWritten */),
7102 AsyncSocket::ByteEventHelper::Exception);
7103 }
7104
TEST_F(AsyncSocketByteEventHelperTest,ErrorDoubleSerrCmsg)7105 TEST_F(AsyncSocketByteEventHelperTest, ErrorDoubleSerrCmsg) {
7106 const auto softwareTsSec = std::chrono::seconds(59);
7107 const auto softwareTsNs = std::chrono::nanoseconds(11);
7108 auto serrTs = cmsgForSockExtendedErrTimestamping();
7109 serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7110
7111 AsyncSocket::ByteEventHelper helper = {};
7112 helper.byteEventsEnabled = true;
7113 helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7114 EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7115 EXPECT_THROW(
7116 helper.processCmsg(serrTs, 1 /* rawBytesWritten */),
7117 AsyncSocket::ByteEventHelper::Exception);
7118 }
7119
TEST_F(AsyncSocketByteEventHelperTest,ErrorExceptionSet)7120 TEST_F(AsyncSocketByteEventHelperTest, ErrorExceptionSet) {
7121 auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
7122 const auto softwareTsSec = std::chrono::seconds(59);
7123 const auto softwareTsNs = std::chrono::nanoseconds(11);
7124 auto serrTs = cmsgForSockExtendedErrTimestamping();
7125 serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7126
7127 AsyncSocket::ByteEventHelper helper = {};
7128 helper.byteEventsEnabled = true;
7129 helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7130 helper.maybeEx = AsyncSocketException(
7131 AsyncSocketException::AsyncSocketExceptionType::UNKNOWN, "");
7132
7133 // fails due to existing exception
7134 EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7135 EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7136
7137 // delete the exception, then repeat to prove exception was blocking
7138 helper.maybeEx = folly::none;
7139 EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7140 EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7141 }
7142
7143 struct AsyncSocketByteEventHelperTimestampTestParams {
AsyncSocketByteEventHelperTimestampTestParamsAsyncSocketByteEventHelperTimestampTestParams7144 AsyncSocketByteEventHelperTimestampTestParams(
7145 uint32_t scmType,
7146 AsyncTransport::ByteEvent::Type expectedByteEventType,
7147 bool includeSoftwareTs,
7148 bool includeHardwareTs)
7149 : scmType(scmType),
7150 expectedByteEventType(expectedByteEventType),
7151 includeSoftwareTs(includeSoftwareTs),
7152 includeHardwareTs(includeHardwareTs) {}
7153 uint32_t scmType{0};
7154 AsyncTransport::ByteEvent::Type expectedByteEventType;
7155 bool includeSoftwareTs{false};
7156 bool includeHardwareTs{false};
7157 };
7158
7159 class AsyncSocketByteEventHelperTimestampTest
7160 : public AsyncSocketByteEventHelperTest,
7161 public testing::WithParamInterface<
7162 AsyncSocketByteEventHelperTimestampTestParams> {
7163 public:
7164 static std::vector<AsyncSocketByteEventHelperTimestampTestParams>
getTestingValues()7165 getTestingValues() {
7166 std::vector<AsyncSocketByteEventHelperTimestampTestParams> vals;
7167
7168 // software + hardware timestamps
7169 {
7170 vals.emplace_back(
7171 folly::netops::SCM_TSTAMP_SCHED, ByteEventType::SCHED, true, true);
7172 vals.emplace_back(
7173 folly::netops::SCM_TSTAMP_SND, ByteEventType::TX, true, true);
7174 vals.emplace_back(
7175 folly::netops::SCM_TSTAMP_ACK, ByteEventType::ACK, true, true);
7176 }
7177
7178 // software ts only
7179 {
7180 vals.emplace_back(
7181 folly::netops::SCM_TSTAMP_SCHED, ByteEventType::SCHED, true, false);
7182 vals.emplace_back(
7183 folly::netops::SCM_TSTAMP_SND, ByteEventType::TX, true, false);
7184 vals.emplace_back(
7185 folly::netops::SCM_TSTAMP_ACK, ByteEventType::ACK, true, false);
7186 }
7187
7188 // hardware ts only
7189 {
7190 vals.emplace_back(
7191 folly::netops::SCM_TSTAMP_SCHED, ByteEventType::SCHED, false, true);
7192 vals.emplace_back(
7193 folly::netops::SCM_TSTAMP_SND, ByteEventType::TX, false, true);
7194 vals.emplace_back(
7195 folly::netops::SCM_TSTAMP_ACK, ByteEventType::ACK, false, true);
7196 }
7197
7198 return vals;
7199 }
7200 };
7201
7202 INSTANTIATE_TEST_SUITE_P(
7203 ByteEventTimestampTest,
7204 AsyncSocketByteEventHelperTimestampTest,
7205 ::testing::ValuesIn(
7206 AsyncSocketByteEventHelperTimestampTest::getTestingValues()));
7207
7208 /**
7209 * Check timestamp parsing for software and hardware timestamps.
7210 */
TEST_P(AsyncSocketByteEventHelperTimestampTest,CheckEventTimestamps)7211 TEST_P(AsyncSocketByteEventHelperTimestampTest, CheckEventTimestamps) {
7212 const auto softwareTsSec = std::chrono::seconds(59);
7213 const auto softwareTsNs = std::chrono::nanoseconds(11);
7214 const auto hardwareTsSec = std::chrono::seconds(79);
7215 const auto hardwareTsNs = std::chrono::nanoseconds(31);
7216
7217 auto params = GetParam();
7218 auto scmTs = cmsgForScmTimestamping(params.scmType, 0);
7219 auto serrTs = cmsgForSockExtendedErrTimestamping();
7220 if (params.includeSoftwareTs) {
7221 serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7222 }
7223 if (params.includeHardwareTs) {
7224 serrTs.setHardwareTimestamp(hardwareTsSec, hardwareTsNs);
7225 }
7226
7227 AsyncSocket::ByteEventHelper helper = {};
7228 helper.byteEventsEnabled = true;
7229 helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7230 folly::Optional<AsyncTransport::ByteEvent> maybeByteEvent;
7231 maybeByteEvent = helper.processCmsg(serrTs, 1 /* rawBytesWritten */);
7232 EXPECT_FALSE(maybeByteEvent.has_value());
7233 maybeByteEvent = helper.processCmsg(scmTs, 1 /* rawBytesWritten */);
7234
7235 // common checks
7236 ASSERT_TRUE(maybeByteEvent.has_value());
7237 const auto& byteEvent = *maybeByteEvent;
7238 EXPECT_EQ(0, byteEvent.offset);
7239 EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
7240
7241 EXPECT_EQ(params.expectedByteEventType, byteEvent.type);
7242 if (params.includeSoftwareTs) {
7243 EXPECT_EQ(softwareTsSec + softwareTsNs, byteEvent.maybeSoftwareTs);
7244 }
7245 if (params.includeHardwareTs) {
7246 EXPECT_EQ(hardwareTsSec + hardwareTsNs, byteEvent.maybeHardwareTs);
7247 }
7248 }
7249
7250 struct AsyncSocketByteEventHelperOffsetTestParams {
7251 uint64_t rawBytesWrittenWhenByteEventsEnabled{0};
7252 uint64_t byteTimestamped;
7253 uint64_t rawBytesWrittenWhenTimestampReceived;
7254 };
7255
7256 class AsyncSocketByteEventHelperOffsetTest
7257 : public AsyncSocketByteEventHelperTest,
7258 public testing::WithParamInterface<
7259 AsyncSocketByteEventHelperOffsetTestParams> {
7260 public:
7261 static std::vector<AsyncSocketByteEventHelperOffsetTestParams>
getTestingValues()7262 getTestingValues() {
7263 std::vector<AsyncSocketByteEventHelperOffsetTestParams> vals;
7264 const std::array<uint64_t, 5> rawBytesWrittenWhenByteEventsEnabledVals{
7265 0, 1, 100, 4294967295, 4294967296};
7266 for (const auto& rawBytesWrittenWhenByteEventsEnabled :
7267 rawBytesWrittenWhenByteEventsEnabledVals) {
7268 auto addParams = [&](auto params) {
7269 // check if case is valid based on rawBytesWrittenWhenByteEventsEnabled
7270 if (rawBytesWrittenWhenByteEventsEnabled <= params.byteTimestamped) {
7271 vals.push_back(params);
7272 }
7273 };
7274
7275 // case 1
7276 // bytes sent on receipt of timestamp == byte timestamped
7277 {
7278 AsyncSocketByteEventHelperOffsetTestParams params;
7279 params.rawBytesWrittenWhenByteEventsEnabled =
7280 rawBytesWrittenWhenByteEventsEnabled;
7281 params.byteTimestamped = 0;
7282 params.rawBytesWrittenWhenTimestampReceived = 0;
7283 addParams(params);
7284 }
7285 {
7286 AsyncSocketByteEventHelperOffsetTestParams params;
7287 params.rawBytesWrittenWhenByteEventsEnabled =
7288 rawBytesWrittenWhenByteEventsEnabled;
7289 params.byteTimestamped = 1;
7290 params.rawBytesWrittenWhenTimestampReceived = 1;
7291 addParams(params);
7292 }
7293 {
7294 AsyncSocketByteEventHelperOffsetTestParams params;
7295 params.rawBytesWrittenWhenByteEventsEnabled =
7296 rawBytesWrittenWhenByteEventsEnabled;
7297 params.byteTimestamped = 101;
7298 params.rawBytesWrittenWhenTimestampReceived = 101;
7299 addParams(params);
7300 }
7301
7302 // bytes sent on receipt of timestamp > byte timestamped
7303 {
7304 AsyncSocketByteEventHelperOffsetTestParams params;
7305 params.rawBytesWrittenWhenByteEventsEnabled =
7306 rawBytesWrittenWhenByteEventsEnabled;
7307 params.byteTimestamped = 1;
7308 params.rawBytesWrittenWhenTimestampReceived = 2;
7309 addParams(params);
7310 }
7311 {
7312 AsyncSocketByteEventHelperOffsetTestParams params;
7313 params.rawBytesWrittenWhenByteEventsEnabled =
7314 rawBytesWrittenWhenByteEventsEnabled;
7315 params.byteTimestamped = 101;
7316 params.rawBytesWrittenWhenTimestampReceived = 102;
7317 addParams(params);
7318 }
7319
7320 // case 2
7321 // bytes sent on receipt of timestamp == byte timestamped, boundary test
7322 // (boundary is at 2^32)
7323 {
7324 AsyncSocketByteEventHelperOffsetTestParams params;
7325 params.rawBytesWrittenWhenByteEventsEnabled =
7326 rawBytesWrittenWhenByteEventsEnabled;
7327 params.byteTimestamped = 4294967294;
7328 params.rawBytesWrittenWhenTimestampReceived = 4294967294;
7329 addParams(params);
7330 }
7331 {
7332 AsyncSocketByteEventHelperOffsetTestParams params;
7333 params.rawBytesWrittenWhenByteEventsEnabled =
7334 rawBytesWrittenWhenByteEventsEnabled;
7335 params.byteTimestamped = 4294967295;
7336 params.rawBytesWrittenWhenTimestampReceived = 4294967295;
7337 addParams(params);
7338 }
7339 {
7340 AsyncSocketByteEventHelperOffsetTestParams params;
7341 params.rawBytesWrittenWhenByteEventsEnabled =
7342 rawBytesWrittenWhenByteEventsEnabled;
7343 params.byteTimestamped = 4294967296;
7344 params.rawBytesWrittenWhenTimestampReceived = 4294967296;
7345 addParams(params);
7346 }
7347 {
7348 AsyncSocketByteEventHelperOffsetTestParams params;
7349 params.rawBytesWrittenWhenByteEventsEnabled =
7350 rawBytesWrittenWhenByteEventsEnabled;
7351 params.byteTimestamped = 4294967297;
7352 params.rawBytesWrittenWhenTimestampReceived = 4294967297;
7353 addParams(params);
7354 }
7355 {
7356 AsyncSocketByteEventHelperOffsetTestParams params;
7357 params.rawBytesWrittenWhenByteEventsEnabled =
7358 rawBytesWrittenWhenByteEventsEnabled;
7359 params.byteTimestamped = 4294967298;
7360 params.rawBytesWrittenWhenTimestampReceived = 4294967298;
7361 addParams(params);
7362 }
7363
7364 // case 3
7365 // bytes sent on receipt of timestamp > byte timestamped, boundary test
7366 // (boundary is at 2^32)
7367 {
7368 AsyncSocketByteEventHelperOffsetTestParams params;
7369 params.rawBytesWrittenWhenByteEventsEnabled =
7370 rawBytesWrittenWhenByteEventsEnabled;
7371 params.byteTimestamped = 4294967293;
7372 params.rawBytesWrittenWhenTimestampReceived = 4294967294;
7373 addParams(params);
7374 }
7375 {
7376 AsyncSocketByteEventHelperOffsetTestParams params;
7377 params.rawBytesWrittenWhenByteEventsEnabled =
7378 rawBytesWrittenWhenByteEventsEnabled;
7379 params.byteTimestamped = 4294967294;
7380 params.rawBytesWrittenWhenTimestampReceived = 4294967295;
7381 addParams(params);
7382 }
7383 {
7384 AsyncSocketByteEventHelperOffsetTestParams params;
7385 params.rawBytesWrittenWhenByteEventsEnabled =
7386 rawBytesWrittenWhenByteEventsEnabled;
7387 params.byteTimestamped = 4294967295;
7388 params.rawBytesWrittenWhenTimestampReceived = 4294967296;
7389 addParams(params);
7390 }
7391 {
7392 AsyncSocketByteEventHelperOffsetTestParams params;
7393 params.rawBytesWrittenWhenByteEventsEnabled =
7394 rawBytesWrittenWhenByteEventsEnabled;
7395 params.byteTimestamped = 4294967296;
7396 params.rawBytesWrittenWhenTimestampReceived = 4294967297;
7397 addParams(params);
7398 }
7399
7400 // case 4
7401 // bytes sent on receipt of timestamp > byte timestamped, wrap test
7402 // (boundary is at 2^32)
7403 {
7404 AsyncSocketByteEventHelperOffsetTestParams params;
7405 params.rawBytesWrittenWhenByteEventsEnabled =
7406 rawBytesWrittenWhenByteEventsEnabled;
7407 params.byteTimestamped = 4294967275;
7408 params.rawBytesWrittenWhenTimestampReceived = 4294967305;
7409 addParams(params);
7410 }
7411 {
7412 AsyncSocketByteEventHelperOffsetTestParams params;
7413 params.rawBytesWrittenWhenByteEventsEnabled =
7414 rawBytesWrittenWhenByteEventsEnabled;
7415 params.byteTimestamped = 4294967295;
7416 params.rawBytesWrittenWhenTimestampReceived = 4294967296;
7417 addParams(params);
7418 }
7419 {
7420 AsyncSocketByteEventHelperOffsetTestParams params;
7421 params.rawBytesWrittenWhenByteEventsEnabled =
7422 rawBytesWrittenWhenByteEventsEnabled;
7423 params.byteTimestamped = 4294967285;
7424 params.rawBytesWrittenWhenTimestampReceived = 4294967305;
7425 addParams(params);
7426 }
7427
7428 // case 5
7429 // special case when timestamp enabled when bytes transferred > (2^32)
7430 // bytes sent on receipt of timestamp == byte timestamped, boundary test
7431 // (boundary is at 2^32)
7432 {
7433 AsyncSocketByteEventHelperOffsetTestParams params;
7434 params.rawBytesWrittenWhenByteEventsEnabled =
7435 rawBytesWrittenWhenByteEventsEnabled;
7436 params.byteTimestamped = 6442450943;
7437 params.rawBytesWrittenWhenTimestampReceived = 6442450943;
7438 addParams(params);
7439 }
7440
7441 // case 6
7442 // special case when timestamp enabled when bytes transferred > (2^32)
7443 // bytes sent on receipt of timestamp > byte timestamped, boundary test
7444 // (boundary is at 2^32)
7445 {
7446 AsyncSocketByteEventHelperOffsetTestParams params;
7447 params.rawBytesWrittenWhenByteEventsEnabled =
7448 rawBytesWrittenWhenByteEventsEnabled;
7449 params.byteTimestamped = 6442450943;
7450 params.rawBytesWrittenWhenTimestampReceived = 6442450944;
7451 addParams(params);
7452 }
7453
7454 // case 7
7455 // special case when timestamp enabled when bytes transferred > (2^32)
7456 // bytes sent on receipt of timestamp > byte timestamped, wrap test
7457 // (boundary is at 2^32)
7458 {
7459 AsyncSocketByteEventHelperOffsetTestParams params;
7460 params.rawBytesWrittenWhenByteEventsEnabled =
7461 rawBytesWrittenWhenByteEventsEnabled;
7462 params.byteTimestamped = 6442450943;
7463 params.rawBytesWrittenWhenTimestampReceived = 8589934591;
7464 addParams(params);
7465 }
7466 }
7467
7468 return vals;
7469 }
7470 };
7471
7472 INSTANTIATE_TEST_SUITE_P(
7473 ByteEventOffsetTest,
7474 AsyncSocketByteEventHelperOffsetTest,
7475 ::testing::ValuesIn(
7476 AsyncSocketByteEventHelperOffsetTest::getTestingValues()));
7477
7478 /**
7479 * Check byte offset handling, including boundary cases.
7480 *
7481 * See AsyncSocket::ByteEventHelper::processCmsg for details.
7482 */
TEST_P(AsyncSocketByteEventHelperOffsetTest,CheckCalculatedOffset)7483 TEST_P(AsyncSocketByteEventHelperOffsetTest, CheckCalculatedOffset) {
7484 auto params = GetParam();
7485
7486 // because we use SOF_TIMESTAMPING_OPT_ID, byte offsets delivered from the
7487 // kernel are offset (relative to bytes written by AsyncSocket) by the number
7488 // of bytes AsyncSocket had written to the socket when enabling timestamps
7489 //
7490 // here we calculate what the kernel offset would be for the given byte offset
7491 const uint64_t bytesPerOffsetWrap =
7492 static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1;
7493
7494 auto kernelByteOffset =
7495 params.byteTimestamped - params.rawBytesWrittenWhenByteEventsEnabled;
7496 if (kernelByteOffset > 0) {
7497 kernelByteOffset = kernelByteOffset % bytesPerOffsetWrap;
7498 }
7499
7500 auto scmTs =
7501 cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, kernelByteOffset);
7502 const auto softwareTsSec = std::chrono::seconds(59);
7503 const auto softwareTsNs = std::chrono::nanoseconds(11);
7504 auto serrTs = cmsgForSockExtendedErrTimestamping();
7505 serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7506
7507 AsyncSocket::ByteEventHelper helper = {};
7508 helper.byteEventsEnabled = true;
7509 helper.rawBytesWrittenWhenByteEventsEnabled =
7510 params.rawBytesWrittenWhenByteEventsEnabled;
7511
7512 EXPECT_FALSE(helper.processCmsg(
7513 scmTs,
7514 params.rawBytesWrittenWhenTimestampReceived /* rawBytesWritten */));
7515 const auto maybeByteEvent = helper.processCmsg(
7516 serrTs,
7517 params.rawBytesWrittenWhenTimestampReceived /* rawBytesWritten */);
7518 ASSERT_TRUE(maybeByteEvent.has_value());
7519 const auto& byteEvent = *maybeByteEvent;
7520
7521 EXPECT_EQ(params.byteTimestamped, byteEvent.offset);
7522 EXPECT_EQ(softwareTsSec + softwareTsNs, byteEvent.maybeSoftwareTs);
7523 }
7524
7525 #endif // FOLLY_HAVE_SO_TIMESTAMPING
7526
TEST(AsyncSocket,LifecycleCtorCallback)7527 TEST(AsyncSocket, LifecycleCtorCallback) {
7528 EventBase evb;
7529 // create socket and verify that w/o a ctor callback, nothing happens
7530 auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7531 EXPECT_EQ(socket1->getLifecycleObservers().size(), 0);
7532
7533 // Then register a ctor callback that registers a mock lifecycle observer
7534 // NB: use nicemock instead of strict b/c the actual lifecycle testing
7535 // is done below and this simplifies the test
7536 auto lifecycleCB =
7537 std::make_shared<NiceMock<MockAsyncSocketLifecycleObserver>>();
7538 auto lifecycleRawPtr = lifecycleCB.get();
7539 // verify the first part of the lifecycle was processed
7540 ConstructorCallback<AsyncSocket>::addNewConstructorCallback(
7541 [lifecycleRawPtr](AsyncSocket* s) {
7542 s->addLifecycleObserver(lifecycleRawPtr);
7543 });
7544 auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7545 EXPECT_EQ(socket2->getLifecycleObservers().size(), 1);
7546 EXPECT_THAT(
7547 socket2->getLifecycleObservers(),
7548 UnorderedElementsAre(lifecycleCB.get()));
7549 Mock::VerifyAndClearExpectations(lifecycleCB.get());
7550 }
7551
TEST(AsyncSocket,LifecycleObserverDetachAndAttachEvb)7552 TEST(AsyncSocket, LifecycleObserverDetachAndAttachEvb) {
7553 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7554 EventBase evb;
7555 EventBase evb2;
7556 auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7557 EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7558 socket->addLifecycleObserver(cb.get());
7559 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7560 Mock::VerifyAndClearExpectations(cb.get());
7561
7562 // Detach the evb and attach a new evb2
7563 EXPECT_CALL(*cb, evbDetachMock(socket.get(), &evb));
7564 socket->detachEventBase();
7565 EXPECT_EQ(nullptr, socket->getEventBase());
7566 Mock::VerifyAndClearExpectations(cb.get());
7567
7568 EXPECT_CALL(*cb, evbAttachMock(socket.get(), &evb2));
7569 socket->attachEventBase(&evb2);
7570 EXPECT_EQ(&evb2, socket->getEventBase());
7571 Mock::VerifyAndClearExpectations(cb.get());
7572
7573 // detach the new evb2 and re-attach the old evb.
7574 EXPECT_CALL(*cb, evbDetachMock(socket.get(), &evb2));
7575 socket->detachEventBase();
7576 EXPECT_EQ(nullptr, socket->getEventBase());
7577 Mock::VerifyAndClearExpectations(cb.get());
7578
7579 EXPECT_CALL(*cb, evbAttachMock(socket.get(), &evb));
7580 socket->attachEventBase(&evb);
7581 EXPECT_EQ(&evb, socket->getEventBase());
7582 Mock::VerifyAndClearExpectations(cb.get());
7583
7584 InSequence s;
7585 EXPECT_CALL(*cb, destroyMock(socket.get()));
7586 socket = nullptr;
7587 Mock::VerifyAndClearExpectations(cb.get());
7588 }
7589
TEST(AsyncSocket,LifecycleObserverAttachThenDestroySocket)7590 TEST(AsyncSocket, LifecycleObserverAttachThenDestroySocket) {
7591 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7592 TestServer server;
7593
7594 EventBase evb;
7595 auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7596 EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7597 socket->addLifecycleObserver(cb.get());
7598 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7599 Mock::VerifyAndClearExpectations(cb.get());
7600
7601 EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
7602 EXPECT_CALL(*cb, fdAttachMock(socket.get()));
7603 EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
7604 socket->connect(nullptr, server.getAddress(), 30);
7605 evb.loop();
7606 Mock::VerifyAndClearExpectations(cb.get());
7607
7608 InSequence s;
7609 EXPECT_CALL(*cb, closeMock(socket.get()));
7610 EXPECT_CALL(*cb, destroyMock(socket.get()));
7611 socket = nullptr;
7612 Mock::VerifyAndClearExpectations(cb.get());
7613 }
7614
TEST(AsyncSocket,LifecycleObserverAttachThenConnectError)7615 TEST(AsyncSocket, LifecycleObserverAttachThenConnectError) {
7616 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7617 // port =1 is unreachble on localhost
7618 folly::SocketAddress unreachable{"::1", 1};
7619
7620 EventBase evb;
7621 auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7622 EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7623 socket->addLifecycleObserver(cb.get());
7624 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7625 Mock::VerifyAndClearExpectations(cb.get());
7626
7627 // the current state machine calls AsyncSocket::invokeConnectionError() twice
7628 // for this use-case...
7629 EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
7630 EXPECT_CALL(*cb, fdAttachMock(socket.get()));
7631 EXPECT_CALL(*cb, connectErrorMock(socket.get(), _)).Times(2);
7632 EXPECT_CALL(*cb, closeMock(socket.get()));
7633 socket->connect(nullptr, unreachable, 1);
7634 evb.loop();
7635 Mock::VerifyAndClearExpectations(cb.get());
7636
7637 EXPECT_CALL(*cb, destroyMock(socket.get()));
7638 socket = nullptr;
7639 Mock::VerifyAndClearExpectations(cb.get());
7640 }
7641
TEST(AsyncSocket,LifecycleObserverMultipleAttachThenDestroySocket)7642 TEST(AsyncSocket, LifecycleObserverMultipleAttachThenDestroySocket) {
7643 auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7644 auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7645 TestServer server;
7646
7647 EventBase evb;
7648 auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7649 EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
7650 socket->addLifecycleObserver(cb1.get());
7651 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7652 Mock::VerifyAndClearExpectations(cb1.get());
7653 Mock::VerifyAndClearExpectations(cb2.get());
7654
7655 EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
7656 socket->addLifecycleObserver(cb2.get());
7657 EXPECT_THAT(
7658 socket->getLifecycleObservers(),
7659 UnorderedElementsAre(cb1.get(), cb2.get()));
7660 Mock::VerifyAndClearExpectations(cb1.get());
7661 Mock::VerifyAndClearExpectations(cb2.get());
7662
7663 InSequence s;
7664 EXPECT_CALL(*cb1, connectAttemptMock(socket.get()));
7665 EXPECT_CALL(*cb2, connectAttemptMock(socket.get()));
7666 EXPECT_CALL(*cb1, fdAttachMock(socket.get()));
7667 EXPECT_CALL(*cb2, fdAttachMock(socket.get()));
7668 EXPECT_CALL(*cb1, connectSuccessMock(socket.get()));
7669 EXPECT_CALL(*cb2, connectSuccessMock(socket.get()));
7670 socket->connect(nullptr, server.getAddress(), 30);
7671 evb.loop();
7672 Mock::VerifyAndClearExpectations(cb1.get());
7673 Mock::VerifyAndClearExpectations(cb2.get());
7674
7675 EXPECT_CALL(*cb1, closeMock(socket.get()));
7676 EXPECT_CALL(*cb2, closeMock(socket.get()));
7677 EXPECT_CALL(*cb1, destroyMock(socket.get()));
7678 EXPECT_CALL(*cb2, destroyMock(socket.get()));
7679 socket = nullptr;
7680 Mock::VerifyAndClearExpectations(cb1.get());
7681 Mock::VerifyAndClearExpectations(cb2.get());
7682 }
7683
TEST(AsyncSocket,LifecycleObserverAttachRemove)7684 TEST(AsyncSocket, LifecycleObserverAttachRemove) {
7685 EventBase evb;
7686 auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7687
7688 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7689 EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7690 socket->addLifecycleObserver(cb.get());
7691 Mock::VerifyAndClearExpectations(cb.get());
7692
7693 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7694 EXPECT_CALL(*cb, observerDetachMock(socket.get()));
7695 EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
7696 EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
7697 Mock::VerifyAndClearExpectations(cb.get());
7698 }
7699
TEST(AsyncSocket,LifecycleObserverAttachRemoveMultiple)7700 TEST(AsyncSocket, LifecycleObserverAttachRemoveMultiple) {
7701 EventBase evb;
7702 auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7703
7704 auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7705 EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
7706 socket->addLifecycleObserver(cb1.get());
7707 Mock::VerifyAndClearExpectations(cb1.get());
7708 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7709
7710 auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7711 EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
7712 socket->addLifecycleObserver(cb2.get());
7713 Mock::VerifyAndClearExpectations(cb2.get());
7714 EXPECT_THAT(
7715 socket->getLifecycleObservers(),
7716 UnorderedElementsAre(cb1.get(), cb2.get()));
7717
7718 EXPECT_CALL(*cb1, observerDetachMock(socket.get()));
7719 EXPECT_TRUE(socket->removeLifecycleObserver(cb1.get()));
7720 Mock::VerifyAndClearExpectations(cb1.get());
7721 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb2.get()));
7722
7723 EXPECT_CALL(*cb2, observerDetachMock(socket.get()));
7724 EXPECT_TRUE(socket->removeLifecycleObserver(cb2.get()));
7725 Mock::VerifyAndClearExpectations(cb2.get());
7726 EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
7727 }
7728
TEST(AsyncSocket,LifecycleObserverAttachRemoveMultipleReverse)7729 TEST(AsyncSocket, LifecycleObserverAttachRemoveMultipleReverse) {
7730 EventBase evb;
7731 auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7732
7733 auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7734 EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
7735 socket->addLifecycleObserver(cb1.get());
7736 Mock::VerifyAndClearExpectations(cb1.get());
7737 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7738
7739 auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7740 EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
7741 socket->addLifecycleObserver(cb2.get());
7742 Mock::VerifyAndClearExpectations(cb2.get());
7743 EXPECT_THAT(
7744 socket->getLifecycleObservers(),
7745 UnorderedElementsAre(cb1.get(), cb2.get()));
7746
7747 EXPECT_CALL(*cb2, observerDetachMock(socket.get()));
7748 EXPECT_TRUE(socket->removeLifecycleObserver(cb2.get()));
7749 Mock::VerifyAndClearExpectations(cb2.get());
7750 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7751
7752 EXPECT_CALL(*cb1, observerDetachMock(socket.get()));
7753 EXPECT_TRUE(socket->removeLifecycleObserver(cb1.get()));
7754 Mock::VerifyAndClearExpectations(cb1.get());
7755 EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
7756 }
7757
TEST(AsyncSocket,LifecycleObserverRemoveMissing)7758 TEST(AsyncSocket, LifecycleObserverRemoveMissing) {
7759 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7760 EventBase evb;
7761 auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7762 EXPECT_FALSE(socket->removeLifecycleObserver(cb.get()));
7763 }
7764
TEST(AsyncSocket,LifecycleObserverMultipleAttachThenRemove)7765 TEST(AsyncSocket, LifecycleObserverMultipleAttachThenRemove) {
7766 auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7767 auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7768 TestServer server;
7769
7770 EventBase evb;
7771 auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7772 EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
7773 socket->addLifecycleObserver(cb1.get());
7774 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7775 Mock::VerifyAndClearExpectations(cb1.get());
7776 Mock::VerifyAndClearExpectations(cb2.get());
7777
7778 EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
7779 socket->addLifecycleObserver(cb2.get());
7780 EXPECT_THAT(
7781 socket->getLifecycleObservers(),
7782 UnorderedElementsAre(cb1.get(), cb2.get()));
7783 Mock::VerifyAndClearExpectations(cb1.get());
7784 Mock::VerifyAndClearExpectations(cb2.get());
7785
7786 EXPECT_CALL(*cb2, observerDetachMock(socket.get()));
7787 EXPECT_TRUE(socket->removeLifecycleObserver(cb2.get()));
7788 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7789 Mock::VerifyAndClearExpectations(cb1.get());
7790 Mock::VerifyAndClearExpectations(cb2.get());
7791
7792 EXPECT_CALL(*cb1, observerDetachMock(socket.get()));
7793 socket->removeLifecycleObserver(cb1.get());
7794 EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
7795 Mock::VerifyAndClearExpectations(cb1.get());
7796 Mock::VerifyAndClearExpectations(cb2.get());
7797 }
7798
TEST(AsyncSocket,LifecycleObserverDetach)7799 TEST(AsyncSocket, LifecycleObserverDetach) {
7800 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7801 TestServer server;
7802
7803 EventBase evb;
7804 auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7805 EXPECT_CALL(*cb, observerAttachMock(socket1.get()));
7806 socket1->addLifecycleObserver(cb.get());
7807 EXPECT_THAT(socket1->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7808 Mock::VerifyAndClearExpectations(cb.get());
7809
7810 EXPECT_CALL(*cb, connectAttemptMock(socket1.get()));
7811 EXPECT_CALL(*cb, fdAttachMock(socket1.get()));
7812 EXPECT_CALL(*cb, connectSuccessMock(socket1.get()));
7813 socket1->connect(nullptr, server.getAddress(), 30);
7814 evb.loop();
7815 Mock::VerifyAndClearExpectations(cb.get());
7816
7817 EXPECT_CALL(*cb, fdDetachMock(socket1.get()));
7818 auto fd = socket1->detachNetworkSocket();
7819 Mock::VerifyAndClearExpectations(cb.get());
7820
7821 // create socket2, then immediately destroy it, should get no callbacks
7822 auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(&evb, fd));
7823 socket2 = nullptr;
7824
7825 // finally, destroy socket1
7826 EXPECT_CALL(*cb, destroyMock(socket1.get()));
7827 }
7828
TEST(AsyncSocket,LifecycleObserverMoveResubscribe)7829 TEST(AsyncSocket, LifecycleObserverMoveResubscribe) {
7830 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7831 TestServer server;
7832
7833 EventBase evb;
7834 auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7835 EXPECT_CALL(*cb, observerAttachMock(socket1.get()));
7836 socket1->addLifecycleObserver(cb.get());
7837 EXPECT_THAT(socket1->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7838 Mock::VerifyAndClearExpectations(cb.get());
7839
7840 EXPECT_CALL(*cb, connectAttemptMock(socket1.get()));
7841 EXPECT_CALL(*cb, fdAttachMock(socket1.get()));
7842 EXPECT_CALL(*cb, connectSuccessMock(socket1.get()));
7843 socket1->connect(nullptr, server.getAddress(), 30);
7844 evb.loop();
7845 Mock::VerifyAndClearExpectations(cb.get());
7846
7847 AsyncSocket* socket2PtrCapturedmoved = nullptr;
7848 {
7849 InSequence s;
7850 EXPECT_CALL(*cb, fdDetachMock(socket1.get()));
7851 EXPECT_CALL(*cb, moveMock(socket1.get(), Not(socket1.get())))
7852 .WillOnce(Invoke(
7853 [&socket2PtrCapturedmoved, &cb](auto oldSocket, auto newSocket) {
7854 socket2PtrCapturedmoved = newSocket;
7855 EXPECT_CALL(*cb, observerDetachMock(oldSocket));
7856 EXPECT_CALL(*cb, observerAttachMock(newSocket));
7857 EXPECT_TRUE(oldSocket->removeLifecycleObserver(cb.get()));
7858 EXPECT_THAT(oldSocket->getLifecycleObservers(), IsEmpty());
7859 newSocket->addLifecycleObserver(cb.get());
7860 EXPECT_THAT(
7861 newSocket->getLifecycleObservers(),
7862 UnorderedElementsAre(cb.get()));
7863 }));
7864 }
7865 auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
7866 Mock::VerifyAndClearExpectations(cb.get());
7867 EXPECT_EQ(socket2.get(), socket2PtrCapturedmoved);
7868
7869 {
7870 InSequence s;
7871 EXPECT_CALL(*cb, closeMock(socket2.get()));
7872 EXPECT_CALL(*cb, destroyMock(socket2.get()));
7873 }
7874 socket2 = nullptr;
7875 }
7876
TEST(AsyncSocket,LifecycleObserverMoveDoNotResubscribe)7877 TEST(AsyncSocket, LifecycleObserverMoveDoNotResubscribe) {
7878 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7879 TestServer server;
7880
7881 EventBase evb;
7882 auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7883 EXPECT_CALL(*cb, observerAttachMock(socket1.get()));
7884 socket1->addLifecycleObserver(cb.get());
7885 EXPECT_THAT(socket1->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7886 Mock::VerifyAndClearExpectations(cb.get());
7887
7888 EXPECT_CALL(*cb, connectAttemptMock(socket1.get()));
7889 EXPECT_CALL(*cb, fdAttachMock(socket1.get()));
7890 EXPECT_CALL(*cb, connectSuccessMock(socket1.get()));
7891 socket1->connect(nullptr, server.getAddress(), 30);
7892 evb.loop();
7893 Mock::VerifyAndClearExpectations(cb.get());
7894
7895 // close will not be called on socket1 because the fd is detached
7896 AsyncSocket* socket2PtrCapturedMoved = nullptr;
7897 InSequence s;
7898 EXPECT_CALL(*cb, fdDetachMock(socket1.get()));
7899 EXPECT_CALL(*cb, moveMock(socket1.get(), Not(socket1.get())))
7900 .WillOnce(Invoke(
7901 [&socket2PtrCapturedMoved](auto /* oldSocket */, auto newSocket) {
7902 socket2PtrCapturedMoved = newSocket;
7903 }));
7904 EXPECT_CALL(*cb, destroyMock(socket1.get()));
7905 auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
7906 Mock::VerifyAndClearExpectations(cb.get());
7907 EXPECT_EQ(socket2.get(), socket2PtrCapturedMoved);
7908 }
7909
TEST(AsyncSocket,LifecycleObserverDetachCallbackImmediately)7910 TEST(AsyncSocket, LifecycleObserverDetachCallbackImmediately) {
7911 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7912 TestServer server;
7913
7914 EventBase evb;
7915 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
7916 EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7917 socket->addLifecycleObserver(cb.get());
7918 EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7919 Mock::VerifyAndClearExpectations(cb.get());
7920
7921 EXPECT_CALL(*cb, observerDetachMock(socket.get()));
7922 EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
7923 EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
7924 Mock::VerifyAndClearExpectations(cb.get());
7925
7926 // keep going to ensure no further callbacks
7927 socket->connect(nullptr, server.getAddress(), 30);
7928 evb.loop();
7929 }
7930
TEST(AsyncSocket,LifecycleObserverDetachCallbackAfterConnect)7931 TEST(AsyncSocket, LifecycleObserverDetachCallbackAfterConnect) {
7932 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7933 TestServer server;
7934
7935 EventBase evb;
7936 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
7937 EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7938 socket->addLifecycleObserver(cb.get());
7939 Mock::VerifyAndClearExpectations(cb.get());
7940
7941 EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
7942 EXPECT_CALL(*cb, fdAttachMock(socket.get()));
7943 EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
7944 socket->connect(nullptr, server.getAddress(), 30);
7945 evb.loop();
7946 Mock::VerifyAndClearExpectations(cb.get());
7947
7948 EXPECT_CALL(*cb, observerDetachMock(socket.get()));
7949 EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
7950 Mock::VerifyAndClearExpectations(cb.get());
7951 }
7952
TEST(AsyncSocket,LifecycleObserverDetachCallbackAfterClose)7953 TEST(AsyncSocket, LifecycleObserverDetachCallbackAfterClose) {
7954 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7955 TestServer server;
7956
7957 EventBase evb;
7958 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
7959 EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7960 socket->addLifecycleObserver(cb.get());
7961 Mock::VerifyAndClearExpectations(cb.get());
7962
7963 EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
7964 EXPECT_CALL(*cb, fdAttachMock(socket.get()));
7965 EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
7966 socket->connect(nullptr, server.getAddress(), 30);
7967 evb.loop();
7968 Mock::VerifyAndClearExpectations(cb.get());
7969
7970 EXPECT_CALL(*cb, closeMock(socket.get()));
7971 socket->closeNow();
7972 Mock::VerifyAndClearExpectations(cb.get());
7973
7974 EXPECT_CALL(*cb, observerDetachMock(socket.get()));
7975 EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
7976 Mock::VerifyAndClearExpectations(cb.get());
7977 }
7978
TEST(AsyncSocket,LifecycleObserverDetachCallbackcloseDuringDestroy)7979 TEST(AsyncSocket, LifecycleObserverDetachCallbackcloseDuringDestroy) {
7980 auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7981 TestServer server;
7982
7983 EventBase evb;
7984 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
7985 EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7986 socket->addLifecycleObserver(cb.get());
7987 Mock::VerifyAndClearExpectations(cb.get());
7988
7989 EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
7990 EXPECT_CALL(*cb, fdAttachMock(socket.get()));
7991 EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
7992 socket->connect(nullptr, server.getAddress(), 30);
7993 evb.loop();
7994 Mock::VerifyAndClearExpectations(cb.get());
7995
7996 InSequence s;
7997 EXPECT_CALL(*cb, closeMock(socket.get()))
7998 .WillOnce(Invoke([&cb](auto callbackSocket) {
7999 EXPECT_TRUE(callbackSocket->removeLifecycleObserver(cb.get()));
8000 }));
8001 EXPECT_CALL(*cb, observerDetachMock(socket.get()));
8002 socket = nullptr;
8003 Mock::VerifyAndClearExpectations(cb.get());
8004 }
8005
TEST(AsyncSocket,LifecycleObserverBaseClassMoveNoCrash)8006 TEST(AsyncSocket, LifecycleObserverBaseClassMoveNoCrash) {
8007 // use mock for AsyncTransport::LifecycleObserver, which does not have
8008 // move or fdDetach events; verify that static_cast works as expected
8009 auto cb = std::make_unique<StrictMock<MockAsyncTransportLifecycleObserver>>();
8010 TestServer server;
8011
8012 EventBase evb;
8013 auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
8014 EXPECT_CALL(*cb, observerAttachMock(socket1.get()));
8015 socket1->addLifecycleObserver(cb.get());
8016 EXPECT_THAT(socket1->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
8017 Mock::VerifyAndClearExpectations(cb.get());
8018
8019 EXPECT_CALL(*cb, connectAttemptMock(socket1.get()));
8020 EXPECT_CALL(*cb, connectSuccessMock(socket1.get()));
8021 socket1->connect(nullptr, server.getAddress(), 30);
8022 evb.loop();
8023 Mock::VerifyAndClearExpectations(cb.get());
8024
8025 // we'll see socket1 get destroyed, but nothing else
8026 EXPECT_CALL(*cb, destroyMock(socket1.get()));
8027 auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
8028 Mock::VerifyAndClearExpectations(cb.get());
8029 }
8030
TEST(AsyncSocket,PreReceivedData)8031 TEST(AsyncSocket, PreReceivedData) {
8032 TestServer server;
8033
8034 EventBase evb;
8035 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8036 socket->connect(nullptr, server.getAddress(), 30);
8037 evb.loop();
8038
8039 socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
8040
8041 auto acceptedSocket = server.acceptAsync(&evb);
8042
8043 ReadCallback peekCallback(2);
8044 ReadCallback readCallback;
8045 peekCallback.dataAvailableCallback = [&]() {
8046 peekCallback.verifyData("he", 2);
8047 acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h"));
8048 acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e"));
8049 acceptedSocket->setReadCB(nullptr);
8050 acceptedSocket->setReadCB(&readCallback);
8051 };
8052 readCallback.dataAvailableCallback = [&]() {
8053 if (readCallback.dataRead() == 5) {
8054 readCallback.verifyData("hello", 5);
8055 acceptedSocket->setReadCB(nullptr);
8056 }
8057 };
8058
8059 acceptedSocket->setReadCB(&peekCallback);
8060
8061 evb.loop();
8062 }
8063
TEST(AsyncSocket,PreReceivedDataOnly)8064 TEST(AsyncSocket, PreReceivedDataOnly) {
8065 TestServer server;
8066
8067 EventBase evb;
8068 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8069 socket->connect(nullptr, server.getAddress(), 30);
8070 evb.loop();
8071
8072 socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
8073
8074 auto acceptedSocket = server.acceptAsync(&evb);
8075
8076 ReadCallback peekCallback;
8077 ReadCallback readCallback;
8078 peekCallback.dataAvailableCallback = [&]() {
8079 peekCallback.verifyData("hello", 5);
8080 acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
8081 EXPECT_TRUE(acceptedSocket->readable());
8082 acceptedSocket->setReadCB(&readCallback);
8083 };
8084 readCallback.dataAvailableCallback = [&]() {
8085 readCallback.verifyData("hello", 5);
8086 acceptedSocket->setReadCB(nullptr);
8087 };
8088
8089 acceptedSocket->setReadCB(&peekCallback);
8090
8091 evb.loop();
8092 }
8093
TEST(AsyncSocket,PreReceivedDataPartial)8094 TEST(AsyncSocket, PreReceivedDataPartial) {
8095 TestServer server;
8096
8097 EventBase evb;
8098 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8099 socket->connect(nullptr, server.getAddress(), 30);
8100 evb.loop();
8101
8102 socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
8103
8104 auto acceptedSocket = server.acceptAsync(&evb);
8105
8106 ReadCallback peekCallback;
8107 ReadCallback smallReadCallback(3);
8108 ReadCallback normalReadCallback;
8109 peekCallback.dataAvailableCallback = [&]() {
8110 peekCallback.verifyData("hello", 5);
8111 acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
8112 acceptedSocket->setReadCB(&smallReadCallback);
8113 };
8114 smallReadCallback.dataAvailableCallback = [&]() {
8115 smallReadCallback.verifyData("hel", 3);
8116 acceptedSocket->setReadCB(&normalReadCallback);
8117 };
8118 normalReadCallback.dataAvailableCallback = [&]() {
8119 normalReadCallback.verifyData("lo", 2);
8120 acceptedSocket->setReadCB(nullptr);
8121 };
8122
8123 acceptedSocket->setReadCB(&peekCallback);
8124
8125 evb.loop();
8126 }
8127
TEST(AsyncSocket,PreReceivedDataTakeover)8128 TEST(AsyncSocket, PreReceivedDataTakeover) {
8129 TestServer server;
8130
8131 EventBase evb;
8132 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8133 socket->connect(nullptr, server.getAddress(), 30);
8134 evb.loop();
8135
8136 socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
8137
8138 auto fd = server.acceptFD();
8139 SocketAddress peerAddress;
8140 peerAddress.setFromPeerAddress(fd);
8141 auto acceptedSocket =
8142 AsyncSocket::UniquePtr(new AsyncSocket(&evb, fd, 0, &peerAddress));
8143 AsyncSocket::UniquePtr takeoverSocket;
8144
8145 ReadCallback peekCallback(3);
8146 ReadCallback readCallback;
8147 peekCallback.dataAvailableCallback = [&]() {
8148 peekCallback.verifyData("hel", 3);
8149 acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
8150 acceptedSocket->setReadCB(nullptr);
8151 takeoverSocket =
8152 AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket)));
8153 takeoverSocket->setReadCB(&readCallback);
8154 };
8155 readCallback.dataAvailableCallback = [&]() {
8156 readCallback.verifyData("hello", 5);
8157 takeoverSocket->setReadCB(nullptr);
8158 };
8159
8160 acceptedSocket->setReadCB(&peekCallback);
8161
8162 evb.loop();
8163 // Verify we can still get the peer address after the peer socket is reset.
8164 socket->closeWithReset();
8165 evb.loopOnce();
8166 SocketAddress socketPeerAddress;
8167 takeoverSocket->getPeerAddress(&socketPeerAddress);
8168 EXPECT_EQ(socketPeerAddress, peerAddress);
8169 }
8170
8171 #ifdef MSG_NOSIGNAL
TEST(AsyncSocketTest,SendMessageFlags)8172 TEST(AsyncSocketTest, SendMessageFlags) {
8173 TestServer server;
8174 TestSendMsgParamsCallback sendMsgCB(
8175 MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE, 0, nullptr);
8176
8177 // connect()
8178 EventBase evb;
8179 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8180
8181 ConnCallback ccb;
8182 socket->connect(&ccb, server.getAddress(), 30);
8183 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
8184
8185 evb.loop();
8186 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
8187
8188 // Set SendMsgParamsCallback
8189 socket->setSendMsgParamCB(&sendMsgCB);
8190 ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB);
8191
8192 // Write the first portion of data. This data is expected to be
8193 // sent out immediately.
8194 std::vector<uint8_t> buf(128, 'a');
8195 WriteCallback wcb;
8196 sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL);
8197 socket->write(&wcb, buf.data(), buf.size());
8198 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
8199 ASSERT_TRUE(sendMsgCB.queriedFlags_);
8200 ASSERT_FALSE(sendMsgCB.queriedData_);
8201
8202 // Using different flags for the second write operation.
8203 // MSG_MORE flag is expected to delay sending this
8204 // data to the wire.
8205 sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE);
8206 socket->write(&wcb, buf.data(), buf.size());
8207 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
8208 ASSERT_TRUE(sendMsgCB.queriedFlags_);
8209 ASSERT_FALSE(sendMsgCB.queriedData_);
8210
8211 // Make sure the accepted socket saw only the data from
8212 // the first write request.
8213 std::vector<uint8_t> readbuf(2 * buf.size());
8214 uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size());
8215 ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
8216 ASSERT_EQ(bytesRead, buf.size());
8217
8218 // Make sure the server got a connection and received the data
8219 acceptedSocket->close();
8220 socket->close();
8221
8222 ASSERT_TRUE(socket->isClosedBySelf());
8223 ASSERT_FALSE(socket->isClosedByPeer());
8224 }
8225
TEST(AsyncSocketTest,SendMessageAncillaryData)8226 TEST(AsyncSocketTest, SendMessageAncillaryData) {
8227 NetworkSocket fds[2];
8228 EXPECT_EQ(netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0);
8229
8230 // "Client" socket
8231 auto cfd = fds[0];
8232 ASSERT_NE(cfd, NetworkSocket());
8233
8234 // "Server" socket
8235 auto sfd = fds[1];
8236 ASSERT_NE(sfd, NetworkSocket());
8237 SCOPE_EXIT { netops::close(sfd); };
8238
8239 // Instantiate AsyncSocket object for the connected socket
8240 EventBase evb;
8241 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, cfd);
8242
8243 // Open a temporary file and write a magic string to it
8244 // We'll transfer the file handle to test the message parameters
8245 // callback logic.
8246 TemporaryFile file(
8247 StringPiece(), fs::path(), TemporaryFile::Scope::UNLINK_IMMEDIATELY);
8248 int tmpfd = file.fd();
8249 ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file";
8250 std::string magicString("Magic string");
8251 ASSERT_EQ(
8252 write(tmpfd, magicString.c_str(), magicString.length()),
8253 magicString.length());
8254
8255 // Send message
8256 union {
8257 // Space large enough to hold an 'int'
8258 char control[CMSG_SPACE(sizeof(int))];
8259 struct cmsghdr cmh;
8260 } s_u;
8261 s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int));
8262 s_u.cmh.cmsg_level = SOL_SOCKET;
8263 s_u.cmh.cmsg_type = SCM_RIGHTS;
8264 memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int));
8265
8266 // Set up the callback providing message parameters
8267 TestSendMsgParamsCallback sendMsgCB(
8268 MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control);
8269 socket->setSendMsgParamCB(&sendMsgCB);
8270
8271 // We must transmit at least 1 byte of real data in order
8272 // to send ancillary data
8273 int s_data = 12345;
8274 WriteCallback wcb;
8275 socket->write(&wcb, &s_data, sizeof(s_data));
8276 ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
8277
8278 // Receive the message
8279 union {
8280 // Space large enough to hold an 'int'
8281 char control[CMSG_SPACE(sizeof(int))];
8282 struct cmsghdr cmh;
8283 } r_u;
8284 struct msghdr msgh;
8285 struct iovec iov;
8286 int r_data = 0;
8287
8288 msgh.msg_control = r_u.control;
8289 msgh.msg_controllen = sizeof(r_u.control);
8290 msgh.msg_name = nullptr;
8291 msgh.msg_namelen = 0;
8292 msgh.msg_iov = &iov;
8293 msgh.msg_iovlen = 1;
8294 iov.iov_base = &r_data;
8295 iov.iov_len = sizeof(r_data);
8296
8297 // Receive data
8298 ASSERT_NE(netops::recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno;
8299
8300 // Validate the received message
8301 ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int)));
8302 ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET);
8303 ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS);
8304 ASSERT_EQ(r_data, s_data);
8305 int fd = 0;
8306 memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int));
8307 ASSERT_NE(fd, 0);
8308 SCOPE_EXIT { close(fd); };
8309
8310 std::vector<uint8_t> transferredMagicString(magicString.length() + 1, 0);
8311
8312 // Reposition to the beginning of the file
8313 ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));
8314
8315 // Read the magic string back, and compare it with the original
8316 ASSERT_EQ(
8317 magicString.length(),
8318 read(fd, transferredMagicString.data(), transferredMagicString.size()));
8319 ASSERT_TRUE(std::equal(
8320 magicString.begin(), magicString.end(), transferredMagicString.begin()));
8321 }
8322
TEST(AsyncSocketTest,UnixDomainSocketErrMessageCB)8323 TEST(AsyncSocketTest, UnixDomainSocketErrMessageCB) {
8324 // In the latest stable kernel 4.14.3 as of 2017-12-04, Unix Domain
8325 // Socket (UDS) does not support MSG_ERRQUEUE. So
8326 // recvmsg(MSG_ERRQUEUE) will read application data from UDS which
8327 // breaks application message flow. To avoid this problem,
8328 // AsyncSocket currently disables setErrMessageCB for UDS.
8329 //
8330 // This tests two things for UDS
8331 // 1. setErrMessageCB fails
8332 // 2. recvmsg(MSG_ERRQUEUE) reads application data
8333 //
8334 // Feel free to remove this test if UDS supports MSG_ERRQUEUE in the future.
8335
8336 NetworkSocket fd[2];
8337 EXPECT_EQ(netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fd), 0);
8338 ASSERT_NE(fd[0], NetworkSocket());
8339 ASSERT_NE(fd[1], NetworkSocket());
8340 SCOPE_EXIT { netops::close(fd[1]); };
8341
8342 EXPECT_EQ(netops::set_socket_non_blocking(fd[0]), 0);
8343 EXPECT_EQ(netops::set_socket_non_blocking(fd[1]), 0);
8344
8345 EventBase evb;
8346 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, fd[0]);
8347
8348 // setErrMessageCB should fail for unix domain socket
8349 TestErrMessageCallback errMsgCB;
8350 ASSERT_NE(&errMsgCB, nullptr);
8351 socket->setErrMessageCB(&errMsgCB);
8352 ASSERT_EQ(socket->getErrMessageCallback(), nullptr);
8353
8354 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
8355 // The following verifies that MSG_ERRQUEUE does not work for UDS,
8356 // and recvmsg reads application data
8357 union {
8358 // Space large enough to hold an 'int'
8359 char control[CMSG_SPACE(sizeof(int))];
8360 struct cmsghdr cmh;
8361 } r_u;
8362 struct msghdr msgh;
8363 struct iovec iov;
8364 int recv_data = 0;
8365
8366 msgh.msg_control = r_u.control;
8367 msgh.msg_controllen = sizeof(r_u.control);
8368 msgh.msg_name = nullptr;
8369 msgh.msg_namelen = 0;
8370 msgh.msg_iov = &iov;
8371 msgh.msg_iovlen = 1;
8372 iov.iov_base = &recv_data;
8373 iov.iov_len = sizeof(recv_data);
8374
8375 // there is no data, recvmsg should fail
8376 EXPECT_EQ(netops::recvmsg(fd[1], &msgh, MSG_ERRQUEUE), -1);
8377 EXPECT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK);
8378
8379 // provide some application data, error queue should be empty if it exists
8380 // However, UDS reads application data as error message
8381 int test_data = 123456;
8382 WriteCallback wcb;
8383 socket->write(&wcb, &test_data, sizeof(test_data));
8384 recv_data = 0;
8385 ASSERT_NE(netops::recvmsg(fd[1], &msgh, MSG_ERRQUEUE), -1);
8386 ASSERT_EQ(recv_data, test_data);
8387 #endif // FOLLY_HAVE_MSG_ERRQUEUE
8388 }
8389
TEST(AsyncSocketTest,V6TosReflectTest)8390 TEST(AsyncSocketTest, V6TosReflectTest) {
8391 EventBase eventBase;
8392
8393 // Create a server socket
8394 std::shared_ptr<AsyncServerSocket> serverSocket(
8395 AsyncServerSocket::newSocket(&eventBase));
8396 folly::IPAddress ip("::1");
8397 std::vector<folly::IPAddress> serverIp;
8398 serverIp.push_back(ip);
8399 serverSocket->bind(serverIp, 0);
8400 serverSocket->listen(16);
8401 folly::SocketAddress serverAddress;
8402 serverSocket->getAddress(&serverAddress);
8403
8404 // Enable TOS reflect
8405 serverSocket->setTosReflect(true);
8406
8407 // Add a callback to accept one connection then stop the loop
8408 TestAcceptCallback acceptCallback;
8409 acceptCallback.setConnectionAcceptedFn(
8410 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
8411 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8412 });
8413 acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
8414 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8415 });
8416 serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
8417 serverSocket->startAccepting();
8418
8419 // Create a client socket, setsockopt() the TOS before connecting
8420 auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
8421 ConnCallback* ccb,
8422 EventBase* evb,
8423 folly::SocketAddress sAddr) {
8424 clientSock = AsyncSocket::newSocket(evb);
8425 SocketOptionKey v6Opts = {IPPROTO_IPV6, IPV6_TCLASS};
8426 SocketOptionMap optionMap;
8427 optionMap.insert({v6Opts, 0x2c});
8428 SocketAddress bindAddr("0.0.0.0", 0);
8429 clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
8430 };
8431
8432 std::shared_ptr<AsyncSocket> socket(nullptr);
8433 ConnCallback cb;
8434 clientThread(socket, &cb, &eventBase, serverAddress);
8435
8436 eventBase.loop();
8437
8438 // Verify if the connection is accepted and if the accepted socket has
8439 // setsockopt on the TOS for the same value that was on the client socket
8440 auto fd = acceptCallback.getEvents()->at(1).fd;
8441 ASSERT_NE(fd, NetworkSocket());
8442 int value;
8443 socklen_t valueLength = sizeof(value);
8444 int rc =
8445 netops::getsockopt(fd, IPPROTO_IPV6, IPV6_TCLASS, &value, &valueLength);
8446 ASSERT_EQ(rc, 0);
8447 ASSERT_EQ(value, 0x2c);
8448
8449 // Additional Test for ConnectCallback without bindAddr
8450 serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
8451 serverSocket->startAccepting();
8452
8453 auto newClientSock = AsyncSocket::newSocket(&eventBase);
8454 TestConnectCallback callback;
8455 // connect call will not set this SO_REUSEADDR if we do not
8456 // pass the bindAddress in its call; so we can safely verify this.
8457 newClientSock->connect(&callback, serverAddress, 30);
8458
8459 // Collect events
8460 eventBase.loop();
8461
8462 auto acceptedFd = acceptCallback.getEvents()->at(1).fd;
8463 ASSERT_NE(acceptedFd, NetworkSocket());
8464 int reuseAddrVal;
8465 socklen_t reuseAddrValLen = sizeof(reuseAddrVal);
8466 // Get the socket created underneath connect call of AsyncSocket
8467 auto usedSockFd = newClientSock->getNetworkSocket();
8468 int getOptRet = netops::getsockopt(
8469 usedSockFd, SOL_SOCKET, SO_REUSEADDR, &reuseAddrVal, &reuseAddrValLen);
8470 ASSERT_EQ(getOptRet, 0);
8471 ASSERT_EQ(reuseAddrVal, 1 /* configured through preConnect*/);
8472 }
8473
TEST(AsyncSocketTest,V4TosReflectTest)8474 TEST(AsyncSocketTest, V4TosReflectTest) {
8475 EventBase eventBase;
8476
8477 // Create a server socket
8478 std::shared_ptr<AsyncServerSocket> serverSocket(
8479 AsyncServerSocket::newSocket(&eventBase));
8480 folly::IPAddress ip("127.0.0.1");
8481 std::vector<folly::IPAddress> serverIp;
8482 serverIp.push_back(ip);
8483 serverSocket->bind(serverIp, 0);
8484 serverSocket->listen(16);
8485 folly::SocketAddress serverAddress;
8486 serverSocket->getAddress(&serverAddress);
8487
8488 // Enable TOS reflect
8489 serverSocket->setTosReflect(true);
8490
8491 // Add a callback to accept one connection then stop the loop
8492 TestAcceptCallback acceptCallback;
8493 acceptCallback.setConnectionAcceptedFn(
8494 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
8495 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8496 });
8497 acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
8498 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8499 });
8500 serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
8501 serverSocket->startAccepting();
8502
8503 // Create a client socket, setsockopt() the TOS before connecting
8504 auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
8505 ConnCallback* ccb,
8506 EventBase* evb,
8507 folly::SocketAddress sAddr) {
8508 clientSock = AsyncSocket::newSocket(evb);
8509 SocketOptionKey v4Opts = {IPPROTO_IP, IP_TOS};
8510 SocketOptionMap optionMap;
8511 optionMap.insert({v4Opts, 0x2c});
8512 SocketAddress bindAddr("0.0.0.0", 0);
8513 clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
8514 };
8515
8516 std::shared_ptr<AsyncSocket> socket(nullptr);
8517 ConnCallback cb;
8518 clientThread(socket, &cb, &eventBase, serverAddress);
8519
8520 eventBase.loop();
8521
8522 // Verify if the connection is accepted and if the accepted socket has
8523 // setsockopt on the TOS for the same value that was on the client socket
8524 auto fd = acceptCallback.getEvents()->at(1).fd;
8525 ASSERT_NE(fd, NetworkSocket());
8526 int value;
8527 socklen_t valueLength = sizeof(value);
8528 int rc = netops::getsockopt(fd, IPPROTO_IP, IP_TOS, &value, &valueLength);
8529 ASSERT_EQ(rc, 0);
8530 ASSERT_EQ(value, 0x2c);
8531 }
8532
TEST(AsyncSocketTest,V6AcceptedTosTest)8533 TEST(AsyncSocketTest, V6AcceptedTosTest) {
8534 EventBase eventBase;
8535
8536 // This test verifies if the ListenerTos set on a socket is
8537 // propagated properly to accepted socket connections
8538
8539 // Create a server socket
8540 std::shared_ptr<AsyncServerSocket> serverSocket(
8541 AsyncServerSocket::newSocket(&eventBase));
8542 folly::IPAddress ip("::1");
8543 std::vector<folly::IPAddress> serverIp;
8544 serverIp.push_back(ip);
8545 serverSocket->bind(serverIp, 0);
8546 serverSocket->listen(16);
8547 folly::SocketAddress serverAddress;
8548 serverSocket->getAddress(&serverAddress);
8549
8550 // Set listener TOS to 0x74 i.e. dscp 29
8551 serverSocket->setListenerTos(0x74);
8552
8553 // Add a callback to accept one connection then stop the loop
8554 TestAcceptCallback acceptCallback;
8555 acceptCallback.setConnectionAcceptedFn(
8556 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
8557 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8558 });
8559 acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
8560 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8561 });
8562 serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
8563 serverSocket->startAccepting();
8564
8565 // Create a client socket, setsockopt() the TOS before connecting
8566 auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
8567 ConnCallback* ccb,
8568 EventBase* evb,
8569 folly::SocketAddress sAddr) {
8570 clientSock = AsyncSocket::newSocket(evb);
8571 SocketOptionKey v6Opts = {IPPROTO_IPV6, IPV6_TCLASS};
8572 SocketOptionMap optionMap;
8573 optionMap.insert({v6Opts, 0x2c});
8574 SocketAddress bindAddr("0.0.0.0", 0);
8575 clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
8576 };
8577
8578 std::shared_ptr<AsyncSocket> socket(nullptr);
8579 ConnCallback cb;
8580 clientThread(socket, &cb, &eventBase, serverAddress);
8581
8582 eventBase.loop();
8583
8584 // Verify if the connection is accepted and if the accepted socket has
8585 // setsockopt on the TOS for the same value that the listener was set to
8586 auto fd = acceptCallback.getEvents()->at(1).fd;
8587 ASSERT_NE(fd, NetworkSocket());
8588 int value;
8589 socklen_t valueLength = sizeof(value);
8590 int rc =
8591 netops::getsockopt(fd, IPPROTO_IPV6, IPV6_TCLASS, &value, &valueLength);
8592 ASSERT_EQ(rc, 0);
8593 ASSERT_EQ(value, 0x74);
8594 }
8595
TEST(AsyncSocketTest,V4AcceptedTosTest)8596 TEST(AsyncSocketTest, V4AcceptedTosTest) {
8597 EventBase eventBase;
8598
8599 // This test verifies if the ListenerTos set on a socket is
8600 // propagated properly to accepted socket connections
8601
8602 // Create a server socket
8603 std::shared_ptr<AsyncServerSocket> serverSocket(
8604 AsyncServerSocket::newSocket(&eventBase));
8605 folly::IPAddress ip("127.0.0.1");
8606 std::vector<folly::IPAddress> serverIp;
8607 serverIp.push_back(ip);
8608 serverSocket->bind(serverIp, 0);
8609 serverSocket->listen(16);
8610 folly::SocketAddress serverAddress;
8611 serverSocket->getAddress(&serverAddress);
8612
8613 // Set listener TOS to 0x74 i.e. dscp 29
8614 serverSocket->setListenerTos(0x74);
8615
8616 // Add a callback to accept one connection then stop the loop
8617 TestAcceptCallback acceptCallback;
8618 acceptCallback.setConnectionAcceptedFn(
8619 [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
8620 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8621 });
8622 acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
8623 serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8624 });
8625 serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
8626 serverSocket->startAccepting();
8627
8628 // Create a client socket, setsockopt() the TOS before connecting
8629 auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
8630 ConnCallback* ccb,
8631 EventBase* evb,
8632 folly::SocketAddress sAddr) {
8633 clientSock = AsyncSocket::newSocket(evb);
8634 SocketOptionKey v4Opts = {IPPROTO_IP, IP_TOS};
8635 SocketOptionMap optionMap;
8636 optionMap.insert({v4Opts, 0x2c});
8637 SocketAddress bindAddr("0.0.0.0", 0);
8638 clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
8639 };
8640
8641 std::shared_ptr<AsyncSocket> socket(nullptr);
8642 ConnCallback cb;
8643 clientThread(socket, &cb, &eventBase, serverAddress);
8644
8645 eventBase.loop();
8646
8647 // Verify if the connection is accepted and if the accepted socket has
8648 // setsockopt on the TOS for the same value that the listener was set to
8649 auto fd = acceptCallback.getEvents()->at(1).fd;
8650 ASSERT_NE(fd, NetworkSocket());
8651 int value;
8652 socklen_t valueLength = sizeof(value);
8653 int rc = netops::getsockopt(fd, IPPROTO_IP, IP_TOS, &value, &valueLength);
8654 ASSERT_EQ(rc, 0);
8655 ASSERT_EQ(value, 0x74);
8656 }
8657 #endif
8658
8659 #if defined(__linux__)
TEST(AsyncSocketTest,getBufInUse)8660 TEST(AsyncSocketTest, getBufInUse) {
8661 EventBase eventBase;
8662 std::shared_ptr<AsyncServerSocket> server(
8663 AsyncServerSocket::newSocket(&eventBase));
8664 server->bind(0);
8665 server->listen(5);
8666
8667 std::shared_ptr<AsyncSocket> client = AsyncSocket::newSocket(&eventBase);
8668 client->connect(nullptr, server->getAddress());
8669
8670 NetworkSocket servfd = server->getNetworkSocket();
8671 NetworkSocket accepted;
8672 uint64_t maxTries = 5;
8673
8674 do {
8675 std::this_thread::yield();
8676 eventBase.loop();
8677 accepted = netops::accept(servfd, nullptr, nullptr);
8678 } while (accepted == NetworkSocket() && --maxTries);
8679
8680 // Exhaustion number of tries to accept client connection, good bye
8681 ASSERT_TRUE(accepted != NetworkSocket());
8682
8683 auto clientAccepted = AsyncSocket::newSocket(nullptr, accepted);
8684
8685 // Use minimum receive buffer size
8686 clientAccepted->setRecvBufSize(0);
8687
8688 // Use maximum send buffer size
8689 client->setSendBufSize((unsigned)-1);
8690
8691 std::string testData;
8692 for (int i = 0; i < 10000; ++i) {
8693 testData += "0123456789";
8694 }
8695
8696 client->write(nullptr, (const void*)testData.c_str(), testData.size());
8697
8698 std::this_thread::yield();
8699 eventBase.loop();
8700
8701 size_t recvBufSize = clientAccepted->getRecvBufInUse();
8702 size_t sendBufSize = client->getSendBufInUse();
8703
8704 EXPECT_EQ((recvBufSize + sendBufSize), testData.size());
8705 EXPECT_GT(recvBufSize, 0);
8706 EXPECT_GT(sendBufSize, 0);
8707 }
8708 #endif
8709
TEST(AsyncSocketTest,QueueTimeout)8710 TEST(AsyncSocketTest, QueueTimeout) {
8711 // Create a new AsyncServerSocket
8712 EventBase eventBase;
8713 std::shared_ptr<AsyncServerSocket> serverSocket(
8714 AsyncServerSocket::newSocket(&eventBase));
8715 serverSocket->bind(0);
8716 serverSocket->listen(16);
8717 folly::SocketAddress serverAddress;
8718 serverSocket->getAddress(&serverAddress);
8719
8720 constexpr auto kConnectionTimeout = milliseconds(10);
8721 serverSocket->setQueueTimeout(kConnectionTimeout);
8722
8723 TestAcceptCallback acceptCb;
8724 acceptCb.setConnectionAcceptedFn(
8725 [&, called = false](auto&&...) mutable {
8726 ASSERT_FALSE(called)
8727 << "Only the first connection should have been dequeued";
8728 called = true;
8729 // Allow plenty of time for the AsyncSocketServer's event loop to run.
8730 // This should leave no doubt that the acceptor thread has enough time
8731 // to dequeue. If the dequeue succeeds, then our expiry code is broken.
8732 constexpr auto kEventLoopTime = kConnectionTimeout * 5;
8733 eventBase.runInEventBaseThread([&]() {
8734 eventBase.tryRunAfterDelay(
8735 [&]() { serverSocket->removeAcceptCallback(&acceptCb, nullptr); },
8736 milliseconds(kEventLoopTime).count());
8737 });
8738 // After the first message is enqueued, sleep long enough so that the
8739 // second message expires before it has a chance to dequeue.
8740 std::this_thread::sleep_for(kConnectionTimeout);
8741 });
8742 ScopedEventBaseThread acceptThread("ioworker_test");
8743
8744 TestConnectionEventCallback connectionEventCb;
8745 serverSocket->setConnectionEventCallback(&connectionEventCb);
8746 serverSocket->addAcceptCallback(&acceptCb, acceptThread.getEventBase());
8747 serverSocket->startAccepting();
8748
8749 std::shared_ptr<AsyncSocket> clientSocket1(
8750 AsyncSocket::newSocket(&eventBase, serverAddress));
8751 std::shared_ptr<AsyncSocket> clientSocket2(
8752 AsyncSocket::newSocket(&eventBase, serverAddress));
8753
8754 // Loop until we are stopped
8755 eventBase.loop();
8756
8757 EXPECT_EQ(connectionEventCb.getConnectionEnqueuedForAcceptCallback(), 2);
8758 // Since the second message is expired, it should NOT be dequeued
8759 EXPECT_EQ(connectionEventCb.getConnectionDequeuedByAcceptCallback(), 1);
8760 }
8761
8762 class TestRXTimestampsCallback
8763 : public folly::AsyncSocket::ReadAncillaryDataCallback {
8764 public:
TestRXTimestampsCallback()8765 TestRXTimestampsCallback() {}
ancillaryData(struct msghdr & msgh)8766 void ancillaryData(struct msghdr& msgh) noexcept override {
8767 struct cmsghdr* cmsg;
8768 for (cmsg = CMSG_FIRSTHDR(&msgh); cmsg != nullptr;
8769 cmsg = CMSG_NXTHDR(&msgh, cmsg)) {
8770 if (cmsg->cmsg_level != SOL_SOCKET ||
8771 cmsg->cmsg_type != SO_TIMESTAMPING) {
8772 continue;
8773 }
8774 callCount_++;
8775 timespec* ts = (struct timespec*)CMSG_DATA(cmsg);
8776 actualRxTimestampSec_ = ts[0].tv_sec;
8777 }
8778 }
getAncillaryDataCtrlBuffer()8779 folly::MutableByteRange getAncillaryDataCtrlBuffer() override {
8780 return folly::MutableByteRange(ancillaryDataCtrlBuffer_);
8781 }
8782
8783 uint32_t callCount_{0};
8784 long actualRxTimestampSec_{0};
8785
8786 private:
8787 std::array<uint8_t, 1024> ancillaryDataCtrlBuffer_;
8788 };
8789
8790 /**
8791 * Test read ancillary data callback
8792 */
TEST(AsyncSocketTest,readAncillaryData)8793 TEST(AsyncSocketTest, readAncillaryData) {
8794 TestServer server;
8795
8796 // connect()
8797 EventBase evb;
8798 std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8799
8800 ConnCallback ccb;
8801 socket->connect(&ccb, server.getAddress(), 1);
8802 LOG(INFO) << "Client socket fd=" << socket->getNetworkSocket();
8803
8804 // Enable rx timestamp notifications
8805 ASSERT_NE(socket->getNetworkSocket(), NetworkSocket());
8806 int flags = folly::netops::SOF_TIMESTAMPING_SOFTWARE |
8807 folly::netops::SOF_TIMESTAMPING_RX_SOFTWARE |
8808 folly::netops::SOF_TIMESTAMPING_RX_HARDWARE;
8809 SocketOptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
8810 EXPECT_EQ(tstampingOpt.apply(socket->getNetworkSocket(), flags), 0);
8811
8812 // Accept the connection.
8813 std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
8814 LOG(INFO) << "Server socket fd=" << acceptedSocket->getNetworkSocket();
8815
8816 // Wait for connection
8817 evb.loop();
8818 ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
8819
8820 TestRXTimestampsCallback rxcb;
8821
8822 // Set read callback
8823 ReadCallback rcb(100);
8824 socket->setReadCB(&rcb);
8825
8826 // Get the timestamp when the message was write
8827 struct timespec currentTime;
8828 clock_gettime(CLOCK_REALTIME, ¤tTime);
8829 long writeTimestampSec = currentTime.tv_sec;
8830
8831 // write bytes from server (acceptedSocket) to client (socket).
8832 std::vector<uint8_t> wbuf(128, 'a');
8833 acceptedSocket->write(wbuf.data(), wbuf.size());
8834
8835 // Wait for reading to complete.
8836 evb.loopOnce();
8837 ASSERT_NE(rcb.buffers.size(), 0);
8838
8839 // Verify that if the callback is not set, it will not be called
8840 ASSERT_EQ(rxcb.callCount_, 0);
8841
8842 // Set up rx timestamp callbacks
8843 socket->setReadAncillaryDataCB(&rxcb);
8844 acceptedSocket->write(wbuf.data(), wbuf.size());
8845
8846 // Wait for reading to complete.
8847 evb.loopOnce();
8848 ASSERT_NE(rcb.buffers.size(), 0);
8849
8850 // Verify that after setting callback, the callback was called
8851 ASSERT_NE(rxcb.callCount_, 0);
8852 // Compare the received timestamp is within an expected range
8853 clock_gettime(CLOCK_REALTIME, ¤tTime);
8854 ASSERT_TRUE(rxcb.actualRxTimestampSec_ <= currentTime.tv_sec);
8855 ASSERT_TRUE(rxcb.actualRxTimestampSec_ >= writeTimestampSec);
8856
8857 // Close both sockets
8858 acceptedSocket->close();
8859 socket->close();
8860
8861 ASSERT_TRUE(socket->isClosedBySelf());
8862 ASSERT_FALSE(socket->isClosedByPeer());
8863 }
8864