1#!/usr/bin/env python
2# Copyright 2014-2021 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Timothy Berkelbach <tim.berkelbach@gmail.com>
17#         James McClain <jdmcclain47@gmail.com>
18#
19
20
21'''
22kpoint-adapted and spin-adapted MP2
23t2[i,j,a,b] = <ij|ab> / D_ij^ab
24
25t2 and eris are never stored in full, only a partial
26eri of size (nkpts,nocc,nocc,nvir,nvir)
27'''
28
29import numpy as np
30from scipy.linalg import block_diag
31import h5py
32
33from pyscf import lib
34from pyscf.lib import logger, einsum
35from pyscf.mp import mp2
36from pyscf.pbc import df
37from pyscf.pbc.lib import kpts_helper
38from pyscf.lib.parameters import LARGE_DENOM
39from pyscf import __config__
40
41WITH_T2 = getattr(__config__, 'mp_mp2_with_t2', True)
42
43def kernel(mp, mo_energy, mo_coeff, verbose=logger.NOTE, with_t2=WITH_T2):
44    """Computes k-point RMP2 energy.
45
46    Args:
47        mp (KMP2): an instance of KMP2
48        mo_energy (list): a list of numpy.ndarray. Each array contains MO energies of
49                          shape (Nmo,) for one kpt
50        mo_coeff (list): a list of numpy.ndarray. Each array contains MO coefficients
51                         of shape (Nao, Nmo) for one kpt
52        verbose (int, optional): level of verbosity. Defaults to logger.NOTE (=3).
53        with_t2 (bool, optional): whether to compute t2 amplitudes. Defaults to WITH_T2 (=True).
54
55    Returns:
56        KMP2 energy and t2 amplitudes (=None if with_t2 is False)
57    """
58    cput0 = (logger.process_clock(), logger.perf_counter())
59    log = logger.new_logger(mp, verbose)
60
61    mp.dump_flags()
62    nmo = mp.nmo
63    nocc = mp.nocc
64    nvir = nmo - nocc
65    nkpts = mp.nkpts
66
67    with_df_ints = mp.with_df_ints and isinstance(mp._scf.with_df, df.GDF)
68
69    mem_avail = mp.max_memory - lib.current_memory()[0]
70    mem_usage = (nkpts * (nocc * nvir)**2) * 16 / 1e6
71    if with_df_ints:
72        naux = mp._scf.with_df.auxcell.nao_nr()
73        mem_usage += (nkpts**2 * naux * nocc * nvir) * 16 / 1e6
74    if with_t2:
75        mem_usage += (nkpts**3 * (nocc * nvir)**2) * 16 / 1e6
76    if mem_usage > mem_avail:
77        raise MemoryError('Insufficient memory! MP2 memory usage %d MB (currently available %d MB)'
78                          % (mem_usage, mem_avail))
79
80    eia = np.zeros((nocc,nvir))
81    eijab = np.zeros((nocc,nocc,nvir,nvir))
82
83    fao2mo = mp._scf.with_df.ao2mo
84    kconserv = mp.khelper.kconserv
85    emp2 = 0.
86    oovv_ij = np.zeros((nkpts,nocc,nocc,nvir,nvir), dtype=mo_coeff[0].dtype)
87
88    mo_e_o = [mo_energy[k][:nocc] for k in range(nkpts)]
89    mo_e_v = [mo_energy[k][nocc:] for k in range(nkpts)]
90
91    # Get location of non-zero/padded elements in occupied and virtual space
92    nonzero_opadding, nonzero_vpadding = padding_k_idx(mp, kind="split")
93
94    if with_t2:
95        t2 = np.zeros((nkpts, nkpts, nkpts, nocc, nocc, nvir, nvir), dtype=complex)
96    else:
97        t2 = None
98
99    # Build 3-index DF tensor Lov
100    if with_df_ints:
101        Lov = _init_mp_df_eris(mp)
102
103    for ki in range(nkpts):
104        for kj in range(nkpts):
105            for ka in range(nkpts):
106                kb = kconserv[ki,ka,kj]
107                # (ia|jb)
108                if with_df_ints:
109                    oovv_ij[ka] = (1./nkpts) * einsum("Lia,Ljb->iajb", Lov[ki, ka], Lov[kj, kb]).transpose(0,2,1,3)
110                else:
111                    orbo_i = mo_coeff[ki][:,:nocc]
112                    orbo_j = mo_coeff[kj][:,:nocc]
113                    orbv_a = mo_coeff[ka][:,nocc:]
114                    orbv_b = mo_coeff[kb][:,nocc:]
115                    oovv_ij[ka] = fao2mo((orbo_i,orbv_a,orbo_j,orbv_b),
116                                         (mp.kpts[ki],mp.kpts[ka],mp.kpts[kj],mp.kpts[kb]),
117                                         compact=False).reshape(nocc,nvir,nocc,nvir).transpose(0,2,1,3) / nkpts
118            for ka in range(nkpts):
119                kb = kconserv[ki,ka,kj]
120
121                # Remove zero/padded elements from denominator
122                eia = LARGE_DENOM * np.ones((nocc, nvir), dtype=mo_energy[0].dtype)
123                n0_ovp_ia = np.ix_(nonzero_opadding[ki], nonzero_vpadding[ka])
124                eia[n0_ovp_ia] = (mo_e_o[ki][:,None] - mo_e_v[ka])[n0_ovp_ia]
125
126                ejb = LARGE_DENOM * np.ones((nocc, nvir), dtype=mo_energy[0].dtype)
127                n0_ovp_jb = np.ix_(nonzero_opadding[kj], nonzero_vpadding[kb])
128                ejb[n0_ovp_jb] = (mo_e_o[kj][:,None] - mo_e_v[kb])[n0_ovp_jb]
129
130                eijab = lib.direct_sum('ia,jb->ijab',eia,ejb)
131                t2_ijab = np.conj(oovv_ij[ka]/eijab)
132                if with_t2:
133                    t2[ki, kj, ka] = t2_ijab
134                woovv = 2*oovv_ij[ka] - oovv_ij[kb].transpose(0,1,3,2)
135                emp2 += einsum('ijab,ijab', t2_ijab, woovv).real
136
137    log.timer("KMP2", *cput0)
138
139    emp2 /= nkpts
140
141    return emp2, t2
142
143
144def _init_mp_df_eris(mp):
145    """Compute 3-center electron repulsion integrals, i.e. (L|ov),
146    where `L` denotes DF auxiliary basis functions and `o` and `v` occupied and virtual
147    canonical crystalline orbitals. Note that `o` and `v` contain kpt indices `ko` and `kv`,
148    and the third kpt index `kL` is determined by the conservation of momentum.
149
150    Arguments:
151        mp (KMP2) -- A KMP2 instance
152
153    Returns:
154        Lov (numpy.ndarray) -- 3-center DF ints, with shape (nkpts, nkpts, naux, nocc, nvir)
155    """
156    from pyscf.pbc.df import df
157    from pyscf.ao2mo import _ao2mo
158    from pyscf.pbc.lib.kpts_helper import gamma_point
159
160    log = logger.Logger(mp.stdout, mp.verbose)
161
162    if mp._scf.with_df._cderi is None:
163        mp._scf.with_df.build()
164
165    cell = mp._scf.cell
166    if cell.dimension == 2:
167        # 2D ERIs are not positive definite. The 3-index tensors are stored in
168        # two part. One corresponds to the positive part and one corresponds
169        # to the negative part. The negative part is not considered in the
170        # DF-driven CCSD implementation.
171        raise NotImplementedError
172
173    nocc = mp.nocc
174    nmo = mp.nmo
175    nvir = nmo - nocc
176    nao = cell.nao_nr()
177
178    mo_coeff = _add_padding(mp, mp.mo_coeff, mp.mo_energy)[0]
179    kpts = mp.kpts
180    nkpts = len(kpts)
181    if gamma_point(kpts):
182        dtype = np.double
183    else:
184        dtype = np.complex128
185    dtype = np.result_type(dtype, *mo_coeff)
186    Lov = np.empty((nkpts, nkpts), dtype=object)
187
188    cput0 = (logger.process_clock(), logger.perf_counter())
189
190    bra_start = 0
191    bra_end = nocc
192    ket_start = nmo+nocc
193    ket_end = ket_start + nvir
194    with h5py.File(mp._scf.with_df._cderi, 'r') as f:
195        kptij_lst = f['j3c-kptij'][:]
196        tao = []
197        ao_loc = None
198        for ki, kpti in enumerate(kpts):
199            for kj, kptj in enumerate(kpts):
200                kpti_kptj = np.array((kpti, kptj))
201                Lpq_ao = np.asarray(df._getitem(f, 'j3c', kpti_kptj, kptij_lst))
202
203                mo = np.hstack((mo_coeff[ki], mo_coeff[kj]))
204                mo = np.asarray(mo, dtype=dtype, order='F')
205                if dtype == np.double:
206                    out = _ao2mo.nr_e2(Lpq_ao, mo, (bra_start, bra_end, ket_start, ket_end), aosym='s2')
207                else:
208                    #Note: Lpq.shape[0] != naux if linear dependency is found in auxbasis
209                    if Lpq_ao[0].size != nao**2:  # aosym = 's2'
210                        Lpq_ao = lib.unpack_tril(Lpq_ao).astype(np.complex128)
211                    out = _ao2mo.r_e2(Lpq_ao, mo, (bra_start, bra_end, ket_start, ket_end), tao, ao_loc)
212                Lov[ki, kj] = out.reshape(-1, nocc, nvir)
213
214    log.timer_debug1("transforming DF-MP2 integrals", *cput0)
215
216    return Lov
217
218
219def _padding_k_idx(nmo, nocc, kind="split"):
220    """A convention used for padding vectors, matrices and tensors in case when occupation numbers depend on the
221    k-point index.
222    Args:
223        nmo (Iterable): k-dependent orbital number;
224        nocc (Iterable): k-dependent occupation numbers;
225        kind (str): either "split" (occupied and virtual spaces are split) or "joint" (occupied and virtual spaces are
226        the joint;
227
228    Returns:
229        Two lists corresponding to the occupied and virtual spaces for kind="split". Each list contains integer arrays
230        with indexes pointing to actual non-zero entries in the padded vector/matrix/tensor. If kind="joint", a single
231        list of arrays is returned corresponding to the entire MO space.
232    """
233    if kind not in ("split", "joint"):
234        raise ValueError("The 'kind' argument must be one of 'split', 'joint'")
235
236    if kind == "split":
237        indexes_o = []
238        indexes_v = []
239    else:
240        indexes = []
241
242    nocc = np.array(nocc)
243    nmo = np.array(nmo)
244    nvirt = nmo - nocc
245    dense_o = np.amax(nocc)
246    dense_v = np.amax(nvirt)
247    dense_nmo = dense_o + dense_v
248
249    for k_o, k_nmo in zip(nocc, nmo):
250        k_v = k_nmo - k_o
251        if kind == "split":
252            indexes_o.append(np.arange(k_o))
253            indexes_v.append(np.arange(dense_v - k_v, dense_v))
254        else:
255            indexes.append(np.concatenate((
256                np.arange(k_o),
257                np.arange(dense_nmo - k_v, dense_nmo),
258            )))
259
260    if kind == "split":
261        return indexes_o, indexes_v
262
263    else:
264        return indexes
265
266
267def padding_k_idx(mp, kind="split"):
268    """A convention used for padding vectors, matrices and tensors in case when occupation numbers depend on the
269    k-point index.
270
271    This implementation stores k-dependent Fock and other matrix in dense arrays with additional dimensions
272    corresponding to k-point indexes. In case when the occupation numbers depend on the k-point index (i.e. a metal) or
273    when some k-points have more Bloch basis functions than others the corresponding data structure has to be padded
274    with entries that are not used (fictitious occupied and virtual degrees of freedom). Current convention stores these
275    states at the Fermi level as shown in the following example.
276
277    +----+--------+--------+--------+
278    |    |  k=0   |  k=1   |  k=2   |
279    |    +--------+--------+--------+
280    |    | nocc=2 | nocc=3 | nocc=2 |
281    |    | nvir=4 | nvir=3 | nvir=3 |
282    +====+========+========+========+
283    | v3 |  k0v3  |  k1v2  |  k2v2  |
284    +----+--------+--------+--------+
285    | v2 |  k0v2  |  k1v1  |  k2v1  |
286    +----+--------+--------+--------+
287    | v1 |  k0v1  |  k1v0  |  k2v0  |
288    +----+--------+--------+--------+
289    | v0 |  k0v0  |        |        |
290    +====+========+========+========+
291    |          Fermi level          |
292    +====+========+========+========+
293    | o2 |        |  k1o2  |        |
294    +----+--------+--------+--------+
295    | o1 |  k0o1  |  k1o1  |  k2o1  |
296    +----+--------+--------+--------+
297    | o0 |  k0o0  |  k1o0  |  k2o0  |
298    +----+--------+--------+--------+
299
300    In the above example, `get_nmo(mp, per_kpoint=True) == (6, 6, 5)`, `get_nocc(mp, per_kpoint) == (2, 3, 2)`. The
301    resulting dense `get_nmo(mp) == 7` and `get_nocc(mp) == 3` correspond to padded dimensions. This function will
302    return the following indexes corresponding to the filled entries of the above table:
303
304    >>> padding_k_idx(mp, kind="split")
305    ([(0, 1), (0, 1, 2), (0, 1)], [(0, 1, 2, 3), (1, 2, 3), (1, 2, 3)])
306
307    >>> padding_k_idx(mp, kind="joint")
308    [(0, 1, 3, 4, 5, 6), (0, 1, 2, 4, 5, 6), (0, 1, 4, 5, 6)]
309
310    Args:
311        mp (:class:`MP2`): An instantiation of an SCF or post-Hartree-Fock object.
312        kind (str): either "split" (occupied and virtual spaces are split) or "joint" (occupied and virtual spaces are
313        the joint;
314
315    Returns:
316        Two lists corresponding to the occupied and virtual spaces for kind="split". Each list contains integer arrays
317        with indexes pointing to actual non-zero entries in the padded vector/matrix/tensor. If kind="joint", a single
318        list of arrays is returned corresponding to the entire MO space.
319    """
320    return _padding_k_idx(mp.get_nmo(per_kpoint=True), mp.get_nocc(per_kpoint=True), kind=kind)
321
322
323def padded_mo_energy(mp, mo_energy):
324    """
325    Pads energies of active MOs.
326
327    Args:
328        mp (:class:`MP2`): An instantiation of an SCF or post-Hartree-Fock object.
329        mo_energy (ndarray): original non-padded molecular energies;
330
331    Returns:
332        Padded molecular energies.
333    """
334    frozen_mask = get_frozen_mask(mp)
335    padding_convention = padding_k_idx(mp, kind="joint")
336    nkpts = mp.nkpts
337
338    result = np.zeros((nkpts, mp.nmo), dtype=mo_energy[0].dtype)
339    for k in range(nkpts):
340        result[np.ix_([k], padding_convention[k])] = mo_energy[k][frozen_mask[k]]
341
342    return result
343
344
345def padded_mo_coeff(mp, mo_coeff):
346    """
347    Pads coefficients of active MOs.
348
349    Args:
350        mp (:class:`MP2`): An instantiation of an SCF or post-Hartree-Fock object.
351        mo_coeff (ndarray): original non-padded molecular coefficients;
352
353    Returns:
354        Padded molecular coefficients.
355    """
356    frozen_mask = get_frozen_mask(mp)
357    padding_convention = padding_k_idx(mp, kind="joint")
358    nkpts = mp.nkpts
359
360    result = np.zeros((nkpts, mo_coeff[0].shape[0], mp.nmo), dtype=mo_coeff[0].dtype)
361    for k in range(nkpts):
362        result[np.ix_([k], np.arange(result.shape[1]), padding_convention[k])] = mo_coeff[k][:, frozen_mask[k]]
363
364    return result
365
366
367def _frozen_sanity_check(frozen, mo_occ, kpt_idx):
368    '''Performs a few sanity checks on the frozen array and mo_occ.
369
370    Specific tests include checking for duplicates within the frozen array.
371
372    Args:
373        frozen (array_like of int): The orbital indices that will be frozen.
374        mo_occ (:obj:`ndarray` of int): The occupuation number for each orbital
375            resulting from a mean-field-like calculation.
376        kpt_idx (int): The k-point that `mo_occ` and `frozen` belong to.
377
378    '''
379    frozen = np.array(frozen)
380    nocc = np.count_nonzero(mo_occ > 0)
381
382    assert nocc, 'No occupied orbitals?\n\nnocc = %s\nmo_occ = %s' % (nocc, mo_occ)
383    all_frozen_unique = (len(frozen) - len(np.unique(frozen))) == 0
384    if not all_frozen_unique:
385        raise RuntimeError('Frozen orbital list contains duplicates!\n\nkpt_idx %s\n'
386                           'frozen %s' % (kpt_idx, frozen))
387    if len(frozen) > 0 and np.max(frozen) > len(mo_occ) - 1:
388        raise RuntimeError('Freezing orbital not in MO list!\n\nkpt_idx %s\n'
389                           'frozen %s\nmax orbital idx %s' % (kpt_idx, frozen, len(mo_occ) - 1))
390
391
392def get_nocc(mp, per_kpoint=False):
393    '''Number of occupied orbitals for k-point calculations.
394
395    Number of occupied orbitals for use in a calculation with k-points, taking into
396    account frozen orbitals.
397
398    Args:
399        mp (:class:`MP2`): An instantiation of an SCF or post-Hartree-Fock object.
400        per_kpoint (bool, optional): True returns the number of occupied
401            orbitals at each k-point.  False gives the max of this list.
402
403    Returns:
404        nocc (int, list of int): Number of occupied orbitals. For return type, see description of arg
405            `per_kpoint`.
406
407    '''
408    for i, moocc in enumerate(mp.mo_occ):
409        if np.any(moocc % 1 != 0):
410            raise RuntimeError("Fractional occupation numbers encountered @ kp={:d}: {}. This may have been caused by "
411                               "smearing of occupation numbers in the mean-field calculation. If so, consider "
412                               "executing mf.smearing_method = False; mf.mo_occ = mf.get_occ() prior to calling "
413                               "this".format(i, moocc))
414    if mp._nocc is not None:
415        return mp._nocc
416    elif mp.frozen is None:
417        nocc = [np.count_nonzero(mp.mo_occ[ikpt]) for ikpt in range(mp.nkpts)]
418    elif isinstance(mp.frozen, (int, np.integer)):
419        nocc = [(np.count_nonzero(mp.mo_occ[ikpt]) - mp.frozen) for ikpt in range(mp.nkpts)]
420    elif isinstance(mp.frozen[0], (int, np.integer)):
421        [_frozen_sanity_check(mp.frozen, mp.mo_occ[ikpt], ikpt) for ikpt in range(mp.nkpts)]
422        nocc = []
423        for ikpt in range(mp.nkpts):
424            max_occ_idx = np.max(np.where(mp.mo_occ[ikpt] > 0))
425            frozen_nocc = np.sum(np.array(mp.frozen) <= max_occ_idx)
426            nocc.append(np.count_nonzero(mp.mo_occ[ikpt]) - frozen_nocc)
427    elif isinstance(mp.frozen[0], (list, np.ndarray)):
428        nkpts = len(mp.frozen)
429        if nkpts != mp.nkpts:
430            raise RuntimeError('Frozen list has a different number of k-points (length) than passed in mean-field/'
431                               'correlated calculation.  \n\nCalculation nkpts = %d, frozen list = %s '
432                               '(length = %d)' % (mp.nkpts, mp.frozen, nkpts))
433        [_frozen_sanity_check(frozen, mo_occ, ikpt) for ikpt, frozen, mo_occ in zip(range(nkpts), mp.frozen, mp.mo_occ)]
434
435        nocc = []
436        for ikpt, frozen in enumerate(mp.frozen):
437            max_occ_idx = np.max(np.where(mp.mo_occ[ikpt] > 0))
438            frozen_nocc = np.sum(np.array(frozen) <= max_occ_idx)
439            nocc.append(np.count_nonzero(mp.mo_occ[ikpt]) - frozen_nocc)
440    else:
441        raise NotImplementedError
442
443    assert any(np.array(nocc) > 0), ('Must have occupied orbitals! \n\nnocc %s\nfrozen %s\nmo_occ %s' %
444           (nocc, mp.frozen, mp.mo_occ))
445
446    if not per_kpoint:
447        nocc = np.amax(nocc)
448
449    return nocc
450
451
452def get_nmo(mp, per_kpoint=False):
453    '''Number of orbitals for k-point calculations.
454
455    Number of orbitals for use in a calculation with k-points, taking into account
456    frozen orbitals.
457
458    Note:
459        If `per_kpoint` is False, then the number of orbitals here is equal to max(nocc) + max(nvir),
460        where each max is done over all k-points.  Otherwise the number of orbitals is returned
461        as a list of number of orbitals at each k-point.
462
463    Args:
464        mp (:class:`MP2`): An instantiation of an SCF or post-Hartree-Fock object.
465        per_kpoint (bool, optional): True returns the number of orbitals at each k-point.
466            For a description of False, see Note.
467
468    Returns:
469        nmo (int, list of int): Number of orbitals. For return type, see description of arg
470            `per_kpoint`.
471
472    '''
473    if mp._nmo is not None:
474        return mp._nmo
475
476    if mp.frozen is None:
477        nmo = [len(mp.mo_occ[ikpt]) for ikpt in range(mp.nkpts)]
478    elif isinstance(mp.frozen, (int, np.integer)):
479        nmo = [len(mp.mo_occ[ikpt]) - mp.frozen for ikpt in range(mp.nkpts)]
480    elif isinstance(mp.frozen[0], (int, np.integer)):
481        [_frozen_sanity_check(mp.frozen, mp.mo_occ[ikpt], ikpt) for ikpt in range(mp.nkpts)]
482        nmo = [len(mp.mo_occ[ikpt]) - len(mp.frozen) for ikpt in range(mp.nkpts)]
483    elif isinstance(mp.frozen, (list, np.ndarray)):
484        nkpts = len(mp.frozen)
485        if nkpts != mp.nkpts:
486            raise RuntimeError('Frozen list has a different number of k-points (length) than passed in mean-field/'
487                               'correlated calculation.  \n\nCalculation nkpts = %d, frozen list = %s '
488                               '(length = %d)' % (mp.nkpts, mp.frozen, nkpts))
489        [_frozen_sanity_check(fro, mo_occ, ikpt) for ikpt, fro, mo_occ in zip(range(nkpts), mp.frozen, mp.mo_occ)]
490
491        nmo = [len(mp.mo_occ[ikpt]) - len(mp.frozen[ikpt]) for ikpt in range(nkpts)]
492    else:
493        raise NotImplementedError
494
495    assert all(np.array(nmo) > 0), ('Must have a positive number of orbitals!\n\nnmo %s\nfrozen %s\nmo_occ %s' %
496           (nmo, mp.frozen, mp.mo_occ))
497
498    if not per_kpoint:
499        # Depending on whether there are more occupied bands, we want to make sure that
500        # nmo has enough room for max(nocc) + max(nvir) number of orbitals for occupied
501        # and virtual space
502        nocc = mp.get_nocc(per_kpoint=True)
503        nmo = np.max(nocc) + np.max(np.array(nmo) - np.array(nocc))
504
505    return nmo
506
507
508def get_frozen_mask(mp):
509    '''Boolean mask for orbitals in k-point post-HF method.
510
511    Creates a boolean mask to remove frozen orbitals and keep other orbitals for post-HF
512    calculations.
513
514    Args:
515        mp (:class:`MP2`): An instantiation of an SCF or post-Hartree-Fock object.
516
517    Returns:
518        moidx (list of :obj:`ndarray` of `np.bool`): Boolean mask of orbitals to include.
519
520    '''
521    moidx = [np.ones(x.size, dtype=np.bool) for x in mp.mo_occ]
522    if mp.frozen is None:
523        pass
524    elif isinstance(mp.frozen, (int, np.integer)):
525        for idx in moidx:
526            idx[:mp.frozen] = False
527    elif isinstance(mp.frozen[0], (int, np.integer)):
528        frozen = list(mp.frozen)
529        for idx in moidx:
530            idx[frozen] = False
531    elif isinstance(mp.frozen[0], (list, np.ndarray)):
532        nkpts = len(mp.frozen)
533        if nkpts != mp.nkpts:
534            raise RuntimeError('Frozen list has a different number of k-points (length) than passed in mean-field/'
535                               'correlated calculation.  \n\nCalculation nkpts = %d, frozen list = %s '
536                               '(length = %d)' % (mp.nkpts, mp.frozen, nkpts))
537        [_frozen_sanity_check(fro, mo_occ, ikpt) for ikpt, fro, mo_occ in zip(range(nkpts), mp.frozen, mp.mo_occ)]
538        for ikpt, kpt_occ in enumerate(moidx):
539            kpt_occ[mp.frozen[ikpt]] = False
540    else:
541        raise NotImplementedError
542
543    return moidx
544
545
546def _add_padding(mp, mo_coeff, mo_energy):
547    nmo = mp.nmo
548
549    # Check if these are padded mo coefficients and energies
550    if not np.all([x.shape[1] == nmo for x in mo_coeff]):
551        mo_coeff = padded_mo_coeff(mp, mo_coeff)
552
553    if not np.all([x.shape[0] == nmo for x in mo_energy]):
554        mo_energy = padded_mo_energy(mp, mo_energy)
555    return mo_coeff, mo_energy
556
557
558def make_rdm1(mp, t2=None, kind="compact"):
559    r"""
560    Spin-traced one-particle density matrix in the MO basis representation.
561    The occupied-virtual orbital response is not included.
562
563    dm1[p,q] = <q_alpha^\dagger p_alpha> + <q_beta^\dagger p_beta>
564
565    The convention of 1-pdm is based on McWeeney's book, Eq (5.4.20).
566    The contraction between 1-particle Hamiltonian and rdm1 is
567    E = einsum('pq,qp', h1, rdm1)
568
569    Args:
570        mp (KMP2): a KMP2 kernel object;
571        t2 (ndarray): a t2 MP2 tensor;
572        kind (str): either 'compact' or 'padded' - defines behavior for k-dependent MO basis sizes;
573
574    Returns:
575        A k-dependent single-particle density matrix.
576    """
577    if kind not in ("compact", "padded"):
578        raise ValueError("The 'kind' argument should be either 'compact' or 'padded'")
579    d_imds = _gamma1_intermediates(mp, t2=t2)
580    result = []
581    padding_idxs = padding_k_idx(mp, kind="joint")
582    for (oo, vv), idxs in zip(zip(*d_imds), padding_idxs):
583        oo += np.eye(*oo.shape)
584        d = block_diag(oo, vv)
585        d += d.conj().T
586        if kind == "padded":
587            result.append(d)
588        else:
589            result.append(d[np.ix_(idxs, idxs)])
590    return result
591
592
593def make_rdm2(mp, t2=None, kind="compact"):
594    r'''
595    Spin-traced two-particle density matrix in MO basis
596
597    .. math::
598
599        dm2[p,q,r,s] = \sum_{\sigma,\tau} <p_\sigma^\dagger r_\tau^\dagger s_\tau q_\sigma>
600
601    Note the contraction between ERIs (in Chemist's notation) and rdm2 is
602    E = einsum('pqrs,pqrs', eri, rdm2)
603    '''
604    if kind not in ("compact", "padded"):
605        raise ValueError("The 'kind' argument should be either 'compact' or 'padded'")
606    if t2 is None: t2 = mp.t2
607    dm1 = mp.make_rdm1(t2, "padded")
608    nmo = mp.nmo
609    nocc = mp.nocc
610    nkpts = mp.nkpts
611    dtype = t2.dtype
612
613    dm2 = np.zeros((nkpts,nkpts,nkpts,nmo,nmo,nmo,nmo),dtype=dtype)
614    for ki in range(nkpts):
615        for kj in range(nkpts):
616            for ka in range(nkpts):
617                kb = mp.khelper.kconserv[ki, ka, kj]
618                dovov = t2[ki, kj, ka].transpose(0,2,1,3) * 2 - t2[kj, ki, ka].transpose(1,2,0,3)
619                dovov *= 2
620                dm2[ki,ka,kj,:nocc,nocc:,:nocc,nocc:] = dovov
621                dm2[ka,ki,kb,nocc:,:nocc,nocc:,:nocc] = dovov.transpose(1,0,3,2).conj()
622
623    occidx = padding_k_idx(mp, kind="split")[0]
624    for ki in range(nkpts):
625        for i in occidx[ki]:
626            dm1[ki][i,i] -= 2
627
628    for ki in range(nkpts):
629        for kp in range(nkpts):
630            for i in occidx[ki]:
631                dm2[ki,ki,kp,i,i,:,:] += dm1[kp].T * 2
632                dm2[kp,kp,ki,:,:,i,i] += dm1[kp].T * 2
633                dm2[kp,ki,ki,:,i,i,:] -= dm1[kp].T
634                dm2[ki,kp,kp,i,:,:,i] -= dm1[kp]
635
636    for ki in range(nkpts):
637        for kj in range(nkpts):
638            for i in occidx[ki]:
639                for j in occidx[kj]:
640                    dm2[ki,ki,kj,i,i,j,j] += 4
641                    dm2[ki,kj,kj,i,j,j,i] -= 2
642
643    if kind == "padded":
644        return dm2
645    else:
646        idx = padding_k_idx(mp, kind="joint")
647        result = []
648        for kp in range(nkpts):
649            for kq in range(nkpts):
650                for kr in range(nkpts):
651                    ks = mp.khelper.kconserv[kp, kq, kr]
652                    result.append(dm2[kp,kq,kr][np.ix_(idx[kp],idx[kq],idx[kr],idx[ks])])
653        return result
654
655
656def _gamma1_intermediates(mp, t2=None):
657    # Memory optimization should be here
658    if t2 is None:
659        t2 = mp.t2
660    if t2 is None:
661        raise NotImplementedError("Run kmp2.kernel with `with_t2=True`")
662    nmo = mp.nmo
663    nocc = mp.nocc
664    nvir = nmo - nocc
665    nkpts = mp.nkpts
666    dtype = t2.dtype
667
668    dm1occ = np.zeros((nkpts, nocc, nocc), dtype=dtype)
669    dm1vir = np.zeros((nkpts, nvir, nvir), dtype=dtype)
670
671    for ki in range(nkpts):
672        for kj in range(nkpts):
673            for ka in range(nkpts):
674                kb = mp.khelper.kconserv[ki, ka, kj]
675
676                dm1vir[kb] += einsum('ijax,ijay->yx', t2[ki][kj][ka].conj(), t2[ki][kj][ka]) * 2 -\
677                              einsum('ijax,ijya->yx', t2[ki][kj][ka].conj(), t2[ki][kj][kb])
678                dm1occ[kj] += einsum('ixab,iyab->xy', t2[ki][kj][ka].conj(), t2[ki][kj][ka]) * 2 -\
679                              einsum('ixab,iyba->xy', t2[ki][kj][ka].conj(), t2[ki][kj][kb])
680    return -dm1occ, dm1vir
681
682
683class KMP2(mp2.MP2):
684    def __init__(self, mf, frozen=None, mo_coeff=None, mo_occ=None):
685
686        if mo_coeff is None: mo_coeff = mf.mo_coeff
687        if mo_occ is None: mo_occ = mf.mo_occ
688
689        self.mol = mf.mol
690        self._scf = mf
691        self.verbose = self.mol.verbose
692        self.stdout = self.mol.stdout
693        self.max_memory = mf.max_memory
694
695        self.frozen = frozen
696        if isinstance(self._scf.with_df, df.GDF):
697            self.with_df_ints = True
698        else:
699            self.with_df_ints = False
700
701##################################################
702# don't modify the following attributes, they are not input options
703        self.kpts = mf.kpts
704        self.mo_energy = mf.mo_energy
705        self.nkpts = len(self.kpts)
706        self.khelper = kpts_helper.KptsHelper(mf.cell, mf.kpts)
707        self.mo_coeff = mo_coeff
708        self.mo_occ = mo_occ
709        self._nocc = None
710        self._nmo = None
711        self.e_corr = None
712        self.e_hf = None
713        self.t2 = None
714        self._keys = set(self.__dict__.keys())
715
716    get_nocc = get_nocc
717    get_nmo = get_nmo
718    get_frozen_mask = get_frozen_mask
719    make_rdm1 = make_rdm1
720    make_rdm2 = make_rdm2
721
722    def dump_flags(self):
723        logger.info(self, "")
724        logger.info(self, "******** %s ********", self.__class__)
725        logger.info(self, "nkpts = %d", self.nkpts)
726        logger.info(self, "nocc = %d", self.nocc)
727        logger.info(self, "nmo = %d", self.nmo)
728        logger.info(self, "with_df_ints = %s", self.with_df_ints)
729
730        if self.frozen is not None:
731            logger.info(self, "frozen orbitals = %s", self.frozen)
732        logger.info(
733            self,
734            "max_memory %d MB (current use %d MB)",
735            self.max_memory,
736            lib.current_memory()[0],
737        )
738        return self
739
740    def kernel(self, mo_energy=None, mo_coeff=None, with_t2=WITH_T2):
741        if mo_energy is None:
742            mo_energy = self.mo_energy
743        if mo_coeff is None:
744            mo_coeff = self.mo_coeff
745        if mo_energy is None or mo_coeff is None:
746            log = logger.Logger(self.stdout, self.verbose)
747            log.warn('mo_coeff, mo_energy are not given.\n'
748                     'You may need to call mf.kernel() to generate them.')
749            raise RuntimeError
750
751        mo_coeff, mo_energy = _add_padding(self, mo_coeff, mo_energy)
752
753        # TODO: compute e_hf for non-canonical SCF
754        self.e_hf = self._scf.e_tot
755
756        self.e_corr, self.t2 = \
757                kernel(self, mo_energy, mo_coeff, verbose=self.verbose, with_t2=with_t2)
758        logger.log(self, 'KMP2 energy = %.15g', self.e_corr)
759        return self.e_corr, self.t2
760KRMP2 = KMP2
761
762
763from pyscf.pbc import scf
764scf.khf.KRHF.MP2 = lib.class_as_method(KRMP2)
765scf.kghf.KGHF.MP2 = None
766scf.krohf.KROHF.MP2 = None
767
768
769if __name__ == '__main__':
770    from pyscf.pbc import gto, scf, mp
771
772    cell = gto.Cell()
773    cell.atom='''
774    C 0.000000000000   0.000000000000   0.000000000000
775    C 1.685068664391   1.685068664391   1.685068664391
776    '''
777    cell.basis = 'gth-szv'
778    cell.pseudo = 'gth-pade'
779    cell.a = '''
780    0.000000000, 3.370137329, 3.370137329
781    3.370137329, 0.000000000, 3.370137329
782    3.370137329, 3.370137329, 0.000000000'''
783    cell.unit = 'B'
784    cell.verbose = 5
785    cell.build()
786
787    # Running HF and MP2 with 1x1x2 Monkhorst-Pack k-point mesh
788    kmf = scf.KRHF(cell, kpts=cell.make_kpts([1,1,2]), exxdiv=None)
789    ehf = kmf.kernel()
790
791    mymp = mp.KMP2(kmf)
792    emp2, t2 = mymp.kernel()
793    print(emp2 - -0.204721432828996)
794
795