1 // Copyright 2016 The Draco Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 #ifndef DRACO_COMPRESSION_ENTROPY_RANS_SYMBOL_ENCODER_H_
16 #define DRACO_COMPRESSION_ENTROPY_RANS_SYMBOL_ENCODER_H_
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <cstring>
21 
22 #include "draco/compression/entropy/ans.h"
23 #include "draco/compression/entropy/rans_symbol_coding.h"
24 #include "draco/core/encoder_buffer.h"
25 #include "draco/core/varint_encoding.h"
26 
27 namespace draco {
28 
29 // A helper class for encoding symbols using the rANS algorithm (see ans.h).
30 // The class can be used to initialize and encode probability table needed by
31 // rANS, and to perform encoding of symbols into the provided EncoderBuffer.
32 template <int unique_symbols_bit_length_t>
33 class RAnsSymbolEncoder {
34  public:
RAnsSymbolEncoder()35   RAnsSymbolEncoder()
36       : num_symbols_(0), num_expected_bits_(0), buffer_offset_(0) {}
37 
38   // Creates a probability table needed by the rANS library and encode it into
39   // the provided buffer.
40   bool Create(const uint64_t *frequencies, int num_symbols,
41               EncoderBuffer *buffer);
42 
43   void StartEncoding(EncoderBuffer *buffer);
EncodeSymbol(uint32_t symbol)44   void EncodeSymbol(uint32_t symbol) {
45     ans_.rans_write(&probability_table_[symbol]);
46   }
47   void EndEncoding(EncoderBuffer *buffer);
48 
49   // rANS requires to encode the input symbols in the reverse order.
needs_reverse_encoding()50   static constexpr bool needs_reverse_encoding() { return true; }
51 
52  private:
53   // Functor used for sorting symbol ids according to their probabilities.
54   // The functor sorts symbol indices that index an underlying map between
55   // symbol ids and their probabilities. We don't sort the probability table
56   // directly, because that would require an additional indirection during the
57   // EncodeSymbol() function.
58   struct ProbabilityLess {
ProbabilityLessProbabilityLess59     explicit ProbabilityLess(const std::vector<rans_sym> *probs)
60         : probabilities(probs) {}
operatorProbabilityLess61     bool operator()(int i, int j) const {
62       return probabilities->at(i).prob < probabilities->at(j).prob;
63     }
64     const std::vector<rans_sym> *probabilities;
65   };
66 
67   // Encodes the probability table into the output buffer.
68   bool EncodeTable(EncoderBuffer *buffer);
69 
70   static constexpr int rans_precision_bits_ =
71       ComputeRAnsPrecisionFromUniqueSymbolsBitLength(
72           unique_symbols_bit_length_t);
73   static constexpr int rans_precision_ = 1 << rans_precision_bits_;
74 
75   std::vector<rans_sym> probability_table_;
76   // The number of symbols in the input alphabet.
77   uint32_t num_symbols_;
78   // Expected number of bits that is needed to encode the input.
79   uint64_t num_expected_bits_;
80 
81   RAnsEncoder<rans_precision_bits_> ans_;
82   // Initial offset of the encoder buffer before any ans data was encoded.
83   uint64_t buffer_offset_;
84 };
85 
86 template <int unique_symbols_bit_length_t>
Create(const uint64_t * frequencies,int num_symbols,EncoderBuffer * buffer)87 bool RAnsSymbolEncoder<unique_symbols_bit_length_t>::Create(
88     const uint64_t *frequencies, int num_symbols, EncoderBuffer *buffer) {
89   // Compute the total of the input frequencies.
90   uint64_t total_freq = 0;
91   int max_valid_symbol = 0;
92   for (int i = 0; i < num_symbols; ++i) {
93     total_freq += frequencies[i];
94     if (frequencies[i] > 0) {
95       max_valid_symbol = i;
96     }
97   }
98   num_symbols = max_valid_symbol + 1;
99   num_symbols_ = num_symbols;
100   probability_table_.resize(num_symbols);
101   const double total_freq_d = static_cast<double>(total_freq);
102   const double rans_precision_d = static_cast<double>(rans_precision_);
103   // Compute probabilities by rescaling the normalized frequencies into interval
104   // [1, rans_precision - 1]. The total probability needs to be equal to
105   // rans_precision.
106   int total_rans_prob = 0;
107   for (int i = 0; i < num_symbols; ++i) {
108     const uint64_t freq = frequencies[i];
109 
110     // Normalized probability.
111     const double prob = static_cast<double>(freq) / total_freq_d;
112 
113     // RAns probability in range of [1, rans_precision - 1].
114     uint32_t rans_prob = static_cast<uint32_t>(prob * rans_precision_d + 0.5f);
115     if (rans_prob == 0 && freq > 0) {
116       rans_prob = 1;
117     }
118     probability_table_[i].prob = rans_prob;
119     total_rans_prob += rans_prob;
120   }
121   // Because of rounding errors, the total precision may not be exactly accurate
122   // and we may need to adjust the entries a little bit.
123   if (total_rans_prob != rans_precision_) {
124     std::vector<int> sorted_probabilities(num_symbols);
125     for (int i = 0; i < num_symbols; ++i) {
126       sorted_probabilities[i] = i;
127     }
128     std::sort(sorted_probabilities.begin(), sorted_probabilities.end(),
129               ProbabilityLess(&probability_table_));
130     if (total_rans_prob < rans_precision_) {
131       // This happens rather infrequently, just add the extra needed precision
132       // to the most frequent symbol.
133       probability_table_[sorted_probabilities.back()].prob +=
134           rans_precision_ - total_rans_prob;
135     } else {
136       // We have over-allocated the precision, which is quite common.
137       // Rescale the probabilities of all symbols.
138       int32_t error = total_rans_prob - rans_precision_;
139       while (error > 0) {
140         const double act_total_prob_d = static_cast<double>(total_rans_prob);
141         const double act_rel_error_d = rans_precision_d / act_total_prob_d;
142         for (int j = num_symbols - 1; j > 0; --j) {
143           int symbol_id = sorted_probabilities[j];
144           if (probability_table_[symbol_id].prob <= 1) {
145             if (j == num_symbols - 1) {
146               return false;  // Most frequent symbol would be empty.
147             }
148             break;
149           }
150           const int32_t new_prob = static_cast<int32_t>(
151               floor(act_rel_error_d *
152                     static_cast<double>(probability_table_[symbol_id].prob)));
153           int32_t fix = probability_table_[symbol_id].prob - new_prob;
154           if (fix == 0u) {
155             fix = 1;
156           }
157           if (fix >= static_cast<int32_t>(probability_table_[symbol_id].prob)) {
158             fix = probability_table_[symbol_id].prob - 1;
159           }
160           if (fix > error) {
161             fix = error;
162           }
163           probability_table_[symbol_id].prob -= fix;
164           total_rans_prob -= fix;
165           error -= fix;
166           if (total_rans_prob == rans_precision_) {
167             break;
168           }
169         }
170       }
171     }
172   }
173 
174   // Compute the cumulative probability (cdf).
175   uint32_t total_prob = 0;
176   for (int i = 0; i < num_symbols; ++i) {
177     probability_table_[i].cum_prob = total_prob;
178     total_prob += probability_table_[i].prob;
179   }
180   if (total_prob != rans_precision_) {
181     return false;
182   }
183 
184   // Estimate the number of bits needed to encode the input.
185   // From Shannon entropy the total number of bits N is:
186   //   N = -sum{i : all_symbols}(F(i) * log2(P(i)))
187   // where P(i) is the normalized probability of symbol i and F(i) is the
188   // symbol's frequency in the input data.
189   double num_bits = 0;
190   for (int i = 0; i < num_symbols; ++i) {
191     if (probability_table_[i].prob == 0) {
192       continue;
193     }
194     const double norm_prob =
195         static_cast<double>(probability_table_[i].prob) / rans_precision_d;
196     num_bits += static_cast<double>(frequencies[i]) * log2(norm_prob);
197   }
198   num_expected_bits_ = static_cast<uint64_t>(ceil(-num_bits));
199   if (!EncodeTable(buffer)) {
200     return false;
201   }
202   return true;
203 }
204 
205 template <int unique_symbols_bit_length_t>
EncodeTable(EncoderBuffer * buffer)206 bool RAnsSymbolEncoder<unique_symbols_bit_length_t>::EncodeTable(
207     EncoderBuffer *buffer) {
208   EncodeVarint(num_symbols_, buffer);
209   // Use varint encoding for the probabilities (first two bits represent the
210   // number of bytes used - 1).
211   for (uint32_t i = 0; i < num_symbols_; ++i) {
212     const uint32_t prob = probability_table_[i].prob;
213     int num_extra_bytes = 0;
214     if (prob >= (1 << 6)) {
215       num_extra_bytes++;
216       if (prob >= (1 << 14)) {
217         num_extra_bytes++;
218         if (prob >= (1 << 22)) {
219           // The maximum number of precision bits is 20 so we should not really
220           // get to this point.
221           return false;
222         }
223       }
224     }
225     if (prob == 0) {
226       // When the probability of the symbol is 0, set the first two bits to 1
227       // (unique identifier) and use the remaining 6 bits to store the offset
228       // to the next symbol with non-zero probability.
229       uint32_t offset = 0;
230       for (; offset < (1 << 6) - 1; ++offset) {
231         // Note: we don't have to check whether the next symbol id is larger
232         // than num_symbols_ because we know that the last symbol always has
233         // non-zero probability.
234         const uint32_t next_prob = probability_table_[i + offset + 1].prob;
235         if (next_prob > 0) {
236           break;
237         }
238       }
239       buffer->Encode(static_cast<uint8_t>((offset << 2) | 3));
240       i += offset;
241     } else {
242       // Encode the first byte (including the number of extra bytes).
243       buffer->Encode(static_cast<uint8_t>((prob << 2) | (num_extra_bytes & 3)));
244       // Encode the extra bytes.
245       for (int b = 0; b < num_extra_bytes; ++b) {
246         buffer->Encode(static_cast<uint8_t>(prob >> (8 * (b + 1) - 2)));
247       }
248     }
249   }
250   return true;
251 }
252 
253 template <int unique_symbols_bit_length_t>
StartEncoding(EncoderBuffer * buffer)254 void RAnsSymbolEncoder<unique_symbols_bit_length_t>::StartEncoding(
255     EncoderBuffer *buffer) {
256   // Allocate extra storage just in case.
257   const uint64_t required_bits = 2 * num_expected_bits_ + 32;
258 
259   buffer_offset_ = buffer->size();
260   const int64_t required_bytes = (required_bits + 7) / 8;
261   buffer->Resize(buffer_offset_ + required_bytes + sizeof(buffer_offset_));
262   uint8_t *const data =
263       reinterpret_cast<uint8_t *>(const_cast<char *>(buffer->data()));
264   ans_.write_init(data + buffer_offset_);
265 }
266 
267 template <int unique_symbols_bit_length_t>
EndEncoding(EncoderBuffer * buffer)268 void RAnsSymbolEncoder<unique_symbols_bit_length_t>::EndEncoding(
269     EncoderBuffer *buffer) {
270   char *const src = const_cast<char *>(buffer->data()) + buffer_offset_;
271 
272   // TODO(fgalligan): Look into changing this to uint32_t as write_end()
273   // returns an int.
274   const uint64_t bytes_written = static_cast<uint64_t>(ans_.write_end());
275   EncoderBuffer var_size_buffer;
276   EncodeVarint(bytes_written, &var_size_buffer);
277   const uint32_t size_len = static_cast<uint32_t>(var_size_buffer.size());
278   char *const dst = src + size_len;
279   memmove(dst, src, bytes_written);
280 
281   // Store the size of the encoded data.
282   memcpy(src, var_size_buffer.data(), size_len);
283 
284   // Resize the buffer to match the number of encoded bytes.
285   buffer->Resize(buffer_offset_ + bytes_written + size_len);
286 }
287 
288 }  // namespace draco
289 
290 #endif  // DRACO_COMPRESSION_ENTROPY_RANS_SYMBOL_ENCODER_H_
291