1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <thrift/lib/cpp2/server/Cpp2Worker.h>
18
19 #include <vector>
20
21 #include <glog/logging.h>
22
23 #include <folly/Overload.h>
24 #include <folly/String.h>
25 #include <folly/io/async/AsyncSSLSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/EventBaseLocal.h>
28 #include <folly/portability/Sockets.h>
29 #include <thrift/lib/cpp/async/TAsyncSSLSocket.h>
30 #include <thrift/lib/cpp/concurrency/Util.h>
31 #include <thrift/lib/cpp2/async/ResponseChannel.h>
32 #include <thrift/lib/cpp2/security/extensions/ThriftParametersContext.h>
33 #include <thrift/lib/cpp2/server/Cpp2Connection.h>
34 #include <thrift/lib/cpp2/server/LoggingEvent.h>
35 #include <thrift/lib/cpp2/server/ThriftServer.h>
36 #include <thrift/lib/cpp2/server/peeking/PeekingManager.h>
37 #include <thrift/lib/thrift/gen-cpp2/RpcMetadata_types.h>
38 #include <wangle/acceptor/EvbHandshakeHelper.h>
39 #include <wangle/acceptor/SSLAcceptorHandshakeHelper.h>
40 #include <wangle/acceptor/UnencryptedAcceptorHandshakeHelper.h>
41
42 namespace apache {
43 namespace thrift {
44
45 using namespace apache::thrift::server;
46 using namespace apache::thrift::transport;
47 using namespace apache::thrift::async;
48 using apache::thrift::concurrency::Util;
49 using std::shared_ptr;
50
51 namespace {
52 folly::LeakySingleton<folly::EventBaseLocal<RequestsRegistry>> registry;
53 } // namespace
54
initRequestsRegistry()55 void Cpp2Worker::initRequestsRegistry() {
56 auto* evb = getEventBase();
57 auto memPerReq = server_->getMaxDebugPayloadMemoryPerRequest();
58 auto memPerWorker = server_->getMaxDebugPayloadMemoryPerWorker();
59 auto maxFinished = server_->getMaxFinishedDebugPayloadsPerWorker();
60 std::weak_ptr<Cpp2Worker> self_weak = shared_from_this();
61 evb->runInEventBaseThread([=, self_weak = std::move(self_weak)]() {
62 if (auto self = self_weak.lock()) {
63 self->requestsRegistry_ = ®istry.get().try_emplace(
64 *evb, memPerReq, memPerWorker, maxFinished);
65 }
66 });
67 }
68
onNewConnection(folly::AsyncTransport::UniquePtr sock,const folly::SocketAddress * addr,const std::string & nextProtocolName,wangle::SecureTransportType secureTransportType,const wangle::TransportInfo & tinfo)69 void Cpp2Worker::onNewConnection(
70 folly::AsyncTransport::UniquePtr sock,
71 const folly::SocketAddress* addr,
72 const std::string& nextProtocolName,
73 wangle::SecureTransportType secureTransportType,
74 const wangle::TransportInfo& tinfo) {
75 // This is possible if the connection was accepted before stopListening()
76 // call, but handshake was finished after stopCPUWorkers() call.
77 if (stopping_) {
78 return;
79 }
80
81 auto* observer = server_->getObserver();
82 uint32_t maxConnection = server_->getMaxConnections();
83 if (maxConnection > 0 &&
84 (getConnectionManager()->getNumConnections() >=
85 maxConnection / server_->getNumIOWorkerThreads())) {
86 if (observer) {
87 observer->connDropped();
88 observer->connRejected();
89 }
90 return;
91 }
92
93 const auto& func = server_->getZeroCopyEnableFunc();
94 if (func && sock) {
95 sock->setZeroCopy(true);
96 sock->setZeroCopyEnableFunc(func);
97 }
98
99 // Check the security protocol
100 switch (secureTransportType) {
101 // If no security, peek into the socket to determine type
102 case wangle::SecureTransportType::NONE: {
103 new TransportPeekingManager(
104 shared_from_this(), *addr, tinfo, server_, std::move(sock));
105 break;
106 }
107 case wangle::SecureTransportType::TLS:
108 // Use the announced protocol to determine the correct handler
109 if (!nextProtocolName.empty()) {
110 for (auto& routingHandler : *server_->getRoutingHandlers()) {
111 if (routingHandler->canAcceptEncryptedConnection(nextProtocolName)) {
112 VLOG(4) << "Cpp2Worker: Routing encrypted connection for protocol "
113 << nextProtocolName;
114 routingHandler->handleConnection(
115 getConnectionManager(),
116 std::move(sock),
117 addr,
118 tinfo,
119 shared_from_this());
120 return;
121 }
122 }
123 }
124 if (!getServer()->isDuplex()) {
125 new TransportPeekingManager(
126 shared_from_this(), *addr, tinfo, server_, std::move(sock));
127 } else {
128 handleHeader(std::move(sock), addr, tinfo);
129 }
130 break;
131 default:
132 LOG(ERROR) << "Unsupported Secure Transport Type";
133 break;
134 }
135 }
136
handleHeader(folly::AsyncTransport::UniquePtr sock,const folly::SocketAddress * addr,const wangle::TransportInfo & tinfo)137 void Cpp2Worker::handleHeader(
138 folly::AsyncTransport::UniquePtr sock,
139 const folly::SocketAddress* addr,
140 const wangle::TransportInfo& tinfo) {
141 auto fd = sock->getUnderlyingTransport<folly::AsyncSocket>()
142 ->getNetworkSocket()
143 .toFd();
144 VLOG(4) << "Cpp2Worker: Creating connection for socket " << fd;
145
146 auto thriftTransport = createThriftTransport(std::move(sock));
147 auto connection = std::make_shared<Cpp2Connection>(
148 std::move(thriftTransport), addr, shared_from_this(), nullptr);
149 Acceptor::addConnection(connection.get());
150 connection->addConnection(connection);
151 connection->start();
152
153 VLOG(4) << "Cpp2Worker: created connection for socket " << fd;
154
155 auto observer = server_->getObserver();
156 if (observer) {
157 observer->connAccepted(tinfo);
158 observer->activeConnections(
159 getConnectionManager()->getNumConnections() *
160 server_->getNumIOWorkerThreads());
161 }
162 }
163
createThriftTransport(folly::AsyncTransport::UniquePtr sock)164 std::shared_ptr<folly::AsyncTransport> Cpp2Worker::createThriftTransport(
165 folly::AsyncTransport::UniquePtr sock) {
166 auto fizzServer = dynamic_cast<fizz::server::AsyncFizzServer*>(sock.get());
167 if (fizzServer) {
168 auto asyncSock = sock->getUnderlyingTransport<folly::AsyncSocket>();
169 if (asyncSock) {
170 markSocketAccepted(asyncSock);
171 }
172 // give up ownership
173 sock.release();
174 return std::shared_ptr<fizz::server::AsyncFizzServer>(
175 fizzServer, fizz::server::AsyncFizzServer::Destructor());
176 }
177
178 folly::AsyncSocket* tsock =
179 sock->getUnderlyingTransport<folly::AsyncSocket>();
180 CHECK(tsock);
181 markSocketAccepted(tsock);
182 // use custom deleter for std::shared_ptr<folly::AsyncTransport> to allow
183 // socket transfer from header to rocket (if enabled by ThriftFlags)
184 return apache::thrift::transport::detail::convertToShared(std::move(sock));
185 }
186
markSocketAccepted(folly::AsyncSocket * sock)187 void Cpp2Worker::markSocketAccepted(folly::AsyncSocket* sock) {
188 sock->setShutdownSocketSet(server_->wShutdownSocketSet_);
189 }
190
plaintextConnectionReady(folly::AsyncSocket::UniquePtr sock,const folly::SocketAddress & clientAddr,wangle::TransportInfo & tinfo)191 void Cpp2Worker::plaintextConnectionReady(
192 folly::AsyncSocket::UniquePtr sock,
193 const folly::SocketAddress& clientAddr,
194 wangle::TransportInfo& tinfo) {
195 sock->setShutdownSocketSet(server_->wShutdownSocketSet_);
196 new CheckTLSPeekingManager(
197 shared_from_this(),
198 clientAddr,
199 tinfo,
200 server_,
201 std::move(sock),
202 server_->getObserverShared());
203 }
204
useExistingChannel(const std::shared_ptr<HeaderServerChannel> & serverChannel)205 void Cpp2Worker::useExistingChannel(
206 const std::shared_ptr<HeaderServerChannel>& serverChannel) {
207 folly::SocketAddress address;
208
209 auto conn = std::make_shared<Cpp2Connection>(
210 nullptr, &address, shared_from_this(), serverChannel);
211 Acceptor::getConnectionManager()->addConnection(conn.get(), false);
212 conn->addConnection(conn);
213
214 conn->start();
215 }
216
stopDuplex(std::shared_ptr<ThriftServer> myServer)217 void Cpp2Worker::stopDuplex(std::shared_ptr<ThriftServer> myServer) {
218 // They better have given us the correct ThriftServer
219 DCHECK(server_ == myServer.get());
220
221 // This does not really fully drain everything but at least
222 // prevents the connections from accepting new requests
223 wangle::Acceptor::drainAllConnections();
224
225 // Capture a shared_ptr to our ThriftServer making sure it will outlive us
226 // Otherwise our raw pointer to it (server_) will be jeopardized.
227 duplexServer_ = myServer;
228 }
229
updateSSLStats(const folly::AsyncTransport * sock,std::chrono::milliseconds,wangle::SSLErrorEnum error,const folly::exception_wrapper &)230 void Cpp2Worker::updateSSLStats(
231 const folly::AsyncTransport* sock,
232 std::chrono::milliseconds /* acceptLatency */,
233 wangle::SSLErrorEnum error,
234 const folly::exception_wrapper& /*ex*/) noexcept {
235 if (!sock) {
236 return;
237 }
238
239 auto observer = getServer()->getObserver();
240 if (!observer) {
241 return;
242 }
243
244 auto fizz = sock->getUnderlyingTransport<fizz::server::AsyncFizzServer>();
245 if (fizz) {
246 if (sock->good() && error == wangle::SSLErrorEnum::NO_ERROR) {
247 observer->tlsComplete();
248 auto pskType = fizz->getState().pskType();
249 if (pskType && *pskType == fizz::PskType::Resumption) {
250 observer->tlsResumption();
251 }
252 if (fizz->getPeerCertificate()) {
253 observer->tlsWithClientCert();
254 }
255 } else {
256 observer->tlsError();
257 }
258 } else {
259 auto socket = sock->getUnderlyingTransport<folly::AsyncSSLSocket>();
260 if (!socket) {
261 return;
262 }
263 if (socket->good() && error == wangle::SSLErrorEnum::NO_ERROR) {
264 observer->tlsComplete();
265 if (socket->getSSLSessionReused()) {
266 observer->tlsResumption();
267 }
268 if (socket->getPeerCertificate()) {
269 observer->tlsWithClientCert();
270 }
271 } else {
272 observer->tlsError();
273 }
274 }
275 }
276
createSSLHelper(const std::vector<uint8_t> & bytes,const folly::SocketAddress & clientAddr,std::chrono::steady_clock::time_point acceptTime,wangle::TransportInfo & tInfo)277 wangle::AcceptorHandshakeHelper::UniquePtr Cpp2Worker::createSSLHelper(
278 const std::vector<uint8_t>& bytes,
279 const folly::SocketAddress& clientAddr,
280 std::chrono::steady_clock::time_point acceptTime,
281 wangle::TransportInfo& tInfo) {
282 if (accConfig_.fizzConfig.enableFizz) {
283 if (auto parametersContext = getThriftParametersContext()) {
284 fizzPeeker_.setThriftParametersContext(
285 folly::copy_to_shared_ptr(*parametersContext));
286 }
287 return getFizzPeeker()->getHelper(bytes, clientAddr, acceptTime, tInfo);
288 }
289 return defaultPeekingCallback_.getHelper(
290 bytes, clientAddr, acceptTime, tInfo);
291 }
292
shouldPerformSSL(const std::vector<uint8_t> & bytes,const folly::SocketAddress & clientAddr)293 bool Cpp2Worker::shouldPerformSSL(
294 const std::vector<uint8_t>& bytes, const folly::SocketAddress& clientAddr) {
295 auto sslPolicy = getSSLPolicy();
296 if (sslPolicy == SSLPolicy::REQUIRED) {
297 if (isPlaintextAllowedOnLoopback()) {
298 // loopback clients may still be sending TLS so we need to ensure that
299 // it doesn't appear that way in addition to verifying it's loopback.
300 return !(
301 clientAddr.isLoopbackAddress() && !TLSHelper::looksLikeTLS(bytes));
302 }
303 return true;
304 } else {
305 return sslPolicy != SSLPolicy::DISABLED && TLSHelper::looksLikeTLS(bytes);
306 }
307 }
308
309 std::optional<ThriftParametersContext>
getThriftParametersContext()310 Cpp2Worker::getThriftParametersContext() {
311 auto thriftConfigBase =
312 folly::get_ptr(accConfig_.customConfigMap, "thrift_tls_config");
313 if (!thriftConfigBase) {
314 return std::nullopt;
315 }
316 assert(static_cast<ThriftTlsConfig*>((*thriftConfigBase).get()));
317 auto thriftConfig = static_cast<ThriftTlsConfig*>((*thriftConfigBase).get());
318 if (!thriftConfig->enableThriftParamsNegotiation) {
319 return std::nullopt;
320 }
321
322 auto thriftParametersContext = ThriftParametersContext();
323 thriftParametersContext.setUseStopTLS(
324 thriftConfig->enableStopTLS || **ThriftServer::enableStopTLS());
325 return thriftParametersContext;
326 }
327
getHelper(const std::vector<uint8_t> & bytes,const folly::SocketAddress & clientAddr,std::chrono::steady_clock::time_point acceptTime,wangle::TransportInfo & ti)328 wangle::AcceptorHandshakeHelper::UniquePtr Cpp2Worker::getHelper(
329 const std::vector<uint8_t>& bytes,
330 const folly::SocketAddress& clientAddr,
331 std::chrono::steady_clock::time_point acceptTime,
332 wangle::TransportInfo& ti) {
333 if (!shouldPerformSSL(bytes, clientAddr)) {
334 return wangle::AcceptorHandshakeHelper::UniquePtr(
335 new wangle::UnencryptedAcceptorHandshakeHelper());
336 }
337 return createSSLHelper(bytes, clientAddr, acceptTime, ti);
338 }
339
requestStop()340 void Cpp2Worker::requestStop() {
341 getEventBase()->runInEventBaseThreadAndWait([&] {
342 if (isStopping()) {
343 return;
344 }
345 cancelQueuedRequests();
346 stopping_.store(true, std::memory_order_relaxed);
347 if (activeRequests_ == 0) {
348 stopBaton_.post();
349 }
350 });
351 }
352
waitForStop(std::chrono::steady_clock::time_point deadline)353 bool Cpp2Worker::waitForStop(std::chrono::steady_clock::time_point deadline) {
354 if (!stopBaton_.try_wait_until(deadline)) {
355 LOG(ERROR) << "Failed to join outstanding requests.";
356 return false;
357 }
358 return true;
359 }
360
cancelQueuedRequests()361 void Cpp2Worker::cancelQueuedRequests() {
362 auto eb = getEventBase();
363 eb->dcheckIsInEventBaseThread();
364 for (auto& stub : requestsRegistry_->getActive()) {
365 if (stub.stateMachine_.isActive() &&
366 stub.stateMachine_.tryStopProcessing()) {
367 stub.req_->sendQueueTimeoutResponse();
368 }
369 }
370 }
371
getActiveRequestsGuard()372 Cpp2Worker::ActiveRequestsGuard Cpp2Worker::getActiveRequestsGuard() {
373 DCHECK(!isStopping() || activeRequests_);
374 ++activeRequests_;
375 return Cpp2Worker::ActiveRequestsGuard(this);
376 }
377
378 Cpp2Worker::PerServiceMetadata::FindMethodResult
findMethod(std::string_view methodName) const379 Cpp2Worker::PerServiceMetadata::findMethod(std::string_view methodName) const {
380 if (const auto* map =
381 std::get_if<AsyncProcessorFactory::MethodMetadataMap>(&methods_)) {
382 if (auto* m = folly::get_ptr(*map, methodName)) {
383 DCHECK(m->get());
384 return MetadataFound{**m};
385 }
386 return MetadataNotFound{};
387 }
388 if (const auto* wildcard =
389 std::get_if<AsyncProcessorFactory::WildcardMethodMetadataMap>(
390 &methods_)) {
391 if (auto* m = folly::get_ptr(wildcard->knownMethods, methodName)) {
392 DCHECK(m->get());
393 return MetadataFound{**m};
394 }
395 return MetadataFound{AsyncProcessorFactory::kWildcardMethodMetadata};
396 }
397 if (std::holds_alternative<AsyncProcessorFactory::MetadataNotImplemented>(
398 methods_)) {
399 return MetadataNotImplemented{};
400 }
401
402 LOG(FATAL) << "Invalid CreateMethodMetadataResult from service";
403 folly::assume_unreachable();
404 }
405
406 std::shared_ptr<folly::RequestContext>
getBaseContextForRequest(const Cpp2Worker::PerServiceMetadata::FindMethodResult & findMethodResult) const407 Cpp2Worker::PerServiceMetadata::getBaseContextForRequest(
408 const Cpp2Worker::PerServiceMetadata::FindMethodResult& findMethodResult)
409 const {
410 using Result = std::shared_ptr<folly::RequestContext>;
411 if (const auto* found =
412 std::get_if<PerServiceMetadata::MetadataFound>(&findMethodResult)) {
413 return processorFactory_.getBaseContextForRequest(found->metadata);
414 }
415 return nullptr;
416 }
417
418 } // namespace thrift
419 } // namespace apache
420