1!--------------------------------------------------------------------------------------------------!
2!   CP2K: A general program to perform molecular dynamics simulations                              !
3!   Copyright (C) 2000 - 2020  CP2K developers group                                               !
4!--------------------------------------------------------------------------------------------------!
5! **************************************************************************************************
6!> \brief Methods used with 3-center overlap type integrals containers
7!> \par History
8!>      - none
9!>      - 11.2018 fixed OMP race condition in contract3_o3c routine (A.Bussy)
10! **************************************************************************************************
11MODULE qs_o3c_methods
12   USE ai_contraction_sphi,             ONLY: abc_contract
13   USE ai_overlap3,                     ONLY: overlap3
14   USE basis_set_types,                 ONLY: gto_basis_set_p_type,&
15                                              gto_basis_set_type
16   USE dbcsr_api,                       ONLY: dbcsr_get_block_p,&
17                                              dbcsr_p_type,&
18                                              dbcsr_type
19   USE kinds,                           ONLY: dp
20   USE orbital_pointers,                ONLY: ncoset
21   USE qs_o3c_types,                    ONLY: &
22        get_o3c_container, get_o3c_iterator_info, get_o3c_vec, o3c_container_type, o3c_iterate, &
23        o3c_iterator_create, o3c_iterator_release, o3c_iterator_type, o3c_vec_type, &
24        set_o3c_container
25
26!$ USE OMP_LIB, ONLY: omp_get_max_threads, omp_get_thread_num
27#include "./base/base_uses.f90"
28
29   IMPLICIT NONE
30
31   PRIVATE
32
33   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'qs_o3c_methods'
34
35   PUBLIC :: calculate_o3c_integrals, contract12_o3c, contract3_o3c
36
37CONTAINS
38
39! **************************************************************************************************
40!> \brief ...
41!> \param o3c ...
42!> \param calculate_forces ...
43!> \param matrix_p ...
44! **************************************************************************************************
45   SUBROUTINE calculate_o3c_integrals(o3c, calculate_forces, matrix_p)
46      TYPE(o3c_container_type), POINTER                  :: o3c
47      LOGICAL, INTENT(IN), OPTIONAL                      :: calculate_forces
48      TYPE(dbcsr_p_type), DIMENSION(:), OPTIONAL, &
49         POINTER                                         :: matrix_p
50
51      CHARACTER(LEN=*), PARAMETER :: routineN = 'calculate_o3c_integrals', &
52         routineP = moduleN//':'//routineN
53
54      INTEGER :: egfa, egfb, egfc, handle, i, iatom, icol, ikind, irow, iset, ispin, j, jatom, &
55         jkind, jset, katom, kkind, kset, mepos, ncoa, ncob, ncoc, ni, nj, nk, nseta, nsetb, &
56         nsetc, nspin, nthread, sgfa, sgfb, sgfc
57      INTEGER, DIMENSION(:), POINTER                     :: la_max, la_min, lb_max, lb_min, lc_max, &
58                                                            lc_min, npgfa, npgfb, npgfc, nsgfa, &
59                                                            nsgfb, nsgfc
60      INTEGER, DIMENSION(:, :), POINTER                  :: first_sgfa, first_sgfb, first_sgfc
61      LOGICAL                                            :: do_force, found, trans
62      REAL(KIND=dp)                                      :: dij, dik, djk, fpre
63      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: pmat
64      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: sabc, sabc_contr
65      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :, :)  :: iabdc, iadbc, idabc, sabdc, sdabc
66      REAL(KIND=dp), DIMENSION(3)                        :: rij, rik, rjk
67      REAL(KIND=dp), DIMENSION(:), POINTER               :: set_radius_a, set_radius_b, set_radius_c
68      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: fi, fj, fk, pblock, rpgfa, rpgfb, rpgfc, &
69                                                            sphi_a, sphi_b, sphi_c, tvec, zeta, &
70                                                            zetb, zetc
71      REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: iabc
72      TYPE(gto_basis_set_p_type), DIMENSION(:), POINTER  :: basis_set_list_a, basis_set_list_b, &
73                                                            basis_set_list_c
74      TYPE(gto_basis_set_type), POINTER                  :: basis_set_a, basis_set_b, basis_set_c
75      TYPE(o3c_iterator_type)                            :: o3c_iterator
76
77      CALL timeset(routineN, handle)
78
79      do_force = .FALSE.
80      IF (PRESENT(calculate_forces)) do_force = calculate_forces
81      CALL get_o3c_container(o3c, nspin=nspin)
82
83      ! basis sets
84      CALL get_o3c_container(o3c, basis_set_list_a=basis_set_list_a, &
85                             basis_set_list_b=basis_set_list_b, basis_set_list_c=basis_set_list_c)
86
87      nthread = 1
88!$    nthread = omp_get_max_threads()
89      CALL o3c_iterator_create(o3c, o3c_iterator, nthread=nthread)
90
91!$OMP PARALLEL DEFAULT(NONE) &
92!$OMP SHARED (nthread,o3c_iterator,ncoset,nspin,basis_set_list_a,basis_set_list_b,&
93!$OMP         basis_set_list_c,do_force,matrix_p)&
94!$OMP PRIVATE (mepos,ikind,jkind,kkind,basis_set_a,basis_set_b,basis_set_c,rij,rik,rjk,&
95!$OMP          first_sgfa,la_max,la_min,npgfa,nseta,nsgfa,rpgfa,set_radius_a,sphi_a,zeta,&
96!$OMP          first_sgfb,lb_max,lb_min,npgfb,nsetb,nsgfb,rpgfb,set_radius_b,sphi_b,zetb,&
97!$OMP          first_sgfc,lc_max,lc_min,npgfc,nsetc,nsgfc,rpgfc,set_radius_c,sphi_c,zetc,&
98!$OMP          iset,jset,kset,dij,dik,djk,ni,nj,nk,iabc,idabc,iadbc,iabdc,tvec,fi,fj,fk,ncoa,&
99!$OMP          ncob,ncoc,sabc,sabc_contr,sdabc,sabdc,sgfa,sgfb,sgfc,egfa,egfb,egfc,i,j,&
100!$OMP          pblock,pmat,ispin,iatom,jatom,katom,irow,icol,found,trans,fpre)
101
102      mepos = 0
103!$    mepos = omp_get_thread_num()
104
105      DO WHILE (o3c_iterate(o3c_iterator, mepos=mepos) == 0)
106         CALL get_o3c_iterator_info(o3c_iterator, mepos=mepos, &
107                                    ikind=ikind, jkind=jkind, kkind=kkind, rij=rij, rik=rik, &
108                                    integral=iabc, tvec=tvec, force_i=fi, force_j=fj, force_k=fk)
109         CPASSERT(.NOT. ASSOCIATED(iabc))
110         CPASSERT(.NOT. ASSOCIATED(tvec))
111         CPASSERT(.NOT. ASSOCIATED(fi))
112         CPASSERT(.NOT. ASSOCIATED(fj))
113         CPASSERT(.NOT. ASSOCIATED(fk))
114         ! basis
115         basis_set_a => basis_set_list_a(ikind)%gto_basis_set
116         basis_set_b => basis_set_list_b(jkind)%gto_basis_set
117         basis_set_c => basis_set_list_c(kkind)%gto_basis_set
118         ! center A
119         first_sgfa => basis_set_a%first_sgf
120         la_max => basis_set_a%lmax
121         la_min => basis_set_a%lmin
122         npgfa => basis_set_a%npgf
123         nseta = basis_set_a%nset
124         nsgfa => basis_set_a%nsgf_set
125         rpgfa => basis_set_a%pgf_radius
126         set_radius_a => basis_set_a%set_radius
127         sphi_a => basis_set_a%sphi
128         zeta => basis_set_a%zet
129         ! center B
130         first_sgfb => basis_set_b%first_sgf
131         lb_max => basis_set_b%lmax
132         lb_min => basis_set_b%lmin
133         npgfb => basis_set_b%npgf
134         nsetb = basis_set_b%nset
135         nsgfb => basis_set_b%nsgf_set
136         rpgfb => basis_set_b%pgf_radius
137         set_radius_b => basis_set_b%set_radius
138         sphi_b => basis_set_b%sphi
139         zetb => basis_set_b%zet
140         ! center C (RI)
141         first_sgfc => basis_set_c%first_sgf
142         lc_max => basis_set_c%lmax
143         lc_min => basis_set_c%lmin
144         npgfc => basis_set_c%npgf
145         nsetc = basis_set_c%nset
146         nsgfc => basis_set_c%nsgf_set
147         rpgfc => basis_set_c%pgf_radius
148         set_radius_c => basis_set_c%set_radius
149         sphi_c => basis_set_c%sphi
150         zetc => basis_set_c%zet
151
152         ni = SUM(nsgfa)
153         nj = SUM(nsgfb)
154         nk = SUM(nsgfc)
155
156         ALLOCATE (iabc(ni, nj, nk))
157         iabc(:, :, :) = 0.0_dp
158         IF (do_force) THEN
159            ALLOCATE (fi(nk, 3), fj(nk, 3), fk(nk, 3))
160            fi(:, :) = 0.0_dp
161            fj(:, :) = 0.0_dp
162            fk(:, :) = 0.0_dp
163            ALLOCATE (idabc(ni, nj, nk, 3))
164            idabc(:, :, :, :) = 0.0_dp
165            ALLOCATE (iadbc(ni, nj, nk, 3))
166            iadbc(:, :, :, :) = 0.0_dp
167            ALLOCATE (iabdc(ni, nj, nk, 3))
168            iabdc(:, :, :, :) = 0.0_dp
169         ELSE
170            NULLIFY (fi, fj, fk)
171         END IF
172         ALLOCATE (tvec(nk, nspin))
173         tvec(:, :) = 0.0_dp
174
175         rjk(1:3) = rik(1:3) - rij(1:3)
176         dij = NORM2(rij)
177         dik = NORM2(rik)
178         djk = NORM2(rjk)
179
180         DO iset = 1, nseta
181            DO jset = 1, nsetb
182               IF (set_radius_a(iset) + set_radius_b(jset) < dij) CYCLE
183               DO kset = 1, nsetc
184                  IF (set_radius_a(iset) + set_radius_c(kset) < dik) CYCLE
185                  IF (set_radius_b(jset) + set_radius_c(kset) < djk) CYCLE
186
187                  ncoa = npgfa(iset)*ncoset(la_max(iset))
188                  ncob = npgfb(jset)*ncoset(lb_max(jset))
189                  ncoc = npgfc(kset)*ncoset(lc_max(kset))
190
191                  sgfa = first_sgfa(1, iset)
192                  sgfb = first_sgfb(1, jset)
193                  sgfc = first_sgfc(1, kset)
194
195                  egfa = sgfa + nsgfa(iset) - 1
196                  egfb = sgfb + nsgfb(jset) - 1
197                  egfc = sgfc + nsgfc(kset) - 1
198
199                  IF (ncoa*ncob*ncoc > 0) THEN
200                     ALLOCATE (sabc(ncoa, ncob, ncoc))
201                     sabc(:, :, :) = 0.0_dp
202                     IF (do_force) THEN
203                        ALLOCATE (sdabc(ncoa, ncob, ncoc, 3))
204                        sdabc(:, :, :, :) = 0.0_dp
205                        ALLOCATE (sabdc(ncoa, ncob, ncoc, 3))
206                        sabdc(:, :, :, :) = 0.0_dp
207                        CALL overlap3(la_max(iset), npgfa(iset), zeta(:, iset), rpgfa(:, iset), la_min(iset), &
208                                      lb_max(jset), npgfb(jset), zetb(:, jset), rpgfb(:, jset), lb_min(jset), &
209                                      lc_max(kset), npgfc(kset), zetc(:, kset), rpgfc(:, kset), lc_min(kset), &
210                                      rij, dij, rik, dik, rjk, djk, sabc, sdabc, sabdc)
211                     ELSE
212                        CALL overlap3(la_max(iset), npgfa(iset), zeta(:, iset), rpgfa(:, iset), la_min(iset), &
213                                      lb_max(jset), npgfb(jset), zetb(:, jset), rpgfb(:, jset), lb_min(jset), &
214                                      lc_max(kset), npgfc(kset), zetc(:, kset), rpgfc(:, kset), lc_min(kset), &
215                                      rij, dij, rik, dik, rjk, djk, sabc)
216                     END IF
217                     ALLOCATE (sabc_contr(nsgfa(iset), nsgfb(jset), nsgfc(kset)))
218
219                     CALL abc_contract(sabc_contr, sabc, &
220                                       sphi_a(:, sgfa:), sphi_b(:, sgfb:), sphi_c(:, sgfc:), &
221                                       ncoa, ncob, ncoc, nsgfa(iset), nsgfb(jset), nsgfc(kset))
222                     iabc(sgfa:egfa, sgfb:egfb, sgfc:egfc) = &
223                        sabc_contr(1:nsgfa(iset), 1:nsgfb(jset), 1:nsgfc(kset))
224                     IF (do_force) THEN
225                        DO i = 1, 3
226                           CALL abc_contract(sabc_contr, sdabc(:, :, :, i), &
227                                             sphi_a(:, sgfa:), sphi_b(:, sgfb:), sphi_c(:, sgfc:), &
228                                             ncoa, ncob, ncoc, nsgfa(iset), nsgfb(jset), nsgfc(kset))
229                           idabc(sgfa:egfa, sgfb:egfb, sgfc:egfc, i) = &
230                              sabc_contr(1:nsgfa(iset), 1:nsgfb(jset), 1:nsgfc(kset))
231                           CALL abc_contract(sabc_contr, sabdc(:, :, :, i), &
232                                             sphi_a(:, sgfa:), sphi_b(:, sgfb:), sphi_c(:, sgfc:), &
233                                             ncoa, ncob, ncoc, nsgfa(iset), nsgfb(jset), nsgfc(kset))
234                           iabdc(sgfa:egfa, sgfb:egfb, sgfc:egfc, i) = &
235                              sabc_contr(1:nsgfa(iset), 1:nsgfb(jset), 1:nsgfc(kset))
236                        END DO
237                     END IF
238
239                     DEALLOCATE (sabc_contr)
240                     DEALLOCATE (sabc)
241                  END IF
242                  IF (do_force) THEN
243                     DEALLOCATE (sdabc, sabdc)
244                  END IF
245               END DO
246            END DO
247         END DO
248         IF (do_force) THEN
249            ! translational invariance
250            iadbc(:, :, :, :) = -idabc(:, :, :, :) - iabdc(:, :, :, :)
251            !
252            ! get the atom indices
253            CALL get_o3c_iterator_info(o3c_iterator, mepos=mepos, &
254                                       iatom=iatom, jatom=jatom, katom=katom)
255            !
256            ! contract over i and j to get forces
257            IF (iatom <= jatom) THEN
258               irow = iatom
259               icol = jatom
260               trans = .FALSE.
261            ELSE
262               irow = jatom
263               icol = iatom
264               trans = .TRUE.
265            END IF
266            IF (iatom == jatom) THEN
267               fpre = 1.0_dp
268            ELSE
269               fpre = 2.0_dp
270            END IF
271            ALLOCATE (pmat(ni, nj))
272            pmat(:, :) = 0.0_dp
273            DO ispin = 1, nspin
274               CALL dbcsr_get_block_p(matrix=matrix_p(ispin)%matrix, &
275                                      row=irow, col=icol, BLOCK=pblock, found=found)
276               IF (found) THEN
277                  IF (trans) THEN
278                     pmat(:, :) = pmat(:, :) + TRANSPOSE(pblock(:, :))
279                  ELSE
280                     pmat(:, :) = pmat(:, :) + pblock(:, :)
281                  END IF
282               END IF
283            END DO
284            DO i = 1, 3
285               DO j = 1, nk
286                  fi(j, i) = fpre*SUM(pmat(:, :)*idabc(:, :, j, i))
287                  fj(j, i) = fpre*SUM(pmat(:, :)*iadbc(:, :, j, i))
288                  fk(j, i) = fpre*SUM(pmat(:, :)*iabdc(:, :, j, i))
289               END DO
290            END DO
291            DEALLOCATE (pmat)
292            !
293            DEALLOCATE (idabc, iadbc, iabdc)
294         END IF
295         !
296         CALL set_o3c_container(o3c_iterator, mepos=mepos, &
297                                integral=iabc, tvec=tvec, force_i=fi, force_j=fj, force_k=fk)
298
299      END DO
300!$OMP END PARALLEL
301      CALL o3c_iterator_release(o3c_iterator)
302
303      CALL timestop(handle)
304
305   END SUBROUTINE calculate_o3c_integrals
306
307! **************************************************************************************************
308!> \brief Contraction of 3-tensor over indices 1 and 2 (assuming symmetry)
309!>        t(k) = sum_ij (ijk)*p(ij)
310!> \param o3c ...
311!> \param matrix_p ...
312! **************************************************************************************************
313   SUBROUTINE contract12_o3c(o3c, matrix_p)
314      TYPE(o3c_container_type), POINTER                  :: o3c
315      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_p
316
317      CHARACTER(LEN=*), PARAMETER :: routineN = 'contract12_o3c', routineP = moduleN//':'//routineN
318
319      INTEGER                                            :: handle, iatom, icol, ik, irow, ispin, &
320                                                            jatom, mepos, nk, nspin, nthread
321      LOGICAL                                            :: found, ijsymmetric, trans
322      REAL(KIND=dp)                                      :: fpre
323      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: pblock, tvec
324      REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: iabc
325      TYPE(o3c_iterator_type)                            :: o3c_iterator
326
327      CALL timeset(routineN, handle)
328
329      nspin = SIZE(matrix_p, 1)
330      CALL get_o3c_container(o3c, ijsymmetric=ijsymmetric)
331      CPASSERT(ijsymmetric)
332
333      nthread = 1
334!$    nthread = omp_get_max_threads()
335      CALL o3c_iterator_create(o3c, o3c_iterator, nthread=nthread)
336
337!$OMP PARALLEL DEFAULT(NONE) &
338!$OMP SHARED (nthread,o3c_iterator,matrix_p,nspin)&
339!$OMP PRIVATE (mepos,ispin,iatom,jatom,ik,nk,irow,icol,iabc,tvec,found,pblock,trans,fpre)
340
341      mepos = 0
342!$    mepos = omp_get_thread_num()
343
344      DO WHILE (o3c_iterate(o3c_iterator, mepos=mepos) == 0)
345         CALL get_o3c_iterator_info(o3c_iterator, mepos=mepos, iatom=iatom, jatom=jatom, &
346                                    integral=iabc, tvec=tvec)
347         nk = SIZE(tvec, 1)
348
349         IF (iatom <= jatom) THEN
350            irow = iatom
351            icol = jatom
352            trans = .FALSE.
353         ELSE
354            irow = jatom
355            icol = iatom
356            trans = .TRUE.
357         END IF
358         IF (iatom == jatom) THEN
359            fpre = 1.0_dp
360         ELSE
361            fpre = 2.0_dp
362         END IF
363
364         DO ispin = 1, nspin
365            CALL dbcsr_get_block_p(matrix=matrix_p(ispin)%matrix, &
366                                   row=irow, col=icol, BLOCK=pblock, found=found)
367            IF (found) THEN
368               IF (trans) THEN
369                  DO ik = 1, nk
370                     tvec(ik, ispin) = fpre*SUM(TRANSPOSE(pblock(:, :))*iabc(:, :, ik))
371                  END DO
372               ELSE
373                  DO ik = 1, nk
374                     tvec(ik, ispin) = fpre*SUM(pblock(:, :)*iabc(:, :, ik))
375                  END DO
376               END IF
377            END IF
378         END DO
379
380      END DO
381!$OMP END PARALLEL
382      CALL o3c_iterator_release(o3c_iterator)
383
384      CALL timestop(handle)
385
386   END SUBROUTINE contract12_o3c
387
388! **************************************************************************************************
389!> \brief Contraction of 3-tensor over index 3
390!>        h(ij) = h(ij) + sum_k (ijk)*v(k)
391!> \param o3c ...
392!> \param vec ...
393!> \param matrix ...
394! **************************************************************************************************
395   SUBROUTINE contract3_o3c(o3c, vec, matrix)
396      TYPE(o3c_container_type), POINTER                  :: o3c
397      TYPE(o3c_vec_type), DIMENSION(:), POINTER          :: vec
398      TYPE(dbcsr_type), POINTER                          :: matrix
399
400      CHARACTER(LEN=*), PARAMETER :: routineN = 'contract3_o3c', routineP = moduleN//':'//routineN
401
402      INTEGER                                            :: handle, iatom, icol, ik, irow, jatom, &
403                                                            katom, mepos, nk, nthread, s1, s2
404      LOGICAL                                            :: found, ijsymmetric, trans
405      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: work
406      REAL(KIND=dp), DIMENSION(:), POINTER               :: v
407      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: pblock
408      REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: iabc
409      TYPE(o3c_iterator_type)                            :: o3c_iterator
410
411      CALL timeset(routineN, handle)
412
413      CALL get_o3c_container(o3c, ijsymmetric=ijsymmetric)
414      CPASSERT(ijsymmetric)
415
416      nthread = 1
417!$    nthread = omp_get_max_threads()
418      CALL o3c_iterator_create(o3c, o3c_iterator, nthread=nthread)
419
420!$OMP PARALLEL DEFAULT(NONE) &
421!$OMP SHARED (nthread,o3c_iterator,vec,matrix)&
422!$OMP PRIVATE (mepos,iabc,iatom,jatom,katom,irow,icol,trans,pblock,v,found,ik,nk,work,s1,s2)
423
424      mepos = 0
425!$    mepos = omp_get_thread_num()
426
427      DO WHILE (o3c_iterate(o3c_iterator, mepos=mepos) == 0)
428         CALL get_o3c_iterator_info(o3c_iterator, mepos=mepos, iatom=iatom, jatom=jatom, katom=katom, &
429                                    integral=iabc)
430
431         CALL get_o3c_vec(vec, katom, v)
432         nk = SIZE(v)
433
434         IF (iatom <= jatom) THEN
435            irow = iatom
436            icol = jatom
437            trans = .FALSE.
438         ELSE
439            irow = jatom
440            icol = iatom
441            trans = .TRUE.
442         END IF
443
444         CALL dbcsr_get_block_p(matrix=matrix, row=irow, col=icol, BLOCK=pblock, found=found)
445
446         IF (found) THEN
447            s1 = SIZE(pblock, 1); s2 = SIZE(pblock, 2)
448            ALLOCATE (work(s1, s2))
449            work(:, :) = 0.0_dp
450
451            IF (trans) THEN
452               DO ik = 1, nk
453                  CALL daxpy(s1*s2, v(ik), TRANSPOSE(iabc(:, :, ik)), 1, work(:, :), 1)
454               END DO
455            ELSE
456               DO ik = 1, nk
457                  CALL daxpy(s1*s2, v(ik), iabc(:, :, ik), 1, work(:, :), 1)
458               END DO
459            END IF
460
461            ! Multiple threads with same irow, icol but different katom (same even in PBCs) can try
462            ! to access the dbcsr block at the same time. Prevent that by CRITICAL section but keep
463            ! computations before hand in order to retain speed
464
465!$OMP CRITICAL
466            CALL dbcsr_get_block_p(matrix=matrix, row=irow, col=icol, BLOCK=pblock, found=found)
467            CALL daxpy(s1*s2, 1.0_dp, work(:, :), 1, pblock(:, :), 1)
468!$OMP END CRITICAL
469
470            DEALLOCATE (work)
471         END IF
472
473      END DO
474!$OMP END PARALLEL
475      CALL o3c_iterator_release(o3c_iterator)
476
477      CALL timestop(handle)
478
479   END SUBROUTINE contract3_o3c
480
481END MODULE qs_o3c_methods
482