1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <thrift/lib/cpp2/transport/rocket/test/util/TestUtil.h>
18
19 #include <thrift/lib/cpp2/async/RocketClientChannel.h>
20 #include <thrift/lib/cpp2/transport/core/testutil/TAsyncSocketIntercepted.h>
21
22 DECLARE_int32(num_client_connections);
23 DECLARE_string(transport); // ConnectionManager depends on this flag.
24
25 namespace apache {
26 namespace thrift {
27
createServer(std::shared_ptr<AsyncProcessorFactory> processorFactory,uint16_t & port,int maxRequests,std::string transport)28 std::unique_ptr<ThriftServer> TestSetup::createServer(
29 std::shared_ptr<AsyncProcessorFactory> processorFactory,
30 uint16_t& port,
31 int maxRequests,
32 std::string transport) {
33 // override the default
34 FLAGS_transport = transport; // client's transport
35 observer_ = std::make_shared<FakeServerObserver>();
36
37 auto server = std::make_unique<ThriftServer>();
38 if (maxRequests > 0) {
39 server->setMaxRequests(maxRequests);
40 }
41 server->setObserver(observer_);
42 server->setPort(0);
43 server->setNumIOWorkerThreads(numIOThreads_);
44 server->setNumCPUWorkerThreads(numWorkerThreads_);
45 if (queueTimeout_.has_value()) {
46 server->setQueueTimeout(*queueTimeout_);
47 }
48 if (idleTimeout_.has_value()) {
49 server->setIdleTimeout(*idleTimeout_);
50 }
51 if (taskExpireTime_.has_value()) {
52 server->setTaskExpireTime(*taskExpireTime_);
53 }
54 if (streamExpireTime_.has_value()) {
55 server->setStreamExpireTime(*streamExpireTime_);
56 }
57
58 server->setProcessorFactory(processorFactory);
59
60 auto eventHandler = std::make_shared<TestEventHandler>();
61 server->setServerEventHandler(eventHandler);
62 server->setup();
63
64 // Get the port that the server has bound to
65 port = eventHandler->waitForPortAssignment();
66 return server;
67 }
68
connectToServer(uint16_t port,folly::Function<void ()> onDetachable,folly::Function<void (TAsyncSocketIntercepted &)> socketSetup)69 RequestChannel::Ptr TestSetup::connectToServer(
70 uint16_t port,
71 folly::Function<void()> onDetachable,
72 folly::Function<void(TAsyncSocketIntercepted&)> socketSetup) {
73 CHECK_GT(port, 0) << "Check if the server has started already";
74 return PooledRequestChannel::newChannel(
75 evbThread_.getEventBase(),
76 ioThread_,
77 [port,
78 onDetachable = std::move(onDetachable),
79 socketSetup = std::move(socketSetup)](folly::EventBase& evb) mutable
80 -> std::unique_ptr<ClientChannel, folly::DelayedDestruction::Destructor> {
81 auto socket = folly::AsyncSocket::UniquePtr(
82 new TAsyncSocketIntercepted(&evb, "::1", port));
83 if (socketSetup) {
84 socketSetup(*static_cast<TAsyncSocketIntercepted*>(socket.get()));
85 }
86
87 ClientChannel::Ptr channel =
88 RocketClientChannel::newChannel(std::move(socket));
89
90 if (onDetachable) {
91 channel->setOnDetachable(std::move(onDetachable));
92 }
93 return channel;
94 });
95 }
96
97 } // namespace thrift
98 } // namespace apache
99