1 /* 2 * Copyright (c) Facebook, Inc. and its affiliates. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 18 #define _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 1 19 20 #include <boost/lexical_cast.hpp> 21 #include <thrift/lib/cpp/transport/TBufferTransports.h> 22 #include <thrift/lib/cpp/transport/TVirtualTransport.h> 23 24 struct z_stream_s; 25 26 namespace apache { 27 namespace thrift { 28 namespace transport { 29 30 class TZlibTransportException : public TTransportException { 31 public: TZlibTransportException(int status,const char * msg)32 TZlibTransportException(int status, const char* msg) 33 : TTransportException( 34 TTransportException::INTERNAL_ERROR, errorMessage(status, msg)), 35 zlib_status_(status), 36 zlib_msg_(msg == nullptr ? "(null)" : msg) {} 37 throw()38 ~TZlibTransportException() throw() override {} 39 getZlibStatus()40 int getZlibStatus() { return zlib_status_; } getZlibMessage()41 std::string getZlibMessage() { return zlib_msg_; } 42 errorMessage(int status,const char * msg)43 static std::string errorMessage(int status, const char* msg) { 44 std::string rv = "zlib error: "; 45 if (msg) { 46 rv += msg; 47 } else { 48 rv += "(no message)"; 49 } 50 rv += " (status = "; 51 rv += boost::lexical_cast<std::string>(status); 52 rv += ")"; 53 return rv; 54 } 55 56 int zlib_status_; 57 std::string zlib_msg_; 58 }; 59 60 /** 61 * This transport uses zlib's compressed format on the "far" side. 62 * 63 * There are two kinds of TZlibTransport objects: 64 * - Standalone objects are used to encode self-contained chunks of data 65 * (like structures). They include checksums. 66 * - Non-standalone transports are used for RPC. They are not implemented yet. 67 * 68 * TODO(dreiss): Don't do an extra copy of the compressed data if 69 * the underlying transport is TBuffered or TMemory. 70 * 71 */ 72 class TZlibTransport : public TVirtualTransport<TZlibTransport> { 73 public: 74 /** 75 * @param transport The transport to read compressed data from 76 * and write compressed data to. 77 * @param urbuf_size Uncompressed buffer size for reading. 78 * @param crbuf_size Compressed buffer size for reading. 79 * @param uwbuf_size Uncompressed buffer size for writing. 80 * @param cwbuf_size Compressed buffer size for writing. 81 * 82 * TODO(dreiss): Write a constructor that isn't a pain. 83 */ 84 explicit TZlibTransport( 85 std::shared_ptr<TTransport> transport, 86 size_t urbuf_size = DEFAULT_URBUF_SIZE, 87 size_t crbuf_size = DEFAULT_CRBUF_SIZE, 88 size_t uwbuf_size = DEFAULT_UWBUF_SIZE, 89 size_t cwbuf_size = DEFAULT_CWBUF_SIZE) transport_(transport)90 : transport_(transport), 91 urpos_(0), 92 uwpos_(0), 93 input_ended_(false), 94 output_finished_(false), 95 urbuf_size_(urbuf_size), 96 crbuf_size_(crbuf_size), 97 uwbuf_size_(uwbuf_size), 98 cwbuf_size_(cwbuf_size), 99 urbuf_(nullptr), 100 crbuf_(nullptr), 101 uwbuf_(nullptr), 102 cwbuf_(nullptr), 103 rstream_(nullptr), 104 wstream_(nullptr) { 105 if (uwbuf_size_ < MIN_DIRECT_DEFLATE_SIZE) { 106 // Have to copy this into a local because of a linking issue. 107 int minimum = MIN_DIRECT_DEFLATE_SIZE; 108 throw TTransportException( 109 TTransportException::BAD_ARGS, 110 "TZLibTransport: uncompressed write buffer must be at least" + 111 boost::lexical_cast<std::string>(minimum) + "."); 112 } 113 114 try { 115 urbuf_ = new uint8_t[urbuf_size]; 116 crbuf_ = new uint8_t[crbuf_size]; 117 uwbuf_ = new uint8_t[uwbuf_size]; 118 cwbuf_ = new uint8_t[cwbuf_size]; 119 120 // Don't call this outside of the constructor. 121 initZlib(); 122 123 } catch (...) { 124 delete[] urbuf_; 125 delete[] crbuf_; 126 delete[] uwbuf_; 127 delete[] cwbuf_; 128 throw; 129 } 130 } 131 132 // Don't call this outside of the constructor. 133 void initZlib(); 134 135 /** 136 * TZlibTransport destructor. 137 * 138 * Warning: Destroying a TZlibTransport object may discard any written but 139 * unflushed data. You must explicitly call flush() or finish() to ensure 140 * that data is actually written and flushed to the underlying transport. 141 */ 142 ~TZlibTransport() override; 143 144 bool isOpen() override; 145 bool peek() override; 146 open()147 void open() override { transport_->open(); } 148 close()149 void close() override { transport_->close(); } 150 151 uint32_t read(uint8_t* buf, uint32_t len); 152 153 void write(const uint8_t* buf, uint32_t len); 154 155 void flush() override; 156 157 /** 158 * Finalize the zlib stream. 159 * 160 * This causes zlib to flush any pending write data and write end-of-stream 161 * information, including the checksum. Once finish() has been called, no 162 * new data can be written to the stream. 163 */ 164 void finish(); 165 166 const uint8_t* borrow(uint8_t* buf, uint32_t* len); 167 168 void consume(uint32_t len); 169 170 /** 171 * Verify the checksum at the end of the zlib stream. 172 * 173 * This may only be called after all data has been read. 174 * It verifies the checksum that was written by the finish() call. 175 */ 176 void verifyChecksum(); 177 178 /** 179 * TODO(someone_smart): Choose smart defaults. 180 */ 181 static const size_t DEFAULT_URBUF_SIZE = 128; 182 static const size_t DEFAULT_CRBUF_SIZE = 1024; 183 static const size_t DEFAULT_UWBUF_SIZE = 128; 184 static const size_t DEFAULT_CWBUF_SIZE = 1024; 185 186 protected: 187 inline void checkZlibRv(int status, const char* msg); 188 inline void checkZlibRvNothrow(int status, const char* msg); 189 inline size_t readAvail(); 190 void flushToTransport(int flush); 191 void flushToZlib(const uint8_t* buf, int len, int flush); 192 bool readFromZlib(); 193 194 protected: 195 // Writes smaller than this are buffered up. 196 // Larger (or equal) writes are dumped straight to zlib. 197 static const size_t MIN_DIRECT_DEFLATE_SIZE = 32; 198 199 std::shared_ptr<TTransport> transport_; 200 201 size_t urpos_; 202 size_t uwpos_; 203 204 /// True iff zlib has reached the end of the input stream. 205 bool input_ended_; 206 /// True iff we have finished the output stream. 207 bool output_finished_; 208 209 size_t urbuf_size_; 210 size_t crbuf_size_; 211 size_t uwbuf_size_; 212 size_t cwbuf_size_; 213 214 uint8_t* urbuf_; 215 uint8_t* crbuf_; 216 uint8_t* uwbuf_; 217 uint8_t* cwbuf_; 218 219 struct z_stream_s* rstream_; 220 struct z_stream_s* wstream_; 221 }; 222 223 /** 224 * Wraps a transport into a zlibbed one. 225 * 226 */ 227 class TZlibTransportFactory : public TTransportFactory { 228 public: TZlibTransportFactory()229 TZlibTransportFactory() {} 230 ~TZlibTransportFactory()231 ~TZlibTransportFactory() override {} 232 getTransport(std::shared_ptr<TTransport> trans)233 std::shared_ptr<TTransport> getTransport( 234 std::shared_ptr<TTransport> trans) override { 235 return std::shared_ptr<TTransport>(new TZlibTransport(trans)); 236 } 237 }; 238 239 /** 240 * Wraps a transport into a framed, zlibbed one. 241 */ 242 class TFramedZlibTransportFactory : public TTransportFactory { 243 public: TFramedZlibTransportFactory()244 TFramedZlibTransportFactory() {} 245 ~TFramedZlibTransportFactory()246 ~TFramedZlibTransportFactory() override {} 247 getTransport(std::shared_ptr<TTransport> trans)248 std::shared_ptr<TTransport> getTransport( 249 std::shared_ptr<TTransport> trans) override { 250 std::shared_ptr<TTransport> framedTransport(new TFramedTransport(trans)); 251 return std::shared_ptr<TTransport>(new TZlibTransport(framedTransport)); 252 } 253 }; 254 255 } // namespace transport 256 } // namespace thrift 257 } // namespace apache 258 259 #endif // #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 260