1 /*
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <fizz/protocol/AsyncFizzBase.h>
10
11 #include <folly/Conv.h>
12 #include <folly/io/Cursor.h>
13
14 namespace fizz {
15
16 using folly::AsyncSocketException;
17
18 /**
19 * Min and max read buffer sizes when using non-movable buffer.
20 */
21 static const uint32_t kMinReadSize = 1460;
22 static const uint32_t kMaxReadSize = 4000;
23
24 /**
25 * Buffer size above which we should unset our read callback to apply back
26 * pressure on the transport.
27 */
28 static const uint32_t kMaxBufSize = 64 * 1024;
29
30 /**
31 * Buffer size above which we should break up shared writes, to avoid storing
32 * entire unencrypted and encrypted buffer simultaneously.
33 */
34 static const uint32_t kPartialWriteThreshold = 128 * 1024;
35
AsyncFizzBase(folly::AsyncTransportWrapper::UniquePtr transport,TransportOptions options)36 AsyncFizzBase::AsyncFizzBase(
37 folly::AsyncTransportWrapper::UniquePtr transport,
38 TransportOptions options)
39 : folly::WriteChainAsyncTransportWrapper<folly::AsyncTransportWrapper>(
40 std::move(transport)),
41 handshakeTimeout_(*this, transport_->getEventBase()),
42 transportOptions_(std::move(options)),
43 ioVecQueue_(folly::IOBufIovecBuilder::Options().setBlockSize(
44 transportOptions_.readVecBlockSize)) {
45 setReadMode(transportOptions_.readMode);
46 }
47
~AsyncFizzBase()48 AsyncFizzBase::~AsyncFizzBase() {
49 transport_->setEventCallback(nullptr);
50 transport_->setReadCB(nullptr);
51 if (tailWriteRequest_) {
52 tailWriteRequest_->unlinkFromBase();
53 }
54 }
55
destroy()56 void AsyncFizzBase::destroy() {
57 transport_->closeNow();
58 transport_->setEventCallback(nullptr);
59 transport_->setReadCB(nullptr);
60 DelayedDestruction::destroy();
61 }
62
getReadCallback() const63 AsyncFizzBase::ReadCallback* AsyncFizzBase::getReadCallback() const {
64 return readCallback_;
65 }
66
setReadCB(AsyncFizzBase::ReadCallback * callback)67 void AsyncFizzBase::setReadCB(AsyncFizzBase::ReadCallback* callback) {
68 readCallback_ = callback;
69
70 if (readCallback_) {
71 if (appDataBuf_) {
72 deliverAppData(nullptr);
73 }
74
75 if (!good()) {
76 AsyncSocketException ex(
77 AsyncSocketException::NOT_OPEN,
78 "setReadCB() called with transport in bad state");
79 deliverError(ex);
80 } else {
81 // The read callback may have been unset earlier if our buffer was full.
82 startTransportReads();
83 }
84 }
85 }
86
QueuedWriteRequest(AsyncFizzBase * base,folly::AsyncTransportWrapper::WriteCallback * callback,std::unique_ptr<folly::IOBuf> data,folly::WriteFlags flags)87 AsyncFizzBase::QueuedWriteRequest::QueuedWriteRequest(
88 AsyncFizzBase* base,
89 folly::AsyncTransportWrapper::WriteCallback* callback,
90 std::unique_ptr<folly::IOBuf> data,
91 folly::WriteFlags flags)
92 : asyncFizzBase_(base), callback_(callback), flags_(flags) {
93 data_.append(std::move(data));
94 entireChainBytesBuffered = data_.chainLength();
95 }
96
startWriting()97 void AsyncFizzBase::QueuedWriteRequest::startWriting() {
98 auto buf = data_.splitAtMost(kPartialWriteThreshold);
99
100 auto flags = flags_;
101 if (!data_.empty()) {
102 flags |= folly::WriteFlags::CORK;
103 }
104 size_t len = buf->computeChainDataLength();
105 dataWritten_ += len;
106
107 CHECK(asyncFizzBase_);
108 CHECK(asyncFizzBase_->tailWriteRequest_);
109 asyncFizzBase_->tailWriteRequest_->entireChainBytesBuffered -= len;
110 asyncFizzBase_->writeAppData(this, std::move(buf), flags);
111 }
112
append(QueuedWriteRequest * request)113 void AsyncFizzBase::QueuedWriteRequest::append(QueuedWriteRequest* request) {
114 DCHECK(!next_);
115 next_ = request;
116 next_->entireChainBytesBuffered += entireChainBytesBuffered;
117 entireChainBytesBuffered = 0;
118 }
119
unlinkFromBase()120 void AsyncFizzBase::QueuedWriteRequest::unlinkFromBase() {
121 asyncFizzBase_ = nullptr;
122 }
123
writeSuccess()124 void AsyncFizzBase::QueuedWriteRequest::writeSuccess() noexcept {
125 if (!data_.empty()) {
126 startWriting();
127 } else {
128 advanceOnBase();
129 auto callback = callback_;
130 auto next = next_;
131 auto base = asyncFizzBase_;
132 delete this;
133
134 DelayedDestruction::DestructorGuard dg(base);
135
136 if (callback) {
137 callback->writeSuccess();
138 }
139 if (next) {
140 next->startWriting();
141 }
142 }
143 }
144
writeErr(size_t,const folly::AsyncSocketException & ex)145 void AsyncFizzBase::QueuedWriteRequest::writeErr(
146 size_t /* written */,
147 const folly::AsyncSocketException& ex) noexcept {
148 // Deliver the error to all queued writes, starting with this one. We avoid
149 // recursively calling writeErr as that can cause excesssive stack usage if
150 // there are a large number of queued writes.
151 QueuedWriteRequest* errorToDeliver = this;
152 while (errorToDeliver) {
153 errorToDeliver = errorToDeliver->deliverSingleWriteErr(ex);
154 }
155 }
156
157 AsyncFizzBase::QueuedWriteRequest*
deliverSingleWriteErr(const folly::AsyncSocketException & ex)158 AsyncFizzBase::QueuedWriteRequest::deliverSingleWriteErr(
159 const folly::AsyncSocketException& ex) {
160 advanceOnBase();
161 auto callback = callback_;
162 auto next = next_;
163 auto dataWritten = dataWritten_;
164 delete this;
165
166 if (callback) {
167 callback->writeErr(dataWritten, ex);
168 }
169
170 return next;
171 }
172
advanceOnBase()173 void AsyncFizzBase::QueuedWriteRequest::advanceOnBase() {
174 if (!next_ && asyncFizzBase_) {
175 CHECK_EQ(asyncFizzBase_->tailWriteRequest_, this);
176 asyncFizzBase_->tailWriteRequest_ = nullptr;
177 }
178 }
179
writeChain(folly::AsyncTransportWrapper::WriteCallback * callback,std::unique_ptr<folly::IOBuf> && buf,folly::WriteFlags flags)180 void AsyncFizzBase::writeChain(
181 folly::AsyncTransportWrapper::WriteCallback* callback,
182 std::unique_ptr<folly::IOBuf>&& buf,
183 folly::WriteFlags flags) {
184 auto writeSize = buf->computeChainDataLength();
185 appBytesWritten_ += writeSize;
186
187 // We want to split up and queue large writes to avoid simultaneously storing
188 // unencrypted and encrypted large buffer in memory. We can skip this if the
189 // buffer is unshared (because we can encrypt in-place). We also skip this
190 // when sending early data to avoid the possibility of splitting writes
191 // between early data and normal data.
192 bool largeWrite = writeSize > kPartialWriteThreshold;
193 bool transportBuffering = transport_->getRawBytesBuffered() > 0;
194 bool needsToQueue = (largeWrite || transportBuffering) && buf->isShared() &&
195 !connecting() && isReplaySafe();
196 if (tailWriteRequest_ || needsToQueue) {
197 auto newWriteRequest =
198 new QueuedWriteRequest(this, callback, std::move(buf), flags);
199
200 if (tailWriteRequest_) {
201 tailWriteRequest_->append(newWriteRequest);
202 tailWriteRequest_ = newWriteRequest;
203 } else {
204 tailWriteRequest_ = newWriteRequest;
205 newWriteRequest->startWriting();
206 }
207 } else {
208 writeAppData(callback, std::move(buf), flags);
209 }
210 }
211
getAppBytesWritten() const212 size_t AsyncFizzBase::getAppBytesWritten() const {
213 return appBytesWritten_;
214 }
215
getAppBytesReceived() const216 size_t AsyncFizzBase::getAppBytesReceived() const {
217 return appBytesReceived_;
218 }
219
getAppBytesBuffered() const220 size_t AsyncFizzBase::getAppBytesBuffered() const {
221 return transport_->getAppBytesBuffered() +
222 (tailWriteRequest_ ? tailWriteRequest_->getEntireChainBytesBuffered()
223 : 0);
224 }
225
startTransportReads()226 void AsyncFizzBase::startTransportReads() {
227 if (transportOptions_.registerEventCallback) {
228 transport_->setEventCallback(this);
229 }
230 transport_->setReadCB(this);
231 }
232
startHandshakeTimeout(std::chrono::milliseconds timeout)233 void AsyncFizzBase::startHandshakeTimeout(std::chrono::milliseconds timeout) {
234 handshakeTimeout_.scheduleTimeout(timeout);
235 }
236
cancelHandshakeTimeout()237 void AsyncFizzBase::cancelHandshakeTimeout() {
238 handshakeTimeout_.cancelTimeout();
239 }
240
deliverAppData(std::unique_ptr<folly::IOBuf> data)241 void AsyncFizzBase::deliverAppData(std::unique_ptr<folly::IOBuf> data) {
242 if (data) {
243 appBytesReceived_ += data->computeChainDataLength();
244 }
245
246 if (appDataBuf_) {
247 if (data) {
248 appDataBuf_->prependChain(std::move(data));
249 }
250 data = std::move(appDataBuf_);
251 }
252
253 while (readCallback_ && data) {
254 if (readCallback_->isBufferMovable()) {
255 return readCallback_->readBufferAvailable(std::move(data));
256 } else {
257 folly::io::Cursor cursor(data.get());
258 size_t available = 0;
259 while ((available = cursor.totalLength()) != 0 && readCallback_ &&
260 !readCallback_->isBufferMovable()) {
261 void* buf = nullptr;
262 size_t buflen = 0;
263 try {
264 readCallback_->getReadBuffer(&buf, &buflen);
265 } catch (const AsyncSocketException& ase) {
266 return deliverError(ase);
267 } catch (const std::exception& e) {
268 AsyncSocketException ase(
269 AsyncSocketException::BAD_ARGS,
270 folly::to<std::string>("getReadBuffer() threw ", e.what()));
271 return deliverError(ase);
272 } catch (...) {
273 AsyncSocketException ase(
274 AsyncSocketException::BAD_ARGS,
275 "getReadBuffer() threw unknown exception");
276 return deliverError(ase);
277 }
278 if (buflen == 0 || buf == nullptr) {
279 AsyncSocketException ase(
280 AsyncSocketException::BAD_ARGS,
281 "getReadBuffer() returned empty buffer");
282 return deliverError(ase);
283 }
284
285 size_t bytesToRead = std::min(buflen, available);
286 cursor.pull(buf, bytesToRead);
287 readCallback_->readDataAvailable(bytesToRead);
288 }
289
290 // If we have data left, it means the read callback changed and we need
291 // to save the remaining data (if any)
292 if (available != 0) {
293 std::unique_ptr<folly::IOBuf> remainingData;
294 cursor.clone(remainingData, available);
295 data = std::move(remainingData);
296 } else {
297 // Out of data. Reset the data pointer to end the loop
298 data.reset();
299 }
300 }
301 }
302
303 if (data) {
304 appDataBuf_ = std::move(data);
305 }
306
307 checkBufLen();
308 }
309
deliverError(const AsyncSocketException & ex,bool closeTransport)310 void AsyncFizzBase::deliverError(
311 const AsyncSocketException& ex,
312 bool closeTransport) {
313 DelayedDestruction::DestructorGuard dg(this);
314
315 if (readCallback_) {
316 auto readCallback = readCallback_;
317 readCallback_ = nullptr;
318 if (ex.getType() == AsyncSocketException::END_OF_FILE) {
319 readCallback->readEOF();
320 } else {
321 readCallback->readErr(ex);
322 }
323 }
324
325 // Clear the secret callback too.
326 if (secretCallback_) {
327 secretCallback_ = nullptr;
328 }
329
330 if (closeTransport) {
331 transport_->close();
332 }
333 }
334
335 class AsyncFizzBase::FizzMsgHdr : public folly::EventRecvmsgCallback::MsgHdr {
336 FizzMsgHdr() = delete;
337
338 public:
339 ~FizzMsgHdr() override = default;
FizzMsgHdr(AsyncFizzBase * fizzBase)340 explicit FizzMsgHdr(AsyncFizzBase* fizzBase) {
341 arg_ = fizzBase;
342 freeFunc_ = FizzMsgHdr::free;
343 cbFunc_ = FizzMsgHdr::cb;
344 }
345
reset()346 void reset() {
347 data_ = msghdr{};
348 auto base = static_cast<AsyncFizzBase*>(arg_);
349 base->getReadBuffer(&iov_.iov_base, &iov_.iov_len);
350 data_.msg_iov = &iov_;
351 data_.msg_iovlen = 1;
352 }
353
free(folly::EventRecvmsgCallback::MsgHdr * msgHdr)354 static void free(folly::EventRecvmsgCallback::MsgHdr* msgHdr) {
355 delete msgHdr;
356 }
357
cb(folly::EventRecvmsgCallback::MsgHdr * msgHdr,int res)358 static void cb(folly::EventRecvmsgCallback::MsgHdr* msgHdr, int res) {
359 static_cast<AsyncFizzBase*>(msgHdr->arg_)
360 ->eventRecvmsgCallback(static_cast<FizzMsgHdr*>(msgHdr), res);
361 }
362
363 private:
364 iovec iov_;
365 };
366
allocateData()367 folly::EventRecvmsgCallback::MsgHdr* AsyncFizzBase::allocateData() {
368 auto* ret = msgHdr_.release();
369 if (!ret) {
370 ret = new FizzMsgHdr(this);
371 }
372 ret->reset();
373 return ret;
374 }
375
eventRecvmsgCallback(FizzMsgHdr * msgHdr,int res)376 void AsyncFizzBase::eventRecvmsgCallback(FizzMsgHdr* msgHdr, int res) {
377 DelayedDestruction::DestructorGuard dg(this);
378 if (res > 0) {
379 transportReadBuf_.postallocate(res);
380 transportDataAvailable();
381 checkBufLen();
382 } else if (res == 0) {
383 readEOF();
384 } else {
385 AsyncSocketException ex(
386 AsyncSocketException::INTERNAL_ERROR, "event recv failed", (0 - res));
387 deliverError(ex);
388 }
389 msgHdr_.reset(msgHdr);
390 }
391
getReadBuffer(void ** bufReturn,size_t * lenReturn)392 void AsyncFizzBase::getReadBuffer(void** bufReturn, size_t* lenReturn) {
393 std::pair<void*, uint32_t> readSpace =
394 transportReadBuf_.preallocate(kMinReadSize, kMaxReadSize);
395 *bufReturn = readSpace.first;
396
397 // `readSizeHint_`, if zero, indicates that we do not care about how much
398 // data we read from the underlying socket.
399 //
400 // `readSizeHint_`, if nonzero, indicates the maximum amount of data we
401 // want to read from the underlying socket. This is necessary for kTLS,
402 // where we want to ensure that when ReportHandshakeSuccess is called, we
403 // are at a known point in the TCP stream, so we can let the kernel start
404 // decrypting records for us.
405 //
406 // For transport with "record aligned reads", we initially set `readSizeHint_`
407 // equal to the size of the TLS record header. Subsequently, the state machine
408 // will tell us exactly how much data is required to complete the record
409 // in WaitForData actions.
410 if (readSizeHint_ > 0) {
411 *lenReturn = std::min(
412 static_cast<decltype(readSizeHint_)>(kMinReadSize), readSizeHint_);
413 } else {
414 *lenReturn = readSpace.second;
415 }
416 }
417
getReadBuffers(folly::IOBufIovecBuilder::IoVecVec & iovs)418 void AsyncFizzBase::getReadBuffers(folly::IOBufIovecBuilder::IoVecVec& iovs) {
419 ioVecQueue_.allocateBuffers(iovs, kMaxReadSize);
420 }
421
readDataAvailable(size_t len)422 void AsyncFizzBase::readDataAvailable(size_t len) noexcept {
423 DelayedDestruction::DestructorGuard dg(this);
424
425 if (getReadMode() == folly::AsyncTransport::ReadCallback::ReadMode::ReadVec) {
426 auto tmp = ioVecQueue_.extractIOBufChain(len);
427 transportReadBuf_.append(std::move(tmp));
428 } else {
429 transportReadBuf_.postallocate(len);
430 }
431 transportDataAvailable();
432 checkBufLen();
433 }
434
isBufferMovable()435 bool AsyncFizzBase::isBufferMovable() noexcept {
436 return true;
437 }
438
readBufferAvailable(std::unique_ptr<folly::IOBuf> data)439 void AsyncFizzBase::readBufferAvailable(
440 std::unique_ptr<folly::IOBuf> data) noexcept {
441 DelayedDestruction::DestructorGuard dg(this);
442
443 transportReadBuf_.append(std::move(data));
444 transportDataAvailable();
445 checkBufLen();
446 }
447
readEOF()448 void AsyncFizzBase::readEOF() noexcept {
449 AsyncSocketException eof(AsyncSocketException::END_OF_FILE, "readEOF()");
450 transportError(eof);
451 }
452
readErr(const folly::AsyncSocketException & ex)453 void AsyncFizzBase::readErr(const folly::AsyncSocketException& ex) noexcept {
454 transportError(ex);
455 }
456
writeSuccess()457 void AsyncFizzBase::writeSuccess() noexcept {}
458
writeErr(size_t,const folly::AsyncSocketException & ex)459 void AsyncFizzBase::writeErr(
460 size_t /* bytesWritten */,
461 const folly::AsyncSocketException& ex) noexcept {
462 transportError(ex);
463 }
464
checkBufLen()465 void AsyncFizzBase::checkBufLen() {
466 if (!readCallback_ &&
467 (transportReadBuf_.chainLength() >= kMaxBufSize ||
468 (appDataBuf_ && appDataBuf_->computeChainDataLength() >= kMaxBufSize))) {
469 transport_->setEventCallback(nullptr);
470 transport_->setReadCB(nullptr);
471 }
472 }
473
handshakeTimeoutExpired()474 void AsyncFizzBase::handshakeTimeoutExpired() noexcept {
475 AsyncSocketException eof(
476 AsyncSocketException::TIMED_OUT, "handshake timeout expired");
477 transportError(eof);
478 }
479
endOfTLS(std::unique_ptr<folly::IOBuf> endOfData)480 void AsyncFizzBase::endOfTLS(std::unique_ptr<folly::IOBuf> endOfData) noexcept {
481 DelayedDestruction::DestructorGuard dg(this);
482
483 if (connecting()) {
484 AsyncSocketException ex(
485 AsyncSocketException::INVALID_STATE,
486 "tls connection torn down while connecting");
487 transportError(ex);
488 return;
489 }
490
491 if (endOfTLSCallback_) {
492 endOfTLSCallback_->endOfTLS(this, std::move(endOfData));
493 } else {
494 // The end of TLS callback may not want the socket to be closed but by
495 // default read callbacks often close on EOF, as such we defer to the setter
496 // of the end of tls callback to apply the appropriate behaviour if it's set
497 if (readCallback_) {
498 auto readCallback = readCallback_;
499 readCallback_ = nullptr;
500 readCallback->readEOF();
501 }
502 transport_->close();
503 }
504 }
505
506 // The below maps the secret type to the appropriate secret callback function.
507 namespace {
508 class SecretVisitor {
509 public:
SecretVisitor(AsyncFizzBase::SecretCallback * cb,const std::vector<uint8_t> & secretBuf)510 explicit SecretVisitor(
511 AsyncFizzBase::SecretCallback* cb,
512 const std::vector<uint8_t>& secretBuf)
513 : callback_(cb), secretBuf_(secretBuf) {}
operator ()(const SecretType & secretType)514 void operator()(const SecretType& secretType) {
515 switch (secretType.type()) {
516 case SecretType::Type::EarlySecrets_E:
517 operator()(*secretType.asEarlySecrets());
518 break;
519 case SecretType::Type::HandshakeSecrets_E:
520 operator()(*secretType.asHandshakeSecrets());
521 break;
522 case SecretType::Type::MasterSecrets_E:
523 operator()(*secretType.asMasterSecrets());
524 break;
525 case SecretType::Type::AppTrafficSecrets_E:
526 operator()(*secretType.asAppTrafficSecrets());
527 break;
528 }
529 }
530
operator ()(const EarlySecrets & secret)531 void operator()(const EarlySecrets& secret) {
532 switch (secret) {
533 case EarlySecrets::ExternalPskBinder:
534 callback_->externalPskBinderAvailable(secretBuf_);
535 return;
536 case EarlySecrets::ResumptionPskBinder:
537 callback_->resumptionPskBinderAvailable(secretBuf_);
538 return;
539 case EarlySecrets::ClientEarlyTraffic:
540 callback_->clientEarlyTrafficSecretAvailable(secretBuf_);
541 return;
542 case EarlySecrets::EarlyExporter:
543 callback_->earlyExporterSecretAvailable(secretBuf_);
544 return;
545 }
546 }
operator ()(const HandshakeSecrets & secret)547 void operator()(const HandshakeSecrets& secret) {
548 switch (secret) {
549 case HandshakeSecrets::ClientHandshakeTraffic:
550 callback_->clientHandshakeTrafficSecretAvailable(secretBuf_);
551 return;
552 case HandshakeSecrets::ServerHandshakeTraffic:
553 callback_->serverHandshakeTrafficSecretAvailable(secretBuf_);
554 return;
555 }
556 }
operator ()(const MasterSecrets & secret)557 void operator()(const MasterSecrets& secret) {
558 switch (secret) {
559 case MasterSecrets::ExporterMaster:
560 callback_->exporterMasterSecretAvailable(secretBuf_);
561 return;
562 case MasterSecrets::ResumptionMaster:
563 callback_->resumptionMasterSecretAvailable(secretBuf_);
564 return;
565 }
566 }
operator ()(const AppTrafficSecrets & secret)567 void operator()(const AppTrafficSecrets& secret) {
568 switch (secret) {
569 case AppTrafficSecrets::ClientAppTraffic:
570 callback_->clientAppTrafficSecretAvailable(secretBuf_);
571 return;
572 case AppTrafficSecrets::ServerAppTraffic:
573 callback_->serverAppTrafficSecretAvailable(secretBuf_);
574 return;
575 }
576 }
577
578 private:
579 AsyncFizzBase::SecretCallback* callback_;
580 const std::vector<uint8_t>& secretBuf_;
581 };
582 } // namespace
583
secretAvailable(const DerivedSecret & secret)584 void AsyncFizzBase::secretAvailable(const DerivedSecret& secret) noexcept {
585 if (secretCallback_) {
586 SecretVisitor visitor(secretCallback_, secret.secret);
587 visitor(secret.type);
588 }
589 }
590 } // namespace fizz
591