1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <chrono>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23 
24 #include <boost/intrusive/unordered_set.hpp>
25 
26 #include <folly/IntrusiveList.h>
27 #include <folly/Likely.h>
28 #include <folly/Portability.h>
29 #include <folly/fibers/Baton.h>
30 
31 #include <thrift/lib/cpp2/transport/rocket/Types.h>
32 #include <thrift/lib/cpp2/transport/rocket/framing/FrameType.h>
33 #include <thrift/lib/cpp2/transport/rocket/framing/Frames.h>
34 #include <thrift/lib/cpp2/transport/rocket/framing/Serializer.h>
35 #include <thrift/lib/thrift/gen-cpp2/RpcMetadata_types.h>
36 
37 namespace apache {
38 namespace thrift {
39 namespace rocket {
40 class RequestContextQueue;
41 
42 class RequestContext {
43  public:
44   class WriteSuccessCallback {
45    public:
46     virtual ~WriteSuccessCallback() = default;
47     virtual void onWriteSuccess() noexcept = 0;
48   };
49 
50   enum class State : uint8_t {
51     DEFERRED_INIT, /* still needs to be intialized with server version */
52     WRITE_NOT_SCHEDULED,
53     WRITE_SCHEDULED,
54     WRITE_SENDING, /* AsyncSocket::writeChain() called, but WriteCallback has
55                       not yet fired */
56     WRITE_SENT, /* Write to socket completed (possibly with error) */
57     COMPLETE, /* Terminal state. Result stored in responsePayload_ */
58   };
59 
60   template <class Frame>
61   RequestContext(
62       Frame&& frame,
63       RequestContextQueue& queue,
64       SetupFrame* setupFrame = nullptr,
65       WriteSuccessCallback* writeSuccessCallback = nullptr)
queue_(queue)66       : queue_(queue),
67         streamId_(frame.streamId()),
68         frameType_(Frame::frameType()),
69         writeSuccessCallback_(writeSuccessCallback) {
70     serialize(std::forward<Frame>(frame), setupFrame);
71   }
72 
73   template <class InitFunc>
74   RequestContext(
75       InitFunc&& initFunc,
76       int32_t serverVersion,
77       StreamId streamId,
78       RequestContextQueue& queue,
79       WriteSuccessCallback* writeSuccessCallback = nullptr)
queue_(queue)80       : queue_(queue),
81         streamId_(streamId),
82         writeSuccessCallback_(writeSuccessCallback) {
83     if (UNLIKELY(serverVersion == -1)) {
84       deferredInit_ = std::forward<InitFunc>(initFunc);
85       state_ = State::DEFERRED_INIT;
86     } else {
87       std::tie(serializedFrame_, frameType_) = initFunc(serverVersion);
88     }
89   }
90 
91   RequestContext(const RequestContext&) = delete;
92   RequestContext(RequestContext&&) = delete;
93   RequestContext& operator=(const RequestContext&) = delete;
94   RequestContext& operator=(RequestContext&&) = delete;
95 
96   // For REQUEST_RESPONSE contexts, where an immediate matching response is
97   // expected
98   FOLLY_NODISCARD folly::Try<Payload> waitForResponse(
99       std::chrono::milliseconds timeout);
100   FOLLY_NODISCARD folly::Try<Payload> getResponse() &&;
101 
102   // For request types for which an immediate matching response is not
103   // necessarily expected, e.g., REQUEST_FNF and REQUEST_STREAM
104   FOLLY_NODISCARD folly::Try<void> waitForWriteToComplete();
105 
106   void waitForWriteToCompleteSchedule(folly::fibers::Baton::Waiter* waiter);
107   FOLLY_NODISCARD folly::Try<void> waitForWriteToCompleteResult();
108 
setTimeoutInfo(folly::HHWheelTimer & timer,folly::HHWheelTimer::Callback & callback,std::chrono::milliseconds timeout)109   void setTimeoutInfo(
110       folly::HHWheelTimer& timer,
111       folly::HHWheelTimer::Callback& callback,
112       std::chrono::milliseconds timeout) {
113     timer_ = &timer;
114     timeoutCallback_ = &callback;
115     requestTimeout_ = timeout;
116   }
117 
scheduleTimeoutForResponse()118   void scheduleTimeoutForResponse() {
119     DCHECK(isRequestResponse());
120     // In some edge cases, response may arrive before write to socket finishes.
121     if (state_ != State::COMPLETE &&
122         requestTimeout_ != std::chrono::milliseconds::zero()) {
123       timer_->scheduleTimeout(timeoutCallback_, requestTimeout_);
124     }
125   }
126 
serializedChain()127   std::unique_ptr<folly::IOBuf> serializedChain() {
128     DCHECK(serializedFrame_);
129     return std::move(serializedFrame_);
130   }
131 
state()132   State state() const { return state_; }
133 
streamId()134   StreamId streamId() const { return streamId_; }
135 
isRequestResponse()136   bool isRequestResponse() const {
137     return frameType_ == FrameType::REQUEST_RESPONSE;
138   }
139 
140   void onPayloadFrame(PayloadFrame&& payloadFrame);
141   void onErrorFrame(ErrorFrame&& errorFrame);
142 
143   void onWriteSuccess() noexcept;
144 
hasPartialPayload()145   bool hasPartialPayload() const { return responsePayload_.hasValue(); }
146 
initWithVersion(int32_t serverVersion)147   void initWithVersion(int32_t serverVersion) {
148     if (!deferredInit_) {
149       return;
150     }
151     DCHECK(state_ == State::DEFERRED_INIT);
152     std::tie(serializedFrame_, frameType_) = deferredInit_(serverVersion);
153     DCHECK(serializedFrame_ && frameType_ != FrameType::RESERVED);
154     state_ = State::WRITE_NOT_SCHEDULED;
155   }
156 
157  private:
158   RequestContextQueue& queue_;
159   folly::SafeIntrusiveListHook queueHook_;
160   std::unique_ptr<folly::IOBuf> serializedFrame_;
161   StreamId streamId_;
162   FrameType frameType_;
163   State state_{State::WRITE_NOT_SCHEDULED};
164   bool lastInWriteBatch_{false};
165   bool isDummyEndOfBatchMarker_{false};
166 
167   boost::intrusive::unordered_set_member_hook<> setHook_;
168   folly::fibers::Baton baton_;
169   std::chrono::milliseconds requestTimeout_{1000};
170   folly::HHWheelTimer* timer_{nullptr};
171   folly::HHWheelTimer::Callback* timeoutCallback_{nullptr};
172   folly::Try<Payload> responsePayload_;
173   WriteSuccessCallback* const writeSuccessCallback_{nullptr};
174   folly::Function<std::pair<std::unique_ptr<folly::IOBuf>, FrameType>(int32_t)>
175       deferredInit_{nullptr};
176 
177   template <class Frame>
serialize(Frame && frame,SetupFrame * setupFrame)178   void serialize(Frame&& frame, SetupFrame* setupFrame) {
179     DCHECK(!serializedFrame_);
180 
181     serializedFrame_ = std::move(frame).serialize();
182 
183     if (UNLIKELY(setupFrame != nullptr)) {
184       Serializer writer;
185       std::move(*setupFrame).serialize(writer);
186       auto setupBuffer = std::move(writer).move();
187       setupBuffer->prependChain(std::move(serializedFrame_));
188       serializedFrame_ = std::move(setupBuffer);
189     }
190   }
191 
RequestContext(RequestContextQueue & queue)192   explicit RequestContext(RequestContextQueue& queue)
193       : queue_(queue), frameType_(FrameType::REQUEST_RESPONSE) {}
194 
createDummyEndOfBatchMarker(RequestContextQueue & queue)195   static RequestContext& createDummyEndOfBatchMarker(
196       RequestContextQueue& queue) {
197     auto* rctx = new RequestContext(queue);
198     rctx->lastInWriteBatch_ = true;
199     rctx->isDummyEndOfBatchMarker_ = true;
200     rctx->state_ = State::WRITE_SENDING;
201     return *rctx;
202   }
203 
204   struct Equal {
operatorEqual205     bool operator()(
206         const RequestContext& ctxa, const RequestContext& ctxb) const noexcept {
207       return ctxa.streamId_ == ctxb.streamId_;
208     }
209   };
210 
211   struct Hash {
operatorHash212     size_t operator()(const RequestContext& ctx) const noexcept {
213       return std::hash<StreamId::underlying_type>()(
214           static_cast<uint32_t>(ctx.streamId_));
215     }
216   };
217 
218  public:
219   using Queue =
220       folly::CountedIntrusiveList<RequestContext, &RequestContext::queueHook_>;
221 
222   using UnorderedSet = boost::intrusive::unordered_set<
223       RequestContext,
224       boost::intrusive::member_hook<
225           RequestContext,
226           decltype(setHook_),
227           &RequestContext::setHook_>,
228       boost::intrusive::equal<Equal>,
229       boost::intrusive::hash<Hash>>;
230 
231  private:
232   friend class RequestContextQueue;
233 };
234 
235 } // namespace rocket
236 } // namespace thrift
237 } // namespace apache
238