1#!/usr/bin/env python
2# constrained DFT
3# written by Zhihao Cui. zcui@caltech.edu
4
5r'''
6Constrained DFT (cDFT) is a method to control the locality of electron density
7during the SCF calculation. By constraining the electron density or occupation
8on a particular orbital, atom, or functional group, cDFT provides a solution to
9study the charge transfer problems. One type of constraints is to integrate the
10electron density with certain real space weight function and make its
11expectation value equal to a given value (see Wu. PRA, 72(2005), 024502)
12
13    .. math::
14
15        \int w(\mathbf{r}) \rho(\mathbf{r}) d\mathbf{r} = N_c
16
17This example shows another type of constraints which controls the electron
18population (Mulliken population with localized orbitals) during the SCF
19iterations
20
21    .. math::
22
23        \sum_{p} \gamma_{pq} S_{qp} = N_c
24
25When incorporating this constraint with the HF/KS method, a Lagrange multiplier
26V_c for the constraint is used in the energy minimization procedure
27:math:`E + V_c (\sum_{p}gamma_{pq} S_{qp} - N_c)`. The constraints lead to an
28extra term in the Fock matrix. This can be achieved by modifying the
29:func:`get_fock` method of SCF object as shown by the code below.
30
31Since the constraints are based on population analysis, it has close relation to
32the population method and quality of the localized orbitals. The code
33demonstrated in this example supports four localization schemes: Lowdin
34orthogonalization, meta-Lowdin orthogonalization, intrinsic atomic orbitals,
35natural atomic orbitals.
36'''
37
38import numpy as np
39import scipy.linalg as la
40import copy
41from functools import reduce
42from pyscf import gto, scf, lo, dft, lib
43from pyscf.pbc.scf import khf
44
45
46def get_localized_orbitals(mf, lo_method, mo=None):
47    if mo is None:
48        mo = mf.mo_coeff
49
50    if not isinstance(mf, khf.KSCF):
51        mol = mf.mol
52        s1e = mf.get_ovlp()
53
54        if lo_method.lower() == 'lowdin' or lo_method.lower() == 'meta_lowdin':
55            C = lo.orth_ao(mf, 'meta_lowdin', s=s1e)
56            C_inv = np.dot(C.conj().T,s1e)
57            if isinstance(mf, scf.hf.RHF):
58                C_inv_spin = C_inv
59            else:
60                C_inv_spin = np.array([C_inv]*2)
61
62        elif lo_method == 'iao':
63            s1e = mf.get_ovlp()
64            pmol = mf.mol.copy()
65            pmol.build(False, False, basis='minao')
66            if isinstance(mf, scf.hf.RHF):
67                mo_coeff_occ = mf.mo_coeff[:,mf.mo_occ>0]
68                C = lo.iao.iao(mf.mol, mo_coeff_occ)
69                # Orthogonalize IAO
70                C = lo.vec_lowdin(C, s1e)
71                C_inv = np.dot(C.conj().T,s1e)
72                C_inv_spin = C_inv
73            else:
74                mo_coeff_occ_a = mf.mo_coeff[0][:,mf.mo_occ[0]>0]
75                mo_coeff_occ_b = mf.mo_coeff[1][:,mf.mo_occ[1]>0]
76                C_a = lo.iao.iao(mf.mol, mo_coeff_occ_a)
77                C_b = lo.iao.iao(mf.mol, mo_coeff_occ_b)
78                C_a = lo.vec_lowdin(C_a, s1e)
79                C_b = lo.vec_lowdin(C_b, s1e)
80                C_inv_a = np.dot(C_a.T, s1e)
81                C_inv_b = np.dot(C_b.T, s1e)
82                C_inv_spin = np.array([C_inv_a, C_inv_b])
83
84        elif lo_method == 'nao':
85            C = lo.orth_ao(mf, 'nao')
86            C_inv = np.dot(C.conj().T,s1e)
87            if isinstance(mf, scf.hf.RHF):
88                C_inv_spin = C_inv
89            else:
90                C_inv_spin = np.array([C_inv]*2)
91
92        else:
93            raise NotImplementedError("UNDEFINED LOCAL ORBITAL TYPE, EXIT...")
94
95        mo_lo = np.einsum('...jk,...kl->...jl', C_inv_spin, mo)
96        return C_inv_spin, mo_lo
97
98    else:
99        cell = mf.cell
100        s1e = mf.get_ovlp()
101
102        if lo_method.lower() == 'lowdin' or lo_method.lower() == 'meta_lowdin':
103            nkpt = len(mf.kpts)
104            C_arr = []
105            C_inv_arr = []
106            for i in range(nkpt):
107                C_curr = lo.orth_ao(mf, 'meta_lowdin',s=s1e[i])
108                C_inv_arr.append(np.dot(C_curr.conj().T,s1e[i]))
109            C_inv_arr = np.array(C_inv_arr)
110            if isinstance(mf, scf.hf.RHF):
111                C_inv_spin = C_inv_arr
112            else:
113                C_inv_spin = np.array([C_inv_arr]*2)
114        else:
115            raise NotImplementedError("CONSTRUCTING...EXIT")
116
117        mo_lo = np.einsum('...jk,...kl->...jl', C_inv_spin, mo)
118        return C_inv_spin, mo_lo
119
120def pop_analysis(mf, mo_on_loc_ao, disp=True, full_dm=False):
121    '''
122    population analysis for local orbitals.
123    return dm_lo
124
125    mf should be a converged object
126    full_rdm = False: return the diagonal element of dm_lo
127    disp = True: show all the population to screen
128    '''
129    dm_lo = mf.make_rdm1(mo_on_loc_ao, mf.mo_occ)
130
131    if isinstance(mf, khf.KSCF):
132        nkpt = len(mf.kpts)
133        dm_lo_ave = np.einsum('...ijk->...jk', dm_lo)/float(nkpt)
134        dm_lo = dm_lo_ave
135
136    if disp:
137        mf.mulliken_pop(mf.mol, dm_lo, np.eye(mf.mol.nao_nr()))
138
139    if full_dm:
140        return dm_lo
141    else:
142        return np.einsum('...ii->...i', dm_lo)
143
144
145# get the matrix which should be added to the fock matrix, due to the lagrange multiplier V_lagr (in separate format)
146def get_fock_add_cdft(constraints, V, C_ao2lo_inv):
147    '''
148    mf is a pre-converged mf object, with NO constraints.
149
150    F_ao_new=F_ao_old + C^{-1}.T * V_diag_lo * C^{-1}
151    F_add is defined as C^{-1}.T * V_diag_lo * C^{-1}
152
153    C is the transformation matrix of BASIS, from ao to lo. |LO> = |AO> * C
154    NOTE: C should be pre-orthogonalized, i.e. C^T S C = I
155    and thus C^{-1} = C.T * S
156    '''
157
158    V_lagr = constraints.sum2separated(V)
159    sites_a, sites_b = constraints.unique_sites()
160    if isinstance(mf, scf.hf.RHF):
161        C_ao2lo_a = C_ao2lo_b = C_ao2lo_inv
162    else:
163        C_ao2lo_a, C_ao2lo_b = C_ao2lo_inv
164
165    if not isinstance(mf, khf.KSCF):
166        V_a = np.einsum('ip,i,iq->pq', C_ao2lo_a[sites_a].conj(), V_lagr[0], C_ao2lo_a[sites_a])
167        V_b = np.einsum('ip,i,iq->pq', C_ao2lo_b[sites_b].conj(), V_lagr[1], C_ao2lo_b[sites_b])
168    else:
169        V_a = np.einsum('kip,i,kiq->kpq', C_ao2lo_a[:,sites_a].conj(), V_lagr[0], C_ao2lo_a[:,sites_a])
170        V_b = np.einsum('kip,i,kiq->kpq', C_ao2lo_b[:,sites_b].conj(), V_lagr[1], C_ao2lo_b[:,sites_b])
171
172    if isinstance(mf, scf.hf.RHF):
173        return V_a + V_b
174    else:
175        return np.array((V_a, V_b))
176
177
178def W_cdft(mf, constraints, V_c, orb_pop):
179    '''get value of functional W (= V * constraint)'''
180    if isinstance(mf, scf.hf.RHF):
181        pop_a = pop_b = orb_pop * .5
182    else:
183        pop_a, pop_b = orb_pop
184
185    N_c = constraints.nelec_required
186    sites_a, sites_b = constraints.unique_sites()
187    N_cur = pop_a[sites_a], pop_b[sites_b]
188    N_cur_sum = constraints.separated2sum(N_cur)[1]
189    return np.einsum('i,i', V_c, N_cur_sum - N_c)
190
191# get gradient of W, as well as return the current population of selected orbitals
192def jac_cdft(mf, constraints, V_c, orb_pop):
193    if isinstance(mf, scf.hf.RHF):
194        pop_a = pop_b = orb_pop * .5
195    else:
196        pop_a, pop_b = orb_pop
197
198    N_c = constraints.nelec_required
199    sites_a, sites_b = constraints.unique_sites()
200    N_cur = np.array([pop_a[sites_a],pop_b[sites_b]]).real
201    N_cur_sum = constraints.separated2sum(N_cur)[1]
202    return N_cur_sum - N_c, N_cur_sum
203
204# get the hessian of W, w.r.t. V_lagr
205def hess_cdft(mf, constraints, V_c, mo_on_loc_ao):
206    mo_occ = mf.mo_occ
207    mo_energy = mf.mo_energy
208    de_ov_a = mo_energy[0][mo_occ[0]>0][:,None] - mo_energy[0][mo_occ[0]==0]
209    de_ov_b = mo_energy[1][mo_occ[1]>0][:,None] - mo_energy[1][mo_occ[1]==0]
210    de_ov_a[de_ov_a == 0] = 1e-18
211    de_ov_b[de_ov_b == 0] = 1e-18
212
213    sites_a, sites_b = constraints.unique_sites()
214    orb_o_a = mo_on_loc_ao[0][sites_a][:,mo_occ[0] > 0]  # Alpha occupied orbitals
215    orb_v_a = mo_on_loc_ao[0][sites_a][:,mo_occ[0] ==0]  # Alpha virtual  orbitals
216    orb_o_b = mo_on_loc_ao[1][sites_b][:,mo_occ[1] > 0]  # Beta  occupied orbitals
217    orb_v_b = mo_on_loc_ao[1][sites_b][:,mo_occ[1] ==0]  # Beta  virtual  orbitals
218
219    hess_a = np.einsum('pi,pa,qa,qi,ia->pq',
220                       orb_o_a.conj(), orb_v_a,
221                       orb_v_a.conj(), orb_o_a, 1./de_ov_a)
222    hess_a = hess_a + hess_a.conj()
223
224    hess_b = np.einsum('pi,pa,qa,qi,ia->pq',
225                       orb_o_b.conj(), orb_v_b,
226                       orb_v_b.conj(), orb_o_b, 1./de_ov_b)
227    hess_b = hess_b + hess_b.conj()
228
229    t_a, t_b = constraints.site_to_constraints_transform_matrix()
230    hess_arr  = np.einsum('pq,pi,qj->ij', hess_a, t_a, t_a)
231    hess_arr += np.einsum('pq,pi,qj->ij', hess_b, t_b, t_b)
232    return hess_arr
233
234
235# main function for cdft
236# mf : pre-converged mf object
237# V_0 : initial guess of lagrange multipliers
238# orb_idx: orbital index for orbital to be constrained
239# alpha : newton step
240# lo_method: localization method, one of 'lowdin', 'meta-lowdin', 'iao', 'nao'
241# diis_pos: 3 choices: post, pre, both
242# diis_type: 3 choices: use gradient of error vectors, use subsequent diff as error vector, no DIIS
243def cdft(mf, constraints, V_0=None, lo_method='lowdin', alpha=0.2, tol=1e-5,
244         constraints_tol=1e-3, maxiter=200, C_inv=None, verbose=4,
245         diis_pos='post', diis_type=1):
246
247    mf.verbose = verbose
248    mf.max_cycle = maxiter
249
250    old_get_fock = mf.get_fock
251
252    if V_0 is None:
253        V_0 = np.zeros(constraints.get_n_constraints())
254    constraints._final_V = V_0
255
256    C_inv = get_localized_orbitals(mf, lo_method, mf.mo_coeff)[0]
257
258    cdft_diis = lib.diis.DIIS()
259    cdft_diis.space = 8
260
261    def get_fock(h1e, s1e, vhf, dm, cycle=0, mf_diis=None):
262        fock_0 = old_get_fock(h1e, s1e, vhf, dm, cycle, None)
263        V_0 = constraints._final_V
264        if mf_diis is None:
265            fock_add = get_fock_add_cdft(constraints, V_0, C_inv)
266            return fock_0 + fock_add
267
268        cdft_conv_flag = False
269        if cycle < 10:
270            inner_max_cycle = 20
271        else:
272            inner_max_cycle = 50
273
274        if verbose > 3:
275            print("\nCDFT INNER LOOP:")
276
277        fock_0 = old_get_fock(h1e, s1e, vhf, dm, cycle, None)
278        fock_add = get_fock_add_cdft(constraints, V_0, C_inv)
279        fock = fock_0 + fock_add #ZHC
280
281        if diis_pos == 'pre' or diis_pos == 'both':
282            for it in range(inner_max_cycle): # TO BE MODIFIED
283                fock_add = get_fock_add_cdft(constraints, V_0, C_inv)
284                fock = fock_0 + fock_add #ZHC
285
286                mo_energy, mo_coeff = mf.eig(fock, s1e)
287                mo_occ = mf.get_occ(mo_energy, mo_coeff)
288
289                # Required by hess_cdft function
290                mf.mo_energy = mo_energy
291                mf.mo_coeff = mo_coeff
292                mf.mo_occ = mo_occ
293
294                if lo_method.lower() == 'iao':
295                    mo_on_loc_ao = get_localized_orbitals(mf, lo_method, mo_coeff)[1]
296                else:
297                    mo_on_loc_ao = np.einsum('...jk,...kl->...jl', C_inv, mo_coeff)
298
299                orb_pop = pop_analysis(mf, mo_on_loc_ao, disp=False)
300                W_new = W_cdft(mf, constraints, V_0, orb_pop)
301                jacob, N_cur = jac_cdft(mf, constraints, V_0, orb_pop)
302                hess = hess_cdft(mf, constraints, V_0, mo_on_loc_ao)
303
304                deltaV = get_newton_step_aug_hess(jacob,hess)
305                #deltaV = np.linalg.solve (hess, -jacob)
306
307                if it < 5 :
308                    stp = min(0.05, alpha*0.1)
309                else:
310                    stp = alpha
311
312                V = V_0 + deltaV * stp
313                g_norm = np.linalg.norm(jacob)
314                if verbose > 3:
315                    print("  loop %4s : W: %.5e    V_c: %s     Nele: %s      g_norm: %.3e    "
316                          % (it,W_new, V_0, N_cur, g_norm))
317                if g_norm < tol and np.linalg.norm(V-V_0) < constraints_tol:
318                    cdft_conv_flag = True
319                    break
320                V_0 = V
321
322        if cycle > 1:
323            if diis_type == 1:
324                fock = cdft_diis.update(fock_0, scf.diis.get_err_vec(s1e, dm, fock)) + fock_add
325            elif diis_type == 2:
326                # TO DO difference < threshold...
327                fock = cdft_diis.update(fock)
328            elif diis_type == 3:
329                fock = cdft_diis.update(fock, scf.diis.get_err_vec(s1e, dm, fock))
330            else:
331                print("\nWARN: Unknow CDFT DIIS type, NO DIIS IS USED!!!\n")
332
333        if diis_pos == 'post' or diis_pos == 'both':
334            cdft_conv_flag = False
335            fock_0 = fock - fock_add
336            for it in range(inner_max_cycle): # TO BE MODIFIED
337                fock_add = get_fock_add_cdft(constraints, V_0, C_inv)
338                fock = fock_0 + fock_add #ZHC
339
340                mo_energy, mo_coeff = mf.eig(fock, s1e)
341                mo_occ = mf.get_occ(mo_energy, mo_coeff)
342
343                # Required by hess_cdft function
344                mf.mo_energy = mo_energy
345                mf.mo_coeff = mo_coeff
346                mf.mo_occ = mo_occ
347
348                if lo_method.lower() == 'iao':
349                    mo_on_loc_ao = get_localized_orbitals(mf, lo_method, mo_coeff)[1]
350                else:
351                    mo_on_loc_ao = np.einsum('...jk,...kl->...jl', C_inv, mo_coeff)
352
353                orb_pop = pop_analysis(mf, mo_on_loc_ao, disp=False)
354                W_new = W_cdft(mf, constraints, V_0, orb_pop)
355                jacob, N_cur = jac_cdft(mf, constraints, V_0, orb_pop)
356                hess = hess_cdft(mf, constraints, V_0, mo_on_loc_ao)
357                deltaV = np.linalg.solve (hess, -jacob)
358
359                if it < 5 :
360                    stp = min(0.05, alpha*0.1)
361                else:
362                    stp = alpha
363
364                V = V_0 + deltaV * stp
365                g_norm = np.linalg.norm(jacob)
366                if verbose > 3:
367                    print("  loop %4s : W: %.5e    V_c: %s     Nele: %s      g_norm: %.3e    "
368                          % (it,W_new, V_0, N_cur, g_norm))
369                if g_norm < tol and np.linalg.norm(V-V_0) < constraints_tol:
370                    cdft_conv_flag = True
371                    break
372                V_0 = V
373
374        if verbose > 0:
375            print("CDFT W: %.5e   g_norm: %.3e    "%(W_new, g_norm))
376
377        constraints._converged = cdft_conv_flag
378        constraints._final_V = V_0
379        return fock
380
381    dm0 = mf.make_rdm1()
382    mf.get_fock = get_fock
383    mf.kernel(dm0)
384
385    mo_on_loc_ao = get_localized_orbitals(mf, lo_method, mf.mo_coeff)[1]
386    orb_pop = pop_analysis(mf, mo_on_loc_ao, disp=True)
387    return mf, orb_pop
388
389
390class Constraints(object):
391    '''
392    Attributes:
393        site_indices: the orbital indices on which electron population to be
394            constrained. Each element of site_indices is a list which has two
395            items (first for spin alpha, second for spin beta). If the
396            constraints are applied on alpha spin-density only, the second item
397            should be set to None. For the constraints of beta spin-density, the
398            first item should be None. If both items are specified, the
399            population constraints will be applied to the spin-traced density.
400        site_nelec: population constraints for each orbital. Each element is the
401            number of electrons for the orbitals that are specified in site_indices.
402
403    Examples:
404        constraints.orbital_indices = [[2,2], [3]]
405        constraints.spin_labels = [[0,1] , [1]]
406        constraints.nelec_required = [1.5 , 0.5]
407
408        correspond to two constraints:
409        1. N_{alpha-MO_2} + N_{beta-MO_2} = 1.5
410        2. N_{beta-MO_3} = 0.5
411    '''
412    def __init__(self, orbital_indices, spin_labels, nelec_required):
413        self.orbital_indices = orbital_indices
414        self.spin_labels = spin_labels
415        self.nelec_required = np.asarray(nelec_required)
416        assert(len(orbital_indices) == len(spin_labels) == len(nelec_required))
417
418    def get_n_constraints(self):
419        return len(self.nelec_required)
420
421    def unique_sites(self):
422        sites_a = []
423        sites_b = []
424        for group, spin_labels in zip(self.orbital_indices, self.spin_labels):
425            for orbidx, spin in zip(group, spin_labels):
426                if spin == 0:
427                    sites_a.append(orbidx)
428                else:
429                    sites_b.append(orbidx)
430        sites_a = np.sort(list(set(sites_a)))
431        sites_b = np.sort(list(set(sites_b)))
432        return sites_a, sites_b
433
434    def site_to_constraints_transform_matrix(self):
435        sites_a, sites_b = self.unique_sites()
436        map_sites_a = dict(((v,k) for k,v in enumerate(sites_a)))
437        map_sites_b = dict(((v,k) for k,v in enumerate(sites_b)))
438
439        n_constraints = self.get_n_constraints()
440        t_a = np.zeros((sites_a.size, n_constraints))
441        t_b = np.zeros((sites_b.size, n_constraints))
442        for k, group in enumerate(self.orbital_indices):
443            for orbidx, spin in zip(group, self.spin_labels[k]):
444                if spin == 0:
445                    t_a[map_sites_a[orbidx],k] += 1
446                else:
447                    t_b[map_sites_b[orbidx],k] += 1
448        return t_a, t_b
449
450    def sum2separated(self, V_c):
451        '''
452        convert the format of constraint from a summation format (it allows several orbitals' linear combination)
453        to the format each orbital is treated individually (also they are separated by spin)
454        '''
455        t_a, t_b = self.site_to_constraints_transform_matrix()
456        V_c_a = np.einsum('pi,i->p', t_a, V_c)
457        V_c_b = np.einsum('pi,i->p', t_b, V_c)
458        return V_c_a, V_c_b
459
460    def separated2sum(self, N_c):
461        '''the inversion function for sum2separated'''
462        t_a, t_b = self.site_to_constraints_transform_matrix()
463        N_c_new = np.array([np.einsum('pi,p->i', t_a, N_c[0]),
464                            np.einsum('pi,p->i', t_b, N_c[1])]).T
465
466        N_c_sum = N_c_new[:,0] + N_c_new[:,1]
467
468        # V_c on alpha-site if available, otherwise V_c on beta-site
469        V_c_sum = [N_c_new[i,0] if 0 in spins else N_c_new[i,1]
470                   for i,spins in enumerate(self.spin_labels)]
471        return N_c_new, N_c_sum, V_c_sum
472
473
474def get_newton_step_aug_hess(jac,hess):
475    #lamb = 1.0 / alpha
476    ah = np.zeros((hess.shape[0]+1,hess.shape[1]+1))
477    ah[1:,0] = jac
478    ah[0,1:] = jac.conj()
479    ah[1:,1:] = hess
480
481    eigval, eigvec = la.eigh(ah)
482    idx = None
483    for i in xrange(len(eigvec)):
484        if abs(eigvec[0,i]) > 0.1 and eigval[i] > 0.0:
485            idx = i
486            break
487    if idx is None:
488        print("WARNING: ALL EIGENVALUESS in AUG-HESSIAN are NEGATIVE!!! ")
489        return np.zeros_like(jac)
490    deltax = eigvec[1:,idx] / eigvec[0,idx]
491    return deltax
492
493
494if __name__ == '__main__':
495    mol = gto.Mole()
496    mol.verbose = 0
497    mol.atom = '''
498    c   1.217739890298750 -0.703062453466927  0.000000000000000
499    h   2.172991468538160 -1.254577209307266  0.000000000000000
500    c   1.217739890298750  0.703062453466927  0.000000000000000
501    h   2.172991468538160  1.254577209307266  0.000000000000000
502    c   0.000000000000000  1.406124906933854  0.000000000000000
503    h   0.000000000000000  2.509154418614532  0.000000000000000
504    c  -1.217739890298750  0.703062453466927  0.000000000000000
505    h  -2.172991468538160  1.254577209307266  0.000000000000000
506    c  -1.217739890298750 -0.703062453466927  0.000000000000000
507    h  -2.172991468538160 -1.254577209307266  0.000000000000000
508    c   0.000000000000000 -1.406124906933854  0.000000000000000
509    h   0.000000000000000 -2.509154418614532  0.000000000000000
510    '''
511    mol.basis = '631g'
512    mol.spin=0
513    mol.build()
514
515    mf = scf.UHF(mol)
516#    mf = dft.UKS(mol)
517#    mf.xc = 'pbe,pbe'
518    mf.conv_tol=1e-9
519    mf.verbose=0
520    mf.max_cycle=100
521    mf.run()
522
523    idx = mol.search_ao_label('C 2pz') # find all idx for carbon
524    # there are 4 constraints:
525    # 1. N_alpha_C0 + N_beta_C0 + N_beta_C1 = 1.5
526    # 2. N_alpha_C2 = 0.5
527    # 3. N_beta_C2 = 0.5
528    orbital_indices = [[idx[0],idx[0],idx[1]], [idx[2]], [idx[2]]]
529    spin_labels = [[0,1,1], [0], [1]]
530    nelec_required = [1.5, .5, .5]
531    constraints = Constraints(orbital_indices, spin_labels, nelec_required)
532    mf, dm_pop = cdft(mf, constraints, lo_method='lowdin', verbose=4)
533