1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <thrift/lib/cpp2/server/Cpp2Worker.h>
18 
19 #include <vector>
20 
21 #include <glog/logging.h>
22 
23 #include <folly/Overload.h>
24 #include <folly/String.h>
25 #include <folly/io/async/AsyncSSLSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/EventBaseLocal.h>
28 #include <folly/portability/Sockets.h>
29 #include <thrift/lib/cpp/async/TAsyncSSLSocket.h>
30 #include <thrift/lib/cpp/concurrency/Util.h>
31 #include <thrift/lib/cpp2/async/ResponseChannel.h>
32 #include <thrift/lib/cpp2/security/extensions/ThriftParametersContext.h>
33 #include <thrift/lib/cpp2/server/Cpp2Connection.h>
34 #include <thrift/lib/cpp2/server/LoggingEvent.h>
35 #include <thrift/lib/cpp2/server/ThriftServer.h>
36 #include <thrift/lib/cpp2/server/peeking/PeekingManager.h>
37 #include <thrift/lib/thrift/gen-cpp2/RpcMetadata_types.h>
38 #include <wangle/acceptor/EvbHandshakeHelper.h>
39 #include <wangle/acceptor/SSLAcceptorHandshakeHelper.h>
40 #include <wangle/acceptor/UnencryptedAcceptorHandshakeHelper.h>
41 
42 namespace apache {
43 namespace thrift {
44 
45 using namespace apache::thrift::server;
46 using namespace apache::thrift::transport;
47 using namespace apache::thrift::async;
48 using apache::thrift::concurrency::Util;
49 using std::shared_ptr;
50 
51 namespace {
52 folly::LeakySingleton<folly::EventBaseLocal<RequestsRegistry>> registry;
53 } // namespace
54 
initRequestsRegistry()55 void Cpp2Worker::initRequestsRegistry() {
56   auto* evb = getEventBase();
57   auto memPerReq = server_->getMaxDebugPayloadMemoryPerRequest();
58   auto memPerWorker = server_->getMaxDebugPayloadMemoryPerWorker();
59   auto maxFinished = server_->getMaxFinishedDebugPayloadsPerWorker();
60   std::weak_ptr<Cpp2Worker> self_weak = shared_from_this();
61   evb->runInEventBaseThread([=, self_weak = std::move(self_weak)]() {
62     if (auto self = self_weak.lock()) {
63       self->requestsRegistry_ = &registry.get().try_emplace(
64           *evb, memPerReq, memPerWorker, maxFinished);
65     }
66   });
67 }
68 
onNewConnection(folly::AsyncTransport::UniquePtr sock,const folly::SocketAddress * addr,const std::string & nextProtocolName,wangle::SecureTransportType secureTransportType,const wangle::TransportInfo & tinfo)69 void Cpp2Worker::onNewConnection(
70     folly::AsyncTransport::UniquePtr sock,
71     const folly::SocketAddress* addr,
72     const std::string& nextProtocolName,
73     wangle::SecureTransportType secureTransportType,
74     const wangle::TransportInfo& tinfo) {
75   // This is possible if the connection was accepted before stopListening()
76   // call, but handshake was finished after stopCPUWorkers() call.
77   if (stopping_) {
78     return;
79   }
80 
81   auto* observer = server_->getObserver();
82   uint32_t maxConnection = server_->getMaxConnections();
83   if (maxConnection > 0 &&
84       (getConnectionManager()->getNumConnections() >=
85        maxConnection / server_->getNumIOWorkerThreads())) {
86     if (observer) {
87       observer->connDropped();
88       observer->connRejected();
89     }
90     return;
91   }
92 
93   const auto& func = server_->getZeroCopyEnableFunc();
94   if (func && sock) {
95     sock->setZeroCopy(true);
96     sock->setZeroCopyEnableFunc(func);
97   }
98 
99   // Check the security protocol
100   switch (secureTransportType) {
101     // If no security, peek into the socket to determine type
102     case wangle::SecureTransportType::NONE: {
103       new TransportPeekingManager(
104           shared_from_this(), *addr, tinfo, server_, std::move(sock));
105       break;
106     }
107     case wangle::SecureTransportType::TLS:
108       // Use the announced protocol to determine the correct handler
109       if (!nextProtocolName.empty()) {
110         for (auto& routingHandler : *server_->getRoutingHandlers()) {
111           if (routingHandler->canAcceptEncryptedConnection(nextProtocolName)) {
112             VLOG(4) << "Cpp2Worker: Routing encrypted connection for protocol "
113                     << nextProtocolName;
114             routingHandler->handleConnection(
115                 getConnectionManager(),
116                 std::move(sock),
117                 addr,
118                 tinfo,
119                 shared_from_this());
120             return;
121           }
122         }
123       }
124       if (!getServer()->isDuplex()) {
125         new TransportPeekingManager(
126             shared_from_this(), *addr, tinfo, server_, std::move(sock));
127       } else {
128         handleHeader(std::move(sock), addr, tinfo);
129       }
130       break;
131     default:
132       LOG(ERROR) << "Unsupported Secure Transport Type";
133       break;
134   }
135 }
136 
handleHeader(folly::AsyncTransport::UniquePtr sock,const folly::SocketAddress * addr,const wangle::TransportInfo & tinfo)137 void Cpp2Worker::handleHeader(
138     folly::AsyncTransport::UniquePtr sock,
139     const folly::SocketAddress* addr,
140     const wangle::TransportInfo& tinfo) {
141   auto fd = sock->getUnderlyingTransport<folly::AsyncSocket>()
142                 ->getNetworkSocket()
143                 .toFd();
144   VLOG(4) << "Cpp2Worker: Creating connection for socket " << fd;
145 
146   auto thriftTransport = createThriftTransport(std::move(sock));
147   auto connection = std::make_shared<Cpp2Connection>(
148       std::move(thriftTransport), addr, shared_from_this(), nullptr);
149   Acceptor::addConnection(connection.get());
150   connection->addConnection(connection);
151   connection->start();
152 
153   VLOG(4) << "Cpp2Worker: created connection for socket " << fd;
154 
155   auto observer = server_->getObserver();
156   if (observer) {
157     observer->connAccepted(tinfo);
158     observer->activeConnections(
159         getConnectionManager()->getNumConnections() *
160         server_->getNumIOWorkerThreads());
161   }
162 }
163 
createThriftTransport(folly::AsyncTransport::UniquePtr sock)164 std::shared_ptr<folly::AsyncTransport> Cpp2Worker::createThriftTransport(
165     folly::AsyncTransport::UniquePtr sock) {
166   auto fizzServer = dynamic_cast<fizz::server::AsyncFizzServer*>(sock.get());
167   if (fizzServer) {
168     auto asyncSock = sock->getUnderlyingTransport<folly::AsyncSocket>();
169     if (asyncSock) {
170       markSocketAccepted(asyncSock);
171     }
172     // give up ownership
173     sock.release();
174     return std::shared_ptr<fizz::server::AsyncFizzServer>(
175         fizzServer, fizz::server::AsyncFizzServer::Destructor());
176   }
177 
178   folly::AsyncSocket* tsock =
179       sock->getUnderlyingTransport<folly::AsyncSocket>();
180   CHECK(tsock);
181   markSocketAccepted(tsock);
182   // use custom deleter for std::shared_ptr<folly::AsyncTransport> to allow
183   // socket transfer from header to rocket (if enabled by ThriftFlags)
184   return apache::thrift::transport::detail::convertToShared(std::move(sock));
185 }
186 
markSocketAccepted(folly::AsyncSocket * sock)187 void Cpp2Worker::markSocketAccepted(folly::AsyncSocket* sock) {
188   sock->setShutdownSocketSet(server_->wShutdownSocketSet_);
189 }
190 
plaintextConnectionReady(folly::AsyncSocket::UniquePtr sock,const folly::SocketAddress & clientAddr,wangle::TransportInfo & tinfo)191 void Cpp2Worker::plaintextConnectionReady(
192     folly::AsyncSocket::UniquePtr sock,
193     const folly::SocketAddress& clientAddr,
194     wangle::TransportInfo& tinfo) {
195   sock->setShutdownSocketSet(server_->wShutdownSocketSet_);
196   new CheckTLSPeekingManager(
197       shared_from_this(),
198       clientAddr,
199       tinfo,
200       server_,
201       std::move(sock),
202       server_->getObserverShared());
203 }
204 
useExistingChannel(const std::shared_ptr<HeaderServerChannel> & serverChannel)205 void Cpp2Worker::useExistingChannel(
206     const std::shared_ptr<HeaderServerChannel>& serverChannel) {
207   folly::SocketAddress address;
208 
209   auto conn = std::make_shared<Cpp2Connection>(
210       nullptr, &address, shared_from_this(), serverChannel);
211   Acceptor::getConnectionManager()->addConnection(conn.get(), false);
212   conn->addConnection(conn);
213 
214   conn->start();
215 }
216 
stopDuplex(std::shared_ptr<ThriftServer> myServer)217 void Cpp2Worker::stopDuplex(std::shared_ptr<ThriftServer> myServer) {
218   // They better have given us the correct ThriftServer
219   DCHECK(server_ == myServer.get());
220 
221   // This does not really fully drain everything but at least
222   // prevents the connections from accepting new requests
223   wangle::Acceptor::drainAllConnections();
224 
225   // Capture a shared_ptr to our ThriftServer making sure it will outlive us
226   // Otherwise our raw pointer to it (server_) will be jeopardized.
227   duplexServer_ = myServer;
228 }
229 
updateSSLStats(const folly::AsyncTransport * sock,std::chrono::milliseconds,wangle::SSLErrorEnum error,const folly::exception_wrapper &)230 void Cpp2Worker::updateSSLStats(
231     const folly::AsyncTransport* sock,
232     std::chrono::milliseconds /* acceptLatency */,
233     wangle::SSLErrorEnum error,
234     const folly::exception_wrapper& /*ex*/) noexcept {
235   if (!sock) {
236     return;
237   }
238 
239   auto observer = getServer()->getObserver();
240   if (!observer) {
241     return;
242   }
243 
244   auto fizz = sock->getUnderlyingTransport<fizz::server::AsyncFizzServer>();
245   if (fizz) {
246     if (sock->good() && error == wangle::SSLErrorEnum::NO_ERROR) {
247       observer->tlsComplete();
248       auto pskType = fizz->getState().pskType();
249       if (pskType && *pskType == fizz::PskType::Resumption) {
250         observer->tlsResumption();
251       }
252       if (fizz->getPeerCertificate()) {
253         observer->tlsWithClientCert();
254       }
255     } else {
256       observer->tlsError();
257     }
258   } else {
259     auto socket = sock->getUnderlyingTransport<folly::AsyncSSLSocket>();
260     if (!socket) {
261       return;
262     }
263     if (socket->good() && error == wangle::SSLErrorEnum::NO_ERROR) {
264       observer->tlsComplete();
265       if (socket->getSSLSessionReused()) {
266         observer->tlsResumption();
267       }
268       if (socket->getPeerCertificate()) {
269         observer->tlsWithClientCert();
270       }
271     } else {
272       observer->tlsError();
273     }
274   }
275 }
276 
createSSLHelper(const std::vector<uint8_t> & bytes,const folly::SocketAddress & clientAddr,std::chrono::steady_clock::time_point acceptTime,wangle::TransportInfo & tInfo)277 wangle::AcceptorHandshakeHelper::UniquePtr Cpp2Worker::createSSLHelper(
278     const std::vector<uint8_t>& bytes,
279     const folly::SocketAddress& clientAddr,
280     std::chrono::steady_clock::time_point acceptTime,
281     wangle::TransportInfo& tInfo) {
282   if (accConfig_.fizzConfig.enableFizz) {
283     if (auto parametersContext = getThriftParametersContext()) {
284       fizzPeeker_.setThriftParametersContext(
285           folly::copy_to_shared_ptr(*parametersContext));
286     }
287     return getFizzPeeker()->getHelper(bytes, clientAddr, acceptTime, tInfo);
288   }
289   return defaultPeekingCallback_.getHelper(
290       bytes, clientAddr, acceptTime, tInfo);
291 }
292 
shouldPerformSSL(const std::vector<uint8_t> & bytes,const folly::SocketAddress & clientAddr)293 bool Cpp2Worker::shouldPerformSSL(
294     const std::vector<uint8_t>& bytes, const folly::SocketAddress& clientAddr) {
295   auto sslPolicy = getSSLPolicy();
296   if (sslPolicy == SSLPolicy::REQUIRED) {
297     if (isPlaintextAllowedOnLoopback()) {
298       // loopback clients may still be sending TLS so we need to ensure that
299       // it doesn't appear that way in addition to verifying it's loopback.
300       return !(
301           clientAddr.isLoopbackAddress() && !TLSHelper::looksLikeTLS(bytes));
302     }
303     return true;
304   } else {
305     return sslPolicy != SSLPolicy::DISABLED && TLSHelper::looksLikeTLS(bytes);
306   }
307 }
308 
309 std::optional<ThriftParametersContext>
getThriftParametersContext()310 Cpp2Worker::getThriftParametersContext() {
311   auto thriftConfigBase =
312       folly::get_ptr(accConfig_.customConfigMap, "thrift_tls_config");
313   if (!thriftConfigBase) {
314     return std::nullopt;
315   }
316   assert(static_cast<ThriftTlsConfig*>((*thriftConfigBase).get()));
317   auto thriftConfig = static_cast<ThriftTlsConfig*>((*thriftConfigBase).get());
318   if (!thriftConfig->enableThriftParamsNegotiation) {
319     return std::nullopt;
320   }
321 
322   auto thriftParametersContext = ThriftParametersContext();
323   thriftParametersContext.setUseStopTLS(
324       thriftConfig->enableStopTLS || **ThriftServer::enableStopTLS());
325   return thriftParametersContext;
326 }
327 
getHelper(const std::vector<uint8_t> & bytes,const folly::SocketAddress & clientAddr,std::chrono::steady_clock::time_point acceptTime,wangle::TransportInfo & ti)328 wangle::AcceptorHandshakeHelper::UniquePtr Cpp2Worker::getHelper(
329     const std::vector<uint8_t>& bytes,
330     const folly::SocketAddress& clientAddr,
331     std::chrono::steady_clock::time_point acceptTime,
332     wangle::TransportInfo& ti) {
333   if (!shouldPerformSSL(bytes, clientAddr)) {
334     return wangle::AcceptorHandshakeHelper::UniquePtr(
335         new wangle::UnencryptedAcceptorHandshakeHelper());
336   }
337   return createSSLHelper(bytes, clientAddr, acceptTime, ti);
338 }
339 
requestStop()340 void Cpp2Worker::requestStop() {
341   getEventBase()->runInEventBaseThreadAndWait([&] {
342     if (isStopping()) {
343       return;
344     }
345     cancelQueuedRequests();
346     stopping_.store(true, std::memory_order_relaxed);
347     if (activeRequests_ == 0) {
348       stopBaton_.post();
349     }
350   });
351 }
352 
waitForStop(std::chrono::steady_clock::time_point deadline)353 bool Cpp2Worker::waitForStop(std::chrono::steady_clock::time_point deadline) {
354   if (!stopBaton_.try_wait_until(deadline)) {
355     LOG(ERROR) << "Failed to join outstanding requests.";
356     return false;
357   }
358   return true;
359 }
360 
cancelQueuedRequests()361 void Cpp2Worker::cancelQueuedRequests() {
362   auto eb = getEventBase();
363   eb->dcheckIsInEventBaseThread();
364   for (auto& stub : requestsRegistry_->getActive()) {
365     if (stub.stateMachine_.isActive() &&
366         stub.stateMachine_.tryStopProcessing()) {
367       stub.req_->sendQueueTimeoutResponse();
368     }
369   }
370 }
371 
getActiveRequestsGuard()372 Cpp2Worker::ActiveRequestsGuard Cpp2Worker::getActiveRequestsGuard() {
373   DCHECK(!isStopping() || activeRequests_);
374   ++activeRequests_;
375   return Cpp2Worker::ActiveRequestsGuard(this);
376 }
377 
378 Cpp2Worker::PerServiceMetadata::FindMethodResult
findMethod(std::string_view methodName) const379 Cpp2Worker::PerServiceMetadata::findMethod(std::string_view methodName) const {
380   if (const auto* map =
381           std::get_if<AsyncProcessorFactory::MethodMetadataMap>(&methods_)) {
382     if (auto* m = folly::get_ptr(*map, methodName)) {
383       DCHECK(m->get());
384       return MetadataFound{**m};
385     }
386     return MetadataNotFound{};
387   }
388   if (const auto* wildcard =
389           std::get_if<AsyncProcessorFactory::WildcardMethodMetadataMap>(
390               &methods_)) {
391     if (auto* m = folly::get_ptr(wildcard->knownMethods, methodName)) {
392       DCHECK(m->get());
393       return MetadataFound{**m};
394     }
395     return MetadataFound{AsyncProcessorFactory::kWildcardMethodMetadata};
396   }
397   if (std::holds_alternative<AsyncProcessorFactory::MetadataNotImplemented>(
398           methods_)) {
399     return MetadataNotImplemented{};
400   }
401 
402   LOG(FATAL) << "Invalid CreateMethodMetadataResult from service";
403   folly::assume_unreachable();
404 }
405 
406 std::shared_ptr<folly::RequestContext>
getBaseContextForRequest(const Cpp2Worker::PerServiceMetadata::FindMethodResult & findMethodResult) const407 Cpp2Worker::PerServiceMetadata::getBaseContextForRequest(
408     const Cpp2Worker::PerServiceMetadata::FindMethodResult& findMethodResult)
409     const {
410   using Result = std::shared_ptr<folly::RequestContext>;
411   if (const auto* found =
412           std::get_if<PerServiceMetadata::MetadataFound>(&findMethodResult)) {
413     return processorFactory_.getBaseContextForRequest(found->metadata);
414   }
415   return nullptr;
416 }
417 
418 } // namespace thrift
419 } // namespace apache
420