1 /*
2  * This file Copyright (C) 2008-2014 Mnemosyne LLC
3  *
4  * It may be used under the GNU GPL versions 2 or 3
5  * or any future license endorsed by Mnemosyne LLC.
6  *
7  */
8 
9 #include <string.h> /* memset */
10 
11 #include "transmission.h"
12 #include "bitfield.h"
13 #include "tr-assert.h"
14 #include "utils.h" /* tr_new0() */
15 
16 tr_bitfield const TR_BITFIELD_INIT =
17 {
18     .bits = NULL,
19     .alloc_count = 0,
20     .bit_count = 0,
21     .true_count = 0,
22     .have_all_hint = false,
23     .have_none_hint = false
24 };
25 
26 /****
27 *****
28 ****/
29 
30 static int8_t const trueBitCount[256] =
31 {
32     0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
33     1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
34     1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
35     2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
36     1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
37     2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
38     2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
39     3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
40     1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
41     2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
42     2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
43     3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
44     2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
45     3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
46     3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
47     4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8
48 };
49 
countArray(tr_bitfield const * b)50 static size_t countArray(tr_bitfield const* b)
51 {
52     size_t ret = 0;
53     size_t i = b->alloc_count;
54 
55     while (i > 0)
56     {
57         ret += trueBitCount[b->bits[--i]];
58     }
59 
60     return ret;
61 }
62 
countRange(tr_bitfield const * b,size_t begin,size_t end)63 static size_t countRange(tr_bitfield const* b, size_t begin, size_t end)
64 {
65     size_t ret = 0;
66     size_t const first_byte = begin >> 3U;
67     size_t const last_byte = (end - 1) >> 3U;
68 
69     if (b->bit_count == 0)
70     {
71         return 0;
72     }
73 
74     if (first_byte >= b->alloc_count)
75     {
76         return 0;
77     }
78 
79     TR_ASSERT(begin < end);
80     TR_ASSERT(b->bits != NULL);
81 
82     if (first_byte == last_byte)
83     {
84         int i;
85         uint8_t val = b->bits[first_byte];
86 
87         i = begin - (first_byte * 8);
88         val <<= i;
89         val >>= i;
90         i = (last_byte + 1) * 8 - end;
91         val >>= i;
92         val <<= i;
93 
94         ret += trueBitCount[val];
95     }
96     else
97     {
98         uint8_t val;
99         size_t const walk_end = MIN(b->alloc_count, last_byte);
100 
101         /* first byte */
102         size_t const first_shift = begin - (first_byte * 8);
103         val = b->bits[first_byte];
104         val <<= first_shift;
105         val >>= first_shift;
106         ret += trueBitCount[val];
107 
108         /* middle bytes */
109         for (size_t i = first_byte + 1; i < walk_end; ++i)
110         {
111             ret += trueBitCount[b->bits[i]];
112         }
113 
114         /* last byte */
115         if (last_byte < b->alloc_count)
116         {
117             size_t const last_shift = (last_byte + 1) * 8 - end;
118             val = b->bits[last_byte];
119             val >>= last_shift;
120             val <<= last_shift;
121             ret += trueBitCount[val];
122         }
123     }
124 
125     TR_ASSERT(ret <= (begin - end));
126     return ret;
127 }
128 
tr_bitfieldCountRange(tr_bitfield const * b,size_t begin,size_t end)129 size_t tr_bitfieldCountRange(tr_bitfield const* b, size_t begin, size_t end)
130 {
131     if (tr_bitfieldHasAll(b))
132     {
133         return end - begin;
134     }
135 
136     if (tr_bitfieldHasNone(b))
137     {
138         return 0;
139     }
140 
141     return countRange(b, begin, end);
142 }
143 
tr_bitfieldHas(tr_bitfield const * b,size_t n)144 bool tr_bitfieldHas(tr_bitfield const* b, size_t n)
145 {
146     if (tr_bitfieldHasAll(b))
147     {
148         return true;
149     }
150 
151     if (tr_bitfieldHasNone(b))
152     {
153         return false;
154     }
155 
156     if (n >> 3U >= b->alloc_count)
157     {
158         return false;
159     }
160 
161     return (b->bits[n >> 3U] << (n & 7U) & 0x80) != 0;
162 }
163 
164 /***
165 ****
166 ***/
167 
168 #ifdef TR_ENABLE_ASSERTS
169 
tr_bitfieldIsValid(tr_bitfield const * b)170 static bool tr_bitfieldIsValid(tr_bitfield const* b)
171 {
172     TR_ASSERT(b != NULL);
173     TR_ASSERT((b->alloc_count == 0) == (b->bits == NULL));
174     TR_ASSERT(b->bits == NULL || b->true_count == countArray(b));
175 
176     return true;
177 }
178 
179 #endif
180 
tr_bitfieldCountTrueBits(tr_bitfield const * b)181 size_t tr_bitfieldCountTrueBits(tr_bitfield const* b)
182 {
183     TR_ASSERT(tr_bitfieldIsValid(b));
184 
185     return b->true_count;
186 }
187 
get_bytes_needed(size_t bit_count)188 static size_t get_bytes_needed(size_t bit_count)
189 {
190     return (bit_count >> 3) + ((bit_count & 7) != 0 ? 1 : 0);
191 }
192 
set_all_true(uint8_t * array,size_t bit_count)193 static void set_all_true(uint8_t* array, size_t bit_count)
194 {
195     uint8_t const val = 0xFF;
196     size_t const n = get_bytes_needed(bit_count);
197 
198     if (n > 0)
199     {
200         memset(array, val, n - 1);
201 
202         array[n - 1] = val << (n * 8 - bit_count);
203     }
204 }
205 
tr_bitfieldGetRaw(tr_bitfield const * b,size_t * byte_count)206 void* tr_bitfieldGetRaw(tr_bitfield const* b, size_t* byte_count)
207 {
208     TR_ASSERT(b->bit_count > 0);
209 
210     size_t const n = get_bytes_needed(b->bit_count);
211     uint8_t* bits = tr_new0(uint8_t, n);
212 
213     if (b->alloc_count != 0)
214     {
215         TR_ASSERT(b->alloc_count <= n);
216         memcpy(bits, b->bits, b->alloc_count);
217     }
218     else if (tr_bitfieldHasAll(b))
219     {
220         set_all_true(bits, b->bit_count);
221     }
222 
223     *byte_count = n;
224     return bits;
225 }
226 
tr_bitfieldEnsureBitsAlloced(tr_bitfield * b,size_t n)227 static void tr_bitfieldEnsureBitsAlloced(tr_bitfield* b, size_t n)
228 {
229     size_t bytes_needed;
230     bool const has_all = tr_bitfieldHasAll(b);
231 
232     if (has_all)
233     {
234         bytes_needed = get_bytes_needed(MAX(n, b->true_count));
235     }
236     else
237     {
238         bytes_needed = get_bytes_needed(n);
239     }
240 
241     if (b->alloc_count < bytes_needed)
242     {
243         b->bits = tr_renew(uint8_t, b->bits, bytes_needed);
244         memset(b->bits + b->alloc_count, 0, bytes_needed - b->alloc_count);
245         b->alloc_count = bytes_needed;
246 
247         if (has_all)
248         {
249             set_all_true(b->bits, b->true_count);
250         }
251     }
252 }
253 
tr_bitfieldEnsureNthBitAlloced(tr_bitfield * b,size_t nth)254 static bool tr_bitfieldEnsureNthBitAlloced(tr_bitfield* b, size_t nth)
255 {
256     /* count is zero-based, so we need to allocate nth+1 bits before setting the nth */
257     if (nth == SIZE_MAX)
258     {
259         return false;
260     }
261 
262     tr_bitfieldEnsureBitsAlloced(b, nth + 1);
263     return true;
264 }
265 
tr_bitfieldFreeArray(tr_bitfield * b)266 static void tr_bitfieldFreeArray(tr_bitfield* b)
267 {
268     tr_free(b->bits);
269     b->bits = NULL;
270     b->alloc_count = 0;
271 }
272 
tr_bitfieldSetTrueCount(tr_bitfield * b,size_t n)273 static void tr_bitfieldSetTrueCount(tr_bitfield* b, size_t n)
274 {
275     TR_ASSERT(b->bit_count == 0 || n <= b->bit_count);
276 
277     b->true_count = n;
278 
279     if (tr_bitfieldHasAll(b) || tr_bitfieldHasNone(b))
280     {
281         tr_bitfieldFreeArray(b);
282     }
283 
284     TR_ASSERT(tr_bitfieldIsValid(b));
285 }
286 
tr_bitfieldRebuildTrueCount(tr_bitfield * b)287 static void tr_bitfieldRebuildTrueCount(tr_bitfield* b)
288 {
289     tr_bitfieldSetTrueCount(b, countArray(b));
290 }
291 
tr_bitfieldIncTrueCount(tr_bitfield * b,size_t i)292 static void tr_bitfieldIncTrueCount(tr_bitfield* b, size_t i)
293 {
294     TR_ASSERT(b->bit_count == 0 || i <= b->bit_count);
295     TR_ASSERT(b->bit_count == 0 || b->true_count <= b->bit_count - i);
296 
297     tr_bitfieldSetTrueCount(b, b->true_count + i);
298 }
299 
tr_bitfieldDecTrueCount(tr_bitfield * b,size_t i)300 static void tr_bitfieldDecTrueCount(tr_bitfield* b, size_t i)
301 {
302     TR_ASSERT(b->bit_count == 0 || i <= b->bit_count);
303     TR_ASSERT(b->bit_count == 0 || b->true_count >= i);
304 
305     tr_bitfieldSetTrueCount(b, b->true_count - i);
306 }
307 
308 /****
309 *****
310 ****/
311 
tr_bitfieldConstruct(tr_bitfield * b,size_t bit_count)312 void tr_bitfieldConstruct(tr_bitfield* b, size_t bit_count)
313 {
314     b->bit_count = bit_count;
315     b->true_count = 0;
316     b->bits = NULL;
317     b->alloc_count = 0;
318     b->have_all_hint = false;
319     b->have_none_hint = false;
320 
321     TR_ASSERT(tr_bitfieldIsValid(b));
322 }
323 
tr_bitfieldSetHasNone(tr_bitfield * b)324 void tr_bitfieldSetHasNone(tr_bitfield* b)
325 {
326     tr_bitfieldFreeArray(b);
327     b->true_count = 0;
328     b->have_all_hint = false;
329     b->have_none_hint = true;
330 
331     TR_ASSERT(tr_bitfieldIsValid(b));
332 }
333 
tr_bitfieldSetHasAll(tr_bitfield * b)334 void tr_bitfieldSetHasAll(tr_bitfield* b)
335 {
336     tr_bitfieldFreeArray(b);
337     b->true_count = b->bit_count;
338     b->have_all_hint = true;
339     b->have_none_hint = false;
340 
341     TR_ASSERT(tr_bitfieldIsValid(b));
342 }
343 
tr_bitfieldSetFromBitfield(tr_bitfield * b,tr_bitfield const * src)344 void tr_bitfieldSetFromBitfield(tr_bitfield* b, tr_bitfield const* src)
345 {
346     if (tr_bitfieldHasAll(src))
347     {
348         tr_bitfieldSetHasAll(b);
349     }
350     else if (tr_bitfieldHasNone(src))
351     {
352         tr_bitfieldSetHasNone(b);
353     }
354     else
355     {
356         tr_bitfieldSetRaw(b, src->bits, src->alloc_count, true);
357     }
358 }
359 
tr_bitfieldSetRaw(tr_bitfield * b,void const * bits,size_t byte_count,bool bounded)360 void tr_bitfieldSetRaw(tr_bitfield* b, void const* bits, size_t byte_count, bool bounded)
361 {
362     tr_bitfieldFreeArray(b);
363     b->true_count = 0;
364 
365     if (bounded)
366     {
367         byte_count = MIN(byte_count, get_bytes_needed(b->bit_count));
368     }
369 
370     b->bits = tr_memdup(bits, byte_count);
371     b->alloc_count = byte_count;
372 
373     if (bounded)
374     {
375         /* ensure the excess bits are set to '0' */
376         int const excess_bit_count = byte_count * 8 - b->bit_count;
377 
378         TR_ASSERT(excess_bit_count >= 0);
379         TR_ASSERT(excess_bit_count <= 7);
380 
381         if (excess_bit_count != 0)
382         {
383             b->bits[b->alloc_count - 1] &= 0xff << excess_bit_count;
384         }
385     }
386 
387     tr_bitfieldRebuildTrueCount(b);
388 }
389 
tr_bitfieldSetFromFlags(tr_bitfield * b,bool const * flags,size_t n)390 void tr_bitfieldSetFromFlags(tr_bitfield* b, bool const* flags, size_t n)
391 {
392     size_t trueCount = 0;
393 
394     tr_bitfieldFreeArray(b);
395     tr_bitfieldEnsureBitsAlloced(b, n);
396 
397     for (size_t i = 0; i < n; ++i)
398     {
399         if (flags[i])
400         {
401             ++trueCount;
402             b->bits[i >> 3U] |= (0x80 >> (i & 7U));
403         }
404     }
405 
406     tr_bitfieldSetTrueCount(b, trueCount);
407 }
408 
tr_bitfieldAdd(tr_bitfield * b,size_t nth)409 void tr_bitfieldAdd(tr_bitfield* b, size_t nth)
410 {
411     if (!tr_bitfieldHas(b, nth) && tr_bitfieldEnsureNthBitAlloced(b, nth))
412     {
413         b->bits[nth >> 3U] |= 0x80 >> (nth & 7U);
414         tr_bitfieldIncTrueCount(b, 1);
415     }
416 }
417 
418 /* Sets bit range [begin, end) to 1 */
tr_bitfieldAddRange(tr_bitfield * b,size_t begin,size_t end)419 void tr_bitfieldAddRange(tr_bitfield* b, size_t begin, size_t end)
420 {
421     size_t sb;
422     size_t eb;
423     unsigned char sm;
424     unsigned char em;
425     size_t const diff = (end - begin) - tr_bitfieldCountRange(b, begin, end);
426 
427     if (diff == 0)
428     {
429         return;
430     }
431 
432     end--;
433 
434     if (end >= b->bit_count || begin > end)
435     {
436         return;
437     }
438 
439     sb = begin >> 3;
440     sm = ~(0xff << (8 - (begin & 7)));
441     eb = end >> 3;
442     em = 0xff << (7 - (end & 7));
443 
444     if (!tr_bitfieldEnsureNthBitAlloced(b, end))
445     {
446         return;
447     }
448 
449     if (sb == eb)
450     {
451         b->bits[sb] |= sm & em;
452     }
453     else
454     {
455         b->bits[sb] |= sm;
456         b->bits[eb] |= em;
457 
458         if (++sb < eb)
459         {
460             memset(b->bits + sb, 0xff, eb - sb);
461         }
462     }
463 
464     tr_bitfieldIncTrueCount(b, diff);
465 }
466 
tr_bitfieldRem(tr_bitfield * b,size_t nth)467 void tr_bitfieldRem(tr_bitfield* b, size_t nth)
468 {
469     TR_ASSERT(tr_bitfieldIsValid(b));
470 
471     if (tr_bitfieldHas(b, nth) && tr_bitfieldEnsureNthBitAlloced(b, nth))
472     {
473         b->bits[nth >> 3U] &= 0xff7f >> (nth & 7U);
474         tr_bitfieldDecTrueCount(b, 1);
475     }
476 }
477 
478 /* Clears bit range [begin, end) to 0 */
tr_bitfieldRemRange(tr_bitfield * b,size_t begin,size_t end)479 void tr_bitfieldRemRange(tr_bitfield* b, size_t begin, size_t end)
480 {
481     size_t sb;
482     size_t eb;
483     unsigned char sm;
484     unsigned char em;
485     size_t const diff = tr_bitfieldCountRange(b, begin, end);
486 
487     if (diff == 0)
488     {
489         return;
490     }
491 
492     end--;
493 
494     if (end >= b->bit_count || begin > end)
495     {
496         return;
497     }
498 
499     sb = begin >> 3;
500     sm = 0xff << (8 - (begin & 7));
501     eb = end >> 3;
502     em = ~(0xff << (7 - (end & 7)));
503 
504     if (!tr_bitfieldEnsureNthBitAlloced(b, end))
505     {
506         return;
507     }
508 
509     if (sb == eb)
510     {
511         b->bits[sb] &= sm | em;
512     }
513     else
514     {
515         b->bits[sb] &= sm;
516         b->bits[eb] &= em;
517 
518         if (++sb < eb)
519         {
520             memset(b->bits + sb, 0, eb - sb);
521         }
522     }
523 
524     tr_bitfieldDecTrueCount(b, diff);
525 }
526