1!==========================================================================
2!
3! Module:
4!
5! genwf_eps      Originally By FHJ             Last Modified 04/2012 (FHJ)
6!
7!
8!==========================================================================
9
10#include "f_defs.h"
11
12module genwf_eps_m
13
14  use global_m
15  use input_utils_m
16  use fftw_m
17  use genwf_mpi_m
18  use gmap_m
19  use sort_m
20  use timing_m, only: timing => epsilon_timing
21
22  implicit none
23
24  private
25
26  !> communicator object for the WFN FFTs
27  type wfn_FFT_comm_t
28    logical :: done !< Have we received all the buffers?
29    integer, pointer :: req_recvv(:), req_recvc(:) !< Array of MPI_REQUEST
30    integer, pointer :: req_sendv(:), req_sendc(:) !< Array of MPI_REQUEST
31    integer :: recv_cntv, recv_cntc !< Number of requests
32    integer :: send_cntv, send_cntc !< Number of requests
33    integer :: nv !< Number of valence bands
34  end type wfn_FFT_comm_t
35
36  public :: genwf_gen, genwf_FFT, free_wfns, genwf_lvl2, &
37    genwf_FFT_Isend, genwf_FFT_Wait, wfn_FFT_comm_t, &
38    get_wfn_fftgrid, get_eps_fftgrid
39
40contains
41
42  !> FHJ: Figure out what is the min/max gvec components give an isort array
43  !! Note: you should manually initialize box_min and box_max to zero!
44  subroutine get_gvecs_bounds(gvec, ng, isort, box_min, box_max)
45    type(gspace), intent(in) :: gvec
46    integer, intent(in) :: ng
47    integer, intent(in) :: isort(:)
48    integer, intent(inout) :: box_min(3), box_max(3)
49
50    integer :: ig
51
52    PUSH_SUB(get_gvecs_bounds)
53
54    do ig=1,ng
55      box_min(1:3) = min(box_min(1:3),  gvec%components(1:3, isort(ig)))
56      box_max(1:3) = max(box_max(1:3),  gvec%components(1:3, isort(ig)))
57    enddo
58
59    POP_SUB(get_gvecs_bounds)
60    return
61
62  end subroutine get_gvecs_bounds
63
64  !> FHJ: Figure out the minimum fftbox that holds all the WFNs
65  subroutine get_wfn_fftgrid(pol, gvec, kp, intwfn)
66    type(polarizability), intent(inout) :: pol
67    type(gspace), intent(in) :: gvec
68    type(kpoints), intent(in) :: kp
69    type(int_wavefunction), intent(in) :: intwfn
70
71    integer :: ik, wfn_box_min(3), wfn_box_max(3)
72
73    PUSH_SUB(get_wfn_fftgrid)
74
75    wfn_box_min(:) = 0; wfn_box_max(:) = 0
76    do ik=1,kp%nrk
77      call get_gvecs_bounds(gvec, intwfn%ng(ik), intwfn%isort(:,ik), wfn_box_min, wfn_box_max)
78    enddo
79    pol%WFN_FFTgrid(1:3) = wfn_box_max(1:3) - wfn_box_min(1:3) + 1
80    if (peinf%verb_debug .and. peinf%inode==0) then
81      write(6,*) 'WFN min. FFT grid:',pol%WFN_FFTgrid
82    endif
83
84    POP_SUB(get_wfn_fftgrid)
85    return
86
87  end subroutine get_wfn_fftgrid
88
89  !> FHJ: Figure out the minimum fftbox that allows us to convolve the WFNs within
90  !! an energy window of nmtx G vectors.
91  subroutine get_eps_fftgrid(pol, gvec)
92    type(polarizability), intent(inout) :: pol
93    type(gspace), intent(in) :: gvec
94    integer :: eps_box_min(3), eps_box_max(3)
95
96    PUSH_SUB(get_eps_fftgrid)
97
98    eps_box_min(:) = 0; eps_box_max(:) = 0
99    call get_gvecs_bounds(gvec, pol%nmtx, pol%isrtx, eps_box_min, eps_box_max)
100    ! FHJ: Note: the amount of padding is actually N_sig + N_window - 1
101    pol%FFTgrid(1:3) = pol%WFN_FFTgrid(1:3) + (eps_box_max(1:3) - eps_box_min(1:3))
102    pol%FFTgrid(1:3) = min(pol%FFTgrid(1:3), gvec%FFTgrid(1:3))
103    if (peinf%verb_debug .and. peinf%inode==0) then
104      write(6,'(1x,a,3(1x,i0))') 'Original FFT grid:', gvec%FFTgrid
105      write(6,'(1x,a,3(1x,i0))') 'Minimal  FFT grid:', pol%FFTgrid
106    endif
107
108    POP_SUB(get_eps_fftgrid)
109    return
110
111  end subroutine get_eps_fftgrid
112
113  !> FHJ: to be used internally with genwf_FFT_Isend
114  subroutine do_my_FFTs(this,gvec,Nfft,wfn_fft,intwfn,my_bands,my_cnt,ng,tmp_wfn,isort,ind,ph,is_val)
115    type (wfn_FFT_comm_t), intent(inout), target :: this !< communicator object for the WFN FFTs
116    type(gspace), intent(in) :: gvec
117    complex(DPC), intent(inout) :: wfn_fft(:,:,:,:)
118    integer, intent(in) :: Nfft(3)
119    type(int_wavefunction), intent(in) :: intwfn
120    integer, intent(in) :: my_bands(:)  !< my bands
121    integer, intent(in) :: my_cnt       !< number of FFTs to do = sizeof(my_bands)
122    integer, intent(in) :: ng           !< number of gvectors for the k-pt in question
123    SCALAR,  intent(inout) :: tmp_wfn(:)!< buffer to reorder WFN using ind and ph
124    integer, intent(in) :: isort(:)     !< wfn isort
125    integer, intent(in) :: ind(:)
126    SCALAR,  intent(in) :: ph(:)
127    logical, intent(in) :: is_val       !< .true. to take the conjg_fftbox
128
129    integer :: is, ig, fft_size
130    integer :: ib_list, ib_local, ib, iproc, offset
131    integer, pointer :: invindex(:), send_cnt, req_send(:)
132    logical, pointer :: does_it_own(:,:)
133
134    PUSH_SUB(do_my_FFTs)
135
136    is = 1
137    if ( is_val ) then
138      invindex => peinf%invindexv
139      does_it_own => peinf%does_it_ownv
140      req_send => this%req_sendv
141      send_cnt => this%send_cntv
142      offset = 0
143    else
144      invindex => peinf%invindexc
145      does_it_own => peinf%does_it_ownc
146      req_send => this%req_sendc
147      send_cnt => this%send_cntc
148      offset = this%nv
149    endif
150    fft_size = product(Nfft(1:3))
151
152    do ib_list = 1, my_cnt
153      ib_local = my_bands(ib_list)
154      do ig=1,ng
155        tmp_wfn(ig) = intwfn%cg(ind(ig), ib_local, is)*ph(ig)
156      enddo
157      if (peinf%inode==0) call timing%start(timing%opt_fft_fft)
158      call put_into_fftbox(ng, tmp_wfn, gvec%components, &
159        isort, wfn_fft(:,:,:,ib_local), Nfft)
160      call do_FFT( wfn_fft(:,:,:,ib_local), Nfft, 1)
161      if ( is_val ) call conjg_fftbox( wfn_fft(:,:,:,ib_local), Nfft )
162      if (peinf%inode==0) call timing%stop(timing%opt_fft_fft)
163
164#ifdef MPI
165      ! FHJ: distribute all real-space WFNs
166      if(peinf%inode==0) call timing%start(timing%opt_fft_comm_fft)
167      ib = invindex(ib_local)
168      do iproc=0, peinf%npes-1
169        if (does_it_own(ib, iproc+1).and.iproc/=peinf%inode) then
170          send_cnt = send_cnt + 1
171          call MPI_Isend( wfn_fft(1,1,1,ib_local), fft_size, MPI_COMPLEX_DPC, iproc, &
172            offset + ib, MPI_COMM_WORLD, req_send(send_cnt), mpierr )
173        endif
174      enddo
175      if(peinf%inode==0) call timing%stop(timing%opt_fft_comm_fft)
176#endif
177    enddo
178
179    POP_SUB(do_my_FFTs)
180    return
181
182  end subroutine do_my_FFTs
183
184  !> Generates all the real-space wavefunctions. Used only if pol%os_opt_fft==2
185  !! This version avoids communication, and it`s under development!
186  !!TODO`s:
187  !! (1) we are just supporting one spin and one kpt/qpt.
188  !! (2) support serial code
189  subroutine genwf_FFT_Isend(this,crys,gvec,syms,kp,kpq,vwfn,pol,cwfn,intwfnv,intwfnvq,intwfnc)
190    type (wfn_FFT_comm_t), intent(inout) :: this !< communicator object for the WFN FFTs
191    type (crystal), intent(in) :: crys
192    type (gspace), intent(in) :: gvec
193    type (symmetry), intent(in) :: syms
194    type (kpoints), target, intent(in) :: kp
195    type (kpoints), target, intent(in) :: kpq
196    type (valence_wfns), intent(inout) :: vwfn
197    type (polarizability), intent(inout) :: pol
198    type (conduction_wfns), intent(inout) :: cwfn
199    type (int_wavefunction), intent(inout) :: intwfnv
200    type (int_wavefunction), intent(inout) :: intwfnvq
201    type (int_wavefunction), intent(inout) :: intwfnc
202
203    integer :: npes_per_pool, nc_groups, nproc_max, nc
204    integer, allocatable :: nv_bands(:), v_owners(:)
205    integer :: Nfft(3), fft_size
206    real(DP) :: scale
207    integer, allocatable :: my_vbands(:)
208    integer :: my_vcnt, iv, iv_local
209    integer :: ipool, isubrank, inode, ioffset
210    integer, allocatable :: grp_global_ranks(:), ntot_bands(:), &
211      grp_local_owners(:)
212    integer, allocatable :: my_cbands(:)
213    integer :: my_ccnt, ic, ic_local
214    integer :: my_grp_rank, grp_nprocs
215    integer :: min_bands, iproc, iworker
216    integer :: ik, is, ib
217
218    !sort stuff
219    SCALAR, allocatable :: tmp_wfn(:)
220    SCALAR, allocatable :: ph(:)
221    real(DP), allocatable :: ekin(:)
222    integer, allocatable :: ind(:), isorti(:)
223    integer, allocatable, target :: wfn_isort(:)
224    integer :: ng0, ig
225
226    PUSH_SUB(genwf_FFT_Isend)
227
228    call logit('generating all real-space wavefunctions')
229
230    if(peinf%inode==0) call timing%start(timing%opt_fft)
231
232    if(pol%nq>1.or.kp%nrk>1.or.kpq%nrk>0.or.pol%need_WFNq) &
233      call die('FFT opt. level 2 only works for 1 qpt and 1 kpt, and without WFNq',&
234      only_root_writes=.true.)
235
236#ifdef MPI
237
238    ik = 1 ! Fixing one kpt for now
239    is = 1 ! And the spin
240    vwfn%idx_kp = ik
241    cwfn%idx_kp = ik
242    nc = cwfn%nband - vwfn%nband  ! `real` number of conduction bands
243    npes_per_pool = (peinf%npes/peinf%npools) ! number of processors per pool
244    nc_groups = (nc + peinf%ncownmax - 1)/(peinf%ncownmax) ! number of conduction groups
245    nproc_max = (peinf%npes + nc_groups - 1)/(nc_groups) ! max num. of proc. per conduction group
246
247    ! Basic initialization of the communicator object
248    this%recv_cntv = 0; this%recv_cntc = 0
249    this%send_cntv = 0; this%send_cntc = 0
250    this%done = .false.
251    this%nv = vwfn%nband
252    SAFE_ALLOCATE(this%req_recvv, (vwfn%nband*npes_per_pool))
253    SAFE_ALLOCATE(this%req_recvc, (nc*nproc_max))
254    SAFE_ALLOCATE(this%req_sendv, (vwfn%nband*npes_per_pool))
255    SAFE_ALLOCATE(this%req_sendc, (nc*nproc_max))
256
257    if(peinf%inode==0) call timing%start(timing%opt_fft_init)
258
259    ! FHJ: prepare the sorting buffers
260    ng0 = intwfnv%ng(ik) ! We all have at least one valence band
261    SAFE_ALLOCATE(ind,(ng0))
262    SAFE_ALLOCATE(ph, (ng0))
263    SAFE_ALLOCATE(wfn_isort, (gvec%ng))
264    SAFE_ALLOCATE(isorti, (gvec%ng))
265    SAFE_ALLOCATE(ekin, (gvec%ng))
266    call kinetic_energies(gvec, crys%bdot, ekin)
267    call sortrx(gvec%ng, ekin, wfn_isort, gvec=gvec%components)
268
269    do ig=1,gvec%ng
270      isorti(wfn_isort(ig)) = ig
271    enddo
272    do ig=1,ng0
273      isorti(intwfnv%isort(ig, ik)) = ig
274    enddo
275    call gmap(gvec,syms,ng0,1,(/0,0,0/),wfn_isort,isorti,ind,ph,.true.)
276    vwfn%ngv=ng0
277    cwfn%ngc=ng0
278
279    ! FHJ: this will be more complicated to implement for more than 1 kpt
280    if (pol%min_fftgrid) then
281      pol%isrtx => wfn_isort
282      pol%nmtx = gcutoff(gvec%ng, ekin, pol%isrtx, pol%ecuts)
283      call get_eps_fftgrid(pol, gvec)
284      nullify(pol%isrtx)
285    endif
286    call setup_FFT_sizes(pol%FFTgrid,Nfft,scale)
287    fft_size = Nfft(1)*Nfft(2)*Nfft(3)
288
289    ! FHJ: Distribute VALENCE bands !
290    !--------------------------------
291    ! NOTE: we index the processors wrt their real MPI ranks (peinf%inode)
292    !       we refer all the val bands wrt the local index in my_vbands
293
294    ! Number of valence bands that a particular processor owns
295    SAFE_ALLOCATE(nv_bands, (peinf%npes))
296    ! Who owns a particular valence band?
297    SAFE_ALLOCATE(v_owners, (vwfn%nband + pol%ncrit))
298    ! Which valence bands do I own?
299    SAFE_ALLOCATE(my_vbands, (peinf%nvownmax))
300    nv_bands(:) = 0; v_owners(:) = 0
301    my_vbands(:) = 0
302    my_vcnt = 0
303
304    do iv=1, vwfn%nband + pol%ncrit
305      ipool = (iv-1)/peinf%nvownmax ! Get pool for the band
306      ! Distribute bands to the processors in a round-robin way. We add an
307      ! offset so that we get a good load balance for the cond. bands later on.
308      isubrank = mod(iv-1, peinf%nvownmax)
309      ioffset = ipool*peinf%nvownmax
310      inode = mod(isubrank+ioffset, npes_per_pool) + ipool*npes_per_pool
311      v_owners(iv) = inode
312      ! Keep track of the # of times that each processors got a band:
313      nv_bands(inode+1) = nv_bands(inode+1) + 1
314      if (peinf%inode == inode) then
315        my_vcnt = my_vcnt + 1
316        my_vbands(my_vcnt) = isubrank + 1
317      endif
318    enddo
319
320    ! FHJ: Distribute CONDUCTION bands !
321    !----------------------------------!
322    ! NOTE: we use *local ranks* to organize the processors in a particular cond. group
323    !       we refer all the cond. bands wrt the local index in my_cbands
324
325    ! Ranks of the processors that participate in this conduction group
326    SAFE_ALLOCATE(grp_global_ranks, (nproc_max))
327    grp_global_ranks(:) = -1
328    SAFE_ALLOCATE(ntot_bands, (nproc_max)) ! Total number of bands per processor
329    ntot_bands(:) = 0
330    ! Who owns a particular conduction band? (in terms of local workers)
331    SAFE_ALLOCATE(grp_local_owners, (peinf%ncownactual))
332    grp_local_owners(:) = 0
333    ! Which conduction bands do I own?
334    SAFE_ALLOCATE(my_cbands, (peinf%ncownactual))
335    my_cbands(:) = 0
336    my_ccnt = 0
337
338    ! This will be set to my rank within the group of conduction bands, starting at 0.
339    my_grp_rank = -1
340
341    ! Create list of all processors in that same group
342    ! want: grp_global_ranks(:) = [icgroup, icgroup + npes_per_group, ...]
343    grp_global_ranks(:) = 0
344    grp_nprocs = 0
345    ib = peinf%invindexc(1) ! index of first band that I own.
346    if (ib>0) then
347      do iproc=0, peinf%npes-1
348        if (peinf%does_it_ownc(ib, iproc+1)) then
349          grp_nprocs = grp_nprocs + 1
350          grp_global_ranks(grp_nprocs) = iproc
351          if (iproc==peinf%inode) then
352            my_grp_rank = grp_nprocs - 1
353          endif
354        endif
355      enddo
356    endif
357
358    ! Initialize list of the number of bands that each processor has
359    do iproc = 1, grp_nprocs
360      ntot_bands(iproc) = nv_bands(grp_global_ranks(iproc)+1)
361    enddo
362
363    do ic_local = 1, peinf%ncownactual
364      ! Get smallest index of ntot_bands array -> iworker
365      ! note: rank = grp_local_ranks(iworker)
366      iworker = 0
367      min_bands = cwfn%nband + 1
368      do iproc=1, grp_nprocs
369        if (ntot_bands(iproc)<min_bands) then
370          min_bands = ntot_bands(iproc)
371          iworker = iproc - 1
372        endif
373      enddo
374
375      ! add ic_local to list of bands that iworker owns
376      grp_local_owners(ic_local) = iworker
377      ntot_bands(iworker+1) = ntot_bands(iworker+1) + 1
378      if (my_grp_rank == iworker) then
379        my_ccnt = my_ccnt + 1
380        my_cbands(my_ccnt) = ic_local
381      endif
382    enddo
383
384    if(peinf%inode==0) call timing%stop(timing%opt_fft_init)
385    if(peinf%inode==0) call timing%start(timing%opt_fft_comm_fft)
386
387    ! FHJ: Non-blocking Recvs !
388    !-------------------------!
389
390    ! Recv valence bands
391    SAFE_ALLOCATE( vwfn%wfn_fft, (Nfft(1), Nfft(2), Nfft(3), peinf%nvownactual) )
392    do iv_local=1, peinf%nvownactual
393      iv = peinf%invindexv(iv_local)
394      if (v_owners(iv)/=peinf%inode) then
395        this%recv_cntv = this%recv_cntv + 1
396        call MPI_Irecv( vwfn%wfn_fft(1,1,1,iv_local), fft_size, MPI_COMPLEX_DPC, v_owners(iv), &
397          iv, MPI_COMM_WORLD, this%req_recvv(this%recv_cntv), mpierr )
398      endif
399    enddo
400    ! Recv conduction bands
401    SAFE_ALLOCATE( cwfn%wfn_fft, (Nfft(1), Nfft(2), Nfft(3), peinf%ncownactual) )
402    do ic_local=1, peinf%ncownactual
403      ic = peinf%invindexc(ic_local)
404      if (grp_local_owners(ic_local)/=my_grp_rank) then
405        this%recv_cntc = this%recv_cntc + 1
406        iproc = grp_global_ranks(grp_local_owners(ic_local)+1)
407        call MPI_Irecv( cwfn%wfn_fft(1,1,1,ic_local), fft_size, MPI_COMPLEX_DPC, iproc, &
408          vwfn%nband + ic, MPI_COMM_WORLD, this%req_recvc(this%recv_cntc), mpierr )
409      endif
410    enddo
411
412    if(peinf%inode==0) call timing%stop(timing%opt_fft_comm_fft)
413
414    ! FHJ: Do the FFTs + Isend`s !
415    !----------------------------!
416
417    call logit('doing FFTs')
418
419    SAFE_ALLOCATE(tmp_wfn, (ng0))
420    call do_my_FFTs(this,gvec,Nfft,vwfn%wfn_fft,intwfnv,my_vbands,my_vcnt,&
421      ng0,tmp_wfn,wfn_isort,ind,ph,.true.)
422    call do_my_FFTs(this,gvec,Nfft,cwfn%wfn_fft,intwfnc,my_cbands,my_ccnt,&
423      ng0,tmp_wfn,wfn_isort,ind,ph,.false.)
424    SAFE_DEALLOCATE(tmp_wfn)
425
426    ! FHJ: Deallocate original wavefunctions (but only the wfn!)
427    call free_wfns(pol, intwfnv, intwfnvq, intwfnc, .false.)
428
429    call logit('done with FFTs')
430
431    ! Free comm buffers
432    SAFE_DEALLOCATE(nv_bands)
433    SAFE_DEALLOCATE(v_owners)
434    SAFE_DEALLOCATE(my_vbands)
435    SAFE_DEALLOCATE(grp_global_ranks)
436    SAFE_DEALLOCATE(ntot_bands)
437    SAFE_DEALLOCATE(grp_local_owners)
438    SAFE_DEALLOCATE(my_cbands)
439
440    ! and free sorting stuff
441    SAFE_DEALLOCATE(ind)
442    SAFE_DEALLOCATE(ph)
443    SAFE_DEALLOCATE(wfn_isort)
444    SAFE_DEALLOCATE(isorti)
445    SAFE_DEALLOCATE(ekin)
446
447#endif
448
449    call logit('done generating real-space wavefunctions')
450
451    if(peinf%inode==0) call timing%stop(timing%opt_fft)
452
453    POP_SUB(genwf_FFT_Isend)
454    return
455
456  end subroutine genwf_FFT_Isend
457
458
459  ! FHJ: call me after genwf_FFT_Isend, but just before you actually need the data
460  subroutine genwf_FFT_Wait(this)
461    type (wfn_FFT_comm_t), intent(inout) :: this !< communicator object for the WFN FFTs
462
463    PUSH_SUB(genwf_FFT_Wait)
464
465#ifdef MPI
466    if(peinf%inode==0) call timing%start(timing%opt_fft)
467    if(peinf%inode==0) call timing%start(timing%opt_fft_comm_fft)
468    if (this%recv_cntv>0) then
469      call MPI_Waitall(this%recv_cntv, this%req_recvv,  MPI_STATUSES_IGNORE, mpierr)
470    endif
471    if (this%send_cntv>0) then
472      call MPI_Waitall(this%send_cntv, this%req_sendv,  MPI_STATUSES_IGNORE, mpierr)
473    endif
474    if (this%recv_cntc>0) then
475      call MPI_Waitall(this%recv_cntc, this%req_recvc,  MPI_STATUSES_IGNORE, mpierr)
476    endif
477    if (this%send_cntc>0) then
478      call MPI_Waitall(this%send_cntc, this%req_sendc,  MPI_STATUSES_IGNORE, mpierr)
479    endif
480    if(peinf%inode==0) call timing%stop(timing%opt_fft_comm_fft)
481
482    SAFE_DEALLOCATE_P(this%req_recvv)
483    SAFE_DEALLOCATE_P(this%req_sendv)
484    SAFE_DEALLOCATE_P(this%req_recvc)
485    SAFE_DEALLOCATE_P(this%req_sendc)
486    this%done = .true.
487    if(peinf%inode==0) call timing%stop(timing%opt_fft)
488#endif
489
490    POP_SUB(genwf_FFT_Wait)
491    return
492
493  end subroutine genwf_FFT_Wait
494
495
496  subroutine genwf_lvl2(kp,kpq,vwfn,pol,cwfn)
497    type (kpoints), target, intent(in) :: kp
498    type (kpoints), target, intent(in) :: kpq
499    type (valence_wfns), intent(inout) :: vwfn
500    type (polarizability), intent(in) :: pol
501    type (conduction_wfns), intent(inout) :: cwfn
502
503    type(kpoints), pointer :: kp_point
504
505    PUSH_SUB(genwf_lvl2)
506
507    if(pol%need_WFNq) then ! FIXME I think this is wrong if pol%nq1>0
508      kp_point => kpq
509    else
510      kp_point => kp
511    endif
512
513    SAFE_ALLOCATE(vwfn%ev, (vwfn%nband+pol%ncrit, kp%nspin))
514    SAFE_ALLOCATE(cwfn%ec, (cwfn%nband,kp%nspin))
515    vwfn%ev(1:vwfn%nband+pol%ncrit,1:kp%nspin) = &
516      kp_point%el(1:vwfn%nband+pol%ncrit, vwfn%idx_kp, 1:kp%nspin)
517    cwfn%ec(1:cwfn%nband,1:kp%nspin) = &
518      kp%el(1:cwfn%nband, cwfn%idx_kp, 1:kp%nspin)
519
520    POP_SUB(genwf_lvl2)
521    return
522  end subroutine genwf_lvl2
523
524
525  !> Generates all the real-space wavefunctions. Used only if pol%os_opt_fft==2
526  !!TODO`s:
527  !! (1) we are just supporting one spin and one kpt/qpt.
528  !! (2) communication can be reduced if we distribute the WFNs in a smarter way
529  !! (3) support serial code
530  subroutine genwf_FFT(crys,gvec,syms,kp,kpq,vwfn,pol,cwfn,intwfnv,intwfnvq,intwfnc)
531    type (crystal), intent(in) :: crys
532    type (gspace), intent(in) :: gvec
533    type (symmetry), intent(in) :: syms
534    type (kpoints), target, intent(in) :: kp
535    type (kpoints), target, intent(in) :: kpq
536    type (valence_wfns), intent(inout) :: vwfn
537    type (polarizability), intent(inout) :: pol
538    type (conduction_wfns), intent(inout) :: cwfn
539    type (int_wavefunction), intent(inout) :: intwfnv
540    type (int_wavefunction), intent(inout) :: intwfnvq
541    type (int_wavefunction), intent(inout) :: intwfnc
542
543    integer :: ik, is, ipe
544    integer :: own_max
545    integer, allocatable :: band_owners(:)
546    integer :: ib_loc, ib, ng0
547    integer :: receiver, recv_cnt, send_cnt
548    integer :: local_band_idx
549    integer, allocatable :: req_send(:), req_recv(:)
550    integer, allocatable :: my_bands(:), isort0(:)
551    SCALAR, allocatable :: bufs_wfn(:,:)
552    complex(DPC), allocatable :: work_ffts(:,:,:,:)
553    integer :: Nfft(3), fft_size
554    real(DP) :: scale
555    SCALAR, allocatable :: tmp_wfn(:)
556    SCALAR, allocatable :: ph(:)
557    real(DP), allocatable :: ekin(:)
558    integer, allocatable :: ind(:), isorti(:)
559    integer, allocatable, target :: wfn_isort(:)
560    integer :: ig
561
562    PUSH_SUB(genwf_FFT)
563
564    call logit('generating all real-space wavefunctions')
565
566    if(peinf%inode==0) call timing%start(timing%opt_fft)
567
568    if(pol%nq>1.or.kp%nrk>1.or.kpq%nrk>0.or.pol%need_WFNq) &
569      call die('FFT opt. level 2 only works for 1 qpt and 1 kpt, and without WFNq',&
570      only_root_writes=.true.)
571
572#ifdef MPI
573
574    own_max = (cwfn%nband + peinf%npes - 1)/(peinf%npes)
575
576    ik = 1 ! Fixing one kpt for now
577    is = 1 ! And the spin
578
579    vwfn%idx_kp = ik
580    cwfn%idx_kp = ik
581
582    if ( peinf%inode == 0 ) call timing%start(timing%opt_fft_init)
583
584    ! FHJ: Only root is guaranteed to have at least one band.
585    if (peinf%inode==0) ng0 = intwfnv%ng(ik)
586    call MPI_Bcast(ng0, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, mpierr)
587    SAFE_ALLOCATE(isort0, (ng0))
588    if (peinf%inode==0) isort0(1:ng0) = intwfnv%isort(1:ng0, ik)
589    call MPI_Bcast(isort0(1), ng0, MPI_INTEGER, 0, MPI_COMM_WORLD, mpierr)
590
591    ! FHJ: prepare the sorting stuff
592    SAFE_ALLOCATE(ind,(ng0))
593    SAFE_ALLOCATE(ph, (ng0))
594    SAFE_ALLOCATE(wfn_isort, (gvec%ng))
595    SAFE_ALLOCATE(isorti, (gvec%ng))
596    SAFE_ALLOCATE(ekin, (gvec%ng))
597    SAFE_ALLOCATE(tmp_wfn, (ng0))
598    call kinetic_energies(gvec, crys%bdot, ekin)
599    call sortrx(gvec%ng, ekin, wfn_isort, gvec=gvec%components)
600    do ig=1,gvec%ng
601      isorti(wfn_isort(ig)) = ig
602    enddo
603    do ig=1,ng0
604      isorti(isort0(ig)) = ig
605    enddo
606    call gmap(gvec,syms,ng0,1,(/0,0,0/),wfn_isort,isorti,ind,ph,.true.)
607    vwfn%ngv=ng0
608    cwfn%ngc=ng0
609
610    ! FHJ: Work buffers
611    SAFE_ALLOCATE(band_owners, (cwfn%nband))
612    SAFE_ALLOCATE(my_bands, (cwfn%nband))
613    SAFE_ALLOCATE(req_send, (cwfn%nband))
614    SAFE_ALLOCATE(req_recv, (cwfn%nband))
615    SAFE_ALLOCATE(bufs_wfn, (ng0, own_max))
616
617    ! FHJ: Re-distribute WFNs
618    recv_cnt = 0
619    send_cnt = 0
620    do ib = 1, cwfn%nband
621      receiver = mod(ib-1, peinf%npes)
622      band_owners(ib) = receiver
623
624      ! FHJ: Receiving part
625      if (receiver==peinf%inode) then
626        recv_cnt = recv_cnt + 1
627        my_bands(recv_cnt) = ib
628        call MPI_Irecv( bufs_wfn(1, recv_cnt), ng0, MPI_SCALAR, MPI_ANY_SOURCE, &
629          ib, MPI_COMM_WORLD, req_recv(recv_cnt), mpierr )
630      endif
631
632      ! FHJ: Sending part
633      if (ib<=vwfn%nband) then
634        local_band_idx = peinf%indexv(ib)+(ik-1)*peinf%nvownactual
635        if (should_send(peinf%does_it_ownv)) &
636          call MPI_Isend( intwfnv%cg(1, local_band_idx, is), ng0, MPI_SCALAR, receiver, ib, &
637          MPI_COMM_WORLD, req_send(send_cnt), mpierr )
638      else
639        local_band_idx = peinf%indexc(ib-vwfn%nband)+(ik-1)*peinf%ncownactual
640        if (should_send(peinf%does_it_ownc)) &
641          call MPI_Isend( intwfnc%cg(1, local_band_idx, is), ng0, MPI_SCALAR, receiver, ib, &
642          MPI_COMM_WORLD, req_send(send_cnt), mpierr )
643      endif
644    enddo
645
646    if (recv_cnt>0) then
647      call MPI_Waitall(recv_cnt, req_recv,  MPI_STATUSES_IGNORE, mpierr)
648    endif
649    if (send_cnt>0) then
650      call MPI_Waitall(send_cnt, req_send,  MPI_STATUSES_IGNORE, mpierr)
651    endif
652
653    if(peinf%inode==0) call timing%stop(timing%opt_fft_init)
654    if(peinf%inode==0) call timing%start(timing%opt_fft_fft)
655
656    ! FHJ: TODO - change genwf_mpi.f90 so that it doesn`t use cwfn%el
657    ! (see genwf_mpi.f90)
658
659    ! FHJ: Deallocate original wavefunctions (but only the wfn!)
660    call free_wfns(pol, intwfnv, intwfnvq, intwfnc, .false.)
661
662    call logit('doing FFTs')
663
664    ! FHJ: this will be more complicated to implement for more than 1 kpt
665    if (pol%min_fftgrid) then
666      pol%isrtx => wfn_isort
667      pol%nmtx = gcutoff(gvec%ng, ekin, pol%isrtx, pol%ecuts)
668      call get_eps_fftgrid(pol, gvec)
669      nullify(pol%isrtx)
670    endif
671    ! FHJ: Do all the FFTs
672    call setup_FFT_sizes(pol%FFTgrid,Nfft,scale)
673    fft_size = Nfft(1)*Nfft(2)*Nfft(3)
674
675    SAFE_ALLOCATE( vwfn%wfn_fft, (Nfft(1), Nfft(2), Nfft(3), peinf%nvownactual) )
676    SAFE_ALLOCATE( cwfn%wfn_fft, (Nfft(1), Nfft(2), Nfft(3), peinf%ncownactual) )
677    SAFE_ALLOCATE( work_ffts, (Nfft(1), Nfft(2), Nfft(3), recv_cnt) )
678
679    if (peinf%inode==0) write(0,*)
680    do ib = 1, recv_cnt
681      do ig=1,ng0
682        tmp_wfn(ig) = bufs_wfn(ind(ig), ib)*ph(ig)
683      enddo
684      call put_into_fftbox(ng0, tmp_wfn, gvec%components, wfn_isort, work_ffts(:,:,:,ib), Nfft)
685      call do_FFT(work_ffts(:,:,:,ib), Nfft, 1)
686      if (my_bands(ib)<=vwfn%nband) then
687        call conjg_fftbox(work_ffts(:,:,:,ib), Nfft)
688      endif
689    enddo
690
691    if(peinf%inode==0) call timing%stop(timing%opt_fft_fft)
692    if(peinf%inode==0) call timing%start(timing%opt_fft_comm_fft)
693
694    call logit('done with FFTs')
695
696    do ib_loc = 1,peinf%nvownactual
697      ib = peinf%invindexv(ib_loc)
698      call MPI_Irecv( vwfn%wfn_fft(1,1,1,ib_loc), fft_size, MPI_COMPLEX_DPC, band_owners(ib), &
699        ib, MPI_COMM_WORLD, req_recv(ib_loc), mpierr )
700    enddo
701    do ib_loc = 1,peinf%ncownactual
702      ib = peinf%invindexc(ib_loc)+vwfn%nband
703      call MPI_Irecv( cwfn%wfn_fft(1,1,1,ib_loc), fft_size, MPI_COMPLEX_DPC, band_owners(ib), &
704        ib, MPI_COMM_WORLD, req_recv(peinf%nvownactual+ib_loc), mpierr )
705    enddo
706
707    ! FHJ: And send them back using point-to-point communication.
708    send_cnt = 0
709    do ib_loc = 1, recv_cnt
710      ib = my_bands(ib_loc)
711
712      if (ib<=vwfn%nband) then
713        do ipe = 0, peinf%npes-1
714          if ( peinf%does_it_ownv(ib,ipe+1) ) then
715            send_cnt = send_cnt + 1
716            call MPI_Isend( work_ffts(1,1,1,ib_loc), fft_size, MPI_COMPLEX_DPC, ipe, &
717              ib, MPI_COMM_WORLD, req_send(send_cnt), mpierr )
718          endif
719        enddo
720      else
721        do ipe = 0, peinf%npes-1
722          if ( peinf%does_it_ownc(ib-vwfn%nband,ipe+1) ) then
723            send_cnt = send_cnt + 1
724            call MPI_Isend( work_ffts(1,1,1,ib_loc), fft_size, MPI_COMPLEX_DPC, ipe, &
725              ib, MPI_COMM_WORLD, req_send(send_cnt), mpierr )
726          endif
727        enddo
728      endif
729    enddo
730
731    recv_cnt = peinf%nvownactual + peinf%ncownactual
732    if (recv_cnt>0) then
733      call MPI_Waitall(recv_cnt, req_recv,  MPI_STATUSES_IGNORE, mpierr)
734    endif
735    if (send_cnt>0) then
736      call MPI_Waitall(send_cnt, req_send,  MPI_STATUSES_IGNORE, mpierr)
737    endif
738
739    if(peinf%inode==0) call timing%stop(timing%opt_fft_comm_fft)
740
741    SAFE_DEALLOCATE(work_ffts)
742    SAFE_DEALLOCATE(bufs_wfn)
743    SAFE_DEALLOCATE(req_send)
744    SAFE_DEALLOCATE(req_recv)
745    SAFE_DEALLOCATE(my_bands)
746    SAFE_DEALLOCATE(band_owners)
747    SAFE_DEALLOCATE(isort0)
748
749    ! and free sorting stuff
750    SAFE_DEALLOCATE(ind)
751    SAFE_DEALLOCATE(ph)
752    SAFE_DEALLOCATE(wfn_isort)
753    SAFE_DEALLOCATE(isorti)
754    SAFE_DEALLOCATE(ekin)
755    SAFE_DEALLOCATE(tmp_wfn)
756
757    call logit('done generating real-space wavefunctions')
758
759    if(peinf%inode==0) call timing%stop(timing%opt_fft)
760
761#endif
762
763    POP_SUB(genwf_FFT)
764    return
765
766  contains
767
768    logical function should_send(own_arr)
769      logical, intent(in) :: own_arr(:,:)
770
771      integer :: sender, ib_
772
773      PUSH_SUB(genwf_FFT.should_send)
774
775      ib_ = ib
776      if (ib_ > vwfn%nband) ib_ = ib_ - vwfn%nband
777
778      should_send = .false.
779#ifdef MPI
780      sender = -1
781      do ipe = 0, peinf%npes-1
782        if (own_arr(ib_, ipe+1)) then
783          sender = ipe
784          exit
785        endif
786      enddo
787
788      if (sender==-1) call die('No sender found!')
789
790      if (sender==peinf%inode) then
791        send_cnt = send_cnt + 1
792        should_send = .true.
793      endif
794#endif
795
796      POP_SUB(genwf_FFT.should_send)
797      return
798
799    end function should_send
800
801  end subroutine genwf_FFT
802
803  !> Generic routine that generates wavefunctions
804  subroutine genwf_gen(syms,gvec,crys,kp,kpq,irk,rk,qq,vwfn,pol,cwfn,use_wfnq,intwfnv,intwfnvq,intwfnc,iv)
805    type (symmetry), intent(in) :: syms
806    type (gspace), intent(in) :: gvec
807    type (crystal), intent(in) :: crys
808    type (kpoints), target, intent(in) :: kp
809    type (kpoints), target, intent(in) :: kpq
810    integer, intent(in) :: irk
811    real(DP), intent(in) :: rk(3)
812    real(DP), intent(in) :: qq(3)
813    type (valence_wfns), intent(inout) :: vwfn
814    type (polarizability), intent(in) :: pol
815    type (conduction_wfns), intent(inout) :: cwfn
816    logical, intent(in) :: use_wfnq
817    type (int_wavefunction), intent(in) :: intwfnv
818    type (int_wavefunction), intent(in) :: intwfnvq
819    type (int_wavefunction), intent(in) :: intwfnc
820    integer, intent(in) :: iv
821
822    integer :: ivr
823
824    PUSH_SUB(genwf_gen)
825
826    call logit('calling genwf')
827    if(peinf%inode==0) call timing%start(timing%genwf)
828    if (iv .le. peinf%nvownactual) then
829      call genwf_mpi(syms,gvec,crys,kp,kpq,irk,rk,qq,vwfn,pol,cwfn,use_wfnq,intwfnv,intwfnvq,intwfnc,iv)
830    endif
831    if(peinf%inode==0) call timing%stop(timing%genwf)
832    call logit('done genwf')
833
834    POP_SUB(genwf_gen)
835    return
836
837  end subroutine genwf_gen
838
839  !> Deallocate all "intermediate" wavefunctions
840  subroutine free_wfns(pol, intwfnv, intwfnvq, intwfnc, free_all)
841    type(polarizability), intent(in) :: pol
842    type(int_wavefunction), intent(inout) :: intwfnv, intwfnvq, intwfnc
843    logical, intent(in) :: free_all !< if .false., then only %ng and %isort will be preserved
844
845    PUSH_SUB(free_wfns)
846
847    if (free_all) then
848      SAFE_DEALLOCATE_P(intwfnv%ng)
849      SAFE_DEALLOCATE_P(intwfnv%isort)
850    endif
851    SAFE_DEALLOCATE_P(intwfnv%cg)
852    SAFE_DEALLOCATE_P(intwfnv%qk)
853    if (pol%need_WFNq) then
854      if (free_all) then
855        SAFE_DEALLOCATE_P(intwfnvq%ng)
856        SAFE_DEALLOCATE_P(intwfnvq%isort)
857      endif
858      SAFE_DEALLOCATE_P(intwfnvq%cg)
859      SAFE_DEALLOCATE_P(intwfnvq%qk)
860    endif
861    if (free_all) then
862      SAFE_DEALLOCATE_P(intwfnc%ng)
863      SAFE_DEALLOCATE_P(intwfnc%isort)
864    endif
865    SAFE_DEALLOCATE_P(intwfnc%cg)
866    SAFE_DEALLOCATE_P(intwfnc%cbi)
867    SAFE_DEALLOCATE_P(intwfnc%qk)
868
869    POP_SUB(free_wfns)
870    return
871
872  end subroutine free_wfns
873
874end module genwf_eps_m
875