1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_
21 #define _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 1
22 
23 #include <thrift/transport/TTransport.h>
24 #include <thrift/transport/TVirtualTransport.h>
25 #include <thrift/TToString.h>
26 #include <zlib.h>
27 
28 struct z_stream_s;
29 
30 namespace apache {
31 namespace thrift {
32 namespace transport {
33 
34 class TZlibTransportException : public TTransportException {
35 public:
TZlibTransportException(int status,const char * msg)36   TZlibTransportException(int status, const char* msg)
37     : TTransportException(TTransportException::INTERNAL_ERROR, errorMessage(status, msg)),
38       zlib_status_(status),
39       zlib_msg_(msg == nullptr ? "(null)" : msg) {}
40 
41   ~TZlibTransportException() noexcept override = default;
42 
getZlibStatus()43   int getZlibStatus() { return zlib_status_; }
getZlibMessage()44   std::string getZlibMessage() { return zlib_msg_; }
45 
errorMessage(int status,const char * msg)46   static std::string errorMessage(int status, const char* msg) {
47     std::string rv = "zlib error: ";
48     if (msg) {
49       rv += msg;
50     } else {
51       rv += "(no message)";
52     }
53     rv += " (status = ";
54     rv += to_string(status);
55     rv += ")";
56     return rv;
57   }
58 
59   int zlib_status_;
60   std::string zlib_msg_;
61 };
62 
63 /**
64  * This transport uses zlib to compress on write and decompress on read
65  *
66  * TODO(dreiss): Don't do an extra copy of the compressed data if
67  *               the underlying transport is TBuffered or TMemory.
68  *
69  */
70 class TZlibTransport : public TVirtualTransport<TZlibTransport> {
71 public:
72   /**
73    * @param transport    The transport to read compressed data from
74    *                     and write compressed data to.
75    * @param urbuf_size   Uncompressed buffer size for reading.
76    * @param crbuf_size   Compressed buffer size for reading.
77    * @param uwbuf_size   Uncompressed buffer size for writing.
78    * @param cwbuf_size   Compressed buffer size for writing.
79    * @param comp_level   Compression level (0=none[fast], 6=default, 9=max[slow]).
80    */
81   TZlibTransport(std::shared_ptr<TTransport> transport,
82                  int urbuf_size = DEFAULT_URBUF_SIZE,
83                  int crbuf_size = DEFAULT_CRBUF_SIZE,
84                  int uwbuf_size = DEFAULT_UWBUF_SIZE,
85                  int cwbuf_size = DEFAULT_CWBUF_SIZE,
86                  int16_t comp_level = Z_DEFAULT_COMPRESSION,
87                  std::shared_ptr<TConfiguration> config = nullptr)
TVirtualTransport(config)88     : TVirtualTransport(config),
89       transport_(transport),
90       urpos_(0),
91       uwpos_(0),
92       input_ended_(false),
93       output_finished_(false),
94       urbuf_size_(urbuf_size),
95       crbuf_size_(crbuf_size),
96       uwbuf_size_(uwbuf_size),
97       cwbuf_size_(cwbuf_size),
98       urbuf_(nullptr),
99       crbuf_(nullptr),
100       uwbuf_(nullptr),
101       cwbuf_(nullptr),
102       rstream_(nullptr),
103       wstream_(nullptr),
104       comp_level_(comp_level) {
105     if (uwbuf_size_ < MIN_DIRECT_DEFLATE_SIZE) {
106       // Have to copy this into a local because of a linking issue.
107       int minimum = MIN_DIRECT_DEFLATE_SIZE;
108       throw TTransportException(TTransportException::BAD_ARGS,
109                                 "TZLibTransport: uncompressed write buffer must be at least"
110                                 + to_string(minimum) + ".");
111     }
112 
113     try {
114       urbuf_ = new uint8_t[urbuf_size];
115       crbuf_ = new uint8_t[crbuf_size];
116       uwbuf_ = new uint8_t[uwbuf_size];
117       cwbuf_ = new uint8_t[cwbuf_size];
118 
119       // Don't call this outside of the constructor.
120       initZlib();
121 
122     } catch (...) {
123       delete[] urbuf_;
124       delete[] crbuf_;
125       delete[] uwbuf_;
126       delete[] cwbuf_;
127       throw;
128     }
129   }
130 
131   // Don't call this outside of the constructor.
132   void initZlib();
133 
134   /**
135    * TZlibTransport destructor.
136    *
137    * Warning: Destroying a TZlibTransport object may discard any written but
138    * unflushed data.  You must explicitly call flush() or finish() to ensure
139    * that data is actually written and flushed to the underlying transport.
140    */
141   ~TZlibTransport() override;
142 
143   bool isOpen() const override;
144   bool peek() override;
145 
open()146   void open() override { transport_->open(); }
147 
close()148   void close() override { transport_->close(); }
149 
150   uint32_t read(uint8_t* buf, uint32_t len);
151 
152   void write(const uint8_t* buf, uint32_t len);
153 
154   void flush() override;
155 
156   /**
157    * Finalize the zlib stream.
158    *
159    * This causes zlib to flush any pending write data and write end-of-stream
160    * information, including the checksum.  Once finish() has been called, no
161    * new data can be written to the stream.
162    */
163   void finish();
164 
165   const uint8_t* borrow(uint8_t* buf, uint32_t* len);
166 
167   void consume(uint32_t len);
168 
169   /**
170    * Verify the checksum at the end of the zlib stream.
171    *
172    * This may only be called after all data has been read.
173    * It verifies the checksum that was written by the finish() call.
174    */
175   void verifyChecksum();
176 
177   /**
178    * TODO(someone_smart): Choose smart defaults.
179    */
180   static const int DEFAULT_URBUF_SIZE = 128;
181   static const int DEFAULT_CRBUF_SIZE = 1024;
182   static const int DEFAULT_UWBUF_SIZE = 128;
183   static const int DEFAULT_CWBUF_SIZE = 1024;
184 
getUnderlyingTransport()185   std::shared_ptr<TTransport> getUnderlyingTransport() const { return transport_; }
186 
187 protected:
188   inline void checkZlibRv(int status, const char* msg);
189   inline void checkZlibRvNothrow(int status, const char* msg);
190   inline int readAvail() const;
191   void flushToTransport(int flush);
192   void flushToZlib(const uint8_t* buf, int len, int flush);
193   bool readFromZlib();
194 
195 protected:
196   // Writes smaller than this are buffered up.
197   // Larger (or equal) writes are dumped straight to zlib.
198   static const uint32_t MIN_DIRECT_DEFLATE_SIZE = 32;
199 
200   std::shared_ptr<TTransport> transport_;
201 
202   int urpos_;
203   int uwpos_;
204 
205   /// True iff zlib has reached the end of the input stream.
206   bool input_ended_;
207   /// True iff we have finished the output stream.
208   bool output_finished_;
209 
210   uint32_t urbuf_size_;
211   uint32_t crbuf_size_;
212   uint32_t uwbuf_size_;
213   uint32_t cwbuf_size_;
214 
215   uint8_t* urbuf_;
216   uint8_t* crbuf_;
217   uint8_t* uwbuf_;
218   uint8_t* cwbuf_;
219 
220   struct z_stream_s* rstream_;
221   struct z_stream_s* wstream_;
222 
223   const int comp_level_;
224 };
225 
226 /**
227  * Wraps a transport into a zlibbed one.
228  *
229  */
230 class TZlibTransportFactory : public TTransportFactory {
231 public:
232   TZlibTransportFactory() = default;
233 
234   /**
235    * Wraps a transport factory into a zlibbed one.
236    */
237   TZlibTransportFactory(std::shared_ptr<TTransportFactory> transportFactory);
238 
239   ~TZlibTransportFactory() override = default;
240 
241   std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override;
242 
243 protected:
244   std::shared_ptr<TTransportFactory> transportFactory_;
245 };
246 
247 }
248 }
249 } // apache::thrift::transport
250 
251 #endif // #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_
252