1##################################################################
2##  (c) Copyright 2015-  by Jaron T. Krogel                     ##
3##################################################################
4
5
6#====================================================================#
7#  qmcpack_analyzer.py                                               #
8#    Supports data analysis for QMCPACK output.                      #
9#                                                                    #
10#  Content summary:                                                  #
11#    QmcpackAnalyzer                                                 #
12#      SimulationAnalyzer class for QMCPACK.                         #
13#                                                                    #
14#    QmcpackCapabilities                                             #
15#      Class to pair QMCPACK output data with analyzer classes.      #
16#                                                                    #
17#    QmcpackAnalysisRequest                                          #
18#      Offers detailed control over analysis.                        #
19#      Serves as a record of requested analysis.                     #
20#                                                                    #
21#====================================================================#
22
23
24
25from time import time
26
27#python standard library imports
28import os
29import re
30import sys
31import traceback
32from numpy import arange,array
33#custom library imports
34from generic import obj
35from developer import unavailable
36from xmlreader import XMLreader
37from plotting import *
38from physical_system import ghost_atoms
39#QmcpackAnalyzer classes imports
40from qmcpack_analyzer_base import QAobject,QAanalyzer,QAanalyzerCollection
41from qmcpack_property_analyzers \
42    import WavefunctionAnalyzer
43from qmcpack_quantity_analyzers \
44    import ScalarsDatAnalyzer,ScalarsHDFAnalyzer,DmcDatAnalyzer,\
45    EnergyDensityAnalyzer,TracesAnalyzer,DensityMatricesAnalyzer,\
46    SpinDensityAnalyzer,StructureFactorAnalyzer,DensityAnalyzer
47from qmcpack_method_analyzers \
48    import OptAnalyzer,VmcAnalyzer,DmcAnalyzer
49from qmcpack_result_analyzers \
50    import OptimizationAnalyzer,TimestepStudyAnalyzer
51from simulation import SimulationAnalyzer,Simulation
52from qmcpack_input import QmcpackInput
53from debug import *
54
55try:
56    import h5py
57    h5py_unavailable = False
58except:
59    h5py = unavailable('h5py')
60    h5py_unavailable = True
61#end try
62
63
64
65class QmcpackAnalyzerCapabilities(QAobject):
66
67    def __init__(self):
68
69        self.methods=set(['opt','vmc','dmc','rmc'])
70        self.data_sources = set(['scalar','stat','dmc','storeconfig','opt','traces'])
71        self.scalars=set(['localenergy','localpotential','kinetic','elecelec','localecp','nonlocalecp','ionion','localenergy_sq','acceptratio','blockcpu','blockweight','mpc','kecorr'])
72        self.fields=set(['energydensity','density','dm1b','spindensity','structurefactor'])
73
74        hdf_data_sources = set(['stat','storeconfig','traces'])
75        if h5py_unavailable:
76            self.data_sources -= hdf_data_sources
77        #end if
78
79        self.analyzer_quantities = set(self.fields)
80
81        self.analyzers = obj(
82            scalars_dat     = ScalarsDatAnalyzer,
83            scalars_hdf     = ScalarsHDFAnalyzer,
84            dmc_dat         = DmcDatAnalyzer,
85            traces          = TracesAnalyzer,
86            energydensity   = EnergyDensityAnalyzer,
87            dm1b            = DensityMatricesAnalyzer,
88            spindensity     = SpinDensityAnalyzer,
89            structurefactor = StructureFactorAnalyzer,
90            density         = DensityAnalyzer
91            )
92
93        self.quantities = self.scalars | self.fields
94
95        self.ignorable_estimators=set(['LocalEnergy'])
96
97        self.quantity_aliases=dict()
98        for q in self.analyzer_quantities:
99            self.quantity_aliases[q]=q
100        #end for
101
102        self.future_quantities=set(['StructureFactor','MomentumDistribution'])
103        return
104    #end def __init__
105#end class QmcpackCapabilities
106
107
108QAanalyzer.capabilities = QmcpackAnalyzerCapabilities()
109
110
111
112class QmcpackAnalysisRequest(QAobject):
113    def __init__(self,source=None,destination=None,savefile='',
114                 methods=None,calculations=None,data_sources=None,quantities=None,
115                 warmup_calculations=None,
116                 output=set(['averages','samples']),
117                 ndmc_blocks=1000,equilibration=None,group_num=None,
118                 traces=False,dm_settings=None):
119        self.source          = source
120        self.destination     = destination
121        self.savefile        = str(savefile)
122        self.output          = set(output)
123        self.ndmc_blocks     = int(ndmc_blocks)
124        self.group_num       = group_num
125        self.traces          = traces
126        self.dm_settings     = dm_settings
127
128        cap = QAanalyzer.capabilities
129
130        if methods is None:
131            self.methods = set(cap.methods)
132        else:
133            self.methods = set(methods) & cap.methods
134        #end if
135        if calculations is None:
136            self.calculations = set()
137        else:
138            self.calculations = set(calculations)
139        #end if
140        if data_sources is None:
141            self.data_sources = set(cap.data_sources)
142        else:
143            self.data_sources = set(data_sources) & cap.data_sources
144        #end if
145        if quantities is None:
146            self.quantities   = set(cap.quantities)
147        else:
148            quants = set()
149            for q in quantities:
150                qc = self.condense_name(q)
151                quants.add(qc)
152            #end for
153            self.quantities  = quants & cap.quantities
154        #end if
155        if warmup_calculations is None:
156            self.warmup_calculations = set()
157        else:
158            self.warmup_calculations = set(warmup_calculations)
159        #end if
160        if isinstance(equilibration,(dict,obj)):
161            eq = obj()
162            eq.transfer_from(equilibration)
163        else:
164            eq = equilibration
165        #end if
166        self.equilibration = eq
167
168        return
169    #end def __init__
170
171    def complete(self):
172        spath,sfile = os.path.split(self.source)
173        if spath=='':
174            self.source = os.path.join('./',self.source)
175        #end if
176        if self.destination==None:
177            self.destination = os.path.split(self.source)[0]
178        #end if
179        return True
180    #end def complete
181#end class QmcpackAnalysisRequest:
182
183
184
185"""
186class QmcpackAnalyzer
187  used to analyze all data produced by QMCPACK
188
189  usage:
190     results = QmcpackAnalyzer("qmcpack.in.xml")
191       |  QMC methods used and observables estimated are determined
192       \  Each observable is calculated by an object contained in results
193"""
194class QmcpackAnalyzer(SimulationAnalyzer,QAanalyzer):
195    def __init__(self,arg0=None,**kwargs):
196
197        verbose = False
198        if 'verbose' in kwargs:
199            verbose=kwargs['verbose']
200            del kwargs['verbose']
201        #end if
202        QAanalyzer.verbose_vlog = verbose or QAanalyzer.verbose_vlog
203
204        nindent = 0
205        if 'nindent' in kwargs:
206            nindent=kwargs['nindent']
207            del kwargs['nindent']
208        #end if
209        QAanalyzer.__init__(self,nindent=nindent)
210
211        analyze = False
212        if 'analyze' in kwargs:
213            analyze=kwargs['analyze']
214            del kwargs['analyze']
215        #end if
216
217        if 'ghost_atoms' in kwargs:
218            ghosts = kwargs.pop('ghost_atoms')
219            ghost_atoms(*ghosts)
220        #end if
221
222        if isinstance(arg0,Simulation):
223            sim = arg0
224            if 'analysis_request' in sim:
225                request = sim.analysis_request.copy()
226            else:
227                request = QmcpackAnalysisRequest(
228                    source = os.path.join(sim.resdir,sim.infile),
229                    destination = sim.resdir
230                    )
231                if 'stat' in request.data_sources:
232                    request.data_sources.remove('stat')
233                #end if
234                if 'storeconfig' in request.data_sources:
235                    request.data_sources.remove('storeconfig')
236                #end if
237                if 'traces' in request.data_sources:
238                    request.data_sources.remove('traces')
239                #end if
240            #end if
241        elif isinstance(arg0,QmcpackAnalysisRequest):
242            request = arg0
243        elif isinstance(arg0,str):
244            kwargs['source']=arg0
245            request = QmcpackAnalysisRequest(**kwargs)
246        else:
247            if 'source' not in kwargs:
248                kwargs['source']='./qmcpack.in.xml'
249            #end if
250            request = QmcpackAnalysisRequest(**kwargs)
251        #end if
252
253        self.change_request(request)
254
255        if request!=None and os.path.exists(request.source):
256            self.init_sub_analyzers(request)
257        #end if
258
259        savefile = request.savefile
260        savefilepath = os.path.join(request.destination,request.savefile)
261        self.info.savefile = savefile
262        self.info.savefilepath = savefilepath
263        self.info.error = None
264        if os.path.exists(savefilepath) and savefile!='':
265            self.load()
266        elif analyze:
267            self.analyze()
268        #end if
269
270        return
271    #end def __init__
272
273
274    def change_request(self,request):
275        if not isinstance(request,QmcpackAnalysisRequest):
276            self.error('input request must be a QmcpackAnalysisRequest',exit=False)
277            self.error('  type provided: '+str(type(request)))
278        #end if
279        request.complete()
280        self.info.request = request
281    #end def change_request
282
283
284
285    def init_sub_analyzers(self,request=None):
286        own_request = request==None
287        if request==None:
288            request = self.info.request
289        #end if
290        group_num = request.group_num
291
292        #determine if the run was bundled
293        if request.source.endswith('.xml'):
294            self.info.type = 'single'
295        else:
296            self.info.type = 'bundled'
297            self.bundle(request.source)
298            return
299        #end if
300
301        self.vlog('reading input file: '+request.source,n=1)
302        input = QmcpackInput(request.source)
303        input.pluralize()
304        input.unroll_calculations()
305        calculations = input.simulation.calculations
306        self.info.set(
307            input = input,
308            ordered_input = input.read_xml(request.source)
309            )
310
311        project,wavefunction = input.get('project','wavefunction')
312        wavefunction = wavefunction.get_single('psi0')
313
314        subindent = self.subindent()
315
316        self.wavefunction = WavefunctionAnalyzer(wavefunction,nindent=subindent)
317
318        self.vlog('project id: '+project.id,n=1)
319        file_prefix  = project.id
320        if group_num!=None:
321            group_ext = '.g'+str(group_num).zfill(3)
322            if not file_prefix.endswith(group_ext):
323                file_prefix += group_ext
324            #end if
325        elif self.info.type=='single':
326            resdir,infile = os.path.split(request.source)
327            #ifprefix = infile.replace('.xml','')
328            ifprefix = infile.replace('.xml','.')
329            ls = os.listdir(resdir)
330            for filename in ls:
331                if filename.startswith(ifprefix) and filename.endswith('.qmc'):
332                    group_tag = filename.split('.')[-2]
333                    #file_prefix = 'qmc.'+group_tag
334                    file_prefix = project.id+'.'+group_tag
335                    break
336                #end if
337            #end for
338        #end if
339        if 'series' in project:
340            series_start = int(project.series)
341        else:
342            series_start = 0
343        #end if
344
345        self.vlog('data file prefix: '+file_prefix,n=1)
346
347        run_info = obj(
348            file_prefix  = file_prefix,
349            series_start = series_start,
350            source_path  = os.path.split(request.source)[0],
351            group_num    = group_num,
352            system       = input.return_system()
353            )
354        self.info.transfer_from(run_info)
355
356        self.set_global_info()
357
358        if len(request.calculations)==0:
359            request.calculations = set(series_start+arange(len(calculations)))
360        #end if
361
362        method_aliases = dict()
363        for method in self.opt_methods:
364            method_aliases[method]='opt'
365        #end for
366        for method in self.vmc_methods:
367            method_aliases[method]='vmc'
368        #end for
369        for method in self.dmc_methods:
370            method_aliases[method]='dmc'
371        #end for
372
373        method_objs = ['qmc','opt','vmc','dmc']
374        for method in method_objs:
375            self[method] = QAanalyzerCollection()
376        #end for
377        for index,calc in calculations.items():
378            method = calc.method
379            if method in method_aliases:
380                method_type = method_aliases[method]
381            else:
382                self.error('method '+method+' is unrecognized')
383            #end if
384            if method_type in request.methods:
385                series = series_start + index
386                if series in request.calculations:
387                    if method in self.opt_methods:
388                        qma = OptAnalyzer(series,calc,input,nindent=subindent)
389                        primary = self.opt
390                    elif method in self.vmc_methods:
391                        qma = VmcAnalyzer(series,calc,input,nindent=subindent)
392                        primary = self.vmc
393                    elif method in self.dmc_methods:
394                        qma = DmcAnalyzer(series,calc,input,nindent=subindent)
395                        primary = self.dmc
396                    #end if
397                    primary[series]  = qma
398                    self.qmc[series] = qma
399                #end if
400            #end if
401        #end for
402        for method in method_objs:
403            if len(self[method])==0:
404                del self[method]
405            #end if
406        #end for
407
408        #Check for multi-qmc results such as
409        # optimization or timestep studies
410        results = QAanalyzerCollection()
411        if 'opt' in self and len(self.opt)>0:
412            optres = OptimizationAnalyzer(input,self.opt,nindent=subindent)
413            results.optimization = optres
414        #end if
415        if 'dmc' in self and len(self.dmc)>1:
416            maxtime = 0
417            times = dict()
418            for series,dmc in self.dmc.items():
419                blocks,steps,timestep = dmc.info.method_input.list('blocks','steps','timestep')
420                times[series] = blocks*steps*timestep
421                maxtime = max(times[series],maxtime)
422            #end for
423            dmc = QAanalyzerCollection()
424            for series,time in times.items():
425                if abs(time-maxtime)/maxtime<.5:
426                    dmc[series] = self.dmc[series]
427                #end if
428            #end for
429            if len(dmc)>1:
430                results.timestep_study = TimestepStudyAnalyzer(dmc,nindent=subindent)
431            #end if
432        #end if
433
434        if len(results)>0:
435            self.results = results
436        #end if
437
438        self.unset_global_info()
439
440    #end def init_sub_analyzers
441
442
443    def set_global_info(self):
444        QAanalyzer.request  = self.info.request
445        QAanalyzer.run_info = self.info
446    #end def set_global_info
447
448    def unset_global_info(self):
449        QAanalyzer.request  = None
450        QAanalyzer.run_info = None
451    #end def unset_global_info
452
453
454    def load_data(self):
455        request = self.info.request
456        if not os.path.exists(request.source):
457            self.error('path to source\n  '+request.source+'\n  does not exist\n ensure that request.source points to a valid qmcpack input file')
458        #end if
459        self.set_global_info()
460        self.propagate_indicators(data_loaded=False)
461        if self.info.type=='bundled' and self.info.perform_bundle_average:
462            self.prevent_average_load()
463        #end if
464        QAanalyzer.load_data(self)
465        if self.info.type=='bundled' and self.info.perform_bundle_average:
466            self.average_bundle_data()
467        #end if
468        self.unset_global_info()
469    #end def load_data
470
471
472    def analyze(self,force=False):
473        if not self.info.analyzed or force:
474            if not self.info.data_loaded:
475                self.load_data()
476            #end if
477            self.vlog('main analysis of QmcpackAnalyzer data',n=1)
478            try:
479                self.set_global_info()
480                self.propagate_indicators(analyzed=False)
481                if self.info.type!='bundled':
482                    QAanalyzer.analyze(self,force=force)
483                else:
484                    for analyzer in self.bundled_analyzers:
485                        analyzer.analyze()
486                    #end for
487                    QAanalyzer.analyze(self,force=force)
488                #end if
489                self.unset_global_info()
490            except:
491                exc_type, exc_value, exc_traceback = sys.exc_info()
492                lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
493                msg = ''
494                for line in lines:
495                    msg+=line
496                #end for
497                self.info.error = exc_type
498                self.warn('runtime exception encountered\n'+msg)
499            #end try
500            self.vlog('end main analysis of QmcpackAnalyzer data',n=1)
501            if self.info.request.savefile!='':
502                self.save()
503            #end if
504        #end if
505    #end def analyze
506
507
508
509    def bundle(self,source):
510        self.vlog('bundled run detected',n=1)
511        if os.path.exists(source):
512            fobj = open(source,'r')
513            lines = fobj.read().split('\n')
514            fobj.close()
515        else:
516            self.error('source file '+source+' does not exist')
517        #end if
518        infiles = []
519        for line in lines:
520            ls = line.strip()
521            if ls!='':
522                infiles.append(ls)
523            #end if
524        #end for
525        self.info.input_infiles = list(infiles)
526        analyzers = QAanalyzerCollection()
527        request = self.info.request
528        path = os.path.split(request.source)[0]
529        files = os.listdir(path)
530        outfiles = []
531        for file in files:
532            if file.endswith('qmc'):
533                outfiles.append(file)
534            #end if
535        #end for
536        del files
537        for i in range(len(infiles)):
538            infile = infiles[i]
539            prefix = infile.replace('.xml','')
540            gn = i
541            for outfile in outfiles:
542                if outfile.startswith(prefix):
543                    gn = int(outfile.split('.')[-2][1:])
544                    break
545                #end if
546            #end for
547            req = request.copy()
548            req.source = os.path.join(path,infile)
549            req.group_num = gn
550            qa = QmcpackAnalyzer(req,nindent=self.subindent())
551            #qa.init_sub_analyzers(group_num=gn)
552            analyzers[gn] = qa
553        #end for
554        self.bundled_analyzers = analyzers
555        self.info.perform_bundle_average = False
556        #check to see if twist averaging
557        #  indicated by distinct twistnums
558        #  or twist in all ids
559        twistnums = set()
560        twist_ids = True
561        for analyzer in analyzers:
562            input = analyzer.info.input
563            twistnum = input.get('twistnum')
564            project = input.get('project')
565            if twistnum!=None:
566                twistnums.add(twistnum)
567            #end if
568            twist_ids = twist_ids and 'twist' in project.id
569        #end for
570        distinct_twistnums = len(twistnums)==len(analyzers)
571        twist_averaging = distinct_twistnums or twist_ids
572        if twist_averaging:
573            self.info.perform_bundle_average = True
574        #end if
575        example = analyzers.list()[0]
576        input,system = example.info.tuple('input','system')
577        self.info.set(
578            input  = input.copy(),
579            system = system.copy()
580            )
581        self.vlog('average over bundled runs?  {0}'.format(self.info.perform_bundle_average),n=1)
582    #end def bundle
583
584
585    def prevent_average_load(self):
586        for method_type in self.capabilities.methods:
587            if method_type in self:
588                self[method_type].propagate_indicators(data_loaded=True)
589            #end if
590        #end for
591    #end def prevent_average_load
592
593
594    def average_bundle_data(self):
595        analyzers = self.bundled_analyzers
596        if len(analyzers)>0:
597            self.vlog('performing bundle (e.g. twist) averaging',n=1)
598            #create local data structures to match those in the bundle
599            example = analyzers.list()[0].copy()
600            for method_type in self.capabilities.methods:
601                if method_type in self:
602                    del self[method_type]
603                #end if
604                if method_type in example:
605                    self.vlog('copying {0} methods from analyzer 0'.format(method_type),n=2)
606                    self[method_type] = example[method_type]
607                #end if
608            #end if
609            if 'qmc' in self:
610                del self.qmc
611            #end if
612            if 'qmc' in example:
613                self.vlog('copying qmc methods from analyzer 0',n=2)
614                self.qmc = example.qmc
615            #end if
616            if 'wavefunction' in self:
617                del self.wavefunction
618            #end if
619            if 'wavefunction' in example:
620                self.vlog('copying wavefunction from analyzer 0',n=2)
621                self.wavefunction = example.wavefunction
622            #end if
623            del example
624
625            if 'qmc' in self:
626                #zero out the average data
627                self.vlog('zeroing own qmc data',n=2)
628                for qmc in self.qmc:
629                    qmc.zero_data()
630                #end for
631
632                #resize the average data
633                self.vlog('finding minimum data size (for incomplete runs)',n=2)
634                for analyzer in analyzers:
635                    for series,qmc in self.qmc.items():
636                        qmc.minsize_data(analyzer.qmc[series])
637                    #end for
638                #end for
639
640                #accumulate the average data
641                self.vlog('accumulating data from bundled runs',n=2)
642                for analyzer in analyzers:
643                    for series,qmc in self.qmc.items():
644                        qmc.accumulate_data(analyzer.qmc[series])
645                    #end for
646                #end for
647
648                #normalize the average data
649                norm_factor = len(analyzers)
650                self.vlog('normalizing bundle average (factor={0})'.format(norm_factor),n=2)
651                for qmc in self.qmc:
652                    qmc.normalize_data(norm_factor)
653                #end for
654            #end if
655        #end if
656    #end def average_bundle_data
657
658
659
660
661    def save(self,filepath=None,overwrite=True):
662        if filepath==None:
663            filepath = self.info.savefilepath
664        #end if
665        self.vlog('saving QmcpackAnalyzer in file {0}'.format(filepath),n=1)
666        if not overwrite and os.path.exists(filepath):
667            return
668        #end if
669        self._unlink_dynamic_methods()
670        self.saved_global = QAobject._global
671        self._save(filepath)
672        self._relink_dynamic_methods()
673        return
674    #end def save
675
676    def load(self,filepath=None):
677        if filepath==None:
678            filepath = self.info.savefilepath
679        #end if
680        self.vlog('loading QmcpackAnalyzer from file {0}'.format(filepath),n=1)
681        self._load(filepath)
682        QAobject._global = self.saved_global
683        del self.saved_global
684        self._relink_dynamic_methods()
685        return
686    #end def load
687
688
689
690
691    def check_traces(self,verbose=False,pad=None,header=None):
692        if pad is None:
693            pad = ''
694        #end if
695        if header is None:
696            header = '\nChecking traces'
697        #end if
698        if 'qmc' in self:
699            if verbose:
700                self.log(pad+header)
701                pad += '  '
702            #end if
703            for method in self.qmc:
704                method.check_traces(pad)
705            #end for
706        else:
707            if verbose:
708                self.log(pad+'\nNo traces to check')
709            #end if
710            return None
711        #end if
712    #end def check_traces
713
714
715    def plot_trace(self,quantity,style='b-',offset=0,source='scalar',mlabels=True,
716                   mlines=True,show=True,alloff=False):
717        mlabels &= not alloff
718        mlines  &= not alloff
719        show    &= not alloff
720        shw = show
721        offset = int(offset)
722        id = self.info.input.get('project').id
723        sdata = obj()
724        series = sorted(self.qmc.keys())
725        q = []
726        soffset = offset
727        for s in series:
728            qmc = self.qmc[s]
729            method = qmc.info.method
730            if source=='scalar' or method=='vmc':
731                src = qmc.scalars.data
732            elif source=='dmc':
733                src = qmc.dmc.data
734            else:
735                self.error('invalid source: '+source)
736            #end if
737            if quantity in src:
738                qn = list(src[quantity])
739            else:
740                qn = len(src.LocalEnergy)*[0]
741            #end if
742            q.extend(qn)
743            sdata[s] = obj(
744                mlab = method+' '+str(s),
745                mloc = soffset + len(qn)//2,
746                line_loc = soffset + len(qn)-1
747                )
748            soffset += len(qn)
749        #end for
750        q = array(q)
751        qmin = q.min()
752        qmax = q.max()
753        mlabel_height = qmin + .8*(qmax-qmin)
754        if shw:
755            figure()
756        #end if
757        plot(offset+arange(len(q)),q,style,label=id)
758        for s in series:
759            sd = sdata[s]
760            if mlabels:
761                text(sd.mloc,mlabel_height,sd.mlab)
762            #end if
763            if mlines:
764                plot([sd.line_loc,sd.line_loc],[qmin,qmax],'k-')
765            #end if
766        #end for
767        if shw:
768            title('{0} vs series for {1}'.format(quantity,id))
769            xlabel('blocks')
770            ylabel(quantity)
771            legend()
772            show()
773        #end if
774    #end def plot_trace
775
776#end class QmcpackAnalyzer
777
778
779