1##################################################################
2##  (c) Copyright 2015-  by Jaron T. Krogel                     ##
3##################################################################
4
5
6#====================================================================#
7#  qmcpack_input.py                                                  #
8#    Supports I/O and manipulation of QMCPACK's xml input file.      #
9#                                                                    #
10#  Content summary:                                                  #
11#    QmcpackInput                                                    #
12#      SimulationInput class for QMCPACK.                            #
13#      Represents the QMCPACK input file as a nested collection of   #
14#        objects mirroring the structure of the XML input.           #
15#      XML attributes and <parameter/> elements are joined into      #
16#        a common keyword representation.                            #
17#      Full tranlations of test QMCPACK input files into executable  #
18#        Python code can be found at the end of this file for        #
19#        reference.                                                  #
20#                                                                    #
21#    BundledQmcpackInput                                             #
22#      Class represents QMCPACK input when provided in 'bundled'     #
23#        form, i.e. an input file containing a list of XML input     #
24#        files.                                                      #
25#      A BundledQmcpackInput object contains many QmcpackInput       #
26#        objects.                                                    #
27#                                                                    #
28#    QmcpackInputTemplate                                            #
29#      Class supports Nexus interaction with text input files        #
30#        provided by users as 'templates'.                           #
31#      Users can mark keywords in a template input file, generate    #
32#        variations on that input file through this class, and       #
33#        then use the resulting input files in Nexus workflows.      #
34#                                                                    #
35#    generate_qmcpack_input                                          #
36#      User-facing function to compose specific QMCPACK input files. #
37#      Though any QMCPACK input file can be composed directly with   #
38#        Python code, only a more limited selection can be           #
39#        generated directly.                                         #
40#      Calls many other functions to generate specific XML elements: #
41#        generate_simulationcell                                     #
42#        generate_particlesets                                       #
43#        generate_sposets                                            #
44#        generate_sposet_builder                                     #
45#        generate_bspline_builder                                    #
46#        generate_heg_builder                                        #
47#        generate_determinantset                                     #
48#        generate_determinantset_old                                 #
49#        generate_hamiltonian                                        #
50#        generate_jastrows                                           #
51#        generate_jastrow                                            #
52#        generate_jastrow1                                           #
53#        generate_bspline_jastrow2                                   #
54#        generate_pade_jastrow2                                      #
55#        generate_jastrow2                                           #
56#        generate_jastrow3                                           #
57#        generate_kspace_jastrow                                     #
58#        generate_opt                                                #
59#        generate_opts                                               #
60#                                                                    #
61#   QIxml, Names                                                     #
62#     Class represents a generic XML element.                        #
63#     Supports read/write and setting default values.                #
64#     Derived classes represent specific elements and contain        #
65#       a keyword specification for each element.                    #
66#     Support for new XML elements is enabled by creating            #
67#       a corresponding class with the allowed keywords, etc.,       #
68#       and then adding it to the global 'classes' list below.       #
69#     See classes simulation, project, application, random, include, #
70#       mcwalkerset, qmcsystem, simulationcell, particleset, group,  #
71#       sposet, bspline_builder, heg_builder, composite_builder,     #
72#       wavefunction, determinantset, basisset, grid, atomicbasisset,#
73#       basisgroup, radfunc, slaterdeterminant, determinant,         #
74#       occupation, multideterminant, detlist, ci, jastrow1,         #
75#       jastrow2, jastrow3, correlation, var, coefficients,          #
76#       coefficient, hamiltonian, coulomb, constant, pseudopotential,#
77#       pseudo, mpc, localenergy, energydensity, reference_points,   #
78#       spacegrid, origin, axis, chiesa, density, nearestneighbors,  #
79#       neighbor_trace, dm1b, spindensity, structurefactor, init,    #
80#       scalar_traces, array_traces, particle_traces, traces, loop,  #
81#       linear, cslinear, vmc, dmc.                                  #
82#                                                                    #
83#   QIxmlFactory                                                     #
84#     Class supports comprehension of XML elements that share the    #
85#       same XML tag (e.g. <pairpot/>), but have many different      #
86#       and distinct realizations (e.g. coulomb, mpc, etc.).         #
87#     See QIxmlFactory instances pairpot, estimator,                 #
88#       sposet_builder, jastrow, and qmc.                            #
89#                                                                    #
90#   collection                                                       #
91#     Container class representing an ordered set of plural XML      #
92#       elements named according to an XML attribute (e.g. named     #
93#       particlesets).                                               #
94#     XML elements are listed by name in the collection to allow     #
95#       intuitive interactive navigation by the user.                #
96#                                                                    #
97#   Param (single global instance is param)                          #
98#     Handles all reads/writes of attributes and text contents       #
99#       of XML elements.                                             #
100#     Automatically distinguishes string, int, float, and array      #
101#       data of any of these types.                                  #
102#                                                                    #
103#   Note: those seeking to add new XML elements should be aware      #
104#     of the global data structures described below.                 #
105#                                                                    #
106#   classes                                                          #
107#     Global list of QIxml classes.                                  #
108#     Each new QIxml class must be added here.                       #
109#                                                                    #
110#   types                                                            #
111#     Global dict of elements that are actually simple types (to be  #
112#     interpreted the same way as <parameter/> elements) and also    #
113#     factory instances.                                             #
114#                                                                    #
115#   plurals                                                          #
116#     Global obj of elements that can be plural (i.e. multiple can   #
117#     be present with the same tag.  A QmcpackInput object will      #
118#     contain collection instances containing the related XML        #
119#     elements.  Each collection is named according to the keys in   #
120#     plurals.                                                       #
121#                                                                    #
122#   Names.set_expanded_names                                         #
123#     Function to specify mappings from all-lowercase names used     #
124#     in QmcpackInput objects to the names expected by QMCPACK in    #
125#     the actual XML input file.  QMCPACK does not follow any        #
126#     consistent naming convention (camel-case, hyphenation, and     #
127#     separation via underscore are all present in names) and some   #
128#     names are case-sensitive while others are not.  This function  #
129#     allows developers to cope with this diversity upon write,      #
130#     while preserving a uniform representation (all lowercase) in   #
131#     QmcpackInput objects.                                          #
132#                                                                    #
133#====================================================================#
134
135
136import os
137import inspect
138import keyword
139from numpy import fromstring,empty,array,float64,\
140    loadtxt,ndarray,dtype,sqrt,pi,arange,exp,eye,\
141    ceil,mod,dot,abs,identity,floor,linalg,where,isclose
142from io import StringIO
143from superstring import string2val
144from generic import obj,hidden
145from xmlreader import XMLreader,XMLelement
146from developer import DevBase,error
147from periodic_table import is_element
148from structure import Structure,Jellium,get_kpath
149from physical_system import PhysicalSystem
150from simulation import SimulationInput,SimulationInputTemplate
151from pwscf_input import array_to_string as pwscf_array_string
152from debug import ci as interact
153
154yesno_dict     = {True:'yes' ,False:'no'}
155truefalse_dict = {True:'true',False:'false'}
156onezero_dict   = {True:'1'   ,False:'0'}
157boolmap={'yes':True,'no':False,'true':True,'false':False,'1':True,'0':False}
158
159def is_int(var):
160    try:
161        int(var)
162        return True
163    except ValueError:
164        return False
165    #end try
166#end def is_int
167
168def is_float(var):
169    try:
170        float(var)
171        return True
172    except ValueError:
173        return False
174    #end try
175#end def is_float
176
177def is_array(var,type):
178    try:
179        if isinstance(var,str):
180            array(var.split(),type)
181        else:
182            array(var,type)
183        #end if
184        return True
185    except ValueError:
186        return False
187    #end try
188#end def is_float_array
189
190
191def attribute_to_value(attr):
192    if is_int(attr):
193        val = int(attr)
194    elif is_float(attr):
195        val = float(attr)
196    elif is_array(attr,int):
197        val = array(attr.split(),int)
198        if val.size==9:
199            val.shape = 3,3
200        #end if
201    elif is_array(attr,float):
202        val = array(attr.split(),float)
203    else:
204        val = attr
205    #end if
206    return val
207#end def attribute_to_value
208
209
210
211#local write types
212def yesno(var):
213    return render_bool(var,'yes','no')
214#end def yesno
215
216def yesnostr(var):
217    if isinstance(var,str):
218        return var
219    else:
220        return yesno(var)
221    #end if
222#end def yesnostr
223
224def onezero(var):
225    return render_bool(var,'1','0')
226#end def onezero
227
228def truefalse(var):
229    return render_bool(var,'true','false')
230#end def onezero
231
232def render_bool(var,T,F):
233    if isinstance(var,bool) or var in (1,0):
234        if var:
235            return T
236        else:
237            return F
238        #end if
239    elif var in (T,F):
240        return var
241    else:
242        error('Invalid QMCPACK input encountered.\nUser provided an invalid value of "{}" when yes/no was expected.\nValid options are: "{}", "{}", True, False, 1, 0'.format(var,T,F))
243
244    #end if
245#end def render_bool
246
247
248bool_write_types = set([yesno,onezero,truefalse])
249
250
251
252
253class QIobj(DevBase):
254
255    afqmc_mode = False
256
257    # user settings
258    permissive_read  = False
259    permissive_write = False
260    permissive_init  = False
261
262    @staticmethod
263    def settings(
264        permissive_read  = False,
265        permissive_write = False,
266        permissive_init  = False,
267        ):
268        QIobj.permissive_read  = permissive_read
269        QIobj.permissive_write = permissive_write
270        QIobj.permissive_init  = permissive_init
271    #end def settings
272#end class QIobj
273
274
275class meta(obj):
276    None
277#end class meta
278
279
280class section(QIobj):
281    def __init__(self,*args,**kwargs):
282        self.args   = args
283        self.kwargs = kwargs
284    #end def __init__
285#end class section
286
287
288class collection(hidden):
289    def __init__(self,*elements):
290        hidden.__init__(self)
291        if len(elements)==1 and isinstance(elements[0],list):
292            elements = elements[0]
293        elif len(elements)==1 and isinstance(elements[0],collection):
294            elements = elements[0].__dict__.values()
295        #end if
296        self.hidden().order = []
297        for element in elements:
298            self.add(element)
299        #end for
300    #end def __init__
301
302    def __setitem__(self,name,value):
303        #self.error('elements can only be set via the add function')
304        self.add(value,key=name)
305    #end def __setitem__
306
307    def __delitem__(self,name):
308        #self.error('elements can only be deleted via the remove function')
309        self.remove(name)
310    #end def __delitem__
311
312    def __setattr__(self,name,value):
313        #self.error('elements can only be set via the add function')
314        self.add(value,key=name)
315    #end def __setattr__
316
317    def __delattr__(self,name):
318        #self.error('elements can only be deleted via the remove function')
319        self.remove(name)
320    #end def __delattr__
321
322    def add(self,element,strict=True,key=None):
323        if not isinstance(element,QIxml):
324            self.error('collection cannot be formed\nadd attempted for non QIxml element\ntype received: {0}'.format(element.__class__.__name__))
325        #end if
326        keyin = key
327        key   = None
328        public = self.public()
329        identifier = element.identifier
330        missing_identifier = False
331        if not element.tag in plurals_inv and element.collection_id is None:
332            self.error('collection cannot be formed\n  encountered non-plural element\n  element class: {0}\n  element tag: {1}\n  tags allowed in a collection: {2}'.format(element.__class__.__name__,element.tag,sorted(plurals_inv.keys())))
333        elif identifier is None:
334            key = len(public)
335        elif isinstance(identifier,str):
336            if identifier in element:
337                key = element[identifier]
338            else:
339                missing_identifier = True
340            #end if
341        else:
342            key = ''
343            for ident in identifier:
344                if ident in element:
345                    key+=element[ident]
346                #end if
347            #end for
348            missing_identifier = key==''
349        #end if
350        if missing_identifier:
351            key = len(public)
352        #end if
353        if keyin is not None and not isinstance(key,int) and keyin.lower()!=key.lower():
354            self.error('attempted to add key with incorrect name\nrequested key: {0}\ncorrect key: {1}'.format(keyin,key))
355        #end if
356        #if key in public:
357        #    self.error('attempted to add duplicate key to collection: {0}\n keys present: {1}'.format(key,sorted(public.keys())))
358        ##end if
359        public[key] = element
360        self.hidden().order.append(key)
361        return True
362    #end def add
363
364    def remove(self,key):
365        public = self.public()
366        if key in public:
367            del public[key]
368            self.hidden().order.remove(key)
369        else:
370            raise KeyError
371        #end if
372    #end def remove
373
374    def get_single(self,preference=None):
375        if len(self)>0:
376            if preference!=None and preference in self:
377                return self[preference]
378            else:
379                return self.list()[0]
380            #end if
381        else:
382            #return self
383            return None
384        #end if
385    #end def get_single
386
387    def list(self):
388        lst = []
389        for key in self.hidden().order:
390            lst.append(self[key])
391        #end for
392        return lst
393    #end def list
394
395    def pairlist(self):
396        pairs = []
397        for key in self.hidden().order:
398            pairs.append((key,self[key]))
399        #end for
400        return pairs
401    #end def pairlist
402#end class collection
403
404
405def make_collection(elements):
406    return collection(*elements)
407#end def make_collection
408
409
410class classcollection(QIobj):
411    def __init__(self,*classes):
412        if len(classes)==1 and isinstance(classes[0],list):
413            classes = classes[0]
414        #end if
415        self.classes = classes
416    #end def __init__
417#end class classcollection
418
419
420class QmcpackInputCollections(QIobj):
421    def add(self,element):
422        if element.tag in plurals_inv:
423            cname = plurals_inv[element.tag]
424            if not cname in self:
425                coll = collection()
426                success = coll.add(element,strict=False)
427                if success:
428                    self[cname] = coll
429                #end if
430            else:
431                self[cname].add(element,strict=False)
432            #end if
433        #end if
434    #end def add
435
436    def get(self,cname,label=None):
437        v = None
438        if cname in self:
439            if label is None:
440                v = self[cname]
441            elif label in self[cname]:
442                v = self[cname][label]
443            #end if
444        #end if
445        return v
446    #end def get
447#end class QmcpackInputCollections
448QIcollections = QmcpackInputCollections()
449
450
451class Names(QIobj):
452    condensed_names = obj()
453    expanded_names = None
454
455    rsqmc_expanded_names = None
456    afqmc_expanded_names = None
457
458    escape_names = set(keyword.kwlist+['write'])
459    escaped_names = list(escape_names)
460    for i in range(len(escaped_names)):
461        escaped_names[i]+='_'
462    #end for
463    escaped_names = set(escaped_names)
464
465    @staticmethod
466    def set_expanded_names(**kwargs):
467        exnames = obj(**kwargs)
468        Names.expanded_names = exnames
469        Names.rsqmc_expanded_names = exnames
470    #end def set_expanded_names
471
472    @staticmethod
473    def set_afqmc_expanded_names(**kwargs):
474        Names.afqmc_expanded_names = obj(**kwargs)
475    #end def set_afqmc_expanded_names
476
477    @staticmethod
478    def use_rsqmc_expanded_names():
479        Names.expanded_names = Names.rsqmc_expanded_names
480    #end def use_rsqmc_expanded_names
481
482    @staticmethod
483    def use_afqmc_expanded_names():
484        Names.expanded_names = Names.afqmc_expanded_names
485    #end def use_afqmc_expanded_names
486
487    def expand_name(self,condensed):
488        expanded = condensed
489        cname = self.condense_name(condensed)
490        if cname in self.escaped_names:
491            cname = cname[:-1]
492            expanded = cname
493        #end if
494        if cname in self.expanded_names:
495            expanded = self.expanded_names[cname]
496        #end if
497        return expanded
498    #end def expand_name
499
500    def condense_name(self,expanded):
501        condensed = expanded
502        condensed = condensed.replace('___','_').replace('__','_')
503        condensed = condensed.replace('-','_').replace(' ','_')
504        condensed = condensed.lower()
505        if condensed in self.escape_names:
506            condensed += '_'
507        #end if
508        self.condensed_names[expanded]=condensed
509        return condensed
510    #end def condense_name
511
512    def condense_names(self,*namelists):
513        out = []
514        for namelist in namelists:
515            exp = obj()
516            for expanded in namelist:
517                condensed = self.condense_name(expanded)
518                exp[condensed]=expanded
519            #end for
520            out.append(exp)
521        #end for
522        return out
523    #end def condense_names
524
525    def condensed_name_report(self):
526        print()
527        print('Condensed Name Report:')
528        print('----------------------')
529        keylist = array(list(self.condensed_names.keys()))
530        order = array(list(self.condensed_names.values())).argsort()
531        keylist = keylist[order]
532        for expanded in keylist:
533            condensed = self.condensed_names[expanded]
534            if expanded!=condensed:
535                print("    {0:15} = '{1}'".format(condensed,expanded))
536            #end if
537        #end for
538        print()
539        print()
540    #end def condensed_name_report
541#end class Names
542
543
544
545
546class QIxml(Names):
547
548    def init_from_args(self,args):
549        print()
550        print('In init from args (not implemented).')
551        print('Possible reasons for incorrect entry:  ')
552        print('  Is xml element {0} meant to be plural?'.format(self.__class__.__name__))
553        print('    If so, add it to the plurals object.')
554        print()
555        print('Arguments received:')
556        print(args)
557        print()
558        self.not_implemented()
559    #end def init_from_args
560
561
562
563    @classmethod
564    def init_class(cls):
565        cls.class_set_optional(
566            tag         = cls.__name__,
567            identifier  = None,
568            attributes  = [],
569            elements    = [],
570            text        = None,
571            parameters  = [],
572            attribs     = [],
573            costs       = [],
574            h5tags      = [],
575            types       = obj(),
576            write_types = obj(),
577            attr_types  = None,
578            precision   = None,
579            defaults    = obj(),
580            collection_id = None,
581            exp_names   = None,
582            )
583        for v in ['attributes','elements','parameters','attribs','costs','h5tags']:
584            names = cls.class_get(v)
585            for i in range(len(names)):
586                if names[i] in cls.escape_names:
587                    names[i]+='_'
588                #end if
589            #end for
590        #end for
591        cls.params = cls.parameters + cls.attribs + cls.costs + cls.h5tags
592        cls.plurals_inv = obj()
593        for e in cls.elements:
594            if e in plurals_inv:
595                cls.plurals_inv[e] = plurals_inv[e]
596            #end if
597        #end for
598        cls.plurals = cls.plurals_inv.inverse()
599        if cls.exp_names is not None:
600            cls.expanded_names = obj(Names.expanded_names,cls.exp_names)
601        #end if
602    #end def init_class
603
604
605    def write(self,indent_level=0,pad='   ',first=False):
606        param.set_precision(self.get_precision())
607        if not QIobj.permissive_write:
608            self.check_junk(exit=True)
609        #end if
610        indent  = indent_level*pad
611        ip = indent+pad
612        ipp= ip+pad
613        expanded_tag = self.expand_name(self.tag)
614        c = indent+'<'+expanded_tag
615        for a in self.attributes:
616            if a in self:
617                val = self[a]
618                if isinstance(val,str):
619                    val = self.expand_name(val)
620                #end if
621                c += ' '+self.expand_name(a)+'='
622                if a in self.write_types:
623                    c += '"'+self.write_types[a](val)+'"'
624                else:
625                    c += '"'+param.write(val)+'"'
626                #end if
627            #end if
628        #end for
629        #if first:
630        #    c+=' xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="http://www.mcc.uiuc.edu/qmc/schema/molecu.xsd"'
631        ##end if
632        #no_contents = len(set(self.keys())-set(self.elements)-set(self.plurals.keys()))==0
633        no_contents = len(set(self.keys())-set(self.attributes))==0
634        if no_contents:
635            c += '/>\n'
636        else:
637            c += '>\n'
638            for v in self.h5tags:
639                if v in self:
640                    if v in self.write_types:
641                        write_type = self.write_types[v]
642                    else:
643                        write_type = None
644                    #end if
645                    c += param.write(self[v],name=self.expand_name(v),tag='h5tag',mode='elem',pad=ip,write_type=write_type)
646                #end if
647            #end for
648            for v in self.costs:
649                if v in self:
650                    c += param.write(self[v],name=self.expand_name(v),tag='cost',mode='elem',pad=ip)
651                #end if
652            #end for
653            for p in self.parameters:
654                if p in self:
655                    if p in self.write_types:
656                        write_type = self.write_types[p]
657                    else:
658                        write_type = None
659                    #end if
660                    c += param.write(self[p],name=self.expand_name(p),mode='elem',pad=ip,write_type=write_type)
661                #end if
662            #end for
663            for a in self.attribs:
664                if a in self:
665                    if a in self.write_types:
666                        write_type = self.write_types[a]
667                    else:
668                        write_type = None
669                    #end if
670                    c += param.write(self[a],name=self.expand_name(a),tag='attrib',mode='elem',pad=ip,write_type=write_type)
671                #end if
672            #end for
673            elements = self.elements
674            if self.afqmc_mode and 'afqmc_order' in self.__class__.__dict__:
675                elements = self.afqmc_order
676            #end if
677            for e in elements:
678                if e in self:
679                    elem = self[e]
680                    if isinstance(elem,QIxml):
681                        c += elem.write(indent_level+1)
682                    else:
683                        begin = '<'+e+'>'
684                        contents = param.write(elem)
685                        end = '</'+e+'>'
686                        if contents.strip()=='':
687                            c += ip+begin+end+'\n'
688                        else:
689                            c += ip+begin+'\n'
690                            c += ipp+contents+'\n'
691                            c += ip+end+'\n'
692                        #end if
693                    #end if
694                elif e in plurals_inv and plurals_inv[e] in self:
695                    coll = self[plurals_inv[e]]
696                    if not isinstance(coll,collection):
697                        self.error('write failed\n  element {0} is not a collection\n  contents of element {0}:\n{1}'.format(plurals_inv[e],str(coll)))
698                    #end if
699                    for instance in coll.list():
700                        c += instance.write(indent_level+1)
701                    #end for
702                #end if
703            #end for
704            if self.text!=None:
705                c = c.rstrip('\n')
706                c+=param.write(self[self.text],mode='elem',pad=ip,tag=None,normal_elem=True)
707            #end if
708            c+=indent+'</'+expanded_tag+'>\n'
709        #end if
710        param.reset_precision()
711
712        return c
713    #end def write
714
715
716    def __init__(self,*args,**kwargs):
717        if Param.metadata==None:
718            Param.metadata = meta()
719        #end if
720        if len(args)==1:
721            a = args[0]
722            if isinstance(a,XMLelement):
723                self.init_from_xml(a)
724            elif isinstance(a,section):
725                self.init_from_inputs(a.args,a.kwargs)
726            elif isinstance(a,self.__class__):
727                self.transfer_from(a)
728            else:
729                self.init_from_inputs(args,kwargs)
730            #end if
731        else:
732            self.init_from_inputs(args,kwargs)
733        #end if
734        QIcollections.add(self)
735    #end def __init__
736
737
738    def init_from_xml(self,xml):
739        al,el = self.condense_names(xml._attributes.keys(),xml._elements.keys())
740        xa,sa = set(al.keys()) , set(self.attributes)
741        attr = xa & sa
742        junk = xa-attr
743        junk_elem = []
744        for e,ecap in el.items():
745            value = xml._elements[ecap]
746            if (isinstance(value,list) or isinstance(value,tuple)) and e in self.plurals_inv.keys():
747                if e not in types:
748                    self.error('input element "{}" is unknown'.format(e))
749                #end if
750                p = self.plurals_inv[e]
751                plist = []
752                for instance in value:
753                    plist.append(types[e](instance))
754                #end for
755                self[p] = make_collection(plist)
756            elif e in self.elements:
757                if e not in types:
758                    self.error('input element "{}" is unknown'.format(e))
759                #end if
760                self[e] = types[e](value)
761            elif e in ['parameter','attrib','cost','h5tag']:
762                if isinstance(value,XMLelement):
763                    value = [value]
764                #end if
765                for p in value:
766                    name = self.condense_name(p.name)
767                    if name in self.params:
768                        self[name] = param(p)
769                    else:
770                        junk_elem.append(name)
771                    #end if
772                #end for
773            else:
774                junk_elem.append(e)
775            #end if
776        #end for
777        junk = junk | set(junk_elem)
778        if QmcpackInput.profile_collection!=None:
779            self.collect_profile(xml,al,el,junk)
780        #end for
781        if not QIobj.permissive_read:
782            self.check_junk(junk)
783        #end if
784        if self.attr_types!=None:
785            typed_attr = attr & set(self.attr_types.keys())
786            attr -= typed_attr
787            for a in typed_attr:
788                self[a] = self.attr_types[a](xml._attributes[al[a]])
789            #end for
790        #end if
791        for a in attr:
792            if a in self.write_types and self.write_types[a] in bool_write_types:
793                aval = xml._attributes[al[a]]
794                if aval in boolmap:
795                    self[a] = boolmap[aval]
796                else:
797                    self.error('{0} is not a valid value for boolean attribute {1}\n  valid values are: {2}'.format(aval,a,boolmap.keys()))
798                #end if
799            else:
800                self[a] = attribute_to_value(xml._attributes[al[a]])
801            #end if
802        #end for
803        if self.text!=None:
804            self[self.text] = param(xml)
805        #end if
806    #end def init_from_xml
807
808
809    def init_from_inputs(self,args,kwargs):
810        if len(args)>0:
811            if len(args)==1 and isinstance(args[0],self.__class__):
812                self.transfer_from(args[0])
813            elif len(args)==1 and isinstance(args[0],dict):
814                self.init_from_kwargs(args[0])
815            else:
816                self.init_from_args(args)
817            #end if
818        #end if
819        self.init_from_kwargs(kwargs)
820    #end def init_from_inputs
821
822
823    def init_from_kwargs(self,kwargs):
824        ks=[]
825        kmap = dict()
826        for key,val in kwargs.items():
827            ckey = self.condense_name(key)
828            ks.append(ckey)
829            kmap[ckey] = val
830        #end for
831        ks = set(ks)
832        kwargs = kmap
833        h5tags     = ks & set(self.h5tags)
834        costs      = ks & set(self.costs)
835        parameters = ks & set(self.parameters)
836        attribs    = ks & set(self.attribs)
837        attr = ks & set(self.attributes)
838        elem = ks & set(self.elements)
839        plur = ks & set(self.plurals.keys())
840        if self.text!=None:
841            text = ks & set([self.text])
842        else:
843            text = set()
844        #end if
845        if not QIobj.permissive_init:
846            junk = ks -attr -elem -plur -h5tags -costs -parameters -attribs -text
847            self.check_junk(junk,exit=True)
848        #end if
849
850        for v in h5tags:
851            self[v] = param(kwargs[v])
852        #end for
853        for v in costs:
854            self[v] = param(kwargs[v])
855        #end for
856        for v in parameters:
857            self[v] = param(kwargs[v])
858        #end for
859        for v in attribs:
860            self[v] = param(kwargs[v])
861        #end for
862        for a in attr:
863            self[a] = kwargs[a]
864        #end for
865        for e in elem:
866            self[e] = types[e](kwargs[e])
867        #end for
868        for p in plur:
869            e = self.plurals[p]
870            kwcoll = kwargs[p]
871            if isinstance(kwcoll,collection):
872                plist = kwcoll.list()
873            elif isinstance(kwcoll,(list,tuple)):
874                plist = kwcoll
875            else:
876                self.error('init failed\n  encountered non-list collection')
877            #end if
878            ilist = []
879            for instance in plist:
880                ilist.append(types[e](instance))
881            #end for
882            self[p] = make_collection(ilist)
883        #end for
884        for t in text:
885            self[t] = kwargs[t]
886        #end for
887    #end def init_from_kwargs
888
889
890    def incorporate_defaults(self,elements=False,overwrite=False,propagate=True):
891        for name,value in self.defaults.items():
892            defval=None
893            if isinstance(value,classcollection):
894                if elements:
895                    coll=[]
896                    for cls in value.classes:
897                        ins = cls()
898                        ins.incorporate_defaults()
899                        coll.append(ins)
900                    #end for
901                    defval = make_collection(coll)
902                #end if
903            elif inspect.isclass(value):
904                if elements:
905                    defval = value()
906                #end if
907            elif inspect.isfunction(value):
908                defval = value()
909            else:
910                defval = value
911            #end if
912            if defval!=None:
913                if overwrite or not name in self:
914                    self[name] = defval
915                #end if
916            #end if
917        #end for
918        if propagate:
919            for name,value in self.items():
920                if isinstance(value,QIxml):
921                    value.incorporate_defaults(elements,overwrite)
922                elif isinstance(value,collection):
923                    for v in value:
924                        if isinstance(v,QIxml):
925                            v.incorporate_defaults(elements,overwrite)
926                        #end if
927                    #end for
928                #end if
929            #end for
930        #end if
931    #end def incorporate_defaults
932
933
934    def check_junk(self,junk=None,exit=False):
935        if junk is None:
936            ks = set(self.keys())
937            h5tags     = ks & set(self.h5tags)
938            costs      = ks & set(self.costs)
939            parameters = ks & set(self.parameters)
940            attribs    = ks & set(self.attribs)
941            attr = ks & set(self.attributes)
942            elem = ks & set(self.elements)
943            plur = ks & set(self.plurals.keys())
944            if self.text is not None:
945                text = ks & set([self.text])
946            else:
947                text = set()
948            #end if
949            junk = ks -attr -elem -plur -h5tags -costs -parameters -attribs -text
950        #end if
951        if len(junk)>0:
952            oname = ''
953            if self.tag!=self.__class__.__name__:
954                oname = ' ('+self.__class__.__name__+')'
955            #end if
956            msg = '{0}{1} does not have the following attributes/elements:\n'.format(self.tag,oname)
957            for jname in junk:
958                msg+='    '+jname+'\n'
959            #end for
960            #if QmcpackInput.profile_collection is None:
961            #    self.error(msg,'QmcpackInput',exit=exit,trace=exit)
962            ##end if
963
964            print(obj(dict(self.__class__.__dict__)))
965
966            self.error(msg,'QmcpackInput',exit=exit,trace=exit)
967        #end if
968    #end def check_junk
969
970
971    def collect_profile(self,xml,al,el,junk):
972        attributes = obj(**al)
973        parameters = obj()
974        elements   = obj()
975        for e,ecap in el.items():
976            if e=='parameter':
977                parameters[e]=ecap
978            else:
979                elements[e]=ecap
980            #end if
981        #end for
982        profile = obj(
983            attributes = attributes,
984            parameters = parameters,
985            elements   = elements,
986            xml        = xml,
987            junk       = junk
988            )
989        #xml = xml.copy()
990        xname = xml._name
991        if xname[-1:].isdigit():
992            xname = xname[:-1]
993        elif xname[-2:].isdigit():
994            xname = xname[:-2]
995        #end if
996        #xml._name = xname
997
998        if len(profile.junk)>0:
999            print('  '+xname+' (found '+str(junk)+')')
1000            for sector in 'attributes elements'.split():
1001                missing = []
1002                for n in profile.junk:
1003                    if n in profile[sector]:
1004                        missing.append(profile[sector][n])
1005                    #end if
1006                #end for
1007                if len(missing)>0:
1008                    ms= '    '+sector+':'
1009                    for m in missing:
1010                        ms+=' '+m
1011                    #end for
1012                    print(ms)
1013                #end if
1014            #end for
1015            if 'parameter' in profile.xml:
1016                params = obj()
1017                for p in profile.xml.parameter:
1018                    params[p.name.lower()] = p.text.strip()
1019                #end for
1020                missing = []
1021                for n in profile.junk:
1022                    if n in params:
1023                        missing.append(n)
1024                    #end if
1025                #end for
1026                if len(missing)>0:
1027                    ms= '    parameters:'
1028                    for m in missing:
1029                        ms+=' '+m
1030                    #end for
1031                    print(ms)
1032                #end if
1033            #end if
1034            if junk!=set(['analysis']) and junk!=set(['ratio']) and junk!=set(['randmo']) and junk!=set(['printeloc', 'source']) and junk!=set(['warmup_steps']) and junk!=set(['sposet_collection']) and junk!=set(['eigensolve', 'atom']) and junk!=set(['maxweight', 'reweightedvariance', 'unreweightedvariance', 'energy', 'exp0', 'stabilizerscale', 'minmethod', 'alloweddifference', 'stepsize', 'beta', 'minwalkers', 'nstabilizers', 'bigchange', 'usebuffer']) and junk!=set(['loop2']) and junk!=set(['random']) and junk!=set(['max_steps']):
1035                exit()
1036            #end if
1037        #end if
1038
1039        pc = QmcpackInput.profile_collection
1040        if not xname in pc:
1041            pc[xname] = obj()
1042        #end if
1043        pc[xname].append(profile)
1044    #end def collect_profile
1045
1046
1047
1048    def get_single(self,preference):
1049        return self
1050    #end def get_single
1051
1052
1053    def get(self,names,namedict=None,host=False,root=True):
1054        if namedict is None:
1055            namedict = {}
1056        #end if
1057        if isinstance(names,str):
1058            names = [names]
1059        #end if
1060        if root and not host:
1061            if self.identifier!=None and self.identifier in self:
1062                identity = self[self.identifier]
1063            else:
1064                identity = None
1065            #end if
1066            for name in names:
1067                if name==self.tag:
1068                    namedict[name]=self
1069                elif name==identity:
1070                    namedict[name]=self
1071                #end if
1072            #end for
1073        #end if
1074        for name in names:
1075            loc = None
1076            if name in self:
1077                loc = name
1078            elif name in plurals_inv and plurals_inv[name] in self:
1079                loc = plurals_inv[name]
1080            #end if
1081            name_absent = not name in namedict
1082            not_element = False
1083            if not name_absent:
1084                not_xml  = not isinstance(namedict[name],QIxml)
1085                not_coll = not isinstance(namedict[name],collection)
1086                not_element = not_xml and not_coll
1087            #end if
1088            if loc!=None and (name_absent or not_element):
1089                if host:
1090                    namedict[name] = self
1091                else:
1092                    namedict[name] = self[loc]
1093                #end if
1094            #end if
1095        #end for
1096        for name,value in self.items():
1097            if isinstance(value,QIxml):
1098                value.get(names,namedict,host,root=False)
1099            elif isinstance(value,collection):
1100                for n,v in value.items():
1101                    name_absent = not n in namedict
1102                    not_element = False
1103                    if not name_absent:
1104                        not_xml  = not isinstance(namedict[n],QIxml)
1105                        not_coll = not isinstance(namedict[n],collection)
1106                        not_element = not_xml and not_coll
1107                    #end if
1108                    if n in names and (name_absent or not_element):
1109                        if host:
1110                            namedict[n] = value
1111                        else:
1112                            namedict[n] = v
1113                        #end if
1114                    #end if
1115                    if isinstance(v,QIxml):
1116                        v.get(names,namedict,host,root=False)
1117                    #end if
1118                #end if
1119            #end if
1120        #end for
1121        if root:
1122            namelist = []
1123            for name in names:
1124                if name in namedict:
1125                    namelist.append(namedict[name])
1126                else:
1127                    namelist.append(None)
1128                #end if
1129            #end for
1130            if len(namelist)==1:
1131                return namelist[0]
1132            else:
1133                return namelist
1134            #end if
1135        #end if
1136    #end def get
1137
1138    def remove(self,*names):
1139        if len(names)==1 and not isinstance(names[0],str):
1140            names = names[0]
1141        #end if
1142        remove = []
1143        for name in names:
1144            attempt = True
1145            if name in self:
1146                rname = name
1147            elif name in plurals_inv and plurals_inv[name] in self:
1148                rname = plurals_inv[name]
1149            else:
1150                attempt = False
1151            #end if
1152            if attempt:
1153                val = self[rname]
1154                if isinstance(val,QIxml) or isinstance(val,collection):
1155                    remove.append(rname)
1156                #end if
1157            #end if
1158        #end for
1159        for name in remove:
1160            del self[name]
1161        #end for
1162        for name,value in self.items():
1163            if isinstance(value,QIxml):
1164                value.remove(*names)
1165            elif isinstance(value,collection):
1166                for element in value:
1167                    if isinstance(element,QIxml):
1168                        element.remove(*names)
1169                    #end if
1170                #end for
1171            #end if
1172        #end for
1173    #end def remove
1174
1175
1176    def assign(self,**kwargs):
1177        for var,vnew in kwargs.items():
1178            if var in self:
1179                val = self[var]
1180                not_coll = not isinstance(val,collection)
1181                not_xml  = not isinstance(val,QIxml)
1182                not_arr  = not isinstance(val,ndarray)
1183                if not_coll and not_xml and not_arr:
1184                    self[var] = vnew
1185                #end if
1186            #end if
1187        #end for
1188        for vname,val in self.items():
1189            if isinstance(val,QIxml):
1190                val.assign(**kwargs)
1191            elif isinstance(val,collection):
1192                for v in val:
1193                    if isinstance(v,QIxml):
1194                        v.assign(**kwargs)
1195                    #end if
1196                #end for
1197            #end if
1198        #end for
1199    #end def assign
1200
1201
1202    def replace(self,*args,**kwargs):
1203        if len(args)==2 and isinstance(args[0],str) and isinstance(args[1],str):
1204            vold,vnew = args
1205            args = [(vold,vnew)]
1206        #end for
1207        for valpair in args:
1208            vold,vnew = valpair
1209            for var,val in self.items():
1210                not_coll = not isinstance(val,collection)
1211                not_xml  = not isinstance(val,QIxml)
1212                not_arr  = not isinstance(val,ndarray)
1213                if not_coll and not_xml and not_arr and val==vold:
1214                    self[var] = vnew
1215                #end if
1216            #end for
1217        #end for
1218        for var,valpair in kwargs.items():
1219            vold,vnew = valpair
1220            if var in self:
1221                val = self[var]
1222                if vold==None:
1223                    self[var] = vnew
1224                else:
1225                    not_coll = not isinstance(val,collection)
1226                    not_xml  = not isinstance(val,QIxml)
1227                    not_arr  = not isinstance(val,ndarray)
1228                    if not_coll and not_xml and not_arr and val==vold:
1229                        self[var] = vnew
1230                    #end if
1231                #end if
1232            #end if
1233        #end for
1234        for vname,val in self.items():
1235            if isinstance(val,QIxml):
1236                val.replace(*args,**kwargs)
1237            elif isinstance(val,collection):
1238                for v in val:
1239                    if isinstance(v,QIxml):
1240                        v.replace(*args,**kwargs)
1241                    #end if
1242                #end for
1243            #end if
1244        #end for
1245    #end def replace
1246
1247
1248    def combine(self,other):
1249        #elemental combine only
1250        for name,element in other.items():
1251            plural = isinstance(element,collection)
1252            single = isinstance(element,QIxml)
1253            if single or plural:
1254                elem = []
1255                single_name = None
1256                plural_name = None
1257                if single:
1258                    elem.append(element)
1259                    single_name = name
1260                    if name in plurals_inv:
1261                        plural_name = plurals_inv[name]
1262                    #end if
1263                else:
1264                    elem.extend(element.values())
1265                    plural_name = name
1266                    single_name = plurals[name]
1267                #end if
1268                if single_name in self:
1269                    elem.append(self[single_name])
1270                    del self[single_name]
1271                elif plural_name!=None and plural_name in self:
1272                    elem.append(self[plural_name])
1273                    del self[plural_name]
1274                #end if
1275                if len(elem)==1:
1276                    self[single_name]=elem[0]
1277                elif plural_name==None:
1278                    self.error('attempting to combine non-plural elements: '+single_name)
1279                else:
1280                    self[plural_name] = make_collection(elem)
1281                #end if
1282            #end if
1283        #end for
1284    #end def combine
1285
1286
1287    def move(self,**elemdests):
1288        names = list(elemdests.keys())
1289        hosts = self.get_host(names)
1290        dests = self.get(list(elemdests.values()))
1291        if len(names)==1:
1292            hosts = [hosts]
1293            dests = [dests]
1294        #end if
1295        for i in range(len(names)):
1296            name = names[i]
1297            host = hosts[i]
1298            dest = dests[i]
1299            if host!=None and dest!=None and id(host)!=id(dest):
1300                if not name in host:
1301                    name = plurals_inv[name]
1302                #end if
1303                dest[name] = host[name]
1304                del host[name]
1305            #end if
1306        #end for
1307    #end def move
1308
1309
1310
1311    def pluralize(self):
1312        make_plural = []
1313        for name,value in self.items():
1314            if isinstance(value,QIxml):
1315                if name in plurals_inv:
1316                    make_plural.append(name)
1317                #end if
1318                value.pluralize()
1319            elif isinstance(value,collection):
1320                if name in plurals_inv:
1321                    make_plural.append(name)
1322                #end if
1323                for v in value:
1324                    if isinstance(v,QIxml):
1325                        v.pluralize()
1326                    #end if
1327                #end for
1328            #end if
1329        #end for
1330        for name in make_plural:
1331            value = self[name]
1332            del self[name]
1333            plural_name = plurals_inv[name]
1334            self[plural_name] = make_collection([value])
1335        #end for
1336    #end def pluralize
1337
1338
1339    def difference(self,other,root=True):
1340        if root:
1341            q1 = self.copy()
1342            q2 = other.copy()
1343        else:
1344            q1 = self
1345            q2 = other
1346        #end if
1347        if q1.__class__!=q2.__class__:
1348            different = True
1349            diff = None
1350            d1 = q1
1351            d2 = q2
1352        else:
1353            cls = q1.__class__
1354            s1 = set(q1.keys())
1355            s2 = set(q2.keys())
1356            shared  = s1 & s2
1357            unique1 = s1 - s2
1358            unique2 = s2 - s1
1359            different = len(unique1)>0 or len(unique2)>0
1360            diff = cls()
1361            d1 = cls()
1362            d2 = cls()
1363            d1.transfer_from(q1,unique1)
1364            d2.transfer_from(q2,unique2)
1365            for k in shared:
1366                value1 = q1[k]
1367                value2 = q2[k]
1368                is_coll1 = isinstance(value1,collection)
1369                is_coll2 = isinstance(value2,collection)
1370                is_qxml1 = isinstance(value1,QIxml)
1371                is_qxml2 = isinstance(value2,QIxml)
1372                if is_coll1!=is_coll2 or is_qxml1!=is_qxml2:
1373                    self.error('values for '+k+' are of inconsistent types\n  difference could not be taken')
1374                #end if
1375                if is_qxml1 and is_qxml2:
1376                    kdifferent,kdiff,kd1,kd2 = value1.difference(value2,root=False)
1377                elif is_coll1 and is_coll2:
1378                    ks1 = set(value1.keys())
1379                    ks2 = set(value2.keys())
1380                    kshared  = ks1 & ks2
1381                    kunique1 = ks1 - ks2
1382                    kunique2 = ks2 - ks1
1383                    kdifferent = len(kunique1)>0 or len(kunique2)>0
1384                    kd1 = collection()
1385                    kd2 = collection()
1386                    kd1.transfer_from(value1,kunique1)
1387                    kd2.transfer_from(value2,kunique2)
1388                    kdiff = collection()
1389                    for kk in kshared:
1390                        v1 = value1[kk]
1391                        v2 = value2[kk]
1392                        if isinstance(v1,QIxml) and isinstance(v2,QIxml):
1393                            kkdifferent,kkdiff,kkd1,kkd2 = v1.difference(v2,root=False)
1394                            kdifferent = kdifferent or kkdifferent
1395                            if kkdiff!=None:
1396                                kdiff[kk]=kkdiff
1397                            #end if
1398                            if kkd1!=None:
1399                                kd1[kk]=kkd1
1400                            #end if
1401                            if kkd2!=None:
1402                                kd2[kk]=kkd2
1403                            #end if
1404                        #end if
1405                    #end for
1406                else:
1407                    if isinstance(value1,ndarray):
1408                        a1 = value1.ravel()
1409                    else:
1410                        a1 = array([value1])
1411                    #end if
1412                    if isinstance(value2,ndarray):
1413                        a2 = value2.ravel()
1414                    else:
1415                        a2 = array([value2])
1416                    #end if
1417                    if len(a1)!=len(a2):
1418                        kdifferent = True
1419                    elif len(a1)==0:
1420                        kdifferent = False
1421                    elif (isinstance(a1[0],float) or isinstance(a2[0],float)) and not  (isinstance(a1[0],str)   or isinstance(a2[0],str)):
1422                        kdifferent = abs(a1-a2).max()/max(1e-99,abs(a1).max(),abs(a2).max()) > 1e-6
1423                    else:
1424                        kdifferent = not (a1==a2).all()
1425                    #end if
1426                    if kdifferent:
1427                        kdiff = (value1,value2)
1428                        kd1   = value1
1429                        kd2   = value2
1430                    else:
1431                        kdiff = None
1432                        kd1   = None
1433                        kd2   = None
1434                    #end if
1435                #end if
1436                different = different or kdifferent
1437                if kdiff!=None:
1438                    diff[k] = kdiff
1439                #end if
1440                if kd1!=None:
1441                    d1[k] = kd1
1442                #end if
1443                if kd2!=None:
1444                    d2[k] = kd2
1445                #end if
1446            #end for
1447        #end if
1448        if root:
1449            if diff!=None:
1450                diff.remove_empty()
1451            #end if
1452            d1.remove_empty()
1453            d2.remove_empty()
1454        #end if
1455        return different,diff,d1,d2
1456    #end def difference
1457
1458    def remove_empty(self):
1459        names = list(self.keys())
1460        for name in names:
1461            value = self[name]
1462            if isinstance(value,QIxml):
1463                value.remove_empty()
1464                if len(value)==0:
1465                    del self[name]
1466                #end if
1467            elif isinstance(value,collection):
1468                ns = list(value.keys())
1469                for n in ns:
1470                    v = value[n]
1471                    if isinstance(v,QIxml):
1472                        v.remove_empty()
1473                        if len(v)==0:
1474                            del value[n]
1475                        #end if
1476                    #end if
1477                #end for
1478                if len(value)==0:
1479                    del self[name]
1480                #end if
1481            #end if
1482        #end for
1483    #end def remove_empty
1484
1485    def get_host(self,names):
1486        return self.get(names,host=True)
1487    #end def get_host
1488
1489    def get_precision(self):
1490        return self.__class__.class_get('precision')
1491    #end def get_precision
1492#end class QIxml
1493
1494
1495
1496class QIxmlFactory(Names):
1497    def __init__(self,name,types,typekey='',typeindex=-1,typekey2='',default=None):
1498        self.name = name
1499        self.types = types
1500        self.typekey = typekey
1501        self.typekey2 = typekey2
1502        self.typeindex = typeindex
1503        self.default = default
1504    #end def __init__
1505
1506    def __call__(self,*args,**kwargs):
1507        #emulate QIxml.__init__
1508        #get the value of the typekey
1509        a  = args
1510        kw = kwargs
1511        found_type = False
1512        if len(args)>0:
1513            v = args[0]
1514            if isinstance(v,XMLelement):
1515                kw = v._attributes
1516            elif isinstance(v,section):
1517                a  = v.args
1518                kw = v.kwargs
1519            elif isinstance(v,tuple(self.types.values())):
1520                found_type = True
1521                type = v.__class__.__name__
1522            #end if
1523        #end if
1524        if not found_type:
1525            if self.typekey in kw.keys():
1526                type = kw[self.typekey]
1527            elif self.typekey2 in kw.keys():
1528                type = kw[self.typekey2]
1529            elif self.default!=None:
1530                type = self.default
1531            elif self.typeindex==-1:
1532                self.error('QMCPACK input file is misformatted\ncannot identify type for <{0}/> element\nwith contents:\n{1}\nplease find the XML element matching this description in the input file to identify the problem\nmost likely, it is missing attributes "{2}" or "{3}"'.format(self.name,str(v).rstrip(),self.typekey,self.typekey2))
1533            else:
1534                type = a[self.typeindex]
1535            #end if
1536        #end if
1537        type = self.condense_name(type)
1538        if type in self.types:
1539            return self.types[type](*args,**kwargs)
1540        else:
1541            msg = self.name+' factory is not aware of the following subtype:\n'
1542            msg+= '    '+type+'\n'
1543            self.error(msg,exit=False,trace=False)
1544        #end if
1545    #end def __call__
1546
1547    def init_class(self):
1548        None # this is for compatibility with QIxml only (do not overwrite)
1549    #end def init_class
1550#end class QIxmlFactory
1551
1552
1553
1554class Param(Names):
1555    metadata = None
1556
1557    def __init__(self):
1558        self.reset_precision()
1559    #end def __init__
1560
1561    def reset_precision(self):
1562        self.precision   = None
1563        self.prec_format = None
1564    #end def reset_precision
1565
1566    def set_precision(self,precision):
1567        if precision is None:
1568            self.reset_precision()
1569        elif not isinstance(precision,str):
1570            self.error('attempted to set precision with non-string: {0}'.format(precision))
1571        else:
1572            self.precision   = precision
1573            self.prec_format = '{0:'+precision+'}'
1574        #end if
1575    #end def set_precision
1576
1577    def __call__(self,*args,**kwargs):
1578        if len(args)==0:
1579            self.error('no arguments provided, should have received one XMLelement')
1580        elif not isinstance(args[0],XMLelement):
1581            return args[0]
1582            #self.error('first argument is not an XMLelement')
1583        #end if
1584        return self.read(args[0])
1585    #end def __call__
1586
1587    def read(self,xml):
1588        val = ''
1589        attr = set(xml._attributes.keys())
1590        other_attr = attr-set(['name'])
1591        if 'name' in attr and len(other_attr)>0:
1592            oa = obj()
1593            for a in other_attr:
1594                oa[a] = xml._attributes[a]
1595            #end for
1596            self.metadata[xml.name] = oa
1597        #end if
1598        if 'text' in xml:
1599            token = xml.text.split('\n',1)[0].split(None,1)[0]
1600            try:
1601                if is_int(token):
1602                    val = loadtxt(StringIO(xml.text),int)
1603                elif is_float(token):
1604                    val = loadtxt(StringIO(xml.text),float)
1605                else:
1606                    val = array(xml.text.split())
1607                #end if
1608            except:
1609                if is_int(token):
1610                    val = array(xml.text.split(),dtype=int)
1611                elif is_float(token):
1612                    val = array(xml.text.split(),dtype=float)
1613                else:
1614                    val = array(xml.text.split())
1615                #end if
1616            #end try
1617            if val.size==1:
1618                val = val.ravel()[0]
1619            #end if
1620        #end if
1621        return val
1622    #end def read
1623
1624
1625    def write(self,value,mode='attr',tag='parameter',name=None,pad='   ',write_type=None,normal_elem=False):
1626        c = ''
1627        attr_mode = mode=='attr'
1628        elem_mode = mode=='elem'
1629        if not attr_mode and not elem_mode:
1630            self.error(mode+' is not a valid mode.  Options are attr,elem.')
1631        #end if
1632        if isinstance(value,list) or isinstance(value,tuple):
1633            value = array(value)
1634        #end if
1635        if attr_mode:
1636            if isinstance(value,ndarray):
1637                arr = value.ravel()
1638                for v in arr:
1639                    c+=self.write_val(v)+' '
1640                #end for
1641                c=c[:-1]
1642            else:
1643                c = self.write_val(value)
1644            #end if
1645        elif elem_mode:
1646            c+=pad
1647            is_array = isinstance(value,ndarray)
1648            is_single = not (is_array and value.size>1)
1649            if tag!=None:
1650                if is_single:
1651                    max_len = 20
1652                    rem_len = max(0,max_len-len(name))
1653                else:
1654                    rem_len = 0
1655                #end if
1656                other=''
1657                if name in self.metadata:
1658                    for a,v in self.metadata[name].items():
1659                        other +=' '+self.expand_name(a)+'="'+self.write_val(v)+'"'
1660                    #end for
1661                #end if
1662                c+='<'+tag+' name="'+name+'"'+other+rem_len*' '+'>'
1663                pp = pad+'   '
1664            else:
1665                pp = pad
1666            #end if
1667            if is_array:
1668                if normal_elem:
1669                    c+='\n'
1670                #end if
1671                if tag!=None:
1672                    c+='\n'
1673                #end if
1674                ndim = len(value.shape)
1675                if ndim==1:
1676                    line_len = 70
1677                    if tag!=None:
1678                        c+=pp
1679                    #end if
1680                    line = ''
1681                    for v in value:
1682                        line+=self.write_val(v)+' '
1683                        if len(line)>line_len:
1684                            c+=line+'\n'
1685                            line = ''
1686                        #end if
1687                    #end for
1688                    if len(line)>0:
1689                        c+=line
1690                    #end if
1691                    c=c[:-1]+'\n'
1692                elif ndim==2:
1693                    nrows,ncols = value.shape
1694                    fmt=pp
1695                    if value.dtype == dtype(float):
1696                        if self.precision is None:
1697                            vfmt = ':16.8f' # must have 8 digits of post decimal accuracy to meet qmcpack tolerance standards
1698                            #vfmt = ':16.8e'
1699                        else:
1700                            vfmt = ': '+self.precision
1701                        #end if
1702                    else:
1703                        vfmt = ''
1704                    #end if
1705                    for nc in range(ncols):
1706                        fmt+='{'+str(nc)+vfmt+'}  '
1707                    #end for
1708                    fmt = fmt[:-2]+'\n'
1709                    for nr in range(nrows):
1710                        c+=fmt.format(*value[nr])
1711                    #end for
1712                else:
1713                    self.error('only 1 and 2 dimensional arrays are supported for xml formatting.\n  Received '+ndim+' dimensional array.')
1714                #end if
1715            else:
1716                if write_type!=None:
1717                    val = write_type(value)
1718                else:
1719                    val = value
1720                #end if
1721                #c += '    '+str(val)
1722                c += '    {0:<10}'.format(self.write_val(val))
1723            #end if
1724            if tag!=None:
1725                c+=pad+'</'+tag+'>\n'
1726            #end if
1727        #end if
1728        return c
1729    #end def write
1730
1731
1732    def write_val(self,val):
1733        if self.precision!=None and isinstance(val,float):
1734            return self.prec_format.format(val)
1735        else:
1736            return str(val)
1737        #end if
1738    #end def write_val
1739
1740    def init_class(self):
1741        None
1742    #end def init_class
1743#end class Param
1744param = Param()
1745
1746
1747
1748
1749class simulation(QIxml):
1750    #            afqmc
1751    attributes = ['method']
1752    #            rsqmc
1753    elements   = ['project','random','include','qmcsystem','particleset',
1754                  'wavefunction','hamiltonian','init','traces','qmc','loop',
1755                  'mcwalkerset','cmc']+\
1756                  ['afqmcinfo','walkerset','propagator','execute'] # afqmc
1757    afqmc_order = ['project','random','afqmcinfo','hamiltonian',
1758                   'wavefunction','walkerset','propagator','execute']
1759    write_types = obj(random=yesno)
1760#end class simulation
1761
1762
1763class project(QIxml):
1764    attributes = ['id','series']
1765    elements   = ['application','host','date','user']
1766#end class project
1767
1768class application(QIxml):
1769    attributes = ['name','role','class','version']
1770#end class application
1771
1772class host(QIxml):
1773    text = 'value'
1774#end class host
1775
1776class date(QIxml):
1777    text = 'value'
1778#end class date
1779
1780class user(QIxml):
1781    text = 'value'
1782#end class user
1783
1784class random(QIxml):
1785    attributes = ['seed','parallel']
1786    write_types= obj(parallel=truefalse)
1787#end class random
1788
1789class include(QIxml):
1790    attributes = ['href']
1791#end def include
1792
1793class mcwalkerset(QIxml):
1794    attributes = ['fileroot','version','collected','node','nprocs','href','target','file','walkers']
1795    write_types = obj(collected=yesno)
1796#end class mcwalkerset
1797
1798class qmcsystem(QIxml):
1799    attributes = ['dim'] #,'wavefunction','hamiltonian']  # breaks QmcpackInput
1800    elements = ['simulationcell','particleset','wavefunction','hamiltonian','random','init','mcwalkerset']
1801#end class qmcsystem
1802
1803
1804
1805class simulationcell(QIxml):
1806    attributes = ['name','tilematrix']
1807    parameters = ['lattice','reciprocal','bconds','lr_dim_cutoff','lr_tol','lr_handler','rs','nparticles','scale','uc_grid']
1808#end class simulationcell
1809
1810class particleset(QIxml):
1811    attributes = ['name','size','random','random_source','randomsrc','charge','source']
1812    elements   = ['group','simulationcell']
1813    attribs    = ['ionid','position']
1814    write_types= obj(random=yesno)
1815    identifier = 'name'
1816#end class particleset
1817
1818class group(QIxml):
1819    attributes = ['name','size','mass'] # mass attr and param, bad bad bad!!!
1820    parameters = ['charge','valence','atomicnumber','mass','lmax',
1821                  'cutoff_radius','spline_radius','spline_npoints']
1822    attribs    = ['position']
1823    identifier = 'name'
1824#end class group
1825
1826
1827
1828class sposet(QIxml):
1829    attributes = ['basisset','type','name','group','size',
1830                  'index_min','index_max','energy_min','energy_max',
1831                  'spindataset','cuspinfo','sort','gpu','href','twistnum',
1832                  'gs_sposet','basis_sposet','same_k','frequency','mass',
1833                  'source','version','precision','tilematrix',
1834                  'meshfactor']
1835    elements   = ['occupation','coefficient','coefficients']
1836    text       = 'spos'
1837    identifier = 'name'
1838#end class sposet
1839
1840class bspline_builder(QIxml):
1841    tag         = 'sposet_builder'
1842    identifier  = 'type'
1843    attributes  = ['type','href','sort','tilematrix','twistnum','twist','source',
1844                   'version','meshfactor','gpu','transform','precision','truncate',
1845                   'lr_dim_cutoff','shell','randomize','key','buffer','rmax_core','dilation','tag','hybridrep']
1846    elements    = ['sposet']
1847    write_types = obj(gpu=yesno,sort=onezero,transform=yesno,truncate=yesno,randomize=truefalse,hybridrep=yesno)
1848#end class bspline_builder
1849
1850class heg_builder(QIxml):
1851    tag        = 'sposet_builder'
1852    identifier = 'type'
1853    attributes = ['type','twist']
1854    elements   = ['sposet']
1855#end class heg_builder
1856
1857class molecular_orbital_builder(QIxml):
1858    tag = 'sposet_builder'
1859    identifier = 'type'
1860    attributes = ['name','type','transform','source','cuspcorrection']
1861    elements   = ['basisset','sposet']
1862#end class molecular_orbital_builder
1863
1864class composite_builder(QIxml):
1865    tag = 'sposet_builder'
1866    identifier = 'type'
1867    attributes = ['type']
1868    elements   = ['sposet']
1869#end class composite_builder
1870
1871sposet_builder = QIxmlFactory(
1872    name    = 'sposet_builder',
1873    types   = dict(bspline=bspline_builder,
1874                   einspline=bspline_builder,
1875                   heg=heg_builder,
1876                   composite=composite_builder,
1877                   molecularorbital = molecular_orbital_builder),
1878    typekey = 'type'
1879    )
1880
1881
1882
1883class wavefunction(QIxml):
1884    #            rsqmc                        afqmc
1885    attributes = ['name','target','id','ref']+['info','type']
1886    #            afqmc
1887    parameters = ['filetype','filename','cutoff']
1888    elements   = ['sposet_builder','determinantset','jastrow']
1889    identifier = 'name','id'
1890#end class wavefunction
1891
1892class determinantset(QIxml):
1893    attributes = ['type','href','sort','tilematrix','twistnum','twist','source','version','meshfactor','gpu','transform','precision','truncate','lr_dim_cutoff','shell','randomize','key','rmax_core','dilation','name','cuspcorrection','tiling','usegrid','meshspacing','shell2','src','buffer','bconds','keyword','hybridrep','pbcimages']
1894    elements   = ['basisset','sposet','slaterdeterminant','multideterminant','spline','backflow','cubicgrid']
1895    h5tags     = ['twistindex','twistangle','rcut']
1896    write_types = obj(gpu=yesno,sort=onezero,transform=yesno,truncate=yesno,randomize=truefalse,cuspcorrection=yesno,usegrid=yesno)
1897#end class determinantset
1898
1899class spline(QIxml):
1900    attributes = ['method']
1901    elements   = ['grid']
1902#end class spline
1903
1904class cubicgrid(QIxml):
1905    attributes = ['method']
1906    elements   = ['grid']
1907#end class cubicgrid
1908
1909class basisset(QIxml):
1910    attributes = ['ecut','name','ref','type','source','transform','key']
1911    elements   = ['grid','atomicbasisset']
1912    write_types = obj(transform=yesno)
1913#end class basisset
1914
1915class grid(QIxml):
1916    attributes = ['dir','npts','closed','type','ri','rf','rc','step']
1917    #identifier = 'dir'
1918#end class grid
1919
1920class atomicbasisset(QIxml):
1921    attributes = ['type','elementtype','expandylm','href','normalized','name','angular']
1922    elements   = ['grid','basisgroup']
1923    identifier = 'elementtype'
1924    write_types= obj(#expandylm=yesno,
1925                     normalized=yesno)
1926#end class atomicbasisset
1927
1928class basisgroup(QIxml):
1929    attributes = ['rid','ds','n','l','m','zeta','type','s','imin','source']
1930    parameters = ['b']
1931    elements   = ['radfunc']
1932    #identifier = 'rid'
1933#end class basisgroup
1934
1935class radfunc(QIxml):
1936    attributes = ['exponent','node','contraction','id','type']
1937    precision  = '16.12e'
1938#end class radfunc
1939
1940class slaterdeterminant(QIxml):
1941    attributes = ['optimize']
1942    elements   = ['determinant']
1943    write_types = obj(optimize=yesno)
1944#end class slaterdeterminant
1945
1946class determinant(QIxml):
1947    attributes = ['id','group','sposet','size','ref','spin','href','orbitals','spindataset','name','cuspinfo','debug']
1948    elements   = ['occupation','coefficient']
1949    identifier = 'id'
1950    write_types = obj(debug=yesno)
1951#end class determinant
1952
1953class occupation(QIxml):
1954    attributes = ['mode','spindataset','size','pairs','format']
1955    text       = 'contents'
1956#end class occupation
1957
1958class multideterminant(QIxml):
1959    attributes = ['optimize','spo_up','spo_dn']
1960    elements   = ['detlist']
1961#end class multideterminant
1962
1963class detlist(QIxml):
1964    attributes = ['size','type','nca','ncb','nea','neb','nstates','cutoff','href']
1965    elements   = ['ci','csf']
1966#end class detlist
1967
1968class ci(QIxml):
1969    attributes = ['id','coeff','qc_coeff','alpha','beta']
1970    #identifier = 'id'
1971    attr_types = obj(alpha=str,beta=str)
1972    precision  = '16.12e'
1973#end class ci
1974
1975class csf(QIxml):
1976    attributes = ['id','exctlvl','coeff','qchem_coeff','occ']
1977    elements   = ['det']
1978    attr_types = obj(occ=str)
1979#end class csf
1980
1981class det(QIxml):
1982    attributes = ['id','coeff','alpha','beta']
1983    attr_types = obj(alpha=str,beta=str)
1984#end class det
1985
1986class backflow(QIxml):
1987    attributes = ['optimize']
1988    elements   = ['transformation']
1989    write_types = obj(optimize=yesno)
1990#end class backflow
1991
1992class transformation(QIxml):
1993    attributes = ['name','type','function','source']
1994    elements   = ['correlation']
1995    identifier = 'name'
1996#end class transformation
1997
1998class jastrow1(QIxml):
1999    tag = 'jastrow'
2000    attributes = ['type','name','function','source','print','spin','transform']
2001    elements   = ['correlation','distancetable','grid']
2002    identifier = 'name'
2003    write_types = obj(print=yesno,spin=yesno,transform=yesno)
2004#end class jastrow1
2005
2006class jastrow2(QIxml):
2007    tag = 'jastrow'
2008    attributes = ['type','name','function','print','spin','init','kc','transform','source','optimize']
2009    elements   = ['correlation','distancetable','basisset','grid','basisgroup']
2010    parameters = ['b','longrange']
2011    identifier = 'name'
2012    write_types = obj(print=yesno,transform=yesno,optimize=yesno)
2013#end class jastrow2
2014
2015class jastrow3(QIxml):
2016    tag = 'jastrow'
2017    attributes = ['type','name','function','print','source']
2018    elements   = ['correlation']
2019    identifier = 'name'
2020    write_types = obj(print=yesno)
2021#end class jastrow3
2022
2023class kspace_jastrow(QIxml):
2024    tag = 'jastrow'
2025    attributes = ['type','name','source']
2026    elements   = ['correlation']
2027    identifier = 'name'
2028    write_types = obj(optimize=yesno)
2029#end class kspace_jastrow
2030
2031class rpa_jastrow(QIxml):
2032    tag = 'jastrow'
2033    attributes = ['type','name','source','function','kc']
2034    parameters = ['longrange']
2035    identifier = 'name'
2036    write_types = obj(longrange=yesno)
2037#end class rpa_jastrow
2038
2039class correlation(QIxml):
2040    attributes = ['elementtype','speciesa','speciesb','size','ispecies','especies',
2041                  'especies1','especies2','isize','esize','rcut','cusp','pairtype',
2042                  'kc','type','symmetry','cutoff','spindependent','dimension','init',
2043                  'species']
2044    parameters = ['a','b','c','d']
2045    elements   = ['coefficients','var','coefficient']
2046    identifier = 'speciesa','speciesb','elementtype','especies1','especies2','ispecies'
2047    write_types = obj(init=yesno)
2048#end class correlation
2049
2050class var(QIxml):
2051    attributes = ['id','name','optimize']
2052    text       = 'value'
2053    identifier = 'id'
2054    write_types=obj(optimize=yesno)
2055#end class var
2056
2057class coefficients(QIxml):
2058    attributes = ['id','type','optimize','state','size','cusp','rcut']
2059    text       = 'coeff'
2060    write_types= obj(optimize=yesno)
2061    exp_names  = obj(array='Array')
2062#end class coefficients
2063
2064class coefficient(QIxml):  # this is bad!!! coefficients/coefficient
2065    attributes = ['id','type','size','dataset','spindataset']
2066    text       = 'coeff'
2067    precision  = '16.12e'
2068#end class coefficient
2069
2070class distancetable(QIxml):
2071    attributes = ['source','target']
2072#end class distancetable
2073
2074jastrow = QIxmlFactory(
2075    name = 'jastrow',
2076    types   = dict(one_body=jastrow1,two_body=jastrow2,jastrow1=jastrow1,jastrow2=jastrow2,eei=jastrow3,jastrow3=jastrow3,kspace=kspace_jastrow,kspace_jastrow=kspace_jastrow,rpa=rpa_jastrow,rpa_jastrow=rpa_jastrow),
2077    typekey = 'type'
2078    )
2079
2080
2081class hamiltonian(QIxml):
2082    #            rsqmc                              afqmc
2083    attributes = ['name','type','target','default']+['info']
2084    #            afqmc
2085    parameters = ['filetype','filename']
2086    elements   = ['pairpot','constant','estimator']
2087    identifier = 'name'
2088#end class hamiltonian
2089
2090class coulomb(QIxml):
2091    tag = 'pairpot'
2092    attributes  = ['type','name','source','target','physical','forces']
2093    write_types = obj(physical=yesno)
2094    identifier  = 'name'
2095#end class coulomb
2096
2097class constant(QIxml):
2098    attributes = ['type','name','source','target','forces']
2099    write_types= obj(forces=yesno)
2100#end class constant
2101
2102class pseudopotential(QIxml):
2103    tag = 'pairpot'
2104    attributes = ['type','name','source','wavefunction','format','target','forces','dla']
2105    elements   = ['pseudo']
2106    write_types= obj(forces=yesno,dla=yesno)
2107    identifier = 'name'
2108#end class pseudopotential
2109
2110class pseudo(QIxml):
2111    attributes = ['elementtype','href','format','cutoff','lmax','nrule','l_local']
2112    elements   = ['header','local','grid']
2113    identifier = 'elementtype'
2114#end class pseudo
2115
2116class mpc(QIxml):
2117    tag='pairpot'
2118    attributes=['type','name','source','target','ecut','physical']
2119    write_types = obj(physical=yesno)
2120    identifier='name'
2121#end class mpc
2122
2123class cpp(QIxml):
2124    tag = 'pairpot'
2125    attributes = ['type','name','source','target']
2126    elements   = ['element']
2127    identifier = 'name'
2128#end class cpp
2129
2130class element(QIxml):
2131    attributes = ['name','alpha','rb']
2132#end class element
2133
2134pairpot = QIxmlFactory(
2135    name  = 'pairpot',
2136    types = dict(coulomb=coulomb,pseudo=pseudopotential,
2137                 pseudopotential=pseudopotential,mpc=mpc,
2138                 cpp=cpp),
2139    typekey = 'type'
2140    )
2141
2142
2143class header(QIxml):
2144    attributes = ['symbol','atomic-number','zval']
2145#end class header
2146
2147class local(QIxml):
2148    elements = ['grid']
2149#end class local
2150
2151
2152class localenergy(QIxml):
2153    tag = 'estimator'
2154    attributes = ['name','hdf5']
2155    write_types= obj(hdf5=yesno)
2156    identifier = 'name'
2157#end class localenergy
2158
2159class energydensity(QIxml):
2160    tag = 'estimator'
2161    attributes  = ['type','name','dynamic','static','ion_points']
2162    elements    = ['reference_points','spacegrid']
2163    identifier  = 'name'
2164    write_types = obj(ion_points=yesno)
2165#end class energydensity
2166
2167class reference_points(QIxml):
2168    attributes = ['coord']
2169    text = 'points'
2170#end class reference_points
2171
2172class spacegrid(QIxml):
2173    attributes = ['coord','min_part','max_part']
2174    elements   = ['origin','axis']
2175#end class spacegrid
2176
2177class origin(QIxml):
2178    attributes = ['p1','p2']
2179#end class origin
2180
2181class axis(QIxml):
2182    attributes = ['p1','p2','scale','label','grid']
2183    identifier = 'label'
2184#end class axis
2185
2186class chiesa(QIxml):
2187    tag = 'estimator'
2188    attributes = ['name','type','source','psi','wavefunction','target']
2189    identifier = 'name'
2190#end class chiesa
2191
2192class density(QIxml):
2193    tag = 'estimator'
2194    attributes = ['name','type','delta','x_min','x_max','y_min','y_max','z_min','z_max']
2195    identifier = 'type'
2196#end class density
2197
2198class nearestneighbors(QIxml):
2199    tag = 'estimator'
2200    attributes = ['type']
2201    elements   = ['neighbor_trace']
2202    identifier = 'type'
2203#end class nearestneighbors
2204
2205class neighbor_trace(QIxml):
2206    attributes = ['count','neighbors','centers']
2207    identifier = 'neighbors','centers'
2208#end class neighbor_trace
2209
2210class dm1b(QIxml):
2211    tag         = 'estimator'
2212    identifier  = 'type'
2213    attributes  = ['type','name','reuse']#reuse is a temporary dummy keyword
2214    parameters  = ['energy_matrix','basis_size','integrator','points','scale','basis','evaluator','center','check_overlap','check_derivatives','acceptance_ratio','rstats','normalized','volume_normed']
2215    write_types = obj(energy_matrix=yesno,check_overlap=yesno,check_derivatives=yesno,acceptance_ratio=yesno,rstats=yesno,normalized=yesno,volume_normed=yesno)
2216#end class dm1b
2217
2218class spindensity(QIxml):
2219    tag = 'estimator'
2220    attributes  = ['type','name','report']
2221    parameters  = ['dr','grid','cell','center','corner','voronoi','test_moves']
2222    write_types = obj(report=yesno)
2223    identifier  = 'name'
2224#end class spindensity
2225
2226class spindensity_new(QIxml): # temporary
2227    tag = 'estimator'
2228    attributes  = ['type','name','report','save_memory']
2229    parameters  = ['dr','grid','cell','center','corner','voronoi','test_moves']
2230    write_types = obj(report=yesno,save_memory=yesno)
2231    identifier  = 'name'
2232#end class spindensity_new
2233
2234class structurefactor(QIxml):
2235    tag = 'estimator'
2236    attributes  = ['type','name','report']
2237    write_types = obj(report=yesno)
2238    identifier  = 'name'
2239#end class structurefactor
2240
2241class force(QIxml):
2242    tag = 'estimator'
2243    attributes = ['type','name','mode','source','species','target','addionion']
2244    parameters = ['rcut','nbasis','weightexp']
2245    identifier = 'name'
2246    write_types= obj(addionion=yesno)
2247#end class force
2248
2249class forwardwalking(QIxml):
2250    tag = 'estimator'
2251    attributes = ['type','blocksize']
2252    elements   = ['observable']
2253    identifier = 'name'
2254#end class forwardwalking
2255
2256class pressure(QIxml):
2257    tag = 'estimator'
2258    attributes = ['type','potential','etype','function']
2259    parameters = ['kc']
2260    identifier = 'type'
2261#end class pressure
2262
2263class dmccorrection(QIxml):
2264    tag = 'estimator'
2265    attributes = ['type','blocksize','max','frequency']
2266    elements   = ['observable']
2267    identifier = 'type'
2268#end class dmccorrection
2269
2270class nofk(QIxml):
2271    tag = 'estimator'
2272    attributes = ['type','name','wavefunction']
2273    identifier = 'name'
2274#end class nofk
2275
2276class mpc_est(QIxml):
2277    tag = 'estimator'
2278    attributes = ['type','name','physical']
2279    write_types = obj(physical=yesno)
2280    identifier = 'name'
2281#end class mpc_est
2282
2283class sk(QIxml):
2284    tag = 'estimator'
2285    attributes = ['name','type','hdf5']
2286    identifier = 'name'
2287    write_types = obj(hdf5=yesno)
2288#end class sk
2289
2290class skall(QIxml):
2291    tag = 'estimator'
2292    attributes = ['name','type','hdf5','source','target','writeionion']
2293    identifier = 'name'
2294    write_types = obj(hdf5=yesno,writeionion=yesno)
2295#end class skall
2296
2297class gofr(QIxml):
2298    tag = 'estimator'
2299    attributes = ['type','name','num_bin','rmax','source']
2300    identifier = 'name'
2301#end class gofr
2302
2303class flux(QIxml):
2304    tag = 'estimator'
2305    attributes = ['type','name']
2306    identifier = 'name'
2307#end class flux
2308
2309class momentum(QIxml):
2310    tag = 'estimator'
2311    attributes = ['type','name','grid','samples','hdf5','wavefunction','kmax','kmax0','kmax1','kmax2']
2312    identifier = 'name'
2313    write_types = obj(hdf5=yesno)
2314#end class momentum
2315
2316# afqmc estimators
2317class back_propagation(QIxml):
2318    tag = 'estimator'
2319    attributes = ['name']
2320    parameters = ['naverages','block_size','ortho','nsteps']
2321    elements   = ['onerdm']
2322    identifier = 'name'
2323#end class back_propagation
2324
2325estimator = QIxmlFactory(
2326    name  = 'estimator',
2327    types = dict(localenergy         = localenergy,
2328                 energydensity       = energydensity,
2329                 chiesa              = chiesa,
2330                 density             = density,
2331                 nearestneighbors    = nearestneighbors,
2332                 dm1b                = dm1b,
2333                 spindensity         = spindensity,
2334                 spindensity_new     = spindensity_new, # temporary
2335                 structurefactor     = structurefactor,
2336                 force               = force,
2337                 forwardwalking      = forwardwalking,
2338                 pressure            = pressure,
2339                 dmccorrection       = dmccorrection,
2340                 nofk                = nofk,
2341                 mpc                 = mpc_est,
2342                 sk                  = sk,
2343                 skall               = skall,
2344                 gofr                = gofr,
2345                 flux                = flux,
2346                 momentum            = momentum,
2347                 # afqmc estimators
2348                 back_propagation    = back_propagation,
2349                 ),
2350    typekey  = 'type',
2351    typekey2 = 'name'
2352    )
2353
2354
2355class observable(QIxml):
2356    attributes = ['name','max','frequency']
2357    identifier = 'name'
2358#end class observable
2359
2360
2361
2362class init(QIxml):
2363    attributes = ['source','target']
2364#end class
2365
2366
2367class scalar_traces(QIxml):
2368    attributes  = ['defaults']
2369    text        = 'quantities'
2370    write_types = obj(defaults=yesno)
2371#end class scalar_traces
2372
2373class array_traces(QIxml):
2374    attributes  = ['defaults']
2375    text        = 'quantities'
2376    write_types = obj(defaults=yesno)
2377#end class array_traces
2378
2379class particle_traces(QIxml): # legacy
2380    attributes  = ['defaults']
2381    text        = 'quantities'
2382    write_types = obj(defaults=yesno)
2383#end class particle_traces
2384
2385class traces(QIxml):
2386    attributes = ['write','throttle','format','verbose','scalar','array',
2387                  'scalar_defaults','array_defaults',
2388                  'particle','particle_defaults']
2389    elements = ['scalar_traces','array_traces','particle_traces']
2390    write_types = obj(write_=yesno,verbose=yesno,scalar=yesno,array=yesno,
2391                      scalar_defaults=yesno,array_defaults=yesno,
2392                      particle=yesno,particle_defaults=yesno)
2393#end class traces
2394
2395
2396class record(QIxml):
2397    attributes = ['name','stride']
2398#end class record
2399
2400
2401class loop(QIxml):
2402    collection_id = 'qmc'
2403    attributes = ['max']
2404    elements = ['qmc','init']
2405    def unroll(self):
2406        calculations=[]
2407        calcs = []
2408        if 'qmc' in self:
2409            calcs = [self.qmc]
2410        elif 'calculations' in self:
2411            calcs = self.calculations
2412        #end if
2413        for n in range(self.max):
2414            for i in range(len(calcs)):
2415                calculations.append(calcs[i].copy())
2416            #end for
2417        #end for
2418        return make_collection(calculations)
2419    #end def unroll
2420#end class loop
2421
2422
2423class optimize(QIxml):
2424    text = 'parameters'
2425#end class optimize
2426
2427class cg_optimizer(QIxml):
2428    tag        = 'optimizer'
2429    attributes = ['method']
2430    parameters = ['max_steps','tolerance','stepsize','friction','epsilon',
2431                  'xybisect','verbose','max_linemin','tolerance_g','length_cg',
2432                  'rich','xypolish','gfactor']
2433#end class cg_optimizer
2434
2435class flex_optimizer(QIxml):
2436    tag        = 'optimizer'
2437    attributes = ['method']
2438    parameters = ['max_steps','tolerance','stepsize','epsilon',
2439                  'xybisect','verbose','max_linemin','tolerance_g','length_cg',
2440                  'rich','xypolish','gfactor']
2441#end class flex_optimizer
2442
2443
2444
2445optimizer = QIxmlFactory(
2446    name    = 'optimizer',
2447    types   = dict(cg=cg_optimizer,flexopt=flex_optimizer),
2448    typekey = 'method',
2449    )
2450
2451
2452
2453class optimize_qmc(QIxml):
2454    collection_id = 'qmc'
2455    tag = 'qmc'
2456    attributes = ['method','move','renew','completed','checkpoint','gpu']
2457    parameters = ['blocks','steps','timestep','walkers','minwalkers','useweight',
2458                  'power','correlation','maxweight','usedrift','min_walkers',
2459                  'minke','samples','warmupsteps','minweight','warmupblocks',
2460                  'maxdispl','tau','tolerance','stepsize','epsilon',
2461                  'en_ref','usebuffer','substeps','stepsbetweensamples',
2462                  'samplesperthread','max_steps','nonlocalpp']
2463    elements = ['optimize','optimizer','estimator']
2464    costs    = ['energy','variance','difference','weight','unreweightedvariance','reweightedvariance']
2465    write_types = obj(renew=yesno,completed=yesno)
2466#end class optimize_qmc
2467
2468class linear(QIxml):
2469    collection_id = 'qmc'
2470    tag = 'qmc'
2471    attributes = ['method','move','checkpoint','gpu','trace']
2472    elements   = ['estimator']
2473    parameters = ['walkers','warmupsteps','blocks','steps','substeps','timestep',
2474                  'usedrift','stepsbetweensamples','samples','minmethod',
2475                  'minwalkers','maxweight','nonlocalpp','use_nonlocalpp_deriv',
2476                  'usebuffer','alloweddifference','gevmethod','beta','exp0',
2477                  'bigchange','stepsize','stabilizerscale','nstabilizers',
2478                  'max_its','cgsteps','eigcg','stabilizermethod',
2479                  'rnwarmupsteps','walkersperthread','minke','gradtol','alpha',
2480                  'tries','min_walkers','samplesperthread',
2481                  'shift_i','shift_s','max_relative_change','max_param_change',
2482                  'chase_lowest','chase_closest','block_lm','nblocks','nolds',
2483                  'nkept',
2484                  ]
2485    costs      = ['energy','unreweightedvariance','reweightedvariance','variance','difference']
2486    write_types = obj(gpu=yesno,usedrift=yesno,nonlocalpp=yesno,usebuffer=yesno,use_nonlocalpp_deriv=yesno,chase_lowest=yesno,chase_closest=yesno,block_lm=yesno)
2487#end class linear
2488
2489class cslinear(QIxml):
2490    collection_id = 'qmc'
2491    tag = 'qmc'
2492    attributes = ['method','move','checkpoint','gpu','trace']
2493    elements   = ['estimator']
2494    parameters = ['walkers','warmupsteps','blocks','steps','substeps','timestep',
2495                  'usedrift','stepsbetweensamples','samples','minmethod',
2496                  'minwalkers','maxweight','nonlocalpp','usebuffer',
2497                  'alloweddifference','gevmethod','beta','exp0','bigchange',
2498                  'stepsize','stabilizerscale','nstabilizers','max_its',
2499                  'stabilizermethod','cswarmupsteps','alpha_error','gevsplit',
2500                  'beta_error','use_nonlocalpp_deriv']
2501    costs      = ['energy','unreweightedvariance','reweightedvariance']
2502    write_types = obj(gpu=yesno,usedrift=yesno,nonlocalpp=yesno,use_nonlocalpp_deriv=yesno,usebuffer=yesno)
2503#end class cslinear
2504
2505class vmc(QIxml):
2506    collection_id = 'qmc'
2507    tag = 'qmc'
2508    attributes = ['method','multiple','warp','move','gpu','checkpoint','trace','target','completed','id']
2509    elements   = ['estimator','record']
2510    parameters = ['walkers','warmupsteps','blocks','steps','substeps','timestep','usedrift','stepsbetweensamples','samples','samplesperthread','nonlocalpp','tau','walkersperthread','reconfiguration','dmcwalkersperthread','current','ratio','firststep','minimumtargetwalkers']
2511    write_types = obj(gpu=yesno,usedrift=yesno,nonlocalpp=yesno,reconfiguration=yesno,ratio=yesno,completed=yesno)
2512#end class vmc
2513
2514class dmc(QIxml):
2515    collection_id = 'qmc'
2516    tag = 'qmc'
2517    attributes = ['method','move','gpu','multiple','warp','checkpoint','trace','target','completed','id','continue']
2518    elements   = ['estimator']
2519    parameters = ['walkers','warmupsteps','blocks','steps','timestep','nonlocalmove','nonlocalmoves','pop_control','reconfiguration','targetwalkers','minimumtargetwalkers','sigmabound','energybound','feedback','recordwalkers','fastgrad','popcontrol','branchinterval','usedrift','storeconfigs','en_ref','tau','alpha','gamma','stepsbetweensamples','max_branch','killnode','swap_walkers','swap_trigger','branching_cutoff_scheme','l2_diffusion']
2520    write_types = obj(gpu=yesno,nonlocalmoves=yesnostr,reconfiguration=yesno,fastgrad=yesno,completed=yesno,killnode=yesno,swap_walkers=yesno,l2_diffusion=yesno)
2521#end class dmc
2522
2523class rmc(QIxml):
2524    collection_id = 'qmc'
2525    tag = 'qmc'
2526    attributes = ['method','multiple','target','observables','target','warp']
2527    parameters = ['blocks','steps','chains','cuts','bounce','clone','walkers','timestep','trunclength','maxtouch','mass','collect']
2528    elements = ['qmcsystem']
2529    write_types = obj(collect=yesno)
2530#end class rmc
2531
2532class vmc_batch(QIxml):
2533    collection_id = 'qmc'
2534    tag = 'qmc'
2535    attributes = ['method','move']
2536    elements   = ['estimator']
2537    parameters = ['walkers','warmupsteps','blocks','steps','substeps','timestep','usedrift']
2538    write_types = obj(usedrift=yesno)
2539#end class vmc_batch
2540
2541class wftest(QIxml):
2542    collection_id = 'qmc'
2543    tag = 'qmc'
2544    attributes = ['method','checkpoint', 'gpu', 'move', 'multiple', 'warp']
2545    parameters = ['ratio','walkers','clone','source','hamiltonianpbyp','orbitalutility','printeloc','basic','virtual_move']
2546    #elements   = ['printeloc','source']
2547    write_types = obj(ratio=yesno,clone=yesno,hamiltonianpbyp=yesno,orbitalutility=yesno,printeloc=yesno,basic=yesno,virtual_move=yesno)
2548#end class wftest
2549
2550class setparams(QIxml):
2551    collection_id = 'qmc'
2552    tag = 'qmc'
2553    attributes = ['method','move','checkpoint','gpu']
2554    parameters = ['alpha','blocks','warmupsteps','stepsbetweensamples','timestep','samples','usedrift']
2555    elements   = ['estimator']
2556#end class setparams
2557
2558qmc = QIxmlFactory(
2559    name = 'qmc',
2560    types   = dict(linear=linear,cslinear=cslinear,vmc=vmc,dmc=dmc,loop=loop,optimize=optimize_qmc,wftest=wftest,rmc=rmc,setparams=setparams,vmc_batch=vmc_batch),
2561    typekey = 'method',
2562    default = 'loop'
2563    )
2564
2565
2566
2567class cmc(QIxml):
2568    attributes = ['method','target']
2569#end class cmc
2570
2571
2572
2573# afqmc elements
2574
2575class afqmcinfo(QIxml):
2576    attributes = ['name']
2577    parameters = ['nmo','naea','naeb']
2578#end class afqmcinfo
2579
2580class walkerset(QIxml):
2581    attributes = ['name','type']
2582    parameters = ['walker_type']
2583#end class walkerset
2584
2585class propagator(QIxml):
2586    attributes  = ['name','info']
2587    parameters  = ['hybrid']
2588    write_types = obj(hybrid=yesno)
2589#end class propagator
2590
2591class execute(QIxml):
2592    attributes = ['info','ham','wfn','wset','prop']
2593    parameters = ['ncores','nwalkers','blocks','steps','timestep']
2594    elements   = ['estimator']
2595#end class execute
2596
2597class onerdm(QIxml):
2598    None
2599#end class onerdm
2600
2601
2602
2603
2604
2605class gen(QIxml):
2606    attributes = []
2607    elements   = []
2608#end class gen
2609
2610
2611classes = [   #standard classes
2612    simulation,project,application,random,qmcsystem,simulationcell,particleset,
2613    group,hamiltonian,constant,pseudopotential,coulomb,pseudo,mpc,chiesa,density,
2614    localenergy,energydensity,spacegrid,origin,axis,wavefunction,
2615    determinantset,slaterdeterminant,basisset,grid,determinant,occupation,
2616    jastrow1,jastrow2,jastrow3,
2617    correlation,coefficients,loop,linear,cslinear,vmc,dmc,vmc_batch,
2618    atomicbasisset,basisgroup,init,var,traces,scalar_traces,particle_traces,array_traces,
2619    reference_points,nearestneighbors,neighbor_trace,dm1b,
2620    coefficient,radfunc,spindensity,structurefactor,
2621    spindensity_new, # temporary
2622    sposet,bspline_builder,composite_builder,heg_builder,include,
2623    multideterminant,detlist,ci,mcwalkerset,csf,det,
2624    optimize,cg_optimizer,flex_optimizer,optimize_qmc,wftest,kspace_jastrow,
2625    header,local,force,forwardwalking,observable,record,rmc,pressure,dmccorrection,
2626    nofk,mpc_est,flux,distancetable,cpp,element,spline,setparams,
2627    backflow,transformation,cubicgrid,molecular_orbital_builder,cmc,sk,skall,gofr,
2628    host,date,user,rpa_jastrow,momentum,
2629    # afqmc classes
2630    afqmcinfo,walkerset,propagator,execute,back_propagation,onerdm
2631    ]
2632types = dict( #simple types and factories
2633    #host           = param,
2634    #date           = param,
2635    #user           = param,
2636    pairpot        = pairpot,
2637    estimator      = estimator,
2638    sposet_builder = sposet_builder,
2639    jastrow        = jastrow,
2640    qmc            = qmc,
2641    optimizer      = optimizer,
2642    )
2643plurals = obj(
2644    particlesets    = 'particleset',
2645    groups          = 'group',
2646    hamiltonians    = 'hamiltonian',
2647    pairpots        = 'pairpot',
2648    pseudos         = 'pseudo',
2649    estimators      = 'estimator',
2650    spacegrids      = 'spacegrid',
2651    axes            = 'axis',
2652    wavefunctions   = 'wavefunction',
2653    grids           = 'grid',
2654    determinants    = 'determinant',
2655    correlations    = 'correlation',
2656    jastrows        = 'jastrow',
2657    basisgroups     = 'basisgroup',
2658    calculations    = 'qmc',
2659    vars            = 'var',
2660    neighbor_traces = 'neighbor_trace',
2661    sposet_builders = 'sposet_builder',
2662    sposets         = 'sposet',
2663    radfuncs        = 'radfunc',
2664    #qmcsystems      = 'qmcsystem',  # not a good idea
2665    atomicbasissets = 'atomicbasisset',
2666    cis             = 'ci',
2667    csfs            = 'csf',
2668    dets            = 'det',
2669    observables     = 'observable',
2670    optimizes       = 'optimize',
2671    #coefficientss   = 'coefficients', # bad plurality of qmcpack
2672    constants       = 'constant',
2673    mcwalkersets    = 'mcwalkerset',
2674    transformations = 'transformation',
2675    )
2676plurals_inv = plurals.inverse()
2677plural_names = set(plurals.keys())
2678single_names = set(plurals.values())
2679Names.set_expanded_names(
2680    elementtype      = 'elementType',
2681    energydensity    = 'EnergyDensity',
2682    gevmethod        = 'GEVMethod',
2683    localenergy      = 'LocalEnergy',
2684    lr_dim_cutoff    = 'LR_dim_cutoff',
2685    lr_tol           = 'LR_tol',
2686    lr_handler       = 'LR_handler',
2687    minmethod        = 'MinMethod',
2688    one_body         = 'One-Body',
2689    speciesa         = 'speciesA',
2690    speciesb         = 'speciesB',
2691    substeps         = 'subSteps',
2692    two_body         = 'Two-Body',
2693    usedrift         = 'useDrift',
2694    maxweight        = 'maxWeight',
2695    warmupsteps      = 'warmupSteps',
2696    twistindex       = 'twistIndex',
2697    twistangle       = 'twistAngle',
2698    usebuffer        = 'useBuffer',
2699    mpc              = 'MPC',
2700    kecorr           = 'KEcorr',
2701    ionion           = 'IonIon',
2702    elecelec         = 'ElecElec',
2703    pseudopot        = 'PseudoPot',
2704    posarray         = 'posArray',
2705    #array            = 'Array',  # handle separately, namespace collision
2706    atomicbasisset   = 'atomicBasisSet',
2707    basisgroup       = 'basisGroup',
2708    expandylm        = 'expandYlm',
2709    mo               = 'MO',
2710    numerical        = 'Numerical',
2711    nearestneighbors = 'NearestNeighbors',
2712    cuspcorrection   = 'cuspCorrection',
2713    cuspinfo         = 'cuspInfo',
2714    exctlvl          = 'exctLvl',
2715    pairtype         = 'pairType',
2716    printeloc        = 'printEloc',
2717    spindependent    = 'spinDependent',
2718    l_local          = 'l-local',
2719    pbcimages        = 'PBCimages',
2720    dla              = 'DLA',
2721    l2_diffusion     = 'L2_diffusion',
2722    )
2723# afqmc names
2724Names.set_afqmc_expanded_names(
2725    afqmcinfo        = 'AFQMCInfo',
2726    nmo              = 'NMO',
2727    naea             = 'NAEA',
2728    naeb             = 'NAEB',
2729    hamiltonian      = 'Hamiltonian',
2730    wavefunction     = 'Wavefunction',
2731    walkerset        = 'WalkerSet',
2732    propagator       = 'Propagator',
2733    onerdm           = 'OneRDM',
2734    nwalkers         = 'nWalkers',
2735    estimator        = 'Estimator',
2736    )
2737for c in classes:
2738    c.init_class()
2739    types[c.__name__] = c
2740#end for
2741
2742
2743#set default values
2744simulation.defaults.set(
2745    project      = project,
2746    qmcsystem    = qmcsystem,
2747    calculations = lambda:list()
2748    )
2749project.defaults.set(
2750    series=0,
2751    application = application
2752    )
2753application.defaults.set(
2754    name='qmcpack',role='molecu',class_='serial',version='1.0'
2755    )
2756#simulationcell.defaults.set(
2757#    bconds = 'p p p',lr_dim_cutoff=15
2758#    )
2759wavefunction.defaults.set(
2760    name='psi0'
2761    )
2762#determinantset.defaults.set(
2763#    type='einspline',tilematrix=lambda:eye(3,dtype=int),meshfactor=1.,gpu=False,precision='double'
2764#    )
2765#occupation.defaults.set(
2766#    mode='ground',spindataset=0
2767#    )
2768jastrow1.defaults.set(
2769    name='J1',type='one-body',function='bspline',print=True,source='ion0',
2770    correlation=correlation
2771    )
2772jastrow2.defaults.set(
2773    name='J2',type='two-body',function='bspline',print=True,
2774    correlation=correlation
2775    )
2776jastrow3.defaults.set(
2777    name='J3',type='eeI',function='polynomial',print=True,source='ion0',
2778    correlation=correlation
2779    )
2780correlation.defaults.set(
2781    coefficients=coefficients
2782    )
2783coefficients.defaults.set(
2784    type='Array'
2785    )
2786#hamiltonian.defaults.set(
2787#    name='h0',type='generic',target='e',
2788#    constant = constant,
2789#    pairpots = classcollection(coulomb,pseudopotential,mpc),
2790#    estimators = classcollection(chiesa),
2791#    )
2792#coulomb.defaults.set(
2793#    name='ElecElec',type='coulomb',source='e',target='e'
2794#    )
2795#constant.defaults.set(
2796#    name='IonIon',type='coulomb',source='ion0',target='ion0'
2797#    )
2798#pseudopotential.defaults.set(
2799#    name='PseudoPot',type='pseudo',source='ion0',wavefunction='psi0',format='xml'
2800#    )
2801#mpc.defaults.set(
2802#    name='MPC',type='MPC',ecut=60.0,source='e',target='e',physical=False
2803#    )
2804localenergy.defaults.set(
2805    name='LocalEnergy',hdf5=True
2806    )
2807#chiesa.defaults.set(
2808#    name='KEcorr',type='chiesa',source='e',psi='psi0'
2809#    )
2810#energydensity.defaults.set(
2811#    type='EnergyDensity',name='EDvoronoi',dynamic='e',static='ion0',
2812#    spacegrid = spacegrid
2813#    )
2814#spacegrid.defaults.set(
2815#    coord='voronoi'
2816#    )
2817dm1b.defaults.set(
2818    type = 'dm1b',name='DensityMatrices'
2819    )
2820density.defaults.set(
2821    type='density',name='Density'
2822    )
2823spindensity.defaults.set(
2824    type='spindensity',name='SpinDensity'
2825    )
2826skall.defaults.set(
2827    type='skall',name='skall',source='ion0',target='e',hdf5=True
2828    )
2829force.defaults.set(
2830    type='Force',name='force'
2831    )
2832pressure.defaults.set(
2833    type='Pressure'
2834    )
2835momentum.defaults.set(
2836    type='momentum'
2837    )
2838spindensity_new.defaults.set( # temporary
2839    type='spindensity_new',name='SpinDensityNew'
2840    )
2841
2842
2843linear.defaults.set(
2844     method = 'linear',move='pbyp',checkpoint=-1,
2845     #estimators = classcollection(localenergy)
2846#  #jtk
2847#    method='linear',move='pbyp',checkpoint=-1,gpu=True,
2848#    energy=0, reweightedvariance=0, unreweightedvariance=0,
2849#    warmupsteps       = 20,
2850#    usedrift          = True,
2851#    timestep          = .5,
2852#    minmethod         ='rescale',
2853#    stepsize          = .5,
2854#    beta              = 0.05,
2855#    alloweddifference = 1e-8,
2856#    bigchange         = 1.1,
2857#    cgsteps           = 3,
2858#    eigcg             = 1,
2859#    exp0              = -6,
2860#    maxweight         = 1e9,
2861#    minwalkers        = .5,
2862#    nstabilizers      = 10,
2863#    stabilizerscale   = .5,
2864#    usebuffer         = True,
2865    )
2866cslinear.defaults.set(
2867    method='cslinear', move='pbyp', checkpoint=-1,
2868    #estimators = classcollection(localenergy)
2869  #jtk
2870    #method='cslinear',move='pbyp',checkpoint=-1,gpu=True,
2871    #energy=0,reweightedvariance=0,unreweightedvariance=1.,
2872    #warmupsteps=5,steps=2,usedrift=True,timestep=.5,
2873    #minmethod='quartic',gevmethod='mixed',exp0=-15,
2874    #nstabilizers=5,stabilizerscale=3,stepsize=.35,
2875    #alloweddifference=1e-5,beta=.05,bigchange=5.,
2876    #estimators=classcollection(localenergy)
2877  #lschulen
2878    #method='cslinear', move='pbyp', checkpoint=-1, gpu=True,
2879    #energy=0, reweightedvariance=0, unreweightedvariance=0,
2880    #warmupsteps       = 20,
2881    ##steps             = 5,
2882    #usedrift          = True,
2883    #timestep          = .8,
2884    #nonlocalpp        = False,
2885    #minmethod         = 'rescale',
2886    #stepsize          = .4,
2887    #beta              = .05,
2888    #gevmethod         = 'mixed',
2889    #alloweddifference = 1e-4,
2890    #bigchange         = 9.,
2891    #exp0              = -16,
2892    #max_its           = 1,
2893    #maxweight         = 1e9,
2894    #minwalkers        = .5,
2895    #nstabilizers      = 3,
2896    #stabilizerscale   = 1,
2897    #usebuffer         = False,
2898    #estimators = classcollection(localenergy)
2899  #jmm
2900    #method='cslinear', move='pbyp', checkpoint=-1, gpu=True,
2901    #energy=0, reweightedvariance=0, unreweightedvariance=0,
2902    #warmupsteps       = 20,
2903    #usedrift          = True,
2904    #timestep          = .5,
2905    #nonlocalpp        = True,
2906    #minmethod         = 'quartic',
2907    #stepsize          = .4,
2908    #beta              = 0.0,
2909    #gevmethod         = 'mixed',
2910    #alloweddifference = 1.0e-4,
2911    #bigchange         = 9.0,
2912    #exp0              = -16,
2913    #max_its           = 1,
2914    #maxweight         = 1e9,
2915    #minwalkers        = 0.5,
2916    #nstabilizers      = 3,
2917    #stabilizerscale   = 1.0,
2918    #usebuffer         = True,
2919    #estimators = classcollection(localenergy)
2920    )
2921vmc.defaults.set(
2922    method='vmc',move='pbyp',
2923    #walkers     = 1,
2924    #warmupsteps = 50,
2925    #substeps    = 3,
2926    #usedrift    = True,
2927    #timestep    = .5,
2928    #estimators = classcollection(localenergy)
2929    )
2930dmc.defaults.set(
2931    method='dmc',move='pbyp',
2932    #warmupsteps   = 20,
2933    #timestep      = .01,
2934    #nonlocalmoves = True,
2935    #estimators = classcollection(localenergy)
2936    )
2937vmc_batch.defaults.set(
2938    method='vmc_batch',move='pbyp',
2939    )
2940
2941
2942
2943# afqmc defaults
2944afqmcinfo.defaults.set(
2945    name = 'info0',
2946    )
2947walkerset.defaults.set(
2948    name = 'wset0',
2949    )
2950propagator.defaults.set(
2951    name = 'prop0',
2952    info = 'info0',
2953    )
2954execute.defaults.set(
2955    info = 'info0',
2956    ham  = 'ham0',
2957    wfn  = 'wfn0',
2958    wset = 'wset0',
2959    prop = 'prop0',
2960    )
2961back_propagation.defaults.set(
2962    name='back_propagation'
2963    )
2964
2965
2966
2967
2968
2969def set_rsqmc_mode():
2970    QIobj.afqmc_mode = False
2971    Names.use_rsqmc_expanded_names()
2972#end def set_rsqmc_mode
2973
2974def set_afqmc_mode():
2975    QIobj.afqmc_mode = True
2976    Names.use_afqmc_expanded_names()
2977#end def set_afqmc_mode
2978
2979
2980
2981
2982class QmcpackInput(SimulationInput,Names):
2983
2984    profile_collection = None
2985
2986    opt_methods = set(['opt','linear','cslinear'])
2987
2988    simulation_type = simulation
2989
2990    default_metadata = meta(
2991        lattice    = dict(units='bohr'),
2992        reciprocal = dict(units='2pi/bohr'),
2993        ionid      = dict(datatype='stringArray'),
2994        position   = dict(datatype='posArray', condition=0)
2995        )
2996
2997    @staticmethod
2998    def settings(**kwargs):
2999        QIobj.settings(**kwargs)
3000    #end def settings
3001
3002    def __init__(self,arg0=None,arg1=None):
3003        Param.metadata = None
3004        filepath = None
3005        metadata = None
3006        element  = None
3007        if arg0==None and arg1==None:
3008            None
3009        elif isinstance(arg0,str) and arg1==None:
3010            filepath = arg0
3011        elif isinstance(arg0,QIxml) and arg1==None:
3012            element = arg0
3013        elif isinstance(arg0,meta) and isinstance(arg1,QIxml):
3014            metadata = arg0
3015            element  = arg1
3016        else:
3017            self.error('input arguments of types '+arg0.__class__.__name__+' and '+arg0.__class__.__name__+' cannot be used to initialize QmcpackInput')
3018        #end if
3019        if metadata!=None:
3020            self._metadata = metadata
3021        else:
3022            self._metadata = meta()
3023        #end if
3024        if filepath!=None:
3025            self.read(filepath)
3026        elif element!=None:
3027            #simulation = arg0
3028            #self.simulation = self.simulation_type(simulation)
3029            elem_class = element.__class__
3030            if elem_class.identifier!=None:
3031                name = elem_class.identifier
3032            else:
3033                name = elem_class.__name__
3034            #end if
3035            self[name] = elem_class(element)
3036        #end if
3037        Param.metadata = None
3038        QIcollections.clear()
3039    #end def __init__
3040
3041    def is_afqmc_input(self):
3042        is_afqmc = False
3043        if 'simulation' in self:
3044            sim = self.simulation
3045            is_afqmc = 'method' in sim and sim.method.lower()=='afqmc'
3046        #end if
3047        return is_afqmc
3048    #end def is_afqmc_input
3049
3050    def get_base(self):
3051        elem_names = list(self.keys())
3052        elem_names.remove('_metadata')
3053        if len(elem_names)>1:
3054            self.error('qmcpack input cannot have more than one base element\n  You have provided '+str(len(elem_names))+': '+str(elem_names))
3055        #end if
3056        return self[elem_names[0]]
3057    #end def get_base
3058
3059    def get_basename(self):
3060        elem_names = list(self.keys())
3061        elem_names.remove('_metadata')
3062        if len(elem_names)>1:
3063            self.error('qmcpack input cannot have more than one base element\n  You have provided '+str(len(elem_names))+': '+str(elem_names))
3064        #end if
3065        return elem_names[0]
3066    #end def get_basename
3067
3068    def read(self,filepath=None,xml=None):
3069        if xml!=None or os.path.exists(filepath):
3070            element_joins=['qmcsystem']
3071            element_aliases=dict(loop='qmc')
3072            xml = XMLreader(filepath,element_joins,element_aliases,warn=False,xml=xml).obj
3073            xml.condense()
3074            self._metadata = meta() #store parameter/attrib attribute metadata
3075            Param.metadata = self._metadata
3076            if 'simulation' in xml:
3077                self.simulation = simulation(xml.simulation)
3078            else:
3079                #try to determine the type
3080                elements = []
3081                keys = []
3082                error = False
3083                for key,value in xml.items():
3084                    if isinstance(key,str) and key[0]!='_':
3085                        if key in types:
3086                            elements.append(types[key](value))
3087                            keys.append(key)
3088                        else:
3089                            self.error('element '+key+' is not a recognized type',exit=False)
3090                            error = True
3091                        #end if
3092                    #end if
3093                #end for
3094                if error:
3095                    self.error('cannot read input xml file')
3096                #end if
3097                if len(elements)==0:
3098                    self.error('no valid elements were found for input xml file')
3099                #end if
3100                for i in range(len(elements)):
3101                    elem = elements[i]
3102                    key  = keys[i]
3103                    if isinstance(elem,QIxml):
3104                        if elem.identifier!=None:
3105                            name = elem.identifier
3106                        else:
3107                            name = elem.tag
3108                        #end if
3109                    else:
3110                        name = key
3111                    #end if
3112                    self[name] = elem
3113                #end for
3114            #end if
3115            Param.metadata = None
3116        else:
3117            self.error('the filepath you provided does not exist.\n  Input filepath: '+filepath)
3118        #end if
3119        return self
3120    #end def read
3121
3122
3123    def write_text(self,filepath=None):
3124        set_rsqmc_mode()
3125        if self.is_afqmc_input():
3126            set_afqmc_mode()
3127        #end if
3128        c = ''
3129        header = '''<?xml version="1.0"?>
3130'''
3131        c+= header
3132        if len(self._metadata)==0:
3133            Param.metadata = self.default_metadata
3134        else:
3135            Param.metadata = self._metadata
3136        #end if
3137        base = self.get_base()
3138        c+=base.write(first=True)
3139        Param.metadata = None
3140        set_rsqmc_mode()
3141        return c
3142    #end def write_text
3143
3144
3145    def unroll_calculations(self,modify=True):
3146        qmc = []
3147        sim = self.simulation
3148        if 'calculations' in sim:
3149            calcs = sim.calculations
3150        elif 'qmc' in sim:
3151            calcs = [sim.qmc]
3152        else:
3153            calcs = []
3154        #end if
3155        for i in range(len(calcs)):
3156            c = calcs[i]
3157            if isinstance(c,loop):
3158                qmc.extend(c.unroll())
3159            else:
3160                qmc.append(c)
3161            #end if
3162        #end for
3163        qmc = make_collection(qmc)
3164        if modify:
3165            self.simulation.calculations = qmc
3166        #end if
3167        return qmc
3168    #end def unroll_calculations
3169
3170    def get(self,*names):
3171        base = self.get_base()
3172        return base.get(names)
3173    #end def get
3174
3175    def remove(self,*names):
3176        base = self.get_base()
3177        base.remove(*names)
3178    #end def remove
3179
3180    def assign(self,**kwargs):
3181        base = self.get_base()
3182        base.assign(**kwargs)
3183    #end def assign
3184
3185    def replace(self,*args,**kwargs):# input is list of keyword=(oldval,newval)
3186        base = self.get_base()
3187        base.replace(*args,**kwargs)
3188    #end def replace
3189
3190    def move(self,**elemdests):
3191        base = self.get_base()
3192        base.move(**elemdests)
3193    #end def move
3194
3195
3196    def get_host(self,names):
3197        base = self.get_base()
3198        return base.get_host(names)
3199    #end if
3200
3201    def incorporate_defaults(self,elements=False,overwrite=False,propagate=False):
3202        base = self.get_base()
3203        base.incorporate_defaults(elements,overwrite,propagate)
3204    #end def incorporate_defaults
3205
3206    def pluralize(self):
3207        base = self.get_base()
3208        base.pluralize()
3209    #end def pluralize
3210
3211    def standard_placements(self):
3212        self.move(particleset='qmcsystem',wavefunction='qmcsystem',hamiltonian='qmcsystem')
3213    #end def standard_placements
3214
3215    def difference(self,other):
3216        s1 = self.copy()
3217        s2 = other.copy()
3218        b1 = s1.get_basename()
3219        b2 = s2.get_basename()
3220        q1 = s1[b1]
3221        q2 = s2[b2]
3222        if b1!=b2:
3223            different = True
3224            d1 = q1
3225            d2 = q2
3226            diff = None
3227        else:
3228            s1.standard_placements()
3229            s2.standard_placements()
3230            s1.pluralize()
3231            s2.pluralize()
3232            different,diff,d1,d2 = q1.difference(q2,root=False)
3233        #end if
3234        if diff!=None:
3235            diff.remove_empty()
3236        #end if
3237        d1.remove_empty()
3238        d2.remove_empty()
3239        return different,diff,d1,d2
3240    #end def difference
3241
3242    def remove_empty(self):
3243        base = self.get_base()
3244        base.remove_empty()
3245    #end def remove_empty
3246
3247    def read_xml(self,filepath=None,xml=None):
3248        if os.path.exists(filepath):
3249            element_joins=['qmcsystem']
3250            element_aliases=dict(loop='qmc')
3251            if xml is None:
3252                xml = XMLreader(filepath,element_joins,element_aliases,warn=False).obj
3253            else:
3254                xml = XMLreader(None,element_joins,element_aliases,warn=False,xml=xml).obj
3255            #end if
3256            xml.condense()
3257        else:
3258            self.error('the filepath you provided does not exist.\n  Input filepath: '+filepath)
3259        #end if
3260        return xml
3261    #end def read_xml
3262
3263    def include_xml(self,xmlfile,replace=True,exists=True):
3264        xml = self.read_xml(xmlfile)
3265        Param.metadata = self._metadata
3266        for name,exml in xml.items():
3267            if not name.startswith('_'):
3268                qxml = types[name](exml)
3269                qname = qxml.tag
3270                host = self.get_host(qname)
3271                if host==None and exists:
3272                    self.error('host xml section for '+qname+' not found','QmcpackInput')
3273                #end if
3274                if qname in host:
3275                    section_name = qname
3276                elif qname in plurals_inv and plurals_inv[qname] in host:
3277                    section_name = plurals_inv[qname]
3278                else:
3279                    section_name = None
3280                #end if
3281                if replace:
3282                    if section_name!=None:
3283                        del host[section_name]
3284                    #end if
3285                    host[qname] = qxml
3286                else:
3287                    if section_name==None:
3288                        host[qname] = qxml
3289                    else:
3290                        section = host[section_name]
3291                        if isinstance(section,collection):
3292                            section[qxml.identifier] = qxml
3293                        elif section_name in plurals_inv:
3294                            coll = collection()
3295                            coll[section.identifier] = section
3296                            coll[qxml.identifier]    = qxml
3297                            del host[section_name]
3298                            host[plurals_inv[section_name]] = coll
3299                        else:
3300                            section.combine(qxml)
3301                        #end if
3302                    #end if
3303                #end if
3304            #end if
3305        #end for
3306        Param.metadata = None
3307    #end def include_xml
3308
3309    # This include functionality is currently not being used
3310    # The rationale is essentially this:
3311    #   -Having includes explicitly represented in the input file object
3312    #    makes it very difficult to search for various components
3313    #    i.e. where is the particleset? the wavefunction? a particular determinant?
3314    #   -Difficulty in locating components makes it difficult to modify them
3315    #   -Includes necessarily introduce greater variability in input file structure
3316    #    and it is difficult to ensure every possible form is preserved each and
3317    #    every time a modification is made
3318    #   -The only time it is undesirable to incorporate the contents of an
3319    #    include directly into the input file object is if the data is large
3320    #    e.g. for an xml wavefunction or pseudopotential.
3321    #    In these cases, an external file should be provided that contains
3322    #    only the large object in question (pseudo or wavefunction).
3323    #    This is already done for pseudopotentials and should be done for
3324    #    wavefunctions, e.g. multideterminants.
3325    #    Until that time, wavefunctions will be explicitly read into the full
3326    #    input file.
3327    def add_include(self,element_type,href,placement='on'):
3328        # check the element type
3329        elems = ['cell','ptcl','wfs','ham']
3330        emap  = obj(
3331            simulationcell = 'cell',
3332            particleset    = 'ptcl',
3333            wavefunction   = 'wfs',
3334            hamiltonian    = 'ham'
3335            )
3336        if not element_type in elems:
3337            self.error('cannot add include for element of type {0}\n  valid element types are {1}'.format(element_type,elems))
3338        #end if
3339        # check the requested placement
3340        placements = ('before','on','after')
3341        if not placement in placements:
3342            self.error('cannot add include for element with placement {0}\n  valid placements are {1}'.format(placement,list(placements)))
3343        #end if
3344        # check that the base element is a simulation
3345        base = self.get_base()
3346        if not isinstance(base,simulation):
3347            self.error('an include can only be added to simulation\n  attempted to add to {0}'.format(base.__class__.__name__))
3348        #end if
3349        # gather a list of current qmcsystems
3350        if 'qmcsystem' in base:
3351            qslist = [(0,base.qmcsystem)]
3352            del base.qmcsystem
3353        elif 'qmcsystems' in base:
3354            qslist = base.qmcsystems.pairlist()
3355            del base.qmcsystems
3356        else:
3357            qslist = []
3358        #end if
3359        # organize the elements of the qmcsystems
3360        cur_elems = obj()
3361        for elem in elems:
3362            for place in placements:
3363                cur_elems[elem,place] = None
3364            #end for
3365        #end for
3366        for qskey,qs in qslist:
3367            if isinstance(qs,include):
3368                inc = qs
3369                ekey = qskey.split('_')[1]
3370                if not ekey in elems:
3371                    self.error('encountered invalid element key: {0}\n  valid keys are: {1}'.format(ekey,elems))
3372                #end if
3373                if cur_elems[ekey,'on'] is None:
3374                    cur_elems[ekey,'before'] = ekey,inc
3375                else:
3376                    cur_elems[ekey,'after' ] = ekey,inc
3377                #end if
3378            elif not isinstance(qs,qmcsystem):
3379                self.error('expected qmcsystem element, got {0}'.format(qs.__class__.__name__))
3380            else:
3381                for elem in qmcsystem.elements:
3382                    elem_plural = elem+'s'
3383                    name  = None
3384                    if elem in qs:
3385                        name = elem
3386                    elif elem_plural in qs:
3387                        name = elem_plural
3388                    #end if
3389                    if name!=None:
3390                        cur_elems[emap[elem],'on'] = name,qs[name]
3391                        del qs[name]
3392                    #end if
3393                #end for
3394                residue = list(qs.keys())
3395                if len(residue)>0:
3396                    self.error('extra keys found in qmcsystem: {0}'.format(sorted(residue)))
3397                #end if
3398            #end if
3399        #end for
3400        for elem in elems:
3401            pbef = cur_elems[elem,'before']
3402            pon  = cur_elems[elem,'on'    ]
3403            paft = cur_elems[elem,'after' ]
3404            if pon is None:
3405                if not pbef is None and paft is None:
3406                    cur_elems[elem,'on'    ] = pbef
3407                    cur_elems[elem,'before'] = None
3408                elif not paft is None and pbef is None:
3409                    cur_elems[elem,'on'    ] = paft
3410                    cur_elems[elem,'after' ] = None
3411                #end if
3412            #end if
3413        #end for
3414        # insert the new include
3415        inc_name  = 'include_'+element_type
3416        inc_value = include(href=href)
3417        cur_elems[element_type,placement] = inc_name,inc_value
3418        # create a collection of qmcsystems
3419        qmcsystems = collection()
3420        qskey = ''
3421        qs    = qmcsystem()
3422        for elem in elems:
3423            for place in placements:
3424                cur_elem = cur_elems[elem,place]
3425                if cur_elem!=None:
3426                    name,value = cur_elem
3427                    if isinstance(value,include):
3428                        if len(qskey)>0:
3429                            qmcsystems.add(qs,key=qskey)
3430                            qskey = ''
3431                            qs    = qmcsystem()
3432                        #end if
3433                        qmcsystems.add(value,key=name)
3434                    else:
3435                        qskey += elem[0]
3436                        qs[name] = value
3437                    #end if
3438                #end if
3439            #end for
3440        #end for
3441        if len(qskey)>0:
3442            qmcsystems.add(qs,key=qskey)
3443        #end if
3444        # attach the collection to the input file
3445        base.qmcsystems = qmcsystems
3446    #end def add_include
3447
3448
3449    def get_output_info(self,*requests):
3450        project = self.simulation.project
3451        prefix = project.id
3452        series = project.series
3453
3454        qmc = []
3455        calctypes = set()
3456        outfiles = []
3457
3458        if not self.is_afqmc_input():
3459            qmc_ur = self.unroll_calculations(modify=False)
3460            n=0
3461            for qo in qmc_ur:
3462                q = obj()
3463                q.prefix = prefix
3464                q.series = series+n
3465                n+=1
3466                method = qo.method
3467                if method in self.opt_methods:
3468                    q.type = 'opt'
3469                else:
3470                    q.type = method
3471                #end if
3472                calctypes.add(q.type)
3473                q.method = method
3474                fprefix = prefix+'.s'+str(q.series).zfill(3)+'.'
3475                files = obj()
3476                files.scalar = fprefix+'scalar.dat'
3477                files.stat   = fprefix+'stat.h5'
3478                # apparently this one is no longer generated by default as of r5756
3479                #files.config = fprefix+'storeConfig.h5'
3480                if q.type=='opt':
3481                    files.opt = fprefix+'opt.xml'
3482                elif q.type=='dmc':
3483                    files.dmc = fprefix+'dmc.dat'
3484                #end if
3485                outfiles.extend(files.values())
3486                q.files = files
3487                qmc.append(q)
3488            #end for
3489        else:
3490            q = obj()
3491            q.prefix = prefix
3492            q.series = series
3493            q.type   = 'afqmc'
3494            q.method = 'afqmc'
3495            calctypes.add(q.type)
3496            fprefix = prefix+'.s'+str(q.series).zfill(3)+'.'
3497            files = obj()
3498            files.scalar = fprefix+'scalar.dat'
3499            outfiles.extend(files.values())
3500            q.files = files
3501            qmc.append(q)
3502        #end if
3503
3504        res = dict(qmc=qmc,calctypes=calctypes,outfiles=outfiles)
3505
3506        values = []
3507        for req in requests:
3508            if req in res:
3509                values.append(res[req])
3510            else:
3511                self.error(req+' is not a valid output info request')
3512            #end if
3513        #end for
3514        if len(values)==1:
3515            return values[0]
3516        else:
3517            return values
3518        #end if
3519    #end def get_output_info
3520
3521
3522    def generate_jastrows(self,size=None,j1func='bspline',j1size=8,j2func='bspline',j2size=8):
3523        if size!=None:
3524            j1size = size
3525            j2size = size
3526        #end if
3527
3528        #self.remove('jastrow')
3529        lattice,particlesets,wavefunction = self.get('lattice','particleset','wavefunction')
3530        no_lattice = lattice==None
3531        no_particleset = particlesets==None
3532        no_wavefunction = wavefunction==None
3533        if no_lattice:
3534            self.error('a simulationcell lattice must be present to generate jastrows',exit=False)
3535        #end if
3536        if no_particleset:
3537            self.error('a particleset must be present to generate jastrows',exit=False)
3538        #end if
3539        if no_wavefunction:
3540            self.error('a wavefunction must be present to generate jastrows',exit=False)
3541        #end if
3542        if no_lattice or no_particleset or no_wavefunction:
3543            self.error('jastrows cannot be generated')
3544        #end if
3545        if isinstance(particlesets,QIxml):
3546            particlesets = make_collection([particlesets])
3547        #end if
3548        if not 'e' in particlesets:
3549            self.error('electron particleset (e) not found\n particlesets: '+str(particlesets.keys()))
3550        #end if
3551
3552
3553        jastrows = collection()
3554
3555        cell = Structure(lattice)
3556        volume = cell.volume()
3557        rcut   = cell.rmin()
3558
3559        #use the rpa jastrow for electrons (modeled after Luke's tool)
3560        size = j2size
3561        e = particlesets.e
3562        nelectrons = 0
3563        for g in e.groups:
3564            nelectrons += g.size
3565        #end for
3566        density = nelectrons/volume
3567        wp = sqrt(4*pi*density)
3568        dr = rcut/size
3569        r = .02 + dr*arange(size)
3570        uuc = .5/(wp*r)*(1.-exp(-r*sqrt(wp/2)))*exp(-(2*r/rcut)**2)
3571        udc = .5/(wp*r)*(1.-exp(-r*sqrt(wp)))*exp(-(2*r/rcut)**2)
3572        jastrows.J2 = jastrow2(
3573            name = 'J2',type='Two-Body',function=j2func,print='yes',
3574            correlations = collection(
3575                uu = correlation(speciesA='u',speciesB='u',size=size,
3576                                 coefficients=section(id='uu',type='Array',coeff=uuc)),
3577                ud = correlation(speciesA='u',speciesB='d',size=size,
3578                                 coefficients=section(id='ud',type='Array',coeff=udc))
3579                )
3580            )
3581
3582        #generate electron-ion jastrows, if ions present
3583        ions = []
3584        for name in particlesets.keys():
3585            if name=='i' or name.startswith('ion'):
3586                ions.append(name)
3587            #end if
3588        #end for
3589        if len(ions)>0:
3590            size = j1size
3591            j1 = []
3592            for ion in ions:
3593                i = particlesets[ion]
3594                if 'group' in i:
3595                    groups = [i.group]
3596                else:
3597                    groups = i.groups
3598                #end if
3599                corr = []
3600                for g in groups:
3601                    elem = g.name
3602                    c=correlation(
3603                        elementtype=elem,
3604                        cusp=0.,
3605                        size=size,
3606                        coefficients=section(
3607                            id='e'+elem,
3608                            type='Array',
3609                            coeff=size*[0]
3610                            )
3611                        )
3612                    corr.append(c)
3613                #end for
3614                j=jastrow1(
3615                    name='J1_'+ion,
3616                    type='One-Body',
3617                    function=j1func,
3618                    source=ion,
3619                    print='yes',
3620                    correlations = corr
3621                    )
3622                j1.append(j)
3623            #end for
3624            if len(j1)==1:
3625                j1[0].name='J1'
3626            #end if
3627            for j in j1:
3628                jastrows[j.name]=j
3629            #end for
3630        #end if
3631
3632        if 'J2' in wavefunction.jastrows:
3633            J2 = wavefunction.jastrows.J2
3634            if 'function' in J2 and J2.function.lower()=='bspline':
3635                c = wavefunction.jastrows.J2.correlations
3636                ctot = abs(array(c.uu.coefficients.coeff)).sum() + abs(array(c.ud.coefficients.coeff)).sum()
3637                if ctot < 1e-3:
3638                    wavefunction.jastrows.J2 = jastrows.J2
3639                #end if
3640            #end if
3641        #end if
3642
3643        #only add the jastrows if ones of the same type
3644        # (one-body,two-body,etc) are not already present
3645        for jastrow in jastrows:
3646            jtype = jastrow.type.lower().replace('-','_')
3647            has_jtype = False
3648            for wjastrow in wavefunction.jastrows:
3649                wjtype = wjastrow.type.lower().replace('-','_')
3650                has_jtype = has_jtype or wjtype==jtype
3651            #end for
3652            if not has_jtype:
3653                wavefunction.jastrows[jastrow.name] = jastrow
3654            #end if
3655        #end for
3656    #end def generate_jastrows
3657
3658
3659    def incorporate_system(self,system):
3660        self.warn('incorporate_system may or may not work\n  please check the qmcpack input produced\n  if it is wrong, please contact the developer')
3661        system = system.copy()
3662        system.check_folded_system()
3663        system.change_units('B')
3664        #system.structure.group_atoms()
3665        system.structure.order_by_species()
3666        particles  = system.particles
3667        structure  = system.structure
3668        net_charge = system.net_charge
3669        net_spin   = system.net_spin
3670
3671        qs,sc,ham,ps = self.get('qmcsystem','simulationcell','hamiltonian','particleset')
3672
3673        old_eps_name = None
3674        old_ips_name = None
3675        if ps!=None:
3676            if isinstance(ps,particleset):
3677                ps = make_collection([ps])
3678            #end if
3679            for pname,pset in ps.items():
3680                g0name = list(pset.groups.keys())[0]
3681                g0 = pset.groups[g0name]
3682                if abs(-1-g0.charge)<1e-2:
3683                    old_eps_name = pname
3684                elif 'ionid' in pset:
3685                    old_ips_name = pname
3686                #end if
3687            #end for
3688        #end if
3689        del ps
3690        self.remove('particleset')
3691        if qs==None:
3692            qs = qmcsystem()
3693            qs.incorporate_defaults(elements=False,propagate=False)
3694            self.simulation.qmcsystem = qs
3695        #end if
3696        if sc==None:
3697            sc = simulationcell()
3698            sc.incorporate_defaults(elements=False,propagate=False)
3699            qs.simulationcell = sc
3700        #end if
3701        if ham==None:
3702            ham = hamiltonian()
3703            ham.incorporate_defaults(elements=False,propagate=False)
3704            qs.hamiltonian = ham
3705        elif isinstance(ham,collection):
3706            if 'h0' in ham:
3707                ham = ham.h0
3708            elif len(ham)==1:
3709                ham = ham.list()[0]
3710            else:
3711                self.error('cannot find hamiltonian for system incorporation')
3712            #end if
3713        #end if
3714
3715        elem = structure.elem
3716        pos  = structure.pos
3717
3718        if len(structure.axes)>0: #exclude systems with open boundaries
3719            #setting the 'lattice' (cell axes) requires some delicate care
3720            #  qmcpack will fail if this is even 1e-10 off of what is in
3721            #  the wavefunction hdf5 file from pwscf
3722            if structure.folded_structure!=None:
3723                fs = structure.folded_structure
3724                axes = array(pwscf_array_string(fs.axes).split(),dtype=float)
3725                axes.shape = fs.axes.shape
3726                axes = dot(structure.tmatrix,axes)
3727                if abs(axes-structure.axes).sum()>1e-5:
3728                    self.error('supercell axes do not match tiled version of folded cell axes\n  you may have changed one set of axes (super/folded) and not the other\n  folded cell axes:\n'+str(fs.axes)+'\n  supercell axes:\n'+str(structure.axes)+'\n  folded axes tiled:\n'+str(axes))
3729                #end if
3730            else:
3731                axes = array(pwscf_array_string(structure.axes).split(),dtype=float)
3732                axes.shape = structure.axes.shape
3733            #end if
3734            structure.adjust_axes(axes)
3735
3736            sc.lattice = axes
3737        #end if
3738
3739        elns = particles.get_electrons()
3740        ions = particles.get_ions()
3741        eup  = elns.up_electron
3742        edn  = elns.down_electron
3743
3744        particlesets = []
3745        eps = particleset(
3746            name='e',random=True,
3747            groups = [
3748                group(name='u',charge=-1,mass=eup.mass,size=eup.count),
3749                group(name='d',charge=-1,mass=edn.mass,size=edn.count)
3750                ]
3751            )
3752        particlesets.append(eps)
3753        if len(ions)>0:
3754            if sc!=None and 'bconds' in sc and tuple(sc.bconds)!=('p','p','p'):
3755                eps.randomsrc = 'ion0'
3756            #end if
3757            ips = particleset(
3758                name='ion0',
3759                )
3760            groups = []
3761            ham.pluralize()
3762            pseudos = ham.get('pseudo')
3763            if pseudos==None:
3764                pp = ham.get('PseudoPot')
3765                if pp!=None:
3766                    pseudos = collection()
3767                    pp.pseudos = pseudos
3768                #end if
3769            #end if
3770            for ion in ions:
3771                gpos = pos[elem==ion.name]
3772                g = group(
3773                    name         = ion.name,
3774                    charge       = ion.charge,
3775                    valence      = ion.charge,
3776                    atomicnumber = ion.protons,
3777                    mass         = ion.mass,
3778                    position     = gpos,
3779                    size         = len(gpos)
3780                    )
3781                groups.append(g)
3782                if pseudos!=None and not ion.name in pseudos:
3783                    pseudos[ion.name] = pseudo(elementtype=ion.name,href='MISSING.xml')
3784                #end if
3785            #end for
3786            ips.groups = make_collection(groups)
3787            particlesets.append(ips)
3788        #end if
3789        qs.particlesets = make_collection(particlesets)
3790
3791        if old_eps_name!=None:
3792            self.replace(old_eps_name,'e')
3793        #end if
3794        if old_ips_name!=None and len(ions)>0:
3795            self.replace(old_ips_name,'ion0')
3796        #end if
3797
3798        udet,ddet = self.get('updet','downdet')
3799
3800        if udet!=None:
3801            udet.size = elns.up_electron.count
3802        #end if
3803        if ddet!=None:
3804            ddet.size = elns.down_electron.count
3805        #end if
3806
3807        if abs(net_spin) > 1e-1:
3808            if ddet!=None:
3809                if 'occupation' in ddet:
3810                    ddet.occupation.spindataset = 1
3811                else:
3812                    ss = self.get('sposets')
3813                    ss[ddet.sposet].spindataset = 1
3814                #end if
3815            #end if
3816        #end if
3817    #end def incorporate_system
3818
3819
3820    def return_system(self,structure_only=False):
3821        input = self.copy()
3822        input.pluralize()
3823        axes,ps,H = input.get('lattice','particlesets','hamiltonian')
3824
3825        if ps is None:
3826            return None
3827        #end if
3828
3829        # find electrons and ions
3830        have_ions = True
3831        have_jellium = False
3832        ions = None
3833        elns = None
3834        ion_list = []
3835        for name,p in ps.items():
3836            if 'ionid' in p:
3837                ion_list.append(p)
3838            elif name.startswith('e'):
3839                elns = p
3840            #end if
3841        #end for
3842        if len(ion_list)==0: #try to identify ions by positive charged groups
3843            for name,p in ps.items():
3844                if 'groups' in p:
3845                    for g in p.groups:
3846                        if 'charge' in g and g.charge>0:
3847                            ion_list.append(p)
3848                            break
3849                        #end if
3850                    #end for
3851                #end if
3852            #end for
3853        #end if
3854        if len(ion_list)==1:
3855            ions = ion_list[0]
3856        elif len(ion_list)>1:
3857            self.error('ability to handle multiple ion particlesets has not been implemented')
3858        #end if
3859        if ions is None and elns!=None and 'groups' in elns:
3860            simcell = input.get('simulationcell')
3861            if simcell!=None and 'rs' in simcell:
3862                have_ions = False
3863                have_jellium = True
3864            elif not 'pairpots' in H:
3865                have_ions = False
3866            #end if
3867        #end if
3868        if elns==None:
3869            self.error('could not find electron particleset')
3870        #end if
3871        if ions==None and have_ions:
3872            self.error('could not find ion particleset')
3873        #end if
3874
3875        #compute spin and electron charge
3876        net_spin   = 0
3877        eln_charge = 0
3878        for spin,eln in elns.groups.items():
3879            if spin[0]=='u':
3880                net_spin+=eln.size
3881            elif spin[0]=='d':
3882                net_spin-=eln.size
3883            #end if
3884            eln_charge += eln.charge*eln.size
3885        #end if
3886
3887        #get structure and ion charge
3888        structure  = None
3889        ion_charge = 0
3890        valency    = dict()
3891        if have_ions:
3892            elem = None
3893            if 'ionid' in ions:
3894                if isinstance(ions.ionid,str):
3895                    elem = [ions.ionid]
3896                else:
3897                    elem = list(ions.ionid)
3898                #end if
3899                pos  = ions.position
3900            elif 'size' in ions and ions.size==1:
3901                elem = [ions.groups.list()[0].name]
3902                pos  = [[0,0,0]]
3903            elif 'groups' in ions:
3904                elem = []
3905                pos  = []
3906                for group in ions.groups:
3907                    if 'position' in group:
3908                        nions = group.size
3909                        elem.extend(nions*[group.name])
3910                        if group.size==1:
3911                            pos.extend([list(group.position)])
3912                        else:
3913                            pos.extend(list(group.position))
3914                        #end if
3915                    #end if
3916                #end for
3917                if len(elem)==0:
3918                    elem = None
3919                    pos  = None
3920                else:
3921                    elem  = array(elem)
3922                    pos   = array(pos)
3923                    order = elem.argsort()
3924                    elem  = elem[order]
3925                    pos   = pos[order]
3926                #end if
3927            #end if
3928            if elem is None:
3929                self.error('could not read ions from ion particleset')
3930            #end if
3931            if axes is None:
3932                center = (0,0,0)
3933            else:
3934                md = input._metadata
3935                if 'position' in md and 'condition' in md['position'] and md['position']['condition']==1:
3936                    pos = dot(pos,axes)
3937                #end if
3938                center = axes.sum(0)/2
3939            #end if
3940
3941            structure = Structure(axes=axes,elem=elem,pos=pos,center=center,units='B')
3942
3943            for name,element in ions.groups.items():
3944                if 'charge' in element:
3945                    valence = element.charge
3946                elif 'valence' in element:
3947                    valence = element.valence
3948                elif 'atomic_number' in element:
3949                    valence = element.atomic_number
3950                else:
3951                    self.error('could not identify valency of '+name)
3952                #end if
3953                valency[name] = valence
3954                count = list(elem).count(name)
3955                ion_charge += valence*count
3956            #end for
3957        elif have_jellium:
3958            structure  = Jellium(rs=simcell.rs,background_charge=-eln_charge)
3959            ion_charge = structure.background_charge
3960        #end if
3961
3962        net_charge = ion_charge + eln_charge
3963
3964        system = PhysicalSystem(structure,net_charge,net_spin,**valency)
3965
3966        if structure_only:
3967            return structure
3968        else:
3969            return system
3970        #end if
3971    #end def return_system
3972
3973
3974    def get_ion_particlesets(self):
3975        ions = obj()
3976        ps = self.get('particlesets')
3977        #try to identify ions by positive charged groups
3978        for name,p in ps.items():
3979            if name.startswith('ion') or name.startswith('atom'):
3980                ions[name] = p
3981            elif 'groups' in p:
3982                for g in p.groups:
3983                    if 'charge' in g and g.charge>0:
3984                        ions[name] = p
3985                        break
3986                    #end if
3987                #end for
3988            #end if
3989        #end for
3990        return ions
3991    #end def get_ion_particlesets
3992
3993
3994    def get_pp_files(self):
3995        pp_files = []
3996        h = self.get('hamiltonian')
3997        if h != None:
3998            pp = None
3999            if 'pairpots' in h:
4000                for pairpot in h.pairpots:
4001                    if 'type' in pairpot and pairpot.type=='pseudo':
4002                        pp = pairpot
4003                    #end if
4004                #end for
4005            elif 'pairpot' in h and 'type' in h.pairpot and h.pairpot.type=='pseudo':
4006                pp = h.pairpot
4007            #end if
4008            if pp!=None:
4009                if 'pseudo' in pp and 'href' in pp.pseudo:
4010                    pp_files.append(pp.pseudo.href)
4011                elif 'pseudos' in pp:
4012                    for pseudo in pp.pseudos:
4013                        if 'href' in pseudo:
4014                            pp_files.append(pseudo.href)
4015                        #end if
4016                    #end for
4017                #end if
4018            #end if
4019        #end if
4020        return pp_files
4021    #end def get_pp_files
4022
4023
4024    def remove_physical_system(self):
4025        qs = self.simulation.qmcsystem
4026        if 'simulationcell' in qs:
4027            del qs.simulationcell
4028        #end if
4029        if 'particlesets' in qs:
4030            del qs.particlesets
4031        #end if
4032        for name in qs.keys():
4033            if isinstance(qs[name],particleset):
4034                del qs[name]
4035            #end if
4036        #end for
4037        self.replace('ion0','i')
4038    #end def remove_physical_system
4039
4040
4041    def cusp_correction(self):
4042        cc = False
4043        if not self.is_afqmc_input():
4044            ds = self.get('determinantset')
4045            cc_var = ds!=None and 'cuspcorrection' in ds and ds.cuspcorrection==True
4046            cc_run = len(self.simulation.calculations)==0
4047            cc = cc_var and cc_run
4048        #end if
4049        return cc
4050    #end def cusp_correction
4051
4052
4053    def get_qmc(self,series):
4054        qmc = None
4055        calcs        = self.get('calculations')
4056        series_start = self.get('series')
4057        if calcs!=None:
4058            if series_start is None:
4059                qmc = calcs[series]
4060            else:
4061                qmc = calcs[series-series_start]
4062            #end if
4063        #end if
4064        return qmc
4065    #end def get_qmc
4066
4067
4068    def bundle(self,inputs,filenames):
4069        return BundledQmcpackInput(inputs,filenames)
4070    #end def bundle
4071
4072
4073    def trace(self,quantity,values):
4074        return TracedQmcpackInput(quantity,values,self)
4075    #end def trace
4076
4077
4078    def twist_average(self,twistnums):
4079        return self.trace('twistnum',twistnums)
4080    #end def twist_average
4081#end class QmcpackInput
4082
4083
4084
4085# base class for bundled qmcpack input
4086#  not used on its own
4087class BundledQmcpackInput(SimulationInput):
4088
4089    def __init__(self,inputs,filenames):
4090        self.inputs = obj()
4091        for input in inputs:
4092            self.inputs.append(input)
4093        #end for
4094        self.filenames = filenames
4095    #end def __init__
4096
4097
4098    def get_output_info(self,*requests):
4099        outfiles = []
4100
4101        for index,input in self.inputs.items():
4102            outfs = input.get_output_info('outfiles')
4103            infile = self.filenames[index]
4104            outfile= infile.rsplit('.',1)[0]+'.g'+str(index).zfill(3)+'.qmc'
4105            outfiles.append(infile)
4106            outfiles.append(outfile)
4107            for outf in outfs:
4108                prefix,rest = outf.split('.',1)
4109                outfiles.append(prefix+'.g'+str(index).zfill(3)+'.'+rest)
4110            #end for
4111        #end for
4112
4113        values = []
4114        for req in requests:
4115            if req=='outfiles':
4116                values.append(outfiles)
4117            else:
4118                values.append(None)
4119            #end if
4120        #end for
4121        if len(values)==1:
4122            return values[0]
4123        else:
4124            return values
4125        #end if
4126    #end def get_output_info
4127
4128
4129    def generate_filenames(self,infile):
4130        self.not_implemented()
4131    #end def generate_filenames
4132
4133
4134    def write(self,filepath=None):
4135        if filepath!=None and not 'filenames' in self:
4136            infile = os.path.split(filepath)[1]
4137            if not infile.endswith('.xml'):
4138                infile+='.xml'
4139            #end if
4140            self.generate_filenames(infile)
4141        #end if
4142        if filepath==None:
4143            c = ''
4144            for i in range(len(self.inputs)):
4145                c += self.filenames[i]+'\n'
4146            #end for
4147            return c
4148        else:
4149            path,file  = os.path.split(filepath)
4150            #if file!=self.filenames[-1]:
4151            #    self.error('main filenames do not match\n  internal: '+self.filenames[-1]+'\n  inputted: '+file)
4152            ##end if
4153            c = ''
4154            for i in range(len(self.inputs)):
4155                input = self.inputs[i]
4156                bfile = self.filenames[i]
4157                c += bfile+'\n'
4158                bfilepath = os.path.join(path,bfile)
4159                input.write(bfilepath)
4160            #end for
4161            fobj = open(filepath,'w')
4162            fobj.write(c)
4163            fobj.close()
4164        #end if
4165    #end def write
4166#end class BundledQmcpackInput
4167
4168
4169
4170class TracedQmcpackInput(BundledQmcpackInput):
4171    def __init__(self,quantity=None,values=None,input=None):
4172        self.quantities = obj()
4173        self.variables = obj()
4174        self.inputs = obj()
4175        self.filenames = None
4176        if quantity!=None and values!=None and input!=None:
4177            self.bundle_inputs(quantity,values,input)
4178        #end if
4179    #end def __init__
4180
4181    def bundle_inputs(self,quantity,values,input):
4182        range = len(self.inputs),len(self.inputs)+len(values)
4183        self.quantities.append(obj(quantity=quantity,range=range))
4184        for value in values:
4185            inp = input.copy()
4186            qhost = inp.get_host(quantity)
4187            if qhost!=None:
4188                qhost[quantity] = value
4189            else:
4190                self.error('quantity '+quantity+' was not found in '+input.__class__.__name__)
4191            #end if
4192            self.variables.append(obj(quantity=quantity,value=value))
4193            self.inputs.append(inp)
4194        #end for
4195    #end def bundle_inputs
4196
4197
4198    def generate_filenames(self,infile):
4199        prefix,ext = infile.split('.',1)
4200        if not ext.endswith('xml'):
4201            ext+='.xml'
4202        #end if
4203        self.filenames = []
4204        for i in range(len(self.variables)):
4205            var = self.variables[i]
4206            q = var.quantity
4207            v = var.value
4208            bfile = prefix+'.g'+str(i).zfill(3)+'.'+q+'_'+str(v)+'.'+ext
4209            self.filenames.append(bfile)
4210        #end if
4211        self.filenames.append(prefix+'.in')
4212    #end def generate_filenames
4213#end class TracedQmcpackInput
4214
4215
4216
4217
4218class QmcpackInputTemplate(SimulationInputTemplate):
4219    def preprocess(self,contents,filepath=None):
4220        if filepath!=None:
4221            basepath,filename = os.path.split(filepath)
4222            c = contents
4223            contents=''
4224            for line in c.splitlines():
4225                if '<include' in line and '/>' in line:
4226                    tokens = line.replace('<include','').replace('/>','').split()
4227                    for token in tokens:
4228                        if token.startswith('href'):
4229                            include_file = token.replace('href','').replace('=','').replace('"','').strip()
4230                            include_path = os.path.join(basepath,include_file)
4231                            if os.path.exists(include_path):
4232                                icont = open(include_path,'r').read()+'\n'
4233                                line = ''
4234                                for iline in icont.splitlines():
4235                                    if not '<?' in iline:
4236                                        line+=iline+'\n'
4237                                    #end if
4238                                #end for
4239                            #end if
4240                        #end if
4241                    #end for
4242                #end if
4243                contents+=line+'\n'
4244            #end for
4245        #end if
4246        return contents
4247    #end def preprocess
4248
4249
4250    def get_output_info(self,*args,**kwargs):
4251        # just pretend
4252        return []
4253    #end def get_output_info
4254#end class QmcpackInputTemplate
4255
4256
4257
4258
4259def generate_simulationcell(bconds='ppp',lr_dim_cutoff=15,lr_tol=None,lr_handler=None,system=None):
4260    bconds = tuple(bconds)
4261    sc = simulationcell(bconds=bconds)
4262    periodic = 'p' in bconds
4263    axes_valid = system!=None and len(system.structure.axes)>0
4264    if periodic:
4265        sc.lr_dim_cutoff = lr_dim_cutoff
4266        if lr_tol is not None:
4267            sc.lr_tol = lr_tol
4268        #end if
4269        if lr_handler is not None:
4270            sc.lr_handler = lr_handler
4271        #end if
4272        if not axes_valid:
4273            QmcpackInput.class_error('invalid axes in generate_simulationcell\nargument system must be provided\naxes of the structure must have non-zero dimension')
4274        #end if
4275    #end if
4276    if axes_valid:
4277        system.check_folded_system()
4278        system.change_units('B')
4279        structure = system.structure
4280        if isinstance(structure,Jellium):
4281            sc.rs         = structure.rs()
4282            sc.nparticles = system.particles.count_electrons()
4283        else:
4284            #setting the 'lattice' (cell axes) requires some delicate care
4285            #  qmcpack will fail if this is even 1e-10 off of what is in
4286            #  the wavefunction hdf5 file from pwscf
4287            if structure.folded_structure!=None:
4288                fs = structure.folded_structure
4289                axes = array(pwscf_array_string(fs.axes).split(),dtype=float)
4290                axes.shape = fs.axes.shape
4291                axes = dot(structure.tmatrix,axes)
4292                if abs(axes-structure.axes).sum()>1e-5:
4293                    QmcpackInput.class_error('in generate_simulationcell\nsupercell axes do not match tiled version of folded cell axes\nyou may have changed one set of axes (super/folded) and not the other\nfolded cell axes:\n'+str(fs.axes)+'\nsupercell axes:\n'+str(structure.axes)+'\nfolded axes tiled:\n'+str(axes))
4294                #end if
4295            else:
4296                axes = array(pwscf_array_string(structure.axes).split(),dtype=float)
4297                axes.shape = structure.axes.shape
4298            #end if
4299            structure.adjust_axes(axes)
4300
4301            sc.lattice = axes
4302        #end if
4303    #end if
4304    return sc
4305#end def generate_simulationcell
4306
4307
4308def generate_particlesets(electrons   = 'e',
4309                          ions        = 'ion0',
4310                          up          = 'u',
4311                          down        = 'd',
4312                          system      = None,
4313                          randomsrc   = False,
4314                          hybrid_rcut = None,
4315                          hybrid_lmax = None,
4316                          ):
4317    if system is None:
4318        QmcpackInput.class_error('generate_particlesets argument system must not be None')
4319    #end if
4320
4321    ename = electrons
4322    iname = ions
4323    uname = up
4324    dname = down
4325
4326    del electrons
4327    del ions
4328    del up
4329    del down
4330
4331    system.check_folded_system()
4332    system.change_units('B')
4333
4334    particles  = system.particles
4335    structure  = system.structure
4336    net_charge = system.net_charge
4337    net_spin   = system.net_spin
4338
4339    elns = particles.get_electrons()
4340    ions = particles.get_ions()
4341    eup  = elns.up_electron
4342    edn  = elns.down_electron
4343
4344    particlesets = []
4345    eps = particleset(
4346        name   = ename,
4347        random = True,
4348        groups = [
4349            group(name=uname,charge=-1,mass=eup.mass,size=eup.count),
4350            group(name=dname,charge=-1,mass=edn.mass,size=edn.count)
4351            ]
4352        )
4353    particlesets.append(eps)
4354    if len(ions)>0:
4355        # maintain consistent order
4356        ion_species,ion_counts = structure.order_by_species()
4357        elem = structure.elem
4358        pos  = structure.pos
4359        if randomsrc:
4360            eps.randomsrc = iname
4361        #end if
4362        ips = particleset(name=iname)
4363        # handle hybrid rep
4364        hybridrep = hybrid_rcut is not None or hybrid_lmax is not None
4365        if hybridrep:
4366            hybrid_vars = (
4367                ('hybrid_rcut',hybrid_rcut),
4368                ('hybrid_lmax',hybrid_lmax),
4369                )
4370            for hvar,hval in hybrid_vars:
4371                if not isinstance(hval,obj):
4372                    QmcpackInput.class_error('generate_particlesets argument "{0}" must be of type obj\nyou provided type: {1}\nwith value: {2}'.format(hvar,hval.__class__.__name__,hval))
4373                #end if
4374                if set(hval.keys())!=set(ion_species):
4375                    QmcpackInput.class_error('generate_particsets argument "{0}" is incorrect\none entry must be present for each atomic species\natomic species present in the simulation: {1}\nvalues provided for the following species: {2}'.format(hvar,sorted(ion_species),sorted(hval.keys())))
4376                #end if
4377            #end for
4378        #end if
4379        # make groups
4380        groups = []
4381        for ion_spec in ion_species:
4382            ion = ions[ion_spec]
4383            gpos = pos[elem==ion.name]
4384            g = group(
4385                name         = ion.name,
4386                charge       = ion.charge,
4387                valence      = ion.charge,
4388                atomicnumber = ion.protons,
4389                mass         = ion.mass,
4390                position     = gpos,
4391                size         = len(gpos)
4392                )
4393            if hybridrep:
4394                g.lmax           = hybrid_lmax[ion_spec]
4395                g.cutoff_radius  = hybrid_rcut[ion_spec]
4396            #end if
4397            groups.append(g)
4398        #end for
4399        ips.groups = make_collection(groups)
4400        particlesets.append(ips)
4401    #end if
4402    particlesets = make_collection(particlesets)
4403    return particlesets
4404#end def generate_particlesets
4405
4406
4407def generate_sposets(type           = None,
4408                     occupation     = None,
4409                     spin_polarized = False,
4410                     nup            = None,
4411                     ndown          = None,
4412                     spo_up         = 'spo_u',
4413                     spo_down       = 'spo_d',
4414                     system         = None,
4415                     sposets        = None,
4416                     spindatasets   = False):
4417    ndn = ndown
4418    if type is None:
4419        QmcpackInput.class_error('cannot generate sposets\n  type of sposet not specified')
4420    #end if
4421    if sposets!=None:
4422        for spo in spo:
4423            spo.type = type
4424        #end for
4425    elif occupation=='slater_ground':
4426        have_counts = not (nup is None or ndown is None)
4427        if system is None and not have_counts:
4428            QmcpackInput.class_error('cannot generate sposets in occupation mode {0}\n  arguments nup & ndown or system must be given to generate_sposets'.format(occupation))
4429        elif not have_counts:
4430            elns = system.particles.get_electrons()
4431            nup  = elns.up_electron.count
4432            ndn  = elns.down_electron.count
4433        #end if
4434        if not spin_polarized:
4435            if nup==ndn:
4436                sposets = [sposet(type=type,name='spo_ud',spindataset=0,size=nup)]
4437            else:
4438                sposets = [sposet(type=type,name=spo_up,  spindataset=0,size=nup),
4439                           sposet(type=type,name=spo_down,spindataset=0,size=ndn)]
4440            #end if
4441        else:
4442            sposets = [sposet(type=type,name=spo_up,  spindataset=0,size=nup),
4443                       sposet(type=type,name=spo_down,spindataset=1,size=ndn)]
4444        #end if
4445        if not spindatasets:
4446            for spo in sposets:
4447                del spo.spindataset
4448            #end for
4449        #end if
4450    else:
4451        QmcpackInput.class_error('cannot generate sposets in occupation mode {0}\n  generate_sposets currently supports the following occupation modes:\n  slater_ground'.format(occupation))
4452    #end if
4453    return make_collection(sposets)
4454#end def generate_sposets
4455
4456
4457def generate_sposet_builder(type,*args,**kwargs):
4458    if type=='bspline' or type=='einspline':
4459        return generate_bspline_builder(type,*args,**kwargs)
4460    elif type=='heg':
4461        return generate_heg_builder(*args,**kwargs)
4462    else:
4463        QmcpackInput.class_error('cannot generate sposet_builder\n  sposet_builder of type {0} is unrecognized'.format(type))
4464    #end if
4465#end def generate_sposet_builder
4466
4467
4468def generate_bspline_builder(type           = 'bspline',
4469                             meshfactor     = 1.0,
4470                             precision      = 'float',
4471                             twistnum       = None,
4472                             twist          = None,
4473                             sort           = None,
4474                             version        = '0.10',
4475                             truncate       = False,
4476                             buffer         = None,
4477                             spin_polarized = False,
4478                             hybridrep      = None,
4479                             href           = 'MISSING.h5',
4480                             ions           = 'ion0',
4481                             spo_up         = 'spo_u',
4482                             spo_down       = 'spo_d',
4483                             sposets        = None,
4484                             system         = None
4485                             ):
4486    tilematrix = identity(3,dtype=int)
4487    if system!=None:
4488        tilematrix = system.structure.tilematrix()
4489    #end if
4490    bsb = bspline_builder(
4491        type       = type,
4492        meshfactor = meshfactor,
4493        precision  = precision,
4494        tilematrix = tilematrix,
4495        href       = href,
4496        version    = version,
4497        truncate   = truncate,
4498        source     = ions,
4499        sposets    = generate_sposets(
4500            type           = type,
4501            occupation     = 'slater_ground',
4502            spin_polarized = spin_polarized,
4503            system         = system,
4504            sposets        = sposets,
4505            spindatasets   = True
4506            )
4507        )
4508    if sort!=None:
4509        bsb.sort = sort
4510    #end if
4511    if truncate and buffer!=None:
4512        bsb.buffer = buffer
4513    #end if
4514    if hybridrep is not None:
4515        bsb.hybridrep = hybridrep
4516    #end if
4517    if twist!=None:
4518        bsb.twistnum = system.structure.select_twist(twist)
4519    elif twistnum!=None:
4520        bsb.twistnum = twistnum
4521    elif len(system.structure.kpoints)==1:
4522        bsb.twistnum = 0
4523    else:
4524        bsb.twistnum = None
4525    #end if
4526    return bsb
4527#end def generate_bspline_builder
4528
4529
4530def generate_heg_builder(twist          = None,
4531                         spin_polarized = False,
4532                         spo_up         = 'spo_u',
4533                         spo_down       = 'spo_d',
4534                         sposets        = None,
4535                         system         = None
4536                         ):
4537    type = 'heg'
4538    hb = heg_builder(
4539        type    = type,
4540        sposets = generate_sposets(
4541            type           = type,
4542            occupation     = 'slater_ground',
4543            spin_polarized = spin_polarized,
4544            system         = system,
4545            sposets        = sposets
4546            )
4547        )
4548    if twist!=None:
4549        hb.twist = tuple(twist)
4550    #end if
4551    return hb
4552#end def generate_heg_builder
4553
4554
4555def partition_sposets(sposet_builder,partition,partition_meshfactors=None):
4556    ssb = sposet_builder
4557    spos_in =ssb.sposets
4558    del ssb.sposets
4559    if isinstance(partition,(dict,obj)):
4560        partition_indices  = sorted(partition.keys())
4561        partition_contents = partition
4562    else:
4563        partition_indices  = list(partition)
4564        partition_contents = None
4565    #end if
4566    if partition_meshfactors is not None:
4567        if partition_contents is None:
4568            partition_contents = obj()
4569            for p in partition_indices:
4570                partition_contents[p] = obj()
4571            #end for
4572        #end if
4573        for p,mf in zip(partition_indices,partition_meshfactors):
4574            partition_contents[p].meshfactor = mf
4575        #end for
4576    #end if
4577    # partition each spo in the builder and create a corresponding composite spo
4578    comp_spos = []
4579    part_spos = []
4580    for spo in spos_in.list():
4581        part_spo_names = []
4582        part_ranges = partition_indices+[spo.size]
4583        for i in range(len(partition_indices)):
4584            index_min = part_ranges[i]
4585            index_max = part_ranges[i+1]
4586            if index_min>spo.size:
4587                break
4588            elif index_max>spo.size:
4589                index_max = spo.size
4590            #end if
4591            part_spo_name = spo.name+'_'+str(index_min)
4592            part_spo = sposet(**spo)
4593            part_spo.name = part_spo_name
4594            if index_min==0:
4595                part_spo.size = index_max
4596            else:
4597                part_spo.index_min = index_min
4598                part_spo.index_max = index_max
4599                del part_spo.size
4600            #end if
4601            if partition_contents is not None:
4602                part_spo.set(**partition_contents[index_min])
4603            #end if
4604            part_spos.append(part_spo)
4605            part_spo_names.append(part_spo_name)
4606        #end for
4607        comp_spo = sposet(
4608            name = spo.name,
4609            size = spo.size,
4610            spos = part_spo_names,
4611            )
4612        comp_spos.append(comp_spo)
4613    #end for
4614
4615    ssb.sposets = make_collection(part_spos)
4616
4617    cssb = composite_builder(
4618        type = 'composite',
4619        sposets = make_collection(comp_spos),
4620        )
4621
4622    return [ssb,cssb]
4623#end def partition_sposets
4624
4625
4626def generate_determinantset(up             = 'u',
4627                            down           = 'd',
4628                            spo_up         = 'spo_u',
4629                            spo_down       = 'spo_d',
4630                            spin_polarized = False,
4631                            system         = None
4632                            ):
4633    if system is None:
4634        QmcpackInput.class_error('generate_determinantset argument system must not be None')
4635    #end if
4636    elns = system.particles.get_electrons()
4637    nup  = elns.up_electron.count
4638    ndn  = elns.down_electron.count
4639    if not spin_polarized and nup==ndn:
4640        spo_u = 'spo_ud'
4641        spo_d = 'spo_ud'
4642    else:
4643        spo_u = spo_up
4644        spo_d = spo_down
4645    #end if
4646    dset = determinantset(
4647        slaterdeterminant = slaterdeterminant(
4648            determinants = collection(
4649                determinant(
4650                    id     = 'updet',
4651                    group  = up,
4652                    sposet = spo_u,
4653                    size   = nup
4654                    ),
4655                determinant(
4656                    id     = 'downdet',
4657                    group  = down,
4658                    sposet = spo_d,
4659                    size   = ndn
4660                    )
4661                )
4662            )
4663        )
4664    return dset
4665#end def generate_determinantset
4666
4667
4668def generate_determinantset_old(type           = 'bspline',
4669                                meshfactor     = 1.0,
4670                                precision      = 'float',
4671                                twistnum       = None,
4672                                twist          = None,
4673                                spin_polarized = False,
4674                                hybridrep      = None,
4675                                source         = 'ion0',
4676                                href           = 'MISSING.h5',
4677                                excitation     = None,
4678                                system         = None
4679                                ):
4680    if system is None:
4681        QmcpackInput.class_error('generate_determinantset argument system must not be None')
4682    #end if
4683    elns = system.particles.get_electrons()
4684    down_spin = 0
4685    if spin_polarized:
4686        down_spin=1
4687    #end if
4688    tilematrix = identity(3,dtype=int)
4689    if system!=None:
4690        tilematrix = system.structure.tilematrix()
4691    #end if
4692    dset = determinantset(
4693        type       = type,
4694        meshfactor = meshfactor,
4695        precision  = precision,
4696        tilematrix = tilematrix,
4697        href       = href,
4698        source     = source,
4699        slaterdeterminant = slaterdeterminant(
4700            determinants = collection(
4701                determinant(
4702                    id   = 'updet',
4703                    size = elns.up_electron.count,
4704                    occupation=section(mode='ground',spindataset=0)
4705                    ),
4706                determinant(
4707                    id   = 'downdet',
4708                    size = elns.down_electron.count,
4709                    occupation=section(mode='ground',spindataset=down_spin)
4710                    )
4711                )
4712            )
4713        )
4714    if twist!=None:
4715        dset.twistnum = system.structure.select_twist(twist)
4716    elif twistnum!=None:
4717        dset.twistnum = twistnum
4718    elif len(system.structure.kpoints)==1:
4719        dset.twistnum = 0
4720    else:
4721        dset.twistnum = None
4722    #end if
4723    if hybridrep is not None:
4724        if hybridrep=='yes' or hybridrep=='no':
4725            dset.hybridrep = hybridrep
4726        else:
4727            dset.hybridrep = yesno_dict[hybridrep]
4728        #end if
4729    #end if
4730    if excitation is not None:
4731        format_failed = False
4732        if not isinstance(excitation,(tuple,list)):
4733            QmcpackInput.class_error('excitation must be a tuple or list\nyou provided type: {0}\nwith value: {1}'.format(excitation.__class__.__name__,excitation))
4734        elif excitation[0] not in ('up','down') or not isinstance(excitation[1],str):
4735            format_failed = True
4736        else:
4737            #There are three types of input:
4738            #1. excitation=['up','0 45 3 46']
4739            #2. excitation=['up','-215 216']
4740            #3. excitation=['up', 'L vb F cb']
4741            if len(excitation) == 2: #Type 1 or 2
4742                if 'cb' not in excitation[1] and 'vb' not in excitation[1]:
4743                    try:
4744                        tmp = array(excitation[1].split(),dtype=int)
4745                    except:
4746                        format_failed = True
4747                    #end try
4748                #end if
4749            else:
4750                format_failed = True
4751            #end if
4752        #end if
4753        if format_failed:
4754            #Should be modified
4755            QmcpackInput.class_error('excitation must be a tuple or list with with two elements\nthe first element must be either "up" or "down"\nand the second element must be integers separated by spaces, e.g. "-216 +217"\nyou provided: {0}'.format(excitation))
4756        #end if
4757
4758        spin_channel,excitation = excitation
4759
4760        if spin_channel=='up':
4761            det = dset.get('updet')
4762        elif spin_channel=='down':
4763            det = dset.get('downdet')
4764        #end if
4765        occ = det.occupation
4766        occ.pairs    = 1
4767        occ.mode     = 'excited'
4768        occ.contents = '\n'+excitation+'\n'
4769        # add new input format
4770        if 'cb' in excitation or 'vb' in excitation: #Type 3
4771            # assume excitation of form 'gamma vb k cb' or 'gamma vb-1 k cb+1'
4772            excitation = excitation.upper().split(' ')
4773            if len(excitation) == 4:
4774                k_1, band_1, k_2, band_2 = excitation
4775            else:
4776                QmcpackInput.class_error('excitation with vb-cb band format works only with special k-points')
4777            #end if
4778
4779            vb = int(det.size / abs(linalg.det(tilematrix))) -1  # Separate for each spin channel
4780            cb = vb+1
4781            # Convert band_1, band_2 to band indexes
4782            bands = [band_1, band_2]
4783            for bnum, b in enumerate(bands):
4784                if 'CB' in b:
4785                    if '-' in b:
4786                        b = b.split('-')
4787                        bands[bnum] = cb - int(b[1])
4788                    elif '+' in b:
4789                        b = b.split('+')
4790                        bands[bnum] = cb + int(b[1])
4791                    else:
4792                        bands[bnum] = cb
4793                    #end if
4794                elif 'VB' in b:
4795                    if '-' in b:
4796                        b = b.split('-')
4797                        bands[bnum] = vb - int(b[1])
4798                    elif '+' in b:
4799                        b = b.split('+')
4800                        bands[bnum] = vb + int(b[1])
4801                    else:
4802                        bands[bnum] = vb
4803                    #end if
4804                else:
4805                    QmcpackInput.class_error('{0} in excitation has the wrong formatting'.format(b))
4806                #end if
4807            #end for
4808            band_1, band_2 = bands
4809
4810            # Convert k_1 k_2 to wavevector indexes
4811            structure   = system.structure.folded_structure.copy()
4812            structure.change_units('A')
4813            kpath       = get_kpath(structure=structure)
4814            kpath_label = array(kpath['explicit_kpoints_labels'])
4815            kpath_rel   = kpath['explicit_kpoints_rel']
4816
4817            k1_in = k_1
4818            k2_in = k_2
4819            if k_1 in kpath_label and k_2 in kpath_label:
4820                k_1 = kpath_rel[where(kpath_label == k_1)][0]
4821                k_2 = kpath_rel[where(kpath_label == k_2)][0]
4822
4823                #kpts = nscf.input.k_points.kpoints
4824                kpts = structure.kpoints_unit()
4825                found_k1 = False
4826                found_k2 = False
4827                for knum, k in enumerate(kpts):
4828                    if isclose(k_1, k).all():
4829                        k_1 = knum
4830                        found_k1 = True
4831                    #end if
4832                    if isclose(k_2, k).all():
4833                        k_2 = knum
4834                        found_k2 = True
4835                    #end if
4836                #end for
4837                if not found_k1 or not found_k2:
4838                    QmcpackInput.class_error('Requested special kpoint is not in the tiled cell\nRequested "{}", present={}\nRequested "{}", present={}\nAvailable kpoints: {}'.format(k1_in,found_k1,k2_in,found_k2,sorted(set(kpath_label))))
4839                #end if
4840            else:
4841                QmcpackInput.class_error('Excitation wavevectors are not found in the kpath\nlabels requested: {} {}\nlabels present: {}'.format(k_1,k_2,sorted(set(kpath_label))))
4842            #end if
4843
4844            #Write everything in band (ti,bi) format
4845            occ.contents = '\n'+str(k_1)+' '+str(band_1)+' '+str(k_2)+' '+str(band_2)+'\n'
4846            occ.format = 'band'
4847
4848        elif '-' in excitation or '+' in excitation: #Type 2
4849            # assume excitation of form '-216 +217'
4850            occ.format = 'energy'
4851        else: #Type 1
4852            # assume excitation of form '6 36 6 37'
4853            occ.format   = 'band'
4854        #end if
4855    #end if
4856    return dset
4857#end def generate_determinantset_old
4858
4859
4860def generate_hamiltonian(name         = 'h0',
4861                         type         = 'generic',
4862                         electrons    = 'e',
4863                         ions         = 'ion0',
4864                         wavefunction = 'psi0',
4865                         pseudos      = None,
4866                         dla          = None,
4867                         format       = 'xml',
4868                         estimators   = None,
4869                         system       = None,
4870                         interactions = 'default',
4871                         ):
4872    if system is None:
4873        QmcpackInput.class_error('generate_hamiltonian argument system must not be None')
4874    #end if
4875
4876    ename   = electrons
4877    iname   = ions
4878    wfname  = wavefunction
4879    ppfiles = pseudos
4880    del electrons
4881    del ions
4882    del pseudos
4883    del wavefunction
4884
4885    particles = system.particles
4886    if particles.count_electrons()==0:
4887        QmcpackInput.class_error('cannot generate hamiltonian, no electrons present')
4888    #end if
4889
4890    pairpots = []
4891    if interactions!=None:
4892        pairpots.append(coulomb(name='ElecElec',type='coulomb',source=ename,target=ename))
4893        if particles.count_ions()>0:
4894            pairpots.append(coulomb(name='IonIon',type='coulomb',source=iname,target=iname))
4895            ions = particles.get_ions()
4896            if not system.pseudized:
4897                pairpots.append(coulomb(name='ElecIon',type='coulomb',source=iname,target=ename))
4898            else:
4899                if ppfiles is None or len(ppfiles)==0:
4900                    QmcpackInput.class_error('cannot generate hamiltonian\n  system is pseudized, but no pseudopotentials have been provided\n  please provide pseudopotential files via the pseudos keyword')
4901                #end if
4902                if isinstance(ppfiles,list):
4903                    pplist = ppfiles
4904                    ppfiles = obj()
4905                    for pppath in pplist:
4906                        if '/' in pppath:
4907                            ppfile = pppath.split('/')[-1]
4908                        else:
4909                            ppfile = pppath
4910                        #end if
4911                        element = ppfile.split('.')[0]
4912                        if len(element)>2:
4913                            element = element[0:2]
4914                        #end if
4915                        ppfiles[element] = pppath
4916                    #end for
4917                #end if
4918                pseudos = collection()
4919                for ion in ions:
4920                    label = ion.name
4921                    iselem,symbol = is_element(ion.name,symbol=True)
4922                    if label in ppfiles:
4923                        ppfile = ppfiles[label]
4924                    elif symbol in ppfiles:
4925                        ppfile = ppfiles[symbol]
4926                    else:
4927                        QmcpackInput.class_error('pseudos provided to generate_hamiltonian are incomplete\n  a pseudopotential for ion of type {0} is missing\n  pseudos provided:\n{1}'.format(ion.name,str(ppfiles)))
4928                    #end if
4929                    pseudos.add(pseudo(elementtype=label,href=ppfile))
4930                #end for
4931                pp = pseudopotential(name='PseudoPot',type='pseudo',source=iname,wavefunction=wfname,format=format,pseudos=pseudos)
4932                if dla is not None:
4933                    pp.dla = dla
4934                #end if
4935                pairpots.append(pp)
4936            #end if
4937        #end if
4938    #end if
4939
4940    ests = []
4941    if estimators!=None:
4942        for estimator in estimators:
4943            if isinstance(estimator,QIxml):
4944                estimator = estimator.copy()
4945            #end if
4946            est=estimator
4947            if isinstance(estimator,str):
4948                estname = estimator.lower().replace(' ','_').replace('-','_').replace('__','_')
4949                if estname=='mpc':
4950                    pairpots.append(mpc(name='MPC',type='MPC',ecut=60.0,source=ename,target=ename,physical=False))
4951                    est = None
4952                elif estname=='chiesa':
4953                    est = chiesa(name='KEcorr',type='chiesa',source=ename,psi=wfname)
4954                elif estname=='localenergy':
4955                    est = localenergy(name='LocalEnergy')
4956                elif estname=='skall':
4957                    est = skall(name='SkAll',type='skall',source=iname,target=ename,hdf5=True)
4958                elif estname=='energydensity':
4959                    est = energydensity(
4960                        type='EnergyDensity',name='EDvoronoi',dynamic=ename,static=iname,
4961                        spacegrid = spacegrid(coord='voronoi')
4962                        )
4963                elif estname=='pressure':
4964                    est = pressure(type='Pressure')
4965                else:
4966                    QmcpackInput.class_error('estimator '+estimator+' has not yet been enabled in generate_basic_input')
4967                #end if
4968            elif not isinstance(estimator,QIxml):
4969                QmcpackInput.class_error('generate_hamiltonian received an invalid estimator\n  an estimator must either be a name or a QIxml object\n  inputted estimator type: {0}\n  inputted estimator contents: {1}'.format(estimator.__class__.__name__,estimator))
4970            elif isinstance(estimator,energydensity):
4971                est.set_optional(
4972                    type = 'EnergyDensity',
4973                    dynamic = ename,
4974                    static  = iname,
4975                    )
4976            elif isinstance(estimator,dm1b):
4977                dm = estimator
4978                reuse = False
4979                if 'reuse' in dm:
4980                    reuse = bool(dm.reuse)
4981                    del dm.reuse
4982                #end if
4983                basis = []
4984                builder = None
4985                maxed = False
4986                if reuse and 'basis' in dm and isinstance(dm.basis,sposet):
4987                    spo = dm.basis
4988                    # get sposet size
4989                    if 'size' in dm.basis:
4990                        size = spo.size
4991                        del spo.size
4992                    elif 'index_max' in dm.basis:
4993                        size = spo.index_max
4994                        del spo.index_max
4995                    else:
4996                        QmcpackInput.class_error('cannot generate estimator dm1b\n  basis sposet provided does not have a "size" attribute')
4997                    #end if
4998                    try:
4999                        # get sposet from wavefunction
5000                        wf = QIcollections.get('wavefunctions',wfname)
5001                        dets = wf.get('determinant')
5002                        det  = dets.get_single()
5003                        if 'sposet' in det:
5004                            rsponame = det.sposet
5005                        else:
5006                            rsponame = det.id
5007                        #end if
5008                        builders = QIcollections.get('sposet_builders')
5009                        rspo = None
5010                        for bld in builders:
5011                            if rsponame in bld.sposets:
5012                                builder = bld
5013                                rspo    = bld.sposets[rsponame]
5014                                break
5015                            #end if
5016                        #end for
5017                        basis.append(rsponame)
5018                        # adjust current sposet
5019                        spo.index_min = rspo.size
5020                        spo.index_max = size
5021                        maxed = rspo.size>=size
5022                    except Exception as e:
5023                        msg = 'cannot generate estimator dm1b\n  '
5024                        if wf is None:
5025                            QmcpackInput.class_error(msg+'wavefunction {0} not found'.format(wfname))
5026                        elif dets is None or det is None:
5027                            QmcpackInput.class_error(msg+'determinant not found')
5028                        elif builders is None:
5029                            QmcpackInput.class_error(msg+'sposet_builders not found')
5030                        elif rspo is None:
5031                            QmcpackInput.class_error(msg+'sposet {0} not found'.format(rsponame))
5032                        else:
5033                            QmcpackInput.class_error(msg+'cause of failure could not be determined\n  see the following error message:\n{0}'.format(e))
5034
5035                        #end if
5036                    #end if
5037                #end if
5038                # put the basis sposet in the appropriate builder
5039                if isinstance(dm.basis,sposet) and not maxed:
5040                    spo = dm.basis
5041                    del dm.basis
5042                    if not 'type' in spo:
5043                        QmcpackInput.class_error('cannot generate estimator dm1b\n  basis sposet provided does not have a "type" attribute')
5044                    #end if
5045                    if not 'name' in spo:
5046                        spo.name = 'spo_dm'
5047                    #end if
5048                    builders = QIcollections.get('sposet_builders')
5049                    if not spo.type in builders:
5050                        bld = generate_sposet_builder(spo.type,sposets=[spo])
5051                        builders.add(bld)
5052                    else:
5053                        bld = builders[spo.type]
5054                        bld.sposets.add(spo)
5055                    #end if
5056                    basis.append(spo.name)
5057                #end if
5058                dm.basis = basis
5059                dm.incorporate_defaults(elements=False,overwrite=False,propagate=False)
5060            #end if
5061            if est!=None:
5062                ests.append(est)
5063            #end if
5064        #end for
5065    #end if
5066    estimators = ests
5067
5068    hmltn = hamiltonian(
5069        name   = name,
5070        type   = type,
5071        target = ename
5072        )
5073
5074    if len(pairpots)>0:
5075        hmltn.pairpots = make_collection(pairpots)
5076    #end if
5077
5078    if len(estimators)>0:
5079        hmltn.estimators = make_collection(estimators)
5080    #end if
5081
5082    return hmltn
5083#end def generate_hamiltonian
5084
5085
5086
5087def generate_jastrows(jastrows,system=None,return_list=False,check_ions=False):
5088    jin = []
5089    have_ions = True
5090    if check_ions and system!=None:
5091        have_ions = system.particles.count_ions()>0
5092    #end if
5093    if isinstance(jastrows,str):
5094        jorders = set(jastrows.replace('generate',''))
5095        if '1' in jorders and have_ions:
5096            jterm = generate_jastrow('J1','bspline',8,system=system)
5097        #end if
5098        if '2' in jorders:
5099            jterm = generate_jastrow('J2','bspline',8,system=system)
5100        #end if
5101        if '3' in jorders and have_ions:
5102            jterm = generate_jastrow('J3','polynomial',3,3,4.0,system=system)
5103        #end if
5104        if 'k' in jorders:
5105            kcut = max(system.rpa_kf())
5106            nksh = system.structure.count_kshells(kcut)
5107            jterm = generate_kspace_jastrow(kc1=0, kc2=kcut, nk1=0, nk2=nksh)
5108        #end if
5109        jin.append(jterm)
5110        if len(jin)==0:
5111            QmcpackInput.class_error('jastrow generation requested but no orders specified (1,2,and/or 3)')
5112        #end if
5113    else:
5114        jset = set(['J1','J2','J3'])
5115        for jastrow in jastrows:
5116            if isinstance(jastrow,QIxml):
5117                jin.append(jastrow)
5118            elif isinstance(jastrow,dict) or isinstance(jastrow,obj):
5119                jdict = dict(**jastrow)
5120                if not 'type' in jastrow:
5121                    QmcpackInput.class_error("could not determine jastrow type from input\n  field 'type' must be 'J1', 'J2', or 'J3'\n  object you provided: "+str(jastrow))
5122                #end if
5123                jtype = jdict['type']
5124                if not jtype in jset:
5125                    QmcpackInput.class_error("invalid jastrow type provided\n  field 'type' must be 'J1', 'J2', or 'J3'\n  object you provided: "+str(jdict))
5126                #end if
5127                del jdict['type']
5128                if 'system' in jdict:
5129                    jsys = jdict['system']
5130                    del jdict['system']
5131                else:
5132                    jsys = system
5133                #end if
5134                jterm = generate_jastrow(jtype,system=jsys,**jdict)
5135                if jterm is not None:
5136                    jin.append(jterm)
5137                #end if
5138                del jtype
5139                del jsys
5140            elif jastrow[0] in jset:
5141                jin.append(generate_jastrow(jastrow,system=system))
5142            else:
5143                QmcpackInput.class_error('starting jastrow unrecognized:\n  '+str(jastrow))
5144            #end if
5145        #end for
5146    #end if
5147    if return_list:
5148        return jin
5149    else:
5150        wf = wavefunction(jastrows=jin)
5151        wf.pluralize()
5152        return wf.jastrows
5153    #end if
5154#end def generate_jastrows
5155
5156
5157
5158def generate_jastrows_alt(
5159    J1           = False,
5160    J2           = False,
5161    J3           = False,
5162    J1_size      = None,
5163    J1_rcut      = None,
5164    J1_dr        = 0.5,
5165    J2_size      = None,
5166    J2_rcut      = None,
5167    J2_dr        = 0.5,
5168    J2_init      = 'zero',
5169    J3_isize     = 3,
5170    J3_esize     = 3,
5171    J3_rcut      = 5.0,
5172    J1_rcut_open = 5.0,
5173    J2_rcut_open = 10.0,
5174    system       = None,
5175    ):
5176    if system is None:
5177        QmcpackInput.class_error('input variable "system" is required to generate jastrows','generate_jastrows_alt')
5178    elif system.structure.units!='B':
5179        system = system.copy()
5180        system.structure.change_units('B')
5181    #end if
5182
5183    openbc = system.structure.is_open()
5184
5185    jastrows = []
5186    J2 |= J3
5187    J1 |= J2
5188    if not J1 and not J2 and not J3:
5189        J1 = True
5190        J2 = True
5191    #end if
5192    rwigner = None
5193    if J1:
5194        if J1_rcut is None:
5195            if openbc:
5196                J1_rcut = J1_rcut_open
5197            else:
5198                if rwigner is None:
5199                    rwigner = system.structure.rwigner(1)
5200                #end if
5201                J1_rcut = rwigner
5202            #end if
5203        #end if
5204        if J1_size is None:
5205            J1_size = int(ceil(J1_rcut/J1_dr))
5206        #end if
5207        J = generate_jastrow('J1','bspline',J1_size,J1_rcut,system=system)
5208        jastrows.append(J)
5209    #end if
5210    if J2:
5211        if J2_rcut is None:
5212            if openbc:
5213                J2_rcut = J2_rcut_open
5214            else:
5215                if rwigner is None:
5216                    rwigner = system.structure.rwigner(1)
5217                #end if
5218                J2_rcut = rwigner
5219            #end if
5220        #end if
5221        if J2_size is None:
5222            J2_size = int(ceil(J2_rcut/J2_dr))
5223        #end if
5224        J = generate_jastrow('J2','bspline',J2_size,J2_rcut,init=J2_init,system=system)
5225        jastrows.append(J)
5226    #end if
5227    if J3:
5228        if not openbc:
5229            if rwigner is None:
5230                rwigner = system.structure.rwigner(1)
5231            #end if
5232            J3_rcut = min(J3_rcut,rwigner)
5233        #end if
5234        J = generate_jastrow('J3','polynomial',J3_esize,J3_isize,J3_rcut,system=system)
5235        jastrows.append(J)
5236    #end if
5237
5238    return jastrows
5239#end def generate_jastrows_alt
5240
5241
5242def generate_jastrow(descriptor,*args,**kwargs):
5243    keywords = set(['function','size','rcut','elements','coeff','cusp','ename',
5244                    'iname','spins','density','Buu','Bud','system','isize','esize','init'])
5245    if not 'system' in kwargs:
5246        kwargs['system'] = None
5247    #end if
5248    system = kwargs['system']
5249    del kwargs['system']
5250    if system!=None:
5251        system.change_units('B')
5252    #end if
5253    if isinstance(descriptor,str):
5254        descriptor = [descriptor]
5255    #end if
5256    ikw=0
5257    for i in range(len(descriptor)):
5258        if descriptor[i] in keywords:
5259            break
5260        #end if
5261        ikw += 1
5262    #end for
5263    dargs = descriptor[1:ikw]
5264    if len(dargs)>0:
5265        args = dargs
5266    #end if
5267    for i in range(ikw,len(descriptor),2):
5268        d = descriptor[i]
5269        if isinstance(d,str):
5270            if d in keywords:
5271                kwargs[d] = descriptor[i+1]
5272            else:
5273                QmcpackInput.class_error('keyword {0} is unrecognized\n  valid options are: {1}'.format(d,str(keywords)),'generate_jastrow')
5274            #end if
5275        #end if
5276    #end for
5277    kwargs['system'] = system
5278    jtype = descriptor[0]
5279    if jtype=='J1':
5280        jastrow = generate_jastrow1(*args,**kwargs)
5281    elif jtype=='J2':
5282        jastrow = generate_jastrow2(*args,**kwargs)
5283    elif jtype=='J3':
5284        jastrow = generate_jastrow3(*args,**kwargs)
5285    else:
5286        QmcpackInput.class_error('jastrow type unrecognized: '+jtype)
5287    #end if
5288    return jastrow
5289#end def generate_jastrow
5290
5291
5292
5293def generate_jastrow1(function='bspline',size=8,rcut=None,coeff=None,cusp=0.,ename='e',iname='ion0',elements=None,system=None,**elemargs):
5294    noelements = elements is None
5295    nosystem   = system is None
5296    noelemargs = len(elemargs)==0
5297    isopen     = False
5298    isperiodic = False
5299    rwigner = 1e99
5300    if noelements and nosystem and noelemargs:
5301        QmcpackInput.class_error('must specify elements or system','generate_jastrow1')
5302    #end if
5303    if noelements:
5304        elements = []
5305    #end if
5306    if not nosystem:
5307        elements.extend(list(set(system.structure.elem)))
5308        isopen     = system.structure.is_open()
5309        isperiodic = system.structure.is_periodic()
5310        if not isopen and isperiodic:
5311            rwigner = system.structure.rwigner()
5312        #end if
5313    #end if
5314    if not noelemargs:
5315        elements.extend(elemargs.keys())
5316    #end if
5317    # remove duplicate elements
5318    eset = set()
5319    elements = [ e for e in elements if e not in eset and not eset.add(e) ]
5320    corrs = []
5321    for i in range(len(elements)):
5322        element = elements[i]
5323        if cusp is 'Z':
5324            QmcpackInput.class_error('need to implement Z cusp','generate_jastrow1')
5325        else:
5326            lcusp  = cusp
5327        #end if
5328        lrcut  = rcut
5329        lcoeff = size*[0]
5330        if coeff!=None:
5331            if element in coeff:
5332                lcoeff = coeff[element]
5333            else:
5334                lcoeff = coeff[i]
5335            #end if
5336        #end if
5337        if element in elemargs:
5338            v = elemargs[element]
5339            if 'cusp' in v:
5340                lcusp = v['cusp']
5341            #end if
5342            if 'rcut' in v:
5343                lrcut = v['rcut']
5344            #end if
5345            if 'size' in v and not 'coeff' in v:
5346                lcoeff = v['size']*[0]
5347            #end if
5348            if 'coeff' in v:
5349                lcoeff = v['coeff']
5350            #end if
5351        #end if
5352        corr = correlation(
5353            elementtype = element,
5354            size        = len(lcoeff),
5355            cusp        = cusp,
5356            coefficients=section(
5357                id    = ename+element,
5358                type  = 'Array',
5359                coeff = lcoeff
5360                )
5361            )
5362        if lrcut!=None:
5363            if isperiodic and lrcut>rwigner:
5364                QmcpackInput.class_error('rcut must not be greater than the simulation cell wigner radius\nyou provided: {0}\nwigner radius: {1}'.format(lrcut,rwigner),'generate_jastrow1')
5365
5366            corr.rcut = lrcut
5367        elif isopen:
5368            QmcpackInput.class_error('rcut must be provided for an open system','generate_jastrow1')
5369        elif isperiodic:
5370            corr.rcut = rwigner
5371        #end if
5372        corrs.append(corr)
5373    #end for
5374    j1 = jastrow1(
5375        name         = 'J1',
5376        type         = 'One-Body',
5377        function     = function,
5378        source       = iname,
5379        print       = True,
5380        correlations = corrs
5381        )
5382    return j1
5383#end def generate_jastrow1
5384
5385
5386
5387def generate_bspline_jastrow2(size=8,rcut=None,coeff=None,spins=('u','d'),density=None,system=None,init='rpa'):
5388    if coeff is None and system is None and (init=='rpa' and density is None or rcut is None):
5389        QmcpackInput.class_error('rcut and density or system must be specified','generate_bspline_jastrow2')
5390    #end if
5391    isopen      = False
5392    isperiodic  = False
5393    allperiodic = False
5394    rwigner     = 1e99
5395    if system!=None:
5396        isopen      = system.structure.is_open()
5397        isperiodic  = system.structure.is_periodic()
5398        allperiodic = system.structure.all_periodic()
5399        if not isopen and isperiodic:
5400            rwigner = system.structure.rwigner()
5401        #end if
5402        volume = system.structure.volume()
5403        if isopen:
5404            if rcut is None:
5405                QmcpackInput.class_error('rcut must be provided for an open system','generate_bspline_jastrow2')
5406            #end if
5407            if init=='rpa':
5408                init = 'zero'
5409            #end if
5410        else:
5411            if rcut is None and isperiodic:
5412                rcut = rwigner
5413            #end if
5414            nelectrons = system.particles.count_electrons()
5415            density = nelectrons/volume
5416        #end if
5417    elif init=='rpa':
5418        init = 'zero'
5419    #end if
5420    if coeff is None:
5421        if init=='rpa':
5422            if not allperiodic:
5423                QmcpackInput.class_error('rpa initialization can only be used for fully periodic systems','generate_bspline_jastrow2')
5424            #end if
5425            wp = sqrt(4*pi*density)
5426            dr = rcut/size
5427            r = .02 + dr*arange(size)
5428            uuc = .5/(wp*r)*(1.-exp(-r*sqrt(wp/2)))*exp(-(2*r/rcut)**2)
5429            udc = .5/(wp*r)*(1.-exp(-r*sqrt(wp))  )*exp(-(2*r/rcut)**2)
5430            coeff = [uuc,udc]
5431        elif init=='zero' or init==0:
5432            coeff = [size*[0],size*[0]]
5433        else:
5434            QmcpackInput.class_error(str(init)+' is not a valid value for parameter init\n  valid options are: rpa, zero','generate_bspline_jastrow2')
5435        #end if
5436    elif len(coeff)!=2:
5437        QmcpackInput.class_error('must provide 2 sets of coefficients (uu,ud)','generate_bspline_jastrow2')
5438    #end if
5439    size = len(coeff[0])
5440    uname,dname = spins
5441    uuname = uname+uname
5442    udname = uname+dname
5443    corrs = [
5444        correlation(speciesA=uname,speciesB=uname,size=size,
5445                    coefficients=section(id=uuname,type='Array',coeff=coeff[0])),
5446        correlation(speciesA=uname,speciesB=dname,size=size,
5447                    coefficients=section(id=udname,type='Array',coeff=coeff[1]))
5448        ]
5449    if rcut!=None:
5450        if isperiodic and rcut>rwigner:
5451            QmcpackInput.class_error('rcut must not be greater than the simulation cell wigner radius\nyou provided: {0}\nwigner radius: {1}'.format(rcut,rwigner),'generate_jastrow2')
5452        #end if
5453        for corr in corrs:
5454            corr.rcut=rcut
5455        #end for
5456    #end if
5457    j2 = jastrow2(
5458        name = 'J2',type='Two-Body',function='bspline',print=True,
5459        correlations = corrs
5460        )
5461    return j2
5462#end def generate_bspline_jastrow2
5463
5464
5465def generate_pade_jastrow2(Buu=None,Bud=None,spins=('u','d'),system=None):
5466    if Buu is None:
5467        Buu = 2.0
5468    #end if
5469    if Bud is None:
5470        Bud = float(Buu)
5471    #end if
5472    uname,dname = spins
5473    uuname = uname+uname
5474    udname = uname+dname
5475    cuu = var(id=uuname+'_b',name='B',value=Buu)
5476    cud = var(id=udname+'_b',name='B',value=Bud)
5477    corrs = [
5478        correlation(speciesA=uname,speciesB=uname,
5479                    vars=[cuu]),
5480        correlation(speciesA=uname,speciesB=dname,
5481                    vars=[cud])
5482        ]
5483    j2 = jastrow2(
5484        name = 'J2',type='Two-Body',function='pade',
5485        correlations = corrs
5486        )
5487    return j2
5488#end def generate_pade_jastrow2
5489
5490
5491
5492def generate_jastrow2(function='bspline',*args,**kwargs):
5493    if not 'spins' in kwargs:
5494        kwargs['spins'] = ('u','d')
5495    #end if
5496    spins = kwargs['spins']
5497    if not isinstance(spins,tuple) and not isinstance(spins,list):
5498        QmcpackInput.class_error('spins must be a list or tuple of u/d spin names\n  you provided: '+str(spins))
5499    #end if
5500    if len(spins)!=2:
5501        QmcpackInput.class_error('name for up and down spins must be specified\n  you provided: '+str(spins))
5502    #end if
5503    if not isinstance(function,str):
5504        QmcpackInput.class_error('function must be a string\n  you provided: '+str(function),'generate_jastrow2')
5505    #end if
5506    if function=='bspline':
5507        j2 = generate_bspline_jastrow2(*args,**kwargs)
5508    elif function=='pade':
5509        j2 = generate_pade_jastrow2(*args,**kwargs)
5510    else:
5511        QmcpackInput.class_error('function is invalid\n  you provided: {0}\n  valid options are: bspline or pade'.format(function),'generate_jastrow2')
5512    #end if
5513    if 'system' in kwargs and kwargs['system'] is not None:
5514        nup,ndn = kwargs['system'].particles.electron_counts()
5515        if nup<2:
5516            del j2.correlations.uu
5517        #end if
5518        if nup<1 or ndn<1:
5519            del j2.correlations.ud
5520        #end if
5521        if len(j2.correlations)==0:
5522            j2=None
5523        #end if
5524    #end if
5525    return j2
5526#end def generate_jastrow2
5527
5528
5529
5530def generate_jastrow3(function='polynomial',esize=3,isize=3,rcut=4.,coeff=None,iname='ion0',spins=('u','d'),elements=None,system=None):
5531    if elements is None and system is None:
5532        QmcpackInput.class_error('must specify elements or system','generate_jastrow3')
5533    elif elements is None:
5534        elements = list(set(system.structure.elem))
5535    #end if
5536    if coeff!=None:
5537        QmcpackInput.class_error('handling coeff is not yet implemented for generate jastrow3')
5538    #end if
5539    if len(spins)!=2:
5540        QmcpackInput.class_error('must specify name for up and down spins\n  provided: '+str(spins),'generate_jastrow3')
5541    #end if
5542    if rcut is None:
5543        QmcpackInput.class_error('must specify rcut','generate_jastrow3')
5544    #end if
5545    if system!=None and system.structure.is_periodic():
5546        rwigner = system.structure.rwigner()
5547        if rcut>rwigner:
5548            QmcpackInput.class_error('rcut must not be greater than the simulation cell wigner radius\nyou provided: {0}\nwigner radius: {1}'.format(rcut,rwigner),'generate_jastrow3')
5549        #end if
5550    #end if
5551    uname,dname = spins
5552    uuname = uname+uname
5553    udname = uname+dname
5554    corrs=[]
5555    for element in elements:
5556        corrs.append(
5557            correlation(
5558                especies1=uname,especies2=uname,ispecies=element,esize=esize,
5559                isize=isize,rcut=rcut,
5560                coefficients=section(id=uuname+element,type='Array',optimize=True))
5561            )
5562        corrs.append(
5563            correlation(
5564                especies1=uname,especies2=dname,ispecies=element,esize=esize,
5565                isize=isize,rcut=rcut,
5566                coefficients=section(id=udname+element,type='Array',optimize=True))
5567            )
5568    #end for
5569    jastrow = jastrow3(
5570        name = 'J3',type='eeI',function=function,print=True,source=iname,
5571        correlations = corrs
5572        )
5573    return jastrow
5574#end def generate_jastrow3
5575
5576
5577def generate_kspace_jastrow(kc1=0, kc2=0, nk1=0, nk2=0,
5578  symm1='isotropic', symm2='isotropic', coeff1=None, coeff2=None):
5579  """Generate <jastrow type="kSpace">
5580
5581  Parameters
5582  ----------
5583    kc1 : float, optional
5584      kcut for one-body Jastrow, default 0
5585    kc2 : float, optional
5586      kcut for two-body Jastrow, default 0
5587    nk1 : int, optional
5588      number of coefficients for one-body Jastrow, default 0
5589    nk2 : int, optional
5590      number of coefficients for two-body Jastrow, default 0
5591    symm1 : str, optional
5592      one of ['crystal', 'isotropic', 'none'], default 'isotropic'
5593    symm2 : str, optional
5594      one of ['crystal', 'isotropic', 'none'], default is 'isotropic'
5595    coeff1 : list, optional
5596      one-body Jastrow coefficients, default None
5597    coeff2 : list, optional
5598      list, optional two-body Jastrow coefficients, default None
5599  Returns
5600  -------
5601    jk: QIxml
5602      kspace_jastrow qmcpack_input element
5603  """
5604
5605  if coeff1 is None: coeff1 = [0]*nk1
5606  if coeff2 is None: coeff2 = [0]*nk2
5607  if len(coeff1) != nk1:
5608    QmcpackInput.class_error('coeff1 mismatch', 'generate_kspace_jastrow')
5609  #end if
5610  if len(coeff2) != nk2:
5611    QmcpackInput.class_error('coeff2 mismatch', 'generate_kspace_jastrow')
5612  #end if
5613
5614  corr1 = correlation(
5615    type = 'One-Body',
5616    symmetry = symm1,
5617    kc = kc1,
5618    coefficients = section(
5619      id = 'cG1', type = 'Array',
5620      coeff = coeff1
5621    )
5622  )
5623  corr2 = correlation(
5624    type = 'Two-Body',
5625    symmetry = symm2,
5626    kc = kc2,
5627    coefficients = section(
5628      id = 'cG2', type = 'Array',
5629      coeff = coeff2
5630     )
5631  )
5632  jk = kspace_jastrow(
5633    type = 'kSpace',
5634    name = 'Jk',
5635    source = 'ion0',
5636    correlations = collection([corr1, corr2])
5637  )
5638  return jk
5639# end def generate_kspace_jastrow
5640
5641
5642def count_jastrow_params(jastrows):
5643    if isinstance(jastrows,QIxml):
5644        jastrows = [jastrows]
5645    #end if
5646    params = 0
5647    for jastrow in jastrows:
5648        name = jastrow.name
5649        if 'type' in jastrow:
5650            type = jastrow.type.lower()
5651        else:
5652            type = ''
5653        #end if
5654        jastrow.pluralize()
5655        if name=='J1' or type=='one-body':
5656            for correlation in jastrow.correlations:
5657                params += correlation.size
5658            #end for
5659        elif name=='J2' or type=='two-body':
5660            for correlation in jastrow.correlations:
5661                params += correlation.size
5662            #end for
5663        elif name=='J3' or type=='eeI':
5664            for correlation in jastrow.correlations:
5665                params += correlation.esize
5666                params += correlation.isize
5667            #end for
5668        #end if
5669    #end for
5670    return params
5671#end def count_jastrow_params
5672
5673
5674def generate_energydensity(
5675        name      = None,
5676        dynamic   = None,
5677        static    = None,
5678        coord     = None,
5679        grid      = None,
5680        scale     = None,
5681        ion_grids = None,
5682        system    = None,
5683        ):
5684    if dynamic is None:
5685        dynamic = 'e'
5686    #end if
5687    if static is None:
5688        static = 'ion0'
5689    #end if
5690    refp = None
5691    sg = []
5692    if coord is None:
5693        QmcpackInput.class_error('coord must be provided','generate_energydensity')
5694    elif coord=='voronoi':
5695        if name is None:
5696            name = 'EDvoronoi'
5697        #end if
5698        sg.append(spacegrid(coord=coord))
5699    elif coord=='cartesian':
5700        if name is None:
5701            name = 'EDcell'
5702        #end if
5703        if grid is None:
5704            QmcpackInput.class_error('grid must be provided for cartesian coordinates','generate_energydensity')
5705        #end if
5706        axes = [
5707            axis(p1='a1',scale='.5',label='x'),
5708            axis(p1='a2',scale='.5',label='y'),
5709            axis(p1='a3',scale='.5',label='z'),
5710            ]
5711        n=0
5712        for ax in axes:
5713           ax.grid = '-1 ({0}) 1'.format(grid[n])
5714           n+=1
5715        #end for
5716        sg.append(spacegrid(coord=coord,origin=origin(p1='zero'),axes=axes))
5717    elif coord=='spherical':
5718        if name is None:
5719            name = 'EDatom'
5720        #end if
5721        if ion_grids is None:
5722            QmcpackInput.class_error('ion_grids must be provided for spherical coordinates','generate_energydensity')
5723        #end if
5724        refp = reference_points(coord='cartesian',points='\nr1 1 0 0\nr2 0 1 0\nr3 0 0 1\n')
5725        if system is None:
5726            i=1
5727            for scale,g1,g2,g3 in ion_grids:
5728                grid = g1,g2,g3
5729                axes = [
5730                    axis(p1='r1',scale=scale,label='r'),
5731                    axis(p1='r2',scale=scale,label='phi'),
5732                    axis(p1='r3',scale=scale,label='theta'),
5733                    ]
5734                n=0
5735                for ax in axes:
5736                    ax.grid = '0 ({0}) 1'.format(grid[n])
5737                    n+=1
5738                #end for
5739                sg.append(spacegrid(coord=coord,origin=origin(p1=static+str(i)),axes=axes))
5740                i+=1
5741            #end for
5742        else:
5743            ig = ion_grids
5744            ion_grids = obj()
5745            for e,s,g1,g2,g3 in ig:
5746                ion_grids[e] = s,(g1,g2,g3)
5747            #end for
5748            missing = set(ion_grids.keys())
5749            i=1
5750            for e in system.structure.elem:
5751                if e in ion_grids:
5752                    scale,grid = ion_grids[e]
5753                    axes = [
5754                        axis(p1='r1',scale=scale,label='r'),
5755                        axis(p1='r2',scale=scale,label='phi'),
5756                        axis(p1='r3',scale=scale,label='theta'),
5757                        ]
5758                    n=0
5759                    for ax in axes:
5760                        ax.grid = '0 ({0}) 1'.format(grid[n])
5761                        n+=1
5762                    #end for
5763                    sg.append(spacegrid(coord=coord,origin=origin(p1=static+str(i)),axes=axes))
5764                    if e in missing:
5765                        missing.remove(e)
5766                    #end if
5767                #end if
5768                i+=1
5769            #end for
5770            if len(missing)>0:
5771                QmcpackInput.class_error('ion species not found for spherical grid\nspecies not found: {0}\nspecies present: {1}'.format(sorted(missing),sorted(set(list(system.structure.elem)))),'generate_energydensity')
5772            #end if
5773        #end if
5774    else:
5775        QmcpackInput.class_error('unsupported coord type\ncoord type provided: {0}\nsupported coord types: voronoi, cartesian, spherical'.format(coord),'generate_energydensity')
5776    #end if
5777    ed = energydensity(
5778        type       = 'EnergyDensity',
5779        name       = name,
5780        dynamic    = dynamic,
5781        static     = static,
5782        spacegrids = sg,
5783        )
5784    if refp is not None:
5785        ed.reference_points = refp
5786    #end if
5787    return ed
5788#end def generate_energydensity
5789
5790
5791opt_map = dict(linear=linear,cslinear=cslinear)
5792def generate_opt(method,
5793                 repeat           = 1,
5794                 energy           = None,
5795                 rw_variance      = None,
5796                 urw_variance     = None,
5797                 params           = None,
5798                 jastrows         = None,
5799                 processes        = None,
5800                 walkers_per_proc = None,
5801                 threads          = None,
5802                 blocks           = 2000,
5803                 #steps            = 5,
5804                 decorr           = 10,
5805                 min_walkers      = None, #use e.g. 128 for gpu's
5806                 timestep         = .5,
5807                 nonlocalpp       = False,
5808                 sample_factor    = 1.0):
5809    if not method in opt_map:
5810        QmcpackInput.class_error('section cannot be generated for optimization method '+method)
5811    #end if
5812    if energy is None and rw_variance is None and urw_variance is None:
5813        QmcpackInput.class_error('at least one cost parameter must be specified\n options are: energy, rw_variance, urw_variance')
5814    #end if
5815    if params is None and jastrows is None:
5816        QmcpackInput.class_error('must provide either number of opt parameters (params) or a list of jastrow objects (jastrows)')
5817    #end if
5818    if processes is None:
5819        QmcpackInput.class_error('must specify total number of processes')
5820    elif walkers_per_proc is None and threads is None:
5821        QmcpackInput.class_error('must specify walkers_per_proc or threads')
5822    #end if
5823
5824    if params is None:
5825        params = count_jastrow_params(jastrows)
5826    #end if
5827    samples = max(100000,100*params**2)
5828    samples = int(round(sample_factor*samples))
5829    samples_per_proc = int(round(float(samples)/processes))
5830
5831    if walkers_per_proc is None:
5832        walkers = 1
5833        walkers_per_proc = threads
5834    else:
5835        walkers = walkers_per_proc
5836    #end if
5837    tot_walkers = processes*walkers_per_proc
5838    if min_walkers!=None:
5839        tot_walkers = max(min_walkers,tot_walkers)
5840        walkers = int(ceil(float(tot_walkers)/processes-.001))
5841        if threads!=None and mod(walkers,threads)!=0:
5842            walkers = threads*int(ceil(float(walkers)/threads-.001))
5843        #end if
5844    #end if
5845    #blocks = int(ceil(float(decorr*samples)/(steps*tot_walkers)))
5846    blocks = min(blocks,samples_per_proc*decorr)
5847
5848    opt = opt_map[method]()
5849
5850    opt.set(
5851        walkers    = walkers,
5852        blocks     = blocks,
5853        #steps      = steps,
5854        samples    = samples,
5855        substeps   = decorr,
5856        timestep   = timestep,
5857        nonlocalpp = nonlocalpp,
5858        stepsbetweensamples = 1
5859        )
5860    if energy!=None:
5861        opt.energy = energy
5862    #end if
5863    if rw_variance!=None:
5864        opt.reweightedvariance = rw_variance
5865    #end if
5866    if urw_variance!=None:
5867        opt.unreweightedvariance = urw_variance
5868    #end if
5869
5870    opt.incorporate_defaults(elements=True)
5871
5872    if repeat>1:
5873        opt = loop(max=repeat,qmc=opt)
5874    #end if
5875
5876    return opt
5877#end def generate_opt
5878
5879
5880def generate_opts(opt_reqs,**kwargs):
5881    opts = []
5882    for opt_req in opt_reqs:
5883        opts.append(generate_opt(*opt_req,**kwargs))
5884    #end for
5885    return opts
5886#end def generate_opts
5887
5888
5889
5890opt_defaults = obj(
5891    method          = 'linear',
5892    minmethod       = 'quartic',
5893    cost            = 'variance',
5894    cycles          = 12,
5895    var_cycles      = 0,
5896    var_samples     = None,
5897    init_cycles     = 0,
5898    init_samples    = None,
5899    init_minwalkers = 1e-4,
5900    )
5901
5902shared_opt_defaults = obj(
5903    samples              = 204800,
5904    nonlocalpp           = True,
5905    use_nonlocalpp_deriv = True,
5906    warmupsteps          = 300,
5907    blocks               = 100,
5908    steps                = 1,
5909    substeps             = 10,
5910    timestep             = 0.3,
5911    usedrift             = False,
5912    )
5913
5914linear_quartic_defaults = obj(
5915    minwalkers        = 0.3,
5916    usebuffer         = True,
5917    exp0              = -6,
5918    bigchange         = 10.0,
5919    alloweddifference = 1e-04,
5920    stepsize          = 0.15,
5921    nstabilizers      = 1,
5922    **shared_opt_defaults
5923    )
5924linear_oneshift_defaults = obj(
5925    minwalkers = 0.5,
5926    #shift_i    = 0.01,
5927    #shift_s    = 1.00,
5928    **shared_opt_defaults
5929    )
5930linear_adaptive_defaults = obj(
5931    minwalkers          = 0.3,
5932    max_relative_change = 10.0,
5933    max_param_change    = 0.3,
5934    shift_i             = 0.01,
5935    shift_s             = 1.00,
5936    **shared_opt_defaults
5937    )
5938
5939opt_method_defaults = obj({
5940    ('linear'  ,'quartic' ) : linear_quartic_defaults,
5941    ('linear'  ,'rescale' ) : linear_quartic_defaults,
5942    ('linear'  ,'linemin' ) : linear_quartic_defaults,
5943    ('cslinear','quartic' ) : linear_quartic_defaults,
5944    ('cslinear','rescale' ) : linear_quartic_defaults,
5945    ('cslinear','linemin' ) : linear_quartic_defaults,
5946    ('linear'  ,'adaptive') : linear_adaptive_defaults,
5947    ('linear'  ,'oneshift') : linear_oneshift_defaults,
5948    ('linear'  ,'oneshiftonly') : linear_oneshift_defaults,
5949    })
5950del shared_opt_defaults
5951del linear_quartic_defaults
5952del linear_oneshift_defaults
5953del linear_adaptive_defaults
5954
5955allowed_opt_method_inputs = set(linear.attributes+linear.parameters
5956                                +cslinear.attributes+cslinear.parameters)
5957
5958vmc_defaults = obj(
5959    walkers     = 1,
5960    warmupsteps = 50,
5961    blocks      = 800,
5962    steps       = 10,
5963    substeps    = 3,
5964    timestep    = 0.3,
5965    checkpoint  = -1,
5966    )
5967vmc_test_defaults = obj(
5968    warmupsteps = 10,
5969    blocks      = 20,
5970    steps       =  4,
5971    ).set_optional(**vmc_defaults)
5972vmc_noJ_defaults = obj(
5973    warmupsteps = 200,
5974    blocks      = 800,
5975    steps       = 100,
5976    ).set_optional(**vmc_defaults)
5977
5978dmc_defaults = obj(
5979    warmupsteps             = 20,
5980    blocks                  = 200,
5981    steps                   = 10,
5982    timestep                = 0.01,
5983    checkpoint              = -1,
5984    vmc_samples             = 2048,
5985    vmc_samplesperthread    = None,
5986    vmc_walkers             = 1,
5987    vmc_warmupsteps         = 30,
5988    vmc_blocks              = 40,
5989    vmc_steps               = 10,
5990    vmc_substeps            = 3,
5991    vmc_timestep            = 0.3,
5992    vmc_checkpoint          = -1,
5993    eq_dmc                  = False,
5994    eq_warmupsteps          = 20,
5995    eq_blocks               = 20,
5996    eq_steps                = 5,
5997    eq_timestep             = 0.02,
5998    eq_checkpoint           = -1,
5999    ntimesteps              = 1,
6000    timestep_factor         = 0.5,
6001    nonlocalmoves           = None,
6002    branching_cutoff_scheme = None,
6003    )
6004dmc_test_defaults = obj(
6005    vmc_warmupsteps = 10,
6006    vmc_blocks      = 20,
6007    vmc_steps       =  4,
6008    eq_warmupsteps  =  2,
6009    eq_blocks       =  5,
6010    eq_steps        =  2,
6011    warmupsteps     =  2,
6012    blocks          = 10,
6013    steps           =  2,
6014    ).set_optional(**dmc_defaults)
6015dmc_noJ_defaults = obj(
6016    warmupsteps     =  40,
6017    blocks          = 400,
6018    steps           =  20,
6019    ).set_optional(**dmc_defaults)
6020
6021qmc_defaults = obj(
6022    opt      = opt_defaults,
6023    vmc      = vmc_defaults,
6024    vmc_test = vmc_test_defaults,
6025    vmc_noJ  = vmc_noJ_defaults,
6026    dmc      = dmc_defaults,
6027    dmc_test = dmc_test_defaults,
6028    dmc_noJ  = dmc_noJ_defaults,
6029    )
6030del opt_defaults
6031del vmc_defaults
6032del vmc_test_defaults
6033del vmc_noJ_defaults
6034del dmc_defaults
6035del dmc_test_defaults
6036del dmc_noJ_defaults
6037
6038
6039
6040def generate_opt_calculations(
6041    method     ,
6042    cost       ,
6043    cycles     ,
6044    var_cycles ,
6045    var_samples,
6046    init_cycles,
6047    init_samples,
6048    init_minwalkers,
6049    loc        = 'generate_opt_calculations',
6050    **opt_inputs
6051    ):
6052
6053    methods = obj(linear=linear,cslinear=cslinear)
6054    if method not in methods:
6055        error('invalid optimization method requested\ninvalid method: {0}\nvalid options are: {1}'.format(method,sorted(methods.keys())),loc)
6056    #end if
6057    opt = methods[method]
6058
6059    opt_inputs = obj(opt_inputs)
6060    invalid = set(opt_inputs.keys())-allowed_opt_method_inputs
6061    oneshift = False
6062    if len(invalid)>0:
6063        error('invalid optimization inputs provided\ninvalid inputs: {}\nvalid options are: {}'.format(sorted(invalid),sorted(allowed_opt_method_inputs)))
6064    #end if
6065    if 'minmethod' in opt_inputs and opt_inputs.minmethod.lower().startswith('oneshift'):
6066        opt_inputs.minmethod = 'OneShiftOnly'
6067        oneshift = True
6068    #end if
6069
6070    if cost=='variance':
6071        cost = (0.0,1.0,0.0)
6072    elif cost=='energy':
6073        cost = (1.0,0.0,0.0)
6074    elif isinstance(cost,(tuple,list)) and (len(cost)==2 or len(cost)==3):
6075        if len(cost)==2:
6076            cost = (cost[0],0.0,cost[1])
6077        #end if
6078    else:
6079        error('invalid optimization cost function encountered\ninvalid cost fuction: {0}\nvalid options are: variance, energy, (0.95,0.05), etc'.format(cost),loc)
6080    #end if
6081    opt_calcs = []
6082    if var_cycles>0:
6083        vmin_opt = opt(
6084            energy               = 0.0,
6085            unreweightedvariance = 1.0,
6086            reweightedvariance   = 0.0,
6087            **opt_inputs
6088            )
6089        if var_samples is not None:
6090            vmin_opt.samples = var_samples
6091        #end if
6092        opt_calcs.append(loop(max=var_cycles,qmc=vmin_opt))
6093    #end if
6094    if init_cycles>0:
6095        init_opt = opt(**opt_inputs)
6096        if init_samples is not None:
6097            init_opt.samples = init_samples
6098        #end if
6099        init_opt.minwalkers = init_minwalkers
6100        if not oneshift:
6101            init_opt.energy               = cost[0]
6102            init_opt.unreweightedvariance = cost[1]
6103            init_opt.reweightedvariance   = cost[2]
6104        #end if
6105        opt_calcs.append(loop(max=init_cycles,qmc=init_opt))
6106    #end if
6107
6108    cost_opt = opt(**opt_inputs)
6109    if not oneshift:
6110        cost_opt.energy               = cost[0]
6111        cost_opt.unreweightedvariance = cost[1]
6112        cost_opt.reweightedvariance   = cost[2]
6113    #end if
6114
6115    opt_calcs.append(loop(max=cycles,qmc=cost_opt))
6116    return opt_calcs
6117#end def generate_opt_calculations
6118
6119
6120
6121def generate_vmc_calculations(
6122    walkers    ,
6123    warmupsteps,
6124    blocks     ,
6125    steps      ,
6126    substeps   ,
6127    timestep   ,
6128    checkpoint ,
6129    loc        = 'generate_vmc_calculations',
6130    ):
6131    vmc_calcs = [
6132        vmc(
6133            walkers     = walkers,
6134            warmupsteps = warmupsteps,
6135            blocks      = blocks,
6136            steps       = steps,
6137            substeps    = substeps,
6138            timestep    = timestep,
6139            checkpoint  = checkpoint,
6140            )
6141        ]
6142    return vmc_calcs
6143#end def generate_vmc_calculations
6144
6145
6146
6147def generate_dmc_calculations(
6148    warmupsteps            ,
6149    blocks                 ,
6150    steps                  ,
6151    timestep               ,
6152    checkpoint             ,
6153    vmc_samples            ,
6154    vmc_samplesperthread   ,
6155    vmc_walkers            ,
6156    vmc_warmupsteps        ,
6157    vmc_blocks             ,
6158    vmc_steps              ,
6159    vmc_substeps           ,
6160    vmc_timestep           ,
6161    vmc_checkpoint         ,
6162    eq_dmc                 ,
6163    eq_warmupsteps         ,
6164    eq_blocks              ,
6165    eq_steps               ,
6166    eq_timestep            ,
6167    eq_checkpoint          ,
6168    ntimesteps             ,
6169    timestep_factor        ,
6170    nonlocalmoves          ,
6171    branching_cutoff_scheme,
6172    loc                 = 'generate_dmc_calculations',
6173    ):
6174
6175    if vmc_samples is None and vmc_samplesperthread is None:
6176        error('vmc samples (dmc walkers) not specified\nplease provide one of the following keywords: vmc_samples, vmc_samplesperthread',loc)
6177    #end if
6178
6179    vmc_calc = vmc(
6180        walkers     = vmc_walkers,
6181        warmupsteps = vmc_warmupsteps,
6182        blocks      = vmc_blocks,
6183        steps       = vmc_steps,
6184        substeps    = vmc_substeps,
6185        timestep    = vmc_timestep,
6186        checkpoint  = vmc_checkpoint,
6187        )
6188    if vmc_samplesperthread is not None:
6189        vmc_calc.samplesperthread = vmc_samplesperthread
6190    elif vmc_samples is not None:
6191        vmc_calc.samples = vmc_samples
6192    #end if
6193
6194    dmc_calcs = [vmc_calc]
6195    if eq_dmc:
6196        dmc_calcs.append(
6197            dmc(
6198                warmupsteps   = eq_warmupsteps,
6199                blocks        = eq_blocks,
6200                steps         = eq_steps,
6201                timestep      = eq_timestep,
6202                checkpoint    = eq_checkpoint,
6203                )
6204            )
6205    #end if
6206    tfac = 1.0
6207    for n in range(ntimesteps):
6208        sfac = 1.0/tfac
6209        dmc_calcs.append(
6210            dmc(
6211                warmupsteps   = int(sfac*warmupsteps),
6212                blocks        = blocks,
6213                steps         = int(sfac*steps),
6214                timestep      = tfac*timestep,
6215                checkpoint    = checkpoint,
6216                )
6217            )
6218        tfac *= timestep_factor
6219    #end for
6220
6221    for calc in dmc_calcs:
6222        if isinstance(calc,dmc):
6223            if nonlocalmoves is not None:
6224                calc.nonlocalmoves = nonlocalmoves
6225            #end if
6226            if branching_cutoff_scheme is not None:
6227                calc.branching_cutoff_scheme = branching_cutoff_scheme
6228            #end if
6229        #end if
6230    #end for
6231
6232    return dmc_calcs
6233#end def generate_dmc_calculations
6234
6235
6236
6237def generate_qmcpack_input(**kwargs):
6238    QIcollections.clear()
6239    system = kwargs.get('system',None)
6240    if isinstance(system,PhysicalSystem):
6241        system.update_particles()
6242    #end if
6243    selector = kwargs.pop('input_type','basic')
6244    if selector=='basic':
6245        inp = generate_basic_input(**kwargs)
6246    elif selector=='basic_afqmc':
6247        inp = generate_basic_afqmc_input(**kwargs)
6248    elif selector=='opt_jastrow':
6249        inp = generate_opt_jastrow_input(**kwargs)
6250    else:
6251        QmcpackInput.class_error('selection '+str(selector)+' has not been implemented for qmcpack input generation')
6252    #end if
6253    return inp
6254#end def generate_qmcpack_input
6255
6256
6257
6258gen_basic_input_defaults = obj(
6259    id             = 'qmc',
6260    series         = 0,
6261    purpose        = '',
6262    seed           = None,
6263    bconds         = None,
6264    truncate       = False,
6265    buffer         = None,
6266    lr_dim_cutoff  = 15,
6267    lr_tol         = None,
6268    lr_handler     = None,
6269    remove_cell    = False,
6270    randomsrc      = False,
6271    meshfactor     = 1.0,
6272    orbspline      = None,
6273    precision      = 'float',
6274    twistnum       = None,
6275    twist          = None,
6276    spin_polarized = None,
6277    partition      = None,
6278    partition_mf   = None,
6279    hybridrep      = None,
6280    hybrid_rcut    = None,
6281    hybrid_lmax    = None,
6282    orbitals_h5    = 'MISSING.h5',
6283    excitation     = None,
6284    system         = 'missing',
6285    pseudos        = None,
6286    dla            = None,
6287    jastrows       = 'generateJ12',
6288    interactions   = 'all',
6289    corrections    = 'default',
6290    observables    = None,
6291    estimators     = None,
6292    traces         = None,
6293    calculations   = None,
6294    det_format     = 'new',
6295    J1             = False,
6296    J2             = False,
6297    J3             = False,
6298    J1_size        = None,
6299    J1_rcut        = None,
6300    J1_dr          = 0.5,
6301    J2_size        = None,
6302    J2_rcut        = None,
6303    J2_dr          = 0.5,
6304    J2_init        = 'zero',
6305    J3_isize       = 3,
6306    J3_esize       = 3,
6307    J3_rcut        = 5.0,
6308    J1_rcut_open   = 5.0,
6309    J2_rcut_open   = 10.0,
6310    qmc            = None, # opt,vmc,vmc_test,dmc,dmc_test
6311    )
6312
6313def generate_basic_input(**kwargs):
6314    # capture inputs
6315    kw = obj(kwargs)
6316    # apply general defaults
6317    kw.set_optional(**gen_basic_input_defaults)
6318    valid = set(gen_basic_input_defaults.keys())
6319    # apply method specific defaults
6320    if kw.qmc is not None:
6321        if kw.qmc not in qmc_defaults:
6322            QmcpackInput.class_error('invalid input for argument "qmc"\ninvalid input: {}\nvalid options are: {}'.format(kw.qmc,sorted(qmc_defaults.keys())),'generate_basic_input')
6323        #end if
6324        qmc_keys = []
6325        kw.set_optional(**qmc_defaults[kw.qmc])
6326        qmc_keys += list(qmc_defaults[kw.qmc].keys())
6327        if kw.qmc=='opt':
6328            key = (kw.method,kw.minmethod.lower())
6329            if key not in opt_method_defaults:
6330                QmcpackInput.class_error('invalid input for arguments "method,minmethod"\ninvalid input: {}\nvalid options are: {}'.format(key,sorted(opt_method_defaults.keys())),'generate_basic_input')
6331            #end if
6332            kw.set_optional(**opt_method_defaults[key])
6333            qmc_keys += list(opt_method_defaults[key].keys())
6334            del key
6335        #end if
6336        valid |= set(qmc_keys)
6337    #end if
6338    # screen for invalid keywords
6339    invalid_kwargs = set(kw.keys())-valid
6340    if len(invalid_kwargs)>0:
6341        QmcpackInput.class_error('invalid input parameters encountered\ninvalid input parameters: {0}\nvalid options are: {1}'.format(sorted(invalid_kwargs),sorted(valid)),'generate_qmcpack_input')
6342    #end if
6343
6344    if kw.system=='missing':
6345        QmcpackInput.class_error('generate_basic_input argument system is missing\nif you really do not want particlesets to be generated, set system to None')
6346    #end if
6347    if kw.bconds is None:
6348        if kw.system is not None:
6349            s = kw.system.structure
6350            kw.bconds = s.bconds
6351            if len(kw.bconds)==0 or not s.has_axes():
6352                kw.bconds = 'nnn'
6353            #end if
6354        else:
6355            kw.bconds = 'ppp'
6356        #end if
6357    #end if
6358    if kw.corrections=='default' and tuple(kw.bconds)==tuple('ppp'):
6359        kw.corrections = ['mpc','chiesa']
6360    elif isinstance(kw.corrections,(list,tuple)):
6361        None
6362    else:
6363        kw.corrections = []
6364    #end if
6365    if kw.observables is None:
6366        #observables = ['localenergy']
6367        kw.observables = []
6368    #end if
6369    if kw.estimators is None:
6370        kw.estimators = []
6371    #end if
6372    kw.estimators = kw.estimators + kw.observables + kw.corrections
6373    if kw.calculations is None:
6374        kw.calculations = []
6375    #end if
6376    if kw.spin_polarized is None:
6377        kw.spin_polarized = kw.system.net_spin>0
6378    #end if
6379    if kw.partition is not None:
6380        kw.det_format = 'new'
6381    #end if
6382    if kw.hybrid_rcut is not None or kw.hybrid_lmax is not None:
6383        kw.hybridrep = True
6384    #end if
6385
6386    metadata = QmcpackInput.default_metadata.copy()
6387
6388    proj = project(
6389        id          = kw.id,
6390        series      = kw.series,
6391        application = application(),
6392        )
6393
6394    simcell = generate_simulationcell(
6395        bconds        = kw.bconds,
6396        lr_dim_cutoff = kw.lr_dim_cutoff,
6397        lr_tol        = kw.lr_tol,
6398        lr_handler    = kw.lr_handler,
6399        system        = kw.system,
6400        )
6401
6402    if kw.system is not None:
6403        kw.system.structure.set_bconds(kw.bconds)
6404        particlesets = generate_particlesets(
6405            system      = kw.system,
6406            randomsrc   = kw.randomsrc or tuple(kw.bconds)!=('p','p','p'),
6407            hybrid_rcut = kw.hybrid_rcut,
6408            hybrid_lmax = kw.hybrid_lmax,
6409            )
6410    #end if
6411
6412
6413    if kw.det_format=='new':
6414        if kw.excitation is not None:
6415            QmcpackInput.class_error('user provided "excitation" input argument with new style determinant format\nplease add det_format="old" and try again')
6416        #end if
6417        if kw.system is not None and isinstance(kw.system.structure,Jellium):
6418            ssb = generate_sposet_builder(
6419                type           = 'heg',
6420                twist          = kw.twist,
6421                spin_polarized = kw.spin_polarized,
6422                system         = kw.system,
6423                )
6424        else:
6425            if kw.orbspline is None:
6426                kw.orbspline = 'bspline'
6427            #end if
6428            ssb = generate_sposet_builder(
6429                type           = kw.orbspline,
6430                twist          = kw.twist,
6431                twistnum       = kw.twistnum,
6432                meshfactor     = kw.meshfactor,
6433                precision      = kw.precision,
6434                truncate       = kw.truncate,
6435                buffer         = kw.buffer,
6436                hybridrep      = kw.hybridrep,
6437                href           = kw.orbitals_h5,
6438                spin_polarized = kw.spin_polarized,
6439                system         = kw.system,
6440                )
6441        #end if
6442        if kw.partition is None:
6443            spobuilders = [ssb]
6444        else:
6445            spobuilders = partition_sposets(
6446                sposet_builder = ssb,
6447                partition      = kw.partition,
6448                partition_meshfactors = kw.partition_mf,
6449                )
6450        #end if
6451
6452        dset = generate_determinantset(
6453            spin_polarized = kw.spin_polarized,
6454            system         = kw.system,
6455            )
6456    elif kw.det_format=='old':
6457        spobuilders = None
6458        if kw.orbspline is None:
6459            kw.orbspline = 'einspline'
6460        #end if
6461        dset = generate_determinantset_old(
6462            type           = kw.orbspline,
6463            twistnum       = kw.twistnum,
6464            meshfactor     = kw.meshfactor,
6465            precision      = kw.precision,
6466            hybridrep      = kw.hybridrep,
6467            href           = kw.orbitals_h5,
6468            spin_polarized = kw.spin_polarized,
6469            excitation     = kw.excitation,
6470            system         = kw.system,
6471            )
6472    else:
6473        QmcpackInput.class_error('generate_basic_input argument det_format is invalid\n  received: {0}\n  valid options are: new,old'.format(det_format))
6474    #end if
6475
6476
6477    wfn = wavefunction(
6478        name           = 'psi0',
6479        target         = 'e',
6480        determinantset = dset,
6481        )
6482
6483    if kw.J1 or kw.J2 or kw.J3:
6484        kw.jastrows = generate_jastrows_alt(
6485            J1           = kw.J1          ,
6486            J2           = kw.J2          ,
6487            J3           = kw.J3          ,
6488            J1_size      = kw.J1_size     ,
6489            J1_rcut      = kw.J1_rcut     ,
6490            J1_dr        = kw.J1_dr       ,
6491            J2_size      = kw.J2_size     ,
6492            J2_rcut      = kw.J2_rcut     ,
6493            J2_dr        = kw.J2_dr       ,
6494            J2_init      = kw.J2_init     ,
6495            J3_isize     = kw.J3_isize    ,
6496            J3_esize     = kw.J3_esize    ,
6497            J3_rcut      = kw.J3_rcut     ,
6498            J1_rcut_open = kw.J1_rcut_open,
6499            J2_rcut_open = kw.J2_rcut_open,
6500            system       = kw.system      ,
6501            )
6502    #end if
6503    if kw.jastrows is not None:
6504        wfn.jastrows = generate_jastrows(kw.jastrows,kw.system,check_ions=True)
6505    #end if
6506
6507    hmltn = generate_hamiltonian(
6508        system       = kw.system,
6509        pseudos      = kw.pseudos,
6510        dla          = kw.dla,
6511        interactions = kw.interactions,
6512        estimators   = kw.estimators,
6513        )
6514
6515    if spobuilders is not None:
6516        wfn.sposet_builders = make_collection(spobuilders)
6517    #end if
6518
6519    qmcsys = qmcsystem(
6520        simulationcell  = simcell,
6521        wavefunction    = wfn,
6522        hamiltonian     = hmltn,
6523        )
6524
6525    if kw.system is not None:
6526        qmcsys.particlesets = particlesets
6527    #end if
6528
6529    sim = simulation(
6530        project   = proj,
6531        qmcsystem = qmcsys,
6532        )
6533
6534    if kw.seed is not None:
6535        sim.random = random(seed=kw.seed)
6536    #end if
6537
6538    if kw.traces is not None:
6539        sim.traces = kw.traces
6540    #end if
6541
6542    if len(kw.calculations)==0 and kw.qmc is not None:
6543        qmc_inputs = kw.obj(*qmc_keys)
6544        if kw.qmc=='opt':
6545            kw.calculations = generate_opt_calculations(**qmc_inputs)
6546        elif 'vmc' in kw.qmc:
6547            kw.calculations = generate_vmc_calculations(**qmc_inputs)
6548        elif 'dmc' in kw.qmc:
6549            kw.calculations = generate_dmc_calculations(**qmc_inputs)
6550        #end if
6551    #end if
6552    for calculation in kw.calculations:
6553        if isinstance(calculation,loop):
6554            calc = calculation.qmc
6555        else:
6556            calc = calculation
6557        #end if
6558        has_localenergy = False
6559        has_estimators = 'estimators' in calc
6560        if has_estimators:
6561            estimators = calc.estimators
6562            if not isinstance(estimators,collection):
6563                estimators = make_collection(estimators)
6564            #end if
6565            has_localenergy = 'localenergy' in estimators or 'LocalEnergy' in estimators
6566        else:
6567            estimators = collection()
6568        #end if
6569        #if not has_localenergy:
6570        #    estimators.localenergy = localenergy(name='LocalEnergy')
6571        #    calc.estimators = estimators
6572        ##end if
6573    #end for
6574    sim.calculations = make_collection(kw.calculations).copy()
6575
6576    qi = QmcpackInput(metadata,sim)
6577
6578    qi.incorporate_defaults(elements=False,overwrite=False,propagate=True)
6579
6580    if kw.remove_cell:
6581        qi.remove_physical_system()
6582    #end if
6583
6584    for calc in sim.calculations:
6585        if isinstance(calc,loop):
6586            calc = calc.qmc
6587        #end if
6588        if isinstance(calc,(linear,cslinear)) and 'nonlocalpp' not in calc:
6589            calc.nonlocalpp           = True
6590            calc.use_nonlocalpp_deriv = True
6591        #end if
6592    #end for
6593
6594    return qi
6595#end def generate_basic_input
6596
6597
6598
6599gen_basic_afqmc_input_defaults = obj(
6600    id          = 'qmc',
6601    series      = 0,
6602    seed        = None,
6603    nmo         = None,
6604    naea        = None,
6605    naeb        = None,
6606    ham_file    = None,
6607    wfn_file    = None,
6608    wfn_type    = 'NOMSD',
6609    cutoff      = 1e-8,
6610    wset_type   = 'shared',
6611    walker_type = 'CLOSED',
6612    hybrid      = True,
6613    ncores      = 1,
6614    nwalkers    = 10,
6615    blocks      = 10000,
6616    steps       = 10,
6617    timestep    = 0.005,
6618    estimators  = None,
6619    info_name   = 'info0',
6620    ham_name    = 'ham0',
6621    wfn_name    = 'wfn0',
6622    wset_name   = 'wset0',
6623    prop_name   = 'prop0',
6624    system      = None,
6625    )
6626
6627def generate_basic_afqmc_input(**kwargs):
6628    # capture inputs
6629    kw = obj(kwargs)
6630    gen_info = obj()
6631    for k,v in kw.items():
6632        if not isinstance(v,obj):
6633            gen_info[k] = v
6634        #end if
6635    #end for
6636    # apply general defaults
6637    kw.set_optional(**gen_basic_afqmc_input_defaults)
6638    valid = set(gen_basic_afqmc_input_defaults.keys())
6639    # screen for invalid keywords
6640    invalid_kwargs = set(kw.keys())-valid
6641    if len(invalid_kwargs)>0:
6642        QmcpackInput.class_error('invalid input parameters encountered\ninvalid input parameters: {0}\nvalid options are: {1}'.format(sorted(invalid_kwargs),sorted(valid)),'generate_qmcpack_input')
6643    #end if
6644
6645    metadata = meta(
6646        generation_info = gen_info.copy(),
6647        )
6648
6649    sim = simulation(
6650        method = 'afqmc',
6651        )
6652
6653    sim.project = project(
6654        id     = kw.id,
6655        series = kw.series,
6656        )
6657
6658    if kw.seed is not None:
6659        sim.random = random(seed=kw.seed)
6660    #end if
6661
6662    info = afqmcinfo(
6663        name = kw.info_name,
6664        )
6665    if kw.nmo is not None:
6666        info.nmo = kw.nmo
6667    #end if
6668    if kw.naea is not None:
6669        info.naea = kw.naea
6670    #end if
6671    if kw.naeb is not None:
6672        info.naeb = kw.naeb
6673    #end if
6674    sim.afqmcinfo = info
6675
6676    if kw.ham_file is None and kw.wfn_file is not None:
6677        kw.ham_file = kw.wfn_file
6678    elif kw.ham_file is not None and kw.wfn_file is None:
6679        kw.wfn_file = kw.ham_file
6680    elif kw.ham_file is None and kw.wfn_file is None:
6681        kw.ham_file = 'MISSING.h5'
6682        kw.wfn_file = 'MISSING.h5'
6683    #end if
6684    def get_filetype(filename,loc):
6685        if filename.endswith('.h5'):
6686            filetype = 'hdf5'
6687        else:
6688            QmcpackInput.class_error('Type of {} file "{}" is unrecognized.\n The following file extensions are allowed: .h5'.format(loc,filename))
6689        #end if
6690        return filetype
6691    #end def get_filetype
6692
6693    ham = hamiltonian(
6694        name     = kw.ham_name,
6695        info     = info.name,
6696        filetype = get_filetype(kw.ham_file,'hamiltonian'),
6697        filename = kw.ham_file,
6698        )
6699    sim.hamiltonian = ham
6700
6701    wfn = wavefunction(
6702        name     = kw.wfn_name,
6703        info     = info.name,
6704        filetype = get_filetype(kw.wfn_file,'wavefunction'),
6705        filename = kw.wfn_file,
6706        )
6707    if kw.wfn_type is not None:
6708        wfn.type = kw.wfn_type
6709    #end if
6710    if kw.cutoff is not None:
6711        wfn.cutoff = kw.cutoff
6712    #end if
6713    sim.wavefunction = wfn
6714
6715    wset = walkerset(
6716        name        = kw.wset_name,
6717        )
6718    if kw.wset_type is not None:
6719        wset.type = kw.wset_type
6720    #end if
6721    if kw.walker_type is not None:
6722        wset.walker_type = kw.walker_type
6723    #end if
6724    sim.walkerset = wset
6725
6726    prop = propagator(
6727        name = kw.prop_name,
6728        info = info.name,
6729        )
6730    if kw.hybrid is not None:
6731        prop.hybrid = kw.hybrid
6732    #end if
6733    sim.propagator = prop
6734
6735    exe = execute(
6736        info = info.name,
6737        ham  = ham.name,
6738        wfn  = wfn.name,
6739        wset = wset.name,
6740        prop = prop.name,
6741        )
6742    for k in execute.parameters:
6743        if k in kw and kw[k] is not None:
6744            exe[k] = kw[k]
6745        #end if
6746    #end for
6747    estimators = []
6748    valid_estimators = (back_propagation,)
6749    if kw.estimators is not None:
6750        for est in kw.estimators:
6751            invalid = False
6752            if isinstance(est,QIxml):
6753                est = est.copy()
6754            else:
6755                invalid = True
6756            #end if
6757            invalid |= not isinstance(est,valid_estimators)
6758            if invalid:
6759                valid_names = [e.__class__.__name__ for e in valid_estimators]
6760                QmcpackInput.class_error('invalid estimator input encountered\nexpected one of the following: {}\ninputted type: {}\ninputted value: {}'.format(valid_names,est.__class__.__name__,est))
6761            #end if
6762            est.incorporate_defaults()
6763            estimators.append(est)
6764        #end for
6765    #end if
6766    if len(estimators)>0:
6767        exe.estimators = make_collection(estimators)
6768    #end if
6769    sim.execute = exe
6770
6771    qi = QmcpackInput(metadata,sim)
6772
6773    return qi
6774#end def generate_basic_afqmc_input
6775
6776
6777
6778def generate_opt_jastrow_input(id  = 'qmc',
6779                               series           = 0,
6780                               purpose          = '',
6781                               seed             = None,
6782                               bconds           = None,
6783                               remove_cell      = False,
6784                               meshfactor       = 1.0,
6785                               precision        = 'float',
6786                               twistnum         = None,
6787                               twist            = None,
6788                               spin_polarized   = False,
6789                               orbitals_h5      = 'MISSING.h5',
6790                               system           = None,
6791                               pseudos          = None,
6792                               jastrows         = 'generateJ12',
6793                               corrections      = None,
6794                               observables      = None,
6795                               processes        = None,
6796                               walkers_per_proc = None,
6797                               threads          = None,
6798                               decorr           = 10,
6799                               min_walkers      = None, #use e.g. 128 for gpu's
6800                               timestep         = 0.5,
6801                               nonlocalpp       = False,
6802                               sample_factor    = 1.0,
6803                               opt_calcs        = None,
6804                               det_format       = 'new'):
6805    jastrows = generate_jastrows(jastrows,system)
6806
6807    if opt_calcs is None:
6808        opt_calcs = [
6809            ('linear', 4,  0,  0, 1.0),
6810            ('linear', 4, .8, .2,   0)
6811            ]
6812    #end if
6813    opts = []
6814    for opt_calc in opt_calcs:
6815        if isinstance(opt_calc,QIxml):
6816            opts.append(opt_calc)
6817        elif len(opt_calc)==5:
6818            if opt_calc[0] in opt_map:
6819                opts.append(
6820                    generate_opt(
6821                        *opt_calc,
6822                         jastrows         = jastrows,
6823                         processes        = processes,
6824                         walkers_per_proc = walkers_per_proc,
6825                         threads          = threads,
6826                         decorr           = decorr,
6827                         min_walkers      = min_walkers,
6828                         timestep         = timestep,
6829                         nonlocalpp       = nonlocalpp,
6830                         sample_factor    = sample_factor
6831                         )
6832                    )
6833            else:
6834                QmcpackInput.class_error('optimization method '+opt_calc[0]+' has not yet been implemented')
6835            #end if
6836        else:
6837            QmcpackInput.class_error('optimization calculation is ill formatted\n  opt calc provided: \n'+str(opt_calc))
6838        #end if
6839    #end if
6840
6841    input = generate_basic_input(
6842        id             = id             ,
6843        series         = series         ,
6844        purpose        = purpose        ,
6845        seed           = seed           ,
6846        bconds         = bconds         ,
6847        remove_cell    = remove_cell    ,
6848        meshfactor     = meshfactor     ,
6849        precision      = precision      ,
6850        twistnum       = twistnum       ,
6851        twist          = twist          ,
6852        spin_polarized = spin_polarized ,
6853        orbitals_h5    = orbitals_h5    ,
6854        system         = system         ,
6855        pseudos        = pseudos        ,
6856        jastrows       = jastrows       ,
6857        corrections    = corrections    ,
6858        observables    = observables    ,
6859        calculations   = opts           ,
6860        det_format     = det_format     ,
6861        )
6862
6863    return input
6864#end def generate_opt_jastrow_input
6865
6866
6867
6868
6869
6870
6871
6872if __name__=='__main__':
6873
6874    filepath = './example_input_files/c_boron/qmcpack.in.xml'
6875
6876    element_joins=['qmcsystem']
6877    element_aliases=dict(loop='qmc')
6878    xml = XMLreader(filepath,element_joins,element_aliases,warn=False).obj
6879    xml.condense()
6880
6881    qi = QmcpackInput()
6882    qi.read(filepath)
6883
6884
6885    s = qi.simulation
6886    q = s.qmcsystem
6887    c = s.calculations
6888    h = q.hamiltonian
6889    p = q.particlesets
6890    w = q.wavefunction
6891    j = w.jastrows
6892    co= j.J1.correlations.B.coefficients
6893
6894
6895    qi.write('./output/qmcpack.in.xml')
6896
6897    #qi.condensed_name_report()
6898    #exit()
6899
6900
6901    test_ret_system    = 1
6902    test_gen_input     = 0
6903    test_difference    = 0
6904    test_moves         = 0
6905    test_defaults      = 0
6906    test_substitution  = 0
6907    test_generation    = 0
6908
6909
6910    if test_ret_system:
6911        from structure import generate_structure
6912        from physical_system import PhysicalSystem
6913
6914        system = PhysicalSystem(
6915            structure = generate_structure('diamond','fcc','Ge',(2,2,2),scale=5.639,units='A'),
6916            net_charge = 1,
6917            net_spin   = 1,
6918            Ge = 4
6919        )
6920
6921        gi = generate_qmcpack_input('basic',system=system)
6922
6923        rsys = gi.return_system()
6924
6925        print(rsys)
6926
6927    #end if
6928
6929
6930    if test_gen_input:
6931        from structure import generate_structure
6932        from physical_system import PhysicalSystem
6933
6934        system = PhysicalSystem(
6935            structure = generate_structure('diamond','fcc','Ge',(2,2,2),scale=5.639,units='A'),
6936            net_charge = 1,
6937            net_spin   = 1,
6938            Ge = 4
6939        )
6940
6941        gi = generate_qmcpack_input('basic',system=system)
6942
6943        print(gi)
6944
6945        print(gi.write())
6946    #end if
6947
6948
6949
6950    if test_difference:
6951        tstep = QmcpackInput('./example_input_files/luke_tutorial/diamond-dmcTsteps.xml')
6952        opt   = QmcpackInput('./example_input_files/luke_tutorial/opt-diamond.xml')
6953
6954        different,diff,d1,d2 = tstep.difference(tstep)
6955        different,diff,d1,d2 = tstep.difference(opt)
6956
6957    #end if
6958
6959
6960
6961    if test_moves:
6962        print(50*'=')
6963        sim = qi.simulation
6964        print(repr(sim))
6965        print(repr(sim.qmcsystem))
6966        print(50*'=')
6967        qi.move(particleset='simulation')
6968        print(repr(sim))
6969        print(repr(sim.qmcsystem))
6970        print(50*'=')
6971        qi.standard_placements()
6972        print(repr(sim))
6973        print(repr(sim.qmcsystem))
6974
6975        qi.pluralize()
6976    #end if
6977
6978
6979
6980    if test_defaults:
6981        q=QmcpackInput(
6982            simulation(
6983                qmcsystem=section(
6984                    simulationcell = section(),
6985                    wavefunction = section(),
6986                    hamiltonian = section()
6987                    ),
6988                calculations = [
6989                    cslinear(),
6990                    vmc(),
6991                    dmc()
6992                    ]
6993                )
6994            )
6995
6996        #q.simulation = simulation()
6997
6998        q.incorporate_defaults(elements=True)
6999
7000        print(q)
7001    #end if
7002
7003
7004    if test_substitution:
7005        q = qi.copy()
7006
7007        q.remove('simulationcell','particleset','wavefunction')
7008        q.write('./output/qmcpack.remove.xml')
7009        q.include_xml('./example_input_files/energy_density/Si.ptcl.xml',replace=False)
7010        q.include_xml('./example_input_files/energy_density/Si.wfs.xml',replace=False)
7011        q.write('./output/qmcpack.replace.xml')
7012
7013        qnj = QmcpackInput()
7014        qnj.read('./example_input_files/jastrowless/opt_jastrow.in.xml')
7015
7016        qnj.generate_jastrows(size=6)
7017        qnj.write('./output/jastrow_gen.in.xml')
7018
7019    #end if
7020
7021
7022
7023    if test_generation:
7024
7025        q=QmcpackInput(
7026            meta(
7027                lattice    = {'units':'bohr'},
7028                reciprocal = {'units':'2pi/bohr'},
7029                ionid      = {'datatype':'stringArray'},
7030                position   = {'datatype':'posArray', 'condition':0}
7031                ),
7032            simulation(
7033                project = section(
7034                    id='C16B',
7035                    series = 0,
7036                    application = section(
7037                        name = 'qmcpack',
7038                        role = 'molecu',
7039                        class_ = 'serial',
7040                        version = .2
7041                        ),
7042                    host = 'kraken',
7043                    date = '3 May 2012',
7044                    user = 'jtkrogel'
7045                    ),
7046                random = section(seed=13),
7047                qmcsystem = section(
7048                    simulationcell = section(
7049                        name = 'global',
7050                        lattice = array([[1,1,0],[1,0,1],[0,1,1]]),
7051                        reciprocal = array([[1,1,-1],[1,-1,1],[-1,1,1]]),
7052                        bconds = 'p p p',
7053                        LR_dim_cutoff = 15
7054                        ),
7055                    particlesets = [
7056                        particleset(
7057                            name = 'ion0',
7058                            size = 32,
7059                            groups=[
7060                                group(
7061                                    name='C',
7062                                    charge=4.
7063                                    ),
7064                                group(
7065                                    name='B',
7066                                    charge = 3.
7067                                    )
7068                                ],
7069                            ionid = ['B','C','C','C','C','C','C','C','C','C','C','C','C','C','C','C',
7070                                     'B','C','C','C','C','C','C','C','C','C','C','C','C','C','C','C'],
7071                            position = array([
7072                                    [ 0.00, 0.00, 0.00],[ 1.68, 1.68, 1.68],[ 3.37, 3.37, 0.00],
7073                                    [ 5.05, 5.05, 1.68],[ 3.37, 0.00, 3.37],[ 5.05, 1.68, 5.05],
7074                                    [ 6.74, 3.37, 3.37],[ 8.42, 5.05, 5.05],[ 0.00, 3.37, 3.37],
7075                                    [ 1.68, 5.05, 5.05],[ 3.37, 6.74, 3.37],[ 5.05, 8.42, 5.05],
7076                                    [ 3.37, 3.37, 6.74],[ 5.05, 5.05, 8.42],[ 6.74, 6.74, 6.74],
7077                                    [ 8.42, 8.42, 8.42],[ 6.74, 6.74, 0.00],[ 8.42, 8.42, 1.68],
7078                                    [10.11,10.11, 0.00],[11.79,11.79, 1.68],[10.11, 6.74, 3.37],
7079                                    [11.79, 8.42, 5.05],[13.48,10.11, 3.37],[15.16,11.79, 5.05],
7080                                    [ 6.74,10.11, 3.37],[ 8.42,11.79, 5.05],[10.11,13.48, 3.37],
7081                                    [11.79,15.16, 5.05],[10.11,10.11, 6.74],[11.79,11.79, 8.42],
7082                                    [13.48,13.48, 6.74],[15.16,15.16, 8.42]])
7083                            ),
7084                        particleset(
7085                            name='e',
7086                            random = 'yes',
7087                            random_source = 'ion0',
7088                            groups=[
7089                                group(
7090                                    name='u',
7091                                    size=64,
7092                                    charge=-1
7093                                    ),
7094                                group(
7095                                    name='d',
7096                                    size=63,
7097                                    charge=-1
7098                                    )
7099                                ]
7100                            ),
7101                        ],
7102                    hamiltonians = [
7103                        hamiltonian(
7104                            name='h0',
7105                            type='generic',
7106                            target='e',
7107                            pairpots=[
7108                                pairpot(
7109                                    type = 'coulomb',
7110                                    name = 'ElecElec',
7111                                    source = 'e',
7112                                    target = 'e'
7113                                    ),
7114                                pairpot(
7115                                    type = 'pseudo',
7116                                    name = 'PseudoPot',
7117                                    source = 'ion0',
7118                                    wavefunction='psi0',
7119                                    format='xml',
7120                                    pseudos = [
7121                                        pseudo(
7122                                            elementtype='B',
7123                                            href='B.pp.xml'
7124                                            ),
7125                                        pseudo(
7126                                            elementtype='C',
7127                                            href='C.pp.xml'
7128                                            )
7129                                        ]
7130                                    )
7131                                ],
7132                            constant = section(
7133                                type='coulomb',
7134                                name='IonIon',
7135                                source='ion0',
7136                                target='ion0'
7137                                ),
7138                            estimators = [
7139                                estimator(
7140                                    type='energydensity',
7141                                    name='edvoronoi',
7142                                    dynamic='e',
7143                                    static='ion0',
7144                                    spacegrid = section(
7145                                        coord = 'voronoi'
7146                                        )
7147                                    ),
7148                                energydensity(
7149                                    name='edchempot',
7150                                    dynamic='e',
7151                                    static='ion0',
7152                                    spacegrid=spacegrid(
7153                                        coord='voronoi',
7154                                        min_part=-4,
7155                                        max_part=5
7156                                        )
7157                                    ),
7158                                estimator(
7159                                    type='energydensity',
7160                                    name='edcell',
7161                                    dynamic='e',
7162                                    static='ion0',
7163                                    spacegrid = section(
7164                                        coord = 'cartesian',
7165                                        origin = section(p1='zero'),
7166                                        axes   = (
7167                                            axis(label='x',p1='a1',scale=.5,grid='-1 (192) 1'),
7168                                            axis(label='y',p1='a2',scale=.5,grid='-1 (1) 1'),
7169                                            axis(label='z',p1='a3',scale=.5,grid='-1 (1) 1')
7170                                            )
7171        #                                axes   = collection(
7172        #                                    x = section(p1='a1',scale=.5,grid='-1 (192) 1'),
7173        #                                    y = section(p1='a2',scale=.5,grid='-1 (1) 1'),
7174        #                                    z = section(p1='a3',scale=.5,grid='-1 (1) 1')
7175        #                                    )
7176                                        )
7177                                    )
7178                                ]
7179                            )
7180                        ],
7181                    wavefunction = section(
7182                        name = 'psi0',
7183                        target = 'e',
7184                        determinantset = section(
7185                            type='bspline',
7186                            href='Si.pwscf.h5',
7187                            sort = 1,
7188                            tilematrix = array([[1,0,0],[0,1,0],[0,0,1]]),
7189                            twistnum = 0,
7190                            source = 'ion0',
7191                            slaterdeterminant = section(
7192                                determinants=[
7193                                    determinant(
7194                                        id='updet',
7195                                        size=64,
7196                                        occupation = section(
7197                                            mode='ground',
7198                                            spindataset=0
7199                                            )
7200                                        ),
7201                                    determinant(
7202                                        id='downdet',
7203                                        size=63,
7204                                        occupation = section(
7205                                            mode='ground',
7206                                            spindataset=1
7207                                            )
7208                                        )
7209                                    ]
7210                                ),
7211                            ),
7212                        jastrows = [
7213                            jastrow(
7214                                type='two-body',
7215                                name='J2',
7216                                function='bspline',
7217                                print='yes',
7218                                correlations = [
7219                                    correlation(
7220                                        speciesA='u',
7221                                        speciesB='u',
7222                                        size=6,
7223                                        rcut=3.9,
7224                                        coefficients = section(
7225                                            id='uu',
7226                                            type='Array',
7227                                            coeff=[0,0,0,0,0,0]
7228                                            )
7229                                        ),
7230                                    correlation(
7231                                        speciesA='u',
7232                                        speciesB='d',
7233                                        size=6,
7234                                        rcut=3.9,
7235                                        coefficients = section(
7236                                            id='ud',
7237                                            type='Array',
7238                                            coeff=[0,0,0,0,0,0]
7239                                            )
7240                                        )
7241                                    ]
7242                                ),
7243                            jastrow(
7244                                type='one-body',
7245                                name='J1',
7246                                function='bspline',
7247                                source='ion0',
7248                                print='yes',
7249                                correlations = [
7250                                    correlation(
7251                                        elementtype='C',
7252                                        size=6,
7253                                        rcut=3.9,
7254                                        coefficients = section(
7255                                            id='eC',
7256                                            type='Array',
7257                                            coeff=[0,0,0,0,0,0]
7258                                            )
7259                                        ),
7260                                    correlation(
7261                                        elementtype='B',
7262                                        size=6,
7263                                        rcut=3.9,
7264                                        coefficients = section(
7265                                            id='eB',
7266                                            type='Array',
7267                                            coeff=[0,0,0,0,0,0]
7268                                            )
7269                                        )
7270                                    ]
7271                                )
7272                            ]
7273                        ),
7274
7275                    ),
7276                calculations=[
7277                    loop(max=4,
7278                         qmc=qmc(
7279                            method='cslinear',
7280                            move='pbyp',
7281                            checkpoint=-1,
7282                            gpu='no',
7283                            blocks = 3125,
7284                            warmupsteps = 5,
7285                            steps = 2,
7286                            samples = 80000,
7287                            timestep = .5,
7288                            usedrift = 'yes',
7289                            minmethod = 'rescale',
7290                            gevmethod = 'mixed',
7291                            exp0=-15,
7292                            nstabilizers = 5,
7293                            stabilizerscale = 3,
7294                            stepsize=.35,
7295                            alloweddifference=1e-5,
7296                            beta = .05,
7297                            bigchange = 5.,
7298                            energy = 0.,
7299                            unreweightedvariance = 0.,
7300                            reweightedvariance = 0.,
7301                            estimators=[
7302                                estimator(
7303                                    name='LocalEnergy',
7304                                    hdf5='no'
7305                                    )
7306                                ]
7307                            )
7308                        ),
7309                    qmc(
7310                        method = 'vmc',
7311                        multiple = 'no',
7312                        warp = 'no',
7313                        move = 'pbyp',
7314                        walkers = 1,
7315                        blocks = 2,
7316                        steps = 500,
7317                        substeps = 3,
7318                        timestep = .5,
7319                        usedrift = 'yes',
7320                        estimators=[
7321                            estimator(
7322                                name='LocalEnergy',
7323                                hdf5='yes'
7324                                )
7325                            ]
7326                        ),
7327                    qmc(
7328                        method='dmc',
7329                        move='pbyp',
7330                        walkers = 72,
7331                        blocks = 2,
7332                        steps = 50,
7333                        timestep = .01,
7334                        nonlocalmove = 'yes',
7335                        estimators=[
7336                            estimator(
7337                                name='LocalEnergy',
7338                                hdf5='no'
7339                                )
7340                            ]
7341                        )
7342                    ]
7343                )
7344            )
7345
7346        q.write('./output/gen.in.xml')
7347
7348
7349
7350
7351
7352        #something broke this, check later
7353        exit()
7354        qs=QmcpackInput(
7355            simulation = section(
7356                project = section(
7357                    id='C16B',series = 0,
7358                    application = section(name='qmcpack',role='molecu',class_='serial',version=.2),
7359                    host='kraken',date='3 May 2012',user='jtkrogel'
7360                    ),
7361                random = section(seed=13),
7362                qmcsystem = section(
7363                    simulationcell = section(
7364                        name='global',bconds='p p p',lr_dim_cutoff=15,
7365                        lattice    = [[1,1,0] ,[1,0,1] ,[0,1,1]],
7366                        reciprocal = [[1,1,-1],[1,-1,1],[-1,1,1]],
7367                        ),
7368                    particlesets = collection(
7369                        ion0=particleset(
7370                            size=32,
7371                            groups=collection(
7372                                C = group(charge=4.),
7373                                B = group(charge=3.)),
7374                            ionid = ('B','C','C','C','C','C','C','C','C','C','C','C','C','C','C','C',
7375                                     'B','C','C','C','C','C','C','C','C','C','C','C','C','C','C','C'),
7376                            position = [[ 0.00, 0.00, 0.00],[ 1.68, 1.68, 1.68],[ 3.37, 3.37, 0.00],
7377                                        [ 5.05, 5.05, 1.68],[ 3.37, 0.00, 3.37],[ 5.05, 1.68, 5.05],
7378                                        [ 6.74, 3.37, 3.37],[ 8.42, 5.05, 5.05],[ 0.00, 3.37, 3.37],
7379                                        [ 1.68, 5.05, 5.05],[ 3.37, 6.74, 3.37],[ 5.05, 8.42, 5.05],
7380                                        [ 3.37, 3.37, 6.74],[ 5.05, 5.05, 8.42],[ 6.74, 6.74, 6.74],
7381                                        [ 8.42, 8.42, 8.42],[ 6.74, 6.74, 0.00],[ 8.42, 8.42, 1.68],
7382                                        [10.11,10.11, 0.00],[11.79,11.79, 1.68],[10.11, 6.74, 3.37],
7383                                        [11.79, 8.42, 5.05],[13.48,10.11, 3.37],[15.16,11.79, 5.05],
7384                                        [ 6.74,10.11, 3.37],[ 8.42,11.79, 5.05],[10.11,13.48, 3.37],
7385                                        [11.79,15.16, 5.05],[10.11,10.11, 6.74],[11.79,11.79, 8.42],
7386                                        [13.48,13.48, 6.74],[15.16,15.16, 8.42]]
7387                            ),
7388                        e=particleset(
7389                            random='yes',random_source='ion0',
7390                            groups = collection(
7391                                u=group(size=64,charge=-1),
7392                                d=group(size=63,charge=-1))
7393                            ),
7394                        ),
7395                    hamiltonian = section(
7396                        name='h0',type='generic',target='e',
7397                        pairpots=collection(
7398                            ElecElec = coulomb(name='ElecElec',source='e',target='e'),
7399                            PseudoPot = pseudopotential(
7400                                source='ion0',wavefunction='psi0',format='xml',
7401                                pseudos = collection(
7402                                    B = pseudo(href='B.pp.xml'),
7403                                    C = pseudo(href='C.pp.xml'))
7404                                )
7405                            ),
7406                        constant = section(type='coulomb',name='IonIon',source='ion0',target='ion0'),
7407                        estimators = collection(
7408                            edvoronoi = energydensity(
7409                                dynamic='e',static='ion0',spacegrid=section(coord ='voronoi')
7410                                ),
7411                            edchempot = energydensity(
7412                                dynamic='e',static='ion0',
7413                                spacegrid=section(coord='voronoi',min_part=-4,max_part=5)
7414                                ),
7415                            edcell = energydensity(
7416                                dynamic='e',static='ion0',
7417                                spacegrid = section(
7418                                    coord = 'cartesian',
7419                                    origin = section(p1='zero'),
7420                                    axes = collection(
7421                                        x = axis(p1='a1',scale=.5,grid='-1 (192) 1'),
7422                                        y = axis(p1='a2',scale=.5,grid='-1 (1) 1'),
7423                                        z = axis(p1='a3',scale=.5,grid='-1 (1) 1'))
7424                                    )
7425                                )
7426                            )
7427                        ),
7428                    wavefunction = section(
7429                        name = 'psi0',target = 'e',
7430                        determinantset = section(
7431                            type='bspline',href='Si.pwscf.h5',sort=1,twistnum=0,source='ion0',
7432                            tilematrix=(1,0,0,0,1,0,0,0,1),
7433                            slaterdeterminant = section(
7434                                determinants=collection(
7435                                    updet = determinant(
7436                                        size=64,
7437                                        occupation=section(mode='ground',spindataset=0)
7438                                        ),
7439                                    downdet = determinant(
7440                                        size=63,
7441                                        occupation = section(mode='ground',spindataset=1))
7442                                    )
7443                                ),
7444                            ),
7445                        jastrows = collection(
7446                            J2=jastrow2(
7447                                function='bspline',print='yes',
7448                                correlations = collection(
7449                                    uu=correlation(
7450                                        speciesA='u',speciesB='u',size=6,rcut=3.9,
7451                                        coefficients = section(id='uu',type='Array',coeff=[0,0,0,0,0,0])
7452                                        ),
7453                                    ud=correlation(
7454                                        speciesA='u',speciesB='d',size=6,rcut=3.9,
7455                                        coefficients = section(id='ud',type='Array',coeff=[0,0,0,0,0,0])
7456                                        )
7457                                    )
7458                                ),
7459                            J1=jastrow1(
7460                                function='bspline',source='ion0',print='yes',
7461                                correlations = collection(
7462                                    C=correlation(
7463                                        size=6,rcut=3.9,
7464                                        coefficients = section(
7465                                            id='eC',type='Array',coeff=[0,0,0,0,0,0])
7466                                        ),
7467                                    B=correlation(
7468                                        size=6,rcut=3.9,
7469                                        coefficients = section(id='eB',type='Array',coeff=[0,0,0,0,0,0])
7470                                        )
7471                                    )
7472                                )
7473                            )
7474                        ),
7475                    ),
7476                calculations=(
7477                    loop(max=4,
7478                         qmc=cslinear(
7479                            move='pbyp',checkpoint=-1,gpu='no',
7480                            blocks      = 3125,
7481                            warmupsteps = 5,
7482                            steps       = 2,
7483                            samples     = 80000,
7484                            timestep    = .5,
7485                            usedrift    = 'yes',
7486                            minmethod   = 'rescale',
7487                            gevmethod   = 'mixed',
7488                            exp0              = -15,
7489                            nstabilizers      =  5,
7490                            stabilizerscale   =  3,
7491                            stepsize          =  .35,
7492                            alloweddifference = 1e-5,
7493                            beta              = .05,
7494                            bigchange         = 5.,
7495                            energy               = 0.,
7496                            unreweightedvariance = 0.,
7497                            reweightedvariance   = 0.,
7498                            estimator = localenergy(hdf5='no')
7499                            )
7500                        ),
7501                    vmc(multiple='no',warp='no',move='pbyp',
7502                        walkers  =  1,
7503                        blocks   =  2,
7504                        steps    = 500,
7505                        substeps =  3,
7506                        timestep = .5,
7507                        usedrift = 'yes',
7508                        estimator = localenergy(hdf5='no')
7509                        ),
7510                    dmc(move='pbyp',
7511                        walkers  =  72,
7512                        blocks   =   2,
7513                        steps    =  50,
7514                        timestep = .01,
7515                        nonlocalmove = 'yes',
7516                        estimator = localenergy(hdf5='yes')
7517                        )
7518                    )
7519                )
7520            )
7521
7522        qs.write('./output/simple.in.xml')
7523
7524
7525        est = qs.simulation.qmcsystem.hamiltonian.estimators
7526        sg = est.edcell.spacegrid
7527        print(repr(est))
7528
7529        exit()
7530
7531
7532
7533
7534
7535
7536
7537
7538
7539
7540
7541
7542
7543
7544
7545
7546
7547
7548
7549        q=QmcpackInput()
7550        q.simulation = section(
7551            project = section('C16B',0,
7552                application = section('qmcpack','molecu','serial',.2),
7553                host = 'kraken',
7554                date = '3 May 2012',
7555                user = 'jtkrogel'
7556                ),
7557            random = (13),
7558            qmcsystem = section(
7559                simulationcell = section(
7560                    units = 'bohr',
7561                    lattice = array([[1,1,0],[1,0,1],[0,1,1]]),
7562                    bconds = 'p p p',
7563                    LR_dim_cutoff = 15
7564                    ),
7565                particlesets = [
7566                    particleset('ion0', ('C',4), ('B',3),
7567                        ionid = ['B','C','C','C','C','C','C','C','C','C','C','C','C','C','C','C',
7568                                 'B','C','C','C','C','C','C','C','C','C','C','C','C','C','C','C'],
7569                        position = array([
7570                                [ 0.00, 0.00, 0.00],[ 1.68, 1.68, 1.68],[ 3.37, 3.37, 0.00],
7571                                [ 5.05, 5.05, 1.68],[ 3.37, 0.00, 3.37],[ 5.05, 1.68, 5.05],
7572                                [ 6.74, 3.37, 3.37],[ 8.42, 5.05, 5.05],[ 0.00, 3.37, 3.37],
7573                                [ 1.68, 5.05, 5.05],[ 3.37, 6.74, 3.37],[ 5.05, 8.42, 5.05],
7574                                [ 3.37, 3.37, 6.74],[ 5.05, 5.05, 8.42],[ 6.74, 6.74, 6.74],
7575                                [ 8.42, 8.42, 8.42],[ 6.74, 6.74, 0.00],[ 8.42, 8.42, 1.68],
7576                                [10.11,10.11, 0.00],[11.79,11.79, 1.68],[10.11, 6.74, 3.37],
7577                                [11.79, 8.42, 5.05],[13.48,10.11, 3.37],[15.16,11.79, 5.05],
7578                                [ 6.74,10.11, 3.37],[ 8.42,11.79, 5.05],[10.11,13.48, 3.37],
7579                                [11.79,15.16, 5.05],[10.11,10.11, 6.74],[11.79,11.79, 8.42],
7580                                [13.48,13.48, 6.74],[15.16,15.16, 8.42]])
7581                        ),
7582                    particleset('e', ('u',-1,64), ('d',-1,63), random_source = 'ion0'),
7583                    ],
7584                hamiltonian = section('h0','e',
7585                    pairpots=[
7586                        coulomb('ElecElec','e','e'),
7587                        pseudopotential('PseudoPot','ion0','psi0',('B','B.pp.xml'),('C','C.pp.xml')),
7588                        coulomb('IonIon','ion0','ion0'),
7589                        ],
7590                    estimators = [
7591                        energydensity('edvoronoi','e','ion0','voronoi',-4,5),
7592                        energydensity('edcell','e','ion0',
7593                            spacegrid('cartesian',
7594                                origin = 'zero',
7595                                x = ('a1',.5,'-1 (192) 1'),
7596                                y = ('a2',.5,'-1 (1) 1'),
7597                                z = ('a3',.5,'-1 (1) 1')
7598                                )
7599                            )
7600                        ]
7601                    ),
7602                wavefunction = section('psi0','e',
7603                    determinantset = section('bspline','Si.pwscf.h5','ion0',
7604                        sort = 1,
7605                        tilematrix = array([[1,0,0],[0,1,0],[0,0,1]]),
7606                        twistnum = 0,
7607                        slaterdeterminant = [
7608                            determinant('updet',64,'ground',0),
7609                            determinant('downdet',63,'ground',1)
7610                            ],
7611                        jastrows = [
7612                            twobody('J2','bspline',
7613                                    ('u','u',3.9,[0,0,0,0,0,0]),
7614                                    ('u','d',3.9,[0,0,0,0,0,0])),
7615                            onebody('J1','bspline','ion0',
7616                                    ('C',3.9,[0,0,0,0,0,0]),
7617                                    ('B',3.9,[0,0,0,0,0,0]))
7618                            ]
7619                        )
7620                    )
7621                ),
7622            calculations=[
7623                loop(4,
7624                    cslinear(
7625                        blocks = 3125,
7626                        warmupsteps = 5,
7627                        steps = 2,
7628                        samples = 80000,
7629                        timestep = .5,
7630                        minmethod = 'rescale',
7631                        gevmethod = 'mixed',
7632                        exp0=-15,
7633                        nstabilizers = 5,
7634                        stabilizerscale = 3,
7635                        stepsize=.35,
7636                        alloweddifference=1e-5,
7637                        beta = .05,
7638                        bigchange = 5.,
7639                        energy = 0.,
7640                        unreweightedvariance = 0.,
7641                        reweightedvariance = 0.,
7642                        estimator = localenergy(hdf5='no')
7643                        )
7644                    ),
7645                vmc(
7646                    blocks = 2,
7647                    steps = 500,
7648                    substeps = 3,
7649                    timestep = .5,
7650                    estimator = localenergy(hdf5='yes')
7651                    ),
7652                dmc(
7653                    walkers = 72,
7654                    blocks = 2,
7655                    steps = 50,
7656                    timestep = .01,
7657                    nonlocalmove = 'yes',
7658                    estimator = localenergy(hdf5='no')
7659                    )
7660                ]
7661            )
7662    #end if
7663#end if
7664