1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 21 #define _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 1 22 23 #include <thrift/transport/TTransport.h> 24 #include <thrift/transport/TVirtualTransport.h> 25 #include <thrift/TToString.h> 26 #include <zlib.h> 27 28 struct z_stream_s; 29 30 namespace apache { 31 namespace thrift { 32 namespace transport { 33 34 class TZlibTransportException : public TTransportException { 35 public: TZlibTransportException(int status,const char * msg)36 TZlibTransportException(int status, const char* msg) 37 : TTransportException(TTransportException::INTERNAL_ERROR, errorMessage(status, msg)), 38 zlib_status_(status), 39 zlib_msg_(msg == nullptr ? "(null)" : msg) {} 40 41 ~TZlibTransportException() noexcept override = default; 42 getZlibStatus()43 int getZlibStatus() { return zlib_status_; } getZlibMessage()44 std::string getZlibMessage() { return zlib_msg_; } 45 errorMessage(int status,const char * msg)46 static std::string errorMessage(int status, const char* msg) { 47 std::string rv = "zlib error: "; 48 if (msg) { 49 rv += msg; 50 } else { 51 rv += "(no message)"; 52 } 53 rv += " (status = "; 54 rv += to_string(status); 55 rv += ")"; 56 return rv; 57 } 58 59 int zlib_status_; 60 std::string zlib_msg_; 61 }; 62 63 /** 64 * This transport uses zlib to compress on write and decompress on read 65 * 66 * TODO(dreiss): Don't do an extra copy of the compressed data if 67 * the underlying transport is TBuffered or TMemory. 68 * 69 */ 70 class TZlibTransport : public TVirtualTransport<TZlibTransport> { 71 public: 72 /** 73 * @param transport The transport to read compressed data from 74 * and write compressed data to. 75 * @param urbuf_size Uncompressed buffer size for reading. 76 * @param crbuf_size Compressed buffer size for reading. 77 * @param uwbuf_size Uncompressed buffer size for writing. 78 * @param cwbuf_size Compressed buffer size for writing. 79 * @param comp_level Compression level (0=none[fast], 6=default, 9=max[slow]). 80 */ 81 TZlibTransport(std::shared_ptr<TTransport> transport, 82 int urbuf_size = DEFAULT_URBUF_SIZE, 83 int crbuf_size = DEFAULT_CRBUF_SIZE, 84 int uwbuf_size = DEFAULT_UWBUF_SIZE, 85 int cwbuf_size = DEFAULT_CWBUF_SIZE, 86 int16_t comp_level = Z_DEFAULT_COMPRESSION, 87 std::shared_ptr<TConfiguration> config = nullptr) TVirtualTransport(config)88 : TVirtualTransport(config), 89 transport_(transport), 90 urpos_(0), 91 uwpos_(0), 92 input_ended_(false), 93 output_finished_(false), 94 urbuf_size_(urbuf_size), 95 crbuf_size_(crbuf_size), 96 uwbuf_size_(uwbuf_size), 97 cwbuf_size_(cwbuf_size), 98 urbuf_(nullptr), 99 crbuf_(nullptr), 100 uwbuf_(nullptr), 101 cwbuf_(nullptr), 102 rstream_(nullptr), 103 wstream_(nullptr), 104 comp_level_(comp_level) { 105 if (uwbuf_size_ < MIN_DIRECT_DEFLATE_SIZE) { 106 // Have to copy this into a local because of a linking issue. 107 int minimum = MIN_DIRECT_DEFLATE_SIZE; 108 throw TTransportException(TTransportException::BAD_ARGS, 109 "TZLibTransport: uncompressed write buffer must be at least" 110 + to_string(minimum) + "."); 111 } 112 113 try { 114 urbuf_ = new uint8_t[urbuf_size]; 115 crbuf_ = new uint8_t[crbuf_size]; 116 uwbuf_ = new uint8_t[uwbuf_size]; 117 cwbuf_ = new uint8_t[cwbuf_size]; 118 119 // Don't call this outside of the constructor. 120 initZlib(); 121 122 } catch (...) { 123 delete[] urbuf_; 124 delete[] crbuf_; 125 delete[] uwbuf_; 126 delete[] cwbuf_; 127 throw; 128 } 129 } 130 131 // Don't call this outside of the constructor. 132 void initZlib(); 133 134 /** 135 * TZlibTransport destructor. 136 * 137 * Warning: Destroying a TZlibTransport object may discard any written but 138 * unflushed data. You must explicitly call flush() or finish() to ensure 139 * that data is actually written and flushed to the underlying transport. 140 */ 141 ~TZlibTransport() override; 142 143 bool isOpen() const override; 144 bool peek() override; 145 open()146 void open() override { transport_->open(); } 147 close()148 void close() override { transport_->close(); } 149 150 uint32_t read(uint8_t* buf, uint32_t len); 151 152 void write(const uint8_t* buf, uint32_t len); 153 154 void flush() override; 155 156 /** 157 * Finalize the zlib stream. 158 * 159 * This causes zlib to flush any pending write data and write end-of-stream 160 * information, including the checksum. Once finish() has been called, no 161 * new data can be written to the stream. 162 */ 163 void finish(); 164 165 const uint8_t* borrow(uint8_t* buf, uint32_t* len); 166 167 void consume(uint32_t len); 168 169 /** 170 * Verify the checksum at the end of the zlib stream. 171 * 172 * This may only be called after all data has been read. 173 * It verifies the checksum that was written by the finish() call. 174 */ 175 void verifyChecksum(); 176 177 /** 178 * TODO(someone_smart): Choose smart defaults. 179 */ 180 static const int DEFAULT_URBUF_SIZE = 128; 181 static const int DEFAULT_CRBUF_SIZE = 1024; 182 static const int DEFAULT_UWBUF_SIZE = 128; 183 static const int DEFAULT_CWBUF_SIZE = 1024; 184 getUnderlyingTransport()185 std::shared_ptr<TTransport> getUnderlyingTransport() const { return transport_; } 186 187 protected: 188 inline void checkZlibRv(int status, const char* msg); 189 inline void checkZlibRvNothrow(int status, const char* msg); 190 inline int readAvail() const; 191 void flushToTransport(int flush); 192 void flushToZlib(const uint8_t* buf, int len, int flush); 193 bool readFromZlib(); 194 195 protected: 196 // Writes smaller than this are buffered up. 197 // Larger (or equal) writes are dumped straight to zlib. 198 static const uint32_t MIN_DIRECT_DEFLATE_SIZE = 32; 199 200 std::shared_ptr<TTransport> transport_; 201 202 int urpos_; 203 int uwpos_; 204 205 /// True iff zlib has reached the end of the input stream. 206 bool input_ended_; 207 /// True iff we have finished the output stream. 208 bool output_finished_; 209 210 uint32_t urbuf_size_; 211 uint32_t crbuf_size_; 212 uint32_t uwbuf_size_; 213 uint32_t cwbuf_size_; 214 215 uint8_t* urbuf_; 216 uint8_t* crbuf_; 217 uint8_t* uwbuf_; 218 uint8_t* cwbuf_; 219 220 struct z_stream_s* rstream_; 221 struct z_stream_s* wstream_; 222 223 const int comp_level_; 224 }; 225 226 /** 227 * Wraps a transport into a zlibbed one. 228 * 229 */ 230 class TZlibTransportFactory : public TTransportFactory { 231 public: 232 TZlibTransportFactory() = default; 233 234 /** 235 * Wraps a transport factory into a zlibbed one. 236 */ 237 TZlibTransportFactory(std::shared_ptr<TTransportFactory> transportFactory); 238 239 ~TZlibTransportFactory() override = default; 240 241 std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override; 242 243 protected: 244 std::shared_ptr<TTransportFactory> transportFactory_; 245 }; 246 247 } 248 } 249 } // apache::thrift::transport 250 251 #endif // #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 252