1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <map>
18
19 #include <glog/logging.h>
20
21 #include <boost/python.hpp>
22
23 #include <folly/Memory.h>
24 #include <folly/ScopeGuard.h>
25 #include <folly/SocketAddress.h>
26 #include <folly/Utility.h>
27
28 #include <thrift/lib/cpp/concurrency/PosixThreadFactory.h>
29 #include <thrift/lib/cpp/concurrency/ThreadManager.h>
30 #include <thrift/lib/cpp/protocol/TProtocolTypes.h>
31 #include <thrift/lib/cpp2/async/AsyncProcessor.h>
32 #include <thrift/lib/cpp2/protocol/BinaryProtocol.h>
33 #include <thrift/lib/cpp2/protocol/CompactProtocol.h>
34 #include <thrift/lib/cpp2/server/ThriftServer.h>
35 #include <thrift/lib/cpp2/util/LegacyRequestExpiryGuard.h>
36 #include <thrift/lib/py/server/CppContextData.h>
37 #include <wangle/ssl/SSLContextConfig.h>
38
39 using namespace apache::thrift;
40 using apache::thrift::BaseThriftServer;
41 using apache::thrift::concurrency::PosixThreadFactory;
42 using apache::thrift::concurrency::ThreadManager;
43 using apache::thrift::server::TConnectionContext;
44 using apache::thrift::server::TServerEventHandler;
45 using apache::thrift::server::TServerObserver;
46 using apache::thrift::transport::THeader;
47 using folly::SSLContext;
48 using wangle::SSLCacheOptions;
49 using wangle::SSLContextConfig;
50 using namespace boost::python;
51
52 // If less than 3.7 offer no additional protection
53 #if PY_VERSION_HEX <= 0x03070000
54 #define _Py_IsFinalizing() false
55 #endif
56
57 namespace {
58
59 const std::string kHeaderEx = "uex";
60 const std::string kHeaderExWhat = "uexw";
61
makePythonHeaders(const THeader::StringToStringMap & cppheaders,const Cpp2RequestContext * context)62 object makePythonHeaders(
63 const THeader::StringToStringMap& cppheaders,
64 const Cpp2RequestContext* context) {
65 object headers = dict();
66 for (const auto& it : cppheaders) {
67 headers[it.first] = it.second;
68 }
69 headers[apache::thrift::THeader::CLIENT_TIMEOUT_HEADER] =
70 folly::to<std::string>(
71 std::chrono::milliseconds(context->getRequestTimeout()).count());
72 return headers;
73 }
74
makePythonList(const std::vector<std::string> & vec)75 object makePythonList(const std::vector<std::string>& vec) {
76 list result;
77 for (auto it = vec.begin(); it != vec.end(); ++it) {
78 result.append(*it);
79 }
80 return std::move(result);
81 }
82
getStringAttrSafe(object & pyObject,const char * attrName)83 std::string getStringAttrSafe(object& pyObject, const char* attrName) {
84 object val = pyObject.attr(attrName);
85 if (val.is_none()) {
86 return "";
87 }
88 return extract<std::string>(str(val));
89 }
90
91 template <class T>
getIntAttr(object & pyObject,const char * attrName)92 T getIntAttr(object& pyObject, const char* attrName) {
93 object val = pyObject.attr(attrName);
94 return extract<T>(val);
95 }
96
97 } // namespace
98
99 class CallbackWrapper {
100 public:
call(object obj)101 void call(object obj) { callback_(obj); }
102
setCallback(folly::Function<void (object)> && callback)103 void setCallback(folly::Function<void(object)>&& callback) {
104 callback_ = std::move(callback);
105 }
106
107 private:
108 folly::Function<void(object)> callback_;
109 };
110
111 class CppServerEventHandler : public TServerEventHandler {
112 public:
CppServerEventHandler(object serverEventHandler)113 explicit CppServerEventHandler(object serverEventHandler)
114 : handler_(std::make_shared<object>(serverEventHandler)) {}
115
newConnection(TConnectionContext * ctx)116 void newConnection(TConnectionContext* ctx) override {
117 callPythonHandler(ctx, "newConnection");
118 }
119
connectionDestroyed(TConnectionContext * ctx)120 void connectionDestroyed(TConnectionContext* ctx) override {
121 callPythonHandler(ctx, "connectionDestroyed");
122 }
123
124 private:
callPythonHandler(TConnectionContext * ctx,const char * method)125 void callPythonHandler(TConnectionContext* ctx, const char* method) {
126 if (!_Py_IsFinalizing()) {
127 PyGILState_STATE state = PyGILState_Ensure();
128 SCOPE_EXIT { PyGILState_Release(state); };
129
130 // This cast always succeeds because it is called from Cpp2Connection.
131 Cpp2ConnContext* cpp2Ctx = dynamic_cast<Cpp2ConnContext*>(ctx);
132 auto cd_cls = handler_->attr("CONTEXT_DATA");
133 object contextData = cd_cls();
134 extract<CppContextData&>(contextData)().copyContextContents(cpp2Ctx);
135 auto ctx_cls = handler_->attr("CPP_CONNECTION_CONTEXT");
136 object cppConnContext = ctx_cls(contextData);
137 handler_->attr(method)(cppConnContext);
138 }
139 }
140
141 std::shared_ptr<object> handler_;
142 };
143
144 class PythonCallTimestamps : public TServerObserver::CallTimestamps {
145 public:
set_readEndNow()146 void set_readEndNow() { readEnd = clock::now(); }
get_readEndUsec() const147 uint64_t get_readEndUsec() const noexcept {
148 return to_microseconds(readEnd.time_since_epoch());
149 }
set_processBeginNow()150 void set_processBeginNow() { processBegin = clock::now(); }
get_processBeginUsec() const151 uint64_t get_processBeginUsec() const noexcept {
152 return to_microseconds(processBegin.time_since_epoch());
153 }
set_processEndNow()154 void set_processEndNow() { processEnd = clock::now(); }
get_processEndUsec() const155 uint64_t get_processEndUsec() const noexcept {
156 return to_microseconds(processEnd.time_since_epoch());
157 }
set_writeBeginNow()158 void set_writeBeginNow() { writeBegin = clock::now(); }
get_writeBeginUsec() const159 uint64_t get_writeBeginUsec() const noexcept {
160 return to_microseconds(writeBegin.time_since_epoch());
161 }
set_writeEndNow()162 void set_writeEndNow() { writeEnd = clock::now(); }
get_writeEndUsec() const163 uint64_t get_writeEndUsec() const noexcept {
164 return to_microseconds(writeEnd.time_since_epoch());
165 }
166 };
167
168 class CppServerObserver : public TServerObserver {
169 public:
CppServerObserver(object serverObserver)170 explicit CppServerObserver(object serverObserver)
171 : observer_(serverObserver) {}
172
connAccepted(const wangle::TransportInfo &)173 void connAccepted(const wangle::TransportInfo& /* info */) override {
174 this->call("connAccepted");
175 }
connDropped()176 void connDropped() override { this->call("connDropped"); }
connRejected()177 void connRejected() override { this->call("connRejected"); }
tlsError()178 void tlsError() override { this->call("tlsError"); }
tlsComplete()179 void tlsComplete() override { this->call("tlsComplete"); }
tlsFallback()180 void tlsFallback() override { this->call("tlsFallback"); }
tlsResumption()181 void tlsResumption() override { this->call("tlsResumption"); }
taskKilled()182 void taskKilled() override { this->call("taskKilled"); }
taskTimeout()183 void taskTimeout() override { this->call("taskTimeout"); }
serverOverloaded()184 void serverOverloaded() override { this->call("serverOverloaded"); }
receivedRequest(const std::string *)185 void receivedRequest(const std::string* /*method*/) override {
186 this->call("receivedRequest");
187 }
admittedRequest(const std::string *)188 void admittedRequest(const std::string* /*method*/) override {
189 this->call("admittedRequest");
190 }
queuedRequests(int32_t n)191 void queuedRequests(int32_t n) override { this->call("queuedRequests", n); }
queueTimeout()192 void queueTimeout() override { this->call("queueTimeout"); }
sentReply()193 void sentReply() override { this->call("sentReply"); }
activeRequests(int32_t n)194 void activeRequests(int32_t n) override { this->call("activeRequests", n); }
callCompleted(const CallTimestamps & runtimes)195 void callCompleted(const CallTimestamps& runtimes) override {
196 this->call(
197 "callCompleted",
198 reinterpret_cast<const PythonCallTimestamps&>(runtimes));
199 }
tlsWithClientCert()200 void tlsWithClientCert() override { this->call("tlsWithClientCert"); }
201
202 private:
203 template <class... Types>
call(const char * method_name,Types...args)204 void call(const char* method_name, Types... args) {
205 PyGILState_STATE state = PyGILState_Ensure();
206 SCOPE_EXIT { PyGILState_Release(state); };
207
208 // check if the object has an attribute, because we want to be accepting
209 // if we added a new listener callback and didn't yet update call the
210 // people using this interface.
211 if (!PyObject_HasAttrString(observer_.ptr(), method_name)) {
212 return;
213 }
214
215 try {
216 (void)observer_.attr(method_name)(args...);
217 } catch (const error_already_set&) {
218 // print the error to sys.stderr and carry on, because raising here
219 // would break the server protocol, and raising in Python later
220 // would be extremely disconnected and confusing since it would
221 // happen in apparently unconnected Python code.
222 PyErr_Print();
223 }
224 }
225
226 object observer_;
227 };
228
229 class PythonAsyncProcessor : public AsyncProcessor {
230 public:
PythonAsyncProcessor(std::shared_ptr<object> adapter)231 explicit PythonAsyncProcessor(std::shared_ptr<object> adapter)
232 : adapter_(adapter) {
233 getPythonOnewayMethods();
234 }
235
236 // Create a task and add it to thread manager's queue. Essentially the same
237 // as GeneratedAsyncProcessor's processInThread method.
processSerializedRequest(ResponseChannelRequest::UniquePtr req,apache::thrift::SerializedRequest && serializedRequest,apache::thrift::protocol::PROTOCOL_TYPES protType,Cpp2RequestContext * context,folly::EventBase * eb,apache::thrift::concurrency::ThreadManager * tm)238 void processSerializedRequest(
239 ResponseChannelRequest::UniquePtr req,
240 apache::thrift::SerializedRequest&& serializedRequest,
241 apache::thrift::protocol::PROTOCOL_TYPES protType,
242 Cpp2RequestContext* context,
243 folly::EventBase* eb,
244 apache::thrift::concurrency::ThreadManager* tm) override {
245 auto fname = context->getMethodName();
246 bool oneway = isOnewayMethod(fname);
247
248 if (oneway && !req->isOneway()) {
249 req->sendReply(ResponsePayload{});
250 }
251
252 apache::thrift::LegacyRequestExpiryGuard rh{std::move(req), eb};
253 auto task = [=,
254 buf = apache::thrift::LegacySerializedRequest(
255 protType,
256 context->getProtoSeqId(),
257 context->getMethodName(),
258 std::move(serializedRequest))
259 .buffer,
260 rh = std::move(rh)]() mutable {
261 auto req_up = std::move(rh.req);
262 SCOPE_EXIT {
263 rh.eb->runInEventBaseThread(
264 [req_up = std::move(req_up)]() mutable { req_up = {}; });
265 };
266
267 if (!oneway && !req_up->getShouldStartProcessing()) {
268 return;
269 }
270
271 folly::ByteRange input_range = buf->coalesce();
272 auto input_data = const_cast<unsigned char*>(input_range.data());
273 auto clientType = context->getHeader()->getClientType();
274
275 {
276 PyGILState_STATE state = PyGILState_Ensure();
277 SCOPE_EXIT { PyGILState_Release(state); };
278
279 #if PY_MAJOR_VERSION == 2
280 auto input =
281 handle<>(PyBuffer_FromMemory(input_data, input_range.size()));
282 #else
283 auto input = handle<>(PyMemoryView_FromMemory(
284 reinterpret_cast<char*>(input_data),
285 input_range.size(),
286 PyBUF_READ));
287 #endif
288
289 auto cd_ctor = adapter_->attr("CONTEXT_DATA");
290 object contextData = cd_ctor();
291 extract<CppContextData&>(contextData)().copyContextContents(context);
292
293 auto cb_ctor = adapter_->attr("CALLBACK_WRAPPER");
294 object callbackWrapper = cb_ctor();
295 extract<CallbackWrapper&>(callbackWrapper)().setCallback(
296 [oneway,
297 req_up = std::move(req_up),
298 context,
299 eb = rh.eb,
300 contextData,
301 protType](object output) mutable {
302 // Make sure the request is deleted in evb.
303 SCOPE_EXIT {
304 eb->runInEventBaseThread(
305 [req_up = std::move(req_up)]() mutable { req_up = {}; });
306 };
307
308 // Always called from python so no need to grab GIL.
309 try {
310 std::unique_ptr<folly::IOBuf> outbuf;
311 if (output.is_none()) {
312 throw std::runtime_error(
313 "Unexpected error in processor method");
314 }
315 PyObject* output_ptr = output.ptr();
316 #if PY_MAJOR_VERSION == 2
317 if (PyString_Check(output_ptr)) {
318 int len = extract<int>(output.attr("__len__")());
319 if (len == 0) {
320 return;
321 }
322 outbuf = folly::IOBuf::copyBuffer(
323 extract<const char*>(output), len);
324 } else
325 #endif
326 if (PyBytes_Check(output_ptr)) {
327 int len = PyBytes_Size(output_ptr);
328 if (len == 0) {
329 return;
330 }
331 outbuf = folly::IOBuf::copyBuffer(
332 PyBytes_AsString(output_ptr), len);
333 } else {
334 throw std::runtime_error(
335 "Return from processor "
336 "method is not string or bytes");
337 }
338
339 if (!req_up->isActive()) {
340 return;
341 }
342 CppContextData& cppContextData =
343 extract<CppContextData&>(contextData);
344 if (!cppContextData.getHeaderEx().empty()) {
345 context->getHeader()->setHeader(
346 kHeaderEx, cppContextData.getHeaderEx());
347 }
348 if (!cppContextData.getHeaderExWhat().empty()) {
349 context->getHeader()->setHeader(
350 kHeaderExWhat, cppContextData.getHeaderExWhat());
351 }
352 auto response = LegacySerializedResponse{std::move(outbuf)};
353 auto [mtype, payload] = std::move(response).extractPayload(
354 req_up->includeEnvelope(), protType);
355 payload.transform(context->getHeader()->getWriteTransforms());
356 eb->runInEventBaseThread(
357 [mtype = mtype,
358 req_up = std::move(req_up),
359 payload = std::move(payload)]() mutable {
360 if (mtype == MessageType::T_REPLY) {
361 req_up->sendReply(std::move(payload));
362 } else if (mtype == MessageType::T_EXCEPTION) {
363 req_up->sendException(std::move(payload));
364 } else {
365 LOG(ERROR) << "Invalid type. type=" << uint16_t(mtype);
366 }
367 });
368 } catch (const std::exception& e) {
369 if (!oneway) {
370 req_up->sendErrorWrapped(
371 folly::make_exception_wrapper<TApplicationException>(
372 folly::to<std::string>(
373 "Failed to read response from Python:",
374 e.what())),
375 "python");
376 }
377 }
378 });
379
380 adapter_->attr("call_processor")(
381 input,
382 makePythonHeaders(context->getHeader()->getHeaders(), context),
383 int(clientType),
384 int(protType),
385 contextData,
386 callbackWrapper);
387 }
388 };
389
390 using PriorityThreadManager =
391 apache::thrift::concurrency::PriorityThreadManager;
392 auto ptm = dynamic_cast<PriorityThreadManager*>(tm);
393 if (ptm != nullptr) {
394 ptm->add(
395 getMethodPriority(fname, context),
396 std::make_shared<apache::thrift::concurrency::FunctionRunner>(
397 std::move(task)));
398 return;
399 }
400 tm->add(std::move(task));
401 }
402
403 /**
404 * Get the priority of the request
405 * Check the headers directly in C++ since noone seems to override that logic
406 * Ask python if no priority headers were supplied with the request
407 */
getMethodPriority(std::string const & fname,Cpp2RequestContext * ctx=nullptr)408 concurrency::PRIORITY getMethodPriority(
409 std::string const& fname, Cpp2RequestContext* ctx = nullptr) {
410 if (ctx) {
411 auto requestPriority = ctx->getCallPriority();
412 if (requestPriority != concurrency::PRIORITY::N_PRIORITIES) {
413 VLOG(3) << "Request priority from headers";
414 return requestPriority;
415 }
416 }
417
418 PyGILState_STATE state = PyGILState_Ensure();
419 SCOPE_EXIT { PyGILState_Release(state); };
420
421 try {
422 return static_cast<concurrency::PRIORITY>(
423 extract<int>(adapter_->attr("get_priority")(fname))());
424 } catch (error_already_set&) {
425 // get_priority doesn't exist, or it threw an exception
426 LOG(ERROR) << "Error while calling _ProcessorAdapter.get_priority()";
427 PyErr_Print();
428 }
429
430 return concurrency::PRIORITY::NORMAL;
431 }
432
433 private:
isOnewayMethod(std::string const & fname)434 bool isOnewayMethod(std::string const& fname) {
435 return onewayMethods_.find(fname) != onewayMethods_.end();
436 }
437
getPythonOnewayMethods()438 void getPythonOnewayMethods() {
439 PyGILState_STATE state = PyGILState_Ensure();
440 SCOPE_EXIT { PyGILState_Release(state); };
441 object ret = adapter_->attr("oneway_methods")();
442 if (ret.is_none()) {
443 LOG(ERROR) << "Unexpected error in processor method";
444 return;
445 }
446 tuple t = extract<tuple>(ret);
447 for (int i = 0; i < len(t); i++) {
448 onewayMethods_.insert(extract<std::string>(t[i]));
449 }
450 }
451
452 std::shared_ptr<object> adapter_;
453 std::unordered_set<std::string> onewayMethods_;
454 };
455
456 class PythonAsyncProcessorFactory : public AsyncProcessorFactory {
457 public:
PythonAsyncProcessorFactory(std::shared_ptr<object> adapter)458 explicit PythonAsyncProcessorFactory(std::shared_ptr<object> adapter)
459 : adapter_(adapter) {}
460
getProcessor()461 std::unique_ptr<apache::thrift::AsyncProcessor> getProcessor() override {
462 return std::make_unique<PythonAsyncProcessor>(adapter_);
463 }
464
465 // TODO(T89004867): Call onStartServing() and onStopServing() hooks for
466 // non-C++ thrift servers
getServiceHandlers()467 std::vector<apache::thrift::ServiceHandler*> getServiceHandlers() override {
468 return {};
469 }
470
471 private:
472 std::shared_ptr<object> adapter_;
473 };
474
475 class CppServerWrapper : public ThriftServer {
476 public:
CppServerWrapper()477 CppServerWrapper() {
478 BaseThriftServer::metadata().wrapper = "CppServerWrapper-py";
479 }
480
setAdapter(object adapter)481 void setAdapter(object adapter) {
482 // We use a shared_ptr to manage the adapter so the processor
483 // factory handing won't ever try to manipulate python reference
484 // counts without the GIL.
485 setProcessorFactory(std::make_unique<PythonAsyncProcessorFactory>(
486 std::make_shared<object>(adapter)));
487 }
488
489 // peer to setObserver, but since we want a different argument, avoid
490 // shadowing in our parent class.
setObserverFromPython(object observer)491 void setObserverFromPython(object observer) {
492 setObserver(std::make_shared<CppServerObserver>(observer));
493 }
494
getAddress()495 object getAddress() { return makePythonAddress(ThriftServer::getAddress()); }
496
loop()497 void loop() {
498 PyThreadState* save_state = PyEval_SaveThread();
499 SCOPE_EXIT { PyEval_RestoreThread(save_state); };
500
501 // Thrift main loop. This will run indefinitely, until stop() is
502 // called.
503
504 getServeEventBase()->loopForever();
505 }
506
setup()507 void setup() {
508 PyThreadState* save_state = PyEval_SaveThread();
509 SCOPE_EXIT { PyEval_RestoreThread(save_state); };
510
511 // This check is only useful for C++-based Thrift servers.
512 ThriftServer::setAllowCheckUnimplementedExtraInterfaces(false);
513 ThriftServer::setup();
514 }
515
setCppSSLConfig(object sslConfig)516 void setCppSSLConfig(object sslConfig) {
517 auto certPath = getStringAttrSafe(sslConfig, "cert_path");
518 auto keyPath = getStringAttrSafe(sslConfig, "key_path");
519 if (certPath.empty() ^ keyPath.empty()) {
520 PyErr_SetString(
521 PyExc_ValueError, "certPath and keyPath must both be populated");
522 throw_error_already_set();
523 return;
524 }
525 auto cfg = std::make_shared<SSLContextConfig>();
526 cfg->clientCAFile = getStringAttrSafe(sslConfig, "client_ca_path");
527 if (!certPath.empty()) {
528 auto keyPwPath = getStringAttrSafe(sslConfig, "key_pw_path");
529 cfg->setCertificate(certPath, keyPath, keyPwPath);
530 }
531 cfg->clientVerification =
532 extract<SSLContext::VerifyClientCertificate>(sslConfig.attr("verify"));
533 auto eccCurve = getStringAttrSafe(sslConfig, "ecc_curve_name");
534 if (!eccCurve.empty()) {
535 cfg->eccCurveName = eccCurve;
536 }
537 object sessionContext = sslConfig.attr("session_context");
538 if (!sessionContext.is_none()) {
539 cfg->sessionContext = extract<std::string>(str(sessionContext));
540 }
541
542 object sslVersionAttr = sslConfig.attr("ssl_version");
543 if (!sslVersionAttr.is_none()) {
544 cfg->sslVersion =
545 extract<SSLContext::SSLVersion>(sslConfig.attr("ssl_version"));
546 }
547
548 ThriftServer::setSSLConfig(folly::observer::makeObserver(
549 [cfg, nextProtocolsObserver = ThriftServer::defaultNextProtocols()] {
550 auto cfgWithNextProtocols = *cfg;
551 cfgWithNextProtocols.setNextProtocols(**nextProtocolsObserver);
552 return cfgWithNextProtocols;
553 }));
554
555 setSSLPolicy(extract<SSLPolicy>(sslConfig.attr("ssl_policy")));
556
557 auto ticketFilePath = getStringAttrSafe(sslConfig, "ticket_file_path");
558 ThriftServer::watchTicketPathForChanges(ticketFilePath);
559 }
560
setCppFastOpenOptions(object enabledObj,object tfoMaxQueueObj)561 void setCppFastOpenOptions(object enabledObj, object tfoMaxQueueObj) {
562 bool enabled{extract<bool>(enabledObj)};
563 uint32_t tfoMaxQueue{extract<uint32_t>(tfoMaxQueueObj)};
564 ThriftServer::setFastOpenOptions(enabled, tfoMaxQueue);
565 }
566
useCppExistingSocket(int socket)567 void useCppExistingSocket(int socket) {
568 ThriftServer::useExistingSocket(socket);
569 }
570
setCppSSLCacheOptions(object cacheOptions)571 void setCppSSLCacheOptions(object cacheOptions) {
572 SSLCacheOptions options = {
573 .sslCacheTimeout = std::chrono::seconds(
574 getIntAttr<uint32_t>(cacheOptions, "ssl_cache_timeout_seconds")),
575 .maxSSLCacheSize =
576 getIntAttr<uint64_t>(cacheOptions, "max_ssl_cache_size"),
577 .sslCacheFlushSize =
578 getIntAttr<uint64_t>(cacheOptions, "ssl_cache_flush_size"),
579 .handshakeValidity = std::chrono::seconds(getIntAttr<uint32_t>(
580 cacheOptions, "ssl_handshake_validity_seconds")),
581 };
582 ThriftServer::setSSLCacheOptions(std::move(options));
583 }
584
getCppTicketSeeds()585 object getCppTicketSeeds() {
586 auto seeds = getTicketSeeds();
587 if (!seeds) {
588 return boost::python::object();
589 }
590 boost::python::dict result;
591 result["old"] = makePythonList(seeds->oldSeeds);
592 result["current"] = makePythonList(seeds->currentSeeds);
593 result["new"] = makePythonList(seeds->newSeeds);
594 return std::move(result);
595 }
596
cleanUp()597 void cleanUp() {
598 // Deadlock avoidance: consider a thrift worker thread is doing
599 // something in C++-land having relinquished the GIL. This thread
600 // acquires the GIL, stops the workers, and waits for the worker
601 // threads to complete. The worker thread now finishes its work,
602 // and tries to reacquire the GIL, but deadlocks with the current
603 // thread, which holds the GIL and is waiting for the worker to
604 // complete. So we do cleanUp() without the GIL, and reacquire it
605 // only once thrift is all cleaned up.
606
607 PyThreadState* save_state = PyEval_SaveThread();
608 SCOPE_EXIT { PyEval_RestoreThread(save_state); };
609 ThriftServer::cleanUp();
610 }
611
setIdleTimeout(int timeout)612 void setIdleTimeout(int timeout) {
613 std::chrono::milliseconds ms(timeout);
614 ThriftServer::setIdleTimeout(ms, AttributeSource::OVERRIDE);
615 }
616
setTaskExpireTime(int timeout)617 void setTaskExpireTime(int timeout) {
618 std::chrono::milliseconds ms(timeout);
619 ThriftServer::setTaskExpireTime(ms, AttributeSource::OVERRIDE);
620 }
621
setCppServerEventHandler(object serverEventHandler)622 void setCppServerEventHandler(object serverEventHandler) {
623 setServerEventHandler(
624 std::make_shared<CppServerEventHandler>(serverEventHandler));
625 }
626
setNewSimpleThreadManager(size_t count,size_t)627 void setNewSimpleThreadManager(size_t count, size_t) {
628 auto tm = ThreadManager::newSimpleThreadManager(count);
629 auto poolThreadName = getCPUWorkerThreadName();
630 if (!poolThreadName.empty()) {
631 tm->setNamePrefix(poolThreadName);
632 }
633
634 tm->threadFactory(std::make_shared<PosixThreadFactory>());
635 tm->start();
636 setThreadManager(std::move(tm));
637 }
638
setNewPriorityQueueThreadManager(size_t numThreads)639 void setNewPriorityQueueThreadManager(size_t numThreads) {
640 auto tm = ThreadManager::newPriorityQueueThreadManager(numThreads);
641 auto poolThreadName = getCPUWorkerThreadName();
642 if (!poolThreadName.empty()) {
643 tm->setNamePrefix(poolThreadName);
644 }
645
646 tm->threadFactory(std::make_shared<PosixThreadFactory>());
647 tm->start();
648 setThreadManager(std::move(tm));
649 }
650
setNewPriorityThreadManager(size_t high_important,size_t high,size_t important,size_t normal,size_t best_effort,size_t)651 void setNewPriorityThreadManager(
652 size_t high_important,
653 size_t high,
654 size_t important,
655 size_t normal,
656 size_t best_effort,
657 size_t) {
658 auto tm = PriorityThreadManager::newPriorityThreadManager(
659 {{high_important, high, important, normal, best_effort}});
660 tm->enableCodel(getEnableCodel());
661 auto poolThreadName = getCPUWorkerThreadName();
662 if (!poolThreadName.empty()) {
663 tm->setNamePrefix(poolThreadName);
664 }
665
666 tm->threadFactory(std::make_shared<PosixThreadFactory>());
667 tm->start();
668 setThreadManager(std::move(tm));
669 }
670
671 // this adapts from a std::shared_ptr, which boost::python does not (yet)
672 // support, to a boost::shared_ptr, which it has internal support for.
673 //
674 // the magic is in the custom deleter which takes and releases a refcount on
675 // the std::shared_ptr, instead of doing any local deletion.
getThreadManagerHelper()676 boost::shared_ptr<ThreadManager> getThreadManagerHelper() {
677 auto ptr = this->getThreadManager();
678 return boost::shared_ptr<ThreadManager>(ptr.get(), [ptr](void*) {});
679 }
680
setWorkersJoinTimeout(int seconds)681 void setWorkersJoinTimeout(int seconds) {
682 ThriftServer::setWorkersJoinTimeout(std::chrono::seconds(seconds));
683 }
684
setNumIOWorkerThreads(size_t numIOWorkerThreads)685 void setNumIOWorkerThreads(size_t numIOWorkerThreads) {
686 BaseThriftServer::setNumIOWorkerThreads(
687 numIOWorkerThreads, AttributeSource::OVERRIDE);
688 }
689
setListenBacklog(int listenBacklog)690 void setListenBacklog(int listenBacklog) {
691 BaseThriftServer::setListenBacklog(
692 listenBacklog, AttributeSource::OVERRIDE);
693 }
694
setMaxConnections(uint32_t maxConnections)695 void setMaxConnections(uint32_t maxConnections) {
696 BaseThriftServer::setMaxConnections(
697 maxConnections, AttributeSource::OVERRIDE);
698 }
699
setNumCPUWorkerThreads(size_t numCPUWorkerThreads)700 void setNumCPUWorkerThreads(size_t numCPUWorkerThreads) {
701 BaseThriftServer::setNumCPUWorkerThreads(
702 numCPUWorkerThreads, AttributeSource::OVERRIDE);
703 }
704
setEnableCodel(bool enableCodel)705 void setEnableCodel(bool enableCodel) {
706 BaseThriftServer::setEnableCodel(enableCodel, AttributeSource::OVERRIDE);
707 }
708
setWrapperName(object wrapperName)709 void setWrapperName(object wrapperName) {
710 BaseThriftServer::metadata().wrapper =
711 extract<std::string>(str(wrapperName));
712 }
713
setLanguageFrameworkName(object languageFrameworkName)714 void setLanguageFrameworkName(object languageFrameworkName) {
715 BaseThriftServer::metadata().languageFramework =
716 extract<std::string>(str(languageFrameworkName));
717 }
718
setUnixSocketPath(const char * path)719 void setUnixSocketPath(const char* path) {
720 setAddress(folly::SocketAddress::makeFromPath(path));
721 }
722 };
723
BOOST_PYTHON_MODULE(CppServerWrapper)724 BOOST_PYTHON_MODULE(CppServerWrapper) {
725 PyEval_InitThreads();
726
727 class_<CppContextData>("CppContextData")
728 .def("getClientIdentity", &CppContextData::getClientIdentity)
729 .def("getPeerAddress", &CppContextData::getPeerAddress)
730 .def("getLocalAddress", &CppContextData::getLocalAddress)
731 .def("setHeaderEx", &CppContextData::setHeaderEx)
732 .def("setHeaderExWhat", &CppContextData::setHeaderExWhat);
733
734 class_<CallbackWrapper, boost::noncopyable>("CallbackWrapper")
735 .def("call", &CallbackWrapper::call);
736
737 class_<ThriftServer, boost::noncopyable>("ThriftServer");
738
739 class_<CppServerWrapper, bases<ThriftServer>, boost::noncopyable>(
740 "CppServerWrapper")
741 // methods added or customized for the python implementation
742 .def("setAdapter", &CppServerWrapper::setAdapter)
743 .def(
744 "setAddress",
745 static_cast<void (CppServerWrapper::*)(std::string const&, uint16_t)>(
746 &CppServerWrapper::setAddress))
747 .def("setUnixSocketPath", &CppServerWrapper::setUnixSocketPath)
748 .def("setObserver", &CppServerWrapper::setObserverFromPython)
749 .def("setIdleTimeout", &CppServerWrapper::setIdleTimeout)
750 .def("setTaskExpireTime", &CppServerWrapper::setTaskExpireTime)
751 .def("getAddress", &CppServerWrapper::getAddress)
752 .def("getPort", &CppServerWrapper::getPort)
753 .def("loop", &CppServerWrapper::loop)
754 .def("cleanUp", &CppServerWrapper::cleanUp)
755 .def(
756 "setCppServerEventHandler",
757 &CppServerWrapper::setCppServerEventHandler)
758 .def(
759 "setNewSimpleThreadManager",
760 &CppServerWrapper::setNewSimpleThreadManager,
761 (arg("count"), arg("pendingTaskCountMax")))
762 .def(
763 "setNewPriorityQueueThreadManager",
764 &CppServerWrapper::setNewPriorityQueueThreadManager,
765 (arg("numThreads")))
766 .def(
767 "setNewPriorityThreadManager",
768 &CppServerWrapper::setNewPriorityThreadManager,
769 (arg("high_important"),
770 arg("high"),
771 arg("important"),
772 arg("normal"),
773 arg("best_effort"),
774 arg("maxQueueLen") = 0))
775 .def("setCppSSLConfig", &CppServerWrapper::setCppSSLConfig)
776 .def("setCppSSLCacheOptions", &CppServerWrapper::setCppSSLCacheOptions)
777 .def("setCppFastOpenOptions", &CppServerWrapper::setCppFastOpenOptions)
778 .def("getCppTicketSeeds", &CppServerWrapper::getCppTicketSeeds)
779 .def("setWorkersJoinTimeout", &CppServerWrapper::setWorkersJoinTimeout)
780 .def("useCppExistingSocket", &CppServerWrapper::useCppExistingSocket)
781
782 // methods directly passed to the C++ impl
783 .def("setup", &CppServerWrapper::setup)
784 .def("setNumCPUWorkerThreads", &CppServerWrapper::setNumCPUWorkerThreads)
785 .def("setNumIOWorkerThreads", &CppServerWrapper::setNumIOWorkerThreads)
786 .def("setListenBacklog", &CppServerWrapper::setListenBacklog)
787 .def("setPort", &CppServerWrapper::setPort)
788 .def("setReusePort", &CppServerWrapper::setReusePort)
789 .def("stop", &CppServerWrapper::stop)
790 .def("setMaxConnections", &CppServerWrapper::setMaxConnections)
791 .def("getMaxConnections", &CppServerWrapper::getMaxConnections)
792 .def("setEnabled", &CppServerWrapper::setEnabled)
793
794 .def("getLoad", &CppServerWrapper::getLoad)
795 .def("getActiveRequests", &CppServerWrapper::getActiveRequests)
796 .def("getThreadManager", &CppServerWrapper::getThreadManagerHelper)
797 .def("setWrapperName", &CppServerWrapper::setWrapperName)
798 .def(
799 "setLanguageFrameworkName",
800 &CppServerWrapper::setLanguageFrameworkName);
801
802 class_<ThreadManager, boost::shared_ptr<ThreadManager>, boost::noncopyable>(
803 "ThreadManager", no_init)
804 .def("idleWorkerCount", &ThreadManager::idleWorkerCount)
805 .def("workerCount", &ThreadManager::workerCount)
806 .def("pendingTaskCount", &ThreadManager::pendingTaskCount)
807 .def("pendingUpstreamTaskCount", &ThreadManager::pendingUpstreamTaskCount)
808 .def("totalTaskCount", &ThreadManager::totalTaskCount)
809 .def("expiredTaskCount", &ThreadManager::expiredTaskCount)
810 .def("clearPending", &ThreadManager::clearPending);
811
812 class_<PythonCallTimestamps>("CallTimestamps")
813 .def("getReadEnd", &PythonCallTimestamps::get_readEndUsec)
814 .def("setReadEndNow", &PythonCallTimestamps::set_readEndNow)
815 .def("getProcessBegin", &PythonCallTimestamps::get_processBeginUsec)
816 .def("setProcessBeginNow", &PythonCallTimestamps::set_processBeginNow)
817 .def("getProcessEnd", &PythonCallTimestamps::get_processEndUsec)
818 .def("setProcessEndNow", &PythonCallTimestamps::set_processEndNow)
819 .def("getWriteBegin", &PythonCallTimestamps::get_writeBeginUsec)
820 .def("setWriteBeginNow", &PythonCallTimestamps::set_writeBeginNow)
821 .def("getWriteEnd", &PythonCallTimestamps::get_writeEndUsec)
822 .def("setWriteEndNow", &PythonCallTimestamps::set_writeEndNow);
823
824 enum_<SSLPolicy>("SSLPolicy")
825 .value("DISABLED", SSLPolicy::DISABLED)
826 .value("PERMITTED", SSLPolicy::PERMITTED)
827 .value("REQUIRED", SSLPolicy::REQUIRED);
828
829 enum_<folly::SSLContext::VerifyClientCertificate>("VerifyClientCertificate")
830 .value(
831 "IF_PRESENTED",
832 folly::SSLContext::VerifyClientCertificate::IF_PRESENTED)
833 .value(
834 "ALWAYS_VERIFY", folly::SSLContext::VerifyClientCertificate::ALWAYS)
835 .value(
836 "NONE_DO_NOT_REQUEST",
837 folly::SSLContext::VerifyClientCertificate::DO_NOT_REQUEST);
838
839 enum_<folly::SSLContext::SSLVersion>("SSLVersion")
840 .value("TLSv1_2", folly::SSLContext::SSLVersion::TLSv1_2);
841 }
842