1 /*
2  *  Copyright (c) 2018-present, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <fizz/protocol/AsyncFizzBase.h>
10 
11 #include <folly/Conv.h>
12 #include <folly/io/Cursor.h>
13 
14 namespace fizz {
15 
16 using folly::AsyncSocketException;
17 
18 /**
19  * Min and max read buffer sizes when using non-movable buffer.
20  */
21 static const uint32_t kMinReadSize = 1460;
22 static const uint32_t kMaxReadSize = 4000;
23 
24 /**
25  * Buffer size above which we should unset our read callback to apply back
26  * pressure on the transport.
27  */
28 static const uint32_t kMaxBufSize = 64 * 1024;
29 
30 /**
31  * Buffer size above which we should break up shared writes, to avoid storing
32  * entire unencrypted and encrypted buffer simultaneously.
33  */
34 static const uint32_t kPartialWriteThreshold = 128 * 1024;
35 
AsyncFizzBase(folly::AsyncTransportWrapper::UniquePtr transport,TransportOptions options)36 AsyncFizzBase::AsyncFizzBase(
37     folly::AsyncTransportWrapper::UniquePtr transport,
38     TransportOptions options)
39     : folly::WriteChainAsyncTransportWrapper<folly::AsyncTransportWrapper>(
40           std::move(transport)),
41       handshakeTimeout_(*this, transport_->getEventBase()),
42       transportOptions_(std::move(options)),
43       ioVecQueue_(folly::IOBufIovecBuilder::Options().setBlockSize(
44           transportOptions_.readVecBlockSize)) {
45   setReadMode(transportOptions_.readMode);
46 }
47 
~AsyncFizzBase()48 AsyncFizzBase::~AsyncFizzBase() {
49   transport_->setEventCallback(nullptr);
50   transport_->setReadCB(nullptr);
51   if (tailWriteRequest_) {
52     tailWriteRequest_->unlinkFromBase();
53   }
54 }
55 
destroy()56 void AsyncFizzBase::destroy() {
57   transport_->closeNow();
58   transport_->setEventCallback(nullptr);
59   transport_->setReadCB(nullptr);
60   DelayedDestruction::destroy();
61 }
62 
getReadCallback() const63 AsyncFizzBase::ReadCallback* AsyncFizzBase::getReadCallback() const {
64   return readCallback_;
65 }
66 
setReadCB(AsyncFizzBase::ReadCallback * callback)67 void AsyncFizzBase::setReadCB(AsyncFizzBase::ReadCallback* callback) {
68   readCallback_ = callback;
69 
70   if (readCallback_) {
71     if (appDataBuf_) {
72       deliverAppData(nullptr);
73     }
74 
75     if (!good()) {
76       AsyncSocketException ex(
77           AsyncSocketException::NOT_OPEN,
78           "setReadCB() called with transport in bad state");
79       deliverError(ex);
80     } else {
81       // The read callback may have been unset earlier if our buffer was full.
82       startTransportReads();
83     }
84   }
85 }
86 
QueuedWriteRequest(AsyncFizzBase * base,folly::AsyncTransportWrapper::WriteCallback * callback,std::unique_ptr<folly::IOBuf> data,folly::WriteFlags flags)87 AsyncFizzBase::QueuedWriteRequest::QueuedWriteRequest(
88     AsyncFizzBase* base,
89     folly::AsyncTransportWrapper::WriteCallback* callback,
90     std::unique_ptr<folly::IOBuf> data,
91     folly::WriteFlags flags)
92     : asyncFizzBase_(base), callback_(callback), flags_(flags) {
93   data_.append(std::move(data));
94   entireChainBytesBuffered = data_.chainLength();
95 }
96 
startWriting()97 void AsyncFizzBase::QueuedWriteRequest::startWriting() {
98   auto buf = data_.splitAtMost(kPartialWriteThreshold);
99 
100   auto flags = flags_;
101   if (!data_.empty()) {
102     flags |= folly::WriteFlags::CORK;
103   }
104   size_t len = buf->computeChainDataLength();
105   dataWritten_ += len;
106 
107   CHECK(asyncFizzBase_);
108   CHECK(asyncFizzBase_->tailWriteRequest_);
109   asyncFizzBase_->tailWriteRequest_->entireChainBytesBuffered -= len;
110   asyncFizzBase_->writeAppData(this, std::move(buf), flags);
111 }
112 
append(QueuedWriteRequest * request)113 void AsyncFizzBase::QueuedWriteRequest::append(QueuedWriteRequest* request) {
114   DCHECK(!next_);
115   next_ = request;
116   next_->entireChainBytesBuffered += entireChainBytesBuffered;
117   entireChainBytesBuffered = 0;
118 }
119 
unlinkFromBase()120 void AsyncFizzBase::QueuedWriteRequest::unlinkFromBase() {
121   asyncFizzBase_ = nullptr;
122 }
123 
writeSuccess()124 void AsyncFizzBase::QueuedWriteRequest::writeSuccess() noexcept {
125   if (!data_.empty()) {
126     startWriting();
127   } else {
128     advanceOnBase();
129     auto callback = callback_;
130     auto next = next_;
131     auto base = asyncFizzBase_;
132     delete this;
133 
134     DelayedDestruction::DestructorGuard dg(base);
135 
136     if (callback) {
137       callback->writeSuccess();
138     }
139     if (next) {
140       next->startWriting();
141     }
142   }
143 }
144 
writeErr(size_t,const folly::AsyncSocketException & ex)145 void AsyncFizzBase::QueuedWriteRequest::writeErr(
146     size_t /* written */,
147     const folly::AsyncSocketException& ex) noexcept {
148   // Deliver the error to all queued writes, starting with this one. We avoid
149   // recursively calling writeErr as that can cause excesssive stack usage if
150   // there are a large number of queued writes.
151   QueuedWriteRequest* errorToDeliver = this;
152   while (errorToDeliver) {
153     errorToDeliver = errorToDeliver->deliverSingleWriteErr(ex);
154   }
155 }
156 
157 AsyncFizzBase::QueuedWriteRequest*
deliverSingleWriteErr(const folly::AsyncSocketException & ex)158 AsyncFizzBase::QueuedWriteRequest::deliverSingleWriteErr(
159     const folly::AsyncSocketException& ex) {
160   advanceOnBase();
161   auto callback = callback_;
162   auto next = next_;
163   auto dataWritten = dataWritten_;
164   delete this;
165 
166   if (callback) {
167     callback->writeErr(dataWritten, ex);
168   }
169 
170   return next;
171 }
172 
advanceOnBase()173 void AsyncFizzBase::QueuedWriteRequest::advanceOnBase() {
174   if (!next_ && asyncFizzBase_) {
175     CHECK_EQ(asyncFizzBase_->tailWriteRequest_, this);
176     asyncFizzBase_->tailWriteRequest_ = nullptr;
177   }
178 }
179 
writeChain(folly::AsyncTransportWrapper::WriteCallback * callback,std::unique_ptr<folly::IOBuf> && buf,folly::WriteFlags flags)180 void AsyncFizzBase::writeChain(
181     folly::AsyncTransportWrapper::WriteCallback* callback,
182     std::unique_ptr<folly::IOBuf>&& buf,
183     folly::WriteFlags flags) {
184   auto writeSize = buf->computeChainDataLength();
185   appBytesWritten_ += writeSize;
186 
187   // We want to split up and queue large writes to avoid simultaneously storing
188   // unencrypted and encrypted large buffer in memory. We can skip this if the
189   // buffer is unshared (because we can encrypt in-place). We also skip this
190   // when sending early data to avoid the possibility of splitting writes
191   // between early data and normal data.
192   bool largeWrite = writeSize > kPartialWriteThreshold;
193   bool transportBuffering = transport_->getRawBytesBuffered() > 0;
194   bool needsToQueue = (largeWrite || transportBuffering) && buf->isShared() &&
195       !connecting() && isReplaySafe();
196   if (tailWriteRequest_ || needsToQueue) {
197     auto newWriteRequest =
198         new QueuedWriteRequest(this, callback, std::move(buf), flags);
199 
200     if (tailWriteRequest_) {
201       tailWriteRequest_->append(newWriteRequest);
202       tailWriteRequest_ = newWriteRequest;
203     } else {
204       tailWriteRequest_ = newWriteRequest;
205       newWriteRequest->startWriting();
206     }
207   } else {
208     writeAppData(callback, std::move(buf), flags);
209   }
210 }
211 
getAppBytesWritten() const212 size_t AsyncFizzBase::getAppBytesWritten() const {
213   return appBytesWritten_;
214 }
215 
getAppBytesReceived() const216 size_t AsyncFizzBase::getAppBytesReceived() const {
217   return appBytesReceived_;
218 }
219 
getAppBytesBuffered() const220 size_t AsyncFizzBase::getAppBytesBuffered() const {
221   return transport_->getAppBytesBuffered() +
222       (tailWriteRequest_ ? tailWriteRequest_->getEntireChainBytesBuffered()
223                          : 0);
224 }
225 
startTransportReads()226 void AsyncFizzBase::startTransportReads() {
227   if (transportOptions_.registerEventCallback) {
228     transport_->setEventCallback(this);
229   }
230   transport_->setReadCB(this);
231 }
232 
startHandshakeTimeout(std::chrono::milliseconds timeout)233 void AsyncFizzBase::startHandshakeTimeout(std::chrono::milliseconds timeout) {
234   handshakeTimeout_.scheduleTimeout(timeout);
235 }
236 
cancelHandshakeTimeout()237 void AsyncFizzBase::cancelHandshakeTimeout() {
238   handshakeTimeout_.cancelTimeout();
239 }
240 
deliverAppData(std::unique_ptr<folly::IOBuf> data)241 void AsyncFizzBase::deliverAppData(std::unique_ptr<folly::IOBuf> data) {
242   if (data) {
243     appBytesReceived_ += data->computeChainDataLength();
244   }
245 
246   if (appDataBuf_) {
247     if (data) {
248       appDataBuf_->prependChain(std::move(data));
249     }
250     data = std::move(appDataBuf_);
251   }
252 
253   while (readCallback_ && data) {
254     if (readCallback_->isBufferMovable()) {
255       return readCallback_->readBufferAvailable(std::move(data));
256     } else {
257       folly::io::Cursor cursor(data.get());
258       size_t available = 0;
259       while ((available = cursor.totalLength()) != 0 && readCallback_ &&
260              !readCallback_->isBufferMovable()) {
261         void* buf = nullptr;
262         size_t buflen = 0;
263         try {
264           readCallback_->getReadBuffer(&buf, &buflen);
265         } catch (const AsyncSocketException& ase) {
266           return deliverError(ase);
267         } catch (const std::exception& e) {
268           AsyncSocketException ase(
269               AsyncSocketException::BAD_ARGS,
270               folly::to<std::string>("getReadBuffer() threw ", e.what()));
271           return deliverError(ase);
272         } catch (...) {
273           AsyncSocketException ase(
274               AsyncSocketException::BAD_ARGS,
275               "getReadBuffer() threw unknown exception");
276           return deliverError(ase);
277         }
278         if (buflen == 0 || buf == nullptr) {
279           AsyncSocketException ase(
280               AsyncSocketException::BAD_ARGS,
281               "getReadBuffer() returned empty buffer");
282           return deliverError(ase);
283         }
284 
285         size_t bytesToRead = std::min(buflen, available);
286         cursor.pull(buf, bytesToRead);
287         readCallback_->readDataAvailable(bytesToRead);
288       }
289 
290       // If we have data left, it means the read callback changed and we need
291       // to save the remaining data (if any)
292       if (available != 0) {
293         std::unique_ptr<folly::IOBuf> remainingData;
294         cursor.clone(remainingData, available);
295         data = std::move(remainingData);
296       } else {
297         // Out of data. Reset the data pointer to end the loop
298         data.reset();
299       }
300     }
301   }
302 
303   if (data) {
304     appDataBuf_ = std::move(data);
305   }
306 
307   checkBufLen();
308 }
309 
deliverError(const AsyncSocketException & ex,bool closeTransport)310 void AsyncFizzBase::deliverError(
311     const AsyncSocketException& ex,
312     bool closeTransport) {
313   DelayedDestruction::DestructorGuard dg(this);
314 
315   if (readCallback_) {
316     auto readCallback = readCallback_;
317     readCallback_ = nullptr;
318     if (ex.getType() == AsyncSocketException::END_OF_FILE) {
319       readCallback->readEOF();
320     } else {
321       readCallback->readErr(ex);
322     }
323   }
324 
325   // Clear the secret callback too.
326   if (secretCallback_) {
327     secretCallback_ = nullptr;
328   }
329 
330   if (closeTransport) {
331     transport_->close();
332   }
333 }
334 
335 class AsyncFizzBase::FizzMsgHdr : public folly::EventRecvmsgCallback::MsgHdr {
336   FizzMsgHdr() = delete;
337 
338  public:
339   ~FizzMsgHdr() override = default;
FizzMsgHdr(AsyncFizzBase * fizzBase)340   explicit FizzMsgHdr(AsyncFizzBase* fizzBase) {
341     arg_ = fizzBase;
342     freeFunc_ = FizzMsgHdr::free;
343     cbFunc_ = FizzMsgHdr::cb;
344   }
345 
reset()346   void reset() {
347     data_ = msghdr{};
348     auto base = static_cast<AsyncFizzBase*>(arg_);
349     base->getReadBuffer(&iov_.iov_base, &iov_.iov_len);
350     data_.msg_iov = &iov_;
351     data_.msg_iovlen = 1;
352   }
353 
free(folly::EventRecvmsgCallback::MsgHdr * msgHdr)354   static void free(folly::EventRecvmsgCallback::MsgHdr* msgHdr) {
355     delete msgHdr;
356   }
357 
cb(folly::EventRecvmsgCallback::MsgHdr * msgHdr,int res)358   static void cb(folly::EventRecvmsgCallback::MsgHdr* msgHdr, int res) {
359     static_cast<AsyncFizzBase*>(msgHdr->arg_)
360         ->eventRecvmsgCallback(static_cast<FizzMsgHdr*>(msgHdr), res);
361   }
362 
363  private:
364   iovec iov_;
365 };
366 
allocateData()367 folly::EventRecvmsgCallback::MsgHdr* AsyncFizzBase::allocateData() {
368   auto* ret = msgHdr_.release();
369   if (!ret) {
370     ret = new FizzMsgHdr(this);
371   }
372   ret->reset();
373   return ret;
374 }
375 
eventRecvmsgCallback(FizzMsgHdr * msgHdr,int res)376 void AsyncFizzBase::eventRecvmsgCallback(FizzMsgHdr* msgHdr, int res) {
377   DelayedDestruction::DestructorGuard dg(this);
378   if (res > 0) {
379     transportReadBuf_.postallocate(res);
380     transportDataAvailable();
381     checkBufLen();
382   } else if (res == 0) {
383     readEOF();
384   } else {
385     AsyncSocketException ex(
386         AsyncSocketException::INTERNAL_ERROR, "event recv failed", (0 - res));
387     deliverError(ex);
388   }
389   msgHdr_.reset(msgHdr);
390 }
391 
getReadBuffer(void ** bufReturn,size_t * lenReturn)392 void AsyncFizzBase::getReadBuffer(void** bufReturn, size_t* lenReturn) {
393   std::pair<void*, uint32_t> readSpace =
394       transportReadBuf_.preallocate(kMinReadSize, kMaxReadSize);
395   *bufReturn = readSpace.first;
396 
397   // `readSizeHint_`, if zero, indicates that we do not care about how much
398   // data we read from the underlying socket.
399   //
400   // `readSizeHint_`, if nonzero, indicates the maximum amount of data we
401   // want to read from the underlying socket. This is necessary for kTLS,
402   // where we want to ensure that when ReportHandshakeSuccess is called, we
403   // are at a known point in the TCP stream, so we can let the kernel start
404   // decrypting records for us.
405   //
406   // For transport with "record aligned reads", we initially set `readSizeHint_`
407   // equal to the size of the TLS record header. Subsequently, the state machine
408   // will tell us exactly how much data is required to complete the record
409   // in WaitForData actions.
410   if (readSizeHint_ > 0) {
411     *lenReturn = std::min(
412         static_cast<decltype(readSizeHint_)>(kMinReadSize), readSizeHint_);
413   } else {
414     *lenReturn = readSpace.second;
415   }
416 }
417 
getReadBuffers(folly::IOBufIovecBuilder::IoVecVec & iovs)418 void AsyncFizzBase::getReadBuffers(folly::IOBufIovecBuilder::IoVecVec& iovs) {
419   ioVecQueue_.allocateBuffers(iovs, kMaxReadSize);
420 }
421 
readDataAvailable(size_t len)422 void AsyncFizzBase::readDataAvailable(size_t len) noexcept {
423   DelayedDestruction::DestructorGuard dg(this);
424 
425   if (getReadMode() == folly::AsyncTransport::ReadCallback::ReadMode::ReadVec) {
426     auto tmp = ioVecQueue_.extractIOBufChain(len);
427     transportReadBuf_.append(std::move(tmp));
428   } else {
429     transportReadBuf_.postallocate(len);
430   }
431   transportDataAvailable();
432   checkBufLen();
433 }
434 
isBufferMovable()435 bool AsyncFizzBase::isBufferMovable() noexcept {
436   return true;
437 }
438 
readBufferAvailable(std::unique_ptr<folly::IOBuf> data)439 void AsyncFizzBase::readBufferAvailable(
440     std::unique_ptr<folly::IOBuf> data) noexcept {
441   DelayedDestruction::DestructorGuard dg(this);
442 
443   transportReadBuf_.append(std::move(data));
444   transportDataAvailable();
445   checkBufLen();
446 }
447 
readEOF()448 void AsyncFizzBase::readEOF() noexcept {
449   AsyncSocketException eof(AsyncSocketException::END_OF_FILE, "readEOF()");
450   transportError(eof);
451 }
452 
readErr(const folly::AsyncSocketException & ex)453 void AsyncFizzBase::readErr(const folly::AsyncSocketException& ex) noexcept {
454   transportError(ex);
455 }
456 
writeSuccess()457 void AsyncFizzBase::writeSuccess() noexcept {}
458 
writeErr(size_t,const folly::AsyncSocketException & ex)459 void AsyncFizzBase::writeErr(
460     size_t /* bytesWritten */,
461     const folly::AsyncSocketException& ex) noexcept {
462   transportError(ex);
463 }
464 
checkBufLen()465 void AsyncFizzBase::checkBufLen() {
466   if (!readCallback_ &&
467       (transportReadBuf_.chainLength() >= kMaxBufSize ||
468        (appDataBuf_ && appDataBuf_->computeChainDataLength() >= kMaxBufSize))) {
469     transport_->setEventCallback(nullptr);
470     transport_->setReadCB(nullptr);
471   }
472 }
473 
handshakeTimeoutExpired()474 void AsyncFizzBase::handshakeTimeoutExpired() noexcept {
475   AsyncSocketException eof(
476       AsyncSocketException::TIMED_OUT, "handshake timeout expired");
477   transportError(eof);
478 }
479 
endOfTLS(std::unique_ptr<folly::IOBuf> endOfData)480 void AsyncFizzBase::endOfTLS(std::unique_ptr<folly::IOBuf> endOfData) noexcept {
481   DelayedDestruction::DestructorGuard dg(this);
482 
483   if (connecting()) {
484     AsyncSocketException ex(
485         AsyncSocketException::INVALID_STATE,
486         "tls connection torn down while connecting");
487     transportError(ex);
488     return;
489   }
490 
491   if (endOfTLSCallback_) {
492     endOfTLSCallback_->endOfTLS(this, std::move(endOfData));
493   } else {
494     // The end of TLS callback may not want the socket to be closed but by
495     // default read callbacks often close on EOF, as such we defer to the setter
496     // of the end of tls callback to apply the appropriate behaviour if it's set
497     if (readCallback_) {
498       auto readCallback = readCallback_;
499       readCallback_ = nullptr;
500       readCallback->readEOF();
501     }
502     transport_->close();
503   }
504 }
505 
506 // The below maps the secret type to the appropriate secret callback function.
507 namespace {
508 class SecretVisitor {
509  public:
SecretVisitor(AsyncFizzBase::SecretCallback * cb,const std::vector<uint8_t> & secretBuf)510   explicit SecretVisitor(
511       AsyncFizzBase::SecretCallback* cb,
512       const std::vector<uint8_t>& secretBuf)
513       : callback_(cb), secretBuf_(secretBuf) {}
operator ()(const SecretType & secretType)514   void operator()(const SecretType& secretType) {
515     switch (secretType.type()) {
516       case SecretType::Type::EarlySecrets_E:
517         operator()(*secretType.asEarlySecrets());
518         break;
519       case SecretType::Type::HandshakeSecrets_E:
520         operator()(*secretType.asHandshakeSecrets());
521         break;
522       case SecretType::Type::MasterSecrets_E:
523         operator()(*secretType.asMasterSecrets());
524         break;
525       case SecretType::Type::AppTrafficSecrets_E:
526         operator()(*secretType.asAppTrafficSecrets());
527         break;
528     }
529   }
530 
operator ()(const EarlySecrets & secret)531   void operator()(const EarlySecrets& secret) {
532     switch (secret) {
533       case EarlySecrets::ExternalPskBinder:
534         callback_->externalPskBinderAvailable(secretBuf_);
535         return;
536       case EarlySecrets::ResumptionPskBinder:
537         callback_->resumptionPskBinderAvailable(secretBuf_);
538         return;
539       case EarlySecrets::ClientEarlyTraffic:
540         callback_->clientEarlyTrafficSecretAvailable(secretBuf_);
541         return;
542       case EarlySecrets::EarlyExporter:
543         callback_->earlyExporterSecretAvailable(secretBuf_);
544         return;
545     }
546   }
operator ()(const HandshakeSecrets & secret)547   void operator()(const HandshakeSecrets& secret) {
548     switch (secret) {
549       case HandshakeSecrets::ClientHandshakeTraffic:
550         callback_->clientHandshakeTrafficSecretAvailable(secretBuf_);
551         return;
552       case HandshakeSecrets::ServerHandshakeTraffic:
553         callback_->serverHandshakeTrafficSecretAvailable(secretBuf_);
554         return;
555     }
556   }
operator ()(const MasterSecrets & secret)557   void operator()(const MasterSecrets& secret) {
558     switch (secret) {
559       case MasterSecrets::ExporterMaster:
560         callback_->exporterMasterSecretAvailable(secretBuf_);
561         return;
562       case MasterSecrets::ResumptionMaster:
563         callback_->resumptionMasterSecretAvailable(secretBuf_);
564         return;
565     }
566   }
operator ()(const AppTrafficSecrets & secret)567   void operator()(const AppTrafficSecrets& secret) {
568     switch (secret) {
569       case AppTrafficSecrets::ClientAppTraffic:
570         callback_->clientAppTrafficSecretAvailable(secretBuf_);
571         return;
572       case AppTrafficSecrets::ServerAppTraffic:
573         callback_->serverAppTrafficSecretAvailable(secretBuf_);
574         return;
575     }
576   }
577 
578  private:
579   AsyncFizzBase::SecretCallback* callback_;
580   const std::vector<uint8_t>& secretBuf_;
581 };
582 } // namespace
583 
secretAvailable(const DerivedSecret & secret)584 void AsyncFizzBase::secretAvailable(const DerivedSecret& secret) noexcept {
585   if (secretCallback_) {
586     SecretVisitor visitor(secretCallback_, secret.secret);
587     visitor(secret.type);
588   }
589 }
590 } // namespace fizz
591