1 // Copyright 2016 The Draco Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 #ifndef DRACO_CORE_ANS_H_
16 #define DRACO_CORE_ANS_H_
17 // An implementation of Asymmetric Numeral Systems (rANS).
18 // See http://arxiv.org/abs/1311.2540v2 for more information on rANS.
19 // This file is based off libvpx's ans.h.
20
21 #include <vector>
22
23 #define DRACO_ANS_DIVIDE_BY_MULTIPLY 1
24 #if DRACO_ANS_DIVIDE_BY_MULTIPLY
25 #include "draco/core/divide.h"
26 #endif
27 #include "draco/core/macros.h"
28
29 namespace draco {
30
31 #if DRACO_ANS_DIVIDE_BY_MULTIPLY
32
33 #define DRACO_ANS_DIVREM(quotient, remainder, dividend, divisor) \
34 do { \
35 quotient = fastdiv(dividend, divisor); \
36 remainder = dividend - quotient * divisor; \
37 } while (0)
38 #define DRACO_ANS_DIV(dividend, divisor) fastdiv(dividend, divisor)
39 #else
40 #define DRACO_ANS_DIVREM(quotient, remainder, dividend, divisor) \
41 do { \
42 quotient = dividend / divisor; \
43 remainder = dividend % divisor; \
44 } while (0)
45 #define DRACO_ANS_DIV(dividend, divisor) ((dividend) / (divisor))
46 #endif
47
48 struct AnsCoder {
AnsCoderAnsCoder49 AnsCoder() : buf(nullptr), buf_offset(0), state(0) {}
50 uint8_t *buf;
51 int buf_offset;
52 uint32_t state;
53 };
54
55 struct AnsDecoder {
AnsDecoderAnsDecoder56 AnsDecoder() : buf(nullptr), buf_offset(0), state(0) {}
57 const uint8_t *buf;
58 int buf_offset;
59 uint32_t state;
60 };
61
62 typedef uint8_t AnsP8;
63 #define DRACO_ANS_P8_PRECISION 256u
64 #define DRACO_ANS_L_BASE (4096u)
65 #define DRACO_ANS_IO_BASE 256
66
mem_get_le16(const void * vmem)67 static uint32_t mem_get_le16(const void *vmem) {
68 uint32_t val;
69 const uint8_t *mem = (const uint8_t *)vmem;
70
71 val = mem[1] << 8;
72 val |= mem[0];
73 return val;
74 }
75
mem_get_le24(const void * vmem)76 static uint32_t mem_get_le24(const void *vmem) {
77 uint32_t val;
78 const uint8_t *mem = (const uint8_t *)vmem;
79
80 val = mem[2] << 16;
81 val |= mem[1] << 8;
82 val |= mem[0];
83 return val;
84 }
85
mem_get_le32(const void * vmem)86 static inline uint32_t mem_get_le32(const void *vmem) {
87 uint32_t val;
88 const uint8_t *mem = (const uint8_t *)vmem;
89
90 val = mem[3] << 24;
91 val |= mem[2] << 16;
92 val |= mem[1] << 8;
93 val |= mem[0];
94 return val;
95 }
96
mem_put_le16(void * vmem,uint32_t val)97 static inline void mem_put_le16(void *vmem, uint32_t val) {
98 uint8_t *mem = reinterpret_cast<uint8_t *>(vmem);
99
100 mem[0] = (val >> 0) & 0xff;
101 mem[1] = (val >> 8) & 0xff;
102 }
103
mem_put_le24(void * vmem,uint32_t val)104 static inline void mem_put_le24(void *vmem, uint32_t val) {
105 uint8_t *mem = reinterpret_cast<uint8_t *>(vmem);
106
107 mem[0] = (val >> 0) & 0xff;
108 mem[1] = (val >> 8) & 0xff;
109 mem[2] = (val >> 16) & 0xff;
110 }
111
mem_put_le32(void * vmem,uint32_t val)112 static inline void mem_put_le32(void *vmem, uint32_t val) {
113 uint8_t *mem = reinterpret_cast<uint8_t *>(vmem);
114
115 mem[0] = (val >> 0) & 0xff;
116 mem[1] = (val >> 8) & 0xff;
117 mem[2] = (val >> 16) & 0xff;
118 mem[3] = (val >> 24) & 0xff;
119 }
120
ans_write_init(struct AnsCoder * const ans,uint8_t * const buf)121 static inline void ans_write_init(struct AnsCoder *const ans,
122 uint8_t *const buf) {
123 ans->buf = buf;
124 ans->buf_offset = 0;
125 ans->state = DRACO_ANS_L_BASE;
126 }
127
ans_write_end(struct AnsCoder * const ans)128 static inline int ans_write_end(struct AnsCoder *const ans) {
129 uint32_t state;
130 DRACO_DCHECK_GE(ans->state, DRACO_ANS_L_BASE);
131 DRACO_DCHECK_LT(ans->state, DRACO_ANS_L_BASE * DRACO_ANS_IO_BASE);
132 state = ans->state - DRACO_ANS_L_BASE;
133 if (state < (1 << 6)) {
134 ans->buf[ans->buf_offset] = (0x00 << 6) + state;
135 return ans->buf_offset + 1;
136 } else if (state < (1 << 14)) {
137 mem_put_le16(ans->buf + ans->buf_offset, (0x01 << 14) + state);
138 return ans->buf_offset + 2;
139 } else if (state < (1 << 22)) {
140 mem_put_le24(ans->buf + ans->buf_offset, (0x02 << 22) + state);
141 return ans->buf_offset + 3;
142 } else {
143 DRACO_DCHECK(0 && "State is too large to be serialized");
144 return ans->buf_offset;
145 }
146 }
147
148 // rABS with descending spread.
149 // p or p0 takes the place of l_s from the paper.
150 // DRACO_ANS_P8_PRECISION is m.
rabs_desc_write(struct AnsCoder * ans,int val,AnsP8 p0)151 static inline void rabs_desc_write(struct AnsCoder *ans, int val, AnsP8 p0) {
152 const AnsP8 p = DRACO_ANS_P8_PRECISION - p0;
153 const unsigned l_s = val ? p : p0;
154 unsigned quot, rem;
155 if (ans->state >=
156 DRACO_ANS_L_BASE / DRACO_ANS_P8_PRECISION * DRACO_ANS_IO_BASE * l_s) {
157 ans->buf[ans->buf_offset++] = ans->state % DRACO_ANS_IO_BASE;
158 ans->state /= DRACO_ANS_IO_BASE;
159 }
160 DRACO_ANS_DIVREM(quot, rem, ans->state, l_s);
161 ans->state = quot * DRACO_ANS_P8_PRECISION + rem + (val ? 0 : p);
162 }
163
164 #define DRACO_ANS_IMPL1 0
165 #define UNPREDICTABLE(x) x
rabs_desc_read(struct AnsDecoder * ans,AnsP8 p0)166 static inline int rabs_desc_read(struct AnsDecoder *ans, AnsP8 p0) {
167 int val;
168 #if DRACO_ANS_IMPL1
169 unsigned l_s;
170 #else
171 unsigned quot, rem, x, xn;
172 #endif
173 const AnsP8 p = DRACO_ANS_P8_PRECISION - p0;
174 if (ans->state < DRACO_ANS_L_BASE && ans->buf_offset > 0) {
175 ans->state = ans->state * DRACO_ANS_IO_BASE + ans->buf[--ans->buf_offset];
176 }
177 #if DRACO_ANS_IMPL1
178 val = ans->state % DRACO_ANS_P8_PRECISION < p;
179 l_s = val ? p : p0;
180 ans->state = (ans->state / DRACO_ANS_P8_PRECISION) * l_s +
181 ans->state % DRACO_ANS_P8_PRECISION - (!val * p);
182 #else
183 x = ans->state;
184 quot = x / DRACO_ANS_P8_PRECISION;
185 rem = x % DRACO_ANS_P8_PRECISION;
186 xn = quot * p;
187 val = rem < p;
188 if (UNPREDICTABLE(val)) {
189 ans->state = xn + rem;
190 } else {
191 // ans->state = quot * p0 + rem - p;
192 ans->state = x - xn - p;
193 }
194 #endif
195 return val;
196 }
197
198 // rABS with ascending spread.
199 // p or p0 takes the place of l_s from the paper.
200 // DRACO_ANS_P8_PRECISION is m.
rabs_asc_write(struct AnsCoder * ans,int val,AnsP8 p0)201 static inline void rabs_asc_write(struct AnsCoder *ans, int val, AnsP8 p0) {
202 const AnsP8 p = DRACO_ANS_P8_PRECISION - p0;
203 const unsigned l_s = val ? p : p0;
204 unsigned quot, rem;
205 if (ans->state >=
206 DRACO_ANS_L_BASE / DRACO_ANS_P8_PRECISION * DRACO_ANS_IO_BASE * l_s) {
207 ans->buf[ans->buf_offset++] = ans->state % DRACO_ANS_IO_BASE;
208 ans->state /= DRACO_ANS_IO_BASE;
209 }
210 DRACO_ANS_DIVREM(quot, rem, ans->state, l_s);
211 ans->state = quot * DRACO_ANS_P8_PRECISION + rem + (val ? p0 : 0);
212 }
213
rabs_asc_read(struct AnsDecoder * ans,AnsP8 p0)214 static inline int rabs_asc_read(struct AnsDecoder *ans, AnsP8 p0) {
215 int val;
216 #if DRACO_ANS_IMPL1
217 unsigned l_s;
218 #else
219 unsigned quot, rem, x, xn;
220 #endif
221 const AnsP8 p = DRACO_ANS_P8_PRECISION - p0;
222 if (ans->state < DRACO_ANS_L_BASE) {
223 ans->state = ans->state * DRACO_ANS_IO_BASE + ans->buf[--ans->buf_offset];
224 }
225 #if DRACO_ANS_IMPL1
226 val = ans->state % DRACO_ANS_P8_PRECISION < p;
227 l_s = val ? p : p0;
228 ans->state = (ans->state / DRACO_ANS_P8_PRECISION) * l_s +
229 ans->state % DRACO_ANS_P8_PRECISION - (!val * p);
230 #else
231 x = ans->state;
232 quot = x / DRACO_ANS_P8_PRECISION;
233 rem = x % DRACO_ANS_P8_PRECISION;
234 xn = quot * p;
235 val = rem >= p0;
236 if (UNPREDICTABLE(val)) {
237 ans->state = xn + rem - p0;
238 } else {
239 // ans->state = quot * p0 + rem - p0;
240 ans->state = x - xn;
241 }
242 #endif
243 return val;
244 }
245
246 #define rabs_read rabs_desc_read
247 #define rabs_write rabs_desc_write
248
249 // uABS with normalization.
uabs_write(struct AnsCoder * ans,int val,AnsP8 p0)250 static inline void uabs_write(struct AnsCoder *ans, int val, AnsP8 p0) {
251 AnsP8 p = DRACO_ANS_P8_PRECISION - p0;
252 const unsigned l_s = val ? p : p0;
253 while (ans->state >=
254 DRACO_ANS_L_BASE / DRACO_ANS_P8_PRECISION * DRACO_ANS_IO_BASE * l_s) {
255 ans->buf[ans->buf_offset++] = ans->state % DRACO_ANS_IO_BASE;
256 ans->state /= DRACO_ANS_IO_BASE;
257 }
258 if (!val)
259 ans->state = DRACO_ANS_DIV(ans->state * DRACO_ANS_P8_PRECISION, p0);
260 else
261 ans->state =
262 DRACO_ANS_DIV((ans->state + 1) * DRACO_ANS_P8_PRECISION + p - 1, p) - 1;
263 }
264
uabs_read(struct AnsDecoder * ans,AnsP8 p0)265 static inline int uabs_read(struct AnsDecoder *ans, AnsP8 p0) {
266 AnsP8 p = DRACO_ANS_P8_PRECISION - p0;
267 int s;
268 // unsigned int xp1;
269 unsigned xp, sp;
270 unsigned state = ans->state;
271 while (state < DRACO_ANS_L_BASE && ans->buf_offset > 0) {
272 state = state * DRACO_ANS_IO_BASE + ans->buf[--ans->buf_offset];
273 }
274 sp = state * p;
275 // xp1 = (sp + p) / DRACO_ANS_P8_PRECISION;
276 xp = sp / DRACO_ANS_P8_PRECISION;
277 // s = xp1 - xp;
278 s = (sp & 0xFF) >= p0;
279 if (UNPREDICTABLE(s))
280 ans->state = xp;
281 else
282 ans->state = state - xp;
283 return s;
284 }
285
uabs_read_bit(struct AnsDecoder * ans)286 static inline int uabs_read_bit(struct AnsDecoder *ans) {
287 int s;
288 unsigned state = ans->state;
289 while (state < DRACO_ANS_L_BASE && ans->buf_offset > 0) {
290 state = state * DRACO_ANS_IO_BASE + ans->buf[--ans->buf_offset];
291 }
292 s = static_cast<int>(state & 1);
293 ans->state = state >> 1;
294 return s;
295 }
296
ans_read_init(struct AnsDecoder * const ans,const uint8_t * const buf,int offset)297 static inline int ans_read_init(struct AnsDecoder *const ans,
298 const uint8_t *const buf, int offset) {
299 unsigned x;
300 if (offset < 1)
301 return 1;
302 ans->buf = buf;
303 x = buf[offset - 1] >> 6;
304 if (x == 0) {
305 ans->buf_offset = offset - 1;
306 ans->state = buf[offset - 1] & 0x3F;
307 } else if (x == 1) {
308 if (offset < 2)
309 return 1;
310 ans->buf_offset = offset - 2;
311 ans->state = mem_get_le16(buf + offset - 2) & 0x3FFF;
312 } else if (x == 2) {
313 if (offset < 3)
314 return 1;
315 ans->buf_offset = offset - 3;
316 ans->state = mem_get_le24(buf + offset - 3) & 0x3FFFFF;
317 } else {
318 return 1;
319 }
320 ans->state += DRACO_ANS_L_BASE;
321 if (ans->state >= DRACO_ANS_L_BASE * DRACO_ANS_IO_BASE)
322 return 1;
323 return 0;
324 }
325
ans_read_end(struct AnsDecoder * const ans)326 static inline int ans_read_end(struct AnsDecoder *const ans) {
327 return ans->state == DRACO_ANS_L_BASE;
328 }
329
ans_reader_has_error(const struct AnsDecoder * const ans)330 static inline int ans_reader_has_error(const struct AnsDecoder *const ans) {
331 return ans->state < DRACO_ANS_L_BASE && ans->buf_offset == 0;
332 }
333
334 struct rans_sym {
335 uint32_t prob;
336 uint32_t cum_prob; // not-inclusive.
337 };
338
339 // Class for performing rANS encoding using a desired number of precision bits.
340 // The max number of precision bits is currently 19. The actual number of
341 // symbols in the input alphabet should be (much) smaller than that, otherwise
342 // the compression rate may suffer.
343 template <int rans_precision_bits_t>
344 class RAnsEncoder {
345 public:
RAnsEncoder()346 RAnsEncoder() {}
347
348 // Provides the input buffer where the data is going to be stored.
write_init(uint8_t * const buf)349 inline void write_init(uint8_t *const buf) {
350 ans_.buf = buf;
351 ans_.buf_offset = 0;
352 ans_.state = l_rans_base;
353 }
354
355 // Needs to be called after all symbols are encoded.
write_end()356 inline int write_end() {
357 uint32_t state;
358 DRACO_DCHECK_GE(ans_.state, l_rans_base);
359 DRACO_DCHECK_LT(ans_.state, l_rans_base * DRACO_ANS_IO_BASE);
360 state = ans_.state - l_rans_base;
361 if (state < (1 << 6)) {
362 ans_.buf[ans_.buf_offset] = (0x00 << 6) + state;
363 return ans_.buf_offset + 1;
364 } else if (state < (1 << 14)) {
365 mem_put_le16(ans_.buf + ans_.buf_offset, (0x01 << 14) + state);
366 return ans_.buf_offset + 2;
367 } else if (state < (1 << 22)) {
368 mem_put_le24(ans_.buf + ans_.buf_offset, (0x02 << 22) + state);
369 return ans_.buf_offset + 3;
370 } else if (state < (1 << 30)) {
371 mem_put_le32(ans_.buf + ans_.buf_offset, (0x03u << 30u) + state);
372 return ans_.buf_offset + 4;
373 } else {
374 DRACO_DCHECK(0 && "State is too large to be serialized");
375 return ans_.buf_offset;
376 }
377 }
378
379 // rANS with normalization.
380 // sym->prob takes the place of l_s from the paper.
381 // rans_precision is m.
rans_write(const struct rans_sym * const sym)382 inline void rans_write(const struct rans_sym *const sym) {
383 const uint32_t p = sym->prob;
384 while (ans_.state >= l_rans_base / rans_precision * DRACO_ANS_IO_BASE * p) {
385 ans_.buf[ans_.buf_offset++] = ans_.state % DRACO_ANS_IO_BASE;
386 ans_.state /= DRACO_ANS_IO_BASE;
387 }
388 // TODO(ostava): The division and multiplication should be optimized.
389 ans_.state =
390 (ans_.state / p) * rans_precision + ans_.state % p + sym->cum_prob;
391 }
392
393 private:
394 static constexpr int rans_precision = 1 << rans_precision_bits_t;
395 static constexpr int l_rans_base = rans_precision * 4;
396 AnsCoder ans_;
397 };
398
399 struct rans_dec_sym {
400 uint32_t val;
401 uint32_t prob;
402 uint32_t cum_prob; // not-inclusive.
403 };
404
405 // Class for performing rANS decoding using a desired number of precision bits.
406 // The number of precision bits needs to be the same as with the RAnsEncoder
407 // that was used to encode the input data.
408 template <int rans_precision_bits_t>
409 class RAnsDecoder {
410 public:
RAnsDecoder()411 RAnsDecoder() {}
412
413 // Initializes the decoder from the input buffer. The |offset| specifies the
414 // number of bytes encoded by the encoder. A non zero return value is an
415 // error.
read_init(const uint8_t * const buf,int offset)416 inline int read_init(const uint8_t *const buf, int offset) {
417 unsigned x;
418 if (offset < 1)
419 return 1;
420 ans_.buf = buf;
421 x = buf[offset - 1] >> 6;
422 if (x == 0) {
423 ans_.buf_offset = offset - 1;
424 ans_.state = buf[offset - 1] & 0x3F;
425 } else if (x == 1) {
426 if (offset < 2)
427 return 1;
428 ans_.buf_offset = offset - 2;
429 ans_.state = mem_get_le16(buf + offset - 2) & 0x3FFF;
430 } else if (x == 2) {
431 if (offset < 3)
432 return 1;
433 ans_.buf_offset = offset - 3;
434 ans_.state = mem_get_le24(buf + offset - 3) & 0x3FFFFF;
435 } else if (x == 3) {
436 ans_.buf_offset = offset - 4;
437 ans_.state = mem_get_le32(buf + offset - 4) & 0x3FFFFFFF;
438 } else {
439 return 1;
440 }
441 ans_.state += l_rans_base;
442 if (ans_.state >= l_rans_base * DRACO_ANS_IO_BASE)
443 return 1;
444 return 0;
445 }
446
read_end()447 inline int read_end() { return ans_.state == l_rans_base; }
448
reader_has_error()449 inline int reader_has_error() {
450 return ans_.state < l_rans_base && ans_.buf_offset == 0;
451 }
452
rans_read()453 inline int rans_read() {
454 unsigned rem;
455 unsigned quo;
456 struct rans_dec_sym sym;
457 while (ans_.state < l_rans_base && ans_.buf_offset > 0) {
458 ans_.state = ans_.state * DRACO_ANS_IO_BASE + ans_.buf[--ans_.buf_offset];
459 }
460 // |rans_precision| is a power of two compile time constant, and the below
461 // division and modulo are going to be optimized by the compiler.
462 quo = ans_.state / rans_precision;
463 rem = ans_.state % rans_precision;
464 fetch_sym(&sym, rem);
465 ans_.state = quo * sym.prob + rem - sym.cum_prob;
466 return sym.val;
467 }
468
469 // Construct a lookup table with |rans_precision| number of entries.
470 // Returns false if the table couldn't be built (because of wrong input data).
rans_build_look_up_table(const uint32_t token_probs[],uint32_t num_symbols)471 inline bool rans_build_look_up_table(const uint32_t token_probs[],
472 uint32_t num_symbols) {
473 lut_table_.resize(rans_precision);
474 probability_table_.resize(num_symbols);
475 uint32_t cum_prob = 0;
476 uint32_t act_prob = 0;
477 for (uint32_t i = 0; i < num_symbols; ++i) {
478 probability_table_[i].prob = token_probs[i];
479 probability_table_[i].cum_prob = cum_prob;
480 cum_prob += token_probs[i];
481 if (cum_prob > rans_precision) {
482 return false;
483 }
484 for (uint32_t j = act_prob; j < cum_prob; ++j) {
485 lut_table_[j] = i;
486 }
487 act_prob = cum_prob;
488 }
489 if (cum_prob != rans_precision) {
490 return false;
491 }
492 return true;
493 }
494
495 private:
fetch_sym(struct rans_dec_sym * out,uint32_t rem)496 inline void fetch_sym(struct rans_dec_sym *out, uint32_t rem) {
497 uint32_t symbol = lut_table_[rem];
498 out->val = symbol;
499 out->prob = probability_table_[symbol].prob;
500 out->cum_prob = probability_table_[symbol].cum_prob;
501 }
502
503 static constexpr int rans_precision = 1 << rans_precision_bits_t;
504 static constexpr int l_rans_base = rans_precision * 4;
505 std::vector<uint32_t> lut_table_;
506 std::vector<rans_sym> probability_table_;
507 AnsDecoder ans_;
508 };
509
510 #undef DRACO_ANS_DIVREM
511 #undef DRACO_ANS_P8_PRECISION
512 #undef DRACO_ANS_L_BASE
513 #undef DRACO_ANS_IO_BASE
514
515 } // namespace draco
516
517 #endif // DRACO_CORE_ANS_H_
518