1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef THRIFT_ASYNC_CPP2CONNCONTEXT_H_
18 #define THRIFT_ASYNC_CPP2CONNCONTEXT_H_ 1
19 
20 #include <memory>
21 #include <string_view>
22 
23 #include <folly/CancellationToken.h>
24 #include <folly/MapUtil.h>
25 #include <folly/Memory.h>
26 #include <folly/Optional.h>
27 #include <folly/SocketAddress.h>
28 #include <folly/io/async/AsyncSocket.h>
29 #include <folly/io/async/AsyncTransport.h>
30 #include <folly/io/async/ssl/OpenSSLTransportCertificate.h>
31 #include <thrift/lib/cpp/concurrency/ThreadManager.h>
32 #include <thrift/lib/cpp/server/TConnectionContext.h>
33 #include <thrift/lib/cpp/server/TServerObserver.h>
34 #include <thrift/lib/cpp/transport/THeader.h>
35 #include <thrift/lib/cpp2/async/Interaction.h>
36 #include <wangle/ssl/SSLUtil.h>
37 
38 using apache::thrift::concurrency::PriorityThreadManager;
39 
40 namespace apache {
41 namespace thrift {
42 
43 namespace rocket {
44 class ThriftRocketServerHandler;
45 }
46 
47 using ClientIdentityHook = std::function<std::unique_ptr<void, void (*)(void*)>(
48     const folly::AsyncTransport* transport,
49     X509* cert,
50     const folly::SocketAddress& peerAddress)>;
51 
52 class RequestChannel;
53 class TClientBase;
54 class Cpp2Worker;
55 
56 class ClientMetadataRef {
57  public:
ClientMetadataRef(const ClientMetadata & md)58   explicit ClientMetadataRef(const ClientMetadata& md) : md_(md) {}
59   std::optional<std::string_view> getAgent();
60   std::optional<std::string_view> getHostname();
61   std::optional<std::string_view> getOtherMetadataField(std::string_view key);
62 
63  private:
64   const ClientMetadata& md_;
65 };
66 
67 class Cpp2ConnContext : public apache::thrift::server::TConnectionContext {
68  public:
69   enum class TransportType {
70     HEADER,
71     ROCKET,
72     HTTP2,
73   };
74 
75   explicit Cpp2ConnContext(
76       const folly::SocketAddress* address = nullptr,
77       const folly::AsyncTransport* transport = nullptr,
78       folly::EventBaseManager* manager = nullptr,
79       const std::shared_ptr<RequestChannel>& duplexChannel = nullptr,
80       const std::shared_ptr<X509> peerCert = nullptr /*overridden from socket*/,
81       apache::thrift::ClientIdentityHook clientIdentityHook = nullptr,
82       const Cpp2Worker* worker = nullptr)
manager_(manager)83       : manager_(manager),
84         duplexChannel_(duplexChannel),
85         transport_(transport),
86         worker_(worker) {
87     if (address) {
88       peerAddress_ = *address;
89     }
90     X509* x509 = peerCert.get();
91     if (transport) {
92       // require worker to be passed when wrapping a real connection
93       DCHECK(worker != nullptr);
94       transport->getLocalAddress(&localAddress_);
95       auto cert = transport->getPeerCertificate();
96       if (cert) {
97         auto osslCert =
98             dynamic_cast<const folly::OpenSSLTransportCertificate*>(cert);
99         x509 = osslCert ? osslCert->getX509().get() : nullptr;
100       }
101       securityProtocol_ = transport->getSecurityProtocol();
102 
103       if (localAddress_.getFamily() == AF_UNIX) {
104         auto wrapper = transport->getUnderlyingTransport<folly::AsyncSocket>();
105         if (wrapper) {
106           peerCred_ = PeerCred::queryFromSocket(wrapper->getNetworkSocket());
107         }
108       }
109     }
110 
111     if (clientIdentityHook) {
112       peerIdentities_ = clientIdentityHook(transport, x509, peerAddress_);
113     }
114   }
115 
~Cpp2ConnContext()116   ~Cpp2ConnContext() override { DCHECK(tiles_.empty()); }
117   Cpp2ConnContext(Cpp2ConnContext&&) = default;
118   Cpp2ConnContext& operator=(Cpp2ConnContext&&) = default;
119 
reset()120   void reset() {
121     peerAddress_.reset();
122     localAddress_.reset();
123     userData_.reset();
124   }
125 
getPeerAddress()126   const folly::SocketAddress* getPeerAddress() const final {
127     return &peerAddress_;
128   }
129 
getLocalAddress()130   const folly::SocketAddress* getLocalAddress() const { return &localAddress_; }
131 
setLocalAddress(const folly::SocketAddress & localAddress)132   void setLocalAddress(const folly::SocketAddress& localAddress) {
133     localAddress_ = localAddress;
134   }
135 
setRequestHeader(apache::thrift::transport::THeader * header)136   void setRequestHeader(apache::thrift::transport::THeader* header) {
137     header_ = header;
138   }
139 
getEventBaseManager()140   folly::EventBaseManager* getEventBaseManager() override { return manager_; }
141 
getPeerCommonName()142   std::string getPeerCommonName() const {
143     if (!transport_) {
144       return "";
145     }
146     auto osslCert = dynamic_cast<const folly::OpenSSLTransportCertificate*>(
147         transport_->getPeerCertificate());
148     if (!osslCert) {
149       return "";
150     }
151     return folly::ssl::OpenSSLUtils::getCommonName(osslCert->getX509().get());
152   }
153 
154   template <typename Client>
getDuplexClient()155   std::shared_ptr<Client> getDuplexClient() {
156     DCHECK(duplexChannel_);
157     auto client = std::dynamic_pointer_cast<Client>(duplexClient_);
158     if (!client) {
159       duplexClient_.reset(new Client(duplexChannel_));
160       client = std::dynamic_pointer_cast<Client>(duplexClient_);
161     }
162     return client;
163   }
164 
getSecurityProtocol()165   virtual const std::string& getSecurityProtocol() const {
166     return securityProtocol_;
167   }
168 
getPeerIdentities()169   virtual void* getPeerIdentities() const { return peerIdentities_.get(); }
170 
getTransport()171   virtual const folly::AsyncTransport* getTransport() const {
172     return transport_;
173   }
174 
175   /**
176    * Get the user data field.
177    */
getUserData()178   void* getUserData() const override { return userData_.get(); }
179 
180   /**
181    * Set the user data field.
182    *
183    * @param data         The new value for the user data field.
184    *
185    * @return Returns the old user data value.
186    */
setUserData(folly::erased_unique_ptr data)187   void* setUserData(folly::erased_unique_ptr data) override {
188     auto oldData = userData_.release();
189     userData_ = std::move(data);
190     return oldData;
191   }
192   using TConnectionContext::setUserData;
193 
194 #ifndef _WIN32
195   struct PeerEffectiveCreds {
196     pid_t pid;
197     uid_t uid;
198     gid_t gid;
199   };
200 
201   /**
202    * Returns the connecting process ID, effective user ID, and effective user ID
203    * of the unix socket peer. The connection may have terminated since that
204    * time, so the PID may no longer map to a running process or the same process
205    * that initially connected. Returns nullopt for TCP, on Windows, and if there
206    * was an error retrieving the peer creds. In that case, call
207    * `getPeerCredError` for the reason.
208    *
209    * On macOS, the pid field contains the effective pid. On Linux, there is no
210    * distinction.
211    */
getPeerEffectiveCreds()212   folly::Optional<PeerEffectiveCreds> getPeerEffectiveCreds() const {
213     return peerCred_.getPeerEffectiveCreds();
214   }
215 #endif
216 
217   /**
218    * If the peer effective pid or uid are not available, it's possible
219    * retrieving the information failed. Produce an error message with the
220    * reason.
221    */
getPeerCredError()222   folly::Optional<std::string> getPeerCredError() const {
223     return peerCred_.getError();
224   }
225 
226   /**
227    * Retrieve a new folly::CancellationToken that will be signaled when the
228    * connection is closed.
229    */
getCancellationToken()230   folly::CancellationToken getCancellationToken() const {
231     return cancellationSource_.getToken();
232   }
233 
234   /**
235    * Signal that the connection has been closed.
236    *
237    * This is intended to be called by the thrift server implementation code.
238    *
239    * Note that this will cause any CancellationCallback functions that have been
240    * registered to run immediately in this thread.  If any of these callbacks
241    * throw this will cause program termination.
242    */
connectionClosed()243   void connectionClosed() { cancellationSource_.requestCancellation(); }
244 
getWorker()245   const Cpp2Worker* getWorker() const { return worker_; }
246 
getTransportType()247   std::optional<TransportType> getTransportType() const {
248     return transportType_;
249   }
250 
getClientType()251   std::optional<CLIENT_TYPE> getClientType() const { return clientType_; }
252 
getClientMetadataRef()253   std::optional<ClientMetadataRef> getClientMetadataRef() const {
254     if (!clientMetadata_) {
255       return {};
256     }
257     return ClientMetadataRef{*clientMetadata_};
258   }
259 
getInterfaceKind()260   InterfaceKind getInterfaceKind() const { return interfaceKind_; }
261 
262  private:
263   /**
264    * Adds interaction to interaction map
265    * Returns false and destroys tile if id is in use
266    */
addTile(int64_t id,TilePtr tile)267   bool addTile(int64_t id, TilePtr tile) {
268     return tiles_.try_emplace(id, std::move(tile)).second;
269   }
270   /**
271    * Removes interaction from map
272    * Returns old value
273    */
removeTile(int64_t id)274   TilePtr removeTile(int64_t id) {
275     auto it = tiles_.find(id);
276     if (it == tiles_.end()) {
277       return {};
278     }
279     auto ret = std::move(it->second);
280     tiles_.erase(it);
281     return ret;
282   }
283   /**
284    * Replaces interaction if id is present in map.
285    * Destroys passed-in tile otherwise.
286    */
tryReplaceTile(int64_t id,TilePtr tile)287   void tryReplaceTile(int64_t id, TilePtr tile) {
288     auto it = tiles_.find(id);
289     if (it != tiles_.end()) {
290       it->second = std::move(tile);
291     }
292   }
293   /**
294    * Gets tile from map
295    * Throws std::out_of_range if not found
296    */
getTile(int64_t id)297   Tile& getTile(int64_t id) { return *tiles_.at(id); }
298   friend class GeneratedAsyncProcessor;
299   friend class Tile;
300   friend class TilePromise;
301 
setTransportType(TransportType transportType)302   void setTransportType(TransportType transportType) {
303     transportType_ = transportType;
304   }
305 
setClientType(CLIENT_TYPE clientType)306   void setClientType(CLIENT_TYPE clientType) { clientType_ = clientType; }
307 
readSetupMetadata(const RequestSetupMetadata & meta)308   void readSetupMetadata(const RequestSetupMetadata& meta) {
309     if (const auto& md = meta.clientMetadata()) {
310       setClientMetadata(*md);
311     }
312     if (auto interfaceKind = meta.interfaceKind()) {
313       interfaceKind_ = *interfaceKind;
314     }
315   }
316 
setClientMetadata(const ClientMetadata & md)317   void setClientMetadata(const ClientMetadata& md) { clientMetadata_ = md; }
318 
319   friend class Cpp2Connection;
320   friend class rocket::ThriftRocketServerHandler;
321   friend class HTTP2RoutingHandler;
322 
323   /**
324    * Platform-independent representation of unix domain socket peer credentials,
325    * e.g. ucred on Linux and xucred on macOS.
326    *
327    * Null implementation on Windows.
328    */
329   class PeerCred {
330    public:
331 #ifndef _WIN32
332     using StatusOrPid = pid_t;
333 #else
334     // Even on Windows, differentiate between not initialized (not unix
335     // domain socket), and unsupported platform.
336     using StatusOrPid = int;
337 #endif
338 
339     /**
340      * pid_t is guaranteed to be signed, so reserve non-positive values as
341      * sentinels that indicate credential validity.
342      * While negative pid_t values are possible, they are used to refer
343      * to process groups and thus cannot occur in a process identifier.
344      * Linux and macOS allow user IDs to span the entire range of a uint32_t,
345      * so sentinal values must be stored in pid_t.
346      */
347     enum Validity : StatusOrPid {
348       NotInitialized = -1,
349       ErrorRetrieving = -2,
350       UnsupportedPlatform = -3,
351     };
352 
353     PeerCred() = default;
354     PeerCred(const PeerCred&) = default;
355     PeerCred& operator=(const PeerCred&) = default;
356 
357     /**
358      * Query a socket for peer credentials.
359      */
360     static PeerCred queryFromSocket(folly::NetworkSocket socket);
361 
362 #ifndef _WIN32
getPeerEffectiveCreds()363     folly::Optional<PeerEffectiveCreds> getPeerEffectiveCreds() const {
364       return hasCredentials()
365           ? folly::make_optional(PeerEffectiveCreds{pid_, uid_, gid_})
366           : folly::none;
367     }
368 #endif
369 
370     /**
371      * If retrieving the effective credentials failed, return a string
372      * containing the reason.
373      */
374     folly::Optional<std::string> getError() const;
375 
376    private:
PeerCred(Validity validity)377     explicit PeerCred(Validity validity) : pid_{validity} {}
378 
379 #ifndef _WIN32
PeerCred(pid_t pid,uid_t uid,gid_t gid)380     explicit PeerCred(pid_t pid, uid_t uid, gid_t gid)
381         : pid_{pid}, uid_{uid}, gid_{gid} {}
382 #endif
383 
hasCredentials()384     bool hasCredentials() const { return pid_ >= 0; }
385 
386     StatusOrPid pid_ = Validity::NotInitialized;
387 #ifndef _WIN32
388     uid_t uid_ = 0;
389     gid_t gid_ = 0;
390 #endif
391   };
392 
393   folly::erased_unique_ptr userData_{folly::empty_erased_unique_ptr()};
394   folly::SocketAddress peerAddress_;
395   folly::SocketAddress localAddress_;
396   folly::EventBaseManager* manager_;
397   std::shared_ptr<RequestChannel> duplexChannel_;
398   std::shared_ptr<TClientBase> duplexClient_;
399   folly::erased_unique_ptr peerIdentities_{folly::empty_erased_unique_ptr()};
400   std::string securityProtocol_;
401   const folly::AsyncTransport* transport_;
402   PeerCred peerCred_;
403   // A CancellationSource that will be signaled when the connection is closed.
404   folly::CancellationSource cancellationSource_;
405   folly::F14FastMap<int64_t, TilePtr> tiles_;
406   const Cpp2Worker* worker_;
407   InterfaceKind interfaceKind_{InterfaceKind::USER};
408   std::optional<TransportType> transportType_;
409   std::optional<CLIENT_TYPE> clientType_;
410   std::optional<ClientMetadata> clientMetadata_;
411 };
412 
413 class Cpp2ClientRequestContext
414     : public apache::thrift::server::TConnectionContext {
415  public:
Cpp2ClientRequestContext(transport::THeader * header)416   explicit Cpp2ClientRequestContext(transport::THeader* header)
417       : TConnectionContext(header) {}
418 
setRequestHeader(transport::THeader * header)419   void setRequestHeader(transport::THeader* header) { header_ = header; }
420 };
421 
422 // Request-specific context
423 class Cpp2RequestContext : public apache::thrift::server::TConnectionContext {
424  public:
425   explicit Cpp2RequestContext(
426       Cpp2ConnContext* ctx,
427       apache::thrift::transport::THeader* header = nullptr,
428       std::string methodName = std::string{})
TConnectionContext(header)429       : TConnectionContext(header),
430         ctx_(ctx),
431         methodName_(std::move(methodName)) {}
432 
setConnectionContext(Cpp2ConnContext * ctx)433   void setConnectionContext(Cpp2ConnContext* ctx) { ctx_ = ctx; }
434 
435   // Forward all connection-specific information
getPeerAddress()436   const folly::SocketAddress* getPeerAddress() const override {
437     return ctx_->getPeerAddress();
438   }
439 
getLocalAddress()440   const folly::SocketAddress* getLocalAddress() const {
441     return ctx_->getLocalAddress();
442   }
443 
reset()444   void reset() { ctx_->reset(); }
445 
getCallPriority()446   concurrency::PRIORITY getCallPriority() const {
447     return header_->getCallPriority();
448   }
449 
getRequestExecutionScope()450   concurrency::ThreadManager::ExecutionScope getRequestExecutionScope() const {
451     return executionScope_;
452   }
453 
setRequestExecutionScope(concurrency::ThreadManager::ExecutionScope scope)454   void setRequestExecutionScope(
455       concurrency::ThreadManager::ExecutionScope scope) {
456     executionScope_ = std::move(scope);
457   }
458 
getTransforms()459   virtual std::vector<uint16_t>& getTransforms() {
460     return header_->getWriteTransforms();
461   }
462 
getEventBaseManager()463   folly::EventBaseManager* getEventBaseManager() override {
464     return ctx_->getEventBaseManager();
465   }
466 
getUserData()467   void* getUserData() const override { return ctx_->getUserData(); }
468 
setUserData(folly::erased_unique_ptr data)469   void* setUserData(folly::erased_unique_ptr data) override {
470     return ctx_->setUserData(std::move(data));
471   }
472   using TConnectionContext::setUserData;
473 
474   // This data is set on a per request basis.
getRequestData()475   void* getRequestData() const { return requestData_.get(); }
476 
477   // Returns the old request data context so the caller can clean up
478   folly::erased_unique_ptr setRequestData(
479       void* data, void (*destructor)(void*) = no_op_destructor) {
480     return std::exchange(requestData_, {data, destructor});
481   }
setRequestData(folly::erased_unique_ptr data)482   folly::erased_unique_ptr setRequestData(folly::erased_unique_ptr data) {
483     return std::exchange(requestData_, std::move(data));
484   }
485 
getConnectionContext()486   virtual Cpp2ConnContext* getConnectionContext() const { return ctx_; }
487 
getRequestTimeout()488   std::chrono::milliseconds getRequestTimeout() const {
489     return requestTimeout_;
490   }
491 
setRequestTimeout(std::chrono::milliseconds requestTimeout)492   void setRequestTimeout(std::chrono::milliseconds requestTimeout) {
493     requestTimeout_ = requestTimeout;
494   }
495 
getMethodName()496   const std::string& getMethodName() const { return methodName_; }
497 
releaseMethodName()498   std::string releaseMethodName() { return std::move(methodName_); }
499 
setProtoSeqId(int32_t protoSeqId)500   void setProtoSeqId(int32_t protoSeqId) { protoSeqId_ = protoSeqId; }
501 
getProtoSeqId()502   int32_t getProtoSeqId() { return protoSeqId_; }
503 
setInteractionId(int64_t id)504   void setInteractionId(int64_t id) { interactionId_ = id; }
505 
getInteractionId()506   int64_t getInteractionId() { return interactionId_; }
507 
setInteractionCreate(InteractionCreate interactionCreate)508   void setInteractionCreate(InteractionCreate interactionCreate) {
509     interactionCreate_ = std::move(interactionCreate);
510   }
511 
getInteractionCreate()512   folly::Optional<InteractionCreate>& getInteractionCreate() {
513     return interactionCreate_;
514   }
515 
getTimestamps()516   server::TServerObserver::PreHandlerTimestamps& getTimestamps() {
517     return timestamps_;
518   }
519 
520   // This lets us avoid having different signatures in the processor map.
521   // Should remove if we decide to split out interaction methods.
setTile(TilePtr && tile)522   void setTile(TilePtr&& tile) {
523     DCHECK(tile);
524     tile_ = std::move(tile);
525   }
releaseTile()526   TilePtr releaseTile() { return std::move(tile_); }
527 
clientId()528   const std::string* clientId() const {
529     if (auto header = getHeader(); header && header->clientId()) {
530       return &*header->clientId();
531     }
532     if (auto headers = getHeadersPtr()) {
533       return folly::get_ptr(*headers, transport::THeader::kClientId);
534     }
535     return nullptr;
536   }
537 
538  protected:
539   apache::thrift::server::TServerObserver::CallTimestamps timestamps_;
540 
541  private:
542   Cpp2ConnContext* ctx_;
543   folly::erased_unique_ptr requestData_{nullptr, nullptr};
544   std::chrono::milliseconds requestTimeout_{0};
545   std::string methodName_;
546   int32_t protoSeqId_{0};
547   int64_t interactionId_{0};
548   folly::Optional<InteractionCreate> interactionCreate_;
549   TilePtr tile_;
550   concurrency::ThreadManager::ExecutionScope executionScope_{
551       concurrency::PRIORITY::NORMAL};
552 };
553 
554 } // namespace thrift
555 } // namespace apache
556 
557 #endif // #ifndef THRIFT_ASYNC_CPP2CONNCONTEXT_H_
558