1!========================================================================== 2! 3! Module: 4! 5! genwf_eps Originally By FHJ Last Modified 04/2012 (FHJ) 6! 7! 8!========================================================================== 9 10#include "f_defs.h" 11 12module genwf_eps_m 13 14 use global_m 15 use input_utils_m 16 use fftw_m 17 use genwf_mpi_m 18 use gmap_m 19 use sort_m 20 use timing_m, only: timing => epsilon_timing 21 22 implicit none 23 24 private 25 26 !> communicator object for the WFN FFTs 27 type wfn_FFT_comm_t 28 logical :: done !< Have we received all the buffers? 29 integer, pointer :: req_recvv(:), req_recvc(:) !< Array of MPI_REQUEST 30 integer, pointer :: req_sendv(:), req_sendc(:) !< Array of MPI_REQUEST 31 integer :: recv_cntv, recv_cntc !< Number of requests 32 integer :: send_cntv, send_cntc !< Number of requests 33 integer :: nv !< Number of valence bands 34 end type wfn_FFT_comm_t 35 36 public :: genwf_gen, genwf_FFT, free_wfns, genwf_lvl2, & 37 genwf_FFT_Isend, genwf_FFT_Wait, wfn_FFT_comm_t, & 38 get_wfn_fftgrid, get_eps_fftgrid 39 40contains 41 42 !> FHJ: Figure out what is the min/max gvec components give an isort array 43 !! Note: you should manually initialize box_min and box_max to zero! 44 subroutine get_gvecs_bounds(gvec, ng, isort, box_min, box_max) 45 type(gspace), intent(in) :: gvec 46 integer, intent(in) :: ng 47 integer, intent(in) :: isort(:) 48 integer, intent(inout) :: box_min(3), box_max(3) 49 50 integer :: ig 51 52 PUSH_SUB(get_gvecs_bounds) 53 54 do ig=1,ng 55 box_min(1:3) = min(box_min(1:3), gvec%components(1:3, isort(ig))) 56 box_max(1:3) = max(box_max(1:3), gvec%components(1:3, isort(ig))) 57 enddo 58 59 POP_SUB(get_gvecs_bounds) 60 return 61 62 end subroutine get_gvecs_bounds 63 64 !> FHJ: Figure out the minimum fftbox that holds all the WFNs 65 subroutine get_wfn_fftgrid(pol, gvec, kp, intwfn) 66 type(polarizability), intent(inout) :: pol 67 type(gspace), intent(in) :: gvec 68 type(kpoints), intent(in) :: kp 69 type(int_wavefunction), intent(in) :: intwfn 70 71 integer :: ik, wfn_box_min(3), wfn_box_max(3) 72 73 PUSH_SUB(get_wfn_fftgrid) 74 75 wfn_box_min(:) = 0; wfn_box_max(:) = 0 76 do ik=1,kp%nrk 77 call get_gvecs_bounds(gvec, intwfn%ng(ik), intwfn%isort(:,ik), wfn_box_min, wfn_box_max) 78 enddo 79 pol%WFN_FFTgrid(1:3) = wfn_box_max(1:3) - wfn_box_min(1:3) + 1 80 if (peinf%verb_debug .and. peinf%inode==0) then 81 write(6,*) 'WFN min. FFT grid:',pol%WFN_FFTgrid 82 endif 83 84 POP_SUB(get_wfn_fftgrid) 85 return 86 87 end subroutine get_wfn_fftgrid 88 89 !> FHJ: Figure out the minimum fftbox that allows us to convolve the WFNs within 90 !! an energy window of nmtx G vectors. 91 subroutine get_eps_fftgrid(pol, gvec) 92 type(polarizability), intent(inout) :: pol 93 type(gspace), intent(in) :: gvec 94 integer :: eps_box_min(3), eps_box_max(3) 95 96 PUSH_SUB(get_eps_fftgrid) 97 98 eps_box_min(:) = 0; eps_box_max(:) = 0 99 call get_gvecs_bounds(gvec, pol%nmtx, pol%isrtx, eps_box_min, eps_box_max) 100 ! FHJ: Note: the amount of padding is actually N_sig + N_window - 1 101 pol%FFTgrid(1:3) = pol%WFN_FFTgrid(1:3) + (eps_box_max(1:3) - eps_box_min(1:3)) 102 pol%FFTgrid(1:3) = min(pol%FFTgrid(1:3), gvec%FFTgrid(1:3)) 103 if (peinf%verb_debug .and. peinf%inode==0) then 104 write(6,'(1x,a,3(1x,i0))') 'Original FFT grid:', gvec%FFTgrid 105 write(6,'(1x,a,3(1x,i0))') 'Minimal FFT grid:', pol%FFTgrid 106 endif 107 108 POP_SUB(get_eps_fftgrid) 109 return 110 111 end subroutine get_eps_fftgrid 112 113 !> FHJ: to be used internally with genwf_FFT_Isend 114 subroutine do_my_FFTs(this,gvec,Nfft,wfn_fft,intwfn,my_bands,my_cnt,ng,tmp_wfn,isort,ind,ph,is_val) 115 type (wfn_FFT_comm_t), intent(inout), target :: this !< communicator object for the WFN FFTs 116 type(gspace), intent(in) :: gvec 117 complex(DPC), intent(inout) :: wfn_fft(:,:,:,:) 118 integer, intent(in) :: Nfft(3) 119 type(int_wavefunction), intent(in) :: intwfn 120 integer, intent(in) :: my_bands(:) !< my bands 121 integer, intent(in) :: my_cnt !< number of FFTs to do = sizeof(my_bands) 122 integer, intent(in) :: ng !< number of gvectors for the k-pt in question 123 SCALAR, intent(inout) :: tmp_wfn(:)!< buffer to reorder WFN using ind and ph 124 integer, intent(in) :: isort(:) !< wfn isort 125 integer, intent(in) :: ind(:) 126 SCALAR, intent(in) :: ph(:) 127 logical, intent(in) :: is_val !< .true. to take the conjg_fftbox 128 129 integer :: is, ig, fft_size 130 integer :: ib_list, ib_local, ib, iproc, offset 131 integer, pointer :: invindex(:), send_cnt, req_send(:) 132 logical, pointer :: does_it_own(:,:) 133 134 PUSH_SUB(do_my_FFTs) 135 136 is = 1 137 if ( is_val ) then 138 invindex => peinf%invindexv 139 does_it_own => peinf%does_it_ownv 140 req_send => this%req_sendv 141 send_cnt => this%send_cntv 142 offset = 0 143 else 144 invindex => peinf%invindexc 145 does_it_own => peinf%does_it_ownc 146 req_send => this%req_sendc 147 send_cnt => this%send_cntc 148 offset = this%nv 149 endif 150 fft_size = product(Nfft(1:3)) 151 152 do ib_list = 1, my_cnt 153 ib_local = my_bands(ib_list) 154 do ig=1,ng 155 tmp_wfn(ig) = intwfn%cg(ind(ig), ib_local, is)*ph(ig) 156 enddo 157 if (peinf%inode==0) call timing%start(timing%opt_fft_fft) 158 call put_into_fftbox(ng, tmp_wfn, gvec%components, & 159 isort, wfn_fft(:,:,:,ib_local), Nfft) 160 call do_FFT( wfn_fft(:,:,:,ib_local), Nfft, 1) 161 if ( is_val ) call conjg_fftbox( wfn_fft(:,:,:,ib_local), Nfft ) 162 if (peinf%inode==0) call timing%stop(timing%opt_fft_fft) 163 164#ifdef MPI 165 ! FHJ: distribute all real-space WFNs 166 if(peinf%inode==0) call timing%start(timing%opt_fft_comm_fft) 167 ib = invindex(ib_local) 168 do iproc=0, peinf%npes-1 169 if (does_it_own(ib, iproc+1).and.iproc/=peinf%inode) then 170 send_cnt = send_cnt + 1 171 call MPI_Isend( wfn_fft(1,1,1,ib_local), fft_size, MPI_COMPLEX_DPC, iproc, & 172 offset + ib, MPI_COMM_WORLD, req_send(send_cnt), mpierr ) 173 endif 174 enddo 175 if(peinf%inode==0) call timing%stop(timing%opt_fft_comm_fft) 176#endif 177 enddo 178 179 POP_SUB(do_my_FFTs) 180 return 181 182 end subroutine do_my_FFTs 183 184 !> Generates all the real-space wavefunctions. Used only if pol%os_opt_fft==2 185 !! This version avoids communication, and it`s under development! 186 !!TODO`s: 187 !! (1) we are just supporting one spin and one kpt/qpt. 188 !! (2) support serial code 189 subroutine genwf_FFT_Isend(this,crys,gvec,syms,kp,kpq,vwfn,pol,cwfn,intwfnv,intwfnvq,intwfnc) 190 type (wfn_FFT_comm_t), intent(inout) :: this !< communicator object for the WFN FFTs 191 type (crystal), intent(in) :: crys 192 type (gspace), intent(in) :: gvec 193 type (symmetry), intent(in) :: syms 194 type (kpoints), target, intent(in) :: kp 195 type (kpoints), target, intent(in) :: kpq 196 type (valence_wfns), intent(inout) :: vwfn 197 type (polarizability), intent(inout) :: pol 198 type (conduction_wfns), intent(inout) :: cwfn 199 type (int_wavefunction), intent(inout) :: intwfnv 200 type (int_wavefunction), intent(inout) :: intwfnvq 201 type (int_wavefunction), intent(inout) :: intwfnc 202 203 integer :: npes_per_pool, nc_groups, nproc_max, nc 204 integer, allocatable :: nv_bands(:), v_owners(:) 205 integer :: Nfft(3), fft_size 206 real(DP) :: scale 207 integer, allocatable :: my_vbands(:) 208 integer :: my_vcnt, iv, iv_local 209 integer :: ipool, isubrank, inode, ioffset 210 integer, allocatable :: grp_global_ranks(:), ntot_bands(:), & 211 grp_local_owners(:) 212 integer, allocatable :: my_cbands(:) 213 integer :: my_ccnt, ic, ic_local 214 integer :: my_grp_rank, grp_nprocs 215 integer :: min_bands, iproc, iworker 216 integer :: ik, is, ib 217 218 !sort stuff 219 SCALAR, allocatable :: tmp_wfn(:) 220 SCALAR, allocatable :: ph(:) 221 real(DP), allocatable :: ekin(:) 222 integer, allocatable :: ind(:), isorti(:) 223 integer, allocatable, target :: wfn_isort(:) 224 integer :: ng0, ig 225 226 PUSH_SUB(genwf_FFT_Isend) 227 228 call logit('generating all real-space wavefunctions') 229 230 if(peinf%inode==0) call timing%start(timing%opt_fft) 231 232 if(pol%nq>1.or.kp%nrk>1.or.kpq%nrk>0.or.pol%need_WFNq) & 233 call die('FFT opt. level 2 only works for 1 qpt and 1 kpt, and without WFNq',& 234 only_root_writes=.true.) 235 236#ifdef MPI 237 238 ik = 1 ! Fixing one kpt for now 239 is = 1 ! And the spin 240 vwfn%idx_kp = ik 241 cwfn%idx_kp = ik 242 nc = cwfn%nband - vwfn%nband ! `real` number of conduction bands 243 npes_per_pool = (peinf%npes/peinf%npools) ! number of processors per pool 244 nc_groups = (nc + peinf%ncownmax - 1)/(peinf%ncownmax) ! number of conduction groups 245 nproc_max = (peinf%npes + nc_groups - 1)/(nc_groups) ! max num. of proc. per conduction group 246 247 ! Basic initialization of the communicator object 248 this%recv_cntv = 0; this%recv_cntc = 0 249 this%send_cntv = 0; this%send_cntc = 0 250 this%done = .false. 251 this%nv = vwfn%nband 252 SAFE_ALLOCATE(this%req_recvv, (vwfn%nband*npes_per_pool)) 253 SAFE_ALLOCATE(this%req_recvc, (nc*nproc_max)) 254 SAFE_ALLOCATE(this%req_sendv, (vwfn%nband*npes_per_pool)) 255 SAFE_ALLOCATE(this%req_sendc, (nc*nproc_max)) 256 257 if(peinf%inode==0) call timing%start(timing%opt_fft_init) 258 259 ! FHJ: prepare the sorting buffers 260 ng0 = intwfnv%ng(ik) ! We all have at least one valence band 261 SAFE_ALLOCATE(ind,(ng0)) 262 SAFE_ALLOCATE(ph, (ng0)) 263 SAFE_ALLOCATE(wfn_isort, (gvec%ng)) 264 SAFE_ALLOCATE(isorti, (gvec%ng)) 265 SAFE_ALLOCATE(ekin, (gvec%ng)) 266 call kinetic_energies(gvec, crys%bdot, ekin) 267 call sortrx(gvec%ng, ekin, wfn_isort, gvec=gvec%components) 268 269 do ig=1,gvec%ng 270 isorti(wfn_isort(ig)) = ig 271 enddo 272 do ig=1,ng0 273 isorti(intwfnv%isort(ig, ik)) = ig 274 enddo 275 call gmap(gvec,syms,ng0,1,(/0,0,0/),wfn_isort,isorti,ind,ph,.true.) 276 vwfn%ngv=ng0 277 cwfn%ngc=ng0 278 279 ! FHJ: this will be more complicated to implement for more than 1 kpt 280 if (pol%min_fftgrid) then 281 pol%isrtx => wfn_isort 282 pol%nmtx = gcutoff(gvec%ng, ekin, pol%isrtx, pol%ecuts) 283 call get_eps_fftgrid(pol, gvec) 284 nullify(pol%isrtx) 285 endif 286 call setup_FFT_sizes(pol%FFTgrid,Nfft,scale) 287 fft_size = Nfft(1)*Nfft(2)*Nfft(3) 288 289 ! FHJ: Distribute VALENCE bands ! 290 !-------------------------------- 291 ! NOTE: we index the processors wrt their real MPI ranks (peinf%inode) 292 ! we refer all the val bands wrt the local index in my_vbands 293 294 ! Number of valence bands that a particular processor owns 295 SAFE_ALLOCATE(nv_bands, (peinf%npes)) 296 ! Who owns a particular valence band? 297 SAFE_ALLOCATE(v_owners, (vwfn%nband + pol%ncrit)) 298 ! Which valence bands do I own? 299 SAFE_ALLOCATE(my_vbands, (peinf%nvownmax)) 300 nv_bands(:) = 0; v_owners(:) = 0 301 my_vbands(:) = 0 302 my_vcnt = 0 303 304 do iv=1, vwfn%nband + pol%ncrit 305 ipool = (iv-1)/peinf%nvownmax ! Get pool for the band 306 ! Distribute bands to the processors in a round-robin way. We add an 307 ! offset so that we get a good load balance for the cond. bands later on. 308 isubrank = mod(iv-1, peinf%nvownmax) 309 ioffset = ipool*peinf%nvownmax 310 inode = mod(isubrank+ioffset, npes_per_pool) + ipool*npes_per_pool 311 v_owners(iv) = inode 312 ! Keep track of the # of times that each processors got a band: 313 nv_bands(inode+1) = nv_bands(inode+1) + 1 314 if (peinf%inode == inode) then 315 my_vcnt = my_vcnt + 1 316 my_vbands(my_vcnt) = isubrank + 1 317 endif 318 enddo 319 320 ! FHJ: Distribute CONDUCTION bands ! 321 !----------------------------------! 322 ! NOTE: we use *local ranks* to organize the processors in a particular cond. group 323 ! we refer all the cond. bands wrt the local index in my_cbands 324 325 ! Ranks of the processors that participate in this conduction group 326 SAFE_ALLOCATE(grp_global_ranks, (nproc_max)) 327 grp_global_ranks(:) = -1 328 SAFE_ALLOCATE(ntot_bands, (nproc_max)) ! Total number of bands per processor 329 ntot_bands(:) = 0 330 ! Who owns a particular conduction band? (in terms of local workers) 331 SAFE_ALLOCATE(grp_local_owners, (peinf%ncownactual)) 332 grp_local_owners(:) = 0 333 ! Which conduction bands do I own? 334 SAFE_ALLOCATE(my_cbands, (peinf%ncownactual)) 335 my_cbands(:) = 0 336 my_ccnt = 0 337 338 ! This will be set to my rank within the group of conduction bands, starting at 0. 339 my_grp_rank = -1 340 341 ! Create list of all processors in that same group 342 ! want: grp_global_ranks(:) = [icgroup, icgroup + npes_per_group, ...] 343 grp_global_ranks(:) = 0 344 grp_nprocs = 0 345 ib = peinf%invindexc(1) ! index of first band that I own. 346 if (ib>0) then 347 do iproc=0, peinf%npes-1 348 if (peinf%does_it_ownc(ib, iproc+1)) then 349 grp_nprocs = grp_nprocs + 1 350 grp_global_ranks(grp_nprocs) = iproc 351 if (iproc==peinf%inode) then 352 my_grp_rank = grp_nprocs - 1 353 endif 354 endif 355 enddo 356 endif 357 358 ! Initialize list of the number of bands that each processor has 359 do iproc = 1, grp_nprocs 360 ntot_bands(iproc) = nv_bands(grp_global_ranks(iproc)+1) 361 enddo 362 363 do ic_local = 1, peinf%ncownactual 364 ! Get smallest index of ntot_bands array -> iworker 365 ! note: rank = grp_local_ranks(iworker) 366 iworker = 0 367 min_bands = cwfn%nband + 1 368 do iproc=1, grp_nprocs 369 if (ntot_bands(iproc)<min_bands) then 370 min_bands = ntot_bands(iproc) 371 iworker = iproc - 1 372 endif 373 enddo 374 375 ! add ic_local to list of bands that iworker owns 376 grp_local_owners(ic_local) = iworker 377 ntot_bands(iworker+1) = ntot_bands(iworker+1) + 1 378 if (my_grp_rank == iworker) then 379 my_ccnt = my_ccnt + 1 380 my_cbands(my_ccnt) = ic_local 381 endif 382 enddo 383 384 if(peinf%inode==0) call timing%stop(timing%opt_fft_init) 385 if(peinf%inode==0) call timing%start(timing%opt_fft_comm_fft) 386 387 ! FHJ: Non-blocking Recvs ! 388 !-------------------------! 389 390 ! Recv valence bands 391 SAFE_ALLOCATE( vwfn%wfn_fft, (Nfft(1), Nfft(2), Nfft(3), peinf%nvownactual) ) 392 do iv_local=1, peinf%nvownactual 393 iv = peinf%invindexv(iv_local) 394 if (v_owners(iv)/=peinf%inode) then 395 this%recv_cntv = this%recv_cntv + 1 396 call MPI_Irecv( vwfn%wfn_fft(1,1,1,iv_local), fft_size, MPI_COMPLEX_DPC, v_owners(iv), & 397 iv, MPI_COMM_WORLD, this%req_recvv(this%recv_cntv), mpierr ) 398 endif 399 enddo 400 ! Recv conduction bands 401 SAFE_ALLOCATE( cwfn%wfn_fft, (Nfft(1), Nfft(2), Nfft(3), peinf%ncownactual) ) 402 do ic_local=1, peinf%ncownactual 403 ic = peinf%invindexc(ic_local) 404 if (grp_local_owners(ic_local)/=my_grp_rank) then 405 this%recv_cntc = this%recv_cntc + 1 406 iproc = grp_global_ranks(grp_local_owners(ic_local)+1) 407 call MPI_Irecv( cwfn%wfn_fft(1,1,1,ic_local), fft_size, MPI_COMPLEX_DPC, iproc, & 408 vwfn%nband + ic, MPI_COMM_WORLD, this%req_recvc(this%recv_cntc), mpierr ) 409 endif 410 enddo 411 412 if(peinf%inode==0) call timing%stop(timing%opt_fft_comm_fft) 413 414 ! FHJ: Do the FFTs + Isend`s ! 415 !----------------------------! 416 417 call logit('doing FFTs') 418 419 SAFE_ALLOCATE(tmp_wfn, (ng0)) 420 call do_my_FFTs(this,gvec,Nfft,vwfn%wfn_fft,intwfnv,my_vbands,my_vcnt,& 421 ng0,tmp_wfn,wfn_isort,ind,ph,.true.) 422 call do_my_FFTs(this,gvec,Nfft,cwfn%wfn_fft,intwfnc,my_cbands,my_ccnt,& 423 ng0,tmp_wfn,wfn_isort,ind,ph,.false.) 424 SAFE_DEALLOCATE(tmp_wfn) 425 426 ! FHJ: Deallocate original wavefunctions (but only the wfn!) 427 call free_wfns(pol, intwfnv, intwfnvq, intwfnc, .false.) 428 429 call logit('done with FFTs') 430 431 ! Free comm buffers 432 SAFE_DEALLOCATE(nv_bands) 433 SAFE_DEALLOCATE(v_owners) 434 SAFE_DEALLOCATE(my_vbands) 435 SAFE_DEALLOCATE(grp_global_ranks) 436 SAFE_DEALLOCATE(ntot_bands) 437 SAFE_DEALLOCATE(grp_local_owners) 438 SAFE_DEALLOCATE(my_cbands) 439 440 ! and free sorting stuff 441 SAFE_DEALLOCATE(ind) 442 SAFE_DEALLOCATE(ph) 443 SAFE_DEALLOCATE(wfn_isort) 444 SAFE_DEALLOCATE(isorti) 445 SAFE_DEALLOCATE(ekin) 446 447#endif 448 449 call logit('done generating real-space wavefunctions') 450 451 if(peinf%inode==0) call timing%stop(timing%opt_fft) 452 453 POP_SUB(genwf_FFT_Isend) 454 return 455 456 end subroutine genwf_FFT_Isend 457 458 459 ! FHJ: call me after genwf_FFT_Isend, but just before you actually need the data 460 subroutine genwf_FFT_Wait(this) 461 type (wfn_FFT_comm_t), intent(inout) :: this !< communicator object for the WFN FFTs 462 463 PUSH_SUB(genwf_FFT_Wait) 464 465#ifdef MPI 466 if(peinf%inode==0) call timing%start(timing%opt_fft) 467 if(peinf%inode==0) call timing%start(timing%opt_fft_comm_fft) 468 if (this%recv_cntv>0) then 469 call MPI_Waitall(this%recv_cntv, this%req_recvv, MPI_STATUSES_IGNORE, mpierr) 470 endif 471 if (this%send_cntv>0) then 472 call MPI_Waitall(this%send_cntv, this%req_sendv, MPI_STATUSES_IGNORE, mpierr) 473 endif 474 if (this%recv_cntc>0) then 475 call MPI_Waitall(this%recv_cntc, this%req_recvc, MPI_STATUSES_IGNORE, mpierr) 476 endif 477 if (this%send_cntc>0) then 478 call MPI_Waitall(this%send_cntc, this%req_sendc, MPI_STATUSES_IGNORE, mpierr) 479 endif 480 if(peinf%inode==0) call timing%stop(timing%opt_fft_comm_fft) 481 482 SAFE_DEALLOCATE_P(this%req_recvv) 483 SAFE_DEALLOCATE_P(this%req_sendv) 484 SAFE_DEALLOCATE_P(this%req_recvc) 485 SAFE_DEALLOCATE_P(this%req_sendc) 486 this%done = .true. 487 if(peinf%inode==0) call timing%stop(timing%opt_fft) 488#endif 489 490 POP_SUB(genwf_FFT_Wait) 491 return 492 493 end subroutine genwf_FFT_Wait 494 495 496 subroutine genwf_lvl2(kp,kpq,vwfn,pol,cwfn) 497 type (kpoints), target, intent(in) :: kp 498 type (kpoints), target, intent(in) :: kpq 499 type (valence_wfns), intent(inout) :: vwfn 500 type (polarizability), intent(in) :: pol 501 type (conduction_wfns), intent(inout) :: cwfn 502 503 type(kpoints), pointer :: kp_point 504 505 PUSH_SUB(genwf_lvl2) 506 507 if(pol%need_WFNq) then ! FIXME I think this is wrong if pol%nq1>0 508 kp_point => kpq 509 else 510 kp_point => kp 511 endif 512 513 SAFE_ALLOCATE(vwfn%ev, (vwfn%nband+pol%ncrit, kp%nspin)) 514 SAFE_ALLOCATE(cwfn%ec, (cwfn%nband,kp%nspin)) 515 vwfn%ev(1:vwfn%nband+pol%ncrit,1:kp%nspin) = & 516 kp_point%el(1:vwfn%nband+pol%ncrit, vwfn%idx_kp, 1:kp%nspin) 517 cwfn%ec(1:cwfn%nband,1:kp%nspin) = & 518 kp%el(1:cwfn%nband, cwfn%idx_kp, 1:kp%nspin) 519 520 POP_SUB(genwf_lvl2) 521 return 522 end subroutine genwf_lvl2 523 524 525 !> Generates all the real-space wavefunctions. Used only if pol%os_opt_fft==2 526 !!TODO`s: 527 !! (1) we are just supporting one spin and one kpt/qpt. 528 !! (2) communication can be reduced if we distribute the WFNs in a smarter way 529 !! (3) support serial code 530 subroutine genwf_FFT(crys,gvec,syms,kp,kpq,vwfn,pol,cwfn,intwfnv,intwfnvq,intwfnc) 531 type (crystal), intent(in) :: crys 532 type (gspace), intent(in) :: gvec 533 type (symmetry), intent(in) :: syms 534 type (kpoints), target, intent(in) :: kp 535 type (kpoints), target, intent(in) :: kpq 536 type (valence_wfns), intent(inout) :: vwfn 537 type (polarizability), intent(inout) :: pol 538 type (conduction_wfns), intent(inout) :: cwfn 539 type (int_wavefunction), intent(inout) :: intwfnv 540 type (int_wavefunction), intent(inout) :: intwfnvq 541 type (int_wavefunction), intent(inout) :: intwfnc 542 543 integer :: ik, is, ipe 544 integer :: own_max 545 integer, allocatable :: band_owners(:) 546 integer :: ib_loc, ib, ng0 547 integer :: receiver, recv_cnt, send_cnt 548 integer :: local_band_idx 549 integer, allocatable :: req_send(:), req_recv(:) 550 integer, allocatable :: my_bands(:), isort0(:) 551 SCALAR, allocatable :: bufs_wfn(:,:) 552 complex(DPC), allocatable :: work_ffts(:,:,:,:) 553 integer :: Nfft(3), fft_size 554 real(DP) :: scale 555 SCALAR, allocatable :: tmp_wfn(:) 556 SCALAR, allocatable :: ph(:) 557 real(DP), allocatable :: ekin(:) 558 integer, allocatable :: ind(:), isorti(:) 559 integer, allocatable, target :: wfn_isort(:) 560 integer :: ig 561 562 PUSH_SUB(genwf_FFT) 563 564 call logit('generating all real-space wavefunctions') 565 566 if(peinf%inode==0) call timing%start(timing%opt_fft) 567 568 if(pol%nq>1.or.kp%nrk>1.or.kpq%nrk>0.or.pol%need_WFNq) & 569 call die('FFT opt. level 2 only works for 1 qpt and 1 kpt, and without WFNq',& 570 only_root_writes=.true.) 571 572#ifdef MPI 573 574 own_max = (cwfn%nband + peinf%npes - 1)/(peinf%npes) 575 576 ik = 1 ! Fixing one kpt for now 577 is = 1 ! And the spin 578 579 vwfn%idx_kp = ik 580 cwfn%idx_kp = ik 581 582 if ( peinf%inode == 0 ) call timing%start(timing%opt_fft_init) 583 584 ! FHJ: Only root is guaranteed to have at least one band. 585 if (peinf%inode==0) ng0 = intwfnv%ng(ik) 586 call MPI_Bcast(ng0, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, mpierr) 587 SAFE_ALLOCATE(isort0, (ng0)) 588 if (peinf%inode==0) isort0(1:ng0) = intwfnv%isort(1:ng0, ik) 589 call MPI_Bcast(isort0(1), ng0, MPI_INTEGER, 0, MPI_COMM_WORLD, mpierr) 590 591 ! FHJ: prepare the sorting stuff 592 SAFE_ALLOCATE(ind,(ng0)) 593 SAFE_ALLOCATE(ph, (ng0)) 594 SAFE_ALLOCATE(wfn_isort, (gvec%ng)) 595 SAFE_ALLOCATE(isorti, (gvec%ng)) 596 SAFE_ALLOCATE(ekin, (gvec%ng)) 597 SAFE_ALLOCATE(tmp_wfn, (ng0)) 598 call kinetic_energies(gvec, crys%bdot, ekin) 599 call sortrx(gvec%ng, ekin, wfn_isort, gvec=gvec%components) 600 do ig=1,gvec%ng 601 isorti(wfn_isort(ig)) = ig 602 enddo 603 do ig=1,ng0 604 isorti(isort0(ig)) = ig 605 enddo 606 call gmap(gvec,syms,ng0,1,(/0,0,0/),wfn_isort,isorti,ind,ph,.true.) 607 vwfn%ngv=ng0 608 cwfn%ngc=ng0 609 610 ! FHJ: Work buffers 611 SAFE_ALLOCATE(band_owners, (cwfn%nband)) 612 SAFE_ALLOCATE(my_bands, (cwfn%nband)) 613 SAFE_ALLOCATE(req_send, (cwfn%nband)) 614 SAFE_ALLOCATE(req_recv, (cwfn%nband)) 615 SAFE_ALLOCATE(bufs_wfn, (ng0, own_max)) 616 617 ! FHJ: Re-distribute WFNs 618 recv_cnt = 0 619 send_cnt = 0 620 do ib = 1, cwfn%nband 621 receiver = mod(ib-1, peinf%npes) 622 band_owners(ib) = receiver 623 624 ! FHJ: Receiving part 625 if (receiver==peinf%inode) then 626 recv_cnt = recv_cnt + 1 627 my_bands(recv_cnt) = ib 628 call MPI_Irecv( bufs_wfn(1, recv_cnt), ng0, MPI_SCALAR, MPI_ANY_SOURCE, & 629 ib, MPI_COMM_WORLD, req_recv(recv_cnt), mpierr ) 630 endif 631 632 ! FHJ: Sending part 633 if (ib<=vwfn%nband) then 634 local_band_idx = peinf%indexv(ib)+(ik-1)*peinf%nvownactual 635 if (should_send(peinf%does_it_ownv)) & 636 call MPI_Isend( intwfnv%cg(1, local_band_idx, is), ng0, MPI_SCALAR, receiver, ib, & 637 MPI_COMM_WORLD, req_send(send_cnt), mpierr ) 638 else 639 local_band_idx = peinf%indexc(ib-vwfn%nband)+(ik-1)*peinf%ncownactual 640 if (should_send(peinf%does_it_ownc)) & 641 call MPI_Isend( intwfnc%cg(1, local_band_idx, is), ng0, MPI_SCALAR, receiver, ib, & 642 MPI_COMM_WORLD, req_send(send_cnt), mpierr ) 643 endif 644 enddo 645 646 if (recv_cnt>0) then 647 call MPI_Waitall(recv_cnt, req_recv, MPI_STATUSES_IGNORE, mpierr) 648 endif 649 if (send_cnt>0) then 650 call MPI_Waitall(send_cnt, req_send, MPI_STATUSES_IGNORE, mpierr) 651 endif 652 653 if(peinf%inode==0) call timing%stop(timing%opt_fft_init) 654 if(peinf%inode==0) call timing%start(timing%opt_fft_fft) 655 656 ! FHJ: TODO - change genwf_mpi.f90 so that it doesn`t use cwfn%el 657 ! (see genwf_mpi.f90) 658 659 ! FHJ: Deallocate original wavefunctions (but only the wfn!) 660 call free_wfns(pol, intwfnv, intwfnvq, intwfnc, .false.) 661 662 call logit('doing FFTs') 663 664 ! FHJ: this will be more complicated to implement for more than 1 kpt 665 if (pol%min_fftgrid) then 666 pol%isrtx => wfn_isort 667 pol%nmtx = gcutoff(gvec%ng, ekin, pol%isrtx, pol%ecuts) 668 call get_eps_fftgrid(pol, gvec) 669 nullify(pol%isrtx) 670 endif 671 ! FHJ: Do all the FFTs 672 call setup_FFT_sizes(pol%FFTgrid,Nfft,scale) 673 fft_size = Nfft(1)*Nfft(2)*Nfft(3) 674 675 SAFE_ALLOCATE( vwfn%wfn_fft, (Nfft(1), Nfft(2), Nfft(3), peinf%nvownactual) ) 676 SAFE_ALLOCATE( cwfn%wfn_fft, (Nfft(1), Nfft(2), Nfft(3), peinf%ncownactual) ) 677 SAFE_ALLOCATE( work_ffts, (Nfft(1), Nfft(2), Nfft(3), recv_cnt) ) 678 679 if (peinf%inode==0) write(0,*) 680 do ib = 1, recv_cnt 681 do ig=1,ng0 682 tmp_wfn(ig) = bufs_wfn(ind(ig), ib)*ph(ig) 683 enddo 684 call put_into_fftbox(ng0, tmp_wfn, gvec%components, wfn_isort, work_ffts(:,:,:,ib), Nfft) 685 call do_FFT(work_ffts(:,:,:,ib), Nfft, 1) 686 if (my_bands(ib)<=vwfn%nband) then 687 call conjg_fftbox(work_ffts(:,:,:,ib), Nfft) 688 endif 689 enddo 690 691 if(peinf%inode==0) call timing%stop(timing%opt_fft_fft) 692 if(peinf%inode==0) call timing%start(timing%opt_fft_comm_fft) 693 694 call logit('done with FFTs') 695 696 do ib_loc = 1,peinf%nvownactual 697 ib = peinf%invindexv(ib_loc) 698 call MPI_Irecv( vwfn%wfn_fft(1,1,1,ib_loc), fft_size, MPI_COMPLEX_DPC, band_owners(ib), & 699 ib, MPI_COMM_WORLD, req_recv(ib_loc), mpierr ) 700 enddo 701 do ib_loc = 1,peinf%ncownactual 702 ib = peinf%invindexc(ib_loc)+vwfn%nband 703 call MPI_Irecv( cwfn%wfn_fft(1,1,1,ib_loc), fft_size, MPI_COMPLEX_DPC, band_owners(ib), & 704 ib, MPI_COMM_WORLD, req_recv(peinf%nvownactual+ib_loc), mpierr ) 705 enddo 706 707 ! FHJ: And send them back using point-to-point communication. 708 send_cnt = 0 709 do ib_loc = 1, recv_cnt 710 ib = my_bands(ib_loc) 711 712 if (ib<=vwfn%nband) then 713 do ipe = 0, peinf%npes-1 714 if ( peinf%does_it_ownv(ib,ipe+1) ) then 715 send_cnt = send_cnt + 1 716 call MPI_Isend( work_ffts(1,1,1,ib_loc), fft_size, MPI_COMPLEX_DPC, ipe, & 717 ib, MPI_COMM_WORLD, req_send(send_cnt), mpierr ) 718 endif 719 enddo 720 else 721 do ipe = 0, peinf%npes-1 722 if ( peinf%does_it_ownc(ib-vwfn%nband,ipe+1) ) then 723 send_cnt = send_cnt + 1 724 call MPI_Isend( work_ffts(1,1,1,ib_loc), fft_size, MPI_COMPLEX_DPC, ipe, & 725 ib, MPI_COMM_WORLD, req_send(send_cnt), mpierr ) 726 endif 727 enddo 728 endif 729 enddo 730 731 recv_cnt = peinf%nvownactual + peinf%ncownactual 732 if (recv_cnt>0) then 733 call MPI_Waitall(recv_cnt, req_recv, MPI_STATUSES_IGNORE, mpierr) 734 endif 735 if (send_cnt>0) then 736 call MPI_Waitall(send_cnt, req_send, MPI_STATUSES_IGNORE, mpierr) 737 endif 738 739 if(peinf%inode==0) call timing%stop(timing%opt_fft_comm_fft) 740 741 SAFE_DEALLOCATE(work_ffts) 742 SAFE_DEALLOCATE(bufs_wfn) 743 SAFE_DEALLOCATE(req_send) 744 SAFE_DEALLOCATE(req_recv) 745 SAFE_DEALLOCATE(my_bands) 746 SAFE_DEALLOCATE(band_owners) 747 SAFE_DEALLOCATE(isort0) 748 749 ! and free sorting stuff 750 SAFE_DEALLOCATE(ind) 751 SAFE_DEALLOCATE(ph) 752 SAFE_DEALLOCATE(wfn_isort) 753 SAFE_DEALLOCATE(isorti) 754 SAFE_DEALLOCATE(ekin) 755 SAFE_DEALLOCATE(tmp_wfn) 756 757 call logit('done generating real-space wavefunctions') 758 759 if(peinf%inode==0) call timing%stop(timing%opt_fft) 760 761#endif 762 763 POP_SUB(genwf_FFT) 764 return 765 766 contains 767 768 logical function should_send(own_arr) 769 logical, intent(in) :: own_arr(:,:) 770 771 integer :: sender, ib_ 772 773 PUSH_SUB(genwf_FFT.should_send) 774 775 ib_ = ib 776 if (ib_ > vwfn%nband) ib_ = ib_ - vwfn%nband 777 778 should_send = .false. 779#ifdef MPI 780 sender = -1 781 do ipe = 0, peinf%npes-1 782 if (own_arr(ib_, ipe+1)) then 783 sender = ipe 784 exit 785 endif 786 enddo 787 788 if (sender==-1) call die('No sender found!') 789 790 if (sender==peinf%inode) then 791 send_cnt = send_cnt + 1 792 should_send = .true. 793 endif 794#endif 795 796 POP_SUB(genwf_FFT.should_send) 797 return 798 799 end function should_send 800 801 end subroutine genwf_FFT 802 803 !> Generic routine that generates wavefunctions 804 subroutine genwf_gen(syms,gvec,crys,kp,kpq,irk,rk,qq,vwfn,pol,cwfn,use_wfnq,intwfnv,intwfnvq,intwfnc,iv) 805 type (symmetry), intent(in) :: syms 806 type (gspace), intent(in) :: gvec 807 type (crystal), intent(in) :: crys 808 type (kpoints), target, intent(in) :: kp 809 type (kpoints), target, intent(in) :: kpq 810 integer, intent(in) :: irk 811 real(DP), intent(in) :: rk(3) 812 real(DP), intent(in) :: qq(3) 813 type (valence_wfns), intent(inout) :: vwfn 814 type (polarizability), intent(in) :: pol 815 type (conduction_wfns), intent(inout) :: cwfn 816 logical, intent(in) :: use_wfnq 817 type (int_wavefunction), intent(in) :: intwfnv 818 type (int_wavefunction), intent(in) :: intwfnvq 819 type (int_wavefunction), intent(in) :: intwfnc 820 integer, intent(in) :: iv 821 822 integer :: ivr 823 824 PUSH_SUB(genwf_gen) 825 826 call logit('calling genwf') 827 if(peinf%inode==0) call timing%start(timing%genwf) 828 if (iv .le. peinf%nvownactual) then 829 call genwf_mpi(syms,gvec,crys,kp,kpq,irk,rk,qq,vwfn,pol,cwfn,use_wfnq,intwfnv,intwfnvq,intwfnc,iv) 830 endif 831 if(peinf%inode==0) call timing%stop(timing%genwf) 832 call logit('done genwf') 833 834 POP_SUB(genwf_gen) 835 return 836 837 end subroutine genwf_gen 838 839 !> Deallocate all "intermediate" wavefunctions 840 subroutine free_wfns(pol, intwfnv, intwfnvq, intwfnc, free_all) 841 type(polarizability), intent(in) :: pol 842 type(int_wavefunction), intent(inout) :: intwfnv, intwfnvq, intwfnc 843 logical, intent(in) :: free_all !< if .false., then only %ng and %isort will be preserved 844 845 PUSH_SUB(free_wfns) 846 847 if (free_all) then 848 SAFE_DEALLOCATE_P(intwfnv%ng) 849 SAFE_DEALLOCATE_P(intwfnv%isort) 850 endif 851 SAFE_DEALLOCATE_P(intwfnv%cg) 852 SAFE_DEALLOCATE_P(intwfnv%qk) 853 if (pol%need_WFNq) then 854 if (free_all) then 855 SAFE_DEALLOCATE_P(intwfnvq%ng) 856 SAFE_DEALLOCATE_P(intwfnvq%isort) 857 endif 858 SAFE_DEALLOCATE_P(intwfnvq%cg) 859 SAFE_DEALLOCATE_P(intwfnvq%qk) 860 endif 861 if (free_all) then 862 SAFE_DEALLOCATE_P(intwfnc%ng) 863 SAFE_DEALLOCATE_P(intwfnc%isort) 864 endif 865 SAFE_DEALLOCATE_P(intwfnc%cg) 866 SAFE_DEALLOCATE_P(intwfnc%cbi) 867 SAFE_DEALLOCATE_P(intwfnc%qk) 868 869 POP_SUB(free_wfns) 870 return 871 872 end subroutine free_wfns 873 874end module genwf_eps_m 875