1#!/usr/bin/env python
2# Copyright 2014-2019 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#
18
19'''
20Restricted open-shell Hartree-Fock for periodic systems with k-point sampling
21'''
22
23from functools import reduce
24import numpy as np
25import scipy.linalg
26from pyscf.scf import hf as mol_hf
27from pyscf.pbc.scf import khf
28from pyscf.pbc.scf import kuhf
29from pyscf.pbc.scf import rohf as pbcrohf
30from pyscf import lib
31from pyscf.lib import logger
32from pyscf.pbc.scf import addons
33from pyscf import __config__
34
35WITH_META_LOWDIN = getattr(__config__, 'pbc_scf_analyze_with_meta_lowdin', True)
36PRE_ORTH_METHOD = getattr(__config__, 'pbc_scf_analyze_pre_orth_method', 'ANO')
37
38
39def make_rdm1(mo_coeff_kpts, mo_occ_kpts, **kwargs):
40    '''Alpha and beta spin one particle density matrices for all k-points.
41
42    Returns:
43        dm_kpts : (2, nkpts, nao, nao) ndarray
44    '''
45    dma = []
46    dmb = []
47    for k, occ in enumerate(mo_occ_kpts):
48        mo_a = mo_coeff_kpts[k][:,occ> 0]
49        mo_b = mo_coeff_kpts[k][:,occ==2]
50        dma.append(np.dot(mo_a, mo_a.conj().T))
51        dmb.append(np.dot(mo_b, mo_b.conj().T))
52    return lib.asarray((dma,dmb))
53
54def get_fock(mf, h1e=None, s1e=None, vhf=None, dm=None, cycle=-1, diis=None,
55             diis_start_cycle=None, level_shift_factor=None, damp_factor=None):
56    h1e_kpts, s_kpts, vhf_kpts, dm_kpts = h1e, s1e, vhf, dm
57    if h1e_kpts is None: h1e_kpts = mf.get_hcore()
58    if vhf_kpts is None: vhf_kpts = mf.get_veff(mf.cell, dm_kpts)
59    focka = h1e_kpts + vhf_kpts[0]
60    fockb = h1e_kpts + vhf_kpts[1]
61    f_kpts = get_roothaan_fock((focka,fockb), dm, s1e)
62    if cycle < 0 and diis is None:  # Not inside the SCF iteration
63        return f_kpts
64
65    if diis_start_cycle is None:
66        diis_start_cycle = mf.diis_start_cycle
67    if level_shift_factor is None:
68        level_shift_factor = mf.level_shift
69    if damp_factor is None:
70        damp_factor = mf.damp
71    if s_kpts is None: s_kpts = mf.get_ovlp()
72    if dm_kpts is None: dm_kpts = mf.make_rdm1()
73
74    dm_sf = dm_kpts[0] + dm_kpts[1]
75    if 0 <= cycle < diis_start_cycle-1 and abs(damp_factor) > 1e-4:
76        raise NotImplementedError('ROHF Fock-damping')
77    if diis and cycle >= diis_start_cycle:
78        f_kpts = diis.update(s_kpts, dm_sf, f_kpts, mf, h1e_kpts, vhf_kpts)
79    if abs(level_shift_factor) > 1e-4:
80        f_kpts = [mol_hf.level_shift(s, dm_sf[k]*.5, f_kpts[k], level_shift_factor)
81                  for k, s in enumerate(s_kpts)]
82    f_kpts = lib.tag_array(lib.asarray(f_kpts), focka=focka, fockb=fockb)
83    return f_kpts
84
85def get_roothaan_fock(focka_fockb, dma_dmb, s):
86    '''Roothaan's effective fock.
87
88    ======== ======== ====== =========
89    space     closed   open   virtual
90    ======== ======== ====== =========
91    closed      Fc      Fb     Fc
92    open        Fb      Fc     Fa
93    virtual     Fc      Fa     Fc
94    ======== ======== ====== =========
95
96    where Fc = (Fa + Fb) / 2
97
98    Returns:
99        Roothaan effective Fock matrix
100    '''
101    nkpts = len(s)
102    nao = s[0].shape[0]
103    focka, fockb = focka_fockb
104    dma, dmb = dma_dmb
105    fock_kpts = []
106    for k in range(nkpts):
107        fc = (focka[k] + fockb[k]) * .5
108        pc = np.dot(dmb[k], s[k])
109        po = np.dot(dma[k]-dmb[k], s[k])
110        pv = np.eye(nao) - np.dot(dma[k], s[k])
111        fock  = reduce(np.dot, (pc.conj().T, fc, pc)) * .5
112        fock += reduce(np.dot, (po.conj().T, fc, po)) * .5
113        fock += reduce(np.dot, (pv.conj().T, fc, pv)) * .5
114        fock += reduce(np.dot, (po.conj().T, fockb[k], pc))
115        fock += reduce(np.dot, (po.conj().T, focka[k], pv))
116        fock += reduce(np.dot, (pv.conj().T, fc, pc))
117        fock_kpts.append(fock + fock.conj().T)
118    fock_kpts = lib.tag_array(np.asarray(fock_kpts), focka=focka, fockb=fockb)
119    return fock_kpts
120
121def get_occ(mf, mo_energy_kpts=None, mo_coeff_kpts=None):
122    '''Label the occupancies for each orbital for sampled k-points.
123
124    This is a k-point version of scf.hf.SCF.get_occ
125    '''
126
127    if mo_energy_kpts is None: mo_energy_kpts = mf.mo_energy
128    if getattr(mo_energy_kpts[0], 'mo_ea', None) is not None:
129        mo_ea_kpts = [x.mo_ea for x in mo_energy_kpts]
130        mo_eb_kpts = [x.mo_eb for x in mo_energy_kpts]
131    else:
132        mo_ea_kpts = mo_eb_kpts = mo_energy_kpts
133
134    nocc_a, nocc_b = mf.nelec
135    mo_energy_kpts1 = np.hstack(mo_energy_kpts)
136    mo_energy = np.sort(mo_energy_kpts1)
137    if nocc_b > 0:
138        core_level = mo_energy[nocc_b-1]
139    else:
140        core_level = -1e9
141    if nocc_a == nocc_b:
142        fermi = core_level
143    else:
144        mo_ea_kpts1 = np.hstack(mo_ea_kpts)
145        mo_ea = np.sort(mo_ea_kpts1[mo_energy_kpts1 > core_level])
146        fermi = mo_ea[nocc_a - nocc_b - 1]
147
148    mo_occ_kpts = []
149    for k, mo_e in enumerate(mo_energy_kpts):
150        occ = np.zeros_like(mo_e)
151        occ[mo_e <= core_level] = 2
152        if nocc_a != nocc_b:
153            occ[(mo_e > core_level) & (mo_ea_kpts[k] <= fermi)] = 1
154        mo_occ_kpts.append(occ)
155
156    if nocc_a < len(mo_energy):
157        logger.info(mf, 'HOMO = %.12g  LUMO = %.12g',
158                    mo_energy[nocc_a-1], mo_energy[nocc_a])
159    else:
160        logger.info(mf, 'HOMO = %.12g', mo_energy[nocc_a-1])
161
162    np.set_printoptions(threshold=len(mo_energy))
163    if mf.verbose >= logger.DEBUG:
164        logger.debug(mf, '                  Roothaan           | alpha              | beta')
165        for k,kpt in enumerate(mf.cell.get_scaled_kpts(mf.kpts)):
166            core_idx = mo_occ_kpts[k] == 2
167            open_idx = mo_occ_kpts[k] == 1
168            vir_idx = mo_occ_kpts[k] == 0
169            logger.debug(mf, '  kpt %2d (%6.3f %6.3f %6.3f)',
170                         k, kpt[0], kpt[1], kpt[2])
171            if np.count_nonzero(core_idx) > 0:
172                logger.debug(mf, '  Highest 2-occ = %18.15g | %18.15g | %18.15g',
173                             max(mo_energy_kpts[k][core_idx]),
174                             max(mo_ea_kpts[k][core_idx]), max(mo_eb_kpts[k][core_idx]))
175            if np.count_nonzero(vir_idx) > 0:
176                logger.debug(mf, '  Lowest 0-occ =  %18.15g | %18.15g | %18.15g',
177                             min(mo_energy_kpts[k][vir_idx]),
178                             min(mo_ea_kpts[k][vir_idx]), min(mo_eb_kpts[k][vir_idx]))
179            for i in np.where(open_idx)[0]:
180                logger.debug(mf, '  1-occ =         %18.15g | %18.15g | %18.15g',
181                             mo_energy_kpts[k][i], mo_ea_kpts[k][i], mo_eb_kpts[k][i])
182
183        logger.debug(mf, '     k-point                  Roothaan mo_energy')
184        for k,kpt in enumerate(mf.cell.get_scaled_kpts(mf.kpts)):
185            logger.debug(mf, '  %2d (%6.3f %6.3f %6.3f)   %s %s',
186                         k, kpt[0], kpt[1], kpt[2],
187                         mo_energy_kpts[k][mo_occ_kpts[k]> 0],
188                         mo_energy_kpts[k][mo_occ_kpts[k]==0])
189
190    if mf.verbose >= logger.DEBUG1:
191        logger.debug1(mf, '     k-point                  alpha mo_energy')
192        for k,kpt in enumerate(mf.cell.get_scaled_kpts(mf.kpts)):
193            logger.debug1(mf, '  %2d (%6.3f %6.3f %6.3f)   %s %s',
194                          k, kpt[0], kpt[1], kpt[2],
195                          mo_ea_kpts[k][mo_occ_kpts[k]> 0],
196                          mo_ea_kpts[k][mo_occ_kpts[k]==0])
197        logger.debug1(mf, '     k-point                  beta  mo_energy')
198        for k,kpt in enumerate(mf.cell.get_scaled_kpts(mf.kpts)):
199            logger.debug1(mf, '  %2d (%6.3f %6.3f %6.3f)   %s %s',
200                          k, kpt[0], kpt[1], kpt[2],
201                          mo_eb_kpts[k][mo_occ_kpts[k]==2],
202                          mo_eb_kpts[k][mo_occ_kpts[k]!=2])
203    np.set_printoptions(threshold=1000)
204
205    return mo_occ_kpts
206
207
208energy_elec = kuhf.energy_elec
209dip_moment = kuhf.dip_moment
210get_rho = kuhf.get_rho
211
212
213@lib.with_doc(khf.mulliken_meta.__doc__)
214def mulliken_meta(cell, dm_ao_kpts, verbose=logger.DEBUG,
215                  pre_orth_method=PRE_ORTH_METHOD, s=None):
216    '''Mulliken population analysis, based on meta-Lowdin AOs.
217
218    Note this function only computes the Mulliken population for the gamma
219    point density matrix.
220    '''
221    dm = dm_ao_kpts[0] + dm_ao_kpts[1]
222    return khf.mulliken_meta(cell, dm, verbose, pre_orth_method, s)
223
224
225def canonicalize(mf, mo_coeff_kpts, mo_occ_kpts, fock=None):
226    '''Canonicalization diagonalizes the ROHF Fock matrix within occupied,
227    virtual subspaces separatedly (without change occupancy).
228    '''
229    if fock is None:
230        dm = mf.make_rdm1(mo_coeff_kpts, mo_occ_kpts)
231        fock = mf.get_fock(dm=dm)
232
233    mo_coeff = []
234    mo_energy = []
235    for k, mo in enumerate(mo_coeff_kpts):
236        mo1 = np.empty_like(mo)
237        mo_e = np.empty_like(mo_occ_kpts[k])
238        coreidx = mo_occ_kpts[k] == 2
239        openidx = mo_occ_kpts[k] == 1
240        viridx = mo_occ_kpts[k] == 0
241        for idx in (coreidx, openidx, viridx):
242            if np.count_nonzero(idx) > 0:
243                orb = mo[:,idx]
244                f1 = reduce(np.dot, (orb.T.conj(), fock[k], orb))
245                e, c = scipy.linalg.eigh(f1)
246                mo1[:,idx] = np.dot(orb, c)
247                mo_e[idx] = e
248        if getattr(fock, 'focka', None) is not None:
249            fa, fb = fock.focka[k], fock.fockb[k]
250            mo_ea = np.einsum('pi,pi->i', mo1.conj(), fa.dot(mo1)).real
251            mo_eb = np.einsum('pi,pi->i', mo1.conj(), fb.dot(mo1)).real
252            mo_e = lib.tag_array(mo_e, mo_ea=mo_ea, mo_eb=mo_eb)
253        mo_coeff.append(mo1)
254        mo_energy.append(mo_e)
255    return mo_energy, mo_coeff
256
257init_guess_by_chkfile = kuhf.init_guess_by_chkfile
258
259
260class KROHF(khf.KRHF, pbcrohf.ROHF):
261    '''UHF class with k-point sampling.
262    '''
263    conv_tol = getattr(__config__, 'pbc_scf_KSCF_conv_tol', 1e-7)
264    conv_tol_grad = getattr(__config__, 'pbc_scf_KSCF_conv_tol_grad', None)
265    direct_scf = getattr(__config__, 'pbc_scf_SCF_direct_scf', True)
266
267    def __init__(self, cell, kpts=np.zeros((1,3)),
268                 exxdiv=getattr(__config__, 'pbc_scf_SCF_exxdiv', 'ewald')):
269        khf.KSCF.__init__(self, cell, kpts, exxdiv)
270        self.nelec = None
271
272    @property
273    def nelec(self):
274        if self._nelec is not None:
275            return self._nelec
276        else:
277            cell = self.cell
278            nkpts = len(self.kpts)
279            ne = cell.tot_electrons(nkpts)
280            nalpha = (ne + cell.spin) // 2
281            nbeta = nalpha - cell.spin
282            if nalpha + nbeta != ne:
283                raise RuntimeError('Electron number %d and spin %d are not consistent\n'
284                                   'Note cell.spin = 2S = Nalpha - Nbeta, not 2S+1' %
285                                   (ne, cell.spin))
286            return nalpha, nbeta
287    @nelec.setter
288    def nelec(self, x):
289        self._nelec = x
290
291    def dump_flags(self, verbose=None):
292        khf.KSCF.dump_flags(self, verbose)
293        logger.info(self, 'number of electrons per unit cell  '
294                    'alpha = %d beta = %d', *self.nelec)
295        return self
296
297#?    def get_init_guess(self, cell=None, key='minao'):
298#?        dm_kpts = khf.KSCF.get_init_guess(self, cell, key)
299#?        if dm_kpts.ndim != 4:  # The KRHF initial guess
300#?            # dm_kpts shape should be (spin, nkpts, nao, nao)
301#?            dm_kpts = lib.asarray([dm_kpts*.5,]*2)
302#?        return dm_kpts
303#?
304    def get_init_guess(self, cell=None, key='minao'):
305        if cell is None:
306            cell = self.cell
307        dm_kpts = None
308        key = key.lower()
309        if key == '1e' or key == 'hcore':
310            dm_kpts = self.init_guess_by_1e(cell)
311        elif getattr(cell, 'natm', 0) == 0:
312            logger.info(self, 'No atom found in cell. Use 1e initial guess')
313            dm_kpts = self.init_guess_by_1e(cell)
314        elif key == 'atom':
315            dm = self.init_guess_by_atom(cell)
316        elif key[:3] == 'chk':
317            try:
318                dm_kpts = self.from_chk()
319            except (IOError, KeyError):
320                logger.warn(self, 'Fail to read %s. Use MINAO initial guess',
321                            self.chkfile)
322                dm = self.init_guess_by_minao(cell)
323        else:
324            dm = self.init_guess_by_minao(cell)
325
326        if dm_kpts is None:
327            nkpts = len(self.kpts)
328            # dm[spin,nao,nao] at gamma point -> dm_kpts[spin,nkpts,nao,nao]
329            dm_kpts = np.repeat(dm[:,None,:,:], nkpts, axis=1)
330
331        ne = np.einsum('xkij,kji->', dm_kpts, self.get_ovlp(cell)).real
332        # FIXME: consider the fractional num_electron or not? This maybe
333        # relates to the charged system.
334        nkpts = len(self.kpts)
335        nelec = float(sum(self.nelec))
336        if np.any(abs(ne - nelec) > 1e-7*nkpts):
337            logger.debug(self, 'Big error detected in the electron number '
338                         'of initial guess density matrix (Ne/cell = %g)!\n'
339                         '  This can cause huge error in Fock matrix and '
340                         'lead to instability in SCF for low-dimensional '
341                         'systems.\n  DM is normalized wrt the number '
342                         'of electrons %g', ne/nkpts, nelec/nkpts)
343            dm_kpts *= nelec / ne
344        return dm_kpts
345
346    init_guess_by_minao  = pbcrohf.ROHF.init_guess_by_minao
347    init_guess_by_atom   = pbcrohf.ROHF.init_guess_by_atom
348    init_guess_by_huckel = pbcrohf.ROHF.init_guess_by_huckel
349
350    get_rho = get_rho
351
352    get_fock = get_fock
353    get_occ = get_occ
354    energy_elec = energy_elec
355
356    def get_veff(self, cell=None, dm_kpts=None, dm_last=0, vhf_last=0, hermi=1,
357                 kpts=None, kpts_band=None):
358        if dm_kpts is None:
359            dm_kpts = self.make_rdm1()
360        if getattr(dm_kpts, 'mo_coeff', None) is not None:
361            mo_coeff = dm_kpts.mo_coeff
362            mo_occ_a = [(x > 0).astype(np.double) for x in dm_kpts.mo_occ]
363            mo_occ_b = [(x ==2).astype(np.double) for x in dm_kpts.mo_occ]
364            dm_kpts = lib.tag_array(dm_kpts, mo_coeff=(mo_coeff,mo_coeff),
365                                    mo_occ=(mo_occ_a,mo_occ_b))
366        if self.rsjk and self.direct_scf:
367            ddm = dm_kpts - dm_last
368            vj, vk = self.get_jk(cell, ddm, hermi, kpts, kpts_band)
369            vhf = vj[0] + vj[1] - vk
370            vhf += vhf_last
371        else:
372            vj, vk = self.get_jk(cell, dm_kpts, hermi, kpts, kpts_band)
373            vhf = vj[0] + vj[1] - vk
374        return vhf
375
376    def get_grad(self, mo_coeff_kpts, mo_occ_kpts, fock=None):
377        if fock is None:
378            dm1 = self.make_rdm1(mo_coeff_kpts, mo_occ_kpts)
379            fock = self.get_hcore(self.cell, self.kpts) + self.get_veff(self.cell, dm1)
380
381        if getattr(fock, 'focka', None) is not None:
382            focka = fock.focka
383            fockb = fock.fockb
384        elif getattr(fock, 'ndim', None) == 4:
385            focka, fockb = fock
386        else:
387            focka = fockb = fock
388
389        def grad(k):
390            mo_occ = mo_occ_kpts[k]
391            mo_coeff = mo_coeff_kpts[k]
392            return pbcrohf.get_grad(mo_coeff, mo_occ, (focka[k], fockb[k]))
393
394        nkpts = len(self.kpts)
395        grad_kpts = np.hstack([grad(k) for k in range(nkpts)])
396        return grad_kpts
397
398    def eig(self, fock, s):
399        e, c = khf.KSCF.eig(self, fock, s)
400        if getattr(fock, 'focka', None) is not None:
401            for k, mo in enumerate(c):
402                fa, fb = fock.focka[k], fock.fockb[k]
403                mo_ea = np.einsum('pi,pi->i', mo.conj(), fa.dot(mo)).real
404                mo_eb = np.einsum('pi,pi->i', mo.conj(), fb.dot(mo)).real
405                e[k] = lib.tag_array(e[k], mo_ea=mo_ea, mo_eb=mo_eb)
406        return e, c
407
408    def make_rdm1(self, mo_coeff_kpts=None, mo_occ_kpts=None, **kwargs):
409        if mo_coeff_kpts is None: mo_coeff_kpts = self.mo_coeff
410        if mo_occ_kpts is None: mo_occ_kpts = self.mo_occ
411        return make_rdm1(mo_coeff_kpts, mo_occ_kpts, **kwargs)
412
413    def init_guess_by_chkfile(self, chk=None, project=True, kpts=None):
414        if chk is None: chk = self.chkfile
415        if kpts is None: kpts = self.kpts
416        return init_guess_by_chkfile(self.cell, chk, project, kpts)
417
418
419    def analyze(self, verbose=None, with_meta_lowdin=WITH_META_LOWDIN,
420                **kwargs):
421        if verbose is None: verbose = self.verbose
422        return khf.analyze(self, verbose, with_meta_lowdin, **kwargs)
423
424    def mulliken_meta(self, cell=None, dm=None, verbose=logger.DEBUG,
425                      pre_orth_method=PRE_ORTH_METHOD, s=None):
426        if cell is None: cell = self.cell
427        if dm is None: dm = self.make_rdm1()
428        if s is None: s = self.get_ovlp(cell)
429        return mulliken_meta(cell, dm, s=s, verbose=verbose,
430                             pre_orth_method=pre_orth_method)
431
432    @lib.with_doc(dip_moment.__doc__)
433    def dip_moment(self, cell=None, dm=None, unit='Debye', verbose=logger.NOTE,
434                   **kwargs):
435        if cell is None: cell = self.cell
436        if dm is None: dm = self.make_rdm1()
437        rho = kwargs.pop('rho', None)
438        if rho is None:
439            rho = self.get_rho(dm)
440        return dip_moment(cell, dm, unit, verbose, rho=rho, kpts=self.kpts, **kwargs)
441
442    spin_square = pbcrohf.ROHF.spin_square
443
444    canonicalize = canonicalize
445
446    def stability(self,
447                  internal=getattr(__config__, 'pbc_scf_KSCF_stability_internal', True),
448                  external=getattr(__config__, 'pbc_scf_KSCF_stability_external', False),
449                  verbose=None):
450        raise NotImplementedError
451
452    def convert_from_(self, mf):
453        '''Convert given mean-field object to KUHF'''
454        addons.convert_to_rhf(mf, self)
455        return self
456
457del(WITH_META_LOWDIN, PRE_ORTH_METHOD)
458
459
460if __name__ == '__main__':
461    from pyscf.pbc import gto
462    cell = gto.Cell()
463    cell.atom = '''
464    He 0 0 1
465    He 1 0 1
466    '''
467    cell.basis = '321g'
468    cell.a = np.eye(3) * 3
469    cell.mesh = [11] * 3
470    cell.verbose = 5
471    cell.spin = 2
472    cell.build()
473    mf = KROHF(cell, [2,1,1])
474    mf.kernel()
475    mf.analyze()
476
477