1#!/usr/bin/env python
2# Copyright 2014-2021 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#         James D. McClain
18#         Jason Yu
19#         Shining Sun
20#         Mario Motta
21#         Chong Sun
22#
23
24
25import numpy as np
26
27from pyscf import lib
28from pyscf.lib import logger
29from pyscf import ao2mo
30from pyscf.cc import ccsd
31from pyscf.cc import uccsd
32from pyscf.cc import eom_rccsd
33from pyscf.cc import uintermediates
34
35########################################
36# EOM-IP-CCSD
37########################################
38
39def vector_to_amplitudes_ip(vector, nmo, nocc):
40    '''For spin orbitals'''
41    nocca, noccb = nocc
42    nmoa, nmob = nmo
43    nvira, nvirb = nmoa-nocca, nmob-noccb
44
45    sizes = (nocca, noccb, nocca*(nocca-1)//2*nvira, noccb*nocca*nvira,
46             nocca*noccb*nvirb, noccb*(noccb-1)//2*nvirb)
47    sections = np.cumsum(sizes[:-1])
48    r1a, r1b, r2a, r2baa, r2abb, r2b = np.split(vector, sections)
49    r2a = r2a.reshape(nocca*(nocca-1)//2,nvira)
50    r2b = r2b.reshape(noccb*(noccb-1)//2,nvirb)
51    r2baa = r2baa.reshape(noccb,nocca,nvira).copy()
52    r2abb = r2abb.reshape(nocca,noccb,nvirb).copy()
53
54    idxa = np.tril_indices(nocca, -1)
55    idxb = np.tril_indices(noccb, -1)
56    r2aaa = np.zeros((nocca,nocca,nvira), vector.dtype)
57    r2bbb = np.zeros((noccb,noccb,nvirb), vector.dtype)
58    r2aaa[idxa[0],idxa[1]] = r2a
59    r2aaa[idxa[1],idxa[0]] =-r2a
60    r2bbb[idxb[0],idxb[1]] = r2b
61    r2bbb[idxb[1],idxb[0]] =-r2b
62
63    r1 = (r1a.copy(), r1b.copy())
64    r2 = (r2aaa, r2baa, r2abb, r2bbb)
65    return r1, r2
66
67def amplitudes_to_vector_ip(r1, r2):
68    '''For spin orbitals'''
69    r1a, r1b = r1
70    r2aaa, r2baa, r2abb, r2bbb = r2
71    nocca, noccb, nvirb = r2abb.shape
72    idxa = np.tril_indices(nocca, -1)
73    idxb = np.tril_indices(noccb, -1)
74    return np.hstack((r1a, r1b,
75                      r2aaa[idxa].ravel(), r2baa.ravel(),
76                      r2abb.ravel(), r2bbb[idxb].ravel()))
77
78def spatial2spin_ip(r1, r2, orbspin=None):
79    '''Convert R1/R2 of spatial orbital representation to R1/R2 of
80    spin-orbital representation
81    '''
82    r1a, r1b = r1
83    r2aaa, r2baa, r2abb, r2bbb = r2
84    nocc_a, nvir_a = r2aaa.shape[1:]
85    nocc_b, nvir_b = r2bbb.shape[1:]
86
87    if orbspin is None:
88        orbspin = np.zeros((nocc_a+nvir_a)*2, dtype=int)
89        orbspin[1::2] = 1
90
91    nocc = nocc_a + nocc_b
92    nvir = nvir_a + nvir_b
93    idxoa = np.where(orbspin[:nocc] == 0)[0]
94    idxob = np.where(orbspin[:nocc] == 1)[0]
95    idxva = np.where(orbspin[nocc:] == 0)[0]
96    idxvb = np.where(orbspin[nocc:] == 1)[0]
97
98    r1 = np.zeros((nocc), dtype=r1a.dtype)
99    r1[idxoa] = r1a
100    r1[idxob] = r1b
101
102    r2 = np.zeros((nocc**2, nvir), dtype=r2aaa.dtype)
103    idxoaa = idxoa[:,None] * nocc + idxoa
104    idxoab = idxoa[:,None] * nocc + idxob
105    idxoba = idxob[:,None] * nocc + idxoa
106    idxobb = idxob[:,None] * nocc + idxob
107    # idxvaa = idxva[:,None] * nvir + idxva
108    # idxvab = idxva[:,None] * nvir + idxvb
109    # idxvba = idxvb[:,None] * nvir + idxva
110    # idxvbb = idxvb[:,None] * nvir + idxvb
111    r2aaa = r2aaa.reshape(nocc_a*nocc_a, nvir_a)
112    r2baa = r2baa.reshape(nocc_b*nocc_a, nvir_a)
113    r2abb = r2abb.reshape(nocc_a*nocc_b, nvir_b)
114    r2bbb = r2bbb.reshape(nocc_b*nocc_b, nvir_b)
115    lib.takebak_2d(r2, r2aaa, idxoaa.ravel(), idxva.ravel())
116    lib.takebak_2d(r2, r2baa, idxoba.ravel(), idxva.ravel())
117    lib.takebak_2d(r2, r2abb, idxoab.ravel(), idxvb.ravel())
118    lib.takebak_2d(r2, r2bbb, idxobb.ravel(), idxvb.ravel())
119    r2aba = -r2baa
120    r2bab = -r2abb
121    lib.takebak_2d(r2, r2aba, idxoab.T.ravel(), idxva.ravel())
122    lib.takebak_2d(r2, r2bab, idxoba.T.ravel(), idxvb.ravel())
123    return r1, r2.reshape(nocc, nocc, nvir)
124
125def spin2spatial_ip(r1, r2, orbspin):
126    nocc, nvir = r2.shape[1:]
127
128    idxoa = np.where(orbspin[:nocc] == 0)[0]
129    idxob = np.where(orbspin[:nocc] == 1)[0]
130    idxva = np.where(orbspin[nocc:] == 0)[0]
131    idxvb = np.where(orbspin[nocc:] == 1)[0]
132    nocc_a = len(idxoa)
133    nocc_b = len(idxob)
134    nvir_a = len(idxva)
135    nvir_b = len(idxvb)
136
137    r1a = r1[idxoa]
138    r1b = r1[idxob]
139
140    idxoaa = idxoa[:,None] * nocc + idxoa
141    idxoab = idxoa[:,None] * nocc + idxob
142    idxoba = idxob[:,None] * nocc + idxoa
143    idxobb = idxob[:,None] * nocc + idxob
144    #idxvaa = idxva[:,None] * nvir + idxva
145    #idxvab = idxva[:,None] * nvir + idxvb
146    #idxvba = idxvb[:,None] * nvir + idxva
147    #idxvbb = idxvb[:,None] * nvir + idxvb
148
149    r2 = r2.reshape(nocc**2, nvir)
150    r2aaa = lib.take_2d(r2, idxoaa.ravel(), idxva.ravel())
151    r2baa = lib.take_2d(r2, idxoba.ravel(), idxva.ravel())
152    r2abb = lib.take_2d(r2, idxoab.ravel(), idxvb.ravel())
153    r2bbb = lib.take_2d(r2, idxobb.ravel(), idxvb.ravel())
154
155    r2aaa = r2aaa.reshape(nocc_a, nocc_a, nvir_a)
156    r2baa = r2baa.reshape(nocc_b, nocc_a, nvir_a)
157    r2abb = r2abb.reshape(nocc_a, nocc_b, nvir_b)
158    r2bbb = r2bbb.reshape(nocc_b, nocc_b, nvir_b)
159    return [r1a, r1b], [r2aaa, r2baa, r2abb, r2bbb]
160
161def ipccsd_matvec(eom, vector, imds=None, diag=None):
162    '''For spin orbitals
163    R2 operators of the form s_{ij}^{ b}, i.e. indices jb are coupled.'''
164    # Ref: Tu, Wang, and Li, J. Chem. Phys. 136, 174102 (2012) Eqs.(8)-(9)
165    if imds is None: imds = eom.make_imds()
166    t1, t2 = imds.t1, imds.t2
167    t1a, t1b = t1
168    t2aa, t2ab, t2bb = t2
169    nocca, noccb, nvira, nvirb = t2ab.shape
170    nmoa, nmob = nocca+nvira, noccb+nvirb
171    r1, r2 = vector_to_amplitudes_ip(vector, (nmoa,nmob), (nocca,noccb))
172    r1a, r1b = r1
173    r2aaa, r2baa, r2abb, r2bbb = r2
174
175    #Foo, Fov, and Wooov
176    Hr1a  = np.einsum('me,mie->i', imds.Fov, r2aaa)
177    Hr1a -= np.einsum('ME,iME->i', imds.FOV, r2abb)
178    Hr1b  = np.einsum('ME,MIE->I', imds.FOV, r2bbb)
179    Hr1b -= np.einsum('me,Ime->I', imds.Fov, r2baa)
180
181    Hr1a += -np.einsum('mi,m->i', imds.Foo, r1a)
182    Hr1b += -np.einsum('MI,M->I', imds.FOO, r1b)
183
184    Hr1a += -0.5*np.einsum('nime,mne->i', imds.Wooov, r2aaa)
185    Hr1b +=      np.einsum('NIme,Nme->I', imds.WOOov, r2baa)
186    Hr1b += -0.5*np.einsum('NIME,MNE->I', imds.WOOOV, r2bbb)
187    Hr1a +=      np.einsum('niME,nME->i', imds.WooOV, r2abb)
188
189    # Fvv term
190    Hr2aaa = lib.einsum('be,ije->ijb', imds.Fvv, r2aaa)
191    Hr2abb = lib.einsum('BE,iJE->iJB', imds.FVV, r2abb)
192    Hr2bbb = lib.einsum('BE,IJE->IJB', imds.FVV, r2bbb)
193    Hr2baa = lib.einsum('be,Ije->Ijb', imds.Fvv, r2baa)
194
195    # Foo term
196    tmpa = lib.einsum('mi,mjb->ijb', imds.Foo, r2aaa)
197    Hr2aaa -= tmpa - tmpa.transpose((1,0,2))
198    Hr2abb -= lib.einsum('mi,mJB->iJB', imds.Foo, r2abb)
199    Hr2abb -= lib.einsum('MJ,iMB->iJB', imds.FOO, r2abb)
200    Hr2baa -= lib.einsum('MI,Mjb->Ijb', imds.FOO, r2baa)
201    Hr2baa -= lib.einsum('mj,Imb->Ijb', imds.Foo, r2baa)
202    tmpb = lib.einsum('MI,MJB->IJB', imds.FOO, r2bbb)
203    Hr2bbb -= tmpb - tmpb.transpose((1,0,2))
204
205    # Wovoo term
206    Hr2aaa -= np.einsum('mjbi,m->ijb', imds.Woovo, r1a)
207    Hr2abb += np.einsum('miBJ,m->iJB', imds.WooVO, r1a)
208    Hr2baa += np.einsum('MIbj,M->Ijb', imds.WOOvo, r1b)
209    Hr2bbb -= np.einsum('MJBI,M->IJB', imds.WOOVO, r1b)
210
211    # Woooo term
212    Hr2aaa += .5 * lib.einsum('minj,mnb->ijb', imds.Woooo, r2aaa)
213    Hr2abb +=      lib.einsum('miNJ,mNB->iJB', imds.WooOO, r2abb)
214    Hr2bbb += .5 * lib.einsum('MINJ,MNB->IJB', imds.WOOOO, r2bbb)
215    Hr2baa +=      lib.einsum('njMI,Mnb->Ijb', imds.WooOO, r2baa)
216
217    # Wovvo terms
218    tmp = lib.einsum('mebj,ime->ijb', imds.Wovvo, r2aaa)
219    tmp += lib.einsum('MEbj,iME->ijb', imds.WOVvo, r2abb)
220    Hr2aaa += tmp - tmp.transpose(1, 0, 2)
221
222    WooVV = -imds.WoVVo.transpose(0,3,2,1)
223    WOOvv = -imds.WOvvO.transpose(0,3,2,1)
224    Hr2abb += lib.einsum('MEBJ,iME->iJB', imds.WOVVO, r2abb)
225    Hr2abb += lib.einsum('meBJ,ime->iJB', imds.WovVO, r2aaa)
226    Hr2abb += -lib.einsum('miBE,mJE->iJB', WooVV, r2abb)
227
228    Hr2baa += lib.einsum('meaj,Ime->Ija', imds.Wovvo, r2baa)
229    Hr2baa += lib.einsum('MEaj,IME->Ija', imds.WOVvo, r2bbb)
230    Hr2baa += -lib.einsum('MIab,Mjb->Ija', WOOvv, r2baa)
231
232    tmp = lib.einsum('MEBJ,IME->IJB', imds.WOVVO, r2bbb)
233    tmp += lib.einsum('meBJ,Ime->IJB', imds.WovVO, r2baa)
234    Hr2bbb += tmp - tmp.transpose(1, 0, 2)
235
236    # T2 term
237    Hr2aaa -= 0.5 * lib.einsum('menf,mnf,jibe->ijb', imds.Wovov, r2aaa, t2aa)
238    Hr2aaa -= lib.einsum('meNF,mNF,jibe->ijb', imds.WovOV, r2abb, t2aa)
239
240    Hr2abb -= 0.5 * lib.einsum('menf,mnf,iJeB->iJB', imds.Wovov, r2aaa, t2ab)
241    Hr2abb -= lib.einsum('meNF,mNF,iJeB->iJB', imds.WovOV, r2abb, t2ab)
242
243    Hr2baa -= 0.5 * lib.einsum('MENF,MNF,jIbE->Ijb', imds.WOVOV, r2bbb, t2ab)
244    Hr2baa -= lib.einsum('nfME,Mnf,jIbE->Ijb', imds.WovOV, r2baa, t2ab)
245
246    Hr2bbb -= 0.5 * lib.einsum('MENF,MNF,JIBE->IJB', imds.WOVOV, r2bbb, t2bb)
247    Hr2bbb -= lib.einsum('nfME,Mnf,JIBE->IJB', imds.WovOV, r2baa, t2bb)
248
249    vector = amplitudes_to_vector_ip([Hr1a, Hr1b], [Hr2aaa, Hr2baa, Hr2abb, Hr2bbb])
250    return vector
251
252def ipccsd_diag(eom, imds=None):
253    if imds is None: imds = eom.make_imds()
254    t1, t2 = imds.t1, imds.t2
255    t1a, t1b = t1
256    t2aa, t2ab, t2bb = t2
257
258    nocc_a, nvir_a = t1a.shape
259    nocc_b, nvir_b = t1b.shape
260
261    Hr1a = -np.diag(imds.Foo)
262    Hr1b = -np.diag(imds.FOO)
263
264    Fvv_diag = np.diag(imds.Fvv)
265    Foo_diag = np.diag(imds.Foo)
266    FOO_diag = np.diag(imds.FOO)
267    FVV_diag = np.diag(imds.FVV)
268
269    Woooo_slice = np.einsum('iijj->ij',imds.Woooo)
270    Wovvo_slice = np.einsum('iaai->ia',imds.Wovvo)
271    WooOO_slice = np.einsum('jjii->ij',imds.WooOO)
272    WOvvO_slice = np.einsum('iaai->ia',imds.WOvvO)
273    WooOO_slice_T = np.einsum('iijj->ij',imds.WooOO)
274    WoVVo_slice = np.einsum('iaai->ia',imds.WoVVo)
275    WOVVO_slice = np.einsum('jaaj->ja',imds.WOVVO)
276    WOOOO_slice = np.einsum('iijj->ij',imds.WOOOO)
277
278    Wovov_t2_dot = np.einsum('jaib,jiab->ija',imds.Wovov,t2aa)
279    WovOV_t2_dot = np.einsum('ibja,ijba->ija',imds.WovOV,t2ab)
280    WovOV_t2_dot_T = np.einsum('jaib,jiab->ija',imds.WovOV,t2ab)
281    WOVOV_t2_dot = np.einsum('jaib,jiab->ija',imds.WOVOV,t2bb)
282
283    Hr2aaa = Fvv_diag[None,None,:] - Foo_diag[:,None,None] - Foo_diag[None,:,None] \
284             + Woooo_slice[:,:,None] + Wovvo_slice[:,None,:] + Wovvo_slice[None,:,:] \
285             - Wovov_t2_dot
286
287    Hr2baa = Fvv_diag[None,None,:] - FOO_diag[:,None,None] - Foo_diag[None,:,None] \
288             + WooOO_slice[:,:,None] + WOvvO_slice[:,None,:] + Wovvo_slice[None,:,:] \
289             - WovOV_t2_dot_T
290
291    Hr2abb = FVV_diag[None,None,:] - Foo_diag[:,None,None] - FOO_diag[None,:,None] \
292             + WooOO_slice_T[:,:,None] + WoVVo_slice[:,None,:] + WOVVO_slice[None,:,:] \
293             - WovOV_t2_dot
294
295    Hr2bbb = FVV_diag[None,None,:] - FOO_diag[:,None,None] - FOO_diag[None,:,None] \
296             + WOOOO_slice[:,:,None] + WOVVO_slice[:,None,:] + WOVVO_slice[None,:,:] \
297             - WOVOV_t2_dot
298
299    vector = amplitudes_to_vector_ip([Hr1a, Hr1b], [Hr2aaa, Hr2baa, Hr2abb, Hr2bbb])
300    return vector
301
302
303class EOMIP(eom_rccsd.EOMIP):
304    matvec = ipccsd_matvec
305    l_matvec = None
306    get_diag = ipccsd_diag
307    ipccsd_star = None
308    ccsd_star_contract = None
309
310    def __init__(self, cc):
311        eom_rccsd.EOMIP.__init__(self, cc)
312        self.nocc = cc.get_nocc()
313        self.nmo = cc.get_nmo()
314
315    def get_init_guess(self, nroots=1, koopmans=True, diag=None):
316        if koopmans:
317            nocca, noccb = self.nocc
318            idx = diag[:nocca+noccb].argsort()
319        else:
320            idx = diag.argsort()
321
322        size = self.vector_size()
323        dtype = getattr(diag, 'dtype', np.double)
324        nroots = min(nroots, size)
325        guess = []
326        for i in idx[:nroots]:
327            g = np.zeros(size, dtype)
328            g[i] = 1.0
329            guess.append(g)
330        return guess
331
332    def vector_to_amplitudes(self, vector, nmo=None, nocc=None):
333        if nmo is None: nmo = self.nmo
334        if nocc is None: nocc = self.nocc
335        return vector_to_amplitudes_ip(vector, nmo, nocc)
336
337    def amplitudes_to_vector(self, r1, r2):
338        return amplitudes_to_vector_ip(r1, r2)
339
340    def vector_size(self):
341        '''size of the vector based on spin-orbital basis'''
342        nocca, noccb = self.nocc
343        nmoa, nmob = self.nmo
344        nvira, nvirb = nmoa-nocca, nmob-noccb
345        return (nocca + noccb
346                + nocca*(nocca-1)//2*nvira + noccb*nocca*nvira
347                + nocca*noccb*nvirb + noccb*(noccb-1)//2*nvirb)
348
349    def make_imds(self, eris=None):
350        imds = _IMDS(self._cc, eris)
351        imds.make_ip()
352        return imds
353
354########################################
355# EOM-EA-CCSD
356########################################
357
358def vector_to_amplitudes_ea(vector, nmo, nocc):
359    nocca, noccb = nocc
360    nmoa, nmob = nmo
361    nvira, nvirb = nmoa-nocca, nmob-noccb
362
363    sizes = (nvira, nvirb, nocca*nvira*(nvira-1)//2, nocca*nvirb*nvira,
364             noccb*nvira*nvirb, noccb*nvirb*(nvirb-1)//2)
365    sections = np.cumsum(sizes[:-1])
366    r1a, r1b, r2a, r2aba, r2bab, r2b = np.split(vector, sections)
367    r2a = r2a.reshape(nocca,nvira*(nvira-1)//2)
368    r2b = r2b.reshape(noccb,nvirb*(nvirb-1)//2)
369    r2aba = r2aba.reshape(nocca,nvirb,nvira).copy()
370    r2bab = r2bab.reshape(noccb,nvira,nvirb).copy()
371
372    idxa = np.tril_indices(nvira, -1)
373    idxb = np.tril_indices(nvirb, -1)
374    r2aaa = np.zeros((nocca,nvira,nvira), vector.dtype)
375    r2bbb = np.zeros((noccb,nvirb,nvirb), vector.dtype)
376    r2aaa[:,idxa[0],idxa[1]] = r2a
377    r2aaa[:,idxa[1],idxa[0]] =-r2a
378    r2bbb[:,idxb[0],idxb[1]] = r2b
379    r2bbb[:,idxb[1],idxb[0]] =-r2b
380
381    r1 = (r1a.copy(), r1b.copy())
382    r2 = (r2aaa, r2aba, r2bab, r2bbb)
383    return r1, r2
384
385def amplitudes_to_vector_ea(r1, r2):
386    r1a, r1b = r1
387    r2aaa, r2aba, r2bab, r2bbb = r2
388    nocca, nvirb, nvira = r2aba.shape
389    idxa = np.tril_indices(nvira, -1)
390    idxb = np.tril_indices(nvirb, -1)
391    return np.hstack((r1a, r1b,
392                      r2aaa[:,idxa[0],idxa[1]].ravel(),
393                      r2aba.ravel(), r2bab.ravel(),
394                      r2bbb[:,idxb[0],idxb[1]].ravel()))
395
396def spatial2spin_ea(r1, r2, orbspin=None):
397    '''Convert R1/R2 of spatial orbital representation to R1/R2 of
398    spin-orbital representation
399    '''
400    r1a, r1b = r1
401    r2aaa, r2aba, r2bab, r2bbb = r2
402    nocc_a, nvir_a = r2aaa.shape[:2]
403    nocc_b, nvir_b = r2bbb.shape[:2]
404
405    if orbspin is None:
406        orbspin = np.zeros((nocc_a+nvir_a)*2, dtype=int)
407        orbspin[1::2] = 1
408
409    nocc = nocc_a + nocc_b
410    nvir = nvir_a + nvir_b
411    idxoa = np.where(orbspin[:nocc] == 0)[0]
412    idxob = np.where(orbspin[:nocc] == 1)[0]
413    idxva = np.where(orbspin[nocc:] == 0)[0]
414    idxvb = np.where(orbspin[nocc:] == 1)[0]
415
416    r1 = np.zeros((nvir), dtype=r1a.dtype)
417    r1[idxva] = r1a
418    r1[idxvb] = r1b
419
420    r2 = np.zeros((nocc, nvir**2), dtype=r2aaa.dtype)
421    #idxoaa = idxoa[:,None] * nocc + idxoa
422    #idxoab = idxoa[:,None] * nocc + idxob
423    #idxoba = idxob[:,None] * nocc + idxoa
424    #idxobb = idxob[:,None] * nocc + idxob
425    idxvaa = idxva[:,None] * nvir + idxva
426    idxvab = idxva[:,None] * nvir + idxvb
427    idxvba = idxvb[:,None] * nvir + idxva
428    idxvbb = idxvb[:,None] * nvir + idxvb
429
430    r2aaa = r2aaa.reshape(nocc_a, nvir_a*nvir_a)
431    r2aba = r2aba.reshape(nocc_a, nvir_b*nvir_a)
432    r2bab = r2bab.reshape(nocc_b, nvir_a*nvir_b)
433    r2bbb = r2bbb.reshape(nocc_b, nvir_b*nvir_b)
434
435    lib.takebak_2d(r2, r2aaa, idxoa.ravel(), idxvaa.ravel())
436    lib.takebak_2d(r2, r2aba, idxoa.ravel(), idxvba.ravel())
437    lib.takebak_2d(r2, r2bab, idxob.ravel(), idxvab.ravel())
438    lib.takebak_2d(r2, r2bbb, idxob.ravel(), idxvbb.ravel())
439    r2aab = -r2aba
440    r2bba = -r2bab
441    lib.takebak_2d(r2, r2bba, idxob.ravel(), idxvba.T.ravel())
442    lib.takebak_2d(r2, r2aab, idxoa.ravel(), idxvab.T.ravel())
443    r2 = r2.reshape(nocc, nvir, nvir)
444    return r1, r2
445
446def spin2spatial_ea(r1, r2, orbspin):
447    nocc, nvir = r2.shape[:2]
448
449    idxoa = np.where(orbspin[:nocc] == 0)[0]
450    idxob = np.where(orbspin[:nocc] == 1)[0]
451    idxva = np.where(orbspin[nocc:] == 0)[0]
452    idxvb = np.where(orbspin[nocc:] == 1)[0]
453    nocc_a = len(idxoa)
454    nocc_b = len(idxob)
455    nvir_a = len(idxva)
456    nvir_b = len(idxvb)
457
458    r1a = r1[idxva]
459    r1b = r1[idxvb]
460
461    #idxoaa = idxoa[:,None] * nocc + idxoa
462    #idxoab = idxoa[:,None] * nocc + idxob
463    #idxoba = idxob[:,None] * nocc + idxoa
464    #idxobb = idxob[:,None] * nocc + idxob
465    idxvaa = idxva[:,None] * nvir + idxva
466    idxvab = idxva[:,None] * nvir + idxvb
467    idxvba = idxvb[:,None] * nvir + idxva
468    idxvbb = idxvb[:,None] * nvir + idxvb
469
470    r2 = r2.reshape(nocc, nvir**2)
471    r2aaa = lib.take_2d(r2, idxoa.ravel(), idxvaa.ravel())
472    r2aba = lib.take_2d(r2, idxoa.ravel(), idxvba.ravel())
473    r2bab = lib.take_2d(r2, idxob.ravel(), idxvab.ravel())
474    r2bbb = lib.take_2d(r2, idxob.ravel(), idxvbb.ravel())
475
476    r2aaa = r2aaa.reshape(nocc_a, nvir_a, nvir_a)
477    r2aba = r2aba.reshape(nocc_a, nvir_b, nvir_a)
478    r2bab = r2bab.reshape(nocc_b, nvir_a, nvir_b)
479    r2bbb = r2bbb.reshape(nocc_b, nvir_b, nvir_b)
480    return [r1a, r1b], [r2aaa, r2aba, r2bab, r2bbb]
481
482def eaccsd_matvec(eom, vector, imds=None, diag=None):
483    '''For spin orbitals.
484
485    R2 operators of the form s_{ j}^{ab}, i.e. indices jb are coupled.'''
486    # Ref: Nooijen and Bartlett, J. Chem. Phys. 102, 3629 (1994) Eqs.(30)-(31)
487    if imds is None: imds = eom.make_imds()
488    t1, t2, eris = imds.t1, imds.t2, imds.eris
489    t1a, t1b = t1
490    t2aa, t2ab, t2bb = t2
491    nocca, noccb, nvira, nvirb = t2ab.shape
492    nmoa, nmob = nocca+nvira, noccb+nvirb
493    r1, r2 = vector_to_amplitudes_ea(vector, (nmoa,nmob), (nocca,noccb))
494    r1a, r1b = r1
495    r2aaa, r2aba, r2bab, r2bbb = r2
496
497    # Fov terms
498    Hr1a  = np.einsum('ld,lad->a', imds.Fov, r2aaa)
499    Hr1a += np.einsum('LD,LaD->a', imds.FOV, r2bab)
500    Hr1b  = np.einsum('ld,lAd->A', imds.Fov, r2aba)
501    Hr1b += np.einsum('LD,LAD->A', imds.FOV, r2bbb)
502
503    # Fvv terms
504    Hr1a += np.einsum('ac,c->a', imds.Fvv, r1a)
505    Hr1b += np.einsum('AC,C->A', imds.FVV, r1b)
506
507    # Wvovv
508    Hr1a += 0.5*lib.einsum('acld,lcd->a', imds.Wvvov, r2aaa)
509    Hr1a +=     lib.einsum('acLD,LcD->a', imds.WvvOV, r2bab)
510    Hr1b += 0.5*lib.einsum('ACLD,LCD->A', imds.WVVOV, r2bbb)
511    Hr1b +=     lib.einsum('ACld,lCd->A', imds.WVVov, r2aba)
512
513    #** Wvvvv term
514    #:Hr2aaa = lib.einsum('acbd,jcd->jab', eris_vvvv, r2aaa)
515    #:Hr2aba = lib.einsum('bdac,jcd->jab', eris_vvVV, r2aba)
516    #:Hr2bab = lib.einsum('acbd,jcd->jab', eris_vvVV, r2bab)
517    #:Hr2bbb = lib.einsum('acbd,jcd->jab', eris_VVVV, r2bbb)
518    u2 = (r2aaa + np.einsum('c,jd->jcd', r1a, t1a) - np.einsum('d,jc->jcd', r1a, t1a),
519          r2aba + np.einsum('c,jd->jcd', r1b, t1a),
520          r2bab + np.einsum('c,jd->jcd', r1a, t1b),
521          r2bbb + np.einsum('c,jd->jcd', r1b, t1b) - np.einsum('d,jc->jcd', r1b, t1b))
522    Hr2aaa, Hr2aba, Hr2bab, Hr2bbb = _add_vvvv_ea(eom._cc, u2, eris)
523    u2 = None
524
525    tauaa, tauab, taubb = uccsd.make_tau(t2, t1, t1)
526    eris_ovov = np.asarray(eris.ovov)
527    eris_OVOV = np.asarray(eris.OVOV)
528    eris_ovOV = np.asarray(eris.ovOV)
529    tmpaaa = lib.einsum('menf,jef->mnj', eris_ovov, r2aaa) * .5
530    Hr2aaa += lib.einsum('mnj,mnab->jab', tmpaaa, tauaa)
531    tmpaaa = tauaa = None
532
533    tmpbbb = lib.einsum('menf,jef->mnj', eris_OVOV, r2bbb) * .5
534    Hr2bbb += lib.einsum('mnj,mnab->jab', tmpbbb, taubb)
535    tmpbbb = taubb = None
536
537    tmpabb = lib.einsum('menf,jef->mnj', eris_ovOV, r2bab)
538    Hr2bab += lib.einsum('mnj,mnab->jab', tmpabb, tauab)
539    tmpaba = lib.einsum('nfme,jef->nmj', eris_ovOV, r2aba)
540    Hr2aba += lib.einsum('nmj,nmba->jab', tmpaba, tauab)
541    tmpaba = tauab = None
542    eris_ovov = eris_OVOV = eris_ovOV = None
543
544    eris_ovvv = imds.eris.get_ovvv(slice(None))
545    tmpaaa = lib.einsum('mebf,jef->mjb', eris_ovvv, r2aaa)
546    tmpaaa = lib.einsum('mjb,ma->jab', tmpaaa, t1a)
547    Hr2aaa-= tmpaaa - tmpaaa.transpose(0,2,1)
548    tmpaaa = eris_ovvv = None
549
550    eris_OVVV = imds.eris.get_OVVV(slice(None))
551    tmpbbb = lib.einsum('mebf,jef->mjb', eris_OVVV, r2bbb)
552    tmpbbb = lib.einsum('mjb,ma->jab', tmpbbb, t1b)
553    Hr2bbb-= tmpbbb - tmpbbb.transpose(0,2,1)
554    tmpbbb = eris_OVVV = None
555
556    eris_ovVV = imds.eris.get_ovVV(slice(None))
557    eris_OVvv = imds.eris.get_OVvv(slice(None))
558    tmpaab = lib.einsum('meBF,jFe->mjB', eris_ovVV, r2aba)
559    Hr2aba-= lib.einsum('mjB,ma->jBa', tmpaab, t1a)
560    tmpabb = lib.einsum('meBF,JeF->mJB', eris_ovVV, r2bab)
561    Hr2bab-= lib.einsum('mJB,ma->JaB', tmpabb, t1a)
562    tmpaab = tmpabb = eris_ovVV = None
563
564    tmpbaa = lib.einsum('MEbf,jEf->Mjb', eris_OVvv, r2aba)
565    Hr2aba-= lib.einsum('Mjb,MA->jAb', tmpbaa, t1b)
566    tmpbba = lib.einsum('MEbf,JfE->MJb', eris_OVvv, r2bab)
567    Hr2bab-= lib.einsum('MJb,MA->JbA', tmpbba, t1b)
568    tmpbaa = tmpbba = eris_OVvv = None
569    #** Wvvvv term end
570
571    # Wvvvo
572    Hr2aaa += np.einsum('acbj,c->jab', imds.Wvvvo, r1a)
573    Hr2bbb += np.einsum('ACBJ,C->JAB', imds.WVVVO, r1b)
574    Hr2bab += np.einsum('acBJ,c->JaB', imds.WvvVO, r1a)
575    Hr2aba += np.einsum('ACbj,C->jAb', imds.WVVvo, r1b)
576
577    # Wovvo
578    tmp2aa = lib.einsum('ldbj,lad->jab', imds.Wovvo, r2aaa)
579    tmp2aa += lib.einsum('ldbj,lad->jab', imds.WOVvo, r2bab)
580    Hr2aaa += tmp2aa - tmp2aa.transpose(0,2,1)
581
582    Hr2bab += lib.einsum('ldbj,lad->jab', imds.WovVO, r2aaa)
583    Hr2bab += lib.einsum('ldbj,lad->jab', imds.WOVVO, r2bab)
584    Hr2bab += lib.einsum('ldaj,ldb->jab', imds.WOvvO, r2bab)
585
586    Hr2aba += lib.einsum('ldbj,lad->jab', imds.WOVvo, r2bbb)
587    Hr2aba += lib.einsum('ldbj,lad->jab', imds.Wovvo, r2aba)
588    Hr2aba += lib.einsum('ldaj,ldb->jab', imds.WoVVo, r2aba)
589
590    tmp2bb = lib.einsum('ldbj,lad->jab', imds.WOVVO, r2bbb)
591    tmp2bb += lib.einsum('ldbj,lad->jab', imds.WovVO, r2aba)
592    Hr2bbb += tmp2bb - tmp2bb.transpose(0,2,1)
593
594    #Fvv Term
595    tmpa = lib.einsum('ac,jcb->jab', imds.Fvv, r2aaa)
596    Hr2aaa += tmpa - tmpa.transpose((0,2,1))
597    Hr2aba += lib.einsum('AC,jCb->jAb', imds.FVV, r2aba)
598    Hr2bab += lib.einsum('ac,JcB->JaB', imds.Fvv, r2bab)
599    Hr2aba += lib.einsum('bc, jAc -> jAb', imds.Fvv, r2aba)
600    Hr2bab += lib.einsum('BC, JaC -> JaB', imds.FVV, r2bab)
601    tmpb = lib.einsum('AC,JCB->JAB', imds.FVV, r2bbb)
602    Hr2bbb += tmpb - tmpb.transpose((0,2,1))
603
604    #Foo Term
605    Hr2aaa -= lib.einsum('lj,lab->jab', imds.Foo, r2aaa)
606    Hr2bbb -= lib.einsum('LJ,LAB->JAB', imds.FOO, r2bbb)
607    Hr2bab -= lib.einsum('LJ,LaB->JaB', imds.FOO, r2bab)
608    Hr2aba -= lib.einsum('lj,lAb->jAb', imds.Foo, r2aba)
609
610    # Woovv term
611    Hr2aaa -= 0.5 * lib.einsum('kcld,lcd,kjab->jab', imds.Wovov, r2aaa, t2aa)
612    Hr2bab -= 0.5 * lib.einsum('kcld,lcd,kJaB->JaB', imds.Wovov, r2aaa, t2ab)
613
614    Hr2aba -= lib.einsum('ldKC,lCd,jKbA->jAb', imds.WovOV, r2aba, t2ab)
615    Hr2aaa -= lib.einsum('kcLD,LcD,kjab->jab', imds.WovOV, r2bab, t2aa)
616
617    Hr2aba -= 0.5 * lib.einsum('KCLD,LCD,jKbA->jAb', imds.WOVOV, r2bbb, t2ab)
618    Hr2bbb -= 0.5 * lib.einsum('KCLD,LCD,KJAB->JAB', imds.WOVOV, r2bbb, t2bb)
619
620    Hr2bbb -= lib.einsum('ldKC,lCd,KJAB->JAB', imds.WovOV, r2aba, t2bb)
621    Hr2bab -= lib.einsum('kcLD,LcD,kJaB->JaB', imds.WovOV, r2bab, t2ab)
622
623    vector = amplitudes_to_vector_ea([Hr1a, Hr1b], [Hr2aaa, Hr2aba, Hr2bab, Hr2bbb])
624    return vector
625
626def _add_vvvv_ea(mycc, r2, eris):
627    time0 = logger.process_clock(), logger.perf_counter()
628    log = logger.Logger(mycc.stdout, mycc.verbose)
629    r2aaa, r2aba, r2bab, r2bbb = r2
630    nocca, noccb = mycc.nocc
631
632    if mycc.direct:
633        if getattr(eris, 'mo_coeff', None) is not None:
634            mo_a, mo_b = eris.mo_coeff
635        else:
636            moidxa, moidxb = mycc.get_frozen_mask()
637            mo_a = mycc.mo_coeff[0][:,moidxa]
638            mo_b = mycc.mo_coeff[1][:,moidxb]
639
640        r2aaa = lib.einsum('xab,pa->xpb', r2aaa, mo_a[:,nocca:])
641        r2aaa = lib.einsum('xab,pb->xap', r2aaa, mo_a[:,nocca:])
642        r2aba = lib.einsum('xab,pa->xpb', r2aba, mo_b[:,noccb:])
643        r2aba = lib.einsum('xab,pb->xap', r2aba, mo_a[:,nocca:])
644        r2bab = lib.einsum('xab,pa->xpb', r2bab, mo_a[:,nocca:])
645        r2bab = lib.einsum('xab,pb->xap', r2bab, mo_b[:,noccb:])
646        r2bbb = lib.einsum('xab,pa->xpb', r2bbb, mo_b[:,noccb:])
647        r2bbb = lib.einsum('xab,pb->xap', r2bbb, mo_b[:,noccb:])
648
649        r2 = np.vstack((r2aaa, r2aba, r2bab, r2bbb))
650        r2aaa = r2aba = r2bab = r2bbb = None
651        time0 = log.timer_debug1('vvvv-tau', *time0)
652
653        buf = ccsd._contract_vvvv_t2(mycc, mycc.mol, None, r2, verbose=log)
654        sections = np.cumsum([nocca,nocca,noccb])
655        Hr2aaa, Hr2aba, Hr2bab, Hr2bbb = np.split(buf, sections)
656        buf = None
657
658        Hr2aaa = lib.einsum('xpb,pa->xab', Hr2aaa, mo_a[:,nocca:])
659        Hr2aaa = lib.einsum('xap,pb->xab', Hr2aaa, mo_a[:,nocca:])
660        Hr2aba = lib.einsum('xpb,pa->xab', Hr2aba, mo_b[:,noccb:])
661        Hr2aba = lib.einsum('xap,pb->xab', Hr2aba, mo_a[:,nocca:])
662        Hr2bab = lib.einsum('xpb,pa->xab', Hr2bab, mo_a[:,nocca:])
663        Hr2bab = lib.einsum('xap,pb->xab', Hr2bab, mo_b[:,noccb:])
664        Hr2bbb = lib.einsum('xpb,pa->xab', Hr2bbb, mo_b[:,noccb:])
665        Hr2bbb = lib.einsum('xap,pb->xab', Hr2bbb, mo_b[:,noccb:])
666
667    elif r2aaa.dtype == np.double:
668        r2aab = np.asarray(r2aba.transpose(0,2,1), order='C')
669        Hr2aab = eris._contract_vvVV_t2(mycc, r2aab, mycc.direct, None)
670        Hr2aba = np.asarray(Hr2aab.transpose(0,2,1), order='C')
671        r2aab = Hr2aab = None
672        Hr2bab = eris._contract_vvVV_t2(mycc, r2bab, mycc.direct, None)
673        Hr2aaa = eris._contract_vvvv_t2(mycc, r2aaa, mycc.direct, None)
674        Hr2bbb = eris._contract_VVVV_t2(mycc, r2bbb, mycc.direct, None)
675
676    else:
677        noccb, nvira, nvirb = r2bab.shape
678        eris_vvvv = ao2mo.restore(1, np.asarray(eris.vvvv), nvira)
679        Hr2aaa = lib.einsum('acbd,jcd->jab', eris_vvvv, r2aaa)
680        eris_vvvv = None
681
682        eris_VVVV = ao2mo.restore(1, np.asarray(eris.VVVV), nvirb)
683        Hr2bbb = lib.einsum('acbd,jcd->jab', eris_VVVV, r2bbb)
684        eris_VVVV = None
685
686        sqa = lib.square_mat_in_trilu_indices(nvira)
687        sqb = lib.square_mat_in_trilu_indices(nvirb)
688        eris_vvVV = np.asarray(eris.vvVV)[:,sqb][sqa]
689        Hr2aba = lib.einsum('bdac,jcd->jab', eris_vvVV, r2aba)
690        Hr2bab = lib.einsum('acbd,jcd->jab', eris_vvVV, r2bab)
691        eris_vvVV = None
692
693    return Hr2aaa, Hr2aba, Hr2bab, Hr2bbb
694
695def eaccsd_diag(eom, imds=None):
696    if imds is None: imds = eom.make_imds()
697    eris = imds.eris
698    t1, t2 = imds.t1, imds.t2
699    t1a, t1b = t1
700    t2aa, t2ab, t2bb = t2
701    t2ba = t2ab.transpose(1,0,3,2)
702
703    nocca, nvira = t1a.shape
704    noccb, nvirb = t1b.shape
705
706    Hr1a = np.diag(imds.Fvv)
707    Hr1b = np.diag(imds.FVV)
708
709    #-------------- intermediates
710
711    Fvv_diag = np.diag(imds.Fvv)
712    Foo_diag = np.diag(imds.Foo)
713    FOO_diag = np.diag(imds.FOO)
714    FVV_diag = np.diag(imds.FVV)
715
716    Wovvo_slice = np.einsum('jbbj->jb',imds.Wovvo)
717    Wovov_t2_dot = np.einsum('iajb,ijab->jab',imds.Wovov,t2aa)
718    WoVVo_slice  = np.einsum('jaaj->ja',imds.WoVVo)
719    WovOV_t2_dot = np.einsum('jbia,ijab->jab',imds.WovOV,t2ba)
720    WOVVO_slice = np.einsum('jaaj->ja',imds.WOVVO)
721    WOvvO_slice = np.einsum('jbbj->jb',imds.WOvvO)
722    WovOV_t2_dot_T = np.einsum('ibja,ijba->jab',imds.WovOV,t2ab)
723    WOVOV_t2_dot = np.einsum('iajb,ijab->jab',imds.WOVOV,t2bb)
724
725    #-------------- contraction
726
727    Hr2aaa = Fvv_diag[None,:,None]+Fvv_diag[None,None,:]-Foo_diag[:,None,None]+ \
728             Wovvo_slice[:,None,:]+Wovvo_slice[:,:,None]-Wovov_t2_dot
729
730    Hr2aba = FVV_diag[None,:,None]+Fvv_diag[None,None,:]-Foo_diag[:,None,None]+ \
731             Wovvo_slice[:,None,:]+WoVVo_slice[:,:,None]-WovOV_t2_dot
732
733    Hr2bab = -FOO_diag[:,None,None]+FVV_diag[None,:,None]+Fvv_diag[None,None,:]+ \
734             WOVVO_slice[:,:,None]+WOvvO_slice[:,None,:]-WovOV_t2_dot_T
735    Hr2bab = Hr2bab.transpose(0,2,1)
736
737    Hr2bbb = -FOO_diag[:,None,None]+FVV_diag[None,:,None]+FVV_diag[None,None,:]+ \
738             WOVVO_slice[:,:,None]+WOVVO_slice[:,None,:]-WOVOV_t2_dot
739
740#    if imds.Wvvvv is not None:
741#        Wvvvv_slice = np.einsum('aabb->ab',imds.Wvvvv)
742#        Hr2aaa += 0.5 * Wvvvv_slice[None,:,:]
743#        WVVvv_slice = np.einsum('aabb->ba',imds.WvvVV)
744#        Hr2aba += WVVvv_slice[None,:,:]
745#        WvvVV_slice = np.einsum('aabb->ab',imds.WvvVV)
746#        Hr2bab += WvvVV_slice[None,:,:]
747#        WVVVV_slice = np.einsum('aabb->ab',imds.WVVVV)
748#        Hr2bbb += 0.5 * WVVVV_slice[None,:,:]
749
750# TODO: test Wvvvv contribution
751    # See also the code for Wvvvv contribution in function eeccsd_diag
752    tauaa, tauab, taubb = uccsd.make_tau(t2, t1, t1)
753    eris_ovov = np.asarray(eris.ovov)
754    eris_OVOV = np.asarray(eris.OVOV)
755    eris_ovOV = np.asarray(eris.ovOV)
756    Wvvaa = .5*np.einsum('mnab,manb->ab', tauaa, eris_ovov)
757    Wvvbb = .5*np.einsum('mnab,manb->ab', taubb, eris_OVOV)
758    Wvvab =    np.einsum('mNaB,maNB->aB', tauab, eris_ovOV)
759    eris_ovov = eris_OVOV = eris_ovOV = None
760
761    mem_now = lib.current_memory()[0]
762    max_memory = max(0, eom.max_memory - mem_now)
763    blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira**3*3))))
764    for p0,p1 in lib.prange(0, nocca, blksize):
765        ovvv = eris.get_ovvv(slice(p0,p1))  # ovvv = eris.ovvv[p0:p1]
766        Wvvaa += np.einsum('mb,maab->ab', t1a[p0:p1], ovvv)
767        Wvvaa -= np.einsum('mb,mbaa->ab', t1a[p0:p1], ovvv)
768        ovvv = None
769    blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb**3*3))))
770    for p0, p1 in lib.prange(0, noccb, blksize):
771        OVVV = eris.get_OVVV(slice(p0,p1))  # OVVV = eris.OVVV[p0:p1]
772        Wvvbb += np.einsum('mb,maab->ab', t1b[p0:p1], OVVV)
773        Wvvbb -= np.einsum('mb,mbaa->ab', t1b[p0:p1], OVVV)
774        OVVV = None
775    blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira*nvirb**2*3))))
776    for p0,p1 in lib.prange(0, nocca, blksize):
777        ovVV = eris.get_ovVV(slice(p0,p1))  # ovVV = eris.ovVV[p0:p1]
778        Wvvab -= np.einsum('mb,mbaa->ba', t1a[p0:p1], ovVV)
779        ovVV = None
780    blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb*nvira**2*3))))
781    for p0, p1 in lib.prange(0, noccb, blksize):
782        OVvv = eris.get_OVvv(slice(p0,p1))  # OVvv = eris.OVvv[p0:p1]
783        Wvvab -= np.einsum('mb,mbaa->ab', t1b[p0:p1], OVvv)
784        OVvv = None
785    Wvvaa = Wvvaa + Wvvaa.T
786    Wvvbb = Wvvbb + Wvvbb.T
787    if eris.vvvv is not None:
788        for i in range(nvira):
789            i0 = i*(i+1)//2
790            vvv = lib.unpack_tril(np.asarray(eris.vvvv[i0:i0+i+1]))
791            tmp = np.einsum('bb->b', vvv[i])
792            Wvvaa[i] += tmp
793            tmp = np.einsum('bb->b', vvv[:,:i+1,i])
794            Wvvaa[i,:i+1] -= tmp
795            Wvvaa[:i  ,i] -= tmp[:i]
796            vvv = lib.unpack_tril(np.asarray(eris.vvVV[i0:i0+i+1]))
797            Wvvab[i] += np.einsum('bb->b', vvv[i])
798            vvv = None
799        for i in range(nvirb):
800            i0 = i*(i+1)//2
801            vvv = lib.unpack_tril(np.asarray(eris.VVVV[i0:i0+i+1]))
802            tmp = np.einsum('bb->b', vvv[i])
803            Wvvbb[i] += tmp
804            tmp = np.einsum('bb->b', vvv[:,:i+1,i])
805            Wvvbb[i,:i+1] -= tmp
806            Wvvbb[:i  ,i] -= tmp[:i]
807            vvv = None
808    Wvvba = Wvvab.T
809
810    Hr2aaa += Wvvaa[None,:,:]
811    Hr2aba += Wvvba[None,:,:]
812    Hr2bab += Wvvab[None,:,:]
813    Hr2bbb += Wvvbb[None,:,:]
814    # Wvvvv contribution end
815
816    vector = amplitudes_to_vector_ea((Hr1a,Hr1b), (Hr2aaa,Hr2aba,Hr2bab,Hr2bbb))
817    return vector
818
819
820class EOMEA(eom_rccsd.EOMEA):
821    matvec = eaccsd_matvec
822    l_matvec = None
823    get_diag = eaccsd_diag
824    eaccsd_star = None
825    ccsd_star_contract = None
826
827    def __init__(self, cc):
828        eom_rccsd.EOMEA.__init__(self, cc)
829        self.nocc = cc.get_nocc()
830        self.nmo = cc.get_nmo()
831
832    def get_init_guess(self, nroots=1, koopmans=True, diag=None):
833        if koopmans:
834            nocca, noccb = self.nocc
835            nmoa, nmob = self.nmo
836            nvira, nvirb = nmoa-nocca, nmob-noccb
837            idx = diag[:nvira+nvirb].argsort()
838        else:
839            idx = diag.argsort()
840
841        size = self.vector_size()
842        dtype = getattr(diag, 'dtype', np.double)
843        nroots = min(nroots, size)
844        guess = []
845        for i in idx[:nroots]:
846            g = np.zeros(size, dtype)
847            g[i] = 1.0
848            guess.append(g)
849        return guess
850
851    def vector_to_amplitudes(self, vector, nmo=None, nocc=None):
852        if nmo is None: nmo = self.nmo
853        if nocc is None: nocc = self.nocc
854        return vector_to_amplitudes_ea(vector, nmo, nocc)
855
856    def amplitudes_to_vector(self, r1, r2):
857        return amplitudes_to_vector_ea(r1, r2)
858
859    def vector_size(self):
860        '''size of the vector based on spin-orbital basis'''
861        nocca, noccb = self.nocc
862        nmoa, nmob = self.nmo
863        nvira, nvirb = nmoa-nocca, nmob-noccb
864        return (nvira + nvirb
865                + nocca*nvira*(nvira-1)//2 + nocca*nvirb*nvira
866                + noccb*nvira*nvirb + noccb*nvirb*(nvirb-1)//2)
867
868    def make_imds(self, eris=None):
869        imds = _IMDS(self._cc, eris=eris)
870        imds.make_ea()
871        return imds
872
873########################################
874# EOM-EE-CCSD
875########################################
876
877def eeccsd(eom, nroots=1, koopmans=False, guess=None, eris=None, imds=None):
878    '''Calculate N-electron neutral excitations via EOM-EE-CCSD.
879
880    Kwargs:
881        nroots : int
882            Number of roots (eigenvalues) requested
883        koopmans : bool
884            Calculate Koopmans'-like (1p1h) excitations only, targeting via
885            overlap.
886        guess : list of ndarray
887            List of guess vectors to use for targeting via overlap.
888    '''
889    if eris is None: eris = eom._cc.ao2mo()
890    if imds is None: imds = eom.make_imds(eris)
891
892    spinvec_size = eom.vector_size()
893    nroots = min(nroots, spinvec_size)
894
895    diag_ee, diag_sf = eom.get_diag(imds)
896    guess_ee = []
897    guess_sf = []
898    if guess and guess[0].size == spinvec_size:
899        raise NotImplementedError
900        #TODO: initial guess from GCCSD EOM amplitudes
901        #from pyscf.cc import addons
902        #from pyscf.cc import eom_gccsd
903        #orbspin = scf.addons.get_ghf_orbspin(eris.mo_coeff)
904        #nmo = np.sum(eom.nmo)
905        #nocc = np.sum(eom.nocc)
906        #for g in guess:
907        #    r1, r2 = eom_gccsd.vector_to_amplitudes_ee(g, nmo, nocc)
908        #    r1aa = r1[orbspin==0][:,orbspin==0]
909        #    r1ab = r1[orbspin==0][:,orbspin==1]
910        #    if abs(r1aa).max() > 1e-7:
911        #        r1 = addons.spin2spatial(r1, orbspin)
912        #        r2 = addons.spin2spatial(r2, orbspin)
913        #        guess_ee.append(eom.amplitudes_to_vector(r1, r2))
914        #    else:
915        #        r1 = spin2spatial_eomsf(r1, orbspin)
916        #        r2 = spin2spatial_eomsf(r2, orbspin)
917        #        guess_sf.append(amplitudes_to_vector_eomsf(r1, r2))
918        #    r1 = r2 = r1aa = r1ab = g = None
919        #nroots_ee = len(guess_ee)
920        #nroots_sf = len(guess_sf)
921    elif guess:
922        for g in guess:
923            if g.size == diag_ee.size:
924                guess_ee.append(g)
925            else:
926                guess_sf.append(g)
927        nroots_ee = len(guess_ee)
928        nroots_sf = len(guess_sf)
929    else:
930        dee = np.sort(diag_ee)[:nroots]
931        dsf = np.sort(diag_sf)[:nroots]
932        dmax = np.sort(np.hstack([dee,dsf]))[nroots-1]
933        nroots_ee = np.count_nonzero(dee <= dmax)
934        nroots_sf = np.count_nonzero(dsf <= dmax)
935        guess_ee = guess_sf = None
936
937    def eomee_sub(cls, nroots, guess, diag):
938        ee_sub = cls(eom._cc)
939        ee_sub.__dict__.update(eom.__dict__)
940        e, v = ee_sub.kernel(nroots, koopmans, guess, eris, imds, diag=diag)
941        if nroots == 1:
942            e, v = [e], [v]
943            ee_sub.converged = [ee_sub.converged]
944        return list(ee_sub.converged), list(e), list(v)
945
946    e0 = e1 = []
947    v0 = v1 = []
948    conv0 = conv1 = []
949    if nroots_ee > 0:
950        conv0, e0, v0 = eomee_sub(EOMEESpinKeep, nroots_ee, guess_ee, diag_ee)
951    if nroots_sf > 0:
952        conv1, e1, v1 = eomee_sub(EOMEESpinFlip, nroots_sf, guess_sf, diag_sf)
953
954    e = np.hstack([e0,e1])
955    idx = e.argsort()
956    e = e[idx]
957    conv = conv0 + conv1
958    conv = [conv[x] for x in idx]
959    v = v0 + v1
960    v = [v[x] for x in idx]
961
962    if nroots == 1:
963        conv = conv[0]
964        e = e[0]
965        v = v[0]
966    eom.converged = conv
967    eom.e = e
968    eom.v = v
969    return eom.e, eom.v
970
971def eomee_ccsd(eom, nroots=1, koopmans=False, guess=None,
972               eris=None, imds=None, diag=None):
973    if eris is None: eris = eom._cc.ao2mo()
974    if imds is None: imds = eom.make_imds(eris)
975    eom.converged, eom.e, eom.v \
976            = eom_rccsd.kernel(eom, nroots, koopmans, guess, imds=imds, diag=diag)
977    return eom.e, eom.v
978
979def eomsf_ccsd(eom, nroots=1, koopmans=False, guess=None,
980               eris=None, imds=None, diag=None):
981    '''Spin flip EOM-EE-CCSD
982    '''
983    return eomee_ccsd(eom, nroots, koopmans, guess, eris, imds, diag)
984
985amplitudes_to_vector_ee = uccsd.amplitudes_to_vector
986vector_to_amplitudes_ee = uccsd.vector_to_amplitudes
987
988def amplitudes_to_vector_eomsf(t1, t2, out=None):
989    t1ab, t1ba = t1
990    t2baaa, t2aaba, t2abbb, t2bbab = t2
991    nocca, nvirb = t1ab.shape
992    noccb, nvira = t1ba.shape
993
994    otrila = np.tril_indices(nocca, k=-1)
995    otrilb = np.tril_indices(noccb, k=-1)
996    vtrila = np.tril_indices(nvira, k=-1)
997    vtrilb = np.tril_indices(nvirb, k=-1)
998    baaa = np.take(t2baaa.reshape(noccb*nocca,nvira*nvira),
999                   vtrila[0]*nvira+vtrila[1], axis=1)
1000    abbb = np.take(t2abbb.reshape(nocca*noccb,nvirb*nvirb),
1001                   vtrilb[0]*nvirb+vtrilb[1], axis=1)
1002    vector = np.hstack((t1ab.ravel(), t1ba.ravel(),
1003                        baaa.ravel(), t2aaba[otrila].ravel(),
1004                        abbb.ravel(), t2bbab[otrilb].ravel()))
1005    return vector
1006
1007def vector_to_amplitudes_eomsf(vector, nmo, nocc):
1008    nocca, noccb = nocc
1009    nmoa, nmob = nmo
1010    nvira, nvirb = nmoa-nocca, nmob-noccb
1011
1012    nbaaa = noccb*nocca*nvira*(nvira-1)//2
1013    naaba = nocca*(nocca-1)//2*nvirb*nvira
1014    nabbb = nocca*noccb*nvirb*(nvirb-1)//2
1015    nbbab = noccb*(noccb-1)//2*nvira*nvirb
1016    sizes = (nocca*nvirb, noccb*nvira, nbaaa, naaba, nabbb, nbbab)
1017    sections = np.cumsum(sizes[:-1])
1018    t1ab, t1ba, vbaaa, vaaba, vabbb, vbbab = np.split(vector, sections)
1019
1020    t1ab = t1ab.reshape(nocca,nvirb).copy()
1021    t1ba = t1ba.reshape(noccb,nvira).copy()
1022
1023    t2baaa = np.zeros((noccb*nocca,nvira*nvira), dtype=vector.dtype)
1024    t2aaba = np.zeros((nocca*nocca,nvirb*nvira), dtype=vector.dtype)
1025    t2abbb = np.zeros((nocca*noccb,nvirb*nvirb), dtype=vector.dtype)
1026    t2bbab = np.zeros((noccb*noccb,nvira*nvirb), dtype=vector.dtype)
1027    otrila = np.tril_indices(nocca, k=-1)
1028    otrilb = np.tril_indices(noccb, k=-1)
1029    vtrila = np.tril_indices(nvira, k=-1)
1030    vtrilb = np.tril_indices(nvirb, k=-1)
1031    oidxab = np.arange(nocca*noccb, dtype=np.int32)
1032    vidxab = np.arange(nvira*nvirb, dtype=np.int32)
1033
1034    vbaaa = vbaaa.reshape(noccb*nocca,-1)
1035    lib.takebak_2d(t2baaa, vbaaa, oidxab, vtrila[0]*nvira+vtrila[1])
1036    lib.takebak_2d(t2baaa,-vbaaa, oidxab, vtrila[1]*nvira+vtrila[0])
1037    vaaba = vaaba.reshape(-1,nvirb*nvira)
1038    lib.takebak_2d(t2aaba, vaaba, otrila[0]*nocca+otrila[1], vidxab)
1039    lib.takebak_2d(t2aaba,-vaaba, otrila[1]*nocca+otrila[0], vidxab)
1040    vabbb = vabbb.reshape(nocca*noccb,-1)
1041    lib.takebak_2d(t2abbb, vabbb, oidxab, vtrilb[0]*nvirb+vtrilb[1])
1042    lib.takebak_2d(t2abbb,-vabbb, oidxab, vtrilb[1]*nvirb+vtrilb[0])
1043    vbbab = vbbab.reshape(-1,nvira*nvirb)
1044    lib.takebak_2d(t2bbab, vbbab, otrilb[0]*noccb+otrilb[1], vidxab)
1045    lib.takebak_2d(t2bbab,-vbbab, otrilb[1]*noccb+otrilb[0], vidxab)
1046    t2baaa = t2baaa.reshape(noccb,nocca,nvira,nvira)
1047    t2aaba = t2aaba.reshape(nocca,nocca,nvirb,nvira)
1048    t2abbb = t2abbb.reshape(nocca,noccb,nvirb,nvirb)
1049    t2bbab = t2bbab.reshape(noccb,noccb,nvira,nvirb)
1050    return (t1ab,t1ba), (t2baaa, t2aaba, t2abbb, t2bbab)
1051
1052def spatial2spin_eomsf(rx, orbspin):
1053    '''Convert EOM spatial R1,R2 to spin-orbital R1,R2'''
1054    if len(rx) == 2:  # r1
1055        r1ab, r1ba = rx
1056        nocca, nvirb = r1ab.shape
1057        noccb, nvira = r1ba.shape
1058    else:
1059        r2baaa,r2aaba,r2abbb,r2bbab = rx
1060        noccb, nocca, nvira = r2baaa.shape[:3]
1061        nvirb = r2aaba.shape[2]
1062
1063    nocc = nocca + noccb
1064    nvir = nvira + nvirb
1065    idxoa = np.where(orbspin[:nocc] == 0)[0]
1066    idxob = np.where(orbspin[:nocc] == 1)[0]
1067    idxva = np.where(orbspin[nocc:] == 0)[0]
1068    idxvb = np.where(orbspin[nocc:] == 1)[0]
1069
1070    if len(rx) == 2:  # r1
1071        r1 = np.zeros((nocc,nvir), dtype=r1ab.dtype)
1072        lib.takebak_2d(r1, r1ab, idxoa, idxvb)
1073        lib.takebak_2d(r1, r1ba, idxob, idxva)
1074        return r1
1075
1076    else:
1077        r2 = np.zeros((nocc**2,nvir**2), dtype=r2aaba.dtype)
1078        idxoaa = idxoa[:,None] * nocc + idxoa
1079        idxoab = idxoa[:,None] * nocc + idxob
1080        idxoba = idxob[:,None] * nocc + idxoa
1081        idxobb = idxob[:,None] * nocc + idxob
1082        idxvaa = idxva[:,None] * nvir + idxva
1083        idxvab = idxva[:,None] * nvir + idxvb
1084        idxvba = idxvb[:,None] * nvir + idxva
1085        idxvbb = idxvb[:,None] * nvir + idxvb
1086        r2baaa = r2baaa.reshape(noccb*nocca,nvira*nvira)
1087        r2aaba = r2aaba.reshape(nocca*nocca,nvirb*nvira)
1088        r2abbb = r2abbb.reshape(nocca*noccb,nvirb*nvirb)
1089        r2bbab = r2bbab.reshape(noccb*noccb,nvira*nvirb)
1090        lib.takebak_2d(r2, r2baaa, idxoba.ravel(), idxvaa.ravel())
1091        lib.takebak_2d(r2, r2aaba, idxoaa.ravel(), idxvba.ravel())
1092        lib.takebak_2d(r2, r2abbb, idxoab.ravel(), idxvbb.ravel())
1093        lib.takebak_2d(r2, r2bbab, idxobb.ravel(), idxvab.ravel())
1094        lib.takebak_2d(r2, r2baaa, idxoab.T.ravel(), idxvaa.T.ravel())
1095        lib.takebak_2d(r2, r2aaba, idxoaa.T.ravel(), idxvab.T.ravel())
1096        lib.takebak_2d(r2, r2abbb, idxoba.T.ravel(), idxvbb.T.ravel())
1097        lib.takebak_2d(r2, r2bbab, idxobb.T.ravel(), idxvba.T.ravel())
1098        return r2.reshape(nocc,nocc,nvir,nvir)
1099
1100def spin2spatial_eomsf(rx, orbspin):
1101    '''Convert EOM spin-orbital R1,R2 to spatial R1,R2'''
1102    if rx.ndim == 2:  # r1
1103        nocc, nvir = rx.shape
1104    else:
1105        nocc, nvir = rx.shape[1:3]
1106
1107    idxoa = np.where(orbspin[:nocc] == 0)[0]
1108    idxob = np.where(orbspin[:nocc] == 1)[0]
1109    idxva = np.where(orbspin[nocc:] == 0)[0]
1110    idxvb = np.where(orbspin[nocc:] == 1)[0]
1111    nocca = len(idxoa)
1112    noccb = len(idxob)
1113    nvira = len(idxva)
1114    nvirb = len(idxvb)
1115
1116    if rx.ndim == 2:
1117        r1ab = lib.take_2d(rx, idxoa, idxvb)
1118        r1ba = lib.take_2d(rx, idxob, idxva)
1119        return r1ab, r1ba
1120    else:
1121        idxoaa = idxoa[:,None] * nocc + idxoa
1122        idxoab = idxoa[:,None] * nocc + idxob
1123        idxoba = idxob[:,None] * nocc + idxoa
1124        idxobb = idxob[:,None] * nocc + idxob
1125        idxvaa = idxva[:,None] * nvir + idxva
1126        idxvab = idxva[:,None] * nvir + idxvb
1127        idxvba = idxvb[:,None] * nvir + idxva
1128        idxvbb = idxvb[:,None] * nvir + idxvb
1129        r2 = rx.reshape(nocc**2,nvir**2)
1130        r2baaa = lib.take_2d(r2, idxoba.ravel(), idxvaa.ravel())
1131        r2aaba = lib.take_2d(r2, idxoaa.ravel(), idxvba.ravel())
1132        r2abbb = lib.take_2d(r2, idxoab.ravel(), idxvbb.ravel())
1133        r2bbab = lib.take_2d(r2, idxobb.ravel(), idxvab.ravel())
1134        r2baaa = r2baaa.reshape(noccb,nocca,nvira,nvira)
1135        r2aaba = r2aaba.reshape(nocca,nocca,nvirb,nvira)
1136        r2abbb = r2abbb.reshape(nocca,noccb,nvirb,nvirb)
1137        r2bbab = r2bbab.reshape(noccb,noccb,nvira,nvirb)
1138        return r2baaa,r2aaba,r2abbb,r2bbab
1139
1140# Ref: Wang, Tu, and Wang, J. Chem. Theory Comput. 10, 5567 (2014) Eqs.(9)-(10)
1141# Note: Last line in Eq. (10) is superfluous.
1142# See, e.g. Gwaltney, Nooijen, and Barlett, Chem. Phys. Lett. 248, 189 (1996)
1143def eomee_ccsd_matvec(eom, vector, imds=None):
1144    if imds is None: imds = eom.make_imds()
1145
1146    t1, t2, eris = imds.t1, imds.t2, imds.eris
1147    t1a, t1b = t1
1148    t2aa, t2ab, t2bb = t2
1149    nocca, noccb, nvira, nvirb = t2ab.shape
1150    nmoa, nmob = nocca+nvira, noccb+nvirb
1151    r1, r2 = vector_to_amplitudes_ee(vector, (nmoa,nmob), (nocca,noccb))
1152    r1a, r1b = r1
1153    r2aa, r2ab, r2bb = r2
1154
1155    #:Hr2aa += lib.einsum('ijef,aebf->ijab', tau2aa, eris.vvvv) * .5
1156    #:Hr2bb += lib.einsum('ijef,aebf->ijab', tau2bb, eris.VVVV) * .5
1157    #:Hr2ab += lib.einsum('iJeF,aeBF->iJaB', tau2ab, eris.vvVV)
1158    tau2aa, tau2ab, tau2bb = uccsd.make_tau(r2, r1, t1, 2)
1159    Hr2aa, Hr2ab, Hr2bb = eom._cc._add_vvvv(None, (tau2aa,tau2ab,tau2bb), eris)
1160    Hr2aa *= .5
1161    Hr2bb *= .5
1162    tau2aa = tau2ab = tau2bb = None
1163
1164    Hr1a  = lib.einsum('ae,ie->ia', imds.Fvva, r1a)
1165    Hr1a -= lib.einsum('mi,ma->ia', imds.Fooa, r1a)
1166    Hr1a += np.einsum('me,imae->ia',imds.Fova, r2aa)
1167    Hr1a += np.einsum('ME,iMaE->ia',imds.Fovb, r2ab)
1168    Hr1b  = lib.einsum('ae,ie->ia', imds.Fvvb, r1b)
1169    Hr1b -= lib.einsum('mi,ma->ia', imds.Foob, r1b)
1170    Hr1b += np.einsum('me,imae->ia',imds.Fovb, r2bb)
1171    Hr1b += np.einsum('me,mIeA->IA',imds.Fova, r2ab)
1172
1173    Hr2aa += lib.einsum('minj,mnab->ijab', imds.woooo, r2aa) * .25
1174    Hr2bb += lib.einsum('minj,mnab->ijab', imds.wOOOO, r2bb) * .25
1175    Hr2ab += lib.einsum('miNJ,mNaB->iJaB', imds.wooOO, r2ab)
1176    Hr2aa += lib.einsum('be,ijae->ijab', imds.Fvva, r2aa)
1177    Hr2bb += lib.einsum('be,ijae->ijab', imds.Fvvb, r2bb)
1178    Hr2ab += lib.einsum('BE,iJaE->iJaB', imds.Fvvb, r2ab)
1179    Hr2ab += lib.einsum('be,iJeA->iJbA', imds.Fvva, r2ab)
1180    Hr2aa -= lib.einsum('mj,imab->ijab', imds.Fooa, r2aa)
1181    Hr2bb -= lib.einsum('mj,imab->ijab', imds.Foob, r2bb)
1182    Hr2ab -= lib.einsum('MJ,iMaB->iJaB', imds.Foob, r2ab)
1183    Hr2ab -= lib.einsum('mj,mIaB->jIaB', imds.Fooa, r2ab)
1184
1185    #:tau2aa, tau2ab, tau2bb = uccsd.make_tau(r2, r1, t1, 2)
1186    #:eris_ovvv = lib.unpack_tril(np.asarray(eris.ovvv).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvira,nvira)
1187    #:eris_ovVV = lib.unpack_tril(np.asarray(eris.ovVV).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvirb,nvirb)
1188    #:eris_OVvv = lib.unpack_tril(np.asarray(eris.OVvv).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvira,nvira)
1189    #:eris_OVVV = lib.unpack_tril(np.asarray(eris.OVVV).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvirb,nvirb)
1190    #:Hr1a += lib.einsum('mfae,imef->ia', eris_ovvv, r2aa)
1191    #:tmpaa = lib.einsum('meaf,ijef->maij', eris_ovvv, tau2aa)
1192    #:Hr2aa+= lib.einsum('mb,maij->ijab', t1a, tmpaa)
1193    #:tmpa = lib.einsum('mfae,me->af', eris_ovvv, r1a)
1194    #:tmpa-= lib.einsum('meaf,me->af', eris_ovvv, r1a)
1195
1196    #:Hr1b += lib.einsum('mfae,imef->ia', eris_OVVV, r2bb)
1197    #:tmpbb = lib.einsum('meaf,ijef->maij', eris_OVVV, tau2bb)
1198    #:Hr2bb+= lib.einsum('mb,maij->ijab', t1b, tmpbb)
1199    #:tmpb = lib.einsum('mfae,me->af', eris_OVVV, r1b)
1200    #:tmpb-= lib.einsum('meaf,me->af', eris_OVVV, r1b)
1201
1202    #:Hr1b += lib.einsum('mfAE,mIfE->IA', eris_ovVV, r2ab)
1203    #:tmpab = lib.einsum('meAF,iJeF->mAiJ', eris_ovVV, tau2ab)
1204    #:Hr2ab-= lib.einsum('mb,mAiJ->iJbA', t1a, tmpab)
1205    #:tmpb-= lib.einsum('meAF,me->AF', eris_ovVV, r1a)
1206
1207    #:Hr1a += lib.einsum('MFae,iMeF->ia', eris_OVvv, r2ab)
1208    #:tmpba =-lib.einsum('MEaf,iJfE->MaiJ', eris_OVvv, tau2ab)
1209    #:Hr2ab+= lib.einsum('MB,MaiJ->iJaB', t1b, tmpba)
1210    #:tmpa-= lib.einsum('MEaf,ME->af', eris_OVvv, r1b)
1211    tau2aa = uccsd.make_tau_aa(r2aa, r1a, t1a, 2)
1212    mem_now = lib.current_memory()[0]
1213    max_memory = max(0, eom.max_memory - mem_now)
1214    tmpa = np.zeros((nvira,nvira))
1215    tmpb = np.zeros((nvirb,nvirb))
1216    blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira**3*3))))
1217    for p0, p1 in lib.prange(0, nocca, blksize):
1218        ovvv = eris.get_ovvv(slice(p0,p1))  # ovvv = eris.ovvv[p0:p1]
1219        Hr1a += lib.einsum('mfae,imef->ia', ovvv, r2aa[:,p0:p1])
1220        tmpaa = lib.einsum('meaf,ijef->maij', ovvv, tau2aa)
1221        Hr2aa+= lib.einsum('mb,maij->ijab', t1a[p0:p1], tmpaa)
1222        tmpa+= lib.einsum('mfae,me->af', ovvv, r1a[p0:p1])
1223        tmpa-= lib.einsum('meaf,me->af', ovvv, r1a[p0:p1])
1224        ovvv = tmpaa = None
1225    tau2aa = None
1226
1227    tau2bb = uccsd.make_tau_aa(r2bb, r1b, t1b, 2)
1228    blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb**3*3))))
1229    for p0, p1 in lib.prange(0, noccb, blksize):
1230        OVVV = eris.get_OVVV(slice(p0,p1))  # OVVV = eris.OVVV[p0:p1]
1231        Hr1b += lib.einsum('mfae,imef->ia', OVVV, r2bb[:,p0:p1])
1232        tmpbb = lib.einsum('meaf,ijef->maij', OVVV, tau2bb)
1233        Hr2bb+= lib.einsum('mb,maij->ijab', t1b[p0:p1], tmpbb)
1234        tmpb+= lib.einsum('mfae,me->af', OVVV, r1b[p0:p1])
1235        tmpb-= lib.einsum('meaf,me->af', OVVV, r1b[p0:p1])
1236        OVVV = tmpbb = None
1237    tau2bb = None
1238
1239    tau2ab = uccsd.make_tau_ab(r2ab, r1 , t1 , 2)
1240    blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira*nvirb**2*3))))
1241    for p0, p1 in lib.prange(0, nocca, blksize):
1242        ovVV = eris.get_ovVV(slice(p0,p1))  # ovVV = eris.ovVV[p0:p1]
1243        Hr1b += lib.einsum('mfAE,mIfE->IA', ovVV, r2ab[p0:p1])
1244        tmpab = lib.einsum('meAF,iJeF->mAiJ', ovVV, tau2ab)
1245        Hr2ab-= lib.einsum('mb,mAiJ->iJbA', t1a[p0:p1], tmpab)
1246        tmpb-= lib.einsum('meAF,me->AF', ovVV, r1a[p0:p1])
1247        ovVV = tmpab = None
1248
1249    blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb*nvira**2*3))))
1250    for p0, p1 in lib.prange(0, noccb, blksize):
1251        OVvv = eris.get_OVvv(slice(p0,p1))  # OVvv = eris.OVvv[p0:p1]
1252        Hr1a += lib.einsum('MFae,iMeF->ia', OVvv, r2ab[:,p0:p1])
1253        tmpba = lib.einsum('MEaf,iJfE->MaiJ', OVvv, tau2ab)
1254        Hr2ab-= lib.einsum('MB,MaiJ->iJaB', t1b[p0:p1], tmpba)
1255        tmpa-= lib.einsum('MEaf,ME->af', OVvv, r1b[p0:p1])
1256        OVvv = tmpba = None
1257    tau2ab = None
1258
1259    Hr2aa-= lib.einsum('af,ijfb->ijab', tmpa, t2aa)
1260    Hr2bb-= lib.einsum('af,ijfb->ijab', tmpb, t2bb)
1261    Hr2ab-= lib.einsum('af,iJfB->iJaB', tmpa, t2ab)
1262    Hr2ab-= lib.einsum('AF,iJbF->iJbA', tmpb, t2ab)
1263
1264    eris_ovov = np.asarray(eris.ovov)
1265    eris_OVOV = np.asarray(eris.OVOV)
1266    eris_ovOV = np.asarray(eris.ovOV)
1267    tau2aa = uccsd.make_tau_aa(r2aa, r1a, t1a, 2)
1268    tauaa = uccsd.make_tau_aa(t2aa, t1a, t1a)
1269    tmpaa = lib.einsum('menf,ijef->mnij', eris_ovov, tau2aa)
1270    Hr2aa += lib.einsum('mnij,mnab->ijab', tmpaa, tauaa) * 0.25
1271    tmpaa = tau2aa = tauaa = None
1272
1273    tau2bb = uccsd.make_tau_aa(r2bb, r1b, t1b, 2)
1274    taubb = uccsd.make_tau_aa(t2bb, t1b, t1b)
1275    tmpbb = lib.einsum('menf,ijef->mnij', eris_OVOV, tau2bb)
1276    Hr2bb += lib.einsum('mnij,mnab->ijab', tmpbb, taubb) * 0.25
1277    tmpbb = tau2bb = taubb = None
1278
1279    tau2ab = uccsd.make_tau_ab(r2ab, r1 , t1 , 2)
1280    tauab = uccsd.make_tau_ab(t2ab, t1 , t1)
1281    tmpab = lib.einsum('meNF,iJeF->mNiJ', eris_ovOV, tau2ab)
1282    Hr2ab += lib.einsum('mNiJ,mNaB->iJaB', tmpab, tauab)
1283    tmpab = tau2ab = tauab = None
1284
1285    tmpa = lib.einsum('menf,imef->ni', eris_ovov, r2aa)
1286    tmpa-= lib.einsum('neMF,iMeF->ni', eris_ovOV, r2ab)
1287    tmpb = lib.einsum('menf,imef->ni', eris_OVOV, r2bb)
1288    tmpb-= lib.einsum('mfNE,mIfE->NI', eris_ovOV, r2ab)
1289    Hr1a += lib.einsum('na,ni->ia', t1a, tmpa)
1290    Hr1b += lib.einsum('na,ni->ia', t1b, tmpb)
1291    Hr2aa+= lib.einsum('mj,imab->ijab', tmpa, t2aa)
1292    Hr2bb+= lib.einsum('mj,imab->ijab', tmpb, t2bb)
1293    Hr2ab+= lib.einsum('MJ,iMaB->iJaB', tmpb, t2ab)
1294    Hr2ab+= lib.einsum('mj,mIaB->jIaB', tmpa, t2ab)
1295
1296    tmp1a = np.einsum('menf,mf->en', eris_ovov, r1a)
1297    tmp1a-= np.einsum('mfne,mf->en', eris_ovov, r1a)
1298    tmp1a-= np.einsum('neMF,MF->en', eris_ovOV, r1b)
1299    tmp1b = np.einsum('menf,mf->en', eris_OVOV, r1b)
1300    tmp1b-= np.einsum('mfne,mf->en', eris_OVOV, r1b)
1301    tmp1b-= np.einsum('mfNE,mf->EN', eris_ovOV, r1a)
1302    tmpa = np.einsum('en,nb->eb', tmp1a, t1a)
1303    tmpa+= lib.einsum('menf,mnfb->eb', eris_ovov, r2aa)
1304    tmpa-= lib.einsum('meNF,mNbF->eb', eris_ovOV, r2ab)
1305    tmpb = np.einsum('en,nb->eb', tmp1b, t1b)
1306    tmpb+= lib.einsum('menf,mnfb->eb', eris_OVOV, r2bb)
1307    tmpb-= lib.einsum('nfME,nMfB->EB', eris_ovOV, r2ab)
1308    Hr2aa+= lib.einsum('eb,ijae->ijab', tmpa, t2aa)
1309    Hr2bb+= lib.einsum('eb,ijae->ijab', tmpb, t2bb)
1310    Hr2ab+= lib.einsum('EB,iJaE->iJaB', tmpb, t2ab)
1311    Hr2ab+= lib.einsum('eb,iJeA->iJbA', tmpa, t2ab)
1312    eris_ovOV = eris_OVOV = None
1313
1314    Hr2aa-= lib.einsum('mbij,ma->ijab', imds.wovoo, r1a)
1315    Hr2bb-= lib.einsum('mbij,ma->ijab', imds.wOVOO, r1b)
1316    Hr2ab-= lib.einsum('mBiJ,ma->iJaB', imds.woVoO, r1a)
1317    Hr2ab-= lib.einsum('MbJi,MA->iJbA', imds.wOvOo, r1b)
1318
1319    Hr1a-= 0.5*lib.einsum('mine,mnae->ia', imds.wooov, r2aa)
1320    Hr1a-=     lib.einsum('miNE,mNaE->ia', imds.wooOV, r2ab)
1321    Hr1b-= 0.5*lib.einsum('mine,mnae->ia', imds.wOOOV, r2bb)
1322    Hr1b-=     lib.einsum('MIne,nMeA->IA', imds.wOOov, r2ab)
1323    tmpa = lib.einsum('mine,me->ni', imds.wooov, r1a)
1324    tmpa-= lib.einsum('niME,ME->ni', imds.wooOV, r1b)
1325    tmpb = lib.einsum('mine,me->ni', imds.wOOOV, r1b)
1326    tmpb-= lib.einsum('NIme,me->NI', imds.wOOov, r1a)
1327    Hr2aa+= lib.einsum('ni,njab->ijab', tmpa, t2aa)
1328    Hr2bb+= lib.einsum('ni,njab->ijab', tmpb, t2bb)
1329    Hr2ab+= lib.einsum('ni,nJaB->iJaB', tmpa, t2ab)
1330    Hr2ab+= lib.einsum('NI,jNaB->jIaB', tmpb, t2ab)
1331    for p0, p1 in lib.prange(0, nvira, nocca):
1332        Hr2aa+= lib.einsum('ejab,ie->ijab', imds.wvovv[p0:p1], r1a[:,p0:p1])
1333        Hr2ab+= lib.einsum('eJaB,ie->iJaB', imds.wvOvV[p0:p1], r1a[:,p0:p1])
1334    for p0, p1 in lib.prange(0, nvirb, noccb):
1335        Hr2bb+= lib.einsum('ejab,ie->ijab', imds.wVOVV[p0:p1], r1b[:,p0:p1])
1336        Hr2ab+= lib.einsum('EjBa,IE->jIaB', imds.wVoVv[p0:p1], r1b[:,p0:p1])
1337
1338    Hr1a += np.einsum('maei,me->ia',imds.wovvo,r1a)
1339    Hr1a += np.einsum('MaEi,ME->ia',imds.wOvVo,r1b)
1340    Hr1b += np.einsum('maei,me->ia',imds.wOVVO,r1b)
1341    Hr1b += np.einsum('mAeI,me->IA',imds.woVvO,r1a)
1342    Hr2aa+= lib.einsum('mbej,imae->ijab', imds.wovvo, r2aa) * 2
1343    Hr2aa+= lib.einsum('MbEj,iMaE->ijab', imds.wOvVo, r2ab) * 2
1344    Hr2bb+= lib.einsum('mbej,imae->ijab', imds.wOVVO, r2bb) * 2
1345    Hr2bb+= lib.einsum('mBeJ,mIeA->IJAB', imds.woVvO, r2ab) * 2
1346    Hr2ab+= lib.einsum('mBeJ,imae->iJaB', imds.woVvO, r2aa)
1347    Hr2ab+= lib.einsum('MBEJ,iMaE->iJaB', imds.wOVVO, r2ab)
1348    Hr2ab+= lib.einsum('mBEj,mIaE->jIaB', imds.woVVo, r2ab)
1349    Hr2ab+= lib.einsum('mbej,mIeA->jIbA', imds.wovvo, r2ab)
1350    Hr2ab+= lib.einsum('MbEj,IMAE->jIbA', imds.wOvVo, r2bb)
1351    Hr2ab+= lib.einsum('MbeJ,iMeA->iJbA', imds.wOvvO, r2ab)
1352
1353    Hr2aa *= .5
1354    Hr2bb *= .5
1355    Hr2aa = Hr2aa - Hr2aa.transpose(0,1,3,2)
1356    Hr2aa = Hr2aa - Hr2aa.transpose(1,0,2,3)
1357    Hr2bb = Hr2bb - Hr2bb.transpose(0,1,3,2)
1358    Hr2bb = Hr2bb - Hr2bb.transpose(1,0,2,3)
1359
1360    vector = amplitudes_to_vector_ee((Hr1a,Hr1b), (Hr2aa,Hr2ab,Hr2bb))
1361    return vector
1362
1363def eomsf_ccsd_matvec(eom, vector, imds=None):
1364    '''Spin flip EOM-CCSD'''
1365    if imds is None: imds = eom.make_imds()
1366
1367    t1, t2, eris = imds.t1, imds.t2, imds.eris
1368    t1a, t1b = t1
1369    t2aa, t2ab, t2bb = t2
1370    nocca, noccb, nvira, nvirb = t2ab.shape
1371    nmoa, nmob = nocca+nvira, noccb+nvirb
1372    r1, r2 = vector_to_amplitudes_eomsf(vector, (nmoa,nmob), (nocca,noccb))
1373    r1ab, r1ba = r1
1374    r2baaa, r2aaba, r2abbb, r2bbab = r2
1375
1376    Hr1ab  = np.einsum('ae,ie->ia', imds.Fvvb, r1ab)
1377    Hr1ab -= np.einsum('mi,ma->ia', imds.Fooa, r1ab)
1378    Hr1ab += np.einsum('me,imae->ia', imds.Fovb, r2abbb)
1379    Hr1ab += np.einsum('me,imae->ia', imds.Fova, r2aaba)
1380    Hr1ba  = np.einsum('ae,ie->ia', imds.Fvva, r1ba)
1381    Hr1ba -= np.einsum('mi,ma->ia', imds.Foob, r1ba)
1382    Hr1ba += np.einsum('me,imae->ia', imds.Fova, r2baaa)
1383    Hr1ba += np.einsum('me,imae->ia', imds.Fovb, r2bbab)
1384    Hr2baaa = .5 *lib.einsum('njMI,Mnab->Ijab', imds.wooOO, r2baaa)
1385    Hr2aaba = .25*lib.einsum('minj,mnAb->ijAb', imds.woooo, r2aaba)
1386    Hr2abbb = .5 *lib.einsum('miNJ,mNAB->iJAB', imds.wooOO, r2abbb)
1387    Hr2bbab = .25*lib.einsum('MINJ,MNaB->IJaB', imds.wOOOO, r2bbab)
1388    Hr2baaa += lib.einsum('be,Ijae->Ijab', imds.Fvva   , r2baaa)
1389    Hr2baaa -= lib.einsum('mj,imab->ijab', imds.Fooa*.5, r2baaa)
1390    Hr2baaa -= lib.einsum('MJ,Miab->Jiab', imds.Foob*.5, r2baaa)
1391    Hr2bbab -= lib.einsum('mj,imab->ijab', imds.Foob   , r2bbab)
1392    Hr2bbab += lib.einsum('BE,IJaE->IJaB', imds.Fvvb*.5, r2bbab)
1393    Hr2bbab += lib.einsum('be,IJeA->IJbA', imds.Fvva*.5, r2bbab)
1394    Hr2aaba -= lib.einsum('mj,imab->ijab', imds.Fooa   , r2aaba)
1395    Hr2aaba += lib.einsum('be,ijAe->ijAb', imds.Fvva*.5, r2aaba)
1396    Hr2aaba += lib.einsum('BE,ijEa->ijBa', imds.Fvvb*.5, r2aaba)
1397    Hr2abbb += lib.einsum('BE,iJAE->iJAB', imds.Fvvb   , r2abbb)
1398    Hr2abbb -= lib.einsum('mj,imab->ijab', imds.Foob*.5, r2abbb)
1399    Hr2abbb -= lib.einsum('mj,mIAB->jIAB', imds.Fooa*.5, r2abbb)
1400
1401    tau2baaa = np.einsum('ia,jb->ijab', r1ba, t1a)
1402    tau2baaa = tau2baaa - tau2baaa.transpose(0,1,3,2)
1403    tau2abbb = np.einsum('ia,jb->ijab', r1ab, t1b)
1404    tau2abbb = tau2abbb - tau2abbb.transpose(0,1,3,2)
1405    tau2aaba = np.einsum('ia,jb->ijab', r1ab, t1a)
1406    tau2aaba = tau2aaba - tau2aaba.transpose(1,0,2,3)
1407    tau2bbab = np.einsum('ia,jb->ijab', r1ba, t1b)
1408    tau2bbab = tau2bbab - tau2bbab.transpose(1,0,2,3)
1409    tau2baaa += r2baaa
1410    tau2bbab += r2bbab
1411    tau2abbb += r2abbb
1412    tau2aaba += r2aaba
1413    #:eris_ovvv = lib.unpack_tril(np.asarray(eris.ovvv).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvira,nvira)
1414    #:Hr1ba += lib.einsum('mfae,Imef->Ia', eris_ovvv, r2baaa)
1415    #:tmp1aaba = lib.einsum('meaf,Ijef->maIj', eris_ovvv, tau2baaa)
1416    #:Hr2baaa += lib.einsum('mb,maIj->Ijab', t1a   , tmp1aaba)
1417    mem_now = lib.current_memory()[0]
1418    max_memory = max(0, eom.max_memory - mem_now)
1419    blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira**3*3))))
1420    for p0,p1 in lib.prange(0, nocca, blksize):
1421        ovvv = eris.get_ovvv(slice(p0,p1))  # ovvv = eris.ovvv[p0:p1]
1422        Hr1ba += lib.einsum('mfae,Imef->Ia', ovvv, r2baaa[:,p0:p1])
1423        tmp1aaba = lib.einsum('meaf,Ijef->maIj', ovvv, tau2baaa)
1424        Hr2baaa += lib.einsum('mb,maIj->Ijab', t1a[p0:p1], tmp1aaba)
1425        ovvv = tmp1aaba = None
1426
1427    #:eris_OVVV = lib.unpack_tril(np.asarray(eris.OVVV).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvirb,nvirb)
1428    #:Hr1ab += lib.einsum('MFAE,iMEF->iA', eris_OVVV, r2abbb)
1429    #:tmp1bbab = lib.einsum('MEAF,iJEF->MAiJ', eris_OVVV, tau2abbb)
1430    #:Hr2abbb += lib.einsum('MB,MAiJ->iJAB', t1b   , tmp1bbab)
1431    blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb**3*3))))
1432    for p0, p1 in lib.prange(0, noccb, blksize):
1433        OVVV = eris.get_OVVV(slice(p0,p1))  # OVVV = eris.OVVV[p0:p1]
1434        Hr1ab += lib.einsum('MFAE,iMEF->iA', OVVV, r2abbb[:,p0:p1])
1435        tmp1bbab = lib.einsum('MEAF,iJEF->MAiJ', OVVV, tau2abbb)
1436        Hr2abbb += lib.einsum('MB,MAiJ->iJAB', t1b[p0:p1], tmp1bbab)
1437        OVVV = tmp1bbab = None
1438
1439    #:eris_ovVV = lib.unpack_tril(np.asarray(eris.ovVV).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvirb,nvirb)
1440    #:Hr1ab += lib.einsum('mfAE,imEf->iA', eris_ovVV, r2aaba)
1441    #:tmp1abaa = lib.einsum('meAF,ijFe->mAij', eris_ovVV, tau2aaba)
1442    #:tmp1abbb = lib.einsum('meAF,IJeF->mAIJ', eris_ovVV, tau2bbab)
1443    #:tmp1ba = lib.einsum('mfAE,mE->Af', eris_ovVV, r1ab)
1444    #:Hr2bbab -= lib.einsum('mb,mAIJ->IJbA', t1a*.5, tmp1abbb)
1445    #:Hr2aaba -= lib.einsum('mb,mAij->ijAb', t1a*.5, tmp1abaa)
1446    tmp1ba = np.zeros((nvirb,nvira))
1447    blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira*nvirb**2*3))))
1448    for p0,p1 in lib.prange(0, nocca, blksize):
1449        ovVV = eris.get_ovVV(slice(p0,p1))  # ovVV = eris.ovVV[p0:p1]
1450        Hr1ab += lib.einsum('mfAE,imEf->iA', ovVV, r2aaba[:,p0:p1])
1451        tmp1abaa = lib.einsum('meAF,ijFe->mAij', ovVV, tau2aaba)
1452        tmp1abbb = lib.einsum('meAF,IJeF->mAIJ', ovVV, tau2bbab)
1453        tmp1ba += lib.einsum('mfAE,mE->Af', ovVV, r1ab[p0:p1])
1454        Hr2bbab -= lib.einsum('mb,mAIJ->IJbA', t1a[p0:p1]*.5, tmp1abbb)
1455        Hr2aaba -= lib.einsum('mb,mAij->ijAb', t1a[p0:p1]*.5, tmp1abaa)
1456
1457    #:eris_OVvv = lib.unpack_tril(np.asarray(eris.OVvv).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvira,nvira)
1458    #:Hr1ba += lib.einsum('MFae,IMeF->Ia', eris_OVvv, r2bbab)
1459    #:tmp1baaa = lib.einsum('MEaf,ijEf->Maij', eris_OVvv, tau2aaba)
1460    #:tmp1babb = lib.einsum('MEaf,IJfE->MaIJ', eris_OVvv, tau2bbab)
1461    #:tmp1ab = lib.einsum('MFae,Me->aF', eris_OVvv, r1ba)
1462    #:Hr2aaba -= lib.einsum('MB,Maij->ijBa', t1b*.5, tmp1baaa)
1463    #:Hr2bbab -= lib.einsum('MB,MaIJ->IJaB', t1b*.5, tmp1babb)
1464    tmp1ab = np.zeros((nvira,nvirb))
1465    blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb*nvira**2*3))))
1466    for p0, p1 in lib.prange(0, noccb, blksize):
1467        OVvv = eris.get_OVvv(slice(p0,p1))  # OVvv = eris.OVvv[p0:p1]
1468        Hr1ba += lib.einsum('MFae,IMeF->Ia', OVvv, r2bbab[:,p0:p1])
1469        tmp1baaa = lib.einsum('MEaf,ijEf->Maij', OVvv, tau2aaba)
1470        tmp1babb = lib.einsum('MEaf,IJfE->MaIJ', OVvv, tau2bbab)
1471        tmp1ab+= lib.einsum('MFae,Me->aF', OVvv, r1ba[p0:p1])
1472        Hr2aaba -= lib.einsum('MB,Maij->ijBa', t1b[p0:p1]*.5, tmp1baaa)
1473        Hr2bbab -= lib.einsum('MB,MaIJ->IJaB', t1b[p0:p1]*.5, tmp1babb)
1474
1475    Hr2baaa += lib.einsum('aF,jIbF->Ijba', tmp1ab   , t2ab)
1476    Hr2bbab -= lib.einsum('aF,IJFB->IJaB', tmp1ab*.5, t2bb)
1477    Hr2abbb += lib.einsum('Af,iJfB->iJBA', tmp1ba   , t2ab)
1478    Hr2aaba -= lib.einsum('Af,ijfb->ijAb', tmp1ba*.5, t2aa)
1479    Hr2baaa -= lib.einsum('MbIj,Ma->Ijab', imds.wOvOo, r1ba   )
1480    Hr2bbab -= lib.einsum('MBIJ,Ma->IJaB', imds.wOVOO, r1ba*.5)
1481    Hr2abbb -= lib.einsum('mBiJ,mA->iJAB', imds.woVoO, r1ab   )
1482    Hr2aaba -= lib.einsum('mbij,mA->ijAb', imds.wovoo, r1ab*.5)
1483
1484    Hr1ab -= 0.5*lib.einsum('mine,mnAe->iA', imds.wooov, r2aaba)
1485    Hr1ab -=     lib.einsum('miNE,mNAE->iA', imds.wooOV, r2abbb)
1486    Hr1ba -= 0.5*lib.einsum('MINE,MNaE->Ia', imds.wOOOV, r2bbab)
1487    Hr1ba -=     lib.einsum('MIne,Mnae->Ia', imds.wOOov, r2baaa)
1488    tmp1ab = lib.einsum('MIne,Me->nI', imds.wOOov, r1ba)
1489    tmp1ba = lib.einsum('miNE,mE->Ni', imds.wooOV, r1ab)
1490    Hr2baaa += lib.einsum('nI,njab->Ijab', tmp1ab*.5, t2aa)
1491    Hr2bbab += lib.einsum('nI,nJaB->IJaB', tmp1ab   , t2ab)
1492    Hr2abbb += lib.einsum('Ni,NJAB->iJAB', tmp1ba*.5, t2bb)
1493    Hr2aaba += lib.einsum('Ni,jNbA->ijAb', tmp1ba   , t2ab)
1494    for p0, p1 in lib.prange(0, nvira, nocca):
1495        Hr2baaa += lib.einsum('ejab,Ie->Ijab', imds.wvovv[p0:p1], r1ba[:,p0:p1]*.5)
1496        Hr2bbab += lib.einsum('eJaB,Ie->IJaB', imds.wvOvV[p0:p1], r1ba[:,p0:p1]   )
1497    for p0, p1 in lib.prange(0, nvirb, noccb):
1498        Hr2abbb += lib.einsum('EJAB,iE->iJAB', imds.wVOVV[p0:p1], r1ab[:,p0:p1]*.5)
1499        Hr2aaba += lib.einsum('EjAb,iE->ijAb', imds.wVoVv[p0:p1], r1ab[:,p0:p1]   )
1500
1501    Hr1ab += np.einsum('mAEi,mE->iA', imds.woVVo, r1ab)
1502    Hr1ba += np.einsum('MaeI,Me->Ia', imds.wOvvO, r1ba)
1503    Hr2baaa += lib.einsum('mbej,Imae->Ijab', imds.wovvo, r2baaa)
1504    Hr2baaa += lib.einsum('MbeJ,Miae->Jiab', imds.wOvvO, r2baaa)
1505    Hr2baaa += lib.einsum('MbEj,IMaE->Ijab', imds.wOvVo, r2bbab)
1506    Hr2bbab += lib.einsum('MBEJ,IMaE->IJaB', imds.wOVVO, r2bbab)
1507    Hr2bbab += lib.einsum('MbeJ,IMeA->IJbA', imds.wOvvO, r2bbab)
1508    Hr2bbab += lib.einsum('mBeJ,Imae->IJaB', imds.woVvO, r2baaa)
1509    Hr2aaba += lib.einsum('mbej,imAe->ijAb', imds.wovvo, r2aaba)
1510    Hr2aaba += lib.einsum('mBEj,imEa->ijBa', imds.woVVo, r2aaba)
1511    Hr2aaba += lib.einsum('MbEj,iMAE->ijAb', imds.wOvVo, r2abbb)
1512    Hr2abbb += lib.einsum('MBEJ,iMAE->iJAB', imds.wOVVO, r2abbb)
1513    Hr2abbb += lib.einsum('mBEj,mIAE->jIAB', imds.woVVo, r2abbb)
1514    Hr2abbb += lib.einsum('mBeJ,imAe->iJAB', imds.woVvO, r2aaba)
1515
1516    eris_ovov = np.asarray(eris.ovov)
1517    eris_OVOV = np.asarray(eris.OVOV)
1518    eris_ovOV = np.asarray(eris.ovOV)
1519    tauaa, tauab, taubb = uccsd.make_tau(t2, t1, t1)
1520    tmp1baaa = lib.einsum('nfME,ijEf->Mnij', eris_ovOV, tau2aaba)
1521    tmp1aaba = lib.einsum('menf,Ijef->mnIj', eris_ovov, tau2baaa)
1522    tmp1abbb = lib.einsum('meNF,IJeF->mNIJ', eris_ovOV, tau2bbab)
1523    tmp1bbab = lib.einsum('MENF,iJEF->MNiJ', eris_OVOV, tau2abbb)
1524    Hr2baaa += 0.5*.5*lib.einsum('mnIj,mnab->Ijab', tmp1aaba, tauaa)
1525    Hr2bbab +=     .5*lib.einsum('nMIJ,nMaB->IJaB', tmp1abbb, tauab)
1526    Hr2aaba +=     .5*lib.einsum('Nmij,mNbA->ijAb', tmp1baaa, tauab)
1527    Hr2abbb += 0.5*.5*lib.einsum('MNiJ,MNAB->iJAB', tmp1bbab, taubb)
1528    tauaa = tauab = taubb = None
1529
1530    tmpab  = lib.einsum('menf,Imef->nI', eris_ovov, r2baaa)
1531    tmpab -= lib.einsum('nfME,IMfE->nI', eris_ovOV, r2bbab)
1532    tmpba  = lib.einsum('MENF,iMEF->Ni', eris_OVOV, r2abbb)
1533    tmpba -= lib.einsum('meNF,imFe->Ni', eris_ovOV, r2aaba)
1534    Hr1ab += np.einsum('NA,Ni->iA', t1b, tmpba)
1535    Hr1ba += np.einsum('na,nI->Ia', t1a, tmpab)
1536    Hr2baaa -= lib.einsum('mJ,imab->Jiab', tmpab*.5, t2aa)
1537    Hr2bbab -= lib.einsum('mJ,mIaB->IJaB', tmpab*.5, t2ab) * 2
1538    Hr2aaba -= lib.einsum('Mj,iMbA->ijAb', tmpba*.5, t2ab) * 2
1539    Hr2abbb -= lib.einsum('Mj,IMAB->jIAB', tmpba*.5, t2bb)
1540
1541    tmp1ab = np.einsum('meNF,mF->eN', eris_ovOV, r1ab)
1542    tmp1ba = np.einsum('nfME,Mf->En', eris_ovOV, r1ba)
1543    tmpab = np.einsum('eN,NB->eB', tmp1ab, t1b)
1544    tmpba = np.einsum('En,nb->Eb', tmp1ba, t1a)
1545    tmpab -= lib.einsum('menf,mnBf->eB', eris_ovov, r2aaba)
1546    tmpab += lib.einsum('meNF,mNFB->eB', eris_ovOV, r2abbb)
1547    tmpba -= lib.einsum('MENF,MNbF->Eb', eris_OVOV, r2bbab)
1548    tmpba += lib.einsum('nfME,Mnfb->Eb', eris_ovOV, r2baaa)
1549    Hr2baaa -= lib.einsum('Eb,jIaE->Ijab', tmpba*.5, t2ab) * 2
1550    Hr2bbab -= lib.einsum('Eb,IJAE->IJbA', tmpba*.5, t2bb)
1551    Hr2aaba -= lib.einsum('eB,ijae->ijBa', tmpab*.5, t2aa)
1552    Hr2abbb -= lib.einsum('eB,iJeA->iJAB', tmpab*.5, t2ab) * 2
1553    eris_ovov = eris_OVOV = eris_ovOV = None
1554
1555    #:Hr2baaa += .5*lib.einsum('Ijef,aebf->Ijab', tau2baaa, eris.vvvv)
1556    #:Hr2abbb += .5*lib.einsum('iJEF,AEBF->iJAB', tau2abbb, eris.VVVV)
1557    #:Hr2bbab += .5*lib.einsum('IJeF,aeBF->IJaB', tau2bbab, eris.vvVV)
1558    #:Hr2aaba += .5*lib.einsum('ijEf,bfAE->ijAb', tau2aaba, eris.vvVV)
1559    fakeri = uccsd._ChemistsERIs()
1560    fakeri.mol = eris.mol
1561
1562    if eom._cc.direct:
1563        orbva = eris.mo_coeff[0][:,nocca:]
1564        orbvb = eris.mo_coeff[1][:,noccb:]
1565        tau2baaa = lib.einsum('ijab,pa,qb->ijpq', tau2baaa, .5*orbva, orbva)
1566        tmp = eris._contract_vvvv_t2(eom._cc, tau2baaa, True)
1567        Hr2baaa += lib.einsum('ijpq,pa,qb->ijab', tmp, orbva.conj(), orbva.conj())
1568        tmp = None
1569
1570        tau2abbb = lib.einsum('ijab,pa,qb->ijpq', tau2abbb, .5*orbvb, orbvb)
1571        tmp = eris._contract_VVVV_t2(eom._cc, tau2abbb, True)
1572        Hr2abbb += lib.einsum('ijpq,pa,qb->ijab', tmp, orbvb.conj(), orbvb.conj())
1573        tmp = None
1574    else:
1575        tau2baaa *= .5
1576        Hr2baaa += eris._contract_vvvv_t2(eom._cc, tau2baaa, False)
1577        tau2abbb *= .5
1578        Hr2abbb += eris._contract_VVVV_t2(eom._cc, tau2abbb, False)
1579
1580    tau2bbab *= .5
1581    Hr2bbab += eom._cc._add_vvVV(None, tau2bbab, eris)
1582    tau2aaba = tau2aaba.transpose(0,1,3,2)*.5
1583    Hr2aaba += eom._cc._add_vvVV(None, tau2aaba, eris).transpose(0,1,3,2)
1584
1585    Hr2baaa = Hr2baaa - Hr2baaa.transpose(0,1,3,2)
1586    Hr2bbab = Hr2bbab - Hr2bbab.transpose(1,0,2,3)
1587    Hr2abbb = Hr2abbb - Hr2abbb.transpose(0,1,3,2)
1588    Hr2aaba = Hr2aaba - Hr2aaba.transpose(1,0,2,3)
1589    vector = amplitudes_to_vector_eomsf((Hr1ab, Hr1ba), (Hr2baaa,Hr2aaba,Hr2abbb,Hr2bbab))
1590    return vector
1591
1592def eeccsd_diag(eom, imds=None):
1593    if imds is None: imds = eom.make_imds()
1594    eris = imds.eris
1595    t1, t2 = imds.t1, imds.t2
1596    t1a, t1b = t1
1597    t2aa, t2ab, t2bb = t2
1598    tauaa, tauab, taubb = uccsd.make_tau(t2, t1, t1)
1599    nocca, noccb, nvira, nvirb = t2ab.shape
1600
1601    Foa = imds.Fooa.diagonal()
1602    Fob = imds.Foob.diagonal()
1603    Fva = imds.Fvva.diagonal()
1604    Fvb = imds.Fvvb.diagonal()
1605    Wovaa = np.einsum('iaai->ia', imds.wovvo)
1606    Wovbb = np.einsum('iaai->ia', imds.wOVVO)
1607    Wovab = np.einsum('iaai->ia', imds.woVVo)
1608    Wovba = np.einsum('iaai->ia', imds.wOvvO)
1609
1610    Hr1aa = lib.direct_sum('-i+a->ia', Foa, Fva)
1611    Hr1bb = lib.direct_sum('-i+a->ia', Fob, Fvb)
1612    Hr1ab = lib.direct_sum('-i+a->ia', Foa, Fvb)
1613    Hr1ba = lib.direct_sum('-i+a->ia', Fob, Fva)
1614    Hr1aa += Wovaa
1615    Hr1bb += Wovbb
1616    Hr1ab += Wovab
1617    Hr1ba += Wovba
1618
1619    eris_ovov = np.asarray(eris.ovov)
1620    eris_OVOV = np.asarray(eris.OVOV)
1621    eris_ovOV = np.asarray(eris.ovOV)
1622    ovov = eris_ovov - eris_ovov.transpose(0,3,2,1)
1623    OVOV = eris_OVOV - eris_OVOV.transpose(0,3,2,1)
1624    Wvvaa = .5*np.einsum('mnab,manb->ab', tauaa, eris_ovov)
1625    Wvvbb = .5*np.einsum('mnab,manb->ab', taubb, eris_OVOV)
1626    Wvvab =    np.einsum('mNaB,maNB->aB', tauab, eris_ovOV)
1627    ijb = np.einsum('iejb,ijbe->ijb',      ovov, t2aa)
1628    IJB = np.einsum('iejb,ijbe->ijb',      OVOV, t2bb)
1629    iJB =-np.einsum('ieJB,iJeB->iJB', eris_ovOV, t2ab)
1630    Ijb =-np.einsum('jbIE,jIbE->Ijb', eris_ovOV, t2ab)
1631    iJb =-np.einsum('ibJE,iJbE->iJb', eris_ovOV, t2ab)
1632    jab = np.einsum('kajb,jkab->jab',      ovov, t2aa)
1633    JAB = np.einsum('kajb,jkab->jab',      OVOV, t2bb)
1634    jAb =-np.einsum('jbKA,jKbA->jAb', eris_ovOV, t2ab)
1635    JaB =-np.einsum('kaJB,kJaB->JaB', eris_ovOV, t2ab)
1636    jaB =-np.einsum('jaKB,jKaB->jaB', eris_ovOV, t2ab)
1637    eris_ovov = eris_ovOV = eris_OVOV = ovov = OVOV = None
1638    Hr2aa = lib.direct_sum('ijb+a->ijba', ijb, Fva)
1639    Hr2bb = lib.direct_sum('ijb+a->ijba', IJB, Fvb)
1640    Hr2ab = lib.direct_sum('iJb+A->iJbA', iJb, Fvb)
1641    Hr2ab+= lib.direct_sum('iJB+a->iJaB', iJB, Fva)
1642    Hr2aa+= lib.direct_sum('-i+jab->ijab', Foa, jab)
1643    Hr2bb+= lib.direct_sum('-i+jab->ijab', Fob, JAB)
1644    Hr2ab+= lib.direct_sum('-i+JaB->iJaB', Foa, JaB)
1645    Hr2ab+= lib.direct_sum('-I+jaB->jIaB', Fob, jaB)
1646    Hr2aa = Hr2aa + Hr2aa.transpose(0,1,3,2)
1647    Hr2aa = Hr2aa + Hr2aa.transpose(1,0,2,3)
1648    Hr2bb = Hr2bb + Hr2bb.transpose(0,1,3,2)
1649    Hr2bb = Hr2bb + Hr2bb.transpose(1,0,2,3)
1650    Hr2aa *= .5
1651    Hr2bb *= .5
1652    Hr2baaa = lib.direct_sum('Ijb+a->Ijba', Ijb, Fva)
1653    Hr2aaba = lib.direct_sum('ijb+A->ijAb', ijb, Fvb)
1654    Hr2aaba+= Fva.reshape(1,1,1,-1)
1655    Hr2abbb = lib.direct_sum('iJB+A->iJBA', iJB, Fvb)
1656    Hr2bbab = lib.direct_sum('IJB+a->IJaB', IJB, Fva)
1657    Hr2bbab+= Fvb.reshape(1,1,1,-1)
1658    Hr2baaa = Hr2baaa + Hr2baaa.transpose(0,1,3,2)
1659    Hr2abbb = Hr2abbb + Hr2abbb.transpose(0,1,3,2)
1660    Hr2baaa+= lib.direct_sum('-I+jab->Ijab', Fob, jab)
1661    Hr2baaa-= Foa.reshape(1,-1,1,1)
1662    tmpaaba = lib.direct_sum('-i+jAb->ijAb', Foa, jAb)
1663    Hr2abbb+= lib.direct_sum('-i+JAB->iJAB', Foa, JAB)
1664    Hr2abbb-= Fob.reshape(1,-1,1,1)
1665    tmpbbab = lib.direct_sum('-I+JaB->IJaB', Fob, JaB)
1666    Hr2aaba+= tmpaaba + tmpaaba.transpose(1,0,2,3)
1667    Hr2bbab+= tmpbbab + tmpbbab.transpose(1,0,2,3)
1668    tmpaaba = tmpbbab = None
1669    Hr2aa += Wovaa.reshape(1,nocca,1,nvira)
1670    Hr2aa += Wovaa.reshape(nocca,1,1,nvira)
1671    Hr2aa += Wovaa.reshape(nocca,1,nvira,1)
1672    Hr2aa += Wovaa.reshape(1,nocca,nvira,1)
1673    Hr2ab += Wovbb.reshape(1,noccb,1,nvirb)
1674    Hr2ab += Wovab.reshape(nocca,1,1,nvirb)
1675    Hr2ab += Wovaa.reshape(nocca,1,nvira,1)
1676    Hr2ab += Wovba.reshape(1,noccb,nvira,1)
1677    Hr2bb += Wovbb.reshape(1,noccb,1,nvirb)
1678    Hr2bb += Wovbb.reshape(noccb,1,1,nvirb)
1679    Hr2bb += Wovbb.reshape(noccb,1,nvirb,1)
1680    Hr2bb += Wovbb.reshape(1,noccb,nvirb,1)
1681    Hr2baaa += Wovaa.reshape(1,nocca,1,nvira)
1682    Hr2baaa += Wovba.reshape(noccb,1,1,nvira)
1683    Hr2baaa += Wovba.reshape(noccb,1,nvira,1)
1684    Hr2baaa += Wovaa.reshape(1,nocca,nvira,1)
1685    Hr2aaba += Wovaa.reshape(1,nocca,1,nvira)
1686    Hr2aaba += Wovaa.reshape(nocca,1,1,nvira)
1687    Hr2aaba += Wovab.reshape(nocca,1,nvirb,1)
1688    Hr2aaba += Wovab.reshape(1,nocca,nvirb,1)
1689    Hr2abbb += Wovbb.reshape(1,noccb,1,nvirb)
1690    Hr2abbb += Wovab.reshape(nocca,1,1,nvirb)
1691    Hr2abbb += Wovab.reshape(nocca,1,nvirb,1)
1692    Hr2abbb += Wovbb.reshape(1,noccb,nvirb,1)
1693    Hr2bbab += Wovbb.reshape(1,noccb,1,nvirb)
1694    Hr2bbab += Wovbb.reshape(noccb,1,1,nvirb)
1695    Hr2bbab += Wovba.reshape(noccb,1,nvira,1)
1696    Hr2bbab += Wovba.reshape(1,noccb,nvira,1)
1697
1698    Wooaa  = np.einsum('iijj->ij', imds.woooo).copy()
1699    Wooaa -= np.einsum('ijji->ij', imds.woooo)
1700    Woobb  = np.einsum('iijj->ij', imds.wOOOO).copy()
1701    Woobb -= np.einsum('ijji->ij', imds.wOOOO)
1702    Wooab = np.einsum('iijj->ij', imds.wooOO)
1703    Wooba = Wooab.T
1704    Wooaa *= .5
1705    Woobb *= .5
1706    Hr2aa += Wooaa.reshape(nocca,nocca,1,1)
1707    Hr2ab += Wooab.reshape(nocca,noccb,1,1)
1708    Hr2bb += Woobb.reshape(noccb,noccb,1,1)
1709    Hr2baaa += Wooba.reshape(noccb,nocca,1,1)
1710    Hr2aaba += Wooaa.reshape(nocca,nocca,1,1)
1711    Hr2abbb += Wooab.reshape(nocca,noccb,1,1)
1712    Hr2bbab += Woobb.reshape(noccb,noccb,1,1)
1713
1714    #:eris_ovvv = lib.unpack_tril(np.asarray(eris.ovvv).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvira,nvira)
1715    #:Wvvaa += np.einsum('mb,maab->ab', t1a, eris_ovvv)
1716    #:Wvvaa -= np.einsum('mb,mbaa->ab', t1a, eris_ovvv)
1717    mem_now = lib.current_memory()[0]
1718    max_memory = max(0, eom.max_memory - mem_now)
1719    blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira**3*3))))
1720    for p0,p1 in lib.prange(0, nocca, blksize):
1721        ovvv = eris.get_ovvv(slice(p0,p1))  # ovvv = eris.ovvv[p0:p1]
1722        Wvvaa += np.einsum('mb,maab->ab', t1a[p0:p1], ovvv)
1723        Wvvaa -= np.einsum('mb,mbaa->ab', t1a[p0:p1], ovvv)
1724        ovvv = None
1725    #:eris_OVVV = lib.unpack_tril(np.asarray(eris.OVVV).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvirb,nvirb)
1726    #:Wvvbb += np.einsum('mb,maab->ab', t1b, eris_OVVV)
1727    #:Wvvbb -= np.einsum('mb,mbaa->ab', t1b, eris_OVVV)
1728    blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb**3*3))))
1729    for p0, p1 in lib.prange(0, noccb, blksize):
1730        OVVV = eris.get_OVVV(slice(p0,p1))  # OVVV = eris.OVVV[p0:p1]
1731        Wvvbb += np.einsum('mb,maab->ab', t1b[p0:p1], OVVV)
1732        Wvvbb -= np.einsum('mb,mbaa->ab', t1b[p0:p1], OVVV)
1733        OVVV = None
1734    #:eris_ovVV = lib.unpack_tril(np.asarray(eris.ovVV).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvirb,nvirb)
1735    #:Wvvab -= np.einsum('mb,mbaa->ba', t1a, eris_ovVV)
1736    blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira*nvirb**2*3))))
1737    for p0,p1 in lib.prange(0, nocca, blksize):
1738        ovVV = eris.get_ovVV(slice(p0,p1))  # ovVV = eris.ovVV[p0:p1]
1739        Wvvab -= np.einsum('mb,mbaa->ba', t1a[p0:p1], ovVV)
1740        ovVV = None
1741    blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb*nvira**2*3))))
1742    #:eris_OVvv = lib.unpack_tril(np.asarray(eris.OVvv).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvira,nvira)
1743    #:Wvvab -= np.einsum('mb,mbaa->ab', t1b, eris_OVvv)
1744    #idxa = np.arange(nvira)
1745    #idxa = idxa*(idxa+1)//2+idxa
1746    #for p0, p1 in lib.prange(0, noccb, blksize):
1747    #    OVvv = np.asarray(eris.OVvv[p0:p1])
1748    #    Wvvab -= np.einsum('mb,mba->ab', t1b[p0:p1], OVvv[:,:,idxa])
1749    #    OVvv = None
1750    for p0, p1 in lib.prange(0, noccb, blksize):
1751        OVvv = eris.get_OVvv(slice(p0,p1))  # OVvv = eris.OVvv[p0:p1]
1752        Wvvab -= np.einsum('mb,mbaa->ab', t1b[p0:p1], OVvv)
1753        OVvv = None
1754    Wvvaa = Wvvaa + Wvvaa.T
1755    Wvvbb = Wvvbb + Wvvbb.T
1756    #:eris_vvvv = ao2mo.restore(1, np.asarray(eris.vvvv), nvirb)
1757    #:eris_VVVV = ao2mo.restore(1, np.asarray(eris.VVVV), nvirb)
1758    #:eris_vvVV = _restore(np.asarray(eris.vvVV), nvira, nvirb)
1759    #:Wvvaa += np.einsum('aabb->ab', eris_vvvv) - np.einsum('abba->ab', eris_vvvv)
1760    #:Wvvbb += np.einsum('aabb->ab', eris_VVVV) - np.einsum('abba->ab', eris_VVVV)
1761    #:Wvvab += np.einsum('aabb->ab', eris_vvVV)
1762    if eris.vvvv is not None:
1763        for i in range(nvira):
1764            i0 = i*(i+1)//2
1765            vvv = lib.unpack_tril(np.asarray(eris.vvvv[i0:i0+i+1]))
1766            tmp = np.einsum('bb->b', vvv[i])
1767            Wvvaa[i] += tmp
1768            tmp = np.einsum('bb->b', vvv[:,:i+1,i])
1769            Wvvaa[i,:i+1] -= tmp
1770            Wvvaa[:i  ,i] -= tmp[:i]
1771            vvv = lib.unpack_tril(np.asarray(eris.vvVV[i0:i0+i+1]))
1772            Wvvab[i] += np.einsum('bb->b', vvv[i])
1773            vvv = None
1774        for i in range(nvirb):
1775            i0 = i*(i+1)//2
1776            vvv = lib.unpack_tril(np.asarray(eris.VVVV[i0:i0+i+1]))
1777            tmp = np.einsum('bb->b', vvv[i])
1778            Wvvbb[i] += tmp
1779            tmp = np.einsum('bb->b', vvv[:,:i+1,i])
1780            Wvvbb[i,:i+1] -= tmp
1781            Wvvbb[:i  ,i] -= tmp[:i]
1782            vvv = None
1783    Wvvba = Wvvab.T
1784    Hr2aa += Wvvaa.reshape(1,1,nvira,nvira)
1785    Hr2ab += Wvvab.reshape(1,1,nvira,nvirb)
1786    Hr2bb += Wvvbb.reshape(1,1,nvirb,nvirb)
1787    Hr2baaa += Wvvaa.reshape(1,1,nvira,nvira)
1788    Hr2aaba += Wvvba.reshape(1,1,nvirb,nvira)
1789    Hr2abbb += Wvvbb.reshape(1,1,nvirb,nvirb)
1790    Hr2bbab += Wvvab.reshape(1,1,nvira,nvirb)
1791
1792    vec_ee = amplitudes_to_vector_ee((Hr1aa,Hr1bb), (Hr2aa,Hr2ab,Hr2bb))
1793    vec_sf = amplitudes_to_vector_eomsf((Hr1ab,Hr1ba), (Hr2baaa,Hr2aaba,Hr2abbb,Hr2bbab))
1794    return vec_ee, vec_sf
1795
1796class EOMEE(eom_rccsd.EOMEE):
1797    def __init__(self, cc):
1798        eom_rccsd.EOMEE.__init__(self, cc)
1799        self.nocc = cc.get_nocc()
1800        self.nmo = cc.get_nmo()
1801
1802    kernel = eeccsd
1803    eeccsd = eeccsd
1804    get_diag = eeccsd_diag
1805
1806    def vector_size(self):
1807        '''size of the vector based on spin-orbital basis'''
1808        nocc = np.sum(self.nocc)
1809        nvir = np.sum(self.nmo) - nocc
1810        return nocc*nvir + nocc*(nocc-1)//2*nvir*(nvir-1)//2
1811
1812    def make_imds(self, eris=None):
1813        imds = _IMDS(self._cc, eris=eris)
1814        imds.make_ee()
1815        return imds
1816
1817class EOMEESpinKeep(EOMEE):
1818    kernel = eomee_ccsd
1819    eomee_ccsd = eomee_ccsd
1820    matvec = eomee_ccsd_matvec
1821    get_diag = eeccsd_diag
1822
1823    def get_init_guess(self, nroots=1, koopmans=True, diag=None):
1824        if koopmans:
1825            nocca, noccb = self.nocc
1826            nmoa, nmob = self.nmo
1827            nvira, nvirb = nmoa-nocca, nmob-noccb
1828# amplitudes are compressed by the function amplitudes_to_vector_ee. sizea is
1829# the offset in the compressed vector that points to the amplitudes R1_beta
1830# The addresses of R1_alpha and R1_beta are not contiguous in the compressed
1831# vector.
1832            sizea = nocca * nvira + nocca*(nocca-1)//2*nvira*(nvira-1)//2
1833            diag = np.append(diag[:nocca*nvira], diag[sizea:sizea+noccb*nvirb])
1834            addr = np.append(np.arange(nocca*nvira),
1835                             np.arange(sizea,sizea+noccb*nvirb))
1836            idx = addr[diag.argsort()]
1837        else:
1838            idx = diag.argsort()
1839
1840        size = self.vector_size()
1841        dtype = getattr(diag, 'dtype', np.double)
1842        nroots = min(nroots, size)
1843        guess = []
1844        for i in idx[:nroots]:
1845            g = np.zeros(size, dtype)
1846            g[i] = 1.0
1847            guess.append(g)
1848        return guess
1849
1850    def gen_matvec(self, imds=None, diag=None, **kwargs):
1851        if imds is None: imds = self.make_imds()
1852        if diag is None: diag = self.get_diag(imds)[0]
1853        matvec = lambda xs: [self.matvec(x, imds) for x in xs]
1854        return matvec, diag
1855
1856    def vector_to_amplitudes(self, vector, nmo=None, nocc=None):
1857        if nmo is None: nmo = self.nmo
1858        if nocc is None: nocc = self.nocc
1859        return vector_to_amplitudes_ee(vector, nmo, nocc)
1860
1861    def amplitudes_to_vector(self, r1, r2):
1862        return amplitudes_to_vector_ee(r1, r2)
1863
1864    def vector_size(self):
1865        '''size of the vector based on spin-orbital basis'''
1866        nocca, noccb = self.nocc
1867        nmoa, nmob = self.nmo
1868        nvira, nvirb = nmoa-nocca, nmob-noccb
1869        sizea = nocca * nvira + nocca*(nocca-1)//2*nvira*(nvira-1)//2
1870        sizeb = noccb * nvirb + noccb*(noccb-1)//2*nvirb*(nvirb-1)//2
1871        sizeab = nocca * noccb * nvira * nvirb
1872        return sizea+sizeb+sizeab
1873
1874class EOMEESpinFlip(EOMEE):
1875    kernel = eomsf_ccsd
1876    eomsf_ccsd = eomsf_ccsd
1877    matvec = eomsf_ccsd_matvec
1878
1879    def get_init_guess(self, nroots=1, koopmans=True, diag=None):
1880        if koopmans:
1881            nocca, noccb = self.nocc
1882            nmoa, nmob = self.nmo
1883            nvira, nvirb = nmoa-nocca, nmob-noccb
1884            idx = diag[:nocca*nvirb+noccb*nvira].argsort()
1885        else:
1886            idx = diag.argsort()
1887
1888        size = self.vector_size()
1889        dtype = getattr(diag, 'dtype', np.double)
1890        nroots = min(nroots, size)
1891        guess = []
1892        for i in idx[:nroots]:
1893            g = np.zeros(size, dtype)
1894            g[i] = 1.0
1895            guess.append(g)
1896        return guess
1897
1898    def gen_matvec(self, imds=None, diag=None, **kwargs):
1899        if imds is None: imds = self.make_imds()
1900        if diag is None: diag = self.get_diag(imds)[1]
1901        matvec = lambda xs: [self.matvec(x, imds) for x in xs]
1902        return matvec, diag
1903
1904    def vector_to_amplitudes(self, vector, nmo=None, nocc=None):
1905        if nmo is None: nmo = self.nmo
1906        if nocc is None: nocc = self.nocc
1907        return vector_to_amplitudes_eomsf(vector, nmo, nocc)
1908
1909    def amplitudes_to_vector(self, r1, r2):
1910        return amplitudes_to_vector_eomsf(r1, r2)
1911
1912    def vector_size(self):
1913        '''size of the vector based on spin-orbital basis'''
1914        nocca, noccb = self.nocc
1915        nmoa, nmob = self.nmo
1916        nvira, nvirb = nmoa-nocca, nmob-noccb
1917
1918        nbaaa = noccb*nocca*nvira*(nvira-1)//2
1919        naaba = nocca*(nocca-1)//2*nvirb*nvira
1920        nabbb = nocca*noccb*nvirb*(nvirb-1)//2
1921        nbbab = noccb*(noccb-1)//2*nvira*nvirb
1922        return nocca*nvirb + noccb*nvira + nbaaa + naaba + nabbb + nbbab
1923
1924uccsd.UCCSD.EOMIP         = lib.class_as_method(EOMIP)
1925uccsd.UCCSD.EOMEA         = lib.class_as_method(EOMEA)
1926uccsd.UCCSD.EOMEE         = lib.class_as_method(EOMEE)
1927uccsd.UCCSD.EOMEESpinKeep = lib.class_as_method(EOMEESpinKeep)
1928uccsd.UCCSD.EOMEESpinFlip = lib.class_as_method(EOMEESpinFlip)
1929
1930
1931class _IMDS:
1932    # Exactly the same as RCCSD IMDS except
1933    # -- rintermediates --> uintermediates
1934    # -- Loo, Lvv, cc_Fov --> Foo, Fvv, Fov
1935    # -- One less 2-virtual intermediate
1936    def __init__(self, cc, eris=None):
1937        self.verbose = cc.verbose
1938        self.stdout = cc.stdout
1939        self.t1 = cc.t1
1940        self.t2 = cc.t2
1941        if eris is None:
1942            eris = cc.ao2mo()
1943        self.eris = eris
1944        self._made_shared = False
1945        self.made_ip_imds = False
1946        self.made_ea_imds = False
1947        self.made_ee_imds = False
1948
1949    def _make_shared(self):
1950        cput0 = (logger.process_clock(), logger.perf_counter())
1951
1952        t1, t2, eris = self.t1, self.t2, self.eris
1953        self.Foo, self.FOO = uintermediates.Foo(t1, t2, eris)
1954        self.Fvv, self.FVV = uintermediates.Fvv(t1, t2, eris)
1955        self.Fov, self.FOV = uintermediates.Fov(t1, t2, eris)
1956
1957        # 2 virtuals
1958        self.Wovvo, self.WovVO, self.WOVvo, self.WOVVO, self.WoVVo, self.WOvvO = \
1959                uintermediates.Wovvo(t1, t2, eris)
1960        Wovov = np.asarray(eris.ovov)
1961        WOVOV = np.asarray(eris.OVOV)
1962        Wovov = Wovov - Wovov.transpose(0,3,2,1)
1963        WOVOV = WOVOV - WOVOV.transpose(0,3,2,1)
1964        self.Wovov = Wovov
1965        self.WovOV = eris.ovOV
1966        self.WOVov = None
1967        self.WOVOV = WOVOV
1968
1969        self._made_shared = True
1970        logger.timer_debug1(self, 'EOM-CCSD shared intermediates', *cput0)
1971        return self
1972
1973    def make_ip(self):
1974        if not self._made_shared:
1975            self._make_shared()
1976
1977        cput0 = (logger.process_clock(), logger.perf_counter())
1978
1979        t1, t2, eris = self.t1, self.t2, self.eris
1980
1981        # 0 or 1 virtuals
1982        self.Woooo, self.WooOO, _         , self.WOOOO = uintermediates.Woooo(t1, t2, eris)
1983        self.Wooov, self.WooOV, self.WOOov, self.WOOOV = uintermediates.Wooov(t1, t2, eris)
1984        self.Woovo, self.WooVO, self.WOOvo, self.WOOVO = uintermediates.Woovo(t1, t2, eris)
1985
1986        self.made_ip_imds = True
1987        logger.timer_debug1(self, 'EOM-UCCSD IP intermediates', *cput0)
1988        return self
1989
1990    def make_ea(self):
1991        if not self._made_shared:
1992            self._make_shared()
1993
1994        cput0 = (logger.process_clock(), logger.perf_counter())
1995
1996        t1, t2, eris = self.t1, self.t2, self.eris
1997
1998        # 3 or 4 virtuals
1999        self.Wvvov, self.WvvOV, self.WVVov, self.WVVOV = uintermediates.Wvvov(t1, t2, eris)
2000        self.Wvvvv = None  # too expensive to hold Wvvvv
2001        self.Wvvvo, self.WvvVO, self.WVVvo, self.WVVVO = uintermediates.Wvvvo(t1, t2, eris)
2002
2003        # The contribution of Wvvvv
2004        t1a, t1b = t1
2005        # The contraction to eris.vvvv is included in eaccsd_matvec
2006        #:vvvv = eris.vvvv - eris.vvvv.transpose(0,3,2,1)
2007        #:VVVV = eris.VVVV - eris.VVVV.transpose(0,3,2,1)
2008        #:self.Wvvvo += lib.einsum('abef,if->abei',      vvvv, t1a)
2009        #:self.WvvVO += lib.einsum('abef,if->abei', eris_vvVV, t1b)
2010        #:self.WVVvo += lib.einsum('efab,if->abei', eris_vvVV, t1a)
2011        #:self.WVVVO += lib.einsum('abef,if->abei',      VVVV, t1b)
2012
2013        tauaa, tauab, taubb = uccsd.make_tau(t2, t1, t1)
2014        eris_ovov = np.asarray(eris.ovov)
2015        eris_OVOV = np.asarray(eris.OVOV)
2016        eris_ovOV = np.asarray(eris.ovOV)
2017        ovov = eris_ovov - eris_ovov.transpose(0,3,2,1)
2018        OVOV = eris_OVOV - eris_OVOV.transpose(0,3,2,1)
2019        tmp = lib.einsum('menf,if->meni',      ovov, t1a) * .5
2020        self.Wvvvo += lib.einsum('meni,mnab->aebi', tmp, tauaa)
2021        tmp = tauaa = None
2022
2023        tmp = lib.einsum('menf,if->meni',      OVOV, t1b) * .5
2024        self.WVVVO += lib.einsum('meni,mnab->aebi', tmp, taubb)
2025        tmp = taubb = None
2026
2027        tmp = lib.einsum('menf,if->meni', eris_ovOV, t1b)
2028        self.WvvVO += lib.einsum('meni,mnab->aebi', tmp, tauab)
2029        tmp = lib.einsum('nfme,if->meni', eris_ovOV, t1a)
2030        self.WVVvo += lib.einsum('meni,nmba->aebi', tmp, tauab)
2031        tauab = None
2032        ovov = OVOV = eris_ovov = eris_OVOV = eris_ovOV = None
2033
2034        eris_ovvv = eris.get_ovvv(slice(None))
2035        ovvv = eris_ovvv - eris_ovvv.transpose(0,3,2,1)
2036        tmp = lib.einsum('mebf,if->mebi', ovvv, t1a)
2037        tmp = lib.einsum('mebi,ma->aebi', tmp, t1a)
2038        self.Wvvvo -= tmp - tmp.transpose(2,1,0,3)
2039        tmp = eris_ovvv = ovvv = None
2040
2041        eris_OVVV = eris.get_OVVV(slice(None))
2042        OVVV = eris_OVVV - eris_OVVV.transpose(0,3,2,1)
2043        tmp = lib.einsum('mebf,if->mebi', OVVV, t1b)
2044        tmp = lib.einsum('mebi,ma->aebi', tmp, t1b)
2045        self.WVVVO -= tmp - tmp.transpose(2,1,0,3)
2046        tmp = eris_OVVV = OVVV = None
2047
2048        eris_ovVV = eris.get_ovVV(slice(None))
2049        eris_OVvv = eris.get_OVvv(slice(None))
2050        tmpaabb = lib.einsum('mebf,if->mebi', eris_ovVV, t1b)
2051        tmpbaab = lib.einsum('mebf,ie->mfbi', eris_OVvv, t1b)
2052        tmp  = lib.einsum('mebi,ma->aebi', tmpaabb, t1a)
2053        tmp += lib.einsum('mfbi,ma->bfai', tmpbaab, t1b)
2054        self.WvvVO -= tmp
2055        tmp = tmpaabb = tmpbaab = None
2056
2057        tmpbbaa = lib.einsum('mebf,if->mebi', eris_OVvv, t1a)
2058        tmpabba = lib.einsum('mebf,ie->mfbi', eris_ovVV, t1a)
2059        tmp  = lib.einsum('mebi,ma->aebi', tmpbbaa, t1b)
2060        tmp += lib.einsum('mfbi,ma->bfai', tmpabba, t1a)
2061        self.WVVvo -= tmp
2062        tmp = tmpbbaa = tmpabba = None
2063        eris_ovVV = eris_OVvv = None
2064        # The contribution of Wvvvv end
2065
2066        self.made_ea_imds = True
2067        logger.timer_debug1(self, 'EOM-UCCSD EA intermediates', *cput0)
2068        return self
2069
2070    def make_ee(self):
2071        cput0 = (logger.process_clock(), logger.perf_counter())
2072        log = logger.Logger(self.stdout, self.verbose)
2073
2074        t1, t2, eris = self.t1, self.t2, self.eris
2075        t1a, t1b = t1
2076        t2aa, t2ab, t2bb = t2
2077        nocca, noccb, nvira, nvirb = t2ab.shape
2078        dtype = np.result_type(t1a, t1b, t2aa, t2ab, t2bb)
2079
2080        fooa = eris.focka[:nocca,:nocca]
2081        foob = eris.fockb[:noccb,:noccb]
2082        fova = eris.focka[:nocca,nocca:]
2083        fovb = eris.fockb[:noccb,noccb:]
2084        fvva = eris.focka[nocca:,nocca:]
2085        fvvb = eris.fockb[noccb:,noccb:]
2086
2087        self.Fooa = np.zeros((nocca,nocca), dtype=dtype)
2088        self.Foob = np.zeros((noccb,noccb), dtype=dtype)
2089        self.Fvva = np.zeros((nvira,nvira), dtype=dtype)
2090        self.Fvvb = np.zeros((nvirb,nvirb), dtype=dtype)
2091
2092        wovvo = np.zeros((nocca,nvira,nvira,nocca), dtype=dtype)
2093        wOVVO = np.zeros((noccb,nvirb,nvirb,noccb), dtype=dtype)
2094        woVvO = np.zeros((nocca,nvirb,nvira,noccb), dtype=dtype)
2095        woVVo = np.zeros((nocca,nvirb,nvirb,nocca), dtype=dtype)
2096        wOvVo = np.zeros((noccb,nvira,nvirb,nocca), dtype=dtype)
2097        wOvvO = np.zeros((noccb,nvira,nvira,noccb), dtype=dtype)
2098
2099        wovoo = np.zeros((nocca,nvira,nocca,nocca), dtype=dtype)
2100        wOVOO = np.zeros((noccb,nvirb,noccb,noccb), dtype=dtype)
2101        woVoO = np.zeros((nocca,nvirb,nocca,noccb), dtype=dtype)
2102        wOvOo = np.zeros((noccb,nvira,noccb,nocca), dtype=dtype)
2103
2104        tauaa, tauab, taubb = uccsd.make_tau(t2, t1, t1)
2105        #:eris_ovvv = lib.unpack_tril(np.asarray(eris.ovvv).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvira,nvira)
2106        #:ovvv = eris_ovvv - eris_ovvv.transpose(0,3,2,1)
2107        #:self.Fvva  = np.einsum('mf,mfae->ae', t1a, ovvv)
2108        #:self.wovvo = lib.einsum('jf,mebf->mbej', t1a, ovvv)
2109        #:self.wovoo  = 0.5 * lib.einsum('mebf,ijef->mbij', eris_ovvv, tauaa)
2110        #:self.wovoo -= 0.5 * lib.einsum('mfbe,ijef->mbij', eris_ovvv, tauaa)
2111        mem_now = lib.current_memory()[0]
2112        max_memory = max(0, lib.param.MAX_MEMORY - mem_now)
2113        blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira**3*3))))
2114        for p0,p1 in lib.prange(0, nocca, blksize):
2115            ovvv = eris.get_ovvv(slice(p0,p1))  # ovvv = eris.ovvv[p0:p1]
2116            ovvv = ovvv - ovvv.transpose(0,3,2,1)
2117            self.Fvva += np.einsum('mf,mfae->ae', t1a[p0:p1], ovvv)
2118            wovvo[p0:p1] = lib.einsum('jf,mebf->mbej', t1a, ovvv)
2119            wovoo[p0:p1] = 0.5 * lib.einsum('mebf,ijef->mbij', ovvv, tauaa)
2120            ovvv = None
2121
2122        #:eris_OVVV = lib.unpack_tril(np.asarray(eris.OVVV).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvirb,nvirb)
2123        #:OVVV = eris_OVVV - eris_OVVV.transpose(0,3,2,1)
2124        #:self.Fvvb  = np.einsum('mf,mfae->ae', t1b, OVVV)
2125        #:self.wOVVO = lib.einsum('jf,mebf->mbej', t1b, OVVV)
2126        #:self.wOVOO  = 0.5 * lib.einsum('mebf,ijef->mbij', OVVV, taubb)
2127        blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb**3*3))))
2128        for p0, p1 in lib.prange(0, noccb, blksize):
2129            OVVV = eris.get_OVVV(slice(p0,p1))  # OVVV = eris.OVVV[p0:p1]
2130            OVVV = OVVV - OVVV.transpose(0,3,2,1)
2131            self.Fvvb += np.einsum('mf,mfae->ae', t1b[p0:p1], OVVV)
2132            wOVVO[p0:p1] = lib.einsum('jf,mebf->mbej', t1b, OVVV)
2133            wOVOO[p0:p1] = 0.5 * lib.einsum('mebf,ijef->mbij', OVVV, taubb)
2134            OVVV = None
2135
2136        #:eris_ovVV = lib.unpack_tril(np.asarray(eris.ovVV).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvirb,nvirb)
2137        #:self.Fvvb += np.einsum('mf,mfAE->AE', t1a, eris_ovVV)
2138        #:self.woVvO = lib.einsum('JF,meBF->mBeJ', t1b, eris_ovVV)
2139        #:self.woVVo = lib.einsum('jf,mfBE->mBEj',-t1a, eris_ovVV)
2140        #:self.woVoO  = 0.5 * lib.einsum('meBF,iJeF->mBiJ', eris_ovVV, tauab)
2141        #:self.woVoO += 0.5 * lib.einsum('mfBE,iJfE->mBiJ', eris_ovVV, tauab)
2142        blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira*nvirb**2*3))))
2143        for p0,p1 in lib.prange(0, nocca, blksize):
2144            ovVV = eris.get_ovVV(slice(p0,p1))  # ovVV = eris.ovVV[p0:p1]
2145            self.Fvvb += np.einsum('mf,mfAE->AE', t1a[p0:p1], ovVV)
2146            woVvO[p0:p1] = lib.einsum('JF,meBF->mBeJ', t1b, ovVV)
2147            woVVo[p0:p1] = lib.einsum('jf,mfBE->mBEj',-t1a, ovVV)
2148            woVoO[p0:p1] = 0.5 * lib.einsum('meBF,iJeF->mBiJ', ovVV, tauab)
2149            woVoO[p0:p1]+= 0.5 * lib.einsum('mfBE,iJfE->mBiJ', ovVV, tauab)
2150            ovVV = None
2151
2152        #:eris_OVvv = lib.unpack_tril(np.asarray(eris.OVvv).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvira,nvira)
2153        #:self.Fvva += np.einsum('MF,MFae->ae', t1b, eris_OVvv)
2154        #:self.wOvVo = lib.einsum('jf,MEbf->MbEj', t1a, eris_OVvv)
2155        #:self.wOvvO = lib.einsum('JF,MFbe->MbeJ',-t1b, eris_OVvv)
2156        #:self.wOvOo  = 0.5 * lib.einsum('MEbf,jIfE->MbIj', eris_OVvv, tauab)
2157        #:self.wOvOo += 0.5 * lib.einsum('MFbe,jIeF->MbIj', eris_OVvv, tauab)
2158        blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb*nvira**2*3))))
2159        for p0, p1 in lib.prange(0, noccb, blksize):
2160            OVvv = eris.get_OVvv(slice(p0,p1))  # OVvv = eris.OVvv[p0:p1]
2161            self.Fvva += np.einsum('MF,MFae->ae', t1b[p0:p1], OVvv)
2162            wOvVo[p0:p1] = lib.einsum('jf,MEbf->MbEj', t1a, OVvv)
2163            wOvvO[p0:p1] = lib.einsum('JF,MFbe->MbeJ',-t1b, OVvv)
2164            wOvOo[p0:p1] = 0.5 * lib.einsum('MEbf,jIfE->MbIj', OVvv, tauab)
2165            wOvOo[p0:p1]+= 0.5 * lib.einsum('MFbe,jIeF->MbIj', OVvv, tauab)
2166            OVvv = None
2167
2168        eris_ovov = np.asarray(eris.ovov)
2169        eris_OVOV = np.asarray(eris.OVOV)
2170        eris_ovOV = np.asarray(eris.ovOV)
2171        ovov = eris_ovov - eris_ovov.transpose(0,3,2,1)
2172        OVOV = eris_OVOV - eris_OVOV.transpose(0,3,2,1)
2173        self.Fova = np.einsum('nf,menf->me', t1a,      ovov)
2174        self.Fova+= np.einsum('NF,meNF->me', t1b, eris_ovOV)
2175        self.Fova += fova
2176        self.Fovb = np.einsum('nf,menf->me', t1b,      OVOV)
2177        self.Fovb+= np.einsum('nf,nfME->ME', t1a, eris_ovOV)
2178        self.Fovb += fovb
2179        tilaa, tilab, tilbb = uccsd.make_tau(t2,t1,t1,fac=0.5)
2180        self.Fooa  = lib.einsum('inef,menf->mi', tilaa, eris_ovov)
2181        self.Fooa += lib.einsum('iNeF,meNF->mi', tilab, eris_ovOV)
2182        self.Foob  = lib.einsum('inef,menf->mi', tilbb, eris_OVOV)
2183        self.Foob += lib.einsum('nIfE,nfME->MI', tilab, eris_ovOV)
2184        self.Fvva -= lib.einsum('mnaf,menf->ae', tilaa, eris_ovov)
2185        self.Fvva -= lib.einsum('mNaF,meNF->ae', tilab, eris_ovOV)
2186        self.Fvvb -= lib.einsum('mnaf,menf->ae', tilbb, eris_OVOV)
2187        self.Fvvb -= lib.einsum('nMfA,nfME->AE', tilab, eris_ovOV)
2188        wovvo -= lib.einsum('jnfb,menf->mbej', t2aa,      ovov)
2189        wovvo += lib.einsum('jNbF,meNF->mbej', t2ab, eris_ovOV)
2190        wOVVO -= lib.einsum('jnfb,menf->mbej', t2bb,      OVOV)
2191        wOVVO += lib.einsum('nJfB,nfME->MBEJ', t2ab, eris_ovOV)
2192        woVvO += lib.einsum('nJfB,menf->mBeJ', t2ab,      ovov)
2193        woVvO -= lib.einsum('JNFB,meNF->mBeJ', t2bb, eris_ovOV)
2194        wOvVo -= lib.einsum('jnfb,nfME->MbEj', t2aa, eris_ovOV)
2195        wOvVo += lib.einsum('jNbF,MENF->MbEj', t2ab,      OVOV)
2196        woVVo += lib.einsum('jNfB,mfNE->mBEj', t2ab, eris_ovOV)
2197        wOvvO += lib.einsum('nJbF,neMF->MbeJ', t2ab, eris_ovOV)
2198
2199        eris_ovoo = np.asarray(eris.ovoo)
2200        eris_OVOO = np.asarray(eris.OVOO)
2201        eris_OVoo = np.asarray(eris.OVoo)
2202        eris_ovOO = np.asarray(eris.ovOO)
2203        self.Fooa += np.einsum('ne,nemi->mi', t1a, eris_ovoo)
2204        self.Fooa -= np.einsum('ne,meni->mi', t1a, eris_ovoo)
2205        self.Fooa += np.einsum('NE,NEmi->mi', t1b, eris_OVoo)
2206        self.Foob += np.einsum('ne,nemi->mi', t1b, eris_OVOO)
2207        self.Foob -= np.einsum('ne,meni->mi', t1b, eris_OVOO)
2208        self.Foob += np.einsum('ne,neMI->MI', t1a, eris_ovOO)
2209        eris_ovoo = eris_ovoo + np.einsum('nfme,jf->menj', eris_ovov, t1a)
2210        eris_OVOO = eris_OVOO + np.einsum('nfme,jf->menj', eris_OVOV, t1b)
2211        eris_OVoo = eris_OVoo + np.einsum('nfme,jf->menj', eris_ovOV, t1a)
2212        eris_ovOO = eris_ovOO + np.einsum('menf,jf->menj', eris_ovOV, t1b)
2213        ovoo = eris_ovoo - eris_ovoo.transpose(2,1,0,3)
2214        OVOO = eris_OVOO - eris_OVOO.transpose(2,1,0,3)
2215        wovvo += lib.einsum('nb,nemj->mbej', t1a,      ovoo)
2216        wOVVO += lib.einsum('nb,nemj->mbej', t1b,      OVOO)
2217        woVvO -= lib.einsum('NB,meNJ->mBeJ', t1b, eris_ovOO)
2218        wOvVo -= lib.einsum('nb,MEnj->MbEj', t1a, eris_OVoo)
2219        woVVo += lib.einsum('NB,NEmj->mBEj', t1b, eris_OVoo)
2220        wOvvO += lib.einsum('nb,neMJ->MbeJ', t1a, eris_ovOO)
2221
2222        self.Fooa += fooa + 0.5*lib.einsum('me,ie->mi', self.Fova+fova, t1a)
2223        self.Foob += foob + 0.5*lib.einsum('me,ie->mi', self.Fovb+fovb, t1b)
2224        self.Fvva += fvva - 0.5*lib.einsum('me,ma->ae', self.Fova+fova, t1a)
2225        self.Fvvb += fvvb - 0.5*lib.einsum('me,ma->ae', self.Fovb+fovb, t1b)
2226
2227        # 0 or 1 virtuals
2228        eris_ovoo = np.asarray(eris.ovoo)
2229        eris_OVOO = np.asarray(eris.OVOO)
2230        eris_OVoo = np.asarray(eris.OVoo)
2231        eris_ovOO = np.asarray(eris.ovOO)
2232        ovoo = eris_ovoo - eris_ovoo.transpose(2,1,0,3)
2233        OVOO = eris_OVOO - eris_OVOO.transpose(2,1,0,3)
2234        woooo = lib.einsum('je,nemi->minj', t1a,      ovoo)
2235        wOOOO = lib.einsum('je,nemi->minj', t1b,      OVOO)
2236        wooOO = lib.einsum('JE,NEmi->miNJ', t1b, eris_OVoo)
2237        woOOo = lib.einsum('je,meNI->mINj',-t1a, eris_ovOO)
2238        tmpaa = lib.einsum('nemi,jnbe->mbij',      ovoo, t2aa)
2239        tmpaa+= lib.einsum('NEmi,jNbE->mbij', eris_OVoo, t2ab)
2240        tmpbb = lib.einsum('nemi,jnbe->mbij',      OVOO, t2bb)
2241        tmpbb+= lib.einsum('neMI,nJeB->MBIJ', eris_ovOO, t2ab)
2242        woVoO += lib.einsum('nemi,nJeB->mBiJ',      ovoo, t2ab)
2243        woVoO += lib.einsum('NEmi,JNBE->mBiJ', eris_OVoo, t2bb)
2244        woVoO -= lib.einsum('meNI,jNeB->mBjI', eris_ovOO, t2ab)
2245        wOvOo += lib.einsum('NEMI,jNbE->MbIj',      OVOO, t2ab)
2246        wOvOo += lib.einsum('neMI,jnbe->MbIj', eris_ovOO, t2aa)
2247        wOvOo -= lib.einsum('MEni,nJbE->MbJi', eris_OVoo, t2ab)
2248        wovoo += tmpaa - tmpaa.transpose(0,1,3,2)
2249        wOVOO += tmpbb - tmpbb.transpose(0,1,3,2)
2250        self.wooov = np.array(     ovoo.transpose(2,3,0,1), dtype=dtype)
2251        self.wOOOV = np.array(     OVOO.transpose(2,3,0,1), dtype=dtype)
2252        self.wooOV = np.array(eris_OVoo.transpose(2,3,0,1), dtype=dtype)
2253        self.wOOov = np.array(eris_ovOO.transpose(2,3,0,1), dtype=dtype)
2254#X        self.wOooV =-np.array(eris_OVoo.transpose(0,3,2,1), dtype=dtype)
2255#X        self.woOOv =-np.array(eris_ovOO.transpose(0,3,2,1), dtype=dtype)
2256        eris_ovoo = eris_OVOO = eris_ovOO = eris_OVoo = None
2257
2258        woooo += np.asarray(eris.oooo)
2259        wOOOO += np.asarray(eris.OOOO)
2260        wooOO += np.asarray(eris.ooOO)
2261        self.woooo = woooo - woooo.transpose(0,3,2,1)
2262        self.wOOOO = wOOOO - wOOOO.transpose(0,3,2,1)
2263        self.wooOO = wooOO - woOOo.transpose(0,3,2,1)
2264
2265        eris_ovov = np.asarray(eris.ovov)
2266        eris_OVOV = np.asarray(eris.OVOV)
2267        eris_ovOV = np.asarray(eris.ovOV)
2268        ovov = eris_ovov - eris_ovov.transpose(0,3,2,1)
2269        OVOV = eris_OVOV - eris_OVOV.transpose(0,3,2,1)
2270        tauaa, tauab, taubb = uccsd.make_tau(t2,t1,t1)
2271        self.woooo += 0.5*lib.einsum('ijef,menf->minj', tauaa,      ovov)
2272        self.wOOOO += 0.5*lib.einsum('ijef,menf->minj', taubb,      OVOV)
2273        self.wooOO +=     lib.einsum('iJeF,meNF->miNJ', tauab, eris_ovOV)
2274
2275        self.wooov += lib.einsum('if,mfne->mine', t1a,      ovov)
2276        self.wOOOV += lib.einsum('if,mfne->mine', t1b,      OVOV)
2277        self.wooOV += lib.einsum('if,mfNE->miNE', t1a, eris_ovOV)
2278        self.wOOov += lib.einsum('IF,neMF->MIne', t1b, eris_ovOV)
2279#X        self.wOooV -= lib.einsum('if,nfME->MinE', t1a, eris_ovOV)
2280#X        self.woOOv -= lib.einsum('IF,meNF->mINe', t1b, eris_ovOV)
2281
2282        tmp1aa = lib.einsum('njbf,menf->mbej', t2aa,      ovov)
2283        tmp1aa-= lib.einsum('jNbF,meNF->mbej', t2ab, eris_ovOV)
2284        tmp1bb = lib.einsum('njbf,menf->mbej', t2bb,      OVOV)
2285        tmp1bb-= lib.einsum('nJfB,nfME->MBEJ', t2ab, eris_ovOV)
2286        tmp1ab = lib.einsum('NJBF,meNF->mBeJ', t2bb, eris_ovOV)
2287        tmp1ab-= lib.einsum('nJfB,menf->mBeJ', t2ab,      ovov)
2288        tmp1ba = lib.einsum('njbf,nfME->MbEj', t2aa, eris_ovOV)
2289        tmp1ba-= lib.einsum('jNbF,MENF->MbEj', t2ab,      OVOV)
2290        tmp1abba =-lib.einsum('jNfB,mfNE->mBEj', t2ab, eris_ovOV)
2291        tmp1baab =-lib.einsum('nJbF,neMF->MbeJ', t2ab, eris_ovOV)
2292        tmpaa = lib.einsum('ie,mbej->mbij', t1a, tmp1aa)
2293        tmpbb = lib.einsum('ie,mbej->mbij', t1b, tmp1bb)
2294        tmpab = lib.einsum('ie,mBeJ->mBiJ', t1a, tmp1ab)
2295        tmpab-= lib.einsum('IE,mBEj->mBjI', t1b, tmp1abba)
2296        tmpba = lib.einsum('IE,MbEj->MbIj', t1b, tmp1ba)
2297        tmpba-= lib.einsum('ie,MbeJ->MbJi', t1a, tmp1baab)
2298        wovoo -= tmpaa - tmpaa.transpose(0,1,3,2)
2299        wOVOO -= tmpbb - tmpbb.transpose(0,1,3,2)
2300        woVoO -= tmpab
2301        wOvOo -= tmpba
2302        eris_ovov = eris_OVOV = eris_ovOV = None
2303        eris_ovoo = np.asarray(eris.ovoo)
2304        eris_OVOO = np.asarray(eris.OVOO)
2305        eris_ovOO = np.asarray(eris.ovOO)
2306        eris_OVoo = np.asarray(eris.OVoo)
2307        wovoo += eris_ovoo.transpose(3,1,2,0) - eris_ovoo.transpose(2,1,0,3)
2308        wOVOO += eris_OVOO.transpose(3,1,2,0) - eris_OVOO.transpose(2,1,0,3)
2309        woVoO += eris_OVoo.transpose(3,1,2,0)
2310        wOvOo += eris_ovOO.transpose(3,1,2,0)
2311        eris_ovoo = eris_OVOO = eris_ovOO = eris_OVoo = None
2312
2313        eris_ovvo = np.asarray(eris.ovvo)
2314        eris_OVVO = np.asarray(eris.OVVO)
2315        eris_OVvo = np.asarray(eris.OVvo)
2316        eris_ovVO = np.asarray(eris.ovVO)
2317        eris_oovv = np.asarray(eris.oovv)
2318        eris_OOVV = np.asarray(eris.OOVV)
2319        eris_OOvv = np.asarray(eris.OOvv)
2320        eris_ooVV = np.asarray(eris.ooVV)
2321        wovvo += eris_ovvo.transpose(0,2,1,3)
2322        wOVVO += eris_OVVO.transpose(0,2,1,3)
2323        woVvO += eris_ovVO.transpose(0,2,1,3)
2324        wOvVo += eris_OVvo.transpose(0,2,1,3)
2325        wovvo -= eris_oovv.transpose(0,2,3,1)
2326        wOVVO -= eris_OOVV.transpose(0,2,3,1)
2327        woVVo -= eris_ooVV.transpose(0,2,3,1)
2328        wOvvO -= eris_OOvv.transpose(0,2,3,1)
2329
2330        tmpaa = lib.einsum('ie,mebj->mbij', t1a, eris_ovvo)
2331        tmpbb = lib.einsum('ie,mebj->mbij', t1b, eris_OVVO)
2332        tmpaa-= lib.einsum('ie,mjbe->mbij', t1a, eris_oovv)
2333        tmpbb-= lib.einsum('ie,mjbe->mbij', t1b, eris_OOVV)
2334        woVoO += lib.einsum('ie,meBJ->mBiJ', t1a, eris_ovVO)
2335        woVoO -= lib.einsum('IE,mjBE->mBjI',-t1b, eris_ooVV)
2336        wOvOo += lib.einsum('IE,MEbj->MbIj', t1b, eris_OVvo)
2337        wOvOo -= lib.einsum('ie,MJbe->MbJi',-t1a, eris_OOvv)
2338        wovoo += tmpaa - tmpaa.transpose(0,1,3,2)
2339        wOVOO += tmpbb - tmpbb.transpose(0,1,3,2)
2340        wovoo -= lib.einsum('me,ijbe->mbij', self.Fova, t2aa)
2341        wOVOO -= lib.einsum('me,ijbe->mbij', self.Fovb, t2bb)
2342        woVoO += lib.einsum('me,iJeB->mBiJ', self.Fova, t2ab)
2343        wOvOo += lib.einsum('ME,jIbE->MbIj', self.Fovb, t2ab)
2344        wovoo -= lib.einsum('nb,minj->mbij', t1a, self.woooo)
2345        wOVOO -= lib.einsum('nb,minj->mbij', t1b, self.wOOOO)
2346        woVoO -= lib.einsum('NB,miNJ->mBiJ', t1b, self.wooOO)
2347        wOvOo -= lib.einsum('nb,njMI->MbIj', t1a, self.wooOO)
2348        eris_ovvo = eris_OVVO = eris_OVvo = eris_ovVO = None
2349        eris_oovv = eris_OOVV = eris_OOvv = eris_ooVV = None
2350
2351        self.saved = lib.H5TmpFile()
2352        self.saved['ovvo'] = wovvo
2353        self.saved['OVVO'] = wOVVO
2354        self.saved['oVvO'] = woVvO
2355        self.saved['OvVo'] = wOvVo
2356        self.saved['oVVo'] = woVVo
2357        self.saved['OvvO'] = wOvvO
2358        self.wovvo = self.saved['ovvo']
2359        self.wOVVO = self.saved['OVVO']
2360        self.woVvO = self.saved['oVvO']
2361        self.wOvVo = self.saved['OvVo']
2362        self.woVVo = self.saved['oVVo']
2363        self.wOvvO = self.saved['OvvO']
2364        self.saved['ovoo'] = wovoo
2365        self.saved['OVOO'] = wOVOO
2366        self.saved['oVoO'] = woVoO
2367        self.saved['OvOo'] = wOvOo
2368        self.wovoo = self.saved['ovoo']
2369        self.wOVOO = self.saved['OVOO']
2370        self.woVoO = self.saved['oVoO']
2371        self.wOvOo = self.saved['OvOo']
2372
2373        self.wvovv = self.saved.create_dataset('vovv', (nvira,nocca,nvira,nvira), t1a.dtype.char)
2374        self.wVOVV = self.saved.create_dataset('VOVV', (nvirb,noccb,nvirb,nvirb), t1a.dtype.char)
2375        self.wvOvV = self.saved.create_dataset('vOvV', (nvira,noccb,nvira,nvirb), t1a.dtype.char)
2376        self.wVoVv = self.saved.create_dataset('VoVv', (nvirb,nocca,nvirb,nvira), t1a.dtype.char)
2377
2378        # 3 or 4 virtuals
2379        eris_ovoo = np.asarray(eris.ovoo)
2380        eris_ovov = np.asarray(eris.ovov)
2381        eris_ovOV = np.asarray(eris.ovOV)
2382        ovov = eris_ovov - eris_ovov.transpose(0,3,2,1)
2383        eris_oovv = np.asarray(eris.oovv)
2384        eris_ovvo = np.asarray(eris.ovvo)
2385        oovv = eris_oovv - eris_ovvo.transpose(0,3,2,1)
2386        eris_oovv = eris_ovvo = None
2387        #:wvovv  = .5 * lib.einsum('meni,mnab->eiab', eris_ovoo, tauaa)
2388        #:wvovv -= .5 * lib.einsum('me,miab->eiab', self.Fova, t2aa)
2389        #:tmp1aa = lib.einsum('nibf,menf->mbei', t2aa,      ovov)
2390        #:tmp1aa-= lib.einsum('iNbF,meNF->mbei', t2ab, eris_ovOV)
2391        #:wvovv+= lib.einsum('ma,mbei->eiab', t1a, tmp1aa)
2392        #:wvovv+= lib.einsum('ma,mibe->eiab', t1a,      oovv)
2393        for p0, p1 in lib.prange(0, nvira, nocca):
2394            wvovv  = .5*lib.einsum('meni,mnab->eiab', eris_ovoo[:,p0:p1], tauaa)
2395            wvovv -= .5*lib.einsum('me,miab->eiab', self.Fova[:,p0:p1], t2aa)
2396
2397            tmp1aa = lib.einsum('nibf,menf->mbei', t2aa, ovov[:,p0:p1])
2398            tmp1aa-= lib.einsum('iNbF,meNF->mbei', t2ab, eris_ovOV[:,p0:p1])
2399            wvovv += lib.einsum('ma,mbei->eiab', t1a, tmp1aa)
2400            wvovv += lib.einsum('ma,mibe->eiab', t1a, oovv[:,:,:,p0:p1])
2401            self.wvovv[p0:p1] = wvovv
2402            tmp1aa = None
2403        eris_ovov = eris_ovoo = eris_ovOV = None
2404
2405        #:eris_ovvv = lib.unpack_tril(np.asarray(eris.ovvv).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvira,nvira)
2406        #:ovvv = eris_ovvv - eris_ovvv.transpose(0,3,2,1)
2407        #:wvovv += lib.einsum('mebf,miaf->eiab',      ovvv, t2aa)
2408        #:eris_OVvv = lib.unpack_tril(np.asarray(eris.OVvv).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvira,nvira)
2409        #:wvovv += lib.einsum('MFbe,iMaF->eiab', eris_OVvv, t2ab)
2410        #:wvovv += eris_ovvv.transpose(2,0,3,1).conj()
2411        #:self.wvovv -= wvovv - wvovv.transpose(0,1,3,2)
2412        mem_now = lib.current_memory()[0]
2413        max_memory = max(0, lib.param.MAX_MEMORY - mem_now)
2414        blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira**3*6))))
2415        for i0,i1 in lib.prange(0, nocca, blksize):
2416            wvovv = self.wvovv[:,i0:i1]
2417            for p0,p1 in lib.prange(0, noccb, blksize):
2418                OVvv = eris.get_OVvv(slice(p0,p1))  # OVvv = eris.OVvv[p0:p1]
2419                wvovv -= lib.einsum('MFbe,iMaF->eiab', OVvv, t2ab[i0:i1,p0:p1])
2420                OVvv = None
2421            for p0,p1 in lib.prange(0, nocca, blksize):
2422                ovvv = eris.get_ovvv(slice(p0,p1))  # ovvv = eris.ovvv[p0:p1]
2423                if p0 == i0:
2424                    wvovv += ovvv.transpose(2,0,3,1).conj()
2425                ovvv = ovvv - ovvv.transpose(0,3,2,1)
2426                wvovv -= lib.einsum('mebf,miaf->eiab', ovvv, t2aa[p0:p1,i0:i1])
2427                ovvv = None
2428            wvovv = wvovv - wvovv.transpose(0,1,3,2)
2429            self.wvovv[:,i0:i1] = wvovv
2430
2431        eris_OVOO = np.asarray(eris.OVOO)
2432        eris_OVOV = np.asarray(eris.OVOV)
2433        eris_ovOV = np.asarray(eris.ovOV)
2434        OVOV = eris_OVOV - eris_OVOV.transpose(0,3,2,1)
2435        eris_OOVV = np.asarray(eris.OOVV)
2436        eris_OVVO = np.asarray(eris.OVVO)
2437        OOVV = eris_OOVV - eris_OVVO.transpose(0,3,2,1)
2438        eris_OOVV = eris_OVVO = None
2439        #:wVOVV  = .5*lib.einsum('meni,mnab->eiab', eris_OVOO, taubb)
2440        #:wVOVV -= .5*lib.einsum('me,miab->eiab', self.Fovb, t2bb)
2441        #:tmp1bb = lib.einsum('nibf,menf->mbei', t2bb,      OVOV)
2442        #:tmp1bb-= lib.einsum('nIfB,nfME->MBEI', t2ab, eris_ovOV)
2443        #:wVOVV += lib.einsum('ma,mbei->eiab', t1b, tmp1bb)
2444        #:wVOVV += lib.einsum('ma,mibe->eiab', t1b,      OOVV)
2445        for p0, p1 in lib.prange(0, nvirb, noccb):
2446            wVOVV  = .5*lib.einsum('meni,mnab->eiab', eris_OVOO[:,p0:p1], taubb)
2447            wVOVV -= .5*lib.einsum('me,miab->eiab', self.Fovb[:,p0:p1], t2bb)
2448
2449            tmp1bb = lib.einsum('nibf,menf->mbei', t2bb, OVOV[:,p0:p1])
2450            tmp1bb-= lib.einsum('nIfB,nfME->MBEI', t2ab, eris_ovOV[:,:,:,p0:p1])
2451            wVOVV += lib.einsum('ma,mbei->eiab', t1b, tmp1bb)
2452            wVOVV += lib.einsum('ma,mibe->eiab', t1b, OOVV[:,:,:,p0:p1])
2453            self.wVOVV[p0:p1] = wVOVV
2454            tmp1bb = None
2455        eris_OVOV = eris_OVOO = eris_ovOV = None
2456
2457        #:eris_OVVV = lib.unpack_tril(np.asarray(eris.OVVV).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvirb,nvirb)
2458        #:OVVV = eris_OVVV - eris_OVVV.transpose(0,3,2,1)
2459        #:wVOVV -= lib.einsum('MEBF,MIAF->EIAB',      OVVV, t2bb)
2460        #:eris_ovVV = lib.unpack_tril(np.asarray(eris.ovVV).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvirb,nvirb)
2461        #:wVOVV -= lib.einsum('mfBE,mIfA->EIAB', eris_ovVV, t2ab)
2462        #:wVOVV += eris_OVVV.transpose(2,0,3,1).conj()
2463        #:self.wVOVV += wVOVV - wVOVV.transpose(0,1,3,2)
2464        blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb**3*6))))
2465        for i0,i1 in lib.prange(0, noccb, blksize):
2466            wVOVV = self.wVOVV[:,i0:i1]
2467            for p0,p1 in lib.prange(0, nocca, blksize):
2468                ovVV = eris.get_ovVV(slice(p0,p1))  # ovVV = eris.ovVV[p0:p1]
2469                wVOVV -= lib.einsum('mfBE,mIfA->EIAB', ovVV, t2ab[p0:p1,i0:i1])
2470                ovVV = None
2471            for p0,p1 in lib.prange(0, noccb, blksize):
2472                OVVV = eris.get_OVVV(slice(p0,p1))  # OVVV = eris.OVVV[p0:p1]
2473                if p0 == i0:
2474                    wVOVV += OVVV.transpose(2,0,3,1).conj()
2475                OVVV = OVVV - OVVV.transpose(0,3,2,1)
2476                wVOVV -= lib.einsum('mebf,miaf->eiab', OVVV, t2bb[p0:p1,i0:i1])
2477                OVVV = None
2478            wVOVV = wVOVV - wVOVV.transpose(0,1,3,2)
2479            self.wVOVV[:,i0:i1] = wVOVV
2480
2481        eris_ovOV = np.asarray(eris.ovOV)
2482        eris_ovOO = np.asarray(eris.ovOO)
2483        eris_OOvv = np.asarray(eris.OOvv)
2484        eris_ovVO = np.asarray(eris.ovVO)
2485        #:self.wvOvV = lib.einsum('meNI,mNaB->eIaB', eris_ovOO, tauab)
2486        #:self.wvOvV -= lib.einsum('me,mIaB->eIaB', self.Fova, t2ab)
2487        #:tmp1ab = lib.einsum('NIBF,meNF->mBeI', t2bb, eris_ovOV)
2488        #:tmp1ab-= lib.einsum('nIfB,menf->mBeI', t2ab,      ovov)
2489        #:tmp1baab = lib.einsum('nIbF,neMF->MbeI', t2ab, eris_ovOV)
2490        #:tmpab = lib.einsum('ma,mBeI->eIaB', t1a, tmp1ab)
2491        #:tmpab+= lib.einsum('MA,MbeI->eIbA', t1b, tmp1baab)
2492        #:tmpab-= lib.einsum('MA,MIbe->eIbA', t1b, eris_OOvv)
2493        #:tmpab-= lib.einsum('ma,meBI->eIaB', t1a, eris_ovVO)
2494        #:self.wvOvV += tmpab
2495        for p0, p1 in lib.prange(0, nvira, nocca):
2496            wvOvV  = lib.einsum('meNI,mNaB->eIaB', eris_ovOO[:,p0:p1], tauab)
2497            wvOvV -= lib.einsum('me,mIaB->eIaB', self.Fova[:,p0:p1], t2ab)
2498            tmp1ab = lib.einsum('NIBF,meNF->mBeI', t2bb, eris_ovOV[:,p0:p1])
2499            tmp1ab-= lib.einsum('nIfB,menf->mBeI', t2ab, ovov[:,p0:p1])
2500            wvOvV+= lib.einsum('ma,mBeI->eIaB', t1a, tmp1ab)
2501            tmp1ab = None
2502            tmp1baab = lib.einsum('nIbF,neMF->MbeI', t2ab, eris_ovOV[:,p0:p1])
2503            wvOvV+= lib.einsum('MA,MbeI->eIbA', t1b, tmp1baab)
2504            tmp1baab = None
2505            wvOvV-= lib.einsum('MA,MIbe->eIbA', t1b, eris_OOvv[:,:,:,p0:p1])
2506            wvOvV-= lib.einsum('ma,meBI->eIaB', t1a, eris_ovVO[:,p0:p1])
2507            self.wvOvV[p0:p1] = wvOvV
2508        eris_ovOV = eris_ovOO = eris_OOvv = eris_ovVO = None
2509
2510        #:eris_ovvv = lib.unpack_tril(np.asarray(eris.ovvv).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvira,nvira)
2511        #:ovvv = eris_ovvv - eris_ovvv.transpose(0,3,2,1)
2512        #:self.wvOvV -= lib.einsum('mebf,mIfA->eIbA',      ovvv, t2ab)
2513        #:eris_ovVV = lib.unpack_tril(np.asarray(eris.ovVV).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvirb,nvirb)
2514        #:self.wvOvV -= lib.einsum('meBF,mIaF->eIaB', eris_ovVV, t2ab)
2515        #:eris_OVvv = lib.unpack_tril(np.asarray(eris.OVvv).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvira,nvira)
2516        #:self.wvOvV -= lib.einsum('MFbe,MIAF->eIbA', eris_OVvv, t2bb)
2517        #:self.wvOvV += eris_OVvv.transpose(2,0,3,1).conj()
2518        blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira**3*6))))
2519        for i0,i1 in lib.prange(0, noccb, blksize):
2520            wvOvV = self.wvOvV[:,i0:i1]
2521            for p0,p1 in lib.prange(0, nocca, blksize):
2522                ovVV = eris.get_ovVV(slice(p0,p1))  # ovVV = eris.ovVV[p0:p1]
2523                wvOvV -= lib.einsum('meBF,mIaF->eIaB', ovVV, t2ab[p0:p1,i0:i1])
2524                ovVV = None
2525            for p0,p1 in lib.prange(0, nocca, blksize):
2526                ovvv = eris.get_ovvv(slice(p0,p1))  # ovvv = eris.ovvv[p0:p1]
2527                ovvv = ovvv - ovvv.transpose(0,3,2,1)
2528                wvOvV -= lib.einsum('mebf,mIfA->eIbA',ovvv, t2ab[p0:p1,i0:i1])
2529                ovvv = None
2530            self.wvOvV[:,i0:i1] = wvOvV
2531
2532        blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb*nvira**2*3))))
2533        for i0,i1 in lib.prange(0, noccb, blksize):
2534            wvOvV = self.wvOvV[:,i0:i1]
2535            for p0,p1 in lib.prange(0, noccb, blksize):
2536                OVvv = eris.get_OVvv(slice(p0,p1))  # OVvv = eris.OVvv[p0:p1]
2537                if p0 == i0:
2538                    wvOvV += OVvv.transpose(2,0,3,1).conj()
2539                wvOvV -= lib.einsum('MFbe,MIAF->eIbA', OVvv, t2bb[p0:p1,i0:i1])
2540                OVvv = None
2541            self.wvOvV[:,i0:i1] = wvOvV
2542
2543        eris_ovOV = np.asarray(eris.ovOV)
2544        eris_OVoo = np.asarray(eris.OVoo)
2545        eris_ooVV = np.asarray(eris.ooVV)
2546        eris_OVvo = np.asarray(eris.OVvo)
2547        #:self.wVoVv = lib.einsum('MEni,nMbA->EiAb', eris_OVoo, tauab)
2548        #:self.wVoVv -= lib.einsum('ME,iMbA->EiAb', self.Fovb, t2ab)
2549        #:tmp1ba = lib.einsum('nibf,nfME->MbEi', t2aa, eris_ovOV)
2550        #:tmp1ba-= lib.einsum('iNbF,MENF->MbEi', t2ab,      OVOV)
2551        #:tmp1abba = lib.einsum('iNfB,mfNE->mBEi', t2ab, eris_ovOV)
2552        #:tmpba = lib.einsum('MA,MbEi->EiAb', t1b, tmp1ba)
2553        #:tmpba+= lib.einsum('ma,mBEi->EiBa', t1a, tmp1abba)
2554        #:tmpba-= lib.einsum('ma,miBE->EiBa', t1a, eris_ooVV)
2555        #:tmpba-= lib.einsum('MA,MEbi->EiAb', t1b, eris_OVvo)
2556        #:self.wVoVv += tmpba
2557        for p0, p1 in lib.prange(0, nvirb, noccb):
2558            wVoVv  = lib.einsum('MEni,nMbA->EiAb', eris_OVoo[:,p0:p1], tauab)
2559            wVoVv -= lib.einsum('ME,iMbA->EiAb', self.Fovb[:,p0:p1], t2ab)
2560            tmp1ba = lib.einsum('nibf,nfME->MbEi', t2aa, eris_ovOV[:,:,:,p0:p1])
2561            tmp1ba-= lib.einsum('iNbF,MENF->MbEi', t2ab, OVOV[:,p0:p1])
2562            wVoVv += lib.einsum('MA,MbEi->EiAb', t1b, tmp1ba)
2563            tmp1ba = None
2564            tmp1abba = lib.einsum('iNfB,mfNE->mBEi', t2ab, eris_ovOV[:,:,:,p0:p1])
2565            wVoVv += lib.einsum('ma,mBEi->EiBa', t1a, tmp1abba)
2566            tmp1abba = None
2567            wVoVv -= lib.einsum('ma,miBE->EiBa', t1a, eris_ooVV[:,:,:,p0:p1])
2568            wVoVv -= lib.einsum('MA,MEbi->EiAb', t1b, eris_OVvo[:,p0:p1])
2569            self.wVoVv[p0:p1] = wVoVv
2570        eris_ovOV = eris_OVoo = eris_ooVV = eris_OVvo = None
2571
2572        #:eris_OVVV = lib.unpack_tril(np.asarray(eris.OVVV).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvirb,nvirb)
2573        #:OVVV = eris_OVVV - eris_OVVV.transpose(0,3,2,1)
2574        #:self.wVoVv -= lib.einsum('MEBF,iMaF->EiBa',      OVVV, t2ab)
2575        #:eris_OVvv = lib.unpack_tril(np.asarray(eris.OVvv).reshape(noccb*nvirb,-1)).reshape(noccb,nvirb,nvira,nvira)
2576        #:self.wVoVv -= lib.einsum('MEbf,iMfA->EiAb', eris_OVvv, t2ab)
2577        #:eris_ovVV = lib.unpack_tril(np.asarray(eris.ovVV).reshape(nocca*nvira,-1)).reshape(nocca,nvira,nvirb,nvirb)
2578        #:self.wVoVv -= lib.einsum('mfBE,miaf->EiBa', eris_ovVV, t2aa)
2579        #:self.wVoVv += eris_ovVV.transpose(2,0,3,1).conj()
2580        blksize = min(noccb, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvirb**3*6))))
2581        for i0,i1 in lib.prange(0, nocca, blksize):
2582            wVoVv = self.wVoVv[:,i0:i1]
2583            for p0,p1 in lib.prange(0, noccb, blksize):
2584                OVvv = eris.get_OVvv(slice(p0,p1))  # OVvv = eris.OVvv[p0:p1]
2585                wVoVv -= lib.einsum('MEbf,iMfA->EiAb', OVvv, t2ab[i0:i1,p0:p1])
2586                OVvv = None
2587            for p0,p1 in lib.prange(0, noccb, blksize):
2588                OVVV = eris.get_OVVV(slice(p0,p1))  # OVVV = eris.OVVV[p0:p1]
2589                OVVV = OVVV - OVVV.transpose(0,3,2,1)
2590                wVoVv -= lib.einsum('MEBF,iMaF->EiBa', OVVV, t2ab[i0:i1,p0:p1])
2591                OVVV = None
2592            self.wVoVv[:,i0:i1] = wVoVv
2593
2594        blksize = min(nocca, max(ccsd.BLKMIN, int(max_memory*1e6/8/(nvira*nvirb**2*3))))
2595        for i0,i1 in lib.prange(0, nocca, blksize):
2596            wVoVv = self.wVoVv[:,i0:i1]
2597            for p0,p1 in lib.prange(0, nocca, blksize):
2598                ovVV = eris.get_ovVV(slice(p0,p1))  # ovVV = eris.ovVV[p0:p1]
2599                if p0 == i0:
2600                    wVoVv += ovVV.transpose(2,0,3,1).conj()
2601                wVoVv -= lib.einsum('mfBE,miaf->EiBa', ovVV, t2aa[p0:p1,i0:i1])
2602                ovVV = None
2603            self.wVoVv[:,i0:i1] = wVoVv
2604
2605        self.made_ee_imds = True
2606        log.timer('EOM-UCCSD EE intermediates', *cput0)
2607
2608def rand_mf(mol, seed=1):
2609    from pyscf import scf
2610    from pyscf import gto
2611    from pyscf import lo
2612    mol = gto.Mole()
2613    mol.atom = [
2614        [8 , (0. , 0.     , 0.)],
2615        [1 , (0. , -0.757 , 0.587)],
2616        [1 , (0. , 0.757  , 0.587)]]
2617    mol.basis = 'sto-3g'
2618    mol.verbose = 0
2619    mol.spin = 0
2620    mol.build()
2621
2622    np.random.seed(seed)
2623
2624    mf = scf.UHF(mol).run(conv_tol=1e-14)
2625    nmo = mol.nao_nr()
2626    mf.mo_occ = np.zeros((2,nmo))
2627    mf.mo_occ[0,:4] = 1
2628    mf.mo_occ[1,:2] = 1
2629    mf.mo_energy = np.arange(nmo) + np.random.random((2,nmo)) * .3
2630    mf.mo_energy[mf.mo_occ == 0] += 2
2631
2632    mo = np.random.random((2,nmo,nmo))
2633    s = mf.get_ovlp()
2634    mf.mo_coeff = np.empty_like(mo)
2635    mf.mo_coeff[0] = lo.orth.vec_lowdin(mo[0], s)
2636    mf.mo_coeff[1] = lo.orth.vec_lowdin(mo[1], s)
2637    return mf
2638
2639def rand_cc_t1_t2(mf, seed=1):
2640    from pyscf import ao2mo
2641    from pyscf.cc import uccsd
2642    mycc = uccsd.UCCSD(mf)
2643
2644    nocca, noccb = mycc.nocc
2645    nmoa, nmob = mycc.nmo
2646    nvira, nvirb = nmoa - nocca, nmob - noccb
2647    def my_ao2mo(mo):
2648        eris = ao2mo.kernel(mycc._scf._eri, mo, compact=False)
2649        eris = ao2mo.restore(1, eris, mf.mol.nao_nr())
2650
2651        eris = eris + np.cos(eris)*1j
2652        eris = eris + eris.transpose(1, 0, 3, 2)
2653        eris = eris + eris.conj().transpose(2, 3, 0, 1)
2654        return eris
2655    eris = uccsd._make_eris_incore(mycc)#, ao2mofn=my_ao2mo)
2656
2657    np.random.seed(seed)
2658    t1a = (np.random.random((nocca,nvira)) +
2659           np.random.random((nocca,nvira))*1j - .5 - .5j)
2660    t1b = (np.random.random((noccb,nvirb)) +
2661           np.random.random((noccb,nvirb))*1j - .5 - .5j)
2662    t2aa = (np.random.random((nocca,nocca,nvira,nvira)) +
2663            np.random.random((nocca,nocca,nvira,nvira))*1j - .5 - .5j)
2664    t2aa = t2aa - t2aa.transpose(1, 0, 2, 3)
2665    t2aa = t2aa - t2aa.transpose(0, 1, 3, 2)
2666
2667    t2ab = (np.random.random((nocca,noccb,nvira,nvirb)) +
2668            np.random.random((nocca,noccb,nvira,nvirb))*1j - .5 - .5j)
2669
2670    t2bb = (np.random.random((noccb,noccb,nvirb,nvirb)) +
2671            np.random.random((noccb,noccb,nvirb,nvirb))*1j - .5 - .5j)
2672    t2bb = t2bb - t2bb.transpose(1, 0, 2, 3)
2673    t2bb = t2bb - t2bb.transpose(0, 1, 3, 2)
2674
2675    t1 = (t1a, t1b)
2676    t2 = (t2aa, t2ab, t2bb)
2677    return mycc, eris, t1, t2
2678
2679def enforce_symm_2p_spin(r1, r2, orbspin, excitation):
2680    assert(excitation in ['ip', 'ea'])
2681    if excitation == 'ip':
2682        nocc, nvir = r2.shape[1:]
2683    elif excitation == 'ea':
2684        nocc, nvir = r2.shape[:2]
2685    else:
2686        raise NotImplementedError
2687
2688    idxoa = np.where(orbspin[:nocc] == 0)[0]
2689    idxob = np.where(orbspin[:nocc] == 1)[0]
2690    idxva = np.where(orbspin[nocc:] == 0)[0]
2691    idxvb = np.where(orbspin[nocc:] == 1)[0]
2692
2693    idxoaa = idxoa[:,None] * nocc + idxoa
2694    idxobb = idxob[:,None] * nocc + idxob
2695    idxvaa = idxva[:,None] * nvir + idxva
2696    idxvbb = idxvb[:,None] * nvir + idxvb
2697
2698    if excitation == 'ip':
2699        r2 = r2 - r2.transpose(1, 0, 2)
2700
2701        r2 = r2.reshape(nocc**2, nvir)
2702        r2[idxobb.ravel()[:, None], idxva.ravel()] = 0.0
2703        r2[idxoaa.ravel()[:, None], idxvb.ravel()] = 0.0
2704        r2 = r2.reshape(nocc, nocc, nvir)
2705
2706    if excitation == 'ea':
2707        r2 = r2 - r2.transpose(0, 2, 1)
2708
2709        r2 = r2.reshape(nocc, nvir**2)
2710        r2[idxoa.ravel(), idxvbb.ravel()[:, None]] = 0.0
2711        r2[idxob.ravel(), idxvaa.ravel()[:, None]] = 0.0
2712        r2 = r2.reshape(nocc, nvir, nvir)
2713
2714    return r1, r2
2715
2716def enforce_symm_2p_spin_ip(r1, r2, orbspin):
2717    return enforce_symm_2p_spin(r1, r2, orbspin, 'ip')
2718
2719def enforce_symm_2p_spin_ea(r1, r2, orbspin):
2720    return enforce_symm_2p_spin(r1, r2, orbspin, 'ea')
2721
2722if __name__ == '__main__':
2723    from pyscf import gto
2724    #from pyscf import scf
2725    #from pyscf.cc import rccsd
2726
2727    mol = gto.Mole()
2728    mol.atom = [
2729        [8 , (0. , 0.     , 0.)],
2730        [1 , (0. , -0.757 , 0.587)],
2731        [1 , (0. , 0.757  , 0.587)]]
2732    mol.basis = 'sto-3g'
2733    mol.verbose = 0
2734    mol.spin = 0
2735    mol.build()
2736
2737    mf = rand_mf(mol)
2738    mycc, eris, t1, t2 = rand_cc_t1_t2(mf)
2739    mycc.t1 = t1
2740    mycc.t2 = t2
2741
2742    nocca, noccb = mycc.nocc
2743    nmoa, nmob = mycc.nmo
2744    nvira, nvirb = nmoa - nocca, nmob - noccb
2745    nocc = nocca + noccb
2746    nvir = nvira + nvirb
2747    nmo = nocc + nvir
2748
2749    def my_ao2mo(mo):
2750        nao, nmo = mo.shape
2751        orbspin = mo.orbspin
2752
2753#        eris = ao2mo.kernel(mygcc._scf._eri, mo_a + mo_b)
2754#        sym_forbid = (orbspin[:,None] != orbspin)[np.tril_indices(nmo)]
2755#        eris[sym_forbid,:] = 0
2756#        eris[:,sym_forbid] = 0
2757#        eris = ao2mo.restore(1, eris, nao)
2758#        return eris
2759        eris =(np.random.random((nmo,nmo,nmo,nmo)) +
2760               np.random.random((nmo,nmo,nmo,nmo)) * 1j)
2761
2762        eris = eris + np.cos(eris)*1j
2763        eris = eris + eris.transpose(1, 0, 3, 2)
2764        eris = eris + eris.conj().transpose(2, 3, 0, 1)
2765        eris[orbspin[:,None] != orbspin] = 0
2766        eris[:,:,orbspin[:,None] != orbspin] = 0
2767        return eris
2768
2769    import pyscf.cc.addons
2770    from pyscf.cc import gccsd
2771    mygcc = pyscf.cc.addons.convert_to_gccsd(mycc)
2772    mygcc._ucc = mycc
2773    mygcc._ucc_eris = eris
2774    eris = gccsd._make_eris_incore(mygcc)#, ao2mofn=my_ao2mo)
2775    orbspin = eris.orbspin
2776
2777    ## EOM-IP
2778    myeom = EOMIP(mycc)
2779    imds = myeom.make_imds()
2780
2781    np.random.seed(1)
2782    r1 = np.random.rand(nocc)*1j + np.random.rand(nocc) - 0.5 - 0.5*1j
2783    r2 = np.random.rand(nocc**2 * nvir)*1j + np.random.rand(nocc**2 * nvir) - 0.5 - 0.5*1j
2784    r2 = r2.reshape(nocc, nocc, nvir)
2785    r1, r2 = enforce_symm_2p_spin_ip(r1, r2, orbspin)
2786    r1, r2 = spin2spatial_ip(r1, r2, orbspin)
2787
2788    vector = myeom.amplitudes_to_vector(r1, r2)
2789    r1x, r2x = myeom.vector_to_amplitudes(vector)
2790    print(abs(r1[0]-r1x[0]).max() < 1e-13 and
2791          abs(r1[1]-r1x[1]).max() < 1e-13 and
2792          abs(r2[0]-r2x[0]).max() < 1e-13 and
2793          abs(r2[1]-r2x[1]).max() < 1e-13 and
2794          abs(r2[2]-r2x[2]).max() < 1e-13 and
2795          abs(r2[3]-r2x[3]).max() < 1e-13)
2796    Hvector = myeom.matvec(vector, imds=imds)
2797    print('ip', lib.finger(Hvector) - (21.67127462317093-19.068987454261908j))
2798    print('diag', lib.finger(myeom.get_diag()) - (-9.6676217223549763+9.325219825942975j))
2799
2800    # EOM-EA
2801    myeom = EOMEA(mycc)
2802    imds = myeom.make_imds()
2803
2804    np.random.seed(1)
2805    r1 = np.random.rand(nvir)*1j + np.random.rand(nvir) - 0.5 - 0.5*1j
2806    r2 = np.random.rand(nocc * nvir**2)*1j + np.random.rand(nocc * nvir**2) - 0.5 - 0.5*1j
2807    r2 = r2.reshape(nocc, nvir, nvir)
2808    r1, r2 = enforce_symm_2p_spin_ea(r1, r2, orbspin)
2809    r1, r2 = spin2spatial_ea(r1, r2, orbspin)
2810
2811    vector = myeom.amplitudes_to_vector(r1, r2)
2812    r1x, r2x = myeom.vector_to_amplitudes(vector)
2813    print(abs(r1[0]-r1x[0]).max() < 1e-13 and
2814          abs(r1[1]-r1x[1]).max() < 1e-13 and
2815          abs(r2[0]-r2x[0]).max() < 1e-13 and
2816          abs(r2[1]-r2x[1]).max() < 1e-13 and
2817          abs(r2[2]-r2x[2]).max() < 1e-13 and
2818          abs(r2[3]-r2x[3]).max() < 1e-13)
2819    Hvector = myeom.matvec(vector, imds=imds)
2820    print('ea', lib.finger(Hvector) - (6.5543877287461187-13.175055314063574j))
2821    print('diag', lib.finger(myeom.get_diag()) - (-57.353207240857785+1.4052857730841204j))
2822
2823    mycc = uccsd.UCCSD(mol.UHF().run())
2824    ecc, t1, t2 = mycc.kernel()
2825    print(ecc - -0.04946750711013597)
2826    e,v = mycc.ipccsd(nroots=6)
2827    print(e[0] - 0.3092874511803249)
2828    print(e[1] - 0.3092874511803249)
2829    print(e[2] - 0.4011171373779585)
2830    print(e[3] - 0.4011171373779585)
2831    print(e[4] - 0.6107409208314764)
2832    print(e[5] - 0.6107409208314764)
2833