1from h5py import File
2import numpy
3from pyscf.pbc import gto, tools
4from pyscf.pbc.dft import numint
5from pyscf import gto as molgto
6import os
7import sys
8import numpy
9from mpi4py import MPI
10from afqmctools.utils.gto_basis_utils import extend_gto_id
11try:
12    from pyscf_driver import (pyscf_driver_init, pyscf_driver_get_info, pyscf_driver_end,
13                    pyscf_driver_mp2,pyscf_driver_hamil,pyscf_driver_mp2no)
14except ImportError:
15    print("Warning: module pyscf_driver not found. AFQMC QE converter required to "
16          "use this module.")
17
18def make_cell(latt,sp_label,atid,atpos,basis_='gthdzvp',pseudo_='gthpbe',mesh=None,prec=1e-8):
19    """Generates a PySCF gto.Cell object.
20
21    Parameters
22    ----------
23    latt: (3,3) floating point array.
24      Lattice vectors in Bohr. Used to define cell.a
25    sp_label: array of strings.
26      Array containing species labels/symbols.
27    atid: integer array.
28      Array that contains the mapping between the atoms in the unit cell and their
29      species label (as defined by sp_label).
30    basis_: string. Default: 'gthdzvp'
31      Basis set string. Used to define cell.basis.
32    pseudo_: string. Default: 'gthpbe'
33      Pseudopotential string. Used to define cell.pseudo.
34    mesh: 3-d array. Default: None
35      Used to define cell.mesh.
36    prec: floating point. Default: 1e-8
37      Used to define cell.precision
38
39    Returns
40    -------
41    cell: PySCF get.Cell object.
42    """
43    assert(len(atid) == len(atpos))
44    assert(atpos.ndim == 2)
45    assert(atpos.shape[1] == 3)
46    cell = gto.Cell()
47    cell.a = '''
48       {} {} {}
49       {} {} {}
50       {} {} {}'''.format(latt[0,0],latt[0,1],latt[0,2],
51                          latt[1,0],latt[1,1],latt[1,2],
52                          latt[2,0],latt[2,1],latt[2,2])
53    atom = ''
54    for i in range(len(atid)):
55        atom += '''{} {} {} {} \n'''.format(sp_label[atid[i]-1],
56                                            atpos[i,0],atpos[i,1],atpos[i,2])
57    cell.atom = atom
58    cell.basis = basis_
59    cell.pseudo = pseudo_
60    if mesh is not None:
61        cell.mesh = mesh
62    cell.verbose = 5
63    cell.unit = 'B'
64    cell.precision = prec
65    cell.build()
66    return cell
67
68def write_esh5_orbitals(cell, name, kpts = numpy.zeros((1,3),dtype=numpy.float64)):
69    """Writes periodic AO basis to hdf5 file.
70
71    Parameters
72    ----------
73    cell: PySCF get.Cell object
74      PySCF cell object which contains information of the system, including
75      AO basis set, FFT mesh, unit cell information, etc.
76    name: string
77      Name of hdf5 file.
78    kpts: array. Default: numpy.zeros((1,3)
79      K-point array of dimension (nkpts, 3)
80    dtype: datatype. Default: numpy.float64
81      Datatype of orbitals in file.
82
83    """
84
85    def to_qmcpack_complex(array):
86        shape = array.shape
87        return array.view(numpy.float64).reshape(shape+(2,))
88    nao = cell.nao_nr()
89
90    fh5 = File(name,'w')
91    coords = cell.gen_uniform_grids(cell.mesh)
92
93    kpts = numpy.asarray(kpts)
94    nkpts = len(kpts)
95    norbs = numpy.zeros((nkpts,),dtype=int)
96    norbs[:] = nao
97
98    grp = fh5.create_group("OrbsG")
99    dset = grp.create_dataset("reciprocal_vectors", data=cell.reciprocal_vectors())
100    dset = grp.create_dataset("number_of_kpoints", data=len(kpts))
101    dset = grp.create_dataset("kpoints", data=kpts)
102    dset = grp.create_dataset("number_of_orbitals", data=norbs)
103    dset = grp.create_dataset("fft_grid", data=cell.mesh)
104    dset = grp.create_dataset("grid_type", data=int(0))
105    nnr = cell.mesh[0]*cell.mesh[1]*cell.mesh[2]
106    # loop over kpoints later
107    for (ik,k) in enumerate(kpts):
108        ao = numint.KNumInt().eval_ao(cell, coords, k)[0]
109        fac = numpy.exp(-1j * numpy.dot(coords, k))
110        for i in range(norbs[ik]):
111            aoi = fac * numpy.asarray(ao[:,i].T, order='C')
112            aoi_G = tools.fft(aoi, cell.mesh)
113            aoi_G = aoi_G.reshape(cell.mesh).transpose(2,1,0).reshape(nnr)
114            dset = grp.create_dataset('kp'+str(ik)+'_b'+str(i), data=to_qmcpack_complex(aoi_G))
115    fh5.close()
116
117def make_image_comm(nimage, comm=MPI.COMM_WORLD):
118    """Splits a communicator into image communicators, consistent with QE partitioning.
119       nimage consecutive ranks in comm belong to the same image communicator.
120       The number of distinct image communcators is comm.size/nimage.
121
122    Parameters
123    ----------
124    nimage: integer
125      Number of image communicators, must divide comm.size.
126    comm: mpi4py MPI comunicator. Default: MPI.COMM_WORLD
127      A valid mpi4py communicator.
128
129    Returns
130    -------
131    intra_image: mpi4py MPI comunicator.
132      Communicator between mpi tasks within an image.
133    inter_image:mpi4py MPI comunicator.
134      Communicator between mpi tasks on different images, but having the same rank in the image.
135    """
136    parent_nproc = comm.size
137    parent_mype = comm.rank
138    assert( parent_nproc%nimage == 0 )
139    nproc_image = parent_nproc / nimage
140    my_image_id = parent_mype / nproc_image
141    me_image    = parent_mype%nproc_image
142    intra_image = comm.Split(my_image_id,comm.rank)
143    inter_image = comm.Split(me_image,comm.rank)
144    return intra_image, inter_image
145
146# put these in modules later
147def qe_driver_init(norb, qe_prefix, qe_outdir, atm_labels,
148                   intra_image=MPI.COMM_WORLD, inter_image=None,
149                   npools=1, outdir='./qedrv', #remove_dir=True,
150                   set_soft_links=True, verbose=True,add_image_tag=True):
151    """Initializes the QE driver. Must be called before any routine that calls
152       the QE driver is executed. Requires a pre-existing QE successful run.
153
154    Parameters
155    ----------
156    norb: integer
157        Number of orbitals read from QE calculation.
158    qe_prefix: string
159        prefix parameter from QE run.
160    qe_outdir: string
161        outdir parameter from QE run. (location of QE files).
162    atm_labels: array of strings
163        Array containing the species labels.
164    intra_image: mpi4py communicator. Default: MPI.COMM_WORLD
165        Intra image communicator.
166    inter_image: mpi4py communicator. Default: None
167        Inter image communicator.
168    npools: integer. Default: 1
169        Number of QE pools used in the driver.
170    outdir: string. Default: ./qedrv
171        Output directory of the driver. Does not need to be the same as the QE parameter.
172    set_soft_links: Bool. Default: True
173        If true, soft links to QE files/foulders from qe_outdir will be placed in outdir.
174    verbose: Bool. Default: True
175        Sets verbosity in driver.
176    add_image_tag: Bool. Default True.
177        If True, outdir is modified by adding a tag that identifies the image.
178        This is needed if running with multiple images simultaneously,
179        otherwise the files from different images might conflict with each other.
180    Returns
181    -------
182    qe_info: Python Dictionary
183      Dictonary containing all stored information about the QE driver.
184      Contents:
185        'species' : string array    # array with species labels.
186        'nsp' : integer,            # number of species
187        'nat' : integer,            # number of atoms
188        'at_id' :  integer array    # array with the ids of atoms in the unit cell.
189        'at_pos' : (nat,3) fp array # array with atom positions
190        'nkpts' : integer           # number of kpoints
191        'kpts' : (nkpts,3) fp array # k-points
192        'latt' : (3,3) fp array     # lattice vectors
193        'npwx' : integer            # npwx parameter from QE.
194        'mesh' : 3D integer array   # FFT mesh
195        'ngm' : integer             # ngm parameter from QE.
196        'outdir' : string           # Location of folder with driver files.
197    """
198#    if remove_dir:
199#        assert(set_soft_links)
200    assert(intra_image.size%npools==0)
201    intra_rank = intra_image.rank
202    intra_size = intra_image.size
203
204    if inter_image is not None:
205        inter_rank = inter_image.rank
206        inter_size = inter_image.size
207    else:
208        inter_rank = 0
209        inter_size = 1
210
211    fname = outdir
212    if add_image_tag:
213        fname += '.'+str(inter_rank)+'/'
214    else:
215        fname += '/'
216    if intra_rank == 0:
217#        if remove_dir:
218#            os.system('rm -rf '+fname+'\n')
219        if set_soft_links:
220            os.system('mkdir '+fname)
221            os.system('ln -s ./'+qe_outdir+'/'+qe_prefix+'.xml '+fname+'/'+qe_prefix+'.xml')
222            os.system('ln -s ./'+qe_outdir+'/'+qe_prefix+'.save/ '+fname+'/'+qe_prefix+'.save')
223    MPI.COMM_WORLD.barrier()
224
225    # initialize driver:
226    nkpts, nat, nsp, npwx, ngm, mesh = pyscf_driver_init(inter_size, npools, intra_size/npools,
227                                                         norb, qe_prefix, fname, verbose)
228    atms = numpy.array(atm_labels)  # don't know how to return an array of strings
229    atom_ids,atom_pos,kpts,latt = pyscf_driver_get_info(nat,nsp,nkpts)#,atms)
230    atom_pos=atom_pos.T
231    kpts=kpts.T
232    latt=latt.T
233    qe_info = {'species' : atms,
234               'nsp' : nsp,
235               'nat' : nat,
236               'at_id' : atom_ids,
237               'at_pos' : atom_pos,
238               'nkpts' : nkpts,
239               'kpts' : kpts,
240               'latt' : latt,
241               'npwx' : npwx,
242               'mesh' : mesh,
243               'ngm' : ngm,
244               'outdir' : fname
245               }
246    if verbose and (MPI.COMM_WORLD.rank==0):
247        print("# species = {}".format(qe_info['nsp']))
248        print("# atoms = {}".format(qe_info['nat']))
249        print("# kpts = {}".format(qe_info['nkpts']))
250        print("FFT mesh = {} {} {}".format(qe_info['mesh'][0],qe_info['mesh'][1],qe_info['mesh'][2]))
251        print(" Atom species: ")
252        print(qe_info['species'])
253        print(" Atom positions: ")
254        print(qe_info['at_pos'])
255        print(" Lattice: ")
256        print(qe_info['latt'])
257        print(" K-points: ")
258        print(qe_info['kpts'])
259
260    return qe_info
261
262def qe_driver_end():
263    """Finishes and performs clean-up on the QE driver.
264       After a call to this routine, further calls to the driver are undefined.
265    """
266    if(MPI.COMM_WORLD.rank==0):
267        print(" Closing QE driver.")
268    pyscf_driver_end()
269
270def gen_qe_gto(qe_info,bset,x=[],
271               fname='pyscf.orbitals.h5',prec=1e-12):
272    """ Writes periodic AO basis set in real space to hdf5 file.
273        This routine constructs a new gaussian basis set from a OptimizableBasisSet
274        object and an array of optimizable parameters.
275        With the resulting basis set, a new gto.Cell object is constructed consistent
276        with the QE calculation and used to generate the periodic AO basis set.
277
278    Parameters
279    ----------
280    qe_info: Python Dictionary.
281        Dictionary with information from QE calculation, generated by qe_driver_init.
282    bset: Object of type OptimizableBasisSet.
283        Contains information about a (possibly dynamic) basis set.
284    x: fp array. Default: [] (no variable parameters in bset)
285        Array with variational parameters in the bset object.
286    fname: string. Default: 'pyscf.orbitals.h5'
287        Name of hdf5 file.
288    prec: floating point number. Default: 1e-12
289        Precision used to generate AO orbitals in real space. Controls sum over periodic images.
290
291    Returns
292    -------
293    nao: integer
294        Number of atomic orbitals generates.
295    """
296    assert(len(x) == bset.number_of_params)
297    basis = {}
298    for I, atm in enumerate(qe_info['species']):
299        basis.update({atm: molgto.parse( bset.basis_str(atm,x) )})
300
301    cell = make_cell(qe_info['latt'],
302                 qe_info['species'],
303                 qe_info['at_id'],
304                 qe_info['at_pos'],
305                 basis_=basis,
306                 mesh=qe_info['mesh'],prec=prec)
307    nao = cell.nao_nr()
308    write_esh5_orbitals(cell, qe_info['outdir']+fname, kpts=qe_info['kpts'])
309    return nao
310
311def qe_driver_MP2(qe_info,out_prefix='pyscf_drv',
312                        diag_type='keep_occ',
313                        nread_from_h5=0,h5_add_orbs='',
314                        eigcut=1e-3,nextracut=1e-6,kappa=0.0,regp=0):
315    """ Calls the MP2 routine in the driver.
316
317    Parameters
318    ----------
319    qe_info: Python Dictionary.
320        Dictionary with information from QE calculation, generated by qe_driver_init.
321    out_prefix: string. Default: 'pyscf_drv'
322        Prefix used in all the files generated by the driver.
323    diag_type: string. Default: 'keep_occ'
324        Defines the type of HF diagonalization performed before the MP2 calculation.
325        Options:
326            'keep_occ': Only the virtual orbitals/eigenvalues are calculated.
327                        Occupied orbitals/eigenvalues are kept from the QE calculation.
328            'full': All orbitals/eigenvalues are recalculated.
329            'fullpw': A basis set is generated that contains all the plane waves
330                      below the QE wfn cutoff. The HF eigenvalues/orbitals and MP2NO
331                      are calculated in this basis.
332    nread_from_h5: integer. Default: 0
333        Number of orbitals to read from h5_add_orbs.
334    h5_add_orbs: string. Default: ''
335        Name of hdf5 file with additional orbitals to add to the basis set.
336    eigcut: fp number. Default: 1e-3
337        Cutoff used during the generation of the spin independent basis in UHF/GHF
338        calculations. Only the eigenvalues of the overlap matrix (alpha/beta)
339        above this cutoff are kept in the calculation. In order to reproduce
340        the UHF/GHF energy accurately, this number must be set to a small value (e.g. 1e-8).
341    nextracut: fp number. Default: 1e-6
342        Cutoff used when adding states from h5_add_orbs to the basis set.
343        When a new state from the file is being added to the orbital set,
344        the component along all current orbitals in the set is removed.
345        The resulting (orthogonal) state is added only if the norm of the unnormalized
346        orbital is larger than nextracut (state is afterwards normalized).
347        This is used as a way to remove linear dependencies from the basis set.
348    """
349    if diag_type=='fullpw':
350        emp2=pyscf_driver_mp2(out_prefix,True,diag_type,
351                     0,'',0.0,
352                     0.0,kappa,regp)
353    else:
354        emp2=pyscf_driver_mp2(out_prefix,True,diag_type,
355                     nread_from_h5,h5_add_orbs,eigcut,
356                     nextracut,kappa,regp)
357    return emp2
358
359def qe_driver_MP2NO(qe_info,out_prefix='pyscf_drv',
360                        appnos=False,
361                        diag_type='keep_occ',
362                        nread_from_h5=0,h5_add_orbs='',nskip=0,
363                        eigcut=1e-3,nextracut=1e-6,mp2noecut=1e-6,kappa=0.0,regp=0):
364    """ Calls the MP2NO routine in the driver.
365
366    Parameters
367    ----------
368    qe_info: Python Dictionary.
369        Dictionary with information from QE calculation, generated by qe_driver_init.
370    out_prefix: string. Default: 'pyscf_drv'
371        Prefix used in all the files generated by the driver.
372    appnos: Bool. Default: False.
373        If True, generates approximate natural orbitals.
374    diag_type: string. Default: 'keep_occ'
375        Defines the type of HF diagonalization performed before the MP2 calculation.
376        Options:
377            'keep_occ': Only the virtual orbitals/eigenvalues are calculated.
378                        Occupied orbitals/eigenvalues are kept from the QE calculation.
379            'full': All orbitals/eigenvalues are recalculated.
380            'fullpw': A basis set is generated that contains all the plane waves
381                      below the QE wfn cutoff. The HF eigenvalues/orbitals and MP2NO
382                      are calculated in this basis.
383    nread_from_h5: integer. Default: 0
384        Number of orbitals to read from h5_add_orbs.
385    h5_add_orbs: string. Default: ''
386        Name of hdf5 file with additional orbitals to add to the basis set.
387    nskip: integer. Default: 0
388        Number of states above the HOMO state of the solid to skip
389        during the calculation of MP2 NOs. This can be used to avoid divergencies
390        in metals. The assumption being that these states will be included in the
391        orbital set directly.
392    eigcut: fp number. Default: 1e-3
393        Cutoff used during the generation of the spin independent basis in UHF/GHF
394        calculations. Only the eigenvalues of the overlap matrix (alpha/beta)
395        above this cutoff are kept in the calculation. In order to reproduce
396        the UHF/GHF energy accurately, this number must be set to a small value (e.g. 1e-8).
397    nextracut: fp number. Default: 1e-6
398        Cutoff used when adding states from h5_add_orbs to the basis set.
399        When a new state from the file is being added to the orbital set,
400        the component along all current orbitals in the set is removed.
401        The resulting (orthogonal) state is added only if the norm of the unnormalized
402        orbital is larger than nextracut (state is afterwards normalized).
403        This is used as a way to remove linear dependencies from the basis set.
404    mp2noecut: fp number. Default: 1e-6
405        Cutoff used when adding natural orbitals from the MP2 RDM,
406        only states with eigenvalue > mp2noecut will be kept.
407        If this number is < 0.0, then a specific number of states is kept and is
408        given by nint(-mp2noecut).
409
410    """
411    if diag_type=='fullpw':
412        pyscf_driver_mp2no(out_prefix,True,diag_type,appnos,
413                     0,'',nskip,0.0,
414                     0.0,mp2noecut,kappa,regp)
415    else:
416        pyscf_driver_mp2no(out_prefix,True,diag_type,appnos,
417                     nread_from_h5,h5_add_orbs,nskip,eigcut,
418                     mp2noecut,nextracut,kappa,regp)
419
420def qe_driver_hamil(qe_info,out_prefix='pwscf',
421                    nread_from_h5=0,h5_add_orbs='',ndet=1,eigcut=1e-3,
422                    nextracut=1e-6,thresh=1e-5,ncholmax=15,get_hf=True,
423                    get_mp2=True,update_qe_bands=False):
424    """ Calls the MP2 routine in the driver.
425
426    Parameters
427    ----------
428    qe_info: Python Dictionary.
429        Dictionary with information from QE calculation, generated by qe_driver_init.
430    out_prefix: string. Default: 'pyscf_drv'
431        Prefix used in all the files generated by the driver.
432    nread_from_h5: integer. Default: 0
433        Number of orbitals to read from h5_add_orbs.
434    h5_add_orbs: string. Default: ''
435        Name of hdf5 file with additional orbitals to add to the basis set.
436    ndet: integer. Default: 1
437        Maximum number of determinants allowed in the trial wave-function.
438    eigcut: fp number. Default: 1e-3
439        Cutoff used during the generation of the spin independent basis in UHF/GHF
440        calculations. Only the eigenvalues of the overlap matrix (alpha/beta)
441        above this cutoff are kept in the calculation. In order to reproduce
442        the UHF/GHF energy accurately, this number must be set to a small value (e.g. 1e-8).
443    nextracut: fp number. Default: 1e-6
444        Cutoff used when adding states from h5_add_orbs to the basis set.
445        When a new state from the file is being added to the orbital set,
446        the component along all current orbitals in the set is removed.
447        The resulting (orthogonal) state is added only if the norm of the unnormalized
448        orbital is larger than nextracut (state is afterwards normalized).
449        This is used as a way to remove linear dependencies from the basis set.
450    thresh: floating point. Detault: 1e-5
451        Value used to stop the iterative calculation of Cholesky vectors. The iterations
452        stop when the error on a diagonal element falls below this value.
453    ncholmax: integer. Default: 15
454        Maximum number of Cholesky vectors allowed (in units of the number of orbitals).
455        If the iterative calculation has not converged when this number of Cholesky vectors
456        is found, the calculation stops.
457    get_hf: Bool. Default: True
458        If True, calculate the HF eigenvalues/eigenvectors.
459    get_mp2: Bool. Default: True
460        If True, calculate the MP2 energy. (If True, get_hf will be set to true.)
461    update_qe_bands: Bool. Default: False
462        If True, the orbitals in the QE restart file will overwritten with the
463        basis set generated by the driver. Orbitals beyond norb (set in qe_driver_init)
464        will be left unmodified.
465    Returns
466    -------
467    ehf: floting point
468        The HF energy on this basis. (return 0.0 if not requested)
469    emp2: floting point
470        The MP2 energy on this basis. (return 0.0 if not requested)
471    """
472    if get_mp2:
473        get_hf = True
474    ehf, emp2 = pyscf_driver_hamil(out_prefix, nread_from_h5, h5_add_orbs,
475             ndet, eigcut, nextracut, thresh, ncholmax,
476             get_hf, get_mp2, update_qe_bands)
477    return ehf, emp2
478