1#!/usr/bin/env python
2# Copyright 2018-2020 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Bryan Lau <blau1270@gmail.com>
17#         Qiming Sun <osirpt.sun@gmail.com>
18#
19
20"""
21Created on Thu May 17 11:05:22 2018
22
23@author: Bryan Lau
24
25
26A module that will do on-disk transformation of two electron integrals, and
27also return specific slices of (o)ccupied and (v)irtual ones needed for post HF
28
29Comparing to the full in-memory transformation (see incore.py) which holds all
30intermediates in memory, this version uses less memory but performs slow due
31to IO overhead.
32"""
33
34
35import ctypes
36import numpy
37import h5py
38from pyscf import lib
39from pyscf.lib import logger
40from pyscf.ao2mo.incore import iden_coeffs, _conc_mos
41from pyscf.ao2mo.outcore import _load_from_h5g
42from pyscf.ao2mo import _ao2mo
43
44IOBLK_SIZE = 128  # MB
45
46def general(eri, mo_coeffs, erifile, dataname='eri_mo',
47            ioblk_size=IOBLK_SIZE, compact=True, verbose=logger.NOTE):
48    '''For the given four sets of orbitals, transfer arbitrary spherical AO
49    integrals to MO integrals on disk.
50    Args:
51        eri : 8-fold reduced eri vector
52        mo_coeffs : 4-item list of ndarray
53            Four sets of orbital coefficients, corresponding to the four
54            indices of (ij|kl)
55        erifile : str or h5py File or h5py Group object
56            To store the transformed integrals, in HDF5 format.
57    Kwargs
58        dataname : str
59            The dataset name in the erifile (ref the hierarchy of HDF5 format
60            http://www.hdfgroup.org/HDF5/doc1.6/UG/09_Groups.html).  By assigning
61            different dataname, the existed integral file can be reused.  If
62            the erifile contains the dataname, the new integrals data will
63            overwrite the old one.
64        ioblk_size : float or int
65            The block size for IO, large block size may **not** improve performance
66        compact : bool
67            When compact is True, depending on the four oribital sets, the
68            returned MO integrals has (up to 4-fold) permutation symmetry.
69            If it's False, the function will abandon any permutation symmetry,
70            and return the "plain" MO integrals
71
72
73    Pseudocode / algorithm:
74        u = mu
75        v = nu
76        l = lambda
77        o = sigma
78
79        Assume eri's are 8-fold reduced.
80        nij/nkl_pair = npair or i*j/k*l if only transforming a subset
81
82        First half transform:
83            Initialize half_eri of size (nij_pair,npair)
84                For lo = 1 -> npair
85                    Unpack row lo
86                    Unpack row lo to matrix E_{uv}^{lo}
87                    Transform C_ui^+*E*C_nj -> E_{ij}^{lo}
88                    Ravel or pack E_{ij}^{lo}
89                    Save E_{ij}^{lo} -> half_eri[:,lo]
90
91        Second half transform:
92            Initialize h5d_eri of size (nij_pair,nkl_pair)
93                For ij = 1 -> nij_pair
94                    Load and unpack half_eri[ij,:] -> E_{lo}^{ij}
95                    Transform C_{lk}E_{lo}^{ij}C_{ol} -> E_{kl}^{ij}
96                    Repack E_{kl}^{ij}
97                    Save E_{kl}^{ij} -> h5d_eri[ij,:]
98
99        Each matrix is indexed by the composite index ij x kl, where ij/kl is
100        either npair or ixj/kxl, if only a subset of MOs are being transformed.
101        Since entire rows or columns need to be read in, the arrays are chunked
102        such that IOBLK_SIZE = row/col x chunking col/row. For example, for the
103        first half transform, we would save in nij_pair x IOBLK_SIZE/nij_pair,
104        then load in IOBLK_SIZE/nkl_pair x npair for the second half transform.
105
106        ------ kl ----->
107        |jxl
108        |
109        ij
110        |
111        |
112        v
113
114        As a first guess, the chunking size is jxl. If the super-rows/cols are
115        larger than IOBLK_SIZE, then the chunk rectangle jxl is trimmed
116        accordingly. The pathological limiting case is where the dimensions
117        nao_pair, nij_pair, or nkl_pair are so large that the arrays are
118        chunked 1x1, in which case IOBLK_SIZE needs to be increased.
119
120    '''
121    log = logger.new_logger(None, verbose)
122    log.info('******** ao2mo disk, custom eri ********')
123
124    eri_ao = numpy.asarray(eri, order='C')
125    nao, nmoi = mo_coeffs[0].shape
126    nmoj = mo_coeffs[1].shape[1]
127    nao_pair = nao*(nao+1)//2
128    ijmosym, nij_pair, moij, ijshape = _conc_mos(mo_coeffs[0], mo_coeffs[1], compact)
129    klmosym, nkl_pair, mokl, klshape = _conc_mos(mo_coeffs[2], mo_coeffs[3], compact)
130    ijshape = (ijshape[0], ijshape[1]-ijshape[0],
131               ijshape[2], ijshape[3]-ijshape[2])
132    dtype = numpy.result_type(eri, *mo_coeffs)
133    typesize = dtype.itemsize/1e6 # in MB
134
135    if nij_pair == 0:
136        return numpy.empty((nij_pair,nkl_pair))
137
138    ij_red = ijmosym == 's1'
139    kl_red = klmosym == 's1'
140
141    if isinstance(erifile, str):
142        if h5py.is_hdf5(erifile):
143            feri = h5py.File(erifile, 'a')
144            if dataname in feri:
145                del(feri[dataname])
146        else:
147            feri = h5py.File(erifile,'w',libver='latest')
148    else:
149        assert(isinstance(erifile, h5py.Group))
150        feri = erifile
151
152    h5d_eri = feri.create_dataset(dataname,(nij_pair,nkl_pair), dtype.char)
153    feri_swap = lib.H5TmpFile(libver='latest')
154    chunk_size = min(nao_pair, max(4, int(ioblk_size*1e6/8/nao_pair)))
155
156    log.debug('Memory information:')
157    log.debug('  IOBLK_SIZE (MB): {}  chunk_size: {}'
158              .format(ioblk_size, chunk_size))
159    log.debug('  Final disk eri size (MB): {:.3g}'
160              .format(nij_pair*nkl_pair*typesize))
161    log.debug('  Half transformed eri size (MB): {:.3g}'
162              .format(nij_pair*nao_pair*typesize))
163    log.debug('  RAM buffer (MB): {:.3g}'
164              .format(nij_pair*IOBLK_SIZE*typesize*2))
165
166    if eri_ao.size == nao_pair**2: # 4-fold symmetry
167        # half_e1 first transforms the indices which are contiguous in memory
168        # transpose the 4-fold integrals to make ij the contiguous indices
169        eri_ao = lib.transpose(eri_ao)
170        ftrans = _ao2mo.libao2mo.AO2MOtranse1_incore_s4
171    elif eri_ao.size == nao_pair*(nao_pair+1)//2:
172        ftrans = _ao2mo.libao2mo.AO2MOtranse1_incore_s8
173    else:
174        raise NotImplementedError
175
176    if ijmosym == 's2':
177        fmmm = _ao2mo.libao2mo.AO2MOmmm_nr_s2_s2
178    elif nmoi <= nmoj:
179        fmmm = _ao2mo.libao2mo.AO2MOmmm_nr_s2_iltj
180    else:
181        fmmm = _ao2mo.libao2mo.AO2MOmmm_nr_s2_igtj
182    fdrv = getattr(_ao2mo.libao2mo, 'AO2MOnr_e1incore_drv')
183
184    def save(piece, buf):
185        feri_swap[str(piece)] = buf.T
186
187    # transform \mu\nu -> ij
188    cput0 = logger.process_clock(), logger.perf_counter()
189    with lib.call_in_background(save) as async_write:
190        for istep, (p0, p1) in enumerate(lib.prange(0, nao_pair, chunk_size)):
191            if dtype == numpy.double:
192                buf = numpy.empty((p1-p0, nij_pair))
193                fdrv(ftrans, fmmm,
194                     buf.ctypes.data_as(ctypes.c_void_p),
195                     eri_ao.ctypes.data_as(ctypes.c_void_p),
196                     moij.ctypes.data_as(ctypes.c_void_p),
197                     ctypes.c_int(p0), ctypes.c_int(p1-p0),
198                     ctypes.c_int(nao),
199                     ctypes.c_int(ijshape[0]), ctypes.c_int(ijshape[1]),
200                     ctypes.c_int(ijshape[2]), ctypes.c_int(ijshape[3]))
201            else:  # complex
202                tmp = numpy.empty((p1-p0, nao_pair))
203                if eri_ao.size == nao_pair**2: # 4-fold symmetry
204                    tmp = eri_ao[p0:p1]
205                else: # 8-fold symmetry
206                    for i in range(p0, p1):
207                        tmp[i-p0] = lib.unpack_row(eri_ao, i)
208                tmp = lib.unpack_tril(tmp, filltriu=lib.SYMMETRIC)
209                buf = lib.einsum('xpq,pi,qj->xij', tmp, mo_coeffs[0].conj(), mo_coeffs[1])
210                if ij_red:
211                    buf = buf.reshape(p1-p0,-1) # grabs by row
212                else:
213                    buf = lib.pack_tril(buf)
214
215            async_write(istep, buf)
216
217    log.timer('(uv|lo) -> (ij|lo)', *cput0)
218
219    # transform \lambda\sigma -> kl
220    cput1 = logger.process_clock(), logger.perf_counter()
221    Cklam = mo_coeffs[2].conj()
222    buf_read = numpy.empty((chunk_size,nao_pair), dtype=dtype)
223    buf_prefetch = numpy.empty_like(buf_read)
224
225    def load(start, stop, buf):
226        if start < stop:
227            _load_from_h5g(feri_swap, start, stop, buf)
228
229    def save(start, stop, buf):
230        if start < stop:
231            h5d_eri[start:stop] = buf[:stop-start]
232
233    with lib.call_in_background(save,load) as (async_write, prefetch):
234        for p0, p1 in lib.prange(0, nij_pair, chunk_size):
235            if p0 == 0:
236                load(p0, p1, buf_prefetch)
237
238            buf_read, buf_prefetch = buf_prefetch, buf_read
239            prefetch(p1, min(p1+chunk_size, nij_pair), buf_prefetch)
240
241            lo = lib.unpack_tril(buf_read[:p1-p0], filltriu=lib.SYMMETRIC)
242            lo = lib.einsum('xpq,pi,qj->xij', lo, Cklam, mo_coeffs[3])
243            if kl_red:
244                kl = lo.reshape(p1-p0,-1)
245            else:
246                kl = lib.pack_tril(lo)
247            async_write(p0, p1, kl)
248
249    log.timer('(ij|lo) -> (ij|kl)', *cput1)
250
251    if isinstance(erifile, str):
252        feri.close()
253    return erifile
254
255if __name__ == '__main__':
256    import tempfile
257    from pyscf import gto, scf, ao2mo
258    # set verbose to 7 to get detailed timing info, otherwise 0
259    verbose = 0
260
261    mol = gto.Mole()
262    mol.verbose = 0
263    mol.output = None
264
265    mol.atom = [
266        ['H' , (0. , 0. , .917)],
267        ['F' , (0. , 0. , 0.)], ]
268    mol.basis = '6311g'
269    mol.build()
270
271    mf = scf.RHF(mol)
272    mf.kernel()
273    mf.verbose = verbose
274    mo_coeff = mf.mo_coeff
275    nmo = mo_coeff.shape[0]
276
277    # compare custom outcore eri with incore eri
278    nocc = numpy.count_nonzero(mf.mo_occ)
279    nvir = nmo - nocc
280
281    print('Full incore transformation (pyscf)...')
282    start_time = logger.perf_counter()
283    eri_incore = ao2mo.incore.full(mf._eri, mo_coeff)
284    onnn = eri_incore[:nocc*nmo].copy()
285    print('    Time elapsed (s): ',logger.perf_counter() - start_time)
286
287    print('Parital incore transformation (pyscf)...')
288    start_time = logger.perf_counter()
289    orbo = mo_coeff[:,:nocc]
290    onnn2 = ao2mo.incore.general(mf._eri, (orbo,mo_coeff,mo_coeff,mo_coeff))
291    print('    Time elapsed (s): ',logger.perf_counter() - start_time)
292
293    tmpfile2 = tempfile.NamedTemporaryFile(dir=lib.param.TMPDIR)
294
295    print('\n\nCustom outcore transformation ...')
296    orbo = mo_coeff[:,:nocc]
297    start_time = logger.perf_counter()
298    general(mf._eri, (orbo,mo_coeff,mo_coeff,mo_coeff), tmpfile2.name, 'aa',
299            verbose=verbose)
300    stop_time = logger.perf_counter() - start_time
301    print('    Time elapsed (s): ',stop_time)
302    print('\n\nPyscf outcore transformation ...')
303    start_time = logger.perf_counter()
304    ao2mo.outcore.general(mol, (orbo,mo_coeff,mo_coeff,mo_coeff), tmpfile2.name, 'ab',
305                          verbose=verbose)
306    stop_time2 = logger.perf_counter() - start_time
307    print('    Time elapsed (s): ',stop_time2)
308    print('How worse is the custom implemenation?',stop_time/stop_time2)
309    with h5py.File(tmpfile2.name, 'r') as f:
310        print('\n\nIncore (pyscf) vs outcore (custom)?',numpy.allclose(onnn2,f['aa']))
311        print('Outcore (pyscf) vs outcore (custom)?',numpy.allclose(f['ab'],f['aa']))
312
313    print('\n\nCustom full outcore transformation ...')
314    start_time = logger.perf_counter()
315    general(mf._eri, (mo_coeff,mo_coeff,mo_coeff,mo_coeff), tmpfile2.name, 'aa',
316            verbose=verbose)
317    stop_time = logger.perf_counter() - start_time
318    print('    Time elapsed (s): ',stop_time)
319    print('\n\nPyscf full outcore transformation ...')
320    start_time = logger.perf_counter()
321    ao2mo.outcore.full(mol, mo_coeff, tmpfile2.name, 'ab',verbose=verbose)
322    stop_time2 = logger.perf_counter() - start_time
323    print('    Time elapsed (s): ',stop_time2)
324    print('    How worse is the custom implemenation?',stop_time/stop_time2)
325    with h5py.File(tmpfile2.name, 'r') as f:
326        print('\n\nIncore (pyscf) vs outcore (custom)?',numpy.allclose(eri_incore,f['aa']))
327        print('Outcore (pyscf) vs outcore (custom)?',numpy.allclose(f['ab'],f['aa']))
328
329    tmpfile2.close()
330