1 /* rans_byte.h originally from https://github.com/rygorous/ryg_rans
2  *
3  * This is a public-domain implementation of several rANS variants. rANS is an
4  * entropy coder from the ANS family, as described in Jarek Duda's paper
5  * "Asymmetric numeral systems" (http://arxiv.org/abs/1311.2540).
6  */
7 
8 /*-------------------------------------------------------------------------- */
9 
10 // Simple byte-aligned rANS encoder/decoder - public domain - Fabian 'ryg' Giesen 2014
11 //
12 // Not intended to be "industrial strength"; just meant to illustrate the general
13 // idea.
14 
15 #ifndef RANS_BYTE_HEADER
16 #define RANS_BYTE_HEADER
17 
18 #include <stdint.h>
19 
20 #ifdef assert
21 #define RansAssert assert
22 #else
23 #define RansAssert(x)
24 #endif
25 
26 // READ ME FIRST:
27 //
28 // This is designed like a typical arithmetic coder API, but there's three
29 // twists you absolutely should be aware of before you start hacking:
30 //
31 // 1. You need to encode data in *reverse* - last symbol first. rANS works
32 //    like a stack: last in, first out.
33 // 2. Likewise, the encoder outputs bytes *in reverse* - that is, you give
34 //    it a pointer to the *end* of your buffer (exclusive), and it will
35 //    slowly move towards the beginning as more bytes are emitted.
36 // 3. Unlike basically any other entropy coder implementation you might
37 //    have used, you can interleave data from multiple independent rANS
38 //    encoders into the same bytestream without any extra signaling;
39 //    you can also just write some bytes by yourself in the middle if
40 //    you want to. This is in addition to the usual arithmetic encoder
41 //    property of being able to switch models on the fly. Writing raw
42 //    bytes can be useful when you have some data that you know is
43 //    incompressible, and is cheaper than going through the rANS encode
44 //    function. Using multiple rANS coders on the same byte stream wastes
45 //    a few bytes compared to using just one, but execution of two
46 //    independent encoders can happen in parallel on superscalar and
47 //    Out-of-Order CPUs, so this can be *much* faster in tight decoding
48 //    loops.
49 //
50 //    This is why all the rANS functions take the write pointer as an
51 //    argument instead of just storing it in some context struct.
52 
53 // --------------------------------------------------------------------------
54 
55 // L ('l' in the paper) is the lower bound of our normalization interval.
56 // Between this and our byte-aligned emission, we use 31 (not 32!) bits.
57 // This is done intentionally because exact reciprocals for 31-bit uints
58 // fit in 32-bit uints: this permits some optimizations during encoding.
59 #define RANS_BYTE_L (1u << 23)  // lower bound of our normalization interval
60 
61 // State for a rANS encoder. Yep, that's all there is to it.
62 typedef uint32_t RansState;
63 
64 // Initialize a rANS encoder.
RansEncInit(RansState * r)65 static inline void RansEncInit(RansState* r)
66 {
67     *r = RANS_BYTE_L;
68 }
69 
70 // Renormalize the encoder. Internal function.
RansEncRenorm(RansState x,uint8_t ** pptr,uint32_t freq,uint32_t scale_bits)71 static inline RansState RansEncRenorm(RansState x, uint8_t** pptr, uint32_t freq, uint32_t scale_bits)
72 {
73     uint32_t x_max = ((RANS_BYTE_L >> scale_bits) << 8) * freq; // this turns into a shift.
74     if (x >= x_max) {
75         uint8_t* ptr = *pptr;
76         do {
77             *--ptr = (uint8_t) (x & 0xff);
78             x >>= 8;
79         } while (x >= x_max);
80         *pptr = ptr;
81     }
82     return x;
83 }
84 
85 // Encodes a single symbol with range start "start" and frequency "freq".
86 // All frequencies are assumed to sum to "1 << scale_bits", and the
87 // resulting bytes get written to ptr (which is updated).
88 //
89 // NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from
90 // beginning to end! Likewise, the output bytestream is written *backwards*:
91 // ptr starts pointing at the end of the output buffer and keeps decrementing.
RansEncPut(RansState * r,uint8_t ** pptr,uint32_t start,uint32_t freq,uint32_t scale_bits)92 static inline void RansEncPut(RansState* r, uint8_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
93 {
94     // renormalize
95     RansState x = RansEncRenorm(*r, pptr, freq, scale_bits);
96 
97     // x = C(s,x)
98     *r = ((x / freq) << scale_bits) + (x % freq) + start;
99 }
100 
101 // Flushes the rANS encoder.
RansEncFlush(RansState * r,uint8_t ** pptr)102 static inline void RansEncFlush(RansState* r, uint8_t** pptr)
103 {
104     uint32_t x = *r;
105     uint8_t* ptr = *pptr;
106 
107     ptr -= 4;
108     ptr[0] = (uint8_t) (x >> 0);
109     ptr[1] = (uint8_t) (x >> 8);
110     ptr[2] = (uint8_t) (x >> 16);
111     ptr[3] = (uint8_t) (x >> 24);
112 
113     *pptr = ptr;
114 }
115 
116 // Initializes a rANS decoder.
117 // Unlike the encoder, the decoder works forwards as you'd expect.
RansDecInit(RansState * r,uint8_t ** pptr)118 static inline void RansDecInit(RansState* r, uint8_t** pptr)
119 {
120     uint32_t x;
121     uint8_t* ptr = *pptr;
122 
123     x  = ptr[0] << 0;
124     x |= ptr[1] << 8;
125     x |= ptr[2] << 16;
126     x |= ptr[3] << 24;
127     ptr += 4;
128 
129     *pptr = ptr;
130     *r = x;
131 }
132 
133 // Returns the current cumulative frequency (map it to a symbol yourself!)
RansDecGet(RansState * r,uint32_t scale_bits)134 static inline uint32_t RansDecGet(RansState* r, uint32_t scale_bits)
135 {
136     return *r & ((1u << scale_bits) - 1);
137 }
138 
139 // Advances in the bit stream by "popping" a single symbol with range start
140 // "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits",
141 // and the resulting bytes get written to ptr (which is updated).
RansDecAdvance(RansState * r,uint8_t ** pptr,uint32_t start,uint32_t freq,uint32_t scale_bits)142 static inline void RansDecAdvance(RansState* r, uint8_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
143 {
144     uint32_t mask = (1u << scale_bits) - 1;
145 
146     // s, x = D(x)
147     uint32_t x = *r;
148     x = freq * (x >> scale_bits) + (x & mask) - start;
149 
150     // renormalize
151     if (x < RANS_BYTE_L) {
152         uint8_t* ptr = *pptr;
153         do x = (x << 8) | *ptr++; while (x < RANS_BYTE_L);
154         *pptr = ptr;
155     }
156 
157     *r = x;
158 }
159 
160 // --------------------------------------------------------------------------
161 
162 // That's all you need for a full encoder; below here are some utility
163 // functions with extra convenience or optimizations.
164 
165 // Encoder symbol description
166 // This (admittedly odd) selection of parameters was chosen to make
167 // RansEncPutSymbol as cheap as possible.
168 typedef struct {
169     uint32_t x_max;     // (Exclusive) upper bound of pre-normalization interval
170     uint32_t rcp_freq;  // Fixed-point reciprocal frequency
171     uint32_t bias;      // Bias
172     uint16_t cmpl_freq; // Complement of frequency: (1 << scale_bits) - freq
173     uint16_t rcp_shift; // Reciprocal shift
174 } RansEncSymbol;
175 
176 // Decoder symbols are straightforward.
177 typedef struct {
178     uint16_t start;     // Start of range.
179     uint16_t freq;      // Symbol frequency.
180 } RansDecSymbol;
181 
182 // Initializes an encoder symbol to start "start" and frequency "freq"
RansEncSymbolInit(RansEncSymbol * s,uint32_t start,uint32_t freq,uint32_t scale_bits)183 static inline void RansEncSymbolInit(RansEncSymbol* s, uint32_t start, uint32_t freq, uint32_t scale_bits)
184 {
185     RansAssert(scale_bits <= 16);
186     RansAssert(start <= (1u << scale_bits));
187     RansAssert(freq <= (1u << scale_bits) - start);
188 
189     // Say M := 1 << scale_bits.
190     //
191     // The original encoder does:
192     //   x_new = (x/freq)*M + start + (x%freq)
193     //
194     // The fast encoder does (schematically):
195     //   q     = mul_hi(x, rcp_freq) >> rcp_shift   (division)
196     //   r     = x - q*freq                         (remainder)
197     //   x_new = q*M + bias + r                     (new x)
198     // plugging in r into x_new yields:
199     //   x_new = bias + x + q*(M - freq)
200     //        =: bias + x + q*cmpl_freq             (*)
201     //
202     // and we can just precompute cmpl_freq. Now we just need to
203     // set up our parameters such that the original encoder and
204     // the fast encoder agree.
205 
206     s->x_max = ((RANS_BYTE_L >> scale_bits) << 8) * freq;
207     s->cmpl_freq = (uint16_t) ((1 << scale_bits) - freq);
208     if (freq < 2) {
209         // freq=0 symbols are never valid to encode, so it doesn't matter what
210         // we set our values to.
211         //
212         // freq=1 is tricky, since the reciprocal of 1 is 1; unfortunately,
213         // our fixed-point reciprocal approximation can only multiply by values
214         // smaller than 1.
215         //
216         // So we use the "next best thing": rcp_freq=0xffffffff, rcp_shift=0.
217         // This gives:
218         //   q = mul_hi(x, rcp_freq) >> rcp_shift
219         //     = mul_hi(x, (1<<32) - 1)) >> 0
220         //     = floor(x - x/(2^32))
221         //     = x - 1 if 1 <= x < 2^32
222         // and we know that x>0 (x=0 is never in a valid normalization interval).
223         //
224         // So we now need to choose the other parameters such that
225         //   x_new = x*M + start
226         // plug it in:
227         //     x*M + start                   (desired result)
228         //   = bias + x + q*cmpl_freq        (*)
229         //   = bias + x + (x - 1)*(M - 1)    (plug in q=x-1, cmpl_freq)
230         //   = bias + 1 + (x - 1)*M
231         //   = x*M + (bias + 1 - M)
232         //
233         // so we have start = bias + 1 - M, or equivalently
234         //   bias = start + M - 1.
235         s->rcp_freq = ~0u;
236         s->rcp_shift = 0;
237         s->bias = start + (1 << scale_bits) - 1;
238     } else {
239         // Alverson, "Integer Division using reciprocals"
240         // shift=ceil(log2(freq))
241         uint32_t shift = 0;
242         while (freq > (1u << shift))
243             shift++;
244 
245         s->rcp_freq = (uint32_t) (((1ull << (shift + 31)) + freq-1) / freq);
246         s->rcp_shift = shift - 1;
247 
248         // With these values, 'q' is the correct quotient, so we
249         // have bias=start.
250         s->bias = start;
251     }
252 
253     s->rcp_shift += 32; // Avoid the extra >>32 in RansEncPutSymbol
254 }
255 
256 // Initialize a decoder symbol to start "start" and frequency "freq"
RansDecSymbolInit(RansDecSymbol * s,uint32_t start,uint32_t freq)257 static inline void RansDecSymbolInit(RansDecSymbol* s, uint32_t start, uint32_t freq)
258 {
259     RansAssert(start <= (1 << 16));
260     RansAssert(freq <= (1 << 16) - start);
261     s->start = (uint16_t) start;
262     s->freq = (uint16_t) freq;
263 }
264 
265 // Encodes a given symbol. This is faster than straight RansEnc since we can do
266 // multiplications instead of a divide.
267 //
268 // See RansEncSymbolInit for a description of how this works.
RansEncPutSymbol(RansState * r,uint8_t ** pptr,RansEncSymbol const * sym)269 static inline void RansEncPutSymbol(RansState* r, uint8_t** pptr, RansEncSymbol const* sym)
270 {
271     RansAssert(sym->x_max != 0); // can't encode symbol with freq=0
272 
273     // renormalize
274     uint32_t x = *r;
275     uint32_t x_max = sym->x_max;
276 
277     if (x >= x_max) {
278 	uint8_t* ptr = *pptr;
279 	do {
280 	    *--ptr = (uint8_t) (x & 0xff);
281 	    x >>= 8;
282 	} while (x >= x_max);
283 	*pptr = ptr;
284     }
285 
286     // x = C(s,x)
287     // NOTE: written this way so we get a 32-bit "multiply high" when
288     // available. If you're on a 64-bit platform with cheap multiplies
289     // (e.g. x64), just bake the +32 into rcp_shift.
290     //uint32_t q = (uint32_t) (((uint64_t)x * sym->rcp_freq) >> 32) >> sym->rcp_shift;
291 
292     // The extra >>32 has already been added to RansEncSymbolInit
293     uint32_t q = (uint32_t) (((uint64_t)x * sym->rcp_freq) >> sym->rcp_shift);
294     *r = x + sym->bias + q * sym->cmpl_freq;
295 }
296 
297 // Equivalent to RansDecAdvance that takes a symbol.
RansDecAdvanceSymbol(RansState * r,uint8_t ** pptr,RansDecSymbol const * sym,uint32_t scale_bits)298 static inline void RansDecAdvanceSymbol(RansState* r, uint8_t** pptr, RansDecSymbol const* sym, uint32_t scale_bits)
299 {
300     RansDecAdvance(r, pptr, sym->start, sym->freq, scale_bits);
301 }
302 
303 // Advances in the bit stream by "popping" a single symbol with range start
304 // "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits".
305 // No renormalization or output happens.
RansDecAdvanceStep(RansState * r,uint32_t start,uint32_t freq,uint32_t scale_bits)306 static inline void RansDecAdvanceStep(RansState* r, uint32_t start, uint32_t freq, uint32_t scale_bits)
307 {
308     uint32_t mask = (1u << scale_bits) - 1;
309 
310     // s, x = D(x)
311     uint32_t x = *r;
312     *r = freq * (x >> scale_bits) + (x & mask) - start;
313 }
314 
315 // Equivalent to RansDecAdvanceStep that takes a symbol.
RansDecAdvanceSymbolStep(RansState * r,RansDecSymbol const * sym,uint32_t scale_bits)316 static inline void RansDecAdvanceSymbolStep(RansState* r, RansDecSymbol const* sym, uint32_t scale_bits)
317 {
318     RansDecAdvanceStep(r, sym->start, sym->freq, scale_bits);
319 }
320 
321 // Renormalize.
RansDecRenorm(RansState * r,uint8_t ** pptr)322 static inline void RansDecRenorm(RansState* r, uint8_t** pptr)
323 {
324     // renormalize
325     uint32_t x = *r;
326 
327     if (x < RANS_BYTE_L) {
328         uint8_t* ptr = *pptr;
329         x = (x << 8) | *ptr++;
330         if (x < RANS_BYTE_L)
331             x = (x << 8) | *ptr++;
332         *pptr = ptr;
333     }
334 
335     *r = x;
336 }
337 
338 // Renormalize, with extra checks for falling off the end of the input.
RansDecRenormSafe(RansState * r,uint8_t ** pptr,uint8_t * ptr_end)339 static inline void RansDecRenormSafe(RansState* r, uint8_t** pptr, uint8_t *ptr_end)
340 {
341     uint32_t x = *r;
342     uint8_t* ptr = *pptr;
343     if (x >= RANS_BYTE_L || ptr >= ptr_end) return;
344     x = (x << 8) | *ptr++;
345     if (x < RANS_BYTE_L && ptr < ptr_end)
346         x = (x << 8) | *ptr++;
347     *pptr = ptr;
348     *r = x;
349 }
350 
351 
352 #endif // RANS_BYTE_HEADER
353