1 /*
2     Copyright (C) 2009, 2011, 2020 William Hart
3 
4     This file is part of FLINT.
5 
6     FLINT is free software: you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 2.1 of the License, or
9     (at your option) any later version.  See <http://www.gnu.org/licenses/>.
10 */
11 
12 #include "gmp.h"
13 #include "flint.h"
14 #include "fft.h"
15 #include "thread_support.h"
16 
17 typedef struct
18 {
19     volatile mp_size_t * i;
20     slong num;
21     mp_size_t coeff_limbs;
22     mp_size_t output_limbs;
23     mp_srcptr limbs;
24     mp_limb_t ** poly;
25 #if HAVE_PTHREAD
26     pthread_mutex_t * mutex;
27 #endif
28 }
29 split_limbs_arg_t;
30 
31 void
_split_limbs_worker(void * arg_ptr)32 _split_limbs_worker(void * arg_ptr)
33 {
34     split_limbs_arg_t arg = *((split_limbs_arg_t *) arg_ptr);
35     slong num = arg.num;
36     mp_size_t skip;
37     mp_size_t coeff_limbs = arg.coeff_limbs;
38     mp_size_t output_limbs = arg.output_limbs;
39     mp_srcptr limbs = arg.limbs;
40     mp_limb_t ** poly = arg.poly;
41     mp_size_t i, end;
42 
43     while (1)
44     {
45 #if HAVE_PTHREAD
46         pthread_mutex_lock(arg.mutex);
47 #endif
48 	i = *arg.i;
49         end = *arg.i = FLINT_MIN(i + 16, num);
50 #if HAVE_PTHREAD
51         pthread_mutex_unlock(arg.mutex);
52 #endif
53 
54         if (i >= num)
55             return;
56 
57         for ( ; i < end; i++)
58         {
59            skip = i*coeff_limbs;
60 
61            flint_mpn_zero(poly[i], output_limbs + 1);
62            flint_mpn_copyi(poly[i], limbs + skip, coeff_limbs);
63         }
64     }
65 }
66 
fft_split_limbs(mp_limb_t ** poly,mp_srcptr limbs,mp_size_t total_limbs,mp_size_t coeff_limbs,mp_size_t output_limbs)67 mp_size_t fft_split_limbs(mp_limb_t ** poly, mp_srcptr limbs,
68           mp_size_t total_limbs, mp_size_t coeff_limbs, mp_size_t output_limbs)
69 {
70     mp_size_t i, shared_i = 0, skip, length = (total_limbs - 1)/coeff_limbs + 1;
71     mp_size_t num = total_limbs/coeff_limbs;
72 #if HAVE_PTHREAD
73     pthread_mutex_t mutex;
74 #endif
75     slong num_threads;
76     thread_pool_handle * threads;
77     split_limbs_arg_t * args;
78 
79 #if HAVE_PTHREAD
80     pthread_mutex_init(&mutex, NULL);
81 #endif
82 
83     num_threads = flint_request_threads(&threads,
84                             FLINT_MIN(flint_get_num_threads(), (num + 15)/16));
85 
86     args = (split_limbs_arg_t *)
87                      flint_malloc(sizeof(split_limbs_arg_t)*(num_threads + 1));
88 
89     for (i = 0; i < num_threads + 1; i++)
90     {
91        args[i].i = &shared_i;
92        args[i].num = num;
93        args[i].coeff_limbs = coeff_limbs;
94        args[i].output_limbs = output_limbs;
95        args[i].limbs = limbs;
96        args[i].poly = poly;
97 #if HAVE_PTHREAD
98        args[i].mutex = &mutex;
99 #endif
100     }
101 
102     for (i = 0; i < num_threads; i++)
103         thread_pool_wake(global_thread_pool, threads[i], 0,
104                                                 _split_limbs_worker, &args[i]);
105 
106     _split_limbs_worker(&args[num_threads]);
107 
108     for (i = 0; i < num_threads; i++)
109         thread_pool_wait(global_thread_pool, threads[i]);
110 
111     flint_give_back_threads(threads, num_threads);
112 
113     flint_free(args);
114 
115 #if HAVE_PTHREAD
116     pthread_mutex_destroy(&mutex);
117 #endif
118 
119     i = num;
120     skip = i*coeff_limbs;
121 
122     if (i < length)
123         flint_mpn_zero(poly[i], output_limbs + 1);
124 
125     if (total_limbs > skip)
126         flint_mpn_copyi(poly[i], limbs + skip, total_limbs - skip);
127 
128     return length;
129 }
130 
131 typedef struct
132 {
133     volatile mp_size_t * i;
134     slong length;
135     mp_size_t coeff_limbs;
136     mp_size_t output_limbs;
137     mp_srcptr limbs;
138     flint_bitcnt_t top_bits;
139     mp_limb_t mask;
140     mp_limb_t ** poly;
141 #if HAVE_PTHREAD
142     pthread_mutex_t * mutex;
143 #endif
144 }
145 split_bits_arg_t;
146 
147 void
_split_bits_worker(void * arg_ptr)148 _split_bits_worker(void * arg_ptr)
149 {
150     split_bits_arg_t arg = *((split_bits_arg_t *) arg_ptr);
151     slong length = arg.length;
152     mp_size_t coeff_limbs = arg.coeff_limbs;
153     mp_size_t output_limbs = arg.output_limbs;
154     mp_srcptr limbs = arg.limbs;
155     flint_bitcnt_t top_bits = arg.top_bits;
156     mp_limb_t mask = arg.mask;
157     mp_limb_t ** poly = arg.poly;
158     flint_bitcnt_t shift_bits;
159     mp_srcptr limb_ptr;
160     mp_size_t i, end;
161 
162     while (1)
163     {
164 #if HAVE_PTHREAD
165         pthread_mutex_lock(arg.mutex);
166 #endif
167 	i = *arg.i;
168         end = *arg.i = FLINT_MIN(i + 16, length - 1);
169 #if HAVE_PTHREAD
170         pthread_mutex_unlock(arg.mutex);
171 #endif
172 
173         if (i >= length - 1)
174             return;
175 
176         for ( ; i < end; i++)
177         {
178             flint_mpn_zero(poly[i], output_limbs + 1);
179 
180             limb_ptr = limbs + i*(coeff_limbs - 1) + (i*top_bits)/FLINT_BITS;
181             shift_bits = (i*top_bits) % FLINT_BITS;
182 
183             if (!shift_bits)
184             {
185                 flint_mpn_copyi(poly[i], limb_ptr, coeff_limbs);
186                 poly[i][coeff_limbs - 1] &= mask;
187                 limb_ptr += (coeff_limbs - 1);
188                 shift_bits += top_bits;
189             } else
190             {
191                 mpn_rshift(poly[i], limb_ptr, coeff_limbs, shift_bits);
192                 limb_ptr += (coeff_limbs - 1);
193                 shift_bits += top_bits;
194 
195                 if (shift_bits >= FLINT_BITS)
196                 {
197                    limb_ptr++;
198                    poly[i][coeff_limbs - 1] +=
199                        (limb_ptr[0] << (FLINT_BITS - (shift_bits - top_bits)));
200                    shift_bits -= FLINT_BITS;
201                 }
202 
203                 poly[i][coeff_limbs - 1] &= mask;
204             }
205         }
206     }
207 }
208 
fft_split_bits(mp_limb_t ** poly,mp_srcptr limbs,mp_size_t total_limbs,flint_bitcnt_t bits,mp_size_t output_limbs)209 mp_size_t fft_split_bits(mp_limb_t ** poly, mp_srcptr limbs,
210                mp_size_t total_limbs, flint_bitcnt_t bits, mp_size_t output_limbs)
211 {
212     mp_size_t i, shared_i = 0, coeff_limbs, limbs_left;
213     mp_size_t length = (FLINT_BITS*total_limbs - 1)/bits + 1;
214     flint_bitcnt_t shift_bits, top_bits = ((FLINT_BITS - 1) & bits);
215     mp_srcptr limb_ptr;
216     mp_limb_t mask;
217 #if HAVE_PTHREAD
218     pthread_mutex_t mutex;
219 #endif
220     slong num_threads;
221     thread_pool_handle * threads;
222     split_bits_arg_t * args;
223 
224     if (top_bits == 0)
225         return fft_split_limbs(poly, limbs, total_limbs, bits/FLINT_BITS, output_limbs);
226 
227     coeff_limbs = (bits/FLINT_BITS) + 1;
228     mask = (WORD(1)<<top_bits) - WORD(1);
229     shift_bits = WORD(0);
230     limb_ptr = limbs;
231 
232 #if HAVE_PTHREAD
233     pthread_mutex_init(&mutex, NULL);
234 #endif
235 
236     num_threads = flint_request_threads(&threads,
237                      FLINT_MIN(flint_get_num_threads(), (length - 1 + 15)/16));
238 
239     args = (split_bits_arg_t *)
240                       flint_malloc(sizeof(split_bits_arg_t)*(num_threads + 1));
241 
242     for (i = 0; i < num_threads + 1; i++)
243     {
244        args[i].i = &shared_i;
245        args[i].length = length;
246        args[i].coeff_limbs = coeff_limbs;
247        args[i].output_limbs = output_limbs;
248        args[i].limbs = limbs;
249        args[i].top_bits = top_bits;
250        args[i].mask = mask;
251        args[i].poly = poly;
252 #if HAVE_PTHREAD
253        args[i].mutex = &mutex;
254 #endif
255     }
256 
257     for (i = 0; i < num_threads; i++)
258         thread_pool_wake(global_thread_pool, threads[i], 0,
259                                                  _split_bits_worker, &args[i]);
260 
261     _split_bits_worker(&args[num_threads]);
262 
263     for (i = 0; i < num_threads; i++)
264         thread_pool_wait(global_thread_pool, threads[i]);
265 
266     flint_give_back_threads(threads, num_threads);
267 
268     flint_free(args);
269 
270 #if HAVE_PTHREAD
271     pthread_mutex_destroy(&mutex);
272 #endif
273 
274     i = length - 1;
275     limb_ptr = limbs + i*(coeff_limbs - 1) + (i*top_bits)/FLINT_BITS;
276     shift_bits = (i*top_bits) % FLINT_BITS;
277 
278     flint_mpn_zero(poly[i], output_limbs + 1);
279 
280     limbs_left = total_limbs - (limb_ptr - limbs);
281 
282     if (!shift_bits)
283         flint_mpn_copyi(poly[i], limb_ptr, limbs_left);
284     else
285         mpn_rshift(poly[i], limb_ptr, limbs_left, shift_bits);
286 
287     return length;
288 }
289 
290