1#!/usr/bin/env python
2# Copyright 2014-2021 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Tianyu Zhu <zhutianyu1991@gmail.com>
17#
18
19'''
20PBC spin-restricted G0W0-AC QP eigenvalues with k-point sampling
21This implementation has N^4 scaling, and is faster than GW-CD (N^4)
22and analytic GW (N^6) methods.
23GW-AC is recommended for valence states only, and is inaccuarate for core states.
24
25Method:
26    See T. Zhu and G.K.-L. Chan, arxiv:2007.03148 (2020) for details
27    Compute Sigma on imaginary frequency with density fitting,
28    then analytically continued to real frequency.
29    Gaussian density fitting must be used (FFTDF and MDF are not supported).
30'''
31
32from functools import reduce
33import numpy
34import numpy as np
35import h5py
36from scipy.optimize import newton, least_squares
37
38from pyscf import lib
39from pyscf.lib import logger
40from pyscf.ao2mo import _ao2mo
41from pyscf.ao2mo.incore import _conc_mos
42from pyscf.pbc import df, dft, scf
43from pyscf.pbc.mp.kmp2 import get_nocc, get_nmo, get_frozen_mask
44from pyscf import __config__
45
46einsum = lib.einsum
47
48def kernel(gw, mo_energy, mo_coeff, orbs=None,
49           kptlist=None, nw=None, verbose=logger.NOTE):
50    '''GW-corrected quasiparticle orbital energies
51
52    Returns:
53        A list :  converged, mo_energy, mo_coeff
54    '''
55    mf = gw._scf
56    if gw.frozen is None:
57        frozen = 0
58    else:
59        frozen = gw.frozen
60    assert (frozen == 0)
61
62    if orbs is None:
63        orbs = range(gw.nmo)
64    if kptlist is None:
65        kptlist = range(gw.nkpts)
66    nkpts = gw.nkpts
67    nklist = len(kptlist)
68
69    # v_xc
70    dm = np.array(mf.make_rdm1())
71    v_mf = np.array(mf.get_veff()) - np.array(mf.get_j(dm_kpts=dm))
72    for k in range(nkpts):
73        v_mf[k] = reduce(numpy.dot, (mo_coeff[k].T.conj(), v_mf[k], mo_coeff[k]))
74
75    nocc = gw.nocc
76    nmo = gw.nmo
77
78    # v_hf from DFT/HF density
79    if gw.fc:
80        exxdiv = 'ewald'
81    else:
82        exxdiv = None
83    rhf = scf.KRHF(gw.mol, gw.kpts, exxdiv=exxdiv)
84    rhf.with_df = gw.with_df
85    if getattr(gw.with_df, '_cderi', None) is None:
86        raise RuntimeError('Found incompatible integral scheme %s.'
87                           'KGWAC can be only used with GDF integrals' %
88                           gw.with_df.__class__)
89
90    vk = rhf.get_veff(gw.mol,dm_kpts=dm) - rhf.get_j(gw.mol,dm_kpts=dm)
91    for k in range(nkpts):
92        vk[k] = reduce(numpy.dot, (mo_coeff[k].T.conj(), vk[k], mo_coeff[k]))
93
94    # Grids for integration on imaginary axis
95    freqs,wts = _get_scaled_legendre_roots(nw)
96
97    # Compute self-energy on imaginary axis i*[0,iw_cutoff]
98    sigmaI, omega = get_sigma_diag(gw, orbs, kptlist, freqs, wts, iw_cutoff=5.)
99
100    # Analytic continuation
101    coeff = []
102    if gw.ac == 'twopole':
103        for k in range(nklist):
104            coeff.append(AC_twopole_diag(sigmaI[k], omega, orbs, nocc))
105    elif gw.ac == 'pade':
106        for k in range(nklist):
107            coeff_tmp, omega_fit = AC_pade_thiele_diag(sigmaI[k], omega)
108            coeff.append(coeff_tmp)
109    coeff = np.array(coeff)
110
111    conv = True
112    # This code does not support metals
113    homo = -99.
114    lumo = 99.
115    for k in range(nkpts):
116        if homo < mf.mo_energy[k][nocc-1]:
117            homo = mf.mo_energy[k][nocc-1]
118        if lumo > mf.mo_energy[k][nocc]:
119            lumo = mf.mo_energy[k][nocc]
120    ef = (homo+lumo)/2.
121
122    mo_energy = np.zeros_like(np.array(mf.mo_energy))
123    for k in range(nklist):
124        kn = kptlist[k]
125        for p in orbs:
126            if gw.linearized:
127                # linearized G0W0
128                de = 1e-6
129                ep = mf.mo_energy[kn][p]
130                #TODO: analytic sigma derivative
131                if gw.ac == 'twopole':
132                    sigmaR = two_pole(ep-ef, coeff[k,:,p-orbs[0]]).real
133                    dsigma = two_pole(ep-ef+de, coeff[k,:,p-orbs[0]]).real - sigmaR.real
134                elif gw.ac == 'pade':
135                    sigmaR = pade_thiele(ep-ef, omega_fit[p-orbs[0]], coeff[k,:,p-orbs[0]]).real
136                    dsigma = pade_thiele(ep-ef+de, omega_fit[p-orbs[0]], coeff[k,:,p-orbs[0]]).real - sigmaR.real
137                zn = 1.0/(1.0-dsigma/de)
138                e = ep + zn*(sigmaR.real + vk[kn,p,p].real - v_mf[kn,p,p].real)
139                mo_energy[kn,p] = e
140            else:
141                # self-consistently solve QP equation
142                def quasiparticle(omega):
143                    if gw.ac == 'twopole':
144                        sigmaR = two_pole(omega-ef, coeff[k,:,p-orbs[0]]).real
145                    elif gw.ac == 'pade':
146                        sigmaR = pade_thiele(omega-ef, omega_fit[p-orbs[0]], coeff[k,:,p-orbs[0]]).real
147                    return omega - mf.mo_energy[kn][p] - (sigmaR.real + vk[kn,p,p].real - v_mf[kn,p,p].real)
148                try:
149                    e = newton(quasiparticle, mf.mo_energy[kn][p], tol=1e-6, maxiter=100)
150                    mo_energy[kn,p] = e
151                except RuntimeError:
152                    conv = False
153    mo_coeff = mf.mo_coeff
154
155    if gw.verbose >= logger.DEBUG:
156        numpy.set_printoptions(threshold=nmo)
157        for k in range(nkpts):
158            logger.debug(gw, '  GW mo_energy @ k%d =\n%s', k,mo_energy[k])
159        numpy.set_printoptions(threshold=1000)
160
161    return conv, mo_energy, mo_coeff
162
163def get_rho_response(gw, omega, mo_energy, Lpq, kL, kidx):
164    '''
165    Compute density response function in auxiliary basis at freq iw
166    '''
167    nkpts, naux, nmo, nmo = Lpq.shape
168    nocc = gw.nocc
169    kpts = gw.kpts
170    kscaled = gw.mol.get_scaled_kpts(kpts)
171    kscaled -= kscaled[0]
172
173    # Compute Pi for kL
174    Pi = np.zeros((naux,naux),dtype=np.complex128)
175    for i, kpti in enumerate(kpts):
176        # Find ka that conserves with ki and kL (-ki+ka+kL=G)
177        a = kidx[i]
178        eia = mo_energy[i,:nocc,None] - mo_energy[a,None,nocc:]
179        eia = eia/(omega**2+eia*eia)
180        Pia = einsum('Pia,ia->Pia',Lpq[i][:,:nocc,nocc:],eia)
181        # Response from both spin-up and spin-down density
182        Pi += 4./nkpts * einsum('Pia,Qia->PQ',Pia,Lpq[i][:,:nocc,nocc:].conj())
183    return Pi
184
185def get_sigma_diag(gw, orbs, kptlist, freqs, wts, iw_cutoff=None, max_memory=8000):
186    '''
187    Compute GW correlation self-energy (diagonal elements)
188    in MO basis on imaginary axis
189    '''
190    mo_energy = np.array(gw._scf.mo_energy)
191    mo_coeff = np.array(gw._scf.mo_coeff)
192    nocc = gw.nocc
193    nmo = gw.nmo
194    nkpts = gw.nkpts
195    kpts = gw.kpts
196    nklist = len(kptlist)
197    nw = len(freqs)
198    norbs = len(orbs)
199    mydf = gw.with_df
200
201    # possible kpts shift center
202    kscaled = gw.mol.get_scaled_kpts(kpts)
203    kscaled -= kscaled[0]
204
205    # This code does not support metals
206    homo = -99.
207    lumo = 99.
208    for k in range(nkpts):
209        if homo < mo_energy[k][nocc-1]:
210            homo = mo_energy[k][nocc-1]
211        if lumo > mo_energy[k][nocc]:
212            lumo = mo_energy[k][nocc]
213    if (lumo-homo)<1e-3:
214        logger.warn(gw, 'This GW-AC code is not supporting metals!')
215    ef = (homo+lumo)/2.
216
217    # Integration on numerical grids
218    if iw_cutoff is not None:
219        nw_sigma = sum(iw < iw_cutoff for iw in freqs) + 1
220    else:
221        nw_sigma = nw + 1
222
223    # Compute occ for -iw and vir for iw separately
224    # to avoid branch cuts in analytic continuation
225    omega_occ = np.zeros((nw_sigma), dtype=np.complex128)
226    omega_vir = np.zeros((nw_sigma), dtype=np.complex128)
227    omega_occ[1:] = -1j*freqs[:(nw_sigma-1)]
228    omega_vir[1:] = 1j*freqs[:(nw_sigma-1)]
229    orbs_occ = [i for i in orbs if i < nocc]
230    norbs_occ = len(orbs_occ)
231
232    emo_occ = np.zeros((nkpts,nmo,nw_sigma),dtype=np.complex128)
233    emo_vir = np.zeros((nkpts,nmo,nw_sigma),dtype=np.complex128)
234    for k in range(nkpts):
235        emo_occ[k] = omega_occ[None,:] + ef - mo_energy[k][:,None]
236        emo_vir[k] = omega_vir[None,:] + ef - mo_energy[k][:,None]
237
238    sigma = np.zeros((nklist,norbs,nw_sigma),dtype=np.complex128)
239    omega = np.zeros((norbs,nw_sigma),dtype=np.complex128)
240    for p in range(norbs):
241        orbp = orbs[p]
242        if orbp < nocc:
243            omega[p] = omega_occ.copy()
244        else:
245            omega[p] = omega_vir.copy()
246
247    if gw.fc:
248        # Set up q mesh for q->0 finite size correction
249        q_pts = np.array([1e-3,0,0]).reshape(1,3)
250        q_abs = gw.mol.get_abs_kpts(q_pts)
251
252        # Get qij = 1/sqrt(Omega) * < psi_{ik} | e^{iqr} | psi_{ak-q} > at q: (nkpts, nocc, nvir)
253        qij = get_qij(gw, q_abs[0], mo_coeff)
254
255    for kL in range(nkpts):
256        # Lij: (ki, L, i, j) for looping every kL
257        Lij = []
258        # kidx: save kj that conserves with kL and ki (-ki+kj+kL=G)
259        # kidx_r: save ki that conserves with kL and kj (-ki+kj+kL=G)
260        kidx = np.zeros((nkpts),dtype=np.int64)
261        kidx_r = np.zeros((nkpts),dtype=np.int64)
262        for i, kpti in enumerate(kpts):
263            for j, kptj in enumerate(kpts):
264                # Find (ki,kj) that satisfies momentum conservation with kL
265                kconserv = -kscaled[i] + kscaled[j] + kscaled[kL]
266                is_kconserv = np.linalg.norm(np.round(kconserv) - kconserv) < 1e-12
267                if is_kconserv:
268                    kidx[i] = j
269                    kidx_r[j] = i
270                    logger.debug(gw, "Read Lpq (kL: %s / %s, ki: %s, kj: %s)"%(kL+1, nkpts, i, j))
271                    Lij_out = None
272                    # Read (L|pq) and ao2mo transform to (L|ij)
273                    Lpq = []
274                    for LpqR, LpqI, sign \
275                            in mydf.sr_loop([kpti, kptj], max_memory=0.1*gw._scf.max_memory, compact=False):
276                        Lpq.append(LpqR+LpqI*1.0j)
277                    # support uneqaul naux on different k points
278                    Lpq = np.vstack(Lpq).reshape(-1,nmo**2)
279                    tao = []
280                    ao_loc = None
281                    moij, ijslice = _conc_mos(mo_coeff[i], mo_coeff[j])[2:]
282                    Lij_out = _ao2mo.r_e2(Lpq, moij, ijslice, tao, ao_loc, out=Lij_out)
283                    Lij.append(Lij_out.reshape(-1,nmo,nmo))
284        Lij = np.asarray(Lij)
285        naux = Lij.shape[1]
286
287        if kL == 0:
288            for w in range(nw):
289                # body dielectric matrix eps_body
290                Pi = get_rho_response(gw, freqs[w], mo_energy, Lij, kL, kidx)
291                eps_body_inv = np.linalg.inv(np.eye(naux)-Pi)
292
293                if gw.fc:
294                    # head dielectric matrix eps_00
295                    Pi_00 = get_rho_response_head(gw, freqs[w], mo_energy, qij)
296                    eps_00 = 1. - 4. * np.pi/np.linalg.norm(q_abs[0])**2 * Pi_00
297
298                    # wings dielectric matrix eps_P0
299                    Pi_P0 = get_rho_response_wing(gw, freqs[w], mo_energy, Lij, qij)
300                    eps_P0 = -np.sqrt(4.*np.pi) / np.linalg.norm(q_abs[0]) * Pi_P0
301
302                    # inverse dielectric matrix
303                    eps_inv_00 = 1./(eps_00 - np.dot(np.dot(eps_P0.conj(),eps_body_inv),eps_P0))
304                    eps_inv_P0 = -eps_inv_00 * np.dot(eps_body_inv, eps_P0)
305
306                    # head correction
307                    Del_00 = 2./np.pi * (6.*np.pi**2/gw.mol.vol/nkpts)**(1./3.) * (eps_inv_00 - 1.)
308
309                eps_inv_PQ = eps_body_inv
310                g0_occ = wts[w] * emo_occ / (emo_occ**2+freqs[w]**2)
311                g0_vir = wts[w] * emo_vir / (emo_vir**2+freqs[w]**2)
312
313                for k in range(nklist):
314                    kn = kptlist[k]
315                    # Find km that conserves with kn and kL (-km+kn+kL=G)
316                    km = kidx_r[kn]
317                    Qmn = einsum('Pmn,PQ->Qmn',Lij[km][:,:,orbs].conj(),eps_inv_PQ-np.eye(naux))
318                    Wmn = 1./nkpts * einsum('Qmn,Qmn->mn',Qmn,Lij[km][:,:,orbs])
319                    sigma[k][:norbs_occ] += -einsum('mn,mw->nw',Wmn[:,:norbs_occ],g0_occ[km])/np.pi
320                    sigma[k][norbs_occ:] += -einsum('mn,mw->nw',Wmn[:,norbs_occ:],g0_vir[km])/np.pi
321
322                    if gw.fc:
323                        # apply head correction
324                        assert(kn == km)
325                        sigma[k][:norbs_occ] += -Del_00 * g0_occ[kn][orbs][:norbs_occ] /np.pi
326                        sigma[k][norbs_occ:] += -Del_00 * g0_vir[kn][orbs][norbs_occ:] /np.pi
327
328                        # apply wing correction
329                        Wn_P0 = einsum('Pnm,P->nm',Lij[kn],eps_inv_P0).diagonal()
330                        Wn_P0 = Wn_P0.real * 2.
331                        Del_P0 = np.sqrt(gw.mol.vol/4./np.pi**3) * (6.*np.pi**2/gw.mol.vol/nkpts)**(2./3.) * Wn_P0[orbs]
332                        sigma[k][:norbs_occ] += -einsum('n,nw->nw', Del_P0[:norbs_occ],
333                                                        g0_occ[kn][orbs][:norbs_occ]) /np.pi
334                        sigma[k][norbs_occ:] += -einsum('n,nw->nw', Del_P0[norbs_occ:],
335                                                        g0_vir[kn][orbs][norbs_occ:]) /np.pi
336        else:
337            for w in range(nw):
338                Pi = get_rho_response(gw, freqs[w], mo_energy, Lij, kL, kidx)
339                Pi_inv = np.linalg.inv(np.eye(naux)-Pi)-np.eye(naux)
340                g0_occ = wts[w] * emo_occ / (emo_occ**2+freqs[w]**2)
341                g0_vir = wts[w] * emo_vir / (emo_vir**2+freqs[w]**2)
342                for k in range(nklist):
343                    kn = kptlist[k]
344                    # Find km that conserves with kn and kL (-km+kn+kL=G)
345                    km = kidx_r[kn]
346                    Qmn = einsum('Pmn,PQ->Qmn',Lij[km][:,:,orbs].conj(),Pi_inv)
347                    Wmn = 1./nkpts * einsum('Qmn,Qmn->mn',Qmn,Lij[km][:,:,orbs])
348                    sigma[k][:norbs_occ] += -einsum('mn,mw->nw',Wmn[:,:norbs_occ],g0_occ[km])/np.pi
349                    sigma[k][norbs_occ:] += -einsum('mn,mw->nw',Wmn[:,norbs_occ:],g0_vir[km])/np.pi
350
351    return sigma, omega
352
353def get_rho_response_head(gw, omega, mo_energy, qij):
354    '''
355    Compute head (G=0, G'=0) density response function in auxiliary basis at freq iw
356    '''
357    nkpts, nocc, nvir = qij.shape
358    nocc = gw.nocc
359    kpts = gw.kpts
360
361    # Compute Pi head
362    Pi_00 = 0j
363    for i, kpti in enumerate(kpts):
364        eia = mo_energy[i,:nocc,None] - mo_energy[i,None,nocc:]
365        eia = eia/(omega**2+eia*eia)
366        Pi_00 += 4./nkpts * einsum('ia,ia->',eia,qij[i].conj()*qij[i])
367    return Pi_00
368
369def get_rho_response_wing(gw, omega, mo_energy, Lpq, qij):
370    '''
371    Compute wing (G=P, G'=0) density response function in auxiliary basis at freq iw
372    '''
373    nkpts, naux, nmo, nmo = Lpq.shape
374    nocc = gw.nocc
375    kpts = gw.kpts
376
377    # Compute Pi wing
378    Pi = np.zeros(naux,dtype=np.complex128)
379    for i, kpti in enumerate(kpts):
380        eia = mo_energy[i,:nocc,None] - mo_energy[i,None,nocc:]
381        eia = eia/(omega**2+eia*eia)
382        eia_q = eia * qij[i].conj()
383        Pi += 4./nkpts * einsum('Pia,ia->P',Lpq[i][:,:nocc,nocc:],eia_q)
384    return Pi
385
386def get_qij(gw, q, mo_coeff, uniform_grids=False):
387    '''
388    Compute qij = 1/Omega * |< psi_{ik} | e^{iqr} | psi_{ak-q} >|^2 at q: (nkpts, nocc, nvir)
389    through kp perturbtation theory
390    Ref: Phys. Rev. B 83, 245122 (2011)
391    '''
392    nocc = gw.nocc
393    nmo = gw.nmo
394    nvir = nmo - nocc
395    kpts = gw.kpts
396    nkpts = len(kpts)
397    cell = gw.mol
398    mo_energy = gw._scf.mo_energy
399
400    if uniform_grids:
401        mydf = df.FFTDF(cell, kpts=kpts)
402        coords = cell.gen_uniform_grids(mydf.mesh)
403    else:
404        coords, weights = dft.gen_grid.get_becke_grids(cell,level=5)
405    ngrid = len(coords)
406
407    qij = np.zeros((nkpts,nocc,nvir),dtype=np.complex128)
408    for i, kpti in enumerate(kpts):
409        ao_p = dft.numint.eval_ao(cell, coords, kpt=kpti, deriv=1)
410        ao = ao_p[0]
411        ao_grad = ao_p[1:4]
412        if uniform_grids:
413            ao_ao_grad = einsum('mg,xgn->xmn',ao.T.conj(),ao_grad) * cell.vol / ngrid
414        else:
415            ao_ao_grad = einsum('g,mg,xgn->xmn',weights,ao.T.conj(),ao_grad)
416        q_ao_ao_grad = -1j * einsum('x,xmn->mn',q,ao_ao_grad)
417        q_mo_mo_grad = np.dot(np.dot(mo_coeff[i][:,:nocc].T.conj(), q_ao_ao_grad), mo_coeff[i][:,nocc:])
418        enm = 1./(mo_energy[i][nocc:,None] - mo_energy[i][None,:nocc])
419        dens = enm.T * q_mo_mo_grad
420        qij[i] = dens / np.sqrt(cell.vol)
421
422    return qij
423
424def _get_scaled_legendre_roots(nw):
425    """
426    Scale nw Legendre roots, which lie in the
427    interval [-1, 1], so that they lie in [0, inf)
428    Ref: www.cond-mat.de/events/correl19/manuscripts/ren.pdf
429
430    Returns:
431        freqs : 1D ndarray
432        wts : 1D ndarray
433    """
434    freqs, wts = np.polynomial.legendre.leggauss(nw)
435    x0 = 0.5
436    freqs_new = x0*(1.+freqs)/(1.-freqs)
437    wts = wts*2.*x0/(1.-freqs)**2
438    return freqs_new, wts
439
440def _get_clenshaw_curtis_roots(nw):
441    """
442    Clenshaw-Curtis qaudrature on [0,inf)
443    Ref: J. Chem. Phys. 132, 234114 (2010)
444    Returns:
445        freqs : 1D ndarray
446        wts : 1D ndarray
447    """
448    freqs = np.zeros(nw)
449    wts = np.zeros(nw)
450    a = 0.2
451    for w in range(nw):
452        t = (w+1.0)/nw * np.pi/2.
453        freqs[w] = a / np.tan(t)
454        if w != nw-1:
455            wts[w] = a*np.pi/2./nw/(np.sin(t)**2)
456        else:
457            wts[w] = a*np.pi/4./nw/(np.sin(t)**2)
458    return freqs[::-1], wts[::-1]
459
460def two_pole_fit(coeff, omega, sigma):
461    cf = coeff[:5] + 1j*coeff[5:]
462    f = cf[0] + cf[1]/(omega+cf[3]) + cf[2]/(omega+cf[4]) - sigma
463    f[0] = f[0]/0.01
464    return np.array([f.real,f.imag]).reshape(-1)
465
466def two_pole(freqs, coeff):
467    cf = coeff[:5] + 1j*coeff[5:]
468    return cf[0] + cf[1]/(freqs+cf[3]) + cf[2]/(freqs+cf[4])
469
470def AC_twopole_diag(sigma, omega, orbs, nocc):
471    """
472    Analytic continuation to real axis using a two-pole model
473    Returns:
474        coeff: 2D array (ncoeff, norbs)
475    """
476    norbs, nw = sigma.shape
477    coeff = np.zeros((10,norbs))
478    for p in range(norbs):
479        if orbs[p] < nocc:
480            x0 = np.array([0, 1, 1, 1, -1, 0, 0, 0, -1.0, -0.5])
481        else:
482            x0 = np.array([0, 1, 1, 1, -1, 0, 0, 0, 1.0, 0.5])
483        #TODO: analytic gradient
484        xopt = least_squares(two_pole_fit, x0, jac='3-point', method='trf', xtol=1e-10,
485                             gtol = 1e-10, max_nfev=1000, verbose=0, args=(omega[p], sigma[p]))
486        if xopt.success is False:
487            print('WARN: 2P-Fit Orb %d not converged, cost function %e'%(p,xopt.cost))
488        coeff[:,p] = xopt.x.copy()
489    return coeff
490
491def thiele(fn,zn):
492    nfit = len(zn)
493    g = np.zeros((nfit,nfit),dtype=np.complex128)
494    g[:,0] = fn.copy()
495    for i in range(1,nfit):
496        g[i:,i] = (g[i-1,i-1]-g[i:,i-1])/((zn[i:]-zn[i-1])*g[i:,i-1])
497    a = g.diagonal()
498    return a
499
500def pade_thiele(freqs,zn,coeff):
501    nfit = len(coeff)
502    X = coeff[-1]*(freqs-zn[-2])
503    for i in range(nfit-1):
504        idx = nfit-i-1
505        X = coeff[idx]*(freqs-zn[idx-1])/(1.+X)
506    X = coeff[0]/(1.+X)
507    return X
508
509def AC_pade_thiele_diag(sigma, omega):
510    """
511    Analytic continuation to real axis using a Pade approximation
512    from Thiele's reciprocal difference method
513    Reference: J. Low Temp. Phys. 29, 179 (1977)
514    Returns:
515        coeff: 2D array (ncoeff, norbs)
516        omega: 2D array (norbs, npade)
517    """
518    idx = range(1,40,6)
519    sigma1 = sigma[:,idx].copy()
520    sigma2 = sigma[:,(idx[-1]+4)::4].copy()
521    sigma = np.hstack((sigma1,sigma2))
522    omega1 = omega[:,idx].copy()
523    omega2 = omega[:,(idx[-1]+4)::4].copy()
524    omega = np.hstack((omega1,omega2))
525    norbs, nw = sigma.shape
526    npade = nw // 2
527    coeff = np.zeros((npade*2,norbs),dtype=np.complex128)
528    for p in range(norbs):
529        coeff[:,p] = thiele(sigma[p,:npade*2], omega[p,:npade*2])
530
531    return coeff, omega[:,:npade*2]
532
533class KRGWAC(lib.StreamObject):
534
535    linearized = getattr(__config__, 'gw_gw_GW_linearized', False)
536    # Analytic continuation: pade or twopole
537    ac = getattr(__config__, 'gw_gw_GW_ac', 'pade')
538    # Whether applying finite size corrections
539    fc = getattr(__config__, 'gw_gw_GW_fc', True)
540
541    def __init__(self, mf, frozen=0):
542        self.mol = mf.mol
543        self._scf = mf
544        self.verbose = self.mol.verbose
545        self.stdout = self.mol.stdout
546        self.max_memory = mf.max_memory
547
548        #TODO: implement frozen orbs
549        if frozen > 0:
550            raise NotImplementedError
551        self.frozen = frozen
552
553        # DF-KGW must use GDF integrals
554        if getattr(mf, 'with_df', None):
555            self.with_df = mf.with_df
556        else:
557            raise NotImplementedError
558        self._keys.update(['with_df'])
559
560##################################################
561# don't modify the following attributes, they are not input options
562        self._nocc = None
563        self._nmo = None
564        self.kpts = mf.kpts
565        self.nkpts = len(self.kpts)
566        # self.mo_energy: GW quasiparticle energy, not scf mo_energy
567        self.mo_energy = None
568        self.mo_coeff = mf.mo_coeff
569        self.mo_occ = mf.mo_occ
570        self.sigma = None
571
572        keys = set(('linearized','ac','fc'))
573        self._keys = set(self.__dict__.keys()).union(keys)
574
575    def dump_flags(self):
576        log = logger.Logger(self.stdout, self.verbose)
577        log.info('')
578        log.info('******** %s ********', self.__class__)
579        log.info('method = %s', self.__class__.__name__)
580        nocc = self.nocc
581        nvir = self.nmo - nocc
582        nkpts = self.nkpts
583        log.info('GW nocc = %d, nvir = %d, nkpts = %d', nocc, nvir, nkpts)
584        if self.frozen is not None:
585            log.info('frozen orbitals %s', str(self.frozen))
586        logger.info(self, 'use perturbative linearized QP eqn = %s', self.linearized)
587        logger.info(self, 'analytic continuation method = %s', self.ac)
588        logger.info(self, 'GW finite size corrections = %s', self.fc)
589        return self
590
591    @property
592    def nocc(self):
593        return self.get_nocc()
594    @nocc.setter
595    def nocc(self, n):
596        self._nocc = n
597
598    @property
599    def nmo(self):
600        return self.get_nmo()
601    @nmo.setter
602    def nmo(self, n):
603        self._nmo = n
604
605    get_nocc = get_nocc
606    get_nmo = get_nmo
607    get_frozen_mask = get_frozen_mask
608
609    def kernel(self, mo_energy=None, mo_coeff=None, orbs=None, kptlist=None, nw=100):
610        """
611        Input:
612            kptlist: self-energy k-points
613            orbs: self-energy orbs
614            nw: grid number
615        Output:
616            mo_energy: GW quasiparticle energy
617        """
618        if mo_coeff is None:
619            mo_coeff = np.array(self._scf.mo_coeff)
620        if mo_energy is None:
621            mo_energy = np.array(self._scf.mo_energy)
622
623        nmo = self.nmo
624        naux = self.with_df.get_naoaux()
625        nkpts = self.nkpts
626        mem_incore = (2*nkpts*nmo**2*naux) * 16/1e6
627        mem_now = lib.current_memory()[0]
628        if (mem_incore + mem_now > 0.99*self.max_memory):
629            logger.warn(self, 'Memory may not be enough!')
630            raise NotImplementedError
631
632        cput0 = (logger.process_clock(), logger.perf_counter())
633        self.dump_flags()
634        self.converged, self.mo_energy, self.mo_coeff = \
635                kernel(self, mo_energy, mo_coeff, orbs=orbs,
636                       kptlist=kptlist, nw=nw, verbose=self.verbose)
637
638        logger.warn(self, 'GW QP energies may not be sorted from min to max')
639        logger.timer(self, 'GW', *cput0)
640        return self.mo_energy
641
642if __name__ == '__main__':
643    from pyscf.pbc import gto
644    from pyscf.pbc.lib import chkfile
645    import os
646    # This test takes a few minutes
647    cell = gto.Cell()
648    cell.build(unit = 'angstrom',
649               a = '''
650               0.000000     1.783500     1.783500
651               1.783500     0.000000     1.783500
652               1.783500     1.783500     0.000000
653               ''',
654               atom = 'C 1.337625 1.337625 1.337625; C 2.229375 2.229375 2.229375',
655               dimension = 3,
656               max_memory = 8000,
657               verbose = 4,
658               pseudo = 'gth-pade',
659               basis='gth-szv',
660               precision=1e-10)
661
662    kpts = cell.make_kpts([3,1,1],scaled_center=[0,0,0])
663    gdf = df.GDF(cell, kpts)
664    gdf_fname = 'gdf_ints_311.h5'
665    gdf._cderi_to_save = gdf_fname
666    if not os.path.isfile(gdf_fname):
667        gdf.build()
668
669    chkfname = 'diamond_311.chk'
670    if os.path.isfile(chkfname):
671        kmf = dft.KRKS(cell, kpts)
672        kmf.xc = 'pbe'
673        kmf.with_df = gdf
674        kmf.with_df._cderi = gdf_fname
675        data = chkfile.load(chkfname, 'scf')
676        kmf.__dict__.update(data)
677    else:
678        kmf = dft.KRKS(cell, kpts)
679        kmf.xc = 'pbe'
680        kmf.with_df = gdf
681        kmf.with_df._cderi = gdf_fname
682        kmf.conv_tol = 1e-12
683        kmf.chkfile = chkfname
684        kmf.kernel()
685
686    gw = KRGWAC(kmf)
687    gw.linearized = False
688    gw.ac = 'pade'
689    # without finite size corrections
690    gw.fc = False
691    nocc = gw.nocc
692    gw.kernel(kptlist=[0,1,2],orbs=range(0,nocc+3))
693    print(gw.mo_energy)
694    assert((abs(gw.mo_energy[0][nocc-1]-0.62045797))<1e-5)
695    assert((abs(gw.mo_energy[0][nocc]-0.96574324))<1e-5)
696    assert((abs(gw.mo_energy[1][nocc-1]-0.52639137))<1e-5)
697    assert((abs(gw.mo_energy[1][nocc]-1.07513258))<1e-5)
698
699    # with finite size corrections
700    gw.fc = True
701    gw.kernel(kptlist=[0,1,2],orbs=range(0,nocc+3))
702    print(gw.mo_energy)
703    assert((abs(gw.mo_energy[0][nocc-1]-0.54277092))<1e-5)
704    assert((abs(gw.mo_energy[0][nocc]-0.80148537))<1e-5)
705    assert((abs(gw.mo_energy[1][nocc-1]-0.45073793))<1e-5)
706    assert((abs(gw.mo_energy[1][nocc]-0.92910108))<1e-5)
707