1 /* 2 * Copyright (c) Facebook, Inc. and its affiliates. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #pragma once 18 19 #include <chrono> 20 #include <functional> 21 #include <memory> 22 #include <utility> 23 24 #include <boost/intrusive/unordered_set.hpp> 25 26 #include <folly/IntrusiveList.h> 27 #include <folly/Likely.h> 28 #include <folly/Portability.h> 29 #include <folly/fibers/Baton.h> 30 31 #include <thrift/lib/cpp2/transport/rocket/Types.h> 32 #include <thrift/lib/cpp2/transport/rocket/framing/FrameType.h> 33 #include <thrift/lib/cpp2/transport/rocket/framing/Frames.h> 34 #include <thrift/lib/cpp2/transport/rocket/framing/Serializer.h> 35 #include <thrift/lib/thrift/gen-cpp2/RpcMetadata_types.h> 36 37 namespace apache { 38 namespace thrift { 39 namespace rocket { 40 class RequestContextQueue; 41 42 class RequestContext { 43 public: 44 class WriteSuccessCallback { 45 public: 46 virtual ~WriteSuccessCallback() = default; 47 virtual void onWriteSuccess() noexcept = 0; 48 }; 49 50 enum class State : uint8_t { 51 DEFERRED_INIT, /* still needs to be intialized with server version */ 52 WRITE_NOT_SCHEDULED, 53 WRITE_SCHEDULED, 54 WRITE_SENDING, /* AsyncSocket::writeChain() called, but WriteCallback has 55 not yet fired */ 56 WRITE_SENT, /* Write to socket completed (possibly with error) */ 57 COMPLETE, /* Terminal state. Result stored in responsePayload_ */ 58 }; 59 60 template <class Frame> 61 RequestContext( 62 Frame&& frame, 63 RequestContextQueue& queue, 64 SetupFrame* setupFrame = nullptr, 65 WriteSuccessCallback* writeSuccessCallback = nullptr) queue_(queue)66 : queue_(queue), 67 streamId_(frame.streamId()), 68 frameType_(Frame::frameType()), 69 writeSuccessCallback_(writeSuccessCallback) { 70 serialize(std::forward<Frame>(frame), setupFrame); 71 } 72 73 template <class InitFunc> 74 RequestContext( 75 InitFunc&& initFunc, 76 int32_t serverVersion, 77 StreamId streamId, 78 RequestContextQueue& queue, 79 WriteSuccessCallback* writeSuccessCallback = nullptr) queue_(queue)80 : queue_(queue), 81 streamId_(streamId), 82 writeSuccessCallback_(writeSuccessCallback) { 83 if (UNLIKELY(serverVersion == -1)) { 84 deferredInit_ = std::forward<InitFunc>(initFunc); 85 state_ = State::DEFERRED_INIT; 86 } else { 87 std::tie(serializedFrame_, frameType_) = initFunc(serverVersion); 88 } 89 } 90 91 RequestContext(const RequestContext&) = delete; 92 RequestContext(RequestContext&&) = delete; 93 RequestContext& operator=(const RequestContext&) = delete; 94 RequestContext& operator=(RequestContext&&) = delete; 95 96 // For REQUEST_RESPONSE contexts, where an immediate matching response is 97 // expected 98 FOLLY_NODISCARD folly::Try<Payload> waitForResponse( 99 std::chrono::milliseconds timeout); 100 FOLLY_NODISCARD folly::Try<Payload> getResponse() &&; 101 102 // For request types for which an immediate matching response is not 103 // necessarily expected, e.g., REQUEST_FNF and REQUEST_STREAM 104 FOLLY_NODISCARD folly::Try<void> waitForWriteToComplete(); 105 106 void waitForWriteToCompleteSchedule(folly::fibers::Baton::Waiter* waiter); 107 FOLLY_NODISCARD folly::Try<void> waitForWriteToCompleteResult(); 108 setTimeoutInfo(folly::HHWheelTimer & timer,folly::HHWheelTimer::Callback & callback,std::chrono::milliseconds timeout)109 void setTimeoutInfo( 110 folly::HHWheelTimer& timer, 111 folly::HHWheelTimer::Callback& callback, 112 std::chrono::milliseconds timeout) { 113 timer_ = &timer; 114 timeoutCallback_ = &callback; 115 requestTimeout_ = timeout; 116 } 117 scheduleTimeoutForResponse()118 void scheduleTimeoutForResponse() { 119 DCHECK(isRequestResponse()); 120 // In some edge cases, response may arrive before write to socket finishes. 121 if (state_ != State::COMPLETE && 122 requestTimeout_ != std::chrono::milliseconds::zero()) { 123 timer_->scheduleTimeout(timeoutCallback_, requestTimeout_); 124 } 125 } 126 serializedChain()127 std::unique_ptr<folly::IOBuf> serializedChain() { 128 DCHECK(serializedFrame_); 129 return std::move(serializedFrame_); 130 } 131 state()132 State state() const { return state_; } 133 streamId()134 StreamId streamId() const { return streamId_; } 135 isRequestResponse()136 bool isRequestResponse() const { 137 return frameType_ == FrameType::REQUEST_RESPONSE; 138 } 139 140 void onPayloadFrame(PayloadFrame&& payloadFrame); 141 void onErrorFrame(ErrorFrame&& errorFrame); 142 143 void onWriteSuccess() noexcept; 144 hasPartialPayload()145 bool hasPartialPayload() const { return responsePayload_.hasValue(); } 146 initWithVersion(int32_t serverVersion)147 void initWithVersion(int32_t serverVersion) { 148 if (!deferredInit_) { 149 return; 150 } 151 DCHECK(state_ == State::DEFERRED_INIT); 152 std::tie(serializedFrame_, frameType_) = deferredInit_(serverVersion); 153 DCHECK(serializedFrame_ && frameType_ != FrameType::RESERVED); 154 state_ = State::WRITE_NOT_SCHEDULED; 155 } 156 157 private: 158 RequestContextQueue& queue_; 159 folly::SafeIntrusiveListHook queueHook_; 160 std::unique_ptr<folly::IOBuf> serializedFrame_; 161 StreamId streamId_; 162 FrameType frameType_; 163 State state_{State::WRITE_NOT_SCHEDULED}; 164 bool lastInWriteBatch_{false}; 165 bool isDummyEndOfBatchMarker_{false}; 166 167 boost::intrusive::unordered_set_member_hook<> setHook_; 168 folly::fibers::Baton baton_; 169 std::chrono::milliseconds requestTimeout_{1000}; 170 folly::HHWheelTimer* timer_{nullptr}; 171 folly::HHWheelTimer::Callback* timeoutCallback_{nullptr}; 172 folly::Try<Payload> responsePayload_; 173 WriteSuccessCallback* const writeSuccessCallback_{nullptr}; 174 folly::Function<std::pair<std::unique_ptr<folly::IOBuf>, FrameType>(int32_t)> 175 deferredInit_{nullptr}; 176 177 template <class Frame> serialize(Frame && frame,SetupFrame * setupFrame)178 void serialize(Frame&& frame, SetupFrame* setupFrame) { 179 DCHECK(!serializedFrame_); 180 181 serializedFrame_ = std::move(frame).serialize(); 182 183 if (UNLIKELY(setupFrame != nullptr)) { 184 Serializer writer; 185 std::move(*setupFrame).serialize(writer); 186 auto setupBuffer = std::move(writer).move(); 187 setupBuffer->prependChain(std::move(serializedFrame_)); 188 serializedFrame_ = std::move(setupBuffer); 189 } 190 } 191 RequestContext(RequestContextQueue & queue)192 explicit RequestContext(RequestContextQueue& queue) 193 : queue_(queue), frameType_(FrameType::REQUEST_RESPONSE) {} 194 createDummyEndOfBatchMarker(RequestContextQueue & queue)195 static RequestContext& createDummyEndOfBatchMarker( 196 RequestContextQueue& queue) { 197 auto* rctx = new RequestContext(queue); 198 rctx->lastInWriteBatch_ = true; 199 rctx->isDummyEndOfBatchMarker_ = true; 200 rctx->state_ = State::WRITE_SENDING; 201 return *rctx; 202 } 203 204 struct Equal { operatorEqual205 bool operator()( 206 const RequestContext& ctxa, const RequestContext& ctxb) const noexcept { 207 return ctxa.streamId_ == ctxb.streamId_; 208 } 209 }; 210 211 struct Hash { operatorHash212 size_t operator()(const RequestContext& ctx) const noexcept { 213 return std::hash<StreamId::underlying_type>()( 214 static_cast<uint32_t>(ctx.streamId_)); 215 } 216 }; 217 218 public: 219 using Queue = 220 folly::CountedIntrusiveList<RequestContext, &RequestContext::queueHook_>; 221 222 using UnorderedSet = boost::intrusive::unordered_set< 223 RequestContext, 224 boost::intrusive::member_hook< 225 RequestContext, 226 decltype(setHook_), 227 &RequestContext::setHook_>, 228 boost::intrusive::equal<Equal>, 229 boost::intrusive::hash<Hash>>; 230 231 private: 232 friend class RequestContextQueue; 233 }; 234 235 } // namespace rocket 236 } // namespace thrift 237 } // namespace apache 238