1#!/usr/bin/env python
2# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Authors: Sheng Guo
17#          Qiming Sun <osirpt.sun@gmail.com>
18#
19
20import ctypes
21
22import tempfile
23from functools import reduce
24import numpy
25import h5py
26from pyscf import lib
27from pyscf.lib import logger
28from pyscf import fci
29from pyscf.mcscf import mc_ao2mo
30from pyscf import ao2mo
31from pyscf.ao2mo import _ao2mo
32
33libmc = lib.load_library('libmcscf')
34
35NUMERICAL_ZERO = 1e-14
36# Ref JCP 117, 9138 (2002); DOI:10.1063/1.1515317
37
38# h1e is the CAS space effective 1e hamiltonian
39# h2e is the CAS space 2e integrals in  notation # a' -> p # b' -> q # c' -> r
40# d' -> s
41
42def make_a16(h1e, h2e, dms, civec, norb, nelec, link_index=None):
43    dm3 = dms['3']
44    #dm4 = dms['4']
45    if 'f3ca' in dms and 'f3ac' in dms:
46        f3ca = dms['f3ca']
47        f3ac = dms['f3ac']
48    else:
49        if isinstance(nelec, (int, numpy.integer)):
50            neleca = nelecb = nelec//2
51        else:
52            neleca, nelecb = nelec
53        if link_index is None:
54            link_indexa = fci.cistring.gen_linkstr_index(range(norb), neleca)
55            link_indexb = fci.cistring.gen_linkstr_index(range(norb), nelecb)
56        else:
57            link_indexa, link_indexb = link_index
58        eri = h2e.transpose(0,2,1,3)
59        f3ca = _contract4pdm('NEVPTkern_cedf_aedf', eri, civec, norb, nelec,
60                             (link_indexa,link_indexb))
61        f3ac = _contract4pdm('NEVPTkern_aedf_ecdf', eri, civec, norb, nelec,
62                             (link_indexa,link_indexb))
63
64    a16 = -numpy.einsum('ib,rpqiac->pqrabc', h1e, dm3)
65    a16 += numpy.einsum('ia,rpqbic->pqrabc', h1e, dm3)
66    a16 -= numpy.einsum('ci,rpqbai->pqrabc', h1e, dm3)
67
68# qjkiac = acqjki + delta(ja)qcki + delta(ia)qjkc - delta(qc)ajki - delta(kc)qjai
69    #:a16 -= numpy.einsum('kbij,rpqjkiac->pqrabc', h2e, dm4)
70    a16 -= f3ca.transpose(1,4,0,2,5,3) # c'a'acb'b -> a'b'c'abc
71    a16 -= numpy.einsum('kbia,rpqcki->pqrabc', h2e, dm3)
72    a16 -= numpy.einsum('kbaj,rpqjkc->pqrabc', h2e, dm3)
73    a16 += numpy.einsum('cbij,rpqjai->pqrabc', h2e, dm3)
74    fdm2 = numpy.einsum('kbij,rpajki->prab'  , h2e, dm3)
75    for i in range(norb):
76        a16[:,i,:,:,:,i] += fdm2
77
78    #:a16 += numpy.einsum('ijka,rpqbjcik->pqrabc', h2e, dm4)
79    a16 += f3ac.transpose(1,2,0,4,3,5) # c'a'b'bac -> a'b'c'abc
80
81    #:a16 -= numpy.einsum('kcij,rpqbajki->pqrabc', h2e, dm4)
82    a16 -= f3ca.transpose(1,2,0,4,3,5) # c'a'b'bac -> a'b'c'abc
83
84    a16 += numpy.einsum('jbij,rpqiac->pqrabc', h2e, dm3)
85    a16 -= numpy.einsum('cjka,rpqbjk->pqrabc', h2e, dm3)
86    a16 += numpy.einsum('jcij,rpqbai->pqrabc', h2e, dm3)
87    return a16
88
89def make_a22(h1e, h2e, dms, civec, norb, nelec, link_index=None):
90    dm2 = dms['2']
91    dm3 = dms['3']
92    #dm4 = dms['4']
93    if 'f3ca' in dms and 'f3ac' in dms:
94        f3ca = dms['f3ca']
95        f3ac = dms['f3ac']
96    else:
97        if isinstance(nelec, (int, numpy.integer)):
98            neleca = nelecb = nelec//2
99        else:
100            neleca, nelecb = nelec
101        if link_index is None:
102            link_indexa = fci.cistring.gen_linkstr_index(range(norb), neleca)
103            link_indexb = fci.cistring.gen_linkstr_index(range(norb), nelecb)
104        else:
105            link_indexa, link_indexb = link_index
106        eri = h2e.transpose(0,2,1,3)
107        f3ca = _contract4pdm('NEVPTkern_cedf_aedf', eri, civec, norb, nelec,
108                             (link_indexa,link_indexb))
109        f3ac = _contract4pdm('NEVPTkern_aedf_ecdf', eri, civec, norb, nelec,
110                             (link_indexa,link_indexb))
111
112    a22 = -numpy.einsum('pb,kipjac->ijkabc', h1e, dm3)
113    a22 -= numpy.einsum('pa,kibjpc->ijkabc', h1e, dm3)
114    a22 += numpy.einsum('cp,kibjap->ijkabc', h1e, dm3)
115    a22 += numpy.einsum('cqra,kibjqr->ijkabc', h2e, dm3)
116    a22 -= numpy.einsum('qcpq,kibjap->ijkabc', h2e, dm3)
117
118# qjprac = acqjpr + delta(ja)qcpr + delta(ra)qjpc - delta(qc)ajpr - delta(pc)qjar
119    #a22 -= numpy.einsum('pqrb,kiqjprac->ijkabc', h2e, dm4)
120    a22 -= f3ac.transpose(1,5,0,2,4,3) # c'a'acbb'
121    fdm2 = numpy.einsum('pqrb,kiqcpr->ikbc', h2e, dm3)
122    for i in range(norb):
123        a22[:,i,:,i,:,:] -= fdm2
124    a22 -= numpy.einsum('pqab,kiqjpc->ijkabc', h2e, dm3)
125    a22 += numpy.einsum('pcrb,kiajpr->ijkabc', h2e, dm3)
126    a22 += numpy.einsum('cqrb,kiqjar->ijkabc', h2e, dm3)
127
128    #a22 -= numpy.einsum('pqra,kibjqcpr->ijkabc', h2e, dm4)
129    a22 -= f3ac.transpose(1,3,0,4,2,5) # c'a'bb'ac -> a'b'c'abc
130
131    #a22 += numpy.einsum('rcpq,kibjaqrp->ijkabc', h2e, dm4)
132    a22 += f3ca.transpose(1,3,0,4,2,5) # c'a'bb'ac -> a'b'c'abc
133
134    a22 += 2.0*numpy.einsum('jb,kiac->ijkabc', h1e, dm2)
135    a22 += 2.0*numpy.einsum('pjrb,kiprac->ijkabc', h2e, dm3)
136    fdm2  = numpy.einsum('pa,kipc->ikac', h1e, dm2)
137    fdm2 -= numpy.einsum('cp,kiap->ikac', h1e, dm2)
138    fdm2 -= numpy.einsum('cqra,kiqr->ikac', h2e, dm2)
139    fdm2 += numpy.einsum('qcpq,kiap->ikac', h2e, dm2)
140    fdm2 += numpy.einsum('pqra,kiqcpr->ikac', h2e, dm3)
141    fdm2 -= numpy.einsum('rcpq,kiaqrp->ikac', h2e, dm3)
142    for i in range(norb):
143        a22[:,i,:,:,i,:] += fdm2 * 2
144
145    return a22
146
147
148def make_a17(h1e,h2e,dm2,dm3):
149    h1e = h1e - numpy.einsum('mjjn->mn',h2e)
150
151    a17 = -numpy.einsum('pi,cabi->abcp',h1e,dm2)\
152          -numpy.einsum('kpij,cabjki->abcp',h2e,dm3)
153    return a17
154
155def make_a19(h1e,h2e,dm1,dm2):
156    h1e = h1e - numpy.einsum('mjjn->mn',h2e)
157
158    a19 = -numpy.einsum('pi,ai->ap',h1e,dm1)\
159          -numpy.einsum('kpij,ajki->ap',h2e,dm2)
160    return a19
161
162def make_a23(h1e,h2e,dm1,dm2,dm3):
163    a23 = -numpy.einsum('ip,caib->abcp',h1e,dm2)\
164          -numpy.einsum('pijk,cajbik->abcp',h2e,dm3)\
165          +2.0*numpy.einsum('bp,ca->abcp',h1e,dm1)\
166          +2.0*numpy.einsum('pibk,caik->abcp',h2e,dm2)
167
168    return a23
169
170def make_a25(h1e,h2e,dm1,dm2):
171
172    a25 = -numpy.einsum('pi,ai->ap',h1e,dm1)\
173          -numpy.einsum('pijk,jaik->ap',h2e,dm2)\
174          +2.0*numpy.einsum('ap->pa',h1e)\
175          +2.0*numpy.einsum('piaj,ij->ap',h2e,dm1)
176
177    return a25
178
179def make_hdm3(dm1,dm2,dm3,hdm1,hdm2):
180    delta = numpy.eye(dm3.shape[0])
181    hdm3 = - numpy.einsum('pb,qrac->pqrabc',delta,hdm2)\
182          - numpy.einsum('br,pqac->pqrabc',delta,hdm2)\
183          + numpy.einsum('bq,prac->pqrabc',delta,hdm2)*2.0\
184          + numpy.einsum('ap,bqcr->pqrabc',delta,dm2)*2.0\
185          - numpy.einsum('ap,cr,bq->pqrabc',delta,delta,dm1)*4.0\
186          + numpy.einsum('cr,bqap->pqrabc',delta,dm2)*2.0\
187          - numpy.einsum('bqapcr->pqrabc',dm3)\
188          + numpy.einsum('ar,pc,bq->pqrabc',delta,delta,dm1)*2.0\
189          - numpy.einsum('ar,bqcp->pqrabc',delta,dm2)
190    return hdm3
191
192
193def make_hdm2(dm1,dm2):
194    delta = numpy.eye(dm2.shape[0])
195    dm2 = numpy.einsum('ikjl->ijkl',dm2) -numpy.einsum('jk,il->ijkl',delta,dm1)
196    hdm2 = numpy.einsum('klij->ijkl',dm2)\
197            + numpy.einsum('il,kj->ijkl',delta,dm1)\
198            + numpy.einsum('jk,li->ijkl',delta,dm1)\
199            - 2.0*numpy.einsum('ik,lj->ijkl',delta,dm1)\
200            - 2.0*numpy.einsum('jl,ki->ijkl',delta,dm1)\
201            - 2.0*numpy.einsum('il,jk->ijkl',delta,delta)\
202            + 4.0*numpy.einsum('ik,jl->ijkl',delta,delta)
203
204    return hdm2
205
206def make_hdm1(dm1):
207    delta = numpy.eye(dm1.shape[0])
208    hdm1 = 2.0*delta-dm1.transpose(1,0)
209    return hdm1
210
211def make_a3(h1e,h2e,dm1,dm2,hdm1):
212    delta = numpy.eye(dm2.shape[0])
213    a3 = numpy.einsum('ia,ip->pa',h1e,hdm1)\
214            + 2.0*numpy.einsum('ijka,pj,ik->pa',h2e,delta,dm1)\
215            - numpy.einsum('ijka,jpik->pa',h2e,dm2)
216    return a3
217
218def make_k27(h1e,h2e,dm1,dm2):
219    k27 = -numpy.einsum('ai,pi->pa',h1e,dm1)\
220         -numpy.einsum('iajk,pkij->pa',h2e,dm2)\
221         +numpy.einsum('iaji,pj->pa',h2e,dm1)
222    return k27
223
224
225
226def make_a7(h1e,h2e,dm1,dm2,dm3):
227    #This dm2 and dm3 need to be in the form of norm order
228    delta = numpy.eye(dm2.shape[0])
229    # a^+_ia^+_ja_ka^l =  E^i_lE^j_k -\delta_{j,l} E^i_k
230    rm2 = numpy.einsum('iljk->ijkl',dm2) - numpy.einsum('ik,jl->ijkl',dm1,delta)
231    # E^{i,j,k}_{l,m,n} = E^{i,j}_{m,n}E^k_l -\delta_{k,m}E^{i,j}_{l,n}- \delta_{k,n}E^{i,j}_{m,l}
232    # = E^i_nE^j_mE^k_l -\delta_{j,n}E^i_mE^k_l -\delta_{k,m}E^{i,j}_{l,n} -\delta_{k,n}E^{i,j}_{m,l}
233    rm3 = numpy.einsum('injmkl->ijklmn',dm3)\
234        - numpy.einsum('jn,imkl->ijklmn',delta,dm2)\
235        - numpy.einsum('km,ijln->ijklmn',delta,rm2)\
236        - numpy.einsum('kn,ijml->ijklmn',delta,rm2)
237
238    a7 = -numpy.einsum('bi,pqia->pqab',h1e,rm2)\
239         -numpy.einsum('ai,pqbi->pqab',h1e,rm2)\
240         -numpy.einsum('kbij,pqkija->pqab',h2e,rm3) \
241         -numpy.einsum('kaij,pqkibj->pqab',h2e,rm3) \
242         -numpy.einsum('baij,pqij->pqab',h2e,rm2)
243    return rm2, a7
244
245def make_a9(h1e,h2e,hdm1,hdm2,hdm3):
246    a9 =  numpy.einsum('ib,pqai->pqab',h1e,hdm2)
247    a9 += numpy.einsum('ijib,pqaj->pqab',h2e,hdm2)*2.0
248    a9 -= numpy.einsum('ijjb,pqai->pqab',h2e,hdm2)
249    a9 -= numpy.einsum('ijkb,pkqaij->pqab',h2e,hdm3)
250    a9 += numpy.einsum('ia,pqib->pqab',h1e,hdm2)
251    a9 -= numpy.einsum('ijja,pqib->pqab',h2e,hdm2)
252    a9 -= numpy.einsum('ijba,pqji->pqab',h2e,hdm2)
253    a9 += numpy.einsum('ijia,pqjb->pqab',h2e,hdm2)*2.0
254    a9 -= numpy.einsum('ijka,pqkjbi->pqab',h2e,hdm3)
255    return a9
256
257def make_a12(h1e,h2e,dm1,dm2,dm3):
258    a12 = numpy.einsum('ia,qpib->pqab',h1e,dm2)\
259        - numpy.einsum('bi,qpai->pqab',h1e,dm2)\
260        + numpy.einsum('ijka,qpjbik->pqab',h2e,dm3)\
261        - numpy.einsum('kbij,qpajki->pqab',h2e,dm3)\
262        - numpy.einsum('bjka,qpjk->pqab',h2e,dm2)\
263        + numpy.einsum('jbij,qpai->pqab',h2e,dm2)
264    return a12
265
266def make_a13(h1e,h2e,dm1,dm2,dm3):
267    delta = numpy.eye(dm3.shape[0])
268    a13 = -numpy.einsum('ia,qbip->pqab',h1e,dm2)
269    a13 += numpy.einsum('pa,qb->pqab',h1e,dm1)*2.0
270    a13 += numpy.einsum('bi,qiap->pqab',h1e,dm2)
271    a13 -= numpy.einsum('pa,bi,qi->pqab',delta,h1e,dm1)*2.0
272    a13 -= numpy.einsum('ijka,qbjpik->pqab',h2e,dm3)
273    a13 += numpy.einsum('kbij,qjapki->pqab',h2e,dm3)
274    a13 += numpy.einsum('blma,qmlp->pqab',h2e,dm2)
275    a13 += numpy.einsum('kpma,qbkm->pqab',h2e,dm2)*2.0
276    a13 -= numpy.einsum('bpma,qm->pqab',h2e,dm1)*2.0
277    a13 -= numpy.einsum('lbkl,qkap->pqab',h2e,dm2)
278    a13 -= numpy.einsum('ap,mbkl,qlmk->pqab',delta,h2e,dm2)*2.0
279    a13 += numpy.einsum('ap,lbkl,qk->pqab',delta,h2e,dm1)*2.0
280    return a13
281
282
283def Sr(mc,ci,dms, eris=None, verbose=None):
284    #The subspace S_r^{(-1)}
285    mo_core, mo_cas, mo_virt = _extract_orbs(mc, mc.mo_coeff)
286    dm1 = dms['1']
287    dm2 = dms['2']
288    dm3 = dms['3']
289    #dm4 = dms['4']
290    ncore = mo_core.shape[1]
291    ncas = mo_cas.shape[1]
292    nocc = ncore + ncas
293
294    if eris is None:
295        h1e = mc.h1e_for_cas()[0]
296        h2e = ao2mo.restore(1, mc.ao2mo(mo_cas), ncas).transpose(0,2,1,3)
297        h2e_v = ao2mo.incore.general(mc._scf._eri,[mo_virt,mo_cas,mo_cas,mo_cas],compact=False)
298        h2e_v = h2e_v.reshape(mo_virt.shape[1],ncas,ncas,ncas).transpose(0,2,1,3)
299        core_dm = numpy.dot(mo_core,mo_core.T) *2
300        core_vhf = mc.get_veff(mc.mol,core_dm)
301        h1e_v = reduce(numpy.dot, (mo_virt.T, mc.get_hcore()+core_vhf , mo_cas))
302        h1e_v -= numpy.einsum('mbbn->mn',h2e_v)
303    else:
304        h1e = eris['h1eff'][ncore:nocc,ncore:nocc]
305        h2e = eris['ppaa'][ncore:nocc,ncore:nocc].transpose(0,2,1,3)
306        h2e_v = eris['ppaa'][nocc:,ncore:nocc].transpose(0,2,1,3)
307        h1e_v = eris['h1eff'][nocc:,ncore:nocc] - numpy.einsum('mbbn->mn',h2e_v)
308
309
310    if getattr(mc.fcisolver, 'nevpt_intermediate', None):
311        a16 = mc.fcisolver.nevpt_intermediate('A16',ncas,mc.nelecas,ci)
312    else:
313        a16 = make_a16(h1e,h2e, dms, ci, ncas, mc.nelecas)
314    a17 = make_a17(h1e,h2e,dm2,dm3)
315    a19 = make_a19(h1e,h2e,dm1,dm2)
316
317    ener = numpy.einsum('ipqr,pqrabc,iabc->i',h2e_v,a16,h2e_v)\
318        +  numpy.einsum('ipqr,pqra,ia->i',h2e_v,a17,h1e_v)*2.0\
319        +  numpy.einsum('ip,pa,ia->i',h1e_v,a19,h1e_v)
320
321    norm = numpy.einsum('ipqr,rpqbac,iabc->i',h2e_v,dm3,h2e_v)\
322        +  numpy.einsum('ipqr,rpqa,ia->i',h2e_v,dm2,h1e_v)*2.0\
323        +  numpy.einsum('ip,pa,ia->i',h1e_v,dm1,h1e_v)
324
325    return _norm_to_energy(norm, ener, mc.mo_energy[nocc:])
326
327def Si(mc, ci, dms, eris=None, verbose=None):
328    #Subspace S_i^{(1)}
329    mo_core, mo_cas, mo_virt = _extract_orbs(mc, mc.mo_coeff)
330    dm1 = dms['1']
331    dm2 = dms['2']
332    dm3 = dms['3']
333    #dm4 = dms['4']
334    ncore = mo_core.shape[1]
335    ncas = mo_cas.shape[1]
336    nocc = ncore + ncas
337
338    if eris is None:
339        h1e = mc.h1e_for_cas()[0]
340        h2e = ao2mo.restore(1, mc.ao2mo(mo_cas), ncas).transpose(0,2,1,3)
341        h2e_v = ao2mo.incore.general(mc._scf._eri,[mo_cas,mo_core,mo_cas,mo_cas],compact=False)
342        h2e_v = h2e_v.reshape(ncas,ncore,ncas,ncas).transpose(0,2,1,3)
343        core_dm = numpy.dot(mo_core,mo_core.T) *2
344        core_vhf = mc.get_veff(mc.mol,core_dm)
345        h1e_v = reduce(numpy.dot, (mo_cas.T, mc.get_hcore()+core_vhf , mo_core))
346    else:
347        h1e = eris['h1eff'][ncore:nocc,ncore:nocc]
348        h2e = eris['ppaa'][ncore:nocc,ncore:nocc].transpose(0,2,1,3)
349        h2e_v = eris['ppaa'][ncore:nocc,:ncore].transpose(0,2,1,3)
350        h1e_v = eris['h1eff'][ncore:nocc,:ncore]
351
352    if getattr(mc.fcisolver, 'nevpt_intermediate', None):
353        #mc.fcisolver.make_a22(ncas, state)
354        a22 = mc.fcisolver.nevpt_intermediate('A22',ncas,mc.nelecas,ci)
355    else:
356        a22 = make_a22(h1e,h2e, dms, ci, ncas, mc.nelecas)
357    a23 = make_a23(h1e,h2e,dm1,dm2,dm3)
358    a25 = make_a25(h1e,h2e,dm1,dm2)
359    delta = numpy.eye(ncas)
360    dm3_h = numpy.einsum('abef,cd->abcdef',dm2,delta)*2\
361            - dm3.transpose(0,1,3,2,4,5)
362    dm2_h = numpy.einsum('ab,cd->abcd',dm1,delta)*2\
363            - dm2.transpose(0,1,3,2)
364    dm1_h = 2*delta- dm1.transpose(1,0)
365
366    ener = numpy.einsum('qpir,pqrabc,baic->i',h2e_v,a22,h2e_v)\
367        +  numpy.einsum('qpir,pqra,ai->i',h2e_v,a23,h1e_v)*2.0\
368        +  numpy.einsum('pi,pa,ai->i',h1e_v,a25,h1e_v)
369
370    norm = numpy.einsum('qpir,rpqbac,baic->i',h2e_v,dm3_h,h2e_v)\
371        +  numpy.einsum('qpir,rpqa,ai->i',h2e_v,dm2_h,h1e_v)*2.0\
372        +  numpy.einsum('pi,pa,ai->i',h1e_v,dm1_h,h1e_v)
373
374    return _norm_to_energy(norm, ener, -mc.mo_energy[:ncore])
375
376
377def Sijrs(mc, eris, verbose=None):
378    mo_core, mo_cas, mo_virt = _extract_orbs(mc, mc.mo_coeff)
379    ncore = mo_core.shape[1]
380    nvirt = mo_virt.shape[1]
381    ncas = mo_cas.shape[1]
382    nocc = ncore + ncas
383    if eris is None:
384        erifile = tempfile.NamedTemporaryFile(dir=lib.param.TMPDIR)
385        feri = ao2mo.outcore.general(mc.mol, (mo_core,mo_virt,mo_core,mo_virt),
386                                     erifile.name, verbose=mc.verbose)
387    else:
388        feri = eris['cvcv']
389
390    eia = mc.mo_energy[:ncore,None] -mc.mo_energy[None,nocc:]
391    norm = 0
392    e = 0
393    with ao2mo.load(feri) as cvcv:
394        for i in range(ncore):
395            djba = (eia.reshape(-1,1) + eia[i].reshape(1,-1)).ravel()
396            gi = numpy.asarray(cvcv[i*nvirt:(i+1)*nvirt])
397            gi = gi.reshape(nvirt,ncore,nvirt).transpose(1,2,0)
398            t2i = (gi.ravel()/djba).reshape(ncore,nvirt,nvirt)
399            # 2*ijab-ijba
400            theta = gi*2 - gi.transpose(0,2,1)
401            norm += numpy.einsum('jab,jab', gi, theta)
402            e += numpy.einsum('jab,jab', t2i, theta)
403    return norm, e
404
405def Sijr(mc, dms, eris, verbose=None):
406    #Subspace S_ijr^{(1)}
407    mo_core, mo_cas, mo_virt = _extract_orbs(mc, mc.mo_coeff)
408    dm1 = dms['1']
409    dm2 = dms['2']
410    ncore = mo_core.shape[1]
411    ncas = mo_cas.shape[1]
412    nocc = ncore + ncas
413    if eris is None:
414        h1e = mc.h1e_for_cas()[0]
415        h2e = ao2mo.restore(1, mc.ao2mo(mo_cas), ncas).transpose(0,2,1,3)
416        h2e_v = ao2mo.incore.general(mc._scf._eri,[mo_virt,mo_core,mo_cas,mo_core],compact=False)
417        h2e_v = h2e_v.reshape(mo_virt.shape[1],ncore,ncas,ncore).transpose(0,2,1,3)
418    else:
419        h1e = eris['h1eff'][ncore:nocc,ncore:nocc]
420        h2e = eris['ppaa'][ncore:nocc,ncore:nocc].transpose(0,2,1,3)
421        h2e_v = eris['pacv'][:ncore].transpose(3,1,2,0)
422    if 'h1' in dms:
423        hdm1 = dms['h1']
424    else:
425        hdm1 = make_hdm1(dm1)
426
427    a3 = make_a3(h1e,h2e,dm1,dm2,hdm1)
428    # We sum norm and h only over i <= j (or j <= i instead).
429    # See Eq. (13) and (A2) in https://doi.org/10.1063/1.1515317
430    # This implementation is still somewhat wasteful in terms of memory,
431    # as we only need about half of norm and h in the end.
432    ci_diag = numpy.diag_indices(ncore)
433    ci_triu = numpy.triu_indices(ncore)
434    norm = 2.0*numpy.einsum('rpji,raji,pa->rji',h2e_v,h2e_v,hdm1)\
435         - 1.0*numpy.einsum('rpji,raij,pa->rji',h2e_v,h2e_v,hdm1)
436    norm += norm.transpose(0, 2, 1)
437    norm[:, ci_diag[0], ci_diag[1]] *= 0.5
438    h = 2.0*numpy.einsum('rpji,raji,pa->rji',h2e_v,h2e_v,a3)\
439         - 1.0*numpy.einsum('rpji,raij,pa->rji',h2e_v,h2e_v,a3)
440    h += h.transpose(0, 2, 1)
441    h[:, ci_diag[0], ci_diag[1]] *= 0.5
442
443    diff = mc.mo_energy[nocc:,None,None] - mc.mo_energy[None,:ncore,None] - mc.mo_energy[None,None,:ncore]
444
445    norm_tri = norm[:, ci_triu[0], ci_triu[1]]
446    h_tri = h[:, ci_triu[0], ci_triu[1]]
447    diff_tri = diff[:, ci_triu[0], ci_triu[1]]
448    return _norm_to_energy(norm_tri, h_tri, diff_tri)
449
450def Srsi(mc, dms, eris, verbose=None):
451    #Subspace S_ijr^{(1)}
452    mo_core, mo_cas, mo_virt = _extract_orbs(mc, mc.mo_coeff)
453    dm1 = dms['1']
454    dm2 = dms['2']
455    ncore = mo_core.shape[1]
456    ncas = mo_cas.shape[1]
457    nocc = ncore + ncas
458    nvirt = mo_virt.shape[1]
459    if eris is None:
460        h1e = mc.h1e_for_cas()[0]
461        h2e = ao2mo.restore(1, mc.ao2mo(mo_cas), ncas).transpose(0,2,1,3)
462        h2e_v = ao2mo.incore.general(mc._scf._eri,[mo_virt,mo_core,mo_virt,mo_cas],compact=False)
463        h2e_v = h2e_v.reshape(mo_virt.shape[1],ncore,mo_virt.shape[1],ncas).transpose(0,2,1,3)
464    else:
465        h1e = eris['h1eff'][ncore:nocc,ncore:nocc]
466        h2e = eris['ppaa'][ncore:nocc,ncore:nocc].transpose(0,2,1,3)
467        h2e_v = eris['pacv'][nocc:].transpose(3,0,2,1)
468
469    k27 = make_k27(h1e,h2e,dm1,dm2)
470    # We sum norm and h only over r <= s.
471    # See Eq. (12) and (26) in https://doi.org/10.1063/1.1515317
472    # This implementation is still somewhat wasteful in terms of memory,
473    # as we only need about half of norm and h in the end.
474    vi_diag = numpy.diag_indices(nvirt)
475    vi_triu = numpy.triu_indices(nvirt)
476    norm = 2.0*numpy.einsum('rsip,rsia,pa->rsi',h2e_v,h2e_v,dm1)\
477         - 1.0*numpy.einsum('rsip,sria,pa->rsi',h2e_v,h2e_v,dm1)
478    norm += norm.transpose(1, 0, 2)
479    norm[vi_diag] *= 0.5
480    h = 2.0*numpy.einsum('rsip,rsia,pa->rsi',h2e_v,h2e_v,k27)\
481         - 1.0*numpy.einsum('rsip,sria,pa->rsi',h2e_v,h2e_v,k27)
482    h += h.transpose(1, 0, 2)
483    h[vi_diag] *= 0.5
484    diff = mc.mo_energy[nocc:,None,None] + mc.mo_energy[None,nocc:,None] - mc.mo_energy[None,None,:ncore]
485    return _norm_to_energy(norm[vi_triu], h[vi_triu], diff[vi_triu])
486
487def Srs(mc, dms, eris=None, verbose=None):
488    #Subspace S_rs^{(-2)}
489    mo_core, mo_cas, mo_virt = _extract_orbs(mc, mc.mo_coeff)
490    dm1 = dms['1']
491    dm2 = dms['2']
492    dm3 = dms['3']
493    ncore = mo_core.shape[1]
494    ncas = mo_cas.shape[1]
495    nocc = ncore + ncas
496    if mo_virt.shape[1] ==0:
497        return 0, 0
498    if eris is None:
499        h1e = mc.h1e_for_cas()[0]
500        h2e = ao2mo.restore(1, mc.ao2mo(mo_cas), ncas).transpose(0,2,1,3)
501        h2e_v = ao2mo.incore.general(mc._scf._eri,[mo_virt,mo_cas,mo_virt,mo_cas],compact=False)
502        h2e_v = h2e_v.reshape(mo_virt.shape[1],ncas,mo_virt.shape[1],ncas).transpose(0,2,1,3)
503    else:
504        h1e = eris['h1eff'][ncore:nocc,ncore:nocc]
505        h2e = eris['ppaa'][ncore:nocc,ncore:nocc].transpose(0,2,1,3)
506        h2e_v = eris['papa'][nocc:,:,nocc:].transpose(0,2,1,3)
507
508# a7 is very sensitive to the accuracy of HF orbital and CI wfn
509    rm2, a7 = make_a7(h1e,h2e,dm1,dm2,dm3)
510    norm = 0.5*numpy.einsum('rsqp,rsba,pqba->rs',h2e_v,h2e_v,rm2)
511    h = 0.5*numpy.einsum('rsqp,rsba,pqab->rs',h2e_v,h2e_v,a7)
512    diff = mc.mo_energy[nocc:,None] + mc.mo_energy[None,nocc:]
513    return _norm_to_energy(norm, h, diff)
514
515def Sij(mc, dms, eris, verbose=None):
516    #Subspace S_ij^{(-2)}
517    mo_core, mo_cas, mo_virt = _extract_orbs(mc, mc.mo_coeff)
518    dm1 = dms['1']
519    dm2 = dms['2']
520    dm3 = dms['3']
521    ncore = mo_core.shape[1]
522    ncas = mo_cas.shape[1]
523    nocc = ncore + ncas
524    if mo_core.size ==0 :
525        return 0.0, 0
526    if eris is None:
527        h1e = mc.h1e_for_cas()[0]
528        h2e = ao2mo.restore(1, mc.ao2mo(mo_cas), ncas).transpose(0,2,1,3)
529        h2e_v = ao2mo.incore.general(mc._scf._eri,[mo_cas,mo_core,mo_cas,mo_core],compact=False)
530        h2e_v = h2e_v.reshape(ncas,ncore,ncas,ncore).transpose(0,2,1,3)
531    else:
532        h1e = eris['h1eff'][ncore:nocc,ncore:nocc]
533        h2e = eris['ppaa'][ncore:nocc,ncore:nocc].transpose(0,2,1,3)
534        h2e_v = eris['papa'][:ncore,:,:ncore].transpose(1,3,0,2)
535
536    if 'h1' in dms:
537        hdm1 = dms['h1']
538    else:
539        hdm1 = make_hdm1(dm1)
540    if 'h2' in dms:
541        hdm2 = dms['h2']
542    else:
543        hdm2 = make_hdm2(dm1,dm2)
544    if 'h3' in dms:
545        hdm3 = dms['h3']
546    else:
547        hdm3 = make_hdm3(dm1,dm2,dm3,hdm1,hdm2)
548
549# a9 is very sensitive to the accuracy of HF orbital and CI wfn
550    a9 = make_a9(h1e,h2e,hdm1,hdm2,hdm3)
551    norm = 0.5*numpy.einsum('qpij,baij,pqab->ij',h2e_v,h2e_v,hdm2)
552    h = 0.5*numpy.einsum('qpij,baij,pqab->ij',h2e_v,h2e_v,a9)
553    diff = mc.mo_energy[:ncore,None] + mc.mo_energy[None,:ncore]
554    return _norm_to_energy(norm, h, -diff)
555
556
557def Sir(mc, dms, eris, verbose=None):
558    #Subspace S_il^{(0)}
559    mo_core, mo_cas, mo_virt = _extract_orbs(mc, mc.mo_coeff)
560    dm1 = dms['1']
561    dm2 = dms['2']
562    dm3 = dms['3']
563    ncore = mo_core.shape[1]
564    ncas = mo_cas.shape[1]
565    nocc = ncore + ncas
566    if eris is None:
567        h1e = mc.h1e_for_cas()[0]
568        h2e = ao2mo.restore(1, mc.ao2mo(mo_cas), ncas).transpose(0,2,1,3)
569        h2e_v1 = ao2mo.incore.general(mc._scf._eri,[mo_virt,mo_core,mo_cas,mo_cas],compact=False)
570        h2e_v1 = h2e_v1.reshape(mo_virt.shape[1],ncore,ncas,ncas).transpose(0,2,1,3)
571        h2e_v2 = ao2mo.incore.general(mc._scf._eri,[mo_virt,mo_cas,mo_cas,mo_core],compact=False)
572        h2e_v2 = h2e_v2.reshape(mo_virt.shape[1],ncas,ncas,ncore).transpose(0,2,1,3)
573    else:
574        h1e = eris['h1eff'][ncore:nocc,ncore:nocc]
575        h2e = eris['ppaa'][ncore:nocc,ncore:nocc].transpose(0,2,1,3)
576        h2e_v1 = eris['ppaa'][nocc:,:ncore].transpose(0,2,1,3)
577        h2e_v2 = eris['papa'][nocc:,:,:ncore].transpose(0,3,1,2)
578        h1e_v = eris['h1eff'][nocc:,:ncore]
579
580    norm = numpy.einsum('rpiq,raib,qpab->ir',h2e_v1,h2e_v1,dm2)*2.0\
581         - numpy.einsum('rpiq,rabi,qpab->ir',h2e_v1,h2e_v2,dm2)\
582         - numpy.einsum('rpqi,raib,qpab->ir',h2e_v2,h2e_v1,dm2)\
583         + numpy.einsum('raqi,rabi,qb->ir',h2e_v2,h2e_v2,dm1)*2.0\
584         - numpy.einsum('rpqi,rabi,qbap->ir',h2e_v2,h2e_v2,dm2)\
585         + numpy.einsum('rpqi,raai,qp->ir',h2e_v2,h2e_v2,dm1)\
586         + numpy.einsum('rpiq,ri,qp->ir',h2e_v1,h1e_v,dm1)*4.0\
587         - numpy.einsum('rpqi,ri,qp->ir',h2e_v2,h1e_v,dm1)*2.0\
588         + numpy.einsum('ri,ri->ir',h1e_v,h1e_v)*2.0
589
590    a12 = make_a12(h1e,h2e,dm1,dm2,dm3)
591    a13 = make_a13(h1e,h2e,dm1,dm2,dm3)
592
593    h = numpy.einsum('rpiq,raib,pqab->ir',h2e_v1,h2e_v1,a12)*2.0\
594         - numpy.einsum('rpiq,rabi,pqab->ir',h2e_v1,h2e_v2,a12)\
595         - numpy.einsum('rpqi,raib,pqab->ir',h2e_v2,h2e_v1,a12)\
596         + numpy.einsum('rpqi,rabi,pqab->ir',h2e_v2,h2e_v2,a13)
597    diff = mc.mo_energy[:ncore,None] - mc.mo_energy[None,nocc:]
598    return _norm_to_energy(norm, h, -diff)
599
600
601class NEVPT(lib.StreamObject):
602    '''Strongly contracted NEVPT2
603
604    Attributes:
605        root : int
606            To control which state to compute if multiple roots or state-average
607            wfn were calculated in CASCI/CASSCF
608        compressed_mps : bool
609            compressed MPS perturber method for DMRG-SC-NEVPT2
610
611    Examples:
612
613    >>> mf = gto.M('N 0 0 0; N 0 0 1.4', basis='6-31g').apply(scf.RHF).run()
614    >>> mc = mcscf.CASSCF(mf, 4, 4).run()
615    >>> NEVPT(mc).kernel()
616    -0.14058324991532101
617    '''
618    def __init__(self, mc, root=0):
619        self.__dict__.update(mc.__dict__)
620        self.ncore = mc.ncore
621        self._mc = mc
622        self.root = root
623        self.compressed_mps = False
624
625##################################################
626# don't modify the following attributes, they are not input options
627        self.e_corr = None
628        self.canonicalized = False
629        nao, nmo = mc.mo_coeff.shape
630        self.onerdm = numpy.zeros((nao,nao))
631        self._keys = set(self.__dict__.keys())
632
633    def reset(self, mol=None):
634        if mol is not None:
635            self.mol = mol
636        self._mc.reset(mol)
637        return self
638
639    def get_hcore(self):
640        return self._mc.get_hcore()
641
642    def canonicalize(self, mo_coeff=None, ci=None, eris=None, sort=False,
643                      cas_natorb=False, casdm1=None, verbose=logger.NOTE):
644        return self._mc.canonicalize(mo_coeff, ci, eris, sort, cas_natorb, casdm1, verbose)
645
646    def get_veff(self, mol=None, dm=None, hermi=1):
647        return self._mc.get_veff(mol, dm, hermi)
648
649    def h1e_for_cas(self, mo_coeff=None, ncas=None, ncore=None):
650        return self._mc.h1e_for_cas(mo_coeff, ncas, ncore)
651
652    def load_ci(self, root=None):
653        '''Hack me to load CI wfn from disk'''
654        if root is None:
655            root = self.root
656        if self.fcisolver.nroots == 1:
657            return self.ci
658        else:
659            return self.ci[root]
660
661    def for_dmrg(self):
662        '''Some preprocess for dmrg-nevpt'''
663        if not self._mc.natorb:
664            logger.warn(self, '''\
665DRMG-MCSCF orbitals are not natural orbitals in active space. It's recommended
666to rerun DMRG-CASCI with mc.natorb before calling DMRG-NEVPT2.
667See discussions in github issue https://github.com/pyscf/pyscf/issues/698 and
668example examples/dmrg/32-dmrg_casscf_nevpt2_for_FeS.py''')
669        return self
670
671    def compress_approx(self,maxM=500, nevptsolver=None, tol=1e-7, stored_integral =False):
672        '''SC-NEVPT2 with compressed perturber
673
674        Kwargs :
675            maxM : int
676                DMRG bond dimension
677
678        Examples:
679
680        >>> mf = gto.M('N 0 0 0; N 0 0 1.4', basis='6-31g').apply(scf.RHF).run()
681        >>> mc = dmrgscf.DMRGSCF(mf, 4, 4).run()
682        >>> NEVPT(mc, root=0).compress_approx(maxM=100).kernel()
683        -0.14058324991532101
684
685        References:
686
687        J. Chem. Theory Comput. 12, 1583 (2016), doi:10.1021/acs.jctc.5b01225
688
689        J. Chem. Phys. 146, 244102 (2017), doi:10.1063/1.4986975
690        '''
691        #TODO
692        #Some preprocess for compressed perturber
693        if getattr(self.fcisolver, 'nevpt_intermediate', None):
694            logger.info(self, 'Use compressed mps perturber as an approximation')
695        else:
696            msg = 'Compressed mps perturber can be only used with DMRG wave function'
697            logger.error(self, msg)
698            raise RuntimeError(msg)
699
700        self.nevptsolver = nevptsolver
701        self.maxM = maxM
702        self.tol = tol
703        self.stored_integral = stored_integral
704
705        self.canonicalized = True
706        self.compressed_mps = True
707        self.for_dmrg()
708        return self
709
710
711
712    def kernel(self):
713        from pyscf.mcscf.addons import StateAverageFCISolver
714        if isinstance(self.fcisolver, StateAverageFCISolver):
715            raise RuntimeError('State-average FCI solver object cannot be used '
716                               'in NEVPT2 calculation.\nA separated multi-root '
717                               'CASCI calculation is required for NEVPT2 method. '
718                               'See examples/mrpt/41-for_state_average.py.')
719
720        if getattr(self._mc, 'frozen', None) is not None:
721            raise NotImplementedError
722
723        if isinstance(self.verbose, logger.Logger):
724            log = self.verbose
725        else:
726            log = logger.Logger(self.stdout, self.verbose)
727        time0 = (logger.process_clock(), logger.perf_counter())
728        ncore = self.ncore
729        ncas = self.ncas
730        nocc = ncore + ncas
731
732        #By defaut, _mc is canonicalized for the first root.
733        #For SC-NEVPT based on compressed MPS perturber functions, _mc was already canonicalized.
734        if (not self.canonicalized):
735            # Need to assign roots differently if we have more than one root
736            # See issue #1081 (https://github.com/pyscf/pyscf/issues/1081) for more details
737            self.mo_coeff, single_ci_vec, self.mo_energy = self.canonicalize(
738                self.mo_coeff, ci=self.load_ci(), cas_natorb=True, verbose=self.verbose)
739            if self.fcisolver.nroots == 1:
740                self.ci = single_ci_vec
741            else:
742                self.ci[self.root] = single_ci_vec
743
744        if getattr(self.fcisolver, 'nevpt_intermediate', None):
745            logger.info(self, 'DMRG-NEVPT')
746            dm1, dm2, dm3 = self.fcisolver._make_dm123(self.load_ci(),ncas,self.nelecas,None)
747        else:
748            dm1, dm2, dm3 = fci.rdm.make_dm123('FCI3pdm_kern_sf',
749                                               self.load_ci(), self.load_ci(), ncas, self.nelecas)
750        dm4 = None
751
752        dms = {
753            '1': dm1, '2': dm2, '3': dm3, '4': dm4,
754            # 'h1': hdm1, 'h2': hdm2, 'h3': hdm3
755        }
756        time1 = log.timer('3pdm, 4pdm', *time0)
757
758        eris = _ERIS(self, self.mo_coeff)
759        time1 = log.timer('integral transformation', *time1)
760
761        if not getattr(self.fcisolver, 'nevpt_intermediate', None):  # regular FCI solver
762            link_indexa = fci.cistring.gen_linkstr_index(range(ncas), self.nelecas[0])
763            link_indexb = fci.cistring.gen_linkstr_index(range(ncas), self.nelecas[1])
764            aaaa = eris['ppaa'][ncore:nocc,ncore:nocc].copy()
765            f3ca = _contract4pdm('NEVPTkern_cedf_aedf', aaaa, self.load_ci(), ncas,
766                                 self.nelecas, (link_indexa,link_indexb))
767            f3ac = _contract4pdm('NEVPTkern_aedf_ecdf', aaaa, self.load_ci(), ncas,
768                                 self.nelecas, (link_indexa,link_indexb))
769            dms['f3ca'] = f3ca
770            dms['f3ac'] = f3ac
771        time1 = log.timer('eri-4pdm contraction', *time1)
772
773        if self.compressed_mps:
774            from pyscf.dmrgscf.nevpt_mpi import DMRG_COMPRESS_NEVPT
775            if self.stored_integral: #Stored perturbation integral and read them again. For debugging purpose.
776                perturb_file = DMRG_COMPRESS_NEVPT(self, maxM=self.maxM, root=self.root,
777                                                   nevptsolver=self.nevptsolver,
778                                                   tol=self.tol,
779                                                   nevpt_integral='nevpt_perturb_integral')
780            else:
781                perturb_file = DMRG_COMPRESS_NEVPT(self, maxM=self.maxM, root=self.root,
782                                                   nevptsolver=self.nevptsolver,
783                                                   tol=self.tol)
784            fh5 = h5py.File(perturb_file, 'r')
785            e_Si     =   fh5['Vi/energy'][()]
786            #The definition of norm changed.
787            #However, there is no need to print out it.
788            #Only perturbation energy is wanted.
789            norm_Si  =   fh5['Vi/norm'][()]
790            e_Sr     =   fh5['Vr/energy'][()]
791            norm_Sr  =   fh5['Vr/norm'][()]
792            fh5.close()
793            logger.note(self, "Sr    (-1)',   E = %.14f",  e_Sr  )
794            logger.note(self, "Si    (+1)',   E = %.14f",  e_Si  )
795
796        else:
797            norm_Sr   , e_Sr    = Sr(self, self.load_ci(), dms, eris)
798            logger.note(self, "Sr    (-1)',   E = %.14f",  e_Sr  )
799            time1 = log.timer("space Sr (-1)'", *time1)
800            norm_Si   , e_Si    = Si(self, self.load_ci(), dms, eris)
801            logger.note(self, "Si    (+1)',   E = %.14f",  e_Si  )
802            time1 = log.timer("space Si (+1)'", *time1)
803        norm_Sijrs, e_Sijrs = Sijrs(self, eris)
804        logger.note(self, "Sijrs (0)  ,   E = %.14f", e_Sijrs)
805        time1 = log.timer('space Sijrs (0)', *time1)
806        norm_Sijr , e_Sijr  = Sijr(self, dms, eris)
807        logger.note(self, "Sijr  (+1) ,   E = %.14f",  e_Sijr)
808        time1 = log.timer('space Sijr (+1)', *time1)
809        norm_Srsi , e_Srsi  = Srsi(self, dms, eris)
810        logger.note(self, "Srsi  (-1) ,   E = %.14f",  e_Srsi)
811        time1 = log.timer('space Srsi (-1)', *time1)
812        norm_Srs  , e_Srs   = Srs(self, dms, eris)
813        logger.note(self, "Srs   (-2) ,   E = %.14f",  e_Srs )
814        time1 = log.timer('space Srs (-2)', *time1)
815        norm_Sij  , e_Sij   = Sij(self, dms, eris)
816        logger.note(self, "Sij   (+2) ,   E = %.14f",  e_Sij )
817        time1 = log.timer('space Sij (+2)', *time1)
818        norm_Sir  , e_Sir   = Sir(self, dms, eris)
819        logger.note(self, "Sir   (0)' ,   E = %.14f",  e_Sir )
820        time1 = log.timer("space Sir (0)'", *time1)
821
822        nevpt_e  = e_Sr + e_Si + e_Sijrs + e_Sijr + e_Srsi + e_Srs + e_Sij + e_Sir
823        logger.note(self, "Nevpt2 Energy = %.15f", nevpt_e)
824        log.timer('SC-NEVPT2', *time0)
825
826        self.e_corr = nevpt_e
827        return nevpt_e
828
829
830def kernel(mc, *args, **kwargs):
831    return sc_nevpt(mc, *args, **kwargs)
832
833def sc_nevpt(mc, ci=None, verbose=None):
834    import warnings
835    with warnings.catch_warnings():
836        warnings.simplefilter("once")
837        warnings.warn('API updates: function sc_nevpt is deprecated feature. '
838                      'It will be removed in future release.\n'
839                      'It is recommended to run NEVPT2 with new function '
840                      'mrpt.NEVPT(mc).kernel()')
841        if ci is not None:
842            warnings.warn('API updates: The kwarg "ci" has no effects. '
843                          'Use mrpt.NEVPT(mc,root=?) for excited state.')
844    return NEVPT(mc).kernel()
845
846
847# register NEVPT2 in MCSCF
848from pyscf.mcscf import casci
849casci.CASCI.NEVPT2 = NEVPT
850
851
852
853
854
855
856
857def _contract4pdm(kern, eri, civec, norb, nelec, link_index=None):
858    if isinstance(nelec, (int, numpy.integer)):
859        neleca = nelecb = nelec//2
860    else:
861        neleca, nelecb = nelec
862    if link_index is None:
863        link_indexa = fci.cistring.gen_linkstr_index(range(norb), neleca)
864        link_indexb = fci.cistring.gen_linkstr_index(range(norb), nelecb)
865    else:
866        link_indexa, link_indexb = link_index
867    na,nlinka = link_indexa.shape[:2]
868    nb,nlinkb = link_indexb.shape[:2]
869    fdm2 = numpy.empty((norb,norb,norb,norb))
870    fdm3 = numpy.empty((norb,norb,norb,norb,norb,norb))
871    eri = numpy.ascontiguousarray(eri)
872
873    libmc.NEVPTcontract(getattr(libmc, kern),
874                        fdm2.ctypes.data_as(ctypes.c_void_p),
875                        fdm3.ctypes.data_as(ctypes.c_void_p),
876                        eri.ctypes.data_as(ctypes.c_void_p),
877                        civec.ctypes.data_as(ctypes.c_void_p),
878                        ctypes.c_int(norb),
879                        ctypes.c_int(na), ctypes.c_int(nb),
880                        ctypes.c_int(nlinka), ctypes.c_int(nlinkb),
881                        link_indexa.ctypes.data_as(ctypes.c_void_p),
882                        link_indexb.ctypes.data_as(ctypes.c_void_p))
883    for i in range(norb):
884        for j in range(i):
885            fdm3[j,:,i] = fdm3[i,:,j].transpose(1,0,2,3)
886            fdm3[j,i,i,:] += fdm2[j,:]
887            fdm3[j,:,i,j] -= fdm2[i,:]
888    return fdm3
889
890def _extract_orbs(mc, mo_coeff):
891    ncore = mc.ncore
892    ncas = mc.ncas
893    nocc = ncore + ncas
894    mo_core = mo_coeff[:,:ncore]
895    mo_cas = mo_coeff[:,ncore:nocc]
896    mo_vir = mo_coeff[:,nocc:]
897    return mo_core, mo_cas, mo_vir
898
899
900def _norm_to_energy(norm, h, diff):
901    idx = abs(norm) > NUMERICAL_ZERO
902    ener_t = -(norm[idx] / (diff[idx] + h[idx]/norm[idx])).sum()
903    norm_t = norm.sum()
904    return norm_t, ener_t
905
906def _ERIS(mc, mo, method='incore'):
907    nmo = mo.shape[1]
908    ncore = mc.ncore
909    ncas = mc.ncas
910
911    mem_incore, mem_outcore, mem_basic = mc_ao2mo._mem_usage(ncore, ncas, nmo)
912    mem_now = lib.current_memory()[0]
913    if (method == 'incore' and mc._scf._eri is not None and
914        (mem_incore+mem_now < mc.max_memory*.9) or
915        mc.mol.incore_anyway):
916        ppaa, papa, pacv, cvcv = trans_e1_incore(mc, mo)
917    else:
918        max_memory = max(2000, mc.max_memory-mem_now)
919        ppaa, papa, pacv, cvcv = \
920                trans_e1_outcore(mc, mo, max_memory=max_memory,
921                                 verbose=mc.verbose)
922
923    dmcore = numpy.dot(mo[:,:ncore], mo[:,:ncore].T)
924    vj, vk = mc._scf.get_jk(mc.mol, dmcore)
925    vhfcore = reduce(numpy.dot, (mo.T, vj*2-vk, mo))
926
927    eris = {}
928    eris['vhf_c'] = vhfcore
929    eris['ppaa'] = ppaa
930    eris['papa'] = papa
931    eris['pacv'] = pacv
932    eris['cvcv'] = cvcv
933    eris['h1eff'] = reduce(numpy.dot, (mo.T, mc.get_hcore(), mo)) + vhfcore
934    return eris
935
936# see mcscf.mc_ao2mo
937def trans_e1_incore(mc, mo):
938    eri_ao = mc._scf._eri
939    ncore = mc.ncore
940    ncas = mc.ncas
941    nmo = mo.shape[1]
942    nocc = ncore + ncas
943    nav = nmo - ncore
944    eri1 = ao2mo.incore.half_e1(eri_ao, (mo[:,:nocc],mo[:,ncore:]),
945                                compact=False)
946    load_buf = lambda r0,r1: eri1[r0*nav:r1*nav]
947    ppaa, papa, pacv, cvcv = _trans(mo, ncore, ncas, load_buf)
948    return ppaa, papa, pacv, cvcv
949
950def trans_e1_outcore(mc, mo, max_memory=None, ioblk_size=256, tmpdir=None,
951                     verbose=0):
952    time0 = (logger.process_clock(), logger.perf_counter())
953    mol = mc.mol
954    log = logger.Logger(mc.stdout, verbose)
955    ncore = mc.ncore
956    ncas = mc.ncas
957    nao, nmo = mo.shape
958    nao_pair = nao*(nao+1)//2
959    nocc = ncore + ncas
960    nvir = nmo - nocc
961    nav = nmo - ncore
962
963    if tmpdir is None:
964        tmpdir = lib.param.TMPDIR
965    swapfile = tempfile.NamedTemporaryFile(dir=tmpdir)
966    ao2mo.outcore.half_e1(mol, (mo[:,:nocc],mo[:,ncore:]), swapfile.name,
967                          max_memory=max_memory, ioblk_size=ioblk_size,
968                          verbose=log, compact=False)
969
970    fswap = h5py.File(swapfile.name, 'r')
971    klaoblks = len(fswap['0'])
972    def load_buf(r0,r1):
973        if mol.verbose >= logger.DEBUG1:
974            time1[:] = logger.timer(mol, 'between load_buf',
975                                              *tuple(time1))
976        buf = numpy.empty(((r1-r0)*nav,nao_pair))
977        col0 = 0
978        for ic in range(klaoblks):
979            dat = fswap['0/%d'%ic]
980            col1 = col0 + dat.shape[1]
981            buf[:,col0:col1] = dat[r0*nav:r1*nav]
982            col0 = col1
983        if mol.verbose >= logger.DEBUG1:
984            time1[:] = logger.timer(mol, 'load_buf', *tuple(time1))
985        return buf
986    time0 = logger.timer(mol, 'halfe1', *time0)
987    time1 = [logger.process_clock(), logger.perf_counter()]
988    ao_loc = numpy.array(mol.ao_loc_nr(), dtype=numpy.int32)
989    cvcvfile = tempfile.NamedTemporaryFile(dir=tmpdir)
990    with h5py.File(cvcvfile.name, 'w') as f5:
991        cvcv = f5.create_dataset('eri_mo', (ncore*nvir,ncore*nvir), 'f8')
992        ppaa, papa, pacv = _trans(mo, ncore, ncas, load_buf, cvcv, ao_loc)[:3]
993    time0 = logger.timer(mol, 'trans_cvcv', *time0)
994    fswap.close()
995    return ppaa, papa, pacv, cvcvfile
996
997def _trans(mo, ncore, ncas, fload, cvcv=None, ao_loc=None):
998    nao, nmo = mo.shape
999    nocc = ncore + ncas
1000    nvir = nmo - nocc
1001    nav = nmo - ncore
1002
1003    if cvcv is None:
1004        cvcv = numpy.zeros((ncore*nvir,ncore*nvir))
1005    pacv = numpy.empty((nmo,ncas,ncore*nvir))
1006    aapp = numpy.empty((ncas,ncas,nmo*nmo))
1007    papa = numpy.empty((nmo,ncas,nmo*ncas))
1008    vcv = numpy.empty((nav,ncore*nvir))
1009    apa = numpy.empty((ncas,nmo*ncas))
1010    vpa = numpy.empty((nav,nmo*ncas))
1011    app = numpy.empty((ncas,nmo*nmo))
1012    for i in range(ncore):
1013        buf = fload(i, i+1)
1014        klshape = (0, ncore, nocc, nmo)
1015        _ao2mo.nr_e2(buf, mo, klshape,
1016                      aosym='s4', mosym='s1', out=vcv, ao_loc=ao_loc)
1017        cvcv[i*nvir:(i+1)*nvir] = vcv[ncas:]
1018        pacv[i] = vcv[:ncas]
1019
1020        klshape = (0, nmo, ncore, nocc)
1021        _ao2mo.nr_e2(buf[:ncas], mo, klshape,
1022                      aosym='s4', mosym='s1', out=apa, ao_loc=ao_loc)
1023        papa[i] = apa
1024    for i in range(ncas):
1025        buf = fload(ncore+i, ncore+i+1)
1026        klshape = (0, ncore, nocc, nmo)
1027        _ao2mo.nr_e2(buf, mo, klshape,
1028                      aosym='s4', mosym='s1', out=vcv, ao_loc=ao_loc)
1029        pacv[ncore:,i] = vcv
1030
1031        klshape = (0, nmo, ncore, nocc)
1032        _ao2mo.nr_e2(buf, mo, klshape,
1033                      aosym='s4', mosym='s1', out=vpa, ao_loc=ao_loc)
1034        papa[ncore:,i] = vpa
1035
1036        klshape = (0, nmo, 0, nmo)
1037        _ao2mo.nr_e2(buf[:ncas], mo, klshape,
1038                      aosym='s4', mosym='s1', out=app, ao_loc=ao_loc)
1039        aapp[i] = app
1040    ppaa = lib.transpose(aapp.reshape(ncas**2,-1))
1041    return (ppaa.reshape(nmo,nmo,ncas,ncas), papa.reshape(nmo,ncas,nmo,ncas),
1042            pacv.reshape(nmo,ncas,ncore,nvir), cvcv)
1043
1044
1045
1046
1047if __name__ == '__main__':
1048    from pyscf import gto
1049    from pyscf import scf
1050    from pyscf import mcscf
1051
1052    mol = gto.Mole()
1053    mol.verbose = 0
1054    mol.output = None
1055    mol.atom = [
1056        ['O', ( 0., 0.    , 0.    )],
1057        ['O', ( 0., 0.    , 1.207 )],
1058    ]
1059    mol.basis = '6-31g'
1060    mol.spin = 2
1061    mol.build()
1062
1063    m = scf.RHF(mol)
1064    m.conv_tol = 1e-20
1065    ehf = m.scf()
1066    mc = mcscf.CASCI(m, 6, 8)
1067    mc.fcisolver.conv_tol = 1e-14
1068    ci_e = mc.kernel()[0]
1069    mc.verbose = 4
1070    print(ci_e)
1071    #dm1, dm2, dm3, dm4 = fci.rdm.make_dm1234('FCI4pdm_kern_sf',
1072    #                                         mc.ci, mc.ci, mc.ncas, mc.nelecas)
1073    print(sc_nevpt(mc), -0.169785157128082)
1074
1075
1076    mol = gto.Mole()
1077    mol.verbose = 0
1078    mol.output = None
1079    mol.atom = [
1080        ['H', ( 0., 0.    , 0.    )],
1081        ['H', ( 0., 0.    , 0.8   )],
1082        ['H', ( 0., 0.    , 2.    )],
1083        ['H', ( 0., 0.    , 2.8   )],
1084        ['H', ( 0., 0.    , 4.    )],
1085        ['H', ( 0., 0.    , 4.8   )],
1086        ['H', ( 0., 0.    , 6.    )],
1087        ['H', ( 0., 0.    , 6.8   )],
1088        ['H', ( 0., 0.    , 8.    )],
1089        ['H', ( 0., 0.    , 8.8   )],
1090        ['H', ( 0., 0.    , 10.    )],
1091        ['H', ( 0., 0.    , 10.8   )],
1092        ['H', ( 0., 0.    , 12     )],
1093        ['H', ( 0., 0.    , 12.8   )],
1094    ]
1095    mol.basis = {'H': 'sto-3g'}
1096    mol.build()
1097
1098    m = scf.RHF(mol)
1099    m.conv_tol = 1e-20
1100    ehf = m.scf()
1101    mc = mcscf.CASCI(m,8,10)
1102    mc.fcisolver.conv_tol = 1e-14
1103    mc.kernel()
1104    mc.verbose = 4
1105    print(sc_nevpt(mc), -0.094164359938171)
1106