1 /*
2  * Copyright (c) 2003, 2007-14 Matteo Frigo
3  * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation; either version 2 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
18  *
19  */
20 
21 #include "api/api.h"
22 #include "fftw3-mpi.h"
23 #include "ifftw-mpi.h"
24 #include "mpi-transpose.h"
25 #include "mpi-dft.h"
26 #include "mpi-rdft.h"
27 #include "mpi-rdft2.h"
28 
29 /* Convert API flags to internal MPI flags. */
30 #define MPI_FLAGS(f) ((f) >> 27)
31 
32 /*************************************************************************/
33 
34 static int mpi_inited = 0;
35 
problem_comm(const problem * p)36 static MPI_Comm problem_comm(const problem *p) {
37      switch (p->adt->problem_kind) {
38 	 case PROBLEM_MPI_DFT:
39 	      return ((const problem_mpi_dft *) p)->comm;
40 	 case PROBLEM_MPI_RDFT:
41 	      return ((const problem_mpi_rdft *) p)->comm;
42 	 case PROBLEM_MPI_RDFT2:
43 	      return ((const problem_mpi_rdft2 *) p)->comm;
44 	 case PROBLEM_MPI_TRANSPOSE:
45 	      return ((const problem_mpi_transpose *) p)->comm;
46 	 default:
47 	      return MPI_COMM_NULL;
48      }
49 }
50 
51 /* used to synchronize cost measurements (timing or estimation)
52    across all processes for an MPI problem, which is critical to
53    ensure that all processes decide to use the same MPI plans
54    (whereas serial plans need not be syncronized). */
cost_hook(const problem * p,double t,cost_kind k)55 static double cost_hook(const problem *p, double t, cost_kind k)
56 {
57      MPI_Comm comm = problem_comm(p);
58      double tsum;
59      if (comm == MPI_COMM_NULL) return t;
60      MPI_Allreduce(&t, &tsum, 1, MPI_DOUBLE,
61 		   k == COST_SUM ? MPI_SUM : MPI_MAX, comm);
62      return tsum;
63 }
64 
65 /* Used to reject wisdom that is not in sync across all processes
66    for an MPI problem, which is critical to ensure that all processes
67    decide to use the same MPI plans.  (Even though costs are synchronized,
68    above, out-of-sync wisdom may result from plans being produced
69    by communicators that do not span all processes, either from a
70    user-specified communicator or e.g. from transpose-recurse. */
wisdom_ok_hook(const problem * p,flags_t flags)71 static int wisdom_ok_hook(const problem *p, flags_t flags)
72 {
73      MPI_Comm comm = problem_comm(p);
74      int eq_me, eq_all;
75      /* unpack flags bitfield, since MPI communications may involve
76 	byte-order changes and MPI cannot do this for bit fields */
77 #if SIZEOF_UNSIGNED_INT >= 4 /* must be big enough to hold 20-bit fields */
78      unsigned int f[5];
79 #else
80      unsigned long f[5]; /* at least 32 bits as per C standard */
81 #endif
82 
83      if (comm == MPI_COMM_NULL) return 1; /* non-MPI wisdom is always ok */
84 
85      if (XM(any_true)(0, comm)) return 0; /* some process had nowisdom_hook */
86 
87      /* otherwise, check that the flags and solver index are identical
88 	on all processes in this problem's communicator.
89 
90 	TO DO: possibly we can relax strict equality, but it is
91 	critical to ensure that any flags which affect what plan is
92 	created (and whether the solver is applicable) are the same,
93 	e.g. DESTROY_INPUT, NO_UGLY, etcetera.  (If the MPI algorithm
94 	differs between processes, deadlocks/crashes generally result.) */
95      f[0] = flags.l;
96      f[1] = flags.hash_info;
97      f[2] = flags.timelimit_impatience;
98      f[3] = flags.u;
99      f[4] = flags.slvndx;
100      MPI_Bcast(f, 5,
101 	       SIZEOF_UNSIGNED_INT >= 4 ? MPI_UNSIGNED : MPI_UNSIGNED_LONG,
102 	       0, comm);
103      eq_me = f[0] == flags.l && f[1] == flags.hash_info
104 	  && f[2] == flags.timelimit_impatience
105 	  && f[3] == flags.u && f[4] == flags.slvndx;
106      MPI_Allreduce(&eq_me, &eq_all, 1, MPI_INT, MPI_LAND, comm);
107      return eq_all;
108 }
109 
110 /* This hook is called when wisdom is not found.  The any_true here
111    matches up with the any_true in wisdom_ok_hook, in order to handle
112    the case where some processes had wisdom (and called wisdom_ok_hook)
113    and some processes didn't have wisdom (and called nowisdom_hook). */
nowisdom_hook(const problem * p)114 static void nowisdom_hook(const problem *p)
115 {
116      MPI_Comm comm = problem_comm(p);
117      if (comm == MPI_COMM_NULL) return; /* nothing to do for non-MPI p */
118      XM(any_true)(1, comm); /* signal nowisdom to any wisdom_ok_hook */
119 }
120 
121 /* needed to synchronize planner bogosity flag, in case non-MPI problems
122    on a subset of processes encountered bogus wisdom */
bogosity_hook(wisdom_state_t state,const problem * p)123 static wisdom_state_t bogosity_hook(wisdom_state_t state, const problem *p)
124 {
125      MPI_Comm comm = problem_comm(p);
126      if (comm != MPI_COMM_NULL /* an MPI problem */
127 	 && XM(any_true)(state == WISDOM_IS_BOGUS, comm)) /* bogus somewhere */
128 	  return WISDOM_IS_BOGUS;
129      return state;
130 }
131 
XM(init)132 void XM(init)(void)
133 {
134      if (!mpi_inited) {
135 	  planner *plnr = X(the_planner)();
136 	  plnr->cost_hook = cost_hook;
137 	  plnr->wisdom_ok_hook = wisdom_ok_hook;
138 	  plnr->nowisdom_hook = nowisdom_hook;
139 	  plnr->bogosity_hook = bogosity_hook;
140           XM(conf_standard)(plnr);
141 	  mpi_inited = 1;
142      }
143 }
144 
XM(cleanup)145 void XM(cleanup)(void)
146 {
147      X(cleanup)();
148      mpi_inited = 0;
149 }
150 
151 /*************************************************************************/
152 
mkdtensor_api(int rnk,const XM (ddim)* dims0)153 static dtensor *mkdtensor_api(int rnk, const XM(ddim) *dims0)
154 {
155      dtensor *x = XM(mkdtensor)(rnk);
156      int i;
157      for (i = 0; i < rnk; ++i) {
158 	  x->dims[i].n = dims0[i].n;
159 	  x->dims[i].b[IB] = dims0[i].ib;
160 	  x->dims[i].b[OB] = dims0[i].ob;
161      }
162      return x;
163 }
164 
default_sz(int rnk,const XM (ddim)* dims0,int n_pes,int rdft2)165 static dtensor *default_sz(int rnk, const XM(ddim) *dims0, int n_pes,
166 			   int rdft2)
167 {
168      dtensor *sz = XM(mkdtensor)(rnk);
169      dtensor *sz0 = mkdtensor_api(rnk, dims0);
170      block_kind k;
171      int i;
172 
173      for (i = 0; i < rnk; ++i)
174 	  sz->dims[i].n = dims0[i].n;
175 
176      if (rdft2) sz->dims[rnk-1].n = dims0[rnk-1].n / 2 + 1;
177 
178      for (i = 0; i < rnk; ++i) {
179 	  sz->dims[i].b[IB] = dims0[i].ib ? dims0[i].ib : sz->dims[i].n;
180 	  sz->dims[i].b[OB] = dims0[i].ob ? dims0[i].ob : sz->dims[i].n;
181      }
182 
183      /* If we haven't used all of the processes yet, and some of the
184 	block sizes weren't specified (i.e. 0), then set the
185 	unspecified blocks so as to use as many processes as
186 	possible with as few distributed dimensions as possible. */
187      FORALL_BLOCK_KIND(k) {
188 	  INT nb = XM(num_blocks_total)(sz, k);
189 	  INT np = n_pes / nb;
190 	  for (i = 0; i < rnk && np > 1; ++i)
191 	       if (!sz0->dims[i].b[k]) {
192 		    sz->dims[i].b[k] = XM(default_block)(sz->dims[i].n, np);
193 		    nb *= XM(num_blocks)(sz->dims[i].n, sz->dims[i].b[k]);
194 		    np = n_pes / nb;
195 	       }
196      }
197 
198      if (rdft2) sz->dims[rnk-1].n = dims0[rnk-1].n;
199 
200      /* punt for 1d prime */
201      if (rnk == 1 && X(is_prime)(sz->dims[0].n))
202 	  sz->dims[0].b[IB] = sz->dims[0].b[OB] = sz->dims[0].n;
203 
204      XM(dtensor_destroy)(sz0);
205      sz0 = XM(dtensor_canonical)(sz, 0);
206      XM(dtensor_destroy)(sz);
207      return sz0;
208 }
209 
210 /* allocate simple local (serial) dims array corresponding to n[rnk] */
XM(ddim)211 static XM(ddim) *simple_dims(int rnk, const ptrdiff_t *n)
212 {
213      XM(ddim) *dims = (XM(ddim) *) MALLOC(sizeof(XM(ddim)) * rnk,
214 						TENSORS);
215      int i;
216      for (i = 0; i < rnk; ++i)
217 	  dims[i].n = dims[i].ib = dims[i].ob = n[i];
218      return dims;
219 }
220 
221 /*************************************************************************/
222 
local_size(int my_pe,const dtensor * sz,block_kind k,ptrdiff_t * local_n,ptrdiff_t * local_start)223 static void local_size(int my_pe, const dtensor *sz, block_kind k,
224 		       ptrdiff_t *local_n, ptrdiff_t *local_start)
225 {
226      int i;
227      if (my_pe >= XM(num_blocks_total)(sz, k))
228 	  for (i = 0; i < sz->rnk; ++i)
229 	       local_n[i] = local_start[i] = 0;
230      else {
231 	  XM(block_coords)(sz, k, my_pe, local_start);
232 	  for (i = 0; i < sz->rnk; ++i) {
233 	       local_n[i] = XM(block)(sz->dims[i].n, sz->dims[i].b[k],
234 				      local_start[i]);
235 	       local_start[i] *= sz->dims[i].b[k];
236 	  }
237      }
238 }
239 
prod(int rnk,const ptrdiff_t * local_n)240 static INT prod(int rnk, const ptrdiff_t *local_n)
241 {
242      int i;
243      INT N = 1;
244      for (i = 0; i < rnk; ++i) N *= local_n[i];
245      return N;
246 }
247 
XM(local_size_guru)248 ptrdiff_t XM(local_size_guru)(int rnk, const XM(ddim) *dims0,
249 			      ptrdiff_t howmany, MPI_Comm comm,
250 			      ptrdiff_t *local_n_in,
251 			      ptrdiff_t *local_start_in,
252 			      ptrdiff_t *local_n_out,
253 			      ptrdiff_t *local_start_out,
254 			      int sign, unsigned flags)
255 {
256      INT N;
257      int my_pe, n_pes, i;
258      dtensor *sz;
259 
260      if (rnk == 0)
261 	  return howmany;
262 
263      MPI_Comm_rank(comm, &my_pe);
264      MPI_Comm_size(comm, &n_pes);
265      sz = default_sz(rnk, dims0, n_pes, 0);
266 
267      /* Now, we must figure out how much local space the user should
268 	allocate (or at least an upper bound).  This depends strongly
269 	on the exact algorithms we employ...ugh!  FIXME: get this info
270 	from the solvers somehow? */
271      N = 1; /* never return zero allocation size */
272      if (rnk > 1 && XM(is_block1d)(sz, IB) && XM(is_block1d)(sz, OB)) {
273 	  INT Nafter;
274 	  ddim odims[2];
275 
276 	  /* dft-rank-geq2-transposed */
277 	  odims[0] = sz->dims[0]; odims[1] = sz->dims[1]; /* save */
278 	  /* we may need extra space for transposed intermediate data */
279 	  for (i = 0; i < 2; ++i)
280 	       if (XM(num_blocks)(sz->dims[i].n, sz->dims[i].b[IB]) == 1 &&
281 		   XM(num_blocks)(sz->dims[i].n, sz->dims[i].b[OB]) == 1) {
282 		    sz->dims[i].b[IB]
283 			 = XM(default_block)(sz->dims[i].n, n_pes);
284 		    sz->dims[1-i].b[IB] = sz->dims[1-i].n;
285 		    local_size(my_pe, sz, IB, local_n_in, local_start_in);
286 		    N = X(imax)(N, prod(rnk, local_n_in));
287 		    sz->dims[i] = odims[i];
288 		    sz->dims[1-i] = odims[1-i];
289 		    break;
290 	       }
291 
292 	  /* dft-rank-geq2 */
293 	  Nafter = howmany;
294 	  for (i = 1; i < sz->rnk; ++i) Nafter *= sz->dims[i].n;
295 	  N = X(imax)(N, (sz->dims[0].n
296 			  * XM(block)(Nafter, XM(default_block)(Nafter, n_pes),
297 				      my_pe) + howmany - 1) / howmany);
298 
299 	  /* dft-rank-geq2 with dimensions swapped */
300 	  Nafter = howmany * sz->dims[0].n;
301           for (i = 2; i < sz->rnk; ++i) Nafter *= sz->dims[i].n;
302           N = X(imax)(N, (sz->dims[1].n
303                           * XM(block)(Nafter, XM(default_block)(Nafter, n_pes),
304                                       my_pe) + howmany - 1) / howmany);
305      }
306      else if (rnk == 1) {
307 	  if (howmany >= n_pes && !MPI_FLAGS(flags)) { /* dft-rank1-bigvec */
308 	       ptrdiff_t n[2], start[2];
309 	       dtensor *sz2 = XM(mkdtensor)(2);
310 	       sz2->dims[0] = sz->dims[0];
311 	       sz2->dims[0].b[IB] = sz->dims[0].n;
312 	       sz2->dims[1].n = sz2->dims[1].b[OB] = howmany;
313 	       sz2->dims[1].b[IB] = XM(default_block)(howmany, n_pes);
314 	       local_size(my_pe, sz2, IB, n, start);
315 	       XM(dtensor_destroy)(sz2);
316 	       N = X(imax)(N, (prod(2, n) + howmany - 1) / howmany);
317 	  }
318 	  else { /* dft-rank1 */
319 	       INT r, m, rblock[2], mblock[2];
320 
321 	       /* Since the 1d transforms are so different, we require
322 		  the user to call local_size_1d for this case.  Ugh. */
323 	       CK(sign == FFTW_FORWARD || sign == FFTW_BACKWARD);
324 
325 	       if ((r = XM(choose_radix)(sz->dims[0], n_pes, flags, sign,
326 					 rblock, mblock))) {
327 		    m = sz->dims[0].n / r;
328 		    if (flags & FFTW_MPI_SCRAMBLED_IN)
329 			 sz->dims[0].b[IB] = rblock[IB] * m;
330 		    else { /* !SCRAMBLED_IN */
331 			 sz->dims[0].b[IB] = r * mblock[IB];
332 			 N = X(imax)(N, rblock[IB] * m);
333 		    }
334 		    if (flags & FFTW_MPI_SCRAMBLED_OUT)
335 			 sz->dims[0].b[OB] = r * mblock[OB];
336 		    else { /* !SCRAMBLED_OUT */
337 			 N = X(imax)(N, r * mblock[OB]);
338 			 sz->dims[0].b[OB] = rblock[OB] * m;
339 		    }
340 	       }
341 	  }
342      }
343 
344      local_size(my_pe, sz, IB, local_n_in, local_start_in);
345      local_size(my_pe, sz, OB, local_n_out, local_start_out);
346 
347      /* at least, make sure we have enough space to store input & output */
348      N = X(imax)(N, X(imax)(prod(rnk, local_n_in), prod(rnk, local_n_out)));
349 
350      XM(dtensor_destroy)(sz);
351      return N * howmany;
352 }
353 
XM(local_size_many_transposed)354 ptrdiff_t XM(local_size_many_transposed)(int rnk, const ptrdiff_t *n,
355 					 ptrdiff_t howmany,
356 					 ptrdiff_t xblock, ptrdiff_t yblock,
357 					 MPI_Comm comm,
358 					 ptrdiff_t *local_nx,
359 					 ptrdiff_t *local_x_start,
360 					 ptrdiff_t *local_ny,
361 					 ptrdiff_t *local_y_start)
362 {
363      ptrdiff_t N;
364      XM(ddim) *dims;
365      ptrdiff_t *local;
366 
367      if (rnk == 0) {
368 	  *local_nx = *local_ny = 1;
369 	  *local_x_start = *local_y_start = 0;
370 	  return howmany;
371      }
372 
373      dims = simple_dims(rnk, n);
374      local = (ptrdiff_t *) MALLOC(sizeof(ptrdiff_t) * rnk * 4, TENSORS);
375 
376      /* default 1d block distribution, with transposed output
377         if yblock < n[1] */
378      dims[0].ib = xblock;
379      if (rnk > 1) {
380 	  if (yblock < n[1])
381 	       dims[1].ob = yblock;
382 	  else
383 	       dims[0].ob = xblock;
384      }
385      else
386 	  dims[0].ob = xblock; /* FIXME: 1d not really supported here
387 				         since we don't have flags/sign */
388 
389      N = XM(local_size_guru)(rnk, dims, howmany, comm,
390 			     local, local + rnk,
391 			     local + 2*rnk, local + 3*rnk,
392 			     0, 0);
393      *local_nx = local[0];
394      *local_x_start = local[rnk];
395      if (rnk > 1) {
396 	  *local_ny = local[2*rnk + 1];
397 	  *local_y_start = local[3*rnk + 1];
398      }
399      else {
400 	  *local_ny = *local_nx;
401 	  *local_y_start = *local_x_start;
402      }
403      X(ifree)(local);
404      X(ifree)(dims);
405      return N;
406 }
407 
XM(local_size_many)408 ptrdiff_t XM(local_size_many)(int rnk, const ptrdiff_t *n,
409 			      ptrdiff_t howmany,
410 			      ptrdiff_t xblock,
411 			      MPI_Comm comm,
412 			      ptrdiff_t *local_nx,
413 			      ptrdiff_t *local_x_start)
414 {
415      ptrdiff_t local_ny, local_y_start;
416      return XM(local_size_many_transposed)(rnk, n, howmany,
417 					   xblock, rnk > 1
418 					   ? n[1] : FFTW_MPI_DEFAULT_BLOCK,
419 					   comm,
420 					   local_nx, local_x_start,
421 					   &local_ny, &local_y_start);
422 }
423 
424 
XM(local_size_transposed)425 ptrdiff_t XM(local_size_transposed)(int rnk, const ptrdiff_t *n,
426 				    MPI_Comm comm,
427 				    ptrdiff_t *local_nx,
428 				    ptrdiff_t *local_x_start,
429 				    ptrdiff_t *local_ny,
430 				    ptrdiff_t *local_y_start)
431 {
432      return XM(local_size_many_transposed)(rnk, n, 1,
433 					   FFTW_MPI_DEFAULT_BLOCK,
434 					   FFTW_MPI_DEFAULT_BLOCK,
435 					   comm,
436 					   local_nx, local_x_start,
437 					   local_ny, local_y_start);
438 }
439 
XM(local_size)440 ptrdiff_t XM(local_size)(int rnk, const ptrdiff_t *n,
441 			 MPI_Comm comm,
442 			 ptrdiff_t *local_nx,
443 			 ptrdiff_t *local_x_start)
444 {
445      return XM(local_size_many)(rnk, n, 1, FFTW_MPI_DEFAULT_BLOCK, comm,
446 				local_nx, local_x_start);
447 }
448 
XM(local_size_many_1d)449 ptrdiff_t XM(local_size_many_1d)(ptrdiff_t nx, ptrdiff_t howmany,
450 				 MPI_Comm comm, int sign, unsigned flags,
451 				 ptrdiff_t *local_nx, ptrdiff_t *local_x_start,
452 				 ptrdiff_t *local_ny, ptrdiff_t *local_y_start)
453 {
454      XM(ddim) d;
455      d.n = nx;
456      d.ib = d.ob = FFTW_MPI_DEFAULT_BLOCK;
457      return XM(local_size_guru)(1, &d, howmany, comm,
458 				local_nx, local_x_start,
459 				local_ny, local_y_start, sign, flags);
460 }
461 
XM(local_size_1d)462 ptrdiff_t XM(local_size_1d)(ptrdiff_t nx,
463 			    MPI_Comm comm, int sign, unsigned flags,
464 			    ptrdiff_t *local_nx, ptrdiff_t *local_x_start,
465 			    ptrdiff_t *local_ny, ptrdiff_t *local_y_start)
466 {
467      return XM(local_size_many_1d)(nx, 1, comm, sign, flags,
468 				   local_nx, local_x_start,
469 				   local_ny, local_y_start);
470 }
471 
XM(local_size_2d_transposed)472 ptrdiff_t XM(local_size_2d_transposed)(ptrdiff_t nx, ptrdiff_t ny,
473 				       MPI_Comm comm,
474 				       ptrdiff_t *local_nx,
475 				       ptrdiff_t *local_x_start,
476 				       ptrdiff_t *local_ny,
477 				       ptrdiff_t *local_y_start)
478 {
479      ptrdiff_t n[2];
480      n[0] = nx; n[1] = ny;
481      return XM(local_size_transposed)(2, n, comm,
482 				      local_nx, local_x_start,
483 				      local_ny, local_y_start);
484 }
485 
XM(local_size_2d)486 ptrdiff_t XM(local_size_2d)(ptrdiff_t nx, ptrdiff_t ny, MPI_Comm comm,
487 			       ptrdiff_t *local_nx, ptrdiff_t *local_x_start)
488 {
489      ptrdiff_t n[2];
490      n[0] = nx; n[1] = ny;
491      return XM(local_size)(2, n, comm, local_nx, local_x_start);
492 }
493 
XM(local_size_3d_transposed)494 ptrdiff_t XM(local_size_3d_transposed)(ptrdiff_t nx, ptrdiff_t ny,
495 				       ptrdiff_t nz,
496 				       MPI_Comm comm,
497 				       ptrdiff_t *local_nx,
498 				       ptrdiff_t *local_x_start,
499 				       ptrdiff_t *local_ny,
500 				       ptrdiff_t *local_y_start)
501 {
502      ptrdiff_t n[3];
503      n[0] = nx; n[1] = ny; n[2] = nz;
504      return XM(local_size_transposed)(3, n, comm,
505 				      local_nx, local_x_start,
506 				      local_ny, local_y_start);
507 }
508 
XM(local_size_3d)509 ptrdiff_t XM(local_size_3d)(ptrdiff_t nx, ptrdiff_t ny, ptrdiff_t nz,
510 			    MPI_Comm comm,
511 			    ptrdiff_t *local_nx, ptrdiff_t *local_x_start)
512 {
513      ptrdiff_t n[3];
514      n[0] = nx; n[1] = ny; n[2] = nz;
515      return XM(local_size)(3, n, comm, local_nx, local_x_start);
516 }
517 
518 /*************************************************************************/
519 /* Transpose API */
520 
XM(plan_many_transpose)521 X(plan) XM(plan_many_transpose)(ptrdiff_t nx, ptrdiff_t ny,
522 				ptrdiff_t howmany,
523 				ptrdiff_t xblock, ptrdiff_t yblock,
524 				R *in, R *out,
525 				MPI_Comm comm, unsigned flags)
526 {
527      int n_pes;
528      XM(init)();
529 
530      if (howmany < 0 || xblock < 0 || yblock < 0 ||
531 	 nx <= 0 || ny <= 0) return 0;
532 
533      MPI_Comm_size(comm, &n_pes);
534      if (!xblock) xblock = XM(default_block)(nx, n_pes);
535      if (!yblock) yblock = XM(default_block)(ny, n_pes);
536      if (n_pes < XM(num_blocks)(nx, xblock)
537 	 || n_pes < XM(num_blocks)(ny, yblock))
538 	  return 0;
539 
540      return
541 	  X(mkapiplan)(FFTW_FORWARD, flags,
542 		       XM(mkproblem_transpose)(nx, ny, howmany,
543 					       in, out, xblock, yblock,
544 					       comm, MPI_FLAGS(flags)));
545 }
546 
XM(plan_transpose)547 X(plan) XM(plan_transpose)(ptrdiff_t nx, ptrdiff_t ny, R *in, R *out,
548 			   MPI_Comm comm, unsigned flags)
549 
550 {
551      return XM(plan_many_transpose)(nx, ny, 1,
552 				    FFTW_MPI_DEFAULT_BLOCK,
553 				    FFTW_MPI_DEFAULT_BLOCK,
554 				    in, out, comm, flags);
555 }
556 
557 /*************************************************************************/
558 /* Complex DFT API */
559 
XM(plan_guru_dft)560 X(plan) XM(plan_guru_dft)(int rnk, const XM(ddim) *dims0,
561 			  ptrdiff_t howmany,
562 			  C *in, C *out,
563 			  MPI_Comm comm, int sign, unsigned flags)
564 {
565      int n_pes, i;
566      dtensor *sz;
567 
568      XM(init)();
569 
570      if (howmany < 0 || rnk < 1) return 0;
571      for (i = 0; i < rnk; ++i)
572 	  if (dims0[i].n < 1 || dims0[i].ib < 0 || dims0[i].ob < 0)
573 	       return 0;
574 
575      MPI_Comm_size(comm, &n_pes);
576      sz = default_sz(rnk, dims0, n_pes, 0);
577 
578      if (XM(num_blocks_total)(sz, IB) > n_pes
579 	 || XM(num_blocks_total)(sz, OB) > n_pes) {
580 	  XM(dtensor_destroy)(sz);
581 	  return 0;
582      }
583 
584      return
585           X(mkapiplan)(sign, flags,
586                        XM(mkproblem_dft_d)(sz, howmany,
587 					   (R *) in, (R *) out,
588 					   comm, sign,
589 					   MPI_FLAGS(flags)));
590 }
591 
XM(plan_many_dft)592 X(plan) XM(plan_many_dft)(int rnk, const ptrdiff_t *n,
593 			  ptrdiff_t howmany,
594 			  ptrdiff_t iblock, ptrdiff_t oblock,
595 			  C *in, C *out,
596 			  MPI_Comm comm, int sign, unsigned flags)
597 {
598      XM(ddim) *dims = simple_dims(rnk, n);
599      X(plan) pln;
600 
601      if (rnk == 1) {
602 	  dims[0].ib = iblock;
603 	  dims[0].ob = oblock;
604      }
605      else if (rnk > 1) {
606 	  dims[0 != (flags & FFTW_MPI_TRANSPOSED_IN)].ib = iblock;
607 	  dims[0 != (flags & FFTW_MPI_TRANSPOSED_OUT)].ob = oblock;
608      }
609 
610      pln = XM(plan_guru_dft)(rnk,dims,howmany, in,out, comm, sign, flags);
611      X(ifree)(dims);
612      return pln;
613 }
614 
XM(plan_dft)615 X(plan) XM(plan_dft)(int rnk, const ptrdiff_t *n, C *in, C *out,
616 		     MPI_Comm comm, int sign, unsigned flags)
617 {
618      return XM(plan_many_dft)(rnk, n, 1,
619 			      FFTW_MPI_DEFAULT_BLOCK,
620 			      FFTW_MPI_DEFAULT_BLOCK,
621 			      in, out, comm, sign, flags);
622 }
623 
XM(plan_dft_1d)624 X(plan) XM(plan_dft_1d)(ptrdiff_t nx, C *in, C *out,
625 			MPI_Comm comm, int sign, unsigned flags)
626 {
627      return XM(plan_dft)(1, &nx, in, out, comm, sign, flags);
628 }
629 
XM(plan_dft_2d)630 X(plan) XM(plan_dft_2d)(ptrdiff_t nx, ptrdiff_t ny, C *in, C *out,
631 			MPI_Comm comm, int sign, unsigned flags)
632 {
633      ptrdiff_t n[2];
634      n[0] = nx; n[1] = ny;
635      return XM(plan_dft)(2, n, in, out, comm, sign, flags);
636 }
637 
XM(plan_dft_3d)638 X(plan) XM(plan_dft_3d)(ptrdiff_t nx, ptrdiff_t ny, ptrdiff_t nz,
639 			C *in, C *out,
640 			MPI_Comm comm, int sign, unsigned flags)
641 {
642      ptrdiff_t n[3];
643      n[0] = nx; n[1] = ny; n[2] = nz;
644      return XM(plan_dft)(3, n, in, out, comm, sign, flags);
645 }
646 
647 /*************************************************************************/
648 /* R2R API */
649 
XM(plan_guru_r2r)650 X(plan) XM(plan_guru_r2r)(int rnk, const XM(ddim) *dims0,
651 			  ptrdiff_t howmany,
652 			  R *in, R *out,
653 			  MPI_Comm comm, const X(r2r_kind) *kind,
654 			  unsigned flags)
655 {
656      int n_pes, i;
657      dtensor *sz;
658      rdft_kind *k;
659      X(plan) pln;
660 
661      XM(init)();
662 
663      if (howmany < 0 || rnk < 1) return 0;
664      for (i = 0; i < rnk; ++i)
665 	  if (dims0[i].n < 1 || dims0[i].ib < 0 || dims0[i].ob < 0)
666 	       return 0;
667 
668      k = X(map_r2r_kind)(rnk, kind);
669 
670      MPI_Comm_size(comm, &n_pes);
671      sz = default_sz(rnk, dims0, n_pes, 0);
672 
673      if (XM(num_blocks_total)(sz, IB) > n_pes
674 	 || XM(num_blocks_total)(sz, OB) > n_pes) {
675 	  XM(dtensor_destroy)(sz);
676 	  return 0;
677      }
678 
679      pln = X(mkapiplan)(0, flags,
680 			XM(mkproblem_rdft_d)(sz, howmany,
681 					     in, out,
682 					     comm, k, MPI_FLAGS(flags)));
683      X(ifree0)(k);
684      return pln;
685 }
686 
XM(plan_many_r2r)687 X(plan) XM(plan_many_r2r)(int rnk, const ptrdiff_t *n,
688 			  ptrdiff_t howmany,
689 			  ptrdiff_t iblock, ptrdiff_t oblock,
690 			  R *in, R *out,
691 			  MPI_Comm comm, const X(r2r_kind) *kind,
692 			  unsigned flags)
693 {
694      XM(ddim) *dims = simple_dims(rnk, n);
695      X(plan) pln;
696 
697      if (rnk == 1) {
698 	  dims[0].ib = iblock;
699 	  dims[0].ob = oblock;
700      }
701      else if (rnk > 1) {
702 	  dims[0 != (flags & FFTW_MPI_TRANSPOSED_IN)].ib = iblock;
703 	  dims[0 != (flags & FFTW_MPI_TRANSPOSED_OUT)].ob = oblock;
704      }
705 
706      pln = XM(plan_guru_r2r)(rnk,dims,howmany, in,out, comm, kind, flags);
707      X(ifree)(dims);
708      return pln;
709 }
710 
XM(plan_r2r)711 X(plan) XM(plan_r2r)(int rnk, const ptrdiff_t *n, R *in, R *out,
712 		     MPI_Comm comm,
713 		     const X(r2r_kind) *kind,
714 		     unsigned flags)
715 {
716      return XM(plan_many_r2r)(rnk, n, 1,
717 			      FFTW_MPI_DEFAULT_BLOCK,
718 			      FFTW_MPI_DEFAULT_BLOCK,
719 			      in, out, comm, kind, flags);
720 }
721 
XM(plan_r2r_2d)722 X(plan) XM(plan_r2r_2d)(ptrdiff_t nx, ptrdiff_t ny, R *in, R *out,
723 			MPI_Comm comm,
724 			X(r2r_kind) kindx, X(r2r_kind) kindy,
725 			unsigned flags)
726 {
727      ptrdiff_t n[2];
728      X(r2r_kind) kind[2];
729      n[0] = nx; n[1] = ny;
730      kind[0] = kindx; kind[1] = kindy;
731      return XM(plan_r2r)(2, n, in, out, comm, kind, flags);
732 }
733 
XM(plan_r2r_3d)734 X(plan) XM(plan_r2r_3d)(ptrdiff_t nx, ptrdiff_t ny, ptrdiff_t nz,
735 			R *in, R *out,
736 			MPI_Comm comm,
737 			X(r2r_kind) kindx, X(r2r_kind) kindy,
738 			X(r2r_kind) kindz,
739 			unsigned flags)
740 {
741      ptrdiff_t n[3];
742      X(r2r_kind) kind[3];
743      n[0] = nx; n[1] = ny; n[2] = nz;
744      kind[0] = kindx; kind[1] = kindy; kind[2] = kindz;
745      return XM(plan_r2r)(3, n, in, out, comm, kind, flags);
746 }
747 
748 /*************************************************************************/
749 /* R2C/C2R API */
750 
plan_guru_rdft2(int rnk,const XM (ddim)* dims0,ptrdiff_t howmany,R * r,C * c,MPI_Comm comm,rdft_kind kind,unsigned flags)751 static X(plan) plan_guru_rdft2(int rnk, const XM(ddim) *dims0,
752 			       ptrdiff_t howmany,
753 			       R *r, C *c,
754 			       MPI_Comm comm, rdft_kind kind, unsigned flags)
755 {
756      int n_pes, i;
757      dtensor *sz;
758      R *cr = (R *) c;
759 
760      XM(init)();
761 
762      if (howmany < 0 || rnk < 2) return 0;
763      for (i = 0; i < rnk; ++i)
764 	  if (dims0[i].n < 1 || dims0[i].ib < 0 || dims0[i].ob < 0)
765 	       return 0;
766 
767      MPI_Comm_size(comm, &n_pes);
768      sz = default_sz(rnk, dims0, n_pes, 1);
769 
770      sz->dims[rnk-1].n = dims0[rnk-1].n / 2 + 1;
771      if (XM(num_blocks_total)(sz, IB) > n_pes
772 	 || XM(num_blocks_total)(sz, OB) > n_pes) {
773 	  XM(dtensor_destroy)(sz);
774 	  return 0;
775      }
776      sz->dims[rnk-1].n = dims0[rnk-1].n;
777 
778      if (kind == R2HC)
779 	  return X(mkapiplan)(0, flags,
780 			      XM(mkproblem_rdft2_d)(sz, howmany,
781 						    r, cr, comm, R2HC,
782 						    MPI_FLAGS(flags)));
783      else
784 	  return X(mkapiplan)(0, flags,
785 			      XM(mkproblem_rdft2_d)(sz, howmany,
786 						    cr, r, comm, HC2R,
787 						    MPI_FLAGS(flags)));
788 }
789 
XM(plan_many_dft_r2c)790 X(plan) XM(plan_many_dft_r2c)(int rnk, const ptrdiff_t *n,
791 			  ptrdiff_t howmany,
792 			  ptrdiff_t iblock, ptrdiff_t oblock,
793 			  R *in, C *out,
794 			  MPI_Comm comm, unsigned flags)
795 {
796      XM(ddim) *dims = simple_dims(rnk, n);
797      X(plan) pln;
798 
799      if (rnk == 1) {
800 	  dims[0].ib = iblock;
801 	  dims[0].ob = oblock;
802      }
803      else if (rnk > 1) {
804 	  dims[0 != (flags & FFTW_MPI_TRANSPOSED_IN)].ib = iblock;
805 	  dims[0 != (flags & FFTW_MPI_TRANSPOSED_OUT)].ob = oblock;
806      }
807 
808      pln = plan_guru_rdft2(rnk,dims,howmany, in,out, comm, R2HC, flags);
809      X(ifree)(dims);
810      return pln;
811 }
812 
XM(plan_many_dft_c2r)813 X(plan) XM(plan_many_dft_c2r)(int rnk, const ptrdiff_t *n,
814 			  ptrdiff_t howmany,
815 			  ptrdiff_t iblock, ptrdiff_t oblock,
816 			  C *in, R *out,
817 			  MPI_Comm comm, unsigned flags)
818 {
819      XM(ddim) *dims = simple_dims(rnk, n);
820      X(plan) pln;
821 
822      if (rnk == 1) {
823 	  dims[0].ib = iblock;
824 	  dims[0].ob = oblock;
825      }
826      else if (rnk > 1) {
827 	  dims[0 != (flags & FFTW_MPI_TRANSPOSED_IN)].ib = iblock;
828 	  dims[0 != (flags & FFTW_MPI_TRANSPOSED_OUT)].ob = oblock;
829      }
830 
831      pln = plan_guru_rdft2(rnk,dims,howmany, out,in, comm, HC2R, flags);
832      X(ifree)(dims);
833      return pln;
834 }
835 
XM(plan_dft_r2c)836 X(plan) XM(plan_dft_r2c)(int rnk, const ptrdiff_t *n, R *in, C *out,
837 		     MPI_Comm comm, unsigned flags)
838 {
839      return XM(plan_many_dft_r2c)(rnk, n, 1,
840 			      FFTW_MPI_DEFAULT_BLOCK,
841 			      FFTW_MPI_DEFAULT_BLOCK,
842 			      in, out, comm, flags);
843 }
844 
XM(plan_dft_r2c_2d)845 X(plan) XM(plan_dft_r2c_2d)(ptrdiff_t nx, ptrdiff_t ny, R *in, C *out,
846 			MPI_Comm comm, unsigned flags)
847 {
848      ptrdiff_t n[2];
849      n[0] = nx; n[1] = ny;
850      return XM(plan_dft_r2c)(2, n, in, out, comm, flags);
851 }
852 
XM(plan_dft_r2c_3d)853 X(plan) XM(plan_dft_r2c_3d)(ptrdiff_t nx, ptrdiff_t ny, ptrdiff_t nz,
854 			R *in, C *out, MPI_Comm comm, unsigned flags)
855 {
856      ptrdiff_t n[3];
857      n[0] = nx; n[1] = ny; n[2] = nz;
858      return XM(plan_dft_r2c)(3, n, in, out, comm, flags);
859 }
860 
XM(plan_dft_c2r)861 X(plan) XM(plan_dft_c2r)(int rnk, const ptrdiff_t *n, C *in, R *out,
862 		     MPI_Comm comm, unsigned flags)
863 {
864      return XM(plan_many_dft_c2r)(rnk, n, 1,
865 			      FFTW_MPI_DEFAULT_BLOCK,
866 			      FFTW_MPI_DEFAULT_BLOCK,
867 			      in, out, comm, flags);
868 }
869 
XM(plan_dft_c2r_2d)870 X(plan) XM(plan_dft_c2r_2d)(ptrdiff_t nx, ptrdiff_t ny, C *in, R *out,
871 			MPI_Comm comm, unsigned flags)
872 {
873      ptrdiff_t n[2];
874      n[0] = nx; n[1] = ny;
875      return XM(plan_dft_c2r)(2, n, in, out, comm, flags);
876 }
877 
XM(plan_dft_c2r_3d)878 X(plan) XM(plan_dft_c2r_3d)(ptrdiff_t nx, ptrdiff_t ny, ptrdiff_t nz,
879 			C *in, R *out, MPI_Comm comm, unsigned flags)
880 {
881      ptrdiff_t n[3];
882      n[0] = nx; n[1] = ny; n[2] = nz;
883      return XM(plan_dft_c2r)(3, n, in, out, comm, flags);
884 }
885 
886 /*************************************************************************/
887 /* New-array execute functions */
888 
XM(execute_dft)889 void XM(execute_dft)(const X(plan) p, C *in, C *out) {
890      /* internally, MPI plans are just rdft plans */
891      X(execute_r2r)(p, (R*) in, (R*) out);
892 }
893 
XM(execute_dft_r2c)894 void XM(execute_dft_r2c)(const X(plan) p, R *in, C *out) {
895      /* internally, MPI plans are just rdft plans */
896      X(execute_r2r)(p, in, (R*) out);
897 }
898 
XM(execute_dft_c2r)899 void XM(execute_dft_c2r)(const X(plan) p, C *in, R *out) {
900      /* internally, MPI plans are just rdft plans */
901      X(execute_r2r)(p, (R*) in, out);
902 }
903 
XM(execute_r2r)904 void XM(execute_r2r)(const X(plan) p, R *in, R *out) {
905      /* internally, MPI plans are just rdft plans */
906      X(execute_r2r)(p, in, out);
907 }
908