1#!/usr/bin/env python
2# Copyright 2014-2021 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#
18
19'''Multigrid to compute DFT integrals'''
20
21import ctypes
22import copy
23import numpy
24import scipy.linalg
25
26from pyscf import lib
27from pyscf.lib import logger
28from pyscf.gto import ATOM_OF, ANG_OF, NPRIM_OF, PTR_EXP, PTR_COEFF
29from pyscf.dft.numint import libdft, BLKSIZE
30from pyscf.pbc import tools
31from pyscf.pbc import gto
32from pyscf.pbc.gto import pseudo
33from pyscf.pbc.dft import numint, gen_grid
34from pyscf.pbc.df.df_jk import _format_dms, _format_kpts_band, _format_jks
35from pyscf.pbc.lib.kpts_helper import gamma_point
36from pyscf.pbc.df import fft
37from pyscf.pbc.df import ft_ao
38from pyscf import __config__
39
40#sys.stderr.write('WARN: multigrid is an experimental feature. It is still in '
41#                 'testing\nFeatures and APIs may be changed in the future.\n')
42
43EXTRA_PREC = getattr(__config__, 'pbc_gto_eval_gto_extra_precision', 1e-2)
44TO_EVEN_GRIDS = getattr(__config__, 'pbc_dft_multigrid_to_even', False)
45RMAX_FACTOR_ORTH = getattr(__config__, 'pbc_dft_multigrid_rmax_factor_orth', 1.1)
46RMAX_FACTOR_NONORTH = getattr(__config__, 'pbc_dft_multigrid_rmax_factor_nonorth', 0.5)
47RMAX_RATIO = getattr(__config__, 'pbc_dft_multigrid_rmax_ratio', 0.7)
48R_RATIO_SUBLOOP = getattr(__config__, 'pbc_dft_multigrid_r_ratio_subloop', 0.6)
49INIT_MESH_ORTH = getattr(__config__, 'pbc_dft_multigrid_init_mesh_orth', (12,12,12))
50INIT_MESH_NONORTH = getattr(__config__, 'pbc_dft_multigrid_init_mesh_nonorth', (32,32,32))
51KE_RATIO = getattr(__config__, 'pbc_dft_multigrid_ke_ratio', 1.3)
52TASKS_TYPE = getattr(__config__, 'pbc_dft_multigrid_tasks_type', 'ke_cut') # 'rcut'
53
54# RHOG_HIGH_ORDER=True will compute the high order derivatives of electron
55# density in real space and FT to reciprocal space.  Set RHOG_HIGH_ORDER=False
56# to approximate the density derivatives in reciprocal space (without
57# evaluating the high order derivatives in real space).
58RHOG_HIGH_ORDER = getattr(__config__, 'pbc_dft_multigrid_rhog_high_order', False)
59
60PTR_EXPDROP = 16
61EXPDROP = getattr(__config__, 'pbc_dft_multigrid_expdrop', 1e-12)
62IMAG_TOL = 1e-9
63
64
65def eval_mat(cell, weights, shls_slice=None, comp=1, hermi=0,
66             xctype='LDA', kpts=None, mesh=None, offset=None, submesh=None):
67    assert(all(cell._bas[:,NPRIM_OF] == 1))
68    atm, bas, env = gto.conc_env(cell._atm, cell._bas, cell._env,
69                                 cell._atm, cell._bas, cell._env)
70    env[PTR_EXPDROP] = min(cell.precision*EXTRA_PREC, EXPDROP)
71    ao_loc = gto.moleintor.make_loc(bas, 'cart')
72    if shls_slice is None:
73        shls_slice = (0, cell.nbas, 0, cell.nbas)
74    i0, i1, j0, j1 = shls_slice
75    j0 += cell.nbas
76    j1 += cell.nbas
77    naoi = ao_loc[i1] - ao_loc[i0]
78    naoj = ao_loc[j1] - ao_loc[j0]
79
80    if cell.dimension > 0:
81        Ls = numpy.asarray(cell.get_lattice_Ls(), order='C')
82    else:
83        Ls = numpy.zeros((1,3))
84    nimgs = len(Ls)
85
86    if mesh is None:
87        mesh = cell.mesh
88    weights = numpy.asarray(weights, order='C')
89    assert(weights.dtype == numpy.double)
90    xctype = xctype.upper()
91    n_mat = None
92    if xctype == 'LDA':
93        if weights.ndim == 1:
94            weights = weights.reshape(-1, numpy.prod(mesh))
95        else:
96            n_mat = weights.shape[0]
97    elif xctype == 'GGA':
98        if hermi == 1:
99            raise RuntimeError('hermi=1 is not supported for GGA functional')
100        if weights.ndim == 2:
101            weights = weights.reshape(-1, 4, numpy.prod(mesh))
102        else:
103            n_mat = weights.shape[0]
104    else:
105        raise NotImplementedError
106
107    a = cell.lattice_vectors()
108    b = numpy.linalg.inv(a.T)
109    if offset is None:
110        offset = (0, 0, 0)
111    if submesh is None:
112        submesh = mesh
113    # log_prec is used to estimate the gto_rcut. Add EXTRA_PREC to count
114    # other possible factors and coefficients in the integral.
115    log_prec = numpy.log(cell.precision * EXTRA_PREC)
116
117    if abs(a-numpy.diag(a.diagonal())).max() < 1e-12:
118        lattice_type = '_orth'
119    else:
120        lattice_type = '_nonorth'
121    eval_fn = 'NUMINTeval_' + xctype.lower() + lattice_type
122    drv = libdft.NUMINT_fill2c
123
124    def make_mat(weights):
125        mat = numpy.zeros((nimgs,comp,naoj,naoi))
126        drv(getattr(libdft, eval_fn),
127            weights.ctypes.data_as(ctypes.c_void_p),
128            mat.ctypes.data_as(ctypes.c_void_p),
129            ctypes.c_int(comp), ctypes.c_int(hermi),
130            (ctypes.c_int*4)(i0, i1, j0, j1),
131            ao_loc.ctypes.data_as(ctypes.c_void_p),
132            ctypes.c_double(log_prec),
133            ctypes.c_int(cell.dimension),
134            ctypes.c_int(nimgs),
135            Ls.ctypes.data_as(ctypes.c_void_p),
136            a.ctypes.data_as(ctypes.c_void_p),
137            b.ctypes.data_as(ctypes.c_void_p),
138            (ctypes.c_int*3)(*offset), (ctypes.c_int*3)(*submesh),
139            (ctypes.c_int*3)(*mesh),
140            atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(len(atm)),
141            bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(len(bas)),
142            env.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(len(env)))
143        return mat
144
145    out = []
146    for wv in weights:
147        if cell.dimension == 0:
148            mat = make_mat(wv)[0].transpose(0,2,1)
149            if hermi:
150                for i in range(comp):
151                    lib.hermi_triu(mat[i], inplace=True)
152            if comp == 1:
153                mat = mat[0]
154        elif kpts is None or gamma_point(kpts):
155            mat = make_mat(wv).sum(axis=0).transpose(0,2,1)
156            if hermi:
157                for i in range(comp):
158                    lib.hermi_triu(mat[i], inplace=True)
159            if comp == 1:
160                mat = mat[0]
161            if getattr(kpts, 'ndim', None) == 2:
162                mat = mat[None,:]
163        else:
164            mat = make_mat(wv)
165            expkL = numpy.exp(1j*kpts.reshape(-1,3).dot(Ls.T))
166            mat = lib.einsum('kr,rcij->kcij', expkL, mat)
167            if hermi:
168                for i in range(comp):
169                    for k in range(len(kpts)):
170                        lib.hermi_triu(mat[k,i], inplace=True)
171            mat = mat.transpose(0,1,3,2)
172            if comp == 1:
173                mat = mat[:,0]
174        out.append(mat)
175
176    if n_mat is None:
177        out = out[0]
178    return out
179
180def eval_rho(cell, dm, shls_slice=None, hermi=0, xctype='LDA', kpts=None,
181             mesh=None, offset=None, submesh=None, ignore_imag=False,
182             out=None):
183    '''Collocate the *real* density (opt. gradients) on the real-space grid.
184    '''
185    assert(all(cell._bas[:,NPRIM_OF] == 1))
186    atm, bas, env = gto.conc_env(cell._atm, cell._bas, cell._env,
187                                 cell._atm, cell._bas, cell._env)
188    env[PTR_EXPDROP] = min(cell.precision*EXTRA_PREC, EXPDROP)
189    ao_loc = gto.moleintor.make_loc(bas, 'cart')
190    if shls_slice is None:
191        shls_slice = (0, cell.nbas, 0, cell.nbas)
192    i0, i1, j0, j1 = shls_slice
193    if hermi:
194        assert(i0 == j0 and i1 == j1)
195    j0 += cell.nbas
196    j1 += cell.nbas
197    naoi = ao_loc[i1] - ao_loc[i0]
198    naoj = ao_loc[j1] - ao_loc[j0]
199    dm = numpy.asarray(dm, order='C')
200    assert(dm.shape[-2:] == (naoi, naoj))
201
202    if cell.dimension > 0:
203        Ls = numpy.asarray(cell.get_lattice_Ls(), order='C')
204    else:
205        Ls = numpy.zeros((1,3))
206
207    if cell.dimension == 0 or kpts is None or gamma_point(kpts):
208        nkpts, nimgs = 1, Ls.shape[0]
209        dm = dm.reshape(-1,1,naoi,naoj).transpose(0,1,3,2)
210    else:
211        expkL = numpy.exp(1j*kpts.reshape(-1,3).dot(Ls.T))
212        nkpts, nimgs = expkL.shape
213        dm = dm.reshape(-1,nkpts,naoi,naoj).transpose(0,1,3,2)
214    n_dm = dm.shape[0]
215
216    a = cell.lattice_vectors()
217    b = numpy.linalg.inv(a.T)
218    if mesh is None:
219        mesh = cell.mesh
220    if offset is None:
221        offset = (0, 0, 0)
222    if submesh is None:
223        submesh = mesh
224    log_prec = numpy.log(cell.precision * EXTRA_PREC)
225
226    if abs(a-numpy.diag(a.diagonal())).max() < 1e-12:
227        lattice_type = '_orth'
228    else:
229        lattice_type = '_nonorth'
230    xctype = xctype.upper()
231    if xctype == 'LDA':
232        comp = 1
233    elif xctype == 'GGA':
234        if hermi == 1:
235            raise RuntimeError('hermi=1 is not supported for GGA functional')
236        comp = 4
237    else:
238        raise NotImplementedError('meta-GGA')
239    if comp == 1:
240        shape = (numpy.prod(submesh),)
241    else:
242        shape = (comp, numpy.prod(submesh))
243    eval_fn = 'NUMINTrho_' + xctype.lower() + lattice_type
244    drv = libdft.NUMINT_rho_drv
245
246    def make_rho_(rho, dm, hermi):
247        drv(getattr(libdft, eval_fn),
248            rho.ctypes.data_as(ctypes.c_void_p),
249            dm.ctypes.data_as(ctypes.c_void_p),
250            ctypes.c_int(comp), ctypes.c_int(hermi),
251            (ctypes.c_int*4)(i0, i1, j0, j1),
252            ao_loc.ctypes.data_as(ctypes.c_void_p),
253            ctypes.c_double(log_prec),
254            ctypes.c_int(cell.dimension),
255            ctypes.c_int(nimgs),
256            Ls.ctypes.data_as(ctypes.c_void_p),
257            a.ctypes.data_as(ctypes.c_void_p),
258            b.ctypes.data_as(ctypes.c_void_p),
259            (ctypes.c_int*3)(*offset), (ctypes.c_int*3)(*submesh),
260            (ctypes.c_int*3)(*mesh),
261            atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(len(atm)),
262            bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(len(bas)),
263            env.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(len(env)))
264        return rho
265
266    rho = []
267    for i, dm_i in enumerate(dm):
268        if out is None:
269            rho_i = numpy.zeros(shape)
270        else:
271            rho_i = out[i]
272            assert(rho_i.size == numpy.prod(shape))
273
274        if cell.dimension == 0:
275            # make a copy because the dm may be overwritten in the
276            # NUMINT_rho_drv inplace
277            make_rho_(rho_i, numpy.array(dm_i, order='C', copy=True), hermi)
278        elif kpts is None or gamma_point(kpts):
279            make_rho_(rho_i, numpy.repeat(dm_i, nimgs, axis=0), hermi)
280        else:
281            dm_L = lib.dot(expkL.T, dm_i.reshape(nkpts,-1)).reshape(nimgs,naoj,naoi)
282            dmR = numpy.asarray(dm_L.real, order='C')
283
284            if ignore_imag:
285                has_imag = False
286            else:
287                dmI = numpy.asarray(dm_L.imag, order='C')
288                has_imag = (hermi == 0 and abs(dmI).max() > 1e-8)
289                if (has_imag and xctype == 'LDA' and
290                    naoi == naoj and
291                    # For hermitian density matrices, the anti-symmetry
292                    # character of the imaginary part of the density matrices
293                    # can be found by rearranging the repeated images.
294                    abs(dm_i - dm_i.conj().transpose(0,2,1)).max() < 1e-8):
295                    has_imag = False
296            dm_L = None
297
298            if has_imag:
299                if out is None:
300                    rho_i  = make_rho_(rho_i, dmI, 0)*1j
301                    rho_i += make_rho_(numpy.zeros(shape), dmR, 0)
302                else:
303                    out[i]  = make_rho_(numpy.zeros(shape), dmI, 0)*1j
304                    out[i] += make_rho_(numpy.zeros(shape), dmR, 0)
305            else:
306                assert(rho_i.dtype == numpy.double)
307                make_rho_(rho_i, dmR, hermi)
308            dmR = dmI = None
309
310        rho.append(rho_i)
311
312    if n_dm == 1:
313        rho = rho[0]
314    return rho
315
316def get_nuc(mydf, kpts=None):
317    cell = mydf.cell
318    if kpts is None:
319        kpts_lst = numpy.zeros((1,3))
320    else:
321        kpts_lst = numpy.reshape(kpts, (-1,3))
322
323    mesh = mydf.mesh
324    charge = -cell.atom_charges()
325    Gv = cell.get_Gv(mesh)
326    SI = cell.get_SI(Gv)
327    rhoG = numpy.dot(charge, SI)
328
329    coulG = tools.get_coulG(cell, mesh=mesh, Gv=Gv)
330    vneG = rhoG * coulG
331    vne = _get_j_pass2(mydf, vneG, kpts_lst)[0]
332
333    if kpts is None or numpy.shape(kpts) == (3,):
334        vne = vne[0]
335    return numpy.asarray(vne)
336
337def get_pp(mydf, kpts=None):
338    '''Get the periodic pseudotential nuc-el AO matrix, with G=0 removed.
339    '''
340    from pyscf import gto
341    cell = mydf.cell
342    if kpts is None:
343        kpts_lst = numpy.zeros((1,3))
344    else:
345        kpts_lst = numpy.reshape(kpts, (-1,3))
346
347    mesh = mydf.mesh
348    SI = cell.get_SI()
349    Gv = cell.get_Gv(mesh)
350    vpplocG = pseudo.get_vlocG(cell, Gv)
351    vpplocG = -numpy.einsum('ij,ij->j', SI, vpplocG)
352    # from get_jvloc_G0 function
353    vpplocG[0] = numpy.sum(pseudo.get_alphas(cell))
354    ngrids = len(vpplocG)
355
356    vpp = _get_j_pass2(mydf, vpplocG, kpts_lst)[0]
357
358    # vppnonloc evaluated in reciprocal space
359    fakemol = gto.Mole()
360    fakemol._atm = numpy.zeros((1,gto.ATM_SLOTS), dtype=numpy.int32)
361    fakemol._bas = numpy.zeros((1,gto.BAS_SLOTS), dtype=numpy.int32)
362    ptr = gto.PTR_ENV_START
363    fakemol._env = numpy.zeros(ptr+10)
364    fakemol._bas[0,gto.NPRIM_OF ] = 1
365    fakemol._bas[0,gto.NCTR_OF  ] = 1
366    fakemol._bas[0,gto.PTR_EXP  ] = ptr+3
367    fakemol._bas[0,gto.PTR_COEFF] = ptr+4
368
369    # buf for SPG_lmi upto l=0..3 and nl=3
370    buf = numpy.empty((48,ngrids), dtype=numpy.complex128)
371
372    def vppnl_by_k(kpt):
373        Gk = Gv + kpt
374        G_rad = lib.norm(Gk, axis=1)
375        aokG = ft_ao.ft_ao(cell, Gv, kpt=kpt) * (ngrids/cell.vol)
376        vppnl = 0
377        for ia in range(cell.natm):
378            symb = cell.atom_symbol(ia)
379            if symb not in cell._pseudo:
380                continue
381            pp = cell._pseudo[symb]
382            p1 = 0
383            for l, proj in enumerate(pp[5:]):
384                rl, nl, hl = proj
385                if nl > 0:
386                    fakemol._bas[0,gto.ANG_OF] = l
387                    fakemol._env[ptr+3] = .5*rl**2
388                    fakemol._env[ptr+4] = rl**(l+1.5)*numpy.pi**1.25
389                    pYlm_part = fakemol.eval_gto('GTOval', Gk)
390
391                    p0, p1 = p1, p1+nl*(l*2+1)
392                    # pYlm is real, SI[ia] is complex
393                    pYlm = numpy.ndarray((nl,l*2+1,ngrids), dtype=numpy.complex128, buffer=buf[p0:p1])
394                    for k in range(nl):
395                        qkl = pseudo.pp._qli(G_rad*rl, l, k)
396                        pYlm[k] = pYlm_part.T * qkl
397                    #:SPG_lmi = numpy.einsum('g,nmg->nmg', SI[ia].conj(), pYlm)
398                    #:SPG_lm_aoG = numpy.einsum('nmg,gp->nmp', SPG_lmi, aokG)
399                    #:tmp = numpy.einsum('ij,jmp->imp', hl, SPG_lm_aoG)
400                    #:vppnl += numpy.einsum('imp,imq->pq', SPG_lm_aoG.conj(), tmp)
401            if p1 > 0:
402                SPG_lmi = buf[:p1]
403                SPG_lmi *= SI[ia].conj()
404                SPG_lm_aoGs = lib.zdot(SPG_lmi, aokG)
405                p1 = 0
406                for l, proj in enumerate(pp[5:]):
407                    rl, nl, hl = proj
408                    if nl > 0:
409                        p0, p1 = p1, p1+nl*(l*2+1)
410                        hl = numpy.asarray(hl)
411                        SPG_lm_aoG = SPG_lm_aoGs[p0:p1].reshape(nl,l*2+1,-1)
412                        tmp = numpy.einsum('ij,jmp->imp', hl, SPG_lm_aoG)
413                        vppnl += numpy.einsum('imp,imq->pq', SPG_lm_aoG.conj(), tmp)
414        return vppnl * (1./ngrids**2)
415
416    for k, kpt in enumerate(kpts_lst):
417        vppnl = vppnl_by_k(kpt)
418        if gamma_point(kpt):
419            vpp[k] = vpp[k].real + vppnl.real
420        else:
421            vpp[k] += vppnl
422
423    if kpts is None or numpy.shape(kpts) == (3,):
424        vpp = vpp[0]
425    return numpy.asarray(vpp)
426
427
428def get_j_kpts(mydf, dm_kpts, hermi=1, kpts=numpy.zeros((1,3)), kpts_band=None):
429    '''Get the Coulomb (J) AO matrix at sampled k-points.
430
431    Args:
432        dm_kpts : (nkpts, nao, nao) ndarray or a list of (nkpts,nao,nao) ndarray
433            Density matrix at each k-point.  If a list of k-point DMs, eg,
434            UHF alpha and beta DM, the alpha and beta DMs are contracted
435            separately.
436        kpts : (nkpts, 3) ndarray
437
438    Kwargs:
439        kpts_band : (3,) ndarray or (*,3) ndarray
440            A list of arbitrary "band" k-points at which to evalute the matrix.
441
442    Returns:
443        vj : (nkpts, nao, nao) ndarray
444        or list of vj if the input dm_kpts is a list of DMs
445    '''
446    cell = mydf.cell
447    dm_kpts = numpy.asarray(dm_kpts)
448    rhoG = _eval_rhoG(mydf, dm_kpts, hermi, kpts, deriv=0)
449    coulG = tools.get_coulG(cell, mesh=cell.mesh)
450    #:vG = numpy.einsum('ng,g->ng', rhoG[:,0], coulG)
451    vG = rhoG[:,0]
452    vG *= coulG
453
454    kpts_band, input_band = _format_kpts_band(kpts_band, kpts), kpts_band
455    vj_kpts = _get_j_pass2(mydf, vG, kpts_band)
456    return _format_jks(vj_kpts, dm_kpts, input_band, kpts)
457
458
459def _eval_rhoG(mydf, dm_kpts, hermi=1, kpts=numpy.zeros((1,3)), deriv=0,
460               rhog_high_order=RHOG_HIGH_ORDER):
461    log = logger.Logger(mydf.stdout, mydf.verbose)
462    cell = mydf.cell
463
464    dm_kpts = lib.asarray(dm_kpts, order='C')
465    dms = _format_dms(dm_kpts, kpts)
466    nset, nkpts, nao = dms.shape[:3]
467
468    tasks = getattr(mydf, 'tasks', None)
469    if tasks is None:
470        mydf.tasks = tasks = multi_grids_tasks(cell, mydf.mesh, log)
471        log.debug('Multigrid ntasks %s', len(tasks))
472
473    assert(deriv < 2)
474    #hermi = hermi and abs(dms - dms.transpose(0,1,3,2).conj()).max() < 1e-9
475    gga_high_order = False
476    if deriv == 0:
477        xctype = 'LDA'
478        rhodim = 1
479
480    elif deriv == 1:
481        if rhog_high_order:
482            xctype = 'GGA'
483            rhodim = 4
484        else:  # approximate high order derivatives in reciprocal space
485            gga_high_order = True
486            xctype = 'LDA'
487            rhodim = 1
488            deriv = 0
489        assert(hermi == 1 or gamma_point(kpts))
490
491    elif deriv == 2:  # meta-GGA
492        raise NotImplementedError
493        assert(hermi == 1 or gamma_point(kpts))
494
495    ignore_imag = (hermi == 1)
496
497    nx, ny, nz = mydf.mesh
498    rhoG = numpy.zeros((nset*rhodim,nx,ny,nz), dtype=numpy.complex128)
499    for grids_dense, grids_sparse in tasks:
500        h_cell = grids_dense.cell
501        mesh = tuple(grids_dense.mesh)
502        ngrids = numpy.prod(mesh)
503        log.debug('mesh %s  rcut %g', mesh, h_cell.rcut)
504
505        if grids_sparse is None:
506            # The first pass handles all diffused functions using the regular
507            # matrix multiplication code.
508            rho = numpy.zeros((nset,rhodim,ngrids), dtype=numpy.complex128)
509            idx_h = grids_dense.ao_idx
510            dms_hh = numpy.asarray(dms[:,:,idx_h[:,None],idx_h], order='C')
511            for ao_h_etc, p0, p1 in mydf.aoR_loop(grids_dense, kpts, deriv):
512                ao_h, mask = ao_h_etc[0], ao_h_etc[2]
513                for k in range(nkpts):
514                    for i in range(nset):
515                        if xctype == 'LDA':
516                            ao_dm = lib.dot(ao_h[k], dms_hh[i,k])
517                            rho_sub = numpy.einsum('xi,xi->x', ao_dm, ao_h[k].conj())
518                        else:
519                            rho_sub = numint.eval_rho(h_cell, ao_h[k], dms_hh[i,k],
520                                                      mask, xctype, hermi)
521                        rho[i,:,p0:p1] += rho_sub
522                ao_h = ao_h_etc = ao_dm = None
523            if ignore_imag:
524                rho = rho.real
525        else:
526            idx_h = grids_dense.ao_idx
527            idx_l = grids_sparse.ao_idx
528            idx_t = numpy.append(idx_h, idx_l)
529            dms_ht = numpy.asarray(dms[:,:,idx_h[:,None],idx_t], order='C')
530            dms_lh = numpy.asarray(dms[:,:,idx_l[:,None],idx_h], order='C')
531
532            t_cell = h_cell + grids_sparse.cell
533            nshells_h = _pgto_shells(h_cell)
534            nshells_t = _pgto_shells(t_cell)
535            t_cell, t_coeff = t_cell.to_uncontracted_cartesian_basis()
536
537            if deriv == 0:
538                h_coeff = scipy.linalg.block_diag(*t_coeff[:h_cell.nbas])
539                l_coeff = scipy.linalg.block_diag(*t_coeff[h_cell.nbas:])
540                t_coeff = scipy.linalg.block_diag(*t_coeff)
541
542                if hermi == 1:
543                    naol, naoh = dms_lh.shape[2:]
544                    dms_ht[:,:,:,naoh:] += dms_lh.transpose(0,1,3,2)
545                    pgto_dms = lib.einsum('nkij,pi,qj->nkpq', dms_ht, h_coeff, t_coeff)
546                    shls_slice = (0, nshells_h, 0, nshells_t)
547                    #:rho = eval_rho(t_cell, pgto_dms, shls_slice, 0, 'LDA', kpts,
548                    #:               offset=None, submesh=None, ignore_imag=True)
549                    rho = _eval_rho_bra(t_cell, pgto_dms, shls_slice, 0,
550                                        'LDA', kpts, grids_dense, True, log)
551
552                else:
553                    pgto_dms = lib.einsum('nkij,pi,qj->nkpq', dms_ht, h_coeff, t_coeff)
554                    shls_slice = (0, nshells_h, 0, nshells_t)
555                    #:rho = eval_rho(t_cell, pgto_dms, shls_slice, 0, 'LDA', kpts,
556                    #:               offset=None, submesh=None)
557                    rho = _eval_rho_bra(t_cell, pgto_dms, shls_slice, 0,
558                                        'LDA', kpts, grids_dense, True, log)
559                    pgto_dms = lib.einsum('nkij,pi,qj->nkpq', dms_lh, l_coeff, h_coeff)
560                    shls_slice = (nshells_h, nshells_t, 0, nshells_h)
561                    #:rho += eval_rho(t_cell, pgto_dms, shls_slice, 0, 'LDA', kpts,
562                    #:                offset=None, submesh=None)
563                    rho += _eval_rho_ket(t_cell, pgto_dms, shls_slice, 0,
564                                         'LDA', kpts, grids_dense, True, log)
565
566            elif deriv == 1:
567                h_coeff = scipy.linalg.block_diag(*t_coeff[:h_cell.nbas])
568                l_coeff = scipy.linalg.block_diag(*t_coeff[h_cell.nbas:])
569                t_coeff = scipy.linalg.block_diag(*t_coeff)
570
571                pgto_dms = lib.einsum('nkij,pi,qj->nkpq', dms_ht, h_coeff, t_coeff)
572                shls_slice = (0, nshells_h, 0, nshells_t)
573                #:rho = eval_rho(t_cell, pgto_dms, shls_slice, 0, 'GGA', kpts,
574                #:               ignore_imag=ignore_imag)
575                rho = _eval_rho_bra(t_cell, pgto_dms, shls_slice, 0, 'GGA',
576                                    kpts, grids_dense, ignore_imag, log)
577
578                pgto_dms = lib.einsum('nkij,pi,qj->nkpq', dms_lh, l_coeff, h_coeff)
579                shls_slice = (nshells_h, nshells_t, 0, nshells_h)
580                #:rho += eval_rho(t_cell, pgto_dms, shls_slice, 0, 'GGA', kpts,
581                #:                ignore_imag=ignore_imag)
582                rho += _eval_rho_ket(t_cell, pgto_dms, shls_slice, 0, 'GGA',
583                                     kpts, grids_dense, ignore_imag, log)
584                if hermi == 1:
585                    # \nabla \chi_i DM(i,j) \chi_j was computed above.
586                    # *2 for \chi_i DM(i,j) \nabla \chi_j
587                    rho[:,1:4] *= 2
588                else:
589                    raise NotImplementedError
590
591        weight = 1./nkpts * cell.vol/ngrids
592        rho_freq = tools.fft(rho.reshape(nset*rhodim, -1), mesh)
593        rho_freq *= weight
594        gx = numpy.fft.fftfreq(mesh[0], 1./mesh[0]).astype(numpy.int32)
595        gy = numpy.fft.fftfreq(mesh[1], 1./mesh[1]).astype(numpy.int32)
596        gz = numpy.fft.fftfreq(mesh[2], 1./mesh[2]).astype(numpy.int32)
597        #:rhoG[:,gx[:,None,None],gy[:,None],gz] += rho_freq.reshape((-1,)+mesh)
598        _takebak_4d(rhoG, rho_freq.reshape((-1,) + mesh), (None, gx, gy, gz))
599
600    rhoG = rhoG.reshape(nset,rhodim,-1)
601
602    if gga_high_order:
603        Gv = cell.get_Gv(mydf.mesh)
604        rhoG1 = numpy.einsum('np,px->nxp', 1j*rhoG[:,0], Gv)
605        rhoG = numpy.concatenate([rhoG, rhoG1], axis=1)
606    return rhoG
607
608
609def _eval_rho_bra(cell, dms, shls_slice, hermi, xctype, kpts, grids,
610                  ignore_imag, log):
611    a = cell.lattice_vectors()
612    rmax = a.max()
613    mesh = numpy.asarray(grids.mesh)
614    rcut = grids.cell.rcut
615    nset = dms.shape[0]
616    if xctype == 'LDA':
617        rhodim = 1
618    else:
619        rhodim = 4
620
621    if rcut > rmax * R_RATIO_SUBLOOP:
622        rho = eval_rho(cell, dms, shls_slice, hermi, xctype, kpts,
623                       mesh, ignore_imag=ignore_imag)
624        return numpy.reshape(rho, (nset, rhodim, numpy.prod(mesh)))
625
626    if hermi == 1 or ignore_imag:
627        rho = numpy.zeros((nset, rhodim) + tuple(mesh))
628    else:
629        rho = numpy.zeros((nset, rhodim) + tuple(mesh), dtype=numpy.complex128)
630
631    b = numpy.linalg.inv(a.T)
632    ish0, ish1, jsh0, jsh1 = shls_slice
633    nshells_j = jsh1 - jsh0
634    pcell = copy.copy(cell)
635    rest_dms = []
636    rest_bas = []
637    i1 = 0
638    for atm_id in set(cell._bas[ish0:ish1,ATOM_OF]):
639        atm_bas_idx = numpy.where(cell._bas[ish0:ish1,ATOM_OF] == atm_id)[0]
640        _bas_i = cell._bas[atm_bas_idx]
641        l = _bas_i[:,ANG_OF]
642        i0, i1 = i1, i1 + sum((l+1)*(l+2)//2)
643        sub_dms = dms[:,:,i0:i1]
644
645        atom_position = cell.atom_coord(atm_id)
646        frac_edge0 = b.dot(atom_position - rcut)
647        frac_edge1 = b.dot(atom_position + rcut)
648
649        if (numpy.all(0 < frac_edge0) and numpy.all(frac_edge1 < 1)):
650            pcell._bas = numpy.vstack((_bas_i, cell._bas[jsh0:jsh1]))
651            nshells_i = len(atm_bas_idx)
652            sub_slice = (0, nshells_i, nshells_i, nshells_i+nshells_j)
653
654            offset = (frac_edge0 * mesh).astype(int)
655            mesh1 = numpy.ceil(frac_edge1 * mesh).astype(int)
656            submesh = mesh1 - offset
657            log.debug1('atm %d  rcut %f  offset %s submesh %s',
658                       atm_id, rcut, offset, submesh)
659            rho1 = eval_rho(pcell, sub_dms, sub_slice, hermi, xctype, kpts,
660                            mesh, offset, submesh, ignore_imag=ignore_imag)
661            #:rho[:,:,offset[0]:mesh1[0],offset[1]:mesh1[1],offset[2]:mesh1[2]] += \
662            #:        numpy.reshape(rho1, (nset, rhodim) + tuple(submesh))
663            gx = numpy.arange(offset[0], mesh1[0], dtype=numpy.int32)
664            gy = numpy.arange(offset[1], mesh1[1], dtype=numpy.int32)
665            gz = numpy.arange(offset[2], mesh1[2], dtype=numpy.int32)
666            _takebak_5d(rho, numpy.reshape(rho1, (nset,rhodim)+tuple(submesh)),
667                        (None, None, gx, gy, gz))
668        else:
669            log.debug1('atm %d  rcut %f  over 2 images', atm_id, rcut)
670            #:rho1 = eval_rho(pcell, sub_dms, sub_slice, hermi, xctype, kpts,
671            #:                mesh, ignore_imag=ignore_imag)
672            #:rho += numpy.reshape(rho1, rho.shape)
673            # or
674            #:eval_rho(pcell, sub_dms, sub_slice, hermi, xctype, kpts,
675            #:         mesh, ignore_imag=ignore_imag, out=rho)
676            rest_bas.append(_bas_i)
677            rest_dms.append(sub_dms)
678    if rest_bas:
679        pcell._bas = numpy.vstack(rest_bas + [cell._bas[jsh0:jsh1]])
680        nshells_i = sum(len(x) for x in rest_bas)
681        sub_slice = (0, nshells_i, nshells_i, nshells_i+nshells_j)
682        sub_dms = numpy.concatenate(rest_dms, axis=2)
683        eval_rho(pcell, sub_dms, sub_slice, hermi, xctype, kpts,
684                 mesh, ignore_imag=ignore_imag, out=rho)
685    return rho.reshape((nset, rhodim, numpy.prod(mesh)))
686
687def _eval_rho_ket(cell, dms, shls_slice, hermi, xctype, kpts, grids,
688                  ignore_imag, log):
689    a = cell.lattice_vectors()
690    rmax = a.max()
691    mesh = numpy.asarray(grids.mesh)
692    rcut = grids.cell.rcut
693    nset = dms.shape[0]
694    if xctype == 'LDA':
695        rhodim = 1
696    else:
697        rhodim = 4
698
699    if rcut > rmax * R_RATIO_SUBLOOP:
700        rho = eval_rho(cell, dms, shls_slice, hermi, xctype, kpts,
701                       mesh, ignore_imag=ignore_imag)
702        return numpy.reshape(rho, (nset, rhodim, numpy.prod(mesh)))
703
704    if hermi == 1 or ignore_imag:
705        rho = numpy.zeros((nset, rhodim) + tuple(mesh))
706    else:
707        rho = numpy.zeros((nset, rhodim) + tuple(mesh), dtype=numpy.complex128)
708
709    b = numpy.linalg.inv(a.T)
710    ish0, ish1, jsh0, jsh1 = shls_slice
711    nshells_i = ish1 - ish0
712    pcell = copy.copy(cell)
713    rest_dms = []
714    rest_bas = []
715    j1 = 0
716    for atm_id in set(cell._bas[jsh0:jsh1,ATOM_OF]):
717        atm_bas_idx = numpy.where(cell._bas[jsh0:jsh1,ATOM_OF] == atm_id)[0]
718        _bas_j = cell._bas[atm_bas_idx]
719        l = _bas_j[:,ANG_OF]
720        j0, j1 = j1, j1 + sum((l+1)*(l+2)//2)
721        sub_dms = dms[:,:,:,j0:j1]
722
723        atom_position = cell.atom_coord(atm_id)
724        frac_edge0 = b.dot(atom_position - rcut)
725        frac_edge1 = b.dot(atom_position + rcut)
726
727        if (numpy.all(0 < frac_edge0) and numpy.all(frac_edge1 < 1)):
728            pcell._bas = numpy.vstack((cell._bas[ish0:ish1], _bas_j))
729            nshells_j = len(atm_bas_idx)
730            sub_slice = (0, nshells_i, nshells_i, nshells_i+nshells_j)
731
732            offset = (frac_edge0 * mesh).astype(int)
733            mesh1 = numpy.ceil(frac_edge1 * mesh).astype(int)
734            submesh = mesh1 - offset
735            log.debug1('atm %d  rcut %f  offset %s submesh %s',
736                       atm_id, rcut, offset, submesh)
737            rho1 = eval_rho(pcell, sub_dms, sub_slice, hermi, xctype, kpts,
738                            mesh, offset, submesh, ignore_imag=ignore_imag)
739            #:rho[:,:,offset[0]:mesh1[0],offset[1]:mesh1[1],offset[2]:mesh1[2]] += \
740            #:        numpy.reshape(rho1, (nset, rhodim) + tuple(submesh))
741            gx = numpy.arange(offset[0], mesh1[0], dtype=numpy.int32)
742            gy = numpy.arange(offset[1], mesh1[1], dtype=numpy.int32)
743            gz = numpy.arange(offset[2], mesh1[2], dtype=numpy.int32)
744            _takebak_5d(rho, numpy.reshape(rho1, (nset,rhodim)+tuple(submesh)),
745                        (None, None, gx, gy, gz))
746        else:
747            log.debug1('atm %d  rcut %f  over 2 images', atm_id, rcut)
748            #:rho1 = eval_rho(pcell, sub_dms, sub_slice, hermi, xctype, kpts,
749            #:                mesh, ignore_imag=ignore_imag)
750            #:rho += numpy.reshape(rho1, rho.shape)
751            #:eval_rho(pcell, sub_dms, sub_slice, hermi, xctype, kpts,
752            #:         mesh, ignore_imag=ignore_imag, out=rho)
753            rest_bas.append(_bas_j)
754            rest_dms.append(sub_dms)
755    if rest_bas:
756        pcell._bas = numpy.vstack([cell._bas[ish0:ish1]] + rest_bas)
757        nshells_j = sum(len(x) for x in rest_bas)
758        sub_slice = (0, nshells_i, nshells_i, nshells_i+nshells_j)
759        sub_dms = numpy.concatenate(rest_dms, axis=3)
760        eval_rho(pcell, sub_dms, sub_slice, hermi, xctype, kpts,
761                 mesh, ignore_imag=ignore_imag, out=rho)
762    return rho.reshape((nset, rhodim, numpy.prod(mesh)))
763
764
765def _get_j_pass2(mydf, vG, kpts=numpy.zeros((1,3)), verbose=None):
766    log = logger.new_logger(mydf, verbose)
767    cell = mydf.cell
768    nkpts = len(kpts)
769    nao = cell.nao_nr()
770    nx, ny, nz = mydf.mesh
771    vG = vG.reshape(-1,nx,ny,nz)
772    nset = vG.shape[0]
773
774    tasks = getattr(mydf, 'tasks', None)
775    if tasks is None:
776        mydf.tasks = tasks = multi_grids_tasks(cell, mydf.mesh, log)
777        log.debug('Multigrid ntasks %s', len(tasks))
778
779    at_gamma_point = gamma_point(kpts)
780    if at_gamma_point:
781        vj_kpts = numpy.zeros((nset,nkpts,nao,nao))
782    else:
783        vj_kpts = numpy.zeros((nset,nkpts,nao,nao), dtype=numpy.complex128)
784
785    for grids_dense, grids_sparse in tasks:
786        mesh = grids_dense.mesh
787        ngrids = numpy.prod(mesh)
788        log.debug('mesh %s', mesh)
789
790        gx = numpy.fft.fftfreq(mesh[0], 1./mesh[0]).astype(numpy.int32)
791        gy = numpy.fft.fftfreq(mesh[1], 1./mesh[1]).astype(numpy.int32)
792        gz = numpy.fft.fftfreq(mesh[2], 1./mesh[2]).astype(numpy.int32)
793        #:sub_vG = vG[:,gx[:,None,None],gy[:,None],gz].reshape(nset,ngrids)
794        sub_vG = _take_4d(vG, (None, gx, gy, gz)).reshape(nset,ngrids)
795
796        v_rs = tools.ifft(sub_vG, mesh).reshape(nset,ngrids)
797        vR = numpy.asarray(v_rs.real, order='C')
798        vI = numpy.asarray(v_rs.imag, order='C')
799        if at_gamma_point:
800            v_rs = vR
801
802        idx_h = grids_dense.ao_idx
803        if grids_sparse is None:
804            for ao_h_etc, p0, p1 in mydf.aoR_loop(grids_dense, kpts):
805                ao_h = ao_h_etc[0]
806                for k in range(nkpts):
807                    for i in range(nset):
808                        vj_sub = lib.dot(ao_h[k].conj().T*v_rs[i,p0:p1], ao_h[k])
809                        vj_kpts[i,k,idx_h[:,None],idx_h] += vj_sub
810                ao_h = ao_h_etc = None
811        else:
812            idx_h = grids_dense.ao_idx
813            idx_l = grids_sparse.ao_idx
814            # idx_t = numpy.append(idx_h, idx_l)
815            naoh = len(idx_h)
816
817            h_cell = grids_dense.cell
818            l_cell = grids_sparse.cell
819            t_cell = h_cell + l_cell
820            t_cell, t_coeff = t_cell.to_uncontracted_cartesian_basis()
821            nshells_h = _pgto_shells(h_cell)
822            nshells_t = _pgto_shells(t_cell)
823
824            h_coeff = scipy.linalg.block_diag(*t_coeff[:h_cell.nbas])
825            #l_coeff = scipy.linalg.block_diag(*t_coeff[h_cell.nbas:])
826            t_coeff = scipy.linalg.block_diag(*t_coeff)
827            shls_slice = (0, nshells_h, 0, nshells_t)
828            vp = eval_mat(t_cell, vR, shls_slice, 1, 0, 'LDA', kpts)
829            vp = lib.einsum('nkpq,pi,qj->nkij', vp, h_coeff, t_coeff)
830
831            # Imaginary part may contribute
832            if not at_gamma_point and abs(vI).max() > IMAG_TOL:
833                vpI = eval_mat(t_cell, vI, shls_slice, 1, 0, 'LDA', kpts)
834                vpI = lib.einsum('nkpq,pi,qj->nkij', vpI, h_coeff, t_coeff)
835                vp = vp + vpI * 1j
836                vpI = None
837
838            vj_kpts[:,:,idx_h[:,None],idx_h] += vp[:,:,:,:naoh]
839            vj_kpts[:,:,idx_h[:,None],idx_l] += vp[:,:,:,naoh:]
840
841            #:shls_slice = (nshells_h, nshells_t, 0, nshells_h)
842            #:vp = eval_mat(t_cell, vR, shls_slice, 1, 0, 'LDA', kpts)
843            #:vp = lib.einsum('nkpq,pi,qj->nkij', vp, l_coeff, h_coeff)
844            #:vj_kpts[:,:,idx_l[:,None],idx_h] += vp
845            vj_kpts[:,:,idx_l[:,None],idx_h] += \
846                    vp[:,:,:,naoh:].transpose(0,1,3,2).conj()
847
848    return vj_kpts
849
850
851def _get_gga_pass2(mydf, vG, kpts=numpy.zeros((1,3)), verbose=None):
852    log = logger.new_logger(mydf, verbose)
853    cell = mydf.cell
854    nkpts = len(kpts)
855    nao = cell.nao_nr()
856    nx, ny, nz = mydf.mesh
857    vG = vG.reshape(-1,4,nx,ny,nz)
858    nset = vG.shape[0]
859
860    if gamma_point(kpts):
861        veff = numpy.zeros((nset,nkpts,nao,nao))
862    else:
863        veff = numpy.zeros((nset,nkpts,nao,nao), dtype=numpy.complex128)
864
865    for grids_dense, grids_sparse in mydf.tasks:
866        mesh = grids_dense.mesh
867        ngrids = numpy.prod(mesh)
868        log.debug('mesh %s', mesh)
869
870        gx = numpy.fft.fftfreq(mesh[0], 1./mesh[0]).astype(numpy.int32)
871        gy = numpy.fft.fftfreq(mesh[1], 1./mesh[1]).astype(numpy.int32)
872        gz = numpy.fft.fftfreq(mesh[2], 1./mesh[2]).astype(numpy.int32)
873        #:sub_vG = vG[:,:,gx[:,None,None],gy[:,None],gz].reshape(-1,ngrids)
874        sub_vG = _take_5d(vG, (None, None, gx, gy, gz)).reshape(-1,ngrids)
875        wv = tools.ifft(sub_vG, mesh).real.reshape(nset,4,ngrids)
876        wv = numpy.asarray(wv, order='C')
877
878        if grids_sparse is None:
879            idx_h = grids_dense.ao_idx
880            naoh = len(idx_h)
881            for ao_h_etc, p0, p1 in mydf.aoR_loop(grids_dense, kpts, deriv=1):
882                ao_h = ao_h_etc[0]
883                for k in range(nkpts):
884                    for i in range(nset):
885                        aow = numint._scale_ao(ao_h[k], wv[i])
886                        v = lib.dot(aow.conj().T, ao_h[k][0])
887                        veff[i,k,idx_h[:,None],idx_h] += v + v.conj().T
888                ao_h = ao_h_etc = None
889        else:
890            idx_h = grids_dense.ao_idx
891            idx_l = grids_sparse.ao_idx
892            # idx_t = numpy.append(idx_h, idx_l)
893            naoh = len(idx_h)
894
895            h_cell = grids_dense.cell
896            l_cell = grids_sparse.cell
897            t_cell = h_cell + l_cell
898            t_cell, t_coeff = t_cell.to_uncontracted_cartesian_basis()
899            nshells_h = _pgto_shells(h_cell)
900            nshells_t = _pgto_shells(t_cell)
901
902            h_coeff = scipy.linalg.block_diag(*t_coeff[:h_cell.nbas])
903            l_coeff = scipy.linalg.block_diag(*t_coeff[h_cell.nbas:])
904            t_coeff = scipy.linalg.block_diag(*t_coeff)
905
906            shls_slice = (0, nshells_h, 0, nshells_t)
907            v = eval_mat(t_cell, wv, shls_slice, 1, 0, 'GGA', kpts)
908            v = lib.einsum('nkpq,pi,qj->nkij', v, h_coeff, t_coeff)
909            veff[:,:,idx_h[:,None],idx_h] += v[:,:,:,:naoh]
910            veff[:,:,idx_h[:,None],idx_h] += v[:,:,:,:naoh].conj().transpose(0,1,3,2)
911            veff[:,:,idx_h[:,None],idx_l] += v[:,:,:,naoh:]
912            veff[:,:,idx_l[:,None],idx_h] += v[:,:,:,naoh:].conj().transpose(0,1,3,2)
913
914            shls_slice = (nshells_h, nshells_t, 0, nshells_h)
915            v = eval_mat(t_cell, wv, shls_slice, 1, 0, 'GGA', kpts)#, offset, submesh)
916            v = lib.einsum('nkpq,pi,qj->nkij', v, l_coeff.conj(), h_coeff)
917            veff[:,:,idx_l[:,None],idx_h] += v
918            veff[:,:,idx_h[:,None],idx_l] += v.conj().transpose(0,1,3,2)
919
920    return veff
921
922
923def nr_rks(mydf, xc_code, dm_kpts, hermi=1, kpts=None,
924           kpts_band=None, with_j=False, return_j=False, verbose=None):
925    '''Compute the XC energy and RKS XC matrix at sampled k-points.
926    multigrid version of function pbc.dft.numint.nr_rks.
927
928    Args:
929        dm_kpts : (nkpts, nao, nao) ndarray or a list of (nkpts,nao,nao) ndarray
930            Density matrix at each k-point.
931        kpts : (nkpts, 3) ndarray
932
933    Kwargs:
934        kpts_band : (3,) ndarray or (*,3) ndarray
935            A list of arbitrary "band" k-points at which to evalute the matrix.
936
937    Returns:
938        exc : XC energy
939        nelec : number of electrons obtained from the numerical integration
940        veff : (nkpts, nao, nao) ndarray
941            or list of veff if the input dm_kpts is a list of DMs
942        vj : (nkpts, nao, nao) ndarray
943            or list of vj if the input dm_kpts is a list of DMs
944    '''
945    if kpts is None: kpts = mydf.kpts
946    log = logger.new_logger(mydf, verbose)
947    cell = mydf.cell
948    dm_kpts = lib.asarray(dm_kpts, order='C')
949    dms = _format_dms(dm_kpts, kpts)
950    nset, nkpts, nao = dms.shape[:3]
951    kpts_band, input_band = _format_kpts_band(kpts_band, kpts), kpts_band
952
953    ni = mydf._numint
954    xctype = ni._xc_type(xc_code)
955
956    if xctype == 'LDA':
957        deriv = 0
958    elif xctype == 'GGA':
959        deriv = 1
960    rhoG = _eval_rhoG(mydf, dm_kpts, hermi, kpts, deriv)
961
962    mesh = mydf.mesh
963    ngrids = numpy.prod(mesh)
964    coulG = tools.get_coulG(cell, mesh=mesh)
965    vG = numpy.einsum('ng,g->ng', rhoG[:,0], coulG)
966    ecoul = .5 * numpy.einsum('ng,ng->n', rhoG[:,0].real, vG.real)
967    ecoul+= .5 * numpy.einsum('ng,ng->n', rhoG[:,0].imag, vG.imag)
968    ecoul /= cell.vol
969    log.debug('Multigrid Coulomb energy %s', ecoul)
970
971    weight = cell.vol / ngrids
972    # *(1./weight) because rhoR is scaled by weight in _eval_rhoG.  When
973    # computing rhoR with IFFT, the weight factor is not needed.
974    rhoR = tools.ifft(rhoG.reshape(-1,ngrids), mesh).real * (1./weight)
975    rhoR = rhoR.reshape(nset,-1,ngrids)
976    wv_freq = []
977    nelec = numpy.zeros(nset)
978    excsum = numpy.zeros(nset)
979    for i in range(nset):
980        exc, vxc = ni.eval_xc(xc_code, rhoR[i], spin=0, deriv=1)[:2]
981        if xctype == 'LDA':
982            wv = vxc[0].reshape(1,ngrids) * weight
983        elif xctype == 'GGA':
984            wv = numint._rks_gga_wv0(rhoR[i], vxc, weight)
985
986        nelec[i] += rhoR[i,0].sum() * weight
987        excsum[i] += (rhoR[i,0]*exc).sum() * weight
988        wv_freq.append(tools.fft(wv, mesh))
989    rhoR = rhoG = None
990    wv_freq = numpy.asarray(wv_freq).reshape(nset,-1,*mesh)
991
992    if nset == 1:
993        ecoul = ecoul[0]
994        nelec = nelec[0]
995        excsum = excsum[0]
996    log.debug('Multigrid exc %s  nelec %s', excsum, nelec)
997
998    kpts_band, input_band = _format_kpts_band(kpts_band, kpts), kpts_band
999    if xctype == 'LDA':
1000        if with_j:
1001            wv_freq[:,0] += vG.reshape(nset,*mesh)
1002        veff = _get_j_pass2(mydf, wv_freq, kpts_band, verbose=log)
1003    elif xctype == 'GGA':
1004        if with_j:  # *.5 because v+v.T.conj() is evaluated in _get_gga_pass2
1005            wv_freq[:,0] += vG.reshape(nset,*mesh) * .5
1006        veff = _get_gga_pass2(mydf, wv_freq, kpts_band, verbose=log)
1007    veff = _format_jks(veff, dm_kpts, input_band, kpts)
1008
1009    if return_j:
1010        vj = _get_j_pass2(mydf, vG, kpts_band, verbose=log)
1011        vj = _format_jks(veff, dm_kpts, input_band, kpts)
1012    else:
1013        vj = None
1014
1015    veff = lib.tag_array(veff, ecoul=ecoul, exc=excsum, vj=vj, vk=None)
1016    return nelec, excsum, veff
1017
1018
1019# Note nr_uks handles only one set of KUKS density matrices (alpha, beta) in
1020# each call (nr_rks supports multiple sets of KRKS density matrices)
1021def nr_uks(mydf, xc_code, dm_kpts, hermi=1, kpts=None,
1022           kpts_band=None, with_j=False, return_j=False, verbose=None):
1023    '''Compute the XC energy and UKS XC matrix at sampled k-points.
1024    multigrid version of function pbc.dft.numint.nr_uks
1025
1026    Args:
1027        dm_kpts : (nkpts, nao, nao) ndarray or a list of (nkpts,nao,nao) ndarray
1028            Density matrix at each k-point.
1029        kpts : (nkpts, 3) ndarray
1030
1031    Kwargs:
1032        kpts_band : (3,) ndarray or (*,3) ndarray
1033            A list of arbitrary "band" k-points at which to evalute the matrix.
1034
1035    Returns:
1036        exc : XC energy
1037        nelec : number of electrons obtained from the numerical integration
1038        veff : (2, nkpts, nao, nao) ndarray
1039            or list of veff if the input dm_kpts is a list of DMs
1040        vj : (nkpts, nao, nao) ndarray
1041            or list of vj if the input dm_kpts is a list of DMs
1042    '''
1043    if kpts is None:
1044        kpts = mydf.kpts
1045    log = logger.new_logger(mydf, verbose)
1046    cell = mydf.cell
1047    dm_kpts = lib.asarray(dm_kpts, order='C')
1048    dms = _format_dms(dm_kpts, kpts)
1049    nset, nkpts, nao = dms.shape[:3]
1050    assert(nset == 2)
1051    kpts_band, input_band = _format_kpts_band(kpts_band, kpts), kpts_band
1052
1053    ni = mydf._numint
1054    xctype = ni._xc_type(xc_code)
1055
1056    if xctype == 'LDA':
1057        deriv = 0
1058    elif xctype == 'GGA':
1059        deriv = 1
1060    rhoG = _eval_rhoG(mydf, dm_kpts, hermi, kpts, deriv)
1061
1062    mesh = mydf.mesh
1063    ngrids = numpy.prod(mesh)
1064    coulG = tools.get_coulG(cell, mesh=mesh)
1065    vG = numpy.einsum('ng,g->g', rhoG[:,0], coulG)
1066    ecoul = .5 * numpy.einsum('ng,g->', rhoG[:,0].real, vG.real)
1067    ecoul+= .5 * numpy.einsum('ng,g->', rhoG[:,0].imag, vG.imag)
1068    ecoul /= cell.vol
1069    log.debug('Multigrid Coulomb energy %s', ecoul)
1070
1071    weight = cell.vol / ngrids
1072    # *(1./weight) because rhoR is scaled by weight in _eval_rhoG.  When
1073    # computing rhoR with IFFT, the weight factor is not needed.
1074    rhoR = tools.ifft(rhoG.reshape(-1,ngrids), mesh).real * (1./weight)
1075    rhoR = rhoR.reshape(2,-1,ngrids)
1076    wv_freq = []
1077    nelec = numpy.zeros((2))
1078    excsum = 0
1079
1080    exc, vxc = ni.eval_xc(xc_code, rhoR, spin=1, deriv=1)[:2]
1081    if xctype == 'LDA':
1082        vrho = vxc[0]
1083        wva = vrho[:,0].reshape(1,ngrids) * weight
1084        wvb = vrho[:,1].reshape(1,ngrids) * weight
1085    elif xctype == 'GGA':
1086        wva, wvb = numint._uks_gga_wv0(rhoR, vxc, weight)
1087
1088    nelec[0] += rhoR[0,0].sum() * weight
1089    nelec[1] += rhoR[1,0].sum() * weight
1090    excsum += (rhoR[0,0]*exc).sum() * weight
1091    excsum += (rhoR[1,0]*exc).sum() * weight
1092    wv_freq = tools.fft(numpy.vstack((wva,wvb)), mesh)
1093    wv_freq = wv_freq.reshape(2,-1,*mesh)
1094    rhoR = rhoG = None
1095    log.debug('Multigrid exc %g  nelec %s', excsum, nelec)
1096
1097    kpts_band, input_band = _format_kpts_band(kpts_band, kpts), kpts_band
1098    if xctype == 'LDA':
1099        if with_j:
1100            wv_freq[:,0] += vG.reshape(*mesh)
1101        veff = _get_j_pass2(mydf, wv_freq, kpts_band, verbose=log)
1102    elif xctype == 'GGA':
1103        if with_j:  # *.5 because v+v.T.conj() is evaluated in _get_gga_pass2
1104            wv_freq[:,0] += vG.reshape(*mesh) * .5
1105        veff = _get_gga_pass2(mydf, wv_freq, kpts_band, verbose=log)
1106    veff = _format_jks(veff, dm_kpts, input_band, kpts)
1107
1108    if return_j:
1109        vj = _get_j_pass2(mydf, vG, kpts_band, verbose=log)
1110        vj = _format_jks(veff, dm_kpts, input_band, kpts)
1111    else:
1112        vj = None
1113
1114    veff = lib.tag_array(veff, ecoul=ecoul, exc=excsum, vj=vj, vk=None)
1115    return nelec, excsum, veff
1116
1117
1118def nr_rks_fxc(mydf, xc_code, dm0, dms, hermi=1, with_j=False,
1119               rho0=None, vxc=None, fxc=None, kpts=None, verbose=None):
1120    '''multigrid version of function pbc.dft.numint.nr_rks_fxc
1121    '''
1122    if kpts is None:
1123        kpts = numpy.zeros((1,3))
1124    log = logger.new_logger(mydf, verbose)
1125    cell = mydf.cell
1126    mesh = mydf.mesh
1127    ngrids = numpy.prod(mesh)
1128
1129    dm_kpts = lib.asarray(dms, order='C')
1130    dms = _format_dms(dm_kpts, kpts)
1131    nset, nkpts, nao = dms.shape[:3]
1132
1133    ni = mydf._numint
1134    xctype = ni._xc_type(xc_code)
1135    if xctype == 'LDA':
1136        deriv = 0
1137    elif xctype == 'GGA':
1138        deriv = 1
1139    else:
1140        deriv = 2
1141
1142    weight = cell.vol / ngrids
1143    if rho0 is None:
1144        rhoG = _eval_rhoG(mydf, dm0, hermi, kpts, deriv)
1145        rho0 = tools.ifft(rhoG.reshape(-1,ngrids), mesh).real * (1./weight)
1146
1147    if vxc is None or fxc is None:
1148        vxc, fxc = ni.eval_xc(xc_code, rho0, spin=0, deriv=2)[1:3]
1149
1150    rhoG = _eval_rhoG(mydf, dms, hermi, kpts, deriv)
1151    rho1 = tools.ifft(rhoG.reshape(-1,ngrids), mesh).real * (1./weight)
1152    rho1 = rho1.reshape(nset,-1,ngrids)
1153    if with_j:
1154        coulG = tools.get_coulG(cell, mesh=mesh)
1155        vG = rhoG[:,0] * coulG
1156        vG = vG.reshape(nset, *mesh)
1157
1158    if xctype == 'LDA':
1159        frr = fxc[0]
1160        wv = weight * frr * rho1
1161        wv = tools.fft(wv.reshape(-1,ngrids), mesh).reshape(nset,-1,*mesh)
1162        if with_j:
1163            wv[:,0] += vG
1164        veff = _get_j_pass2(mydf, wv, kpts, verbose=log)
1165
1166    elif xctype == 'GGA':
1167        wv = [numint._rks_gga_wv1(rho0, rho1[i], vxc, fxc, weight)
1168              for i in range(nset)]
1169        wv = numpy.vstack(wv).reshape(-1,ngrids)
1170        wv = tools.fft(wv, mesh).reshape(nset,-1,*mesh)
1171        if with_j:
1172            wv[:,0] += vG * .5
1173        veff = _get_gga_pass2(mydf, wv, kpts, verbose=log)
1174
1175    return veff.reshape(dm_kpts.shape)
1176
1177
1178def nr_rks_fxc_st(mydf, xc_code, dm0, dms_alpha, hermi=1, singlet=True, with_j=False,
1179                  rho0=None, vxc=None, fxc=None, kpts=None, verbose=None):
1180    '''multigrid version of function pbc.dft.numint.nr_rks_fxc_st
1181    '''
1182    if kpts is None:
1183        kpts = numpy.zeros((1,3))
1184    log = logger.new_logger(mydf, verbose)
1185    cell = mydf.cell
1186    mesh = mydf.mesh
1187    ngrids = numpy.prod(mesh)
1188
1189    dm_kpts = lib.asarray(dms_alpha, order='C')
1190    dms = _format_dms(dm_kpts, kpts)
1191    nset, nkpts, nao = dms.shape[:3]
1192
1193    ni = mydf._numint
1194    xctype = ni._xc_type(xc_code)
1195    if xctype == 'LDA':
1196        deriv = 0
1197    elif xctype == 'GGA':
1198        deriv = 1
1199    else:
1200        deriv = 2
1201
1202    weight = cell.vol / ngrids
1203    if rho0 is None:
1204        rhoG = _eval_rhoG(mydf, dm0, hermi, kpts, deriv)
1205        # *.5 to get alpha density
1206        rho0 = tools.ifft(rhoG.reshape(-1,ngrids), mesh).real * (.5/weight)
1207        rho0 = (rho0, rho0)
1208
1209    if vxc is None or fxc is None:
1210        vxc, fxc = ni.eval_xc(xc_code, rho0, spin=1, deriv=2)[1:3]
1211
1212    rhoG = _eval_rhoG(mydf, dms, hermi, kpts, deriv)
1213    rho1 = tools.ifft(rhoG.reshape(-1,ngrids), mesh).real * (1./weight)
1214    rho1 = rho1.reshape(nset,-1,ngrids)
1215    if with_j:
1216        coulG = tools.get_coulG(cell, mesh=mesh)
1217        vG = rhoG[:,0] * coulG
1218        vG = vG.reshape(nset, *mesh)
1219
1220    if xctype == 'LDA':
1221        u_u, u_d, d_d = fxc[0].T
1222        if singlet:
1223            frho = u_u + u_d
1224        else:
1225            frho = u_u - u_d
1226        wv = weight * frho * rho1
1227        wv = tools.fft(wv.reshape(-1,ngrids), mesh).reshape(nset,-1,*mesh)
1228        if with_j:
1229            wv[:,0] += vG
1230        veff = _get_j_pass2(mydf, wv, kpts, verbose=log)
1231
1232    elif xctype == 'GGA':
1233        vsigma = vxc[1].T
1234        u_u, u_d, d_d = fxc[0].T  # v2rho2
1235        u_uu, u_ud, u_dd, d_uu, d_ud, d_dd = fxc[1].T  # v2rhosigma
1236        uu_uu, uu_ud, uu_dd, ud_ud, ud_dd, dd_dd = fxc[2].T  # v2sigma2
1237        if singlet:
1238            fgamma = vsigma[0] + vsigma[1] * .5
1239            frho = u_u + u_d
1240            fgg = uu_uu + .5*ud_ud + 2*uu_ud + uu_dd
1241            frhogamma = u_uu + u_dd + u_ud
1242        else:
1243            fgamma = vsigma[0] - vsigma[1] * .5
1244            frho = u_u - u_d
1245            fgg = uu_uu - uu_dd
1246            frhogamma = u_uu - u_dd
1247
1248        wv = [numint._rks_gga_wv1(rho0[0], rho1[i], (None,fgamma),
1249                                  (frho,frhogamma,fgg), weight)
1250              for i in range(nset)]
1251        wv = numpy.asarray(wv).reshape(-1,ngrids)
1252        wv = tools.fft(wv, mesh).reshape(nset,-1,*mesh)
1253        if with_j:
1254            wv[:,0] += vG * .5
1255        veff = _get_gga_pass2(mydf, wv, kpts, verbose=log)
1256
1257    return veff.reshape(dm_kpts.shape)
1258
1259
1260def nr_uks_fxc(mydf, xc_code, dm0, dms, hermi=1, with_j=False,
1261               rho0=None, vxc=None, fxc=None, kpts=None, verbose=None):
1262    '''multigrid version of function pbc.dft.numint.nr_uks_fxc
1263    '''
1264    if kpts is None:
1265        kpts = numpy.zeros((1,3))
1266    log = logger.new_logger(mydf, verbose)
1267    cell = mydf.cell
1268    mesh = mydf.mesh
1269    ngrids = numpy.prod(mesh)
1270
1271    dm_kpts = lib.asarray(dms, order='C')
1272    dms = _format_dms(dm_kpts, kpts)
1273    nset, nkpts, nao = dms.shape[:3]
1274    assert(nset == 2)
1275
1276    ni = mydf._numint
1277    xctype = ni._xc_type(xc_code)
1278    if xctype == 'LDA':
1279        deriv = 0
1280    elif xctype == 'GGA':
1281        deriv = 1
1282    else:
1283        deriv = 2
1284
1285    weight = cell.vol / ngrids
1286    if rho0 is None:
1287        rhoG = _eval_rhoG(mydf, dm0, hermi, kpts, deriv)
1288        rho0 = tools.ifft(rhoG.reshape(-1,ngrids), mesh).real * (1./weight)
1289        rho0 = rho0.reshape(nset,-1,ngrids)
1290
1291    if vxc is None or fxc is None:
1292        vxc, fxc = ni.eval_xc(xc_code, rho0, spin=1, deriv=2)[1:3]
1293
1294    rhoG = _eval_rhoG(mydf, dms, hermi, kpts, deriv)
1295    rho1 = tools.ifft(rhoG.reshape(-1,ngrids), mesh).real * (1./weight)
1296    rho1 = rho1.reshape(nset,-1,ngrids)
1297    if with_j:
1298        coulG = tools.get_coulG(cell, mesh=mesh)
1299        vG = (rhoG[0,0] + rhoG[1,0]) * coulG
1300        vG = vG.reshape(mesh)
1301
1302    if xctype == 'LDA':
1303        u_u, u_d, d_d = fxc[0].T
1304        wv = numpy.asarray([u_u * rho1[0] + u_d * rho1[1],
1305                            u_d * rho1[0] + d_d * rho1[1]])
1306        wv *= weight
1307        wv = tools.fft(wv.reshape(-1,ngrids), mesh).reshape(nset,-1,*mesh)
1308        if with_j:
1309            wv[:,0] += vG
1310        veff = _get_j_pass2(mydf, wv, kpts, verbose=log)
1311
1312    elif xctype == 'GGA':
1313        wv = numint._uks_gga_wv1(rho0, rho1, vxc, fxc, weight)
1314        wv = numpy.vstack(wv).reshape(-1,ngrids)
1315        wv = tools.fft(wv, mesh).reshape(nset,-1,*mesh)
1316        if with_j:
1317            wv[:,0] += vG * .5
1318        veff = _get_gga_pass2(mydf, wv, kpts, verbose=log)
1319
1320    return veff.reshape(dm_kpts.shape)
1321
1322
1323def cache_xc_kernel(mydf, xc_code, dm, spin=0, kpts=None):
1324    '''Compute the 0th order density, Vxc and fxc.  They can be used in TDDFT,
1325    DFT hessian module etc.
1326    '''
1327    if kpts is None:
1328        kpts = numpy.zeros((1,3))
1329    cell = mydf.cell
1330    mesh = mydf.mesh
1331    ngrids = numpy.prod(mesh)
1332
1333    ni = mydf._numint
1334    xctype = ni._xc_type(xc_code)
1335    if xctype == 'LDA':
1336        deriv = 0
1337        comp = 1
1338    elif xctype == 'GGA':
1339        deriv = 1
1340        comp = 4
1341    else:
1342        deriv = 2
1343
1344    hermi = 1
1345    weight = cell.vol / ngrids
1346    rhoG = _eval_rhoG(mydf, dm, hermi, kpts, deriv)
1347    rho = tools.ifft(rhoG.reshape(-1,ngrids), mesh).real * (1./weight)
1348    if spin == 0:
1349        rho = rho.reshape(comp,ngrids)
1350    else:
1351        rho = rho.reshape(2,comp,ngrids)
1352
1353    vxc, fxc = ni.eval_xc(xc_code, rho, spin=spin, deriv=2)[1:3]
1354    return rho, vxc, fxc
1355
1356def _gen_rhf_response(mf, dm0, singlet=None, hermi=0):
1357    '''multigrid version of function pbc.scf.newton_ah._gen_rhf_response
1358    '''
1359    #assert(isinstance(mf, dft.krks.KRKS))
1360    if getattr(mf, 'kpts', None) is not None:
1361        kpts = mf.kpts
1362    else:
1363        kpts = mf.kpt.reshape(1,3)
1364
1365    if singlet is None:  # for newton solver
1366        rho0, vxc, fxc = cache_xc_kernel(mf.with_df, mf.xc, dm0, 0, kpts)
1367    else:
1368        rho0, vxc, fxc = cache_xc_kernel(mf.with_df, mf.xc, [dm0*.5]*2, 1, kpts)
1369    dm0 = None
1370
1371    def vind(dm1):
1372        if hermi == 2:
1373            return numpy.zeros_like(dm1)
1374
1375        if singlet is None:  # Without specify singlet, general case
1376            v1 = nr_rks_fxc(mf.with_df, mf.xc, dm0, dm1, hermi,
1377                            True, rho0, vxc, fxc, kpts)
1378        elif singlet:
1379            v1 = nr_rks_fxc_st(mf.with_df, mf.xc, dm0, dm1, hermi, singlet,
1380                               True, rho0, vxc, fxc, kpts)
1381        else:
1382            v1 = nr_rks_fxc_st(mf.with_df, mf.xc, dm0, dm1, hermi, singlet,
1383                               False, rho0, vxc, fxc, kpts)
1384        return v1
1385    return vind
1386
1387def _gen_uhf_response(mf, dm0, with_j=True, hermi=0):
1388    '''multigrid version of function pbc.scf.newton_ah._gen_uhf_response
1389    '''
1390    #assert(isinstance(mf, dft.kuks.KUKS))
1391    if getattr(mf, 'kpts', None) is not None:
1392        kpts = mf.kpts
1393    else:
1394        kpts = mf.kpt.reshape(1,3)
1395
1396    rho0, vxc, fxc = cache_xc_kernel(mf.with_df, mf.xc, dm0, 1, kpts)
1397    dm0 = None
1398
1399    def vind(dm1):
1400        if hermi == 2:
1401            return numpy.zeros_like(dm1)
1402
1403        v1 = nr_uks_fxc(mf.with_df, mf.xc, dm0, dm1, hermi,
1404                        with_j, rho0, vxc, fxc, kpts)
1405        return v1
1406    return vind
1407
1408
1409def get_rho(mydf, dm, kpts=numpy.zeros((1,3))):
1410    '''Density in real space
1411    '''
1412    cell = mydf.cell
1413    hermi = 1
1414    rhoG = _eval_rhoG(mydf, numpy.asarray(dm), hermi, kpts, deriv=0)
1415
1416    mesh = mydf.mesh
1417    ngrids = numpy.prod(mesh)
1418    weight = cell.vol / ngrids
1419    # *(1./weight) because rhoR is scaled by weight in _eval_rhoG.  When
1420    # computing rhoR with IFFT, the weight factor is not needed.
1421    rhoR = tools.ifft(rhoG.reshape(ngrids), mesh).real * (1./weight)
1422    return rhoR
1423
1424
1425def multi_grids_tasks(cell, fft_mesh=None, verbose=None):
1426    if TASKS_TYPE == 'rcut':
1427        return multi_grids_tasks_for_rcut(cell, fft_mesh, verbose)
1428    else:
1429        return multi_grids_tasks_for_ke_cut(cell, fft_mesh, verbose)
1430
1431def multi_grids_tasks_for_rcut(cell, fft_mesh=None, verbose=None):
1432    log = logger.new_logger(cell, verbose)
1433    if fft_mesh is None:
1434        fft_mesh = cell.mesh
1435
1436    # Split shells based on rcut
1437    rcuts_pgto, kecuts_pgto = _primitive_gto_cutoff(cell)
1438    ao_loc = cell.ao_loc_nr()
1439
1440    def make_cell_dense_exp(shls_dense, r0, r1):
1441        cell_dense = copy.copy(cell)
1442        cell_dense._bas = cell._bas.copy()
1443        cell_dense._env = cell._env.copy()
1444
1445        rcut_atom = [0] * cell.natm
1446        ke_cutoff = 0
1447        for ib in shls_dense:
1448            rc = rcuts_pgto[ib]
1449            idx = numpy.where((r1 <= rc) & (rc < r0))[0]
1450            np1 = len(idx)
1451            cs = cell._libcint_ctr_coeff(ib)
1452            np, nc = cs.shape
1453            if np1 < np:  # no pGTO splitting within the shell
1454                pexp = cell._bas[ib,PTR_EXP]
1455                pcoeff = cell._bas[ib,PTR_COEFF]
1456                cs1 = cs[idx]
1457                cell_dense._env[pcoeff:pcoeff+cs1.size] = cs1.T.ravel()
1458                cell_dense._env[pexp:pexp+np1] = cell.bas_exp(ib)[idx]
1459                cell_dense._bas[ib,NPRIM_OF] = np1
1460
1461            ke_cutoff = max(ke_cutoff, kecuts_pgto[ib][idx].max())
1462
1463            ia = cell.bas_atom(ib)
1464            rcut_atom[ia] = max(rcut_atom[ia], rc[idx].max())
1465        cell_dense._bas = cell_dense._bas[shls_dense]
1466        ao_idx = numpy.hstack([numpy.arange(ao_loc[i], ao_loc[i+1])
1467                               for i in shls_dense])
1468        cell_dense.rcut = max(rcut_atom)
1469        return cell_dense, ao_idx, ke_cutoff, rcut_atom
1470
1471    def make_cell_sparse_exp(shls_sparse, r0, r1):
1472        cell_sparse = copy.copy(cell)
1473        cell_sparse._bas = cell._bas.copy()
1474        cell_sparse._env = cell._env.copy()
1475
1476        for ib in shls_sparse:
1477            idx = numpy.where(r0 <= rcuts_pgto[ib])[0]
1478            np1 = len(idx)
1479            cs = cell._libcint_ctr_coeff(ib)
1480            np, nc = cs.shape
1481            if np1 < np:  # no pGTO splitting within the shell
1482                pexp = cell._bas[ib,PTR_EXP]
1483                pcoeff = cell._bas[ib,PTR_COEFF]
1484                cs1 = cs[idx]
1485                cell_sparse._env[pcoeff:pcoeff+cs1.size] = cs1.T.ravel()
1486                cell_sparse._env[pexp:pexp+np1] = cell.bas_exp(ib)[idx]
1487                cell_sparse._bas[ib,NPRIM_OF] = np1
1488        cell_sparse._bas = cell_sparse._bas[shls_sparse]
1489        ao_idx = numpy.hstack([numpy.arange(ao_loc[i], ao_loc[i+1])
1490                               for i in shls_sparse])
1491        return cell_sparse, ao_idx
1492
1493    tasks = []
1494    a = cell.lattice_vectors()
1495    if abs(a-numpy.diag(a.diagonal())).max() < 1e-12:
1496        rmax = a.max() * RMAX_FACTOR_ORTH
1497    else:
1498        rmax = a.max() * RMAX_FACTOR_NONORTH
1499    n_delimeter = int(numpy.log(0.005/rmax) / numpy.log(RMAX_RATIO))
1500    rcut_delimeter = rmax * (RMAX_RATIO ** numpy.arange(n_delimeter))
1501    for r0, r1 in zip(numpy.append(1e9, rcut_delimeter),
1502                      numpy.append(rcut_delimeter, 0)):
1503        # shells which have high exps (small rcut)
1504        shls_dense = [ib for ib, rc in enumerate(rcuts_pgto)
1505                      if numpy.any((r1 <= rc) & (rc < r0))]
1506        if len(shls_dense) == 0:
1507            continue
1508        cell_dense, ao_idx_dense, ke_cutoff, rcut_atom = \
1509                make_cell_dense_exp(shls_dense, r0, r1)
1510
1511        mesh = tools.cutoff_to_mesh(a, ke_cutoff)
1512        if TO_EVEN_GRIDS:
1513            mesh = (mesh+1)//2 * 2  # to the nearest even number
1514        if numpy.all(mesh >= fft_mesh):
1515            # Including all rest shells
1516            shls_dense = [ib for ib, rc in enumerate(rcuts_pgto)
1517                          if numpy.any(rc < r0)]
1518            cell_dense, ao_idx_dense = make_cell_dense_exp(shls_dense, r0, 0)[:2]
1519        cell_dense.mesh = mesh = numpy.min([mesh, fft_mesh], axis=0)
1520
1521        grids_dense = gen_grid.UniformGrids(cell_dense)
1522        grids_dense.ao_idx = ao_idx_dense
1523        #grids_dense.rcuts_pgto = [rcuts_pgto[i] for i in shls_dense]
1524
1525        # shells which have low exps (big rcut)
1526        shls_sparse = [ib for ib, rc in enumerate(rcuts_pgto)
1527                       if numpy.any(r0 <= rc)]
1528        if len(shls_sparse) == 0:
1529            cell_sparse = None
1530            ao_idx_sparse = []
1531        else:
1532            cell_sparse, ao_idx_sparse = make_cell_sparse_exp(shls_sparse, r0, r1)
1533            cell_sparse.mesh = mesh
1534
1535        if cell_sparse is None:
1536            grids_sparse = None
1537        else:
1538            grids_sparse = gen_grid.UniformGrids(cell_sparse)
1539            grids_sparse.ao_idx = ao_idx_sparse
1540
1541        log.debug('mesh %s nao dense/sparse %d %d  rcut %g',
1542                  mesh, len(ao_idx_dense), len(ao_idx_sparse), cell_dense.rcut)
1543
1544        tasks.append([grids_dense, grids_sparse])
1545        if numpy.all(mesh >= fft_mesh):
1546            break
1547    return tasks
1548
1549def multi_grids_tasks_for_ke_cut(cell, fft_mesh=None, verbose=None):
1550    log = logger.new_logger(cell, verbose)
1551    if fft_mesh is None:
1552        fft_mesh = cell.mesh
1553
1554    # Split shells based on rcut
1555    rcuts_pgto, kecuts_pgto = _primitive_gto_cutoff(cell)
1556    ao_loc = cell.ao_loc_nr()
1557
1558    # cell that needs dense integration grids
1559    def make_cell_dense_exp(shls_dense, ke0, ke1):
1560        cell_dense = copy.copy(cell)
1561        cell_dense._bas = cell._bas.copy()
1562        cell_dense._env = cell._env.copy()
1563
1564        rcut_atom = [0] * cell.natm
1565        ke_cutoff = 0
1566        for ib in shls_dense:
1567            ke = kecuts_pgto[ib]
1568            idx = numpy.where((ke0 < ke) & (ke <= ke1))[0]
1569            np1 = len(idx)
1570            cs = cell._libcint_ctr_coeff(ib)
1571            np, nc = cs.shape
1572            if np1 < np:  # no pGTO splitting within the shell
1573                pexp = cell._bas[ib,PTR_EXP]
1574                pcoeff = cell._bas[ib,PTR_COEFF]
1575                cs1 = cs[idx]
1576                cell_dense._env[pcoeff:pcoeff+cs1.size] = cs1.T.ravel()
1577                cell_dense._env[pexp:pexp+np1] = cell.bas_exp(ib)[idx]
1578                cell_dense._bas[ib,NPRIM_OF] = np1
1579
1580            ke_cutoff = max(ke_cutoff, ke[idx].max())
1581
1582            ia = cell.bas_atom(ib)
1583            rcut_atom[ia] = max(rcut_atom[ia], rcuts_pgto[ib][idx].max())
1584        cell_dense._bas = cell_dense._bas[shls_dense]
1585        ao_idx = numpy.hstack([numpy.arange(ao_loc[i], ao_loc[i+1])
1586                               for i in shls_dense])
1587        cell_dense.rcut = max(rcut_atom)
1588        return cell_dense, ao_idx, ke_cutoff, rcut_atom
1589
1590    # cell that needs sparse integration grids
1591    def make_cell_sparse_exp(shls_sparse, ke0, ke1):
1592        cell_sparse = copy.copy(cell)
1593        cell_sparse._bas = cell._bas.copy()
1594        cell_sparse._env = cell._env.copy()
1595
1596        for ib in shls_sparse:
1597            idx = numpy.where(kecuts_pgto[ib] <= ke0)[0]
1598            np1 = len(idx)
1599            cs = cell._libcint_ctr_coeff(ib)
1600            np, nc = cs.shape
1601            if np1 < np:  # no pGTO splitting within the shell
1602                pexp = cell._bas[ib,PTR_EXP]
1603                pcoeff = cell._bas[ib,PTR_COEFF]
1604                cs1 = cs[idx]
1605                cell_sparse._env[pcoeff:pcoeff+cs1.size] = cs1.T.ravel()
1606                cell_sparse._env[pexp:pexp+np1] = cell.bas_exp(ib)[idx]
1607                cell_sparse._bas[ib,NPRIM_OF] = np1
1608        cell_sparse._bas = cell_sparse._bas[shls_sparse]
1609        ao_idx = numpy.hstack([numpy.arange(ao_loc[i], ao_loc[i+1])
1610                               for i in shls_sparse])
1611        return cell_sparse, ao_idx
1612
1613    a = cell.lattice_vectors()
1614    if abs(a-numpy.diag(a.diagonal())).max() < 1e-12:
1615        init_mesh = INIT_MESH_ORTH
1616    else:
1617        init_mesh = INIT_MESH_NONORTH
1618    ke_cutoff_min = tools.mesh_to_cutoff(cell.lattice_vectors(), init_mesh)
1619    ke_cutoff_max = max([ke.max() for ke in kecuts_pgto])
1620    ke1 = ke_cutoff_min.min()
1621    ke_delimeter = [0, ke1]
1622    while ke1 < ke_cutoff_max:
1623        ke1 *= KE_RATIO
1624        ke_delimeter.append(ke1)
1625
1626    tasks = []
1627    for ke0, ke1 in zip(ke_delimeter[:-1], ke_delimeter[1:]):
1628        # shells which have high exps (small rcut)
1629        shls_dense = [ib for ib, ke in enumerate(kecuts_pgto)
1630                      if numpy.any((ke0 < ke) & (ke <= ke1))]
1631        if len(shls_dense) == 0:
1632            continue
1633
1634        mesh = tools.cutoff_to_mesh(a, ke1)
1635        if TO_EVEN_GRIDS:
1636            mesh = int((mesh+1)//2) * 2  # to the nearest even number
1637
1638        if numpy.all(mesh >= fft_mesh):
1639            # Including all rest shells
1640            shls_dense = [ib for ib, ke in enumerate(kecuts_pgto)
1641                          if numpy.any(ke0 < ke)]
1642            cell_dense, ao_idx_dense = make_cell_dense_exp(shls_dense, ke0,
1643                                                           ke_cutoff_max+1)[:2]
1644        else:
1645            cell_dense, ao_idx_dense, ke_cutoff, rcut_atom = \
1646                    make_cell_dense_exp(shls_dense, ke0, ke1)
1647
1648        cell_dense.mesh = mesh = numpy.min([mesh, fft_mesh], axis=0)
1649
1650        grids_dense = gen_grid.UniformGrids(cell_dense)
1651        grids_dense.ao_idx = ao_idx_dense
1652        #grids_dense.rcuts_pgto = [rcuts_pgto[i] for i in shls_dense]
1653
1654        # shells which have low exps (big rcut)
1655        shls_sparse = [ib for ib, ke in enumerate(kecuts_pgto)
1656                       if numpy.any(ke <= ke0)]
1657        if len(shls_sparse) == 0:
1658            cell_sparse = None
1659            ao_idx_sparse = []
1660        else:
1661            cell_sparse, ao_idx_sparse = make_cell_sparse_exp(shls_sparse, ke0, ke1)
1662            cell_sparse.mesh = mesh
1663
1664        if cell_sparse is None:
1665            grids_sparse = None
1666        else:
1667            grids_sparse = gen_grid.UniformGrids(cell_sparse)
1668            grids_sparse.ao_idx = ao_idx_sparse
1669
1670        log.debug('mesh %s nao dense/sparse %d %d  rcut %g',
1671                  mesh, len(ao_idx_dense), len(ao_idx_sparse), cell_dense.rcut)
1672
1673        tasks.append([grids_dense, grids_sparse])
1674        if numpy.all(mesh >= fft_mesh):
1675            break
1676    return tasks
1677
1678def _primitive_gto_cutoff(cell, precision=None):
1679    '''Cutoff raidus, above which each shell decays to a value less than the
1680    required precsion'''
1681    if precision is None:
1682        precision = cell.precision * EXTRA_PREC
1683    log_prec = min(numpy.log(precision), 0)
1684
1685    rcut = []
1686    ke_cutoff = []
1687    for ib in range(cell.nbas):
1688        l = cell.bas_angular(ib)
1689        es = cell.bas_exp(ib)
1690        cs = abs(cell.bas_ctr_coeff(ib)).max(axis=1)
1691        r = 5.
1692        r = (((l+2)*numpy.log(r)+numpy.log(4*numpy.pi*cs) - log_prec) / es)**.5
1693        r = (((l+2)*numpy.log(r)+numpy.log(4*numpy.pi*cs) - log_prec) / es)**.5
1694
1695# Errors in total number of electrons were observed with the default
1696# precision. The energy cutoff (or the integration mesh) is not enough to
1697# produce the desired accuracy. Scale precision by 0.1 to decrease the error.
1698        ke_guess = gto.cell._estimate_ke_cutoff(es, l, cs, precision*0.1)
1699
1700        rcut.append(r)
1701        ke_cutoff.append(ke_guess)
1702    return rcut, ke_cutoff
1703
1704
1705class MultiGridFFTDF(fft.FFTDF):
1706    def __init__(self, cell, kpts=numpy.zeros((1,3))):
1707        fft.FFTDF.__init__(self, cell, kpts)
1708        self.tasks = None
1709        self._keys = self._keys.union(['tasks'])
1710
1711    def build(self):
1712        self.tasks = multi_grids_tasks(self.cell, self.mesh, self.verbose)
1713        return self
1714
1715    def reset(self, cell=None):
1716        self.tasks = None
1717        return fft.FFTDF.reset(cell)
1718
1719    get_pp = get_pp
1720    get_nuc = get_nuc
1721
1722    def get_jk(self, dm, hermi=1, kpts=None, kpts_band=None,
1723               with_j=True, with_k=True, exxdiv='ewald', **kwargs):
1724        from pyscf.pbc.df import fft_jk
1725        if with_k:
1726            logger.warn(self, 'MultiGridFFTDF does not support HFX. '
1727                        'HFX is computed by FFTDF.get_k_kpts function.')
1728
1729        if kpts is None:
1730            if numpy.all(self.kpts == 0): # Gamma-point J/K by default
1731                kpts = numpy.zeros(3)
1732            else:
1733                kpts = self.kpts
1734        else:
1735            kpts = numpy.asarray(kpts)
1736
1737        vj = vk = None
1738        if kpts.shape == (3,):
1739            if with_k:
1740                vk = fft_jk.get_jk(self, dm, hermi, kpts, kpts_band,
1741                                   False, True, exxdiv)[1]
1742            vj = get_j_kpts(self, dm, hermi, kpts.reshape(1,3), kpts_band)
1743            if kpts_band is None:
1744                vj = vj[...,0,:,:]
1745        else:
1746            if with_k:
1747                vk = fft_jk.get_k_kpts(self, dm, hermi, kpts, kpts_band, exxdiv)
1748            if with_j:
1749                vj = get_j_kpts(self, dm, hermi, kpts, kpts_band)
1750        return vj, vk
1751
1752    get_rho = get_rho
1753
1754
1755def multigrid(mf):
1756    '''Use MultiGridFFTDF to replace the default FFTDF integration method in
1757    the DFT object.
1758    '''
1759    mf.with_df, old_df = MultiGridFFTDF(mf.cell), mf.with_df
1760    keys = mf.with_df._keys
1761    mf.with_df.__dict__.update(old_df.__dict__)
1762    mf.with_df._keys = keys
1763    return mf
1764
1765
1766def _pgto_shells(cell):
1767    return cell._bas[:,NPRIM_OF].sum()
1768
1769def _take_4d(a, indices):
1770    a_shape = a.shape
1771    ranges = []
1772    for i, s in enumerate(indices):
1773        if s is None:
1774            idx = numpy.arange(a_shape[i], dtype=numpy.int32)
1775        else:
1776            idx = numpy.asarray(s, dtype=numpy.int32)
1777            idx[idx < 0] += a_shape[i]
1778        ranges.append(idx)
1779    idx = ranges[0][:,None] * a_shape[1] + ranges[1]
1780    idy = ranges[2][:,None] * a_shape[3] + ranges[3]
1781    a = a.reshape(a_shape[0]*a_shape[1], a_shape[2]*a_shape[3])
1782    out = lib.take_2d(a, idx.ravel(), idy.ravel())
1783    return out.reshape([len(s) for s in ranges])
1784
1785def _takebak_4d(out, a, indices):
1786    out_shape = out.shape
1787    a_shape = a.shape
1788    ranges = []
1789    for i, s in enumerate(indices):
1790        if s is None:
1791            idx = numpy.arange(a_shape[i], dtype=numpy.int32)
1792        else:
1793            idx = numpy.asarray(s, dtype=numpy.int32)
1794            idx[idx < 0] += out_shape[i]
1795        assert(len(idx) == a_shape[i])
1796        ranges.append(idx)
1797    idx = ranges[0][:,None] * out_shape[1] + ranges[1]
1798    idy = ranges[2][:,None] * out_shape[3] + ranges[3]
1799    nx = idx.size
1800    ny = idy.size
1801    out = out.reshape(out_shape[0]*out_shape[1], out_shape[2]*out_shape[3])
1802    lib.takebak_2d(out, a.reshape(nx,ny), idx.ravel(), idy.ravel())
1803    return out
1804
1805def _take_5d(a, indices):
1806    a_shape = a.shape
1807    a = a.reshape((a_shape[0]*a_shape[1],) + a_shape[2:])
1808    indices = (None,) + indices[2:]
1809    return _take_4d(a, indices)
1810
1811def _takebak_5d(out, a, indices):
1812    a_shape = a.shape
1813    out_shape = out.shape
1814    a = a.reshape((a_shape[0]*a_shape[1],) + a_shape[2:])
1815    out = out.reshape((out_shape[0]*out_shape[1],) + out_shape[2:])
1816    indices = (None,) + indices[2:]
1817    return _takebak_4d(out, a, indices)
1818
1819
1820if __name__ == '__main__':
1821    from pyscf.pbc import dft
1822    numpy.random.seed(22)
1823    cell = gto.M(
1824        a = numpy.eye(3)*3.5668,
1825        atom = '''C     0.      0.      0.
1826                  C     0.8917  0.8917  0.8917
1827                  C     1.7834  1.7834  0.
1828                  C     2.6751  2.6751  0.8917
1829                  C     1.7834  0.      1.7834
1830                  C     2.6751  0.8917  2.6751
1831                  C     0.      1.7834  1.7834
1832                  C     0.8917  2.6751  2.6751''',
1833        #basis = 'sto3g',
1834        #basis = 'ccpvdz',
1835        basis = 'gth-dzvp',
1836        #basis = 'gth-szv',
1837        #verbose = 5,
1838        #mesh = [15]*3,
1839        #precision=1e-6
1840        pseudo = 'gth-pade'
1841    )
1842    multi_grids_tasks(cell, cell.mesh, 5)
1843
1844    nao = cell.nao_nr()
1845    numpy.random.seed(1)
1846    kpts = cell.make_kpts([3,1,1])
1847
1848    dm = numpy.random.random((len(kpts),nao,nao)) * .2
1849    dm += numpy.eye(nao)
1850    dm = dm + dm.transpose(0,2,1)
1851
1852    mf = dft.KRKS(cell)
1853    ref = mf.get_veff(cell, dm, kpts=kpts)
1854    out = multigrid(mf).get_veff(cell, dm, kpts=kpts)
1855    print(abs(ref-out).max())
1856
1857