1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_
21 #define _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 1
22 
23 #include <boost/lexical_cast.hpp>
24 #include <thrift/transport/TTransport.h>
25 #include <thrift/transport/TVirtualTransport.h>
26 #include <zlib.h>
27 
28 struct z_stream_s;
29 
30 namespace apache
31 {
32 namespace thrift
33 {
34 namespace transport
35 {
36 
37 class TZlibTransportException : public TTransportException
38 {
39 public:
TZlibTransportException(int status,const char * msg)40     TZlibTransportException(int status, const char* msg) :
41         TTransportException(TTransportException::INTERNAL_ERROR,
42                             errorMessage(status, msg)),
43         zlib_status_(status),
44         zlib_msg_(msg == NULL ? "(null)" : msg) {}
45 
~TZlibTransportException()46     virtual ~TZlibTransportException() throw() {}
47 
getZlibStatus()48     int getZlibStatus()
49     {
50         return zlib_status_;
51     }
getZlibMessage()52     std::string getZlibMessage()
53     {
54         return zlib_msg_;
55     }
56 
errorMessage(int status,const char * msg)57     static std::string errorMessage(int status, const char* msg)
58     {
59         std::string rv = "zlib error: ";
60 
61         if (msg)
62         {
63             rv += msg;
64         }
65         else
66         {
67             rv += "(no message)";
68         }
69 
70         rv += " (status = ";
71         rv += boost::lexical_cast<std::string>(status);
72         rv += ")";
73         return rv;
74     }
75 
76     int zlib_status_;
77     std::string zlib_msg_;
78 };
79 
80 /**
81  * This transport uses zlib to compress on write and decompress on read
82  *
83  * TODO(dreiss): Don't do an extra copy of the compressed data if
84  *               the underlying transport is TBuffered or TMemory.
85  *
86  */
87 class TZlibTransport : public TVirtualTransport<TZlibTransport>
88 {
89 public:
90 
91     /**
92      * @param transport    The transport to read compressed data from
93      *                     and write compressed data to.
94      * @param urbuf_size   Uncompressed buffer size for reading.
95      * @param crbuf_size   Compressed buffer size for reading.
96      * @param uwbuf_size   Uncompressed buffer size for writing.
97      * @param cwbuf_size   Compressed buffer size for writing.
98      * @param comp_level   Compression level (0=none[fast], 6=default, 9=max[slow]).
99      */
100     TZlibTransport(boost::shared_ptr<TTransport> transport,
101                    int urbuf_size = DEFAULT_URBUF_SIZE,
102                    int crbuf_size = DEFAULT_CRBUF_SIZE,
103                    int uwbuf_size = DEFAULT_UWBUF_SIZE,
104                    int cwbuf_size = DEFAULT_CWBUF_SIZE,
105                    int16_t comp_level = Z_DEFAULT_COMPRESSION) :
transport_(transport)106         transport_(transport),
107         urpos_(0),
108         uwpos_(0),
109         input_ended_(false),
110         output_finished_(false),
111         urbuf_size_(urbuf_size),
112         crbuf_size_(crbuf_size),
113         uwbuf_size_(uwbuf_size),
114         cwbuf_size_(cwbuf_size),
115         urbuf_(NULL),
116         crbuf_(NULL),
117         uwbuf_(NULL),
118         cwbuf_(NULL),
119         rstream_(NULL),
120         wstream_(NULL),
121         comp_level_(comp_level)
122     {
123         if (uwbuf_size_ < MIN_DIRECT_DEFLATE_SIZE)
124         {
125             // Have to copy this into a local because of a linking issue.
126             int minimum = MIN_DIRECT_DEFLATE_SIZE;
127             throw TTransportException(
128                 TTransportException::BAD_ARGS,
129                 "TZLibTransport: uncompressed write buffer must be at least"
130                 + boost::lexical_cast<std::string>(minimum) + ".");
131         }
132 
133         try
134         {
135             urbuf_ = new uint8_t[urbuf_size];
136             crbuf_ = new uint8_t[crbuf_size];
137             uwbuf_ = new uint8_t[uwbuf_size];
138             cwbuf_ = new uint8_t[cwbuf_size];
139 
140             // Don't call this outside of the constructor.
141             initZlib();
142 
143         }
144         catch (...)
145         {
146             delete[] urbuf_;
147             delete[] crbuf_;
148             delete[] uwbuf_;
149             delete[] cwbuf_;
150             throw;
151         }
152     }
153 
154     // Don't call this outside of the constructor.
155     void initZlib();
156 
157     /**
158      * TZlibTransport destructor.
159      *
160      * Warning: Destroying a TZlibTransport object may discard any written but
161      * unflushed data.  You must explicitly call flush() or finish() to ensure
162      * that data is actually written and flushed to the underlying transport.
163      */
164     ~TZlibTransport();
165 
166     bool isOpen();
167     bool peek();
168 
open()169     void open()
170     {
171         transport_->open();
172     }
173 
close()174     void close()
175     {
176         transport_->close();
177     }
178 
179     uint32_t read(uint8_t* buf, uint32_t len);
180 
181     void write(const uint8_t* buf, uint32_t len);
182 
183     void flush();
184 
185     /**
186      * Finalize the zlib stream.
187      *
188      * This causes zlib to flush any pending write data and write end-of-stream
189      * information, including the checksum.  Once finish() has been called, no
190      * new data can be written to the stream.
191      */
192     void finish();
193 
194     const uint8_t* borrow(uint8_t* buf, uint32_t* len);
195 
196     void consume(uint32_t len);
197 
198     /**
199      * Verify the checksum at the end of the zlib stream.
200      *
201      * This may only be called after all data has been read.
202      * It verifies the checksum that was written by the finish() call.
203      */
204     void verifyChecksum();
205 
206     /**
207      * TODO(someone_smart): Choose smart defaults.
208      */
209     static const int DEFAULT_URBUF_SIZE = 128;
210     static const int DEFAULT_CRBUF_SIZE = 1024;
211     static const int DEFAULT_UWBUF_SIZE = 128;
212     static const int DEFAULT_CWBUF_SIZE = 1024;
213 
214 protected:
215 
216     inline void checkZlibRv(int status, const char* msg);
217     inline void checkZlibRvNothrow(int status, const char* msg);
218     inline int readAvail();
219     void flushToTransport(int flush);
220     void flushToZlib(const uint8_t* buf, int len, int flush);
221     bool readFromZlib();
222 
223 protected:
224     // Writes smaller than this are buffered up.
225     // Larger (or equal) writes are dumped straight to zlib.
226     static const uint32_t MIN_DIRECT_DEFLATE_SIZE = 32;
227 
228     boost::shared_ptr<TTransport> transport_;
229 
230     int urpos_;
231     int uwpos_;
232 
233     /// True iff zlib has reached the end of the input stream.
234     bool input_ended_;
235     /// True iff we have finished the output stream.
236     bool output_finished_;
237 
238     uint32_t urbuf_size_;
239     uint32_t crbuf_size_;
240     uint32_t uwbuf_size_;
241     uint32_t cwbuf_size_;
242 
243     uint8_t* urbuf_;
244     uint8_t* crbuf_;
245     uint8_t* uwbuf_;
246     uint8_t* cwbuf_;
247 
248     struct z_stream_s* rstream_;
249     struct z_stream_s* wstream_;
250 
251     const int comp_level_;
252 };
253 
254 
255 /**
256  * Wraps a transport into a zlibbed one.
257  *
258  */
259 class TZlibTransportFactory : public TTransportFactory
260 {
261 public:
TZlibTransportFactory()262     TZlibTransportFactory() {}
263 
~TZlibTransportFactory()264     virtual ~TZlibTransportFactory() {}
265 
getTransport(boost::shared_ptr<TTransport> trans)266     virtual boost::shared_ptr<TTransport> getTransport(
267         boost::shared_ptr<TTransport> trans)
268     {
269         return boost::shared_ptr<TTransport>(new TZlibTransport(trans));
270     }
271 };
272 
273 
274 }
275 }
276 } // apache::thrift::transport
277 
278 #endif // #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_
279