1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_
18 #define _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 1
19 
20 #include <boost/lexical_cast.hpp>
21 #include <thrift/lib/cpp/transport/TBufferTransports.h>
22 #include <thrift/lib/cpp/transport/TVirtualTransport.h>
23 
24 struct z_stream_s;
25 
26 namespace apache {
27 namespace thrift {
28 namespace transport {
29 
30 class TZlibTransportException : public TTransportException {
31  public:
TZlibTransportException(int status,const char * msg)32   TZlibTransportException(int status, const char* msg)
33       : TTransportException(
34             TTransportException::INTERNAL_ERROR, errorMessage(status, msg)),
35         zlib_status_(status),
36         zlib_msg_(msg == nullptr ? "(null)" : msg) {}
37 
throw()38   ~TZlibTransportException() throw() override {}
39 
getZlibStatus()40   int getZlibStatus() { return zlib_status_; }
getZlibMessage()41   std::string getZlibMessage() { return zlib_msg_; }
42 
errorMessage(int status,const char * msg)43   static std::string errorMessage(int status, const char* msg) {
44     std::string rv = "zlib error: ";
45     if (msg) {
46       rv += msg;
47     } else {
48       rv += "(no message)";
49     }
50     rv += " (status = ";
51     rv += boost::lexical_cast<std::string>(status);
52     rv += ")";
53     return rv;
54   }
55 
56   int zlib_status_;
57   std::string zlib_msg_;
58 };
59 
60 /**
61  * This transport uses zlib's compressed format on the "far" side.
62  *
63  * There are two kinds of TZlibTransport objects:
64  * - Standalone objects are used to encode self-contained chunks of data
65  *   (like structures).  They include checksums.
66  * - Non-standalone transports are used for RPC.  They are not implemented yet.
67  *
68  * TODO(dreiss): Don't do an extra copy of the compressed data if
69  *               the underlying transport is TBuffered or TMemory.
70  *
71  */
72 class TZlibTransport : public TVirtualTransport<TZlibTransport> {
73  public:
74   /**
75    * @param transport    The transport to read compressed data from
76    *                     and write compressed data to.
77    * @param urbuf_size   Uncompressed buffer size for reading.
78    * @param crbuf_size   Compressed buffer size for reading.
79    * @param uwbuf_size   Uncompressed buffer size for writing.
80    * @param cwbuf_size   Compressed buffer size for writing.
81    *
82    * TODO(dreiss): Write a constructor that isn't a pain.
83    */
84   explicit TZlibTransport(
85       std::shared_ptr<TTransport> transport,
86       size_t urbuf_size = DEFAULT_URBUF_SIZE,
87       size_t crbuf_size = DEFAULT_CRBUF_SIZE,
88       size_t uwbuf_size = DEFAULT_UWBUF_SIZE,
89       size_t cwbuf_size = DEFAULT_CWBUF_SIZE)
transport_(transport)90       : transport_(transport),
91         urpos_(0),
92         uwpos_(0),
93         input_ended_(false),
94         output_finished_(false),
95         urbuf_size_(urbuf_size),
96         crbuf_size_(crbuf_size),
97         uwbuf_size_(uwbuf_size),
98         cwbuf_size_(cwbuf_size),
99         urbuf_(nullptr),
100         crbuf_(nullptr),
101         uwbuf_(nullptr),
102         cwbuf_(nullptr),
103         rstream_(nullptr),
104         wstream_(nullptr) {
105     if (uwbuf_size_ < MIN_DIRECT_DEFLATE_SIZE) {
106       // Have to copy this into a local because of a linking issue.
107       int minimum = MIN_DIRECT_DEFLATE_SIZE;
108       throw TTransportException(
109           TTransportException::BAD_ARGS,
110           "TZLibTransport: uncompressed write buffer must be at least" +
111               boost::lexical_cast<std::string>(minimum) + ".");
112     }
113 
114     try {
115       urbuf_ = new uint8_t[urbuf_size];
116       crbuf_ = new uint8_t[crbuf_size];
117       uwbuf_ = new uint8_t[uwbuf_size];
118       cwbuf_ = new uint8_t[cwbuf_size];
119 
120       // Don't call this outside of the constructor.
121       initZlib();
122 
123     } catch (...) {
124       delete[] urbuf_;
125       delete[] crbuf_;
126       delete[] uwbuf_;
127       delete[] cwbuf_;
128       throw;
129     }
130   }
131 
132   // Don't call this outside of the constructor.
133   void initZlib();
134 
135   /**
136    * TZlibTransport destructor.
137    *
138    * Warning: Destroying a TZlibTransport object may discard any written but
139    * unflushed data.  You must explicitly call flush() or finish() to ensure
140    * that data is actually written and flushed to the underlying transport.
141    */
142   ~TZlibTransport() override;
143 
144   bool isOpen() override;
145   bool peek() override;
146 
open()147   void open() override { transport_->open(); }
148 
close()149   void close() override { transport_->close(); }
150 
151   uint32_t read(uint8_t* buf, uint32_t len);
152 
153   void write(const uint8_t* buf, uint32_t len);
154 
155   void flush() override;
156 
157   /**
158    * Finalize the zlib stream.
159    *
160    * This causes zlib to flush any pending write data and write end-of-stream
161    * information, including the checksum.  Once finish() has been called, no
162    * new data can be written to the stream.
163    */
164   void finish();
165 
166   const uint8_t* borrow(uint8_t* buf, uint32_t* len);
167 
168   void consume(uint32_t len);
169 
170   /**
171    * Verify the checksum at the end of the zlib stream.
172    *
173    * This may only be called after all data has been read.
174    * It verifies the checksum that was written by the finish() call.
175    */
176   void verifyChecksum();
177 
178   /**
179    * TODO(someone_smart): Choose smart defaults.
180    */
181   static const size_t DEFAULT_URBUF_SIZE = 128;
182   static const size_t DEFAULT_CRBUF_SIZE = 1024;
183   static const size_t DEFAULT_UWBUF_SIZE = 128;
184   static const size_t DEFAULT_CWBUF_SIZE = 1024;
185 
186  protected:
187   inline void checkZlibRv(int status, const char* msg);
188   inline void checkZlibRvNothrow(int status, const char* msg);
189   inline size_t readAvail();
190   void flushToTransport(int flush);
191   void flushToZlib(const uint8_t* buf, int len, int flush);
192   bool readFromZlib();
193 
194  protected:
195   // Writes smaller than this are buffered up.
196   // Larger (or equal) writes are dumped straight to zlib.
197   static const size_t MIN_DIRECT_DEFLATE_SIZE = 32;
198 
199   std::shared_ptr<TTransport> transport_;
200 
201   size_t urpos_;
202   size_t uwpos_;
203 
204   /// True iff zlib has reached the end of the input stream.
205   bool input_ended_;
206   /// True iff we have finished the output stream.
207   bool output_finished_;
208 
209   size_t urbuf_size_;
210   size_t crbuf_size_;
211   size_t uwbuf_size_;
212   size_t cwbuf_size_;
213 
214   uint8_t* urbuf_;
215   uint8_t* crbuf_;
216   uint8_t* uwbuf_;
217   uint8_t* cwbuf_;
218 
219   struct z_stream_s* rstream_;
220   struct z_stream_s* wstream_;
221 };
222 
223 /**
224  * Wraps a transport into a zlibbed one.
225  *
226  */
227 class TZlibTransportFactory : public TTransportFactory {
228  public:
TZlibTransportFactory()229   TZlibTransportFactory() {}
230 
~TZlibTransportFactory()231   ~TZlibTransportFactory() override {}
232 
getTransport(std::shared_ptr<TTransport> trans)233   std::shared_ptr<TTransport> getTransport(
234       std::shared_ptr<TTransport> trans) override {
235     return std::shared_ptr<TTransport>(new TZlibTransport(trans));
236   }
237 };
238 
239 /**
240  * Wraps a transport into a framed, zlibbed one.
241  */
242 class TFramedZlibTransportFactory : public TTransportFactory {
243  public:
TFramedZlibTransportFactory()244   TFramedZlibTransportFactory() {}
245 
~TFramedZlibTransportFactory()246   ~TFramedZlibTransportFactory() override {}
247 
getTransport(std::shared_ptr<TTransport> trans)248   std::shared_ptr<TTransport> getTransport(
249       std::shared_ptr<TTransport> trans) override {
250     std::shared_ptr<TTransport> framedTransport(new TFramedTransport(trans));
251     return std::shared_ptr<TTransport>(new TZlibTransport(framedTransport));
252   }
253 };
254 
255 } // namespace transport
256 } // namespace thrift
257 } // namespace apache
258 
259 #endif // #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_
260