1##################################################################
2##  (c) Copyright 2015-  by Jaron T. Krogel                     ##
3##################################################################
4
5
6#====================================================================#
7#  qmcpack.py                                                        #
8#    Nexus interface with the QMCPACK simulation code.               #
9#                                                                    #
10#                                                                    #
11#  Content summary:                                                  #
12#    Qmcpack                                                         #
13#      Simulation class for QMCPACK.                                 #
14#      Handles incorporation of structure, orbital, and Jastrow      #
15#        data from other completed simulations.                      #
16#                                                                    #
17#    generate_qmcpack                                                #
18#      User-facing function to create QMCPACK simulation objects.    #
19#                                                                    #
20#    generate_cusp_correction                                        #
21#      User-facing function to run QMCPACK as an intermediate tool   #
22#        to add cusps to Gaussian orbitals coming from GAMESS.       #
23#                                                                    #
24#====================================================================#
25
26
27import os
28from numpy import array,dot,pi
29from numpy.linalg import inv,norm
30from generic import obj
31from periodic_table import periodic_table
32from physical_system import PhysicalSystem
33from simulation import Simulation,NullSimulationAnalyzer
34from qmcpack_input import QmcpackInput,generate_qmcpack_input
35from qmcpack_input import TracedQmcpackInput
36from qmcpack_input import loop,linear,cslinear,vmc,dmc,collection,determinantset,hamiltonian,init,pairpot,bspline_builder
37from qmcpack_input import generate_jastrows,generate_jastrow,generate_jastrow1,generate_jastrow2,generate_jastrow3
38from qmcpack_input import generate_opt,generate_opts
39from qmcpack_analyzer import QmcpackAnalyzer
40from qmcpack_converters import Pw2qmcpack,Convert4qmc,PyscfToAfqmc
41from debug import ci,ls,gs
42from developer import unavailable
43from nexus_base import nexus_core
44try:
45    import h5py
46except:
47    h5py = unavailable('h5py')
48#end try
49
50
51
52class Qmcpack(Simulation):
53    input_type    = QmcpackInput
54    analyzer_type = QmcpackAnalyzer
55    generic_identifier = 'qmcpack'
56    infile_extension   = '.in.xml'
57    application   = 'qmcpack'
58    application_properties = set(['serial','omp','mpi'])
59    application_results    = set(['jastrow','cuspcorr','wavefunction'])
60
61
62    def has_afqmc_input(self):
63        afqmc_input = False
64        if not self.has_generic_input():
65            afqmc_input = self.input.is_afqmc_input()
66        #end if
67        return afqmc_input
68    #end def has_afqmc_input
69
70
71    def post_init(self):
72        generic_input = self.has_generic_input()
73
74        if self.has_afqmc_input():
75            self.analyzer_type = NullSimulationAnalyzer
76            self.should_twist_average = False
77        elif self.system is None:
78            if not generic_input:
79                self.warn('system must be specified to determine whether to twist average\nproceeding under the assumption of no twist averaging')
80            #end if
81            self.should_twist_average = False
82        else:
83            if generic_input:
84                cls = self.__class__
85                self.error('cannot twist average generic or templated input\nplease provide {0} instead of {1} for input'.format(cls.input_type.__class__.__name__,self.input.__class__.__name__))
86            #end if
87            self.system.group_atoms()
88            self.system.change_units('B')
89            twh = self.input.get_host('twist')
90            tnh = self.input.get_host('twistnum')
91            htypes = bspline_builder,determinantset
92            user_twist_given  = isinstance(twh,htypes) and twh.twist!=None
93            user_twist_given |= isinstance(tnh,htypes) and tnh.twistnum!=None
94            many_kpoints = len(self.system.structure.kpoints)>1
95            self.should_twist_average = many_kpoints and not user_twist_given
96            if self.should_twist_average:
97                # correct the job app command to account for the change in input file name
98                # this is necessary for twist averaged runs in bundles
99                app_comm = self.app_command()
100                prefix,ext = self.infile.split('.',1)
101                self.infile = prefix+'.in'
102                app_comm_new = self.app_command()
103                if self.job.app_command==app_comm:
104                    self.job.app_command=app_comm_new
105                #end if
106            #end if
107        #end if
108    #end def post_init
109
110
111    def propagate_identifier(self):
112        if not self.has_generic_input():
113            self.input.simulation.project.id = self.identifier
114        #end if
115    #end def propagate_identifier
116
117
118    def pre_write_inputs(self,save_image):
119        # fix to make twist averaged input file under generate_only
120        if self.system is None:
121            self.should_twist_average = False
122        elif nexus_core.generate_only:
123            twistnums = list(range(len(self.system.structure.kpoints)))
124            if self.should_twist_average:
125                self.twist_average(twistnums)
126            #end if
127        #end if
128    #end def pre_write_inputs
129
130
131    def check_result(self,result_name,sim):
132        calculating_result = False
133        if result_name=='jastrow' or result_name=='wavefunction':
134            calctypes = self.input.get_output_info('calctypes')
135            calculating_result = 'opt' in calctypes
136        elif result_name=='cuspcorr':
137            calculating_result = self.input.cusp_correction()
138        #end if
139        return calculating_result
140    #end def check_result
141
142
143    def get_result(self,result_name,sim):
144        result = obj()
145        if result_name=='jastrow' or result_name=='wavefunction':
146            analyzer = self.load_analyzer_image()
147            if not 'results' in analyzer or not 'optimization' in analyzer.results:
148                if self.should_twist_average:
149                    self.error('Wavefunction optimization was performed for each twist separately.\nCurrently, the transfer of per-twist wavefunction parameters from\none QMCPACK simulation to another is not supported.  Please either\nredo the optimization with a single twist (see "twist" or "twistnum"\noptions), or request that this feature be implemented.')
150                else:
151                    self.error('analyzer did not compute results required to determine jastrow')
152                #end if
153            #end if
154            opt_file = analyzer.results.optimization.optimal_file
155            opt_file = str(opt_file)
156            result.opt_file = os.path.join(self.locdir,opt_file)
157            del analyzer
158        elif result_name=='cuspcorr':
159            result.spo_up_cusps = os.path.join(self.locdir,self.identifier+'.spo-up.cuspInfo.xml')
160            result.spo_dn_cusps = os.path.join(self.locdir,self.identifier+'.spo-dn.cuspInfo.xml')
161            result.updet_cusps = os.path.join(self.locdir,'updet.cuspInfo.xml')
162            result.dndet_cusps = os.path.join(self.locdir,'downdet.cuspInfo.xml')
163        else:
164            self.error('ability to get result '+result_name+' has not been implemented')
165        #end if
166        return result
167    #end def get_result
168
169
170    def incorporate_result(self,result_name,result,sim):
171        input = self.input
172        system = self.system
173        if result_name=='orbitals':
174            if isinstance(sim,Pw2qmcpack):
175
176                h5file = result.h5file
177
178                wavefunction = input.get('wavefunction')
179                if isinstance(wavefunction,collection):
180                    wavefunction = wavefunction.get_single('psi0')
181                #end if
182                wf = wavefunction
183                if 'sposet_builder' in wf and wf.sposet_builder.type=='bspline':
184                    orb_elem = wf.sposet_builder
185                elif 'sposet_builders' in wf and 'bspline' in wf.sposet_builders:
186                    orb_elem = wf.sposet_builders.bspline
187                elif 'sposet_builders' in wf and 'einspline' in wf.sposet_builders:
188                    orb_elem = wf.sposet_builders.einspline
189                elif 'determinantset' in wf and wf.determinantset.type in ('bspline','einspline'):
190                    orb_elem = wf.determinantset
191                else:
192                    self.error('could not incorporate pw2qmcpack orbitals\nbspline sposet_builder and determinantset are both missing')
193                #end if
194                if 'href' in orb_elem and isinstance(orb_elem.href,str) and os.path.exists(orb_elem.href):
195                    # user specified h5 file for orbitals, bypass orbital dependency
196                    orb_elem.href = os.path.relpath(orb_elem.href,self.locdir)
197                else:
198                    orb_elem.href = os.path.relpath(h5file,self.locdir)
199                    if system.structure.folded_structure!=None:
200                        orb_elem.tilematrix = array(system.structure.tmatrix)
201                    #end if
202                #end if
203                defs = obj(
204                    #twistnum   = 0,
205                    meshfactor = 1.0
206                    )
207                for var,val in defs.items():
208                    if not var in orb_elem:
209                        orb_elem[var] = val
210                    #end if
211                #end for
212                has_twist    = 'twist' in orb_elem
213                has_twistnum = 'twistnum' in orb_elem
214                if  not has_twist and not has_twistnum:
215                    orb_elem.twistnum = 0
216                #end if
217
218                system = self.system
219                structure = system.structure
220                nkpoints = len(structure.kpoints)
221                if nkpoints==0:
222                    self.error('system must have kpoints to assign twistnums')
223                #end if
224
225                if not os.path.exists(h5file):
226                    self.error('wavefunction file not found:\n'+h5file)
227                #end if
228
229                twistnums = list(range(len(structure.kpoints)))
230                if self.should_twist_average:
231                    self.twist_average(twistnums)
232                elif not has_twist and orb_elem.twistnum is None:
233                    orb_elem.twistnum = twistnums[0]
234                #end if
235
236            elif isinstance(sim,Convert4qmc):
237
238                res = QmcpackInput(result.location)
239                qs  = input.simulation.qmcsystem
240                oldwfn = qs.wavefunction
241                newwfn = res.qmcsystem.wavefunction
242                dset = newwfn.determinantset
243                if 'jastrows' in newwfn:
244                    del newwfn.jastrows
245                #end if
246                if 'jastrows' in oldwfn:
247                    newwfn.jastrows = oldwfn.jastrows
248                #end if
249                if input.cusp_correction():
250                    dset.cuspcorrection = True
251                #end if
252                if 'orbfile' in result:
253                    orb_h5file = result.orbfile
254                    if not os.path.exists(orb_h5file) and 'href' in dset:
255                        orb_h5file = os.path.join(sim.locdir,dset.href)
256                    #end if
257                    if not os.path.exists(orb_h5file):
258                        self.error('orbital h5 file from convert4qmc does not exist\nlocation checked: {}'.format(orb_h5file))
259                    #end if
260                    orb_path = os.path.relpath(orb_h5file,self.locdir)
261                    dset.href = orb_path
262                    detlist = dset.get('detlist')
263                    if detlist is not None and 'href' in detlist:
264                        detlist.href = orb_path
265                    #end if
266                #end if
267                qs.wavefunction = newwfn
268
269            else:
270                self.error('incorporating orbitals from '+sim.__class__.__name__+' has not been implemented')
271            #end if
272        elif result_name=='jastrow':
273            if isinstance(sim,Qmcpack):
274                opt_file = result.opt_file
275                opt = QmcpackInput(opt_file)
276                wavefunction = input.get('wavefunction')
277                optwf = opt.qmcsystem.wavefunction
278                def process_jastrow(wf):
279                    if 'jastrow' in wf:
280                        js = [wf.jastrow]
281                    elif 'jastrows' in wf:
282                        js = list(wf.jastrows.values())
283                    else:
284                        js = []
285                    #end if
286                    jd = dict()
287                    for j in js:
288                        jtype = j.type.lower().replace('-','_').replace(' ','_')
289                        key = jtype
290                        # take care of multiple jastrows of the same type
291                        if key in jd:  # use name to distinguish
292                            key += j.name
293                            if key in jd:  # if still duplicate then error out
294                                msg = 'duplicate jastrow in '+self.__class__.__name__
295                                self.error(msg)
296                            #end if
297                        #end if
298                        jd[key] = j
299                    #end for
300                    return jd
301                #end def process_jastrow
302                if wavefunction==None:
303                    qs = input.get('qmcsystem')
304                    qs.wavefunction = optwf.copy()
305                else:
306                    jold = process_jastrow(wavefunction)
307                    jopt = process_jastrow(optwf)
308                    jnew = list(jopt.values())
309                    for jtype in jold.keys():
310                        if not jtype in jopt:
311                            jnew.append(jold[jtype])
312                        #end if
313                    #end for
314                    if len(jnew)==1:
315                        wavefunction.jastrow = jnew[0].copy()
316                    else:
317                        wavefunction.jastrows = collection(jnew)
318                    #end if
319                #end if
320                del optwf
321        elif result_name=='particles':
322            if isinstance(sim,Convert4qmc):
323                ptcl_file = result.location
324                qi = QmcpackInput(ptcl_file)
325                self.input.simulation.qmcsystem.particlesets = qi.qmcsystem.particlesets
326            else:
327                self.error('incorporating particles from '+sim.__class__.__name__+' has not been implemented')
328            # end if
329        elif result_name=='structure':
330            relstruct = result.structure.copy()
331            relstruct.change_units('B')
332            self.system.structure = relstruct
333            self.system.remove_folded()
334            self.input.incorporate_system(self.system)
335
336        elif result_name=='cuspcorr':
337
338            ds = self.input.get('determinantset')
339            ds.cuspcorrection = True
340            try: # multideterminant
341                ds.sposets['spo-up'].cuspinfo = os.path.relpath(result.spo_up_cusps,self.locdir)
342                ds.sposets['spo-dn'].cuspinfo = os.path.relpath(result.spo_dn_cusps,self.locdir)
343            except: # single determinant
344                sd = ds.slaterdeterminant
345                sd.determinants['updet'].cuspinfo = os.path.relpath(result.updet_cusps,self.locdir)
346                sd.determinants['downdet'].cuspinfo = os.path.relpath(result.dndet_cusps,self.locdir)
347            #end try
348
349        elif result_name=='wavefunction':
350            if isinstance(sim,Qmcpack):
351                opt = QmcpackInput(result.opt_file)
352                qs = input.get('qmcsystem')
353                qs.wavefunction = opt.qmcsystem.wavefunction.copy()
354            elif isinstance(sim,PyscfToAfqmc):
355                if not self.input.is_afqmc_input():
356                    self.error('incorporating wavefunction from {} is only supported for AFQMC calculations'.format(sim.__class__.__name__))
357                #end if
358                h5_file =  os.path.relpath(result.h5_file,self.locdir)
359                wfn = self.input.simulation.wavefunction
360                ham = self.input.simulation.hamiltonian
361                wfn.filename = h5_file
362                wfn.filetype = 'hdf5'
363                if 'filename' not in ham or ham.filename=='MISSING.h5':
364                    ham.filename = h5_file
365                    ham.filetype = 'hdf5'
366                #end if
367                if 'xml' in result:
368                    xml = QmcpackInput(result.xml)
369                    info_new = xml.simulation.afqmcinfo.copy()
370                    info = self.input.simulation.afqmcinfo
371                    info.set_optional(**info_new)
372                    # override particular inputs set by default
373                    if 'generation_info' in input._metadata:
374                        g = input._metadata.generation_info
375                        if 'walker_type' not in g:
376                            walker_type = xml.get('walker_type')
377                            walkerset = input.get('walkerset')
378                            if walker_type is not None and walkerset is not None:
379                                walkerset.walker_type = walker_type
380                            #end if
381                        #end if
382                    #end if
383                #end if
384            else:
385                self.error('incorporating wavefunction from '+sim.__class__.__name__+' has not been implemented')
386            #end if
387        else:
388            self.error('ability to incorporate result '+result_name+' has not been implemented')
389        #end if
390    #end def incorporate_result
391
392
393    def check_sim_status(self):
394        output = self.outfile_text()
395        errors = self.errfile_text()
396
397        ran_to_end  = 'Total Execution' in output
398        aborted     = 'Fatal Error' in errors
399        files_exist = True
400        cusp_run    = False
401
402        if not self.has_generic_input():
403            if not isinstance(self.input,TracedQmcpackInput):
404                cusp_run = self.input.cusp_correction()
405            #end if
406            if cusp_run:
407                sd = self.input.get('slaterdeterminant')
408                if sd!=None:
409                    cuspfiles = []
410                    for d in sd.determinants:
411                        cuspfiles.append(d.id+'.cuspInfo.xml')
412                    #end for
413                else: # assume multideterminant sposet names
414                    cuspfiles = ['spo-up.cuspInfo.xml','spo-dn.cuspInfo.xml']
415                #end if
416                outfiles   = cuspfiles
417            else:
418                outfiles = self.input.get_output_info('outfiles')
419            #end if
420
421            for file in outfiles:
422                file_loc = os.path.join(self.locdir,file)
423                files_exist = files_exist and os.path.exists(file_loc)
424            #end for
425
426            if ran_to_end and not files_exist:
427                self.warn('run finished successfully, but output files do not exist')
428                self.log(outfiles)
429                self.log(os.listdir(self.locdir))
430            #end if
431        #end if
432
433
434        self.succeeded = ran_to_end
435        self.failed    = aborted
436        self.finished  = files_exist and (self.job.finished or ran_to_end) and not aborted
437
438        if cusp_run and files_exist:
439            for cuspfile in cuspfiles:
440                cf_orig = os.path.join(self.locdir,cuspfile)
441                cf_new  = os.path.join(self.locdir,self.identifier+'.'+cuspfile)
442                os.system('cp {0} {1}'.format(cf_orig,cf_new))
443            #end for
444        #end if
445    #end def check_sim_status
446
447
448    def get_output_files(self):
449        if self.has_generic_input():
450            output_files = []
451        else:
452            if self.should_twist_average and not isinstance(self.input,TracedQmcpackInput):
453                self.twist_average(list(range(len(self.system.structure.kpoints))))
454                br = self.bundle_request
455                input = self.input.trace(br.quantity,br.values)
456                input.generate_filenames(self.infile)
457                self.input = input
458            #end if
459            output_files = self.input.get_output_info('outfiles')
460        #end if
461        return output_files
462    #end def get_output_files
463
464
465    def post_analyze(self,analyzer):
466        if not self.has_generic_input():
467            calctypes = self.input.get_output_info('calctypes')
468            opt_run = calctypes!=None and 'opt' in calctypes
469            if opt_run:
470                opt_file = analyzer.results.optimization.optimal_file
471                if opt_file is None:
472                    self.failed = True
473                #end if
474            #end if
475        #end if
476    #end def post_analyze
477
478
479    def app_command(self):
480        return self.app_name+' '+self.infile
481    #end def app_command
482
483
484    def twist_average(self,twistnums):
485        br = obj()
486        br.quantity = 'twistnum'
487        br.values   = list(twistnums)
488        self.bundle_request = br
489    #end def twist_average
490
491
492    def write_prep(self):
493        if self.got_dependencies:
494            traced_input  = isinstance(self.input,TracedQmcpackInput)
495            generic_input = self.has_generic_input()
496            if 'bundle_request' in self and not traced_input and not generic_input:
497                br = self.bundle_request
498                input = self.input.trace(br.quantity,br.values)
499                input.generate_filenames(self.infile)
500                if self.infile in self.files:
501                    self.files.remove(self.infile)
502                #end if
503                for file in input.filenames:
504                    self.files.add(file)
505                #end for
506                self.infile = input.filenames[-1]
507                self.input  = input
508                self.job.app_command = self.app_command()
509                # write twist info files
510                s = self.system.structure
511                kweights        = s.kweights.copy()
512                kpoints         = s.kpoints.copy()
513                kpoints_qmcpack = s.kpoints_qmcpack()
514                for file in input.filenames:
515                    if file.startswith(self.identifier+'.g'):
516                        tokens = file.split('.')
517                        twist_index = int(tokens[1].replace('g',''))
518                        twist_filename = '{}.{}.twist_info.dat'.format(tokens[0],tokens[1])
519                        kw  = kweights[twist_index]
520                        kp  = kpoints[twist_index]
521                        kpq = kpoints_qmcpack[twist_index]
522                        contents = ' {: 16.6f}  {: 16.12f} {: 16.12f} {: 16.12f}  {: 16.12f} {: 16.12f} {: 16.12f}\n'.format(kw,*kp,*kpq)
523                        fobj = open(os.path.join(self.locdir,twist_filename),'w')
524                        fobj.write(contents)
525                        fobj.close()
526                    #end if
527                #end for
528            #end if
529        #end if
530    #end def write_prep
531#end class Qmcpack
532
533
534
535def generate_qmcpack(**kwargs):
536    sim_args,inp_args = Qmcpack.separate_inputs(kwargs)
537
538    if 'input' not in sim_args:
539        sim_args.input = generate_qmcpack_input(**inp_args)
540    #end if
541    qmcpack = Qmcpack(**sim_args)
542
543    return qmcpack
544#end def generate_qmcpack
545
546
547def generate_cusp_correction(**kwargs):
548    kwargs['input_type']   = 'basic'
549    kwargs['bconds']       = 'nnn'
550    kwargs['jastrows']     = []
551    kwargs['corrections']  = []
552    kwargs['calculations'] = []
553
554    sim_args,inp_args = Simulation.separate_inputs(kwargs)
555
556    input = generate_qmcpack_input(**inp_args)
557
558    wf = input.get('wavefunction')
559    if not 'determinantset' in wf:
560        Qmcpack.class_error('wavefunction does not have determinantset, cannot create cusp correction','generate_cusp_correction')
561    #end if
562    wf.determinantset.cuspcorrection = True
563
564    sim_args.input = input
565    qmcpack = Qmcpack(**sim_args)
566
567    return qmcpack
568#end def generate_cusp_correction
569