1 /* 2 * Copyright (c) Facebook, Inc. and its affiliates. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef THRIFT_ASYNC_CPP2CONNCONTEXT_H_ 18 #define THRIFT_ASYNC_CPP2CONNCONTEXT_H_ 1 19 20 #include <memory> 21 #include <string_view> 22 23 #include <folly/CancellationToken.h> 24 #include <folly/MapUtil.h> 25 #include <folly/Memory.h> 26 #include <folly/Optional.h> 27 #include <folly/SocketAddress.h> 28 #include <folly/io/async/AsyncSocket.h> 29 #include <folly/io/async/AsyncTransport.h> 30 #include <folly/io/async/ssl/OpenSSLTransportCertificate.h> 31 #include <thrift/lib/cpp/concurrency/ThreadManager.h> 32 #include <thrift/lib/cpp/server/TConnectionContext.h> 33 #include <thrift/lib/cpp/server/TServerObserver.h> 34 #include <thrift/lib/cpp/transport/THeader.h> 35 #include <thrift/lib/cpp2/async/Interaction.h> 36 #include <wangle/ssl/SSLUtil.h> 37 38 using apache::thrift::concurrency::PriorityThreadManager; 39 40 namespace apache { 41 namespace thrift { 42 43 namespace rocket { 44 class ThriftRocketServerHandler; 45 } 46 47 using ClientIdentityHook = std::function<std::unique_ptr<void, void (*)(void*)>( 48 const folly::AsyncTransport* transport, 49 X509* cert, 50 const folly::SocketAddress& peerAddress)>; 51 52 class RequestChannel; 53 class TClientBase; 54 class Cpp2Worker; 55 56 class ClientMetadataRef { 57 public: ClientMetadataRef(const ClientMetadata & md)58 explicit ClientMetadataRef(const ClientMetadata& md) : md_(md) {} 59 std::optional<std::string_view> getAgent(); 60 std::optional<std::string_view> getHostname(); 61 std::optional<std::string_view> getOtherMetadataField(std::string_view key); 62 63 private: 64 const ClientMetadata& md_; 65 }; 66 67 class Cpp2ConnContext : public apache::thrift::server::TConnectionContext { 68 public: 69 enum class TransportType { 70 HEADER, 71 ROCKET, 72 HTTP2, 73 }; 74 75 explicit Cpp2ConnContext( 76 const folly::SocketAddress* address = nullptr, 77 const folly::AsyncTransport* transport = nullptr, 78 folly::EventBaseManager* manager = nullptr, 79 const std::shared_ptr<RequestChannel>& duplexChannel = nullptr, 80 const std::shared_ptr<X509> peerCert = nullptr /*overridden from socket*/, 81 apache::thrift::ClientIdentityHook clientIdentityHook = nullptr, 82 const Cpp2Worker* worker = nullptr) manager_(manager)83 : manager_(manager), 84 duplexChannel_(duplexChannel), 85 transport_(transport), 86 worker_(worker) { 87 if (address) { 88 peerAddress_ = *address; 89 } 90 X509* x509 = peerCert.get(); 91 if (transport) { 92 // require worker to be passed when wrapping a real connection 93 DCHECK(worker != nullptr); 94 transport->getLocalAddress(&localAddress_); 95 auto cert = transport->getPeerCertificate(); 96 if (cert) { 97 auto osslCert = 98 dynamic_cast<const folly::OpenSSLTransportCertificate*>(cert); 99 x509 = osslCert ? osslCert->getX509().get() : nullptr; 100 } 101 securityProtocol_ = transport->getSecurityProtocol(); 102 103 if (localAddress_.getFamily() == AF_UNIX) { 104 auto wrapper = transport->getUnderlyingTransport<folly::AsyncSocket>(); 105 if (wrapper) { 106 peerCred_ = PeerCred::queryFromSocket(wrapper->getNetworkSocket()); 107 } 108 } 109 } 110 111 if (clientIdentityHook) { 112 peerIdentities_ = clientIdentityHook(transport, x509, peerAddress_); 113 } 114 } 115 ~Cpp2ConnContext()116 ~Cpp2ConnContext() override { DCHECK(tiles_.empty()); } 117 Cpp2ConnContext(Cpp2ConnContext&&) = default; 118 Cpp2ConnContext& operator=(Cpp2ConnContext&&) = default; 119 reset()120 void reset() { 121 peerAddress_.reset(); 122 localAddress_.reset(); 123 userData_.reset(); 124 } 125 getPeerAddress()126 const folly::SocketAddress* getPeerAddress() const final { 127 return &peerAddress_; 128 } 129 getLocalAddress()130 const folly::SocketAddress* getLocalAddress() const { return &localAddress_; } 131 setLocalAddress(const folly::SocketAddress & localAddress)132 void setLocalAddress(const folly::SocketAddress& localAddress) { 133 localAddress_ = localAddress; 134 } 135 setRequestHeader(apache::thrift::transport::THeader * header)136 void setRequestHeader(apache::thrift::transport::THeader* header) { 137 header_ = header; 138 } 139 getEventBaseManager()140 folly::EventBaseManager* getEventBaseManager() override { return manager_; } 141 getPeerCommonName()142 std::string getPeerCommonName() const { 143 if (!transport_) { 144 return ""; 145 } 146 auto osslCert = dynamic_cast<const folly::OpenSSLTransportCertificate*>( 147 transport_->getPeerCertificate()); 148 if (!osslCert) { 149 return ""; 150 } 151 return folly::ssl::OpenSSLUtils::getCommonName(osslCert->getX509().get()); 152 } 153 154 template <typename Client> getDuplexClient()155 std::shared_ptr<Client> getDuplexClient() { 156 DCHECK(duplexChannel_); 157 auto client = std::dynamic_pointer_cast<Client>(duplexClient_); 158 if (!client) { 159 duplexClient_.reset(new Client(duplexChannel_)); 160 client = std::dynamic_pointer_cast<Client>(duplexClient_); 161 } 162 return client; 163 } 164 getSecurityProtocol()165 virtual const std::string& getSecurityProtocol() const { 166 return securityProtocol_; 167 } 168 getPeerIdentities()169 virtual void* getPeerIdentities() const { return peerIdentities_.get(); } 170 getTransport()171 virtual const folly::AsyncTransport* getTransport() const { 172 return transport_; 173 } 174 175 /** 176 * Get the user data field. 177 */ getUserData()178 void* getUserData() const override { return userData_.get(); } 179 180 /** 181 * Set the user data field. 182 * 183 * @param data The new value for the user data field. 184 * 185 * @return Returns the old user data value. 186 */ setUserData(folly::erased_unique_ptr data)187 void* setUserData(folly::erased_unique_ptr data) override { 188 auto oldData = userData_.release(); 189 userData_ = std::move(data); 190 return oldData; 191 } 192 using TConnectionContext::setUserData; 193 194 #ifndef _WIN32 195 struct PeerEffectiveCreds { 196 pid_t pid; 197 uid_t uid; 198 gid_t gid; 199 }; 200 201 /** 202 * Returns the connecting process ID, effective user ID, and effective user ID 203 * of the unix socket peer. The connection may have terminated since that 204 * time, so the PID may no longer map to a running process or the same process 205 * that initially connected. Returns nullopt for TCP, on Windows, and if there 206 * was an error retrieving the peer creds. In that case, call 207 * `getPeerCredError` for the reason. 208 * 209 * On macOS, the pid field contains the effective pid. On Linux, there is no 210 * distinction. 211 */ getPeerEffectiveCreds()212 folly::Optional<PeerEffectiveCreds> getPeerEffectiveCreds() const { 213 return peerCred_.getPeerEffectiveCreds(); 214 } 215 #endif 216 217 /** 218 * If the peer effective pid or uid are not available, it's possible 219 * retrieving the information failed. Produce an error message with the 220 * reason. 221 */ getPeerCredError()222 folly::Optional<std::string> getPeerCredError() const { 223 return peerCred_.getError(); 224 } 225 226 /** 227 * Retrieve a new folly::CancellationToken that will be signaled when the 228 * connection is closed. 229 */ getCancellationToken()230 folly::CancellationToken getCancellationToken() const { 231 return cancellationSource_.getToken(); 232 } 233 234 /** 235 * Signal that the connection has been closed. 236 * 237 * This is intended to be called by the thrift server implementation code. 238 * 239 * Note that this will cause any CancellationCallback functions that have been 240 * registered to run immediately in this thread. If any of these callbacks 241 * throw this will cause program termination. 242 */ connectionClosed()243 void connectionClosed() { cancellationSource_.requestCancellation(); } 244 getWorker()245 const Cpp2Worker* getWorker() const { return worker_; } 246 getTransportType()247 std::optional<TransportType> getTransportType() const { 248 return transportType_; 249 } 250 getClientType()251 std::optional<CLIENT_TYPE> getClientType() const { return clientType_; } 252 getClientMetadataRef()253 std::optional<ClientMetadataRef> getClientMetadataRef() const { 254 if (!clientMetadata_) { 255 return {}; 256 } 257 return ClientMetadataRef{*clientMetadata_}; 258 } 259 getInterfaceKind()260 InterfaceKind getInterfaceKind() const { return interfaceKind_; } 261 262 private: 263 /** 264 * Adds interaction to interaction map 265 * Returns false and destroys tile if id is in use 266 */ addTile(int64_t id,TilePtr tile)267 bool addTile(int64_t id, TilePtr tile) { 268 return tiles_.try_emplace(id, std::move(tile)).second; 269 } 270 /** 271 * Removes interaction from map 272 * Returns old value 273 */ removeTile(int64_t id)274 TilePtr removeTile(int64_t id) { 275 auto it = tiles_.find(id); 276 if (it == tiles_.end()) { 277 return {}; 278 } 279 auto ret = std::move(it->second); 280 tiles_.erase(it); 281 return ret; 282 } 283 /** 284 * Replaces interaction if id is present in map. 285 * Destroys passed-in tile otherwise. 286 */ tryReplaceTile(int64_t id,TilePtr tile)287 void tryReplaceTile(int64_t id, TilePtr tile) { 288 auto it = tiles_.find(id); 289 if (it != tiles_.end()) { 290 it->second = std::move(tile); 291 } 292 } 293 /** 294 * Gets tile from map 295 * Throws std::out_of_range if not found 296 */ getTile(int64_t id)297 Tile& getTile(int64_t id) { return *tiles_.at(id); } 298 friend class GeneratedAsyncProcessor; 299 friend class Tile; 300 friend class TilePromise; 301 setTransportType(TransportType transportType)302 void setTransportType(TransportType transportType) { 303 transportType_ = transportType; 304 } 305 setClientType(CLIENT_TYPE clientType)306 void setClientType(CLIENT_TYPE clientType) { clientType_ = clientType; } 307 readSetupMetadata(const RequestSetupMetadata & meta)308 void readSetupMetadata(const RequestSetupMetadata& meta) { 309 if (const auto& md = meta.clientMetadata()) { 310 setClientMetadata(*md); 311 } 312 if (auto interfaceKind = meta.interfaceKind()) { 313 interfaceKind_ = *interfaceKind; 314 } 315 } 316 setClientMetadata(const ClientMetadata & md)317 void setClientMetadata(const ClientMetadata& md) { clientMetadata_ = md; } 318 319 friend class Cpp2Connection; 320 friend class rocket::ThriftRocketServerHandler; 321 friend class HTTP2RoutingHandler; 322 323 /** 324 * Platform-independent representation of unix domain socket peer credentials, 325 * e.g. ucred on Linux and xucred on macOS. 326 * 327 * Null implementation on Windows. 328 */ 329 class PeerCred { 330 public: 331 #ifndef _WIN32 332 using StatusOrPid = pid_t; 333 #else 334 // Even on Windows, differentiate between not initialized (not unix 335 // domain socket), and unsupported platform. 336 using StatusOrPid = int; 337 #endif 338 339 /** 340 * pid_t is guaranteed to be signed, so reserve non-positive values as 341 * sentinels that indicate credential validity. 342 * While negative pid_t values are possible, they are used to refer 343 * to process groups and thus cannot occur in a process identifier. 344 * Linux and macOS allow user IDs to span the entire range of a uint32_t, 345 * so sentinal values must be stored in pid_t. 346 */ 347 enum Validity : StatusOrPid { 348 NotInitialized = -1, 349 ErrorRetrieving = -2, 350 UnsupportedPlatform = -3, 351 }; 352 353 PeerCred() = default; 354 PeerCred(const PeerCred&) = default; 355 PeerCred& operator=(const PeerCred&) = default; 356 357 /** 358 * Query a socket for peer credentials. 359 */ 360 static PeerCred queryFromSocket(folly::NetworkSocket socket); 361 362 #ifndef _WIN32 getPeerEffectiveCreds()363 folly::Optional<PeerEffectiveCreds> getPeerEffectiveCreds() const { 364 return hasCredentials() 365 ? folly::make_optional(PeerEffectiveCreds{pid_, uid_, gid_}) 366 : folly::none; 367 } 368 #endif 369 370 /** 371 * If retrieving the effective credentials failed, return a string 372 * containing the reason. 373 */ 374 folly::Optional<std::string> getError() const; 375 376 private: PeerCred(Validity validity)377 explicit PeerCred(Validity validity) : pid_{validity} {} 378 379 #ifndef _WIN32 PeerCred(pid_t pid,uid_t uid,gid_t gid)380 explicit PeerCred(pid_t pid, uid_t uid, gid_t gid) 381 : pid_{pid}, uid_{uid}, gid_{gid} {} 382 #endif 383 hasCredentials()384 bool hasCredentials() const { return pid_ >= 0; } 385 386 StatusOrPid pid_ = Validity::NotInitialized; 387 #ifndef _WIN32 388 uid_t uid_ = 0; 389 gid_t gid_ = 0; 390 #endif 391 }; 392 393 folly::erased_unique_ptr userData_{folly::empty_erased_unique_ptr()}; 394 folly::SocketAddress peerAddress_; 395 folly::SocketAddress localAddress_; 396 folly::EventBaseManager* manager_; 397 std::shared_ptr<RequestChannel> duplexChannel_; 398 std::shared_ptr<TClientBase> duplexClient_; 399 folly::erased_unique_ptr peerIdentities_{folly::empty_erased_unique_ptr()}; 400 std::string securityProtocol_; 401 const folly::AsyncTransport* transport_; 402 PeerCred peerCred_; 403 // A CancellationSource that will be signaled when the connection is closed. 404 folly::CancellationSource cancellationSource_; 405 folly::F14FastMap<int64_t, TilePtr> tiles_; 406 const Cpp2Worker* worker_; 407 InterfaceKind interfaceKind_{InterfaceKind::USER}; 408 std::optional<TransportType> transportType_; 409 std::optional<CLIENT_TYPE> clientType_; 410 std::optional<ClientMetadata> clientMetadata_; 411 }; 412 413 class Cpp2ClientRequestContext 414 : public apache::thrift::server::TConnectionContext { 415 public: Cpp2ClientRequestContext(transport::THeader * header)416 explicit Cpp2ClientRequestContext(transport::THeader* header) 417 : TConnectionContext(header) {} 418 setRequestHeader(transport::THeader * header)419 void setRequestHeader(transport::THeader* header) { header_ = header; } 420 }; 421 422 // Request-specific context 423 class Cpp2RequestContext : public apache::thrift::server::TConnectionContext { 424 public: 425 explicit Cpp2RequestContext( 426 Cpp2ConnContext* ctx, 427 apache::thrift::transport::THeader* header = nullptr, 428 std::string methodName = std::string{}) TConnectionContext(header)429 : TConnectionContext(header), 430 ctx_(ctx), 431 methodName_(std::move(methodName)) {} 432 setConnectionContext(Cpp2ConnContext * ctx)433 void setConnectionContext(Cpp2ConnContext* ctx) { ctx_ = ctx; } 434 435 // Forward all connection-specific information getPeerAddress()436 const folly::SocketAddress* getPeerAddress() const override { 437 return ctx_->getPeerAddress(); 438 } 439 getLocalAddress()440 const folly::SocketAddress* getLocalAddress() const { 441 return ctx_->getLocalAddress(); 442 } 443 reset()444 void reset() { ctx_->reset(); } 445 getCallPriority()446 concurrency::PRIORITY getCallPriority() const { 447 return header_->getCallPriority(); 448 } 449 getRequestExecutionScope()450 concurrency::ThreadManager::ExecutionScope getRequestExecutionScope() const { 451 return executionScope_; 452 } 453 setRequestExecutionScope(concurrency::ThreadManager::ExecutionScope scope)454 void setRequestExecutionScope( 455 concurrency::ThreadManager::ExecutionScope scope) { 456 executionScope_ = std::move(scope); 457 } 458 getTransforms()459 virtual std::vector<uint16_t>& getTransforms() { 460 return header_->getWriteTransforms(); 461 } 462 getEventBaseManager()463 folly::EventBaseManager* getEventBaseManager() override { 464 return ctx_->getEventBaseManager(); 465 } 466 getUserData()467 void* getUserData() const override { return ctx_->getUserData(); } 468 setUserData(folly::erased_unique_ptr data)469 void* setUserData(folly::erased_unique_ptr data) override { 470 return ctx_->setUserData(std::move(data)); 471 } 472 using TConnectionContext::setUserData; 473 474 // This data is set on a per request basis. getRequestData()475 void* getRequestData() const { return requestData_.get(); } 476 477 // Returns the old request data context so the caller can clean up 478 folly::erased_unique_ptr setRequestData( 479 void* data, void (*destructor)(void*) = no_op_destructor) { 480 return std::exchange(requestData_, {data, destructor}); 481 } setRequestData(folly::erased_unique_ptr data)482 folly::erased_unique_ptr setRequestData(folly::erased_unique_ptr data) { 483 return std::exchange(requestData_, std::move(data)); 484 } 485 getConnectionContext()486 virtual Cpp2ConnContext* getConnectionContext() const { return ctx_; } 487 getRequestTimeout()488 std::chrono::milliseconds getRequestTimeout() const { 489 return requestTimeout_; 490 } 491 setRequestTimeout(std::chrono::milliseconds requestTimeout)492 void setRequestTimeout(std::chrono::milliseconds requestTimeout) { 493 requestTimeout_ = requestTimeout; 494 } 495 getMethodName()496 const std::string& getMethodName() const { return methodName_; } 497 releaseMethodName()498 std::string releaseMethodName() { return std::move(methodName_); } 499 setProtoSeqId(int32_t protoSeqId)500 void setProtoSeqId(int32_t protoSeqId) { protoSeqId_ = protoSeqId; } 501 getProtoSeqId()502 int32_t getProtoSeqId() { return protoSeqId_; } 503 setInteractionId(int64_t id)504 void setInteractionId(int64_t id) { interactionId_ = id; } 505 getInteractionId()506 int64_t getInteractionId() { return interactionId_; } 507 setInteractionCreate(InteractionCreate interactionCreate)508 void setInteractionCreate(InteractionCreate interactionCreate) { 509 interactionCreate_ = std::move(interactionCreate); 510 } 511 getInteractionCreate()512 folly::Optional<InteractionCreate>& getInteractionCreate() { 513 return interactionCreate_; 514 } 515 getTimestamps()516 server::TServerObserver::PreHandlerTimestamps& getTimestamps() { 517 return timestamps_; 518 } 519 520 // This lets us avoid having different signatures in the processor map. 521 // Should remove if we decide to split out interaction methods. setTile(TilePtr && tile)522 void setTile(TilePtr&& tile) { 523 DCHECK(tile); 524 tile_ = std::move(tile); 525 } releaseTile()526 TilePtr releaseTile() { return std::move(tile_); } 527 clientId()528 const std::string* clientId() const { 529 if (auto header = getHeader(); header && header->clientId()) { 530 return &*header->clientId(); 531 } 532 if (auto headers = getHeadersPtr()) { 533 return folly::get_ptr(*headers, transport::THeader::kClientId); 534 } 535 return nullptr; 536 } 537 538 protected: 539 apache::thrift::server::TServerObserver::CallTimestamps timestamps_; 540 541 private: 542 Cpp2ConnContext* ctx_; 543 folly::erased_unique_ptr requestData_{nullptr, nullptr}; 544 std::chrono::milliseconds requestTimeout_{0}; 545 std::string methodName_; 546 int32_t protoSeqId_{0}; 547 int64_t interactionId_{0}; 548 folly::Optional<InteractionCreate> interactionCreate_; 549 TilePtr tile_; 550 concurrency::ThreadManager::ExecutionScope executionScope_{ 551 concurrency::PRIORITY::NORMAL}; 552 }; 553 554 } // namespace thrift 555 } // namespace apache 556 557 #endif // #ifndef THRIFT_ASYNC_CPP2CONNCONTEXT_H_ 558