xref: /openbsd/lib/libz/infback.c (revision dda28197)
1 /* infback.c -- inflate using a call-back interface
2  * Copyright (C) 1995-2005 Mark Adler
3  * For conditions of distribution and use, see copyright notice in zlib.h
4  */
5 
6 /*
7    This code is largely copied from inflate.c.  Normally either infback.o or
8    inflate.o would be linked into an application--not both.  The interface
9    with inffast.c is retained so that optimized assembler-coded versions of
10    inflate_fast() can be used with either inflate.c or infback.c.
11  */
12 
13 #include "zutil.h"
14 #include "inftrees.h"
15 #include "inflate.h"
16 #include "inffast.h"
17 
18 /* function prototypes */
19 local void fixedtables OF((struct inflate_state FAR *state));
20 
21 /*
22    strm provides memory allocation functions in zalloc and zfree, or
23    Z_NULL to use the library memory allocation functions.
24 
25    windowBits is in the range 8..15, and window is a user-supplied
26    window and output buffer that is 2**windowBits bytes.
27  */
28 int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
29 z_streamp strm;
30 int windowBits;
31 unsigned char FAR *window;
32 const char *version;
33 int stream_size;
34 {
35     struct inflate_state FAR *state;
36 
37     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
38         stream_size != (int)(sizeof(z_stream)))
39         return Z_VERSION_ERROR;
40     if (strm == Z_NULL || window == Z_NULL ||
41         windowBits < 8 || windowBits > 15)
42         return Z_STREAM_ERROR;
43     strm->msg = Z_NULL;                 /* in case we return an error */
44     if (strm->zalloc == (alloc_func)0) {
45         strm->zalloc = zcalloc;
46         strm->opaque = (voidpf)0;
47     }
48     if (strm->zfree == (free_func)0) strm->zfree = zcfree;
49     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
50                                                sizeof(struct inflate_state));
51     if (state == Z_NULL) return Z_MEM_ERROR;
52     Tracev((stderr, "inflate: allocated\n"));
53     strm->state = (struct internal_state FAR *)state;
54     state->dmax = 32768U;
55     state->wbits = windowBits;
56     state->wsize = 1U << windowBits;
57     state->window = window;
58     state->write = 0;
59     state->whave = 0;
60     return Z_OK;
61 }
62 
63 /*
64    Return state with length and distance decoding tables and index sizes set to
65    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
66    If BUILDFIXED is defined, then instead this routine builds the tables the
67    first time it's called, and returns those tables the first time and
68    thereafter.  This reduces the size of the code by about 2K bytes, in
69    exchange for a little execution time.  However, BUILDFIXED should not be
70    used for threaded applications, since the rewriting of the tables and virgin
71    may not be thread-safe.
72  */
73 local void fixedtables(state)
74 struct inflate_state FAR *state;
75 {
76 #ifdef BUILDFIXED
77     static int virgin = 1;
78     static code *lenfix, *distfix;
79     static code fixed[544];
80 
81     /* build fixed huffman tables if first call (may not be thread safe) */
82     if (virgin) {
83         unsigned sym, bits;
84         static code *next;
85 
86         /* literal/length table */
87         sym = 0;
88         while (sym < 144) state->lens[sym++] = 8;
89         while (sym < 256) state->lens[sym++] = 9;
90         while (sym < 280) state->lens[sym++] = 7;
91         while (sym < 288) state->lens[sym++] = 8;
92         next = fixed;
93         lenfix = next;
94         bits = 9;
95         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
96 
97         /* distance table */
98         sym = 0;
99         while (sym < 32) state->lens[sym++] = 5;
100         distfix = next;
101         bits = 5;
102         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
103 
104         /* do this just once */
105         virgin = 0;
106     }
107 #else /* !BUILDFIXED */
108 #   include "inffixed.h"
109 #endif /* BUILDFIXED */
110     state->lencode = lenfix;
111     state->lenbits = 9;
112     state->distcode = distfix;
113     state->distbits = 5;
114 }
115 
116 /* Macros for inflateBack(): */
117 
118 /* Load returned state from inflate_fast() */
119 #define LOAD() \
120     do { \
121         put = strm->next_out; \
122         left = strm->avail_out; \
123         next = strm->next_in; \
124         have = strm->avail_in; \
125         hold = state->hold; \
126         bits = state->bits; \
127     } while (0)
128 
129 /* Set state from registers for inflate_fast() */
130 #define RESTORE() \
131     do { \
132         strm->next_out = put; \
133         strm->avail_out = left; \
134         strm->next_in = next; \
135         strm->avail_in = have; \
136         state->hold = hold; \
137         state->bits = bits; \
138     } while (0)
139 
140 /* Clear the input bit accumulator */
141 #define INITBITS() \
142     do { \
143         hold = 0; \
144         bits = 0; \
145     } while (0)
146 
147 /* Assure that some input is available.  If input is requested, but denied,
148    then return a Z_BUF_ERROR from inflateBack(). */
149 #define PULL() \
150     do { \
151         if (have == 0) { \
152             have = in(in_desc, &next); \
153             if (have == 0) { \
154                 next = Z_NULL; \
155                 ret = Z_BUF_ERROR; \
156                 goto inf_leave; \
157             } \
158         } \
159     } while (0)
160 
161 /* Get a byte of input into the bit accumulator, or return from inflateBack()
162    with an error if there is no input available. */
163 #define PULLBYTE() \
164     do { \
165         PULL(); \
166         have--; \
167         hold += (unsigned long)(*next++) << bits; \
168         bits += 8; \
169     } while (0)
170 
171 /* Assure that there are at least n bits in the bit accumulator.  If there is
172    not enough available input to do that, then return from inflateBack() with
173    an error. */
174 #define NEEDBITS(n) \
175     do { \
176         while (bits < (unsigned)(n)) \
177             PULLBYTE(); \
178     } while (0)
179 
180 /* Return the low n bits of the bit accumulator (n < 16) */
181 #define BITS(n) \
182     ((unsigned)hold & ((1U << (n)) - 1))
183 
184 /* Remove n bits from the bit accumulator */
185 #define DROPBITS(n) \
186     do { \
187         hold >>= (n); \
188         bits -= (unsigned)(n); \
189     } while (0)
190 
191 /* Remove zero to seven bits as needed to go to a byte boundary */
192 #define BYTEBITS() \
193     do { \
194         hold >>= bits & 7; \
195         bits -= bits & 7; \
196     } while (0)
197 
198 /* Assure that some output space is available, by writing out the window
199    if it's full.  If the write fails, return from inflateBack() with a
200    Z_BUF_ERROR. */
201 #define ROOM() \
202     do { \
203         if (left == 0) { \
204             put = state->window; \
205             left = state->wsize; \
206             state->whave = left; \
207             if (out(out_desc, put, left)) { \
208                 ret = Z_BUF_ERROR; \
209                 goto inf_leave; \
210             } \
211         } \
212     } while (0)
213 
214 /*
215    strm provides the memory allocation functions and window buffer on input,
216    and provides information on the unused input on return.  For Z_DATA_ERROR
217    returns, strm will also provide an error message.
218 
219    in() and out() are the call-back input and output functions.  When
220    inflateBack() needs more input, it calls in().  When inflateBack() has
221    filled the window with output, or when it completes with data in the
222    window, it calls out() to write out the data.  The application must not
223    change the provided input until in() is called again or inflateBack()
224    returns.  The application must not change the window/output buffer until
225    inflateBack() returns.
226 
227    in() and out() are called with a descriptor parameter provided in the
228    inflateBack() call.  This parameter can be a structure that provides the
229    information required to do the read or write, as well as accumulated
230    information on the input and output such as totals and check values.
231 
232    in() should return zero on failure.  out() should return non-zero on
233    failure.  If either in() or out() fails, than inflateBack() returns a
234    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
235    was in() or out() that caused in the error.  Otherwise,  inflateBack()
236    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
237    error, or Z_MEM_ERROR if it could not allocate memory for the state.
238    inflateBack() can also return Z_STREAM_ERROR if the input parameters
239    are not correct, i.e. strm is Z_NULL or the state was not initialized.
240  */
241 int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
242 z_streamp strm;
243 in_func in;
244 void FAR *in_desc;
245 out_func out;
246 void FAR *out_desc;
247 {
248     struct inflate_state FAR *state;
249     z_const unsigned char FAR *next;    /* next input */
250     unsigned char FAR *put;     /* next output */
251     unsigned have, left;        /* available input and output */
252     unsigned long hold;         /* bit buffer */
253     unsigned bits;              /* bits in bit buffer */
254     unsigned copy;              /* number of stored or match bytes to copy */
255     unsigned char FAR *from;    /* where to copy match bytes from */
256     code this;                  /* current decoding table entry */
257     code last;                  /* parent table entry */
258     unsigned len;               /* length to copy for repeats, bits to drop */
259     int ret;                    /* return code */
260     static const unsigned short order[19] = /* permutation of code lengths */
261         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
262 
263     /* Check that the strm exists and that the state was initialized */
264     if (strm == Z_NULL || strm->state == Z_NULL)
265         return Z_STREAM_ERROR;
266     state = (struct inflate_state FAR *)strm->state;
267 
268     /* Reset the state */
269     strm->msg = Z_NULL;
270     state->mode = TYPE;
271     state->last = 0;
272     state->whave = 0;
273     next = strm->next_in;
274     have = next != Z_NULL ? strm->avail_in : 0;
275     hold = 0;
276     bits = 0;
277     put = state->window;
278     left = state->wsize;
279 
280     /* Inflate until end of block marked as last */
281     for (;;)
282         switch (state->mode) {
283         case TYPE:
284             /* determine and dispatch block type */
285             if (state->last) {
286                 BYTEBITS();
287                 state->mode = DONE;
288                 break;
289             }
290             NEEDBITS(3);
291             state->last = BITS(1);
292             DROPBITS(1);
293             switch (BITS(2)) {
294             case 0:                             /* stored block */
295                 Tracev((stderr, "inflate:     stored block%s\n",
296                         state->last ? " (last)" : ""));
297                 state->mode = STORED;
298                 break;
299             case 1:                             /* fixed block */
300                 fixedtables(state);
301                 Tracev((stderr, "inflate:     fixed codes block%s\n",
302                         state->last ? " (last)" : ""));
303                 state->mode = LEN;              /* decode codes */
304                 break;
305             case 2:                             /* dynamic block */
306                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
307                         state->last ? " (last)" : ""));
308                 state->mode = TABLE;
309                 break;
310             case 3:
311 #ifdef SMALL
312 		strm->msg = "error";
313 #else
314                 strm->msg = (char *)"invalid block type";
315 #endif
316                 state->mode = BAD;
317             }
318             DROPBITS(2);
319             break;
320 
321         case STORED:
322             /* get and verify stored block length */
323             BYTEBITS();                         /* go to byte boundary */
324             NEEDBITS(32);
325             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
326 #ifdef SMALL
327 		strm->msg = "error";
328 #else
329                 strm->msg = (char *)"invalid stored block lengths";
330 #endif
331                 state->mode = BAD;
332                 break;
333             }
334             state->length = (unsigned)hold & 0xffff;
335             Tracev((stderr, "inflate:       stored length %u\n",
336                     state->length));
337             INITBITS();
338 
339             /* copy stored block from input to output */
340             while (state->length != 0) {
341                 copy = state->length;
342                 PULL();
343                 ROOM();
344                 if (copy > have) copy = have;
345                 if (copy > left) copy = left;
346                 zmemcpy(put, next, copy);
347                 have -= copy;
348                 next += copy;
349                 left -= copy;
350                 put += copy;
351                 state->length -= copy;
352             }
353             Tracev((stderr, "inflate:       stored end\n"));
354             state->mode = TYPE;
355             break;
356 
357         case TABLE:
358             /* get dynamic table entries descriptor */
359             NEEDBITS(14);
360             state->nlen = BITS(5) + 257;
361             DROPBITS(5);
362             state->ndist = BITS(5) + 1;
363             DROPBITS(5);
364             state->ncode = BITS(4) + 4;
365             DROPBITS(4);
366 #ifndef PKZIP_BUG_WORKAROUND
367             if (state->nlen > 286 || state->ndist > 30) {
368 #ifdef SMALL
369 		strm->msg = "error";
370 #else
371                 strm->msg = (char *)"too many length or distance symbols";
372 #endif
373                 state->mode = BAD;
374                 break;
375             }
376 #endif
377             Tracev((stderr, "inflate:       table sizes ok\n"));
378 
379             /* get code length code lengths (not a typo) */
380             state->have = 0;
381             while (state->have < state->ncode) {
382                 NEEDBITS(3);
383                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
384                 DROPBITS(3);
385             }
386             while (state->have < 19)
387                 state->lens[order[state->have++]] = 0;
388             state->next = state->codes;
389             state->lencode = (code const FAR *)(state->next);
390             state->lenbits = 7;
391             ret = inflate_table(CODES, state->lens, 19, &(state->next),
392                                 &(state->lenbits), state->work);
393             if (ret) {
394                 strm->msg = (char *)"invalid code lengths set";
395                 state->mode = BAD;
396                 break;
397             }
398             Tracev((stderr, "inflate:       code lengths ok\n"));
399 
400             /* get length and distance code code lengths */
401             state->have = 0;
402             while (state->have < state->nlen + state->ndist) {
403                 for (;;) {
404                     this = state->lencode[BITS(state->lenbits)];
405                     if ((unsigned)(this.bits) <= bits) break;
406                     PULLBYTE();
407                 }
408                 if (this.val < 16) {
409                     NEEDBITS(this.bits);
410                     DROPBITS(this.bits);
411                     state->lens[state->have++] = this.val;
412                 }
413                 else {
414                     if (this.val == 16) {
415                         NEEDBITS(this.bits + 2);
416                         DROPBITS(this.bits);
417                         if (state->have == 0) {
418                             strm->msg = (char *)"invalid bit length repeat";
419                             state->mode = BAD;
420                             break;
421                         }
422                         len = (unsigned)(state->lens[state->have - 1]);
423                         copy = 3 + BITS(2);
424                         DROPBITS(2);
425                     }
426                     else if (this.val == 17) {
427                         NEEDBITS(this.bits + 3);
428                         DROPBITS(this.bits);
429                         len = 0;
430                         copy = 3 + BITS(3);
431                         DROPBITS(3);
432                     }
433                     else {
434                         NEEDBITS(this.bits + 7);
435                         DROPBITS(this.bits);
436                         len = 0;
437                         copy = 11 + BITS(7);
438                         DROPBITS(7);
439                     }
440                     if (state->have + copy > state->nlen + state->ndist) {
441                         strm->msg = (char *)"invalid bit length repeat";
442                         state->mode = BAD;
443                         break;
444                     }
445                     while (copy--)
446                         state->lens[state->have++] = (unsigned short)len;
447                 }
448             }
449 
450             /* handle error breaks in while */
451             if (state->mode == BAD) break;
452 
453             /* build code tables */
454             state->next = state->codes;
455             state->lencode = (code const FAR *)(state->next);
456             state->lenbits = 9;
457             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
458                                 &(state->lenbits), state->work);
459             if (ret) {
460                 strm->msg = (char *)"invalid literal/lengths set";
461                 state->mode = BAD;
462                 break;
463             }
464             state->distcode = (code const FAR *)(state->next);
465             state->distbits = 6;
466             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
467                             &(state->next), &(state->distbits), state->work);
468             if (ret) {
469                 strm->msg = (char *)"invalid distances set";
470                 state->mode = BAD;
471                 break;
472             }
473             Tracev((stderr, "inflate:       codes ok\n"));
474             state->mode = LEN;
475 
476         case LEN:
477 #ifndef SLOW
478             /* use inflate_fast() if we have enough input and output */
479             if (have >= 6 && left >= 258) {
480                 RESTORE();
481                 if (state->whave < state->wsize)
482                     state->whave = state->wsize - left;
483                 inflate_fast(strm, state->wsize);
484                 LOAD();
485                 break;
486             }
487 #endif
488 
489             /* get a literal, length, or end-of-block code */
490             for (;;) {
491                 this = state->lencode[BITS(state->lenbits)];
492                 if ((unsigned)(this.bits) <= bits) break;
493                 PULLBYTE();
494             }
495             if (this.op && (this.op & 0xf0) == 0) {
496                 last = this;
497                 for (;;) {
498                     this = state->lencode[last.val +
499                             (BITS(last.bits + last.op) >> last.bits)];
500                     if ((unsigned)(last.bits + this.bits) <= bits) break;
501                     PULLBYTE();
502                 }
503                 DROPBITS(last.bits);
504             }
505             DROPBITS(this.bits);
506             state->length = (unsigned)this.val;
507 
508             /* process literal */
509             if (this.op == 0) {
510                 Tracevv((stderr, this.val >= 0x20 && this.val < 0x7f ?
511                         "inflate:         literal '%c'\n" :
512                         "inflate:         literal 0x%02x\n", this.val));
513                 ROOM();
514                 *put++ = (unsigned char)(state->length);
515                 left--;
516                 state->mode = LEN;
517                 break;
518             }
519 
520             /* process end of block */
521             if (this.op & 32) {
522                 Tracevv((stderr, "inflate:         end of block\n"));
523                 state->mode = TYPE;
524                 break;
525             }
526 
527             /* invalid code */
528             if (this.op & 64) {
529                 strm->msg = (char *)"invalid literal/length code";
530                 state->mode = BAD;
531                 break;
532             }
533 
534             /* length code -- get extra bits, if any */
535             state->extra = (unsigned)(this.op) & 15;
536             if (state->extra != 0) {
537                 NEEDBITS(state->extra);
538                 state->length += BITS(state->extra);
539                 DROPBITS(state->extra);
540             }
541             Tracevv((stderr, "inflate:         length %u\n", state->length));
542 
543             /* get distance code */
544             for (;;) {
545                 this = state->distcode[BITS(state->distbits)];
546                 if ((unsigned)(this.bits) <= bits) break;
547                 PULLBYTE();
548             }
549             if ((this.op & 0xf0) == 0) {
550                 last = this;
551                 for (;;) {
552                     this = state->distcode[last.val +
553                             (BITS(last.bits + last.op) >> last.bits)];
554                     if ((unsigned)(last.bits + this.bits) <= bits) break;
555                     PULLBYTE();
556                 }
557                 DROPBITS(last.bits);
558             }
559             DROPBITS(this.bits);
560             if (this.op & 64) {
561                 strm->msg = (char *)"invalid distance code";
562                 state->mode = BAD;
563                 break;
564             }
565             state->offset = (unsigned)this.val;
566 
567             /* get distance extra bits, if any */
568             state->extra = (unsigned)(this.op) & 15;
569             if (state->extra != 0) {
570                 NEEDBITS(state->extra);
571                 state->offset += BITS(state->extra);
572                 DROPBITS(state->extra);
573             }
574             if (state->offset > state->wsize - (state->whave < state->wsize ?
575                                                 left : 0)) {
576                 strm->msg = (char *)"invalid distance too far back";
577                 state->mode = BAD;
578                 break;
579             }
580             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
581 
582             /* copy match from window to output */
583             do {
584                 ROOM();
585                 copy = state->wsize - state->offset;
586                 if (copy < left) {
587                     from = put + copy;
588                     copy = left - copy;
589                 }
590                 else {
591                     from = put - state->offset;
592                     copy = left;
593                 }
594                 if (copy > state->length) copy = state->length;
595                 state->length -= copy;
596                 left -= copy;
597                 do {
598                     *put++ = *from++;
599                 } while (--copy);
600             } while (state->length != 0);
601             break;
602 
603         case DONE:
604             /* inflate stream terminated properly -- write leftover output */
605             ret = Z_STREAM_END;
606             if (left < state->wsize) {
607                 if (out(out_desc, state->window, state->wsize - left))
608                     ret = Z_BUF_ERROR;
609             }
610             goto inf_leave;
611 
612         case BAD:
613             ret = Z_DATA_ERROR;
614             goto inf_leave;
615 
616         default:                /* can't happen, but makes compilers happy */
617             ret = Z_STREAM_ERROR;
618             goto inf_leave;
619         }
620 
621     /* Return unused input */
622   inf_leave:
623     strm->next_in = next;
624     strm->avail_in = have;
625     return ret;
626 }
627 
628 int ZEXPORT inflateBackEnd(strm)
629 z_streamp strm;
630 {
631     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
632         return Z_STREAM_ERROR;
633     ZFREE(strm, strm->state);
634     strm->state = Z_NULL;
635     Tracev((stderr, "inflate: end\n"));
636     return Z_OK;
637 }
638