xref: /openbsd/lib/libz/infback.c (revision 4fb9ab68)
1 /* infback.c -- inflate using a call-back interface
2  * Copyright (C) 1995-2022 Mark Adler
3  * For conditions of distribution and use, see copyright notice in zlib.h
4  */
5 
6 /*
7    This code is largely copied from inflate.c.  Normally either infback.o or
8    inflate.o would be linked into an application--not both.  The interface
9    with inffast.c is retained so that optimized assembler-coded versions of
10    inflate_fast() can be used with either inflate.c or infback.c.
11  */
12 
13 #include "zutil.h"
14 #include "inftrees.h"
15 #include "inflate.h"
16 #include "inffast.h"
17 
18 /*
19    strm provides memory allocation functions in zalloc and zfree, or
20    Z_NULL to use the library memory allocation functions.
21 
22    windowBits is in the range 8..15, and window is a user-supplied
23    window and output buffer that is 2**windowBits bytes.
24  */
25 int ZEXPORT inflateBackInit_(z_streamp strm, int windowBits,
26                              unsigned char FAR *window, const char *version,
27                              int stream_size) {
28     struct inflate_state FAR *state;
29 
30     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
31         stream_size != (int)(sizeof(z_stream)))
32         return Z_VERSION_ERROR;
33     if (strm == Z_NULL || window == Z_NULL ||
34         windowBits < 8 || windowBits > 15)
35         return Z_STREAM_ERROR;
36     strm->msg = Z_NULL;                 /* in case we return an error */
37     if (strm->zalloc == (alloc_func)0) {
38 #ifdef Z_SOLO
39         return Z_STREAM_ERROR;
40 #else
41         strm->zalloc = zcalloc;
42         strm->opaque = (voidpf)0;
43 #endif
44     }
45     if (strm->zfree == (free_func)0)
46 #ifdef Z_SOLO
47         return Z_STREAM_ERROR;
48 #else
49     strm->zfree = zcfree;
50 #endif
51     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
52                                                sizeof(struct inflate_state));
53     if (state == Z_NULL) return Z_MEM_ERROR;
54     Tracev((stderr, "inflate: allocated\n"));
55     strm->state = (struct internal_state FAR *)state;
56     state->dmax = 32768U;
57     state->wbits = (uInt)windowBits;
58     state->wsize = 1U << windowBits;
59     state->window = window;
60     state->wnext = 0;
61     state->whave = 0;
62     state->sane = 1;
63     return Z_OK;
64 }
65 
66 /*
67    Return state with length and distance decoding tables and index sizes set to
68    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
69    If BUILDFIXED is defined, then instead this routine builds the tables the
70    first time it's called, and returns those tables the first time and
71    thereafter.  This reduces the size of the code by about 2K bytes, in
72    exchange for a little execution time.  However, BUILDFIXED should not be
73    used for threaded applications, since the rewriting of the tables and virgin
74    may not be thread-safe.
75  */
76 local void fixedtables(struct inflate_state FAR *state) {
77 #ifdef BUILDFIXED
78     static int virgin = 1;
79     static code *lenfix, *distfix;
80     static code fixed[544];
81 
82     /* build fixed huffman tables if first call (may not be thread safe) */
83     if (virgin) {
84         unsigned sym, bits;
85         static code *next;
86 
87         /* literal/length table */
88         sym = 0;
89         while (sym < 144) state->lens[sym++] = 8;
90         while (sym < 256) state->lens[sym++] = 9;
91         while (sym < 280) state->lens[sym++] = 7;
92         while (sym < 288) state->lens[sym++] = 8;
93         next = fixed;
94         lenfix = next;
95         bits = 9;
96         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
97 
98         /* distance table */
99         sym = 0;
100         while (sym < 32) state->lens[sym++] = 5;
101         distfix = next;
102         bits = 5;
103         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
104 
105         /* do this just once */
106         virgin = 0;
107     }
108 #else /* !BUILDFIXED */
109 #   include "inffixed.h"
110 #endif /* BUILDFIXED */
111     state->lencode = lenfix;
112     state->lenbits = 9;
113     state->distcode = distfix;
114     state->distbits = 5;
115 }
116 
117 /* Macros for inflateBack(): */
118 
119 /* Load returned state from inflate_fast() */
120 #define LOAD() \
121     do { \
122         put = strm->next_out; \
123         left = strm->avail_out; \
124         next = strm->next_in; \
125         have = strm->avail_in; \
126         hold = state->hold; \
127         bits = state->bits; \
128     } while (0)
129 
130 /* Set state from registers for inflate_fast() */
131 #define RESTORE() \
132     do { \
133         strm->next_out = put; \
134         strm->avail_out = left; \
135         strm->next_in = next; \
136         strm->avail_in = have; \
137         state->hold = hold; \
138         state->bits = bits; \
139     } while (0)
140 
141 /* Clear the input bit accumulator */
142 #define INITBITS() \
143     do { \
144         hold = 0; \
145         bits = 0; \
146     } while (0)
147 
148 /* Assure that some input is available.  If input is requested, but denied,
149    then return a Z_BUF_ERROR from inflateBack(). */
150 #define PULL() \
151     do { \
152         if (have == 0) { \
153             have = in(in_desc, &next); \
154             if (have == 0) { \
155                 next = Z_NULL; \
156                 ret = Z_BUF_ERROR; \
157                 goto inf_leave; \
158             } \
159         } \
160     } while (0)
161 
162 /* Get a byte of input into the bit accumulator, or return from inflateBack()
163    with an error if there is no input available. */
164 #define PULLBYTE() \
165     do { \
166         PULL(); \
167         have--; \
168         hold += (unsigned long)(*next++) << bits; \
169         bits += 8; \
170     } while (0)
171 
172 /* Assure that there are at least n bits in the bit accumulator.  If there is
173    not enough available input to do that, then return from inflateBack() with
174    an error. */
175 #define NEEDBITS(n) \
176     do { \
177         while (bits < (unsigned)(n)) \
178             PULLBYTE(); \
179     } while (0)
180 
181 /* Return the low n bits of the bit accumulator (n < 16) */
182 #define BITS(n) \
183     ((unsigned)hold & ((1U << (n)) - 1))
184 
185 /* Remove n bits from the bit accumulator */
186 #define DROPBITS(n) \
187     do { \
188         hold >>= (n); \
189         bits -= (unsigned)(n); \
190     } while (0)
191 
192 /* Remove zero to seven bits as needed to go to a byte boundary */
193 #define BYTEBITS() \
194     do { \
195         hold >>= bits & 7; \
196         bits -= bits & 7; \
197     } while (0)
198 
199 /* Assure that some output space is available, by writing out the window
200    if it's full.  If the write fails, return from inflateBack() with a
201    Z_BUF_ERROR. */
202 #define ROOM() \
203     do { \
204         if (left == 0) { \
205             put = state->window; \
206             left = state->wsize; \
207             state->whave = left; \
208             if (out(out_desc, put, left)) { \
209                 ret = Z_BUF_ERROR; \
210                 goto inf_leave; \
211             } \
212         } \
213     } while (0)
214 
215 /*
216    strm provides the memory allocation functions and window buffer on input,
217    and provides information on the unused input on return.  For Z_DATA_ERROR
218    returns, strm will also provide an error message.
219 
220    in() and out() are the call-back input and output functions.  When
221    inflateBack() needs more input, it calls in().  When inflateBack() has
222    filled the window with output, or when it completes with data in the
223    window, it calls out() to write out the data.  The application must not
224    change the provided input until in() is called again or inflateBack()
225    returns.  The application must not change the window/output buffer until
226    inflateBack() returns.
227 
228    in() and out() are called with a descriptor parameter provided in the
229    inflateBack() call.  This parameter can be a structure that provides the
230    information required to do the read or write, as well as accumulated
231    information on the input and output such as totals and check values.
232 
233    in() should return zero on failure.  out() should return non-zero on
234    failure.  If either in() or out() fails, than inflateBack() returns a
235    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
236    was in() or out() that caused in the error.  Otherwise,  inflateBack()
237    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
238    error, or Z_MEM_ERROR if it could not allocate memory for the state.
239    inflateBack() can also return Z_STREAM_ERROR if the input parameters
240    are not correct, i.e. strm is Z_NULL or the state was not initialized.
241  */
242 int ZEXPORT inflateBack(z_streamp strm, in_func in, void FAR *in_desc,
243                         out_func out, void FAR *out_desc) {
244     struct inflate_state FAR *state;
245     z_const unsigned char FAR *next;    /* next input */
246     unsigned char FAR *put;     /* next output */
247     unsigned have, left;        /* available input and output */
248     unsigned long hold;         /* bit buffer */
249     unsigned bits;              /* bits in bit buffer */
250     unsigned copy;              /* number of stored or match bytes to copy */
251     unsigned char FAR *from;    /* where to copy match bytes from */
252     code here;                  /* current decoding table entry */
253     code last;                  /* parent table entry */
254     unsigned len;               /* length to copy for repeats, bits to drop */
255     int ret;                    /* return code */
256     static const unsigned short order[19] = /* permutation of code lengths */
257         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
258 
259     /* Check that the strm exists and that the state was initialized */
260     if (strm == Z_NULL || strm->state == Z_NULL)
261         return Z_STREAM_ERROR;
262     state = (struct inflate_state FAR *)strm->state;
263 
264     /* Reset the state */
265     strm->msg = Z_NULL;
266     state->mode = TYPE;
267     state->last = 0;
268     state->whave = 0;
269     next = strm->next_in;
270     have = next != Z_NULL ? strm->avail_in : 0;
271     hold = 0;
272     bits = 0;
273     put = state->window;
274     left = state->wsize;
275 
276     /* Inflate until end of block marked as last */
277     for (;;)
278         switch (state->mode) {
279         case TYPE:
280             /* determine and dispatch block type */
281             if (state->last) {
282                 BYTEBITS();
283                 state->mode = DONE;
284                 break;
285             }
286             NEEDBITS(3);
287             state->last = BITS(1);
288             DROPBITS(1);
289             switch (BITS(2)) {
290             case 0:                             /* stored block */
291                 Tracev((stderr, "inflate:     stored block%s\n",
292                         state->last ? " (last)" : ""));
293                 state->mode = STORED;
294                 break;
295             case 1:                             /* fixed block */
296                 fixedtables(state);
297                 Tracev((stderr, "inflate:     fixed codes block%s\n",
298                         state->last ? " (last)" : ""));
299                 state->mode = LEN;              /* decode codes */
300                 break;
301             case 2:                             /* dynamic block */
302                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
303                         state->last ? " (last)" : ""));
304                 state->mode = TABLE;
305                 break;
306             case 3:
307 #ifdef SMALL
308                 strm->msg = (z_const char *)"error";
309 #else
310                 strm->msg = (z_const char *)"invalid block type";
311 #endif
312                 state->mode = BAD;
313             }
314             DROPBITS(2);
315             break;
316 
317         case STORED:
318             /* get and verify stored block length */
319             BYTEBITS();                         /* go to byte boundary */
320             NEEDBITS(32);
321             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
322 #ifdef SMALL
323                 strm->msg = (z_const char *)"error";
324 #else
325                 strm->msg = (z_const char *)"invalid stored block lengths";
326 #endif
327                 state->mode = BAD;
328                 break;
329             }
330             state->length = (unsigned)hold & 0xffff;
331             Tracev((stderr, "inflate:       stored length %u\n",
332                     state->length));
333             INITBITS();
334 
335             /* copy stored block from input to output */
336             while (state->length != 0) {
337                 copy = state->length;
338                 PULL();
339                 ROOM();
340                 if (copy > have) copy = have;
341                 if (copy > left) copy = left;
342                 zmemcpy(put, next, copy);
343                 have -= copy;
344                 next += copy;
345                 left -= copy;
346                 put += copy;
347                 state->length -= copy;
348             }
349             Tracev((stderr, "inflate:       stored end\n"));
350             state->mode = TYPE;
351             break;
352 
353         case TABLE:
354             /* get dynamic table entries descriptor */
355             NEEDBITS(14);
356             state->nlen = BITS(5) + 257;
357             DROPBITS(5);
358             state->ndist = BITS(5) + 1;
359             DROPBITS(5);
360             state->ncode = BITS(4) + 4;
361             DROPBITS(4);
362 #ifndef PKZIP_BUG_WORKAROUND
363             if (state->nlen > 286 || state->ndist > 30) {
364 #ifdef SMALL
365                 strm->msg = (z_const char *)"error";
366 #else
367                 strm->msg = (z_const char *)"too many length or distance symbols";
368 #endif
369                 state->mode = BAD;
370                 break;
371             }
372 #endif
373             Tracev((stderr, "inflate:       table sizes ok\n"));
374 
375             /* get code length code lengths (not a typo) */
376             state->have = 0;
377             while (state->have < state->ncode) {
378                 NEEDBITS(3);
379                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
380                 DROPBITS(3);
381             }
382             while (state->have < 19)
383                 state->lens[order[state->have++]] = 0;
384             state->next = state->codes;
385             state->lencode = (code const FAR *)(state->next);
386             state->lenbits = 7;
387             ret = inflate_table(CODES, state->lens, 19, &(state->next),
388                                 &(state->lenbits), state->work);
389             if (ret) {
390 #ifdef SMALL
391                 strm->msg = (z_const char *)"error";
392 #else
393                 strm->msg = (z_const char *)"invalid code lengths set";
394 #endif
395                 state->mode = BAD;
396                 break;
397             }
398             Tracev((stderr, "inflate:       code lengths ok\n"));
399 
400             /* get length and distance code code lengths */
401             state->have = 0;
402             while (state->have < state->nlen + state->ndist) {
403                 for (;;) {
404                     here = state->lencode[BITS(state->lenbits)];
405                     if ((unsigned)(here.bits) <= bits) break;
406                     PULLBYTE();
407                 }
408                 if (here.val < 16) {
409                     DROPBITS(here.bits);
410                     state->lens[state->have++] = here.val;
411                 }
412                 else {
413                     if (here.val == 16) {
414                         NEEDBITS(here.bits + 2);
415                         DROPBITS(here.bits);
416                         if (state->have == 0) {
417 #ifdef SMALL
418                             strm->msg = (z_const char *)"error";
419 #else
420                             strm->msg = (z_const char *)"invalid bit length repeat";
421 #endif
422                             state->mode = BAD;
423                             break;
424                         }
425                         len = (unsigned)(state->lens[state->have - 1]);
426                         copy = 3 + BITS(2);
427                         DROPBITS(2);
428                     }
429                     else if (here.val == 17) {
430                         NEEDBITS(here.bits + 3);
431                         DROPBITS(here.bits);
432                         len = 0;
433                         copy = 3 + BITS(3);
434                         DROPBITS(3);
435                     }
436                     else {
437                         NEEDBITS(here.bits + 7);
438                         DROPBITS(here.bits);
439                         len = 0;
440                         copy = 11 + BITS(7);
441                         DROPBITS(7);
442                     }
443                     if (state->have + copy > state->nlen + state->ndist) {
444 #ifdef SMALL
445                         strm->msg = (z_const char *)"error";
446 #else
447                         strm->msg = (z_const char *)"invalid bit length repeat";
448 #endif
449                         state->mode = BAD;
450                         break;
451                     }
452                     while (copy--)
453                         state->lens[state->have++] = (unsigned short)len;
454                 }
455             }
456 
457             /* handle error breaks in while */
458             if (state->mode == BAD) break;
459 
460             /* check for end-of-block code (better have one) */
461             if (state->lens[256] == 0) {
462 #ifdef SMALL
463                 strm->msg = (z_const char *)"error";
464 #else
465                 strm->msg = (z_const char *)"invalid code -- missing end-of-block";
466 #endif
467                 state->mode = BAD;
468                 break;
469             }
470 
471             /* build code tables -- note: do not change the lenbits or distbits
472                values here (9 and 6) without reading the comments in inftrees.h
473                concerning the ENOUGH constants, which depend on those values */
474             state->next = state->codes;
475             state->lencode = (code const FAR *)(state->next);
476             state->lenbits = 9;
477             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
478                                 &(state->lenbits), state->work);
479             if (ret) {
480 #ifdef SMALL
481                 strm->msg = (z_const char *)"error";
482 #else
483                 strm->msg = (z_const char *)"invalid literal/lengths set";
484 #endif
485                 state->mode = BAD;
486                 break;
487             }
488             state->distcode = (code const FAR *)(state->next);
489             state->distbits = 6;
490             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
491                             &(state->next), &(state->distbits), state->work);
492             if (ret) {
493 #ifdef SMALL
494                 strm->msg = (z_const char *)"error";
495 #else
496                 strm->msg = (z_const char *)"invalid distances set";
497 #endif
498                 state->mode = BAD;
499                 break;
500             }
501             Tracev((stderr, "inflate:       codes ok\n"));
502             state->mode = LEN;
503                 /* fallthrough */
504 
505         case LEN:
506 #ifndef SLOW
507             /* use inflate_fast() if we have enough input and output */
508             if (have >= 6 && left >= 258) {
509                 RESTORE();
510                 if (state->whave < state->wsize)
511                     state->whave = state->wsize - left;
512                 inflate_fast(strm, state->wsize);
513                 LOAD();
514                 break;
515             }
516 #endif
517 
518             /* get a literal, length, or end-of-block code */
519             for (;;) {
520                 here = state->lencode[BITS(state->lenbits)];
521                 if ((unsigned)(here.bits) <= bits) break;
522                 PULLBYTE();
523             }
524             if (here.op && (here.op & 0xf0) == 0) {
525                 last = here;
526                 for (;;) {
527                     here = state->lencode[last.val +
528                             (BITS(last.bits + last.op) >> last.bits)];
529                     if ((unsigned)(last.bits + here.bits) <= bits) break;
530                     PULLBYTE();
531                 }
532                 DROPBITS(last.bits);
533             }
534             DROPBITS(here.bits);
535             state->length = (unsigned)here.val;
536 
537             /* process literal */
538             if (here.op == 0) {
539                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
540                         "inflate:         literal '%c'\n" :
541                         "inflate:         literal 0x%02x\n", here.val));
542                 ROOM();
543                 *put++ = (unsigned char)(state->length);
544                 left--;
545                 state->mode = LEN;
546                 break;
547             }
548 
549             /* process end of block */
550             if (here.op & 32) {
551                 Tracevv((stderr, "inflate:         end of block\n"));
552                 state->mode = TYPE;
553                 break;
554             }
555 
556             /* invalid code */
557             if (here.op & 64) {
558 #ifdef SMALL
559                 strm->msg = (z_const char *)"error";
560 #else
561                 strm->msg = (z_const char *)"invalid literal/length code";
562 #endif
563                 state->mode = BAD;
564                 break;
565             }
566 
567             /* length code -- get extra bits, if any */
568             state->extra = (unsigned)(here.op) & 15;
569             if (state->extra != 0) {
570                 NEEDBITS(state->extra);
571                 state->length += BITS(state->extra);
572                 DROPBITS(state->extra);
573             }
574             Tracevv((stderr, "inflate:         length %u\n", state->length));
575 
576             /* get distance code */
577             for (;;) {
578                 here = state->distcode[BITS(state->distbits)];
579                 if ((unsigned)(here.bits) <= bits) break;
580                 PULLBYTE();
581             }
582             if ((here.op & 0xf0) == 0) {
583                 last = here;
584                 for (;;) {
585                     here = state->distcode[last.val +
586                             (BITS(last.bits + last.op) >> last.bits)];
587                     if ((unsigned)(last.bits + here.bits) <= bits) break;
588                     PULLBYTE();
589                 }
590                 DROPBITS(last.bits);
591             }
592             DROPBITS(here.bits);
593             if (here.op & 64) {
594 #ifdef SMALL
595                 strm->msg = (z_const char *)"error";
596 #else
597                 strm->msg = (z_const char *)"invalid distance code";
598 #endif
599                 state->mode = BAD;
600                 break;
601             }
602             state->offset = (unsigned)here.val;
603 
604             /* get distance extra bits, if any */
605             state->extra = (unsigned)(here.op) & 15;
606             if (state->extra != 0) {
607                 NEEDBITS(state->extra);
608                 state->offset += BITS(state->extra);
609                 DROPBITS(state->extra);
610             }
611             if (state->offset > state->wsize - (state->whave < state->wsize ?
612                                                 left : 0)) {
613 #ifdef SMALL
614                 strm->msg = (z_const char *)"error";
615 #else
616                 strm->msg = (z_const char *)"invalid distance too far back";
617 #endif
618                 state->mode = BAD;
619                 break;
620             }
621             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
622 
623             /* copy match from window to output */
624             do {
625                 ROOM();
626                 copy = state->wsize - state->offset;
627                 if (copy < left) {
628                     from = put + copy;
629                     copy = left - copy;
630                 }
631                 else {
632                     from = put - state->offset;
633                     copy = left;
634                 }
635                 if (copy > state->length) copy = state->length;
636                 state->length -= copy;
637                 left -= copy;
638                 do {
639                     *put++ = *from++;
640                 } while (--copy);
641             } while (state->length != 0);
642             break;
643 
644         case DONE:
645             /* inflate stream terminated properly */
646             ret = Z_STREAM_END;
647             goto inf_leave;
648 
649         case BAD:
650             ret = Z_DATA_ERROR;
651             goto inf_leave;
652 
653         default:
654             /* can't happen, but makes compilers happy */
655             ret = Z_STREAM_ERROR;
656             goto inf_leave;
657         }
658 
659     /* Write leftover output and return unused input */
660   inf_leave:
661     if (left < state->wsize) {
662         if (out(out_desc, state->window, state->wsize - left) &&
663             ret == Z_STREAM_END)
664             ret = Z_BUF_ERROR;
665     }
666     strm->next_in = next;
667     strm->avail_in = have;
668     return ret;
669 }
670 
671 int ZEXPORT inflateBackEnd(z_streamp strm) {
672     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
673         return Z_STREAM_ERROR;
674     ZFREE(strm, strm->state);
675     strm->state = Z_NULL;
676     Tracev((stderr, "inflate: end\n"));
677     return Z_OK;
678 }
679