1 /*
2
3 qdev.c - code handling q-expansions
4
5 Copyright (C) 2009, 2015, 2016, 2018, 2020 Andreas Enge
6
7 This file is part of CM.
8
9 CM is free software; you can redistribute it and/or modify it under
10 the terms of the GNU General Public License as published by the
11 Free Software Foundation; either version 3 of the license, or (at your
12 option) any later version.
13
14 CM is distributed in the hope that it will be useful, but WITHOUT ANY
15 WARRANTY; without even the implied warranty of MERCHANTABILITY or
16 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
17 for more details.
18
19 You should have received a copy of the GNU General Public License along
20 with CM; see the file COPYING. If not, write to the Free Software
21 Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
22 */
23
24 #include "cm_common-impl.h"
25
26 static bool find_in_chain (int* index, cm_qdev_t f, int length, long int no);
27 static double lognorm2 (ctype op);
28 static void qdev_eval_addition_sequence (ctype rop, cm_qdev_t f, ctype q1,
29 double delta, int N);
30 static int dense_addition_sequence (long int ** b, int m);
31 static int minimal_dense_addition_sequence (long int ** b, int m);
32 static void qdev_eval_bsgs (ctype rop, cm_qdev_t f, ctype q1,
33 double delta, int N);
34
35 /*****************************************************************************/
36
find_in_chain(int * index,cm_qdev_t f,int length,long int no)37 static bool find_in_chain (int* index, cm_qdev_t f, int length, long int no)
38 /* looks for no in the first length elements of f */
39 /* The return value indicates the success; if the operation succeeds, */
40 /* the index contains i such that f.chain [i][0] = no. */
41
42 {
43 int left = 0, right = length - 1, middle;
44
45 if (no < f.chain [0][0] || no > f.chain [length-1][0])
46 return false;
47
48 while (left < right - 1)
49 {
50 middle = (left + right) / 2;
51 if (f.chain [middle][0] < no)
52 left = middle;
53 else
54 right = middle;
55 }
56 if (f.chain [left][0] == no)
57 {
58 *index = left;
59 return true;
60 }
61 else if (f.chain [right][0] == no)
62 {
63 *index = right;
64 return true;
65 }
66 else
67 return false;
68 }
69
70 /*****************************************************************************/
71
lognorm2(ctype op)72 static double lognorm2 (ctype op)
73 /* computes the logarithm in base 2 of the complex norm of op */
74
75 {
76 double re, im;
77 long int ere, eim, diff;
78
79 /* Just extracting a double may overflow, so treat the exponents
80 separately. */
81 re = fget_d_2exp (&ere, crealref (op));
82 im = fget_d_2exp (&eim, cimagref (op));
83
84 /* Handle the case of 0 in one part separately, as it may be coupled
85 with another very small exponent; then normalising for the larger
86 exponent yields 0 and a problem with the logarithm. */
87 if (re == 0)
88 return (eim + log2 (fabs (im)));
89 else if (im == 0)
90 return (ere + log2 (fabs (re)));
91
92 /* Normalise to keep the larger exponent; the smaller one may underflow,
93 then the number becomes a harmless 0. */
94 if (ere > eim) {
95 diff = eim - ere;
96 eim = ere;
97 im *= exp2 (diff);
98 }
99 else {
100 diff = ere - eim;
101 ere = eim;
102 re *= exp2 (diff);
103 }
104
105 return (ere + log2 (re*re + im*im) / 2);
106 }
107
108 /*****************************************************************************/
109
cm_qdev_init(cm_qdev_t * f,fprec_t prec)110 void cm_qdev_init (cm_qdev_t *f, fprec_t prec)
111 /* initialises the addition chain for eta */
112
113 {
114 int n, i, j;
115
116 f->length = 2 * ((fprec_t) (sqrt (prec * 0.085) + 1)) + 1;
117 /* must be odd */
118 /* Since each power of q yields at least */
119 /* log_2 (exp (sqrt (3) * pi)) = 7.85 bits, */
120 /* the k yielding the largest exponent must satisfy */
121 /* ((3*k+1)*k/2)*7.85 >= prec; we drop the +1 to simplify. */
122 /* Then we have twice as many exponents (taking into account the */
123 /* (3*k-1)*k/2), and one more for the constant coefficient. */
124
125 f->chain = (long int **) malloc (f->length * sizeof (long int *));
126 for (n = 0; n < f->length; n++)
127 f->chain [n] = (long int *) malloc (5 * sizeof (long int));
128
129 f->chain [0][0] = 0;
130 f->chain [0][4] = 1;
131 for (n = 1; n <= f->length / 2; n++)
132 {
133 f->chain [2*n-1][0] = n*(3*n-1) / 2;
134 f->chain [2*n][0] = n*(3*n+1) / 2;
135 if (n % 2 == 0)
136 {
137 f->chain [2*n-1][4] = 1;
138 f->chain [2*n][4] = 1;
139 }
140 else
141 {
142 f->chain [2*n-1][4] = -1;
143 f->chain [2*n][4] = -1;
144 }
145 }
146
147 f->chain [0][1] = 0;
148 f->chain [1][1] = 0;
149 for (n = 2; n < f->length; n++)
150 {
151 f->chain [n][1] = 0;
152 /* try to express an even exponent as twice a previous one */
153 if (f->chain [n][0] % 2 == 0)
154 if (find_in_chain (&i, *f, n, f->chain [n][0] / 2))
155 {
156 f->chain [n][1] = 1;
157 f->chain [n][2] = i;
158 }
159 /* try to express the exponent as the sum of two previous ones */
160 for (i = 0; i < n && f->chain [n][1] == 0; i++)
161 if (find_in_chain (&j, *f, n, f->chain [n][0] - f->chain [i][0]))
162 {
163 f->chain [n][1] = 2;
164 f->chain [n][2] = i;
165 f->chain [n][3] = j;
166 }
167 /* try to express the exponent as twice a previous plus a third one */
168 for (i = 0; i < n && f->chain [n][1] == 0; i++)
169 if (find_in_chain (&j, *f, n, f->chain [n][0] - 2 * f->chain [i][0]))
170 {
171 f->chain [n][1] = 3;
172 f->chain [n][2] = i;
173 f->chain [n][3] = j;
174 }
175 /* This covers all cases for eta, see Enge-Johansson 2016. */
176 }
177 }
178
179 /*****************************************************************************/
180
cm_qdev_clear(cm_qdev_t * f)181 void cm_qdev_clear (cm_qdev_t *f)
182
183 {
184 int n;
185
186 for (n = 0; n < f->length; n++)
187 free (f->chain [n]);
188 free (f->chain);
189 }
190
191 /*****************************************************************************/
192
qdev_eval_addition_sequence(ctype rop,cm_qdev_t f,ctype q1,double delta,int N)193 static void qdev_eval_addition_sequence (ctype rop, cm_qdev_t f, ctype q1,
194 double delta, int N)
195 /* Evaluate f in q1 using the optimised addition sequence from f.
196 N is the last index used in the addition chain.
197 delta is the number of bits gained with each power of q.
198 rop and q1 may be the same. */
199
200 {
201 mp_prec_t prec;
202 long int local_prec;
203 ctype *q, term, tmp1, tmp2;
204 int n, i;
205
206 prec = fget_prec (crealref (rop));
207
208 q = (ctype *) malloc (f.length * sizeof (ctype));
209 cinit (q [1], prec);
210 cset (q [1], q1);
211 cinit (term, prec);
212 cinit (tmp1, prec);
213 cinit (tmp2, prec);
214
215 cset_si (rop, f.chain [0][4]);
216 if (f.chain [1][4] == 1)
217 cadd (rop, rop, q [1]);
218 else if (f.chain [1][4] == -1)
219 csub (rop, rop, q [1]);
220 else if (f.chain [1][4] != 0)
221 {
222 cmul_si (term, q [1], f.chain [1][4]);
223 cadd (rop, rop, term);
224 }
225
226 for (n = 2; n <= N; n++) {
227 local_prec = (long int) prec - (long int) (f.chain [n][0] * delta);
228 cinit (q [n], (mp_prec_t) local_prec);
229 switch (f.chain [n][1]) {
230 case 1:
231 /* Reduce the precision of the argument to save some more time. */
232 cset_prec (tmp1, local_prec);
233 cset (tmp1, q [f.chain [n][2]]);
234 csqr (q [n], tmp1);
235 break;
236 case 2:
237 cset_prec (tmp1, local_prec);
238 cset_prec (tmp2, local_prec);
239 cset (tmp1, q [f.chain [n][2]]);
240 cset (tmp2, q [f.chain [n][3]]);
241 cmul (q [n], tmp1, tmp2);
242 break;
243 case 3:
244 cset_prec (tmp1, local_prec);
245 cset_prec (tmp2, local_prec);
246 cset (tmp1, q [f.chain [n][2]]);
247 cset (tmp2, q [f.chain [n][3]]);
248 csqr (q [n], tmp1);
249 cmul (q [n], q [n], tmp2);
250 break;
251 }
252 if (f.chain [n][4] == 1)
253 cadd (rop, rop, q [n]);
254 else if (f.chain [n][4] == -1)
255 csub (rop, rop, q [n]);
256 else if (f.chain [n][4] != 0) {
257 cset_prec (term, (mp_prec_t) local_prec);
258 cmul_si (term, q [n], f.chain [n][4]);
259 cadd (rop, rop, term);
260 }
261 }
262
263 for (i = 1; i < n; i++)
264 cclear (q [i]);
265 free (q);
266 cclear (term);
267 cclear (tmp1);
268 cclear (tmp2);
269 }
270
271 /*****************************************************************************/
272
dense_addition_sequence(long int ** b,int m)273 static int dense_addition_sequence (long int ** b, int m)
274 /* Compute an addition sequence for b and return it via b itself.
275 The return value is the number of additional entries.
276 b is a matrix of m+1 rows and (at least) 2 columns; the row i
277 represents the exponent i. Initially, b [i][0] is expected to be 1
278 if the exponent i occurs, 0 otherwise
279 During the course of the algorithm, new elements are added;
280 these are marked by a 2 in b [i][0].
281 b [i][1] is set to j if the exponent occurs and can be written as
282 j+(i-j) with also occurring exponents j and i-j.
283 The algorithm prefers doublings over other additions. */
284 {
285 int i, j, found, added;
286
287 added = 0;
288 for (i = m; i >= 2; i--)
289 if (b [i][0] != 0) {
290 found = 0;
291 /* Search for two existing entries adding to i. */
292 for (j = i / 2; j >= 1 && !found; j--)
293 /* In this way, a doubling is found if it exists. */
294 if (b [j][0] != 0 && b [i-j][0] != 0) {
295 b [i][1] = j;
296 found = 1;
297 }
298 if (!found) {
299 /* Add missing elements. The middle strategy seems to give the
300 best performance; when minimising additionally, all three
301 have a very similar outcome. */
302 #if 0
303 /* First strategy: Add floor (i/2) and ceil (i/2). */
304 j = i/2;
305 #endif
306 #if 1
307 /* Second strategy:
308 If i is even, add i/2.
309 Otherwise, the previous strategy would add two numbers.
310 Instead, add i-j for the largest occurring j less than i. */
311 if (i % 2 == 0)
312 j = i / 2;
313 else
314 for (j = i-1; b [j][0] == 0; j--);
315 #endif
316 #if 0
317 /* Third strategy: As the second one, but for odd i, add
318 an even i-j for the largest possible j; at the latest, this
319 happens for i-j=1. */
320 if (i % 2 == 0)
321 j = i / 2;
322 else
323 for (j = i-1; j % 2 == 0 || b [j][0] == 0; j--);
324 #endif
325 b [i][1] = j;
326 if (b [j][0] == 0) {
327 b [j][0] = 2;
328 added++;
329 }
330 if (b [i-j][0] == 0) {
331 b [i-j][0] = 2;
332 added++;
333 }
334 }
335 }
336
337 return (added);
338 }
339
340 /*****************************************************************************/
341
minimal_dense_addition_sequence(long int ** b,int m)342 static int minimal_dense_addition_sequence (long int ** b, int m)
343 /* The parameters are the same as for dense_addition_sequence,
344 except that b has three columns: b [i][2] is set to the desired
345 precision if the exponent i occurs.
346 The function returns an addition sequence that is minimal in the sense
347 that none of the added entries may be removed (which does not mean that
348 it is optimal!).
349 Also, if possible it replaces remaining general additions by doublings
350 (the addition sequence computation already privileges doublings, but
351 these may become possible only later using added terms).
352 The precision is also tracked, when writing an element as a sum of two
353 smaller ones, then the precision of the smaller elements may need to be
354 increased to that of the target. */
355 {
356 int added, new_added, changed, i, j;
357 long int **bnew;
358
359 added = dense_addition_sequence (b, m);
360 /* We need to work on a copy, since the computation of addition sequences
361 via side effects destroys the previously computed sequence. It may
362 happen that when removing one of the additional terms, three new ones
363 are added in a run! Then we need to simply throw away the new
364 sequence. */
365 bnew = (long int **) malloc ((m + 1) * sizeof (long int *));
366 for (i = 0; i <= m; i++)
367 bnew [i] = (long *) malloc (2 * sizeof (long));
368
369 /* This additional loop does not seem to be needed in practice. */
370 changed = 1;
371 while (changed) {
372 changed = 0;
373 /* Check if any of the added entries may be removed again. */
374 for (i = m; i >= 2; i--)
375 if (b [i][0] == 2) {
376 /* Work on a copy of b with this element cancelled. */
377 for (j = 0; j <= m; j++) {
378 bnew [j][0] = b [j][0];
379 bnew [j][1] = 0;
380 }
381 bnew [i][0] = 0;
382 new_added = dense_addition_sequence (bnew, m);
383 if (new_added == 0) {
384 /* Copy the new addition sequence back. */
385 for (j = 0; j <= m; j++) {
386 b [j][0] = bnew [j][0];
387 b [j][1] = bnew [j][1];
388 }
389 added--;
390 changed = 1;
391 }
392 }
393 }
394
395 /* Look for doublings. */
396 for (i = 4; i <= m; i += 2)
397 if (b [i][0] != 0 && b [i/2][0] != 0)
398 b [i][1] = i/2;
399
400 /* Track the precision. */
401 for (i = m; i >= 2; i--)
402 if (b [i][0] != 0) {
403 j = b [i][1];
404 if (b [i][2] > b [j][2])
405 b [j][2] = b [i][2];
406 if (b [i][2] > b [i-j][2])
407 b [i-j][2] = b [i][2];
408 }
409
410 for (i = 0; i <= m; i++)
411 free (bnew [i]);
412 free (bnew);
413
414 return (added);
415 }
416
417 /*****************************************************************************/
418
qdev_eval_bsgs(ctype rop,cm_qdev_t f,ctype q1,double delta,int N)419 static void qdev_eval_bsgs (ctype rop, cm_qdev_t f, ctype q1,
420 double delta, int N)
421 /* Evaluate f in q1 using the optimised addition sequence from f.
422 N is the last index used in the addition chain.
423 delta is the number of bits gained with each power of q.
424 rop and q1 may be the same. */
425 {
426 mp_prec_t prec, local_prec;
427 int mopt [37] = { 2, 5, 7, 11, 13, 17, 19, 23, 55, 65, 77,
428 91, 119, 133, 143, 175, 275, 325, 455, 595, 665, 715, 935, 1001,
429 1309, 1463, 1547, 1729, 2275, 2975, 3325, 3575, 4675,
430 6545, 7315, 7735, 8645 };
431 int p [37] = { 2, 3, 4, 6, 7, 9, 10, 12, 18, 21, 24,
432 28, 36, 40, 42, 44, 66, 77, 84, 108, 120, 126, 162, 168,
433 216, 240, 252, 280, 308, 396, 440, 462, 594,
434 648, 720, 756, 840 };
435 long int T, cost, cost_new;
436 int m, index;
437 long int **bs;
438 ctype *q, *c, tmp1, tmp2, tmp3;
439 int i, j, k, J;
440
441 prec = fget_prec (crealref (rop));
442 cinit (tmp1, prec);
443 cinit (tmp2, prec);
444 cinit (tmp3, prec);
445 T = f.chain [N][0];
446 /* Find the optimal m=mopt[i] minimising the theoretical cost function
447 (roughly, T/mopt[i] + p[i]). The function seems to be unimodular,
448 so we take the smallest i before it increases again. */
449 i = 0;
450 cost_new = (T + mopt [0]) / mopt[0] + p[0] - 1;
451 do {
452 cost = cost_new;
453 i++;
454 cost_new = (T + mopt[i]) / mopt[i] + p[i] - 1;
455 } while (cost_new < cost && i < 36);
456 if (i == 36 && cost_new < cost) {
457 printf ("*** Houston, we have a problem!\n");
458 printf ("mopt and p too short in 'qdev_eval_bsgs'.\n");
459 exit (1);
460 }
461 else
462 m = mopt [i-1];
463
464 /* Determine the occurring baby-steps; in practice, these are as
465 many as predicted by p. Also keep track of their required precision. */
466 bs = (long int **) malloc ((m + 1) * sizeof (long int *));
467 for (i = 0; i <= m; i++) {
468 bs [i] = (long *) malloc (3 * sizeof (long));
469 for (j = 0; j < 3; j++)
470 bs [i][j] = 0;
471 }
472 for (i = 0; i <= N; i++) {
473 index = f.chain [i][0] % m;
474 /* Register the precision needed for the term with lowest exponent
475 and not later ones. */
476 if (bs [index][0] == 0) {
477 bs [index][0] = 1;
478 bs [index][2] = (long int) prec
479 - (long int) (f.chain [i][0] * delta);
480 }
481 }
482 bs [m][0] = 1; /* for the giant steps */
483 bs [m][2] = (long int) prec - (long int) (m * delta);
484 minimal_dense_addition_sequence (bs, m);
485
486 /* Compute the baby-steps. */
487 q = (ctype *) malloc ((m + 1) * sizeof (ctype));
488 cinit (q [0], 2);
489 cset_ui (q [0], 1);
490 cinit (q [1], bs [1][2]);
491 cset (q [1], q1);
492 for (i = 2; i <= m; i++)
493 if (bs [i][0] != 0) {
494 cinit (q [i], bs [i][2]);
495 j = bs [i][1];
496 k = i - j;
497 /* Here decreasing the argument precision to the target precision
498 apparently does not gain time; both are probably too close for
499 it to make a difference. */
500 if (j == k)
501 csqr (q [i], q [j]);
502 else
503 cmul (q [i], q [j], q [k]);
504 }
505
506 /* Compute the giant steps; we need
507 \sum_{j=0}^{J-1} (\sum_{k=0}^{m-1} c_{k+j*m} q^k) (q^m)^j.
508 First compute the inner coefficients, then use a Horner scheme. */
509 J = f.chain [N][0] / m + 1;
510 c = (ctype *) malloc (J * sizeof (ctype));
511 for (j = 0; j < J; j++) {
512 local_prec = prec - (mp_prec_t) (j * m * delta);
513 cinit (c [j], local_prec);
514 cset_ui (c [j], 0);
515 }
516 for (i = 0; i <= N; i++)
517 if (f.chain [i][4] != 0) {
518 j = f.chain [i][0] / m;
519 k = f.chain [i][0] % m;
520 /* We assume the coefficients are 1 or -1. */
521 if (f.chain [i][4] == 1)
522 cadd (c [j], c[j], q [k]);
523 else
524 csub (c [j], c[j], q [k]);
525 }
526 cset (rop, c [J-1]);
527 for (j = J-2; j >= 0; j--) {
528 /* Carry out the multiplication at the precision of c [j+1]. */
529 local_prec = prec - (mp_prec_t) ((j+1) * m * delta);
530 cset_prec (tmp1, local_prec);
531 cset_prec (tmp2, local_prec);
532 cset_prec (tmp3, local_prec);
533 cset (tmp1, rop);
534 cset (tmp2, q [m]);
535 cmul (tmp3, tmp1, tmp2);
536 cadd (rop, tmp3, c [j]);
537 }
538
539 for (i = 0; i <= m; i++) {
540 if (bs [i][0] != 0)
541 cclear (q [i]);
542 free (bs [i]);
543 }
544 for (j = 0; j < J; j++)
545 cclear (c [j]);
546 free (bs);
547 free (q);
548 free (c);
549 cclear (tmp1);
550 cclear (tmp2);
551 cclear (tmp3);
552 }
553
554 /*****************************************************************************/
555
cm_qdev_eval(ctype rop,cm_qdev_t f,ctype q1)556 void cm_qdev_eval (ctype rop, cm_qdev_t f, ctype q1)
557 /* Evaluate f in q1. rop and q1 may be the same. */
558
559 {
560 mp_prec_t prec;
561 double delta;
562 long int T;
563 int N;
564
565 /* Compute the last exponent T and its index N in f.chain. */
566 prec = fget_prec (crealref (rop));
567 delta = - lognorm2 (q1);
568 T = (prec - 2) / delta;
569 for (N=0; N < f.length && f.chain [N][0] <= T; N++);
570 if (N == f.length) {
571 printf ("*** Houston, we have a problem! Addition chain too short ");
572 printf ("in 'cm_qdev_eval'.\n");
573 printf ("T=%li, length=%i\n", T, f.length);
574 exit (1);
575 }
576 N--;
577 T = f.chain [N][0];
578
579 if (N < 20)
580 qdev_eval_addition_sequence (rop, f, q1, delta, N);
581 else
582 qdev_eval_bsgs (rop, f, q1, delta, N);
583 }
584
585 /*****************************************************************************/
586
cm_qdev_eval_fr(ftype rop,cm_qdev_t f,ftype q1)587 void cm_qdev_eval_fr (ftype rop, cm_qdev_t f, ftype q1)
588 /* evaluates f in q1 */
589
590 {
591 mp_prec_t prec;
592 long int local_prec, e;
593 double mantissa, delta;
594 ftype *q, term;
595 int n, i;
596
597 prec = fget_prec (rop);
598 mantissa = fget_d_2exp (&e, q1);
599 delta = - (e + log2 (fabs (mantissa)));
600
601 q = (ftype *) malloc (f.length * sizeof (ftype));
602 finit (q [1], prec);
603 fset (q [1], q1);
604 finit (term, prec);
605
606 fset_si (rop, f.chain [0][4]);
607 if (f.chain [1][4] == 1)
608 fadd (rop, rop, q [1]);
609 else if (f.chain [1][4] == -1)
610 fsub (rop, rop, q [1]);
611 else if (f.chain [1][4] != 0)
612 {
613 fmul_si (term, q [1], f.chain [1][4]);
614 fadd (rop, rop, term);
615 }
616
617 n = 2;
618 /* Adapt the precision for the next term. */
619 local_prec = (long int) prec - (long int) (f.chain [n][0] * delta);
620
621 while (local_prec >= 2)
622 {
623 finit (q [n], (mp_prec_t) local_prec);
624 switch (f.chain [n][1])
625 {
626 case 1:
627 fsqr (q [n], q [f.chain [n][2]]);
628 break;
629 case 2:
630 fmul (q [n], q [f.chain [n][2]], q [f.chain [n][3]]);
631 break;
632 case 3:
633 fsqr (q [n], q [f.chain [n][2]]);
634 fmul (q [n], q[n], q [f.chain [n][3]]);
635 break;
636 }
637 if (f.chain [n][4] == 1)
638 fadd (rop, rop, q [n]);
639 else if (f.chain [n][4] == -1)
640 fsub (rop, rop, q [n]);
641 else if (f.chain [n][4] != 0)
642 {
643 fset_prec (term, (mp_prec_t) local_prec);
644 fmul_si (term, q [n], f.chain [n][4]);
645 fadd (rop, rop, term);
646 }
647 n++;
648 if (n >= f.length)
649 {
650 printf ("*** Houston, we have a problem! Addition chain too short ");
651 printf ("in 'qdev_eval_fr'.\n");
652 printf ("n=%i, length=%i\n", n, f.length);
653 printf ("q "); fout_str (stdout, 10, 10, q [1]);
654 printf ("\n");
655 printf ("q^i "); fout_str (stdout, 10, 10, q [n-1]);
656 printf ("\n");
657 exit (1);
658 }
659 local_prec = (long int) prec - (long int) (f.chain [n][0] * delta);
660 }
661
662 for (i = 1; i < n; i++)
663 fclear (q [i]);
664 free (q);
665 fclear (term);
666 }
667
668 /*****************************************************************************/
669
670