1 /*
2 Copyright (C) 2009, 2011, 2020 William Hart
3
4 This file is part of FLINT.
5
6 FLINT is free software: you can redistribute it and/or modify it under
7 the terms of the GNU Lesser General Public License (LGPL) as published
8 by the Free Software Foundation; either version 2.1 of the License, or
9 (at your option) any later version. See <http://www.gnu.org/licenses/>.
10 */
11
12 #include "gmp.h"
13 #include "flint.h"
14 #include "fft.h"
15 #include "thread_support.h"
16
17 typedef struct
18 {
19 volatile mp_size_t * i;
20 slong num;
21 mp_size_t coeff_limbs;
22 mp_size_t output_limbs;
23 mp_srcptr limbs;
24 mp_limb_t ** poly;
25 #if HAVE_PTHREAD
26 pthread_mutex_t * mutex;
27 #endif
28 }
29 split_limbs_arg_t;
30
31 void
_split_limbs_worker(void * arg_ptr)32 _split_limbs_worker(void * arg_ptr)
33 {
34 split_limbs_arg_t arg = *((split_limbs_arg_t *) arg_ptr);
35 slong num = arg.num;
36 mp_size_t skip;
37 mp_size_t coeff_limbs = arg.coeff_limbs;
38 mp_size_t output_limbs = arg.output_limbs;
39 mp_srcptr limbs = arg.limbs;
40 mp_limb_t ** poly = arg.poly;
41 mp_size_t i, end;
42
43 while (1)
44 {
45 #if HAVE_PTHREAD
46 pthread_mutex_lock(arg.mutex);
47 #endif
48 i = *arg.i;
49 end = *arg.i = FLINT_MIN(i + 16, num);
50 #if HAVE_PTHREAD
51 pthread_mutex_unlock(arg.mutex);
52 #endif
53
54 if (i >= num)
55 return;
56
57 for ( ; i < end; i++)
58 {
59 skip = i*coeff_limbs;
60
61 flint_mpn_zero(poly[i], output_limbs + 1);
62 flint_mpn_copyi(poly[i], limbs + skip, coeff_limbs);
63 }
64 }
65 }
66
fft_split_limbs(mp_limb_t ** poly,mp_srcptr limbs,mp_size_t total_limbs,mp_size_t coeff_limbs,mp_size_t output_limbs)67 mp_size_t fft_split_limbs(mp_limb_t ** poly, mp_srcptr limbs,
68 mp_size_t total_limbs, mp_size_t coeff_limbs, mp_size_t output_limbs)
69 {
70 mp_size_t i, shared_i = 0, skip, length = (total_limbs - 1)/coeff_limbs + 1;
71 mp_size_t num = total_limbs/coeff_limbs;
72 #if HAVE_PTHREAD
73 pthread_mutex_t mutex;
74 #endif
75 slong num_threads;
76 thread_pool_handle * threads;
77 split_limbs_arg_t * args;
78
79 #if HAVE_PTHREAD
80 pthread_mutex_init(&mutex, NULL);
81 #endif
82
83 num_threads = flint_request_threads(&threads,
84 FLINT_MIN(flint_get_num_threads(), (num + 15)/16));
85
86 args = (split_limbs_arg_t *)
87 flint_malloc(sizeof(split_limbs_arg_t)*(num_threads + 1));
88
89 for (i = 0; i < num_threads + 1; i++)
90 {
91 args[i].i = &shared_i;
92 args[i].num = num;
93 args[i].coeff_limbs = coeff_limbs;
94 args[i].output_limbs = output_limbs;
95 args[i].limbs = limbs;
96 args[i].poly = poly;
97 #if HAVE_PTHREAD
98 args[i].mutex = &mutex;
99 #endif
100 }
101
102 for (i = 0; i < num_threads; i++)
103 thread_pool_wake(global_thread_pool, threads[i], 0,
104 _split_limbs_worker, &args[i]);
105
106 _split_limbs_worker(&args[num_threads]);
107
108 for (i = 0; i < num_threads; i++)
109 thread_pool_wait(global_thread_pool, threads[i]);
110
111 flint_give_back_threads(threads, num_threads);
112
113 flint_free(args);
114
115 #if HAVE_PTHREAD
116 pthread_mutex_destroy(&mutex);
117 #endif
118
119 i = num;
120 skip = i*coeff_limbs;
121
122 if (i < length)
123 flint_mpn_zero(poly[i], output_limbs + 1);
124
125 if (total_limbs > skip)
126 flint_mpn_copyi(poly[i], limbs + skip, total_limbs - skip);
127
128 return length;
129 }
130
131 typedef struct
132 {
133 volatile mp_size_t * i;
134 slong length;
135 mp_size_t coeff_limbs;
136 mp_size_t output_limbs;
137 mp_srcptr limbs;
138 flint_bitcnt_t top_bits;
139 mp_limb_t mask;
140 mp_limb_t ** poly;
141 #if HAVE_PTHREAD
142 pthread_mutex_t * mutex;
143 #endif
144 }
145 split_bits_arg_t;
146
147 void
_split_bits_worker(void * arg_ptr)148 _split_bits_worker(void * arg_ptr)
149 {
150 split_bits_arg_t arg = *((split_bits_arg_t *) arg_ptr);
151 slong length = arg.length;
152 mp_size_t coeff_limbs = arg.coeff_limbs;
153 mp_size_t output_limbs = arg.output_limbs;
154 mp_srcptr limbs = arg.limbs;
155 flint_bitcnt_t top_bits = arg.top_bits;
156 mp_limb_t mask = arg.mask;
157 mp_limb_t ** poly = arg.poly;
158 flint_bitcnt_t shift_bits;
159 mp_srcptr limb_ptr;
160 mp_size_t i, end;
161
162 while (1)
163 {
164 #if HAVE_PTHREAD
165 pthread_mutex_lock(arg.mutex);
166 #endif
167 i = *arg.i;
168 end = *arg.i = FLINT_MIN(i + 16, length - 1);
169 #if HAVE_PTHREAD
170 pthread_mutex_unlock(arg.mutex);
171 #endif
172
173 if (i >= length - 1)
174 return;
175
176 for ( ; i < end; i++)
177 {
178 flint_mpn_zero(poly[i], output_limbs + 1);
179
180 limb_ptr = limbs + i*(coeff_limbs - 1) + (i*top_bits)/FLINT_BITS;
181 shift_bits = (i*top_bits) % FLINT_BITS;
182
183 if (!shift_bits)
184 {
185 flint_mpn_copyi(poly[i], limb_ptr, coeff_limbs);
186 poly[i][coeff_limbs - 1] &= mask;
187 limb_ptr += (coeff_limbs - 1);
188 shift_bits += top_bits;
189 } else
190 {
191 mpn_rshift(poly[i], limb_ptr, coeff_limbs, shift_bits);
192 limb_ptr += (coeff_limbs - 1);
193 shift_bits += top_bits;
194
195 if (shift_bits >= FLINT_BITS)
196 {
197 limb_ptr++;
198 poly[i][coeff_limbs - 1] +=
199 (limb_ptr[0] << (FLINT_BITS - (shift_bits - top_bits)));
200 shift_bits -= FLINT_BITS;
201 }
202
203 poly[i][coeff_limbs - 1] &= mask;
204 }
205 }
206 }
207 }
208
fft_split_bits(mp_limb_t ** poly,mp_srcptr limbs,mp_size_t total_limbs,flint_bitcnt_t bits,mp_size_t output_limbs)209 mp_size_t fft_split_bits(mp_limb_t ** poly, mp_srcptr limbs,
210 mp_size_t total_limbs, flint_bitcnt_t bits, mp_size_t output_limbs)
211 {
212 mp_size_t i, shared_i = 0, coeff_limbs, limbs_left;
213 mp_size_t length = (FLINT_BITS*total_limbs - 1)/bits + 1;
214 flint_bitcnt_t shift_bits, top_bits = ((FLINT_BITS - 1) & bits);
215 mp_srcptr limb_ptr;
216 mp_limb_t mask;
217 #if HAVE_PTHREAD
218 pthread_mutex_t mutex;
219 #endif
220 slong num_threads;
221 thread_pool_handle * threads;
222 split_bits_arg_t * args;
223
224 if (top_bits == 0)
225 return fft_split_limbs(poly, limbs, total_limbs, bits/FLINT_BITS, output_limbs);
226
227 coeff_limbs = (bits/FLINT_BITS) + 1;
228 mask = (WORD(1)<<top_bits) - WORD(1);
229 shift_bits = WORD(0);
230 limb_ptr = limbs;
231
232 #if HAVE_PTHREAD
233 pthread_mutex_init(&mutex, NULL);
234 #endif
235
236 num_threads = flint_request_threads(&threads,
237 FLINT_MIN(flint_get_num_threads(), (length - 1 + 15)/16));
238
239 args = (split_bits_arg_t *)
240 flint_malloc(sizeof(split_bits_arg_t)*(num_threads + 1));
241
242 for (i = 0; i < num_threads + 1; i++)
243 {
244 args[i].i = &shared_i;
245 args[i].length = length;
246 args[i].coeff_limbs = coeff_limbs;
247 args[i].output_limbs = output_limbs;
248 args[i].limbs = limbs;
249 args[i].top_bits = top_bits;
250 args[i].mask = mask;
251 args[i].poly = poly;
252 #if HAVE_PTHREAD
253 args[i].mutex = &mutex;
254 #endif
255 }
256
257 for (i = 0; i < num_threads; i++)
258 thread_pool_wake(global_thread_pool, threads[i], 0,
259 _split_bits_worker, &args[i]);
260
261 _split_bits_worker(&args[num_threads]);
262
263 for (i = 0; i < num_threads; i++)
264 thread_pool_wait(global_thread_pool, threads[i]);
265
266 flint_give_back_threads(threads, num_threads);
267
268 flint_free(args);
269
270 #if HAVE_PTHREAD
271 pthread_mutex_destroy(&mutex);
272 #endif
273
274 i = length - 1;
275 limb_ptr = limbs + i*(coeff_limbs - 1) + (i*top_bits)/FLINT_BITS;
276 shift_bits = (i*top_bits) % FLINT_BITS;
277
278 flint_mpn_zero(poly[i], output_limbs + 1);
279
280 limbs_left = total_limbs - (limb_ptr - limbs);
281
282 if (!shift_bits)
283 flint_mpn_copyi(poly[i], limb_ptr, limbs_left);
284 else
285 mpn_rshift(poly[i], limb_ptr, limbs_left, shift_bits);
286
287 return length;
288 }
289
290