1 /*
2  * Copyright (c) 1997-1999, 2003 Massachusetts Institute of Technology
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation; either version 2 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software
16  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17  *
18  */
19 
20 #include <stdio.h>
21 #include <math.h>
22 
23 #include "fftw_mpi.h"
24 #include "fftw-int.h"
25 
26 /************************** Twiddle Factors *****************************/
27 
28 /* To conserve space, we share twiddle factor arrays between forward and
29    backward plans and plans of the same size (just as in the uniprocessor
30    transforms). */
31 
32 static fftw_mpi_twiddle *fftw_mpi_twiddles = NULL;
33 
fftw_mpi_create_twiddle(int rows,int rowstart,int cols,int n)34 static fftw_mpi_twiddle *fftw_mpi_create_twiddle(int rows, int rowstart,
35 						 int cols, int n)
36 {
37      fftw_mpi_twiddle *tw = fftw_mpi_twiddles;
38 
39      while (tw && (tw->rows != rows || tw->rowstart != rowstart ||
40 		   tw->cols != cols || tw->n != n))
41 	  tw = tw->next;
42 
43      if (tw) {
44 	  tw->refcount++;
45 	  return tw;
46      }
47 
48      tw = (fftw_mpi_twiddle *) fftw_malloc(sizeof(fftw_mpi_twiddle));
49      tw->rows = rows;
50      tw->rowstart = rowstart;
51      tw->cols = cols;
52      tw->n = n;
53      tw->refcount = 1;
54      tw->next = fftw_mpi_twiddles;
55 
56      {
57 	  fftw_complex *W = (fftw_complex *) fftw_malloc(sizeof(fftw_complex) *
58 							 rows * (cols - 1));
59 	  int j, i;
60 	  FFTW_TRIG_REAL twoPiOverN = FFTW_K2PI / (FFTW_TRIG_REAL) n;
61 
62 	  for (j = 0; j < rows; ++j)
63 	       for (i = 1; i < cols; ++i) {
64 		    int k = (j * (cols - 1) - 1) + i;
65 		    FFTW_TRIG_REAL
66 			 ij = (FFTW_TRIG_REAL) (i * (j + rowstart));
67 		    c_re(W[k]) = FFTW_TRIG_COS(twoPiOverN * ij);
68 		    c_im(W[k]) = FFTW_FORWARD * FFTW_TRIG_SIN(twoPiOverN * ij);
69 	       }
70 
71 	  tw->W = W;
72      }
73 
74      fftw_mpi_twiddles = tw;
75 
76      return tw;
77 }
78 
fftw_mpi_destroy_twiddle(fftw_mpi_twiddle * tw)79 static void fftw_mpi_destroy_twiddle(fftw_mpi_twiddle *tw)
80 {
81      if (tw) {
82 	  tw->refcount--;
83 	  if (tw->refcount == 0) {
84 	       /* delete tw from fftw_mpi_twiddles list: */
85 	       if (fftw_mpi_twiddles == tw)
86 		    fftw_mpi_twiddles = tw->next;
87 	       else {
88 		    fftw_mpi_twiddle *prev = fftw_mpi_twiddles;
89 
90 		    if (!prev)
91 			 fftw_mpi_die("unexpected empty MPI twiddle list");
92 		    while (prev->next && prev->next != tw)
93 			 prev = prev->next;
94 		    if (prev->next != tw)
95 			 fftw_mpi_die("tried to destroy unknown MPI twiddle");
96 		    prev->next = tw->next;
97 	       }
98 
99 	       fftw_free(tw->W);
100 	       fftw_free(tw);
101 	  }
102      }
103 }
104 
105 /* multiply the array in d (of size tw->cols * n_fields) by the row cur_row
106    of the twiddle factors pointed to by tw, given the transform direction. */
fftw_mpi_mult_twiddles(fftw_complex * d,int n_fields,int cur_row,fftw_mpi_twiddle * tw,fftw_direction dir)107 static void fftw_mpi_mult_twiddles(fftw_complex *d, int n_fields,
108 				   int cur_row,
109 				   fftw_mpi_twiddle *tw,
110 				   fftw_direction dir)
111 {
112      int cols = tw->cols;
113      fftw_complex *W = tw->W + cur_row * (cols - 1);
114      int j;
115 
116      if (dir == FFTW_FORWARD) {
117           if (n_fields > 1)
118                for (j = 1; j < cols; ++j) {
119                     fftw_real
120                          w_re = c_re(W[j-1]),
121                          w_im = c_im(W[j-1]);
122                     int f;
123 
124                     for (f = 0; f < n_fields; ++f) {
125                          fftw_real
126                               d_re = c_re(d[j*n_fields + f]),
127                               d_im = c_im(d[j*n_fields + f]);
128                          c_re(d[j*n_fields + f]) = w_re * d_re - w_im * d_im;
129                          c_im(d[j*n_fields + f]) = w_re * d_im + w_im * d_re;
130                     }
131                }
132           else
133                for (j = 1; j < cols; ++j) {
134                     fftw_real w_re = c_re(W[j-1]),
135                          w_im = c_im(W[j-1]),
136                          d_re = c_re(d[j]),
137                          d_im = c_im(d[j]);
138                     c_re(d[j]) = w_re * d_re - w_im * d_im;
139                     c_im(d[j]) = w_re * d_im + w_im * d_re;
140                }
141      }
142      else {  /* FFTW_BACKWARDS */
143 	  /* same as above, except that W is complex-conjugated: */
144           if (n_fields > 1)
145                for (j = 1; j < cols; ++j) {
146                     fftw_real
147                          w_re = c_re(W[j-1]),
148                          w_im = c_im(W[j-1]);
149                     int f;
150 
151                     for (f = 0; f < n_fields; ++f) {
152                          fftw_real
153                               d_re = c_re(d[j*n_fields + f]),
154                               d_im = c_im(d[j*n_fields + f]);
155                          c_re(d[j*n_fields + f]) = w_re * d_re + w_im * d_im;
156                          c_im(d[j*n_fields + f]) = w_re * d_im - w_im * d_re;
157                     }
158                }
159           else
160                for (j = 1; j < cols; ++j) {
161                     fftw_real w_re = c_re(W[j-1]),
162                          w_im = c_im(W[j-1]),
163                          d_re = c_re(d[j]),
164                          d_im = c_im(d[j]);
165                     c_re(d[j]) = w_re * d_re + w_im * d_im;
166                     c_im(d[j]) = w_re * d_im - w_im * d_re;
167                }
168      }
169 }
170 
171 /***************************** Plan Creation ****************************/
172 
173 /* return the factor of n closest to sqrt(n): */
find_sqrt_factor(int n)174 static int find_sqrt_factor(int n)
175 {
176      int i = sqrt(n) + 0.5;
177      int i2 = i - 1;
178 
179      while (i2 > 0) {
180 	  if (n % i2 == 0)
181 	       return i2;
182 	  if (n % i == 0)
183 	       return i;
184 	  ++i; --i2;
185      }
186      return 1; /* n <= 1 */
187 }
188 
189 /* find the "best" r to divide n by for the FFT decomposition.  Ideally,
190    we would like both r and n/r to be divisible by the number of
191    processes (for optimum load-balancing).  Also, pick r to be close
192    to sqrt(n) if possible. */
find_best_r(int n,MPI_Comm comm)193 static int find_best_r(int n, MPI_Comm comm)
194 {
195      int n_pes;
196 
197      MPI_Comm_size(comm, &n_pes);
198 
199      if (n % n_pes == 0) {
200 	  n /= n_pes;
201 	  if (n % n_pes == 0)
202 	       return (n_pes * find_sqrt_factor(n / n_pes));
203 	  else
204 	       return (n_pes * find_sqrt_factor(n));
205      }
206      else
207 	  return find_sqrt_factor(n);
208 }
209 
210 #define MAX2(a,b) ((a) > (b) ? (a) : (b))
211 
fftw_mpi_create_plan(MPI_Comm comm,int n,fftw_direction dir,int flags)212 fftw_mpi_plan fftw_mpi_create_plan(MPI_Comm comm,
213 				   int n, fftw_direction dir, int flags)
214 {
215      fftw_mpi_plan p;
216      int i, r, m;
217 
218      p = (fftw_mpi_plan) fftw_malloc(sizeof(struct fftw_mpi_plan_struct));
219 
220      i = find_best_r(n, comm);
221      if (dir == FFTW_FORWARD)
222 	  m = n / (r = i);
223      else
224 	  r = n / (m = i);
225 
226      p->n = n;
227      p->r = r;
228      p->m = m;
229 
230      flags |= FFTW_IN_PLACE;
231      p->flags = flags;
232      p->dir = dir;
233 
234      p->pr = fftw_create_plan(r, dir, flags);
235      p->pm = fftw_create_plan(m, dir, flags);
236 
237      p->p_transpose = transpose_mpi_create_plan(m, r, comm);
238      p->p_transpose_inv = transpose_mpi_create_plan(r, m, comm);
239 
240      transpose_mpi_get_local_size(r,
241 				  p->p_transpose_inv->my_pe,
242 				  p->p_transpose_inv->n_pes,
243 				  &p->local_r,
244 				  &p->local_r_start);
245      transpose_mpi_get_local_size(m,
246 				  p->p_transpose->my_pe,
247 				  p->p_transpose->n_pes,
248 				  &p->local_m,
249 				  &p->local_m_start);
250 
251      if (dir == FFTW_FORWARD)
252 	  p->tw = fftw_mpi_create_twiddle(p->local_r, p->local_r_start, m, n);
253      else
254 	  p->tw = fftw_mpi_create_twiddle(p->local_m, p->local_m_start, r, n);
255 
256      p->fft_work = (fftw_complex *) fftw_malloc(sizeof(fftw_complex) *
257 						MAX2(m, r));
258 
259      return p;
260 }
261 
262 /********************* Getting Local Size ***********************/
263 
fftw_mpi_local_sizes(fftw_mpi_plan p,int * local_n,int * local_start,int * local_n_after_transform,int * local_start_after_transform,int * total_local_size)264 void fftw_mpi_local_sizes(fftw_mpi_plan p,
265 			  int *local_n,
266 			  int *local_start,
267 			  int *local_n_after_transform,
268 			  int *local_start_after_transform,
269 			  int *total_local_size)
270 {
271      if (p) {
272 	  if (p->flags & FFTW_SCRAMBLED_INPUT) {
273 	       *local_n = p->local_r * p->m;
274 	       *local_start = p->local_r_start * p->m;
275 	  }
276 	  else {
277 	       *local_n = p->local_m * p->r;
278 	       *local_start = p->local_m_start * p->r;
279 	  }
280 
281 	  if (p->flags & FFTW_SCRAMBLED_OUTPUT) {
282 	       *local_n_after_transform = p->local_m * p->r;
283 	       *local_start_after_transform = p->local_m_start * p->r;
284 	  }
285 	  else {
286 	       *local_n_after_transform = p->local_r * p->m;
287 	       *local_start_after_transform = p->local_r_start * p->m;
288 	  }
289 
290 	  *total_local_size =
291 	       transpose_mpi_get_local_storage_size(p->p_transpose->nx,
292 						    p->p_transpose->ny,
293 						    p->p_transpose->my_pe,
294 						    p->p_transpose->n_pes);
295      }
296 }
297 
fftw_mpi_fprint_plan(FILE * f,fftw_mpi_plan p)298 static void fftw_mpi_fprint_plan(FILE *f, fftw_mpi_plan p)
299 {
300      fprintf(f, "mpi plan:\n");
301      fprintf(f, "m = %d plan:\n", p->m);
302      fftw_fprint_plan(f, p->pm);
303      fprintf(f, "r = %d plan:\n", p->r);
304      fftw_fprint_plan(f, p->pr);
305 }
306 
fftw_mpi_print_plan(fftw_mpi_plan p)307 void fftw_mpi_print_plan(fftw_mpi_plan p)
308 {
309      fftw_mpi_fprint_plan(stdout, p);
310 }
311 
312 /********************** Plan Destruction ************************/
313 
fftw_mpi_destroy_plan(fftw_mpi_plan p)314 void fftw_mpi_destroy_plan(fftw_mpi_plan p)
315 {
316      if (p) {
317 	  fftw_destroy_plan(p->pr);
318 	  fftw_destroy_plan(p->pm);
319 	  transpose_mpi_destroy_plan(p->p_transpose);
320 	  transpose_mpi_destroy_plan(p->p_transpose_inv);
321 	  fftw_mpi_destroy_twiddle(p->tw);
322 	  fftw_free(p->fft_work);
323 	  fftw_free(p);
324      }
325 }
326 
327 /******************** Computing the Transform *******************/
328 
fftw_mpi(fftw_mpi_plan p,int n_fields,fftw_complex * local_data,fftw_complex * work)329 void fftw_mpi(fftw_mpi_plan p, int n_fields,
330 	      fftw_complex *local_data, fftw_complex *work)
331 {
332      int i;
333      int el_size = (sizeof(fftw_complex) / sizeof(TRANSPOSE_EL_TYPE))
334                    * n_fields;
335      fftw_complex *fft_work;
336      fftw_direction dir;
337      fftw_mpi_twiddle *tw;
338 
339      if (n_fields < 1)
340 	  return;
341 
342      if (!(p->flags & FFTW_SCRAMBLED_INPUT))
343 	  transpose_mpi(p->p_transpose, el_size,
344 			(TRANSPOSE_EL_TYPE *) local_data,
345 			(TRANSPOSE_EL_TYPE *) work);
346 
347      tw = p->tw;
348      dir = p->dir;
349      fft_work = work ? work : p->fft_work;
350 
351      /* For forward plans, we multiply by the twiddle factors here,
352 	before the second transpose.  For backward plans, we multiply
353 	by the twiddle factors after the second transpose.  We do
354 	this so that forward and backward transforms can share the
355 	same twiddle factor array (noting that m and r are swapped
356 	for the two directions so that the local sizes will be compatible). */
357 
358      {
359 	  int rows = p->local_r, cols = p->m;
360 	  fftw_plan p_fft = p->pm;
361 
362 	  if (dir == FFTW_FORWARD) {
363 	       for (i = 0; i < rows; ++i) {
364 		    fftw_complex *d = local_data + i * (cols * n_fields);
365 
366 		    fftw(p_fft, n_fields, d, n_fields, 1, fft_work, 1, 0);
367 		    fftw_mpi_mult_twiddles(d, n_fields, i, tw, FFTW_FORWARD);
368 	       }
369 	  }
370 	  else {
371 	       if (n_fields > 1)
372 		    for (i = 0; i < rows; ++i)
373 			 fftw(p_fft, n_fields, local_data + i*(cols*n_fields),
374 			      n_fields, 1, fft_work, 1, 0);
375 	       else
376 		    fftw(p_fft, rows, local_data, 1, cols, fft_work, 1, 0);
377 	  }
378      }
379 
380      transpose_mpi(p->p_transpose_inv, el_size,
381 		   (TRANSPOSE_EL_TYPE *) local_data,
382 		   (TRANSPOSE_EL_TYPE *) work);
383 
384      {
385 	  int rows = p->local_m, cols = p->r;
386 	  fftw_plan p_fft = p->pr;
387 
388 	  if (dir == FFTW_BACKWARD) {
389 	       for (i = 0; i < rows; ++i) {
390 		    fftw_complex *d = local_data + i * (cols * n_fields);
391 
392 		    fftw_mpi_mult_twiddles(d, n_fields, i, tw, FFTW_BACKWARD);
393 		    fftw(p_fft, n_fields, d, n_fields, 1, fft_work, 1, 0);
394 	       }
395 	  }
396 	  else {
397 	       if (n_fields > 1)
398 		    for (i = 0; i < rows; ++i)
399 			 fftw(p_fft, n_fields, local_data + i*(cols*n_fields),
400 			      n_fields, 1, fft_work, 1, 0);
401 	       else
402 		    fftw(p_fft, rows, local_data, 1, cols, fft_work, 1, 0);
403 	  }
404      }
405 
406      if (!(p->flags & FFTW_SCRAMBLED_OUTPUT))
407 	  transpose_mpi(p->p_transpose, el_size,
408 			(TRANSPOSE_EL_TYPE *) local_data,
409 			(TRANSPOSE_EL_TYPE *) work);
410 
411      /* Yes, we really had to do three transposes...sigh. */
412 }
413 
414 
415