1#!/usr/bin/env python
2# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#
18
19import numpy
20from pyscf import lib
21from pyscf.ao2mo import _ao2mo
22from pyscf.ao2mo.incore import _conc_mos
23from pyscf.pbc.df.fft_ao2mo import _format_kpts
24from pyscf.pbc.df import df_ao2mo
25from pyscf.pbc.df import aft_ao2mo
26from pyscf.pbc.lib import kpts_helper
27from pyscf.pbc.lib.kpts_helper import gamma_point, unique
28from pyscf import __config__
29
30
31def get_eri(mydf, kpts=None,
32            compact=getattr(__config__, 'pbc_df_ao2mo_get_eri_compact', True)):
33    if mydf._cderi is None:
34        mydf.build()
35
36    kptijkl = _format_kpts(kpts)
37    eri = aft_ao2mo.get_eri(mydf, kptijkl, compact=compact)
38    eri += df_ao2mo.get_eri(mydf, kptijkl, compact=compact)
39    return eri
40
41
42def general(mydf, mo_coeffs, kpts=None,
43            compact=getattr(__config__, 'pbc_df_ao2mo_general_compact', True)):
44    if mydf._cderi is None:
45        mydf.build()
46
47    kptijkl = _format_kpts(kpts)
48    if isinstance(mo_coeffs, numpy.ndarray) and mo_coeffs.ndim == 2:
49        mo_coeffs = (mo_coeffs,) * 4
50    eri_mo = aft_ao2mo.general(mydf, mo_coeffs, kptijkl, compact=compact)
51    eri_mo += df_ao2mo.general(mydf, mo_coeffs, kptijkl, compact=compact)
52    return eri_mo
53
54def ao2mo_7d(mydf, mo_coeff_kpts, kpts=None, factor=1, out=None):
55    cell = mydf.cell
56    if kpts is None:
57        kpts = mydf.kpts
58    nkpts = len(kpts)
59
60    if isinstance(mo_coeff_kpts, numpy.ndarray) and mo_coeff_kpts.ndim == 3:
61        mo_coeff_kpts = [mo_coeff_kpts] * 4
62    else:
63        mo_coeff_kpts = list(mo_coeff_kpts)
64
65    # Shape of the orbitals can be different on different k-points. The
66    # orbital coefficients must be formatted (padded by zeros) so that the
67    # shape of the orbital coefficients are the same on all k-points. This can
68    # be achieved by calling pbc.mp.kmp2.padded_mo_coeff function
69    nmoi, nmoj, nmok, nmol = [x.shape[2] for x in mo_coeff_kpts]
70    eri_shape = (nkpts, nkpts, nkpts, nmoi, nmoj, nmok, nmol)
71    if gamma_point(kpts):
72        dtype = numpy.result_type(*mo_coeff_kpts)
73    else:
74        dtype = numpy.complex128
75
76    if out is None:
77        out = numpy.empty(eri_shape, dtype=dtype)
78    else:
79        assert(out.shape == eri_shape)
80
81    kptij_lst = numpy.array([(ki, kj) for ki in kpts for kj in kpts])
82    kptis_lst = kptij_lst[:,0]
83    kptjs_lst = kptij_lst[:,1]
84    kpt_ji = kptjs_lst - kptis_lst
85    uniq_kpts, uniq_index, uniq_inverse = unique(kpt_ji)
86    ngrids = numpy.prod(mydf.mesh)
87    nao = cell.nao_nr()
88    max_memory = max(2000, mydf.max_memory-lib.current_memory()[0]-nao**4*16/1e6) * .5
89
90    fswap = lib.H5TmpFile()
91    tao = []
92    ao_loc = None
93    kconserv = kpts_helper.get_kconserv(cell, kpts)
94    for uniq_id, kpt in enumerate(uniq_kpts):
95        q = uniq_kpts[uniq_id]
96        adapted_ji_idx = numpy.where(uniq_inverse == uniq_id)[0]
97
98        kptjs = kptjs_lst[adapted_ji_idx]
99        coulG = mydf.weighted_coulG(q, False, mydf.mesh)
100        coulG *= factor
101
102        moij_list = []
103        ijslice_list = []
104        for ji, ji_idx in enumerate(adapted_ji_idx):
105            ki = ji_idx // nkpts
106            kj = ji_idx % nkpts
107            moij, ijslice = _conc_mos(mo_coeff_kpts[0][ki], mo_coeff_kpts[1][kj])[2:]
108            moij_list.append(moij)
109            ijslice_list.append(ijslice)
110            fswap.create_dataset('zij/'+str(ji), (ngrids,nmoi*nmoj), 'D')
111
112        for aoaoks, p0, p1 in mydf.ft_loop(mydf.mesh, q, kptjs,
113                                           max_memory=max_memory):
114            for ji, aoao in enumerate(aoaoks):
115                ki = adapted_ji_idx[ji] // nkpts
116                kj = adapted_ji_idx[ji] %  nkpts
117                buf = aoao.transpose(1,2,0).reshape(nao**2,p1-p0)
118                zij = _ao2mo.r_e2(lib.transpose(buf), moij_list[ji],
119                                  ijslice_list[ji], tao, ao_loc)
120                zij *= coulG[p0:p1,None]
121                fswap['zij/'+str(ji)][p0:p1] = zij
122
123        mokl_list = []
124        klslice_list = []
125        for kk in range(nkpts):
126            kl = kconserv[ki, kj, kk]
127            mokl, klslice = _conc_mos(mo_coeff_kpts[2][kk], mo_coeff_kpts[3][kl])[2:]
128            mokl_list.append(mokl)
129            klslice_list.append(klslice)
130            fswap.create_dataset('zkl/'+str(kk), (ngrids,nmok*nmol), 'D')
131
132        ki = adapted_ji_idx[0] // nkpts
133        kj = adapted_ji_idx[0] % nkpts
134        kptls = kpts[kconserv[ki, kj, :]]
135        for aoaoks, p0, p1 in mydf.ft_loop(mydf.mesh, q, -kptls,
136                                           max_memory=max_memory):
137            for kk, aoao in enumerate(aoaoks):
138                buf = aoao.conj().transpose(1,2,0).reshape(nao**2,p1-p0)
139                zkl = _ao2mo.r_e2(lib.transpose(buf), mokl_list[kk],
140                                  klslice_list[kk], tao, ao_loc)
141                fswap['zkl/'+str(kk)][p0:p1] = zkl
142
143        for ji, ji_idx in enumerate(adapted_ji_idx):
144            ki = ji_idx // nkpts
145            kj = ji_idx % nkpts
146
147            moij, ijslice = _conc_mos(mo_coeff_kpts[0][ki], mo_coeff_kpts[1][kj])[2:]
148            zij = []
149            for LpqR, LpqI, sign in mydf.sr_loop(kpts[[ki,kj]], max_memory, False, mydf.blockdim):
150                zij.append(_ao2mo.r_e2(LpqR+LpqI*1j, moij, ijslice, tao, ao_loc))
151
152            for kk in range(nkpts):
153                kl = kconserv[ki, kj, kk]
154                eri_mo = lib.dot(numpy.asarray(fswap['zij/'+str(ji)]).T,
155                                 numpy.asarray(fswap['zkl/'+str(kk)]))
156
157                for i, (LrsR, LrsI, sign) in \
158                        enumerate(mydf.sr_loop(kpts[[kk,kl]], max_memory, False, mydf.blockdim)):
159                    zkl = _ao2mo.r_e2(LrsR+LrsI*1j, mokl_list[kk],
160                                      klslice_list[kk], tao, ao_loc)
161                    lib.dot(zij[i].T, zkl, sign*factor, eri_mo, 1)
162
163                if dtype == numpy.double:
164                    eri_mo = eri_mo.real
165                out[ki,kj,kk] = eri_mo.reshape(eri_shape[3:])
166        del(fswap['zij'])
167        del(fswap['zkl'])
168
169    return out
170
171