1##################################################################
2##  (c) Copyright 2015-  by Jaron T. Krogel                     ##
3##################################################################
4
5
6#====================================================================#
7#  qmcpack_quantity_analyzers.py                                     #
8#    Analyzer classes for specific quantities generated by QMCPACK.  #
9#    Quantities include scalar values from scalars.dat, dmc.dat,     #
10#    or stat.h5 and general quantities from stat.h5 such as the      #
11#    energy density, 1-body density matrices, total densities,       #
12#    spin densities, and static structure factors.  Also supports    #
13#    basic analysis of Traces data (multiple traces.h5 files).       #
14#                                                                    #
15#  Content summary:                                                  #
16#    QuantityAnalyzer                                                #
17#      Base class for specific quantity analyzers.                   #
18#                                                                    #
19#    DatAnalyzer                                                     #
20#      Base class containing common characteristics of *.dat file    #
21#      analysis.                                                     #
22#                                                                    #
23#    ScalarsDatAnalyzer                                              #
24#      Supports analysis specific to scalars.dat.                    #
25#                                                                    #
26#    DmcDatAnalyzer                                                  #
27#      Supports analysis specific to dmc.dat.                        #
28#                                                                    #
29#    HDFAnalyzer                                                     #
30#      Base class for analyzers of stat.h5 data.                     #
31#                                                                    #
32#    ScalarsHDFAnalyzer                                              #
33#      Supports analysis specific to scalar values in stat.h5        #
34#                                                                    #
35#    EnergyDensityAnalyzer                                           #
36#      Supports analysis of energy density data from stat.h5         #
37#                                                                    #
38#    DensityMatricesAnalyzer                                         #
39#      Supports analysis of 1-body particle or energy density        #
40#      matrices from stat.h5.                                        #
41#                                                                    #
42#    DensityAnalyzer                                                 #
43#      Supports analysis of total densities from stat.h5.            #
44#                                                                    #
45#    SpinDensityAnalyzer                                             #
46#      Supports analysis of spin-resolved densities from stat.h5.    #
47#                                                                    #
48#    StructureFactorAnalyzer                                         #
49#      Supports analysis of spin-resolved static structure factors   #
50#      from stat.h5.                                                 #
51#                                                                    #
52#    TracesFileHDF                                                   #
53#      Represents an HDF file containing traces data.                #
54#      One traces.h5 file is produced per MPI process.               #
55#                                                                    #
56#    TracesAnalyzer                                                  #
57#      Supports basic analysis of Traces data.                       #
58#      Can read multiple traces.h5 files and validate against        #
59#        data contained in scalars.dat and dmc.dat.                  #
60#                                                                    #
61#    SpaceGrid                                                       #
62#      Specifically for energy density analysis                      #
63#      Represents a grid of data in 3-dimensional space.             #
64#      Can represent rectilinear grids in Cartesian, cylindrical, or #
65#      or spherical coordinates as well as Voronoi grids.            #
66#                                                                    #
67#====================================================================#
68
69
70import os
71import re
72from numpy import array,zeros,dot,loadtxt,ceil,floor,empty,sqrt,trace,savetxt,concatenate,real,imag,diag,arange,ones,identity
73try:
74    from scipy.linalg import eig,LinAlgError
75except Exception:
76    from numpy.linalg import eig,LinAlgError
77#end try
78from numerics import ndgrid,simstats,simplestats,equilibration_length
79from generic import obj
80from hdfreader import HDFreader
81from qmcpack_analyzer_base import QAobject,QAanalyzer,QAdata,QAHDFdata
82from fileio import XsfFile
83from debug import *
84
85
86class QuantityAnalyzer(QAanalyzer):
87    def __init__(self,nindent=0):
88        QAanalyzer.__init__(self,nindent=nindent)
89        self.method_info = QAanalyzer.method_info
90    #end def __init__
91
92    def plot_trace(self,quantity,*args,**kwargs):
93        from matplotlib.pyplot import plot,xlabel,ylabel,title,ylim
94        if 'data' in self:
95            if not quantity in self.data:
96                self.error('quantity '+quantity+' is not present in the data')
97            #end if
98            nbe = self.get_nblocks_exclude()
99            q = self.data[quantity]
100            middle = int(len(q)/2)
101            qmean = q[middle:].mean()
102            qmax = q[middle:].max()
103            qmin = q[middle:].min()
104            ylims = [qmean-2*(qmean-qmin),qmean+2*(qmax-qmean)]
105            smean,svar = self[quantity].tuple('mean','sample_variance')
106            sstd = sqrt(svar)
107            plot(q,*args,**kwargs)
108            plot([nbe,nbe],ylims,'k-.',lw=2)
109            plot([0,len(q)],[smean,smean],'r-')
110            plot([0,len(q)],[smean+sstd,smean+sstd],'r-.')
111            plot([0,len(q)],[smean-sstd,smean-sstd],'r-.')
112            ylim(ylims)
113            ylabel(quantity)
114            xlabel('samples')
115            title('Trace of '+quantity)
116        #end if
117    #end def QuantityAnalyzer
118
119    def init_sub_analyzers(self):
120        None
121    #end def init_sub_analyzers
122
123    def get_nblocks_exclude(self):
124        return self.info.nblocks_exclude
125    #end def get_nblocks_exclude
126#end class QuantityAnalyzer
127
128
129class DatAnalyzer(QuantityAnalyzer):
130    def __init__(self,filepath=None,equilibration=None,nindent=0):
131        QuantityAnalyzer.__init__(self,nindent=nindent)
132        self.info.filepath = filepath
133        nbe = self.method_info.nblocks_exclude
134        if equilibration!=None and nbe==-1:
135            self.load_data()
136            nbe = equilibration_length(self.data[equilibration])
137            assert nbe>=0, 'Number of equilibration blocks is negative.'
138            self.method_info.nblocks_exclude = nbe
139        #end if
140    #end def __init__
141
142    def analyze_local(self):
143        self.not_implemented()
144    #end def load_data_local
145#end class DatAnalyzer
146
147
148class ScalarsDatAnalyzer(DatAnalyzer):
149    def load_data_local(self):
150        filepath = self.info.filepath
151        quantities = QAanalyzer.request.quantities
152
153        lt = loadtxt(filepath)
154        if len(lt.shape)==1:
155            lt.shape = (1,len(lt))
156        #end if
157
158        data = lt[:,1:].transpose()
159
160        fobj = open(filepath,'r')
161        variables = fobj.readline().split()[2:]
162        fobj.close()
163
164        self.data = QAdata()
165        for i in range(len(variables)):
166            var = variables[i]
167            cvar = self.condense_name(var)
168            if cvar in quantities:
169                self.data[var]=data[i,:]
170            #end if
171        #end for
172    #end def load_data_local
173
174
175    def analyze_local(self):
176        nbe = QAanalyzer.method_info.nblocks_exclude
177        self.info.nblocks_exclude = nbe
178        data = self.data
179        for varname,samples in data.items():
180            (mean,var,error,kappa)=simstats(samples[nbe:])
181            self[varname] = obj(
182                mean            = mean,
183                sample_variance = var,
184                error           = error,
185                kappa           = kappa
186                )
187        #end for
188
189        if 'LocalEnergy_sq' in data:
190            v = data.LocalEnergy_sq - data.LocalEnergy**2
191            (mean,var,error,kappa)=simstats(v[nbe:])
192            self.LocalEnergyVariance = obj(
193                mean            = mean,
194                sample_variance = var,
195                error           = error,
196                kappa           = kappa
197                )
198        #end if
199    #end def analyze_data_local
200#end class ScalarsDatAnalyzer
201
202
203class DmcDatAnalyzer(DatAnalyzer):
204    def load_data_local(self):
205        filepath = self.info.filepath
206
207        lt = loadtxt(filepath)
208        if len(lt.shape)==1:
209            lt.shape = (1,len(lt))
210        #end if
211
212        data = lt[:,1:].transpose()
213
214        fobj = open(filepath,'r')
215        variables = fobj.readline().split()[2:]
216        fobj.close()
217
218        self.data = QAdata()
219        for i in range(len(variables)):
220            var = variables[i]
221            self.data[var]=data[i,:]
222        #end for
223    #end def load_data_local
224
225
226    def analyze_local(self):
227        nbe = QAanalyzer.method_info.nblocks_exclude
228        self.info.nblocks_exclude = nbe
229        data = self.data
230
231        input       = self.run_info.input
232        series      = self.method_info.series
233        ndmc_blocks = self.run_info.request.ndmc_blocks
234
235        #qmc    = input.simulation.calculations[series]
236        qmc    = input.get_qmc(series)
237        blocks = qmc.blocks
238        steps  = qmc.steps
239        nse    = nbe*steps
240
241        self.info.nsteps_exclude = nse
242
243        nsteps = len(data.list()[0])-nse
244
245        #nsteps = blocks*steps-nse
246        block_avg = nsteps > 2*ndmc_blocks
247        if block_avg:
248            block_size  = int(floor(float(nsteps)/ndmc_blocks))
249            ndmc_blocks = int(floor(float(nsteps)/block_size))
250            nse += nsteps-ndmc_blocks*block_size
251            nsteps      = ndmc_blocks*block_size
252        #end if
253
254        for varname,samples in data.items():
255            samp = samples[nse:]
256            if block_avg:
257                samp.shape = ndmc_blocks,block_size
258                samp = samp.mean(axis=1)
259            #end if
260            (mean,var,error,kappa)=simstats(samp)
261            self[varname] = obj(
262                mean            = mean,
263                sample_variance = var,
264                error           = error,
265                kappa           = kappa
266                )
267        #end for
268    #end def load_data_local
269
270
271    def get_nblocks_exclude(self):
272        return self.info.nsteps_exclude
273    #end def get_nblocks_exclude
274#end class DmcDatAnalyzer
275
276
277class HDFAnalyzer(QuantityAnalyzer):
278    def __init__(self,nindent=0):
279        QuantityAnalyzer.__init__(self,nindent=nindent)
280        self.info.should_remove = False
281    #end def __init__
282#end class HDFAnalyzer
283
284
285class ScalarsHDFAnalyzer(HDFAnalyzer):
286    corrections = obj(
287        mpc = obj(ElecElec=-1,MPC=1),
288        kc  = obj(KEcorr=1)
289        )
290
291    def __init__(self,exclude,nindent=0):
292        HDFAnalyzer.__init__(self,nindent=nindent)
293        self.info.exclude = exclude
294    #end def
295
296
297    def load_data_local(self,data=None):
298        if data==None:
299            self.error('attempted load without data')
300        #end if
301        exclude = self.info.exclude
302        self.data = QAHDFdata()
303        for var in list(data.keys()):
304            if not var in exclude and not str(var)[0]=='_' and not 'skall' in var.lower():
305                self.data[var] = data[var]
306                del data[var]
307            #end if
308        #end for
309        corrvars = ['LocalEnergy','ElecElec','MPC','KEcorr']
310        if set(corrvars)<set(self.data.keys()):
311            Ed,Ved,Vmd,Kcd = self.data.tuple(*corrvars)
312            E_mpc_kc = obj()
313            E  = Ed.value
314            Ve = Ved.value
315            Vm = Vmd.value
316            Kc = Kcd.value
317            E_mpc_kc.value = E-Ve+Vm+Kc
318            if 'value_squared' in Ed:
319                E2  = Ed.value_squared
320                Ve2 = Ved.value_squared
321                Vm2 = Vmd.value_squared
322                Kc2 = Kcd.value_squared
323                E_mpc_kc.value_squared = E2+Ve2+Vm2+Kc2 + 2*(E*(-Ve+Vm+Kc)-Ve*(Vm+Kc)+Vm*Kc)
324            #end if
325            self.data.LocalEnergy_mpc_kc = E_mpc_kc
326        #end if
327    #end def load_data_local
328
329
330    def analyze_local(self):
331        nbe = QAanalyzer.method_info.nblocks_exclude
332        self.info.nblocks_exclude = nbe
333        for varname,val in self.data.items():
334            (mean,var,error,kappa)=simstats(val.value[nbe:,...].ravel())
335            if 'value_squared' in val:
336                variance = val.value_squared[nbe:,...].mean()-mean**2
337            else:
338                variance = var
339            #end if
340            self[varname] = obj(
341                mean            = mean,
342                variance        = variance,
343                sample_variance = var,
344                error           = error,
345                kappa           = kappa
346                )
347        #end for
348        self.correct('mpc','kc')
349    #end def analyze_local
350
351
352    def correct(self,*corrections):
353        corrkey=''
354        for corr in corrections:
355            corrkey+=corr+'_'
356        #end for
357        corrkey=corrkey[:-1]
358        if set(corrections)>set(self.corrections.keys()):
359            self.warn('correction '+corrkey+' is unknown and cannot be applied')
360            return
361        #end if
362        if not 'data' in self:
363            self.warn('correction '+corrkey+' cannot be applied because data is not present')
364            return
365        #end if
366        varname = 'LocalEnergy_'+corrkey
367        if varname in self and varname in self.data:
368            return
369        #end if
370        corrvars = ['LocalEnergy']
371        signs    = [1]
372        for corr in corrections:
373            for var,sign in self.corrections[corr].items():
374                corrvars.append(var)
375                signs.append(sign)
376            #end for
377        #end for
378        missing = list(set(corrvars)-set(self.data.keys()))
379        if len(missing)>0:
380            #self.warn('correction '+corrkey+' cannot be applied because '+str(missing)+' are missing')
381            return
382        #end if
383
384        le = self.data.LocalEnergy
385        E,E2 = 0*le.value,0*le.value_squared
386        n = len(corrvars)
387        for i in range(n):
388            ed = self.data[corrvars[i]]
389            e,e2 = ed.value,ed.value_squared
390            s = signs[i]
391            E += s*e
392            E2 += e2
393            for j in range(i+1,n):
394                eo = self.data[corrvars[j]].value
395                so = signs[j]
396                E2 += 2*s*e*so*eo
397            #end for
398        #end for
399        val = obj(value=E,value_squared=E2)
400        self.data[varname] = val
401        nbe = self.info.nblocks_exclude
402        (mean,var,error,kappa)=simstats(val.value[nbe:,...].ravel())
403        self[varname] = obj(
404            mean            = mean,
405            variance        = val.value_squared[nbe:,...].mean()-mean**2,
406            sample_variance = var,
407            error           = error,
408            kappa           = kappa
409            )
410    #end def correct
411#end class ScalarsHDFAnalyzer
412
413
414
415class EnergyDensityAnalyzer(HDFAnalyzer):
416    def __init__(self,name,nindent=0):
417        HDFAnalyzer.__init__(self,nindent=nindent)
418        self.info.set(
419            name = name,
420            reordered = False
421            )
422    #end def __init__
423
424
425    def load_data_local(self,data=None):
426        if data==None:
427            self.error('attempted load without data')
428        #end if
429        name = self.info.name
430        self.data = QAHDFdata()
431        if name in data:
432            hdfg = data[name]
433            hdfg._remove_hidden(deep=False)
434            self.data.transfer_from(hdfg)
435            del data[name]
436        else:
437            self.info.should_remove = True
438        #end if
439    #end def load_data_local
440
441
442    def analyze_local(self):
443        nbe = QAanalyzer.method_info.nblocks_exclude
444        self.info.nblocks_exclude = nbe
445        data = self.data
446
447        #why is this called 3 times?
448        #print nbe
449
450        #transfer hdf data
451        sg_pattern = re.compile(r'spacegrid\d*')
452        nspacegrids=0
453        #  add simple data first
454        for k,v in data.items():
455            if not sg_pattern.match(k):
456                self[k] = v
457            else:
458                nspacegrids+=1
459            #end if
460        #end for
461        #  add spacegrids second
462        opts = QAobject()
463        opts.points = self.reference_points
464        opts.nblocks_exclude = nbe
465        self.spacegrids=[]
466        if nspacegrids==0:
467            self.spacegrids.append(SpaceGrid(data.spacegrid,opts))
468        else:
469            for ig in range(nspacegrids):
470                sg=SpaceGrid(data['spacegrid'+str(ig+1)],opts)
471                self.spacegrids.append(sg)
472            #end for
473        #end if
474
475        #reorder atomic data to match input file for Voronoi grids
476        if self.run_info.type=='bundled':
477            self.info.reordered=True
478        #end if
479        if not self.info.reordered:
480            self.reorder_atomic_data()
481        #end if
482
483        #convert quantities outside all spacegrids
484        outside = QAobject()
485        iD,iT,iV = tuple(range(3))
486        outside.D  = QAobject()
487        outside.T  = QAobject()
488        outside.V  = QAobject()
489        outside.E  = QAobject()
490        outside.P  = QAobject()
491
492        value = self.outside.value.transpose()[...,nbe:]
493
494        #mean,error = simplestats(value)
495        mean,var,error,kappa = simstats(value)
496        outside.D.mean   = mean[iD]
497        outside.D.error  = error[iD]
498        outside.T.mean   = mean[iT]
499        outside.T.error  = error[iT]
500        outside.V.mean   = mean[iV]
501        outside.V.error  = error[iV]
502
503        E  = value[iT,:]+value[iV,:]
504        #mean,error = simplestats(E)
505        mean,var,error,kappa = simstats(E)
506        outside.E.mean  = mean
507        outside.E.error = error
508
509        P  = 2./3.*value[iT,:]+1./3.*value[iV,:]
510        #mean,error = simplestats(P)
511        mean,var,error,kappa = simstats(P)
512        outside.P.mean  = mean
513        outside.P.error = error
514
515        self.outside = outside
516
517        self.outside.data = obj(
518            D = value[iD,:],
519            T = value[iT,:],
520            V = value[iV,:],
521            E = E,
522            P = P
523            )
524
525        # convert ion point data, if present
526        if 'ions' in self:
527            ions = QAobject()
528            ions.D  = QAobject()
529            ions.T  = QAobject()
530            ions.V  = QAobject()
531            ions.E  = QAobject()
532            ions.P  = QAobject()
533
534            value = self.ions.value.transpose()[...,nbe:]
535
536            mean,var,error,kappa = simstats(value)
537            ions.D.mean   = mean[iD]
538            ions.D.error  = error[iD]
539            ions.T.mean   = mean[iT]
540            ions.T.error  = error[iT]
541            ions.V.mean   = mean[iV]
542            ions.V.error  = error[iV]
543
544            E  = value[iT,:]+value[iV,:]
545            mean,var,error,kappa = simstats(E)
546            ions.E.mean  = mean
547            ions.E.error = error
548
549            P  = 2./3.*value[iT,:]+1./3.*value[iV,:]
550            mean,var,error,kappa = simstats(P)
551            ions.P.mean  = mean
552            ions.P.error = error
553
554            ions.data = obj(
555                D = value[iD,:],
556                T = value[iT,:],
557                V = value[iV,:],
558                E = E,
559                P = P
560                )
561
562            self.ions = ions
563        #end if
564
565        return
566    #end def analyze_local
567
568
569    def reorder_atomic_data(self):
570        input = self.run_info.input
571        xml   = self.run_info.ordered_input
572        ps = input.get('particlesets')
573        if 'ion0' in ps and len(ps.ion0.groups)>1 and 'size' in ps.ion0:
574            qsx = xml.simulation.qmcsystem
575            if len(ps)==1:
576                psx = qsx.particleset
577            else:
578                psx=None
579                for pst in qsx.particleset:
580                    if pst.name=='ion0':
581                        psx=pst
582                    #end if
583                #end for
584                if psx==None:
585                    self.error('ion0 particleset not found in qmcpack xml file for atomic reordering of Voronoi energy density')
586                #end if
587            #end if
588
589            #ordered ion names
590            # xml groups are ordered the same as in qmcpack's input file
591            ion_names = []
592            for gx in psx.group:
593                ion_names.append(gx.name)
594            #end for
595
596            #create the mapping to restore proper ordering
597            nions = ps.ion0.size
598            ions = ps.ion0.ionid
599            imap=empty((nions,),dtype=int)
600            icurr = 0
601            for ion_name in ion_names:
602                for i in range(len(ions)):
603                    if ions[i]==ion_name:
604                        imap[i]=icurr
605                        icurr+=1
606                    #end if
607                #end for
608            #end for
609
610            #reorder the atomic data
611            for sg in self.spacegrids:
612                sg.reorder_atomic_data(imap)
613            #end for
614        #end if
615        self.info.reordered=True
616        return
617    #end def reorder_atomic_data
618
619
620    def remove_data(self):
621        QAanalyzer.remove_data(self)
622        if 'spacegrids' in self:
623            for sg in self.spacegrids:
624                if 'data' in sg:
625                    del sg.data
626                #end if
627            #end for
628        #end if
629        if 'outside' in self and 'data' in self.outside:
630            del self.outside.data
631        #end if
632    #end def remove_data
633
634
635    #def prev_init(self):
636    #    if data._contains_group("spacegrid1"):
637    #        self.points = data.spacegrid1.domain_centers
638    #        self.axinv  = data.spacegrid1.axinv
639    #        val = data.spacegrid1.value
640    #        npoints,ndim = self.points.shape
641    #        self.E = zeros((npoints,))
642    #        print 'p shape ',self.points.shape
643    #        print 'v shape ',val.shape
644    #        nblocks,nvpoints = val.shape
645    #        for b in range(nblocks):
646    #            for i in range(npoints):
647    #                ind = 6*i
648    #                self.E[i] += val[b,ind+1] + val[b,ind+2]
649    #            #end for
650    #        #end for
651    #    #end if
652    ##end def prev_init
653
654
655
656    def isosurface(self):
657        from enthought.mayavi import mlab
658
659        npoints,ndim = self.points.shape
660        dimensions = array([20,20,20])
661
662        x = zeros(dimensions)
663        y = zeros(dimensions)
664        z = zeros(dimensions)
665        s = zeros(dimensions)
666
667        ipoint = 0
668        for i in range(dimensions[0]):
669            for j in range(dimensions[1]):
670                for k in range(dimensions[2]):
671                    r = self.points[ipoint,:]
672                    u = dot(self.axinv,r)
673                    #u=r
674                    x[i,j,k] = u[0]
675                    y[i,j,k] = u[1]
676                    z[i,j,k] = u[2]
677                    s[i,j,k] = self.E[ipoint]
678                    ipoint+=1
679                #end for
680            #end for
681        #end for
682
683        mlab.contour3d(x,y,z,s)
684        mlab.show()
685
686        return
687    #end def isosurface
688
689    def mesh(self):
690        return
691    #end def mesh
692
693    def etest(self):
694        from enthought.mayavi import mlab
695        from numpy import pi, sin, cos, exp, arange, array
696        ni=10
697        dr, dphi, dtheta = 1.0/ni, 2*pi/ni, pi/ni
698
699        rlin = arange(0.0,1.0+dr,dr)
700        plin = arange(0.0,2*pi+dphi,dphi)
701        tlin = arange(0.0,pi+dtheta,dtheta)
702        r,phi,theta = ndgrid(rlin,plin,tlin)
703
704        a=1
705
706        fr = .5*exp(-r/a)*(cos(2*pi*r/a)+1.0)
707        fp = (1.0/6.0)*(cos(3.0*phi)+5.0)
708        ft = (1.0/6.0)*(cos(10.0*theta)+5.0)
709
710        f = fr*fp*ft
711
712        x = r*sin(theta)*cos(phi)
713        y = r*sin(theta)*sin(phi)
714        z = r*cos(theta)
715
716
717        #mayavi
718        #mlab.contour3d(x,y,z,f)
719        #mlab.contour3d(r,phi,theta,f)
720        i=7
721        #mlab.mesh(x[i],y[i],z[i],scalars=f[i])
722        mlab.mesh(f[i]*x[i],f[i]*y[i],f[i]*z[i],scalars=f[i])
723        mlab.show()
724
725        return
726    #end def test
727
728
729    def mtest(self):
730        from enthought.mayavi import mlab
731        # Create the data.
732        from numpy import pi, sin, cos, mgrid, arange, array
733        ni = 100.0
734        dtheta, dphi = pi/ni, pi/ni
735
736        #[theta,phi] = mgrid[0:pi+dtheta:dtheta,0:2*pi+dphi:dphi]
737
738        #tlin = arange(0,pi+dtheta,dtheta)
739        #plin = arange(0,2*pi+dphi,dphi)
740        tlin = pi*array([0,.12,.2,.31,.43,.56,.63,.75,.87,.92,1])
741        plin = 2*pi*array([0,.11,.22,.34,.42,.58,.66,.74,.85,.97,1])
742        theta,phi = ndgrid(tlin,plin)
743
744        fp = (1.0/6.0)*(cos(3.0*phi)+5.0)
745        ft = (1.0/6.0)*(cos(10.0*theta)+5.0)
746
747        r = fp*ft
748
749        x = r*sin(theta)*cos(phi)
750        y = r*sin(theta)*sin(phi)
751        z = r*cos(theta)
752
753        # View it.
754        s = mlab.mesh(x, y, z, scalars=r)
755        mlab.show()
756        return
757    #end def
758
759
760    def test(self):
761        from enthought.mayavi import mlab
762        from numpy import array,dot,arange,sin,ogrid,mgrid,zeros
763
764        n=10
765        n2=2*n
766        s = '-'+str(n)+':'+str(n)+':'+str(n2)+'j'
767        self.error('alternative to exec needed')
768        #exec('x, y, z = ogrid['+s+','+s+','+s+']')
769        del s
770
771        #x, y, z = ogrid[-10:10:20j, -10:10:20j, -10:10:20j]
772        #x, y, z = mgrid[-10:11:1, -10:11:1, -10:11:1]
773
774        s = sin(x*y*z)/(x*y*z)
775
776
777        #xl = [-5.0,-4.2,-3.5,-2.1,-1.7,-0.4,0.7,1.8,2.6,3.7,4.3,5.0]
778        #yl = [-5.0,-4.3,-3.6,-2.2,-1.8,-0.3,0.8,1.7,2.7,3.6,4.4,5.0]
779        #zl = [-5.0,-4.4,-3.7,-2.3,-1.9,-0.4,0.9,1.6,2.8,3.5,4.5,5.0]
780        dx = 2.0*n/(2.0*n-1.0)
781        xl = arange(-n,n+dx,dx)
782        yl = xl
783        zl = xl
784
785        x,y,z = ndgrid(xl,yl,zl)
786
787        s2 = sin(x*y*z)/(x*y*z)
788
789        #shear the grid
790        nx,ny,nz = x.shape
791        A = array([[1,1,-1],[1,-1,1],[-1,1,1]])
792        #A = array([[3,2,1],[0,2,1],[0,0,1]])
793        #A = array([[4,7,2],[8,4,3],[2,5,3]])
794        #A = 1.0*array([[1,2,3],[4,5,6],[7,8,9]]).transpose()
795        r = zeros((3,))
796        np=0
797        for i in range(nx):
798            for j in range(ny):
799                for k in range(nz):
800                    r[0] = x[i,j,k]
801                    r[1] = y[i,j,k]
802                    r[2] = z[i,j,k]
803
804                    #print np,r[0],r[1],r[2]
805                    np+=1
806
807                    r = dot(A,r)
808                    x[i,j,k] = r[0]
809                    y[i,j,k] = r[1]
810                    z[i,j,k] = r[2]
811                #end for
812            #end for
813        #end for
814        s2 = sin(x*y*z)/(x*y*z)
815
816        mlab.contour3d(x,y,z,s2)
817        mlab.show()
818
819        out = QAobject()
820        out.x=x
821        out.y=y
822        out.z=z
823        out.s=s2
824        out.A=A
825
826        return out
827    #end def
828
829
830    def test_structured(self):
831
832        import numpy as np
833        from numpy import cos, sin, pi
834        from enthought.tvtk.api import tvtk
835        from enthought.mayavi import mlab
836
837        def generate_annulus(r=None, theta=None, z=None):
838            """ Generate points for structured grid for a cylindrical annular
839                volume.  This method is useful for generating a unstructured
840                cylindrical mesh for VTK (and perhaps other tools).
841
842                Parameters
843                ----------
844                r : array : The radial values of the grid points.
845                            It defaults to linspace(1.0, 2.0, 11).
846
847                theta : array : The angular values of the x axis for the grid
848                                points. It defaults to linspace(0,2*pi,11).
849
850                z: array : The values along the z axis of the grid points.
851                           It defaults to linspace(0,0,1.0, 11).
852
853                Return
854                ------
855                points : array
856                    Nx3 array of points that make up the volume of the annulus.
857                    They are organized in planes starting with the first value
858                    of z and with the inside "ring" of the plane as the first
859                    set of points.  The default point array will be 1331x3.
860            """
861            # Default values for the annular grid.
862            if r is None: r = np.linspace(1.0, 2.0, 11)
863            if theta is None: theta = np.linspace(0, 2*pi, 11)
864            if z is None: z = np.linspace(0.0, 1.0, 11)
865
866            # Find the x values and y values for each plane.
867            x_plane = (cos(theta)*r[:,None]).ravel()
868            y_plane = (sin(theta)*r[:,None]).ravel()
869
870            # Allocate an array for all the points.  We'll have len(x_plane)
871            # points on each plane, and we have a plane for each z value, so
872            # we need len(x_plane)*len(z) points.
873            points = np.empty([len(x_plane)*len(z),3])
874
875            # Loop through the points for each plane and fill them with the
876            # correct x,y,z values.
877            start = 0
878            for z_plane in z:
879                end = start + len(x_plane)
880                # slice out a plane of the output points and fill it
881                # with the x,y, and z values for this plane.  The x,y
882                # values are the same for every plane.  The z value
883                # is set to the current z
884                plane_points = points[start:end]
885                plane_points[:,0] = x_plane
886                plane_points[:,1] = y_plane
887                plane_points[:,2] = z_plane
888                start = end
889
890            return points
891
892        # Make the data.
893        dims = (51, 25, 25)
894        # Note here that the 'x' axis corresponds to 'theta'
895        theta = np.linspace(0, 2*np.pi, dims[0])
896        # 'y' corresponds to varying 'r'
897        r = np.linspace(1, 10, dims[1])
898        z = np.linspace(0, 5, dims[2])
899        pts = generate_annulus(r, theta, z)
900        # Uncomment the following if you want to add some noise to the data.
901        #pts += np.random.randn(dims[0]*dims[1]*dims[2], 3)*0.04
902        sgrid = tvtk.StructuredGrid(dimensions=dims)
903        sgrid.points = pts
904        s = np.sqrt(pts[:,0]**2 + pts[:,1]**2 + pts[:,2]**2)
905        sgrid.point_data.scalars = np.ravel(s.copy())
906        sgrid.point_data.scalars.name = 'scalars'
907
908        contour = mlab.pipeline.contour(sgrid)
909        mlab.pipeline.surface(contour)
910
911
912        return
913    #end def test_structured
914
915
916
917#end class EnergyDensityAnalyzer
918
919
920
921
922class TracesFileHDF(QAobject):
923    def __init__(self,filepath=None,blocks=None):
924        self.info = obj(
925            filepath    = filepath,
926            loaded      = False,
927            accumulated = False,
928            particle_sums_valid = None,
929            blocks      = blocks
930            )
931    #end def __init__
932
933    def loaded(self):
934        return self.info.loaded
935    #end def loaded
936
937    def accumulated_scalars(self):
938        return self.info.accumulated
939    #end def accumulated_scalars
940
941    def checked_particle_sums(self):
942        return self.info.particle_sums_valid!=None
943    #end def checked_particle_sums
944
945    def formed_diagnostic_data(self):
946        return self.accumulated_scalars() and self.checked_particle_sums()
947    #end def formed_diagnostic_data
948
949    def load(self,filepath=None,force=False):
950        if not self.loaded() or force:
951            if filepath is None:
952                if self.info.filepath is None:
953                    self.error('cannot load traces data, filepath has not been defined')
954                else:
955                    filepath = self.info.filepath
956                #end if
957            #end if
958            hr = HDFreader(filepath)
959            if not hr._success:
960                self.warn('  hdf file seems to be corrupted, skipping contents:\n    '+filepath)
961            #end if
962            hdf = hr.obj
963            hdf._remove_hidden()
964            for name,buffer in hdf.items():
965                self.init_trace(name,buffer)
966            #end for
967            self.info.loaded = True
968        #end if
969    #end def load
970
971    def unload(self):
972        if self.loaded():
973            if 'int_traces' in self:
974                del self.int_traces
975            #end if
976            if 'real_traces' in self:
977                del self.real_traces
978            #end if
979            self.info.loaded = False
980        #end if
981    #end def unload
982
983    def init_trace(self,name,fbuffer):
984        trace = obj()
985        if 'traces' in fbuffer:
986            ftrace = fbuffer.traces
987            nrows = len(ftrace)
988            for dname,fdomain in fbuffer.layout.items():
989                domain = obj()
990                for qname,fquantity in fdomain.items():
991                    q = obj()
992                    for vname,value in fquantity.items():
993                        q[vname] = value[0]
994                    #end for
995                    quantity = ftrace[:,q.row_start:q.row_end]
996                    if q.unit_size==1:
997                        shape = [nrows]+list(fquantity.shape[0:q.dimension])
998                    else:
999                        shape = [nrows]+list(fquantity.shape[0:q.dimension])+[q.unit_size]
1000                    #end if
1001                    quantity.shape = tuple(shape)
1002                    #if len(fquantity.shape)==q.dimension:
1003                    #    quantity.shape = tuple([nrows]+list(fquantity.shape))
1004                    ##end if
1005                    domain[qname] = quantity
1006                #end for
1007                trace[dname] = domain
1008            #end for
1009        #end if
1010        self[name.replace('data','traces')] = trace
1011    #end def init_trace
1012
1013
1014    def check_particle_sums(self,tol=1e-8,force=False):
1015        if not self.checked_particle_sums() or force:
1016            self.load()
1017            t = self.real_traces
1018            scalar_names = set(t.scalars.keys())
1019            other_names = []
1020            for dname,domain in t.items():
1021                if dname!='scalars':
1022                    other_names.extend(domain.keys())
1023                #end if
1024            #end for
1025            other_names = set(other_names)
1026            sum_names = scalar_names & other_names
1027            same = True
1028            for qname in sum_names:
1029                q = t.scalars[qname]
1030                qs = 0*q
1031                for dname,domain in t.items():
1032                    if dname!='scalars' and qname in domain:
1033                        tqs = domain[qname].sum(1)
1034                        if len(tqs.shape)==1:
1035                            qs[:,0] += tqs
1036                        else:
1037                            qs[:,0] += tqs[:,0]
1038                        #end if
1039                    #end if
1040                #end for
1041                same = same and (abs(q-qs)<tol).all()
1042            #end for
1043            self.info.particle_sums_valid = same
1044        #end if
1045        return self.info.particle_sums_valid
1046    #end def check_particle_sums
1047
1048
1049    def accumulate_scalars(self,force=False):
1050        if not self.accumulated_scalars() or force:
1051            # get block and step information for the qmc method
1052            blocks = self.info.blocks
1053            if blocks is None:
1054                self.scalars_by_step  = None
1055                self.scalars_by_block = None
1056                return
1057            #end if
1058            # load in traces data if it isn't already
1059            self.load()
1060            # real and int traces
1061            tr = self.real_traces
1062            ti = self.int_traces
1063            # names shared by traces and scalar files
1064            scalar_names = set(tr.scalars.keys())
1065            # step and weight traces
1066            st = ti.scalars.step
1067            wt = tr.scalars.weight
1068            if len(st)!=len(wt):
1069                self.error('weight and steps traces have different lengths')
1070            #end if
1071            #recompute steps (can vary for vmc w/ samples/samples_per_thread)
1072            steps = st.max()+1
1073            steps_per_block = steps//blocks
1074            # accumulate weights into steps and blocks
1075            ws   = zeros((steps,))
1076            wb   = zeros((blocks,))
1077            for t in range(len(wt)):
1078                ws[st[t]] += wt[t]
1079            #end for
1080            s = 0
1081            for b in range(blocks):
1082                wb[b] = ws[s:s+steps_per_block].sum()
1083                s+=steps_per_block
1084            #end for
1085            # accumulate walker population into steps
1086            ps  = zeros((steps,))
1087            for t in range(len(wt)):
1088                ps[st[t]] += 1
1089            #end for
1090            # accumulate quantities into steps and blocks
1091            scalars_by_step  = obj(Weight=ws,NumOfWalkers=ps)
1092            scalars_by_block = obj(Weight=wb)
1093            qs   = zeros((steps,))
1094            qb   = zeros((blocks,))
1095            quantities = set(tr.scalars.keys())
1096            quantities.remove('weight')
1097            for qname in quantities:
1098                qt = tr.scalars[qname]
1099                if len(qt)!=len(wt):
1100                    self.error('quantity {0} trace is not commensurate with weight and steps traces'.format(qname))
1101                #end if
1102                qs[:] = 0
1103                for t in range(len(wt)):
1104                    qs[st[t]] += wt[t]*qt[t]
1105                #end for
1106                qb[:] = 0
1107                s=0
1108                for b in range(blocks):
1109                    qb[b] = qs[s:s+steps_per_block].sum()
1110                    s+=steps_per_block
1111                #end for
1112                qb = qb/wb
1113                qs = qs/ws
1114                scalars_by_step[qname]  = qs.copy()
1115                scalars_by_block[qname] = qb.copy()
1116            #end for
1117            self.scalars_by_step  = scalars_by_step
1118            self.scalars_by_block = scalars_by_block
1119            self.info.accumulated = True
1120        #end if
1121    #end def accumulate_scalars
1122
1123
1124    def form_diagnostic_data(self,tol=1e-8):
1125        if not self.formed_diagnostic_data():
1126            self.load()
1127            self.accumulate_scalars()
1128            self.check_particle_sums(tol=tol)
1129            self.unload()
1130        #end if
1131    #end def form_diagnostic_data
1132#end class TracesFileHDF
1133
1134
1135
1136class TracesAnalyzer(QAanalyzer):
1137    def __init__(self,path,files,nindent=0):
1138        QAanalyzer.__init__(self,nindent=nindent)
1139        self.info.path = path
1140        self.info.files = files
1141        self.method_info = QAanalyzer.method_info
1142        self.data = obj()
1143    #end def __init__
1144
1145
1146    def load_data_local(self):
1147        if 'blocks' in self.method_info.method_input:
1148            blocks = self.method_info.method_input.blocks
1149        else:
1150            blocks = None
1151        #end if
1152        path  = self.info.path
1153        files = self.info.files
1154        self.data.clear()
1155        for file in sorted(files):
1156            filepath = os.path.join(path,file)
1157            trace_file = TracesFileHDF(filepath,blocks)
1158            self.data.append(trace_file)
1159        #end for
1160        #if self.run_info.request.traces:
1161        #    path = self.info.path
1162        #    files = self.info.files
1163        #    if len(files)>1:
1164        #        self.error('ability to read multiple trace files has not yet been implemented\n  files requested: {0}'.format(files))
1165        #    #end if
1166        #    filepath = os.path.join(path,files[0])
1167        #    self.data = TracesFileHDF(filepath)
1168        #    ci(ls(),gs())
1169        ##end if
1170    #end def load_data_local
1171
1172
1173    def form_diagnostic_data(self):
1174        for trace_file in self.data:
1175            trace_file.form_diagnostic_data()
1176        #end for
1177    #end def form_diagnostic_data
1178
1179    def analyze_local(self):
1180        None
1181    #end def analyze_local
1182
1183
1184    def check_particle_sums(self,tol=1e-8):
1185        same = True
1186        for trace_file in self.data:
1187            same &= trace_file.check_particle_sums(tol=tol)
1188        #end for
1189        return same
1190    #end def check_particle_sums
1191
1192
1193    def check_scalars(self,scalars=None,scalars_hdf=None,tol=1e-8):
1194        scalars_valid     = True
1195        scalars_hdf_valid = True
1196        if scalars is None:
1197            scalars_valid = None
1198        #end if
1199        if scalars_hdf is None:
1200            scalars_hdf_valid = None
1201        #end if
1202        if len(self.data)>0:
1203            scalar_names = set(self.data[0].scalars_by_block.keys())
1204            summed_scalars = obj()
1205            if scalars!=None:
1206                qnames = set(scalars.keys()) & scalar_names
1207                summed_scalars.clear()
1208                for qname in qnames:
1209                    summed_scalars[qname] = zeros(scalars[qname].shape)
1210                #end for
1211                wtot = zeros(summed_scalars.first().shape)
1212                for trace_file in self.data:
1213                    w = trace_file.scalars_by_block.Weight
1214                    wtot += w
1215                    for qname in qnames:
1216                        q = trace_file.scalars_by_block[qname]
1217                        summed_scalars[qname] += w*q
1218                    #end for
1219                #end for
1220                for qname in qnames:
1221                    qscalar = scalars[qname]
1222                    qb = summed_scalars[qname]/wtot
1223                    scalars_valid &= (abs(qb-qscalar)<tol).all()
1224                #end for
1225            #end if
1226            if scalars_hdf!=None:
1227                qnames = set(scalars_hdf.keys()) & scalar_names
1228                summed_scalars.clear()
1229                for qname in qnames:
1230                    summed_scalars[qname] = zeros((len(scalars_hdf[qname].value),))
1231                #end for
1232                wtot = zeros(summed_scalars.first().shape)
1233                for trace_file in self.data:
1234                    w = trace_file.scalars_by_block.Weight
1235                    wtot += w
1236                    for qname in qnames:
1237                        q = trace_file.scalars_by_block[qname]
1238                        summed_scalars[qname] += w*q
1239                    #end for
1240                #end for
1241                for qname in qnames:
1242                    qscalar = scalars_hdf[qname].value.ravel()
1243                    qb = summed_scalars[qname]/wtot
1244                    scalars_hdf_valid &= (abs(qb-qscalar)<tol).all()
1245                #end for
1246            #end if
1247        #end if
1248        return scalars_valid,scalars_hdf_valid
1249    #end def check_scalars
1250
1251
1252    def check_dmc(self,dmc,tol=1e-8):
1253        if dmc is None:
1254            dmc_valid = None
1255        else:
1256            dmc_valid = True
1257            if len(self.data)>0:
1258                scalar_names = set(self.data[0].scalars_by_step.keys())
1259                qnames = set(['LocalEnergy','Weight','NumOfWalkers']) & scalar_names
1260                weighted = set(['LocalEnergy'])
1261                summed_scalars = obj()
1262                for qname in qnames:
1263                    summed_scalars[qname] = zeros(dmc[qname].shape)
1264                #end for
1265                wtot = zeros(summed_scalars.first().shape)
1266                for trace_file in self.data:
1267                    w = trace_file.scalars_by_step.Weight
1268                    wtot += w
1269                    for qname in qnames:
1270                        q = trace_file.scalars_by_step[qname]
1271                        if qname in weighted:
1272                            summed_scalars[qname] += w*q
1273                        else:
1274                            summed_scalars[qname] += q
1275                        #end if
1276                    #end for
1277                #end for
1278                for qname in qnames:
1279                    qdmc = dmc[qname]
1280                    if qname in weighted:
1281                        qb = summed_scalars[qname]/wtot
1282                    else:
1283                        qb = summed_scalars[qname]
1284                    #end if
1285                    dmc_valid &= (abs(qb-qdmc)<tol).all()
1286                #end for
1287            #end if
1288        #end if
1289        return dmc_valid
1290    #end def check_dmc
1291
1292
1293    def check_scalars_old(self,scalars=None,scalars_hdf=None,tol=1e-8):
1294        blocks = None
1295        steps_per_block = None
1296        steps = None
1297        method_input = self.method_info.method_input
1298        if 'blocks' in method_input:
1299            blocks = method_input.blocks
1300        #end if
1301        if 'steps' in method_input:
1302            steps_per_block = method_input.steps
1303        #end if
1304        if blocks!=None and steps_per_block!=None:
1305            steps = blocks*steps_per_block
1306        #end if
1307        if steps is None:
1308            return None,None
1309        #end if
1310        # real and int traces
1311        tr = self.data.real_traces
1312        ti = self.data.int_traces
1313        # names shared by traces and scalar files
1314        scalar_names = set(tr.scalars.keys())
1315        # step and weight traces
1316        st = ti.scalars.step
1317        wt = tr.scalars.weight
1318        if len(st)!=len(wt):
1319            self.error('weight and steps traces have different lengths')
1320        #end if
1321        #recompute steps (can vary for vmc w/ samples/samples_per_thread)
1322        steps = st.max()+1
1323        steps_per_block = steps//blocks
1324        # accumulate weights into steps and blocks
1325        ws   = zeros((steps,))
1326        qs   = zeros((steps,))
1327        q2s  = zeros((steps,))
1328        wb   = zeros((blocks,))
1329        qb   = zeros((blocks,))
1330        q2b  = zeros((blocks,))
1331        for t in range(len(wt)):
1332            ws[st[t]] += wt[t]
1333        #end for
1334        s = 0
1335        for b in range(blocks):
1336            wb[b] = ws[s:s+steps_per_block].sum()
1337            s+=steps_per_block
1338        #end for
1339        # check scalar.dat
1340        if scalars is None:
1341            scalars_valid = None
1342        else:
1343            dat_names = set(scalars.keys())     & scalar_names
1344            same = True
1345            for qname in dat_names:
1346                qt = tr.scalars[qname]
1347                if len(qt)!=len(wt):
1348                    self.error('quantity {0} trace is not commensurate with weight and steps traces'.format(qname))
1349                #end if
1350                qs[:] = 0
1351                for t in range(len(qt)):
1352                    qs[st[t]] += wt[t]*qt[t]
1353                #end for
1354                qb[:] = 0
1355                s=0
1356                for b in range(blocks):
1357                    qb[b] = qs[s:s+steps_per_block].sum()
1358                    s+=steps_per_block
1359                #end for
1360                qb = qb/wb
1361                qs = qs/ws
1362                qscalar = scalars[qname]
1363                qsame = (abs(qb-qscalar)<tol).all()
1364                #if not qsame and qname=='LocalEnergy':
1365                #    print '    scalar.dat LocalEnergy'
1366                #    print qscalar
1367                #    print qb
1368                ##end if
1369                same = same and qsame
1370            #end for
1371            scalars_valid = same
1372        #end if
1373        # check scalars from stat.h5
1374        if scalars_hdf is None:
1375            scalars_hdf_valid = None
1376        else:
1377            hdf_names = set(scalars_hdf.keys()) & scalar_names
1378            same = True
1379            for qname in hdf_names:
1380                qt = tr.scalars[qname]
1381                if len(qt)!=len(wt):
1382                    self.error('quantity {0} trace is not commensurate with weight and steps traces'.format(qname))
1383                #end if
1384                qs[:] = 0
1385                q2s[:] = 0
1386                for t in range(len(qt)):
1387                    s = st[t]
1388                    w = wt[t]
1389                    q = qt[t]
1390                    qs[s]  += w*q
1391                    q2s[s] += w*q*q
1392                #end for
1393                qb[:] = 0
1394                s=0
1395                for b in range(blocks):
1396                    qb[b]  = qs[s:s+steps_per_block].sum()
1397                    q2b[b] = q2s[s:s+steps_per_block].sum()
1398                    s+=steps_per_block
1399                #end for
1400                qb  = qb/wb
1401                q2b = q2b/wb
1402                qs  = qs/ws
1403                q2s = q2s/ws
1404                qhdf = scalars_hdf[qname]
1405                qscalar  = qhdf.value.ravel()
1406                q2scalar = qhdf.value_squared.ravel()
1407                qsame  = (abs(qb -qscalar )<tol).all()
1408                q2same = (abs(q2b-q2scalar)<tol).all()
1409                #if not qsame and qname=='LocalEnergy':
1410                #    print '    stat.h5 LocalEnergy'
1411                #    print qscalar
1412                #    print qb
1413                ##end if
1414                same = same and qsame and q2same
1415            #end for
1416            scalars_hdf_valid = same
1417        #end if
1418        return scalars_valid,scalars_hdf_valid
1419    #end def check_scalars_old
1420
1421
1422    def check_dmc_old(self,dmc,tol=1e-8):
1423        if dmc is None:
1424            dmc_valid = None
1425        else:
1426            #dmc data
1427            ene  = dmc.LocalEnergy
1428            wgt  = dmc.Weight
1429            pop  = dmc.NumOfWalkers
1430            # real and int traces
1431            tr = self.data.real_traces
1432            ti = self.data.int_traces
1433            # names shared by traces and scalar files
1434            scalar_names = set(tr.scalars.keys())
1435            # step and weight traces
1436            st = ti.scalars.step
1437            wt = tr.scalars.weight
1438            et = tr.scalars.LocalEnergy
1439            if len(st)!=len(wt):
1440                self.error('weight and steps traces have different lengths')
1441            #end if
1442            #recompute steps (can vary for vmc w/ samples/samples_per_thread)
1443            steps = st.max()+1
1444            # accumulate weights into steps
1445            ws  = zeros((steps,))
1446            es  = zeros((steps,))
1447            ps  = zeros((steps,))
1448            for t in range(len(wt)):
1449                ws[st[t]] += wt[t]
1450            #end for
1451            for t in range(len(wt)):
1452                es[st[t]] += wt[t]*et[t]
1453            #end for
1454            for t in range(len(wt)):
1455                ps[st[t]] += 1
1456            #end for
1457            es/=ws
1458            psame = (abs(ps-pop)<tol).all()
1459            wsame = (abs(ws-wgt)<tol).all()
1460            esame = (abs(es-ene)<tol).all()
1461            dmc_valid = psame and wsame and esame
1462        #end if
1463        return dmc_valid
1464    #end def check_dmc_old
1465
1466
1467    #methods that do not apply
1468    def init_sub_analyzers(self):
1469        None
1470    def zero_data(self):
1471        None
1472    def minsize_data(self,other):
1473        None
1474    def accumulate_data(self,other):
1475        None
1476    def normalize_data(self,normalization):
1477        None
1478#end class TracesAnalyzer
1479
1480
1481class DMSettings(QAobject):
1482    def __init__(self,ds):
1483        self.jackknife = True
1484        self.diagonal  = False
1485        self.save_data = True
1486        self.occ_tol   = 1e-3
1487        self.coup_tol  = 1e-4
1488        self.stat_tol  = 2.0
1489        if ds!=None:
1490            for name,value in ds.items():
1491                if not name in self:
1492                    self.error('{0} is an invalid setting for DensityMatricesAnalyzer\n  valid options are: {1}'.format(name,sorted(self.keys())))
1493                else:
1494                    self[name] = value
1495                #end if
1496            #end for
1497        #end if
1498    #end def __init__
1499#end class DMSettings
1500
1501
1502class DensityMatricesAnalyzer(HDFAnalyzer):
1503
1504    allowed_settings = ['save_data','jackknife','diagonal','occ_tol','coup_tol','stat_tol']
1505
1506    def __init__(self,name,nindent=0):
1507        HDFAnalyzer.__init__(self)
1508        self.info.name = name
1509    #end def __init__
1510
1511
1512    def load_data_local(self,data=None):
1513        if data==None:
1514            self.error('attempted load without data')
1515        #end if
1516        i = complex(0,1)
1517        loc_data = QAdata()
1518        name = self.info.name
1519        self.info.complex = False
1520        if name in data:
1521            matrices = data[name]
1522            del data[name]
1523            matrices._remove_hidden()
1524            for mname,matrix in matrices.items():
1525                mdata = QAdata()
1526                loc_data[mname] = mdata
1527                for species,d in matrix.items():
1528                    v = d.value
1529                    if 'value_squared' in d:
1530                        v2 = d.value_squared
1531                    #end if
1532                    if len(v.shape)==4 and v.shape[3]==2:
1533                        d.value         = v[:,:,:,0]  + i*v[:,:,:,1]
1534                        if 'value_squared' in d:
1535                            d.value_squared = v2[:,:,:,0] + i*v2[:,:,:,1]
1536                        #end if
1537                        self.info.complex = True
1538                    #end if
1539                    mdata[species] = d
1540                #end for
1541            #end for
1542        #end for
1543        self.data = loc_data
1544        self.info.should_remove = False
1545    #end def load_data_local
1546
1547
1548    def analyze_local(self):
1549        # 1) exclude states that do not contribute to the number trace
1550        # 2) exclude elements that are not statistically significant (1 sigma?)
1551        # 3) use remaining states to form filtered number and energy matrices
1552        # 4) perform jackknife sampling to get eigenvalue error bars
1553        # 5) consider using cross-correlations w/ excluded elements to reduce variance
1554
1555        ds = DMSettings(self.run_info.request.dm_settings)
1556        diagonal  = ds.diagonal
1557        jackknife = ds.jackknife and not diagonal
1558        save_data = ds.save_data
1559        occ_tol   = ds.occ_tol
1560        coup_tol  = ds.coup_tol
1561        stat_tol  = ds.stat_tol
1562
1563        nbe = QAanalyzer.method_info.nblocks_exclude
1564        self.info.nblocks_exclude = nbe
1565        has_nmat = 'number_matrix' in self.data
1566        has_emat = 'energy_matrix' in self.data
1567        species = list(self.data.number_matrix.keys())
1568        species_sizes = obj()
1569        ps = self.run_info.input.get('particleset')
1570        for s in species:
1571            species_sizes[s] = ps.e.groups[s].size
1572        #end for
1573        mnames = []
1574        if has_nmat:
1575            mnames.append('number_matrix')
1576            if has_emat:
1577                mnames.append('energy_matrix')
1578            #end if
1579        #end if
1580
1581        for species_name in species:
1582            for matrix_name in mnames:
1583                if not matrix_name in self:
1584                    self[matrix_name] = obj()
1585                #end if
1586                mres = self[matrix_name]
1587                msres = obj()
1588                mres[species_name] = msres
1589
1590                species_data = self.data[matrix_name][species_name]
1591
1592                md_all = species_data.value
1593                mdata  = md_all[nbe:,...]
1594
1595                tdata = zeros((len(md_all),))
1596                b = 0
1597                for mat in md_all:
1598                    tdata[b] = trace(mat).real # trace sums to N-elec (real)
1599                    b+=1
1600                #end for
1601                t,tvar,terr,tkap = simstats(tdata[nbe:])
1602                msres.trace        = t
1603                msres.trace_error  = terr
1604
1605                if save_data:
1606                    msres.trace_data = tdata
1607                    msres.data       = md_all
1608                #end if
1609
1610                if diagonal:
1611                    ddata = empty(mdata.shape[0:2],dtype=mdata.dtype)
1612                    b = 0
1613                    for mat in mdata:
1614                        ddata[b] = diag(mat)
1615                        b+=1
1616                    #end for
1617                    d,dvar,derr,dkap = simstats(ddata.transpose())
1618                    msres.set(
1619                        eigval  = d,
1620                        eigvec  = identity(len(d)),
1621                        eigmean = d,
1622                        eigerr  = derr
1623                        )
1624                else:
1625                    m,mvar,merr,mkap = simstats(mdata.transpose((1,2,0)))
1626
1627                    mfull  = m
1628                    mefull = merr
1629
1630                    if matrix_name=='number_matrix':
1631                        # remove states that do not have significant occupation
1632                        nspec = species_sizes[species_name]
1633                        occ = diag(m)/t*nspec
1634                        nstates = len(occ)
1635                        abs_occ = abs(occ)
1636                        abs_occ.sort()
1637                        nsum = 0
1638                        i = -1
1639                        min_occ = 0
1640                        for o in abs_occ:
1641                            if nsum+o<occ_tol:
1642                                nsum+=o
1643                                i+=1
1644                            #end if
1645                        #end if
1646                        if i!=-1:
1647                            min_occ = abs_occ[i]+1e-12
1648                        #end if
1649                        sig_states = arange(nstates)[abs(occ)>min_occ]
1650                        nsig = len(sig_states)
1651                        if nsig<nspec:
1652                            self.warn('number matrix fewer occupied states than particles')
1653                            sig_states = arange(nstates)
1654                        #end if
1655                        sig_occ = empty((nstates,nstates),dtype=bool)
1656                        sig_occ[:,:] = False
1657                        for s in sig_states:
1658                            sig_occ[s,sig_states] = True
1659                        #end for
1660                    #end if
1661                    # remove states with insignificant occupation
1662                    mos = m
1663                    m = m[sig_occ]
1664                    m.shape = nsig,nsig
1665                    merr = merr[sig_occ]
1666                    merr.shape = nsig,nsig
1667                    # remove off-diagonal elements with insignificant coupling
1668                    insig_coup = ones(m.shape,dtype=bool)
1669                    for i in range(nsig):
1670                        for j in range(nsig):
1671                            mdiag = min((abs(m[i,i]),abs(m[j,j])))
1672                            insig_coup[i,j] = abs(m[i,j])/mdiag < coup_tol
1673                        #end for
1674                    #end for
1675                    # remove elements with insignificant statistical deviation from zero
1676                    insig_stat = abs(m)/merr < stat_tol
1677                    # remove insignificant elements
1678                    insig_coup_stat = insig_coup | insig_stat
1679                    for i in range(nsig):
1680                        insig_coup_stat[i,i] = False
1681                    #end for
1682                    moi = m.copy()
1683                    m[insig_coup_stat] = 0.0
1684
1685                    # obtain standard eigenvalue estimates
1686                    eigval,eigvec = eig(m)
1687
1688                    # save common results
1689                    msres.set(
1690                        matrix            = m,
1691                        matrix_error      = merr,
1692                        sig_states        = sig_states,
1693                        sig_occ           = sig_occ,
1694                        insig_coup        = insig_coup,
1695                        insig_stat        = insig_stat,
1696                        insig_coup_stat   = insig_coup_stat,
1697                        eigval            = eigval,
1698                        eigvec            = eigvec,
1699                        matrix_full       = mfull,
1700                        matrix_error_full = mefull,
1701                        )
1702
1703                    if jackknife:
1704                        # obtain jackknife eigenvalue estimates
1705                        nblocks  = len(mdata)
1706                        mjdata   = zeros((nblocks,nsig,nsig),dtype=mdata.dtype)
1707                        eigsum   = zeros((nsig,),dtype=mdata.dtype)
1708                        eigsum2r = zeros((nsig,),dtype=mdata.dtype)
1709                        eigsum2i = zeros((nsig,),dtype=mdata.dtype)
1710                        i = complex(0,1)
1711                        nb = float(nblocks)
1712                        for b in range(nblocks):
1713                            mb = mdata[b,...][sig_occ]
1714                            mb.shape = nsig,nsig
1715                            mb[insig_coup_stat] = 0.0
1716                            mj = (nb*m-mb)/(nb-1)
1717                            mjdata[b,...] = mj
1718                            d,v = eig(mj)
1719                            eigsum   += d
1720                            eigsum2r += real(d)**2
1721                            eigsum2i += imag(d)**2
1722                        #end for
1723                        eigmean = eigsum/nb
1724                        esr = real(eigsum)
1725                        esi = imag(eigsum)
1726                        eigvar  = (nb-1)/nb*(eigsum2r+i*eigsum2i-(esr**2+i*esi**2)/nb)
1727                        eigerr  = sqrt(real(eigvar))+i*sqrt(imag(eigvar))
1728                        msres.set(
1729                            eigmean         = eigmean,
1730                            eigerr          = eigerr
1731                            )
1732
1733                        # perform generalized eigenvalue analysis for energy matrix
1734                        if matrix_name=='number_matrix':
1735                            nmjdata = mjdata
1736                            nm      = m
1737                        elif matrix_name=='energy_matrix':
1738                            # obtain general eigenvalue estimates
1739                            em = m
1740                            geigval,geigvec = eig(em,nm)
1741                            # get occupations of  eigenvectors
1742                            eigocc  = zeros((nsig,),dtype=mdata.dtype)
1743                            geigocc = zeros((nsig,),dtype=mdata.dtype)
1744                            for k in range(nsig):
1745                                v = eigvec[:,k]
1746                                eigocc[k] = dot(v.conj(),dot(nm,v))
1747                                v = geigvec[:,k]
1748                                geigocc[k] = dot(v.conj(),dot(nm,v))
1749                            #end for
1750                            # obtain jackknife estimates of generalized eigenvalues
1751                            emjdata = mjdata
1752                            eigsum[:]   = 0.0
1753                            eigsum2r[:] = 0.0
1754                            eigsum2i[:] = 0.0
1755                            for b in range(nblocks):
1756                                d,v = eig(emjdata[b,...],nmjdata[b,...])
1757                                eigsum   += d
1758                                eigsum2r += real(d)**2
1759                                eigsum2i += imag(d)**2
1760                            #end for
1761                            geigmean = eigsum/nb
1762                            esr = real(eigsum)
1763                            esi = imag(eigsum)
1764                            eigvar  = (nb-1)/nb*(eigsum2r+i*eigsum2i-(esr**2+i*esi**2)/nb)
1765                            geigerr  = sqrt(real(eigvar))+i*sqrt(imag(eigvar))
1766                            # save the results
1767                            msres.set(
1768                                eigocc   = eigocc,
1769                                geigocc  = geigocc,
1770                                geigval  = geigval,
1771                                geigvec  = geigvec,
1772                                geigmean = geigmean,
1773                                geigerr  = geigerr
1774                                )
1775                        #end if
1776                    #end if
1777                #end if
1778            #end for
1779        #end for
1780        del self.data
1781        #self.write_files()
1782    #end def analyze_local
1783
1784
1785    def analyze_local_orig(self):
1786        nbe = QAanalyzer.method_info.nblocks_exclude
1787        self.info.nblocks_exclude = nbe
1788        for matrix_name,matrix_data in self.data.items():
1789            mres = obj()
1790            self[matrix_name] = mres
1791            for species_name,species_data in matrix_data.items():
1792                md_all = species_data.value
1793                mdata  = md_all[nbe:,...]
1794                m,mvar,merr,mkap = simstats(mdata.transpose((1,2,0)))
1795
1796                tdata = zeros((len(md_all),))
1797                b = 0
1798                for mat in md_all:
1799                    tdata[b] = trace(mat)
1800                    b+=1
1801                #end for
1802                t,tvar,terr,tkap = simstats(tdata[nbe:])
1803
1804                try:
1805                    val,vec = eig(m)
1806                except LinAlgError as e:
1807                    self.warn(matrix_name+' diagonalization failed!')
1808                    val,vec = None,None
1809                #end try
1810
1811                mres[species_name] = obj(
1812                    matrix       = m,
1813                    matrix_error = merr,
1814                    eigenvalues  = val,
1815                    eigenvectors = vec,
1816                    trace        = t,
1817                    trace_error  = terr,
1818                    trace_data   = tdata,
1819                    data         = md_all
1820                    )
1821            #end for
1822        #end for
1823        if self.has_energy_matrix():
1824            nmat = self.number_matrix
1825            emat = self.energy_matrix
1826            for s,es in emat.items():
1827                ns = nmat[s]
1828                nm = ns.matrix
1829                em = es.matrix
1830                try:
1831                    val,vec = eig(em,nm)
1832                except LinAlgError:
1833                    self.warn('energy matrix generalized diagonalization failed!')
1834                    val,vec = None,None
1835                #end try
1836                size = len(vec)
1837                occ = zeros((size,),dtype=nm.dtype)
1838                for i in range(size):
1839                    v = vec[:,i]
1840                    occ[i] = dot(v.conj(),dot(nm,v))
1841                #end for
1842                es.set(
1843                    energies       = val,
1844                    occupations    = occ,
1845                    energy_vectors = vec
1846                    )
1847            #end for
1848        #end if
1849        del self.data
1850        #self.write_files()
1851        ci(ls(),gs())
1852    #end def analyze_local_orig
1853
1854
1855    def has_energy_matrix(self):
1856        return 'energy_matrix' in self
1857    #end def has_energy_matrix
1858
1859    def write_files(self,path='./'):
1860        prefix = self.method_info.file_prefix
1861        nm = self.number_matrix
1862        for gname,g in nm.items():
1863            filename =  '{0}.dm1b_{1}.dat'.format(prefix,gname)
1864            filepath = os.path.join(path,filename)
1865            mean  = g.matrix.ravel()
1866            error = g.matrix_error.ravel()
1867            if not self.info.complex:
1868                savetxt(filepath,concatenate((mean,error)))
1869            else:
1870                savetxt(filepath,concatenate((real(mean ),imag(mean ),
1871                                              real(error),imag(error))))
1872            #end if
1873        #end for
1874    #end def write_files
1875#end class DensityMatricesAnalyzer
1876
1877
1878
1879class DensityAnalyzerBase(HDFAnalyzer):
1880    def __init__(self,name,nindent=0):
1881        HDFAnalyzer.__init__(self)
1882        self.info.set(
1883            name        = name,
1884            structure   = self.run_info.system.structure,
1885            file_prefix = self.run_info.file_prefix,
1886            source_path = self.run_info.source_path,
1887            series      = self.method_info.series
1888            )
1889        try:
1890            self.info.xml  = self.run_info.input.get(self.info.name)
1891        except:
1892            self.info.xml = None
1893        #end try
1894    #end def __init__
1895
1896
1897    def write_single_density(self,name,density,density_err,format='xsf'):
1898        if format!='xsf':
1899            self.error('sorry, the density can only be written in xsf format for now\n  you requested: {0}'.format(format))
1900        #end if
1901
1902        s = self.info.structure.copy()
1903        p = s.pos.ravel()
1904        if p.min()>0 and p.max()<1.0:
1905            s.pos_to_cartesian()
1906        #end if
1907        s.change_units('A')
1908        cell   = s.axes
1909
1910        f = XsfFile()
1911        f.incorporate_structure(s)
1912
1913        prefix = '{0}.s{1}.{2}'.format(self.info.file_prefix,str(self.info.series).zfill(3),name)
1914
1915        c = 1
1916        g = 1
1917        t = 1
1918
1919        print('writing to ',self.info.source_path,prefix)
1920
1921        # mean
1922        f.add_density(cell,density,centered=c,add_ghost=g)
1923        f.write(os.path.join(self.info.source_path,prefix+'.xsf'))
1924
1925        # mean + errorbar
1926        f.add_density(cell,density+density_err,centered=c,add_ghost=g)
1927        f.write(os.path.join(self.info.source_path,prefix+'+err.xsf'))
1928
1929        # mean - errorbar
1930        f.add_density(cell,density-density_err,centered=c,add_ghost=g)
1931        f.write(os.path.join(self.info.source_path,prefix+'-err.xsf'))
1932    #end def write_single_density
1933
1934
1935    def write_density(self,format='xsf'):
1936        self.not_implemented()
1937    #end def write_density
1938#end class DensityAnalyzerBase
1939
1940
1941
1942class SpinDensityAnalyzer(DensityAnalyzerBase):
1943    def load_data_local(self,data=None):
1944        if data==None:
1945            self.error('attempted load without data')
1946        #end if
1947        name = self.info.name
1948        if name in data:
1949            hdata = data[name]
1950            hdata._remove_hidden()
1951            self.data = QAHDFdata()
1952            self.data.transfer_from(hdata)
1953            del data[name]
1954        else:
1955            self.info.should_remove = True
1956        #end if
1957
1958        if 'grid' in self.info.xml:
1959            g = self.info.xml.grid
1960        else:
1961            dr = self.info.xml.dr
1962            g = array((ceil(sqrt(self.info.structure.axes[0].dot(self.info.structure.axes[0]))/dr[0]),
1963                       ceil(sqrt(self.info.structure.axes[1].dot(self.info.structure.axes[1]))/dr[1]),
1964                       ceil(sqrt(self.info.structure.axes[2].dot(self.info.structure.axes[2]))/dr[2])),dtype=int)
1965        #end if
1966
1967        for d in self.data:
1968            b = len(d.value)
1969            d.value.shape = (b,g[0],g[1],g[2])
1970            if 'value_squared' in d:
1971                d.value_squared.shape = (b,g[0],g[1],g[2])
1972            #end if
1973        #end for
1974    #end def load_data_local
1975
1976
1977    def analyze_local(self):
1978        nbe = QAanalyzer.method_info.nblocks_exclude
1979        for group,data in self.data.items():
1980            gdata = data.value[nbe:,...]
1981            g = obj()
1982            #g.mean,g.variance,g.error,g.kappa = simstats(gdata,dim=0)
1983            g.mean,g.error = simplestats(gdata,dim=0)
1984            self[group] = g
1985        #end for
1986        self.info.nblocks_exclude = nbe
1987        #self.write_files()
1988    #end def analyze_local
1989
1990
1991    def write_files(self,path='./'):
1992        prefix = self.method_info.file_prefix
1993        for gname in self.data.keys():
1994            filename =  '{0}.spindensity_{1}.dat'.format(prefix,gname)
1995            filepath = os.path.join(path,filename)
1996            mean  = self[gname].mean.ravel()
1997            error = self[gname].error.ravel()
1998            savetxt(filepath,concatenate((mean,error)))
1999        #end for
2000    #end def write_files
2001
2002
2003    def write_density(self,format='xsf'):
2004        nbe = self.info.nblocks_exclude
2005        umean = self.u.mean
2006        uerr  = self.u.error
2007        dmean = self.d.mean
2008        derr  = self.d.error
2009
2010        upd_data = self.data.u.value + self.data.d.value
2011        umd_data = self.data.u.value - self.data.d.value
2012
2013        upd_mean,upd_err = simplestats(upd_data[nbe:,...],dim=0)
2014        umd_mean,umd_err = simplestats(umd_data[nbe:,...],dim=0)
2015
2016        self.write_single_density('spindensity_u'  ,umean   ,uerr   ,format)
2017        self.write_single_density('spindensity_d'  ,dmean   ,derr   ,format)
2018        self.write_single_density('spindensity_u+d',upd_mean,upd_err,format)
2019        self.write_single_density('spindensity_u-d',umd_mean,umd_err,format)
2020    #end def write_density
2021#end class SpinDensityAnalyzer
2022
2023
2024
2025
2026class StructureFactorAnalyzer(HDFAnalyzer):
2027    def __init__(self,name,nindent=0):
2028        HDFAnalyzer.__init__(self)
2029        self.info.name = name
2030    #end def __init__
2031
2032
2033    def load_data_local(self,data=None):
2034        if data==None:
2035            self.error('attempted load without data')
2036        #end if
2037        name = self.info.name
2038        if name in data:
2039            hdata = data[name]
2040            hdata._remove_hidden()
2041            self.data = QAHDFdata()
2042            self.data.transfer_from(hdata)
2043            del data[name]
2044        else:
2045            self.info.should_remove = True
2046        #end if
2047    #end def load_data_local
2048
2049
2050    def analyze_local(self):
2051        nbe = QAanalyzer.method_info.nblocks_exclude
2052        for group,data in self.data.items():
2053            gdata = data.value[nbe:,...]
2054            g = obj()
2055            #g.mean,g.variance,g.error,g.kappa = simstats(gdata,dim=0)
2056            g.mean,g.error = simplestats(gdata,dim=0)
2057            self[group] = g
2058        #end for
2059        self.info.nblocks_exclude = nbe
2060        #self.write_files()
2061    #end def analyze_local
2062
2063
2064    def write_files(self,path='./'):
2065        print('  sf write files')
2066        prefix = self.method_info.file_prefix
2067        for gname in self.data.keys():
2068            filename =  '{0}.structurefactor_{1}.dat'.format(prefix,gname)
2069            filepath = os.path.join(path,filename)
2070            mean  = self[gname].mean.ravel()
2071            error = self[gname].error.ravel()
2072            savetxt(filepath,concatenate((mean,error)))
2073        #end for
2074    #end def write_files
2075#end class StructureFactorAnalyzer
2076
2077
2078
2079
2080
2081class DensityAnalyzer(DensityAnalyzerBase):
2082
2083    def load_data_local(self,data=None):
2084        if data==None:
2085            self.error('attempted load without data')
2086        #end if
2087        name = self.info.name
2088        if name in data:
2089            hdata = data[name]
2090            hdata._remove_hidden()
2091            self.data = QAHDFdata()
2092            self.data.transfer_from(hdata)
2093            del data[name]
2094        else:
2095            self.info.should_remove = True
2096        #end if
2097    #end def load_data_local
2098
2099
2100    def analyze_local(self):
2101        nbe = QAanalyzer.method_info.nblocks_exclude
2102        self.mean,self.error = simplestats(self.data.value[nbe:,...],dim=0)
2103        self.info.nblocks_exclude = nbe
2104    #end def analyze_local
2105
2106
2107    def write_density(self,format='xsf'):
2108        self.write_single_density('density',self.mean,self.error,format)
2109    #end def write_density
2110#end class DensityAnalyzer
2111
2112
2113
2114
2115
2116
2117
2118# spacegrid code
2119
2120import re
2121import copy
2122from numpy import array,floor,sqrt,zeros,prod,dot,ones,empty,min,max
2123from numpy import pi,sin,cos,arccos as acos,arctan2 as atan2
2124from numpy.linalg import inv,det
2125from numerics import simplestats,ndgrid,ogrid,arange,simstats
2126from hdfreader import HDFgroup
2127
2128#simple constants
2129o2pi = 1./(2.*pi)
2130
2131#simple functions
2132def is_integer(i):
2133    return abs(floor(i)-i)<1e-6
2134#end def is_integer
2135
2136
2137class SpaceGridInitializer(QAobject):
2138    def __init__(self):
2139        self.coord              = None # string
2140        return
2141    #end def __init__
2142
2143    def check_complete(self,exit_on_fail=True):
2144        succeeded = True
2145        for k,v in self.items():
2146            if v==None:
2147                succeeded=False
2148                if exit_on_fail:
2149                    self.error('  SpaceGridInitializer.'+k+' must be provided',exit=False)
2150                #end if
2151            #end if
2152        #end if
2153        if not succeeded and exit_on_fail:
2154            self.error('  SpaceGridInitializer is incomplete')
2155        #end if
2156        return succeeded
2157    #end def check_complete
2158#end class SpaceGridInitializer
2159
2160
2161class SpaceGridBase(QAobject):
2162    cnames=['cartesian','cylindrical','spherical','voronoi']
2163    coord_s2n = dict()
2164    coord_n2s = dict()
2165    for i,name in enumerate(cnames):
2166        coord_s2n[name]=i
2167        coord_n2s[i]=name
2168    #end for
2169
2170    cartesian   = coord_s2n['cartesian']
2171    cylindrical = coord_s2n['cylindrical']
2172    spherical   = coord_s2n['spherical']
2173    voronoi     = coord_s2n['voronoi']
2174
2175    xlabel = 0
2176    ylabel = 1
2177    zlabel = 2
2178    rlabel = 3
2179    plabel = 4
2180    tlabel = 5
2181    axlabel_s2n = {'x':xlabel,'y':ylabel,'z':zlabel,'r':rlabel,'phi':plabel,'theta':tlabel}
2182    axlabel_n2s = {xlabel:'x',ylabel:'y',zlabel:'z',rlabel:'r',plabel:'phi',tlabel:'theta'}
2183
2184    axindex = {'x':0,'y':1,'z':2,'r':0,'phi':1,'theta':2}
2185
2186    quantities=['D','T','V','E','P']
2187
2188    def __init__(self,initobj,options):
2189        if options==None:
2190            options = QAobject()
2191            options.wasNone = True
2192            options.points       = None
2193            options.exit_on_fail = True
2194            options.nblocks_exclude = 0
2195        else:
2196            if 'points' not in options:
2197                options.points = None
2198            if 'exit_on_fail' not in options:
2199                options.exit_on_fail = True
2200            if 'nblocks_exclude' not in options:
2201                options.nblocks_exclude = 0
2202        #end if
2203
2204        self.points          = options.points
2205        self.init_exit_fail  = options.exit_on_fail
2206        self.nblocks_exclude = options.nblocks_exclude
2207        self.keep_data = True
2208        delvars = ['init_exit_fail','keep_data']
2209
2210        self.coord          = None # string
2211        self.coordinate     = None
2212        self.ndomains       = None
2213        self.domain_volumes = None
2214        self.domain_centers = None
2215        self.nvalues_per_domain = -1
2216        self.nblocks            = -1
2217        self.D  = QAobject() #Number Density
2218        self.T  = QAobject() #Kinetic Energy Density
2219        self.V  = QAobject() #Potential Energy Density
2220        self.E  = QAobject() #Energy Density, T+V
2221        self.P  = QAobject() #Local Pressure, (Volume)*P=(2*T+V)/3
2222
2223
2224        self.init_special()
2225
2226        if initobj==None:
2227            return
2228        #end if
2229
2230        self.DIM=3
2231
2232        iname = initobj.__class__.__name__
2233        self.iname=iname
2234        if iname==self.__class__.__name__+'Initializer':
2235            self.init_from_initializer(initobj)
2236        elif iname==self.__class__.__name__:
2237            self.init_from_spacegrid(initobj)
2238        elif iname=='HDFgroup':
2239            self.init_from_hdfgroup(initobj)
2240        elif iname=='XMLelement':
2241            self.init_from_xmlelement(initobj)
2242        else:
2243            self.error('Spacegrid cannot be initialized from '+iname)
2244        #end if
2245        delvars.append('iname')
2246
2247        self.check_complete()
2248
2249        for dv in delvars:
2250            del self[dv]
2251        #end for
2252
2253        self._reset_dynamic_methods()
2254        self._register_dynamic_methods()
2255        return
2256    #end def __init__
2257
2258    def copy(self,other):
2259        None
2260    #end def copy
2261
2262    def init_special(self):
2263        None
2264    #end def init_special
2265
2266    def init_from_initializer(self,init):
2267        None
2268    #end def init_from_initializer
2269
2270    def init_from_spacegrid(self,init):
2271        None
2272    #end def init_from_spacegrid
2273
2274    def init_from_hdfgroup(self,init):
2275        #copy all datasets from hdf group
2276        value_pattern = re.compile('value')
2277        gmap_pattern = re.compile(r'gmap\d*')
2278        for k,v in init.items():
2279            exclude = k[0]=='_' or gmap_pattern.match(k) or value_pattern.match(k)
2280            if not exclude:
2281                self[k]=v
2282            #end if
2283        #end for
2284
2285        #convert 1x and 1x1 numpy arrays to just numbers
2286        #convert Nx1 and 1xN numpy arrays to Nx arrays
2287        array_type = type(array([]))
2288        exclude = set(['value','value_squared'])
2289        for k,v in self.items():
2290            if k[0]!='_' and type(v)==array_type and k not in exclude:
2291                sh=v.shape
2292                ndim = len(sh)
2293                if ndim==1 and sh[0]==1:
2294                    self[k]=v[0]
2295                elif ndim==2:
2296                    if sh[0]==1 and sh[1]==1:
2297                        self[k]=v[0,0]
2298                    elif sh[0]==1 or sh[1]==1:
2299                        self[k]=v.reshape((sh[0]*sh[1],))
2300                    #end if
2301                #end if
2302            #end if
2303        #end for
2304
2305        #set coord string
2306        self.coord = SpaceGridBase.coord_n2s[self.coordinate]
2307
2308        #determine if chempot grid
2309        chempot = 'min_part' in init
2310        self.chempot = chempot
2311        if chempot:
2312            npvalues = self.max_part-self.min_part+1
2313            self.npvalues = npvalues
2314        #end if
2315
2316        #process the data in hdf value,value_squared
2317        nbe = self.nblocks_exclude
2318        nquant   = self.nvalues_per_domain
2319        ndomains = self.ndomains
2320        nblocks,ntmp = init.value.shape
2321        self.nblocks = nblocks
2322
2323        if not chempot:
2324            value = init.value.reshape(nblocks,ndomains,nquant).transpose(2,1,0)
2325        else:
2326            value = init.value.reshape(nblocks,ndomains,npvalues,nquant).transpose(3,2,1,0)
2327        #end if
2328        value = value[...,nbe:]
2329
2330        (mean,var,error,kappa)=simstats(value)
2331        quants = ['D','T','V']
2332        iD = -1
2333        iT = -1
2334        iV = -1
2335        for i in range(len(quants)):
2336            q=quants[i]
2337            self[q].mean  =  mean[i,...]
2338            self[q].error = error[i,...]
2339            if q=='D':
2340                iD = i
2341            elif q=='T':
2342                iT = i
2343            elif q=='V':
2344                iV = i
2345            else:
2346                self.error('quantity "{}" not recognized'.format(q))
2347            #end if
2348        #end for
2349
2350        E = value[iT,...]+value[iV,...]
2351        (mean,var,error,kappa)=simstats(E)
2352        self.E.mean  =  mean
2353        self.E.error = error
2354
2355        P = 2./3.*value[iT,...]+1./3.*value[iV,...]
2356        (mean,var,error,kappa)=simstats(P)
2357        self.P.mean  =  mean
2358        self.P.error = error
2359
2360
2361        #convert all quantities into true densities
2362        ovol = 1./self.domain_volumes
2363        sqovol = sqrt(ovol)
2364        for q in SpaceGridBase.quantities:
2365            self[q].mean  *= ovol
2366            self[q].error *= sqovol
2367        #end for
2368
2369        #keep original data, if requested
2370        if self.keep_data:
2371            self.data = QAobject()
2372            for i in range(len(quants)):
2373                q=quants[i]
2374                self.data[q] = value[i,...]
2375            #end for
2376            self.data.E = E
2377            self.data.P = P
2378        #end if
2379
2380        return
2381    #end def init_from_hdfgroup
2382
2383    def init_from_xmlelement(self,init):
2384        None
2385    #end def init_from_xmlelement
2386
2387    def check_complete(self,exit_on_fail=True):
2388        succeeded = True
2389        for k,v in self.items():
2390            if k[0]!='_' and v is None:
2391                succeeded=False
2392                if exit_on_fail:
2393                    self.error('SpaceGridBase.'+k+' must be provided',exit=False)
2394                #end if
2395            #end if
2396        #end if
2397        if not succeeded:
2398            self.error('SpaceGrid attempted initialization from '+self.iname,exit=False)
2399            self.error('SpaceGrid is incomplete',exit=False)
2400            if exit_on_fail:
2401                exit()
2402            #end if
2403        #end if
2404        return succeeded
2405    #end def check_complete
2406
2407    def _reset_dynamic_methods(self):
2408        None
2409    #end def _reset_dynamic_methods
2410
2411    def _unset_dynamic_methods(self):
2412        None
2413    #end def _unset_dynamic_methods
2414
2415    def add_all_attributes(self,o):
2416        for k,v in o.__dict__.items():
2417            if not k.startswith('_'):
2418                vc = copy.deepcopy(v)
2419                self._add_attribute(k,vc)
2420            #end if
2421        #end for
2422        return
2423    #end def add_all_attributes
2424
2425
2426    def reorder_atomic_data(self,imap):
2427        None
2428    #end if
2429
2430
2431    def integrate(self,quantity,domain=None):
2432        if quantity not in SpaceGridBase.quantities:
2433            msg = 'requested integration of quantity '+quantity+'\n'
2434            msg +='  '+quantity+' is not a valid SpaceGrid quantity\n'
2435            msg +='  valid quantities are:\n'
2436            msg +='  '+str(SpaceGridBase.quantities)
2437            self.error(msg)
2438        #end if
2439        dv = self.domain_volumes
2440        if domain==None:
2441            mean = (self[quantity].mean*dv).sum()
2442            error = sqrt((self[quantity].error**2*dv).sum())
2443        else:
2444            mean = (self[quantity].mean[domain]*dv[domain]).sum()
2445            error = sqrt((self[quantity].error[domain]**2*dv[domain]).sum())
2446        #end if
2447        return mean,error
2448    #end def integrate
2449
2450    def integrate_data(self,quantity,*domains,**kwargs):
2451        return_list = False
2452        if 'domains' in kwargs:
2453            domains = kwargs['domains']
2454            return_list = True
2455        #end if
2456        if 'return_list' in kwargs:
2457            return_list = kwargs['return_list']
2458        #end if
2459        if quantity not in SpaceGridBase.quantities:
2460            msg = 'requested integration of quantity '+quantity+'\n'
2461            msg +='  '+quantity+' is not a valid SpaceGrid quantity\n'
2462            msg +='  valid quantities are:\n'
2463            msg +='  '+str(SpaceGridBase.quantities)
2464            self.error(msg)
2465        #end if
2466        q = self.data[quantity]
2467        results = list()
2468        nblocks = q.shape[-1]
2469        qi = zeros((nblocks,))
2470        if len(domains)==0:
2471            for b in range(nblocks):
2472                qi[b] = q[...,b].sum()
2473            #end for
2474            (mean,var,error,kappa)=simstats(qi)
2475        else:
2476            for domain in domains:
2477                for b in range(nblocks):
2478                    qb = q[...,b]
2479                    qi[b] = qb[domain].sum()
2480                #end for
2481                (mean,var,error,kappa)=simstats(qi)
2482                res = QAobject()
2483                res.mean  = mean
2484                res.error = error
2485                res.data  = qi.copy()
2486                results.append(res)
2487            #end for
2488        #end for
2489        if len(domains)<2:
2490            return mean,error
2491        else:
2492            if not return_list:
2493                return tuple(results)
2494            else:
2495                means = list()
2496                errors = list()
2497                for res in results:
2498                    means.append(res.mean)
2499                    errors.append(res.error)
2500                #end for
2501                return means,errors
2502            #end if
2503        #end if
2504    #end def integrate_data
2505
2506#end class SpaceGridBase
2507
2508
2509
2510
2511class RectilinearGridInitializer(SpaceGridInitializer):
2512    def __init__(self):
2513        SpaceGridInitializer.__init__(self)
2514        self.origin             = None # 3x1 array
2515        self.axes               = None # 3x3 array
2516        self.axlabel            = None # 3x1 string list
2517        self.axgrid             = None # 3x1 string list
2518    #end def __init__
2519#end class RectilinearGridInitializer
2520
2521
2522class RectilinearGrid(SpaceGridBase):
2523    def __init__(self,initobj=None,options=None):
2524        SpaceGridBase.__init__(self,initobj,options)
2525        return
2526    #end def __init__
2527
2528    def init_special(self):
2529        self.origin         = None # 3x1 array
2530        self.axes           = None # 3x3 array
2531        self.axlabel        = None # 3x1 string list
2532        self.axinv          = None
2533        self.volume         = None
2534        self.dimensions     = None
2535        self.gmap           = None
2536        self.umin           = None
2537        self.umax           = None
2538        self.odu            = None
2539        self.dm             = None
2540        self.domain_uwidths = None
2541        return
2542    #end def init_special
2543
2544    def copy(self):
2545        return RectilinearGrid(self)
2546    #end def copy
2547
2548    def _reset_dynamic_methods(self):
2549        p2d=[self.points2domains_cartesian,   \
2550             self.points2domains_cylindrical, \
2551             self.points2domains_spherical]
2552        self.points2domains = p2d[self.coordinate]
2553
2554        p2u=[self.point2unit_cartesian,   \
2555             self.point2unit_cylindrical, \
2556             self.point2unit_spherical]
2557        self.point2unit = p2u[self.coordinate]
2558        return
2559    #end def _reset_dynamic_methods
2560
2561    def _unset_dynamic_methods(self):
2562        self.points2domains = None
2563        self.point2unit     = None
2564        return
2565    #end def _unset_dynamic_methods
2566
2567    def init_from_initializer(self,init):
2568        init.check_complete()
2569        for k,v in init.items():
2570            if k[0]!='_':
2571                self[k]=v
2572            #end if
2573        #end for
2574        self.initialize()
2575        return
2576    #end def init_from_initializer
2577
2578    def init_from_spacegrid(self,init):
2579        for q in SpaceGridBase.quantities:
2580            self[q].mean = init[q].mean.copy()
2581            self[q].error = init[q].error.copy()
2582        #end for
2583        array_type = type(array([1]))
2584        exclude = set(['point2unit','points2domains','points'])
2585        for k,v in init.items():
2586            if k[0]!='_':
2587                vtype = type(v)
2588                if k in SpaceGridBase.quantities:
2589                    self[k].mean  = v.mean.copy()
2590                    self[k].error = v.error.copy()
2591                elif vtype==array_type:
2592                    self[k] = v.copy()
2593                elif vtype==HDFgroup:
2594                    self[k] = v
2595                elif k in exclude:
2596                    None
2597                else:
2598                    self[k] = vtype(v)
2599                #end if
2600            #end for
2601        #end for
2602        self.points = init.points
2603        return
2604    #end def init_from_spacegrid
2605
2606    def init_from_hdfgroup(self,init):
2607        SpaceGridBase.init_from_hdfgroup(self,init)
2608        self.gmap=[init.gmap1,init.gmap2,init.gmap3]
2609        #set axlabel strings
2610        self.axlabel=list()
2611        for d in range(self.DIM):
2612            label = SpaceGridBase.axlabel_n2s[self.axtypes[d]]
2613            self.axlabel.append(label)
2614        #end for
2615        del self.axtypes
2616        for i in range(len(self.gmap)):
2617            self.gmap[i]=self.gmap[i].reshape((len(self.gmap[i]),))
2618        #end for
2619        return
2620    #end def init_from_hdfgroup
2621
2622
2623    def init_from_xmlelement(self,init):
2624        DIM=self.DIM
2625        self.axlabel=list()
2626        self.axgrid =list()
2627        #coord
2628        self.coord = init.coord
2629        #origin
2630        p1 = self.points[init.origin.p1]
2631        if 'p2' in init.origin:
2632            p2 = self.points[init.origin.p2]
2633        else:
2634            p2 = self.points['zero']
2635        #end if
2636        if 'fraction' in init.origin:
2637            frac = eval(init.origin.fraction)
2638        else:
2639            frac = 0.0
2640        self.origin = p1 + frac*(p2-p1)
2641        #axes
2642        self.axes = zeros((DIM,DIM))
2643        for d in range(DIM):
2644            self.error('alternative to exec needed')
2645            #exec('axis=init.axis'+str(d+1))
2646            p1 = self.points[axis.p1]
2647            if 'p2' in axis:
2648                p2 = self.points[axis.p2]
2649            else:
2650                p2 = self.points['zero']
2651            #end if
2652            if 'scale' in axis:
2653                scale = eval(axis.scale)
2654            else:
2655                scale = 1.0
2656            #end if
2657            for dd in range(DIM):
2658                self.axes[dd,d] = scale*(p1[dd]-p2[dd])
2659            #end for
2660            self.axlabel.append(axis.label)
2661            self.axgrid.append(axis.grid)
2662        #end for
2663        self.initialize()
2664        return
2665    #end def init_from_xmlelement
2666
2667    def initialize(self): #like qmcpack SpaceGridBase.initialize
2668        write=False
2669        succeeded=True
2670
2671        ndomains=-1
2672
2673        DIM = self.DIM
2674
2675        coord   = self.coord
2676        origin  = self.origin
2677        axes    = self.axes
2678        axlabel = self.axlabel
2679        axgrid  = self.axgrid
2680        del self.axgrid
2681
2682
2683
2684        ax_cartesian   = ["x" , "y"   , "z"    ]
2685        ax_cylindrical = ["r" , "phi" , "z"    ]
2686        ax_spherical   = ["r" , "phi" , "theta"]
2687
2688        cmap = dict()
2689        if(coord=="cartesian"):
2690            for d in range(DIM):
2691                cmap[ax_cartesian[d]]=d
2692                axlabel[d]=ax_cartesian[d]
2693            #end
2694        elif(coord=="cylindrical"):
2695            for d in range(DIM):
2696                cmap[ax_cylindrical[d]]=d
2697                axlabel[d]=ax_cylindrical[d]
2698            #end
2699        elif(coord=="spherical"):
2700            for d in range(DIM):
2701                cmap[ax_spherical[d]]=d
2702                axlabel[d]=ax_spherical[d]
2703            #end
2704        else:
2705            self.error("  Coordinate supplied to spacegrid must be cartesian, cylindrical, or spherical\n  You provided "+coord,exit=False)
2706            succeeded=False
2707        #end
2708        self.coordinate = SpaceGridBase.coord_s2n[self.coord]
2709        coordinate = self.coordinate
2710
2711
2712        #loop over spacegrid xml elements
2713        naxes =DIM
2714        # variables for loop
2715        utol = 1e-5
2716        dimensions=zeros((DIM,),dtype=int)
2717        umin=zeros((DIM,))
2718        umax=zeros((DIM,))
2719        odu=zeros((DIM,))
2720        ndu_per_interval=[None,None,None]
2721        gmap=[None,None,None]
2722        for dd in range(DIM):
2723            iaxis = cmap[axlabel[dd]]
2724            grid = axgrid[dd]
2725            #read in the grid contents
2726            #  remove spaces inside of parentheses
2727            inparen=False
2728            gtmp=''
2729            for gc in grid:
2730                if(gc=='('):
2731                    inparen=True
2732                    gtmp+=' '
2733                #end
2734                if(not(inparen and gc==' ')):
2735                    gtmp+=gc
2736                if(gc==')'):
2737                    inparen=False
2738                    gtmp+=' '
2739                #end
2740            #end
2741            grid=gtmp
2742            #  break into tokens
2743            tokens = grid.split()
2744            if(write):
2745                print("      grid   = ",grid)
2746                print("      tokens = ",tokens)
2747            #end
2748            #  count the number of intervals
2749            nintervals=0
2750            for t in tokens:
2751                if t[0]!='(':
2752                    nintervals+=1
2753                #end
2754            #end
2755            nintervals-=1
2756            if(write):
2757                print("      nintervals = ",nintervals)
2758            #end if
2759            #  allocate temporary interval variables
2760            ndom_int = zeros((nintervals,),dtype=int)
2761            du_int = zeros((nintervals,))
2762            ndu_int = zeros((nintervals,),dtype=int)
2763            #  determine number of domains in each interval and the width of each domain
2764            u1=1.0*eval(tokens[0])
2765            umin[iaxis]=u1
2766            if(abs(u1)>1.0000001):
2767                self.error("  interval endpoints cannot be greater than 1\n  endpoint provided: "+str(u1),exit=False)
2768                succeeded=False
2769            #end
2770            is_int=False
2771            has_paren_val=False
2772            interval=-1
2773            for i in range(1,len(tokens)):
2774                if not tokens[i].startswith('('):
2775                    u2=1.0*eval(tokens[i])
2776                    umax[iaxis]=u2
2777                    if(not has_paren_val):
2778                        du_i=u2-u1
2779                    #end
2780                    has_paren_val=False
2781                    interval+=1
2782                    if(write):
2783                        print("      parsing interval ",interval," of ",nintervals)
2784                        print("      u1,u2 = ",u1,",",u2)
2785                    #end
2786                    if(u2<u1):
2787                        self.error("  interval ("+str(u1)+","+str(u2)+") is negative",exit=False)
2788                        succeeded=False
2789                    #end
2790                    if(abs(u2)>1.0000001):
2791                        self.error("  interval endpoints cannot be greater than 1\n  endpoint provided: "+str(u2),exit=False)
2792                        succeeded=False
2793                    #end
2794                    if(is_int):
2795                        du_int[interval]=(u2-u1)/ndom_i
2796                        ndom_int[interval]=ndom_i
2797                    else:
2798                        du_int[interval]=du_i
2799                        ndom_int[interval]=floor((u2-u1)/du_i+.5)
2800                        if(abs(u2-u1-du_i*ndom_int[interval])>utol):
2801                            self.error("  interval ("+str(u1)+","+str(u2)+") not divisible by du="+str(du_i),exit=False)
2802                            succeeded=False
2803                        #end
2804                    #end
2805                    u1=u2
2806                else:
2807                    has_paren_val=True
2808                    paren_val=tokens[i][1:len(tokens[i])-1]
2809                    if(write):
2810                        print("      interval spacer = ",paren_val)
2811                    #end if
2812                    is_int=tokens[i].find(".")==-1
2813                    if(is_int):
2814                        ndom_i = eval(paren_val)
2815                        du_i = -1.0
2816                    else:
2817                        ndom_i = 0
2818                        du_i = eval(paren_val)
2819                    #end
2820                #end
2821            #end
2822            # find the smallest domain width
2823            du_min=min(du_int)
2824            odu[iaxis]=1.0/du_min
2825            # make sure it divides into all other domain widths
2826            for i in range(len(du_int)):
2827                ndu_int[i]=floor(du_int[i]/du_min+.5)
2828                if(abs(du_int[i]-ndu_int[i]*du_min)>utol):
2829                    self.error("interval {0} of axis {1} is not divisible by smallest subinterval {2}".format(i+1,iaxis+1,du_min),exit=False)
2830                    succeeded=False
2831                #end
2832            #end
2833
2834            if(write):
2835                print("      interval breakdown")
2836                print("        interval,ndomains,nsubdomains_per_domain")
2837                for i in range(len(ndom_int)):
2838                    print("      ",i,",",ndom_int[i],",",ndu_int[i])
2839                #end
2840            #end
2841
2842            # set up the interval map such that gmap[u/du]==domain index
2843            gmap[iaxis] = zeros((floor((umax[iaxis]-umin[iaxis])*odu[iaxis]+.5),),dtype=int)
2844            n=0
2845            nd=-1
2846            if(write):
2847                print("        i,j,k    ax,n,nd  ")
2848            #end if
2849            for i in range(len(ndom_int)):
2850                for j in range(ndom_int[i]):
2851                    nd+=1
2852                    for k in range(ndu_int[i]):
2853                        gmap[iaxis][n]=nd
2854                        if(write):
2855                            print("      ",i,",",j,",",k,"    ",iaxis,",",n,",",nd)
2856                        #end
2857                        n+=1
2858                    #end
2859                #end
2860            #end
2861            dimensions[iaxis]=nd+1
2862            #end read in the grid contents
2863
2864            #save interval width information
2865            ndom_tot=sum(ndom_int)
2866            ndu_per_interval[iaxis] = zeros((ndom_tot,),dtype=int)
2867            idom=0
2868            for i in range(len(ndom_int)):
2869                for ii in range(ndom_int[i]):
2870                    ndu_per_interval[iaxis][idom] = ndu_int[i]
2871                    idom+=1
2872                #end
2873            #end
2874        #end
2875
2876        axinv = inv(axes)
2877
2878        #check that all axis grid values fall in the allowed intervals
2879        cartmap = dict()
2880        for d in range(DIM):
2881            cartmap[ax_cartesian[d]]=d
2882        #end for
2883        for d in range(DIM):
2884            if axlabel[d] in cartmap:
2885                if(umin[d]<-1.0 or umax[d]>1.0):
2886                    self.error("  grid values for {0} must fall in [-1,1]\n".format(axlabel[d])+"  interval provided: [{0},{1}]".format(umin[d],umax[d]),exit=False)
2887                    succeeded=False
2888                #end if
2889            elif(axlabel[d]=="phi"):
2890                if(abs(umin[d])+abs(umax[d])>1.0):
2891                    self.error("  phi interval cannot be longer than 1\n  interval length provided: {0}".format(abs(umin[d])+abs(umax[d])),exit=False)
2892                    succeeded=False
2893                #end if
2894            else:
2895                if(umin[d]<0.0 or umax[d]>1.0):
2896                    self.error("  grid values for {0} must fall in [0,1]\n".format(axlabel[d])+"  interval provided: [{0},{1}]".format(umin[d],umax[d]),exit=False)
2897                    succeeded=False
2898                #end if
2899            #end if
2900        #end for
2901
2902
2903        #set grid dimensions
2904        # C/Python style indexing
2905        dm=array([0,0,0],dtype=int)
2906        dm[0] = dimensions[1]*dimensions[2]
2907        dm[1] = dimensions[2]
2908        dm[2] = 1
2909
2910        ndomains=prod(dimensions)
2911
2912        volume = abs(det(axes))*8.0#axes span only one octant
2913
2914        #compute domain volumes, centers, and widths
2915        domain_volumes = zeros((ndomains,))
2916        domain_centers = zeros((ndomains,DIM))
2917        domain_uwidths = zeros((ndomains,DIM))
2918        interval_centers = [None,None,None]
2919        interval_widths  = [None,None,None]
2920        for d in range(DIM):
2921            nintervals = len(ndu_per_interval[d])
2922            interval_centers[d] = zeros((nintervals))
2923            interval_widths[d] = zeros((nintervals))
2924            interval_widths[d][0]=ndu_per_interval[d][0]/odu[d]
2925            interval_centers[d][0]=interval_widths[d][0]/2.0+umin[d]
2926            for i in range(1,nintervals):
2927                interval_widths[d][i] = ndu_per_interval[d][i]/odu[d]
2928                interval_centers[d][i] = interval_centers[d][i-1] \
2929                    +.5*(interval_widths[d][i]+interval_widths[d][i-1])
2930            #end for
2931        #end for
2932        du,uc,ubc,rc = zeros((DIM,)),zeros((DIM,)),zeros((DIM,)),zeros((DIM,))
2933        vol = -1e99
2934        vol_tot=0.0
2935        vscale = abs(det(axes))
2936
2937        for i in range(dimensions[0]):
2938            for j in range(dimensions[1]):
2939                for k in range(dimensions[2]):
2940                    idomain = dm[0]*i + dm[1]*j + dm[2]*k
2941                    du[0] = interval_widths[0][i]
2942                    du[1] = interval_widths[1][j]
2943                    du[2] = interval_widths[2][k]
2944                    uc[0] = interval_centers[0][i]
2945                    uc[1] = interval_centers[1][j]
2946                    uc[2] = interval_centers[2][k]
2947
2948                    if(coordinate==SpaceGridBase.cartesian):
2949                        vol=du[0]*du[1]*du[2]
2950                        ubc=uc
2951                    elif(coordinate==SpaceGridBase.cylindrical):
2952                        uc[1]=2.0*pi*uc[1]-pi
2953                        du[1]=2.0*pi*du[1]
2954                        vol=uc[0]*du[0]*du[1]*du[2]
2955                        ubc[0]=uc[0]*cos(uc[1])
2956                        ubc[1]=uc[0]*sin(uc[1])
2957                        ubc[2]=uc[2]
2958                    elif(coordinate==SpaceGridBase.spherical):
2959                        uc[1]=2.0*pi*uc[1]-pi
2960                        du[1]=2.0*pi*du[1]
2961                        uc[2]=    pi*uc[2]
2962                        du[2]=    pi*du[2]
2963                        vol=(uc[0]*uc[0]+du[0]*du[0]/12.0)*du[0] \
2964                           *du[1]                                \
2965                           *2.0*sin(uc[2])*sin(.5*du[2])
2966                        ubc[0]=uc[0]*sin(uc[2])*cos(uc[1])
2967                        ubc[1]=uc[0]*sin(uc[2])*sin(uc[1])
2968                        ubc[2]=uc[0]*cos(uc[2])
2969                    #end if
2970                    vol*=vscale
2971
2972                    vol_tot+=vol
2973
2974                    rc = dot(axes,ubc) + origin
2975
2976                    domain_volumes[idomain] = vol
2977                    for d in range(DIM):
2978                        domain_uwidths[idomain,d] = du[d]
2979                        domain_centers[idomain,d] = rc[d]
2980                    #end for
2981                #end for
2982            #end for
2983        #end for
2984
2985        #find the actual volume of the grid
2986        du = umax-umin
2987        uc = .5*(umax+umin)
2988        if coordinate==SpaceGridBase.cartesian:
2989            vol=du[0]*du[1]*du[2]
2990        elif coordinate==SpaceGridBase.cylindrical:
2991            uc[1]=2.0*pi*uc[1]-pi
2992            du[1]=2.0*pi*du[1]
2993            vol=uc[0]*du[0]*du[1]*du[2]
2994        elif coordinate==SpaceGridBase.spherical:
2995            uc[1]=2.0*pi*uc[1]-pi
2996            du[1]=2.0*pi*du[1]
2997            uc[2]=    pi*uc[2]
2998            du[2]=    pi*du[2]
2999            vol=(uc[0]*uc[0]+du[0]*du[0]/12.0)*du[0]*du[1]*2.0*sin(uc[2])*sin(.5*du[2])
3000        #end if
3001        volume = vol*abs(det(axes))
3002
3003        for q in SpaceGridBase.quantities:
3004            self[q].mean  = zeros((ndomains,))
3005            self[q].error = zeros((ndomains,))
3006        #end for
3007
3008        #save the results
3009        self.axinv              = axinv
3010        self.volume             = volume
3011        self.gmap               = gmap
3012        self.umin               = umin
3013        self.umax               = umax
3014        self.odu                = odu
3015        self.dm                 = dm
3016        self.dimensions         = dimensions
3017        self.ndomains           = ndomains
3018        self.domain_volumes     = domain_volumes
3019        self.domain_centers     = domain_centers
3020        self.domain_uwidths     = domain_uwidths
3021
3022
3023        #succeeded = succeeded and check_grid()
3024
3025        if(self.init_exit_fail and not succeeded):
3026            self.error(" in def initialize")
3027        #end
3028
3029        return succeeded
3030    #end def initialize
3031
3032    def point2unit_cartesian(self,point):
3033        u = dot(self.axinv,(point-self.origin))
3034        return u
3035    #end def point2unit_cartesian
3036
3037    def point2unit_cylindrical(self,point):
3038        ub = dot(self.axinv,(point-self.origin))
3039        u=zeros((self.DIM,))
3040        u[0] = sqrt(ub[0]*ub[0]+ub[1]*ub[1])
3041        u[1] = atan2(ub[1],ub[0])*o2pi+.5
3042        u[2] = ub[2]
3043        return u
3044    #end def point2unit_cylindrical
3045
3046    def point2unit_spherical(self,point):
3047        ub = dot(self.axinv,(point-self.origin))
3048        u=zeros((self.DIM,))
3049        u[0] = sqrt(ub[0]*ub[0]+ub[1]*ub[1]+ub[2]*ub[2])
3050        u[1] = atan2(ub[1],ub[0])*o2pi+.5
3051        u[2] = acos(ub[2]/u[0])*o2pi*2.0
3052        return u
3053    #end def point2unit_spherical
3054
3055    def points2domains_cartesian(self,points,domains,points_outside):
3056        u  = zeros((self.DIM,))
3057        iu = zeros((self.DIM,),dtype=int)
3058        ndomains=-1
3059        npoints,ndim = points.shape
3060        for p in range(npoints):
3061            u = dot(self.axinv,(points[p]-self.origin))
3062            if (u>self.umin).all() and (u<self.umax).all():
3063                points_outside[p]=False
3064                iu=floor( (u-self.umin)*self.odu )
3065                iu[0] = self.gmap[0][iu[0]]
3066                iu[1] = self.gmap[1][iu[1]]
3067                iu[2] = self.gmap[2][iu[2]]
3068                ndomains+=1
3069                domains[ndomains,0] = p
3070                domains[ndomains,1] = dot(self.dm,iu)
3071            #end
3072        #end
3073        ndomains+=1
3074        return ndomains
3075    #end def points2domains_cartesian
3076
3077    def points2domains_cylindrical(self,points,domains,points_outside):
3078        u  = zeros((self.DIM,))
3079        iu = zeros((self.DIM,),dtype=int)
3080        ndomains=-1
3081        npoints,ndim = points.shape
3082        for p in range(npoints):
3083            ub = dot(self.axinv,(points[p]-self.origin))
3084            u[0] = sqrt(ub[0]*ub[0]+ub[1]*ub[1])
3085            u[1] = atan2(ub[1],ub[0])*o2pi+.5
3086            u[2] = ub[2]
3087            if (u>self.umin).all() and (u<self.umax).all():
3088                points_outside[p]=False
3089                iu=floor( (u-self.umin)*self.odu )
3090                iu[0] = self.gmap[0][iu[0]]
3091                iu[1] = self.gmap[1][iu[1]]
3092                iu[2] = self.gmap[2][iu[2]]
3093                ndomains+=1
3094                domains[ndomains,0] = p
3095                domains[ndomains,1] = dot(self.dm,iu)
3096            #end
3097        #end
3098        ndomains+=1
3099        return ndomains
3100    #end def points2domains_cylindrical
3101
3102    def points2domains_spherical(self,points,domains,points_outside):
3103        u  = zeros((self.DIM,))
3104        iu = zeros((self.DIM,),dtype=int)
3105        ndomains=-1
3106        npoints,ndim = points.shape
3107        for p in range(npoints):
3108            ub = dot(self.axinv,(points[p]-self.origin))
3109            u[0] = sqrt(ub[0]*ub[0]+ub[1]*ub[1]+ub[2]*ub[2])
3110            u[1] = atan2(ub[1],ub[0])*o2pi+.5
3111            u[2] = acos(ub[2]/u[0])*o2pi*2.0
3112            if (u>self.umin).all() and (u<self.umax).all():
3113                points_outside[p]=False
3114                iu=floor( (u-self.umin)*self.odu )
3115                iu[0] = self.gmap[0][iu[0]]
3116                iu[1] = self.gmap[1][iu[1]]
3117                iu[2] = self.gmap[2][iu[2]]
3118                ndomains+=1
3119                domains[ndomains,0] = p
3120                domains[ndomains,1] = dot(self.dm,iu)
3121            #end
3122        #end
3123        ndomains+=1
3124        return ndomains
3125    #end def points2domains_spherical
3126
3127
3128    def shift_origin(self,shift):
3129        self.origin += shift
3130        for i in range(self.domain_centers.shape[0]):
3131            self.domain_centers[i,:] += shift
3132        #end for
3133        return
3134    #end def shift_origin
3135
3136
3137    def set_origin(self,origin):
3138        self.shift_origin(origin-self.origin)
3139        return
3140    #end def set_origin
3141
3142
3143    def interpolate_across(self,quantities,spacegrids,outside,integration=False,warn=False):
3144        #if the grid is to be used for integration confirm that domains
3145        #  of this spacegrid subdivide source spacegrid domains
3146        if integration:
3147            #setup checking variables
3148            am_cartesian   = self.coordinate==Spacegrid.cartesian
3149            am_cylindrical = self.coordinate==Spacegrid.cylindrical
3150            am_spherical   = self.coordinate==Spacegrid.spherical
3151            fine_interval_centers = [None,None,None]
3152            fine_interval_domains = [None,None,None]
3153            for d in range(self.DIM):
3154                ndu = round( (self.umax[d]-self.umin[d])*self.odu[d] )
3155                if len(self.gmap[d])!=ndu:
3156                    self.error('ndu is different than len(gmap)')
3157                #end if
3158                du = 1./self.odu[d]
3159                fine_interval_centers[d] = self.umin + .5*du + du*array(list(range(ndu)))
3160                find_interval_domains[d] = zeros((ndu,))
3161            #end for
3162            #checks are done on each source spacegrid to determine interpolation compatibility
3163            for s in spacegrids:
3164                # all the spacegrids must have coordinate system to satisfy this
3165                if s.coordinate!=self.coordinate:
3166                    if warn:
3167                        self.warn('SpaceGrids must have same coordinate for interpolation')
3168                    #end if
3169                    return False
3170                #end if
3171                # each spacegrids' axes must be int mult of this spacegrid's axes
3172                #   (this ensures that isosurface shapes conform)
3173                tile = dot(self.axinv,s.axes)
3174                for d in range(self.DIM):
3175                    if not is_integer(tile[d,d]):
3176                        if warn:
3177                            self.warn("source axes must be multiples of interpolant's axes")
3178                        #end if
3179                        return False
3180                    #end if
3181                #end for
3182                # origin must be at r=0 for cylindrical or spherical
3183                uo = self.point2unit(s.origin)
3184                if am_cylindrical or am_spherical:
3185                    if uo[0]>1e-6:
3186                        if warn:
3187                            self.warn('source origin must lie at interpolant r=0')
3188                        #end if
3189                        return False
3190                    #end if
3191                #end if
3192                # fine meshes must align
3193                #  origin must be an integer multiple of smallest dom width
3194                if am_cylindrical:
3195                    mdims=[2]
3196                elif am_cartesian:
3197                    mdims=[0,1,2]
3198                else:
3199                    mdims=[]
3200                #end if
3201                for d in mdims:
3202                    if not is_integer(uo[d]*self.odu[d]):
3203                        if warn:
3204                            self.warn('source origin does not lie on interpolant fine mesh')
3205                        #end if
3206                        return False
3207                    #end if
3208                #end for
3209                #  smallest dom width must be multiple of this smallest dom width
3210                for d in range(self.DIM):
3211                    if not is_integer(self.odu[d]/s.odu[d]):
3212                        if warn:
3213                            self.warn('smallest source domain width must be a multiple of interpolants smallest domain width')
3214                        #end if
3215                        return False
3216                    #end if
3217                #end for
3218                #  each interval along each direction for interpolant must map to only one source interval
3219                #    construct points at each fine interval center of interpolant, run them through source gmap to get interval indices
3220                for d in range(self.DIM):
3221                    fine_interval_domains[d][:]=-2
3222                    gmlen = len(s.gmap[d])
3223                    for i in range(len(fine_interval_centers[d])):
3224                        uc = fine_interval_centers[d][i]
3225                        ind = floor((uc-s.umin[d])*s.odu[d])
3226                        if ind < gmlen:
3227                            idom=s.gmap[d][ind]
3228                        else:
3229                            idom=-1
3230                        #end if
3231                        fine_interval_domains[d][i]=idom
3232                    #end for
3233                    cind   = self.gmap[d][0]
3234                    istart = 0
3235                    iend   = 0
3236                    for i in range(len(self.gmap[d])):
3237                        if self.gmap[d][i]==cind:
3238                            iend+=1
3239                        else:
3240                            source_ind = fine_interval_domains[istart]
3241                            for j in range(istart+1,iend):
3242                                if fine_interval_domains[j]!=source_ind:
3243                                    if warn:
3244                                        self.warn('an interpolant domain must not fall on multiple source domains')
3245                                    #end if
3246                                    return False
3247                                #end if
3248                            #end for
3249                            istart=iend
3250                        #end if
3251                    #end for
3252                #end for
3253            #end for
3254        #end if
3255
3256
3257        #get the list of domains points from this grid fall in
3258        #  and interpolate requested quantities on them
3259        domain_centers  = self.domain_centers
3260        domind = zeros((self.ndomains,2),dtype=int)
3261        domout = ones((self.ndomains,) ,dtype=int)
3262        for s in spacegrids:
3263            domind[:,:] = -1
3264            ndomin = s.points2domains(domain_centers,domind,domout)
3265            for q in quantities:
3266                self[q].mean[domind[0:ndomin,0]]  = s[q].mean[domind[0:ndomin,1]].copy()
3267                self[q].error[domind[0:ndomin,0]] = s[q].error[domind[0:ndomin,1]].copy()
3268            #end for
3269        #end for
3270        for d in range(self.ndomains):
3271            if domout[d]:
3272                for q in quantities:
3273                    self[q].mean[d]  = outside[q].mean
3274                    self[q].error[d] = outside[q].error
3275                #end for
3276            #end if
3277        #end for
3278        return True
3279    #end def interpolate_across
3280
3281
3282    def interpolate(self,points,quantities=None):
3283        if quantities==None:
3284            quantities=SpaceGridBase.quantities
3285        #end if
3286        npoints,ndim = points.shape
3287        ind = empty((npoints,2),dtype=int)
3288        out = ones((npoints,) ,dtype=int)
3289        nin = self.points2domains(points,ind,out)
3290        result = QAobject()
3291        for q in quantities:
3292            result._add_attribute(q,QAobject())
3293            result[q].mean  = zeros((npoints,))
3294            result[q].error = zeros((npoints,))
3295            result[q].mean[ind[0:nin,0]]  = self[q].mean[ind[0:nin,1]].copy()
3296            result[q].error[ind[0:nin,0]] = self[q].error[ind[0:nin,1]].copy()
3297        #end for
3298        return result
3299    #end def interpolate
3300
3301
3302    def isosurface(self,quantity,contours=5,origin=None):
3303        if quantity not in SpaceGridBase.quantities:
3304            self.error()
3305        #end if
3306        dimensions = self.dimensions
3307        if origin==None:
3308            points     = self.domain_centers
3309        else:
3310            npoints,ndim = self.domain_centers.shape
3311            points = empty((npoints,ndim))
3312            for i in range(npoints):
3313                points[i,:] = origin + self.domain_centers[i,:]
3314            #end for
3315        #end if
3316        scalars    = self[quantity].mean
3317        name       = quantity
3318        self.plotter.isosurface(points,scalars,contours,dimensions,name)
3319        return
3320    #end def isosurface
3321
3322
3323    def surface_slice(self,quantity,x,y,z,options=None):
3324        if quantity not in SpaceGridBase.quantities:
3325            self.error()
3326        #end if
3327        points = empty( (x.size,self.DIM) )
3328        points[:,0] = x.ravel()
3329        points[:,1] = y.ravel()
3330        points[:,2] = z.ravel()
3331        val = self.interpolate(points,[quantity])
3332        scalars = val[quantity].mean
3333        scalars.shape = x.shape
3334        self.plotter.surface_slice(x,y,z,scalars,options)
3335        return
3336    #end def surface_slice
3337
3338
3339    def plot_axes(self,color=None,radius=.025,origin=None):
3340        if color is None:
3341            color = (0.,0,0)
3342        #end if
3343        if origin is None:
3344            origin = array([0.,0,0])
3345        #end if
3346        colors=array([[1.,0,0],[0,1.,0],[0,0,1.]])
3347        for d in range(self.DIM):
3348            a=self.axes[:,d]+origin
3349            ax=array([-a[0],a[0]])
3350            ay=array([-a[1],a[1]])
3351            az=array([-a[2],a[2]])
3352            self.plotter.plot3d(ax,ay,az,tube_radius=radius,color=tuple(colors[:,d]))
3353        #end for
3354        return
3355    #end def plot_axes
3356
3357    def plot_box(self,color=None,radius=.025,origin=None):
3358        if color is None:
3359            color = (0.,0,0)
3360        #end if
3361        if origin is None:
3362            origin = array([0.,0,0])
3363        #end if
3364        p = self.points
3365        p1=p.cmmm+origin
3366        p2=p.cmpm+origin
3367        p3=p.cpmm+origin
3368        p4=p.cppm+origin
3369        p5=p.cmmp+origin
3370        p6=p.cmpp+origin
3371        p7=p.cpmp+origin
3372        p8=p.cppp+origin
3373        bline = array([p1,p2,p4,p3,p1,p5,p6,p8,p7,p5,p7,p3,p4,p8,p6,p2])
3374        self.plotter.plot3d(bline[:,0],bline[:,1],bline[:,2],color=color)
3375        return
3376    #end def plot_box
3377#end class RectilinearGrid
3378
3379
3380
3381
3382
3383class VoronoiGridInitializer(SpaceGridInitializer):
3384    def __init__(self):
3385        SpaceGridInitializer.__init__(self)
3386    #end def __init__
3387#end class VoronoiGridInitializer
3388
3389
3390class VoronoiGrid(SpaceGridBase):
3391    def __init__(self,initobj=None,options=None):
3392        SpaceGridBase.__init__(self,initobj,options)
3393        return
3394    #end def __init__
3395
3396    def copy(self,other):
3397        return VoronoiGrid(other)
3398    #end def copy
3399
3400
3401    def reorder_atomic_data(self,imap):
3402        for q in self.quantities:
3403            qv = self[q]
3404            qv.mean  = qv.mean[...,imap]
3405            qv.error = qv.error[...,imap]
3406        #end for
3407        if 'data' in self:
3408            data = self.data
3409            for q in self.quantities:
3410                data[q] = data[q][...,imap,:]
3411            #end for
3412        #end if
3413    #end def reorder_atomic_data
3414#end class VoronoiGrid
3415
3416
3417
3418
3419
3420
3421def SpaceGrid(init,opts=None):
3422    SpaceGrid.count+=1
3423
3424    iname = init.__class__.__name__
3425    if iname=='HDFgroup':
3426        coordinate = init.coordinate[0]
3427    #end if
3428    coord = SpaceGrid.coord_n2s[coordinate]
3429
3430    if coord in SpaceGrid.rect:
3431        return RectilinearGrid(init,opts)
3432    elif coord=='voronoi':
3433        return VoronoiGrid(init,opts)
3434    else:
3435        print('SpaceGrid '+coord+' has not been implemented, exiting...')
3436        exit()
3437    #end if
3438
3439#end def SpaceGrid
3440SpaceGrid.count = 0
3441SpaceGrid.coord_n2s = SpaceGridBase.coord_n2s
3442SpaceGrid.rect = set(['cartesian','cylindrical','spherical'])
3443
3444