1 // Multinomials over Z.
2 // e.g. [[1, 2], 3, [4, [5, 6]]] means
3 // (1 + 2y) + 3 x + (4 + (5 + 6z)y)x^2
4 // Convenient interchange format for different groups, rings, and fields.
5
6 // TODO: Canonicalize, e.g. [[1]], 0, 0] --> 1.
7
8 #include <stdarg.h>
9 #include <stdio.h>
10 #include <stdint.h> // for intptr_t
11 #include <stdlib.h>
12 #include <gmp.h>
13 #include "pbc_utils.h"
14 #include "pbc_field.h"
15 #include "pbc_multiz.h"
16 #include "pbc_random.h"
17 #include "pbc_fp.h"
18 #include "pbc_memory.h"
19 #include "misc/darray.h"
20
21 // Per-element data.
22 struct multiz_s {
23 // Either it's an mpz, or a list of mpzs.
24 char type;
25 union {
26 mpz_t z;
27 darray_t a;
28 };
29 };
30
31 enum {
32 T_MPZ,
33 T_ARR,
34 };
35
multiz_new_empty_list(void)36 static multiz multiz_new_empty_list(void) {
37 multiz ep = pbc_malloc(sizeof(*ep));
38 ep->type = T_ARR;
39 darray_init(ep->a);
40 return ep;
41 }
42
multiz_append(element_ptr x,element_ptr e)43 void multiz_append(element_ptr x, element_ptr e) {
44 multiz l = x->data;
45 darray_append(l->a, e->data);
46 }
47
multiz_new(void)48 static multiz multiz_new(void) {
49 multiz ep = pbc_malloc(sizeof(*ep));
50 ep->type = T_MPZ;
51 mpz_init(ep->z);
52 return ep;
53 }
54
f_init(element_ptr e)55 static void f_init(element_ptr e) {
56 e->data = multiz_new();
57 }
58
multiz_free(multiz ep)59 static void multiz_free(multiz ep) {
60 switch(ep->type) {
61 case T_MPZ:
62 mpz_clear(ep->z);
63 break;
64 default:
65 PBC_ASSERT(T_ARR == ep->type, "no such type");
66 darray_forall(ep->a, (void(*)(void*))multiz_free);
67 darray_clear(ep->a);
68 break;
69 }
70 pbc_free(ep);
71 }
72
f_clear(element_ptr e)73 static void f_clear(element_ptr e) {
74 multiz_free(e->data);
75 }
76
multiz_new_list(element_ptr e)77 element_ptr multiz_new_list(element_ptr e) {
78 element_ptr x = pbc_malloc(sizeof(*x));
79 element_init_same_as(x, e);
80 multiz_free(x->data);
81 x->data = multiz_new_empty_list();
82 multiz_append(x, e);
83 return x;
84 }
85
f_set_si(element_ptr e,signed long int op)86 static void f_set_si(element_ptr e, signed long int op) {
87 multiz_free(e->data);
88 f_init(e);
89 multiz ep = e->data;
90 mpz_set_si(ep->z, op);
91 }
92
f_set_mpz(element_ptr e,mpz_ptr z)93 static void f_set_mpz(element_ptr e, mpz_ptr z) {
94 multiz_free(e->data);
95 f_init(e);
96 multiz ep = e->data;
97 mpz_set(ep->z, z);
98 }
99
f_set0(element_ptr e)100 static void f_set0(element_ptr e) {
101 multiz_free(e->data);
102 f_init(e);
103 }
104
f_set1(element_ptr e)105 static void f_set1(element_ptr e) {
106 multiz_free(e->data);
107 f_init(e);
108 multiz ep = e->data;
109 mpz_set_ui(ep->z, 1);
110 }
111
multiz_out_str(FILE * stream,int base,multiz ep)112 static size_t multiz_out_str(FILE *stream, int base, multiz ep) {
113 switch(ep->type) {
114 case T_MPZ:
115 return mpz_out_str(stream, base, ep->z);
116 default:
117 PBC_ASSERT(T_ARR == ep->type, "no such type");
118 fputc('[', stream);
119 size_t res = 1;
120 int n = darray_count(ep->a);
121 int i;
122 for(i = 0; i < n; i++) {
123 if (i) res += 2, fputs(", ", stream);
124 res += multiz_out_str(stream, base, darray_at(ep->a, i));
125 }
126 fputc(']', stream);
127 res++;
128 return res;
129 }
130 }
131
f_out_str(FILE * stream,int base,element_ptr e)132 static size_t f_out_str(FILE *stream, int base, element_ptr e) {
133 return multiz_out_str(stream, base, e->data);
134 }
135
multiz_to_mpz(mpz_ptr z,multiz ep)136 void multiz_to_mpz(mpz_ptr z, multiz ep) {
137 while(ep->type == T_ARR) ep = darray_at(ep->a, 0);
138 PBC_ASSERT(T_MPZ == ep->type, "no such type");
139 mpz_set(z, ep->z);
140 }
141
f_to_mpz(mpz_ptr z,element_ptr a)142 static void f_to_mpz(mpz_ptr z, element_ptr a) {
143 multiz_to_mpz(z, a->data);
144 }
145
multiz_sgn(multiz ep)146 static int multiz_sgn(multiz ep) {
147 while(ep->type == T_ARR) ep = darray_at(ep->a, 0);
148 PBC_ASSERT(T_MPZ == ep->type, "no such type");
149 return mpz_sgn(ep->z);
150 }
151
f_sgn(element_ptr a)152 static int f_sgn(element_ptr a) {
153 return multiz_sgn(a->data);
154 }
155
156 static void add_to_x(void *data,
157 multiz x,
158 void (*fun)(mpz_t, const mpz_t, void *scope_ptr),
159 void *scope_ptr);
160
multiz_new_unary(const multiz y,void (* fun)(mpz_t,const mpz_t,void * scope_ptr),void * scope_ptr)161 static multiz multiz_new_unary(const multiz y,
162 void (*fun)(mpz_t, const mpz_t, void *scope_ptr), void *scope_ptr) {
163 multiz x = pbc_malloc(sizeof(*x));
164 switch(y->type) {
165 case T_MPZ:
166 x->type = T_MPZ;
167 mpz_init(x->z);
168 fun(x->z, y->z, scope_ptr);
169 break;
170 default:
171 PBC_ASSERT(T_ARR == ep->type, "no such type");
172 x->type = T_ARR;
173 darray_init(x->a);
174 darray_forall4(y->a,
175 (void(*)(void*,void*,void*,void*))add_to_x,
176 x,
177 fun,
178 scope_ptr);
179 break;
180 }
181 return x;
182 }
183
add_to_x(void * data,multiz x,void (* fun)(mpz_t,const mpz_t,void * scope_ptr),void * scope_ptr)184 static void add_to_x(void *data,
185 multiz x,
186 void (*fun)(mpz_t, const mpz_t, void *scope_ptr),
187 void *scope_ptr) {
188 darray_append(x->a, multiz_new_unary(data, fun, scope_ptr));
189 }
190
mpzset(mpz_t dst,const mpz_t src,void * scope_ptr)191 static void mpzset(mpz_t dst, const mpz_t src, void *scope_ptr) {
192 UNUSED_VAR(scope_ptr);
193 mpz_set(dst, src);
194 }
195
multiz_clone(multiz y)196 static multiz multiz_clone(multiz y) {
197 return multiz_new_unary(y, (void(*)(mpz_t, const mpz_t, void *))mpzset, NULL);
198 }
199
multiz_new_bin(const multiz a,const multiz b,void (* fun)(mpz_t,const mpz_t,const mpz_t))200 static multiz multiz_new_bin(const multiz a, const multiz b,
201 void (*fun)(mpz_t, const mpz_t, const mpz_t)) {
202 if (T_MPZ == a->type) {
203 if (T_MPZ == b->type) {
204 multiz x = multiz_new();
205 fun(x->z, a->z, b->z);
206 return x;
207 } else {
208 multiz x = multiz_clone(b);
209 multiz z = x;
210 PBC_ASSERT(T_ARR == z->type, "no such type");
211 while(z->type == T_ARR) z = darray_at(z->a, 0);
212 fun(z->z, a->z, z->z);
213 return x;
214 }
215 } else {
216 PBC_ASSERT(T_ARR == a->type, "no such type");
217 if (T_MPZ == b->type) {
218 multiz x = multiz_clone(a);
219 multiz z = x;
220 PBC_ASSERT(T_ARR == z->type, "no such type");
221 while(z->type == T_ARR) z = darray_at(z->a, 0);
222 fun(z->z, b->z, z->z);
223 return x;
224 } else {
225 PBC_ASSERT(T_ARR == b->type, "no such type");
226 int m = darray_count(a->a);
227 int n = darray_count(b->a);
228 int min = m < n ? m : n;
229 int max = m > n ? m : n;
230 multiz x = multiz_new_empty_list();
231 int i;
232 for(i = 0; i < min; i++) {
233 multiz z = multiz_new_bin(darray_at(a->a, i), darray_at(b->a, i), fun);
234 darray_append(x->a, z);
235 }
236 multiz zero = multiz_new();
237 for(; i < max; i++) {
238 multiz z = multiz_new_bin(m > n ? darray_at(a->a, i) : zero,
239 n > m ? darray_at(b->a, i) : zero,
240 fun);
241 darray_append(x->a, z);
242 }
243 multiz_free(zero);
244 return x;
245 }
246 }
247 }
multiz_new_add(const multiz a,const multiz b)248 static multiz multiz_new_add(const multiz a, const multiz b) {
249 return multiz_new_bin(a, b, mpz_add);
250 }
251
f_add(element_ptr n,element_ptr a,element_ptr b)252 static void f_add(element_ptr n, element_ptr a, element_ptr b) {
253 multiz delme = n->data;
254 n->data = multiz_new_add(a->data, b->data);
255 multiz_free(delme);
256 }
257
multiz_new_sub(const multiz a,const multiz b)258 static multiz multiz_new_sub(const multiz a, const multiz b) {
259 return multiz_new_bin(a, b, mpz_sub);
260 }
f_sub(element_ptr n,element_ptr a,element_ptr b)261 static void f_sub(element_ptr n, element_ptr a, element_ptr b) {
262 multiz delme = n->data;
263 n->data = multiz_new_sub(a->data, b->data);
264 multiz_free(delme);
265 }
266
mpzmul(mpz_t x,const mpz_t y,const mpz_t z)267 static void mpzmul(mpz_t x, const mpz_t y, const mpz_t z) {
268 mpz_mul(x, y, z);
269 }
270
multiz_new_mul(const multiz a,const multiz b)271 static multiz multiz_new_mul(const multiz a, const multiz b) {
272 if (T_MPZ == a->type) {
273 // Multiply each coefficient of b by a->z.
274 return multiz_new_unary(b, (void(*)(mpz_t, const mpz_t, void *))mpzmul, a->z);
275 } else {
276 PBC_ASSERT(T_ARR == a->type, "no such type");
277 if (T_MPZ == b->type) {
278 // Multiply each coefficient of a by b->z.
279 return multiz_new_unary(a, (void(*)(mpz_t, const mpz_t, void *))mpzmul, b->z);
280 } else {
281 PBC_ASSERT(T_ARR == b->type, "no such type");
282 int m = darray_count(a->a);
283 int n = darray_count(b->a);
284 int max = m + n - 1;
285 multiz x = multiz_new_empty_list();
286 int i;
287 multiz zero = multiz_new();
288 for(i = 0; i < max; i++) {
289 multiz z = multiz_new();
290 int j;
291 for (j = 0; j <= i; j++) {
292 multiz y = multiz_new_mul(j < m ? darray_at(a->a, j) : zero,
293 i - j < n ? darray_at(b->a, i - j) : zero);
294 multiz t = multiz_new_add(z, y);
295 multiz_free(y);
296 multiz_free(z);
297 z = t;
298 }
299 darray_append(x->a, z);
300 }
301 multiz_free(zero);
302 return x;
303 }
304 }
305 }
f_mul(element_ptr n,element_ptr a,element_ptr b)306 static void f_mul(element_ptr n, element_ptr a, element_ptr b) {
307 multiz delme = n->data;
308 n->data = multiz_new_mul(a->data, b->data);
309 multiz_free(delme);
310 }
311
f_mul_mpz(element_ptr n,element_ptr a,mpz_ptr z)312 static void f_mul_mpz(element_ptr n, element_ptr a, mpz_ptr z) {
313 multiz delme = n->data;
314 n->data = multiz_new_unary(a->data, (void(*)(mpz_t, const mpz_t, void *))mpzmul, z);
315 multiz_free(delme);
316 }
317
mulsi(mpz_t x,const mpz_t y,signed long * i)318 static void mulsi(mpz_t x, const mpz_t y, signed long *i) {
319 mpz_mul_si(x, y, *i);
320 }
321
f_mul_si(element_ptr n,element_ptr a,signed long int z)322 static void f_mul_si(element_ptr n, element_ptr a, signed long int z) {
323 multiz delme = n->data;
324 n->data = multiz_new_unary(a->data, (void(*)(mpz_t, const mpz_t, void *))mulsi, &z);
325 multiz_free(delme);
326 }
327
mpzneg(mpz_t dst,const mpz_t src,void * scope_ptr)328 static void mpzneg(mpz_t dst, const mpz_t src, void *scope_ptr) {
329 UNUSED_VAR(scope_ptr);
330 mpz_neg(dst, src);
331 }
332
multiz_new_neg(multiz z)333 static multiz multiz_new_neg(multiz z) {
334 return multiz_new_unary(z, (void(*)(mpz_t, const mpz_t, void *))mpzneg, NULL);
335 }
336
f_set(element_ptr n,element_ptr a)337 static void f_set(element_ptr n, element_ptr a) {
338 multiz delme = n->data;
339 n->data = multiz_clone(a->data);
340 multiz_free(delme);
341 }
342
f_neg(element_ptr n,element_ptr a)343 static void f_neg(element_ptr n, element_ptr a) {
344 multiz delme = n->data;
345 n->data = multiz_new_neg(a->data);
346 multiz_free(delme);
347 }
348
f_div(element_ptr c,element_ptr a,element_ptr b)349 static void f_div(element_ptr c, element_ptr a, element_ptr b) {
350 mpz_t d;
351 mpz_init(d);
352 element_to_mpz(d, b);
353 multiz delme = c->data;
354 c->data = multiz_new_unary(a->data, (void(*)(mpz_t, const mpz_t, void *))mpz_tdiv_q, d);
355 mpz_clear(d);
356 multiz_free(delme);
357 }
358
359 // Doesn't make sense if order is infinite.
f_random(element_ptr n)360 static void f_random(element_ptr n) {
361 multiz delme = n->data;
362 f_init(n);
363 multiz_free(delme);
364 }
365
f_from_hash(element_ptr n,void * data,int len)366 static void f_from_hash(element_ptr n, void *data, int len) {
367 mpz_t z;
368 mpz_init(z);
369 mpz_import(z, len, -1, 1, -1, 0, data);
370 f_set_mpz(n, z);
371 mpz_clear(z);
372 }
373
f_is1(element_ptr n)374 static int f_is1(element_ptr n) {
375 multiz ep = n->data;
376 return ep->type == T_MPZ && !mpz_cmp_ui(ep->z, 1);
377 }
378
multiz_is0(multiz m)379 int multiz_is0(multiz m) {
380 return m->type == T_MPZ && mpz_is0(m->z);
381 }
382
f_is0(element_ptr n)383 static int f_is0(element_ptr n) {
384 return multiz_is0(n->data);
385 }
386
f_item_count(element_ptr e)387 static int f_item_count(element_ptr e) {
388 multiz z = e->data;
389 if (T_MPZ == z->type) return 0;
390 return darray_count(z->a);
391 }
392
393 // TODO: Redesign multiz so this doesn't leak.
f_item(element_ptr e,int i)394 static element_ptr f_item(element_ptr e, int i) {
395 multiz z = e->data;
396 if (T_MPZ == z->type) return NULL;
397 element_ptr r = malloc(sizeof(*r));
398 r->field = e->field;
399 r->data = darray_at(z->a, i);
400 return r;
401 }
402
403 // Usual meaning when both are integers.
404 // Otherwise, compare coefficients.
multiz_cmp(multiz a,multiz b)405 static int multiz_cmp(multiz a, multiz b) {
406 if (T_MPZ == a->type) {
407 if (T_MPZ == b->type) {
408 // Simplest case: both are integers.
409 return mpz_cmp(a->z, b->z);
410 }
411 // Leading coefficient of b.
412 while(T_ARR == b->type) b = darray_last(b->a);
413 PBC_ASSERT(T_MPZ == b->type, "no such type");
414 return -mpz_sgn(b->z);
415 }
416 PBC_ASSERT(T_ARR == a->type, "no such type");
417 if (T_MPZ == b->type) {
418 // Leading coefficient of a.
419 while(T_ARR == a->type) a = darray_last(a->a);
420 PBC_ASSERT(T_MPZ == a->type, "no such type");
421 return mpz_sgn(a->z);
422 }
423 PBC_ASSERT(T_ARR == b->type, "no such type");
424 int m = darray_count(a->a);
425 int n = darray_count(b->a);
426 if (m > n) {
427 // Leading coefficient of a.
428 while(T_ARR == a->type) a = darray_last(a->a);
429 PBC_ASSERT(T_MPZ == a->type, "no such type");
430 return mpz_sgn(a->z);
431 }
432 if (n > m) {
433 // Leading coefficient of b.
434 while(T_ARR == b->type) b = darray_last(b->a);
435 PBC_ASSERT(T_MPZ == b->type, "no such type");
436 return -mpz_sgn(b->z);
437 }
438 for(n--; n >= 0; n--) {
439 int i = multiz_cmp(darray_at(a->a, n), darray_at(b->a, n));
440 if (i) return i;
441 }
442 return 0;
443 }
f_cmp(element_ptr x,element_ptr y)444 static int f_cmp(element_ptr x, element_ptr y) {
445 return multiz_cmp(x->data, y->data);
446 }
447
f_field_clear(field_t f)448 static void f_field_clear(field_t f) { UNUSED_VAR (f); }
449
450 // OpenSSL convention:
451 // 4 bytes containing length
452 // followed by number in big-endian, most-significant bit set if negative
453 // (prepending null byte if necessary)
454 // Positive numbers also the same as mpz_out_raw.
z_to_bytes(unsigned char * data,element_t e)455 static int z_to_bytes(unsigned char *data, element_t e) {
456 mpz_ptr z = e->data;
457 size_t msb = mpz_sizeinbase(z, 2);
458 size_t n = 4;
459 size_t i;
460
461 if (!(msb % 8)) {
462 data[4] = 0;
463 n++;
464 }
465 if (mpz_sgn(z) < 0) {
466 mpz_export(data + n, NULL, 1, 1, 1, 0, z);
467 data[4] |= 128;
468 } else {
469 mpz_export(data + n, NULL, 1, 1, 1, 0, z);
470 }
471 n += (msb + 7) / 8 - 4;
472 for (i=0; i<4; i++) {
473 data[i] = (n >> 8 * (3 - i));
474 }
475 n += 4;
476
477 return n;
478 }
479
z_from_bytes(element_t e,unsigned char * data)480 static int z_from_bytes(element_t e, unsigned char *data) {
481 unsigned char *ptr;
482 size_t i, n;
483 mpz_ptr z = e->data;
484 mpz_t z1;
485 int neg = 0;
486
487 mpz_init(z1);
488 mpz_set_ui(z, 0);
489
490 ptr = data;
491 n = 0;
492 for (i=0; i<4; i++) {
493 n += ((unsigned int) *ptr) << 8 * (3 - i);
494 ptr++;
495 }
496 if (data[4] & 128) {
497 neg = 1;
498 data[4] &= 127;
499 }
500 for (i=0; i<n; i++) {
501 mpz_set_ui(z1, *ptr);
502 mpz_mul_2exp(z1, z1, 8 * (n - 1 - i));
503 ptr++;
504 mpz_add(z, z, z1);
505 }
506 mpz_clear(z1);
507 if (neg) mpz_neg(z, z);
508 return n;
509 }
510
z_length_in_bytes(element_ptr a)511 static int z_length_in_bytes(element_ptr a) {
512 return (mpz_sizeinbase(a->data, 2) + 7) / 8 + 4;
513 }
514
f_out_info(FILE * out,field_ptr f)515 static void f_out_info(FILE *out, field_ptr f) {
516 UNUSED_VAR(f);
517 fprintf(out, "Z multinomials");
518 }
519
f_set_str(element_ptr e,const char * s,int base)520 static int f_set_str(element_ptr e, const char *s, int base) {
521 // TODO: Square brackets.
522 mpz_t z;
523 mpz_init(z);
524 int result = pbc_mpz_set_str(z, s, base);
525 f_set_mpz(e, z);
526 mpz_clear(z);
527 return result;
528 }
529
f_set_multiz(element_ptr e,multiz m)530 static void f_set_multiz(element_ptr e, multiz m) {
531 multiz delme = e->data;
532 e->data = multiz_clone(m);
533 multiz_free(delme);
534 }
535
field_init_multiz(field_ptr f)536 void field_init_multiz(field_ptr f) {
537 field_init(f);
538 f->init = f_init;
539 f->clear = f_clear;
540 f->set_si = f_set_si;
541 f->set_mpz = f_set_mpz;
542 f->set_multiz = f_set_multiz;
543 f->set_str = f_set_str;
544 f->out_str = f_out_str;
545 f->sign = f_sgn;
546 f->add = f_add;
547 f->sub = f_sub;
548 f->set = f_set;
549 f->mul = f_mul;
550 f->mul_mpz = f_mul_mpz;
551 f->mul_si = f_mul_si;
552 f->neg = f_neg;
553 f->cmp = f_cmp;
554 f->div = f_div;
555 f->random = f_random;
556 f->from_hash = f_from_hash;
557 f->is1 = f_is1;
558 f->is0 = f_is0;
559 f->set0 = f_set0;
560 f->set1 = f_set1;
561 f->field_clear = f_field_clear;
562 f->to_bytes = z_to_bytes;
563 f->from_bytes = z_from_bytes;
564 f->to_mpz = f_to_mpz;
565 f->length_in_bytes = z_length_in_bytes;
566 f->item = f_item;
567 f->item_count = f_item_count;
568
569 f->out_info = f_out_info;
570
571 mpz_set_ui(f->order, 0);
572 f->data = NULL;
573 f->fixed_length_in_bytes = -1;
574 }
575
multiz_is_z(multiz m)576 int multiz_is_z(multiz m) {
577 return T_MPZ == m->type;
578 }
579
multiz_count(multiz m)580 int multiz_count(multiz m) {
581 if (T_ARR != m->type) return -1;
582 return darray_count(m->a);
583 }
584
multiz_at(multiz m,int i)585 multiz multiz_at(multiz m, int i) {
586 PBC_ASSERT(T_ARR == m->type, "wrong type");
587 PBC_ASSERT(darray_count(m->a) > i, "out of bounds");
588 return darray_at(m->a, i);
589 }
590