1      subroutine ccsd_trpdrv_nb(t1,
2     &     f1n,f1t,f2n,f2t,f3n,f3t,f4n,f4t,eorb,
3     &     g_objo,g_objv,g_coul,g_exch,
4     &     ncor,nocc,nvir,iprt,emp4,emp5,
5     &     oseg_lo,oseg_hi,
6     $     kchunk, Tij, Tkj, Tia, Tka, Xia, Xka, Jia, Jka, Kia, Kka,
7     $     Jij, Jkj, Kij, Kkj, Dja, Djka, Djia)
8c
9C     $Id$
10c
11c     CCSD(T) non-blocking modifications written by
12c     Jeff Hammond, Argonne Leadership Computing Facility
13c     Fall 2009
14c
15      implicit none
16c
17#include "global.fh"
18#include "ccsd_len.fh"
19#include "ccsdps.fh"
20c
21      double precision t1(*),
22     &     f1n(*),f1t(*),f2n(*),
23     &     f2t(*),f3n(*),f3t(*),f4n(*),f4t(*),eorb(*),
24     &     emp4,emp5
25      double precision Tij(*), Tkj(*), Tia(*), Tka(*), Xia(*), Xka(*),
26     $     Jia(*), Jka(*), Kia(*), Kka(*),
27     $     Jij(*), Jkj(*), Kij(*), Kkj(*), Dja(*), Djka(*), Djia(*)
28
29      integer g_objo,g_objv,ncor,nocc,nvir,iprt,g_coul,
30     &     g_exch,oseg_lo,oseg_hi
31c
32      double precision eaijk
33      integer a,i,j,k,akold,av,inode,len,ad3,next
34      integer nxtask
35      external nxtask
36c
37      Integer Nodes, IAm
38c
39      integer klo, khi, start, end
40      integer kchunk
41c
42c==================================================
43c
44c  NON-BLOCKING stuff
45c
46c==================================================
47c
48c  Dependencies (global array, local array, handle):
49c
50c      g_objv, Dja, nbh_objv1
51c      g_objv, Tka, nbh_objv2
52c      g_objv, Xka, nbh_objv3
53c      g_objv, Djka(1+(k-klo)*nvir), nbh_objv4(k)
54c      g_objv, Djia, nbh_objv5
55c      g_objv, Tia, nbh_objv6
56c      g_objv, Xia, nbh_objv7
57c      g_objo, Tkj, nbh_objo1
58c      g_objo, Jkj, nbh_objo2
59c      g_objo, Kkj, nbh_objo3
60c      g_objo, Tij, nbh_objo4
61c      g_objo, Jij, nbh_objo5
62c      g_objo, Kij, nbh_objo6
63c      g_exch, Kka, nbh_exch1
64c      g_exch, Kia, nbh_exch2
65c      g_coul, Jka, nbh_coul1
66c      g_coul, Jia, nbh_coul2
67c
68c  non-blocking handles
69c
70       integer nbh_objv1,nbh_objv2,nbh_objv3
71       integer nbh_objv5,nbh_objv6,nbh_objv7
72       integer nbh_objv4(nocc)
73c
74       integer nbh_objo1,nbh_objo2,nbh_objo3
75       integer nbh_objo4,nbh_objo5,nbh_objo6
76c
77       integer nbh_exch1,nbh_exch2,nbh_coul1,nbh_coul2
78c
79       logical need_ccsd_dovvv1
80       logical need_ccsd_dovvv2
81       logical need_ccsd_doooo1
82       logical need_ccsd_doooo2
83c
84#ifdef DEBUG_PRINT
85      integer tt
86      double precision tt0,tt1,trp_time(26)
87#endif
88c
89c==================================================
90c
91      double precision zip
92      data zip/0.0d00/
93c
94      Nodes = GA_NNodes()
95      IAm = GA_NodeID()
96c
97      call ga_sync()
98c
99      if (occsdps) then
100         call pstat_on(ps_trpdrv)
101      else
102         call qenter('trpdrv',0)
103      endif
104      inode=-1
105      next=nxtask(nodes, 1)
106c
107#ifdef DEBUG_PRINT
108      do tt = 1, 26
109        trp_time(tt) = 0.0d0
110      enddo
111#endif
112c
113      do klo = 1, nocc, kchunk
114         akold=0
115         khi = min(nocc, klo+kchunk-1)
116         do a=oseg_lo,oseg_hi
117            av=a-ncor-nocc
118            do j=1,nocc
119               inode=inode+1
120               if (inode.eq.next)then
121c
122c     Get Dja = Dci,ja for given j, a, all ci
123c
124                  start = 1 + (j-1)*lnov
125                  len   = lnov
126                  end   = start + len - 1
127#ifdef DEBUG_PRINT
128                  tt0 = ga_wtime()
129#endif
130                  call ga_nbget(g_objv,start,end,av,av,Dja,len,
131     1                          nbh_objv1)
132#ifdef DEBUG_PRINT
133                  tt1 = ga_wtime()
134                  trp_time(1) = trp_time(1) + (tt1-tt0)
135#endif
136c
137c     Get Tkj = T(b,c,k,j) for given j, klo<=k<=khi, all bc
138c
139                  start = (klo-1)*lnvv + 1
140                  len   = (khi-klo+1)*lnvv
141                  end   = start + len - 1
142#ifdef DEBUG_PRINT
143                  tt0 = ga_wtime()
144#endif
145                  call ga_nbget(g_objo,start,end,j,j,Tkj,len,
146     1                          nbh_objo1)
147#ifdef DEBUG_PRINT
148                  tt1 = ga_wtime()
149                  trp_time(2) = trp_time(2) + (tt1-tt0)
150#endif
151c
152c     Get Jkj = J(c,l,k,j) for given j, klo<=k<=khi, all cl
153c
154                  start = lnovv + (klo-1)*lnov + 1
155                  len   = (khi-klo+1)*lnov
156                  end   = start + len - 1
157#ifdef DEBUG_PRINT
158                  tt0 = ga_wtime()
159#endif
160                  call ga_nbget(g_objo,start,end,j,j,Jkj,len,
161     1                          nbh_objo2)
162#ifdef DEBUG_PRINT
163                  tt1 = ga_wtime()
164                  trp_time(3) = trp_time(3) + (tt1-tt0)
165#endif
166c
167c     Get Kkj = K(c,l,k,j) for given j, klo<=k<=khi, all cl
168c
169                  start = lnovv + lnoov + (klo-1)*lnov + 1
170                  len   = (khi-klo+1)*lnov
171                  end   = start + len - 1
172#ifdef DEBUG_PRINT
173                  tt0 = ga_wtime()
174#endif
175                  call ga_nbget(g_objo,start,end,j,j,Kkj,len,
176     1                          nbh_objo3)
177#ifdef DEBUG_PRINT
178                  tt1 = ga_wtime()
179                  trp_time(4) = trp_time(4) + (tt1-tt0)
180#endif
181c
182                  if (akold .ne. a) then
183                     akold = a
184c
185c     Get Jka = J(b,c,k,a) for given a, klo<=k<=khi, all bc
186c
187                     start = (a-oseg_lo)*nocc + klo
188                     len   = (khi-klo+1)
189                     end   = start + len - 1
190#ifdef DEBUG_PRINT
191                  tt0 = ga_wtime()
192#endif
193                     call ga_nbget(g_coul,1,lnvv,start,end,Jka,lnvv,
194     1                             nbh_coul1)
195#ifdef DEBUG_PRINT
196                  tt1 = ga_wtime()
197                  trp_time(5) = trp_time(5) + (tt1-tt0)
198#endif
199c
200c     Get Kka = K(b,c,k,a) for given a, klo<=k<=khi, all bc
201c
202                     start = (a-oseg_lo)*nocc + klo
203                     len   = (khi-klo+1)
204                     end   = start + len - 1
205#ifdef DEBUG_PRINT
206                  tt0 = ga_wtime()
207#endif
208                     call ga_nbget(g_exch,1,lnvv,start,end,Kka,lnvv,
209     1                             nbh_exch1)
210#ifdef DEBUG_PRINT
211                  tt1 = ga_wtime()
212                  trp_time(6) = trp_time(6) + (tt1-tt0)
213#endif
214c
215c     Get Tka = Tbl,ka for given a, klo<=k<=khi, all bl
216c
217                     start = 1 + lnoov + (klo-1)*lnov
218                     len   = (khi-klo+1)*lnov
219                     end   = start + len - 1
220#ifdef DEBUG_PRINT
221                  tt0 = ga_wtime()
222#endif
223                     call ga_nbget(g_objv,start,end,av,av,Tka,len,
224     1                             nbh_objv2)
225#ifdef DEBUG_PRINT
226                  tt1 = ga_wtime()
227                  trp_time(7) = trp_time(7) + (tt1-tt0)
228#endif
229c
230c     Get Xka = Tal,kb for given a, klo<=k<=khi, all bl
231c
232                     start = 1 + lnoov + lnoov + (klo-1)*lnov
233                     len   = (khi-klo+1)*lnov
234                     end   = start + len - 1
235#ifdef DEBUG_PRINT
236                  tt0 = ga_wtime()
237#endif
238                     call ga_nbget(g_objv,start,end,av,av,Xka,len,
239     1                             nbh_objv3)
240#ifdef DEBUG_PRINT
241                  tt1 = ga_wtime()
242                  trp_time(8) = trp_time(8) + (tt1-tt0)
243#endif
244                  endif
245c
246c     Get Djka = Dcj,ka for given j, a, klo<=k<=khi, all c
247c
248                  do k = klo, khi
249                     start = 1 + (j-1)*nvir + (k-1)*lnov
250                     len   = nvir
251                     end   = start + len - 1
252#ifdef DEBUG_PRINT
253                  tt0 = ga_wtime()
254#endif
255                     call ga_nbget(g_objv,start,end,av,av,
256     1                    Djka(1+(k-klo)*nvir),len,nbh_objv4(k)) ! k <= nocc
257#ifdef DEBUG_PRINT
258                  tt1 = ga_wtime()
259                  trp_time(9) = trp_time(9) + (tt1-tt0)
260#endif
261                  enddo
262c
263                  do i=1,nocc
264c
265c     Get Tij = T(b,c,i,j) for given j, i, all bc
266c
267                     start = (i-1)*lnvv + 1
268                     len   = lnvv
269                     end   = start + len - 1
270#ifdef DEBUG_PRINT
271                  tt0 = ga_wtime()
272#endif
273                     call ga_nbget(g_objo,start,end,j,j,Tij,len,
274     1                             nbh_objo4)
275#ifdef DEBUG_PRINT
276                  tt1 = ga_wtime()
277                  trp_time(10) = trp_time(10) + (tt1-tt0)
278#endif
279c
280c     Get Jij = J(c,l,i,j) for given j, i, all cl
281c
282                     start = lnovv + (i-1)*lnov + 1
283                     len   = lnov
284                     end   = start + len - 1
285#ifdef DEBUG_PRINT
286                  tt0 = ga_wtime()
287#endif
288                     call ga_nbget(g_objo,start,end,j,j,Jij,len,
289     1                             nbh_objo5)
290#ifdef DEBUG_PRINT
291                  tt1 = ga_wtime()
292                  trp_time(11) = trp_time(11) + (tt1-tt0)
293#endif
294c
295c     Get Kij = K(c,l,i,j) for given j, i, all cl
296c
297                     start = lnovv + lnoov + (i-1)*lnov + 1
298                     len   = lnov
299                     end   = start + len - 1
300#ifdef DEBUG_PRINT
301                  tt0 = ga_wtime()
302#endif
303                     call ga_nbget(g_objo,start,end,j,j,Kij,len,
304     1                             nbh_objo6)
305#ifdef DEBUG_PRINT
306                  tt1 = ga_wtime()
307                  trp_time(12) = trp_time(12) + (tt1-tt0)
308#endif
309c
310c     Get Jia = J(b,c,i,a) for given a, i, all bc
311c
312                     start = (a-oseg_lo)*nocc + i
313                     len   = 1
314                     end   = start + len - 1
315#ifdef DEBUG_PRINT
316                  tt0 = ga_wtime()
317#endif
318                     call ga_nbget(g_coul,1,lnvv,start,end,Jia,lnvv,
319     1                             nbh_coul2)
320#ifdef DEBUG_PRINT
321                  tt1 = ga_wtime()
322                  trp_time(13) = trp_time(13) + (tt1-tt0)
323#endif
324c
325c     Get Kia = K(b,c,i,a) for given a, i, all bc
326c
327                     start = (a-oseg_lo)*nocc + i
328                     len   = 1
329                     end   = start + len - 1
330#ifdef DEBUG_PRINT
331                  tt0 = ga_wtime()
332#endif
333                     call ga_nbget(g_exch,1,lnvv,start,end,Kia,lnvv,
334     1                             nbh_exch2)
335#ifdef DEBUG_PRINT
336                  tt1 = ga_wtime()
337                  trp_time(14) = trp_time(14) + (tt1-tt0)
338#endif
339c
340c     Get Dia = Dcj,ia for given j, i, a, all c
341c
342                     start = 1 + (j-1)*nvir + (i-1)*lnov
343                     len   = nvir
344                     end   = start + len - 1
345#ifdef DEBUG_PRINT
346                  tt0 = ga_wtime()
347#endif
348                     call ga_nbget(g_objv,start,end,av,av,Djia,len,
349     1                             nbh_objv5)
350#ifdef DEBUG_PRINT
351                  tt1 = ga_wtime()
352                  trp_time(15) = trp_time(15) + (tt1-tt0)
353#endif
354c
355c     Get Tia = Tbl,ia for given a, i, all bl
356c
357                     start = 1 + lnoov + (i-1)*lnov
358                     len   = lnov
359                     end   = start + len - 1
360#ifdef DEBUG_PRINT
361                  tt0 = ga_wtime()
362#endif
363                     call ga_nbget(g_objv,start,end,av,av,Tia,len,
364     1                             nbh_objv6)
365#ifdef DEBUG_PRINT
366                  tt1 = ga_wtime()
367                  trp_time(16) = trp_time(16) + (tt1-tt0)
368#endif
369c
370c     Get Xia = Tal,ib for given a, i, all bl
371c
372                     start = 1 + lnoov + lnoov + (i-1)*lnov
373                     len   = lnov
374                     end   = start + len - 1
375#ifdef DEBUG_PRINT
376                  tt0 = ga_wtime()
377#endif
378                     call ga_nbget(g_objv,start,end,av,av,Xia,len,
379     1                             nbh_objv7)
380#ifdef DEBUG_PRINT
381                  tt1 = ga_wtime()
382                  trp_time(17) = trp_time(17) + (tt1-tt0)
383#endif
384c
385                     do k=klo,min(khi,i)
386                        call dfill(lnvv,zip,f1n,1)
387                        call dfill(lnvv,zip,f1t,1)
388                        call dfill(lnvv,zip,f2n,1)
389                        call dfill(lnvv,zip,f2t,1)
390                        call dfill(lnvv,zip,f3n,1)
391                        call dfill(lnvv,zip,f3t,1)
392                        call dfill(lnvv,zip,f4n,1)
393                        call dfill(lnvv,zip,f4t,1)
394c
395            need_ccsd_dovvv1 = .true.
396            need_ccsd_dovvv2 = .true.
397            need_ccsd_doooo1 = .true.
398            need_ccsd_doooo2 = .true.
399c
400#ifdef DEBUG_PRINT
401            !write(6,*) IAm,'before do-while loop'
402#endif
403c
404            do while ( need_ccsd_dovvv1 .or. need_ccsd_dovvv2
405     1            .or. need_ccsd_doooo1 .or. need_ccsd_doooo2 )
406c
407c     sum(d) (Jia, Kia)bd * Tkj,cd -> Fbc
408c
409c      g_coul, Jia, nbh_coul2
410c      g_exch, Kia, nbh_exch2
411c      g_objo, Tkj, nbh_objo1
412c
413                if ( need_ccsd_dovvv1 ) then
414                  if ( (0.eq.ga_nbtest(nbh_coul2)) .and.
415     1                 (0.eq.ga_nbtest(nbh_exch2)) .and.
416     2                 (0.eq.ga_nbtest(nbh_objo1)) ) then
417
418#ifdef DEBUG_PRINT
419                        !write(6,55) IAm,'ccsd_dovvv1'
420#endif
421
422#ifdef DEBUG_PRINT
423                        tt0 = ga_wtime()
424#endif
425                        call ccsd_dovvv(Jia, Kia,
426     $                       Tkj(1+(k-klo)*lnvv),
427     $                       f1n,f2n,f3n,f4n,nocc,nvir)
428#ifdef DEBUG_PRINT
429                        tt1 = ga_wtime()
430                        trp_time(18) = trp_time(18) + (tt1-tt0)
431#endif
432c
433                        need_ccsd_dovvv1 = .false.
434c
435                  endif
436                endif
437c
438c     sum(d) (Jka, Kka)bd * Tij,cd -> Fbc
439c
440c      g_coul, Jka, nbh_coul1
441c      g_exch, Kka, nbh_exch1
442c      g_objo, Tij, nbh_objo4
443c
444                if ( need_ccsd_dovvv2 ) then
445                  if ( (0.eq.ga_nbtest(nbh_coul1)) .and.
446     1                 (0.eq.ga_nbtest(nbh_exch1)) .and.
447     2                 (0.eq.ga_nbtest(nbh_objo4)) ) then
448
449#ifdef DEBUG_PRINT
450                        !write(6,55) IAm,'ccsd_dovvv2'
451#endif
452
453#ifdef DEBUG_PRINT
454                        tt0 = ga_wtime()
455#endif
456                        call ccsd_dovvv(Jka(1+(k-klo)*lnvv),
457     $                       Kka(1+(k-klo)*lnvv),Tij,
458     $                       f1t,f2t,f3t,f4t,nocc,nvir)
459#ifdef DEBUG_PRINT
460                        tt1 = ga_wtime()
461                        trp_time(19) = trp_time(19) + (tt1-tt0)
462#endif
463c
464                        need_ccsd_dovvv2 = .false.
465c
466                  endif
467                endif
468c
469c     sum(l) (Jij, Kij)cl  * Tkl,ab -> Fbc
470c
471c      g_objo, Jkj, nbh_objo2
472c      g_objo, Kkj, nbh_objo3
473c      g_objv, Tia, nbh_objv6
474c      g_objv, Xia, nbh_objv7
475c
476                if ( need_ccsd_doooo1 ) then
477                  if ( (0.eq.ga_nbtest(nbh_objo2)) .and.
478     1                 (0.eq.ga_nbtest(nbh_objo3)) .and.
479     2                 (0.eq.ga_nbtest(nbh_objv6)) .and.
480     3                 (0.eq.ga_nbtest(nbh_objv7)) ) then
481
482#ifdef DEBUG_PRINT
483                        !write(6,55) IAm,'ccsd_doooo1'
484#endif
485
486#ifdef DEBUG_PRINT
487                        tt0 = ga_wtime()
488#endif
489                        call ccsd_doooo(Jkj(1+(k-klo)*lnov),
490     $                       Kkj(1+(k-klo)*lnov),Tia,Xia,
491     $                       f1n,f2n,f3n,f4n,nocc,nvir)
492#ifdef DEBUG_PRINT
493                        tt1 = ga_wtime()
494                        trp_time(20) = trp_time(20) + (tt1-tt0)
495#endif
496c
497                        need_ccsd_doooo1 = .false.
498c
499                  endif
500                endif
501c
502c     sum(l) (Jkj, Kkj)cl  * Tli,ba -> Fbc
503c
504c      g_objo, Jij, nbh_objo5
505c      g_objo, Kij, nbh_objo6
506c      g_objv, Tka, nbh_objv2
507c      g_objv, Xka, nbh_objv3
508c
509                if ( need_ccsd_doooo2 ) then
510                  if ( (0.eq.ga_nbtest(nbh_objo5)) .and.
511     1                 (0.eq.ga_nbtest(nbh_objo6)) .and.
512     2                 (0.eq.ga_nbtest(nbh_objv2)) .and.
513     3                 (0.eq.ga_nbtest(nbh_objv3)) ) then
514
515#ifdef DEBUG_PRINT
516                        !write(6,55) IAm,'ccsd_doooo2'
517#endif
518
519#ifdef DEBUG_PRINT
520                        tt0 = ga_wtime()
521#endif
522                        call ccsd_doooo(Jij, Kij,
523     $                       Tka(1+(k-klo)*lnov),Xka(1+(k-klo)*lnov),
524     $                       f1t,f2t,f3t,f4t,nocc,nvir)
525#ifdef DEBUG_PRINT
526                        tt1 = ga_wtime()
527                        trp_time(21) = trp_time(21) + (tt1-tt0)
528#endif
529c
530                        need_ccsd_doooo2 = .false.
531c
532                  endif
533                endif
534c
535            enddo ! while need...
536
537#ifdef DEBUG_PRINT
538            !write(6,*) IAm,'after do-while loop and before ga_nbwaits'
539#endif
540c
541c      g_objv, Dja, nbh_objv1
542c      g_objv, Djka(1+(k-klo)*nvir), nbh_objv4(k)
543c      g_objv, Djia, nbh_objv5
544c
545c       just do waits since it is unlikely that these get calls
546c       will not finish during the time that ccsd_do... is running
547c
548                        call ga_nbwait(nbh_objv1)
549                        !do k = klo, khi
550                           call ga_nbwait(nbh_objv4(k))
551                        !enddo
552                        call ga_nbwait(nbh_objv5)
553c
554#ifdef DEBUG_PRINT
555                        !write(6,*) IAm,'after ga_nbwaits'
556#endif
557c
558                        eaijk=eorb(ncor+i)+eorb(ncor+j)+eorb(ncor+k)-
559     $                       eorb(a)
560
561#ifdef DEBUG_PRINT
562                        tt0 = ga_wtime()
563#endif
564                        call ccsd_tengy(f1n,f1t,f2n,f2t,
565     $                                  f3n,f3t,f4n,f4t,
566     &                                  Dja(1+(i-1)*nvir),Djia,
567     $                                  t1((k-1)*nvir+1),
568     $                                  eorb,eaijk,emp4,emp5,
569     $                                  ncor,nocc,nvir)
570#ifdef DEBUG_PRINT
571                        tt1 = ga_wtime()
572                        trp_time(22) = trp_time(22) + (tt1-tt0)
573#endif
574c
575                        if (i.ne.k)then
576#ifdef DEBUG_PRINT
577                           tt0 = ga_wtime()
578#endif
579                           call ccsd_tengy(f1t,f1n,f2t,f2n,
580     $                                     f3t,f3n,f4t,f4n,
581     $                                     Dja(1+(k-1)*nvir),
582     $                                     Djka(1+(k-klo)*nvir),
583     $                                     t1((i-1)*nvir+1),
584     $                                     eorb,eaijk,emp4,emp5,
585     $                                     ncor,nocc,nvir)
586#ifdef DEBUG_PRINT
587                        tt1 = ga_wtime()
588                        trp_time(23) = trp_time(23) + (tt1-tt0)
589#endif
590c
591                        end if
592
593
594                     end do
595                  end do
596                  if (iprt.gt.50)then
597                     write(6,1234)iam,a,j,emp4,emp5
598 1234                format(' iam aijk',3i5,2e15.5)
599                  end if
600                  next=nxtask(nodes, 1)
601               end if
602            end do
603         end do
604      end do
605c
606#ifdef DEBUG_PRINT
607c
608      do tt = 1, 17
609        !write(6,97) IAm,tt,trp_time(tt)
610        trp_time(24) = trp_time(24) + trp_time(tt)
611      enddo
612      call util_flush(6)
613   97 format('node ',i5,': ga_nbget timer(',i2,') = ',1e15.5)
614c
615      do tt = 18, 21
616        !write(6,98) IAm,tt,trp_time(tt)
617        trp_time(25) = trp_time(25) + trp_time(tt)
618      enddo
619      call util_flush(6)
620   98 format('node ',i5,': dgemm    timer(',i2,') = ',1e15.5)
621c
622      do tt = 22, 23
623        !write(6,99) IAm,tt,trp_time(tt)
624        trp_time(26) = trp_time(26) + trp_time(tt)
625      enddo
626      call util_flush(6)
627   99 format('node ',i5,': tengy    timer(',i2,') = ',1e15.5)
628c
629      call ga_sync()
630      if (IAm.eq.0) write(6,87)
631   87 format(2x,'node',6x,'ga_nbget',9x,'dgemm',10x,'tengy')
632      call ga_sync()
633      write(6,88) IAm,trp_time(24),trp_time(25),trp_time(26)
634   88 format(i7,3e15.5)
635c
636#endif
637c
638      next=nxtask(-nodes, 1)
639      call ga_sync
640      if (occsdps) then
641         call pstat_off(ps_trpdrv)
642      else
643         call qexit('trpdrv',0)
644      endif
645c
646   55                   format('node ',i5,': ',a12)
647c
648      end
649
650