1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 21 #define _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 1 22 23 #include <thrift/transport/TTransport.h> 24 #include <thrift/transport/TVirtualTransport.h> 25 #include <thrift/TToString.h> 26 #include <zlib.h> 27 28 struct z_stream_s; 29 30 namespace apache { 31 namespace thrift { 32 namespace transport { 33 34 class TZlibTransportException : public TTransportException { 35 public: TZlibTransportException(int status,const char * msg)36 TZlibTransportException(int status, const char* msg) 37 : TTransportException(TTransportException::INTERNAL_ERROR, errorMessage(status, msg)), 38 zlib_status_(status), 39 zlib_msg_(msg == nullptr ? "(null)" : msg) {} 40 41 ~TZlibTransportException() noexcept override = default; 42 getZlibStatus()43 int getZlibStatus() { return zlib_status_; } getZlibMessage()44 std::string getZlibMessage() { return zlib_msg_; } 45 errorMessage(int status,const char * msg)46 static std::string errorMessage(int status, const char* msg) { 47 std::string rv = "zlib error: "; 48 if (msg) { 49 rv += msg; 50 } else { 51 rv += "(no message)"; 52 } 53 rv += " (status = "; 54 rv += to_string(status); 55 rv += ")"; 56 return rv; 57 } 58 59 int zlib_status_; 60 std::string zlib_msg_; 61 }; 62 63 /** 64 * This transport uses zlib to compress on write and decompress on read 65 * 66 * TODO(dreiss): Don't do an extra copy of the compressed data if 67 * the underlying transport is TBuffered or TMemory. 68 * 69 */ 70 class TZlibTransport : public TVirtualTransport<TZlibTransport> { 71 public: 72 /** 73 * @param transport The transport to read compressed data from 74 * and write compressed data to. 75 * @param urbuf_size Uncompressed buffer size for reading. 76 * @param crbuf_size Compressed buffer size for reading. 77 * @param uwbuf_size Uncompressed buffer size for writing. 78 * @param cwbuf_size Compressed buffer size for writing. 79 * @param comp_level Compression level (0=none[fast], 6=default, 9=max[slow]). 80 */ 81 TZlibTransport(std::shared_ptr<TTransport> transport, 82 int urbuf_size = DEFAULT_URBUF_SIZE, 83 int crbuf_size = DEFAULT_CRBUF_SIZE, 84 int uwbuf_size = DEFAULT_UWBUF_SIZE, 85 int cwbuf_size = DEFAULT_CWBUF_SIZE, 86 int16_t comp_level = Z_DEFAULT_COMPRESSION) transport_(transport)87 : transport_(transport), 88 urpos_(0), 89 uwpos_(0), 90 input_ended_(false), 91 output_finished_(false), 92 urbuf_size_(urbuf_size), 93 crbuf_size_(crbuf_size), 94 uwbuf_size_(uwbuf_size), 95 cwbuf_size_(cwbuf_size), 96 urbuf_(nullptr), 97 crbuf_(nullptr), 98 uwbuf_(nullptr), 99 cwbuf_(nullptr), 100 rstream_(nullptr), 101 wstream_(nullptr), 102 comp_level_(comp_level) { 103 if (uwbuf_size_ < MIN_DIRECT_DEFLATE_SIZE) { 104 // Have to copy this into a local because of a linking issue. 105 int minimum = MIN_DIRECT_DEFLATE_SIZE; 106 throw TTransportException(TTransportException::BAD_ARGS, 107 "TZLibTransport: uncompressed write buffer must be at least" 108 + to_string(minimum) + "."); 109 } 110 111 try { 112 urbuf_ = new uint8_t[urbuf_size]; 113 crbuf_ = new uint8_t[crbuf_size]; 114 uwbuf_ = new uint8_t[uwbuf_size]; 115 cwbuf_ = new uint8_t[cwbuf_size]; 116 117 // Don't call this outside of the constructor. 118 initZlib(); 119 120 } catch (...) { 121 delete[] urbuf_; 122 delete[] crbuf_; 123 delete[] uwbuf_; 124 delete[] cwbuf_; 125 throw; 126 } 127 } 128 129 // Don't call this outside of the constructor. 130 void initZlib(); 131 132 /** 133 * TZlibTransport destructor. 134 * 135 * Warning: Destroying a TZlibTransport object may discard any written but 136 * unflushed data. You must explicitly call flush() or finish() to ensure 137 * that data is actually written and flushed to the underlying transport. 138 */ 139 ~TZlibTransport() override; 140 141 bool isOpen() const override; 142 bool peek() override; 143 open()144 void open() override { transport_->open(); } 145 close()146 void close() override { transport_->close(); } 147 148 uint32_t read(uint8_t* buf, uint32_t len); 149 150 void write(const uint8_t* buf, uint32_t len); 151 152 void flush() override; 153 154 /** 155 * Finalize the zlib stream. 156 * 157 * This causes zlib to flush any pending write data and write end-of-stream 158 * information, including the checksum. Once finish() has been called, no 159 * new data can be written to the stream. 160 */ 161 void finish(); 162 163 const uint8_t* borrow(uint8_t* buf, uint32_t* len); 164 165 void consume(uint32_t len); 166 167 /** 168 * Verify the checksum at the end of the zlib stream. 169 * 170 * This may only be called after all data has been read. 171 * It verifies the checksum that was written by the finish() call. 172 */ 173 void verifyChecksum(); 174 175 /** 176 * TODO(someone_smart): Choose smart defaults. 177 */ 178 static const int DEFAULT_URBUF_SIZE = 128; 179 static const int DEFAULT_CRBUF_SIZE = 1024; 180 static const int DEFAULT_UWBUF_SIZE = 128; 181 static const int DEFAULT_CWBUF_SIZE = 1024; 182 getUnderlyingTransport()183 std::shared_ptr<TTransport> getUnderlyingTransport() const { return transport_; } 184 185 protected: 186 inline void checkZlibRv(int status, const char* msg); 187 inline void checkZlibRvNothrow(int status, const char* msg); 188 inline int readAvail() const; 189 void flushToTransport(int flush); 190 void flushToZlib(const uint8_t* buf, int len, int flush); 191 bool readFromZlib(); 192 193 protected: 194 // Writes smaller than this are buffered up. 195 // Larger (or equal) writes are dumped straight to zlib. 196 static const uint32_t MIN_DIRECT_DEFLATE_SIZE = 32; 197 198 std::shared_ptr<TTransport> transport_; 199 200 int urpos_; 201 int uwpos_; 202 203 /// True iff zlib has reached the end of the input stream. 204 bool input_ended_; 205 /// True iff we have finished the output stream. 206 bool output_finished_; 207 208 uint32_t urbuf_size_; 209 uint32_t crbuf_size_; 210 uint32_t uwbuf_size_; 211 uint32_t cwbuf_size_; 212 213 uint8_t* urbuf_; 214 uint8_t* crbuf_; 215 uint8_t* uwbuf_; 216 uint8_t* cwbuf_; 217 218 struct z_stream_s* rstream_; 219 struct z_stream_s* wstream_; 220 221 const int comp_level_; 222 }; 223 224 /** 225 * Wraps a transport into a zlibbed one. 226 * 227 */ 228 class TZlibTransportFactory : public TTransportFactory { 229 public: 230 TZlibTransportFactory() = default; 231 232 ~TZlibTransportFactory() override = default; 233 getTransport(std::shared_ptr<TTransport> trans)234 std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override { 235 return std::shared_ptr<TTransport>(new TZlibTransport(trans)); 236 } 237 }; 238 } 239 } 240 } // apache::thrift::transport 241 242 #endif // #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 243