1 /*
2 ** SPDX-License-Identifier: BSD-3-Clause
3 ** Copyright Contributors to the OpenEXR Project.
4 */
5 
6 #include "internal_huf.h"
7 
8 #include "internal_memory.h"
9 
10 #include <stddef.h>
11 #include <stdint.h>
12 #include <string.h>
13 
14 #define HUF_ENCBITS 16
15 #define HUF_DECBITS 14
16 
17 #define HUF_ENCSIZE ((1 << HUF_ENCBITS) + 1)
18 #define HUF_DECSIZE (1 << HUF_DECBITS)
19 #define HUF_DECMASK (HUF_DECSIZE - 1)
20 
21 typedef struct _HufDec
22 {
23     int32_t   len;
24     uint32_t  lit;
25     uint32_t* p;
26 } HufDec;
27 
28 /**************************************/
29 
30 static inline int
hufLength(uint64_t code)31 hufLength (uint64_t code)
32 {
33     return (int) (code & 63);
34 }
35 
36 static inline uint64_t
hufCode(uint64_t code)37 hufCode (uint64_t code)
38 {
39     return code >> 6;
40 }
41 
42 static inline void
outputBits(int nBits,uint64_t bits,uint64_t * c,int * lc,uint8_t ** outptr)43 outputBits (int nBits, uint64_t bits, uint64_t* c, int* lc, uint8_t** outptr)
44 {
45     uint8_t* out = *outptr;
46     *c <<= nBits;
47     *lc += nBits;
48     *c |= bits;
49 
50     while (*lc >= 8)
51         *out++ = (uint8_t) (*c >> (*lc -= 8));
52     *outptr = out;
53 }
54 
55 static inline uint64_t
getBits(uint32_t nBits,uint64_t * c,uint32_t * lc,const uint8_t ** inptr)56 getBits (uint32_t nBits, uint64_t* c, uint32_t* lc, const uint8_t** inptr)
57 {
58     const uint8_t* in = *inptr;
59     while (*lc < nBits)
60     {
61         *c = (*c << 8) | (uint64_t) (*in++);
62         *lc += 8;
63     }
64 
65     *inptr = in;
66     *lc -= nBits;
67     return (*c >> *lc) & ((1 << nBits) - 1);
68 }
69 
70 //
71 // ENCODING TABLE BUILDING & (UN)PACKING
72 //
73 
74 //
75 // Build a "canonical" Huffman code table:
76 //	- for each (uncompressed) symbol, hcode contains the length
77 //	  of the corresponding code (in the compressed data)
78 //	- canonical codes are computed and stored in hcode
79 //	- the rules for constructing canonical codes are as follows:
80 //	  * shorter codes (if filled with zeroes to the right)
81 //	    have a numerically higher value than longer codes
82 //	  * for codes with the same length, numerical values
83 //	    increase with numerical symbol values
84 //	- because the canonical code table can be constructed from
85 //	  symbol lengths alone, the code table can be transmitted
86 //	  without sending the actual code values
87 //	- see http://www.compressconsult.com/huffman/
88 //
89 
90 static void
hufCanonicalCodeTable(uint64_t * hcode)91 hufCanonicalCodeTable (uint64_t* hcode)
92 {
93     uint64_t n[59];
94 
95     //
96     // For each i from 0 through 58, count the
97     // number of different codes of length i, and
98     // store the count in n[i].
99     //
100 
101     for (int i = 0; i <= 58; ++i)
102         n[i] = 0;
103 
104     for (int i = 0; i < HUF_ENCSIZE; ++i)
105         n[hcode[i]] += 1;
106 
107     //
108     // For each i from 58 through 1, compute the
109     // numerically lowest code with length i, and
110     // store that code in n[i].
111     //
112 
113     uint64_t c = 0;
114 
115     for (int i = 58; i > 0; --i)
116     {
117         uint64_t nc = ((c + n[i]) >> 1);
118         n[i]        = c;
119         c           = nc;
120     }
121 
122     //
123     // hcode[i] contains the length, l, of the
124     // code for symbol i.  Assign the next available
125     // code of length l to the symbol and store both
126     // l and the code in hcode[i].
127     //
128 
129     for (int i = 0; i < HUF_ENCSIZE; ++i)
130     {
131         uint64_t l = hcode[i];
132 
133         if (l > 0) hcode[i] = l | (n[l]++ << 6);
134     }
135 }
136 
137 //
138 // Compute Huffman codes (based on frq input) and store them in frq:
139 //	- code structure is : [63:lsb - 6:msb] | [5-0: bit length];
140 //	- max code length is 58 bits;
141 //	- codes outside the range [im-iM] have a null length (unused values);
142 //	- original frequencies are destroyed;
143 //	- encoding tables are used by hufEncode() and hufBuildDecTable();
144 //
145 // NB: The following code "(*a == *b) && (a > b))" was added to ensure
146 //     elements in the heap with the same value are sorted by index.
147 //     This is to ensure, the STL make_heap()/pop_heap()/push_heap() methods
148 //     produced a resultant sorted heap that is identical across OSes.
149 //
150 
151 static inline int
FHeapCompare(uint64_t * a,uint64_t * b)152 FHeapCompare (uint64_t* a, uint64_t* b)
153 {
154     return ((*a > *b) || ((*a == *b) && (a > b)));
155 }
156 
157 static inline void
intern_push_heap(uint64_t ** first,size_t holeIndex,size_t topIndex,uint64_t * value)158 intern_push_heap (
159     uint64_t** first, size_t holeIndex, size_t topIndex, uint64_t* value)
160 {
161     size_t parent = (holeIndex - 1) / 2;
162     while (holeIndex > topIndex && FHeapCompare (*(first + parent), value))
163     {
164         *(first + holeIndex) = *(first + parent);
165         holeIndex            = parent;
166         parent               = (holeIndex - 1) / 2;
167     }
168     *(first + holeIndex) = value;
169 }
170 
171 static inline void
adjust_heap(uint64_t ** first,size_t holeIndex,size_t len,uint64_t * value)172 adjust_heap (uint64_t** first, size_t holeIndex, size_t len, uint64_t* value)
173 {
174     const size_t topIndex    = holeIndex;
175     size_t       secondChild = holeIndex;
176 
177     while (secondChild < (len - 1) / 2)
178     {
179         secondChild = 2 * (secondChild + 1);
180         if (FHeapCompare (*(first + secondChild), *(first + (secondChild - 1))))
181             --secondChild;
182         *(first + holeIndex) = *(first + secondChild);
183         holeIndex            = secondChild;
184     }
185 
186     if ((len & 1) == 0 && secondChild == (len - 2) / 2)
187     {
188         secondChild          = 2 * (secondChild + 1);
189         *(first + holeIndex) = *(first + (secondChild - 1));
190         holeIndex            = secondChild - 1;
191     }
192 
193     intern_push_heap (first, holeIndex, topIndex, value);
194 }
195 
196 static inline void
push_heap(uint64_t ** first,uint64_t ** last)197 push_heap (uint64_t** first, uint64_t** last)
198 {
199     uint64_t* value = *(last - 1);
200     intern_push_heap (first, (size_t) (last - first) - 1, 0, value);
201 }
202 
203 static inline void
intern_pop_heap(uint64_t ** first,uint64_t ** last,uint64_t ** result)204 intern_pop_heap (uint64_t** first, uint64_t** last, uint64_t** result)
205 {
206     uint64_t* value = *result;
207     *result         = *first;
208     adjust_heap (first, 0, (size_t) (last - first), value);
209 }
210 
211 static inline void
pop_heap(uint64_t ** first,uint64_t ** last)212 pop_heap (uint64_t** first, uint64_t** last)
213 {
214     if (last - first > 1)
215     {
216         --last;
217         intern_pop_heap (first, last, last);
218     }
219 }
220 
221 static void
make_heap(uint64_t ** first,uint64_t len)222 make_heap (uint64_t** first, uint64_t len)
223 {
224     size_t parent;
225 
226     if (len < 2) return;
227     parent = (len - 2) / 2;
228 
229     while (1)
230     {
231         uint64_t* value = *(first + parent);
232         adjust_heap (first, parent, len, value);
233         if (parent == 0) return;
234         --parent;
235     }
236 }
237 
238 static void
hufBuildEncTable(uint64_t * frq,uint32_t * im,uint32_t * iM,uint32_t * hlink,uint64_t ** fHeap,uint64_t * scode)239 hufBuildEncTable (
240     uint64_t*  frq,
241     uint32_t*  im,
242     uint32_t*  iM,
243     uint32_t*  hlink,
244     uint64_t** fHeap,
245     uint64_t*  scode)
246 {
247     //
248     // This function assumes that when it is called, array frq
249     // indicates the frequency of all possible symbols in the data
250     // that are to be Huffman-encoded.  (frq[i] contains the number
251     // of occurrences of symbol i in the data.)
252     //
253     // The loop below does three things:
254     //
255     // 1) Finds the minimum and maximum indices that point
256     //    to non-zero entries in frq:
257     //
258     //     frq[im] != 0, and frq[i] == 0 for all i < im
259     //     frq[iM] != 0, and frq[i] == 0 for all i > iM
260     //
261     // 2) Fills array fHeap with pointers to all non-zero
262     //    entries in frq.
263     //
264     // 3) Initializes array hlink such that hlink[i] == i
265     //    for all array entries.
266     //
267 
268     *im = 0;
269 
270     while (!frq[*im])
271         (*im)++;
272 
273     uint32_t nf = 0;
274 
275     for (uint32_t i = *im; i < HUF_ENCSIZE; i++)
276     {
277         hlink[i] = i;
278 
279         if (frq[i])
280         {
281             fHeap[nf] = &frq[i];
282             ++nf;
283             *iM = i;
284         }
285     }
286 
287     //
288     // Add a pseudo-symbol, with a frequency count of 1, to frq;
289     // adjust the fHeap and hlink array accordingly.  Function
290     // hufEncode() uses the pseudo-symbol for run-length encoding.
291     //
292 
293     (*iM)++;
294     frq[*iM]  = 1;
295     fHeap[nf] = &frq[*iM];
296     ++nf;
297 
298     //
299     // Build an array, scode, such that scode[i] contains the number
300     // of bits assigned to symbol i.  Conceptually this is done by
301     // constructing a tree whose leaves are the symbols with non-zero
302     // frequency:
303     //
304     //     Make a heap that contains all symbols with a non-zero frequency,
305     //     with the least frequent symbol on top.
306     //
307     //     Repeat until only one symbol is left on the heap:
308     //
309     //         Take the two least frequent symbols off the top of the heap.
310     //         Create a new node that has first two nodes as children, and
311     //         whose frequency is the sum of the frequencies of the first
312     //         two nodes.  Put the new node back into the heap.
313     //
314     // The last node left on the heap is the root of the tree.  For each
315     // leaf node, the distance between the root and the leaf is the length
316     // of the code for the corresponding symbol.
317     //
318     // The loop below doesn't actually build the tree; instead we compute
319     // the distances of the leaves from the root on the fly.  When a new
320     // node is added to the heap, then that node's descendants are linked
321     // into a single linear list that starts at the new node, and the code
322     // lengths of the descendants (that is, their distance from the root
323     // of the tree) are incremented by one.
324     //
325 
326     make_heap (fHeap, nf);
327 
328     memset (scode, 0, sizeof (uint64_t) * HUF_ENCSIZE);
329 
330     while (nf > 1)
331     {
332         //
333         // Find the indices, mm and m, of the two smallest non-zero frq
334         // values in fHeap, add the smallest frq to the second-smallest
335         // frq, and remove the smallest frq value from fHeap.
336         //
337 
338         uint32_t mm = (uint32_t) (fHeap[0] - frq);
339         pop_heap (&fHeap[0], &fHeap[nf]);
340         --nf;
341 
342         uint32_t m = (uint32_t) (fHeap[0] - frq);
343         pop_heap (&fHeap[0], &fHeap[nf]);
344 
345         frq[m] += frq[mm];
346         push_heap (&fHeap[0], &fHeap[nf]);
347 
348         //
349         // The entries in scode are linked into lists with the
350         // entries in hlink serving as "next" pointers and with
351         // the end of a list marked by hlink[j] == j.
352         //
353         // Traverse the lists that start at scode[m] and scode[mm].
354         // For each element visited, increment the length of the
355         // corresponding code by one bit. (If we visit scode[j]
356         // during the traversal, then the code for symbol j becomes
357         // one bit longer.)
358         //
359         // Merge the lists that start at scode[m] and scode[mm]
360         // into a single list that starts at scode[m].
361         //
362 
363         //
364         // Add a bit to all codes in the first list.
365         //
366 
367         for (uint32_t j = m;; j = hlink[j])
368         {
369             scode[j]++;
370 
371             if (hlink[j] == j)
372             {
373                 //
374                 // Merge the two lists.
375                 //
376 
377                 hlink[j] = mm;
378                 break;
379             }
380         }
381 
382         //
383         // Add a bit to all codes in the second list
384         //
385 
386         for (uint32_t j = mm;; j = hlink[j])
387         {
388             scode[j]++;
389 
390             if (hlink[j] == j) break;
391         }
392     }
393 
394     //
395     // Build a canonical Huffman code table, replacing the code
396     // lengths in scode with (code, code length) pairs.  Copy the
397     // code table from scode into frq.
398     //
399 
400     hufCanonicalCodeTable (scode);
401     memcpy (frq, scode, sizeof (uint64_t) * HUF_ENCSIZE);
402 }
403 
404 //
405 // Pack an encoding table:
406 //	- only code lengths, not actual codes, are stored
407 //	- runs of zeroes are compressed as follows:
408 //
409 //	  unpacked		packed
410 //	  --------------------------------
411 //	  1 zero		0	(6 bits)
412 //	  2 zeroes		59
413 //	  3 zeroes		60
414 //	  4 zeroes		61
415 //	  5 zeroes		62
416 //	  n zeroes (6 or more)	63 n-6	(6 + 8 bits)
417 //
418 
419 #define SHORT_ZEROCODE_RUN 59
420 #define LONG_ZEROCODE_RUN 63
421 #define SHORTEST_LONG_RUN (2 + LONG_ZEROCODE_RUN - SHORT_ZEROCODE_RUN)
422 #define LONGEST_LONG_RUN (255 + SHORTEST_LONG_RUN)
423 
424 static void
hufPackEncTable(const uint64_t * hcode,uint32_t im,uint32_t iM,uint8_t ** pcode)425 hufPackEncTable (
426     const uint64_t* hcode, // i : encoding table [HUF_ENCSIZE]
427     uint32_t        im,    // i : min hcode index
428     uint32_t        iM,    // i : max hcode index
429     uint8_t**       pcode)       //  o: ptr to packed table (updated)
430 {
431     uint8_t* out = *pcode;
432     uint64_t c   = 0;
433     int      lc  = 0;
434 
435     for (; im <= iM; im++)
436     {
437         int l = hufLength (hcode[im]);
438 
439         if (l == 0)
440         {
441             uint64_t zerun = 1;
442 
443             while ((im < iM) && (zerun < LONGEST_LONG_RUN))
444             {
445                 if (hufLength (hcode[im + 1]) > 0) break;
446                 im++;
447                 zerun++;
448             }
449 
450             if (zerun >= 2)
451             {
452                 if (zerun >= SHORTEST_LONG_RUN)
453                 {
454                     outputBits (6, LONG_ZEROCODE_RUN, &c, &lc, &out);
455                     outputBits (8, zerun - SHORTEST_LONG_RUN, &c, &lc, &out);
456                 }
457                 else
458                 {
459                     outputBits (
460                         6, SHORT_ZEROCODE_RUN + zerun - 2, &c, &lc, &out);
461                 }
462                 continue;
463             }
464         }
465 
466         outputBits (6, (uint64_t) l, &c, &lc, &out);
467     }
468 
469     if (lc > 0) *out++ = (uint8_t) (c << (8 - lc));
470 
471     *pcode = out;
472 }
473 
474 //
475 // Unpack an encoding table packed by hufPackEncTable():
476 //
477 
478 static exr_result_t
hufUnpackEncTable(const uint8_t ** pcode,uint64_t * nLeft,uint32_t im,uint32_t iM,uint64_t * hcode)479 hufUnpackEncTable (
480     const uint8_t** pcode, // io: ptr to packed table (updated)
481     uint64_t*       nLeft, // io: input size (in bytes), bytes left
482     uint32_t        im,    // i : min hcode index
483     uint32_t        iM,    // i : max hcode index
484     uint64_t*       hcode) // o : encoding table [HUF_ENCSIZE]
485 {
486     memset (hcode, 0, sizeof (uint64_t) * HUF_ENCSIZE);
487 
488     const uint8_t* p  = *pcode;
489     uint64_t       c  = 0;
490     uint64_t       ni = *nLeft;
491     uint64_t       nr;
492     uint32_t       lc = 0;
493 
494     for (; im <= iM; im++)
495     {
496         nr = (((uintptr_t) p) - ((uintptr_t) *pcode));
497         if (lc < 6 && nr >= ni) return EXR_ERR_OUT_OF_MEMORY;
498 
499         uint64_t l = hcode[im] = getBits (6, &c, &lc, &p); // code length
500 
501         if (l == (uint64_t) LONG_ZEROCODE_RUN)
502         {
503             nr = (((uintptr_t) p) - ((uintptr_t) *pcode));
504             if (lc < 8 && nr >= ni) return EXR_ERR_OUT_OF_MEMORY;
505 
506             uint64_t zerun = getBits (8, &c, &lc, &p) + SHORTEST_LONG_RUN;
507 
508             if (im + zerun > iM + 1) return EXR_ERR_CORRUPT_CHUNK;
509 
510             while (zerun--)
511                 hcode[im++] = 0;
512 
513             im--;
514         }
515         else if (l >= (uint64_t) SHORT_ZEROCODE_RUN)
516         {
517             uint64_t zerun = l - SHORT_ZEROCODE_RUN + 2;
518 
519             if (im + zerun > iM + 1) return EXR_ERR_CORRUPT_CHUNK;
520 
521             while (zerun--)
522                 hcode[im++] = 0;
523 
524             im--;
525         }
526     }
527 
528     nr = (((uintptr_t) p) - ((uintptr_t) *pcode));
529     *nLeft -= nr;
530     *pcode = p;
531 
532     hufCanonicalCodeTable (hcode);
533     return EXR_ERR_SUCCESS;
534 }
535 
536 //
537 // DECODING TABLE BUILDING
538 //
539 
540 //
541 // Clear a newly allocated decoding table so that it contains only zeroes.
542 //
543 
544 static void
hufClearDecTable(HufDec * hdecod)545 hufClearDecTable (HufDec* hdecod)
546 {
547     memset (hdecod, 0, sizeof (HufDec) * HUF_DECSIZE);
548 }
549 
550 //
551 // Build a decoding hash table based on the encoding table hcode:
552 //	- short codes (<= HUF_DECBITS) are resolved with a single table access;
553 //	- long code entry allocations are not optimized, because long codes are
554 //	  unfrequent;
555 //	- decoding tables are used by hufDecode();
556 //
557 
558 static exr_result_t
hufBuildDecTable(const uint64_t * hcode,uint32_t im,uint32_t iM,HufDec * hdecod)559 hufBuildDecTable (
560     const uint64_t* hcode, uint32_t im, uint32_t iM, HufDec* hdecod)
561 {
562     //
563     // Init hashtable & loop on all codes.
564     // Assumes that hufClearDecTable(hdecod) has already been called.
565     //
566 
567     for (; im <= iM; im++)
568     {
569         uint64_t c = hufCode (hcode[im]);
570         int      l = hufLength (hcode[im]);
571 
572         if (c >> l)
573         {
574             //
575             // Error: c is supposed to be an l-bit code,
576             // but c contains a value that is greater
577             // than the largest l-bit number.
578             //
579 
580             return EXR_ERR_CORRUPT_CHUNK;
581         }
582 
583         if (l > HUF_DECBITS)
584         {
585             //
586             // Long code: add a secondary entry
587             //
588 
589             HufDec* pl = hdecod + (c >> (l - HUF_DECBITS));
590 
591             if (pl->len)
592             {
593                 //
594                 // Error: a short code has already
595                 // been stored in table entry *pl.
596                 //
597 
598                 return EXR_ERR_CORRUPT_CHUNK;
599             }
600 
601             pl->lit++;
602 
603             if (pl->p)
604             {
605                 uint32_t* p = pl->p;
606                 pl->p       = (uint32_t*) internal_exr_alloc (
607                     sizeof (uint32_t) * pl->lit);
608 
609                 if (pl->p)
610                 {
611                     for (uint32_t i = 0; i < pl->lit - 1; ++i)
612                         pl->p[i] = p[i];
613                 }
614 
615                 internal_exr_free (p);
616             }
617             else
618             {
619                 pl->p = (uint32_t*) internal_exr_alloc (sizeof (uint32_t));
620             }
621 
622             if (!pl->p) return EXR_ERR_OUT_OF_MEMORY;
623 
624             pl->p[pl->lit - 1] = im;
625         }
626         else if (l)
627         {
628             //
629             // Short code: init all primary entries
630             //
631 
632             HufDec* pl = hdecod + (c << (HUF_DECBITS - l));
633 
634             for (uint64_t i = 1 << (HUF_DECBITS - l); i > 0; i--, pl++)
635             {
636                 if (pl->len || pl->p)
637                 {
638                     //
639                     // Error: a short code or a long code has
640                     // already been stored in table entry *pl.
641                     //
642 
643                     return EXR_ERR_CORRUPT_CHUNK;
644                 }
645 
646                 pl->len = (int32_t) l;
647                 pl->lit = im;
648             }
649         }
650     }
651     return EXR_ERR_SUCCESS;
652 }
653 
654 //
655 // Free the long code entries of a decoding table built by hufBuildDecTable()
656 //
657 
658 static void
hufFreeDecTable(HufDec * hdecod)659 hufFreeDecTable (HufDec* hdecod) // io: Decoding table
660 {
661     for (int i = 0; i < HUF_DECSIZE; i++)
662     {
663         if (hdecod[i].p)
664         {
665             internal_exr_free (hdecod[i].p);
666             hdecod[i].p = NULL;
667         }
668     }
669 }
670 
671 //
672 // ENCODING
673 //
674 
675 static inline void
outputCode(uint64_t code,uint64_t * c,int * lc,uint8_t ** out)676 outputCode (uint64_t code, uint64_t* c, int* lc, uint8_t** out)
677 {
678     outputBits (hufLength (code), hufCode (code), c, lc, out);
679 }
680 
681 static inline void
sendCode(uint64_t sCode,int runCount,uint64_t runCode,uint64_t * c,int * lc,uint8_t ** out)682 sendCode (
683     uint64_t  sCode,
684     int       runCount,
685     uint64_t  runCode,
686     uint64_t* c,
687     int*      lc,
688     uint8_t** out)
689 {
690     if (hufLength (sCode) + hufLength (runCode) + 8 <
691         hufLength (sCode) * runCount)
692     {
693         outputCode (sCode, c, lc, out);
694         outputCode (runCode, c, lc, out);
695         outputBits (8, (uint64_t) runCount, c, lc, out);
696     }
697     else
698     {
699         while (runCount-- >= 0)
700             outputCode (sCode, c, lc, out);
701     }
702 }
703 
704 //
705 // Encode (compress) ni values based on the Huffman encoding table hcode:
706 //
707 
708 static inline uint64_t
hufEncode(const uint64_t * hcode,const uint16_t * in,const uint64_t ni,uint32_t rlc,uint8_t * out)709 hufEncode (
710     const uint64_t* hcode,
711     const uint16_t* in,
712     const uint64_t  ni,
713     uint32_t        rlc,
714     uint8_t*        out)
715 {
716     uint8_t* outStart = out;
717     uint64_t c        = 0; // bits not yet written to out
718     int      lc       = 0; // number of valid bits in c (LSB)
719     uint16_t s        = in[0];
720     int      cs       = 0;
721 
722     //
723     // Loop on input values
724     //
725 
726     for (uint64_t i = 1; i < ni; i++)
727     {
728         //
729         // Count same values or send code
730         //
731 
732         if (s == in[i] && cs < 255) { cs++; }
733         else
734         {
735             sendCode (hcode[s], cs, hcode[rlc], &c, &lc, &out);
736             cs = 0;
737         }
738 
739         s = in[i];
740     }
741 
742     //
743     // Send remaining code
744     //
745 
746     sendCode (hcode[s], cs, hcode[rlc], &c, &lc, &out);
747 
748     if (lc) *out = (c << (8 - lc)) & 0xff;
749 
750     return (((uintptr_t) out) - ((uintptr_t) outStart)) * 8 + (uint64_t) (lc);
751 }
752 
753 //
754 // DECODING
755 //
756 
757 //
758 // In order to force the compiler to inline them,
759 // getChar() and getCode() are implemented as macros
760 // instead of "inline" functions.
761 //
762 
763 #define getChar(c, lc, in)                                                     \
764     c = (c << 8) | (uint64_t) (*in++);                                         \
765     lc += 8
766 
767 #define getCode(po, rlc, c, lc, in, ie, out, ob, oe)                           \
768     do                                                                         \
769     {                                                                          \
770         if (po == rlc)                                                         \
771         {                                                                      \
772             if (lc < 8)                                                        \
773             {                                                                  \
774                 if (in >= ie) return EXR_ERR_OUT_OF_MEMORY;                    \
775                 getChar (c, lc, in);                                           \
776             }                                                                  \
777                                                                                \
778             lc -= 8;                                                           \
779                                                                                \
780             uint8_t cs = (uint8_t) (c >> lc);                                  \
781                                                                                \
782             if (out + cs > oe)                                                 \
783                 return EXR_ERR_CORRUPT_CHUNK;                                  \
784             else if (out - 1 < ob)                                             \
785                 return EXR_ERR_OUT_OF_MEMORY;                                  \
786                                                                                \
787             uint16_t s = out[-1];                                              \
788                                                                                \
789             while (cs-- > 0)                                                   \
790                 *out++ = s;                                                    \
791         }                                                                      \
792         else if (out < oe)                                                     \
793         {                                                                      \
794             *out++ = (uint16_t) po;                                            \
795         }                                                                      \
796         else                                                                   \
797         {                                                                      \
798             return EXR_ERR_CORRUPT_CHUNK;                                      \
799         }                                                                      \
800     } while (0)
801 
802 //
803 // Decode (uncompress) ni bits based on encoding & decoding tables:
804 //
805 
806 static exr_result_t
hufDecode(const uint64_t * hcode,const HufDec * hdecod,const uint8_t * in,uint64_t ni,uint32_t rlc,uint64_t no,uint16_t * out)807 hufDecode (
808     const uint64_t* hcode,  // i : encoding table
809     const HufDec*   hdecod, // i : decoding table
810     const uint8_t*  in,     // i : compressed input buffer
811     uint64_t        ni,     // i : input size (in bits)
812     uint32_t        rlc,    // i : run-length code
813     uint64_t        no,     // i : expected output size (count of uint16 items)
814     uint16_t*       out)
815 {
816     uint64_t       c    = 0;
817     int            lc   = 0;
818     uint16_t*      outb = out;
819     uint16_t*      oe   = out + no;
820     const uint8_t* ie   = in + (ni + 7) / 8; // input byte size
821 
822     //
823     // Loop on input bytes
824     //
825 
826     while (in < ie)
827     {
828         getChar (c, lc, in);
829 
830         //
831         // Access decoding table
832         //
833 
834         while (lc >= HUF_DECBITS)
835         {
836             uint64_t      decoffset = (c >> (lc - HUF_DECBITS)) & HUF_DECMASK;
837             const HufDec* pl        = hdecod + decoffset;
838 
839             if (pl->len)
840             {
841                 //
842                 // Get short code
843                 //
844 
845                 if (pl->len > lc) return EXR_ERR_CORRUPT_CHUNK;
846 
847                 lc -= pl->len;
848                 getCode (pl->lit, rlc, c, lc, in, ie, out, outb, oe);
849             }
850             else
851             {
852                 uint32_t        j;
853                 const uint32_t* decbuf = pl->p;
854                 if (!pl->p) return EXR_ERR_CORRUPT_CHUNK; // wrong code
855 
856                 //
857                 // Search long code
858                 //
859 
860                 for (j = 0; j < pl->lit; j++)
861                 {
862                     int l = hufLength (hcode[decbuf[j]]);
863 
864                     while (lc < l && in < ie) // get more bits
865                     {
866                         getChar (c, lc, in);
867                     }
868 
869                     if (lc >= l)
870                     {
871                         if (hufCode (hcode[decbuf[j]]) ==
872                             ((c >> (lc - l)) & (((uint64_t) (1) << l) - 1)))
873                         {
874                             //
875                             // Found : get long code
876                             //
877 
878                             lc -= l;
879                             getCode (
880                                 decbuf[j], rlc, c, lc, in, ie, out, outb, oe);
881                             break;
882                         }
883                     }
884                 }
885 
886                 if (j == pl->lit) return EXR_ERR_CORRUPT_CHUNK;
887             }
888         }
889     }
890 
891     //
892     // Get remaining (short) codes
893     //
894 
895     uint64_t i = (8 - ni) & 7;
896     c >>= i;
897     lc -= i;
898 
899     while (lc > 0)
900     {
901         uint64_t      decoffset = (c << (HUF_DECBITS - lc)) & HUF_DECMASK;
902         const HufDec* pl        = hdecod + decoffset;
903 
904         if (pl->len)
905         {
906             if (pl->len > lc) return EXR_ERR_CORRUPT_CHUNK;
907             lc -= pl->len;
908             getCode (pl->lit, rlc, c, lc, in, ie, out, outb, oe);
909         }
910         else
911             return EXR_ERR_CORRUPT_CHUNK;
912     }
913 
914     if (out != oe) return EXR_ERR_OUT_OF_MEMORY;
915     return EXR_ERR_SUCCESS;
916 }
917 
918 static inline void
countFrequencies(uint64_t * freq,const uint16_t * data,uint64_t n)919 countFrequencies (uint64_t* freq, const uint16_t* data, uint64_t n)
920 {
921     memset (freq, 0, HUF_ENCSIZE * sizeof (uint64_t));
922     for (uint64_t i = 0; i < n; ++i)
923         ++freq[data[i]];
924 }
925 
926 static inline void
writeUInt(uint8_t * b,uint32_t i)927 writeUInt (uint8_t* b, uint32_t i)
928 {
929     b[0] = (uint8_t) (i);
930     b[1] = (uint8_t) (i >> 8);
931     b[2] = (uint8_t) (i >> 16);
932     b[3] = (uint8_t) (i >> 24);
933 }
934 
935 static inline uint32_t
readUInt(const uint8_t * b)936 readUInt (const uint8_t* b)
937 {
938     return (
939         ((uint32_t) b[0]) | (((uint32_t) b[1]) << 8u) |
940         (((uint32_t) b[2]) << 16u) | (((uint32_t) b[3]) << 24u));
941 }
942 
943 /**************************************/
944 
945 uint64_t
internal_exr_huf_compress_spare_bytes(void)946 internal_exr_huf_compress_spare_bytes (void)
947 {
948     uint64_t ret = 0;
949     ret += HUF_ENCSIZE * sizeof (uint64_t);  // freq
950     ret += HUF_ENCSIZE * sizeof (int);       // hlink
951     ret += HUF_ENCSIZE * sizeof (uint64_t*); // fheap
952     ret += HUF_ENCSIZE * sizeof (uint64_t);  // scode
953     return ret;
954 }
955 
956 uint64_t
internal_exr_huf_decompress_spare_bytes(void)957 internal_exr_huf_decompress_spare_bytes (void)
958 {
959     uint64_t ret = 0;
960     ret += HUF_ENCSIZE * sizeof (uint64_t); // freq
961     ret += HUF_DECSIZE * sizeof (HufDec);   // hdec
962     //    ret += HUF_ENCSIZE * sizeof (uint64_t*); // fheap
963     //    ret += HUF_ENCSIZE * sizeof (uint64_t);  // scode
964     return ret;
965 }
966 
967 exr_result_t
internal_huf_compress(uint64_t * encbytes,void * out,uint64_t outsz,const uint16_t * raw,uint64_t nRaw,void * spare,uint64_t sparebytes)968 internal_huf_compress (
969     uint64_t*       encbytes,
970     void*           out,
971     uint64_t        outsz,
972     const uint16_t* raw,
973     uint64_t        nRaw,
974     void*           spare,
975     uint64_t        sparebytes)
976 {
977     uint64_t*  freq;
978     uint32_t*  hlink;
979     uint64_t** fHeap;
980     uint64_t*  scode;
981     uint32_t   im = 0;
982     uint32_t   iM = 0;
983     uint32_t   tableLength, nBits, dataLength;
984     uint8_t*   dataStart;
985     uint8_t*   compressed = (uint8_t*) out;
986     uint8_t*   tableStart = compressed + 20;
987     uint8_t*   tableEnd   = tableStart;
988 
989     if (nRaw == 0)
990     {
991         *encbytes = 0;
992         return EXR_ERR_SUCCESS;
993     }
994 
995     (void) outsz;
996     if (sparebytes != internal_exr_huf_compress_spare_bytes ())
997         return EXR_ERR_INVALID_ARGUMENT;
998 
999     freq  = (uint64_t*) spare;
1000     scode = freq + HUF_ENCSIZE;
1001     fHeap = (uint64_t**) (scode + HUF_ENCSIZE);
1002     hlink = (uint32_t*) (fHeap + HUF_ENCSIZE);
1003 
1004     countFrequencies (freq, raw, nRaw);
1005 
1006     hufBuildEncTable (freq, &im, &iM, hlink, fHeap, scode);
1007 
1008     hufPackEncTable (freq, im, iM, &tableEnd);
1009 
1010     tableLength =
1011         (uint32_t) (((uintptr_t) tableEnd) - ((uintptr_t) tableStart));
1012     dataStart = tableEnd;
1013 
1014     nBits      = (uint32_t) hufEncode (freq, raw, nRaw, iM, dataStart);
1015     dataLength = (nBits + 7) / 8;
1016 
1017     writeUInt (compressed, im);
1018     writeUInt (compressed + 4, iM);
1019     writeUInt (compressed + 8, tableLength);
1020     writeUInt (compressed + 12, nBits);
1021     writeUInt (compressed + 16, 0); // room for future extensions
1022 
1023     *encbytes =
1024         (((uintptr_t) dataStart) + ((uintptr_t) dataLength) -
1025          ((uintptr_t) compressed));
1026     return EXR_ERR_SUCCESS;
1027 }
1028 
1029 exr_result_t
internal_huf_decompress(const uint8_t * compressed,uint64_t nCompressed,uint16_t * raw,uint64_t nRaw,void * spare,uint64_t sparebytes)1030 internal_huf_decompress (
1031     const uint8_t* compressed,
1032     uint64_t       nCompressed,
1033     uint16_t*      raw,
1034     uint64_t       nRaw,
1035     void*          spare,
1036     uint64_t       sparebytes)
1037 {
1038     uint32_t       im, iM, nBits;
1039     uint64_t       nBytes;
1040     const uint8_t* ptr;
1041     exr_result_t   rv;
1042 
1043     //
1044     // need at least 20 bytes for header
1045     //
1046     if (nCompressed < 20)
1047     {
1048         if (nRaw != 0) return EXR_ERR_INVALID_ARGUMENT;
1049         return EXR_ERR_SUCCESS;
1050     }
1051 
1052     if (sparebytes != internal_exr_huf_decompress_spare_bytes ())
1053         return EXR_ERR_INVALID_ARGUMENT;
1054 
1055     im = readUInt (compressed);
1056     iM = readUInt (compressed + 4);
1057     // uint32_t tableLength = readUInt (compressed + 8);
1058     nBits = readUInt (compressed + 12);
1059     // uint32_t future = readUInt (compressed + 16);
1060 
1061     if (im >= HUF_ENCSIZE || iM >= HUF_ENCSIZE) return EXR_ERR_CORRUPT_CHUNK;
1062 
1063     ptr = compressed + 20;
1064 
1065     nBytes = (((uint64_t) (nBits) + 7)) / 8;
1066     if (ptr + nBytes > compressed + nCompressed) return EXR_ERR_OUT_OF_MEMORY;
1067 
1068     //
1069     // Fast decoder needs at least 2x64-bits of compressed data, and
1070     // needs to be run-able on this platform. Otherwise, fall back
1071     // to the original decoder
1072     //
1073 #if 0
1074     if (FastHufDecoder::enabled () && nBits > 128)
1075     {
1076         FastHufDecoder fhd (ptr, nCompressed - (ptr - compressed), im, iM, iM);
1077 
1078         // must be nBytes remaining in buffer
1079         if (ptr - compressed + nBytes > static_cast<uint64_t> (nCompressed))
1080         {
1081             notEnoughData ();
1082             return;
1083         }
1084 
1085         rv = fhd.decode (ptr, nBits, raw, nRaw);
1086     }
1087     else
1088 #endif
1089     {
1090         uint64_t* freq     = (uint64_t*) spare;
1091         HufDec*   hdec     = (HufDec*) (freq + HUF_ENCSIZE);
1092         uint64_t  nLeft    = nCompressed - 20;
1093         uint64_t  nTableSz = 0;
1094 
1095         hufClearDecTable (hdec);
1096         hufUnpackEncTable (&ptr, &nLeft, im, iM, freq);
1097 
1098         if (nBits > 8 * nLeft) return EXR_ERR_CORRUPT_CHUNK;
1099 
1100         rv = hufBuildDecTable (freq, im, iM, hdec);
1101         if (rv == EXR_ERR_SUCCESS)
1102             rv = hufDecode (freq, hdec, ptr, nBits, iM, nRaw, raw);
1103 
1104         hufFreeDecTable (hdec);
1105     }
1106     return rv;
1107 }
1108