1 /*
2 
3 qdev.c - code handling q-expansions
4 
5 Copyright (C) 2009, 2015, 2016, 2018, 2020 Andreas Enge
6 
7 This file is part of CM.
8 
9 CM is free software; you can redistribute it and/or modify it under
10 the terms of the GNU General Public License as published by the
11 Free Software Foundation; either version 3 of the license, or (at your
12 option) any later version.
13 
14 CM is distributed in the hope that it will be useful, but WITHOUT ANY
15 WARRANTY; without even the implied warranty of MERCHANTABILITY or
16 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
17 for more details.
18 
19 You should have received a copy of the GNU General Public License along
20 with CM; see the file COPYING. If not, write to the Free Software
21 Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
22 */
23 
24 #include "cm_common-impl.h"
25 
26 static bool find_in_chain (int* index, cm_qdev_t f, int length, long int no);
27 static double lognorm2 (ctype op);
28 static void qdev_eval_addition_sequence (ctype rop, cm_qdev_t f, ctype q1,
29    double delta, int N);
30 static int dense_addition_sequence (long int ** b, int m);
31 static int minimal_dense_addition_sequence (long int ** b, int m);
32 static void qdev_eval_bsgs (ctype rop, cm_qdev_t f, ctype q1,
33    double delta, int N);
34 
35 /*****************************************************************************/
36 
find_in_chain(int * index,cm_qdev_t f,int length,long int no)37 static bool find_in_chain (int* index, cm_qdev_t f, int length, long int no)
38    /* looks for no in the first length elements of f                         */
39    /* The return value indicates the success; if the operation succeeds,     */
40    /* the index contains i such that f.chain [i][0] = no.                    */
41 
42 {
43    int left = 0, right = length - 1, middle;
44 
45    if (no < f.chain [0][0] || no > f.chain [length-1][0])
46       return false;
47 
48    while (left < right - 1)
49    {
50       middle = (left + right) / 2;
51       if (f.chain [middle][0] < no)
52          left = middle;
53       else
54          right = middle;
55    }
56    if (f.chain [left][0] == no)
57    {
58       *index = left;
59       return true;
60    }
61    else if (f.chain [right][0] == no)
62    {
63       *index = right;
64       return true;
65    }
66    else
67       return false;
68 }
69 
70 /*****************************************************************************/
71 
lognorm2(ctype op)72 static double lognorm2 (ctype op)
73    /* computes the logarithm in base 2 of the complex norm of op */
74 
75 {
76    double   re, im;
77    long int ere, eim, diff;
78 
79    /* Just extracting a double may overflow, so treat the exponents
80       separately. */
81    re = fget_d_2exp (&ere, crealref (op));
82    im = fget_d_2exp (&eim, cimagref (op));
83 
84    /* Handle the case of 0 in one part separately, as it may be coupled
85       with another very small exponent; then normalising for the larger
86       exponent yields 0 and a problem with the logarithm. */
87    if (re == 0)
88       return (eim + log2 (fabs (im)));
89    else if (im == 0)
90       return (ere + log2 (fabs (re)));
91 
92    /* Normalise to keep the larger exponent; the smaller one may underflow,
93       then the number becomes a harmless 0. */
94    if (ere > eim) {
95       diff = eim - ere;
96       eim = ere;
97       im *= exp2 (diff);
98    }
99    else {
100       diff = ere - eim;
101       ere = eim;
102       re *= exp2 (diff);
103    }
104 
105    return (ere + log2 (re*re + im*im) / 2);
106 }
107 
108 /*****************************************************************************/
109 
cm_qdev_init(cm_qdev_t * f,fprec_t prec)110 void cm_qdev_init (cm_qdev_t *f, fprec_t prec)
111    /* initialises the addition chain for eta */
112 
113 {
114    int n, i, j;
115 
116    f->length = 2 * ((fprec_t) (sqrt (prec * 0.085) + 1)) + 1;
117    /* must be odd                                                        */
118    /* Since each power of q yields at least                              */
119    /* log_2 (exp (sqrt (3) * pi)) = 7.85 bits,                           */
120    /* the k yielding the largest exponent must satisfy                   */
121    /* ((3*k+1)*k/2)*7.85 >= prec; we drop the +1 to simplify.            */
122    /* Then we have twice as many exponents (taking into account the      */
123    /* (3*k-1)*k/2), and one more for the constant coefficient.           */
124 
125    f->chain = (long int **) malloc (f->length * sizeof (long int *));
126    for (n = 0; n < f->length; n++)
127       f->chain [n] = (long int *) malloc (5 * sizeof (long int));
128 
129    f->chain [0][0] = 0;
130    f->chain [0][4] = 1;
131    for (n = 1; n <= f->length / 2; n++)
132    {
133       f->chain [2*n-1][0] = n*(3*n-1) / 2;
134       f->chain [2*n][0] = n*(3*n+1) / 2;
135       if (n % 2 == 0)
136       {
137          f->chain [2*n-1][4] = 1;
138          f->chain [2*n][4] = 1;
139       }
140       else
141       {
142          f->chain [2*n-1][4] = -1;
143          f->chain [2*n][4] = -1;
144       }
145    }
146 
147    f->chain [0][1] = 0;
148    f->chain [1][1] = 0;
149    for (n = 2; n < f->length; n++)
150    {
151       f->chain [n][1] = 0;
152       /* try to express an even exponent as twice a previous one      */
153       if (f->chain [n][0] % 2 == 0)
154          if (find_in_chain (&i, *f, n, f->chain [n][0] / 2))
155             {
156                f->chain [n][1] = 1;
157                f->chain [n][2] = i;
158             }
159       /* try to express the exponent as the sum of two previous ones */
160       for (i = 0; i < n && f->chain [n][1] == 0; i++)
161          if (find_in_chain (&j, *f, n, f->chain [n][0] - f->chain [i][0]))
162             {
163                f->chain [n][1] = 2;
164                f->chain [n][2] = i;
165                f->chain [n][3] = j;
166             }
167       /* try to express the exponent as twice a previous plus a third one */
168       for (i = 0; i < n && f->chain [n][1] == 0; i++)
169          if (find_in_chain (&j, *f, n, f->chain [n][0] - 2 * f->chain [i][0]))
170       {
171          f->chain [n][1] = 3;
172          f->chain [n][2] = i;
173          f->chain [n][3] = j;
174       }
175       /* This covers all cases for eta, see Enge-Johansson 2016. */
176    }
177 }
178 
179 /*****************************************************************************/
180 
cm_qdev_clear(cm_qdev_t * f)181 void cm_qdev_clear (cm_qdev_t *f)
182 
183 {
184    int n;
185 
186    for (n = 0; n < f->length; n++)
187       free (f->chain [n]);
188    free (f->chain);
189 }
190 
191 /*****************************************************************************/
192 
qdev_eval_addition_sequence(ctype rop,cm_qdev_t f,ctype q1,double delta,int N)193 static void qdev_eval_addition_sequence (ctype rop, cm_qdev_t f, ctype q1,
194    double delta, int N)
195    /* Evaluate f in q1 using the optimised addition sequence from f.
196       N is the last index used in the addition chain.
197       delta is the number of bits gained with each power of q.
198       rop and q1 may be the same. */
199 
200 {
201    mp_prec_t prec;
202    long int  local_prec;
203    ctype     *q, term, tmp1, tmp2;
204    int       n, i;
205 
206    prec = fget_prec (crealref (rop));
207 
208    q = (ctype *) malloc (f.length * sizeof (ctype));
209    cinit (q [1], prec);
210    cset (q [1], q1);
211    cinit (term, prec);
212    cinit (tmp1, prec);
213    cinit (tmp2, prec);
214 
215    cset_si (rop, f.chain [0][4]);
216    if (f.chain [1][4] == 1)
217      cadd (rop, rop, q [1]);
218    else if (f.chain [1][4] == -1)
219      csub (rop, rop, q [1]);
220    else if (f.chain [1][4] != 0)
221    {
222       cmul_si (term, q [1], f.chain [1][4]);
223       cadd (rop, rop, term);
224    }
225 
226    for (n = 2; n <= N; n++) {
227       local_prec = (long int) prec - (long int) (f.chain [n][0] * delta);
228       cinit (q [n], (mp_prec_t) local_prec);
229       switch (f.chain [n][1]) {
230       case 1:
231          /* Reduce the precision of the argument to save some more time. */
232          cset_prec (tmp1, local_prec);
233          cset (tmp1, q [f.chain [n][2]]);
234          csqr (q [n], tmp1);
235          break;
236       case 2:
237          cset_prec (tmp1, local_prec);
238          cset_prec (tmp2, local_prec);
239          cset (tmp1, q [f.chain [n][2]]);
240          cset (tmp2, q [f.chain [n][3]]);
241          cmul (q [n], tmp1, tmp2);
242          break;
243       case 3:
244          cset_prec (tmp1, local_prec);
245          cset_prec (tmp2, local_prec);
246          cset (tmp1, q [f.chain [n][2]]);
247          cset (tmp2, q [f.chain [n][3]]);
248          csqr (q [n], tmp1);
249          cmul (q [n], q [n], tmp2);
250          break;
251       }
252       if (f.chain [n][4] == 1)
253         cadd (rop, rop, q [n]);
254       else if (f.chain [n][4] == -1)
255         csub (rop, rop, q [n]);
256       else if (f.chain [n][4] != 0) {
257 	 cset_prec (term, (mp_prec_t) local_prec);
258          cmul_si (term, q [n], f.chain [n][4]);
259          cadd (rop, rop, term);
260       }
261    }
262 
263    for (i = 1; i < n; i++)
264       cclear (q [i]);
265    free (q);
266    cclear (term);
267    cclear (tmp1);
268    cclear (tmp2);
269 }
270 
271 /*****************************************************************************/
272 
dense_addition_sequence(long int ** b,int m)273 static int dense_addition_sequence (long int ** b, int m)
274    /* Compute an addition sequence for b and return it via b itself.
275       The return value is the number of additional entries.
276       b is a matrix of m+1 rows and (at least) 2 columns; the row i
277       represents the exponent i. Initially, b [i][0] is expected to be 1
278       if the exponent i occurs, 0 otherwise
279       During the course of the algorithm, new elements are added;
280       these are marked by a 2 in b [i][0].
281       b [i][1] is set to j if the exponent occurs and can be written as
282         j+(i-j) with also occurring exponents j and i-j.
283       The algorithm prefers doublings over other additions. */
284 {
285    int i, j, found, added;
286 
287    added = 0;
288    for (i = m; i >= 2; i--)
289       if (b [i][0] != 0) {
290          found = 0;
291          /* Search for two existing entries adding to i. */
292          for (j = i / 2; j >= 1 && !found; j--)
293             /* In this way, a doubling is found if it exists. */
294             if (b [j][0] != 0 && b [i-j][0] != 0) {
295                b [i][1] = j;
296                found = 1;
297             }
298          if (!found) {
299             /* Add missing elements. The middle strategy seems to give the
300                best performance; when minimising additionally, all three
301                have a very similar outcome. */
302 #if 0
303             /* First strategy: Add floor (i/2) and ceil (i/2). */
304             j = i/2;
305 #endif
306 #if 1
307             /* Second strategy:
308                If i is even, add i/2.
309                Otherwise, the previous strategy would add two numbers.
310                Instead, add i-j for the largest occurring j less than i. */
311             if (i % 2 == 0)
312                j = i / 2;
313             else
314                for (j = i-1; b [j][0] == 0; j--);
315 #endif
316 #if 0
317             /* Third strategy: As the second one, but for odd i, add
318                an even i-j for the largest possible j; at the latest, this
319                happens for i-j=1. */
320             if (i % 2 == 0)
321                j = i / 2;
322             else
323                for (j = i-1; j % 2 == 0 || b [j][0] == 0; j--);
324 #endif
325             b [i][1] = j;
326             if (b [j][0] == 0) {
327                b [j][0] = 2;
328                added++;
329             }
330             if (b [i-j][0] == 0) {
331                b [i-j][0] = 2;
332                added++;
333             }
334          }
335       }
336 
337    return (added);
338 }
339 
340 /*****************************************************************************/
341 
minimal_dense_addition_sequence(long int ** b,int m)342 static int minimal_dense_addition_sequence (long int ** b, int m)
343    /* The parameters are the same as for dense_addition_sequence,
344       except that b has three columns: b [i][2] is set to the desired
345       precision if the exponent i occurs.
346       The function returns an addition sequence that is minimal in the sense
347       that none of the added entries may be removed (which does not mean that
348       it is optimal!).
349       Also, if possible it replaces remaining general additions by doublings
350       (the addition sequence computation already privileges doublings, but
351       these may become possible only later using added terms).
352       The precision is also tracked, when writing an element as a sum of two
353       smaller ones, then the precision of the smaller elements may need to be
354       increased to that of the target. */
355 {
356    int added, new_added, changed, i, j;
357    long int **bnew;
358 
359    added = dense_addition_sequence (b, m);
360    /* We need to work on a copy, since the computation of addition sequences
361       via side effects destroys the previously computed sequence. It may
362       happen that when removing one of the additional terms, three new ones
363       are added in a run! Then we need to simply throw away the new
364       sequence. */
365    bnew = (long int **) malloc ((m + 1) * sizeof (long int *));
366       for (i = 0; i <= m; i++)
367          bnew [i] = (long *) malloc (2 * sizeof (long));
368 
369    /* This additional loop does not seem to be needed in practice. */
370    changed = 1;
371    while (changed) {
372       changed = 0;
373       /* Check if any of the added entries may be removed again. */
374       for (i = m; i >= 2; i--)
375          if (b [i][0] == 2) {
376             /* Work on a copy of b with this element cancelled. */
377             for (j = 0; j <= m; j++) {
378                bnew [j][0] = b [j][0];
379                bnew [j][1] = 0;
380             }
381             bnew [i][0] = 0;
382             new_added = dense_addition_sequence (bnew, m);
383             if (new_added == 0) {
384                /* Copy the new addition sequence back. */
385                for (j = 0; j <= m; j++) {
386                   b [j][0] = bnew [j][0];
387                   b [j][1] = bnew [j][1];
388                }
389                added--;
390                changed = 1;
391             }
392          }
393    }
394 
395    /* Look for doublings. */
396    for (i = 4; i <= m; i += 2)
397       if (b [i][0] != 0 && b [i/2][0] != 0)
398          b [i][1] = i/2;
399 
400    /* Track the precision. */
401    for (i = m; i >= 2; i--)
402       if (b [i][0] != 0) {
403          j = b [i][1];
404          if (b [i][2] > b [j][2])
405             b [j][2] = b [i][2];
406          if (b [i][2] > b [i-j][2])
407             b [i-j][2] = b [i][2];
408       }
409 
410    for (i = 0; i <= m; i++)
411       free (bnew [i]);
412    free (bnew);
413 
414    return (added);
415 }
416 
417 /*****************************************************************************/
418 
qdev_eval_bsgs(ctype rop,cm_qdev_t f,ctype q1,double delta,int N)419 static void qdev_eval_bsgs (ctype rop, cm_qdev_t f, ctype q1,
420    double delta, int N)
421    /* Evaluate f in q1 using the optimised addition sequence from f.
422       N is the last index used in the addition chain.
423       delta is the number of bits gained with each power of q.
424       rop and q1 may be the same. */
425 {
426    mp_prec_t prec, local_prec;
427    int mopt [37] = { 2, 5, 7, 11, 13, 17, 19, 23, 55, 65, 77,
428       91, 119, 133, 143, 175, 275, 325, 455, 595, 665, 715, 935, 1001,
429       1309, 1463, 1547, 1729, 2275, 2975, 3325, 3575, 4675,
430       6545, 7315, 7735, 8645 };
431    int p [37] = { 2, 3, 4, 6, 7, 9, 10, 12, 18, 21, 24,
432       28, 36, 40, 42, 44, 66, 77, 84, 108, 120, 126, 162, 168,
433       216, 240, 252, 280, 308, 396, 440, 462, 594,
434       648, 720, 756, 840 };
435    long int T, cost, cost_new;
436    int m, index;
437    long int **bs;
438    ctype *q, *c, tmp1, tmp2, tmp3;
439    int i, j, k, J;
440 
441    prec = fget_prec (crealref (rop));
442    cinit (tmp1, prec);
443    cinit (tmp2, prec);
444    cinit (tmp3, prec);
445    T = f.chain [N][0];
446    /* Find the optimal m=mopt[i] minimising the theoretical cost function
447       (roughly, T/mopt[i] + p[i]). The function seems to be unimodular,
448       so we take the smallest i before it increases again. */
449    i = 0;
450    cost_new = (T + mopt [0]) / mopt[0] + p[0] - 1;
451    do {
452       cost = cost_new;
453       i++;
454       cost_new = (T + mopt[i]) / mopt[i] + p[i] - 1;
455    } while (cost_new < cost && i < 36);
456    if (i == 36 && cost_new < cost) {
457       printf ("*** Houston, we have a problem!\n");
458       printf ("mopt and p too short in 'qdev_eval_bsgs'.\n");
459       exit (1);
460    }
461    else
462       m = mopt [i-1];
463 
464    /* Determine the occurring baby-steps; in practice, these are as
465       many as predicted by p. Also keep track of their required precision. */
466    bs = (long int **) malloc ((m + 1) * sizeof (long int *));
467    for (i = 0; i <= m; i++) {
468       bs [i] = (long *) malloc (3 * sizeof (long));
469       for (j = 0; j < 3; j++)
470          bs [i][j] = 0;
471    }
472    for (i = 0; i <= N; i++) {
473       index = f.chain [i][0] % m;
474       /* Register the precision needed for the term with lowest exponent
475          and not later ones. */
476       if (bs [index][0] == 0) {
477          bs [index][0] = 1;
478          bs [index][2] = (long int) prec
479                          - (long int) (f.chain [i][0] * delta);
480       }
481    }
482    bs [m][0] = 1; /* for the giant steps */
483    bs [m][2] = (long int) prec - (long int) (m * delta);
484    minimal_dense_addition_sequence (bs, m);
485 
486    /* Compute the baby-steps. */
487    q = (ctype *) malloc ((m + 1) * sizeof (ctype));
488    cinit (q [0], 2);
489    cset_ui (q [0], 1);
490    cinit (q [1], bs [1][2]);
491    cset (q [1], q1);
492    for (i = 2; i <= m; i++)
493       if (bs [i][0] != 0) {
494          cinit (q [i], bs [i][2]);
495          j = bs [i][1];
496          k = i - j;
497          /* Here decreasing the argument precision to the target precision
498             apparently does not gain time; both are probably too close for
499             it to make a difference. */
500          if (j == k)
501             csqr (q [i], q [j]);
502          else
503             cmul (q [i], q [j], q [k]);
504       }
505 
506    /* Compute the giant steps; we need
507       \sum_{j=0}^{J-1} (\sum_{k=0}^{m-1} c_{k+j*m} q^k) (q^m)^j.
508       First compute the inner coefficients, then use a Horner scheme. */
509    J = f.chain [N][0] / m + 1;
510    c = (ctype *) malloc (J * sizeof (ctype));
511    for (j = 0; j < J; j++) {
512       local_prec = prec - (mp_prec_t) (j * m * delta);
513       cinit (c [j], local_prec);
514       cset_ui (c [j], 0);
515    }
516    for (i = 0; i <= N; i++)
517       if (f.chain [i][4] != 0) {
518          j = f.chain [i][0] / m;
519          k = f.chain [i][0] % m;
520          /* We assume the coefficients are 1 or -1. */
521          if (f.chain [i][4] == 1)
522             cadd (c [j], c[j], q [k]);
523          else
524             csub (c [j], c[j], q [k]);
525       }
526    cset (rop, c [J-1]);
527    for (j = J-2; j >= 0; j--) {
528       /* Carry out the multiplication at the precision of c [j+1]. */
529       local_prec = prec - (mp_prec_t) ((j+1) * m * delta);
530       cset_prec (tmp1, local_prec);
531       cset_prec (tmp2, local_prec);
532       cset_prec (tmp3, local_prec);
533       cset (tmp1, rop);
534       cset (tmp2, q [m]);
535       cmul (tmp3, tmp1, tmp2);
536       cadd (rop, tmp3, c [j]);
537    }
538 
539    for (i = 0; i <= m; i++) {
540       if (bs [i][0] != 0)
541          cclear (q [i]);
542       free (bs [i]);
543    }
544    for (j = 0; j < J; j++)
545       cclear (c [j]);
546    free (bs);
547    free (q);
548    free (c);
549    cclear (tmp1);
550    cclear (tmp2);
551    cclear (tmp3);
552 }
553 
554 /*****************************************************************************/
555 
cm_qdev_eval(ctype rop,cm_qdev_t f,ctype q1)556 void cm_qdev_eval (ctype rop, cm_qdev_t f, ctype q1)
557    /* Evaluate f in q1. rop and q1 may be the same. */
558 
559 {
560    mp_prec_t prec;
561    double    delta;
562    long int T;
563    int N;
564 
565    /* Compute the last exponent T and its index N in f.chain. */
566    prec = fget_prec (crealref (rop));
567    delta = - lognorm2 (q1);
568    T = (prec - 2) / delta;
569    for (N=0; N < f.length && f.chain [N][0] <= T; N++);
570    if (N == f.length) {
571       printf ("*** Houston, we have a problem! Addition chain too short ");
572       printf ("in 'cm_qdev_eval'.\n");
573       printf ("T=%li, length=%i\n", T, f.length);
574       exit (1);
575    }
576    N--;
577    T = f.chain [N][0];
578 
579    if (N < 20)
580       qdev_eval_addition_sequence (rop, f, q1, delta, N);
581    else
582       qdev_eval_bsgs (rop, f, q1, delta, N);
583 }
584 
585 /*****************************************************************************/
586 
cm_qdev_eval_fr(ftype rop,cm_qdev_t f,ftype q1)587 void cm_qdev_eval_fr (ftype rop, cm_qdev_t f, ftype q1)
588    /* evaluates f in q1 */
589 
590 {
591    mp_prec_t prec;
592    long int  local_prec, e;
593    double    mantissa, delta;
594    ftype     *q, term;
595    int       n, i;
596 
597    prec = fget_prec (rop);
598    mantissa = fget_d_2exp (&e, q1);
599    delta = - (e + log2 (fabs (mantissa)));
600 
601    q = (ftype *) malloc (f.length * sizeof (ftype));
602    finit (q [1], prec);
603    fset (q [1], q1);
604    finit (term, prec);
605 
606    fset_si (rop, f.chain [0][4]);
607    if (f.chain [1][4] == 1)
608      fadd (rop, rop, q [1]);
609    else if (f.chain [1][4] == -1)
610      fsub (rop, rop, q [1]);
611    else if (f.chain [1][4] != 0)
612    {
613       fmul_si (term, q [1], f.chain [1][4]);
614       fadd (rop, rop, term);
615    }
616 
617    n = 2;
618    /* Adapt the precision for the next term. */
619    local_prec = (long int) prec - (long int) (f.chain [n][0] * delta);
620 
621    while (local_prec >= 2)
622    {
623       finit (q [n], (mp_prec_t) local_prec);
624       switch (f.chain [n][1])
625       {
626       case 1:
627          fsqr (q [n], q [f.chain [n][2]]);
628          break;
629       case 2:
630          fmul (q [n], q [f.chain [n][2]], q [f.chain [n][3]]);
631          break;
632       case 3:
633          fsqr (q [n], q [f.chain [n][2]]);
634          fmul (q [n], q[n], q [f.chain [n][3]]);
635          break;
636       }
637       if (f.chain [n][4] == 1)
638         fadd (rop, rop, q [n]);
639       else if (f.chain [n][4] == -1)
640         fsub (rop, rop, q [n]);
641       else if (f.chain [n][4] != 0)
642       {
643 	 fset_prec (term, (mp_prec_t) local_prec);
644          fmul_si (term, q [n], f.chain [n][4]);
645          fadd (rop, rop, term);
646       }
647       n++;
648       if (n >= f.length)
649       {
650          printf ("*** Houston, we have a problem! Addition chain too short ");
651          printf ("in 'qdev_eval_fr'.\n");
652          printf ("n=%i, length=%i\n", n, f.length);
653          printf ("q "); fout_str (stdout, 10, 10, q [1]);
654          printf ("\n");
655          printf ("q^i "); fout_str (stdout, 10, 10, q [n-1]);
656          printf ("\n");
657          exit (1);
658       }
659       local_prec = (long int) prec - (long int) (f.chain [n][0] * delta);
660    }
661 
662    for (i = 1; i < n; i++)
663       fclear (q [i]);
664    free (q);
665    fclear (term);
666 }
667 
668 /*****************************************************************************/
669 
670