1#!/usr/bin/env python
2# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#
18
19'''
20Integral transformation with analytic Fourier transformation
21'''
22
23import numpy
24from pyscf import lib
25from pyscf import ao2mo
26from pyscf.ao2mo import _ao2mo
27from pyscf.ao2mo.incore import iden_coeffs, _conc_mos
28from pyscf.pbc.df.df_jk import zdotNC
29from pyscf.pbc.df.fft_ao2mo import _format_kpts, _iskconserv
30from pyscf.pbc.df.df_ao2mo import _mo_as_complex, _dtrans, _ztrans
31from pyscf.pbc.df.df_ao2mo import warn_pbc2d_eri
32from pyscf.pbc.lib import kpts_helper
33from pyscf.pbc.lib.kpts_helper import is_zero, gamma_point, unique
34from pyscf import __config__
35
36
37def get_eri(mydf, kpts=None,
38            compact=getattr(__config__, 'pbc_df_ao2mo_get_eri_compact', True)):
39    cell = mydf.cell
40    nao = cell.nao_nr()
41    kptijkl = _format_kpts(kpts)
42    if not _iskconserv(cell, kptijkl):
43        lib.logger.warn(cell, 'aft_ao2mo: momentum conservation not found in '
44                        'the given k-points %s', kptijkl)
45        return numpy.zeros((nao,nao,nao,nao))
46
47    kpti, kptj, kptk, kptl = kptijkl
48    q = kptj - kpti
49    mesh = mydf.mesh
50    coulG = mydf.weighted_coulG(q, False, mesh)
51    nao_pair = nao * (nao+1) // 2
52    max_memory = max(2000, (mydf.max_memory - lib.current_memory()[0]) * .8)
53
54####################
55# gamma point, the integral is real and with s4 symmetry
56    if gamma_point(kptijkl):
57        eriR = numpy.zeros((nao_pair,nao_pair))
58        for pqkR, pqkI, p0, p1 \
59                in mydf.pw_loop(mesh, kptijkl[:2], q, max_memory=max_memory,
60                                aosym='s2'):
61            lib.ddot(pqkR*coulG[p0:p1], pqkR.T, 1, eriR, 1)
62            lib.ddot(pqkI*coulG[p0:p1], pqkI.T, 1, eriR, 1)
63            pqkR = pqkI = None
64        if not compact:
65            eriR = ao2mo.restore(1, eriR, nao).reshape(nao**2,-1)
66        return eriR
67
68####################
69# (kpt) i == j == k == l != 0
70# (kpt) i == l && j == k && i != j && j != k  =>
71#
72# complex integrals, N^4 elements
73    elif is_zero(kpti-kptl) and is_zero(kptj-kptk):
74        eriR = numpy.zeros((nao**2,nao**2))
75        eriI = numpy.zeros((nao**2,nao**2))
76        for pqkR, pqkI, p0, p1 \
77                in mydf.pw_loop(mesh, kptijkl[:2], q, max_memory=max_memory):
78            # rho_pq(G+k_pq) * conj(rho_rs(G-k_rs))
79            zdotNC(pqkR*coulG[p0:p1], pqkI*coulG[p0:p1], pqkR.T, pqkI.T,
80                   1, eriR, eriI, 1)
81            pqkR = pqkI = None
82        pqkR = pqkI = coulG = None
83        # transpose(0,1,3,2) because
84        # j == k && i == l  =>
85        # (L|ij).transpose(0,2,1).conj() = (L^*|ji) = (L^*|kl)  =>  (M|kl)
86        # rho_rs(-G+k_rs) = conj(transpose(rho_sr(G+k_sr), (0,2,1)))
87        eri = lib.transpose((eriR+eriI*1j).reshape(-1,nao,nao), axes=(0,2,1))
88        return eri.reshape(nao**2,-1)
89
90####################
91# aosym = s1, complex integrals
92#
93# If kpti == kptj, (kptl-kptk)*a has to be multiples of 2pi because of the wave
94# vector symmetry.  k is a fraction of reciprocal basis, 0 < k/b < 1, by definition.
95# So  kptl/b - kptk/b  must be -1 < k/b < 1.  =>  kptl == kptk
96#
97    else:
98        eriR = numpy.zeros((nao**2,nao**2))
99        eriI = numpy.zeros((nao**2,nao**2))
100#
101#       (pq|rs) = \sum_G 4\pi rho_pq rho_rs / |G+k_{pq}|^2
102#       rho_pq = 1/N \sum_{Tp,Tq} \int exp(-i(G+k_{pq})*r) p(r-Tp) q(r-Tq) dr
103#              = \sum_{Tq} exp(i k_q*Tq) \int exp(-i(G+k_{pq})*r) p(r) q(r-Tq) dr
104# Note the k-point wrap-around for rho_rs, which leads to G+k_{pq} in FT
105#       rho_rs = 1/N \sum_{Tr,Ts} \int exp( i(G+k_{pq})*r) r(r-Tr) s(r-Ts) dr
106#              = \sum_{Ts} exp(i k_s*Ts) \int exp( i(G+k_{pq})*r) r(r) s(r-Ts) dr
107# rho_pq can be directly evaluated by AFT (function pw_loop)
108#       rho_pq = pw_loop(k_q, G+k_{pq})
109# Assuming r(r) and s(r) are real functions, rho_rs is evaluated
110#       rho_rs = 1/N \sum_{Tr,Ts} \int exp( i(G+k_{pq})*r) r(r-Tr) s(r-Ts) dr
111#              = conj(\sum_{Ts} exp(-i k_s*Ts) \int exp(-i(G+k_{pq})*r) r(r) s(r-Ts) dr)
112#              = conj( pw_loop(-k_s, G+k_{pq}) )
113#
114# TODO: For complex AO function r(r) and s(r), pw_loop function needs to be
115# extended to include Gv vector in the arguments
116        for (pqkR, pqkI, p0, p1), (rskR, rskI, q0, q1) in \
117                lib.izip(mydf.pw_loop(mesh, kptijkl[:2], q, max_memory=max_memory*.5),
118                         mydf.pw_loop(mesh,-kptijkl[2:], q, max_memory=max_memory*.5)):
119            pqkR *= coulG[p0:p1]
120            pqkI *= coulG[p0:p1]
121            zdotNC(pqkR, pqkI, rskR.T, rskI.T, 1, eriR, eriI, 1)
122            pqkR = pqkI = rskR = rskI = None
123        return (eriR+eriI*1j)
124
125
126def general(mydf, mo_coeffs, kpts=None,
127            compact=getattr(__config__, 'pbc_df_ao2mo_general_compact', True)):
128    warn_pbc2d_eri(mydf)
129    cell = mydf.cell
130    kptijkl = _format_kpts(kpts)
131    kpti, kptj, kptk, kptl = kptijkl
132    if isinstance(mo_coeffs, numpy.ndarray) and mo_coeffs.ndim == 2:
133        mo_coeffs = (mo_coeffs,) * 4
134    if not _iskconserv(cell, kptijkl):
135        lib.logger.warn(cell, 'aft_ao2mo: momentum conservation not found in '
136                        'the given k-points %s', kptijkl)
137        return numpy.zeros([mo.shape[1] for mo in mo_coeffs])
138
139    q = kptj - kpti
140    mesh = mydf.mesh
141    coulG = mydf.weighted_coulG(q, False, mesh)
142    all_real = not any(numpy.iscomplexobj(mo) for mo in mo_coeffs)
143    max_memory = max(2000, (mydf.max_memory - lib.current_memory()[0]) * .5)
144
145####################
146# gamma point, the integral is real and with s4 symmetry
147    if gamma_point(kptijkl) and all_real:
148        ijmosym, nij_pair, moij, ijslice = _conc_mos(mo_coeffs[0], mo_coeffs[1], compact)
149        klmosym, nkl_pair, mokl, klslice = _conc_mos(mo_coeffs[2], mo_coeffs[3], compact)
150        eri_mo = numpy.zeros((nij_pair,nkl_pair))
151        sym = (iden_coeffs(mo_coeffs[0], mo_coeffs[2]) and
152               iden_coeffs(mo_coeffs[1], mo_coeffs[3]))
153
154        ijR = ijI = klR = klI = buf = None
155        for pqkR, pqkI, p0, p1 \
156                in mydf.pw_loop(mesh, kptijkl[:2], q, max_memory=max_memory,
157                                aosym='s2'):
158            buf = lib.transpose(pqkR, out=buf)
159            ijR, klR = _dtrans(buf, ijR, ijmosym, moij, ijslice,
160                               buf, klR, klmosym, mokl, klslice, sym)
161            lib.ddot(ijR.T, klR*coulG[p0:p1,None], 1, eri_mo, 1)
162            buf = lib.transpose(pqkI, out=buf)
163            ijI, klI = _dtrans(buf, ijI, ijmosym, moij, ijslice,
164                               buf, klI, klmosym, mokl, klslice, sym)
165            lib.ddot(ijI.T, klI*coulG[p0:p1,None], 1, eri_mo, 1)
166            pqkR = pqkI = None
167        return eri_mo
168
169####################
170# (kpt) i == j == k == l != 0
171# (kpt) i == l && j == k && i != j && j != k  =>
172#
173    elif is_zero(kpti-kptl) and is_zero(kptj-kptk):
174        mo_coeffs = _mo_as_complex(mo_coeffs)
175        nij_pair, moij, ijslice = _conc_mos(mo_coeffs[0], mo_coeffs[1])[1:]
176        nlk_pair, molk, lkslice = _conc_mos(mo_coeffs[3], mo_coeffs[2])[1:]
177        eri_mo = numpy.zeros((nij_pair,nlk_pair), dtype=numpy.complex128)
178        sym = (iden_coeffs(mo_coeffs[0], mo_coeffs[3]) and
179               iden_coeffs(mo_coeffs[1], mo_coeffs[2]))
180
181        zij = zlk = buf = None
182        for pqkR, pqkI, p0, p1 \
183                in mydf.pw_loop(mesh, kptijkl[:2], q, max_memory=max_memory):
184            buf = lib.transpose(pqkR+pqkI*1j, out=buf)
185            zij, zlk = _ztrans(buf, zij, moij, ijslice,
186                               buf, zlk, molk, lkslice, sym)
187            lib.dot(zij.T, zlk.conj()*coulG[p0:p1,None], 1, eri_mo, 1)
188            pqkR = pqkI = None
189        nmok = mo_coeffs[2].shape[1]
190        nmol = mo_coeffs[3].shape[1]
191        eri_mo = lib.transpose(eri_mo.reshape(-1,nmol,nmok), axes=(0,2,1))
192        return eri_mo.reshape(nij_pair,nlk_pair)
193
194####################
195# aosym = s1, complex integrals
196#
197# If kpti == kptj, (kptl-kptk)*a has to be multiples of 2pi because of the wave
198# vector symmetry.  k is a fraction of reciprocal basis, 0 < k/b < 1, by definition.
199# So  kptl/b - kptk/b  must be -1 < k/b < 1.  =>  kptl == kptk
200#
201    else:
202        mo_coeffs = _mo_as_complex(mo_coeffs)
203        nij_pair, moij, ijslice = _conc_mos(mo_coeffs[0], mo_coeffs[1])[1:]
204        nkl_pair, mokl, klslice = _conc_mos(mo_coeffs[2], mo_coeffs[3])[1:]
205        eri_mo = numpy.zeros((nij_pair,nkl_pair), dtype=numpy.complex128)
206
207        tao = []
208        ao_loc = None
209        zij = zkl = buf = None
210        for (pqkR, pqkI, p0, p1), (rskR, rskI, q0, q1) in \
211                lib.izip(mydf.pw_loop(mesh, kptijkl[:2], q, max_memory=max_memory*.5),
212                         mydf.pw_loop(mesh,-kptijkl[2:], q, max_memory=max_memory*.5)):
213            buf = lib.transpose(pqkR+pqkI*1j, out=buf)
214            zij = _ao2mo.r_e2(buf, moij, ijslice, tao, ao_loc, out=zij)
215            buf = lib.transpose(rskR-rskI*1j, out=buf)
216            zkl = _ao2mo.r_e2(buf, mokl, klslice, tao, ao_loc, out=zkl)
217            zij *= coulG[p0:p1,None]
218            lib.dot(zij.T, zkl, 1, eri_mo, 1)
219            pqkR = pqkI = rskR = rskI = None
220        return eri_mo
221
222
223def get_ao_pairs_G(mydf, kpts=numpy.zeros((2,3)), q=None, shls_slice=None,
224                   compact=getattr(__config__, 'pbc_df_ao_pairs_compact', False)):
225    '''Calculate forward Fourier tranform (G|ij) of all AO pairs.
226
227    Returns:
228        ao_pairs_G : 2D complex array
229            For gamma point, the shape is (ngrids, nao*(nao+1)/2); otherwise the
230            shape is (ngrids, nao*nao)
231    '''
232    if kpts is None: kpts = numpy.zeros((2,3))
233    cell = mydf.cell
234    kpts = numpy.asarray(kpts)
235    q = kpts[1] - kpts[0]
236    coords = cell.gen_uniform_grids(mydf.mesh)
237    ngrids = len(coords)
238    max_memory = max(2000, (mydf.max_memory - lib.current_memory()[0]) * .5)
239
240    if shls_slice is None:
241        shls_slice = (0, cell.nbas, 0, cell.nbas)
242    ish0, ish1, jsh0, jsh1 = shls_slice
243    ao_loc = cell.ao_loc_nr()
244    i0 = ao_loc[ish0]
245    i1 = ao_loc[ish1]
246    j0 = ao_loc[jsh0]
247    j1 = ao_loc[jsh1]
248    compact = compact and (i0 == j0) and (i1 == j1)
249
250    if compact and gamma_point(kpts):  # gamma point
251        aosym = 's2'
252    else:
253        aosym = 's1'
254
255    ao_pairs_G = numpy.empty((ngrids,(i1-i0)*(j1-j0)), dtype=numpy.complex128)
256    for pqkR, pqkI, p0, p1 \
257            in mydf.pw_loop(mydf.mesh, kpts, q, shls_slice,
258                            max_memory=max_memory, aosym=aosym):
259        ao_pairs_G[p0:p1] = pqkR.T + pqkI.T * 1j
260    return ao_pairs_G
261
262def get_mo_pairs_G(mydf, mo_coeffs, kpts=numpy.zeros((2,3)), q=None,
263                   compact=getattr(__config__, 'pbc_df_mo_pairs_compact', False)):
264    '''Calculate forward fourier transform (G|ij) of all MO pairs.
265
266    Args:
267        mo_coeff: length-2 list of (nao,nmo) ndarrays
268            The two sets of MO coefficients to use in calculating the
269            product |ij).
270
271    Returns:
272        mo_pairs_G : (ngrids, nmoi*nmoj) ndarray
273            The FFT of the real-space MO pairs.
274    '''
275    if kpts is None: kpts = numpy.zeros((2,3))
276    cell = mydf.cell
277    kpts = numpy.asarray(kpts)
278    q = kpts[1] - kpts[0]
279    coords = cell.gen_uniform_grids(mydf.mesh)
280    nmoi = mo_coeffs[0].shape[1]
281    nmoj = mo_coeffs[1].shape[1]
282    ngrids = len(coords)
283    max_memory = max(2000, (mydf.max_memory - lib.current_memory()[0]) * .5)
284
285    mo_pairs_G = numpy.empty((ngrids,nmoi,nmoj), dtype=numpy.complex128)
286    nao = cell.nao
287    for pqkR, pqkI, p0, p1 \
288            in mydf.pw_loop(mydf.mesh, kpts, q,
289                            max_memory=max_memory, aosym='s2'):
290        pqk = lib.unpack_tril(pqkR + pqkI*1j, axis=0).reshape(nao,nao,-1)
291        mo_pairs_G[p0:p1] = lib.einsum('pqk,pi,qj->kij', pqk, *mo_coeffs[:2])
292    return mo_pairs_G.reshape(ngrids,nmoi*nmoj)
293
294def ao2mo_7d(mydf, mo_coeff_kpts, kpts=None, factor=1, out=None):
295    cell = mydf.cell
296    if kpts is None:
297        kpts = mydf.kpts
298    nkpts = len(kpts)
299
300    if isinstance(mo_coeff_kpts, numpy.ndarray) and mo_coeff_kpts.ndim == 3:
301        mo_coeff_kpts = [mo_coeff_kpts] * 4
302    else:
303        mo_coeff_kpts = list(mo_coeff_kpts)
304
305    # Shape of the orbitals can be different on different k-points. The
306    # orbital coefficients must be formatted (padded by zeros) so that the
307    # shape of the orbital coefficients are the same on all k-points. This can
308    # be achieved by calling pbc.mp.kmp2.padded_mo_coeff function
309    nmoi, nmoj, nmok, nmol = [x.shape[2] for x in mo_coeff_kpts]
310    eri_shape = (nkpts, nkpts, nkpts, nmoi, nmoj, nmok, nmol)
311    if gamma_point(kpts):
312        dtype = numpy.result_type(*mo_coeff_kpts)
313    else:
314        dtype = numpy.complex128
315
316    if out is None:
317        out = numpy.empty(eri_shape, dtype=dtype)
318    else:
319        assert(out.shape == eri_shape)
320
321    kptij_lst = numpy.array([(ki, kj) for ki in kpts for kj in kpts])
322    kptis_lst = kptij_lst[:,0]
323    kptjs_lst = kptij_lst[:,1]
324    kpt_ji = kptjs_lst - kptis_lst
325    uniq_kpts, uniq_index, uniq_inverse = unique(kpt_ji)
326    ngrids = numpy.prod(mydf.mesh)
327    nao = cell.nao
328    max_memory = max(2000, mydf.max_memory-lib.current_memory()[0]-nao**4*16/1e6) * .5
329
330    # To hold intermediates
331    fswap = lib.H5TmpFile()
332    tao = []
333    ao_loc = None
334    kconserv = kpts_helper.get_kconserv(cell, kpts)
335    for uniq_id, kpt in enumerate(uniq_kpts):
336        q = uniq_kpts[uniq_id]
337        adapted_ji_idx = numpy.where(uniq_inverse == uniq_id)[0]
338        kptjs = kptjs_lst[adapted_ji_idx]
339        coulG = mydf.weighted_coulG(q, False, mydf.mesh)
340        coulG *= factor
341
342        moij_list = []
343        ijslice_list = []
344        for ji, ji_idx in enumerate(adapted_ji_idx):
345            ki = ji_idx // nkpts
346            kj = ji_idx % nkpts
347            moij, ijslice = _conc_mos(mo_coeff_kpts[0][ki], mo_coeff_kpts[1][kj])[2:]
348            moij_list.append(moij)
349            ijslice_list.append(ijslice)
350            fswap.create_dataset('zij/'+str(ji), (ngrids,nmoi*nmoj), 'D')
351
352        for aoaoks, p0, p1 in mydf.ft_loop(mydf.mesh, q, kptjs,
353                                           max_memory=max_memory):
354            for ji, aoao in enumerate(aoaoks):
355                ki = adapted_ji_idx[ji] // nkpts
356                kj = adapted_ji_idx[ji] %  nkpts
357                buf = aoao.transpose(1,2,0).reshape(nao**2,p1-p0)
358                zij = _ao2mo.r_e2(lib.transpose(buf), moij_list[ji],
359                                  ijslice_list[ji], tao, ao_loc)
360                zij *= coulG[p0:p1,None]
361                fswap['zij/'+str(ji)][p0:p1] = zij
362                buf = zij = None
363
364        mokl_list = []
365        klslice_list = []
366        for kk in range(nkpts):
367            kl = kconserv[ki, kj, kk]
368            mokl, klslice = _conc_mos(mo_coeff_kpts[2][kk], mo_coeff_kpts[3][kl])[2:]
369            mokl_list.append(mokl)
370            klslice_list.append(klslice)
371            fswap.create_dataset('zkl/'+str(kk), (ngrids,nmok*nmol), 'D')
372
373        ki = adapted_ji_idx[0] // nkpts
374        kj = adapted_ji_idx[0] % nkpts
375        kptls = kpts[kconserv[ki, kj, :]]
376        for aoaoks, p0, p1 in mydf.ft_loop(mydf.mesh, q, -kptls,
377                                           max_memory=max_memory):
378            for kk, aoao in enumerate(aoaoks):
379                buf = aoao.conj().transpose(1,2,0).reshape(nao**2,p1-p0)
380                zkl = _ao2mo.r_e2(lib.transpose(buf), mokl_list[kk],
381                                  klslice_list[kk], tao, ao_loc)
382                fswap['zkl/'+str(kk)][p0:p1] = zkl
383                buf = zkl = None
384
385        for ji, ji_idx in enumerate(adapted_ji_idx):
386            ki = ji_idx // nkpts
387            kj = ji_idx % nkpts
388            for kk in range(nkpts):
389                zij = numpy.asarray(fswap['zij/'+str(ji)])
390                zkl = numpy.asarray(fswap['zkl/'+str(kk)])
391                tmp = lib.dot(zij.T, zkl)
392                if dtype == numpy.double:
393                    tmp = tmp.real
394                out[ki,kj,kk] = tmp.reshape(eri_shape[3:])
395        del(fswap['zij'])
396        del(fswap['zkl'])
397
398    return out
399
400
401if __name__ == '__main__':
402    from pyscf.pbc import gto as pgto
403    from pyscf.pbc.df import AFTDF
404
405    L = 5.
406    n = 11
407    cell = pgto.Cell()
408    cell.a = numpy.diag([L,L,L])
409    cell.mesh = numpy.array([n,n,n])
410
411    cell.atom = '''He    3.    2.       3.
412                   He    1.    1.       1.'''
413    #cell.basis = {'He': [[0, (1.0, 1.0)]]}
414    #cell.basis = '631g'
415    #cell.basis = {'He': [[0, (2.4, 1)], [1, (1.1, 1)]]}
416    cell.basis = 'ccpvdz'
417    cell.verbose = 0
418    cell.build(0,0)
419
420    nao = cell.nao_nr()
421    numpy.random.seed(1)
422    kpts = numpy.random.random((4,3))
423    kpts[3] = -numpy.einsum('ij->j', kpts[:3])
424    with_df = AFTDF(cell, kpts)
425    with_df.mesh = [n] * 3
426    mo =(numpy.random.random((nao,nao)) +
427         numpy.random.random((nao,nao))*1j)
428    eri = with_df.get_eri(kpts).reshape((nao,)*4)
429    eri0 = numpy.einsum('pjkl,pi->ijkl', eri , mo.conj())
430    eri0 = numpy.einsum('ipkl,pj->ijkl', eri0, mo       )
431    eri0 = numpy.einsum('ijpl,pk->ijkl', eri0, mo.conj())
432    eri0 = numpy.einsum('ijkp,pl->ijkl', eri0, mo       )
433    eri1 = with_df.ao2mo(mo, kpts)
434    print(abs(eri1-eri0).sum())
435