1!--------------------------------------------------------------------------------------------------!
2!   CP2K: A general program to perform molecular dynamics simulations                              !
3!   Copyright (C) 2000 - 2020  CP2K developers group                                               !
4!--------------------------------------------------------------------------------------------------!
5
6! **************************************************************************************************
7!> \brief Main module for the PAO method
8!> \author Ole Schuett
9! **************************************************************************************************
10MODULE pao_main
11   USE bibliography,                    ONLY: Schuett2018,&
12                                              cite_reference
13   USE cp_external_control,             ONLY: external_control
14   USE dbcsr_api,                       ONLY: dbcsr_add,&
15                                              dbcsr_copy,&
16                                              dbcsr_create,&
17                                              dbcsr_p_type,&
18                                              dbcsr_release,&
19                                              dbcsr_reserve_diag_blocks,&
20                                              dbcsr_set,&
21                                              dbcsr_type
22   USE dm_ls_scf_types,                 ONLY: ls_mstruct_type,&
23                                              ls_scf_env_type
24   USE input_section_types,             ONLY: section_vals_get_subs_vals,&
25                                              section_vals_type
26   USE kinds,                           ONLY: dp
27   USE linesearch,                      ONLY: linesearch_finalize,&
28                                              linesearch_init,&
29                                              linesearch_reset,&
30                                              linesearch_step
31   USE machine,                         ONLY: m_walltime
32   USE pao_input,                       ONLY: parse_pao_section
33   USE pao_io,                          ONLY: pao_read_restart,&
34                                              pao_write_ks_matrix_csr,&
35                                              pao_write_restart,&
36                                              pao_write_s_matrix_csr
37   USE pao_methods,                     ONLY: &
38        pao_add_forces, pao_build_core_hamiltonian, pao_build_diag_distribution, &
39        pao_build_matrix_X, pao_build_orthogonalizer, pao_build_selector, pao_calc_energy, &
40        pao_calc_outer_grad_lnv, pao_check_grad, pao_check_trace_ps, pao_guess_initial_P, &
41        pao_init_kinds, pao_print_atom_info, pao_store_P, pao_test_convergence
42   USE pao_ml,                          ONLY: pao_ml_init,&
43                                              pao_ml_predict
44   USE pao_optimizer,                   ONLY: pao_opt_finalize,&
45                                              pao_opt_init,&
46                                              pao_opt_new_dir
47   USE pao_param,                       ONLY: pao_calc_U,&
48                                              pao_param_finalize,&
49                                              pao_param_init,&
50                                              pao_param_initial_guess,&
51                                              pao_update_AB
52   USE pao_types,                       ONLY: pao_env_type
53   USE qs_environment_types,            ONLY: get_qs_env,&
54                                              qs_environment_type
55#include "./base/base_uses.f90"
56
57   IMPLICIT NONE
58
59   PRIVATE
60
61   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pao_main'
62
63   PUBLIC :: pao_init, pao_update, pao_post_scf, pao_optimization_start, pao_optimization_end
64
65CONTAINS
66
67! **************************************************************************************************
68!> \brief Initialize the PAO environment
69!> \param qs_env ...
70!> \param ls_scf_env ...
71! **************************************************************************************************
72   SUBROUTINE pao_init(qs_env, ls_scf_env)
73      TYPE(qs_environment_type), POINTER                 :: qs_env
74      TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
75
76      CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_init'
77
78      INTEGER                                            :: handle
79      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
80      TYPE(pao_env_type), POINTER                        :: pao
81      TYPE(section_vals_type), POINTER                   :: input
82
83      IF (.NOT. ls_scf_env%do_pao) RETURN
84
85      CALL timeset(routineN, handle)
86      CALL cite_reference(Schuett2018)
87      pao => ls_scf_env%pao_env
88      CALL get_qs_env(qs_env=qs_env, input=input, matrix_s=matrix_s)
89
90      ! parse input
91      CALL parse_pao_section(pao, input)
92
93      CALL pao_init_kinds(pao, qs_env)
94
95      ! train machine learning
96      CALL pao_ml_init(pao, qs_env)
97
98      CALL timestop(handle)
99   END SUBROUTINE pao_init
100
101! **************************************************************************************************
102!> \brief Start a PAO optimization run.
103!> \param qs_env ...
104!> \param ls_scf_env ...
105! **************************************************************************************************
106   SUBROUTINE pao_optimization_start(qs_env, ls_scf_env)
107      TYPE(qs_environment_type), POINTER                 :: qs_env
108      TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
109
110      CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_start'
111
112      INTEGER                                            :: handle
113      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
114      TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
115      TYPE(pao_env_type), POINTER                        :: pao
116      TYPE(section_vals_type), POINTER                   :: input, section
117
118      IF (.NOT. ls_scf_env%do_pao) RETURN
119
120      CALL timeset(routineN, handle)
121      CALL get_qs_env(qs_env=qs_env, &
122                      matrix_s=matrix_s, &
123                      input=input)
124
125      pao => ls_scf_env%pao_env
126      ls_mstruct => ls_scf_env%ls_mstruct
127
128      ! reset state
129      pao%step_start_time = m_walltime()
130      pao%istep = 0
131      pao%matrix_P_ready = .FALSE.
132
133      ! ready stuff that does not depend on atom positions
134      IF (.NOT. pao%constants_ready) THEN
135         CALL pao_build_diag_distribution(pao, qs_env)
136         CALL pao_build_orthogonalizer(pao, qs_env)
137         CALL pao_build_selector(pao, qs_env)
138         CALL pao_build_core_hamiltonian(pao, qs_env)
139         pao%constants_ready = .TRUE.
140      ENDIF
141
142      CALL pao_param_init(pao, qs_env)
143
144      ! ready PAO parameter matrix_X
145      IF (.NOT. pao%matrix_X_ready) THEN
146         CALL pao_build_matrix_X(pao, qs_env)
147         CALL pao_print_atom_info(pao)
148         IF (LEN_TRIM(pao%restart_file) > 0) THEN
149            CALL pao_read_restart(pao, qs_env)
150         ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
151            CALL pao_ml_predict(pao, qs_env)
152         ELSE
153            CALL pao_param_initial_guess(pao, qs_env)
154         ENDIF
155         pao%matrix_X_ready = .TRUE.
156      ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
157         CALL pao_ml_predict(pao, qs_env)
158      ELSE
159         IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| reusing matrix_X from previous optimization"
160      ENDIF
161
162      ! init line-search
163      section => section_vals_get_subs_vals(input, "DFT%LS_SCF%PAO%LINE_SEARCH")
164      CALL linesearch_init(pao%linesearch, section, "PAO|")
165
166      ! create some more matrices
167      CALL dbcsr_copy(pao%matrix_G, pao%matrix_X)
168      CALL dbcsr_set(pao%matrix_G, 0.0_dp)
169
170      CALL dbcsr_create(pao%matrix_U, &
171                        name="PAO matrix_U", &
172                        matrix_type="N", &
173                        dist=pao%diag_distribution, &
174                        template=matrix_s(1)%matrix)
175      CALL dbcsr_reserve_diag_blocks(pao%matrix_U)
176
177      CALL dbcsr_create(ls_mstruct%matrix_A, template=pao%matrix_Y)
178      CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_A)
179      CALL dbcsr_create(ls_mstruct%matrix_B, template=pao%matrix_Y)
180      CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_B)
181
182      ! fill PAO transformation matrices
183      CALL pao_update_AB(pao, qs_env, ls_mstruct)
184
185      CALL timestop(handle)
186   END SUBROUTINE pao_optimization_start
187
188! **************************************************************************************************
189!> \brief Called after the SCF optimization, updates the PAO basis.
190!> \param qs_env ...
191!> \param ls_scf_env ...
192!> \param pao_is_done ...
193! **************************************************************************************************
194   SUBROUTINE pao_update(qs_env, ls_scf_env, pao_is_done)
195      TYPE(qs_environment_type), POINTER                 :: qs_env
196      TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
197      LOGICAL, INTENT(OUT)                               :: pao_is_done
198
199      CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_update'
200
201      INTEGER                                            :: handle, icycle
202      LOGICAL                                            :: cycle_converged, do_mixing, should_stop
203      REAL(KIND=dp)                                      :: energy, penalty
204      TYPE(dbcsr_type)                                   :: matrix_M, matrix_X_mixing
205      TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
206      TYPE(pao_env_type), POINTER                        :: pao
207
208      IF (.NOT. ls_scf_env%do_pao) THEN
209         pao_is_done = .TRUE.
210         RETURN
211      ENDIF
212
213      ls_mstruct => ls_scf_env%ls_mstruct
214      pao => ls_scf_env%pao_env
215
216      IF (.NOT. pao%matrix_P_ready) THEN
217         CALL pao_guess_initial_P(pao, qs_env, ls_scf_env)
218         pao%matrix_P_ready = .TRUE.
219      ENDIF
220
221      IF (pao%max_pao == 0) THEN
222         pao_is_done = .TRUE.
223         RETURN
224      ENDIF
225
226      IF (pao%need_initial_scf) THEN
227         pao_is_done = .FALSE.
228         pao%need_initial_scf = .FALSE.
229         IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Performing initial SCF optimization."
230         RETURN
231      ENDIF
232
233      CALL timeset(routineN, handle)
234
235      ! perform mixing once we are well into the optimization
236      do_mixing = pao%mixing /= 1.0_dp .AND. pao%istep > 1
237      IF (do_mixing) THEN
238         CALL dbcsr_copy(matrix_X_mixing, pao%matrix_X)
239      ENDIF
240
241      cycle_converged = .FALSE.
242      icycle = 0
243      CALL linesearch_reset(pao%linesearch)
244      CALL pao_opt_init(pao)
245
246      DO WHILE (.TRUE.)
247         pao%istep = pao%istep + 1
248
249         IF (pao%iw > 0) WRITE (pao%iw, "(A,I9,A)") " PAO| ======================= Iteration: ", &
250            pao%istep, " ============================="
251
252         ! calc energy and check trace_PS
253         CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
254         CALL pao_check_trace_PS(ls_scf_env)
255
256         IF (pao%linesearch%starts) THEN
257            icycle = icycle + 1
258            ! calc new gradient including penalty terms
259            CALL pao_calc_outer_grad_lnv(qs_env, ls_scf_env, matrix_M)
260            CALL pao_calc_U(pao, qs_env, matrix_M=matrix_M, matrix_G=pao%matrix_G, penalty=penalty)
261            CALL dbcsr_release(matrix_M)
262            CALL pao_check_grad(pao, qs_env, ls_scf_env)
263
264            ! calculate new direction for line-search
265            CALL pao_opt_new_dir(pao, icycle)
266
267            !backup X
268            CALL dbcsr_copy(pao%matrix_X_orig, pao%matrix_X)
269
270            ! print info and convergence test
271            CALL pao_test_convergence(pao, ls_scf_env, energy, cycle_converged)
272            IF (cycle_converged) THEN
273               pao_is_done = icycle < 3
274               IF (pao_is_done .AND. pao%iw > 0) WRITE (pao%iw, *) "PAO| converged after ", pao%istep, " steps :-)"
275               EXIT
276            ENDIF
277
278            ! if we have reached the maximum number of cycles exit in order
279            ! to restart with a fresh hamiltonian
280            IF (icycle >= pao%max_cycles) THEN
281               IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| CG not yet converged after ", icycle, " cylces."
282               pao_is_done = .FALSE.
283               EXIT
284            ENDIF
285
286            IF (MOD(icycle, pao%write_cycles) == 0) &
287               CALL pao_write_restart(pao, qs_env, energy) ! write an intermediate restart file
288         ENDIF
289
290         ! check for early abort without convergence?
291         CALL external_control(should_stop, "PAO", start_time=qs_env%start_time, target_time=qs_env%target_time)
292         IF (should_stop .OR. pao%istep >= pao%max_pao) THEN
293            CPWARN("PAO not converged!")
294            pao_is_done = .TRUE.
295            EXIT
296         ENDIF
297
298         ! perform line-search step
299         CALL linesearch_step(pao%linesearch, energy=energy, slope=pao%norm_G**2)
300
301         IF (pao%linesearch%step_size < 1e-10_dp) CPABORT("PAO gradient is wrong.")
302
303         CALL dbcsr_copy(pao%matrix_X, pao%matrix_X_orig) !restore X
304         CALL dbcsr_add(pao%matrix_X, pao%matrix_D, 1.0_dp, pao%linesearch%step_size)
305      ENDDO
306
307      ! perform mixing of matrix_X
308      IF (do_mixing) THEN
309         CALL dbcsr_add(pao%matrix_X, matrix_X_mixing, pao%mixing, 1.0_dp - pao%mixing)
310         CALL dbcsr_release(matrix_X_mixing)
311         IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Recalculating energy after mixing."
312         CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
313      ENDIF
314
315      CALL pao_write_restart(pao, qs_env, energy)
316      CALL pao_opt_finalize(pao)
317
318      CALL timestop(handle)
319   END SUBROUTINE pao_update
320
321! **************************************************************************************************
322!> \brief Calculate PAO forces and store density matrix for future ASPC extrapolations
323!> \param qs_env ...
324!> \param ls_scf_env ...
325!> \param pao_is_done ...
326! **************************************************************************************************
327   SUBROUTINE pao_post_scf(qs_env, ls_scf_env, pao_is_done)
328      TYPE(qs_environment_type), POINTER                 :: qs_env
329      TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
330      LOGICAL, INTENT(IN)                                :: pao_is_done
331
332      CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_post_scf'
333
334      INTEGER                                            :: handle
335
336      IF (.NOT. ls_scf_env%do_pao) RETURN
337      IF (.NOT. pao_is_done) RETURN
338
339      CALL timeset(routineN, handle)
340
341      ! print out the matrices here before pao_store_P converts them back into matrices in
342      ! terms of the primary basis
343      CALL pao_write_ks_matrix_csr(qs_env, ls_scf_env)
344      CALL pao_write_s_matrix_csr(qs_env, ls_scf_env)
345
346      CALL pao_store_P(qs_env, ls_scf_env)
347      IF (ls_scf_env%calculate_forces) CALL pao_add_forces(qs_env, ls_scf_env)
348
349      CALL timestop(handle)
350   END SUBROUTINE
351
352! **************************************************************************************************
353!> \brief Finish a PAO optimization run.
354!> \param ls_scf_env ...
355! **************************************************************************************************
356   SUBROUTINE pao_optimization_end(ls_scf_env)
357      TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
358
359      CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_end'
360
361      INTEGER                                            :: handle
362      TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
363      TYPE(pao_env_type), POINTER                        :: pao
364
365      IF (.NOT. ls_scf_env%do_pao) RETURN
366
367      pao => ls_scf_env%pao_env
368      ls_mstruct => ls_scf_env%ls_mstruct
369
370      CALL timeset(routineN, handle)
371
372      CALL pao_param_finalize(pao)
373
374      ! We keep pao%matrix_X for next scf-run, e.g. during MD or GEO-OPT
375      CALL dbcsr_release(pao%matrix_X_orig)
376      CALL dbcsr_release(pao%matrix_G)
377      CALL dbcsr_release(pao%matrix_U)
378      CALL dbcsr_release(ls_mstruct%matrix_A)
379      CALL dbcsr_release(ls_mstruct%matrix_B)
380
381      CALL linesearch_finalize(pao%linesearch)
382
383      CALL timestop(handle)
384   END SUBROUTINE pao_optimization_end
385
386END MODULE pao_main
387