1#!/usr/bin/env python
2# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#
18
19'''
20UCCSD(T)
21'''
22
23
24import ctypes
25import numpy
26from pyscf import lib
27from pyscf.lib import logger
28from pyscf.cc import _ccsd
29
30def kernel(mycc, eris, t1=None, t2=None, verbose=logger.NOTE):
31    cpu1 = cpu0 = (logger.process_clock(), logger.perf_counter())
32    log = logger.new_logger(mycc, verbose)
33    if t1 is None: t1 = mycc.t1
34    if t2 is None: t2 = mycc.t2
35    t1a, t1b = t1
36    t2aa, t2ab, t2bb = t2
37
38    nocca, noccb = mycc.nocc
39    nmoa = eris.focka.shape[0]
40    nmob = eris.fockb.shape[0]
41    nvira = nmoa - nocca
42    nvirb = nmob - noccb
43
44    if mycc.incore_complete:
45        ftmp = None
46    else:
47        ftmp = lib.H5TmpFile()
48    t1aT = t1a.T.copy()
49    t1bT = t1b.T.copy()
50    t2aaT = t2aa.transpose(2,3,0,1).copy()
51    t2bbT = t2bb.transpose(2,3,0,1).copy()
52
53    eris_vooo = numpy.asarray(eris.ovoo).transpose(1,3,0,2).conj().copy()
54    eris_VOOO = numpy.asarray(eris.OVOO).transpose(1,3,0,2).conj().copy()
55    eris_vOoO = numpy.asarray(eris.ovOO).transpose(1,3,0,2).conj().copy()
56    eris_VoOo = numpy.asarray(eris.OVoo).transpose(1,3,0,2).conj().copy()
57
58    eris_vvop, eris_VVOP, eris_vVoP, eris_VvOp = _sort_eri(mycc, eris, ftmp, log)
59    cpu1 = log.timer_debug1('UCCSD(T) sort_eri', *cpu1)
60
61    dtype = numpy.result_type(t1a.dtype, t2aa.dtype, eris_vooo.dtype)
62    et_sum = numpy.zeros(1, dtype=dtype)
63    mem_now = lib.current_memory()[0]
64    max_memory = max(0, mycc.max_memory - mem_now)
65    # aaa
66    bufsize = max(8, int((max_memory*.5e6/8-nocca**3*3*lib.num_threads())*.4/(nocca*nmoa)))
67    log.debug('max_memory %d MB (%d MB in use)', max_memory, mem_now)
68    orbsym = numpy.zeros(nocca, dtype=int)
69    contract = _gen_contract_aaa(t1aT, t2aaT, eris_vooo, eris.focka,
70                                 eris.mo_energy[0], orbsym, log)
71    with lib.call_in_background(contract, sync=not mycc.async_io) as ctr:
72        for a0, a1 in reversed(list(lib.prange_tril(0, nvira, bufsize))):
73            cache_row_a = numpy.asarray(eris_vvop[a0:a1,:a1], order='C')
74            if a0 == 0:
75                cache_col_a = cache_row_a
76            else:
77                cache_col_a = numpy.asarray(eris_vvop[:a0,a0:a1], order='C')
78            ctr(et_sum, a0, a1, a0, a1, (cache_row_a,cache_col_a,
79                                         cache_row_a,cache_col_a))
80
81            for b0, b1 in lib.prange_tril(0, a0, bufsize/8):
82                cache_row_b = numpy.asarray(eris_vvop[b0:b1,:b1], order='C')
83                if b0 == 0:
84                    cache_col_b = cache_row_b
85                else:
86                    cache_col_b = numpy.asarray(eris_vvop[:b0,b0:b1], order='C')
87                ctr(et_sum, a0, a1, b0, b1, (cache_row_a,cache_col_a,
88                                             cache_row_b,cache_col_b))
89    cpu1 = log.timer_debug1('contract_aaa', *cpu1)
90
91    # bbb
92    bufsize = max(8, int((max_memory*.5e6/8-noccb**3*3*lib.num_threads())*.4/(noccb*nmob)))
93    log.debug('max_memory %d MB (%d MB in use)', max_memory, mem_now)
94    orbsym = numpy.zeros(noccb, dtype=int)
95    contract = _gen_contract_aaa(t1bT, t2bbT, eris_VOOO, eris.fockb,
96                                 eris.mo_energy[1], orbsym, log)
97    with lib.call_in_background(contract, sync=not mycc.async_io) as ctr:
98        for a0, a1 in reversed(list(lib.prange_tril(0, nvirb, bufsize))):
99            cache_row_a = numpy.asarray(eris_VVOP[a0:a1,:a1], order='C')
100            if a0 == 0:
101                cache_col_a = cache_row_a
102            else:
103                cache_col_a = numpy.asarray(eris_VVOP[:a0,a0:a1], order='C')
104            ctr(et_sum, a0, a1, a0, a1, (cache_row_a,cache_col_a,
105                                         cache_row_a,cache_col_a))
106
107            for b0, b1 in lib.prange_tril(0, a0, bufsize/8):
108                cache_row_b = numpy.asarray(eris_VVOP[b0:b1,:b1], order='C')
109                if b0 == 0:
110                    cache_col_b = cache_row_b
111                else:
112                    cache_col_b = numpy.asarray(eris_VVOP[:b0,b0:b1], order='C')
113                ctr(et_sum, a0, a1, b0, b1, (cache_row_a,cache_col_a,
114                                             cache_row_b,cache_col_b))
115    cpu1 = log.timer_debug1('contract_bbb', *cpu1)
116
117    # Cache t2abT in t2ab to reduce memory footprint
118    assert(t2ab.flags.c_contiguous)
119    t2abT = lib.transpose(t2ab.copy().reshape(nocca*noccb,nvira*nvirb), out=t2ab)
120    t2abT = t2abT.reshape(nvira,nvirb,nocca,noccb)
121    # baa
122    bufsize = int(max(12, (max_memory*.5e6/8-noccb*nocca**2*5)*.7/(nocca*nmob)))
123    ts = t1aT, t1bT, t2aaT, t2abT
124    fock = (eris.focka, eris.fockb)
125    vooo = (eris_vooo, eris_vOoO, eris_VoOo)
126    contract = _gen_contract_baa(ts, vooo, fock, eris.mo_energy, orbsym, log)
127    with lib.call_in_background(contract, sync=not mycc.async_io) as ctr:
128        for a0, a1 in lib.prange(0, nvirb, int(bufsize/nvira+1)):
129            cache_row_a = numpy.asarray(eris_VvOp[a0:a1,:], order='C')
130            cache_col_a = numpy.asarray(eris_vVoP[:,a0:a1], order='C')
131            for b0, b1 in lib.prange_tril(0, nvira, bufsize/6/2):
132                cache_row_b = numpy.asarray(eris_vvop[b0:b1,:b1], order='C')
133                cache_col_b = numpy.asarray(eris_vvop[:b0,b0:b1], order='C')
134                ctr(et_sum, a0, a1, b0, b1, (cache_row_a,cache_col_a,
135                                             cache_row_b,cache_col_b))
136    cpu1 = log.timer_debug1('contract_baa', *cpu1)
137
138    t2baT = numpy.ndarray((nvirb,nvira,noccb,nocca), buffer=t2abT,
139                          dtype=t2abT.dtype)
140    t2baT[:] = t2abT.copy().transpose(1,0,3,2)
141    # abb
142    ts = t1bT, t1aT, t2bbT, t2baT
143    fock = (eris.fockb, eris.focka)
144    mo_energy = (eris.mo_energy[1], eris.mo_energy[0])
145    vooo = (eris_VOOO, eris_VoOo, eris_vOoO)
146    contract = _gen_contract_baa(ts, vooo, fock, mo_energy, orbsym, log)
147    for a0, a1 in lib.prange(0, nvira, int(bufsize/nvirb+1)):
148        with lib.call_in_background(contract, sync=not mycc.async_io) as ctr:
149            cache_row_a = numpy.asarray(eris_vVoP[a0:a1,:], order='C')
150            cache_col_a = numpy.asarray(eris_VvOp[:,a0:a1], order='C')
151            for b0, b1 in lib.prange_tril(0, nvirb, bufsize/6/2):
152                cache_row_b = numpy.asarray(eris_VVOP[b0:b1,:b1], order='C')
153                cache_col_b = numpy.asarray(eris_VVOP[:b0,b0:b1], order='C')
154                ctr(et_sum, a0, a1, b0, b1, (cache_row_a,cache_col_a,
155                                             cache_row_b,cache_col_b))
156    cpu1 = log.timer_debug1('contract_abb', *cpu1)
157
158    # Restore t2ab
159    lib.transpose(t2baT.transpose(1,0,3,2).copy().reshape(nvira*nvirb,nocca*noccb),
160                  out=t2ab)
161    et_sum *= .25
162    if abs(et_sum[0].imag) > 1e-4:
163        logger.warn(mycc, 'Non-zero imaginary part of UCCSD(T) energy was found %s',
164                    et_sum[0])
165    et = et_sum[0].real
166    log.timer('UCCSD(T)', *cpu0)
167    log.note('UCCSD(T) correction = %.15g', et)
168    return et
169
170def _gen_contract_aaa(t1T, t2T, vooo, fock, mo_energy, orbsym, log):
171    nvir, nocc = t1T.shape
172    mo_energy = numpy.asarray(mo_energy, order='C')
173    fvo = fock[nocc:,:nocc].copy()
174
175    cpu2 = [logger.process_clock(), logger.perf_counter()]
176    orbsym = numpy.hstack((numpy.sort(orbsym[:nocc]),numpy.sort(orbsym[nocc:])))
177    o_ir_loc = numpy.append(0, numpy.cumsum(numpy.bincount(orbsym[:nocc], minlength=8)))
178    v_ir_loc = numpy.append(0, numpy.cumsum(numpy.bincount(orbsym[nocc:], minlength=8)))
179    o_sym = orbsym[:nocc]
180    oo_sym = (o_sym[:,None] ^ o_sym).ravel()
181    oo_ir_loc = numpy.append(0, numpy.cumsum(numpy.bincount(oo_sym, minlength=8)))
182    nirrep = max(oo_sym) + 1
183
184    orbsym   = orbsym.astype(numpy.int32)
185    o_ir_loc = o_ir_loc.astype(numpy.int32)
186    v_ir_loc = v_ir_loc.astype(numpy.int32)
187    oo_ir_loc = oo_ir_loc.astype(numpy.int32)
188    dtype = numpy.result_type(t2T.dtype, vooo.dtype, fock.dtype)
189    if dtype == numpy.complex128:
190        drv = _ccsd.libcc.CCuccsd_t_zaaa
191    else:
192        drv = _ccsd.libcc.CCuccsd_t_aaa
193    def contract(et_sum, a0, a1, b0, b1, cache):
194        cache_row_a, cache_col_a, cache_row_b, cache_col_b = cache
195        drv(et_sum.ctypes.data_as(ctypes.c_void_p),
196            mo_energy.ctypes.data_as(ctypes.c_void_p),
197            t1T.ctypes.data_as(ctypes.c_void_p),
198            t2T.ctypes.data_as(ctypes.c_void_p),
199            vooo.ctypes.data_as(ctypes.c_void_p),
200            fvo.ctypes.data_as(ctypes.c_void_p),
201            ctypes.c_int(nocc), ctypes.c_int(nvir),
202            ctypes.c_int(a0), ctypes.c_int(a1),
203            ctypes.c_int(b0), ctypes.c_int(b1),
204            ctypes.c_int(nirrep),
205            o_ir_loc.ctypes.data_as(ctypes.c_void_p),
206            v_ir_loc.ctypes.data_as(ctypes.c_void_p),
207            oo_ir_loc.ctypes.data_as(ctypes.c_void_p),
208            orbsym.ctypes.data_as(ctypes.c_void_p),
209            cache_row_a.ctypes.data_as(ctypes.c_void_p),
210            cache_col_a.ctypes.data_as(ctypes.c_void_p),
211            cache_row_b.ctypes.data_as(ctypes.c_void_p),
212            cache_col_b.ctypes.data_as(ctypes.c_void_p))
213        cpu2[:] = log.timer_debug1('contract %d:%d,%d:%d'%(a0,a1,b0,b1), *cpu2)
214    return contract
215
216def _gen_contract_baa(ts, vooo, fock, mo_energy, orbsym, log):
217    t1aT, t1bT, t2aaT, t2abT = ts
218    focka, fockb = fock
219    vooo, vOoO, VoOo = vooo
220    nvira, nocca = t1aT.shape
221    nvirb, noccb = t1bT.shape
222    mo_ea = numpy.asarray(mo_energy[0], order='C')
223    mo_eb = numpy.asarray(mo_energy[1], order='C')
224    fvo = focka[nocca:,:nocca].copy()
225    fVO = fockb[noccb:,:noccb].copy()
226
227    cpu2 = [logger.process_clock(), logger.perf_counter()]
228    dtype = numpy.result_type(t2aaT.dtype, vooo.dtype)
229    if dtype == numpy.complex128:
230        drv = _ccsd.libcc.CCuccsd_t_zbaa
231    else:
232        drv = _ccsd.libcc.CCuccsd_t_baa
233    def contract(et_sum, a0, a1, b0, b1, cache):
234        cache_row_a, cache_col_a, cache_row_b, cache_col_b = cache
235        drv(et_sum.ctypes.data_as(ctypes.c_void_p),
236            mo_ea.ctypes.data_as(ctypes.c_void_p),
237            mo_eb.ctypes.data_as(ctypes.c_void_p),
238            t1aT.ctypes.data_as(ctypes.c_void_p),
239            t1bT.ctypes.data_as(ctypes.c_void_p),
240            t2aaT.ctypes.data_as(ctypes.c_void_p),
241            t2abT.ctypes.data_as(ctypes.c_void_p),
242            vooo.ctypes.data_as(ctypes.c_void_p),
243            vOoO.ctypes.data_as(ctypes.c_void_p),
244            VoOo.ctypes.data_as(ctypes.c_void_p),
245            fvo.ctypes.data_as(ctypes.c_void_p),
246            fVO.ctypes.data_as(ctypes.c_void_p),
247            ctypes.c_int(nocca), ctypes.c_int(noccb),
248            ctypes.c_int(nvira), ctypes.c_int(nvirb),
249            ctypes.c_int(a0), ctypes.c_int(a1),
250            ctypes.c_int(b0), ctypes.c_int(b1),
251            cache_row_a.ctypes.data_as(ctypes.c_void_p),
252            cache_col_a.ctypes.data_as(ctypes.c_void_p),
253            cache_row_b.ctypes.data_as(ctypes.c_void_p),
254            cache_col_b.ctypes.data_as(ctypes.c_void_p))
255        cpu2[:] = log.timer_debug1('contract %d:%d,%d:%d'%(a0,a1,b0,b1), *cpu2)
256    return contract
257
258def _sort_eri(mycc, eris, h5tmp, log):
259    cpu1 = (logger.process_clock(), logger.perf_counter())
260    nocca, noccb = mycc.nocc
261    nmoa = eris.focka.shape[0]
262    nmob = eris.fockb.shape[0]
263    nvira = nmoa - nocca
264    nvirb = nmob - noccb
265
266    if mycc.t2 is None:
267        dtype = eris.ovov.dtype
268    else:
269        dtype = numpy.result_type(mycc.t2[0], eris.ovov.dtype)
270
271    if mycc.incore_complete or h5tmp is None:
272        eris_vvop = numpy.empty((nvira,nvira,nocca,nmoa), dtype)
273        eris_VVOP = numpy.empty((nvirb,nvirb,noccb,nmob), dtype)
274        eris_vVoP = numpy.empty((nvira,nvirb,nocca,nmob), dtype)
275        eris_VvOp = numpy.empty((nvirb,nvira,noccb,nmoa), dtype)
276    else:
277        eris_vvop = h5tmp.create_dataset('vvop', (nvira,nvira,nocca,nmoa), dtype)
278        eris_VVOP = h5tmp.create_dataset('VVOP', (nvirb,nvirb,noccb,nmob), dtype)
279        eris_vVoP = h5tmp.create_dataset('vVoP', (nvira,nvirb,nocca,nmob), dtype)
280        eris_VvOp = h5tmp.create_dataset('VvOp', (nvirb,nvira,noccb,nmoa), dtype)
281
282    max_memory = max(2000, mycc.max_memory - lib.current_memory()[0])
283    max_memory = min(8000, max_memory*.9)
284
285    blksize = min(nvira, max(16, int(max_memory*1e6/8/(nvira*nocca*nmoa))))
286    with lib.call_in_background(eris_vvop.__setitem__, sync=not mycc.async_io) as save:
287        bufopv = numpy.empty((nocca,nmoa,nvira), dtype=dtype)
288        buf1 = numpy.empty_like(bufopv)
289        for j0, j1 in lib.prange(0, nvira, blksize):
290            ovov = numpy.asarray(eris.ovov[:,j0:j1])
291            ovvv = eris.get_ovvv(slice(None), slice(j0,j1))
292            for j in range(j0,j1):
293                bufopv[:,:nocca,:] = ovov[:,j-j0].conj()
294                bufopv[:,nocca:,:] = ovvv[:,j-j0].conj()
295                save(j, bufopv.transpose(2,0,1))
296                bufopv, buf1 = buf1, bufopv
297            ovov = ovvv = None
298            cpu1 = log.timer_debug1('transpose %d:%d'%(j0,j1), *cpu1)
299
300    blksize = min(nvirb, max(16, int(max_memory*1e6/8/(nvirb*noccb*nmob))))
301    with lib.call_in_background(eris_VVOP.__setitem__, sync=not mycc.async_io) as save:
302        bufopv = numpy.empty((noccb,nmob,nvirb), dtype=dtype)
303        buf1 = numpy.empty_like(bufopv)
304        for j0, j1 in lib.prange(0, nvirb, blksize):
305            ovov = numpy.asarray(eris.OVOV[:,j0:j1])
306            ovvv = eris.get_OVVV(slice(None), slice(j0,j1))
307            for j in range(j0,j1):
308                bufopv[:,:noccb,:] = ovov[:,j-j0].conj()
309                bufopv[:,noccb:,:] = ovvv[:,j-j0].conj()
310                save(j, bufopv.transpose(2,0,1))
311                bufopv, buf1 = buf1, bufopv
312            ovov = ovvv = None
313            cpu1 = log.timer_debug1('transpose %d:%d'%(j0,j1), *cpu1)
314
315    blksize = min(nvira, max(16, int(max_memory*1e6/8/(nvirb*nocca*nmob))))
316    with lib.call_in_background(eris_vVoP.__setitem__, sync=not mycc.async_io) as save:
317        bufopv = numpy.empty((nocca,nmob,nvirb), dtype=dtype)
318        buf1 = numpy.empty_like(bufopv)
319        for j0, j1 in lib.prange(0, nvira, blksize):
320            ovov = numpy.asarray(eris.ovOV[:,j0:j1])
321            ovvv = eris.get_ovVV(slice(None), slice(j0,j1))
322            for j in range(j0,j1):
323                bufopv[:,:noccb,:] = ovov[:,j-j0].conj()
324                bufopv[:,noccb:,:] = ovvv[:,j-j0].conj()
325                save(j, bufopv.transpose(2,0,1))
326                bufopv, buf1 = buf1, bufopv
327            ovov = ovvv = None
328            cpu1 = log.timer_debug1('transpose %d:%d'%(j0,j1), *cpu1)
329
330    blksize = min(nvirb, max(16, int(max_memory*1e6/8/(nvira*noccb*nmoa))))
331    OVov = numpy.asarray(eris.ovOV).transpose(2,3,0,1)
332    with lib.call_in_background(eris_VvOp.__setitem__, sync=not mycc.async_io) as save:
333        bufopv = numpy.empty((noccb,nmoa,nvira), dtype=dtype)
334        buf1 = numpy.empty_like(bufopv)
335        for j0, j1 in lib.prange(0, nvirb, blksize):
336            ovov = OVov[:,j0:j1]
337            ovvv = eris.get_OVvv(slice(None), slice(j0,j1))
338            for j in range(j0,j1):
339                bufopv[:,:nocca,:] = ovov[:,j-j0].conj()
340                bufopv[:,nocca:,:] = ovvv[:,j-j0].conj()
341                save(j, bufopv.transpose(2,0,1))
342                bufopv, buf1 = buf1, bufopv
343            ovov = ovvv = None
344            cpu1 = log.timer_debug1('transpose %d:%d'%(j0,j1), *cpu1)
345    return eris_vvop, eris_VVOP, eris_vVoP, eris_VvOp
346
347
348if __name__ == '__main__':
349    from pyscf import gto
350    from pyscf import scf
351    from pyscf import cc
352
353    mol = gto.Mole()
354    mol.atom = [
355        [8 , (0. , 0.     , 0.)],
356        [1 , (0. , -.757 , .587)],
357        [1 , (0. ,  .757 , .587)]]
358
359    mol.basis = '631g'
360    mol.build()
361    rhf = scf.RHF(mol)
362    rhf.conv_tol = 1e-14
363    rhf.scf()
364    mcc = cc.CCSD(rhf)
365    mcc.conv_tol = 1e-12
366    mcc.ccsd()
367    t1a = t1b = mcc.t1
368    t2ab = mcc.t2
369    t2aa = t2bb = t2ab - t2ab.transpose(1,0,2,3)
370    mycc = cc.UCCSD(scf.addons.convert_to_uhf(rhf))
371    eris = mycc.ao2mo()
372    e3a = kernel(mycc, eris, (t1a,t1b), (t2aa,t2ab,t2bb))
373    print(e3a - -0.00099642337843278096)
374
375    mol = gto.Mole()
376    mol.atom = [
377        [8 , (0. , 0.     , 0.)],
378        [1 , (0. , -.757 , .587)],
379        [1 , (0. ,  .757 , .587)]]
380    mol.spin = 2
381    mol.basis = '3-21g'
382    mol.build()
383    mf = scf.UHF(mol).run(conv_tol=1e-14)
384    nao, nmo = mf.mo_coeff[0].shape
385    numpy.random.seed(10)
386    mf.mo_coeff = numpy.random.random((2,nao,nmo))
387
388    numpy.random.seed(12)
389    nocca, noccb = mol.nelec
390    nmo = mf.mo_occ[0].size
391    nvira = nmo - nocca
392    nvirb = nmo - noccb
393    t1a  = .1 * numpy.random.random((nocca,nvira))
394    t1b  = .1 * numpy.random.random((noccb,nvirb))
395    t2aa = .1 * numpy.random.random((nocca,nocca,nvira,nvira))
396    t2aa = t2aa - t2aa.transpose(0,1,3,2)
397    t2aa = t2aa - t2aa.transpose(1,0,2,3)
398    t2bb = .1 * numpy.random.random((noccb,noccb,nvirb,nvirb))
399    t2bb = t2bb - t2bb.transpose(0,1,3,2)
400    t2bb = t2bb - t2bb.transpose(1,0,2,3)
401    t2ab = .1 * numpy.random.random((nocca,noccb,nvira,nvirb))
402    t1 = t1a, t1b
403    t2 = t2aa, t2ab, t2bb
404    mycc = cc.UCCSD(mf)
405    eris = mycc.ao2mo(mf.mo_coeff)
406    e3a = kernel(mycc, eris, [t1a,t1b], [t2aa, t2ab, t2bb])
407    print(e3a - 9877.2780859693339)
408
409    mycc = cc.GCCSD(scf.addons.convert_to_ghf(mf))
410    eris = mycc.ao2mo()
411    t1 = mycc.spatial2spin(t1, eris.orbspin)
412    t2 = mycc.spatial2spin(t2, eris.orbspin)
413    from pyscf.cc import gccsd_t_slow
414    et = gccsd_t_slow.kernel(mycc, eris, t1, t2)
415    print(et - 9877.2780859693339)
416
417