1 #include "shared.h"
2 
3 #ifdef LTM_MTEST_REAL_RAND
4 #define LTM_MTEST_RAND_SEED  time(NULL)
5 #else
6 #define LTM_MTEST_RAND_SEED  23
7 #endif
8 
draw(mp_int * a)9 static void draw(mp_int *a)
10 {
11    ndraw(a, "");
12 }
13 
14 #define FGETS(str, size, stream) \
15    { \
16       char *ret = fgets(str, size, stream); \
17       if (!ret) { fprintf(stderr, "\n%d: fgets failed\n", __LINE__); goto LBL_ERR; } \
18    }
19 
mtest_opponent(void)20 static int mtest_opponent(void)
21 {
22    char cmd[4096];
23    char buf[4096];
24    int ix;
25    unsigned rr;
26    mp_int a, b, c, d, e, f;
27    unsigned long expt_n, add_n, sub_n, mul_n, div_n, sqr_n, mul2d_n, div2d_n,
28             gcd_n, lcm_n, inv_n, div2_n, mul2_n, add_d_n, sub_d_n;
29 
30    srand(LTM_MTEST_RAND_SEED);
31 
32    if (mp_init_multi(&a, &b, &c, &d, &e, &f, NULL)!= MP_OKAY)
33       return EXIT_FAILURE;
34 
35    div2_n = mul2_n = inv_n = expt_n = lcm_n = gcd_n = add_n =
36                                          sub_n = mul_n = div_n = sqr_n = mul2d_n = div2d_n = add_d_n = sub_d_n = 0;
37 
38 #ifndef MP_FIXED_CUTOFFS
39    /* force KARA and TOOM to enable despite cutoffs */
40    KARATSUBA_SQR_CUTOFF = KARATSUBA_MUL_CUTOFF = 8;
41    TOOM_SQR_CUTOFF = TOOM_MUL_CUTOFF = 16;
42 #endif
43 
44    for (;;) {
45       /* randomly clear and re-init one variable, this has the affect of triming the alloc space */
46       switch (abs(rand()) % 7) {
47       case 0:
48          mp_clear(&a);
49          mp_init(&a);
50          break;
51       case 1:
52          mp_clear(&b);
53          mp_init(&b);
54          break;
55       case 2:
56          mp_clear(&c);
57          mp_init(&c);
58          break;
59       case 3:
60          mp_clear(&d);
61          mp_init(&d);
62          break;
63       case 4:
64          mp_clear(&e);
65          mp_init(&e);
66          break;
67       case 5:
68          mp_clear(&f);
69          mp_init(&f);
70          break;
71       case 6:
72          break;        /* don't clear any */
73       }
74 
75 
76       printf("%4lu/%4lu/%4lu/%4lu/%4lu/%4lu/%4lu/%4lu/%4lu/%4lu/%4lu/%4lu/%4lu/%4lu/%4lu ",
77              add_n, sub_n, mul_n, div_n, sqr_n, mul2d_n, div2d_n, gcd_n, lcm_n,
78              expt_n, inv_n, div2_n, mul2_n, add_d_n, sub_d_n);
79       FGETS(cmd, 4095, stdin);
80       cmd[strlen(cmd) - 1u] = '\0';
81       printf("%-6s ]\r", cmd);
82       fflush(stdout);
83       if (strcmp(cmd, "mul2d") == 0) {
84          ++mul2d_n;
85          FGETS(buf, 4095, stdin);
86          mp_read_radix(&a, buf, 64);
87          FGETS(buf, 4095, stdin);
88          sscanf(buf, "%u", &rr);
89          FGETS(buf, 4095, stdin);
90          mp_read_radix(&b, buf, 64);
91 
92          mp_mul_2d(&a, (int)rr, &a);
93          a.sign = b.sign;
94          if (mp_cmp(&a, &b) != MP_EQ) {
95             printf("mul2d failed, rr == %u\n", rr);
96             draw(&a);
97             draw(&b);
98             goto LBL_ERR;
99          }
100       } else if (strcmp(cmd, "div2d") == 0) {
101          ++div2d_n;
102          FGETS(buf, 4095, stdin);
103          mp_read_radix(&a, buf, 64);
104          FGETS(buf, 4095, stdin);
105          sscanf(buf, "%u", &rr);
106          FGETS(buf, 4095, stdin);
107          mp_read_radix(&b, buf, 64);
108 
109          mp_div_2d(&a, (int)rr, &a, &e);
110          a.sign = b.sign;
111          if ((a.used == b.used) && (a.used == 0)) {
112             a.sign = b.sign = MP_ZPOS;
113          }
114          if (mp_cmp(&a, &b) != MP_EQ) {
115             printf("div2d failed, rr == %u\n", rr);
116             draw(&a);
117             draw(&b);
118             goto LBL_ERR;
119          }
120       } else if (strcmp(cmd, "add") == 0) {
121          ++add_n;
122          FGETS(buf, 4095, stdin);
123          mp_read_radix(&a, buf, 64);
124          FGETS(buf, 4095, stdin);
125          mp_read_radix(&b, buf, 64);
126          FGETS(buf, 4095, stdin);
127          mp_read_radix(&c, buf, 64);
128          mp_copy(&a, &d);
129          mp_add(&d, &b, &d);
130          if (mp_cmp(&c, &d) != MP_EQ) {
131             printf("add %lu failure!\n", add_n);
132             draw(&a);
133             draw(&b);
134             draw(&c);
135             draw(&d);
136             goto LBL_ERR;
137          }
138 
139          /* test the sign/unsigned storage functions */
140 
141          rr = (unsigned)mp_sbin_size(&c);
142          mp_to_sbin(&c, (unsigned char *) cmd, (size_t)rr, NULL);
143          memset(cmd + rr, rand() & 0xFF, sizeof(cmd) - rr);
144          mp_from_sbin(&d, (unsigned char *) cmd, (size_t)rr);
145          if (mp_cmp(&c, &d) != MP_EQ) {
146             printf("mp_signed_bin failure!\n");
147             draw(&c);
148             draw(&d);
149             goto LBL_ERR;
150          }
151 
152          rr = (unsigned)mp_ubin_size(&c);
153          mp_to_ubin(&c, (unsigned char *) cmd, (size_t)rr, NULL);
154          memset(cmd + rr, rand() & 0xFF, sizeof(cmd) - rr);
155          mp_from_ubin(&d, (unsigned char *) cmd, (size_t)rr);
156          if (mp_cmp_mag(&c, &d) != MP_EQ) {
157             printf("mp_unsigned_bin failure!\n");
158             draw(&c);
159             draw(&d);
160             goto LBL_ERR;
161          }
162 
163       } else if (strcmp(cmd, "sub") == 0) {
164          ++sub_n;
165          FGETS(buf, 4095, stdin);
166          mp_read_radix(&a, buf, 64);
167          FGETS(buf, 4095, stdin);
168          mp_read_radix(&b, buf, 64);
169          FGETS(buf, 4095, stdin);
170          mp_read_radix(&c, buf, 64);
171          mp_copy(&a, &d);
172          mp_sub(&d, &b, &d);
173          if (mp_cmp(&c, &d) != MP_EQ) {
174             printf("sub %lu failure!\n", sub_n);
175             draw(&a);
176             draw(&b);
177             draw(&c);
178             draw(&d);
179             goto LBL_ERR;
180          }
181       } else if (strcmp(cmd, "mul") == 0) {
182          ++mul_n;
183          FGETS(buf, 4095, stdin);
184          mp_read_radix(&a, buf, 64);
185          FGETS(buf, 4095, stdin);
186          mp_read_radix(&b, buf, 64);
187          FGETS(buf, 4095, stdin);
188          mp_read_radix(&c, buf, 64);
189          mp_copy(&a, &d);
190          mp_mul(&d, &b, &d);
191          if (mp_cmp(&c, &d) != MP_EQ) {
192             printf("mul %lu failure!\n", mul_n);
193             draw(&a);
194             draw(&b);
195             draw(&c);
196             draw(&d);
197             goto LBL_ERR;
198          }
199       } else if (strcmp(cmd, "div") == 0) {
200          ++div_n;
201          FGETS(buf, 4095, stdin);
202          mp_read_radix(&a, buf, 64);
203          FGETS(buf, 4095, stdin);
204          mp_read_radix(&b, buf, 64);
205          FGETS(buf, 4095, stdin);
206          mp_read_radix(&c, buf, 64);
207          FGETS(buf, 4095, stdin);
208          mp_read_radix(&d, buf, 64);
209 
210          mp_div(&a, &b, &e, &f);
211          if ((mp_cmp(&c, &e) != MP_EQ) || (mp_cmp(&d, &f) != MP_EQ)) {
212             printf("div %lu %d, %d, failure!\n", div_n, mp_cmp(&c, &e),
213                    mp_cmp(&d, &f));
214             draw(&a);
215             draw(&b);
216             draw(&c);
217             draw(&d);
218             draw(&e);
219             draw(&f);
220             goto LBL_ERR;
221          }
222 
223       } else if (strcmp(cmd, "sqr") == 0) {
224          ++sqr_n;
225          FGETS(buf, 4095, stdin);
226          mp_read_radix(&a, buf, 64);
227          FGETS(buf, 4095, stdin);
228          mp_read_radix(&b, buf, 64);
229          mp_copy(&a, &c);
230          mp_sqr(&c, &c);
231          if (mp_cmp(&b, &c) != MP_EQ) {
232             printf("sqr %lu failure!\n", sqr_n);
233             draw(&a);
234             draw(&b);
235             draw(&c);
236             goto LBL_ERR;
237          }
238       } else if (strcmp(cmd, "gcd") == 0) {
239          ++gcd_n;
240          FGETS(buf, 4095, stdin);
241          mp_read_radix(&a, buf, 64);
242          FGETS(buf, 4095, stdin);
243          mp_read_radix(&b, buf, 64);
244          FGETS(buf, 4095, stdin);
245          mp_read_radix(&c, buf, 64);
246          mp_copy(&a, &d);
247          mp_gcd(&d, &b, &d);
248          d.sign = c.sign;
249          if (mp_cmp(&c, &d) != MP_EQ) {
250             printf("gcd %lu failure!\n", gcd_n);
251             draw(&a);
252             draw(&b);
253             draw(&c);
254             draw(&d);
255             goto LBL_ERR;
256          }
257       } else if (strcmp(cmd, "lcm") == 0) {
258          ++lcm_n;
259          FGETS(buf, 4095, stdin);
260          mp_read_radix(&a, buf, 64);
261          FGETS(buf, 4095, stdin);
262          mp_read_radix(&b, buf, 64);
263          FGETS(buf, 4095, stdin);
264          mp_read_radix(&c, buf, 64);
265          mp_copy(&a, &d);
266          mp_lcm(&d, &b, &d);
267          d.sign = c.sign;
268          if (mp_cmp(&c, &d) != MP_EQ) {
269             printf("lcm %lu failure!\n", lcm_n);
270             draw(&a);
271             draw(&b);
272             draw(&c);
273             draw(&d);
274             goto LBL_ERR;
275          }
276       } else if (strcmp(cmd, "expt") == 0) {
277          ++expt_n;
278          FGETS(buf, 4095, stdin);
279          mp_read_radix(&a, buf, 64);
280          FGETS(buf, 4095, stdin);
281          mp_read_radix(&b, buf, 64);
282          FGETS(buf, 4095, stdin);
283          mp_read_radix(&c, buf, 64);
284          FGETS(buf, 4095, stdin);
285          mp_read_radix(&d, buf, 64);
286          mp_copy(&a, &e);
287          mp_exptmod(&e, &b, &c, &e);
288          if (mp_cmp(&d, &e) != MP_EQ) {
289             printf("expt %lu failure!\n", expt_n);
290             draw(&a);
291             draw(&b);
292             draw(&c);
293             draw(&d);
294             draw(&e);
295             goto LBL_ERR;
296          }
297       } else if (strcmp(cmd, "invmod") == 0) {
298          ++inv_n;
299          FGETS(buf, 4095, stdin);
300          mp_read_radix(&a, buf, 64);
301          FGETS(buf, 4095, stdin);
302          mp_read_radix(&b, buf, 64);
303          FGETS(buf, 4095, stdin);
304          mp_read_radix(&c, buf, 64);
305          mp_invmod(&a, &b, &d);
306          mp_mulmod(&d, &a, &b, &e);
307          if (mp_cmp_d(&e, 1uL) != MP_EQ) {
308             printf("inv [wrong value from MPI?!] failure\n");
309             draw(&a);
310             draw(&b);
311             draw(&c);
312             draw(&d);
313             draw(&e);
314             mp_gcd(&a, &b, &e);
315             draw(&e);
316             goto LBL_ERR;
317          }
318 
319       } else if (strcmp(cmd, "div2") == 0) {
320          ++div2_n;
321          FGETS(buf, 4095, stdin);
322          mp_read_radix(&a, buf, 64);
323          FGETS(buf, 4095, stdin);
324          mp_read_radix(&b, buf, 64);
325          mp_div_2(&a, &c);
326          if (mp_cmp(&c, &b) != MP_EQ) {
327             printf("div_2 %lu failure\n", div2_n);
328             draw(&a);
329             draw(&b);
330             draw(&c);
331             goto LBL_ERR;
332          }
333       } else if (strcmp(cmd, "mul2") == 0) {
334          ++mul2_n;
335          FGETS(buf, 4095, stdin);
336          mp_read_radix(&a, buf, 64);
337          FGETS(buf, 4095, stdin);
338          mp_read_radix(&b, buf, 64);
339          mp_mul_2(&a, &c);
340          if (mp_cmp(&c, &b) != MP_EQ) {
341             printf("mul_2 %lu failure\n", mul2_n);
342             draw(&a);
343             draw(&b);
344             draw(&c);
345             goto LBL_ERR;
346          }
347       } else if (strcmp(cmd, "add_d") == 0) {
348          ++add_d_n;
349          FGETS(buf, 4095, stdin);
350          mp_read_radix(&a, buf, 64);
351          FGETS(buf, 4095, stdin);
352          sscanf(buf, "%d", &ix);
353          FGETS(buf, 4095, stdin);
354          mp_read_radix(&b, buf, 64);
355          mp_add_d(&a, (mp_digit)ix, &c);
356          if (mp_cmp(&b, &c) != MP_EQ) {
357             printf("add_d %lu failure\n", add_d_n);
358             draw(&a);
359             draw(&b);
360             draw(&c);
361             printf("d == %d\n", ix);
362             goto LBL_ERR;
363          }
364       } else if (strcmp(cmd, "sub_d") == 0) {
365          ++sub_d_n;
366          FGETS(buf, 4095, stdin);
367          mp_read_radix(&a, buf, 64);
368          FGETS(buf, 4095, stdin);
369          sscanf(buf, "%d", &ix);
370          FGETS(buf, 4095, stdin);
371          mp_read_radix(&b, buf, 64);
372          mp_sub_d(&a, (mp_digit)ix, &c);
373          if (mp_cmp(&b, &c) != MP_EQ) {
374             printf("sub_d %lu failure\n", sub_d_n);
375             draw(&a);
376             draw(&b);
377             draw(&c);
378             printf("d == %d\n", ix);
379             goto LBL_ERR;
380          }
381       } else if (strcmp(cmd, "exit") == 0) {
382          printf("\nokay, exiting now\n");
383          break;
384       }
385    }
386 
387    mp_clear_multi(&a, &b, &c, &d, &e, &f, NULL);
388    printf("\n");
389    return 0;
390 
391 LBL_ERR:
392    mp_clear_multi(&a, &b, &c, &d, &e, &f, NULL);
393    printf("\n");
394    return EXIT_FAILURE;
395 }
396 
main(void)397 int main(void)
398 {
399    print_header();
400 
401    return mtest_opponent();
402 }
403