1 /*
2 Copyright (c) 2019 - 2021, Ilan Schnell; All Rights Reserved
3 bitarray is published under the PSF license.
4
5 This file contains the C implementation of some useful utility functions.
6
7 Author: Ilan Schnell
8 */
9
10 #define PY_SSIZE_T_CLEAN
11 #include "Python.h"
12 #include "pythoncapi_compat.h"
13 #include "bitarray.h"
14
15 #define IS_LE(a) ((a)->endian == ENDIAN_LITTLE)
16 #define IS_BE(a) ((a)->endian == ENDIAN_BIG)
17
18 /* set using the Python module function _set_bato() */
19 static PyObject *bitarray_type_obj = NULL;
20
21 /* Return 0 if obj is bitarray. If not, return -1 and set an exception. */
22 static int
ensure_bitarray(PyObject * obj)23 ensure_bitarray(PyObject *obj)
24 {
25 int t;
26
27 if (bitarray_type_obj == NULL)
28 Py_FatalError("bitarray_type_obj not set");
29 t = PyObject_IsInstance(obj, bitarray_type_obj);
30 if (t < 0)
31 return -1;
32 if (t == 0) {
33 PyErr_Format(PyExc_TypeError, "bitarray expected, not %s",
34 Py_TYPE(obj)->tp_name);
35 return -1;
36 }
37 return 0;
38 }
39
40 /* ensure object is a bitarray of given length */
41 static int
ensure_ba_of_length(PyObject * a,const Py_ssize_t n)42 ensure_ba_of_length(PyObject *a, const Py_ssize_t n)
43 {
44 if (ensure_bitarray(a) < 0)
45 return -1;
46 if (((bitarrayobject *) a)->nbits != n) {
47 PyErr_SetString(PyExc_ValueError, "size mismatch");
48 return -1;
49 }
50 return 0;
51 }
52
53 /* ------------------------------- count_n ----------------------------- */
54
55 /* return the smallest index i for which a.count(1, 0, i) == n, or when
56 n exceeds the total count return -1 */
57 static Py_ssize_t
count_to_n(bitarrayobject * a,Py_ssize_t n)58 count_to_n(bitarrayobject *a, Py_ssize_t n)
59 {
60 const Py_ssize_t nbits = a->nbits;
61 Py_ssize_t i = 0; /* index */
62 Py_ssize_t j = 0; /* total count up to index */
63 Py_ssize_t block_start, block_stop, k, m;
64
65 assert(0 <= n && n <= nbits);
66 if (n == 0)
67 return 0;
68
69 #define BLOCK_BITS 8192
70 /* by counting big blocks we save comparisons */
71 while (i + BLOCK_BITS < nbits) {
72 m = 0;
73 assert(i % 8 == 0);
74 block_start = i >> 3;
75 block_stop = block_start + (BLOCK_BITS >> 3);
76 assert(block_stop <= Py_SIZE(a));
77 for (k = block_start; k < block_stop; k++)
78 m += bitcount_lookup[(unsigned char) a->ob_item[k]];
79 if (j + m >= n)
80 break;
81 j += m;
82 i += BLOCK_BITS;
83 }
84 #undef BLOCK_BITS
85
86 while (i + 8 < nbits) {
87 k = i >> 3;
88 assert(k < Py_SIZE(a));
89 m = bitcount_lookup[(unsigned char) a->ob_item[k]];
90 if (j + m >= n)
91 break;
92 j += m;
93 i += 8;
94 }
95
96 while (j < n && i < nbits ) {
97 j += getbit(a, i);
98 i++;
99 }
100 if (j < n)
101 return -1;
102
103 return i;
104 }
105
106 static PyObject *
count_n(PyObject * module,PyObject * args)107 count_n(PyObject *module, PyObject *args)
108 {
109 PyObject *a;
110 Py_ssize_t n, i;
111
112 if (!PyArg_ParseTuple(args, "On:count_n", &a, &n))
113 return NULL;
114 if (ensure_bitarray(a) < 0)
115 return NULL;
116
117 if (n < 0) {
118 PyErr_SetString(PyExc_ValueError, "non-negative integer expected");
119 return NULL;
120 }
121 #define aa ((bitarrayobject *) a)
122 if (n > aa->nbits) {
123 PyErr_SetString(PyExc_ValueError, "n larger than bitarray size");
124 return NULL;
125 }
126 i = count_to_n(aa, n); /* do actual work here */
127 #undef aa
128 if (i < 0) {
129 PyErr_SetString(PyExc_ValueError, "n exceeds total count");
130 return NULL;
131 }
132 return PyLong_FromSsize_t(i);
133 }
134
135 PyDoc_STRVAR(count_n_doc,
136 "count_n(a, n, /) -> int\n\
137 \n\
138 Return lowest index `i` for which `a[:i].count() == n`.\n\
139 Raises `ValueError`, when n exceeds total count (`a.count()`).");
140
141 /* ----------------------------- right index --------------------------- */
142
143 /* return index of highest occurrence of vi in self[a:b], -1 when not found */
144 static Py_ssize_t
find_last(bitarrayobject * self,int vi,Py_ssize_t a,Py_ssize_t b)145 find_last(bitarrayobject *self, int vi, Py_ssize_t a, Py_ssize_t b)
146 {
147 const Py_ssize_t n = b - a;
148 Py_ssize_t res, i;
149
150 assert(0 <= a && a <= self->nbits);
151 assert(0 <= b && b <= self->nbits);
152 assert(0 <= vi && vi <= 1);
153 if (n <= 0)
154 return -1;
155
156 /* the logic here is the same as in find_bit() in _bitarray.c */
157 #ifdef PY_UINT64_T
158 if (n > 64) {
159 const Py_ssize_t word_a = (a + 63) / 64;
160 const Py_ssize_t word_b = b / 64;
161 const PY_UINT64_T w = vi ? 0 : ~0;
162
163 if ((res = find_last(self, vi, 64 * word_b, b)) >= 0)
164 return res;
165
166 for (i = word_b - 1; i >= word_a; i--) { /* skip uint64 words */
167 if (w ^ ((PY_UINT64_T *) self->ob_item)[i])
168 return find_last(self, vi, 64 * i, 64 * i + 64);
169 }
170 return find_last(self, vi, a, 64 * word_a);
171 }
172 #endif
173 if (n > 8) {
174 const Py_ssize_t byte_a = BYTES(a);
175 const Py_ssize_t byte_b = b / 8;
176 const char c = vi ? 0 : ~0;
177
178 if ((res = find_last(self, vi, BITS(byte_b), b)) >= 0)
179 return res;
180
181 for (i = byte_b - 1; i >= byte_a; i--) { /* skip bytes */
182 assert_byte_in_range(self, i);
183 if (c ^ self->ob_item[i])
184 return find_last(self, vi, BITS(i), BITS(i) + 8);
185 }
186 return find_last(self, vi, a, BITS(byte_a));
187 }
188 assert(n <= 8);
189 for (i = b - 1; i >= a; i--) {
190 if (getbit(self, i) == vi)
191 return i;
192 }
193 return -1;
194 }
195
196 static PyObject *
r_index(PyObject * module,PyObject * args)197 r_index(PyObject *module, PyObject *args)
198 {
199 PyObject *value = Py_True, *a;
200 Py_ssize_t start = 0, stop = PY_SSIZE_T_MAX, res;
201 int vi;
202
203 if (!PyArg_ParseTuple(args, "O|Onn:rindex", &a, &value, &start, &stop))
204 return NULL;
205 if (ensure_bitarray(a) < 0)
206 return NULL;
207 if ((vi = pybit_as_int(value)) < 0)
208 return NULL;
209
210 #define aa ((bitarrayobject *) a)
211 normalize_index(aa->nbits, &start);
212 normalize_index(aa->nbits, &stop);
213 res = find_last(aa, vi, start, stop);
214 #undef aa
215 if (res < 0)
216 return PyErr_Format(PyExc_ValueError, "%d not in bitarray", vi);
217
218 return PyLong_FromSsize_t(res);
219 }
220
221 PyDoc_STRVAR(rindex_doc,
222 "rindex(bitarray, value=1, start=0, stop=<end of array>, /) -> int\n\
223 \n\
224 Return the rightmost (highest) index of `value` in bitarray.\n\
225 Raises `ValueError` if the value is not present.");
226
227 /* --------------------------- unary functions ------------------------- */
228
229 static PyObject *
parity(PyObject * module,PyObject * a)230 parity(PyObject *module, PyObject *a)
231 {
232 unsigned char par = 0;
233 Py_ssize_t i;
234
235 if (ensure_bitarray(a) < 0)
236 return NULL;
237
238 #define aa ((bitarrayobject *) a)
239 for (i = 0; i < aa->nbits / 8; i++)
240 par ^= aa->ob_item[i];
241 if (aa->nbits % 8)
242 par ^= zeroed_last_byte(aa);
243 #undef aa
244
245 return PyLong_FromLong((long) bitcount_lookup[par] % 2);
246 }
247
248 PyDoc_STRVAR(parity_doc,
249 "parity(a, /) -> int\n\
250 \n\
251 Return the parity of bitarray `a`.\n\
252 This is equivalent to `a.count() % 2` (but more efficient).");
253
254 /* --------------------------- binary functions ------------------------ */
255
256 enum kernel_type {
257 KERN_cand, /* count bitwise and -> int */
258 KERN_cor, /* count bitwise or -> int */
259 KERN_cxor, /* count bitwise xor -> int */
260 KERN_subset, /* is subset -> bool */
261 };
262
263 static PyObject *
binary_function(PyObject * args,enum kernel_type kern,const char * format)264 binary_function(PyObject *args, enum kernel_type kern, const char *format)
265 {
266 Py_ssize_t res = 0, s, i;
267 PyObject *a, *b;
268 unsigned char c;
269 int r;
270
271 if (!PyArg_ParseTuple(args, format, &a, &b))
272 return NULL;
273 if (ensure_bitarray(a) < 0 || ensure_bitarray(b) < 0)
274 return NULL;
275
276 #define aa ((bitarrayobject *) a)
277 #define bb ((bitarrayobject *) b)
278 if (aa->nbits != bb->nbits) {
279 PyErr_SetString(PyExc_ValueError,
280 "bitarrays of equal length expected");
281 return NULL;
282 }
283 if (aa->endian != bb->endian) {
284 PyErr_SetString(PyExc_ValueError,
285 "bitarrays of equal endianness expected");
286 return NULL;
287 }
288 s = aa->nbits / 8; /* number of whole bytes in buffer */
289 r = aa->nbits % 8; /* remaining bits */
290
291 switch (kern) {
292 case KERN_cand:
293 for (i = 0; i < s; i++) {
294 c = aa->ob_item[i] & bb->ob_item[i];
295 res += bitcount_lookup[c];
296 }
297 if (r) {
298 c = zeroed_last_byte(aa) & zeroed_last_byte(bb);
299 res += bitcount_lookup[c];
300 }
301 break;
302
303 case KERN_cor:
304 for (i = 0; i < s; i++) {
305 c = aa->ob_item[i] | bb->ob_item[i];
306 res += bitcount_lookup[c];
307 }
308 if (r) {
309 c = zeroed_last_byte(aa) | zeroed_last_byte(bb);
310 res += bitcount_lookup[c];
311 }
312 break;
313
314 case KERN_cxor:
315 for (i = 0; i < s; i++) {
316 c = aa->ob_item[i] ^ bb->ob_item[i];
317 res += bitcount_lookup[c];
318 }
319 if (r) {
320 c = zeroed_last_byte(aa) ^ zeroed_last_byte(bb);
321 res += bitcount_lookup[c];
322 }
323 break;
324
325 case KERN_subset:
326 for (i = 0; i < s; i++) {
327 if ((aa->ob_item[i] & bb->ob_item[i]) != aa->ob_item[i])
328 Py_RETURN_FALSE;
329 }
330 if (r) {
331 if ((zeroed_last_byte(aa) & zeroed_last_byte(bb)) !=
332 zeroed_last_byte(aa))
333 Py_RETURN_FALSE;
334 }
335 Py_RETURN_TRUE;
336
337 default: /* cannot happen */
338 return NULL;
339 }
340 #undef aa
341 #undef bb
342 return PyLong_FromSsize_t(res);
343 }
344
345 #define COUNT_FUNC(oper, ochar) \
346 static PyObject * \
347 count_ ## oper (PyObject *module, PyObject *args) \
348 { \
349 return binary_function(args, KERN_c ## oper, "OO:count_" #oper); \
350 } \
351 PyDoc_STRVAR(count_ ## oper ## _doc, \
352 "count_" #oper "(a, b, /) -> int\n\
353 \n\
354 Return `(a " ochar " b).count()` in a memory efficient manner,\n\
355 as no intermediate bitarray object gets created.")
356
357 COUNT_FUNC(and, "&"); /* count_and */
358 COUNT_FUNC(or, "|"); /* count_or */
359 COUNT_FUNC(xor, "^"); /* count_xor */
360
361
362 static PyObject *
subset(PyObject * module,PyObject * args)363 subset(PyObject *module, PyObject *args)
364 {
365 return binary_function(args, KERN_subset, "OO:subset");
366 }
367
368 PyDoc_STRVAR(subset_doc,
369 "subset(a, b, /) -> bool\n\
370 \n\
371 Return `True` if bitarray `a` is a subset of bitarray `b`.\n\
372 `subset(a, b)` is equivalent to `(a & b).count() == a.count()` but is more\n\
373 efficient since we can stop as soon as one mismatch is found, and no\n\
374 intermediate bitarray object gets created.");
375
376 /* ---------------------------- serialization -------------------------- */
377
378 static PyObject *
serialize(PyObject * module,PyObject * a)379 serialize(PyObject *module, PyObject *a)
380 {
381 PyObject *result;
382 Py_ssize_t nbytes;
383 char *str;
384
385 if (ensure_bitarray(a) < 0)
386 return NULL;
387
388 nbytes = Py_SIZE(a);
389 result = PyBytes_FromStringAndSize(NULL, nbytes + 1);
390 if (result == NULL)
391 return NULL;
392
393 str = PyBytes_AsString(result);
394 #define aa ((bitarrayobject *) a)
395 *str = (char) (16 * IS_BE(aa) + setunused(aa));
396 memcpy(str + 1, aa->ob_item, (size_t) nbytes);
397 #undef aa
398 return result;
399 }
400
401 PyDoc_STRVAR(serialize_doc,
402 "serialize(bitarray, /) -> bytes\n\
403 \n\
404 Return a serialized representation of the bitarray, which may be passed to\n\
405 `deserialize()`. It efficiently represents the bitarray object (including\n\
406 its endianness) and is guaranteed not to change in future releases.");
407
408 /* ----------------------------- hexadecimal --------------------------- */
409
410 static const char hexdigits[] = "0123456789abcdef";
411
412 static int
hex_to_int(char c)413 hex_to_int(char c)
414 {
415 if ('0' <= c && c <= '9')
416 return c - '0';
417 if ('a' <= c && c <= 'f')
418 return c - 'a' + 10;
419 if ('A' <= c && c <= 'F')
420 return c - 'A' + 10;
421 return -1;
422 }
423
424 static PyObject *
ba2hex(PyObject * module,PyObject * a)425 ba2hex(PyObject *module, PyObject *a)
426 {
427 PyObject *result;
428 size_t i, strsize;
429 char *str;
430 int le, be;
431
432 if (ensure_bitarray(a) < 0)
433 return NULL;
434
435 #define aa ((bitarrayobject *) a)
436 if (aa->nbits % 4) {
437 PyErr_SetString(PyExc_ValueError, "bitarray length not multiple of 4");
438 return NULL;
439 }
440
441 /* strsize = aa->nbits / 4; could make strsize odd */
442 strsize = 2 * Py_SIZE(a);
443 str = (char *) PyMem_Malloc(strsize);
444 if (str == NULL)
445 return PyErr_NoMemory();
446
447 le = IS_LE(aa);
448 be = IS_BE(aa);
449 for (i = 0; i < strsize; i += 2) {
450 unsigned char c = aa->ob_item[i / 2];
451
452 str[i + le] = hexdigits[c >> 4];
453 str[i + be] = hexdigits[0x0f & c];
454 }
455 assert((size_t) aa->nbits / 4 <= strsize);
456 result = Py_BuildValue("s#", str, aa->nbits / 4);
457 #undef aa
458 PyMem_Free((void *) str);
459 return result;
460 }
461
462 PyDoc_STRVAR(ba2hex_doc,
463 "ba2hex(bitarray, /) -> hexstr\n\
464 \n\
465 Return a string containing the hexadecimal representation of\n\
466 the bitarray (which has to be multiple of 4 in length).");
467
468
469 /* Translate hexadecimal digits into the bitarray's buffer.
470 Each digit corresponds to 4 bits in the bitarray.
471 The number of digits may be odd. */
472 static PyObject *
hex2ba(PyObject * module,PyObject * args)473 hex2ba(PyObject *module, PyObject *args)
474 {
475 PyObject *a;
476 char *str;
477 Py_ssize_t i, strsize;
478 int le, be;
479
480 if (!PyArg_ParseTuple(args, "Os#", &a, &str, &strsize))
481 return NULL;
482 if (ensure_ba_of_length(a, 4 * strsize) < 0)
483 return NULL;
484
485 #define aa ((bitarrayobject *) a)
486 le = IS_LE(aa);
487 be = IS_BE(aa);
488 assert(le + be == 1 && str[strsize] == 0);
489 for (i = 0; i < strsize; i += 2) {
490 int x = hex_to_int(str[i + le]);
491 int y = hex_to_int(str[i + be]);
492
493 if (x < 0 || y < 0) {
494 /* ignore the terminating NUL - happends when strsize is odd */
495 if (i + le == strsize) /* str[i+le] is NUL */
496 x = 0;
497 if (i + be == strsize) /* str[i+be] is NUL */
498 y = 0;
499 /* there is an invalid byte - or (non-terminating) NUL */
500 if (x < 0 || y < 0) {
501 PyErr_SetString(PyExc_ValueError,
502 "non-hexadecimal digit found");
503 return NULL;
504 }
505 }
506 assert(0 <= x && x < 16 && 0 <= y && y < 16);
507 aa->ob_item[i / 2] = x << 4 | y;
508 }
509 #undef aa
510 Py_RETURN_NONE;
511 }
512
513 /* ----------------------- base 2, 4, 8, 16, 32, 64 -------------------- */
514
515 /* RFC 4648 Base32 alphabet */
516 static const char base32_alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
517
518 /* standard base 64 alphabet */
519 static const char base64_alphabet[] =
520 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
521
522 static int
digit_to_int(char c,int n)523 digit_to_int(char c, int n)
524 {
525 int i;
526
527 switch (n) {
528 case 32: /* base 32 */
529 if ('A' <= c && c <= 'Z')
530 return c - 'A';
531 if ('2' <= c && c <= '7')
532 return c - '2' + 26;
533 break;
534
535 case 64: /* base 64 */
536 if ('A' <= c && c <= 'Z')
537 return c - 'A';
538 if ('a' <= c && c <= 'z')
539 return c - 'a' + 26;
540 if ('0' <= c && c <= '9')
541 return c - '0' + 52;
542 if (c == '+')
543 return 62;
544 if (c == '/')
545 return 63;
546 break;
547
548 default: /* base 2, 4, 8, 16 */
549 i = hex_to_int(c);
550 if (i < n)
551 return i;
552 }
553 return -1;
554 }
555
556 /* return m = log2(n) for m = 1..6 */
557 static int
base_to_length(int n)558 base_to_length(int n)
559 {
560 int m;
561
562 for (m = 1; m < 7; m++) {
563 if (n == (1 << m))
564 return m;
565 }
566 PyErr_Format(PyExc_ValueError,
567 "base must be 2, 4, 8, 16, 32 or 64, not %d", n);
568 return -1;
569 }
570
571 static PyObject *
ba2base(PyObject * module,PyObject * args)572 ba2base(PyObject *module, PyObject *args)
573 {
574 const char *alphabet;
575 PyObject *result, *a;
576 size_t i, strsize;
577 char *str;
578 int n, m, le;
579
580 if (!PyArg_ParseTuple(args, "iO:ba2base", &n, &a))
581 return NULL;
582 if ((m = base_to_length(n)) < 0)
583 return NULL;
584 if (ensure_bitarray(a) < 0)
585 return NULL;
586
587 switch (n) {
588 case 32: alphabet = base32_alphabet; break;
589 case 64: alphabet = base64_alphabet; break;
590 default: alphabet = hexdigits;
591 }
592
593 #define aa ((bitarrayobject *) a)
594 if (aa->nbits % m)
595 return PyErr_Format(PyExc_ValueError,
596 "bitarray length must be multiple of %d", m);
597
598 strsize = aa->nbits / m;
599 if ((str = (char *) PyMem_Malloc(strsize)) == NULL)
600 return PyErr_NoMemory();
601
602 le = IS_LE(aa);
603 for (i = 0; i < strsize; i++) {
604 int j, k, x = 0;
605
606 for (j = 0; j < m; j++) {
607 k = le ? j : (m - j - 1);
608 x |= getbit(aa, i * m + k) << j;
609 }
610 str[i] = alphabet[x];
611 }
612 result = Py_BuildValue("s#", str, strsize);
613 #undef aa
614 PyMem_Free((void *) str);
615 return result;
616 }
617
618 PyDoc_STRVAR(ba2base_doc,
619 "ba2base(n, bitarray, /) -> str\n\
620 \n\
621 Return a string containing the base `n` ASCII representation of\n\
622 the bitarray. Allowed values for `n` are 2, 4, 8, 16, 32 and 64.\n\
623 The bitarray has to be multiple of length 1, 2, 3, 4, 5 or 6 respectively.\n\
624 For `n=16` (hexadecimal), `ba2hex()` will be much faster, as `ba2base()`\n\
625 does not take advantage of byte level operations.\n\
626 For `n=32` the RFC 4648 Base32 alphabet is used, and for `n=64` the\n\
627 standard base 64 alphabet is used.");
628
629
630 /* Translate ASCII digits into the bitarray's buffer.
631 The (Python) arguments to this functions are:
632 - base n, one of 2, 4, 8, 16, 32, 64 (n=2^m where m bits per digit)
633 - bitarray (of length m * len(s)) whose buffer is written into
634 - byte object s containing the ASCII digits
635 */
636 static PyObject *
base2ba(PyObject * module,PyObject * args)637 base2ba(PyObject *module, PyObject *args)
638 {
639 PyObject *a = NULL;
640 Py_ssize_t i, strsize = 0;
641 char *str = NULL;
642 int n, m, le;
643
644 if (!PyArg_ParseTuple(args, "i|Os#", &n, &a, &str, &strsize))
645 return NULL;
646 if ((m = base_to_length(n)) < 0)
647 return NULL;
648 if (a == NULL)
649 return PyLong_FromLong(m);
650 if (ensure_ba_of_length(a, m * strsize) < 0)
651 return NULL;
652
653 #define aa ((bitarrayobject *) a)
654 memset(aa->ob_item, 0x00, (size_t) Py_SIZE(a));
655
656 le = IS_LE(aa);
657 for (i = 0; i < strsize; i++) {
658 int j, k, d = digit_to_int(str[i], n);
659
660 if (d < 0) {
661 unsigned char c = str[i];
662 return PyErr_Format(PyExc_ValueError, "invalid digit found for "
663 "base %d, got '%c' (0x%02x)", n, c, c);
664 }
665 for (j = 0; j < m; j++) {
666 k = le ? j : (m - j - 1);
667 setbit(aa, i * m + k, d & (1 << j));
668 }
669 }
670 #undef aa
671 Py_RETURN_NONE;
672 }
673
674 /* ------------------- variable length bitarray format ----------------- */
675
676 /* grow buffer by at least one byte */
677 static int
grow_buffer(bitarrayobject * self)678 grow_buffer(bitarrayobject *self)
679 {
680 size_t newsize = Py_SIZE(self) + 1;
681
682 assert_nbits(self);
683 assert(self->allocated >= Py_SIZE(self));
684 assert(self->ob_exports == 0);
685 assert(self->buffer == NULL);
686 assert(self->readonly == 0);
687
688 /* standard growth pattern */
689 newsize += (newsize >> 4) + (newsize < 8 ? 3 : 7);
690
691 self->ob_item = PyMem_Realloc(self->ob_item, newsize);
692 if (self->ob_item == NULL) {
693 PyErr_NoMemory();
694 return -1;
695 }
696 Py_SET_SIZE(self, newsize);
697 self->allocated = newsize;
698 self->nbits = 8 * newsize;
699 return 0;
700 }
701
702 /* PADBITS is always 3 - the number of bits that represent the number of
703 padding bits. The actual number of padding bits is called 'padding'
704 below, and is in range(0, 7).
705 Also note that 'padding' refers to the pad bits within the variable
706 length format, which is not the same as the pad bits of the actual
707 bitarray. For example, b'\x10' has padding = 1, and decodes to
708 bitarray('000'), which has 5 pad bits. */
709 #define PADBITS 3
710
711 /* consume iterator while decoding bytes into bitarray */
712 static PyObject *
vl_decode(PyObject * module,PyObject * args)713 vl_decode(PyObject *module, PyObject *args)
714 {
715 PyObject *iter, *item, *a;
716 Py_ssize_t padding = 0; /* number of pad bits read from header byte */
717 Py_ssize_t i = 0; /* bit counter */
718 unsigned char b = 0x80; /* empty stream will raise StopIteration */
719 Py_ssize_t k;
720
721 if (!PyArg_ParseTuple(args, "OO", &iter, &a))
722 return NULL;
723 if (!PyIter_Check(iter))
724 return PyErr_Format(PyExc_TypeError, "iterator or bytes expected, "
725 "got '%s'", Py_TYPE(iter)->tp_name);
726
727 #define aa ((bitarrayobject *) a)
728 while ((item = PyIter_Next(iter))) {
729 #ifdef IS_PY3K
730 if (PyLong_Check(item))
731 b = (unsigned char) PyLong_AsLong(item);
732 #else
733 if (PyBytes_Check(item))
734 b = (unsigned char) *PyBytes_AS_STRING(item);
735 #endif
736 else {
737 PyErr_Format(PyExc_TypeError, "int (byte) iterator expected, "
738 "got '%s' element", Py_TYPE(item)->tp_name);
739 Py_DECREF(item);
740 return NULL;
741 }
742 Py_DECREF(item);
743
744 if (i + 6 >= aa->nbits && grow_buffer(aa) < 0)
745 return NULL;
746 assert(i + 6 < aa->nbits);
747
748 if (i == 0) {
749 padding = (b & 0x70) >> 4;
750 if (padding >= 7 || ((b & 0x80) == 0 && padding > 4))
751 return PyErr_Format(PyExc_ValueError,
752 "invalid header byte: 0x%02x", b);
753 for (k = 0; k < 4; k++)
754 setbit(aa, i++, (0x08 >> k) & b);
755 }
756 else {
757 for (k = 0; k < 7; k++)
758 setbit(aa, i++, (0x40 >> k) & b);
759 }
760 if ((b & 0x80) == 0)
761 break;
762 }
763 /* set final length of bitarray */
764 aa->nbits = i - padding;
765 Py_SET_SIZE(a, BYTES(aa->nbits));
766 assert_nbits(aa);
767 #undef aa
768
769 if (PyErr_Occurred()) /* from PyIter_Next() */
770 return NULL;
771
772 if (b & 0x80) {
773 k = (i + PADBITS) / 7;
774 return PyErr_Format(PyExc_StopIteration,
775 "no terminating byte found, bytes read: %zd", k);
776 }
777 Py_RETURN_NONE;
778 }
779
780 static PyObject *
vl_encode(PyObject * module,PyObject * a)781 vl_encode(PyObject *module, PyObject *a)
782 {
783 PyObject *result;
784 Py_ssize_t padding, n, m, i;
785 Py_ssize_t j = 0; /* byte conter */
786 char *str;
787
788 if (ensure_bitarray(a) < 0)
789 return NULL;
790
791 #define aa ((bitarrayobject *) a)
792 n = (aa->nbits + PADBITS + 6) / 7; /* number of resulting bytes */
793 m = 7 * n - PADBITS; /* number of bits resulting bytes can hold */
794 padding = m - aa->nbits; /* number of pad bits */
795 assert(0 <= padding && padding < 7);
796
797 result = PyBytes_FromStringAndSize(NULL, n);
798 if (result == NULL)
799 return NULL;
800
801 str = PyBytes_AsString(result);
802 str[0] = aa->nbits > 4 ? 0x80 : 0x00; /* leading bit */
803 str[0] |= padding << 4; /* encode padding */
804 for (i = 0; i < 4 && i < aa->nbits; i++)
805 str[0] |= (0x08 >> i) * getbit(aa, i);
806
807 for (i = 4; i < aa->nbits; i++) {
808 int k = (i - 4) % 7;
809
810 if (k == 0) {
811 j++;
812 str[j] = j < n - 1 ? 0x80 : 0x00; /* leading bit */
813 }
814 str[j] |= (0x40 >> k) * getbit(aa, i);
815 }
816 #undef aa
817 assert(j == n - 1);
818
819 return result;
820 }
821
822 PyDoc_STRVAR(vl_encode_doc,
823 "vl_encode(bitarray, /) -> bytes\n\
824 \n\
825 Return variable length binary representation of bitarray.\n\
826 This representation is useful for efficiently storing small bitarray\n\
827 in a binary stream. Use `vl_decode()` for decoding.");
828
829 /* --------------------------------------------------------------------- */
830
831 /* Set bitarray_type_obj (bato). This function must be called before any
832 other Python function in this module. */
833 static PyObject *
set_bato(PyObject * module,PyObject * obj)834 set_bato(PyObject *module, PyObject *obj)
835 {
836 bitarray_type_obj = obj;
837 Py_RETURN_NONE;
838 }
839
840 static PyMethodDef module_functions[] = {
841 {"count_n", (PyCFunction) count_n, METH_VARARGS, count_n_doc},
842 {"rindex", (PyCFunction) r_index, METH_VARARGS, rindex_doc},
843 {"parity", (PyCFunction) parity, METH_O, parity_doc},
844 {"count_and", (PyCFunction) count_and, METH_VARARGS, count_and_doc},
845 {"count_or", (PyCFunction) count_or, METH_VARARGS, count_or_doc},
846 {"count_xor", (PyCFunction) count_xor, METH_VARARGS, count_xor_doc},
847 {"subset", (PyCFunction) subset, METH_VARARGS, subset_doc},
848 {"serialize", (PyCFunction) serialize, METH_O, serialize_doc},
849 {"ba2hex", (PyCFunction) ba2hex, METH_O, ba2hex_doc},
850 {"_hex2ba", (PyCFunction) hex2ba, METH_VARARGS, 0},
851 {"ba2base", (PyCFunction) ba2base, METH_VARARGS, ba2base_doc},
852 {"_base2ba", (PyCFunction) base2ba, METH_VARARGS, 0},
853 {"vl_encode", (PyCFunction) vl_encode, METH_O, vl_encode_doc},
854 {"_vl_decode",(PyCFunction) vl_decode, METH_VARARGS, 0},
855 {"_set_bato", (PyCFunction) set_bato, METH_O, 0},
856 {NULL, NULL} /* sentinel */
857 };
858
859 /******************************* Install Module ***************************/
860
861 #ifdef IS_PY3K
862 static PyModuleDef moduledef = {
863 PyModuleDef_HEAD_INIT, "_util", 0, -1, module_functions,
864 };
865 #endif
866
867 PyMODINIT_FUNC
868 #ifdef IS_PY3K
PyInit__util(void)869 PyInit__util(void)
870 #else
871 init_util(void)
872 #endif
873 {
874 PyObject *m;
875
876 #ifdef IS_PY3K
877 m = PyModule_Create(&moduledef);
878 if (m == NULL)
879 return NULL;
880 return m;
881 #else
882 m = Py_InitModule3("_util", module_functions, 0);
883 if (m == NULL)
884 return;
885 #endif
886 }
887