1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 module thrift.transport.zlib;
21
22 import core.bitop : bswap;
23 import etc.c.zlib;
24 import std.algorithm : min;
25 import std.array : empty;
26 import std.conv : to;
27 import std.exception : enforce;
28 import thrift.base;
29 import thrift.transport.base;
30
31 /**
32 * zlib transport. Compresses (deflates) data before writing it to the
33 * underlying transport, and decompresses (inflates) it after reading.
34 */
35 final class TZlibTransport : TBaseTransport {
36 // These defaults have yet to be optimized.
37 enum DEFAULT_URBUF_SIZE = 128;
38 enum DEFAULT_CRBUF_SIZE = 1024;
39 enum DEFAULT_UWBUF_SIZE = 128;
40 enum DEFAULT_CWBUF_SIZE = 1024;
41
42 /**
43 * Constructs a new zlib transport.
44 *
45 * Params:
46 * transport = The underlying transport to wrap.
47 * urbufSize = The size of the uncompressed reading buffer, in bytes.
48 * crbufSize = The size of the compressed reading buffer, in bytes.
49 * uwbufSize = The size of the uncompressed writing buffer, in bytes.
50 * cwbufSize = The size of the compressed writing buffer, in bytes.
51 */
52 this(
53 TTransport transport,
54 size_t urbufSize = DEFAULT_URBUF_SIZE,
55 size_t crbufSize = DEFAULT_CRBUF_SIZE,
56 size_t uwbufSize = DEFAULT_UWBUF_SIZE,
57 size_t cwbufSize = DEFAULT_CWBUF_SIZE
58 ) {
59 transport_ = transport;
60
61 enforce(uwbufSize >= MIN_DIRECT_DEFLATE_SIZE, new TTransportException(
62 "TZLibTransport: uncompressed write buffer must be at least " ~
63 to!string(MIN_DIRECT_DEFLATE_SIZE) ~ "bytes in size.",
64 TTransportException.Type.BAD_ARGS));
65
66 urbuf_ = new ubyte[urbufSize];
67 crbuf_ = new ubyte[crbufSize];
68 uwbuf_ = new ubyte[uwbufSize];
69 cwbuf_ = new ubyte[cwbufSize];
70
71 rstream_ = new z_stream;
72 rstream_.next_in = crbuf_.ptr;
73 rstream_.avail_in = 0;
74 rstream_.next_out = urbuf_.ptr;
75 rstream_.avail_out = to!uint(urbuf_.length);
76
77 wstream_ = new z_stream;
78 wstream_.next_in = uwbuf_.ptr;
79 wstream_.avail_in = 0;
80 wstream_.next_out = cwbuf_.ptr;
81 wstream_.avail_out = to!uint(crbuf_.length);
82
83 zlibEnforce(inflateInit(rstream_), rstream_);
scope(failure)84 scope (failure) {
85 zlibLogError(inflateEnd(rstream_), rstream_);
86 }
87
88 zlibEnforce(deflateInit(wstream_, Z_DEFAULT_COMPRESSION), wstream_);
89 }
90
~this()91 ~this() {
92 zlibLogError(inflateEnd(rstream_), rstream_);
93
94 auto result = deflateEnd(wstream_);
95 // Z_DATA_ERROR may indicate unflushed data, so just ignore it.
96 if (result != Z_DATA_ERROR) {
97 zlibLogError(result, wstream_);
98 }
99 }
100
101 /**
102 * Returns the wrapped transport.
103 */
underlyingTransport()104 TTransport underlyingTransport() @property {
105 return transport_;
106 }
107
isOpen()108 override bool isOpen() @property {
109 return readAvail > 0 || transport_.isOpen;
110 }
111
peek()112 override bool peek() {
113 return readAvail > 0 || transport_.peek();
114 }
115
open()116 override void open() {
117 transport_.open();
118 }
119
close()120 override void close() {
121 transport_.close();
122 }
123
read(ubyte[]buf)124 override size_t read(ubyte[] buf) {
125 // The C++ implementation suggests to skip urbuf on big reads in future
126 // versions, we would benefit from it as well.
127 auto origLen = buf.length;
128 while (true) {
129 auto give = min(readAvail, buf.length);
130
131 // If std.range.put was optimized for slicable ranges, it could be used
132 // here as well.
133 buf[0 .. give] = urbuf_[urpos_ .. urpos_ + give];
134 buf = buf[give .. $];
135 urpos_ += give;
136
137 auto need = buf.length;
138 if (need == 0) {
139 // We could manage to get the all the data requested.
140 return origLen;
141 }
142
143 if (inputEnded_ || (need < origLen && rstream_.avail_in == 0)) {
144 // We didn't fill buf completely, but there is no more data available.
145 return origLen - need;
146 }
147
148 // Refill our buffer by reading more data through zlib.
149 rstream_.next_out = urbuf_.ptr;
150 rstream_.avail_out = to!uint(urbuf_.length);
151 urpos_ = 0;
152
153 if (!readFromZlib()) {
154 // Couldn't get more data from the underlying transport.
155 return origLen - need;
156 }
157 }
158 }
159
write(in ubyte[]buf)160 override void write(in ubyte[] buf) {
161 enforce(!outputFinished_, new TTransportException(
162 "write() called after finish()", TTransportException.Type.BAD_ARGS));
163
164 auto len = buf.length;
165 if (len > MIN_DIRECT_DEFLATE_SIZE) {
166 flushToZlib(uwbuf_[0 .. uwpos_], Z_NO_FLUSH);
167 uwpos_ = 0;
168 flushToZlib(buf, Z_NO_FLUSH);
169 } else if (len > 0) {
170 if (uwbuf_.length - uwpos_ < len) {
171 flushToZlib(uwbuf_[0 .. uwpos_], Z_NO_FLUSH);
172 uwpos_ = 0;
173 }
174 uwbuf_[uwpos_ .. uwpos_ + len] = buf[];
175 uwpos_ += len;
176 }
177 }
178
flush()179 override void flush() {
180 enforce(!outputFinished_, new TTransportException(
181 "flush() called after finish()", TTransportException.Type.BAD_ARGS));
182
183 flushToTransport(Z_SYNC_FLUSH);
184 }
185
borrow(ubyte * buf,size_t len)186 override const(ubyte)[] borrow(ubyte* buf, size_t len) {
187 if (len <= readAvail) {
188 return urbuf_[urpos_ .. $];
189 }
190 return null;
191 }
192
consume(size_t len)193 override void consume(size_t len) {
194 enforce(readAvail >= len, new TTransportException(
195 "consume() did not follow a borrow().", TTransportException.Type.BAD_ARGS));
196 urpos_ += len;
197 }
198
199 /**
200 * Finalize the zlib stream.
201 *
202 * This causes zlib to flush any pending write data and write end-of-stream
203 * information, including the checksum. Once finish() has been called, no
204 * new data can be written to the stream.
205 */
finish()206 void finish() {
207 enforce(!outputFinished_, new TTransportException(
208 "flush() called on already finished TZlibTransport",
209 TTransportException.Type.BAD_ARGS));
210 flushToTransport(Z_FINISH);
211 }
212
213 /**
214 * Verify the checksum at the end of the zlib stream (by finish()).
215 *
216 * May only be called after all data has been read.
217 *
218 * Throws: TTransportException when the checksum is corrupted or there is
219 * still unread data left.
220 */
verifyChecksum()221 void verifyChecksum() {
222 // If zlib has already reported the end of the stream, the checksum has
223 // been verified, no.
224 if (inputEnded_) return;
225
226 enforce(!readAvail, new TTransportException(
227 "verifyChecksum() called before end of zlib stream",
228 TTransportException.Type.CORRUPTED_DATA));
229
230 rstream_.next_out = urbuf_.ptr;
231 rstream_.avail_out = to!uint(urbuf_.length);
232 urpos_ = 0;
233
234 // readFromZlib() will throw an exception if the checksum is bad.
235 enforce(readFromZlib(), new TTransportException(
236 "checksum not available yet in verifyChecksum()",
237 TTransportException.Type.CORRUPTED_DATA));
238
239 enforce(inputEnded_, new TTransportException(
240 "verifyChecksum() called before end of zlib stream",
241 TTransportException.Type.CORRUPTED_DATA));
242
243 // If we get here, we are at the end of the stream and thus zlib has
244 // successfully verified the checksum.
245 }
246
247 private:
readAvail()248 size_t readAvail() const @property {
249 return urbuf_.length - rstream_.avail_out - urpos_;
250 }
251
readFromZlib()252 bool readFromZlib() {
253 assert(!inputEnded_);
254
255 if (rstream_.avail_in == 0) {
256 // zlib has used up all the compressed data we provided in crbuf, read
257 // some more from the underlying transport.
258 auto got = transport_.read(crbuf_);
259 if (got == 0) return false;
260 rstream_.next_in = crbuf_.ptr;
261 rstream_.avail_in = to!uint(got);
262 }
263
264 // We have some compressed data now, uncompress it.
265 auto zlib_result = inflate(rstream_, Z_SYNC_FLUSH);
266 if (zlib_result == Z_STREAM_END) {
267 inputEnded_ = true;
268 } else {
269 zlibEnforce(zlib_result, rstream_);
270 }
271
272 return true;
273 }
274
flushToTransport(int type)275 void flushToTransport(int type) {
276 // Compress remaining data in uwbuf_ to cwbuf_.
277 flushToZlib(uwbuf_[0 .. uwpos_], type);
278 uwpos_ = 0;
279
280 // Write all compressed data to the transport.
281 transport_.write(cwbuf_[0 .. $ - wstream_.avail_out]);
282 wstream_.next_out = cwbuf_.ptr;
283 wstream_.avail_out = to!uint(cwbuf_.length);
284
285 // Flush the transport.
286 transport_.flush();
287 }
288
flushToZlib(in ubyte[]buf,int type)289 void flushToZlib(in ubyte[] buf, int type) {
290 wstream_.next_in = cast(ubyte*)buf.ptr; // zlib only reads, cast is safe.
291 wstream_.avail_in = to!uint(buf.length);
292
293 while (true) {
294 if (type == Z_NO_FLUSH && wstream_.avail_in == 0) {
295 break;
296 }
297
298 if (wstream_.avail_out == 0) {
299 // cwbuf has been exhausted by zlib, flush to the underlying transport.
300 transport_.write(cwbuf_);
301 wstream_.next_out = cwbuf_.ptr;
302 wstream_.avail_out = to!uint(cwbuf_.length);
303 }
304
305 auto zlib_result = deflate(wstream_, type);
306
307 if (type == Z_FINISH && zlib_result == Z_STREAM_END) {
308 assert(wstream_.avail_in == 0);
309 outputFinished_ = true;
310 break;
311 }
312
313 zlibEnforce(zlib_result, wstream_);
314
315 if ((type == Z_SYNC_FLUSH || type == Z_FULL_FLUSH) &&
316 wstream_.avail_in == 0 && wstream_.avail_out != 0) {
317 break;
318 }
319 }
320 }
321
zlibEnforce(int status,z_stream * stream)322 static void zlibEnforce(int status, z_stream* stream) {
323 if (status != Z_OK) {
324 throw new TZlibException(status, stream.msg);
325 }
326 }
327
zlibLogError(int status,z_stream * stream)328 static void zlibLogError(int status, z_stream* stream) {
329 if (status != Z_OK) {
330 logError("TZlibTransport: zlib failure in destructor: %s",
331 TZlibException.errorMessage(status, stream.msg));
332 }
333 }
334
335 // Writes smaller than this are buffered up (due to zlib handling overhead).
336 // Larger (or equal) writes are dumped straight to zlib.
337 enum MIN_DIRECT_DEFLATE_SIZE = 32;
338
339 TTransport transport_;
340 z_stream* rstream_;
341 z_stream* wstream_;
342
343 /// Whether zlib has reached the end of the input stream.
344 bool inputEnded_;
345
346 /// Whether the output stream was already finish()ed.
347 bool outputFinished_;
348
349 /// Compressed input data buffer.
350 ubyte[] crbuf_;
351
352 /// Uncompressed input data buffer.
353 ubyte[] urbuf_;
354 size_t urpos_;
355
356 /// Uncompressed output data buffer (where small writes are accumulated
357 /// before handing over to zlib).
358 ubyte[] uwbuf_;
359 size_t uwpos_;
360
361 /// Compressed output data buffer (filled by zlib, we flush it to the
362 /// underlying transport).
363 ubyte[] cwbuf_;
364 }
365
366 /**
367 * Wraps given transports into TZlibTransports.
368 */
369 alias TWrapperTransportFactory!TZlibTransport TZlibTransportFactory;
370
371 /**
372 * An INTERNAL_ERROR-type TTransportException originating from an error
373 * signaled by zlib.
374 */
375 class TZlibException : TTransportException {
this(int statusCode,const (char)* msg)376 this(int statusCode, const(char)* msg) {
377 super(errorMessage(statusCode, msg), TTransportException.Type.INTERNAL_ERROR);
378 zlibStatusCode = statusCode;
379 zlibMsg = msg ? to!string(msg) : "(null)";
380 }
381
382 int zlibStatusCode;
383 string zlibMsg;
384
errorMessage(int statusCode,const (char)* msg)385 static string errorMessage(int statusCode, const(char)* msg) {
386 string result = "zlib error: ";
387
388 if (msg) {
389 result ~= to!string(msg);
390 } else {
391 result ~= "(no message)";
392 }
393
394 result ~= " (status code = " ~ to!string(statusCode) ~ ")";
395 return result;
396 }
397 }
398
version(unittest)399 version (unittest) {
400 import std.exception : collectException;
401 import thrift.transport.memory;
402 }
403
404 // Make sure basic reading/writing works.
405 unittest {
406 auto buf = new TMemoryBuffer;
407 auto zlib = new TZlibTransport(buf);
408
409 immutable ubyte[] data = [1, 2, 3, 4, 5];
410 zlib.write(data);
411 zlib.finish();
412
413 auto result = new ubyte[data.length];
414 zlib.readAll(result);
415 enforce(data == result);
416 zlib.verifyChecksum();
417 }
418
419 // Make sure there is no data is written if write() is never called.
420 unittest {
421 auto buf = new TMemoryBuffer;
422 {
423 scope zlib = new TZlibTransport(buf);
424 }
425 enforce(buf.getContents().length == 0);
426 }
427
428 // Make sure calling write()/flush()/finish() again after finish() throws.
429 unittest {
430 auto buf = new TMemoryBuffer;
431 auto zlib = new TZlibTransport(buf);
432
433 zlib.write([1, 2, 3, 4, 5]);
434 zlib.finish();
435
436 auto ex = collectException!TTransportException(zlib.write([6]));
437 enforce(ex && ex.type == TTransportException.Type.BAD_ARGS);
438
439 ex = collectException!TTransportException(zlib.flush());
440 enforce(ex && ex.type == TTransportException.Type.BAD_ARGS);
441
442 ex = collectException!TTransportException(zlib.finish());
443 enforce(ex && ex.type == TTransportException.Type.BAD_ARGS);
444 }
445
446 // Make sure verifying the checksum works even if it requires starting a new
447 // reading buffer after reading the payload has already been completed.
448 unittest {
449 auto buf = new TMemoryBuffer;
450 auto zlib = new TZlibTransport(buf);
451
452 immutable ubyte[] data = [1, 2, 3, 4, 5];
453 zlib.write(data);
454 zlib.finish();
455
456 zlib = new TZlibTransport(buf, TZlibTransport.DEFAULT_URBUF_SIZE,
457 buf.getContents().length - 1); // The last byte belongs to the checksum.
458
459 auto result = new ubyte[data.length];
460 zlib.readAll(result);
461 enforce(data == result);
462
463 zlib.verifyChecksum();
464 }
465
466 // Make sure verifyChecksum() throws if we messed with the checksum.
467 unittest {
468 import std.stdio;
469 import thrift.transport.range;
470
471 auto buf = new TMemoryBuffer;
472 auto zlib = new TZlibTransport(buf);
473
474 immutable ubyte[] data = [1, 2, 3, 4, 5];
475 zlib.write(data);
476 zlib.finish();
477
testCorrupted(const (ubyte)[]corruptedData)478 void testCorrupted(const(ubyte)[] corruptedData) {
479 auto reader = new TZlibTransport(tInputRangeTransport(corruptedData));
480 auto result = new ubyte[data.length];
481 try {
482 reader.readAll(result);
483
484 // If it does read without complaining, the result should be correct.
485 enforce(result == data);
486 } catch (TZlibException e) {}
487
488 auto ex = collectException!TTransportException(reader.verifyChecksum());
489 enforce(ex && ex.type == TTransportException.Type.CORRUPTED_DATA);
490 }
491
492 testCorrupted(buf.getContents()[0 .. $ - 1]);
493
494 auto modified = buf.getContents().dup;
495 ++modified[$ - 1];
496 testCorrupted(modified);
497 }
498