1
2#define NBLOCKS 4
3
4
5*
6*     ***********************************
7*     *					*
8*     *	        D1dBs_SumAll		*
9*     *					*
10*     ***********************************
11
12      subroutine D1dBs_SumAll(sum)
13c     implicit none
14      real  sum
15
16#include "D1dB.fh"
17
18#ifdef MPI4
19#include "stupid_mpi4.fh"
20#else
21#include "mpif.h"
22#endif
23
24
25      integer msglen,mpierr,np
26      real  sumall
27
28
29*     **** external functions ****
30      integer  Parallel2d_comm_j
31      external Parallel2d_comm_j
32
33      call Parallel2d_np_j(np)
34      if (np.gt.1) then
35!$OMP MASTER
36#ifdef MPI4
37        stupid_msglen = 1
38        call MPI_Allreduce(sum,sumall,stupid_msglen,stupid_real,
39     >                     stupid_sum,stupid_comm_j,stupid_ierr)
40#else
41        msglen = 1
42        call MPI_Allreduce(sum,sumall,msglen,MPI_REAL,
43     >                      MPI_SUM,Parallel2d_comm_j(),mpierr)
44#endif
45        sum = sumall
46!$OMP END MASTER
47!$OMP BARRIER
48      end if
49
50      return
51      end
52
53
54
55*     ***********************************
56*     *                                 *
57*     *         D1dBs_MaxAll            *
58*     *                                 *
59*     ***********************************
60      subroutine D1dBs_MaxAll(sum)
61c     implicit none
62      real  sum
63
64#include "D1dB.fh"
65
66#ifdef MPI4
67#include "stupid_mpi4.fh"
68#else
69#include "mpif.h"
70#endif
71
72
73      integer msglen,mpierr,np
74      real  sumall
75
76*     **** external functions ****
77      integer  Parallel2d_comm_j
78      external Parallel2d_comm_j
79
80      call Parallel2d_np_j(np)
81      if (np.gt.1) then
82#ifdef MPI4
83        stupid_msglen = 1
84        call MPI_Allreduce(sum,sumall,stupid_msglen,stupid_real,
85     >                      stupid_max,stupid_comm_j,stupid_ierr)
86#else
87        msglen = 1
88        call MPI_Allreduce(sum,sumall,msglen,MPI_REAL,
89     >                      MPI_MAX,Parallel2d_comm_j(),mpierr)
90#endif
91        sum = sumall
92      end if
93
94      return
95      end
96
97
98
99
100
101*     ***********************************
102*     *					*
103*     *	       D1dBs_Vector_SumAll	*
104*     *					*
105*     ***********************************
106
107      subroutine D1dBs_Vector_SumAll(n,sum)
108c     implicit none
109      integer n
110      real  sum(*)
111
112#include "bafdecls.fh"
113#include "errquit.fh"
114#include "D1dB.fh"
115
116#ifdef MPI4
117#include "stupid_mpi4.fh"
118#else
119#include "mpif.h"
120#endif
121
122      logical value
123      integer msglen
124      integer sumall(2),np,mpierr
125      integer tid,nthr,nn,nnmod,istart
126
127*     **** external functions ****
128      integer  Parallel2d_comm_j
129      external Parallel2d_comm_j
130
131      call Parallel2d_np_j(np)
132      call nwpw_timing_start(2)
133      if (np.gt.1) then
134
135*     ***** allocate temporary space ****
136      value = BA_push_get(mt_real,n,'sumall',sumall(2),sumall(1))
137      if (.not. value) call errquit('out of stack memory',0, MA_ERR)
138
139!$OMP MASTER
140#ifdef MPI4
141      stupid_msglen = n
142      call MPI_Allreduce(sum,real_mb(sumall(1)),stupid_msglen,
143     >                stupid_real,
144     >                stupid_sum,stupid_comm_j,stupid_ierr)
145#else
146      msglen = n
147      call MPI_Allreduce(sum,real_mb(sumall(1)),msglen,
148     >                MPI_REAL,
149     >                MPI_SUM,Parallel2d_comm_j(),mpierr)
150#endif
151      call Parallel_shared_vector_scopy(.false.,n,
152     >                                  real_mb(sumall(1)),sum)
153!$OMP END MASTER
154!$OMP BARRIER
155
156      value = BA_pop_stack(sumall(2))
157
158      end if
159      call nwpw_timing_end(2)
160      return
161      end
162
163
164*     ***********************************
165*     *                                 *
166*     *      D1dBs_Brdcst_values        *
167*     *                                 *
168*     ***********************************
169
170      subroutine D1dBs_Brdcst_values(psend,nsize,sum)
171      implicit none
172      integer psend,nsize
173      real    sum(*)
174
175#ifdef MPI4
176#include "stupid_mpi4.fh"
177#else
178#include "mpif.h"
179#endif
180
181
182#ifdef MPI4
183      integer*4 tpsend
184      integer np
185
186      call Parallel2d_np_j(np)
187      if (np.gt.1) then
188!$OMP MASTER
189         stupid_msglen = nsize
190         tpsend        = psend
191         call MPI_Bcast(sum,stupid_msglen,stupid_real,
192     >                  tpsend,stupid_comm_j,stupid_ierr)
193!$OMP END MASTER
194!$OMP BARRIER
195      end if
196#else
197*     **** external functions ****
198      integer  Parallel2d_comm_j
199      external Parallel2d_comm_j
200
201      integer ierr,np
202
203      call Parallel2d_np_j(np)
204      if (np.gt.1) then
205!$OMP MASTER
206         call MPI_Bcast(sum,nsize,MPI_REAL,
207     >                  psend,Parallel2d_comm_j(),ierr)
208!$OMP END MASTER
209!$OMP BARRIER
210      end if
211#endif
212
213      return
214      end
215
216
217
218
219*     ***********************************
220*     *                                 *
221*     *         D1dBs_start_rot         *
222*     *                                 *
223*     ***********************************
224*
225* This routine start sends A buffer to proc_to = mod(taskid_j+j,np_j)
226* and receives W buffer from proc_from = mod(taskid_j-j+np_j,np_j)
227* where taskid_j and np_j is the taskid and number of processors for
228* the Parallel2d_comm_j communicator
229*
230      subroutine D1dBs_start_rot(j,
231     >                           A,W,lda,na,
232     >                           request)
233      implicit none
234      integer j
235      real     A(*),W(*)
236      integer lda,na(*)
237      integer request(*)
238
239#include "D1dB.fh"
240
241#include "mpif.h"
242#ifdef MPI4
243#include "stupid_mpi4.fh"
244#endif
245
246*     **** local variables ****
247      integer amsglen,wmsglen
248      integer proc_to,proc_from,msgtype,mpierr
249      integer taskid_j
250
251*     **** external functions ****
252      integer  Parallel2d_comm_j
253      external Parallel2d_comm_j
254
255      call Parallel2d_taskid_j(taskid_j)
256
257      proc_to   = mod(taskid_j+j,np_j)
258      proc_from = mod(taskid_j-j+np_j,np_j)
259      msgtype   = j
260      amsglen = lda*na(taskid_j+1)
261      wmsglen = lda*na(proc_from+1)
262
263#ifdef MPI4
264            if (wmsglen.gt.0) then
265               stupid_msglen = wmsglen
266               stupid_type   = msgtype
267               stupid_taskid = proc_from
268               call MPI_IRECV(W,
269     >                    stupid_msglen,stupid_real,
270     >                    stupid_taskid,
271     >                    stupid_type,stupid_comm_j,
272     >                    stupid_request,stupid_ierr)
273               request(1) = stupid_request
274               request(3) = 1
275            else
276               request(3) = 0
277            end if
278
279            if (amsglen.gt.0) then
280               stupid_msglen = amsglen
281               stupid_type   = msgtype
282               stupid_taskid = proc_to
283               call MPI_ISEND(A,
284     >                     stupid_msglen,stupid_real,
285     >                     stupid_taskid,
286     >                     stupid_type,stupid_comm_j,
287     >                     stupid_request,stupid_ierr)
288               request(2) = stupid_request
289               request(4) = 1
290            else
291               request(4) = 0
292            end if
293#else
294            if (wmsglen.gt.0) then
295               call MPI_IRECV(W,wmsglen,MPI_REAL,
296     >                    proc_from,
297     >                    msgtype,Parallel2d_comm_j(),
298     >                    request(1),mpierr)
299               request(3) = 1
300            else
301               request(3) = 0
302            end if
303            if (amsglen.gt.0) then
304               call MPI_ISEND(A,amsglen,MPI_REAL,
305     >                     proc_to,
306     >                     msgtype,Parallel2d_comm_j(),
307     >                     request(2),mpierr)
308               request(4) = 1
309            else
310               request(4) = 0
311            end if
312#endif
313
314      if ((request(3).eq.1).and.(request(4).eq.1)) then
315         request(3) = 1
316      else if (request(3).eq.1) then
317         request(3) = 2
318      else if (request(4).eq.1) then
319         request(3) = 3
320      else
321         request(3) = 4
322      end if
323
324      return
325      end
326
327*     ***********************************
328*     *                                 *
329*     *         D1dBs_end_rot           *
330*     *                                 *
331*     ***********************************
332*
333*  This routine waits for the sends and receives to
334* finish that were started with D1dBs_start_rot routine.
335*
336      subroutine D1dBs_end_rot(request)
337      implicit none
338      integer request(*)
339
340*     **** wait for completion of mp_send, also do a sync ****
341      if (request(3).eq.1) then
342         call Parallel_mpiWaitAll(2,request)
343      else if (request(3).eq.2) then
344         call Parallel_mpiWaitAll(1,request)
345      else if (request(3).eq.3) then
346          call Parallel_mpiWaitAll(1,request(2))
347      endif
348
349      return
350      end
351
352
353
354*     ***********************************
355*     *                                 *
356*     *         D1dBs_isendrecv         *
357*     *                                 *
358*     ***********************************
359      subroutine D1dBs_isendrecv(pto,  ssize,sdata,
360     >                           pfrom,rsize,rdata,
361     >                           request,reqcnt)
362      implicit none
363      integer pto,ssize
364      real    sdata(*)
365      integer pfrom,rsize
366      real    rdata(*)
367      integer request(*)
368      integer reqcnt
369
370#include "D1dB.fh"
371
372#ifdef MPI4
373#include "stupid_mpi4.fh"
374#else
375#include "mpif.h"
376#endif
377
378*     **** local variables ****
379      integer msgtype,mpierr
380
381*     **** external functions ****
382      integer  Parallel2d_comm_j
383      external Parallel2d_comm_j
384
385
386      msgtype   = 7
387
388#ifdef MPI4
389            if (rsize.gt.0) then
390               stupid_msglen = rsize
391               stupid_type   = msgtype
392               stupid_taskid = pfrom
393               call MPI_IRECV(rdata,
394     >                    stupid_msglen,stupid_real,
395     >                    stupid_taskid,
396     >                    stupid_type,stupid_comm_j,
397     >                    stupid_request,stupid_ierr)
398               reqcnt          = reqcnt + 1
399               request(reqcnt) = stupid_request
400            end if
401
402            if (ssize.gt.0) then
403               stupid_msglen = ssize
404               stupid_type   = msgtype
405               stupid_taskid = pto
406               call MPI_ISEND(sdata,
407     >                     stupid_msglen,stupid_real,
408     >                     stupid_taskid,
409     >                     stupid_type,stupid_comm_j,
410     >                     stupid_request,stupid_ierr)
411               reqcnt          = reqcnt + 1
412               request(reqcnt) = stupid_request
413            end if
414#else
415            if (rsize.gt.0) then
416               reqcnt = reqcnt + 1
417               call MPI_IRECV(rdata,rsize,MPI_REAL,
418     >                    pfrom,
419     >                    msgtype,Parallel2d_comm_j(),
420     >                    request(reqcnt),mpierr)
421            end if
422            if (ssize.gt.0) then
423               reqcnt = reqcnt + 1
424               call MPI_ISEND(sdata,ssize,MPI_REAL,
425     >                     pto,
426     >                     msgtype,Parallel2d_comm_j(),
427     >                     request(reqcnt),mpierr)
428            end if
429#endif
430
431      return
432      end
433
434
435c     ****************************************
436c     *                                      *
437c     *          D1dBs_Brdcst_step           *
438c     *                                      *
439c     ****************************************
440c
441c  This routine performs step l of a butterfly Broadcast all algorithm. The step
442c  l spans from 0..(Level-1) where the number of levels is Level = Log(np_j)/Log(2).
443c
444c   Entry - l: Butterfly step 0...(Level-1)
445c           na: an array of length np_j containing the number of orbitals per taskid_j
446c           blocks0: number of blocks to send size=blocks0,
447c                    the exceptions are:
448c                    if blocks0==0: the block size is size=2**l.
449c                    if blocks0==-1: block size is size=(np_j-2**Level)/2 + 1 for l==(Level-1),
450c                                    blocksize is size=2**l otherwise
451c           n2ft3d: leading size of psi_rep
452c           psi_rep: data array
453c   Exit -
454c           psi_rep: modified data array
455c           requests,reqcnt: tags for asychronous message passing
456c
457      subroutine D1dBs_Brdcst_step(l,na,blocks0,
458     >                             n2ft3d,psi_rep,
459     >                             requests,reqcnt)
460      implicit none
461      integer l,na(*),blocks0
462      integer n2ft3d
463      real    psi_rep(n2ft3d,*)
464      integer requests(*),reqcnt
465
466*     *** local variables ***
467      integer taskid_j,np_j
468      integer i,pr,ps,shift,size,Level
469      integer pto,pfrom,rsize,ssize,rindx,sindx
470
471*     *** local variables ***
472      integer  Butter_levels
473      external Butter_levels
474
475      call Parallel2d_taskid_j(taskid_j)
476      call Parallel2d_np_j(np_j)
477
478cccc      Level = (log(dble(np_j))/log(2.0d0))
479      Level = Butter_levels(np_j)
480
481      shift = 2**l
482      pfrom = mod(taskid_j     +shift,np_j)
483      pto   = mod(taskid_j+np_j-shift,np_j)
484
485*     *** hfx exception ***
486      if (blocks0.lt.0) then
487         if (l.eq.(Level-1)) then
488            size = ((np_j-2**Level)/2) + 1
489         else
490            size =  shift
491         end if
492*     *** standard butterfly exception ***
493      elseif  (blocks0.eq.0) then
494         size = shift
495*     *** user blocksize ***
496      else
497         size = blocks0
498      end if
499
500*     *** determine message sizes ***
501      rsize = 0
502      ssize = 0
503      do i=0,(size-1)
504        pr = mod(pfrom    + i,np_j)
505        ps = mod(taskid_j + i,np_j)
506        rsize = rsize + na(pr+1)
507        ssize = ssize + na(ps+1)
508      end do
509
510*     *** determine message indexes ***
511      rindx = 1
512      do  i=0,(shift-1)
513        ps = mod(taskid_j + i,np_j)
514        rindx = rindx + na(ps+1)
515      end do
516      sindx = 1
517
518
519*     *** will be much more complicated for synchronous ***
520      reqcnt = 0
521      call D1dBs_isendrecv(pto,  ssize*n2ft3d,psi_rep(1,sindx),
522     >                     pfrom,rsize*n2ft3d,psi_rep(1,rindx),
523     >                     requests,reqcnt)
524      return
525      end
526
527
528c     ****************************************
529c     *                                      *
530c     *           D1dBs_Reduce_step          *
531c     *                                      *
532c     ****************************************
533c
534c  This routine performs step l of a butterfly Reduceall algorithm. The step
535c  l spans from 0..(Level-1) where the number of levels is Level = Log(np_j)/Log(2).
536c
537c   Entry - l: Butterfly step 0...(Level-1)
538c           na: an array of length np_j containing the number of orbitals per taskid_j
539c           blocks0: number of blocks to send size=blocks0,
540c                    the exceptions are:
541c                    if blocks0==0: the block size is size=2**l.
542c                    if blocks0==-1: block size is size=(np_j-2**Level)/2 + 1 for l==(Level-1),
543c                                    blocksize is size=2**l otherwise
544c           n2ft3d: leading size of psi_rep
545c           hpsi_rep: data array
546c           tmp: tempory data array. Needs to be at least  n2ft3d*size
547c
548c   Exit - hpsi_rep: modified data array
549c          requests,reqcnt: tags for asychronous message passing
550c
551      subroutine D1dBs_Reduce_step(l,na,blocks0,
552     >                             n2ft3d,hpsi_rep,tmp)
553      implicit none
554      integer l,na(*),blocks0
555      integer n2ft3d
556      real    hpsi_rep(n2ft3d,*)
557      real    tmp(*)
558
559*     *** local variables ***
560      integer taskid_j,np_j
561      integer i,pr,ps,size,shift,Level,pfrom,pto
562      integer rsize,ssize,rindx,sindx
563      integer requests(10),reqcnt
564
565*     *** local variables ***
566      integer  Butter_levels
567      external Butter_levels
568
569      call Parallel2d_taskid_j(taskid_j)
570      call Parallel2d_np_j(np_j)
571
572cccccc      !Level = (log(dble(np_j))/log(2.0d0))
573      Level = Butter_levels(np_j)
574
575      shift = 2**l
576      pfrom = mod(taskid_j+np_j-shift,np_j)
577      pto   = mod(taskid_j     +shift,np_j)
578
579*     *** hfx exception ***
580      if (blocks0.lt.0) then
581         if (l.eq.(Level-1)) then
582            size = ((np_j-2**Level)/2) + 1
583         else
584            size =  shift
585         end if
586*     *** standard butterfly exception ***
587      elseif  (blocks0.eq.0) then
588         size = shift
589*     *** user blocksize ***
590      else
591         size = blocks0
592      end if
593
594*     *** determine message sizes ***
595      rsize = 0
596      ssize = 0
597      do i=0,(size-1)
598        pr = mod(taskid_j + i,np_j)
599        ps = mod(pto      + i,np_j)
600        rsize = rsize + na(pr+1)
601        ssize = ssize + na(ps+1)
602      end do
603
604*     *** determine message indexes ***
605      rindx = 1
606      sindx = 1
607      do i=0,(shift-1)
608         ps = mod(taskid_j  + i,np_j)
609         sindx = sindx + na(ps+1)
610      end do
611
612
613*     *** will be much more complicated for synchronous ***
614      reqcnt = 0
615      call D1dBs_isendrecv(pto,  ssize*n2ft3d,hpsi_rep(1,sindx),
616     >                     pfrom,rsize*n2ft3d,tmp,
617     >                    requests,reqcnt)
618
619      call D1dB_WaitAll(requests,reqcnt)
620      call SAXPY_OMP(rsize*n2ft3d,1.0d0,tmp,1,hpsi_rep(1,rindx),1)
621
622      return
623      end
624c $Id$
625