1 #include <complex.h>
2 #include <assert.h>
3 #include <mpi.h>
4 #include <fftw3-mpi.h>
5 
6 #include "vdwxc.h"
7 
8 #if defined HAVE_CONFIG_H
9 #include "config.h"
10 #endif
11 
12 #include "vdw_core.h"
13 
14 
vdwxc_set_communicator(vdwxc_data data,MPI_Comm mpi_comm)15 void vdwxc_set_communicator(vdwxc_data data, MPI_Comm mpi_comm)
16 {
17     assert(data->mpi_comm == NULL);
18     data->fft_type = VDW_FFTW_MPI;
19     MPI_Comm_rank(mpi_comm, &data->mpi_rank);
20     MPI_Comm_size(mpi_comm, &data->mpi_size);
21     data->mpi_comm = mpi_comm;
22 }
23 
vdwxc_init_mpi(vdwxc_data data,MPI_Comm mpi_comm)24 void vdwxc_init_mpi(vdwxc_data data, MPI_Comm mpi_comm)
25 {
26     fftw_mpi_init(); // May be called any number of times
27     vdwxc_set_communicator(data, mpi_comm);
28 
29     const ptrdiff_t plan_dims[3] = {data->cell.Nglobal[0], data->cell.Nglobal[1], data->cell.Nglobal[2]};
30     // Here we need to allocate a theta_ak and theta_ag for use with FFTW.
31     ptrdiff_t local_size_dims[3];
32     ptrdiff_t fftw_alloc_size;
33     data->kLDA = data->icell.Nglobal[2];
34     data->gLDA = 2 * data->kLDA;
35     local_size_dims[0] = data->cell.Nglobal[0];
36     local_size_dims[1] = data->cell.Nglobal[1];
37     local_size_dims[2] = data->kLDA;
38 
39     ptrdiff_t fftw_xsize, fftw_xstart, fftw_ysize, fftw_ystart;
40 
41     fftw_alloc_size = fftw_mpi_local_size_many_transposed(3,
42                                                           local_size_dims,
43                                                           data->kernel.nalpha,
44                                                           FFTW_MPI_DEFAULT_BLOCK,
45                                                           FFTW_MPI_DEFAULT_BLOCK,
46                                                           data->mpi_comm,
47                                                           &fftw_xsize,
48                                                           &fftw_xstart,
49                                                           &fftw_ysize,
50                                                           &fftw_ystart);
51     data->cell.Nlocal[0] = fftw_xsize;
52     data->cell.Nlocal[1] = data->cell.Nglobal[1];
53     data->cell.Nlocal[2] = data->cell.Nglobal[2];
54     data->cell.offset[0] = fftw_xstart;
55 
56     data->icell.Nlocal[0] = data->cell.Nglobal[0];
57     data->icell.Nlocal[1] = fftw_ysize;
58     data->icell.Nlocal[2] = data->icell.Nglobal[2];
59     data->icell.offset[1] = fftw_ystart;
60 
61     assert(fftw_alloc_size % data->kernel.nalpha == 0);
62     data->work_ka = fftw_alloc_complex(fftw_alloc_size);
63 
64     // TODO custom strategy FFT_ESTIMATE/MEASURE/etc.
65     data->plan_r2c = fftw_mpi_plan_many_dft_r2c(3,
66                                                 plan_dims,
67                                                 data->kernel.nalpha,
68                                                 FFTW_MPI_DEFAULT_BLOCK,
69                                                 FFTW_MPI_DEFAULT_BLOCK,
70                                                 (double*)data->work_ka,
71                                                 data->work_ka,
72                                                 data->mpi_comm,
73                                                 FFTW_ESTIMATE|FFTW_MPI_TRANSPOSED_OUT);
74     data->plan_c2r = fftw_mpi_plan_many_dft_c2r(3,
75                                                 plan_dims,
76                                                 data->kernel.nalpha,
77                                                 FFTW_MPI_DEFAULT_BLOCK,
78                                                 FFTW_MPI_DEFAULT_BLOCK,
79                                                 data->work_ka,
80                                                 (double*)data->work_ka,
81                                                 data->mpi_comm,
82                                                 FFTW_ESTIMATE|FFTW_MPI_TRANSPOSED_IN);
83     assert(data->plan_r2c != NULL);
84     assert(data->plan_c2r != NULL);
85     vdwxc_allocate_buffers(data);
86 }
87 
88 // We want vdwxc_init_pfft to be defined even when we do not have PFFT.
89 // That way, codes do not need to manage precompiler arguments to check
90 // for which symbols exist in libvdwxc, except for MPI where each DFT
91 // code probably has a precompiler variable anyway.
92 #ifndef HAVE_PFFT
vdwxc_init_pfft(vdwxc_data data,MPI_Comm comm,int proc1,int proc2)93 void vdwxc_init_pfft(vdwxc_data data, MPI_Comm comm, int proc1, int proc2)
94 {
95     assert(0);
96 }
97 #endif
98