1#!/usr/bin/env python
2# Copyright 2017-2021 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Authors: James D. McClain <jmcclain@princeton.edu>
17#
18"""Module for running restricted closed-shell k-point ccsd(t)"""
19
20import h5py
21import itertools
22import numpy as np
23import pyscf.pbc.cc.kccsd_rhf
24
25
26from itertools import product
27from pyscf import lib
28#from pyscf import _ccsd
29from pyscf.lib import logger
30from pyscf.lib.misc import tril_product
31from pyscf.lib.misc import flatten
32from pyscf.lib.numpy_helper import cartesian_prod
33from pyscf.lib.numpy_helper import pack_tril
34from pyscf.lib.parameters import LARGE_DENOM
35from pyscf.pbc.lib import kpts_helper
36from pyscf.pbc.mp.kmp2 import (get_frozen_mask, get_nocc, get_nmo,
37                               padded_mo_coeff, padding_k_idx)
38from pyscf import __config__
39
40#einsum = np.einsum
41einsum = lib.einsum
42
43# CCSD(T) equations taken from Scuseria, JCP (94), 1991
44#
45# NOTE: As pointed out in cc/ccsd_t_slow.py, there is an error in this paper
46#     and the equation should read [ia] >= [jb] >= [kc] (since the only
47#     symmetry in spin-less operators is the exchange of a column of excitation
48#     ooperators).
49def kernel(mycc, eris, t1=None, t2=None, max_memory=2000, verbose=logger.INFO):
50    '''Returns the CCSD(T) for restricted closed-shell systems with k-points.
51
52    Note:
53        Returns real part of the CCSD(T) energy, raises warning if there is
54        a complex part.
55
56    Args:
57        mycc (:class:`RCCSD`): Coupled-cluster object storing results of
58            a coupled-cluster calculation.
59        eris (:class:`_ERIS`): Integral object holding the relevant electron-
60            repulsion integrals and Fock matrix elements
61        t1 (:obj:`ndarray`): t1 coupled-cluster amplitudes
62        t2 (:obj:`ndarray`): t2 coupled-cluster amplitudes
63        max_memory (float): Maximum memory used in calculation (NOT USED)
64        verbose (int, :class:`Logger`): verbosity of calculation
65
66    Returns:
67        energy_t (float): The real-part of the k-point CCSD(T) energy.
68    '''
69    assert isinstance(mycc, pyscf.pbc.cc.kccsd_rhf.RCCSD)
70    cpu1 = cpu0 = (logger.process_clock(), logger.perf_counter())
71    if isinstance(verbose, logger.Logger):
72        log = verbose
73    else:
74        log = logger.Logger(mycc.stdout, verbose)
75
76    if t1 is None: t1 = mycc.t1
77    if t2 is None: t2 = mycc.t2
78
79    if eris is None:
80        raise TypeError('Electron repulsion integrals, `eris`, must be passed in '
81                        'to the CCSD(T) kernel or created in the cc object for '
82                        'the k-point CCSD(T) to run!')
83    if t1 is None or t2 is None:
84        raise TypeError('Must pass in t1/t2 amplitudes to k-point CCSD(T)! (Maybe '
85                        'need to run `.ccsd()` on the ccsd object?)')
86
87    cell = mycc._scf.cell
88    kpts = mycc.kpts
89
90    # The dtype of any local arrays that will be created
91    dtype = t1.dtype
92
93    nkpts, nocc, nvir = t1.shape
94
95    mo_energy_occ = [eris.mo_energy[ki][:nocc] for ki in range(nkpts)]
96    mo_energy_vir = [eris.mo_energy[ki][nocc:] for ki in range(nkpts)]
97    fov = eris.fock[:, :nocc, nocc:]
98
99    # Set up class for k-point conservation
100    kconserv = kpts_helper.get_kconserv(cell, kpts)
101
102    cpu1 = log.timer_debug1('CCSD(T) tmp eri creation', *cpu1)
103
104    def get_w(ki, kj, kk, ka, kb, kc, a0, a1, b0, b1, c0, c1, out=None):
105        '''Wijkabc intermediate as described in Scuseria paper before Pijkabc acts'''
106        km = kconserv[ki, ka, kj]
107        kf = kconserv[kk, kc, kj]
108        ret = einsum('kjcf,fiba->abcijk', t2[kk,kj,kc,:,:,c0:c1,:], eris.vovv[kf,ki,kb,:,:,b0:b1,a0:a1].conj())
109        ret = ret - einsum('mkbc,jima->abcijk', t2[km,kk,kb,:,:,b0:b1,c0:c1], eris.ooov[kj,ki,km,:,:,:,a0:a1].conj())
110        return ret
111
112    def get_permuted_w(ki, kj, kk, ka, kb, kc, orb_indices):
113        '''Pijkabc operating on Wijkabc intermediate as described in Scuseria paper'''
114        a0, a1, b0, b1, c0, c1 = orb_indices
115        out = get_w(ki, kj, kk, ka, kb, kc, a0, a1, b0, b1, c0, c1)
116        out = out + get_w(kj, kk, ki, kb, kc, ka, b0, b1, c0, c1, a0, a1).transpose(2,0,1,5,3,4)
117        out = out + get_w(kk, ki, kj, kc, ka, kb, c0, c1, a0, a1, b0, b1).transpose(1,2,0,4,5,3)
118        out = out + get_w(ki, kk, kj, ka, kc, kb, a0, a1, c0, c1, b0, b1).transpose(0,2,1,3,5,4)
119        out = out + get_w(kk, kj, ki, kc, kb, ka, c0, c1, b0, b1, a0, a1).transpose(2,1,0,5,4,3)
120        out = out + get_w(kj, ki, kk, kb, ka, kc, b0, b1, a0, a1, c0, c1).transpose(1,0,2,4,3,5)
121        return out
122
123    def get_rw(ki, kj, kk, ka, kb, kc, orb_indices):
124        '''R operating on Wijkabc intermediate as described in Scuseria paper'''
125        a0, a1, b0, b1, c0, c1 = orb_indices
126        ret = (4. * get_permuted_w(ki,kj,kk,ka,kb,kc,orb_indices) +
127               1. * get_permuted_w(kj,kk,ki,ka,kb,kc,orb_indices).transpose(0,1,2,5,3,4) +
128               1. * get_permuted_w(kk,ki,kj,ka,kb,kc,orb_indices).transpose(0,1,2,4,5,3) -
129               2. * get_permuted_w(ki,kk,kj,ka,kb,kc,orb_indices).transpose(0,1,2,3,5,4) -
130               2. * get_permuted_w(kk,kj,ki,ka,kb,kc,orb_indices).transpose(0,1,2,5,4,3) -
131               2. * get_permuted_w(kj,ki,kk,ka,kb,kc,orb_indices).transpose(0,1,2,4,3,5))
132        return ret
133
134    def get_v(ki, kj, kk, ka, kb, kc, a0, a1, b0, b1, c0, c1):
135        '''Vijkabc intermediate as described in Scuseria paper'''
136        out = np.zeros((a1-a0,b1-b0,c1-c0) + (nocc,)*3, dtype=dtype)
137        if kk == kc:
138            out = out + einsum('kc,ijab->abcijk', t1[kk,:,c0:c1], eris.oovv[ki,kj,ka,:,:,a0:a1,b0:b1].conj())
139            out = out + einsum('kc,ijab->abcijk', fov[kk,:,c0:c1], t2[ki,kj,ka,:,:,a0:a1,b0:b1])
140        return out
141
142    def get_permuted_v(ki, kj, kk, ka, kb, kc, orb_indices):
143        '''Pijkabc operating on Vijkabc intermediate as described in Scuseria paper'''
144        a0, a1, b0, b1, c0, c1 = orb_indices
145        ret = get_v(ki, kj, kk, ka, kb, kc, a0, a1, b0, b1, c0, c1)
146        ret = ret + get_v(kj, kk, ki, kb, kc, ka, b0, b1, c0, c1, a0, a1).transpose(2,0,1,5,3,4)
147        ret = ret + get_v(kk, ki, kj, kc, ka, kb, c0, c1, a0, a1, b0, b1).transpose(1,2,0,4,5,3)
148        ret = ret + get_v(ki, kk, kj, ka, kc, kb, a0, a1, c0, c1, b0, b1).transpose(0,2,1,3,5,4)
149        ret = ret + get_v(kk, kj, ki, kc, kb, ka, c0, c1, b0, b1, a0, a1).transpose(2,1,0,5,4,3)
150        ret = ret + get_v(kj, ki, kk, kb, ka, kc, b0, b1, a0, a1, c0, c1).transpose(1,0,2,4,3,5)
151        return ret
152
153    energy_t = 0.0
154
155    # Get location of padded elements in occupied and virtual space
156    nonzero_opadding, nonzero_vpadding = padding_k_idx(mycc, kind="split")
157
158    mem_now = lib.current_memory()[0]
159    max_memory = max(0, mycc.max_memory - mem_now)
160    blkmin = 4
161    # temporary t3 array is size:    blksize**3 * nocc**3 * 16
162    vir_blksize = min(nvir, max(blkmin, int((max_memory*.9e6/16/nocc**3)**(1./3))))
163    tasks = []
164    log.debug('max_memory %d MB (%d MB in use)', max_memory, mem_now)
165    log.debug('virtual blksize = %d (nvir = %d)', nvir, vir_blksize)
166    for a0, a1 in lib.prange(0, nvir, vir_blksize):
167        for b0, b1 in lib.prange(0, nvir, vir_blksize):
168            for c0, c1 in lib.prange(0, nvir, vir_blksize):
169                tasks.append((a0,a1,b0,b1,c0,c1))
170
171    for ka in range(nkpts):
172        for kb in range(ka+1):
173
174            for ki, kj, kk in product(range(nkpts), repeat=3):
175                # eigenvalue denominator: e(i) + e(j) + e(k)
176                eijk = LARGE_DENOM * np.ones((nocc,)*3, dtype=mo_energy_occ[0].dtype)
177                n0_ovp_ijk = np.ix_(nonzero_opadding[ki], nonzero_opadding[kj], nonzero_opadding[kk])
178                eijk[n0_ovp_ijk] = lib.direct_sum('i,j,k->ijk', mo_energy_occ[ki],
179                                                  mo_energy_occ[kj], mo_energy_occ[kk])[n0_ovp_ijk]
180
181                # Find momentum conservation condition for triples
182                # amplitude t3ijkabc
183                kc = kpts_helper.get_kconserv3(cell, kpts, [ki, kj, kk, ka, kb])
184
185                if not (ka >= kb and kb >= kc):
186                    continue
187
188                if ka == kb and kb == kc:
189                    symm_kpt = 1.
190                elif ka == kb or kb == kc:
191                    symm_kpt = 3.
192                else:
193                    symm_kpt = 6.
194
195                eabc = LARGE_DENOM * np.ones((nvir,)*3, dtype=mo_energy_occ[0].dtype)
196                n0_ovp_abc = np.ix_(nonzero_vpadding[ka], nonzero_vpadding[kb], nonzero_vpadding[kc])
197                eabc[n0_ovp_abc] = lib.direct_sum('i,j,k->ijk', mo_energy_vir[ka],
198                                                  mo_energy_vir[kb], mo_energy_vir[kc])[n0_ovp_abc]
199                for task_id, task in enumerate(tasks):
200                    eijkabc = (eijk[None,None,None,:,:,:] - eabc[a0:a1,b0:b1,c0:c1,None,None,None])
201                    pwijk = (get_permuted_w(ki,kj,kk,ka,kb,kc,task) +
202                             get_permuted_v(ki,kj,kk,ka,kb,kc,task) * 0.5)
203                    rwijk = get_rw(ki,kj,kk,ka,kb,kc,task) / eijkabc
204                    energy_t += symm_kpt * einsum('abcijk,abcijk', pwijk, rwijk.conj())
205
206    energy_t *= (1. / 3)
207    energy_t /= nkpts
208
209    if abs(energy_t.imag) > 1e-4:
210        log.warn('Non-zero imaginary part of CCSD(T) energy was found %s', energy_t.imag)
211    log.timer('CCSD(T)', *cpu0)
212    log.note('CCSD(T) correction per cell = %.15g', energy_t.real)
213    log.note('CCSD(T) correction per cell (imag) = %.15g', energy_t.imag)
214    return energy_t.real
215
216if __name__ == '__main__':
217    from pyscf.pbc import gto
218    from pyscf.pbc import scf
219    from pyscf.pbc import cc
220
221    cell = gto.Cell()
222    cell.atom = '''
223    C 0.000000000000   0.000000000000   0.000000000000
224    C 1.685068664391   1.685068664391   1.685068664391
225    '''
226    cell.basis = 'gth-szv'
227    cell.pseudo = 'gth-pade'
228    cell.a = '''
229    0.000000000, 3.370137329, 3.370137329
230    3.370137329, 0.000000000, 3.370137329
231    3.370137329, 3.370137329, 0.000000000'''
232    cell.conv_tol = 1e-12
233    cell.conv_tol_grad = 1e-12
234    cell.direct_scf_tol = 1e-16
235    cell.unit = 'B'
236    cell.verbose = 4
237    cell.mesh = [24, 24, 24]
238    cell.build()
239
240    nmp = [1,1,4]
241    kpts = cell.make_kpts(nmp)
242    kpts -= kpts[0]
243    kmf = scf.KRHF(cell, kpts=kpts, exxdiv=None)
244    kmf.conv_tol = 1e-12
245    kmf.conv_tol_grad = 1e-12
246    kmf.direct_scf_tol = 1e-16
247    ehf = kmf.kernel()
248
249    mycc = cc.KRCCSD(kmf)
250    eris = mycc.ao2mo()
251    ecc, t1, t2 = mycc.kernel(eris=eris)
252    energy_t = kernel(mycc, eris=eris, verbose=9)
253
254    # Start of supercell calculations
255    from pyscf.pbc.tools.pbc import super_cell
256    supcell = super_cell(cell, nmp)
257    supcell.build()
258    kmf = scf.RHF(supcell, exxdiv=None)
259    kmf.conv_tol = 1e-12
260    kmf.conv_tol_grad = 1e-12
261    kmf.direct_scf_tol = 1e-16
262    sup_ehf = kmf.kernel()
263
264    myscc = cc.RCCSD(kmf)
265    eris = myscc.ao2mo()
266    sup_ecc, t1, t2 = myscc.kernel(eris=eris)
267    sup_energy_t = myscc.ccsd_t(eris=eris)
268    print("Kpoint    CCSD: %20.16f" % ecc)
269    print("Supercell CCSD: %20.16f" % (sup_ecc/np.prod(nmp)))
270    print("Kpoint    CCSD(T): %20.16f" % energy_t)
271    print("Supercell CCSD(T): %20.16f" % (sup_energy_t/np.prod(nmp)))
272