1#!/usr/bin/env python
2# Copyright 2014-2021 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#
18
19'''
20Unrestricted CISD
21'''
22
23import numpy
24from pyscf import lib
25from pyscf.lib import logger
26from pyscf.cc import uccsd
27from pyscf.cc import uccsd_rdm
28from pyscf.ci import cisd
29from pyscf.fci import cistring
30from pyscf.cc.ccsd import _unpack_4fold
31
32def make_diagonal(myci, eris):
33    nocca, noccb = eris.nocc
34    nmoa = eris.focka.shape[0]
35    nmob = eris.fockb.shape[1]
36    nvira = nmoa - nocca
37    nvirb = nmob - noccb
38    jdiag_aa = numpy.zeros((nmoa,nmoa))
39    jdiag_ab = numpy.zeros((nmoa,nmob))
40    jdiag_bb = numpy.zeros((nmob,nmob))
41    jdiag_aa[:nocca,:nocca] = numpy.einsum('iijj->ij', eris.oooo)
42    jdiag_aa[:nocca,nocca:] = numpy.einsum('iijj->ij', eris.oovv)
43    jdiag_aa[nocca:,:nocca] = jdiag_aa[:nocca,nocca:].T
44    jdiag_ab[:nocca,:noccb] = numpy.einsum('iijj->ij', eris.ooOO)
45    jdiag_ab[:nocca,noccb:] = numpy.einsum('iijj->ij', eris.ooVV)
46    jdiag_ab[nocca:,:noccb] = numpy.einsum('iijj->ji', eris.OOvv)
47    jdiag_bb[:noccb,:noccb] = numpy.einsum('iijj->ij', eris.OOOO)
48    jdiag_bb[:noccb,noccb:] = numpy.einsum('iijj->ij', eris.OOVV)
49    jdiag_bb[noccb:,:noccb] = jdiag_bb[:noccb,noccb:].T
50
51    kdiag_aa = numpy.zeros((nmoa,nmoa))
52    kdiag_bb = numpy.zeros((nmob,nmob))
53    kdiag_aa[:nocca,:nocca] = numpy.einsum('ijji->ij', eris.oooo)
54    kdiag_aa[:nocca,nocca:] = numpy.einsum('ijji->ij', eris.ovvo)
55    kdiag_aa[nocca:,:nocca] = kdiag_aa[:nocca,nocca:].T
56    kdiag_bb[:noccb,:noccb] = numpy.einsum('ijji->ij', eris.OOOO)
57    kdiag_bb[:noccb,noccb:] = numpy.einsum('ijji->ij', eris.OVVO)
58    kdiag_bb[noccb:,:noccb] = kdiag_bb[:noccb,noccb:].T
59
60#    if eris.vvvv is not None and eris.vvVV is not None and eris.VVVV is not None:
61#        def diag_idx(n):
62#            idx = numpy.arange(n)
63#            return idx * (idx + 1) // 2 + idx
64#        jdiag_aa[nocca:,nocca:] = eris.vvvv[diag_idx(nvira)[:,None],diag_idx(nvira)]
65#        jdiag_ab[nocca:,noccb:] = eris.vvVV[diag_idx(nvira)[:,None],diag_idx(nvirb)]
66#        jdiag_bb[noccb:,noccb:] = eris.VVVV[diag_idx(nvirb)[:,None],diag_idx(nvirb)]
67#        kdiag_aa[nocca:,nocca:] = lib.unpack_tril(eris.vvvv.diagonal())
68#        kdiag_bb[noccb:,noccb:] = lib.unpack_tril(eris.VVVV.diagonal())
69
70    jkdiag_aa = jdiag_aa - kdiag_aa
71    jkdiag_bb = jdiag_bb - kdiag_bb
72
73    mo_ea = eris.focka.diagonal()
74    mo_eb = eris.fockb.diagonal()
75    ehf = (mo_ea[:nocca].sum() + mo_eb[:noccb].sum()
76           - jkdiag_aa[:nocca,:nocca].sum() * .5
77           - jdiag_ab[:nocca,:noccb].sum()
78           - jkdiag_bb[:noccb,:noccb].sum() * .5)
79
80    dia_a = lib.direct_sum('a-i->ia', mo_ea[nocca:], mo_ea[:nocca])
81    dia_a -= jkdiag_aa[:nocca,nocca:]
82    dia_b = lib.direct_sum('a-i->ia', mo_eb[noccb:], mo_eb[:noccb])
83    dia_b -= jkdiag_bb[:noccb,noccb:]
84    e1diag_a = dia_a + ehf
85    e1diag_b = dia_b + ehf
86
87    e2diag_aa = lib.direct_sum('ia+jb->ijab', dia_a, dia_a)
88    e2diag_aa += ehf
89    e2diag_aa += jkdiag_aa[:nocca,:nocca].reshape(nocca,nocca,1,1)
90    e2diag_aa -= jkdiag_aa[:nocca,nocca:].reshape(nocca,1,1,nvira)
91    e2diag_aa -= jkdiag_aa[:nocca,nocca:].reshape(1,nocca,nvira,1)
92    e2diag_aa += jkdiag_aa[nocca:,nocca:].reshape(1,1,nvira,nvira)
93
94    e2diag_ab = lib.direct_sum('ia+jb->ijab', dia_a, dia_b)
95    e2diag_ab += ehf
96    e2diag_ab += jdiag_ab[:nocca,:noccb].reshape(nocca,noccb,1,1)
97    e2diag_ab += jdiag_ab[nocca:,noccb:].reshape(1,1,nvira,nvirb)
98    e2diag_ab -= jdiag_ab[:nocca,noccb:].reshape(nocca,1,1,nvirb)
99    e2diag_ab -= jdiag_ab[nocca:,:noccb].T.reshape(1,noccb,nvira,1)
100
101    e2diag_bb = lib.direct_sum('ia+jb->ijab', dia_b, dia_b)
102    e2diag_bb += ehf
103    e2diag_bb += jkdiag_bb[:noccb,:noccb].reshape(noccb,noccb,1,1)
104    e2diag_bb -= jkdiag_bb[:noccb,noccb:].reshape(noccb,1,1,nvirb)
105    e2diag_bb -= jkdiag_bb[:noccb,noccb:].reshape(1,noccb,nvirb,1)
106    e2diag_bb += jkdiag_bb[noccb:,noccb:].reshape(1,1,nvirb,nvirb)
107
108    return amplitudes_to_cisdvec(ehf, (e1diag_a, e1diag_b),
109                                 (e2diag_aa, e2diag_ab, e2diag_bb))
110
111def contract(myci, civec, eris):
112    nocca, noccb = eris.nocc
113    nmoa = eris.focka.shape[0]
114    nmob = eris.fockb.shape[0]
115    nvira = nmoa - nocca
116    nvirb = nmob - noccb
117    c0, (c1a,c1b), (c2aa,c2ab,c2bb) = \
118            cisdvec_to_amplitudes(civec, (nmoa,nmob), (nocca,noccb))
119
120    #:t2 += 0.5*einsum('ijef,abef->ijab', c2, eris.vvvv)
121    #:eris_vvvv = ao2mo.restore(1, eris.vvvv, nvira)
122    #:eris_vvVV = ucisd_slow._restore(eris.vvVV, nvira, nvirb)
123    #:eris_VVVV = ao2mo.restore(1, eris.VVVV, nvirb)
124    #:t2aa += lib.einsum('ijef,aebf->ijab', c2aa, eris_vvvv)
125    #:t2bb += lib.einsum('ijef,aebf->ijab', c2bb, eris_VVVV)
126    #:t2ab += lib.einsum('iJeF,aeBF->iJaB', c2ab, eris_vvVV)
127    t2aa, t2ab, t2bb = myci._add_vvvv(None, (c2aa,c2ab,c2bb), eris)
128    t2aa *= .25
129    t2bb *= .25
130
131    fooa = eris.focka[:nocca,:nocca]
132    foob = eris.fockb[:noccb,:noccb]
133    fova = eris.focka[:nocca,nocca:]
134    fovb = eris.fockb[:noccb,noccb:]
135    fvoa = eris.focka[nocca:,:nocca]
136    fvob = eris.fockb[noccb:,:noccb]
137    fvva = eris.focka[nocca:,nocca:]
138    fvvb = eris.fockb[noccb:,noccb:]
139
140    t0 = 0
141    t1a = 0
142    t1b = 0
143    eris_oovv = _cp(eris.oovv)
144    eris_ooVV = _cp(eris.ooVV)
145    eris_OOvv = _cp(eris.OOvv)
146    eris_OOVV = _cp(eris.OOVV)
147    eris_ovov = _cp(eris.ovov)
148    eris_ovOV = _cp(eris.ovOV)
149    eris_OVOV = _cp(eris.OVOV)
150    #:t2 += eris.oovv * c0
151    t2aa += .25 * c0 * eris_ovov.conj().transpose(0,2,1,3)
152    t2aa -= .25 * c0 * eris_ovov.conj().transpose(0,2,3,1)
153    t2bb += .25 * c0 * eris_OVOV.conj().transpose(0,2,1,3)
154    t2bb -= .25 * c0 * eris_OVOV.conj().transpose(0,2,3,1)
155    t2ab += c0 * eris_ovOV.conj().transpose(0,2,1,3)
156    #:t0 += numpy.einsum('ijab,ijab', eris.oovv, c2) * .25
157    t0 += numpy.einsum('iajb,ijab', eris_ovov, c2aa) * .25
158    t0 -= numpy.einsum('jaib,ijab', eris_ovov, c2aa) * .25
159    t0 += numpy.einsum('iajb,ijab', eris_OVOV, c2bb) * .25
160    t0 -= numpy.einsum('jaib,ijab', eris_OVOV, c2bb) * .25
161    t0 += numpy.einsum('iajb,ijab', eris_ovOV, c2ab)
162    eris_ovov = eris_ovOV = eris_OVOV = None
163
164    #:tmp = einsum('imae,mbej->ijab', c2, eris.ovvo)
165    #:tmp = tmp - tmp.transpose(0,1,3,2)
166    #:t2 += tmp - tmp.transpose(1,0,2,3)
167    eris_ovvo = _cp(eris.ovvo)
168    eris_ovVO = _cp(eris.ovVO)
169    eris_OVVO = _cp(eris.OVVO)
170    ovvo = eris_ovvo - eris_oovv.transpose(0,3,2,1)
171    OVVO = eris_OVVO - eris_OOVV.transpose(0,3,2,1)
172    t2aa += lib.einsum('imae,jbem->ijab', c2aa, ovvo)
173    t2aa += lib.einsum('iMaE,jbEM->ijab', c2ab, eris_ovVO)
174    t2bb += lib.einsum('imae,jbem->ijab', c2bb, OVVO)
175    t2bb += lib.einsum('mIeA,meBJ->IJAB', c2ab, eris_ovVO)
176    t2ab += lib.einsum('imae,meBJ->iJaB', c2aa, eris_ovVO)
177    t2ab += lib.einsum('iMaE,MEBJ->iJaB', c2ab, OVVO)
178    t2ab += lib.einsum('IMAE,jbEM->jIbA', c2bb, eris_ovVO)
179    t2ab += lib.einsum('mIeA,jbem->jIbA', c2ab, ovvo)
180    t2ab -= lib.einsum('iMeA,JMeb->iJbA', c2ab, eris_OOvv)
181    t2ab -= lib.einsum('mIaE,jmEB->jIaB', c2ab, eris_ooVV)
182
183    #:t1 += einsum('nf,nafi->ia', c1, eris.ovvo)
184    t1a += numpy.einsum('nf,nfai->ia', c1a, eris_ovvo)
185    t1a -= numpy.einsum('nf,nifa->ia', c1a, eris_oovv)
186    t1b += numpy.einsum('nf,nfai->ia', c1b, eris_OVVO)
187    t1b -= numpy.einsum('nf,nifa->ia', c1b, eris_OOVV)
188    t1b += numpy.einsum('nf,nfai->ia', c1a, eris_ovVO)
189    t1a += numpy.einsum('nf,iafn->ia', c1b, eris_ovVO)
190
191    #:t1 -= 0.5*einsum('mnae,mnie->ia', c2, eris.ooov)
192    eris_ovoo = _cp(eris.ovoo)
193    eris_OVOO = _cp(eris.OVOO)
194    eris_OVoo = _cp(eris.OVoo)
195    eris_ovOO = _cp(eris.ovOO)
196    t1a += lib.einsum('mnae,meni->ia', c2aa, eris_ovoo)
197    t1b += lib.einsum('mnae,meni->ia', c2bb, eris_OVOO)
198    t1a -= lib.einsum('nMaE,MEni->ia', c2ab, eris_OVoo)
199    t1b -= lib.einsum('mNeA,meNI->IA', c2ab, eris_ovOO)
200    #:tmp = einsum('ma,mbij->ijab', c1, eris.ovoo)
201    #:t2 -= tmp - tmp.transpose(0,1,3,2)
202    t2aa -= lib.einsum('ma,jbmi->jiba', c1a, eris_ovoo)
203    t2bb -= lib.einsum('ma,jbmi->jiba', c1b, eris_OVOO)
204    t2ab -= lib.einsum('ma,JBmi->iJaB', c1a, eris_OVoo)
205    t2ab -= lib.einsum('MA,ibMJ->iJbA', c1b, eris_ovOO)
206
207    #:#:t1 -= 0.5*einsum('imef,maef->ia', c2, eris.ovvv)
208    #:eris_ovvv = _cp(eris.ovvv)
209    #:eris_OVVV = _cp(eris.OVVV)
210    #:eris_ovVV = _cp(eris.ovVV)
211    #:eris_OVvv = _cp(eris.OVvv)
212    #:t1a += lib.einsum('mief,mefa->ia', c2aa, eris_ovvv)
213    #:t1b += lib.einsum('MIEF,MEFA->IA', c2bb, eris_OVVV)
214    #:t1a += lib.einsum('iMfE,MEaf->ia', c2ab, eris_OVvv)
215    #:t1b += lib.einsum('mIeF,meAF->IA', c2ab, eris_ovVV)
216    #:#:tmp = einsum('ie,jeba->ijab', c1, numpy.asarray(eris.ovvv).conj())
217    #:#:t2 += tmp - tmp.transpose(1,0,2,3)
218    #:t2aa += lib.einsum('ie,mbae->imab', c1a, eris_ovvv)
219    #:t2bb += lib.einsum('ie,mbae->imab', c1b, eris_OVVV)
220    #:t2ab += lib.einsum('ie,MBae->iMaB', c1a, eris_OVvv)
221    #:t2ab += lib.einsum('IE,maBE->mIaB', c1b, eris_ovVV)
222    mem_now = lib.current_memory()[0]
223    max_memory = max(0, lib.param.MAX_MEMORY - mem_now)
224    if nvira > 0 and nocca > 0:
225        blksize = max(int(max_memory*1e6/8/(nvira**2*nocca*2)), 2)
226        for p0,p1 in lib.prange(0, nvira, blksize):
227            ovvv = eris.get_ovvv(slice(None), slice(p0,p1))
228            t1a += lib.einsum('mief,mefa->ia', c2aa[:,:,p0:p1], ovvv)
229            t2aa[:,:,p0:p1] += lib.einsum('mbae,ie->miba', ovvv, c1a)
230            ovvv = None
231
232    if nvirb > 0 and noccb > 0:
233        blksize = max(int(max_memory*1e6/8/(nvirb**2*noccb*2)), 2)
234        for p0,p1 in lib.prange(0, nvirb, blksize):
235            OVVV = eris.get_OVVV(slice(None), slice(p0,p1))
236            t1b += lib.einsum('MIEF,MEFA->IA', c2bb[:,:,p0:p1], OVVV)
237            t2bb[:,:,p0:p1] += lib.einsum('mbae,ie->miba', OVVV, c1b)
238            OVVV = None
239
240    if nvirb > 0 and nocca > 0:
241        blksize = max(int(max_memory*1e6/8/(nvirb**2*nocca*2)), 2)
242        for p0,p1 in lib.prange(0, nvira, blksize):
243            ovVV = eris.get_ovVV(slice(None), slice(p0,p1))
244            t1b += lib.einsum('mIeF,meAF->IA', c2ab[:,:,p0:p1], ovVV)
245            t2ab[:,:,p0:p1] += lib.einsum('maBE,IE->mIaB', ovVV, c1b)
246            ovVV = None
247
248    if nvira > 0 and noccb > 0:
249        blksize = max(int(max_memory*1e6/8/(nvira**2*noccb*2)), 2)
250        for p0,p1 in lib.prange(0, nvirb, blksize):
251            OVvv = eris.get_OVvv(slice(None), slice(p0,p1))
252            t1a += lib.einsum('iMfE,MEaf->ia', c2ab[:,:,:,p0:p1], OVvv)
253            t2ab[:,:,:,p0:p1] += lib.einsum('MBae,ie->iMaB', OVvv, c1a)
254            OVvv = None
255
256    #:t1  = einsum('ie,ae->ia', c1, fvv)
257    t1a += lib.einsum('ie,ae->ia', c1a, fvva)
258    t1b += lib.einsum('ie,ae->ia', c1b, fvvb)
259    #:t1 -= einsum('ma,mi->ia', c1, foo)
260    t1a -= lib.einsum('ma,mi->ia', c1a, fooa)
261    t1b -= lib.einsum('ma,mi->ia', c1b, foob)
262    #:t1 += einsum('imae,me->ia', c2, fov)
263    t1a += numpy.einsum('imae,me->ia', c2aa, fova)
264    t1a += numpy.einsum('imae,me->ia', c2ab, fovb)
265    t1b += numpy.einsum('imae,me->ia', c2bb, fovb)
266    t1b += numpy.einsum('miea,me->ia', c2ab, fova)
267
268    #:tmp = einsum('ijae,be->ijab', c2, fvv)
269    #:t2  = tmp - tmp.transpose(0,1,3,2)
270    t2aa += lib.einsum('ijae,be->ijab', c2aa, fvva*.5)
271    t2bb += lib.einsum('ijae,be->ijab', c2bb, fvvb*.5)
272    t2ab += lib.einsum('iJaE,BE->iJaB', c2ab, fvvb)
273    t2ab += lib.einsum('iJeA,be->iJbA', c2ab, fvva)
274    #:tmp = einsum('imab,mj->ijab', c2, foo)
275    #:t2 -= tmp - tmp.transpose(1,0,2,3)
276    t2aa -= lib.einsum('imab,mj->ijab', c2aa, fooa*.5)
277    t2bb -= lib.einsum('imab,mj->ijab', c2bb, foob*.5)
278    t2ab -= lib.einsum('iMaB,MJ->iJaB', c2ab, foob)
279    t2ab -= lib.einsum('mIaB,mj->jIaB', c2ab, fooa)
280
281    #:tmp = numpy.einsum('ia,bj->ijab', c1, fvo)
282    #:tmp = tmp - tmp.transpose(0,1,3,2)
283    #:t2 += tmp - tmp.transpose(1,0,2,3)
284    t2aa += numpy.einsum('ia,bj->ijab', c1a, fvoa)
285    t2bb += numpy.einsum('ia,bj->ijab', c1b, fvob)
286    t2ab += numpy.einsum('ia,bj->ijab', c1a, fvob)
287    t2ab += numpy.einsum('ia,bj->jiba', c1b, fvoa)
288
289    t2aa = t2aa - t2aa.transpose(0,1,3,2)
290    t2aa = t2aa - t2aa.transpose(1,0,2,3)
291    t2bb = t2bb - t2bb.transpose(0,1,3,2)
292    t2bb = t2bb - t2bb.transpose(1,0,2,3)
293
294    #:t2 += 0.5*einsum('mnab,mnij->ijab', c2, eris.oooo)
295    eris_oooo = _cp(eris.oooo)
296    eris_OOOO = _cp(eris.OOOO)
297    eris_ooOO = _cp(eris.ooOO)
298    t2aa += lib.einsum('mnab,minj->ijab', c2aa, eris_oooo)
299    t2bb += lib.einsum('mnab,minj->ijab', c2bb, eris_OOOO)
300    t2ab += lib.einsum('mNaB,miNJ->iJaB', c2ab, eris_ooOO)
301
302    #:t1 += fov.conj() * c0
303    t1a += fova.conj() * c0
304    t1b += fovb.conj() * c0
305    #:t0  = numpy.einsum('ia,ia', fov, c1)
306    t0 += numpy.einsum('ia,ia', fova, c1a)
307    t0 += numpy.einsum('ia,ia', fovb, c1b)
308    return amplitudes_to_cisdvec(t0, (t1a,t1b), (t2aa,t2ab,t2bb))
309
310def amplitudes_to_cisdvec(c0, c1, c2):
311    c1a, c1b = c1
312    c2aa, c2ab, c2bb = c2
313    nocca, nvira = c1a.shape
314    noccb, nvirb = c1b.shape
315    def trilidx(n):
316        idx = numpy.tril_indices(n, -1)
317        return idx[0] * n + idx[1]
318    ooidxa = trilidx(nocca)
319    vvidxa = trilidx(nvira)
320    ooidxb = trilidx(noccb)
321    vvidxb = trilidx(nvirb)
322    size = (1, nocca*nvira, noccb*nvirb, nocca*noccb*nvira*nvirb,
323            len(ooidxa)*len(vvidxa), len(ooidxb)*len(vvidxb))
324    loc = numpy.cumsum(size)
325    civec = numpy.empty(loc[-1], dtype=c2ab.dtype)
326    civec[0] = c0
327    civec[loc[0]:loc[1]] = c1a.ravel()
328    civec[loc[1]:loc[2]] = c1b.ravel()
329    civec[loc[2]:loc[3]] = c2ab.ravel()
330    lib.take_2d(c2aa.reshape(nocca**2,nvira**2), ooidxa, vvidxa, out=civec[loc[3]:loc[4]])
331    lib.take_2d(c2bb.reshape(noccb**2,nvirb**2), ooidxb, vvidxb, out=civec[loc[4]:loc[5]])
332    return civec
333
334def cisdvec_to_amplitudes(civec, nmo, nocc):
335    norba, norbb = nmo
336    nocca, noccb = nocc
337    nvira = norba - nocca
338    nvirb = norbb - noccb
339    nooa = nocca * (nocca-1) // 2
340    nvva = nvira * (nvira-1) // 2
341    noob = noccb * (noccb-1) // 2
342    nvvb = nvirb * (nvirb-1) // 2
343    size = (1, nocca*nvira, noccb*nvirb, nocca*noccb*nvira*nvirb,
344            nooa*nvva, noob*nvvb)
345    loc = numpy.cumsum(size)
346    c0 = civec[0]
347    c1a = civec[loc[0]:loc[1]].reshape(nocca,nvira)
348    c1b = civec[loc[1]:loc[2]].reshape(noccb,nvirb)
349    c2ab = civec[loc[2]:loc[3]].reshape(nocca,noccb,nvira,nvirb)
350    c2aa = _unpack_4fold(civec[loc[3]:loc[4]], nocca, nvira)
351    c2bb = _unpack_4fold(civec[loc[4]:loc[5]], noccb, nvirb)
352    return c0, (c1a,c1b), (c2aa,c2ab,c2bb)
353
354def to_fcivec(cisdvec, norb, nelec, frozen=None):
355    '''Convert CISD coefficients to FCI coefficients'''
356    if isinstance(nelec, (int, numpy.number)):
357        nelecb = nelec//2
358        neleca = nelec - nelecb
359    else:
360        neleca, nelecb = nelec
361
362    frozena_mask = numpy.zeros(norb, dtype=bool)
363    frozenb_mask = numpy.zeros(norb, dtype=bool)
364    if frozen is None:
365        nfroza = nfrozb = 0
366    elif isinstance(frozen, (int, numpy.integer)):
367        nfroza = nfrozb = frozen
368        frozena_mask[:frozen] = True
369        frozenb_mask[:frozen] = True
370    else:
371        nfroza = len(frozen[0])
372        nfrozb = len(frozen[1])
373        frozena_mask[frozen[0]] = True
374        frozenb_mask[frozen[1]] = True
375
376#    if nfroza != nfrozb:
377#        raise NotImplementedError
378    nocca = numpy.count_nonzero(~frozena_mask[:neleca])
379    noccb = numpy.count_nonzero(~frozenb_mask[:nelecb])
380    nmo = nmoa, nmob = norb - nfroza, norb - nfrozb
381    nocc = nocca, noccb
382    nvira, nvirb = nmoa - nocca, nmob - noccb
383
384    c0, c1, c2 = cisdvec_to_amplitudes(cisdvec, nmo, nocc)
385    c1a, c1b = c1
386    c2aa, c2ab, c2bb = c2
387    t1addra, t1signa = cisd.tn_addrs_signs(nmoa, nocca, 1)
388    t1addrb, t1signb = cisd.tn_addrs_signs(nmob, noccb, 1)
389
390    na = cistring.num_strings(nmoa, nocca)
391    nb = cistring.num_strings(nmob, noccb)
392    fcivec = numpy.zeros((na,nb))
393    fcivec[0,0] = c0
394    fcivec[t1addra,0] = c1a.ravel() * t1signa
395    fcivec[0,t1addrb] = c1b.ravel() * t1signb
396    c2ab = c2ab.transpose(0,2,1,3).reshape(nocca*nvira,-1)
397    c2ab = numpy.einsum('i,j,ij->ij', t1signa, t1signb, c2ab)
398    fcivec[t1addra[:,None],t1addrb] = c2ab
399
400    if nocca > 1 and nvira > 1:
401        ooidx = numpy.tril_indices(nocca, -1)
402        vvidx = numpy.tril_indices(nvira, -1)
403        c2aa = c2aa[ooidx][:,vvidx[0],vvidx[1]]
404        t2addra, t2signa = cisd.tn_addrs_signs(nmoa, nocca, 2)
405        fcivec[t2addra,0] = c2aa.ravel() * t2signa
406    if noccb > 1 and nvirb > 1:
407        ooidx = numpy.tril_indices(noccb, -1)
408        vvidx = numpy.tril_indices(nvirb, -1)
409        c2bb = c2bb[ooidx][:,vvidx[0],vvidx[1]]
410        t2addrb, t2signb = cisd.tn_addrs_signs(nmob, noccb, 2)
411        fcivec[0,t2addrb] = c2bb.ravel() * t2signb
412
413    if nfroza == nfrozb == 0:
414        return fcivec
415
416    assert(norb < 63)
417
418    strsa = cistring.gen_strings4orblist(range(norb), neleca)
419    strsb = cistring.gen_strings4orblist(range(norb), nelecb)
420    na = len(strsa)
421    nb = len(strsb)
422    count_a = numpy.zeros(na, dtype=int)
423    count_b = numpy.zeros(nb, dtype=int)
424    parity_a = numpy.zeros(na, dtype=bool)
425    parity_b = numpy.zeros(nb, dtype=bool)
426    core_a_mask = numpy.ones(na, dtype=bool)
427    core_b_mask = numpy.ones(nb, dtype=bool)
428
429    for i in range(norb):
430        if frozena_mask[i]:
431            if i < neleca:
432                core_a_mask &= (strsa & (1 <<i )) != 0
433                parity_a ^= (count_a & 1) == 1
434            else:
435                core_a_mask &= (strsa & (1 << i)) == 0
436        else:
437            count_a += (strsa & (1 << i)) != 0
438
439        if frozenb_mask[i]:
440            if i < nelecb:
441                core_b_mask &= (strsb & (1 <<i )) != 0
442                parity_b ^= (count_b & 1) == 1
443            else:
444                core_b_mask &= (strsb & (1 << i)) == 0
445        else:
446            count_b += (strsb & (1 << i)) != 0
447
448    sub_strsa = strsa[core_a_mask & (count_a == nocca)]
449    sub_strsb = strsb[core_b_mask & (count_b == noccb)]
450    addrsa = cistring.strs2addr(norb, neleca, sub_strsa)
451    addrsb = cistring.strs2addr(norb, nelecb, sub_strsb)
452    fcivec1 = numpy.zeros((na,nb))
453    fcivec1[addrsa[:,None],addrsb] = fcivec
454    fcivec1[parity_a,:] *= -1
455    fcivec1[:,parity_b] *= -1
456    return fcivec1
457
458def from_fcivec(ci0, norb, nelec, frozen=None):
459    '''Extract CISD coefficients from FCI coefficients'''
460    if not (frozen is None or frozen == 0):
461        raise NotImplementedError
462
463    if isinstance(nelec, (int, numpy.number)):
464        nelecb = nelec//2
465        neleca = nelec - nelecb
466    else:
467        neleca, nelecb = nelec
468
469    norba = norbb = norb
470    nocca, noccb = neleca, nelecb
471    nvira = norba - nocca
472    nvirb = norbb - noccb
473    t1addra, t1signa = cisd.tn_addrs_signs(norba, nocca, 1)
474    t1addrb, t1signb = cisd.tn_addrs_signs(norbb, noccb, 1)
475
476    na = cistring.num_strings(norba, nocca)
477    nb = cistring.num_strings(norbb, noccb)
478    ci0 = ci0.reshape(na,nb)
479    c0 = ci0[0,0]
480    c1a = (ci0[t1addra,0] * t1signa).reshape(nocca,nvira)
481    c1b = (ci0[0,t1addrb] * t1signb).reshape(noccb,nvirb)
482
483    c2ab = numpy.einsum('i,j,ij->ij', t1signa, t1signb, ci0[t1addra[:,None],t1addrb])
484    c2ab = c2ab.reshape(nocca,nvira,noccb,nvirb).transpose(0,2,1,3)
485    t2addra, t2signa = cisd.tn_addrs_signs(norba, nocca, 2)
486    t2addrb, t2signb = cisd.tn_addrs_signs(norbb, noccb, 2)
487    c2aa = (ci0[t2addra,0] * t2signa).reshape(nocca*(nocca-1)//2, nvira*(nvira-1)//2)
488    c2aa = _unpack_4fold(c2aa, nocca, nvira)
489    c2bb = (ci0[0,t2addrb] * t2signb).reshape(noccb*(noccb-1)//2, nvirb*(nvirb-1)//2)
490    c2bb = _unpack_4fold(c2bb, noccb, nvirb)
491
492    return amplitudes_to_cisdvec(c0, (c1a,c1b), (c2aa,c2ab,c2bb))
493
494def overlap(cibra, ciket, nmo, nocc, s=None):
495    '''Overlap between two CISD wavefunctions.
496
497    Args:
498        s : a list of 2D arrays
499            The overlap matrix of non-orthogonal one-particle basis
500    '''
501    if s is None:
502        return numpy.dot(cibra, ciket, nmo, nocc)
503
504    if isinstance(nmo, (int, numpy.integer)):
505        nmoa = nmob = nmo
506    else:
507        nmoa, nmob = nmo
508    nocca, noccb = nocc
509    nvira, nvirb = nmoa - nocca, nmob - noccb
510
511    bra0, bra1, bra2 = cisdvec_to_amplitudes(cibra, (nmoa,nmob), nocc)
512    ket0, ket1, ket2 = cisdvec_to_amplitudes(ciket, (nmoa,nmob), nocc)
513
514    ooidx = numpy.tril_indices(nocca, -1)
515    vvidx = numpy.tril_indices(nvira, -1)
516    bra2aa = lib.take_2d(bra2[0].reshape(nocca**2,nvira**2),
517                         ooidx[0]*nocca+ooidx[1], vvidx[0]*nvira+vvidx[1])
518    ket2aa = lib.take_2d(ket2[0].reshape(nocca**2,nvira**2),
519                         ooidx[0]*nocca+ooidx[1], vvidx[0]*nvira+vvidx[1])
520
521    ooidx = numpy.tril_indices(noccb, -1)
522    vvidx = numpy.tril_indices(nvirb, -1)
523    bra2bb = lib.take_2d(bra2[2].reshape(noccb**2,nvirb**2),
524                         ooidx[0]*noccb+ooidx[1], vvidx[0]*nvirb+vvidx[1])
525    ket2bb = lib.take_2d(ket2[2].reshape(noccb**2,nvirb**2),
526                         ooidx[0]*noccb+ooidx[1], vvidx[0]*nvirb+vvidx[1])
527
528    nova = nocca * nvira
529    novb = noccb * nvirb
530    occlist0a = numpy.arange(nocca).reshape(1,nocca)
531    occlist0b = numpy.arange(noccb).reshape(1,noccb)
532    occlistsa = numpy.repeat(occlist0a, 1+nova+bra2aa.size, axis=0)
533    occlistsb = numpy.repeat(occlist0b, 1+novb+bra2bb.size, axis=0)
534    occlist0a = occlistsa[:1]
535    occlist1a = occlistsa[1:1+nova]
536    occlist2a = occlistsa[1+nova:]
537    occlist0b = occlistsb[:1]
538    occlist1b = occlistsb[1:1+novb]
539    occlist2b = occlistsb[1+novb:]
540
541    ia = 0
542    for i in range(nocca):
543        for a in range(nocca, nmoa):
544            occlist1a[ia,i] = a
545            ia += 1
546    ia = 0
547    for i in range(noccb):
548        for a in range(noccb, nmob):
549            occlist1b[ia,i] = a
550            ia += 1
551
552    ia = 0
553    for i in range(nocca):
554        for j in range(i):
555            for a in range(nocca, nmoa):
556                for b in range(nocca, a):
557                    occlist2a[ia,i] = a
558                    occlist2a[ia,j] = b
559                    ia += 1
560    ia = 0
561    for i in range(noccb):
562        for j in range(i):
563            for a in range(noccb, nmob):
564                for b in range(noccb, a):
565                    occlist2b[ia,i] = a
566                    occlist2b[ia,j] = b
567                    ia += 1
568
569    na = len(occlistsa)
570    trans_a = numpy.empty((na,na))
571    for i, idx in enumerate(occlistsa):
572        s_sub = s[0][idx].T.copy()
573        minors = s_sub[occlistsa]
574        trans_a[i,:] = numpy.linalg.det(minors)
575    nb = len(occlistsb)
576    trans_b = numpy.empty((nb,nb))
577    for i, idx in enumerate(occlistsb):
578        s_sub = s[1][idx].T.copy()
579        minors = s_sub[occlistsb]
580        trans_b[i,:] = numpy.linalg.det(minors)
581
582    # Mimic the transformation einsum('ab,ap->pb', FCI, trans).
583    # The wavefunction FCI has the [excitation_alpha,excitation_beta]
584    # representation.  The zero blocks like FCI[S_alpha,D_beta],
585    # FCI[D_alpha,D_beta], are explicitly excluded.
586    bra_mat = numpy.zeros((na,nb))
587    bra_mat[0,0] = bra0
588    bra_mat[1:1+nova,0] = bra1[0].ravel()
589    bra_mat[0,1:1+novb] = bra1[1].ravel()
590    bra_mat[1+nova:,0] = bra2aa.ravel()
591    bra_mat[0,1+novb:] = bra2bb.ravel()
592    bra_mat[1:1+nova,1:1+novb] = bra2[1].transpose(0,2,1,3).reshape(nova,novb)
593    c_s = lib.einsum('ab,ap,bq->pq', bra_mat, trans_a, trans_b)
594    ovlp  =  c_s[0,0] * ket0
595    ovlp += numpy.dot(c_s[1:1+nova,0], ket1[0].ravel())
596    ovlp += numpy.dot(c_s[0,1:1+novb], ket1[1].ravel())
597    ovlp += numpy.dot(c_s[1+nova:,0] , ket2aa.ravel())
598    ovlp += numpy.dot(c_s[0,1+novb:] , ket2bb.ravel())
599    ovlp += numpy.einsum('ijab,iajb->', ket2[1],
600                         c_s[1:1+nova,1:1+novb].reshape(nocca,nvira,noccb,nvirb))
601    return ovlp
602
603
604def make_rdm1(myci, civec=None, nmo=None, nocc=None, ao_repr=False):
605    r'''
606    One-particle spin density matrices dm1a, dm1b in MO basis (the
607    occupied-virtual blocks due to the orbital response contribution are not
608    included).
609
610    dm1a[p,q] = <q_alpha^\dagger p_alpha>
611    dm1b[p,q] = <q_beta^\dagger p_beta>
612
613    The convention of 1-pdm is based on McWeeney's book, Eq (5.4.20).
614    '''
615    if civec is None: civec = myci.ci
616    if nmo is None: nmo = myci.nmo
617    if nocc is None: nocc = myci.nocc
618    d1 = _gamma1_intermediates(myci, civec, nmo, nocc)
619    return uccsd_rdm._make_rdm1(myci, d1, with_frozen=True, ao_repr=ao_repr)
620
621def make_rdm2(myci, civec=None, nmo=None, nocc=None, ao_repr=False):
622    r'''
623    Two-particle spin density matrices dm2aa, dm2ab, dm2bb in MO basis
624
625    dm2aa[p,q,r,s] = <q_alpha^\dagger s_alpha^\dagger r_alpha p_alpha>
626    dm2ab[p,q,r,s] = <q_alpha^\dagger s_beta^\dagger r_beta p_alpha>
627    dm2bb[p,q,r,s] = <q_beta^\dagger s_beta^\dagger r_beta p_beta>
628
629    (p,q correspond to one particle and r,s correspond to another particle)
630    Two-particle density matrix should be contracted to integrals with the
631    pattern below to compute energy
632
633    E = numpy.einsum('pqrs,pqrs', eri_aa, dm2_aa)
634    E+= numpy.einsum('pqrs,pqrs', eri_ab, dm2_ab)
635    E+= numpy.einsum('pqrs,rspq', eri_ba, dm2_ab)
636    E+= numpy.einsum('pqrs,pqrs', eri_bb, dm2_bb)
637
638    where eri_aa[p,q,r,s] = (p_alpha q_alpha | r_alpha s_alpha )
639    eri_ab[p,q,r,s] = ( p_alpha q_alpha | r_beta s_beta )
640    eri_ba[p,q,r,s] = ( p_beta q_beta | r_alpha s_alpha )
641    eri_bb[p,q,r,s] = ( p_beta q_beta | r_beta s_beta )
642    '''
643    if civec is None: civec = myci.ci
644    if nmo is None: nmo = myci.nmo
645    if nocc is None: nocc = myci.nocc
646    d1 = _gamma1_intermediates(myci, civec, nmo, nocc)
647    d2 = _gamma2_intermediates(myci, civec, nmo, nocc)
648    return uccsd_rdm._make_rdm2(myci, d1, d2, with_dm1=True, with_frozen=True,
649                                ao_repr=ao_repr)
650
651def _gamma1_intermediates(myci, civec, nmo, nocc):
652    nmoa, nmob = nmo
653    nocca, noccb = nocc
654    c0, c1, c2 = cisdvec_to_amplitudes(civec, nmo, nocc)
655    c1a, c1b = c1
656    c2aa, c2ab, c2bb = c2
657
658    dvoa = c0.conj() * c1a.T
659    dvob = c0.conj() * c1b.T
660    dvoa += numpy.einsum('jb,ijab->ai', c1a.conj(), c2aa)
661    dvoa += numpy.einsum('jb,ijab->ai', c1b.conj(), c2ab)
662    dvob += numpy.einsum('jb,ijab->ai', c1b.conj(), c2bb)
663    dvob += numpy.einsum('jb,jiba->ai', c1a.conj(), c2ab)
664    dova = dvoa.T.conj()
665    dovb = dvob.T.conj()
666
667    dooa  =-numpy.einsum('ia,ka->ik', c1a.conj(), c1a)
668    doob  =-numpy.einsum('ia,ka->ik', c1b.conj(), c1b)
669    dooa -= numpy.einsum('ijab,ikab->jk', c2aa.conj(), c2aa) * .5
670    dooa -= numpy.einsum('jiab,kiab->jk', c2ab.conj(), c2ab)
671    doob -= numpy.einsum('ijab,ikab->jk', c2bb.conj(), c2bb) * .5
672    doob -= numpy.einsum('ijab,ikab->jk', c2ab.conj(), c2ab)
673
674    dvva  = numpy.einsum('ia,ic->ac', c1a, c1a.conj())
675    dvvb  = numpy.einsum('ia,ic->ac', c1b, c1b.conj())
676    dvva += numpy.einsum('ijab,ijac->bc', c2aa, c2aa.conj()) * .5
677    dvva += numpy.einsum('ijba,ijca->bc', c2ab, c2ab.conj())
678    dvvb += numpy.einsum('ijba,ijca->bc', c2bb, c2bb.conj()) * .5
679    dvvb += numpy.einsum('ijab,ijac->bc', c2ab, c2ab.conj())
680    return (dooa, doob), (dova, dovb), (dvoa, dvob), (dvva, dvvb)
681
682def _gamma2_intermediates(myci, civec, nmo, nocc):
683    nmoa, nmob = nmo
684    nocca, noccb = nocc
685    c0, c1, c2 = cisdvec_to_amplitudes(civec, nmo, nocc)
686    c1a, c1b = c1
687    c2aa, c2ab, c2bb = c2
688
689    goovv = c0 * c2aa.conj() * .5
690    goOvV = c0 * c2ab.conj()
691    gOOVV = c0 * c2bb.conj() * .5
692
693    govvv = numpy.einsum('ia,ikcd->kadc', c1a, c2aa.conj()) * .5
694    gOvVv = numpy.einsum('ia,ikcd->kadc', c1a, c2ab.conj())
695    goVvV = numpy.einsum('ia,kidc->kadc', c1b, c2ab.conj())
696    gOVVV = numpy.einsum('ia,ikcd->kadc', c1b, c2bb.conj()) * .5
697
698    gooov = numpy.einsum('ia,klac->klic', c1a, c2aa.conj()) *-.5
699    goOoV =-numpy.einsum('ia,klac->klic', c1a, c2ab.conj())
700    gOoOv =-numpy.einsum('ia,lkca->klic', c1b, c2ab.conj())
701    gOOOV = numpy.einsum('ia,klac->klic', c1b, c2bb.conj()) *-.5
702
703    goooo = numpy.einsum('ijab,klab->ijkl', c2aa.conj(), c2aa) * .25
704    goOoO = numpy.einsum('ijab,klab->ijkl', c2ab.conj(), c2ab)
705    gOOOO = numpy.einsum('ijab,klab->ijkl', c2bb.conj(), c2bb) * .25
706    gvvvv = numpy.einsum('ijab,ijcd->abcd', c2aa, c2aa.conj()) * .25
707    gvVvV = numpy.einsum('ijab,ijcd->abcd', c2ab, c2ab.conj())
708    gVVVV = numpy.einsum('ijab,ijcd->abcd', c2bb, c2bb.conj()) * .25
709
710    goVoV = numpy.einsum('jIaB,kIaC->jCkB', c2ab.conj(), c2ab)
711    gOvOv = numpy.einsum('iJbA,iKcA->JcKb', c2ab.conj(), c2ab)
712
713    govvo = numpy.einsum('ijab,ikac->jcbk', c2aa.conj(), c2aa)
714    govvo+= numpy.einsum('jIbA,kIcA->jcbk', c2ab.conj(), c2ab)
715    goVvO = numpy.einsum('jIbA,IKAC->jCbK', c2ab.conj(), c2bb)
716    goVvO+= numpy.einsum('ijab,iKaC->jCbK', c2aa.conj(), c2ab)
717    gOVVO = numpy.einsum('ijab,ikac->jcbk', c2bb.conj(), c2bb)
718    gOVVO+= numpy.einsum('iJaB,iKaC->JCBK', c2ab.conj(), c2ab)
719    govvo+= numpy.einsum('ia,jb->ibaj', c1a.conj(), c1a)
720    goVvO+= numpy.einsum('ia,jb->ibaj', c1a.conj(), c1b)
721    gOVVO+= numpy.einsum('ia,jb->ibaj', c1b.conj(), c1b)
722
723    dovov = goovv.transpose(0,2,1,3) - goovv.transpose(0,3,1,2)
724    doooo = goooo.transpose(0,2,1,3) - goooo.transpose(0,3,1,2)
725    dvvvv = gvvvv.transpose(0,2,1,3) - gvvvv.transpose(0,3,1,2)
726    dovvo = govvo.transpose(0,2,1,3)
727    dooov = gooov.transpose(0,2,1,3) - gooov.transpose(1,2,0,3)
728    dovvv = govvv.transpose(0,2,1,3) - govvv.transpose(0,3,1,2)
729    doovv =-dovvo.transpose(0,3,2,1)
730    dvvov = None
731
732    dOVOV = gOOVV.transpose(0,2,1,3) - gOOVV.transpose(0,3,1,2)
733    dOOOO = gOOOO.transpose(0,2,1,3) - gOOOO.transpose(0,3,1,2)
734    dVVVV = gVVVV.transpose(0,2,1,3) - gVVVV.transpose(0,3,1,2)
735    dOVVO = gOVVO.transpose(0,2,1,3)
736    dOOOV = gOOOV.transpose(0,2,1,3) - gOOOV.transpose(1,2,0,3)
737    dOVVV = gOVVV.transpose(0,2,1,3) - gOVVV.transpose(0,3,1,2)
738    dOOVV =-dOVVO.transpose(0,3,2,1)
739    dVVOV = None
740
741    dovOV = goOvV.transpose(0,2,1,3)
742    dooOO = goOoO.transpose(0,2,1,3)
743    dvvVV = gvVvV.transpose(0,2,1,3)
744    dovVO = goVvO.transpose(0,2,1,3)
745    dooOV = goOoV.transpose(0,2,1,3)
746    dovVV = goVvV.transpose(0,2,1,3)
747    dooVV = goVoV.transpose(0,2,1,3)
748    dooVV = -(dooVV + dooVV.transpose(1,0,3,2).conj()) * .5
749    dvvOV = None
750
751    dOVov = None
752    dOOoo = None
753    dVVvv = None
754    dOVvo = dovVO.transpose(3,2,1,0).conj()
755    dOOov = gOoOv.transpose(0,2,1,3)
756    dOVvv = gOvVv.transpose(0,2,1,3)
757    dOOvv = gOvOv.transpose(0,2,1,3)
758    dOOvv =-(dOOvv + dOOvv.transpose(1,0,3,2).conj()) * .5
759    dVVov = None
760
761    return ((dovov, dovOV, dOVov, dOVOV),
762            (dvvvv, dvvVV, dVVvv, dVVVV),
763            (doooo, dooOO, dOOoo, dOOOO),
764            (doovv, dooVV, dOOvv, dOOVV),
765            (dovvo, dovVO, dOVvo, dOVVO),
766            (dvvov, dvvOV, dVVov, dVVOV),
767            (dovvv, dovVV, dOVvv, dOVVV),
768            (dooov, dooOV, dOOov, dOOOV))
769
770def trans_rdm1(myci, cibra, ciket, nmo=None, nocc=None):
771    r'''
772    One-particle spin density matrices dm1a, dm1b in MO basis (the
773    occupied-virtual blocks due to the orbital response contribution are not
774    included).
775
776    dm1a[p,q] = <q_alpha^\dagger p_alpha>
777    dm1b[p,q] = <q_beta^\dagger p_beta>
778
779    The convention of 1-pdm is based on McWeeney's book, Eq (5.4.20).
780    '''
781    if nmo is None: nmo = myci.nmo
782    if nocc is None: nocc = myci.nocc
783    c0bra, c1bra, c2bra = myci.cisdvec_to_amplitudes(cibra, nmo, nocc)
784    c0ket, c1ket, c2ket = myci.cisdvec_to_amplitudes(ciket, nmo, nocc)
785
786    nmoa, nmob = nmo
787    nocca, noccb = nocc
788    bra1a, bra1b = c1bra
789    bra2aa, bra2ab, bra2bb = c2bra
790    ket1a, ket1b = c1ket
791    ket2aa, ket2ab, ket2bb = c2ket
792
793    dvoa = c0bra.conj() * ket1a.T
794    dvob = c0bra.conj() * ket1b.T
795    dvoa += numpy.einsum('jb,ijab->ai', bra1a.conj(), ket2aa)
796    dvoa += numpy.einsum('jb,ijab->ai', bra1b.conj(), ket2ab)
797    dvob += numpy.einsum('jb,ijab->ai', bra1b.conj(), ket2bb)
798    dvob += numpy.einsum('jb,jiba->ai', bra1a.conj(), ket2ab)
799
800    dova = c0ket * bra1a.conj()
801    dovb = c0ket * bra1b.conj()
802    dova += numpy.einsum('jb,ijab->ia', ket1a.conj(), bra2aa)
803    dova += numpy.einsum('jb,ijab->ia', ket1b.conj(), bra2ab)
804    dovb += numpy.einsum('jb,ijab->ia', ket1b.conj(), bra2bb)
805    dovb += numpy.einsum('jb,jiba->ia', ket1a.conj(), bra2ab)
806
807    dooa  =-numpy.einsum('ia,ka->ik', bra1a.conj(), ket1a)
808    doob  =-numpy.einsum('ia,ka->ik', bra1b.conj(), ket1b)
809    dooa -= numpy.einsum('ijab,ikab->jk', bra2aa.conj(), ket2aa) * .5
810    dooa -= numpy.einsum('jiab,kiab->jk', bra2ab.conj(), ket2ab)
811    doob -= numpy.einsum('ijab,ikab->jk', bra2bb.conj(), ket2bb) * .5
812    doob -= numpy.einsum('ijab,ikab->jk', bra2ab.conj(), ket2ab)
813
814    dvva  = numpy.einsum('ia,ic->ac', ket1a, bra1a.conj())
815    dvvb  = numpy.einsum('ia,ic->ac', ket1b, bra1b.conj())
816    dvva += numpy.einsum('ijab,ijac->bc', ket2aa, bra2aa.conj()) * .5
817    dvva += numpy.einsum('ijba,ijca->bc', ket2ab, bra2ab.conj())
818    dvvb += numpy.einsum('ijba,ijca->bc', ket2bb, bra2bb.conj()) * .5
819    dvvb += numpy.einsum('ijab,ijac->bc', ket2ab, bra2ab.conj())
820
821    dm1a = numpy.empty((nmoa,nmoa), dtype=dooa.dtype)
822    dm1a[:nocca,:nocca] = dooa
823    dm1a[:nocca,nocca:] = dova
824    dm1a[nocca:,:nocca] = dvoa
825    dm1a[nocca:,nocca:] = dvva
826    norm = numpy.dot(cibra, ciket)
827    dm1a[numpy.diag_indices(nocca)] += norm
828
829    dm1b = numpy.empty((nmob,nmob), dtype=dooa.dtype)
830    dm1b[:noccb,:noccb] = doob
831    dm1b[:noccb,noccb:] = dovb
832    dm1b[noccb:,:noccb] = dvob
833    dm1b[noccb:,noccb:] = dvvb
834    dm1b[numpy.diag_indices(noccb)] += norm
835
836    if myci.frozen is not None:
837        nmoa = myci.mo_occ[0].size
838        nmob = myci.mo_occ[1].size
839        nocca = numpy.count_nonzero(myci.mo_occ[0] > 0)
840        noccb = numpy.count_nonzero(myci.mo_occ[1] > 0)
841        rdm1a = numpy.zeros((nmoa,nmoa), dtype=dm1a.dtype)
842        rdm1b = numpy.zeros((nmob,nmob), dtype=dm1b.dtype)
843        rdm1a[numpy.diag_indices(nocca)] = norm
844        rdm1b[numpy.diag_indices(noccb)] = norm
845        moidx = myci.get_frozen_mask()
846        moidxa = numpy.where(moidx[0])[0]
847        moidxb = numpy.where(moidx[1])[0]
848        rdm1a[moidxa[:,None],moidxa] = dm1a
849        rdm1b[moidxb[:,None],moidxb] = dm1b
850        dm1a = rdm1a
851        dm1b = rdm1b
852    return dm1a, dm1b
853
854
855class UCISD(cisd.CISD):
856
857    def vector_size(self):
858        norba, norbb = self.nmo
859        nocca, noccb = self.nocc
860        nvira = norba - nocca
861        nvirb = norbb - noccb
862        nooa = nocca * (nocca-1) // 2
863        nvva = nvira * (nvira-1) // 2
864        noob = noccb * (noccb-1) // 2
865        nvvb = nvirb * (nvirb-1) // 2
866        size = (1 + nocca*nvira + noccb*nvirb +
867                nocca*noccb*nvira*nvirb + nooa*nvva + noob*nvvb)
868        return size
869
870    get_nocc = uccsd.get_nocc
871    get_nmo = uccsd.get_nmo
872    get_frozen_mask = uccsd.get_frozen_mask
873
874    def get_init_guess(self, eris=None, nroots=1, diag=None):
875        if eris is None: eris = self.ao2mo(self.mo_coeff)
876        nocca, noccb = self.nocc
877        mo_ea, mo_eb = eris.mo_energy
878        eia_a = mo_ea[:nocca,None] - mo_ea[None,nocca:]
879        eia_b = mo_eb[:noccb,None] - mo_eb[None,noccb:]
880        t1a = eris.focka[:nocca,nocca:].conj() / eia_a
881        t1b = eris.fockb[:noccb,noccb:].conj() / eia_b
882
883        eris_ovov = _cp(eris.ovov)
884        eris_ovOV = _cp(eris.ovOV)
885        eris_OVOV = _cp(eris.OVOV)
886        t2aa = eris_ovov.transpose(0,2,1,3) - eris_ovov.transpose(0,2,3,1)
887        t2bb = eris_OVOV.transpose(0,2,1,3) - eris_OVOV.transpose(0,2,3,1)
888        t2ab = eris_ovOV.transpose(0,2,1,3).copy()
889        t2aa = t2aa.conj()
890        t2ab = t2ab.conj()
891        t2bb = t2bb.conj()
892        t2aa /= lib.direct_sum('ia+jb->ijab', eia_a, eia_a)
893        t2ab /= lib.direct_sum('ia+jb->ijab', eia_a, eia_b)
894        t2bb /= lib.direct_sum('ia+jb->ijab', eia_b, eia_b)
895
896        emp2  = numpy.einsum('iajb,ijab', eris_ovov, t2aa) * .25
897        emp2 -= numpy.einsum('jaib,ijab', eris_ovov, t2aa) * .25
898        emp2 += numpy.einsum('iajb,ijab', eris_OVOV, t2bb) * .25
899        emp2 -= numpy.einsum('jaib,ijab', eris_OVOV, t2bb) * .25
900        emp2 += numpy.einsum('iajb,ijab', eris_ovOV, t2ab)
901        self.emp2 = emp2.real
902        logger.info(self, 'Init t2, MP2 energy = %.15g', self.emp2)
903
904        if abs(emp2) < 1e-3 and (abs(t1a).sum()+abs(t1b).sum()) < 1e-3:
905            t1a = 1e-1 / eia_a
906            t1b = 1e-1 / eia_b
907
908        ci_guess = amplitudes_to_cisdvec(1, (t1a,t1b), (t2aa,t2ab,t2bb))
909
910        if nroots > 1:
911            civec_size = ci_guess.size
912            ci1_size = t1a.size + t1b.size
913            dtype = ci_guess.dtype
914            nroots = min(ci1_size+1, nroots)
915
916            if diag is None:
917                idx = range(1, nroots)
918            else:
919                idx = diag[:ci1_size+1].argsort()[1:nroots]  # exclude HF determinant
920
921            ci_guess = [ci_guess]
922            for i in idx:
923                g = numpy.zeros(civec_size, dtype)
924                g[i] = 1.0
925                ci_guess.append(g)
926
927        return self.emp2, ci_guess
928
929    contract = contract
930    make_diagonal = make_diagonal
931    _dot = None
932    _add_vvvv = uccsd._add_vvvv
933
934    def ao2mo(self, mo_coeff=None):
935        nmoa, nmob = self.get_nmo()
936        nao = self.mo_coeff[0].shape[0]
937        nmo_pair = nmoa * (nmoa+1) // 2
938        nao_pair = nao * (nao+1) // 2
939        mem_incore = (max(nao_pair**2, nmoa**4) + nmo_pair**2) * 8/1e6
940        mem_now = lib.current_memory()[0]
941        if (self._scf._eri is not None and
942            (mem_incore+mem_now < self.max_memory) or self.mol.incore_anyway):
943            return uccsd._make_eris_incore(self, mo_coeff)
944
945        elif getattr(self._scf, 'with_df', None):
946            raise NotImplementedError
947
948        else:
949            return uccsd._make_eris_outcore(self, mo_coeff)
950
951    def to_fcivec(self, cisdvec, nmo=None, nocc=None):
952        return to_fcivec(cisdvec, nmo, nocc)
953
954    def from_fcivec(self, fcivec, nmo=None, nocc=None):
955        return from_fcivec(fcivec, nmo, nocc)
956
957    def amplitudes_to_cisdvec(self, c0, c1, c2):
958        return amplitudes_to_cisdvec(c0, c1, c2)
959
960    def cisdvec_to_amplitudes(self, civec, nmo=None, nocc=None):
961        if nmo is None: nmo = self.nmo
962        if nocc is None: nocc = self.nocc
963        return cisdvec_to_amplitudes(civec, nmo, nocc)
964
965    make_rdm1 = make_rdm1
966    make_rdm2 = make_rdm2
967    trans_rdm1 = trans_rdm1
968
969    def nuc_grad_method(self):
970        from pyscf.grad import ucisd
971        return ucisd.Gradients(self)
972
973CISD = UCISD
974
975from pyscf import scf
976scf.uhf.UHF.CISD = lib.class_as_method(CISD)
977
978def _cp(a):
979    return numpy.array(a, copy=False, order='C')
980
981
982if __name__ == '__main__':
983    from pyscf import gto
984
985    mol = gto.Mole()
986    mol.verbose = 0
987    mol.atom = [
988        ['O', ( 0., 0.    , 0.   )],
989        ['H', ( 0., -0.757, 0.587)],
990        ['H', ( 0., 0.757 , 0.587)],]
991    mol.basis = {'H': 'sto-3g',
992                 'O': 'sto-3g',}
993#    mol.build()
994#    mf = scf.UHF(mol).run(conv_tol=1e-14)
995#    myci = CISD(mf)
996#    eris = myci.ao2mo()
997#    ecisd, civec = myci.kernel(eris=eris)
998#    print(ecisd - -0.048878084082066106)
999#
1000#    nmoa = mf.mo_energy[0].size
1001#    nmob = mf.mo_energy[1].size
1002#    rdm1 = myci.make_rdm1(civec)
1003#    rdm2 = myci.make_rdm2(civec)
1004#    eri_aa = ao2mo.kernel(mf._eri, mf.mo_coeff[0], compact=False).reshape([nmoa]*4)
1005#    eri_bb = ao2mo.kernel(mf._eri, mf.mo_coeff[1], compact=False).reshape([nmob]*4)
1006#    eri_ab = ao2mo.kernel(mf._eri, [mf.mo_coeff[0], mf.mo_coeff[0],
1007#                                    mf.mo_coeff[1], mf.mo_coeff[1]], compact=False)
1008#    eri_ab = eri_ab.reshape(nmoa,nmoa,nmob,nmob)
1009#    h1a = reduce(numpy.dot, (mf.mo_coeff[0].T, mf.get_hcore(), mf.mo_coeff[0]))
1010#    h1b = reduce(numpy.dot, (mf.mo_coeff[1].T, mf.get_hcore(), mf.mo_coeff[1]))
1011#    e2 = (numpy.einsum('ij,ji', h1a, rdm1[0]) +
1012#          numpy.einsum('ij,ji', h1b, rdm1[1]) +
1013#          numpy.einsum('ijkl,ijkl', eri_aa, rdm2[0]) * .5 +
1014#          numpy.einsum('ijkl,ijkl', eri_ab, rdm2[1])      +
1015#          numpy.einsum('ijkl,ijkl', eri_bb, rdm2[2]) * .5)
1016#    print(ecisd + mf.e_tot - mol.energy_nuc() - e2)   # = 0
1017#
1018#    print(abs(rdm1[0] - (numpy.einsum('ijkk->ji', rdm2[0]) +
1019#                         numpy.einsum('ijkk->ji', rdm2[1]))/(mol.nelectron-1)).sum())
1020#    print(abs(rdm1[1] - (numpy.einsum('ijkk->ji', rdm2[2]) +
1021#                         numpy.einsum('kkij->ji', rdm2[1]))/(mol.nelectron-1)).sum())
1022
1023    if 1:
1024        from pyscf.ci import ucisd
1025        from pyscf import fci
1026        nmo = 8
1027        nocc = nocca, noccb = (4,3)
1028        numpy.random.seed(2)
1029        nvira, nvirb = nmo-nocca, nmo-noccb
1030        cibra = ucisd.amplitudes_to_cisdvec(numpy.random.rand(1),
1031                                            (numpy.random.rand(nocca,nvira),
1032                                             numpy.random.rand(noccb,nvirb)),
1033                                            (numpy.random.rand(nocca,nocca,nvira,nvira),
1034                                             numpy.random.rand(nocca,noccb,nvira,nvirb),
1035                                             numpy.random.rand(noccb,noccb,nvirb,nvirb)))
1036        ciket = ucisd.amplitudes_to_cisdvec(numpy.random.rand(1),
1037                                            (numpy.random.rand(nocca,nvira),
1038                                             numpy.random.rand(noccb,nvirb)),
1039                                            (numpy.random.rand(nocca,nocca,nvira,nvira),
1040                                             numpy.random.rand(nocca,noccb,nvira,nvirb),
1041                                             numpy.random.rand(noccb,noccb,nvirb,nvirb)))
1042        fcibra = ucisd.to_fcivec(cibra, nmo, nocc)
1043        fciket = ucisd.to_fcivec(ciket, nmo, nocc)
1044        s_mo = (numpy.random.random((nmo,nmo)),
1045                numpy.random.random((nmo,nmo)))
1046        s_mo = (s_mo[0], s_mo[0])
1047        s0 = fci.addons.overlap(fcibra, fciket, nmo, nocc, s_mo)
1048        s1 = ucisd.overlap(cibra, ciket, nmo, nocc, s_mo)
1049        print(s1, s0, 9)
1050