1##################################################################
2##  (c) Copyright 2015-  by Jaron T. Krogel                     ##
3##################################################################
4
5
6#====================================================================#
7#  qmcpack_analyzer_base.py                                          #
8#    Data object and analyzer base classes for QmcpackAnalyzer.      #
9#    Maintains data global to these classes.                         #
10#                                                                    #
11#  Content summary:                                                  #
12#    QAobject                                                        #
13#      Base class for all QMCPACK analyzer components.               #
14#      Exposes settings options to the user.                         #
15#                                                                    #
16#    Checks                                                          #
17#      Class to assess overall validity based on stored results of   #
18#      many checks (boolean values). Only use so far is to validate  #
19#      the structure of Trace files. See qmcpack_method_analyzers.py.#
20#                                                                    #
21#    Plotter                                                         #
22#      Wrapper class for mayavi visualization of isosurfaces and     #
23#      surface slices. Previously used to visualize energy densities.#
24#      See qmcpack_quantity_analyzers.py and spacegrid.py.           #
25#                                                                    #
26#    QAdata                                                          #
27#      Represents stored data from QMCPACK's output files.           #
28#      Classification marks it as a target for potential merging,    #
29#      e.g. twist averaging.                                         #
30#                                                                    #
31#    QAHDFdata                                                       #
32#      Specialization of QAdata for data from HDF files.             #
33#                                                                    #
34#    QAanalyzer                                                      #
35#      Base class for analyzer classes. Analyzers load and analyze   #
36#      data. Base class functionality includes recursive traversal   #
37#      of nested analyzer object structures for loading and          #
38#      analyzing data.                                               #
39#                                                                    #
40#====================================================================#
41
42
43from numpy import minimum,resize
44from generic import obj
45from developer import DevBase
46from hdfreader import HDFgroup
47from debug import *
48
49
50
51
52import numpy as np
53
54class Plotter(DevBase):
55    def __init__(self):
56        self.initialized = False
57        return
58    #end def __init__
59
60    def ensure_init(self):
61        if not self.initialized:
62            from enthought.mayavi import mlab
63            from enthought.tvtk.api import tvtk
64            self.mlab = mlab
65            self.tvtk = tvtk
66
67            self.show   = mlab.show
68            self.plot3d = mlab.plot3d
69            self.mesh   = mlab.mesh
70
71            self.initialized = True
72        #end if
73    #end def ensure_init
74
75    def isosurface(self,points,scalars,contours,dimensions,name='val'):
76        self.ensure_init()
77        mlab = self.mlab
78        tvtk = self.tvtk
79        sg=tvtk.StructuredGrid(dimensions=dimensions,points=points)
80        sg.point_data.scalars = scalars
81        sg.point_data.scalars.name = name
82        d = mlab.pipeline.add_dataset(sg)
83        iso = mlab.pipeline.iso_surface(d)
84        if isinstance(contours,int):
85            iso.contour.number_of_contours = contours
86        elif isinstance(contours,list):
87            iso.contour.auto_contours = False
88            iso.contour.contours = contours
89        else:
90            self.error('isosurface contours must be an int or list\n  a '+str(type(contours))+' was provided instead')
91        #end if
92        return
93    #end def isosurface
94
95    def surface_slice(self,x,y,z,scalars,options=None):
96        scale = 1.0
97        opacity= 1.0
98        if options!=None:
99            if 'norm_height' in options:
100                scale = options.norm_height/abs(scalars.max())
101            if 'scale' in options:
102                scale = options.scale
103            if 'opacity' in options:
104                opacity = options.opacity
105        #end if
106        self.ensure_init()
107        from numerics import surface_normals
108        self.mesh(x,y,z,opacity=.2)
109        surfnorm = scale*surface_normals(x,y,z)
110        xs=x.copy()
111        ys=y.copy()
112        zs=z.copy()
113        xs[...] = x[...] + surfnorm[...,0]*scalars[...]
114        ys[...] = y[...] + surfnorm[...,1]*scalars[...]
115        zs[...] = z[...] + surfnorm[...,2]*scalars[...]
116        self.mesh(xs,ys,zs,scalars=scalars,opacity=opacity)
117        return
118    #end def surface_slice
119#end class Plotter
120
121
122
123class QAobj_base(DevBase):
124    None
125#end class QAobj_base
126
127
128class QAobject(QAobj_base):
129
130    _global = obj()
131    _global.dynamic_methods_objects=[]
132
133    plotter = Plotter()
134
135    opt_methods = set(['opt','linear','cslinear'])
136
137    def __init__(self):
138        return
139    #end def __init__
140
141    @staticmethod
142    def condense_name(name):
143        return name.strip().lower().replace(' ','_').replace('-','_').replace('__','_')
144    #end def condense_name
145
146
147    def _register_dynamic_methods(self):
148        QAobject._global.dynamic_methods_objects.append(self)
149        return
150    #end def _register_dynamic_methods
151
152    def _unlink_dynamic_methods(self):
153        for o in QAobject._global.dynamic_methods_objects:
154            o._unset_dynamic_methods()
155        #end for
156        return
157    #end def _unlink_dynamic_methods
158
159    def _relink_dynamic_methods(self):
160        for o in QAobject._global.dynamic_methods_objects:
161            o._reset_dynamic_methods()
162        #end for
163        return
164    #end def _relink_dynamic_methods
165
166
167    _allowed_settings = set(['optimize'])
168    _default_settings = obj(
169        #optimize = 'variance'
170        optimize = 'lastcost'
171        #optimize = 'energy_within_variance_tol'  # also ewvt
172        )
173    QAobj_base.class_set(**_default_settings)
174
175    @classmethod
176    def settings(cls,**kwargs):
177        vars = set(kwargs.keys())
178        invalid = vars-cls._allowed_settings
179        if len(invalid)>0:
180            allowed = list(cls._allowed_settings)
181            allowed.sort()
182            invalid = list(invalid)
183            invalid.sort()
184            cls.class_error('attempted to set unknown variables\n  unknown variables: {0}\n  valid options are: {1}'.format(invalid,allowed))
185        #end if
186        QAobj_base.class_set(**kwargs)
187    #end settings
188#end class QAobject
189
190
191
192class Checks(DevBase):
193    def __init__(self,label=''):
194        self._label = label
195        self._exclusions = set()
196    #end def __init__
197
198    def exclude(self,value):
199        self._exclusions.add(value)
200    #end def exclude
201
202    def valid(self):
203        valid = True
204        for name,value in self.items():
205            if not (isinstance(name,str) and name.startswith('_')):
206                if not value in self._exclusions:
207                    valid = valid and value
208                #end if
209            #end if
210        #end if
211        self._valid = valid
212        return valid
213    #end def valid
214
215    def write(self,pad=''):
216        pad2 = pad+'  '
217        if not '_valid' in self:
218            self.valid()
219        #end if
220        valid = self._valid
221        if valid:
222            self.log(pad+self._label+' is valid')
223        else:
224            self.log(pad+self._label+' is invalid')
225            for name,value in self.items():
226                if not (isinstance(name,str) and name.startswith('_')):
227                    if value in self._exclusions:
228                        self.log(pad2+name+' could not be checked')
229                    elif value:
230                        self.log(pad2+name+' is valid')
231                    else:
232                        self.log(pad2+name+' is invalid')
233                    #end if
234                #end if
235            #end for
236        #end if
237    #end def write
238#end class Checks
239
240
241
242
243class QAinformation(obj):
244    None
245#end class QAinformation
246
247
248
249class QAdata(QAobject):
250    def zero(self):
251        for value in self:
252            value[:] = 0
253        #end for
254        #self.sum()
255    #end def zero
256
257    def minsize(self,other):
258        for name,value in self.items():
259            if name in other:
260                self[name] = resize(value,minimum(value.shape,other[name].shape))
261            else:
262                self.error(name+' not found in minsize partner')
263            #end if
264        #end for
265        #self.sum()
266    #end def minsize
267
268    def accumulate(self,other):
269        for name,value in self.items():
270            if name in other:
271                value += other[name][0:len(value)]
272            else:
273                self.error(name+' not found in accumulate partner')
274            #end if
275        #end for
276        #self.sum()
277    #end def accumulate
278
279    def normalize(self,normalization):
280        for value in self:
281            value/=normalization
282        #end for
283        #self.sum()
284    #end def normalize
285
286
287    def sum(self):
288        s = 0
289        for value in self:
290            s+=value.sum()
291        #end for
292        print('                sum = {0}'.format(s))
293    #end def sum
294#end class QAdata
295
296
297
298class QAHDFdata(QAdata):
299    def zero(self):
300        for name,value in self.items():
301            if isinstance(value,HDFgroup):
302                value.zero('value','value_squared')
303            #end if
304        #end for
305    #end def zero
306
307    def minsize(self,other):
308        for name,value in self.items():
309            if isinstance(value,HDFgroup):
310                if name in other and isinstance(other[name],HDFgroup):
311                    value.minsize(other[name],'value','value_squared')
312                else:
313                    self.error(name+' not found in minsize partner')
314                #end if
315            #end if
316        #end for
317    #end def minsize
318
319    def accumulate(self,other):
320        for name,value in self.items():
321            if isinstance(value,HDFgroup):
322                if name in other and isinstance(other[name],HDFgroup):
323                    value.accumulate(other[name],'value','value_squared')
324                else:
325                    self.error(name+' not found in accumulate partner')
326                #end if
327            #end if
328        #end for
329    #end def accumulate
330
331    def normalize(self,normalization):
332        for value in self:
333            if isinstance(value,HDFgroup):
334                value.normalize(normalization,'value','value_squared')
335            #end if
336        #end for
337    #end def normalize
338#end class QAHDFdata
339
340
341
342
343class QAanalyzer(QAobject):
344
345    verbose_vlog = False
346
347    capabilities = None
348    request      = None
349    run_info     = None
350    method_info  = None
351
352    opt_methods = set(['opt','linear','cslinear'])
353    vmc_methods = set(['vmc'])
354    dmc_methods = set(['dmc'])
355
356
357    def __init__(self,nindent=0):
358        self.info = QAinformation(
359            initialized = False,
360            data_loaded = False,
361            analyzed    = False,
362            failed      = False,
363            nindent     = nindent
364            )
365        self.vlog('building '+self.__class__.__name__)
366    #end def __init__
367
368    def subindent(self):
369        return self.info.nindent+1
370    #end def indent
371
372    def vlog(self,msg,n=0):
373        if QAanalyzer.verbose_vlog:
374            self.log(msg,n=self.info.nindent+n)
375        #end if
376    #end def vlog
377
378    def reset_indicators(self,initialized=None,data_loaded=None,analyzed=None):
379        if initialized!=None:
380            self.info.initialized = initialized
381        #end if
382        if data_loaded!=None:
383            self.info.data_loaded = data_loaded
384        #end if
385        if analyzed!=None:
386            self.info.analyzed = analyzed
387        #end if
388    #end def reset_indicators
389
390    def init_sub_analyzers(self):
391        self.not_implemented()
392    #end def init_sub_analyzers
393
394    def load_data_local(self):
395        None
396    #end def load_data_local
397
398    def remove_data_local(self):
399        if 'data' in self:
400            del self.data
401        #end if
402    #end def remove_data_local
403
404    def analyze_local(self):
405        None
406    #end def analyze_local
407
408    def set_global_info(self):
409        None
410    #end def set_global_info
411
412    def unset_global_info(self):
413        None
414    #end def unset_global_info
415
416    #def traverse(self,function,block_name=None,callpost=True,**kwargs):
417    #    if not callpost:
418    #        cls.__dict__[func_name](self,**kwargs)
419    #    #end if
420    #    if block_name is None or not self.info[block_name]:
421    #        for name,value in self.items():
422    #            if isinstance(value,QAanalyzer):
423    #                value.traverse(value,func_name,block_name,callpost,**kwargs)
424    #            elif isinstance(value,QAanalyzerCollection):
425    #                for n,v in value.items():
426    #                    if isinstance(v,QAanalyzer):
427    #                        v.traverse(v,func_name,block_name,callpost,**kwargs)
428    #                    #end if
429    #                #end for
430    #            #end if
431    #        #end for
432    #    #end if
433    #    if block_name!=None:
434    #        self.info[block_name] = True
435    #    #end if
436    #    if callpost:
437    #        cls.__dict__[func_name](self,**kwargs)
438    #    #end if
439    ##end def traverse
440
441    def propagate_indicators(self,**kwargs):
442        self.reset_indicators(**kwargs)
443        for name,value in self.items():
444            if isinstance(value,QAanalyzer):
445                value.propagate_indicators(**kwargs)
446            elif isinstance(value,QAanalyzerCollection):
447                for n,v in value.items():
448                    if isinstance(v,QAanalyzer):
449                        v.propagate_indicators(**kwargs)
450                    #end if
451                #end for
452            #end if
453        #end for
454    #end def propagate_indicators
455
456    def load_data(self):
457        if not self.info.data_loaded:
458            self.vlog('loading '+self.__class__.__name__+' data',n=1)
459            self.load_data_local()
460            self.info.data_loaded = True
461        #end if
462        for name,value in self.items():
463            if isinstance(value,QAanalyzer):
464                value.load_data()
465            elif isinstance(value,QAanalyzerCollection):
466                for n,v in value.items():
467                    if isinstance(v,QAanalyzer):
468                        v.load_data()
469                    #end if
470                #end for
471            #end if
472        #end for
473    #end def load_data
474
475    def analyze(self,force=False):
476        self.set_global_info()
477        if not self.info.data_loaded:
478            self.load_data_local()
479            self.info.data_loaded = True
480        #end if
481        for name,value in self.items():
482            if isinstance(value,QAanalyzer):
483                value.analyze(force)
484            elif isinstance(value,QAanalyzerCollection):
485                for n,v in value.items():
486                    if isinstance(v,QAanalyzer):
487                        v.analyze(force)
488                    #end if
489                #end for
490            #end if
491        #end for
492        if not self.info.analyzed or force:
493            self.vlog('analyzing {0} data'.format(self.__class__.__name__),n=1)
494            self.analyze_local()
495            self.info.analyzed = True
496        #end if
497        self.unset_global_info()
498    #end def analyze
499
500
501    def remove_data(self):
502        self.vlog('removing '+self.__class__.__name__+' data',n=1)
503        names = list(self.keys())
504        for name in names:
505            if isinstance(self[name],QAdata):
506                del self[name]
507            #end if
508        #end for
509        for name,value in self.items():
510            if isinstance(value,QAanalyzer):
511                value.remove_data()
512            elif isinstance(value,QAanalyzerCollection):
513                for n,v in value.items():
514                    if isinstance(v,QAanalyzer):
515                        v.remove_data()
516                    #end if
517                #end for
518            #end if
519        #end for
520    #end def remove_data
521
522
523    def zero_data(self):
524        self.vlog('zeroing '+self.__class__.__name__+' data',n=1)
525        for value in self:
526            if isinstance(value,QAdata):
527                value.zero()
528            #end if
529        #end if
530        for name,value in self.items():
531            if isinstance(value,QAanalyzer):
532                value.zero_data()
533            elif isinstance(value,QAanalyzerCollection):
534                for n,v in value.items():
535                    if isinstance(v,QAanalyzer):
536                        v.zero_data()
537                    #end if
538                #end for
539            #end if
540        #end for
541    #end def zero_data
542
543
544    def minsize_data(self,other):
545        self.vlog('minsizing '+self.__class__.__name__+' data',n=1)
546        for name,value in self.items():
547            if isinstance(value,QAdata):
548                if name in other and isinstance(other[name],value.__class__):
549                    value.minsize(other[name])
550                else:
551                    self.error('data '+name+' not found in minsize_data partner')
552                #end if
553            #end if
554        #end if
555        for name,value in self.items():
556            if isinstance(value,QAanalyzer):
557                if name in other and isinstance(other[name],value.__class__):
558                    ovalue = other[name]
559                else:
560                    self.error('analyzer '+name+' not found in minsize_data partner')
561                #end if
562                value.minsize_data(ovalue)
563            elif isinstance(value,QAanalyzerCollection):
564                if name in other and isinstance(other[name],QAanalyzerCollection):
565                    ovalue = other[name]
566                else:
567                    self.error('collection '+name+' not found in minsize_data partner')
568                #end if
569                for n,v in value.items():
570                    if isinstance(v,QAanalyzer):
571                        if n in ovalue and isinstance(ovalue[n],v.__class__):
572                            ov = ovalue[n]
573                        else:
574                            self.error('analyzer '+n+' not found in minsize_data partner collection '+name)
575                        #end if
576                        v.minsize_data(ov)
577                    #end if
578                #end for
579            #end if
580        #end for
581    #end def minsize_data
582
583
584    def accumulate_data(self,other):
585        self.vlog('accumulating '+self.__class__.__name__+' data',n=1)
586        for name,value in self.items():
587            if isinstance(value,QAdata):
588                if name in other and isinstance(other[name],value.__class__):
589                    value.accumulate(other[name])
590                else:
591                    self.error('data '+name+' not found in accumulate_data partner')
592                #end if
593            #end if
594        #end if
595        for name,value in self.items():
596            if isinstance(value,QAanalyzer):
597                if name in other and isinstance(other[name],value.__class__):
598                    ovalue = other[name]
599                else:
600                    self.error('analyzer '+name+' not found in accumulate_data partner')
601                #end if
602                value.accumulate_data(ovalue)
603            elif isinstance(value,QAanalyzerCollection):
604                if name in other and isinstance(other[name],QAanalyzerCollection):
605                    ovalue = other[name]
606                else:
607                    self.error('collection '+name+' not found in accumulate_data partner')
608                #end if
609                for n,v in value.items():
610                    if isinstance(v,QAanalyzer):
611                        if n in ovalue and isinstance(ovalue[n],v.__class__):
612                            ov = ovalue[n]
613                        else:
614                            self.error('analyzer '+n+' not found in accumulate_data partner collection '+name)
615                        #end if
616                        v.accumulate_data(ov)
617                    #end if
618                #end for
619            #end if
620        #end for
621    #end def accumulate_data
622
623
624    def normalize_data(self,normalization):
625        self.vlog('normalizing '+self.__class__.__name__+' data',n=1)
626        for value in self:
627            if isinstance(value,QAdata):
628                value.normalize(normalization)
629            #end if
630        #end if
631        for name,value in self.items():
632            if isinstance(value,QAanalyzer):
633                value.normalize_data(normalization)
634            elif isinstance(value,QAanalyzerCollection):
635                for n,v in value.items():
636                    if isinstance(v,QAanalyzer):
637                        v.normalize_data(normalization)
638                    #end if
639                #end for
640            #end if
641        #end for
642    #end def normalize_data
643
644#end class QAanalyzer
645
646
647
648class QAanalyzerCollection(QAobject):
649    None
650#end class QAanalyzerCollection
651