1 subroutine ccsd_trpdrv_nb(t1, 2 & f1n,f1t,f2n,f2t,f3n,f3t,f4n,f4t,eorb, 3 & g_objo,g_objv,g_coul,g_exch, 4 & ncor,nocc,nvir,iprt,emp4,emp5, 5 & oseg_lo,oseg_hi, 6 $ kchunk, Tij, Tkj, Tia, Tka, Xia, Xka, Jia, Jka, Kia, Kka, 7 $ Jij, Jkj, Kij, Kkj, Dja, Djka, Djia) 8c 9C $Id$ 10c 11c CCSD(T) non-blocking modifications written by 12c Jeff Hammond, Argonne Leadership Computing Facility 13c Fall 2009 14c 15 implicit none 16c 17#include "global.fh" 18#include "ccsd_len.fh" 19#include "ccsdps.fh" 20c 21 double precision t1(*), 22 & f1n(*),f1t(*),f2n(*), 23 & f2t(*),f3n(*),f3t(*),f4n(*),f4t(*),eorb(*), 24 & emp4,emp5 25 double precision Tij(*), Tkj(*), Tia(*), Tka(*), Xia(*), Xka(*), 26 $ Jia(*), Jka(*), Kia(*), Kka(*), 27 $ Jij(*), Jkj(*), Kij(*), Kkj(*), Dja(*), Djka(*), Djia(*) 28 29 integer g_objo,g_objv,ncor,nocc,nvir,iprt,g_coul, 30 & g_exch,oseg_lo,oseg_hi 31c 32 double precision eaijk 33 integer a,i,j,k,akold,av,inode,len,ad3,next 34 integer nxtask 35 external nxtask 36c 37 Integer Nodes, IAm 38c 39 integer klo, khi, start, end 40 integer kchunk 41c 42c================================================== 43c 44c NON-BLOCKING stuff 45c 46c================================================== 47c 48c Dependencies (global array, local array, handle): 49c 50c g_objv, Dja, nbh_objv1 51c g_objv, Tka, nbh_objv2 52c g_objv, Xka, nbh_objv3 53c g_objv, Djka(1+(k-klo)*nvir), nbh_objv4(k) 54c g_objv, Djia, nbh_objv5 55c g_objv, Tia, nbh_objv6 56c g_objv, Xia, nbh_objv7 57c g_objo, Tkj, nbh_objo1 58c g_objo, Jkj, nbh_objo2 59c g_objo, Kkj, nbh_objo3 60c g_objo, Tij, nbh_objo4 61c g_objo, Jij, nbh_objo5 62c g_objo, Kij, nbh_objo6 63c g_exch, Kka, nbh_exch1 64c g_exch, Kia, nbh_exch2 65c g_coul, Jka, nbh_coul1 66c g_coul, Jia, nbh_coul2 67c 68c non-blocking handles 69c 70 integer nbh_objv1,nbh_objv2,nbh_objv3 71 integer nbh_objv5,nbh_objv6,nbh_objv7 72 integer nbh_objv4(nocc) 73c 74 integer nbh_objo1,nbh_objo2,nbh_objo3 75 integer nbh_objo4,nbh_objo5,nbh_objo6 76c 77 integer nbh_exch1,nbh_exch2,nbh_coul1,nbh_coul2 78c 79 logical need_ccsd_dovvv1 80 logical need_ccsd_dovvv2 81 logical need_ccsd_doooo1 82 logical need_ccsd_doooo2 83c 84#ifdef DEBUG_PRINT 85 integer tt 86 double precision tt0,tt1,trp_time(26) 87#endif 88c 89c================================================== 90c 91 double precision zip 92 data zip/0.0d00/ 93c 94 Nodes = GA_NNodes() 95 IAm = GA_NodeID() 96c 97 call ga_sync() 98c 99 if (occsdps) then 100 call pstat_on(ps_trpdrv) 101 else 102 call qenter('trpdrv',0) 103 endif 104 inode=-1 105 next=nxtask(nodes, 1) 106c 107#ifdef DEBUG_PRINT 108 do tt = 1, 26 109 trp_time(tt) = 0.0d0 110 enddo 111#endif 112c 113 do klo = 1, nocc, kchunk 114 akold=0 115 khi = min(nocc, klo+kchunk-1) 116 do a=oseg_lo,oseg_hi 117 av=a-ncor-nocc 118 do j=1,nocc 119 inode=inode+1 120 if (inode.eq.next)then 121c 122c Get Dja = Dci,ja for given j, a, all ci 123c 124 start = 1 + (j-1)*lnov 125 len = lnov 126 end = start + len - 1 127#ifdef DEBUG_PRINT 128 tt0 = ga_wtime() 129#endif 130 call ga_nbget(g_objv,start,end,av,av,Dja,len, 131 1 nbh_objv1) 132#ifdef DEBUG_PRINT 133 tt1 = ga_wtime() 134 trp_time(1) = trp_time(1) + (tt1-tt0) 135#endif 136c 137c Get Tkj = T(b,c,k,j) for given j, klo<=k<=khi, all bc 138c 139 start = (klo-1)*lnvv + 1 140 len = (khi-klo+1)*lnvv 141 end = start + len - 1 142#ifdef DEBUG_PRINT 143 tt0 = ga_wtime() 144#endif 145 call ga_nbget(g_objo,start,end,j,j,Tkj,len, 146 1 nbh_objo1) 147#ifdef DEBUG_PRINT 148 tt1 = ga_wtime() 149 trp_time(2) = trp_time(2) + (tt1-tt0) 150#endif 151c 152c Get Jkj = J(c,l,k,j) for given j, klo<=k<=khi, all cl 153c 154 start = lnovv + (klo-1)*lnov + 1 155 len = (khi-klo+1)*lnov 156 end = start + len - 1 157#ifdef DEBUG_PRINT 158 tt0 = ga_wtime() 159#endif 160 call ga_nbget(g_objo,start,end,j,j,Jkj,len, 161 1 nbh_objo2) 162#ifdef DEBUG_PRINT 163 tt1 = ga_wtime() 164 trp_time(3) = trp_time(3) + (tt1-tt0) 165#endif 166c 167c Get Kkj = K(c,l,k,j) for given j, klo<=k<=khi, all cl 168c 169 start = lnovv + lnoov + (klo-1)*lnov + 1 170 len = (khi-klo+1)*lnov 171 end = start + len - 1 172#ifdef DEBUG_PRINT 173 tt0 = ga_wtime() 174#endif 175 call ga_nbget(g_objo,start,end,j,j,Kkj,len, 176 1 nbh_objo3) 177#ifdef DEBUG_PRINT 178 tt1 = ga_wtime() 179 trp_time(4) = trp_time(4) + (tt1-tt0) 180#endif 181c 182 if (akold .ne. a) then 183 akold = a 184c 185c Get Jka = J(b,c,k,a) for given a, klo<=k<=khi, all bc 186c 187 start = (a-oseg_lo)*nocc + klo 188 len = (khi-klo+1) 189 end = start + len - 1 190#ifdef DEBUG_PRINT 191 tt0 = ga_wtime() 192#endif 193 call ga_nbget(g_coul,1,lnvv,start,end,Jka,lnvv, 194 1 nbh_coul1) 195#ifdef DEBUG_PRINT 196 tt1 = ga_wtime() 197 trp_time(5) = trp_time(5) + (tt1-tt0) 198#endif 199c 200c Get Kka = K(b,c,k,a) for given a, klo<=k<=khi, all bc 201c 202 start = (a-oseg_lo)*nocc + klo 203 len = (khi-klo+1) 204 end = start + len - 1 205#ifdef DEBUG_PRINT 206 tt0 = ga_wtime() 207#endif 208 call ga_nbget(g_exch,1,lnvv,start,end,Kka,lnvv, 209 1 nbh_exch1) 210#ifdef DEBUG_PRINT 211 tt1 = ga_wtime() 212 trp_time(6) = trp_time(6) + (tt1-tt0) 213#endif 214c 215c Get Tka = Tbl,ka for given a, klo<=k<=khi, all bl 216c 217 start = 1 + lnoov + (klo-1)*lnov 218 len = (khi-klo+1)*lnov 219 end = start + len - 1 220#ifdef DEBUG_PRINT 221 tt0 = ga_wtime() 222#endif 223 call ga_nbget(g_objv,start,end,av,av,Tka,len, 224 1 nbh_objv2) 225#ifdef DEBUG_PRINT 226 tt1 = ga_wtime() 227 trp_time(7) = trp_time(7) + (tt1-tt0) 228#endif 229c 230c Get Xka = Tal,kb for given a, klo<=k<=khi, all bl 231c 232 start = 1 + lnoov + lnoov + (klo-1)*lnov 233 len = (khi-klo+1)*lnov 234 end = start + len - 1 235#ifdef DEBUG_PRINT 236 tt0 = ga_wtime() 237#endif 238 call ga_nbget(g_objv,start,end,av,av,Xka,len, 239 1 nbh_objv3) 240#ifdef DEBUG_PRINT 241 tt1 = ga_wtime() 242 trp_time(8) = trp_time(8) + (tt1-tt0) 243#endif 244 endif 245c 246c Get Djka = Dcj,ka for given j, a, klo<=k<=khi, all c 247c 248 do k = klo, khi 249 start = 1 + (j-1)*nvir + (k-1)*lnov 250 len = nvir 251 end = start + len - 1 252#ifdef DEBUG_PRINT 253 tt0 = ga_wtime() 254#endif 255 call ga_nbget(g_objv,start,end,av,av, 256 1 Djka(1+(k-klo)*nvir),len,nbh_objv4(k)) ! k <= nocc 257#ifdef DEBUG_PRINT 258 tt1 = ga_wtime() 259 trp_time(9) = trp_time(9) + (tt1-tt0) 260#endif 261 enddo 262c 263 do i=1,nocc 264c 265c Get Tij = T(b,c,i,j) for given j, i, all bc 266c 267 start = (i-1)*lnvv + 1 268 len = lnvv 269 end = start + len - 1 270#ifdef DEBUG_PRINT 271 tt0 = ga_wtime() 272#endif 273 call ga_nbget(g_objo,start,end,j,j,Tij,len, 274 1 nbh_objo4) 275#ifdef DEBUG_PRINT 276 tt1 = ga_wtime() 277 trp_time(10) = trp_time(10) + (tt1-tt0) 278#endif 279c 280c Get Jij = J(c,l,i,j) for given j, i, all cl 281c 282 start = lnovv + (i-1)*lnov + 1 283 len = lnov 284 end = start + len - 1 285#ifdef DEBUG_PRINT 286 tt0 = ga_wtime() 287#endif 288 call ga_nbget(g_objo,start,end,j,j,Jij,len, 289 1 nbh_objo5) 290#ifdef DEBUG_PRINT 291 tt1 = ga_wtime() 292 trp_time(11) = trp_time(11) + (tt1-tt0) 293#endif 294c 295c Get Kij = K(c,l,i,j) for given j, i, all cl 296c 297 start = lnovv + lnoov + (i-1)*lnov + 1 298 len = lnov 299 end = start + len - 1 300#ifdef DEBUG_PRINT 301 tt0 = ga_wtime() 302#endif 303 call ga_nbget(g_objo,start,end,j,j,Kij,len, 304 1 nbh_objo6) 305#ifdef DEBUG_PRINT 306 tt1 = ga_wtime() 307 trp_time(12) = trp_time(12) + (tt1-tt0) 308#endif 309c 310c Get Jia = J(b,c,i,a) for given a, i, all bc 311c 312 start = (a-oseg_lo)*nocc + i 313 len = 1 314 end = start + len - 1 315#ifdef DEBUG_PRINT 316 tt0 = ga_wtime() 317#endif 318 call ga_nbget(g_coul,1,lnvv,start,end,Jia,lnvv, 319 1 nbh_coul2) 320#ifdef DEBUG_PRINT 321 tt1 = ga_wtime() 322 trp_time(13) = trp_time(13) + (tt1-tt0) 323#endif 324c 325c Get Kia = K(b,c,i,a) for given a, i, all bc 326c 327 start = (a-oseg_lo)*nocc + i 328 len = 1 329 end = start + len - 1 330#ifdef DEBUG_PRINT 331 tt0 = ga_wtime() 332#endif 333 call ga_nbget(g_exch,1,lnvv,start,end,Kia,lnvv, 334 1 nbh_exch2) 335#ifdef DEBUG_PRINT 336 tt1 = ga_wtime() 337 trp_time(14) = trp_time(14) + (tt1-tt0) 338#endif 339c 340c Get Dia = Dcj,ia for given j, i, a, all c 341c 342 start = 1 + (j-1)*nvir + (i-1)*lnov 343 len = nvir 344 end = start + len - 1 345#ifdef DEBUG_PRINT 346 tt0 = ga_wtime() 347#endif 348 call ga_nbget(g_objv,start,end,av,av,Djia,len, 349 1 nbh_objv5) 350#ifdef DEBUG_PRINT 351 tt1 = ga_wtime() 352 trp_time(15) = trp_time(15) + (tt1-tt0) 353#endif 354c 355c Get Tia = Tbl,ia for given a, i, all bl 356c 357 start = 1 + lnoov + (i-1)*lnov 358 len = lnov 359 end = start + len - 1 360#ifdef DEBUG_PRINT 361 tt0 = ga_wtime() 362#endif 363 call ga_nbget(g_objv,start,end,av,av,Tia,len, 364 1 nbh_objv6) 365#ifdef DEBUG_PRINT 366 tt1 = ga_wtime() 367 trp_time(16) = trp_time(16) + (tt1-tt0) 368#endif 369c 370c Get Xia = Tal,ib for given a, i, all bl 371c 372 start = 1 + lnoov + lnoov + (i-1)*lnov 373 len = lnov 374 end = start + len - 1 375#ifdef DEBUG_PRINT 376 tt0 = ga_wtime() 377#endif 378 call ga_nbget(g_objv,start,end,av,av,Xia,len, 379 1 nbh_objv7) 380#ifdef DEBUG_PRINT 381 tt1 = ga_wtime() 382 trp_time(17) = trp_time(17) + (tt1-tt0) 383#endif 384c 385 do k=klo,min(khi,i) 386 call dfill(lnvv,zip,f1n,1) 387 call dfill(lnvv,zip,f1t,1) 388 call dfill(lnvv,zip,f2n,1) 389 call dfill(lnvv,zip,f2t,1) 390 call dfill(lnvv,zip,f3n,1) 391 call dfill(lnvv,zip,f3t,1) 392 call dfill(lnvv,zip,f4n,1) 393 call dfill(lnvv,zip,f4t,1) 394c 395 need_ccsd_dovvv1 = .true. 396 need_ccsd_dovvv2 = .true. 397 need_ccsd_doooo1 = .true. 398 need_ccsd_doooo2 = .true. 399c 400#ifdef DEBUG_PRINT 401 !write(6,*) IAm,'before do-while loop' 402#endif 403c 404 do while ( need_ccsd_dovvv1 .or. need_ccsd_dovvv2 405 1 .or. need_ccsd_doooo1 .or. need_ccsd_doooo2 ) 406c 407c sum(d) (Jia, Kia)bd * Tkj,cd -> Fbc 408c 409c g_coul, Jia, nbh_coul2 410c g_exch, Kia, nbh_exch2 411c g_objo, Tkj, nbh_objo1 412c 413 if ( need_ccsd_dovvv1 ) then 414 if ( (0.eq.ga_nbtest(nbh_coul2)) .and. 415 1 (0.eq.ga_nbtest(nbh_exch2)) .and. 416 2 (0.eq.ga_nbtest(nbh_objo1)) ) then 417 418#ifdef DEBUG_PRINT 419 !write(6,55) IAm,'ccsd_dovvv1' 420#endif 421 422#ifdef DEBUG_PRINT 423 tt0 = ga_wtime() 424#endif 425 call ccsd_dovvv(Jia, Kia, 426 $ Tkj(1+(k-klo)*lnvv), 427 $ f1n,f2n,f3n,f4n,nocc,nvir) 428#ifdef DEBUG_PRINT 429 tt1 = ga_wtime() 430 trp_time(18) = trp_time(18) + (tt1-tt0) 431#endif 432c 433 need_ccsd_dovvv1 = .false. 434c 435 endif 436 endif 437c 438c sum(d) (Jka, Kka)bd * Tij,cd -> Fbc 439c 440c g_coul, Jka, nbh_coul1 441c g_exch, Kka, nbh_exch1 442c g_objo, Tij, nbh_objo4 443c 444 if ( need_ccsd_dovvv2 ) then 445 if ( (0.eq.ga_nbtest(nbh_coul1)) .and. 446 1 (0.eq.ga_nbtest(nbh_exch1)) .and. 447 2 (0.eq.ga_nbtest(nbh_objo4)) ) then 448 449#ifdef DEBUG_PRINT 450 !write(6,55) IAm,'ccsd_dovvv2' 451#endif 452 453#ifdef DEBUG_PRINT 454 tt0 = ga_wtime() 455#endif 456 call ccsd_dovvv(Jka(1+(k-klo)*lnvv), 457 $ Kka(1+(k-klo)*lnvv),Tij, 458 $ f1t,f2t,f3t,f4t,nocc,nvir) 459#ifdef DEBUG_PRINT 460 tt1 = ga_wtime() 461 trp_time(19) = trp_time(19) + (tt1-tt0) 462#endif 463c 464 need_ccsd_dovvv2 = .false. 465c 466 endif 467 endif 468c 469c sum(l) (Jij, Kij)cl * Tkl,ab -> Fbc 470c 471c g_objo, Jkj, nbh_objo2 472c g_objo, Kkj, nbh_objo3 473c g_objv, Tia, nbh_objv6 474c g_objv, Xia, nbh_objv7 475c 476 if ( need_ccsd_doooo1 ) then 477 if ( (0.eq.ga_nbtest(nbh_objo2)) .and. 478 1 (0.eq.ga_nbtest(nbh_objo3)) .and. 479 2 (0.eq.ga_nbtest(nbh_objv6)) .and. 480 3 (0.eq.ga_nbtest(nbh_objv7)) ) then 481 482#ifdef DEBUG_PRINT 483 !write(6,55) IAm,'ccsd_doooo1' 484#endif 485 486#ifdef DEBUG_PRINT 487 tt0 = ga_wtime() 488#endif 489 call ccsd_doooo(Jkj(1+(k-klo)*lnov), 490 $ Kkj(1+(k-klo)*lnov),Tia,Xia, 491 $ f1n,f2n,f3n,f4n,nocc,nvir) 492#ifdef DEBUG_PRINT 493 tt1 = ga_wtime() 494 trp_time(20) = trp_time(20) + (tt1-tt0) 495#endif 496c 497 need_ccsd_doooo1 = .false. 498c 499 endif 500 endif 501c 502c sum(l) (Jkj, Kkj)cl * Tli,ba -> Fbc 503c 504c g_objo, Jij, nbh_objo5 505c g_objo, Kij, nbh_objo6 506c g_objv, Tka, nbh_objv2 507c g_objv, Xka, nbh_objv3 508c 509 if ( need_ccsd_doooo2 ) then 510 if ( (0.eq.ga_nbtest(nbh_objo5)) .and. 511 1 (0.eq.ga_nbtest(nbh_objo6)) .and. 512 2 (0.eq.ga_nbtest(nbh_objv2)) .and. 513 3 (0.eq.ga_nbtest(nbh_objv3)) ) then 514 515#ifdef DEBUG_PRINT 516 !write(6,55) IAm,'ccsd_doooo2' 517#endif 518 519#ifdef DEBUG_PRINT 520 tt0 = ga_wtime() 521#endif 522 call ccsd_doooo(Jij, Kij, 523 $ Tka(1+(k-klo)*lnov),Xka(1+(k-klo)*lnov), 524 $ f1t,f2t,f3t,f4t,nocc,nvir) 525#ifdef DEBUG_PRINT 526 tt1 = ga_wtime() 527 trp_time(21) = trp_time(21) + (tt1-tt0) 528#endif 529c 530 need_ccsd_doooo2 = .false. 531c 532 endif 533 endif 534c 535 enddo ! while need... 536 537#ifdef DEBUG_PRINT 538 !write(6,*) IAm,'after do-while loop and before ga_nbwaits' 539#endif 540c 541c g_objv, Dja, nbh_objv1 542c g_objv, Djka(1+(k-klo)*nvir), nbh_objv4(k) 543c g_objv, Djia, nbh_objv5 544c 545c just do waits since it is unlikely that these get calls 546c will not finish during the time that ccsd_do... is running 547c 548 call ga_nbwait(nbh_objv1) 549 !do k = klo, khi 550 call ga_nbwait(nbh_objv4(k)) 551 !enddo 552 call ga_nbwait(nbh_objv5) 553c 554#ifdef DEBUG_PRINT 555 !write(6,*) IAm,'after ga_nbwaits' 556#endif 557c 558 eaijk=eorb(ncor+i)+eorb(ncor+j)+eorb(ncor+k)- 559 $ eorb(a) 560 561#ifdef DEBUG_PRINT 562 tt0 = ga_wtime() 563#endif 564 call ccsd_tengy(f1n,f1t,f2n,f2t, 565 $ f3n,f3t,f4n,f4t, 566 & Dja(1+(i-1)*nvir),Djia, 567 $ t1((k-1)*nvir+1), 568 $ eorb,eaijk,emp4,emp5, 569 $ ncor,nocc,nvir) 570#ifdef DEBUG_PRINT 571 tt1 = ga_wtime() 572 trp_time(22) = trp_time(22) + (tt1-tt0) 573#endif 574c 575 if (i.ne.k)then 576#ifdef DEBUG_PRINT 577 tt0 = ga_wtime() 578#endif 579 call ccsd_tengy(f1t,f1n,f2t,f2n, 580 $ f3t,f3n,f4t,f4n, 581 $ Dja(1+(k-1)*nvir), 582 $ Djka(1+(k-klo)*nvir), 583 $ t1((i-1)*nvir+1), 584 $ eorb,eaijk,emp4,emp5, 585 $ ncor,nocc,nvir) 586#ifdef DEBUG_PRINT 587 tt1 = ga_wtime() 588 trp_time(23) = trp_time(23) + (tt1-tt0) 589#endif 590c 591 end if 592 593 594 end do 595 end do 596 if (iprt.gt.50)then 597 write(6,1234)iam,a,j,emp4,emp5 598 1234 format(' iam aijk',3i5,2e15.5) 599 end if 600 next=nxtask(nodes, 1) 601 end if 602 end do 603 end do 604 end do 605c 606#ifdef DEBUG_PRINT 607c 608 do tt = 1, 17 609 !write(6,97) IAm,tt,trp_time(tt) 610 trp_time(24) = trp_time(24) + trp_time(tt) 611 enddo 612 call util_flush(6) 613 97 format('node ',i5,': ga_nbget timer(',i2,') = ',1e15.5) 614c 615 do tt = 18, 21 616 !write(6,98) IAm,tt,trp_time(tt) 617 trp_time(25) = trp_time(25) + trp_time(tt) 618 enddo 619 call util_flush(6) 620 98 format('node ',i5,': dgemm timer(',i2,') = ',1e15.5) 621c 622 do tt = 22, 23 623 !write(6,99) IAm,tt,trp_time(tt) 624 trp_time(26) = trp_time(26) + trp_time(tt) 625 enddo 626 call util_flush(6) 627 99 format('node ',i5,': tengy timer(',i2,') = ',1e15.5) 628c 629 call ga_sync() 630 if (IAm.eq.0) write(6,87) 631 87 format(2x,'node',6x,'ga_nbget',9x,'dgemm',10x,'tengy') 632 call ga_sync() 633 write(6,88) IAm,trp_time(24),trp_time(25),trp_time(26) 634 88 format(i7,3e15.5) 635c 636#endif 637c 638 next=nxtask(-nodes, 1) 639 call ga_sync 640 if (occsdps) then 641 call pstat_off(ps_trpdrv) 642 else 643 call qexit('trpdrv',0) 644 endif 645c 646 55 format('node ',i5,': ',a12) 647c 648 end 649 650