1#!/usr/bin/env python
2# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#
18
19'''
20Attach ddCOSMO to SCF, MCSCF, and post-SCF methods.
21'''
22
23import copy
24import numpy
25from pyscf import lib
26from pyscf.lib import logger
27from functools import reduce
28from pyscf import scf
29
30def _for_scf(mf, solvent_obj, dm=None):
31    '''Add solvent model to SCF (HF and DFT) method.
32
33    Kwargs:
34        dm : if given, solvent does not respond to the change of density
35            matrix. A frozen ddCOSMO potential is added to the results.
36    '''
37    if isinstance(mf, _Solvation):
38        mf.with_solvent = solvent_obj
39        return mf
40
41    oldMF = mf.__class__
42
43    if dm is not None:
44        solvent_obj.e, solvent_obj.v = solvent_obj.kernel(dm)
45        solvent_obj.frozen = True
46
47    class SCFWithSolvent(_Solvation, oldMF):
48        def __init__(self, mf, solvent):
49            self.__dict__.update(mf.__dict__)
50            self.with_solvent = solvent
51            self._keys.update(['with_solvent'])
52
53        def dump_flags(self, verbose=None):
54            oldMF.dump_flags(self, verbose)
55            self.with_solvent.check_sanity()
56            self.with_solvent.dump_flags(verbose)
57            return self
58
59        def reset(self, mol=None):
60            self.with_solvent.reset(mol)
61            return oldMF.reset(self, mol)
62
63        # Note v_solvent should not be added to get_hcore for scf methods.
64        # get_hcore is overloaded by many post-HF methods. Modifying
65        # SCF.get_hcore may lead error.
66
67        def get_veff(self, mol=None, dm=None, *args, **kwargs):
68            vhf = oldMF.get_veff(self, mol, dm, *args, **kwargs)
69            with_solvent = self.with_solvent
70            if not with_solvent.frozen:
71                with_solvent.e, with_solvent.v = with_solvent.kernel(dm)
72            e_solvent, v_solvent = with_solvent.e, with_solvent.v
73
74            # NOTE: v_solvent should not be added to vhf in this place. This is
75            # because vhf is used as the reference for direct_scf in the next
76            # iteration. If v_solvent is added here, it may break direct SCF.
77            return lib.tag_array(vhf, e_solvent=e_solvent, v_solvent=v_solvent)
78
79        def get_fock(self, h1e=None, s1e=None, vhf=None, dm=None, cycle=-1,
80                     diis=None, diis_start_cycle=None,
81                     level_shift_factor=None, damp_factor=None):
82            # DIIS was called inside oldMF.get_fock. v_solvent, as a function of
83            # dm, should be extrapolated as well. To enable it, v_solvent has to be
84            # added to the fock matrix before DIIS was called.
85            if getattr(vhf, 'v_solvent', None) is None:
86                vhf = self.get_veff(self.mol, dm)
87            return oldMF.get_fock(self, h1e, s1e, vhf+vhf.v_solvent, dm, cycle, diis,
88                                  diis_start_cycle, level_shift_factor, damp_factor)
89
90        def energy_elec(self, dm=None, h1e=None, vhf=None):
91            if dm is None:
92                dm = self.make_rdm1()
93            if getattr(vhf, 'e_solvent', None) is None:
94                vhf = self.get_veff(self.mol, dm)
95            e_tot, e_coul = oldMF.energy_elec(self, dm, h1e, vhf)
96            e_tot += vhf.e_solvent
97            self.scf_summary['e_solvent'] = vhf.e_solvent.real
98            logger.debug(self, 'Solvent Energy = %.15g', vhf.e_solvent)
99            return e_tot, e_coul
100
101        def nuc_grad_method(self):
102            grad_method = oldMF.nuc_grad_method(self)
103            return self.with_solvent.nuc_grad_method(grad_method)
104
105        Gradients = nuc_grad_method
106
107        def gen_response(self, *args, **kwargs):
108            vind = oldMF.gen_response(self, *args, **kwargs)
109            is_uhf = isinstance(self, scf.uhf.UHF)
110            # singlet=None is orbital hessian or CPHF type response function
111            singlet = kwargs.get('singlet', True)
112            singlet = singlet or singlet is None
113            def vind_with_solvent(dm1):
114                v = vind(dm1)
115                if self.with_solvent.equilibrium_solvation:
116                    if is_uhf:
117                        v_solvent = self.with_solvent._B_dot_x(dm1)
118                        v += v_solvent[0] + v_solvent[1]
119                    elif singlet:
120                        v += self.with_solvent._B_dot_x(dm1)
121                return v
122            return vind_with_solvent
123
124        def stability(self, *args, **kwargs):
125            # When computing orbital hessian, the second order derivatives of
126            # solvent energy needs to be computed. It is enabled by
127            # the attribute equilibrium_solvation in gen_response method.
128            # If solvent was frozen, its contribution is treated as the
129            # external potential. The response of solvent does not need to
130            # be considered in stability analysis.
131            with lib.temporary_env(self.with_solvent,
132                                   equilibrium_solvation=not self.with_solvent.frozen):
133                return oldMF.stability(self, *args, **kwargs)
134
135    mf1 = SCFWithSolvent(mf, solvent_obj)
136    return mf1
137
138def _for_casscf(mc, solvent_obj, dm=None):
139    '''Add solvent model to CASSCF method.
140
141    Kwargs:
142        dm : if given, solvent does not respond to the change of density
143            matrix. A frozen ddCOSMO potential is added to the results.
144    '''
145    if isinstance(mc, _Solvation):
146        mc.with_solvent = solvent_obj
147        return mc
148
149    oldCAS = mc.__class__
150
151    if dm is not None:
152        solvent_obj.e, solvent_obj.v = solvent_obj.kernel(dm)
153        solvent_obj.frozen = True
154
155    class CASSCFWithSolvent(_Solvation, oldCAS):
156        def __init__(self, mc, solvent):
157            self.__dict__.update(mc.__dict__)
158            self.with_solvent = solvent
159            self._e_tot_without_solvent = 0
160            self._keys.update(['with_solvent'])
161
162        def dump_flags(self, verbose=None):
163            oldCAS.dump_flags(self, verbose)
164            self.with_solvent.check_sanity()
165            self.with_solvent.dump_flags(verbose)
166            if self.conv_tol < 1e-7:
167                logger.warn(self, 'CASSCF+ddCOSMO may not be able to '
168                            'converge to conv_tol=%g', self.conv_tol)
169
170            if (getattr(self._scf, 'with_solvent', None) and
171                not getattr(self, 'with_solvent', None)):
172                logger.warn(self, '''Solvent model %s was found in SCF object.
173COSMO is not applied to the CASCI object. The CASSCF result is not affected by the SCF solvent model.
174To enable the solvent model for CASSCF, a decoration to CASSCF object as below needs to be called
175        from pyscf import solvent
176        mc = mcscf.CASSCF(...)
177        mc = solvent.ddCOSMO(mc)
178''',
179                            self._scf.with_solvent.__class__)
180            return self
181
182        def reset(self, mol=None):
183            self.with_solvent.reset(mol)
184            return oldCAS.reset(self, mol)
185
186        def update_casdm(self, mo, u, fcivec, e_ci, eris, envs={}):
187            casdm1, casdm2, gci, fcivec = \
188                    oldCAS.update_casdm(self, mo, u, fcivec, e_ci, eris, envs)
189
190# The potential is generated based on the density of current micro iteration.
191# It will be added to hcore in casci function. Strictly speaking, this density
192# is not the same to the CASSCF density (which was used to measure
193# convergence) in the macro iterations.  When CASSCF is converged, it
194# should be almost the same to the CASSCF density of the last macro iteration.
195            with_solvent = self.with_solvent
196            if not with_solvent.frozen:
197                # Code to mimic dm = self.make_rdm1(ci=fcivec)
198                mocore = mo[:,:self.ncore]
199                mocas = mo[:,self.ncore:self.ncore+self.ncas]
200                dm = reduce(numpy.dot, (mocas, casdm1, mocas.T))
201                dm += numpy.dot(mocore, mocore.T) * 2
202                with_solvent.e, with_solvent.v = with_solvent.kernel(dm)
203
204            return casdm1, casdm2, gci, fcivec
205
206# ddCOSMO Potential should be added to the effective potential. However, there
207# is no hook to modify the effective potential in CASSCF. The workaround
208# here is to modify hcore. It can affect the 1-electron operator in many CASSCF
209# functions: gen_h_op, update_casdm, casci.  Note hcore is used to compute the
210# energy for core density (Ecore).  The resultant total energy from casci
211# function will include the contribution from ddCOSMO potential. The
212# duplicated energy contribution from solvent needs to be removed.
213        def get_hcore(self, mol=None):
214            hcore = self._scf.get_hcore(mol)
215            if self.with_solvent.v is not None:
216                hcore += self.with_solvent.v
217            return hcore
218
219        def casci(self, mo_coeff, ci0=None, eris=None, verbose=None, envs=None):
220            log = logger.new_logger(self, verbose)
221            log.debug('Running CASCI with solvent. Note the total energy '
222                      'has duplicated contributions from solvent.')
223
224            # In oldCAS.casci function, dE was computed based on the total
225            # energy without removing the duplicated solvent contributions.
226            # However, envs['elast'] is the last total energy with correct
227            # solvent effects. Hack envs['elast'] to make oldCAS.casci print
228            # the correct energy difference.
229            envs['elast'] = self._e_tot_without_solvent
230            e_tot, e_cas, fcivec = oldCAS.casci(self, mo_coeff, ci0, eris,
231                                                verbose, envs)
232            self._e_tot_without_solvent = e_tot
233
234            log.debug('Computing corrections to the total energy.')
235            dm = self.make_rdm1(ci=fcivec, ao_repr=True)
236
237            with_solvent = self.with_solvent
238            if with_solvent.e is not None:
239                edup = numpy.einsum('ij,ji->', with_solvent.v, dm)
240                e_tot = e_tot - edup + with_solvent.e
241                log.info('Removing duplication %.15g, '
242                         'adding E(solvent) = %.15g to total energy:\n'
243                         '    E(CASSCF+solvent) = %.15g', edup, with_solvent.e, e_tot)
244
245            # Update solvent effects for next iteration if needed
246            if not with_solvent.frozen:
247                with_solvent.e, with_solvent.v = with_solvent.kernel(dm)
248
249            return e_tot, e_cas, fcivec
250
251        def nuc_grad_method(self):
252            logger.warn(self, '''
253The code for CASSCF gradients was based on variational CASSCF wavefunction.
254However, the ddCOSMO-CASSCF energy was not computed variationally.
255Approximate gradients are evaluated here. A small error may be expected in the
256gradients which corresponds to the contribution of
257  MCSCF_DM * V_solvent[d/dX MCSCF_DM] + V_solvent[MCSCF_DM] * d/dX MCSCF_DM
258''')
259            grad_method = oldCAS.nuc_grad_method(self)
260            return self.with_solvent.nuc_grad_method(grad_method)
261
262        Gradients = nuc_grad_method
263
264    return CASSCFWithSolvent(mc, solvent_obj)
265
266
267def _for_casci(mc, solvent_obj, dm=None):
268    '''Add solvent model to CASCI method.
269
270    Kwargs:
271        dm : if given, solvent does not respond to the change of density
272            matrix. A frozen ddCOSMO potential is added to the results.
273    '''
274    if isinstance(mc, _Solvation):
275        mc.with_solvent = solvent_obj
276        return mc
277
278    oldCAS = mc.__class__
279
280    if dm is not None:
281        solvent_obj.e, solvent_obj.v = solvent_obj.kernel(dm)
282        solvent_obj.frozen = True
283
284    class CASCIWithSolvent(_Solvation, oldCAS):
285        def __init__(self, mc, solvent):
286            self.__dict__.update(mc.__dict__)
287            self.with_solvent = solvent
288            self._keys.update(['with_solvent'])
289
290        def dump_flags(self, verbose=None):
291            oldCAS.dump_flags(self, verbose)
292            self.with_solvent.check_sanity()
293            self.with_solvent.dump_flags(verbose)
294            return self
295
296        def reset(self, mol=None):
297            self.with_solvent.reset(mol)
298            return oldCAS.reset(self, mol)
299
300        def get_hcore(self, mol=None):
301            hcore = self._scf.get_hcore(mol)
302            if self.with_solvent.v is not None:
303                # NOTE: get_hcore was called by CASCI to generate core
304                # potential.  v_solvent is added in this place to take accounts the
305                # effects of solvent. Its contribution is duplicated and it
306                # should be removed from the total energy.
307                hcore += self.with_solvent.v
308            return hcore
309
310        def kernel(self, mo_coeff=None, ci0=None, verbose=None):
311            with_solvent = self.with_solvent
312
313            log = logger.new_logger(self)
314            log.info('\n** Self-consistently update the solvent effects for %s **',
315                     oldCAS)
316            log1 = copy.copy(log)
317            log1.verbose -= 1  # Suppress a few output messages
318
319            def casci_iter_(ci0, log):
320                # self.e_tot, self.e_cas, and self.ci are updated in the call
321                # to oldCAS.kernel
322                e_tot, e_cas, ci0 = oldCAS.kernel(self, mo_coeff, ci0, log)[:3]
323
324                if isinstance(self.e_cas, (float, numpy.number)):
325                    dm = self.make_rdm1(ci=ci0)
326                else:
327                    log.debug('Computing solvent responses to DM of state %d',
328                              with_solvent.state_id)
329                    dm = self.make_rdm1(ci=ci0[with_solvent.state_id])
330
331                if with_solvent.e is not None:
332                    edup = numpy.einsum('ij,ji->', with_solvent.v, dm)
333                    self.e_tot += with_solvent.e - edup
334
335                if not with_solvent.frozen:
336                    with_solvent.e, with_solvent.v = with_solvent.kernel(dm)
337                return self.e_tot, e_cas, ci0
338
339            if with_solvent.frozen:
340                with lib.temporary_env(self, _finalize=lambda:None):
341                    casci_iter_(ci0, log)
342                log.note('Total energy with solvent effects')
343                self._finalize()
344                return self.e_tot, self.e_cas, self.ci, self.mo_coeff, self.mo_energy
345
346            self.converged = False
347            with lib.temporary_env(self, canonicalization=False):
348                e_tot = e_last = 0
349                for cycle in range(self.with_solvent.max_cycle):
350                    log.info('\n** Solvent self-consistent cycle %d:', cycle)
351                    e_tot, e_cas, ci0 = casci_iter_(ci0, log1)
352
353                    de = e_tot - e_last
354                    if isinstance(e_cas, (float, numpy.number)):
355                        log.info('Sovlent cycle %d  E(CASCI+solvent) = %.15g  '
356                                 'dE = %g', cycle, e_tot, de)
357                    else:
358                        for i, e in enumerate(e_tot):
359                            log.info('Solvent cycle %d  CASCI root %d  '
360                                     'E(CASCI+solvent) = %.15g  dE = %g',
361                                     cycle, i, e, de[i])
362
363                    if abs(e_tot-e_last).max() < with_solvent.conv_tol:
364                        self.converged = True
365                        break
366                    e_last = e_tot
367
368            # An extra cycle to canonicalize CASCI orbitals
369            with lib.temporary_env(self, _finalize=lambda:None):
370                casci_iter_(ci0, log)
371            if self.converged:
372                log.info('self-consistent CASCI+solvent converged')
373            else:
374                log.info('self-consistent CASCI+solvent not converged')
375            log.note('Total energy with solvent effects')
376            self._finalize()
377            return self.e_tot, self.e_cas, self.ci, self.mo_coeff, self.mo_energy
378
379        def nuc_grad_method(self):
380            logger.warn(self, '''
381The code for CASCI gradients was based on variational CASCI wavefunction.
382However, the ddCOSMO-CASCI energy was not computed variationally.
383Approximate gradients are evaluated here. A small error may be expected in the
384gradients which corresponds to the contribution of
385  MCSCF_DM * V_solvent[d/dX MCSCF_DM] + V_solvent[MCSCF_DM] * d/dX MCSCF_DM
386''')
387            grad_method = oldCAS.nuc_grad_method(self)
388            return self.with_solvent.nuc_grad_method(grad_method)
389
390        Gradients = nuc_grad_method
391
392    return CASCIWithSolvent(mc, solvent_obj)
393
394
395def _for_post_scf(method, solvent_obj, dm=None):
396    '''A wrapper of solvent model for post-SCF methods (CC, CI, MP etc.)
397
398    NOTE: this implementation often causes (macro iteration) convergence issue
399
400    Kwargs:
401        dm : if given, solvent does not respond to the change of density
402            matrix. A frozen ddCOSMO potential is added to the results.
403    '''
404    if isinstance(method, _Solvation):
405        method.with_solvent = solvent_obj
406        method._scf.with_solvent = solvent_obj
407        return method
408
409    old_method = method.__class__
410
411    # Ensure that the underlying _scf object has solvent model enabled
412    if getattr(method._scf, 'with_solvent', None):
413        scf_with_solvent = method._scf
414    else:
415        scf_with_solvent = _for_scf(method._scf, solvent_obj, dm)
416        if dm is None:
417            solvent_obj = scf_with_solvent.with_solvent
418            solvent_obj.e, solvent_obj.v = \
419                    solvent_obj.kernel(scf_with_solvent.make_rdm1())
420
421    # Post-HF objects access the solvent effects indirectly through the
422    # underlying ._scf object.
423    basic_scanner = method.as_scanner()
424    basic_scanner._scf = scf_with_solvent.as_scanner()
425
426    if dm is not None:
427        solvent_obj = scf_with_solvent.with_solvent
428        solvent_obj.e, solvent_obj.v = solvent_obj.kernel(dm)
429        solvent_obj.frozen = True
430
431    class PostSCFWithSolvent(_Solvation, old_method):
432        def __init__(self, method):
433            self.__dict__.update(method.__dict__)
434            self._scf = scf_with_solvent
435
436        @property
437        def with_solvent(self):
438            return self._scf.with_solvent
439
440        def dump_flags(self, verbose=None):
441            old_method.dump_flags(self, verbose)
442            self.with_solvent.check_sanity()
443            self.with_solvent.dump_flags(verbose)
444            return self
445
446        def reset(self, mol=None):
447            self.with_solvent.reset(mol)
448            return old_method.reset(self, mol)
449
450        def kernel(self, *args, **kwargs):
451            with_solvent = self.with_solvent
452            # The underlying ._scf object is decorated with solvent effects.
453            # The resultant Fock matrix and orbital energies both include the
454            # effects from solvent. It means that solvent effects for post-HF
455            # methods are automatically counted if solvent is enabled at scf
456            # level.
457            if with_solvent.frozen:
458                return old_method.kernel(self, *args, **kwargs)
459
460            log = logger.new_logger(self)
461            log.info('\n** Self-consistently update the solvent effects for %s **',
462                     old_method)
463            ##TODO: Suppress a few output messages
464            #log1 = copy.copy(log)
465            #log1.note, log1.info = log1.info, log1.debug
466
467            e_last = 0
468            #diis = lib.diis.DIIS()
469            for cycle in range(self.with_solvent.max_cycle):
470                log.info('\n** Solvent self-consistent cycle %d:', cycle)
471                # Solvent effects are applied when accessing the
472                # underlying ._scf objects. The flag frozen=True ensures that
473                # the generated potential with_solvent.v is passed to the
474                # the post-HF object, without being updated in the implicit
475                # call to the _scf iterations.
476                with lib.temporary_env(with_solvent, frozen=True):
477                    e_tot = basic_scanner(self.mol)
478                    dm = basic_scanner.make_rdm1(ao_repr=True)
479                    #dm = diis.update(dm)
480
481                # To generate the solvent potential for ._scf object. Since
482                # frozen is set when calling basic_scanner, the solvent
483                # effects are frozen during the scf iterations.
484                with_solvent.e, with_solvent.v = with_solvent.kernel(dm)
485
486                de = e_tot - e_last
487                log.info('Sovlent cycle %d  E_tot = %.15g  dE = %g',
488                         cycle, e_tot, de)
489
490                if abs(e_tot-e_last).max() < with_solvent.conv_tol:
491                    break
492                e_last = e_tot
493
494            # An extra cycle to compute the total energy
495            log.info('\n** Extra cycle for solvent effects')
496            with lib.temporary_env(with_solvent, frozen=True):
497                #Update everything except the _scf object and _keys
498                basic_scanner(self.mol)
499                self.__dict__.update(basic_scanner.__dict__)
500                self._scf.__dict__.update(basic_scanner._scf.__dict__)
501            self._finalize()
502            return self.e_corr, None
503
504        def nuc_grad_method(self):
505            logger.warn(self, '''
506Approximate gradients are evaluated here. A small error may be expected in the
507gradients which corresponds to the contribution of
508  DM * V_solvent[d/dX DM] + V_solvent[DM] * d/dX DM
509''')
510            grad_method = old_method.nuc_grad_method(self)
511            return self.with_solvent.nuc_grad_method(grad_method)
512
513        Gradients = nuc_grad_method
514
515    return PostSCFWithSolvent(method)
516
517
518def _for_tdscf(method, solvent_obj, dm=None):
519    '''Add solvent model in TDDFT calculations.
520
521    Kwargs:
522        dm : if given, solvent does not respond to the change of density
523            matrix. A frozen ddCOSMO potential is added to the results.
524    '''
525    if isinstance(method, _Solvation):
526        method.with_solvent = solvent_obj
527        method._scf.with_solvent = solvent_obj
528        return method
529
530    old_method = method.__class__
531
532    # Ensure that the underlying _scf object has solvent model enabled
533    if getattr(method._scf, 'with_solvent', None):
534        scf_with_solvent = method._scf
535    else:
536        scf_with_solvent = _for_scf(method._scf, solvent_obj, dm).run()
537
538    if dm is not None:
539        solvent_obj = scf_with_solvent.with_solvent
540        solvent_obj.e, solvent_obj.v = solvent_obj.kernel(dm)
541        solvent_obj.frozen = True
542
543    class TDSCFWithSolvent(_Solvation, old_method):
544        def __init__(self, method):
545            self.__dict__.update(method.__dict__)
546            self._scf = scf_with_solvent
547            self.with_solvent = self._scf.with_solvent
548            self._keys.update(['with_solvent'])
549
550        @property
551        def equilibrium_solvation(self):
552            '''Whether to allow the solvent rapidly responds to the changes of
553            electronic structure or geometry of solute.
554            '''
555            return self.with_solvent.equilibrium_solvation
556        @equilibrium_solvation.setter
557        def equilibrium_solvation(self, val):
558            if val and self.with_solvent.frozen:
559                logger.warn(self, 'Solvent model was set to be frozen in the '
560                            'ground state SCF calculation. It may conflict to '
561                            'the assumption of equilibrium solvation.\n'
562                            'You may set _scf.with_solvent.frozen = False and '
563                            'rerun the ground state calculation _scf.run().')
564            self.with_solvent.equilibrium_solvation = val
565
566        def dump_flags(self, verbose=None):
567            old_method.dump_flags(self, verbose)
568            self.with_solvent.check_sanity()
569            self.with_solvent.dump_flags(verbose)
570            return self
571
572        def reset(self, mol=None):
573            self.with_solvent.reset(mol)
574            return old_method.reset(self, mol)
575
576        def get_ab(self, mf=None):
577            #if mf is None: mf = self._scf
578            #a, b = get_ab(mf)
579            if self.equilibrium_solvation:
580                raise NotImplementedError
581
582        def nuc_grad_method(self):
583            grad_method = old_method.nuc_grad_method(self)
584            return self.with_solvent.nuc_grad_method(grad_method)
585
586    mf1 = TDSCFWithSolvent(method)
587    return mf1
588
589# 1. A tag to label the derived method class
590class _Solvation(object):
591    pass
592