1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <thrift/lib/cpp2/transport/rocket/test/util/TestUtil.h>
18 
19 #include <thrift/lib/cpp2/async/RocketClientChannel.h>
20 #include <thrift/lib/cpp2/transport/core/testutil/TAsyncSocketIntercepted.h>
21 
22 DECLARE_int32(num_client_connections);
23 DECLARE_string(transport); // ConnectionManager depends on this flag.
24 
25 namespace apache {
26 namespace thrift {
27 
createServer(std::shared_ptr<AsyncProcessorFactory> processorFactory,uint16_t & port,int maxRequests,std::string transport)28 std::unique_ptr<ThriftServer> TestSetup::createServer(
29     std::shared_ptr<AsyncProcessorFactory> processorFactory,
30     uint16_t& port,
31     int maxRequests,
32     std::string transport) {
33   // override the default
34   FLAGS_transport = transport; // client's transport
35   observer_ = std::make_shared<FakeServerObserver>();
36 
37   auto server = std::make_unique<ThriftServer>();
38   if (maxRequests > 0) {
39     server->setMaxRequests(maxRequests);
40   }
41   server->setObserver(observer_);
42   server->setPort(0);
43   server->setNumIOWorkerThreads(numIOThreads_);
44   server->setNumCPUWorkerThreads(numWorkerThreads_);
45   if (queueTimeout_.has_value()) {
46     server->setQueueTimeout(*queueTimeout_);
47   }
48   if (idleTimeout_.has_value()) {
49     server->setIdleTimeout(*idleTimeout_);
50   }
51   if (taskExpireTime_.has_value()) {
52     server->setTaskExpireTime(*taskExpireTime_);
53   }
54   if (streamExpireTime_.has_value()) {
55     server->setStreamExpireTime(*streamExpireTime_);
56   }
57 
58   server->setProcessorFactory(processorFactory);
59 
60   auto eventHandler = std::make_shared<TestEventHandler>();
61   server->setServerEventHandler(eventHandler);
62   server->setup();
63 
64   // Get the port that the server has bound to
65   port = eventHandler->waitForPortAssignment();
66   return server;
67 }
68 
connectToServer(uint16_t port,folly::Function<void ()> onDetachable,folly::Function<void (TAsyncSocketIntercepted &)> socketSetup)69 RequestChannel::Ptr TestSetup::connectToServer(
70     uint16_t port,
71     folly::Function<void()> onDetachable,
72     folly::Function<void(TAsyncSocketIntercepted&)> socketSetup) {
73   CHECK_GT(port, 0) << "Check if the server has started already";
74   return PooledRequestChannel::newChannel(
75       evbThread_.getEventBase(),
76       ioThread_,
77       [port,
78        onDetachable = std::move(onDetachable),
79        socketSetup = std::move(socketSetup)](folly::EventBase& evb) mutable
80       -> std::unique_ptr<ClientChannel, folly::DelayedDestruction::Destructor> {
81         auto socket = folly::AsyncSocket::UniquePtr(
82             new TAsyncSocketIntercepted(&evb, "::1", port));
83         if (socketSetup) {
84           socketSetup(*static_cast<TAsyncSocketIntercepted*>(socket.get()));
85         }
86 
87         ClientChannel::Ptr channel =
88             RocketClientChannel::newChannel(std::move(socket));
89 
90         if (onDetachable) {
91           channel->setOnDetachable(std::move(onDetachable));
92         }
93         return channel;
94       });
95 }
96 
97 } // namespace thrift
98 } // namespace apache
99