1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Mocks.h"
18 
19 using namespace folly;
20 using namespace testing;
21 using namespace wangle;
22 
23 using TestServer = ServerBootstrap<DefaultPipeline>;
24 using TestClient = ClientBootstrap<DefaultPipeline>;
25 
26 class TestClientPipelineFactory : public PipelineFactory<DefaultPipeline> {
27  public:
newPipeline(std::shared_ptr<AsyncTransport> socket)28   DefaultPipeline::Ptr newPipeline(
29       std::shared_ptr<AsyncTransport> socket) override {
30     // Socket should be connected already
31     EXPECT_TRUE(socket->good());
32 
33     auto pipeline = DefaultPipeline::create();
34     pipeline->addBack(wangle::AsyncSocketHandler(socket));
35     pipeline->finalize();
36     return pipeline;
37   }
38 };
39 
40 class AcceptRoutingHandlerTest : public Test {
41  public:
SetUp()42   void SetUp() override {
43     routingData_.routingData = 'A';
44 
45     downstreamHandler_ = new MockBytesToBytesHandler();
46     downstreamPipelineFactory_ =
47         std::make_shared<MockDownstreamPipelineFactory>(downstreamHandler_);
48 
49     server_ = std::make_unique<TestServer>();
50 
51     // A routing pipeline with a mock routing handler that we can set
52     // expectations on.
53     routingPipeline_ = DefaultPipeline::create();
54 
55     routingDataHandlerFactory_ =
56         std::make_shared<MockRoutingDataHandlerFactory>();
57     acceptRoutingHandler_ = new MockAcceptRoutingHandler(
58         server_.get(),
59         routingDataHandlerFactory_,
60         downstreamPipelineFactory_,
61         routingPipeline_);
62     routingDataHandler_ =
63         new MockRoutingDataHandler(kConnId0, acceptRoutingHandler_);
64     routingDataHandlerFactory_->setRoutingDataHandler(routingDataHandler_);
65 
66     acceptPipeline_ = AcceptPipeline::create();
67     acceptPipeline_->addBack(
68         std::shared_ptr<MockAcceptRoutingHandler>(acceptRoutingHandler_));
69     acceptPipeline_->finalize();
70 
71     // A single threaded IOGroup shared between client and server for a
72     // deterministic event list.
73     auto ioGroup = std::make_shared<IOThreadPoolExecutor>(kNumIOThreads);
74 
75     acceptPipelineFactory_ =
76         std::make_shared<MockAcceptPipelineFactory>(acceptPipeline_);
77     server_->pipeline(acceptPipelineFactory_)->group(ioGroup, ioGroup)->bind(0);
78     server_->getSockets()[0]->getAddress(&address_);
79     VLOG(4) << "Start server at " << address_;
80   }
81 
getEventBase()82   EventBase* getEventBase() {
83     return server_->getIOGroup()->getEventBase();
84   }
85 
clientConnect()86   Future<DefaultPipeline*> clientConnect() {
87     client_ = std::make_shared<TestClient>();
88     client_->pipelineFactory(std::make_shared<TestClientPipelineFactory>());
89     client_->group(server_->getIOGroup());
90     return client_->connect(address_);
91   }
92 
clientConnectAndWrite()93   Future<DefaultPipeline*> clientConnectAndWrite() {
94     auto clientPipelinePromise =
95         std::make_shared<folly::Promise<DefaultPipeline*>>();
96 
97     getEventBase()->runInEventBaseThread([=]() {
98       clientConnect().thenValue([=](DefaultPipeline* clientPipeline) {
99         VLOG(4) << "Client connected. Send data.";
100         auto data = IOBuf::create(1);
101         data->append(1);
102         *(data->writableData()) = 'a';
103         clientPipeline->write(std::move(data)).thenValue([=](auto&&) {
104           clientPipelinePromise->setValue(clientPipeline);
105         });
106       });
107     });
108 
109     return clientPipelinePromise->getFuture();
110   }
111 
clientConnectAndCleanClose()112   Future<DefaultPipeline*> clientConnectAndCleanClose() {
113     auto clientPipelinePromise =
114         std::make_shared<folly::Promise<DefaultPipeline*>>();
115 
116     getEventBase()->runInEventBaseThread([=]() {
117       clientConnectAndWrite().thenValue([=](DefaultPipeline* clientPipeline) {
118         VLOG(4) << "Client close";
119         clientPipeline->close().thenValue(
120             [=](auto&&) { clientPipelinePromise->setValue(clientPipeline); });
121       });
122     });
123 
124     return clientPipelinePromise->getFuture();
125   }
126 
justClientConnect()127   Future<DefaultPipeline*> justClientConnect() {
128     auto clientPipelinePromise =
129         std::make_shared<folly::Promise<DefaultPipeline*>>();
130     getEventBase()->runInEventBaseThread([=]() {
131       clientConnect().thenValue([=](DefaultPipeline* clientPipeline) {
132         clientPipelinePromise->setValue(clientPipeline);
133       });
134     });
135 
136     return clientPipelinePromise->getFuture();
137   }
138 
sendClientException(DefaultPipeline * clientPipeline)139   void sendClientException(DefaultPipeline* clientPipeline) {
140     getEventBase()->runInEventBaseThread([=]() {
141       clientPipeline->writeException(
142           std::runtime_error("Client socket exception, right after connect."));
143     });
144   }
145 
TearDown()146   void TearDown() override {
147     acceptPipeline_.reset();
148     acceptPipelineFactory_->cleanup();
149   }
150 
151  protected:
152   std::unique_ptr<TestServer> server_;
153   std::shared_ptr<MockAcceptPipelineFactory> acceptPipelineFactory_;
154   AcceptPipeline::Ptr acceptPipeline_;
155   DefaultPipeline::Ptr routingPipeline_;
156   std::shared_ptr<MockRoutingDataHandlerFactory> routingDataHandlerFactory_;
157   MockRoutingDataHandler* routingDataHandler_;
158 
159   MockAcceptRoutingHandler* acceptRoutingHandler_;
160   MockBytesToBytesHandler* downstreamHandler_;
161   std::shared_ptr<MockDownstreamPipelineFactory> downstreamPipelineFactory_;
162   SocketAddress address_;
163   RoutingDataHandler<char>::RoutingData routingData_;
164 
165   std::shared_ptr<TestClient> client_;
166 
167   int kConnId0{0};
168   int kNumIOThreads{1};
169 };
170 
TEST_F(AcceptRoutingHandlerTest,ParseRoutingDataSuccess)171 TEST_F(AcceptRoutingHandlerTest, ParseRoutingDataSuccess) {
172   // Server receives data, and parses routing data
173   EXPECT_CALL(*routingDataHandler_, transportActive(_));
174   EXPECT_CALL(*routingDataHandler_, parseRoutingData(_, _))
175       .WillOnce(
176           Invoke([&](folly::IOBufQueue& /*bufQueue*/,
177                      MockRoutingDataHandler::RoutingData& /*routingData*/) {
178             VLOG(4) << "Parsed routing data";
179             return true;
180           }));
181 
182   // Downstream pipeline is created, and its handler receives events
183   boost::barrier barrier(2);
184   EXPECT_CALL(*downstreamHandler_, transportActive(_));
185   EXPECT_CALL(*downstreamHandler_, read(_, _))
186       .WillOnce(Invoke([&](MockBytesToBytesHandler::Context* /*ctx*/,
187                            IOBufQueue& /*bufQueue*/) {
188         VLOG(4) << "Downstream received a read";
189       }));
190   EXPECT_CALL(*downstreamHandler_, readEOF(_))
191       .WillOnce(Invoke([&](MockBytesToBytesHandler::Context* ctx) {
192         VLOG(4) << "Downstream EOF";
193         ctx->fireClose();
194         barrier.wait();
195       }));
196   EXPECT_CALL(*downstreamHandler_, transportInactive(_));
197 
198   // Send client request that triggers server processing
199   clientConnectAndCleanClose();
200 
201   barrier.wait();
202 
203   // Routing pipeline has been erased
204   EXPECT_EQ(0, acceptRoutingHandler_->getRoutingPipelineCount());
205 }
206 
TEST_F(AcceptRoutingHandlerTest,SocketErrorInRoutingPipeline)207 TEST_F(AcceptRoutingHandlerTest, SocketErrorInRoutingPipeline) {
208   // Server receives data, and parses routing data
209   boost::barrier barrierConnect(2);
210   EXPECT_CALL(*routingDataHandler_, transportActive(_));
211   EXPECT_CALL(*routingDataHandler_, parseRoutingData(_, _))
212       .WillOnce(
213           Invoke([&](folly::IOBufQueue& /*bufQueue*/,
214                      MockRoutingDataHandler::RoutingData& /*routingData*/) {
215             VLOG(4) << "Need more data to be parse.";
216             barrierConnect.wait();
217             return false;
218           }));
219 
220   // Send client request that triggers server processing
221   auto futureClientPipeline = clientConnectAndWrite();
222 
223   // Socket exception after routing pipeline had been created
224   barrierConnect.wait();
225   boost::barrier barrierException(2);
226   std::move(futureClientPipeline)
227       .thenValue([](DefaultPipeline* clientPipeline) {
228         clientPipeline->getTransport()->getEventBase()->runInEventBaseThread(
229             [clientPipeline]() {
230               clientPipeline->writeException(std::runtime_error(
231                   "Socket error while expecting routing data."));
232             });
233       });
234   EXPECT_CALL(*routingDataHandler_, readException(_, _))
235       .WillOnce(Invoke([&](MockBytesToBytesHandler::Context* /*ctx*/,
236                            folly::exception_wrapper ex) {
237         VLOG(4) << "Routing data handler Exception";
238         acceptRoutingHandler_->onError(kConnId0, ex);
239         barrierException.wait();
240       }));
241   barrierException.wait();
242 
243   // Downstream pipeline is not created
244   EXPECT_CALL(*downstreamHandler_, transportActive(_)).Times(0);
245   delete downstreamHandler_;
246 
247   // Routing pipeline has been erased
248   EXPECT_EQ(0, acceptRoutingHandler_->getRoutingPipelineCount());
249 }
250 
TEST_F(AcceptRoutingHandlerTest,OnNewConnectionWithBadSocket)251 TEST_F(AcceptRoutingHandlerTest, OnNewConnectionWithBadSocket) {
252   // Routing data handler doesn't receive any data
253   EXPECT_CALL(*routingDataHandler_, parseRoutingData(_, _)).Times(0);
254 
255   // Downstream pipeline is not created
256   EXPECT_CALL(*downstreamHandler_, transportActive(_)).Times(0);
257   delete downstreamHandler_;
258 
259   // Send client request that triggers server processing
260   boost::barrier barrierConnect(2);
261   EXPECT_CALL(*routingDataHandler_, transportActive(_))
262       .WillOnce(Invoke([&](MockBytesToBytesHandler::Context* /*ctx*/) {
263         barrierConnect.wait();
264       }));
265   auto futureClientPipeline = justClientConnect();
266   barrierConnect.wait();
267   futureClientPipeline.wait();
268 
269   // Expect an exception on the routing data handler
270   boost::barrier barrierException(2);
271   EXPECT_CALL(*routingDataHandler_, readException(_, _))
272       .WillOnce(Invoke(
273           [&](MockBytesToBytesHandler::Context* /*ctx*/,
274               folly::exception_wrapper /*ex*/) { barrierException.wait(); }));
275   sendClientException(futureClientPipeline.value());
276   barrierException.wait();
277 
278   // Routing pipeline has been added
279   EXPECT_EQ(1, acceptRoutingHandler_->getRoutingPipelineCount());
280 }
281 
TEST_F(AcceptRoutingHandlerTest,RoutingPipelineErasedOnlyOnce)282 TEST_F(AcceptRoutingHandlerTest, RoutingPipelineErasedOnlyOnce) {
283   // Simulate client socket throwing an exception, while routing data handler
284   // parsed data successfully.
285   acceptPipeline_->readException(
286       std::runtime_error("An exception from the socket."));
287   acceptRoutingHandler_->onRoutingData(kConnId0, routingData_);
288 }
289