1#!/usr/bin/env python
2# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#
18
19
20import numpy
21from pyscf import lib
22from pyscf.lib import logger
23from pyscf.cc import uccsd
24from pyscf.cc import ccsd_lambda
25
26einsum = lib.einsum
27
28def kernel(mycc, eris=None, t1=None, t2=None, l1=None, l2=None,
29           max_cycle=50, tol=1e-8, verbose=logger.INFO):
30    if eris is None: eris = mycc.ao2mo()
31    return ccsd_lambda.kernel(mycc, eris, t1, t2, l1, l2, max_cycle, tol,
32                              verbose, make_intermediates, update_lambda)
33
34# l2, t2 as ijab
35def make_intermediates(mycc, t1, t2, eris):
36    t1a, t1b = t1
37    t2aa, t2ab, t2bb = t2
38    nocca, nvira = t1a.shape
39    noccb, nvirb = t1b.shape
40
41    fooa = eris.focka[:nocca,:nocca]
42    fova = eris.focka[:nocca,nocca:]
43    fvoa = eris.focka[nocca:,:nocca]
44    fvva = eris.focka[nocca:,nocca:]
45    foob = eris.fockb[:noccb,:noccb]
46    fovb = eris.fockb[:noccb,noccb:]
47    fvob = eris.fockb[noccb:,:noccb]
48    fvvb = eris.fockb[noccb:,noccb:]
49
50    tauaa, tauab, taubb = uccsd.make_tau(t2, t1, t1)
51
52    ovov = numpy.asarray(eris.ovov)
53    ovov = ovov - ovov.transpose(0,3,2,1)
54    OVOV = numpy.asarray(eris.OVOV)
55    OVOV = OVOV - OVOV.transpose(0,3,2,1)
56    ovOV = numpy.asarray(eris.ovOV)
57
58    v1a  = fvva - einsum('ja,jb->ba', fova, t1a)
59    v1b  = fvvb - einsum('ja,jb->ba', fovb, t1b)
60    v1a += einsum('jcka,jkbc->ba', ovov, tauaa) * .5
61    v1a -= einsum('jaKC,jKbC->ba', ovOV, tauab) * .5
62    v1a -= einsum('kaJC,kJbC->ba', ovOV, tauab) * .5
63    v1b += einsum('jcka,jkbc->ba', OVOV, taubb) * .5
64    v1b -= einsum('kcJA,kJcB->BA', ovOV, tauab) * .5
65    v1b -= einsum('jcKA,jKcB->BA', ovOV, tauab) * .5
66
67    v2a  = fooa + einsum('ib,jb->ij', fova, t1a)
68    v2b  = foob + einsum('ib,jb->ij', fovb, t1b)
69    v2a += einsum('ibkc,jkbc->ij', ovov, tauaa) * .5
70    v2a += einsum('ibKC,jKbC->ij', ovOV, tauab)
71    v2b += einsum('ibkc,jkbc->ij', OVOV, taubb) * .5
72    v2b += einsum('kcIB,kJcB->IJ', ovOV, tauab)
73
74    ovoo = numpy.asarray(eris.ovoo)
75    ovoo = ovoo - ovoo.transpose(2,1,0,3)
76    OVOO = numpy.asarray(eris.OVOO)
77    OVOO = OVOO - OVOO.transpose(2,1,0,3)
78    OVoo = numpy.asarray(eris.OVoo)
79    ovOO = numpy.asarray(eris.ovOO)
80    v2a -= numpy.einsum('ibkj,kb->ij', ovoo, t1a)
81    v2a += numpy.einsum('KBij,KB->ij', OVoo, t1b)
82    v2b -= numpy.einsum('ibkj,kb->ij', OVOO, t1b)
83    v2b += numpy.einsum('kbIJ,kb->IJ', ovOO, t1a)
84
85    v5a  = fvoa + numpy.einsum('kc,jkbc->bj', fova, t2aa)
86    v5a += numpy.einsum('KC,jKbC->bj', fovb, t2ab)
87    v5b  = fvob + numpy.einsum('kc,jkbc->bj', fovb, t2bb)
88    v5b += numpy.einsum('kc,kJcB->BJ', fova, t2ab)
89    tmp  = fova - numpy.einsum('kdlc,ld->kc', ovov, t1a)
90    tmp += numpy.einsum('kcLD,LD->kc', ovOV, t1b)
91    v5a += einsum('kc,kb,jc->bj', tmp, t1a, t1a)
92    tmp  = fovb - numpy.einsum('kdlc,ld->kc', OVOV, t1b)
93    tmp += numpy.einsum('ldKC,ld->KC', ovOV, t1a)
94    v5b += einsum('kc,kb,jc->bj', tmp, t1b, t1b)
95    v5a -= einsum('lckj,klbc->bj', ovoo, t2aa) * .5
96    v5a -= einsum('LCkj,kLbC->bj', OVoo, t2ab)
97    v5b -= einsum('LCKJ,KLBC->BJ', OVOO, t2bb) * .5
98    v5b -= einsum('lcKJ,lKcB->BJ', ovOO, t2ab)
99
100    oooo = numpy.asarray(eris.oooo)
101    OOOO = numpy.asarray(eris.OOOO)
102    ooOO = numpy.asarray(eris.ooOO)
103    woooo  = einsum('icjl,kc->ikjl', ovoo, t1a)
104    wOOOO  = einsum('icjl,kc->ikjl', OVOO, t1b)
105    wooOO  = einsum('icJL,kc->ikJL', ovOO, t1a)
106    wooOO += einsum('JCil,KC->ilJK', OVoo, t1b)
107    woooo += (oooo - oooo.transpose(0,3,2,1)) * .5
108    wOOOO += (OOOO - OOOO.transpose(0,3,2,1)) * .5
109    wooOO += ooOO.copy()
110    woooo += einsum('icjd,klcd->ikjl', ovov, tauaa) * .25
111    wOOOO += einsum('icjd,klcd->ikjl', OVOV, taubb) * .25
112    wooOO += einsum('icJD,kLcD->ikJL', ovOV, tauab)
113
114    v4ovvo  = einsum('jbld,klcd->jbck', ovov, t2aa)
115    v4ovvo += einsum('jbLD,kLcD->jbck', ovOV, t2ab)
116    v4ovvo += numpy.asarray(eris.ovvo)
117    v4ovvo -= numpy.asarray(eris.oovv).transpose(0,3,2,1)
118    v4OVVO  = einsum('jbld,klcd->jbck', OVOV, t2bb)
119    v4OVVO += einsum('ldJB,lKdC->JBCK', ovOV, t2ab)
120    v4OVVO += numpy.asarray(eris.OVVO)
121    v4OVVO -= numpy.asarray(eris.OOVV).transpose(0,3,2,1)
122    v4OVvo  = einsum('ldJB,klcd->JBck', ovOV, t2aa)
123    v4OVvo += einsum('JBLD,kLcD->JBck', OVOV, t2ab)
124    v4OVvo += numpy.asarray(eris.OVvo)
125    v4ovVO  = einsum('jbLD,KLCD->jbCK', ovOV, t2bb)
126    v4ovVO += einsum('jbld,lKdC->jbCK', ovov, t2ab)
127    v4ovVO += numpy.asarray(eris.ovVO)
128    v4oVVo  = einsum('jdLB,kLdC->jBCk', ovOV, t2ab)
129    v4oVVo -= numpy.asarray(eris.ooVV).transpose(0,3,2,1)
130    v4OvvO  = einsum('lbJD,lKcD->JbcK', ovOV, t2ab)
131    v4OvvO -= numpy.asarray(eris.OOvv).transpose(0,3,2,1)
132
133    woovo  = einsum('ibck,jb->ijck', v4ovvo, t1a)
134    wOOVO  = einsum('ibck,jb->ijck', v4OVVO, t1b)
135    wOOvo  = einsum('IBck,JB->IJck', v4OVvo, t1b)
136    wOOvo -= einsum('IbcK,jb->IKcj', v4OvvO, t1a)
137    wooVO  = einsum('ibCK,jb->ijCK', v4ovVO, t1a)
138    wooVO -= einsum('iBCk,JB->ikCJ', v4oVVo, t1b)
139    woovo += ovoo.conj().transpose(3,2,1,0) * .5
140    wOOVO += OVOO.conj().transpose(3,2,1,0) * .5
141    wooVO += OVoo.conj().transpose(3,2,1,0)
142    wOOvo += ovOO.conj().transpose(3,2,1,0)
143    woovo -= einsum('iclk,jlbc->ikbj', ovoo, t2aa)
144    woovo += einsum('LCik,jLbC->ikbj', OVoo, t2ab)
145    wOOVO -= einsum('iclk,jlbc->ikbj', OVOO, t2bb)
146    wOOVO += einsum('lcIK,lJcB->IKBJ', ovOO, t2ab)
147    wooVO -= einsum('iclk,lJcB->ikBJ', ovoo, t2ab)
148    wooVO += einsum('LCik,JLBC->ikBJ', OVoo, t2bb)
149    wooVO -= einsum('icLK,jLcB->ijBK', ovOO, t2ab)
150    wOOvo -= einsum('ICLK,jLbC->IKbj', OVOO, t2ab)
151    wOOvo += einsum('lcIK,jlbc->IKbj', ovOO, t2aa)
152    wOOvo -= einsum('IClk,lJbC->IJbk', OVoo, t2ab)
153
154    wvvvo  = einsum('jack,jb->back', v4ovvo, t1a)
155    wVVVO  = einsum('jack,jb->back', v4OVVO, t1b)
156    wVVvo  = einsum('JAck,JB->BAck', v4OVvo, t1b)
157    wVVvo -= einsum('jACk,jb->CAbk', v4oVVo, t1a)
158    wvvVO  = einsum('jaCK,jb->baCK', v4ovVO, t1a)
159    wvvVO -= einsum('JacK,JB->caBK', v4OvvO, t1b)
160    wvvvo += einsum('lajk,jlbc->back', .25*ovoo, tauaa)
161    wVVVO += einsum('lajk,jlbc->back', .25*OVOO, taubb)
162    wVVvo -= einsum('LAjk,jLcB->BAck', OVoo, tauab)
163    wvvVO -= einsum('laJK,lJbC->baCK', ovOO, tauab)
164
165    w3a  = numpy.einsum('jbck,jb->ck', v4ovvo, t1a)
166    w3a += numpy.einsum('JBck,JB->ck', v4OVvo, t1b)
167    w3b  = numpy.einsum('jbck,jb->ck', v4OVVO, t1b)
168    w3b += numpy.einsum('jbCK,jb->CK', v4ovVO, t1a)
169
170    wovvo  = v4ovvo
171    wOVVO  = v4OVVO
172    wovVO  = v4ovVO
173    wOVvo  = v4OVvo
174    woVVo  = v4oVVo
175    wOvvO  = v4OvvO
176    wovvo += lib.einsum('jbld,kd,lc->jbck', ovov, t1a, -t1a)
177    wOVVO += lib.einsum('jbld,kd,lc->jbck', OVOV, t1b, -t1b)
178    wovVO += lib.einsum('jbLD,KD,LC->jbCK', ovOV, t1b, -t1b)
179    wOVvo += lib.einsum('ldJB,kd,lc->JBck', ovOV, t1a, -t1a)
180    woVVo += lib.einsum('jdLB,kd,LC->jBCk', ovOV, t1a,  t1b)
181    wOvvO += lib.einsum('lbJD,KD,lc->JbcK', ovOV, t1b,  t1a)
182    wovvo -= einsum('jblk,lc->jbck', ovoo, t1a)
183    wOVVO -= einsum('jblk,lc->jbck', OVOO, t1b)
184    wovVO -= einsum('jbLK,LC->jbCK', ovOO, t1b)
185    wOVvo -= einsum('JBlk,lc->JBck', OVoo, t1a)
186    woVVo += einsum('LBjk,LC->jBCk', OVoo, t1b)
187    wOvvO += einsum('lbJK,lc->JbcK', ovOO, t1a)
188
189    if nvira > 0 and nocca > 0:
190        ovvv = numpy.asarray(eris.get_ovvv())
191        ovvv = ovvv - ovvv.transpose(0,3,2,1)
192        v1a -= numpy.einsum('jabc,jc->ba', ovvv, t1a)
193        v5a += einsum('kdbc,jkcd->bj', ovvv, t2aa) * .5
194        woovo += einsum('idcb,kjbd->ijck', ovvv, tauaa) * .25
195        wovvo += einsum('jbcd,kd->jbck', ovvv, t1a)
196        wvvvo -= ovvv.conj().transpose(3,2,1,0) * .5
197        wvvvo += einsum('jacd,kjbd->cabk', ovvv, t2aa)
198        wvvVO += einsum('jacd,jKdB->caBK', ovvv, t2ab)
199        ovvv = tmp = None
200
201    if nvirb > 0 and noccb > 0:
202        OVVV = numpy.asarray(eris.get_OVVV())
203        OVVV = OVVV - OVVV.transpose(0,3,2,1)
204        v1b -= numpy.einsum('jabc,jc->ba', OVVV, t1b)
205        v5b += einsum('KDBC,JKCD->BJ', OVVV, t2bb) * .5
206        wOOVO += einsum('idcb,kjbd->ijck', OVVV, taubb) * .25
207        wOVVO += einsum('jbcd,kd->jbck', OVVV, t1b)
208        wVVVO -= OVVV.conj().transpose(3,2,1,0) * .5
209        wVVVO += einsum('jacd,kjbd->cabk', OVVV, t2bb)
210        wVVvo += einsum('JACD,kJbD->CAbk', OVVV, t2ab)
211        OVVV = tmp = None
212
213    if nvirb > 0 and nocca > 0:
214        OVvv = numpy.asarray(eris.get_OVvv())
215        v1a += numpy.einsum('JCba,JC->ba', OVvv, t1b)
216        v5a += einsum('KDbc,jKcD->bj', OVvv, t2ab)
217        wOOvo += einsum('IDcb,kJbD->IJck', OVvv, tauab)
218        wOVvo += einsum('JBcd,kd->JBck', OVvv, t1a)
219        wOvvO -= einsum('JDcb,KD->JbcK', OVvv, t1b)
220        wvvVO -= OVvv.conj().transpose(3,2,1,0)
221        wvvvo -= einsum('KDca,jKbD->cabj', OVvv, t2ab)
222        wvvVO -= einsum('KDca,JKBD->caBJ', OVvv, t2bb)
223        wVVvo += einsum('KAcd,jKdB->BAcj', OVvv, t2ab)
224        OVvv = tmp = None
225
226    if nvira > 0 and noccb > 0:
227        ovVV = numpy.asarray(eris.get_ovVV())
228        v1b += numpy.einsum('jcBA,jc->BA', ovVV, t1a)
229        v5b += einsum('kdBC,kJdC->BJ', ovVV, t2ab)
230        wooVO += einsum('idCB,jKdB->ijCK', ovVV, tauab)
231        wovVO += einsum('jbCD,KD->jbCK', ovVV, t1b)
232        woVVo -= einsum('jdCB,kd->jBCk', ovVV, t1a)
233        wVVvo -= ovVV.conj().transpose(3,2,1,0)
234        wVVVO -= einsum('kdCA,kJdB->CABJ', ovVV, t2ab)
235        wVVvo -= einsum('kdCA,jkbd->CAbj', ovVV, t2aa)
236        wvvVO += einsum('kaCD,kJbD->baCJ', ovVV, t2ab)
237        ovVV = tmp = None
238
239    w3a += v5a
240    w3b += v5b
241    w3a += lib.einsum('cb,jb->cj', v1a, t1a)
242    w3b += lib.einsum('cb,jb->cj', v1b, t1b)
243    w3a -= lib.einsum('jk,jb->bk', v2a, t1a)
244    w3b -= lib.einsum('jk,jb->bk', v2b, t1b)
245
246    class _IMDS: pass
247    imds = _IMDS()
248    imds.ftmp = lib.H5TmpFile()
249    dtype = numpy.result_type(t2ab, eris.vvvv).char
250    imds.woooo = imds.ftmp.create_dataset('woooo', (nocca,nocca,nocca,nocca), dtype)
251    imds.wooOO = imds.ftmp.create_dataset('wooOO', (nocca,nocca,noccb,noccb), dtype)
252    imds.wOOOO = imds.ftmp.create_dataset('wOOOO', (noccb,noccb,noccb,noccb), dtype)
253    imds.wovvo = imds.ftmp.create_dataset('wovvo', (nocca,nvira,nvira,nocca), dtype)
254    imds.wOVVO = imds.ftmp.create_dataset('wOVVO', (noccb,nvirb,nvirb,noccb), dtype)
255    imds.wovVO = imds.ftmp.create_dataset('wovVO', (nocca,nvira,nvirb,noccb), dtype)
256    imds.wOVvo = imds.ftmp.create_dataset('wOVvo', (noccb,nvirb,nvira,nocca), dtype)
257    imds.woVVo = imds.ftmp.create_dataset('woVVo', (nocca,nvirb,nvirb,nocca), dtype)
258    imds.wOvvO = imds.ftmp.create_dataset('wOvvO', (noccb,nvira,nvira,noccb), dtype)
259    imds.woovo = imds.ftmp.create_dataset('woovo', (nocca,nocca,nvira,nocca), dtype)
260    imds.wOOVO = imds.ftmp.create_dataset('wOOVO', (noccb,noccb,nvirb,noccb), dtype)
261    imds.wOOvo = imds.ftmp.create_dataset('wOOvo', (noccb,noccb,nvira,nocca), dtype)
262    imds.wooVO = imds.ftmp.create_dataset('wooVO', (nocca,nocca,nvirb,noccb), dtype)
263    imds.wvvvo = imds.ftmp.create_dataset('wvvvo', (nvira,nvira,nvira,nocca), dtype)
264    imds.wVVVO = imds.ftmp.create_dataset('wVVVO', (nvirb,nvirb,nvirb,noccb), dtype)
265    imds.wVVvo = imds.ftmp.create_dataset('wVVvo', (nvirb,nvirb,nvira,nocca), dtype)
266    imds.wvvVO = imds.ftmp.create_dataset('wvvVO', (nvira,nvira,nvirb,noccb), dtype)
267
268    imds.woooo[:] = woooo
269    imds.wOOOO[:] = wOOOO
270    imds.wooOO[:] = wooOO
271    imds.wovvo[:] = wovvo
272    imds.wOVVO[:] = wOVVO
273    imds.wovVO[:] = wovVO
274    imds.wOVvo[:] = wOVvo
275    imds.woVVo[:] = woVVo
276    imds.wOvvO[:] = wOvvO
277    imds.woovo[:] = woovo
278    imds.wOOVO[:] = wOOVO
279    imds.wOOvo[:] = wOOvo
280    imds.wooVO[:] = wooVO
281    imds.wvvvo[:] = wvvvo
282    imds.wVVVO[:] = wVVVO
283    imds.wVVvo[:] = wVVvo
284    imds.wvvVO[:] = wvvVO
285    imds.v1a = v1a
286    imds.v1b = v1b
287    imds.v2a = v2a
288    imds.v2b = v2b
289    imds.w3a = w3a
290    imds.w3b = w3b
291    imds.ftmp.flush()
292    return imds
293
294
295# update L1, L2
296def update_lambda(mycc, t1, t2, l1, l2, eris, imds):
297    time0 = logger.process_clock(), logger.perf_counter()
298    log = logger.Logger(mycc.stdout, mycc.verbose)
299
300    t1a, t1b = t1
301    t2aa, t2ab, t2bb = t2
302    l1a, l1b = l1
303    l2aa, l2ab, l2bb = l2
304    nocca, nvira = t1a.shape
305    noccb, nvirb = t1b.shape
306    u1a = numpy.zeros_like(l1a)
307    u1b = numpy.zeros_like(l1b)
308    u2aa = numpy.zeros_like(l2aa)
309    u2ab = numpy.zeros_like(l2ab)
310    u2bb = numpy.zeros_like(l2bb)
311    mo_ea_o = eris.mo_energy[0][:nocca]
312    mo_ea_v = eris.mo_energy[0][nocca:] + mycc.level_shift
313    mo_eb_o = eris.mo_energy[1][:noccb]
314    mo_eb_v = eris.mo_energy[1][noccb:] + mycc.level_shift
315
316    fova = eris.focka[:nocca,nocca:]
317    fovb = eris.fockb[:noccb,noccb:]
318    v1a = imds.v1a - numpy.diag(mo_ea_v)
319    v1b = imds.v1b - numpy.diag(mo_eb_v)
320    v2a = imds.v2a - numpy.diag(mo_ea_o)
321    v2b = imds.v2b - numpy.diag(mo_eb_o)
322
323    mvv = einsum('klca,klcb->ba', l2aa, t2aa) * .5
324    mvv+= einsum('lKaC,lKbC->ba', l2ab, t2ab)
325    mVV = einsum('klca,klcb->ba', l2bb, t2bb) * .5
326    mVV+= einsum('kLcA,kLcB->BA', l2ab, t2ab)
327    moo = einsum('kicd,kjcd->ij', l2aa, t2aa) * .5
328    moo+= einsum('iKdC,jKdC->ij', l2ab, t2ab)
329    mOO = einsum('kicd,kjcd->ij', l2bb, t2bb) * .5
330    mOO+= einsum('kIcD,kJcD->IJ', l2ab, t2ab)
331
332    #m3 = lib.einsum('ijcd,cdab->ijab', l2, eris.vvvv) * .5
333    m3aa, m3ab, m3bb = mycc._add_vvvv(None, (l2aa.conj(),l2ab.conj(),l2bb.conj()), eris)
334    m3aa = m3aa.conj()
335    m3ab = m3ab.conj()
336    m3bb = m3bb.conj()
337    m3aa += lib.einsum('klab,ikjl->ijab', l2aa, numpy.asarray(imds.woooo))
338    m3bb += lib.einsum('klab,ikjl->ijab', l2bb, numpy.asarray(imds.wOOOO))
339    m3ab += lib.einsum('kLaB,ikJL->iJaB', l2ab, numpy.asarray(imds.wooOO))
340
341    ovov = numpy.asarray(eris.ovov)
342    ovov = ovov - ovov.transpose(0,3,2,1)
343    OVOV = numpy.asarray(eris.OVOV)
344    OVOV = OVOV - OVOV.transpose(0,3,2,1)
345    ovOV = numpy.asarray(eris.ovOV)
346    mvv1 = einsum('jc,jb->bc', l1a, t1a) + mvv
347    mVV1 = einsum('jc,jb->bc', l1b, t1b) + mVV
348    moo1 = einsum('ic,kc->ik', l1a, t1a) + moo
349    mOO1 = einsum('ic,kc->ik', l1b, t1b) + mOO
350    if nvira > 0 and nocca > 0:
351        ovvv = numpy.asarray(eris.get_ovvv())
352        ovvv = ovvv - ovvv.transpose(0,3,2,1)
353        tmp = lib.einsum('ijcd,kd->ijck', l2aa, t1a)
354        m3aa -= lib.einsum('kbca,ijck->ijab', ovvv, tmp)
355
356        tmp = einsum('ic,jbca->jiba', l1a, ovvv)
357        tmp+= einsum('kiab,jk->ijab', l2aa, v2a)
358        tmp-= einsum('ik,kajb->ijab', moo1, ovov)
359        u2aa += tmp - tmp.transpose(1,0,2,3)
360        u1a += numpy.einsum('iacb,bc->ia', ovvv, mvv1)
361        ovvv = tmp = None
362
363    if nvirb > 0 and noccb > 0:
364        OVVV = numpy.asarray(eris.get_OVVV())
365        OVVV = OVVV - OVVV.transpose(0,3,2,1)
366        tmp = lib.einsum('ijcd,kd->ijck', l2bb, t1b)
367        m3bb -= lib.einsum('kbca,ijck->ijab', OVVV, tmp)
368
369        tmp = einsum('ic,jbca->jiba', l1b, OVVV)
370        tmp+= einsum('kiab,jk->ijab', l2bb, v2b)
371        tmp-= einsum('ik,kajb->ijab', mOO1, OVOV)
372        u2bb += tmp - tmp.transpose(1,0,2,3)
373        u1b += numpy.einsum('iaCB,BC->ia', OVVV, mVV1)
374        OVVV = tmp = None
375
376    if nvirb > 0 and nocca > 0:
377        OVvv = numpy.asarray(eris.get_OVvv())
378        tmp = lib.einsum('iJcD,KD->iJcK', l2ab, t1b)
379        m3ab -= lib.einsum('KBca,iJcK->iJaB', OVvv, tmp)
380
381        tmp = einsum('ic,JAcb->JibA', l1a, OVvv)
382        tmp-= einsum('kIaB,jk->IjaB', l2ab, v2a)
383        tmp-= einsum('IK,jaKB->IjaB', mOO1, ovOV)
384        u2ab += tmp.transpose(1,0,2,3)
385        u1b += numpy.einsum('iacb,bc->ia', OVvv, mvv1)
386        OVvv = tmp = None
387
388    if nvira > 0 and noccb > 0:
389        ovVV = numpy.asarray(eris.get_ovVV())
390        tmp = lib.einsum('iJdC,kd->iJCk', l2ab, t1a)
391        m3ab -= lib.einsum('kaCB,iJCk->iJaB', ovVV, tmp)
392
393        tmp = einsum('IC,jbCA->jIbA', l1b, ovVV)
394        tmp-= einsum('iKaB,JK->iJaB', l2ab, v2b)
395        tmp-= einsum('ik,kaJB->iJaB', moo1, ovOV)
396        u2ab += tmp
397        u1a += numpy.einsum('iaCB,BC->ia', ovVV, mVV1)
398        ovVV = tmp = None
399
400    tauaa, tauab, taubb = uccsd.make_tau(t2, t1, t1)
401    tmp = lib.einsum('ijcd,klcd->ijkl', l2aa, tauaa)
402    ovov = numpy.asarray(eris.ovov)
403    ovov = ovov - ovov.transpose(0,3,2,1)
404    m3aa += lib.einsum('kalb,ijkl->ijab', ovov, tmp) * .25
405
406    tmp = lib.einsum('ijcd,klcd->ijkl', l2bb, taubb)
407    OVOV = numpy.asarray(eris.OVOV)
408    OVOV = OVOV - OVOV.transpose(0,3,2,1)
409    m3bb += lib.einsum('kalb,ijkl->ijab', OVOV, tmp) * .25
410
411    tmp = lib.einsum('iJcD,kLcD->iJkL', l2ab, tauab)
412    ovOV = numpy.asarray(eris.ovOV)
413    m3ab += lib.einsum('kaLB,iJkL->iJaB', ovOV, tmp) * .5
414    tmp = lib.einsum('iJdC,lKdC->iJKl', l2ab, tauab)
415    m3ab += lib.einsum('laKB,iJKl->iJaB', ovOV, tmp) * .5
416
417    u1a += numpy.einsum('ijab,jb->ia', m3aa, t1a)
418    u1a += numpy.einsum('iJaB,JB->ia', m3ab, t1b)
419    u1b += numpy.einsum('IJAB,JB->IA', m3bb, t1b)
420    u1b += numpy.einsum('jIbA,jb->IA', m3ab, t1a)
421
422    u2aa += m3aa
423    u2bb += m3bb
424    u2ab += m3ab
425    u2aa += ovov.transpose(0,2,1,3)
426    u2bb += OVOV.transpose(0,2,1,3)
427    u2ab += ovOV.transpose(0,2,1,3)
428
429    fov1 = fova + numpy.einsum('kcjb,kc->jb', ovov, t1a)
430    fov1+= numpy.einsum('jbKC,KC->jb', ovOV, t1b)
431    tmp = numpy.einsum('ia,jb->ijab', l1a, fov1)
432    tmp+= einsum('kica,jbck->ijab', l2aa, imds.wovvo)
433    tmp+= einsum('iKaC,jbCK->ijab', l2ab, imds.wovVO)
434    tmp = tmp - tmp.transpose(1,0,2,3)
435    u2aa += tmp - tmp.transpose(0,1,3,2)
436
437    fov1 = fovb + numpy.einsum('kcjb,kc->jb', OVOV, t1b)
438    fov1+= numpy.einsum('kcJB,kc->JB', ovOV, t1a)
439    tmp = numpy.einsum('ia,jb->ijab', l1b, fov1)
440    tmp+= einsum('kica,jbck->ijab', l2bb, imds.wOVVO)
441    tmp+= einsum('kIcA,JBck->IJAB', l2ab, imds.wOVvo)
442    tmp = tmp - tmp.transpose(1,0,2,3)
443    u2bb += tmp - tmp.transpose(0,1,3,2)
444
445    fov1 = fovb + numpy.einsum('kcjb,kc->jb', OVOV, t1b)
446    fov1+= numpy.einsum('kcJB,kc->JB', ovOV, t1a)
447    u2ab += numpy.einsum('ia,JB->iJaB', l1a, fov1)
448    u2ab += einsum('iKaC,JBCK->iJaB', l2ab, imds.wOVVO)
449    u2ab += einsum('kica,JBck->iJaB', l2aa, imds.wOVvo)
450    u2ab += einsum('kIaC,jBCk->jIaB', l2ab, imds.woVVo)
451    u2ab += einsum('iKcA,JbcK->iJbA', l2ab, imds.wOvvO)
452    fov1 = fova + numpy.einsum('kcjb,kc->jb', ovov, t1a)
453    fov1+= numpy.einsum('jbKC,KC->jb', ovOV, t1b)
454    u2ab += numpy.einsum('ia,jb->jiba', l1b, fov1)
455    u2ab += einsum('kIcA,jbck->jIbA', l2ab, imds.wovvo)
456    u2ab += einsum('KICA,jbCK->jIbA', l2bb, imds.wovVO)
457
458    ovoo = numpy.asarray(eris.ovoo)
459    ovoo = ovoo - ovoo.transpose(2,1,0,3)
460    OVOO = numpy.asarray(eris.OVOO)
461    OVOO = OVOO - OVOO.transpose(2,1,0,3)
462    OVoo = numpy.asarray(eris.OVoo)
463    ovOO = numpy.asarray(eris.ovOO)
464    tmp = einsum('ka,jbik->ijab', l1a, ovoo)
465    tmp+= einsum('ijca,cb->ijab', l2aa, v1a)
466    tmp+= einsum('ca,icjb->ijab', mvv1, ovov)
467    u2aa -= tmp - tmp.transpose(0,1,3,2)
468    tmp = einsum('ka,jbik->ijab', l1b, OVOO)
469    tmp+= einsum('ijca,cb->ijab', l2bb, v1b)
470    tmp+= einsum('ca,icjb->ijab', mVV1, OVOV)
471    u2bb -= tmp - tmp.transpose(0,1,3,2)
472    u2ab -= einsum('ka,JBik->iJaB', l1a, OVoo)
473    u2ab += einsum('iJaC,CB->iJaB', l2ab, v1b)
474    u2ab -= einsum('ca,icJB->iJaB', mvv1, ovOV)
475    u2ab -= einsum('KA,ibJK->iJbA', l1b, ovOO)
476    u2ab += einsum('iJcA,cb->iJbA', l2ab, v1a)
477    u2ab -= einsum('CA,ibJC->iJbA', mVV1, ovOV)
478
479    u1a += fova
480    u1b += fovb
481    u1a += einsum('ib,ba->ia', l1a, v1a)
482    u1a -= einsum('ja,ij->ia', l1a, v2a)
483    u1b += einsum('ib,ba->ia', l1b, v1b)
484    u1b -= einsum('ja,ij->ia', l1b, v2b)
485
486    u1a += numpy.einsum('jb,iabj->ia', l1a, eris.ovvo)
487    u1a -= numpy.einsum('jb,ijba->ia', l1a, eris.oovv)
488    u1a += numpy.einsum('JB,iaBJ->ia', l1b, eris.ovVO)
489    u1b += numpy.einsum('jb,iabj->ia', l1b, eris.OVVO)
490    u1b -= numpy.einsum('jb,ijba->ia', l1b, eris.OOVV)
491    u1b += numpy.einsum('jb,iabj->ia', l1a, eris.OVvo)
492
493    u1a -= einsum('kjca,ijck->ia', l2aa, imds.woovo)
494    u1a -= einsum('jKaC,ijCK->ia', l2ab, imds.wooVO)
495    u1b -= einsum('kjca,ijck->ia', l2bb, imds.wOOVO)
496    u1b -= einsum('kJcA,IJck->IA', l2ab, imds.wOOvo)
497
498    u1a -= einsum('ikbc,back->ia', l2aa, imds.wvvvo)
499    u1a -= einsum('iKbC,baCK->ia', l2ab, imds.wvvVO)
500    u1b -= einsum('IKBC,BACK->IA', l2bb, imds.wVVVO)
501    u1b -= einsum('kIcB,BAck->IA', l2ab, imds.wVVvo)
502
503    u1a += numpy.einsum('jiba,bj->ia', l2aa, imds.w3a)
504    u1a += numpy.einsum('iJaB,BJ->ia', l2ab, imds.w3b)
505    u1b += numpy.einsum('JIBA,BJ->IA', l2bb, imds.w3b)
506    u1b += numpy.einsum('jIbA,bj->IA', l2ab, imds.w3a)
507
508    tmpa  = t1a + numpy.einsum('kc,kjcb->jb', l1a, t2aa)
509    tmpa += numpy.einsum('KC,jKbC->jb', l1b, t2ab)
510    tmpa -= einsum('bd,jd->jb', mvv1, t1a)
511    tmpa -= einsum('lj,lb->jb', moo, t1a)
512    tmpb  = t1b + numpy.einsum('kc,kjcb->jb', l1b, t2bb)
513    tmpb += numpy.einsum('kc,kJcB->JB', l1a, t2ab)
514    tmpb -= einsum('bd,jd->jb', mVV1, t1b)
515    tmpb -= einsum('lj,lb->jb', mOO, t1b)
516    u1a += numpy.einsum('jbia,jb->ia', ovov, tmpa)
517    u1a += numpy.einsum('iaJB,JB->ia', ovOV, tmpb)
518    u1b += numpy.einsum('jbia,jb->ia', OVOV, tmpb)
519    u1b += numpy.einsum('jbIA,jb->IA', ovOV, tmpa)
520
521    u1a -= numpy.einsum('iajk,kj->ia', ovoo, moo1)
522    u1a -= numpy.einsum('iaJK,KJ->ia', ovOO, mOO1)
523    u1b -= numpy.einsum('iajk,kj->ia', OVOO, mOO1)
524    u1b -= numpy.einsum('IAjk,kj->IA', OVoo, moo1)
525
526    tmp  = fova - numpy.einsum('kbja,jb->ka', ovov, t1a)
527    tmp += numpy.einsum('kaJB,JB->ka', ovOV, t1b)
528    u1a -= lib.einsum('ik,ka->ia', moo, tmp)
529    u1a -= lib.einsum('ca,ic->ia', mvv, tmp)
530    tmp  = fovb - numpy.einsum('kbja,jb->ka', OVOV, t1b)
531    tmp += numpy.einsum('jbKA,jb->KA', ovOV, t1a)
532    u1b -= lib.einsum('ik,ka->ia', mOO, tmp)
533    u1b -= lib.einsum('ca,ic->ia', mVV, tmp)
534
535    eia = lib.direct_sum('i-j->ij', mo_ea_o, mo_ea_v)
536    eIA = lib.direct_sum('i-j->ij', mo_eb_o, mo_eb_v)
537    u1a /= eia
538    u1b /= eIA
539
540    u2aa /= lib.direct_sum('ia+jb->ijab', eia, eia)
541    u2ab /= lib.direct_sum('ia+jb->ijab', eia, eIA)
542    u2bb /= lib.direct_sum('ia+jb->ijab', eIA, eIA)
543
544    time0 = log.timer_debug1('update l1 l2', *time0)
545    return (u1a,u1b), (u2aa,u2ab,u2bb)
546
547
548if __name__ == '__main__':
549    from pyscf import gto
550    from pyscf import scf
551    from pyscf.cc import gccsd
552
553    mol = gto.Mole()
554    mol.atom = [
555        [8 , (0. , 0.     , 0.)],
556        [1 , (0. , -0.757 , 0.587)],
557        [1 , (0. , 0.757  , 0.587)]]
558    mol.basis = '631g'
559    mol.spin = 2
560    mol.build()
561    mf = scf.UHF(mol).run()
562    mycc = gccsd.GCCSD(scf.addons.convert_to_ghf(mf))
563    eris = mycc.ao2mo()
564    mycc.kernel()
565    l1, l2 = mycc.solve_lambda(mycc.t1, mycc.t2, eris=eris)
566    l1ref = mycc.spin2spatial(l1, mycc.mo_coeff.orbspin)
567    l2ref = mycc.spin2spatial(l2, mycc.mo_coeff.orbspin)
568
569    mycc = uccsd.UCCSD(mf)
570    eris = mycc.ao2mo()
571    mycc.kernel()
572    conv, l1, l2 = kernel(mycc, eris, mycc.t1, mycc.t2, tol=1e-8)
573    print(abs(l1[0]-l1ref[0]).max())
574    print(abs(l1[1]-l1ref[1]).max())
575    print(abs(l2[0]-l2ref[0]).max())
576    print(abs(l2[1]-l2ref[1]).max())
577    print(abs(l2[2]-l2ref[2]).max())
578