1 /*
2  *  Copyright (c) 2016, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree. An additional grant
7  *  of patent rights can be found in the PATENTS file in the same directory.
8  */
9 
10 #ifndef FATAL_INCLUDE_fatal_codec_varint_h
11 #define FATAL_INCLUDE_fatal_codec_varint_h
12 
13 #include <fatal/math/numerics.h>
14 
15 #include <array>
16 #include <iterator>
17 #include <type_traits>
18 
19 #include <cassert>
20 
21 namespace fatal {
22 namespace detail {
23 namespace varint_impl {
24 
25 // the maximum amount of data units with (DataSize + 1) bits of data taken
26 // by the varint encoding, for a type that takes up ValueSize bytes
27 template <std::size_t ValueSize, std::size_t DataSize>
28 using size = std::integral_constant<
29   std::size_t,
30   DataSize < ValueSize ? (ValueSize - 1) / DataSize + 1 : 1
31 >;
32 
33 auto constexpr byte_size = data_bits<char>::value;
34 auto constexpr byte_payload = byte_size - 1;
35 
36 static_assert(size<1, byte_payload>::value == 1, "size mismatch");
37 static_assert(size<2, byte_payload>::value == 1, "size mismatch");
38 static_assert(size<3, byte_payload>::value == 1, "size mismatch");
39 static_assert(size<4, byte_payload>::value == 1, "size mismatch");
40 static_assert(size<5, byte_payload>::value == 1, "size mismatch");
41 static_assert(size<6, byte_payload>::value == 1, "size mismatch");
42 static_assert(size<7, byte_payload>::value == 1, "size mismatch");
43 static_assert(size<8, byte_payload>::value == 2, "size mismatch");
44 static_assert(size<9, byte_payload>::value == 2, "size mismatch");
45 static_assert(size<10, byte_payload>::value == 2, "size mismatch");
46 static_assert(size<11, byte_payload>::value == 2, "size mismatch");
47 static_assert(size<12, byte_payload>::value == 2, "size mismatch");
48 static_assert(size<13, byte_payload>::value == 2, "size mismatch");
49 static_assert(size<14, byte_payload>::value == 2, "size mismatch");
50 static_assert(size<15, byte_payload>::value == 3, "size mismatch");
51 static_assert(size<16, byte_payload>::value == 3, "size mismatch");
52 static_assert(size<17, byte_payload>::value == 3, "size mismatch");
53 static_assert(size<18, byte_payload>::value == 3, "size mismatch");
54 static_assert(size<19, byte_payload>::value == 3, "size mismatch");
55 static_assert(size<20, byte_payload>::value == 3, "size mismatch");
56 static_assert(size<21, byte_payload>::value == 3, "size mismatch");
57 static_assert(size<22, byte_payload>::value == 4, "size mismatch");
58 static_assert(size<23, byte_payload>::value == 4, "size mismatch");
59 static_assert(size<24, byte_payload>::value == 4, "size mismatch");
60 static_assert(size<25, byte_payload>::value == 4, "size mismatch");
61 static_assert(size<26, byte_payload>::value == 4, "size mismatch");
62 static_assert(size<27, byte_payload>::value == 4, "size mismatch");
63 static_assert(size<28, byte_payload>::value == 4, "size mismatch");
64 static_assert(size<29, byte_payload>::value == 5, "size mismatch");
65 static_assert(size<30, byte_payload>::value == 5, "size mismatch");
66 static_assert(size<31, byte_payload>::value == 5, "size mismatch");
67 static_assert(size<32, byte_payload>::value == 5, "size mismatch");
68 static_assert(size<33, byte_payload>::value == 5, "size mismatch");
69 static_assert(size<34, byte_payload>::value == 5, "size mismatch");
70 static_assert(size<35, byte_payload>::value == 5, "size mismatch");
71 static_assert(size<36, byte_payload>::value == 6, "size mismatch");
72 static_assert(size<37, byte_payload>::value == 6, "size mismatch");
73 static_assert(size<38, byte_payload>::value == 6, "size mismatch");
74 static_assert(size<39, byte_payload>::value == 6, "size mismatch");
75 static_assert(size<40, byte_payload>::value == 6, "size mismatch");
76 static_assert(size<41, byte_payload>::value == 6, "size mismatch");
77 static_assert(size<42, byte_payload>::value == 6, "size mismatch");
78 static_assert(size<43, byte_payload>::value == 7, "size mismatch");
79 static_assert(size<44, byte_payload>::value == 7, "size mismatch");
80 static_assert(size<45, byte_payload>::value == 7, "size mismatch");
81 static_assert(size<46, byte_payload>::value == 7, "size mismatch");
82 static_assert(size<47, byte_payload>::value == 7, "size mismatch");
83 static_assert(size<48, byte_payload>::value == 7, "size mismatch");
84 static_assert(size<49, byte_payload>::value == 7, "size mismatch");
85 static_assert(size<50, byte_payload>::value == 8, "size mismatch");
86 static_assert(size<51, byte_payload>::value == 8, "size mismatch");
87 static_assert(size<52, byte_payload>::value == 8, "size mismatch");
88 static_assert(size<53, byte_payload>::value == 8, "size mismatch");
89 static_assert(size<54, byte_payload>::value == 8, "size mismatch");
90 static_assert(size<55, byte_payload>::value == 8, "size mismatch");
91 static_assert(size<56, byte_payload>::value == 8, "size mismatch");
92 static_assert(size<57, byte_payload>::value == 9, "size mismatch");
93 static_assert(size<58, byte_payload>::value == 9, "size mismatch");
94 static_assert(size<59, byte_payload>::value == 9, "size mismatch");
95 static_assert(size<60, byte_payload>::value == 9, "size mismatch");
96 static_assert(size<61, byte_payload>::value == 9, "size mismatch");
97 static_assert(size<62, byte_payload>::value == 9, "size mismatch");
98 static_assert(size<63, byte_payload>::value == 9, "size mismatch");
99 static_assert(size<64, byte_payload>::value == 10, "size mismatch");
100 static_assert(size<65, byte_payload>::value == 10, "size mismatch");
101 static_assert(size<66, byte_payload>::value == 10, "size mismatch");
102 static_assert(size<67, byte_payload>::value == 10, "size mismatch");
103 static_assert(size<68, byte_payload>::value == 10, "size mismatch");
104 static_assert(size<69, byte_payload>::value == 10, "size mismatch");
105 static_assert(size<70, byte_payload>::value == 10, "size mismatch");
106 
107 static_assert(size<byte_size * 1, byte_payload>::value == 2, "size mismatch");
108 static_assert(size<byte_size * 2, byte_payload>::value == 3, "size mismatch");
109 static_assert(size<byte_size * 3, byte_payload>::value == 4, "size mismatch");
110 static_assert(size<byte_size * 4, byte_payload>::value == 5, "size mismatch");
111 static_assert(size<byte_size * 5, byte_payload>::value == 6, "size mismatch");
112 static_assert(size<byte_size * 6, byte_payload>::value == 7, "size mismatch");
113 static_assert(size<byte_size * 7, byte_payload>::value == 8, "size mismatch");
114 static_assert(size<byte_size * 8, byte_payload>::value == 10, "size mismatch");
115 static_assert(size<byte_size * 9, byte_payload>::value == 11, "size mismatch");
116 static_assert(size<byte_size * 10, byte_payload>::value == 12, "size mismatch");
117 static_assert(size<byte_size * 11, byte_payload>::value == 13, "size mismatch");
118 static_assert(size<byte_size * 12, byte_payload>::value == 14, "size mismatch");
119 static_assert(size<byte_size * 13, byte_payload>::value == 15, "size mismatch");
120 static_assert(size<byte_size * 14, byte_payload>::value == 16, "size mismatch");
121 static_assert(size<byte_size * 15, byte_payload>::value == 18, "size mismatch");
122 static_assert(size<byte_size * 16, byte_payload>::value == 19, "size mismatch");
123 
124 template <typename TData>
125 struct data_traits {
126   using data_unit = TData;
127   using unsigned_unit = typename std::make_unsigned<data_unit>::type;
128 
129   static_assert(
130     std::is_integral<data_unit>::value,
131     "expected integral data unit"
132   );
133 
134   static_assert(
135     data_bits<unsigned_unit>::value == data_bits<data_unit>::value,
136     "unsupported data unit"
137   );
138 
139   static_assert(
140     data_bits<unsigned_unit>::value > 1,
141     "at least 2 bits of data needed"
142   );
143 
144   using data_size = data_bits<unsigned_unit>;
145 
146   using payload_size = std::integral_constant<
147     std::size_t, data_size::value - 1
148   >;
149 
150   using continuation_bit = std::integral_constant<
151     unsigned_unit, static_cast<unsigned_unit>(1) << payload_size::value
152   >;
153 
154   static_assert(
155     most_significant_bit<continuation_bit::value>::value == data_size::value,
156     "invalid continuation bit"
157   );
158 
159   static_assert(
160     pop_count<continuation_bit::value>::value == 1,
161     "invalid continuation bit"
162   );
163 
164   using filter_mask = std::integral_constant<
165     unsigned_unit,
166     static_cast<unsigned_unit>(~continuation_bit::value)
167   >;
168 
169   static_assert(
170     most_significant_bit<filter_mask::value>::value == payload_size::value,
171     "invalid filter mask"
172   );
173 
174   static_assert(
175     pop_count<filter_mask::value>::value == payload_size::value,
176     "invalid filter mask"
177   );
178 
179   // the maximum amount of TData units taken by the varint encoding of a type T
180   template <typename T>
181   class size_for {
182     static_assert(std::is_integral<T>::value, "expected an integral");
183     static_assert(
184       std::is_same<typename std::decay<T>::type, T>::value,
185       "plain type expected"
186     );
187 
188   public:
189     using type = size<data_bits<T>::value, payload_size::value>;
190   };
191 
fromdata_traits192   static unsigned_unit from(fast_pass<data_unit> value) {
193     return *reinterpret_cast<unsigned_unit const *>(std::addressof(value));
194   }
195 
todata_traits196   static data_unit to(fast_pass<unsigned_unit> value) {
197     return *reinterpret_cast<data_unit const *>(std::addressof(value));
198   }
199 };
200 
201 template <bool, typename T>
202 struct value_traits {
203   using external = T;
204   using internal = external;
205 
206   static_assert(std::is_unsigned<external>::value, "implementation mismatch");
207 
prevalue_traits208   static internal pre(external value) noexcept { return value; }
postvalue_traits209   static external post(internal value) noexcept { return value; }
210 };
211 
212 template <typename T>
213 struct value_traits<true, T> {
214   using external = T;
215   using internal = typename std::make_unsigned<external>::type;
216 
217   static_assert(std::is_signed<external>::value, "implementation mismatch");
218   static_assert(sizeof(external) == sizeof(internal), "invalid integral");
219 
220   static internal pre(fast_pass<external> value) noexcept {
221     auto ivalue = *reinterpret_cast<internal const *>(std::addressof(value));
222     return internal(ivalue << 1) | internal(value < 0 ? 1 : 0);
223   }
224 
225   static external post(internal value) noexcept {
226     auto const shift = (data_bits<internal>::value - 1);
227     value = internal(value >> 1) | internal((value & 1) << shift);
228     return *reinterpret_cast<external const *>(std::addressof(value));
229   }
230 };
231 
232 } // namespace varint_impl {
233 } // namespace detail {
234 
235 // TODO: DOCUMENT
236 template <typename T>
237 struct varint {
238   using type = T;
239 
240 private:
241   using shift_counter = smallest_fast_unsigned_integral<
242     most_significant_bit<data_bits<type>::value>::value
243   >;
244 
245   using value_traits = detail::varint_impl::value_traits<
246     std::is_signed<type>::value, type
247   >;
248   using internal = typename value_traits::internal;
249 
250 public:
251   // largest amount of `TData` written when encoding a value of type `type`
252   template <typename TData>
253   using max_size = typename detail::varint_impl::data_traits<TData>
254     ::template size_for<type>::type;
255 
256   // an automatically allocated buffer of `TData` able
257   // to hold any encoding of a value of type `type`
258   template <typename TData = char>
259   using automatic_buffer = std::array<TData, max_size<TData>::value>;
260 
261   struct encoder {
262     explicit encoder(fast_pass<type> value) noexcept:
263       value_(value_traits::pre(value))
264     {}
265 
266     // TODO: RETURN SIZE?
267     template <typename TOutputIterator>
268     TOutputIterator operator ()(
269       TOutputIterator begin,
270       TOutputIterator const end
271     ) noexcept {
272       using traits = detail::varint_impl::data_traits<
273         typename std::iterator_traits<TOutputIterator>::value_type
274       >;
275       using unsigned_unit = typename traits::unsigned_unit;
276 
277       for (; begin != end; std::advance(begin, 1)) {
278         unsigned_unit data = value_ & traits::filter_mask::value;
279 
280         value_ >>= traits::payload_size::value;
281 
282         if (value_) {
283           *begin = traits::to(data | traits::continuation_bit::value);
284         } else {
285           *begin = traits::to(data);
286           continued_ = false;
287           return std::next(begin);
288         }
289       }
290 
291       return begin;
292     }
293 
294     void reset(fast_pass<type> value) noexcept {
295       value_ = value;
296       continued_ = true;
297     }
298 
299     // returns true if the encoding is done, false if it needs more data
300     bool done() const noexcept { return !continued_; }
301 
302     // returns false if the encoding is done, true if it needs more data
303     bool operator !() const noexcept { return continued_; }
304 
305     // returns true if the encoding is done, false if it needs more data
306     explicit operator bool() const noexcept { return !continued_; }
307 
308   private:
309     internal value_;
310     bool continued_ = true;
311   };
312 
313   // returns the iterator `i` to the first unused element of the output
314   // buffer such that [out, i) represents the data that have been encoded
315   // buffer must be able to fit at least `max_size<decltype(*out)>` elements
316   // TODO: RETURN SIZE?
317   template <typename TOutputIterator>
318   static TOutputIterator encode(
319     type value,
320     TOutputIterator out
321   ) noexcept {
322     using traits = detail::varint_impl::data_traits<
323       typename std::iterator_traits<TOutputIterator>::value_type
324     >;
325     using unsigned_unit = typename traits::unsigned_unit;
326 
327     for (internal x = value_traits::pre(value); ; std::advance(out, 1)) {
328       unsigned_unit data = x & traits::filter_mask::value;
329 
330       x >>= traits::payload_size::value;
331 
332       if (x) {
333         *out = traits::to(data | traits::continuation_bit::value);
334       } else {
335         *out = traits::to(data);
336         return std::next(out);
337       }
338     }
339   }
340 
341   struct decoder {
342     // returns the iterator `i` to the first unused element such
343     // that [begin, i) represents the data that have been decoded
344     template <typename TInputIterator>
345     TInputIterator operator ()(
346       TInputIterator begin,
347       TInputIterator const end
348     ) noexcept {
349       using traits = detail::varint_impl::data_traits<
350         typename std::iterator_traits<TInputIterator>::value_type
351       >;
352 
353       for (; continuation_ && begin != end; std::advance(begin, 1)) {
354         assert(continuation_);
355 
356         value_ |= static_cast<internal>(
357           traits::from(*begin) & traits::filter_mask::value
358         ) << shift_;
359         shift_ += traits::payload_size::value;
360 
361         continuation_ = traits::from(*begin) & traits::continuation_bit::value;
362       }
363 
364       return begin;
365     }
366 
367     // resets the internal structure of this decoder as if
368     // no data had been fed to it
369     void reset() noexcept {
370       value_ = 0;
371       shift_ = 0;
372       continuation_ = true;
373     }
374 
375     // the value decoded so far
376     type value() const noexcept { return value_traits::post(value_); }
377 
378     // returns true if the decoding is done, false if it needs more data
379     bool done() const noexcept { return !continuation_; }
380 
381     // returns false if the decoding is done, true if it needs more data
382     bool operator !() const noexcept { return continuation_; }
383 
384     // returns true if the decoding is done, false if it needs more data
385     explicit operator bool() const noexcept { return !continuation_; }
386 
387   private:
388     internal value_ = 0;
389     shift_counter shift_ = 0;
390     bool continuation_ = true;
391   };
392 
393   template <typename TInputIterator>
394   static std::pair<type, bool> decode(
395     TInputIterator begin,
396     TInputIterator end
397   ) noexcept {
398     decoder decode;
399     decode(begin, end);
400     return std::make_pair(decode.value(), decode.done());
401   }
402 
403   // TODO: bike-shed
404   template <typename TInputIterator>
405   static std::pair<type, bool> tracking_decode(
406     TInputIterator &begin,
407     TInputIterator end
408   ) noexcept {
409     decoder decode;
410     begin = decode(begin, end);
411     return std::make_pair(decode.value(), decode.done());
412   }
413 };
414 
415 } // namespace fatal {
416 
417 #endif // FATAL_INCLUDE_fatal_codec_varint_h
418