1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <folly/portability/GTest.h>
18 
19 #include <folly/SocketAddress.h>
20 #include <folly/io/Cursor.h>
21 #include <folly/io/IOBuf.h>
22 #include <folly/io/IOBufQueue.h>
23 #include <folly/io/async/AsyncSSLSocket.h>
24 #include <folly/io/async/AsyncSocket.h>
25 #include <folly/io/async/AsyncTimeout.h>
26 #include <folly/io/async/AsyncTransport.h>
27 #include <folly/io/async/EventBase.h>
28 #include <folly/io/async/test/SocketPair.h>
29 #include <folly/io/async/test/TestSSLServer.h>
30 #include <folly/lang/Bits.h>
31 #include <thrift/lib/cpp/EventHandlerBase.h>
32 #include <thrift/lib/cpp/async/TAsyncSSLSocket.h>
33 #include <thrift/lib/cpp2/async/Cpp2Channel.h>
34 #include <thrift/lib/cpp2/async/HeaderClientChannel.h>
35 #include <thrift/lib/cpp2/async/HeaderServerChannel.h>
36 #include <thrift/lib/cpp2/async/MessageChannel.h>
37 #include <thrift/lib/cpp2/async/RequestCallback.h>
38 #include <thrift/lib/cpp2/async/RequestChannel.h>
39 #include <thrift/lib/cpp2/async/ResponseChannel.h>
40 #include <thrift/lib/cpp2/async/RpcTypes.h>
41 
42 using namespace apache::thrift;
43 using namespace apache::thrift::async;
44 using namespace apache::thrift::transport;
45 using apache::thrift::ContextStack;
46 using folly::IOBuf;
47 using folly::IOBufQueue;
48 using std::make_unique;
49 using std::shared_ptr;
50 using std::unique_ptr;
51 
makeTestBufImpl(size_t len)52 unique_ptr<IOBuf> makeTestBufImpl(size_t len) {
53   unique_ptr<IOBuf> buf = IOBuf::create(len);
54   buf->IOBuf::append(len);
55   memset(buf->writableData(), char(0x80), len);
56   return LegacySerializedRequest(
57              T_COMPACT_PROTOCOL, "test", SerializedRequest(std::move(buf)))
58       .buffer;
59 }
60 
makeTestBuf(size_t len)61 unique_ptr<IOBuf> makeTestBuf(size_t len) {
62   for (auto requestLen = len; requestLen > 0; --requestLen) {
63     auto buf = makeTestBufImpl(requestLen);
64     if (buf->computeChainDataLength() == len) {
65       return buf;
66     }
67   }
68   LOG(FATAL) << "Can't generate valid legacy request of given length: " << len;
69 }
70 
makeTestSerializedRequest(size_t len)71 SerializedRequest makeTestSerializedRequest(size_t len) {
72   for (auto requestLen = len; requestLen > 0; --requestLen) {
73     unique_ptr<IOBuf> buf = IOBuf::create(requestLen);
74     buf->IOBuf::append(requestLen);
75     memset(buf->writableData(), char(0x80), requestLen);
76     if (LegacySerializedRequest(
77             T_COMPACT_PROTOCOL, "test", SerializedRequest(buf->clone()))
78             .buffer->computeChainDataLength() == len) {
79       return SerializedRequest(std::move(buf));
80     }
81   }
82   LOG(FATAL) << "Can't generate valid serialized request of given length: "
83              << len;
84 }
85 
lengthWithEnvelope(const ClientReceiveState & state)86 size_t lengthWithEnvelope(const ClientReceiveState& state) {
87   return LegacySerializedResponse(
88              state.protocolId(),
89              0,
90              state.messageType(),
91              "test",
92              SerializedResponse(state.serializedResponse().buffer->clone()))
93       .buffer->computeChainDataLength();
94 }
95 
96 class EventBaseAborter : public folly::AsyncTimeout {
97  public:
EventBaseAborter(folly::EventBase * eventBase,uint32_t timeoutMS)98   EventBaseAborter(folly::EventBase* eventBase, uint32_t timeoutMS)
99       : folly::AsyncTimeout(
100             eventBase, folly::AsyncTimeout::InternalEnum::INTERNAL),
101         eventBase_(eventBase) {
102     scheduleTimeout(timeoutMS);
103   }
104 
timeoutExpired()105   void timeoutExpired() noexcept override {
106     ADD_FAILURE();
107     eventBase_->terminateLoopSoon();
108   }
109 
110  private:
111   folly::EventBase* eventBase_;
112 };
113 
114 // Creates/unwraps a framed message (LEN(MSG) | MSG)
115 class TestFramingHandler : public FramingHandler {
116  public:
removeFrame(IOBufQueue * q)117   std::tuple<unique_ptr<IOBuf>, size_t, unique_ptr<THeader>> removeFrame(
118       IOBufQueue* q) override {
119     assert(q);
120     queue_.append(*q);
121     if (!queue_.front() || queue_.front()->empty()) {
122       return make_tuple(std::unique_ptr<IOBuf>(), 0, nullptr);
123     }
124 
125     uint32_t len = queue_.front()->computeChainDataLength();
126 
127     if (len < 4) {
128       size_t remaining = 4 - len;
129       return make_tuple(unique_ptr<IOBuf>(), remaining, nullptr);
130     }
131 
132     folly::io::Cursor c(queue_.front());
133     uint32_t msgLen = c.readBE<uint32_t>();
134     if (len - 4 < msgLen) {
135       size_t remaining = msgLen - (len - 4);
136       return make_tuple(unique_ptr<IOBuf>(), remaining, nullptr);
137     }
138 
139     queue_.trimStart(4);
140     return make_tuple(queue_.split(msgLen), 0, nullptr);
141   }
142 
addFrame(unique_ptr<IOBuf> buf,THeader *)143   unique_ptr<IOBuf> addFrame(unique_ptr<IOBuf> buf, THeader*) override {
144     assert(buf);
145     unique_ptr<IOBuf> framing;
146 
147     if (buf->headroom() > 4) {
148       framing = std::move(buf);
149       buf->prepend(4);
150     } else {
151       framing = IOBuf::create(4);
152       framing->append(4);
153       framing->appendChain(std::move(buf));
154     }
155     folly::io::RWPrivateCursor c(framing.get());
156     c.writeBE<uint32_t>(framing->computeChainDataLength() - 4);
157 
158     return framing;
159   }
160 
161  private:
162   IOBufQueue queue_;
163 };
164 
165 template <typename Channel>
createChannel(folly::AsyncTransport::UniquePtr transport)166 unique_ptr<Channel, folly::DelayedDestruction::Destructor> createChannel(
167     folly::AsyncTransport::UniquePtr transport) {
168   return Channel::newChannel(std::move(transport));
169 }
170 
171 template <>
createChannel(folly::AsyncTransport::UniquePtr transport)172 unique_ptr<Cpp2Channel, folly::DelayedDestruction::Destructor> createChannel(
173     folly::AsyncTransport::UniquePtr transport) {
174   return Cpp2Channel::newChannel(
175       std::move(transport), make_unique<TestFramingHandler>());
176 }
177 
178 template <>
createChannel(folly::AsyncTransport::UniquePtr transport)179 HeaderClientChannel::LegacyPtr createChannel(
180     folly::AsyncTransport::UniquePtr transport) {
181   return HeaderClientChannel::newChannel(
182       HeaderClientChannel::WithoutRocketUpgrade{}, std::move(transport));
183 }
184 
185 template <typename Channel1, typename Channel2>
186 class SocketPairTest {
187  public:
188   struct Config {
189     bool ssl{false};
190   };
191 
SocketPairTest(Config config=Config ())192   SocketPairTest(Config config = Config()) : eventBase_() {
193     folly::SocketPair socketPair;
194 
195     folly::AsyncSocket::UniquePtr socket0, socket1;
196     if (config.ssl) {
197       auto clientCtx = std::make_shared<folly::SSLContext>();
198       auto serverCtx = std::make_shared<folly::SSLContext>();
199       getctx(clientCtx, serverCtx);
200       socket0 = TAsyncSSLSocket::newSocket(
201           clientCtx, &eventBase_, socketPair.extractNetworkSocket0(), false);
202       socket1 = TAsyncSSLSocket::newSocket(
203           serverCtx, &eventBase_, socketPair.extractNetworkSocket1(), true);
204       dynamic_cast<folly::AsyncSSLSocket*>(socket0.get())->sslConn(nullptr);
205       dynamic_cast<folly::AsyncSSLSocket*>(socket1.get())->sslAccept(nullptr);
206     } else {
207       socket0 = folly::AsyncSocket::newSocket(
208           &eventBase_, socketPair.extractNetworkSocket0());
209       socket1 = folly::AsyncSocket::newSocket(
210           &eventBase_, socketPair.extractNetworkSocket1());
211     }
212     socket0_ = socket0.get();
213     socket1_ = socket1.get();
214 
215     channel0_ = createChannel<Channel1>(std::move(socket0));
216     channel1_ = createChannel<Channel2>(std::move(socket1));
217   }
~SocketPairTest()218   virtual ~SocketPairTest() {}
219 
loop(uint32_t timeoutMS)220   void loop(uint32_t timeoutMS) {
221     EventBaseAborter aborter(&eventBase_, timeoutMS);
222     eventBase_.loop();
223   }
224 
run()225   void run() { runWithTimeout(); }
226 
getctx(std::shared_ptr<folly::SSLContext> clientCtx,std::shared_ptr<folly::SSLContext> serverCtx)227   void getctx(
228       std::shared_ptr<folly::SSLContext> clientCtx,
229       std::shared_ptr<folly::SSLContext> serverCtx) {
230     clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
231 
232     serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
233     serverCtx->loadCertificate(folly::kTestCert);
234     serverCtx->loadPrivateKey(folly::kTestKey);
235   }
236 
getFd0()237   int getFd0() { return socket0_->getNetworkSocket().toFd(); }
238 
getFd1()239   int getFd1() { return socket1_->getNetworkSocket().toFd(); }
240 
getSocket0()241   folly::AsyncSocket* getSocket0() { return socket0_; }
242 
getSocket1()243   folly::AsyncSocket* getSocket1() { return socket1_; }
244 
runWithTimeout(uint32_t timeoutMS=6000)245   void runWithTimeout(uint32_t timeoutMS = 6000) {
246     preLoop();
247     loop(timeoutMS);
248     postLoop();
249   }
250 
preLoop()251   virtual void preLoop() {}
postLoop()252   virtual void postLoop() {}
253 
254  protected:
255   folly::EventBase eventBase_;
256   folly::AsyncSocket* socket0_;
257   folly::AsyncSocket* socket1_;
258   unique_ptr<Channel1, folly::DelayedDestruction::Destructor> channel0_;
259   unique_ptr<Channel2, folly::DelayedDestruction::Destructor> channel1_;
260 };
261 
262 class MessageCallback : public MessageChannel::SendCallback,
263                         public MessageChannel::RecvCallback {
264  public:
MessageCallback()265   MessageCallback()
266       : sent_(0),
267         recv_(0),
268         sendError_(0),
269         recvError_(0),
270         recvEOF_(0),
271         recvBytes_(0) {}
272 
sendQueued()273   void sendQueued() override {}
274 
messageSent()275   void messageSent() override { sent_++; }
messageSendError(folly::exception_wrapper &&)276   void messageSendError(folly::exception_wrapper&&) override { sendError_++; }
277 
messageReceived(unique_ptr<IOBuf> && buf,unique_ptr<THeader> &&)278   void messageReceived(
279       unique_ptr<IOBuf>&& buf, unique_ptr<THeader>&&) override {
280     recv_++;
281     recvBytes_ += buf->computeChainDataLength();
282   }
messageChannelEOF()283   void messageChannelEOF() override { recvEOF_++; }
messageReceiveErrorWrapped(folly::exception_wrapper &&)284   void messageReceiveErrorWrapped(folly::exception_wrapper&&) override {
285     sendError_++;
286   }
287 
288   uint32_t sent_;
289   uint32_t recv_;
290   uint32_t sendError_;
291   uint32_t recvError_;
292   uint32_t recvEOF_;
293   size_t recvBytes_;
294 };
295 
296 class TestRequestCallback : public RequestClientCallback, public CloseCallback {
297  public:
TestRequestCallback(bool oneWay=false)298   explicit TestRequestCallback(bool oneWay = false) : oneWay_(oneWay) {}
299 
onRequestSent()300   void onRequestSent() noexcept override {
301     if (oneWay_) {
302       delete this;
303     }
304   }
305 
onResponse(ClientReceiveState && state)306   void onResponse(ClientReceiveState&& state) noexcept override {
307     reply_++;
308     replyBytes_ += lengthWithEnvelope(state);
309     delete this;
310   }
311 
onResponseError(folly::exception_wrapper ex)312   void onResponseError(folly::exception_wrapper ex) noexcept override {
313     EXPECT_TRUE(ex);
314     replyError_++;
315     delete this;
316   }
317 
channelClosed()318   void channelClosed() override { closed_ = true; }
319 
reset()320   static void reset() {
321     closed_ = false;
322     reply_ = 0;
323     replyBytes_ = 0;
324     replyError_ = 0;
325   }
326   static bool closed_;
327   static uint32_t reply_;
328   static uint32_t replyBytes_;
329   static uint32_t replyError_;
330 
331  private:
332   const bool oneWay_;
333 };
334 
335 bool TestRequestCallback::closed_ = false;
336 uint32_t TestRequestCallback::reply_ = 0;
337 uint32_t TestRequestCallback::replyBytes_ = 0;
338 uint32_t TestRequestCallback::replyError_ = 0;
339 
340 class ResponseCallback : public HeaderServerChannel::Callback {
341  public:
ResponseCallback()342   ResponseCallback()
343       : serverClosed_(false), oneway_(0), request_(0), requestBytes_(0) {}
344 
requestReceived(unique_ptr<HeaderServerChannel::HeaderRequest> && req)345   void requestReceived(
346       unique_ptr<HeaderServerChannel::HeaderRequest>&& req) override {
347     request_++;
348     requestBytes_ += req->getBuf()->computeChainDataLength();
349     if (req->isOneway()) {
350       oneway_++;
351     } else {
352       req->sendReply(ResponsePayload::create(req->extractBuf()));
353     }
354   }
355 
channelClosed(folly::exception_wrapper &&)356   void channelClosed(folly::exception_wrapper&&) override {
357     serverClosed_ = true;
358   }
359 
360   bool serverClosed_;
361   uint32_t oneway_;
362   uint32_t request_;
363   uint32_t requestBytes_;
364 };
365 
366 class MessageTest : public SocketPairTest<Cpp2Channel, Cpp2Channel>,
367                     public MessageCallback {
368  public:
MessageTest(size_t len,Config socketConfig=Config ())369   explicit MessageTest(size_t len, Config socketConfig = Config())
370       : SocketPairTest(socketConfig), len_(len), header_(new THeader) {}
371 
preLoop()372   void preLoop() override {
373     channel0_->sendMessage(&sendCallback_, makeTestBuf(len_), header_.get());
374     channel1_->setReceiveCallback(this);
375   }
376 
postLoop()377   void postLoop() override {
378     EXPECT_EQ(sendCallback_.sendError_, 0);
379     EXPECT_EQ(recvError_, 0);
380     EXPECT_EQ(recvEOF_, 0);
381     EXPECT_EQ(recv_, 1);
382     EXPECT_EQ(sendCallback_.sent_, 1);
383     EXPECT_EQ(recvBytes_, len_);
384   }
385 
messageReceived(unique_ptr<IOBuf> && buf,unique_ptr<THeader> && header)386   void messageReceived(
387       unique_ptr<IOBuf>&& buf, unique_ptr<THeader>&& header) override {
388     MessageCallback::messageReceived(std::move(buf), std::move(header));
389     channel1_->setReceiveCallback(nullptr);
390   }
391 
392  private:
393   size_t len_;
394   unique_ptr<THeader> header_;
395   MessageCallback sendCallback_;
396 };
397 
TEST(Channel,Cpp2Channel)398 TEST(Channel, Cpp2Channel) {
399   MessageTest(10).run();
400   MessageTest(100).run();
401   MessageTest(1024 * 1024).run();
402 }
403 
TEST(Channel,Cpp2ChannelSSL)404 TEST(Channel, Cpp2ChannelSSL) {
405   MessageTest::Config socketConfig;
406   socketConfig.ssl = true;
407   MessageTest(10, socketConfig).run();
408   MessageTest(100, socketConfig).run();
409   MessageTest(1024 * 1024, socketConfig).run();
410 }
411 
412 class MessageCloseTest : public SocketPairTest<Cpp2Channel, Cpp2Channel>,
413                          public MessageCallback {
414  public:
MessageCloseTest()415   MessageCloseTest() : header_(new THeader) {}
416 
preLoop()417   void preLoop() override {
418     channel0_->sendMessage(
419         &sendCallback_, makeTestBuf(1024 * 1024), header_.get());
420     // Close the other socket after delay
421     this->eventBase_.runInLoop(
422         std::bind(&folly::AsyncSocket::close, this->socket1_));
423     channel1_->setReceiveCallback(this);
424   }
425 
postLoop()426   void postLoop() override {
427     EXPECT_EQ(sendCallback_.sendError_, 1);
428     EXPECT_EQ(recvError_, 0);
429     EXPECT_EQ(recvEOF_, 1);
430     EXPECT_EQ(recv_, 0);
431     EXPECT_EQ(sendCallback_.sent_, 0);
432   }
433 
messageChannelEOF()434   void messageChannelEOF() override {
435     MessageCallback::messageChannelEOF();
436     channel1_->setReceiveCallback(nullptr);
437   }
438 
439  private:
440   MessageCallback sendCallback_;
441   unique_ptr<THeader> header_;
442 };
443 
TEST(Channel,MessageCloseTest)444 TEST(Channel, MessageCloseTest) {
445   MessageCloseTest().run();
446 }
447 
448 class MessageEOFTest : public SocketPairTest<Cpp2Channel, Cpp2Channel>,
449                        public MessageCallback {
450  public:
MessageEOFTest()451   MessageEOFTest() : header_(new THeader) {}
452 
preLoop()453   void preLoop() override {
454     channel0_->setReceiveCallback(this);
455     channel1_->getTransport()->shutdownWrite();
456     channel0_->sendMessage(this, makeTestBuf(1024 * 1024), header_.get());
457   }
458 
postLoop()459   void postLoop() override {
460     EXPECT_EQ(sendError_, 1);
461     EXPECT_EQ(recvError_, 0);
462     EXPECT_EQ(recvEOF_, 1);
463     EXPECT_EQ(recv_, 0);
464     EXPECT_EQ(sent_, 0);
465   }
466 
467  private:
468   unique_ptr<THeader> header_;
469 };
470 
TEST(Channel,MessageEOFTest)471 TEST(Channel, MessageEOFTest) {
472   MessageEOFTest().run();
473 }
474 
475 class HeaderChannelTest
476     : public SocketPairTest<HeaderClientChannel, HeaderServerChannel>,
477       public TestRequestCallback,
478       public ResponseCallback {
479  public:
HeaderChannelTest(size_t len,Config socketConfig=Config ())480   explicit HeaderChannelTest(size_t len, Config socketConfig = Config())
481       : SocketPairTest(socketConfig), len_(len) {}
482 
483   class Callback : public TestRequestCallback {
484    public:
Callback(HeaderChannelTest * c,bool oneWay)485     Callback(HeaderChannelTest* c, bool oneWay)
486         : TestRequestCallback(oneWay), c_(c) {}
onResponse(ClientReceiveState && state)487     void onResponse(ClientReceiveState&& state) noexcept override {
488       c_->channel1_->setCallback(nullptr);
489       TestRequestCallback::onResponse(std::move(state));
490     }
491 
492    private:
493     HeaderChannelTest* c_;
494   };
495 
preLoop()496   void preLoop() override {
497     TestRequestCallback::reset();
498     channel1_->setCallback(this);
499     channel0_->setCloseCallback(this);
500     RpcOptions options;
501     channel0_->sendRequestNoResponse(
502         options,
503         "test",
504         makeTestSerializedRequest(len_),
505         std::unique_ptr<THeader>(new THeader),
506         RequestClientCallback::Ptr(new Callback(this, true)));
507     channel0_->sendRequestResponse(
508         options,
509         "test",
510         makeTestSerializedRequest(len_),
511         std::unique_ptr<THeader>(new THeader),
512         RequestClientCallback::Ptr(new Callback(this, false)));
513     channel0_->setCloseCallback(nullptr);
514   }
515 
postLoop()516   void postLoop() override {
517     EXPECT_EQ(reply_, 1);
518     EXPECT_EQ(replyError_, 0);
519     EXPECT_EQ(replyBytes_, len_);
520     EXPECT_EQ(closed_, false);
521     EXPECT_EQ(serverClosed_, false);
522     EXPECT_EQ(request_, 2);
523     EXPECT_EQ(requestBytes_, len_ * 2);
524     EXPECT_EQ(oneway_, 1);
525     channel1_->setCallback(nullptr);
526   }
527 
528  private:
529   size_t len_;
530 };
531 
TEST(Channel,HeaderChannelTest)532 TEST(Channel, HeaderChannelTest) {
533   HeaderChannelTest(10).run();
534   HeaderChannelTest(100).run();
535   HeaderChannelTest(1024 * 1024).run();
536 }
537 
TEST(Channel,HeaderChannelTestSSL)538 TEST(Channel, HeaderChannelTestSSL) {
539   HeaderChannelTest::Config socketConfig;
540   socketConfig.ssl = true;
541   HeaderChannelTest(10, socketConfig).run();
542   HeaderChannelTest(100, socketConfig).run();
543   HeaderChannelTest(1024 * 1024, socketConfig).run();
544 }
545 
546 class HeaderChannelClosedTest
547     : public SocketPairTest<HeaderClientChannel, HeaderServerChannel> {
548   //   , public TestRequestCallback
549   //   , public ResponseCallback {
550  public:
HeaderChannelClosedTest()551   explicit HeaderChannelClosedTest() {}
552 
553   class Callback : public RequestClientCallback {
554    public:
Callback(HeaderChannelClosedTest * c)555     explicit Callback(HeaderChannelClosedTest* c) : c_(c) {}
556 
~Callback()557     ~Callback() override { c_->callbackDtor_ = true; }
558 
onResponse(ClientReceiveState &&)559     void onResponse(ClientReceiveState&&) noexcept override {
560       FAIL() << "should not recv reply from closed channel";
561     }
562 
onResponseError(folly::exception_wrapper ew)563     void onResponseError(folly::exception_wrapper ew) noexcept override {
564       EXPECT_TRUE(ew.with_exception([this](const TTransportException& e) {
565         EXPECT_EQ(e.getType(), TTransportException::END_OF_FILE);
566         c_->gotError_ = true;
567       }));
568       delete this;
569     }
570 
571    private:
572     HeaderChannelClosedTest* c_;
573   };
574 
preLoop()575   void preLoop() override {
576     TestRequestCallback::reset();
577     channel1_->getTransport()->shutdownWrite();
578     RpcOptions options;
579     channel0_->sendRequestResponse(
580         options,
581         "test",
582         makeTestSerializedRequest(42),
583         std::make_unique<THeader>(),
584         RequestClientCallback::Ptr(new Callback(this)));
585   }
586 
postLoop()587   void postLoop() override {
588     EXPECT_TRUE(gotError_);
589     EXPECT_TRUE(callbackDtor_);
590   }
591 
592  private:
593   uint32_t seqId_;
594   bool gotError_ = true;
595   bool callbackDtor_ = false;
596 };
597 
TEST(Channel,HeaderChannelClosedTest)598 TEST(Channel, HeaderChannelClosedTest) {
599   HeaderChannelClosedTest().run();
600 }
601 
602 class InOrderTest
603     : public SocketPairTest<HeaderClientChannel, HeaderServerChannel>,
604       public TestRequestCallback,
605       public ResponseCallback {
606  public:
InOrderTest(Config socketConfig=Config ())607   explicit InOrderTest(Config socketConfig = Config())
608       : SocketPairTest(socketConfig), len_(10) {}
609 
610   class Callback : public TestRequestCallback {
611    public:
Callback(InOrderTest * c)612     explicit Callback(InOrderTest* c) : c_(c) {}
onResponse(ClientReceiveState && state)613     void onResponse(ClientReceiveState&& state) noexcept override {
614       if (reply_ == 1) {
615         c_->channel1_->setCallback(nullptr);
616         // Verify that they came back in the same order
617         EXPECT_EQ(lengthWithEnvelope(state), c_->len_ + 1);
618       }
619       TestRequestCallback::onResponse(std::move(state));
620     }
621 
requestReceived(ResponseChannelRequest::UniquePtr rcr)622     void requestReceived(ResponseChannelRequest::UniquePtr rcr) {
623       auto req = dynamic_cast<HeaderServerChannel::HeaderRequest*>(rcr.get());
624       c_->request_++;
625       c_->requestBytes_ += req->getBuf()->computeChainDataLength();
626       if (c_->firstbuf_) {
627         req->sendReply(ResponsePayload::create(req->extractBuf()));
628         auto firstbuf = dynamic_cast<HeaderServerChannel::HeaderRequest*>(
629             c_->firstbuf_.get());
630         c_->firstbuf_->sendReply(
631             ResponsePayload::create(firstbuf->extractBuf()));
632       } else {
633         c_->firstbuf_ = std::move(rcr);
634       }
635     }
636 
637    private:
638     InOrderTest* c_;
639   };
640 
preLoop()641   void preLoop() override {
642     TestRequestCallback::reset();
643     channel1_->setCallback(this);
644     RpcOptions options;
645     channel0_->sendRequestResponse(
646         options,
647         "test",
648         makeTestSerializedRequest(len_),
649         std::unique_ptr<THeader>(new THeader),
650         RequestClientCallback::Ptr(new Callback(this)));
651     channel0_->sendRequestResponse(
652         options,
653         "test",
654         makeTestSerializedRequest(len_ + 1),
655         std::unique_ptr<THeader>(new THeader),
656         RequestClientCallback::Ptr(new Callback(this)));
657   }
658 
postLoop()659   void postLoop() override {
660     EXPECT_EQ(reply_, 2);
661     EXPECT_EQ(replyError_, 0);
662     EXPECT_EQ(replyBytes_, 2 * len_ + 1);
663     EXPECT_EQ(closed_, false);
664     EXPECT_EQ(serverClosed_, false);
665     EXPECT_EQ(request_, 2);
666     EXPECT_EQ(requestBytes_, 2 * len_ + 1);
667     EXPECT_EQ(oneway_, 0);
668   }
669 
670  private:
671   ResponseChannelRequest::UniquePtr firstbuf_;
672   size_t len_;
673 };
674 
TEST(Channel,InOrderTest)675 TEST(Channel, InOrderTest) {
676   InOrderTest().run();
677 }
678 
TEST(Channel,InOrderTestSSL)679 TEST(Channel, InOrderTestSSL) {
680   InOrderTest::Config socketConfig;
681   socketConfig.ssl = true;
682   InOrderTest(socketConfig).run();
683 }
684 
685 class BadSeqIdTest
686     : public SocketPairTest<HeaderClientChannel, HeaderServerChannel>,
687       public TestRequestCallback,
688       public ResponseCallback {
689  public:
BadSeqIdTest(size_t len,Config socketConfig=Config ())690   explicit BadSeqIdTest(size_t len, Config socketConfig = Config())
691       : SocketPairTest(socketConfig), len_(len) {}
692 
693   class Callback : public TestRequestCallback {
694    public:
Callback(BadSeqIdTest * c,bool oneWay)695     Callback(BadSeqIdTest* c, bool oneWay)
696         : TestRequestCallback(oneWay), c_(c) {}
697 
onResponseError(folly::exception_wrapper ew)698     void onResponseError(folly::exception_wrapper ew) noexcept override {
699       c_->channel1_->setCallback(nullptr);
700       TestRequestCallback::onResponseError(std::move(ew));
701     }
702 
703    private:
704     BadSeqIdTest* c_;
705   };
706 
requestReceived(unique_ptr<HeaderServerChannel::HeaderRequest> && req)707   void requestReceived(
708       unique_ptr<HeaderServerChannel::HeaderRequest>&& req) override {
709     request_++;
710     requestBytes_ += req->getBuf()->computeChainDataLength();
711     if (req->isOneway()) {
712       oneway_++;
713       return;
714     }
715     unique_ptr<THeader> header(new THeader);
716     header->setSequenceNumber(-1);
717     HeaderServerChannel::HeaderRequest r(
718         channel1_.get(), req->extractBuf(), std::move(header), {});
719     r.sendReply(ResponsePayload::create(r.extractBuf()));
720   }
721 
preLoop()722   void preLoop() override {
723     TestRequestCallback::reset();
724     channel0_->setTimeout(1000);
725     channel1_->setCallback(this);
726     RpcOptions options;
727     channel0_->sendRequestNoResponse(
728         options,
729         "test",
730         makeTestSerializedRequest(len_),
731         std::unique_ptr<THeader>(new THeader),
732         RequestClientCallback::Ptr(new Callback(this, true)));
733     channel0_->sendRequestResponse(
734         options,
735         "test",
736         makeTestSerializedRequest(len_),
737         std::unique_ptr<THeader>(new THeader),
738         RequestClientCallback::Ptr(new Callback(this, false)));
739   }
740 
postLoop()741   void postLoop() override {
742     EXPECT_EQ(reply_, 0);
743     EXPECT_EQ(replyError_, 1);
744     EXPECT_EQ(replyBytes_, 0);
745     EXPECT_EQ(closed_, false);
746     EXPECT_EQ(serverClosed_, false);
747     EXPECT_EQ(request_, 2);
748     EXPECT_EQ(requestBytes_, len_ * 2);
749     EXPECT_EQ(oneway_, 1);
750   }
751 
752  private:
753   size_t len_;
754 };
755 
TEST(Channel,BadSeqIdTest)756 TEST(Channel, BadSeqIdTest) {
757   BadSeqIdTest(10).run();
758 }
759 
TEST(Channel,BadSeqIdTestSSL)760 TEST(Channel, BadSeqIdTestSSL) {
761   BadSeqIdTest::Config socketConfig;
762   socketConfig.ssl = true;
763   BadSeqIdTest(10, socketConfig).run();
764 }
765 
766 class TimeoutTest
767     : public SocketPairTest<HeaderClientChannel, HeaderServerChannel>,
768       public TestRequestCallback,
769       public ResponseCallback {
770  public:
TimeoutTest(uint32_t timeout,Config socketConfig=Config ())771   explicit TimeoutTest(uint32_t timeout, Config socketConfig = Config())
772       : SocketPairTest(socketConfig), timeout_(timeout), len_(10) {}
773 
preLoop()774   void preLoop() override {
775     TestRequestCallback::reset();
776     channel1_->setCallback(this);
777     channel0_->setTimeout(timeout_);
778     channel0_->setCloseCallback(this);
779     RpcOptions options;
780     channel0_->sendRequestResponse(
781         options,
782         "test",
783         makeTestSerializedRequest(len_),
784         std::unique_ptr<THeader>(new THeader),
785         RequestClientCallback::Ptr(new TestRequestCallback()));
786     channel0_->sendRequestResponse(
787         options,
788         "test",
789         makeTestSerializedRequest(len_),
790         std::unique_ptr<THeader>(new THeader),
791         RequestClientCallback::Ptr(new TestRequestCallback()));
792   }
793 
postLoop()794   void postLoop() override {
795     EXPECT_EQ(reply_, 0);
796     EXPECT_EQ(replyError_, 2);
797     EXPECT_EQ(replyBytes_, 0);
798     EXPECT_EQ(closed_, false); // client timeouts do not close connection
799     EXPECT_EQ(serverClosed_, false);
800     EXPECT_EQ(request_, 2);
801     EXPECT_EQ(requestBytes_, len_ * 2);
802     EXPECT_EQ(oneway_, 0);
803     channel0_->setCloseCallback(nullptr);
804     channel1_->setCallback(nullptr);
805   }
806 
requestReceived(unique_ptr<HeaderServerChannel::HeaderRequest> && req)807   void requestReceived(
808       unique_ptr<HeaderServerChannel::HeaderRequest>&& req) override {
809     request_++;
810     requestBytes_ += req->getBuf()->computeChainDataLength();
811     // Don't respond, let it time out
812     // TestRequestCallback::replyReceived(std::move(buf));
813     channel1_->getEventBase()->tryRunAfterDelay(
814         [&]() {
815           channel1_->setCallback(nullptr);
816           channel0_->setCloseCallback(nullptr);
817         },
818         timeout_ * 2); // enough time for server socket to close also
819   }
820 
821  private:
822   uint32_t timeout_;
823   size_t len_;
824 };
825 
TEST(Channel,TimeoutTest)826 TEST(Channel, TimeoutTest) {
827   TimeoutTest(25).run();
828   TimeoutTest(100).run();
829   TimeoutTest(250).run();
830 }
831 
TEST(Channel,TimeoutTestSSL)832 TEST(Channel, TimeoutTestSSL) {
833   TimeoutTest::Config socketConfig;
834   socketConfig.ssl = true;
835   TimeoutTest(25, socketConfig).run();
836   TimeoutTest(100, socketConfig).run();
837   TimeoutTest(250, socketConfig).run();
838 }
839 
840 // Test client per-call timeout options
841 class OptionsTimeoutTest
842     : public SocketPairTest<HeaderClientChannel, HeaderServerChannel>,
843       public TestRequestCallback,
844       public ResponseCallback {
845  public:
OptionsTimeoutTest(Config socketConfig=Config ())846   explicit OptionsTimeoutTest(Config socketConfig = Config())
847       : SocketPairTest(socketConfig), len_(10) {}
848 
preLoop()849   void preLoop() override {
850     TestRequestCallback::reset();
851     channel1_->setCallback(this);
852     channel0_->setTimeout(1000);
853     RpcOptions options;
854     options.setTimeout(std::chrono::milliseconds(25));
855     channel0_->sendRequestResponse(
856         options,
857         "test",
858         makeTestSerializedRequest(len_),
859         std::unique_ptr<THeader>(new THeader),
860         RequestClientCallback::Ptr(new TestRequestCallback()));
861     // Verify the timeout worked within 10ms
862     channel0_->getEventBase()->tryRunAfterDelay(
863         [&]() { EXPECT_EQ(replyError_, 1); }, 35);
864     // Verify that subsequent successful requests don't delay timeout
865     channel0_->getEventBase()->tryRunAfterDelay(
866         [&]() {
867           RpcOptions options;
868           channel0_->sendRequestResponse(
869               options,
870               "test",
871               makeTestSerializedRequest(len_),
872               std::unique_ptr<THeader>(new THeader),
873               RequestClientCallback::Ptr(new TestRequestCallback()));
874         },
875         20);
876   }
877 
postLoop()878   void postLoop() override {
879     EXPECT_EQ(reply_, 1);
880     EXPECT_EQ(replyError_, 1);
881     EXPECT_EQ(replyBytes_, len_);
882     EXPECT_EQ(closed_, false); // client timeouts do not close connection
883     EXPECT_EQ(serverClosed_, false);
884     EXPECT_EQ(request_, 2);
885     EXPECT_EQ(requestBytes_, len_ * 2);
886     EXPECT_EQ(oneway_, 0);
887   }
888 
requestReceived(unique_ptr<HeaderServerChannel::HeaderRequest> && req)889   void requestReceived(
890       unique_ptr<HeaderServerChannel::HeaderRequest>&& req) override {
891     if (request_ == 0) {
892       request_++;
893       requestBytes_ += req->getBuf()->computeChainDataLength();
894     } else {
895       ResponseCallback::requestReceived(std::move(req));
896       channel1_->setCallback(nullptr);
897     }
898   }
899 
900  private:
901   size_t len_;
902 };
903 
TEST(Channel,OptionsTimeoutTest)904 TEST(Channel, OptionsTimeoutTest) {
905   OptionsTimeoutTest().run();
906 }
907 
TEST(Channel,OptionsTimeoutTestSSL)908 TEST(Channel, OptionsTimeoutTestSSL) {
909   OptionsTimeoutTest::Config socketConfig;
910   socketConfig.ssl = true;
911   OptionsTimeoutTest(socketConfig).run();
912 }
913 
914 class ClientCloseTest
915     : public SocketPairTest<HeaderClientChannel, HeaderServerChannel>,
916       public TestRequestCallback,
917       public ResponseCallback {
918  public:
ClientCloseTest(bool halfClose)919   explicit ClientCloseTest(bool halfClose) : halfClose_(halfClose) {}
920 
preLoop()921   void preLoop() override {
922     TestRequestCallback::reset();
923     channel1_->setCallback(this);
924     channel0_->setCloseCallback(this);
925     if (halfClose_) {
926       channel1_->getEventBase()->tryRunAfterDelay(
927           [&]() { channel1_->getTransport()->shutdownWrite(); }, 10);
928     } else {
929       channel1_->getEventBase()->tryRunAfterDelay(
930           [&]() { channel1_->getTransport()->close(); }, 10);
931     }
932     channel1_->getEventBase()->tryRunAfterDelay(
933         [&]() { channel1_->setCallback(nullptr); }, 20);
934     channel0_->getEventBase()->tryRunAfterDelay(
935         [&]() { channel0_->setCloseCallback(nullptr); }, 20);
936   }
937 
postLoop()938   void postLoop() override {
939     EXPECT_EQ(reply_, 0);
940     EXPECT_EQ(replyError_, 0);
941     EXPECT_EQ(replyBytes_, 0);
942     EXPECT_EQ(closed_, true);
943     EXPECT_EQ(serverClosed_, !halfClose_);
944     EXPECT_EQ(request_, 0);
945     EXPECT_EQ(requestBytes_, 0);
946     EXPECT_EQ(oneway_, 0);
947   }
948 
949  private:
950   bool halfClose_;
951 };
952 
TEST(Channel,ClientCloseTest)953 TEST(Channel, ClientCloseTest) {
954   ClientCloseTest(true).run();
955   ClientCloseTest(false).run();
956 }
957 
958 class ServerCloseTest
959     : public SocketPairTest<HeaderClientChannel, HeaderServerChannel>,
960       public TestRequestCallback,
961       public ResponseCallback {
962  public:
ServerCloseTest(bool halfClose)963   explicit ServerCloseTest(bool halfClose) : halfClose_(halfClose) {}
964 
preLoop()965   void preLoop() override {
966     TestRequestCallback::reset();
967     channel1_->setCallback(this);
968     channel0_->setCloseCallback(this);
969     if (halfClose_) {
970       channel0_->getEventBase()->tryRunAfterDelay(
971           [&]() { channel0_->getTransport()->shutdownWrite(); }, 10);
972     } else {
973       channel0_->getEventBase()->tryRunAfterDelay(
974           [&]() { channel0_->getTransport()->close(); }, 10);
975     }
976     channel1_->getEventBase()->tryRunAfterDelay(
977         [&]() { channel1_->setCallback(nullptr); }, 20);
978     channel0_->getEventBase()->tryRunAfterDelay(
979         [&]() { channel0_->setCloseCallback(nullptr); }, 20);
980   }
981 
postLoop()982   void postLoop() override {
983     EXPECT_EQ(reply_, 0);
984     EXPECT_EQ(replyError_, 0);
985     EXPECT_EQ(replyBytes_, 0);
986     EXPECT_EQ(closed_, !halfClose_);
987     EXPECT_EQ(serverClosed_, true);
988     EXPECT_EQ(request_, 0);
989     EXPECT_EQ(requestBytes_, 0);
990     EXPECT_EQ(oneway_, 0);
991   }
992 
993  private:
994   bool halfClose_;
995 };
996 
TEST(Channel,ServerCloseTest)997 TEST(Channel, ServerCloseTest) {
998   ServerCloseTest(true).run();
999   ServerCloseTest(false).run();
1000 }
1001 
1002 class ClientCloseOnErrorTest;
1003 class InvalidResponseCallback : public HeaderServerChannel::Callback {
1004  public:
InvalidResponseCallback(ClientCloseOnErrorTest * self)1005   explicit InvalidResponseCallback(ClientCloseOnErrorTest* self)
1006       : self_(self), request_(0), requestBytes_(0) {}
1007 
1008   // configuration
closeSocketInResponse(bool value)1009   InvalidResponseCallback& closeSocketInResponse(bool value) {
1010     closeSocketInResponse_ = value;
1011     return *this;
1012   }
1013 
1014   void requestReceived(
1015       unique_ptr<HeaderServerChannel::HeaderRequest>&& req) override;
channelClosed(folly::exception_wrapper &&)1016   void channelClosed(folly::exception_wrapper&&) override {}
1017 
1018  protected:
1019   ClientCloseOnErrorTest* self_;
1020   uint32_t request_;
1021   uint32_t requestBytes_;
1022 
1023   bool closeSocketInResponse_ = false;
1024 };
1025 
1026 class ClientCloseOnErrorTest
1027     : public SocketPairTest<HeaderClientChannel, HeaderServerChannel>,
1028       public TestRequestCallback,
1029       public InvalidResponseCallback {
1030  public:
ClientCloseOnErrorTest()1031   explicit ClientCloseOnErrorTest() : InvalidResponseCallback(this) {}
1032 
1033   // configuration
forcePendingSend(bool value)1034   ClientCloseOnErrorTest& forcePendingSend(bool value) {
1035     forcePendingSend_ = value;
1036     return *this;
1037   }
1038 
closeSocketInResponse(bool value)1039   ClientCloseOnErrorTest& closeSocketInResponse(bool value) {
1040     InvalidResponseCallback::closeSocketInResponse(value);
1041     return *this;
1042   }
1043 
1044   class Callback : public TestRequestCallback {
1045    public:
Callback(ClientCloseOnErrorTest * c)1046     explicit Callback(ClientCloseOnErrorTest* c) : c_(c) {}
1047 
onResponseError(folly::exception_wrapper ew)1048     void onResponseError(folly::exception_wrapper ew) noexcept override {
1049       // force closing the channel on error
1050       c_->channel0_->closeNow();
1051       TestRequestCallback::onResponseError(std::move(ew));
1052     }
1053 
1054    private:
1055     ClientCloseOnErrorTest* c_;
1056   };
1057 
preLoop()1058   void preLoop() override {
1059     TestRequestCallback::reset();
1060 
1061     reqSize_ = 30;
1062     uint32_t ss = sizeof(reqSize_);
1063     if (forcePendingSend_) {
1064       // make request size big enough to not fit into kernel buffer
1065       getsockopt(getFd1(), SOL_SOCKET, SO_RCVBUF, &reqSize_, &ss);
1066       reqSize_++;
1067     }
1068 
1069     channel1_->setCallback(this);
1070     RpcOptions options;
1071     channel0_->sendRequestResponse(
1072         options,
1073         "test",
1074         makeTestSerializedRequest(10),
1075         std::make_unique<THeader>(),
1076         RequestClientCallback::Ptr(new Callback(this)));
1077     channel0_->sendRequestResponse(
1078         options,
1079         "test",
1080         makeTestSerializedRequest(reqSize_),
1081         std::make_unique<THeader>(),
1082         RequestClientCallback::Ptr(new Callback(this)));
1083   }
1084 
postLoop()1085   void postLoop() override {
1086     EXPECT_EQ(reply_, 0);
1087     EXPECT_EQ(replyError_, 2);
1088     EXPECT_EQ(replyBytes_, 0);
1089     EXPECT_EQ(request_, (forcePendingSend_ ? 1 : 2));
1090     EXPECT_EQ(requestBytes_, 10 + (forcePendingSend_ ? 0 : reqSize_));
1091     channel1_->setCallback(nullptr);
1092   }
1093 
1094  private:
1095   bool forcePendingSend_ = false;
1096   int32_t reqSize_;
1097 };
1098 
requestReceived(unique_ptr<HeaderServerChannel::HeaderRequest> && req)1099 void InvalidResponseCallback::requestReceived(
1100     unique_ptr<HeaderServerChannel::HeaderRequest>&& req) {
1101   request_++;
1102   requestBytes_ += req->getBuf()->computeChainDataLength();
1103   if (closeSocketInResponse_) {
1104     self_->getSocket1()->shutdownWrite();
1105   } else {
1106     write(self_->getFd1(), "SSH-", 4);
1107   }
1108 }
1109 
TEST(Channel,ClientCloseOnErrorTest)1110 TEST(Channel, ClientCloseOnErrorTest) {
1111   ClientCloseOnErrorTest()
1112       .forcePendingSend(false)
1113       .closeSocketInResponse(true)
1114       .run();
1115   ClientCloseOnErrorTest()
1116       .forcePendingSend(false)
1117       .closeSocketInResponse(false)
1118       .run();
1119   ClientCloseOnErrorTest()
1120       .forcePendingSend(true)
1121       .closeSocketInResponse(true)
1122       .run();
1123   ClientCloseOnErrorTest()
1124       .forcePendingSend(true)
1125       .closeSocketInResponse(false)
1126       .run();
1127 }
1128 
1129 class DestroyAsyncTransport : public folly::AsyncTransport {
1130  public:
DestroyAsyncTransport()1131   DestroyAsyncTransport() : cb_(nullptr) {}
setReadCB(folly::AsyncTransport::ReadCallback * callback)1132   void setReadCB(folly::AsyncTransport::ReadCallback* callback) override {
1133     cb_ = callback;
1134   }
getReadCallback() const1135   ReadCallback* getReadCallback() const override {
1136     return dynamic_cast<ReadCallback*>(cb_);
1137   }
write(folly::AsyncTransport::WriteCallback *,const void *,size_t,folly::WriteFlags)1138   void write(
1139       folly::AsyncTransport::WriteCallback*,
1140       const void*,
1141       size_t,
1142       folly::WriteFlags) override {}
writev(folly::AsyncTransport::WriteCallback *,const iovec *,size_t,folly::WriteFlags)1143   void writev(
1144       folly::AsyncTransport::WriteCallback*,
1145       const iovec*,
1146       size_t,
1147       folly::WriteFlags) override {}
writeChain(folly::AsyncTransport::WriteCallback *,std::unique_ptr<folly::IOBuf> &&,folly::WriteFlags)1148   void writeChain(
1149       folly::AsyncTransport::WriteCallback*,
1150       std::unique_ptr<folly::IOBuf>&&,
1151       folly::WriteFlags) override {}
close()1152   void close() override {}
closeNow()1153   void closeNow() override {}
shutdownWrite()1154   void shutdownWrite() override {}
shutdownWriteNow()1155   void shutdownWriteNow() override {}
good() const1156   bool good() const override { return true; }
readable() const1157   bool readable() const override { return false; }
connecting() const1158   bool connecting() const override { return false; }
error() const1159   bool error() const override { return false; }
attachEventBase(folly::EventBase *)1160   void attachEventBase(folly::EventBase*) override {}
detachEventBase()1161   void detachEventBase() override {}
isDetachable() const1162   bool isDetachable() const override { return true; }
getEventBase() const1163   folly::EventBase* getEventBase() const override { return nullptr; }
setSendTimeout(uint32_t)1164   void setSendTimeout(uint32_t /* ms */) override {}
getSendTimeout() const1165   uint32_t getSendTimeout() const override { return 0; }
getLocalAddress(folly::SocketAddress *) const1166   void getLocalAddress(folly::SocketAddress*) const override {}
getPeerAddress(folly::SocketAddress *) const1167   void getPeerAddress(folly::SocketAddress*) const override {}
getAppBytesWritten() const1168   size_t getAppBytesWritten() const override { return 0; }
getRawBytesWritten() const1169   size_t getRawBytesWritten() const override { return 0; }
getAppBytesReceived() const1170   size_t getAppBytesReceived() const override { return 0; }
getRawBytesReceived() const1171   size_t getRawBytesReceived() const override { return 0; }
setEorTracking(bool)1172   void setEorTracking(bool /* track */) override {}
isEorTrackingEnabled() const1173   bool isEorTrackingEnabled() const override { return false; }
1174 
invokeEOF()1175   void invokeEOF() { cb_->readEOF(); }
1176 
1177  private:
1178   folly::AsyncTransport::ReadCallback* cb_;
1179 };
1180 
1181 class DestroyRecvCallback : public MessageChannel::RecvCallback {
1182  public:
1183   typedef std::unique_ptr<Cpp2Channel, folly::DelayedDestruction::Destructor>
1184       ChannelPointer;
DestroyRecvCallback(ChannelPointer && channel)1185   explicit DestroyRecvCallback(ChannelPointer&& channel)
1186       : channel_(std::move(channel)), invocations_(0) {
1187     channel_->setReceiveCallback(this);
1188   }
messageReceived(std::unique_ptr<folly::IOBuf> &&,std::unique_ptr<apache::thrift::transport::THeader> &&)1189   void messageReceived(
1190       std::unique_ptr<folly::IOBuf>&&,
1191       std::unique_ptr<apache::thrift::transport::THeader>&&) override {}
messageChannelEOF()1192   void messageChannelEOF() override {
1193     EXPECT_EQ(invocations_, 0);
1194     invocations_++;
1195     channel_.reset();
1196   }
messageReceiveErrorWrapped(folly::exception_wrapper &&)1197   void messageReceiveErrorWrapped(folly::exception_wrapper&&) override {}
1198 
1199  private:
1200   ChannelPointer channel_;
1201   int invocations_;
1202 };
1203 
TEST(Channel,DestroyInEOF)1204 TEST(Channel, DestroyInEOF) {
1205   auto t = new DestroyAsyncTransport();
1206   auto transport = folly::AsyncTransport::UniquePtr(t);
1207   auto channel = createChannel<Cpp2Channel>(std::move(transport));
1208   DestroyRecvCallback drc(std::move(channel));
1209   t->invokeEOF();
1210 }
1211 
1212 class NullCloseCallback : public CloseCallback {
1213  public:
channelClosed()1214   void channelClosed() override {}
1215 };
1216