1#!/usr/bin/env python
2# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#
18
19import numpy
20import h5py
21from pyscf import lib
22from pyscf import gto
23from pyscf.ao2mo.outcore import balance_segs
24from pyscf.pbc.lib.kpts_helper import gamma_point, unique, KPT_DIFF_TOL
25from pyscf.pbc.df.incore import wrap_int3c, make_auxcell
26
27libpbc = lib.load_library('libpbc')
28
29
30def aux_e1(cell, auxcell_or_auxbasis, erifile, intor='int3c2e', aosym='s2ij', comp=None,
31           kptij_lst=None, dataname='eri_mo', shls_slice=None, max_memory=2000,
32           verbose=0):
33    r'''3-center AO integrals (L|ij) with double lattice sum:
34    \sum_{lm} (L[0]|i[l]j[m]), where L is the auxiliary basis.
35    Three-index integral tensor (kptij_idx, naux, nao_pair) or four-index
36    integral tensor (kptij_idx, comp, naux, nao_pair) are stored on disk.
37
38    Args:
39        kptij_lst : (*,2,3) array
40            A list of (kpti, kptj)
41    '''
42    if isinstance(auxcell_or_auxbasis, gto.Mole):
43        auxcell = auxcell_or_auxbasis
44    else:
45        auxcell = make_auxcell(cell, auxcell_or_auxbasis)
46
47    intor, comp = gto.moleintor._get_intor_and_comp(cell._add_suffix(intor), comp)
48
49    if isinstance(erifile, h5py.Group):
50        feri = erifile
51    elif h5py.is_hdf5(erifile):
52        feri = h5py.File(erifile, 'a')
53    else:
54        feri = h5py.File(erifile, 'w')
55    if dataname in feri:
56        del(feri[dataname])
57    if dataname+'-kptij' in feri:
58        del(feri[dataname+'-kptij'])
59
60    if kptij_lst is None:
61        kptij_lst = numpy.zeros((1,2,3))
62    feri[dataname+'-kptij'] = kptij_lst
63
64    if shls_slice is None:
65        shls_slice = (0, cell.nbas, 0, cell.nbas, 0, auxcell.nbas)
66
67    ao_loc = cell.ao_loc_nr()
68    aux_loc = auxcell.ao_loc_nr(auxcell.cart or 'ssc' in intor)[:shls_slice[5]+1]
69    ni = ao_loc[shls_slice[1]] - ao_loc[shls_slice[0]]
70    nj = ao_loc[shls_slice[3]] - ao_loc[shls_slice[2]]
71    naux = aux_loc[shls_slice[5]] - aux_loc[shls_slice[4]]
72    nkptij = len(kptij_lst)
73
74    nii = (ao_loc[shls_slice[1]]*(ao_loc[shls_slice[1]]+1)//2 -
75           ao_loc[shls_slice[0]]*(ao_loc[shls_slice[0]]+1)//2)
76    nij = ni * nj
77
78    kpti = kptij_lst[:,0]
79    kptj = kptij_lst[:,1]
80    aosym_ks2 = abs(kpti-kptj).sum(axis=1) < KPT_DIFF_TOL
81    j_only = numpy.all(aosym_ks2)
82    #aosym_ks2 &= (aosym[:2] == 's2' and shls_slice[:2] == shls_slice[2:4])
83    aosym_ks2 &= aosym[:2] == 's2'
84    for k, kptij in enumerate(kptij_lst):
85        key = '%s/%d' % (dataname, k)
86        if gamma_point(kptij):
87            dtype = 'f8'
88        else:
89            dtype = 'c16'
90        if aosym_ks2[k]:
91            nao_pair = nii
92        else:
93            nao_pair = nij
94        if comp == 1:
95            shape = (naux,nao_pair)
96        else:
97            shape = (comp,naux,nao_pair)
98        feri.create_dataset(key, shape, dtype)
99    if naux == 0:
100        feri.close()
101        return erifile
102
103    if j_only and aosym[:2] == 's2':
104        assert(shls_slice[2] == 0)
105        nao_pair = nii
106    else:
107        nao_pair = nij
108
109    if gamma_point(kptij_lst):
110        dtype = numpy.double
111    else:
112        dtype = numpy.complex128
113
114    buflen = max(8, int(max_memory*1e6/16/(nkptij*ni*nj*comp)))
115    auxdims = aux_loc[shls_slice[4]+1:shls_slice[5]+1] - aux_loc[shls_slice[4]:shls_slice[5]]
116    auxranges = balance_segs(auxdims, buflen)
117    buflen = max([x[2] for x in auxranges])
118    buf = numpy.empty(nkptij*comp*ni*nj*buflen, dtype=dtype)
119    buf1 = numpy.empty(ni*nj*buflen, dtype=dtype)
120
121    int3c = wrap_int3c(cell, auxcell, intor, aosym, comp, kptij_lst)
122
123    naux0 = 0
124    for istep, auxrange in enumerate(auxranges):
125        sh0, sh1, nrow = auxrange
126        sub_slice = (shls_slice[0], shls_slice[1],
127                     shls_slice[2], shls_slice[3],
128                     shls_slice[4]+sh0, shls_slice[4]+sh1)
129        mat = numpy.ndarray((nkptij,comp,nao_pair,nrow), dtype=dtype, buffer=buf)
130        mat = int3c(sub_slice, mat)
131
132        for k, kptij in enumerate(kptij_lst):
133            h5dat = feri['%s/%d'%(dataname,k)]
134            for icomp, v in enumerate(mat[k]):
135                v = lib.transpose(v, out=buf1)
136                if gamma_point(kptij):
137                    v = v.real
138                if aosym_ks2[k] and v.shape[1] == ni**2:
139                    v = lib.pack_tril(v.reshape(-1,ni,ni))
140                if comp == 1:
141                    h5dat[naux0:naux0+nrow] = v
142                else:
143                    h5dat[icomp,naux0:naux0+nrow] = v
144        naux0 += nrow
145
146    if not isinstance(erifile, h5py.Group):
147        feri.close()
148    return erifile
149
150
151def _aux_e2(cell, auxcell_or_auxbasis, erifile, intor='int3c2e', aosym='s2ij', comp=None,
152            kptij_lst=None, dataname='eri_mo', shls_slice=None, max_memory=2000,
153            verbose=0):
154    r'''3-center AO integrals (ij|L) with double lattice sum:
155    \sum_{lm} (i[l]j[m]|L[0]), where L is the auxiliary basis.
156    Three-index integral tensor (kptij_idx, nao_pair, naux) or four-index
157    integral tensor (kptij_idx, comp, nao_pair, naux) are stored on disk.
158
159    **This function should be only used by df and mdf initialization function
160    _make_j3c**
161
162    Args:
163        kptij_lst : (*,2,3) array
164            A list of (kpti, kptj)
165    '''
166    if isinstance(auxcell_or_auxbasis, gto.Mole):
167        auxcell = auxcell_or_auxbasis
168    else:
169        auxcell = make_auxcell(cell, auxcell_or_auxbasis)
170
171    intor, comp = gto.moleintor._get_intor_and_comp(cell._add_suffix(intor), comp)
172
173    if isinstance(erifile, h5py.Group):
174        feri = erifile
175    elif h5py.is_hdf5(erifile):
176        feri = h5py.File(erifile, 'a')
177    else:
178        feri = h5py.File(erifile, 'w')
179    if dataname in feri:
180        del(feri[dataname])
181    if dataname+'-kptij' in feri:
182        del(feri[dataname+'-kptij'])
183
184    if kptij_lst is None:
185        kptij_lst = numpy.zeros((1,2,3))
186    feri[dataname+'-kptij'] = kptij_lst
187
188    if shls_slice is None:
189        shls_slice = (0, cell.nbas, 0, cell.nbas, 0, auxcell.nbas)
190
191    ao_loc = cell.ao_loc_nr()
192    aux_loc = auxcell.ao_loc_nr(auxcell.cart or 'ssc' in intor)[:shls_slice[5]+1]
193    ni = ao_loc[shls_slice[1]] - ao_loc[shls_slice[0]]
194    nj = ao_loc[shls_slice[3]] - ao_loc[shls_slice[2]]
195    nkptij = len(kptij_lst)
196
197    nii = (ao_loc[shls_slice[1]]*(ao_loc[shls_slice[1]]+1)//2 -
198           ao_loc[shls_slice[0]]*(ao_loc[shls_slice[0]]+1)//2)
199    nij = ni * nj
200
201    kpti = kptij_lst[:,0]
202    kptj = kptij_lst[:,1]
203    aosym_ks2 = abs(kpti-kptj).sum(axis=1) < KPT_DIFF_TOL
204    j_only = numpy.all(aosym_ks2)
205    #aosym_ks2 &= (aosym[:2] == 's2' and shls_slice[:2] == shls_slice[2:4])
206    aosym_ks2 &= aosym[:2] == 's2'
207
208    if j_only and aosym[:2] == 's2':
209        assert(shls_slice[2] == 0)
210        nao_pair = nii
211    else:
212        nao_pair = nij
213
214    if gamma_point(kptij_lst):
215        dtype = numpy.double
216    else:
217        dtype = numpy.complex128
218
219    buflen = max(8, int(max_memory*.47e6/16/(nkptij*ni*nj*comp)))
220    auxdims = aux_loc[shls_slice[4]+1:shls_slice[5]+1] - aux_loc[shls_slice[4]:shls_slice[5]]
221    auxranges = balance_segs(auxdims, buflen)
222    buflen = max([x[2] for x in auxranges])
223    buf = numpy.empty(nkptij*comp*ni*nj*buflen, dtype=dtype)
224    bufs = [buf, numpy.empty_like(buf)]
225    int3c = wrap_int3c(cell, auxcell, intor, aosym, comp, kptij_lst)
226
227    def process(aux_range):
228        sh0, sh1, nrow = aux_range
229        sub_slice = (shls_slice[0], shls_slice[1],
230                     shls_slice[2], shls_slice[3],
231                     shls_slice[4]+sh0, shls_slice[4]+sh1)
232        mat = numpy.ndarray((nkptij,comp,nao_pair,nrow), dtype=dtype, buffer=bufs[0])
233        bufs[:] = bufs[1], bufs[0]
234        int3c(sub_slice, mat)
235        return mat
236
237    kptis = kptij_lst[:,0]
238    kptjs = kptij_lst[:,1]
239    kpt_ji = kptjs - kptis
240    uniq_kpts, uniq_index, uniq_inverse = unique(kpt_ji)
241    # sorted_ij_idx: Sort and group the kptij_lst according to the ordering in
242    # df._make_j3c to reduce the data fragment in the hdf5 file.  When datasets
243    # are written to hdf5, they are saved sequentially. If the integral data are
244    # saved as the order of kptij_lst, removing the datasets in df._make_j3c will
245    # lead to disk space fragment that can not be reused.
246    sorted_ij_idx = numpy.hstack([numpy.where(uniq_inverse == k)[0]
247                                  for k, kpt in enumerate(uniq_kpts)])
248    tril_idx = numpy.tril_indices(ni)
249    tril_idx = tril_idx[0] * ni + tril_idx[1]
250
251    for istep, mat in enumerate(lib.map_with_prefetch(process, auxranges)):
252        for k in sorted_ij_idx:
253            v = mat[k]
254            if gamma_point(kptij_lst[k]):
255                v = v.real
256            if aosym_ks2[k] and nao_pair == ni**2:
257                v = v[:,tril_idx]
258            feri['%s/%d/%d' % (dataname,k,istep)] = v
259        mat = None
260
261    if not isinstance(erifile, h5py.Group):
262        feri.close()
263    return erifile
264
265
266