1 /*
2 * Copyright (c) 1997-1999, 2003 Massachusetts Institute of Technology
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation; either version 2 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, write to the Free Software
16 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 *
18 */
19
20 #include <stdio.h>
21 #include <math.h>
22
23 #include "fftw_mpi.h"
24 #include "fftw-int.h"
25
26 /************************** Twiddle Factors *****************************/
27
28 /* To conserve space, we share twiddle factor arrays between forward and
29 backward plans and plans of the same size (just as in the uniprocessor
30 transforms). */
31
32 static fftw_mpi_twiddle *fftw_mpi_twiddles = NULL;
33
fftw_mpi_create_twiddle(int rows,int rowstart,int cols,int n)34 static fftw_mpi_twiddle *fftw_mpi_create_twiddle(int rows, int rowstart,
35 int cols, int n)
36 {
37 fftw_mpi_twiddle *tw = fftw_mpi_twiddles;
38
39 while (tw && (tw->rows != rows || tw->rowstart != rowstart ||
40 tw->cols != cols || tw->n != n))
41 tw = tw->next;
42
43 if (tw) {
44 tw->refcount++;
45 return tw;
46 }
47
48 tw = (fftw_mpi_twiddle *) fftw_malloc(sizeof(fftw_mpi_twiddle));
49 tw->rows = rows;
50 tw->rowstart = rowstart;
51 tw->cols = cols;
52 tw->n = n;
53 tw->refcount = 1;
54 tw->next = fftw_mpi_twiddles;
55
56 {
57 fftw_complex *W = (fftw_complex *) fftw_malloc(sizeof(fftw_complex) *
58 rows * (cols - 1));
59 int j, i;
60 FFTW_TRIG_REAL twoPiOverN = FFTW_K2PI / (FFTW_TRIG_REAL) n;
61
62 for (j = 0; j < rows; ++j)
63 for (i = 1; i < cols; ++i) {
64 int k = (j * (cols - 1) - 1) + i;
65 FFTW_TRIG_REAL
66 ij = (FFTW_TRIG_REAL) (i * (j + rowstart));
67 c_re(W[k]) = FFTW_TRIG_COS(twoPiOverN * ij);
68 c_im(W[k]) = FFTW_FORWARD * FFTW_TRIG_SIN(twoPiOverN * ij);
69 }
70
71 tw->W = W;
72 }
73
74 fftw_mpi_twiddles = tw;
75
76 return tw;
77 }
78
fftw_mpi_destroy_twiddle(fftw_mpi_twiddle * tw)79 static void fftw_mpi_destroy_twiddle(fftw_mpi_twiddle *tw)
80 {
81 if (tw) {
82 tw->refcount--;
83 if (tw->refcount == 0) {
84 /* delete tw from fftw_mpi_twiddles list: */
85 if (fftw_mpi_twiddles == tw)
86 fftw_mpi_twiddles = tw->next;
87 else {
88 fftw_mpi_twiddle *prev = fftw_mpi_twiddles;
89
90 if (!prev)
91 fftw_mpi_die("unexpected empty MPI twiddle list");
92 while (prev->next && prev->next != tw)
93 prev = prev->next;
94 if (prev->next != tw)
95 fftw_mpi_die("tried to destroy unknown MPI twiddle");
96 prev->next = tw->next;
97 }
98
99 fftw_free(tw->W);
100 fftw_free(tw);
101 }
102 }
103 }
104
105 /* multiply the array in d (of size tw->cols * n_fields) by the row cur_row
106 of the twiddle factors pointed to by tw, given the transform direction. */
fftw_mpi_mult_twiddles(fftw_complex * d,int n_fields,int cur_row,fftw_mpi_twiddle * tw,fftw_direction dir)107 static void fftw_mpi_mult_twiddles(fftw_complex *d, int n_fields,
108 int cur_row,
109 fftw_mpi_twiddle *tw,
110 fftw_direction dir)
111 {
112 int cols = tw->cols;
113 fftw_complex *W = tw->W + cur_row * (cols - 1);
114 int j;
115
116 if (dir == FFTW_FORWARD) {
117 if (n_fields > 1)
118 for (j = 1; j < cols; ++j) {
119 fftw_real
120 w_re = c_re(W[j-1]),
121 w_im = c_im(W[j-1]);
122 int f;
123
124 for (f = 0; f < n_fields; ++f) {
125 fftw_real
126 d_re = c_re(d[j*n_fields + f]),
127 d_im = c_im(d[j*n_fields + f]);
128 c_re(d[j*n_fields + f]) = w_re * d_re - w_im * d_im;
129 c_im(d[j*n_fields + f]) = w_re * d_im + w_im * d_re;
130 }
131 }
132 else
133 for (j = 1; j < cols; ++j) {
134 fftw_real w_re = c_re(W[j-1]),
135 w_im = c_im(W[j-1]),
136 d_re = c_re(d[j]),
137 d_im = c_im(d[j]);
138 c_re(d[j]) = w_re * d_re - w_im * d_im;
139 c_im(d[j]) = w_re * d_im + w_im * d_re;
140 }
141 }
142 else { /* FFTW_BACKWARDS */
143 /* same as above, except that W is complex-conjugated: */
144 if (n_fields > 1)
145 for (j = 1; j < cols; ++j) {
146 fftw_real
147 w_re = c_re(W[j-1]),
148 w_im = c_im(W[j-1]);
149 int f;
150
151 for (f = 0; f < n_fields; ++f) {
152 fftw_real
153 d_re = c_re(d[j*n_fields + f]),
154 d_im = c_im(d[j*n_fields + f]);
155 c_re(d[j*n_fields + f]) = w_re * d_re + w_im * d_im;
156 c_im(d[j*n_fields + f]) = w_re * d_im - w_im * d_re;
157 }
158 }
159 else
160 for (j = 1; j < cols; ++j) {
161 fftw_real w_re = c_re(W[j-1]),
162 w_im = c_im(W[j-1]),
163 d_re = c_re(d[j]),
164 d_im = c_im(d[j]);
165 c_re(d[j]) = w_re * d_re + w_im * d_im;
166 c_im(d[j]) = w_re * d_im - w_im * d_re;
167 }
168 }
169 }
170
171 /***************************** Plan Creation ****************************/
172
173 /* return the factor of n closest to sqrt(n): */
find_sqrt_factor(int n)174 static int find_sqrt_factor(int n)
175 {
176 int i = sqrt(n) + 0.5;
177 int i2 = i - 1;
178
179 while (i2 > 0) {
180 if (n % i2 == 0)
181 return i2;
182 if (n % i == 0)
183 return i;
184 ++i; --i2;
185 }
186 return 1; /* n <= 1 */
187 }
188
189 /* find the "best" r to divide n by for the FFT decomposition. Ideally,
190 we would like both r and n/r to be divisible by the number of
191 processes (for optimum load-balancing). Also, pick r to be close
192 to sqrt(n) if possible. */
find_best_r(int n,MPI_Comm comm)193 static int find_best_r(int n, MPI_Comm comm)
194 {
195 int n_pes;
196
197 MPI_Comm_size(comm, &n_pes);
198
199 if (n % n_pes == 0) {
200 n /= n_pes;
201 if (n % n_pes == 0)
202 return (n_pes * find_sqrt_factor(n / n_pes));
203 else
204 return (n_pes * find_sqrt_factor(n));
205 }
206 else
207 return find_sqrt_factor(n);
208 }
209
210 #define MAX2(a,b) ((a) > (b) ? (a) : (b))
211
fftw_mpi_create_plan(MPI_Comm comm,int n,fftw_direction dir,int flags)212 fftw_mpi_plan fftw_mpi_create_plan(MPI_Comm comm,
213 int n, fftw_direction dir, int flags)
214 {
215 fftw_mpi_plan p;
216 int i, r, m;
217
218 p = (fftw_mpi_plan) fftw_malloc(sizeof(struct fftw_mpi_plan_struct));
219
220 i = find_best_r(n, comm);
221 if (dir == FFTW_FORWARD)
222 m = n / (r = i);
223 else
224 r = n / (m = i);
225
226 p->n = n;
227 p->r = r;
228 p->m = m;
229
230 flags |= FFTW_IN_PLACE;
231 p->flags = flags;
232 p->dir = dir;
233
234 p->pr = fftw_create_plan(r, dir, flags);
235 p->pm = fftw_create_plan(m, dir, flags);
236
237 p->p_transpose = transpose_mpi_create_plan(m, r, comm);
238 p->p_transpose_inv = transpose_mpi_create_plan(r, m, comm);
239
240 transpose_mpi_get_local_size(r,
241 p->p_transpose_inv->my_pe,
242 p->p_transpose_inv->n_pes,
243 &p->local_r,
244 &p->local_r_start);
245 transpose_mpi_get_local_size(m,
246 p->p_transpose->my_pe,
247 p->p_transpose->n_pes,
248 &p->local_m,
249 &p->local_m_start);
250
251 if (dir == FFTW_FORWARD)
252 p->tw = fftw_mpi_create_twiddle(p->local_r, p->local_r_start, m, n);
253 else
254 p->tw = fftw_mpi_create_twiddle(p->local_m, p->local_m_start, r, n);
255
256 p->fft_work = (fftw_complex *) fftw_malloc(sizeof(fftw_complex) *
257 MAX2(m, r));
258
259 return p;
260 }
261
262 /********************* Getting Local Size ***********************/
263
fftw_mpi_local_sizes(fftw_mpi_plan p,int * local_n,int * local_start,int * local_n_after_transform,int * local_start_after_transform,int * total_local_size)264 void fftw_mpi_local_sizes(fftw_mpi_plan p,
265 int *local_n,
266 int *local_start,
267 int *local_n_after_transform,
268 int *local_start_after_transform,
269 int *total_local_size)
270 {
271 if (p) {
272 if (p->flags & FFTW_SCRAMBLED_INPUT) {
273 *local_n = p->local_r * p->m;
274 *local_start = p->local_r_start * p->m;
275 }
276 else {
277 *local_n = p->local_m * p->r;
278 *local_start = p->local_m_start * p->r;
279 }
280
281 if (p->flags & FFTW_SCRAMBLED_OUTPUT) {
282 *local_n_after_transform = p->local_m * p->r;
283 *local_start_after_transform = p->local_m_start * p->r;
284 }
285 else {
286 *local_n_after_transform = p->local_r * p->m;
287 *local_start_after_transform = p->local_r_start * p->m;
288 }
289
290 *total_local_size =
291 transpose_mpi_get_local_storage_size(p->p_transpose->nx,
292 p->p_transpose->ny,
293 p->p_transpose->my_pe,
294 p->p_transpose->n_pes);
295 }
296 }
297
fftw_mpi_fprint_plan(FILE * f,fftw_mpi_plan p)298 static void fftw_mpi_fprint_plan(FILE *f, fftw_mpi_plan p)
299 {
300 fprintf(f, "mpi plan:\n");
301 fprintf(f, "m = %d plan:\n", p->m);
302 fftw_fprint_plan(f, p->pm);
303 fprintf(f, "r = %d plan:\n", p->r);
304 fftw_fprint_plan(f, p->pr);
305 }
306
fftw_mpi_print_plan(fftw_mpi_plan p)307 void fftw_mpi_print_plan(fftw_mpi_plan p)
308 {
309 fftw_mpi_fprint_plan(stdout, p);
310 }
311
312 /********************** Plan Destruction ************************/
313
fftw_mpi_destroy_plan(fftw_mpi_plan p)314 void fftw_mpi_destroy_plan(fftw_mpi_plan p)
315 {
316 if (p) {
317 fftw_destroy_plan(p->pr);
318 fftw_destroy_plan(p->pm);
319 transpose_mpi_destroy_plan(p->p_transpose);
320 transpose_mpi_destroy_plan(p->p_transpose_inv);
321 fftw_mpi_destroy_twiddle(p->tw);
322 fftw_free(p->fft_work);
323 fftw_free(p);
324 }
325 }
326
327 /******************** Computing the Transform *******************/
328
fftw_mpi(fftw_mpi_plan p,int n_fields,fftw_complex * local_data,fftw_complex * work)329 void fftw_mpi(fftw_mpi_plan p, int n_fields,
330 fftw_complex *local_data, fftw_complex *work)
331 {
332 int i;
333 int el_size = (sizeof(fftw_complex) / sizeof(TRANSPOSE_EL_TYPE))
334 * n_fields;
335 fftw_complex *fft_work;
336 fftw_direction dir;
337 fftw_mpi_twiddle *tw;
338
339 if (n_fields < 1)
340 return;
341
342 if (!(p->flags & FFTW_SCRAMBLED_INPUT))
343 transpose_mpi(p->p_transpose, el_size,
344 (TRANSPOSE_EL_TYPE *) local_data,
345 (TRANSPOSE_EL_TYPE *) work);
346
347 tw = p->tw;
348 dir = p->dir;
349 fft_work = work ? work : p->fft_work;
350
351 /* For forward plans, we multiply by the twiddle factors here,
352 before the second transpose. For backward plans, we multiply
353 by the twiddle factors after the second transpose. We do
354 this so that forward and backward transforms can share the
355 same twiddle factor array (noting that m and r are swapped
356 for the two directions so that the local sizes will be compatible). */
357
358 {
359 int rows = p->local_r, cols = p->m;
360 fftw_plan p_fft = p->pm;
361
362 if (dir == FFTW_FORWARD) {
363 for (i = 0; i < rows; ++i) {
364 fftw_complex *d = local_data + i * (cols * n_fields);
365
366 fftw(p_fft, n_fields, d, n_fields, 1, fft_work, 1, 0);
367 fftw_mpi_mult_twiddles(d, n_fields, i, tw, FFTW_FORWARD);
368 }
369 }
370 else {
371 if (n_fields > 1)
372 for (i = 0; i < rows; ++i)
373 fftw(p_fft, n_fields, local_data + i*(cols*n_fields),
374 n_fields, 1, fft_work, 1, 0);
375 else
376 fftw(p_fft, rows, local_data, 1, cols, fft_work, 1, 0);
377 }
378 }
379
380 transpose_mpi(p->p_transpose_inv, el_size,
381 (TRANSPOSE_EL_TYPE *) local_data,
382 (TRANSPOSE_EL_TYPE *) work);
383
384 {
385 int rows = p->local_m, cols = p->r;
386 fftw_plan p_fft = p->pr;
387
388 if (dir == FFTW_BACKWARD) {
389 for (i = 0; i < rows; ++i) {
390 fftw_complex *d = local_data + i * (cols * n_fields);
391
392 fftw_mpi_mult_twiddles(d, n_fields, i, tw, FFTW_BACKWARD);
393 fftw(p_fft, n_fields, d, n_fields, 1, fft_work, 1, 0);
394 }
395 }
396 else {
397 if (n_fields > 1)
398 for (i = 0; i < rows; ++i)
399 fftw(p_fft, n_fields, local_data + i*(cols*n_fields),
400 n_fields, 1, fft_work, 1, 0);
401 else
402 fftw(p_fft, rows, local_data, 1, cols, fft_work, 1, 0);
403 }
404 }
405
406 if (!(p->flags & FFTW_SCRAMBLED_OUTPUT))
407 transpose_mpi(p->p_transpose, el_size,
408 (TRANSPOSE_EL_TYPE *) local_data,
409 (TRANSPOSE_EL_TYPE *) work);
410
411 /* Yes, we really had to do three transposes...sigh. */
412 }
413
414
415