1!--------------------------------------------------------------------------------------------------!
2!   CP2K: A general program to perform molecular dynamics simulations                              !
3!   Copyright (C) 2000 - 2019  CP2K developers group                                               !
4!--------------------------------------------------------------------------------------------------!
5
6! **************************************************************************************************
7!> \brief   Interface to (sca)lapack for the Cholesky based procedures
8!> \author  VW
9!> \date    2009-09-08
10!> \version 0.8
11!>
12!> <b>Modification history:</b>
13!> - Created 2009-09-08
14! **************************************************************************************************
15MODULE cp_dbcsr_cholesky
16   USE cp_blacs_env,                    ONLY: cp_blacs_env_type
17   USE cp_dbcsr_operations,             ONLY: copy_dbcsr_to_fm,&
18                                              copy_fm_to_dbcsr
19   USE cp_fm_basic_linalg,              ONLY: cp_fm_upper_to_full
20   USE cp_fm_struct,                    ONLY: cp_fm_struct_create,&
21                                              cp_fm_struct_release,&
22                                              cp_fm_struct_type
23   USE cp_fm_types,                     ONLY: cp_fm_create,&
24                                              cp_fm_release,&
25                                              cp_fm_type
26   USE cp_para_types,                   ONLY: cp_para_env_type
27   USE dbcsr_api,                       ONLY: dbcsr_get_info,&
28                                              dbcsr_type
29   USE kinds,                           ONLY: dp,&
30                                              sp
31#include "base/base_uses.f90"
32
33   IMPLICIT NONE
34
35   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'cp_dbcsr_cholesky'
36
37   PUBLIC :: cp_dbcsr_cholesky_decompose, cp_dbcsr_cholesky_invert, &
38             cp_dbcsr_cholesky_restore
39
40   PRIVATE
41
42CONTAINS
43
44! **************************************************************************************************
45!> \brief used to replace a symmetric positive def. matrix M with its cholesky
46!>      decomposition U: M = U^T * U, with U upper triangular
47!> \param matrix the matrix to replace with its cholesky decomposition
48!> \param n the number of row (and columns) of the matrix &
49!>        (defaults to the min(size(matrix)))
50!> \param para_env ...
51!> \param blacs_env ...
52!> \par History
53!>      05.2002 created [JVdV]
54!>      12.2002 updated, added n optional parm [fawzi]
55!> \author Joost
56! **************************************************************************************************
57   SUBROUTINE cp_dbcsr_cholesky_decompose(matrix, n, para_env, blacs_env)
58      TYPE(dbcsr_type)                      :: matrix
59      INTEGER, INTENT(in), OPTIONAL            :: n
60      TYPE(cp_para_env_type), POINTER          :: para_env
61      TYPE(cp_blacs_env_type), POINTER         :: blacs_env
62
63      CHARACTER(len=*), PARAMETER :: routineN = 'cp_dbcsr_cholesky_decompose', &
64                                     routineP = moduleN//':'//routineN
65
66      INTEGER                                  :: handle, info, my_n, &
67                                                  nfullcols_total, &
68                                                  nfullrows_total
69      REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
70      REAL(KIND=sp), DIMENSION(:, :), POINTER  :: a_sp
71      TYPE(cp_fm_struct_type), POINTER         :: fm_struct
72      TYPE(cp_fm_type), POINTER                :: fm_matrix
73#if defined(__SCALAPACK)
74      INTEGER, DIMENSION(9)                    :: desca
75#endif
76
77      CALL timeset(routineN, handle)
78
79      NULLIFY (fm_matrix, fm_struct)
80      CALL dbcsr_get_info(matrix, nfullrows_total=nfullrows_total, nfullcols_total=nfullcols_total)
81
82      CALL cp_fm_struct_create(fm_struct, context=blacs_env, nrow_global=nfullrows_total, &
83                               ncol_global=nfullcols_total, para_env=para_env)
84      CALL cp_fm_create(fm_matrix, fm_struct, name="fm_matrix")
85      CALL cp_fm_struct_release(fm_struct)
86
87      CALL copy_dbcsr_to_fm(matrix, fm_matrix)
88
89      my_n = MIN(fm_matrix%matrix_struct%nrow_global, &
90                 fm_matrix%matrix_struct%ncol_global)
91      IF (PRESENT(n)) THEN
92         CPASSERT(n <= my_n)
93         my_n = n
94      END IF
95
96      a => fm_matrix%local_data
97      a_sp => fm_matrix%local_data_sp
98
99#if defined(__SCALAPACK)
100      desca(:) = fm_matrix%matrix_struct%descriptor(:)
101
102      IF (fm_matrix%use_sp) THEN
103         CALL pspotrf('U', my_n, a_sp(1, 1), 1, 1, desca, info)
104      ELSE
105         CALL pdpotrf('U', my_n, a(1, 1), 1, 1, desca, info)
106      ENDIF
107
108#else
109
110      IF (fm_matrix%use_sp) THEN
111         CALL spotrf('U', my_n, a_sp(1, 1), SIZE(a_sp, 1), info)
112      ELSE
113         CALL dpotrf('U', my_n, a(1, 1), SIZE(a, 1), info)
114      ENDIF
115
116#endif
117
118      IF (info /= 0) &
119         CPABORT("Cholesky decomposition failed. Matrix ill conditioned ?")
120
121      CALL copy_fm_to_dbcsr(fm_matrix, matrix)
122
123      CALL cp_fm_release(fm_matrix)
124
125      CALL timestop(handle)
126
127   END SUBROUTINE cp_dbcsr_cholesky_decompose
128
129! **************************************************************************************************
130!> \brief used to replace the cholesky decomposition by the inverse
131!> \param matrix the matrix to invert (must be an upper triangular matrix)
132!> \param n size of the matrix to invert (defaults to the min(size(matrix)))
133!> \param para_env ...
134!> \param blacs_env ...
135!> \param upper_to_full ...
136!> \par History
137!>      05.2002 created [JVdV]
138!> \author Joost VandeVondele
139! **************************************************************************************************
140   SUBROUTINE cp_dbcsr_cholesky_invert(matrix, n, para_env, blacs_env, upper_to_full)
141      TYPE(dbcsr_type)                           :: matrix
142      INTEGER, INTENT(in), OPTIONAL             :: n
143      TYPE(cp_para_env_type), POINTER           :: para_env
144      TYPE(cp_blacs_env_type), POINTER          :: blacs_env
145      LOGICAL, INTENT(IN)                       :: upper_to_full
146
147      CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_cholesky_invert', &
148                                     routineP = moduleN//':'//routineN
149
150      REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
151      REAL(KIND=sp), DIMENSION(:, :), POINTER  :: a_sp
152      INTEGER                                   :: info, handle
153      INTEGER                                   :: my_n, nfullrows_total, nfullcols_total
154      TYPE(cp_fm_type), POINTER                 :: fm_matrix, fm_matrix_tmp
155      TYPE(cp_fm_struct_type), POINTER          :: fm_struct
156#if defined(__SCALAPACK)
157      INTEGER, DIMENSION(9)                     :: desca
158#endif
159
160      CALL timeset(routineN, handle)
161
162      NULLIFY (fm_matrix, fm_struct)
163      CALL dbcsr_get_info(matrix, nfullrows_total=nfullrows_total, nfullcols_total=nfullcols_total)
164
165      CALL cp_fm_struct_create(fm_struct, context=blacs_env, nrow_global=nfullrows_total, &
166                               ncol_global=nfullrows_total, para_env=para_env)
167      CALL cp_fm_create(fm_matrix, fm_struct, name="fm_matrix")
168      CALL cp_fm_struct_release(fm_struct)
169
170      CALL copy_dbcsr_to_fm(matrix, fm_matrix)
171
172      my_n = MIN(fm_matrix%matrix_struct%nrow_global, &
173                 fm_matrix%matrix_struct%ncol_global)
174      IF (PRESENT(n)) THEN
175         CPASSERT(n <= my_n)
176         my_n = n
177      END IF
178
179      a => fm_matrix%local_data
180      a_sp => fm_matrix%local_data_sp
181
182#if defined(__SCALAPACK)
183
184      desca(:) = fm_matrix%matrix_struct%descriptor(:)
185
186      IF (fm_matrix%use_sp) THEN
187         CALL pspotri('U', my_n, a_sp(1, 1), 1, 1, desca, info)
188      ELSE
189         CALL pdpotri('U', my_n, a(1, 1), 1, 1, desca, info)
190      ENDIF
191
192#else
193
194      IF (fm_matrix%use_sp) THEN
195         CALL spotri('U', my_n, a_sp(1, 1), SIZE(a_sp, 1), info)
196      ELSE
197         CALL dpotri('U', my_n, a(1, 1), SIZE(a, 1), info)
198      ENDIF
199
200#endif
201
202      CPASSERT(info == 0)
203
204      IF (upper_to_full) THEN
205         CALL cp_fm_create(fm_matrix_tmp, fm_matrix%matrix_struct, name="fm_matrix_tmp")
206         CALL cp_fm_upper_to_full(fm_matrix, fm_matrix_tmp)
207         CALL cp_fm_release(fm_matrix_tmp)
208      ENDIF
209
210      CALL copy_fm_to_dbcsr(fm_matrix, matrix)
211
212      CALL cp_fm_release(fm_matrix)
213
214      CALL timestop(handle)
215
216   END SUBROUTINE cp_dbcsr_cholesky_invert
217
218! **************************************************************************************************
219!> \brief ...
220!> \param matrix ...
221!> \param neig ...
222!> \param matrixb ...
223!> \param matrixout ...
224!> \param op ...
225!> \param pos ...
226!> \param transa ...
227!> \param para_env ...
228!> \param blacs_env ...
229! **************************************************************************************************
230   SUBROUTINE cp_dbcsr_cholesky_restore(matrix, neig, matrixb, matrixout, op, pos, transa, &
231                                        para_env, blacs_env)
232      TYPE(dbcsr_type)                                :: matrix, matrixb, matrixout
233      INTEGER, INTENT(IN)                            :: neig
234      CHARACTER(LEN=*), INTENT(IN)           :: op
235      CHARACTER(LEN=*), INTENT(IN), OPTIONAL :: pos
236      CHARACTER(LEN=*), INTENT(IN), OPTIONAL :: transa
237      TYPE(cp_para_env_type), POINTER                :: para_env
238      TYPE(cp_blacs_env_type), POINTER               :: blacs_env
239
240      CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_cholesky_restore', &
241                                     routineP = moduleN//':'//routineN
242
243      REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a, b, out
244      REAL(KIND=sp), DIMENSION(:, :), POINTER  :: a_sp, b_sp, out_sp
245      INTEGER                                   :: itype, handle
246      INTEGER                                   :: n
247      REAL(KIND=dp)                           :: alpha
248      INTEGER                                   :: myprow, mypcol, nfullrows_total, &
249                                                   nfullcols_total
250      TYPE(cp_blacs_env_type), POINTER          :: context
251      CHARACTER                                 :: chol_pos, chol_transa
252      TYPE(cp_fm_type), POINTER                 :: fm_matrix, fm_matrixb, fm_matrixout
253      TYPE(cp_fm_struct_type), POINTER          :: fm_struct
254#if defined(__SCALAPACK)
255      INTEGER                                   :: i
256      INTEGER, DIMENSION(9)                     :: desca, descb, descout
257#endif
258
259      CALL timeset(routineN, handle)
260
261      NULLIFY (fm_matrix, fm_matrixb, fm_matrixout, fm_struct)
262
263      CALL dbcsr_get_info(matrix, nfullrows_total=nfullrows_total, nfullcols_total=nfullcols_total)
264      CALL cp_fm_struct_create(fm_struct, context=blacs_env, nrow_global=nfullrows_total, &
265                               ncol_global=nfullcols_total, para_env=para_env)
266      CALL cp_fm_create(fm_matrix, fm_struct, name="fm_matrix")
267      CALL cp_fm_struct_release(fm_struct)
268
269      CALL dbcsr_get_info(matrixb, nfullrows_total=nfullrows_total, nfullcols_total=nfullcols_total)
270      CALL cp_fm_struct_create(fm_struct, context=blacs_env, nrow_global=nfullrows_total, &
271                               ncol_global=nfullcols_total, para_env=para_env)
272      CALL cp_fm_create(fm_matrixb, fm_struct, name="fm_matrixb")
273      CALL cp_fm_struct_release(fm_struct)
274
275      CALL dbcsr_get_info(matrixout, nfullrows_total=nfullrows_total, nfullcols_total=nfullcols_total)
276      CALL cp_fm_struct_create(fm_struct, context=blacs_env, nrow_global=nfullrows_total, &
277                               ncol_global=nfullcols_total, para_env=para_env)
278      CALL cp_fm_create(fm_matrixout, fm_struct, name="fm_matrixout")
279      CALL cp_fm_struct_release(fm_struct)
280
281      CALL copy_dbcsr_to_fm(matrix, fm_matrix)
282      CALL copy_dbcsr_to_fm(matrixb, fm_matrixb)
283      !CALL copy_dbcsr_to_fm(matrixout, fm_matrixout)
284
285      context => fm_matrix%matrix_struct%context
286      myprow = context%mepos(1)
287      mypcol = context%mepos(2)
288      n = fm_matrix%matrix_struct%nrow_global
289      itype = 1
290      IF (op /= "SOLVE" .AND. op /= "MULTIPLY") &
291         CPABORT("wrong argument op")
292
293      IF (PRESENT(pos)) THEN
294         SELECT CASE (pos)
295         CASE ("LEFT")
296            chol_pos = 'L'
297         CASE ("RIGHT")
298            chol_pos = 'R'
299         CASE DEFAULT
300            CPABORT("wrong argument pos")
301         END SELECT
302      ELSE
303         chol_pos = 'L'
304      ENDIF
305
306      chol_transa = 'N'
307      IF (PRESENT(transa)) chol_transa = transa
308
309      IF ((fm_matrix%use_sp .NEQV. fm_matrixb%use_sp) .OR. (fm_matrix%use_sp .NEQV. fm_matrixout%use_sp)) &
310         CPABORT("not the same precision")
311
312      ! notice b is the cholesky guy
313      a => fm_matrix%local_data
314      b => fm_matrixb%local_data
315      out => fm_matrixout%local_data
316      a_sp => fm_matrix%local_data_sp
317      b_sp => fm_matrixb%local_data_sp
318      out_sp => fm_matrixout%local_data_sp
319
320#if defined(__SCALAPACK)
321
322      desca(:) = fm_matrix%matrix_struct%descriptor(:)
323      descb(:) = fm_matrixb%matrix_struct%descriptor(:)
324      descout(:) = fm_matrixout%matrix_struct%descriptor(:)
325      alpha = 1.0_dp
326      DO i = 1, neig
327         IF (fm_matrix%use_sp) THEN
328            CALL pscopy(n, a_sp(1, 1), 1, i, desca, 1, out_sp(1, 1), 1, i, descout, 1)
329         ELSE
330            CALL pdcopy(n, a(1, 1), 1, i, desca, 1, out(1, 1), 1, i, descout, 1)
331         ENDIF
332      ENDDO
333      IF (op .EQ. "SOLVE") THEN
334         IF (fm_matrix%use_sp) THEN
335            CALL pstrsm(chol_pos, 'U', chol_transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), 1, 1, descb, &
336                        out_sp(1, 1), 1, 1, descout)
337         ELSE
338            CALL pdtrsm(chol_pos, 'U', chol_transa, 'N', n, neig, alpha, b(1, 1), 1, 1, descb, out(1, 1), 1, 1, descout)
339         ENDIF
340      ELSE
341         IF (fm_matrix%use_sp) THEN
342            CALL pstrmm(chol_pos, 'U', chol_transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), 1, 1, descb, &
343                        out_sp(1, 1), 1, 1, descout)
344         ELSE
345            CALL pdtrmm(chol_pos, 'U', chol_transa, 'N', n, neig, alpha, b(1, 1), 1, 1, descb, out(1, 1), 1, 1, descout)
346         ENDIF
347      ENDIF
348#else
349
350      alpha = 1.0_dp
351      IF (fm_matrix%use_sp) THEN
352         CALL scopy(neig*n, a_sp(1, 1), 1, out_sp(1, 1), 1)
353      ELSE
354         CALL dcopy(neig*n, a(1, 1), 1, out(1, 1), 1)
355      ENDIF
356      IF (op .EQ. "SOLVE") THEN
357         IF (fm_matrix%use_sp) THEN
358            CALL strsm(chol_pos, 'U', chol_transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), SIZE(b_sp, 1), out_sp(1, 1), n)
359         ELSE
360            CALL dtrsm(chol_pos, 'U', chol_transa, 'N', n, neig, alpha, b(1, 1), SIZE(b, 1), out(1, 1), n)
361         ENDIF
362      ELSE
363         IF (fm_matrix%use_sp) THEN
364            CALL strmm(chol_pos, 'U', chol_transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), n, out_sp(1, 1), n)
365         ELSE
366            CALL dtrmm(chol_pos, 'U', chol_transa, 'N', n, neig, alpha, b(1, 1), n, out(1, 1), n)
367         ENDIF
368      ENDIF
369
370#endif
371
372      CALL copy_fm_to_dbcsr(fm_matrixout, matrixout)
373
374      CALL cp_fm_release(fm_matrix)
375      CALL cp_fm_release(fm_matrixb)
376      CALL cp_fm_release(fm_matrixout)
377
378      CALL timestop(handle)
379
380   END SUBROUTINE cp_dbcsr_cholesky_restore
381
382END MODULE cp_dbcsr_cholesky
383
384