1 // Multinomials over Z.
2 // e.g. [[1, 2], 3, [4, [5, 6]]] means
3 // (1 + 2y) + 3 x + (4 + (5 + 6z)y)x^2
4 // Convenient interchange format for different groups, rings, and fields.
5 
6 // TODO: Canonicalize, e.g. [[1]], 0, 0] --> 1.
7 
8 #include <stdarg.h>
9 #include <stdio.h>
10 #include <stdint.h> // for intptr_t
11 #include <stdlib.h>
12 #include <gmp.h>
13 #include "pbc_utils.h"
14 #include "pbc_field.h"
15 #include "pbc_multiz.h"
16 #include "pbc_random.h"
17 #include "pbc_fp.h"
18 #include "pbc_memory.h"
19 #include "misc/darray.h"
20 
21 // Per-element data.
22 struct multiz_s {
23   // Either it's an mpz, or a list of mpzs.
24   char type;
25   union {
26     mpz_t z;
27     darray_t a;
28   };
29 };
30 
31 enum {
32   T_MPZ,
33   T_ARR,
34 };
35 
multiz_new_empty_list(void)36 static multiz multiz_new_empty_list(void) {
37   multiz ep = pbc_malloc(sizeof(*ep));
38   ep->type = T_ARR;
39   darray_init(ep->a);
40   return ep;
41 }
42 
multiz_append(element_ptr x,element_ptr e)43 void multiz_append(element_ptr x, element_ptr e) {
44   multiz l = x->data;
45   darray_append(l->a, e->data);
46 }
47 
multiz_new(void)48 static multiz multiz_new(void) {
49   multiz ep = pbc_malloc(sizeof(*ep));
50   ep->type = T_MPZ;
51   mpz_init(ep->z);
52   return ep;
53 }
54 
f_init(element_ptr e)55 static void f_init(element_ptr e) {
56   e->data = multiz_new();
57 }
58 
multiz_free(multiz ep)59 static void multiz_free(multiz ep) {
60   switch(ep->type) {
61     case T_MPZ:
62       mpz_clear(ep->z);
63       break;
64     default:
65       PBC_ASSERT(T_ARR == ep->type, "no such type");
66       darray_forall(ep->a, (void(*)(void*))multiz_free);
67       darray_clear(ep->a);
68       break;
69   }
70   pbc_free(ep);
71 }
72 
f_clear(element_ptr e)73 static void f_clear(element_ptr e) {
74   multiz_free(e->data);
75 }
76 
multiz_new_list(element_ptr e)77 element_ptr multiz_new_list(element_ptr e) {
78   element_ptr x = pbc_malloc(sizeof(*x));
79   element_init_same_as(x, e);
80   multiz_free(x->data);
81   x->data = multiz_new_empty_list();
82   multiz_append(x, e);
83   return x;
84 }
85 
f_set_si(element_ptr e,signed long int op)86 static void f_set_si(element_ptr e, signed long int op) {
87   multiz_free(e->data);
88   f_init(e);
89   multiz ep = e->data;
90   mpz_set_si(ep->z, op);
91 }
92 
f_set_mpz(element_ptr e,mpz_ptr z)93 static void f_set_mpz(element_ptr e, mpz_ptr z) {
94   multiz_free(e->data);
95   f_init(e);
96   multiz ep = e->data;
97   mpz_set(ep->z, z);
98 }
99 
f_set0(element_ptr e)100 static void f_set0(element_ptr e) {
101   multiz_free(e->data);
102   f_init(e);
103 }
104 
f_set1(element_ptr e)105 static void f_set1(element_ptr e) {
106   multiz_free(e->data);
107   f_init(e);
108   multiz ep = e->data;
109   mpz_set_ui(ep->z, 1);
110 }
111 
multiz_out_str(FILE * stream,int base,multiz ep)112 static size_t multiz_out_str(FILE *stream, int base, multiz ep) {
113   switch(ep->type) {
114     case T_MPZ:
115       return mpz_out_str(stream, base, ep->z);
116     default:
117       PBC_ASSERT(T_ARR == ep->type, "no such type");
118       fputc('[', stream);
119       size_t res = 1;
120       int n = darray_count(ep->a);
121       int i;
122       for(i = 0; i < n; i++) {
123         if (i) res += 2, fputs(", ", stream);
124         res += multiz_out_str(stream, base, darray_at(ep->a, i));
125       }
126       fputc(']', stream);
127       res++;
128       return res;
129   }
130 }
131 
f_out_str(FILE * stream,int base,element_ptr e)132 static size_t f_out_str(FILE *stream, int base, element_ptr e) {
133   return multiz_out_str(stream, base, e->data);
134 }
135 
multiz_to_mpz(mpz_ptr z,multiz ep)136 void multiz_to_mpz(mpz_ptr z, multiz ep) {
137   while(ep->type == T_ARR) ep = darray_at(ep->a, 0);
138   PBC_ASSERT(T_MPZ == ep->type, "no such type");
139   mpz_set(z, ep->z);
140 }
141 
f_to_mpz(mpz_ptr z,element_ptr a)142 static void f_to_mpz(mpz_ptr z, element_ptr a) {
143   multiz_to_mpz(z, a->data);
144 }
145 
multiz_sgn(multiz ep)146 static int multiz_sgn(multiz ep) {
147   while(ep->type == T_ARR) ep = darray_at(ep->a, 0);
148   PBC_ASSERT(T_MPZ == ep->type, "no such type");
149   return mpz_sgn(ep->z);
150 }
151 
f_sgn(element_ptr a)152 static int f_sgn(element_ptr a) {
153   return multiz_sgn(a->data);
154 }
155 
156 static void add_to_x(void *data,
157                      multiz x,
158                      void (*fun)(mpz_t, const mpz_t, void *scope_ptr),
159                      void *scope_ptr);
160 
multiz_new_unary(const multiz y,void (* fun)(mpz_t,const mpz_t,void * scope_ptr),void * scope_ptr)161 static multiz multiz_new_unary(const multiz y,
162     void (*fun)(mpz_t, const mpz_t, void *scope_ptr), void *scope_ptr) {
163   multiz x = pbc_malloc(sizeof(*x));
164   switch(y->type) {
165     case T_MPZ:
166       x->type = T_MPZ;
167       mpz_init(x->z);
168       fun(x->z, y->z, scope_ptr);
169       break;
170     default:
171       PBC_ASSERT(T_ARR == ep->type, "no such type");
172       x->type = T_ARR;
173       darray_init(x->a);
174       darray_forall4(y->a,
175                      (void(*)(void*,void*,void*,void*))add_to_x,
176                      x,
177                      fun,
178                      scope_ptr);
179       break;
180   }
181   return x;
182 }
183 
add_to_x(void * data,multiz x,void (* fun)(mpz_t,const mpz_t,void * scope_ptr),void * scope_ptr)184 static void add_to_x(void *data,
185                      multiz x,
186                      void (*fun)(mpz_t, const mpz_t, void *scope_ptr),
187                      void *scope_ptr) {
188   darray_append(x->a, multiz_new_unary(data, fun, scope_ptr));
189 }
190 
mpzset(mpz_t dst,const mpz_t src,void * scope_ptr)191 static void mpzset(mpz_t dst, const mpz_t src, void *scope_ptr) {
192   UNUSED_VAR(scope_ptr);
193   mpz_set(dst, src);
194 }
195 
multiz_clone(multiz y)196 static multiz multiz_clone(multiz y) {
197   return multiz_new_unary(y, (void(*)(mpz_t, const mpz_t, void *))mpzset, NULL);
198 }
199 
multiz_new_bin(const multiz a,const multiz b,void (* fun)(mpz_t,const mpz_t,const mpz_t))200 static multiz multiz_new_bin(const multiz a, const multiz b,
201     void (*fun)(mpz_t, const mpz_t, const mpz_t)) {
202   if (T_MPZ == a->type) {
203     if (T_MPZ == b->type) {
204       multiz x = multiz_new();
205       fun(x->z, a->z, b->z);
206       return x;
207     } else {
208       multiz x = multiz_clone(b);
209       multiz z = x;
210       PBC_ASSERT(T_ARR == z->type, "no such type");
211       while(z->type == T_ARR) z = darray_at(z->a, 0);
212       fun(z->z, a->z, z->z);
213       return x;
214     }
215   } else {
216     PBC_ASSERT(T_ARR == a->type, "no such type");
217     if (T_MPZ == b->type) {
218       multiz x = multiz_clone(a);
219       multiz z = x;
220       PBC_ASSERT(T_ARR == z->type, "no such type");
221       while(z->type == T_ARR) z = darray_at(z->a, 0);
222       fun(z->z, b->z, z->z);
223       return x;
224     } else {
225       PBC_ASSERT(T_ARR == b->type, "no such type");
226       int m = darray_count(a->a);
227       int n = darray_count(b->a);
228       int min = m < n ? m : n;
229       int max = m > n ? m : n;
230       multiz x = multiz_new_empty_list();
231       int i;
232       for(i = 0; i < min; i++) {
233         multiz z = multiz_new_bin(darray_at(a->a, i), darray_at(b->a, i), fun);
234         darray_append(x->a, z);
235       }
236       multiz zero = multiz_new();
237       for(; i < max; i++) {
238         multiz z = multiz_new_bin(m > n ? darray_at(a->a, i) : zero,
239                                   n > m ? darray_at(b->a, i) : zero,
240                                   fun);
241         darray_append(x->a, z);
242       }
243       multiz_free(zero);
244       return x;
245     }
246   }
247 }
multiz_new_add(const multiz a,const multiz b)248 static multiz multiz_new_add(const multiz a, const multiz b) {
249   return multiz_new_bin(a, b, mpz_add);
250 }
251 
f_add(element_ptr n,element_ptr a,element_ptr b)252 static void f_add(element_ptr n, element_ptr a, element_ptr b) {
253   multiz delme = n->data;
254   n->data = multiz_new_add(a->data, b->data);
255   multiz_free(delme);
256 }
257 
multiz_new_sub(const multiz a,const multiz b)258 static multiz multiz_new_sub(const multiz a, const multiz b) {
259   return multiz_new_bin(a, b, mpz_sub);
260 }
f_sub(element_ptr n,element_ptr a,element_ptr b)261 static void f_sub(element_ptr n, element_ptr a, element_ptr b) {
262   multiz delme = n->data;
263   n->data = multiz_new_sub(a->data, b->data);
264   multiz_free(delme);
265 }
266 
mpzmul(mpz_t x,const mpz_t y,const mpz_t z)267 static void mpzmul(mpz_t x, const mpz_t y, const mpz_t z) {
268   mpz_mul(x, y, z);
269 }
270 
multiz_new_mul(const multiz a,const multiz b)271 static multiz multiz_new_mul(const multiz a, const multiz b) {
272   if (T_MPZ == a->type) {
273     // Multiply each coefficient of b by a->z.
274     return multiz_new_unary(b, (void(*)(mpz_t, const mpz_t, void *))mpzmul, a->z);
275   } else {
276     PBC_ASSERT(T_ARR == a->type, "no such type");
277     if (T_MPZ == b->type) {
278       // Multiply each coefficient of a by b->z.
279       return multiz_new_unary(a, (void(*)(mpz_t, const mpz_t, void *))mpzmul, b->z);
280     } else {
281       PBC_ASSERT(T_ARR == b->type, "no such type");
282       int m = darray_count(a->a);
283       int n = darray_count(b->a);
284       int max = m + n - 1;
285       multiz x = multiz_new_empty_list();
286       int i;
287       multiz zero = multiz_new();
288       for(i = 0; i < max; i++) {
289         multiz z = multiz_new();
290         int j;
291         for (j = 0; j <= i; j++) {
292           multiz y = multiz_new_mul(j < m ? darray_at(a->a, j) : zero,
293                                     i - j < n ? darray_at(b->a, i - j) : zero);
294           multiz t = multiz_new_add(z, y);
295           multiz_free(y);
296           multiz_free(z);
297           z = t;
298         }
299         darray_append(x->a, z);
300       }
301       multiz_free(zero);
302       return x;
303     }
304   }
305 }
f_mul(element_ptr n,element_ptr a,element_ptr b)306 static void f_mul(element_ptr n, element_ptr a, element_ptr b) {
307   multiz delme = n->data;
308   n->data = multiz_new_mul(a->data, b->data);
309   multiz_free(delme);
310 }
311 
f_mul_mpz(element_ptr n,element_ptr a,mpz_ptr z)312 static void f_mul_mpz(element_ptr n, element_ptr a, mpz_ptr z) {
313   multiz delme = n->data;
314   n->data = multiz_new_unary(a->data, (void(*)(mpz_t, const mpz_t, void *))mpzmul, z);
315   multiz_free(delme);
316 }
317 
mulsi(mpz_t x,const mpz_t y,signed long * i)318 static void mulsi(mpz_t x, const mpz_t y, signed long *i) {
319   mpz_mul_si(x, y, *i);
320 }
321 
f_mul_si(element_ptr n,element_ptr a,signed long int z)322 static void f_mul_si(element_ptr n, element_ptr a, signed long int z) {
323   multiz delme = n->data;
324   n->data = multiz_new_unary(a->data, (void(*)(mpz_t, const mpz_t, void *))mulsi, &z);
325   multiz_free(delme);
326 }
327 
mpzneg(mpz_t dst,const mpz_t src,void * scope_ptr)328 static void mpzneg(mpz_t dst, const mpz_t src, void *scope_ptr) {
329   UNUSED_VAR(scope_ptr);
330   mpz_neg(dst, src);
331 }
332 
multiz_new_neg(multiz z)333 static multiz multiz_new_neg(multiz z) {
334   return multiz_new_unary(z, (void(*)(mpz_t, const mpz_t, void *))mpzneg, NULL);
335 }
336 
f_set(element_ptr n,element_ptr a)337 static void f_set(element_ptr n, element_ptr a) {
338   multiz delme = n->data;
339   n->data = multiz_clone(a->data);
340   multiz_free(delme);
341 }
342 
f_neg(element_ptr n,element_ptr a)343 static void f_neg(element_ptr n, element_ptr a) {
344   multiz delme = n->data;
345   n->data = multiz_new_neg(a->data);
346   multiz_free(delme);
347 }
348 
f_div(element_ptr c,element_ptr a,element_ptr b)349 static void f_div(element_ptr c, element_ptr a, element_ptr b) {
350   mpz_t d;
351   mpz_init(d);
352   element_to_mpz(d, b);
353   multiz delme = c->data;
354   c->data = multiz_new_unary(a->data, (void(*)(mpz_t, const mpz_t, void *))mpz_tdiv_q, d);
355   mpz_clear(d);
356   multiz_free(delme);
357 }
358 
359 // Doesn't make sense if order is infinite.
f_random(element_ptr n)360 static void f_random(element_ptr n) {
361   multiz delme = n->data;
362   f_init(n);
363   multiz_free(delme);
364 }
365 
f_from_hash(element_ptr n,void * data,int len)366 static void f_from_hash(element_ptr n, void *data, int len) {
367   mpz_t z;
368   mpz_init(z);
369   mpz_import(z, len, -1, 1, -1, 0, data);
370   f_set_mpz(n, z);
371   mpz_clear(z);
372 }
373 
f_is1(element_ptr n)374 static int f_is1(element_ptr n) {
375   multiz ep = n->data;
376   return ep->type == T_MPZ && !mpz_cmp_ui(ep->z, 1);
377 }
378 
multiz_is0(multiz m)379 int multiz_is0(multiz m) {
380   return m->type == T_MPZ && mpz_is0(m->z);
381 }
382 
f_is0(element_ptr n)383 static int f_is0(element_ptr n) {
384   return multiz_is0(n->data);
385 }
386 
f_item_count(element_ptr e)387 static int f_item_count(element_ptr e) {
388   multiz z = e->data;
389   if (T_MPZ == z->type) return 0;
390   return darray_count(z->a);
391 }
392 
393 // TODO: Redesign multiz so this doesn't leak.
f_item(element_ptr e,int i)394 static element_ptr f_item(element_ptr e, int i) {
395   multiz z = e->data;
396   if (T_MPZ == z->type) return NULL;
397   element_ptr r = malloc(sizeof(*r));
398   r->field = e->field;
399   r->data = darray_at(z->a, i);
400   return r;
401 }
402 
403 // Usual meaning when both are integers.
404 // Otherwise, compare coefficients.
multiz_cmp(multiz a,multiz b)405 static int multiz_cmp(multiz a, multiz b) {
406   if (T_MPZ == a->type) {
407     if (T_MPZ == b->type) {
408       // Simplest case: both are integers.
409       return mpz_cmp(a->z, b->z);
410     }
411     // Leading coefficient of b.
412     while(T_ARR == b->type) b = darray_last(b->a);
413     PBC_ASSERT(T_MPZ == b->type, "no such type");
414     return -mpz_sgn(b->z);
415   }
416   PBC_ASSERT(T_ARR == a->type, "no such type");
417   if (T_MPZ == b->type) {
418     // Leading coefficient of a.
419     while(T_ARR == a->type) a = darray_last(a->a);
420     PBC_ASSERT(T_MPZ == a->type, "no such type");
421     return mpz_sgn(a->z);
422   }
423   PBC_ASSERT(T_ARR == b->type, "no such type");
424   int m = darray_count(a->a);
425   int n = darray_count(b->a);
426   if (m > n) {
427     // Leading coefficient of a.
428     while(T_ARR == a->type) a = darray_last(a->a);
429     PBC_ASSERT(T_MPZ == a->type, "no such type");
430     return mpz_sgn(a->z);
431   }
432   if (n > m) {
433     // Leading coefficient of b.
434     while(T_ARR == b->type) b = darray_last(b->a);
435     PBC_ASSERT(T_MPZ == b->type, "no such type");
436     return -mpz_sgn(b->z);
437   }
438   for(n--; n >= 0; n--) {
439     int i = multiz_cmp(darray_at(a->a, n), darray_at(b->a, n));
440     if (i) return i;
441   }
442   return 0;
443 }
f_cmp(element_ptr x,element_ptr y)444 static int f_cmp(element_ptr x, element_ptr y) {
445   return multiz_cmp(x->data, y->data);
446 }
447 
f_field_clear(field_t f)448 static void f_field_clear(field_t f) { UNUSED_VAR (f); }
449 
450 // OpenSSL convention:
451 //   4 bytes containing length
452 //   followed by number in big-endian, most-significant bit set if negative
453 //   (prepending null byte if necessary)
454 // Positive numbers also the same as mpz_out_raw.
z_to_bytes(unsigned char * data,element_t e)455 static int z_to_bytes(unsigned char *data, element_t e) {
456   mpz_ptr z = e->data;
457   size_t msb = mpz_sizeinbase(z, 2);
458   size_t n = 4;
459   size_t i;
460 
461   if (!(msb % 8)) {
462     data[4] = 0;
463     n++;
464   }
465   if (mpz_sgn(z) < 0) {
466     mpz_export(data + n, NULL, 1, 1, 1, 0, z);
467     data[4] |= 128;
468   } else {
469     mpz_export(data + n, NULL, 1, 1, 1, 0, z);
470   }
471   n += (msb + 7) / 8 - 4;
472   for (i=0; i<4; i++) {
473     data[i] = (n >> 8 * (3 - i));
474   }
475   n += 4;
476 
477   return n;
478 }
479 
z_from_bytes(element_t e,unsigned char * data)480 static int z_from_bytes(element_t e, unsigned char *data) {
481   unsigned char *ptr;
482   size_t i, n;
483   mpz_ptr z = e->data;
484   mpz_t z1;
485   int neg = 0;
486 
487   mpz_init(z1);
488   mpz_set_ui(z, 0);
489 
490   ptr = data;
491   n = 0;
492   for (i=0; i<4; i++) {
493     n += ((unsigned int) *ptr) << 8 * (3 - i);
494     ptr++;
495   }
496   if (data[4] & 128) {
497     neg = 1;
498     data[4] &= 127;
499   }
500   for (i=0; i<n; i++) {
501     mpz_set_ui(z1, *ptr);
502     mpz_mul_2exp(z1, z1, 8 * (n - 1 - i));
503     ptr++;
504     mpz_add(z, z, z1);
505   }
506   mpz_clear(z1);
507   if (neg) mpz_neg(z, z);
508   return n;
509 }
510 
z_length_in_bytes(element_ptr a)511 static int z_length_in_bytes(element_ptr a) {
512   return (mpz_sizeinbase(a->data, 2) + 7) / 8 + 4;
513 }
514 
f_out_info(FILE * out,field_ptr f)515 static void f_out_info(FILE *out, field_ptr f) {
516   UNUSED_VAR(f);
517   fprintf(out, "Z multinomials");
518 }
519 
f_set_str(element_ptr e,const char * s,int base)520 static int f_set_str(element_ptr e, const char *s, int base) {
521   // TODO: Square brackets.
522   mpz_t z;
523   mpz_init(z);
524   int result = pbc_mpz_set_str(z, s, base);
525   f_set_mpz(e, z);
526   mpz_clear(z);
527   return result;
528 }
529 
f_set_multiz(element_ptr e,multiz m)530 static void f_set_multiz(element_ptr e, multiz m) {
531   multiz delme = e->data;
532   e->data = multiz_clone(m);
533   multiz_free(delme);
534 }
535 
field_init_multiz(field_ptr f)536 void field_init_multiz(field_ptr f) {
537   field_init(f);
538   f->init = f_init;
539   f->clear = f_clear;
540   f->set_si = f_set_si;
541   f->set_mpz = f_set_mpz;
542   f->set_multiz = f_set_multiz;
543   f->set_str = f_set_str;
544   f->out_str = f_out_str;
545   f->sign = f_sgn;
546   f->add = f_add;
547   f->sub = f_sub;
548   f->set = f_set;
549   f->mul = f_mul;
550   f->mul_mpz = f_mul_mpz;
551   f->mul_si = f_mul_si;
552   f->neg = f_neg;
553   f->cmp = f_cmp;
554   f->div = f_div;
555   f->random = f_random;
556   f->from_hash = f_from_hash;
557   f->is1 = f_is1;
558   f->is0 = f_is0;
559   f->set0 = f_set0;
560   f->set1 = f_set1;
561   f->field_clear = f_field_clear;
562   f->to_bytes = z_to_bytes;
563   f->from_bytes = z_from_bytes;
564   f->to_mpz = f_to_mpz;
565   f->length_in_bytes = z_length_in_bytes;
566   f->item = f_item;
567   f->item_count = f_item_count;
568 
569   f->out_info = f_out_info;
570 
571   mpz_set_ui(f->order, 0);
572   f->data = NULL;
573   f->fixed_length_in_bytes = -1;
574 }
575 
multiz_is_z(multiz m)576 int multiz_is_z(multiz m) {
577   return T_MPZ == m->type;
578 }
579 
multiz_count(multiz m)580 int multiz_count(multiz m) {
581   if (T_ARR != m->type) return -1;
582   return darray_count(m->a);
583 }
584 
multiz_at(multiz m,int i)585 multiz multiz_at(multiz m, int i) {
586   PBC_ASSERT(T_ARR == m->type, "wrong type");
587   PBC_ASSERT(darray_count(m->a) > i, "out of bounds");
588   return darray_at(m->a, i);
589 }
590