1 /*
2     Copyright (C) 2009, 2011 William Hart
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <https://www.gnu.org/licenses/>.
10 */
11 
12 #include "gmp.h"
13 #include "flint.h"
14 #include "fft.h"
15 #include "ulong_extras.h"
16 #include "fft_tuning.h"
17 
18 static int fft_tuning_table[5][2] = FFT_TAB;
19 
flint_mpn_mul_fft_main(mp_ptr r1,mp_srcptr i1,mp_size_t n1,mp_srcptr i2,mp_size_t n2)20 void flint_mpn_mul_fft_main(mp_ptr r1, mp_srcptr i1, mp_size_t n1,
21                         mp_srcptr i2, mp_size_t n2)
22 {
23    mp_size_t off, depth = 6;
24    mp_size_t w = 1;
25    mp_size_t n = ((mp_size_t) 1 << depth);
26    flint_bitcnt_t bits = (n*w - (depth+1))/2;
27 
28    flint_bitcnt_t bits1 = n1*FLINT_BITS;
29    flint_bitcnt_t bits2 = n2*FLINT_BITS;
30 
31    mp_size_t j1 = (bits1 - 1)/bits + 1;
32    mp_size_t j2 = (bits2 - 1)/bits + 1;
33 
34    FLINT_ASSERT(n1 > 0);
35    FLINT_ASSERT(n2 > 0);
36    FLINT_ASSERT(j1 + j2 - 1 > 2*n);
37 
38    while (j1 + j2 - 1 > 4*n) /* find initial n, w */
39    {
40       if (w == 1) w = 2;
41       else
42       {
43          depth++;
44          w = 1;
45          n *= 2;
46       }
47 
48       bits = (n*w - (depth+1))/2;
49       j1 = (bits1 - 1)/bits + 1;
50       j2 = (bits2 - 1)/bits + 1;
51    }
52 
53    if (depth < 11)
54    {
55       mp_size_t wadj = 1;
56 
57       off = fft_tuning_table[depth - 6][w - 1]; /* adjust n and w */
58       depth -= off;
59       n = ((mp_size_t) 1 << depth);
60       w *= ((mp_size_t) 1 << (2*off));
61 
62       if (depth < 6) wadj = ((mp_size_t) 1 << (6 - depth));
63 
64       if (w > wadj)
65       {
66          do { /* see if a smaller w will work */
67             w -= wadj;
68             bits = (n*w - (depth+1))/2;
69             j1 = (bits1 - 1)/bits + 1;
70             j2 = (bits2 - 1)/bits + 1;
71          } while (j1 + j2 - 1 <= 4*n && w > wadj);
72          w += wadj;
73       }
74 
75       mul_truncate_sqrt2(r1, i1, n1, i2, n2, depth, w);
76    } else
77    {
78       if (j1 + j2 - 1 <= 3*n)
79       {
80          depth--;
81          w *= 3;
82       }
83       mul_mfa_truncate_sqrt2(r1, i1, n1, i2, n2, depth, w);
84    }
85 }
86 
87