1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <folly/io/async/AsyncSocket.h>
18 
19 #include <sys/types.h>
20 
21 #include <cerrno>
22 #include <climits>
23 #include <sstream>
24 #include <thread>
25 
26 #include <boost/preprocessor/control/if.hpp>
27 
28 #include <folly/Exception.h>
29 #include <folly/ExceptionWrapper.h>
30 #include <folly/Format.h>
31 #include <folly/Portability.h>
32 #include <folly/SocketAddress.h>
33 #include <folly/String.h>
34 #include <folly/io/Cursor.h>
35 #include <folly/io/IOBuf.h>
36 #include <folly/io/IOBufQueue.h>
37 #include <folly/io/SocketOptionMap.h>
38 #include <folly/portability/Fcntl.h>
39 #include <folly/portability/Sockets.h>
40 #include <folly/portability/SysUio.h>
41 #include <folly/portability/Unistd.h>
42 
43 #if defined(__linux__)
44 #include <linux/if_packet.h>
45 #include <linux/sockios.h>
46 #include <sys/ioctl.h>
47 #endif
48 
49 #if FOLLY_HAVE_VLA
50 #define FOLLY_HAVE_VLA_01 1
51 #else
52 #define FOLLY_HAVE_VLA_01 0
53 #endif
54 
55 using std::string;
56 using std::unique_ptr;
57 
58 namespace fsp = folly::portability::sockets;
59 
60 namespace folly {
61 
62 static constexpr bool msgErrQueueSupported =
63 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
64     true;
65 #else
66     false;
67 #endif // FOLLY_HAVE_MSG_ERRQUEUE
68 
getSocketClosedLocallyEx()69 static AsyncSocketException const& getSocketClosedLocallyEx() {
70   static auto& ex = *new AsyncSocketException(
71       AsyncSocketException::END_OF_FILE, "socket closed locally");
72   return ex;
73 }
74 
getSocketShutdownForWritesEx()75 static AsyncSocketException const& getSocketShutdownForWritesEx() {
76   static auto& ex = *new AsyncSocketException(
77       AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
78   return ex;
79 }
80 
81 namespace {
82 #if FOLLY_HAVE_SO_TIMESTAMPING
83 const sock_extended_err* FOLLY_NULLABLE
cmsgToSockExtendedErr(const cmsghdr & cmsg)84 cmsgToSockExtendedErr(const cmsghdr& cmsg) {
85   if ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
86       (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR) ||
87       (cmsg.cmsg_level == SOL_PACKET &&
88        cmsg.cmsg_type == PACKET_TX_TIMESTAMP)) {
89     return reinterpret_cast<const sock_extended_err*>(CMSG_DATA(&cmsg));
90   }
91   (void)cmsg;
92   return nullptr;
93 }
94 
95 const sock_extended_err* FOLLY_NULLABLE
cmsgToSockExtendedErrTimestamping(const cmsghdr & cmsg)96 cmsgToSockExtendedErrTimestamping(const cmsghdr& cmsg) {
97   const auto serr = cmsgToSockExtendedErr(cmsg);
98   if (serr && serr->ee_errno == ENOMSG &&
99       serr->ee_origin == SO_EE_ORIGIN_TIMESTAMPING) {
100     return serr;
101   }
102   (void)cmsg;
103   return nullptr;
104 }
105 
106 const scm_timestamping* FOLLY_NULLABLE
cmsgToScmTimestamping(const cmsghdr & cmsg)107 cmsgToScmTimestamping(const cmsghdr& cmsg) {
108   if (cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_TIMESTAMPING) {
109     return reinterpret_cast<const struct scm_timestamping*>(CMSG_DATA(&cmsg));
110   }
111   (void)cmsg;
112   return nullptr;
113 }
114 
115 #endif // FOLLY_HAVE_SO_TIMESTAMPING
116 } // namespace
117 
118 // TODO: It might help performance to provide a version of BytesWriteRequest
119 // that users could derive from, so we can avoid the extra allocation for each
120 // call to write()/writev().
121 //
122 // We would need the version for external users where they provide the iovec
123 // storage space, and only our internal version would allocate it at the end of
124 // the WriteRequest.
125 
126 /* The default WriteRequest implementation, used for write(), writev() and
127  * writeChain()
128  *
129  * A new BytesWriteRequest operation is allocated on the heap for all write
130  * operations that cannot be completed immediately.
131  */
132 class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
133  public:
newRequest(AsyncSocket * socket,WriteCallback * callback,const iovec * ops,uint32_t opCount,uint32_t partialWritten,uint32_t bytesWritten,unique_ptr<IOBuf> && ioBuf,WriteFlags flags)134   static BytesWriteRequest* newRequest(
135       AsyncSocket* socket,
136       WriteCallback* callback,
137       const iovec* ops,
138       uint32_t opCount,
139       uint32_t partialWritten,
140       uint32_t bytesWritten,
141       unique_ptr<IOBuf>&& ioBuf,
142       WriteFlags flags) {
143     assert(opCount > 0);
144     // Since we put a variable size iovec array at the end
145     // of each BytesWriteRequest, we have to manually allocate the memory.
146     void* buf =
147         malloc(sizeof(BytesWriteRequest) + (opCount * sizeof(struct iovec)));
148     if (buf == nullptr) {
149       throw std::bad_alloc();
150     }
151 
152     return new (buf) BytesWriteRequest(
153         socket,
154         callback,
155         ops,
156         opCount,
157         partialWritten,
158         bytesWritten,
159         std::move(ioBuf),
160         flags);
161   }
162 
destroy()163   void destroy() override {
164     socket_->releaseIOBuf(std::move(ioBuf_), releaseIOBufCallback_);
165     this->~BytesWriteRequest();
166     free(this);
167   }
168 
performWrite()169   WriteResult performWrite() override {
170     WriteFlags writeFlags = flags_;
171     if (getNext() != nullptr) {
172       writeFlags |= WriteFlags::CORK;
173     }
174 
175     socket_->adjustZeroCopyFlags(writeFlags);
176 
177     auto writeResult = socket_->performWrite(
178         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
179     bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
180     if (bytesWritten_) {
181       if (socket_->isZeroCopyRequest(writeFlags)) {
182         if (isComplete()) {
183           socket_->addZeroCopyBuf(std::move(ioBuf_), releaseIOBufCallback_);
184         } else {
185           socket_->addZeroCopyBuf(ioBuf_.get());
186         }
187       } else {
188         // this happens if at least one of the prev requests were sent
189         // with zero copy but not the last one
190         if (isComplete() && zeroCopyRequest_ &&
191             socket_->containsZeroCopyBuf(ioBuf_.get())) {
192           socket_->setZeroCopyBuf(std::move(ioBuf_), releaseIOBufCallback_);
193         }
194       }
195     }
196     return writeResult;
197   }
198 
isComplete()199   bool isComplete() override { return opsWritten_ == getOpCount(); }
200 
consume()201   void consume() override {
202     // Advance opIndex_ forward by opsWritten_
203     opIndex_ += opsWritten_;
204     assert(opIndex_ < opCount_);
205 
206     bool zeroCopyReq = socket_->isZeroCopyRequest(flags_);
207     if (zeroCopyReq) {
208       zeroCopyRequest_ = true;
209     }
210 
211     if (!zeroCopyRequest_) {
212       // If we've finished writing any IOBufs, release them
213       // but only if we did not send any of them via zerocopy
214       if (ioBuf_) {
215         for (uint32_t i = opsWritten_; i != 0; --i) {
216           assert(ioBuf_);
217           auto next = ioBuf_->pop();
218           socket_->releaseIOBuf(std::move(ioBuf_), releaseIOBufCallback_);
219           ioBuf_ = std::move(next);
220         }
221       }
222     }
223 
224     // Move partialBytes_ forward into the current iovec buffer
225     struct iovec* currentOp = writeOps_ + opIndex_;
226     assert((partialBytes_ < currentOp->iov_len) || (currentOp->iov_len == 0));
227     currentOp->iov_base =
228         reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes_;
229     currentOp->iov_len -= partialBytes_;
230 
231     // Increment the totalBytesWritten_ count by bytesWritten_;
232     assert(bytesWritten_ >= 0);
233     totalBytesWritten_ += uint32_t(bytesWritten_);
234   }
235 
236  private:
BytesWriteRequest(AsyncSocket * socket,WriteCallback * callback,const struct iovec * ops,uint32_t opCount,uint32_t partialBytes,uint32_t bytesWritten,unique_ptr<IOBuf> && ioBuf,WriteFlags flags)237   BytesWriteRequest(
238       AsyncSocket* socket,
239       WriteCallback* callback,
240       const struct iovec* ops,
241       uint32_t opCount,
242       uint32_t partialBytes,
243       uint32_t bytesWritten,
244       unique_ptr<IOBuf>&& ioBuf,
245       WriteFlags flags)
246       : AsyncSocket::WriteRequest(socket, callback),
247         opCount_(opCount),
248         opIndex_(0),
249         flags_(flags),
250         ioBuf_(std::move(ioBuf)),
251         opsWritten_(0),
252         partialBytes_(partialBytes),
253         bytesWritten_(bytesWritten) {
254     memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
255     zeroCopyRequest_ = socket_->isZeroCopyRequest(flags_);
256   }
257 
258   // private destructor, to ensure callers use destroy()
259   ~BytesWriteRequest() override = default;
260 
getOps() const261   const struct iovec* getOps() const {
262     assert(opCount_ > opIndex_);
263     return writeOps_ + opIndex_;
264   }
265 
getOpCount() const266   uint32_t getOpCount() const {
267     assert(opCount_ > opIndex_);
268     return opCount_ - opIndex_;
269   }
270 
271   uint32_t opCount_; ///< number of entries in writeOps_
272   uint32_t opIndex_; ///< current index into writeOps_
273   WriteFlags flags_; ///< set for WriteFlags
274   bool zeroCopyRequest_{
275       false}; ///< if we sent any part of the ioBuf_ with zerocopy
276   unique_ptr<IOBuf> ioBuf_; ///< underlying IOBuf, or nullptr if N/A
277 
278   // for consume(), how much we wrote on the last write
279   uint32_t opsWritten_; ///< complete ops written
280   uint32_t partialBytes_; ///< partial bytes of incomplete op written
281   ssize_t bytesWritten_; ///< bytes written altogether
282 
283   struct iovec writeOps_[]; ///< write operation(s) list
284 };
285 
getDefaultFlags(folly::WriteFlags flags,bool zeroCopyEnabled)286 int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(
287     folly::WriteFlags flags, bool zeroCopyEnabled) noexcept {
288   int msg_flags = MSG_DONTWAIT;
289 
290 #ifdef MSG_NOSIGNAL // Linux-only
291   msg_flags |= MSG_NOSIGNAL;
292 #ifdef MSG_MORE
293   if (isSet(flags, WriteFlags::CORK)) {
294     // MSG_MORE tells the kernel we have more data to send, so wait for us to
295     // give it the rest of the data rather than immediately sending a partial
296     // frame, even when TCP_NODELAY is enabled.
297     msg_flags |= MSG_MORE;
298   }
299 #endif // MSG_MORE
300 #endif // MSG_NOSIGNAL
301   if (isSet(flags, WriteFlags::EOR)) {
302     // marks that this is the last byte of a record (response)
303     msg_flags |= MSG_EOR;
304   }
305 
306   if (zeroCopyEnabled && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY)) {
307     msg_flags |= MSG_ZEROCOPY;
308   }
309 
310   return msg_flags;
311 }
312 
getAncillaryData(folly::WriteFlags flags,void * data,const bool byteEventsEnabled)313 void AsyncSocket::SendMsgParamsCallback::getAncillaryData(
314     folly::WriteFlags flags,
315     void* data,
316     const bool byteEventsEnabled) noexcept {
317   auto ancillaryDataSize = getAncillaryDataSize(flags, byteEventsEnabled);
318   if (!ancillaryDataSize) {
319     return;
320   }
321 #if FOLLY_HAVE_SO_TIMESTAMPING
322   CHECK_NOTNULL(data);
323   // this function only handles ancillary data for timestamping
324   //
325   // if getAncillaryDataSize() is overridden and returning a size different
326   // than what we expect, then this function needs to be overridden too, in
327   // order to avoid conflict with how cmsg / msg are written
328   CHECK_EQ(CMSG_LEN(sizeof(uint32_t)), ancillaryDataSize);
329 
330   uint32_t sofFlags = 0;
331   if (byteEventsEnabled && isSet(flags, WriteFlags::TIMESTAMP_TX)) {
332     sofFlags = sofFlags | folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE;
333   }
334   if (byteEventsEnabled && isSet(flags, WriteFlags::TIMESTAMP_ACK)) {
335     sofFlags = sofFlags | folly::netops::SOF_TIMESTAMPING_TX_ACK;
336   }
337   if (byteEventsEnabled && isSet(flags, WriteFlags::TIMESTAMP_SCHED)) {
338     sofFlags = sofFlags | folly::netops::SOF_TIMESTAMPING_TX_SCHED;
339   }
340 
341   msghdr msg;
342   msg.msg_control = data;
343   msg.msg_controllen = ancillaryDataSize;
344   struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
345   CHECK_NOTNULL(cmsg);
346   cmsg->cmsg_level = SOL_SOCKET;
347   cmsg->cmsg_type = SO_TIMESTAMPING;
348   cmsg->cmsg_len = CMSG_LEN(sizeof(uint32_t));
349   memcpy(CMSG_DATA(cmsg), &sofFlags, sizeof(sofFlags));
350 #else
351   (void)data;
352 #endif // FOLLY_HAVE_SO_TIMESTAMPING
353   return;
354 }
355 
getAncillaryDataSize(folly::WriteFlags flags,const bool byteEventsEnabled)356 uint32_t AsyncSocket::SendMsgParamsCallback::getAncillaryDataSize(
357     folly::WriteFlags flags, const bool byteEventsEnabled) noexcept {
358 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
359   if (WriteFlags::NONE != (flags & kWriteFlagsForTimestamping) &&
360       byteEventsEnabled) {
361     return CMSG_LEN(sizeof(uint32_t));
362   }
363 #else
364   (void)flags;
365   (void)byteEventsEnabled;
366 #endif
367   return 0;
368 }
369 
370 folly::Optional<AsyncSocket::ByteEvent>
processCmsg(const cmsghdr & cmsg,const size_t rawBytesWritten)371 AsyncSocket::ByteEventHelper::processCmsg(
372     const cmsghdr& cmsg, const size_t rawBytesWritten) {
373 #if FOLLY_HAVE_SO_TIMESTAMPING
374   if (!byteEventsEnabled || maybeEx.has_value()) {
375     return folly::none;
376   }
377   if (!maybeTsState_.has_value()) {
378     maybeTsState_ = TimestampState();
379   }
380   auto& state = maybeTsState_.value();
381   if (auto serrTs = cmsgToSockExtendedErrTimestamping(cmsg)) {
382     if (state.serrReceived) {
383       // already have this part of the message pending
384       throw Exception("already have serr event");
385     }
386     state.serrReceived = true;
387     state.typeRaw = serrTs->ee_info;
388     state.byteOffsetKernel = serrTs->ee_data;
389   } else if (auto scmTs = cmsgToScmTimestamping(cmsg)) {
390     if (state.scmTsReceived) {
391       throw Exception("already have scmTs event");
392     }
393     state.scmTsReceived = true;
394 
395     auto timespecToDuration =
396         [](const timespec& ts) -> folly::Optional<std::chrono::nanoseconds> {
397       std::chrono::nanoseconds duration = std::chrono::seconds(ts.tv_sec) +
398           std::chrono::nanoseconds(ts.tv_nsec);
399       if (duration == duration.zero()) {
400         return folly::none;
401       }
402       return duration;
403     };
404     // ts[0] -> software timestamp
405     // ts[1] -> hardware timestamp transformed to userspace time (deprecated)
406     // ts[2] -> hardware timestamp
407     state.maybeSoftwareTs = timespecToDuration(scmTs->ts[0]);
408     state.maybeHardwareTs = timespecToDuration(scmTs->ts[2]);
409   }
410 
411   // if we have both components needed for a complete timestamp, build it
412   if (state.serrReceived && state.scmTsReceived) {
413     // cleanup state so that we're ready for next timestamp
414     TimestampState completeState = state;
415     maybeTsState_ = folly::none;
416 
417     // map the type
418     folly::Optional<ByteEvent::Type> tsType;
419     switch (completeState.typeRaw) {
420       case folly::netops::SCM_TSTAMP_SND: {
421         tsType = ByteEvent::Type::TX;
422         break;
423       }
424       case folly::netops::SCM_TSTAMP_ACK: {
425         tsType = ByteEvent::Type::ACK;
426         break;
427       }
428       case folly::netops::SCM_TSTAMP_SCHED: {
429         tsType = ByteEvent::Type::SCHED;
430         break;
431       }
432       default:
433         break; // unknown, maybe something new
434     }
435     if (!tsType) {
436       // it's a timestamp, but not one that we're set up to handle
437       // we've cleared our state, loop back around
438       return folly::none;
439     }
440 
441     // Calculate the byte offset.
442     //
443     // See documentation for SOF_TIMESTAMPING_OPT_ID for details.
444     //
445     // In summary, two things we have to consider:
446     //
447     //   (1) The byte stream offset is relative:
448     //       Socket timestamps include the byte stream offset for which the
449     //       timestamp applies. There may have been bytes transferred before the
450     //       fd was controlled by AsyncSocket. As a result, we don't know the
451     //       socket byte stream offset when we enable timestamping.
452     //
453     //       To get around this, we set SOF_TIMESTAMPING_OPT_ID when we enable
454     //       timestamping via setsockopt. This flag causes the kernel to reset
455     //       the offset it uses for timestamps to 0. This allows us to determine
456     //       an offset relative to the number of bytes that had been written to
457     //       the socket since timestamps were enabled.
458     //
459     //       Note that offsets begin at zero; if only a single byte is written
460     //       after timestamping is enabled, the offset included in the kernel
461     //       cmsg will be 0.
462     //
463     //   (2) The byte stream offset is a uint32_t:
464     //       Because the kernel uses a uint32_t to store and communicate the
465     //       byte stream offset, the offset will wrap every ~4GB. When we get a
466     //       timestamp, we need to figure out which byte it is for. We assume
467     //       that there will never be more than ~4GB of bytes sent between us
468     //       requesting timestamping for a byte and receiving the timestamp;
469     //       this is a realistic assumption given CWND and TCP buffer sizes. We
470     //       then calculate assuming that the counter has not wrapped since we
471     //       sent the byte that we are getting the timestamp for. If the counter
472     //       has wrapped, we detect it, and go back one position.
473     const uint64_t bytesPerOffsetWrap =
474         static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1;
475     size_t byteOffset = rawBytesWritten -
476         (rawBytesWritten % bytesPerOffsetWrap) +
477         completeState.byteOffsetKernel + rawBytesWrittenWhenByteEventsEnabled;
478     if (byteOffset > rawBytesWritten) {
479       // kernel's uint32_t var wrapped around; go back one wrap
480       CHECK_GE(byteOffset, bytesPerOffsetWrap);
481       byteOffset = byteOffset - bytesPerOffsetWrap;
482     }
483 
484     ByteEvent event = {};
485     event.type = tsType.value();
486     event.offset = byteOffset;
487     event.maybeSoftwareTs = state.maybeSoftwareTs;
488     event.maybeHardwareTs = state.maybeHardwareTs;
489     return event;
490   }
491 #else
492   (void)cmsg;
493   (void)rawBytesWritten;
494 #endif // FOLLY_HAVE_SO_TIMESTAMPING
495   return folly::none;
496 }
497 
498 namespace {
499 AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
500 
501 // Based on flags, signal the transparent handler to disable certain functions
disableTransparentFunctions(NetworkSocket fd,bool noTransparentTls,bool noTSocks)502 void disableTransparentFunctions(
503     NetworkSocket fd, bool noTransparentTls, bool noTSocks) {
504   (void)fd;
505   (void)noTransparentTls;
506   (void)noTSocks;
507 #if defined(__linux__)
508   if (noTransparentTls) {
509     // Ignore return value, errors are ok
510     VLOG(5) << "Disabling TTLS for fd " << fd;
511     netops::setsockopt(fd, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
512   }
513   if (noTSocks) {
514     VLOG(5) << "Disabling TSOCKS for fd " << fd;
515     // Ignore return value, errors are ok
516     netops::setsockopt(fd, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
517   }
518 #endif
519 }
520 
521 constexpr size_t kSmallIoVecSize = 64;
522 
523 } // namespace
524 
AsyncSocket()525 AsyncSocket::AsyncSocket()
526     : eventBase_(nullptr),
527       writeTimeout_(this, nullptr),
528       ioHandler_(this, nullptr),
529       immediateReadHandler_(this) {
530   VLOG(5) << "new AsyncSocket()";
531   init();
532 }
533 
AsyncSocket(EventBase * evb)534 AsyncSocket::AsyncSocket(EventBase* evb)
535     : eventBase_(evb),
536       writeTimeout_(this, evb),
537       ioHandler_(this, evb),
538       immediateReadHandler_(this) {
539   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
540   init();
541 }
542 
AsyncSocket(EventBase * evb,const folly::SocketAddress & address,uint32_t connectTimeout,bool useZeroCopy)543 AsyncSocket::AsyncSocket(
544     EventBase* evb,
545     const folly::SocketAddress& address,
546     uint32_t connectTimeout,
547     bool useZeroCopy)
548     : AsyncSocket(evb) {
549   setZeroCopy(useZeroCopy);
550   connect(nullptr, address, connectTimeout);
551 }
552 
AsyncSocket(EventBase * evb,const std::string & ip,uint16_t port,uint32_t connectTimeout,bool useZeroCopy)553 AsyncSocket::AsyncSocket(
554     EventBase* evb,
555     const std::string& ip,
556     uint16_t port,
557     uint32_t connectTimeout,
558     bool useZeroCopy)
559     : AsyncSocket(evb) {
560   setZeroCopy(useZeroCopy);
561   connect(nullptr, ip, port, connectTimeout);
562 }
563 
AsyncSocket(EventBase * evb,NetworkSocket fd,uint32_t zeroCopyBufId,const SocketAddress * peerAddress)564 AsyncSocket::AsyncSocket(
565     EventBase* evb,
566     NetworkSocket fd,
567     uint32_t zeroCopyBufId,
568     const SocketAddress* peerAddress)
569     : zeroCopyBufId_(zeroCopyBufId),
570       state_(StateEnum::ESTABLISHED),
571       fd_(fd),
572       addr_(peerAddress ? *peerAddress : folly::SocketAddress()),
573       eventBase_(evb),
574       writeTimeout_(this, evb),
575       ioHandler_(this, evb, fd),
576       immediateReadHandler_(this) {
577   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd=" << fd
578           << ", zeroCopyBufId=" << zeroCopyBufId << ")";
579   init();
580   disableTransparentFunctions(fd_, noTransparentTls_, noTSocks_);
581   setCloseOnExec();
582 }
583 
AsyncSocket(AsyncSocket * oldAsyncSocket)584 AsyncSocket::AsyncSocket(AsyncSocket* oldAsyncSocket)
585     : zeroCopyBufId_(oldAsyncSocket->getZeroCopyBufId()),
586       state_(oldAsyncSocket->state_),
587       fd_(oldAsyncSocket->detachNetworkSocket()),
588       addr_(oldAsyncSocket->addr_),
589       eventBase_(oldAsyncSocket->getEventBase()),
590       writeTimeout_(this, eventBase_),
591       ioHandler_(this, eventBase_, fd_),
592       immediateReadHandler_(this),
593       appBytesWritten_(oldAsyncSocket->appBytesWritten_),
594       rawBytesWritten_(oldAsyncSocket->rawBytesWritten_),
595       preReceivedData_(std::move(oldAsyncSocket->preReceivedData_)),
596       byteEventHelper_(std::move(oldAsyncSocket->byteEventHelper_)) {
597   VLOG(5) << "move AsyncSocket(" << oldAsyncSocket << "->" << this
598           << ", evb=" << eventBase_ << ", fd=" << fd_
599           << ", zeroCopyBufId=" << zeroCopyBufId_ << ")";
600   init();
601   disableTransparentFunctions(fd_, noTransparentTls_, noTSocks_);
602   setCloseOnExec();
603 
604   // inform lifecycle observers to give them an opportunity to unsubscribe from
605   // events for the old socket and subscribe to the new socket; we do not move
606   // the subscription ourselves
607   for (const auto& cb : oldAsyncSocket->lifecycleObservers_) {
608     // only available for observers derived from AsyncSocket::LifecycleObserver
609     if (auto dCb = dynamic_cast<AsyncSocket::LifecycleObserver*>(cb)) {
610       dCb->move(oldAsyncSocket, this);
611     }
612   }
613 }
614 
AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)615 AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)
616     : AsyncSocket(oldAsyncSocket.get()) {}
617 
618 // init() method, since constructor forwarding isn't supported in most
619 // compilers yet.
init()620 void AsyncSocket::init() {
621   if (eventBase_) {
622     eventBase_->dcheckIsInEventBaseThread();
623   }
624   eventFlags_ = EventHandler::NONE;
625   sendTimeout_ = 0;
626   maxReadsPerEvent_ = 16;
627   connectCallback_ = nullptr;
628   errMessageCallback_ = nullptr;
629   readAncillaryDataCallback_ = nullptr;
630   readCallback_ = nullptr;
631   writeReqHead_ = nullptr;
632   writeReqTail_ = nullptr;
633   wShutdownSocketSet_.reset();
634   appBytesReceived_ = 0;
635   totalAppBytesScheduledForWrite_ = 0;
636   sendMsgParamCallback_ = &defaultSendMsgParamsCallback;
637 }
638 
~AsyncSocket()639 AsyncSocket::~AsyncSocket() {
640   VLOG(7) << "actual destruction of AsyncSocket(this=" << this
641           << ", evb=" << eventBase_ << ", fd=" << fd_ << ", state=" << state_
642           << ")";
643   for (const auto& cb : lifecycleObservers_) {
644     cb->destroy(this);
645   }
646   DCHECK_EQ(allocatedBytesBuffered_, 0);
647 }
648 
destroy()649 void AsyncSocket::destroy() {
650   VLOG(5) << "AsyncSocket::destroy(this=" << this << ", evb=" << eventBase_
651           << ", fd=" << fd_ << ", state=" << state_;
652   // When destroy is called, close the socket immediately
653   closeNow();
654 
655   // Then call DelayedDestruction::destroy() to take care of
656   // whether or not we need immediate or delayed destruction
657   DelayedDestruction::destroy();
658 }
659 
detachNetworkSocket()660 NetworkSocket AsyncSocket::detachNetworkSocket() {
661   VLOG(6) << "AsyncSocket::detachFd(this=" << this << ", fd=" << fd_
662           << ", evb=" << eventBase_ << ", state=" << state_
663           << ", events=" << std::hex << eventFlags_ << ")";
664   for (const auto& cb : lifecycleObservers_) {
665     // only available for observers derived from AsyncSocket::LifecycleObserver
666     if (auto dCb = dynamic_cast<AsyncSocket::LifecycleObserver*>(cb)) {
667       dCb->fdDetach(this);
668     }
669   }
670   // Extract the fd, and set fd_ to -1 first, so closeNow() won't
671   // actually close the descriptor.
672   if (const auto socketSet = wShutdownSocketSet_.lock()) {
673     socketSet->remove(fd_);
674   }
675   auto fd = fd_;
676   fd_ = NetworkSocket();
677   // Call closeNow() to invoke all pending callbacks with an error.
678   closeNow();
679   // Update the EventHandler to stop using this fd.
680   // This can only be done after closeNow() unregisters the handler.
681   ioHandler_.changeHandlerFD(NetworkSocket());
682   return fd;
683 }
684 
anyAddress()685 const folly::SocketAddress& AsyncSocket::anyAddress() {
686   static const folly::SocketAddress anyAddress =
687       folly::SocketAddress("0.0.0.0", 0);
688   return anyAddress;
689 }
690 
setShutdownSocketSet(const std::weak_ptr<ShutdownSocketSet> & wNewSS)691 void AsyncSocket::setShutdownSocketSet(
692     const std::weak_ptr<ShutdownSocketSet>& wNewSS) {
693   const auto newSS = wNewSS.lock();
694   const auto shutdownSocketSet = wShutdownSocketSet_.lock();
695 
696   if (newSS == shutdownSocketSet) {
697     return;
698   }
699 
700   if (shutdownSocketSet && fd_ != NetworkSocket()) {
701     shutdownSocketSet->remove(fd_);
702   }
703 
704   if (newSS && fd_ != NetworkSocket()) {
705     newSS->add(fd_);
706   }
707 
708   wShutdownSocketSet_ = wNewSS;
709 }
710 
setCloseOnExec()711 void AsyncSocket::setCloseOnExec() {
712   int rv = netops_->set_socket_close_on_exec(fd_);
713   if (rv != 0) {
714     auto errnoCopy = errno;
715     throw AsyncSocketException(
716         AsyncSocketException::INTERNAL_ERROR,
717         withAddr("failed to set close-on-exec flag"),
718         errnoCopy);
719   }
720 }
721 
connect(ConnectCallback * callback,const folly::SocketAddress & address,int timeout,const SocketOptionMap & options,const folly::SocketAddress & bindAddr,const std::string & ifName)722 void AsyncSocket::connect(
723     ConnectCallback* callback,
724     const folly::SocketAddress& address,
725     int timeout,
726     const SocketOptionMap& options,
727     const folly::SocketAddress& bindAddr,
728     const std::string& ifName) noexcept {
729   DestructorGuard dg(this);
730   eventBase_->dcheckIsInEventBaseThread();
731 
732   addr_ = address;
733 
734   // Make sure we're in the uninitialized state
735   if (state_ != StateEnum::UNINIT) {
736     return invalidState(callback);
737   }
738 
739   connectTimeout_ = std::chrono::milliseconds(timeout);
740   connectStartTime_ = std::chrono::steady_clock::now();
741   // Make connect end time at least >= connectStartTime.
742   connectEndTime_ = connectStartTime_;
743 
744   assert(fd_ == NetworkSocket());
745   state_ = StateEnum::CONNECTING;
746   connectCallback_ = callback;
747   invokeConnectAttempt();
748 
749   sockaddr_storage addrStorage;
750   auto saddr = reinterpret_cast<sockaddr*>(&addrStorage);
751 
752   try {
753     // Create the socket
754     // Technically the first parameter should actually be a protocol family
755     // constant (PF_xxx) rather than an address family (AF_xxx), but the
756     // distinction is mainly just historical.  In pretty much all
757     // implementations the PF_foo and AF_foo constants are identical.
758     fd_ = netops_->socket(address.getFamily(), SOCK_STREAM, 0);
759     if (fd_ == NetworkSocket()) {
760       auto errnoCopy = errno;
761       throw AsyncSocketException(
762           AsyncSocketException::INTERNAL_ERROR,
763           withAddr("failed to create socket"),
764           errnoCopy);
765     }
766 
767     disableTransparentFunctions(fd_, noTransparentTls_, noTSocks_);
768     handleNetworkSocketAttached();
769     setCloseOnExec();
770 
771     // Put the socket in non-blocking mode
772     int rv = netops_->set_socket_non_blocking(fd_);
773     if (rv == -1) {
774       auto errnoCopy = errno;
775       throw AsyncSocketException(
776           AsyncSocketException::INTERNAL_ERROR,
777           withAddr("failed to put socket in non-blocking mode"),
778           errnoCopy);
779     }
780 
781 #if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
782     // iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
783     rv = fcntl(fd_.toFd(), F_SETNOSIGPIPE, 1);
784     if (rv == -1) {
785       auto errnoCopy = errno;
786       throw AsyncSocketException(
787           AsyncSocketException::INTERNAL_ERROR,
788           "failed to enable F_SETNOSIGPIPE on socket",
789           errnoCopy);
790     }
791 #endif
792 
793     // By default, turn on TCP_NODELAY
794     // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
795     // setNoDelay() will log an error message if it fails.
796     // Also set the cached zeroCopyVal_ since it cannot be set earlier if the fd
797     // is not created
798     if (address.getFamily() != AF_UNIX) {
799       (void)setNoDelay(true);
800       setZeroCopy(zeroCopyVal_);
801     }
802 
803     // Apply the additional PRE_BIND options if any.
804     applyOptions(options, SocketOptionKey::ApplyPos::PRE_BIND);
805 
806     VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
807             << ", fd=" << fd_ << ", host=" << address.describe().c_str();
808 
809     // bind the socket to the interface
810 #if defined(__linux__)
811     if (!ifName.empty() &&
812         netops_->setsockopt(
813             fd_,
814             SOL_SOCKET,
815             SO_BINDTODEVICE,
816             ifName.c_str(),
817             ifName.length())) {
818       auto errnoCopy = errno;
819       doClose();
820       throw AsyncSocketException(
821           AsyncSocketException::NOT_OPEN,
822           "failed to bind to device: " + ifName,
823           errnoCopy);
824     }
825 #else
826     (void)ifName;
827 #endif
828 
829     // bind the socket
830     if (bindAddr != anyAddress()) {
831       int one = 1;
832       if (netops_->setsockopt(
833               fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
834         auto errnoCopy = errno;
835         doClose();
836         throw AsyncSocketException(
837             AsyncSocketException::NOT_OPEN,
838             "failed to setsockopt prior to bind on " + bindAddr.describe(),
839             errnoCopy);
840       }
841 
842       bindAddr.getAddress(&addrStorage);
843 
844       if (netops_->bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
845         auto errnoCopy = errno;
846         doClose();
847         throw AsyncSocketException(
848             AsyncSocketException::NOT_OPEN,
849             "failed to bind to async socket: " + bindAddr.describe(),
850             errnoCopy);
851       }
852     }
853 
854     // Apply the additional POST_BIND options if any.
855     applyOptions(options, SocketOptionKey::ApplyPos::POST_BIND);
856 
857     // Call preConnect hook if any.
858     if (connectCallback_) {
859       connectCallback_->preConnect(fd_);
860     }
861 
862     // Perform the connect()
863     address.getAddress(&addrStorage);
864 
865     if (tfoEnabled_) {
866       state_ = StateEnum::FAST_OPEN;
867       tfoAttempted_ = true;
868     } else {
869       if (socketConnect(saddr, addr_.getActualSize()) < 0) {
870         return;
871       }
872     }
873 
874     // If we're still here the connect() succeeded immediately.
875     // Fall through to call the callback outside of this try...catch block
876   } catch (const AsyncSocketException& ex) {
877     return failConnect(__func__, ex);
878   } catch (const std::exception& ex) {
879     // shouldn't happen, but handle it just in case
880     VLOG(4) << "AsyncSocket::connect(this=" << this << ", fd=" << fd_
881             << "): unexpected " << typeid(ex).name()
882             << " exception: " << ex.what();
883     AsyncSocketException tex(
884         AsyncSocketException::INTERNAL_ERROR,
885         withAddr(string("unexpected exception: ") + ex.what()));
886     return failConnect(__func__, tex);
887   }
888 
889   // The connection succeeded immediately
890   // The read callback may not have been set yet, and no writes may be pending
891   // yet, so we don't have to register for any events at the moment.
892   VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
893   assert(errMessageCallback_ == nullptr);
894   assert(readAncillaryDataCallback_ == nullptr);
895   assert(readCallback_ == nullptr);
896   assert(writeReqHead_ == nullptr);
897   if (state_ != StateEnum::FAST_OPEN) {
898     state_ = StateEnum::ESTABLISHED;
899   }
900   invokeConnectSuccess();
901 }
902 
socketConnect(const struct sockaddr * saddr,socklen_t len)903 int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
904   int rv = netops_->connect(fd_, saddr, len);
905   if (rv < 0) {
906     auto errnoCopy = errno;
907     if (errnoCopy == EINPROGRESS) {
908       scheduleConnectTimeout();
909       registerForConnectEvents();
910     } else {
911       throw AsyncSocketException(
912           AsyncSocketException::NOT_OPEN,
913           "connect failed (immediately)",
914           errnoCopy);
915     }
916   }
917   return rv;
918 }
919 
scheduleConnectTimeout()920 void AsyncSocket::scheduleConnectTimeout() {
921   // Connection in progress.
922   auto timeout = connectTimeout_.count();
923   if (timeout > 0) {
924     // Start a timer in case the connection takes too long.
925     if (!writeTimeout_.scheduleTimeout(uint32_t(timeout))) {
926       throw AsyncSocketException(
927           AsyncSocketException::INTERNAL_ERROR,
928           withAddr("failed to schedule AsyncSocket connect timeout"));
929     }
930   }
931 }
932 
registerForConnectEvents()933 void AsyncSocket::registerForConnectEvents() {
934   // Register for write events, so we'll
935   // be notified when the connection finishes/fails.
936   // Note that we don't register for a persistent event here.
937   assert(eventFlags_ == EventHandler::NONE);
938   eventFlags_ = EventHandler::WRITE;
939   if (!ioHandler_.registerHandler(eventFlags_)) {
940     throw AsyncSocketException(
941         AsyncSocketException::INTERNAL_ERROR,
942         withAddr("failed to register AsyncSocket connect handler"));
943   }
944 }
945 
connect(ConnectCallback * callback,const string & ip,uint16_t port,int timeout,const SocketOptionMap & options)946 void AsyncSocket::connect(
947     ConnectCallback* callback,
948     const string& ip,
949     uint16_t port,
950     int timeout,
951     const SocketOptionMap& options) noexcept {
952   DestructorGuard dg(this);
953   try {
954     connectCallback_ = callback;
955     connect(callback, folly::SocketAddress(ip, port), timeout, options);
956   } catch (const std::exception& ex) {
957     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR, ex.what());
958     return failConnect(__func__, tex);
959   }
960 }
961 
cancelConnect()962 void AsyncSocket::cancelConnect() {
963   connectCallback_ = nullptr;
964   if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) {
965     closeNow();
966   }
967 }
968 
setSendTimeout(uint32_t milliseconds)969 void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
970   sendTimeout_ = milliseconds;
971   if (eventBase_) {
972     eventBase_->dcheckIsInEventBaseThread();
973   }
974 
975   // If we are currently pending on write requests, immediately update
976   // writeTimeout_ with the new value.
977   if ((eventFlags_ & EventHandler::WRITE) &&
978       (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) {
979     assert(state_ == StateEnum::ESTABLISHED);
980     assert((shutdownFlags_ & SHUT_WRITE) == 0);
981     if (sendTimeout_ > 0) {
982       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
983         AsyncSocketException ex(
984             AsyncSocketException::INTERNAL_ERROR,
985             withAddr("failed to reschedule send timeout in setSendTimeout"));
986         return failWrite(__func__, ex);
987       }
988     } else {
989       writeTimeout_.cancelTimeout();
990     }
991   }
992 }
993 
setErrMessageCB(ErrMessageCallback * callback)994 void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
995   VLOG(6) << "AsyncSocket::setErrMessageCB() this=" << this << ", fd=" << fd_
996           << ", callback=" << callback << ", state=" << state_;
997 
998   // In the latest stable kernel 4.14.3 as of 2017-12-04, unix domain
999   // socket does not support MSG_ERRQUEUE. So recvmsg(MSG_ERRQUEUE)
1000   // will read application data from unix doamin socket as error
1001   // message, which breaks the message flow in application.  Feel free
1002   // to remove the next code block if MSG_ERRQUEUE is added for unix
1003   // domain socket in the future.
1004   if (callback != nullptr) {
1005     cacheLocalAddress();
1006     if (localAddr_.getFamily() == AF_UNIX) {
1007       LOG(ERROR) << "Failed to set ErrMessageCallback=" << callback
1008                  << " for Unix Doamin Socket where MSG_ERRQUEUE is unsupported,"
1009                  << " fd=" << fd_;
1010       return;
1011     }
1012   }
1013 
1014   // Short circuit if callback is the same as the existing errMessageCallback_.
1015   if (callback == errMessageCallback_) {
1016     return;
1017   }
1018 
1019   if (!msgErrQueueSupported) {
1020     // Per-socket error message queue is not supported on this platform.
1021     return invalidState(callback);
1022   }
1023 
1024   DestructorGuard dg(this);
1025   eventBase_->dcheckIsInEventBaseThread();
1026 
1027   if (callback == nullptr) {
1028     // We should be able to reset the callback regardless of the
1029     // socket state. It's important to have a reliable callback
1030     // cancellation mechanism.
1031     errMessageCallback_ = callback;
1032     return;
1033   }
1034 
1035   switch ((StateEnum)state_) {
1036     case StateEnum::CONNECTING:
1037     case StateEnum::FAST_OPEN:
1038     case StateEnum::ESTABLISHED: {
1039       errMessageCallback_ = callback;
1040       return;
1041     }
1042     case StateEnum::CLOSED:
1043     case StateEnum::ERROR:
1044       // We should never reach here.  SHUT_READ should always be set
1045       // if we are in STATE_CLOSED or STATE_ERROR.
1046       assert(false);
1047       return invalidState(callback);
1048     case StateEnum::UNINIT:
1049       // We do not allow setReadCallback() to be called before we start
1050       // connecting.
1051       return invalidState(callback);
1052   }
1053 
1054   // We don't put a default case in the switch statement, so that the compiler
1055   // will warn us to update the switch statement if a new state is added.
1056   return invalidState(callback);
1057 }
1058 
getErrMessageCallback() const1059 AsyncSocket::ErrMessageCallback* AsyncSocket::getErrMessageCallback() const {
1060   return errMessageCallback_;
1061 }
1062 
setReadAncillaryDataCB(ReadAncillaryDataCallback * callback)1063 void AsyncSocket::setReadAncillaryDataCB(ReadAncillaryDataCallback* callback) {
1064   VLOG(6) << "AsyncSocket::setReadAncillaryDataCB() this=" << this
1065           << ", fd=" << fd_ << ", callback=" << callback
1066           << ", state=" << state_;
1067 
1068   readAncillaryDataCallback_ = callback;
1069 }
1070 
1071 AsyncSocket::ReadAncillaryDataCallback*
getReadAncillaryDataCallback() const1072 AsyncSocket::getReadAncillaryDataCallback() const {
1073   return readAncillaryDataCallback_;
1074 }
1075 
setSendMsgParamCB(SendMsgParamsCallback * callback)1076 void AsyncSocket::setSendMsgParamCB(SendMsgParamsCallback* callback) {
1077   sendMsgParamCallback_ = callback;
1078 }
1079 
getSendMsgParamsCB() const1080 AsyncSocket::SendMsgParamsCallback* AsyncSocket::getSendMsgParamsCB() const {
1081   return sendMsgParamCallback_;
1082 }
1083 
setReadCB(ReadCallback * callback)1084 void AsyncSocket::setReadCB(ReadCallback* callback) {
1085   VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
1086           << ", callback=" << callback << ", state=" << state_;
1087 
1088   // Short circuit if callback is the same as the existing readCallback_.
1089   //
1090   // Note that this is needed for proper functioning during some cleanup cases.
1091   // During cleanup we allow setReadCallback(nullptr) to be called even if the
1092   // read callback is already unset and we have been detached from an event
1093   // base.  This check prevents us from asserting
1094   // eventBase_->isInEventBaseThread() when eventBase_ is nullptr.
1095   if (callback == readCallback_) {
1096     return;
1097   }
1098 
1099   /* We are removing a read callback */
1100   if (callback == nullptr && immediateReadHandler_.isLoopCallbackScheduled()) {
1101     immediateReadHandler_.cancelLoopCallback();
1102   }
1103 
1104   if (shutdownFlags_ & SHUT_READ) {
1105     // Reads have already been shut down on this socket.
1106     //
1107     // Allow setReadCallback(nullptr) to be called in this case, but don't
1108     // allow a new callback to be set.
1109     //
1110     // For example, setReadCallback(nullptr) can happen after an error if we
1111     // invoke some other error callback before invoking readError().  The other
1112     // error callback that is invoked first may go ahead and clear the read
1113     // callback before we get a chance to invoke readError().
1114     if (callback != nullptr) {
1115       return invalidState(callback);
1116     }
1117     assert((eventFlags_ & EventHandler::READ) == 0);
1118     readCallback_ = nullptr;
1119     return;
1120   }
1121 
1122   DestructorGuard dg(this);
1123   eventBase_->dcheckIsInEventBaseThread();
1124 
1125   switch ((StateEnum)state_) {
1126     case StateEnum::CONNECTING:
1127     case StateEnum::FAST_OPEN:
1128       // For convenience, we allow the read callback to be set while we are
1129       // still connecting.  We just store the callback for now.  Once the
1130       // connection completes we'll register for read events.
1131       readCallback_ = callback;
1132       return;
1133     case StateEnum::ESTABLISHED: {
1134       readCallback_ = callback;
1135       uint16_t oldFlags = eventFlags_;
1136       if (readCallback_) {
1137         eventFlags_ |= EventHandler::READ;
1138       } else {
1139         eventFlags_ &= ~EventHandler::READ;
1140       }
1141 
1142       // Update our registration if our flags have changed
1143       if (eventFlags_ != oldFlags) {
1144         // We intentionally ignore the return value here.
1145         // updateEventRegistration() will move us into the error state if it
1146         // fails, and we don't need to do anything else here afterwards.
1147         (void)updateEventRegistration();
1148       }
1149 
1150       if (readCallback_) {
1151         checkForImmediateRead();
1152       }
1153       return;
1154     }
1155     case StateEnum::CLOSED:
1156     case StateEnum::ERROR:
1157       // We should never reach here.  SHUT_READ should always be set
1158       // if we are in STATE_CLOSED or STATE_ERROR.
1159       assert(false);
1160       return invalidState(callback);
1161     case StateEnum::UNINIT:
1162       // We do not allow setReadCallback() to be called before we start
1163       // connecting.
1164       return invalidState(callback);
1165   }
1166 
1167   // We don't put a default case in the switch statement, so that the compiler
1168   // will warn us to update the switch statement if a new state is added.
1169   return invalidState(callback);
1170 }
1171 
getReadCallback() const1172 AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
1173   return readCallback_;
1174 }
1175 
setZeroCopy(bool enable)1176 bool AsyncSocket::setZeroCopy(bool enable) {
1177   if (msgErrQueueSupported) {
1178     zeroCopyVal_ = enable;
1179 
1180     if (fd_ == NetworkSocket()) {
1181       return false;
1182     }
1183 
1184     // No-op, bail out early
1185     if (enable == zeroCopyEnabled_) {
1186       return true;
1187     }
1188 
1189     int val = enable ? 1 : 0;
1190     int ret =
1191         netops_->setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
1192 
1193     // if enable == false, set zeroCopyEnabled_ = false regardless
1194     // if SO_ZEROCOPY is set or not
1195     if (!enable) {
1196       zeroCopyEnabled_ = enable;
1197       return true;
1198     }
1199 
1200     /* if the setsockopt failed, try to see if the socket inherited the flag
1201      * since we cannot set SO_ZEROCOPY on a socket s = accept
1202      */
1203     if (ret) {
1204       val = 0;
1205       socklen_t optlen = sizeof(val);
1206       ret = netops_->getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
1207 
1208       if (!ret) {
1209         enable = val != 0;
1210       }
1211     }
1212 
1213     if (!ret) {
1214       zeroCopyEnabled_ = enable;
1215 
1216       return true;
1217     }
1218   }
1219 
1220   return false;
1221 }
1222 
setZeroCopyEnableFunc(AsyncWriter::ZeroCopyEnableFunc func)1223 void AsyncSocket::setZeroCopyEnableFunc(AsyncWriter::ZeroCopyEnableFunc func) {
1224   zeroCopyEnableFunc_ = func;
1225 }
1226 
setZeroCopyReenableThreshold(size_t threshold)1227 void AsyncSocket::setZeroCopyReenableThreshold(size_t threshold) {
1228   zeroCopyReenableThreshold_ = threshold;
1229 }
1230 
isZeroCopyRequest(WriteFlags flags)1231 bool AsyncSocket::isZeroCopyRequest(WriteFlags flags) {
1232   return (zeroCopyEnabled_ && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY));
1233 }
1234 
adjustZeroCopyFlags(folly::WriteFlags & flags)1235 void AsyncSocket::adjustZeroCopyFlags(folly::WriteFlags& flags) {
1236   if (!zeroCopyEnabled_) {
1237     // if the zeroCopyReenableCounter_ is > 0
1238     // we try to dec and if we reach 0
1239     // we set zeroCopyEnabled_ to true
1240     if (zeroCopyReenableCounter_) {
1241       if (0 == --zeroCopyReenableCounter_) {
1242         zeroCopyEnabled_ = true;
1243         return;
1244       }
1245     }
1246     flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
1247   }
1248 }
1249 
addZeroCopyBuf(std::unique_ptr<folly::IOBuf> && buf,ReleaseIOBufCallback * cb)1250 void AsyncSocket::addZeroCopyBuf(
1251     std::unique_ptr<folly::IOBuf>&& buf, ReleaseIOBufCallback* cb) {
1252   uint32_t id = getNextZeroCopyBufId();
1253   folly::IOBuf* ptr = buf.get();
1254 
1255   idZeroCopyBufPtrMap_[id] = ptr;
1256   auto& p = idZeroCopyBufInfoMap_[ptr];
1257   p.count_++;
1258   CHECK(p.buf_.get() == nullptr);
1259   p.buf_ = std::move(buf);
1260   p.cb_ = cb;
1261 }
1262 
addZeroCopyBuf(folly::IOBuf * ptr)1263 void AsyncSocket::addZeroCopyBuf(folly::IOBuf* ptr) {
1264   uint32_t id = getNextZeroCopyBufId();
1265   idZeroCopyBufPtrMap_[id] = ptr;
1266 
1267   idZeroCopyBufInfoMap_[ptr].count_++;
1268 }
1269 
releaseZeroCopyBuf(uint32_t id)1270 void AsyncSocket::releaseZeroCopyBuf(uint32_t id) {
1271   auto iter = idZeroCopyBufPtrMap_.find(id);
1272   CHECK(iter != idZeroCopyBufPtrMap_.end());
1273   auto ptr = iter->second;
1274   auto iter1 = idZeroCopyBufInfoMap_.find(ptr);
1275   CHECK(iter1 != idZeroCopyBufInfoMap_.end());
1276   if (0 == --iter1->second.count_) {
1277     releaseIOBuf(std::move(iter1->second.buf_), iter1->second.cb_);
1278     idZeroCopyBufInfoMap_.erase(iter1);
1279   }
1280 
1281   idZeroCopyBufPtrMap_.erase(iter);
1282 }
1283 
setZeroCopyBuf(std::unique_ptr<folly::IOBuf> && buf,ReleaseIOBufCallback * cb)1284 void AsyncSocket::setZeroCopyBuf(
1285     std::unique_ptr<folly::IOBuf>&& buf, ReleaseIOBufCallback* cb) {
1286   folly::IOBuf* ptr = buf.get();
1287   auto& p = idZeroCopyBufInfoMap_[ptr];
1288   CHECK(p.buf_.get() == nullptr);
1289 
1290   p.buf_ = std::move(buf);
1291   p.cb_ = cb;
1292 }
1293 
containsZeroCopyBuf(folly::IOBuf * ptr)1294 bool AsyncSocket::containsZeroCopyBuf(folly::IOBuf* ptr) {
1295   return (idZeroCopyBufInfoMap_.find(ptr) != idZeroCopyBufInfoMap_.end());
1296 }
1297 
isZeroCopyMsg(const cmsghdr & cmsg) const1298 bool AsyncSocket::isZeroCopyMsg(const cmsghdr& cmsg) const {
1299 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
1300   if ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
1301       (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
1302     auto serr =
1303         reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
1304     return (
1305         (serr->ee_errno == 0) && (serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY));
1306   }
1307 #endif
1308   (void)cmsg;
1309   return false;
1310 }
1311 
processZeroCopyMsg(const cmsghdr & cmsg)1312 void AsyncSocket::processZeroCopyMsg(const cmsghdr& cmsg) {
1313 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
1314   auto serr =
1315       reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
1316   uint32_t hi = serr->ee_data;
1317   uint32_t lo = serr->ee_info;
1318   // disable zero copy if the buffer was actually copied
1319   if ((serr->ee_code & SO_EE_CODE_ZEROCOPY_COPIED) && zeroCopyEnabled_) {
1320     VLOG(2) << "AsyncSocket::processZeroCopyMsg(): setting "
1321             << "zeroCopyEnabled_ = false due to SO_EE_CODE_ZEROCOPY_COPIED "
1322             << "on " << fd_;
1323     zeroCopyEnabled_ = false;
1324   }
1325 
1326   for (uint32_t i = lo; i <= hi; i++) {
1327     releaseZeroCopyBuf(i);
1328   }
1329 #else
1330   (void)cmsg;
1331 #endif
1332 }
1333 
releaseIOBuf(std::unique_ptr<folly::IOBuf> buf,ReleaseIOBufCallback * callback)1334 void AsyncSocket::releaseIOBuf(
1335     std::unique_ptr<folly::IOBuf> buf, ReleaseIOBufCallback* callback) {
1336   if (!buf) {
1337     return;
1338   }
1339   const size_t allocated = buf->computeChainCapacity();
1340   DCHECK_GE(allocatedBytesBuffered_, allocated);
1341   allocatedBytesBuffered_ -= allocated;
1342   if (callback) {
1343     callback->releaseIOBuf(std::move(buf));
1344   }
1345 }
1346 
enableByteEvents()1347 void AsyncSocket::enableByteEvents() {
1348   if (!byteEventHelper_) {
1349     byteEventHelper_ = std::make_unique<ByteEventHelper>();
1350   }
1351 
1352   if (byteEventHelper_->byteEventsEnabled ||
1353       byteEventHelper_->maybeEx.has_value()) {
1354     return;
1355   }
1356 
1357   try {
1358 #if FOLLY_HAVE_SO_TIMESTAMPING
1359     // make sure we have a connected IP socket that supports error queues
1360     // (Unix sockets do not support error queues)
1361     if (NetworkSocket() == fd_ || !good()) {
1362       throw AsyncSocketException(
1363           AsyncSocketException::INVALID_STATE,
1364           withAddr("failed to enable byte events: "
1365                    "socket is not open or not in a good state"));
1366     }
1367     folly::SocketAddress addr = {};
1368     try {
1369       // explicitly fetch local address (instead of using cache)
1370       // to ensure socket is currently healthy
1371       addr.setFromLocalAddress(fd_);
1372     } catch (const std::system_error&) {
1373       throw AsyncSocketException(
1374           AsyncSocketException::INVALID_STATE,
1375           withAddr("failed to enable byte events: "
1376                    "socket is not open or not in a good state"));
1377     }
1378     const auto family = addr.getFamily();
1379     if (family != AF_INET && family != AF_INET6) {
1380       throw AsyncSocketException(
1381           AsyncSocketException::NOT_SUPPORTED,
1382           withAddr("failed to enable byte events: socket type not supported"));
1383     }
1384 
1385     // check if timestamping is already enabled on the socket by another source
1386     {
1387       uint32_t flags = 0;
1388       socklen_t len = sizeof(flags);
1389       const auto ret =
1390           getSockOptVirtual(SOL_SOCKET, SO_TIMESTAMPING, &flags, &len);
1391       int getSockOptErrno = errno;
1392       if (0 != ret) {
1393         throw AsyncSocketException(
1394             AsyncSocketException::INTERNAL_ERROR,
1395             withAddr("failed to enable byte events: "
1396                      "timestamps may not be supported for this socket type "
1397                      "or socket be closed"),
1398             getSockOptErrno);
1399       }
1400       if (0 != flags) {
1401         throw AsyncSocketException(
1402             AsyncSocketException::INTERNAL_ERROR,
1403             withAddr("failed to enable byte events: "
1404                      "timestamps may have already been enabled"),
1405             getSockOptErrno);
1406       }
1407     }
1408 
1409     // enable control messages for software and hardware timestamps
1410     // WriteFlags will determine which messages are generated
1411     //
1412     // SOF_TIMESTAMPING_OPT_ID: see discussion in ByteEventHelper::processCmsg
1413     // SOF_TIMESTAMPING_OPT_TSONLY: only get timestamps, not original packet
1414     // SOF_TIMESTAMPING_SOFTWARE: get software timestamps if generated
1415     // SOF_TIMESTAMPING_RAW_HARDWARE: get hardware timestamps if generated
1416     // SOF_TIMESTAMPING_OPT_TX_SWHW: get both sw + hw timestamps if generated
1417     const uint32_t flags =
1418         (folly::netops::SOF_TIMESTAMPING_OPT_ID |
1419          folly::netops::SOF_TIMESTAMPING_OPT_TSONLY |
1420          folly::netops::SOF_TIMESTAMPING_SOFTWARE |
1421          folly::netops::SOF_TIMESTAMPING_RAW_HARDWARE |
1422          folly::netops::SOF_TIMESTAMPING_OPT_TX_SWHW);
1423     socklen_t len = sizeof(flags);
1424     const auto ret =
1425         setSockOptVirtual(SOL_SOCKET, SO_TIMESTAMPING, &flags, len);
1426     int setSockOptErrno = errno;
1427     if (ret == 0) {
1428       byteEventHelper_->byteEventsEnabled = true;
1429       byteEventHelper_->rawBytesWrittenWhenByteEventsEnabled =
1430           getRawBytesWritten();
1431       for (const auto& observer : lifecycleObservers_) {
1432         if (observer->getConfig().byteEvents) {
1433           observer->byteEventsEnabled(this);
1434         }
1435       }
1436       return;
1437     }
1438 
1439     // failed
1440     throw AsyncSocketException(
1441         AsyncSocketException::INTERNAL_ERROR,
1442         withAddr("failed to enable byte events: setsockopt failed"),
1443         setSockOptErrno);
1444 #endif // FOLLY_HAVE_SO_TIMESTAMPING
1445     // unsupported by platform
1446     throw AsyncSocketException(
1447         AsyncSocketException::NOT_SUPPORTED,
1448         withAddr("failed to enable byte events: platform not supported"));
1449 
1450   } catch (const AsyncSocketException& ex) {
1451     failByteEvents(ex);
1452   }
1453 }
1454 
write(WriteCallback * callback,const void * buf,size_t bytes,WriteFlags flags)1455 void AsyncSocket::write(
1456     WriteCallback* callback, const void* buf, size_t bytes, WriteFlags flags) {
1457   iovec op;
1458   op.iov_base = const_cast<void*>(buf);
1459   op.iov_len = bytes;
1460   writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), bytes, flags);
1461 }
1462 
writev(WriteCallback * callback,const iovec * vec,size_t count,WriteFlags flags)1463 void AsyncSocket::writev(
1464     WriteCallback* callback, const iovec* vec, size_t count, WriteFlags flags) {
1465   size_t totalBytes = 0;
1466   for (size_t i = 0; i < count; ++i) {
1467     totalBytes += vec[i].iov_len;
1468   }
1469   writeImpl(callback, vec, count, unique_ptr<IOBuf>(), totalBytes, flags);
1470 }
1471 
writeChain(WriteCallback * callback,unique_ptr<IOBuf> && buf,WriteFlags flags)1472 void AsyncSocket::writeChain(
1473     WriteCallback* callback, unique_ptr<IOBuf>&& buf, WriteFlags flags) {
1474   adjustZeroCopyFlags(flags);
1475 
1476   // adjustZeroCopyFlags can set zeroCopyEnabled_ to true
1477   if (zeroCopyEnabled_ && !isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY) &&
1478       zeroCopyEnableFunc_ && zeroCopyEnableFunc_(buf) && buf->isManaged()) {
1479     flags |= WriteFlags::WRITE_MSG_ZEROCOPY;
1480   }
1481 
1482   size_t count = buf->countChainElements();
1483   if (count <= kSmallIoVecSize) {
1484     // suppress "warning: variable length array 'vec' is used [-Wvla]"
1485     FOLLY_PUSH_WARNING
1486     FOLLY_GNU_DISABLE_WARNING("-Wvla")
1487     iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA_01, count, kSmallIoVecSize)];
1488     FOLLY_POP_WARNING
1489 
1490     writeChainImpl(callback, vec, count, std::move(buf), flags);
1491   } else {
1492     std::unique_ptr<iovec[]> vec(new iovec[count]);
1493     writeChainImpl(callback, vec.get(), count, std::move(buf), flags);
1494   }
1495 }
1496 
writeChainImpl(WriteCallback * callback,iovec * vec,size_t count,unique_ptr<IOBuf> && buf,WriteFlags flags)1497 void AsyncSocket::writeChainImpl(
1498     WriteCallback* callback,
1499     iovec* vec,
1500     size_t count,
1501     unique_ptr<IOBuf>&& buf,
1502     WriteFlags flags) {
1503   auto res = buf->fillIov(vec, count);
1504   writeImpl(
1505       callback, vec, res.numIovecs, std::move(buf), res.totalLength, flags);
1506 }
1507 
writeImpl(WriteCallback * callback,const iovec * vec,size_t count,unique_ptr<IOBuf> && buf,size_t totalBytes,WriteFlags flags)1508 void AsyncSocket::writeImpl(
1509     WriteCallback* callback,
1510     const iovec* vec,
1511     size_t count,
1512     unique_ptr<IOBuf>&& buf,
1513     size_t totalBytes,
1514     WriteFlags flags) {
1515   VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
1516           << ", callback=" << callback << ", count=" << count
1517           << ", state=" << state_;
1518   DestructorGuard dg(this);
1519   unique_ptr<IOBuf> ioBuf(std::move(buf));
1520   eventBase_->dcheckIsInEventBaseThread();
1521 
1522   auto* releaseIOBufCallback =
1523       callback ? callback->getReleaseIOBufCallback() : nullptr;
1524 
1525   SCOPE_EXIT { releaseIOBuf(std::move(ioBuf), releaseIOBufCallback); };
1526 
1527   totalAppBytesScheduledForWrite_ += totalBytes;
1528   if (ioBuf) {
1529     allocatedBytesBuffered_ += ioBuf->computeChainCapacity();
1530   }
1531 
1532   if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
1533     // No new writes may be performed after the write side of the socket has
1534     // been shutdown.
1535     //
1536     // We could just call callback->writeError() here to fail just this write.
1537     // However, fail hard and use invalidState() to fail all outstanding
1538     // callbacks and move the socket into the error state.  There's most likely
1539     // a bug in the caller's code, so we abort everything rather than trying to
1540     // proceed as best we can.
1541     return invalidState(callback);
1542   }
1543 
1544   uint32_t countWritten = 0;
1545   uint32_t partialWritten = 0;
1546   ssize_t bytesWritten = 0;
1547   bool mustRegister = false;
1548   if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) &&
1549       !connecting()) {
1550     if (writeReqHead_ == nullptr) {
1551       // If we are established and there are no other writes pending,
1552       // we can attempt to perform the write immediately.
1553       assert(writeReqTail_ == nullptr);
1554       assert((eventFlags_ & EventHandler::WRITE) == 0);
1555 
1556       auto writeResult = performWrite(
1557           vec, uint32_t(count), flags, &countWritten, &partialWritten);
1558       bytesWritten = writeResult.writeReturn;
1559       if (bytesWritten < 0) {
1560         auto errnoCopy = errno;
1561         if (writeResult.exception) {
1562           return failWrite(__func__, callback, 0, *writeResult.exception);
1563         }
1564         AsyncSocketException ex(
1565             AsyncSocketException::INTERNAL_ERROR,
1566             withAddr("writev failed"),
1567             errnoCopy);
1568         return failWrite(__func__, callback, 0, ex);
1569       } else if (countWritten == count) {
1570         // done, add the whole buffer
1571         if (countWritten && isZeroCopyRequest(flags)) {
1572           addZeroCopyBuf(std::move(ioBuf), releaseIOBufCallback);
1573         } else {
1574           releaseIOBuf(std::move(ioBuf), releaseIOBufCallback);
1575         }
1576 
1577         // We successfully wrote everything.
1578         // Invoke the callback and return.
1579         if (callback) {
1580           callback->writeSuccess();
1581         }
1582         return;
1583       } else { // continue writing the next writeReq
1584         // add just the ptr
1585         if (bytesWritten && isZeroCopyRequest(flags)) {
1586           addZeroCopyBuf(ioBuf.get());
1587         }
1588       }
1589       if (!connecting()) {
1590         // Writes might put the socket back into connecting state
1591         // if TFO is enabled, and using TFO fails.
1592         // This means that write timeouts would not be active, however
1593         // connect timeouts would affect this stage.
1594         mustRegister = true;
1595       }
1596     }
1597   } else if (!connecting()) {
1598     // Invalid state for writing
1599     return invalidState(callback);
1600   }
1601 
1602   // Create a new WriteRequest to add to the queue
1603   WriteRequest* req;
1604   try {
1605     req = BytesWriteRequest::newRequest(
1606         this,
1607         callback,
1608         vec + countWritten,
1609         uint32_t(count - countWritten),
1610         partialWritten,
1611         uint32_t(bytesWritten),
1612         std::move(ioBuf),
1613         flags);
1614   } catch (const std::exception& ex) {
1615     // we mainly expect to catch std::bad_alloc here
1616     AsyncSocketException tex(
1617         AsyncSocketException::INTERNAL_ERROR,
1618         withAddr(string("failed to append new WriteRequest: ") + ex.what()));
1619     return failWrite(__func__, callback, size_t(bytesWritten), tex);
1620   }
1621   req->consume();
1622   if (writeReqTail_ == nullptr) {
1623     assert(writeReqHead_ == nullptr);
1624     writeReqHead_ = writeReqTail_ = req;
1625   } else {
1626     writeReqTail_->append(req);
1627     writeReqTail_ = req;
1628   }
1629 
1630   if (bufferCallback_) {
1631     bufferCallback_->onEgressBuffered();
1632   }
1633 
1634   // Register for write events if are established and not currently
1635   // waiting on write events
1636   if (mustRegister) {
1637     assert(state_ == StateEnum::ESTABLISHED);
1638     assert((eventFlags_ & EventHandler::WRITE) == 0);
1639     if (!updateEventRegistration(EventHandler::WRITE, 0)) {
1640       assert(state_ == StateEnum::ERROR);
1641       return;
1642     }
1643     if (sendTimeout_ > 0) {
1644       // Schedule a timeout to fire if the write takes too long.
1645       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
1646         AsyncSocketException ex(
1647             AsyncSocketException::INTERNAL_ERROR,
1648             withAddr("failed to schedule send timeout"));
1649         return failWrite(__func__, ex);
1650       }
1651     }
1652   }
1653 }
1654 
writeRequest(WriteRequest * req)1655 void AsyncSocket::writeRequest(WriteRequest* req) {
1656   if (writeReqTail_ == nullptr) {
1657     assert(writeReqHead_ == nullptr);
1658     writeReqHead_ = writeReqTail_ = req;
1659     req->start();
1660   } else {
1661     writeReqTail_->append(req);
1662     writeReqTail_ = req;
1663   }
1664 }
1665 
close()1666 void AsyncSocket::close() {
1667   VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
1668           << ", state=" << state_ << ", shutdownFlags=" << std::hex
1669           << (int)shutdownFlags_;
1670 
1671   // close() is only different from closeNow() when there are pending writes
1672   // that need to drain before we can close.  In all other cases, just call
1673   // closeNow().
1674   //
1675   // Note that writeReqHead_ can be non-nullptr even in STATE_CLOSED or
1676   // STATE_ERROR if close() is invoked while a previous closeNow() or failure
1677   // is still running.  (e.g., If there are multiple pending writes, and we
1678   // call writeError() on the first one, it may call close().  In this case we
1679   // will already be in STATE_CLOSED or STATE_ERROR, but the remaining pending
1680   // writes will still be in the queue.)
1681   //
1682   // We only need to drain pending writes if we are still in STATE_CONNECTING
1683   // or STATE_ESTABLISHED
1684   if ((writeReqHead_ == nullptr) ||
1685       !(state_ == StateEnum::CONNECTING || state_ == StateEnum::ESTABLISHED)) {
1686     closeNow();
1687     return;
1688   }
1689 
1690   // Declare a DestructorGuard to ensure that the AsyncSocket cannot be
1691   // destroyed until close() returns.
1692   DestructorGuard dg(this);
1693   eventBase_->dcheckIsInEventBaseThread();
1694 
1695   // Since there are write requests pending, we have to set the
1696   // SHUT_WRITE_PENDING flag, and wait to perform the real close until the
1697   // connect finishes and we finish writing these requests.
1698   //
1699   // Set SHUT_READ to indicate that reads are shut down, and set the
1700   // SHUT_WRITE_PENDING flag to mark that we want to shutdown once the
1701   // pending writes complete.
1702   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE_PENDING);
1703 
1704   // If a read callback is set, invoke readEOF() immediately to inform it that
1705   // the socket has been closed and no more data can be read.
1706   if (readCallback_) {
1707     // Disable reads if they are enabled
1708     if (!updateEventRegistration(0, EventHandler::READ)) {
1709       // We're now in the error state; callbacks have been cleaned up
1710       assert(state_ == StateEnum::ERROR);
1711       assert(readCallback_ == nullptr);
1712     } else {
1713       ReadCallback* callback = readCallback_;
1714       readCallback_ = nullptr;
1715       callback->readEOF();
1716     }
1717   }
1718 }
1719 
closeNow()1720 void AsyncSocket::closeNow() {
1721   VLOG(5) << "AsyncSocket::closeNow(): this=" << this << ", fd_=" << fd_
1722           << ", state=" << state_ << ", shutdownFlags=" << std::hex
1723           << (int)shutdownFlags_;
1724   DestructorGuard dg(this);
1725   if (eventBase_) {
1726     eventBase_->dcheckIsInEventBaseThread();
1727   }
1728 
1729   switch (state_) {
1730     case StateEnum::ESTABLISHED:
1731     case StateEnum::CONNECTING:
1732     case StateEnum::FAST_OPEN: {
1733       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1734       state_ = StateEnum::CLOSED;
1735 
1736       // If the write timeout was set, cancel it.
1737       writeTimeout_.cancelTimeout();
1738 
1739       // If we are registered for I/O events, unregister.
1740       if (eventFlags_ != EventHandler::NONE) {
1741         eventFlags_ = EventHandler::NONE;
1742         if (!updateEventRegistration()) {
1743           // We will have been moved into the error state.
1744           assert(state_ == StateEnum::ERROR);
1745           return;
1746         }
1747       }
1748 
1749       if (immediateReadHandler_.isLoopCallbackScheduled()) {
1750         immediateReadHandler_.cancelLoopCallback();
1751       }
1752 
1753       if (fd_ != NetworkSocket()) {
1754         ioHandler_.changeHandlerFD(NetworkSocket());
1755         doClose();
1756       }
1757 
1758       invokeConnectErr(getSocketClosedLocallyEx());
1759 
1760       failAllWrites(getSocketClosedLocallyEx());
1761 
1762       if (readCallback_) {
1763         ReadCallback* callback = readCallback_;
1764         readCallback_ = nullptr;
1765         callback->readEOF();
1766       }
1767       return;
1768     }
1769     case StateEnum::CLOSED:
1770       // Do nothing.  It's possible that we are being called recursively
1771       // from inside a callback that we invoked inside another call to close()
1772       // that is still running.
1773       return;
1774     case StateEnum::ERROR:
1775       // Do nothing.  The error handling code has performed (or is performing)
1776       // cleanup.
1777       return;
1778     case StateEnum::UNINIT:
1779       assert(eventFlags_ == EventHandler::NONE);
1780       assert(connectCallback_ == nullptr);
1781       assert(readCallback_ == nullptr);
1782       assert(writeReqHead_ == nullptr);
1783       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1784       state_ = StateEnum::CLOSED;
1785       return;
1786   }
1787 
1788   LOG(DFATAL) << "AsyncSocket::closeNow() (this=" << this << ", fd=" << fd_
1789               << ") called in unknown state " << state_;
1790 }
1791 
closeWithReset()1792 void AsyncSocket::closeWithReset() {
1793   // Enable SO_LINGER, with the linger timeout set to 0.
1794   // This will trigger a TCP reset when we close the socket.
1795   if (fd_ != NetworkSocket()) {
1796     struct linger optLinger = {1, 0};
1797     if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) {
1798       VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER "
1799               << "on " << fd_ << ": errno=" << errno;
1800     }
1801   }
1802 
1803   // Then let closeNow() take care of the rest
1804   closeNow();
1805 }
1806 
shutdownWrite()1807 void AsyncSocket::shutdownWrite() {
1808   VLOG(5) << "AsyncSocket::shutdownWrite(): this=" << this << ", fd=" << fd_
1809           << ", state=" << state_ << ", shutdownFlags=" << std::hex
1810           << (int)shutdownFlags_;
1811 
1812   // If there are no pending writes, shutdownWrite() is identical to
1813   // shutdownWriteNow().
1814   if (writeReqHead_ == nullptr) {
1815     shutdownWriteNow();
1816     return;
1817   }
1818 
1819   eventBase_->dcheckIsInEventBaseThread();
1820 
1821   // There are pending writes.  Set SHUT_WRITE_PENDING so that the actual
1822   // shutdown will be performed once all writes complete.
1823   shutdownFlags_ |= SHUT_WRITE_PENDING;
1824 }
1825 
shutdownWriteNow()1826 void AsyncSocket::shutdownWriteNow() {
1827   VLOG(5) << "AsyncSocket::shutdownWriteNow(): this=" << this << ", fd=" << fd_
1828           << ", state=" << state_ << ", shutdownFlags=" << std::hex
1829           << (int)shutdownFlags_;
1830 
1831   if (shutdownFlags_ & SHUT_WRITE) {
1832     // Writes are already shutdown; nothing else to do.
1833     return;
1834   }
1835 
1836   // If SHUT_READ is already set, just call closeNow() to completely
1837   // close the socket.  This can happen if close() was called with writes
1838   // pending, and then shutdownWriteNow() is called before all pending writes
1839   // complete.
1840   if (shutdownFlags_ & SHUT_READ) {
1841     closeNow();
1842     return;
1843   }
1844 
1845   DestructorGuard dg(this);
1846   if (eventBase_) {
1847     eventBase_->dcheckIsInEventBaseThread();
1848   }
1849 
1850   switch (static_cast<StateEnum>(state_)) {
1851     case StateEnum::ESTABLISHED: {
1852       shutdownFlags_ |= SHUT_WRITE;
1853 
1854       // If the write timeout was set, cancel it.
1855       writeTimeout_.cancelTimeout();
1856 
1857       // If we are registered for write events, unregister.
1858       if (!updateEventRegistration(0, EventHandler::WRITE)) {
1859         // We will have been moved into the error state.
1860         assert(state_ == StateEnum::ERROR);
1861         return;
1862       }
1863 
1864       // Shutdown writes on the file descriptor
1865       netops_->shutdown(fd_, SHUT_WR);
1866 
1867       // Immediately fail all write requests
1868       failAllWrites(getSocketShutdownForWritesEx());
1869       return;
1870     }
1871     case StateEnum::CONNECTING: {
1872       // Set the SHUT_WRITE_PENDING flag.
1873       // When the connection completes, it will check this flag,
1874       // shutdown the write half of the socket, and then set SHUT_WRITE.
1875       shutdownFlags_ |= SHUT_WRITE_PENDING;
1876 
1877       // Immediately fail all write requests
1878       failAllWrites(getSocketShutdownForWritesEx());
1879       return;
1880     }
1881     case StateEnum::UNINIT:
1882       // Callers normally shouldn't call shutdownWriteNow() before the socket
1883       // even starts connecting.  Nonetheless, go ahead and set
1884       // SHUT_WRITE_PENDING.  Once the socket eventually connects it will
1885       // immediately shut down the write side of the socket.
1886       shutdownFlags_ |= SHUT_WRITE_PENDING;
1887       return;
1888     case StateEnum::FAST_OPEN:
1889       // In fast open state we haven't call connected yet, and if we shutdown
1890       // the writes, we will never try to call connect, so shut everything down
1891       shutdownFlags_ |= SHUT_WRITE;
1892       // Immediately fail all write requests
1893       failAllWrites(getSocketShutdownForWritesEx());
1894       return;
1895     case StateEnum::CLOSED:
1896     case StateEnum::ERROR:
1897       // We should never get here.  SHUT_WRITE should always be set
1898       // in STATE_CLOSED and STATE_ERROR.
1899       VLOG(4) << "AsyncSocket::shutdownWriteNow() (this=" << this
1900               << ", fd=" << fd_ << ") in unexpected state " << state_
1901               << " with SHUT_WRITE not set (" << std::hex << (int)shutdownFlags_
1902               << ")";
1903       assert(false);
1904       return;
1905   }
1906 
1907   LOG(DFATAL) << "AsyncSocket::shutdownWriteNow() (this=" << this
1908               << ", fd=" << fd_ << ") called in unknown state " << state_;
1909 }
1910 
readable() const1911 bool AsyncSocket::readable() const {
1912   if (fd_ == NetworkSocket()) {
1913     return false;
1914   }
1915 
1916   if (preReceivedData_ && !preReceivedData_->empty()) {
1917     return true;
1918   }
1919   netops::PollDescriptor fds[1];
1920   fds[0].fd = fd_;
1921   fds[0].events = POLLIN;
1922   fds[0].revents = 0;
1923   int rc = netops_->poll(fds, 1, 0);
1924   return rc == 1;
1925 }
1926 
writable() const1927 bool AsyncSocket::writable() const {
1928   if (fd_ == NetworkSocket()) {
1929     return false;
1930   }
1931   netops::PollDescriptor fds[1];
1932   fds[0].fd = fd_;
1933   fds[0].events = POLLOUT;
1934   fds[0].revents = 0;
1935   int rc = netops_->poll(fds, 1, 0);
1936   return rc == 1;
1937 }
1938 
isPending() const1939 bool AsyncSocket::isPending() const {
1940   return ioHandler_.isPending();
1941 }
1942 
hangup() const1943 bool AsyncSocket::hangup() const {
1944   if (fd_ == NetworkSocket()) {
1945     // sanity check, no one should ask for hangup if we are not connected.
1946     assert(false);
1947     return false;
1948   }
1949 #ifdef POLLRDHUP // Linux-only
1950   netops::PollDescriptor fds[1];
1951   fds[0].fd = fd_;
1952   fds[0].events = POLLRDHUP | POLLHUP;
1953   fds[0].revents = 0;
1954   netops_->poll(fds, 1, 0);
1955   return (fds[0].revents & (POLLRDHUP | POLLHUP)) != 0;
1956 #else
1957   return false;
1958 #endif
1959 }
1960 
good() const1961 bool AsyncSocket::good() const {
1962   return (
1963       (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN ||
1964        state_ == StateEnum::ESTABLISHED) &&
1965       (shutdownFlags_ == 0) && (eventBase_ != nullptr));
1966 }
1967 
error() const1968 bool AsyncSocket::error() const {
1969   return (state_ == StateEnum::ERROR);
1970 }
1971 
attachEventBase(EventBase * eventBase)1972 void AsyncSocket::attachEventBase(EventBase* eventBase) {
1973   VLOG(5) << "AsyncSocket::attachEventBase(this=" << this << ", fd=" << fd_
1974           << ", old evb=" << eventBase_ << ", new evb=" << eventBase
1975           << ", state=" << state_ << ", events=" << std::hex << eventFlags_
1976           << ")";
1977   assert(eventBase_ == nullptr);
1978   eventBase->dcheckIsInEventBaseThread();
1979 
1980   eventBase_ = eventBase;
1981   ioHandler_.attachEventBase(eventBase);
1982 
1983   updateEventRegistration();
1984 
1985   writeTimeout_.attachEventBase(eventBase);
1986   if (evbChangeCb_) {
1987     evbChangeCb_->evbAttached(this);
1988   }
1989   for (const auto& cb : lifecycleObservers_) {
1990     cb->evbAttach(this, eventBase_);
1991   }
1992 }
1993 
detachEventBase()1994 void AsyncSocket::detachEventBase() {
1995   VLOG(5) << "AsyncSocket::detachEventBase(this=" << this << ", fd=" << fd_
1996           << ", old evb=" << eventBase_ << ", state=" << state_
1997           << ", events=" << std::hex << eventFlags_ << ")";
1998   assert(eventBase_ != nullptr);
1999   eventBase_->dcheckIsInEventBaseThread();
2000 
2001   // Make a copy of the existing event base, to invoke lifecycle observer
2002   // callbacks
2003   EventBase* existingEvb = eventBase_;
2004 
2005   eventBase_ = nullptr;
2006 
2007   ioHandler_.unregisterHandler();
2008 
2009   ioHandler_.detachEventBase();
2010   writeTimeout_.detachEventBase();
2011   if (evbChangeCb_) {
2012     evbChangeCb_->evbDetached(this);
2013   }
2014   for (const auto& cb : lifecycleObservers_) {
2015     cb->evbDetach(this, existingEvb);
2016   }
2017 }
2018 
isDetachable() const2019 bool AsyncSocket::isDetachable() const {
2020   DCHECK(eventBase_ != nullptr);
2021   eventBase_->dcheckIsInEventBaseThread();
2022 
2023   return !writeTimeout_.isScheduled();
2024 }
2025 
cacheAddresses()2026 void AsyncSocket::cacheAddresses() {
2027   if (fd_ != NetworkSocket()) {
2028     try {
2029       cacheLocalAddress();
2030       cachePeerAddress();
2031     } catch (const std::system_error& e) {
2032       if (e.code() !=
2033           std::error_code(ENOTCONN, errorCategoryForErrnoDomain())) {
2034         VLOG(2) << "Error caching addresses: " << e.code().value() << ", "
2035                 << e.code().message();
2036       }
2037     }
2038   }
2039 }
2040 
cacheLocalAddress() const2041 void AsyncSocket::cacheLocalAddress() const {
2042   if (!localAddr_.isInitialized()) {
2043     localAddr_.setFromLocalAddress(fd_);
2044   }
2045 }
2046 
cachePeerAddress() const2047 void AsyncSocket::cachePeerAddress() const {
2048   if (!addr_.isInitialized()) {
2049     addr_.setFromPeerAddress(fd_);
2050   }
2051 }
2052 
applyOptions(const SocketOptionMap & options,SocketOptionKey::ApplyPos pos)2053 void AsyncSocket::applyOptions(
2054     const SocketOptionMap& options, SocketOptionKey::ApplyPos pos) {
2055   auto result = applySocketOptions(fd_, options, pos);
2056   if (result != 0) {
2057     throw AsyncSocketException(
2058         AsyncSocketException::INTERNAL_ERROR,
2059         withAddr("failed to set socket option"),
2060         result);
2061   }
2062 }
2063 
isZeroCopyWriteInProgress() const2064 bool AsyncSocket::isZeroCopyWriteInProgress() const noexcept {
2065   eventBase_->dcheckIsInEventBaseThread();
2066   return (!idZeroCopyBufPtrMap_.empty());
2067 }
2068 
getLocalAddress(folly::SocketAddress * address) const2069 void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
2070   cacheLocalAddress();
2071   *address = localAddr_;
2072 }
2073 
getPeerAddress(folly::SocketAddress * address) const2074 void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
2075   cachePeerAddress();
2076   *address = addr_;
2077 }
2078 
getTFOSucceded() const2079 bool AsyncSocket::getTFOSucceded() const {
2080   return detail::tfo_succeeded(fd_);
2081 }
2082 
setNoDelay(bool noDelay)2083 int AsyncSocket::setNoDelay(bool noDelay) {
2084   if (fd_ == NetworkSocket()) {
2085     VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket " << this
2086             << "(state=" << state_ << ")";
2087     return EINVAL;
2088   }
2089 
2090   int value = noDelay ? 1 : 0;
2091   if (netops_->setsockopt(
2092           fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
2093     int errnoCopy = errno;
2094     VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket " << this
2095             << " (fd=" << fd_ << ", state=" << state_
2096             << "): " << errnoStr(errnoCopy);
2097     return errnoCopy;
2098   }
2099 
2100   return 0;
2101 }
2102 
setCongestionFlavor(const std::string & cname)2103 int AsyncSocket::setCongestionFlavor(const std::string& cname) {
2104 #ifndef TCP_CONGESTION
2105 #define TCP_CONGESTION 13
2106 #endif
2107 
2108   if (fd_ == NetworkSocket()) {
2109     VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open "
2110             << "socket " << this << "(state=" << state_ << ")";
2111     return EINVAL;
2112   }
2113 
2114   if (netops_->setsockopt(
2115           fd_,
2116           IPPROTO_TCP,
2117           TCP_CONGESTION,
2118           cname.c_str(),
2119           socklen_t(cname.length() + 1)) != 0) {
2120     int errnoCopy = errno;
2121     VLOG(2) << "failed to update TCP_CONGESTION option on AsyncSocket " << this
2122             << "(fd=" << fd_ << ", state=" << state_
2123             << "): " << errnoStr(errnoCopy);
2124     return errnoCopy;
2125   }
2126 
2127   return 0;
2128 }
2129 
setQuickAck(bool quickack)2130 int AsyncSocket::setQuickAck(bool quickack) {
2131   (void)quickack;
2132   if (fd_ == NetworkSocket()) {
2133     VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket " << this
2134             << "(state=" << state_ << ")";
2135     return EINVAL;
2136   }
2137 
2138 #ifdef TCP_QUICKACK // Linux-only
2139   int value = quickack ? 1 : 0;
2140   if (netops_->setsockopt(
2141           fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
2142     int errnoCopy = errno;
2143     VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket" << this
2144             << "(fd=" << fd_ << ", state=" << state_
2145             << "): " << errnoStr(errnoCopy);
2146     return errnoCopy;
2147   }
2148 
2149   return 0;
2150 #else
2151   return ENOSYS;
2152 #endif
2153 }
2154 
setSendBufSize(size_t bufsize)2155 int AsyncSocket::setSendBufSize(size_t bufsize) {
2156   if (fd_ == NetworkSocket()) {
2157     VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket "
2158             << this << "(state=" << state_ << ")";
2159     return EINVAL;
2160   }
2161 
2162   if (netops_->setsockopt(
2163           fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) != 0) {
2164     int errnoCopy = errno;
2165     VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket" << this
2166             << "(fd=" << fd_ << ", state=" << state_
2167             << "): " << errnoStr(errnoCopy);
2168     return errnoCopy;
2169   }
2170 
2171   return 0;
2172 }
2173 
setRecvBufSize(size_t bufsize)2174 int AsyncSocket::setRecvBufSize(size_t bufsize) {
2175   if (fd_ == NetworkSocket()) {
2176     VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket "
2177             << this << "(state=" << state_ << ")";
2178     return EINVAL;
2179   }
2180 
2181   if (netops_->setsockopt(
2182           fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) != 0) {
2183     int errnoCopy = errno;
2184     VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket" << this
2185             << "(fd=" << fd_ << ", state=" << state_
2186             << "): " << errnoStr(errnoCopy);
2187     return errnoCopy;
2188   }
2189 
2190   return 0;
2191 }
2192 
2193 #if defined(__linux__)
getSendBufInUse() const2194 size_t AsyncSocket::getSendBufInUse() const {
2195   if (fd_ == NetworkSocket()) {
2196     std::stringstream issueString;
2197     issueString << "AsyncSocket::getSendBufInUse() called on non-open socket "
2198                 << this << "(state=" << state_ << ")";
2199     VLOG(4) << issueString.str();
2200     throw std::logic_error(issueString.str());
2201   }
2202 
2203   size_t returnValue = 0;
2204   if (-1 == ::ioctl(fd_.toFd(), SIOCOUTQ, &returnValue)) {
2205     int errnoCopy = errno;
2206     std::stringstream issueString;
2207     issueString << "Failed to get the tx used bytes on Socket: " << this
2208                 << "(fd=" << fd_ << ", state=" << state_
2209                 << "): " << errnoStr(errnoCopy);
2210     VLOG(2) << issueString.str();
2211     throw std::logic_error(issueString.str());
2212   }
2213 
2214   return returnValue;
2215 }
2216 
getRecvBufInUse() const2217 size_t AsyncSocket::getRecvBufInUse() const {
2218   if (fd_ == NetworkSocket()) {
2219     std::stringstream issueString;
2220     issueString << "AsyncSocket::getRecvBufInUse() called on non-open socket "
2221                 << this << "(state=" << state_ << ")";
2222     VLOG(4) << issueString.str();
2223     throw std::logic_error(issueString.str());
2224   }
2225 
2226   size_t returnValue = 0;
2227   if (-1 == ::ioctl(fd_.toFd(), SIOCINQ, &returnValue)) {
2228     std::stringstream issueString;
2229     int errnoCopy = errno;
2230     issueString << "Failed to get the rx used bytes on Socket: " << this
2231                 << "(fd=" << fd_ << ", state=" << state_
2232                 << "): " << errnoStr(errnoCopy);
2233     VLOG(2) << issueString.str();
2234     throw std::logic_error(issueString.str());
2235   }
2236 
2237   return returnValue;
2238 }
2239 #endif
2240 
setTCPProfile(int profd)2241 int AsyncSocket::setTCPProfile(int profd) {
2242   if (fd_ == NetworkSocket()) {
2243     VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket " << this
2244             << "(state=" << state_ << ")";
2245     return EINVAL;
2246   }
2247 
2248   if (netops_->setsockopt(
2249           fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) != 0) {
2250     int errnoCopy = errno;
2251     VLOG(2) << "failed to set socket namespace option on AsyncSocket" << this
2252             << "(fd=" << fd_ << ", state=" << state_
2253             << "): " << errnoStr(errnoCopy);
2254     return errnoCopy;
2255   }
2256 
2257   return 0;
2258 }
2259 
ioReady(uint16_t events)2260 void AsyncSocket::ioReady(uint16_t events) noexcept {
2261   VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd=" << fd_
2262           << ", events=" << std::hex << events << ", state=" << state_;
2263   DestructorGuard dg(this);
2264   assert(events & EventHandler::READ_WRITE);
2265   eventBase_->dcheckIsInEventBaseThread();
2266 
2267   auto relevantEvents = uint16_t(events & EventHandler::READ_WRITE);
2268   EventBase* originalEventBase = eventBase_;
2269   // If we got there it means that either EventHandler::READ or
2270   // EventHandler::WRITE is set. Any of these flags can
2271   // indicate that there are messages available in the socket
2272   // error message queue.
2273   // Return if we handle any error messages - this is to avoid
2274   // unnecessary read/write calls
2275   if (handleErrMessages()) {
2276     return;
2277   }
2278 
2279   // Return now if handleErrMessages() detached us from our EventBase
2280   if (eventBase_ != originalEventBase) {
2281     return;
2282   }
2283 
2284   if (relevantEvents == EventHandler::READ) {
2285     handleRead();
2286   } else if (relevantEvents == EventHandler::WRITE) {
2287     handleWrite();
2288   } else if (relevantEvents == EventHandler::READ_WRITE) {
2289     // If both read and write events are ready, process writes first.
2290     handleWrite();
2291 
2292     // Return now if handleWrite() detached us from our EventBase
2293     if (eventBase_ != originalEventBase) {
2294       return;
2295     }
2296 
2297     // Only call handleRead() if a read callback is still installed.
2298     // (It's possible that the read callback was uninstalled during
2299     // handleWrite().)
2300     if (readCallback_) {
2301       handleRead();
2302     }
2303   } else {
2304     VLOG(4) << "AsyncSocket::ioRead() called with unexpected events "
2305             << std::hex << events << "(this=" << this << ")";
2306     abort();
2307   }
2308 }
2309 
performRead(void ** buf,size_t * buflen,size_t *)2310 AsyncSocket::ReadResult AsyncSocket::performRead(
2311     void** buf, size_t* buflen, size_t* /* offset */) {
2312   struct iovec iov;
2313 
2314   // Data buffer pointer and length
2315   iov.iov_base = *buf;
2316   iov.iov_len = *buflen;
2317 
2318   return performReadInternal(&iov, 1);
2319 }
2320 
performReadv(struct iovec * iovs,size_t num)2321 AsyncSocket::ReadResult AsyncSocket::performReadv(
2322     struct iovec* iovs, size_t num) {
2323   return performReadInternal(iovs, num);
2324 }
2325 
performReadInternal(struct iovec * iovs,size_t num)2326 AsyncSocket::ReadResult AsyncSocket::performReadInternal(
2327     struct iovec* iovs, size_t num) {
2328   VLOG(5) << "AsyncSocket::performReadInternal() this=" << this
2329           << ", iovs=" << iovs << ", num=" << num;
2330 
2331   if (!num) {
2332     return ReadResult(READ_ERROR);
2333   }
2334 
2335   if (preReceivedData_ && !preReceivedData_->empty()) {
2336     VLOG(5) << "AsyncSocket::performReadInternal() this=" << this
2337             << ", reading pre-received data";
2338 
2339     ssize_t len = 0;
2340     for (size_t i = 0; (i < num) && (!preReceivedData_->empty()); ++i) {
2341       io::Cursor cursor(preReceivedData_.get());
2342       auto ret = cursor.pullAtMost(iovs[i].iov_base, iovs[i].iov_len);
2343       len += ret;
2344 
2345       IOBufQueue queue;
2346       queue.append(std::move(preReceivedData_));
2347       queue.trimStart(ret);
2348       preReceivedData_ = queue.move();
2349     }
2350 
2351     appBytesReceived_ += len;
2352     return ReadResult(len);
2353   }
2354 
2355   ssize_t bytes = 0;
2356 
2357   struct msghdr msg;
2358 
2359   if (readAncillaryDataCallback_ == nullptr && num == 1) {
2360     bytes = netops_->recv(fd_, iovs[0].iov_base, iovs[0].iov_len, MSG_DONTWAIT);
2361   } else {
2362     if (readAncillaryDataCallback_) {
2363       // Ancillary data buffer and length
2364       msg.msg_control =
2365           readAncillaryDataCallback_->getAncillaryDataCtrlBuffer().data();
2366       msg.msg_controllen =
2367           readAncillaryDataCallback_->getAncillaryDataCtrlBuffer().size();
2368     } else {
2369       msg.msg_control = nullptr;
2370       msg.msg_controllen = 0;
2371     }
2372 
2373     // Dest address info
2374     msg.msg_name = nullptr;
2375     msg.msg_namelen = 0;
2376 
2377     // Array of data buffers (scatter/gather)
2378     msg.msg_iov = iovs;
2379     msg.msg_iovlen = num;
2380 
2381     bytes = netops::recvmsg(fd_, &msg, 0);
2382   }
2383 
2384   if (readAncillaryDataCallback_ && (bytes > 0)) {
2385     readAncillaryDataCallback_->ancillaryData(msg);
2386   }
2387 
2388   if (bytes < 0) {
2389     if (errno == EAGAIN || errno == EWOULDBLOCK) {
2390       // No more data to read right now.
2391       return ReadResult(READ_BLOCKING);
2392     } else {
2393       return ReadResult(READ_ERROR);
2394     }
2395   } else {
2396     appBytesReceived_ += bytes;
2397     return ReadResult(bytes);
2398   }
2399 }
2400 
prepareReadBuffer(void ** buf,size_t * buflen)2401 void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
2402   // no matter what, buffer should be prepared for non-ssl socket
2403   CHECK(readCallback_);
2404   readCallback_->getReadBuffer(buf, buflen);
2405 }
2406 
prepareReadBuffers(IOBufIovecBuilder::IoVecVec & iovs)2407 void AsyncSocket::prepareReadBuffers(IOBufIovecBuilder::IoVecVec& iovs) {
2408   // no matter what, buffers should be prepared for non-ssl socket
2409   CHECK(readCallback_);
2410   readCallback_->getReadBuffers(iovs);
2411 }
2412 
handleErrMessages()2413 size_t AsyncSocket::handleErrMessages() noexcept {
2414   // This method has non-empty implementation only for platforms
2415   // supporting per-socket error queues.
2416   VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
2417           << ", state=" << state_;
2418   if (errMessageCallback_ == nullptr && idZeroCopyBufPtrMap_.empty() &&
2419       (!byteEventHelper_ || !byteEventHelper_->byteEventsEnabled)) {
2420     VLOG(7) << "AsyncSocket::handleErrMessages(): "
2421             << "no err message callback installed and "
2422             << "ByteEvents not enabled - exiting.";
2423     return 0;
2424   }
2425 
2426 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
2427   uint8_t ctrl[1024];
2428   unsigned char data;
2429   struct msghdr msg;
2430   iovec entry;
2431 
2432   entry.iov_base = &data;
2433   entry.iov_len = sizeof(data);
2434   msg.msg_iov = &entry;
2435   msg.msg_iovlen = 1;
2436   msg.msg_name = nullptr;
2437   msg.msg_namelen = 0;
2438   msg.msg_control = ctrl;
2439   msg.msg_controllen = sizeof(ctrl);
2440   msg.msg_flags = 0;
2441 
2442   int ret;
2443   size_t num = 0;
2444   // the socket may be closed by errMessage callback, so check on each iteration
2445   while (fd_ != NetworkSocket()) {
2446     ret = netops_->recvmsg(fd_, &msg, MSG_ERRQUEUE);
2447     VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
2448 
2449     if (ret < 0) {
2450       if (errno != EAGAIN) {
2451         auto errnoCopy = errno;
2452         LOG(ERROR) << "::recvmsg exited with code " << ret
2453                    << ", errno: " << errnoCopy << ", fd: " << fd_;
2454         AsyncSocketException ex(
2455             AsyncSocketException::INTERNAL_ERROR,
2456             withAddr("recvmsg() failed"),
2457             errnoCopy);
2458         failErrMessageRead(__func__, ex);
2459       }
2460 
2461       return num;
2462     }
2463 
2464     for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
2465          cmsg != nullptr && cmsg->cmsg_len != 0;
2466          cmsg = CMSG_NXTHDR(&msg, cmsg)) {
2467       ++num;
2468       if (isZeroCopyMsg(*cmsg)) {
2469         processZeroCopyMsg(*cmsg);
2470         continue;
2471       }
2472 
2473       // try to process it as a ByteEvent and forward to observers
2474       //
2475       // observers cannot throw and thus we expect only exceptions from
2476       // ByteEventHelper, but we guard against other cases for safety
2477       if (byteEventHelper_) {
2478         try {
2479           if (const auto maybeByteEvent =
2480                   byteEventHelper_->processCmsg(*cmsg, getRawBytesWritten())) {
2481             const auto& byteEvent = maybeByteEvent.value();
2482             for (const auto& observer : lifecycleObservers_) {
2483               if (observer->getConfig().byteEvents) {
2484                 observer->byteEvent(this, byteEvent);
2485               }
2486             }
2487           }
2488         } catch (const ByteEventHelper::Exception& behEx) {
2489           // rewrap the ByteEventHelper::Exception with extra information
2490           AsyncSocketException ex(
2491               AsyncSocketException::INTERNAL_ERROR,
2492               withAddr(
2493                   string("AsyncSocket::handleErrMessages(), "
2494                          "internal exception during ByteEvent processing: ") +
2495                   behEx.what()));
2496           failByteEvents(ex);
2497         } catch (const std::exception& ex) {
2498           AsyncSocketException tex(
2499               AsyncSocketException::UNKNOWN,
2500               string("AsyncSocket::handleErrMessages(), "
2501                      "unhandled exception during ByteEvent processing, "
2502                      "threw exception: ") +
2503                   ex.what());
2504           failByteEvents(tex);
2505         } catch (...) {
2506           AsyncSocketException tex(
2507               AsyncSocketException::UNKNOWN,
2508               string("AsyncSocket::handleErrMessages(), "
2509                      "unhandled exception during ByteEvent processing, "
2510                      "threw non-exception type"));
2511           failByteEvents(tex);
2512         }
2513       }
2514 
2515       // even if it is a timestamp, hand it off to the errMessageCallback,
2516       // the application may want it as well.
2517       if (errMessageCallback_) {
2518         errMessageCallback_->errMessage(*cmsg);
2519       }
2520     }
2521   }
2522   return num;
2523 #else
2524   return 0;
2525 #endif // FOLLY_HAVE_MSG_ERRQUEUE
2526 }
2527 
processZeroCopyWriteInProgress()2528 bool AsyncSocket::processZeroCopyWriteInProgress() noexcept {
2529   eventBase_->dcheckIsInEventBaseThread();
2530   if (idZeroCopyBufPtrMap_.empty()) {
2531     return true;
2532   }
2533 
2534   handleErrMessages();
2535 
2536   return idZeroCopyBufPtrMap_.empty();
2537 }
2538 
addLifecycleObserver(AsyncTransport::LifecycleObserver * observer)2539 void AsyncSocket::addLifecycleObserver(
2540     AsyncTransport::LifecycleObserver* observer) {
2541   if (eventBase_) {
2542     eventBase_->dcheckIsInEventBaseThread();
2543   }
2544 
2545   // adding the same observer multiple times is not allowed
2546   auto& observers = lifecycleObservers_;
2547   CHECK(
2548       std::find(observers.begin(), observers.end(), observer) ==
2549       observers.end());
2550 
2551   observers.push_back(observer);
2552   observer->observerAttach(this);
2553   if (observer->getConfig().byteEvents) {
2554     if (byteEventHelper_ && byteEventHelper_->maybeEx.has_value()) {
2555       observer->byteEventsUnavailable(this, *byteEventHelper_->maybeEx);
2556     } else if (byteEventHelper_ && byteEventHelper_->byteEventsEnabled) {
2557       observer->byteEventsEnabled(this);
2558     } else if (state_ == StateEnum::ESTABLISHED) {
2559       enableByteEvents(); // try to enable now
2560     }
2561     // do nothing right now; wait until we're connected
2562   }
2563 }
2564 
removeLifecycleObserver(AsyncTransport::LifecycleObserver * observer)2565 bool AsyncSocket::removeLifecycleObserver(
2566     AsyncTransport::LifecycleObserver* observer) {
2567   auto& observers = lifecycleObservers_;
2568   auto it = std::find(observers.begin(), observers.end(), observer);
2569   if (it == observers.end()) {
2570     return false;
2571   }
2572   observer->observerDetach(this);
2573   observers.erase(it);
2574   return true;
2575 }
2576 
2577 std::vector<AsyncTransport::LifecycleObserver*>
getLifecycleObservers() const2578 AsyncSocket::getLifecycleObservers() const {
2579   if (eventBase_) {
2580     eventBase_->dcheckIsInEventBaseThread();
2581   }
2582   return std::vector<AsyncTransport::LifecycleObserver*>(
2583       lifecycleObservers_.begin(), lifecycleObservers_.end());
2584 }
2585 
splitIovecArray(const size_t startOffset,const size_t endOffset,const iovec * srcVec,const size_t srcCount,iovec * dstVec,size_t & dstCount)2586 void AsyncSocket::splitIovecArray(
2587     const size_t startOffset,
2588     const size_t endOffset,
2589     const iovec* srcVec,
2590     const size_t srcCount,
2591     iovec* dstVec,
2592     size_t& dstCount) {
2593   CHECK_GE(endOffset, startOffset);
2594   CHECK_GE(dstCount, srcCount);
2595   dstCount = 0;
2596 
2597   const size_t targetBytes = endOffset - startOffset + 1;
2598   size_t dstBytes = 0;
2599   size_t processedBytes = 0;
2600   for (size_t i = 0; i < srcCount; processedBytes += srcVec[i].iov_len, i++) {
2601     iovec currentOp = srcVec[i];
2602     if (currentOp.iov_len == 0) { // to handle the oddballs
2603       continue;
2604     }
2605 
2606     // if we haven't found the start offset yet, see if it is in this op
2607     if (dstCount == 0) {
2608       if (processedBytes + currentOp.iov_len < startOffset + 1) {
2609         continue; // start offset isn't in this op
2610       }
2611 
2612       // offset iov_base to get the start offset
2613       const size_t trimFromStart = startOffset - processedBytes;
2614       currentOp.iov_base =
2615           reinterpret_cast<uint8_t*>(currentOp.iov_base) + trimFromStart;
2616       currentOp.iov_len -= trimFromStart;
2617     }
2618 
2619     // trim the end of the iovec, if needed
2620     ssize_t trimFromEnd = (dstBytes + currentOp.iov_len) - targetBytes;
2621     if (trimFromEnd > 0) {
2622       currentOp.iov_len -= trimFromEnd;
2623     }
2624 
2625     dstVec[dstCount] = currentOp;
2626     dstCount++;
2627     dstBytes += currentOp.iov_len;
2628     CHECK_GE(targetBytes, dstBytes);
2629     if (targetBytes == dstBytes) {
2630       break; // done
2631     }
2632   }
2633 
2634   CHECK_EQ(targetBytes, dstBytes);
2635 }
2636 
handleRead()2637 void AsyncSocket::handleRead() noexcept {
2638   VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
2639           << ", state=" << state_;
2640   assert(state_ == StateEnum::ESTABLISHED);
2641   assert((shutdownFlags_ & SHUT_READ) == 0);
2642   assert(readCallback_ != nullptr);
2643   assert(eventFlags_ & EventHandler::READ);
2644 
2645   // Loop until:
2646   // - a read attempt would block
2647   // - readCallback_ is uninstalled
2648   // - the number of loop iterations exceeds the optional maximum
2649   // - this AsyncSocket is moved to another EventBase
2650   //
2651   // When we invoke readDataAvailable() it may uninstall the readCallback_,
2652   // which is why need to check for it here.
2653   //
2654   // The last bullet point is slightly subtle.  readDataAvailable() may also
2655   // detach this socket from this EventBase.  However, before
2656   // readDataAvailable() returns another thread may pick it up, attach it to
2657   // a different EventBase, and install another readCallback_.  We need to
2658   // exit immediately after readDataAvailable() returns if the eventBase_ has
2659   // changed.  (The caller must perform some sort of locking to transfer the
2660   // AsyncSocket between threads properly.  This will be sufficient to ensure
2661   // that this thread sees the updated eventBase_ variable after
2662   // readDataAvailable() returns.)
2663   size_t numReads = maxReadsPerEvent_ ? maxReadsPerEvent_ : size_t(-1);
2664   EventBase* originalEventBase = eventBase_;
2665   while (readCallback_ && eventBase_ == originalEventBase && numReads--) {
2666     auto readMode = readCallback_->getReadMode();
2667     // Get the buffer(s) to read into.
2668     void* buf = nullptr;
2669     size_t buflen = 0, offset = 0, num = 0;
2670     IOBufIovecBuilder::IoVecVec iovs; // this can be an Asyncsocket member too
2671 
2672     try {
2673       if (readMode == AsyncReader::ReadCallback::ReadMode::ReadVec) {
2674         prepareReadBuffers(iovs);
2675         num = iovs.size();
2676         VLOG(5) << "prepareReadBuffers() bufs=" << iovs.data()
2677                 << ", num=" << num;
2678       } else {
2679         prepareReadBuffer(&buf, &buflen);
2680         VLOG(5) << "prepareReadBuffer() buf=" << buf << ", buflen=" << buflen;
2681       }
2682     } catch (const AsyncSocketException& ex) {
2683       return failRead(__func__, ex);
2684     } catch (const std::exception& ex) {
2685       AsyncSocketException tex(
2686           AsyncSocketException::BAD_ARGS,
2687           string("ReadCallback::getReadBuffer() "
2688                  "threw exception: ") +
2689               ex.what());
2690       return failRead(__func__, tex);
2691     } catch (...) {
2692       AsyncSocketException ex(
2693           AsyncSocketException::BAD_ARGS,
2694           "ReadCallback::getReadBuffer() threw "
2695           "non-exception type");
2696       return failRead(__func__, ex);
2697     }
2698     if ((num == 0) && (buf == nullptr || buflen == 0)) {
2699       AsyncSocketException ex(
2700           AsyncSocketException::BAD_ARGS,
2701           "ReadCallback::getReadBuffer() returned "
2702           "empty buffer");
2703       return failRead(__func__, ex);
2704     }
2705 
2706     // Perform the read
2707     auto readResult = (readMode == AsyncReader::ReadCallback::ReadMode::ReadVec)
2708         ? performReadv(iovs.data(), num)
2709         : performRead(&buf, &buflen, &offset);
2710     auto bytesRead = readResult.readReturn;
2711     VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got "
2712             << bytesRead << " bytes";
2713     if (bytesRead > 0) {
2714       readCallback_->readDataAvailable(size_t(bytesRead));
2715 
2716       // Fall through and continue around the loop if the read
2717       // completely filled the available buffer.
2718       // Note that readCallback_ may have been uninstalled or changed inside
2719       // readDataAvailable().
2720       if (size_t(bytesRead) < buflen) {
2721         return;
2722       }
2723     } else if (bytesRead == READ_BLOCKING) {
2724       // No more data to read right now.
2725       return;
2726     } else if (bytesRead == READ_ERROR) {
2727       readErr_ = READ_ERROR;
2728       if (readResult.exception) {
2729         return failRead(__func__, *readResult.exception);
2730       }
2731       auto errnoCopy = errno;
2732       AsyncSocketException ex(
2733           AsyncSocketException::INTERNAL_ERROR,
2734           withAddr("recv() failed"),
2735           errnoCopy);
2736       return failRead(__func__, ex);
2737     } else {
2738       assert(bytesRead == READ_EOF);
2739       readErr_ = READ_EOF;
2740       // EOF
2741       shutdownFlags_ |= SHUT_READ;
2742       if (!updateEventRegistration(0, EventHandler::READ)) {
2743         // we've already been moved into STATE_ERROR
2744         assert(state_ == StateEnum::ERROR);
2745         assert(readCallback_ == nullptr);
2746         return;
2747       }
2748 
2749       ReadCallback* callback = readCallback_;
2750       readCallback_ = nullptr;
2751       callback->readEOF();
2752       return;
2753     }
2754   }
2755 
2756   if (readCallback_ && eventBase_ == originalEventBase) {
2757     // We might still have data in the socket.
2758     // (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
2759     scheduleImmediateRead();
2760   }
2761 }
2762 
2763 /**
2764  * This function attempts to write as much data as possible, until no more data
2765  * can be written.
2766  *
2767  * - If it sends all available data, it unregisters for write events, and stops
2768  *   the writeTimeout_.
2769  *
2770  * - If not all of the data can be sent immediately, it reschedules
2771  *   writeTimeout_ (if a non-zero timeout is set), and ensures the handler is
2772  *   registered for write events.
2773  */
handleWrite()2774 void AsyncSocket::handleWrite() noexcept {
2775   VLOG(5) << "AsyncSocket::handleWrite() this=" << this << ", fd=" << fd_
2776           << ", state=" << state_;
2777   DestructorGuard dg(this);
2778 
2779   if (state_ == StateEnum::CONNECTING) {
2780     handleConnect();
2781     return;
2782   }
2783 
2784   // Normal write
2785   assert(state_ == StateEnum::ESTABLISHED);
2786   assert((shutdownFlags_ & SHUT_WRITE) == 0);
2787   assert(writeReqHead_ != nullptr);
2788 
2789   // Loop until we run out of write requests,
2790   // or until this socket is moved to another EventBase.
2791   // (See the comment in handleRead() explaining how this can happen.)
2792   EventBase* originalEventBase = eventBase_;
2793   while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
2794     auto writeResult = writeReqHead_->performWrite();
2795     if (writeResult.writeReturn < 0) {
2796       if (writeResult.exception) {
2797         return failWrite(__func__, *writeResult.exception);
2798       }
2799       auto errnoCopy = errno;
2800       AsyncSocketException ex(
2801           AsyncSocketException::INTERNAL_ERROR,
2802           withAddr("writev() failed"),
2803           errnoCopy);
2804       return failWrite(__func__, ex);
2805     } else if (writeReqHead_->isComplete()) {
2806       // We finished this request
2807       WriteRequest* req = writeReqHead_;
2808       writeReqHead_ = req->getNext();
2809 
2810       if (writeReqHead_ == nullptr) {
2811         writeReqTail_ = nullptr;
2812         // This is the last write request.
2813         // Unregister for write events and cancel the send timer
2814         // before we invoke the callback.  We have to update the state properly
2815         // before calling the callback, since it may want to detach us from
2816         // the EventBase.
2817         if (eventFlags_ & EventHandler::WRITE) {
2818           if (!updateEventRegistration(0, EventHandler::WRITE)) {
2819             assert(state_ == StateEnum::ERROR);
2820             return;
2821           }
2822           // Stop the send timeout
2823           writeTimeout_.cancelTimeout();
2824         }
2825         assert(!writeTimeout_.isScheduled());
2826 
2827         // If SHUT_WRITE_PENDING is set, we should shutdown the socket after
2828         // we finish sending the last write request.
2829         //
2830         // We have to do this before invoking writeSuccess(), since
2831         // writeSuccess() may detach us from our EventBase.
2832         if (shutdownFlags_ & SHUT_WRITE_PENDING) {
2833           assert(connectCallback_ == nullptr);
2834           shutdownFlags_ |= SHUT_WRITE;
2835 
2836           if (shutdownFlags_ & SHUT_READ) {
2837             // Reads have already been shutdown.  Fully close the socket and
2838             // move to STATE_CLOSED.
2839             //
2840             // Note: This code currently moves us to STATE_CLOSED even if
2841             // close() hasn't ever been called.  This can occur if we have
2842             // received EOF from the peer and shutdownWrite() has been called
2843             // locally.  Should we bother staying in STATE_ESTABLISHED in this
2844             // case, until close() is actually called?  I can't think of a
2845             // reason why we would need to do so.  No other operations besides
2846             // calling close() or destroying the socket can be performed at
2847             // this point.
2848             assert(readCallback_ == nullptr);
2849             state_ = StateEnum::CLOSED;
2850             if (fd_ != NetworkSocket()) {
2851               ioHandler_.changeHandlerFD(NetworkSocket());
2852               doClose();
2853             }
2854           } else {
2855             // Reads are still enabled, so we are only doing a half-shutdown
2856             netops_->shutdown(fd_, SHUT_WR);
2857           }
2858         }
2859       }
2860 
2861       // Invoke the callback
2862       WriteCallback* callback = req->getCallback();
2863       req->destroy();
2864       if (callback) {
2865         callback->writeSuccess();
2866       }
2867       // We'll continue around the loop, trying to write another request
2868     } else {
2869       // Partial write.
2870       writeReqHead_->consume();
2871       if (bufferCallback_) {
2872         bufferCallback_->onEgressBuffered();
2873       }
2874       // Stop after a partial write; it's highly likely that a subsequent write
2875       // attempt will just return EAGAIN.
2876       //
2877       // Ensure that we are registered for write events.
2878       if ((eventFlags_ & EventHandler::WRITE) == 0) {
2879         if (!updateEventRegistration(EventHandler::WRITE, 0)) {
2880           assert(state_ == StateEnum::ERROR);
2881           return;
2882         }
2883       }
2884 
2885       // Reschedule the send timeout, since we have made some write progress.
2886       if (sendTimeout_ > 0) {
2887         if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
2888           AsyncSocketException ex(
2889               AsyncSocketException::INTERNAL_ERROR,
2890               withAddr("failed to reschedule write timeout"));
2891           return failWrite(__func__, ex);
2892         }
2893       }
2894       return;
2895     }
2896   }
2897   if (!writeReqHead_ && bufferCallback_) {
2898     bufferCallback_->onEgressBufferCleared();
2899   }
2900 }
2901 
checkForImmediateRead()2902 void AsyncSocket::checkForImmediateRead() noexcept {
2903   // We currently don't attempt to perform optimistic reads in AsyncSocket.
2904   // (However, note that some subclasses do override this method.)
2905   //
2906   // Simply calling handleRead() here would be bad, as this would call
2907   // readCallback_->getReadBuffer(), forcing the callback to allocate a read
2908   // buffer even though no data may be available.  This would waste lots of
2909   // memory, since the buffer will sit around unused until the socket actually
2910   // becomes readable.
2911   //
2912   // Checking if the socket is readable now also seems like it would probably
2913   // be a pessimism.  In most cases it probably wouldn't be readable, and we
2914   // would just waste an extra system call.  Even if it is readable, waiting to
2915   // find out from libevent on the next event loop doesn't seem that bad.
2916   //
2917   // The exception to this is if we have pre-received data. In that case there
2918   // is definitely data available immediately.
2919   if (preReceivedData_ && !preReceivedData_->empty()) {
2920     handleRead();
2921   }
2922 }
2923 
handleInitialReadWrite()2924 void AsyncSocket::handleInitialReadWrite() noexcept {
2925   // Our callers should already be holding a DestructorGuard, but grab
2926   // one here just to make sure, in case one of our calling code paths ever
2927   // changes.
2928   DestructorGuard dg(this);
2929   // If we have a readCallback_, make sure we enable read events.  We
2930   // may already be registered for reads if connectSuccess() set
2931   // the read calback.
2932   if (readCallback_ && !(eventFlags_ & EventHandler::READ)) {
2933     assert(state_ == StateEnum::ESTABLISHED);
2934     assert((shutdownFlags_ & SHUT_READ) == 0);
2935     if (!updateEventRegistration(EventHandler::READ, 0)) {
2936       assert(state_ == StateEnum::ERROR);
2937       return;
2938     }
2939     checkForImmediateRead();
2940   } else if (readCallback_ == nullptr) {
2941     // Unregister for read events.
2942     updateEventRegistration(0, EventHandler::READ);
2943   }
2944 
2945   // If we have write requests pending, try to send them immediately.
2946   // Since we just finished accepting, there is a very good chance that we can
2947   // write without blocking.
2948   //
2949   // However, we only process them if EventHandler::WRITE is not already set,
2950   // which means that we're already blocked on a write attempt.  (This can
2951   // happen if connectSuccess() called write() before returning.)
2952   if (writeReqHead_ && !(eventFlags_ & EventHandler::WRITE)) {
2953     // Call handleWrite() to perform write processing.
2954     handleWrite();
2955   } else if (writeReqHead_ == nullptr) {
2956     // Unregister for write event.
2957     updateEventRegistration(0, EventHandler::WRITE);
2958   }
2959 }
2960 
handleConnect()2961 void AsyncSocket::handleConnect() noexcept {
2962   VLOG(5) << "AsyncSocket::handleConnect() this=" << this << ", fd=" << fd_
2963           << ", state=" << state_;
2964   assert(state_ == StateEnum::CONNECTING);
2965   // SHUT_WRITE can never be set while we are still connecting;
2966   // SHUT_WRITE_PENDING may be set, be we only set SHUT_WRITE once the connect
2967   // finishes
2968   assert((shutdownFlags_ & SHUT_WRITE) == 0);
2969 
2970   // In case we had a connect timeout, cancel the timeout
2971   writeTimeout_.cancelTimeout();
2972   // We don't use a persistent registration when waiting on a connect event,
2973   // so we have been automatically unregistered now.  Update eventFlags_ to
2974   // reflect reality.
2975   assert(eventFlags_ == EventHandler::WRITE);
2976   eventFlags_ = EventHandler::NONE;
2977 
2978   // Call getsockopt() to check if the connect succeeded
2979   int error;
2980   socklen_t len = sizeof(error);
2981   int rv = netops_->getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
2982   if (rv != 0) {
2983     auto errnoCopy = errno;
2984     AsyncSocketException ex(
2985         AsyncSocketException::INTERNAL_ERROR,
2986         withAddr("error calling getsockopt() after connect"),
2987         errnoCopy);
2988     VLOG(4) << "AsyncSocket::handleConnect(this=" << this << ", fd=" << fd_
2989             << " host=" << addr_.describe() << ") exception:" << ex.what();
2990     return failConnect(__func__, ex);
2991   }
2992 
2993   if (error != 0) {
2994     AsyncSocketException ex(
2995         AsyncSocketException::NOT_OPEN, "connect failed", error);
2996     VLOG(2) << "AsyncSocket::handleConnect(this=" << this << ", fd=" << fd_
2997             << " host=" << addr_.describe() << ") exception: " << ex.what();
2998     return failConnect(__func__, ex);
2999   }
3000 
3001   // Move into STATE_ESTABLISHED
3002   state_ = StateEnum::ESTABLISHED;
3003 
3004   // If SHUT_WRITE_PENDING is set and we don't have any write requests to
3005   // perform, immediately shutdown the write half of the socket.
3006   if ((shutdownFlags_ & SHUT_WRITE_PENDING) && writeReqHead_ == nullptr) {
3007     // SHUT_READ shouldn't be set.  If close() is called on the socket while we
3008     // are still connecting we just abort the connect rather than waiting for
3009     // it to complete.
3010     assert((shutdownFlags_ & SHUT_READ) == 0);
3011     netops_->shutdown(fd_, SHUT_WR);
3012     shutdownFlags_ |= SHUT_WRITE;
3013   }
3014 
3015   VLOG(7) << "AsyncSocket " << this << ": fd " << fd_
3016           << "successfully connected; state=" << state_;
3017 
3018   // Remember the EventBase we are attached to, before we start invoking any
3019   // callbacks (since the callbacks may call detachEventBase()).
3020   EventBase* originalEventBase = eventBase_;
3021 
3022   invokeConnectSuccess();
3023   // Note that the connect callback may have changed our state.
3024   // (set or unset the read callback, called write(), closed the socket, etc.)
3025   // The following code needs to handle these situations correctly.
3026   //
3027   // If the socket has been closed, readCallback_ and writeReqHead_ will
3028   // always be nullptr, so that will prevent us from trying to read or write.
3029   //
3030   // The main thing to check for is if eventBase_ is still originalEventBase.
3031   // If not, we have been detached from this event base, so we shouldn't
3032   // perform any more operations.
3033   if (eventBase_ != originalEventBase) {
3034     return;
3035   }
3036 
3037   handleInitialReadWrite();
3038 }
3039 
timeoutExpired()3040 void AsyncSocket::timeoutExpired() noexcept {
3041   VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
3042           << "state=" << state_ << ", events=" << std::hex << eventFlags_;
3043   DestructorGuard dg(this);
3044   eventBase_->dcheckIsInEventBaseThread();
3045 
3046   if (state_ == StateEnum::CONNECTING) {
3047     // connect() timed out
3048     // Unregister for I/O events.
3049     if (connectCallback_) {
3050       AsyncSocketException ex(
3051           AsyncSocketException::TIMED_OUT,
3052           folly::sformat(
3053               "connect timed out after {}ms", connectTimeout_.count()));
3054       failConnect(__func__, ex);
3055     } else {
3056       // we faced a connect error without a connect callback, which could
3057       // happen due to TFO.
3058       AsyncSocketException ex(
3059           AsyncSocketException::TIMED_OUT, "write timed out during connection");
3060       failWrite(__func__, ex);
3061     }
3062   } else {
3063     // a normal write operation timed out
3064     AsyncSocketException ex(
3065         AsyncSocketException::TIMED_OUT,
3066         folly::sformat("write timed out after {}ms", sendTimeout_));
3067     failWrite(__func__, ex);
3068   }
3069 }
3070 
handleNetworkSocketAttached()3071 void AsyncSocket::handleNetworkSocketAttached() {
3072   VLOG(6) << "AsyncSocket::attachFd(this=" << this << ", fd=" << fd_
3073           << ", evb=" << eventBase_ << " , state=" << state_
3074           << ", events=" << std::hex << eventFlags_ << ")";
3075   for (const auto& cb : lifecycleObservers_) {
3076     if (auto dCb = dynamic_cast<AsyncSocket::LifecycleObserver*>(cb)) {
3077       dCb->fdAttach(this);
3078     }
3079   }
3080 
3081   if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
3082     shutdownSocketSet->add(fd_);
3083   }
3084   ioHandler_.changeHandlerFD(fd_);
3085 }
3086 
tfoSendMsg(NetworkSocket fd,struct msghdr * msg,int msg_flags)3087 ssize_t AsyncSocket::tfoSendMsg(
3088     NetworkSocket fd, struct msghdr* msg, int msg_flags) {
3089   return detail::tfo_sendmsg(fd, msg, msg_flags);
3090 }
3091 
sendSocketMessage(const iovec * vec,size_t count,WriteFlags flags)3092 AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
3093     const iovec* vec, size_t count, WriteFlags flags) {
3094   // lambda to gather and merge PrewriteRequests from observers
3095   auto mergePrewriteRequests = [this,
3096                                 vec,
3097                                 count,
3098                                 flags,
3099                                 maybeVecTotalBytes =
3100                                     folly::Optional<size_t>()]() mutable {
3101     AsyncTransport::LifecycleObserver::PrewriteRequest mergedRequest = {};
3102     if (lifecycleObservers_.empty()) {
3103       return mergedRequest;
3104     }
3105 
3106     // determine total number of bytes in vec, reuse once determined
3107     if (!maybeVecTotalBytes.has_value()) {
3108       maybeVecTotalBytes = 0;
3109       for (size_t i = 0; i < count; ++i) {
3110         maybeVecTotalBytes.value() += vec[i].iov_len;
3111       }
3112     }
3113     auto& vecTotalBytes = maybeVecTotalBytes.value();
3114 
3115     const auto startOffset = getRawBytesWritten();
3116     const auto endOffset = getRawBytesWritten() + vecTotalBytes - 1;
3117     const AsyncTransport::LifecycleObserver::PrewriteState prewriteState = [&] {
3118       AsyncTransport::LifecycleObserver::PrewriteState state = {};
3119       state.startOffset = startOffset;
3120       state.endOffset = endOffset;
3121       state.writeFlags = flags;
3122       state.ts = std::chrono::steady_clock::now();
3123       return state;
3124     }();
3125     for (const auto& observer : lifecycleObservers_) {
3126       if (!observer->getConfig().prewrite) {
3127         continue;
3128       }
3129 
3130       const auto request = observer->prewrite(this, prewriteState);
3131 
3132       mergedRequest.writeFlagsToAdd |= request.writeFlagsToAdd;
3133       if (request.maybeOffsetToSplitWrite.has_value()) {
3134         CHECK_GE(endOffset, request.maybeOffsetToSplitWrite.value());
3135         if (
3136             // case 1: offset not set in merged request
3137             !mergedRequest.maybeOffsetToSplitWrite.has_value() ||
3138             // case 2: offset in merged request > offset in current request
3139             mergedRequest.maybeOffsetToSplitWrite >
3140                 request.maybeOffsetToSplitWrite) {
3141           mergedRequest.maybeOffsetToSplitWrite =
3142               request.maybeOffsetToSplitWrite; // update
3143           mergedRequest.writeFlagsToAddAtOffset =
3144               request.writeFlagsToAddAtOffset; // reset
3145         } else if (
3146             // case 3: offset in merged request == offset in current request
3147             request.maybeOffsetToSplitWrite ==
3148             mergedRequest.maybeOffsetToSplitWrite) {
3149           mergedRequest.writeFlagsToAddAtOffset |=
3150               request.writeFlagsToAddAtOffset; // merge
3151         }
3152         // case 4: offset in merged request < offset in current request
3153         // (do nothing)
3154       }
3155     }
3156 
3157     // if maybeOffsetToSplitWrite points to end of the vector, remove the split
3158     if (mergedRequest.maybeOffsetToSplitWrite.has_value() && // explicit
3159         mergedRequest.maybeOffsetToSplitWrite == endOffset) {
3160       mergedRequest.maybeOffsetToSplitWrite.reset(); // no split needed
3161     }
3162 
3163     return mergedRequest;
3164   };
3165 
3166   // lambda to prepare and send a message, and handle byte events
3167   // parameters have L at the end to prevent shadowing warning from gcc
3168   auto prepSendMsg = [this](
3169                          const iovec* vecL,
3170                          const size_t countL,
3171                          const WriteFlags flagsL) {
3172     const bool byteEventsEnabled =
3173         (byteEventHelper_ && byteEventHelper_->byteEventsEnabled &&
3174          !byteEventHelper_->maybeEx.has_value());
3175 
3176     struct msghdr msg = {};
3177     msg.msg_name = nullptr;
3178     msg.msg_namelen = 0;
3179     msg.msg_iov = const_cast<struct iovec*>(vecL);
3180     msg.msg_iovlen = std::min<size_t>(countL, kIovMax);
3181     msg.msg_flags = 0; // passed to sendSocketMessage below, it sets them
3182     msg.msg_control = nullptr;
3183     msg.msg_controllen =
3184         sendMsgParamCallback_->getAncillaryDataSize(flagsL, byteEventsEnabled);
3185     CHECK_GE(
3186         AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
3187         msg.msg_controllen);
3188 
3189     if (msg.msg_controllen != 0) {
3190       msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
3191       sendMsgParamCallback_->getAncillaryData(
3192           flagsL, msg.msg_control, byteEventsEnabled);
3193     }
3194 
3195     const auto prewriteRawBytesWritten = getRawBytesWritten();
3196     int msg_flags = sendMsgParamCallback_->getFlags(flagsL, zeroCopyEnabled_);
3197     auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
3198 
3199     if (writeResult.writeReturn < 0 && zeroCopyEnabled_ && errno == ENOBUFS) {
3200       // workaround for running with zerocopy enabled but without a big enough
3201       // memlock value - see ulimit -l
3202       zeroCopyEnabled_ = false;
3203       zeroCopyReenableCounter_ = zeroCopyReenableThreshold_;
3204       msg_flags = sendMsgParamCallback_->getFlags(flagsL, zeroCopyEnabled_);
3205       writeResult = sendSocketMessage(fd_, &msg, msg_flags);
3206     }
3207 
3208     if (writeResult.writeReturn > 0 && byteEventsEnabled &&
3209         isSet(flagsL, WriteFlags::TIMESTAMP_WRITE)) {
3210       CHECK_GT(getRawBytesWritten(), prewriteRawBytesWritten); // sanity check
3211       ByteEvent byteEvent = {};
3212       byteEvent.type = ByteEvent::Type::WRITE;
3213       byteEvent.offset = getRawBytesWritten() - 1;
3214       byteEvent.maybeRawBytesWritten = writeResult.writeReturn;
3215       byteEvent.maybeRawBytesTriedToWrite = 0;
3216       for (size_t i = 0; i < countL; ++i) {
3217         byteEvent.maybeRawBytesTriedToWrite.value() += vecL[i].iov_len;
3218       }
3219       byteEvent.maybeWriteFlags = flagsL;
3220       for (const auto& observer : lifecycleObservers_) {
3221         if (observer->getConfig().byteEvents) {
3222           observer->byteEvent(this, byteEvent);
3223         }
3224       }
3225     }
3226 
3227     return writeResult;
3228   };
3229 
3230   // get PrewriteRequests (if any), merge flags with write flags
3231   const auto prewriteRequest = mergePrewriteRequests();
3232   auto mergedFlags = flags | prewriteRequest.writeFlagsToAdd |
3233       prewriteRequest.writeFlagsToAddAtOffset;
3234 
3235   // if no PrewriteRequests, or none requiring the write to be split, proceed
3236   if (!prewriteRequest.maybeOffsetToSplitWrite.has_value()) {
3237     return prepSendMsg(vec, count, mergedFlags);
3238   }
3239 
3240   // we need to split the write...
3241   // add CORK flag to inform the OS that more data is on the way...
3242   mergedFlags |= WriteFlags::CORK;
3243 
3244   // TODO(bschlinker): When prewrite splits a write, try to continue writing
3245   // after a write returns; this will improve efficiency.
3246   const auto splitWriteAtOffset = *prewriteRequest.maybeOffsetToSplitWrite;
3247   if (count <= kSmallIoVecSize) {
3248     // suppress "warning: variable length array 'vec' is used [-Wvla]"
3249     FOLLY_PUSH_WARNING
3250     FOLLY_GNU_DISABLE_WARNING("-Wvla")
3251     iovec tmpVec[BOOST_PP_IF(FOLLY_HAVE_VLA_01, count, kSmallIoVecSize)];
3252     FOLLY_POP_WARNING
3253 
3254     size_t tmpVecCount = count;
3255     splitIovecArray(
3256         0,
3257         splitWriteAtOffset - getRawBytesWritten(),
3258         vec,
3259         count,
3260         tmpVec,
3261         tmpVecCount);
3262     return prepSendMsg(tmpVec, tmpVecCount, mergedFlags);
3263   } else {
3264     auto tmpVecPtr = std::make_unique<iovec[]>(count);
3265     auto tmpVec = tmpVecPtr.get();
3266     size_t tmpVecCount = count;
3267     splitIovecArray(
3268         0,
3269         splitWriteAtOffset - getRawBytesWritten(),
3270         vec,
3271         count,
3272         tmpVec,
3273         tmpVecCount);
3274     return prepSendMsg(tmpVec, tmpVecCount, mergedFlags);
3275   }
3276 }
3277 
sendSocketMessage(NetworkSocket fd,struct msghdr * msg,int msg_flags)3278 AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
3279     NetworkSocket fd, struct msghdr* msg, int msg_flags) {
3280   ssize_t totalWritten = 0;
3281   SCOPE_EXIT {
3282     if (totalWritten > 0) {
3283       rawBytesWritten_ += totalWritten;
3284     }
3285   };
3286   if (state_ == StateEnum::FAST_OPEN) {
3287     sockaddr_storage addr;
3288     auto len = addr_.getAddress(&addr);
3289     msg->msg_name = &addr;
3290     msg->msg_namelen = len;
3291     totalWritten = tfoSendMsg(fd_, msg, msg_flags);
3292     if (totalWritten >= 0) {
3293       tfoFinished_ = true;
3294       state_ = StateEnum::ESTABLISHED;
3295       // We schedule this asynchrously so that we don't end up
3296       // invoking initial read or write while a write is in progress.
3297       scheduleInitialReadWrite();
3298     } else if (errno == EINPROGRESS) {
3299       VLOG(4) << "TFO falling back to connecting";
3300       // A normal sendmsg doesn't return EINPROGRESS, however
3301       // TFO might fallback to connecting if there is no
3302       // cookie.
3303       state_ = StateEnum::CONNECTING;
3304       try {
3305         scheduleConnectTimeout();
3306         registerForConnectEvents();
3307       } catch (const AsyncSocketException& ex) {
3308         return WriteResult(
3309             WRITE_ERROR, std::make_unique<AsyncSocketException>(ex));
3310       }
3311       // Let's fake it that no bytes were written and return an errno.
3312       errno = EAGAIN;
3313       totalWritten = -1;
3314     } else if (errno == EOPNOTSUPP) {
3315       // Try falling back to connecting.
3316       VLOG(4) << "TFO not supported";
3317       state_ = StateEnum::CONNECTING;
3318       try {
3319         int ret = socketConnect((const sockaddr*)&addr, len);
3320         if (ret == 0) {
3321           // connect succeeded immediately
3322           // Treat this like no data was written.
3323           state_ = StateEnum::ESTABLISHED;
3324           scheduleInitialReadWrite();
3325         }
3326         // If there was no exception during connections,
3327         // we would return that no bytes were written.
3328         errno = EAGAIN;
3329         totalWritten = -1;
3330       } catch (const AsyncSocketException& ex) {
3331         return WriteResult(
3332             WRITE_ERROR, std::make_unique<AsyncSocketException>(ex));
3333       }
3334     } else if (errno == EAGAIN) {
3335       // Normally sendmsg would indicate that the write would block.
3336       // However in the fast open case, it would indicate that sendmsg
3337       // fell back to a connect. This is a return code from connect()
3338       // instead, and is an error condition indicating no fds available.
3339       return WriteResult(
3340           WRITE_ERROR,
3341           std::make_unique<AsyncSocketException>(
3342               AsyncSocketException::UNKNOWN, "No more free local ports"));
3343     }
3344   } else {
3345     totalWritten = netops_->sendmsg(fd, msg, msg_flags);
3346   }
3347   return WriteResult(totalWritten);
3348 }
3349 
performWrite(const iovec * vec,uint32_t count,WriteFlags flags,uint32_t * countWritten,uint32_t * partialWritten)3350 AsyncSocket::WriteResult AsyncSocket::performWrite(
3351     const iovec* vec,
3352     uint32_t count,
3353     WriteFlags flags,
3354     uint32_t* countWritten,
3355     uint32_t* partialWritten) {
3356   auto writeResult = sendSocketMessage(vec, count, flags);
3357   auto totalWritten = writeResult.writeReturn;
3358   if (totalWritten < 0) {
3359     bool tryAgain = (errno == EAGAIN);
3360 #ifdef __APPLE__
3361     // Apple has a bug where doing a second write on a socket which we
3362     // have opened with TFO causes an ENOTCONN to be thrown. However the
3363     // socket is really connected, so treat ENOTCONN as a EAGAIN until
3364     // this bug is fixed.
3365     tryAgain |= (errno == ENOTCONN);
3366 #endif
3367 
3368     if (!writeResult.exception && tryAgain) {
3369       // TCP buffer is full; we can't write any more data right now.
3370       *countWritten = 0;
3371       *partialWritten = 0;
3372       return WriteResult(0);
3373     }
3374     // error
3375     *countWritten = 0;
3376     *partialWritten = 0;
3377     return writeResult;
3378   }
3379 
3380   appBytesWritten_ += totalWritten;
3381 
3382   uint32_t bytesWritten;
3383   uint32_t n;
3384   for (bytesWritten = uint32_t(totalWritten), n = 0; n < count; ++n) {
3385     const iovec* v = vec + n;
3386     if (v->iov_len > bytesWritten) {
3387       // Partial write finished in the middle of this iovec
3388       *countWritten = n;
3389       *partialWritten = bytesWritten;
3390       return WriteResult(totalWritten);
3391     }
3392 
3393     bytesWritten -= uint32_t(v->iov_len);
3394   }
3395 
3396   assert(bytesWritten == 0);
3397   *countWritten = n;
3398   *partialWritten = 0;
3399   return WriteResult(totalWritten);
3400 }
3401 
3402 /**
3403  * Re-register the EventHandler after eventFlags_ has changed.
3404  *
3405  * If an error occurs, fail() is called to move the socket into the error state
3406  * and call all currently installed callbacks.  After an error, the
3407  * AsyncSocket is completely unregistered.
3408  *
3409  * @return Returns true on success, or false on error.
3410  */
updateEventRegistration()3411 bool AsyncSocket::updateEventRegistration() {
3412   VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
3413           << ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
3414           << ", events=" << std::hex << eventFlags_;
3415   if (eventFlags_ == EventHandler::NONE) {
3416     if (ioHandler_.isHandlerRegistered()) {
3417       DCHECK(eventBase_ != nullptr);
3418       eventBase_->dcheckIsInEventBaseThread();
3419     }
3420     ioHandler_.unregisterHandler();
3421     return true;
3422   }
3423 
3424   eventBase_->dcheckIsInEventBaseThread();
3425 
3426   // Always register for persistent events, so we don't have to re-register
3427   // after being called back.
3428   if (!ioHandler_.registerHandler(
3429           uint16_t(eventFlags_ | EventHandler::PERSIST))) {
3430     eventFlags_ = EventHandler::NONE; // we're not registered after error
3431     AsyncSocketException ex(
3432         AsyncSocketException::INTERNAL_ERROR,
3433         withAddr("failed to update AsyncSocket event registration"));
3434     fail("updateEventRegistration", ex);
3435     return false;
3436   }
3437 
3438   return true;
3439 }
3440 
updateEventRegistration(uint16_t enable,uint16_t disable)3441 bool AsyncSocket::updateEventRegistration(uint16_t enable, uint16_t disable) {
3442   uint16_t oldFlags = eventFlags_;
3443   eventFlags_ |= enable;
3444   eventFlags_ &= ~disable;
3445   if (eventFlags_ == oldFlags) {
3446     return true;
3447   } else {
3448     return updateEventRegistration();
3449   }
3450 }
3451 
startFail()3452 void AsyncSocket::startFail() {
3453   // startFail() should only be called once
3454   assert(state_ != StateEnum::ERROR);
3455   assert(getDestructorGuardCount() > 0);
3456   state_ = StateEnum::ERROR;
3457   // Ensure that SHUT_READ and SHUT_WRITE are set,
3458   // so all future attempts to read or write will be rejected
3459   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
3460 
3461   // Cancel any scheduled immediate read.
3462   if (immediateReadHandler_.isLoopCallbackScheduled()) {
3463     immediateReadHandler_.cancelLoopCallback();
3464   }
3465 
3466   if (eventFlags_ != EventHandler::NONE) {
3467     eventFlags_ = EventHandler::NONE;
3468     ioHandler_.unregisterHandler();
3469   }
3470   writeTimeout_.cancelTimeout();
3471 
3472   if (fd_ != NetworkSocket()) {
3473     ioHandler_.changeHandlerFD(NetworkSocket());
3474     doClose();
3475   }
3476 }
3477 
invokeAllErrors(const AsyncSocketException & ex)3478 void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) {
3479   invokeConnectErr(ex);
3480   failAllWrites(ex);
3481 
3482   if (readCallback_) {
3483     ReadCallback* callback = readCallback_;
3484     readCallback_ = nullptr;
3485     callback->readErr(ex);
3486   }
3487 }
3488 
finishFail()3489 void AsyncSocket::finishFail() {
3490   assert(state_ == StateEnum::ERROR);
3491   assert(getDestructorGuardCount() > 0);
3492 
3493   AsyncSocketException ex(
3494       AsyncSocketException::INTERNAL_ERROR,
3495       withAddr("socket closing after error"));
3496   invokeAllErrors(ex);
3497 }
3498 
finishFail(const AsyncSocketException & ex)3499 void AsyncSocket::finishFail(const AsyncSocketException& ex) {
3500   assert(state_ == StateEnum::ERROR);
3501   assert(getDestructorGuardCount() > 0);
3502   invokeAllErrors(ex);
3503 }
3504 
fail(const char * fn,const AsyncSocketException & ex)3505 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
3506   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
3507           << ", state=" << state_ << " host=" << addr_.describe()
3508           << "): failed in " << fn << "(): " << ex.what();
3509   startFail();
3510   finishFail(ex);
3511 }
3512 
failConnect(const char * fn,const AsyncSocketException & ex)3513 void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
3514   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3515           << ", state=" << state_ << " host=" << addr_.describe()
3516           << "): failed while connecting in " << fn << "(): " << ex.what();
3517   startFail();
3518 
3519   invokeConnectErr(ex);
3520   finishFail(ex);
3521 }
3522 
failRead(const char * fn,const AsyncSocketException & ex)3523 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
3524   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3525           << ", state=" << state_ << " host=" << addr_.describe()
3526           << "): failed while reading in " << fn << "(): " << ex.what();
3527   startFail();
3528 
3529   if (readCallback_ != nullptr) {
3530     ReadCallback* callback = readCallback_;
3531     readCallback_ = nullptr;
3532     callback->readErr(ex);
3533   }
3534 
3535   finishFail(ex);
3536 }
3537 
failErrMessageRead(const char * fn,const AsyncSocketException & ex)3538 void AsyncSocket::failErrMessageRead(
3539     const char* fn, const AsyncSocketException& ex) {
3540   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3541           << ", state=" << state_ << " host=" << addr_.describe()
3542           << "): failed while reading message in " << fn << "(): " << ex.what();
3543   startFail();
3544 
3545   if (errMessageCallback_ != nullptr) {
3546     ErrMessageCallback* callback = errMessageCallback_;
3547     errMessageCallback_ = nullptr;
3548     callback->errMessageError(ex);
3549   }
3550 
3551   finishFail(ex);
3552 }
3553 
failWrite(const char * fn,const AsyncSocketException & ex)3554 void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
3555   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3556           << ", state=" << state_ << " host=" << addr_.describe()
3557           << "): failed while writing in " << fn << "(): " << ex.what();
3558   startFail();
3559 
3560   // Only invoke the first write callback, since the error occurred while
3561   // writing this request.  Let any other pending write callbacks be invoked in
3562   // finishFail().
3563   if (writeReqHead_ != nullptr) {
3564     WriteRequest* req = writeReqHead_;
3565     writeReqHead_ = req->getNext();
3566     WriteCallback* callback = req->getCallback();
3567     uint32_t bytesWritten = req->getTotalBytesWritten();
3568     req->destroy();
3569     if (callback) {
3570       callback->writeErr(bytesWritten, ex);
3571     }
3572   }
3573 
3574   finishFail(ex);
3575 }
3576 
failWrite(const char * fn,WriteCallback * callback,size_t bytesWritten,const AsyncSocketException & ex)3577 void AsyncSocket::failWrite(
3578     const char* fn,
3579     WriteCallback* callback,
3580     size_t bytesWritten,
3581     const AsyncSocketException& ex) {
3582   // This version of failWrite() is used when the failure occurs before
3583   // we've added the callback to writeReqHead_.
3584   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
3585           << ", state=" << state_ << " host=" << addr_.describe()
3586           << "): failed while writing in " << fn << "(): " << ex.what();
3587   if (closeOnFailedWrite_) {
3588     startFail();
3589   }
3590 
3591   if (callback != nullptr) {
3592     callback->writeErr(bytesWritten, ex);
3593   }
3594 
3595   if (closeOnFailedWrite_) {
3596     finishFail(ex);
3597   }
3598 }
3599 
failAllWrites(const AsyncSocketException & ex)3600 void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
3601   // Invoke writeError() on all write callbacks.
3602   // This is used when writes are forcibly shutdown with write requests
3603   // pending, or when an error occurs with writes pending.
3604   while (writeReqHead_ != nullptr) {
3605     WriteRequest* req = writeReqHead_;
3606     writeReqHead_ = req->getNext();
3607     WriteCallback* callback = req->getCallback();
3608     if (callback) {
3609       callback->writeErr(req->getTotalBytesWritten(), ex);
3610     }
3611     req->destroy();
3612   }
3613 
3614   // All pending writes have failed - reset totalAppBytesScheduledForWrite_
3615   totalAppBytesScheduledForWrite_ = appBytesWritten_;
3616 }
3617 
failByteEvents(const AsyncSocketException & ex)3618 void AsyncSocket::failByteEvents(const AsyncSocketException& ex) {
3619   CHECK(byteEventHelper_) << "failByteEvents called without ByteEventHelper";
3620   byteEventHelper_->maybeEx = ex;
3621   // inform any observers that want ByteEvents
3622   for (const auto& observer : lifecycleObservers_) {
3623     if (observer->getConfig().byteEvents) {
3624       observer->byteEventsUnavailable(this, ex);
3625     }
3626   }
3627 }
3628 
invalidState(ConnectCallback * callback)3629 void AsyncSocket::invalidState(ConnectCallback* callback) {
3630   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3631           << "): connect() called in invalid state " << state_;
3632 
3633   /*
3634    * The invalidState() methods don't use the normal failure mechanisms,
3635    * since we don't know what state we are in.  We don't want to call
3636    * startFail()/finishFail() recursively if we are already in the middle of
3637    * cleaning up.
3638    */
3639 
3640   AsyncSocketException ex(
3641       AsyncSocketException::ALREADY_OPEN,
3642       "connect() called with socket in invalid state");
3643   connectEndTime_ = std::chrono::steady_clock::now();
3644   if ((state_ == StateEnum::CONNECTING) || (state_ == StateEnum::ERROR)) {
3645     for (const auto& cb : lifecycleObservers_) {
3646       if (auto observer = dynamic_cast<AsyncSocket::LifecycleObserver*>(cb)) {
3647         // inform any lifecycle observes that the connection failed
3648         observer->connectError(this, ex);
3649       }
3650     }
3651   }
3652   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
3653     if (callback) {
3654       callback->connectErr(ex);
3655     }
3656   } else {
3657     // We can't use failConnect() here since connectCallback_
3658     // may already be set to another callback.  Invoke this ConnectCallback
3659     // here; any other connectCallback_ will be invoked in finishFail()
3660     startFail();
3661     if (callback) {
3662       callback->connectErr(ex);
3663     }
3664     finishFail(ex);
3665   }
3666 }
3667 
invalidState(ErrMessageCallback * callback)3668 void AsyncSocket::invalidState(ErrMessageCallback* callback) {
3669   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
3670           << "): setErrMessageCB(" << callback << ") called in invalid state "
3671           << state_;
3672 
3673   AsyncSocketException ex(
3674       AsyncSocketException::NOT_OPEN,
3675       msgErrQueueSupported
3676           ? "setErrMessageCB() called with socket in invalid state"
3677           : "This platform does not support socket error message notifications");
3678   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
3679     if (callback) {
3680       callback->errMessageError(ex);
3681     }
3682   } else {
3683     startFail();
3684     if (callback) {
3685       callback->errMessageError(ex);
3686     }
3687     finishFail(ex);
3688   }
3689 }
3690 
invokeConnectErr(const AsyncSocketException & ex)3691 void AsyncSocket::invokeConnectErr(const AsyncSocketException& ex) {
3692   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3693           << "): connect err invoked with ex: " << ex.what();
3694   connectEndTime_ = std::chrono::steady_clock::now();
3695   if ((state_ == StateEnum::CONNECTING) || (state_ == StateEnum::ERROR)) {
3696     // invokeConnectErr() can be invoked when state is {FAST_OPEN, CLOSED,
3697     // ESTABLISHED} (!?) and a bunch of other places that are not what this call
3698     // back wants. This seems like a bug but work around here while we explore
3699     // it independently
3700     for (const auto& cb : lifecycleObservers_) {
3701       cb->connectError(this, ex);
3702     }
3703   }
3704   if (connectCallback_) {
3705     ConnectCallback* callback = connectCallback_;
3706     connectCallback_ = nullptr;
3707     callback->connectErr(ex);
3708   }
3709 }
3710 
invokeConnectSuccess()3711 void AsyncSocket::invokeConnectSuccess() {
3712   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3713           << "): connect success invoked";
3714   connectEndTime_ = std::chrono::steady_clock::now();
3715   bool enableByteEventsForObserver = false;
3716   for (const auto& cb : lifecycleObservers_) {
3717     cb->connectSuccess(this);
3718     enableByteEventsForObserver |= ((cb->getConfig().byteEvents) ? 1 : 0);
3719   }
3720   if (enableByteEventsForObserver) {
3721     enableByteEvents();
3722   }
3723   if (connectCallback_) {
3724     ConnectCallback* callback = connectCallback_;
3725     connectCallback_ = nullptr;
3726     callback->connectSuccess();
3727   }
3728 }
3729 
invokeConnectAttempt()3730 void AsyncSocket::invokeConnectAttempt() {
3731   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3732           << "): connect attempt";
3733   for (const auto& cb : lifecycleObservers_) {
3734     cb->connectAttempt(this);
3735   }
3736 }
3737 
invalidState(ReadCallback * callback)3738 void AsyncSocket::invalidState(ReadCallback* callback) {
3739   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
3740           << "): setReadCallback(" << callback << ") called in invalid state "
3741           << state_;
3742 
3743   AsyncSocketException ex(
3744       AsyncSocketException::NOT_OPEN,
3745       "setReadCallback() called with socket in "
3746       "invalid state");
3747   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
3748     if (callback) {
3749       callback->readErr(ex);
3750     }
3751   } else {
3752     startFail();
3753     if (callback) {
3754       callback->readErr(ex);
3755     }
3756     finishFail(ex);
3757   }
3758 }
3759 
invalidState(WriteCallback * callback)3760 void AsyncSocket::invalidState(WriteCallback* callback) {
3761   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
3762           << "): write() called in invalid state " << state_;
3763 
3764   AsyncSocketException ex(
3765       AsyncSocketException::NOT_OPEN,
3766       withAddr("write() called with socket in invalid state"));
3767   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
3768     if (callback) {
3769       callback->writeErr(0, ex);
3770     }
3771   } else {
3772     startFail();
3773     if (callback) {
3774       callback->writeErr(0, ex);
3775     }
3776     finishFail(ex);
3777   }
3778 }
3779 
doClose()3780 void AsyncSocket::doClose() {
3781   for (const auto& cb : lifecycleObservers_) {
3782     cb->close(this);
3783   }
3784   if (fd_ == NetworkSocket()) {
3785     return;
3786   }
3787   if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
3788     shutdownSocketSet->close(fd_);
3789   } else {
3790     netops_->close(fd_);
3791   }
3792   fd_ = NetworkSocket();
3793 
3794   // we also want to clear the zerocopy maps
3795   // if the fd has been closed
3796   idZeroCopyBufPtrMap_.clear();
3797   idZeroCopyBufInfoMap_.clear();
3798 }
3799 
operator <<(std::ostream & os,const AsyncSocket::StateEnum & state)3800 std::ostream& operator<<(
3801     std::ostream& os, const AsyncSocket::StateEnum& state) {
3802   os << static_cast<int>(state);
3803   return os;
3804 }
3805 
withAddr(folly::StringPiece s)3806 std::string AsyncSocket::withAddr(folly::StringPiece s) {
3807   // Don't use addr_ directly because it may not be initialized
3808   // e.g. if constructed from fd
3809   folly::SocketAddress peer, local;
3810   try {
3811     getLocalAddress(&local);
3812   } catch (...) {
3813     // ignore
3814   }
3815   try {
3816     getPeerAddress(&peer);
3817   } catch (...) {
3818     // ignore
3819   }
3820 
3821   return fmt::format(
3822       "{} (peer={}{})",
3823       s,
3824       peer.describe(),
3825       kIsMobile ? "" : fmt::format(", local={}", local.describe()));
3826 }
3827 
setBufferCallback(BufferCallback * cb)3828 void AsyncSocket::setBufferCallback(BufferCallback* cb) {
3829   bufferCallback_ = cb;
3830 }
3831 
3832 } // namespace folly
3833