1 //---------------------------------------------------------
2 // Copyright 2015 Ontario Institute for Cancer Research
3 // Written by Matei David (matei@cs.toronto.edu)
4 //---------------------------------------------------------
5 
6 // Reference:
7 // http://stackoverflow.com/questions/14086417/how-to-write-custom-input-stream-in-c
8 
9 #ifndef __ZSTR_HPP
10 #define __ZSTR_HPP
11 
12 #include <cassert>
13 #include <fstream>
14 #include <sstream>
15 #include <zlib.h>
16 #include "strict_fstream.hpp"
17 
18 namespace zstr
19 {
20 
21 /// Exception class thrown by failed zlib operations.
22 class Exception
23     : public std::exception
24 {
25 public:
Exception(z_stream * zstrm_p,int ret)26     Exception(z_stream * zstrm_p, int ret)
27         : _msg("zlib: ")
28     {
29         switch (ret)
30         {
31         case Z_STREAM_ERROR:
32             _msg += "Z_STREAM_ERROR: ";
33             break;
34         case Z_DATA_ERROR:
35             _msg += "Z_DATA_ERROR: ";
36             break;
37         case Z_MEM_ERROR:
38             _msg += "Z_MEM_ERROR: ";
39             break;
40         case Z_VERSION_ERROR:
41             _msg += "Z_VERSION_ERROR: ";
42             break;
43         case Z_BUF_ERROR:
44             _msg += "Z_BUF_ERROR: ";
45             break;
46         default:
47             std::ostringstream oss;
48             oss << ret;
49             _msg += "[" + oss.str() + "]: ";
50             break;
51         }
52         _msg += zstrm_p->msg;
53     }
Exception(const std::string msg)54     Exception(const std::string msg) : _msg(msg) {}
what() const55     const char * what() const noexcept { return _msg.c_str(); }
56 private:
57     std::string _msg;
58 }; // class Exception
59 
60 namespace detail
61 {
62 
63 class z_stream_wrapper
64     : public z_stream
65 {
66 public:
z_stream_wrapper(bool _is_input=true,int _level=Z_DEFAULT_COMPRESSION)67     z_stream_wrapper(bool _is_input = true, int _level = Z_DEFAULT_COMPRESSION)
68         : is_input(_is_input)
69     {
70         this->zalloc = Z_NULL;
71         this->zfree = Z_NULL;
72         this->opaque = Z_NULL;
73         int ret;
74         if (is_input)
75         {
76             this->avail_in = 0;
77             this->next_in = Z_NULL;
78             ret = inflateInit2(this, -MAX_WBITS);
79         }
80         else
81         {
82             ret = deflateInit2(this, 8, Z_DEFLATED, -MAX_WBITS, _level, Z_DEFAULT_STRATEGY);
83         }
84         if (ret != Z_OK) throw Exception(this, ret);
85     }
~z_stream_wrapper()86     ~z_stream_wrapper()
87     {
88         if (is_input)
89         {
90             inflateEnd(this);
91         }
92         else
93         {
94             deflateEnd(this);
95         }
96     }
97 private:
98     bool is_input;
99 }; // class z_stream_wrapper
100 
101 } // namespace detail
102 
103 class istreambuf
104     : public std::streambuf
105 {
106 public:
istreambuf(std::streambuf * _sbuf_p,std::size_t _buff_size=default_buff_size,bool _auto_detect=false)107     istreambuf(std::streambuf * _sbuf_p,
108                std::size_t _buff_size = default_buff_size, bool _auto_detect = false)
109         : sbuf_p(_sbuf_p),
110           zstrm_p(nullptr),
111           buff_size(_buff_size),
112           auto_detect(_auto_detect),
113           auto_detect_run(false),
114           is_text(false)
115     {
116         assert(sbuf_p);
117         in_buff = new char [buff_size];
118         in_buff_start = in_buff;
119         in_buff_end = in_buff;
120         out_buff = new char [buff_size];
121         setg(out_buff, out_buff, out_buff);
122     }
123 
124     istreambuf(const istreambuf &) = delete;
125     istreambuf(istreambuf &&) = default;
126     istreambuf & operator = (const istreambuf &) = delete;
127     istreambuf & operator = (istreambuf &&) = default;
128 
~istreambuf()129     virtual ~istreambuf()
130     {
131         delete [] in_buff;
132         delete [] out_buff;
133         if (zstrm_p) delete zstrm_p;
134     }
135 
underflow()136     virtual std::streambuf::int_type underflow()
137     {
138         if (this->gptr() == this->egptr())
139         {
140             // pointers for free region in output buffer
141             char * out_buff_free_start = out_buff;
142             do
143             {
144                 // read more input if none available
145                 if (in_buff_start == in_buff_end)
146                 {
147                     // empty input buffer: refill from the start
148                     in_buff_start = in_buff;
149                     std::streamsize sz = sbuf_p->sgetn(in_buff, (std::streamsize) buff_size);
150                     in_buff_end = in_buff + sz;
151                     if (in_buff_end == in_buff_start) break; // end of input
152                 }
153                 // auto detect if the stream contains text or deflate data
154                 if (auto_detect && !auto_detect_run)
155                 {
156                     auto_detect_run = true;
157                     unsigned char b0 = *reinterpret_cast<unsigned char *>(in_buff_start);
158                     unsigned char b1 = *reinterpret_cast<unsigned char *>(in_buff_start + 1);
159                     // Ref:
160                     // http://en.wikipedia.org/wiki/Gzip
161                     // http://stackoverflow.com/questions/9050260/what-does-a-zlib-header-look-like
162                     is_text = ! (in_buff_start + 2 <= in_buff_end
163                                  && ((b0 == 0x1F && b1 == 0x8B)         // gzip header
164                                      || (b0 == 0x78 && (b1 == 0x01      // zlib header
165                                                         || b1 == 0x9C
166                                                         || b1 == 0xDA))));
167                 }
168                 if (is_text)
169                 {
170                     // simply swap in_buff and out_buff, and adjust pointers
171                     assert(in_buff_start == in_buff);
172                     std::swap(in_buff, out_buff);
173                     out_buff_free_start = in_buff_end;
174                     in_buff_start = in_buff;
175                     in_buff_end = in_buff;
176                 }
177                 else
178                 {
179                     // run inflate() on input
180                     if (!zstrm_p)
181                     {
182                         zstrm_p = new detail::z_stream_wrapper(true);
183                     }
184                     zstrm_p->next_in = reinterpret_cast<decltype(zstrm_p->next_in)>(in_buff_start);
185                     zstrm_p->avail_in = static_cast<decltype(zstrm_p->avail_in)>(in_buff_end - in_buff_start);
186                     zstrm_p->next_out = reinterpret_cast<decltype(zstrm_p->next_out)>(out_buff_free_start);
187                     zstrm_p->avail_out = static_cast<decltype(zstrm_p->avail_out)>((out_buff + buff_size) - out_buff_free_start);
188                     int ret = inflate(zstrm_p, Z_NO_FLUSH);
189                     // process return code
190                     if (ret != Z_OK && ret != Z_STREAM_END) throw Exception(zstrm_p, ret);
191                     // update in&out pointers following inflate()
192                     in_buff_start = reinterpret_cast< decltype(in_buff_start) >(zstrm_p->next_in);
193                     in_buff_end = in_buff_start + zstrm_p->avail_in;
194                     out_buff_free_start = reinterpret_cast< decltype(out_buff_free_start) >(zstrm_p->next_out);
195                     assert(out_buff_free_start + zstrm_p->avail_out == out_buff + buff_size);
196                     // if stream ended, deallocate inflator
197                     if (ret == Z_STREAM_END)
198                     {
199                         delete zstrm_p;
200                         zstrm_p = nullptr;
201                     }
202                 }
203             } while (out_buff_free_start == out_buff);
204             // 2 exit conditions:
205             // - end of input: there might or might not be output available
206             // - out_buff_free_start != out_buff: output available
207             this->setg(out_buff, out_buff, out_buff_free_start);
208         }
209         return this->gptr() == this->egptr()
210             ? traits_type::eof()
211             : traits_type::to_int_type(*this->gptr());
212     }
213 private:
214     std::streambuf * sbuf_p;
215     char * in_buff;
216     char * in_buff_start;
217     char * in_buff_end;
218     char * out_buff;
219     detail::z_stream_wrapper * zstrm_p;
220     std::size_t buff_size;
221     bool auto_detect;
222     bool auto_detect_run;
223     bool is_text;
224 
225     static const std::size_t default_buff_size = (std::size_t)1 << 20;
226 }; // class istreambuf
227 
228 class ostreambuf
229     : public std::streambuf
230 {
231 public:
ostreambuf(std::streambuf * _sbuf_p,std::size_t _buff_size=default_buff_size,int _level=Z_DEFAULT_COMPRESSION)232     ostreambuf(std::streambuf * _sbuf_p,
233                std::size_t _buff_size = default_buff_size, int _level = Z_DEFAULT_COMPRESSION)
234         : sbuf_p(_sbuf_p),
235           zstrm_p(new detail::z_stream_wrapper(false, _level)),
236           buff_size(_buff_size)
237     {
238         assert(sbuf_p);
239         in_buff = new char [buff_size];
240         out_buff = new char [buff_size];
241         setp(in_buff, in_buff + buff_size);
242     }
243 
244     ostreambuf(const ostreambuf &) = delete;
245     ostreambuf(ostreambuf &&) = default;
246     ostreambuf & operator = (const ostreambuf &) = delete;
247     ostreambuf & operator = (ostreambuf &&) = default;
248 
deflate_loop(int flush)249     int deflate_loop(int flush)
250     {
251         while (true)
252         {
253             zstrm_p->next_out = reinterpret_cast<decltype(zstrm_p->next_out)>(out_buff);
254             zstrm_p->avail_out = static_cast<decltype(zstrm_p->avail_out)>(buff_size);
255             int ret = deflate(zstrm_p, flush);
256             if (ret != Z_OK && ret != Z_STREAM_END && ret != Z_BUF_ERROR) throw Exception(zstrm_p, ret);
257             std::streamsize sz = sbuf_p->sputn(out_buff, reinterpret_cast<decltype(out_buff)>(zstrm_p->next_out) - out_buff);
258             if (sz != reinterpret_cast<decltype(out_buff)>(zstrm_p->next_out) - out_buff)
259             {
260                 // there was an error in the sink stream
261                 return -1;
262             }
263             if (ret == Z_STREAM_END || ret == Z_BUF_ERROR || sz == 0)
264             {
265                 break;
266             }
267         }
268         return 0;
269     }
270 
~ostreambuf()271     virtual ~ostreambuf()
272     {
273         // flush the zlib stream
274         //
275         // NOTE: Errors here (sync() return value not 0) are ignored, because we
276         // cannot throw in a destructor. This mirrors the behaviour of
277         // std::basic_filebuf::~basic_filebuf(). To see an exception on error,
278         // close the ofstream with an explicit call to close(), and do not rely
279         // on the implicit call in the destructor.
280         //
281         sync();
282         delete [] in_buff;
283         delete [] out_buff;
284         delete zstrm_p;
285     }
overflow(std::streambuf::int_type c=traits_type::eof ())286     virtual std::streambuf::int_type overflow(std::streambuf::int_type c = traits_type::eof())
287     {
288         zstrm_p->next_in = reinterpret_cast< decltype(zstrm_p->next_in) >(pbase());
289         zstrm_p->avail_in = static_cast<decltype(zstrm_p->avail_in)>(pptr() - pbase());
290         while (zstrm_p->avail_in > 0)
291         {
292             int r = deflate_loop(Z_NO_FLUSH);
293             if (r != 0)
294             {
295                 setp(nullptr, nullptr);
296                 return traits_type::eof();
297             }
298         }
299         setp(in_buff, in_buff + buff_size);
300         return traits_type::eq_int_type(c, traits_type::eof()) ? traits_type::eof() : sputc((char) c);
301     }
sync()302     virtual int sync()
303     {
304         // first, call overflow to clear in_buff
305         overflow();
306         if (! pptr()) return -1;
307         // then, call deflate asking to finish the zlib stream
308         zstrm_p->next_in = nullptr;
309         zstrm_p->avail_in = 0;
310         if (deflate_loop(Z_FINISH) != 0) return -1;
311         deflateReset(zstrm_p);
312         return 0;
313     }
314 private:
315     std::streambuf * sbuf_p;
316     char * in_buff;
317     char * out_buff;
318     detail::z_stream_wrapper * zstrm_p;
319     std::size_t buff_size;
320 
321     static const std::size_t default_buff_size = (std::size_t)1 << 20;
322 }; // class ostreambuf
323 
324 class istream
325     : public std::istream
326 {
327 public:
istream(std::istream & is)328     istream(std::istream & is)
329         : std::istream(new istreambuf(is.rdbuf()))
330     {
331         exceptions(std::ios_base::badbit);
332     }
333 
istream(std::istream & is,std::size_t buff_size)334     istream(std::istream & is, std::size_t buff_size)
335         : std::istream(new istreambuf(is.rdbuf(), buff_size))
336     {
337         exceptions(std::ios_base::badbit);
338     }
339 
istream(std::streambuf * sbuf_p)340     explicit istream(std::streambuf * sbuf_p)
341         : std::istream(new istreambuf(sbuf_p))
342     {
343         exceptions(std::ios_base::badbit);
344     }
~istream()345     virtual ~istream()
346     {
347         delete rdbuf();
348     }
349 }; // class istream
350 
351 class ostream
352     : public std::ostream
353 {
354 public:
ostream(std::ostream & os)355     ostream(std::ostream & os)
356         : std::ostream(new ostreambuf(os.rdbuf()))
357     {
358         exceptions(std::ios_base::badbit);
359     }
ostream(std::ostream & os,int compression_level)360     ostream(std::ostream & os, int compression_level)
361         : std::ostream(new ostreambuf(os.rdbuf(), (std::size_t) 1 << 20, compression_level))
362     {
363         exceptions(std::ios_base::badbit);
364     }
ostream(std::streambuf * sbuf_p)365     explicit ostream(std::streambuf * sbuf_p)
366         : std::ostream(new ostreambuf(sbuf_p))
367     {
368         exceptions(std::ios_base::badbit);
369     }
~ostream()370     virtual ~ostream()
371     {
372         delete rdbuf();
373     }
374 }; // class ostream
375 
376 namespace detail
377 {
378 
379 template < typename FStream_Type >
380 struct strict_fstream_holder
381 {
strict_fstream_holderzstr::detail::strict_fstream_holder382     strict_fstream_holder(const std::string& filename, std::ios_base::openmode mode = std::ios_base::in)
383         : _fs(filename, mode)
384     {}
385     FStream_Type _fs;
386 }; // class strict_fstream_holder
387 
388 } // namespace detail
389 
390 class ifstream
391     : private detail::strict_fstream_holder< strict_fstream::ifstream >,
392       public std::istream
393 {
394 public:
ifstream(const std::string & filename,std::ios_base::openmode mode=std::ios_base::in)395     explicit ifstream(const std::string& filename, std::ios_base::openmode mode = std::ios_base::in)
396         : detail::strict_fstream_holder< strict_fstream::ifstream >(filename, mode),
397           std::istream(new istreambuf(_fs.rdbuf()))
398     {
399         exceptions(std::ios_base::badbit);
400     }
~ifstream()401     virtual ~ifstream()
402     {
403         if (rdbuf()) delete rdbuf();
404     }
405 }; // class ifstream
406 
407 class ofstream
408     : private detail::strict_fstream_holder< strict_fstream::ofstream >,
409       public std::ostream
410 {
411 public:
ofstream(const std::string & filename,std::ios_base::openmode mode=std::ios_base::out)412     explicit ofstream(const std::string& filename, std::ios_base::openmode mode = std::ios_base::out)
413         : detail::strict_fstream_holder< strict_fstream::ofstream >(filename, mode | std::ios_base::binary),
414           std::ostream(new ostreambuf(_fs.rdbuf()))
415     {
416         exceptions(std::ios_base::badbit);
417     }
~ofstream()418     virtual ~ofstream()
419     {
420         if (rdbuf()) delete rdbuf();
421     }
422 }; // class ofstream
423 
424 } // namespace zstr
425 
426 #endif