1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <folly/io/async/test/AsyncSocketTest2.h>
18 
19 #include <fcntl.h>
20 #include <sys/types.h>
21 
22 #include <time.h>
23 #include <iostream>
24 #include <memory>
25 #include <thread>
26 
27 #include <folly/ExceptionWrapper.h>
28 #include <folly/Random.h>
29 #include <folly/SocketAddress.h>
30 #include <folly/experimental/TestUtil.h>
31 #include <folly/io/IOBuf.h>
32 #include <folly/io/SocketOptionMap.h>
33 #include <folly/io/async/AsyncTimeout.h>
34 #include <folly/io/async/EventBase.h>
35 #include <folly/io/async/ScopedEventBaseThread.h>
36 #include <folly/io/async/test/AsyncSocketTest.h>
37 #include <folly/io/async/test/MockAsyncSocketObserver.h>
38 #include <folly/io/async/test/MockAsyncTransportObserver.h>
39 #include <folly/io/async/test/TFOTest.h>
40 #include <folly/io/async/test/Util.h>
41 #include <folly/net/test/MockNetOpsDispatcher.h>
42 #include <folly/portability/GMock.h>
43 #include <folly/portability/GTest.h>
44 #include <folly/portability/Sockets.h>
45 #include <folly/portability/Unistd.h>
46 #include <folly/synchronization/Baton.h>
47 #include <folly/test/SocketAddressTestHelper.h>
48 
49 using std::min;
50 using std::string;
51 using std::unique_ptr;
52 using std::vector;
53 using std::chrono::milliseconds;
54 using testing::MatchesRegex;
55 
56 using namespace folly;
57 using namespace folly::test;
58 using namespace testing;
59 
60 namespace {
61 // string and corresponding vector with 100 characters
62 const std::string kOneHundredCharacterString(
63     "ThisIsAVeryLongStringThatHas100Characters"
64     "AndIsUniqueEnoughToBeInterestingForTestUsageNowEndOfMessage");
65 const std::vector<uint8_t> kOneHundredCharacterVec(
66     kOneHundredCharacterString.begin(), kOneHundredCharacterString.end());
67 
msgFlagsToWriteFlags(const int msg_flags)68 WriteFlags msgFlagsToWriteFlags(const int msg_flags) {
69   WriteFlags flags = WriteFlags::NONE;
70 #ifdef MSG_MORE
71   if (msg_flags & MSG_MORE) {
72     flags = flags | WriteFlags::CORK;
73   }
74 #endif // MSG_MORE
75 
76 #ifdef MSG_EOR
77   if (msg_flags & MSG_EOR) {
78     flags = flags | WriteFlags::EOR;
79   }
80 #endif
81 
82 #ifdef MSG_ZEROCOPY
83   if (msg_flags & MSG_ZEROCOPY) {
84     flags = flags | WriteFlags::WRITE_MSG_ZEROCOPY;
85   }
86 #endif
87   return flags;
88 }
89 
getMsgAncillaryTsFlags(const struct msghdr & msg)90 WriteFlags getMsgAncillaryTsFlags(const struct msghdr& msg) {
91   const struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
92   if (!cmsg || cmsg->cmsg_level != SOL_SOCKET ||
93       cmsg->cmsg_type != SO_TIMESTAMPING ||
94       cmsg->cmsg_len != CMSG_LEN(sizeof(uint32_t))) {
95     return WriteFlags::NONE;
96   }
97 
98   const uint32_t* sofFlags =
99       (reinterpret_cast<const uint32_t*>(CMSG_DATA(cmsg)));
100   WriteFlags flags = WriteFlags::NONE;
101   if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SCHED) {
102     flags = flags | WriteFlags::TIMESTAMP_SCHED;
103   }
104   if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE) {
105     flags = flags | WriteFlags::TIMESTAMP_TX;
106   }
107   if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_ACK) {
108     flags = flags | WriteFlags::TIMESTAMP_ACK;
109   }
110 
111   return flags;
112 }
113 
getMsgAncillaryTsFlags(const struct msghdr * msg)114 WriteFlags getMsgAncillaryTsFlags(const struct msghdr* msg) {
115   return getMsgAncillaryTsFlags(*msg);
116 }
117 
118 MATCHER_P(SendmsgMsghdrHasTotalIovLen, len, "") {
119   size_t iovLen = 0;
120   for (size_t i = 0; i < arg.msg_iovlen; i++) {
121     iovLen += arg.msg_iov[i].iov_len;
122   }
123   return len == iovLen;
124 }
125 
126 MATCHER_P(SendmsgInvocHasTotalIovLen, len, "") {
127   size_t iovLen = 0;
128   for (const auto& iov : arg.iovs) {
129     iovLen += iov.iov_len;
130   }
131   return len == iovLen;
132 }
133 
134 MATCHER_P(SendmsgInvocHasIovFirstByte, firstBytePtr, "") {
135   if (arg.iovs.empty()) {
136     return false;
137   }
138 
139   const auto& firstIov = arg.iovs.front();
140   auto iovFirstBytePtr = const_cast<void*>(
141       static_cast<const void*>(reinterpret_cast<uint8_t*>(firstIov.iov_base)));
142   return firstBytePtr == iovFirstBytePtr;
143 }
144 
145 MATCHER_P(SendmsgInvocHasIovLastByte, lastBytePtr, "") {
146   if (arg.iovs.empty()) {
147     return false;
148   }
149 
150   const auto& lastIov = arg.iovs.back();
151   auto iovLastBytePtr = const_cast<void*>(static_cast<const void*>(
152       reinterpret_cast<uint8_t*>(lastIov.iov_base) + lastIov.iov_len - 1));
153   return lastBytePtr == iovLastBytePtr;
154 }
155 
156 MATCHER_P(SendmsgInvocMsgFlagsEq, writeFlags, "") {
157   return writeFlags == arg.writeFlagsInMsgFlags;
158 }
159 
160 MATCHER_P(SendmsgInvocAncillaryFlagsEq, writeFlags, "") {
161   return writeFlags == arg.writeFlagsInAncillary;
162 }
163 
164 MATCHER_P2(ByteEventMatching, type, offset, "") {
165   if (type != arg.type || (size_t)offset != arg.offset) {
166     return false;
167   }
168   return true;
169 }
170 } // namespace
171 
172 class DelayedWrite : public AsyncTimeout {
173  public:
DelayedWrite(const std::shared_ptr<AsyncSocket> & socket,unique_ptr<IOBuf> && bufs,AsyncTransportWrapper::WriteCallback * wcb,bool cork,bool lastWrite=false)174   DelayedWrite(
175       const std::shared_ptr<AsyncSocket>& socket,
176       unique_ptr<IOBuf>&& bufs,
177       AsyncTransportWrapper::WriteCallback* wcb,
178       bool cork,
179       bool lastWrite = false)
180       : AsyncTimeout(socket->getEventBase()),
181         socket_(socket),
182         bufs_(std::move(bufs)),
183         wcb_(wcb),
184         cork_(cork),
185         lastWrite_(lastWrite) {}
186 
187  private:
timeoutExpired()188   void timeoutExpired() noexcept override {
189     WriteFlags flags = cork_ ? WriteFlags::CORK : WriteFlags::NONE;
190     socket_->writeChain(wcb_, std::move(bufs_), flags);
191     if (lastWrite_) {
192       socket_->shutdownWrite();
193     }
194   }
195 
196   std::shared_ptr<AsyncSocket> socket_;
197   unique_ptr<IOBuf> bufs_;
198   AsyncTransportWrapper::WriteCallback* wcb_;
199   bool cork_;
200   bool lastWrite_;
201 };
202 
203 ///////////////////////////////////////////////////////////////////////////
204 // connect() tests
205 ///////////////////////////////////////////////////////////////////////////
206 
207 /**
208  * Test connecting to a server
209  */
TEST(AsyncSocketTest,Connect)210 TEST(AsyncSocketTest, Connect) {
211   // Start listening on a local port
212   TestServer server;
213 
214   // Connect using a AsyncSocket
215   EventBase evb;
216   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
217   ConnCallback cb;
218   const auto startedAt = std::chrono::steady_clock::now();
219   socket->connect(&cb, server.getAddress(), 30);
220 
221   evb.loop();
222   const auto finishedAt = std::chrono::steady_clock::now();
223 
224   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
225   EXPECT_LE(0, socket->getConnectTime().count());
226   EXPECT_GE(socket->getConnectStartTime(), startedAt);
227   EXPECT_LE(socket->getConnectStartTime(), socket->getConnectEndTime());
228   EXPECT_LE(socket->getConnectEndTime(), finishedAt);
229   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
230 }
231 
232 enum class TFOState {
233   DISABLED,
234   ENABLED,
235 };
236 
237 class AsyncSocketConnectTest : public ::testing::TestWithParam<TFOState> {};
238 
getTestingValues()239 std::vector<TFOState> getTestingValues() {
240   std::vector<TFOState> vals;
241   vals.emplace_back(TFOState::DISABLED);
242 
243 #if FOLLY_ALLOW_TFO
244   vals.emplace_back(TFOState::ENABLED);
245 #endif
246   return vals;
247 }
248 
249 INSTANTIATE_TEST_SUITE_P(
250     ConnectTests,
251     AsyncSocketConnectTest,
252     ::testing::ValuesIn(getTestingValues()));
253 
254 /**
255  * Test connecting to a server that isn't listening
256  */
TEST(AsyncSocketTest,ConnectRefused)257 TEST(AsyncSocketTest, ConnectRefused) {
258   EventBase evb;
259 
260   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
261 
262   // Hopefully nothing is actually listening on this address
263   folly::SocketAddress addr("127.0.0.1", 65535);
264   ConnCallback cb;
265   socket->connect(&cb, addr, 30);
266 
267   evb.loop();
268 
269   EXPECT_EQ(STATE_FAILED, cb.state);
270   EXPECT_EQ(AsyncSocketException::NOT_OPEN, cb.exception.getType());
271   EXPECT_LE(0, socket->getConnectTime().count());
272   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
273 }
274 
275 /**
276  * Test connection timeout
277  */
TEST(AsyncSocketTest,ConnectTimeout)278 TEST(AsyncSocketTest, ConnectTimeout) {
279   EventBase evb;
280 
281   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
282 
283   // Try connecting to server that won't respond.
284   //
285   // This depends somewhat on the network where this test is run.
286   // Hopefully this IP will be routable but unresponsive.
287   // (Alternatively, we could try listening on a local raw socket, but that
288   // normally requires root privileges.)
289   auto host = SocketAddressTestHelper::isIPv6Enabled()
290       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
291       : SocketAddressTestHelper::isIPv4Enabled()
292       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
293       : nullptr;
294   SocketAddress addr(host, 65535);
295   ConnCallback cb;
296   socket->connect(&cb, addr, 1); // also set a ridiculously small timeout
297 
298   evb.loop();
299 
300   ASSERT_EQ(cb.state, STATE_FAILED);
301   if (cb.exception.getType() == AsyncSocketException::NOT_OPEN) {
302     // This can happen if we could not route to the IP address picked above.
303     // In this case the connect will fail immediately rather than timing out.
304     // Just skip the test in this case.
305     SKIP() << "do not have a routable but unreachable IP address";
306   }
307   ASSERT_EQ(cb.exception.getType(), AsyncSocketException::TIMED_OUT);
308 
309   // Verify that we can still get the peer address after a timeout.
310   // Use case is if the client was created from a client pool, and we want
311   // to log which peer failed.
312   folly::SocketAddress peer;
313   socket->getPeerAddress(&peer);
314   ASSERT_EQ(peer, addr);
315   EXPECT_LE(0, socket->getConnectTime().count());
316   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(1));
317 }
318 
319 /**
320  * Test writing immediately after connecting, without waiting for connect
321  * to finish.
322  */
TEST_P(AsyncSocketConnectTest,ConnectAndWrite)323 TEST_P(AsyncSocketConnectTest, ConnectAndWrite) {
324   TestServer server;
325 
326   // connect()
327   EventBase evb;
328   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
329 
330   if (GetParam() == TFOState::ENABLED) {
331     socket->enableTFO();
332   }
333 
334   ConnCallback ccb;
335   socket->connect(&ccb, server.getAddress(), 30);
336 
337   // write()
338   char buf[128];
339   memset(buf, 'a', sizeof(buf));
340   WriteCallback wcb(true /*enableReleaseIOBufCallback*/);
341   // use writeChain so we can pass an IOBuf
342   socket->writeChain(&wcb, IOBuf::copyBuffer(buf, sizeof(buf)));
343 
344   // Loop.  We don't bother accepting on the server socket yet.
345   // The kernel should be able to buffer the write request so it can succeed.
346   evb.loop();
347 
348   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
349   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
350   ASSERT_EQ(wcb.numIoBufCount, 1);
351   ASSERT_EQ(wcb.numIoBufBytes, sizeof(buf));
352 
353   // Make sure the server got a connection and received the data
354   socket->close();
355   server.verifyConnection(buf, sizeof(buf));
356 
357   ASSERT_TRUE(socket->isClosedBySelf());
358   ASSERT_FALSE(socket->isClosedByPeer());
359   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
360 }
361 
362 /**
363  * Test connecting using a nullptr connect callback.
364  */
TEST_P(AsyncSocketConnectTest,ConnectNullCallback)365 TEST_P(AsyncSocketConnectTest, ConnectNullCallback) {
366   TestServer server;
367 
368   // connect()
369   EventBase evb;
370   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
371   if (GetParam() == TFOState::ENABLED) {
372     socket->enableTFO();
373   }
374 
375   socket->connect(nullptr, server.getAddress(), 30);
376 
377   // write some data, just so we have some way of verifing
378   // that the socket works correctly after connecting
379   char buf[128];
380   memset(buf, 'a', sizeof(buf));
381   WriteCallback wcb;
382   socket->write(&wcb, buf, sizeof(buf));
383 
384   evb.loop();
385 
386   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
387 
388   // Make sure the server got a connection and received the data
389   socket->close();
390   server.verifyConnection(buf, sizeof(buf));
391 
392   ASSERT_TRUE(socket->isClosedBySelf());
393   ASSERT_FALSE(socket->isClosedByPeer());
394 }
395 
396 /**
397  * Test calling both write() and close() immediately after connecting, without
398  * waiting for connect to finish.
399  *
400  * This exercises the STATE_CONNECTING_CLOSING code.
401  */
TEST_P(AsyncSocketConnectTest,ConnectWriteAndClose)402 TEST_P(AsyncSocketConnectTest, ConnectWriteAndClose) {
403   TestServer server;
404 
405   // connect()
406   EventBase evb;
407   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
408   if (GetParam() == TFOState::ENABLED) {
409     socket->enableTFO();
410   }
411   ConnCallback ccb;
412   socket->connect(&ccb, server.getAddress(), 30);
413 
414   // write()
415   char buf[128];
416   memset(buf, 'a', sizeof(buf));
417   WriteCallback wcb;
418   socket->write(&wcb, buf, sizeof(buf));
419 
420   // close()
421   socket->close();
422 
423   // Loop.  We don't bother accepting on the server socket yet.
424   // The kernel should be able to buffer the write request so it can succeed.
425   evb.loop();
426 
427   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
428   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
429 
430   // Make sure the server got a connection and received the data
431   server.verifyConnection(buf, sizeof(buf));
432 
433   ASSERT_TRUE(socket->isClosedBySelf());
434   ASSERT_FALSE(socket->isClosedByPeer());
435 }
436 
437 /**
438  * Test calling close() immediately after connect()
439  */
TEST(AsyncSocketTest,ConnectAndClose)440 TEST(AsyncSocketTest, ConnectAndClose) {
441   TestServer server;
442 
443   // Connect using a AsyncSocket
444   EventBase evb;
445   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
446   ConnCallback ccb;
447   socket->connect(&ccb, server.getAddress(), 30);
448 
449   // Hopefully the connect didn't succeed immediately.
450   // If it did, we can't exercise the close-while-connecting code path.
451   if (ccb.state == STATE_SUCCEEDED) {
452     LOG(INFO) << "connect() succeeded immediately; aborting test "
453                  "of close-during-connect behavior";
454     return;
455   }
456 
457   socket->close();
458 
459   // Loop, although there shouldn't be anything to do.
460   evb.loop();
461 
462   // Make sure the connection was aborted
463   ASSERT_EQ(ccb.state, STATE_FAILED);
464 
465   ASSERT_TRUE(socket->isClosedBySelf());
466   ASSERT_FALSE(socket->isClosedByPeer());
467 }
468 
469 /**
470  * Test calling closeNow() immediately after connect()
471  *
472  * This should be identical to the normal close behavior.
473  */
TEST(AsyncSocketTest,ConnectAndCloseNow)474 TEST(AsyncSocketTest, ConnectAndCloseNow) {
475   TestServer server;
476 
477   // Connect using a AsyncSocket
478   EventBase evb;
479   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
480   ConnCallback ccb;
481   socket->connect(&ccb, server.getAddress(), 30);
482 
483   // Hopefully the connect didn't succeed immediately.
484   // If it did, we can't exercise the close-while-connecting code path.
485   if (ccb.state == STATE_SUCCEEDED) {
486     LOG(INFO) << "connect() succeeded immediately; aborting test "
487                  "of closeNow()-during-connect behavior";
488     return;
489   }
490 
491   socket->closeNow();
492 
493   // Loop, although there shouldn't be anything to do.
494   evb.loop();
495 
496   // Make sure the connection was aborted
497   ASSERT_EQ(ccb.state, STATE_FAILED);
498 
499   ASSERT_TRUE(socket->isClosedBySelf());
500   ASSERT_FALSE(socket->isClosedByPeer());
501 }
502 
503 /**
504  * Test calling both write() and closeNow() immediately after connecting,
505  * without waiting for connect to finish.
506  *
507  * This should abort the pending write.
508  */
TEST(AsyncSocketTest,ConnectWriteAndCloseNow)509 TEST(AsyncSocketTest, ConnectWriteAndCloseNow) {
510   TestServer server;
511 
512   // connect()
513   EventBase evb;
514   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
515   ConnCallback ccb;
516   socket->connect(&ccb, server.getAddress(), 30);
517 
518   // Hopefully the connect didn't succeed immediately.
519   // If it did, we can't exercise the close-while-connecting code path.
520   if (ccb.state == STATE_SUCCEEDED) {
521     LOG(INFO) << "connect() succeeded immediately; aborting test "
522                  "of write-during-connect behavior";
523     return;
524   }
525 
526   // write()
527   char buf[128];
528   memset(buf, 'a', sizeof(buf));
529   WriteCallback wcb;
530   socket->write(&wcb, buf, sizeof(buf));
531 
532   // close()
533   socket->closeNow();
534 
535   // Loop, although there shouldn't be anything to do.
536   evb.loop();
537 
538   ASSERT_EQ(ccb.state, STATE_FAILED);
539   ASSERT_EQ(wcb.state, STATE_FAILED);
540 
541   ASSERT_TRUE(socket->isClosedBySelf());
542   ASSERT_FALSE(socket->isClosedByPeer());
543 }
544 
545 /**
546  * Test installing a read callback immediately, before connect() finishes.
547  */
TEST_P(AsyncSocketConnectTest,ConnectAndRead)548 TEST_P(AsyncSocketConnectTest, ConnectAndRead) {
549   TestServer server;
550 
551   // connect()
552   EventBase evb;
553   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
554   if (GetParam() == TFOState::ENABLED) {
555     socket->enableTFO();
556   }
557 
558   ConnCallback ccb;
559   socket->connect(&ccb, server.getAddress(), 30);
560 
561   ReadCallback rcb;
562   socket->setReadCB(&rcb);
563 
564   if (GetParam() == TFOState::ENABLED) {
565     // Trigger a connection
566     socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
567   }
568 
569   // Even though we haven't looped yet, we should be able to accept
570   // the connection and send data to it.
571   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
572   uint8_t buf[128];
573   memset(buf, 'a', sizeof(buf));
574   acceptedSocket->write(buf, sizeof(buf));
575   acceptedSocket->flush();
576   acceptedSocket->close();
577 
578   // Loop, although there shouldn't be anything to do.
579   evb.loop();
580 
581   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
582   ASSERT_EQ(rcb.buffers.size(), 1);
583   ASSERT_EQ(rcb.buffers[0].length, sizeof(buf));
584   ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);
585 
586   ASSERT_FALSE(socket->isClosedBySelf());
587   ASSERT_FALSE(socket->isClosedByPeer());
588 }
589 
TEST_P(AsyncSocketConnectTest,ConnectAndReadv)590 TEST_P(AsyncSocketConnectTest, ConnectAndReadv) {
591   TestServer server;
592 
593   // connect()
594   EventBase evb;
595   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
596   if (GetParam() == TFOState::ENABLED) {
597     socket->enableTFO();
598   }
599 
600   ConnCallback ccb;
601   socket->connect(&ccb, server.getAddress(), 30);
602 
603   static constexpr size_t kBuffSize = 10;
604   static constexpr size_t kLen = 40;
605   static constexpr size_t kDataSize = 128;
606 
607   ReadvCallback rcb(kBuffSize, kLen);
608   socket->setReadCB(&rcb);
609 
610   if (GetParam() == TFOState::ENABLED) {
611     // Trigger a connection
612     socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
613   }
614 
615   // Even though we haven't looped yet, we should be able to accept
616   // the connection and send data to it.
617   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
618   std::string data(kDataSize, 'A');
619   acceptedSocket->write(
620       reinterpret_cast<unsigned char*>(data.data()), data.size());
621   acceptedSocket->flush();
622   acceptedSocket->close();
623 
624   // Loop, although there shouldn't be anything to do.
625   evb.loop();
626 
627   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
628   rcb.verifyData(data);
629 
630   ASSERT_FALSE(socket->isClosedBySelf());
631   ASSERT_FALSE(socket->isClosedByPeer());
632 }
633 
634 /**
635  * Test installing a read callback and then closing immediately before the
636  * connect attempt finishes.
637  */
TEST(AsyncSocketTest,ConnectReadAndClose)638 TEST(AsyncSocketTest, ConnectReadAndClose) {
639   TestServer server;
640 
641   // connect()
642   EventBase evb;
643   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
644   ConnCallback ccb;
645   socket->connect(&ccb, server.getAddress(), 30);
646 
647   // Hopefully the connect didn't succeed immediately.
648   // If it did, we can't exercise the close-while-connecting code path.
649   if (ccb.state == STATE_SUCCEEDED) {
650     LOG(INFO) << "connect() succeeded immediately; aborting test "
651                  "of read-during-connect behavior";
652     return;
653   }
654 
655   ReadCallback rcb;
656   socket->setReadCB(&rcb);
657 
658   // close()
659   socket->close();
660 
661   // Loop, although there shouldn't be anything to do.
662   evb.loop();
663 
664   ASSERT_EQ(ccb.state, STATE_FAILED); // we aborted the close attempt
665   ASSERT_EQ(rcb.buffers.size(), 0);
666   ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
667 
668   ASSERT_TRUE(socket->isClosedBySelf());
669   ASSERT_FALSE(socket->isClosedByPeer());
670 }
671 
672 /**
673  * Test both writing and installing a read callback immediately,
674  * before connect() finishes.
675  */
TEST_P(AsyncSocketConnectTest,ConnectWriteAndRead)676 TEST_P(AsyncSocketConnectTest, ConnectWriteAndRead) {
677   TestServer server;
678 
679   // connect()
680   EventBase evb;
681   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
682   if (GetParam() == TFOState::ENABLED) {
683     socket->enableTFO();
684   }
685   ConnCallback ccb;
686   socket->connect(&ccb, server.getAddress(), 30);
687 
688   // write()
689   char buf1[128];
690   memset(buf1, 'a', sizeof(buf1));
691   WriteCallback wcb;
692   socket->write(&wcb, buf1, sizeof(buf1));
693 
694   // set a read callback
695   ReadCallback rcb;
696   socket->setReadCB(&rcb);
697 
698   // Even though we haven't looped yet, we should be able to accept
699   // the connection and send data to it.
700   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
701   uint8_t buf2[128];
702   memset(buf2, 'b', sizeof(buf2));
703   acceptedSocket->write(buf2, sizeof(buf2));
704   acceptedSocket->flush();
705 
706   // shut down the write half of acceptedSocket, so that the AsyncSocket
707   // will stop reading and we can break out of the event loop.
708   netops::shutdown(acceptedSocket->getNetworkSocket(), SHUT_WR);
709 
710   // Loop
711   evb.loop();
712 
713   // Make sure the connect succeeded
714   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
715 
716   // Make sure the AsyncSocket read the data written by the accepted socket
717   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
718   ASSERT_EQ(rcb.buffers.size(), 1);
719   ASSERT_EQ(rcb.buffers[0].length, sizeof(buf2));
720   ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf2, sizeof(buf2)), 0);
721 
722   // Close the AsyncSocket so we'll see EOF on acceptedSocket
723   socket->close();
724 
725   // Make sure the accepted socket saw the data written by the AsyncSocket
726   uint8_t readbuf[sizeof(buf1)];
727   acceptedSocket->readAll(readbuf, sizeof(readbuf));
728   ASSERT_EQ(memcmp(buf1, readbuf, sizeof(buf1)), 0);
729   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
730   ASSERT_EQ(bytesRead, 0);
731 
732   ASSERT_FALSE(socket->isClosedBySelf());
733   ASSERT_TRUE(socket->isClosedByPeer());
734 }
735 
736 /**
737  * Test writing to the socket then shutting down writes before the connect
738  * attempt finishes.
739  */
TEST(AsyncSocketTest,ConnectWriteAndShutdownWrite)740 TEST(AsyncSocketTest, ConnectWriteAndShutdownWrite) {
741   TestServer server;
742 
743   // connect()
744   EventBase evb;
745   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
746   ConnCallback ccb;
747   socket->connect(&ccb, server.getAddress(), 30);
748 
749   // Hopefully the connect didn't succeed immediately.
750   // If it did, we can't exercise the write-while-connecting code path.
751   if (ccb.state == STATE_SUCCEEDED) {
752     LOG(INFO) << "connect() succeeded immediately; skipping test";
753     return;
754   }
755 
756   // Ask to write some data
757   char wbuf[128];
758   memset(wbuf, 'a', sizeof(wbuf));
759   WriteCallback wcb;
760   socket->write(&wcb, wbuf, sizeof(wbuf));
761   socket->shutdownWrite();
762 
763   // Shutdown writes
764   socket->shutdownWrite();
765 
766   // Even though we haven't looped yet, we should be able to accept
767   // the connection.
768   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
769 
770   // Since the connection is still in progress, there should be no data to
771   // read yet.  Verify that the accepted socket is not readable.
772   netops::PollDescriptor fds[1];
773   fds[0].fd = acceptedSocket->getNetworkSocket();
774   fds[0].events = POLLIN;
775   fds[0].revents = 0;
776   int rc = netops::poll(fds, 1, 0);
777   ASSERT_EQ(rc, 0);
778 
779   // Write data to the accepted socket
780   uint8_t acceptedWbuf[192];
781   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
782   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
783   acceptedSocket->flush();
784 
785   // Loop
786   evb.loop();
787 
788   // The loop should have completed the connection, written the queued data,
789   // and shutdown writes on the socket.
790   //
791   // Check that the connection was completed successfully and that the write
792   // callback succeeded.
793   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
794   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
795 
796   // Check that we can read the data that was written to the socket, and that
797   // we see an EOF, since its socket was half-shutdown.
798   uint8_t readbuf[sizeof(wbuf)];
799   acceptedSocket->readAll(readbuf, sizeof(readbuf));
800   ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
801   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
802   ASSERT_EQ(bytesRead, 0);
803 
804   // Close the accepted socket.  This will cause it to see EOF
805   // and uninstall the read callback when we loop next.
806   acceptedSocket->close();
807 
808   // Install a read callback, then loop again.
809   ReadCallback rcb;
810   socket->setReadCB(&rcb);
811   evb.loop();
812 
813   // This loop should have read the data and seen the EOF
814   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
815   ASSERT_EQ(rcb.buffers.size(), 1);
816   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
817   ASSERT_EQ(
818       memcmp(rcb.buffers[0].buffer, acceptedWbuf, sizeof(acceptedWbuf)), 0);
819 
820   ASSERT_FALSE(socket->isClosedBySelf());
821   ASSERT_FALSE(socket->isClosedByPeer());
822 }
823 
824 /**
825  * Test reading, writing, and shutting down writes before the connect attempt
826  * finishes.
827  */
TEST(AsyncSocketTest,ConnectReadWriteAndShutdownWrite)828 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWrite) {
829   TestServer server;
830 
831   // connect()
832   EventBase evb;
833   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
834   ConnCallback ccb;
835   socket->connect(&ccb, server.getAddress(), 30);
836 
837   // Hopefully the connect didn't succeed immediately.
838   // If it did, we can't exercise the write-while-connecting code path.
839   if (ccb.state == STATE_SUCCEEDED) {
840     LOG(INFO) << "connect() succeeded immediately; skipping test";
841     return;
842   }
843 
844   // Install a read callback
845   ReadCallback rcb;
846   socket->setReadCB(&rcb);
847 
848   // Ask to write some data
849   char wbuf[128];
850   memset(wbuf, 'a', sizeof(wbuf));
851   WriteCallback wcb;
852   socket->write(&wcb, wbuf, sizeof(wbuf));
853 
854   // Shutdown writes
855   socket->shutdownWrite();
856 
857   // Even though we haven't looped yet, we should be able to accept
858   // the connection.
859   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
860 
861   // Since the connection is still in progress, there should be no data to
862   // read yet.  Verify that the accepted socket is not readable.
863   netops::PollDescriptor fds[1];
864   fds[0].fd = acceptedSocket->getNetworkSocket();
865   fds[0].events = POLLIN;
866   fds[0].revents = 0;
867   int rc = netops::poll(fds, 1, 0);
868   ASSERT_EQ(rc, 0);
869 
870   // Write data to the accepted socket
871   uint8_t acceptedWbuf[192];
872   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
873   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
874   acceptedSocket->flush();
875   // Shutdown writes to the accepted socket.  This will cause it to see EOF
876   // and uninstall the read callback.
877   netops::shutdown(acceptedSocket->getNetworkSocket(), SHUT_WR);
878 
879   // Loop
880   evb.loop();
881 
882   // The loop should have completed the connection, written the queued data,
883   // shutdown writes on the socket, read the data we wrote to it, and see the
884   // EOF.
885   //
886   // Check that the connection was completed successfully and that the read
887   // and write callbacks were invoked as expected.
888   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
889   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
890   ASSERT_EQ(rcb.buffers.size(), 1);
891   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
892   ASSERT_EQ(
893       memcmp(rcb.buffers[0].buffer, acceptedWbuf, sizeof(acceptedWbuf)), 0);
894   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
895 
896   // Check that we can read the data that was written to the socket, and that
897   // we see an EOF, since its socket was half-shutdown.
898   uint8_t readbuf[sizeof(wbuf)];
899   acceptedSocket->readAll(readbuf, sizeof(readbuf));
900   ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
901   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
902   ASSERT_EQ(bytesRead, 0);
903 
904   // Fully close both sockets
905   acceptedSocket->close();
906   socket->close();
907 
908   ASSERT_FALSE(socket->isClosedBySelf());
909   ASSERT_TRUE(socket->isClosedByPeer());
910 }
911 
912 /**
913  * Test reading, writing, and calling shutdownWriteNow() before the
914  * connect attempt finishes.
915  */
TEST(AsyncSocketTest,ConnectReadWriteAndShutdownWriteNow)916 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWriteNow) {
917   TestServer server;
918 
919   // connect()
920   EventBase evb;
921   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
922   ConnCallback ccb;
923   socket->connect(&ccb, server.getAddress(), 30);
924 
925   // Hopefully the connect didn't succeed immediately.
926   // If it did, we can't exercise the write-while-connecting code path.
927   if (ccb.state == STATE_SUCCEEDED) {
928     LOG(INFO) << "connect() succeeded immediately; skipping test";
929     return;
930   }
931 
932   // Install a read callback
933   ReadCallback rcb;
934   socket->setReadCB(&rcb);
935 
936   // Ask to write some data
937   char wbuf[128];
938   memset(wbuf, 'a', sizeof(wbuf));
939   WriteCallback wcb;
940   socket->write(&wcb, wbuf, sizeof(wbuf));
941 
942   // Shutdown writes immediately.
943   // This should immediately discard the data that we just tried to write.
944   socket->shutdownWriteNow();
945 
946   // Verify that writeError() was invoked on the write callback.
947   ASSERT_EQ(wcb.state, STATE_FAILED);
948   ASSERT_EQ(wcb.bytesWritten, 0);
949 
950   // Even though we haven't looped yet, we should be able to accept
951   // the connection.
952   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
953 
954   // Since the connection is still in progress, there should be no data to
955   // read yet.  Verify that the accepted socket is not readable.
956   netops::PollDescriptor fds[1];
957   fds[0].fd = acceptedSocket->getNetworkSocket();
958   fds[0].events = POLLIN;
959   fds[0].revents = 0;
960   int rc = netops::poll(fds, 1, 0);
961   ASSERT_EQ(rc, 0);
962 
963   // Write data to the accepted socket
964   uint8_t acceptedWbuf[192];
965   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
966   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
967   acceptedSocket->flush();
968   // Shutdown writes to the accepted socket.  This will cause it to see EOF
969   // and uninstall the read callback.
970   netops::shutdown(acceptedSocket->getNetworkSocket(), SHUT_WR);
971 
972   // Loop
973   evb.loop();
974 
975   // The loop should have completed the connection, written the queued data,
976   // shutdown writes on the socket, read the data we wrote to it, and see the
977   // EOF.
978   //
979   // Check that the connection was completed successfully and that the read
980   // callback was invoked as expected.
981   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
982   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
983   ASSERT_EQ(rcb.buffers.size(), 1);
984   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
985   ASSERT_EQ(
986       memcmp(rcb.buffers[0].buffer, acceptedWbuf, sizeof(acceptedWbuf)), 0);
987 
988   // Since we used shutdownWriteNow(), it should have discarded all pending
989   // write data.  Verify we see an immediate EOF when reading from the accepted
990   // socket.
991   uint8_t readbuf[sizeof(wbuf)];
992   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
993   ASSERT_EQ(bytesRead, 0);
994 
995   // Fully close both sockets
996   acceptedSocket->close();
997   socket->close();
998 
999   ASSERT_FALSE(socket->isClosedBySelf());
1000   ASSERT_TRUE(socket->isClosedByPeer());
1001 }
1002 
1003 // Helper function for use in testConnectOptWrite()
1004 // Temporarily disable the read callback
tmpDisableReads(AsyncSocket * socket,ReadCallback * rcb)1005 void tmpDisableReads(AsyncSocket* socket, ReadCallback* rcb) {
1006   // Uninstall the read callback
1007   socket->setReadCB(nullptr);
1008   // Schedule the read callback to be reinstalled after 1ms
1009   socket->getEventBase()->runInLoop(
1010       std::bind(&AsyncSocket::setReadCB, socket, rcb));
1011 }
1012 
1013 /**
1014  * Test connect+write, then have the connect callback perform another write.
1015  *
1016  * This tests interaction of the optimistic writing after connect with
1017  * additional write attempts that occur in the connect callback.
1018  */
testConnectOptWrite(size_t size1,size_t size2,bool close=false)1019 void testConnectOptWrite(size_t size1, size_t size2, bool close = false) {
1020   TestServer server;
1021   EventBase evb;
1022   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1023 
1024   // connect()
1025   ConnCallback ccb;
1026   socket->connect(&ccb, server.getAddress(), 30);
1027 
1028   // Hopefully the connect didn't succeed immediately.
1029   // If it did, we can't exercise the optimistic write code path.
1030   if (ccb.state == STATE_SUCCEEDED) {
1031     LOG(INFO) << "connect() succeeded immediately; aborting test "
1032                  "of optimistic write behavior";
1033     return;
1034   }
1035 
1036   // Tell the connect callback to perform a write when the connect succeeds
1037   WriteCallback wcb2;
1038   std::unique_ptr<char[]> buf2(new char[size2]);
1039   memset(buf2.get(), 'b', size2);
1040   if (size2 > 0) {
1041     ccb.successCallback = [&] { socket->write(&wcb2, buf2.get(), size2); };
1042     // Tell the second write callback to close the connection when it is done
1043     wcb2.successCallback = [&] { socket->closeNow(); };
1044   }
1045 
1046   // Schedule one write() immediately, before the connect finishes
1047   std::unique_ptr<char[]> buf1(new char[size1]);
1048   memset(buf1.get(), 'a', size1);
1049   WriteCallback wcb1;
1050   if (size1 > 0) {
1051     socket->write(&wcb1, buf1.get(), size1);
1052   }
1053 
1054   if (close) {
1055     // immediately perform a close, before connect() completes
1056     socket->close();
1057   }
1058 
1059   // Start reading from the other endpoint after 10ms.
1060   // If we're using large buffers, we have to read so that the writes don't
1061   // block forever.
1062   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1063   ReadCallback rcb;
1064   rcb.dataAvailableCallback =
1065       std::bind(tmpDisableReads, acceptedSocket.get(), &rcb);
1066   socket->getEventBase()->tryRunAfterDelay(
1067       std::bind(&AsyncSocket::setReadCB, acceptedSocket.get(), &rcb), 10);
1068 
1069   // Loop.  We don't bother accepting on the server socket yet.
1070   // The kernel should be able to buffer the write request so it can succeed.
1071   evb.loop();
1072 
1073   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1074   if (size1 > 0) {
1075     ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1076   }
1077   if (size2 > 0) {
1078     ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1079   }
1080 
1081   socket->close();
1082 
1083   // Make sure the read callback received all of the data
1084   size_t bytesRead = 0;
1085   for (const auto& buffer : rcb.buffers) {
1086     size_t start = bytesRead;
1087     bytesRead += buffer.length;
1088     size_t end = bytesRead;
1089     if (start < size1) {
1090       size_t cmpLen = min(size1, end) - start;
1091       ASSERT_EQ(memcmp(buffer.buffer, buf1.get() + start, cmpLen), 0);
1092     }
1093     if (end > size1 && end <= size1 + size2) {
1094       size_t itOffset;
1095       size_t buf2Offset;
1096       size_t cmpLen;
1097       if (start >= size1) {
1098         itOffset = 0;
1099         buf2Offset = start - size1;
1100         cmpLen = end - start;
1101       } else {
1102         itOffset = size1 - start;
1103         buf2Offset = 0;
1104         cmpLen = end - size1;
1105       }
1106       ASSERT_EQ(
1107           memcmp(buffer.buffer + itOffset, buf2.get() + buf2Offset, cmpLen), 0);
1108     }
1109   }
1110   ASSERT_EQ(bytesRead, size1 + size2);
1111 }
1112 
TEST(AsyncSocketTest,ConnectCallbackWrite)1113 TEST(AsyncSocketTest, ConnectCallbackWrite) {
1114   // Test using small writes that should both succeed immediately
1115   testConnectOptWrite(100, 200);
1116 
1117   // Test using a large buffer in the connect callback, that should block
1118   const size_t largeSize = 32 * 1024 * 1024;
1119   testConnectOptWrite(100, largeSize);
1120 
1121   // Test using a large initial write
1122   testConnectOptWrite(largeSize, 100);
1123 
1124   // Test using two large buffers
1125   testConnectOptWrite(largeSize, largeSize);
1126 
1127   // Test a small write in the connect callback,
1128   // but no immediate write before connect completes
1129   testConnectOptWrite(0, 64);
1130 
1131   // Test a large write in the connect callback,
1132   // but no immediate write before connect completes
1133   testConnectOptWrite(0, largeSize);
1134 
1135   // Test connect, a small write, then immediately call close() before connect
1136   // completes
1137   testConnectOptWrite(211, 0, true);
1138 
1139   // Test connect, a large immediate write (that will block), then immediately
1140   // call close() before connect completes
1141   testConnectOptWrite(largeSize, 0, true);
1142 }
1143 
1144 ///////////////////////////////////////////////////////////////////////////
1145 // write() related tests
1146 ///////////////////////////////////////////////////////////////////////////
1147 
1148 /**
1149  * Test writing using a nullptr callback
1150  */
TEST(AsyncSocketTest,WriteNullCallback)1151 TEST(AsyncSocketTest, WriteNullCallback) {
1152   TestServer server;
1153 
1154   // connect()
1155   EventBase evb;
1156   std::shared_ptr<AsyncSocket> socket =
1157       AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1158   evb.loop(); // loop until the socket is connected
1159 
1160   // write() with a nullptr callback
1161   char buf[128];
1162   memset(buf, 'a', sizeof(buf));
1163   socket->write(nullptr, buf, sizeof(buf));
1164 
1165   evb.loop(); // loop until the data is sent
1166 
1167   // Make sure the server got a connection and received the data
1168   socket->close();
1169   server.verifyConnection(buf, sizeof(buf));
1170 
1171   ASSERT_TRUE(socket->isClosedBySelf());
1172   ASSERT_FALSE(socket->isClosedByPeer());
1173 }
1174 
1175 /**
1176  * Test writing with a send timeout
1177  */
TEST(AsyncSocketTest,WriteTimeout)1178 TEST(AsyncSocketTest, WriteTimeout) {
1179   TestServer server;
1180 
1181   // connect()
1182   EventBase evb;
1183   std::shared_ptr<AsyncSocket> socket =
1184       AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1185   evb.loop(); // loop until the socket is connected
1186 
1187   // write() a large chunk of data, with no-one on the other end reading.
1188   // Tricky: the kernel caches the connection metrics for recently-used
1189   // routes (see tcp_no_metrics_save) so a freshly opened connection can
1190   // have a send buffer size bigger than wmem_default.  This makes the test
1191   // flaky on contbuild if writeLength is < wmem_max (20M on our systems).
1192   size_t writeLength = 32 * 1024 * 1024;
1193   uint32_t timeout = 200;
1194   socket->setSendTimeout(timeout);
1195   std::unique_ptr<char[]> buf(new char[writeLength]);
1196   memset(buf.get(), 'a', writeLength);
1197   WriteCallback wcb;
1198   socket->write(&wcb, buf.get(), writeLength);
1199 
1200   TimePoint start;
1201   evb.loop();
1202   TimePoint end;
1203 
1204   // Make sure the write attempt timed out as requested
1205   ASSERT_EQ(wcb.state, STATE_FAILED);
1206   ASSERT_EQ(wcb.exception.getType(), AsyncSocketException::TIMED_OUT);
1207 
1208   // Check that the write timed out within a reasonable period of time.
1209   // We don't check for exactly the specified timeout, since AsyncSocket only
1210   // times out when it hasn't made progress for that period of time.
1211   //
1212   // On linux, the first write sends a few hundred kb of data, then blocks for
1213   // writability, and then unblocks again after 40ms and is able to write
1214   // another smaller of data before blocking permanently.  Therefore it doesn't
1215   // time out until 40ms + timeout.
1216   //
1217   // I haven't fully verified the cause of this, but I believe it probably
1218   // occurs because the receiving end delays sending an ack for up to 40ms.
1219   // (This is the default value for TCP_DELACK_MIN.)  Once the sender receives
1220   // the ack, it can send some more data.  However, after that point the
1221   // receiver's kernel buffer is full.  This 40ms delay happens even with
1222   // TCP_NODELAY and TCP_QUICKACK enabled on both endpoints.  However, the
1223   // kernel may be automatically disabling TCP_QUICKACK after receiving some
1224   // data.
1225   //
1226   // For now, we simply check that the timeout occurred within 160ms of
1227   // the requested value.
1228   T_CHECK_TIMEOUT(start, end, milliseconds(timeout), milliseconds(160));
1229 }
1230 
1231 /**
1232  * Test writing to a socket that the remote endpoint has closed
1233  */
TEST(AsyncSocketTest,WritePipeError)1234 TEST(AsyncSocketTest, WritePipeError) {
1235   TestServer server;
1236 
1237   // connect()
1238   EventBase evb;
1239   std::shared_ptr<AsyncSocket> socket =
1240       AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1241   socket->setSendTimeout(1000);
1242   evb.loop(); // loop until the socket is connected
1243 
1244   // accept and immediately close the socket
1245   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1246   acceptedSocket->close();
1247 
1248   // write() a large chunk of data
1249   size_t writeLength = 32 * 1024 * 1024;
1250   std::unique_ptr<char[]> buf(new char[writeLength]);
1251   memset(buf.get(), 'a', writeLength);
1252   WriteCallback wcb;
1253   socket->write(&wcb, buf.get(), writeLength);
1254 
1255   evb.loop();
1256 
1257   // Make sure the write failed.
1258   // It would be nice if AsyncSocketException could convey the errno value,
1259   // so that we could check for EPIPE
1260   ASSERT_EQ(wcb.state, STATE_FAILED);
1261   ASSERT_EQ(wcb.exception.getType(), AsyncSocketException::INTERNAL_ERROR);
1262   ASSERT_THAT(
1263       wcb.exception.what(),
1264       MatchesRegex(
1265           kIsMobile
1266               ? "AsyncSocketException: writev\\(\\) failed \\(peer=.+\\), type = Internal error, errno = .+ \\(Broken pipe\\)"
1267               : "AsyncSocketException: writev\\(\\) failed \\(peer=.+, local=.+\\), type = Internal error, errno = .+ \\(Broken pipe\\)"));
1268   ASSERT_FALSE(socket->isClosedBySelf());
1269   ASSERT_FALSE(socket->isClosedByPeer());
1270 }
1271 
1272 /**
1273  * Test writing to a socket that has its read side closed
1274  */
TEST(AsyncSocketTest,WriteAfterReadEOF)1275 TEST(AsyncSocketTest, WriteAfterReadEOF) {
1276   TestServer server;
1277 
1278   // connect()
1279   EventBase evb;
1280   std::shared_ptr<AsyncSocket> socket =
1281       AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1282   evb.loop(); // loop until the socket is connected
1283 
1284   // Accept the connection
1285   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1286   ReadCallback rcb;
1287   acceptedSocket->setReadCB(&rcb);
1288 
1289   // Shutdown the write side of client socket (read side of server socket)
1290   socket->shutdownWrite();
1291   evb.loop();
1292 
1293   // Check that accepted socket is still writable
1294   ASSERT_FALSE(acceptedSocket->good());
1295   ASSERT_TRUE(acceptedSocket->writable());
1296 
1297   // Write data to accepted socket
1298   constexpr size_t simpleBufLength = 5;
1299   char simpleBuf[simpleBufLength];
1300   memset(simpleBuf, 'a', simpleBufLength);
1301   WriteCallback wcb;
1302   acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
1303   evb.loop();
1304 
1305   // Make sure we were able to write even after getting a read EOF
1306   ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
1307   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1308 }
1309 
1310 /**
1311  * Test that bytes written is correctly computed in case of write failure
1312  */
TEST(AsyncSocketTest,WriteErrorCallbackBytesWritten)1313 TEST(AsyncSocketTest, WriteErrorCallbackBytesWritten) {
1314   // Send and receive buffer sizes for the sockets.
1315   // Note that Linux will double this value to allow space for bookkeeping
1316   // overhead.
1317   constexpr size_t kSockBufSize = 8 * 1024;
1318   constexpr size_t kEffectiveSockBufSize = 2 * kSockBufSize;
1319 
1320   TestServer server(false, kSockBufSize);
1321 
1322   SocketOptionMap options{
1323       {{SOL_SOCKET, SO_SNDBUF}, int(kSockBufSize)},
1324       {{SOL_SOCKET, SO_RCVBUF}, int(kSockBufSize)},
1325       {{IPPROTO_TCP, TCP_NODELAY}, 1},
1326   };
1327 
1328   // The current thread will be used by the receiver - use a separate thread
1329   // for the sender.
1330   EventBase senderEvb;
1331   std::thread senderThread([&]() { senderEvb.loopForever(); });
1332 
1333   ConnCallback ccb;
1334   WriteCallback wcb;
1335   std::shared_ptr<AsyncSocket> socket;
1336 
1337   senderEvb.runInEventBaseThreadAndWait([&]() {
1338     socket = AsyncSocket::newSocket(&senderEvb);
1339     socket->connect(&ccb, server.getAddress(), 30, options);
1340   });
1341 
1342   // accept the socket on the server side
1343   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1344 
1345   // Send a big (100KB) write so that it is partially written.
1346   constexpr size_t kSendSize = 100 * 1024;
1347   auto const sendBuf = std::vector<char>(kSendSize, 'a');
1348 
1349   senderEvb.runInEventBaseThreadAndWait(
1350       [&]() { socket->write(&wcb, sendBuf.data(), kSendSize); });
1351 
1352   // Read 20KB of data from the socket to allow the sender to send a bit more
1353   // data after it initially blocks.
1354   constexpr size_t kRecvSize = 20 * 1024;
1355   uint8_t recvBuf[kRecvSize];
1356   auto bytesRead = acceptedSocket->readAll(recvBuf, sizeof(recvBuf));
1357   ASSERT_EQ(kRecvSize, bytesRead);
1358   EXPECT_EQ(0, memcmp(recvBuf, sendBuf.data(), bytesRead));
1359 
1360   // We should be able to send at least the amount of data received plus the
1361   // send buffer size.  In practice we should probably be able to send
1362   constexpr size_t kMinExpectedBytesWritten = kRecvSize + kSockBufSize;
1363 
1364   // We shouldn't be able to send more than the amount of data received plus
1365   // the send buffer size of the sending socket (kEffectiveSockBufSize) plus
1366   // the receive buffer size on the receiving socket (kEffectiveSockBufSize)
1367   constexpr size_t kMaxExpectedBytesWritten =
1368       kRecvSize + kEffectiveSockBufSize + kEffectiveSockBufSize;
1369   static_assert(
1370       kMaxExpectedBytesWritten < kSendSize, "kSendSize set too small");
1371 
1372   // Need to delay after receiving 20KB and before closing the receive side so
1373   // that the send side has a chance to fill the send buffer past.
1374   using clock = std::chrono::steady_clock;
1375   auto const deadline = clock::now() + std::chrono::seconds(2);
1376   while (wcb.bytesWritten < kMinExpectedBytesWritten &&
1377          clock::now() < deadline) {
1378     std::this_thread::yield();
1379   }
1380   acceptedSocket->closeWithReset();
1381 
1382   senderEvb.terminateLoopSoon();
1383   senderThread.join();
1384   socket.reset();
1385 
1386   ASSERT_EQ(STATE_FAILED, wcb.state);
1387   ASSERT_LE(kMinExpectedBytesWritten, wcb.bytesWritten);
1388   ASSERT_GE(kMaxExpectedBytesWritten, wcb.bytesWritten);
1389 }
1390 
1391 /**
1392  * Test writing a mix of simple buffers and IOBufs
1393  */
TEST(AsyncSocketTest,WriteIOBuf)1394 TEST(AsyncSocketTest, WriteIOBuf) {
1395   TestServer server;
1396 
1397   // connect()
1398   EventBase evb;
1399   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1400   ConnCallback ccb;
1401   socket->connect(&ccb, server.getAddress(), 30);
1402 
1403   // Accept the connection
1404   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1405   ReadCallback rcb;
1406   acceptedSocket->setReadCB(&rcb);
1407 
1408   // Check if EOR tracking flag can be set and reset.
1409   EXPECT_FALSE(socket->isEorTrackingEnabled());
1410   socket->setEorTracking(true);
1411   EXPECT_TRUE(socket->isEorTrackingEnabled());
1412   socket->setEorTracking(false);
1413   EXPECT_FALSE(socket->isEorTrackingEnabled());
1414 
1415   // Write a simple buffer to the socket
1416   constexpr size_t simpleBufLength = 5;
1417   char simpleBuf[simpleBufLength];
1418   memset(simpleBuf, 'a', simpleBufLength);
1419   WriteCallback wcb;
1420   socket->write(&wcb, simpleBuf, simpleBufLength);
1421 
1422   // Write a single-element IOBuf chain
1423   size_t buf1Length = 7;
1424   unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1425   memset(buf1->writableData(), 'b', buf1Length);
1426   buf1->append(buf1Length);
1427   unique_ptr<IOBuf> buf1Copy(buf1->clone());
1428   WriteCallback wcb2;
1429   socket->writeChain(&wcb2, std::move(buf1));
1430 
1431   // Write a multiple-element IOBuf chain
1432   size_t buf2Length = 11;
1433   unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1434   memset(buf2->writableData(), 'c', buf2Length);
1435   buf2->append(buf2Length);
1436   size_t buf3Length = 13;
1437   unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1438   memset(buf3->writableData(), 'd', buf3Length);
1439   buf3->append(buf3Length);
1440   buf2->appendToChain(std::move(buf3));
1441   unique_ptr<IOBuf> buf2Copy(buf2->clone());
1442   buf2Copy->coalesce();
1443   WriteCallback wcb3;
1444   socket->writeChain(&wcb3, std::move(buf2));
1445   socket->shutdownWrite();
1446 
1447   // Let the reads and writes run to completion
1448   evb.loop();
1449 
1450   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1451   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1452   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1453 
1454   // Make sure the reader got the right data in the right order
1455   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1456   ASSERT_EQ(rcb.buffers.size(), 1);
1457   ASSERT_EQ(
1458       rcb.buffers[0].length,
1459       simpleBufLength + buf1Length + buf2Length + buf3Length);
1460   ASSERT_EQ(memcmp(rcb.buffers[0].buffer, simpleBuf, simpleBufLength), 0);
1461   ASSERT_EQ(
1462       memcmp(
1463           rcb.buffers[0].buffer + simpleBufLength,
1464           buf1Copy->data(),
1465           buf1Copy->length()),
1466       0);
1467   ASSERT_EQ(
1468       memcmp(
1469           rcb.buffers[0].buffer + simpleBufLength + buf1Length,
1470           buf2Copy->data(),
1471           buf2Copy->length()),
1472       0);
1473 
1474   acceptedSocket->close();
1475   socket->close();
1476 
1477   ASSERT_TRUE(socket->isClosedBySelf());
1478   ASSERT_FALSE(socket->isClosedByPeer());
1479 }
1480 
TEST(AsyncSocketTest,WriteIOBufCorked)1481 TEST(AsyncSocketTest, WriteIOBufCorked) {
1482   TestServer server;
1483 
1484   // connect()
1485   EventBase evb;
1486   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1487   ConnCallback ccb;
1488   socket->connect(&ccb, server.getAddress(), 30);
1489 
1490   // Accept the connection
1491   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1492   ReadCallback rcb;
1493   acceptedSocket->setReadCB(&rcb);
1494 
1495   // Do three writes, 100ms apart, with the "cork" flag set
1496   // on the second write.  The reader should see the first write
1497   // arrive by itself, followed by the second and third writes
1498   // arriving together.
1499   size_t buf1Length = 5;
1500   unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1501   memset(buf1->writableData(), 'a', buf1Length);
1502   buf1->append(buf1Length);
1503   size_t buf2Length = 7;
1504   unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1505   memset(buf2->writableData(), 'b', buf2Length);
1506   buf2->append(buf2Length);
1507   size_t buf3Length = 11;
1508   unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1509   memset(buf3->writableData(), 'c', buf3Length);
1510   buf3->append(buf3Length);
1511   WriteCallback wcb1;
1512   socket->writeChain(&wcb1, std::move(buf1));
1513   WriteCallback wcb2;
1514   DelayedWrite write2(socket, std::move(buf2), &wcb2, true);
1515   write2.scheduleTimeout(100);
1516   WriteCallback wcb3;
1517   DelayedWrite write3(socket, std::move(buf3), &wcb3, false, true);
1518   write3.scheduleTimeout(140);
1519 
1520   evb.loop();
1521   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1522   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1523   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1524   if (wcb3.state != STATE_SUCCEEDED) {
1525     throw(wcb3.exception);
1526   }
1527   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1528 
1529   // Make sure the reader got the data with the right grouping
1530   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1531   ASSERT_EQ(rcb.buffers.size(), 2);
1532   ASSERT_EQ(rcb.buffers[0].length, buf1Length);
1533   ASSERT_EQ(rcb.buffers[1].length, buf2Length + buf3Length);
1534 
1535   acceptedSocket->close();
1536   socket->close();
1537 
1538   ASSERT_TRUE(socket->isClosedBySelf());
1539   ASSERT_FALSE(socket->isClosedByPeer());
1540 }
1541 
1542 /**
1543  * Test performing a zero-length write
1544  */
TEST(AsyncSocketTest,ZeroLengthWrite)1545 TEST(AsyncSocketTest, ZeroLengthWrite) {
1546   TestServer server;
1547 
1548   // connect()
1549   EventBase evb;
1550   std::shared_ptr<AsyncSocket> socket =
1551       AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1552   evb.loop(); // loop until the socket is connected
1553 
1554   auto acceptedSocket = server.acceptAsync(&evb);
1555   ReadCallback rcb;
1556   acceptedSocket->setReadCB(&rcb);
1557 
1558   size_t len1 = 1024 * 1024;
1559   size_t len2 = 1024 * 1024;
1560   std::unique_ptr<char[]> buf(new char[len1 + len2]);
1561   memset(buf.get(), 'a', len1);
1562   memset(buf.get() + len1, 'b', len2);
1563 
1564   WriteCallback wcb1;
1565   WriteCallback wcb2;
1566   WriteCallback wcb3;
1567   WriteCallback wcb4;
1568   socket->write(&wcb1, buf.get(), 0);
1569   socket->write(&wcb2, buf.get(), len1);
1570   socket->write(&wcb3, buf.get() + len1, 0);
1571   socket->write(&wcb4, buf.get() + len1, len2);
1572   socket->close();
1573 
1574   evb.loop(); // loop until the data is sent
1575 
1576   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1577   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1578   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1579   ASSERT_EQ(wcb4.state, STATE_SUCCEEDED);
1580   rcb.verifyData(buf.get(), len1 + len2);
1581 
1582   ASSERT_TRUE(socket->isClosedBySelf());
1583   ASSERT_FALSE(socket->isClosedByPeer());
1584 }
1585 
TEST(AsyncSocketTest,ZeroLengthWritev)1586 TEST(AsyncSocketTest, ZeroLengthWritev) {
1587   TestServer server;
1588 
1589   // connect()
1590   EventBase evb;
1591   std::shared_ptr<AsyncSocket> socket =
1592       AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1593   evb.loop(); // loop until the socket is connected
1594 
1595   auto acceptedSocket = server.acceptAsync(&evb);
1596   ReadCallback rcb;
1597   acceptedSocket->setReadCB(&rcb);
1598 
1599   size_t len1 = 1024 * 1024;
1600   size_t len2 = 1024 * 1024;
1601   std::unique_ptr<char[]> buf(new char[len1 + len2]);
1602   memset(buf.get(), 'a', len1);
1603   memset(buf.get(), 'b', len2);
1604 
1605   WriteCallback wcb;
1606   constexpr size_t iovCount = 4;
1607   struct iovec iov[iovCount];
1608   iov[0].iov_base = buf.get();
1609   iov[0].iov_len = len1;
1610   iov[1].iov_base = buf.get() + len1;
1611   iov[1].iov_len = 0;
1612   iov[2].iov_base = buf.get() + len1;
1613   iov[2].iov_len = len2;
1614   iov[3].iov_base = buf.get() + len1 + len2;
1615   iov[3].iov_len = 0;
1616 
1617   socket->writev(&wcb, iov, iovCount);
1618   socket->close();
1619   evb.loop(); // loop until the data is sent
1620 
1621   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1622   rcb.verifyData(buf.get(), len1 + len2);
1623 
1624   ASSERT_TRUE(socket->isClosedBySelf());
1625   ASSERT_FALSE(socket->isClosedByPeer());
1626 }
1627 
1628 ///////////////////////////////////////////////////////////////////////////
1629 // close() related tests
1630 ///////////////////////////////////////////////////////////////////////////
1631 
1632 /**
1633  * Test calling close() with pending writes when the socket is already closing.
1634  */
TEST(AsyncSocketTest,ClosePendingWritesWhileClosing)1635 TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
1636   TestServer server;
1637 
1638   // connect()
1639   EventBase evb;
1640   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1641   ConnCallback ccb;
1642   socket->connect(&ccb, server.getAddress(), 30);
1643 
1644   // accept the socket on the server side
1645   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1646 
1647   // Loop to ensure the connect has completed
1648   evb.loop();
1649 
1650   // Make sure we are connected
1651   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1652 
1653   // Schedule pending writes, until several write attempts have blocked
1654   char buf[128];
1655   memset(buf, 'a', sizeof(buf));
1656   typedef vector<std::shared_ptr<WriteCallback>> WriteCallbackVector;
1657   WriteCallbackVector writeCallbacks;
1658 
1659   writeCallbacks.reserve(5);
1660   while (writeCallbacks.size() < 5) {
1661     std::shared_ptr<WriteCallback> wcb(new WriteCallback);
1662 
1663     socket->write(wcb.get(), buf, sizeof(buf));
1664     if (wcb->state == STATE_SUCCEEDED) {
1665       // Succeeded immediately.  Keep performing more writes
1666       continue;
1667     }
1668 
1669     // This write is blocked.
1670     // Have the write callback call close() when writeError() is invoked
1671     wcb->errorCallback = std::bind(&AsyncSocket::close, socket.get());
1672     writeCallbacks.push_back(wcb);
1673   }
1674 
1675   // Call closeNow() to immediately fail the pending writes
1676   socket->closeNow();
1677 
1678   // Make sure writeError() was invoked on all of the pending write callbacks
1679   for (const auto& writeCallback : writeCallbacks) {
1680     ASSERT_EQ((writeCallback)->state, STATE_FAILED);
1681   }
1682 
1683   ASSERT_TRUE(socket->isClosedBySelf());
1684   ASSERT_FALSE(socket->isClosedByPeer());
1685 }
1686 
1687 ///////////////////////////////////////////////////////////////////////////
1688 // ImmediateRead related tests
1689 ///////////////////////////////////////////////////////////////////////////
1690 
1691 /* AsyncSocket use to verify immediate read works */
1692 class AsyncSocketImmediateRead : public folly::AsyncSocket {
1693  public:
1694   bool immediateReadCalled = false;
AsyncSocketImmediateRead(folly::EventBase * evb)1695   explicit AsyncSocketImmediateRead(folly::EventBase* evb) : AsyncSocket(evb) {}
1696 
1697  protected:
checkForImmediateRead()1698   void checkForImmediateRead() noexcept override {
1699     immediateReadCalled = true;
1700     AsyncSocket::handleRead();
1701   }
1702 };
1703 
TEST(AsyncSocket,ConnectReadImmediateRead)1704 TEST(AsyncSocket, ConnectReadImmediateRead) {
1705   TestServer server;
1706 
1707   const size_t maxBufferSz = 100;
1708   const size_t maxReadsPerEvent = 1;
1709   const size_t expectedDataSz = maxBufferSz * 3;
1710   char expectedData[expectedDataSz];
1711   memset(expectedData, 'j', expectedDataSz);
1712 
1713   EventBase evb;
1714   ReadCallback rcb(maxBufferSz);
1715   AsyncSocketImmediateRead socket(&evb);
1716   socket.connect(nullptr, server.getAddress(), 30);
1717 
1718   evb.loop(); // loop until the socket is connected
1719 
1720   socket.setReadCB(&rcb);
1721   socket.setMaxReadsPerEvent(maxReadsPerEvent);
1722   socket.immediateReadCalled = false;
1723 
1724   auto acceptedSocket = server.acceptAsync(&evb);
1725 
1726   ReadCallback rcbServer;
1727   WriteCallback wcbServer;
1728   rcbServer.dataAvailableCallback = [&]() {
1729     if (rcbServer.dataRead() == expectedDataSz) {
1730       // write back all data read
1731       rcbServer.verifyData(expectedData, expectedDataSz);
1732       acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1733       acceptedSocket->close();
1734     }
1735   };
1736   acceptedSocket->setReadCB(&rcbServer);
1737 
1738   // write data
1739   WriteCallback wcb1;
1740   socket.write(&wcb1, expectedData, expectedDataSz);
1741   evb.loop();
1742   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1743   rcb.verifyData(expectedData, expectedDataSz);
1744   ASSERT_EQ(socket.immediateReadCalled, true);
1745 
1746   ASSERT_FALSE(socket.isClosedBySelf());
1747   ASSERT_FALSE(socket.isClosedByPeer());
1748 }
1749 
TEST(AsyncSocket,ConnectReadUninstallRead)1750 TEST(AsyncSocket, ConnectReadUninstallRead) {
1751   TestServer server;
1752 
1753   const size_t maxBufferSz = 100;
1754   const size_t maxReadsPerEvent = 1;
1755   const size_t expectedDataSz = maxBufferSz * 3;
1756   char expectedData[expectedDataSz];
1757   memset(expectedData, 'k', expectedDataSz);
1758 
1759   EventBase evb;
1760   ReadCallback rcb(maxBufferSz);
1761   AsyncSocketImmediateRead socket(&evb);
1762   socket.connect(nullptr, server.getAddress(), 30);
1763 
1764   evb.loop(); // loop until the socket is connected
1765 
1766   socket.setReadCB(&rcb);
1767   socket.setMaxReadsPerEvent(maxReadsPerEvent);
1768   socket.immediateReadCalled = false;
1769 
1770   auto acceptedSocket = server.acceptAsync(&evb);
1771 
1772   ReadCallback rcbServer;
1773   WriteCallback wcbServer;
1774   rcbServer.dataAvailableCallback = [&]() {
1775     if (rcbServer.dataRead() == expectedDataSz) {
1776       // write back all data read
1777       rcbServer.verifyData(expectedData, expectedDataSz);
1778       acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1779       acceptedSocket->close();
1780     }
1781   };
1782   acceptedSocket->setReadCB(&rcbServer);
1783 
1784   rcb.dataAvailableCallback = [&]() {
1785     // we read data and reset readCB
1786     socket.setReadCB(nullptr);
1787   };
1788 
1789   // write data
1790   WriteCallback wcb;
1791   socket.write(&wcb, expectedData, expectedDataSz);
1792   evb.loop();
1793   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1794 
1795   /* we shoud've only read maxBufferSz data since readCallback_
1796    * was reset in dataAvailableCallback */
1797   ASSERT_EQ(rcb.dataRead(), maxBufferSz);
1798   ASSERT_EQ(socket.immediateReadCalled, false);
1799 
1800   ASSERT_FALSE(socket.isClosedBySelf());
1801   ASSERT_FALSE(socket.isClosedByPeer());
1802 }
1803 
1804 // TODO:
1805 // - Test connect() and have the connect callback set the read callback
1806 // - Test connect() and have the connect callback unset the read callback
1807 // - Test reading/writing/closing/destroying the socket in the connect callback
1808 // - Test reading/writing/closing/destroying the socket in the read callback
1809 // - Test reading/writing/closing/destroying the socket in the write callback
1810 // - Test one-way shutdown behavior
1811 // - Test changing the EventBase
1812 //
1813 // - TODO: test multiple threads sharing a AsyncSocket, and detaching from it
1814 //   in connectSuccess(), readDataAvailable(), writeSuccess()
1815 
1816 ///////////////////////////////////////////////////////////////////////////
1817 // AsyncServerSocket tests
1818 ///////////////////////////////////////////////////////////////////////////
1819 
1820 /**
1821  * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
1822  */
TEST(AsyncSocketTest,ServerAcceptOptions)1823 TEST(AsyncSocketTest, ServerAcceptOptions) {
1824   EventBase eventBase;
1825 
1826   // Create a server socket
1827   std::shared_ptr<AsyncServerSocket> serverSocket(
1828       AsyncServerSocket::newSocket(&eventBase));
1829   serverSocket->bind(0);
1830   serverSocket->listen(16);
1831   folly::SocketAddress serverAddress;
1832   serverSocket->getAddress(&serverAddress);
1833 
1834   // Add a callback to accept one connection then stop the loop
1835   TestAcceptCallback acceptCallback;
1836   acceptCallback.setConnectionAcceptedFn(
1837       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1838         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1839       });
1840   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
1841     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1842   });
1843   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
1844   serverSocket->startAccepting();
1845 
1846   // Connect to the server socket
1847   std::shared_ptr<AsyncSocket> socket(
1848       AsyncSocket::newSocket(&eventBase, serverAddress));
1849 
1850   eventBase.loop();
1851 
1852   // Verify that the server accepted a connection
1853   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
1854   ASSERT_EQ(
1855       acceptCallback.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
1856   ASSERT_EQ(
1857       acceptCallback.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
1858   ASSERT_EQ(
1859       acceptCallback.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
1860   auto fd = acceptCallback.getEvents()->at(1).fd;
1861 
1862 #ifndef _WIN32
1863   // It is not possible to check if a socket is already in non-blocking mode on
1864   // Windows. Yes really. The accepted connection should already be in
1865   // non-blocking mode
1866   int flags = fcntl(fd.toFd(), F_GETFL, 0);
1867   ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
1868 #endif
1869 
1870 #ifndef TCP_NOPUSH
1871   // The accepted connection should already have TCP_NODELAY set
1872   int value;
1873   socklen_t valueLength = sizeof(value);
1874   int rc =
1875       netops::getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
1876   ASSERT_EQ(rc, 0);
1877   ASSERT_EQ(value, 1);
1878 #endif
1879 }
1880 
1881 /**
1882  * Test AsyncServerSocket::removeAcceptCallback()
1883  */
TEST(AsyncSocketTest,RemoveAcceptCallback)1884 TEST(AsyncSocketTest, RemoveAcceptCallback) {
1885   // Create a new AsyncServerSocket
1886   EventBase eventBase;
1887   std::shared_ptr<AsyncServerSocket> serverSocket(
1888       AsyncServerSocket::newSocket(&eventBase));
1889   serverSocket->bind(0);
1890   serverSocket->listen(16);
1891   folly::SocketAddress serverAddress;
1892   serverSocket->getAddress(&serverAddress);
1893 
1894   // Add several accept callbacks
1895   TestAcceptCallback cb1;
1896   TestAcceptCallback cb2;
1897   TestAcceptCallback cb3;
1898   TestAcceptCallback cb4;
1899   TestAcceptCallback cb5;
1900   TestAcceptCallback cb6;
1901   TestAcceptCallback cb7;
1902 
1903   // Test having callbacks remove other callbacks before them on the list,
1904   // after them on the list, or removing themselves.
1905   //
1906   // Have callback 2 remove callback 3 and callback 5 the first time it is
1907   // called.
1908   int cb2Count = 0;
1909   cb1.setConnectionAcceptedFn(
1910       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1911         std::shared_ptr<AsyncSocket> sock2(AsyncSocket::newSocket(
1912             &eventBase, serverAddress)); // cb2: -cb3 -cb5
1913       });
1914   cb3.setConnectionAcceptedFn(
1915       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {});
1916   cb4.setConnectionAcceptedFn(
1917       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1918         std::shared_ptr<AsyncSocket> sock3(
1919             AsyncSocket::newSocket(&eventBase, serverAddress)); // cb4
1920       });
1921   cb5.setConnectionAcceptedFn(
1922       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1923         std::shared_ptr<AsyncSocket> sock5(
1924             AsyncSocket::newSocket(&eventBase, serverAddress)); // cb7: -cb7
1925       });
1926   cb2.setConnectionAcceptedFn(
1927       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1928         if (cb2Count == 0) {
1929           serverSocket->removeAcceptCallback(&cb3, nullptr);
1930           serverSocket->removeAcceptCallback(&cb5, nullptr);
1931         }
1932         ++cb2Count;
1933       });
1934   // Have callback 6 remove callback 4 the first time it is called,
1935   // and destroy the server socket the second time it is called
1936   int cb6Count = 0;
1937   cb6.setConnectionAcceptedFn(
1938       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1939         if (cb6Count == 0) {
1940           serverSocket->removeAcceptCallback(&cb4, nullptr);
1941           std::shared_ptr<AsyncSocket> sock6(
1942               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1943           std::shared_ptr<AsyncSocket> sock7(
1944               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2
1945           std::shared_ptr<AsyncSocket> sock8(
1946               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: stop
1947 
1948         } else {
1949           serverSocket.reset();
1950         }
1951         ++cb6Count;
1952       });
1953   // Have callback 7 remove itself
1954   cb7.setConnectionAcceptedFn(
1955       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
1956         serverSocket->removeAcceptCallback(&cb7, nullptr);
1957       });
1958 
1959   serverSocket->addAcceptCallback(&cb1, &eventBase);
1960   serverSocket->addAcceptCallback(&cb2, &eventBase);
1961   serverSocket->addAcceptCallback(&cb3, &eventBase);
1962   serverSocket->addAcceptCallback(&cb4, &eventBase);
1963   serverSocket->addAcceptCallback(&cb5, &eventBase);
1964   serverSocket->addAcceptCallback(&cb6, &eventBase);
1965   serverSocket->addAcceptCallback(&cb7, &eventBase);
1966   serverSocket->startAccepting();
1967 
1968   // Make several connections to the socket
1969   std::shared_ptr<AsyncSocket> sock1(
1970       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1971   std::shared_ptr<AsyncSocket> sock4(
1972       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: -cb4
1973 
1974   // Loop until we are stopped
1975   eventBase.loop();
1976 
1977   // Check to make sure that the expected callbacks were invoked.
1978   //
1979   // NOTE: This code depends on the AsyncServerSocket operating calling all of
1980   // the AcceptCallbacks in round-robin fashion, in the order that they were
1981   // added.  The code is implemented this way right now, but the API doesn't
1982   // explicitly require it be done this way.  If we change the code not to be
1983   // exactly round robin in the future, we can simplify the test checks here.
1984   // (We'll also need to update the termination code, since we expect cb6 to
1985   // get called twice to terminate the loop.)
1986   ASSERT_EQ(cb1.getEvents()->size(), 4);
1987   ASSERT_EQ(cb1.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
1988   ASSERT_EQ(cb1.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
1989   ASSERT_EQ(cb1.getEvents()->at(2).type, TestAcceptCallback::TYPE_ACCEPT);
1990   ASSERT_EQ(cb1.getEvents()->at(3).type, TestAcceptCallback::TYPE_STOP);
1991 
1992   ASSERT_EQ(cb2.getEvents()->size(), 4);
1993   ASSERT_EQ(cb2.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
1994   ASSERT_EQ(cb2.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
1995   ASSERT_EQ(cb2.getEvents()->at(2).type, TestAcceptCallback::TYPE_ACCEPT);
1996   ASSERT_EQ(cb2.getEvents()->at(3).type, TestAcceptCallback::TYPE_STOP);
1997 
1998   ASSERT_EQ(cb3.getEvents()->size(), 2);
1999   ASSERT_EQ(cb3.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2000   ASSERT_EQ(cb3.getEvents()->at(1).type, TestAcceptCallback::TYPE_STOP);
2001 
2002   ASSERT_EQ(cb4.getEvents()->size(), 3);
2003   ASSERT_EQ(cb4.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2004   ASSERT_EQ(cb4.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2005   ASSERT_EQ(cb4.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
2006 
2007   ASSERT_EQ(cb5.getEvents()->size(), 2);
2008   ASSERT_EQ(cb5.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2009   ASSERT_EQ(cb5.getEvents()->at(1).type, TestAcceptCallback::TYPE_STOP);
2010 
2011   ASSERT_EQ(cb6.getEvents()->size(), 4);
2012   ASSERT_EQ(cb6.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2013   ASSERT_EQ(cb6.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2014   ASSERT_EQ(cb6.getEvents()->at(2).type, TestAcceptCallback::TYPE_ACCEPT);
2015   ASSERT_EQ(cb6.getEvents()->at(3).type, TestAcceptCallback::TYPE_STOP);
2016 
2017   ASSERT_EQ(cb7.getEvents()->size(), 3);
2018   ASSERT_EQ(cb7.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2019   ASSERT_EQ(cb7.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2020   ASSERT_EQ(cb7.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
2021 }
2022 
2023 /**
2024  * Test AsyncServerSocket::removeAcceptCallback()
2025  */
TEST(AsyncSocketTest,OtherThreadAcceptCallback)2026 TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
2027   // Create a new AsyncServerSocket
2028   EventBase eventBase;
2029   std::shared_ptr<AsyncServerSocket> serverSocket(
2030       AsyncServerSocket::newSocket(&eventBase));
2031   serverSocket->bind(0);
2032   serverSocket->listen(16);
2033   folly::SocketAddress serverAddress;
2034   serverSocket->getAddress(&serverAddress);
2035 
2036   // Add several accept callbacks
2037   TestAcceptCallback cb1;
2038   auto thread_id = std::this_thread::get_id();
2039   cb1.setAcceptStartedFn([&]() {
2040     CHECK_NE(thread_id, std::this_thread::get_id());
2041     thread_id = std::this_thread::get_id();
2042   });
2043   cb1.setConnectionAcceptedFn(
2044       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2045         ASSERT_EQ(thread_id, std::this_thread::get_id());
2046         serverSocket->removeAcceptCallback(&cb1, &eventBase);
2047       });
2048   cb1.setAcceptStoppedFn(
2049       [&]() { ASSERT_EQ(thread_id, std::this_thread::get_id()); });
2050 
2051   // Test having callbacks remove other callbacks before them on the list,
2052   serverSocket->addAcceptCallback(&cb1, &eventBase);
2053   serverSocket->startAccepting();
2054 
2055   // Make several connections to the socket
2056   std::shared_ptr<AsyncSocket> sock1(
2057       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
2058 
2059   // Loop in another thread
2060   auto other = std::thread([&]() { eventBase.loop(); });
2061   other.join();
2062 
2063   // Check to make sure that the expected callbacks were invoked.
2064   //
2065   // NOTE: This code depends on the AsyncServerSocket operating calling all of
2066   // the AcceptCallbacks in round-robin fashion, in the order that they were
2067   // added.  The code is implemented this way right now, but the API doesn't
2068   // explicitly require it be done this way.  If we change the code not to be
2069   // exactly round robin in the future, we can simplify the test checks here.
2070   // (We'll also need to update the termination code, since we expect cb6 to
2071   // get called twice to terminate the loop.)
2072   ASSERT_EQ(cb1.getEvents()->size(), 3);
2073   ASSERT_EQ(cb1.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2074   ASSERT_EQ(cb1.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2075   ASSERT_EQ(cb1.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
2076 }
2077 
serverSocketSanityTest(AsyncServerSocket * serverSocket)2078 void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
2079   EventBase* eventBase = serverSocket->getEventBase();
2080   CHECK(eventBase);
2081 
2082   // Add a callback to accept one connection then stop accepting
2083   TestAcceptCallback acceptCallback;
2084   acceptCallback.setConnectionAcceptedFn(
2085       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2086         serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
2087       });
2088   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2089     serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
2090   });
2091   serverSocket->addAcceptCallback(&acceptCallback, eventBase);
2092   serverSocket->startAccepting();
2093 
2094   // Connect to the server socket
2095   folly::SocketAddress serverAddress;
2096   serverSocket->getAddress(&serverAddress);
2097   AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
2098 
2099   // Loop to process all events
2100   eventBase->loop();
2101 
2102   // Verify that the server accepted a connection
2103   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
2104   ASSERT_EQ(
2105       acceptCallback.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2106   ASSERT_EQ(
2107       acceptCallback.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2108   ASSERT_EQ(
2109       acceptCallback.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
2110 }
2111 
2112 /* Verify that we don't leak sockets if we are destroyed()
2113  * and there are still writes pending
2114  *
2115  * If destroy() only calls close() instead of closeNow(),
2116  * it would shutdown(writes) on the socket, but it would
2117  * never be close()'d, and the socket would leak
2118  */
TEST(AsyncSocketTest,DestroyCloseTest)2119 TEST(AsyncSocketTest, DestroyCloseTest) {
2120   TestServer server;
2121 
2122   // connect()
2123   EventBase clientEB;
2124   EventBase serverEB;
2125   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&clientEB);
2126   ConnCallback ccb;
2127   socket->connect(&ccb, server.getAddress(), 30);
2128 
2129   // Accept the connection
2130   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&serverEB);
2131   ReadCallback rcb;
2132   acceptedSocket->setReadCB(&rcb);
2133 
2134   // Write a large buffer to the socket that is larger than kernel buffer
2135   size_t simpleBufLength = 5000000;
2136   char* simpleBuf = new char[simpleBufLength];
2137   memset(simpleBuf, 'a', simpleBufLength);
2138   WriteCallback wcb;
2139 
2140   // Let the reads and writes run to completion
2141   int fd = acceptedSocket->getNetworkSocket().toFd();
2142 
2143   acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
2144   socket.reset();
2145   acceptedSocket.reset();
2146 
2147   // Test that server socket was closed
2148   folly::test::msvcSuppressAbortOnInvalidParams([&] {
2149     ssize_t sz = read(fd, simpleBuf, simpleBufLength);
2150     ASSERT_EQ(sz, -1);
2151     ASSERT_EQ(errno, EBADF);
2152   });
2153   delete[] simpleBuf;
2154 }
2155 
2156 /**
2157  * Test AsyncServerSocket::useExistingSocket()
2158  */
TEST(AsyncSocketTest,ServerExistingSocket)2159 TEST(AsyncSocketTest, ServerExistingSocket) {
2160   EventBase eventBase;
2161 
2162   // Test creating a socket, and letting AsyncServerSocket bind and listen
2163   {
2164     // Manually create a socket
2165     auto fd = netops::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2166     ASSERT_NE(fd, NetworkSocket());
2167 
2168     // Create a server socket
2169     AsyncServerSocket::UniquePtr serverSocket(
2170         new AsyncServerSocket(&eventBase));
2171     serverSocket->useExistingSocket(fd);
2172     folly::SocketAddress address;
2173     serverSocket->getAddress(&address);
2174     address.setPort(0);
2175     serverSocket->bind(address);
2176     serverSocket->listen(16);
2177 
2178     // Make sure the socket works
2179     serverSocketSanityTest(serverSocket.get());
2180   }
2181 
2182   // Test creating a socket and binding manually,
2183   // then letting AsyncServerSocket listen
2184   {
2185     // Manually create a socket
2186     auto fd = netops::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2187     ASSERT_NE(fd, NetworkSocket());
2188     // bind
2189     struct sockaddr_in addr;
2190     addr.sin_family = AF_INET;
2191     addr.sin_port = 0;
2192     addr.sin_addr.s_addr = INADDR_ANY;
2193     ASSERT_EQ(
2194         netops::bind(
2195             fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
2196         0);
2197     // Look up the address that we bound to
2198     folly::SocketAddress boundAddress;
2199     boundAddress.setFromLocalAddress(fd);
2200 
2201     // Create a server socket
2202     AsyncServerSocket::UniquePtr serverSocket(
2203         new AsyncServerSocket(&eventBase));
2204     serverSocket->useExistingSocket(fd);
2205     serverSocket->listen(16);
2206 
2207     // Make sure AsyncServerSocket reports the same address that we bound to
2208     folly::SocketAddress serverSocketAddress;
2209     serverSocket->getAddress(&serverSocketAddress);
2210     ASSERT_EQ(boundAddress, serverSocketAddress);
2211 
2212     // Make sure the socket works
2213     serverSocketSanityTest(serverSocket.get());
2214   }
2215 
2216   // Test creating a socket, binding and listening manually,
2217   // then giving it to AsyncServerSocket
2218   {
2219     // Manually create a socket
2220     auto fd = netops::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2221     ASSERT_NE(fd, NetworkSocket());
2222     // bind
2223     struct sockaddr_in addr;
2224     addr.sin_family = AF_INET;
2225     addr.sin_port = 0;
2226     addr.sin_addr.s_addr = INADDR_ANY;
2227     ASSERT_EQ(
2228         netops::bind(
2229             fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
2230         0);
2231     // Look up the address that we bound to
2232     folly::SocketAddress boundAddress;
2233     boundAddress.setFromLocalAddress(fd);
2234     // listen
2235     ASSERT_EQ(netops::listen(fd, 16), 0);
2236 
2237     // Create a server socket
2238     AsyncServerSocket::UniquePtr serverSocket(
2239         new AsyncServerSocket(&eventBase));
2240     serverSocket->useExistingSocket(fd);
2241 
2242     // Make sure AsyncServerSocket reports the same address that we bound to
2243     folly::SocketAddress serverSocketAddress;
2244     serverSocket->getAddress(&serverSocketAddress);
2245     ASSERT_EQ(boundAddress, serverSocketAddress);
2246 
2247     // Make sure the socket works
2248     serverSocketSanityTest(serverSocket.get());
2249   }
2250 }
2251 
TEST(AsyncSocketTest,UnixDomainSocketTest)2252 TEST(AsyncSocketTest, UnixDomainSocketTest) {
2253   EventBase eventBase;
2254 
2255   // Create a server socket
2256   std::shared_ptr<AsyncServerSocket> serverSocket(
2257       AsyncServerSocket::newSocket(&eventBase));
2258   string path(1, 0);
2259   path.append(folly::to<string>("/anonymous", folly::Random::rand64()));
2260   folly::SocketAddress serverAddress;
2261   serverAddress.setFromPath(path);
2262   serverSocket->bind(serverAddress);
2263   serverSocket->listen(16);
2264 
2265   // Add a callback to accept one connection then stop the loop
2266   TestAcceptCallback acceptCallback;
2267   acceptCallback.setConnectionAcceptedFn(
2268       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2269         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2270       });
2271   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2272     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2273   });
2274   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2275   serverSocket->startAccepting();
2276 
2277   // Connect to the server socket
2278   std::shared_ptr<AsyncSocket> socket(
2279       AsyncSocket::newSocket(&eventBase, serverAddress));
2280 
2281   eventBase.loop();
2282 
2283   // Verify that the server accepted a connection
2284   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
2285   ASSERT_EQ(
2286       acceptCallback.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
2287   ASSERT_EQ(
2288       acceptCallback.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
2289   ASSERT_EQ(
2290       acceptCallback.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
2291   auto fd = acceptCallback.getEvents()->at(1).fd;
2292 
2293 #ifndef _WIN32
2294   // It is not possible to check if a socket is already in non-blocking mode on
2295   // Windows. Yes really. The accepted connection should already be in
2296   // non-blocking mode
2297   int flags = fcntl(fd.toFd(), F_GETFL, 0);
2298   ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
2299 #endif
2300 }
2301 
TEST(AsyncSocketTest,ConnectionEventCallbackDefault)2302 TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
2303   EventBase eventBase;
2304   TestConnectionEventCallback connectionEventCallback;
2305 
2306   // Create a server socket
2307   std::shared_ptr<AsyncServerSocket> serverSocket(
2308       AsyncServerSocket::newSocket(&eventBase));
2309   serverSocket->setConnectionEventCallback(&connectionEventCallback);
2310   serverSocket->bind(0);
2311   serverSocket->listen(16);
2312   folly::SocketAddress serverAddress;
2313   serverSocket->getAddress(&serverAddress);
2314 
2315   // Add a callback to accept one connection then stop the loop
2316   TestAcceptCallback acceptCallback;
2317   acceptCallback.setConnectionAcceptedFn(
2318       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2319         serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2320       });
2321   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2322     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2323   });
2324   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2325   serverSocket->startAccepting();
2326 
2327   // Connect to the server socket
2328   std::shared_ptr<AsyncSocket> socket(
2329       AsyncSocket::newSocket(&eventBase, serverAddress));
2330 
2331   eventBase.loop();
2332 
2333   // Validate the connection event counters
2334   ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2335   ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2336   ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2337   ASSERT_EQ(
2338       connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
2339   ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
2340   ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2341   ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2342   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2343 }
2344 
TEST(AsyncSocketTest,CallbackInPrimaryEventBase)2345 TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
2346   EventBase eventBase;
2347   TestConnectionEventCallback connectionEventCallback;
2348 
2349   // Create a server socket
2350   std::shared_ptr<AsyncServerSocket> serverSocket(
2351       AsyncServerSocket::newSocket(&eventBase));
2352   serverSocket->setConnectionEventCallback(&connectionEventCallback);
2353   serverSocket->bind(0);
2354   serverSocket->listen(16);
2355   folly::SocketAddress serverAddress;
2356   serverSocket->getAddress(&serverAddress);
2357 
2358   // Add a callback to accept one connection then stop the loop
2359   TestAcceptCallback acceptCallback;
2360   acceptCallback.setConnectionAcceptedFn(
2361       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2362         serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2363       });
2364   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2365     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2366   });
2367   bool acceptStartedFlag{false};
2368   acceptCallback.setAcceptStartedFn(
2369       [&acceptStartedFlag]() { acceptStartedFlag = true; });
2370   bool acceptStoppedFlag{false};
2371   acceptCallback.setAcceptStoppedFn(
2372       [&acceptStoppedFlag]() { acceptStoppedFlag = true; });
2373   serverSocket->addAcceptCallback(&acceptCallback, nullptr);
2374   serverSocket->startAccepting();
2375 
2376   // Connect to the server socket
2377   std::shared_ptr<AsyncSocket> socket(
2378       AsyncSocket::newSocket(&eventBase, serverAddress));
2379 
2380   eventBase.loop();
2381 
2382   ASSERT_TRUE(acceptStartedFlag);
2383   ASSERT_TRUE(acceptStoppedFlag);
2384   // Validate the connection event counters
2385   ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2386   ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2387   ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2388   ASSERT_EQ(
2389       connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
2390   ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
2391   ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2392   ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2393   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2394 }
2395 
TEST(AsyncSocketTest,CallbackInSecondaryEventBase)2396 TEST(AsyncSocketTest, CallbackInSecondaryEventBase) {
2397   EventBase eventBase;
2398   TestConnectionEventCallback connectionEventCallback;
2399 
2400   // Create a server socket
2401   std::shared_ptr<AsyncServerSocket> serverSocket(
2402       AsyncServerSocket::newSocket(&eventBase));
2403   serverSocket->setConnectionEventCallback(&connectionEventCallback);
2404   serverSocket->bind(0);
2405   serverSocket->listen(16);
2406   SocketAddress serverAddress;
2407   serverSocket->getAddress(&serverAddress);
2408 
2409   // Add a callback to accept one connection then stop the loop
2410   TestAcceptCallback acceptCallback;
2411   ScopedEventBaseThread cobThread("ioworker_test");
2412   acceptCallback.setConnectionAcceptedFn(
2413       [&](NetworkSocket /* fd */, const SocketAddress& /* addr */) {
2414         eventBase.runInEventBaseThread([&] {
2415           serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2416         });
2417       });
2418   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2419     eventBase.runInEventBaseThread(
2420         [&] { serverSocket->removeAcceptCallback(&acceptCallback, nullptr); });
2421   });
2422   std::atomic<bool> acceptStartedFlag{false};
2423   acceptCallback.setAcceptStartedFn([&]() { acceptStartedFlag = true; });
2424   Baton<> acceptStoppedFlag;
2425   acceptCallback.setAcceptStoppedFn([&]() { acceptStoppedFlag.post(); });
2426   serverSocket->addAcceptCallback(&acceptCallback, cobThread.getEventBase());
2427   serverSocket->startAccepting();
2428 
2429   // Connect to the server socket
2430   std::shared_ptr<AsyncSocket> socket(
2431       AsyncSocket::newSocket(&eventBase, serverAddress));
2432 
2433   eventBase.loop();
2434 
2435   ASSERT_TRUE(acceptStoppedFlag.try_wait_for(std::chrono::seconds(1)));
2436   ASSERT_TRUE(acceptStartedFlag);
2437   // Validate the connection event counters
2438   ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2439   ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2440   ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2441   ASSERT_EQ(
2442       connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
2443   ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
2444   ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2445   ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2446   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2447 }
2448 
2449 /**
2450  * Test AsyncServerSocket::getNumPendingMessagesInQueue()
2451  */
TEST(AsyncSocketTest,NumPendingMessagesInQueue)2452 TEST(AsyncSocketTest, NumPendingMessagesInQueue) {
2453   EventBase eventBase;
2454 
2455   // Counter of how many connections have been accepted
2456   int count = 0;
2457 
2458   // Create a server socket
2459   auto serverSocket(AsyncServerSocket::newSocket(&eventBase));
2460   serverSocket->bind(0);
2461   serverSocket->listen(16);
2462   folly::SocketAddress serverAddress;
2463   serverSocket->getAddress(&serverAddress);
2464 
2465   // Add a callback to accept connections
2466   TestAcceptCallback acceptCallback;
2467   folly::ScopedEventBaseThread cobThread("ioworker_test");
2468   acceptCallback.setConnectionAcceptedFn(
2469       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
2470         count++;
2471         eventBase.runInEventBaseThreadAndWait([&] {
2472           ASSERT_EQ(4 - count, serverSocket->getNumPendingMessagesInQueue());
2473         });
2474         if (count == 4) {
2475           eventBase.runInEventBaseThread([&] {
2476             serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2477           });
2478         }
2479       });
2480   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2481     eventBase.runInEventBaseThread(
2482         [&] { serverSocket->removeAcceptCallback(&acceptCallback, nullptr); });
2483   });
2484   serverSocket->addAcceptCallback(&acceptCallback, cobThread.getEventBase());
2485   serverSocket->startAccepting();
2486 
2487   // Connect to the server socket, 4 clients, there are 4 connections
2488   auto socket1(AsyncSocket::newSocket(&eventBase, serverAddress));
2489   auto socket2(AsyncSocket::newSocket(&eventBase, serverAddress));
2490   auto socket3(AsyncSocket::newSocket(&eventBase, serverAddress));
2491   auto socket4(AsyncSocket::newSocket(&eventBase, serverAddress));
2492 
2493   eventBase.loop();
2494   ASSERT_EQ(4, count);
2495 }
2496 
2497 /**
2498  * Test AsyncTransport::BufferCallback
2499  */
TEST(AsyncSocketTest,BufferTest)2500 TEST(AsyncSocketTest, BufferTest) {
2501   TestServer server;
2502 
2503   EventBase evb;
2504   SocketOptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2505   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2506   ConnCallback ccb;
2507   socket->connect(&ccb, server.getAddress(), 30, option);
2508 
2509   char buf[100 * 1024];
2510   memset(buf, 'c', sizeof(buf));
2511   WriteCallback wcb;
2512   BufferCallback bcb(socket.get(), sizeof(buf));
2513   socket->setBufferCallback(&bcb);
2514   socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2515 
2516   evb.loop();
2517   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2518   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
2519 
2520   ASSERT_TRUE(bcb.hasBuffered());
2521   ASSERT_TRUE(bcb.hasBufferCleared());
2522 
2523   socket->close();
2524   server.verifyConnection(buf, sizeof(buf));
2525 
2526   ASSERT_TRUE(socket->isClosedBySelf());
2527   ASSERT_FALSE(socket->isClosedByPeer());
2528 }
2529 
TEST(AsyncSocketTest,BufferTestChain)2530 TEST(AsyncSocketTest, BufferTestChain) {
2531   TestServer server;
2532 
2533   EventBase evb;
2534   SocketOptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2535   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2536   ConnCallback ccb;
2537   socket->connect(&ccb, server.getAddress(), 30, option);
2538 
2539   char buf1[100 * 1024];
2540   memset(buf1, 'c', sizeof(buf1));
2541   char buf2[100 * 1024];
2542   memset(buf2, 'f', sizeof(buf2));
2543 
2544   auto buf = folly::IOBuf::copyBuffer(buf1, sizeof(buf1));
2545   buf->appendToChain(folly::IOBuf::copyBuffer(buf2, sizeof(buf2)));
2546   ASSERT_EQ(sizeof(buf1) + sizeof(buf2), buf->computeChainDataLength());
2547 
2548   BufferCallback bcb(socket.get(), buf->computeChainDataLength());
2549   socket->setBufferCallback(&bcb);
2550 
2551   WriteCallback wcb;
2552   socket->writeChain(&wcb, buf->clone(), WriteFlags::NONE);
2553 
2554   evb.loop();
2555   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2556   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
2557 
2558   ASSERT_TRUE(bcb.hasBuffered());
2559   ASSERT_TRUE(bcb.hasBufferCleared());
2560 
2561   socket->close();
2562   buf->coalesce();
2563   server.verifyConnection(
2564       reinterpret_cast<const char*>(buf->data()), buf->length());
2565 
2566   ASSERT_TRUE(socket->isClosedBySelf());
2567   ASSERT_FALSE(socket->isClosedByPeer());
2568 }
2569 
TEST(AsyncSocketTest,BufferCallbackKill)2570 TEST(AsyncSocketTest, BufferCallbackKill) {
2571   TestServer server;
2572   EventBase evb;
2573   SocketOptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2574   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2575   ConnCallback ccb;
2576   socket->connect(&ccb, server.getAddress(), 30, option);
2577   evb.loopOnce();
2578 
2579   char buf[100 * 1024];
2580   memset(buf, 'c', sizeof(buf));
2581   BufferCallback bcb(socket.get(), sizeof(buf));
2582   socket->setBufferCallback(&bcb);
2583   WriteCallback wcb;
2584   wcb.successCallback = [&] {
2585     ASSERT_TRUE(socket.unique());
2586     socket.reset();
2587   };
2588 
2589   // This will trigger AsyncSocket::handleWrite,
2590   // which calls WriteCallback::writeSuccess,
2591   // which calls wcb.successCallback above,
2592   // which tries to delete socket
2593   // Then, the socket will also try to use this BufferCallback
2594   // And that should crash us, if there is no DestructorGuard on the stack
2595   socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2596 
2597   evb.loop();
2598   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2599 }
2600 
2601 #if FOLLY_ALLOW_TFO
TEST(AsyncSocketTest,ConnectTFO)2602 TEST(AsyncSocketTest, ConnectTFO) {
2603   if (!folly::test::isTFOAvailable()) {
2604     GTEST_SKIP() << "TFO not supported.";
2605   }
2606 
2607   // Start listening on a local port
2608   TestServer server(true);
2609 
2610   // Connect using a AsyncSocket
2611   EventBase evb;
2612   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2613   socket->enableTFO();
2614   ConnCallback cb;
2615   socket->connect(&cb, server.getAddress(), 30);
2616 
2617   std::array<uint8_t, 128> buf;
2618   memset(buf.data(), 'a', buf.size());
2619 
2620   std::array<uint8_t, 3> readBuf;
2621   auto sendBuf = IOBuf::copyBuffer("hey");
2622 
2623   std::thread t([&] {
2624     auto acceptedSocket = server.accept();
2625     acceptedSocket->write(buf.data(), buf.size());
2626     acceptedSocket->flush();
2627     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2628     acceptedSocket->close();
2629   });
2630 
2631   evb.loop();
2632 
2633   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2634   EXPECT_LE(0, socket->getConnectTime().count());
2635   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2636   EXPECT_TRUE(socket->getTFOAttempted());
2637 
2638   // Should trigger the connect
2639   WriteCallback write;
2640   ReadCallback rcb;
2641   socket->writeChain(&write, sendBuf->clone());
2642   socket->setReadCB(&rcb);
2643   evb.loop();
2644 
2645   t.join();
2646 
2647   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2648   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2649   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2650   ASSERT_EQ(1, rcb.buffers.size());
2651   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2652   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2653   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2654 }
2655 
TEST(AsyncSocketTest,ConnectTFOSupplyEarlyReadCB)2656 TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
2657   if (!folly::test::isTFOAvailable()) {
2658     GTEST_SKIP() << "TFO not supported.";
2659   }
2660 
2661   // Start listening on a local port
2662   TestServer server(true);
2663 
2664   // Connect using a AsyncSocket
2665   EventBase evb;
2666   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2667   socket->enableTFO();
2668   ConnCallback cb;
2669   socket->connect(&cb, server.getAddress(), 30);
2670   ReadCallback rcb;
2671   socket->setReadCB(&rcb);
2672 
2673   std::array<uint8_t, 128> buf;
2674   memset(buf.data(), 'a', buf.size());
2675 
2676   std::array<uint8_t, 3> readBuf;
2677   auto sendBuf = IOBuf::copyBuffer("hey");
2678 
2679   std::thread t([&] {
2680     auto acceptedSocket = server.accept();
2681     acceptedSocket->write(buf.data(), buf.size());
2682     acceptedSocket->flush();
2683     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2684     acceptedSocket->close();
2685   });
2686 
2687   evb.loop();
2688 
2689   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2690   EXPECT_LE(0, socket->getConnectTime().count());
2691   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2692   EXPECT_TRUE(socket->getTFOAttempted());
2693 
2694   // Should trigger the connect
2695   WriteCallback write;
2696   socket->writeChain(&write, sendBuf->clone());
2697   evb.loop();
2698 
2699   t.join();
2700 
2701   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2702   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2703   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2704   ASSERT_EQ(1, rcb.buffers.size());
2705   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2706   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2707   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2708 }
2709 
2710 /**
2711  * Test connecting to a server that isn't listening
2712  */
TEST(AsyncSocketTest,ConnectRefusedImmediatelyTFO)2713 TEST(AsyncSocketTest, ConnectRefusedImmediatelyTFO) {
2714   EventBase evb;
2715 
2716   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2717 
2718   socket->enableTFO();
2719 
2720   // Hopefully nothing is actually listening on this address
2721   folly::SocketAddress addr("::1", 65535);
2722   ConnCallback cb;
2723   socket->connect(&cb, addr, 30);
2724 
2725   evb.loop();
2726 
2727   WriteCallback write1;
2728   // Trigger the connect if TFO attempt is supported.
2729   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2730   WriteCallback write2;
2731   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2732   evb.loop();
2733 
2734   if (!socket->getTFOFinished()) {
2735     EXPECT_EQ(STATE_FAILED, write1.state);
2736   } else {
2737     EXPECT_EQ(STATE_SUCCEEDED, write1.state);
2738     EXPECT_FALSE(socket->getTFOSucceded());
2739   }
2740 
2741   EXPECT_EQ(STATE_FAILED, write2.state);
2742 
2743   EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2744   EXPECT_LE(0, socket->getConnectTime().count());
2745   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2746   EXPECT_TRUE(socket->getTFOAttempted());
2747 }
2748 
2749 /**
2750  * Test calling closeNow() immediately after connecting.
2751  */
TEST(AsyncSocketTest,ConnectWriteAndCloseNowTFO)2752 TEST(AsyncSocketTest, ConnectWriteAndCloseNowTFO) {
2753   TestServer server(true);
2754 
2755   // connect()
2756   EventBase evb;
2757   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2758   socket->enableTFO();
2759 
2760   ConnCallback ccb;
2761   socket->connect(&ccb, server.getAddress(), 30);
2762 
2763   // write()
2764   std::array<char, 128> buf;
2765   memset(buf.data(), 'a', buf.size());
2766 
2767   // close()
2768   socket->closeNow();
2769 
2770   // Loop, although there shouldn't be anything to do.
2771   evb.loop();
2772 
2773   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2774 
2775   ASSERT_TRUE(socket->isClosedBySelf());
2776   ASSERT_FALSE(socket->isClosedByPeer());
2777 }
2778 
2779 /**
2780  * Test calling close() immediately after connect()
2781  */
TEST(AsyncSocketTest,ConnectAndCloseTFO)2782 TEST(AsyncSocketTest, ConnectAndCloseTFO) {
2783   TestServer server(true);
2784 
2785   // Connect using a AsyncSocket
2786   EventBase evb;
2787   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2788   socket->enableTFO();
2789 
2790   ConnCallback ccb;
2791   socket->connect(&ccb, server.getAddress(), 30);
2792 
2793   socket->close();
2794 
2795   // Loop, although there shouldn't be anything to do.
2796   evb.loop();
2797 
2798   // Make sure the connection was aborted
2799   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2800 
2801   ASSERT_TRUE(socket->isClosedBySelf());
2802   ASSERT_FALSE(socket->isClosedByPeer());
2803 }
2804 
2805 class MockAsyncTFOSocket : public AsyncSocket {
2806  public:
2807   using UniquePtr = std::unique_ptr<MockAsyncTFOSocket, Destructor>;
2808 
MockAsyncTFOSocket(EventBase * evb)2809   explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}
2810 
2811   MOCK_METHOD3(
2812       tfoSendMsg, ssize_t(NetworkSocket fd, struct msghdr* msg, int msg_flags));
2813 };
2814 
TEST(AsyncSocketTest,TestTFOUnsupported)2815 TEST(AsyncSocketTest, TestTFOUnsupported) {
2816   TestServer server(true);
2817 
2818   // Connect using a AsyncSocket
2819   EventBase evb;
2820   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2821   socket->enableTFO();
2822 
2823   ConnCallback ccb;
2824   socket->connect(&ccb, server.getAddress(), 30);
2825   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2826 
2827   ReadCallback rcb;
2828   socket->setReadCB(&rcb);
2829 
2830   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2831       .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2832   WriteCallback write;
2833   auto sendBuf = IOBuf::copyBuffer("hey");
2834   socket->writeChain(&write, sendBuf->clone());
2835   EXPECT_EQ(STATE_WAITING, write.state);
2836 
2837   std::array<uint8_t, 128> buf;
2838   memset(buf.data(), 'a', buf.size());
2839 
2840   std::array<uint8_t, 3> readBuf;
2841 
2842   std::thread t([&] {
2843     std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2844     acceptedSocket->write(buf.data(), buf.size());
2845     acceptedSocket->flush();
2846     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2847     acceptedSocket->close();
2848   });
2849 
2850   evb.loop();
2851 
2852   t.join();
2853   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2854   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2855 
2856   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2857   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2858   ASSERT_EQ(1, rcb.buffers.size());
2859   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2860   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2861   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2862 }
2863 
TEST(AsyncSocketTest,ConnectRefusedDelayedTFO)2864 TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
2865   EventBase evb;
2866 
2867   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2868   socket->enableTFO();
2869 
2870   // Hopefully this fails
2871   folly::SocketAddress fakeAddr("127.0.0.1", 65535);
2872   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2873       .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
2874         sockaddr_storage addr;
2875         auto len = fakeAddr.getAddress(&addr);
2876         auto ret = netops::connect(fd, (const struct sockaddr*)&addr, len);
2877         LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
2878                   << errno;
2879         return ret;
2880       }));
2881 
2882   // Hopefully nothing is actually listening on this address
2883   ConnCallback cb;
2884   socket->connect(&cb, fakeAddr, 30);
2885 
2886   WriteCallback write1;
2887   // Trigger the connect if TFO attempt is supported.
2888   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2889 
2890   if (socket->getTFOFinished()) {
2891     // This test is useless now.
2892     return;
2893   }
2894   WriteCallback write2;
2895   // Trigger the connect if TFO attempt is supported.
2896   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2897   evb.loop();
2898 
2899   EXPECT_EQ(STATE_FAILED, write1.state);
2900   EXPECT_EQ(STATE_FAILED, write2.state);
2901   EXPECT_FALSE(socket->getTFOSucceded());
2902 
2903   EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2904   EXPECT_LE(0, socket->getConnectTime().count());
2905   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2906   EXPECT_TRUE(socket->getTFOAttempted());
2907 }
2908 
TEST(AsyncSocketTest,TestTFOUnsupportedTimeout)2909 TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
2910   // Try connecting to server that won't respond.
2911   //
2912   // This depends somewhat on the network where this test is run.
2913   // Hopefully this IP will be routable but unresponsive.
2914   // (Alternatively, we could try listening on a local raw socket, but that
2915   // normally requires root privileges.)
2916   auto host = SocketAddressTestHelper::isIPv6Enabled()
2917       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
2918       : SocketAddressTestHelper::isIPv4Enabled()
2919       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
2920       : nullptr;
2921   SocketAddress addr(host, 65535);
2922 
2923   // Connect using a AsyncSocket
2924   EventBase evb;
2925   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2926   socket->enableTFO();
2927 
2928   ConnCallback ccb;
2929   // Set a very small timeout
2930   socket->connect(&ccb, addr, 1);
2931   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2932 
2933   ReadCallback rcb;
2934   socket->setReadCB(&rcb);
2935 
2936   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2937       .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2938   WriteCallback write;
2939   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2940 
2941   evb.loop();
2942 
2943   EXPECT_EQ(STATE_FAILED, write.state);
2944 }
2945 
TEST(AsyncSocketTest,TestTFOFallbackToConnect)2946 TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
2947   TestServer server(true);
2948 
2949   // Connect using a AsyncSocket
2950   EventBase evb;
2951   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2952   socket->enableTFO();
2953 
2954   ConnCallback ccb;
2955   socket->connect(&ccb, server.getAddress(), 30);
2956   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2957 
2958   ReadCallback rcb;
2959   socket->setReadCB(&rcb);
2960 
2961   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2962       .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
2963         sockaddr_storage addr;
2964         auto len = server.getAddress().getAddress(&addr);
2965         return netops::connect(fd, (const struct sockaddr*)&addr, len);
2966       }));
2967   WriteCallback write;
2968   auto sendBuf = IOBuf::copyBuffer("hey");
2969   socket->writeChain(&write, sendBuf->clone());
2970   EXPECT_EQ(STATE_WAITING, write.state);
2971 
2972   std::array<uint8_t, 128> buf;
2973   memset(buf.data(), 'a', buf.size());
2974 
2975   std::array<uint8_t, 3> readBuf;
2976 
2977   std::thread t([&] {
2978     std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2979     acceptedSocket->write(buf.data(), buf.size());
2980     acceptedSocket->flush();
2981     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2982     acceptedSocket->close();
2983   });
2984 
2985   evb.loop();
2986 
2987   t.join();
2988   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2989 
2990   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2991   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2992 
2993   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2994   ASSERT_EQ(1, rcb.buffers.size());
2995   ASSERT_EQ(buf.size(), rcb.buffers[0].length);
2996   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2997 }
2998 
TEST(AsyncSocketTest,TestTFOFallbackTimeout)2999 TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
3000   // Try connecting to server that won't respond.
3001   //
3002   // This depends somewhat on the network where this test is run.
3003   // Hopefully this IP will be routable but unresponsive.
3004   // (Alternatively, we could try listening on a local raw socket, but that
3005   // normally requires root privileges.)
3006   auto host = SocketAddressTestHelper::isIPv6Enabled()
3007       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
3008       : SocketAddressTestHelper::isIPv4Enabled()
3009       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
3010       : nullptr;
3011   SocketAddress addr(host, 65535);
3012 
3013   // Connect using a AsyncSocket
3014   EventBase evb;
3015   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
3016   socket->enableTFO();
3017 
3018   ConnCallback ccb;
3019   // Set a very small timeout
3020   socket->connect(&ccb, addr, 1);
3021   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
3022 
3023   ReadCallback rcb;
3024   socket->setReadCB(&rcb);
3025 
3026   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
3027       .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
3028         sockaddr_storage addr2;
3029         auto len = addr.getAddress(&addr2);
3030         return netops::connect(fd, (const struct sockaddr*)&addr2, len);
3031       }));
3032   WriteCallback write;
3033   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
3034 
3035   evb.loop();
3036 
3037   EXPECT_EQ(STATE_FAILED, write.state);
3038 }
3039 
TEST(AsyncSocketTest,TestTFOEagain)3040 TEST(AsyncSocketTest, TestTFOEagain) {
3041   TestServer server(true);
3042 
3043   // Connect using a AsyncSocket
3044   EventBase evb;
3045   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
3046   socket->enableTFO();
3047 
3048   ConnCallback ccb;
3049   socket->connect(&ccb, server.getAddress(), 30);
3050 
3051   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
3052       .WillOnce(SetErrnoAndReturn(EAGAIN, -1));
3053   WriteCallback write;
3054   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
3055 
3056   evb.loop();
3057 
3058   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
3059   EXPECT_EQ(STATE_FAILED, write.state);
3060 }
3061 
3062 // Sending a large amount of data in the first write which will
3063 // definitely not fit into MSS.
TEST(AsyncSocketTest,ConnectTFOWithBigData)3064 TEST(AsyncSocketTest, ConnectTFOWithBigData) {
3065   if (!folly::test::isTFOAvailable()) {
3066     GTEST_SKIP() << "TFO not supported.";
3067   }
3068 
3069   // Start listening on a local port
3070   TestServer server(true);
3071 
3072   // Connect using a AsyncSocket
3073   EventBase evb;
3074   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3075   socket->enableTFO();
3076   ConnCallback cb;
3077   socket->connect(&cb, server.getAddress(), 30);
3078 
3079   std::array<uint8_t, 128> buf;
3080   memset(buf.data(), 'a', buf.size());
3081 
3082   constexpr size_t len = 10 * 1024;
3083   auto sendBuf = IOBuf::create(len);
3084   sendBuf->append(len);
3085   std::array<uint8_t, len> readBuf;
3086 
3087   std::thread t([&] {
3088     auto acceptedSocket = server.accept();
3089     acceptedSocket->write(buf.data(), buf.size());
3090     acceptedSocket->flush();
3091     acceptedSocket->readAll(readBuf.data(), readBuf.size());
3092     acceptedSocket->close();
3093   });
3094 
3095   evb.loop();
3096 
3097   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
3098   EXPECT_LE(0, socket->getConnectTime().count());
3099   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
3100   EXPECT_TRUE(socket->getTFOAttempted());
3101 
3102   // Should trigger the connect
3103   WriteCallback write;
3104   ReadCallback rcb;
3105   socket->writeChain(&write, sendBuf->clone());
3106   socket->setReadCB(&rcb);
3107   evb.loop();
3108 
3109   t.join();
3110 
3111   EXPECT_EQ(STATE_SUCCEEDED, write.state);
3112   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
3113   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
3114   ASSERT_EQ(1, rcb.buffers.size());
3115   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
3116   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
3117   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
3118 }
3119 
3120 #endif // FOLLY_ALLOW_TFO
3121 
3122 class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
3123  public:
3124   MOCK_METHOD1(evbAttached, void(AsyncSocket*));
3125   MOCK_METHOD1(evbDetached, void(AsyncSocket*));
3126 };
3127 
TEST(AsyncSocketTest,EvbCallbacks)3128 TEST(AsyncSocketTest, EvbCallbacks) {
3129   auto cb = std::make_unique<MockEvbChangeCallback>();
3130   EventBase evb;
3131   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3132 
3133   InSequence seq;
3134   EXPECT_CALL(*cb, evbDetached(socket.get())).Times(1);
3135   EXPECT_CALL(*cb, evbAttached(socket.get())).Times(1);
3136 
3137   socket->setEvbChangedCallback(std::move(cb));
3138   socket->detachEventBase();
3139   socket->attachEventBase(&evb);
3140 }
3141 
TEST(AsyncSocketTest,TestEvbDetachWtRegisteredIOHandlers)3142 TEST(AsyncSocketTest, TestEvbDetachWtRegisteredIOHandlers) {
3143   // Start listening on a local port
3144   TestServer server;
3145 
3146   // Connect using a AsyncSocket
3147   EventBase evb;
3148   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3149   ConnCallback cb;
3150   socket->connect(&cb, server.getAddress(), 30);
3151 
3152   evb.loop();
3153 
3154   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
3155   EXPECT_LE(0, socket->getConnectTime().count());
3156   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
3157 
3158   // After the ioHandlers are registered, still should be able to detach/attach
3159   ReadCallback rcb;
3160   socket->setReadCB(&rcb);
3161 
3162   auto cbEvbChg = std::make_unique<MockEvbChangeCallback>();
3163   InSequence seq;
3164   EXPECT_CALL(*cbEvbChg, evbDetached(socket.get())).Times(1);
3165   EXPECT_CALL(*cbEvbChg, evbAttached(socket.get())).Times(1);
3166 
3167   socket->setEvbChangedCallback(std::move(cbEvbChg));
3168   EXPECT_TRUE(socket->isDetachable());
3169   socket->detachEventBase();
3170   socket->attachEventBase(&evb);
3171 
3172   socket->close();
3173 }
3174 
TEST(AsyncSocketTest,TestEvbDetachThenClose)3175 TEST(AsyncSocketTest, TestEvbDetachThenClose) {
3176   // Start listening on a local port
3177   TestServer server;
3178 
3179   // Connect an AsyncSocket to the server
3180   EventBase evb;
3181   auto socket = AsyncSocket::newSocket(&evb);
3182   ConnCallback cb;
3183   socket->connect(&cb, server.getAddress(), 30);
3184 
3185   evb.loop();
3186 
3187   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
3188   EXPECT_LE(0, socket->getConnectTime().count());
3189   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
3190 
3191   // After the ioHandlers are registered, still should be able to detach/attach
3192   ReadCallback rcb;
3193   socket->setReadCB(&rcb);
3194 
3195   auto cbEvbChg = std::make_unique<MockEvbChangeCallback>();
3196   InSequence seq;
3197   EXPECT_CALL(*cbEvbChg, evbDetached(socket.get())).Times(1);
3198 
3199   socket->setEvbChangedCallback(std::move(cbEvbChg));
3200 
3201   // Should be possible to destroy/call closeNow() without an attached EventBase
3202   EXPECT_TRUE(socket->isDetachable());
3203   socket->detachEventBase();
3204   socket.reset();
3205 }
3206 
TEST(AsyncSocket,BytesWrittenWithMove)3207 TEST(AsyncSocket, BytesWrittenWithMove) {
3208   TestServer server;
3209 
3210   EventBase evb;
3211   auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
3212   ConnCallback ccb;
3213   socket1->connect(&ccb, server.getAddress(), 30);
3214   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
3215 
3216   EXPECT_EQ(0, socket1->getRawBytesWritten());
3217   std::vector<uint8_t> wbuf(128, 'a');
3218   WriteCallback wcb;
3219   socket1->write(&wcb, wbuf.data(), wbuf.size());
3220   evb.loopOnce();
3221   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3222   EXPECT_EQ(wbuf.size(), socket1->getRawBytesWritten());
3223   EXPECT_EQ(wbuf.size(), socket1->getAppBytesWritten());
3224 
3225   auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
3226   EXPECT_EQ(wbuf.size(), socket2->getRawBytesWritten());
3227   EXPECT_EQ(wbuf.size(), socket2->getAppBytesWritten());
3228 }
3229 
3230 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
3231 struct AsyncSocketErrMessageCallbackTestParams {
3232   folly::Optional<int> resetCallbackAfter;
3233   folly::Optional<int> closeSocketAfter;
3234   int gotTimestampExpected{0};
3235   int gotByteSeqExpected{0};
3236 };
3237 
3238 class AsyncSocketErrMessageCallbackTest
3239     : public ::testing::TestWithParam<AsyncSocketErrMessageCallbackTestParams> {
3240  public:
3241   static std::vector<AsyncSocketErrMessageCallbackTestParams>
getTestingValues()3242   getTestingValues() {
3243     std::vector<AsyncSocketErrMessageCallbackTestParams> vals;
3244     // each socket err message triggers two socket callbacks:
3245     //   (1) timestamp callback
3246     //   (2) byteseq callback
3247 
3248     // reset callback cases
3249     // resetting the callback should prevent any further callbacks
3250     {
3251       AsyncSocketErrMessageCallbackTestParams params;
3252       params.resetCallbackAfter = 1;
3253       params.gotTimestampExpected = 1;
3254       params.gotByteSeqExpected = 0;
3255       vals.push_back(params);
3256     }
3257     {
3258       AsyncSocketErrMessageCallbackTestParams params;
3259       params.resetCallbackAfter = 2;
3260       params.gotTimestampExpected = 1;
3261       params.gotByteSeqExpected = 1;
3262       vals.push_back(params);
3263     }
3264     {
3265       AsyncSocketErrMessageCallbackTestParams params;
3266       params.resetCallbackAfter = 3;
3267       params.gotTimestampExpected = 2;
3268       params.gotByteSeqExpected = 1;
3269       vals.push_back(params);
3270     }
3271     {
3272       AsyncSocketErrMessageCallbackTestParams params;
3273       params.resetCallbackAfter = 4;
3274       params.gotTimestampExpected = 2;
3275       params.gotByteSeqExpected = 2;
3276       vals.push_back(params);
3277     }
3278 
3279     // close socket cases
3280     // closing the socket will prevent callbacks after the current err message
3281     // callbacks (both timestamp and byteseq) are completed
3282     {
3283       AsyncSocketErrMessageCallbackTestParams params;
3284       params.closeSocketAfter = 1;
3285       params.gotTimestampExpected = 1;
3286       params.gotByteSeqExpected = 1;
3287       vals.push_back(params);
3288     }
3289     {
3290       AsyncSocketErrMessageCallbackTestParams params;
3291       params.closeSocketAfter = 2;
3292       params.gotTimestampExpected = 1;
3293       params.gotByteSeqExpected = 1;
3294       vals.push_back(params);
3295     }
3296     {
3297       AsyncSocketErrMessageCallbackTestParams params;
3298       params.closeSocketAfter = 3;
3299       params.gotTimestampExpected = 2;
3300       params.gotByteSeqExpected = 2;
3301       vals.push_back(params);
3302     }
3303     {
3304       AsyncSocketErrMessageCallbackTestParams params;
3305       params.closeSocketAfter = 4;
3306       params.gotTimestampExpected = 2;
3307       params.gotByteSeqExpected = 2;
3308       vals.push_back(params);
3309     }
3310     return vals;
3311   }
3312 };
3313 
3314 INSTANTIATE_TEST_SUITE_P(
3315     ErrMessageTests,
3316     AsyncSocketErrMessageCallbackTest,
3317     ::testing::ValuesIn(AsyncSocketErrMessageCallbackTest::getTestingValues()));
3318 
3319 class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
3320  public:
TestErrMessageCallback()3321   TestErrMessageCallback()
3322       : exception_(folly::AsyncSocketException::UNKNOWN, "none") {}
3323 
errMessage(const cmsghdr & cmsg)3324   void errMessage(const cmsghdr& cmsg) noexcept override {
3325     if (cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_TIMESTAMPING) {
3326       gotTimestamp_++;
3327       checkResetCallback();
3328       checkCloseSocket();
3329     } else if (
3330         (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
3331         (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
3332       gotByteSeq_++;
3333       checkResetCallback();
3334       checkCloseSocket();
3335     }
3336   }
3337 
errMessageError(const folly::AsyncSocketException & ex)3338   void errMessageError(
3339       const folly::AsyncSocketException& ex) noexcept override {
3340     exception_ = ex;
3341   }
3342 
checkResetCallback()3343   void checkResetCallback() noexcept {
3344     if (socket_ != nullptr && resetCallbackAfter_ != -1 &&
3345         gotTimestamp_ + gotByteSeq_ == resetCallbackAfter_) {
3346       socket_->setErrMessageCB(nullptr);
3347     }
3348   }
3349 
checkCloseSocket()3350   void checkCloseSocket() noexcept {
3351     if (socket_ != nullptr && closeSocketAfter_ != -1 &&
3352         gotTimestamp_ + gotByteSeq_ == closeSocketAfter_) {
3353       socket_->close();
3354     }
3355   }
3356 
3357   folly::AsyncSocket* socket_{nullptr};
3358   folly::AsyncSocketException exception_;
3359   int gotTimestamp_{0};
3360   int gotByteSeq_{0};
3361   int resetCallbackAfter_{-1};
3362   int closeSocketAfter_{-1};
3363 };
3364 
TEST_P(AsyncSocketErrMessageCallbackTest,ErrMessageCallback)3365 TEST_P(AsyncSocketErrMessageCallbackTest, ErrMessageCallback) {
3366   TestServer server;
3367 
3368   // connect()
3369   EventBase evb;
3370   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3371 
3372   ConnCallback ccb;
3373   socket->connect(&ccb, server.getAddress(), 30);
3374   LOG(INFO) << "Client socket fd=" << socket->getNetworkSocket();
3375 
3376   // Let the socket
3377   evb.loop();
3378 
3379   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
3380 
3381   // Set read callback to keep the socket subscribed for event
3382   // notifications. Though we're no planning to read anything from
3383   // this side of the connection.
3384   ReadCallback rcb(1);
3385   socket->setReadCB(&rcb);
3386 
3387   // Set up timestamp callbacks
3388   TestErrMessageCallback errMsgCB;
3389   socket->setErrMessageCB(&errMsgCB);
3390   ASSERT_EQ(
3391       socket->getErrMessageCallback(),
3392       static_cast<folly::AsyncSocket::ErrMessageCallback*>(&errMsgCB));
3393 
3394   // set the number of error messages before socket is closed or callback reset
3395   const auto testParams = GetParam();
3396   errMsgCB.socket_ = socket.get();
3397   if (testParams.resetCallbackAfter.has_value()) {
3398     errMsgCB.resetCallbackAfter_ = testParams.resetCallbackAfter.value();
3399   }
3400   if (testParams.closeSocketAfter.has_value()) {
3401     errMsgCB.closeSocketAfter_ = testParams.closeSocketAfter.value();
3402   }
3403 
3404   // Enable timestamp notifications
3405   ASSERT_NE(socket->getNetworkSocket(), NetworkSocket());
3406   int flags = folly::netops::SOF_TIMESTAMPING_OPT_ID |
3407       folly::netops::SOF_TIMESTAMPING_OPT_TSONLY |
3408       folly::netops::SOF_TIMESTAMPING_SOFTWARE |
3409       folly::netops::SOF_TIMESTAMPING_OPT_CMSG |
3410       folly::netops::SOF_TIMESTAMPING_TX_SCHED;
3411   SocketOptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
3412   EXPECT_EQ(tstampingOpt.apply(socket->getNetworkSocket(), flags), 0);
3413 
3414   // write()
3415   std::vector<uint8_t> wbuf(128, 'a');
3416   WriteCallback wcb;
3417   // Send two packets to get two EOM notifications
3418   socket->write(&wcb, wbuf.data(), wbuf.size() / 2);
3419   socket->write(&wcb, wbuf.data() + wbuf.size() / 2, wbuf.size() / 2);
3420 
3421   // Accept the connection.
3422   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
3423   LOG(INFO) << "Server socket fd=" << acceptedSocket->getNetworkSocket();
3424 
3425   // Loop
3426   evb.loopOnce();
3427   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3428 
3429   // Check that we can read the data that was written to the socket
3430   std::vector<uint8_t> rbuf(wbuf.size(), 0);
3431   uint32_t bytesRead = acceptedSocket->readAll(rbuf.data(), rbuf.size());
3432   ASSERT_EQ(bytesRead, wbuf.size());
3433   ASSERT_TRUE(std::equal(wbuf.begin(), wbuf.end(), rbuf.begin()));
3434 
3435   // Close both sockets
3436   acceptedSocket->close();
3437   socket->close();
3438 
3439   ASSERT_TRUE(socket->isClosedBySelf());
3440   ASSERT_FALSE(socket->isClosedByPeer());
3441 
3442   // Check for the timestamp notifications.
3443   ASSERT_EQ(
3444       errMsgCB.exception_.getType(), folly::AsyncSocketException::UNKNOWN);
3445   ASSERT_EQ(errMsgCB.gotByteSeq_, testParams.gotByteSeqExpected);
3446   ASSERT_EQ(errMsgCB.gotTimestamp_, testParams.gotTimestampExpected);
3447 }
3448 
3449 #endif // FOLLY_HAVE_MSG_ERRQUEUE
3450 
3451 #if FOLLY_HAVE_SO_TIMESTAMPING
3452 
3453 class AsyncSocketByteEventTest : public ::testing::Test {
3454  protected:
3455   using MockDispatcher = ::testing::NiceMock<netops::test::MockDispatcher>;
3456   using TestObserver = MockAsyncTransportObserverForByteEvents;
3457   using ByteEventType = AsyncTransport::ByteEvent::Type;
3458 
3459   /**
3460    * Components of a client connection to TestServer.
3461    *
3462    * Includes EventBase, client's AsyncSocket, and corresponding server socket.
3463    */
3464   class ClientConn {
3465    public:
3466     /**
3467      * Call to sendmsg intercepted and recorded by netops::Dispatcher.
3468      */
3469     struct SendmsgInvocation {
3470       // the iovecs in the msghdr
3471       std::vector<iovec> iovs;
3472 
3473       // WriteFlags encoded in msg_flags
3474       WriteFlags writeFlagsInMsgFlags{WriteFlags::NONE};
3475 
3476       // WriteFlags encoded in the msghdr's ancillary data
3477       WriteFlags writeFlagsInAncillary{WriteFlags::NONE};
3478     };
3479 
ClientConn(std::shared_ptr<TestServer> server,std::shared_ptr<AsyncSocket> socket=nullptr,std::shared_ptr<BlockingSocket> acceptedSocket=nullptr)3480     explicit ClientConn(
3481         std::shared_ptr<TestServer> server,
3482         std::shared_ptr<AsyncSocket> socket = nullptr,
3483         std::shared_ptr<BlockingSocket> acceptedSocket = nullptr)
3484         : server_(std::move(server)),
3485           socket_(std::move(socket)),
3486           acceptedSocket_(std::move(acceptedSocket)) {
3487       if (!socket_) {
3488         socket_ = AsyncSocket::newSocket(&getEventBase());
3489       } else {
3490         setReadCb();
3491       }
3492       socket_->setOverrideNetOpsDispatcher(netOpsDispatcher_);
3493       netOpsDispatcher_->forwardToDefaultImpl();
3494     }
3495 
connect()3496     void connect() {
3497       CHECK_NOTNULL(socket_.get());
3498       CHECK_NOTNULL(socket_->getEventBase());
3499       socket_->connect(&connCb_, server_->getAddress(), 30);
3500       socket_->getEventBase()->loop();
3501       ASSERT_EQ(connCb_.state, STATE_SUCCEEDED);
3502       setReadCb();
3503 
3504       // accept the socket at the server
3505       acceptedSocket_ = server_->accept();
3506     }
3507 
setReadCb()3508     void setReadCb() {
3509       // Due to how libevent works, we currently need to be subscribed to
3510       // EV_READ events in order to get error messages.
3511       //
3512       // TODO(bschlinker): Resolve this with libevent modification.
3513       // See https://github.com/libevent/libevent/issues/1038 for details.
3514       socket_->setReadCB(&readCb_);
3515     }
3516 
attachObserver(bool enableByteEvents,bool enablePrewrite=false)3517     std::shared_ptr<NiceMock<TestObserver>> attachObserver(
3518         bool enableByteEvents, bool enablePrewrite = false) {
3519       auto observer = AsyncSocketByteEventTest::attachObserver(
3520           socket_.get(), enableByteEvents, enablePrewrite);
3521       observers_.push_back(observer);
3522       return observer;
3523     }
3524 
3525     /**
3526      * Write to client socket and read at server.
3527      */
write(const iovec * iov,const size_t count,const WriteFlags writeFlags)3528     void write(
3529         const iovec* iov, const size_t count, const WriteFlags writeFlags) {
3530       CHECK_NOTNULL(socket_.get());
3531       CHECK_NOTNULL(socket_->getEventBase());
3532 
3533       // read buffer for server
3534       std::vector<uint8_t> rbuf(iovsToNumBytes(iov, count), 0);
3535       uint64_t rbufReadBytes = 0;
3536 
3537       // write to the client socket, incrementally read at the server
3538       WriteCallback wcb;
3539       socket_->writev(&wcb, iov, count, writeFlags);
3540       while (wcb.state == STATE_WAITING) {
3541         socket_->getEventBase()->loopOnce();
3542         rbufReadBytes += acceptedSocket_->readNoBlock(
3543             rbuf.data() + rbufReadBytes, rbuf.size() - rbufReadBytes);
3544       }
3545       ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3546 
3547       // finish reading, then compare
3548       rbufReadBytes += acceptedSocket_->readAll(
3549           rbuf.data() + rbufReadBytes, rbuf.size() - rbufReadBytes);
3550       const auto cBuf = iovsToVector(iov, count);
3551       ASSERT_EQ(rbufReadBytes, cBuf.size());
3552       ASSERT_TRUE(std::equal(cBuf.begin(), cBuf.end(), rbuf.begin()));
3553     }
3554 
3555     /**
3556      * Write to client socket and read at server.
3557      */
write(const std::vector<uint8_t> & wbuf,const WriteFlags writeFlags)3558     void write(const std::vector<uint8_t>& wbuf, const WriteFlags writeFlags) {
3559       iovec op;
3560       op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
3561       op.iov_len = wbuf.size();
3562       write(&op, 1, writeFlags);
3563     }
3564 
3565     /**
3566      * Write to client socket, echo at server, and wait for echo at client.
3567      *
3568      * Waiting for echo at client ensures that we have given opportunity for
3569      * timestamps to be generated by the kernel.
3570      */
writeAndReflect(const iovec * iov,const size_t count,const WriteFlags writeFlags)3571     void writeAndReflect(
3572         const iovec* iov, const size_t count, const WriteFlags writeFlags) {
3573       write(iov, count, writeFlags);
3574 
3575       // reflect
3576       const auto wbuf = iovsToVector(iov, count);
3577       acceptedSocket_->write(wbuf.data(), wbuf.size());
3578       while (wbuf.size() != readCb_.dataRead()) {
3579         socket_->getEventBase()->loopOnce();
3580       }
3581       readCb_.verifyData(wbuf.data(), wbuf.size());
3582       readCb_.clearData();
3583     }
3584 
3585     /**
3586      * Write to client socket, echo at server, and wait for echo at client.
3587      *
3588      * Waiting for echo at client ensures that we have given opportunity for
3589      * timestamps to be generated by the kernel.
3590      */
writeAndReflect(const std::vector<uint8_t> & wbuf,const WriteFlags writeFlags)3591     void writeAndReflect(
3592         const std::vector<uint8_t>& wbuf, const WriteFlags writeFlags) {
3593       iovec op = {};
3594       op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
3595       op.iov_len = wbuf.size();
3596       writeAndReflect(&op, 1, writeFlags);
3597     }
3598 
getRawSocket()3599     std::shared_ptr<AsyncSocket> getRawSocket() { return socket_; }
3600 
getAcceptedSocket()3601     std::shared_ptr<BlockingSocket> getAcceptedSocket() {
3602       return acceptedSocket_;
3603     }
3604 
getEventBase()3605     EventBase& getEventBase() {
3606       static EventBase evb; // use same EventBase for all client sockets
3607       return evb;
3608     }
3609 
getNetOpsDispatcher() const3610     std::shared_ptr<MockDispatcher> getNetOpsDispatcher() const {
3611       return netOpsDispatcher_;
3612     }
3613 
3614     /**
3615      * Get recorded SendmsgInvocations.
3616      */
getSendmsgInvocations()3617     const std::vector<SendmsgInvocation>& getSendmsgInvocations() {
3618       return sendmsgInvocations_;
3619     }
3620 
3621     /**
3622      * Expect a call to setsockopt with optname SO_TIMESTAMPING.
3623      */
netOpsExpectTimestampingSetSockOpt()3624     void netOpsExpectTimestampingSetSockOpt() {
3625       // must whitelist other calls
3626       EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
3627           .Times(AnyNumber());
3628       EXPECT_CALL(
3629           *netOpsDispatcher_, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
3630           .Times(1);
3631     }
3632 
3633     /**
3634      * Expect NO calls to setsockopt with optname SO_TIMESTAMPING.
3635      */
netOpsExpectNoTimestampingSetSockOpt()3636     void netOpsExpectNoTimestampingSetSockOpt() {
3637       // must whitelist other calls
3638       EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
3639           .Times(AnyNumber());
3640       EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, SO_TIMESTAMPING, _, _))
3641           .Times(0);
3642     }
3643 
3644     /**
3645      * Expect sendmsg to be called with the passed WriteFlags in ancillary data.
3646      */
netOpsExpectSendmsgWithAncillaryTsFlags(WriteFlags writeFlags)3647     void netOpsExpectSendmsgWithAncillaryTsFlags(WriteFlags writeFlags) {
3648       auto getMsgAncillaryTsFlags = std::bind(
3649           (WriteFlags(*)(const struct msghdr* msg)) & ::getMsgAncillaryTsFlags,
3650           std::placeholders::_1);
3651       EXPECT_CALL(
3652           *netOpsDispatcher_,
3653           sendmsg(_, ResultOf(getMsgAncillaryTsFlags, Eq(writeFlags)), _))
3654           .WillOnce(DoDefault());
3655     }
3656 
3657     /**
3658      * When sendmsg is called, record details and then forward to real sendmsg.
3659      *
3660      * This creates a default action.
3661      */
netOpsOnSendmsgRecordIovecsAndFlagsAndFwd()3662     void netOpsOnSendmsgRecordIovecsAndFlagsAndFwd() {
3663       ON_CALL(*netOpsDispatcher_, sendmsg(_, _, _))
3664           .WillByDefault(::testing::Invoke(
3665               [this](NetworkSocket s, const msghdr* message, int flags) {
3666                 recordSendmsgInvocation(s, message, flags);
3667                 return netops::Dispatcher::getDefaultInstance()->sendmsg(
3668                     s, message, flags);
3669               }));
3670     }
3671 
netOpsVerifyAndClearExpectations()3672     void netOpsVerifyAndClearExpectations() {
3673       Mock::VerifyAndClearExpectations(netOpsDispatcher_.get());
3674     }
3675 
3676    private:
recordSendmsgInvocation(NetworkSocket,const msghdr * message,int flags)3677     void recordSendmsgInvocation(
3678         NetworkSocket /* s */, const msghdr* message, int flags) {
3679       SendmsgInvocation invoc = {};
3680       invoc.iovs = getMsgIovecs(message);
3681       invoc.writeFlagsInMsgFlags = msgFlagsToWriteFlags(flags);
3682       invoc.writeFlagsInAncillary = getMsgAncillaryTsFlags(message);
3683       sendmsgInvocations_.emplace_back(std::move(invoc));
3684     }
3685 
3686     // server
3687     std::shared_ptr<TestServer> server_;
3688 
3689     // managed observers
3690     std::vector<std::shared_ptr<TestObserver>> observers_;
3691 
3692     // socket components
3693     ConnCallback connCb_;
3694     ReadCallback readCb_;
3695     std::shared_ptr<MockDispatcher> netOpsDispatcher_{
3696         std::make_shared<MockDispatcher>()};
3697     std::shared_ptr<AsyncSocket> socket_;
3698 
3699     // accepted socket at server
3700     std::shared_ptr<BlockingSocket> acceptedSocket_;
3701 
3702     // sendmsg invocations observed
3703     std::vector<SendmsgInvocation> sendmsgInvocations_;
3704   };
3705 
getClientConn()3706   ClientConn getClientConn() { return ClientConn(server_); }
3707 
3708   /**
3709    * Static utility functions.
3710    */
3711 
attachObserver(AsyncSocket * socket,bool enableByteEvents,bool enablePrewrite=false)3712   static std::shared_ptr<NiceMock<TestObserver>> attachObserver(
3713       AsyncSocket* socket, bool enableByteEvents, bool enablePrewrite = false) {
3714     AsyncTransport::LifecycleObserver::Config config = {};
3715     config.byteEvents = enableByteEvents;
3716     config.prewrite = enablePrewrite;
3717     return std::make_shared<NiceMock<TestObserver>>(socket, config);
3718   }
3719 
getHundredBytesOfData()3720   static std::vector<uint8_t> getHundredBytesOfData() {
3721     return std::vector<uint8_t>(
3722         kOneHundredCharacterString.begin(), kOneHundredCharacterString.end());
3723   }
3724 
get10KBOfData()3725   static std::vector<uint8_t> get10KBOfData() {
3726     std::vector<uint8_t> vec;
3727     vec.reserve(kOneHundredCharacterString.size() * 100);
3728     for (auto i = 0; i < 100; i++) {
3729       vec.insert(
3730           vec.end(),
3731           kOneHundredCharacterString.begin(),
3732           kOneHundredCharacterString.end());
3733     }
3734     CHECK_EQ(10000, vec.size());
3735     return vec;
3736   }
3737 
get1000KBOfData()3738   static std::vector<uint8_t> get1000KBOfData() {
3739     std::vector<uint8_t> vec;
3740     vec.reserve(kOneHundredCharacterString.size() * 10000);
3741     for (auto i = 0; i < 10000; i++) {
3742       vec.insert(
3743           vec.end(),
3744           kOneHundredCharacterString.begin(),
3745           kOneHundredCharacterString.end());
3746     }
3747     CHECK_EQ(1000000, vec.size());
3748     return vec;
3749   }
3750 
dropWriteFromFlags(WriteFlags writeFlags)3751   static WriteFlags dropWriteFromFlags(WriteFlags writeFlags) {
3752     return writeFlags & ~WriteFlags::TIMESTAMP_WRITE;
3753   }
3754 
getMsgIovecs(const struct msghdr & msg)3755   static std::vector<iovec> getMsgIovecs(const struct msghdr& msg) {
3756     std::vector<iovec> iovecs;
3757     for (size_t i = 0; i < msg.msg_iovlen; i++) {
3758       iovecs.emplace_back(msg.msg_iov[i]);
3759     }
3760     return iovecs;
3761   }
3762 
getMsgIovecs(const struct msghdr * msg)3763   static std::vector<iovec> getMsgIovecs(const struct msghdr* msg) {
3764     return getMsgIovecs(*msg);
3765   }
3766 
iovsToVector(const iovec * iov,const size_t count)3767   static std::vector<uint8_t> iovsToVector(
3768       const iovec* iov, const size_t count) {
3769     std::vector<uint8_t> vec;
3770     for (size_t i = 0; i < count; i++) {
3771       if (iov[i].iov_len == 0) {
3772         continue;
3773       }
3774       const auto ptr = reinterpret_cast<uint8_t*>(iov[i].iov_base);
3775       vec.insert(vec.end(), ptr, ptr + iov[i].iov_len);
3776     }
3777     return vec;
3778   }
3779 
iovsToNumBytes(const iovec * iov,const size_t count)3780   static size_t iovsToNumBytes(const iovec* iov, const size_t count) {
3781     size_t bytes = 0;
3782     for (size_t i = 0; i < count; i++) {
3783       bytes += iov[i].iov_len;
3784     }
3785     return bytes;
3786   }
3787 
filterToWriteEvents(const std::vector<AsyncTransport::ByteEvent> & input)3788   std::vector<AsyncTransport::ByteEvent> filterToWriteEvents(
3789       const std::vector<AsyncTransport::ByteEvent>& input) {
3790     std::vector<AsyncTransport::ByteEvent> result;
3791     std::copy_if(
3792         input.begin(),
3793         input.end(),
3794         std::back_inserter(result),
3795         [](auto& event) {
3796           return event.type == AsyncTransport::ByteEvent::WRITE;
3797         });
3798     return result;
3799   }
3800 
3801   // server
3802   std::shared_ptr<TestServer> server_{std::make_shared<TestServer>()};
3803 };
3804 
TEST_F(AsyncSocketByteEventTest,MsgFlagsToWriteFlags)3805 TEST_F(AsyncSocketByteEventTest, MsgFlagsToWriteFlags) {
3806 #ifdef MSG_MORE
3807   EXPECT_EQ(WriteFlags::CORK, msgFlagsToWriteFlags(MSG_MORE));
3808 #endif // MSG_MORE
3809 
3810 #ifdef MSG_EOR
3811   EXPECT_EQ(WriteFlags::EOR, msgFlagsToWriteFlags(MSG_EOR));
3812 #endif
3813 
3814 #ifdef MSG_ZEROCOPY
3815   EXPECT_EQ(WriteFlags::WRITE_MSG_ZEROCOPY, msgFlagsToWriteFlags(MSG_ZEROCOPY));
3816 #endif
3817 
3818 #if defined(MSG_MORE) && defined(MSG_EOR)
3819   EXPECT_EQ(
3820       WriteFlags::CORK | WriteFlags::EOR,
3821       msgFlagsToWriteFlags(MSG_MORE | MSG_EOR));
3822 #endif
3823 }
3824 
TEST_F(AsyncSocketByteEventTest,GetMsgAncillaryTsFlags)3825 TEST_F(AsyncSocketByteEventTest, GetMsgAncillaryTsFlags) {
3826   auto ancillaryDataSize = CMSG_LEN(sizeof(uint32_t));
3827   auto ancillaryData = reinterpret_cast<char*>(alloca(ancillaryDataSize));
3828 
3829   auto getMsg = [&ancillaryDataSize, &ancillaryData](uint32_t sofFlags) {
3830     struct msghdr msg = {};
3831     msg.msg_name = nullptr;
3832     msg.msg_namelen = 0;
3833     msg.msg_iov = nullptr;
3834     msg.msg_iovlen = 0;
3835     msg.msg_flags = 0;
3836     msg.msg_controllen = 0;
3837     msg.msg_control = nullptr;
3838     if (sofFlags) {
3839       msg.msg_controllen = ancillaryDataSize;
3840       msg.msg_control = ancillaryData;
3841       struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
3842       CHECK_NOTNULL(cmsg);
3843       cmsg->cmsg_level = SOL_SOCKET;
3844       cmsg->cmsg_type = SO_TIMESTAMPING;
3845       cmsg->cmsg_len = CMSG_LEN(sizeof(uint32_t));
3846       memcpy(CMSG_DATA(cmsg), &sofFlags, sizeof(sofFlags));
3847     }
3848     return msg;
3849   };
3850 
3851   // SCHED
3852   {
3853     auto msg = getMsg(folly::netops::SOF_TIMESTAMPING_TX_SCHED);
3854     EXPECT_EQ(WriteFlags::TIMESTAMP_SCHED, getMsgAncillaryTsFlags(msg));
3855   }
3856 
3857   // TX
3858   {
3859     auto msg = getMsg(folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE);
3860     EXPECT_EQ(WriteFlags::TIMESTAMP_TX, getMsgAncillaryTsFlags(msg));
3861   }
3862 
3863   // ACK
3864   {
3865     auto msg = getMsg(folly::netops::SOF_TIMESTAMPING_TX_ACK);
3866     EXPECT_EQ(WriteFlags::TIMESTAMP_ACK, getMsgAncillaryTsFlags(msg));
3867   }
3868 
3869   // SCHED + TX + ACK
3870   {
3871     auto msg = getMsg(
3872         folly::netops::SOF_TIMESTAMPING_TX_SCHED |
3873         folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE |
3874         folly::netops::SOF_TIMESTAMPING_TX_ACK);
3875     EXPECT_EQ(
3876         WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX |
3877             WriteFlags::TIMESTAMP_ACK,
3878         getMsgAncillaryTsFlags(msg));
3879   }
3880 }
3881 
TEST_F(AsyncSocketByteEventTest,ObserverAttachedBeforeConnect)3882 TEST_F(AsyncSocketByteEventTest, ObserverAttachedBeforeConnect) {
3883   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3884       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3885   const std::vector<uint8_t> wbuf(1, 'a');
3886 
3887   auto clientConn = getClientConn();
3888   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3889   clientConn.netOpsExpectTimestampingSetSockOpt();
3890   clientConn.connect();
3891   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3892   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3893   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3894   clientConn.netOpsVerifyAndClearExpectations();
3895 
3896   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
3897   clientConn.writeAndReflect(wbuf, flags);
3898   clientConn.netOpsVerifyAndClearExpectations();
3899   EXPECT_THAT(observer->byteEvents, SizeIs(4));
3900   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3901   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3902   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3903   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3904 
3905   // write again to check offsets
3906   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
3907   clientConn.writeAndReflect(wbuf, flags);
3908   clientConn.netOpsVerifyAndClearExpectations();
3909   EXPECT_THAT(observer->byteEvents, SizeIs(8));
3910   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3911   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3912   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3913   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3914 }
3915 
TEST_F(AsyncSocketByteEventTest,ObserverAttachedAfterConnect)3916 TEST_F(AsyncSocketByteEventTest, ObserverAttachedAfterConnect) {
3917   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3918       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3919   const std::vector<uint8_t> wbuf(1, 'a');
3920 
3921   auto clientConn = getClientConn();
3922   clientConn.netOpsExpectNoTimestampingSetSockOpt();
3923   clientConn.connect();
3924   clientConn.netOpsVerifyAndClearExpectations();
3925 
3926   clientConn.netOpsExpectTimestampingSetSockOpt();
3927   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
3928   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
3929   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3930   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3931   clientConn.netOpsVerifyAndClearExpectations();
3932 
3933   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
3934   clientConn.writeAndReflect(wbuf, flags);
3935   clientConn.netOpsVerifyAndClearExpectations();
3936   EXPECT_THAT(observer->byteEvents, SizeIs(4));
3937   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3938   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3939   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3940   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3941 
3942   // write again to check offsets
3943   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
3944   clientConn.writeAndReflect(wbuf, flags);
3945   clientConn.netOpsVerifyAndClearExpectations();
3946   EXPECT_THAT(observer->byteEvents, SizeIs(8));
3947   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3948   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3949   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
3950   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
3951 }
3952 
TEST_F(AsyncSocketByteEventTest,ObserverAttachedBeforeConnectByteEventsDisabled)3953 TEST_F(
3954     AsyncSocketByteEventTest, ObserverAttachedBeforeConnectByteEventsDisabled) {
3955   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
3956       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
3957   const std::vector<uint8_t> wbuf(1, 'a');
3958 
3959   auto clientConn = getClientConn();
3960   auto observer = clientConn.attachObserver(false /* enableByteEvents */);
3961   clientConn.netOpsExpectNoTimestampingSetSockOpt();
3962 
3963   clientConn.connect(); // connect after observer attached
3964   EXPECT_EQ(0, observer->byteEventsEnabledCalled);
3965   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3966   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3967   clientConn.netOpsVerifyAndClearExpectations();
3968 
3969   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
3970       WriteFlags::NONE); // events disabled
3971   clientConn.writeAndReflect(wbuf, flags);
3972   EXPECT_THAT(observer->byteEvents, IsEmpty());
3973   clientConn.netOpsVerifyAndClearExpectations();
3974 
3975   // now enable ByteEvents with another observer, then write again
3976   clientConn.netOpsExpectTimestampingSetSockOpt();
3977   auto observer2 = clientConn.attachObserver(true /* enableByteEvents */);
3978   EXPECT_EQ(0, observer->byteEventsEnabledCalled); // observer 1 doesn't want
3979   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
3980   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3981   EXPECT_EQ(1, observer2->byteEventsEnabledCalled); // should be set
3982   EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
3983   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
3984   EXPECT_NE(WriteFlags::NONE, flags);
3985   EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
3986   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
3987   clientConn.writeAndReflect(wbuf, flags);
3988   clientConn.netOpsVerifyAndClearExpectations();
3989 
3990   // expect no ByteEvents for first observer, four for the second
3991   EXPECT_THAT(observer->byteEvents, IsEmpty());
3992   EXPECT_THAT(observer2->byteEvents, SizeIs(4));
3993   EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
3994   EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
3995   EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
3996   EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
3997 }
3998 
TEST_F(AsyncSocketByteEventTest,ObserverAttachedAfterConnectByteEventsDisabled)3999 TEST_F(
4000     AsyncSocketByteEventTest, ObserverAttachedAfterConnectByteEventsDisabled) {
4001   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4002       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4003   const std::vector<uint8_t> wbuf(1, 'a');
4004 
4005   auto clientConn = getClientConn();
4006   clientConn.netOpsExpectNoTimestampingSetSockOpt();
4007 
4008   clientConn.connect(); // connect before observer attached
4009 
4010   auto observer = clientConn.attachObserver(false /* enableByteEvents */);
4011   EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4012   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4013   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4014   clientConn.netOpsVerifyAndClearExpectations();
4015 
4016   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
4017       WriteFlags::NONE); // events disabled
4018   clientConn.writeAndReflect(wbuf, flags);
4019   EXPECT_THAT(observer->byteEvents, IsEmpty());
4020   clientConn.netOpsVerifyAndClearExpectations();
4021 
4022   // now enable ByteEvents with another observer, then write again
4023   clientConn.netOpsExpectTimestampingSetSockOpt();
4024   auto observer2 = clientConn.attachObserver(true /* enableByteEvents */);
4025   EXPECT_EQ(0, observer->byteEventsEnabledCalled); // observer 1 doesn't want
4026   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4027   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4028   EXPECT_EQ(1, observer2->byteEventsEnabledCalled); // should be set
4029   EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4030   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4031   EXPECT_NE(WriteFlags::NONE, flags);
4032   EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
4033   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4034   clientConn.writeAndReflect(wbuf, flags);
4035   clientConn.netOpsVerifyAndClearExpectations();
4036 
4037   // expect no ByteEvents for first observer, four for the second
4038   EXPECT_THAT(observer->byteEvents, IsEmpty());
4039   EXPECT_THAT(observer2->byteEvents, SizeIs(4));
4040   EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4041   EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4042   EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4043   EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4044 }
4045 
TEST_F(AsyncSocketByteEventTest,ObserverAttachedAfterWrite)4046 TEST_F(AsyncSocketByteEventTest, ObserverAttachedAfterWrite) {
4047   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4048       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4049   const std::vector<uint8_t> wbuf(1, 'a');
4050 
4051   auto clientConn = getClientConn();
4052   clientConn.netOpsExpectNoTimestampingSetSockOpt();
4053   clientConn.connect(); // connect before observer attached
4054   clientConn.netOpsVerifyAndClearExpectations();
4055 
4056   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
4057       WriteFlags::NONE); // events disabled
4058   clientConn.writeAndReflect(wbuf, flags);
4059   clientConn.netOpsVerifyAndClearExpectations();
4060 
4061   clientConn.netOpsExpectTimestampingSetSockOpt();
4062   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4063   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4064   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4065   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4066   clientConn.netOpsVerifyAndClearExpectations();
4067 
4068   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4069   clientConn.writeAndReflect(wbuf, flags);
4070   clientConn.netOpsVerifyAndClearExpectations();
4071 
4072   EXPECT_THAT(observer->byteEvents, SizeIs(4));
4073   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4074   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4075   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4076   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4077 }
4078 
TEST_F(AsyncSocketByteEventTest,ObserverAttachedAfterClose)4079 TEST_F(AsyncSocketByteEventTest, ObserverAttachedAfterClose) {
4080   auto clientConn = getClientConn();
4081   clientConn.connect();
4082   clientConn.getRawSocket()->close();
4083   EXPECT_TRUE(clientConn.getRawSocket()->isClosedBySelf());
4084 
4085   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4086   EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4087   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4088   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4089 }
4090 
TEST_F(AsyncSocketByteEventTest,MultipleObserverAttached)4091 TEST_F(AsyncSocketByteEventTest, MultipleObserverAttached) {
4092   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4093       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4094   const std::vector<uint8_t> wbuf(50, 'a');
4095 
4096   // attach observer 1 before connect
4097   auto clientConn = getClientConn();
4098   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4099   clientConn.netOpsExpectTimestampingSetSockOpt();
4100   clientConn.connect();
4101   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4102   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4103   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4104   clientConn.netOpsVerifyAndClearExpectations();
4105 
4106   // attach observer 2 after connect
4107   auto observer2 = clientConn.attachObserver(true /* enableByteEvents */);
4108   EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
4109   EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4110   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4111 
4112   // write
4113   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4114   clientConn.writeAndReflect(wbuf, flags);
4115   clientConn.netOpsVerifyAndClearExpectations();
4116 
4117   // check observer1
4118   EXPECT_THAT(observer->byteEvents, SizeIs(4));
4119   EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4120   EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4121   EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4122   EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4123 
4124   // check observer2
4125   EXPECT_THAT(observer2->byteEvents, SizeIs(4));
4126   EXPECT_EQ(
4127       49U, observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4128   EXPECT_EQ(
4129       49U, observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4130   EXPECT_EQ(49U, observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4131   EXPECT_EQ(49U, observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4132 }
4133 
4134 /**
4135  * Test when kernel offset (uint32_t) wraps around.
4136  */
TEST_F(AsyncSocketByteEventTest,KernelOffsetWrap)4137 TEST_F(AsyncSocketByteEventTest, KernelOffsetWrap) {
4138   auto clientConn = getClientConn();
4139   clientConn.connect();
4140   clientConn.netOpsExpectTimestampingSetSockOpt();
4141   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4142   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4143   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4144   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4145   clientConn.netOpsVerifyAndClearExpectations();
4146 
4147   const uint64_t wbufSize = 3000000;
4148   const std::vector<uint8_t> wbuf(wbufSize, 'a');
4149 
4150   // part 1: write close to the wrap point with no ByteEvents to speed things up
4151   const uint64_t bytesToWritePt1 =
4152       static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) -
4153       (wbufSize * 5);
4154   while (clientConn.getRawSocket()->getRawBytesWritten() < bytesToWritePt1) {
4155     clientConn.write(wbuf, WriteFlags::NONE); // no reflect needed
4156   }
4157 
4158   // part 2: write over the wrap point with ByteEvents
4159   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4160       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4161   const uint64_t bytesToWritePt2 =
4162       static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) +
4163       (wbufSize * 5);
4164   while (clientConn.getRawSocket()->getRawBytesWritten() < bytesToWritePt2) {
4165     clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
4166         dropWriteFromFlags(flags));
4167     clientConn.writeAndReflect(wbuf, flags);
4168     clientConn.netOpsVerifyAndClearExpectations();
4169     const uint64_t expectedOffset =
4170         clientConn.getRawSocket()->getRawBytesWritten() - 1;
4171     EXPECT_EQ(
4172         expectedOffset,
4173         observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4174     EXPECT_EQ(
4175         expectedOffset,
4176         observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4177     EXPECT_EQ(
4178         expectedOffset,
4179         observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4180     EXPECT_EQ(
4181         expectedOffset,
4182         observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4183   }
4184 
4185   // part 3: one more write outside of a loop with extra checks
4186   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4187   clientConn.writeAndReflect(wbuf, flags);
4188   clientConn.netOpsVerifyAndClearExpectations();
4189   const auto expectedOffset =
4190       clientConn.getRawSocket()->getRawBytesWritten() - 1;
4191   EXPECT_LT(std::numeric_limits<uint32_t>::max(), expectedOffset);
4192   EXPECT_EQ(
4193       expectedOffset,
4194       observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4195   EXPECT_EQ(
4196       expectedOffset,
4197       observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4198   EXPECT_EQ(
4199       expectedOffset,
4200       observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4201   EXPECT_EQ(
4202       expectedOffset,
4203       observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4204 }
4205 
4206 /**
4207  * Ensure that ErrMessageCallback still works when ByteEvents enabled.
4208  */
TEST_F(AsyncSocketByteEventTest,ErrMessageCallbackStillTriggered)4209 TEST_F(AsyncSocketByteEventTest, ErrMessageCallbackStillTriggered) {
4210   auto clientConn = getClientConn();
4211   clientConn.connect();
4212   clientConn.netOpsExpectTimestampingSetSockOpt();
4213   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4214   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4215   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4216   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4217   clientConn.netOpsVerifyAndClearExpectations();
4218 
4219   TestErrMessageCallback errMsgCB;
4220   clientConn.getRawSocket()->setErrMessageCB(&errMsgCB);
4221 
4222   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4223       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4224 
4225   std::vector<uint8_t> wbuf(1, 'a');
4226   EXPECT_NE(WriteFlags::NONE, flags);
4227   EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
4228   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4229   clientConn.writeAndReflect(wbuf, flags);
4230   clientConn.netOpsVerifyAndClearExpectations();
4231 
4232   // observer should get events
4233   EXPECT_THAT(observer->byteEvents, SizeIs(4));
4234   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4235   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4236   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4237   EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4238 
4239   // err message callbach should get events, too
4240   EXPECT_EQ(3, errMsgCB.gotByteSeq_);
4241   EXPECT_EQ(3, errMsgCB.gotTimestamp_);
4242 
4243   // write again, more events for both
4244   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4245   clientConn.writeAndReflect(wbuf, flags);
4246   clientConn.netOpsVerifyAndClearExpectations();
4247   EXPECT_THAT(observer->byteEvents, SizeIs(8));
4248   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4249   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4250   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4251   EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4252   EXPECT_EQ(6, errMsgCB.gotByteSeq_);
4253   EXPECT_EQ(6, errMsgCB.gotTimestamp_);
4254 }
4255 
4256 /**
4257  * Ensure that ByteEvents disabled for unix sockets (not supported).
4258  */
TEST_F(AsyncSocketByteEventTest,FailUnixSocket)4259 TEST_F(AsyncSocketByteEventTest, FailUnixSocket) {
4260   std::shared_ptr<NiceMock<TestObserver>> observer;
4261   auto netOpsDispatcher = std::make_shared<MockDispatcher>();
4262 
4263   NetworkSocket fd[2];
4264   EXPECT_EQ(netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fd), 0);
4265   ASSERT_NE(fd[0], NetworkSocket());
4266   ASSERT_NE(fd[1], NetworkSocket());
4267   SCOPE_EXIT { netops::close(fd[1]); };
4268 
4269   EXPECT_EQ(netops::set_socket_non_blocking(fd[0]), 0);
4270   EXPECT_EQ(netops::set_socket_non_blocking(fd[1]), 0);
4271 
4272   auto clientSocketRaw = AsyncSocket::newSocket(nullptr, fd[0]);
4273   auto clientBlockingSocket = BlockingSocket(std::move(clientSocketRaw));
4274   clientBlockingSocket.getSocket()->setOverrideNetOpsDispatcher(
4275       netOpsDispatcher);
4276 
4277   // make sure no SO_TIMESTAMPING setsockopt on observer attach
4278   EXPECT_CALL(*netOpsDispatcher, setsockopt(_, _, _, _, _)).Times(AnyNumber());
4279   EXPECT_CALL(
4280       *netOpsDispatcher, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
4281       .Times(0); // no calls
4282   observer = attachObserver(
4283       clientBlockingSocket.getSocket(), true /* enableByteEvents */);
4284   EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4285   EXPECT_EQ(1, observer->byteEventsUnavailableCalled);
4286   EXPECT_TRUE(observer->byteEventsUnavailableCalledEx.has_value());
4287   Mock::VerifyAndClearExpectations(netOpsDispatcher.get());
4288 
4289   // do a write, we should see it has no timestamp flags
4290   const std::vector<uint8_t> wbuf(1, 'a');
4291   EXPECT_CALL(*netOpsDispatcher, sendmsg(_, _, _))
4292       .WillOnce(WithArgs<1>(Invoke([](const msghdr* message) {
4293         EXPECT_EQ(WriteFlags::NONE, getMsgAncillaryTsFlags(*message));
4294         return 1;
4295       })));
4296   clientBlockingSocket.write(
4297       wbuf.data(),
4298       wbuf.size(),
4299       WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4300           WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK);
4301   Mock::VerifyAndClearExpectations(netOpsDispatcher.get());
4302 }
4303 
4304 /**
4305  * If socket timestamps already enabled, do not enable ByteEvents.
4306  */
TEST_F(AsyncSocketByteEventTest,FailTimestampsAlreadyEnabled)4307 TEST_F(AsyncSocketByteEventTest, FailTimestampsAlreadyEnabled) {
4308   auto clientConn = getClientConn();
4309   clientConn.connect();
4310 
4311   // enable timestamps via setsockopt
4312   const uint32_t flags = folly::netops::SOF_TIMESTAMPING_OPT_ID |
4313       folly::netops::SOF_TIMESTAMPING_OPT_TSONLY |
4314       folly::netops::SOF_TIMESTAMPING_SOFTWARE |
4315       folly::netops::SOF_TIMESTAMPING_RAW_HARDWARE |
4316       folly::netops::SOF_TIMESTAMPING_OPT_TX_SWHW;
4317   const auto ret = clientConn.getRawSocket()->setSockOpt(
4318       SOL_SOCKET, SO_TIMESTAMPING, &flags);
4319   EXPECT_EQ(0, ret);
4320 
4321   clientConn.netOpsExpectNoTimestampingSetSockOpt();
4322   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4323   EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4324   EXPECT_EQ(1, observer->byteEventsUnavailableCalled); // fail
4325   EXPECT_TRUE(observer->byteEventsUnavailableCalledEx.has_value());
4326   clientConn.netOpsVerifyAndClearExpectations();
4327 
4328   std::vector<uint8_t> wbuf(1, 'a');
4329   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(WriteFlags::NONE);
4330   clientConn.writeAndReflect(
4331       wbuf,
4332       WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4333           WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK);
4334   clientConn.netOpsVerifyAndClearExpectations();
4335   EXPECT_THAT(observer->byteEvents, IsEmpty());
4336 }
4337 
4338 /**
4339  * Verify that ByteEvent information is properly copied during socket moves.
4340  */
4341 
TEST_F(AsyncSocketByteEventTest,MoveByteEventsEnabled)4342 TEST_F(AsyncSocketByteEventTest, MoveByteEventsEnabled) {
4343   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4344       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4345   const std::vector<uint8_t> wbuf(50, 'a');
4346 
4347   auto clientConn = getClientConn();
4348   clientConn.connect();
4349 
4350   // observer with ByteEvents enabled
4351   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4352   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4353   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4354   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4355 
4356   // move the socket immediately and add an observer with ByteEvents enabled
4357   auto clientConn2 = ClientConn(
4358       server_,
4359       AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
4360       clientConn.getAcceptedSocket());
4361   auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
4362   EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
4363   EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4364   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4365 
4366   // write following move, make sure the offsets are correct
4367   clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4368       dropWriteFromFlags(flags));
4369   clientConn2.writeAndReflect(wbuf, flags);
4370   clientConn2.netOpsVerifyAndClearExpectations();
4371   EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
4372   {
4373     const auto expectedOffset = 49U;
4374     EXPECT_EQ(
4375         expectedOffset,
4376         observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4377     EXPECT_EQ(
4378         expectedOffset,
4379         observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4380     EXPECT_EQ(
4381         expectedOffset,
4382         observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4383     EXPECT_EQ(
4384         expectedOffset,
4385         observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4386   }
4387 
4388   // write again
4389   clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4390       dropWriteFromFlags(flags));
4391   clientConn2.writeAndReflect(wbuf, flags);
4392   clientConn2.netOpsVerifyAndClearExpectations();
4393   EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
4394   {
4395     const auto expectedOffset = 99U;
4396     EXPECT_EQ(
4397         expectedOffset,
4398         observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4399     EXPECT_EQ(
4400         expectedOffset,
4401         observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4402     EXPECT_EQ(
4403         expectedOffset,
4404         observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4405     EXPECT_EQ(
4406         expectedOffset,
4407         observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4408   }
4409 }
4410 
TEST_F(AsyncSocketByteEventTest,WriteThenMoveByteEventsEnabled)4411 TEST_F(AsyncSocketByteEventTest, WriteThenMoveByteEventsEnabled) {
4412   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4413       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4414   const std::vector<uint8_t> wbuf(50, 'a');
4415 
4416   auto clientConn = getClientConn();
4417   clientConn.connect();
4418 
4419   // observer with ByteEvents enabled
4420   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4421   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4422   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4423   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4424 
4425   // write
4426   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4427   clientConn.writeAndReflect(wbuf, flags);
4428   clientConn.netOpsVerifyAndClearExpectations();
4429   EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
4430   {
4431     const auto expectedOffset = 49U;
4432     EXPECT_EQ(
4433         expectedOffset,
4434         observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4435     EXPECT_EQ(
4436         expectedOffset,
4437         observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4438     EXPECT_EQ(
4439         expectedOffset,
4440         observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4441     EXPECT_EQ(
4442         expectedOffset,
4443         observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4444   }
4445 
4446   // now move the socket and add an observer with ByteEvents enabled
4447   auto clientConn2 = ClientConn(
4448       server_,
4449       AsyncSocket::UniquePtr(
4450           new AsyncSocket(std::move(clientConn.getRawSocket().get()))),
4451       clientConn.getAcceptedSocket());
4452   auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
4453   EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
4454   EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4455   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4456 
4457   // write following move, make sure the offsets are correct
4458   clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4459       dropWriteFromFlags(flags));
4460   clientConn2.writeAndReflect(wbuf, flags);
4461   clientConn2.netOpsVerifyAndClearExpectations();
4462   EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
4463   {
4464     const auto expectedOffset = 99U;
4465     EXPECT_EQ(
4466         expectedOffset,
4467         observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4468     EXPECT_EQ(
4469         expectedOffset,
4470         observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4471     EXPECT_EQ(
4472         expectedOffset,
4473         observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4474     EXPECT_EQ(
4475         expectedOffset,
4476         observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4477   }
4478 
4479   // write again
4480   clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4481       dropWriteFromFlags(flags));
4482   clientConn2.writeAndReflect(wbuf, flags);
4483   clientConn2.netOpsVerifyAndClearExpectations();
4484   EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
4485   {
4486     const auto expectedOffset = 149U;
4487     EXPECT_EQ(
4488         expectedOffset,
4489         observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4490     EXPECT_EQ(
4491         expectedOffset,
4492         observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4493     EXPECT_EQ(
4494         expectedOffset,
4495         observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4496     EXPECT_EQ(
4497         expectedOffset,
4498         observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4499   }
4500 }
4501 
TEST_F(AsyncSocketByteEventTest,MoveThenEnableByteEvents)4502 TEST_F(AsyncSocketByteEventTest, MoveThenEnableByteEvents) {
4503   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4504       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4505   const std::vector<uint8_t> wbuf(50, 'a');
4506 
4507   auto clientConn = getClientConn();
4508   clientConn.connect();
4509 
4510   // observer with ByteEvents disabled
4511   auto observer = clientConn.attachObserver(false /* enableByteEvents */);
4512   EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4513   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4514   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4515 
4516   // move the socket immediately and add an observer with ByteEvents enabled
4517   auto clientConn2 = ClientConn(
4518       server_,
4519       AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
4520       clientConn.getAcceptedSocket());
4521   auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
4522   EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
4523   EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4524   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4525 
4526   // write following move, make sure the offsets are correct
4527   clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4528       dropWriteFromFlags(flags));
4529   clientConn2.writeAndReflect(wbuf, flags);
4530   clientConn2.netOpsVerifyAndClearExpectations();
4531   EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
4532   {
4533     const auto expectedOffset = 49U;
4534     EXPECT_EQ(
4535         expectedOffset,
4536         observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4537     EXPECT_EQ(
4538         expectedOffset,
4539         observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4540     EXPECT_EQ(
4541         expectedOffset,
4542         observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4543     EXPECT_EQ(
4544         expectedOffset,
4545         observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4546   }
4547 
4548   // write again
4549   clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4550       dropWriteFromFlags(flags));
4551   clientConn2.writeAndReflect(wbuf, flags);
4552   clientConn2.netOpsVerifyAndClearExpectations();
4553   EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
4554   {
4555     const auto expectedOffset = 99U;
4556     EXPECT_EQ(
4557         expectedOffset,
4558         observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4559     EXPECT_EQ(
4560         expectedOffset,
4561         observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4562     EXPECT_EQ(
4563         expectedOffset,
4564         observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4565     EXPECT_EQ(
4566         expectedOffset,
4567         observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4568   }
4569 }
4570 
TEST_F(AsyncSocketByteEventTest,WriteThenMoveThenEnableByteEvents)4571 TEST_F(AsyncSocketByteEventTest, WriteThenMoveThenEnableByteEvents) {
4572   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4573       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4574   const std::vector<uint8_t> wbuf(50, 'a');
4575 
4576   auto clientConn = getClientConn();
4577   clientConn.connect();
4578 
4579   // observer with ByteEvents disabled
4580   auto observer = clientConn.attachObserver(false /* enableByteEvents */);
4581   EXPECT_EQ(0, observer->byteEventsEnabledCalled);
4582   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4583   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4584 
4585   // write, ByteEvents disabled
4586   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
4587       WriteFlags::NONE); // events diabled
4588   clientConn.writeAndReflect(wbuf, flags);
4589   clientConn.netOpsVerifyAndClearExpectations();
4590 
4591   // now move the socket and add an observer with ByteEvents enabled
4592   auto clientConn2 = ClientConn(
4593       server_,
4594       AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
4595       clientConn.getAcceptedSocket());
4596   auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
4597   EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
4598   EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
4599   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4600 
4601   // write following move, make sure the offsets are correct
4602   clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4603       dropWriteFromFlags(flags));
4604   clientConn2.writeAndReflect(wbuf, flags);
4605   clientConn2.netOpsVerifyAndClearExpectations();
4606   EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
4607   {
4608     const auto expectedOffset = 99U;
4609     EXPECT_EQ(
4610         expectedOffset,
4611         observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4612     EXPECT_EQ(
4613         expectedOffset,
4614         observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4615     EXPECT_EQ(
4616         expectedOffset,
4617         observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4618     EXPECT_EQ(
4619         expectedOffset,
4620         observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4621   }
4622 
4623   // write again
4624   clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4625       dropWriteFromFlags(flags));
4626   clientConn2.writeAndReflect(wbuf, flags);
4627   clientConn2.netOpsVerifyAndClearExpectations();
4628   EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
4629   {
4630     const auto expectedOffset = 149U;
4631     EXPECT_EQ(
4632         expectedOffset,
4633         observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4634     EXPECT_EQ(
4635         expectedOffset,
4636         observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4637     EXPECT_EQ(
4638         expectedOffset,
4639         observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
4640     EXPECT_EQ(
4641         expectedOffset,
4642         observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
4643   }
4644 }
4645 
TEST_F(AsyncSocketByteEventTest,NoObserverMoveThenEnableByteEvents)4646 TEST_F(AsyncSocketByteEventTest, NoObserverMoveThenEnableByteEvents) {
4647   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4648       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4649   const std::vector<uint8_t> wbuf(50, 'a');
4650 
4651   auto clientConn = getClientConn();
4652   clientConn.connect();
4653 
4654   // no observer
4655 
4656   // move the socket immediately and add an observer with ByteEvents enabled
4657   auto clientConn2 = ClientConn(
4658       server_,
4659       AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
4660       clientConn.getAcceptedSocket());
4661   auto observer = clientConn2.attachObserver(true /* enableByteEvents */);
4662   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4663   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4664   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4665 
4666   // write following move, make sure the offsets are correct
4667   clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4668       dropWriteFromFlags(flags));
4669   clientConn2.writeAndReflect(wbuf, flags);
4670   clientConn2.netOpsVerifyAndClearExpectations();
4671   EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
4672   {
4673     const auto expectedOffset = 49U;
4674     EXPECT_EQ(
4675         expectedOffset,
4676         observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4677     EXPECT_EQ(
4678         expectedOffset,
4679         observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4680     EXPECT_EQ(
4681         expectedOffset,
4682         observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4683     EXPECT_EQ(
4684         expectedOffset,
4685         observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4686   }
4687 
4688   // write again
4689   clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
4690       dropWriteFromFlags(flags));
4691   clientConn2.writeAndReflect(wbuf, flags);
4692   clientConn2.netOpsVerifyAndClearExpectations();
4693   EXPECT_THAT(observer->byteEvents, SizeIs(Ge(8)));
4694   {
4695     const auto expectedOffset = 99U;
4696     EXPECT_EQ(
4697         expectedOffset,
4698         observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
4699     EXPECT_EQ(
4700         expectedOffset,
4701         observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
4702     EXPECT_EQ(
4703         expectedOffset,
4704         observer->maxOffsetForByteEventReceived(ByteEventType::TX));
4705     EXPECT_EQ(
4706         expectedOffset,
4707         observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
4708   }
4709 }
4710 
4711 /**
4712  * Inspect ByteEvent fields, including xTimestampRequested in WRITE events.
4713  *
4714  * See CheckByteEventDetailsRawBytesWrittenAndTriedToWrite and
4715  * AsyncSocketByteEventDetailsTest::CheckByteEventDetails as well.
4716  */
TEST_F(AsyncSocketByteEventTest,CheckByteEventDetails)4717 TEST_F(AsyncSocketByteEventTest, CheckByteEventDetails) {
4718   const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
4719       WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
4720   const std::vector<uint8_t> wbuf(1, 'a');
4721 
4722   auto clientConn = getClientConn();
4723   clientConn.connect();
4724   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4725   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4726   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4727   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4728 
4729   EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
4730   clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
4731   clientConn.writeAndReflect(wbuf, flags);
4732   clientConn.netOpsVerifyAndClearExpectations();
4733   EXPECT_THAT(observer->byteEvents, SizeIs(Eq(4)));
4734   const auto expectedOffset = wbuf.size() - 1;
4735 
4736   // check WRITE
4737   {
4738     auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
4739         expectedOffset, ByteEventType::WRITE);
4740     ASSERT_TRUE(maybeByteEvent.has_value());
4741     auto& byteEvent = maybeByteEvent.value();
4742 
4743     EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
4744     EXPECT_EQ(expectedOffset, byteEvent.offset);
4745     EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
4746     EXPECT_LT(
4747         std::chrono::steady_clock::now() - std::chrono::seconds(60),
4748         byteEvent.ts);
4749 
4750     EXPECT_EQ(flags, byteEvent.maybeWriteFlags);
4751     EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
4752     EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
4753     EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());
4754 
4755     EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
4756     EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
4757 
4758     // maybeRawBytesWritten and maybeRawBytesTriedToWrite are tested in
4759     // CheckByteEventDetailsRawBytesWrittenAndTriedToWrite
4760   }
4761 
4762   // check SCHED, TX, ACK
4763   for (const auto& byteEventType :
4764        {ByteEventType::SCHED, ByteEventType::TX, ByteEventType::ACK}) {
4765     auto maybeByteEvent =
4766         observer->getByteEventReceivedWithOffset(expectedOffset, byteEventType);
4767     ASSERT_TRUE(maybeByteEvent.has_value());
4768     auto& byteEvent = maybeByteEvent.value();
4769 
4770     EXPECT_EQ(byteEventType, byteEvent.type);
4771     EXPECT_EQ(expectedOffset, byteEvent.offset);
4772     EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
4773     EXPECT_LT(
4774         std::chrono::steady_clock::now() - std::chrono::seconds(60),
4775         byteEvent.ts);
4776 
4777     EXPECT_FALSE(byteEvent.maybeWriteFlags.has_value());
4778     EXPECT_DEATH(byteEvent.schedTimestampRequestedOnWrite(), ".*");
4779     EXPECT_DEATH(byteEvent.txTimestampRequestedOnWrite(), ".*");
4780     EXPECT_DEATH(byteEvent.ackTimestampRequestedOnWrite(), ".*");
4781 
4782     EXPECT_TRUE(byteEvent.maybeSoftwareTs.has_value());
4783     EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
4784   }
4785 }
4786 
4787 /**
4788  * Inspect ByteEvent fields maybeRawBytesWritten and maybeRawBytesTriedToWrite.
4789  */
TEST_F(AsyncSocketByteEventTest,CheckByteEventDetailsRawBytesWrittenAndTriedToWrite)4790 TEST_F(
4791     AsyncSocketByteEventTest,
4792     CheckByteEventDetailsRawBytesWrittenAndTriedToWrite) {
4793   auto clientConn = getClientConn();
4794   clientConn.connect();
4795   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
4796   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
4797   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
4798   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
4799 
4800   struct ExpectedSendmsgInvocation {
4801     size_t expectedTotalIovLen{0};
4802     ssize_t returnVal{0}; // number of bytes written or error val
4803     folly::Optional<size_t> maybeWriteEventExpectedOffset{};
4804     folly::Optional<WriteFlags> maybeWriteEventExpectedFlags{};
4805   };
4806 
4807   const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
4808       WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
4809 
4810   // first write
4811   //
4812   // no splits triggered by observer
4813   //
4814   // sendmsg will incrementally accept the bytes so we can test the values of
4815   // maybeRawBytesWritten and maybeRawBytesTriedToWrite
4816   {
4817     // bytes written per sendmsg call: 20, 10, 50, -1 (EAGAIN), 11, 99
4818     const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
4819         // {
4820         //    expectedTotalIovLen, returnVal,
4821         //    maybeWriteEventExpectedOffset, maybeWriteEventExpectedFlags
4822         // },
4823         {100, 20, 19, flags},
4824         {80, 10, 29, flags},
4825         {70, 50, 79, flags},
4826         {20, -1, folly::none, flags},
4827         {20, 11, 90, flags},
4828         {9, 9, 99, flags}};
4829 
4830     // sendmsg will be called, we return # of bytes written
4831     {
4832       InSequence s;
4833       for (const auto& expectedInvocation : expectedSendmsgInvocations) {
4834         EXPECT_CALL(
4835             *(clientConn.getNetOpsDispatcher()),
4836             sendmsg(
4837                 _,
4838                 Pointee(SendmsgMsghdrHasTotalIovLen(
4839                     expectedInvocation.expectedTotalIovLen)),
4840                 _))
4841             .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
4842               if (expectedInvocation.returnVal < 0) {
4843                 errno = EAGAIN; // returning error, set EAGAIN
4844               }
4845               return expectedInvocation.returnVal;
4846             }));
4847       }
4848     }
4849 
4850     // write
4851     // writes will be intercepted, so we don't need to read at other end
4852     WriteCallback wcb;
4853     clientConn.getRawSocket()->write(
4854         &wcb,
4855         kOneHundredCharacterVec.data(),
4856         kOneHundredCharacterVec.size(),
4857         flags);
4858     while (STATE_WAITING == wcb.state) {
4859       clientConn.getRawSocket()->getEventBase()->loopOnce();
4860     }
4861     ASSERT_EQ(STATE_SUCCEEDED, wcb.state);
4862 
4863     // check write events
4864     for (const auto& expectedInvocation : expectedSendmsgInvocations) {
4865       if (expectedInvocation.returnVal < 0) {
4866         // should be no WriteEvent since the return value was an error
4867         continue;
4868       }
4869 
4870       ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
4871       const auto& expectedOffset =
4872           *expectedInvocation.maybeWriteEventExpectedOffset;
4873 
4874       auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
4875           expectedOffset, ByteEventType::WRITE);
4876       ASSERT_TRUE(maybeByteEvent.has_value());
4877       auto& byteEvent = maybeByteEvent.value();
4878 
4879       EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
4880       EXPECT_EQ(expectedOffset, byteEvent.offset);
4881       EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
4882       EXPECT_LT(
4883           std::chrono::steady_clock::now() - std::chrono::seconds(60),
4884           byteEvent.ts);
4885 
4886       EXPECT_EQ(
4887           expectedInvocation.maybeWriteEventExpectedFlags,
4888           byteEvent.maybeWriteFlags);
4889       EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
4890       EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
4891       EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());
4892 
4893       EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
4894       EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
4895 
4896       // what we really want to test
4897       EXPECT_EQ(
4898           folly::to_unsigned(expectedInvocation.returnVal),
4899           byteEvent.maybeRawBytesWritten);
4900       EXPECT_EQ(
4901           expectedInvocation.expectedTotalIovLen,
4902           byteEvent.maybeRawBytesTriedToWrite);
4903     }
4904   }
4905 
4906   // everything should have occurred by now
4907   clientConn.netOpsVerifyAndClearExpectations();
4908 
4909   // second write
4910   //
4911   // sendmsg will incrementally accept the bytes so we can test the values of
4912   // maybeRawBytesWritten and maybeRawBytesTriedToWrite
4913   {
4914     // bytes written per sendmsg call: 20, 30, 50
4915     const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
4916         {100, 20, 119, flags}, {80, 30, 149, flags}, {50, 50, 199, flags}};
4917 
4918     // sendmsg will be called, we return # of bytes written
4919     {
4920       InSequence s;
4921       for (const auto& expectedInvocation : expectedSendmsgInvocations) {
4922         EXPECT_CALL(
4923             *(clientConn.getNetOpsDispatcher()),
4924             sendmsg(
4925                 _,
4926                 Pointee(SendmsgMsghdrHasTotalIovLen(
4927                     expectedInvocation.expectedTotalIovLen)),
4928                 _))
4929             .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
4930               return expectedInvocation.returnVal;
4931             }));
4932       }
4933     }
4934 
4935     // write
4936     // writes will be intercepted, so we don't need to read at other end
4937     WriteCallback wcb;
4938     clientConn.getRawSocket()->write(
4939         &wcb,
4940         kOneHundredCharacterVec.data(),
4941         kOneHundredCharacterVec.size(),
4942         flags);
4943     while (STATE_WAITING == wcb.state) {
4944       clientConn.getRawSocket()->getEventBase()->loopOnce();
4945     }
4946     ASSERT_EQ(STATE_SUCCEEDED, wcb.state);
4947 
4948     // check write events
4949     for (const auto& expectedInvocation : expectedSendmsgInvocations) {
4950       ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
4951       const auto& expectedOffset =
4952           *expectedInvocation.maybeWriteEventExpectedOffset;
4953 
4954       auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
4955           expectedOffset, ByteEventType::WRITE);
4956       ASSERT_TRUE(maybeByteEvent.has_value());
4957       auto& byteEvent = maybeByteEvent.value();
4958 
4959       EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
4960       EXPECT_EQ(expectedOffset, byteEvent.offset);
4961       EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
4962       EXPECT_LT(
4963           std::chrono::steady_clock::now() - std::chrono::seconds(60),
4964           byteEvent.ts);
4965 
4966       EXPECT_EQ(
4967           expectedInvocation.maybeWriteEventExpectedFlags,
4968           byteEvent.maybeWriteFlags);
4969       EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
4970       EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
4971       EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());
4972 
4973       EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
4974       EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
4975 
4976       // what we really want to test
4977       EXPECT_EQ(
4978           folly::to_unsigned(expectedInvocation.returnVal),
4979           byteEvent.maybeRawBytesWritten);
4980       EXPECT_EQ(
4981           expectedInvocation.expectedTotalIovLen,
4982           byteEvent.maybeRawBytesTriedToWrite);
4983     }
4984   }
4985 }
4986 
TEST_F(AsyncSocketByteEventTest,SplitIoVecArraySingleIoVec)4987 TEST_F(AsyncSocketByteEventTest, SplitIoVecArraySingleIoVec) {
4988   // get srciov from lambda to enable us to keep it const during test
4989   const char* buf = kOneHundredCharacterString.c_str();
4990   auto getSrcIov = [&buf]() {
4991     std::vector<struct iovec> srcIov(2);
4992     srcIov[0].iov_base = const_cast<void*>(static_cast<const void*>(buf));
4993     srcIov[0].iov_len = kOneHundredCharacterString.size();
4994     return srcIov;
4995   };
4996 
4997   std::vector<struct iovec> srcIov = getSrcIov();
4998   const auto data = srcIov.data();
4999 
5000   // split 0 -> 0 (first byte)
5001   {
5002     std::vector<struct iovec> dstIov(4);
5003     size_t dstIovCount = dstIov.size();
5004     AsyncSocket::splitIovecArray(
5005         0, 0, data, srcIov.size(), dstIov.data(), dstIovCount);
5006 
5007     ASSERT_EQ(1, dstIovCount);
5008     EXPECT_EQ(1, dstIov[0].iov_len);
5009     EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5010     EXPECT_EQ(buf, dstIov[0].iov_base);
5011   }
5012 
5013   // split 0 -> 49 (50th byte)
5014   {
5015     std::vector<struct iovec> dstIov(4);
5016     size_t dstIovCount = dstIov.size();
5017     AsyncSocket::splitIovecArray(
5018         0, 49, data, srcIov.size(), dstIov.data(), dstIovCount);
5019 
5020     ASSERT_EQ(1, dstIovCount);
5021     EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5022     EXPECT_EQ(50, dstIov[0].iov_len);
5023   }
5024 
5025   // split 0 -> 98 (penultimate byte)
5026   {
5027     std::vector<struct iovec> dstIov(4);
5028     size_t dstIovCount = dstIov.size();
5029     AsyncSocket::splitIovecArray(
5030         0, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5031 
5032     ASSERT_EQ(1, dstIovCount);
5033     EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5034     EXPECT_EQ(99, dstIov[0].iov_len);
5035   }
5036 
5037   // split 0 -> 99 (pointless split)
5038   {
5039     std::vector<struct iovec> dstIov(4);
5040     size_t dstIovCount = dstIov.size();
5041     AsyncSocket::splitIovecArray(
5042         0, 99, data, srcIov.size(), dstIov.data(), dstIovCount);
5043 
5044     ASSERT_EQ(1, dstIovCount);
5045     EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5046     EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5047   }
5048 }
5049 
TEST_F(AsyncSocketByteEventTest,SplitIoVecArrayMultiIoVecInvalid)5050 TEST_F(AsyncSocketByteEventTest, SplitIoVecArrayMultiIoVecInvalid) {
5051   // get srciov from lambda to enable us to keep it const during test
5052   const char* buf = kOneHundredCharacterString.c_str();
5053   auto getSrcIov = [&buf]() {
5054     std::vector<struct iovec> srcIov(4);
5055     srcIov[0].iov_base = const_cast<void*>(static_cast<const void*>(buf));
5056     srcIov[0].iov_len = 50;
5057     srcIov[1].iov_base = const_cast<void*>(static_cast<const void*>(buf + 50));
5058     srcIov[1].iov_len = 50;
5059     return srcIov;
5060   };
5061 
5062   std::vector<struct iovec> srcIov = getSrcIov();
5063   const auto data = srcIov.data();
5064 
5065   // dstIov.size() < srcIov.size(); this is not allowed
5066   std::vector<struct iovec> dstIov(1);
5067   size_t dstIovCount = dstIov.size();
5068   EXPECT_LT(dstIovCount, srcIov.size());
5069   EXPECT_DEATH(
5070       AsyncSocket::splitIovecArray(
5071           0, 0, data, srcIov.size(), dstIov.data(), dstIovCount),
5072       ".*");
5073 }
5074 
TEST_F(AsyncSocketByteEventTest,SplitIoVecArrayMultiIoVec)5075 TEST_F(AsyncSocketByteEventTest, SplitIoVecArrayMultiIoVec) {
5076   // get srciov from lambda to enable us to keep it const during test
5077   const char* buf = kOneHundredCharacterString.c_str();
5078   auto getSrcIov = [&buf]() {
5079     std::vector<struct iovec> srcIov(4);
5080     srcIov[0].iov_base = const_cast<void*>(static_cast<const void*>(buf));
5081     srcIov[0].iov_len = 25;
5082     srcIov[1].iov_base = const_cast<void*>(static_cast<const void*>(buf + 25));
5083     srcIov[1].iov_len = 25;
5084     srcIov[2].iov_base = const_cast<void*>(static_cast<const void*>(buf + 50));
5085     srcIov[2].iov_len = 25;
5086     srcIov[3].iov_base = const_cast<void*>(static_cast<const void*>(buf + 75));
5087     srcIov[3].iov_len = 25;
5088     return srcIov;
5089   };
5090 
5091   std::vector<struct iovec> srcIov = getSrcIov();
5092   const auto data = srcIov.data();
5093 
5094   // split 0 -> 0 (first byte)
5095   {
5096     std::vector<struct iovec> dstIov(4);
5097     size_t dstIovCount = dstIov.size();
5098     AsyncSocket::splitIovecArray(
5099         0, 0, data, srcIov.size(), dstIov.data(), dstIovCount);
5100 
5101     ASSERT_EQ(1, dstIovCount);
5102     EXPECT_EQ(1, dstIov[0].iov_len);
5103     EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5104     EXPECT_EQ(buf, dstIov[0].iov_base);
5105   }
5106 
5107   // split 0 -> 98 (penultimate byte)
5108   {
5109     std::vector<struct iovec> dstIov(4);
5110     size_t dstIovCount = dstIov.size();
5111     AsyncSocket::splitIovecArray(
5112         0, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5113 
5114     ASSERT_EQ(4, dstIovCount);
5115     EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5116     EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5117     EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
5118     EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
5119     EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
5120     EXPECT_EQ(srcIov[2].iov_len, dstIov[2].iov_len);
5121 
5122     // last iovec is different
5123     EXPECT_EQ(24, dstIov[3].iov_len);
5124     EXPECT_EQ(srcIov[3].iov_base, dstIov[3].iov_base);
5125   }
5126 
5127   // split 0 -> 99 (pointless split)
5128   {
5129     std::vector<struct iovec> dstIov(4);
5130     size_t dstIovCount = dstIov.size();
5131     AsyncSocket::splitIovecArray(
5132         0, 99, data, srcIov.size(), dstIov.data(), dstIovCount);
5133 
5134     ASSERT_EQ(4, dstIovCount);
5135     EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5136     EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5137     EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
5138     EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
5139     EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
5140     EXPECT_EQ(srcIov[2].iov_len, dstIov[2].iov_len);
5141     EXPECT_EQ(srcIov[3].iov_base, dstIov[3].iov_base);
5142     EXPECT_EQ(srcIov[3].iov_len, dstIov[3].iov_len);
5143   }
5144 
5145   //
5146   // test when endOffset is near a iovec boundary
5147   //
5148 
5149   // split 0 -> 49 (50th byte)
5150   {
5151     std::vector<struct iovec> dstIov(4);
5152     size_t dstIovCount = dstIov.size();
5153     AsyncSocket::splitIovecArray(
5154         0, 49, data, srcIov.size(), dstIov.data(), dstIovCount);
5155 
5156     ASSERT_EQ(2, dstIovCount);
5157     EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5158     EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5159     EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
5160     EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
5161   }
5162 
5163   // split 0 -> 50 (51st byte)
5164   {
5165     std::vector<struct iovec> dstIov(4);
5166     size_t dstIovCount = dstIov.size();
5167     AsyncSocket::splitIovecArray(
5168         0, 50, data, srcIov.size(), dstIov.data(), dstIovCount);
5169 
5170     ASSERT_EQ(3, dstIovCount);
5171     EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5172     EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5173     EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
5174     EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
5175 
5176     // last iovec is one byte
5177     EXPECT_EQ(1, dstIov[2].iov_len);
5178     EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
5179   }
5180 
5181   // split 0 -> 51 (52nd byte)
5182   {
5183     std::vector<struct iovec> dstIov(4);
5184     size_t dstIovCount = dstIov.size();
5185     AsyncSocket::splitIovecArray(
5186         0, 51, data, srcIov.size(), dstIov.data(), dstIovCount);
5187 
5188     ASSERT_EQ(3, dstIovCount);
5189     EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
5190     EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
5191     EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
5192     EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
5193 
5194     // last iovec is two bytes
5195     EXPECT_EQ(2, dstIov[2].iov_len);
5196     EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
5197   }
5198 
5199   //
5200   // test when startOffset is near a iovec boundary
5201   //
5202 
5203   // split 49 -> 99
5204   {
5205     std::vector<struct iovec> dstIov(4);
5206     size_t dstIovCount = dstIov.size();
5207     AsyncSocket::splitIovecArray(
5208         49, 99, data, srcIov.size(), dstIov.data(), dstIovCount);
5209 
5210     ASSERT_EQ(3, dstIovCount);
5211 
5212     // first dst iovec is one byte, starts 24 bytes in to the second src iovec
5213     EXPECT_EQ(1, dstIov[0].iov_len);
5214     EXPECT_EQ(
5215         dstIov[0].iov_base,
5216         const_cast<void*>(static_cast<const void*>(
5217             reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));
5218 
5219     // second dst iovec is third src iovec
5220     // third dst iovec is fourth src iovec
5221     EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
5222     EXPECT_EQ(dstIov[1].iov_len, srcIov[2].iov_len);
5223     EXPECT_EQ(dstIov[2].iov_base, srcIov[3].iov_base);
5224     EXPECT_EQ(dstIov[2].iov_len, srcIov[3].iov_len);
5225   }
5226 
5227   // split 50 -> 99
5228   {
5229     std::vector<struct iovec> dstIov(4);
5230     size_t dstIovCount = dstIov.size();
5231     AsyncSocket::splitIovecArray(
5232         50, 99, data, srcIov.size(), dstIov.data(), dstIovCount);
5233 
5234     ASSERT_EQ(2, dstIovCount);
5235 
5236     // first dst iovec is third src iovec
5237     // second dst iovec is fourth src iovec
5238     EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
5239     EXPECT_EQ(dstIov[0].iov_len, srcIov[2].iov_len);
5240     EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
5241     EXPECT_EQ(dstIov[1].iov_len, srcIov[3].iov_len);
5242   }
5243 
5244   // split 51 -> 99
5245   {
5246     std::vector<struct iovec> dstIov(4);
5247     size_t dstIovCount = dstIov.size();
5248     AsyncSocket::splitIovecArray(
5249         51, 99, data, srcIov.size(), dstIov.data(), dstIovCount);
5250 
5251     ASSERT_EQ(2, dstIovCount);
5252 
5253     // first dst iovec is 24 bytes, starts 1 byte in to the third src iovec
5254     EXPECT_EQ(24, dstIov[0].iov_len);
5255     EXPECT_EQ(
5256         dstIov[0].iov_base,
5257         const_cast<void*>(static_cast<const void*>(
5258             reinterpret_cast<uint8_t*>(srcIov[2].iov_base) + 1)));
5259 
5260     // second dst iovec is fourth src iovec
5261     EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
5262     EXPECT_EQ(dstIov[1].iov_len, srcIov[3].iov_len);
5263   }
5264 
5265   //
5266   // test when startOffset and endOffset are near iovec boundaries
5267   //
5268 
5269   // split 49 -> 49
5270   {
5271     std::vector<struct iovec> dstIov(4);
5272     size_t dstIovCount = dstIov.size();
5273     AsyncSocket::splitIovecArray(
5274         49, 49, data, srcIov.size(), dstIov.data(), dstIovCount);
5275 
5276     ASSERT_EQ(1, dstIovCount);
5277 
5278     // first dst iovec is one byte, starts 24 bytes in to the second src iovec
5279     EXPECT_EQ(1, dstIov[0].iov_len);
5280     EXPECT_EQ(
5281         dstIov[0].iov_base,
5282         const_cast<void*>(static_cast<const void*>(
5283             reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));
5284   }
5285 
5286   // split 49 -> 50
5287   {
5288     std::vector<struct iovec> dstIov(4);
5289     size_t dstIovCount = dstIov.size();
5290     AsyncSocket::splitIovecArray(
5291         49, 50, data, srcIov.size(), dstIov.data(), dstIovCount);
5292 
5293     ASSERT_EQ(2, dstIovCount);
5294 
5295     // first dst iovec is one byte, starts 24 bytes in to the second src iovec
5296     EXPECT_EQ(1, dstIov[0].iov_len);
5297     EXPECT_EQ(
5298         dstIov[0].iov_base,
5299         const_cast<void*>(static_cast<const void*>(
5300             reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));
5301 
5302     // second iovec is one byte, starts at the third src iovec
5303     EXPECT_EQ(1, dstIov[1].iov_len);
5304     EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
5305   }
5306 
5307   // split 49 -> 51
5308   {
5309     std::vector<struct iovec> dstIov(4);
5310     size_t dstIovCount = dstIov.size();
5311     AsyncSocket::splitIovecArray(
5312         49, 51, data, srcIov.size(), dstIov.data(), dstIovCount);
5313 
5314     ASSERT_EQ(2, dstIovCount);
5315 
5316     // first dst iovec is one byte, starts 24 bytes in to the second src iovec
5317     EXPECT_EQ(1, dstIov[0].iov_len);
5318     EXPECT_EQ(
5319         dstIov[0].iov_base,
5320         const_cast<void*>(static_cast<const void*>(
5321             reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));
5322 
5323     // second iovec is two bytes, starts at the third src iovec
5324     EXPECT_EQ(2, dstIov[1].iov_len);
5325     EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
5326   }
5327 
5328   // split 50 -> 50
5329   {
5330     std::vector<struct iovec> dstIov(4);
5331     size_t dstIovCount = dstIov.size();
5332     AsyncSocket::splitIovecArray(
5333         50, 50, data, srcIov.size(), dstIov.data(), dstIovCount);
5334 
5335     ASSERT_EQ(1, dstIovCount);
5336 
5337     // first dst iovec is one byte, starts at the third src iovec
5338     EXPECT_EQ(1, dstIov[0].iov_len);
5339     EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
5340   }
5341 
5342   // split 50 -> 51
5343   {
5344     std::vector<struct iovec> dstIov(4);
5345     size_t dstIovCount = dstIov.size();
5346     AsyncSocket::splitIovecArray(
5347         50, 51, data, srcIov.size(), dstIov.data(), dstIovCount);
5348 
5349     ASSERT_EQ(1, dstIovCount);
5350 
5351     // first dst iovec is two bytes, starts at the third src iovec
5352     EXPECT_EQ(2, dstIov[0].iov_len);
5353     EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
5354   }
5355 
5356   // split 51 -> 51
5357   {
5358     std::vector<struct iovec> dstIov(4);
5359     size_t dstIovCount = dstIov.size();
5360     AsyncSocket::splitIovecArray(
5361         51, 51, data, srcIov.size(), dstIov.data(), dstIovCount);
5362 
5363     ASSERT_EQ(1, dstIovCount);
5364 
5365     // first dst iovec is one byte, starts 1 byte into the third src iovec
5366     EXPECT_EQ(1, dstIov[0].iov_len);
5367     EXPECT_EQ(
5368         dstIov[0].iov_base,
5369         const_cast<void*>(static_cast<const void*>(
5370             reinterpret_cast<uint8_t*>(srcIov[2].iov_base) + 1)));
5371   }
5372 
5373   // split 48 -> 98
5374   {
5375     std::vector<struct iovec> dstIov(4);
5376     size_t dstIovCount = dstIov.size();
5377     AsyncSocket::splitIovecArray(
5378         48, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5379 
5380     ASSERT_EQ(3, dstIovCount);
5381 
5382     // first dst iovec is two bytes, starts 23 bytes in to the second src iovec
5383     EXPECT_EQ(2, dstIov[0].iov_len);
5384     EXPECT_EQ(
5385         dstIov[0].iov_base,
5386         const_cast<void*>(static_cast<const void*>(
5387             reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 23)));
5388 
5389     // second dst iovec is third src iovec
5390     EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
5391     EXPECT_EQ(dstIov[1].iov_len, srcIov[2].iov_len);
5392 
5393     // third dst iovec is 24 bytes, starts at the fourth src iovec
5394     EXPECT_EQ(24, dstIov[2].iov_len);
5395     EXPECT_EQ(dstIov[2].iov_base, srcIov[3].iov_base);
5396   }
5397 
5398   // split 49 -> 98
5399   {
5400     std::vector<struct iovec> dstIov(4);
5401     size_t dstIovCount = dstIov.size();
5402     AsyncSocket::splitIovecArray(
5403         49, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5404 
5405     ASSERT_EQ(3, dstIovCount);
5406 
5407     // first dst iovec is one byte, starts 24 bytes in to the second src iovec
5408     EXPECT_EQ(1, dstIov[0].iov_len);
5409     EXPECT_EQ(
5410         dstIov[0].iov_base,
5411         const_cast<void*>(static_cast<const void*>(
5412             reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));
5413 
5414     // second dst iovec is third src iovec
5415     EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
5416     EXPECT_EQ(dstIov[1].iov_len, srcIov[2].iov_len);
5417 
5418     // third dst iovec is 24 bytes, starts at the fourth src iovec
5419     EXPECT_EQ(24, dstIov[2].iov_len);
5420     EXPECT_EQ(dstIov[2].iov_base, srcIov[3].iov_base);
5421   }
5422 
5423   // split 50 -> 98
5424   {
5425     std::vector<struct iovec> dstIov(4);
5426     size_t dstIovCount = dstIov.size();
5427     AsyncSocket::splitIovecArray(
5428         50, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5429 
5430     ASSERT_EQ(2, dstIovCount);
5431 
5432     // first dst iovec is third src iovec
5433     EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
5434     EXPECT_EQ(dstIov[0].iov_len, srcIov[2].iov_len);
5435 
5436     // second dst iovec is 24 bytes, starts at the fourth src iovec
5437     EXPECT_EQ(24, dstIov[1].iov_len);
5438     EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
5439   }
5440 
5441   // split 51 -> 98
5442   {
5443     std::vector<struct iovec> dstIov(4);
5444     size_t dstIovCount = dstIov.size();
5445     AsyncSocket::splitIovecArray(
5446         51, 98, data, srcIov.size(), dstIov.data(), dstIovCount);
5447 
5448     ASSERT_EQ(2, dstIovCount);
5449 
5450     // first dst iovec is 24 bytes, starts 1 byte in to the third src iovec
5451     EXPECT_EQ(24, dstIov[0].iov_len);
5452     EXPECT_EQ(
5453         dstIov[0].iov_base,
5454         const_cast<void*>(static_cast<const void*>(
5455             reinterpret_cast<uint8_t*>(srcIov[2].iov_base) + 1)));
5456 
5457     // second dst iovec is 24 bytes, starts at the fourth src iovec
5458     EXPECT_EQ(24, dstIov[1].iov_len);
5459     EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
5460   }
5461 }
5462 
TEST_F(AsyncSocketByteEventTest,SendmsgMatchers)5463 TEST_F(AsyncSocketByteEventTest, SendmsgMatchers) {
5464   // empty
5465   {
5466     const ClientConn::SendmsgInvocation sendmsgInvoc = {};
5467     // length
5468     EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(0)));
5469 
5470     // iov first byte
5471     EXPECT_THAT(
5472         sendmsgInvoc,
5473         Not(SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data())));
5474     EXPECT_THAT(
5475         sendmsgInvoc,
5476         Not(SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data() + 5)));
5477 
5478     // iov last byte
5479     EXPECT_THAT(
5480         sendmsgInvoc,
5481         Not(SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data())));
5482     EXPECT_THAT(
5483         sendmsgInvoc,
5484         Not(SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 5)));
5485   }
5486 
5487   // single iov, last byte = end of kOneHundredCharacterVec
5488   {
5489     struct iovec iov = {};
5490     iov.iov_base = const_cast<void*>(
5491         static_cast<const void*>((kOneHundredCharacterVec.data())));
5492     iov.iov_len = kOneHundredCharacterVec.size();
5493     const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov}};
5494 
5495     struct msghdr msg = {};
5496     msg.msg_name = nullptr;
5497     msg.msg_namelen = 0;
5498     msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5499     msg.msg_iovlen = sendmsgInvoc.iovs.size();
5500 
5501     // length
5502     EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(100)));
5503     EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(100)));
5504 
5505     // iov first byte
5506     EXPECT_THAT(
5507         sendmsgInvoc,
5508         SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
5509     EXPECT_THAT(
5510         sendmsgInvoc,
5511         Not(SendmsgInvocHasIovFirstByte(
5512             kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5513             1)));
5514 
5515     // iov last byte
5516     EXPECT_THAT(
5517         sendmsgInvoc,
5518         Not(SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data())));
5519     EXPECT_THAT(
5520         sendmsgInvoc,
5521         SendmsgInvocHasIovLastByte(
5522             kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5523             1));
5524   }
5525 
5526   // single iov, first and last byte = start of kOneHundredCharacterVec
5527   {
5528     struct iovec iov = {};
5529     iov.iov_base = const_cast<void*>(
5530         static_cast<const void*>((kOneHundredCharacterVec.data())));
5531     iov.iov_len = 1;
5532     const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov}};
5533 
5534     struct msghdr msg = {};
5535     msg.msg_name = nullptr;
5536     msg.msg_namelen = 0;
5537     msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5538     msg.msg_iovlen = sendmsgInvoc.iovs.size();
5539 
5540     // length
5541     EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(1)));
5542     EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(1)));
5543 
5544     // iov first byte
5545     EXPECT_THAT(
5546         sendmsgInvoc,
5547         SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
5548     EXPECT_THAT(
5549         sendmsgInvoc,
5550         Not(SendmsgInvocHasIovFirstByte(
5551             kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5552             1)));
5553 
5554     // iov last byte
5555     EXPECT_THAT(
5556         sendmsgInvoc,
5557         SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()));
5558     EXPECT_THAT(
5559         sendmsgInvoc,
5560         Not(SendmsgInvocHasIovLastByte(
5561             kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5562             1)));
5563   }
5564 
5565   // single iov, first and last byte = end of kOneHundredCharacterVec
5566   {
5567     struct iovec iov = {};
5568     iov.iov_base = const_cast<void*>(static_cast<const void*>(
5569         (kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size())));
5570     iov.iov_len = 1;
5571     const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov}};
5572 
5573     struct msghdr msg = {};
5574     msg.msg_name = nullptr;
5575     msg.msg_namelen = 0;
5576     msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5577     msg.msg_iovlen = sendmsgInvoc.iovs.size();
5578 
5579     // length
5580     EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(1)));
5581     EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(1)));
5582 
5583     // iov first byte
5584     EXPECT_THAT(
5585         sendmsgInvoc,
5586         SendmsgInvocHasIovFirstByte(
5587             kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size()));
5588 
5589     // iov last byte
5590     EXPECT_THAT(
5591         sendmsgInvoc,
5592         SendmsgInvocHasIovLastByte(
5593             kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size()));
5594   }
5595 
5596   // two iov, (0 -> 0, 1 - > 99), last byte = end of kOneHundredCharacterVec
5597   {
5598     struct iovec iov1 = {};
5599     iov1.iov_base = const_cast<void*>(
5600         static_cast<const void*>((kOneHundredCharacterVec.data())));
5601     iov1.iov_len = 1;
5602 
5603     struct iovec iov2 = {};
5604     iov2.iov_base = const_cast<void*>(
5605         static_cast<const void*>((kOneHundredCharacterVec.data() + 1)));
5606     iov2.iov_len = 99;
5607 
5608     const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov1, iov2}};
5609 
5610     struct msghdr msg = {};
5611     msg.msg_name = nullptr;
5612     msg.msg_namelen = 0;
5613     msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5614     msg.msg_iovlen = sendmsgInvoc.iovs.size();
5615 
5616     // length
5617     EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(100)));
5618     EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(100)));
5619 
5620     // iov first byte
5621     EXPECT_THAT(
5622         sendmsgInvoc,
5623         SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
5624 
5625     // iov last byte
5626     EXPECT_THAT(
5627         sendmsgInvoc,
5628         SendmsgInvocHasIovLastByte(
5629             kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5630             1));
5631   }
5632 
5633   // two iov, (0 -> 49, 50 - > 99), last byte = end of kOneHundredCharacterVec
5634   {
5635     struct iovec iov1 = {};
5636     iov1.iov_base = const_cast<void*>(
5637         static_cast<const void*>((kOneHundredCharacterVec.data())));
5638     iov1.iov_len = 50;
5639 
5640     struct iovec iov2 = {};
5641     iov2.iov_base = const_cast<void*>(
5642         static_cast<const void*>((kOneHundredCharacterVec.data() + 50)));
5643     iov2.iov_len = 50;
5644 
5645     const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov1, iov2}};
5646 
5647     struct msghdr msg = {};
5648     msg.msg_name = nullptr;
5649     msg.msg_namelen = 0;
5650     msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5651     msg.msg_iovlen = sendmsgInvoc.iovs.size();
5652 
5653     // length
5654     EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(100)));
5655     EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(100)));
5656 
5657     // iov first byte
5658     EXPECT_THAT(
5659         sendmsgInvoc,
5660         SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
5661 
5662     // iov last byte
5663     EXPECT_THAT(
5664         sendmsgInvoc,
5665         SendmsgInvocHasIovLastByte(
5666             kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5667             1));
5668   }
5669 
5670   // two iov, (0 -> 49, 50 - > 98), last byte = penultimate byte
5671   {
5672     struct iovec iov1 = {};
5673     iov1.iov_base = const_cast<void*>(
5674         static_cast<const void*>((kOneHundredCharacterVec.data())));
5675     iov1.iov_len = 50;
5676 
5677     struct iovec iov2 = {};
5678     iov2.iov_base = const_cast<void*>(
5679         static_cast<const void*>((kOneHundredCharacterVec.data() + 50)));
5680     iov2.iov_len = 49;
5681 
5682     const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov1, iov2}};
5683 
5684     struct msghdr msg = {};
5685     msg.msg_name = nullptr;
5686     msg.msg_namelen = 0;
5687     msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
5688     msg.msg_iovlen = sendmsgInvoc.iovs.size();
5689 
5690     // length
5691     EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(99)));
5692     EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(99)));
5693 
5694     // iov first byte
5695     EXPECT_THAT(
5696         sendmsgInvoc,
5697         SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
5698 
5699     // iov last byte
5700     EXPECT_THAT(
5701         sendmsgInvoc,
5702         SendmsgInvocHasIovLastByte(
5703             kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
5704             2));
5705   }
5706 }
5707 
TEST_F(AsyncSocketByteEventTest,SendmsgInvocMsgFlagsEq)5708 TEST_F(AsyncSocketByteEventTest, SendmsgInvocMsgFlagsEq) {
5709   // empty
5710   {
5711     const ClientConn::SendmsgInvocation sendmsgInvoc;
5712     EXPECT_THAT(sendmsgInvoc, SendmsgInvocMsgFlagsEq(WriteFlags::NONE));
5713     EXPECT_THAT(sendmsgInvoc, Not(SendmsgInvocMsgFlagsEq(WriteFlags::CORK)));
5714   }
5715 
5716   // flag set
5717   {
5718     ClientConn::SendmsgInvocation sendmsgInvoc = {};
5719     sendmsgInvoc.writeFlagsInMsgFlags = WriteFlags::CORK;
5720     EXPECT_THAT(sendmsgInvoc, Not(SendmsgInvocMsgFlagsEq(WriteFlags::NONE)));
5721     EXPECT_THAT(
5722         sendmsgInvoc,
5723         Not(SendmsgInvocMsgFlagsEq(
5724             WriteFlags::EOR | WriteFlags::CORK))); // should be exact match
5725     EXPECT_THAT(sendmsgInvoc, SendmsgInvocMsgFlagsEq(WriteFlags::CORK));
5726   }
5727 }
5728 
TEST_F(AsyncSocketByteEventTest,SendmsgInvocAncillaryFlagsEq)5729 TEST_F(AsyncSocketByteEventTest, SendmsgInvocAncillaryFlagsEq) {
5730   // empty
5731   {
5732     const ClientConn::SendmsgInvocation sendmsgInvoc;
5733     EXPECT_THAT(sendmsgInvoc, SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE));
5734     EXPECT_THAT(
5735         sendmsgInvoc,
5736         Not(SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX)));
5737   }
5738 
5739   // flag set
5740   {
5741     ClientConn::SendmsgInvocation sendmsgInvoc = {};
5742     sendmsgInvoc.writeFlagsInAncillary = WriteFlags::TIMESTAMP_TX;
5743     EXPECT_THAT(
5744         sendmsgInvoc, Not(SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE)));
5745     EXPECT_THAT(
5746         sendmsgInvoc,
5747         Not(SendmsgInvocAncillaryFlagsEq(
5748             WriteFlags::TIMESTAMP_TX |
5749             WriteFlags::TIMESTAMP_ACK))); // should be exact match
5750     EXPECT_THAT(
5751         sendmsgInvoc, SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX));
5752   }
5753 }
5754 
TEST_F(AsyncSocketByteEventTest,ByteEventMatching)5755 TEST_F(AsyncSocketByteEventTest, ByteEventMatching) {
5756   // offset = 0, type = WRITE
5757   {
5758     AsyncTransport::ByteEvent event = {};
5759     event.type = ByteEventType::WRITE;
5760     event.offset = 0;
5761     EXPECT_THAT(event, ByteEventMatching(ByteEventType::WRITE, 0));
5762 
5763     // not matching
5764     EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::WRITE, 10)));
5765     EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::TX, 0)));
5766     EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::ACK, 0)));
5767     EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::SCHED, 0)));
5768   }
5769 
5770   // offset = 10, type = TX
5771   {
5772     AsyncTransport::ByteEvent event = {};
5773     event.type = ByteEventType::TX;
5774     event.offset = 10;
5775     EXPECT_THAT(event, ByteEventMatching(ByteEventType::TX, 10));
5776 
5777     // not matching
5778     EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::TX, 0)));
5779     EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::WRITE, 10)));
5780     EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::ACK, 10)));
5781     EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::SCHED, 10)));
5782   }
5783 }
5784 
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserver)5785 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserver) {
5786   auto clientConn = getClientConn();
5787   clientConn.connect();
5788   auto observer = clientConn.attachObserver(
5789       true /* enableByteEvents */, true /* enablePrewrite */);
5790   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
5791   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
5792   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
5793 
5794   clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
5795 
5796   const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
5797       WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
5798   ON_CALL(*observer, prewriteMock(_, _))
5799       .WillByDefault(testing::Invoke(
5800           [](AsyncTransport*,
5801              const AsyncTransport::LifecycleObserver::PrewriteState& state) {
5802             AsyncTransport::LifecycleObserver::PrewriteRequest request;
5803             if (state.startOffset == 0) {
5804               request.maybeOffsetToSplitWrite = 0;
5805             } else if (state.startOffset <= 50) {
5806               request.maybeOffsetToSplitWrite = 50;
5807             } else if (state.startOffset <= 98) {
5808               request.maybeOffsetToSplitWrite = 98;
5809             }
5810 
5811             request.writeFlagsToAddAtOffset = flags;
5812             return request;
5813           }));
5814   clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
5815 
5816   EXPECT_THAT(
5817       clientConn.getSendmsgInvocations(),
5818       ElementsAre(
5819           AllOf(
5820               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
5821               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
5822               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
5823           AllOf(
5824               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 50),
5825               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
5826               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
5827           AllOf(
5828               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 98),
5829               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
5830               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
5831           AllOf(
5832               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
5833               SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
5834               SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));
5835 
5836   // verify WRITE events exist at the appropriate locations
5837   // we verify timestamp events are generated elsewhere
5838   //
5839   // should _not_ contain events for 99 as no prewrite for that
5840   EXPECT_THAT(
5841       filterToWriteEvents(observer->byteEvents),
5842       ElementsAre(
5843           ByteEventMatching(ByteEventType::WRITE, 0),
5844           ByteEventMatching(ByteEventType::WRITE, 50),
5845           ByteEventMatching(ByteEventType::WRITE, 98)));
5846 }
5847 
5848 /**
5849  * Test explicitly that CORK (MSG_MORE) is set if write is split in middle.
5850  */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverCorkIfSplitMiddle)5851 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverCorkIfSplitMiddle) {
5852   auto clientConn = getClientConn();
5853   clientConn.connect();
5854   auto observer = clientConn.attachObserver(
5855       true /* enableByteEvents */, true /* enablePrewrite */);
5856   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
5857   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
5858   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
5859 
5860   clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
5861 
5862   const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
5863       WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
5864   ON_CALL(*observer, prewriteMock(_, _))
5865       .WillByDefault(testing::Invoke(
5866           [](AsyncTransport*,
5867              const AsyncTransport::LifecycleObserver::PrewriteState& state) {
5868             AsyncTransport::LifecycleObserver::PrewriteRequest request;
5869             if (state.startOffset <= 50) {
5870               request.maybeOffsetToSplitWrite = 50;
5871             }
5872             request.writeFlagsToAddAtOffset = flags;
5873             return request;
5874           }));
5875   clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
5876 
5877   EXPECT_THAT(
5878       clientConn.getSendmsgInvocations(),
5879       ElementsAre(
5880           AllOf(
5881               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 50),
5882               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
5883               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
5884           AllOf(
5885               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
5886               SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
5887               SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));
5888 
5889   // verify WRITE events exist at the appropriate locations
5890   // we verify timestamp events are generated elsewhere
5891   EXPECT_THAT(
5892       filterToWriteEvents(observer->byteEvents),
5893       ElementsAre(ByteEventMatching(ByteEventType::WRITE, 50)));
5894 }
5895 
5896 /**
5897  * Test explicitly that CORK (MSG_MORE) is set if write is split in middle.
5898  */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverNoCorkIfSplitAtEnd)5899 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverNoCorkIfSplitAtEnd) {
5900   auto clientConn = getClientConn();
5901   clientConn.connect();
5902   auto observer = clientConn.attachObserver(
5903       true /* enableByteEvents */, true /* enablePrewrite */);
5904   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
5905   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
5906   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
5907 
5908   clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
5909 
5910   const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
5911       WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
5912   ON_CALL(*observer, prewriteMock(_, _))
5913       .WillByDefault(testing::Invoke(
5914           [](AsyncTransport*,
5915              const AsyncTransport::LifecycleObserver::PrewriteState& state) {
5916             AsyncTransport::LifecycleObserver::PrewriteRequest request;
5917             if (state.startOffset <= 99) {
5918               request.maybeOffsetToSplitWrite = 99;
5919             }
5920             request.writeFlagsToAddAtOffset = flags;
5921             return request;
5922           }));
5923   clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
5924 
5925   EXPECT_THAT(
5926       clientConn.getSendmsgInvocations(),
5927       ElementsAre(AllOf(
5928           SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
5929           SendmsgInvocMsgFlagsEq(WriteFlags::NONE), // no cork!
5930           SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))));
5931 
5932   // verify WRITE events exist at the appropriate locations
5933   // we verify timestamp events are generated elsewhere
5934   EXPECT_THAT(
5935       filterToWriteEvents(observer->byteEvents),
5936       ElementsAre(ByteEventMatching(ByteEventType::WRITE, 99)));
5937 }
5938 
5939 /**
5940  * Test explicitly that split flags are NOT added if no split.
5941  */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverNoSplitFlagsIfNoSplit)5942 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverNoSplitFlagsIfNoSplit) {
5943   auto clientConn = getClientConn();
5944   clientConn.connect();
5945   auto observer = clientConn.attachObserver(
5946       true /* enableByteEvents */, true /* enablePrewrite */);
5947   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
5948   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
5949   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
5950 
5951   clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
5952 
5953   const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
5954       WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
5955   ON_CALL(*observer, prewriteMock(_, _))
5956       .WillByDefault(
5957           testing::Invoke([](AsyncTransport*,
5958                              const AsyncTransport::LifecycleObserver::
5959                                  PrewriteState& /* state */) {
5960             AsyncTransport::LifecycleObserver::PrewriteRequest request;
5961             request.writeFlagsToAddAtOffset = flags;
5962             return request;
5963           }));
5964   clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
5965 
5966   EXPECT_THAT(
5967       clientConn.getSendmsgInvocations(),
5968       ElementsAre(AllOf(
5969           SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
5970           SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
5971           SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));
5972 }
5973 
5974 /**
5975  * Test more combinations of prewrite flags, including writeFlagsToAdd.
5976  */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverFlagsOnAll)5977 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverFlagsOnAll) {
5978   auto clientConn = getClientConn();
5979   clientConn.connect();
5980   auto observer = clientConn.attachObserver(
5981       true /* enableByteEvents */, true /* enablePrewrite */);
5982   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
5983   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
5984   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
5985 
5986   clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
5987   ON_CALL(*observer, prewriteMock(_, _))
5988       .WillByDefault(testing::Invoke(
5989           [](AsyncTransport*,
5990              const AsyncTransport::LifecycleObserver::PrewriteState& state) {
5991             AsyncTransport::LifecycleObserver::PrewriteRequest request;
5992             if (state.startOffset == 0) {
5993               request.maybeOffsetToSplitWrite = 0;
5994               request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_WRITE;
5995             } else if (state.startOffset <= 10) {
5996               request.maybeOffsetToSplitWrite = 10;
5997               request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_SCHED;
5998             } else if (state.startOffset <= 20) {
5999               request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_TX;
6000               request.maybeOffsetToSplitWrite = 20;
6001             } else if (state.startOffset <= 30) {
6002               request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_ACK;
6003               request.maybeOffsetToSplitWrite = 30;
6004             } else if (state.startOffset <= 40) {
6005               request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_TX;
6006               request.writeFlagsToAdd |= WriteFlags::TIMESTAMP_WRITE;
6007               request.maybeOffsetToSplitWrite = 40;
6008             } else {
6009               request.writeFlagsToAdd |= WriteFlags::TIMESTAMP_WRITE;
6010             }
6011 
6012             return request;
6013           }));
6014   clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
6015 
6016   EXPECT_THAT(
6017       clientConn.getSendmsgInvocations(),
6018       ElementsAre(
6019           AllOf(
6020               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
6021               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6022               SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE)),
6023           AllOf(
6024               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 10),
6025               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6026               SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_SCHED)),
6027           AllOf(
6028               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 20),
6029               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6030               SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX)),
6031           AllOf(
6032               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 30),
6033               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6034               SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_ACK)),
6035           AllOf(
6036               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 40),
6037               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6038               SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX)),
6039           AllOf(
6040               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
6041               SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
6042               SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));
6043 
6044   // verify WRITE events exist at the appropriate locations
6045   // we verify timestamp events are generated elsewhere
6046   EXPECT_THAT(
6047       filterToWriteEvents(observer->byteEvents),
6048       ElementsAre(
6049           ByteEventMatching(ByteEventType::WRITE, 0),
6050           ByteEventMatching(ByteEventType::WRITE, 40),
6051           ByteEventMatching(ByteEventType::WRITE, 99)));
6052 }
6053 
6054 /**
6055  * Test merging of write flags with those passed to AsyncSocket::write().
6056  */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverFlagsOnWrite)6057 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverFlagsOnWrite) {
6058   auto clientConn = getClientConn();
6059   clientConn.connect();
6060   auto observer = clientConn.attachObserver(
6061       true /* enableByteEvents */, true /* enablePrewrite */);
6062   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6063   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6064   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6065 
6066   clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6067 
6068   // first byte, observer adds TX and WRITE, onwards, it just adds WRITE
6069   ON_CALL(*observer, prewriteMock(_, _))
6070       .WillByDefault(testing::Invoke(
6071           [](AsyncTransport*,
6072              const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6073             AsyncTransport::LifecycleObserver::PrewriteRequest request;
6074             if (state.startOffset == 0) {
6075               request.maybeOffsetToSplitWrite = 0;
6076               request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_TX;
6077             }
6078             request.writeFlagsToAdd |= WriteFlags::TIMESTAMP_WRITE;
6079 
6080             return request;
6081           }));
6082 
6083   // application does a write with ACK and CORK set
6084   clientConn.writeAndReflect(
6085       kOneHundredCharacterVec, WriteFlags::CORK | WriteFlags::TIMESTAMP_ACK);
6086 
6087   // make sure we have the merge
6088   //   first write, TX is added
6089   //   second write, CORK is passed through
6090   EXPECT_THAT(
6091       clientConn.getSendmsgInvocations(),
6092       ElementsAre(
6093           AllOf(
6094               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
6095               SendmsgInvocMsgFlagsEq(WriteFlags::CORK), // set by split
6096               SendmsgInvocAncillaryFlagsEq(
6097                   WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK)),
6098           AllOf(
6099               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
6100               SendmsgInvocMsgFlagsEq(WriteFlags::CORK), // still set
6101               SendmsgInvocAncillaryFlagsEq(
6102                   dropWriteFromFlags(WriteFlags::TIMESTAMP_ACK)))));
6103 
6104   // verify WRITE events exist at the appropriate locations
6105   // we verify timestamp events are generated elsewhere
6106   EXPECT_THAT(
6107       filterToWriteEvents(observer->byteEvents),
6108       ElementsAre(
6109           ByteEventMatching(ByteEventType::WRITE, 0),
6110           ByteEventMatching(ByteEventType::WRITE, 99)));
6111 }
6112 
6113 /**
6114  * Test invalid offset for prewrite, ensure death via CHECK.
6115  */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverInvalidOffset)6116 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverInvalidOffset) {
6117   auto clientConn = getClientConn();
6118   clientConn.connect();
6119   auto observer = clientConn.attachObserver(
6120       true /* enableByteEvents */, true /* enablePrewrite */);
6121   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6122   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6123   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6124 
6125   clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6126 
6127   ON_CALL(*observer, prewriteMock(_, _))
6128       .WillByDefault(testing::Invoke(
6129           [](AsyncTransport*,
6130              const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6131             AsyncTransport::LifecycleObserver::PrewriteRequest request;
6132             EXPECT_GT(200, state.endOffset);
6133             request.maybeOffsetToSplitWrite = 200; // invalid
6134             return request;
6135           }));
6136 
6137   // check will fail due to invalid offset
6138   EXPECT_DEATH(
6139       clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE),
6140       ".*");
6141 }
6142 
6143 /**
6144  * Test prewrite with multiple iovec.
6145  */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverTwoIovec)6146 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverTwoIovec) {
6147   // two iovec, each with half of the kOneHundredCharacterVec
6148   std::vector<iovec> iovs;
6149   {
6150     iovec iov = {};
6151     iov.iov_base = const_cast<void*>(
6152         static_cast<const void*>((kOneHundredCharacterVec.data())));
6153     iov.iov_len = 50;
6154     iovs.push_back(iov);
6155   }
6156   {
6157     iovec iov = {};
6158     iov.iov_base = const_cast<void*>(
6159         static_cast<const void*>((kOneHundredCharacterVec.data() + 50)));
6160     iov.iov_len = 50;
6161     iovs.push_back(iov);
6162   }
6163 
6164   auto clientConn = getClientConn();
6165   clientConn.connect();
6166   auto observer = clientConn.attachObserver(
6167       true /* enableByteEvents */, true /* enablePrewrite */);
6168   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6169   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6170   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6171 
6172   clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6173 
6174   const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
6175       WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
6176   ON_CALL(*observer, prewriteMock(_, _))
6177       .WillByDefault(testing::Invoke(
6178           [](AsyncTransport*,
6179              const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6180             AsyncTransport::LifecycleObserver::PrewriteRequest request;
6181             if (state.startOffset == 0) {
6182               request.maybeOffsetToSplitWrite = 0;
6183             } else if (state.startOffset <= 49) {
6184               request.maybeOffsetToSplitWrite = 49;
6185             } else if (state.startOffset <= 99) {
6186               request.maybeOffsetToSplitWrite = 99;
6187             }
6188 
6189             request.writeFlagsToAddAtOffset = flags;
6190             return request;
6191           }));
6192 
6193   clientConn.writeAndReflect(iovs.data(), iovs.size(), WriteFlags::NONE);
6194 
6195   EXPECT_THAT(
6196       clientConn.getSendmsgInvocations(),
6197       ElementsAre(
6198           AllOf(
6199               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
6200               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6201               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
6202           AllOf(
6203               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 49),
6204               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6205               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
6206           AllOf(
6207               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
6208               SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
6209               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))));
6210 
6211   // verify WRITE events exist at the appropriate locations
6212   // we verify timestamp events are generated elsewhere
6213   EXPECT_THAT(
6214       filterToWriteEvents(observer->byteEvents),
6215       ElementsAre(
6216           ByteEventMatching(ByteEventType::WRITE, 0),
6217           ByteEventMatching(ByteEventType::WRITE, 49),
6218           ByteEventMatching(ByteEventType::WRITE, 99)));
6219 }
6220 
6221 /**
6222  * Test prewrite with large number of iovec to trigger malloc codepath.
6223  */
TEST_F(AsyncSocketByteEventTest,PrewriteSingleObserverManyIovec)6224 TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverManyIovec) {
6225   // make a long vector, 10000 bytes long
6226   auto tenThousandByteVec = get10KBOfData();
6227   ASSERT_THAT(tenThousandByteVec, SizeIs(10000));
6228 
6229   // put each byte in the vector into its own iovec
6230   std::vector<iovec> tenThousandIovec;
6231   for (size_t i = 0; i < tenThousandByteVec.size(); i++) {
6232     iovec iov = {};
6233     iov.iov_base = tenThousandByteVec.data() + i;
6234     iov.iov_len = 1;
6235     tenThousandIovec.push_back(iov);
6236   }
6237 
6238   auto clientConn = getClientConn();
6239   clientConn.connect();
6240   auto observer = clientConn.attachObserver(
6241       true /* enableByteEvents */, true /* enablePrewrite */);
6242   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6243   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6244   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6245 
6246   clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6247 
6248   const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
6249       WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
6250   ON_CALL(*observer, prewriteMock(_, _))
6251       .WillByDefault(testing::Invoke(
6252           [](AsyncTransport*,
6253              const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6254             AsyncTransport::LifecycleObserver::PrewriteRequest request;
6255             if (state.startOffset == 0) {
6256               request.maybeOffsetToSplitWrite = 0;
6257             } else if (state.startOffset <= 1000) {
6258               request.maybeOffsetToSplitWrite = 1000;
6259             } else if (state.startOffset <= 5000) {
6260               request.maybeOffsetToSplitWrite = 5000;
6261             }
6262 
6263             request.writeFlagsToAddAtOffset = flags;
6264             return request;
6265           }));
6266 
6267   clientConn.writeAndReflect(
6268       tenThousandIovec.data(), tenThousandIovec.size(), WriteFlags::NONE);
6269 
6270   EXPECT_THAT(
6271       clientConn.getSendmsgInvocations(),
6272       AllOf(
6273           Contains(AllOf(
6274               SendmsgInvocHasIovLastByte(tenThousandByteVec.data()),
6275               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6276               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
6277           Contains(AllOf(
6278               SendmsgInvocHasIovLastByte(tenThousandByteVec.data() + 1000),
6279               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6280               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
6281           Contains(AllOf(
6282               SendmsgInvocHasIovLastByte(tenThousandByteVec.data() + 5000),
6283               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6284               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
6285           Contains(AllOf(
6286               SendmsgInvocHasIovLastByte(tenThousandByteVec.data() + 9999),
6287               SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
6288               SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE)))));
6289 
6290   // verify WRITE events exist at the appropriate locations
6291   // we verify timestamp events are generated elsewhere
6292   //
6293   // should _not_ contain events for 99 as no prewrite for that
6294   EXPECT_THAT(
6295       filterToWriteEvents(observer->byteEvents),
6296       AllOf(
6297           Contains(ByteEventMatching(ByteEventType::WRITE, 0)),
6298           Contains(ByteEventMatching(ByteEventType::WRITE, 1000)),
6299           Contains(ByteEventMatching(ByteEventType::WRITE, 5000))));
6300 }
6301 
TEST_F(AsyncSocketByteEventTest,PrewriteMultipleObservers)6302 TEST_F(AsyncSocketByteEventTest, PrewriteMultipleObservers) {
6303   auto clientConn = getClientConn();
6304   clientConn.connect();
6305 
6306   // five observers
6307   // observer1 - 4 have byte events and prewrite enabled
6308   // observer5 has byte events enabled
6309   // observer6 has neither byte events or prewrite
6310   auto observer1 = clientConn.attachObserver(
6311       true /* enableByteEvents */, true /* enablePrewrite */);
6312   auto observer2 = clientConn.attachObserver(
6313       true /* enableByteEvents */, true /* enablePrewrite */);
6314   auto observer3 = clientConn.attachObserver(
6315       true /* enableByteEvents */, true /* enablePrewrite */);
6316   auto observer4 = clientConn.attachObserver(
6317       true /* enableByteEvents */, true /* enablePrewrite */);
6318   auto observer5 = clientConn.attachObserver(
6319       true /* enableByteEvents */, false /* enablePrewrite */);
6320   auto observer6 = clientConn.attachObserver(
6321       false /* enableByteEvents */, false /* enablePrewrite */);
6322 
6323   clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6324 
6325   // observer 1 wants TX timestamps at 25, 50, 75
6326   ON_CALL(*observer1, prewriteMock(_, _))
6327       .WillByDefault(testing::Invoke(
6328           [](AsyncTransport*,
6329              const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6330             AsyncTransport::LifecycleObserver::PrewriteRequest request;
6331             if (state.startOffset <= 25) {
6332               request.maybeOffsetToSplitWrite = 25;
6333             } else if (state.startOffset <= 50) {
6334               request.maybeOffsetToSplitWrite = 50;
6335             } else if (state.startOffset <= 75) {
6336               request.maybeOffsetToSplitWrite = 75;
6337             }
6338             request.writeFlagsToAddAtOffset = WriteFlags::TIMESTAMP_TX;
6339             return request;
6340           }));
6341 
6342   // observer 2 wants ACK timestamps at 35, 65, 75
6343   ON_CALL(*observer2, prewriteMock(_, _))
6344       .WillByDefault(testing::Invoke(
6345           [](AsyncTransport*,
6346              const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6347             AsyncTransport::LifecycleObserver::PrewriteRequest request;
6348             if (state.startOffset <= 35) {
6349               request.maybeOffsetToSplitWrite = 35;
6350             } else if (state.startOffset <= 65) {
6351               request.maybeOffsetToSplitWrite = 65;
6352             } else if (state.startOffset <= 75) {
6353               request.maybeOffsetToSplitWrite = 75;
6354             }
6355             request.writeFlagsToAddAtOffset = WriteFlags::TIMESTAMP_ACK;
6356             return request;
6357           }));
6358 
6359   // observer 3 wants WRITE and SCHED flag on every write that occurs
6360   ON_CALL(*observer3, prewriteMock(_, _))
6361       .WillByDefault(
6362           testing::Invoke([](AsyncTransport*,
6363                              const AsyncTransport::LifecycleObserver::
6364                                  PrewriteState& /* state */) {
6365             AsyncTransport::LifecycleObserver::PrewriteRequest request;
6366             request.writeFlagsToAdd =
6367                 WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED;
6368             return request;
6369           }));
6370 
6371   // observer 4 has prewrite but makes no requests
6372   ON_CALL(*observer4, prewriteMock(_, _))
6373       .WillByDefault(
6374           testing::Invoke([](AsyncTransport*,
6375                              const AsyncTransport::LifecycleObserver::
6376                                  PrewriteState& /* state */) {
6377             AsyncTransport::LifecycleObserver::PrewriteRequest request;
6378             return request; // empty
6379           }));
6380 
6381   // no calls for observer 5 or observer 6
6382   EXPECT_CALL(*observer5, prewriteMock(_, _)).Times(0);
6383   EXPECT_CALL(*observer6, prewriteMock(_, _)).Times(0);
6384 
6385   // write
6386   clientConn.writeAndReflect(kOneHundredCharacterVec, WriteFlags::NONE);
6387 
6388   EXPECT_THAT(
6389       clientConn.getSendmsgInvocations(),
6390       ElementsAre(
6391           AllOf(
6392               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 25),
6393               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6394               SendmsgInvocAncillaryFlagsEq(
6395                   WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX)),
6396           AllOf(
6397               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 35),
6398               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6399               SendmsgInvocAncillaryFlagsEq(
6400                   WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_ACK)),
6401           AllOf(
6402               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 50),
6403               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6404               SendmsgInvocAncillaryFlagsEq(
6405                   WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX)),
6406           AllOf(
6407               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 65),
6408               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6409               SendmsgInvocAncillaryFlagsEq(
6410                   WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_ACK)),
6411           AllOf(
6412               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 75),
6413               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6414               SendmsgInvocAncillaryFlagsEq(
6415                   WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX |
6416                   WriteFlags::TIMESTAMP_ACK)),
6417           AllOf(
6418               SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
6419               SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
6420               SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_SCHED))));
6421 
6422   // verify WRITE events exist at the appropriate locations
6423   // we verify timestamp events are generated elsewhere
6424   for (const auto& observer : {observer1, observer2, observer3}) {
6425     EXPECT_THAT(
6426         filterToWriteEvents(observer->byteEvents),
6427         ElementsAre(
6428             ByteEventMatching(ByteEventType::WRITE, 25),
6429             ByteEventMatching(ByteEventType::WRITE, 35),
6430             ByteEventMatching(ByteEventType::WRITE, 50),
6431             ByteEventMatching(ByteEventType::WRITE, 65),
6432             ByteEventMatching(ByteEventType::WRITE, 75),
6433             ByteEventMatching(ByteEventType::WRITE, 99)));
6434   }
6435 }
6436 
6437 /**
6438  * Test prewrite with large write that enables testing of timestamps.
6439  *
6440  * We need to use a long vector to ensure that the kernel will not coalesce
6441  * the writes into a single SKB due to MSG_MORE.
6442  */
TEST_F(AsyncSocketByteEventTest,PrewriteTimestampedByteEvents)6443 TEST_F(AsyncSocketByteEventTest, PrewriteTimestampedByteEvents) {
6444   // need a large block of data to ensure that MSG_MORE doesn't limit us
6445   const auto hundredKBVec = get1000KBOfData();
6446   ASSERT_THAT(hundredKBVec, SizeIs(1000000));
6447 
6448   auto clientConn = getClientConn();
6449   clientConn.connect();
6450   auto observer = clientConn.attachObserver(
6451       true /* enableByteEvents */, true /* enablePrewrite */);
6452   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6453   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6454   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6455 
6456   clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
6457 
6458   const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
6459       WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
6460   ON_CALL(*observer, prewriteMock(_, _))
6461       .WillByDefault(testing::Invoke(
6462           [](AsyncTransport*,
6463              const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6464             AsyncTransport::LifecycleObserver::PrewriteRequest request;
6465             if (state.startOffset == 0) {
6466               request.maybeOffsetToSplitWrite = 0;
6467             } else if (state.startOffset <= 500000) {
6468               request.maybeOffsetToSplitWrite = 500000;
6469             } else {
6470               request.maybeOffsetToSplitWrite = 999999;
6471             }
6472 
6473             request.writeFlagsToAdd = flags;
6474             return request;
6475           }));
6476 
6477   clientConn.writeAndReflect(hundredKBVec, WriteFlags::NONE);
6478 
6479   EXPECT_THAT(
6480       clientConn.getSendmsgInvocations(),
6481       AllOf(
6482           Contains(AllOf(
6483               SendmsgInvocHasIovLastByte(hundredKBVec.data()),
6484               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6485               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
6486           Contains(AllOf(
6487               SendmsgInvocHasIovLastByte(hundredKBVec.data() + 500000),
6488               SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
6489               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
6490           Contains(AllOf(
6491               SendmsgInvocHasIovLastByte(hundredKBVec.data() + 999999),
6492               SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
6493               SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))))));
6494 
6495   // verify WRITE events exist at the appropriate locations
6496   EXPECT_THAT(
6497       filterToWriteEvents(observer->byteEvents),
6498       AllOf(
6499           Contains(ByteEventMatching(ByteEventType::WRITE, 0)),
6500           Contains(ByteEventMatching(ByteEventType::WRITE, 500000)),
6501           Contains(ByteEventMatching(ByteEventType::WRITE, 999999))));
6502 
6503   // verify SCHED, TX, and ACK events available at specified locations
6504   EXPECT_THAT(
6505       observer->byteEvents,
6506       AllOf(
6507           Contains(ByteEventMatching(ByteEventType::SCHED, 0)),
6508           Contains(ByteEventMatching(ByteEventType::TX, 0)),
6509           Contains(ByteEventMatching(ByteEventType::ACK, 0)),
6510           Contains(ByteEventMatching(ByteEventType::SCHED, 500000)),
6511           Contains(ByteEventMatching(ByteEventType::TX, 500000)),
6512           Contains(ByteEventMatching(ByteEventType::ACK, 500000)),
6513           Contains(ByteEventMatching(ByteEventType::SCHED, 999999)),
6514           Contains(ByteEventMatching(ByteEventType::TX, 999999)),
6515           Contains(ByteEventMatching(ByteEventType::ACK, 999999))));
6516 }
6517 
6518 /**
6519  * Test raw bytes written and bytes tried to write with prewrite.
6520  */
TEST_F(AsyncSocketByteEventTest,PrewriteRawBytesWrittenAndTriedToWrite)6521 TEST_F(AsyncSocketByteEventTest, PrewriteRawBytesWrittenAndTriedToWrite) {
6522   auto clientConn = getClientConn();
6523   clientConn.connect();
6524   auto observer = clientConn.attachObserver(
6525       true /* enableByteEvents */, true /* enablePrewrite */);
6526   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6527   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6528   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6529 
6530   struct ExpectedSendmsgInvocation {
6531     size_t expectedTotalIovLen{0};
6532     ssize_t returnVal{0}; // number of bytes written or error val
6533     folly::Optional<size_t> maybeWriteEventExpectedOffset{};
6534     folly::Optional<WriteFlags> maybeWriteEventExpectedFlags{};
6535   };
6536 
6537   const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
6538       WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
6539 
6540   // first write
6541   //
6542   // no splits triggered by observer
6543   //
6544   // sendmsg will incrementally accept the bytes so we can test the values of
6545   // maybeRawBytesWritten and maybeRawBytesTriedToWrite
6546   {
6547     // bytes written per sendmsg call: 20, 10, 50, -1 (EAGAIN), 11, 99
6548     const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
6549         // {
6550         //    expectedTotalIovLen, returnVal,
6551         //    maybeWriteEventExpectedOffset, maybeWriteEventExpectedFlags
6552         // },
6553         {100, 20, 19, flags},
6554         {80, 10, 29, flags},
6555         {70, 50, 79, flags},
6556         {20, -1, folly::none, flags},
6557         {20, 11, 90, flags},
6558         {9, 9, 99, flags}};
6559 
6560     // prewrite will be called, we request all events
6561     EXPECT_CALL(*observer, prewriteMock(_, _))
6562         .Times(expectedSendmsgInvocations.size())
6563         .WillRepeatedly(testing::InvokeWithoutArgs([]() {
6564           AsyncTransport::LifecycleObserver::PrewriteRequest request = {};
6565           request.writeFlagsToAdd = flags;
6566           return request;
6567         }));
6568 
6569     // sendmsg will be called, we return # of bytes written
6570     {
6571       InSequence s;
6572       for (const auto& expectedInvocation : expectedSendmsgInvocations) {
6573         EXPECT_CALL(
6574             *(clientConn.getNetOpsDispatcher()),
6575             sendmsg(
6576                 _,
6577                 Pointee(SendmsgMsghdrHasTotalIovLen(
6578                     expectedInvocation.expectedTotalIovLen)),
6579                 _))
6580             .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
6581               if (expectedInvocation.returnVal < 0) {
6582                 errno = EAGAIN; // returning error, set EAGAIN
6583               }
6584               return expectedInvocation.returnVal;
6585             }));
6586       }
6587     }
6588 
6589     // write
6590     // writes will be intercepted, so we don't need to read at other end
6591     WriteCallback wcb;
6592     clientConn.getRawSocket()->write(
6593         &wcb,
6594         kOneHundredCharacterVec.data(),
6595         kOneHundredCharacterVec.size(),
6596         WriteFlags::NONE);
6597     while (STATE_WAITING == wcb.state) {
6598       clientConn.getRawSocket()->getEventBase()->loopOnce();
6599     }
6600     ASSERT_EQ(STATE_SUCCEEDED, wcb.state);
6601 
6602     // check write events
6603     for (const auto& expectedInvocation : expectedSendmsgInvocations) {
6604       if (expectedInvocation.returnVal < 0) {
6605         // should be no WriteEvent since the return value was an error
6606         continue;
6607       }
6608 
6609       ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
6610       const auto& expectedOffset =
6611           *expectedInvocation.maybeWriteEventExpectedOffset;
6612 
6613       auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
6614           expectedOffset, ByteEventType::WRITE);
6615       ASSERT_TRUE(maybeByteEvent.has_value());
6616       auto& byteEvent = maybeByteEvent.value();
6617 
6618       EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
6619       EXPECT_EQ(expectedOffset, byteEvent.offset);
6620       EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
6621       EXPECT_LT(
6622           std::chrono::steady_clock::now() - std::chrono::seconds(60),
6623           byteEvent.ts);
6624 
6625       EXPECT_EQ(
6626           expectedInvocation.maybeWriteEventExpectedFlags,
6627           byteEvent.maybeWriteFlags);
6628       EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
6629       EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
6630       EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());
6631 
6632       EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
6633       EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
6634 
6635       // what we really want to test
6636       EXPECT_EQ(
6637           folly::to_unsigned(expectedInvocation.returnVal),
6638           byteEvent.maybeRawBytesWritten);
6639       EXPECT_EQ(
6640           expectedInvocation.expectedTotalIovLen,
6641           byteEvent.maybeRawBytesTriedToWrite);
6642     }
6643   }
6644 
6645   // everything should have occurred by now
6646   clientConn.netOpsVerifyAndClearExpectations();
6647 
6648   // second write
6649   //
6650   // start offset is 100
6651   //
6652   // split at 150th byte triggered by observer
6653   //
6654   // sendmsg will incrementally accept the bytes so we can test the values of
6655   // maybeRawBytesWritten and maybeRawBytesTriedToWrite
6656   {
6657     // due to the split at the 150th byte, we expect sendmsg invocation to
6658     // only be called with bytes 100 -> 150 until after the 150th byte has been
6659     // written; in addition, the socket only accepts 20 of the 50 bytes the
6660     // first write.
6661     //
6662     // bytes written per sendmsg call: 20, 30, 50
6663     const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
6664         {50, 20, 119, flags | WriteFlags::CORK},
6665         {30, 30, 149, flags | WriteFlags::CORK},
6666         {50, 50, 199, flags}};
6667 
6668     // prewrite will be called, split at 50th byte (offset = 49)
6669     EXPECT_CALL(*observer, prewriteMock(_, _))
6670         .Times(expectedSendmsgInvocations.size())
6671         .WillRepeatedly(testing::Invoke(
6672             [](AsyncTransport*,
6673                const AsyncTransport::LifecycleObserver::PrewriteState& state) {
6674               AsyncTransport::LifecycleObserver::PrewriteRequest request;
6675               if (state.startOffset <= 149) {
6676                 request.maybeOffsetToSplitWrite = 149; // start offset = 100
6677               }
6678               request.writeFlagsToAdd = flags;
6679               return request;
6680             }));
6681 
6682     // sendmsg will be called, we return # of bytes written
6683     {
6684       InSequence s;
6685       for (const auto& expectedInvocation : expectedSendmsgInvocations) {
6686         EXPECT_CALL(
6687             *(clientConn.getNetOpsDispatcher()),
6688             sendmsg(
6689                 _,
6690                 Pointee(SendmsgMsghdrHasTotalIovLen(
6691                     expectedInvocation.expectedTotalIovLen)),
6692                 _))
6693             .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
6694               return expectedInvocation.returnVal;
6695             }));
6696       }
6697     }
6698 
6699     // write
6700     // writes will be intercepted, so we don't need to read at other end
6701     WriteCallback wcb;
6702     clientConn.getRawSocket()->write(
6703         &wcb,
6704         kOneHundredCharacterVec.data(),
6705         kOneHundredCharacterVec.size(),
6706         WriteFlags::NONE);
6707     while (STATE_WAITING == wcb.state) {
6708       clientConn.getRawSocket()->getEventBase()->loopOnce();
6709     }
6710     ASSERT_EQ(STATE_SUCCEEDED, wcb.state);
6711 
6712     // check write events
6713     for (const auto& expectedInvocation : expectedSendmsgInvocations) {
6714       ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
6715       const auto& expectedOffset =
6716           *expectedInvocation.maybeWriteEventExpectedOffset;
6717 
6718       auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
6719           expectedOffset, ByteEventType::WRITE);
6720       ASSERT_TRUE(maybeByteEvent.has_value());
6721       auto& byteEvent = maybeByteEvent.value();
6722 
6723       EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
6724       EXPECT_EQ(expectedOffset, byteEvent.offset);
6725       EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
6726       EXPECT_LT(
6727           std::chrono::steady_clock::now() - std::chrono::seconds(60),
6728           byteEvent.ts);
6729 
6730       EXPECT_EQ(
6731           expectedInvocation.maybeWriteEventExpectedFlags,
6732           byteEvent.maybeWriteFlags);
6733       EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
6734       EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
6735       EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());
6736 
6737       EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
6738       EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
6739 
6740       // what we really want to test
6741       EXPECT_EQ(
6742           folly::to_unsigned(expectedInvocation.returnVal),
6743           byteEvent.maybeRawBytesWritten);
6744       EXPECT_EQ(
6745           expectedInvocation.expectedTotalIovLen,
6746           byteEvent.maybeRawBytesTriedToWrite);
6747     }
6748   }
6749 }
6750 
6751 struct AsyncSocketByteEventDetailsTestParams {
6752   struct WriteParams {
WriteParamsAsyncSocketByteEventDetailsTestParams::WriteParams6753     WriteParams(uint64_t bufferSize, WriteFlags writeFlags)
6754         : bufferSize(bufferSize), writeFlags(writeFlags) {}
6755     uint64_t bufferSize{0};
6756     WriteFlags writeFlags{WriteFlags::NONE};
6757   };
6758 
6759   std::vector<WriteParams> writesWithParams;
6760 };
6761 
6762 class AsyncSocketByteEventDetailsTest
6763     : public AsyncSocketByteEventTest,
6764       public testing::WithParamInterface<
6765           AsyncSocketByteEventDetailsTestParams> {
6766  public:
getTestingValues()6767   static std::vector<AsyncSocketByteEventDetailsTestParams> getTestingValues() {
6768     const std::array<WriteFlags, 9> writeFlagCombinations{
6769         // SCHED
6770         WriteFlags::TIMESTAMP_SCHED,
6771         // TX
6772         WriteFlags::TIMESTAMP_TX,
6773         // ACK
6774         WriteFlags::TIMESTAMP_ACK,
6775         // SCHED + TX + ACK
6776         WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX |
6777             WriteFlags::TIMESTAMP_ACK,
6778         // WRITE
6779         WriteFlags::TIMESTAMP_WRITE,
6780         // WRITE + SCHED
6781         WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED,
6782         // WRITE + TX
6783         WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_TX,
6784         // WRITE + ACK
6785         WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_ACK,
6786         // WRITE + SCHED + TX + ACK
6787         WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
6788             WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK,
6789     };
6790 
6791     std::vector<AsyncSocketByteEventDetailsTestParams> vals;
6792     for (const auto& writeFlags : writeFlagCombinations) {
6793       // write 1 byte
6794       {
6795         AsyncSocketByteEventDetailsTestParams params;
6796         params.writesWithParams.emplace_back(1, writeFlags);
6797         vals.push_back(params);
6798       }
6799 
6800       // write 1 byte twice
6801       {
6802         AsyncSocketByteEventDetailsTestParams params;
6803         params.writesWithParams.emplace_back(1, writeFlags);
6804         params.writesWithParams.emplace_back(1, writeFlags);
6805         vals.push_back(params);
6806       }
6807 
6808       // write 10 bytes
6809       {
6810         AsyncSocketByteEventDetailsTestParams params;
6811         params.writesWithParams.emplace_back(10, writeFlags);
6812         vals.push_back(params);
6813       }
6814 
6815       // write 10 bytes twice
6816       {
6817         AsyncSocketByteEventDetailsTestParams params;
6818         params.writesWithParams.emplace_back(10, writeFlags);
6819         params.writesWithParams.emplace_back(10, writeFlags);
6820         vals.push_back(params);
6821       }
6822     }
6823 
6824     return vals;
6825   }
6826 };
6827 
6828 INSTANTIATE_TEST_SUITE_P(
6829     ByteEventDetailsTest,
6830     AsyncSocketByteEventDetailsTest,
6831     ::testing::ValuesIn(AsyncSocketByteEventDetailsTest::getTestingValues()));
6832 
6833 /**
6834  * Inspect ByteEvent fields, including xTimestampRequested in WRITE events.
6835  */
TEST_P(AsyncSocketByteEventDetailsTest,CheckByteEventDetails)6836 TEST_P(AsyncSocketByteEventDetailsTest, CheckByteEventDetails) {
6837   auto params = GetParam();
6838 
6839   auto clientConn = getClientConn();
6840   clientConn.connect();
6841   auto observer = clientConn.attachObserver(true /* enableByteEvents */);
6842   EXPECT_EQ(1, observer->byteEventsEnabledCalled);
6843   EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
6844   EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
6845 
6846   uint64_t expectedNumByteEvents = 0;
6847   for (const auto& writeParams : params.writesWithParams) {
6848     const std::vector<uint8_t> wbuf(writeParams.bufferSize, 'a');
6849     const auto flags = writeParams.writeFlags;
6850     clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
6851         dropWriteFromFlags(flags));
6852     clientConn.writeAndReflect(wbuf, flags);
6853     clientConn.netOpsVerifyAndClearExpectations();
6854     const auto expectedOffset =
6855         clientConn.getRawSocket()->getRawBytesWritten() - 1;
6856 
6857     // check WRITE
6858     if ((flags & WriteFlags::TIMESTAMP_WRITE) != WriteFlags::NONE) {
6859       expectedNumByteEvents++;
6860 
6861       auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
6862           expectedOffset, ByteEventType::WRITE);
6863       ASSERT_TRUE(maybeByteEvent.has_value());
6864       auto& byteEvent = maybeByteEvent.value();
6865 
6866       EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
6867       EXPECT_EQ(expectedOffset, byteEvent.offset);
6868       EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
6869       EXPECT_LT(
6870           std::chrono::steady_clock::now() - std::chrono::seconds(60),
6871           byteEvent.ts);
6872 
6873       EXPECT_EQ(flags, byteEvent.maybeWriteFlags);
6874       EXPECT_EQ(
6875           isSet(flags, WriteFlags::TIMESTAMP_SCHED),
6876           byteEvent.schedTimestampRequestedOnWrite());
6877       EXPECT_EQ(
6878           isSet(flags, WriteFlags::TIMESTAMP_TX),
6879           byteEvent.txTimestampRequestedOnWrite());
6880       EXPECT_EQ(
6881           isSet(flags, WriteFlags::TIMESTAMP_ACK),
6882           byteEvent.ackTimestampRequestedOnWrite());
6883 
6884       EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
6885       EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
6886     }
6887 
6888     // check SCHED, TX, ACK
6889     for (const auto& byteEventType :
6890          {ByteEventType::SCHED, ByteEventType::TX, ByteEventType::ACK}) {
6891       auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
6892           expectedOffset, byteEventType);
6893       switch (byteEventType) {
6894         case ByteEventType::WRITE:
6895           FAIL();
6896         case ByteEventType::SCHED:
6897           if ((flags & WriteFlags::TIMESTAMP_SCHED) == WriteFlags::NONE) {
6898             EXPECT_FALSE(maybeByteEvent.has_value());
6899             continue;
6900           }
6901           break;
6902         case ByteEventType::TX:
6903           if ((flags & WriteFlags::TIMESTAMP_TX) == WriteFlags::NONE) {
6904             EXPECT_FALSE(maybeByteEvent.has_value());
6905             continue;
6906           }
6907           break;
6908         case ByteEventType::ACK:
6909           if ((flags & WriteFlags::TIMESTAMP_ACK) == WriteFlags::NONE) {
6910             EXPECT_FALSE(maybeByteEvent.has_value());
6911             continue;
6912           }
6913           break;
6914       }
6915 
6916       expectedNumByteEvents++;
6917       ASSERT_TRUE(maybeByteEvent.has_value());
6918       auto& byteEvent = maybeByteEvent.value();
6919 
6920       EXPECT_EQ(byteEventType, byteEvent.type);
6921       EXPECT_EQ(expectedOffset, byteEvent.offset);
6922       EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
6923       EXPECT_LT(
6924           std::chrono::steady_clock::now() - std::chrono::seconds(60),
6925           byteEvent.ts);
6926       EXPECT_FALSE(byteEvent.maybeWriteFlags.has_value());
6927       // don't check *TimestampRequestedOnWrite* fields to avoid CHECK_DEATH,
6928       // already checked in CheckByteEventDetailsApplicationSetsFlags
6929 
6930       EXPECT_TRUE(byteEvent.maybeSoftwareTs.has_value());
6931       EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
6932     }
6933   }
6934 
6935   // should have at least expectedNumByteEvents
6936   // may be more if writes were split up by kernel
6937   EXPECT_THAT(observer->byteEvents, SizeIs(Ge(expectedNumByteEvents)));
6938 }
6939 
6940 class AsyncSocketByteEventHelperTest : public ::testing::Test {
6941  protected:
6942   using ByteEventType = AsyncTransport::ByteEvent::Type;
6943 
6944   /**
6945    * Wrapper around a vector containing cmsg header + data.
6946    */
6947   class WrappedCMsg {
6948    public:
WrappedCMsg(std::vector<char> && data)6949     explicit WrappedCMsg(std::vector<char>&& data) : data_(std::move(data)) {}
6950 
operator const struct cmsghdr&()6951     operator const struct cmsghdr &() {
6952       return *reinterpret_cast<struct cmsghdr*>(data_.data());
6953     }
6954 
6955    protected:
6956     std::vector<char> data_;
6957   };
6958 
6959   /**
6960    * Wrapper around a vector containing cmsg header + data.
6961    */
6962   class WrappedSockExtendedErrTsCMsg : public WrappedCMsg {
6963    public:
6964     using WrappedCMsg::WrappedCMsg;
6965 
6966     // ts[0] -> software timestamp
6967     // ts[1] -> hardware timestamp transformed to userspace time (deprecated)
6968     // ts[2] -> hardware timestamp
6969 
setSoftwareTimestamp(const std::chrono::seconds seconds,const std::chrono::nanoseconds nanoseconds)6970     void setSoftwareTimestamp(
6971         const std::chrono::seconds seconds,
6972         const std::chrono::nanoseconds nanoseconds) {
6973       struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data_.data())};
6974       struct scm_timestamping* tss{
6975           reinterpret_cast<struct scm_timestamping*>(CMSG_DATA(cmsg))};
6976       tss->ts[0].tv_sec = seconds.count();
6977       tss->ts[0].tv_nsec = nanoseconds.count();
6978     }
6979 
setHardwareTimestamp(const std::chrono::seconds seconds,const std::chrono::nanoseconds nanoseconds)6980     void setHardwareTimestamp(
6981         const std::chrono::seconds seconds,
6982         const std::chrono::nanoseconds nanoseconds) {
6983       struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data_.data())};
6984       struct scm_timestamping* tss{
6985           reinterpret_cast<struct scm_timestamping*>(CMSG_DATA(cmsg))};
6986       tss->ts[2].tv_sec = seconds.count();
6987       tss->ts[2].tv_nsec = nanoseconds.count();
6988     }
6989   };
6990 
cmsgData(int level,int type,size_t len)6991   static std::vector<char> cmsgData(int level, int type, size_t len) {
6992     std::vector<char> data(CMSG_LEN(len), 0);
6993     struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data.data())};
6994     cmsg->cmsg_level = level;
6995     cmsg->cmsg_type = type;
6996     cmsg->cmsg_len = CMSG_LEN(len);
6997     return data;
6998   }
6999 
cmsgForSockExtendedErrTimestamping()7000   static WrappedSockExtendedErrTsCMsg cmsgForSockExtendedErrTimestamping() {
7001     return WrappedSockExtendedErrTsCMsg(
7002         cmsgData(SOL_SOCKET, SO_TIMESTAMPING, sizeof(struct scm_timestamping)));
7003   }
7004 
cmsgForScmTimestamping(const uint32_t type,const uint32_t kernelByteOffset)7005   static WrappedCMsg cmsgForScmTimestamping(
7006       const uint32_t type, const uint32_t kernelByteOffset) {
7007     auto data = cmsgData(SOL_IP, IP_RECVERR, sizeof(struct sock_extended_err));
7008     struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data.data())};
7009     struct sock_extended_err* serr{
7010         reinterpret_cast<struct sock_extended_err*>(CMSG_DATA(cmsg))};
7011     serr->ee_errno = ENOMSG;
7012     serr->ee_origin = SO_EE_ORIGIN_TIMESTAMPING;
7013     serr->ee_info = type;
7014     serr->ee_data = kernelByteOffset;
7015     return WrappedCMsg(std::move(data));
7016   }
7017 };
7018 
TEST_F(AsyncSocketByteEventHelperTest,ByteOffsetThenTs)7019 TEST_F(AsyncSocketByteEventHelperTest, ByteOffsetThenTs) {
7020   auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
7021   const auto softwareTsSec = std::chrono::seconds(59);
7022   const auto softwareTsNs = std::chrono::nanoseconds(11);
7023   auto serrTs = cmsgForSockExtendedErrTimestamping();
7024   serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7025 
7026   AsyncSocket::ByteEventHelper helper = {};
7027   helper.byteEventsEnabled = true;
7028   helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7029 
7030   EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7031   EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7032 }
7033 
TEST_F(AsyncSocketByteEventHelperTest,TsThenByteOffset)7034 TEST_F(AsyncSocketByteEventHelperTest, TsThenByteOffset) {
7035   auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
7036   const auto softwareTsSec = std::chrono::seconds(59);
7037   const auto softwareTsNs = std::chrono::nanoseconds(11);
7038   auto serrTs = cmsgForSockExtendedErrTimestamping();
7039   serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7040 
7041   AsyncSocket::ByteEventHelper helper = {};
7042   helper.byteEventsEnabled = true;
7043   helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7044 
7045   EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7046   EXPECT_TRUE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7047 }
7048 
TEST_F(AsyncSocketByteEventHelperTest,ByteEventsDisabled)7049 TEST_F(AsyncSocketByteEventHelperTest, ByteEventsDisabled) {
7050   auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
7051   const auto softwareTsSec = std::chrono::seconds(59);
7052   const auto softwareTsNs = std::chrono::nanoseconds(11);
7053   auto serrTs = cmsgForSockExtendedErrTimestamping();
7054   serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7055 
7056   AsyncSocket::ByteEventHelper helper = {};
7057   helper.byteEventsEnabled = false;
7058   helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7059 
7060   // fails because disabled
7061   EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7062   EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7063 
7064   // enable, try again to prove this works
7065   helper.byteEventsEnabled = true;
7066   EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7067   EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7068 }
7069 
TEST_F(AsyncSocketByteEventHelperTest,IgnoreUnsupportedEvent)7070 TEST_F(AsyncSocketByteEventHelperTest, IgnoreUnsupportedEvent) {
7071   auto scmType =
7072       folly::netops::SCM_TSTAMP_ACK + 10; // imaginary new type of SCM event
7073   auto scmTs = cmsgForScmTimestamping(scmType, 0);
7074   const auto softwareTsSec = std::chrono::seconds(59);
7075   const auto softwareTsNs = std::chrono::nanoseconds(11);
7076   auto serrTs = cmsgForSockExtendedErrTimestamping();
7077   serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7078 
7079   AsyncSocket::ByteEventHelper helper = {};
7080   helper.byteEventsEnabled = true;
7081   helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7082 
7083   // unsupported event is eaten
7084   EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7085   EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7086 
7087   // change type, try again to prove this works
7088   scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_ACK, 0);
7089   EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7090   EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7091 }
7092 
TEST_F(AsyncSocketByteEventHelperTest,ErrorDoubleScmCmsg)7093 TEST_F(AsyncSocketByteEventHelperTest, ErrorDoubleScmCmsg) {
7094   auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
7095 
7096   AsyncSocket::ByteEventHelper helper = {};
7097   helper.byteEventsEnabled = true;
7098   helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7099   EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7100   EXPECT_THROW(
7101       helper.processCmsg(scmTs, 1 /* rawBytesWritten */),
7102       AsyncSocket::ByteEventHelper::Exception);
7103 }
7104 
TEST_F(AsyncSocketByteEventHelperTest,ErrorDoubleSerrCmsg)7105 TEST_F(AsyncSocketByteEventHelperTest, ErrorDoubleSerrCmsg) {
7106   const auto softwareTsSec = std::chrono::seconds(59);
7107   const auto softwareTsNs = std::chrono::nanoseconds(11);
7108   auto serrTs = cmsgForSockExtendedErrTimestamping();
7109   serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7110 
7111   AsyncSocket::ByteEventHelper helper = {};
7112   helper.byteEventsEnabled = true;
7113   helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7114   EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7115   EXPECT_THROW(
7116       helper.processCmsg(serrTs, 1 /* rawBytesWritten */),
7117       AsyncSocket::ByteEventHelper::Exception);
7118 }
7119 
TEST_F(AsyncSocketByteEventHelperTest,ErrorExceptionSet)7120 TEST_F(AsyncSocketByteEventHelperTest, ErrorExceptionSet) {
7121   auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
7122   const auto softwareTsSec = std::chrono::seconds(59);
7123   const auto softwareTsNs = std::chrono::nanoseconds(11);
7124   auto serrTs = cmsgForSockExtendedErrTimestamping();
7125   serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7126 
7127   AsyncSocket::ByteEventHelper helper = {};
7128   helper.byteEventsEnabled = true;
7129   helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7130   helper.maybeEx = AsyncSocketException(
7131       AsyncSocketException::AsyncSocketExceptionType::UNKNOWN, "");
7132 
7133   // fails due to existing exception
7134   EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7135   EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7136 
7137   // delete the exception, then repeat to prove exception was blocking
7138   helper.maybeEx = folly::none;
7139   EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
7140   EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
7141 }
7142 
7143 struct AsyncSocketByteEventHelperTimestampTestParams {
AsyncSocketByteEventHelperTimestampTestParamsAsyncSocketByteEventHelperTimestampTestParams7144   AsyncSocketByteEventHelperTimestampTestParams(
7145       uint32_t scmType,
7146       AsyncTransport::ByteEvent::Type expectedByteEventType,
7147       bool includeSoftwareTs,
7148       bool includeHardwareTs)
7149       : scmType(scmType),
7150         expectedByteEventType(expectedByteEventType),
7151         includeSoftwareTs(includeSoftwareTs),
7152         includeHardwareTs(includeHardwareTs) {}
7153   uint32_t scmType{0};
7154   AsyncTransport::ByteEvent::Type expectedByteEventType;
7155   bool includeSoftwareTs{false};
7156   bool includeHardwareTs{false};
7157 };
7158 
7159 class AsyncSocketByteEventHelperTimestampTest
7160     : public AsyncSocketByteEventHelperTest,
7161       public testing::WithParamInterface<
7162           AsyncSocketByteEventHelperTimestampTestParams> {
7163  public:
7164   static std::vector<AsyncSocketByteEventHelperTimestampTestParams>
getTestingValues()7165   getTestingValues() {
7166     std::vector<AsyncSocketByteEventHelperTimestampTestParams> vals;
7167 
7168     // software + hardware timestamps
7169     {
7170       vals.emplace_back(
7171           folly::netops::SCM_TSTAMP_SCHED, ByteEventType::SCHED, true, true);
7172       vals.emplace_back(
7173           folly::netops::SCM_TSTAMP_SND, ByteEventType::TX, true, true);
7174       vals.emplace_back(
7175           folly::netops::SCM_TSTAMP_ACK, ByteEventType::ACK, true, true);
7176     }
7177 
7178     // software ts only
7179     {
7180       vals.emplace_back(
7181           folly::netops::SCM_TSTAMP_SCHED, ByteEventType::SCHED, true, false);
7182       vals.emplace_back(
7183           folly::netops::SCM_TSTAMP_SND, ByteEventType::TX, true, false);
7184       vals.emplace_back(
7185           folly::netops::SCM_TSTAMP_ACK, ByteEventType::ACK, true, false);
7186     }
7187 
7188     // hardware ts only
7189     {
7190       vals.emplace_back(
7191           folly::netops::SCM_TSTAMP_SCHED, ByteEventType::SCHED, false, true);
7192       vals.emplace_back(
7193           folly::netops::SCM_TSTAMP_SND, ByteEventType::TX, false, true);
7194       vals.emplace_back(
7195           folly::netops::SCM_TSTAMP_ACK, ByteEventType::ACK, false, true);
7196     }
7197 
7198     return vals;
7199   }
7200 };
7201 
7202 INSTANTIATE_TEST_SUITE_P(
7203     ByteEventTimestampTest,
7204     AsyncSocketByteEventHelperTimestampTest,
7205     ::testing::ValuesIn(
7206         AsyncSocketByteEventHelperTimestampTest::getTestingValues()));
7207 
7208 /**
7209  * Check timestamp parsing for software and hardware timestamps.
7210  */
TEST_P(AsyncSocketByteEventHelperTimestampTest,CheckEventTimestamps)7211 TEST_P(AsyncSocketByteEventHelperTimestampTest, CheckEventTimestamps) {
7212   const auto softwareTsSec = std::chrono::seconds(59);
7213   const auto softwareTsNs = std::chrono::nanoseconds(11);
7214   const auto hardwareTsSec = std::chrono::seconds(79);
7215   const auto hardwareTsNs = std::chrono::nanoseconds(31);
7216 
7217   auto params = GetParam();
7218   auto scmTs = cmsgForScmTimestamping(params.scmType, 0);
7219   auto serrTs = cmsgForSockExtendedErrTimestamping();
7220   if (params.includeSoftwareTs) {
7221     serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7222   }
7223   if (params.includeHardwareTs) {
7224     serrTs.setHardwareTimestamp(hardwareTsSec, hardwareTsNs);
7225   }
7226 
7227   AsyncSocket::ByteEventHelper helper = {};
7228   helper.byteEventsEnabled = true;
7229   helper.rawBytesWrittenWhenByteEventsEnabled = 0;
7230   folly::Optional<AsyncTransport::ByteEvent> maybeByteEvent;
7231   maybeByteEvent = helper.processCmsg(serrTs, 1 /* rawBytesWritten */);
7232   EXPECT_FALSE(maybeByteEvent.has_value());
7233   maybeByteEvent = helper.processCmsg(scmTs, 1 /* rawBytesWritten */);
7234 
7235   // common checks
7236   ASSERT_TRUE(maybeByteEvent.has_value());
7237   const auto& byteEvent = *maybeByteEvent;
7238   EXPECT_EQ(0, byteEvent.offset);
7239   EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
7240 
7241   EXPECT_EQ(params.expectedByteEventType, byteEvent.type);
7242   if (params.includeSoftwareTs) {
7243     EXPECT_EQ(softwareTsSec + softwareTsNs, byteEvent.maybeSoftwareTs);
7244   }
7245   if (params.includeHardwareTs) {
7246     EXPECT_EQ(hardwareTsSec + hardwareTsNs, byteEvent.maybeHardwareTs);
7247   }
7248 }
7249 
7250 struct AsyncSocketByteEventHelperOffsetTestParams {
7251   uint64_t rawBytesWrittenWhenByteEventsEnabled{0};
7252   uint64_t byteTimestamped;
7253   uint64_t rawBytesWrittenWhenTimestampReceived;
7254 };
7255 
7256 class AsyncSocketByteEventHelperOffsetTest
7257     : public AsyncSocketByteEventHelperTest,
7258       public testing::WithParamInterface<
7259           AsyncSocketByteEventHelperOffsetTestParams> {
7260  public:
7261   static std::vector<AsyncSocketByteEventHelperOffsetTestParams>
getTestingValues()7262   getTestingValues() {
7263     std::vector<AsyncSocketByteEventHelperOffsetTestParams> vals;
7264     const std::array<uint64_t, 5> rawBytesWrittenWhenByteEventsEnabledVals{
7265         0, 1, 100, 4294967295, 4294967296};
7266     for (const auto& rawBytesWrittenWhenByteEventsEnabled :
7267          rawBytesWrittenWhenByteEventsEnabledVals) {
7268       auto addParams = [&](auto params) {
7269         // check if case is valid based on rawBytesWrittenWhenByteEventsEnabled
7270         if (rawBytesWrittenWhenByteEventsEnabled <= params.byteTimestamped) {
7271           vals.push_back(params);
7272         }
7273       };
7274 
7275       // case 1
7276       // bytes sent on receipt of timestamp == byte timestamped
7277       {
7278         AsyncSocketByteEventHelperOffsetTestParams params;
7279         params.rawBytesWrittenWhenByteEventsEnabled =
7280             rawBytesWrittenWhenByteEventsEnabled;
7281         params.byteTimestamped = 0;
7282         params.rawBytesWrittenWhenTimestampReceived = 0;
7283         addParams(params);
7284       }
7285       {
7286         AsyncSocketByteEventHelperOffsetTestParams params;
7287         params.rawBytesWrittenWhenByteEventsEnabled =
7288             rawBytesWrittenWhenByteEventsEnabled;
7289         params.byteTimestamped = 1;
7290         params.rawBytesWrittenWhenTimestampReceived = 1;
7291         addParams(params);
7292       }
7293       {
7294         AsyncSocketByteEventHelperOffsetTestParams params;
7295         params.rawBytesWrittenWhenByteEventsEnabled =
7296             rawBytesWrittenWhenByteEventsEnabled;
7297         params.byteTimestamped = 101;
7298         params.rawBytesWrittenWhenTimestampReceived = 101;
7299         addParams(params);
7300       }
7301 
7302       // bytes sent on receipt of timestamp > byte timestamped
7303       {
7304         AsyncSocketByteEventHelperOffsetTestParams params;
7305         params.rawBytesWrittenWhenByteEventsEnabled =
7306             rawBytesWrittenWhenByteEventsEnabled;
7307         params.byteTimestamped = 1;
7308         params.rawBytesWrittenWhenTimestampReceived = 2;
7309         addParams(params);
7310       }
7311       {
7312         AsyncSocketByteEventHelperOffsetTestParams params;
7313         params.rawBytesWrittenWhenByteEventsEnabled =
7314             rawBytesWrittenWhenByteEventsEnabled;
7315         params.byteTimestamped = 101;
7316         params.rawBytesWrittenWhenTimestampReceived = 102;
7317         addParams(params);
7318       }
7319 
7320       // case 2
7321       // bytes sent on receipt of timestamp == byte timestamped, boundary test
7322       // (boundary is at 2^32)
7323       {
7324         AsyncSocketByteEventHelperOffsetTestParams params;
7325         params.rawBytesWrittenWhenByteEventsEnabled =
7326             rawBytesWrittenWhenByteEventsEnabled;
7327         params.byteTimestamped = 4294967294;
7328         params.rawBytesWrittenWhenTimestampReceived = 4294967294;
7329         addParams(params);
7330       }
7331       {
7332         AsyncSocketByteEventHelperOffsetTestParams params;
7333         params.rawBytesWrittenWhenByteEventsEnabled =
7334             rawBytesWrittenWhenByteEventsEnabled;
7335         params.byteTimestamped = 4294967295;
7336         params.rawBytesWrittenWhenTimestampReceived = 4294967295;
7337         addParams(params);
7338       }
7339       {
7340         AsyncSocketByteEventHelperOffsetTestParams params;
7341         params.rawBytesWrittenWhenByteEventsEnabled =
7342             rawBytesWrittenWhenByteEventsEnabled;
7343         params.byteTimestamped = 4294967296;
7344         params.rawBytesWrittenWhenTimestampReceived = 4294967296;
7345         addParams(params);
7346       }
7347       {
7348         AsyncSocketByteEventHelperOffsetTestParams params;
7349         params.rawBytesWrittenWhenByteEventsEnabled =
7350             rawBytesWrittenWhenByteEventsEnabled;
7351         params.byteTimestamped = 4294967297;
7352         params.rawBytesWrittenWhenTimestampReceived = 4294967297;
7353         addParams(params);
7354       }
7355       {
7356         AsyncSocketByteEventHelperOffsetTestParams params;
7357         params.rawBytesWrittenWhenByteEventsEnabled =
7358             rawBytesWrittenWhenByteEventsEnabled;
7359         params.byteTimestamped = 4294967298;
7360         params.rawBytesWrittenWhenTimestampReceived = 4294967298;
7361         addParams(params);
7362       }
7363 
7364       // case 3
7365       // bytes sent on receipt of timestamp > byte timestamped, boundary test
7366       // (boundary is at 2^32)
7367       {
7368         AsyncSocketByteEventHelperOffsetTestParams params;
7369         params.rawBytesWrittenWhenByteEventsEnabled =
7370             rawBytesWrittenWhenByteEventsEnabled;
7371         params.byteTimestamped = 4294967293;
7372         params.rawBytesWrittenWhenTimestampReceived = 4294967294;
7373         addParams(params);
7374       }
7375       {
7376         AsyncSocketByteEventHelperOffsetTestParams params;
7377         params.rawBytesWrittenWhenByteEventsEnabled =
7378             rawBytesWrittenWhenByteEventsEnabled;
7379         params.byteTimestamped = 4294967294;
7380         params.rawBytesWrittenWhenTimestampReceived = 4294967295;
7381         addParams(params);
7382       }
7383       {
7384         AsyncSocketByteEventHelperOffsetTestParams params;
7385         params.rawBytesWrittenWhenByteEventsEnabled =
7386             rawBytesWrittenWhenByteEventsEnabled;
7387         params.byteTimestamped = 4294967295;
7388         params.rawBytesWrittenWhenTimestampReceived = 4294967296;
7389         addParams(params);
7390       }
7391       {
7392         AsyncSocketByteEventHelperOffsetTestParams params;
7393         params.rawBytesWrittenWhenByteEventsEnabled =
7394             rawBytesWrittenWhenByteEventsEnabled;
7395         params.byteTimestamped = 4294967296;
7396         params.rawBytesWrittenWhenTimestampReceived = 4294967297;
7397         addParams(params);
7398       }
7399 
7400       // case 4
7401       // bytes sent on receipt of timestamp > byte timestamped, wrap test
7402       // (boundary is at 2^32)
7403       {
7404         AsyncSocketByteEventHelperOffsetTestParams params;
7405         params.rawBytesWrittenWhenByteEventsEnabled =
7406             rawBytesWrittenWhenByteEventsEnabled;
7407         params.byteTimestamped = 4294967275;
7408         params.rawBytesWrittenWhenTimestampReceived = 4294967305;
7409         addParams(params);
7410       }
7411       {
7412         AsyncSocketByteEventHelperOffsetTestParams params;
7413         params.rawBytesWrittenWhenByteEventsEnabled =
7414             rawBytesWrittenWhenByteEventsEnabled;
7415         params.byteTimestamped = 4294967295;
7416         params.rawBytesWrittenWhenTimestampReceived = 4294967296;
7417         addParams(params);
7418       }
7419       {
7420         AsyncSocketByteEventHelperOffsetTestParams params;
7421         params.rawBytesWrittenWhenByteEventsEnabled =
7422             rawBytesWrittenWhenByteEventsEnabled;
7423         params.byteTimestamped = 4294967285;
7424         params.rawBytesWrittenWhenTimestampReceived = 4294967305;
7425         addParams(params);
7426       }
7427 
7428       // case 5
7429       // special case when timestamp enabled when bytes transferred > (2^32)
7430       // bytes sent on receipt of timestamp == byte timestamped, boundary test
7431       // (boundary is at 2^32)
7432       {
7433         AsyncSocketByteEventHelperOffsetTestParams params;
7434         params.rawBytesWrittenWhenByteEventsEnabled =
7435             rawBytesWrittenWhenByteEventsEnabled;
7436         params.byteTimestamped = 6442450943;
7437         params.rawBytesWrittenWhenTimestampReceived = 6442450943;
7438         addParams(params);
7439       }
7440 
7441       // case 6
7442       // special case when timestamp enabled when bytes transferred > (2^32)
7443       // bytes sent on receipt of timestamp > byte timestamped, boundary test
7444       // (boundary is at 2^32)
7445       {
7446         AsyncSocketByteEventHelperOffsetTestParams params;
7447         params.rawBytesWrittenWhenByteEventsEnabled =
7448             rawBytesWrittenWhenByteEventsEnabled;
7449         params.byteTimestamped = 6442450943;
7450         params.rawBytesWrittenWhenTimestampReceived = 6442450944;
7451         addParams(params);
7452       }
7453 
7454       // case 7
7455       // special case when timestamp enabled when bytes transferred > (2^32)
7456       // bytes sent on receipt of timestamp > byte timestamped, wrap test
7457       // (boundary is at 2^32)
7458       {
7459         AsyncSocketByteEventHelperOffsetTestParams params;
7460         params.rawBytesWrittenWhenByteEventsEnabled =
7461             rawBytesWrittenWhenByteEventsEnabled;
7462         params.byteTimestamped = 6442450943;
7463         params.rawBytesWrittenWhenTimestampReceived = 8589934591;
7464         addParams(params);
7465       }
7466     }
7467 
7468     return vals;
7469   }
7470 };
7471 
7472 INSTANTIATE_TEST_SUITE_P(
7473     ByteEventOffsetTest,
7474     AsyncSocketByteEventHelperOffsetTest,
7475     ::testing::ValuesIn(
7476         AsyncSocketByteEventHelperOffsetTest::getTestingValues()));
7477 
7478 /**
7479  * Check byte offset handling, including boundary cases.
7480  *
7481  * See AsyncSocket::ByteEventHelper::processCmsg for details.
7482  */
TEST_P(AsyncSocketByteEventHelperOffsetTest,CheckCalculatedOffset)7483 TEST_P(AsyncSocketByteEventHelperOffsetTest, CheckCalculatedOffset) {
7484   auto params = GetParam();
7485 
7486   // because we use SOF_TIMESTAMPING_OPT_ID, byte offsets delivered from the
7487   // kernel are offset (relative to bytes written by AsyncSocket) by the number
7488   // of bytes AsyncSocket had written to the socket when enabling timestamps
7489   //
7490   // here we calculate what the kernel offset would be for the given byte offset
7491   const uint64_t bytesPerOffsetWrap =
7492       static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1;
7493 
7494   auto kernelByteOffset =
7495       params.byteTimestamped - params.rawBytesWrittenWhenByteEventsEnabled;
7496   if (kernelByteOffset > 0) {
7497     kernelByteOffset = kernelByteOffset % bytesPerOffsetWrap;
7498   }
7499 
7500   auto scmTs =
7501       cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, kernelByteOffset);
7502   const auto softwareTsSec = std::chrono::seconds(59);
7503   const auto softwareTsNs = std::chrono::nanoseconds(11);
7504   auto serrTs = cmsgForSockExtendedErrTimestamping();
7505   serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
7506 
7507   AsyncSocket::ByteEventHelper helper = {};
7508   helper.byteEventsEnabled = true;
7509   helper.rawBytesWrittenWhenByteEventsEnabled =
7510       params.rawBytesWrittenWhenByteEventsEnabled;
7511 
7512   EXPECT_FALSE(helper.processCmsg(
7513       scmTs,
7514       params.rawBytesWrittenWhenTimestampReceived /* rawBytesWritten */));
7515   const auto maybeByteEvent = helper.processCmsg(
7516       serrTs,
7517       params.rawBytesWrittenWhenTimestampReceived /* rawBytesWritten */);
7518   ASSERT_TRUE(maybeByteEvent.has_value());
7519   const auto& byteEvent = *maybeByteEvent;
7520 
7521   EXPECT_EQ(params.byteTimestamped, byteEvent.offset);
7522   EXPECT_EQ(softwareTsSec + softwareTsNs, byteEvent.maybeSoftwareTs);
7523 }
7524 
7525 #endif // FOLLY_HAVE_SO_TIMESTAMPING
7526 
TEST(AsyncSocket,LifecycleCtorCallback)7527 TEST(AsyncSocket, LifecycleCtorCallback) {
7528   EventBase evb;
7529   // create socket and verify that w/o a ctor callback, nothing happens
7530   auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7531   EXPECT_EQ(socket1->getLifecycleObservers().size(), 0);
7532 
7533   // Then register a ctor callback that registers a mock lifecycle observer
7534   // NB: use nicemock instead of strict b/c the actual lifecycle testing
7535   // is done below and this simplifies the test
7536   auto lifecycleCB =
7537       std::make_shared<NiceMock<MockAsyncSocketLifecycleObserver>>();
7538   auto lifecycleRawPtr = lifecycleCB.get();
7539   // verify the first part of the lifecycle was processed
7540   ConstructorCallback<AsyncSocket>::addNewConstructorCallback(
7541       [lifecycleRawPtr](AsyncSocket* s) {
7542         s->addLifecycleObserver(lifecycleRawPtr);
7543       });
7544   auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7545   EXPECT_EQ(socket2->getLifecycleObservers().size(), 1);
7546   EXPECT_THAT(
7547       socket2->getLifecycleObservers(),
7548       UnorderedElementsAre(lifecycleCB.get()));
7549   Mock::VerifyAndClearExpectations(lifecycleCB.get());
7550 }
7551 
TEST(AsyncSocket,LifecycleObserverDetachAndAttachEvb)7552 TEST(AsyncSocket, LifecycleObserverDetachAndAttachEvb) {
7553   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7554   EventBase evb;
7555   EventBase evb2;
7556   auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7557   EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7558   socket->addLifecycleObserver(cb.get());
7559   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7560   Mock::VerifyAndClearExpectations(cb.get());
7561 
7562   // Detach the evb and attach a new evb2
7563   EXPECT_CALL(*cb, evbDetachMock(socket.get(), &evb));
7564   socket->detachEventBase();
7565   EXPECT_EQ(nullptr, socket->getEventBase());
7566   Mock::VerifyAndClearExpectations(cb.get());
7567 
7568   EXPECT_CALL(*cb, evbAttachMock(socket.get(), &evb2));
7569   socket->attachEventBase(&evb2);
7570   EXPECT_EQ(&evb2, socket->getEventBase());
7571   Mock::VerifyAndClearExpectations(cb.get());
7572 
7573   // detach the new evb2 and re-attach the old evb.
7574   EXPECT_CALL(*cb, evbDetachMock(socket.get(), &evb2));
7575   socket->detachEventBase();
7576   EXPECT_EQ(nullptr, socket->getEventBase());
7577   Mock::VerifyAndClearExpectations(cb.get());
7578 
7579   EXPECT_CALL(*cb, evbAttachMock(socket.get(), &evb));
7580   socket->attachEventBase(&evb);
7581   EXPECT_EQ(&evb, socket->getEventBase());
7582   Mock::VerifyAndClearExpectations(cb.get());
7583 
7584   InSequence s;
7585   EXPECT_CALL(*cb, destroyMock(socket.get()));
7586   socket = nullptr;
7587   Mock::VerifyAndClearExpectations(cb.get());
7588 }
7589 
TEST(AsyncSocket,LifecycleObserverAttachThenDestroySocket)7590 TEST(AsyncSocket, LifecycleObserverAttachThenDestroySocket) {
7591   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7592   TestServer server;
7593 
7594   EventBase evb;
7595   auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7596   EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7597   socket->addLifecycleObserver(cb.get());
7598   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7599   Mock::VerifyAndClearExpectations(cb.get());
7600 
7601   EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
7602   EXPECT_CALL(*cb, fdAttachMock(socket.get()));
7603   EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
7604   socket->connect(nullptr, server.getAddress(), 30);
7605   evb.loop();
7606   Mock::VerifyAndClearExpectations(cb.get());
7607 
7608   InSequence s;
7609   EXPECT_CALL(*cb, closeMock(socket.get()));
7610   EXPECT_CALL(*cb, destroyMock(socket.get()));
7611   socket = nullptr;
7612   Mock::VerifyAndClearExpectations(cb.get());
7613 }
7614 
TEST(AsyncSocket,LifecycleObserverAttachThenConnectError)7615 TEST(AsyncSocket, LifecycleObserverAttachThenConnectError) {
7616   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7617   // port =1 is unreachble on localhost
7618   folly::SocketAddress unreachable{"::1", 1};
7619 
7620   EventBase evb;
7621   auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7622   EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7623   socket->addLifecycleObserver(cb.get());
7624   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7625   Mock::VerifyAndClearExpectations(cb.get());
7626 
7627   // the current state machine calls AsyncSocket::invokeConnectionError() twice
7628   // for this use-case...
7629   EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
7630   EXPECT_CALL(*cb, fdAttachMock(socket.get()));
7631   EXPECT_CALL(*cb, connectErrorMock(socket.get(), _)).Times(2);
7632   EXPECT_CALL(*cb, closeMock(socket.get()));
7633   socket->connect(nullptr, unreachable, 1);
7634   evb.loop();
7635   Mock::VerifyAndClearExpectations(cb.get());
7636 
7637   EXPECT_CALL(*cb, destroyMock(socket.get()));
7638   socket = nullptr;
7639   Mock::VerifyAndClearExpectations(cb.get());
7640 }
7641 
TEST(AsyncSocket,LifecycleObserverMultipleAttachThenDestroySocket)7642 TEST(AsyncSocket, LifecycleObserverMultipleAttachThenDestroySocket) {
7643   auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7644   auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7645   TestServer server;
7646 
7647   EventBase evb;
7648   auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7649   EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
7650   socket->addLifecycleObserver(cb1.get());
7651   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7652   Mock::VerifyAndClearExpectations(cb1.get());
7653   Mock::VerifyAndClearExpectations(cb2.get());
7654 
7655   EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
7656   socket->addLifecycleObserver(cb2.get());
7657   EXPECT_THAT(
7658       socket->getLifecycleObservers(),
7659       UnorderedElementsAre(cb1.get(), cb2.get()));
7660   Mock::VerifyAndClearExpectations(cb1.get());
7661   Mock::VerifyAndClearExpectations(cb2.get());
7662 
7663   InSequence s;
7664   EXPECT_CALL(*cb1, connectAttemptMock(socket.get()));
7665   EXPECT_CALL(*cb2, connectAttemptMock(socket.get()));
7666   EXPECT_CALL(*cb1, fdAttachMock(socket.get()));
7667   EXPECT_CALL(*cb2, fdAttachMock(socket.get()));
7668   EXPECT_CALL(*cb1, connectSuccessMock(socket.get()));
7669   EXPECT_CALL(*cb2, connectSuccessMock(socket.get()));
7670   socket->connect(nullptr, server.getAddress(), 30);
7671   evb.loop();
7672   Mock::VerifyAndClearExpectations(cb1.get());
7673   Mock::VerifyAndClearExpectations(cb2.get());
7674 
7675   EXPECT_CALL(*cb1, closeMock(socket.get()));
7676   EXPECT_CALL(*cb2, closeMock(socket.get()));
7677   EXPECT_CALL(*cb1, destroyMock(socket.get()));
7678   EXPECT_CALL(*cb2, destroyMock(socket.get()));
7679   socket = nullptr;
7680   Mock::VerifyAndClearExpectations(cb1.get());
7681   Mock::VerifyAndClearExpectations(cb2.get());
7682 }
7683 
TEST(AsyncSocket,LifecycleObserverAttachRemove)7684 TEST(AsyncSocket, LifecycleObserverAttachRemove) {
7685   EventBase evb;
7686   auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7687 
7688   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7689   EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7690   socket->addLifecycleObserver(cb.get());
7691   Mock::VerifyAndClearExpectations(cb.get());
7692 
7693   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7694   EXPECT_CALL(*cb, observerDetachMock(socket.get()));
7695   EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
7696   EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
7697   Mock::VerifyAndClearExpectations(cb.get());
7698 }
7699 
TEST(AsyncSocket,LifecycleObserverAttachRemoveMultiple)7700 TEST(AsyncSocket, LifecycleObserverAttachRemoveMultiple) {
7701   EventBase evb;
7702   auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7703 
7704   auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7705   EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
7706   socket->addLifecycleObserver(cb1.get());
7707   Mock::VerifyAndClearExpectations(cb1.get());
7708   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7709 
7710   auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7711   EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
7712   socket->addLifecycleObserver(cb2.get());
7713   Mock::VerifyAndClearExpectations(cb2.get());
7714   EXPECT_THAT(
7715       socket->getLifecycleObservers(),
7716       UnorderedElementsAre(cb1.get(), cb2.get()));
7717 
7718   EXPECT_CALL(*cb1, observerDetachMock(socket.get()));
7719   EXPECT_TRUE(socket->removeLifecycleObserver(cb1.get()));
7720   Mock::VerifyAndClearExpectations(cb1.get());
7721   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb2.get()));
7722 
7723   EXPECT_CALL(*cb2, observerDetachMock(socket.get()));
7724   EXPECT_TRUE(socket->removeLifecycleObserver(cb2.get()));
7725   Mock::VerifyAndClearExpectations(cb2.get());
7726   EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
7727 }
7728 
TEST(AsyncSocket,LifecycleObserverAttachRemoveMultipleReverse)7729 TEST(AsyncSocket, LifecycleObserverAttachRemoveMultipleReverse) {
7730   EventBase evb;
7731   auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7732 
7733   auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7734   EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
7735   socket->addLifecycleObserver(cb1.get());
7736   Mock::VerifyAndClearExpectations(cb1.get());
7737   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7738 
7739   auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7740   EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
7741   socket->addLifecycleObserver(cb2.get());
7742   Mock::VerifyAndClearExpectations(cb2.get());
7743   EXPECT_THAT(
7744       socket->getLifecycleObservers(),
7745       UnorderedElementsAre(cb1.get(), cb2.get()));
7746 
7747   EXPECT_CALL(*cb2, observerDetachMock(socket.get()));
7748   EXPECT_TRUE(socket->removeLifecycleObserver(cb2.get()));
7749   Mock::VerifyAndClearExpectations(cb2.get());
7750   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7751 
7752   EXPECT_CALL(*cb1, observerDetachMock(socket.get()));
7753   EXPECT_TRUE(socket->removeLifecycleObserver(cb1.get()));
7754   Mock::VerifyAndClearExpectations(cb1.get());
7755   EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
7756 }
7757 
TEST(AsyncSocket,LifecycleObserverRemoveMissing)7758 TEST(AsyncSocket, LifecycleObserverRemoveMissing) {
7759   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7760   EventBase evb;
7761   auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7762   EXPECT_FALSE(socket->removeLifecycleObserver(cb.get()));
7763 }
7764 
TEST(AsyncSocket,LifecycleObserverMultipleAttachThenRemove)7765 TEST(AsyncSocket, LifecycleObserverMultipleAttachThenRemove) {
7766   auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7767   auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7768   TestServer server;
7769 
7770   EventBase evb;
7771   auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7772   EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
7773   socket->addLifecycleObserver(cb1.get());
7774   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7775   Mock::VerifyAndClearExpectations(cb1.get());
7776   Mock::VerifyAndClearExpectations(cb2.get());
7777 
7778   EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
7779   socket->addLifecycleObserver(cb2.get());
7780   EXPECT_THAT(
7781       socket->getLifecycleObservers(),
7782       UnorderedElementsAre(cb1.get(), cb2.get()));
7783   Mock::VerifyAndClearExpectations(cb1.get());
7784   Mock::VerifyAndClearExpectations(cb2.get());
7785 
7786   EXPECT_CALL(*cb2, observerDetachMock(socket.get()));
7787   EXPECT_TRUE(socket->removeLifecycleObserver(cb2.get()));
7788   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
7789   Mock::VerifyAndClearExpectations(cb1.get());
7790   Mock::VerifyAndClearExpectations(cb2.get());
7791 
7792   EXPECT_CALL(*cb1, observerDetachMock(socket.get()));
7793   socket->removeLifecycleObserver(cb1.get());
7794   EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
7795   Mock::VerifyAndClearExpectations(cb1.get());
7796   Mock::VerifyAndClearExpectations(cb2.get());
7797 }
7798 
TEST(AsyncSocket,LifecycleObserverDetach)7799 TEST(AsyncSocket, LifecycleObserverDetach) {
7800   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7801   TestServer server;
7802 
7803   EventBase evb;
7804   auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7805   EXPECT_CALL(*cb, observerAttachMock(socket1.get()));
7806   socket1->addLifecycleObserver(cb.get());
7807   EXPECT_THAT(socket1->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7808   Mock::VerifyAndClearExpectations(cb.get());
7809 
7810   EXPECT_CALL(*cb, connectAttemptMock(socket1.get()));
7811   EXPECT_CALL(*cb, fdAttachMock(socket1.get()));
7812   EXPECT_CALL(*cb, connectSuccessMock(socket1.get()));
7813   socket1->connect(nullptr, server.getAddress(), 30);
7814   evb.loop();
7815   Mock::VerifyAndClearExpectations(cb.get());
7816 
7817   EXPECT_CALL(*cb, fdDetachMock(socket1.get()));
7818   auto fd = socket1->detachNetworkSocket();
7819   Mock::VerifyAndClearExpectations(cb.get());
7820 
7821   // create socket2, then immediately destroy it, should get no callbacks
7822   auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(&evb, fd));
7823   socket2 = nullptr;
7824 
7825   // finally, destroy socket1
7826   EXPECT_CALL(*cb, destroyMock(socket1.get()));
7827 }
7828 
TEST(AsyncSocket,LifecycleObserverMoveResubscribe)7829 TEST(AsyncSocket, LifecycleObserverMoveResubscribe) {
7830   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7831   TestServer server;
7832 
7833   EventBase evb;
7834   auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7835   EXPECT_CALL(*cb, observerAttachMock(socket1.get()));
7836   socket1->addLifecycleObserver(cb.get());
7837   EXPECT_THAT(socket1->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7838   Mock::VerifyAndClearExpectations(cb.get());
7839 
7840   EXPECT_CALL(*cb, connectAttemptMock(socket1.get()));
7841   EXPECT_CALL(*cb, fdAttachMock(socket1.get()));
7842   EXPECT_CALL(*cb, connectSuccessMock(socket1.get()));
7843   socket1->connect(nullptr, server.getAddress(), 30);
7844   evb.loop();
7845   Mock::VerifyAndClearExpectations(cb.get());
7846 
7847   AsyncSocket* socket2PtrCapturedmoved = nullptr;
7848   {
7849     InSequence s;
7850     EXPECT_CALL(*cb, fdDetachMock(socket1.get()));
7851     EXPECT_CALL(*cb, moveMock(socket1.get(), Not(socket1.get())))
7852         .WillOnce(Invoke(
7853             [&socket2PtrCapturedmoved, &cb](auto oldSocket, auto newSocket) {
7854               socket2PtrCapturedmoved = newSocket;
7855               EXPECT_CALL(*cb, observerDetachMock(oldSocket));
7856               EXPECT_CALL(*cb, observerAttachMock(newSocket));
7857               EXPECT_TRUE(oldSocket->removeLifecycleObserver(cb.get()));
7858               EXPECT_THAT(oldSocket->getLifecycleObservers(), IsEmpty());
7859               newSocket->addLifecycleObserver(cb.get());
7860               EXPECT_THAT(
7861                   newSocket->getLifecycleObservers(),
7862                   UnorderedElementsAre(cb.get()));
7863             }));
7864   }
7865   auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
7866   Mock::VerifyAndClearExpectations(cb.get());
7867   EXPECT_EQ(socket2.get(), socket2PtrCapturedmoved);
7868 
7869   {
7870     InSequence s;
7871     EXPECT_CALL(*cb, closeMock(socket2.get()));
7872     EXPECT_CALL(*cb, destroyMock(socket2.get()));
7873   }
7874   socket2 = nullptr;
7875 }
7876 
TEST(AsyncSocket,LifecycleObserverMoveDoNotResubscribe)7877 TEST(AsyncSocket, LifecycleObserverMoveDoNotResubscribe) {
7878   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7879   TestServer server;
7880 
7881   EventBase evb;
7882   auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
7883   EXPECT_CALL(*cb, observerAttachMock(socket1.get()));
7884   socket1->addLifecycleObserver(cb.get());
7885   EXPECT_THAT(socket1->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7886   Mock::VerifyAndClearExpectations(cb.get());
7887 
7888   EXPECT_CALL(*cb, connectAttemptMock(socket1.get()));
7889   EXPECT_CALL(*cb, fdAttachMock(socket1.get()));
7890   EXPECT_CALL(*cb, connectSuccessMock(socket1.get()));
7891   socket1->connect(nullptr, server.getAddress(), 30);
7892   evb.loop();
7893   Mock::VerifyAndClearExpectations(cb.get());
7894 
7895   // close will not be called on socket1 because the fd is detached
7896   AsyncSocket* socket2PtrCapturedMoved = nullptr;
7897   InSequence s;
7898   EXPECT_CALL(*cb, fdDetachMock(socket1.get()));
7899   EXPECT_CALL(*cb, moveMock(socket1.get(), Not(socket1.get())))
7900       .WillOnce(Invoke(
7901           [&socket2PtrCapturedMoved](auto /* oldSocket */, auto newSocket) {
7902             socket2PtrCapturedMoved = newSocket;
7903           }));
7904   EXPECT_CALL(*cb, destroyMock(socket1.get()));
7905   auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
7906   Mock::VerifyAndClearExpectations(cb.get());
7907   EXPECT_EQ(socket2.get(), socket2PtrCapturedMoved);
7908 }
7909 
TEST(AsyncSocket,LifecycleObserverDetachCallbackImmediately)7910 TEST(AsyncSocket, LifecycleObserverDetachCallbackImmediately) {
7911   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7912   TestServer server;
7913 
7914   EventBase evb;
7915   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
7916   EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7917   socket->addLifecycleObserver(cb.get());
7918   EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
7919   Mock::VerifyAndClearExpectations(cb.get());
7920 
7921   EXPECT_CALL(*cb, observerDetachMock(socket.get()));
7922   EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
7923   EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
7924   Mock::VerifyAndClearExpectations(cb.get());
7925 
7926   // keep going to ensure no further callbacks
7927   socket->connect(nullptr, server.getAddress(), 30);
7928   evb.loop();
7929 }
7930 
TEST(AsyncSocket,LifecycleObserverDetachCallbackAfterConnect)7931 TEST(AsyncSocket, LifecycleObserverDetachCallbackAfterConnect) {
7932   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7933   TestServer server;
7934 
7935   EventBase evb;
7936   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
7937   EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7938   socket->addLifecycleObserver(cb.get());
7939   Mock::VerifyAndClearExpectations(cb.get());
7940 
7941   EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
7942   EXPECT_CALL(*cb, fdAttachMock(socket.get()));
7943   EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
7944   socket->connect(nullptr, server.getAddress(), 30);
7945   evb.loop();
7946   Mock::VerifyAndClearExpectations(cb.get());
7947 
7948   EXPECT_CALL(*cb, observerDetachMock(socket.get()));
7949   EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
7950   Mock::VerifyAndClearExpectations(cb.get());
7951 }
7952 
TEST(AsyncSocket,LifecycleObserverDetachCallbackAfterClose)7953 TEST(AsyncSocket, LifecycleObserverDetachCallbackAfterClose) {
7954   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7955   TestServer server;
7956 
7957   EventBase evb;
7958   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
7959   EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7960   socket->addLifecycleObserver(cb.get());
7961   Mock::VerifyAndClearExpectations(cb.get());
7962 
7963   EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
7964   EXPECT_CALL(*cb, fdAttachMock(socket.get()));
7965   EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
7966   socket->connect(nullptr, server.getAddress(), 30);
7967   evb.loop();
7968   Mock::VerifyAndClearExpectations(cb.get());
7969 
7970   EXPECT_CALL(*cb, closeMock(socket.get()));
7971   socket->closeNow();
7972   Mock::VerifyAndClearExpectations(cb.get());
7973 
7974   EXPECT_CALL(*cb, observerDetachMock(socket.get()));
7975   EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
7976   Mock::VerifyAndClearExpectations(cb.get());
7977 }
7978 
TEST(AsyncSocket,LifecycleObserverDetachCallbackcloseDuringDestroy)7979 TEST(AsyncSocket, LifecycleObserverDetachCallbackcloseDuringDestroy) {
7980   auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
7981   TestServer server;
7982 
7983   EventBase evb;
7984   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
7985   EXPECT_CALL(*cb, observerAttachMock(socket.get()));
7986   socket->addLifecycleObserver(cb.get());
7987   Mock::VerifyAndClearExpectations(cb.get());
7988 
7989   EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
7990   EXPECT_CALL(*cb, fdAttachMock(socket.get()));
7991   EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
7992   socket->connect(nullptr, server.getAddress(), 30);
7993   evb.loop();
7994   Mock::VerifyAndClearExpectations(cb.get());
7995 
7996   InSequence s;
7997   EXPECT_CALL(*cb, closeMock(socket.get()))
7998       .WillOnce(Invoke([&cb](auto callbackSocket) {
7999         EXPECT_TRUE(callbackSocket->removeLifecycleObserver(cb.get()));
8000       }));
8001   EXPECT_CALL(*cb, observerDetachMock(socket.get()));
8002   socket = nullptr;
8003   Mock::VerifyAndClearExpectations(cb.get());
8004 }
8005 
TEST(AsyncSocket,LifecycleObserverBaseClassMoveNoCrash)8006 TEST(AsyncSocket, LifecycleObserverBaseClassMoveNoCrash) {
8007   // use mock for AsyncTransport::LifecycleObserver, which does not have
8008   // move or fdDetach events; verify that static_cast works as expected
8009   auto cb = std::make_unique<StrictMock<MockAsyncTransportLifecycleObserver>>();
8010   TestServer server;
8011 
8012   EventBase evb;
8013   auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
8014   EXPECT_CALL(*cb, observerAttachMock(socket1.get()));
8015   socket1->addLifecycleObserver(cb.get());
8016   EXPECT_THAT(socket1->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
8017   Mock::VerifyAndClearExpectations(cb.get());
8018 
8019   EXPECT_CALL(*cb, connectAttemptMock(socket1.get()));
8020   EXPECT_CALL(*cb, connectSuccessMock(socket1.get()));
8021   socket1->connect(nullptr, server.getAddress(), 30);
8022   evb.loop();
8023   Mock::VerifyAndClearExpectations(cb.get());
8024 
8025   // we'll see socket1 get destroyed, but nothing else
8026   EXPECT_CALL(*cb, destroyMock(socket1.get()));
8027   auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
8028   Mock::VerifyAndClearExpectations(cb.get());
8029 }
8030 
TEST(AsyncSocket,PreReceivedData)8031 TEST(AsyncSocket, PreReceivedData) {
8032   TestServer server;
8033 
8034   EventBase evb;
8035   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8036   socket->connect(nullptr, server.getAddress(), 30);
8037   evb.loop();
8038 
8039   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
8040 
8041   auto acceptedSocket = server.acceptAsync(&evb);
8042 
8043   ReadCallback peekCallback(2);
8044   ReadCallback readCallback;
8045   peekCallback.dataAvailableCallback = [&]() {
8046     peekCallback.verifyData("he", 2);
8047     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h"));
8048     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e"));
8049     acceptedSocket->setReadCB(nullptr);
8050     acceptedSocket->setReadCB(&readCallback);
8051   };
8052   readCallback.dataAvailableCallback = [&]() {
8053     if (readCallback.dataRead() == 5) {
8054       readCallback.verifyData("hello", 5);
8055       acceptedSocket->setReadCB(nullptr);
8056     }
8057   };
8058 
8059   acceptedSocket->setReadCB(&peekCallback);
8060 
8061   evb.loop();
8062 }
8063 
TEST(AsyncSocket,PreReceivedDataOnly)8064 TEST(AsyncSocket, PreReceivedDataOnly) {
8065   TestServer server;
8066 
8067   EventBase evb;
8068   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8069   socket->connect(nullptr, server.getAddress(), 30);
8070   evb.loop();
8071 
8072   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
8073 
8074   auto acceptedSocket = server.acceptAsync(&evb);
8075 
8076   ReadCallback peekCallback;
8077   ReadCallback readCallback;
8078   peekCallback.dataAvailableCallback = [&]() {
8079     peekCallback.verifyData("hello", 5);
8080     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
8081     EXPECT_TRUE(acceptedSocket->readable());
8082     acceptedSocket->setReadCB(&readCallback);
8083   };
8084   readCallback.dataAvailableCallback = [&]() {
8085     readCallback.verifyData("hello", 5);
8086     acceptedSocket->setReadCB(nullptr);
8087   };
8088 
8089   acceptedSocket->setReadCB(&peekCallback);
8090 
8091   evb.loop();
8092 }
8093 
TEST(AsyncSocket,PreReceivedDataPartial)8094 TEST(AsyncSocket, PreReceivedDataPartial) {
8095   TestServer server;
8096 
8097   EventBase evb;
8098   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8099   socket->connect(nullptr, server.getAddress(), 30);
8100   evb.loop();
8101 
8102   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
8103 
8104   auto acceptedSocket = server.acceptAsync(&evb);
8105 
8106   ReadCallback peekCallback;
8107   ReadCallback smallReadCallback(3);
8108   ReadCallback normalReadCallback;
8109   peekCallback.dataAvailableCallback = [&]() {
8110     peekCallback.verifyData("hello", 5);
8111     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
8112     acceptedSocket->setReadCB(&smallReadCallback);
8113   };
8114   smallReadCallback.dataAvailableCallback = [&]() {
8115     smallReadCallback.verifyData("hel", 3);
8116     acceptedSocket->setReadCB(&normalReadCallback);
8117   };
8118   normalReadCallback.dataAvailableCallback = [&]() {
8119     normalReadCallback.verifyData("lo", 2);
8120     acceptedSocket->setReadCB(nullptr);
8121   };
8122 
8123   acceptedSocket->setReadCB(&peekCallback);
8124 
8125   evb.loop();
8126 }
8127 
TEST(AsyncSocket,PreReceivedDataTakeover)8128 TEST(AsyncSocket, PreReceivedDataTakeover) {
8129   TestServer server;
8130 
8131   EventBase evb;
8132   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8133   socket->connect(nullptr, server.getAddress(), 30);
8134   evb.loop();
8135 
8136   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
8137 
8138   auto fd = server.acceptFD();
8139   SocketAddress peerAddress;
8140   peerAddress.setFromPeerAddress(fd);
8141   auto acceptedSocket =
8142       AsyncSocket::UniquePtr(new AsyncSocket(&evb, fd, 0, &peerAddress));
8143   AsyncSocket::UniquePtr takeoverSocket;
8144 
8145   ReadCallback peekCallback(3);
8146   ReadCallback readCallback;
8147   peekCallback.dataAvailableCallback = [&]() {
8148     peekCallback.verifyData("hel", 3);
8149     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
8150     acceptedSocket->setReadCB(nullptr);
8151     takeoverSocket =
8152         AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket)));
8153     takeoverSocket->setReadCB(&readCallback);
8154   };
8155   readCallback.dataAvailableCallback = [&]() {
8156     readCallback.verifyData("hello", 5);
8157     takeoverSocket->setReadCB(nullptr);
8158   };
8159 
8160   acceptedSocket->setReadCB(&peekCallback);
8161 
8162   evb.loop();
8163   // Verify we can still get the peer address after the peer socket is reset.
8164   socket->closeWithReset();
8165   evb.loopOnce();
8166   SocketAddress socketPeerAddress;
8167   takeoverSocket->getPeerAddress(&socketPeerAddress);
8168   EXPECT_EQ(socketPeerAddress, peerAddress);
8169 }
8170 
8171 #ifdef MSG_NOSIGNAL
TEST(AsyncSocketTest,SendMessageFlags)8172 TEST(AsyncSocketTest, SendMessageFlags) {
8173   TestServer server;
8174   TestSendMsgParamsCallback sendMsgCB(
8175       MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE, 0, nullptr);
8176 
8177   // connect()
8178   EventBase evb;
8179   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8180 
8181   ConnCallback ccb;
8182   socket->connect(&ccb, server.getAddress(), 30);
8183   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
8184 
8185   evb.loop();
8186   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
8187 
8188   // Set SendMsgParamsCallback
8189   socket->setSendMsgParamCB(&sendMsgCB);
8190   ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB);
8191 
8192   // Write the first portion of data. This data is expected to be
8193   // sent out immediately.
8194   std::vector<uint8_t> buf(128, 'a');
8195   WriteCallback wcb;
8196   sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL);
8197   socket->write(&wcb, buf.data(), buf.size());
8198   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
8199   ASSERT_TRUE(sendMsgCB.queriedFlags_);
8200   ASSERT_FALSE(sendMsgCB.queriedData_);
8201 
8202   // Using different flags for the second write operation.
8203   // MSG_MORE flag is expected to delay sending this
8204   // data to the wire.
8205   sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE);
8206   socket->write(&wcb, buf.data(), buf.size());
8207   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
8208   ASSERT_TRUE(sendMsgCB.queriedFlags_);
8209   ASSERT_FALSE(sendMsgCB.queriedData_);
8210 
8211   // Make sure the accepted socket saw only the data from
8212   // the first write request.
8213   std::vector<uint8_t> readbuf(2 * buf.size());
8214   uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size());
8215   ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
8216   ASSERT_EQ(bytesRead, buf.size());
8217 
8218   // Make sure the server got a connection and received the data
8219   acceptedSocket->close();
8220   socket->close();
8221 
8222   ASSERT_TRUE(socket->isClosedBySelf());
8223   ASSERT_FALSE(socket->isClosedByPeer());
8224 }
8225 
TEST(AsyncSocketTest,SendMessageAncillaryData)8226 TEST(AsyncSocketTest, SendMessageAncillaryData) {
8227   NetworkSocket fds[2];
8228   EXPECT_EQ(netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0);
8229 
8230   // "Client" socket
8231   auto cfd = fds[0];
8232   ASSERT_NE(cfd, NetworkSocket());
8233 
8234   // "Server" socket
8235   auto sfd = fds[1];
8236   ASSERT_NE(sfd, NetworkSocket());
8237   SCOPE_EXIT { netops::close(sfd); };
8238 
8239   // Instantiate AsyncSocket object for the connected socket
8240   EventBase evb;
8241   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, cfd);
8242 
8243   // Open a temporary file and write a magic string to it
8244   // We'll transfer the file handle to test the message parameters
8245   // callback logic.
8246   TemporaryFile file(
8247       StringPiece(), fs::path(), TemporaryFile::Scope::UNLINK_IMMEDIATELY);
8248   int tmpfd = file.fd();
8249   ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file";
8250   std::string magicString("Magic string");
8251   ASSERT_EQ(
8252       write(tmpfd, magicString.c_str(), magicString.length()),
8253       magicString.length());
8254 
8255   // Send message
8256   union {
8257     // Space large enough to hold an 'int'
8258     char control[CMSG_SPACE(sizeof(int))];
8259     struct cmsghdr cmh;
8260   } s_u;
8261   s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int));
8262   s_u.cmh.cmsg_level = SOL_SOCKET;
8263   s_u.cmh.cmsg_type = SCM_RIGHTS;
8264   memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int));
8265 
8266   // Set up the callback providing message parameters
8267   TestSendMsgParamsCallback sendMsgCB(
8268       MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control);
8269   socket->setSendMsgParamCB(&sendMsgCB);
8270 
8271   // We must transmit at least 1 byte of real data in order
8272   // to send ancillary data
8273   int s_data = 12345;
8274   WriteCallback wcb;
8275   socket->write(&wcb, &s_data, sizeof(s_data));
8276   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
8277 
8278   // Receive the message
8279   union {
8280     // Space large enough to hold an 'int'
8281     char control[CMSG_SPACE(sizeof(int))];
8282     struct cmsghdr cmh;
8283   } r_u;
8284   struct msghdr msgh;
8285   struct iovec iov;
8286   int r_data = 0;
8287 
8288   msgh.msg_control = r_u.control;
8289   msgh.msg_controllen = sizeof(r_u.control);
8290   msgh.msg_name = nullptr;
8291   msgh.msg_namelen = 0;
8292   msgh.msg_iov = &iov;
8293   msgh.msg_iovlen = 1;
8294   iov.iov_base = &r_data;
8295   iov.iov_len = sizeof(r_data);
8296 
8297   // Receive data
8298   ASSERT_NE(netops::recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno;
8299 
8300   // Validate the received message
8301   ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int)));
8302   ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET);
8303   ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS);
8304   ASSERT_EQ(r_data, s_data);
8305   int fd = 0;
8306   memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int));
8307   ASSERT_NE(fd, 0);
8308   SCOPE_EXIT { close(fd); };
8309 
8310   std::vector<uint8_t> transferredMagicString(magicString.length() + 1, 0);
8311 
8312   // Reposition to the beginning of the file
8313   ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));
8314 
8315   // Read the magic string back, and compare it with the original
8316   ASSERT_EQ(
8317       magicString.length(),
8318       read(fd, transferredMagicString.data(), transferredMagicString.size()));
8319   ASSERT_TRUE(std::equal(
8320       magicString.begin(), magicString.end(), transferredMagicString.begin()));
8321 }
8322 
TEST(AsyncSocketTest,UnixDomainSocketErrMessageCB)8323 TEST(AsyncSocketTest, UnixDomainSocketErrMessageCB) {
8324   // In the latest stable kernel 4.14.3 as of 2017-12-04, Unix Domain
8325   // Socket (UDS) does not support MSG_ERRQUEUE. So
8326   // recvmsg(MSG_ERRQUEUE) will read application data from UDS which
8327   // breaks application message flow.  To avoid this problem,
8328   // AsyncSocket currently disables setErrMessageCB for UDS.
8329   //
8330   // This tests two things for UDS
8331   // 1. setErrMessageCB fails
8332   // 2. recvmsg(MSG_ERRQUEUE) reads application data
8333   //
8334   // Feel free to remove this test if UDS supports MSG_ERRQUEUE in the future.
8335 
8336   NetworkSocket fd[2];
8337   EXPECT_EQ(netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fd), 0);
8338   ASSERT_NE(fd[0], NetworkSocket());
8339   ASSERT_NE(fd[1], NetworkSocket());
8340   SCOPE_EXIT { netops::close(fd[1]); };
8341 
8342   EXPECT_EQ(netops::set_socket_non_blocking(fd[0]), 0);
8343   EXPECT_EQ(netops::set_socket_non_blocking(fd[1]), 0);
8344 
8345   EventBase evb;
8346   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, fd[0]);
8347 
8348   // setErrMessageCB should fail for unix domain socket
8349   TestErrMessageCallback errMsgCB;
8350   ASSERT_NE(&errMsgCB, nullptr);
8351   socket->setErrMessageCB(&errMsgCB);
8352   ASSERT_EQ(socket->getErrMessageCallback(), nullptr);
8353 
8354 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
8355   // The following verifies that MSG_ERRQUEUE does not work for UDS,
8356   // and recvmsg reads application data
8357   union {
8358     // Space large enough to hold an 'int'
8359     char control[CMSG_SPACE(sizeof(int))];
8360     struct cmsghdr cmh;
8361   } r_u;
8362   struct msghdr msgh;
8363   struct iovec iov;
8364   int recv_data = 0;
8365 
8366   msgh.msg_control = r_u.control;
8367   msgh.msg_controllen = sizeof(r_u.control);
8368   msgh.msg_name = nullptr;
8369   msgh.msg_namelen = 0;
8370   msgh.msg_iov = &iov;
8371   msgh.msg_iovlen = 1;
8372   iov.iov_base = &recv_data;
8373   iov.iov_len = sizeof(recv_data);
8374 
8375   // there is no data, recvmsg should fail
8376   EXPECT_EQ(netops::recvmsg(fd[1], &msgh, MSG_ERRQUEUE), -1);
8377   EXPECT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK);
8378 
8379   // provide some application data, error queue should be empty if it exists
8380   // However, UDS reads application data as error message
8381   int test_data = 123456;
8382   WriteCallback wcb;
8383   socket->write(&wcb, &test_data, sizeof(test_data));
8384   recv_data = 0;
8385   ASSERT_NE(netops::recvmsg(fd[1], &msgh, MSG_ERRQUEUE), -1);
8386   ASSERT_EQ(recv_data, test_data);
8387 #endif // FOLLY_HAVE_MSG_ERRQUEUE
8388 }
8389 
TEST(AsyncSocketTest,V6TosReflectTest)8390 TEST(AsyncSocketTest, V6TosReflectTest) {
8391   EventBase eventBase;
8392 
8393   // Create a server socket
8394   std::shared_ptr<AsyncServerSocket> serverSocket(
8395       AsyncServerSocket::newSocket(&eventBase));
8396   folly::IPAddress ip("::1");
8397   std::vector<folly::IPAddress> serverIp;
8398   serverIp.push_back(ip);
8399   serverSocket->bind(serverIp, 0);
8400   serverSocket->listen(16);
8401   folly::SocketAddress serverAddress;
8402   serverSocket->getAddress(&serverAddress);
8403 
8404   // Enable TOS reflect
8405   serverSocket->setTosReflect(true);
8406 
8407   // Add a callback to accept one connection then stop the loop
8408   TestAcceptCallback acceptCallback;
8409   acceptCallback.setConnectionAcceptedFn(
8410       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
8411         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8412       });
8413   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
8414     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8415   });
8416   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
8417   serverSocket->startAccepting();
8418 
8419   // Create a client socket, setsockopt() the TOS before connecting
8420   auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
8421                          ConnCallback* ccb,
8422                          EventBase* evb,
8423                          folly::SocketAddress sAddr) {
8424     clientSock = AsyncSocket::newSocket(evb);
8425     SocketOptionKey v6Opts = {IPPROTO_IPV6, IPV6_TCLASS};
8426     SocketOptionMap optionMap;
8427     optionMap.insert({v6Opts, 0x2c});
8428     SocketAddress bindAddr("0.0.0.0", 0);
8429     clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
8430   };
8431 
8432   std::shared_ptr<AsyncSocket> socket(nullptr);
8433   ConnCallback cb;
8434   clientThread(socket, &cb, &eventBase, serverAddress);
8435 
8436   eventBase.loop();
8437 
8438   // Verify if the connection is accepted and if the accepted socket has
8439   // setsockopt on the TOS for the same value that was on the client socket
8440   auto fd = acceptCallback.getEvents()->at(1).fd;
8441   ASSERT_NE(fd, NetworkSocket());
8442   int value;
8443   socklen_t valueLength = sizeof(value);
8444   int rc =
8445       netops::getsockopt(fd, IPPROTO_IPV6, IPV6_TCLASS, &value, &valueLength);
8446   ASSERT_EQ(rc, 0);
8447   ASSERT_EQ(value, 0x2c);
8448 
8449   // Additional Test for ConnectCallback without bindAddr
8450   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
8451   serverSocket->startAccepting();
8452 
8453   auto newClientSock = AsyncSocket::newSocket(&eventBase);
8454   TestConnectCallback callback;
8455   // connect call will not set this SO_REUSEADDR if we do not
8456   // pass the bindAddress in its call; so we can safely verify this.
8457   newClientSock->connect(&callback, serverAddress, 30);
8458 
8459   // Collect events
8460   eventBase.loop();
8461 
8462   auto acceptedFd = acceptCallback.getEvents()->at(1).fd;
8463   ASSERT_NE(acceptedFd, NetworkSocket());
8464   int reuseAddrVal;
8465   socklen_t reuseAddrValLen = sizeof(reuseAddrVal);
8466   // Get the socket created underneath connect call of AsyncSocket
8467   auto usedSockFd = newClientSock->getNetworkSocket();
8468   int getOptRet = netops::getsockopt(
8469       usedSockFd, SOL_SOCKET, SO_REUSEADDR, &reuseAddrVal, &reuseAddrValLen);
8470   ASSERT_EQ(getOptRet, 0);
8471   ASSERT_EQ(reuseAddrVal, 1 /* configured through preConnect*/);
8472 }
8473 
TEST(AsyncSocketTest,V4TosReflectTest)8474 TEST(AsyncSocketTest, V4TosReflectTest) {
8475   EventBase eventBase;
8476 
8477   // Create a server socket
8478   std::shared_ptr<AsyncServerSocket> serverSocket(
8479       AsyncServerSocket::newSocket(&eventBase));
8480   folly::IPAddress ip("127.0.0.1");
8481   std::vector<folly::IPAddress> serverIp;
8482   serverIp.push_back(ip);
8483   serverSocket->bind(serverIp, 0);
8484   serverSocket->listen(16);
8485   folly::SocketAddress serverAddress;
8486   serverSocket->getAddress(&serverAddress);
8487 
8488   // Enable TOS reflect
8489   serverSocket->setTosReflect(true);
8490 
8491   // Add a callback to accept one connection then stop the loop
8492   TestAcceptCallback acceptCallback;
8493   acceptCallback.setConnectionAcceptedFn(
8494       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
8495         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8496       });
8497   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
8498     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8499   });
8500   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
8501   serverSocket->startAccepting();
8502 
8503   // Create a client socket, setsockopt() the TOS before connecting
8504   auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
8505                          ConnCallback* ccb,
8506                          EventBase* evb,
8507                          folly::SocketAddress sAddr) {
8508     clientSock = AsyncSocket::newSocket(evb);
8509     SocketOptionKey v4Opts = {IPPROTO_IP, IP_TOS};
8510     SocketOptionMap optionMap;
8511     optionMap.insert({v4Opts, 0x2c});
8512     SocketAddress bindAddr("0.0.0.0", 0);
8513     clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
8514   };
8515 
8516   std::shared_ptr<AsyncSocket> socket(nullptr);
8517   ConnCallback cb;
8518   clientThread(socket, &cb, &eventBase, serverAddress);
8519 
8520   eventBase.loop();
8521 
8522   // Verify if the connection is accepted and if the accepted socket has
8523   // setsockopt on the TOS for the same value that was on the client socket
8524   auto fd = acceptCallback.getEvents()->at(1).fd;
8525   ASSERT_NE(fd, NetworkSocket());
8526   int value;
8527   socklen_t valueLength = sizeof(value);
8528   int rc = netops::getsockopt(fd, IPPROTO_IP, IP_TOS, &value, &valueLength);
8529   ASSERT_EQ(rc, 0);
8530   ASSERT_EQ(value, 0x2c);
8531 }
8532 
TEST(AsyncSocketTest,V6AcceptedTosTest)8533 TEST(AsyncSocketTest, V6AcceptedTosTest) {
8534   EventBase eventBase;
8535 
8536   // This test verifies if the ListenerTos set on a socket is
8537   // propagated properly to accepted socket connections
8538 
8539   // Create a server socket
8540   std::shared_ptr<AsyncServerSocket> serverSocket(
8541       AsyncServerSocket::newSocket(&eventBase));
8542   folly::IPAddress ip("::1");
8543   std::vector<folly::IPAddress> serverIp;
8544   serverIp.push_back(ip);
8545   serverSocket->bind(serverIp, 0);
8546   serverSocket->listen(16);
8547   folly::SocketAddress serverAddress;
8548   serverSocket->getAddress(&serverAddress);
8549 
8550   // Set listener TOS to 0x74 i.e. dscp 29
8551   serverSocket->setListenerTos(0x74);
8552 
8553   // Add a callback to accept one connection then stop the loop
8554   TestAcceptCallback acceptCallback;
8555   acceptCallback.setConnectionAcceptedFn(
8556       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
8557         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8558       });
8559   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
8560     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8561   });
8562   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
8563   serverSocket->startAccepting();
8564 
8565   // Create a client socket, setsockopt() the TOS before connecting
8566   auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
8567                          ConnCallback* ccb,
8568                          EventBase* evb,
8569                          folly::SocketAddress sAddr) {
8570     clientSock = AsyncSocket::newSocket(evb);
8571     SocketOptionKey v6Opts = {IPPROTO_IPV6, IPV6_TCLASS};
8572     SocketOptionMap optionMap;
8573     optionMap.insert({v6Opts, 0x2c});
8574     SocketAddress bindAddr("0.0.0.0", 0);
8575     clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
8576   };
8577 
8578   std::shared_ptr<AsyncSocket> socket(nullptr);
8579   ConnCallback cb;
8580   clientThread(socket, &cb, &eventBase, serverAddress);
8581 
8582   eventBase.loop();
8583 
8584   // Verify if the connection is accepted and if the accepted socket has
8585   // setsockopt on the TOS for the same value that the listener was set to
8586   auto fd = acceptCallback.getEvents()->at(1).fd;
8587   ASSERT_NE(fd, NetworkSocket());
8588   int value;
8589   socklen_t valueLength = sizeof(value);
8590   int rc =
8591       netops::getsockopt(fd, IPPROTO_IPV6, IPV6_TCLASS, &value, &valueLength);
8592   ASSERT_EQ(rc, 0);
8593   ASSERT_EQ(value, 0x74);
8594 }
8595 
TEST(AsyncSocketTest,V4AcceptedTosTest)8596 TEST(AsyncSocketTest, V4AcceptedTosTest) {
8597   EventBase eventBase;
8598 
8599   // This test verifies if the ListenerTos set on a socket is
8600   // propagated properly to accepted socket connections
8601 
8602   // Create a server socket
8603   std::shared_ptr<AsyncServerSocket> serverSocket(
8604       AsyncServerSocket::newSocket(&eventBase));
8605   folly::IPAddress ip("127.0.0.1");
8606   std::vector<folly::IPAddress> serverIp;
8607   serverIp.push_back(ip);
8608   serverSocket->bind(serverIp, 0);
8609   serverSocket->listen(16);
8610   folly::SocketAddress serverAddress;
8611   serverSocket->getAddress(&serverAddress);
8612 
8613   // Set listener TOS to 0x74 i.e. dscp 29
8614   serverSocket->setListenerTos(0x74);
8615 
8616   // Add a callback to accept one connection then stop the loop
8617   TestAcceptCallback acceptCallback;
8618   acceptCallback.setConnectionAcceptedFn(
8619       [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
8620         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8621       });
8622   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
8623     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
8624   });
8625   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
8626   serverSocket->startAccepting();
8627 
8628   // Create a client socket, setsockopt() the TOS before connecting
8629   auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
8630                          ConnCallback* ccb,
8631                          EventBase* evb,
8632                          folly::SocketAddress sAddr) {
8633     clientSock = AsyncSocket::newSocket(evb);
8634     SocketOptionKey v4Opts = {IPPROTO_IP, IP_TOS};
8635     SocketOptionMap optionMap;
8636     optionMap.insert({v4Opts, 0x2c});
8637     SocketAddress bindAddr("0.0.0.0", 0);
8638     clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
8639   };
8640 
8641   std::shared_ptr<AsyncSocket> socket(nullptr);
8642   ConnCallback cb;
8643   clientThread(socket, &cb, &eventBase, serverAddress);
8644 
8645   eventBase.loop();
8646 
8647   // Verify if the connection is accepted and if the accepted socket has
8648   // setsockopt on the TOS for the same value that the listener was set to
8649   auto fd = acceptCallback.getEvents()->at(1).fd;
8650   ASSERT_NE(fd, NetworkSocket());
8651   int value;
8652   socklen_t valueLength = sizeof(value);
8653   int rc = netops::getsockopt(fd, IPPROTO_IP, IP_TOS, &value, &valueLength);
8654   ASSERT_EQ(rc, 0);
8655   ASSERT_EQ(value, 0x74);
8656 }
8657 #endif
8658 
8659 #if defined(__linux__)
TEST(AsyncSocketTest,getBufInUse)8660 TEST(AsyncSocketTest, getBufInUse) {
8661   EventBase eventBase;
8662   std::shared_ptr<AsyncServerSocket> server(
8663       AsyncServerSocket::newSocket(&eventBase));
8664   server->bind(0);
8665   server->listen(5);
8666 
8667   std::shared_ptr<AsyncSocket> client = AsyncSocket::newSocket(&eventBase);
8668   client->connect(nullptr, server->getAddress());
8669 
8670   NetworkSocket servfd = server->getNetworkSocket();
8671   NetworkSocket accepted;
8672   uint64_t maxTries = 5;
8673 
8674   do {
8675     std::this_thread::yield();
8676     eventBase.loop();
8677     accepted = netops::accept(servfd, nullptr, nullptr);
8678   } while (accepted == NetworkSocket() && --maxTries);
8679 
8680   // Exhaustion number of tries to accept client connection, good bye
8681   ASSERT_TRUE(accepted != NetworkSocket());
8682 
8683   auto clientAccepted = AsyncSocket::newSocket(nullptr, accepted);
8684 
8685   // Use minimum receive buffer size
8686   clientAccepted->setRecvBufSize(0);
8687 
8688   // Use maximum send buffer size
8689   client->setSendBufSize((unsigned)-1);
8690 
8691   std::string testData;
8692   for (int i = 0; i < 10000; ++i) {
8693     testData += "0123456789";
8694   }
8695 
8696   client->write(nullptr, (const void*)testData.c_str(), testData.size());
8697 
8698   std::this_thread::yield();
8699   eventBase.loop();
8700 
8701   size_t recvBufSize = clientAccepted->getRecvBufInUse();
8702   size_t sendBufSize = client->getSendBufInUse();
8703 
8704   EXPECT_EQ((recvBufSize + sendBufSize), testData.size());
8705   EXPECT_GT(recvBufSize, 0);
8706   EXPECT_GT(sendBufSize, 0);
8707 }
8708 #endif
8709 
TEST(AsyncSocketTest,QueueTimeout)8710 TEST(AsyncSocketTest, QueueTimeout) {
8711   // Create a new AsyncServerSocket
8712   EventBase eventBase;
8713   std::shared_ptr<AsyncServerSocket> serverSocket(
8714       AsyncServerSocket::newSocket(&eventBase));
8715   serverSocket->bind(0);
8716   serverSocket->listen(16);
8717   folly::SocketAddress serverAddress;
8718   serverSocket->getAddress(&serverAddress);
8719 
8720   constexpr auto kConnectionTimeout = milliseconds(10);
8721   serverSocket->setQueueTimeout(kConnectionTimeout);
8722 
8723   TestAcceptCallback acceptCb;
8724   acceptCb.setConnectionAcceptedFn(
8725       [&, called = false](auto&&...) mutable {
8726         ASSERT_FALSE(called)
8727             << "Only the first connection should have been dequeued";
8728         called = true;
8729         // Allow plenty of time for the AsyncSocketServer's event loop to run.
8730         // This should leave no doubt that the acceptor thread has enough time
8731         // to dequeue. If the dequeue succeeds, then our expiry code is broken.
8732         constexpr auto kEventLoopTime = kConnectionTimeout * 5;
8733         eventBase.runInEventBaseThread([&]() {
8734           eventBase.tryRunAfterDelay(
8735               [&]() { serverSocket->removeAcceptCallback(&acceptCb, nullptr); },
8736               milliseconds(kEventLoopTime).count());
8737         });
8738         // After the first message is enqueued, sleep long enough so that the
8739         // second message expires before it has a chance to dequeue.
8740         std::this_thread::sleep_for(kConnectionTimeout);
8741       });
8742   ScopedEventBaseThread acceptThread("ioworker_test");
8743 
8744   TestConnectionEventCallback connectionEventCb;
8745   serverSocket->setConnectionEventCallback(&connectionEventCb);
8746   serverSocket->addAcceptCallback(&acceptCb, acceptThread.getEventBase());
8747   serverSocket->startAccepting();
8748 
8749   std::shared_ptr<AsyncSocket> clientSocket1(
8750       AsyncSocket::newSocket(&eventBase, serverAddress));
8751   std::shared_ptr<AsyncSocket> clientSocket2(
8752       AsyncSocket::newSocket(&eventBase, serverAddress));
8753 
8754   // Loop until we are stopped
8755   eventBase.loop();
8756 
8757   EXPECT_EQ(connectionEventCb.getConnectionEnqueuedForAcceptCallback(), 2);
8758   // Since the second message is expired, it should NOT be dequeued
8759   EXPECT_EQ(connectionEventCb.getConnectionDequeuedByAcceptCallback(), 1);
8760 }
8761 
8762 class TestRXTimestampsCallback
8763     : public folly::AsyncSocket::ReadAncillaryDataCallback {
8764  public:
TestRXTimestampsCallback()8765   TestRXTimestampsCallback() {}
ancillaryData(struct msghdr & msgh)8766   void ancillaryData(struct msghdr& msgh) noexcept override {
8767     struct cmsghdr* cmsg;
8768     for (cmsg = CMSG_FIRSTHDR(&msgh); cmsg != nullptr;
8769          cmsg = CMSG_NXTHDR(&msgh, cmsg)) {
8770       if (cmsg->cmsg_level != SOL_SOCKET ||
8771           cmsg->cmsg_type != SO_TIMESTAMPING) {
8772         continue;
8773       }
8774       callCount_++;
8775       timespec* ts = (struct timespec*)CMSG_DATA(cmsg);
8776       actualRxTimestampSec_ = ts[0].tv_sec;
8777     }
8778   }
getAncillaryDataCtrlBuffer()8779   folly::MutableByteRange getAncillaryDataCtrlBuffer() override {
8780     return folly::MutableByteRange(ancillaryDataCtrlBuffer_);
8781   }
8782 
8783   uint32_t callCount_{0};
8784   long actualRxTimestampSec_{0};
8785 
8786  private:
8787   std::array<uint8_t, 1024> ancillaryDataCtrlBuffer_;
8788 };
8789 
8790 /**
8791  * Test read ancillary data callback
8792  */
TEST(AsyncSocketTest,readAncillaryData)8793 TEST(AsyncSocketTest, readAncillaryData) {
8794   TestServer server;
8795 
8796   // connect()
8797   EventBase evb;
8798   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
8799 
8800   ConnCallback ccb;
8801   socket->connect(&ccb, server.getAddress(), 1);
8802   LOG(INFO) << "Client socket fd=" << socket->getNetworkSocket();
8803 
8804   // Enable rx timestamp notifications
8805   ASSERT_NE(socket->getNetworkSocket(), NetworkSocket());
8806   int flags = folly::netops::SOF_TIMESTAMPING_SOFTWARE |
8807       folly::netops::SOF_TIMESTAMPING_RX_SOFTWARE |
8808       folly::netops::SOF_TIMESTAMPING_RX_HARDWARE;
8809   SocketOptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
8810   EXPECT_EQ(tstampingOpt.apply(socket->getNetworkSocket(), flags), 0);
8811 
8812   // Accept the connection.
8813   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
8814   LOG(INFO) << "Server socket fd=" << acceptedSocket->getNetworkSocket();
8815 
8816   // Wait for connection
8817   evb.loop();
8818   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
8819 
8820   TestRXTimestampsCallback rxcb;
8821 
8822   // Set read callback
8823   ReadCallback rcb(100);
8824   socket->setReadCB(&rcb);
8825 
8826   // Get the timestamp when the message was write
8827   struct timespec currentTime;
8828   clock_gettime(CLOCK_REALTIME, &currentTime);
8829   long writeTimestampSec = currentTime.tv_sec;
8830 
8831   // write bytes from server (acceptedSocket) to client (socket).
8832   std::vector<uint8_t> wbuf(128, 'a');
8833   acceptedSocket->write(wbuf.data(), wbuf.size());
8834 
8835   // Wait for reading to complete.
8836   evb.loopOnce();
8837   ASSERT_NE(rcb.buffers.size(), 0);
8838 
8839   // Verify that if the callback is not set, it will not be called
8840   ASSERT_EQ(rxcb.callCount_, 0);
8841 
8842   // Set up rx timestamp callbacks
8843   socket->setReadAncillaryDataCB(&rxcb);
8844   acceptedSocket->write(wbuf.data(), wbuf.size());
8845 
8846   // Wait for reading to complete.
8847   evb.loopOnce();
8848   ASSERT_NE(rcb.buffers.size(), 0);
8849 
8850   // Verify that after setting callback, the callback was called
8851   ASSERT_NE(rxcb.callCount_, 0);
8852   // Compare the received timestamp is within an expected range
8853   clock_gettime(CLOCK_REALTIME, &currentTime);
8854   ASSERT_TRUE(rxcb.actualRxTimestampSec_ <= currentTime.tv_sec);
8855   ASSERT_TRUE(rxcb.actualRxTimestampSec_ >= writeTimestampSec);
8856 
8857   // Close both sockets
8858   acceptedSocket->close();
8859   socket->close();
8860 
8861   ASSERT_TRUE(socket->isClosedBySelf());
8862   ASSERT_FALSE(socket->isClosedByPeer());
8863 }
8864