1 /* Program for computing integer expressions using the GNU Multiple Precision
2    Arithmetic Library.
3 
4 Copyright 1997, 1999-2002, 2005, 2008, 2012, 2015 Free Software Foundation, Inc.
5 
6 This program is free software; you can redistribute it and/or modify it under
7 the terms of the GNU General Public License as published by the Free Software
8 Foundation; either version 3 of the License, or (at your option) any later
9 version.
10 
11 This program is distributed in the hope that it will be useful, but WITHOUT ANY
12 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
13 PARTICULAR PURPOSE.  See the GNU General Public License for more details.
14 
15 You should have received a copy of the GNU General Public License along with
16 this program.  If not, see https://www.gnu.org/licenses/.  */
17 
18 
19 /* This expressions evaluator works by building an expression tree (using a
20    recursive descent parser) which is then evaluated.  The expression tree is
21    useful since we want to optimize certain expressions (like a^b % c).
22 
23    Usage: pexpr [options] expr ...
24    (Assuming you called the executable `pexpr' of course.)
25 
26    Command line options:
27 
28    -b        print output in binary
29    -o        print output in octal
30    -d        print output in decimal (the default)
31    -x        print output in hexadecimal
32    -b<NUM>   print output in base NUM
33    -t        print timing information
34    -html     output html
35    -wml      output wml
36    -split    split long lines each 80th digit
37 */
38 
39 /* Define LIMIT_RESOURCE_USAGE if you want to make sure the program doesn't
40    use up extensive resources (cpu, memory).  Useful for the GMP demo on the
41    GMP web site, since we cannot load the server too much.  */
42 
43 #include "pexpr-config.h"
44 
45 #include <string.h>
46 #include <stdio.h>
47 #include <stdlib.h>
48 #include <setjmp.h>
49 #include <signal.h>
50 #include <ctype.h>
51 
52 #include <time.h>
53 #include <sys/types.h>
54 #include <sys/time.h>
55 #if HAVE_SYS_RESOURCE_H
56 #include <sys/resource.h>
57 #endif
58 
59 #include "gmp.h"
60 
61 /* SunOS 4 and HPUX 9 don't define a canonical SIGSTKSZ, use a default. */
62 #ifndef SIGSTKSZ
63 #define SIGSTKSZ  4096
64 #endif
65 
66 
67 #define TIME(t,func)							\
68   do { int __t0, __tmp;							\
69     __t0 = cputime ();							\
70     {func;}								\
71     __tmp = cputime () - __t0;						\
72     (t) = __tmp;							\
73   } while (0)
74 
75 /* GMP version 1.x compatibility.  */
76 #if ! (__GNU_MP_VERSION >= 2)
77 typedef MP_INT __mpz_struct;
78 typedef __mpz_struct mpz_t[1];
79 typedef __mpz_struct *mpz_ptr;
80 #define mpz_fdiv_q	mpz_div
81 #define mpz_fdiv_r	mpz_mod
82 #define mpz_tdiv_q_2exp	mpz_div_2exp
83 #define mpz_sgn(Z) ((Z)->size < 0 ? -1 : (Z)->size > 0)
84 #endif
85 
86 /* GMP version 2.0 compatibility.  */
87 #if ! (__GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1)
88 #define mpz_swap(a,b) \
89   do { __mpz_struct __t; __t = *a; *a = *b; *b = __t;} while (0)
90 #endif
91 
92 jmp_buf errjmpbuf;
93 
94 enum op_t {NOP, LIT, NEG, NOT, PLUS, MINUS, MULT, DIV, MOD, REM, INVMOD, POW,
95 	   AND, IOR, XOR, SLL, SRA, POPCNT, HAMDIST, GCD, LCM, SQRT, ROOT, FAC,
96 	   LOG, LOG2, FERMAT, MERSENNE, FIBONACCI, RANDOM, NEXTPRIME, BINOM,
97 	   TIMING};
98 
99 /* Type for the expression tree.  */
100 struct expr
101 {
102   enum op_t op;
103   union
104   {
105     struct {struct expr *lhs, *rhs;} ops;
106     mpz_t val;
107   } operands;
108 };
109 
110 typedef struct expr *expr_t;
111 
112 void cleanup_and_exit (int);
113 
114 char *skipspace (char *);
115 void makeexp (expr_t *, enum op_t, expr_t, expr_t);
116 void free_expr (expr_t);
117 char *expr (char *, expr_t *);
118 char *term (char *, expr_t *);
119 char *power (char *, expr_t *);
120 char *factor (char *, expr_t *);
121 int match (char *, char *);
122 int matchp (char *, char *);
123 int cputime (void);
124 
125 void mpz_eval_expr (mpz_ptr, expr_t);
126 void mpz_eval_mod_expr (mpz_ptr, expr_t, mpz_ptr);
127 
128 char *error;
129 int flag_print = 1;
130 int print_timing = 0;
131 int flag_html = 0;
132 int flag_wml = 0;
133 int flag_splitup_output = 0;
134 char *newline = "";
135 gmp_randstate_t rstate;
136 
137 
138 
139 /* cputime() returns user CPU time measured in milliseconds.  */
140 #if ! HAVE_CPUTIME
141 #if HAVE_GETRUSAGE
142 int
cputime(void)143 cputime (void)
144 {
145   struct rusage rus;
146 
147   getrusage (0, &rus);
148   return rus.ru_utime.tv_sec * 1000 + rus.ru_utime.tv_usec / 1000;
149 }
150 #else
151 #if HAVE_CLOCK
152 int
cputime(void)153 cputime (void)
154 {
155   if (CLOCKS_PER_SEC < 100000)
156     return clock () * 1000 / CLOCKS_PER_SEC;
157   return clock () / (CLOCKS_PER_SEC / 1000);
158 }
159 #else
160 int
cputime(void)161 cputime (void)
162 {
163   return 0;
164 }
165 #endif
166 #endif
167 #endif
168 
169 
170 int
stack_downwards_helper(char * xp)171 stack_downwards_helper (char *xp)
172 {
173   char  y;
174   return &y < xp;
175 }
176 int
stack_downwards_p(void)177 stack_downwards_p (void)
178 {
179   char  x;
180   return stack_downwards_helper (&x);
181 }
182 
183 
184 void
setup_error_handler(void)185 setup_error_handler (void)
186 {
187 #if HAVE_SIGACTION
188   struct sigaction act;
189   act.sa_handler = cleanup_and_exit;
190   sigemptyset (&(act.sa_mask));
191 #define SIGNAL(sig)  sigaction (sig, &act, NULL)
192 #else
193   struct { int sa_flags; } act;
194 #define SIGNAL(sig)  signal (sig, cleanup_and_exit)
195 #endif
196   act.sa_flags = 0;
197 
198   /* Set up a stack for signal handling.  A typical cause of error is stack
199      overflow, and in such situation a signal can not be delivered on the
200      overflown stack.  */
201 #if HAVE_SIGALTSTACK
202   {
203     /* AIX uses stack_t, MacOS uses struct sigaltstack, various other
204        systems have both. */
205 #if HAVE_STACK_T
206     stack_t s;
207 #else
208     struct sigaltstack s;
209 #endif
210     s.ss_sp = malloc (SIGSTKSZ);
211     s.ss_size = SIGSTKSZ;
212     s.ss_flags = 0;
213     if (sigaltstack (&s, NULL) != 0)
214       perror("sigaltstack");
215     act.sa_flags = SA_ONSTACK;
216   }
217 #else
218 #if HAVE_SIGSTACK
219   {
220     struct sigstack s;
221     s.ss_sp = malloc (SIGSTKSZ);
222     if (stack_downwards_p ())
223       s.ss_sp += SIGSTKSZ;
224     s.ss_onstack = 0;
225     if (sigstack (&s, NULL) != 0)
226       perror("sigstack");
227     act.sa_flags = SA_ONSTACK;
228   }
229 #else
230 #endif
231 #endif
232 
233 #ifdef LIMIT_RESOURCE_USAGE
234   {
235     struct rlimit limit;
236 
237     limit.rlim_cur = limit.rlim_max = 0;
238     setrlimit (RLIMIT_CORE, &limit);
239 
240     limit.rlim_cur = 3;
241     limit.rlim_max = 4;
242     setrlimit (RLIMIT_CPU, &limit);
243 
244     limit.rlim_cur = limit.rlim_max = 16 * 1024 * 1024;
245     setrlimit (RLIMIT_DATA, &limit);
246 
247     getrlimit (RLIMIT_STACK, &limit);
248     limit.rlim_cur = 4 * 1024 * 1024;
249     setrlimit (RLIMIT_STACK, &limit);
250 
251     SIGNAL (SIGXCPU);
252   }
253 #endif /* LIMIT_RESOURCE_USAGE */
254 
255   SIGNAL (SIGILL);
256   SIGNAL (SIGSEGV);
257 #ifdef SIGBUS /* not in mingw */
258   SIGNAL (SIGBUS);
259 #endif
260   SIGNAL (SIGFPE);
261   SIGNAL (SIGABRT);
262 }
263 
264 int
main(int argc,char ** argv)265 main (int argc, char **argv)
266 {
267   struct expr *e;
268   int i;
269   mpz_t r;
270   int errcode = 0;
271   char *str;
272   int base = 10;
273 
274   setup_error_handler ();
275 
276   gmp_randinit (rstate, GMP_RAND_ALG_LC, 128);
277 
278   {
279 #if HAVE_GETTIMEOFDAY
280     struct timeval tv;
281     gettimeofday (&tv, NULL);
282     gmp_randseed_ui (rstate, tv.tv_sec + tv.tv_usec);
283 #else
284     time_t t;
285     time (&t);
286     gmp_randseed_ui (rstate, t);
287 #endif
288   }
289 
290   mpz_init (r);
291 
292   while (argc > 1 && argv[1][0] == '-')
293     {
294       char *arg = argv[1];
295 
296       if (arg[1] >= '0' && arg[1] <= '9')
297 	break;
298 
299       if (arg[1] == 't')
300 	print_timing = 1;
301       else if (arg[1] == 'b' && arg[2] >= '0' && arg[2] <= '9')
302 	{
303 	  base = atoi (arg + 2);
304 	  if (base < 2 || base > 62)
305 	    {
306 	      fprintf (stderr, "error: invalid output base\n");
307 	      exit (-1);
308 	    }
309 	}
310       else if (arg[1] == 'b' && arg[2] == 0)
311 	base = 2;
312       else if (arg[1] == 'x' && arg[2] == 0)
313 	base = 16;
314       else if (arg[1] == 'X' && arg[2] == 0)
315 	base = -16;
316       else if (arg[1] == 'o' && arg[2] == 0)
317 	base = 8;
318       else if (arg[1] == 'd' && arg[2] == 0)
319 	base = 10;
320       else if (arg[1] == 'v' && arg[2] == 0)
321 	{
322 	  printf ("pexpr linked to gmp %s\n", __gmp_version);
323 	}
324       else if (strcmp (arg, "-html") == 0)
325 	{
326 	  flag_html = 1;
327 	  newline = "<br>";
328 	}
329       else if (strcmp (arg, "-wml") == 0)
330 	{
331 	  flag_wml = 1;
332 	  newline = "<br/>";
333 	}
334       else if (strcmp (arg, "-split") == 0)
335 	{
336 	  flag_splitup_output = 1;
337 	}
338       else if (strcmp (arg, "-noprint") == 0)
339 	{
340 	  flag_print = 0;
341 	}
342       else
343 	{
344 	  fprintf (stderr, "error: unknown option `%s'\n", arg);
345 	  exit (-1);
346 	}
347       argv++;
348       argc--;
349     }
350 
351   for (i = 1; i < argc; i++)
352     {
353       int s;
354       int jmpval;
355 
356       /* Set up error handler for parsing expression.  */
357       jmpval = setjmp (errjmpbuf);
358       if (jmpval != 0)
359 	{
360 	  fprintf (stderr, "error: %s%s\n", error, newline);
361 	  fprintf (stderr, "       %s%s\n", argv[i], newline);
362 	  if (! flag_html)
363 	    {
364 	      /* ??? Dunno how to align expression position with arrow in
365 		 HTML ??? */
366 	      fprintf (stderr, "       ");
367 	      for (s = jmpval - (long) argv[i]; --s >= 0; )
368 		putc (' ', stderr);
369 	      fprintf (stderr, "^\n");
370 	    }
371 
372 	  errcode |= 1;
373 	  continue;
374 	}
375 
376       str = expr (argv[i], &e);
377 
378       if (str[0] != 0)
379 	{
380 	  fprintf (stderr,
381 		   "error: garbage where end of expression expected%s\n",
382 		   newline);
383 	  fprintf (stderr, "       %s%s\n", argv[i], newline);
384 	  if (! flag_html)
385 	    {
386 	      /* ??? Dunno how to align expression position with arrow in
387 		 HTML ??? */
388 	      fprintf (stderr, "        ");
389 	      for (s = str - argv[i]; --s; )
390 		putc (' ', stderr);
391 	      fprintf (stderr, "^\n");
392 	    }
393 
394 	  errcode |= 1;
395 	  free_expr (e);
396 	  continue;
397 	}
398 
399       /* Set up error handler for evaluating expression.  */
400       if (setjmp (errjmpbuf))
401 	{
402 	  fprintf (stderr, "error: %s%s\n", error, newline);
403 	  fprintf (stderr, "       %s%s\n", argv[i], newline);
404 	  if (! flag_html)
405 	    {
406 	      /* ??? Dunno how to align expression position with arrow in
407 		 HTML ??? */
408 	      fprintf (stderr, "       ");
409 	      for (s = str - argv[i]; --s >= 0; )
410 		putc (' ', stderr);
411 	      fprintf (stderr, "^\n");
412 	    }
413 
414 	  errcode |= 2;
415 	  continue;
416 	}
417 
418       if (print_timing)
419 	{
420 	  int t;
421 	  TIME (t, mpz_eval_expr (r, e));
422 	  printf ("computation took %d ms%s\n", t, newline);
423 	}
424       else
425 	mpz_eval_expr (r, e);
426 
427       if (flag_print)
428 	{
429 	  size_t out_len;
430 	  char *tmp, *s;
431 
432 	  out_len = mpz_sizeinbase (r, base >= 0 ? base : -base) + 2;
433 #ifdef LIMIT_RESOURCE_USAGE
434 	  if (out_len > 100000)
435 	    {
436 	      printf ("result is about %ld digits, not printing it%s\n",
437 		      (long) out_len - 3, newline);
438 	      exit (-2);
439 	    }
440 #endif
441 	  tmp = malloc (out_len);
442 
443 	  if (print_timing)
444 	    {
445 	      int t;
446 	      printf ("output conversion ");
447 	      TIME (t, mpz_get_str (tmp, base, r));
448 	      printf ("took %d ms%s\n", t, newline);
449 	    }
450 	  else
451 	    mpz_get_str (tmp, base, r);
452 
453 	  out_len = strlen (tmp);
454 	  if (flag_splitup_output)
455 	    {
456 	      for (s = tmp; out_len > 80; s += 80)
457 		{
458 		  fwrite (s, 1, 80, stdout);
459 		  printf ("%s\n", newline);
460 		  out_len -= 80;
461 		}
462 
463 	      fwrite (s, 1, out_len, stdout);
464 	    }
465 	  else
466 	    {
467 	      fwrite (tmp, 1, out_len, stdout);
468 	    }
469 
470 	  free (tmp);
471 	  printf ("%s\n", newline);
472 	}
473       else
474 	{
475 	  printf ("result is approximately %ld digits%s\n",
476 		  (long) mpz_sizeinbase (r, base >= 0 ? base : -base),
477 		  newline);
478 	}
479 
480       free_expr (e);
481     }
482 
483   mpz_clear (r);
484 
485   exit (errcode);
486 }
487 
488 char *
expr(char * str,expr_t * e)489 expr (char *str, expr_t *e)
490 {
491   expr_t e2;
492 
493   str = skipspace (str);
494   if (str[0] == '+')
495     {
496       str = term (str + 1, e);
497     }
498   else if (str[0] == '-')
499     {
500       str = term (str + 1, e);
501       makeexp (e, NEG, *e, NULL);
502     }
503   else if (str[0] == '~')
504     {
505       str = term (str + 1, e);
506       makeexp (e, NOT, *e, NULL);
507     }
508   else
509     {
510       str = term (str, e);
511     }
512 
513   for (;;)
514     {
515       str = skipspace (str);
516       switch (str[0])
517 	{
518 	case 'p':
519 	  if (match ("plus", str))
520 	    {
521 	      str = term (str + 4, &e2);
522 	      makeexp (e, PLUS, *e, e2);
523 	    }
524 	  else
525 	    return str;
526 	  break;
527 	case 'm':
528 	  if (match ("minus", str))
529 	    {
530 	      str = term (str + 5, &e2);
531 	      makeexp (e, MINUS, *e, e2);
532 	    }
533 	  else
534 	    return str;
535 	  break;
536 	case '+':
537 	  str = term (str + 1, &e2);
538 	  makeexp (e, PLUS, *e, e2);
539 	  break;
540 	case '-':
541 	  str = term (str + 1, &e2);
542 	  makeexp (e, MINUS, *e, e2);
543 	  break;
544 	default:
545 	  return str;
546 	}
547     }
548 }
549 
550 char *
term(char * str,expr_t * e)551 term (char *str, expr_t *e)
552 {
553   expr_t e2;
554 
555   str = power (str, e);
556   for (;;)
557     {
558       str = skipspace (str);
559       switch (str[0])
560 	{
561 	case 'm':
562 	  if (match ("mul", str))
563 	    {
564 	      str = power (str + 3, &e2);
565 	      makeexp (e, MULT, *e, e2);
566 	      break;
567 	    }
568 	  if (match ("mod", str))
569 	    {
570 	      str = power (str + 3, &e2);
571 	      makeexp (e, MOD, *e, e2);
572 	      break;
573 	    }
574 	  return str;
575 	case 'd':
576 	  if (match ("div", str))
577 	    {
578 	      str = power (str + 3, &e2);
579 	      makeexp (e, DIV, *e, e2);
580 	      break;
581 	    }
582 	  return str;
583 	case 'r':
584 	  if (match ("rem", str))
585 	    {
586 	      str = power (str + 3, &e2);
587 	      makeexp (e, REM, *e, e2);
588 	      break;
589 	    }
590 	  return str;
591 	case 'i':
592 	  if (match ("invmod", str))
593 	    {
594 	      str = power (str + 6, &e2);
595 	      makeexp (e, REM, *e, e2);
596 	      break;
597 	    }
598 	  return str;
599 	case 't':
600 	  if (match ("times", str))
601 	    {
602 	      str = power (str + 5, &e2);
603 	      makeexp (e, MULT, *e, e2);
604 	      break;
605 	    }
606 	  if (match ("thru", str))
607 	    {
608 	      str = power (str + 4, &e2);
609 	      makeexp (e, DIV, *e, e2);
610 	      break;
611 	    }
612 	  if (match ("through", str))
613 	    {
614 	      str = power (str + 7, &e2);
615 	      makeexp (e, DIV, *e, e2);
616 	      break;
617 	    }
618 	  return str;
619 	case '*':
620 	  str = power (str + 1, &e2);
621 	  makeexp (e, MULT, *e, e2);
622 	  break;
623 	case '/':
624 	  str = power (str + 1, &e2);
625 	  makeexp (e, DIV, *e, e2);
626 	  break;
627 	case '%':
628 	  str = power (str + 1, &e2);
629 	  makeexp (e, MOD, *e, e2);
630 	  break;
631 	default:
632 	  return str;
633 	}
634     }
635 }
636 
637 char *
power(char * str,expr_t * e)638 power (char *str, expr_t *e)
639 {
640   expr_t e2;
641 
642   str = factor (str, e);
643   while (str[0] == '!')
644     {
645       str++;
646       makeexp (e, FAC, *e, NULL);
647     }
648   str = skipspace (str);
649   if (str[0] == '^')
650     {
651       str = power (str + 1, &e2);
652       makeexp (e, POW, *e, e2);
653     }
654   return str;
655 }
656 
657 int
match(char * s,char * str)658 match (char *s, char *str)
659 {
660   char *ostr = str;
661   int i;
662 
663   for (i = 0; s[i] != 0; i++)
664     {
665       if (str[i] != s[i])
666 	return 0;
667     }
668   str = skipspace (str + i);
669   return str - ostr;
670 }
671 
672 int
matchp(char * s,char * str)673 matchp (char *s, char *str)
674 {
675   char *ostr = str;
676   int i;
677 
678   for (i = 0; s[i] != 0; i++)
679     {
680       if (str[i] != s[i])
681 	return 0;
682     }
683   str = skipspace (str + i);
684   if (str[0] == '(')
685     return str - ostr + 1;
686   return 0;
687 }
688 
689 struct functions
690 {
691   char *spelling;
692   enum op_t op;
693   int arity; /* 1 or 2 means real arity; 0 means arbitrary.  */
694 };
695 
696 struct functions fns[] =
697 {
698   {"sqrt", SQRT, 1},
699 #if __GNU_MP_VERSION >= 2
700   {"root", ROOT, 2},
701   {"popc", POPCNT, 1},
702   {"hamdist", HAMDIST, 2},
703 #endif
704   {"gcd", GCD, 0},
705 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
706   {"lcm", LCM, 0},
707 #endif
708   {"and", AND, 0},
709   {"ior", IOR, 0},
710 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
711   {"xor", XOR, 0},
712 #endif
713   {"plus", PLUS, 0},
714   {"pow", POW, 2},
715   {"minus", MINUS, 2},
716   {"mul", MULT, 0},
717   {"div", DIV, 2},
718   {"mod", MOD, 2},
719   {"rem", REM, 2},
720 #if __GNU_MP_VERSION >= 2
721   {"invmod", INVMOD, 2},
722 #endif
723   {"log", LOG, 2},
724   {"log2", LOG2, 1},
725   {"F", FERMAT, 1},
726   {"M", MERSENNE, 1},
727   {"fib", FIBONACCI, 1},
728   {"Fib", FIBONACCI, 1},
729   {"random", RANDOM, 1},
730   {"nextprime", NEXTPRIME, 1},
731   {"binom", BINOM, 2},
732   {"binomial", BINOM, 2},
733   {"fac", FAC, 1},
734   {"fact", FAC, 1},
735   {"factorial", FAC, 1},
736   {"time", TIMING, 1},
737   {"", NOP, 0}
738 };
739 
740 char *
factor(char * str,expr_t * e)741 factor (char *str, expr_t *e)
742 {
743   expr_t e1, e2;
744 
745   str = skipspace (str);
746 
747   if (isalpha (str[0]))
748     {
749       int i;
750       int cnt;
751 
752       for (i = 0; fns[i].op != NOP; i++)
753 	{
754 	  if (fns[i].arity == 1)
755 	    {
756 	      cnt = matchp (fns[i].spelling, str);
757 	      if (cnt != 0)
758 		{
759 		  str = expr (str + cnt, &e1);
760 		  str = skipspace (str);
761 		  if (str[0] != ')')
762 		    {
763 		      error = "expected `)'";
764 		      longjmp (errjmpbuf, (int) (long) str);
765 		    }
766 		  makeexp (e, fns[i].op, e1, NULL);
767 		  return str + 1;
768 		}
769 	    }
770 	}
771 
772       for (i = 0; fns[i].op != NOP; i++)
773 	{
774 	  if (fns[i].arity != 1)
775 	    {
776 	      cnt = matchp (fns[i].spelling, str);
777 	      if (cnt != 0)
778 		{
779 		  str = expr (str + cnt, &e1);
780 		  str = skipspace (str);
781 
782 		  if (str[0] != ',')
783 		    {
784 		      error = "expected `,' and another operand";
785 		      longjmp (errjmpbuf, (int) (long) str);
786 		    }
787 
788 		  str = skipspace (str + 1);
789 		  str = expr (str, &e2);
790 		  str = skipspace (str);
791 
792 		  if (fns[i].arity == 0)
793 		    {
794 		      while (str[0] == ',')
795 			{
796 			  makeexp (&e1, fns[i].op, e1, e2);
797 			  str = skipspace (str + 1);
798 			  str = expr (str, &e2);
799 			  str = skipspace (str);
800 			}
801 		    }
802 
803 		  if (str[0] != ')')
804 		    {
805 		      error = "expected `)'";
806 		      longjmp (errjmpbuf, (int) (long) str);
807 		    }
808 
809 		  makeexp (e, fns[i].op, e1, e2);
810 		  return str + 1;
811 		}
812 	    }
813 	}
814     }
815 
816   if (str[0] == '(')
817     {
818       str = expr (str + 1, e);
819       str = skipspace (str);
820       if (str[0] != ')')
821 	{
822 	  error = "expected `)'";
823 	  longjmp (errjmpbuf, (int) (long) str);
824 	}
825       str++;
826     }
827   else if (str[0] >= '0' && str[0] <= '9')
828     {
829       expr_t res;
830       char *s, *sc;
831 
832       res = malloc (sizeof (struct expr));
833       res -> op = LIT;
834       mpz_init (res->operands.val);
835 
836       s = str;
837       while (isalnum (str[0]))
838 	str++;
839       sc = malloc (str - s + 1);
840       memcpy (sc, s, str - s);
841       sc[str - s] = 0;
842 
843       mpz_set_str (res->operands.val, sc, 0);
844       *e = res;
845       free (sc);
846     }
847   else
848     {
849       error = "operand expected";
850       longjmp (errjmpbuf, (int) (long) str);
851     }
852   return str;
853 }
854 
855 char *
skipspace(char * str)856 skipspace (char *str)
857 {
858   while (str[0] == ' ')
859     str++;
860   return str;
861 }
862 
863 /* Make a new expression with operation OP and right hand side
864    RHS and left hand side lhs.  Put the result in R.  */
865 void
makeexp(expr_t * r,enum op_t op,expr_t lhs,expr_t rhs)866 makeexp (expr_t *r, enum op_t op, expr_t lhs, expr_t rhs)
867 {
868   expr_t res;
869   res = malloc (sizeof (struct expr));
870   res -> op = op;
871   res -> operands.ops.lhs = lhs;
872   res -> operands.ops.rhs = rhs;
873   *r = res;
874   return;
875 }
876 
877 /* Free the memory used by expression E.  */
878 void
free_expr(expr_t e)879 free_expr (expr_t e)
880 {
881   if (e->op != LIT)
882     {
883       free_expr (e->operands.ops.lhs);
884       if (e->operands.ops.rhs != NULL)
885 	free_expr (e->operands.ops.rhs);
886     }
887   else
888     {
889       mpz_clear (e->operands.val);
890     }
891 }
892 
893 /* Evaluate the expression E and put the result in R.  */
894 void
mpz_eval_expr(mpz_ptr r,expr_t e)895 mpz_eval_expr (mpz_ptr r, expr_t e)
896 {
897   mpz_t lhs, rhs;
898 
899   switch (e->op)
900     {
901     case LIT:
902       mpz_set (r, e->operands.val);
903       return;
904     case PLUS:
905       mpz_init (lhs); mpz_init (rhs);
906       mpz_eval_expr (lhs, e->operands.ops.lhs);
907       mpz_eval_expr (rhs, e->operands.ops.rhs);
908       mpz_add (r, lhs, rhs);
909       mpz_clear (lhs); mpz_clear (rhs);
910       return;
911     case MINUS:
912       mpz_init (lhs); mpz_init (rhs);
913       mpz_eval_expr (lhs, e->operands.ops.lhs);
914       mpz_eval_expr (rhs, e->operands.ops.rhs);
915       mpz_sub (r, lhs, rhs);
916       mpz_clear (lhs); mpz_clear (rhs);
917       return;
918     case MULT:
919       mpz_init (lhs); mpz_init (rhs);
920       mpz_eval_expr (lhs, e->operands.ops.lhs);
921       mpz_eval_expr (rhs, e->operands.ops.rhs);
922       mpz_mul (r, lhs, rhs);
923       mpz_clear (lhs); mpz_clear (rhs);
924       return;
925     case DIV:
926       mpz_init (lhs); mpz_init (rhs);
927       mpz_eval_expr (lhs, e->operands.ops.lhs);
928       mpz_eval_expr (rhs, e->operands.ops.rhs);
929       mpz_fdiv_q (r, lhs, rhs);
930       mpz_clear (lhs); mpz_clear (rhs);
931       return;
932     case MOD:
933       mpz_init (rhs);
934       mpz_eval_expr (rhs, e->operands.ops.rhs);
935       mpz_abs (rhs, rhs);
936       mpz_eval_mod_expr (r, e->operands.ops.lhs, rhs);
937       mpz_clear (rhs);
938       return;
939     case REM:
940       /* Check if lhs operand is POW expression and optimize for that case.  */
941       if (e->operands.ops.lhs->op == POW)
942 	{
943 	  mpz_t powlhs, powrhs;
944 	  mpz_init (powlhs);
945 	  mpz_init (powrhs);
946 	  mpz_init (rhs);
947 	  mpz_eval_expr (powlhs, e->operands.ops.lhs->operands.ops.lhs);
948 	  mpz_eval_expr (powrhs, e->operands.ops.lhs->operands.ops.rhs);
949 	  mpz_eval_expr (rhs, e->operands.ops.rhs);
950 	  mpz_powm (r, powlhs, powrhs, rhs);
951 	  if (mpz_cmp_si (rhs, 0L) < 0)
952 	    mpz_neg (r, r);
953 	  mpz_clear (powlhs);
954 	  mpz_clear (powrhs);
955 	  mpz_clear (rhs);
956 	  return;
957 	}
958 
959       mpz_init (lhs); mpz_init (rhs);
960       mpz_eval_expr (lhs, e->operands.ops.lhs);
961       mpz_eval_expr (rhs, e->operands.ops.rhs);
962       mpz_fdiv_r (r, lhs, rhs);
963       mpz_clear (lhs); mpz_clear (rhs);
964       return;
965 #if __GNU_MP_VERSION >= 2
966     case INVMOD:
967       mpz_init (lhs); mpz_init (rhs);
968       mpz_eval_expr (lhs, e->operands.ops.lhs);
969       mpz_eval_expr (rhs, e->operands.ops.rhs);
970       mpz_invert (r, lhs, rhs);
971       mpz_clear (lhs); mpz_clear (rhs);
972       return;
973 #endif
974     case POW:
975       mpz_init (lhs); mpz_init (rhs);
976       mpz_eval_expr (lhs, e->operands.ops.lhs);
977       if (mpz_cmpabs_ui (lhs, 1) <= 0)
978 	{
979 	  /* For 0^rhs and 1^rhs, we just need to verify that
980 	     rhs is well-defined.  For (-1)^rhs we need to
981 	     determine (rhs mod 2).  For simplicity, compute
982 	     (rhs mod 2) for all three cases.  */
983 	  expr_t two, et;
984 	  two = malloc (sizeof (struct expr));
985 	  two -> op = LIT;
986 	  mpz_init_set_ui (two->operands.val, 2L);
987 	  makeexp (&et, MOD, e->operands.ops.rhs, two);
988 	  e->operands.ops.rhs = et;
989 	}
990 
991       mpz_eval_expr (rhs, e->operands.ops.rhs);
992       if (mpz_cmp_si (rhs, 0L) == 0)
993 	/* x^0 is 1 */
994 	mpz_set_ui (r, 1L);
995       else if (mpz_cmp_si (lhs, 0L) == 0)
996 	/* 0^y (where y != 0) is 0 */
997 	mpz_set_ui (r, 0L);
998       else if (mpz_cmp_ui (lhs, 1L) == 0)
999 	/* 1^y is 1 */
1000 	mpz_set_ui (r, 1L);
1001       else if (mpz_cmp_si (lhs, -1L) == 0)
1002 	/* (-1)^y just depends on whether y is even or odd */
1003 	mpz_set_si (r, (mpz_get_ui (rhs) & 1) ? -1L : 1L);
1004       else if (mpz_cmp_si (rhs, 0L) < 0)
1005 	/* x^(-n) is 0 */
1006 	mpz_set_ui (r, 0L);
1007       else
1008 	{
1009 	  unsigned long int cnt;
1010 	  unsigned long int y;
1011 	  /* error if exponent does not fit into an unsigned long int.  */
1012 	  if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1013 	    goto pow_err;
1014 
1015 	  y = mpz_get_ui (rhs);
1016 	  /* x^y == (x/(2^c))^y * 2^(c*y) */
1017 #if __GNU_MP_VERSION >= 2
1018 	  cnt = mpz_scan1 (lhs, 0);
1019 #else
1020 	  cnt = 0;
1021 #endif
1022 	  if (cnt != 0)
1023 	    {
1024 	      if (y * cnt / cnt != y)
1025 		goto pow_err;
1026 	      mpz_tdiv_q_2exp (lhs, lhs, cnt);
1027 	      mpz_pow_ui (r, lhs, y);
1028 	      mpz_mul_2exp (r, r, y * cnt);
1029 	    }
1030 	  else
1031 	    mpz_pow_ui (r, lhs, y);
1032 	}
1033       mpz_clear (lhs); mpz_clear (rhs);
1034       return;
1035     pow_err:
1036       error = "result of `pow' operator too large";
1037       mpz_clear (lhs); mpz_clear (rhs);
1038       longjmp (errjmpbuf, 1);
1039     case GCD:
1040       mpz_init (lhs); mpz_init (rhs);
1041       mpz_eval_expr (lhs, e->operands.ops.lhs);
1042       mpz_eval_expr (rhs, e->operands.ops.rhs);
1043       mpz_gcd (r, lhs, rhs);
1044       mpz_clear (lhs); mpz_clear (rhs);
1045       return;
1046 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1047     case LCM:
1048       mpz_init (lhs); mpz_init (rhs);
1049       mpz_eval_expr (lhs, e->operands.ops.lhs);
1050       mpz_eval_expr (rhs, e->operands.ops.rhs);
1051       mpz_lcm (r, lhs, rhs);
1052       mpz_clear (lhs); mpz_clear (rhs);
1053       return;
1054 #endif
1055     case AND:
1056       mpz_init (lhs); mpz_init (rhs);
1057       mpz_eval_expr (lhs, e->operands.ops.lhs);
1058       mpz_eval_expr (rhs, e->operands.ops.rhs);
1059       mpz_and (r, lhs, rhs);
1060       mpz_clear (lhs); mpz_clear (rhs);
1061       return;
1062     case IOR:
1063       mpz_init (lhs); mpz_init (rhs);
1064       mpz_eval_expr (lhs, e->operands.ops.lhs);
1065       mpz_eval_expr (rhs, e->operands.ops.rhs);
1066       mpz_ior (r, lhs, rhs);
1067       mpz_clear (lhs); mpz_clear (rhs);
1068       return;
1069 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1070     case XOR:
1071       mpz_init (lhs); mpz_init (rhs);
1072       mpz_eval_expr (lhs, e->operands.ops.lhs);
1073       mpz_eval_expr (rhs, e->operands.ops.rhs);
1074       mpz_xor (r, lhs, rhs);
1075       mpz_clear (lhs); mpz_clear (rhs);
1076       return;
1077 #endif
1078     case NEG:
1079       mpz_eval_expr (r, e->operands.ops.lhs);
1080       mpz_neg (r, r);
1081       return;
1082     case NOT:
1083       mpz_eval_expr (r, e->operands.ops.lhs);
1084       mpz_com (r, r);
1085       return;
1086     case SQRT:
1087       mpz_init (lhs);
1088       mpz_eval_expr (lhs, e->operands.ops.lhs);
1089       if (mpz_sgn (lhs) < 0)
1090 	{
1091 	  error = "cannot take square root of negative numbers";
1092 	  mpz_clear (lhs);
1093 	  longjmp (errjmpbuf, 1);
1094 	}
1095       mpz_sqrt (r, lhs);
1096       return;
1097 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1098     case ROOT:
1099       mpz_init (lhs); mpz_init (rhs);
1100       mpz_eval_expr (lhs, e->operands.ops.lhs);
1101       mpz_eval_expr (rhs, e->operands.ops.rhs);
1102       if (mpz_sgn (rhs) <= 0)
1103 	{
1104 	  error = "cannot take non-positive root orders";
1105 	  mpz_clear (lhs); mpz_clear (rhs);
1106 	  longjmp (errjmpbuf, 1);
1107 	}
1108       if (mpz_sgn (lhs) < 0 && (mpz_get_ui (rhs) & 1) == 0)
1109 	{
1110 	  error = "cannot take even root orders of negative numbers";
1111 	  mpz_clear (lhs); mpz_clear (rhs);
1112 	  longjmp (errjmpbuf, 1);
1113 	}
1114 
1115       {
1116 	unsigned long int nth = mpz_get_ui (rhs);
1117 	if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1118 	  {
1119 	    /* If we are asked to take an awfully large root order, cheat and
1120 	       ask for the largest order we can pass to mpz_root.  This saves
1121 	       some error prone special cases.  */
1122 	    nth = ~(unsigned long int) 0;
1123 	  }
1124 	mpz_root (r, lhs, nth);
1125       }
1126       mpz_clear (lhs); mpz_clear (rhs);
1127       return;
1128 #endif
1129     case FAC:
1130       mpz_eval_expr (r, e->operands.ops.lhs);
1131       if (mpz_size (r) > 1)
1132 	{
1133 	  error = "result of `!' operator too large";
1134 	  longjmp (errjmpbuf, 1);
1135 	}
1136       mpz_fac_ui (r, mpz_get_ui (r));
1137       return;
1138 #if __GNU_MP_VERSION >= 2
1139     case POPCNT:
1140       mpz_eval_expr (r, e->operands.ops.lhs);
1141       { long int cnt;
1142 	cnt = mpz_popcount (r);
1143 	mpz_set_si (r, cnt);
1144       }
1145       return;
1146     case HAMDIST:
1147       { long int cnt;
1148 	mpz_init (lhs); mpz_init (rhs);
1149 	mpz_eval_expr (lhs, e->operands.ops.lhs);
1150 	mpz_eval_expr (rhs, e->operands.ops.rhs);
1151 	cnt = mpz_hamdist (lhs, rhs);
1152 	mpz_clear (lhs); mpz_clear (rhs);
1153 	mpz_set_si (r, cnt);
1154       }
1155       return;
1156 #endif
1157     case LOG2:
1158       mpz_eval_expr (r, e->operands.ops.lhs);
1159       { unsigned long int cnt;
1160 	if (mpz_sgn (r) <= 0)
1161 	  {
1162 	    error = "logarithm of non-positive number";
1163 	    longjmp (errjmpbuf, 1);
1164 	  }
1165 	cnt = mpz_sizeinbase (r, 2);
1166 	mpz_set_ui (r, cnt - 1);
1167       }
1168       return;
1169     case LOG:
1170       { unsigned long int cnt;
1171 	mpz_init (lhs); mpz_init (rhs);
1172 	mpz_eval_expr (lhs, e->operands.ops.lhs);
1173 	mpz_eval_expr (rhs, e->operands.ops.rhs);
1174 	if (mpz_sgn (lhs) <= 0)
1175 	  {
1176 	    error = "logarithm of non-positive number";
1177 	    mpz_clear (lhs); mpz_clear (rhs);
1178 	    longjmp (errjmpbuf, 1);
1179 	  }
1180 	if (mpz_cmp_ui (rhs, 256) >= 0)
1181 	  {
1182 	    error = "logarithm base too large";
1183 	    mpz_clear (lhs); mpz_clear (rhs);
1184 	    longjmp (errjmpbuf, 1);
1185 	  }
1186 	cnt = mpz_sizeinbase (lhs, mpz_get_ui (rhs));
1187 	mpz_set_ui (r, cnt - 1);
1188 	mpz_clear (lhs); mpz_clear (rhs);
1189       }
1190       return;
1191     case FERMAT:
1192       {
1193 	unsigned long int t;
1194 	mpz_init (lhs);
1195 	mpz_eval_expr (lhs, e->operands.ops.lhs);
1196 	t = (unsigned long int) 1 << mpz_get_ui (lhs);
1197 	if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0 || t == 0)
1198 	  {
1199 	    error = "too large Mersenne number index";
1200 	    mpz_clear (lhs);
1201 	    longjmp (errjmpbuf, 1);
1202 	  }
1203 	mpz_set_ui (r, 1);
1204 	mpz_mul_2exp (r, r, t);
1205 	mpz_add_ui (r, r, 1);
1206 	mpz_clear (lhs);
1207       }
1208       return;
1209     case MERSENNE:
1210       mpz_init (lhs);
1211       mpz_eval_expr (lhs, e->operands.ops.lhs);
1212       if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0)
1213 	{
1214 	  error = "too large Mersenne number index";
1215 	  mpz_clear (lhs);
1216 	  longjmp (errjmpbuf, 1);
1217 	}
1218       mpz_set_ui (r, 1);
1219       mpz_mul_2exp (r, r, mpz_get_ui (lhs));
1220       mpz_sub_ui (r, r, 1);
1221       mpz_clear (lhs);
1222       return;
1223     case FIBONACCI:
1224       { mpz_t t;
1225 	unsigned long int n, i;
1226 	mpz_init (lhs);
1227 	mpz_eval_expr (lhs, e->operands.ops.lhs);
1228 	if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1229 	  {
1230 	    error = "Fibonacci index out of range";
1231 	    mpz_clear (lhs);
1232 	    longjmp (errjmpbuf, 1);
1233 	  }
1234 	n = mpz_get_ui (lhs);
1235 	mpz_clear (lhs);
1236 
1237 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1238 	mpz_fib_ui (r, n);
1239 #else
1240 	mpz_init_set_ui (t, 1);
1241 	mpz_set_ui (r, 1);
1242 
1243 	if (n <= 2)
1244 	  mpz_set_ui (r, 1);
1245 	else
1246 	  {
1247 	    for (i = 3; i <= n; i++)
1248 	      {
1249 		mpz_add (t, t, r);
1250 		mpz_swap (t, r);
1251 	      }
1252 	  }
1253 	mpz_clear (t);
1254 #endif
1255       }
1256       return;
1257     case RANDOM:
1258       {
1259 	unsigned long int n;
1260 	mpz_init (lhs);
1261 	mpz_eval_expr (lhs, e->operands.ops.lhs);
1262 	if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1263 	  {
1264 	    error = "random number size out of range";
1265 	    mpz_clear (lhs);
1266 	    longjmp (errjmpbuf, 1);
1267 	  }
1268 	n = mpz_get_ui (lhs);
1269 	mpz_clear (lhs);
1270 	mpz_urandomb (r, rstate, n);
1271       }
1272       return;
1273     case NEXTPRIME:
1274       {
1275 	mpz_eval_expr (r, e->operands.ops.lhs);
1276 	mpz_nextprime (r, r);
1277       }
1278       return;
1279     case BINOM:
1280       mpz_init (lhs); mpz_init (rhs);
1281       mpz_eval_expr (lhs, e->operands.ops.lhs);
1282       mpz_eval_expr (rhs, e->operands.ops.rhs);
1283       {
1284 	unsigned long int k;
1285 	if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1286 	  {
1287 	    error = "k too large in (n over k) expression";
1288 	    mpz_clear (lhs); mpz_clear (rhs);
1289 	    longjmp (errjmpbuf, 1);
1290 	  }
1291 	k = mpz_get_ui (rhs);
1292 	mpz_bin_ui (r, lhs, k);
1293       }
1294       mpz_clear (lhs); mpz_clear (rhs);
1295       return;
1296     case TIMING:
1297       {
1298 	int t0;
1299 	t0 = cputime ();
1300 	mpz_eval_expr (r, e->operands.ops.lhs);
1301 	printf ("time: %d\n", cputime () - t0);
1302       }
1303       return;
1304     default:
1305       abort ();
1306     }
1307 }
1308 
1309 /* Evaluate the expression E modulo MOD and put the result in R.  */
1310 void
mpz_eval_mod_expr(mpz_ptr r,expr_t e,mpz_ptr mod)1311 mpz_eval_mod_expr (mpz_ptr r, expr_t e, mpz_ptr mod)
1312 {
1313   mpz_t lhs, rhs;
1314 
1315   switch (e->op)
1316     {
1317       case POW:
1318 	mpz_init (lhs); mpz_init (rhs);
1319 	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1320 	mpz_eval_expr (rhs, e->operands.ops.rhs);
1321 	mpz_powm (r, lhs, rhs, mod);
1322 	mpz_clear (lhs); mpz_clear (rhs);
1323 	return;
1324       case PLUS:
1325 	mpz_init (lhs); mpz_init (rhs);
1326 	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1327 	mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1328 	mpz_add (r, lhs, rhs);
1329 	if (mpz_cmp_si (r, 0L) < 0)
1330 	  mpz_add (r, r, mod);
1331 	else if (mpz_cmp (r, mod) >= 0)
1332 	  mpz_sub (r, r, mod);
1333 	mpz_clear (lhs); mpz_clear (rhs);
1334 	return;
1335       case MINUS:
1336 	mpz_init (lhs); mpz_init (rhs);
1337 	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1338 	mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1339 	mpz_sub (r, lhs, rhs);
1340 	if (mpz_cmp_si (r, 0L) < 0)
1341 	  mpz_add (r, r, mod);
1342 	else if (mpz_cmp (r, mod) >= 0)
1343 	  mpz_sub (r, r, mod);
1344 	mpz_clear (lhs); mpz_clear (rhs);
1345 	return;
1346       case MULT:
1347 	mpz_init (lhs); mpz_init (rhs);
1348 	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1349 	mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1350 	mpz_mul (r, lhs, rhs);
1351 	mpz_mod (r, r, mod);
1352 	mpz_clear (lhs); mpz_clear (rhs);
1353 	return;
1354       default:
1355 	mpz_init (lhs);
1356 	mpz_eval_expr (lhs, e);
1357 	mpz_mod (r, lhs, mod);
1358 	mpz_clear (lhs);
1359 	return;
1360     }
1361 }
1362 
1363 void
cleanup_and_exit(int sig)1364 cleanup_and_exit (int sig)
1365 {
1366   switch (sig) {
1367 #ifdef LIMIT_RESOURCE_USAGE
1368   case SIGXCPU:
1369     printf ("expression took too long to evaluate%s\n", newline);
1370     break;
1371 #endif
1372   case SIGFPE:
1373     printf ("divide by zero%s\n", newline);
1374     break;
1375   default:
1376     printf ("expression required too much memory to evaluate%s\n", newline);
1377     break;
1378   }
1379   exit (-2);
1380 }
1381