1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <deque>
19 #include <memory>
20 #include <string>
21 
22 #include <folly/portability/GTest.h>
23 
24 #include <folly/Conv.h>
25 #include <folly/Try.h>
26 #include <folly/fibers/Baton.h>
27 #include <folly/fibers/Fiber.h>
28 #include <folly/fibers/FiberManager.h>
29 #include <folly/fibers/FiberManagerMap.h>
30 #include <folly/futures/Future.h>
31 #include <folly/io/Cursor.h>
32 #include <folly/io/IOBufQueue.h>
33 #include <folly/io/async/AsyncSocket.h>
34 #include <folly/io/async/EventBase.h>
35 
36 #include <thrift/lib/cpp2/async/ClientBufferedStream.h>
37 #include <thrift/lib/cpp2/async/RocketClientChannel.h>
38 #include <thrift/lib/cpp2/async/ServerStream.h>
39 #include <thrift/lib/cpp2/test/gen-cpp2/TestService.h>
40 #include <thrift/lib/cpp2/transport/rocket/test/util/TestUtil.h>
41 #include <thrift/lib/cpp2/util/ScopedServerInterfaceThread.h>
42 
43 using namespace apache::thrift;
44 
45 namespace {
46 class Handler : public test::TestServiceSvIf {
47  public:
semifuture_sendResponse(int64_t size)48   folly::SemiFuture<std::unique_ptr<std::string>> semifuture_sendResponse(
49       int64_t size) final {
50     lastTimeoutMsec_ =
51         getConnectionContext()->getHeader()->getClientTimeout().count();
52     return folly::makeSemiFuture()
53         .delayed(std::chrono::milliseconds(sleepDelayMsec_))
54         .defer([size](auto&&) {
55           return std::make_unique<std::string>(folly::to<std::string>(size));
56         });
57   }
58 
semifuture_noResponse(int64_t)59   folly::SemiFuture<folly::Unit> semifuture_noResponse(int64_t) final {
60     lastTimeoutMsec_ =
61         getConnectionContext()->getHeader()->getClientTimeout().count();
62     return folly::makeSemiFuture();
63   }
64 
semifuture_echoIOBuf(std::unique_ptr<folly::IOBuf> iobuf)65   folly::SemiFuture<std::unique_ptr<test::IOBufPtr>> semifuture_echoIOBuf(
66       std::unique_ptr<folly::IOBuf> iobuf) final {
67     return folly::makeSemiFuture(
68         std::make_unique<test::IOBufPtr>(std::move(iobuf)));
69   }
70 
semifuture_noResponseIOBuf(std::unique_ptr<folly::IOBuf>)71   folly::SemiFuture<folly::Unit> semifuture_noResponseIOBuf(
72       std::unique_ptr<folly::IOBuf>) final {
73     return folly::makeSemiFuture();
74   }
75 
echoIOBufAsByteStream(std::unique_ptr<folly::IOBuf> iobuf,int32_t delayMs)76   ServerStream<int8_t> echoIOBufAsByteStream(
77       std::unique_ptr<folly::IOBuf> iobuf, int32_t delayMs) final {
78     auto [stream, publisher] = ServerStream<int8_t>::createPublisher();
79     std::ignore = folly::makeSemiFuture()
80                       .delayed(std::chrono::milliseconds(delayMs))
81                       .via(getThreadManager())
82                       .thenValue([publisher = std::move(publisher),
83                                   iobuf = std::move(iobuf)](auto&&) mutable {
84                         folly::io::Cursor cursor(iobuf.get());
85                         int8_t byte;
86                         while (cursor.tryRead(byte)) {
87                           publisher.next(byte);
88                         }
89                         std::move(publisher).complete();
90                       });
91     return std::move(stream);
92   }
93 
getLastTimeoutMsec() const94   int32_t getLastTimeoutMsec() const { return lastTimeoutMsec_; }
setSleepDelayMs(int32_t delay)95   void setSleepDelayMs(int32_t delay) { sleepDelayMsec_ = delay; }
96 
97  private:
98   int32_t lastTimeoutMsec_{-1};
99   int32_t sleepDelayMsec_{0};
100 };
101 
102 class RocketClientChannelTest : public testing::Test {
103  public:
104   template <typename F>
makeClient(folly::EventBase & evb,F && configureChannel)105   test::TestServiceAsyncClient makeClient(
106       folly::EventBase& evb, F&& configureChannel) {
107     auto channel =
108         RocketClientChannel::newChannel(folly::AsyncSocket::UniquePtr(
109             new folly::AsyncSocket(&evb, runner_.getAddress())));
110     configureChannel(*channel);
111     return test::TestServiceAsyncClient(std::move(channel));
112   }
113 
makeClient(folly::EventBase & evb)114   test::TestServiceAsyncClient makeClient(folly::EventBase& evb) {
115     return makeClient(evb, [](auto&) {});
116   }
117 
118  protected:
119   std::shared_ptr<Handler> handler_{std::make_shared<Handler>()};
120   ScopedServerInterfaceThread runner_{handler_};
121 };
122 } // namespace
123 
TEST_F(RocketClientChannelTest,SyncThread)124 TEST_F(RocketClientChannelTest, SyncThread) {
125   folly::EventBase evb;
126   auto client = makeClient(evb);
127 
128   std::string response;
129   client.sync_sendResponse(response, 123);
130   EXPECT_EQ("123", response);
131 }
132 
TEST_F(RocketClientChannelTest,SyncFiber)133 TEST_F(RocketClientChannelTest, SyncFiber) {
134   folly::EventBase evb;
135   auto& fm = folly::fibers::getFiberManager(evb);
136   auto client = makeClient(evb);
137 
138   size_t responses = 0;
139   fm.addTaskFinally(
140       [&client] {
141         std::string response;
142         client.sync_sendResponse(response, 123);
143         return response;
144       },
145       [&responses](folly::Try<std::string>&& tryResponse) {
146         EXPECT_TRUE(tryResponse.hasValue());
147         EXPECT_EQ("123", *tryResponse);
148         ++responses;
149       });
150   while (fm.hasTasks()) {
151     evb.loopOnce();
152   }
153   EXPECT_EQ(1, responses);
154 }
155 
TEST_F(RocketClientChannelTest,SyncThreadOneWay)156 TEST_F(RocketClientChannelTest, SyncThreadOneWay) {
157   folly::EventBase evb;
158   auto client = makeClient(evb);
159   client.sync_noResponse(123);
160 }
161 
TEST_F(RocketClientChannelTest,SyncFiberOneWay)162 TEST_F(RocketClientChannelTest, SyncFiberOneWay) {
163   folly::EventBase evb;
164   auto& fm = folly::fibers::getFiberManager(evb);
165   auto client = makeClient(evb);
166 
167   size_t sent = 0;
168   fm.addTaskFinally(
169       [&client] { client.sync_noResponse(123); },
170       [&sent](folly::Try<void>&& tryResponse) {
171         EXPECT_TRUE(tryResponse.hasValue());
172         ++sent;
173       });
174   while (fm.hasTasks()) {
175     evb.loopOnce();
176   }
177   EXPECT_EQ(1, sent);
178 }
179 
TEST_F(RocketClientChannelTest,SyncThreadCheckTimeoutPropagated)180 TEST_F(RocketClientChannelTest, SyncThreadCheckTimeoutPropagated) {
181   folly::EventBase evb;
182   auto client = makeClient(evb);
183 
184   RpcOptions opts;
185   std::string response;
186   // Ensure that normally, the timeout value gets propagated.
187   opts.setTimeout(std::chrono::milliseconds(100));
188   client.sync_sendResponse(opts, response, 123);
189   EXPECT_EQ("123", response);
190   EXPECT_EQ(100, handler_->getLastTimeoutMsec());
191   // And when we set client-only, it's not propagated.
192   opts.setClientOnlyTimeouts(true);
193   client.sync_sendResponse(opts, response, 456);
194   EXPECT_EQ("456", response);
195   EXPECT_EQ(0, handler_->getLastTimeoutMsec());
196 
197   // Double-check that client enforces the timeouts in both cases.
198   handler_->setSleepDelayMs(200);
199   ASSERT_ANY_THROW(client.sync_sendResponse(opts, response, 456));
200   opts.setClientOnlyTimeouts(false);
201   ASSERT_ANY_THROW(client.sync_sendResponse(opts, response, 456));
202 
203   // Ensure that a 0 timeout is actually infinite
204   auto infiniteTimeoutClient =
205       makeClient(evb, [](auto& channel) { channel.setTimeout(0); });
206   opts.setTimeout(std::chrono::milliseconds::zero());
207   handler_->setSleepDelayMs(300);
208   infiniteTimeoutClient.sync_sendResponse(opts, response, 456);
209   EXPECT_EQ("456", response);
210 }
211 
TEST_F(RocketClientChannelTest,ThriftClientLifetime)212 TEST_F(RocketClientChannelTest, ThriftClientLifetime) {
213   folly::EventBase evb;
214   folly::Optional<test::TestServiceAsyncClient> client = makeClient(evb);
215 
216   auto& fm = folly::fibers::getFiberManager(evb);
217   auto future = fm.addTaskFuture([&] {
218     std::string response;
219     client->sync_sendResponse(response, 123);
220     EXPECT_EQ("123", response);
221   });
222 
223   // Trigger request sending.
224   evb.loopOnce();
225 
226   // Reset the client.
227   client.reset();
228 
229   // Wait for the response.
230   std::move(future).getVia(&evb);
231 }
232 
TEST_F(RocketClientChannelTest,LargeRequestResponse)233 TEST_F(RocketClientChannelTest, LargeRequestResponse) {
234   // send and receive large IOBufs to test rocket parser correctness in handling
235   // large (larger than kMaxBufferSize) payloads
236   folly::EventBase evb;
237   auto client = makeClient(evb);
238 
239   auto orig = std::string(1024 * 1024, 'x');
240   auto iobuf = folly::IOBuf::copyBuffer(orig);
241 
242   test::IOBufPtr response;
243   client.sync_echoIOBuf(
244       RpcOptions().setTimeout(std::chrono::seconds(30)), response, *iobuf);
245   EXPECT_EQ(
246       response->computeChainDataLength(), iobuf->computeChainDataLength());
247   auto res = response->moveToFbString();
248   EXPECT_EQ(orig, res);
249 }
250 
251 namespace {
252 
echoSync(test::TestServiceAsyncClient & client,size_t nbytes,std::optional<std::chrono::milliseconds> timeout=std::nullopt)253 folly::SemiFuture<std::unique_ptr<folly::IOBuf>> echoSync(
254     test::TestServiceAsyncClient& client,
255     size_t nbytes,
256     std::optional<std::chrono::milliseconds> timeout = std::nullopt) {
257   auto& fm =
258       folly::fibers::getFiberManager(*client.getChannel()->getEventBase());
259   return fm.addTaskFuture([&, nbytes, timeout] {
260     auto iobuf = folly::IOBuf::copyBuffer(std::string(nbytes, 'x'));
261     test::IOBufPtr response;
262     client.sync_echoIOBuf(
263         RpcOptions().setTimeout(timeout.value_or(std::chrono::seconds(30))),
264         response,
265         *iobuf);
266     return response;
267   });
268 }
269 
echoSemiFuture(test::TestServiceAsyncClient & client,size_t nbytes,std::optional<std::chrono::milliseconds> timeout=std::nullopt)270 folly::SemiFuture<std::unique_ptr<folly::IOBuf>> echoSemiFuture(
271     test::TestServiceAsyncClient& client,
272     size_t nbytes,
273     std::optional<std::chrono::milliseconds> timeout = std::nullopt) {
274   return folly::makeSemiFutureWith([&] {
275     auto iobuf = folly::IOBuf::copyBuffer(std::string(nbytes, 'x'));
276     auto options =
277         RpcOptions().setTimeout(timeout.value_or(std::chrono::seconds(30)));
278     return client.semifuture_echoIOBuf(options, *iobuf);
279   });
280 }
281 
noResponseIOBufSync(test::TestServiceAsyncClient & client,size_t nbytes)282 folly::SemiFuture<folly::Unit> noResponseIOBufSync(
283     test::TestServiceAsyncClient& client, size_t nbytes) {
284   auto& fm =
285       folly::fibers::getFiberManager(*client.getChannel()->getEventBase());
286   return fm.addTaskFuture([&, nbytes] {
287     auto iobuf = folly::IOBuf::copyBuffer(std::string(nbytes, 'x'));
288     client.sync_noResponseIOBuf(
289         RpcOptions().setTimeout(std::chrono::seconds(30)), *iobuf);
290   });
291 }
292 
noResponseIOBufSemiFuture(test::TestServiceAsyncClient & client,size_t nbytes)293 folly::SemiFuture<folly::Unit> noResponseIOBufSemiFuture(
294     test::TestServiceAsyncClient& client, size_t nbytes) {
295   return folly::makeSemiFutureWith([&] {
296     auto iobuf = folly::IOBuf::copyBuffer(std::string(nbytes, 'x'));
297     auto options = RpcOptions().setTimeout(std::chrono::seconds(30));
298     client.semifuture_noResponseIOBuf(options, *iobuf);
299   });
300 }
301 
echoIOBufAsByteStreamSync(test::TestServiceAsyncClient & client,size_t nbytes)302 folly::SemiFuture<ClientBufferedStream<int8_t>> echoIOBufAsByteStreamSync(
303     test::TestServiceAsyncClient& client, size_t nbytes) {
304   auto& fm =
305       folly::fibers::getFiberManager(*client.getChannel()->getEventBase());
306   return fm.addTaskFuture([&, nbytes] {
307     auto iobuf = folly::IOBuf::copyBuffer(std::string(nbytes, 'x'));
308     return client.sync_echoIOBufAsByteStream(
309         RpcOptions().setTimeout(std::chrono::seconds(30)),
310         *iobuf,
311         0 /* delayMs */);
312   });
313 }
314 
echoIOBufAsByteStreamSemiFuture(test::TestServiceAsyncClient & client,size_t nbytes)315 folly::SemiFuture<ClientBufferedStream<int8_t>> echoIOBufAsByteStreamSemiFuture(
316     test::TestServiceAsyncClient& client, size_t nbytes) {
317   return folly::makeSemiFutureWith([&] {
318     auto iobuf = folly::IOBuf::copyBuffer(std::string(nbytes, 'x'));
319     auto options = RpcOptions().setTimeout(std::chrono::seconds(30));
320     return client.semifuture_echoIOBufAsByteStream(
321         options, *iobuf, 0 /* delayMs */);
322   });
323 }
324 } // namespace
325 
TEST_F(RocketClientChannelTest,BatchedWriteFastFirstResponseFiberSync)326 TEST_F(RocketClientChannelTest, BatchedWriteFastFirstResponseFiberSync) {
327   folly::EventBase evb;
328   auto* slowWritingSocket = new SlowWritingSocket(&evb, runner_.getAddress());
329   test::TestServiceAsyncClient client(RocketClientChannel::newChannel(
330       folly::AsyncSocket::UniquePtr(slowWritingSocket)));
331 
332   // Allow first requests to be written completely to the socket quickly, but
333   // hold off on sending the complete second request.
334   slowWritingSocket->delayWritingAfterFirstNBytes(2000);
335 
336   std::vector<folly::SemiFuture<folly::Unit>> futures;
337   auto sf =
338       folly::makeSemiFuture()
339           .delayed(std::chrono::seconds(2))
340           .via(&evb)
341           .thenValue([&](auto&&) { slowWritingSocket->flushBufferedWrites(); });
342   futures.push_back(std::move(sf));
343 
344   for (size_t i = 0; i < 5; ++i) {
345     sf = echoSync(client, 25).via(&evb).thenTry([](auto&& response) {
346       EXPECT_TRUE(response.hasValue());
347       EXPECT_EQ(25, response.value()->computeChainDataLength());
348     });
349     futures.push_back(std::move(sf));
350 
351     sf = noResponseIOBufSync(client, 25).via(&evb).thenTry([](auto&& response) {
352       EXPECT_TRUE(response.hasValue());
353     });
354     futures.push_back(std::move(sf));
355 
356     sf = echoIOBufAsByteStreamSync(client, 25)
357              .via(&evb)
358              .thenTry([&](auto&& stream) {
359                EXPECT_TRUE(stream.hasValue());
360                return std::move(*stream)
361                    .subscribeExTry(
362                        &evb,
363                        [](auto&& next) {
364                          EXPECT_FALSE(next.hasException())
365                              << next.exception().what();
366                        })
367                    .futureJoin();
368              });
369     futures.push_back(std::move(sf));
370   }
371 
372   sf = echoSync(client, 2000).via(&evb).thenTry([](auto&& response) {
373     EXPECT_TRUE(response.hasValue());
374     EXPECT_EQ(2000, response.value()->computeChainDataLength());
375   });
376   futures.push_back(std::move(sf));
377 
378   folly::collectAllUnsafe(std::move(futures)).getVia(&evb);
379 }
380 
TEST_F(RocketClientChannelTest,BatchedWriteFastFirstResponseSemiFuture)381 TEST_F(RocketClientChannelTest, BatchedWriteFastFirstResponseSemiFuture) {
382   folly::EventBase evb;
383   auto* slowWritingSocket = new SlowWritingSocket(&evb, runner_.getAddress());
384   test::TestServiceAsyncClient client(RocketClientChannel::newChannel(
385       folly::AsyncSocket::UniquePtr(slowWritingSocket)));
386 
387   // Allow first requests to be written completely to the socket quickly, but
388   // hold off on sending the complete second request.
389   slowWritingSocket->delayWritingAfterFirstNBytes(2000);
390 
391   std::vector<folly::SemiFuture<folly::Unit>> futures;
392   auto sf =
393       folly::makeSemiFuture()
394           .delayed(std::chrono::seconds(2))
395           .via(&evb)
396           .thenValue([&](auto&&) { slowWritingSocket->flushBufferedWrites(); });
397   futures.push_back(std::move(sf));
398 
399   for (size_t i = 0; i < 5; ++i) {
400     sf = echoSemiFuture(client, 25).via(&evb).thenTry([&](auto&& response) {
401       EXPECT_TRUE(response.hasValue());
402       EXPECT_EQ(25, response.value()->computeChainDataLength());
403     });
404     futures.push_back(std::move(sf));
405 
406     sf = noResponseIOBufSemiFuture(client, 25)
407              .via(&evb)
408              .thenTry(
409                  [&](auto&& response) { EXPECT_TRUE(response.hasValue()); });
410     futures.push_back(std::move(sf));
411 
412     sf = echoIOBufAsByteStreamSemiFuture(client, 25)
413              .via(&evb)
414              .thenTry([&](auto&& stream) {
415                EXPECT_TRUE(stream.hasValue());
416                return std::move(*stream)
417                    .subscribeExTry(
418                        &evb,
419                        [](auto&& next) {
420                          EXPECT_FALSE(next.hasException())
421                              << next.exception().what();
422                        })
423                    .futureJoin();
424              });
425     futures.push_back(std::move(sf));
426   }
427 
428   sf = echoSemiFuture(client, 2000).via(&evb).thenTry([&](auto&& response) {
429     EXPECT_TRUE(response.hasValue());
430     EXPECT_EQ(2000, response.value()->computeChainDataLength());
431   });
432   futures.push_back(std::move(sf));
433 
434   folly::collectAllUnsafe(std::move(futures)).getVia(&evb);
435 }
436 
437 namespace {
doFailLastRequestsInBatchFiber(const folly::SocketAddress & serverAddr,folly::Optional<size_t> failLastRequestWithNBytesWritten=folly::none)438 void doFailLastRequestsInBatchFiber(
439     const folly::SocketAddress& serverAddr,
440     folly::Optional<size_t> failLastRequestWithNBytesWritten = folly::none) {
441   folly::EventBase evb;
442   auto* slowWritingSocket = new SlowWritingSocket(&evb, serverAddr);
443   test::TestServiceAsyncClient client(RocketClientChannel::newChannel(
444       folly::AsyncSocket::UniquePtr(slowWritingSocket)));
445 
446   // Allow first requests to be written completely to the socket quickly, but
447   // hold off on sending the complete second request.
448   slowWritingSocket->delayWritingAfterFirstNBytes(2000);
449 
450   std::vector<folly::SemiFuture<folly::Unit>> futures;
451   auto sf = folly::makeSemiFuture()
452                 .delayed(std::chrono::seconds(2))
453                 .via(&evb)
454                 .thenValue([&](auto&&) {
455                   slowWritingSocket->errorOutBufferedWrites(
456                       failLastRequestWithNBytesWritten);
457                 });
458   futures.push_back(std::move(sf));
459 
460   for (size_t i = 0; i < 5; ++i) {
461     sf = echoSync(client, 25).via(&evb).thenTry([](auto&& response) {
462       EXPECT_TRUE(response.hasValue());
463       EXPECT_EQ(25, response.value()->computeChainDataLength());
464     });
465     futures.push_back(std::move(sf));
466 
467     sf = noResponseIOBufSync(client, 25).via(&evb).thenTry([](auto&& response) {
468       EXPECT_FALSE(response.hasValue());
469     });
470     futures.push_back(std::move(sf));
471 
472     sf = echoIOBufAsByteStreamSync(client, 25)
473              .via(&evb)
474              .thenTry([&](auto&& stream) {
475                EXPECT_TRUE(stream.hasValue());
476                return std::move(*stream)
477                    .subscribeExTry(
478                        &evb,
479                        [](auto&& next) {
480                          EXPECT_FALSE(next.hasException())
481                              << next.exception().what();
482                        })
483                    .futureJoin();
484              });
485     futures.push_back(std::move(sf));
486   }
487 
488   for (size_t i = 0; i < 5; ++i) {
489     sf = echoSync(client, 2000).via(&evb).thenTry([](auto&& response) {
490       EXPECT_TRUE(response.hasException());
491       EXPECT_TRUE(
492           response.exception()
493               .template is_compatible_with<transport::TTransportException>());
494     });
495     futures.push_back(std::move(sf));
496 
497     sf = echoIOBufAsByteStreamSync(client, 2000)
498              .via(&evb)
499              .thenTry(
500                  [&](auto&& stream) { EXPECT_TRUE(stream.hasException()); });
501     futures.push_back(std::move(sf));
502   }
503 
504   folly::collectAllUnsafe(std::move(futures)).getVia(&evb);
505 }
506 
doFailLastRequestsInBatchSemiFuture(const folly::SocketAddress & serverAddr,folly::Optional<size_t> failLastRequestWithNBytesWritten=folly::none)507 void doFailLastRequestsInBatchSemiFuture(
508     const folly::SocketAddress& serverAddr,
509     folly::Optional<size_t> failLastRequestWithNBytesWritten = folly::none) {
510   folly::EventBase evb;
511   auto* slowWritingSocket = new SlowWritingSocket(&evb, serverAddr);
512   test::TestServiceAsyncClient client(RocketClientChannel::newChannel(
513       folly::AsyncSocket::UniquePtr(slowWritingSocket)));
514 
515   // Allow first requests to be written completely to the socket quickly, but
516   // hold off on sending the complete second request.
517   slowWritingSocket->delayWritingAfterFirstNBytes(2000);
518 
519   std::vector<folly::SemiFuture<folly::Unit>> futures;
520   auto sf = folly::makeSemiFuture()
521                 .delayed(std::chrono::seconds(2))
522                 .via(&evb)
523                 .thenValue([&](auto&&) {
524                   slowWritingSocket->errorOutBufferedWrites(
525                       failLastRequestWithNBytesWritten);
526                 });
527   futures.push_back(std::move(sf));
528 
529   for (size_t i = 0; i < 5; ++i) {
530     sf = echoSemiFuture(client, 25).via(&evb).thenTry([&](auto&& response) {
531       EXPECT_TRUE(response.hasValue());
532       EXPECT_EQ(25, response.value()->computeChainDataLength());
533     });
534     futures.push_back(std::move(sf));
535 
536     sf = noResponseIOBufSemiFuture(client, 25)
537              .via(&evb)
538              .thenTry(
539                  [&](auto&& response) { EXPECT_TRUE(response.hasValue()); });
540     futures.push_back(std::move(sf));
541 
542     sf = echoIOBufAsByteStreamSemiFuture(client, 25)
543              .via(&evb)
544              .thenTry([&](auto&& stream) {
545                EXPECT_TRUE(stream.hasValue());
546                return std::move(*stream)
547                    .subscribeExTry(
548                        &evb,
549                        [](auto&& next) {
550                          EXPECT_FALSE(next.hasException())
551                              << next.exception().what();
552                        })
553                    .futureJoin();
554              });
555     futures.push_back(std::move(sf));
556   }
557 
558   for (size_t i = 0; i < 5; ++i) {
559     sf = echoSemiFuture(client, 2000).via(&evb).thenTry([&](auto&& response) {
560       EXPECT_TRUE(response.hasException());
561       EXPECT_TRUE(
562           response.exception()
563               .template is_compatible_with<transport::TTransportException>());
564     });
565     futures.push_back(std::move(sf));
566 
567     sf = echoIOBufAsByteStreamSemiFuture(client, 2000)
568              .via(&evb)
569              .thenTry(
570                  [&](auto&& stream) { EXPECT_TRUE(stream.hasException()); });
571     futures.push_back(std::move(sf));
572   }
573 
574   folly::collectAllUnsafe(std::move(futures)).getVia(&evb);
575 }
576 } // namespace
577 
TEST_F(RocketClientChannelTest,FailLastRequestInBatchFiberSync)578 TEST_F(RocketClientChannelTest, FailLastRequestInBatchFiberSync) {
579   doFailLastRequestsInBatchFiber(runner_.getAddress());
580 }
581 
TEST_F(RocketClientChannelTest,FailLastRequestWithZeroBytesWrittenFiberSync)582 TEST_F(RocketClientChannelTest, FailLastRequestWithZeroBytesWrittenFiberSync) {
583   doFailLastRequestsInBatchFiber(
584       runner_.getAddress(), folly::Optional<size_t>(0));
585 }
586 
TEST_F(RocketClientChannelTest,FailLastRequestInBatchSemiFuture)587 TEST_F(RocketClientChannelTest, FailLastRequestInBatchSemiFuture) {
588   doFailLastRequestsInBatchSemiFuture(runner_.getAddress());
589 }
590 
TEST_F(RocketClientChannelTest,FailLastRequestWithZeroBytesWrittenSemiFuture)591 TEST_F(RocketClientChannelTest, FailLastRequestWithZeroBytesWrittenSemiFuture) {
592   doFailLastRequestsInBatchSemiFuture(
593       runner_.getAddress(), folly::Optional<size_t>(0));
594 }
595 
TEST_F(RocketClientChannelTest,BatchedWriteRequestResponseWithFastClientTimeout)596 TEST_F(
597     RocketClientChannelTest, BatchedWriteRequestResponseWithFastClientTimeout) {
598   folly::EventBase evb;
599   auto* slowWritingSocket = new SlowWritingSocket(&evb, runner_.getAddress());
600   test::TestServiceAsyncClient client(RocketClientChannel::newChannel(
601       folly::AsyncSocket::UniquePtr(slowWritingSocket)));
602 
603   // Hold off on writing any requests. This ensures that this test exercises the
604   // path where a client request timeout fires while the request is still in the
605   // WRITE_SENDING queue.
606   slowWritingSocket->delayWritingAfterFirstNBytes(1);
607 
608   std::vector<folly::SemiFuture<folly::Unit>> futures;
609   const std::chrono::seconds flushDelay(2);
610   auto sf =
611       folly::makeSemiFuture()
612           .delayed(flushDelay)
613           .via(&evb)
614           .thenValue([&](auto&&) { slowWritingSocket->flushBufferedWrites(); });
615   futures.push_back(std::move(sf));
616 
617   auto checkResponse = [](const auto& response, size_t expectedResponseSize) {
618     if (expectedResponseSize == 0) {
619       EXPECT_TRUE(response.hasException());
620       EXPECT_TRUE(
621           response.exception()
622               .template is_compatible_with<transport::TTransportException>());
623       response.exception()
624           .template with_exception<transport::TTransportException>(
625               [](const auto& tex) {
626                 EXPECT_EQ(
627                     transport::TTransportException::TTransportExceptionType::
628                         TIMED_OUT,
629                     tex.getType());
630               });
631     } else {
632       EXPECT_TRUE(response.hasValue());
633       EXPECT_EQ(
634           expectedResponseSize, response.value()->computeChainDataLength());
635     }
636   };
637 
638   // Over several event loops, force some timeouts to fire before any socket
639   // writes complete at varying positions within each batch of requests.
640   std::vector<uint32_t> timeouts = {50, 50, 10000, 10000, 10000, 10000};
641   for (size_t requestSize = 20, loops = 0; loops < 20; ++loops) {
642     for (uint32_t timeoutMs : timeouts) {
643       const std::chrono::milliseconds timeout(timeoutMs);
644 
645       sf = echoSync(client, requestSize, timeout)
646                .via(&evb)
647                .thenTry([&checkResponse,
648                          responseSize = timeout < flushDelay ? 0 : requestSize](
649                             auto&& response) {
650                  checkResponse(response, responseSize);
651                });
652       futures.push_back(std::move(sf));
653 
654       sf = echoSemiFuture(client, requestSize, timeout)
655                .via(&evb)
656                .thenTry([&checkResponse,
657                          responseSize = timeout < flushDelay ? 0 : requestSize](
658                             auto&& response) {
659                  checkResponse(response, responseSize);
660                });
661       futures.push_back(std::move(sf));
662 
663       ++requestSize;
664     }
665 
666     // Start writing the current batch of requests and ensure a new batch is
667     // started next iteration
668     evb.loopOnce();
669     evb.loopOnce();
670 
671     std::rotate(timeouts.begin(), timeouts.begin() + 1, timeouts.end());
672   }
673 
674   folly::collectAllUnsafe(std::move(futures)).getVia(&evb);
675 }
676 
TEST_F(RocketClientChannelTest,StreamInitialResponseBeforeBatchedWriteFails)677 TEST_F(RocketClientChannelTest, StreamInitialResponseBeforeBatchedWriteFails) {
678   folly::EventBase evb;
679   auto* slowWritingSocket = new SlowWritingSocket(&evb, runner_.getAddress());
680   test::TestServiceAsyncClient client(RocketClientChannel::newChannel(
681       folly::AsyncSocket::UniquePtr(slowWritingSocket)));
682 
683   // Ensure the first request is written completely to the socket quickly, but
684   // force the write for the whole batch of requests to fail.
685   slowWritingSocket->delayWritingAfterFirstNBytes(1000);
686 
687   std::vector<folly::SemiFuture<folly::Unit>> futures;
688   auto sf = folly::makeSemiFuture()
689                 .delayed(std::chrono::seconds(1))
690                 .via(&evb)
691                 .thenValue([&](auto&&) {
692                   slowWritingSocket->errorOutBufferedWrites(
693                       folly::Optional<size_t>(0));
694                 });
695   futures.push_back(std::move(sf));
696 
697   // Keep the stream alive on both client and server until the end of the test
698   std::optional<ClientBufferedStream<signed char>::Subscription> subscription;
699   sf = folly::makeSemiFutureWith([&] {
700          auto iobuf = folly::IOBuf::copyBuffer(std::string(25, 'x'));
701          auto options = RpcOptions().setTimeout(std::chrono::seconds(30));
702          return client.semifuture_echoIOBufAsByteStream(
703              options, *iobuf, 2000 /* delayMs */);
704        })
705            .via(&evb)
706            .thenTry([&](auto&& stream) {
707              subscription.emplace(
708                  std::move(*stream).subscribeExTry(&evb, [](auto&&) {}));
709            });
710   futures.push_back(std::move(sf));
711 
712   // Include more requests in the write batch
713   for (size_t i = 0; i < 10; ++i) {
714     sf = echoSemiFuture(client, 1000).via(&evb).thenTry([&](auto&& response) {
715       EXPECT_TRUE(response.hasException());
716       EXPECT_TRUE(
717           response.exception()
718               .template is_compatible_with<transport::TTransportException>());
719     });
720     futures.push_back(std::move(sf));
721   }
722 
723   folly::collectAllUnsafe(std::move(futures)).getVia(&evb);
724   subscription->cancel();
725   std::move(*subscription).join();
726 }
727