xref: /openbsd/sys/lib/libz/infback.c (revision 76d0caae)
1 /*	$OpenBSD: infback.c,v 1.8 2021/07/22 16:40:20 tb Exp $ */
2 /* infback.c -- inflate using a call-back interface
3  * Copyright (C) 1995-2016 Mark Adler
4  * For conditions of distribution and use, see copyright notice in zlib.h
5  */
6 
7 /*
8    This code is largely copied from inflate.c.  Normally either infback.o or
9    inflate.o would be linked into an application--not both.  The interface
10    with inffast.c is retained so that optimized assembler-coded versions of
11    inflate_fast() can be used with either inflate.c or infback.c.
12  */
13 
14 #include "zutil.h"
15 #include "inftrees.h"
16 #include "inflate.h"
17 #include "inffast.h"
18 
19 /* function prototypes */
20 local void fixedtables OF((struct inflate_state FAR *state));
21 
22 /*
23    strm provides memory allocation functions in zalloc and zfree, or
24    Z_NULL to use the library memory allocation functions.
25 
26    windowBits is in the range 8..15, and window is a user-supplied
27    window and output buffer that is 2**windowBits bytes.
28  */
29 int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
30 z_streamp strm;
31 int windowBits;
32 unsigned char FAR *window;
33 const char *version;
34 int stream_size;
35 {
36     struct inflate_state FAR *state;
37 
38     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
39         stream_size != (int)(sizeof(z_stream)))
40         return Z_VERSION_ERROR;
41     if (strm == Z_NULL || window == Z_NULL ||
42         windowBits < 8 || windowBits > 15)
43         return Z_STREAM_ERROR;
44     strm->msg = Z_NULL;                 /* in case we return an error */
45     if (strm->zalloc == (alloc_func)0) {
46 #ifdef Z_SOLO
47         return Z_STREAM_ERROR;
48 #else
49         strm->zalloc = zcalloc;
50         strm->opaque = (voidpf)0;
51 #endif
52     }
53     if (strm->zfree == (free_func)0)
54 #ifdef Z_SOLO
55         return Z_STREAM_ERROR;
56 #else
57     strm->zfree = zcfree;
58 #endif
59     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
60                                                sizeof(struct inflate_state));
61     if (state == Z_NULL) return Z_MEM_ERROR;
62     Tracev((stderr, "inflate: allocated\n"));
63     strm->state = (struct internal_state FAR *)state;
64     state->dmax = 32768U;
65     state->wbits = (uInt)windowBits;
66     state->wsize = 1U << windowBits;
67     state->window = window;
68     state->wnext = 0;
69     state->whave = 0;
70     return Z_OK;
71 }
72 
73 /*
74    Return state with length and distance decoding tables and index sizes set to
75    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
76    If BUILDFIXED is defined, then instead this routine builds the tables the
77    first time it's called, and returns those tables the first time and
78    thereafter.  This reduces the size of the code by about 2K bytes, in
79    exchange for a little execution time.  However, BUILDFIXED should not be
80    used for threaded applications, since the rewriting of the tables and virgin
81    may not be thread-safe.
82  */
83 local void fixedtables(state)
84 struct inflate_state FAR *state;
85 {
86 #ifdef BUILDFIXED
87     static int virgin = 1;
88     static code *lenfix, *distfix;
89     static code fixed[544];
90 
91     /* build fixed huffman tables if first call (may not be thread safe) */
92     if (virgin) {
93         unsigned sym, bits;
94         static code *next;
95 
96         /* literal/length table */
97         sym = 0;
98         while (sym < 144) state->lens[sym++] = 8;
99         while (sym < 256) state->lens[sym++] = 9;
100         while (sym < 280) state->lens[sym++] = 7;
101         while (sym < 288) state->lens[sym++] = 8;
102         next = fixed;
103         lenfix = next;
104         bits = 9;
105         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
106 
107         /* distance table */
108         sym = 0;
109         while (sym < 32) state->lens[sym++] = 5;
110         distfix = next;
111         bits = 5;
112         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
113 
114         /* do this just once */
115         virgin = 0;
116     }
117 #else /* !BUILDFIXED */
118 #   include "inffixed.h"
119 #endif /* BUILDFIXED */
120     state->lencode = lenfix;
121     state->lenbits = 9;
122     state->distcode = distfix;
123     state->distbits = 5;
124 }
125 
126 /* Macros for inflateBack(): */
127 
128 /* Load returned state from inflate_fast() */
129 #define LOAD() \
130     do { \
131         put = strm->next_out; \
132         left = strm->avail_out; \
133         next = strm->next_in; \
134         have = strm->avail_in; \
135         hold = state->hold; \
136         bits = state->bits; \
137     } while (0)
138 
139 /* Set state from registers for inflate_fast() */
140 #define RESTORE() \
141     do { \
142         strm->next_out = put; \
143         strm->avail_out = left; \
144         strm->next_in = next; \
145         strm->avail_in = have; \
146         state->hold = hold; \
147         state->bits = bits; \
148     } while (0)
149 
150 /* Clear the input bit accumulator */
151 #define INITBITS() \
152     do { \
153         hold = 0; \
154         bits = 0; \
155     } while (0)
156 
157 /* Assure that some input is available.  If input is requested, but denied,
158    then return a Z_BUF_ERROR from inflateBack(). */
159 #define PULL() \
160     do { \
161         if (have == 0) { \
162             have = in(in_desc, &next); \
163             if (have == 0) { \
164                 next = Z_NULL; \
165                 ret = Z_BUF_ERROR; \
166                 goto inf_leave; \
167             } \
168         } \
169     } while (0)
170 
171 /* Get a byte of input into the bit accumulator, or return from inflateBack()
172    with an error if there is no input available. */
173 #define PULLBYTE() \
174     do { \
175         PULL(); \
176         have--; \
177         hold += (unsigned long)(*next++) << bits; \
178         bits += 8; \
179     } while (0)
180 
181 /* Assure that there are at least n bits in the bit accumulator.  If there is
182    not enough available input to do that, then return from inflateBack() with
183    an error. */
184 #define NEEDBITS(n) \
185     do { \
186         while (bits < (unsigned)(n)) \
187             PULLBYTE(); \
188     } while (0)
189 
190 /* Return the low n bits of the bit accumulator (n < 16) */
191 #define BITS(n) \
192     ((unsigned)hold & ((1U << (n)) - 1))
193 
194 /* Remove n bits from the bit accumulator */
195 #define DROPBITS(n) \
196     do { \
197         hold >>= (n); \
198         bits -= (unsigned)(n); \
199     } while (0)
200 
201 /* Remove zero to seven bits as needed to go to a byte boundary */
202 #define BYTEBITS() \
203     do { \
204         hold >>= bits & 7; \
205         bits -= bits & 7; \
206     } while (0)
207 
208 /* Assure that some output space is available, by writing out the window
209    if it's full.  If the write fails, return from inflateBack() with a
210    Z_BUF_ERROR. */
211 #define ROOM() \
212     do { \
213         if (left == 0) { \
214             put = state->window; \
215             left = state->wsize; \
216             state->whave = left; \
217             if (out(out_desc, put, left)) { \
218                 ret = Z_BUF_ERROR; \
219                 goto inf_leave; \
220             } \
221         } \
222     } while (0)
223 
224 /*
225    strm provides the memory allocation functions and window buffer on input,
226    and provides information on the unused input on return.  For Z_DATA_ERROR
227    returns, strm will also provide an error message.
228 
229    in() and out() are the call-back input and output functions.  When
230    inflateBack() needs more input, it calls in().  When inflateBack() has
231    filled the window with output, or when it completes with data in the
232    window, it calls out() to write out the data.  The application must not
233    change the provided input until in() is called again or inflateBack()
234    returns.  The application must not change the window/output buffer until
235    inflateBack() returns.
236 
237    in() and out() are called with a descriptor parameter provided in the
238    inflateBack() call.  This parameter can be a structure that provides the
239    information required to do the read or write, as well as accumulated
240    information on the input and output such as totals and check values.
241 
242    in() should return zero on failure.  out() should return non-zero on
243    failure.  If either in() or out() fails, than inflateBack() returns a
244    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
245    was in() or out() that caused in the error.  Otherwise,  inflateBack()
246    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
247    error, or Z_MEM_ERROR if it could not allocate memory for the state.
248    inflateBack() can also return Z_STREAM_ERROR if the input parameters
249    are not correct, i.e. strm is Z_NULL or the state was not initialized.
250  */
251 int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
252 z_streamp strm;
253 in_func in;
254 void FAR *in_desc;
255 out_func out;
256 void FAR *out_desc;
257 {
258     struct inflate_state FAR *state;
259     z_const unsigned char FAR *next;    /* next input */
260     unsigned char FAR *put;     /* next output */
261     unsigned have, left;        /* available input and output */
262     unsigned long hold;         /* bit buffer */
263     unsigned bits;              /* bits in bit buffer */
264     unsigned copy;              /* number of stored or match bytes to copy */
265     unsigned char FAR *from;    /* where to copy match bytes from */
266     code here;                  /* current decoding table entry */
267     code last;                  /* parent table entry */
268     unsigned len;               /* length to copy for repeats, bits to drop */
269     int ret;                    /* return code */
270     static const unsigned short order[19] = /* permutation of code lengths */
271         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
272 
273     /* Check that the strm exists and that the state was initialized */
274     if (strm == Z_NULL || strm->state == Z_NULL)
275         return Z_STREAM_ERROR;
276     state = (struct inflate_state FAR *)strm->state;
277 
278     /* Reset the state */
279     strm->msg = Z_NULL;
280     state->mode = TYPE;
281     state->last = 0;
282     state->whave = 0;
283     next = strm->next_in;
284     have = next != Z_NULL ? strm->avail_in : 0;
285     hold = 0;
286     bits = 0;
287     put = state->window;
288     left = state->wsize;
289 
290     /* Inflate until end of block marked as last */
291     for (;;)
292         switch (state->mode) {
293         case TYPE:
294             /* determine and dispatch block type */
295             if (state->last) {
296                 BYTEBITS();
297                 state->mode = DONE;
298                 break;
299             }
300             NEEDBITS(3);
301             state->last = BITS(1);
302             DROPBITS(1);
303             switch (BITS(2)) {
304             case 0:                             /* stored block */
305                 Tracev((stderr, "inflate:     stored block%s\n",
306                         state->last ? " (last)" : ""));
307                 state->mode = STORED;
308                 break;
309             case 1:                             /* fixed block */
310                 fixedtables(state);
311                 Tracev((stderr, "inflate:     fixed codes block%s\n",
312                         state->last ? " (last)" : ""));
313                 state->mode = LEN;              /* decode codes */
314                 break;
315             case 2:                             /* dynamic block */
316                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
317                         state->last ? " (last)" : ""));
318                 state->mode = TABLE;
319                 break;
320             case 3:
321 #ifdef SMALL
322                 strm->msg = "error";
323 #else
324                 strm->msg = (char *)"invalid block type";
325 #endif
326                 state->mode = BAD;
327             }
328             DROPBITS(2);
329             break;
330 
331         case STORED:
332             /* get and verify stored block length */
333             BYTEBITS();                         /* go to byte boundary */
334             NEEDBITS(32);
335             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
336 #ifdef SMALL
337                 strm->msg = "error";
338 #else
339                 strm->msg = (char *)"invalid stored block lengths";
340 #endif
341                 state->mode = BAD;
342                 break;
343             }
344             state->length = (unsigned)hold & 0xffff;
345             Tracev((stderr, "inflate:       stored length %u\n",
346                     state->length));
347             INITBITS();
348 
349             /* copy stored block from input to output */
350             while (state->length != 0) {
351                 copy = state->length;
352                 PULL();
353                 ROOM();
354                 if (copy > have) copy = have;
355                 if (copy > left) copy = left;
356                 zmemcpy(put, next, copy);
357                 have -= copy;
358                 next += copy;
359                 left -= copy;
360                 put += copy;
361                 state->length -= copy;
362             }
363             Tracev((stderr, "inflate:       stored end\n"));
364             state->mode = TYPE;
365             break;
366 
367         case TABLE:
368             /* get dynamic table entries descriptor */
369             NEEDBITS(14);
370             state->nlen = BITS(5) + 257;
371             DROPBITS(5);
372             state->ndist = BITS(5) + 1;
373             DROPBITS(5);
374             state->ncode = BITS(4) + 4;
375             DROPBITS(4);
376 #ifndef PKZIP_BUG_WORKAROUND
377             if (state->nlen > 286 || state->ndist > 30) {
378 #ifdef SMALL
379                 strm->msg = "error";
380 #else
381                 strm->msg = (char *)"too many length or distance symbols";
382 #endif
383                 state->mode = BAD;
384                 break;
385             }
386 #endif
387             Tracev((stderr, "inflate:       table sizes ok\n"));
388 
389             /* get code length code lengths (not a typo) */
390             state->have = 0;
391             while (state->have < state->ncode) {
392                 NEEDBITS(3);
393                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
394                 DROPBITS(3);
395             }
396             while (state->have < 19)
397                 state->lens[order[state->have++]] = 0;
398             state->next = state->codes;
399             state->lencode = (code const FAR *)(state->next);
400             state->lenbits = 7;
401             ret = inflate_table(CODES, state->lens, 19, &(state->next),
402                                 &(state->lenbits), state->work);
403             if (ret) {
404 #ifdef SMALL
405                 strm->msg = "error";
406 #else
407                 strm->msg = (char *)"invalid code lengths set";
408 #endif
409                 state->mode = BAD;
410                 break;
411             }
412             Tracev((stderr, "inflate:       code lengths ok\n"));
413 
414             /* get length and distance code code lengths */
415             state->have = 0;
416             while (state->have < state->nlen + state->ndist) {
417                 for (;;) {
418                     here = state->lencode[BITS(state->lenbits)];
419                     if ((unsigned)(here.bits) <= bits) break;
420                     PULLBYTE();
421                 }
422                 if (here.val < 16) {
423                     DROPBITS(here.bits);
424                     state->lens[state->have++] = here.val;
425                 }
426                 else {
427                     if (here.val == 16) {
428                         NEEDBITS(here.bits + 2);
429                         DROPBITS(here.bits);
430                         if (state->have == 0) {
431 #ifdef SMALL
432                             strm->msg = "error";
433 #else
434                             strm->msg = (char *)"invalid bit length repeat";
435 #endif
436                             state->mode = BAD;
437                             break;
438                         }
439                         len = (unsigned)(state->lens[state->have - 1]);
440                         copy = 3 + BITS(2);
441                         DROPBITS(2);
442                     }
443                     else if (here.val == 17) {
444                         NEEDBITS(here.bits + 3);
445                         DROPBITS(here.bits);
446                         len = 0;
447                         copy = 3 + BITS(3);
448                         DROPBITS(3);
449                     }
450                     else {
451                         NEEDBITS(here.bits + 7);
452                         DROPBITS(here.bits);
453                         len = 0;
454                         copy = 11 + BITS(7);
455                         DROPBITS(7);
456                     }
457                     if (state->have + copy > state->nlen + state->ndist) {
458 #ifdef SMALL
459                         strm->msg = "error";
460 #else
461                         strm->msg = (char *)"invalid bit length repeat";
462 #endif
463                         state->mode = BAD;
464                         break;
465                     }
466                     while (copy--)
467                         state->lens[state->have++] = (unsigned short)len;
468                 }
469             }
470 
471             /* handle error breaks in while */
472             if (state->mode == BAD) break;
473 
474             /* check for end-of-block code (better have one) */
475             if (state->lens[256] == 0) {
476 #ifdef SMALL
477                 strm->msg = "error";
478 #else
479                 strm->msg = (char *)"invalid code -- missing end-of-block";
480 #endif
481                 state->mode = BAD;
482                 break;
483             }
484 
485             /* build code tables -- note: do not change the lenbits or distbits
486                values here (9 and 6) without reading the comments in inftrees.h
487                concerning the ENOUGH constants, which depend on those values */
488             state->next = state->codes;
489             state->lencode = (code const FAR *)(state->next);
490             state->lenbits = 9;
491             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
492                                 &(state->lenbits), state->work);
493             if (ret) {
494 #ifdef SMALL
495                 strm->msg = "error";
496 #else
497                 strm->msg = (char *)"invalid literal/lengths set";
498 #endif
499                 state->mode = BAD;
500                 break;
501             }
502             state->distcode = (code const FAR *)(state->next);
503             state->distbits = 6;
504             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
505                             &(state->next), &(state->distbits), state->work);
506             if (ret) {
507 #ifdef SMALL
508                 strm->msg = "error";
509 #else
510                 strm->msg = (char *)"invalid distances set";
511 #endif
512                 state->mode = BAD;
513                 break;
514             }
515             Tracev((stderr, "inflate:       codes ok\n"));
516             state->mode = LEN;
517 
518         case LEN:
519 #ifndef SLOW
520             /* use inflate_fast() if we have enough input and output */
521             if (have >= 6 && left >= 258) {
522                 RESTORE();
523                 if (state->whave < state->wsize)
524                     state->whave = state->wsize - left;
525                 inflate_fast(strm, state->wsize);
526                 LOAD();
527                 break;
528             }
529 #endif
530 
531             /* get a literal, length, or end-of-block code */
532             for (;;) {
533                 here = state->lencode[BITS(state->lenbits)];
534                 if ((unsigned)(here.bits) <= bits) break;
535                 PULLBYTE();
536             }
537             if (here.op && (here.op & 0xf0) == 0) {
538                 last = here;
539                 for (;;) {
540                     here = state->lencode[last.val +
541                             (BITS(last.bits + last.op) >> last.bits)];
542                     if ((unsigned)(last.bits + here.bits) <= bits) break;
543                     PULLBYTE();
544                 }
545                 DROPBITS(last.bits);
546             }
547             DROPBITS(here.bits);
548             state->length = (unsigned)here.val;
549 
550             /* process literal */
551             if (here.op == 0) {
552                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
553                         "inflate:         literal '%c'\n" :
554                         "inflate:         literal 0x%02x\n", here.val));
555                 ROOM();
556                 *put++ = (unsigned char)(state->length);
557                 left--;
558                 state->mode = LEN;
559                 break;
560             }
561 
562             /* process end of block */
563             if (here.op & 32) {
564                 Tracevv((stderr, "inflate:         end of block\n"));
565                 state->mode = TYPE;
566                 break;
567             }
568 
569             /* invalid code */
570             if (here.op & 64) {
571 #ifdef SMALL
572                 strm->msg = "error";
573 #else
574                 strm->msg = (char *)"invalid literal/length code";
575 #endif
576                 state->mode = BAD;
577                 break;
578             }
579 
580             /* length code -- get extra bits, if any */
581             state->extra = (unsigned)(here.op) & 15;
582             if (state->extra != 0) {
583                 NEEDBITS(state->extra);
584                 state->length += BITS(state->extra);
585                 DROPBITS(state->extra);
586             }
587             Tracevv((stderr, "inflate:         length %u\n", state->length));
588 
589             /* get distance code */
590             for (;;) {
591                 here = state->distcode[BITS(state->distbits)];
592                 if ((unsigned)(here.bits) <= bits) break;
593                 PULLBYTE();
594             }
595             if ((here.op & 0xf0) == 0) {
596                 last = here;
597                 for (;;) {
598                     here = state->distcode[last.val +
599                             (BITS(last.bits + last.op) >> last.bits)];
600                     if ((unsigned)(last.bits + here.bits) <= bits) break;
601                     PULLBYTE();
602                 }
603                 DROPBITS(last.bits);
604             }
605             DROPBITS(here.bits);
606             if (here.op & 64) {
607 #ifdef SMALL
608                 strm->msg = "error";
609 #else
610                 strm->msg = (char *)"invalid distance code";
611 #endif
612                 state->mode = BAD;
613                 break;
614             }
615             state->offset = (unsigned)here.val;
616 
617             /* get distance extra bits, if any */
618             state->extra = (unsigned)(here.op) & 15;
619             if (state->extra != 0) {
620                 NEEDBITS(state->extra);
621                 state->offset += BITS(state->extra);
622                 DROPBITS(state->extra);
623             }
624             if (state->offset > state->wsize - (state->whave < state->wsize ?
625                                                 left : 0)) {
626 #ifdef SMALL
627                 strm->msg = "error";
628 #else
629                 strm->msg = (char *)"invalid distance too far back";
630 #endif
631                 state->mode = BAD;
632                 break;
633             }
634             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
635 
636             /* copy match from window to output */
637             do {
638                 ROOM();
639                 copy = state->wsize - state->offset;
640                 if (copy < left) {
641                     from = put + copy;
642                     copy = left - copy;
643                 }
644                 else {
645                     from = put - state->offset;
646                     copy = left;
647                 }
648                 if (copy > state->length) copy = state->length;
649                 state->length -= copy;
650                 left -= copy;
651                 do {
652                     *put++ = *from++;
653                 } while (--copy);
654             } while (state->length != 0);
655             break;
656 
657         case DONE:
658             /* inflate stream terminated properly -- write leftover output */
659             ret = Z_STREAM_END;
660             if (left < state->wsize) {
661                 if (out(out_desc, state->window, state->wsize - left))
662                     ret = Z_BUF_ERROR;
663             }
664             goto inf_leave;
665 
666         case BAD:
667             ret = Z_DATA_ERROR;
668             goto inf_leave;
669 
670         default:                /* can't happen, but makes compilers happy */
671             ret = Z_STREAM_ERROR;
672             goto inf_leave;
673         }
674 
675     /* Return unused input */
676   inf_leave:
677     strm->next_in = next;
678     strm->avail_in = have;
679     return ret;
680 }
681 
682 int ZEXPORT inflateBackEnd(strm)
683 z_streamp strm;
684 {
685     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
686         return Z_STREAM_ERROR;
687     ZFREE(strm, strm->state, sizeof(struct inflate_state));
688     strm->state = Z_NULL;
689     Tracev((stderr, "inflate: end\n"));
690     return Z_OK;
691 }
692