1#!/usr/bin/env python
2# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Authors: James D. McClain
17#          Timothy Berkelbach <tim.berkelbach@gmail.com>
18#
19
20import numpy as np
21
22from itertools import product
23from pyscf import lib
24from pyscf.lib import logger
25from pyscf.lib.parameters import LOOSE_ZERO_TOL, LARGE_DENOM  # noqa
26from pyscf.pbc.lib import kpts_helper
27from pyscf.pbc.mp.kmp2 import (get_frozen_mask, get_nocc, get_nmo,
28                               padded_mo_coeff, padding_k_idx)  # noqa
29
30#einsum = np.einsum
31einsum = lib.einsum
32
33# This is restricted (R)CCSD
34# Ref: Hirata et al., J. Chem. Phys. 120, 2581 (2004)
35
36### Eqs. (37)-(39) "kappa"
37
38def cc_Foo(t1,t2,eris,kconserv):
39    nkpts, nocc, nvir = t1.shape
40    Fki = np.empty((nkpts,nocc,nocc),dtype=t2.dtype)
41    for ki in range(nkpts):
42        kk = ki
43        Fki[ki] = eris.fock[ki,:nocc,:nocc].copy()
44        for kl in range(nkpts):
45            for kc in range(nkpts):
46                kd = kconserv[kk,kc,kl]
47                Soovv = 2*eris.oovv[kk,kl,kc] - eris.oovv[kk,kl,kd].transpose(0,1,3,2)
48                Fki[ki] += einsum('klcd,ilcd->ki',Soovv,t2[ki,kl,kc])
49            #if ki == kc:
50            kd = kconserv[kk,ki,kl]
51            Soovv = 2*eris.oovv[kk,kl,ki] - eris.oovv[kk,kl,kd].transpose(0,1,3,2)
52            Fki[ki] += einsum('klcd,ic,ld->ki',Soovv,t1[ki],t1[kl])
53    return Fki
54
55def cc_Fvv(t1,t2,eris,kconserv):
56    nkpts, nocc, nvir = t1.shape
57    Fac = np.empty((nkpts,nvir,nvir),dtype=t2.dtype)
58    for ka in range(nkpts):
59        kc = ka
60        Fac[ka] = eris.fock[ka,nocc:,nocc:].copy()
61        for kl in range(nkpts):
62            for kk in range(nkpts):
63                kd = kconserv[kk,kc,kl]
64                Soovv = 2*eris.oovv[kk,kl,kc] - eris.oovv[kk,kl,kd].transpose(0,1,3,2)
65                Fac[ka] += -einsum('klcd,klad->ac',Soovv,t2[kk,kl,ka])
66            #if kk == ka
67            kd = kconserv[ka,kc,kl]
68            Soovv = 2*eris.oovv[ka,kl,kc] - eris.oovv[ka,kl,kd].transpose(0,1,3,2)
69            Fac[ka] += -einsum('klcd,ka,ld->ac',Soovv,t1[ka],t1[kl])
70    return Fac
71
72def cc_Fov(t1,t2,eris,kconserv):
73    nkpts, nocc, nvir = t1.shape
74    Fkc = np.empty((nkpts,nocc,nvir),dtype=t2.dtype)
75    Fkc[:] = eris.fock[:,:nocc,nocc:].copy()
76    for kk in range(nkpts):
77        for kl in range(nkpts):
78            Soovv = 2.*eris.oovv[kk,kl,kk] - eris.oovv[kk,kl,kl].transpose(0,1,3,2)
79            Fkc[kk] += einsum('klcd,ld->kc',Soovv,t1[kl])
80    return Fkc
81
82### Eqs. (40)-(41) "lambda"
83
84def Loo(t1,t2,eris,kconserv):
85    nkpts, nocc, nvir = t1.shape
86    fov = eris.fock[:,:nocc,nocc:]
87    Lki = cc_Foo(t1,t2,eris,kconserv)
88    for ki in range(nkpts):
89        Lki[ki] += einsum('kc,ic->ki',fov[ki],t1[ki])
90        for kl in range(nkpts):
91            Lki[ki] += 2*einsum('klic,lc->ki',eris.ooov[ki,kl,ki],t1[kl])
92            Lki[ki] +=  -einsum('lkic,lc->ki',eris.ooov[kl,ki,ki],t1[kl])
93    return Lki
94
95def Lvv(t1,t2,eris,kconserv):
96    nkpts, nocc, nvir = t1.shape
97    fov = eris.fock[:,:nocc,nocc:]
98    Lac = cc_Fvv(t1,t2,eris,kconserv)
99    for ka in range(nkpts):
100        Lac[ka] += -einsum('kc,ka->ac',fov[ka],t1[ka])
101        for kk in range(nkpts):
102            Svovv = 2*eris.vovv[ka,kk,ka] - eris.vovv[ka,kk,kk].transpose(0,1,3,2)
103            Lac[ka] += einsum('akcd,kd->ac',Svovv,t1[kk])
104    return Lac
105
106### Eqs. (42)-(45) "chi"
107
108def cc_Woooo(t1, t2, eris, kconserv, out=None):
109    nkpts, nocc, nvir = t1.shape
110
111    Wklij = _new(eris.oooo.shape, t1.dtype, out)
112    for kk in range(nkpts):
113        for kl in range(kk+1):
114            for ki in range(nkpts):
115                kj = kconserv[kk,ki,kl]
116                oooo  = einsum('klic,jc->klij',eris.ooov[kk,kl,ki],t1[kj])
117                oooo += einsum('lkjc,ic->klij',eris.ooov[kl,kk,kj],t1[ki])
118                oooo += eris.oooo[kk,kl,ki]
119
120                # ==== Beginning of change ====
121                #
122                #for kc in range(nkpts):
123                #    Wklij[kk,kl,ki] += einsum('klcd,ijcd->klij',eris.oovv[kk,kl,kc],t2[ki,kj,kc])
124                #Wklij[kk,kl,ki] += einsum('klcd,ic,jd->klij',eris.oovv[kk,kl,ki],t1[ki],t1[kj])
125                vvoo = eris.oovv[kk,kl].transpose(0,3,4,1,2).reshape(nkpts*nvir,nvir,nocc,nocc)
126                t2t  = t2[ki,kj].copy().transpose(0,3,4,1,2)
127                #for kc in range(nkpts):
128                #    kd = kconserv[ki,kc,kj]
129                #    if kc == ki and kj == kd:
130                #        t2t[kc] += einsum('ic,jd->cdij',t1[ki],t1[kj])
131                t2t[ki] += einsum('ic,jd->cdij',t1[ki],t1[kj])
132                t2t = t2t.reshape(nkpts*nvir,nvir,nocc,nocc)
133                oooo += einsum('cdkl,cdij->klij',vvoo,t2t)
134                Wklij[kk,kl,ki] = oooo
135                # =====   End of change  = ====
136
137        # Be careful about making this term only after all the others are created
138        for kl in range(kk+1):
139            for ki in range(nkpts):
140                kj = kconserv[kk,ki,kl]
141                Wklij[kl,kk,kj] = Wklij[kk,kl,ki].transpose(1,0,3,2)
142    return Wklij
143
144def cc_Wvvvv(t1, t2, eris, kconserv, out=None):
145    Wabcd = _new(eris.vvvv.shape, t1.dtype, out)
146    nkpts, nocc, nvir = t1.shape
147    for ka in range(nkpts):
148        for kb in range(ka+1):
149            for kc in range(nkpts):
150                kd = kconserv[ka,kc,kb]
151                # avoid transpose in loop
152                vvvv  = einsum('akcd,kb->abcd', eris.vovv[ka,kb,kc], -t1[kb])
153                vvvv += einsum('bkdc,ka->abcd', eris.vovv[kb,ka,kd], -t1[ka])
154                vvvv += eris.vvvv[ka,kb,kc]
155                Wabcd[ka,kb,kc] = vvvv
156
157        # Be careful: making this term only after all the others are created
158        for kb in range(ka+1):
159            for kc in range(nkpts):
160                kd = kconserv[ka,kc,kb]
161                Wabcd[kb,ka,kd] = Wabcd[ka,kb,kc].transpose(1,0,3,2)
162
163    return Wabcd
164
165def cc_Wvoov(t1, t2, eris, kconserv, out=None):
166    Wakic = _new(eris.voov.shape, t1.dtype, out)
167    nkpts, nocc, nvir = t1.shape
168    for ka in range(nkpts):
169        for kk in range(nkpts):
170            voov_i  = einsum('xakdc,xid->xakic',eris.vovv[ka,kk,:],t1[:])
171            voov_i -= einsum('xlkic,la->xakic',eris.ooov[ka,kk,:],t1[ka])
172            voov_i += eris.voov[ka,kk,:]
173            for ki in range(nkpts):
174                kc = kconserv[ka,ki,kk]
175
176                #for kl in range(nkpts):
177                #    # kl - kd + kk = kc
178                #    # => kd = kl - kc + kk
179                #    kd = kconserv[kl,kc,kk]
180                #    Soovv = 2*eris.oovv[kl,kk,kd] - eris.oovv[kl,kk,kc].transpose(0,1,3,2)
181                #    Wakic[ka,kk,ki] += 0.5*einsum('lkdc,ilad->akic',Soovv,t2[ki,kl,ka])
182                #    Wakic[ka,kk,ki] -= 0.5*einsum('lkdc,ilda->akic',eris.oovv[kl,kk,kd],t2[ki,kl,kd])
183                #Wakic[ka,kk,ki] -= einsum('lkdc,id,la->akic',eris.oovv[ka,kk,ki],t1[ki],t1[ka])
184
185                kd = kconserv[ka,kc,kk]
186                tau = t2[:,ki,ka].copy()
187                tau[ka] += 2*einsum('id,la->liad',t1[kd],t1[ka])
188                oovv_tmp = np.array(eris.oovv[kk,:,kc])
189                voov_i[ki] -= 0.5*einsum('xklcd,xliad->akic',oovv_tmp,tau)
190
191                Soovv_tmp = 2*oovv_tmp - eris.oovv[:,kk,kc].transpose(0,2,1,3,4)
192                voov_i[ki] += 0.5*einsum('xklcd,xilad->akic',Soovv_tmp,t2[ki,:,ka])
193
194            Wakic[ka,kk,:] = voov_i[:]
195    return Wakic
196
197def cc_Wvovo(t1, t2, eris, kconserv, out=None):
198    nkpts, nocc, nvir = t1.shape
199    Wakci = _new((nkpts,nkpts,nkpts,nvir,nocc,nvir,nocc), t1.dtype, out)
200
201    for ka in range(nkpts):
202        for kk in range(nkpts):
203            for kc in range(nkpts):
204                ki = kconserv[ka,kc,kk]
205                vovo  = einsum('akcd,id->akci',eris.vovv[ka,kk,kc],t1[ki])
206                vovo -= einsum('klic,la->akci',eris.ooov[kk,ka,ki],t1[ka])
207                vovo += np.asarray(eris.ovov[kk,ka,ki]).transpose(1,0,3,2)
208                # ==== Beginning of change ====
209                #
210                #for kl in range(nkpts):
211                #    kd = kconserv[kl,kc,kk]
212                #    Wakci[ka,kk,kc] -= 0.5*einsum('lkcd,ilda->akci',eris.oovv[kl,kk,kc],t2[ki,kl,kd])
213                #Wakci[ka,kk,kc] -= einsum('lkcd,id,la->akci',eris.oovv[ka,kk,kc],t1[ki],t1[ka])
214                oovvf = eris.oovv[:,kk,kc].reshape(nkpts*nocc,nocc,nvir,nvir)
215                t2f   = t2[:,ki,ka].copy() #This is a tau like term
216                #for kl in range(nkpts):
217                #    kd = kconserv[kl,kc,kk]
218                #    if ki == kd and kl == ka:
219                #        t2f[kl] += 2*einsum('id,la->liad',t1[ki],t1[ka])
220                kd = kconserv[ka,kc,kk]
221                t2f[ka] += 2*einsum('id,la->liad',t1[kd],t1[ka])
222                t2f = t2f.reshape(nkpts*nocc,nocc,nvir,nvir)
223
224                vovo -= 0.5*einsum('lkcd,liad->akci',oovvf,t2f)
225                Wakci[ka,kk,kc] = vovo
226                # =====   End of change  = ====
227    return Wakci
228
229def Wooov(t1, t2, eris, kconserv, out=None):
230    nkpts, nocc, nvir = t1.shape
231    Wklid = _new(eris.ooov.shape, t1.dtype, out)
232    for kk in range(nkpts):
233        for kl in range(nkpts):
234            for ki in range(nkpts):
235                ooov = einsum('ic,klcd->klid',t1[ki],eris.oovv[kk,kl,ki])
236                ooov += eris.ooov[kk,kl,ki]
237                Wklid[kk,kl,ki] = ooov
238    return Wklid
239
240def Wvovv(t1, t2, eris, kconserv, out=None):
241    nkpts, nocc, nvir = t1.shape
242    Walcd = _new(eris.vovv.shape, t1.dtype, out)
243    for ka in range(nkpts):
244        for kl in range(nkpts):
245            for kc in range(nkpts):
246                vovv = einsum('ka,klcd->alcd', -t1[ka], eris.oovv[ka,kl,kc])
247                vovv += eris.vovv[ka,kl,kc]
248                Walcd[ka,kl,kc] = vovv
249    return Walcd
250
251def W1ovvo(t1, t2, eris, kconserv, out=None):
252    nkpts, nocc, nvir = t1.shape
253    Wkaci = _new((nkpts,nkpts,nkpts,nocc,nvir,nvir,nocc), t1.dtype, out)
254    for kk in range(nkpts):
255        for ka in range(nkpts):
256            for kc in range(nkpts):
257                ki = kconserv[kk,kc,ka]
258                # ovvo[kk,ka,kc,ki] => voov[ka,kk,ki,kc]
259                ovvo = np.asarray(eris.voov[ka,kk,ki]).transpose(1,0,3,2).copy()
260                for kl in range(nkpts):
261                    kd = kconserv[ki,ka,kl]
262                    St2 = 2.*t2[ki,kl,ka] - t2[kl,ki,ka].transpose(1,0,2,3)
263                    ovvo +=  einsum('klcd,ilad->kaci',eris.oovv[kk,kl,kc],St2)
264                    ovvo += -einsum('kldc,ilad->kaci',eris.oovv[kk,kl,kd],t2[ki,kl,ka])
265                Wkaci[kk,ka,kc] = ovvo
266    return Wkaci
267
268def W2ovvo(t1, t2, eris, kconserv, out=None):
269    nkpts, nocc, nvir = t1.shape
270    Wkaci = _new((nkpts,nkpts,nkpts,nocc,nvir,nvir,nocc), t1.dtype, out)
271    WWooov = Wooov(t1,t2,eris,kconserv)
272    for kk in range(nkpts):
273        for ka in range(nkpts):
274            for kc in range(nkpts):
275                ki = kconserv[kk,kc,ka]
276                ovvo =  einsum('la,lkic->kaci',-t1[ka],WWooov[ka,kk,ki])
277                ovvo += einsum('akdc,id->kaci',eris.vovv[ka,kk,ki],t1[ki])
278                Wkaci[kk,ka,kc] = ovvo
279    return Wkaci
280
281def Wovvo(t1, t2, eris, kconserv, out=None):
282    Wovvo = W1ovvo(t1, t2, eris, kconserv, out)
283    for k, w2 in enumerate(W2ovvo(t1, t2, eris, kconserv)):
284        Wovvo[k] = Wovvo[k] + w2
285    return Wovvo
286
287def W1ovov(t1, t2, eris, kconserv, out=None):
288    nkpts, nocc, nvir = t1.shape
289    Wkbid = _new(eris.ovov.shape, t1.dtype, out)
290    for kk in range(nkpts):
291        for kb in range(nkpts):
292            for ki in range(nkpts):
293                kd = kconserv[kk,ki,kb]
294                #   kk + kl - kc - kd = 0
295                # => kc = kk - kd + kl
296                ovov = eris.ovov[kk,kb,ki].copy()
297                for kl in range(nkpts):
298                    kc = kconserv[kk,kd,kl]
299                    ovov -= einsum('klcd,ilcb->kbid',eris.oovv[kk,kl,kc],t2[ki,kl,kc])
300                Wkbid[kk,kb,ki] = ovov
301    return Wkbid
302
303def W2ovov(t1, t2, eris, kconserv, out=None):
304    nkpts, nocc, nvir = t1.shape
305    Wkbid = _new((nkpts,nkpts,nkpts,nocc,nvir,nocc,nvir), t1.dtype, out)
306    WWooov = Wooov(t1,t2,eris,kconserv)
307    for kk in range(nkpts):
308        for kb in range(nkpts):
309            for ki in range(nkpts):
310                kd = kconserv[kk,ki,kb]
311                ovov = einsum('klid,lb->kbid',WWooov[kk,kb,ki],-t1[kb])
312                ovov += einsum('bkdc,ic->kbid',eris.vovv[kb,kk,kd],t1[ki])
313                Wkbid[kk,kb,ki] = ovov
314    return Wkbid
315
316def Wovov(t1, t2, eris, kconserv, out=None):
317    Wovov = W1ovov(t1, t2, eris, kconserv, out)
318    for k, w2 in enumerate(W2ovov(t1, t2, eris, kconserv)):
319        Wovov[k] = Wovov[k] + w2
320    return Wovov
321
322def Woooo(t1, t2, eris, kconserv, out=None):
323    nkpts, nocc, nvir = t1.shape
324    Wklij = _new(eris.oooo.shape, t1.dtype, out)
325    for kk in range(nkpts):
326        for kl in range(nkpts):
327            for ki in range(nkpts):
328                kj = kconserv[kk,ki,kl]
329                oooo  = einsum('klcd,ic,jd->klij',eris.oovv[kk,kl,ki],t1[ki],t1[kj])
330                oooo += einsum('klid,jd->klij',eris.ooov[kk,kl,ki],t1[kj])
331                oooo += einsum('lkjc,ic->klij',eris.ooov[kl,kk,kj],t1[ki])
332                oooo += eris.oooo[kk,kl,ki]
333                for kc in range(nkpts):
334                    #kd = kconserv[kk,kc,kl]
335                    oooo += einsum('klcd,ijcd->klij',eris.oovv[kk,kl,kc],t2[ki,kj,kc])
336                Wklij[kk,kl,ki] = oooo
337    return Wklij
338
339def Wvvvv(t1, t2, eris, kconserv, out=None):
340    nkpts, nocc, nvir = t1.shape
341    Wabcd = _new((nkpts,nkpts,nkpts,nvir,nvir,nvir,nvir), t2.dtype, out)
342    for ka in range(nkpts):
343        for kb in range(nkpts):
344            for kc in range(nkpts):
345                Wabcd[ka,kb,kc] = get_Wvvvv(t1, t2, eris, kconserv, ka, kb, kc)
346    return Wabcd
347
348def get_Wvvvv(t1, t2, eris, kconserv, ka, kb, kc):
349    kd = kconserv[ka, kc, kb]
350    nkpts, nocc, nvir = t1.shape
351    if getattr(eris, 'Lpv', None) is not None:
352        # Using GDF to generate Wvvvv on the fly
353        Lpv = eris.Lpv
354        Lac = (Lpv[ka,kc][:,nocc:] -
355               einsum('Lkc,ka->Lac', Lpv[ka,kc][:,:nocc], t1[ka]))
356        Lbd = (Lpv[kb,kd][:,nocc:] -
357               einsum('Lkd,kb->Lbd', Lpv[kb,kd][:,:nocc], t1[kb]))
358        vvvv = einsum('Lac,Lbd->abcd', Lac, Lbd)
359        vvvv *= (1. / nkpts)
360    else:
361        vvvv  = einsum('klcd,ka,lb->abcd',eris.oovv[ka,kb,kc],t1[ka],t1[kb])
362        vvvv += einsum('alcd,lb->abcd',eris.vovv[ka,kb,kc],-t1[kb])
363        vvvv += einsum('bkdc,ka->abcd',eris.vovv[kb,ka,kd],-t1[ka])
364        vvvv += eris.vvvv[ka,kb,kc]
365
366    for kk in range(nkpts):
367        kl = kconserv[kc,kk,kd]
368        vvvv += einsum('klcd,klab->abcd', eris.oovv[kk,kl,kc], t2[kk,kl,ka])
369    return vvvv
370
371def Wvvvo(t1, t2, eris, kconserv, _Wvvvv=None, out=None):
372    nkpts, nocc, nvir = t1.shape
373    Wabcj = _new((nkpts,nkpts,nkpts,nvir,nvir,nvir,nocc), t1.dtype, out)
374    WW1ovov = W1ovov(t1,t2,eris,kconserv)
375    WW1ovvo = W1ovvo(t1,t2,eris,kconserv)
376    FFov = cc_Fov(t1,t2,eris,kconserv)
377    for ka in range(nkpts):
378        for kb in range(nkpts):
379            for kc in range(nkpts):
380                kj = kconserv[ka,kc,kb]
381                # Wvovo[ka,kl,kc,kj] <= Wovov[kl,ka,kj,kc].transpose(1,0,3,2)
382                vvvo  = einsum('alcj,lb->abcj',WW1ovov[kb,ka,kj].transpose(1,0,3,2),-t1[kb])
383                vvvo += einsum('kbcj,ka->abcj',WW1ovvo[ka,kb,kc],-t1[ka])
384                # vvvo[ka,kb,kc,kj] <= vovv[kc,kj,ka,kb].transpose(2,3,0,1).conj()
385                vvvo += np.asarray(eris.vovv[kc,kj,ka]).transpose(2,3,0,1).conj()
386
387                for kl in range(nkpts):
388                    # ka + kl - kc - kd = 0
389                    # => kd = ka - kc + kl
390                    kd = kconserv[ka,kc,kl]
391                    St2 = 2.*t2[kl,kj,kd] - t2[kl,kj,kb].transpose(0,1,3,2)
392                    vvvo += einsum('alcd,ljdb->abcj',eris.vovv[ka,kl,kc], St2)
393                    vvvo += einsum('aldc,ljdb->abcj',eris.vovv[ka,kl,kd], -t2[kl,kj,kd])
394                    # kb - kc + kl = kd
395                    kd = kconserv[kb,kc,kl]
396                    vvvo += einsum('bldc,jlda->abcj',eris.vovv[kb,kl,kd], -t2[kj,kl,kd])
397
398                    # kl + kk - kb - ka = 0
399                    # => kk = kb + ka - kl
400                    kk = kconserv[kb,kl,ka]
401                    vvvo += einsum('lkjc,lkba->abcj',eris.ooov[kl,kk,kj],t2[kl,kk,kb])
402                vvvo += einsum('lkjc,lb,ka->abcj',eris.ooov[kb,ka,kj],t1[kb],t1[ka])
403                vvvo += einsum('lc,ljab->abcj',-FFov[kc],t2[kc,kj,ka])
404                Wabcj[ka,kb,kc] = vvvo
405
406    # Check if t1=0 (HF+MBPT(2))
407    # einsum will check, but don't make vvvv if you can avoid it!
408    if np.any(t1 != 0):
409        for ka in range(nkpts):
410            for kb in range(nkpts):
411                for kc in range(nkpts):
412                    kj = kconserv[ka,kc,kb]
413                    if _Wvvvv is None:
414                        Wvvvv = get_Wvvvv(t1, t2, eris, kconserv, ka, kb, kc)
415                    else:
416                        Wvvvv = _Wvvvv[ka, kb, kc]
417                    Wabcj[ka,kb,kc] = (Wabcj[ka,kb,kc] +
418                                       einsum('abcd,jd->abcj', Wvvvv, t1[kj]))
419    return Wabcj
420
421def Wovoo(t1, t2, eris, kconserv, out=None):
422    nkpts, nocc, nvir = t1.shape
423
424    WW1ovov = W1ovov(t1,t2,eris,kconserv)
425    WWoooo = Woooo(t1,t2,eris,kconserv)
426    WW1ovvo = W1ovvo(t1,t2,eris,kconserv)
427    FFov = cc_Fov(t1,t2,eris,kconserv)
428
429    Wkbij = _new((nkpts,nkpts,nkpts,nocc,nvir,nocc,nocc), t1.dtype, out)
430    for kk in range(nkpts):
431        for kb in range(nkpts):
432            for ki in range(nkpts):
433                kj = kconserv[kk,ki,kb]
434                ovoo  = einsum('kbid,jd->kbij',WW1ovov[kk,kb,ki], t1[kj])
435                ovoo += einsum('klij,lb->kbij',WWoooo[kk,kb,ki],-t1[kb])
436                ovoo += einsum('kbcj,ic->kbij',WW1ovvo[kk,kb,ki],t1[ki])
437                ovoo += np.array(eris.ooov[ki,kj,kk]).transpose(2,3,0,1).conj()
438
439                for kd in range(nkpts):
440                    # kk + kl - ki - kd = 0
441                    # => kl = ki - kk + kd
442                    kl = kconserv[ki,kk,kd]
443                    St2 = 2.*t2[kl,kj,kd] - t2[kj,kl,kd].transpose(1,0,2,3)
444                    ovoo += einsum('klid,ljdb->kbij',  eris.ooov[kk,kl,ki], St2)
445                    ovoo += einsum('lkid,ljdb->kbij', -eris.ooov[kl,kk,ki],t2[kl,kj,kd])
446                    kl = kconserv[kb,ki,kd]
447                    ovoo += einsum('lkjd,libd->kbij', -eris.ooov[kl,kk,kj],t2[kl,ki,kb])
448
449                    # kb + kk - kd = kc
450                    #kc = kconserv[kb,kd,kk]
451                    ovoo += einsum('bkdc,jidc->kbij',eris.vovv[kb,kk,kd],t2[kj,ki,kd])
452                ovoo += einsum('bkdc,jd,ic->kbij',eris.vovv[kb,kk,kj],t1[kj],t1[ki])
453                ovoo += einsum('kc,ijcb->kbij',FFov[kk],t2[ki,kj,kk])
454                Wkbij[kk,kb,ki] = ovoo
455    return Wkbij
456
457def _new(shape, dtype, out):
458    if out is None: # Incore:
459        out = np.empty(shape, dtype=dtype)
460    else:
461        assert(out.shape == shape)
462        assert(out.dtype == dtype)
463    return out
464
465def get_t3p2_imds_slow(cc, t1, t2, eris=None, t3p2_ip_out=None, t3p2_ea_out=None):
466    """For a description of arguments, see `get_t3p2_imds_slow` in
467    the corresponding `kintermediates.py`.
468    """
469    from pyscf.pbc.cc.kccsd_t_rhf import _get_epqr
470    if eris is None:
471        eris = cc.ao2mo()
472    fock = eris.fock
473    nkpts, nocc, nvir = t1.shape
474    kconserv = cc.khelper.kconserv
475    dtype = np.result_type(t1, t2)
476
477    fov = fock[:, :nocc, nocc:]
478    #foo = [fock[ikpt, :nocc, :nocc].diagonal() for ikpt in range(nkpts)]
479    #fvv = [fock[ikpt, nocc:, nocc:].diagonal() for ikpt in range(nkpts)]
480    mo_energy_occ = np.array([eris.mo_energy[ki][:nocc] for ki in range(nkpts)])
481    mo_energy_vir = np.array([eris.mo_energy[ki][nocc:] for ki in range(nkpts)])
482
483    mo_e_o = mo_energy_occ
484    mo_e_v = mo_energy_vir
485
486    # Get location of padded elements in occupied and virtual space
487    nonzero_opadding, nonzero_vpadding = padding_k_idx(cc, kind="split")
488
489    ccsd_energy = cc.energy(t1, t2, eris)
490
491    if t3p2_ip_out is None:
492        t3p2_ip_out = np.zeros((nkpts,nkpts,nkpts,nocc,nvir,nocc,nocc),dtype=dtype)
493    Wmcik = t3p2_ip_out
494
495    if t3p2_ea_out is None:
496        t3p2_ea_out = np.zeros((nkpts,nkpts,nkpts,nvir,nvir,nvir,nocc),dtype=dtype)
497    Wacek = t3p2_ea_out
498
499    from itertools import product
500    tmp_t3 = np.empty((nkpts, nkpts, nkpts, nkpts, nkpts, nocc, nocc, nocc, nvir, nvir, nvir),
501                      dtype = t2.dtype)
502
503    def get_w(ki, kj, kk, ka, kb, kc):
504        kf = kconserv[ka,ki,kb]
505        ret = lib.einsum('fiba,kjcf->ijkabc', eris.vovv[kf, ki, kb].conj(), t2[kk, kj, kc])
506        km = kconserv[kc,kk,kb]
507        ret -= lib.einsum('jima,mkbc->ijkabc', eris.ooov[kj, ki, km].conj(), t2[km, kk, kb])
508        return ret
509
510    for ki, kj, kk, ka, kb in product(range(nkpts), repeat=5):
511        kc = kpts_helper.get_kconserv3(cc._scf.cell, cc.kpts,
512                                       [ki, kj, kk, ka, kb])
513        tmp_t3[ki, kj, kk, ka, kb] = get_w(ki, kj, kk, ka, kb, kc)
514        tmp_t3[ki, kj, kk, ka, kb] += get_w(ki, kk, kj, ka, kc, kb).transpose(0, 2, 1, 3, 5, 4)
515        tmp_t3[ki, kj, kk, ka, kb] += get_w(kj, ki, kk, kb, ka, kc).transpose(1, 0, 2, 4, 3, 5)
516        tmp_t3[ki, kj, kk, ka, kb] += get_w(kj, kk, ki, kb, kc, ka).transpose(2, 0, 1, 5, 3, 4)
517        tmp_t3[ki, kj, kk, ka, kb] += get_w(kk, ki, kj, kc, ka, kb).transpose(1, 2, 0, 4, 5, 3)
518        tmp_t3[ki, kj, kk, ka, kb] += get_w(kk, kj, ki, kc, kb, ka).transpose(2, 1, 0, 5, 4, 3)
519
520        eijk = _get_epqr([0,nocc,ki,mo_e_o,nonzero_opadding],
521                         [0,nocc,kj,mo_e_o,nonzero_opadding],
522                         [0,nocc,kk,mo_e_o,nonzero_opadding])
523        eabc = _get_epqr([0,nvir,ka,mo_e_v,nonzero_vpadding],
524                         [0,nvir,kb,mo_e_v,nonzero_vpadding],
525                         [0,nvir,kc,mo_e_v,nonzero_vpadding],
526                         fac=[-1.,-1.,-1.])
527        eijkabc = eijk[:, :, :, None, None, None] + eabc[None, None, None, :, :, :]
528        tmp_t3[ki, kj, kk, ka, kb] /= eijkabc
529
530    pt1 = np.zeros((nkpts, nocc, nvir), dtype=t2.dtype)
531    for ki in range(nkpts):
532        for km, kn, ke in product(range(nkpts), repeat=3):
533            kf = kconserv[km, ke, kn]
534            Soovv = 2. * eris.oovv[km, kn, ke] - eris.oovv[km, kn, kf].transpose(0, 1, 3, 2)
535            St3 = (tmp_t3[ki, km, kn, ki, ke] -
536                   tmp_t3[ki, km, kn, ke, ki].transpose(0, 1, 2, 4, 3, 5))
537            pt1[ki] += lib.einsum('mnef,imnaef->ia', Soovv, St3)
538
539    pt2 = np.zeros((nkpts, nkpts, nkpts, nocc, nocc, nvir, nvir), dtype=t2.dtype)
540    for ki, kj, ka in product(range(nkpts), repeat=3):
541        kb = kconserv[ki, ka, kj]
542        for km in range(nkpts):
543            for kn in range(nkpts):
544                # (ia,jb) -> (ia,jb)
545                ke = kconserv[km, kj, kn]
546                pt2[ki, kj, ka] += - 2. * lib.einsum('imnabe,mnje->ijab',
547                                                     tmp_t3[ki, km, kn, ka, kb],
548                                                     eris.ooov[km, kn, kj])
549                pt2[ki, kj, ka] += lib.einsum('imnabe,nmje->ijab',
550                                              tmp_t3[ki, km, kn, ka, kb],
551                                              eris.ooov[kn, km, kj])
552                pt2[ki, kj, ka] += lib.einsum('inmeab,mnje->ijab',
553                                              tmp_t3[ki, kn, km, ke, ka],
554                                              eris.ooov[km, kn, kj])
555
556                # (ia,jb) -> (jb,ia)
557                ke = kconserv[km, ki, kn]
558                pt2[ki, kj, ka] += - 2. * lib.einsum('jmnbae,mnie->ijab',
559                                                     tmp_t3[kj, km, kn, kb, ka],
560                                                     eris.ooov[km, kn, ki])
561                pt2[ki, kj, ka] += lib.einsum('jmnbae,nmie->ijab',
562                                              tmp_t3[kj, km, kn, kb, ka],
563                                              eris.ooov[kn, km, ki])
564                pt2[ki, kj, ka] += lib.einsum('jnmeba,mnie->ijab',
565                                              tmp_t3[kj, kn, km, ke, kb],
566                                              eris.ooov[km, kn, ki])
567
568            # (ia,jb) -> (ia,jb)
569            pt2[ki, kj, ka] += lib.einsum('ijmabe,me->ijab',
570                                          tmp_t3[ki, kj, km, ka, kb],
571                                          fov[km])
572            pt2[ki, kj, ka] -= lib.einsum('ijmaeb,me->ijab',
573                                          tmp_t3[ki, kj, km, ka, km],
574                                          fov[km])
575
576            # (ia,jb) -> (jb,ia)
577            pt2[ki, kj, ka] += lib.einsum('jimbae,me->ijab',
578                                          tmp_t3[kj, ki, km, kb, ka],
579                                          fov[km])
580            pt2[ki, kj, ka] -= lib.einsum('jimbea,me->ijab',
581                                          tmp_t3[kj, ki, km, kb, km],
582                                          fov[km])
583
584            for ke in range(nkpts):
585                # (ia,jb) -> (ia,jb)
586                kf = kconserv[km, ke, kb]
587                pt2[ki, kj, ka] += 2. * lib.einsum('ijmaef,bmef->ijab',
588                                                   tmp_t3[ki, kj, km, ka, ke],
589                                                   eris.vovv[kb, km, ke])
590                pt2[ki, kj, ka] -= lib.einsum('ijmaef,bmfe->ijab',
591                                              tmp_t3[ki, kj, km, ka, ke],
592                                              eris.vovv[kb, km, kf])
593                pt2[ki, kj, ka] -= lib.einsum('imjfae,bmef->ijab',
594                                              tmp_t3[ki, km, kj, kf, ka],
595                                              eris.vovv[kb, km, ke])
596
597                # (ia,jb) -> (jb,ia)
598                kf = kconserv[km, ke, ka]
599                pt2[ki, kj, ka] += 2. * lib.einsum('jimbef,amef->ijab',
600                                                   tmp_t3[kj, ki, km, kb, ke],
601                                                   eris.vovv[ka, km, ke])
602                pt2[ki, kj, ka] -= lib.einsum('jimbef,amfe->ijab',
603                                              tmp_t3[kj, ki, km, kb, ke],
604                                              eris.vovv[ka, km, kf])
605                pt2[ki, kj, ka] -= lib.einsum('jmifbe,amef->ijab',
606                                              tmp_t3[kj, km, ki, kf, kb],
607                                              eris.vovv[ka, km, ke])
608
609    for ki in range(nkpts):
610        ka = ki
611        eia = LARGE_DENOM * np.ones((nocc, nvir), dtype=eris.mo_energy[0].dtype)
612        n0_ovp_ia = np.ix_(nonzero_opadding[ki], nonzero_vpadding[ka])
613        eia[n0_ovp_ia] = (mo_e_o[ki][:,None] - mo_e_v[ka])[n0_ovp_ia]
614        pt1[ki] /= eia
615
616    for ki, ka in product(range(nkpts), repeat=2):
617        eia = LARGE_DENOM * np.ones((nocc, nvir), dtype=eris.mo_energy[0].dtype)
618        n0_ovp_ia = np.ix_(nonzero_opadding[ki], nonzero_vpadding[ka])
619        eia[n0_ovp_ia] = (mo_e_o[ki][:,None] - mo_e_v[ka])[n0_ovp_ia]
620        for kj in range(nkpts):
621            kb = kconserv[ki, ka, kj]
622            ejb = LARGE_DENOM * np.ones((nocc, nvir), dtype=eris.mo_energy[0].dtype)
623            n0_ovp_jb = np.ix_(nonzero_opadding[kj], nonzero_vpadding[kb])
624            ejb[n0_ovp_jb] = (mo_e_o[kj][:,None] - mo_e_v[kb])[n0_ovp_jb]
625            eijab = eia[:, None, :, None] + ejb[:, None, :]
626            pt2[ki, kj, ka] /= eijab
627
628    pt1 += t1
629    pt2 += t2
630
631    for ki, kj, kk, ka, kb in product(range(nkpts), repeat=5):
632        kc = kpts_helper.get_kconserv3(cc._scf.cell, cc.kpts,
633                                       [ki, kj, kk, ka, kb])
634        km = kconserv[kc, ki, ka]
635
636        _oovv = eris.oovv[km, ki, kc]
637        Wmcik[km, kb, kk] += 2. * lib.einsum('ijkabc,mica->mbkj', tmp_t3[ki, kj, kk, ka, kb], _oovv)
638        Wmcik[km, kb, kk] -=      lib.einsum('jikabc,mica->mbkj', tmp_t3[kj, ki, kk, ka, kb], _oovv)
639        Wmcik[km, kb, kk] -=      lib.einsum('kjiabc,mica->mbkj', tmp_t3[kk, kj, ki, ka, kb], _oovv)
640
641    for ki, kj, kk, ka, kb in product(range(nkpts), repeat=5):
642        kc = kpts_helper.get_kconserv3(cc._scf.cell, cc.kpts,
643                                       [ki, kj, kk, ka, kb])
644        ke = kconserv[ki, ka, kk]
645
646        _oovv = eris.oovv[ki, kk, ka]
647        Wacek[kc, kb, ke] -= 2. * lib.einsum('ijkabc,ikae->cbej', tmp_t3[ki, kj, kk, ka, kb], _oovv)
648        Wacek[kc, kb, ke] +=      lib.einsum('jikabc,ikae->cbej', tmp_t3[kj, ki, kk, ka, kb], _oovv)
649        Wacek[kc, kb, ke] +=      lib.einsum('kjiabc,ikae->cbej', tmp_t3[kk, kj, ki, ka, kb], _oovv)
650
651    delta_ccsd_energy = cc.energy(pt1, pt2, eris) - ccsd_energy
652    lib.logger.info(cc, 'CCSD energy T3[2] correction : %16.12e', delta_ccsd_energy)
653
654    return delta_ccsd_energy, pt1, pt2, Wmcik, Wacek
655
656
657def _add_pt2(pt2, nkpts, kconserv, kpt_indices, orb_indices, val):
658    '''Adds term P(ia|jb)[tmp] to pt2.
659
660        P(ia|jb)(tmp[i,j,a,b]) = tmp[i,j,a,b] + tmp[j,i,b,a]
661
662    or equivalently for each i,j,a,b, pt2 is defined as
663
664        pt2[i,j,a,b] += tmp[i,j,a,b]
665        pt2[j,i,b,a] += tmp[i,j,a,b].transpose(1,0,3,2)
666
667    If pt2 is lower-triangular, only adds the RHS term that contributes
668    to the lower-triangular pt2.
669
670    Args:
671        pt2 (ndarray or HDF5 dataset):
672            Full or lower triangular T2 array to which one is adding to.
673        kpt_indices (array-like):
674            K-point indices ki, kj, ka.
675        orb_indices (array-like):
676            Array-like of four tuples describing the range for i,j,a,b.  An
677            element of None will convert to slice(None,None).
678        val (ndarray):
679            Values to be added to pt2.
680    '''
681    assert(len(orb_indices) == 4)
682    ki, kj, ka = kpt_indices
683    kb = kconserv[ki,ka,kj]
684    idxi, idxj, idxa, idxb = [slice(None, None)
685                              if x is None else slice(x[0],x[1])
686                              for x in orb_indices]
687    if len(pt2.shape) == 7 and pt2.shape[:2] == (nkpts, nkpts):
688        pt2[ki,kj,ka,idxi,idxj,idxa,idxb] += val
689        pt2[kj,ki,kb,idxj,idxi,idxb,idxa] += val.transpose(1,0,3,2)
690    elif len(pt2.shape) == 6 and pt2.shape[:2] == (nkpts*(nkpts+1)//2, nkpts):
691        if ki <= kj:  # Add tmp[i,j,a,b] to pt2[i,j,a,b]
692            idx = (kj*(kj+1))//2 + ki
693            pt2[idx,ka,idxi,idxj,idxa,idxb] += val
694            if ki == kj:
695                pt2[idx,kb,idxj,idxi,idxb,idxa] += val.transpose(1,0,3,2)
696        else:  # pt2[i,a,j,b] += tmp[j,i,a,b].transpose(1,0,3,2)
697            idx = (ki*(ki+1))//2 + kj
698            pt2[idx,kb,idxj,idxi,idxb,idxa] += val.transpose(1,0,3,2)
699    else:
700        raise ValueError('No known conversion for t2 shape %s' % pt2.shape)
701
702
703def get_t3p2_imds(mycc, t1, t2, eris=None, t3p2_ip_out=None, t3p2_ea_out=None):
704    """For a description of arguments, see `get_t3p2_imds_slow` in
705    the corresponding `kintermediates.py`.
706    """
707    from pyscf.pbc.cc.kccsd_t_rhf import _get_epqr
708    cpu1 = cpu0 = (logger.process_clock(), logger.perf_counter())
709    if eris is None:
710        eris = mycc.ao2mo()
711    fock = eris.fock
712    nkpts, nocc, nvir = t1.shape
713    cell = mycc._scf.cell
714    kpts = mycc.kpts
715    kconserv = mycc.khelper.kconserv
716    dtype = np.result_type(t1, t2)
717
718    fov = fock[:, :nocc, nocc:]
719    #foo = np.asarray([fock[ikpt, :nocc, :nocc].diagonal() for ikpt in range(nkpts)])
720    #fvv = np.asarray([fock[ikpt, nocc:, nocc:].diagonal() for ikpt in range(nkpts)])
721    mo_energy_occ = np.array([eris.mo_energy[ki][:nocc] for ki in range(nkpts)])
722    mo_energy_vir = np.array([eris.mo_energy[ki][nocc:] for ki in range(nkpts)])
723
724    mo_e_o = mo_energy_occ
725    mo_e_v = mo_energy_vir
726
727    ccsd_energy = mycc.energy(t1, t2, eris)
728
729    if t3p2_ip_out is None:
730        t3p2_ip_out = np.zeros((nkpts,nkpts,nkpts,nocc,nvir,nocc,nocc),dtype=dtype)
731    Wmcik = t3p2_ip_out
732
733    if t3p2_ea_out is None:
734        t3p2_ea_out = np.zeros((nkpts,nkpts,nkpts,nvir,nvir,nvir,nocc),dtype=dtype)
735    Wacek = t3p2_ea_out
736
737    # Create necessary temporary eris for fast read
738    from pyscf.pbc.cc.kccsd_t_rhf import create_t3_eris, get_data_slices
739    feri_tmp, t2T, eris_vvop, eris_vooo_C = create_t3_eris(mycc, kconserv, [eris.vovv, eris.oovv, eris.ooov, t2])
740    #t1T = np.array([x.T for x in t1], dtype=np.complex128, order='C')
741    #fvo = np.array([x.T for x in fov], dtype=np.complex128, order='C')
742    cpu1 = logger.timer_debug1(mycc, 'CCSD(T) tmp eri creation', *cpu1)
743
744    def get_w(ki, kj, kk, ka, kb, kc, a0, a1, b0, b1, c0, c1):
745        '''Wijkabc intermediate as described in Scuseria paper before Pijkabc acts
746
747        Function copied for `kccsd_t_rhf.py`'''
748        km = kconserv[kc, kk, kb]
749        kf = kconserv[kk, kc, kj]
750        out = einsum('cfjk,abif->abcijk', t2T[kc,kf,kj,c0:c1,:,:,:], eris_vvop[ka,kb,ki,a0:a1,b0:b1,:,nocc:])
751        out = out - einsum('cbmk,aijm->abcijk', t2T[kc,kb,km,c0:c1,b0:b1,:,:], eris_vooo_C[ka,ki,kj,a0:a1,:,:,:])
752        return out
753
754    def get_permuted_w(ki, kj, kk, ka, kb, kc, orb_indices):
755        '''Pijkabc operating on Wijkabc intermediate as described in Scuseria paper
756
757        Function copied for `kccsd_t_rhf.py`'''
758        a0, a1, b0, b1, c0, c1 = orb_indices
759        out = get_w(ki, kj, kk, ka, kb, kc, a0, a1, b0, b1, c0, c1)
760        out = out + get_w(kj, kk, ki, kb, kc, ka, b0, b1, c0, c1, a0, a1).transpose(2,0,1,5,3,4)
761        out = out + get_w(kk, ki, kj, kc, ka, kb, c0, c1, a0, a1, b0, b1).transpose(1,2,0,4,5,3)
762        out = out + get_w(ki, kk, kj, ka, kc, kb, a0, a1, c0, c1, b0, b1).transpose(0,2,1,3,5,4)
763        out = out + get_w(kk, kj, ki, kc, kb, ka, c0, c1, b0, b1, a0, a1).transpose(2,1,0,5,4,3)
764        out = out + get_w(kj, ki, kk, kb, ka, kc, b0, b1, a0, a1, c0, c1).transpose(1,0,2,4,3,5)
765        return out
766
767    def get_data(kpt_indices):
768        idx_args = get_data_slices(kpt_indices, task, kconserv)
769        vvop_indices, vooo_indices, t2T_vvop_indices, t2T_vooo_indices = idx_args
770        vvop_data = [eris_vvop[tuple(x)] for x in vvop_indices]
771        vooo_data = [eris_vooo_C[tuple(x)] for x in vooo_indices]
772        t2T_vvop_data = [t2T[tuple(x)] for x in t2T_vvop_indices]
773        t2T_vooo_data = [t2T[tuple(x)] for x in t2T_vooo_indices]
774        data = [vvop_data, vooo_data, t2T_vvop_data, t2T_vooo_data]
775        return data
776
777    def add_and_permute(kpt_indices, orb_indices, data):
778        '''Performs permutation and addition of t3 temporary arrays.'''
779        ki, kj, kk, ka, kb, kc = kpt_indices
780        a0, a1, b0, b1, c0, c1 = orb_indices
781        tmp_t3Tv_ijk = np.asarray(data[0], dtype=dtype, order='C')
782        tmp_t3Tv_jik = np.asarray(data[1], dtype=dtype, order='C')
783        tmp_t3Tv_kji = np.asarray(data[2], dtype=dtype, order='C')
784        #out_ijk = np.empty(data[0].shape, dtype=dtype, order='C')
785
786        #drv = _ccsd.libcc.MPICCadd_and_permute_t3T
787        #drv(ctypes.c_int(nocc), ctypes.c_int(nvir),
788        #    ctypes.c_int(0),
789        #    out_ijk.ctypes.data_as(ctypes.c_void_p),
790        #    tmp_t3Tv_ijk.ctypes.data_as(ctypes.c_void_p),
791        #    tmp_t3Tv_jik.ctypes.data_as(ctypes.c_void_p),
792        #    tmp_t3Tv_kji.ctypes.data_as(ctypes.c_void_p),
793        #    mo_offset.ctypes.data_as(ctypes.c_void_p),
794        #    slices.ctypes.data_as(ctypes.c_void_p))
795        return (2.*tmp_t3Tv_ijk -
796                   tmp_t3Tv_jik.transpose(0,1,2,4,3,5) -
797                   tmp_t3Tv_kji.transpose(0,1,2,5,4,3))
798        #return out_ijk
799
800    # Get location of padded elements in occupied and virtual space
801    nonzero_opadding, nonzero_vpadding = padding_k_idx(mycc, kind="split")
802
803    mem_now = lib.current_memory()[0]
804    max_memory = max(0, mycc.max_memory - mem_now)
805    blkmin = 4
806    # temporary t3 array is size:  nkpts**3 * blksize**3 * nocc**3 * 16
807    vir_blksize = min(nvir, max(blkmin, int((max_memory*.9e6/16/nocc**3/nkpts**3)**(1./3))))
808    tasks = []
809    logger.debug(mycc, 'max_memory %d MB (%d MB in use)', max_memory, mem_now)
810    logger.debug(mycc, 'virtual blksize = %d (nvir = %d)', vir_blksize, nvir)
811    for a0, a1 in lib.prange(0, nvir, vir_blksize):
812        for b0, b1 in lib.prange(0, nvir, vir_blksize):
813            for c0, c1 in lib.prange(0, nvir, vir_blksize):
814                tasks.append((a0,a1,b0,b1,c0,c1))
815
816    eaa = []
817    for ka in range(nkpts):
818        eaa.append(mo_e_o[ka][:, None] - mo_e_v[ka][None, :])
819
820    pt1 = np.zeros((nkpts,nocc,nvir), dtype=dtype)
821    pt2 = np.zeros((nkpts,nkpts,nkpts,nocc,nocc,nvir,nvir), dtype=dtype)
822    for ka, kb in product(range(nkpts), repeat=2):
823        for task_id, task in enumerate(tasks):
824            cput2 = (logger.process_clock(), logger.perf_counter())
825            a0,a1,b0,b1,c0,c1 = task
826            my_permuted_w = np.zeros((nkpts,)*3 + (a1-a0,b1-b0,c1-c0) + (nocc,)*3, dtype=dtype)
827
828            for ki, kj, kk in product(range(nkpts), repeat=3):
829                # Find momentum conservation condition for triples
830                # amplitude t3ijkabc
831                kc = kpts_helper.get_kconserv3(cell, kpts, [ki, kj, kk, ka, kb])
832
833                kpt_indices = [ki,kj,kk,ka,kb,kc]
834                #data = get_data(kpt_indices)
835                my_permuted_w[ki,kj,kk] = get_permuted_w(ki,kj,kk,ka,kb,kc,task)
836
837            for ki, kj, kk in product(range(nkpts), repeat=3):
838                # eigenvalue denominator: e(i) + e(j) + e(k)
839                eijk = _get_epqr([0,nocc,ki,mo_e_o,nonzero_opadding],
840                                 [0,nocc,kj,mo_e_o,nonzero_opadding],
841                                 [0,nocc,kk,mo_e_o,nonzero_opadding])
842
843                # Find momentum conservation condition for triples
844                # amplitude t3ijkabc
845                kc = kpts_helper.get_kconserv3(cell, kpts, [ki, kj, kk, ka, kb])
846                eabc = _get_epqr([a0,a1,ka,mo_e_v,nonzero_vpadding],
847                                 [b0,b1,kb,mo_e_v,nonzero_vpadding],
848                                 [c0,c1,kc,mo_e_v,nonzero_vpadding],
849                                 fac=[-1.,-1.,-1.])
850
851                kpt_indices = [ki,kj,kk,ka,kb,kc]
852                eabcijk = (eijk[None,None,None,:,:,:] + eabc[:,:,:,None,None,None])
853
854                tmp_t3Tv_ijk = my_permuted_w[ki,kj,kk]
855                tmp_t3Tv_jik = my_permuted_w[kj,ki,kk]
856                tmp_t3Tv_kji = my_permuted_w[kk,kj,ki]
857                Ptmp_t3Tv = add_and_permute(kpt_indices, task,
858                                            (tmp_t3Tv_ijk,tmp_t3Tv_jik,tmp_t3Tv_kji))
859                Ptmp_t3Tv /= eabcijk
860
861                # Contribution to T1 amplitudes
862                if ki == ka and kc == kconserv[kj, kb, kk]:
863                    eris_Soovv = (2.*eris.oovv[kj,kk,kb,:,:,b0:b1,c0:c1] -
864                                     eris.oovv[kj,kk,kc,:,:,c0:c1,b0:b1].transpose(0,1,3,2))
865                    pt1[ka,:,a0:a1] += 0.5*einsum('abcijk,jkbc->ia', Ptmp_t3Tv,
866                                                  eris_Soovv)
867
868                # Contribution to T2 amplitudes
869                if ki == ka and kc == kconserv[kj, kb, kk]:
870                    tmp = einsum('abcijk,ia->jkbc', Ptmp_t3Tv, 0.5*fov[ki,:,a0:a1])
871                    _add_pt2(pt2, nkpts, kconserv, [kj,kk,kb], [None,None,(b0,b1),(c0,c1)], tmp)
872
873                kd = kconserv[ka,ki,kb]
874                eris_vovv = eris.vovv[kd,ki,kb,:,:,b0:b1,a0:a1]
875                tmp = einsum('abcijk,diba->jkdc', Ptmp_t3Tv, eris_vovv)
876                _add_pt2(pt2, nkpts, kconserv, [kj,kk,kd], [None,None,None,(c0,c1)], tmp)
877
878                km = kconserv[kc, kk, kb]
879                eris_ooov = eris.ooov[kj,ki,km,:,:,:,a0:a1]
880                tmp = einsum('abcijk,jima->mkbc', Ptmp_t3Tv, eris_ooov)
881                _add_pt2(pt2, nkpts, kconserv, [km,kk,kb], [None,None,(b0,b1),(c0,c1)], -1.*tmp)
882
883                # Contribution to Wovoo array
884                km = kconserv[ka,ki,kc]
885                eris_oovv = eris.oovv[km,ki,kc,:,:,c0:c1,a0:a1]
886                tmp = einsum('abcijk,mica->mbkj', Ptmp_t3Tv, eris_oovv)
887                Wmcik[km,kb,kk,:,b0:b1,:,:] += tmp
888
889                # Contribution to Wvvoo array
890                ke = kconserv[ki,ka,kk]
891                eris_oovv = eris.oovv[ki,kk,ka,:,:,a0:a1,:]
892                tmp = einsum('abcijk,ikae->cbej', Ptmp_t3Tv, eris_oovv)
893                Wacek[kc,kb,ke,c0:c1,b0:b1,:,:] -= tmp
894
895            logger.timer_debug1(mycc, 'EOM-CCSD T3[2] ka,kb,vir=(%d,%d,%d/%d) [total=%d]'%
896                                (ka,kb,task_id,len(tasks),nkpts**5), *cput2)
897
898    for ki in range(nkpts):
899        ka = ki
900        eia = LARGE_DENOM * np.ones((nocc, nvir), dtype=eris.mo_energy[0].dtype)
901        n0_ovp_ia = np.ix_(nonzero_opadding[ki], nonzero_vpadding[ka])
902        eia[n0_ovp_ia] = (mo_e_o[ki][:,None] - mo_e_v[ka])[n0_ovp_ia]
903        pt1[ki] /= eia
904
905    for ki, ka in product(range(nkpts), repeat=2):
906        eia = LARGE_DENOM * np.ones((nocc, nvir), dtype=eris.mo_energy[0].dtype)
907        n0_ovp_ia = np.ix_(nonzero_opadding[ki], nonzero_vpadding[ka])
908        eia[n0_ovp_ia] = (mo_e_o[ki][:,None] - mo_e_v[ka])[n0_ovp_ia]
909        for kj in range(nkpts):
910            kb = kconserv[ki, ka, kj]
911            ejb = LARGE_DENOM * np.ones((nocc, nvir), dtype=eris.mo_energy[0].dtype)
912            n0_ovp_jb = np.ix_(nonzero_opadding[kj], nonzero_vpadding[kb])
913            ejb[n0_ovp_jb] = (mo_e_o[kj][:,None] - mo_e_v[kb])[n0_ovp_jb]
914            eijab = eia[:, None, :, None] + ejb[:, None, :]
915            pt2[ki, kj, ka] /= eijab
916
917
918    pt1 += t1
919    pt2 += t2
920
921    logger.timer(mycc, 'EOM-CCSD(T) imds', *cpu0)
922
923    delta_ccsd_energy = mycc.energy(pt1, pt2, eris) - ccsd_energy
924    logger.info(mycc, 'CCSD energy T3[2] correction : %16.12e', delta_ccsd_energy)
925
926    return delta_ccsd_energy, pt1, pt2, Wmcik, Wacek
927