1#!/usr/bin/env python
2# Copyright 2017-2021 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Authors: James D. McClain
17#          Mario Motta
18#          Yang Gao
19#          Qiming Sun <osirpt.sun@gmail.com>
20#          Jason Yu
21#          Alec White
22#
23
24
25from functools import reduce
26import numpy as np
27import h5py
28
29from pyscf import lib
30from pyscf.lib import logger
31from pyscf.pbc import scf
32from pyscf.cc import uccsd
33from pyscf.pbc.lib import kpts_helper
34from pyscf.pbc.lib.kpts_helper import gamma_point
35from pyscf.lib.parameters import LOOSE_ZERO_TOL, LARGE_DENOM  # noqa
36from pyscf.pbc.mp.kump2 import (get_frozen_mask, get_nocc, get_nmo,
37                                padded_mo_coeff, padding_k_idx)  # noqa
38from pyscf.pbc.cc import kintermediates_uhf
39from pyscf import __config__
40
41einsum = lib.einsum
42
43
44# --- list2array
45def mo_c_list_to_array(mo_coeff):
46    mo_coeff_tmp=[]
47    for js in range(2):
48        tmp_nk = len(mo_coeff[js])
49        tmp_nb = mo_coeff[js][0].shape[0]
50        tmp_array = np.zeros((tmp_nk,tmp_nb,tmp_nb),dtype=complex)
51        for ik in range(tmp_nk):
52            tmp_array[ik,:,:]=mo_coeff[js][ik][:,:]
53        mo_coeff_tmp.append(tmp_array)
54    return mo_coeff_tmp
55
56def convert_mo_coeff(mo_coeff):
57    if isinstance(mo_coeff[0], list):
58        mo_coeff=mo_c_list_to_array(mo_coeff)
59    return mo_coeff
60
61def update_amps(cc, t1, t2, eris):
62    time0 = logger.process_clock(), logger.perf_counter()
63    log = logger.Logger(cc.stdout, cc.verbose)
64
65    t1a, t1b = t1
66    t2aa, t2ab, t2bb = t2
67    Ht1a = np.zeros_like(t1a)
68    Ht1b = np.zeros_like(t1b)
69    Ht2aa = np.zeros_like(t2aa)
70    Ht2ab = np.zeros_like(t2ab)
71    Ht2bb = np.zeros_like(t2bb)
72
73    nkpts, nocca, nvira = t1a.shape
74    noccb, nvirb = t1b.shape[1:]
75    #fvv_ = eris.fock[0][:,nocca:,nocca:]
76    #fVV_ = eris.fock[1][:,noccb:,noccb:]
77    #foo_ = eris.fock[0][:,:nocca,:nocca]
78    #fOO_ = eris.fock[1][:,:noccb,:noccb]
79    fov_ = eris.fock[0][:,:nocca,nocca:]
80    fOV_ = eris.fock[1][:,:noccb,noccb:]
81
82    # Get location of padded elements in occupied and virtual space
83    nonzero_padding_alpha, nonzero_padding_beta = padding_k_idx(cc, kind="split")
84    nonzero_opadding_alpha, nonzero_vpadding_alpha = nonzero_padding_alpha
85    nonzero_opadding_beta, nonzero_vpadding_beta = nonzero_padding_beta
86
87    mo_ea_o = [e[:nocca] for e in eris.mo_energy[0]]
88    mo_eb_o = [e[:noccb] for e in eris.mo_energy[1]]
89    mo_ea_v = [e[nocca:] + cc.level_shift for e in eris.mo_energy[0]]
90    mo_eb_v = [e[noccb:] + cc.level_shift for e in eris.mo_energy[1]]
91
92    Fvv_, FVV_ = kintermediates_uhf.cc_Fvv(cc, t1, t2, eris)
93    Foo_, FOO_ = kintermediates_uhf.cc_Foo(cc, t1, t2, eris)
94    Fov_, FOV_ = kintermediates_uhf.cc_Fov(cc, t1, t2, eris)
95
96    # Move energy terms to the other side
97    for k in range(nkpts):
98        Fvv_[k][np.diag_indices(nvira)] -= mo_ea_v[k]
99        FVV_[k][np.diag_indices(nvirb)] -= mo_eb_v[k]
100        Foo_[k][np.diag_indices(nocca)] -= mo_ea_o[k]
101        FOO_[k][np.diag_indices(noccb)] -= mo_eb_o[k]
102
103    # Get the momentum conservation array
104    kconserv = cc.khelper.kconserv
105
106    # T1 equation
107    P = kintermediates_uhf.kconserv_mat(cc.nkpts, cc.khelper.kconserv)
108    Ht1a += fov_.conj()
109    Ht1b += fOV_.conj()
110    Ht1a += einsum('xyximae,yme->xia', t2aa, Fov_)
111    Ht1a += einsum('xyximae,yme->xia', t2ab, FOV_)
112    Ht1b += einsum('xyximae,yme->xia', t2bb, FOV_)
113    Ht1b += einsum('yxymiea,yme->xia', t2ab, Fov_)
114    Ht1a -= einsum('xyzmnae, xzymine->zia', t2aa, eris.ooov)
115    Ht1a -= einsum('xyzmNaE, xzymiNE->zia', t2ab, eris.ooOV)
116    #Ht1a -= einsum('xyzmnae,xzymine,xyzw->zia', t2aa, eris.ooov, P)
117    #Ht1a -= einsum('xyzmNaE,xzymiNE,xyzw->zia', t2ab, eris.ooOV, P)
118    Ht1b -= einsum('xyzmnae, xzymine->zia', t2bb, eris.OOOV)
119    #Ht1b -= einsum('xyzmnae,xzymine,xyzw->zia', t2bb, eris.OOOV, P)
120    Ht1b -= einsum('yxwnmea,xzymine,xyzw->zia', t2ab, eris.OOov, P)
121
122    for ka in range(nkpts):
123        Ht1a[ka] += einsum('ie,ae->ia', t1a[ka], Fvv_[ka])
124        Ht1b[ka] += einsum('ie,ae->ia', t1b[ka], FVV_[ka])
125        Ht1a[ka] -= einsum('ma,mi->ia', t1a[ka], Foo_[ka])
126        Ht1b[ka] -= einsum('ma,mi->ia', t1b[ka], FOO_[ka])
127
128        for km in range(nkpts):
129            # ka == ki; km == kf == km
130            # <ma||if> = [mi|af] - [mf|ai]
131            #         => [mi|af] - [fm|ia]
132            Ht1a[ka] += einsum('mf,aimf->ia', t1a[km], eris.voov[ka, ka, km])
133            Ht1a[ka] -= einsum('mf,miaf->ia', t1a[km], eris.oovv[km, ka, ka])
134            Ht1a[ka] += einsum('MF,aiMF->ia', t1b[km], eris.voOV[ka, ka, km])
135
136            # miaf - mfai => miaf - fmia
137            Ht1b[ka] += einsum('MF,AIMF->IA', t1b[km], eris.VOOV[ka, ka, km])
138            Ht1b[ka] -= einsum('MF,MIAF->IA', t1b[km], eris.OOVV[km, ka, ka])
139            Ht1b[ka] += einsum('mf,fmIA->IA', t1a[km], eris.voOV[km, km, ka].conj())
140
141            for kf in range(nkpts):
142                ki = ka
143                ke = kconserv[ki, kf, km]
144                Ht1a[ka] += einsum('imef,fmea->ia', t2aa[ki,km,ke], eris.vovv[kf,km,ke].conj())
145                Ht1a[ka] += einsum('iMeF,FMea->ia', t2ab[ki,km,ke], eris.VOvv[kf,km,ke].conj())
146                Ht1b[ka] += einsum('IMEF,FMEA->IA', t2bb[ki,km,ke], eris.VOVV[kf,km,ke].conj())
147                Ht1b[ka] += einsum('mIfE,fmEA->IA', t2ab[km,ki,kf], eris.voVV[kf,km,ke].conj())
148
149    for ki, kj, ka in kpts_helper.loop_kkk(nkpts):
150        kb = kconserv[ki, ka, kj]
151
152        # Fvv equation
153        Ftmpa_kb = Fvv_[kb] - 0.5 * einsum('mb,me->be', t1a[kb], Fov_[kb])
154        Ftmpb_kb = FVV_[kb] - 0.5 * einsum('MB,ME->BE', t1b[kb], FOV_[kb])
155
156        Ftmpa_ka = Fvv_[ka] - 0.5 * einsum('mb,me->be', t1a[ka], Fov_[ka])
157        Ftmpb_ka = FVV_[ka] - 0.5 * einsum('MB,ME->BE', t1b[ka], FOV_[ka])
158
159        tmp = einsum('ijae,be->ijab', t2aa[ki, kj, ka], Ftmpa_kb)
160        Ht2aa[ki, kj, ka] += tmp
161
162        tmp = einsum('IJAE,BE->IJAB', t2bb[ki, kj, ka], Ftmpb_kb)
163        Ht2bb[ki, kj, ka] += tmp
164
165        tmp = einsum('iJaE,BE->iJaB', t2ab[ki, kj, ka], Ftmpb_kb)
166        Ht2ab[ki, kj, ka] += tmp
167
168        tmp = einsum('iJeB,ae->iJaB', t2ab[ki, kj, ka], Ftmpa_ka)
169        Ht2ab[ki, kj, ka] += tmp
170
171        #P(ab)
172        tmp = einsum('ijbe,ae->ijab', t2aa[ki, kj, kb], Ftmpa_ka)
173        Ht2aa[ki, kj, ka] -= tmp
174
175        tmp = einsum('IJBE,AE->IJAB', t2bb[ki, kj, kb], Ftmpb_ka)
176        Ht2bb[ki, kj, ka] -= tmp
177
178        # Foo equation
179        Ftmpa_kj = Foo_[kj] + 0.5 * einsum('je,me->mj', t1a[kj], Fov_[kj])
180        Ftmpb_kj = FOO_[kj] + 0.5 * einsum('JE,ME->MJ', t1b[kj], FOV_[kj])
181
182        Ftmpa_ki = Foo_[ki] + 0.5 * einsum('je,me->mj', t1a[ki], Fov_[ki])
183        Ftmpb_ki = FOO_[ki] + 0.5 * einsum('JE,ME->MJ', t1b[ki], FOV_[ki])
184
185        tmp = einsum('imab,mj->ijab', t2aa[ki, kj, ka], Ftmpa_kj)
186        Ht2aa[ki, kj, ka] -= tmp
187
188        tmp = einsum('IMAB,MJ->IJAB', t2bb[ki, kj, ka], Ftmpb_kj)
189        Ht2bb[ki, kj, ka] -= tmp
190
191        tmp = einsum('iMaB,MJ->iJaB', t2ab[ki, kj, ka], Ftmpb_kj)
192        Ht2ab[ki, kj, ka] -= tmp
193
194        tmp = einsum('mJaB,mi->iJaB', t2ab[ki, kj, ka], Ftmpa_ki)
195        Ht2ab[ki, kj, ka] -= tmp
196
197        #P(ij)
198        tmp = einsum('jmab,mi->ijab', t2aa[kj, ki, ka], Ftmpa_ki)
199        Ht2aa[ki, kj, ka] += tmp
200
201        tmp = einsum('JMAB,MI->IJAB', t2bb[kj, ki, ka], Ftmpb_ki)
202        Ht2bb[ki, kj, ka] += tmp
203
204    # T2 equation
205    eris_ovov = np.asarray(eris.ovov)
206    eris_OVOV = np.asarray(eris.OVOV)
207    eris_ovOV = np.asarray(eris.ovOV)
208    Ht2aa += (eris_ovov.transpose(0,2,1,3,5,4,6) - eris_ovov.transpose(2,0,1,5,3,4,6)).conj()
209    Ht2bb += (eris_OVOV.transpose(0,2,1,3,5,4,6) - eris_OVOV.transpose(2,0,1,5,3,4,6)).conj()
210    Ht2ab += eris_ovOV.transpose(0,2,1,3,5,4,6).conj()
211
212    tauaa, tauab, taubb = kintermediates_uhf.make_tau(cc, t2, t1, t1)
213    Woooo, WooOO, WOOOO = kintermediates_uhf.cc_Woooo(cc, t1, t2, eris)
214
215    # Add the contributions from Wvvvv
216    for km, ki, kn in kpts_helper.loop_kkk(nkpts):
217        kj = kconserv[km,ki,kn]
218        Woooo[km,ki,kn] += .5 * einsum('xmenf, xijef->minj', eris_ovov[km,:,kn], tauaa[ki,kj])
219        WOOOO[km,ki,kn] += .5 * einsum('xMENF, xIJEF->MINJ', eris_OVOV[km,:,kn], taubb[ki,kj])
220        WooOO[km,ki,kn] += .5 * einsum('xmeNF, xiJeF->miNJ', eris_ovOV[km,:,kn], tauab[ki,kj])
221
222    for km, ki, kn in kpts_helper.loop_kkk(nkpts):
223        kj = kconserv[km,ki,kn]
224        Ht2aa[ki,kj,:] += einsum('minj,wmnab->wijab', Woooo[km,ki,kn], tauaa[km,kn]) * .5
225        Ht2bb[ki,kj,:] += einsum('MINJ,wMNAB->wIJAB', WOOOO[km,ki,kn], taubb[km,kn]) * .5
226        Ht2ab[ki,kj,:] += einsum('miNJ,wmNaB->wiJaB', WooOO[km,ki,kn], tauab[km,kn])
227
228    add_vvvv_(cc, (Ht2aa, Ht2ab, Ht2bb), t1, t2, eris)
229
230    Wovvo, WovVO, WOVvo, WOVVO, WoVVo, WOvvO = \
231            kintermediates_uhf.cc_Wovvo(cc, t1, t2, eris)
232
233    #:Ht2ab += einsum('xwzimae,wvumeBJ,xwzv,wuvy->xyziJaB', t2aa, WovVO, P, P)
234    #:Ht2ab += einsum('xwziMaE,wvuMEBJ,xwzv,wuvy->xyziJaB', t2ab, WOVVO, P, P)
235    #:Ht2ab -= einsum('xie,zma,uwzBJme,zuwx,xyzu->xyziJaB', t1a, t1a, eris.VOov, P, P)
236    for kx, kw, kz in kpts_helper.loop_kkk(nkpts):
237        kv = kconserv[kx, kz, kw]
238        for ku in range(nkpts):
239            ky = kconserv[kw, kv, ku]
240            Ht2ab[kx, ky, kz] += lib.einsum('imae,mebj->ijab', t2aa[kx,kw,kz], WovVO[kw,kv,ku])
241            Ht2ab[kx, ky, kz] += lib.einsum('imae,mebj->ijab', t2ab[kx,kw,kz], WOVVO[kw,kv,ku])
242
243    #for kz, ku, kw in kpts_helper.loop_kkk(nkpts):
244    #    kx = kconserv[kz,kw,ku]
245    #    ky = kconserv[kz,kx,ku]
246    #    continue
247    #    Ht2ab[kx, ky, kz] -= lib.einsum('ie, ma, emjb->ijab', t1a[kx], t1a[kz], eris.voOV[kx,kz,kw].conj())
248    Ht2ab -= einsum('xie, yma, xyzemjb->xzyijab', t1a, t1a, eris.voOV[:].conj())
249    #:Ht2ab += einsum('wxvmIeA,wvumebj,xwzv,wuvy->yxujIbA', t2ab, Wovvo, P, P)
250    #:Ht2ab += einsum('wxvMIEA,wvuMEbj,xwzv,wuvy->yxujIbA', t2bb, WOVvo, P, P)
251    #:Ht2ab -= einsum('xIE,zMA,uwzbjME,zuwx,xyzu->yxujIbA', t1b, t1b, eris.voOV, P, P)
252
253    #for kx, kw, kz in kpts_helper.loop_kkk(nkpts):
254    #    kv = kconserv[kx, kz, kw]
255    #    for ku in range(nkpts):
256    #        ky = kconserv[kw, kv, ku]
257    #        #Ht2ab[ky,kx,ku] += lib.einsum('miea, mebj-> jiba', t2ab[kw,kx,kv], Wovvo[kw,kv,ku])
258    #        #Ht2ab[ky,kx,ku] += lib.einsum('miea, mebj-> jiba', t2bb[kw,kx,kv], WOVvo[kw,kv,ku])
259
260    for km, ke, kb in kpts_helper.loop_kkk(nkpts):
261        kj = kconserv[km, ke, kb]
262        Ht2ab[kj,:,kb] += einsum('xmiea, mebj->xjiba', t2ab[km,:,ke], Wovvo[km,ke,kb])
263        Ht2ab[kj,:,kb] += einsum('xmiea, mebj->xjiba', t2bb[km,:,ke], WOVvo[km,ke,kb])
264
265
266    for kz, ku, kw in kpts_helper.loop_kkk(nkpts):
267        kx = kconserv[kz, kw, ku]
268        ky = kconserv[kz, kx, ku]
269        Ht2ab[ky,kx,ku] -= lib.einsum('ie, ma, bjme->jiba', t1b[kx], t1b[kz], eris.voOV[ku,kw,kz])
270
271
272    #:Ht2ab += einsum('xwviMeA,wvuMebJ,xwzv,wuvy->xyuiJbA', t2ab, WOvvO, P, P)
273    #:Ht2ab -= einsum('xie,zMA,zwuMJbe,zuwx,xyzu->xyuiJbA', t1a, t1b, eris.OOvv, P, P)
274    #for kx, kw, kz in kpts_helper.loop_kkk(nkpts):
275    #    kv = kconserv[kx, kz, kw]
276    #    for ku in range(nkpts):
277    #        ky = kconserv[kw, kv, ku]
278    #        Ht2ab[kx,ky,ku] += lib.einsum('imea,mebj->ijba', t2ab[kx,kw,kv],WOvvO[kw,kv,ku])
279    for km, ke, kb in kpts_helper.loop_kkk(nkpts):
280        kj = kconserv[km, ke, kb]
281        Ht2ab[:,kj,kb] += einsum('ximea, mebj->xijba', t2ab[:,km,ke], WOvvO[km,ke,kb])
282
283
284    for kz,ku,kw in kpts_helper.loop_kkk(nkpts):
285        kx = kconserv[kz, kw, ku]
286        ky = kconserv[kz, kx, ku]
287        Ht2ab[kx,ky,ku] -= lib.einsum('ie, ma, mjbe->ijba', t1a[kx], t1b[kz], eris.OOvv[kz, kw, ku])
288
289    #:Ht2ab += einsum('wxzmIaE,wvumEBj,xwzv,wuvy->yxzjIaB', t2ab, WoVVo, P, P)
290    #:Ht2ab -= einsum('xIE,zma,zwumjBE,zuwx,xyzu->yxzjIaB', t1b, t1a, eris.ooVV, P, P)
291    for kx, kw, kz in kpts_helper.loop_kkk(nkpts):
292        kv = kconserv[kx, kz, kw]
293        for ku in range(nkpts):
294            ky = kconserv[kw, kv, ku]
295            Ht2ab[ky, kx, kz] += lib.einsum('miae,mebj->jiab', t2ab[kw,kx,kz], WoVVo[kw,kv,ku])
296
297    for kz, ku, kw in kpts_helper.loop_kkk(nkpts):
298        kx = kconserv[kz,kw,ku]
299        ky = kconserv[kz,kx,ku]
300        Ht2ab[ky,kx,kz] -= lib.einsum('ie, ma, mjbe->jiab', t1b[kx], t1a[kz], eris.ooVV[kz,kw,ku])
301
302    #:u2aa  = einsum('xwzimae,wvumebj,xwzv,wuvy->xyzijab', t2aa, Wovvo, P, P)
303    #:u2aa += einsum('xwziMaE,wvuMEbj,xwzv,wuvy->xyzijab', t2ab, WOVvo, P, P)
304    #Left this in to keep proper shape, need to replace later
305    u2aa  = np.zeros_like(t2aa)
306    for kx, kw, kz in kpts_helper.loop_kkk(nkpts):
307        kv = kconserv[kx, kz, kw]
308        for ku in range(nkpts):
309            ky = kconserv[kw, kv, ku]
310            u2aa[kx,ky,kz] += lib.einsum('imae, mebj->ijab', t2aa[kx,kw,kz], Wovvo[kw,kv,ku])
311            u2aa[kx,ky,kz] += lib.einsum('imae, mebj->ijab', t2ab[kx,kw,kz], WOVvo[kw,kv,ku])
312
313
314    #:u2aa += einsum('xie,zma,zwumjbe,zuwx,xyzu->xyzijab', t1a, t1a, eris.oovv, P, P)
315    #:u2aa -= einsum('xie,zma,uwzbjme,zuwx,xyzu->xyzijab', t1a, t1a, eris.voov, P, P)
316
317    for kz, ku, kw in kpts_helper.loop_kkk(nkpts):
318        kx = kconserv[kz,kw,ku]
319        ky = kconserv[kz,kx,ku]
320        u2aa[kx,ky,kz] += lib.einsum('ie,ma,mjbe->ijab',t1a[kx],t1a[kz],eris.oovv[kz,kw,ku])
321        u2aa[kx,ky,kz] -= lib.einsum('ie,ma,bjme->ijab',t1a[kx],t1a[kz],eris.voov[ku,kw,kz])
322
323
324    #:u2aa += np.einsum('xie,uyzbjae,uzyx->xyzijab', t1a, eris.vovv, P)
325    #:u2aa -= np.einsum('zma,xzyimjb->xyzijab', t1a, eris.ooov.conj())
326
327    for ky, kx, ku in kpts_helper.loop_kkk(nkpts):
328        kz = kconserv[ky, ku, kx]
329        u2aa[kx, ky, kz] += lib.einsum('ie, bjae->ijab', t1a[kx], eris.vovv[ku,ky,kz])
330        u2aa[kx, ky, kz] -= lib.einsum('ma, imjb->ijab', t1a[kz], eris.ooov[kx,kz,ky].conj())
331
332    u2aa = u2aa - u2aa.transpose(1,0,2,4,3,5,6)
333    u2aa = u2aa - einsum('xyzijab,xyzu->xyuijba', u2aa, P)
334    Ht2aa += u2aa
335
336    #:u2bb  = einsum('xwzimae,wvumebj,xwzv,wuvy->xyzijab', t2bb, WOVVO, P, P)
337    #:u2bb += einsum('wxvMiEa,wvuMEbj,xwzv,wuvy->xyzijab', t2ab, WovVO, P, P)
338    #:u2bb += einsum('xie,zma,zwumjbe,zuwx,xyzu->xyzijab', t1b, t1b, eris.OOVV, P, P)
339    #:u2bb -= einsum('xie,zma,uwzbjme,zuwx,xyzu->xyzijab', t1b, t1b, eris.VOOV, P, P)
340
341    u2bb = np.zeros_like(t2bb)
342
343    for kx, kw, kz in kpts_helper.loop_kkk(nkpts):
344        kv = kconserv[kx, kz, kw]
345        for ku in range(nkpts):
346            ky = kconserv[kw,kv, ku]
347            u2bb[kx, ky, kz] += lib.einsum('imae,mebj->ijab', t2bb[kx,kw,kz], WOVVO[kw,kv,ku])
348            u2bb[kx, ky, kz] += lib.einsum('miea, mebj-> ijab', t2ab[kw,kx,kv],WovVO[kw,kv,ku])
349
350    for kz, ku, kw in kpts_helper.loop_kkk(nkpts):
351        kx = kconserv[kz, kw, ku]
352        ky = kconserv[kz, kx, ku]
353        u2bb[kx, ky, kz] += lib.einsum('ie, ma, mjbe->ijab',t1b[kx],t1b[kz],eris.OOVV[kz,kw,ku])
354        u2bb[kx, ky, kz] -= lib.einsum('ie, ma, bjme->ijab', t1b[kx], t1b[kz],eris.VOOV[ku,kw,kz])
355
356    #:u2bb += np.einsum('xie,uzybjae,uzyx->xyzijab', t1b, eris.VOVV, P)
357    #:u2bb -= np.einsum('zma,xzyimjb->xyzijab', t1b, eris.OOOV.conj())
358
359    for ky, kx, ku in kpts_helper.loop_kkk(nkpts):
360        kz = kconserv[ky, ku, kx]
361        u2bb[kx,ky,kz] += lib.einsum('ie,bjae->ijab', t1b[kx], eris.VOVV[ku,ky,kz])
362
363    #for kx, kz, ky in kpts_helper.loop_kkk(nkpts):
364    #    u2bb[kx,ky,kz] -= lib.einsum('ma, imjb-> ijab', t1b[kz], eris.OOOV[kx,kz,ky].conj())
365    u2bb -= einsum('zma, xzyimjb->xyzijab', t1b, eris.OOOV[:].conj())
366
367    u2bb = u2bb - u2bb.transpose(1,0,2,4,3,5,6)
368    u2bb = u2bb - einsum('xyzijab,xyzu->xyuijba', u2bb, P)
369    Ht2bb += u2bb
370
371    #:Ht2ab += np.einsum('xie,uyzBJae,uzyx->xyziJaB', t1a, eris.VOvv, P)
372    #:Ht2ab += np.einsum('yJE,zxuaiBE,zuxy->xyziJaB', t1b, eris.voVV, P)
373    #:Ht2ab -= np.einsum('zma,xzyimjb->xyzijab', t1a, eris.ooOV.conj())
374    #:Ht2ab -= np.einsum('umb,yuxjmia,xyuz->xyzijab', t1b, eris.OOov.conj(), P)
375    for ky, kx, ku in kpts_helper.loop_kkk(nkpts):
376        kz = kconserv[ky,ku,kx]
377        Ht2ab[kx,ky,kz] += lib.einsum('ie, bjae-> ijab', t1a[kx], eris.VOvv[ku,ky,kz])
378        Ht2ab[kx,ky,kz] += lib.einsum('je, aibe-> ijab', t1b[ky], eris.voVV[kz,kx,ku])
379
380    #for kx, kz, ky in kpts_helper.loop_kkk(nkpts):
381    #    Ht2ab[kx,ky,kz] -= lib.einsum('ma, imjb->ijab', t1a[kz], eris.ooOV[kx,kz,ky].conj())
382    Ht2ab -= einsum('zma, xzyimjb->xyzijab', t1a, eris.ooOV[:].conj())
383
384    for kx, ky, ku in kpts_helper.loop_kkk(nkpts):
385        kz = kconserv[kx, ku, ky]
386        Ht2ab[kx,ky,kz] -= lib.einsum('mb,jmia->ijab',t1b[ku],eris.OOov[ky,ku,kx].conj())
387
388    eia = []
389    eIA = []
390    for ki in range(nkpts):
391        tmp_alpha = []
392        tmp_beta = []
393        for ka in range(nkpts):
394            tmp_eia = LARGE_DENOM * np.ones((nocca, nvira), dtype=eris.mo_energy[0][0].dtype)
395            tmp_eIA = LARGE_DENOM * np.ones((noccb, nvirb), dtype=eris.mo_energy[0][0].dtype)
396            n0_ovp_ia = np.ix_(nonzero_opadding_alpha[ki], nonzero_vpadding_alpha[ka])
397            n0_ovp_IA = np.ix_(nonzero_opadding_beta[ki], nonzero_vpadding_beta[ka])
398
399            tmp_eia[n0_ovp_ia] = (mo_ea_o[ki][:,None] - mo_ea_v[ka])[n0_ovp_ia]
400            tmp_eIA[n0_ovp_IA] = (mo_eb_o[ki][:,None] - mo_eb_v[ka])[n0_ovp_IA]
401            tmp_alpha.append(tmp_eia)
402            tmp_beta.append(tmp_eIA)
403        eia.append(tmp_alpha)
404        eIA.append(tmp_beta)
405
406    for ki in range(nkpts):
407        ka = ki
408        # Remove zero/padded elements from denominator
409        Ht1a[ki] /= eia[ki][ka]
410        Ht1b[ki] /= eIA[ki][ka]
411
412    for ki, kj, ka in kpts_helper.loop_kkk(nkpts):
413        kb = kconserv[ki, ka, kj]
414        eijab = eia[ki][ka][:,None,:,None] + eia[kj][kb][:,None,:]
415        Ht2aa[ki,kj,ka] /= eijab
416
417        eijab = eia[ki][ka][:,None,:,None] + eIA[kj][kb][:,None,:]
418        Ht2ab[ki,kj,ka] /= eijab
419
420        eijab = eIA[ki][ka][:,None,:,None] + eIA[kj][kb][:,None,:]
421        Ht2bb[ki,kj,ka] /= eijab
422
423    time0 = log.timer_debug1('update t1 t2', *time0)
424    return (Ht1a, Ht1b), (Ht2aa, Ht2ab, Ht2bb)
425
426
427def get_normt_diff(cc, t1, t2, t1new, t2new):
428    '''Calculates norm(t1 - t1new) + norm(t2 - t2new).'''
429    return (np.linalg.norm(t1new[0] - t1[0])**2 +
430            np.linalg.norm(t1new[1] - t1[1])**2 +
431            np.linalg.norm(t2new[0] - t2[0])**2 +
432            np.linalg.norm(t2new[1] - t2[1])**2 +
433            np.linalg.norm(t2new[2] - t2[2])**2) ** .5
434
435
436def energy(cc, t1, t2, eris):
437    t1a, t1b = t1
438    t2aa, t2ab, t2bb = t2
439
440    kka, noa, nva = t1a.shape
441    kkb, nob, nvb = t1b.shape
442    assert(kka == kkb)
443    nkpts = kka
444    s = 0.0 + 0j
445    fa, fb = eris.fock
446    for ki in range(nkpts):
447        s += einsum('ia,ia', fa[ki, :noa, noa:], t1a[ki, :, :])
448        s += einsum('ia,ia', fb[ki, :nob, nob:], t1b[ki, :, :])
449    t1t1aa = np.zeros(shape=t2aa.shape, dtype=t2aa.dtype)
450    t1t1ab = np.zeros(shape=t2ab.shape, dtype=t2ab.dtype)
451    t1t1bb = np.zeros(shape=t2bb.shape, dtype=t2bb.dtype)
452    for ki in range(nkpts):
453        ka = ki
454        for kj in range(nkpts):
455            t1t1aa[ki, kj, ka, :, :, :, :] = einsum('ia,jb->ijab', t1a[ki, :, :], t1a[kj, :, :])
456            t1t1ab[ki, kj, ka, :, :, :, :] = einsum('ia,jb->ijab', t1a[ki, :, :], t1b[kj, :, :])
457            t1t1bb[ki, kj, ka, :, :, :, :] = einsum('ia,jb->ijab', t1b[ki, :, :], t1b[kj, :, :])
458    tauaa = t2aa + 2*t1t1aa
459    tauab = t2ab + t1t1ab
460    taubb = t2bb + 2*t1t1bb
461    d = 0.0 + 0.j
462    d += 0.25*(einsum('xzyiajb,xyzijab->',eris.ovov,tauaa) -
463               einsum('yzxjaib,xyzijab->',eris.ovov,tauaa))
464    d += einsum('xzyiajb,xyzijab->',eris.ovOV,tauab)
465    d += 0.25*(einsum('xzyiajb,xyzijab->',eris.OVOV,taubb) -
466               einsum('yzxjaib,xyzijab->',eris.OVOV,taubb))
467    e = s + d
468    e /= nkpts
469    if abs(e.imag) > 1e-4:
470        logger.warn(cc, 'Non-zero imaginary part found in KCCSD energy %s', e)
471    return e.real
472
473
474#def get_nocc(cc, per_kpoint=False):
475#    '''See also function get_nocc in pyscf/pbc/mp2/kmp2.py'''
476#    if cc._nocc is not None:
477#        return cc._nocc
478#
479#    assert(cc.frozen == 0)
480#
481#    if isinstance(cc.frozen, (int, np.integer)):
482#        nocca = [(np.count_nonzero(cc.mo_occ[0][k] > 0) - cc.frozen) for k in range(cc.nkpts)]
483#        noccb = [(np.count_nonzero(cc.mo_occ[1][k] > 0) - cc.frozen) for k in range(cc.nkpts)]
484#
485#    else:
486#        raise NotImplementedError
487#
488#    if not per_kpoint:
489#        nocca = np.amax(nocca)
490#        noccb = np.amax(noccb)
491#    return nocca, noccb
492#
493#def get_nmo(cc, per_kpoint=False):
494#    '''See also function get_nmo in pyscf/pbc/mp2/kmp2.py'''
495#    if cc._nmo is not None:
496#        return cc._nmo
497#
498#    assert(cc.frozen == 0)
499#
500#    if isinstance(cc.frozen, (int, np.integer)):
501#        nmoa = [(cc.mo_occ[0][k].size - cc.frozen) for k in range(cc.nkpts)]
502#        nmob = [(cc.mo_occ[1][k].size - cc.frozen) for k in range(cc.nkpts)]
503#
504#    else:
505#        raise NotImplementedError
506#
507#    if not per_kpoint:
508#        nmoa = np.amax(nmoa)
509#        nmob = np.amax(nmob)
510#    return nmoa, nmob
511#
512#def get_frozen_mask(cc):
513#    '''See also get_frozen_mask function in pyscf/pbc/mp2/kmp2.py'''
514#
515#    moidxa = [np.ones(x.size, dtype=np.bool) for x in cc.mo_occ[0]]
516#    moidxb = [np.ones(x.size, dtype=np.bool) for x in cc.mo_occ[1]]
517#    assert(cc.frozen == 0)
518#
519#    if isinstance(cc.frozen, (int, np.integer)):
520#        for idx in moidxa:
521#            idx[:cc.frozen] = False
522#        for idx in moidxb:
523#            idx[:cc.frozen] = False
524#    else:
525#        raise NotImplementedError
526#
527#    return moidxa, moisxb
528
529def amplitudes_to_vector(t1, t2):
530    return np.hstack((t1[0].ravel(), t1[1].ravel(),
531                      t2[0].ravel(), t2[1].ravel(), t2[2].ravel()))
532
533def vector_to_amplitudes(vec, nmo, nocc, nkpts=1):
534    nocca, noccb = nocc
535    nmoa, nmob = nmo
536    nvira, nvirb = nmoa - nocca, nmob - noccb
537    sizes = (nkpts*nocca*nvira, nkpts*noccb*nvirb,
538             nkpts**3*nocca**2*nvira**2, nkpts**3*nocca*noccb*nvira*nvirb,
539             nkpts**3*noccb**2*nvirb**2)
540    sections = np.cumsum(sizes[:-1])
541    t1a, t1b, t2aa, t2ab, t2bb = np.split(vec, sections)
542
543    t1a = t1a.reshape(nkpts,nocca,nvira)
544    t1b = t1b.reshape(nkpts,noccb,nvirb)
545    t2aa = t2aa.reshape(nkpts,nkpts,nkpts,nocca,nocca,nvira,nvira)
546    t2ab = t2ab.reshape(nkpts,nkpts,nkpts,nocca,noccb,nvira,nvirb)
547    t2bb = t2bb.reshape(nkpts,nkpts,nkpts,noccb,noccb,nvirb,nvirb)
548    return (t1a,t1b), (t2aa,t2ab,t2bb)
549
550def add_vvvv_(cc, Ht2, t1, t2, eris):
551    nocca, noccb = cc.nocc
552    nmoa, nmob = cc.nmo
553    nkpts = cc.nkpts
554    kconserv = cc.khelper.kconserv
555
556    t1a, t1b = t1
557    t2aa, t2ab, t2bb = t2
558    Ht2aa, Ht2ab, Ht2bb = Ht2
559
560    if cc.direct and getattr(eris, 'Lpv', None) is not None:
561        def get_Wvvvv(ka, kc, kb):
562            kd = kconserv[ka,kc,kb]
563            Lpv = eris.Lpv
564            LPV = eris.LPV
565
566            Lbd = (Lpv[kb,kd][:,nocca:] -
567                   lib.einsum('Lkd,kb->Lbd', Lpv[kb,kd][:,:nocca], t1a[kb]))
568            Wvvvv = lib.einsum('Lac,Lbd->acbd', Lpv[ka,kc][:,nocca:], Lbd)
569            kcbd = lib.einsum('Lkc,Lbd->kcbd', Lpv[ka,kc][:,:nocca],
570                              Lpv[kb,kd][:,nocca:])
571            Wvvvv -= lib.einsum('kcbd,ka->acbd', kcbd, t1a[ka])
572
573            LBD = (LPV[kb,kd][:,noccb:] -
574                   lib.einsum('Lkd,kb->Lbd', LPV[kb,kd][:,:noccb], t1b[kb]))
575
576            WvvVV = lib.einsum('Lac,Lbd->acbd', Lpv[ka,kc][:,nocca:], LBD)
577            kcbd = lib.einsum('Lkc,Lbd->kcbd', Lpv[ka,kc][:,:nocca],
578                              LPV[kb,kd][:,noccb:])
579            WvvVV -= lib.einsum('kcbd,ka->acbd', kcbd, t1a[ka])
580
581            WVVVV = lib.einsum('Lac,Lbd->acbd', LPV[ka,kc][:,noccb:], LBD)
582            kcbd = lib.einsum('Lkc,Lbd->kcbd', LPV[ka,kc][:,:noccb],
583                              LPV[kb,kd][:,noccb:])
584            WVVVV -= lib.einsum('kcbd,ka->acbd', kcbd, t1b[ka])
585
586            Wvvvv *= (1./nkpts)
587            WvvVV *= (1./nkpts)
588            WVVVV *= (1./nkpts)
589            return Wvvvv, WvvVV, WVVVV
590    else:
591        _Wvvvv, _WvvVV, _WVVVV = kintermediates_uhf.cc_Wvvvv_half(cc, t1, t2, eris)
592        def get_Wvvvv(ka, kc, kb):
593            return _Wvvvv[ka,kc,kb], _WvvVV[ka,kc,kb], _WVVVV[ka,kc,kb]
594
595    #:Ht2aa += np.einsum('xyuijef,zuwaebf,xyuv,zwuv->xyzijab', tauaa, _Wvvvv-_Wvvvv.transpose(2,1,0,5,4,3,6), P, P) * .5
596    #:Ht2bb += np.einsum('xyuijef,zuwaebf,xyuv,zwuv->xyzijab', taubb, _WVVVV-_WVVVV.transpose(2,1,0,5,4,3,6), P, P) * .5
597    #:Ht2ab += np.einsum('xyuiJeF,zuwaeBF,xyuv,zwuv->xyziJaB', tauab, _WvvVV, P, P)
598    for ka, kb, kc in kpts_helper.loop_kkk(nkpts):
599        kd = kconserv[ka,kc,kb]
600        Wvvvv, WvvVV, WVVVV = get_Wvvvv(ka, kc, kb)
601        for ki in range(nkpts):
602            kj = kconserv[ka,ki,kb]
603            tauaa = t2aa[ki,kj,kc].copy()
604            tauab = t2ab[ki,kj,kc].copy()
605            taubb = t2bb[ki,kj,kc].copy()
606            if ki == kc and kj == kd:
607                tauaa += einsum('ic,jd->ijcd', t1a[ki], t1a[kj])
608                tauab += einsum('ic,jd->ijcd', t1a[ki], t1b[kj])
609                taubb += einsum('ic,jd->ijcd', t1b[ki], t1b[kj])
610            if ki == kd and kj == kc:
611                tauaa -= einsum('id,jc->ijcd', t1a[ki], t1a[kj])
612                taubb -= einsum('id,jc->ijcd', t1b[ki], t1b[kj])
613
614            tmp = lib.einsum('acbd,ijcd->ijab', Wvvvv, tauaa) * .5
615            Ht2aa[ki,kj,ka] += tmp
616            Ht2aa[ki,kj,kb] -= tmp.transpose(0,1,3,2)
617
618            tmp = lib.einsum('acbd,ijcd->ijab', WVVVV, taubb) * .5
619            Ht2bb[ki,kj,ka] += tmp
620            Ht2bb[ki,kj,kb] -= tmp.transpose(0,1,3,2)
621
622            Ht2ab[ki,kj,ka] += lib.einsum('acbd,ijcd->ijab', WvvVV, tauab)
623        Wvvvv = WvvVV = WVVVV = None
624    _Wvvvv = _WvvVV = _WVVVV = None
625
626    # Contractions below are merged to Woooo intermediates
627    # tauaa, tauab, taubb = kintermediates_uhf.make_tau(cc, t2, t1, t1)
628    # P = kintermediates_uhf.kconserv_mat(cc.nkpts, cc.khelper.kconserv)
629    # minj = np.einsum('xwymenf,uvwijef,xywz,uvwz->xuyminj', eris.ovov, tauaa, P, P)
630    # MINJ = np.einsum('xwymenf,uvwijef,xywz,uvwz->xuyminj', eris.OVOV, taubb, P, P)
631    # miNJ = np.einsum('xwymeNF,uvwiJeF,xywz,uvwz->xuymiNJ', eris.ovOV, tauab, P, P)
632    # Ht2aa += np.einsum('xuyminj,xywmnab,xyuv->uvwijab', minj, tauaa, P) * .25
633    # Ht2bb += np.einsum('xuyminj,xywmnab,xyuv->uvwijab', MINJ, taubb, P) * .25
634    # Ht2ab += np.einsum('xuymiNJ,xywmNaB,xyuv->uvwiJaB', miNJ, tauab, P) * .5
635    return (Ht2aa, Ht2ab, Ht2bb)
636
637
638class KUCCSD(uccsd.UCCSD):
639
640    max_space = getattr(__config__, 'pbc_cc_kccsd_uhf_KUCCSD_max_space', 20)
641
642    def __init__(self, mf, frozen=None, mo_coeff=None, mo_occ=None):
643        assert(isinstance(mf, scf.khf.KSCF))
644        uccsd.UCCSD.__init__(self, mf, frozen, mo_coeff, mo_occ)
645        self.kpts = mf.kpts
646        self.mo_energy = mf.mo_energy
647        self.khelper = kpts_helper.KptsHelper(mf.cell, self.kpts)
648        self.direct = True  # If possible, use GDF to compute Wvvvv on-the-fly
649
650        keys = set(['kpts', 'mo_energy', 'khelper', 'max_space', 'direct'])
651        self._keys = self._keys.union(keys)
652
653    @property
654    def nkpts(self):
655        return len(self.kpts)
656
657    get_normt_diff = get_normt_diff
658    get_nocc = get_nocc
659    get_nmo = get_nmo
660    get_frozen_mask = get_frozen_mask
661
662    update_amps = update_amps
663    energy = energy
664
665    def dump_flags(self, verbose=None):
666        return uccsd.UCCSD.dump_flags(self, verbose)
667
668    def ao2mo(self, mo_coeff=None):
669        from pyscf.pbc.df.df import GDF
670        cell = self._scf.cell
671        nkpts = self.nkpts
672        nmoa, nmob = self.nmo
673        mem_incore = nkpts**3 * (nmoa**4 + nmob**4) * 8 / 1e6
674        mem_now = lib.current_memory()[0]
675
676        if (mem_incore + mem_now < self.max_memory) or self.mol.incore_anyway:
677            return _make_eris_incore(self, mo_coeff)
678        elif (self.direct and type(self._scf.with_df) is GDF
679              and cell.dimension != 2):
680            # DFKCCSD does not support MDF
681            return _make_df_eris(self, mo_coeff)
682        else:
683            return _make_eris_outcore(self, mo_coeff)
684
685    def init_amps(self, eris):
686        time0 = logger.process_clock(), logger.perf_counter()
687
688        nocca, noccb = self.nocc
689        nmoa, nmob = self.nmo
690        nvira, nvirb = nmoa - nocca, nmob - noccb
691
692        nkpts = self.nkpts
693        t1a = np.zeros((nkpts, nocca, nvira), dtype=np.complex128)
694        t1b = np.zeros((nkpts, noccb, nvirb), dtype=np.complex128)
695        t1 = (t1a, t1b)
696        t2aa = np.zeros((nkpts, nkpts, nkpts, nocca, nocca, nvira, nvira), dtype=np.complex128)
697        t2ab = np.zeros((nkpts, nkpts, nkpts, nocca, noccb, nvira, nvirb), dtype=np.complex128)
698        t2bb = np.zeros((nkpts, nkpts, nkpts, noccb, noccb, nvirb, nvirb), dtype=np.complex128)
699
700        mo_ea_o = [e[:nocca] for e in eris.mo_energy[0]]
701        mo_eb_o = [e[:noccb] for e in eris.mo_energy[1]]
702        mo_ea_v = [e[nocca:] for e in eris.mo_energy[0]]
703        mo_eb_v = [e[noccb:] for e in eris.mo_energy[1]]
704
705        # Get location of padded elements in occupied and virtual space
706        nonzero_padding_alpha, nonzero_padding_beta = padding_k_idx(self, kind="split")
707        nonzero_opadding_alpha, nonzero_vpadding_alpha = nonzero_padding_alpha
708        nonzero_opadding_beta, nonzero_vpadding_beta = nonzero_padding_beta
709
710        eia = []
711        eIA = []
712        # Create denominators, ignoring padded elements
713        for ki in range(nkpts):
714            tmp_alpha = []
715            tmp_beta = []
716            for ka in range(nkpts):
717                tmp_eia = LARGE_DENOM * np.ones((nocca, nvira), dtype=eris.mo_energy[0][0].dtype)
718                tmp_eIA = LARGE_DENOM * np.ones((noccb, nvirb), dtype=eris.mo_energy[0][0].dtype)
719                n0_ovp_ia = np.ix_(nonzero_opadding_alpha[ki], nonzero_vpadding_alpha[ka])
720                n0_ovp_IA = np.ix_(nonzero_opadding_beta[ki], nonzero_vpadding_beta[ka])
721
722                tmp_eia[n0_ovp_ia] = (mo_ea_o[ki][:,None] - mo_ea_v[ka])[n0_ovp_ia]
723                tmp_eIA[n0_ovp_IA] = (mo_eb_o[ki][:,None] - mo_eb_v[ka])[n0_ovp_IA]
724                tmp_alpha.append(tmp_eia)
725                tmp_beta.append(tmp_eIA)
726            eia.append(tmp_alpha)
727            eIA.append(tmp_beta)
728
729        kconserv = kpts_helper.get_kconserv(self._scf.cell, self.kpts)
730        for ki, kj, ka in kpts_helper.loop_kkk(nkpts):
731            kb = kconserv[ki, ka, kj]
732            Daa = eia[ki][ka][:,None,:,None] + eia[kj][kb][:,None,:]
733            Dab = eia[ki][ka][:,None,:,None] + eIA[kj][kb][:,None,:]
734            Dbb = eIA[ki][ka][:,None,:,None] + eIA[kj][kb][:,None,:]
735
736            t2aa[ki,kj,ka] = eris.ovov[ki,ka,kj].conj().transpose((0,2,1,3)) / Daa
737            t2aa[ki,kj,ka]-= eris.ovov[kj,ka,ki].conj().transpose((2,0,1,3)) / Daa
738            t2ab[ki,kj,ka] = eris.ovOV[ki,ka,kj].conj().transpose((0,2,1,3)) / Dab
739            t2bb[ki,kj,ka] = eris.OVOV[ki,ka,kj].conj().transpose((0,2,1,3)) / Dbb
740            t2bb[ki,kj,ka]-= eris.OVOV[kj,ka,ki].conj().transpose((2,0,1,3)) / Dbb
741
742        t2 = (t2aa,t2ab,t2bb)
743
744        d = 0.0 + 0.j
745        d += 0.25*(einsum('xzyiajb,xyzijab->',eris.ovov,t2aa) -
746                   einsum('yzxjaib,xyzijab->',eris.ovov,t2aa))
747        d += einsum('xzyiajb,xyzijab->',eris.ovOV,t2ab)
748        d += 0.25*(einsum('xzyiajb,xyzijab->',eris.OVOV,t2bb) -
749                   einsum('yzxjaib,xyzijab->',eris.OVOV,t2bb))
750        self.emp2 = d/nkpts
751
752        logger.info(self, 'Init t2, MP2 energy = %.15g', self.emp2.real)
753        logger.timer(self, 'init mp2', *time0)
754        return self.emp2, t1, t2
755
756    def amplitudes_to_vector(self, t1, t2):
757        return amplitudes_to_vector(t1, t2)
758
759    def vector_to_amplitudes(self, vec, nmo=None, nocc=None, nkpts=None):
760        if nocc is None: nocc = self.nocc
761        if nmo is None: nmo = self.nmo
762        if nkpts is None: nkpts = self.nkpts
763        return vector_to_amplitudes(vec, nmo, nocc, nkpts)
764
765UCCSD = KUCCSD
766
767
768#######################################
769#
770# _ERIS.
771#
772# Note the two electron integrals are stored in different orders from
773# kccsd_rhf._ERIS.  Integrals (ab|cd) are stored as [ka,kb,kc,a,b,c,d] here
774# while the order is [ka,kc,kb,a,c,b,d] in kccsd_rhf._ERIS
775#
776# TODO: use the same convention as kccsd_rhf
777#
778def _make_eris_incore(cc, mo_coeff=None):
779    eris = uccsd._ChemistsERIs()
780    if mo_coeff is None:
781        mo_coeff = cc.mo_coeff
782    mo_coeff = convert_mo_coeff(mo_coeff)  # FIXME: Remove me!
783    mo_coeff = padded_mo_coeff(cc, mo_coeff)
784    eris.mo_coeff = mo_coeff
785    eris.nocc = cc.nocc
786
787    nkpts = cc.nkpts
788    nocca, noccb = cc.nocc
789    nmoa, nmob = cc.nmo
790    nvira, nvirb = nmoa - nocca, nmob - noccb
791
792    if gamma_point(cc.kpts):
793        dtype = np.double
794    else:
795        dtype = np.complex128
796    dtype = np.result_type(dtype, *mo_coeff[0])
797
798    eris.oooo = np.empty((nkpts,nkpts,nkpts,nocca,nocca,nocca,nocca), dtype=dtype)
799    eris.ooov = np.empty((nkpts,nkpts,nkpts,nocca,nocca,nocca,nvira), dtype=dtype)
800    eris.oovv = np.empty((nkpts,nkpts,nkpts,nocca,nocca,nvira,nvira), dtype=dtype)
801    eris.ovov = np.empty((nkpts,nkpts,nkpts,nocca,nvira,nocca,nvira), dtype=dtype)
802    eris.voov = np.empty((nkpts,nkpts,nkpts,nvira,nocca,nocca,nvira), dtype=dtype)
803    eris.vovv = np.empty((nkpts,nkpts,nkpts,nvira,nocca,nvira,nvira), dtype=dtype)
804
805    eris.OOOO = np.empty((nkpts,nkpts,nkpts,noccb,noccb,noccb,noccb), dtype=dtype)
806    eris.OOOV = np.empty((nkpts,nkpts,nkpts,noccb,noccb,noccb,nvirb), dtype=dtype)
807    eris.OOVV = np.empty((nkpts,nkpts,nkpts,noccb,noccb,nvirb,nvirb), dtype=dtype)
808    eris.OVOV = np.empty((nkpts,nkpts,nkpts,noccb,nvirb,noccb,nvirb), dtype=dtype)
809    eris.VOOV = np.empty((nkpts,nkpts,nkpts,nvirb,noccb,noccb,nvirb), dtype=dtype)
810    eris.VOVV = np.empty((nkpts,nkpts,nkpts,nvirb,noccb,nvirb,nvirb), dtype=dtype)
811
812    eris.ooOO = np.empty((nkpts,nkpts,nkpts,nocca,nocca,noccb,noccb), dtype=dtype)
813    eris.ooOV = np.empty((nkpts,nkpts,nkpts,nocca,nocca,noccb,nvirb), dtype=dtype)
814    eris.ooVV = np.empty((nkpts,nkpts,nkpts,nocca,nocca,nvirb,nvirb), dtype=dtype)
815    eris.ovOV = np.empty((nkpts,nkpts,nkpts,nocca,nvira,noccb,nvirb), dtype=dtype)
816    eris.voOV = np.empty((nkpts,nkpts,nkpts,nvira,nocca,noccb,nvirb), dtype=dtype)
817    eris.voVV = np.empty((nkpts,nkpts,nkpts,nvira,nocca,nvirb,nvirb), dtype=dtype)
818
819    eris.OOoo = None
820    eris.OOov = np.empty((nkpts,nkpts,nkpts,noccb,noccb,nocca,nvira), dtype=dtype)
821    eris.OOvv = np.empty((nkpts,nkpts,nkpts,noccb,noccb,nvira,nvira), dtype=dtype)
822    eris.OVov = np.empty((nkpts,nkpts,nkpts,noccb,nvirb,nocca,nvira), dtype=dtype)
823    eris.VOov = np.empty((nkpts,nkpts,nkpts,nvirb,noccb,nocca,nvira), dtype=dtype)
824    eris.VOvv = np.empty((nkpts,nkpts,nkpts,nvirb,noccb,nvira,nvira), dtype=dtype)
825
826    _kuccsd_eris_common_(cc, eris)
827
828    thisdf = cc._scf.with_df
829    orbva = np.asarray(mo_coeff[0][:,:,nocca:], order='C')
830    orbvb = np.asarray(mo_coeff[1][:,:,noccb:], order='C')
831    eris.vvvv = thisdf.ao2mo_7d(orbva, factor=1./nkpts)
832    eris.VVVV = thisdf.ao2mo_7d(orbvb, factor=1./nkpts)
833    eris.vvVV = thisdf.ao2mo_7d([orbva,orbva,orbvb,orbvb], factor=1./nkpts)
834
835    return eris
836
837def _kuccsd_eris_common_(cc, eris, buf=None):
838    from pyscf.pbc import tools
839    from pyscf.pbc.cc.ccsd import _adjust_occ
840    #if not (cc.frozen is None or cc.frozen == 0):
841    #    raise NotImplementedError('cc.frozen = %s' % str(cc.frozen))
842
843    cput0 = (logger.process_clock(), logger.perf_counter())
844    log = logger.new_logger(cc)
845    cell = cc._scf.cell
846    thisdf = cc._scf.with_df
847
848    kpts = cc.kpts
849    nkpts = cc.nkpts
850    mo_coeff = eris.mo_coeff
851    nocca, noccb = eris.nocc
852    nmoa, nmob = cc.nmo
853    mo_a, mo_b = mo_coeff
854
855    # Re-make our fock MO matrix elements from density and fock AO
856    dm = cc._scf.make_rdm1(cc.mo_coeff, cc.mo_occ)
857    hcore = cc._scf.get_hcore()
858    with lib.temporary_env(cc._scf, exxdiv=None):
859        vhf = cc._scf.get_veff(cell, dm)
860    focka = [reduce(np.dot, (mo.conj().T, hcore[k]+vhf[0][k], mo))
861             for k, mo in enumerate(mo_a)]
862    fockb = [reduce(np.dot, (mo.conj().T, hcore[k]+vhf[1][k], mo))
863             for k, mo in enumerate(mo_b)]
864    eris.fock = (np.asarray(focka), np.asarray(fockb))
865    eris.e_hf = cc._scf.energy_tot(dm=dm, vhf=vhf)
866
867    madelung = tools.madelung(cell, kpts)
868    mo_ea = [focka[k].diagonal().real for k in range(nkpts)]
869    mo_eb = [fockb[k].diagonal().real for k in range(nkpts)]
870    mo_ea = [_adjust_occ(e, nocca, -madelung) for e in mo_ea]
871    mo_eb = [_adjust_occ(e, noccb, -madelung) for e in mo_eb]
872    eris.mo_energy = (mo_ea, mo_eb)
873
874    orboa = np.asarray(mo_coeff[0][:,:,:nocca], order='C')
875    orbob = np.asarray(mo_coeff[1][:,:,:noccb], order='C')
876    #orbva = np.asarray(mo_coeff[0][:,:,nocca:], order='C')
877    #orbvb = np.asarray(mo_coeff[1][:,:,noccb:], order='C')
878    dtype = np.result_type(*focka).char
879
880    # The momentum conservation array
881    kconserv = cc.khelper.kconserv
882
883    out = None
884    if isinstance(buf, h5py.Group):
885        out = buf.create_dataset('tmp', (nkpts,nkpts,nkpts,nocca,nmoa,nmoa,nmoa), dtype)
886    oppp = thisdf.ao2mo_7d([orboa,mo_coeff[0],mo_coeff[0],mo_coeff[0]], kpts,
887                           factor=1./nkpts, out=out)
888    for kp, kq, kr in kpts_helper.loop_kkk(nkpts):
889        ks = kconserv[kp,kq,kr]
890        tmp = np.asarray(oppp[kp,kq,kr])
891        eris.oooo[kp,kq,kr] = tmp[:nocca,:nocca,:nocca,:nocca]
892        eris.ooov[kp,kq,kr] = tmp[:nocca,:nocca,:nocca,nocca:]
893        eris.oovv[kp,kq,kr] = tmp[:nocca,:nocca,nocca:,nocca:]
894        eris.ovov[kp,kq,kr] = tmp[:nocca,nocca:,:nocca,nocca:]
895        eris.voov[kq,kp,ks] = tmp[:nocca,nocca:,nocca:,:nocca].conj().transpose(1,0,3,2)
896        eris.vovv[kq,kp,ks] = tmp[:nocca,nocca:,nocca:,nocca:].conj().transpose(1,0,3,2)
897    oppp = None
898
899    if isinstance(buf, h5py.Group):
900        del(buf['tmp'])
901        out = buf.create_dataset('tmp', (nkpts,nkpts,nkpts,noccb,nmob,nmob,nmob), dtype)
902    oppp = thisdf.ao2mo_7d([orbob,mo_coeff[1],mo_coeff[1],mo_coeff[1]], kpts,
903                           factor=1./nkpts, out=out)
904    for kp, kq, kr in kpts_helper.loop_kkk(nkpts):
905        ks = kconserv[kp,kq,kr]
906        tmp = np.asarray(oppp[kp,kq,kr])
907        eris.OOOO[kp,kq,kr] = tmp[:noccb,:noccb,:noccb,:noccb]
908        eris.OOOV[kp,kq,kr] = tmp[:noccb,:noccb,:noccb,noccb:]
909        eris.OOVV[kp,kq,kr] = tmp[:noccb,:noccb,noccb:,noccb:]
910        eris.OVOV[kp,kq,kr] = tmp[:noccb,noccb:,:noccb,noccb:]
911        eris.VOOV[kq,kp,ks] = tmp[:noccb,noccb:,noccb:,:noccb].conj().transpose(1,0,3,2)
912        eris.VOVV[kq,kp,ks] = tmp[:noccb,noccb:,noccb:,noccb:].conj().transpose(1,0,3,2)
913    oppp = None
914
915    if isinstance(buf, h5py.Group):
916        del(buf['tmp'])
917        out = buf.create_dataset('tmp', (nkpts,nkpts,nkpts,nocca,nmoa,nmob,nmob), dtype)
918    oppp = thisdf.ao2mo_7d([orboa,mo_coeff[0],mo_coeff[1],mo_coeff[1]], kpts,
919                           factor=1./nkpts, out=out)
920    for kp, kq, kr in kpts_helper.loop_kkk(nkpts):
921        ks = kconserv[kp,kq,kr]
922        tmp = np.asarray(oppp[kp,kq,kr])
923        eris.ooOO[kp,kq,kr] = tmp[:nocca,:nocca,:noccb,:noccb]
924        eris.ooOV[kp,kq,kr] = tmp[:nocca,:nocca,:noccb,noccb:]
925        eris.ooVV[kp,kq,kr] = tmp[:nocca,:nocca,noccb:,noccb:]
926        eris.ovOV[kp,kq,kr] = tmp[:nocca,nocca:,:noccb,noccb:]
927        eris.voOV[kq,kp,ks] = tmp[:nocca,nocca:,noccb:,:noccb].conj().transpose(1,0,3,2)
928        eris.voVV[kq,kp,ks] = tmp[:nocca,nocca:,noccb:,noccb:].conj().transpose(1,0,3,2)
929    oppp = None
930
931    if isinstance(buf, h5py.Group):
932        del(buf['tmp'])
933        out = buf.create_dataset('tmp', (nkpts,nkpts,nkpts,noccb,nmob,nmoa,nmoa), dtype)
934    oppp = thisdf.ao2mo_7d([orbob,mo_coeff[1],mo_coeff[0],mo_coeff[0]], kpts,
935                           factor=1./nkpts, out=out)
936    for kp, kq, kr in kpts_helper.loop_kkk(nkpts):
937        ks = kconserv[kp,kq,kr]
938        tmp = np.asarray(oppp[kp,kq,kr])
939        #eris.OOoo[kp,kq,kr] = tmp[:noccb,:noccb,:nocca,:nocca]
940        eris.OOov[kp,kq,kr] = tmp[:noccb,:noccb,:nocca,nocca:]
941        eris.OOvv[kp,kq,kr] = tmp[:noccb,:noccb,nocca:,nocca:]
942        eris.OVov[kp,kq,kr] = tmp[:noccb,noccb:,:nocca,nocca:]
943        eris.VOov[kq,kp,ks] = tmp[:noccb,noccb:,nocca:,:nocca].conj().transpose(1,0,3,2)
944        eris.VOvv[kq,kp,ks] = tmp[:noccb,noccb:,nocca:,nocca:].conj().transpose(1,0,3,2)
945    oppp = None
946
947    log.timer('CCSD integral transformation', *cput0)
948    return eris
949
950def _make_eris_outcore(cc, mo_coeff=None):
951    eris = uccsd._ChemistsERIs()
952    if mo_coeff is None:
953        mo_coeff = cc.mo_coeff
954    mo_coeff = convert_mo_coeff(mo_coeff)  # FIXME: Remove me!
955    mo_coeff = padded_mo_coeff(cc, mo_coeff)
956    eris.mo_coeff = mo_coeff
957    eris.nocc = cc.nocc
958
959    nkpts = cc.nkpts
960    nocca, noccb = cc.nocc
961    nmoa, nmob = cc.nmo
962    nvira, nvirb = nmoa - nocca, nmob - noccb
963
964    if gamma_point(cc.kpts):
965        dtype = np.double
966    else:
967        dtype = np.complex128
968    dtype = np.result_type(dtype, *mo_coeff[0]).char
969
970    eris.feri = feri = lib.H5TmpFile()
971
972    eris.oooo = feri.create_dataset('oooo', (nkpts,nkpts,nkpts,nocca,nocca,nocca,nocca), dtype)
973    eris.ooov = feri.create_dataset('ooov', (nkpts,nkpts,nkpts,nocca,nocca,nocca,nvira), dtype)
974    eris.oovv = feri.create_dataset('oovv', (nkpts,nkpts,nkpts,nocca,nocca,nvira,nvira), dtype)
975    eris.ovov = feri.create_dataset('ovov', (nkpts,nkpts,nkpts,nocca,nvira,nocca,nvira), dtype)
976    eris.voov = feri.create_dataset('voov', (nkpts,nkpts,nkpts,nvira,nocca,nocca,nvira), dtype)
977    eris.vovv = feri.create_dataset('vovv', (nkpts,nkpts,nkpts,nvira,nocca,nvira,nvira), dtype)
978    eris.vvvv = feri.create_dataset('vvvv', (nkpts,nkpts,nkpts,nvira,nvira,nvira,nvira), dtype)
979
980    eris.OOOO = feri.create_dataset('OOOO', (nkpts,nkpts,nkpts,noccb,noccb,noccb,noccb), dtype)
981    eris.OOOV = feri.create_dataset('OOOV', (nkpts,nkpts,nkpts,noccb,noccb,noccb,nvirb), dtype)
982    eris.OOVV = feri.create_dataset('OOVV', (nkpts,nkpts,nkpts,noccb,noccb,nvirb,nvirb), dtype)
983    eris.OVOV = feri.create_dataset('OVOV', (nkpts,nkpts,nkpts,noccb,nvirb,noccb,nvirb), dtype)
984    eris.VOOV = feri.create_dataset('VOOV', (nkpts,nkpts,nkpts,nvirb,noccb,noccb,nvirb), dtype)
985    eris.VOVV = feri.create_dataset('VOVV', (nkpts,nkpts,nkpts,nvirb,noccb,nvirb,nvirb), dtype)
986    eris.VVVV = feri.create_dataset('VVVV', (nkpts,nkpts,nkpts,nvirb,nvirb,nvirb,nvirb), dtype)
987
988    eris.ooOO = feri.create_dataset('ooOO', (nkpts,nkpts,nkpts,nocca,nocca,noccb,noccb), dtype)
989    eris.ooOV = feri.create_dataset('ooOV', (nkpts,nkpts,nkpts,nocca,nocca,noccb,nvirb), dtype)
990    eris.ooVV = feri.create_dataset('ooVV', (nkpts,nkpts,nkpts,nocca,nocca,nvirb,nvirb), dtype)
991    eris.ovOV = feri.create_dataset('ovOV', (nkpts,nkpts,nkpts,nocca,nvira,noccb,nvirb), dtype)
992    eris.voOV = feri.create_dataset('voOV', (nkpts,nkpts,nkpts,nvira,nocca,noccb,nvirb), dtype)
993    eris.voVV = feri.create_dataset('voVV', (nkpts,nkpts,nkpts,nvira,nocca,nvirb,nvirb), dtype)
994    eris.vvVV = feri.create_dataset('vvVV', (nkpts,nkpts,nkpts,nvira,nvira,nvirb,nvirb), dtype)
995
996    eris.OOoo = None
997    eris.OOov = feri.create_dataset('OOov', (nkpts,nkpts,nkpts,noccb,noccb,nocca,nvira), dtype)
998    eris.OOvv = feri.create_dataset('OOvv', (nkpts,nkpts,nkpts,noccb,noccb,nvira,nvira), dtype)
999    eris.OVov = feri.create_dataset('OVov', (nkpts,nkpts,nkpts,noccb,nvirb,nocca,nvira), dtype)
1000    eris.VOov = feri.create_dataset('VOov', (nkpts,nkpts,nkpts,nvirb,noccb,nocca,nvira), dtype)
1001    eris.VOvv = feri.create_dataset('VOvv', (nkpts,nkpts,nkpts,nvirb,noccb,nvira,nvira), dtype)
1002    eris.VVvv = None
1003
1004    fswap = lib.H5TmpFile()
1005    _kuccsd_eris_common_(cc, eris, fswap)
1006    fswap = None
1007
1008    thisdf = cc._scf.with_df
1009    orbva = np.asarray(mo_coeff[0][:,:,nocca:], order='C')
1010    orbvb = np.asarray(mo_coeff[1][:,:,noccb:], order='C')
1011    thisdf.ao2mo_7d(orbva, cc.kpts, factor=1./nkpts, out=eris.vvvv)
1012    thisdf.ao2mo_7d(orbvb, cc.kpts, factor=1./nkpts, out=eris.VVVV)
1013    thisdf.ao2mo_7d([orbva,orbva,orbvb,orbvb], cc.kpts, factor=1./nkpts, out=eris.vvVV)
1014
1015    return eris
1016
1017def _make_df_eris(cc, mo_coeff=None):
1018    from pyscf.pbc.df import df
1019    from pyscf.ao2mo import _ao2mo
1020    cell = cc._scf.cell
1021    if cell.dimension == 2:
1022        raise NotImplementedError
1023
1024    eris = uccsd._ChemistsERIs()
1025    if mo_coeff is None:
1026        mo_coeff = cc.mo_coeff
1027    mo_coeff = padded_mo_coeff(cc, mo_coeff)
1028    eris.mo_coeff = mo_coeff
1029    eris.nocc = cc.nocc
1030    thisdf = cc._scf.with_df
1031
1032    kpts = cc.kpts
1033    nkpts = cc.nkpts
1034    nocca, noccb = cc.nocc
1035    nmoa, nmob = cc.nmo
1036    nvira, nvirb = nmoa - nocca, nmob - noccb
1037    #if getattr(thisdf, 'auxcell', None):
1038    #    naux = thisdf.auxcell.nao_nr()
1039    #else:
1040    #    naux = thisdf.get_naoaux()
1041    nao = cell.nao_nr()
1042    mo_kpts_a, mo_kpts_b = eris.mo_coeff
1043
1044    if gamma_point(kpts):
1045        dtype = np.double
1046    else:
1047        dtype = np.complex128
1048    dtype = np.result_type(dtype, *mo_kpts_a)
1049
1050    eris.feri = feri = lib.H5TmpFile()
1051
1052    eris.oooo = feri.create_dataset('oooo', (nkpts,nkpts,nkpts,nocca,nocca,nocca,nocca), dtype)
1053    eris.ooov = feri.create_dataset('ooov', (nkpts,nkpts,nkpts,nocca,nocca,nocca,nvira), dtype)
1054    eris.oovv = feri.create_dataset('oovv', (nkpts,nkpts,nkpts,nocca,nocca,nvira,nvira), dtype)
1055    eris.ovov = feri.create_dataset('ovov', (nkpts,nkpts,nkpts,nocca,nvira,nocca,nvira), dtype)
1056    eris.voov = feri.create_dataset('voov', (nkpts,nkpts,nkpts,nvira,nocca,nocca,nvira), dtype)
1057    eris.vovv = feri.create_dataset('vovv', (nkpts,nkpts,nkpts,nvira,nocca,nvira,nvira), dtype)
1058    eris.vvvv = None
1059
1060    eris.OOOO = feri.create_dataset('OOOO', (nkpts,nkpts,nkpts,noccb,noccb,noccb,noccb), dtype)
1061    eris.OOOV = feri.create_dataset('OOOV', (nkpts,nkpts,nkpts,noccb,noccb,noccb,nvirb), dtype)
1062    eris.OOVV = feri.create_dataset('OOVV', (nkpts,nkpts,nkpts,noccb,noccb,nvirb,nvirb), dtype)
1063    eris.OVOV = feri.create_dataset('OVOV', (nkpts,nkpts,nkpts,noccb,nvirb,noccb,nvirb), dtype)
1064    eris.VOOV = feri.create_dataset('VOOV', (nkpts,nkpts,nkpts,nvirb,noccb,noccb,nvirb), dtype)
1065    eris.VOVV = feri.create_dataset('VOVV', (nkpts,nkpts,nkpts,nvirb,noccb,nvirb,nvirb), dtype)
1066    eris.VVVV = None
1067
1068    eris.ooOO = feri.create_dataset('ooOO', (nkpts,nkpts,nkpts,nocca,nocca,noccb,noccb), dtype)
1069    eris.ooOV = feri.create_dataset('ooOV', (nkpts,nkpts,nkpts,nocca,nocca,noccb,nvirb), dtype)
1070    eris.ooVV = feri.create_dataset('ooVV', (nkpts,nkpts,nkpts,nocca,nocca,nvirb,nvirb), dtype)
1071    eris.ovOV = feri.create_dataset('ovOV', (nkpts,nkpts,nkpts,nocca,nvira,noccb,nvirb), dtype)
1072    eris.voOV = feri.create_dataset('voOV', (nkpts,nkpts,nkpts,nvira,nocca,noccb,nvirb), dtype)
1073    eris.voVV = feri.create_dataset('voVV', (nkpts,nkpts,nkpts,nvira,nocca,nvirb,nvirb), dtype)
1074    eris.vvVV = None
1075
1076    eris.OOoo = None
1077    eris.OOov = feri.create_dataset('OOov', (nkpts,nkpts,nkpts,noccb,noccb,nocca,nvira), dtype)
1078    eris.OOvv = feri.create_dataset('OOvv', (nkpts,nkpts,nkpts,noccb,noccb,nvira,nvira), dtype)
1079    eris.OVov = feri.create_dataset('OVov', (nkpts,nkpts,nkpts,noccb,nvirb,nocca,nvira), dtype)
1080    eris.VOov = feri.create_dataset('VOov', (nkpts,nkpts,nkpts,nvirb,noccb,nocca,nvira), dtype)
1081    eris.VOvv = feri.create_dataset('VOvv', (nkpts,nkpts,nkpts,nvirb,noccb,nvira,nvira), dtype)
1082    eris.VVvv = None
1083
1084    fswap = lib.H5TmpFile()
1085    _kuccsd_eris_common_(cc, eris, fswap)
1086    fswap = None
1087
1088    eris.Lpv = Lpv = np.empty((nkpts,nkpts), dtype=object)
1089    eris.LPV = LPV = np.empty((nkpts,nkpts), dtype=object)
1090    with h5py.File(thisdf._cderi, 'r') as f:
1091        kptij_lst = f['j3c-kptij'][:]
1092        tao = []
1093        ao_loc = None
1094        for ki, kpti in enumerate(kpts):
1095            for kj, kptj in enumerate(kpts):
1096                kpti_kptj = np.array((kpti,kptj))
1097                Lpq = np.asarray(df._getitem(f, 'j3c', kpti_kptj, kptij_lst))
1098
1099                mo_a = np.hstack((mo_kpts_a[ki], mo_kpts_a[kj][:,nocca:]))
1100                mo_b = np.hstack((mo_kpts_b[ki], mo_kpts_b[kj][:,noccb:]))
1101                mo_a = np.asarray(mo_a, dtype=dtype, order='F')
1102                mo_b = np.asarray(mo_b, dtype=dtype, order='F')
1103                if dtype == np.double:
1104                    outa = _ao2mo.nr_e2(Lpq, mo_a, (0, nmoa, nmoa, nmoa+nvira), aosym='s2')
1105                    outb = _ao2mo.nr_e2(Lpq, mo_b, (0, nmob, nmob, nmob+nvirb), aosym='s2')
1106                else:
1107                    #Note: Lpq.shape[0] != naux if linear dependency is found in auxbasis
1108                    if Lpq[0].size != nao**2: # aosym = 's2'
1109                        Lpq = lib.unpack_tril(Lpq).astype(np.complex128)
1110                    outa = _ao2mo.r_e2(Lpq, mo_a, (0, nmoa, nmoa, nmoa+nvira), tao, ao_loc)
1111                    outb = _ao2mo.r_e2(Lpq, mo_b, (0, nmob, nmob, nmob+nvirb), tao, ao_loc)
1112                Lpv[ki,kj] = outa.reshape(-1,nmoa,nvira)
1113                LPV[ki,kj] = outb.reshape(-1,nmob,nvirb)
1114
1115    return eris
1116
1117
1118scf.kuhf.KUHF.CCSD = lib.class_as_method(KUCCSD)
1119
1120
1121if __name__ == '__main__':
1122    from pyscf.pbc import gto
1123    from pyscf import lo
1124
1125    cell = gto.Cell()
1126    cell.atom='''
1127    He 0.000000000000   0.000000000000   0.000000000000
1128    He 1.685068664391   1.685068664391   1.685068664391
1129    '''
1130    #cell.basis = [[0, (1., 1.)], [1, (.5, 1.)]]
1131    cell.basis = [[0, (1., 1.)], [0, (.5, 1.)]]
1132    cell.a = '''
1133    0.000000000, 3.370137329, 3.370137329
1134    3.370137329, 0.000000000, 3.370137329
1135    3.370137329, 3.370137329, 0.000000000'''
1136    cell.unit = 'B'
1137    cell.mesh = [13]*3
1138    cell.build()
1139
1140    np.random.seed(2)
1141    # Running HF and CCSD with 1x1x2 Monkhorst-Pack k-point mesh
1142    kmf = scf.KUHF(cell, kpts=cell.make_kpts([1,1,3]), exxdiv=None)
1143    nmo = cell.nao_nr()
1144    kmf.mo_occ = np.zeros((2,3,nmo))
1145    kmf.mo_occ[0,:,:3] = 1
1146    kmf.mo_occ[1,:,:1] = 1
1147    kmf.mo_energy = np.arange(nmo) + np.random.random((2,3,nmo)) * .3
1148    kmf.mo_energy[kmf.mo_occ == 0] += 2
1149
1150    mo = (np.random.random((2,3,nmo,nmo)) +
1151          np.random.random((2,3,nmo,nmo))*1j - .5-.5j)
1152    s = kmf.get_ovlp()
1153    kmf.mo_coeff = np.empty_like(mo)
1154    nkpts = len(kmf.kpts)
1155    for k in range(nkpts):
1156        kmf.mo_coeff[0,k] = lo.orth.vec_lowdin(mo[0,k], s[k])
1157        kmf.mo_coeff[1,k] = lo.orth.vec_lowdin(mo[1,k], s[k])
1158
1159    def rand_t1_t2(mycc):
1160        nkpts = mycc.nkpts
1161        nocca, noccb = mycc.nocc
1162        nmoa, nmob = mycc.nmo
1163        nvira, nvirb = nmoa - nocca, nmob - noccb
1164        np.random.seed(1)
1165        t1a = (np.random.random((nkpts,nocca,nvira)) +
1166               np.random.random((nkpts,nocca,nvira))*1j - .5-.5j)
1167        t1b = (np.random.random((nkpts,noccb,nvirb)) +
1168               np.random.random((nkpts,noccb,nvirb))*1j - .5-.5j)
1169        t2aa = (np.random.random((nkpts,nkpts,nkpts,nocca,nocca,nvira,nvira)) +
1170                np.random.random((nkpts,nkpts,nkpts,nocca,nocca,nvira,nvira))*1j - .5-.5j)
1171        kconserv = kpts_helper.get_kconserv(kmf.cell, kmf.kpts)
1172        t2aa = t2aa - t2aa.transpose(1,0,2,4,3,5,6)
1173        tmp = t2aa.copy()
1174        for ki, kj, kk in kpts_helper.loop_kkk(nkpts):
1175            kl = kconserv[ki, kk, kj]
1176            t2aa[ki,kj,kk] = t2aa[ki,kj,kk] - tmp[ki,kj,kl].transpose(0,1,3,2)
1177        t2ab = (np.random.random((nkpts,nkpts,nkpts,nocca,noccb,nvira,nvirb)) +
1178                np.random.random((nkpts,nkpts,nkpts,nocca,noccb,nvira,nvirb))*1j - .5-.5j)
1179        t2bb = (np.random.random((nkpts,nkpts,nkpts,noccb,noccb,nvirb,nvirb)) +
1180                np.random.random((nkpts,nkpts,nkpts,noccb,noccb,nvirb,nvirb))*1j - .5-.5j)
1181        t2bb = t2bb - t2bb.transpose(1,0,2,4,3,5,6)
1182        tmp = t2bb.copy()
1183        for ki, kj, kk in kpts_helper.loop_kkk(nkpts):
1184            kl = kconserv[ki, kk, kj]
1185            t2bb[ki,kj,kk] = t2bb[ki,kj,kk] - tmp[ki,kj,kl].transpose(0,1,3,2)
1186
1187        t1 = (t1a, t1b)
1188        t2 = (t2aa, t2ab, t2bb)
1189        return t1, t2
1190
1191    mycc = KUCCSD(kmf)
1192    eris = mycc.ao2mo()
1193    t1, t2 = rand_t1_t2(mycc)
1194    Ht1, Ht2 = mycc.update_amps(t1, t2, eris)
1195    print(lib.finger(Ht1[0]) - (2.2677885702176339-2.5150764056992041j))
1196    print(lib.finger(Ht1[1]) - (-51.643438947846086+526.58026126100458j))
1197    print(lib.finger(Ht2[0]) - (-29.490813482748258-8.7509143690136018j))
1198    print(lib.finger(Ht2[1]) - (2256.0440056839416-193.16480896707569j))
1199    print(lib.finger(Ht2[2]) - (-250.59447681063182-397.57189085666982j))
1200
1201    kmf.mo_occ[:] = 0
1202    kmf.mo_occ[:,:,:2] = 1
1203    mycc = KUCCSD(kmf)
1204    eris = mycc.ao2mo()
1205    t1, t2 = rand_t1_t2(mycc)
1206    Ht1, Ht2 = mycc.update_amps(t1, t2, eris)
1207    print(lib.finger(Ht1[0]) - (5.4622516572705662+1.990046725028729j))
1208    print(lib.finger(Ht1[1]) - (4.8801120611799043-5.9940463787453488j))
1209    print(lib.finger(Ht2[0]) - (-192.38864512375193+305.14191018543983j))
1210    print(lib.finger(Ht2[1]) - (23085.044505825954-11527.802302550244j))
1211    print(lib.finger(Ht2[2]) - (115.57932548288559-40.888597453928604j))
1212
1213    from pyscf.pbc.cc import kccsd
1214    kgcc = kccsd.GCCSD(scf.addons.convert_to_ghf(kmf))
1215    kccsd_eris = kccsd._make_eris_incore(kgcc, kgcc._scf.mo_coeff)
1216    r1 = kgcc.spatial2spin(t1)
1217    r2 = kgcc.spatial2spin(t2)
1218    ge = kccsd.energy(kgcc, r1, r2, kccsd_eris)
1219    r1, r2 = kgcc.update_amps(r1, r2, kccsd_eris)
1220    ue = energy(mycc, t1, t2, eris)
1221    print(abs(ge - ue))
1222    print(abs(r1 - kgcc.spatial2spin(Ht1)).max())
1223    print(abs(r2 - kgcc.spatial2spin(Ht2)).max())
1224
1225    kmf = kmf.density_fit(auxbasis=[[0, (1., 1.)]])
1226    mycc = KUCCSD(kmf)
1227    eris = _make_df_eris(mycc, mycc.mo_coeff)
1228    t1, t2 = rand_t1_t2(mycc)
1229    Ht1, Ht2 = mycc.update_amps(t1, t2, eris)
1230
1231    print(lib.finger(Ht1[0]) - (6.9341372555790013+0.87313546297025901j))
1232    print(lib.finger(Ht1[1]) - (6.7538005829391992-0.95702422534126796j))
1233    print(lib.finger(Ht2[0]) - (-509.24544842179876+448.00925776269855j))
1234    print(lib.finger(Ht2[1]) - (107.5960392010511+40.869216223808067j)  )
1235    print(lib.finger(Ht2[2]) - (-196.75910296082139+218.53005038057515j))
1236    kgcc = kccsd.GCCSD(scf.addons.convert_to_ghf(kmf))
1237    kccsd_eris = kccsd._make_eris_incore(kgcc, kgcc._scf.mo_coeff)
1238    r1 = kgcc.spatial2spin(t1)
1239    r2 = kgcc.spatial2spin(t2)
1240    ge = kccsd.energy(kgcc, r1, r2, kccsd_eris)
1241    r1, r2 = kgcc.update_amps(r1, r2, kccsd_eris)
1242    print(abs(r1 - kgcc.spatial2spin(Ht1)).max())
1243    print(abs(r2 - kgcc.spatial2spin(Ht2)).max())
1244
1245    print(all([abs(lib.finger(eris.oooo) - (-0.18290712163391809-0.13839081039521306j)  )<1e-8,
1246               abs(lib.finger(eris.ooOO) - (-0.084752145202964035-0.28496525042110676j) )<1e-8,
1247               #abs(lib.finger(eris.OOoo) - (0.43054922768629345-0.27990237216969871j)   )<1e-8,
1248               abs(lib.finger(eris.OOOO) - (-0.2941475969103261-0.047247498899840978j)  )<1e-8,
1249               abs(lib.finger(eris.ooov) - (0.23381463349517045-0.11703340936984277j)   )<1e-8,
1250               abs(lib.finger(eris.ooOV) - (-0.052655392703214066+0.69533309442418556j) )<1e-8,
1251               abs(lib.finger(eris.OOov) - (-0.2111361247200903+0.85087916975274647j)   )<1e-8,
1252               abs(lib.finger(eris.OOOV) - (-0.36995992208047412-0.18887278030885621j)  )<1e-8,
1253               abs(lib.finger(eris.oovv) - (0.21107397525051516+0.0048714991438174871j) )<1e-8,
1254               abs(lib.finger(eris.ooVV) - (-0.076411225687065987+0.11080438166425896j) )<1e-8,
1255               abs(lib.finger(eris.OOvv) - (-0.17880337626095003-0.24174716216954206j)  )<1e-8,
1256               abs(lib.finger(eris.OOVV) - (0.059186286356424908+0.68433866387500164j)  )<1e-8,
1257               abs(lib.finger(eris.ovov) - (0.15402983765151051+0.064359681685222214j)  )<1e-8,
1258               abs(lib.finger(eris.ovOV) - (-0.10697649196044598+0.30351249676253234j)  )<1e-8,
1259               #abs(lib.finger(eris.OVov) - (-0.17619329728836752-0.56585020976035816j)  )<1e-8,
1260               abs(lib.finger(eris.OVOV) - (-0.63963235318492118+0.69863219317718828j)  )<1e-8,
1261               abs(lib.finger(eris.voov) - (-0.24137641647339092+0.18676684336011531j)  )<1e-8,
1262               abs(lib.finger(eris.voOV) - (0.19257709151227204+0.38929027819406414j)   )<1e-8,
1263               #abs(lib.finger(eris.VOov) - (0.07632606729926053-0.70350947950650355j)   )<1e-8,
1264               abs(lib.finger(eris.VOOV) - (-0.47970203195500816+0.46735207193861927j)  )<1e-8,
1265               abs(lib.finger(eris.vovv) - (-0.1342049915673903-0.23391327821719513j)   )<1e-8,
1266               abs(lib.finger(eris.voVV) - (-0.28989635223866056+0.9644368822688475j)   )<1e-8,
1267               abs(lib.finger(eris.VOvv) - (-0.32428269235420271+0.0029847254383674748j))<1e-8,
1268               abs(lib.finger(eris.VOVV) - (0.45031779746222456-0.36858577475752041j)   )<1e-8]))
1269
1270    eris = _make_eris_outcore(mycc, mycc.mo_coeff)
1271    print(all([abs(lib.finger(eris.oooo) - (-0.18290712163391809-0.13839081039521306j)  )<1e-8,
1272               abs(lib.finger(eris.ooOO) - (-0.084752145202964035-0.28496525042110676j) )<1e-8,
1273               #abs(lib.finger(eris.OOoo) - (0.43054922768629345-0.27990237216969871j)   )<1e-8,
1274               abs(lib.finger(eris.OOOO) - (-0.2941475969103261-0.047247498899840978j)  )<1e-8,
1275               abs(lib.finger(eris.ooov) - (0.23381463349517045-0.11703340936984277j)   )<1e-8,
1276               abs(lib.finger(eris.ooOV) - (-0.052655392703214066+0.69533309442418556j) )<1e-8,
1277               abs(lib.finger(eris.OOov) - (-0.2111361247200903+0.85087916975274647j)   )<1e-8,
1278               abs(lib.finger(eris.OOOV) - (-0.36995992208047412-0.18887278030885621j)  )<1e-8,
1279               abs(lib.finger(eris.oovv) - (0.21107397525051516+0.0048714991438174871j) )<1e-8,
1280               abs(lib.finger(eris.ooVV) - (-0.076411225687065987+0.11080438166425896j) )<1e-8,
1281               abs(lib.finger(eris.OOvv) - (-0.17880337626095003-0.24174716216954206j)  )<1e-8,
1282               abs(lib.finger(eris.OOVV) - (0.059186286356424908+0.68433866387500164j)  )<1e-8,
1283               abs(lib.finger(eris.ovov) - (0.15402983765151051+0.064359681685222214j)  )<1e-8,
1284               abs(lib.finger(eris.ovOV) - (-0.10697649196044598+0.30351249676253234j)  )<1e-8,
1285               #abs(lib.finger(eris.OVov) - (-0.17619329728836752-0.56585020976035816j)  )<1e-8,
1286               abs(lib.finger(eris.OVOV) - (-0.63963235318492118+0.69863219317718828j)  )<1e-8,
1287               abs(lib.finger(eris.voov) - (-0.24137641647339092+0.18676684336011531j)  )<1e-8,
1288               abs(lib.finger(eris.voOV) - (0.19257709151227204+0.38929027819406414j)   )<1e-8,
1289               #abs(lib.finger(eris.VOov) - (0.07632606729926053-0.70350947950650355j)   )<1e-8,
1290               abs(lib.finger(eris.VOOV) - (-0.47970203195500816+0.46735207193861927j)  )<1e-8,
1291               abs(lib.finger(eris.vovv) - (-0.1342049915673903-0.23391327821719513j)   )<1e-8,
1292               abs(lib.finger(eris.voVV) - (-0.28989635223866056+0.9644368822688475j)   )<1e-8,
1293               abs(lib.finger(eris.VOvv) - (-0.32428269235420271+0.0029847254383674748j))<1e-8,
1294               abs(lib.finger(eris.VOVV) - (0.45031779746222456-0.36858577475752041j)   )<1e-8,
1295               abs(lib.finger(eris.vvvv) - (-0.080512851258903173-0.2868384266725581j)  )<1e-8,
1296               abs(lib.finger(eris.vvVV) - (-0.5137063762484736+1.1036785801263898j)    )<1e-8,
1297               #abs(lib.finger(eris.VVvv) - (0.16468487082491939+0.25730725586992997j)   )<1e-8,
1298               abs(lib.finger(eris.VVVV) - (-0.56714875196802295+0.058636785679170501j) )<1e-8]))
1299