1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <folly/io/async/AsyncSocket.h>
18
19 #include <sys/types.h>
20
21 #include <cerrno>
22 #include <climits>
23 #include <sstream>
24 #include <thread>
25
26 #include <boost/preprocessor/control/if.hpp>
27
28 #include <folly/Exception.h>
29 #include <folly/ExceptionWrapper.h>
30 #include <folly/Format.h>
31 #include <folly/Portability.h>
32 #include <folly/SocketAddress.h>
33 #include <folly/String.h>
34 #include <folly/io/Cursor.h>
35 #include <folly/io/IOBuf.h>
36 #include <folly/io/IOBufQueue.h>
37 #include <folly/io/SocketOptionMap.h>
38 #include <folly/portability/Fcntl.h>
39 #include <folly/portability/Sockets.h>
40 #include <folly/portability/SysUio.h>
41 #include <folly/portability/Unistd.h>
42
43 #if defined(__linux__)
44 #include <linux/if_packet.h>
45 #include <linux/sockios.h>
46 #include <sys/ioctl.h>
47 #endif
48
49 #if FOLLY_HAVE_VLA
50 #define FOLLY_HAVE_VLA_01 1
51 #else
52 #define FOLLY_HAVE_VLA_01 0
53 #endif
54
55 using std::string;
56 using std::unique_ptr;
57
58 namespace fsp = folly::portability::sockets;
59
60 namespace folly {
61
62 static constexpr bool msgErrQueueSupported =
63 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
64 true;
65 #else
66 false;
67 #endif // FOLLY_HAVE_MSG_ERRQUEUE
68
getSocketClosedLocallyEx()69 static AsyncSocketException const& getSocketClosedLocallyEx() {
70 static auto& ex = *new AsyncSocketException(
71 AsyncSocketException::END_OF_FILE, "socket closed locally");
72 return ex;
73 }
74
getSocketShutdownForWritesEx()75 static AsyncSocketException const& getSocketShutdownForWritesEx() {
76 static auto& ex = *new AsyncSocketException(
77 AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
78 return ex;
79 }
80
81 namespace {
82 #if FOLLY_HAVE_SO_TIMESTAMPING
83 const sock_extended_err* FOLLY_NULLABLE
cmsgToSockExtendedErr(const cmsghdr & cmsg)84 cmsgToSockExtendedErr(const cmsghdr& cmsg) {
85 if ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
86 (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR) ||
87 (cmsg.cmsg_level == SOL_PACKET &&
88 cmsg.cmsg_type == PACKET_TX_TIMESTAMP)) {
89 return reinterpret_cast<const sock_extended_err*>(CMSG_DATA(&cmsg));
90 }
91 (void)cmsg;
92 return nullptr;
93 }
94
95 const sock_extended_err* FOLLY_NULLABLE
cmsgToSockExtendedErrTimestamping(const cmsghdr & cmsg)96 cmsgToSockExtendedErrTimestamping(const cmsghdr& cmsg) {
97 const auto serr = cmsgToSockExtendedErr(cmsg);
98 if (serr && serr->ee_errno == ENOMSG &&
99 serr->ee_origin == SO_EE_ORIGIN_TIMESTAMPING) {
100 return serr;
101 }
102 (void)cmsg;
103 return nullptr;
104 }
105
106 const scm_timestamping* FOLLY_NULLABLE
cmsgToScmTimestamping(const cmsghdr & cmsg)107 cmsgToScmTimestamping(const cmsghdr& cmsg) {
108 if (cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_TIMESTAMPING) {
109 return reinterpret_cast<const struct scm_timestamping*>(CMSG_DATA(&cmsg));
110 }
111 (void)cmsg;
112 return nullptr;
113 }
114
115 #endif // FOLLY_HAVE_SO_TIMESTAMPING
116 } // namespace
117
118 // TODO: It might help performance to provide a version of BytesWriteRequest
119 // that users could derive from, so we can avoid the extra allocation for each
120 // call to write()/writev().
121 //
122 // We would need the version for external users where they provide the iovec
123 // storage space, and only our internal version would allocate it at the end of
124 // the WriteRequest.
125
126 /* The default WriteRequest implementation, used for write(), writev() and
127 * writeChain()
128 *
129 * A new BytesWriteRequest operation is allocated on the heap for all write
130 * operations that cannot be completed immediately.
131 */
132 class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
133 public:
newRequest(AsyncSocket * socket,WriteCallback * callback,const iovec * ops,uint32_t opCount,uint32_t partialWritten,uint32_t bytesWritten,unique_ptr<IOBuf> && ioBuf,WriteFlags flags)134 static BytesWriteRequest* newRequest(
135 AsyncSocket* socket,
136 WriteCallback* callback,
137 const iovec* ops,
138 uint32_t opCount,
139 uint32_t partialWritten,
140 uint32_t bytesWritten,
141 unique_ptr<IOBuf>&& ioBuf,
142 WriteFlags flags) {
143 assert(opCount > 0);
144 // Since we put a variable size iovec array at the end
145 // of each BytesWriteRequest, we have to manually allocate the memory.
146 void* buf =
147 malloc(sizeof(BytesWriteRequest) + (opCount * sizeof(struct iovec)));
148 if (buf == nullptr) {
149 throw std::bad_alloc();
150 }
151
152 return new (buf) BytesWriteRequest(
153 socket,
154 callback,
155 ops,
156 opCount,
157 partialWritten,
158 bytesWritten,
159 std::move(ioBuf),
160 flags);
161 }
162
destroy()163 void destroy() override {
164 socket_->releaseIOBuf(std::move(ioBuf_), releaseIOBufCallback_);
165 this->~BytesWriteRequest();
166 free(this);
167 }
168
performWrite()169 WriteResult performWrite() override {
170 WriteFlags writeFlags = flags_;
171 if (getNext() != nullptr) {
172 writeFlags |= WriteFlags::CORK;
173 }
174
175 socket_->adjustZeroCopyFlags(writeFlags);
176
177 auto writeResult = socket_->performWrite(
178 getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
179 bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
180 if (bytesWritten_) {
181 if (socket_->isZeroCopyRequest(writeFlags)) {
182 if (isComplete()) {
183 socket_->addZeroCopyBuf(std::move(ioBuf_), releaseIOBufCallback_);
184 } else {
185 socket_->addZeroCopyBuf(ioBuf_.get());
186 }
187 } else {
188 // this happens if at least one of the prev requests were sent
189 // with zero copy but not the last one
190 if (isComplete() && zeroCopyRequest_ &&
191 socket_->containsZeroCopyBuf(ioBuf_.get())) {
192 socket_->setZeroCopyBuf(std::move(ioBuf_), releaseIOBufCallback_);
193 }
194 }
195 }
196 return writeResult;
197 }
198
isComplete()199 bool isComplete() override { return opsWritten_ == getOpCount(); }
200
consume()201 void consume() override {
202 // Advance opIndex_ forward by opsWritten_
203 opIndex_ += opsWritten_;
204 assert(opIndex_ < opCount_);
205
206 bool zeroCopyReq = socket_->isZeroCopyRequest(flags_);
207 if (zeroCopyReq) {
208 zeroCopyRequest_ = true;
209 }
210
211 if (!zeroCopyRequest_) {
212 // If we've finished writing any IOBufs, release them
213 // but only if we did not send any of them via zerocopy
214 if (ioBuf_) {
215 for (uint32_t i = opsWritten_; i != 0; --i) {
216 assert(ioBuf_);
217 auto next = ioBuf_->pop();
218 socket_->releaseIOBuf(std::move(ioBuf_), releaseIOBufCallback_);
219 ioBuf_ = std::move(next);
220 }
221 }
222 }
223
224 // Move partialBytes_ forward into the current iovec buffer
225 struct iovec* currentOp = writeOps_ + opIndex_;
226 assert((partialBytes_ < currentOp->iov_len) || (currentOp->iov_len == 0));
227 currentOp->iov_base =
228 reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes_;
229 currentOp->iov_len -= partialBytes_;
230
231 // Increment the totalBytesWritten_ count by bytesWritten_;
232 assert(bytesWritten_ >= 0);
233 totalBytesWritten_ += uint32_t(bytesWritten_);
234 }
235
236 private:
BytesWriteRequest(AsyncSocket * socket,WriteCallback * callback,const struct iovec * ops,uint32_t opCount,uint32_t partialBytes,uint32_t bytesWritten,unique_ptr<IOBuf> && ioBuf,WriteFlags flags)237 BytesWriteRequest(
238 AsyncSocket* socket,
239 WriteCallback* callback,
240 const struct iovec* ops,
241 uint32_t opCount,
242 uint32_t partialBytes,
243 uint32_t bytesWritten,
244 unique_ptr<IOBuf>&& ioBuf,
245 WriteFlags flags)
246 : AsyncSocket::WriteRequest(socket, callback),
247 opCount_(opCount),
248 opIndex_(0),
249 flags_(flags),
250 ioBuf_(std::move(ioBuf)),
251 opsWritten_(0),
252 partialBytes_(partialBytes),
253 bytesWritten_(bytesWritten) {
254 memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
255 zeroCopyRequest_ = socket_->isZeroCopyRequest(flags_);
256 }
257
258 // private destructor, to ensure callers use destroy()
259 ~BytesWriteRequest() override = default;
260
getOps() const261 const struct iovec* getOps() const {
262 assert(opCount_ > opIndex_);
263 return writeOps_ + opIndex_;
264 }
265
getOpCount() const266 uint32_t getOpCount() const {
267 assert(opCount_ > opIndex_);
268 return opCount_ - opIndex_;
269 }
270
271 uint32_t opCount_; ///< number of entries in writeOps_
272 uint32_t opIndex_; ///< current index into writeOps_
273 WriteFlags flags_; ///< set for WriteFlags
274 bool zeroCopyRequest_{
275 false}; ///< if we sent any part of the ioBuf_ with zerocopy
276 unique_ptr<IOBuf> ioBuf_; ///< underlying IOBuf, or nullptr if N/A
277
278 // for consume(), how much we wrote on the last write
279 uint32_t opsWritten_; ///< complete ops written
280 uint32_t partialBytes_; ///< partial bytes of incomplete op written
281 ssize_t bytesWritten_; ///< bytes written altogether
282
283 struct iovec writeOps_[]; ///< write operation(s) list
284 };
285
getDefaultFlags(folly::WriteFlags flags,bool zeroCopyEnabled)286 int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(
287 folly::WriteFlags flags, bool zeroCopyEnabled) noexcept {
288 int msg_flags = MSG_DONTWAIT;
289
290 #ifdef MSG_NOSIGNAL // Linux-only
291 msg_flags |= MSG_NOSIGNAL;
292 #ifdef MSG_MORE
293 if (isSet(flags, WriteFlags::CORK)) {
294 // MSG_MORE tells the kernel we have more data to send, so wait for us to
295 // give it the rest of the data rather than immediately sending a partial
296 // frame, even when TCP_NODELAY is enabled.
297 msg_flags |= MSG_MORE;
298 }
299 #endif // MSG_MORE
300 #endif // MSG_NOSIGNAL
301 if (isSet(flags, WriteFlags::EOR)) {
302 // marks that this is the last byte of a record (response)
303 msg_flags |= MSG_EOR;
304 }
305
306 if (zeroCopyEnabled && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY)) {
307 msg_flags |= MSG_ZEROCOPY;
308 }
309
310 return msg_flags;
311 }
312
getAncillaryData(folly::WriteFlags flags,void * data,const bool byteEventsEnabled)313 void AsyncSocket::SendMsgParamsCallback::getAncillaryData(
314 folly::WriteFlags flags,
315 void* data,
316 const bool byteEventsEnabled) noexcept {
317 auto ancillaryDataSize = getAncillaryDataSize(flags, byteEventsEnabled);
318 if (!ancillaryDataSize) {
319 return;
320 }
321 #if FOLLY_HAVE_SO_TIMESTAMPING
322 CHECK_NOTNULL(data);
323 // this function only handles ancillary data for timestamping
324 //
325 // if getAncillaryDataSize() is overridden and returning a size different
326 // than what we expect, then this function needs to be overridden too, in
327 // order to avoid conflict with how cmsg / msg are written
328 CHECK_EQ(CMSG_LEN(sizeof(uint32_t)), ancillaryDataSize);
329
330 uint32_t sofFlags = 0;
331 if (byteEventsEnabled && isSet(flags, WriteFlags::TIMESTAMP_TX)) {
332 sofFlags = sofFlags | folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE;
333 }
334 if (byteEventsEnabled && isSet(flags, WriteFlags::TIMESTAMP_ACK)) {
335 sofFlags = sofFlags | folly::netops::SOF_TIMESTAMPING_TX_ACK;
336 }
337 if (byteEventsEnabled && isSet(flags, WriteFlags::TIMESTAMP_SCHED)) {
338 sofFlags = sofFlags | folly::netops::SOF_TIMESTAMPING_TX_SCHED;
339 }
340
341 msghdr msg;
342 msg.msg_control = data;
343 msg.msg_controllen = ancillaryDataSize;
344 struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
345 CHECK_NOTNULL(cmsg);
346 cmsg->cmsg_level = SOL_SOCKET;
347 cmsg->cmsg_type = SO_TIMESTAMPING;
348 cmsg->cmsg_len = CMSG_LEN(sizeof(uint32_t));
349 memcpy(CMSG_DATA(cmsg), &sofFlags, sizeof(sofFlags));
350 #else
351 (void)data;
352 #endif // FOLLY_HAVE_SO_TIMESTAMPING
353 return;
354 }
355
getAncillaryDataSize(folly::WriteFlags flags,const bool byteEventsEnabled)356 uint32_t AsyncSocket::SendMsgParamsCallback::getAncillaryDataSize(
357 folly::WriteFlags flags, const bool byteEventsEnabled) noexcept {
358 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
359 if (WriteFlags::NONE != (flags & kWriteFlagsForTimestamping) &&
360 byteEventsEnabled) {
361 return CMSG_LEN(sizeof(uint32_t));
362 }
363 #else
364 (void)flags;
365 (void)byteEventsEnabled;
366 #endif
367 return 0;
368 }
369
370 folly::Optional<AsyncSocket::ByteEvent>
processCmsg(const cmsghdr & cmsg,const size_t rawBytesWritten)371 AsyncSocket::ByteEventHelper::processCmsg(
372 const cmsghdr& cmsg, const size_t rawBytesWritten) {
373 #if FOLLY_HAVE_SO_TIMESTAMPING
374 if (!byteEventsEnabled || maybeEx.has_value()) {
375 return folly::none;
376 }
377 if (!maybeTsState_.has_value()) {
378 maybeTsState_ = TimestampState();
379 }
380 auto& state = maybeTsState_.value();
381 if (auto serrTs = cmsgToSockExtendedErrTimestamping(cmsg)) {
382 if (state.serrReceived) {
383 // already have this part of the message pending
384 throw Exception("already have serr event");
385 }
386 state.serrReceived = true;
387 state.typeRaw = serrTs->ee_info;
388 state.byteOffsetKernel = serrTs->ee_data;
389 } else if (auto scmTs = cmsgToScmTimestamping(cmsg)) {
390 if (state.scmTsReceived) {
391 throw Exception("already have scmTs event");
392 }
393 state.scmTsReceived = true;
394
395 auto timespecToDuration =
396 [](const timespec& ts) -> folly::Optional<std::chrono::nanoseconds> {
397 std::chrono::nanoseconds duration = std::chrono::seconds(ts.tv_sec) +
398 std::chrono::nanoseconds(ts.tv_nsec);
399 if (duration == duration.zero()) {
400 return folly::none;
401 }
402 return duration;
403 };
404 // ts[0] -> software timestamp
405 // ts[1] -> hardware timestamp transformed to userspace time (deprecated)
406 // ts[2] -> hardware timestamp
407 state.maybeSoftwareTs = timespecToDuration(scmTs->ts[0]);
408 state.maybeHardwareTs = timespecToDuration(scmTs->ts[2]);
409 }
410
411 // if we have both components needed for a complete timestamp, build it
412 if (state.serrReceived && state.scmTsReceived) {
413 // cleanup state so that we're ready for next timestamp
414 TimestampState completeState = state;
415 maybeTsState_ = folly::none;
416
417 // map the type
418 folly::Optional<ByteEvent::Type> tsType;
419 switch (completeState.typeRaw) {
420 case folly::netops::SCM_TSTAMP_SND: {
421 tsType = ByteEvent::Type::TX;
422 break;
423 }
424 case folly::netops::SCM_TSTAMP_ACK: {
425 tsType = ByteEvent::Type::ACK;
426 break;
427 }
428 case folly::netops::SCM_TSTAMP_SCHED: {
429 tsType = ByteEvent::Type::SCHED;
430 break;
431 }
432 default:
433 break; // unknown, maybe something new
434 }
435 if (!tsType) {
436 // it's a timestamp, but not one that we're set up to handle
437 // we've cleared our state, loop back around
438 return folly::none;
439 }
440
441 // Calculate the byte offset.
442 //
443 // See documentation for SOF_TIMESTAMPING_OPT_ID for details.
444 //
445 // In summary, two things we have to consider:
446 //
447 // (1) The byte stream offset is relative:
448 // Socket timestamps include the byte stream offset for which the
449 // timestamp applies. There may have been bytes transferred before the
450 // fd was controlled by AsyncSocket. As a result, we don't know the
451 // socket byte stream offset when we enable timestamping.
452 //
453 // To get around this, we set SOF_TIMESTAMPING_OPT_ID when we enable
454 // timestamping via setsockopt. This flag causes the kernel to reset
455 // the offset it uses for timestamps to 0. This allows us to determine
456 // an offset relative to the number of bytes that had been written to
457 // the socket since timestamps were enabled.
458 //
459 // Note that offsets begin at zero; if only a single byte is written
460 // after timestamping is enabled, the offset included in the kernel
461 // cmsg will be 0.
462 //
463 // (2) The byte stream offset is a uint32_t:
464 // Because the kernel uses a uint32_t to store and communicate the
465 // byte stream offset, the offset will wrap every ~4GB. When we get a
466 // timestamp, we need to figure out which byte it is for. We assume
467 // that there will never be more than ~4GB of bytes sent between us
468 // requesting timestamping for a byte and receiving the timestamp;
469 // this is a realistic assumption given CWND and TCP buffer sizes. We
470 // then calculate assuming that the counter has not wrapped since we
471 // sent the byte that we are getting the timestamp for. If the counter
472 // has wrapped, we detect it, and go back one position.
473 const uint64_t bytesPerOffsetWrap =
474 static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1;
475 size_t byteOffset = rawBytesWritten -
476 (rawBytesWritten % bytesPerOffsetWrap) +
477 completeState.byteOffsetKernel + rawBytesWrittenWhenByteEventsEnabled;
478 if (byteOffset > rawBytesWritten) {
479 // kernel's uint32_t var wrapped around; go back one wrap
480 CHECK_GE(byteOffset, bytesPerOffsetWrap);
481 byteOffset = byteOffset - bytesPerOffsetWrap;
482 }
483
484 ByteEvent event = {};
485 event.type = tsType.value();
486 event.offset = byteOffset;
487 event.maybeSoftwareTs = state.maybeSoftwareTs;
488 event.maybeHardwareTs = state.maybeHardwareTs;
489 return event;
490 }
491 #else
492 (void)cmsg;
493 (void)rawBytesWritten;
494 #endif // FOLLY_HAVE_SO_TIMESTAMPING
495 return folly::none;
496 }
497
498 namespace {
499 AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
500
501 // Based on flags, signal the transparent handler to disable certain functions
disableTransparentFunctions(NetworkSocket fd,bool noTransparentTls,bool noTSocks)502 void disableTransparentFunctions(
503 NetworkSocket fd, bool noTransparentTls, bool noTSocks) {
504 (void)fd;
505 (void)noTransparentTls;
506 (void)noTSocks;
507 #if defined(__linux__)
508 if (noTransparentTls) {
509 // Ignore return value, errors are ok
510 VLOG(5) << "Disabling TTLS for fd " << fd;
511 netops::setsockopt(fd, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
512 }
513 if (noTSocks) {
514 VLOG(5) << "Disabling TSOCKS for fd " << fd;
515 // Ignore return value, errors are ok
516 netops::setsockopt(fd, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
517 }
518 #endif
519 }
520
521 constexpr size_t kSmallIoVecSize = 64;
522
523 } // namespace
524
AsyncSocket()525 AsyncSocket::AsyncSocket()
526 : eventBase_(nullptr),
527 writeTimeout_(this, nullptr),
528 ioHandler_(this, nullptr),
529 immediateReadHandler_(this) {
530 VLOG(5) << "new AsyncSocket()";
531 init();
532 }
533
AsyncSocket(EventBase * evb)534 AsyncSocket::AsyncSocket(EventBase* evb)
535 : eventBase_(evb),
536 writeTimeout_(this, evb),
537 ioHandler_(this, evb),
538 immediateReadHandler_(this) {
539 VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
540 init();
541 }
542
AsyncSocket(EventBase * evb,const folly::SocketAddress & address,uint32_t connectTimeout,bool useZeroCopy)543 AsyncSocket::AsyncSocket(
544 EventBase* evb,
545 const folly::SocketAddress& address,
546 uint32_t connectTimeout,
547 bool useZeroCopy)
548 : AsyncSocket(evb) {
549 setZeroCopy(useZeroCopy);
550 connect(nullptr, address, connectTimeout);
551 }
552
AsyncSocket(EventBase * evb,const std::string & ip,uint16_t port,uint32_t connectTimeout,bool useZeroCopy)553 AsyncSocket::AsyncSocket(
554 EventBase* evb,
555 const std::string& ip,
556 uint16_t port,
557 uint32_t connectTimeout,
558 bool useZeroCopy)
559 : AsyncSocket(evb) {
560 setZeroCopy(useZeroCopy);
561 connect(nullptr, ip, port, connectTimeout);
562 }
563
AsyncSocket(EventBase * evb,NetworkSocket fd,uint32_t zeroCopyBufId,const SocketAddress * peerAddress)564 AsyncSocket::AsyncSocket(
565 EventBase* evb,
566 NetworkSocket fd,
567 uint32_t zeroCopyBufId,
568 const SocketAddress* peerAddress)
569 : zeroCopyBufId_(zeroCopyBufId),
570 state_(StateEnum::ESTABLISHED),
571 fd_(fd),
572 addr_(peerAddress ? *peerAddress : folly::SocketAddress()),
573 eventBase_(evb),
574 writeTimeout_(this, evb),
575 ioHandler_(this, evb, fd),
576 immediateReadHandler_(this) {
577 VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd=" << fd
578 << ", zeroCopyBufId=" << zeroCopyBufId << ")";
579 init();
580 disableTransparentFunctions(fd_, noTransparentTls_, noTSocks_);
581 setCloseOnExec();
582 }
583
AsyncSocket(AsyncSocket * oldAsyncSocket)584 AsyncSocket::AsyncSocket(AsyncSocket* oldAsyncSocket)
585 : zeroCopyBufId_(oldAsyncSocket->getZeroCopyBufId()),
586 state_(oldAsyncSocket->state_),
587 fd_(oldAsyncSocket->detachNetworkSocket()),
588 addr_(oldAsyncSocket->addr_),
589 eventBase_(oldAsyncSocket->getEventBase()),
590 writeTimeout_(this, eventBase_),
591 ioHandler_(this, eventBase_, fd_),
592 immediateReadHandler_(this),
593 appBytesWritten_(oldAsyncSocket->appBytesWritten_),
594 rawBytesWritten_(oldAsyncSocket->rawBytesWritten_),
595 preReceivedData_(std::move(oldAsyncSocket->preReceivedData_)),
596 byteEventHelper_(std::move(oldAsyncSocket->byteEventHelper_)) {
597 VLOG(5) << "move AsyncSocket(" << oldAsyncSocket << "->" << this
598 << ", evb=" << eventBase_ << ", fd=" << fd_
599 << ", zeroCopyBufId=" << zeroCopyBufId_ << ")";
600 init();
601 disableTransparentFunctions(fd_, noTransparentTls_, noTSocks_);
602 setCloseOnExec();
603
604 // inform lifecycle observers to give them an opportunity to unsubscribe from
605 // events for the old socket and subscribe to the new socket; we do not move
606 // the subscription ourselves
607 for (const auto& cb : oldAsyncSocket->lifecycleObservers_) {
608 // only available for observers derived from AsyncSocket::LifecycleObserver
609 if (auto dCb = dynamic_cast<AsyncSocket::LifecycleObserver*>(cb)) {
610 dCb->move(oldAsyncSocket, this);
611 }
612 }
613 }
614
AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)615 AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)
616 : AsyncSocket(oldAsyncSocket.get()) {}
617
618 // init() method, since constructor forwarding isn't supported in most
619 // compilers yet.
init()620 void AsyncSocket::init() {
621 if (eventBase_) {
622 eventBase_->dcheckIsInEventBaseThread();
623 }
624 eventFlags_ = EventHandler::NONE;
625 sendTimeout_ = 0;
626 maxReadsPerEvent_ = 16;
627 connectCallback_ = nullptr;
628 errMessageCallback_ = nullptr;
629 readAncillaryDataCallback_ = nullptr;
630 readCallback_ = nullptr;
631 writeReqHead_ = nullptr;
632 writeReqTail_ = nullptr;
633 wShutdownSocketSet_.reset();
634 appBytesReceived_ = 0;
635 totalAppBytesScheduledForWrite_ = 0;
636 sendMsgParamCallback_ = &defaultSendMsgParamsCallback;
637 }
638
~AsyncSocket()639 AsyncSocket::~AsyncSocket() {
640 VLOG(7) << "actual destruction of AsyncSocket(this=" << this
641 << ", evb=" << eventBase_ << ", fd=" << fd_ << ", state=" << state_
642 << ")";
643 for (const auto& cb : lifecycleObservers_) {
644 cb->destroy(this);
645 }
646 DCHECK_EQ(allocatedBytesBuffered_, 0);
647 }
648
destroy()649 void AsyncSocket::destroy() {
650 VLOG(5) << "AsyncSocket::destroy(this=" << this << ", evb=" << eventBase_
651 << ", fd=" << fd_ << ", state=" << state_;
652 // When destroy is called, close the socket immediately
653 closeNow();
654
655 // Then call DelayedDestruction::destroy() to take care of
656 // whether or not we need immediate or delayed destruction
657 DelayedDestruction::destroy();
658 }
659
detachNetworkSocket()660 NetworkSocket AsyncSocket::detachNetworkSocket() {
661 VLOG(6) << "AsyncSocket::detachFd(this=" << this << ", fd=" << fd_
662 << ", evb=" << eventBase_ << ", state=" << state_
663 << ", events=" << std::hex << eventFlags_ << ")";
664 for (const auto& cb : lifecycleObservers_) {
665 // only available for observers derived from AsyncSocket::LifecycleObserver
666 if (auto dCb = dynamic_cast<AsyncSocket::LifecycleObserver*>(cb)) {
667 dCb->fdDetach(this);
668 }
669 }
670 // Extract the fd, and set fd_ to -1 first, so closeNow() won't
671 // actually close the descriptor.
672 if (const auto socketSet = wShutdownSocketSet_.lock()) {
673 socketSet->remove(fd_);
674 }
675 auto fd = fd_;
676 fd_ = NetworkSocket();
677 // Call closeNow() to invoke all pending callbacks with an error.
678 closeNow();
679 // Update the EventHandler to stop using this fd.
680 // This can only be done after closeNow() unregisters the handler.
681 ioHandler_.changeHandlerFD(NetworkSocket());
682 return fd;
683 }
684
anyAddress()685 const folly::SocketAddress& AsyncSocket::anyAddress() {
686 static const folly::SocketAddress anyAddress =
687 folly::SocketAddress("0.0.0.0", 0);
688 return anyAddress;
689 }
690
setShutdownSocketSet(const std::weak_ptr<ShutdownSocketSet> & wNewSS)691 void AsyncSocket::setShutdownSocketSet(
692 const std::weak_ptr<ShutdownSocketSet>& wNewSS) {
693 const auto newSS = wNewSS.lock();
694 const auto shutdownSocketSet = wShutdownSocketSet_.lock();
695
696 if (newSS == shutdownSocketSet) {
697 return;
698 }
699
700 if (shutdownSocketSet && fd_ != NetworkSocket()) {
701 shutdownSocketSet->remove(fd_);
702 }
703
704 if (newSS && fd_ != NetworkSocket()) {
705 newSS->add(fd_);
706 }
707
708 wShutdownSocketSet_ = wNewSS;
709 }
710
setCloseOnExec()711 void AsyncSocket::setCloseOnExec() {
712 int rv = netops_->set_socket_close_on_exec(fd_);
713 if (rv != 0) {
714 auto errnoCopy = errno;
715 throw AsyncSocketException(
716 AsyncSocketException::INTERNAL_ERROR,
717 withAddr("failed to set close-on-exec flag"),
718 errnoCopy);
719 }
720 }
721
connect(ConnectCallback * callback,const folly::SocketAddress & address,int timeout,const SocketOptionMap & options,const folly::SocketAddress & bindAddr,const std::string & ifName)722 void AsyncSocket::connect(
723 ConnectCallback* callback,
724 const folly::SocketAddress& address,
725 int timeout,
726 const SocketOptionMap& options,
727 const folly::SocketAddress& bindAddr,
728 const std::string& ifName) noexcept {
729 DestructorGuard dg(this);
730 eventBase_->dcheckIsInEventBaseThread();
731
732 addr_ = address;
733
734 // Make sure we're in the uninitialized state
735 if (state_ != StateEnum::UNINIT) {
736 return invalidState(callback);
737 }
738
739 connectTimeout_ = std::chrono::milliseconds(timeout);
740 connectStartTime_ = std::chrono::steady_clock::now();
741 // Make connect end time at least >= connectStartTime.
742 connectEndTime_ = connectStartTime_;
743
744 assert(fd_ == NetworkSocket());
745 state_ = StateEnum::CONNECTING;
746 connectCallback_ = callback;
747 invokeConnectAttempt();
748
749 sockaddr_storage addrStorage;
750 auto saddr = reinterpret_cast<sockaddr*>(&addrStorage);
751
752 try {
753 // Create the socket
754 // Technically the first parameter should actually be a protocol family
755 // constant (PF_xxx) rather than an address family (AF_xxx), but the
756 // distinction is mainly just historical. In pretty much all
757 // implementations the PF_foo and AF_foo constants are identical.
758 fd_ = netops_->socket(address.getFamily(), SOCK_STREAM, 0);
759 if (fd_ == NetworkSocket()) {
760 auto errnoCopy = errno;
761 throw AsyncSocketException(
762 AsyncSocketException::INTERNAL_ERROR,
763 withAddr("failed to create socket"),
764 errnoCopy);
765 }
766
767 disableTransparentFunctions(fd_, noTransparentTls_, noTSocks_);
768 handleNetworkSocketAttached();
769 setCloseOnExec();
770
771 // Put the socket in non-blocking mode
772 int rv = netops_->set_socket_non_blocking(fd_);
773 if (rv == -1) {
774 auto errnoCopy = errno;
775 throw AsyncSocketException(
776 AsyncSocketException::INTERNAL_ERROR,
777 withAddr("failed to put socket in non-blocking mode"),
778 errnoCopy);
779 }
780
781 #if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
782 // iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
783 rv = fcntl(fd_.toFd(), F_SETNOSIGPIPE, 1);
784 if (rv == -1) {
785 auto errnoCopy = errno;
786 throw AsyncSocketException(
787 AsyncSocketException::INTERNAL_ERROR,
788 "failed to enable F_SETNOSIGPIPE on socket",
789 errnoCopy);
790 }
791 #endif
792
793 // By default, turn on TCP_NODELAY
794 // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
795 // setNoDelay() will log an error message if it fails.
796 // Also set the cached zeroCopyVal_ since it cannot be set earlier if the fd
797 // is not created
798 if (address.getFamily() != AF_UNIX) {
799 (void)setNoDelay(true);
800 setZeroCopy(zeroCopyVal_);
801 }
802
803 // Apply the additional PRE_BIND options if any.
804 applyOptions(options, SocketOptionKey::ApplyPos::PRE_BIND);
805
806 VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
807 << ", fd=" << fd_ << ", host=" << address.describe().c_str();
808
809 // bind the socket to the interface
810 #if defined(__linux__)
811 if (!ifName.empty() &&
812 netops_->setsockopt(
813 fd_,
814 SOL_SOCKET,
815 SO_BINDTODEVICE,
816 ifName.c_str(),
817 ifName.length())) {
818 auto errnoCopy = errno;
819 doClose();
820 throw AsyncSocketException(
821 AsyncSocketException::NOT_OPEN,
822 "failed to bind to device: " + ifName,
823 errnoCopy);
824 }
825 #else
826 (void)ifName;
827 #endif
828
829 // bind the socket
830 if (bindAddr != anyAddress()) {
831 int one = 1;
832 if (netops_->setsockopt(
833 fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
834 auto errnoCopy = errno;
835 doClose();
836 throw AsyncSocketException(
837 AsyncSocketException::NOT_OPEN,
838 "failed to setsockopt prior to bind on " + bindAddr.describe(),
839 errnoCopy);
840 }
841
842 bindAddr.getAddress(&addrStorage);
843
844 if (netops_->bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
845 auto errnoCopy = errno;
846 doClose();
847 throw AsyncSocketException(
848 AsyncSocketException::NOT_OPEN,
849 "failed to bind to async socket: " + bindAddr.describe(),
850 errnoCopy);
851 }
852 }
853
854 // Apply the additional POST_BIND options if any.
855 applyOptions(options, SocketOptionKey::ApplyPos::POST_BIND);
856
857 // Call preConnect hook if any.
858 if (connectCallback_) {
859 connectCallback_->preConnect(fd_);
860 }
861
862 // Perform the connect()
863 address.getAddress(&addrStorage);
864
865 if (tfoEnabled_) {
866 state_ = StateEnum::FAST_OPEN;
867 tfoAttempted_ = true;
868 } else {
869 if (socketConnect(saddr, addr_.getActualSize()) < 0) {
870 return;
871 }
872 }
873
874 // If we're still here the connect() succeeded immediately.
875 // Fall through to call the callback outside of this try...catch block
876 } catch (const AsyncSocketException& ex) {
877 return failConnect(__func__, ex);
878 } catch (const std::exception& ex) {
879 // shouldn't happen, but handle it just in case
880 VLOG(4) << "AsyncSocket::connect(this=" << this << ", fd=" << fd_
881 << "): unexpected " << typeid(ex).name()
882 << " exception: " << ex.what();
883 AsyncSocketException tex(
884 AsyncSocketException::INTERNAL_ERROR,
885 withAddr(string("unexpected exception: ") + ex.what()));
886 return failConnect(__func__, tex);
887 }
888
889 // The connection succeeded immediately
890 // The read callback may not have been set yet, and no writes may be pending
891 // yet, so we don't have to register for any events at the moment.
892 VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
893 assert(errMessageCallback_ == nullptr);
894 assert(readAncillaryDataCallback_ == nullptr);
895 assert(readCallback_ == nullptr);
896 assert(writeReqHead_ == nullptr);
897 if (state_ != StateEnum::FAST_OPEN) {
898 state_ = StateEnum::ESTABLISHED;
899 }
900 invokeConnectSuccess();
901 }
902
socketConnect(const struct sockaddr * saddr,socklen_t len)903 int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
904 int rv = netops_->connect(fd_, saddr, len);
905 if (rv < 0) {
906 auto errnoCopy = errno;
907 if (errnoCopy == EINPROGRESS) {
908 scheduleConnectTimeout();
909 registerForConnectEvents();
910 } else {
911 throw AsyncSocketException(
912 AsyncSocketException::NOT_OPEN,
913 "connect failed (immediately)",
914 errnoCopy);
915 }
916 }
917 return rv;
918 }
919
scheduleConnectTimeout()920 void AsyncSocket::scheduleConnectTimeout() {
921 // Connection in progress.
922 auto timeout = connectTimeout_.count();
923 if (timeout > 0) {
924 // Start a timer in case the connection takes too long.
925 if (!writeTimeout_.scheduleTimeout(uint32_t(timeout))) {
926 throw AsyncSocketException(
927 AsyncSocketException::INTERNAL_ERROR,
928 withAddr("failed to schedule AsyncSocket connect timeout"));
929 }
930 }
931 }
932
registerForConnectEvents()933 void AsyncSocket::registerForConnectEvents() {
934 // Register for write events, so we'll
935 // be notified when the connection finishes/fails.
936 // Note that we don't register for a persistent event here.
937 assert(eventFlags_ == EventHandler::NONE);
938 eventFlags_ = EventHandler::WRITE;
939 if (!ioHandler_.registerHandler(eventFlags_)) {
940 throw AsyncSocketException(
941 AsyncSocketException::INTERNAL_ERROR,
942 withAddr("failed to register AsyncSocket connect handler"));
943 }
944 }
945
connect(ConnectCallback * callback,const string & ip,uint16_t port,int timeout,const SocketOptionMap & options)946 void AsyncSocket::connect(
947 ConnectCallback* callback,
948 const string& ip,
949 uint16_t port,
950 int timeout,
951 const SocketOptionMap& options) noexcept {
952 DestructorGuard dg(this);
953 try {
954 connectCallback_ = callback;
955 connect(callback, folly::SocketAddress(ip, port), timeout, options);
956 } catch (const std::exception& ex) {
957 AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR, ex.what());
958 return failConnect(__func__, tex);
959 }
960 }
961
cancelConnect()962 void AsyncSocket::cancelConnect() {
963 connectCallback_ = nullptr;
964 if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) {
965 closeNow();
966 }
967 }
968
setSendTimeout(uint32_t milliseconds)969 void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
970 sendTimeout_ = milliseconds;
971 if (eventBase_) {
972 eventBase_->dcheckIsInEventBaseThread();
973 }
974
975 // If we are currently pending on write requests, immediately update
976 // writeTimeout_ with the new value.
977 if ((eventFlags_ & EventHandler::WRITE) &&
978 (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) {
979 assert(state_ == StateEnum::ESTABLISHED);
980 assert((shutdownFlags_ & SHUT_WRITE) == 0);
981 if (sendTimeout_ > 0) {
982 if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
983 AsyncSocketException ex(
984 AsyncSocketException::INTERNAL_ERROR,
985 withAddr("failed to reschedule send timeout in setSendTimeout"));
986 return failWrite(__func__, ex);
987 }
988 } else {
989 writeTimeout_.cancelTimeout();
990 }
991 }
992 }
993
setErrMessageCB(ErrMessageCallback * callback)994 void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
995 VLOG(6) << "AsyncSocket::setErrMessageCB() this=" << this << ", fd=" << fd_
996 << ", callback=" << callback << ", state=" << state_;
997
998 // In the latest stable kernel 4.14.3 as of 2017-12-04, unix domain
999 // socket does not support MSG_ERRQUEUE. So recvmsg(MSG_ERRQUEUE)
1000 // will read application data from unix doamin socket as error
1001 // message, which breaks the message flow in application. Feel free
1002 // to remove the next code block if MSG_ERRQUEUE is added for unix
1003 // domain socket in the future.
1004 if (callback != nullptr) {
1005 cacheLocalAddress();
1006 if (localAddr_.getFamily() == AF_UNIX) {
1007 LOG(ERROR) << "Failed to set ErrMessageCallback=" << callback
1008 << " for Unix Doamin Socket where MSG_ERRQUEUE is unsupported,"
1009 << " fd=" << fd_;
1010 return;
1011 }
1012 }
1013
1014 // Short circuit if callback is the same as the existing errMessageCallback_.
1015 if (callback == errMessageCallback_) {
1016 return;
1017 }
1018
1019 if (!msgErrQueueSupported) {
1020 // Per-socket error message queue is not supported on this platform.
1021 return invalidState(callback);
1022 }
1023
1024 DestructorGuard dg(this);
1025 eventBase_->dcheckIsInEventBaseThread();
1026
1027 if (callback == nullptr) {
1028 // We should be able to reset the callback regardless of the
1029 // socket state. It's important to have a reliable callback
1030 // cancellation mechanism.
1031 errMessageCallback_ = callback;
1032 return;
1033 }
1034
1035 switch ((StateEnum)state_) {
1036 case StateEnum::CONNECTING:
1037 case StateEnum::FAST_OPEN:
1038 case StateEnum::ESTABLISHED: {
1039 errMessageCallback_ = callback;
1040 return;
1041 }
1042 case StateEnum::CLOSED:
1043 case StateEnum::ERROR:
1044 // We should never reach here. SHUT_READ should always be set
1045 // if we are in STATE_CLOSED or STATE_ERROR.
1046 assert(false);
1047 return invalidState(callback);
1048 case StateEnum::UNINIT:
1049 // We do not allow setReadCallback() to be called before we start
1050 // connecting.
1051 return invalidState(callback);
1052 }
1053
1054 // We don't put a default case in the switch statement, so that the compiler
1055 // will warn us to update the switch statement if a new state is added.
1056 return invalidState(callback);
1057 }
1058
getErrMessageCallback() const1059 AsyncSocket::ErrMessageCallback* AsyncSocket::getErrMessageCallback() const {
1060 return errMessageCallback_;
1061 }
1062
setReadAncillaryDataCB(ReadAncillaryDataCallback * callback)1063 void AsyncSocket::setReadAncillaryDataCB(ReadAncillaryDataCallback* callback) {
1064 VLOG(6) << "AsyncSocket::setReadAncillaryDataCB() this=" << this
1065 << ", fd=" << fd_ << ", callback=" << callback
1066 << ", state=" << state_;
1067
1068 readAncillaryDataCallback_ = callback;
1069 }
1070
1071 AsyncSocket::ReadAncillaryDataCallback*
getReadAncillaryDataCallback() const1072 AsyncSocket::getReadAncillaryDataCallback() const {
1073 return readAncillaryDataCallback_;
1074 }
1075
setSendMsgParamCB(SendMsgParamsCallback * callback)1076 void AsyncSocket::setSendMsgParamCB(SendMsgParamsCallback* callback) {
1077 sendMsgParamCallback_ = callback;
1078 }
1079
getSendMsgParamsCB() const1080 AsyncSocket::SendMsgParamsCallback* AsyncSocket::getSendMsgParamsCB() const {
1081 return sendMsgParamCallback_;
1082 }
1083
setReadCB(ReadCallback * callback)1084 void AsyncSocket::setReadCB(ReadCallback* callback) {
1085 VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
1086 << ", callback=" << callback << ", state=" << state_;
1087
1088 // Short circuit if callback is the same as the existing readCallback_.
1089 //
1090 // Note that this is needed for proper functioning during some cleanup cases.
1091 // During cleanup we allow setReadCallback(nullptr) to be called even if the
1092 // read callback is already unset and we have been detached from an event
1093 // base. This check prevents us from asserting
1094 // eventBase_->isInEventBaseThread() when eventBase_ is nullptr.
1095 if (callback == readCallback_) {
1096 return;
1097 }
1098
1099 /* We are removing a read callback */
1100 if (callback == nullptr && immediateReadHandler_.isLoopCallbackScheduled()) {
1101 immediateReadHandler_.cancelLoopCallback();
1102 }
1103
1104 if (shutdownFlags_ & SHUT_READ) {
1105 // Reads have already been shut down on this socket.
1106 //
1107 // Allow setReadCallback(nullptr) to be called in this case, but don't
1108 // allow a new callback to be set.
1109 //
1110 // For example, setReadCallback(nullptr) can happen after an error if we
1111 // invoke some other error callback before invoking readError(). The other
1112 // error callback that is invoked first may go ahead and clear the read
1113 // callback before we get a chance to invoke readError().
1114 if (callback != nullptr) {
1115 return invalidState(callback);
1116 }
1117 assert((eventFlags_ & EventHandler::READ) == 0);
1118 readCallback_ = nullptr;
1119 return;
1120 }
1121
1122 DestructorGuard dg(this);
1123 eventBase_->dcheckIsInEventBaseThread();
1124
1125 switch ((StateEnum)state_) {
1126 case StateEnum::CONNECTING:
1127 case StateEnum::FAST_OPEN:
1128 // For convenience, we allow the read callback to be set while we are
1129 // still connecting. We just store the callback for now. Once the
1130 // connection completes we'll register for read events.
1131 readCallback_ = callback;
1132 return;
1133 case StateEnum::ESTABLISHED: {
1134 readCallback_ = callback;
1135 uint16_t oldFlags = eventFlags_;
1136 if (readCallback_) {
1137 eventFlags_ |= EventHandler::READ;
1138 } else {
1139 eventFlags_ &= ~EventHandler::READ;
1140 }
1141
1142 // Update our registration if our flags have changed
1143 if (eventFlags_ != oldFlags) {
1144 // We intentionally ignore the return value here.
1145 // updateEventRegistration() will move us into the error state if it
1146 // fails, and we don't need to do anything else here afterwards.
1147 (void)updateEventRegistration();
1148 }
1149
1150 if (readCallback_) {
1151 checkForImmediateRead();
1152 }
1153 return;
1154 }
1155 case StateEnum::CLOSED:
1156 case StateEnum::ERROR:
1157 // We should never reach here. SHUT_READ should always be set
1158 // if we are in STATE_CLOSED or STATE_ERROR.
1159 assert(false);
1160 return invalidState(callback);
1161 case StateEnum::UNINIT:
1162 // We do not allow setReadCallback() to be called before we start
1163 // connecting.
1164 return invalidState(callback);
1165 }
1166
1167 // We don't put a default case in the switch statement, so that the compiler
1168 // will warn us to update the switch statement if a new state is added.
1169 return invalidState(callback);
1170 }
1171
getReadCallback() const1172 AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
1173 return readCallback_;
1174 }
1175
setZeroCopy(bool enable)1176 bool AsyncSocket::setZeroCopy(bool enable) {
1177 if (msgErrQueueSupported) {
1178 zeroCopyVal_ = enable;
1179
1180 if (fd_ == NetworkSocket()) {
1181 return false;
1182 }
1183
1184 // No-op, bail out early
1185 if (enable == zeroCopyEnabled_) {
1186 return true;
1187 }
1188
1189 int val = enable ? 1 : 0;
1190 int ret =
1191 netops_->setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
1192
1193 // if enable == false, set zeroCopyEnabled_ = false regardless
1194 // if SO_ZEROCOPY is set or not
1195 if (!enable) {
1196 zeroCopyEnabled_ = enable;
1197 return true;
1198 }
1199
1200 /* if the setsockopt failed, try to see if the socket inherited the flag
1201 * since we cannot set SO_ZEROCOPY on a socket s = accept
1202 */
1203 if (ret) {
1204 val = 0;
1205 socklen_t optlen = sizeof(val);
1206 ret = netops_->getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
1207
1208 if (!ret) {
1209 enable = val != 0;
1210 }
1211 }
1212
1213 if (!ret) {
1214 zeroCopyEnabled_ = enable;
1215
1216 return true;
1217 }
1218 }
1219
1220 return false;
1221 }
1222
setZeroCopyEnableFunc(AsyncWriter::ZeroCopyEnableFunc func)1223 void AsyncSocket::setZeroCopyEnableFunc(AsyncWriter::ZeroCopyEnableFunc func) {
1224 zeroCopyEnableFunc_ = func;
1225 }
1226
setZeroCopyReenableThreshold(size_t threshold)1227 void AsyncSocket::setZeroCopyReenableThreshold(size_t threshold) {
1228 zeroCopyReenableThreshold_ = threshold;
1229 }
1230
isZeroCopyRequest(WriteFlags flags)1231 bool AsyncSocket::isZeroCopyRequest(WriteFlags flags) {
1232 return (zeroCopyEnabled_ && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY));
1233 }
1234
adjustZeroCopyFlags(folly::WriteFlags & flags)1235 void AsyncSocket::adjustZeroCopyFlags(folly::WriteFlags& flags) {
1236 if (!zeroCopyEnabled_) {
1237 // if the zeroCopyReenableCounter_ is > 0
1238 // we try to dec and if we reach 0
1239 // we set zeroCopyEnabled_ to true
1240 if (zeroCopyReenableCounter_) {
1241 if (0 == --zeroCopyReenableCounter_) {
1242 zeroCopyEnabled_ = true;
1243 return;
1244 }
1245 }
1246 flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
1247 }
1248 }
1249
addZeroCopyBuf(std::unique_ptr<folly::IOBuf> && buf,ReleaseIOBufCallback * cb)1250 void AsyncSocket::addZeroCopyBuf(
1251 std::unique_ptr<folly::IOBuf>&& buf, ReleaseIOBufCallback* cb) {
1252 uint32_t id = getNextZeroCopyBufId();
1253 folly::IOBuf* ptr = buf.get();
1254
1255 idZeroCopyBufPtrMap_[id] = ptr;
1256 auto& p = idZeroCopyBufInfoMap_[ptr];
1257 p.count_++;
1258 CHECK(p.buf_.get() == nullptr);
1259 p.buf_ = std::move(buf);
1260 p.cb_ = cb;
1261 }
1262
addZeroCopyBuf(folly::IOBuf * ptr)1263 void AsyncSocket::addZeroCopyBuf(folly::IOBuf* ptr) {
1264 uint32_t id = getNextZeroCopyBufId();
1265 idZeroCopyBufPtrMap_[id] = ptr;
1266
1267 idZeroCopyBufInfoMap_[ptr].count_++;
1268 }
1269
releaseZeroCopyBuf(uint32_t id)1270 void AsyncSocket::releaseZeroCopyBuf(uint32_t id) {
1271 auto iter = idZeroCopyBufPtrMap_.find(id);
1272 CHECK(iter != idZeroCopyBufPtrMap_.end());
1273 auto ptr = iter->second;
1274 auto iter1 = idZeroCopyBufInfoMap_.find(ptr);
1275 CHECK(iter1 != idZeroCopyBufInfoMap_.end());
1276 if (0 == --iter1->second.count_) {
1277 releaseIOBuf(std::move(iter1->second.buf_), iter1->second.cb_);
1278 idZeroCopyBufInfoMap_.erase(iter1);
1279 }
1280
1281 idZeroCopyBufPtrMap_.erase(iter);
1282 }
1283
setZeroCopyBuf(std::unique_ptr<folly::IOBuf> && buf,ReleaseIOBufCallback * cb)1284 void AsyncSocket::setZeroCopyBuf(
1285 std::unique_ptr<folly::IOBuf>&& buf, ReleaseIOBufCallback* cb) {
1286 folly::IOBuf* ptr = buf.get();
1287 auto& p = idZeroCopyBufInfoMap_[ptr];
1288 CHECK(p.buf_.get() == nullptr);
1289
1290 p.buf_ = std::move(buf);
1291 p.cb_ = cb;
1292 }
1293
containsZeroCopyBuf(folly::IOBuf * ptr)1294 bool AsyncSocket::containsZeroCopyBuf(folly::IOBuf* ptr) {
1295 return (idZeroCopyBufInfoMap_.find(ptr) != idZeroCopyBufInfoMap_.end());
1296 }
1297
isZeroCopyMsg(const cmsghdr & cmsg) const1298 bool AsyncSocket::isZeroCopyMsg(const cmsghdr& cmsg) const {
1299 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
1300 if ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
1301 (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
1302 auto serr =
1303 reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
1304 return (
1305 (serr->ee_errno == 0) && (serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY));
1306 }
1307 #endif
1308 (void)cmsg;
1309 return false;
1310 }
1311
processZeroCopyMsg(const cmsghdr & cmsg)1312 void AsyncSocket::processZeroCopyMsg(const cmsghdr& cmsg) {
1313 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
1314 auto serr =
1315 reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
1316 uint32_t hi = serr->ee_data;
1317 uint32_t lo = serr->ee_info;
1318 // disable zero copy if the buffer was actually copied
1319 if ((serr->ee_code & SO_EE_CODE_ZEROCOPY_COPIED) && zeroCopyEnabled_) {
1320 VLOG(2) << "AsyncSocket::processZeroCopyMsg(): setting "
1321 << "zeroCopyEnabled_ = false due to SO_EE_CODE_ZEROCOPY_COPIED "
1322 << "on " << fd_;
1323 zeroCopyEnabled_ = false;
1324 }
1325
1326 for (uint32_t i = lo; i <= hi; i++) {
1327 releaseZeroCopyBuf(i);
1328 }
1329 #else
1330 (void)cmsg;
1331 #endif
1332 }
1333
releaseIOBuf(std::unique_ptr<folly::IOBuf> buf,ReleaseIOBufCallback * callback)1334 void AsyncSocket::releaseIOBuf(
1335 std::unique_ptr<folly::IOBuf> buf, ReleaseIOBufCallback* callback) {
1336 if (!buf) {
1337 return;
1338 }
1339 const size_t allocated = buf->computeChainCapacity();
1340 DCHECK_GE(allocatedBytesBuffered_, allocated);
1341 allocatedBytesBuffered_ -= allocated;
1342 if (callback) {
1343 callback->releaseIOBuf(std::move(buf));
1344 }
1345 }
1346
enableByteEvents()1347 void AsyncSocket::enableByteEvents() {
1348 if (!byteEventHelper_) {
1349 byteEventHelper_ = std::make_unique<ByteEventHelper>();
1350 }
1351
1352 if (byteEventHelper_->byteEventsEnabled ||
1353 byteEventHelper_->maybeEx.has_value()) {
1354 return;
1355 }
1356
1357 try {
1358 #if FOLLY_HAVE_SO_TIMESTAMPING
1359 // make sure we have a connected IP socket that supports error queues
1360 // (Unix sockets do not support error queues)
1361 if (NetworkSocket() == fd_ || !good()) {
1362 throw AsyncSocketException(
1363 AsyncSocketException::INVALID_STATE,
1364 withAddr("failed to enable byte events: "
1365 "socket is not open or not in a good state"));
1366 }
1367 folly::SocketAddress addr = {};
1368 try {
1369 // explicitly fetch local address (instead of using cache)
1370 // to ensure socket is currently healthy
1371 addr.setFromLocalAddress(fd_);
1372 } catch (const std::system_error&) {
1373 throw AsyncSocketException(
1374 AsyncSocketException::INVALID_STATE,
1375 withAddr("failed to enable byte events: "
1376 "socket is not open or not in a good state"));
1377 }
1378 const auto family = addr.getFamily();
1379 if (family != AF_INET && family != AF_INET6) {
1380 throw AsyncSocketException(
1381 AsyncSocketException::NOT_SUPPORTED,
1382 withAddr("failed to enable byte events: socket type not supported"));
1383 }
1384
1385 // check if timestamping is already enabled on the socket by another source
1386 {
1387 uint32_t flags = 0;
1388 socklen_t len = sizeof(flags);
1389 const auto ret =
1390 getSockOptVirtual(SOL_SOCKET, SO_TIMESTAMPING, &flags, &len);
1391 int getSockOptErrno = errno;
1392 if (0 != ret) {
1393 throw AsyncSocketException(
1394 AsyncSocketException::INTERNAL_ERROR,
1395 withAddr("failed to enable byte events: "
1396 "timestamps may not be supported for this socket type "
1397 "or socket be closed"),
1398 getSockOptErrno);
1399 }
1400 if (0 != flags) {
1401 throw AsyncSocketException(
1402 AsyncSocketException::INTERNAL_ERROR,
1403 withAddr("failed to enable byte events: "
1404 "timestamps may have already been enabled"),
1405 getSockOptErrno);
1406 }
1407 }
1408
1409 // enable control messages for software and hardware timestamps
1410 // WriteFlags will determine which messages are generated
1411 //
1412 // SOF_TIMESTAMPING_OPT_ID: see discussion in ByteEventHelper::processCmsg
1413 // SOF_TIMESTAMPING_OPT_TSONLY: only get timestamps, not original packet
1414 // SOF_TIMESTAMPING_SOFTWARE: get software timestamps if generated
1415 // SOF_TIMESTAMPING_RAW_HARDWARE: get hardware timestamps if generated
1416 // SOF_TIMESTAMPING_OPT_TX_SWHW: get both sw + hw timestamps if generated
1417 const uint32_t flags =
1418 (folly::netops::SOF_TIMESTAMPING_OPT_ID |
1419 folly::netops::SOF_TIMESTAMPING_OPT_TSONLY |
1420 folly::netops::SOF_TIMESTAMPING_SOFTWARE |
1421 folly::netops::SOF_TIMESTAMPING_RAW_HARDWARE |
1422 folly::netops::SOF_TIMESTAMPING_OPT_TX_SWHW);
1423 socklen_t len = sizeof(flags);
1424 const auto ret =
1425 setSockOptVirtual(SOL_SOCKET, SO_TIMESTAMPING, &flags, len);
1426 int setSockOptErrno = errno;
1427 if (ret == 0) {
1428 byteEventHelper_->byteEventsEnabled = true;
1429 byteEventHelper_->rawBytesWrittenWhenByteEventsEnabled =
1430 getRawBytesWritten();
1431 for (const auto& observer : lifecycleObservers_) {
1432 if (observer->getConfig().byteEvents) {
1433 observer->byteEventsEnabled(this);
1434 }
1435 }
1436 return;
1437 }
1438
1439 // failed
1440 throw AsyncSocketException(
1441 AsyncSocketException::INTERNAL_ERROR,
1442 withAddr("failed to enable byte events: setsockopt failed"),
1443 setSockOptErrno);
1444 #endif // FOLLY_HAVE_SO_TIMESTAMPING
1445 // unsupported by platform
1446 throw AsyncSocketException(
1447 AsyncSocketException::NOT_SUPPORTED,
1448 withAddr("failed to enable byte events: platform not supported"));
1449
1450 } catch (const AsyncSocketException& ex) {
1451 failByteEvents(ex);
1452 }
1453 }
1454
write(WriteCallback * callback,const void * buf,size_t bytes,WriteFlags flags)1455 void AsyncSocket::write(
1456 WriteCallback* callback, const void* buf, size_t bytes, WriteFlags flags) {
1457 iovec op;
1458 op.iov_base = const_cast<void*>(buf);
1459 op.iov_len = bytes;
1460 writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), bytes, flags);
1461 }
1462
writev(WriteCallback * callback,const iovec * vec,size_t count,WriteFlags flags)1463 void AsyncSocket::writev(
1464 WriteCallback* callback, const iovec* vec, size_t count, WriteFlags flags) {
1465 size_t totalBytes = 0;
1466 for (size_t i = 0; i < count; ++i) {
1467 totalBytes += vec[i].iov_len;
1468 }
1469 writeImpl(callback, vec, count, unique_ptr<IOBuf>(), totalBytes, flags);
1470 }
1471
writeChain(WriteCallback * callback,unique_ptr<IOBuf> && buf,WriteFlags flags)1472 void AsyncSocket::writeChain(
1473 WriteCallback* callback, unique_ptr<IOBuf>&& buf, WriteFlags flags) {
1474 adjustZeroCopyFlags(flags);
1475
1476 // adjustZeroCopyFlags can set zeroCopyEnabled_ to true
1477 if (zeroCopyEnabled_ && !isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY) &&
1478 zeroCopyEnableFunc_ && zeroCopyEnableFunc_(buf) && buf->isManaged()) {
1479 flags |= WriteFlags::WRITE_MSG_ZEROCOPY;
1480 }
1481
1482 size_t count = buf->countChainElements();
1483 if (count <= kSmallIoVecSize) {
1484 // suppress "warning: variable length array 'vec' is used [-Wvla]"
1485 FOLLY_PUSH_WARNING
1486 FOLLY_GNU_DISABLE_WARNING("-Wvla")
1487 iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA_01, count, kSmallIoVecSize)];
1488 FOLLY_POP_WARNING
1489
1490 writeChainImpl(callback, vec, count, std::move(buf), flags);
1491 } else {
1492 std::unique_ptr<iovec[]> vec(new iovec[count]);
1493 writeChainImpl(callback, vec.get(), count, std::move(buf), flags);
1494 }
1495 }
1496
writeChainImpl(WriteCallback * callback,iovec * vec,size_t count,unique_ptr<IOBuf> && buf,WriteFlags flags)1497 void AsyncSocket::writeChainImpl(
1498 WriteCallback* callback,
1499 iovec* vec,
1500 size_t count,
1501 unique_ptr<IOBuf>&& buf,
1502 WriteFlags flags) {
1503 auto res = buf->fillIov(vec, count);
1504 writeImpl(
1505 callback, vec, res.numIovecs, std::move(buf), res.totalLength, flags);
1506 }
1507
writeImpl(WriteCallback * callback,const iovec * vec,size_t count,unique_ptr<IOBuf> && buf,size_t totalBytes,WriteFlags flags)1508 void AsyncSocket::writeImpl(
1509 WriteCallback* callback,
1510 const iovec* vec,
1511 size_t count,
1512 unique_ptr<IOBuf>&& buf,
1513 size_t totalBytes,
1514 WriteFlags flags) {
1515 VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
1516 << ", callback=" << callback << ", count=" << count
1517 << ", state=" << state_;
1518 DestructorGuard dg(this);
1519 unique_ptr<IOBuf> ioBuf(std::move(buf));
1520 eventBase_->dcheckIsInEventBaseThread();
1521
1522 auto* releaseIOBufCallback =
1523 callback ? callback->getReleaseIOBufCallback() : nullptr;
1524
1525 SCOPE_EXIT { releaseIOBuf(std::move(ioBuf), releaseIOBufCallback); };
1526
1527 totalAppBytesScheduledForWrite_ += totalBytes;
1528 if (ioBuf) {
1529 allocatedBytesBuffered_ += ioBuf->computeChainCapacity();
1530 }
1531
1532 if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
1533 // No new writes may be performed after the write side of the socket has
1534 // been shutdown.
1535 //
1536 // We could just call callback->writeError() here to fail just this write.
1537 // However, fail hard and use invalidState() to fail all outstanding
1538 // callbacks and move the socket into the error state. There's most likely
1539 // a bug in the caller's code, so we abort everything rather than trying to
1540 // proceed as best we can.
1541 return invalidState(callback);
1542 }
1543
1544 uint32_t countWritten = 0;
1545 uint32_t partialWritten = 0;
1546 ssize_t bytesWritten = 0;
1547 bool mustRegister = false;
1548 if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) &&
1549 !connecting()) {
1550 if (writeReqHead_ == nullptr) {
1551 // If we are established and there are no other writes pending,
1552 // we can attempt to perform the write immediately.
1553 assert(writeReqTail_ == nullptr);
1554 assert((eventFlags_ & EventHandler::WRITE) == 0);
1555
1556 auto writeResult = performWrite(
1557 vec, uint32_t(count), flags, &countWritten, &partialWritten);
1558 bytesWritten = writeResult.writeReturn;
1559 if (bytesWritten < 0) {
1560 auto errnoCopy = errno;
1561 if (writeResult.exception) {
1562 return failWrite(__func__, callback, 0, *writeResult.exception);
1563 }
1564 AsyncSocketException ex(
1565 AsyncSocketException::INTERNAL_ERROR,
1566 withAddr("writev failed"),
1567 errnoCopy);
1568 return failWrite(__func__, callback, 0, ex);
1569 } else if (countWritten == count) {
1570 // done, add the whole buffer
1571 if (countWritten && isZeroCopyRequest(flags)) {
1572 addZeroCopyBuf(std::move(ioBuf), releaseIOBufCallback);
1573 } else {
1574 releaseIOBuf(std::move(ioBuf), releaseIOBufCallback);
1575 }
1576
1577 // We successfully wrote everything.
1578 // Invoke the callback and return.
1579 if (callback) {
1580 callback->writeSuccess();
1581 }
1582 return;
1583 } else { // continue writing the next writeReq
1584 // add just the ptr
1585 if (bytesWritten && isZeroCopyRequest(flags)) {
1586 addZeroCopyBuf(ioBuf.get());
1587 }
1588 }
1589 if (!connecting()) {
1590 // Writes might put the socket back into connecting state
1591 // if TFO is enabled, and using TFO fails.
1592 // This means that write timeouts would not be active, however
1593 // connect timeouts would affect this stage.
1594 mustRegister = true;
1595 }
1596 }
1597 } else if (!connecting()) {
1598 // Invalid state for writing
1599 return invalidState(callback);
1600 }
1601
1602 // Create a new WriteRequest to add to the queue
1603 WriteRequest* req;
1604 try {
1605 req = BytesWriteRequest::newRequest(
1606 this,
1607 callback,
1608 vec + countWritten,
1609 uint32_t(count - countWritten),
1610 partialWritten,
1611 uint32_t(bytesWritten),
1612 std::move(ioBuf),
1613 flags);
1614 } catch (const std::exception& ex) {
1615 // we mainly expect to catch std::bad_alloc here
1616 AsyncSocketException tex(
1617 AsyncSocketException::INTERNAL_ERROR,
1618 withAddr(string("failed to append new WriteRequest: ") + ex.what()));
1619 return failWrite(__func__, callback, size_t(bytesWritten), tex);
1620 }
1621 req->consume();
1622 if (writeReqTail_ == nullptr) {
1623 assert(writeReqHead_ == nullptr);
1624 writeReqHead_ = writeReqTail_ = req;
1625 } else {
1626 writeReqTail_->append(req);
1627 writeReqTail_ = req;
1628 }
1629
1630 if (bufferCallback_) {
1631 bufferCallback_->onEgressBuffered();
1632 }
1633
1634 // Register for write events if are established and not currently
1635 // waiting on write events
1636 if (mustRegister) {
1637 assert(state_ == StateEnum::ESTABLISHED);
1638 assert((eventFlags_ & EventHandler::WRITE) == 0);
1639 if (!updateEventRegistration(EventHandler::WRITE, 0)) {
1640 assert(state_ == StateEnum::ERROR);
1641 return;
1642 }
1643 if (sendTimeout_ > 0) {
1644 // Schedule a timeout to fire if the write takes too long.
1645 if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
1646 AsyncSocketException ex(
1647 AsyncSocketException::INTERNAL_ERROR,
1648 withAddr("failed to schedule send timeout"));
1649 return failWrite(__func__, ex);
1650 }
1651 }
1652 }
1653 }
1654
writeRequest(WriteRequest * req)1655 void AsyncSocket::writeRequest(WriteRequest* req) {
1656 if (writeReqTail_ == nullptr) {
1657 assert(writeReqHead_ == nullptr);
1658 writeReqHead_ = writeReqTail_ = req;
1659 req->start();
1660 } else {
1661 writeReqTail_->append(req);
1662 writeReqTail_ = req;
1663 }
1664 }
1665
close()1666 void AsyncSocket::close() {
1667 VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
1668 << ", state=" << state_ << ", shutdownFlags=" << std::hex
1669 << (int)shutdownFlags_;
1670
1671 // close() is only different from closeNow() when there are pending writes
1672 // that need to drain before we can close. In all other cases, just call
1673 // closeNow().
1674 //
1675 // Note that writeReqHead_ can be non-nullptr even in STATE_CLOSED or
1676 // STATE_ERROR if close() is invoked while a previous closeNow() or failure
1677 // is still running. (e.g., If there are multiple pending writes, and we
1678 // call writeError() on the first one, it may call close(). In this case we
1679 // will already be in STATE_CLOSED or STATE_ERROR, but the remaining pending
1680 // writes will still be in the queue.)
1681 //
1682 // We only need to drain pending writes if we are still in STATE_CONNECTING
1683 // or STATE_ESTABLISHED
1684 if ((writeReqHead_ == nullptr) ||
1685 !(state_ == StateEnum::CONNECTING || state_ == StateEnum::ESTABLISHED)) {
1686 closeNow();
1687 return;
1688 }
1689
1690 // Declare a DestructorGuard to ensure that the AsyncSocket cannot be
1691 // destroyed until close() returns.
1692 DestructorGuard dg(this);
1693 eventBase_->dcheckIsInEventBaseThread();
1694
1695 // Since there are write requests pending, we have to set the
1696 // SHUT_WRITE_PENDING flag, and wait to perform the real close until the
1697 // connect finishes and we finish writing these requests.
1698 //
1699 // Set SHUT_READ to indicate that reads are shut down, and set the
1700 // SHUT_WRITE_PENDING flag to mark that we want to shutdown once the
1701 // pending writes complete.
1702 shutdownFlags_ |= (SHUT_READ | SHUT_WRITE_PENDING);
1703
1704 // If a read callback is set, invoke readEOF() immediately to inform it that
1705 // the socket has been closed and no more data can be read.
1706 if (readCallback_) {
1707 // Disable reads if they are enabled
1708 if (!updateEventRegistration(0, EventHandler::READ)) {
1709 // We're now in the error state; callbacks have been cleaned up
1710 assert(state_ == StateEnum::ERROR);
1711 assert(readCallback_ == nullptr);
1712 } else {
1713 ReadCallback* callback = readCallback_;
1714 readCallback_ = nullptr;
1715 callback->readEOF();
1716 }
1717 }
1718 }
1719
closeNow()1720 void AsyncSocket::closeNow() {
1721 VLOG(5) << "AsyncSocket::closeNow(): this=" << this << ", fd_=" << fd_
1722 << ", state=" << state_ << ", shutdownFlags=" << std::hex
1723 << (int)shutdownFlags_;
1724 DestructorGuard dg(this);
1725 if (eventBase_) {
1726 eventBase_->dcheckIsInEventBaseThread();
1727 }
1728
1729 switch (state_) {
1730 case StateEnum::ESTABLISHED:
1731 case StateEnum::CONNECTING:
1732 case StateEnum::FAST_OPEN: {
1733 shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1734 state_ = StateEnum::CLOSED;
1735
1736 // If the write timeout was set, cancel it.
1737 writeTimeout_.cancelTimeout();
1738
1739 // If we are registered for I/O events, unregister.
1740 if (eventFlags_ != EventHandler::NONE) {
1741 eventFlags_ = EventHandler::NONE;
1742 if (!updateEventRegistration()) {
1743 // We will have been moved into the error state.
1744 assert(state_ == StateEnum::ERROR);
1745 return;
1746 }
1747 }
1748
1749 if (immediateReadHandler_.isLoopCallbackScheduled()) {
1750 immediateReadHandler_.cancelLoopCallback();
1751 }
1752
1753 if (fd_ != NetworkSocket()) {
1754 ioHandler_.changeHandlerFD(NetworkSocket());
1755 doClose();
1756 }
1757
1758 invokeConnectErr(getSocketClosedLocallyEx());
1759
1760 failAllWrites(getSocketClosedLocallyEx());
1761
1762 if (readCallback_) {
1763 ReadCallback* callback = readCallback_;
1764 readCallback_ = nullptr;
1765 callback->readEOF();
1766 }
1767 return;
1768 }
1769 case StateEnum::CLOSED:
1770 // Do nothing. It's possible that we are being called recursively
1771 // from inside a callback that we invoked inside another call to close()
1772 // that is still running.
1773 return;
1774 case StateEnum::ERROR:
1775 // Do nothing. The error handling code has performed (or is performing)
1776 // cleanup.
1777 return;
1778 case StateEnum::UNINIT:
1779 assert(eventFlags_ == EventHandler::NONE);
1780 assert(connectCallback_ == nullptr);
1781 assert(readCallback_ == nullptr);
1782 assert(writeReqHead_ == nullptr);
1783 shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1784 state_ = StateEnum::CLOSED;
1785 return;
1786 }
1787
1788 LOG(DFATAL) << "AsyncSocket::closeNow() (this=" << this << ", fd=" << fd_
1789 << ") called in unknown state " << state_;
1790 }
1791
closeWithReset()1792 void AsyncSocket::closeWithReset() {
1793 // Enable SO_LINGER, with the linger timeout set to 0.
1794 // This will trigger a TCP reset when we close the socket.
1795 if (fd_ != NetworkSocket()) {
1796 struct linger optLinger = {1, 0};
1797 if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) {
1798 VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER "
1799 << "on " << fd_ << ": errno=" << errno;
1800 }
1801 }
1802
1803 // Then let closeNow() take care of the rest
1804 closeNow();
1805 }
1806
shutdownWrite()1807 void AsyncSocket::shutdownWrite() {
1808 VLOG(5) << "AsyncSocket::shutdownWrite(): this=" << this << ", fd=" << fd_
1809 << ", state=" << state_ << ", shutdownFlags=" << std::hex
1810 << (int)shutdownFlags_;
1811
1812 // If there are no pending writes, shutdownWrite() is identical to
1813 // shutdownWriteNow().
1814 if (writeReqHead_ == nullptr) {
1815 shutdownWriteNow();
1816 return;
1817 }
1818
1819 eventBase_->dcheckIsInEventBaseThread();
1820
1821 // There are pending writes. Set SHUT_WRITE_PENDING so that the actual
1822 // shutdown will be performed once all writes complete.
1823 shutdownFlags_ |= SHUT_WRITE_PENDING;
1824 }
1825
shutdownWriteNow()1826 void AsyncSocket::shutdownWriteNow() {
1827 VLOG(5) << "AsyncSocket::shutdownWriteNow(): this=" << this << ", fd=" << fd_
1828 << ", state=" << state_ << ", shutdownFlags=" << std::hex
1829 << (int)shutdownFlags_;
1830
1831 if (shutdownFlags_ & SHUT_WRITE) {
1832 // Writes are already shutdown; nothing else to do.
1833 return;
1834 }
1835
1836 // If SHUT_READ is already set, just call closeNow() to completely
1837 // close the socket. This can happen if close() was called with writes
1838 // pending, and then shutdownWriteNow() is called before all pending writes
1839 // complete.
1840 if (shutdownFlags_ & SHUT_READ) {
1841 closeNow();
1842 return;
1843 }
1844
1845 DestructorGuard dg(this);
1846 if (eventBase_) {
1847 eventBase_->dcheckIsInEventBaseThread();
1848 }
1849
1850 switch (static_cast<StateEnum>(state_)) {
1851 case StateEnum::ESTABLISHED: {
1852 shutdownFlags_ |= SHUT_WRITE;
1853
1854 // If the write timeout was set, cancel it.
1855 writeTimeout_.cancelTimeout();
1856
1857 // If we are registered for write events, unregister.
1858 if (!updateEventRegistration(0, EventHandler::WRITE)) {
1859 // We will have been moved into the error state.
1860 assert(state_ == StateEnum::ERROR);
1861 return;
1862 }
1863
1864 // Shutdown writes on the file descriptor
1865 netops_->shutdown(fd_, SHUT_WR);
1866
1867 // Immediately fail all write requests
1868 failAllWrites(getSocketShutdownForWritesEx());
1869 return;
1870 }
1871 case StateEnum::CONNECTING: {
1872 // Set the SHUT_WRITE_PENDING flag.
1873 // When the connection completes, it will check this flag,
1874 // shutdown the write half of the socket, and then set SHUT_WRITE.
1875 shutdownFlags_ |= SHUT_WRITE_PENDING;
1876
1877 // Immediately fail all write requests
1878 failAllWrites(getSocketShutdownForWritesEx());
1879 return;
1880 }
1881 case StateEnum::UNINIT:
1882 // Callers normally shouldn't call shutdownWriteNow() before the socket
1883 // even starts connecting. Nonetheless, go ahead and set
1884 // SHUT_WRITE_PENDING. Once the socket eventually connects it will
1885 // immediately shut down the write side of the socket.
1886 shutdownFlags_ |= SHUT_WRITE_PENDING;
1887 return;
1888 case StateEnum::FAST_OPEN:
1889 // In fast open state we haven't call connected yet, and if we shutdown
1890 // the writes, we will never try to call connect, so shut everything down
1891 shutdownFlags_ |= SHUT_WRITE;
1892 // Immediately fail all write requests
1893 failAllWrites(getSocketShutdownForWritesEx());
1894 return;
1895 case StateEnum::CLOSED:
1896 case StateEnum::ERROR:
1897 // We should never get here. SHUT_WRITE should always be set
1898 // in STATE_CLOSED and STATE_ERROR.
1899 VLOG(4) << "AsyncSocket::shutdownWriteNow() (this=" << this
1900 << ", fd=" << fd_ << ") in unexpected state " << state_
1901 << " with SHUT_WRITE not set (" << std::hex << (int)shutdownFlags_
1902 << ")";
1903 assert(false);
1904 return;
1905 }
1906
1907 LOG(DFATAL) << "AsyncSocket::shutdownWriteNow() (this=" << this
1908 << ", fd=" << fd_ << ") called in unknown state " << state_;
1909 }
1910
readable() const1911 bool AsyncSocket::readable() const {
1912 if (fd_ == NetworkSocket()) {
1913 return false;
1914 }
1915
1916 if (preReceivedData_ && !preReceivedData_->empty()) {
1917 return true;
1918 }
1919 netops::PollDescriptor fds[1];
1920 fds[0].fd = fd_;
1921 fds[0].events = POLLIN;
1922 fds[0].revents = 0;
1923 int rc = netops_->poll(fds, 1, 0);
1924 return rc == 1;
1925 }
1926
writable() const1927 bool AsyncSocket::writable() const {
1928 if (fd_ == NetworkSocket()) {
1929 return false;
1930 }
1931 netops::PollDescriptor fds[1];
1932 fds[0].fd = fd_;
1933 fds[0].events = POLLOUT;
1934 fds[0].revents = 0;
1935 int rc = netops_->poll(fds, 1, 0);
1936 return rc == 1;
1937 }
1938
isPending() const1939 bool AsyncSocket::isPending() const {
1940 return ioHandler_.isPending();
1941 }
1942
hangup() const1943 bool AsyncSocket::hangup() const {
1944 if (fd_ == NetworkSocket()) {
1945 // sanity check, no one should ask for hangup if we are not connected.
1946 assert(false);
1947 return false;
1948 }
1949 #ifdef POLLRDHUP // Linux-only
1950 netops::PollDescriptor fds[1];
1951 fds[0].fd = fd_;
1952 fds[0].events = POLLRDHUP | POLLHUP;
1953 fds[0].revents = 0;
1954 netops_->poll(fds, 1, 0);
1955 return (fds[0].revents & (POLLRDHUP | POLLHUP)) != 0;
1956 #else
1957 return false;
1958 #endif
1959 }
1960
good() const1961 bool AsyncSocket::good() const {
1962 return (
1963 (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN ||
1964 state_ == StateEnum::ESTABLISHED) &&
1965 (shutdownFlags_ == 0) && (eventBase_ != nullptr));
1966 }
1967
error() const1968 bool AsyncSocket::error() const {
1969 return (state_ == StateEnum::ERROR);
1970 }
1971
attachEventBase(EventBase * eventBase)1972 void AsyncSocket::attachEventBase(EventBase* eventBase) {
1973 VLOG(5) << "AsyncSocket::attachEventBase(this=" << this << ", fd=" << fd_
1974 << ", old evb=" << eventBase_ << ", new evb=" << eventBase
1975 << ", state=" << state_ << ", events=" << std::hex << eventFlags_
1976 << ")";
1977 assert(eventBase_ == nullptr);
1978 eventBase->dcheckIsInEventBaseThread();
1979
1980 eventBase_ = eventBase;
1981 ioHandler_.attachEventBase(eventBase);
1982
1983 updateEventRegistration();
1984
1985 writeTimeout_.attachEventBase(eventBase);
1986 if (evbChangeCb_) {
1987 evbChangeCb_->evbAttached(this);
1988 }
1989 for (const auto& cb : lifecycleObservers_) {
1990 cb->evbAttach(this, eventBase_);
1991 }
1992 }
1993
detachEventBase()1994 void AsyncSocket::detachEventBase() {
1995 VLOG(5) << "AsyncSocket::detachEventBase(this=" << this << ", fd=" << fd_
1996 << ", old evb=" << eventBase_ << ", state=" << state_
1997 << ", events=" << std::hex << eventFlags_ << ")";
1998 assert(eventBase_ != nullptr);
1999 eventBase_->dcheckIsInEventBaseThread();
2000
2001 // Make a copy of the existing event base, to invoke lifecycle observer
2002 // callbacks
2003 EventBase* existingEvb = eventBase_;
2004
2005 eventBase_ = nullptr;
2006
2007 ioHandler_.unregisterHandler();
2008
2009 ioHandler_.detachEventBase();
2010 writeTimeout_.detachEventBase();
2011 if (evbChangeCb_) {
2012 evbChangeCb_->evbDetached(this);
2013 }
2014 for (const auto& cb : lifecycleObservers_) {
2015 cb->evbDetach(this, existingEvb);
2016 }
2017 }
2018
isDetachable() const2019 bool AsyncSocket::isDetachable() const {
2020 DCHECK(eventBase_ != nullptr);
2021 eventBase_->dcheckIsInEventBaseThread();
2022
2023 return !writeTimeout_.isScheduled();
2024 }
2025
cacheAddresses()2026 void AsyncSocket::cacheAddresses() {
2027 if (fd_ != NetworkSocket()) {
2028 try {
2029 cacheLocalAddress();
2030 cachePeerAddress();
2031 } catch (const std::system_error& e) {
2032 if (e.code() !=
2033 std::error_code(ENOTCONN, errorCategoryForErrnoDomain())) {
2034 VLOG(2) << "Error caching addresses: " << e.code().value() << ", "
2035 << e.code().message();
2036 }
2037 }
2038 }
2039 }
2040
cacheLocalAddress() const2041 void AsyncSocket::cacheLocalAddress() const {
2042 if (!localAddr_.isInitialized()) {
2043 localAddr_.setFromLocalAddress(fd_);
2044 }
2045 }
2046
cachePeerAddress() const2047 void AsyncSocket::cachePeerAddress() const {
2048 if (!addr_.isInitialized()) {
2049 addr_.setFromPeerAddress(fd_);
2050 }
2051 }
2052
applyOptions(const SocketOptionMap & options,SocketOptionKey::ApplyPos pos)2053 void AsyncSocket::applyOptions(
2054 const SocketOptionMap& options, SocketOptionKey::ApplyPos pos) {
2055 auto result = applySocketOptions(fd_, options, pos);
2056 if (result != 0) {
2057 throw AsyncSocketException(
2058 AsyncSocketException::INTERNAL_ERROR,
2059 withAddr("failed to set socket option"),
2060 result);
2061 }
2062 }
2063
isZeroCopyWriteInProgress() const2064 bool AsyncSocket::isZeroCopyWriteInProgress() const noexcept {
2065 eventBase_->dcheckIsInEventBaseThread();
2066 return (!idZeroCopyBufPtrMap_.empty());
2067 }
2068
getLocalAddress(folly::SocketAddress * address) const2069 void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
2070 cacheLocalAddress();
2071 *address = localAddr_;
2072 }
2073
getPeerAddress(folly::SocketAddress * address) const2074 void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
2075 cachePeerAddress();
2076 *address = addr_;
2077 }
2078
getTFOSucceded() const2079 bool AsyncSocket::getTFOSucceded() const {
2080 return detail::tfo_succeeded(fd_);
2081 }
2082
setNoDelay(bool noDelay)2083 int AsyncSocket::setNoDelay(bool noDelay) {
2084 if (fd_ == NetworkSocket()) {
2085 VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket " << this
2086 << "(state=" << state_ << ")";
2087 return EINVAL;
2088 }
2089
2090 int value = noDelay ? 1 : 0;
2091 if (netops_->setsockopt(
2092 fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
2093 int errnoCopy = errno;
2094 VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket " << this
2095 << " (fd=" << fd_ << ", state=" << state_
2096 << "): " << errnoStr(errnoCopy);
2097 return errnoCopy;
2098 }
2099
2100 return 0;
2101 }
2102
setCongestionFlavor(const std::string & cname)2103 int AsyncSocket::setCongestionFlavor(const std::string& cname) {
2104 #ifndef TCP_CONGESTION
2105 #define TCP_CONGESTION 13
2106 #endif
2107
2108 if (fd_ == NetworkSocket()) {
2109 VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open "
2110 << "socket " << this << "(state=" << state_ << ")";
2111 return EINVAL;
2112 }
2113
2114 if (netops_->setsockopt(
2115 fd_,
2116 IPPROTO_TCP,
2117 TCP_CONGESTION,
2118 cname.c_str(),
2119 socklen_t(cname.length() + 1)) != 0) {
2120 int errnoCopy = errno;
2121 VLOG(2) << "failed to update TCP_CONGESTION option on AsyncSocket " << this
2122 << "(fd=" << fd_ << ", state=" << state_
2123 << "): " << errnoStr(errnoCopy);
2124 return errnoCopy;
2125 }
2126
2127 return 0;
2128 }
2129
setQuickAck(bool quickack)2130 int AsyncSocket::setQuickAck(bool quickack) {
2131 (void)quickack;
2132 if (fd_ == NetworkSocket()) {
2133 VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket " << this
2134 << "(state=" << state_ << ")";
2135 return EINVAL;
2136 }
2137
2138 #ifdef TCP_QUICKACK // Linux-only
2139 int value = quickack ? 1 : 0;
2140 if (netops_->setsockopt(
2141 fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
2142 int errnoCopy = errno;
2143 VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket" << this
2144 << "(fd=" << fd_ << ", state=" << state_
2145 << "): " << errnoStr(errnoCopy);
2146 return errnoCopy;
2147 }
2148
2149 return 0;
2150 #else
2151 return ENOSYS;
2152 #endif
2153 }
2154
setSendBufSize(size_t bufsize)2155 int AsyncSocket::setSendBufSize(size_t bufsize) {
2156 if (fd_ == NetworkSocket()) {
2157 VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket "
2158 << this << "(state=" << state_ << ")";
2159 return EINVAL;
2160 }
2161
2162 if (netops_->setsockopt(
2163 fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) != 0) {
2164 int errnoCopy = errno;
2165 VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket" << this
2166 << "(fd=" << fd_ << ", state=" << state_
2167 << "): " << errnoStr(errnoCopy);
2168 return errnoCopy;
2169 }
2170
2171 return 0;
2172 }
2173
setRecvBufSize(size_t bufsize)2174 int AsyncSocket::setRecvBufSize(size_t bufsize) {
2175 if (fd_ == NetworkSocket()) {
2176 VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket "
2177 << this << "(state=" << state_ << ")";
2178 return EINVAL;
2179 }
2180
2181 if (netops_->setsockopt(
2182 fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) != 0) {
2183 int errnoCopy = errno;
2184 VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket" << this
2185 << "(fd=" << fd_ << ", state=" << state_
2186 << "): " << errnoStr(errnoCopy);
2187 return errnoCopy;
2188 }
2189
2190 return 0;
2191 }
2192
2193 #if defined(__linux__)
getSendBufInUse() const2194 size_t AsyncSocket::getSendBufInUse() const {
2195 if (fd_ == NetworkSocket()) {
2196 std::stringstream issueString;
2197 issueString << "AsyncSocket::getSendBufInUse() called on non-open socket "
2198 << this << "(state=" << state_ << ")";
2199 VLOG(4) << issueString.str();
2200 throw std::logic_error(issueString.str());
2201 }
2202
2203 size_t returnValue = 0;
2204 if (-1 == ::ioctl(fd_.toFd(), SIOCOUTQ, &returnValue)) {
2205 int errnoCopy = errno;
2206 std::stringstream issueString;
2207 issueString << "Failed to get the tx used bytes on Socket: " << this
2208 << "(fd=" << fd_ << ", state=" << state_
2209 << "): " << errnoStr(errnoCopy);
2210 VLOG(2) << issueString.str();
2211 throw std::logic_error(issueString.str());
2212 }
2213
2214 return returnValue;
2215 }
2216
getRecvBufInUse() const2217 size_t AsyncSocket::getRecvBufInUse() const {
2218 if (fd_ == NetworkSocket()) {
2219 std::stringstream issueString;
2220 issueString << "AsyncSocket::getRecvBufInUse() called on non-open socket "
2221 << this << "(state=" << state_ << ")";
2222 VLOG(4) << issueString.str();
2223 throw std::logic_error(issueString.str());
2224 }
2225
2226 size_t returnValue = 0;
2227 if (-1 == ::ioctl(fd_.toFd(), SIOCINQ, &returnValue)) {
2228 std::stringstream issueString;
2229 int errnoCopy = errno;
2230 issueString << "Failed to get the rx used bytes on Socket: " << this
2231 << "(fd=" << fd_ << ", state=" << state_
2232 << "): " << errnoStr(errnoCopy);
2233 VLOG(2) << issueString.str();
2234 throw std::logic_error(issueString.str());
2235 }
2236
2237 return returnValue;
2238 }
2239 #endif
2240
setTCPProfile(int profd)2241 int AsyncSocket::setTCPProfile(int profd) {
2242 if (fd_ == NetworkSocket()) {
2243 VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket " << this
2244 << "(state=" << state_ << ")";
2245 return EINVAL;
2246 }
2247
2248 if (netops_->setsockopt(
2249 fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) != 0) {
2250 int errnoCopy = errno;
2251 VLOG(2) << "failed to set socket namespace option on AsyncSocket" << this
2252 << "(fd=" << fd_ << ", state=" << state_
2253 << "): " << errnoStr(errnoCopy);
2254 return errnoCopy;
2255 }
2256
2257 return 0;
2258 }
2259
ioReady(uint16_t events)2260 void AsyncSocket::ioReady(uint16_t events) noexcept {
2261 VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd=" << fd_
2262 << ", events=" << std::hex << events << ", state=" << state_;
2263 DestructorGuard dg(this);
2264 assert(events & EventHandler::READ_WRITE);
2265 eventBase_->dcheckIsInEventBaseThread();
2266
2267 auto relevantEvents = uint16_t(events & EventHandler::READ_WRITE);
2268 EventBase* originalEventBase = eventBase_;
2269 // If we got there it means that either EventHandler::READ or
2270 // EventHandler::WRITE is set. Any of these flags can
2271 // indicate that there are messages available in the socket
2272 // error message queue.
2273 // Return if we handle any error messages - this is to avoid
2274 // unnecessary read/write calls
2275 if (handleErrMessages()) {
2276 return;
2277 }
2278
2279 // Return now if handleErrMessages() detached us from our EventBase
2280 if (eventBase_ != originalEventBase) {
2281 return;
2282 }
2283
2284 if (relevantEvents == EventHandler::READ) {
2285 handleRead();
2286 } else if (relevantEvents == EventHandler::WRITE) {
2287 handleWrite();
2288 } else if (relevantEvents == EventHandler::READ_WRITE) {
2289 // If both read and write events are ready, process writes first.
2290 handleWrite();
2291
2292 // Return now if handleWrite() detached us from our EventBase
2293 if (eventBase_ != originalEventBase) {
2294 return;
2295 }
2296
2297 // Only call handleRead() if a read callback is still installed.
2298 // (It's possible that the read callback was uninstalled during
2299 // handleWrite().)
2300 if (readCallback_) {
2301 handleRead();
2302 }
2303 } else {
2304 VLOG(4) << "AsyncSocket::ioRead() called with unexpected events "
2305 << std::hex << events << "(this=" << this << ")";
2306 abort();
2307 }
2308 }
2309
performRead(void ** buf,size_t * buflen,size_t *)2310 AsyncSocket::ReadResult AsyncSocket::performRead(
2311 void** buf, size_t* buflen, size_t* /* offset */) {
2312 struct iovec iov;
2313
2314 // Data buffer pointer and length
2315 iov.iov_base = *buf;
2316 iov.iov_len = *buflen;
2317
2318 return performReadInternal(&iov, 1);
2319 }
2320
performReadv(struct iovec * iovs,size_t num)2321 AsyncSocket::ReadResult AsyncSocket::performReadv(
2322 struct iovec* iovs, size_t num) {
2323 return performReadInternal(iovs, num);
2324 }
2325
performReadInternal(struct iovec * iovs,size_t num)2326 AsyncSocket::ReadResult AsyncSocket::performReadInternal(
2327 struct iovec* iovs, size_t num) {
2328 VLOG(5) << "AsyncSocket::performReadInternal() this=" << this
2329 << ", iovs=" << iovs << ", num=" << num;
2330
2331 if (!num) {
2332 return ReadResult(READ_ERROR);
2333 }
2334
2335 if (preReceivedData_ && !preReceivedData_->empty()) {
2336 VLOG(5) << "AsyncSocket::performReadInternal() this=" << this
2337 << ", reading pre-received data";
2338
2339 ssize_t len = 0;
2340 for (size_t i = 0; (i < num) && (!preReceivedData_->empty()); ++i) {
2341 io::Cursor cursor(preReceivedData_.get());
2342 auto ret = cursor.pullAtMost(iovs[i].iov_base, iovs[i].iov_len);
2343 len += ret;
2344
2345 IOBufQueue queue;
2346 queue.append(std::move(preReceivedData_));
2347 queue.trimStart(ret);
2348 preReceivedData_ = queue.move();
2349 }
2350
2351 appBytesReceived_ += len;
2352 return ReadResult(len);
2353 }
2354
2355 ssize_t bytes = 0;
2356
2357 struct msghdr msg;
2358
2359 if (readAncillaryDataCallback_ == nullptr && num == 1) {
2360 bytes = netops_->recv(fd_, iovs[0].iov_base, iovs[0].iov_len, MSG_DONTWAIT);
2361 } else {
2362 if (readAncillaryDataCallback_) {
2363 // Ancillary data buffer and length
2364 msg.msg_control =
2365 readAncillaryDataCallback_->getAncillaryDataCtrlBuffer().data();
2366 msg.msg_controllen =
2367 readAncillaryDataCallback_->getAncillaryDataCtrlBuffer().size();
2368 } else {
2369 msg.msg_control = nullptr;
2370 msg.msg_controllen = 0;
2371 }
2372
2373 // Dest address info
2374 msg.msg_name = nullptr;
2375 msg.msg_namelen = 0;
2376
2377 // Array of data buffers (scatter/gather)
2378 msg.msg_iov = iovs;
2379 msg.msg_iovlen = num;
2380
2381 bytes = netops::recvmsg(fd_, &msg, 0);
2382 }
2383
2384 if (readAncillaryDataCallback_ && (bytes > 0)) {
2385 readAncillaryDataCallback_->ancillaryData(msg);
2386 }
2387
2388 if (bytes < 0) {
2389 if (errno == EAGAIN || errno == EWOULDBLOCK) {
2390 // No more data to read right now.
2391 return ReadResult(READ_BLOCKING);
2392 } else {
2393 return ReadResult(READ_ERROR);
2394 }
2395 } else {
2396 appBytesReceived_ += bytes;
2397 return ReadResult(bytes);
2398 }
2399 }
2400
prepareReadBuffer(void ** buf,size_t * buflen)2401 void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
2402 // no matter what, buffer should be prepared for non-ssl socket
2403 CHECK(readCallback_);
2404 readCallback_->getReadBuffer(buf, buflen);
2405 }
2406
prepareReadBuffers(IOBufIovecBuilder::IoVecVec & iovs)2407 void AsyncSocket::prepareReadBuffers(IOBufIovecBuilder::IoVecVec& iovs) {
2408 // no matter what, buffers should be prepared for non-ssl socket
2409 CHECK(readCallback_);
2410 readCallback_->getReadBuffers(iovs);
2411 }
2412
handleErrMessages()2413 size_t AsyncSocket::handleErrMessages() noexcept {
2414 // This method has non-empty implementation only for platforms
2415 // supporting per-socket error queues.
2416 VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
2417 << ", state=" << state_;
2418 if (errMessageCallback_ == nullptr && idZeroCopyBufPtrMap_.empty() &&
2419 (!byteEventHelper_ || !byteEventHelper_->byteEventsEnabled)) {
2420 VLOG(7) << "AsyncSocket::handleErrMessages(): "
2421 << "no err message callback installed and "
2422 << "ByteEvents not enabled - exiting.";
2423 return 0;
2424 }
2425
2426 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
2427 uint8_t ctrl[1024];
2428 unsigned char data;
2429 struct msghdr msg;
2430 iovec entry;
2431
2432 entry.iov_base = &data;
2433 entry.iov_len = sizeof(data);
2434 msg.msg_iov = &entry;
2435 msg.msg_iovlen = 1;
2436 msg.msg_name = nullptr;
2437 msg.msg_namelen = 0;
2438 msg.msg_control = ctrl;
2439 msg.msg_controllen = sizeof(ctrl);
2440 msg.msg_flags = 0;
2441
2442 int ret;
2443 size_t num = 0;
2444 // the socket may be closed by errMessage callback, so check on each iteration
2445 while (fd_ != NetworkSocket()) {
2446 ret = netops_->recvmsg(fd_, &msg, MSG_ERRQUEUE);
2447 VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
2448
2449 if (ret < 0) {
2450 if (errno != EAGAIN) {
2451 auto errnoCopy = errno;
2452 LOG(ERROR) << "::recvmsg exited with code " << ret
2453 << ", errno: " << errnoCopy << ", fd: " << fd_;
2454 AsyncSocketException ex(
2455 AsyncSocketException::INTERNAL_ERROR,
2456 withAddr("recvmsg() failed"),
2457 errnoCopy);
2458 failErrMessageRead(__func__, ex);
2459 }
2460
2461 return num;
2462 }
2463
2464 for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
2465 cmsg != nullptr && cmsg->cmsg_len != 0;
2466 cmsg = CMSG_NXTHDR(&msg, cmsg)) {
2467 ++num;
2468 if (isZeroCopyMsg(*cmsg)) {
2469 processZeroCopyMsg(*cmsg);
2470 continue;
2471 }
2472
2473 // try to process it as a ByteEvent and forward to observers
2474 //
2475 // observers cannot throw and thus we expect only exceptions from
2476 // ByteEventHelper, but we guard against other cases for safety
2477 if (byteEventHelper_) {
2478 try {
2479 if (const auto maybeByteEvent =
2480 byteEventHelper_->processCmsg(*cmsg, getRawBytesWritten())) {
2481 const auto& byteEvent = maybeByteEvent.value();
2482 for (const auto& observer : lifecycleObservers_) {
2483 if (observer->getConfig().byteEvents) {
2484 observer->byteEvent(this, byteEvent);
2485 }
2486 }
2487 }
2488 } catch (const ByteEventHelper::Exception& behEx) {
2489 // rewrap the ByteEventHelper::Exception with extra information
2490 AsyncSocketException ex(
2491 AsyncSocketException::INTERNAL_ERROR,
2492 withAddr(
2493 string("AsyncSocket::handleErrMessages(), "
2494 "internal exception during ByteEvent processing: ") +
2495 behEx.what()));
2496 failByteEvents(ex);
2497 } catch (const std::exception& ex) {
2498 AsyncSocketException tex(
2499 AsyncSocketException::UNKNOWN,
2500 string("AsyncSocket::handleErrMessages(), "
2501 "unhandled exception during ByteEvent processing, "
2502 "threw exception: ") +
2503 ex.what());
2504 failByteEvents(tex);
2505 } catch (...) {
2506 AsyncSocketException tex(
2507 AsyncSocketException::UNKNOWN,
2508 string("AsyncSocket::handleErrMessages(), "
2509 "unhandled exception during ByteEvent processing, "
2510 "threw non-exception type"));
2511 failByteEvents(tex);
2512 }
2513 }
2514
2515 // even if it is a timestamp, hand it off to the errMessageCallback,
2516 // the application may want it as well.
2517 if (errMessageCallback_) {
2518 errMessageCallback_->errMessage(*cmsg);
2519 }
2520 }
2521 }
2522 return num;
2523 #else
2524 return 0;
2525 #endif // FOLLY_HAVE_MSG_ERRQUEUE
2526 }
2527
processZeroCopyWriteInProgress()2528 bool AsyncSocket::processZeroCopyWriteInProgress() noexcept {
2529 eventBase_->dcheckIsInEventBaseThread();
2530 if (idZeroCopyBufPtrMap_.empty()) {
2531 return true;
2532 }
2533
2534 handleErrMessages();
2535
2536 return idZeroCopyBufPtrMap_.empty();
2537 }
2538
addLifecycleObserver(AsyncTransport::LifecycleObserver * observer)2539 void AsyncSocket::addLifecycleObserver(
2540 AsyncTransport::LifecycleObserver* observer) {
2541 if (eventBase_) {
2542 eventBase_->dcheckIsInEventBaseThread();
2543 }
2544
2545 // adding the same observer multiple times is not allowed
2546 auto& observers = lifecycleObservers_;
2547 CHECK(
2548 std::find(observers.begin(), observers.end(), observer) ==
2549 observers.end());
2550
2551 observers.push_back(observer);
2552 observer->observerAttach(this);
2553 if (observer->getConfig().byteEvents) {
2554 if (byteEventHelper_ && byteEventHelper_->maybeEx.has_value()) {
2555 observer->byteEventsUnavailable(this, *byteEventHelper_->maybeEx);
2556 } else if (byteEventHelper_ && byteEventHelper_->byteEventsEnabled) {
2557 observer->byteEventsEnabled(this);
2558 } else if (state_ == StateEnum::ESTABLISHED) {
2559 enableByteEvents(); // try to enable now
2560 }
2561 // do nothing right now; wait until we're connected
2562 }
2563 }
2564
removeLifecycleObserver(AsyncTransport::LifecycleObserver * observer)2565 bool AsyncSocket::removeLifecycleObserver(
2566 AsyncTransport::LifecycleObserver* observer) {
2567 auto& observers = lifecycleObservers_;
2568 auto it = std::find(observers.begin(), observers.end(), observer);
2569 if (it == observers.end()) {
2570 return false;
2571 }
2572 observer->observerDetach(this);
2573 observers.erase(it);
2574 return true;
2575 }
2576
2577 std::vector<AsyncTransport::LifecycleObserver*>
getLifecycleObservers() const2578 AsyncSocket::getLifecycleObservers() const {
2579 if (eventBase_) {
2580 eventBase_->dcheckIsInEventBaseThread();
2581 }
2582 return std::vector<AsyncTransport::LifecycleObserver*>(
2583 lifecycleObservers_.begin(), lifecycleObservers_.end());
2584 }
2585
splitIovecArray(const size_t startOffset,const size_t endOffset,const iovec * srcVec,const size_t srcCount,iovec * dstVec,size_t & dstCount)2586 void AsyncSocket::splitIovecArray(
2587 const size_t startOffset,
2588 const size_t endOffset,
2589 const iovec* srcVec,
2590 const size_t srcCount,
2591 iovec* dstVec,
2592 size_t& dstCount) {
2593 CHECK_GE(endOffset, startOffset);
2594 CHECK_GE(dstCount, srcCount);
2595 dstCount = 0;
2596
2597 const size_t targetBytes = endOffset - startOffset + 1;
2598 size_t dstBytes = 0;
2599 size_t processedBytes = 0;
2600 for (size_t i = 0; i < srcCount; processedBytes += srcVec[i].iov_len, i++) {
2601 iovec currentOp = srcVec[i];
2602 if (currentOp.iov_len == 0) { // to handle the oddballs
2603 continue;
2604 }
2605
2606 // if we haven't found the start offset yet, see if it is in this op
2607 if (dstCount == 0) {
2608 if (processedBytes + currentOp.iov_len < startOffset + 1) {
2609 continue; // start offset isn't in this op
2610 }
2611
2612 // offset iov_base to get the start offset
2613 const size_t trimFromStart = startOffset - processedBytes;
2614 currentOp.iov_base =
2615 reinterpret_cast<uint8_t*>(currentOp.iov_base) + trimFromStart;
2616 currentOp.iov_len -= trimFromStart;
2617 }
2618
2619 // trim the end of the iovec, if needed
2620 ssize_t trimFromEnd = (dstBytes + currentOp.iov_len) - targetBytes;
2621 if (trimFromEnd > 0) {
2622 currentOp.iov_len -= trimFromEnd;
2623 }
2624
2625 dstVec[dstCount] = currentOp;
2626 dstCount++;
2627 dstBytes += currentOp.iov_len;
2628 CHECK_GE(targetBytes, dstBytes);
2629 if (targetBytes == dstBytes) {
2630 break; // done
2631 }
2632 }
2633
2634 CHECK_EQ(targetBytes, dstBytes);
2635 }
2636
handleRead()2637 void AsyncSocket::handleRead() noexcept {
2638 VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
2639 << ", state=" << state_;
2640 assert(state_ == StateEnum::ESTABLISHED);
2641 assert((shutdownFlags_ & SHUT_READ) == 0);
2642 assert(readCallback_ != nullptr);
2643 assert(eventFlags_ & EventHandler::READ);
2644
2645 // Loop until:
2646 // - a read attempt would block
2647 // - readCallback_ is uninstalled
2648 // - the number of loop iterations exceeds the optional maximum
2649 // - this AsyncSocket is moved to another EventBase
2650 //
2651 // When we invoke readDataAvailable() it may uninstall the readCallback_,
2652 // which is why need to check for it here.
2653 //
2654 // The last bullet point is slightly subtle. readDataAvailable() may also
2655 // detach this socket from this EventBase. However, before
2656 // readDataAvailable() returns another thread may pick it up, attach it to
2657 // a different EventBase, and install another readCallback_. We need to
2658 // exit immediately after readDataAvailable() returns if the eventBase_ has
2659 // changed. (The caller must perform some sort of locking to transfer the
2660 // AsyncSocket between threads properly. This will be sufficient to ensure
2661 // that this thread sees the updated eventBase_ variable after
2662 // readDataAvailable() returns.)
2663 size_t numReads = maxReadsPerEvent_ ? maxReadsPerEvent_ : size_t(-1);
2664 EventBase* originalEventBase = eventBase_;
2665 while (readCallback_ && eventBase_ == originalEventBase && numReads--) {
2666 auto readMode = readCallback_->getReadMode();
2667 // Get the buffer(s) to read into.
2668 void* buf = nullptr;
2669 size_t buflen = 0, offset = 0, num = 0;
2670 IOBufIovecBuilder::IoVecVec iovs; // this can be an Asyncsocket member too
2671
2672 try {
2673 if (readMode == AsyncReader::ReadCallback::ReadMode::ReadVec) {
2674 prepareReadBuffers(iovs);
2675 num = iovs.size();
2676 VLOG(5) << "prepareReadBuffers() bufs=" << iovs.data()
2677 << ", num=" << num;
2678 } else {
2679 prepareReadBuffer(&buf, &buflen);
2680 VLOG(5) << "prepareReadBuffer() buf=" << buf << ", buflen=" << buflen;
2681 }
2682 } catch (const AsyncSocketException& ex) {
2683 return failRead(__func__, ex);
2684 } catch (const std::exception& ex) {
2685 AsyncSocketException tex(
2686 AsyncSocketException::BAD_ARGS,
2687 string("ReadCallback::getReadBuffer() "
2688 "threw exception: ") +
2689 ex.what());
2690 return failRead(__func__, tex);
2691 } catch (...) {
2692 AsyncSocketException ex(
2693 AsyncSocketException::BAD_ARGS,
2694 "ReadCallback::getReadBuffer() threw "
2695 "non-exception type");
2696 return failRead(__func__, ex);
2697 }
2698 if ((num == 0) && (buf == nullptr || buflen == 0)) {
2699 AsyncSocketException ex(
2700 AsyncSocketException::BAD_ARGS,
2701 "ReadCallback::getReadBuffer() returned "
2702 "empty buffer");
2703 return failRead(__func__, ex);
2704 }
2705
2706 // Perform the read
2707 auto readResult = (readMode == AsyncReader::ReadCallback::ReadMode::ReadVec)
2708 ? performReadv(iovs.data(), num)
2709 : performRead(&buf, &buflen, &offset);
2710 auto bytesRead = readResult.readReturn;
2711 VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got "
2712 << bytesRead << " bytes";
2713 if (bytesRead > 0) {
2714 readCallback_->readDataAvailable(size_t(bytesRead));
2715
2716 // Fall through and continue around the loop if the read
2717 // completely filled the available buffer.
2718 // Note that readCallback_ may have been uninstalled or changed inside
2719 // readDataAvailable().
2720 if (size_t(bytesRead) < buflen) {
2721 return;
2722 }
2723 } else if (bytesRead == READ_BLOCKING) {
2724 // No more data to read right now.
2725 return;
2726 } else if (bytesRead == READ_ERROR) {
2727 readErr_ = READ_ERROR;
2728 if (readResult.exception) {
2729 return failRead(__func__, *readResult.exception);
2730 }
2731 auto errnoCopy = errno;
2732 AsyncSocketException ex(
2733 AsyncSocketException::INTERNAL_ERROR,
2734 withAddr("recv() failed"),
2735 errnoCopy);
2736 return failRead(__func__, ex);
2737 } else {
2738 assert(bytesRead == READ_EOF);
2739 readErr_ = READ_EOF;
2740 // EOF
2741 shutdownFlags_ |= SHUT_READ;
2742 if (!updateEventRegistration(0, EventHandler::READ)) {
2743 // we've already been moved into STATE_ERROR
2744 assert(state_ == StateEnum::ERROR);
2745 assert(readCallback_ == nullptr);
2746 return;
2747 }
2748
2749 ReadCallback* callback = readCallback_;
2750 readCallback_ = nullptr;
2751 callback->readEOF();
2752 return;
2753 }
2754 }
2755
2756 if (readCallback_ && eventBase_ == originalEventBase) {
2757 // We might still have data in the socket.
2758 // (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
2759 scheduleImmediateRead();
2760 }
2761 }
2762
2763 /**
2764 * This function attempts to write as much data as possible, until no more data
2765 * can be written.
2766 *
2767 * - If it sends all available data, it unregisters for write events, and stops
2768 * the writeTimeout_.
2769 *
2770 * - If not all of the data can be sent immediately, it reschedules
2771 * writeTimeout_ (if a non-zero timeout is set), and ensures the handler is
2772 * registered for write events.
2773 */
handleWrite()2774 void AsyncSocket::handleWrite() noexcept {
2775 VLOG(5) << "AsyncSocket::handleWrite() this=" << this << ", fd=" << fd_
2776 << ", state=" << state_;
2777 DestructorGuard dg(this);
2778
2779 if (state_ == StateEnum::CONNECTING) {
2780 handleConnect();
2781 return;
2782 }
2783
2784 // Normal write
2785 assert(state_ == StateEnum::ESTABLISHED);
2786 assert((shutdownFlags_ & SHUT_WRITE) == 0);
2787 assert(writeReqHead_ != nullptr);
2788
2789 // Loop until we run out of write requests,
2790 // or until this socket is moved to another EventBase.
2791 // (See the comment in handleRead() explaining how this can happen.)
2792 EventBase* originalEventBase = eventBase_;
2793 while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
2794 auto writeResult = writeReqHead_->performWrite();
2795 if (writeResult.writeReturn < 0) {
2796 if (writeResult.exception) {
2797 return failWrite(__func__, *writeResult.exception);
2798 }
2799 auto errnoCopy = errno;
2800 AsyncSocketException ex(
2801 AsyncSocketException::INTERNAL_ERROR,
2802 withAddr("writev() failed"),
2803 errnoCopy);
2804 return failWrite(__func__, ex);
2805 } else if (writeReqHead_->isComplete()) {
2806 // We finished this request
2807 WriteRequest* req = writeReqHead_;
2808 writeReqHead_ = req->getNext();
2809
2810 if (writeReqHead_ == nullptr) {
2811 writeReqTail_ = nullptr;
2812 // This is the last write request.
2813 // Unregister for write events and cancel the send timer
2814 // before we invoke the callback. We have to update the state properly
2815 // before calling the callback, since it may want to detach us from
2816 // the EventBase.
2817 if (eventFlags_ & EventHandler::WRITE) {
2818 if (!updateEventRegistration(0, EventHandler::WRITE)) {
2819 assert(state_ == StateEnum::ERROR);
2820 return;
2821 }
2822 // Stop the send timeout
2823 writeTimeout_.cancelTimeout();
2824 }
2825 assert(!writeTimeout_.isScheduled());
2826
2827 // If SHUT_WRITE_PENDING is set, we should shutdown the socket after
2828 // we finish sending the last write request.
2829 //
2830 // We have to do this before invoking writeSuccess(), since
2831 // writeSuccess() may detach us from our EventBase.
2832 if (shutdownFlags_ & SHUT_WRITE_PENDING) {
2833 assert(connectCallback_ == nullptr);
2834 shutdownFlags_ |= SHUT_WRITE;
2835
2836 if (shutdownFlags_ & SHUT_READ) {
2837 // Reads have already been shutdown. Fully close the socket and
2838 // move to STATE_CLOSED.
2839 //
2840 // Note: This code currently moves us to STATE_CLOSED even if
2841 // close() hasn't ever been called. This can occur if we have
2842 // received EOF from the peer and shutdownWrite() has been called
2843 // locally. Should we bother staying in STATE_ESTABLISHED in this
2844 // case, until close() is actually called? I can't think of a
2845 // reason why we would need to do so. No other operations besides
2846 // calling close() or destroying the socket can be performed at
2847 // this point.
2848 assert(readCallback_ == nullptr);
2849 state_ = StateEnum::CLOSED;
2850 if (fd_ != NetworkSocket()) {
2851 ioHandler_.changeHandlerFD(NetworkSocket());
2852 doClose();
2853 }
2854 } else {
2855 // Reads are still enabled, so we are only doing a half-shutdown
2856 netops_->shutdown(fd_, SHUT_WR);
2857 }
2858 }
2859 }
2860
2861 // Invoke the callback
2862 WriteCallback* callback = req->getCallback();
2863 req->destroy();
2864 if (callback) {
2865 callback->writeSuccess();
2866 }
2867 // We'll continue around the loop, trying to write another request
2868 } else {
2869 // Partial write.
2870 writeReqHead_->consume();
2871 if (bufferCallback_) {
2872 bufferCallback_->onEgressBuffered();
2873 }
2874 // Stop after a partial write; it's highly likely that a subsequent write
2875 // attempt will just return EAGAIN.
2876 //
2877 // Ensure that we are registered for write events.
2878 if ((eventFlags_ & EventHandler::WRITE) == 0) {
2879 if (!updateEventRegistration(EventHandler::WRITE, 0)) {
2880 assert(state_ == StateEnum::ERROR);
2881 return;
2882 }
2883 }
2884
2885 // Reschedule the send timeout, since we have made some write progress.
2886 if (sendTimeout_ > 0) {
2887 if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
2888 AsyncSocketException ex(
2889 AsyncSocketException::INTERNAL_ERROR,
2890 withAddr("failed to reschedule write timeout"));
2891 return failWrite(__func__, ex);
2892 }
2893 }
2894 return;
2895 }
2896 }
2897 if (!writeReqHead_ && bufferCallback_) {
2898 bufferCallback_->onEgressBufferCleared();
2899 }
2900 }
2901
checkForImmediateRead()2902 void AsyncSocket::checkForImmediateRead() noexcept {
2903 // We currently don't attempt to perform optimistic reads in AsyncSocket.
2904 // (However, note that some subclasses do override this method.)
2905 //
2906 // Simply calling handleRead() here would be bad, as this would call
2907 // readCallback_->getReadBuffer(), forcing the callback to allocate a read
2908 // buffer even though no data may be available. This would waste lots of
2909 // memory, since the buffer will sit around unused until the socket actually
2910 // becomes readable.
2911 //
2912 // Checking if the socket is readable now also seems like it would probably
2913 // be a pessimism. In most cases it probably wouldn't be readable, and we
2914 // would just waste an extra system call. Even if it is readable, waiting to
2915 // find out from libevent on the next event loop doesn't seem that bad.
2916 //
2917 // The exception to this is if we have pre-received data. In that case there
2918 // is definitely data available immediately.
2919 if (preReceivedData_ && !preReceivedData_->empty()) {
2920 handleRead();
2921 }
2922 }
2923
handleInitialReadWrite()2924 void AsyncSocket::handleInitialReadWrite() noexcept {
2925 // Our callers should already be holding a DestructorGuard, but grab
2926 // one here just to make sure, in case one of our calling code paths ever
2927 // changes.
2928 DestructorGuard dg(this);
2929 // If we have a readCallback_, make sure we enable read events. We
2930 // may already be registered for reads if connectSuccess() set
2931 // the read calback.
2932 if (readCallback_ && !(eventFlags_ & EventHandler::READ)) {
2933 assert(state_ == StateEnum::ESTABLISHED);
2934 assert((shutdownFlags_ & SHUT_READ) == 0);
2935 if (!updateEventRegistration(EventHandler::READ, 0)) {
2936 assert(state_ == StateEnum::ERROR);
2937 return;
2938 }
2939 checkForImmediateRead();
2940 } else if (readCallback_ == nullptr) {
2941 // Unregister for read events.
2942 updateEventRegistration(0, EventHandler::READ);
2943 }
2944
2945 // If we have write requests pending, try to send them immediately.
2946 // Since we just finished accepting, there is a very good chance that we can
2947 // write without blocking.
2948 //
2949 // However, we only process them if EventHandler::WRITE is not already set,
2950 // which means that we're already blocked on a write attempt. (This can
2951 // happen if connectSuccess() called write() before returning.)
2952 if (writeReqHead_ && !(eventFlags_ & EventHandler::WRITE)) {
2953 // Call handleWrite() to perform write processing.
2954 handleWrite();
2955 } else if (writeReqHead_ == nullptr) {
2956 // Unregister for write event.
2957 updateEventRegistration(0, EventHandler::WRITE);
2958 }
2959 }
2960
handleConnect()2961 void AsyncSocket::handleConnect() noexcept {
2962 VLOG(5) << "AsyncSocket::handleConnect() this=" << this << ", fd=" << fd_
2963 << ", state=" << state_;
2964 assert(state_ == StateEnum::CONNECTING);
2965 // SHUT_WRITE can never be set while we are still connecting;
2966 // SHUT_WRITE_PENDING may be set, be we only set SHUT_WRITE once the connect
2967 // finishes
2968 assert((shutdownFlags_ & SHUT_WRITE) == 0);
2969
2970 // In case we had a connect timeout, cancel the timeout
2971 writeTimeout_.cancelTimeout();
2972 // We don't use a persistent registration when waiting on a connect event,
2973 // so we have been automatically unregistered now. Update eventFlags_ to
2974 // reflect reality.
2975 assert(eventFlags_ == EventHandler::WRITE);
2976 eventFlags_ = EventHandler::NONE;
2977
2978 // Call getsockopt() to check if the connect succeeded
2979 int error;
2980 socklen_t len = sizeof(error);
2981 int rv = netops_->getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
2982 if (rv != 0) {
2983 auto errnoCopy = errno;
2984 AsyncSocketException ex(
2985 AsyncSocketException::INTERNAL_ERROR,
2986 withAddr("error calling getsockopt() after connect"),
2987 errnoCopy);
2988 VLOG(4) << "AsyncSocket::handleConnect(this=" << this << ", fd=" << fd_
2989 << " host=" << addr_.describe() << ") exception:" << ex.what();
2990 return failConnect(__func__, ex);
2991 }
2992
2993 if (error != 0) {
2994 AsyncSocketException ex(
2995 AsyncSocketException::NOT_OPEN, "connect failed", error);
2996 VLOG(2) << "AsyncSocket::handleConnect(this=" << this << ", fd=" << fd_
2997 << " host=" << addr_.describe() << ") exception: " << ex.what();
2998 return failConnect(__func__, ex);
2999 }
3000
3001 // Move into STATE_ESTABLISHED
3002 state_ = StateEnum::ESTABLISHED;
3003
3004 // If SHUT_WRITE_PENDING is set and we don't have any write requests to
3005 // perform, immediately shutdown the write half of the socket.
3006 if ((shutdownFlags_ & SHUT_WRITE_PENDING) && writeReqHead_ == nullptr) {
3007 // SHUT_READ shouldn't be set. If close() is called on the socket while we
3008 // are still connecting we just abort the connect rather than waiting for
3009 // it to complete.
3010 assert((shutdownFlags_ & SHUT_READ) == 0);
3011 netops_->shutdown(fd_, SHUT_WR);
3012 shutdownFlags_ |= SHUT_WRITE;
3013 }
3014
3015 VLOG(7) << "AsyncSocket " << this << ": fd " << fd_
3016 << "successfully connected; state=" << state_;
3017
3018 // Remember the EventBase we are attached to, before we start invoking any
3019 // callbacks (since the callbacks may call detachEventBase()).
3020 EventBase* originalEventBase = eventBase_;
3021
3022 invokeConnectSuccess();
3023 // Note that the connect callback may have changed our state.
3024 // (set or unset the read callback, called write(), closed the socket, etc.)
3025 // The following code needs to handle these situations correctly.
3026 //
3027 // If the socket has been closed, readCallback_ and writeReqHead_ will
3028 // always be nullptr, so that will prevent us from trying to read or write.
3029 //
3030 // The main thing to check for is if eventBase_ is still originalEventBase.
3031 // If not, we have been detached from this event base, so we shouldn't
3032 // perform any more operations.
3033 if (eventBase_ != originalEventBase) {
3034 return;
3035 }
3036
3037 handleInitialReadWrite();
3038 }
3039
timeoutExpired()3040 void AsyncSocket::timeoutExpired() noexcept {
3041 VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
3042 << "state=" << state_ << ", events=" << std::hex << eventFlags_;
3043 DestructorGuard dg(this);
3044 eventBase_->dcheckIsInEventBaseThread();
3045
3046 if (state_ == StateEnum::CONNECTING) {
3047 // connect() timed out
3048 // Unregister for I/O events.
3049 if (connectCallback_) {
3050 AsyncSocketException ex(
3051 AsyncSocketException::TIMED_OUT,
3052 folly::sformat(
3053 "connect timed out after {}ms", connectTimeout_.count()));
3054 failConnect(__func__, ex);
3055 } else {
3056 // we faced a connect error without a connect callback, which could
3057 // happen due to TFO.
3058 AsyncSocketException ex(
3059 AsyncSocketException::TIMED_OUT, "write timed out during connection");
3060 failWrite(__func__, ex);
3061 }
3062 } else {
3063 // a normal write operation timed out
3064 AsyncSocketException ex(
3065 AsyncSocketException::TIMED_OUT,
3066 folly::sformat("write timed out after {}ms", sendTimeout_));
3067 failWrite(__func__, ex);
3068 }
3069 }
3070
handleNetworkSocketAttached()3071 void AsyncSocket::handleNetworkSocketAttached() {
3072 VLOG(6) << "AsyncSocket::attachFd(this=" << this << ", fd=" << fd_
3073 << ", evb=" << eventBase_ << " , state=" << state_
3074 << ", events=" << std::hex << eventFlags_ << ")";
3075 for (const auto& cb : lifecycleObservers_) {
3076 if (auto dCb = dynamic_cast<AsyncSocket::LifecycleObserver*>(cb)) {
3077 dCb->fdAttach(this);
3078 }
3079 }
3080
3081 if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
3082 shutdownSocketSet->add(fd_);
3083 }
3084 ioHandler_.changeHandlerFD(fd_);
3085 }
3086
tfoSendMsg(NetworkSocket fd,struct msghdr * msg,int msg_flags)3087 ssize_t AsyncSocket::tfoSendMsg(
3088 NetworkSocket fd, struct msghdr* msg, int msg_flags) {
3089 return detail::tfo_sendmsg(fd, msg, msg_flags);
3090 }
3091
sendSocketMessage(const iovec * vec,size_t count,WriteFlags flags)3092 AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
3093 const iovec* vec, size_t count, WriteFlags flags) {
3094 // lambda to gather and merge PrewriteRequests from observers
3095 auto mergePrewriteRequests = [this,
3096 vec,
3097 count,
3098 flags,
3099 maybeVecTotalBytes =
3100 folly::Optional<size_t>()]() mutable {
3101 AsyncTransport::LifecycleObserver::PrewriteRequest mergedRequest = {};
3102 if (lifecycleObservers_.empty()) {
3103 return mergedRequest;
3104 }
3105
3106 // determine total number of bytes in vec, reuse once determined
3107 if (!maybeVecTotalBytes.has_value()) {
3108 maybeVecTotalBytes = 0;
3109 for (size_t i = 0; i < count; ++i) {
3110 maybeVecTotalBytes.value() += vec[i].iov_len;
3111 }
3112 }
3113 auto& vecTotalBytes = maybeVecTotalBytes.value();
3114
3115 const auto startOffset = getRawBytesWritten();
3116 const auto endOffset = getRawBytesWritten() + vecTotalBytes - 1;
3117 const AsyncTransport::LifecycleObserver::PrewriteState prewriteState = [&] {
3118 AsyncTransport::LifecycleObserver::PrewriteState state = {};
3119 state.startOffset = startOffset;
3120 state.endOffset = endOffset;
3121 state.writeFlags = flags;
3122 state.ts = std::chrono::steady_clock::now();
3123 return state;
3124 }();
3125 for (const auto& observer : lifecycleObservers_) {
3126 if (!observer->getConfig().prewrite) {
3127 continue;
3128 }
3129
3130 const auto request = observer->prewrite(this, prewriteState);
3131
3132 mergedRequest.writeFlagsToAdd |= request.writeFlagsToAdd;
3133 if (request.maybeOffsetToSplitWrite.has_value()) {
3134 CHECK_GE(endOffset, request.maybeOffsetToSplitWrite.value());
3135 if (
3136 // case 1: offset not set in merged request
3137 !mergedRequest.maybeOffsetToSplitWrite.has_value() ||
3138 // case 2: offset in merged request > offset in current request
3139 mergedRequest.maybeOffsetToSplitWrite >
3140 request.maybeOffsetToSplitWrite) {
3141 mergedRequest.maybeOffsetToSplitWrite =
3142 request.maybeOffsetToSplitWrite; // update
3143 mergedRequest.writeFlagsToAddAtOffset =
3144 request.writeFlagsToAddAtOffset; // reset
3145 } else if (
3146 // case 3: offset in merged request == offset in current request
3147 request.maybeOffsetToSplitWrite ==
3148 mergedRequest.maybeOffsetToSplitWrite) {
3149 mergedRequest.writeFlagsToAddAtOffset |=
3150 request.writeFlagsToAddAtOffset; // merge
3151 }
3152 // case 4: offset in merged request < offset in current request
3153 // (do nothing)
3154 }
3155 }
3156
3157 // if maybeOffsetToSplitWrite points to end of the vector, remove the split
3158 if (mergedRequest.maybeOffsetToSplitWrite.has_value() && // explicit
3159 mergedRequest.maybeOffsetToSplitWrite == endOffset) {
3160 mergedRequest.maybeOffsetToSplitWrite.reset(); // no split needed
3161 }
3162
3163 return mergedRequest;
3164 };
3165
3166 // lambda to prepare and send a message, and handle byte events
3167 // parameters have L at the end to prevent shadowing warning from gcc
3168 auto prepSendMsg = [this](
3169 const iovec* vecL,
3170 const size_t countL,
3171 const WriteFlags flagsL) {
3172 const bool byteEventsEnabled =
3173 (byteEventHelper_ && byteEventHelper_->byteEventsEnabled &&
3174 !byteEventHelper_->maybeEx.has_value());
3175
3176 struct msghdr msg = {};
3177 msg.msg_name = nullptr;
3178 msg.msg_namelen = 0;
3179 msg.msg_iov = const_cast<struct iovec*>(vecL);
3180 msg.msg_iovlen = std::min<size_t>(countL, kIovMax);
3181 msg.msg_flags = 0; // passed to sendSocketMessage below, it sets them
3182 msg.msg_control = nullptr;
3183 msg.msg_controllen =
3184 sendMsgParamCallback_->getAncillaryDataSize(flagsL, byteEventsEnabled);
3185 CHECK_GE(
3186 AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
3187 msg.msg_controllen);
3188
3189 if (msg.msg_controllen != 0) {
3190 msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
3191 sendMsgParamCallback_->getAncillaryData(
3192 flagsL, msg.msg_control, byteEventsEnabled);
3193 }
3194
3195 const auto prewriteRawBytesWritten = getRawBytesWritten();
3196 int msg_flags = sendMsgParamCallback_->getFlags(flagsL, zeroCopyEnabled_);
3197 auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
3198
3199 if (writeResult.writeReturn < 0 && zeroCopyEnabled_ && errno == ENOBUFS) {
3200 // workaround for running with zerocopy enabled but without a big enough
3201 // memlock value - see ulimit -l
3202 zeroCopyEnabled_ = false;
3203 zeroCopyReenableCounter_ = zeroCopyReenableThreshold_;
3204 msg_flags = sendMsgParamCallback_->getFlags(flagsL, zeroCopyEnabled_);
3205 writeResult = sendSocketMessage(fd_, &msg, msg_flags);
3206 }
3207
3208 if (writeResult.writeReturn > 0 && byteEventsEnabled &&
3209 isSet(flagsL, WriteFlags::TIMESTAMP_WRITE)) {
3210 CHECK_GT(getRawBytesWritten(), prewriteRawBytesWritten); // sanity check
3211 ByteEvent byteEvent = {};
3212 byteEvent.type = ByteEvent::Type::WRITE;
3213 byteEvent.offset = getRawBytesWritten() - 1;
3214 byteEvent.maybeRawBytesWritten = writeResult.writeReturn;
3215 byteEvent.maybeRawBytesTriedToWrite = 0;
3216 for (size_t i = 0; i < countL; ++i) {
3217 byteEvent.maybeRawBytesTriedToWrite.value() += vecL[i].iov_len;
3218 }
3219 byteEvent.maybeWriteFlags = flagsL;
3220 for (const auto& observer : lifecycleObservers_) {
3221 if (observer->getConfig().byteEvents) {
3222 observer->byteEvent(this, byteEvent);
3223 }
3224 }
3225 }
3226
3227 return writeResult;
3228 };
3229
3230 // get PrewriteRequests (if any), merge flags with write flags
3231 const auto prewriteRequest = mergePrewriteRequests();
3232 auto mergedFlags = flags | prewriteRequest.writeFlagsToAdd |
3233 prewriteRequest.writeFlagsToAddAtOffset;
3234
3235 // if no PrewriteRequests, or none requiring the write to be split, proceed
3236 if (!prewriteRequest.maybeOffsetToSplitWrite.has_value()) {
3237 return prepSendMsg(vec, count, mergedFlags);
3238 }
3239
3240 // we need to split the write...
3241 // add CORK flag to inform the OS that more data is on the way...
3242 mergedFlags |= WriteFlags::CORK;
3243
3244 // TODO(bschlinker): When prewrite splits a write, try to continue writing
3245 // after a write returns; this will improve efficiency.
3246 const auto splitWriteAtOffset = *prewriteRequest.maybeOffsetToSplitWrite;
3247 if (count <= kSmallIoVecSize) {
3248 // suppress "warning: variable length array 'vec' is used [-Wvla]"
3249 FOLLY_PUSH_WARNING
3250 FOLLY_GNU_DISABLE_WARNING("-Wvla")
3251 iovec tmpVec[BOOST_PP_IF(FOLLY_HAVE_VLA_01, count, kSmallIoVecSize)];
3252 FOLLY_POP_WARNING
3253
3254 size_t tmpVecCount = count;
3255 splitIovecArray(
3256 0,
3257 splitWriteAtOffset - getRawBytesWritten(),
3258 vec,
3259 count,
3260 tmpVec,
3261 tmpVecCount);
3262 return prepSendMsg(tmpVec, tmpVecCount, mergedFlags);
3263 } else {
3264 auto tmpVecPtr = std::make_unique<iovec[]>(count);
3265 auto tmpVec = tmpVecPtr.get();
3266 size_t tmpVecCount = count;
3267 splitIovecArray(
3268 0,
3269 splitWriteAtOffset - getRawBytesWritten(),
3270 vec,
3271 count,
3272 tmpVec,
3273 tmpVecCount);
3274 return prepSendMsg(tmpVec, tmpVecCount, mergedFlags);
3275 }
3276 }
3277
sendSocketMessage(NetworkSocket fd,struct msghdr * msg,int msg_flags)3278 AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
3279 NetworkSocket fd, struct msghdr* msg, int msg_flags) {
3280 ssize_t totalWritten = 0;
3281 SCOPE_EXIT {
3282 if (totalWritten > 0) {
3283 rawBytesWritten_ += totalWritten;
3284 }
3285 };
3286 if (state_ == StateEnum::FAST_OPEN) {
3287 sockaddr_storage addr;
3288 auto len = addr_.getAddress(&addr);
3289 msg->msg_name = &addr;
3290 msg->msg_namelen = len;
3291 totalWritten = tfoSendMsg(fd_, msg, msg_flags);
3292 if (totalWritten >= 0) {
3293 tfoFinished_ = true;
3294 state_ = StateEnum::ESTABLISHED;
3295 // We schedule this asynchrously so that we don't end up
3296 // invoking initial read or write while a write is in progress.
3297 scheduleInitialReadWrite();
3298 } else if (errno == EINPROGRESS) {
3299 VLOG(4) << "TFO falling back to connecting";
3300 // A normal sendmsg doesn't return EINPROGRESS, however
3301 // TFO might fallback to connecting if there is no
3302 // cookie.
3303 state_ = StateEnum::CONNECTING;
3304 try {
3305 scheduleConnectTimeout();
3306 registerForConnectEvents();
3307 } catch (const AsyncSocketException& ex) {
3308 return WriteResult(
3309 WRITE_ERROR, std::make_unique<AsyncSocketException>(ex));
3310 }
3311 // Let's fake it that no bytes were written and return an errno.
3312 errno = EAGAIN;
3313 totalWritten = -1;
3314 } else if (errno == EOPNOTSUPP) {
3315 // Try falling back to connecting.
3316 VLOG(4) << "TFO not supported";
3317 state_ = StateEnum::CONNECTING;
3318 try {
3319 int ret = socketConnect((const sockaddr*)&addr, len);
3320 if (ret == 0) {
3321 // connect succeeded immediately
3322 // Treat this like no data was written.
3323 state_ = StateEnum::ESTABLISHED;
3324 scheduleInitialReadWrite();
3325 }
3326 // If there was no exception during connections,
3327 // we would return that no bytes were written.
3328 errno = EAGAIN;
3329 totalWritten = -1;
3330 } catch (const AsyncSocketException& ex) {
3331 return WriteResult(
3332 WRITE_ERROR, std::make_unique<AsyncSocketException>(ex));
3333 }
3334 } else if (errno == EAGAIN) {
3335 // Normally sendmsg would indicate that the write would block.
3336 // However in the fast open case, it would indicate that sendmsg
3337 // fell back to a connect. This is a return code from connect()
3338 // instead, and is an error condition indicating no fds available.
3339 return WriteResult(
3340 WRITE_ERROR,
3341 std::make_unique<AsyncSocketException>(
3342 AsyncSocketException::UNKNOWN, "No more free local ports"));
3343 }
3344 } else {
3345 totalWritten = netops_->sendmsg(fd, msg, msg_flags);
3346 }
3347 return WriteResult(totalWritten);
3348 }
3349
performWrite(const iovec * vec,uint32_t count,WriteFlags flags,uint32_t * countWritten,uint32_t * partialWritten)3350 AsyncSocket::WriteResult AsyncSocket::performWrite(
3351 const iovec* vec,
3352 uint32_t count,
3353 WriteFlags flags,
3354 uint32_t* countWritten,
3355 uint32_t* partialWritten) {
3356 auto writeResult = sendSocketMessage(vec, count, flags);
3357 auto totalWritten = writeResult.writeReturn;
3358 if (totalWritten < 0) {
3359 bool tryAgain = (errno == EAGAIN);
3360 #ifdef __APPLE__
3361 // Apple has a bug where doing a second write on a socket which we
3362 // have opened with TFO causes an ENOTCONN to be thrown. However the
3363 // socket is really connected, so treat ENOTCONN as a EAGAIN until
3364 // this bug is fixed.
3365 tryAgain |= (errno == ENOTCONN);
3366 #endif
3367
3368 if (!writeResult.exception && tryAgain) {
3369 // TCP buffer is full; we can't write any more data right now.
3370 *countWritten = 0;
3371 *partialWritten = 0;
3372 return WriteResult(0);
3373 }
3374 // error
3375 *countWritten = 0;
3376 *partialWritten = 0;
3377 return writeResult;
3378 }
3379
3380 appBytesWritten_ += totalWritten;
3381
3382 uint32_t bytesWritten;
3383 uint32_t n;
3384 for (bytesWritten = uint32_t(totalWritten), n = 0; n < count; ++n) {
3385 const iovec* v = vec + n;
3386 if (v->iov_len > bytesWritten) {
3387 // Partial write finished in the middle of this iovec
3388 *countWritten = n;
3389 *partialWritten = bytesWritten;
3390 return WriteResult(totalWritten);
3391 }
3392
3393 bytesWritten -= uint32_t(v->iov_len);
3394 }
3395
3396 assert(bytesWritten == 0);
3397 *countWritten = n;
3398 *partialWritten = 0;
3399 return WriteResult(totalWritten);
3400 }
3401
3402 /**
3403 * Re-register the EventHandler after eventFlags_ has changed.
3404 *
3405 * If an error occurs, fail() is called to move the socket into the error state
3406 * and call all currently installed callbacks. After an error, the
3407 * AsyncSocket is completely unregistered.
3408 *
3409 * @return Returns true on success, or false on error.
3410 */
updateEventRegistration()3411 bool AsyncSocket::updateEventRegistration() {
3412 VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
3413 << ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
3414 << ", events=" << std::hex << eventFlags_;
3415 if (eventFlags_ == EventHandler::NONE) {
3416 if (ioHandler_.isHandlerRegistered()) {
3417 DCHECK(eventBase_ != nullptr);
3418 eventBase_->dcheckIsInEventBaseThread();
3419 }
3420 ioHandler_.unregisterHandler();
3421 return true;
3422 }
3423
3424 eventBase_->dcheckIsInEventBaseThread();
3425
3426 // Always register for persistent events, so we don't have to re-register
3427 // after being called back.
3428 if (!ioHandler_.registerHandler(
3429 uint16_t(eventFlags_ | EventHandler::PERSIST))) {
3430 eventFlags_ = EventHandler::NONE; // we're not registered after error
3431 AsyncSocketException ex(
3432 AsyncSocketException::INTERNAL_ERROR,
3433 withAddr("failed to update AsyncSocket event registration"));
3434 fail("updateEventRegistration", ex);
3435 return false;
3436 }
3437
3438 return true;
3439 }
3440
updateEventRegistration(uint16_t enable,uint16_t disable)3441 bool AsyncSocket::updateEventRegistration(uint16_t enable, uint16_t disable) {
3442 uint16_t oldFlags = eventFlags_;
3443 eventFlags_ |= enable;
3444 eventFlags_ &= ~disable;
3445 if (eventFlags_ == oldFlags) {
3446 return true;
3447 } else {
3448 return updateEventRegistration();
3449 }
3450 }
3451
startFail()3452 void AsyncSocket::startFail() {
3453 // startFail() should only be called once
3454 assert(state_ != StateEnum::ERROR);
3455 assert(getDestructorGuardCount() > 0);
3456 state_ = StateEnum::ERROR;
3457 // Ensure that SHUT_READ and SHUT_WRITE are set,
3458 // so all future attempts to read or write will be rejected
3459 shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
3460
3461 // Cancel any scheduled immediate read.
3462 if (immediateReadHandler_.isLoopCallbackScheduled()) {
3463 immediateReadHandler_.cancelLoopCallback();
3464 }
3465
3466 if (eventFlags_ != EventHandler::NONE) {
3467 eventFlags_ = EventHandler::NONE;
3468 ioHandler_.unregisterHandler();
3469 }
3470 writeTimeout_.cancelTimeout();
3471
3472 if (fd_ != NetworkSocket()) {
3473 ioHandler_.changeHandlerFD(NetworkSocket());
3474 doClose();
3475 }
3476 }
3477
invokeAllErrors(const AsyncSocketException & ex)3478 void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) {
3479 invokeConnectErr(ex);
3480 failAllWrites(ex);
3481
3482 if (readCallback_) {
3483 ReadCallback* callback = readCallback_;
3484 readCallback_ = nullptr;
3485 callback->readErr(ex);
3486 }
3487 }
3488
finishFail()3489 void AsyncSocket::finishFail() {
3490 assert(state_ == StateEnum::ERROR);
3491 assert(getDestructorGuardCount() > 0);
3492
3493 AsyncSocketException ex(
3494 AsyncSocketException::INTERNAL_ERROR,
3495 withAddr("socket closing after error"));
3496 invokeAllErrors(ex);
3497 }
3498
finishFail(const AsyncSocketException & ex)3499 void AsyncSocket::finishFail(const AsyncSocketException& ex) {
3500 assert(state_ == StateEnum::ERROR);
3501 assert(getDestructorGuardCount() > 0);
3502 invokeAllErrors(ex);
3503 }
3504
fail(const char * fn,const AsyncSocketException & ex)3505 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
3506 VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
3507 << ", state=" << state_ << " host=" << addr_.describe()
3508 << "): failed in " << fn << "(): " << ex.what();
3509 startFail();
3510 finishFail(ex);
3511 }
3512
failConnect(const char * fn,const AsyncSocketException & ex)3513 void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
3514 VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3515 << ", state=" << state_ << " host=" << addr_.describe()
3516 << "): failed while connecting in " << fn << "(): " << ex.what();
3517 startFail();
3518
3519 invokeConnectErr(ex);
3520 finishFail(ex);
3521 }
3522
failRead(const char * fn,const AsyncSocketException & ex)3523 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
3524 VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3525 << ", state=" << state_ << " host=" << addr_.describe()
3526 << "): failed while reading in " << fn << "(): " << ex.what();
3527 startFail();
3528
3529 if (readCallback_ != nullptr) {
3530 ReadCallback* callback = readCallback_;
3531 readCallback_ = nullptr;
3532 callback->readErr(ex);
3533 }
3534
3535 finishFail(ex);
3536 }
3537
failErrMessageRead(const char * fn,const AsyncSocketException & ex)3538 void AsyncSocket::failErrMessageRead(
3539 const char* fn, const AsyncSocketException& ex) {
3540 VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3541 << ", state=" << state_ << " host=" << addr_.describe()
3542 << "): failed while reading message in " << fn << "(): " << ex.what();
3543 startFail();
3544
3545 if (errMessageCallback_ != nullptr) {
3546 ErrMessageCallback* callback = errMessageCallback_;
3547 errMessageCallback_ = nullptr;
3548 callback->errMessageError(ex);
3549 }
3550
3551 finishFail(ex);
3552 }
3553
failWrite(const char * fn,const AsyncSocketException & ex)3554 void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
3555 VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3556 << ", state=" << state_ << " host=" << addr_.describe()
3557 << "): failed while writing in " << fn << "(): " << ex.what();
3558 startFail();
3559
3560 // Only invoke the first write callback, since the error occurred while
3561 // writing this request. Let any other pending write callbacks be invoked in
3562 // finishFail().
3563 if (writeReqHead_ != nullptr) {
3564 WriteRequest* req = writeReqHead_;
3565 writeReqHead_ = req->getNext();
3566 WriteCallback* callback = req->getCallback();
3567 uint32_t bytesWritten = req->getTotalBytesWritten();
3568 req->destroy();
3569 if (callback) {
3570 callback->writeErr(bytesWritten, ex);
3571 }
3572 }
3573
3574 finishFail(ex);
3575 }
3576
failWrite(const char * fn,WriteCallback * callback,size_t bytesWritten,const AsyncSocketException & ex)3577 void AsyncSocket::failWrite(
3578 const char* fn,
3579 WriteCallback* callback,
3580 size_t bytesWritten,
3581 const AsyncSocketException& ex) {
3582 // This version of failWrite() is used when the failure occurs before
3583 // we've added the callback to writeReqHead_.
3584 VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
3585 << ", state=" << state_ << " host=" << addr_.describe()
3586 << "): failed while writing in " << fn << "(): " << ex.what();
3587 if (closeOnFailedWrite_) {
3588 startFail();
3589 }
3590
3591 if (callback != nullptr) {
3592 callback->writeErr(bytesWritten, ex);
3593 }
3594
3595 if (closeOnFailedWrite_) {
3596 finishFail(ex);
3597 }
3598 }
3599
failAllWrites(const AsyncSocketException & ex)3600 void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
3601 // Invoke writeError() on all write callbacks.
3602 // This is used when writes are forcibly shutdown with write requests
3603 // pending, or when an error occurs with writes pending.
3604 while (writeReqHead_ != nullptr) {
3605 WriteRequest* req = writeReqHead_;
3606 writeReqHead_ = req->getNext();
3607 WriteCallback* callback = req->getCallback();
3608 if (callback) {
3609 callback->writeErr(req->getTotalBytesWritten(), ex);
3610 }
3611 req->destroy();
3612 }
3613
3614 // All pending writes have failed - reset totalAppBytesScheduledForWrite_
3615 totalAppBytesScheduledForWrite_ = appBytesWritten_;
3616 }
3617
failByteEvents(const AsyncSocketException & ex)3618 void AsyncSocket::failByteEvents(const AsyncSocketException& ex) {
3619 CHECK(byteEventHelper_) << "failByteEvents called without ByteEventHelper";
3620 byteEventHelper_->maybeEx = ex;
3621 // inform any observers that want ByteEvents
3622 for (const auto& observer : lifecycleObservers_) {
3623 if (observer->getConfig().byteEvents) {
3624 observer->byteEventsUnavailable(this, ex);
3625 }
3626 }
3627 }
3628
invalidState(ConnectCallback * callback)3629 void AsyncSocket::invalidState(ConnectCallback* callback) {
3630 VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3631 << "): connect() called in invalid state " << state_;
3632
3633 /*
3634 * The invalidState() methods don't use the normal failure mechanisms,
3635 * since we don't know what state we are in. We don't want to call
3636 * startFail()/finishFail() recursively if we are already in the middle of
3637 * cleaning up.
3638 */
3639
3640 AsyncSocketException ex(
3641 AsyncSocketException::ALREADY_OPEN,
3642 "connect() called with socket in invalid state");
3643 connectEndTime_ = std::chrono::steady_clock::now();
3644 if ((state_ == StateEnum::CONNECTING) || (state_ == StateEnum::ERROR)) {
3645 for (const auto& cb : lifecycleObservers_) {
3646 if (auto observer = dynamic_cast<AsyncSocket::LifecycleObserver*>(cb)) {
3647 // inform any lifecycle observes that the connection failed
3648 observer->connectError(this, ex);
3649 }
3650 }
3651 }
3652 if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
3653 if (callback) {
3654 callback->connectErr(ex);
3655 }
3656 } else {
3657 // We can't use failConnect() here since connectCallback_
3658 // may already be set to another callback. Invoke this ConnectCallback
3659 // here; any other connectCallback_ will be invoked in finishFail()
3660 startFail();
3661 if (callback) {
3662 callback->connectErr(ex);
3663 }
3664 finishFail(ex);
3665 }
3666 }
3667
invalidState(ErrMessageCallback * callback)3668 void AsyncSocket::invalidState(ErrMessageCallback* callback) {
3669 VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
3670 << "): setErrMessageCB(" << callback << ") called in invalid state "
3671 << state_;
3672
3673 AsyncSocketException ex(
3674 AsyncSocketException::NOT_OPEN,
3675 msgErrQueueSupported
3676 ? "setErrMessageCB() called with socket in invalid state"
3677 : "This platform does not support socket error message notifications");
3678 if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
3679 if (callback) {
3680 callback->errMessageError(ex);
3681 }
3682 } else {
3683 startFail();
3684 if (callback) {
3685 callback->errMessageError(ex);
3686 }
3687 finishFail(ex);
3688 }
3689 }
3690
invokeConnectErr(const AsyncSocketException & ex)3691 void AsyncSocket::invokeConnectErr(const AsyncSocketException& ex) {
3692 VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3693 << "): connect err invoked with ex: " << ex.what();
3694 connectEndTime_ = std::chrono::steady_clock::now();
3695 if ((state_ == StateEnum::CONNECTING) || (state_ == StateEnum::ERROR)) {
3696 // invokeConnectErr() can be invoked when state is {FAST_OPEN, CLOSED,
3697 // ESTABLISHED} (!?) and a bunch of other places that are not what this call
3698 // back wants. This seems like a bug but work around here while we explore
3699 // it independently
3700 for (const auto& cb : lifecycleObservers_) {
3701 cb->connectError(this, ex);
3702 }
3703 }
3704 if (connectCallback_) {
3705 ConnectCallback* callback = connectCallback_;
3706 connectCallback_ = nullptr;
3707 callback->connectErr(ex);
3708 }
3709 }
3710
invokeConnectSuccess()3711 void AsyncSocket::invokeConnectSuccess() {
3712 VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3713 << "): connect success invoked";
3714 connectEndTime_ = std::chrono::steady_clock::now();
3715 bool enableByteEventsForObserver = false;
3716 for (const auto& cb : lifecycleObservers_) {
3717 cb->connectSuccess(this);
3718 enableByteEventsForObserver |= ((cb->getConfig().byteEvents) ? 1 : 0);
3719 }
3720 if (enableByteEventsForObserver) {
3721 enableByteEvents();
3722 }
3723 if (connectCallback_) {
3724 ConnectCallback* callback = connectCallback_;
3725 connectCallback_ = nullptr;
3726 callback->connectSuccess();
3727 }
3728 }
3729
invokeConnectAttempt()3730 void AsyncSocket::invokeConnectAttempt() {
3731 VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
3732 << "): connect attempt";
3733 for (const auto& cb : lifecycleObservers_) {
3734 cb->connectAttempt(this);
3735 }
3736 }
3737
invalidState(ReadCallback * callback)3738 void AsyncSocket::invalidState(ReadCallback* callback) {
3739 VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
3740 << "): setReadCallback(" << callback << ") called in invalid state "
3741 << state_;
3742
3743 AsyncSocketException ex(
3744 AsyncSocketException::NOT_OPEN,
3745 "setReadCallback() called with socket in "
3746 "invalid state");
3747 if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
3748 if (callback) {
3749 callback->readErr(ex);
3750 }
3751 } else {
3752 startFail();
3753 if (callback) {
3754 callback->readErr(ex);
3755 }
3756 finishFail(ex);
3757 }
3758 }
3759
invalidState(WriteCallback * callback)3760 void AsyncSocket::invalidState(WriteCallback* callback) {
3761 VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
3762 << "): write() called in invalid state " << state_;
3763
3764 AsyncSocketException ex(
3765 AsyncSocketException::NOT_OPEN,
3766 withAddr("write() called with socket in invalid state"));
3767 if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
3768 if (callback) {
3769 callback->writeErr(0, ex);
3770 }
3771 } else {
3772 startFail();
3773 if (callback) {
3774 callback->writeErr(0, ex);
3775 }
3776 finishFail(ex);
3777 }
3778 }
3779
doClose()3780 void AsyncSocket::doClose() {
3781 for (const auto& cb : lifecycleObservers_) {
3782 cb->close(this);
3783 }
3784 if (fd_ == NetworkSocket()) {
3785 return;
3786 }
3787 if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
3788 shutdownSocketSet->close(fd_);
3789 } else {
3790 netops_->close(fd_);
3791 }
3792 fd_ = NetworkSocket();
3793
3794 // we also want to clear the zerocopy maps
3795 // if the fd has been closed
3796 idZeroCopyBufPtrMap_.clear();
3797 idZeroCopyBufInfoMap_.clear();
3798 }
3799
operator <<(std::ostream & os,const AsyncSocket::StateEnum & state)3800 std::ostream& operator<<(
3801 std::ostream& os, const AsyncSocket::StateEnum& state) {
3802 os << static_cast<int>(state);
3803 return os;
3804 }
3805
withAddr(folly::StringPiece s)3806 std::string AsyncSocket::withAddr(folly::StringPiece s) {
3807 // Don't use addr_ directly because it may not be initialized
3808 // e.g. if constructed from fd
3809 folly::SocketAddress peer, local;
3810 try {
3811 getLocalAddress(&local);
3812 } catch (...) {
3813 // ignore
3814 }
3815 try {
3816 getPeerAddress(&peer);
3817 } catch (...) {
3818 // ignore
3819 }
3820
3821 return fmt::format(
3822 "{} (peer={}{})",
3823 s,
3824 peer.describe(),
3825 kIsMobile ? "" : fmt::format(", local={}", local.describe()));
3826 }
3827
setBufferCallback(BufferCallback * cb)3828 void AsyncSocket::setBufferCallback(BufferCallback* cb) {
3829 bufferCallback_ = cb;
3830 }
3831
3832 } // namespace folly
3833