1!
2! Copyright (C) Quantum ESPRESSO group
3!
4! This file is distributed under the terms of the
5! GNU General Public License. See the file `License'
6! in the root directory of the present distribution,
7! or http://www.gnu.org/copyleft/gpl.txt .
8!
9
10#if defined(__CUDA)
11#define DEV_ATTRIBUTES , DEVICE
12#else
13#define DEV_ATTRIBUTES
14#endif
15
16!=----------------------------------------------------------------------------=!
17MODULE fft_types
18!=----------------------------------------------------------------------------=!
19
20#if defined(__CUDA)
21  USE cudafor
22#endif
23  USE fft_support, ONLY : good_fft_order, good_fft_dimension
24  USE fft_param
25  USE omp_lib
26  IMPLICIT NONE
27  PRIVATE
28  SAVE
29
30  !
31  !  Data type for FFT descriptor.
32  !
33  TYPE fft_type_descriptor
34    !
35    ! FFT dimensions
36    !
37    INTEGER :: nr1    = 0  !
38    INTEGER :: nr2    = 0  ! effective FFT dimensions of the 3D grid (global)
39    INTEGER :: nr3    = 0  !
40    INTEGER :: nr1x   = 0  ! FFT grids leading dimensions
41    INTEGER :: nr2x   = 0  ! dimensions of the arrays for the 3D grid (global)
42    INTEGER :: nr3x   = 0  ! may differ from nr1 ,nr2 ,nr3 in order to boost performances
43    !
44    !  Parallel layout: in reciprocal space data are organized in columns (sticks) along
45    !                   the third direction and distributed across nproc processors.
46    !                   In real space data are distributed in blocks comprising sections
47    !                   of the Y and Z directions and complete rows in the X direction in
48    !                   a matrix of  nproc2 x nproc3  processors.
49    !                   nproc = nproc2 x nproc3 and additional communicators are introduced
50    !                   for data redistribution across matrix columns and rows.
51    !
52    ! communicators and processor coordinates
53    !
54    LOGICAL :: lpara  = .FALSE. ! .TRUE. if parallel FFT is active
55    LOGICAL :: lgamma = .FALSE. ! .TRUE. if the grid has Gamma symmetry
56    INTEGER :: root   = 0 ! root processor
57    INTEGER :: comm   = MPI_COMM_NULL ! communicator for the main fft group
58    INTEGER :: comm2  = MPI_COMM_NULL ! communicator for the fft group along the second direction
59    INTEGER :: comm3  = MPI_COMM_NULL ! communicator for the fft group along the third direction
60    INTEGER :: nproc  = 1 ! number of processor in the main fft group
61    INTEGER :: nproc2 = 1 ! number of processor in the fft group along the second direction
62    INTEGER :: nproc3 = 1 ! number of processor in the fft group along the third direction
63    INTEGER :: mype   = 0 ! my processor id (starting from 0) in the fft main communicator
64    INTEGER :: mype2  = 0 ! my processor id (starting from 0) in the fft communicator along the second direction (nproc2)
65    INTEGER :: mype3  = 0 ! my processor id (starting from 0) in the fft communicator along the third direction (nproc3)
66
67    INTEGER, ALLOCATABLE :: iproc(:,:) , iproc2(:), iproc3(:) ! subcommunicators proc mapping (starting from 1)
68    !
69    ! FFT distributed data dimensions and indices
70    !
71    INTEGER :: my_nr3p = 0 ! size of the "Z" section for this processor = nr3p( mype3 + 1 )    ~ nr3/nproc3
72    INTEGER :: my_nr2p = 0 ! size of the "Y" section for this processor = nr2p( mype2 + 1 )    ~ nr2/nproc2
73
74    INTEGER :: my_i0r3p = 0 ! offset of the first "Z" element of this proc in the nproc3 group = i0r3p( mype3 + 1 )
75    INTEGER :: my_i0r2p = 0 ! offset of the first "Y" element of this proc in the nproc2 group = i0r2p( mype2 + 1 )
76
77    INTEGER, ALLOCATABLE :: nr3p(:)  ! size of the "Z" section of each processor in the nproc3 group along Z
78    INTEGER, ALLOCATABLE :: nr3p_offset(:)  ! offset of the "Z" section of each processor in the nproc3 group along Z
79    INTEGER, ALLOCATABLE :: nr2p(:)  ! size of the "Y" section of each processor in the nproc2 group along Y
80    INTEGER, ALLOCATABLE :: nr2p_offset(:)  ! offset of the "Y" section of each processor in the nproc2 group along Y
81    INTEGER, ALLOCATABLE :: nr1p(:)  ! number of active "X" values ( potential ) for a given proc in the nproc2 group
82    INTEGER, ALLOCATABLE :: nr1w(:)  ! number of active "X" values ( wave func ) for a given proc in the nproc2 group
83    INTEGER              :: nr1w_tg  ! total number of active "X" values ( wave func ). used in task group ffts
84
85    INTEGER, ALLOCATABLE :: i0r3p(:) ! offset of the first "Z" element of each proc in the nproc3 group (starting from 0)
86    INTEGER, ALLOCATABLE :: i0r2p(:) ! offset of the first "Y" element of each proc in the nproc2 group (starting from 0)
87
88    INTEGER, ALLOCATABLE :: ir1p(:)  ! if >0 ir1p(m1) is the incremental index of the active ( potential ) X value of this proc
89    INTEGER, ALLOCATABLE :: indp(:,:)! is the inverse of ir1p
90    INTEGER, ALLOCATABLE :: ir1w(:)  ! if >0 ir1w(m1) is the incremental index of the active ( wave func ) X value of this proc
91    INTEGER, ALLOCATABLE :: indw(:,:)! is the inverse of ir1w
92    INTEGER, ALLOCATABLE :: ir1w_tg(:)! if >0 ir1w_tg(m1) is the incremental index of the active ( wfc ) X value in task group
93    INTEGER, ALLOCATABLE :: indw_tg(:)! is the inverse of ir1w_tg
94
95    INTEGER, POINTER DEV_ATTRIBUTES :: ir1p_d(:),   ir1w_d(:),   ir1w_tg_d(:)   ! duplicated version of the arrays declared above
96    INTEGER, POINTER DEV_ATTRIBUTES :: indp_d(:,:), indw_d(:,:), indw_tg_d(:,:) !
97    INTEGER, POINTER DEV_ATTRIBUTES :: nr1p_d(:),   nr1w_d(:),   nr1w_tg_d(:)   !
98
99    INTEGER :: nst      ! total number of sticks ( potential )
100
101    INTEGER, ALLOCATABLE :: nsp(:)   ! number of sticks per processor ( potential ) using proc index starting from 1
102                                     !                                              ... that is on proc mype -> nsp( mype + 1 )
103    INTEGER, ALLOCATABLE :: nsp_offset(:,:)   ! offset of sticks per processor ( potential )
104    INTEGER, ALLOCATABLE :: nsw(:)   ! number of sticks per processor ( wave func ) using proc index as above
105    INTEGER, ALLOCATABLE :: nsw_offset(:,:)   ! offset of sticks per processor ( wave func )
106    INTEGER, ALLOCATABLE :: nsw_tg(:)! number of sticks per processor ( wave func ) using proc index as above. task group version
107
108    INTEGER, ALLOCATABLE :: ngl(:) ! per proc. no. of non zero charge density/potential components
109    INTEGER, ALLOCATABLE :: nwl(:) ! per proc. no. of non zero wave function plane components
110
111    INTEGER :: ngm  ! my no. of non zero charge density/potential components
112                    !    ngm = dfftp%ngl( dfftp%mype + 1 )
113                    ! with gamma sym.
114                    !    ngm = ( dfftp%ngl( dfftp%mype + 1 ) + 1 ) / 2
115
116    INTEGER :: ngw  ! my no. of non zero wave function plane components
117                    !    ngw = dffts%nwl( dffts%mype + 1 )
118                    ! with gamma sym.
119                    !    ngw = ( dffts%nwl( dffts%mype + 1 ) + 1 ) / 2
120
121    INTEGER, ALLOCATABLE :: iplp(:) ! if > 0 is the iproc2 processor owning the active "X" value ( potential )
122    INTEGER, ALLOCATABLE :: iplw(:) ! if > 0 is the iproc2 processor owning the active "X" value ( wave func )
123
124    INTEGER :: nnp    = 0  ! number of 0 and non 0 sticks in a plane ( ~nr1*nr2/nproc )
125    INTEGER :: nnr    = 0  ! local number of FFT grid elements  ( ~nr1*nr2*nr3/nproc )
126                           ! size of the arrays allocated for the FFT, local to each processor:
127                           ! in parallel execution may differ from nr1x*nr2x*nr3x
128                           ! Not to be confused either with nr1*nr2*nr3
129    INTEGER :: nnr_tg = 0  ! local number of grid elements for task group FFT ( ~nr1*nr2*nr3/proc3 )
130    INTEGER, ALLOCATABLE :: iss(:)   ! index of the first rho stick on each proc
131    INTEGER, ALLOCATABLE :: isind(:) ! for each position in the plane indicate the stick index
132    INTEGER, ALLOCATABLE :: ismap(:) ! for each stick in the plane indicate the position
133
134    INTEGER, POINTER DEV_ATTRIBUTES :: ismap_d(:)
135
136    INTEGER, ALLOCATABLE :: nl(:)    ! position of the G vec in the FFT grid
137    INTEGER, ALLOCATABLE :: nlm(:)   ! with gamma sym. position of -G vec in the FFT grid
138
139    INTEGER, POINTER DEV_ATTRIBUTES :: nl_d(:)    ! duplication of the variables defined above
140    INTEGER, POINTER DEV_ATTRIBUTES :: nlm_d(:)   !
141    !
142    ! task group ALLTOALL communication layout
143    INTEGER, ALLOCATABLE :: tg_snd(:) ! number of elements to be sent in task group redistribution
144    INTEGER, ALLOCATABLE :: tg_rcv(:) ! number of elements to be received in task group redistribution
145    INTEGER, ALLOCATABLE :: tg_sdsp(:)! send displacement for task group A2A communication
146    INTEGER, ALLOCATABLE :: tg_rdsp(:)! receive displacement for task group A2A communicattion
147    !
148    LOGICAL :: has_task_groups = .FALSE.
149    LOGICAL :: use_pencil_decomposition = .TRUE.
150    !
151    CHARACTER(len=12):: rho_clock_label  = ' '
152    CHARACTER(len=12):: wave_clock_label = ' '
153
154    INTEGER :: grid_id
155#if defined(__CUDA)
156    INTEGER(kind=cuda_stream_kind), allocatable, dimension(:) :: stream_scatter_yz
157    INTEGER(kind=cuda_stream_kind), allocatable, dimension(:) :: stream_many
158    INTEGER                                                   :: nstream_many = 16
159
160    INTEGER(kind=cuda_stream_kind) :: a2a_comp
161    INTEGER(kind=cuda_stream_kind), allocatable, dimension(:) :: bstreams
162    TYPE(cudaEvent), allocatable, dimension(:) :: bevents
163
164    INTEGER              :: batchsize = 16    ! how many ffts to batch together
165    INTEGER              :: subbatchsize = 4  ! size of subbatch for pipelining
166
167#if defined(__IPC)
168    INTEGER :: IPC_PEER(16)          ! This is used for IPC that is not imlpemented yet.
169#endif
170    INTEGER, ALLOCATABLE :: srh(:,:) ! Isend/recv handles by subbatch
171#endif
172    COMPLEX(DP), ALLOCATABLE, DIMENSION(:) :: aux
173#if defined(__FFT_OPENMP_TASKS)
174    INTEGER, ALLOCATABLE :: comm2s(:) ! multiple communicator for the fft group along the second direction
175    INTEGER, ALLOCATABLE :: comm3s(:) ! multiple communicator for the fft group along the third direction
176#endif
177  END TYPE
178
179  REAL(DP) :: fft_dual = 4.0d0
180  INTEGER  :: incremental_grid_identifier = 0
181
182  PUBLIC :: fft_type_descriptor, fft_type_init
183  PUBLIC :: fft_type_allocate, fft_type_deallocate
184  PUBLIC :: fft_stick_index, fft_index_to_3d
185
186CONTAINS
187
188!=----------------------------------------------------------------------------=!
189
190  SUBROUTINE fft_type_allocate( desc, at, bg, gcutm, comm, fft_fact, nyfft  )
191  !
192  ! routine allocating arrays of fft_type_descriptor, called by fft_type_init
193  !
194    TYPE (fft_type_descriptor) :: desc
195    REAL(DP), INTENT(IN) :: at(3,3), bg(3,3)
196    REAL(DP), INTENT(IN) :: gcutm
197    INTEGER, INTENT(IN), OPTIONAL :: fft_fact(3)
198    INTEGER, INTENT(IN), OPTIONAL :: nyfft
199    INTEGER, INTENT(in) :: comm ! mype starting from 0
200    INTEGER :: nx, ny, ierr, nzfft, i, nsubbatches
201    INTEGER :: mype, root, nproc, iproc, iproc2, iproc3 ! mype starting from 0
202    INTEGER :: color, key
203     !write (6,*) ' inside fft_type_allocate' ; FLUSH(6)
204
205    IF ( ALLOCATED( desc%nsp ) ) &
206        CALL fftx_error__(' fft_type_allocate ', ' fft arrays already allocated ', 1 )
207
208    desc%comm = comm
209#if defined(__MPI)
210    IF( desc%comm == MPI_COMM_NULL ) THEN
211       CALL fftx_error__( ' fft_type_allocate ', ' fft communicator is null ', 1 )
212    END IF
213#endif
214    !
215    root = 0 ; mype = 0 ; nproc = 1
216#if defined(__MPI)
217    CALL MPI_COMM_RANK( comm, mype, ierr )
218    CALL MPI_COMM_SIZE( comm, nproc, ierr )
219#endif
220    desc%root = root ; desc%mype = mype ; desc%nproc   = nproc
221
222    IF ( present(nyfft) ) THEN
223      ! check on yfft group dimension
224      CALL fftx_error__( ' fft_type_allocate ', ' MOD(nproc,nyfft) .ne. 0 ', MOD(nproc,nyfft) )
225
226!#define ZCOMPACT
227#if defined(__MPI)
228#if defined(ZCOMPACT)
229      !write (6,*) ' FFT IS ZCOMPACT '
230      nzfft = nproc / nyfft
231      color = mype / nzfft       ;      key   = MOD( mype, nzfft )
232#else
233      !write (6,*) ' FFT IS YCOMPACT '
234      color = MOD( mype, nyfft ) ;      key   = mype / nyfft
235#endif
236
237      ! processes with the same key are in the same group along Y
238      CALL MPI_COMM_SPLIT( comm, key, color, desc%comm2, ierr )
239      CALL MPI_COMM_RANK( desc%comm2, desc%mype2, ierr )
240      CALL MPI_COMM_SIZE( desc%comm2, desc%nproc2, ierr )
241
242      ! processes with the same color are in the same group along Z
243      CALL MPI_COMM_SPLIT( comm, color, key, desc%comm3, ierr )
244      CALL MPI_COMM_RANK( desc%comm3, desc%mype3, ierr )
245      CALL MPI_COMM_SIZE( desc%comm3, desc%nproc3, ierr )
246#if defined(__FFT_OPENMP_TASKS)
247      ALLOCATE( desc%comm2s( omp_get_max_threads() ))
248      ALLOCATE( desc%comm3s( omp_get_max_threads() ))
249      DO i=1, OMP_GET_MAX_THREADS()
250         CALL MPI_COMM_DUP(desc%comm2, desc%comm2s(i), ierr)
251         CALL MPI_COMM_DUP(desc%comm3, desc%comm3s(i), ierr)
252      ENDDO
253#endif
254#else
255      desc%comm2 = desc%comm ; desc%mype2 = desc%mype ; desc%nproc2 = desc%nproc
256      desc%comm3 = desc%comm ; desc%mype3 = desc%mype ; desc%nproc3 = desc%nproc
257#endif
258
259    ENDIF
260    !write (6,*) '  nproc and  mype  '
261    !write (6,*) desc%nproc, desc%nproc2, desc%nproc3
262    !write (6,*) desc%mype, desc%mype2, desc%mype3
263
264    ALLOCATE ( desc%iproc(desc%nproc2,desc%nproc3), desc%iproc2(desc%nproc), desc%iproc3(desc%nproc) )
265    do iproc = 1, desc%nproc
266#if defined(ZCOMPACT)
267       iproc3 = MOD(iproc-1, desc%nproc3) + 1 ; iproc2 = (iproc-1)/desc%nproc3 + 1
268#else
269       iproc2 = MOD(iproc-1, desc%nproc2) + 1 ; iproc3 = (iproc-1)/desc%nproc2 + 1
270#endif
271       desc%iproc2(iproc) = iproc2 ; desc%iproc3(iproc) = iproc3
272       desc%iproc(iproc2,iproc3) = iproc
273    end do
274
275    CALL realspace_grid_init( desc, at, bg, gcutm, fft_fact )
276
277    ALLOCATE( desc%nr2p ( desc%nproc2 ), desc%i0r2p( desc%nproc2 ) ) ; desc%nr2p = 0 ; desc%i0r2p = 0
278    ALLOCATE( desc%nr2p_offset ( desc%nproc2 ) ) ; desc%nr2p_offset = 0
279    ALLOCATE( desc%nr3p ( desc%nproc3 ), desc%i0r3p( desc%nproc3 ) ) ; desc%nr3p = 0 ; desc%i0r3p = 0
280    ALLOCATE( desc%nr3p_offset ( desc%nproc3 ) ) ; desc%nr3p_offset = 0
281
282    nx = desc%nr1x
283    ny = desc%nr2x
284
285    ALLOCATE( desc%nsp( desc%nproc ) ) ; desc%nsp   = 0
286    ALLOCATE( desc%nsp_offset( desc%nproc2, desc%nproc3 ) ) ; desc%nsp_offset = 0
287    ALLOCATE( desc%nsw( desc%nproc ) ) ; desc%nsw   = 0
288    ALLOCATE( desc%nsw_offset( desc%nproc2, desc%nproc3 ) ) ; desc%nsw_offset = 0
289    ALLOCATE( desc%nsw_tg( desc%nproc ) ) ; desc%nsw_tg   = 0
290    ALLOCATE( desc%ngl( desc%nproc ) ) ; desc%ngl   = 0
291    ALLOCATE( desc%nwl( desc%nproc ) ) ; desc%nwl   = 0
292    ALLOCATE( desc%iss( desc%nproc ) ) ; desc%iss   = 0
293    ALLOCATE( desc%isind( nx * ny ) ) ; desc%isind = 0
294    ALLOCATE( desc%ismap( nx * ny ) ) ; desc%ismap = 0
295    ALLOCATE( desc%nr1p( desc%nproc2 ) ) ; desc%nr1p  = 0
296    ALLOCATE( desc%nr1w( desc%nproc2 ) ) ; desc%nr1w  = 0
297    ALLOCATE( desc%ir1p( desc%nr1x ) ) ; desc%ir1p  = 0
298    ALLOCATE( desc%indp( desc%nr1x,desc%nproc2 ) ) ; desc%indp  = 0
299    ALLOCATE( desc%ir1w( desc%nr1x ) ) ; desc%ir1w  = 0
300    ALLOCATE( desc%ir1w_tg( desc%nr1x ) ) ; desc%ir1w_tg  = 0
301    ALLOCATE( desc%indw( desc%nr1x, desc%nproc2 ) ) ; desc%indw  = 0
302    ALLOCATE( desc%indw_tg( desc%nr1x ) ) ; desc%indw_tg  = 0
303    ALLOCATE( desc%iplp( nx ) ) ; desc%iplp  = 0
304    ALLOCATE( desc%iplw( nx ) ) ; desc%iplw  = 0
305
306    ALLOCATE( desc%tg_snd( desc%nproc2) ) ; desc%tg_snd = 0
307    ALLOCATE( desc%tg_rcv( desc%nproc2) ) ; desc%tg_rcv = 0
308    ALLOCATE( desc%tg_sdsp( desc%nproc2) ) ; desc%tg_sdsp = 0
309    ALLOCATE( desc%tg_rdsp( desc%nproc2) ) ; desc%tg_rdsp = 0
310
311#if defined(__CUDA)
312    ALLOCATE( desc%indp_d( desc%nr1x,desc%nproc2 ) ) ; desc%indp_d  = 0
313    ALLOCATE( desc%indw_d( desc%nr1x, desc%nproc2 ) ) ; desc%indw_d  = 0
314    ALLOCATE( desc%indw_tg_d( desc%nr1x, 1 ) ) ; desc%indw_tg_d  = 0
315    !
316    ALLOCATE( desc%nr1p_d(desc%nproc2)) ; desc%nr1p_d  = 0
317    ALLOCATE( desc%nr1w_d(desc%nproc2)) ; desc%nr1w_d  = 0
318    ALLOCATE( desc%nr1w_tg_d(1) ) ; desc%nr1w_tg_d = 0
319
320    ALLOCATE( desc%ir1p_d( desc%nr1x ) ) ; desc%ir1p_d  = 0
321    ALLOCATE( desc%ir1w_d( desc%nr1x ) ) ; desc%ir1w_d  = 0
322    ALLOCATE( desc%ir1w_tg_d( desc%nr1x ) ) ; desc%ir1w_tg_d  = 0
323    ALLOCATE( desc%ismap_d( nx * ny ) ) ; desc%ismap_d = 0
324
325    ALLOCATE ( desc%stream_scatter_yz(desc%nproc3) ) ;
326    DO iproc = 1, desc%nproc3
327        ierr = cudaStreamCreate(desc%stream_scatter_yz(iproc))
328    END DO
329    !
330    ALLOCATE ( desc%stream_many(desc%nstream_many) ) ;
331    DO i = 1, desc%nstream_many
332        ierr = cudaStreamCreate(desc%stream_many(i))
333        IF ( ierr /= 0 ) CALL fftx_error__( ' fft_type_allocate ', ' Error creating stream ', i )
334    END DO
335
336    ierr = cudaStreamCreate( desc%a2a_comp )
337
338    nsubbatches = ceiling(real(desc%batchsize)/desc%subbatchsize)
339
340    ALLOCATE( desc%bstreams( nsubbatches ) )
341    ALLOCATE( desc%bevents( nsubbatches ) )
342    DO i = 1, nsubbatches
343      ierr = cudaStreamCreate( desc%bstreams(i) )
344      ierr = cudaEventCreate( desc%bevents(i) )
345    ENDDO
346    ALLOCATE( desc%srh(2*nproc, nsubbatches))
347
348#endif
349
350    incremental_grid_identifier = incremental_grid_identifier + 1
351    desc%grid_id = incremental_grid_identifier
352
353  END SUBROUTINE fft_type_allocate
354
355  SUBROUTINE fft_type_deallocate( desc )
356    TYPE (fft_type_descriptor) :: desc
357    INTEGER :: i, ierr, nsubbatches
358     !write (6,*) ' inside fft_type_deallocate' ; FLUSH(6)
359    IF ( ALLOCATED( desc%nr2p ) )   DEALLOCATE( desc%nr2p )
360    IF ( ALLOCATED( desc%nr2p_offset ) )   DEALLOCATE( desc%nr2p_offset )
361    IF ( ALLOCATED( desc%nr3p_offset ) )   DEALLOCATE( desc%nr3p_offset )
362    IF ( ALLOCATED( desc%i0r2p ) )  DEALLOCATE( desc%i0r2p )
363    IF ( ALLOCATED( desc%nr3p ) )   DEALLOCATE( desc%nr3p )
364    IF ( ALLOCATED( desc%i0r3p ) )  DEALLOCATE( desc%i0r3p )
365    IF ( ALLOCATED( desc%nsp ) )    DEALLOCATE( desc%nsp )
366    IF ( ALLOCATED( desc%nsp_offset ) )    DEALLOCATE( desc%nsp_offset )
367    IF ( ALLOCATED( desc%nsw ) )    DEALLOCATE( desc%nsw )
368    IF ( ALLOCATED( desc%nsw_offset ) )    DEALLOCATE( desc%nsw_offset )
369    IF ( ALLOCATED( desc%nsw_tg ) ) DEALLOCATE( desc%nsw_tg )
370    IF ( ALLOCATED( desc%ngl ) )    DEALLOCATE( desc%ngl )
371    IF ( ALLOCATED( desc%nwl ) )    DEALLOCATE( desc%nwl )
372    IF ( ALLOCATED( desc%iss ) )    DEALLOCATE( desc%iss )
373    IF ( ALLOCATED( desc%isind ) )  DEALLOCATE( desc%isind )
374    IF ( ALLOCATED( desc%ismap ) )  DEALLOCATE( desc%ismap )
375    IF ( ALLOCATED( desc%nr1p ) )   DEALLOCATE( desc%nr1p )
376    IF ( ALLOCATED( desc%nr1w ) )   DEALLOCATE( desc%nr1w )
377    IF ( ALLOCATED( desc%ir1p ) )   DEALLOCATE( desc%ir1p )
378    IF ( ALLOCATED( desc%indp ) )   DEALLOCATE( desc%indp )
379    IF ( ALLOCATED( desc%ir1w ) )   DEALLOCATE( desc%ir1w )
380    IF ( ALLOCATED( desc%ir1w_tg ) )DEALLOCATE( desc%ir1w_tg )
381    IF ( ALLOCATED( desc%indw ) )   DEALLOCATE( desc%indw )
382    IF ( ALLOCATED( desc%indw_tg ) )DEALLOCATE( desc%indw_tg )
383    IF ( ALLOCATED( desc%iplp ) )   DEALLOCATE( desc%iplp )
384    IF ( ALLOCATED( desc%iplw ) )   DEALLOCATE( desc%iplw )
385    IF ( ALLOCATED( desc%iproc ) )  DEALLOCATE( desc%iproc )
386    IF ( ALLOCATED( desc%iproc2 ) ) DEALLOCATE( desc%iproc2 )
387    IF ( ALLOCATED( desc%iproc3 ) ) DEALLOCATE( desc%iproc3 )
388
389    IF ( ALLOCATED( desc%tg_snd ) ) DEALLOCATE( desc%tg_snd )
390    IF ( ALLOCATED( desc%tg_rcv ) ) DEALLOCATE( desc%tg_rcv )
391    IF ( ALLOCATED( desc%tg_sdsp ) )DEALLOCATE( desc%tg_sdsp )
392    IF ( ALLOCATED( desc%tg_rdsp ) )DEALLOCATE( desc%tg_rdsp )
393
394    IF ( ALLOCATED( desc%nl ) )  DEALLOCATE( desc%nl )
395    IF ( ALLOCATED( desc%nlm ) ) DEALLOCATE( desc%nlm )
396
397#if defined(__CUDA)
398    IF ( ALLOCATED( desc%ismap_d ) )   DEALLOCATE( desc%ismap_d )
399    IF ( ALLOCATED( desc%ir1p_d ) )    DEALLOCATE( desc%ir1p_d )
400    IF ( ALLOCATED( desc%ir1w_d ) )    DEALLOCATE( desc%ir1w_d )
401    IF ( ALLOCATED( desc%ir1w_tg_d ) ) DEALLOCATE( desc%ir1w_tg_d )
402
403    IF ( ALLOCATED( desc%indp_d ) )   DEALLOCATE( desc%indp_d )
404    IF ( ALLOCATED( desc%indw_d ) )    DEALLOCATE( desc%indw_d )
405    IF ( ALLOCATED( desc%indw_tg_d ) ) DEALLOCATE( desc%indw_tg_d )
406
407    IF ( ALLOCATED( desc%nr1p_d ) )   DEALLOCATE( desc%nr1p_d )
408    IF ( ALLOCATED( desc%nr1w_d ) )    DEALLOCATE( desc%nr1w_d )
409    IF ( ALLOCATED( desc%nr1w_tg_d ) ) DEALLOCATE( desc%nr1w_tg_d )
410
411    IF (ALLOCATED(desc%stream_scatter_yz)) THEN
412        do i = 1, desc%nproc3
413            ierr = cudaStreamDestroy(desc%stream_scatter_yz(i))
414        end do
415        DEALLOCATE(desc%stream_scatter_yz)
416    END IF
417    IF (ALLOCATED(desc%stream_many)) THEN
418        do i = 1, desc%nstream_many
419            ierr = cudaStreamDestroy(desc%stream_many(i))
420        end do
421        DEALLOCATE(desc%stream_many)
422    END IF
423
424    IF ( ALLOCATED( desc%nl_d ) )  DEALLOCATE( desc%nl_d )
425    IF ( ALLOCATED( desc%nlm_d ) ) DEALLOCATE( desc%nlm_d )
426    !
427    ! SLAB decomposition
428    IF ( ALLOCATED( desc%srh ) )   DEALLOCATE( desc%srh )
429    ierr = cudaStreamDestroy( desc%a2a_comp )
430
431    IF ( ALLOCATED(desc%bstreams) ) THEN
432        nsubbatches = ceiling(real(desc%batchsize)/desc%subbatchsize)
433        DO i = 1, nsubbatches
434          ierr = cudaStreamDestroy( desc%bstreams(i) )
435          ierr = cudaEventDestroy( desc%bevents(i) )
436        ENDDO
437        !
438        DEALLOCATE( desc%bstreams )
439        DEALLOCATE( desc%bevents )
440    END IF
441
442#endif
443
444    desc%comm  = MPI_COMM_NULL
445#if defined(__MPI)
446    IF (desc%comm2 /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm2, ierr )
447    IF (desc%comm3 /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm3, ierr )
448#if defined(__FFT_OPENMP_TASKS)
449    DO i=1, SIZE(desc%comm2s)
450       IF (desc%comm2s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm2s(i), ierr )
451       IF (desc%comm3s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm3s(i), ierr )
452    ENDDO
453    DEALLOCATE( desc%comm2s )
454    DEALLOCATE( desc%comm3s )
455#endif
456#else
457    desc%comm2 = MPI_COMM_NULL
458    desc%comm3 = MPI_COMM_NULL
459#endif
460
461    desc%nr1    = 0 ; desc%nr2    = 0 ; desc%nr3    = 0
462    desc%nr1x   = 0 ; desc%nr2x   = 0 ; desc%nr3x   = 0
463
464    desc%grid_id = 0
465
466  END SUBROUTINE fft_type_deallocate
467
468!=----------------------------------------------------------------------------=!
469
470  SUBROUTINE fft_type_set( desc, nst, ub, lb, idx, in1, in2, ncp, ncpw, ngp, ngpw, st, stw, nmany )
471
472    TYPE (fft_type_descriptor) :: desc
473
474    INTEGER, INTENT(in) :: nst              ! total number of stiks
475    INTEGER, INTENT(in) :: ub(3), lb(3)     ! upper and lower bound of real space indices
476    INTEGER, INTENT(in) :: idx(:)           ! sorting index of the sticks
477    INTEGER, INTENT(in) :: in1(:)           ! x-index of a stick
478    INTEGER, INTENT(in) :: in2(:)           ! y-index of a stick
479    INTEGER, INTENT(in) :: ncp(:)           ! number of rho  columns per processor
480    INTEGER, INTENT(in) :: ncpw(:)          ! number of wave columns per processor
481    INTEGER, INTENT(in) :: ngp(:)           ! number of rho  G-vectors per processor
482    INTEGER, INTENT(in) :: ngpw(:)          ! number of wave G-vectors per processor
483    INTEGER, INTENT(in) :: st( lb(1) : ub(1), lb(2) : ub(2) )   ! stick owner of a given rho stick
484    INTEGER, INTENT(in) :: stw( lb(1) : ub(1), lb(2) : ub(2) )  ! stick owner of a given wave stick
485    INTEGER, INTENT(in) :: nmany            ! number of FFT bands
486
487    INTEGER :: nsp( desc%nproc ), nsw_tg, nr1w_tg
488    INTEGER :: np, nq, i, is, iss, i1, i2, m1, m2, ip
489    INTEGER :: ncpx, nr1px, nr2px, nr3px
490    INTEGER :: nr1, nr2, nr3    ! size of real space grid
491    INTEGER :: nr1x, nr2x, nr3x ! padded size of real space grid
492    INTEGER :: ierr
493     !write (6,*) ' inside fft_type_set' ; FLUSH(6)
494    !
495    !
496#if defined(__MPI)
497#if defined(__FFT_OPENMP_TASKS)
498    IF (nmany > OMP_GET_MAX_THREADS()) THEN
499       DO i=1, SIZE(desc%comm2s)
500          IF (desc%comm2s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm2s(i), ierr )
501          IF (desc%comm3s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm3s(i), ierr )
502       ENDDO
503       DEALLOCATE( desc%comm2s )
504       DEALLOCATE( desc%comm3s )
505       ALLOCATE( desc%comm2s( nmany ))
506       ALLOCATE( desc%comm3s( nmany ))
507       DO i=1, nmany
508          CALL MPI_COMM_DUP(desc%comm2, desc%comm2s(i), ierr)
509          CALL MPI_COMM_DUP(desc%comm3, desc%comm3s(i), ierr)
510       ENDDO
511       !ELSEIF (nmany == 1) THEN
512       !  DO i=1, SIZE(desc%comm2s)
513       !     IF (desc%comm2s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm2s(i), ierr )
514       !     IF (desc%comm3s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm3s(i), ierr )
515       !  ENDDO
516       !  DEALLOCATE( desc%comm2s )
517       !  DEALLOCATE( desc%comm3s )
518       ENDIF
519#endif
520#endif
521    IF (.NOT. ALLOCATED( desc%nsp ) ) &
522        CALL fftx_error__(' fft_type_set ', ' fft arrays not yet allocated ', 1 )
523
524    IF ( desc%nr1 == 0 .OR. desc%nr2 == 0 .OR. desc%nr3 == 0 ) &
525        CALL fftx_error__(' fft_type_set ', ' fft dimensions not yet set ', 1 )
526
527    !  Set fft actual and leading dimensions to be used internally
528
529    nr1  = desc%nr1  ; nr2  = desc%nr2  ; nr3  = desc%nr3
530    nr1x = desc%nr1x ; nr2x = desc%nr2x ; nr3x = desc%nr3x
531
532    IF( ( nr1 > nr1x ) .or. ( nr2 > nr2x ) .or. ( nr3 > nr3x ) ) &
533      CALL fftx_error__( ' fft_type_set ', ' wrong fft dimensions ', 1 )
534
535    IF( ( size( desc%ngl ) < desc%nproc ) .or.  ( size( desc%iss ) < desc%nproc ) .or. &
536        ( size( desc%nr2p ) < desc%nproc2 ) .or. ( size( desc%i0r2p ) < desc%nproc2 ) .or. &
537        ( size( desc%nr3p ) < desc%nproc3 ) .or. ( size( desc%i0r3p ) < desc%nproc3 ) ) &
538      CALL fftx_error__( ' fft_type_set ', ' wrong descriptor dimensions ', 2 )
539
540    IF( ( size( idx ) < nst ) .or. ( size( in1 ) < nst ) .or. ( size( in2 ) < nst ) ) &
541      CALL fftx_error__( ' fft_type_set ', ' wrong number of stick dimensions ', 3 )
542
543    IF( ( size( ncp ) < desc%nproc ) .or. ( size( ngp ) < desc%nproc ) ) &
544      CALL fftx_error__( ' fft_type_set ', ' wrong stick dimensions ', 4 )
545
546    !  Set the number of "Y" values for each processor in the nproc2 group
547    np = nr2 / desc%nproc2
548    nq = nr2 - np * desc%nproc2
549    desc%nr2p(1:desc%nproc2) = np    ! assign a base value to all processors of the nproc2 group
550    DO i =1, nq ! assign an extra unit to the first nq processors of the nproc2 group
551       desc%nr2p(i) = np + 1
552    ENDDO
553    ! set the offset
554    desc%nr2p_offset(1) = 0
555    DO i =1, desc%nproc2-1
556       desc%nr2p_offset(i+1) = desc%nr2p_offset(i) + desc%nr2p(i)
557    ENDDO
558    !-- my_nr2p is the number of planes per processor of this processor   in the Y group
559    desc%my_nr2p = desc%nr2p( desc%mype2 + 1 )
560
561    !  Find out the index of the starting plane on each proc
562    desc%i0r2p = 0
563    DO i = 2, desc%nproc2
564       desc%i0r2p(i) = desc%i0r2p(i-1) + desc%nr2p(i-1)
565    ENDDO
566    !-- my_i0r2p is the index-offset of the starting plane of this processor  in the Y group
567    desc%my_i0r2p = desc%i0r2p( desc%mype2 + 1 )
568
569    !  Set the number of "Z" values for each processor in the nproc3 group
570    np = nr3 / desc%nproc3
571    nq = nr3 - np * desc%nproc3
572    desc%nr3p(1:desc%nproc3) = np    ! assign a base value to all processors
573    DO i =1, nq ! assign an extra unit to the first nq processors of the nproc3 group
574       desc%nr3p(i) = np + 1
575    END DO
576    ! set the offset
577    desc%nr3p_offset(1) = 0
578    DO i =1, desc%nproc3-1
579       desc%nr3p_offset(i+1) = desc%nr3p_offset(i) + desc%nr3p(i)
580    ENDDO
581    !-- my_nr3p is the number of planes per processor of this processor   in the Z group
582    desc%my_nr3p = desc%nr3p( desc%mype3 + 1 )
583
584    !  Find out the index of the starting plane on each proc
585    desc%i0r3p  = 0
586    DO i = 2, desc%nproc3
587       desc%i0r3p( i )  = desc%i0r3p( i-1 ) + desc%nr3p ( i-1 )
588    ENDDO
589    !-- my_i0r3p is the index-offset of the starting plane of this processor  in the Z group
590    desc%my_i0r3p = desc%i0r3p( desc%mype3 + 1 )
591
592    ! dimension of the xy plane. see ncplane
593
594    desc%nnp  = nr1x * nr2x
595
596!!!!!!!
597
598    desc%ngl( 1:desc%nproc )  = ngp( 1:desc%nproc )  ! local number of g vectors (rho) per processor
599    desc%nwl( 1:desc%nproc )  = ngpw( 1:desc%nproc ) ! local number of g vectors (wave) per processor
600
601    IF( size( desc%isind ) < ( nr1x * nr2x ) ) &
602      CALL fftx_error__( ' fft_type_set ', ' wrong descriptor dimensions, isind ', 5 )
603
604    IF( size( desc%iplp ) < ( nr1x ) .or. size( desc%iplw ) < ( nr1x ) ) &
605      CALL fftx_error__( ' fft_type_set ', ' wrong descriptor dimensions, ipl ', 5 )
606
607    IF( desc%my_nr3p == 0 .and. ( .not. desc%use_pencil_decomposition ) ) &
608      CALL fftx_error__( ' fft_type_set ', &
609                         ' there are processes with no planes. Use pencil decomposition (-pd .true.) ', 6 )
610
611    !
612    !  1. Temporarily store in the array "desc%isind" the index of the processor
613    !     that own the corresponding stick (index of proc starting from 1)
614    !  2. Set the array elements of  "desc%iplw" and "desc%iplp" to one
615    !     for that index corresponding to YZ planes containing at least one stick
616    !     this are used in the FFT transform along Y
617    !
618
619    desc%isind = 0  ! will contain the +ve or -ve of the processor number, if any, that owns the stick
620    desc%iplp  = 0  ! if > 0 is the nporc2 processor owning this ( potential ) X active plane
621    desc%iplw  = 0  ! if > 0 is the nproc2 processor owning this ( wave func ) X active value
622
623    !  Set nst to the proper number of sticks (the total number of 1d fft along z to be done)
624
625    desc%nst = 0
626    DO iss = 1, SIZE( idx )
627      is = idx( iss )
628      IF( is < 1 ) CYCLE
629      i1 = in1( is )
630      i2 = in2( is )
631      IF( st( i1, i2 ) > 0 ) THEN
632        desc%nst = desc%nst + 1
633        m1 = i1 + 1; IF ( m1 < 1 ) m1 = m1 + nr1
634        m2 = i2 + 1; IF ( m2 < 1 ) m2 = m2 + nr2
635        IF( stw( i1, i2 ) > 0 ) THEN
636          desc%isind( m1 + ( m2 - 1 ) * nr1x ) =  st( i1, i2 )
637          desc%iplw( m1 ) = desc%iproc2(st(i1,i2))
638        ELSE
639          desc%isind( m1 + ( m2 - 1 ) * nr1x ) = -st( i1, i2 )
640        ENDIF
641        desc%iplp( m1 ) = desc%iproc2(st(i1,i2))
642        IF( desc%lgamma ) THEN
643          IF( i1 /= 0 .OR. i2 /= 0 ) desc%nst = desc%nst + 1
644          m1 = -i1 + 1; IF ( m1 < 1 ) m1 = m1 + nr1
645          m2 = -i2 + 1; IF ( m2 < 1 ) m2 = m2 + nr2
646          IF( stw( -i1, -i2 ) > 0 ) THEN
647            desc%isind( m1 + ( m2 - 1 ) * nr1x ) =  st( -i1, -i2 )
648            desc%iplw( m1 ) = desc%iproc2(st(-i1,-i2))
649          ELSE
650            desc%isind( m1 + ( m2 - 1 ) * nr1x ) = -st( -i1, -i2 )
651          ENDIF
652          desc%iplp( m1 ) = desc%iproc2(st(-i1,-i2))
653        ENDIF
654      ENDIF
655    ENDDO
656    do m1=1,desc%nr1x
657       if (desc%iplw(m1)>0) then
658          if (desc%iplp(m1) /= desc%iplw(m1) )  then
659             write (6,*) 'WRONG iplp/iplw arrays'
660             write (6,*) desc%iplp
661             write (6,*) desc%iplw
662             CALL fftx_error__( ' fft_type_set ', ' iplp is wrong ', m1 )
663          end if
664       end if
665    end do
666    ! count how many active X values per each nproc2 processor and set the incremental index of this one
667
668    ! wave func X values first
669    desc%nr1w = 0 ; desc%ir1w = 0 ; desc%indw = 0
670    nr1w_tg = 0 ; desc%ir1w_tg = 0 ; desc%indw_tg = 0
671    do i1 = 1, nr1
672       if (desc%iplw(i1) > 0 ) then
673          desc%nr1w(desc%iplw(i1)) =  desc%nr1w(desc%iplw(i1)) + 1
674          desc%indw(desc%nr1w(desc%iplw(i1)),desc%iplw(i1)) = i1
675          nr1w_tg = nr1w_tg + 1 ; desc%ir1w_tg(i1) = nr1w_tg ; desc%indw_tg(nr1w_tg) = i1
676       end if
677       if (desc%iplw(i1) == desc%mype2 +1) desc%ir1w(i1) = desc%nr1w(desc%iplw(i1))
678    end do
679    desc%nr1w_tg = nr1w_tg ! this is useful in task group ffts
680
681    ! then potential X values
682    desc%nr1p = desc%nr1w ; desc%ir1p=desc%ir1w ; desc%indp = desc%indw
683    do i1 = 1, nr1
684       if ( (desc%iplw(i1) > 0) .and. (desc%iplp(i1) == 0) ) &
685             CALL fftx_error__( ' fft_type_set ', ' bad distribution of X values ', i1 )
686       if ( (desc%iplw(i1) > 0) ) cycle ! this X value has already been taken care of
687
688       if (desc%iplp(i1) > 0 ) then
689          desc%nr1p(desc%iplp(i1)) =  desc%nr1p(desc%iplp(i1)) + 1
690          desc%indp(desc%nr1p(desc%iplp(i1)),desc%iplp(i1)) = i1
691       end if
692       if (desc%iplp(i1) == desc%mype2+1) desc%ir1p(i1) = desc%nr1p(desc%iplp(i1))
693    end do
694
695    !
696    !  Compute for each proc the global index ( starting from 0 ) of the first
697    !  local stick ( desc%iss )
698    !
699
700    DO i = 1, desc%nproc
701      IF( i == 1 ) THEN
702        desc%iss( i ) = 0
703      ELSE
704        desc%iss( i ) = desc%iss( i - 1 ) + ncp( i - 1 )
705      ENDIF
706    ENDDO
707
708    ! iss(1:nproc) is the index offset of the first column of a given processor
709
710    IF( size( desc%ismap ) < ( nst ) ) &
711      CALL fftx_error__( ' fft_type_set ', ' wrong descriptor dimensions ', 6 )
712
713    !
714    !  1. Set the array desc%ismap which maps stick indexes to
715    !     position in the plane  ( iss )
716    !  2. Re-set the array "desc%isind",  that maps position
717    !     in the plane with stick indexes (it is the inverse of desc%ismap )
718    !
719
720    !  wave function sticks first
721
722    desc%ismap = 0     ! will be the global xy stick index in the global list of processor-ordered sticks
723    nsp        = 0     ! will be the number of sticks of a given processor
724    DO iss = 1, size( desc%isind )
725      ip = desc%isind( iss ) ! processor that owns iss wave stick. if it's a rho stick it's negative !
726      IF( ip > 0 ) THEN ! only operates on wave sticks
727        nsp( ip ) = nsp( ip ) + 1
728        desc%ismap( nsp( ip ) + desc%iss( ip ) ) = iss
729        IF( ip == ( desc%mype + 1 ) ) THEN
730          desc%isind( iss ) = nsp( ip ) ! isind is OVERWRITTEN as the ordered index in this processor stick list
731        ELSE
732          desc%isind( iss ) = 0         ! zero otherwise...
733        ENDIF
734      ENDIF
735    ENDDO
736
737    !  check number of sticks against the input value
738
739    IF( any( nsp( 1:desc%nproc ) /= ncpw( 1:desc%nproc ) ) ) THEN
740      DO ip = 1, desc%nproc
741        WRITE( stdout,*)  ' * ', ip, ' * ', nsp( ip ), ' /= ', ncpw( ip )
742      ENDDO
743      CALL fftx_error__( ' fft_type_set ', ' inconsistent number of sticks ', 7 )
744    ENDIF
745
746    desc%nsw( 1:desc%nproc ) = nsp( 1:desc%nproc )  ! -- number of wave sticks per processor
747    DO ip=1, desc%nproc3
748       desc%nsw_offset(1,ip) = 0
749       DO i=1, desc%nproc2-1
750          desc%nsw_offset(i+1,ip) = desc%nsw_offset(i,ip) + desc%nsw(desc%iproc(i,ip))
751       ENDDO
752    ENDDO
753
754    ! -- number of wave sticks per processor for task group ffts
755    desc%nsw_tg( 1:desc%nproc ) = 0
756    do ip =1, desc%nproc3
757       nsw_tg = sum(desc%nsw(desc%iproc(1:desc%nproc2,ip)))
758       desc%nsw_tg(desc%iproc(1:desc%nproc2,ip)) = nsw_tg
759    end do
760
761    !  then add pseudopotential stick
762
763    DO iss = 1, size( desc%isind )
764      ip = desc%isind( iss ) ! -ve of processor that owns iss rho stick. if it was a wave stick it's something non negative !
765      IF( ip < 0 ) THEN
766        nsp( -ip ) = nsp( -ip ) + 1
767        desc%ismap( nsp( -ip ) + desc%iss( -ip ) ) = iss
768        IF( -ip == ( desc%mype + 1 ) ) THEN
769          desc%isind( iss ) = nsp( -ip ) ! isind is OVERWRITTEN as the ordered index in this processor stick list
770        ELSE
771          desc%isind( iss ) = 0         ! zero otherwise...
772        ENDIF
773      ENDIF
774    ENDDO
775
776    !  check number of sticks against the input value
777
778    IF( any( nsp( 1:desc%nproc ) /= ncp( 1:desc%nproc ) ) ) THEN
779      DO ip = 1, desc%nproc
780        WRITE( stdout,*)  ' * ', ip, ' * ', nsp( ip ), ' /= ', ncp( ip )
781      ENDDO
782      CALL fftx_error__( ' fft_type_set ', ' inconsistent number of sticks ', 8 )
783    ENDIF
784
785    desc%nsp( 1:desc%nproc ) = nsp( 1:desc%nproc ) ! -- number of rho sticks per processor
786    DO ip=1, desc%nproc3
787       desc%nsp_offset(1,ip) = 0
788       DO i=1, desc%nproc2-1
789          desc%nsp_offset(i+1,ip) = desc%nsp_offset(i,ip) + desc%nsp(desc%iproc(i,ip))
790       ENDDO
791    ENDDO
792
793    IF( .NOT. desc%lpara ) THEN
794
795       desc%isind = 0
796       desc%iplw  = 0
797       desc%iplp  = 1
798
799       ! here we are setting parameter as if we were in a serial code,
800       ! sticks are along X dimension and not along Z
801       desc%nsp(1) = 0
802       desc%nsw(1) = 0
803       DO i1 = lb( 1 ), ub( 1 )
804         DO i2 = lb( 2 ), ub( 2 )
805           m1 = i1 + 1; IF ( m1 < 1 ) m1 = m1 + nr1
806           m2 = i2 + 1; IF ( m2 < 1 ) m2 = m2 + nr2
807           IF( st( i1, i2 ) > 0 ) THEN
808             desc%nsp(1) = desc%nsp(1) + 1
809           END IF
810           IF( stw( i1, i2 ) > 0 ) THEN
811             desc%nsw(1) = desc%nsw(1) + 1
812             desc%isind( m1 + ( m2 - 1 ) * nr1x ) =  1  ! st( i1, i2 )
813             desc%iplw( m1 ) = 1
814           ENDIF
815         ENDDO
816       ENDDO
817       !
818       ! if we are in a parallel run, but would like to use serial FFT, all
819       ! tasks must have the same parameters as if serial run.
820       !
821       desc%nnr  = nr1x * nr2x * nr3x
822       desc%nnp  = nr1x * nr2x
823       desc%my_nr2p = nr2 ;  desc%nr2p = nr2 ;  desc%i0r2p = 0
824       desc%my_nr3p = nr3 ;  desc%nr3p = nr3 ;  desc%i0r3p = 0
825       desc%nsw = desc%nsw(1)
826       desc%nsp = desc%nsp(1)
827       desc%ngl  = SUM(ngp)
828       desc%nwl  = SUM(ngpw)
829       !
830    END IF
831
832    !write (6,*) 'fft_type_set SUMMARY'
833    !write (6,*) 'desc%mype ', desc%mype
834    !write (6,*) 'desc%mype2', desc%mype2
835    !write (6,*) 'desc%mype3', desc%mype3
836    !write (6,*) 'nr1  nr2  nr3  dimensions : ', desc%nr1, desc%nr2, desc%nr3
837    !write (6,*) 'nr1x nr2x nr3x dimensions : ', desc%nr1x, desc%nr2x, desc%nr3x
838    !write (6,*) 'nr3p arrays'
839    !write (6,*) desc%nr3p
840    !write (6,*) 'i0r3p arrays'
841    !write (6,*) desc%i0r3p
842    !write (6,*) 'nr2p arrays'
843    !write (6,*) desc%nr2p
844    !write (6,*) 'i0r2p arrays'
845    !write (6,*) desc%i0r2p
846    !write (6,*) 'nsp/nsw arrays'
847    !write (6,*) desc%nsp
848    !write (6,*) desc%nsw
849    !write (6,*) 'nr1p/nr1w arrays'
850    !write (6,*) desc%nr1p
851    !write (6,*) desc%nr1w
852    !write (6,*) 'ir1p/ir1w arrays'
853    !write (6,*) desc%ir1p
854    !write (6,*) desc%ir1w
855    !write (6,*) 'indp/indw arrays'
856    !write (6,*) desc%indp
857    !write (6,*) desc%indw
858    !write (6,*) 'iplp/iplw arrays'
859    !write (6,*) desc%iplp
860    !write (6,*) desc%iplw
861
862    !  Finally set fft local workspace dimension
863
864    nr1px = MAXVAL( desc%nr1p( 1:desc%nproc2 ) )  ! maximum number of X values per processor in the nproc2 group
865    nr2px = MAXVAL( desc%nr2p( 1:desc%nproc2 ) )  ! maximum number of planes per processor in the nproc2 group
866    nr3px = MAXVAL( desc%nr3p( 1:desc%nproc3 ) )  ! maximum number of planes per processor in the nproc3 group
867    ncpx  = MAXVAL( ncp( 1:desc%nproc ) ) ! maximum number of columns per processor (use potential sticks to be safe)
868
869    IF ( desc%nproc == 1 ) THEN
870      desc%nnr  = nr1x * nr2x * nr3x
871      desc%nnr_tg = desc%nnr * desc%nproc2
872    ELSE
873      desc%nnr  = max( ncpx * nr3x, nr1x * nr2px * nr3px )  ! this is required to contain the local data in R and G space
874      desc%nnr  = max( desc%nnr, ncpx*nr3px*desc%nproc3, nr1px*nr2px*nr3px*desc%nproc2)  ! this is required to use ALLTOALL instead of ALLTOALLV
875      desc%nnr  = max( 1, desc%nnr ) ! ensure that desc%nrr > 0 ( for extreme parallelism )
876      desc%nnr_tg = desc%nnr * desc%nproc2
877    ENDIF
878
879    !write (6,*) ' nnr bounds'
880    !write (6,*) ' nr1x ',nr1x,' nr2x ', nr2x, ' nr3x ', nr3x
881    !write (6,*) ' nr1x * nr2x * nr3x',nr1x * nr2x * nr3x
882    !write (6,*) ' ncpx ',ncpx,' nr3px ', nr3px, ' desc%nproc3 ', desc%nproc3
883    !write (6,*) ' ncpx * nr3x ',ncpx * nr3x
884    !write (6,*) ' ncpx * nr3px * desc%nproc3 ',ncpx*nr3px*desc%nproc3
885    !write (6,*) ' nr1px ', nr1px,' nr2px ',nr2px,' desc%nproc2 ', desc%nproc2
886    !write (6,*) ' nr1x * nr2px * nr3px ',nr1x * nr2px * nr3px
887    !write (6,*) ' nr1px * nr2px *nr3px * desc%nproc2 ',nr1px*nr2px*nr3px*desc%nproc2
888    !write (6,*) ' desc%nnr ', desc%nnr
889
890    IF( desc%nr3x * desc%nsw( desc%mype + 1 ) > desc%nnr ) &
891        CALL fftx_error__( ' task_groups_init ', ' inconsistent desc%nnr ', 1 )
892    desc%tg_snd(1)  = desc%nr3x * desc%nsw( desc%mype + 1 )
893    desc%tg_rcv(1)  = desc%nr3x * desc%nsw( desc%iproc(1,desc%mype3+1) )
894    desc%tg_sdsp(1) = 0
895    desc%tg_rdsp(1) = 0
896    DO i = 2, desc%nproc2
897       desc%tg_snd(i)  = desc%nr3x * desc%nsw( desc%mype + 1 )
898       desc%tg_rcv(i)  = desc%nr3x * desc%nsw( desc%iproc(i,desc%mype3+1) )
899       desc%tg_sdsp(i) = desc%tg_sdsp(i-1) + desc%nnr
900       desc%tg_rdsp(i) = desc%tg_rdsp(i-1) + desc%tg_rcv(i-1)
901    ENDDO
902
903#if defined(__CUDA)
904    desc%ismap_d = desc%ismap
905    desc%ir1p_d = desc%ir1p
906    desc%ir1w_d = desc%ir1w
907    desc%ir1w_tg_d = desc%ir1w_tg
908
909    desc%indp_d = desc%indp
910    desc%indw_d = desc%indw
911    desc%indw_tg_d(:,1) = desc%indw_tg
912
913    desc%nr1p_d = desc%nr1p
914    desc%nr1w_d = desc%nr1w
915    desc%nr1w_tg_d(1) = desc%nr1w_tg
916
917#endif
918    IF (nmany > 1) ALLOCATE(desc%aux(nmany * desc%nnr))
919
920    RETURN
921
922  END SUBROUTINE fft_type_set
923
924!=----------------------------------------------------------------------------=!
925
926  SUBROUTINE fft_type_init( dfft, smap, pers, lgamma, lpara, comm, at, bg, gcut_in, dual_in, fft_fact, nyfft, nmany, use_pd )
927
928     USE stick_base
929
930     TYPE (fft_type_descriptor), INTENT(INOUT) :: dfft
931     TYPE (sticks_map), INTENT(INOUT) :: smap
932     CHARACTER(LEN=*), INTENT(IN) :: pers ! fft personality
933     LOGICAL, INTENT(IN) :: lpara
934     LOGICAL, INTENT(IN) :: lgamma
935     INTEGER, INTENT(IN) :: comm
936     REAL(DP), INTENT(IN) :: gcut_in
937     REAL(DP), INTENT(IN) :: bg(3,3)
938     REAL(DP), INTENT(IN) :: at(3,3)
939     REAL(DP), OPTIONAL, INTENT(IN) :: dual_in
940     INTEGER, INTENT(IN), OPTIONAL :: fft_fact(3)
941     INTEGER, INTENT(IN) :: nyfft
942     INTEGER, INTENT(IN) :: nmany
943     LOGICAL, OPTIONAL, INTENT(IN) :: use_pd ! whether to use pencil decomposition
944!
945!    Potential or dual
946!
947     INTEGER, ALLOCATABLE :: st(:,:)
948! ...   stick map, st(i,j) = number of G-vector in the
949! ...   stick whose x and y miller index are i and j
950     INTEGER, ALLOCATABLE :: nstp(:)
951! ...   number of sticks, nstp(ip) = number of stick for processor ip
952     INTEGER, ALLOCATABLE :: sstp(:)
953! ...   number of G-vectors, sstp(ip) = sum of the
954! ...   sticks length for processor ip = number of G-vectors owned by the processor ip
955     INTEGER :: nst
956! ...   nst      local number of sticks
957!
958! ...     Plane wave
959!
960     INTEGER, ALLOCATABLE :: stw(:,:)
961! ...   stick map (wave functions), stw(i,j) = number of G-vector in the
962! ...   stick whose x and y miller index are i and j
963     INTEGER, ALLOCATABLE :: nstpw(:)
964! ...   number of sticks (wave functions), nstpw(ip) = number of stick for processor ip
965     INTEGER, ALLOCATABLE :: sstpw(:)
966! ...   number of G-vectors (wave functions), sstpw(ip) = sum of the
967! ...   sticks length for processor ip = number of G-vectors owned by the processor ip
968     INTEGER :: nstw
969! ...   nstw     local number of sticks (wave functions)
970
971     REAL(DP) :: gcut, gkcut, dual
972     INTEGER  :: ngm, ngw
973     !write (6,*) ' inside fft_type_init' ; FLUSH(6)
974
975     dual = fft_dual
976     IF( PRESENT( dual_in ) ) dual = dual_in
977
978     IF( pers == 'rho' ) THEN
979        gcut = gcut_in
980        gkcut = gcut / dual
981     ELSE IF ( pers == 'wave' ) THEN
982        gkcut = gcut_in
983        gcut = gkcut * dual
984     ELSE
985        CALL fftx_error__(' fft_type_init ', ' unknown FFT personality ', 1 )
986     END IF
987     !write (*,*) 'FFT_TYPE_INIT pers, gkcut,gcut', pers, gkcut, gcut
988
989     IF( .NOT. ALLOCATED( dfft%nsp ) ) THEN
990        CALL fft_type_allocate( dfft, at, bg, gcut, comm, fft_fact=fft_fact, nyfft=nyfft )
991     ELSE
992        IF( dfft%comm /= comm ) THEN
993           CALL fftx_error__(' fft_type_init ', ' FFT already allocated with a different communicator ', 1 )
994        END IF
995     END IF
996
997     IF ( PRESENT (use_pd) ) dfft%use_pencil_decomposition = use_pd
998     IF ( ( .not. dfft%use_pencil_decomposition ) .and. ( nyfft > 1 ) ) &
999        CALL fftx_error__(' fft_type_init ', ' Slab decomposition and task groups not implemented. ', 1 )
1000
1001     dfft%lpara = lpara  !  this descriptor can be either a descriptor for a
1002                         !  parallel FFT or a serial FFT even in parallel build
1003
1004     CALL sticks_map_allocate( smap, lgamma, dfft%lpara, dfft%nproc2, &
1005          dfft%iproc, dfft%iproc2, dfft%nr1, dfft%nr2, dfft%nr3, bg, dfft%comm )
1006
1007     dfft%lgamma = smap%lgamma ! .TRUE. if the grid has Gamma symmetry
1008
1009     ALLOCATE( stw ( smap%lb(1):smap%ub(1), smap%lb(2):smap%ub(2) ) )
1010     ALLOCATE( st  ( smap%lb(1):smap%ub(1), smap%lb(2):smap%ub(2) ) )
1011     ALLOCATE( nstp(smap%nproc) )
1012     ALLOCATE( sstp(smap%nproc) )
1013     ALLOCATE( nstpw(smap%nproc) )
1014     ALLOCATE( sstpw(smap%nproc) )
1015
1016     !write(*,*) 'calling get_sticks with gkcut =',gkcut
1017     CALL get_sticks(  smap, gkcut, nstpw, sstpw, stw, nstw, ngw )
1018     !write(*,*) 'calling get_sticks with gcut =',gcut
1019     CALL get_sticks(  smap, gcut,  nstp, sstp, st, nst, ngm )
1020
1021     CALL fft_type_set( dfft, nst, smap%ub, smap%lb, smap%idx, &
1022          smap%ist(:,1), smap%ist(:,2), nstp, nstpw, sstp, sstpw, st, stw, nmany )
1023
1024     dfft%ngw = dfft%nwl( dfft%mype + 1 )
1025     dfft%ngm = dfft%ngl( dfft%mype + 1 )
1026     IF( dfft%lgamma ) THEN
1027        dfft%ngw = (dfft%ngw + 1)/2
1028        dfft%ngm = (dfft%ngm + 1)/2
1029     END IF
1030
1031     IF( dfft%ngw /= ngw ) THEN
1032        CALL fftx_error__(' fft_type_init ', ' wrong ngw ', 1 )
1033     END IF
1034     IF( dfft%ngm /= ngm ) THEN
1035        CALL fftx_error__(' fft_type_init ', ' wrong ngm ', 1 )
1036     END IF
1037
1038     DEALLOCATE( st )
1039     DEALLOCATE( stw )
1040     DEALLOCATE( nstp )
1041     DEALLOCATE( sstp )
1042     DEALLOCATE( nstpw )
1043     DEALLOCATE( sstpw )
1044
1045  END SUBROUTINE fft_type_init
1046
1047!=----------------------------------------------------------------------------=!
1048
1049     SUBROUTINE realspace_grid_init( dfft, at, bg, gcutm, fft_fact )
1050       !
1051       ! ... Sets optimal values for dfft%nr[123] and dfft%nr[123]x
1052       ! ... If input dfft%nr[123] are non-zero, leaves them unchanged
1053       ! ... If fft_fact is present, force nr[123] to be multiple of fft_fac([123])
1054       !
1055       USE fft_support, only: good_fft_dimension, good_fft_order
1056       !
1057       IMPLICIT NONE
1058       !
1059       REAL(DP), INTENT(IN) :: at(3,3), bg(3,3)
1060       REAL(DP), INTENT(IN) :: gcutm
1061       INTEGER, INTENT(IN), OPTIONAL :: fft_fact(3)
1062       TYPE(fft_type_descriptor), INTENT(INOUT) :: dfft
1063     !write (6,*) ' inside realspace_grid_init' ; FLUSH(6)
1064       !
1065       IF( dfft%nr1 == 0 .OR. dfft%nr2 == 0 .OR. dfft%nr3 == 0 ) THEN
1066         !
1067         ! ... calculate the size of the real-space dense grid for FFT
1068         ! ... first, an estimate of nr1,nr2,nr3, based on the max values
1069         ! ... of n_i indices in:   G = i*b_1 + j*b_2 + k*b_3
1070         ! ... We use G*a_i = n_i => n_i .le. |Gmax||a_i|
1071         !
1072         dfft%nr1 = int ( sqrt (gcutm) * sqrt (at(1, 1)**2 + at(2, 1)**2 + at(3, 1)**2) ) + 1
1073         dfft%nr2 = int ( sqrt (gcutm) * sqrt (at(1, 2)**2 + at(2, 2)**2 + at(3, 2)**2) ) + 1
1074         dfft%nr3 = int ( sqrt (gcutm) * sqrt (at(1, 3)**2 + at(2, 3)**2 + at(3, 3)**2) ) + 1
1075
1076#if defined (__DEBUG)
1077         write (6,*) sqrt(gcutm)*sqrt(at(1,1)**2 + at(2,1)**2 + at(3,1)**2) , dfft%nr1
1078         write (6,*) sqrt(gcutm)*sqrt(at(1,2)**2 + at(2,2)**2 + at(3,2)**2) , dfft%nr2
1079         write (6,*) sqrt(gcutm)*sqrt(at(1,3)**2 + at(2,3)**2 + at(3,3)**2) , dfft%nr3
1080#endif
1081         !
1082         CALL grid_set( dfft, bg, gcutm, dfft%nr1, dfft%nr2, dfft%nr3 )
1083         !
1084         IF ( PRESENT(fft_fact) ) THEN
1085            dfft%nr1 = good_fft_order( dfft%nr1, fft_fact(1) )
1086            dfft%nr2 = good_fft_order( dfft%nr2, fft_fact(2) )
1087            dfft%nr3 = good_fft_order( dfft%nr3, fft_fact(3) )
1088         ELSE
1089            dfft%nr1 = good_fft_order( dfft%nr1 )
1090            dfft%nr2 = good_fft_order( dfft%nr2 )
1091            dfft%nr3 = good_fft_order( dfft%nr3 )
1092         ENDIF
1093#if defined (__DEBUG)
1094       ELSE
1095          WRITE( stdout, '( /, 3X,"Info: using nr1, nr2, nr3 values from input" )' )
1096#endif
1097       END IF
1098       !
1099       dfft%nr1x  = good_fft_dimension( dfft%nr1 )
1100       dfft%nr2x  = dfft%nr2
1101       dfft%nr3x  = good_fft_dimension( dfft%nr3 )
1102
1103     END SUBROUTINE realspace_grid_init
1104
1105!=----------------------------------------------------------------------------=!
1106
1107   SUBROUTINE grid_set( dfft, bg, gcut, nr1, nr2, nr3 )
1108
1109!  this routine returns in nr1, nr2, nr3 the minimal 3D real-space FFT
1110!  grid required to fit the G-vector sphere with G^2 <= gcut
1111!  On input, nr1,nr2,nr3 must be set to values that match or exceed
1112!  the largest i,j,k (Miller) indices in G(i,j,k) = i*b1 + j*b2 + k*b3
1113!  ----------------------------------------------
1114
1115      IMPLICIT NONE
1116
1117! ... declare arguments
1118      TYPE(fft_type_descriptor), INTENT(IN) :: dfft
1119      INTEGER, INTENT(INOUT) :: nr1, nr2, nr3
1120      REAL(DP), INTENT(IN) :: bg(3,3), gcut
1121
1122! ... declare other variables
1123      INTEGER :: i, j, k, nb(3)
1124      REAL(DP) :: gsq, g(3)
1125     !write (6,*) ' inside grid_set' ; FLUSH(6)
1126
1127!  ----------------------------------------------
1128
1129      nb     = 0
1130
1131! ... calculate moduli of G vectors and the range of indices where
1132! ... |G|^2 < gcut (in parallel whenever possible)
1133
1134      DO k = -nr3, nr3
1135        !
1136        ! ... me_image = processor number, starting from 0
1137        !
1138        IF( MOD( k + nr3, dfft%nproc ) == dfft%mype ) THEN
1139          DO j = -nr2, nr2
1140            DO i = -nr1, nr1
1141
1142              g( 1 ) = DBLE(i)*bg(1,1) + DBLE(j)*bg(1,2) + DBLE(k)*bg(1,3)
1143              g( 2 ) = DBLE(i)*bg(2,1) + DBLE(j)*bg(2,2) + DBLE(k)*bg(2,3)
1144              g( 3 ) = DBLE(i)*bg(3,1) + DBLE(j)*bg(3,2) + DBLE(k)*bg(3,3)
1145
1146! ...         calculate modulus
1147
1148              gsq =  g( 1 )**2 + g( 2 )**2 + g( 3 )**2
1149
1150              IF( gsq < gcut ) THEN
1151
1152! ...           calculate maximum index
1153                nb(1) = MAX( nb(1), ABS( i ) )
1154                nb(2) = MAX( nb(2), ABS( j ) )
1155                nb(3) = MAX( nb(3), ABS( k ) )
1156              END IF
1157
1158            END DO
1159          END DO
1160        END IF
1161      END DO
1162
1163#if defined(__MPI)
1164      CALL MPI_ALLREDUCE( MPI_IN_PLACE, nb, 3, MPI_INTEGER, MPI_MAX, dfft%comm, i )
1165#endif
1166
1167! ... the size of the 3d FFT matrix depends upon the maximum indices. With
1168! ... the following choice, the sphere in G-space "touches" its periodic image
1169
1170      nr1 = 2 * nb(1) + 1
1171      nr2 = 2 * nb(2) + 1
1172      nr3 = 2 * nb(3) + 1
1173
1174      RETURN
1175
1176   END SUBROUTINE grid_set
1177
1178   PURE FUNCTION fft_stick_index( desc, i, j )
1179      IMPLICIT NONE
1180      TYPE(fft_type_descriptor), INTENT(IN) :: desc
1181      INTEGER :: fft_stick_index
1182      INTEGER, INTENT(IN) :: i, j
1183      INTEGER :: mc, m1, m2
1184      m1 = mod (i, desc%nr1) + 1
1185      IF (m1 < 1) m1 = m1 + desc%nr1
1186      m2 = mod (j, desc%nr2) + 1
1187      IF (m2 < 1) m2 = m2 + desc%nr2
1188      mc = m1 + (m2 - 1) * desc%nr1x
1189      fft_stick_index = desc%isind ( mc )
1190   END FUNCTION
1191
1192   !
1193   SUBROUTINE fft_index_to_3d (ir, dfft, i,j,k, offrange)
1194     !
1195     !! returns indices i,j,k yielding the position of grid point ir
1196     !! in the real-space FFT grid described by descriptor dfft:
1197     !!    r(:,ir)= i*tau(:,1)/n1 + j*tau(:,2)/n2 + k*tau(:,3)/n3
1198     !
1199     IMPLICIT NONE
1200     INTEGER, INTENT(IN) :: ir
1201     !! point in the FFT real-space grid
1202     TYPE(fft_type_descriptor), INTENT(IN) :: dfft
1203     !! descriptor for the FFT grid
1204     INTEGER, INTENT(OUT) :: i
1205     !! (i,j,k) corresponding to grid point ir
1206     INTEGER, INTENT(OUT) :: j
1207     !! (i,j,k) corresponding to grid point ir
1208     INTEGER, INTENT(OUT) :: k
1209     !! (i,j,k) corresponding to grid point ir
1210     LOGICAL, INTENT(OUT) :: offrange
1211     !! true if computed i,j,k lie outside the physical range of values
1212     !
1213     i     = ir - 1
1214     k     = i / (dfft%nr1x*dfft%my_nr2p)
1215     i     = i - (dfft%nr1x*dfft%my_nr2p) * k
1216     j     = i /  dfft%nr1x
1217     i     = i -  dfft%nr1x * j
1218     j     = j + dfft%my_i0r2p
1219     k     = k + dfft%my_i0r3p
1220     !
1221     offrange = (i < 0 .OR. i >= dfft%nr1 ) .OR. &
1222          (j < 0 .OR. j >= dfft%nr2 ) .OR. &
1223          (k < 0 .OR. k >= dfft%nr3 )
1224     !
1225   END SUBROUTINE fft_index_to_3d
1226
1227!=----------------------------------------------------------------------------=!
1228END MODULE fft_types
1229!=----------------------------------------------------------------------------=!
1230