1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef THRIFT_ASYNC_TZLIBASYNCCHANNEL_H_
18 #define THRIFT_ASYNC_TZLIBASYNCCHANNEL_H_ 1
19 
20 #include <folly/io/async/AsyncTransport.h>
21 #include <thrift/lib/cpp/async/TAsyncEventChannel.h>
22 #include <thrift/lib/cpp/transport/TZlibTransport.h>
23 
24 namespace apache {
25 namespace thrift {
26 namespace async {
27 
28 class TZlibAsyncChannel : public TAsyncEventChannel {
29  public:
30   explicit TZlibAsyncChannel(
31       const std::shared_ptr<TAsyncEventChannel>& channel);
32 
33   /**
34    * Helper function to create a shared_ptr<TZlibAsyncChannel>.
35    *
36    * This passes in the correct destructor object, since TZlibAsyncChannel's
37    * destructor is protected and cannot be invoked directly.
38    */
newChannel(const std::shared_ptr<TAsyncEventChannel> & channel)39   static std::shared_ptr<TZlibAsyncChannel> newChannel(
40       const std::shared_ptr<TAsyncEventChannel>& channel) {
41     return std::shared_ptr<TZlibAsyncChannel>(
42         new TZlibAsyncChannel(channel), Destructor());
43   }
readable()44   bool readable() const override { return channel_->readable(); }
good()45   bool good() const override { return channel_->good(); }
error()46   bool error() const override { return channel_->error(); }
timedOut()47   bool timedOut() const override { return channel_->timedOut(); }
isIdle()48   bool isIdle() const override { return channel_->isIdle(); }
49 
50   void sendMessage(
51       const VoidCallback& cob,
52       const VoidCallback& errorCob,
53       transport::TMemoryBuffer* message) override;
54   void recvMessage(
55       const VoidCallback& cob,
56       const VoidCallback& errorCob,
57       transport::TMemoryBuffer* message) override;
58   void sendAndRecvMessage(
59       const VoidCallback& cob,
60       const VoidCallback& errorCob,
61       transport::TMemoryBuffer* sendBuf,
62       transport::TMemoryBuffer* recvBuf) override;
63 
getTransport()64   std::shared_ptr<folly::AsyncTransport> getTransport() override {
65     return channel_->getTransport();
66   }
67 
attachEventBase(folly::EventBase * eventBase)68   void attachEventBase(folly::EventBase* eventBase) override {
69     channel_->attachEventBase(eventBase);
70   }
detachEventBase()71   void detachEventBase() override { channel_->detachEventBase(); }
72 
getRecvTimeout()73   uint32_t getRecvTimeout() const override {
74     return channel_->getRecvTimeout();
75   }
76 
setRecvTimeout(uint32_t milliseconds)77   void setRecvTimeout(uint32_t milliseconds) override {
78     channel_->setRecvTimeout(milliseconds);
79   }
80 
cancelCallbacks()81   void cancelCallbacks() override {
82     sendRequest_.cancelCallbacks();
83     recvRequest_.cancelCallbacks();
84   }
85 
86  protected:
87   /**
88    * Protected destructor.
89    *
90    * Users of TZlibAsyncChannel must never delete it directly.  Instead,
91    * invoke destroy().
92    */
~TZlibAsyncChannel()93   ~TZlibAsyncChannel() override {}
94 
95  private:
96   class SendRequest {
97    public:
98     SendRequest();
99 
isSet()100     bool isSet() const { return static_cast<bool>(callback_); }
101 
102     void set(
103         const VoidCallback& callback,
104         const VoidCallback& errorCallback,
105         transport::TMemoryBuffer* message);
106 
107     void send(TAsyncEventChannel* channel);
108 
cancelCallbacks()109     void cancelCallbacks() {
110       callback_ = VoidCallback();
111       errorCallback_ = VoidCallback();
112     }
113 
114    private:
115     void invokeCallback(VoidCallback callback);
116     void sendSuccess();
117     void sendError();
118 
119     std::shared_ptr<transport::TMemoryBuffer> compressedBuffer_;
120     transport::TZlibTransport zlibTransport_;
121     VoidCallback sendSuccess_;
122     VoidCallback sendError_;
123 
124     VoidCallback callback_;
125     VoidCallback errorCallback_;
126   };
127 
128   class RecvRequest {
129    public:
130     RecvRequest();
131 
isSet()132     bool isSet() const { return static_cast<bool>(callback_); }
133 
134     void set(
135         const VoidCallback& callback,
136         const VoidCallback& errorCallback,
137         transport::TMemoryBuffer* message);
138 
139     void recv(TAsyncEventChannel* channel);
140 
cancelCallbacks()141     void cancelCallbacks() {
142       callback_ = VoidCallback();
143       errorCallback_ = VoidCallback();
144     }
145 
146    private:
147     void invokeCallback(VoidCallback callback);
148     void recvSuccess();
149     void recvError();
150 
151     std::shared_ptr<transport::TMemoryBuffer> compressedBuffer_;
152     transport::TZlibTransport zlibTransport_;
153     VoidCallback recvSuccess_;
154     VoidCallback recvError_;
155 
156     VoidCallback callback_;
157     VoidCallback errorCallback_;
158     transport::TMemoryBuffer* callbackBuffer_;
159   };
160 
161   std::shared_ptr<TAsyncEventChannel> channel_;
162 
163   // TODO: support multiple pending send requests
164   SendRequest sendRequest_;
165   RecvRequest recvRequest_;
166 };
167 
168 } // namespace async
169 } // namespace thrift
170 } // namespace apache
171 
172 #endif // THRIFT_ASYNC_TZLIBASYNCCHANNEL_H_
173