1 /*
2    Copyright (c) 2019 - 2021, Ilan Schnell; All Rights Reserved
3    bitarray is published under the PSF license.
4 
5    This file contains the C implementation of some useful utility functions.
6 
7    Author: Ilan Schnell
8 */
9 
10 #define PY_SSIZE_T_CLEAN
11 #include "Python.h"
12 #include "pythoncapi_compat.h"
13 #include "bitarray.h"
14 
15 #define IS_LE(a)  ((a)->endian == ENDIAN_LITTLE)
16 #define IS_BE(a)  ((a)->endian == ENDIAN_BIG)
17 
18 /* set using the Python module function _set_bato() */
19 static PyObject *bitarray_type_obj = NULL;
20 
21 /* Return 0 if obj is bitarray.  If not, return -1 and set an exception. */
22 static int
ensure_bitarray(PyObject * obj)23 ensure_bitarray(PyObject *obj)
24 {
25     int t;
26 
27     if (bitarray_type_obj == NULL)
28         Py_FatalError("bitarray_type_obj not set");
29     t = PyObject_IsInstance(obj, bitarray_type_obj);
30     if (t < 0)
31         return -1;
32     if (t == 0) {
33         PyErr_Format(PyExc_TypeError, "bitarray expected, not %s",
34                      Py_TYPE(obj)->tp_name);
35         return -1;
36     }
37     return 0;
38 }
39 
40 /* ensure object is a bitarray of given length */
41 static int
ensure_ba_of_length(PyObject * a,const Py_ssize_t n)42 ensure_ba_of_length(PyObject *a, const Py_ssize_t n)
43 {
44     if (ensure_bitarray(a) < 0)
45         return -1;
46     if (((bitarrayobject *) a)->nbits != n) {
47         PyErr_SetString(PyExc_ValueError, "size mismatch");
48         return -1;
49     }
50     return 0;
51 }
52 
53 /* ------------------------------- count_n ----------------------------- */
54 
55 /* return the smallest index i for which a.count(1, 0, i) == n, or when
56    n exceeds the total count return -1  */
57 static Py_ssize_t
count_to_n(bitarrayobject * a,Py_ssize_t n)58 count_to_n(bitarrayobject *a, Py_ssize_t n)
59 {
60     const Py_ssize_t nbits = a->nbits;
61     Py_ssize_t i = 0;        /* index */
62     Py_ssize_t j = 0;        /* total count up to index */
63     Py_ssize_t block_start, block_stop, k, m;
64 
65     assert(0 <= n && n <= nbits);
66     if (n == 0)
67         return 0;
68 
69 #define BLOCK_BITS  8192
70     /* by counting big blocks we save comparisons */
71     while (i + BLOCK_BITS < nbits) {
72         m = 0;
73         assert(i % 8 == 0);
74         block_start = i >> 3;
75         block_stop = block_start + (BLOCK_BITS >> 3);
76         assert(block_stop <= Py_SIZE(a));
77         for (k = block_start; k < block_stop; k++)
78             m += bitcount_lookup[(unsigned char) a->ob_item[k]];
79         if (j + m >= n)
80             break;
81         j += m;
82         i += BLOCK_BITS;
83     }
84 #undef BLOCK_BITS
85 
86     while (i + 8 < nbits) {
87         k = i >> 3;
88         assert(k < Py_SIZE(a));
89         m = bitcount_lookup[(unsigned char) a->ob_item[k]];
90         if (j + m >= n)
91             break;
92         j += m;
93         i += 8;
94     }
95 
96     while (j < n && i < nbits ) {
97         j += getbit(a, i);
98         i++;
99     }
100     if (j < n)
101         return -1;
102 
103     return i;
104 }
105 
106 static PyObject *
count_n(PyObject * module,PyObject * args)107 count_n(PyObject *module, PyObject *args)
108 {
109     PyObject *a;
110     Py_ssize_t n, i;
111 
112     if (!PyArg_ParseTuple(args, "On:count_n", &a, &n))
113         return NULL;
114     if (ensure_bitarray(a) < 0)
115         return NULL;
116 
117     if (n < 0) {
118         PyErr_SetString(PyExc_ValueError, "non-negative integer expected");
119         return NULL;
120     }
121 #define aa  ((bitarrayobject *) a)
122     if (n > aa->nbits)  {
123         PyErr_SetString(PyExc_ValueError, "n larger than bitarray size");
124         return NULL;
125     }
126     i = count_to_n(aa, n);        /* do actual work here */
127 #undef aa
128     if (i < 0) {
129         PyErr_SetString(PyExc_ValueError, "n exceeds total count");
130         return NULL;
131     }
132     return PyLong_FromSsize_t(i);
133 }
134 
135 PyDoc_STRVAR(count_n_doc,
136 "count_n(a, n, /) -> int\n\
137 \n\
138 Return lowest index `i` for which `a[:i].count() == n`.\n\
139 Raises `ValueError`, when n exceeds total count (`a.count()`).");
140 
141 /* ----------------------------- right index --------------------------- */
142 
143 /* return index of highest occurrence of vi in self[a:b], -1 when not found */
144 static Py_ssize_t
find_last(bitarrayobject * self,int vi,Py_ssize_t a,Py_ssize_t b)145 find_last(bitarrayobject *self, int vi, Py_ssize_t a, Py_ssize_t b)
146 {
147     const Py_ssize_t n = b - a;
148     Py_ssize_t res, i;
149 
150     assert(0 <= a && a <= self->nbits);
151     assert(0 <= b && b <= self->nbits);
152     assert(0 <= vi && vi <= 1);
153     if (n <= 0)
154         return -1;
155 
156     /* the logic here is the same as in find_bit() in _bitarray.c */
157 #ifdef PY_UINT64_T
158     if (n > 64) {
159         const Py_ssize_t word_a = (a + 63) / 64;
160         const Py_ssize_t word_b = b / 64;
161         const PY_UINT64_T w = vi ? 0 : ~0;
162 
163         if ((res = find_last(self, vi, 64 * word_b, b)) >= 0)
164             return res;
165 
166         for (i = word_b - 1; i >= word_a; i--) {  /* skip uint64 words */
167             if (w ^ ((PY_UINT64_T *) self->ob_item)[i])
168                 return find_last(self, vi, 64 * i, 64 * i + 64);
169         }
170         return find_last(self, vi, a, 64 * word_a);
171     }
172 #endif
173     if (n > 8) {
174         const Py_ssize_t byte_a = BYTES(a);
175         const Py_ssize_t byte_b = b / 8;
176         const char c = vi ? 0 : ~0;
177 
178         if ((res = find_last(self, vi, BITS(byte_b), b)) >= 0)
179             return res;
180 
181         for (i = byte_b - 1; i >= byte_a; i--) {  /* skip bytes */
182             assert_byte_in_range(self, i);
183             if (c ^ self->ob_item[i])
184                 return find_last(self, vi, BITS(i), BITS(i) + 8);
185         }
186         return find_last(self, vi, a, BITS(byte_a));
187     }
188     assert(n <= 8);
189     for (i = b - 1; i >= a; i--) {
190         if (getbit(self, i) == vi)
191             return i;
192     }
193     return -1;
194 }
195 
196 static PyObject *
r_index(PyObject * module,PyObject * args)197 r_index(PyObject *module, PyObject *args)
198 {
199     PyObject *value = Py_True, *a;
200     Py_ssize_t start = 0, stop = PY_SSIZE_T_MAX, res;
201     int vi;
202 
203     if (!PyArg_ParseTuple(args, "O|Onn:rindex", &a, &value, &start, &stop))
204         return NULL;
205     if (ensure_bitarray(a) < 0)
206         return NULL;
207     if ((vi = pybit_as_int(value)) < 0)
208         return NULL;
209 
210 #define aa  ((bitarrayobject *) a)
211     normalize_index(aa->nbits, &start);
212     normalize_index(aa->nbits, &stop);
213     res = find_last(aa, vi, start, stop);
214 #undef aa
215     if (res < 0)
216         return PyErr_Format(PyExc_ValueError, "%d not in bitarray", vi);
217 
218     return PyLong_FromSsize_t(res);
219 }
220 
221 PyDoc_STRVAR(rindex_doc,
222 "rindex(bitarray, value=1, start=0, stop=<end of array>, /) -> int\n\
223 \n\
224 Return the rightmost (highest) index of `value` in bitarray.\n\
225 Raises `ValueError` if the value is not present.");
226 
227 /* --------------------------- unary functions ------------------------- */
228 
229 static PyObject *
parity(PyObject * module,PyObject * a)230 parity(PyObject *module, PyObject *a)
231 {
232     unsigned char par = 0;
233     Py_ssize_t i;
234 
235     if (ensure_bitarray(a) < 0)
236         return NULL;
237 
238 #define aa  ((bitarrayobject *) a)
239     for (i = 0; i < aa->nbits / 8; i++)
240         par ^= aa->ob_item[i];
241     if (aa->nbits % 8)
242         par ^= zeroed_last_byte(aa);
243 #undef aa
244 
245     return PyLong_FromLong((long) bitcount_lookup[par] % 2);
246 }
247 
248 PyDoc_STRVAR(parity_doc,
249 "parity(a, /) -> int\n\
250 \n\
251 Return the parity of bitarray `a`.\n\
252 This is equivalent to `a.count() % 2` (but more efficient).");
253 
254 /* --------------------------- binary functions ------------------------ */
255 
256 enum kernel_type {
257     KERN_cand,     /* count bitwise and -> int */
258     KERN_cor,      /* count bitwise or -> int */
259     KERN_cxor,     /* count bitwise xor -> int */
260     KERN_subset,   /* is subset -> bool */
261 };
262 
263 static PyObject *
binary_function(PyObject * args,enum kernel_type kern,const char * format)264 binary_function(PyObject *args, enum kernel_type kern, const char *format)
265 {
266     Py_ssize_t res = 0, s, i;
267     PyObject *a, *b;
268     unsigned char c;
269     int r;
270 
271     if (!PyArg_ParseTuple(args, format, &a, &b))
272         return NULL;
273     if (ensure_bitarray(a) < 0 || ensure_bitarray(b) < 0)
274         return NULL;
275 
276 #define aa  ((bitarrayobject *) a)
277 #define bb  ((bitarrayobject *) b)
278     if (aa->nbits != bb->nbits) {
279         PyErr_SetString(PyExc_ValueError,
280                         "bitarrays of equal length expected");
281         return NULL;
282     }
283     if (aa->endian != bb->endian) {
284         PyErr_SetString(PyExc_ValueError,
285                         "bitarrays of equal endianness expected");
286         return NULL;
287     }
288     s = aa->nbits / 8;       /* number of whole bytes in buffer */
289     r = aa->nbits % 8;       /* remaining bits  */
290 
291     switch (kern) {
292     case KERN_cand:
293         for (i = 0; i < s; i++) {
294             c = aa->ob_item[i] & bb->ob_item[i];
295             res += bitcount_lookup[c];
296         }
297         if (r) {
298             c = zeroed_last_byte(aa) & zeroed_last_byte(bb);
299             res += bitcount_lookup[c];
300         }
301         break;
302 
303     case KERN_cor:
304         for (i = 0; i < s; i++) {
305             c = aa->ob_item[i] | bb->ob_item[i];
306             res += bitcount_lookup[c];
307         }
308         if (r) {
309             c = zeroed_last_byte(aa) | zeroed_last_byte(bb);
310             res += bitcount_lookup[c];
311         }
312         break;
313 
314     case KERN_cxor:
315         for (i = 0; i < s; i++) {
316             c = aa->ob_item[i] ^ bb->ob_item[i];
317             res += bitcount_lookup[c];
318         }
319         if (r) {
320             c = zeroed_last_byte(aa) ^ zeroed_last_byte(bb);
321             res += bitcount_lookup[c];
322         }
323         break;
324 
325     case KERN_subset:
326         for (i = 0; i < s; i++) {
327             if ((aa->ob_item[i] & bb->ob_item[i]) != aa->ob_item[i])
328                 Py_RETURN_FALSE;
329         }
330         if (r) {
331             if ((zeroed_last_byte(aa) & zeroed_last_byte(bb)) !=
332                  zeroed_last_byte(aa))
333                 Py_RETURN_FALSE;
334         }
335         Py_RETURN_TRUE;
336 
337     default:  /* cannot happen */
338         return NULL;
339     }
340 #undef aa
341 #undef bb
342     return PyLong_FromSsize_t(res);
343 }
344 
345 #define COUNT_FUNC(oper, ochar)                                         \
346 static PyObject *                                                       \
347 count_ ## oper (PyObject *module, PyObject *args)                       \
348 {                                                                       \
349     return binary_function(args, KERN_c ## oper, "OO:count_" #oper);    \
350 }                                                                       \
351 PyDoc_STRVAR(count_ ## oper ## _doc,                                    \
352 "count_" #oper "(a, b, /) -> int\n\
353 \n\
354 Return `(a " ochar " b).count()` in a memory efficient manner,\n\
355 as no intermediate bitarray object gets created.")
356 
357 COUNT_FUNC(and, "&");           /* count_and */
358 COUNT_FUNC(or,  "|");           /* count_or  */
359 COUNT_FUNC(xor, "^");           /* count_xor */
360 
361 
362 static PyObject *
subset(PyObject * module,PyObject * args)363 subset(PyObject *module, PyObject *args)
364 {
365     return binary_function(args, KERN_subset, "OO:subset");
366 }
367 
368 PyDoc_STRVAR(subset_doc,
369 "subset(a, b, /) -> bool\n\
370 \n\
371 Return `True` if bitarray `a` is a subset of bitarray `b`.\n\
372 `subset(a, b)` is equivalent to `(a & b).count() == a.count()` but is more\n\
373 efficient since we can stop as soon as one mismatch is found, and no\n\
374 intermediate bitarray object gets created.");
375 
376 /* ---------------------------- serialization -------------------------- */
377 
378 static PyObject *
serialize(PyObject * module,PyObject * a)379 serialize(PyObject *module, PyObject *a)
380 {
381     PyObject *result;
382     Py_ssize_t nbytes;
383     char *str;
384 
385     if (ensure_bitarray(a) < 0)
386         return NULL;
387 
388     nbytes = Py_SIZE(a);
389     result = PyBytes_FromStringAndSize(NULL, nbytes + 1);
390     if (result == NULL)
391         return NULL;
392 
393     str = PyBytes_AsString(result);
394 #define aa  ((bitarrayobject *) a)
395     *str = (char) (16 * IS_BE(aa) + setunused(aa));
396     memcpy(str + 1, aa->ob_item, (size_t) nbytes);
397 #undef aa
398     return result;
399 }
400 
401 PyDoc_STRVAR(serialize_doc,
402 "serialize(bitarray, /) -> bytes\n\
403 \n\
404 Return a serialized representation of the bitarray, which may be passed to\n\
405 `deserialize()`.  It efficiently represents the bitarray object (including\n\
406 its endianness) and is guaranteed not to change in future releases.");
407 
408 /* ----------------------------- hexadecimal --------------------------- */
409 
410 static const char hexdigits[] = "0123456789abcdef";
411 
412 static int
hex_to_int(char c)413 hex_to_int(char c)
414 {
415     if ('0' <= c && c <= '9')
416         return c - '0';
417     if ('a' <= c && c <= 'f')
418         return c - 'a' + 10;
419     if ('A' <= c && c <= 'F')
420         return c - 'A' + 10;
421     return -1;
422 }
423 
424 static PyObject *
ba2hex(PyObject * module,PyObject * a)425 ba2hex(PyObject *module, PyObject *a)
426 {
427     PyObject *result;
428     size_t i, strsize;
429     char *str;
430     int le, be;
431 
432     if (ensure_bitarray(a) < 0)
433         return NULL;
434 
435 #define aa  ((bitarrayobject *) a)
436     if (aa->nbits % 4) {
437         PyErr_SetString(PyExc_ValueError, "bitarray length not multiple of 4");
438         return NULL;
439     }
440 
441     /* strsize = aa->nbits / 4;  could make strsize odd */
442     strsize = 2 * Py_SIZE(a);
443     str = (char *) PyMem_Malloc(strsize);
444     if (str == NULL)
445         return PyErr_NoMemory();
446 
447     le = IS_LE(aa);
448     be = IS_BE(aa);
449     for (i = 0; i < strsize; i += 2) {
450         unsigned char c = aa->ob_item[i / 2];
451 
452         str[i + le] = hexdigits[c >> 4];
453         str[i + be] = hexdigits[0x0f & c];
454     }
455     assert((size_t) aa->nbits / 4 <= strsize);
456     result = Py_BuildValue("s#", str, aa->nbits / 4);
457 #undef aa
458     PyMem_Free((void *) str);
459     return result;
460 }
461 
462 PyDoc_STRVAR(ba2hex_doc,
463 "ba2hex(bitarray, /) -> hexstr\n\
464 \n\
465 Return a string containing the hexadecimal representation of\n\
466 the bitarray (which has to be multiple of 4 in length).");
467 
468 
469 /* Translate hexadecimal digits into the bitarray's buffer.
470    Each digit corresponds to 4 bits in the bitarray.
471    The number of digits may be odd. */
472 static PyObject *
hex2ba(PyObject * module,PyObject * args)473 hex2ba(PyObject *module, PyObject *args)
474 {
475     PyObject *a;
476     char *str;
477     Py_ssize_t i, strsize;
478     int le, be;
479 
480     if (!PyArg_ParseTuple(args, "Os#", &a, &str, &strsize))
481         return NULL;
482     if (ensure_ba_of_length(a, 4 * strsize) < 0)
483         return NULL;
484 
485 #define aa  ((bitarrayobject *) a)
486     le = IS_LE(aa);
487     be = IS_BE(aa);
488     assert(le + be == 1 && str[strsize] == 0);
489     for (i = 0; i < strsize; i += 2) {
490         int x = hex_to_int(str[i + le]);
491         int y = hex_to_int(str[i + be]);
492 
493         if (x < 0 || y < 0) {
494             /* ignore the terminating NUL - happends when strsize is odd */
495             if (i + le == strsize) /* str[i+le] is NUL */
496                 x = 0;
497             if (i + be == strsize) /* str[i+be] is NUL */
498                 y = 0;
499             /* there is an invalid byte - or (non-terminating) NUL */
500             if (x < 0 || y < 0) {
501                 PyErr_SetString(PyExc_ValueError,
502                                 "non-hexadecimal digit found");
503                 return NULL;
504             }
505         }
506         assert(0 <= x && x < 16 && 0 <= y && y < 16);
507         aa->ob_item[i / 2] = x << 4 | y;
508     }
509 #undef aa
510     Py_RETURN_NONE;
511 }
512 
513 /* ----------------------- base 2, 4, 8, 16, 32, 64 -------------------- */
514 
515 /* RFC 4648 Base32 alphabet */
516 static const char base32_alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
517 
518 /* standard base 64 alphabet */
519 static const char base64_alphabet[] =
520     "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
521 
522 static int
digit_to_int(char c,int n)523 digit_to_int(char c, int n)
524 {
525     int i;
526 
527     switch (n) {
528     case 32:                    /* base 32 */
529         if ('A' <= c && c <= 'Z')
530             return c - 'A';
531         if ('2' <= c && c <= '7')
532             return c - '2' + 26;
533         break;
534 
535     case 64:                    /* base 64 */
536         if ('A' <= c && c <= 'Z')
537             return c - 'A';
538         if ('a' <= c && c <= 'z')
539             return c - 'a' + 26;
540         if ('0' <= c && c <= '9')
541             return c - '0' + 52;
542         if (c == '+')
543             return 62;
544         if (c == '/')
545             return 63;
546         break;
547 
548     default:                    /* base 2, 4, 8, 16 */
549         i = hex_to_int(c);
550         if (i < n)
551             return i;
552     }
553     return -1;
554 }
555 
556 /* return m = log2(n) for m = 1..6 */
557 static int
base_to_length(int n)558 base_to_length(int n)
559 {
560     int m;
561 
562     for (m = 1; m < 7; m++) {
563         if (n == (1 << m))
564             return m;
565     }
566     PyErr_Format(PyExc_ValueError,
567                  "base must be 2, 4, 8, 16, 32 or 64, not %d", n);
568     return -1;
569 }
570 
571 static PyObject *
ba2base(PyObject * module,PyObject * args)572 ba2base(PyObject *module, PyObject *args)
573 {
574     const char *alphabet;
575     PyObject *result, *a;
576     size_t i, strsize;
577     char *str;
578     int n, m, le;
579 
580     if (!PyArg_ParseTuple(args, "iO:ba2base", &n, &a))
581         return NULL;
582     if ((m = base_to_length(n)) < 0)
583         return NULL;
584     if (ensure_bitarray(a) < 0)
585         return NULL;
586 
587     switch (n) {
588     case 32: alphabet = base32_alphabet; break;
589     case 64: alphabet = base64_alphabet; break;
590     default: alphabet = hexdigits;
591     }
592 
593 #define aa  ((bitarrayobject *) a)
594     if (aa->nbits % m)
595         return PyErr_Format(PyExc_ValueError,
596                             "bitarray length must be multiple of %d", m);
597 
598     strsize = aa->nbits / m;
599     if ((str = (char *) PyMem_Malloc(strsize)) == NULL)
600         return PyErr_NoMemory();
601 
602     le = IS_LE(aa);
603     for (i = 0; i < strsize; i++) {
604         int j, k, x = 0;
605 
606         for (j = 0; j < m; j++) {
607             k = le ? j : (m - j - 1);
608             x |= getbit(aa, i * m + k) << j;
609         }
610         str[i] = alphabet[x];
611     }
612     result = Py_BuildValue("s#", str, strsize);
613 #undef aa
614     PyMem_Free((void *) str);
615     return result;
616 }
617 
618 PyDoc_STRVAR(ba2base_doc,
619 "ba2base(n, bitarray, /) -> str\n\
620 \n\
621 Return a string containing the base `n` ASCII representation of\n\
622 the bitarray.  Allowed values for `n` are 2, 4, 8, 16, 32 and 64.\n\
623 The bitarray has to be multiple of length 1, 2, 3, 4, 5 or 6 respectively.\n\
624 For `n=16` (hexadecimal), `ba2hex()` will be much faster, as `ba2base()`\n\
625 does not take advantage of byte level operations.\n\
626 For `n=32` the RFC 4648 Base32 alphabet is used, and for `n=64` the\n\
627 standard base 64 alphabet is used.");
628 
629 
630 /* Translate ASCII digits into the bitarray's buffer.
631    The (Python) arguments to this functions are:
632    - base n, one of 2, 4, 8, 16, 32, 64  (n=2^m   where m bits per digit)
633    - bitarray (of length m * len(s)) whose buffer is written into
634    - byte object s containing the ASCII digits
635 */
636 static PyObject *
base2ba(PyObject * module,PyObject * args)637 base2ba(PyObject *module, PyObject *args)
638 {
639     PyObject *a = NULL;
640     Py_ssize_t i, strsize = 0;
641     char *str = NULL;
642     int n, m, le;
643 
644     if (!PyArg_ParseTuple(args, "i|Os#", &n, &a, &str, &strsize))
645         return NULL;
646     if ((m = base_to_length(n)) < 0)
647         return NULL;
648     if (a == NULL)
649         return PyLong_FromLong(m);
650     if (ensure_ba_of_length(a, m * strsize) < 0)
651         return NULL;
652 
653 #define aa  ((bitarrayobject *) a)
654     memset(aa->ob_item, 0x00, (size_t) Py_SIZE(a));
655 
656     le = IS_LE(aa);
657     for (i = 0; i < strsize; i++) {
658         int j, k, d = digit_to_int(str[i], n);
659 
660         if (d < 0) {
661             unsigned char c = str[i];
662             return PyErr_Format(PyExc_ValueError, "invalid digit found for "
663                                 "base %d, got '%c' (0x%02x)", n, c, c);
664         }
665         for (j = 0; j < m; j++) {
666             k = le ? j : (m - j - 1);
667             setbit(aa, i * m + k, d & (1 << j));
668         }
669     }
670 #undef aa
671     Py_RETURN_NONE;
672 }
673 
674 /* ------------------- variable length bitarray format ----------------- */
675 
676 /* grow buffer by at least one byte */
677 static int
grow_buffer(bitarrayobject * self)678 grow_buffer(bitarrayobject *self)
679 {
680     size_t newsize = Py_SIZE(self) + 1;
681 
682     assert_nbits(self);
683     assert(self->allocated >= Py_SIZE(self));
684     assert(self->ob_exports == 0);
685     assert(self->buffer == NULL);
686     assert(self->readonly == 0);
687 
688     /* standard growth pattern */
689     newsize += (newsize >> 4) + (newsize < 8 ? 3 : 7);
690 
691     self->ob_item = PyMem_Realloc(self->ob_item, newsize);
692     if (self->ob_item == NULL) {
693         PyErr_NoMemory();
694         return -1;
695     }
696     Py_SET_SIZE(self, newsize);
697     self->allocated = newsize;
698     self->nbits = 8 * newsize;
699     return 0;
700 }
701 
702 /* PADBITS is always 3 - the number of bits that represent the number of
703    padding bits.  The actual number of padding bits is called 'padding'
704    below, and is in range(0, 7).
705    Also note that 'padding' refers to the pad bits within the variable
706    length format, which is not the same as the pad bits of the actual
707    bitarray.  For example, b'\x10' has padding = 1, and decodes to
708    bitarray('000'), which has 5 pad bits. */
709 #define PADBITS  3
710 
711 /* consume iterator while decoding bytes into bitarray */
712 static PyObject *
vl_decode(PyObject * module,PyObject * args)713 vl_decode(PyObject *module, PyObject *args)
714 {
715     PyObject *iter, *item, *a;
716     Py_ssize_t padding = 0;  /* number of pad bits read from header byte */
717     Py_ssize_t i = 0;        /* bit counter */
718     unsigned char b = 0x80;  /* empty stream will raise StopIteration */
719     Py_ssize_t k;
720 
721     if (!PyArg_ParseTuple(args, "OO", &iter, &a))
722         return NULL;
723     if (!PyIter_Check(iter))
724         return PyErr_Format(PyExc_TypeError, "iterator or bytes expected, "
725                             "got '%s'", Py_TYPE(iter)->tp_name);
726 
727 #define aa  ((bitarrayobject *) a)
728     while ((item = PyIter_Next(iter))) {
729 #ifdef IS_PY3K
730         if (PyLong_Check(item))
731             b = (unsigned char) PyLong_AsLong(item);
732 #else
733         if (PyBytes_Check(item))
734             b = (unsigned char) *PyBytes_AS_STRING(item);
735 #endif
736         else {
737             PyErr_Format(PyExc_TypeError, "int (byte) iterator expected, "
738                          "got '%s' element", Py_TYPE(item)->tp_name);
739             Py_DECREF(item);
740             return NULL;
741         }
742         Py_DECREF(item);
743 
744         if (i + 6 >= aa->nbits && grow_buffer(aa) < 0)
745             return NULL;
746         assert(i + 6 < aa->nbits);
747 
748         if (i == 0) {
749             padding = (b & 0x70) >> 4;
750             if (padding >= 7 || ((b & 0x80) == 0 && padding > 4))
751                 return PyErr_Format(PyExc_ValueError,
752                                     "invalid header byte: 0x%02x", b);
753             for (k = 0; k < 4; k++)
754                 setbit(aa, i++, (0x08 >> k) & b);
755         }
756         else {
757             for (k = 0; k < 7; k++)
758                 setbit(aa, i++, (0x40 >> k) & b);
759         }
760         if ((b & 0x80) == 0)
761             break;
762     }
763     /* set final length of bitarray */
764     aa->nbits = i - padding;
765     Py_SET_SIZE(a, BYTES(aa->nbits));
766     assert_nbits(aa);
767 #undef aa
768 
769     if (PyErr_Occurred())       /* from PyIter_Next() */
770         return NULL;
771 
772     if (b & 0x80) {
773         k = (i + PADBITS) / 7;
774         return PyErr_Format(PyExc_StopIteration,
775                             "no terminating byte found, bytes read: %zd", k);
776     }
777     Py_RETURN_NONE;
778 }
779 
780 static PyObject *
vl_encode(PyObject * module,PyObject * a)781 vl_encode(PyObject *module, PyObject *a)
782 {
783     PyObject *result;
784     Py_ssize_t padding, n, m, i;
785     Py_ssize_t j = 0;           /* byte conter */
786     char *str;
787 
788     if (ensure_bitarray(a) < 0)
789         return NULL;
790 
791 #define aa  ((bitarrayobject *) a)
792     n = (aa->nbits + PADBITS + 6) / 7;  /* number of resulting bytes */
793     m = 7 * n - PADBITS;      /* number of bits resulting bytes can hold */
794     padding = m - aa->nbits;  /* number of pad bits */
795     assert(0 <= padding && padding < 7);
796 
797     result = PyBytes_FromStringAndSize(NULL, n);
798     if (result == NULL)
799         return NULL;
800 
801     str = PyBytes_AsString(result);
802     str[0] = aa->nbits > 4 ? 0x80 : 0x00;  /* leading bit */
803     str[0] |= padding << 4;                /* encode padding */
804     for (i = 0; i < 4 && i < aa->nbits; i++)
805         str[0] |= (0x08 >> i) * getbit(aa, i);
806 
807     for (i = 4; i < aa->nbits; i++) {
808         int k = (i - 4) % 7;
809 
810         if (k == 0) {
811             j++;
812             str[j] = j < n - 1 ? 0x80 : 0x00;  /* leading bit */
813         }
814         str[j] |= (0x40 >> k) * getbit(aa, i);
815     }
816 #undef aa
817     assert(j == n - 1);
818 
819     return result;
820 }
821 
822 PyDoc_STRVAR(vl_encode_doc,
823 "vl_encode(bitarray, /) -> bytes\n\
824 \n\
825 Return variable length binary representation of bitarray.\n\
826 This representation is useful for efficiently storing small bitarray\n\
827 in a binary stream.  Use `vl_decode()` for decoding.");
828 
829 /* --------------------------------------------------------------------- */
830 
831 /* Set bitarray_type_obj (bato).  This function must be called before any
832    other Python function in this module. */
833 static PyObject *
set_bato(PyObject * module,PyObject * obj)834 set_bato(PyObject *module, PyObject *obj)
835 {
836     bitarray_type_obj = obj;
837     Py_RETURN_NONE;
838 }
839 
840 static PyMethodDef module_functions[] = {
841     {"count_n",   (PyCFunction) count_n,   METH_VARARGS, count_n_doc},
842     {"rindex",    (PyCFunction) r_index,   METH_VARARGS, rindex_doc},
843     {"parity",    (PyCFunction) parity,    METH_O,       parity_doc},
844     {"count_and", (PyCFunction) count_and, METH_VARARGS, count_and_doc},
845     {"count_or",  (PyCFunction) count_or,  METH_VARARGS, count_or_doc},
846     {"count_xor", (PyCFunction) count_xor, METH_VARARGS, count_xor_doc},
847     {"subset",    (PyCFunction) subset,    METH_VARARGS, subset_doc},
848     {"serialize", (PyCFunction) serialize, METH_O,       serialize_doc},
849     {"ba2hex",    (PyCFunction) ba2hex,    METH_O,       ba2hex_doc},
850     {"_hex2ba",   (PyCFunction) hex2ba,    METH_VARARGS, 0},
851     {"ba2base",   (PyCFunction) ba2base,   METH_VARARGS, ba2base_doc},
852     {"_base2ba",  (PyCFunction) base2ba,   METH_VARARGS, 0},
853     {"vl_encode", (PyCFunction) vl_encode, METH_O,       vl_encode_doc},
854     {"_vl_decode",(PyCFunction) vl_decode, METH_VARARGS, 0},
855     {"_set_bato", (PyCFunction) set_bato,  METH_O,       0},
856     {NULL,        NULL}  /* sentinel */
857 };
858 
859 /******************************* Install Module ***************************/
860 
861 #ifdef IS_PY3K
862 static PyModuleDef moduledef = {
863     PyModuleDef_HEAD_INIT, "_util", 0, -1, module_functions,
864 };
865 #endif
866 
867 PyMODINIT_FUNC
868 #ifdef IS_PY3K
PyInit__util(void)869 PyInit__util(void)
870 #else
871 init_util(void)
872 #endif
873 {
874     PyObject *m;
875 
876 #ifdef IS_PY3K
877     m = PyModule_Create(&moduledef);
878     if (m == NULL)
879         return NULL;
880     return m;
881 #else
882     m = Py_InitModule3("_util", module_functions, 0);
883     if (m == NULL)
884         return;
885 #endif
886 }
887