1! 2! Copyright (C) Quantum ESPRESSO group 3! 4! This file is distributed under the terms of the 5! GNU General Public License. See the file `License' 6! in the root directory of the present distribution, 7! or http://www.gnu.org/copyleft/gpl.txt . 8! 9 10#if defined(__CUDA) 11#define DEV_ATTRIBUTES , DEVICE 12#else 13#define DEV_ATTRIBUTES 14#endif 15 16!=----------------------------------------------------------------------------=! 17MODULE fft_types 18!=----------------------------------------------------------------------------=! 19 20#if defined(__CUDA) 21 USE cudafor 22#endif 23 USE fft_support, ONLY : good_fft_order, good_fft_dimension 24 USE fft_param 25 USE omp_lib 26 IMPLICIT NONE 27 PRIVATE 28 SAVE 29 30 ! 31 ! Data type for FFT descriptor. 32 ! 33 TYPE fft_type_descriptor 34 ! 35 ! FFT dimensions 36 ! 37 INTEGER :: nr1 = 0 ! 38 INTEGER :: nr2 = 0 ! effective FFT dimensions of the 3D grid (global) 39 INTEGER :: nr3 = 0 ! 40 INTEGER :: nr1x = 0 ! FFT grids leading dimensions 41 INTEGER :: nr2x = 0 ! dimensions of the arrays for the 3D grid (global) 42 INTEGER :: nr3x = 0 ! may differ from nr1 ,nr2 ,nr3 in order to boost performances 43 ! 44 ! Parallel layout: in reciprocal space data are organized in columns (sticks) along 45 ! the third direction and distributed across nproc processors. 46 ! In real space data are distributed in blocks comprising sections 47 ! of the Y and Z directions and complete rows in the X direction in 48 ! a matrix of nproc2 x nproc3 processors. 49 ! nproc = nproc2 x nproc3 and additional communicators are introduced 50 ! for data redistribution across matrix columns and rows. 51 ! 52 ! communicators and processor coordinates 53 ! 54 LOGICAL :: lpara = .FALSE. ! .TRUE. if parallel FFT is active 55 LOGICAL :: lgamma = .FALSE. ! .TRUE. if the grid has Gamma symmetry 56 INTEGER :: root = 0 ! root processor 57 INTEGER :: comm = MPI_COMM_NULL ! communicator for the main fft group 58 INTEGER :: comm2 = MPI_COMM_NULL ! communicator for the fft group along the second direction 59 INTEGER :: comm3 = MPI_COMM_NULL ! communicator for the fft group along the third direction 60 INTEGER :: nproc = 1 ! number of processor in the main fft group 61 INTEGER :: nproc2 = 1 ! number of processor in the fft group along the second direction 62 INTEGER :: nproc3 = 1 ! number of processor in the fft group along the third direction 63 INTEGER :: mype = 0 ! my processor id (starting from 0) in the fft main communicator 64 INTEGER :: mype2 = 0 ! my processor id (starting from 0) in the fft communicator along the second direction (nproc2) 65 INTEGER :: mype3 = 0 ! my processor id (starting from 0) in the fft communicator along the third direction (nproc3) 66 67 INTEGER, ALLOCATABLE :: iproc(:,:) , iproc2(:), iproc3(:) ! subcommunicators proc mapping (starting from 1) 68 ! 69 ! FFT distributed data dimensions and indices 70 ! 71 INTEGER :: my_nr3p = 0 ! size of the "Z" section for this processor = nr3p( mype3 + 1 ) ~ nr3/nproc3 72 INTEGER :: my_nr2p = 0 ! size of the "Y" section for this processor = nr2p( mype2 + 1 ) ~ nr2/nproc2 73 74 INTEGER :: my_i0r3p = 0 ! offset of the first "Z" element of this proc in the nproc3 group = i0r3p( mype3 + 1 ) 75 INTEGER :: my_i0r2p = 0 ! offset of the first "Y" element of this proc in the nproc2 group = i0r2p( mype2 + 1 ) 76 77 INTEGER, ALLOCATABLE :: nr3p(:) ! size of the "Z" section of each processor in the nproc3 group along Z 78 INTEGER, ALLOCATABLE :: nr3p_offset(:) ! offset of the "Z" section of each processor in the nproc3 group along Z 79 INTEGER, ALLOCATABLE :: nr2p(:) ! size of the "Y" section of each processor in the nproc2 group along Y 80 INTEGER, ALLOCATABLE :: nr2p_offset(:) ! offset of the "Y" section of each processor in the nproc2 group along Y 81 INTEGER, ALLOCATABLE :: nr1p(:) ! number of active "X" values ( potential ) for a given proc in the nproc2 group 82 INTEGER, ALLOCATABLE :: nr1w(:) ! number of active "X" values ( wave func ) for a given proc in the nproc2 group 83 INTEGER :: nr1w_tg ! total number of active "X" values ( wave func ). used in task group ffts 84 85 INTEGER, ALLOCATABLE :: i0r3p(:) ! offset of the first "Z" element of each proc in the nproc3 group (starting from 0) 86 INTEGER, ALLOCATABLE :: i0r2p(:) ! offset of the first "Y" element of each proc in the nproc2 group (starting from 0) 87 88 INTEGER, ALLOCATABLE :: ir1p(:) ! if >0 ir1p(m1) is the incremental index of the active ( potential ) X value of this proc 89 INTEGER, ALLOCATABLE :: indp(:,:)! is the inverse of ir1p 90 INTEGER, ALLOCATABLE :: ir1w(:) ! if >0 ir1w(m1) is the incremental index of the active ( wave func ) X value of this proc 91 INTEGER, ALLOCATABLE :: indw(:,:)! is the inverse of ir1w 92 INTEGER, ALLOCATABLE :: ir1w_tg(:)! if >0 ir1w_tg(m1) is the incremental index of the active ( wfc ) X value in task group 93 INTEGER, ALLOCATABLE :: indw_tg(:)! is the inverse of ir1w_tg 94 95 INTEGER, POINTER DEV_ATTRIBUTES :: ir1p_d(:), ir1w_d(:), ir1w_tg_d(:) ! duplicated version of the arrays declared above 96 INTEGER, POINTER DEV_ATTRIBUTES :: indp_d(:,:), indw_d(:,:), indw_tg_d(:,:) ! 97 INTEGER, POINTER DEV_ATTRIBUTES :: nr1p_d(:), nr1w_d(:), nr1w_tg_d(:) ! 98 99 INTEGER :: nst ! total number of sticks ( potential ) 100 101 INTEGER, ALLOCATABLE :: nsp(:) ! number of sticks per processor ( potential ) using proc index starting from 1 102 ! ... that is on proc mype -> nsp( mype + 1 ) 103 INTEGER, ALLOCATABLE :: nsp_offset(:,:) ! offset of sticks per processor ( potential ) 104 INTEGER, ALLOCATABLE :: nsw(:) ! number of sticks per processor ( wave func ) using proc index as above 105 INTEGER, ALLOCATABLE :: nsw_offset(:,:) ! offset of sticks per processor ( wave func ) 106 INTEGER, ALLOCATABLE :: nsw_tg(:)! number of sticks per processor ( wave func ) using proc index as above. task group version 107 108 INTEGER, ALLOCATABLE :: ngl(:) ! per proc. no. of non zero charge density/potential components 109 INTEGER, ALLOCATABLE :: nwl(:) ! per proc. no. of non zero wave function plane components 110 111 INTEGER :: ngm ! my no. of non zero charge density/potential components 112 ! ngm = dfftp%ngl( dfftp%mype + 1 ) 113 ! with gamma sym. 114 ! ngm = ( dfftp%ngl( dfftp%mype + 1 ) + 1 ) / 2 115 116 INTEGER :: ngw ! my no. of non zero wave function plane components 117 ! ngw = dffts%nwl( dffts%mype + 1 ) 118 ! with gamma sym. 119 ! ngw = ( dffts%nwl( dffts%mype + 1 ) + 1 ) / 2 120 121 INTEGER, ALLOCATABLE :: iplp(:) ! if > 0 is the iproc2 processor owning the active "X" value ( potential ) 122 INTEGER, ALLOCATABLE :: iplw(:) ! if > 0 is the iproc2 processor owning the active "X" value ( wave func ) 123 124 INTEGER :: nnp = 0 ! number of 0 and non 0 sticks in a plane ( ~nr1*nr2/nproc ) 125 INTEGER :: nnr = 0 ! local number of FFT grid elements ( ~nr1*nr2*nr3/nproc ) 126 ! size of the arrays allocated for the FFT, local to each processor: 127 ! in parallel execution may differ from nr1x*nr2x*nr3x 128 ! Not to be confused either with nr1*nr2*nr3 129 INTEGER :: nnr_tg = 0 ! local number of grid elements for task group FFT ( ~nr1*nr2*nr3/proc3 ) 130 INTEGER, ALLOCATABLE :: iss(:) ! index of the first rho stick on each proc 131 INTEGER, ALLOCATABLE :: isind(:) ! for each position in the plane indicate the stick index 132 INTEGER, ALLOCATABLE :: ismap(:) ! for each stick in the plane indicate the position 133 134 INTEGER, POINTER DEV_ATTRIBUTES :: ismap_d(:) 135 136 INTEGER, ALLOCATABLE :: nl(:) ! position of the G vec in the FFT grid 137 INTEGER, ALLOCATABLE :: nlm(:) ! with gamma sym. position of -G vec in the FFT grid 138 139 INTEGER, POINTER DEV_ATTRIBUTES :: nl_d(:) ! duplication of the variables defined above 140 INTEGER, POINTER DEV_ATTRIBUTES :: nlm_d(:) ! 141 ! 142 ! task group ALLTOALL communication layout 143 INTEGER, ALLOCATABLE :: tg_snd(:) ! number of elements to be sent in task group redistribution 144 INTEGER, ALLOCATABLE :: tg_rcv(:) ! number of elements to be received in task group redistribution 145 INTEGER, ALLOCATABLE :: tg_sdsp(:)! send displacement for task group A2A communication 146 INTEGER, ALLOCATABLE :: tg_rdsp(:)! receive displacement for task group A2A communicattion 147 ! 148 LOGICAL :: has_task_groups = .FALSE. 149 LOGICAL :: use_pencil_decomposition = .TRUE. 150 ! 151 CHARACTER(len=12):: rho_clock_label = ' ' 152 CHARACTER(len=12):: wave_clock_label = ' ' 153 154 INTEGER :: grid_id 155#if defined(__CUDA) 156 INTEGER(kind=cuda_stream_kind), allocatable, dimension(:) :: stream_scatter_yz 157 INTEGER(kind=cuda_stream_kind), allocatable, dimension(:) :: stream_many 158 INTEGER :: nstream_many = 16 159 160 INTEGER(kind=cuda_stream_kind) :: a2a_comp 161 INTEGER(kind=cuda_stream_kind), allocatable, dimension(:) :: bstreams 162 TYPE(cudaEvent), allocatable, dimension(:) :: bevents 163 164 INTEGER :: batchsize = 16 ! how many ffts to batch together 165 INTEGER :: subbatchsize = 4 ! size of subbatch for pipelining 166 167#if defined(__IPC) 168 INTEGER :: IPC_PEER(16) ! This is used for IPC that is not imlpemented yet. 169#endif 170 INTEGER, ALLOCATABLE :: srh(:,:) ! Isend/recv handles by subbatch 171#endif 172 COMPLEX(DP), ALLOCATABLE, DIMENSION(:) :: aux 173#if defined(__FFT_OPENMP_TASKS) 174 INTEGER, ALLOCATABLE :: comm2s(:) ! multiple communicator for the fft group along the second direction 175 INTEGER, ALLOCATABLE :: comm3s(:) ! multiple communicator for the fft group along the third direction 176#endif 177 END TYPE 178 179 REAL(DP) :: fft_dual = 4.0d0 180 INTEGER :: incremental_grid_identifier = 0 181 182 PUBLIC :: fft_type_descriptor, fft_type_init 183 PUBLIC :: fft_type_allocate, fft_type_deallocate 184 PUBLIC :: fft_stick_index, fft_index_to_3d 185 186CONTAINS 187 188!=----------------------------------------------------------------------------=! 189 190 SUBROUTINE fft_type_allocate( desc, at, bg, gcutm, comm, fft_fact, nyfft ) 191 ! 192 ! routine allocating arrays of fft_type_descriptor, called by fft_type_init 193 ! 194 TYPE (fft_type_descriptor) :: desc 195 REAL(DP), INTENT(IN) :: at(3,3), bg(3,3) 196 REAL(DP), INTENT(IN) :: gcutm 197 INTEGER, INTENT(IN), OPTIONAL :: fft_fact(3) 198 INTEGER, INTENT(IN), OPTIONAL :: nyfft 199 INTEGER, INTENT(in) :: comm ! mype starting from 0 200 INTEGER :: nx, ny, ierr, nzfft, i, nsubbatches 201 INTEGER :: mype, root, nproc, iproc, iproc2, iproc3 ! mype starting from 0 202 INTEGER :: color, key 203 !write (6,*) ' inside fft_type_allocate' ; FLUSH(6) 204 205 IF ( ALLOCATED( desc%nsp ) ) & 206 CALL fftx_error__(' fft_type_allocate ', ' fft arrays already allocated ', 1 ) 207 208 desc%comm = comm 209#if defined(__MPI) 210 IF( desc%comm == MPI_COMM_NULL ) THEN 211 CALL fftx_error__( ' fft_type_allocate ', ' fft communicator is null ', 1 ) 212 END IF 213#endif 214 ! 215 root = 0 ; mype = 0 ; nproc = 1 216#if defined(__MPI) 217 CALL MPI_COMM_RANK( comm, mype, ierr ) 218 CALL MPI_COMM_SIZE( comm, nproc, ierr ) 219#endif 220 desc%root = root ; desc%mype = mype ; desc%nproc = nproc 221 222 IF ( present(nyfft) ) THEN 223 ! check on yfft group dimension 224 CALL fftx_error__( ' fft_type_allocate ', ' MOD(nproc,nyfft) .ne. 0 ', MOD(nproc,nyfft) ) 225 226!#define ZCOMPACT 227#if defined(__MPI) 228#if defined(ZCOMPACT) 229 !write (6,*) ' FFT IS ZCOMPACT ' 230 nzfft = nproc / nyfft 231 color = mype / nzfft ; key = MOD( mype, nzfft ) 232#else 233 !write (6,*) ' FFT IS YCOMPACT ' 234 color = MOD( mype, nyfft ) ; key = mype / nyfft 235#endif 236 237 ! processes with the same key are in the same group along Y 238 CALL MPI_COMM_SPLIT( comm, key, color, desc%comm2, ierr ) 239 CALL MPI_COMM_RANK( desc%comm2, desc%mype2, ierr ) 240 CALL MPI_COMM_SIZE( desc%comm2, desc%nproc2, ierr ) 241 242 ! processes with the same color are in the same group along Z 243 CALL MPI_COMM_SPLIT( comm, color, key, desc%comm3, ierr ) 244 CALL MPI_COMM_RANK( desc%comm3, desc%mype3, ierr ) 245 CALL MPI_COMM_SIZE( desc%comm3, desc%nproc3, ierr ) 246#if defined(__FFT_OPENMP_TASKS) 247 ALLOCATE( desc%comm2s( omp_get_max_threads() )) 248 ALLOCATE( desc%comm3s( omp_get_max_threads() )) 249 DO i=1, OMP_GET_MAX_THREADS() 250 CALL MPI_COMM_DUP(desc%comm2, desc%comm2s(i), ierr) 251 CALL MPI_COMM_DUP(desc%comm3, desc%comm3s(i), ierr) 252 ENDDO 253#endif 254#else 255 desc%comm2 = desc%comm ; desc%mype2 = desc%mype ; desc%nproc2 = desc%nproc 256 desc%comm3 = desc%comm ; desc%mype3 = desc%mype ; desc%nproc3 = desc%nproc 257#endif 258 259 ENDIF 260 !write (6,*) ' nproc and mype ' 261 !write (6,*) desc%nproc, desc%nproc2, desc%nproc3 262 !write (6,*) desc%mype, desc%mype2, desc%mype3 263 264 ALLOCATE ( desc%iproc(desc%nproc2,desc%nproc3), desc%iproc2(desc%nproc), desc%iproc3(desc%nproc) ) 265 do iproc = 1, desc%nproc 266#if defined(ZCOMPACT) 267 iproc3 = MOD(iproc-1, desc%nproc3) + 1 ; iproc2 = (iproc-1)/desc%nproc3 + 1 268#else 269 iproc2 = MOD(iproc-1, desc%nproc2) + 1 ; iproc3 = (iproc-1)/desc%nproc2 + 1 270#endif 271 desc%iproc2(iproc) = iproc2 ; desc%iproc3(iproc) = iproc3 272 desc%iproc(iproc2,iproc3) = iproc 273 end do 274 275 CALL realspace_grid_init( desc, at, bg, gcutm, fft_fact ) 276 277 ALLOCATE( desc%nr2p ( desc%nproc2 ), desc%i0r2p( desc%nproc2 ) ) ; desc%nr2p = 0 ; desc%i0r2p = 0 278 ALLOCATE( desc%nr2p_offset ( desc%nproc2 ) ) ; desc%nr2p_offset = 0 279 ALLOCATE( desc%nr3p ( desc%nproc3 ), desc%i0r3p( desc%nproc3 ) ) ; desc%nr3p = 0 ; desc%i0r3p = 0 280 ALLOCATE( desc%nr3p_offset ( desc%nproc3 ) ) ; desc%nr3p_offset = 0 281 282 nx = desc%nr1x 283 ny = desc%nr2x 284 285 ALLOCATE( desc%nsp( desc%nproc ) ) ; desc%nsp = 0 286 ALLOCATE( desc%nsp_offset( desc%nproc2, desc%nproc3 ) ) ; desc%nsp_offset = 0 287 ALLOCATE( desc%nsw( desc%nproc ) ) ; desc%nsw = 0 288 ALLOCATE( desc%nsw_offset( desc%nproc2, desc%nproc3 ) ) ; desc%nsw_offset = 0 289 ALLOCATE( desc%nsw_tg( desc%nproc ) ) ; desc%nsw_tg = 0 290 ALLOCATE( desc%ngl( desc%nproc ) ) ; desc%ngl = 0 291 ALLOCATE( desc%nwl( desc%nproc ) ) ; desc%nwl = 0 292 ALLOCATE( desc%iss( desc%nproc ) ) ; desc%iss = 0 293 ALLOCATE( desc%isind( nx * ny ) ) ; desc%isind = 0 294 ALLOCATE( desc%ismap( nx * ny ) ) ; desc%ismap = 0 295 ALLOCATE( desc%nr1p( desc%nproc2 ) ) ; desc%nr1p = 0 296 ALLOCATE( desc%nr1w( desc%nproc2 ) ) ; desc%nr1w = 0 297 ALLOCATE( desc%ir1p( desc%nr1x ) ) ; desc%ir1p = 0 298 ALLOCATE( desc%indp( desc%nr1x,desc%nproc2 ) ) ; desc%indp = 0 299 ALLOCATE( desc%ir1w( desc%nr1x ) ) ; desc%ir1w = 0 300 ALLOCATE( desc%ir1w_tg( desc%nr1x ) ) ; desc%ir1w_tg = 0 301 ALLOCATE( desc%indw( desc%nr1x, desc%nproc2 ) ) ; desc%indw = 0 302 ALLOCATE( desc%indw_tg( desc%nr1x ) ) ; desc%indw_tg = 0 303 ALLOCATE( desc%iplp( nx ) ) ; desc%iplp = 0 304 ALLOCATE( desc%iplw( nx ) ) ; desc%iplw = 0 305 306 ALLOCATE( desc%tg_snd( desc%nproc2) ) ; desc%tg_snd = 0 307 ALLOCATE( desc%tg_rcv( desc%nproc2) ) ; desc%tg_rcv = 0 308 ALLOCATE( desc%tg_sdsp( desc%nproc2) ) ; desc%tg_sdsp = 0 309 ALLOCATE( desc%tg_rdsp( desc%nproc2) ) ; desc%tg_rdsp = 0 310 311#if defined(__CUDA) 312 ALLOCATE( desc%indp_d( desc%nr1x,desc%nproc2 ) ) ; desc%indp_d = 0 313 ALLOCATE( desc%indw_d( desc%nr1x, desc%nproc2 ) ) ; desc%indw_d = 0 314 ALLOCATE( desc%indw_tg_d( desc%nr1x, 1 ) ) ; desc%indw_tg_d = 0 315 ! 316 ALLOCATE( desc%nr1p_d(desc%nproc2)) ; desc%nr1p_d = 0 317 ALLOCATE( desc%nr1w_d(desc%nproc2)) ; desc%nr1w_d = 0 318 ALLOCATE( desc%nr1w_tg_d(1) ) ; desc%nr1w_tg_d = 0 319 320 ALLOCATE( desc%ir1p_d( desc%nr1x ) ) ; desc%ir1p_d = 0 321 ALLOCATE( desc%ir1w_d( desc%nr1x ) ) ; desc%ir1w_d = 0 322 ALLOCATE( desc%ir1w_tg_d( desc%nr1x ) ) ; desc%ir1w_tg_d = 0 323 ALLOCATE( desc%ismap_d( nx * ny ) ) ; desc%ismap_d = 0 324 325 ALLOCATE ( desc%stream_scatter_yz(desc%nproc3) ) ; 326 DO iproc = 1, desc%nproc3 327 ierr = cudaStreamCreate(desc%stream_scatter_yz(iproc)) 328 END DO 329 ! 330 ALLOCATE ( desc%stream_many(desc%nstream_many) ) ; 331 DO i = 1, desc%nstream_many 332 ierr = cudaStreamCreate(desc%stream_many(i)) 333 IF ( ierr /= 0 ) CALL fftx_error__( ' fft_type_allocate ', ' Error creating stream ', i ) 334 END DO 335 336 ierr = cudaStreamCreate( desc%a2a_comp ) 337 338 nsubbatches = ceiling(real(desc%batchsize)/desc%subbatchsize) 339 340 ALLOCATE( desc%bstreams( nsubbatches ) ) 341 ALLOCATE( desc%bevents( nsubbatches ) ) 342 DO i = 1, nsubbatches 343 ierr = cudaStreamCreate( desc%bstreams(i) ) 344 ierr = cudaEventCreate( desc%bevents(i) ) 345 ENDDO 346 ALLOCATE( desc%srh(2*nproc, nsubbatches)) 347 348#endif 349 350 incremental_grid_identifier = incremental_grid_identifier + 1 351 desc%grid_id = incremental_grid_identifier 352 353 END SUBROUTINE fft_type_allocate 354 355 SUBROUTINE fft_type_deallocate( desc ) 356 TYPE (fft_type_descriptor) :: desc 357 INTEGER :: i, ierr, nsubbatches 358 !write (6,*) ' inside fft_type_deallocate' ; FLUSH(6) 359 IF ( ALLOCATED( desc%nr2p ) ) DEALLOCATE( desc%nr2p ) 360 IF ( ALLOCATED( desc%nr2p_offset ) ) DEALLOCATE( desc%nr2p_offset ) 361 IF ( ALLOCATED( desc%nr3p_offset ) ) DEALLOCATE( desc%nr3p_offset ) 362 IF ( ALLOCATED( desc%i0r2p ) ) DEALLOCATE( desc%i0r2p ) 363 IF ( ALLOCATED( desc%nr3p ) ) DEALLOCATE( desc%nr3p ) 364 IF ( ALLOCATED( desc%i0r3p ) ) DEALLOCATE( desc%i0r3p ) 365 IF ( ALLOCATED( desc%nsp ) ) DEALLOCATE( desc%nsp ) 366 IF ( ALLOCATED( desc%nsp_offset ) ) DEALLOCATE( desc%nsp_offset ) 367 IF ( ALLOCATED( desc%nsw ) ) DEALLOCATE( desc%nsw ) 368 IF ( ALLOCATED( desc%nsw_offset ) ) DEALLOCATE( desc%nsw_offset ) 369 IF ( ALLOCATED( desc%nsw_tg ) ) DEALLOCATE( desc%nsw_tg ) 370 IF ( ALLOCATED( desc%ngl ) ) DEALLOCATE( desc%ngl ) 371 IF ( ALLOCATED( desc%nwl ) ) DEALLOCATE( desc%nwl ) 372 IF ( ALLOCATED( desc%iss ) ) DEALLOCATE( desc%iss ) 373 IF ( ALLOCATED( desc%isind ) ) DEALLOCATE( desc%isind ) 374 IF ( ALLOCATED( desc%ismap ) ) DEALLOCATE( desc%ismap ) 375 IF ( ALLOCATED( desc%nr1p ) ) DEALLOCATE( desc%nr1p ) 376 IF ( ALLOCATED( desc%nr1w ) ) DEALLOCATE( desc%nr1w ) 377 IF ( ALLOCATED( desc%ir1p ) ) DEALLOCATE( desc%ir1p ) 378 IF ( ALLOCATED( desc%indp ) ) DEALLOCATE( desc%indp ) 379 IF ( ALLOCATED( desc%ir1w ) ) DEALLOCATE( desc%ir1w ) 380 IF ( ALLOCATED( desc%ir1w_tg ) )DEALLOCATE( desc%ir1w_tg ) 381 IF ( ALLOCATED( desc%indw ) ) DEALLOCATE( desc%indw ) 382 IF ( ALLOCATED( desc%indw_tg ) )DEALLOCATE( desc%indw_tg ) 383 IF ( ALLOCATED( desc%iplp ) ) DEALLOCATE( desc%iplp ) 384 IF ( ALLOCATED( desc%iplw ) ) DEALLOCATE( desc%iplw ) 385 IF ( ALLOCATED( desc%iproc ) ) DEALLOCATE( desc%iproc ) 386 IF ( ALLOCATED( desc%iproc2 ) ) DEALLOCATE( desc%iproc2 ) 387 IF ( ALLOCATED( desc%iproc3 ) ) DEALLOCATE( desc%iproc3 ) 388 389 IF ( ALLOCATED( desc%tg_snd ) ) DEALLOCATE( desc%tg_snd ) 390 IF ( ALLOCATED( desc%tg_rcv ) ) DEALLOCATE( desc%tg_rcv ) 391 IF ( ALLOCATED( desc%tg_sdsp ) )DEALLOCATE( desc%tg_sdsp ) 392 IF ( ALLOCATED( desc%tg_rdsp ) )DEALLOCATE( desc%tg_rdsp ) 393 394 IF ( ALLOCATED( desc%nl ) ) DEALLOCATE( desc%nl ) 395 IF ( ALLOCATED( desc%nlm ) ) DEALLOCATE( desc%nlm ) 396 397#if defined(__CUDA) 398 IF ( ALLOCATED( desc%ismap_d ) ) DEALLOCATE( desc%ismap_d ) 399 IF ( ALLOCATED( desc%ir1p_d ) ) DEALLOCATE( desc%ir1p_d ) 400 IF ( ALLOCATED( desc%ir1w_d ) ) DEALLOCATE( desc%ir1w_d ) 401 IF ( ALLOCATED( desc%ir1w_tg_d ) ) DEALLOCATE( desc%ir1w_tg_d ) 402 403 IF ( ALLOCATED( desc%indp_d ) ) DEALLOCATE( desc%indp_d ) 404 IF ( ALLOCATED( desc%indw_d ) ) DEALLOCATE( desc%indw_d ) 405 IF ( ALLOCATED( desc%indw_tg_d ) ) DEALLOCATE( desc%indw_tg_d ) 406 407 IF ( ALLOCATED( desc%nr1p_d ) ) DEALLOCATE( desc%nr1p_d ) 408 IF ( ALLOCATED( desc%nr1w_d ) ) DEALLOCATE( desc%nr1w_d ) 409 IF ( ALLOCATED( desc%nr1w_tg_d ) ) DEALLOCATE( desc%nr1w_tg_d ) 410 411 IF (ALLOCATED(desc%stream_scatter_yz)) THEN 412 do i = 1, desc%nproc3 413 ierr = cudaStreamDestroy(desc%stream_scatter_yz(i)) 414 end do 415 DEALLOCATE(desc%stream_scatter_yz) 416 END IF 417 IF (ALLOCATED(desc%stream_many)) THEN 418 do i = 1, desc%nstream_many 419 ierr = cudaStreamDestroy(desc%stream_many(i)) 420 end do 421 DEALLOCATE(desc%stream_many) 422 END IF 423 424 IF ( ALLOCATED( desc%nl_d ) ) DEALLOCATE( desc%nl_d ) 425 IF ( ALLOCATED( desc%nlm_d ) ) DEALLOCATE( desc%nlm_d ) 426 ! 427 ! SLAB decomposition 428 IF ( ALLOCATED( desc%srh ) ) DEALLOCATE( desc%srh ) 429 ierr = cudaStreamDestroy( desc%a2a_comp ) 430 431 IF ( ALLOCATED(desc%bstreams) ) THEN 432 nsubbatches = ceiling(real(desc%batchsize)/desc%subbatchsize) 433 DO i = 1, nsubbatches 434 ierr = cudaStreamDestroy( desc%bstreams(i) ) 435 ierr = cudaEventDestroy( desc%bevents(i) ) 436 ENDDO 437 ! 438 DEALLOCATE( desc%bstreams ) 439 DEALLOCATE( desc%bevents ) 440 END IF 441 442#endif 443 444 desc%comm = MPI_COMM_NULL 445#if defined(__MPI) 446 IF (desc%comm2 /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm2, ierr ) 447 IF (desc%comm3 /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm3, ierr ) 448#if defined(__FFT_OPENMP_TASKS) 449 DO i=1, SIZE(desc%comm2s) 450 IF (desc%comm2s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm2s(i), ierr ) 451 IF (desc%comm3s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm3s(i), ierr ) 452 ENDDO 453 DEALLOCATE( desc%comm2s ) 454 DEALLOCATE( desc%comm3s ) 455#endif 456#else 457 desc%comm2 = MPI_COMM_NULL 458 desc%comm3 = MPI_COMM_NULL 459#endif 460 461 desc%nr1 = 0 ; desc%nr2 = 0 ; desc%nr3 = 0 462 desc%nr1x = 0 ; desc%nr2x = 0 ; desc%nr3x = 0 463 464 desc%grid_id = 0 465 466 END SUBROUTINE fft_type_deallocate 467 468!=----------------------------------------------------------------------------=! 469 470 SUBROUTINE fft_type_set( desc, nst, ub, lb, idx, in1, in2, ncp, ncpw, ngp, ngpw, st, stw, nmany ) 471 472 TYPE (fft_type_descriptor) :: desc 473 474 INTEGER, INTENT(in) :: nst ! total number of stiks 475 INTEGER, INTENT(in) :: ub(3), lb(3) ! upper and lower bound of real space indices 476 INTEGER, INTENT(in) :: idx(:) ! sorting index of the sticks 477 INTEGER, INTENT(in) :: in1(:) ! x-index of a stick 478 INTEGER, INTENT(in) :: in2(:) ! y-index of a stick 479 INTEGER, INTENT(in) :: ncp(:) ! number of rho columns per processor 480 INTEGER, INTENT(in) :: ncpw(:) ! number of wave columns per processor 481 INTEGER, INTENT(in) :: ngp(:) ! number of rho G-vectors per processor 482 INTEGER, INTENT(in) :: ngpw(:) ! number of wave G-vectors per processor 483 INTEGER, INTENT(in) :: st( lb(1) : ub(1), lb(2) : ub(2) ) ! stick owner of a given rho stick 484 INTEGER, INTENT(in) :: stw( lb(1) : ub(1), lb(2) : ub(2) ) ! stick owner of a given wave stick 485 INTEGER, INTENT(in) :: nmany ! number of FFT bands 486 487 INTEGER :: nsp( desc%nproc ), nsw_tg, nr1w_tg 488 INTEGER :: np, nq, i, is, iss, i1, i2, m1, m2, ip 489 INTEGER :: ncpx, nr1px, nr2px, nr3px 490 INTEGER :: nr1, nr2, nr3 ! size of real space grid 491 INTEGER :: nr1x, nr2x, nr3x ! padded size of real space grid 492 INTEGER :: ierr 493 !write (6,*) ' inside fft_type_set' ; FLUSH(6) 494 ! 495 ! 496#if defined(__MPI) 497#if defined(__FFT_OPENMP_TASKS) 498 IF (nmany > OMP_GET_MAX_THREADS()) THEN 499 DO i=1, SIZE(desc%comm2s) 500 IF (desc%comm2s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm2s(i), ierr ) 501 IF (desc%comm3s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm3s(i), ierr ) 502 ENDDO 503 DEALLOCATE( desc%comm2s ) 504 DEALLOCATE( desc%comm3s ) 505 ALLOCATE( desc%comm2s( nmany )) 506 ALLOCATE( desc%comm3s( nmany )) 507 DO i=1, nmany 508 CALL MPI_COMM_DUP(desc%comm2, desc%comm2s(i), ierr) 509 CALL MPI_COMM_DUP(desc%comm3, desc%comm3s(i), ierr) 510 ENDDO 511 !ELSEIF (nmany == 1) THEN 512 ! DO i=1, SIZE(desc%comm2s) 513 ! IF (desc%comm2s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm2s(i), ierr ) 514 ! IF (desc%comm3s(i) /= MPI_COMM_NULL) CALL MPI_COMM_FREE( desc%comm3s(i), ierr ) 515 ! ENDDO 516 ! DEALLOCATE( desc%comm2s ) 517 ! DEALLOCATE( desc%comm3s ) 518 ENDIF 519#endif 520#endif 521 IF (.NOT. ALLOCATED( desc%nsp ) ) & 522 CALL fftx_error__(' fft_type_set ', ' fft arrays not yet allocated ', 1 ) 523 524 IF ( desc%nr1 == 0 .OR. desc%nr2 == 0 .OR. desc%nr3 == 0 ) & 525 CALL fftx_error__(' fft_type_set ', ' fft dimensions not yet set ', 1 ) 526 527 ! Set fft actual and leading dimensions to be used internally 528 529 nr1 = desc%nr1 ; nr2 = desc%nr2 ; nr3 = desc%nr3 530 nr1x = desc%nr1x ; nr2x = desc%nr2x ; nr3x = desc%nr3x 531 532 IF( ( nr1 > nr1x ) .or. ( nr2 > nr2x ) .or. ( nr3 > nr3x ) ) & 533 CALL fftx_error__( ' fft_type_set ', ' wrong fft dimensions ', 1 ) 534 535 IF( ( size( desc%ngl ) < desc%nproc ) .or. ( size( desc%iss ) < desc%nproc ) .or. & 536 ( size( desc%nr2p ) < desc%nproc2 ) .or. ( size( desc%i0r2p ) < desc%nproc2 ) .or. & 537 ( size( desc%nr3p ) < desc%nproc3 ) .or. ( size( desc%i0r3p ) < desc%nproc3 ) ) & 538 CALL fftx_error__( ' fft_type_set ', ' wrong descriptor dimensions ', 2 ) 539 540 IF( ( size( idx ) < nst ) .or. ( size( in1 ) < nst ) .or. ( size( in2 ) < nst ) ) & 541 CALL fftx_error__( ' fft_type_set ', ' wrong number of stick dimensions ', 3 ) 542 543 IF( ( size( ncp ) < desc%nproc ) .or. ( size( ngp ) < desc%nproc ) ) & 544 CALL fftx_error__( ' fft_type_set ', ' wrong stick dimensions ', 4 ) 545 546 ! Set the number of "Y" values for each processor in the nproc2 group 547 np = nr2 / desc%nproc2 548 nq = nr2 - np * desc%nproc2 549 desc%nr2p(1:desc%nproc2) = np ! assign a base value to all processors of the nproc2 group 550 DO i =1, nq ! assign an extra unit to the first nq processors of the nproc2 group 551 desc%nr2p(i) = np + 1 552 ENDDO 553 ! set the offset 554 desc%nr2p_offset(1) = 0 555 DO i =1, desc%nproc2-1 556 desc%nr2p_offset(i+1) = desc%nr2p_offset(i) + desc%nr2p(i) 557 ENDDO 558 !-- my_nr2p is the number of planes per processor of this processor in the Y group 559 desc%my_nr2p = desc%nr2p( desc%mype2 + 1 ) 560 561 ! Find out the index of the starting plane on each proc 562 desc%i0r2p = 0 563 DO i = 2, desc%nproc2 564 desc%i0r2p(i) = desc%i0r2p(i-1) + desc%nr2p(i-1) 565 ENDDO 566 !-- my_i0r2p is the index-offset of the starting plane of this processor in the Y group 567 desc%my_i0r2p = desc%i0r2p( desc%mype2 + 1 ) 568 569 ! Set the number of "Z" values for each processor in the nproc3 group 570 np = nr3 / desc%nproc3 571 nq = nr3 - np * desc%nproc3 572 desc%nr3p(1:desc%nproc3) = np ! assign a base value to all processors 573 DO i =1, nq ! assign an extra unit to the first nq processors of the nproc3 group 574 desc%nr3p(i) = np + 1 575 END DO 576 ! set the offset 577 desc%nr3p_offset(1) = 0 578 DO i =1, desc%nproc3-1 579 desc%nr3p_offset(i+1) = desc%nr3p_offset(i) + desc%nr3p(i) 580 ENDDO 581 !-- my_nr3p is the number of planes per processor of this processor in the Z group 582 desc%my_nr3p = desc%nr3p( desc%mype3 + 1 ) 583 584 ! Find out the index of the starting plane on each proc 585 desc%i0r3p = 0 586 DO i = 2, desc%nproc3 587 desc%i0r3p( i ) = desc%i0r3p( i-1 ) + desc%nr3p ( i-1 ) 588 ENDDO 589 !-- my_i0r3p is the index-offset of the starting plane of this processor in the Z group 590 desc%my_i0r3p = desc%i0r3p( desc%mype3 + 1 ) 591 592 ! dimension of the xy plane. see ncplane 593 594 desc%nnp = nr1x * nr2x 595 596!!!!!!! 597 598 desc%ngl( 1:desc%nproc ) = ngp( 1:desc%nproc ) ! local number of g vectors (rho) per processor 599 desc%nwl( 1:desc%nproc ) = ngpw( 1:desc%nproc ) ! local number of g vectors (wave) per processor 600 601 IF( size( desc%isind ) < ( nr1x * nr2x ) ) & 602 CALL fftx_error__( ' fft_type_set ', ' wrong descriptor dimensions, isind ', 5 ) 603 604 IF( size( desc%iplp ) < ( nr1x ) .or. size( desc%iplw ) < ( nr1x ) ) & 605 CALL fftx_error__( ' fft_type_set ', ' wrong descriptor dimensions, ipl ', 5 ) 606 607 IF( desc%my_nr3p == 0 .and. ( .not. desc%use_pencil_decomposition ) ) & 608 CALL fftx_error__( ' fft_type_set ', & 609 ' there are processes with no planes. Use pencil decomposition (-pd .true.) ', 6 ) 610 611 ! 612 ! 1. Temporarily store in the array "desc%isind" the index of the processor 613 ! that own the corresponding stick (index of proc starting from 1) 614 ! 2. Set the array elements of "desc%iplw" and "desc%iplp" to one 615 ! for that index corresponding to YZ planes containing at least one stick 616 ! this are used in the FFT transform along Y 617 ! 618 619 desc%isind = 0 ! will contain the +ve or -ve of the processor number, if any, that owns the stick 620 desc%iplp = 0 ! if > 0 is the nporc2 processor owning this ( potential ) X active plane 621 desc%iplw = 0 ! if > 0 is the nproc2 processor owning this ( wave func ) X active value 622 623 ! Set nst to the proper number of sticks (the total number of 1d fft along z to be done) 624 625 desc%nst = 0 626 DO iss = 1, SIZE( idx ) 627 is = idx( iss ) 628 IF( is < 1 ) CYCLE 629 i1 = in1( is ) 630 i2 = in2( is ) 631 IF( st( i1, i2 ) > 0 ) THEN 632 desc%nst = desc%nst + 1 633 m1 = i1 + 1; IF ( m1 < 1 ) m1 = m1 + nr1 634 m2 = i2 + 1; IF ( m2 < 1 ) m2 = m2 + nr2 635 IF( stw( i1, i2 ) > 0 ) THEN 636 desc%isind( m1 + ( m2 - 1 ) * nr1x ) = st( i1, i2 ) 637 desc%iplw( m1 ) = desc%iproc2(st(i1,i2)) 638 ELSE 639 desc%isind( m1 + ( m2 - 1 ) * nr1x ) = -st( i1, i2 ) 640 ENDIF 641 desc%iplp( m1 ) = desc%iproc2(st(i1,i2)) 642 IF( desc%lgamma ) THEN 643 IF( i1 /= 0 .OR. i2 /= 0 ) desc%nst = desc%nst + 1 644 m1 = -i1 + 1; IF ( m1 < 1 ) m1 = m1 + nr1 645 m2 = -i2 + 1; IF ( m2 < 1 ) m2 = m2 + nr2 646 IF( stw( -i1, -i2 ) > 0 ) THEN 647 desc%isind( m1 + ( m2 - 1 ) * nr1x ) = st( -i1, -i2 ) 648 desc%iplw( m1 ) = desc%iproc2(st(-i1,-i2)) 649 ELSE 650 desc%isind( m1 + ( m2 - 1 ) * nr1x ) = -st( -i1, -i2 ) 651 ENDIF 652 desc%iplp( m1 ) = desc%iproc2(st(-i1,-i2)) 653 ENDIF 654 ENDIF 655 ENDDO 656 do m1=1,desc%nr1x 657 if (desc%iplw(m1)>0) then 658 if (desc%iplp(m1) /= desc%iplw(m1) ) then 659 write (6,*) 'WRONG iplp/iplw arrays' 660 write (6,*) desc%iplp 661 write (6,*) desc%iplw 662 CALL fftx_error__( ' fft_type_set ', ' iplp is wrong ', m1 ) 663 end if 664 end if 665 end do 666 ! count how many active X values per each nproc2 processor and set the incremental index of this one 667 668 ! wave func X values first 669 desc%nr1w = 0 ; desc%ir1w = 0 ; desc%indw = 0 670 nr1w_tg = 0 ; desc%ir1w_tg = 0 ; desc%indw_tg = 0 671 do i1 = 1, nr1 672 if (desc%iplw(i1) > 0 ) then 673 desc%nr1w(desc%iplw(i1)) = desc%nr1w(desc%iplw(i1)) + 1 674 desc%indw(desc%nr1w(desc%iplw(i1)),desc%iplw(i1)) = i1 675 nr1w_tg = nr1w_tg + 1 ; desc%ir1w_tg(i1) = nr1w_tg ; desc%indw_tg(nr1w_tg) = i1 676 end if 677 if (desc%iplw(i1) == desc%mype2 +1) desc%ir1w(i1) = desc%nr1w(desc%iplw(i1)) 678 end do 679 desc%nr1w_tg = nr1w_tg ! this is useful in task group ffts 680 681 ! then potential X values 682 desc%nr1p = desc%nr1w ; desc%ir1p=desc%ir1w ; desc%indp = desc%indw 683 do i1 = 1, nr1 684 if ( (desc%iplw(i1) > 0) .and. (desc%iplp(i1) == 0) ) & 685 CALL fftx_error__( ' fft_type_set ', ' bad distribution of X values ', i1 ) 686 if ( (desc%iplw(i1) > 0) ) cycle ! this X value has already been taken care of 687 688 if (desc%iplp(i1) > 0 ) then 689 desc%nr1p(desc%iplp(i1)) = desc%nr1p(desc%iplp(i1)) + 1 690 desc%indp(desc%nr1p(desc%iplp(i1)),desc%iplp(i1)) = i1 691 end if 692 if (desc%iplp(i1) == desc%mype2+1) desc%ir1p(i1) = desc%nr1p(desc%iplp(i1)) 693 end do 694 695 ! 696 ! Compute for each proc the global index ( starting from 0 ) of the first 697 ! local stick ( desc%iss ) 698 ! 699 700 DO i = 1, desc%nproc 701 IF( i == 1 ) THEN 702 desc%iss( i ) = 0 703 ELSE 704 desc%iss( i ) = desc%iss( i - 1 ) + ncp( i - 1 ) 705 ENDIF 706 ENDDO 707 708 ! iss(1:nproc) is the index offset of the first column of a given processor 709 710 IF( size( desc%ismap ) < ( nst ) ) & 711 CALL fftx_error__( ' fft_type_set ', ' wrong descriptor dimensions ', 6 ) 712 713 ! 714 ! 1. Set the array desc%ismap which maps stick indexes to 715 ! position in the plane ( iss ) 716 ! 2. Re-set the array "desc%isind", that maps position 717 ! in the plane with stick indexes (it is the inverse of desc%ismap ) 718 ! 719 720 ! wave function sticks first 721 722 desc%ismap = 0 ! will be the global xy stick index in the global list of processor-ordered sticks 723 nsp = 0 ! will be the number of sticks of a given processor 724 DO iss = 1, size( desc%isind ) 725 ip = desc%isind( iss ) ! processor that owns iss wave stick. if it's a rho stick it's negative ! 726 IF( ip > 0 ) THEN ! only operates on wave sticks 727 nsp( ip ) = nsp( ip ) + 1 728 desc%ismap( nsp( ip ) + desc%iss( ip ) ) = iss 729 IF( ip == ( desc%mype + 1 ) ) THEN 730 desc%isind( iss ) = nsp( ip ) ! isind is OVERWRITTEN as the ordered index in this processor stick list 731 ELSE 732 desc%isind( iss ) = 0 ! zero otherwise... 733 ENDIF 734 ENDIF 735 ENDDO 736 737 ! check number of sticks against the input value 738 739 IF( any( nsp( 1:desc%nproc ) /= ncpw( 1:desc%nproc ) ) ) THEN 740 DO ip = 1, desc%nproc 741 WRITE( stdout,*) ' * ', ip, ' * ', nsp( ip ), ' /= ', ncpw( ip ) 742 ENDDO 743 CALL fftx_error__( ' fft_type_set ', ' inconsistent number of sticks ', 7 ) 744 ENDIF 745 746 desc%nsw( 1:desc%nproc ) = nsp( 1:desc%nproc ) ! -- number of wave sticks per processor 747 DO ip=1, desc%nproc3 748 desc%nsw_offset(1,ip) = 0 749 DO i=1, desc%nproc2-1 750 desc%nsw_offset(i+1,ip) = desc%nsw_offset(i,ip) + desc%nsw(desc%iproc(i,ip)) 751 ENDDO 752 ENDDO 753 754 ! -- number of wave sticks per processor for task group ffts 755 desc%nsw_tg( 1:desc%nproc ) = 0 756 do ip =1, desc%nproc3 757 nsw_tg = sum(desc%nsw(desc%iproc(1:desc%nproc2,ip))) 758 desc%nsw_tg(desc%iproc(1:desc%nproc2,ip)) = nsw_tg 759 end do 760 761 ! then add pseudopotential stick 762 763 DO iss = 1, size( desc%isind ) 764 ip = desc%isind( iss ) ! -ve of processor that owns iss rho stick. if it was a wave stick it's something non negative ! 765 IF( ip < 0 ) THEN 766 nsp( -ip ) = nsp( -ip ) + 1 767 desc%ismap( nsp( -ip ) + desc%iss( -ip ) ) = iss 768 IF( -ip == ( desc%mype + 1 ) ) THEN 769 desc%isind( iss ) = nsp( -ip ) ! isind is OVERWRITTEN as the ordered index in this processor stick list 770 ELSE 771 desc%isind( iss ) = 0 ! zero otherwise... 772 ENDIF 773 ENDIF 774 ENDDO 775 776 ! check number of sticks against the input value 777 778 IF( any( nsp( 1:desc%nproc ) /= ncp( 1:desc%nproc ) ) ) THEN 779 DO ip = 1, desc%nproc 780 WRITE( stdout,*) ' * ', ip, ' * ', nsp( ip ), ' /= ', ncp( ip ) 781 ENDDO 782 CALL fftx_error__( ' fft_type_set ', ' inconsistent number of sticks ', 8 ) 783 ENDIF 784 785 desc%nsp( 1:desc%nproc ) = nsp( 1:desc%nproc ) ! -- number of rho sticks per processor 786 DO ip=1, desc%nproc3 787 desc%nsp_offset(1,ip) = 0 788 DO i=1, desc%nproc2-1 789 desc%nsp_offset(i+1,ip) = desc%nsp_offset(i,ip) + desc%nsp(desc%iproc(i,ip)) 790 ENDDO 791 ENDDO 792 793 IF( .NOT. desc%lpara ) THEN 794 795 desc%isind = 0 796 desc%iplw = 0 797 desc%iplp = 1 798 799 ! here we are setting parameter as if we were in a serial code, 800 ! sticks are along X dimension and not along Z 801 desc%nsp(1) = 0 802 desc%nsw(1) = 0 803 DO i1 = lb( 1 ), ub( 1 ) 804 DO i2 = lb( 2 ), ub( 2 ) 805 m1 = i1 + 1; IF ( m1 < 1 ) m1 = m1 + nr1 806 m2 = i2 + 1; IF ( m2 < 1 ) m2 = m2 + nr2 807 IF( st( i1, i2 ) > 0 ) THEN 808 desc%nsp(1) = desc%nsp(1) + 1 809 END IF 810 IF( stw( i1, i2 ) > 0 ) THEN 811 desc%nsw(1) = desc%nsw(1) + 1 812 desc%isind( m1 + ( m2 - 1 ) * nr1x ) = 1 ! st( i1, i2 ) 813 desc%iplw( m1 ) = 1 814 ENDIF 815 ENDDO 816 ENDDO 817 ! 818 ! if we are in a parallel run, but would like to use serial FFT, all 819 ! tasks must have the same parameters as if serial run. 820 ! 821 desc%nnr = nr1x * nr2x * nr3x 822 desc%nnp = nr1x * nr2x 823 desc%my_nr2p = nr2 ; desc%nr2p = nr2 ; desc%i0r2p = 0 824 desc%my_nr3p = nr3 ; desc%nr3p = nr3 ; desc%i0r3p = 0 825 desc%nsw = desc%nsw(1) 826 desc%nsp = desc%nsp(1) 827 desc%ngl = SUM(ngp) 828 desc%nwl = SUM(ngpw) 829 ! 830 END IF 831 832 !write (6,*) 'fft_type_set SUMMARY' 833 !write (6,*) 'desc%mype ', desc%mype 834 !write (6,*) 'desc%mype2', desc%mype2 835 !write (6,*) 'desc%mype3', desc%mype3 836 !write (6,*) 'nr1 nr2 nr3 dimensions : ', desc%nr1, desc%nr2, desc%nr3 837 !write (6,*) 'nr1x nr2x nr3x dimensions : ', desc%nr1x, desc%nr2x, desc%nr3x 838 !write (6,*) 'nr3p arrays' 839 !write (6,*) desc%nr3p 840 !write (6,*) 'i0r3p arrays' 841 !write (6,*) desc%i0r3p 842 !write (6,*) 'nr2p arrays' 843 !write (6,*) desc%nr2p 844 !write (6,*) 'i0r2p arrays' 845 !write (6,*) desc%i0r2p 846 !write (6,*) 'nsp/nsw arrays' 847 !write (6,*) desc%nsp 848 !write (6,*) desc%nsw 849 !write (6,*) 'nr1p/nr1w arrays' 850 !write (6,*) desc%nr1p 851 !write (6,*) desc%nr1w 852 !write (6,*) 'ir1p/ir1w arrays' 853 !write (6,*) desc%ir1p 854 !write (6,*) desc%ir1w 855 !write (6,*) 'indp/indw arrays' 856 !write (6,*) desc%indp 857 !write (6,*) desc%indw 858 !write (6,*) 'iplp/iplw arrays' 859 !write (6,*) desc%iplp 860 !write (6,*) desc%iplw 861 862 ! Finally set fft local workspace dimension 863 864 nr1px = MAXVAL( desc%nr1p( 1:desc%nproc2 ) ) ! maximum number of X values per processor in the nproc2 group 865 nr2px = MAXVAL( desc%nr2p( 1:desc%nproc2 ) ) ! maximum number of planes per processor in the nproc2 group 866 nr3px = MAXVAL( desc%nr3p( 1:desc%nproc3 ) ) ! maximum number of planes per processor in the nproc3 group 867 ncpx = MAXVAL( ncp( 1:desc%nproc ) ) ! maximum number of columns per processor (use potential sticks to be safe) 868 869 IF ( desc%nproc == 1 ) THEN 870 desc%nnr = nr1x * nr2x * nr3x 871 desc%nnr_tg = desc%nnr * desc%nproc2 872 ELSE 873 desc%nnr = max( ncpx * nr3x, nr1x * nr2px * nr3px ) ! this is required to contain the local data in R and G space 874 desc%nnr = max( desc%nnr, ncpx*nr3px*desc%nproc3, nr1px*nr2px*nr3px*desc%nproc2) ! this is required to use ALLTOALL instead of ALLTOALLV 875 desc%nnr = max( 1, desc%nnr ) ! ensure that desc%nrr > 0 ( for extreme parallelism ) 876 desc%nnr_tg = desc%nnr * desc%nproc2 877 ENDIF 878 879 !write (6,*) ' nnr bounds' 880 !write (6,*) ' nr1x ',nr1x,' nr2x ', nr2x, ' nr3x ', nr3x 881 !write (6,*) ' nr1x * nr2x * nr3x',nr1x * nr2x * nr3x 882 !write (6,*) ' ncpx ',ncpx,' nr3px ', nr3px, ' desc%nproc3 ', desc%nproc3 883 !write (6,*) ' ncpx * nr3x ',ncpx * nr3x 884 !write (6,*) ' ncpx * nr3px * desc%nproc3 ',ncpx*nr3px*desc%nproc3 885 !write (6,*) ' nr1px ', nr1px,' nr2px ',nr2px,' desc%nproc2 ', desc%nproc2 886 !write (6,*) ' nr1x * nr2px * nr3px ',nr1x * nr2px * nr3px 887 !write (6,*) ' nr1px * nr2px *nr3px * desc%nproc2 ',nr1px*nr2px*nr3px*desc%nproc2 888 !write (6,*) ' desc%nnr ', desc%nnr 889 890 IF( desc%nr3x * desc%nsw( desc%mype + 1 ) > desc%nnr ) & 891 CALL fftx_error__( ' task_groups_init ', ' inconsistent desc%nnr ', 1 ) 892 desc%tg_snd(1) = desc%nr3x * desc%nsw( desc%mype + 1 ) 893 desc%tg_rcv(1) = desc%nr3x * desc%nsw( desc%iproc(1,desc%mype3+1) ) 894 desc%tg_sdsp(1) = 0 895 desc%tg_rdsp(1) = 0 896 DO i = 2, desc%nproc2 897 desc%tg_snd(i) = desc%nr3x * desc%nsw( desc%mype + 1 ) 898 desc%tg_rcv(i) = desc%nr3x * desc%nsw( desc%iproc(i,desc%mype3+1) ) 899 desc%tg_sdsp(i) = desc%tg_sdsp(i-1) + desc%nnr 900 desc%tg_rdsp(i) = desc%tg_rdsp(i-1) + desc%tg_rcv(i-1) 901 ENDDO 902 903#if defined(__CUDA) 904 desc%ismap_d = desc%ismap 905 desc%ir1p_d = desc%ir1p 906 desc%ir1w_d = desc%ir1w 907 desc%ir1w_tg_d = desc%ir1w_tg 908 909 desc%indp_d = desc%indp 910 desc%indw_d = desc%indw 911 desc%indw_tg_d(:,1) = desc%indw_tg 912 913 desc%nr1p_d = desc%nr1p 914 desc%nr1w_d = desc%nr1w 915 desc%nr1w_tg_d(1) = desc%nr1w_tg 916 917#endif 918 IF (nmany > 1) ALLOCATE(desc%aux(nmany * desc%nnr)) 919 920 RETURN 921 922 END SUBROUTINE fft_type_set 923 924!=----------------------------------------------------------------------------=! 925 926 SUBROUTINE fft_type_init( dfft, smap, pers, lgamma, lpara, comm, at, bg, gcut_in, dual_in, fft_fact, nyfft, nmany, use_pd ) 927 928 USE stick_base 929 930 TYPE (fft_type_descriptor), INTENT(INOUT) :: dfft 931 TYPE (sticks_map), INTENT(INOUT) :: smap 932 CHARACTER(LEN=*), INTENT(IN) :: pers ! fft personality 933 LOGICAL, INTENT(IN) :: lpara 934 LOGICAL, INTENT(IN) :: lgamma 935 INTEGER, INTENT(IN) :: comm 936 REAL(DP), INTENT(IN) :: gcut_in 937 REAL(DP), INTENT(IN) :: bg(3,3) 938 REAL(DP), INTENT(IN) :: at(3,3) 939 REAL(DP), OPTIONAL, INTENT(IN) :: dual_in 940 INTEGER, INTENT(IN), OPTIONAL :: fft_fact(3) 941 INTEGER, INTENT(IN) :: nyfft 942 INTEGER, INTENT(IN) :: nmany 943 LOGICAL, OPTIONAL, INTENT(IN) :: use_pd ! whether to use pencil decomposition 944! 945! Potential or dual 946! 947 INTEGER, ALLOCATABLE :: st(:,:) 948! ... stick map, st(i,j) = number of G-vector in the 949! ... stick whose x and y miller index are i and j 950 INTEGER, ALLOCATABLE :: nstp(:) 951! ... number of sticks, nstp(ip) = number of stick for processor ip 952 INTEGER, ALLOCATABLE :: sstp(:) 953! ... number of G-vectors, sstp(ip) = sum of the 954! ... sticks length for processor ip = number of G-vectors owned by the processor ip 955 INTEGER :: nst 956! ... nst local number of sticks 957! 958! ... Plane wave 959! 960 INTEGER, ALLOCATABLE :: stw(:,:) 961! ... stick map (wave functions), stw(i,j) = number of G-vector in the 962! ... stick whose x and y miller index are i and j 963 INTEGER, ALLOCATABLE :: nstpw(:) 964! ... number of sticks (wave functions), nstpw(ip) = number of stick for processor ip 965 INTEGER, ALLOCATABLE :: sstpw(:) 966! ... number of G-vectors (wave functions), sstpw(ip) = sum of the 967! ... sticks length for processor ip = number of G-vectors owned by the processor ip 968 INTEGER :: nstw 969! ... nstw local number of sticks (wave functions) 970 971 REAL(DP) :: gcut, gkcut, dual 972 INTEGER :: ngm, ngw 973 !write (6,*) ' inside fft_type_init' ; FLUSH(6) 974 975 dual = fft_dual 976 IF( PRESENT( dual_in ) ) dual = dual_in 977 978 IF( pers == 'rho' ) THEN 979 gcut = gcut_in 980 gkcut = gcut / dual 981 ELSE IF ( pers == 'wave' ) THEN 982 gkcut = gcut_in 983 gcut = gkcut * dual 984 ELSE 985 CALL fftx_error__(' fft_type_init ', ' unknown FFT personality ', 1 ) 986 END IF 987 !write (*,*) 'FFT_TYPE_INIT pers, gkcut,gcut', pers, gkcut, gcut 988 989 IF( .NOT. ALLOCATED( dfft%nsp ) ) THEN 990 CALL fft_type_allocate( dfft, at, bg, gcut, comm, fft_fact=fft_fact, nyfft=nyfft ) 991 ELSE 992 IF( dfft%comm /= comm ) THEN 993 CALL fftx_error__(' fft_type_init ', ' FFT already allocated with a different communicator ', 1 ) 994 END IF 995 END IF 996 997 IF ( PRESENT (use_pd) ) dfft%use_pencil_decomposition = use_pd 998 IF ( ( .not. dfft%use_pencil_decomposition ) .and. ( nyfft > 1 ) ) & 999 CALL fftx_error__(' fft_type_init ', ' Slab decomposition and task groups not implemented. ', 1 ) 1000 1001 dfft%lpara = lpara ! this descriptor can be either a descriptor for a 1002 ! parallel FFT or a serial FFT even in parallel build 1003 1004 CALL sticks_map_allocate( smap, lgamma, dfft%lpara, dfft%nproc2, & 1005 dfft%iproc, dfft%iproc2, dfft%nr1, dfft%nr2, dfft%nr3, bg, dfft%comm ) 1006 1007 dfft%lgamma = smap%lgamma ! .TRUE. if the grid has Gamma symmetry 1008 1009 ALLOCATE( stw ( smap%lb(1):smap%ub(1), smap%lb(2):smap%ub(2) ) ) 1010 ALLOCATE( st ( smap%lb(1):smap%ub(1), smap%lb(2):smap%ub(2) ) ) 1011 ALLOCATE( nstp(smap%nproc) ) 1012 ALLOCATE( sstp(smap%nproc) ) 1013 ALLOCATE( nstpw(smap%nproc) ) 1014 ALLOCATE( sstpw(smap%nproc) ) 1015 1016 !write(*,*) 'calling get_sticks with gkcut =',gkcut 1017 CALL get_sticks( smap, gkcut, nstpw, sstpw, stw, nstw, ngw ) 1018 !write(*,*) 'calling get_sticks with gcut =',gcut 1019 CALL get_sticks( smap, gcut, nstp, sstp, st, nst, ngm ) 1020 1021 CALL fft_type_set( dfft, nst, smap%ub, smap%lb, smap%idx, & 1022 smap%ist(:,1), smap%ist(:,2), nstp, nstpw, sstp, sstpw, st, stw, nmany ) 1023 1024 dfft%ngw = dfft%nwl( dfft%mype + 1 ) 1025 dfft%ngm = dfft%ngl( dfft%mype + 1 ) 1026 IF( dfft%lgamma ) THEN 1027 dfft%ngw = (dfft%ngw + 1)/2 1028 dfft%ngm = (dfft%ngm + 1)/2 1029 END IF 1030 1031 IF( dfft%ngw /= ngw ) THEN 1032 CALL fftx_error__(' fft_type_init ', ' wrong ngw ', 1 ) 1033 END IF 1034 IF( dfft%ngm /= ngm ) THEN 1035 CALL fftx_error__(' fft_type_init ', ' wrong ngm ', 1 ) 1036 END IF 1037 1038 DEALLOCATE( st ) 1039 DEALLOCATE( stw ) 1040 DEALLOCATE( nstp ) 1041 DEALLOCATE( sstp ) 1042 DEALLOCATE( nstpw ) 1043 DEALLOCATE( sstpw ) 1044 1045 END SUBROUTINE fft_type_init 1046 1047!=----------------------------------------------------------------------------=! 1048 1049 SUBROUTINE realspace_grid_init( dfft, at, bg, gcutm, fft_fact ) 1050 ! 1051 ! ... Sets optimal values for dfft%nr[123] and dfft%nr[123]x 1052 ! ... If input dfft%nr[123] are non-zero, leaves them unchanged 1053 ! ... If fft_fact is present, force nr[123] to be multiple of fft_fac([123]) 1054 ! 1055 USE fft_support, only: good_fft_dimension, good_fft_order 1056 ! 1057 IMPLICIT NONE 1058 ! 1059 REAL(DP), INTENT(IN) :: at(3,3), bg(3,3) 1060 REAL(DP), INTENT(IN) :: gcutm 1061 INTEGER, INTENT(IN), OPTIONAL :: fft_fact(3) 1062 TYPE(fft_type_descriptor), INTENT(INOUT) :: dfft 1063 !write (6,*) ' inside realspace_grid_init' ; FLUSH(6) 1064 ! 1065 IF( dfft%nr1 == 0 .OR. dfft%nr2 == 0 .OR. dfft%nr3 == 0 ) THEN 1066 ! 1067 ! ... calculate the size of the real-space dense grid for FFT 1068 ! ... first, an estimate of nr1,nr2,nr3, based on the max values 1069 ! ... of n_i indices in: G = i*b_1 + j*b_2 + k*b_3 1070 ! ... We use G*a_i = n_i => n_i .le. |Gmax||a_i| 1071 ! 1072 dfft%nr1 = int ( sqrt (gcutm) * sqrt (at(1, 1)**2 + at(2, 1)**2 + at(3, 1)**2) ) + 1 1073 dfft%nr2 = int ( sqrt (gcutm) * sqrt (at(1, 2)**2 + at(2, 2)**2 + at(3, 2)**2) ) + 1 1074 dfft%nr3 = int ( sqrt (gcutm) * sqrt (at(1, 3)**2 + at(2, 3)**2 + at(3, 3)**2) ) + 1 1075 1076#if defined (__DEBUG) 1077 write (6,*) sqrt(gcutm)*sqrt(at(1,1)**2 + at(2,1)**2 + at(3,1)**2) , dfft%nr1 1078 write (6,*) sqrt(gcutm)*sqrt(at(1,2)**2 + at(2,2)**2 + at(3,2)**2) , dfft%nr2 1079 write (6,*) sqrt(gcutm)*sqrt(at(1,3)**2 + at(2,3)**2 + at(3,3)**2) , dfft%nr3 1080#endif 1081 ! 1082 CALL grid_set( dfft, bg, gcutm, dfft%nr1, dfft%nr2, dfft%nr3 ) 1083 ! 1084 IF ( PRESENT(fft_fact) ) THEN 1085 dfft%nr1 = good_fft_order( dfft%nr1, fft_fact(1) ) 1086 dfft%nr2 = good_fft_order( dfft%nr2, fft_fact(2) ) 1087 dfft%nr3 = good_fft_order( dfft%nr3, fft_fact(3) ) 1088 ELSE 1089 dfft%nr1 = good_fft_order( dfft%nr1 ) 1090 dfft%nr2 = good_fft_order( dfft%nr2 ) 1091 dfft%nr3 = good_fft_order( dfft%nr3 ) 1092 ENDIF 1093#if defined (__DEBUG) 1094 ELSE 1095 WRITE( stdout, '( /, 3X,"Info: using nr1, nr2, nr3 values from input" )' ) 1096#endif 1097 END IF 1098 ! 1099 dfft%nr1x = good_fft_dimension( dfft%nr1 ) 1100 dfft%nr2x = dfft%nr2 1101 dfft%nr3x = good_fft_dimension( dfft%nr3 ) 1102 1103 END SUBROUTINE realspace_grid_init 1104 1105!=----------------------------------------------------------------------------=! 1106 1107 SUBROUTINE grid_set( dfft, bg, gcut, nr1, nr2, nr3 ) 1108 1109! this routine returns in nr1, nr2, nr3 the minimal 3D real-space FFT 1110! grid required to fit the G-vector sphere with G^2 <= gcut 1111! On input, nr1,nr2,nr3 must be set to values that match or exceed 1112! the largest i,j,k (Miller) indices in G(i,j,k) = i*b1 + j*b2 + k*b3 1113! ---------------------------------------------- 1114 1115 IMPLICIT NONE 1116 1117! ... declare arguments 1118 TYPE(fft_type_descriptor), INTENT(IN) :: dfft 1119 INTEGER, INTENT(INOUT) :: nr1, nr2, nr3 1120 REAL(DP), INTENT(IN) :: bg(3,3), gcut 1121 1122! ... declare other variables 1123 INTEGER :: i, j, k, nb(3) 1124 REAL(DP) :: gsq, g(3) 1125 !write (6,*) ' inside grid_set' ; FLUSH(6) 1126 1127! ---------------------------------------------- 1128 1129 nb = 0 1130 1131! ... calculate moduli of G vectors and the range of indices where 1132! ... |G|^2 < gcut (in parallel whenever possible) 1133 1134 DO k = -nr3, nr3 1135 ! 1136 ! ... me_image = processor number, starting from 0 1137 ! 1138 IF( MOD( k + nr3, dfft%nproc ) == dfft%mype ) THEN 1139 DO j = -nr2, nr2 1140 DO i = -nr1, nr1 1141 1142 g( 1 ) = DBLE(i)*bg(1,1) + DBLE(j)*bg(1,2) + DBLE(k)*bg(1,3) 1143 g( 2 ) = DBLE(i)*bg(2,1) + DBLE(j)*bg(2,2) + DBLE(k)*bg(2,3) 1144 g( 3 ) = DBLE(i)*bg(3,1) + DBLE(j)*bg(3,2) + DBLE(k)*bg(3,3) 1145 1146! ... calculate modulus 1147 1148 gsq = g( 1 )**2 + g( 2 )**2 + g( 3 )**2 1149 1150 IF( gsq < gcut ) THEN 1151 1152! ... calculate maximum index 1153 nb(1) = MAX( nb(1), ABS( i ) ) 1154 nb(2) = MAX( nb(2), ABS( j ) ) 1155 nb(3) = MAX( nb(3), ABS( k ) ) 1156 END IF 1157 1158 END DO 1159 END DO 1160 END IF 1161 END DO 1162 1163#if defined(__MPI) 1164 CALL MPI_ALLREDUCE( MPI_IN_PLACE, nb, 3, MPI_INTEGER, MPI_MAX, dfft%comm, i ) 1165#endif 1166 1167! ... the size of the 3d FFT matrix depends upon the maximum indices. With 1168! ... the following choice, the sphere in G-space "touches" its periodic image 1169 1170 nr1 = 2 * nb(1) + 1 1171 nr2 = 2 * nb(2) + 1 1172 nr3 = 2 * nb(3) + 1 1173 1174 RETURN 1175 1176 END SUBROUTINE grid_set 1177 1178 PURE FUNCTION fft_stick_index( desc, i, j ) 1179 IMPLICIT NONE 1180 TYPE(fft_type_descriptor), INTENT(IN) :: desc 1181 INTEGER :: fft_stick_index 1182 INTEGER, INTENT(IN) :: i, j 1183 INTEGER :: mc, m1, m2 1184 m1 = mod (i, desc%nr1) + 1 1185 IF (m1 < 1) m1 = m1 + desc%nr1 1186 m2 = mod (j, desc%nr2) + 1 1187 IF (m2 < 1) m2 = m2 + desc%nr2 1188 mc = m1 + (m2 - 1) * desc%nr1x 1189 fft_stick_index = desc%isind ( mc ) 1190 END FUNCTION 1191 1192 ! 1193 SUBROUTINE fft_index_to_3d (ir, dfft, i,j,k, offrange) 1194 ! 1195 !! returns indices i,j,k yielding the position of grid point ir 1196 !! in the real-space FFT grid described by descriptor dfft: 1197 !! r(:,ir)= i*tau(:,1)/n1 + j*tau(:,2)/n2 + k*tau(:,3)/n3 1198 ! 1199 IMPLICIT NONE 1200 INTEGER, INTENT(IN) :: ir 1201 !! point in the FFT real-space grid 1202 TYPE(fft_type_descriptor), INTENT(IN) :: dfft 1203 !! descriptor for the FFT grid 1204 INTEGER, INTENT(OUT) :: i 1205 !! (i,j,k) corresponding to grid point ir 1206 INTEGER, INTENT(OUT) :: j 1207 !! (i,j,k) corresponding to grid point ir 1208 INTEGER, INTENT(OUT) :: k 1209 !! (i,j,k) corresponding to grid point ir 1210 LOGICAL, INTENT(OUT) :: offrange 1211 !! true if computed i,j,k lie outside the physical range of values 1212 ! 1213 i = ir - 1 1214 k = i / (dfft%nr1x*dfft%my_nr2p) 1215 i = i - (dfft%nr1x*dfft%my_nr2p) * k 1216 j = i / dfft%nr1x 1217 i = i - dfft%nr1x * j 1218 j = j + dfft%my_i0r2p 1219 k = k + dfft%my_i0r3p 1220 ! 1221 offrange = (i < 0 .OR. i >= dfft%nr1 ) .OR. & 1222 (j < 0 .OR. j >= dfft%nr2 ) .OR. & 1223 (k < 0 .OR. k >= dfft%nr3 ) 1224 ! 1225 END SUBROUTINE fft_index_to_3d 1226 1227!=----------------------------------------------------------------------------=! 1228END MODULE fft_types 1229!=----------------------------------------------------------------------------=! 1230