1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <map>
18 
19 #include <glog/logging.h>
20 
21 #include <boost/python.hpp>
22 
23 #include <folly/Memory.h>
24 #include <folly/ScopeGuard.h>
25 #include <folly/SocketAddress.h>
26 #include <folly/Utility.h>
27 
28 #include <thrift/lib/cpp/concurrency/PosixThreadFactory.h>
29 #include <thrift/lib/cpp/concurrency/ThreadManager.h>
30 #include <thrift/lib/cpp/protocol/TProtocolTypes.h>
31 #include <thrift/lib/cpp2/async/AsyncProcessor.h>
32 #include <thrift/lib/cpp2/protocol/BinaryProtocol.h>
33 #include <thrift/lib/cpp2/protocol/CompactProtocol.h>
34 #include <thrift/lib/cpp2/server/ThriftServer.h>
35 #include <thrift/lib/cpp2/util/LegacyRequestExpiryGuard.h>
36 #include <thrift/lib/py/server/CppContextData.h>
37 #include <wangle/ssl/SSLContextConfig.h>
38 
39 using namespace apache::thrift;
40 using apache::thrift::BaseThriftServer;
41 using apache::thrift::concurrency::PosixThreadFactory;
42 using apache::thrift::concurrency::ThreadManager;
43 using apache::thrift::server::TConnectionContext;
44 using apache::thrift::server::TServerEventHandler;
45 using apache::thrift::server::TServerObserver;
46 using apache::thrift::transport::THeader;
47 using folly::SSLContext;
48 using wangle::SSLCacheOptions;
49 using wangle::SSLContextConfig;
50 using namespace boost::python;
51 
52 // If less than 3.7 offer no additional protection
53 #if PY_VERSION_HEX <= 0x03070000
54 #define _Py_IsFinalizing() false
55 #endif
56 
57 namespace {
58 
59 const std::string kHeaderEx = "uex";
60 const std::string kHeaderExWhat = "uexw";
61 
makePythonHeaders(const THeader::StringToStringMap & cppheaders,const Cpp2RequestContext * context)62 object makePythonHeaders(
63     const THeader::StringToStringMap& cppheaders,
64     const Cpp2RequestContext* context) {
65   object headers = dict();
66   for (const auto& it : cppheaders) {
67     headers[it.first] = it.second;
68   }
69   headers[apache::thrift::THeader::CLIENT_TIMEOUT_HEADER] =
70       folly::to<std::string>(
71           std::chrono::milliseconds(context->getRequestTimeout()).count());
72   return headers;
73 }
74 
makePythonList(const std::vector<std::string> & vec)75 object makePythonList(const std::vector<std::string>& vec) {
76   list result;
77   for (auto it = vec.begin(); it != vec.end(); ++it) {
78     result.append(*it);
79   }
80   return std::move(result);
81 }
82 
getStringAttrSafe(object & pyObject,const char * attrName)83 std::string getStringAttrSafe(object& pyObject, const char* attrName) {
84   object val = pyObject.attr(attrName);
85   if (val.is_none()) {
86     return "";
87   }
88   return extract<std::string>(str(val));
89 }
90 
91 template <class T>
getIntAttr(object & pyObject,const char * attrName)92 T getIntAttr(object& pyObject, const char* attrName) {
93   object val = pyObject.attr(attrName);
94   return extract<T>(val);
95 }
96 
97 } // namespace
98 
99 class CallbackWrapper {
100  public:
call(object obj)101   void call(object obj) { callback_(obj); }
102 
setCallback(folly::Function<void (object)> && callback)103   void setCallback(folly::Function<void(object)>&& callback) {
104     callback_ = std::move(callback);
105   }
106 
107  private:
108   folly::Function<void(object)> callback_;
109 };
110 
111 class CppServerEventHandler : public TServerEventHandler {
112  public:
CppServerEventHandler(object serverEventHandler)113   explicit CppServerEventHandler(object serverEventHandler)
114       : handler_(std::make_shared<object>(serverEventHandler)) {}
115 
newConnection(TConnectionContext * ctx)116   void newConnection(TConnectionContext* ctx) override {
117     callPythonHandler(ctx, "newConnection");
118   }
119 
connectionDestroyed(TConnectionContext * ctx)120   void connectionDestroyed(TConnectionContext* ctx) override {
121     callPythonHandler(ctx, "connectionDestroyed");
122   }
123 
124  private:
callPythonHandler(TConnectionContext * ctx,const char * method)125   void callPythonHandler(TConnectionContext* ctx, const char* method) {
126     if (!_Py_IsFinalizing()) {
127       PyGILState_STATE state = PyGILState_Ensure();
128       SCOPE_EXIT { PyGILState_Release(state); };
129 
130       // This cast always succeeds because it is called from Cpp2Connection.
131       Cpp2ConnContext* cpp2Ctx = dynamic_cast<Cpp2ConnContext*>(ctx);
132       auto cd_cls = handler_->attr("CONTEXT_DATA");
133       object contextData = cd_cls();
134       extract<CppContextData&>(contextData)().copyContextContents(cpp2Ctx);
135       auto ctx_cls = handler_->attr("CPP_CONNECTION_CONTEXT");
136       object cppConnContext = ctx_cls(contextData);
137       handler_->attr(method)(cppConnContext);
138     }
139   }
140 
141   std::shared_ptr<object> handler_;
142 };
143 
144 class PythonCallTimestamps : public TServerObserver::CallTimestamps {
145  public:
set_readEndNow()146   void set_readEndNow() { readEnd = clock::now(); }
get_readEndUsec() const147   uint64_t get_readEndUsec() const noexcept {
148     return to_microseconds(readEnd.time_since_epoch());
149   }
set_processBeginNow()150   void set_processBeginNow() { processBegin = clock::now(); }
get_processBeginUsec() const151   uint64_t get_processBeginUsec() const noexcept {
152     return to_microseconds(processBegin.time_since_epoch());
153   }
set_processEndNow()154   void set_processEndNow() { processEnd = clock::now(); }
get_processEndUsec() const155   uint64_t get_processEndUsec() const noexcept {
156     return to_microseconds(processEnd.time_since_epoch());
157   }
set_writeBeginNow()158   void set_writeBeginNow() { writeBegin = clock::now(); }
get_writeBeginUsec() const159   uint64_t get_writeBeginUsec() const noexcept {
160     return to_microseconds(writeBegin.time_since_epoch());
161   }
set_writeEndNow()162   void set_writeEndNow() { writeEnd = clock::now(); }
get_writeEndUsec() const163   uint64_t get_writeEndUsec() const noexcept {
164     return to_microseconds(writeEnd.time_since_epoch());
165   }
166 };
167 
168 class CppServerObserver : public TServerObserver {
169  public:
CppServerObserver(object serverObserver)170   explicit CppServerObserver(object serverObserver)
171       : observer_(serverObserver) {}
172 
connAccepted(const wangle::TransportInfo &)173   void connAccepted(const wangle::TransportInfo& /* info */) override {
174     this->call("connAccepted");
175   }
connDropped()176   void connDropped() override { this->call("connDropped"); }
connRejected()177   void connRejected() override { this->call("connRejected"); }
tlsError()178   void tlsError() override { this->call("tlsError"); }
tlsComplete()179   void tlsComplete() override { this->call("tlsComplete"); }
tlsFallback()180   void tlsFallback() override { this->call("tlsFallback"); }
tlsResumption()181   void tlsResumption() override { this->call("tlsResumption"); }
taskKilled()182   void taskKilled() override { this->call("taskKilled"); }
taskTimeout()183   void taskTimeout() override { this->call("taskTimeout"); }
serverOverloaded()184   void serverOverloaded() override { this->call("serverOverloaded"); }
receivedRequest(const std::string *)185   void receivedRequest(const std::string* /*method*/) override {
186     this->call("receivedRequest");
187   }
admittedRequest(const std::string *)188   void admittedRequest(const std::string* /*method*/) override {
189     this->call("admittedRequest");
190   }
queuedRequests(int32_t n)191   void queuedRequests(int32_t n) override { this->call("queuedRequests", n); }
queueTimeout()192   void queueTimeout() override { this->call("queueTimeout"); }
sentReply()193   void sentReply() override { this->call("sentReply"); }
activeRequests(int32_t n)194   void activeRequests(int32_t n) override { this->call("activeRequests", n); }
callCompleted(const CallTimestamps & runtimes)195   void callCompleted(const CallTimestamps& runtimes) override {
196     this->call(
197         "callCompleted",
198         reinterpret_cast<const PythonCallTimestamps&>(runtimes));
199   }
tlsWithClientCert()200   void tlsWithClientCert() override { this->call("tlsWithClientCert"); }
201 
202  private:
203   template <class... Types>
call(const char * method_name,Types...args)204   void call(const char* method_name, Types... args) {
205     PyGILState_STATE state = PyGILState_Ensure();
206     SCOPE_EXIT { PyGILState_Release(state); };
207 
208     // check if the object has an attribute, because we want to be accepting
209     // if we added a new listener callback and didn't yet update call the
210     // people using this interface.
211     if (!PyObject_HasAttrString(observer_.ptr(), method_name)) {
212       return;
213     }
214 
215     try {
216       (void)observer_.attr(method_name)(args...);
217     } catch (const error_already_set&) {
218       // print the error to sys.stderr and carry on, because raising here
219       // would break the server protocol, and raising in Python later
220       // would be extremely disconnected and confusing since it would
221       // happen in apparently unconnected Python code.
222       PyErr_Print();
223     }
224   }
225 
226   object observer_;
227 };
228 
229 class PythonAsyncProcessor : public AsyncProcessor {
230  public:
PythonAsyncProcessor(std::shared_ptr<object> adapter)231   explicit PythonAsyncProcessor(std::shared_ptr<object> adapter)
232       : adapter_(adapter) {
233     getPythonOnewayMethods();
234   }
235 
236   // Create a task and add it to thread manager's queue. Essentially the same
237   // as GeneratedAsyncProcessor's processInThread method.
processSerializedRequest(ResponseChannelRequest::UniquePtr req,apache::thrift::SerializedRequest && serializedRequest,apache::thrift::protocol::PROTOCOL_TYPES protType,Cpp2RequestContext * context,folly::EventBase * eb,apache::thrift::concurrency::ThreadManager * tm)238   void processSerializedRequest(
239       ResponseChannelRequest::UniquePtr req,
240       apache::thrift::SerializedRequest&& serializedRequest,
241       apache::thrift::protocol::PROTOCOL_TYPES protType,
242       Cpp2RequestContext* context,
243       folly::EventBase* eb,
244       apache::thrift::concurrency::ThreadManager* tm) override {
245     auto fname = context->getMethodName();
246     bool oneway = isOnewayMethod(fname);
247 
248     if (oneway && !req->isOneway()) {
249       req->sendReply(ResponsePayload{});
250     }
251 
252     apache::thrift::LegacyRequestExpiryGuard rh{std::move(req), eb};
253     auto task = [=,
254                  buf = apache::thrift::LegacySerializedRequest(
255                            protType,
256                            context->getProtoSeqId(),
257                            context->getMethodName(),
258                            std::move(serializedRequest))
259                            .buffer,
260                  rh = std::move(rh)]() mutable {
261       auto req_up = std::move(rh.req);
262       SCOPE_EXIT {
263         rh.eb->runInEventBaseThread(
264             [req_up = std::move(req_up)]() mutable { req_up = {}; });
265       };
266 
267       if (!oneway && !req_up->getShouldStartProcessing()) {
268         return;
269       }
270 
271       folly::ByteRange input_range = buf->coalesce();
272       auto input_data = const_cast<unsigned char*>(input_range.data());
273       auto clientType = context->getHeader()->getClientType();
274 
275       {
276         PyGILState_STATE state = PyGILState_Ensure();
277         SCOPE_EXIT { PyGILState_Release(state); };
278 
279 #if PY_MAJOR_VERSION == 2
280         auto input =
281             handle<>(PyBuffer_FromMemory(input_data, input_range.size()));
282 #else
283         auto input = handle<>(PyMemoryView_FromMemory(
284             reinterpret_cast<char*>(input_data),
285             input_range.size(),
286             PyBUF_READ));
287 #endif
288 
289         auto cd_ctor = adapter_->attr("CONTEXT_DATA");
290         object contextData = cd_ctor();
291         extract<CppContextData&>(contextData)().copyContextContents(context);
292 
293         auto cb_ctor = adapter_->attr("CALLBACK_WRAPPER");
294         object callbackWrapper = cb_ctor();
295         extract<CallbackWrapper&>(callbackWrapper)().setCallback(
296             [oneway,
297              req_up = std::move(req_up),
298              context,
299              eb = rh.eb,
300              contextData,
301              protType](object output) mutable {
302               // Make sure the request is deleted in evb.
303               SCOPE_EXIT {
304                 eb->runInEventBaseThread(
305                     [req_up = std::move(req_up)]() mutable { req_up = {}; });
306               };
307 
308               // Always called from python so no need to grab GIL.
309               try {
310                 std::unique_ptr<folly::IOBuf> outbuf;
311                 if (output.is_none()) {
312                   throw std::runtime_error(
313                       "Unexpected error in processor method");
314                 }
315                 PyObject* output_ptr = output.ptr();
316 #if PY_MAJOR_VERSION == 2
317                 if (PyString_Check(output_ptr)) {
318                   int len = extract<int>(output.attr("__len__")());
319                   if (len == 0) {
320                     return;
321                   }
322                   outbuf = folly::IOBuf::copyBuffer(
323                       extract<const char*>(output), len);
324                 } else
325 #endif
326                     if (PyBytes_Check(output_ptr)) {
327                   int len = PyBytes_Size(output_ptr);
328                   if (len == 0) {
329                     return;
330                   }
331                   outbuf = folly::IOBuf::copyBuffer(
332                       PyBytes_AsString(output_ptr), len);
333                 } else {
334                   throw std::runtime_error(
335                       "Return from processor "
336                       "method is not string or bytes");
337                 }
338 
339                 if (!req_up->isActive()) {
340                   return;
341                 }
342                 CppContextData& cppContextData =
343                     extract<CppContextData&>(contextData);
344                 if (!cppContextData.getHeaderEx().empty()) {
345                   context->getHeader()->setHeader(
346                       kHeaderEx, cppContextData.getHeaderEx());
347                 }
348                 if (!cppContextData.getHeaderExWhat().empty()) {
349                   context->getHeader()->setHeader(
350                       kHeaderExWhat, cppContextData.getHeaderExWhat());
351                 }
352                 auto response = LegacySerializedResponse{std::move(outbuf)};
353                 auto [mtype, payload] = std::move(response).extractPayload(
354                     req_up->includeEnvelope(), protType);
355                 payload.transform(context->getHeader()->getWriteTransforms());
356                 eb->runInEventBaseThread(
357                     [mtype = mtype,
358                      req_up = std::move(req_up),
359                      payload = std::move(payload)]() mutable {
360                       if (mtype == MessageType::T_REPLY) {
361                         req_up->sendReply(std::move(payload));
362                       } else if (mtype == MessageType::T_EXCEPTION) {
363                         req_up->sendException(std::move(payload));
364                       } else {
365                         LOG(ERROR) << "Invalid type. type=" << uint16_t(mtype);
366                       }
367                     });
368               } catch (const std::exception& e) {
369                 if (!oneway) {
370                   req_up->sendErrorWrapped(
371                       folly::make_exception_wrapper<TApplicationException>(
372                           folly::to<std::string>(
373                               "Failed to read response from Python:",
374                               e.what())),
375                       "python");
376                 }
377               }
378             });
379 
380         adapter_->attr("call_processor")(
381             input,
382             makePythonHeaders(context->getHeader()->getHeaders(), context),
383             int(clientType),
384             int(protType),
385             contextData,
386             callbackWrapper);
387       }
388     };
389 
390     using PriorityThreadManager =
391         apache::thrift::concurrency::PriorityThreadManager;
392     auto ptm = dynamic_cast<PriorityThreadManager*>(tm);
393     if (ptm != nullptr) {
394       ptm->add(
395           getMethodPriority(fname, context),
396           std::make_shared<apache::thrift::concurrency::FunctionRunner>(
397               std::move(task)));
398       return;
399     }
400     tm->add(std::move(task));
401   }
402 
403   /**
404    * Get the priority of the request
405    * Check the headers directly in C++ since noone seems to override that logic
406    * Ask python if no priority headers were supplied with the request
407    */
getMethodPriority(std::string const & fname,Cpp2RequestContext * ctx=nullptr)408   concurrency::PRIORITY getMethodPriority(
409       std::string const& fname, Cpp2RequestContext* ctx = nullptr) {
410     if (ctx) {
411       auto requestPriority = ctx->getCallPriority();
412       if (requestPriority != concurrency::PRIORITY::N_PRIORITIES) {
413         VLOG(3) << "Request priority from headers";
414         return requestPriority;
415       }
416     }
417 
418     PyGILState_STATE state = PyGILState_Ensure();
419     SCOPE_EXIT { PyGILState_Release(state); };
420 
421     try {
422       return static_cast<concurrency::PRIORITY>(
423           extract<int>(adapter_->attr("get_priority")(fname))());
424     } catch (error_already_set&) {
425       // get_priority doesn't exist, or it threw an exception
426       LOG(ERROR) << "Error while calling _ProcessorAdapter.get_priority()";
427       PyErr_Print();
428     }
429 
430     return concurrency::PRIORITY::NORMAL;
431   }
432 
433  private:
isOnewayMethod(std::string const & fname)434   bool isOnewayMethod(std::string const& fname) {
435     return onewayMethods_.find(fname) != onewayMethods_.end();
436   }
437 
getPythonOnewayMethods()438   void getPythonOnewayMethods() {
439     PyGILState_STATE state = PyGILState_Ensure();
440     SCOPE_EXIT { PyGILState_Release(state); };
441     object ret = adapter_->attr("oneway_methods")();
442     if (ret.is_none()) {
443       LOG(ERROR) << "Unexpected error in processor method";
444       return;
445     }
446     tuple t = extract<tuple>(ret);
447     for (int i = 0; i < len(t); i++) {
448       onewayMethods_.insert(extract<std::string>(t[i]));
449     }
450   }
451 
452   std::shared_ptr<object> adapter_;
453   std::unordered_set<std::string> onewayMethods_;
454 };
455 
456 class PythonAsyncProcessorFactory : public AsyncProcessorFactory {
457  public:
PythonAsyncProcessorFactory(std::shared_ptr<object> adapter)458   explicit PythonAsyncProcessorFactory(std::shared_ptr<object> adapter)
459       : adapter_(adapter) {}
460 
getProcessor()461   std::unique_ptr<apache::thrift::AsyncProcessor> getProcessor() override {
462     return std::make_unique<PythonAsyncProcessor>(adapter_);
463   }
464 
465   // TODO(T89004867): Call onStartServing() and onStopServing() hooks for
466   // non-C++ thrift servers
getServiceHandlers()467   std::vector<apache::thrift::ServiceHandler*> getServiceHandlers() override {
468     return {};
469   }
470 
471  private:
472   std::shared_ptr<object> adapter_;
473 };
474 
475 class CppServerWrapper : public ThriftServer {
476  public:
CppServerWrapper()477   CppServerWrapper() {
478     BaseThriftServer::metadata().wrapper = "CppServerWrapper-py";
479   }
480 
setAdapter(object adapter)481   void setAdapter(object adapter) {
482     // We use a shared_ptr to manage the adapter so the processor
483     // factory handing won't ever try to manipulate python reference
484     // counts without the GIL.
485     setProcessorFactory(std::make_unique<PythonAsyncProcessorFactory>(
486         std::make_shared<object>(adapter)));
487   }
488 
489   // peer to setObserver, but since we want a different argument, avoid
490   // shadowing in our parent class.
setObserverFromPython(object observer)491   void setObserverFromPython(object observer) {
492     setObserver(std::make_shared<CppServerObserver>(observer));
493   }
494 
getAddress()495   object getAddress() { return makePythonAddress(ThriftServer::getAddress()); }
496 
loop()497   void loop() {
498     PyThreadState* save_state = PyEval_SaveThread();
499     SCOPE_EXIT { PyEval_RestoreThread(save_state); };
500 
501     // Thrift main loop.  This will run indefinitely, until stop() is
502     // called.
503 
504     getServeEventBase()->loopForever();
505   }
506 
setup()507   void setup() {
508     PyThreadState* save_state = PyEval_SaveThread();
509     SCOPE_EXIT { PyEval_RestoreThread(save_state); };
510 
511     // This check is only useful for C++-based Thrift servers.
512     ThriftServer::setAllowCheckUnimplementedExtraInterfaces(false);
513     ThriftServer::setup();
514   }
515 
setCppSSLConfig(object sslConfig)516   void setCppSSLConfig(object sslConfig) {
517     auto certPath = getStringAttrSafe(sslConfig, "cert_path");
518     auto keyPath = getStringAttrSafe(sslConfig, "key_path");
519     if (certPath.empty() ^ keyPath.empty()) {
520       PyErr_SetString(
521           PyExc_ValueError, "certPath and keyPath must both be populated");
522       throw_error_already_set();
523       return;
524     }
525     auto cfg = std::make_shared<SSLContextConfig>();
526     cfg->clientCAFile = getStringAttrSafe(sslConfig, "client_ca_path");
527     if (!certPath.empty()) {
528       auto keyPwPath = getStringAttrSafe(sslConfig, "key_pw_path");
529       cfg->setCertificate(certPath, keyPath, keyPwPath);
530     }
531     cfg->clientVerification =
532         extract<SSLContext::VerifyClientCertificate>(sslConfig.attr("verify"));
533     auto eccCurve = getStringAttrSafe(sslConfig, "ecc_curve_name");
534     if (!eccCurve.empty()) {
535       cfg->eccCurveName = eccCurve;
536     }
537     object sessionContext = sslConfig.attr("session_context");
538     if (!sessionContext.is_none()) {
539       cfg->sessionContext = extract<std::string>(str(sessionContext));
540     }
541 
542     object sslVersionAttr = sslConfig.attr("ssl_version");
543     if (!sslVersionAttr.is_none()) {
544       cfg->sslVersion =
545           extract<SSLContext::SSLVersion>(sslConfig.attr("ssl_version"));
546     }
547 
548     ThriftServer::setSSLConfig(folly::observer::makeObserver(
549         [cfg, nextProtocolsObserver = ThriftServer::defaultNextProtocols()] {
550           auto cfgWithNextProtocols = *cfg;
551           cfgWithNextProtocols.setNextProtocols(**nextProtocolsObserver);
552           return cfgWithNextProtocols;
553         }));
554 
555     setSSLPolicy(extract<SSLPolicy>(sslConfig.attr("ssl_policy")));
556 
557     auto ticketFilePath = getStringAttrSafe(sslConfig, "ticket_file_path");
558     ThriftServer::watchTicketPathForChanges(ticketFilePath);
559   }
560 
setCppFastOpenOptions(object enabledObj,object tfoMaxQueueObj)561   void setCppFastOpenOptions(object enabledObj, object tfoMaxQueueObj) {
562     bool enabled{extract<bool>(enabledObj)};
563     uint32_t tfoMaxQueue{extract<uint32_t>(tfoMaxQueueObj)};
564     ThriftServer::setFastOpenOptions(enabled, tfoMaxQueue);
565   }
566 
useCppExistingSocket(int socket)567   void useCppExistingSocket(int socket) {
568     ThriftServer::useExistingSocket(socket);
569   }
570 
setCppSSLCacheOptions(object cacheOptions)571   void setCppSSLCacheOptions(object cacheOptions) {
572     SSLCacheOptions options = {
573         .sslCacheTimeout = std::chrono::seconds(
574             getIntAttr<uint32_t>(cacheOptions, "ssl_cache_timeout_seconds")),
575         .maxSSLCacheSize =
576             getIntAttr<uint64_t>(cacheOptions, "max_ssl_cache_size"),
577         .sslCacheFlushSize =
578             getIntAttr<uint64_t>(cacheOptions, "ssl_cache_flush_size"),
579         .handshakeValidity = std::chrono::seconds(getIntAttr<uint32_t>(
580             cacheOptions, "ssl_handshake_validity_seconds")),
581     };
582     ThriftServer::setSSLCacheOptions(std::move(options));
583   }
584 
getCppTicketSeeds()585   object getCppTicketSeeds() {
586     auto seeds = getTicketSeeds();
587     if (!seeds) {
588       return boost::python::object();
589     }
590     boost::python::dict result;
591     result["old"] = makePythonList(seeds->oldSeeds);
592     result["current"] = makePythonList(seeds->currentSeeds);
593     result["new"] = makePythonList(seeds->newSeeds);
594     return std::move(result);
595   }
596 
cleanUp()597   void cleanUp() {
598     // Deadlock avoidance: consider a thrift worker thread is doing
599     // something in C++-land having relinquished the GIL.  This thread
600     // acquires the GIL, stops the workers, and waits for the worker
601     // threads to complete.  The worker thread now finishes its work,
602     // and tries to reacquire the GIL, but deadlocks with the current
603     // thread, which holds the GIL and is waiting for the worker to
604     // complete.  So we do cleanUp() without the GIL, and reacquire it
605     // only once thrift is all cleaned up.
606 
607     PyThreadState* save_state = PyEval_SaveThread();
608     SCOPE_EXIT { PyEval_RestoreThread(save_state); };
609     ThriftServer::cleanUp();
610   }
611 
setIdleTimeout(int timeout)612   void setIdleTimeout(int timeout) {
613     std::chrono::milliseconds ms(timeout);
614     ThriftServer::setIdleTimeout(ms, AttributeSource::OVERRIDE);
615   }
616 
setTaskExpireTime(int timeout)617   void setTaskExpireTime(int timeout) {
618     std::chrono::milliseconds ms(timeout);
619     ThriftServer::setTaskExpireTime(ms, AttributeSource::OVERRIDE);
620   }
621 
setCppServerEventHandler(object serverEventHandler)622   void setCppServerEventHandler(object serverEventHandler) {
623     setServerEventHandler(
624         std::make_shared<CppServerEventHandler>(serverEventHandler));
625   }
626 
setNewSimpleThreadManager(size_t count,size_t)627   void setNewSimpleThreadManager(size_t count, size_t) {
628     auto tm = ThreadManager::newSimpleThreadManager(count);
629     auto poolThreadName = getCPUWorkerThreadName();
630     if (!poolThreadName.empty()) {
631       tm->setNamePrefix(poolThreadName);
632     }
633 
634     tm->threadFactory(std::make_shared<PosixThreadFactory>());
635     tm->start();
636     setThreadManager(std::move(tm));
637   }
638 
setNewPriorityQueueThreadManager(size_t numThreads)639   void setNewPriorityQueueThreadManager(size_t numThreads) {
640     auto tm = ThreadManager::newPriorityQueueThreadManager(numThreads);
641     auto poolThreadName = getCPUWorkerThreadName();
642     if (!poolThreadName.empty()) {
643       tm->setNamePrefix(poolThreadName);
644     }
645 
646     tm->threadFactory(std::make_shared<PosixThreadFactory>());
647     tm->start();
648     setThreadManager(std::move(tm));
649   }
650 
setNewPriorityThreadManager(size_t high_important,size_t high,size_t important,size_t normal,size_t best_effort,size_t)651   void setNewPriorityThreadManager(
652       size_t high_important,
653       size_t high,
654       size_t important,
655       size_t normal,
656       size_t best_effort,
657       size_t) {
658     auto tm = PriorityThreadManager::newPriorityThreadManager(
659         {{high_important, high, important, normal, best_effort}});
660     tm->enableCodel(getEnableCodel());
661     auto poolThreadName = getCPUWorkerThreadName();
662     if (!poolThreadName.empty()) {
663       tm->setNamePrefix(poolThreadName);
664     }
665 
666     tm->threadFactory(std::make_shared<PosixThreadFactory>());
667     tm->start();
668     setThreadManager(std::move(tm));
669   }
670 
671   // this adapts from a std::shared_ptr, which boost::python does not (yet)
672   // support, to a boost::shared_ptr, which it has internal support for.
673   //
674   // the magic is in the custom deleter which takes and releases a refcount on
675   // the std::shared_ptr, instead of doing any local deletion.
getThreadManagerHelper()676   boost::shared_ptr<ThreadManager> getThreadManagerHelper() {
677     auto ptr = this->getThreadManager();
678     return boost::shared_ptr<ThreadManager>(ptr.get(), [ptr](void*) {});
679   }
680 
setWorkersJoinTimeout(int seconds)681   void setWorkersJoinTimeout(int seconds) {
682     ThriftServer::setWorkersJoinTimeout(std::chrono::seconds(seconds));
683   }
684 
setNumIOWorkerThreads(size_t numIOWorkerThreads)685   void setNumIOWorkerThreads(size_t numIOWorkerThreads) {
686     BaseThriftServer::setNumIOWorkerThreads(
687         numIOWorkerThreads, AttributeSource::OVERRIDE);
688   }
689 
setListenBacklog(int listenBacklog)690   void setListenBacklog(int listenBacklog) {
691     BaseThriftServer::setListenBacklog(
692         listenBacklog, AttributeSource::OVERRIDE);
693   }
694 
setMaxConnections(uint32_t maxConnections)695   void setMaxConnections(uint32_t maxConnections) {
696     BaseThriftServer::setMaxConnections(
697         maxConnections, AttributeSource::OVERRIDE);
698   }
699 
setNumCPUWorkerThreads(size_t numCPUWorkerThreads)700   void setNumCPUWorkerThreads(size_t numCPUWorkerThreads) {
701     BaseThriftServer::setNumCPUWorkerThreads(
702         numCPUWorkerThreads, AttributeSource::OVERRIDE);
703   }
704 
setEnableCodel(bool enableCodel)705   void setEnableCodel(bool enableCodel) {
706     BaseThriftServer::setEnableCodel(enableCodel, AttributeSource::OVERRIDE);
707   }
708 
setWrapperName(object wrapperName)709   void setWrapperName(object wrapperName) {
710     BaseThriftServer::metadata().wrapper =
711         extract<std::string>(str(wrapperName));
712   }
713 
setLanguageFrameworkName(object languageFrameworkName)714   void setLanguageFrameworkName(object languageFrameworkName) {
715     BaseThriftServer::metadata().languageFramework =
716         extract<std::string>(str(languageFrameworkName));
717   }
718 
setUnixSocketPath(const char * path)719   void setUnixSocketPath(const char* path) {
720     setAddress(folly::SocketAddress::makeFromPath(path));
721   }
722 };
723 
BOOST_PYTHON_MODULE(CppServerWrapper)724 BOOST_PYTHON_MODULE(CppServerWrapper) {
725   PyEval_InitThreads();
726 
727   class_<CppContextData>("CppContextData")
728       .def("getClientIdentity", &CppContextData::getClientIdentity)
729       .def("getPeerAddress", &CppContextData::getPeerAddress)
730       .def("getLocalAddress", &CppContextData::getLocalAddress)
731       .def("setHeaderEx", &CppContextData::setHeaderEx)
732       .def("setHeaderExWhat", &CppContextData::setHeaderExWhat);
733 
734   class_<CallbackWrapper, boost::noncopyable>("CallbackWrapper")
735       .def("call", &CallbackWrapper::call);
736 
737   class_<ThriftServer, boost::noncopyable>("ThriftServer");
738 
739   class_<CppServerWrapper, bases<ThriftServer>, boost::noncopyable>(
740       "CppServerWrapper")
741       // methods added or customized for the python implementation
742       .def("setAdapter", &CppServerWrapper::setAdapter)
743       .def(
744           "setAddress",
745           static_cast<void (CppServerWrapper::*)(std::string const&, uint16_t)>(
746               &CppServerWrapper::setAddress))
747       .def("setUnixSocketPath", &CppServerWrapper::setUnixSocketPath)
748       .def("setObserver", &CppServerWrapper::setObserverFromPython)
749       .def("setIdleTimeout", &CppServerWrapper::setIdleTimeout)
750       .def("setTaskExpireTime", &CppServerWrapper::setTaskExpireTime)
751       .def("getAddress", &CppServerWrapper::getAddress)
752       .def("getPort", &CppServerWrapper::getPort)
753       .def("loop", &CppServerWrapper::loop)
754       .def("cleanUp", &CppServerWrapper::cleanUp)
755       .def(
756           "setCppServerEventHandler",
757           &CppServerWrapper::setCppServerEventHandler)
758       .def(
759           "setNewSimpleThreadManager",
760           &CppServerWrapper::setNewSimpleThreadManager,
761           (arg("count"), arg("pendingTaskCountMax")))
762       .def(
763           "setNewPriorityQueueThreadManager",
764           &CppServerWrapper::setNewPriorityQueueThreadManager,
765           (arg("numThreads")))
766       .def(
767           "setNewPriorityThreadManager",
768           &CppServerWrapper::setNewPriorityThreadManager,
769           (arg("high_important"),
770            arg("high"),
771            arg("important"),
772            arg("normal"),
773            arg("best_effort"),
774            arg("maxQueueLen") = 0))
775       .def("setCppSSLConfig", &CppServerWrapper::setCppSSLConfig)
776       .def("setCppSSLCacheOptions", &CppServerWrapper::setCppSSLCacheOptions)
777       .def("setCppFastOpenOptions", &CppServerWrapper::setCppFastOpenOptions)
778       .def("getCppTicketSeeds", &CppServerWrapper::getCppTicketSeeds)
779       .def("setWorkersJoinTimeout", &CppServerWrapper::setWorkersJoinTimeout)
780       .def("useCppExistingSocket", &CppServerWrapper::useCppExistingSocket)
781 
782       // methods directly passed to the C++ impl
783       .def("setup", &CppServerWrapper::setup)
784       .def("setNumCPUWorkerThreads", &CppServerWrapper::setNumCPUWorkerThreads)
785       .def("setNumIOWorkerThreads", &CppServerWrapper::setNumIOWorkerThreads)
786       .def("setListenBacklog", &CppServerWrapper::setListenBacklog)
787       .def("setPort", &CppServerWrapper::setPort)
788       .def("setReusePort", &CppServerWrapper::setReusePort)
789       .def("stop", &CppServerWrapper::stop)
790       .def("setMaxConnections", &CppServerWrapper::setMaxConnections)
791       .def("getMaxConnections", &CppServerWrapper::getMaxConnections)
792       .def("setEnabled", &CppServerWrapper::setEnabled)
793 
794       .def("getLoad", &CppServerWrapper::getLoad)
795       .def("getActiveRequests", &CppServerWrapper::getActiveRequests)
796       .def("getThreadManager", &CppServerWrapper::getThreadManagerHelper)
797       .def("setWrapperName", &CppServerWrapper::setWrapperName)
798       .def(
799           "setLanguageFrameworkName",
800           &CppServerWrapper::setLanguageFrameworkName);
801 
802   class_<ThreadManager, boost::shared_ptr<ThreadManager>, boost::noncopyable>(
803       "ThreadManager", no_init)
804       .def("idleWorkerCount", &ThreadManager::idleWorkerCount)
805       .def("workerCount", &ThreadManager::workerCount)
806       .def("pendingTaskCount", &ThreadManager::pendingTaskCount)
807       .def("pendingUpstreamTaskCount", &ThreadManager::pendingUpstreamTaskCount)
808       .def("totalTaskCount", &ThreadManager::totalTaskCount)
809       .def("expiredTaskCount", &ThreadManager::expiredTaskCount)
810       .def("clearPending", &ThreadManager::clearPending);
811 
812   class_<PythonCallTimestamps>("CallTimestamps")
813       .def("getReadEnd", &PythonCallTimestamps::get_readEndUsec)
814       .def("setReadEndNow", &PythonCallTimestamps::set_readEndNow)
815       .def("getProcessBegin", &PythonCallTimestamps::get_processBeginUsec)
816       .def("setProcessBeginNow", &PythonCallTimestamps::set_processBeginNow)
817       .def("getProcessEnd", &PythonCallTimestamps::get_processEndUsec)
818       .def("setProcessEndNow", &PythonCallTimestamps::set_processEndNow)
819       .def("getWriteBegin", &PythonCallTimestamps::get_writeBeginUsec)
820       .def("setWriteBeginNow", &PythonCallTimestamps::set_writeBeginNow)
821       .def("getWriteEnd", &PythonCallTimestamps::get_writeEndUsec)
822       .def("setWriteEndNow", &PythonCallTimestamps::set_writeEndNow);
823 
824   enum_<SSLPolicy>("SSLPolicy")
825       .value("DISABLED", SSLPolicy::DISABLED)
826       .value("PERMITTED", SSLPolicy::PERMITTED)
827       .value("REQUIRED", SSLPolicy::REQUIRED);
828 
829   enum_<folly::SSLContext::VerifyClientCertificate>("VerifyClientCertificate")
830       .value(
831           "IF_PRESENTED",
832           folly::SSLContext::VerifyClientCertificate::IF_PRESENTED)
833       .value(
834           "ALWAYS_VERIFY", folly::SSLContext::VerifyClientCertificate::ALWAYS)
835       .value(
836           "NONE_DO_NOT_REQUEST",
837           folly::SSLContext::VerifyClientCertificate::DO_NOT_REQUEST);
838 
839   enum_<folly::SSLContext::SSLVersion>("SSLVersion")
840       .value("TLSv1_2", folly::SSLContext::SSLVersion::TLSv1_2);
841 }
842