1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <proxygen/lib/http/session/HQUpstreamSession.h>
10 #include <wangle/acceptor/ConnectionManager.h>
11 
12 namespace proxygen {
13 
~HQUpstreamSession()14 HQUpstreamSession::~HQUpstreamSession() {
15   CHECK_EQ(getNumStreams(), 0);
16 }
17 
startNow()18 void HQUpstreamSession::startNow() {
19   HQSession::startNow();
20   if (connectCb_ && connectTimeoutMs_.count() > 0) {
21     // Start a timer in case the connection takes too long.
22     getEventBase()->timer().scheduleTimeout(&connectTimeout_,
23                                             connectTimeoutMs_);
24   }
25 }
26 
connectTimeoutExpired()27 void HQUpstreamSession::connectTimeoutExpired() noexcept {
28   VLOG(4) << __func__ << " sess=" << *this << ": connection failed";
29   if (connectCb_) {
30     onConnectionError(std::make_pair(quic::LocalErrorCode::CONNECT_FAILED,
31                                      "connect timeout"));
32   }
33 }
34 
onTransportReady()35 void HQUpstreamSession::onTransportReady() noexcept {
36   HQUpstreamSession::DestructorGuard dg(this);
37   if (!HQSession::onTransportReadyCommon()) {
38     // Something went wrong in onTransportReady, e.g. the ALPN is not supported
39     return;
40   }
41   connectSuccess();
42 }
43 
onFirstPeerPacketProcessed()44 void HQUpstreamSession::onFirstPeerPacketProcessed() noexcept {
45   HQUpstreamSession::DestructorGuard dg(this);
46   if (connectCb_) {
47     connectCb_->onFirstPeerPacketProcessed();
48   }
49 }
50 
connectSuccess()51 void HQUpstreamSession::connectSuccess() noexcept {
52   HQUpstreamSession::DestructorGuard dg(this);
53   if (connectCb_) {
54     connectCb_->connectSuccess();
55   }
56   if (connCbState_ == ConnCallbackState::REPLAY_SAFE) {
57     handleReplaySafe();
58     connCbState_ = ConnCallbackState::DONE;
59   } else {
60     connCbState_ = ConnCallbackState::CONNECT_SUCCESS;
61   }
62 }
63 
onReplaySafe()64 void HQUpstreamSession::onReplaySafe() noexcept {
65   HQUpstreamSession::DestructorGuard dg(this);
66   if (connCbState_ == ConnCallbackState::CONNECT_SUCCESS) {
67     handleReplaySafe();
68     connCbState_ = ConnCallbackState::DONE;
69   } else {
70     connCbState_ = ConnCallbackState::REPLAY_SAFE;
71   }
72 }
73 
handleReplaySafe()74 void HQUpstreamSession::handleReplaySafe() noexcept {
75   HQSession::onReplaySafe();
76   // In the case that zero rtt, onTransportReady is almost called
77   // immediately without proof of network reachability, and onReplaySafe is
78   // expected to be called in 1 rtt time (if success).
79   if (connectCb_) {
80     auto cb = connectCb_;
81     connectCb_ = nullptr;
82     connectTimeout_.cancelTimeout();
83     cb->onReplaySafe();
84   }
85 }
86 
onConnectionEnd()87 void HQUpstreamSession::onConnectionEnd() noexcept {
88   VLOG(4) << __func__ << " sess=" << *this;
89 
90   HQSession::DestructorGuard dg(this);
91   if (connectCb_) {
92     onConnectionErrorHandler(std::make_pair(
93         quic::LocalErrorCode::CONNECT_FAILED, "session destroyed"));
94   }
95   HQSession::onConnectionEnd();
96 }
97 
onConnectionErrorHandler(std::pair<quic::QuicErrorCode,std::string> code)98 void HQUpstreamSession::onConnectionErrorHandler(
99     std::pair<quic::QuicErrorCode, std::string> code) noexcept {
100   // For an upstream connection, any error before onTransportReady gets
101   // notified as a connect error.
102   if (connectCb_) {
103     HQSession::DestructorGuard dg(this);
104     auto cb = connectCb_;
105     connectCb_ = nullptr;
106     cb->connectError(std::move(code));
107     connectTimeout_.cancelTimeout();
108   }
109 }
110 
isDetachable(bool checkSocket) const111 bool HQUpstreamSession::isDetachable(bool checkSocket) const {
112   VLOG(4) << __func__ << " sess=" << *this;
113   // TODO: deal with control streams in h2q
114   if (checkSocket && sock_ && !sock_->isDetachable()) {
115     return false;
116   }
117   return getNumOutgoingStreams() == 0 && getNumIncomingStreams() == 0;
118 }
119 
attachThreadLocals(folly::EventBase * eventBase,folly::SSLContextPtr,const WheelTimerInstance & timeout,HTTPSessionStats * stats,FilterIteratorFn fn,HeaderCodec::Stats * headerCodecStats,HTTPSessionController * controller)120 void HQUpstreamSession::attachThreadLocals(folly::EventBase* eventBase,
121                                            folly::SSLContextPtr,
122                                            const WheelTimerInstance& timeout,
123                                            HTTPSessionStats* stats,
124                                            FilterIteratorFn fn,
125                                            HeaderCodec::Stats* headerCodecStats,
126                                            HTTPSessionController* controller) {
127   // TODO: deal with control streams in h2q
128   VLOG(4) << __func__ << " sess=" << *this;
129   txnEgressQueue_.attachThreadLocals(timeout);
130   setController(controller);
131   setSessionStats(stats);
132   if (sock_) {
133     sock_->attachEventBase(eventBase);
134   }
135   codec_.foreach (fn);
136   setHeaderCodecStats(headerCodecStats);
137   sock_->getEventBase()->runInLoop(this);
138   // The caller MUST re-add the connection to a new connection manager.
139 }
140 
detachThreadLocals(bool)141 void HQUpstreamSession::detachThreadLocals(bool) {
142   VLOG(4) << __func__ << " sess=" << *this;
143   // TODO: deal with control streams in h2q
144   CHECK_EQ(getNumOutgoingStreams(), 0);
145   cancelLoopCallback();
146 
147   // TODO: Pause reads and invoke infocallback
148   // pauseReadsImpl();
149   if (sock_) {
150     sock_->detachEventBase();
151   }
152 
153   txnEgressQueue_.detachThreadLocals();
154   setController(nullptr);
155   setSessionStats(nullptr);
156   // The codec filters *shouldn't* be accessible while the socket is detached,
157   // I hope
158   setHeaderCodecStats(nullptr);
159   auto cm = getConnectionManager();
160   if (cm) {
161     cm->removeConnection(this);
162   }
163 }
164 
onNetworkSwitch(std::unique_ptr<folly::AsyncUDPSocket> newSock)165 void HQUpstreamSession::onNetworkSwitch(
166     std::unique_ptr<folly::AsyncUDPSocket> newSock) noexcept {
167   if (sock_) {
168     sock_->onNetworkSwitch(std::move(newSock));
169   }
170 }
171 
tryBindIngressStreamToTxn(quic::StreamId streamId,hq::PushId pushId,HQIngressPushStream * pushStream)172 bool HQUpstreamSession::tryBindIngressStreamToTxn(
173     quic::StreamId streamId,
174     hq::PushId pushId,
175     HQIngressPushStream* pushStream) {
176   // lookup pending nascent stream id
177   CHECK(pushStream);
178 
179   VLOG(4) << __func__ << " attempting to bind streamID=" << streamId
180           << " to pushID=" << pushId;
181   pushStream->bindTo(streamId);
182 
183 #if DEBUG
184   // Check postconditions - the ingress push stream
185   // should own both the push id and the stream id.
186   // No nascent stream should own the stream id
187   auto streamById = findIngressPushStream(streamId);
188   auto streamByPushId = findIngressPushStreamByPushId(pushId);
189 
190   DCHECK_EQ(streamId, pushStream->getIngressStreamId());
191   DCHECK(streamById) << "Ingress stream must be bound to the streamID="
192                      << streamId;
193   DCHECK(streamByPushId) << "Ingress stream must be found by the pushID="
194                          << pushId;
195   DCHECK_EQ(streamById, streamByPushId) << "Must be same stream";
196 #endif
197 
198   VLOG(4) << __func__ << " successfully bound streamID=" << streamId
199           << " to pushID=" << pushId;
200   return true;
201 }
202 
203 // Called when we receive a push promise
204 HQUpstreamSession::HQStreamTransportBase*
createIngressPushStream(HTTPCodec::StreamID parentId,hq::PushId pushId)205 HQUpstreamSession::createIngressPushStream(HTTPCodec::StreamID parentId,
206                                            hq::PushId pushId) {
207 
208   // Check that a stream with this ID has not been created yet
209   DCHECK(!findIngressPushStreamByPushId(pushId))
210       << "Ingress stream with this push ID already exists pushID=" << pushId;
211 
212   auto matchPair = ingressPushStreams_.emplace(
213       std::piecewise_construct,
214       std::forward_as_tuple(pushId),
215       std::forward_as_tuple(
216           *this,
217           pushId,
218           parentId,
219           getNumTxnServed(),
220           WheelTimerInstance(transactionsTimeout_, getEventBase())));
221 
222   CHECK(matchPair.second) << "Emplacement failed, despite earlier "
223                              "existence check.";
224 
225   auto newIngressPushStream = &matchPair.first->second;
226 
227   // If there is a nascent stream ready to be bound to the newly
228   // created ingress stream, do it now.
229   bool bound = false;
230   auto res = pushIdToStreamId_.find(pushId);
231   if (res == pushIdToStreamId_.end()) {
232     VLOG(4)
233         << __func__ << " pushID=" << pushId
234         << " not found in the lookup table, size=" << pushIdToStreamId_.size();
235   } else {
236     bound =
237         tryBindIngressStreamToTxn(res->second, pushId, newIngressPushStream);
238   }
239 
240   VLOG(4) << "Successfully created new ingress push stream"
241           << " pushID=" << pushId << " parentStreamID=" << parentId
242           << " bound=" << bound << " streamID="
243           << (bound ? newIngressPushStream->getIngressStreamId()
244                     : static_cast<unsigned long>(-1));
245 
246   return newIngressPushStream;
247 }
248 
findPushStream(quic::StreamId streamId)249 HQSession::HQStreamTransportBase* HQUpstreamSession::findPushStream(
250     quic::StreamId streamId) {
251   return findIngressPushStream(streamId);
252 }
253 
254 HQUpstreamSession::HQIngressPushStream* FOLLY_NULLABLE
findIngressPushStream(quic::StreamId streamId)255 HQUpstreamSession::findIngressPushStream(quic::StreamId streamId) {
256   auto res = streamIdToPushId_.find(streamId);
257   if (res == streamIdToPushId_.end()) {
258     return nullptr;
259   } else {
260     return findIngressPushStreamByPushId(res->second);
261   }
262 }
263 
264 HQUpstreamSession::HQIngressPushStream* FOLLY_NULLABLE
findIngressPushStreamByPushId(hq::PushId pushId)265 HQUpstreamSession::findIngressPushStreamByPushId(hq::PushId pushId) {
266   VLOG(4) << __func__ << " looking up ingress push stream by pushID=" << pushId;
267   auto it = ingressPushStreams_.find(pushId);
268   if (it == ingressPushStreams_.end()) {
269     return nullptr;
270   } else {
271     return &it->second;
272   }
273 }
274 
erasePushStream(quic::StreamId streamId)275 bool HQUpstreamSession::erasePushStream(quic::StreamId streamId) {
276   auto res = streamIdToPushId_.find(streamId);
277   if (res != streamIdToPushId_.end()) {
278     auto pushId = res->second;
279     // Ingress push stream may be using the push id
280     // erase it as well if present
281     ingressPushStreams_.erase(pushId);
282 
283     // Unconditionally erase the lookup entry tables
284     streamIdToPushId_.erase(res);
285     pushIdToStreamId_.erase(pushId);
286     return true;
287   }
288   return false;
289 }
290 
numberOfIngressPushStreams() const291 uint32_t HQUpstreamSession::numberOfIngressPushStreams() const {
292   return ingressPushStreams_.size();
293 }
294 
onNewPushStream(quic::StreamId pushStreamId,hq::PushId pushId,size_t toConsume)295 void HQUpstreamSession::onNewPushStream(quic::StreamId pushStreamId,
296                                         hq::PushId pushId,
297                                         size_t toConsume) {
298   VLOG(4) << __func__ << " streamID=" << pushStreamId << " pushId=" << pushId;
299 
300   // TODO: if/when we support client goaway, reject stream if
301   // pushId >= minUnseenIncomingPushId_ after the GOAWAY is sent
302   minUnseenIncomingPushId_ = std::max(minUnseenIncomingPushId_, pushId);
303   DCHECK_GT(toConsume, 0);
304 
305   bool eom = false;
306   if (serverPushLifecycleCb_) {
307     serverPushLifecycleCb_->onNascentPushStreamBegin(pushStreamId, eom);
308   }
309 
310   auto consumeRes = sock_->consume(pushStreamId, toConsume);
311   CHECK(!consumeRes.hasError())
312       << "Unexpected error " << consumeRes.error() << " while consuming "
313       << toConsume << " bytes from stream=" << pushStreamId
314       << " pushId=" << pushId;
315 
316   // Replace the peek callback with a read callback and pause the read callback
317   sock_->setReadCallback(pushStreamId, this);
318   sock_->setPeekCallback(pushStreamId, nullptr);
319   sock_->pauseRead(pushStreamId);
320 
321   // Increment the sequence no to account for the new transport-like stream
322   incrementSeqNo();
323 
324   pushIdToStreamId_.emplace(pushId, pushStreamId);
325   streamIdToPushId_.emplace(pushStreamId, pushId);
326 
327   VLOG(4) << __func__ << " assigned lookup from pushID=" << pushId
328           << " to streamID=" << pushStreamId;
329 
330   // We have successfully read the push id. Notify the testing callbacks
331   if (serverPushLifecycleCb_) {
332     serverPushLifecycleCb_->onNascentPushStream(pushStreamId, pushId, eom);
333   }
334 
335   // If the transaction for the incoming push stream has been created
336   // already, bind the new stream to the transaction
337   auto ingressPushStream = findIngressPushStreamByPushId(pushId);
338 
339   if (ingressPushStream) {
340     auto bound =
341         tryBindIngressStreamToTxn(pushStreamId, pushId, ingressPushStream);
342     VLOG(4) << __func__ << " bound=" << bound << " pushID=" << pushId
343             << " pushStreamID=" << pushStreamId << " to txn ";
344   }
345 }
346 
bindTo(quic::StreamId streamId)347 void HQUpstreamSession::HQIngressPushStream::bindTo(quic::StreamId streamId) {
348   // Ensure the nascent push stream is in correct state
349   // and that its push id matches this stream's push id
350   DCHECK(txn_.getAssocTxnId().has_value());
351   VLOG(4) << __func__ << " Binding streamID=" << streamId
352           << " to txn=" << txn_.getID();
353 #if DEBUG
354   // will throw bad-cast
355   HQUpstreamSession& session = dynamic_cast<HQUpstreamSession&>(session_);
356 #else
357   HQUpstreamSession& session = static_cast<HQUpstreamSession&>(session_);
358 #endif
359   // Initialize this stream's codec with the id of the transport stream
360   auto codec = session.versionUtils_->createCodec(streamId);
361   initCodec(std::move(codec), __func__);
362   DCHECK_EQ(*codecStreamId_, streamId);
363 
364   // Now that the codec is initialized, set the stream ID
365   // of the push stream
366   setIngressStreamId(streamId);
367   DCHECK_EQ(getIngressStreamId(), streamId);
368 
369   // Enable ingress on this stream. Read callback for the stream's
370   // id will be transferred to the HQSession
371   initIngress(__func__);
372 
373   // Re-enable reads
374   session.resumeReadsForPushStream(streamId);
375 
376   // Notify testing callbacks that a full push transaction
377   // has been successfully initialized
378   if (session.serverPushLifecycleCb_) {
379     session.serverPushLifecycleCb_->onPushedTxn(&txn_,
380                                                 streamId,
381                                                 getPushId(),
382                                                 txn_.getAssocTxnId().value(),
383                                                 false /* eof */);
384   }
385 }
386 
387 // This can only be unbound in that it has not received a stream ID yet
eraseUnboundStream(HQStreamTransportBase * hqStream)388 void HQUpstreamSession::eraseUnboundStream(HQStreamTransportBase* hqStream) {
389   auto hqPushIngressStream = dynamic_cast<HQIngressPushStream*>(hqStream);
390   CHECK(hqPushIngressStream)
391       << "Only HQIngressPushStream streams are allowed to be non-bound";
392   // This is what makes it unbound, it also cannot be in the map
393   DCHECK(!hqStream->hasIngressStreamId());
394   auto pushId = hqPushIngressStream->getPushId();
395   DCHECK(pushIdToStreamId_.find(pushId) == pushIdToStreamId_.end());
396   ingressPushStreams_.erase(pushId);
397 }
398 
cleanupUnboundPushStreams(std::vector<quic::StreamId> & streamsToCleanup)399 void HQUpstreamSession::cleanupUnboundPushStreams(
400     std::vector<quic::StreamId>& streamsToCleanup) {
401   for (auto& it : streamIdToPushId_) {
402     auto streamId = it.first;
403     auto pushId = it.second;
404     if (!ingressPushStreams_.count(pushId)) {
405       streamsToCleanup.push_back(streamId);
406     }
407   }
408 }
409 } // namespace proxygen
410