1#! /usr/bin/env python3
2
3# Performs consistency checks between traces.h5 and scalar.dat/stat.h5/dmc.dat files
4# Jaron Krogel/ORNL
5
6# Type the following to view documentation for command line inputs:
7#   >check_traces.py -h
8
9# For usage examples, type:
10#   >check_traces.py -x
11
12# check_traces.py packages obj and HDFreader classes from Nexus.
13#   Note that h5py is required (which depends on numpy).
14#   This script is compatible with both Python 2 and 3.
15
16import os
17import sys
18from copy import deepcopy
19import numpy as np
20import h5py
21from optparse import OptionParser
22
23# Returns failure error code to OS.
24# Explicitly prints 'fail' after an optional message.
25def test_fail():
26    print('\n\nTest status: fail')
27    sys.exit(1)
28#end def test_fail
29
30
31# Returns success error code to OS.
32# Explicitly prints 'pass' after an optional message.
33def test_pass():
34    print('\n\nTest status: pass')
35    sys.exit(0)
36#end def test_pass
37
38
39
40######################################################################
41# from generic.py
42######################################################################
43
44logfile = sys.stdout
45
46def log(msg,n=0,indent='  '):
47    if not isinstance(msg,str):
48        msg = str(msg)
49    #end if
50    if n>0:
51        indent = n*indent
52        msg=indent+msg.replace('\n','\n'+indent)
53    #end if
54    logfile.write(msg+'\n')
55#end def log
56
57
58def error(msg,header=None,n=0):
59    post_header=' error:'
60    if header is None:
61        header = post_header.lstrip()
62    else:
63        header += post_header
64    #end if
65    log('\n  '+header,n=n)
66    log(msg.rstrip(),n=n)
67    log('  exiting.\n')
68    test_fail()
69#end def error
70
71
72
73class object_interface(object):
74    _logfile = sys.stdout
75
76    def __len__(self):
77        return len(self.__dict__)
78    #end def __len__
79
80    def __contains__(self,name):
81        return name in self.__dict__
82    #end def
83
84    def __getitem__(self,name):
85        return self.__dict__[name]
86    #end def __getitem__
87
88    def __setitem__(self,name,value):
89        self.__dict__[name]=value
90    #end def __setitem__
91
92    def __delitem__(self,name):
93        del self.__dict__[name]
94    #end def __delitem__
95
96    def __iter__(self):
97        for item in self.__dict__:
98            yield self.__dict__[item]
99        #end for
100    #end def __iter__
101
102    def __repr__(self):
103        s=''
104        for k in sorted(self.keys()):
105            if not isinstance(k,str) or k[0]!='_':
106                v=self.__dict__[k]
107                if hasattr(v,'__class__'):
108                    s+='  {0:<20}  {1:<20}\n'.format(str(k),v.__class__.__name__)
109                else:
110                    s+='  {0:<20}  {1:<20}\n'.format(str(k),type(v))
111                #end if
112            #end if
113        #end for
114        return s
115    #end def __repr__
116
117    def __str__(self,nindent=1):
118        pad = '  '
119        npad = nindent*pad
120        s=''
121        normal = []
122        qable  = []
123        for k,v in self.items():
124            if not isinstance(k,str) or k[0]!='_':
125                if isinstance(v,object_interface):
126                    qable.append(k)
127                else:
128                    normal.append(k)
129                #end if
130            #end if
131        #end for
132        normal = sorted(normal)
133        qable  = sorted(qable)
134        indent = npad+18*' '
135        for k in normal:
136            v = self[k]
137            vstr = str(v).replace('\n','\n'+indent)
138            s+=npad+'{0:<15} = '.format(str(k))+vstr+'\n'
139        #end for
140        for k in qable:
141            v = self[k]
142            s+=npad+str(k)+'\n'
143            s+=v.__str__(nindent+1)
144            if isinstance(k,str):
145                s+=npad+'end '+k+'\n'
146            #end if
147        #end for
148        return s
149    #end def __str__
150
151    def keys(self):
152        return self.__dict__.keys()
153    #end def keys
154
155    def values(self):
156        return self.__dict__.values()
157    #end def values
158
159    def items(self):
160        return self.__dict__.items()
161    #end def items
162
163    def copy(self):
164        return deepcopy(self)
165    #end def copy
166
167    def clear(self):
168        self.__dict__.clear()
169    #end def clear
170
171    def log(self,*items,**kwargs):
172        log(*items,**kwargs)
173    #end def log
174
175    def error(self,message,header=None,n=0):
176        if header is None:
177            header = self.__class__.__name__
178        #end if
179        error(message,header,n=n)
180    #end def error
181#end class object_interface
182
183
184
185class obj(object_interface):
186    def __init__(self,*vars,**kwargs):
187        for var in vars:
188            if isinstance(var,(dict,object_interface)):
189                for k,v in var.items():
190                    self[k] = v
191                #end for
192            else:
193                self[var] = None
194            #end if
195        #end for
196        for k,v in kwargs.items():
197            self[k] = v
198        #end for
199    #end def __init__
200
201    def append(self,value):
202        self[len(self)] = value
203    #end def append
204
205    def first(self):
206        return self[min(self.keys())]
207    #end def first
208#end class obj
209
210######################################################################
211# end from generic.py
212######################################################################
213
214
215
216######################################################################
217# from developer.py
218######################################################################
219
220class DevBase(obj):
221    None
222#end class DevBase
223
224######################################################################
225# end from developer.py
226######################################################################
227
228
229
230
231######################################################################
232# from hdfreader.py
233######################################################################
234import keyword
235from inspect import getmembers
236
237class HDFglobals(DevBase):
238    view = False
239#end class HDFglobals
240
241
242class HDFgroup(DevBase):
243    def _escape_name(self,name):
244        if name in self._escape_names:
245            name=name+'_'
246        #end if
247        return name
248    #end def escape_name
249
250    def _set_parent(self,parent):
251        self._parent=parent
252        return
253    #end def set_parent
254
255    def _add_dataset(self,name,dataset):
256        self._datasets[name]=dataset
257        return
258    #end def add_dataset
259
260    def _add_group(self,name,group):
261        group._name=name
262        self._groups[name]=group
263        return
264    #end def add_group
265
266    def _contains_group(self,name):
267        return name in self._groups.keys()
268    #end def _contains_group
269
270    def _contains_dataset(self,name):
271        return name in self._datasets.keys()
272    #end def _contains_dataset
273
274
275    def __init__(self):
276        self._name=''
277        self._parent=None
278        self._groups={};
279        self._datasets={};
280        self._group_counts={}
281
282        self._escape_names=None
283        self._escape_names=set(dict(getmembers(self)).keys()) | set(keyword.kwlist)
284        return
285    #end def __init__
286
287
288    def _remove_hidden(self,deep=True):
289        if '_parent' in self:
290            del self._parent
291        #end if
292        if deep:
293            for name,value in self.items():
294                if isinstance(value,HDFgroup):
295                    value._remove_hidden()
296                #end if
297            #end for
298        #end if
299        for name in list(self.keys()):
300            if name[0]=='_':
301                del self[name]
302            #end if
303        #end for
304    #end def _remove_hidden
305
306
307    # read in all data views (h5py datasets) into arrays
308    #   useful for converting a single group read in view form to full arrays
309    def read_arrays(self):
310        self._remove_hidden()
311        for k,v in self.items():
312            if isinstance(v,HDFgroup):
313                v.read_arrays()
314            else:
315                self[k] = np.array(v)
316            #end if
317        #end for
318    #end def read_arrays
319
320
321    def get_keys(self):
322        if '_groups' in self:
323            keys = list(self._groups.keys())
324        else:
325            keys = list(self.keys())
326        #end if
327        return keys
328    #end def get_keys
329#end class HDFgroup
330
331
332
333
334class HDFreader(DevBase):
335    datasets = set(["<class 'h5py.highlevel.Dataset'>","<class 'h5py._hl.dataset.Dataset'>"])
336    groups   = set(["<class 'h5py.highlevel.Group'>","<class 'h5py._hl.group.Group'>"])
337
338    def __init__(self,fpath,verbose=False,view=False):
339
340        HDFglobals.view = view
341
342        if verbose:
343            print('  Initializing HDFreader')
344        #end if
345
346        self.fpath=fpath
347        if verbose:
348            print('    loading h5 file')
349        #end if
350
351        try:
352            self.hdf = h5py.File(fpath,'r')
353        except IOError:
354            self._success = False
355            self.hdf = obj(obj=obj())
356        else:
357            self._success = True
358        #end if
359
360        if verbose:
361            print('    converting h5 file to dynamic object')
362        #end if
363
364        #convert the hdf 'dict' into a dynamic object
365        self.nlevels=1
366        self.ilevel=0
367        #  Set the current hdf group
368        self.obj = HDFgroup()
369        self.cur=[self.obj]
370        self.hcur=[self.hdf]
371
372        if self._success:
373            cur   = self.cur[self.ilevel]
374            hcur  = self.hcur[self.ilevel]
375            for kr,v in hcur.items():
376                k=cur._escape_name(kr)
377                vtype = str(type(v))
378                if vtype in HDFreader.datasets:
379                    self.add_dataset(cur,k,v)
380                elif vtype in HDFreader.groups:
381                    self.add_group(hcur,cur,k,v)
382                else:
383                    self.error('encountered invalid type: '+vtype)
384                #end if
385            #end for
386        #end if
387
388        if verbose:
389            print('  end HDFreader Initialization')
390        #end if
391
392        return
393    #end def __init__
394
395
396    def increment_level(self):
397        self.ilevel+=1
398        self.nlevels = max(self.ilevel+1,self.nlevels)
399        if self.ilevel+1==self.nlevels:
400            self.cur.append(None)
401            self.hcur.append(None)
402        #end if
403        self.pad = self.ilevel*'  '
404        return
405    #end def increment_level
406
407    def decrement_level(self):
408        self.ilevel-=1
409        self.pad = self.ilevel*'  '
410        return
411    #end def decrement_level
412
413    def add_dataset(self,cur,k,v):
414        if not HDFglobals.view:
415            cur[k] = np.array(v)
416        else:
417            cur[k] = v
418        #end if
419        cur._add_dataset(k,cur[k])
420        return
421    #end def add_dataset
422
423    def add_group(self,hcur,cur,k,v):
424        cur[k] = HDFgroup()
425        cur._add_group(k,cur[k])
426        cur._groups[k]._parent = cur
427        self.increment_level()
428        self.cur[self.ilevel]  = cur._groups[k]
429        self.hcur[self.ilevel] = hcur[k]
430
431        cur   = self.cur[self.ilevel]
432        hcur  = self.hcur[self.ilevel]
433        for kr,v in hcur.items():
434            k=cur._escape_name(kr)
435            vtype = str(type(v))
436            if vtype in HDFreader.datasets:
437                self.add_dataset(cur,k,v)
438            elif vtype in HDFreader.groups:
439                self.add_group(hcur,cur,k,v)
440            #end if
441        #end for
442    #end def add_group
443#end class HDFreader
444
445######################################################################
446# end from hdfreader.py
447######################################################################
448
449
450
451# Represents QMCPACK data file.
452#   Used to read scalar.dat, stat.h5, dmc.dat, traces.h5
453class DataFile(DevBase):
454
455    aliases = None
456
457    def __init__(self,filepath=None,quantities=None):
458        self.data       = None
459        self.filepath   = filepath
460        self.quantities = None
461        if quantities is not None:
462            self.quantities = list(quantities)
463        #end if
464        if filepath is not None:
465            self.read(filepath)
466            if self.aliases is not None:
467                for name,alias in self.aliases.items():
468                    if name in self.data:
469                        self.data[alias] = self.data[name]
470                        del self.data[name]
471                    #end if
472                #end for
473            #end if
474            if quantities is not None:
475                missing = set(quantities)-set(self.data.keys())
476                if len(missing)>0:
477                    self.error('some quantities are missing from file "{}"\nquantities present: {}\nquantities missing: {}'.format(self.filepath,sorted(self.data.keys()),sorted(missing)))
478                #end if
479            #end if
480        #end if
481    #end def __init__
482
483    def read(self,filepath):
484        None
485    #end def read
486#end class DataFile
487
488
489# Used to parse scalar.dat and dmc.dat files
490class DatFile(DataFile):
491    def read(self,filepath):
492        lt = np.loadtxt(filepath)
493        if len(lt.shape)==1:
494            lt.shape = (1,len(lt))
495        #end if
496
497        data = lt[:,1:].transpose()
498
499        fobj = open(filepath,'r')
500        variables = fobj.readline().split()[2:]
501        fobj.close()
502
503        self.data = obj()
504        for i,vname in enumerate(variables):
505            self.data[vname]=data[i,:]
506        #end for
507    #end def read
508#end class DatFile
509
510
511# Parses scalar.dat
512class ScalarDatFile(DatFile):
513    aliases = obj(BlockWeight='Weight')
514#end class ScalarDat
515
516
517# Parses dmc.dat
518class DmcDatFile(DatFile):
519    None
520#end class DmcDatFile
521
522
523
524# Parses scalar data from stat.h5
525class ScalarHDFFile(DataFile):
526    def read(self,filepath):
527        hr = HDFreader(filepath)
528        if not hr._success:
529            self.error('hdf file read failed\nfile path: '+filepath)
530        #end if
531        data = hr.obj
532        data._remove_hidden()
533
534        self.data = obj()
535        for name,value in data.items():
536            if 'value' in value:
537                self.data[name] = value.value.flatten()
538            #end if
539        #end for
540    #end def read
541#end class ScalarHDFFile
542
543
544
545# Parses and organizes data from traces.h5.
546#   QMCPACK writes one traces.h5 for each MPI task.
547#   At every MC step, data from each walker is written to this file.
548class TracesFileHDF(DataFile):
549    def __init__(self,filepath=None,blocks=None):
550        self.info = obj(
551            blocks              = blocks,
552            particle_sums_valid = None,
553            )
554        DataFile.__init__(self,filepath)
555    #end def __init__
556
557
558    def read(self,filepath=None):
559        # Open the traces.h5 file
560        hr = HDFreader(filepath)
561        if not hr._success:
562            self.error('hdf file read failed\nfile path: '+filepath)
563        #end if
564        hdf = hr.obj
565        hdf._remove_hidden()
566
567        # Translate from flat table structure to nested data structure.
568        #   Do this for top level "int_data" and "real_data" HDF groups
569        for name,buffer in hdf.items():
570            self.init_trace(name,buffer)
571        #end for
572
573        # Sum trace data over walkers into per-step and per-block totals
574        self.accumulate_scalars()
575    #end def read
576
577
578    # Translate from serialized traces table format to fully dimensioned
579    #   data resolved by physical quantity.
580    def init_trace(self,name,fbuffer):
581        trace = obj()
582        if 'traces' in fbuffer:
583            # Wide data table
584            #   Each row corresponds to a single step of a single walker.
585            #   Each row contains serialized scalar (e.g. LocalEnergy)
586            #   and array (e.g. electron coordinates) data.
587            ftrace = fbuffer.traces
588
589            # Number of rows (walkers*steps for single mpi task)
590            nrows = len(ftrace)
591
592            # Serialization layout of each row is stored in "layout".
593            #   Layout is separated into a few potential domains:
594            #     scalar domain  : scalar quantities such as LocalEnergy
595            #                      domain name is "scalars"
596            #     electron domain: array quantities dimensioned like number of electrons (e.g. electron positions)
597            #                      domain name follows particleset (often "e")
598            #     ion domain     : array quantities dimensioned like number of ions
599            #                      domain name follows particleset (often "ion0")
600            for dname,fdomain in fbuffer.layout.items():
601                domain = obj()
602                # Get start and end row indices for each quantity
603                for qname,fquantity in fdomain.items():
604                    q = obj()
605                    for vname,value in fquantity.items():
606                        q[vname] = value
607                    #end for
608
609                    # extract per quantity data across all walkers and steps
610                    quantity = ftrace[:,q.row_start:q.row_end]
611
612                    # reshape from serialized to multidimensional data for the quantity
613                    if q.unit_size==1:
614                        shape = [nrows]+list(fquantity.shape[0:q.dimension])
615                    else:
616                        shape = [nrows]+list(fquantity.shape[0:q.dimension])+[q.unit_size]
617                    #end if
618                    quantity.shape = tuple(shape)
619                    #if len(fquantity.shape)==q.dimension:
620                    #    quantity.shape = tuple([nrows]+list(fquantity.shape))
621                    ##end if
622                    domain[qname] = quantity
623                #end for
624                trace[dname] = domain
625            #end for
626        else:
627            self.error('traces are missing in file "{}"'.format(self.filepath))
628        #end if
629        # rename "int_data" and "real_data" as "int_traces" and "real_traces"
630        self[name.replace('data','traces')] = trace
631    #end def init_trace
632
633
634    # Perform internal consistency check between per-walker single
635    #   particle energies and per-walker total energies.
636    def check_particle_sums(self,tol):
637        t = self.real_traces
638
639        # Determine quantities present as "scalars" (total values) and also per-particle
640        scalar_names = set(t.scalars.keys())
641        other_names = []
642        for dname,domain in t.items():
643            if dname!='scalars':
644                other_names.extend(domain.keys())
645            #end if
646        #end for
647        other_names = set(other_names)
648        sum_names = scalar_names & other_names
649
650        # For each quantity, determine if the sum over particles matches the total
651        same = True
652        for qname in sum_names:
653            # Get total value for each quantity
654            q = t.scalars[qname]
655
656            # Perform the sum over particles
657            qs = 0*q
658            for dname,domain in t.items():
659                if dname!='scalars' and qname in domain:
660                    tqs = domain[qname].sum(1)
661                    if len(tqs.shape)==1:
662                        qs[:,0] += tqs
663                    else:
664                        qs[:,0] += tqs[:,0]
665                    #end if
666                #end if
667            #end for
668
669            # Compare total and summed quantities
670            qsame = (abs(q-qs)<tol).all()
671            if qsame:
672                log('{:<16} matches'.format(qname),n=3)
673            else:
674                log('{:<16} does not match'.format(qname),n=3)
675            #end if
676            same = same and qsame
677        #end for
678        self.info.particle_sums_valid = same
679        return self.info.particle_sums_valid
680    #end def check_particle_sums
681
682
683    # Sum trace data over walkers into per-step and per-block totals
684    def accumulate_scalars(self):
685        # Get block and step information for the qmc method
686        blocks = self.info.blocks
687        if blocks is None:
688            self.scalars_by_step  = None
689            self.scalars_by_block = None
690            return
691        #end if
692
693        # Get real and int valued trace data
694        tr = self.real_traces
695        ti = self.int_traces
696
697        # Names shared by traces and scalar files
698        scalar_names = set(tr.scalars.keys())
699
700        # Walker step and weight traces
701        st = ti.scalars.step
702        wt = tr.scalars.weight
703        if len(st)!=len(wt):
704            self.error('weight and steps traces have different lengths')
705        #end if
706
707        # Compute number of steps and steps per block
708        steps = st.max()+1
709        steps_per_block = steps//blocks
710
711        # Accumulate weights into steps and blocks
712        ws   = np.zeros((steps,))
713        wb   = np.zeros((blocks,))
714        #ws2  = np.zeros((steps,))
715        for t in range(len(wt)): # accumulate over walker population per step
716            ws[st[t]] += wt[t]
717            #ws2[st[t]] += 1.0 # scalar.dat/stat.h5 set weights to 1.0 in dmc
718        #end for
719        s = 0
720        for b in range(blocks): # accumulate over steps in a block
721            wb[b] = ws[s:s+steps_per_block].sum()
722            #wb[b] = ws2[s:s+steps_per_block].sum()
723            s+=steps_per_block
724        #end for
725
726        # Accumulate walker population into steps
727        ps  = np.zeros((steps,))
728        for t in range(len(wt)):
729            ps[st[t]] += 1
730        #end for
731
732        # Accumulate quantities into steps and blocks
733        #   These are the values directly comparable with data in
734        #   scalar.dat, stat.h5, and dmc.dat.
735        scalars_by_step  = obj(Weight=ws,NumOfWalkers=ps)
736        scalars_by_block = obj(Weight=wb)
737        qs   = np.zeros((steps,))
738        qb   = np.zeros((blocks,))
739        qs2  = np.zeros((steps,))
740        quantities = set(tr.scalars.keys())
741        quantities.remove('weight')
742        for qname in quantities:
743            qt = tr.scalars[qname]
744            if len(qt)!=len(wt):
745                self.error('quantity {0} trace is not commensurate with weight and steps traces'.format(qname))
746            #end if
747            qs[:] = 0
748            #qs2[:] = 0
749            for t in range(len(wt)):
750                qs[st[t]] += wt[t]*qt[t]
751                #qs2[st[t]] += 1.0*qt[t]
752            #end for
753            qb[:] = 0
754            s=0
755            for b in range(blocks):
756                qb[b] = qs[s:s+steps_per_block].sum()
757                #qb[b] = qs2[s:s+steps_per_block].sum()
758                s+=steps_per_block
759            #end for
760            qb = qb/wb
761            qs = qs/ws
762            scalars_by_step[qname]  = qs.copy()
763            scalars_by_block[qname] = qb.copy()
764        #end for
765        self.scalars_by_step  = scalars_by_step
766        self.scalars_by_block = scalars_by_block
767    #end def accumulate_scalars
768#end class TracesFileHDF
769
770
771
772# Aggregates data from the full collection of traces.h5 files for a
773#   single series (e.g. VMC == series 0) and compares aggregated trace
774#   values to data in scalar.dat, stat.h5, and dmc.dat.
775class TracesAnalyzer(DevBase):
776
777    # Read data from scalar.dat, stat.h5, dmc.dat and all traces.h5 for the series
778    def __init__(self,options):
779
780        self.particle_sums_valid = None
781        self.scalar_dat_valid    = None
782        self.stat_h5_valid       = None
783        self.dmc_dat_valid       = None
784
785        self.failed = False
786
787        self.options = options.copy()
788
789        prefix = options.prefix
790        series = options.series
791        method = options.method
792        mpi    = options.mpi
793        pseudo = options.pseudo
794        path   = options.path
795
796        # Determine the quantities to check
797        dmc_dat_quants = ['Weight','LocalEnergy','NumOfWalkers']
798        scalar_quants = ['LocalEnergy','Kinetic','LocalPotential',
799                         'ElecElec','IonIon']
800        if not pseudo:
801            scalar_quants.append('ElecIon')
802        else:
803            scalar_quants.extend(['LocalECP','NonLocalECP'])
804        #end if
805        scalar_dat_quants = scalar_quants+['Weight']
806        stat_h5_quants  = scalar_quants
807        if self.options.quantities is not None:
808            for qlist in (scalar_dat_quants,stat_h5_quants,dmc_dat_quants):
809                old = list(qlist)
810                del qlist[:]
811                for q in old:
812                    if q in self.options.quantities or q=='Weight':
813                        qlist.append(q)
814                    #end if
815                #end for
816            #end for
817        #end if
818
819        # Make paths to scalar, stat, dmc and traces files
820        prefix = prefix+'.s'+str(series).zfill(3)
821
822        scalar_file = os.path.join(path,prefix+'.scalar.dat')
823        stat_file   = os.path.join(path,prefix+'.stat.h5')
824        dmc_file    = os.path.join(path,prefix+'.dmc.dat')
825
826        trace_files = []
827        if mpi==1:
828            tf = os.path.join(path,prefix+'.traces.h5')
829            trace_files.append(tf)
830        else:
831            for n in range(mpi):
832                tf = os.path.join(path,prefix+'.p'+str(n).zfill(3)+'.traces.h5')
833                trace_files.append(tf)
834            #end for
835        #end if
836
837        # Check that all output files exist
838        files = [scalar_file,stat_file]
839        if method=='dmc':
840            files.append(dmc_file)
841        #end if
842        files.extend(trace_files)
843        for filepath in files:
844            if not os.path.exists(filepath):
845                self.error('filepath {} does not exist'.format(filepath))
846            #end if
847        #end for
848
849        # Load scalar, stat, dmc, and traces files
850
851        # Load scalar.dat
852        self.scalar_dat = ScalarDatFile(scalar_file,scalar_dat_quants)
853
854        # Load stat.h5
855        self.stat_h5 = ScalarHDFFile(stat_file,stat_h5_quants)
856
857        # Load dmc.dat
858        self.dmc_dat = None
859        if method=='dmc':
860            self.dmc_dat = DmcDatFile(dmc_file,dmc_dat_quants)
861        #end if
862
863        # Load all traces.h5
864        self.data = obj()
865        blocks = len(self.scalar_dat.data.first())
866        for filepath in sorted(trace_files):
867            trace_file = TracesFileHDF(filepath,blocks)
868            self.data.append(trace_file)
869        #end for
870        assert(len(self.data)==mpi)
871
872    #end def __init__
873
874
875    # Check that per-particle quantities sum to total/scalar quantities
876    #   in each traces.h5 file separately.
877    def check_particle_sums(self,tol):
878        same = True
879        for trace_file in self.data:
880            log('Checking traces file: {}'.format(os.path.basename(trace_file.filepath)),n=2)
881            same &= trace_file.check_particle_sums(tol=tol)
882        #end for
883        self.particle_sums_valid = same
884        self.failed |= not same
885        self.pass_fail(same,n=2)
886        return same
887    #end def check_particle_sums
888
889
890    # Check aggregated traces data against scalar.dat
891    def check_scalar_dat(self,tol):
892        valid = self.check_scalar_file('scalar.dat',self.scalar_dat,tol)
893        self.scalar_dat_valid = valid
894        self.pass_fail(valid,n=2)
895        return valid
896    #end def check_scalar_dat
897
898
899    # Check aggregated traces data against stat.h5
900    def check_stat_h5(self,tol):
901        valid = self.check_scalar_file('stat.h5',self.stat_h5,tol)
902        self.stat_h5_valid = valid
903        self.pass_fail(valid,n=2)
904        return valid
905    #end def check_stat_h5
906
907
908    # Shared checking implementation for scalar.dat and stat.h5
909    def check_scalar_file(self,file_type,scalar_file,tol):
910
911        # Check that expected quantities are present
912        qnames = scalar_file.quantities
913        trace_names = set(self.data[0].scalars_by_block.keys())
914        missing = set(qnames)-trace_names
915        if len(missing)>0:
916            self.error('{} file check failed for series {}\ntraces file is missing some quantities\nquantities present: {}\nquantities missing: {}'.format(file_type,self.options.series,sorted(trace_names),sorted(missing)))
917        #end if
918
919        scalars_valid  = True
920        scalars        = scalar_file.data
921
922        # Sum traces data for each quantity across all traces.h5 files
923        summed_scalars = obj()
924        for qname in qnames:
925            summed_scalars[qname] = np.zeros(scalars[qname].shape)
926        #end for
927        wtot = np.zeros(summed_scalars.first().shape)
928        for trace_file in self.data:
929            # scalar.dat/stat.h5 are resolved per block
930            w = trace_file.scalars_by_block.Weight
931            wtot += w
932            for qname in qnames:
933                q = trace_file.scalars_by_block[qname]
934                summed_scalars[qname] += w*q
935            #end for
936        #end for
937
938        # Compare summed trace data against scalar.dat/stat.h5 values
939        for qname in qnames:
940            qscalar = scalars[qname]
941            qb = summed_scalars[qname]/wtot
942            match = abs(qb-qscalar)<tol
943            all_match = match.all()
944            self.log('{:<16} {}/{} blocks match'.format(qname,match.sum(),len(match)),n=2)
945            if not all_match:
946                for b,(m,qfile,qtrace) in enumerate(zip(match,qscalar,qb)):
947                    if not m:
948                        log('{:>3}  {: 16.12f}  {: 16.12f}  {: 16.12f}'.format(b,qfile,qtrace,qfile-qtrace),n=3)
949                    #end if
950                #end for
951            #end if
952            scalars_valid &= all_match
953        #end for
954
955        self.failed |= not scalars_valid
956
957        return scalars_valid
958    #end def check_scalar_file
959
960
961    # Check aggregated traces data against dmc.dat
962    def check_dmc_dat(self,tol):
963
964        # Some DMC steps are excluded due to a known bug in QMCPACK weights
965        dmc_steps_exclude = self.options.dmc_steps_exclude
966
967        # Check that expected quantities are present
968        dmc_file = self.dmc_dat
969        qnames = dmc_file.quantities
970        trace_names = set(self.data[0].scalars_by_step.keys())
971        missing = set(qnames)-trace_names
972        if len(missing)>0:
973            self.error('dmc.dat check failed for series {}\ntraces file is missing some quantities\nquantities present: {}\nquantities missing: {}'.format(self.options.series,sorted(trace_names),sorted(missing)))
974        #end if
975        weighted = set(['LocalEnergy'])
976
977        dmc_valid = True
978        dmc       = dmc_file.data
979
980        # Sum traces data for each quantity across all traces.h5 files
981        summed_scalars = obj()
982        for qname in qnames:
983            summed_scalars[qname] = np.zeros(dmc[qname].shape)
984        #end for
985        wtot = np.zeros(summed_scalars.first().shape)
986        for trace_file in self.data:
987            # dmc.dat is resolved per step
988            w = trace_file.scalars_by_step.Weight
989            wtot += w
990            for qname in qnames:
991                q = trace_file.scalars_by_step[qname]
992                if qname in weighted:
993                    summed_scalars[qname] += w*q
994                else:
995                    summed_scalars[qname] += q
996                #end if
997            #end for
998        #end for
999
1000        # Compare summed trace data against dmc.dat values
1001        for qname in qnames:
1002            qdmc = dmc[qname]
1003            if qname in weighted:
1004                qb = summed_scalars[qname]/wtot
1005            else:
1006                qb = summed_scalars[qname]
1007            #end if
1008            match = abs(qb-qdmc)<tol
1009            all_match = match.all()
1010            self.log('{:<16} {}/{} steps match'.format(qname,match.sum(),len(match)),n=2)
1011            if not all_match:
1012                for s,(m,qfile,qtrace) in enumerate(zip(match,qdmc,qb)):
1013                    if not m:
1014                        log('{:>3}  {: 16.12f}  {: 16.12f}  {: 16.12f}'.format(s,qfile,qtrace,qfile-qtrace),n=3)
1015                    #end if
1016                #end for
1017            #end if
1018            if dmc_steps_exclude>0:
1019                all_match = match[dmc_steps_exclude:].all()
1020            #end if
1021            dmc_valid &= all_match
1022        #end for
1023
1024        if dmc_steps_exclude>0:
1025            log('\nExcluding first {} DMC steps from match check.'.format(dmc_steps_exclude),n=2)
1026        #end if
1027
1028        self.dmc_dat_valid = dmc_valid
1029        self.pass_fail(dmc_valid,n=2)
1030
1031        self.failed |= not dmc_valid
1032
1033        return dmc_valid
1034    #end def check_dmc_dat
1035
1036
1037    # Print a brief message about pass/fail status of a subtest
1038    def pass_fail(self,passed,n):
1039        if passed:
1040            self.log('\nCheck passed!',n=n)
1041        else:
1042            self.log('\nCheck failed!',n=n)
1043        #end if
1044    #end def pass_fail
1045#end class TracesAnalyzer
1046
1047
1048
1049examples = '''
1050
1051===================================================================
1052Example 1: QMCPACK VMC/DMC with scalar-only traces, single mpi task
1053===================================================================
1054
1055Contents of QMCPACK input file (qmc.in.xml):
1056--------------------------------------------
1057<simulation>
1058   <project id="qmc" series="0">
1059      ...
1060   </project>
1061
1062   <qmcsystem>
1063      ...
1064   </qmcsystem>
1065
1066   # write traces files w/o array info (scalars only)
1067   <traces array="no"/>
1068
1069   # vmc run, scalars written to stat.h5
1070   <qmc method="vmc" move="pbyp">
1071      <estimator name="LocalEnergy" hdf5="yes"/>
1072      ...
1073   </qmc>
1074
1075   # dmc run, scalars written to stat.h5
1076   <qmc method="dmc" move="pbyp" checkpoint="-1">
1077     <estimator name="LocalEnergy" hdf5="yes"/>
1078     ...
1079   </qmc>
1080
1081</simulation>
1082
1083QMCPACK execution:
1084------------------
1085export OMP_NUM_THREADS=1
1086mpirun -np qmcpack qmc.in.xml
1087
1088QMCPACK output files:
1089---------------------
1090qmc.s000.qmc.xml
1091qmc.s000.scalar.dat
1092qmc.s000.stat.h5
1093qmc.s000.traces.h5
1094qmc.s001.cont.xml
1095qmc.s001.dmc.dat
1096qmc.s001.qmc.xml
1097qmc.s001.scalar.dat
1098qmc.s001.stat.h5
1099qmc.s001.traces.h5
1100
1101check_traces.py usage:
1102----------------------
1103check_traces.py -p qmc -s '0,1' -m 'vmc,dmc' --dmc_steps_exclude=1
1104
1105Either execute check_traces.py in /your/path/to/qmcpack/output as above, or do:
1106
1107check_traces.py -p qmc -s '0,1' -m 'vmc,dmc' --dmc_steps_exclude=1 /your/path/to/qmcpack/output
1108
1109
1110====================================================================
1111Example 2: QMCPACK VMC/DMC, selective scalar traces, single mpi task
1112====================================================================
1113
1114Contents of QMCPACK input file (qmc.in.xml):
1115--------------------------------------------
1116<simulation>
1117   <project id="qmc" series="0">
1118      ...
1119   </project>
1120
1121   <qmcsystem>
1122      ...
1123   </qmcsystem>
1124
1125   # write traces files w/o array info (scalars only)
1126   <traces array="no">
1127      <scalar_traces>
1128         Kinetic ElecElec  # only write traces of Kinetic and ElecElec
1129      </scalar_traces>
1130   </traces>
1131
1132   # vmc run, scalars written to stat.h5
1133   <qmc method="vmc" move="pbyp">
1134      <estimator name="LocalEnergy" hdf5="yes"/>
1135      ...
1136   </qmc>
1137
1138   # dmc run, scalars written to stat.h5
1139   <qmc method="dmc" move="pbyp" checkpoint="-1">
1140     <estimator name="LocalEnergy" hdf5="yes"/>
1141     ...
1142   </qmc>
1143
1144</simulation>
1145
1146QMCPACK execution:
1147------------------
1148export OMP_NUM_THREADS=1
1149mpirun -np qmcpack qmc.in.xml
1150
1151QMCPACK output files:
1152---------------------
1153qmc.s000.qmc.xml
1154qmc.s000.scalar.dat
1155qmc.s000.stat.h5
1156qmc.s000.traces.h5
1157qmc.s001.cont.xml
1158qmc.s001.dmc.dat
1159qmc.s001.qmc.xml
1160qmc.s001.scalar.dat
1161qmc.s001.stat.h5
1162qmc.s001.traces.h5
1163
1164check_traces.py usage:
1165----------------------
1166check_traces.py -p qmc -s '0,1' -m 'vmc,dmc' -q 'Kinetic,ElecElec' --dmc_steps_exclude=1
1167
1168
1169===================================================================
1170Example 3: QMCPACK VMC/DMC, selective array traces, single mpi task
1171===================================================================
1172
1173Contents of QMCPACK input file (qmc.in.xml):
1174--------------------------------------------
1175<simulation>
1176   <project id="qmc" series="0">
1177      ...
1178   </project>
1179
1180   <qmcsystem>
1181      ...
1182   </qmcsystem>
1183
1184   # write traces files w/ all scalar info and select array info
1185   <traces>
1186      <array_traces>
1187         position Kinetic  # write per-electron positions and kinetic energies
1188      </array_traces>
1189   </traces>
1190
1191   # vmc run, scalars written to stat.h5
1192   <qmc method="vmc" move="pbyp">
1193      <estimator name="LocalEnergy" hdf5="yes"/>
1194      ...
1195   </qmc>
1196
1197   # dmc run, scalars written to stat.h5
1198   <qmc method="dmc" move="pbyp" checkpoint="-1">
1199     <estimator name="LocalEnergy" hdf5="yes"/>
1200     ...
1201   </qmc>
1202
1203</simulation>
1204
1205QMCPACK execution:
1206------------------
1207export OMP_NUM_THREADS=1
1208mpirun -np qmcpack qmc.in.xml
1209
1210QMCPACK output files:
1211---------------------
1212qmc.s000.qmc.xml
1213qmc.s000.scalar.dat
1214qmc.s000.stat.h5
1215qmc.s000.traces.h5
1216qmc.s001.cont.xml
1217qmc.s001.dmc.dat
1218qmc.s001.qmc.xml
1219qmc.s001.scalar.dat
1220qmc.s001.stat.h5
1221qmc.s001.traces.h5
1222
1223check_traces.py usage:
1224----------------------
1225check_traces.py -p qmc -s '0,1' -m 'vmc,dmc' --psum --dmc_steps_exclude=1
1226
1227'''
1228
1229
1230
1231if __name__=='__main__':
1232
1233    # Define command line options
1234    usage = '''usage: %prog [options] [path]'''
1235    parser = OptionParser(usage=usage,add_help_option=False,version='%prog 0.1')
1236
1237    parser.add_option('-h','--help',dest='help',
1238                      action='store_true',default=False,
1239                      help='Print help information and exit (default=%default).'
1240                      )
1241    parser.add_option('-x','--examples',dest='examples',
1242                      action='store_true',default=False,
1243                      help='Print usage examples and exit (default=%default).'
1244                      )
1245    parser.add_option('-p','--prefix',dest='prefix',
1246                      default='qmc',
1247                      help='Series number(s) to check (default=%default).'
1248                      )
1249    parser.add_option('-s','--series',dest='series',
1250                      default='None',
1251                      help='Series number(s) to check (default=%default).'
1252                      )
1253    parser.add_option('-m','--methods',dest='methods',
1254                      default='None',
1255                      help='QMC method for each series.  Can be "vmc" or "dmc" for each series (default=%default).'
1256                      )
1257    parser.add_option('-q','--quantities',dest='quantities',
1258                      default='default',
1259                      help='QMC method for each series.  Can be "vmc" or "dmc" for each series (default=%default).'
1260                      )
1261    parser.add_option('-n','--mpi',dest='mpi',
1262                      default='1',
1263                      help='Number of MPI tasks in the original QMCPACK run.  This is also the number of traces.h5 files produced by a single VMC or DMC section (default=%default).'
1264                      )
1265    parser.add_option('--psum',dest='particle_sum',
1266                      action='store_true',default=False,
1267                      help='Check sums of single particle energies (default=%default).'
1268                      )
1269    parser.add_option('--pseudo',dest='pseudo',
1270                      action='store_true',default=True,
1271                      help='QMC calculation used pseudopotentials (default=%default).'
1272                      )
1273    parser.add_option('--dmc_steps_exclude',dest='dmc_steps_exclude',
1274                      default='0',
1275                      help='Exclude a number of DMC steps from being checked.  This option is temporary and will be removed once a bug in the DMC weights for the first step is fixed (default=%default).'
1276                      )
1277    parser.add_option('--tol',dest='tolerance',
1278                      default='1e-8',
1279                      help='Tolerance to check (default=%default).'
1280                      )
1281
1282    opt,paths = parser.parse_args()
1283
1284    options = obj(**opt.__dict__)
1285
1286    # Process command line options
1287    if options.help:
1288        log('\n'+parser.format_help().strip()+'\n')
1289        sys.exit(0)
1290    #end if
1291
1292    if options.examples:
1293        log(examples)
1294        sys.exit(0)
1295    #end if
1296
1297    tol = float(options.tolerance)
1298
1299    if len(paths)==0:
1300        options.path = './'
1301    elif len(paths)==1:
1302        options.path = paths[0]
1303    else:
1304        error('Only a single path is accepted as input.\nPaths provided:\n{}'.format(paths))
1305    #end if
1306    if not os.path.exists(options.path):
1307        error('Path to QMCPACK run does not exist.\nPath provided: {}'.format(options.path))
1308    elif os.path.isfile(options.path):
1309        error('Path to QMCPACK run is actually a file.\nOnly directory paths are accepted.\nPath provided: {}'.format(options.path))
1310    #end if
1311
1312    def process_list(slist,type=lambda x: x):
1313        tokens = slist.strip('"').strip("'").replace(',',' ').split()
1314        lst = [type(t) for t in tokens]
1315        return lst
1316    #end def process_list
1317
1318    options.series  = process_list(options.series,int)
1319    options.methods = process_list(options.methods)
1320    options.mpi     = int(options.mpi)
1321    options.dmc_steps_exclude = int(options.dmc_steps_exclude)
1322    if options.quantities=='default':
1323        options.quantities = None
1324    else:
1325        options.quantities = process_list(options.quantities)
1326    #end if
1327
1328    if len(options.series)!=len(options.methods):
1329        error('"series" and "methods" must match in length.')
1330    #end if
1331    valid_methods = ['vmc','dmc']
1332    invalid = set(options.methods)-set(valid_methods)
1333    if len(invalid)>0:
1334        error('Invalid entries given for "methods".\nValid options are: {}\nYou provided: {}'.format(valid_methods,sorted(invalid)))
1335    #end if
1336
1337    valid_quantities = '''
1338        LocalEnergy Kinetic LocalPotential ElecElec IonIon ElecIon LocalECP NonLocalECP
1339        '''.split()
1340    if options.quantities is not None:
1341        invalid = set(options.quantities)-set(valid_quantities)
1342        if len(invalid)>0:
1343            error('Invalid entries given for "quantities".\nValid options are: {}\nYou provided: {}'.format(valid_quantities,sorted(invalid)))
1344        #end if
1345    #end if
1346
1347
1348    # Parse all files across all requested series and compare traces vs scalar/stat/dmc
1349
1350    log('\nChecking match between traces and scalar/stat/dmc files\n')
1351
1352    log('\nOptions provided:\n'+str(options).rstrip())
1353
1354    failed = False
1355
1356    series_in  = options.series
1357    methods_in = options.methods
1358
1359    del options.series
1360    del options.methods
1361
1362    # Loop over vmc/dmc series
1363    for series,method in zip(series_in,methods_in):
1364
1365        options.series = series
1366        options.method = method
1367
1368        log('\n\nChecking series {} method={}'.format(series,method))
1369
1370        # Read scalar.dat, stat.h5, dmc.dat, and *traces.h5 for the series
1371        ta = TracesAnalyzer(options)
1372
1373        # Check traces data against scalar/stat/dmc files
1374        if method=='vmc':
1375            if options.particle_sum:
1376                log('\nChecking sums of single particle energies',n=1)
1377                ta.check_particle_sums(tol)
1378            #end if
1379
1380            log('\nChecking scalar.dat',n=1)
1381            ta.check_scalar_dat(tol)
1382
1383            log('\nChecking stat.h5',n=1)
1384            ta.check_stat_h5(tol)
1385
1386        elif method=='dmc':
1387            if options.particle_sum:
1388                log('\nChecking sums of single particle energies',n=1)
1389                ta.check_particle_sums(tol)
1390            #end if
1391
1392            log('\nSkipping checks of scalar.dat and stat.h5',n=1)
1393            log('Statistics for these files are currently computed\n'
1394                'after branching. Since traces are written before \n'
1395                'branching, these files cannot be reconstructed \n'
1396                'from the traces.',n=2)
1397
1398            log('\nChecking dmc.dat',n=1)
1399            ta.check_dmc_dat(tol)
1400        #end if
1401
1402        failed |= ta.failed
1403    #end for
1404
1405    # Print final pass/fail message
1406    if failed:
1407        test_fail()
1408    else:
1409        test_pass()
1410    #end if
1411#end if
1412