1##################################################################
2##  (c) Copyright 2015-  by Jaron T. Krogel                     ##
3##################################################################
4
5
6#====================================================================#
7#  fileio.py                                                         #
8#    Support for I/O with various file formats.  Currently this only #
9#    contains a generic file I/O class for XSF files.  In the future #
10#    generic XML and HDF5 support should go here.  Input only        #
11#    interfaces to these formats can be found in hdfreader.py and    #
12#    xmlreader.py.                                                   #
13#                                                                    #
14#  Content summary:                                                  #
15#    XsfFile                                                         #
16#      Represents generic XSF, AXSF, and BXSF files.                 #
17#      Can read/write arbitrary files of these formats.              #
18#      Useful for atomic structure and electronic density I/O.       #
19#                                                                    #
20#====================================================================#
21
22
23import os
24import mmap
25import numpy as np
26from numpy import array,zeros,ndarray,around,arange,dot,savetxt,empty,reshape
27from numpy.linalg import det,norm
28from generic import obj
29from developer import DevBase,error,to_str
30from periodic_table import pt as ptable,is_element
31from unit_converter import convert
32from debug import *
33
34
35class TextFile(DevBase):
36    # interface to mmap files
37    # see Python 2 documentation for mmap
38
39    def __init__(self,filepath=None):
40        self.mm = None
41        self.f  = None
42        if filepath!=None:
43            self.open(filepath)
44        #end if
45    #end def __init__
46
47    def open(self,filepath):
48        if not os.path.exists(filepath):
49            self.error('cannot open non-existent file: {0}'.format(filepath))
50        #end if
51        f = open(filepath,'r')
52        fno = f.fileno()
53        #fno = os.open(filepath,os.O_RDONLY)
54        self.f = f
55        self.mm = mmap.mmap(fno,0,prot=mmap.PROT_READ)
56    #end def open
57
58    def __iter__(self):
59        for line in self.f:
60            yield line
61        #end for
62    #end def __iter__
63
64    def __getitem__(self,slc):
65        return self.mm[slc]
66    #end def __getitem__
67
68    def lines(self):
69        return self.read().splitlines()
70    #end def lines
71
72    def tokens(self):
73        return self.read().split()
74    #end def tokens
75
76    def readtokens(self,s=None):
77        return self.readline(s).split()
78    #end def readtokens
79
80    def readtokensf(self,s=None,*formats):
81        if s!=None:
82            self.seek(s)
83        #end if
84        self.mm.readline()
85        line = to_str(self.mm.readline())
86        stokens = line.split()
87        all_same = False
88        if len(formats)==1 and len(stokens)>1:
89            format = formats[0]
90            all_same = True
91        elif len(formats)>len(stokens):
92            self.error('formatted line read failed\nnumber of tokens and provided number of formats do not match\nline: {0}\nnumber of tokens: {1}\nnumber of formats provided: {2}'.format(line,len(stokens),len(formats)))
93        #end if
94        tokens = []
95        if all_same:
96            for stoken in stokens:
97                tokens.append(format(stoken))
98            #end for
99        else:
100            for format,stoken in zip(formats,stokens):
101                tokens.append(format(stoken))
102            #end for
103        #end if
104        if len(tokens)==1:
105            return tokens[0]
106        else:
107            return tokens
108        #end if
109    #end def readtokensf
110
111    # extended mmap interface below
112    def close(self):
113        r = self.mm.close()
114        self.f.close()
115        return r
116    #end def close
117
118    def seek(self,pos,whence=0,start=None,end=None):
119        if isinstance(pos,str):
120            pos = pos.encode('ASCII')
121            if whence!=2 and start is None:
122                if whence==0:
123                    start = 0
124                elif whence==1:
125                    start = self.mm.tell()
126                else:
127                    self.error('relative positioning must be either 0 (begin), 1 (current), or 2 (end)\nyou provided: {0}'.format(whence))
128                #end if
129            #end if
130            if whence!=2:
131                if end!=None:
132                    pos = self.mm.find(pos,start,end)
133                else:
134                    pos = self.mm.find(pos,start)
135                #end if
136            else:
137                if end!=None:
138                    pos = self.mm.rfind(pos,start,end)
139                else:
140                    pos = self.mm.rfind(pos,start)
141                #end if
142            #end if
143            if pos!=-1:
144                return self.mm.seek(pos,0)
145            else:
146                return -1
147            #end if
148        else:
149            return self.mm.seek(pos,whence)
150        #end if
151    #end def seek
152
153    def readline(self,s=None):
154        if s!=None:
155            self.seek(s)
156        #end if
157        return to_str(self.mm.readline())
158    #end def readline
159
160    def read(self,num=None):
161        if num is None:
162            return to_str(self.mm[:])
163        else:
164            return to_str(self.mm.read(num))
165        #end if
166    #end def read
167
168
169    # unchanged mmap interface below
170    def find(self,*a,**kw):
171        args = []
172        for v in a:
173            if isinstance(v,str):
174                args.append(v.encode('ASCII'))
175            else:
176                args.append(a)
177            #end if
178        #end for
179        return self.mm.find(*args,**kw)
180    #end def find
181
182    def flush(self,*a,**kw):
183        return self.mm(*a,**kw)
184    #end def flush
185
186    def move(self,dest,src,count):
187        return self.mm.move(dest,src,count)
188    #end def move
189
190    def read_byte(self):
191        return self.mm.read_byte()
192    #end def read_byte
193
194    def resize(self,newsize):
195        return self.mm.resize(newsize)
196    #end def resize
197
198    def rfind(self,*a,**kw):
199        args = []
200        for v in a:
201            if isinstance(v,str):
202                args.append(v.encode('ASCII'))
203            else:
204                args.append(a)
205            #end if
206        #end for
207        return self.mm.rfind(*args,**kw)
208    #end def rfind
209
210    def size(self):
211        return self.mm.size()
212    #end def size
213
214    def tell(self):
215        return self.mm.tell()
216    #end def tell
217
218    def write(self,string):
219        return self.mm.write(string)
220    #end def write
221
222    def write_byte(self,byte):
223        return self.mm.write_byte(byte)
224    #end def write_byte
225#end class TextFile
226
227
228
229class StandardFile(DevBase):
230
231    sftype = ''
232
233    def __init__(self,filepath=None):
234        if filepath is None:
235            None
236        elif isinstance(filepath,str):
237            self.read(filepath)
238        else:
239            self.error('unsupported input: {0}'.format(filepath))
240        #end if
241    #end def __init__
242
243
244    def read(self,filepath):
245        if not os.path.exists(filepath):
246            self.error('read failed\nfile does not exist: {0}'.format(filepath))
247        #end if
248        self.read_text(open(filepath,'r').read())
249        self.check_valid('read failed')
250    #end def read
251
252
253    def write(self,filepath=None):
254        self.check_valid('write failed')
255        text = self.write_text()
256        if filepath!=None:
257            open(filepath,'w').write(text)
258        #end if
259        return text
260    #end def write
261
262
263    def is_valid(self):
264        return len(self.validity_checks())==0
265    #end def is_valid
266
267
268    def check_valid(self,header=None):
269        messages = self.validity_checks()
270        if len(messages)>0:
271            msg = ''
272            if header is not None:
273                msg += header+'\n'
274            #end if
275            msg += 'not a valid {0} file, see below for details\n'.format(self.sftype)
276            for m in messages:
277                msg+=m+'\n'
278            #end for
279            self.error(msg)
280        #end if
281    #end def check_valid
282
283
284    def validity_checks(self):
285        messages = []
286        return messages
287    #end def validity_checks
288
289
290    def read_text(self,text):
291        self.not_implemented()
292    #end def read_text
293
294
295    def write_text(self):
296        self.not_implemented()
297    #end def write_text
298
299#end class StandardFile
300
301
302
303class XsfFile(StandardFile):
304
305    sftype = 'xsf'
306
307    filetypes     = set(['xsf','axsf','bxsf'])
308    periodicities = set(['molecule','polymer','slab','crystal'])
309    dimensions    = obj(molecule=0,polymer=1,slab=2,crystal=3)
310
311    # ATOMS  are in units of Angstrom, only provided for 'molecule'
312    # forces are in units of Hatree/Angstrom
313    # each section should be followed by a blank line
314
315    def __init__(self,filepath=None):
316        self.filetype    = None
317        self.periodicity = None
318        StandardFile.__init__(self,filepath)
319    #end def __init__
320
321
322    def add_to_image(self,image,name,value):
323        if image is None:
324            self[name] = value
325        else:
326            if 'images' not in self:
327                self.images = obj()
328            #end if
329            if not image in self.images:
330                self.images[image] = obj()
331            #end if
332            self.images[image][name] = value
333        #end if
334    #end def add_to_image
335
336
337    # test needed for axsf and bxsf
338    def read_text(self,text):
339        lines = text.splitlines()
340        i=0
341        self.filetype = 'xsf'
342        while(i<len(lines)):
343            line = lines[i].strip().lower()
344            if len(line)>0 and line[0]!='#':
345                tokens = line.split()
346                keyword = tokens[0]
347                image = None
348                if len(tokens)==2:
349                    image = int(tokens[1])
350                #end if
351                if keyword in self.periodicities:
352                    self.periodicity = keyword
353                elif keyword=='animsteps':
354                    self.animsteps = int(tokens[1])
355                    self.filetype = 'axsf'
356                elif keyword=='primvec':
357                    primvec = array((lines[i+1]+' '+
358                                     lines[i+2]+' '+
359                                     lines[i+3]).split(),dtype=float)
360                    primvec.shape = 3,3
361                    self.add_to_image(image,'primvec',primvec)
362                    i+=3
363                elif keyword=='convvec':
364                    convvec = array((lines[i+1]+' '+
365                                     lines[i+2]+' '+
366                                     lines[i+3]).split(),dtype=float)
367                    convvec.shape = 3,3
368                    self.add_to_image(image,'convvec',convvec)
369                    i+=3
370                elif keyword=='atoms':
371                    if self.periodicity is None:
372                        self.periodicity='molecule'
373                    #end if
374                    i+=1
375                    tokens = lines[i].strip().split()
376                    elem  = []
377                    pos   = []
378                    force = []
379                    natoms = 0
380                    while len(tokens)==4 or len(tokens)==7:
381                        natoms+=1
382                        elem.append(tokens[0])
383                        pos.extend(tokens[1:4])
384                        if len(tokens)==7:
385                            force.extend(tokens[4:7])
386                        #end if
387                        i+=1
388                        tokens = lines[i].strip().split()
389                    #end while
390                    elem = array(elem,dtype=int)
391                    pos  = array(pos,dtype=float)
392                    pos.shape = natoms,3
393                    self.add_to_image(image,'elem',elem)
394                    self.add_to_image(image,'pos',pos)
395                    if len(force)>0:
396                        force = array(force,dtype=float)
397                        force.shape = natoms,3
398                        self.add_to_image(image,'force',force)
399                    #end if
400                    i-=1
401                elif keyword=='primcoord':
402                    natoms = int(lines[i+1].split()[0])
403                    elem  = []
404                    pos   = []
405                    force = []
406                    for iat in range(natoms):
407                        tokens = lines[i+2+iat].split()
408                        elem.append(tokens[0])
409                        pos.extend(tokens[1:4])
410                        if len(tokens)==7:
411                            force.extend(tokens[4:7])
412                        #end if
413                    #end for
414                    try:
415                        elem = array(elem,dtype=int)
416                    except:
417                        elem = array(elem,dtype=str)
418                    #end try
419                    pos  = array(pos,dtype=float)
420                    pos.shape = natoms,3
421                    self.add_to_image(image,'elem',elem)
422                    self.add_to_image(image,'pos',pos)
423                    if len(force)>0:
424                        force = array(force,dtype=float)
425                        force.shape = natoms,3
426                        self.add_to_image(image,'force',force)
427                    #end if
428                    i+=natoms+1
429                elif keyword.startswith('begin_block_datagrid'):
430                    if keyword.endswith('2d'):
431                        d=2
432                    elif keyword.endswith('3d'):
433                        d=3
434                    else:
435                        self.error('dimension of datagrid could not be identified: '+line)
436                    #end if
437                    i+=1
438                    block_identifier = lines[i].strip().lower()
439                    if not 'data' in self:
440                        self.data = obj()
441                    #end if
442                    if not d in self.data:
443                        self.data[d] = obj()
444                    #end if
445                    if not block_identifier in self.data[d]:
446                        self.data[d][block_identifier]=obj()
447                    #end if
448                    data = self.data[d][block_identifier]
449
450                    line = ''
451                    while not line.startswith('end_block_datagrid'):
452                        line = lines[i].strip().lower()
453                        if line.startswith('begin_datagrid') or line.startswith('datagrid_'):
454                            grid_identifier = line.replace('begin_datagrid_{0}d_'.format(d),'')
455                            grid   = array(lines[i+1].split(),dtype=int)[:d]
456                            corner = array(lines[i+2].split(),dtype=float)
457                            if d==2:
458                                cell   = array((lines[i+3]+' '+
459                                                lines[i+4]).split(),dtype=float)
460                                i+=5
461                            elif d==3:
462                                cell   = array((lines[i+3]+' '+
463                                                lines[i+4]+' '+
464                                                lines[i+5]).split(),dtype=float)
465                                i+=6
466                            #end if
467                            cell.shape = d,3
468                            dtokens = []
469                            line = lines[i].strip().lower()
470                            while not line.startswith('end_datagrid'):
471                                dtokens.extend(line.split())
472                                i+=1
473                                line = lines[i].strip().lower()
474                            #end while
475                            grid_data = array(dtokens,dtype=float)
476                            grid_data=reshape(grid_data,grid,order='F')
477                            data[grid_identifier] = obj(
478                                grid   = grid,
479                                corner = corner,
480                                cell   = cell,
481                                values = grid_data
482                                )
483                        #end if
484                        i+=1
485                    #end while
486                elif keyword=='begin_info':
487                    self.info = obj()
488                    while line.lower()!='end_info':
489                        line = lines[i].strip()
490                        if len(line)>0 and line[0]!='#' and ':' in line:
491                            k,v = line.split(':')
492                            self.info[k.strip()] = v.strip()
493                        #end if
494                        i+=1
495                    #end while
496                elif keyword.startswith('begin_block_bandgrid'):
497                    self.filetype = 'bxsf'
498                    if keyword.endswith('2d'):
499                        d=2
500                    elif keyword.endswith('3d'):
501                        d=3
502                    else:
503                        self.error('dimension of bandgrid could not be identified: '+line)
504                    #end if
505                    i+=1
506                    block_identifier = lines[i].strip().lower()
507                    if not 'band' in self:
508                        self.band = obj()
509                    #end if
510                    if not d in self.band:
511                        self.band[d] = obj()
512                    #end if
513                    if not block_identifier in self.band[d]:
514                        self.band[d][block_identifier]=obj()
515                    #end if
516                    band = self.band[d][block_identifier]
517
518                    line = ''
519                    while not line.startswith('end_block_bandgrid'):
520                        line = lines[i].strip().lower()
521                        if line.startswith('begin_bandgrid'):
522                            grid_identifier = line.replace('begin_bandgrid_{0}d_'.format(d),'')
523                            nbands = int(lines[i+1].strip())
524                            grid   = array(lines[i+2].split(),dtype=int)[:d]
525                            corner = array(lines[i+3].split(),dtype=float)
526                            if d==2:
527                                cell   = array((lines[i+4]+' '+
528                                                lines[i+5]).split(),dtype=float)
529                                i+=6
530                            elif d==3:
531                                cell   = array((lines[i+4]+' '+
532                                                lines[i+5]+' '+
533                                                lines[i+6]).split(),dtype=float)
534                                i+=7
535                            #end if
536                            cell.shape = d,3
537                            bands = obj()
538                            line = lines[i].strip().lower()
539                            while not line.startswith('end_bandgrid'):
540                                if line.startswith('band'):
541                                    band_index = int(line.split(':')[1].strip())
542                                    bands[band_index] = []
543                                else:
544                                    bands[band_index].extend(line.split())
545                                #end if
546                                i+=1
547                                line = lines[i].strip().lower()
548                            #end while
549                            for bi,bv in bands.items():
550                                bands[bi] = array(bv,dtype=float)
551                                bands[bi].shape = tuple(grid)
552                            #end for
553                            band[grid_identifier] = obj(
554                                grid   = grid,
555                                corner = corner,
556                                cell   = cell,
557                                bands  = bands
558                                )
559                        #end if
560                        i+=1
561                    #end while
562                else:
563                    self.error('invalid keyword encountered: {0}'.format(keyword))
564                #end if
565            #end if
566            i+=1
567        #end while
568    #end def read_text
569
570
571    # test needed for axsf and bxsf
572    def write_text(self):
573        c=''
574        if self.filetype=='xsf':    # only write structure/datagrid if present
575            if self.periodicity=='molecule' and 'elem' in self:
576                c += self.write_coord()
577            elif 'primvec' in self:
578                c += ' {0}\n'.format(self.periodicity.upper())
579                c += self.write_vec('primvec',self.primvec)
580                if 'convvec' in self:
581                    c += self.write_vec('convvec',self.convvec)
582                #end if
583                if 'elem' in self:
584                    c+= self.write_coord()
585                #end if
586            #end if
587            if 'data' in self:
588                c += self.write_data()
589            #end if
590        elif self.filetype=='axsf': # only write image structures
591            c += ' ANIMSTEPS {0}\n'.format(self.animsteps)
592            if self.periodicity!='molecule':
593                c += ' {0}\n'.format(self.periodicity.upper())
594            #end if
595            if 'primvec' in self:
596                c += self.write_vec('primvec',self.primvec)
597            #end if
598            if 'convvec' in self:
599                c += self.write_vec('convvec',self.convvec)
600            #end if
601            for i in range(1,len(self.images)+1):
602                image = self.images[i]
603                if 'primvec' in image:
604                    c += self.write_vec('primvec',image.primvec,i)
605                #end if
606                if 'convvec' in image:
607                    c += self.write_vec('convvec',image.convvec,i)
608                #end if
609                c += self.write_coord(image,i)
610            #end for
611        elif self.filetype=='bxsf': # only write bandgrid
612            c += self.write_band()
613        #end if
614        return c
615    #end def write_text
616
617
618    def write_coord(self,image=None,index=''):
619        if image is None:
620            s = self
621        else:
622            s = image
623        #end if
624        c = ''
625        if self.periodicity=='molecule':
626            c += ' ATOMS {0}\n'.format(index)
627        else:
628            c += ' PRIMCOORD {0}\n'.format(index)
629            c += '   {0} 1\n'.format(len(s.elem))
630        if not 'force' in s:
631            for i in range(len(s.elem)):
632                r = s.pos[i]
633                c += '   {0:>3} {1:12.8f} {2:12.8f} {3:12.8f}\n'.format(s.elem[i],r[0],r[1],r[2])
634            #end for
635        else:
636            for i in range(len(s.elem)):
637                r = s.pos[i]
638                f = s.force[i]
639                c += '   {0:>3} {1:12.8f} {2:12.8f} {3:12.8f}  {4:12.8f} {5:12.8f} {6:12.8f}\n'.format(s.elem[i],r[0],r[1],r[2],f[0],f[1],f[2])
640            #end for
641        #end if
642        return c
643    #end def write_coord
644
645
646    def write_vec(self,name,vec,index=''):
647        c = ' {0} {1}\n'.format(name.upper(),index)
648        for v in vec:
649            c += '   {0:12.8f} {1:12.8f} {2:12.8f}\n'.format(v[0],v[1],v[2])
650        #end for
651        return c
652    #end def write_vec
653
654
655    def write_data(self):
656        c = ''
657        ncols = 4
658        data = self.data
659        for d in sorted(data.keys()):
660            bdg_xd = data[d]       # all block datagrids 2 or 3 D
661            for bdgk in sorted(bdg_xd.keys()):
662                c += ' BEGIN_BLOCK_DATAGRID_{0}D\n'.format(d)
663                c += '   {0}\n'.format(bdgk)
664                bdg = bdg_xd[bdgk] # single named block data grid
665                for dgk in sorted(bdg.keys()):
666                    c += '   BEGIN_DATAGRID_{0}D_{1}\n'.format(d,dgk)
667                    dg = bdg[dgk]  # single named data grid
668                    if d==2:
669                        c += '     {0} {1}\n'.format(*dg.grid)
670                    elif d==3:
671                        c += '     {0} {1} {2}\n'.format(*dg.grid)
672                    #end if
673                    c += '   {0:12.8f} {1:12.8f} {2:12.8f}\n'.format(*dg.corner)
674                    for v in dg.cell:
675                        c += '   {0:12.8f} {1:12.8f} {2:12.8f}\n'.format(*v)
676                    #end for
677                    c = c[:-1]
678                    n=0
679                    for v in dg.values.ravel(order='F'):
680                        if n%ncols==0:
681                            c += '\n    '
682                        #end if
683                        c += ' {0:12.8f}'.format(v)
684                        n+=1
685                    #end for
686                    c += '\n   END_DATAGRID_{0}D_{1}\n'.format(d,dgk)
687                #end for
688                c += ' END_BLOCK_DATAGRID_{0}D\n'.format(d)
689            #end for
690        #end for
691        return c
692    #end def write_data
693
694
695    def write_band(self):
696        c = ''
697        ncols = 4
698        band = self.band
699        for d in sorted(band.keys()):
700            bdg_xd = band[d]       # all block bandgrids 2 or 3 D
701            for bdgk in sorted(bdg_xd.keys()):
702                c += ' BEGIN_BLOCK_BANDGRID_{0}D\n'.format(d)
703                c += '   {0}\n'.format(bdgk)
704                bdg = bdg_xd[bdgk] # single named block band grid
705                for dgk in sorted(bdg.keys()):
706                    c += '   BEGIN_BANDGRID_{0}D_{1}\n'.format(d,dgk)
707                    dg = bdg[dgk]  # single named band grid
708                    if d==2:
709                        c += '     {0} {1}\n'.format(*dg.grid)
710                    elif d==3:
711                        c += '     {0} {1} {2}\n'.format(*dg.grid)
712                    #end if
713                    c += '   {0:12.8f} {1:12.8f} {2:12.8f}\n'.format(*dg.corner)
714                    for v in dg.cell:
715                        c += '   {0:12.8f} {1:12.8f} {2:12.8f}\n'.format(*v)
716                    #end for
717                    for bi in sorted(dg.bands.keys()):
718                        c += '   BAND:  {0}'.format(bi)
719                        n=0
720                        for v in dg.bands[bi].ravel():
721                            if n%ncols==0:
722                                c += '\n    '
723                            #end if
724                            c += ' {0:12.8f}'.format(v)
725                            n+=1
726                        #end for
727                        c += '\n'
728                    #end for
729                    c += '   END_BANDGRID_{0}D_{1}\n'.format(d,dgk)
730                #end for
731                c += ' END_BLOCK_BANDGRID_{0}D\n'.format(d)
732            #end for
733        #end for
734        return c
735    #end def write_band
736
737
738    def dimension(self):
739        if self.periodicity in self.dimensions:
740            return self.dimensions[self.periodicity]
741        else:
742            return None
743        #end if
744    #end def dimension
745
746
747    def initialized(self):
748        return self.filetype!=None
749    #end def initialized
750
751
752    def has_animation(self):
753        return self.filetype=='axsf' and 'animsteps' in self
754    #end def has_animation
755
756
757    def has_bands(self):
758        return self.filetype=='bxsf' and 'band' in self and 'info' in self
759    #end def has_bands
760
761
762    def has_structure(self):
763        hs = self.filetype=='xsf'
764        hs &= 'elem' in self and 'pos' in self
765        d = self.dimension()
766        if d!=0:
767            hs &= 'primvec' in self
768        #end if
769        return hs
770    #end def has_structure
771
772
773    def has_data(self):
774        return self.filetype=='xsf' and 'data' in self
775    #end def has_data
776
777
778    def validity_checks(self):
779        ha = self.has_animation()
780        hb = self.has_bands()
781        hs = self.has_structure()
782        hd = self.has_data()
783        v = ha or hb or hs or hd
784        if v:
785            return []
786        else:
787            return ['xsf file must have animation, bands, structure, or data\nthe current file is missing all of these']
788        #end if
789    #end def validity_checks
790
791
792    # test needed
793    def incorporate_structure(self,structure):
794        s = structure.copy()
795        s.change_units('A')
796        s.recenter()
797        elem = []
798        for e in s.elem:
799            ne = len(e)
800            if ne>1:
801                if ne==2 and not e[1].isalpha():
802                    e = e[0]
803                elif ne>2:
804                    e = e[0:2]
805                #end if
806            #end if
807            if is_element(e):
808                elem.append(ptable.elements[e].atomic_number)
809            else:
810                elem.append(0)
811            #end if
812        #end for
813        self.filetype    = 'xsf'
814        self.periodicity = 'crystal' # assumed
815        self.primvec     = s.axes
816        self.elem        = array(elem,dtype=int)
817        self.pos         = s.pos
818    #end def incorporate_structure
819
820
821    def add_density(self,cell,density,name='density',corner=None,grid=None,centered=False,add_ghost=False):
822        if corner is None:
823            corner = zeros((3,),dtype=float)
824        #end if
825        if grid is None:
826            grid = density.shape
827        #end if
828        grid    = array(grid,dtype=int)
829        corner  = array(corner,dtype=float)
830        cell    = array(cell  ,dtype=float)
831        density = array(density,dtype=float)
832        density.shape = tuple(grid)
833
834        if centered: # shift corner by half a grid cell to center it
835            dc = 0.5/grid
836            dc = dot(dc,cell)
837            corner += dc
838        #end if
839
840        if add_ghost: # add ghost points to make a 'general' xsf grid
841            g = grid  # this is an extra shell of points in PBC
842            d = density
843            grid = g+1
844            density = zeros(tuple(grid),dtype=float)
845            density[:g[0],:g[1],:g[2]] = d[:,:,:] # volume copy
846            density[   -1,:g[1],:g[2]] = d[0,:,:] # face copies
847            density[:g[0],   -1,:g[2]] = d[:,0,:]
848            density[:g[0],:g[1],   -1] = d[:,:,0]
849            density[   -1,   -1,:g[2]] = d[0,0,:] # edge copies
850            density[   -1,:g[1],   -1] = d[0,:,0]
851            density[:g[0],   -1,   -1] = d[:,0,0]
852            density[   -1,   -1,   -1] = d[0,0,0] # corner copy
853            density.shape = tuple(grid)
854        #end if
855
856        self.data = obj()
857        self.data[3] = obj()
858        self.data[3][name] = obj()
859        self.data[3][name][name] = obj(
860            grid   = grid,
861            corner = corner,
862            cell   = cell,
863            values = density
864            )
865    #end def add_density
866
867
868    def get_density(self):
869        return self.data.first().first().first()
870    #end def get_density
871
872
873    # test needed
874    def change_units(self,in_unit,out_unit):
875        fac = 1.0/convert(1.0,in_unit,out_unit)**3
876        density = self.get_density()
877        density.values *= fac
878        if 'values_noghost' in density:
879            density.values_noghost *= fac
880        #end if
881    #end def change_units
882
883
884    # test needed
885    def remove_ghost(self,density=None):
886        if density is None:
887            density = self.get_density()
888        #end if
889        if 'values_noghost' in density:
890            return density.values_noghost
891        #end if
892        data = density.values
893
894        # remove the ghost cells
895        d = data
896        g = array(d.shape,dtype=int)-1
897        data = zeros(tuple(g),dtype=float)
898        data[:,:,:] = d[:g[0],:g[1],:g[2]]
899        density.values_noghost = data
900        return data
901    #end def remove_ghost
902
903
904    # test needed
905    def norm(self,density=None,vnorm=True):
906        if density is None:
907            density = self.get_density()
908        #end if
909        if 'values_noghost' not in density:
910            self.remove_ghost(density)
911        #end if
912        data = density.values_noghost
913        if vnorm:
914            dV = det(density.cell)/data.size
915        else:
916            dV = 1.0
917        #end if
918        return data.ravel().sum()*dV
919    #end def norm
920
921
922    # test needed
923    def line_data(self,dim,density=None):
924        if density is None:
925            density = self.get_density()
926        #end if
927        if 'values_noghost' not in density:
928            self.remove_ghost(density)
929        #end if
930        data = density.values_noghost
931        dV = det(density.cell)/data.size
932        dr = norm(density.cell[dim])/data.shape[dim]
933        ndim = 3
934        permute = dim!=0
935        if permute:
936            r = list(range(0,ndim))
937            r.pop(dim)
938            permutation = tuple([dim]+r)
939            data = data.transpose(permutation)
940        #end if
941        s = data.shape
942        data.shape = s[0],s[1]*s[2]
943        line_data = data.sum(1)*dV/dr
944        r_data = density.corner[dim] + dr*arange(len(line_data),dtype=float)
945        return r_data,line_data
946    #end def line_data
947
948
949    def line_plot(self,dim,filepath):
950        r,d = self.line_data(dim)
951        savetxt(filepath,array(list(zip(r,d))))
952    #end def line_plot
953
954    # test needed
955    def interpolate_plane(self,r1,r2,r3,density=None,meshsize=50,fill_value=0):
956        if density is None:
957            density = self.get_density()
958        #end if
959
960        dens_values = np.array(density.values)
961
962        # Construct crystal meshgrid for dens
963        da = 1./(density.grid[0]-1)
964        db = 1./(density.grid[1]-1)
965        dc = 1./(density.grid[2]-1)
966
967        cry_corner = np.matmul(density.corner,np.linalg.inv(density.cell))
968        a0  = cry_corner[0]
969        b0  = cry_corner[1]
970        c0  = cry_corner[2]
971
972        ra = np.arange(a0, density.grid[0]*da, da)
973        rb = np.arange(b0, density.grid[1]*db, db)
974        rc = np.arange(c0, density.grid[2]*dc, dc)
975
976        [mra, mrb, mrc] = np.meshgrid(ra, rb, rc)
977
978        # 3d Interpolation on crystal coordinates
979        from scipy.interpolate import RegularGridInterpolator
980        g = RegularGridInterpolator((ra,rb,rc), dens_values, bounds_error=False,fill_value=fill_value)
981
982        # Construct cartesian meshgrid for dens
983        mrx,mry,mrz = np.array([mra,mrb,mrc]).T.dot(density.cell).T
984
985        # First construct a basis (x'^,y'^,z'^) where z'^ is normal to the plane formed from ra, rb, and rc
986        zph = np.cross((r2-r3),(r1-r3))
987        zph = zph/np.linalg.norm(zph)
988        yph = r2-r3
989        yph = yph/np.linalg.norm(yph)
990        xph = np.cross(yph,zph)
991
992        # Positions in (x'^,y'^,z'^) basis
993        rp1 = np.dot(r1,np.linalg.inv((xph,yph,zph)))
994        rp2 = np.dot(r2,np.linalg.inv((xph,yph,zph)))
995        rp3 = np.dot(r3,np.linalg.inv((xph,yph,zph)))
996
997        # Meshgrid in (x'^,y'^,z'^) basis
998        mrxp,mryp,mrzp = np.array([mrx,mry,mrz]).T.dot(np.linalg.inv([xph,yph,zph])).T
999
1000        # Generate mesh in (x'^,y'^,z'^) basis. Ensure all points are in cell.
1001        xp_min = np.amin(mrxp)
1002        xp_max = np.amax(mrxp)
1003        yp_min = np.amin(mryp)
1004        yp_max = np.amax(mryp)
1005
1006
1007        rpx = np.arange(xp_min,xp_max,(xp_max-xp_min)/meshsize)
1008        rpy = np.arange(yp_min,yp_max,(yp_max-yp_min)/meshsize)
1009        mrpx, mrpy = np.meshgrid(rpx,rpy)
1010
1011        slice_dens = []
1012        for xpi in np.arange(xp_min,xp_max,(xp_max-xp_min)/meshsize):
1013            yline = []
1014            for ypi in np.arange(yp_min,yp_max,(yp_max-yp_min)/meshsize):
1015                # xpi,ypi,rp1[2] to crystal coords
1016                rcry = np.matmul( np.dot((xpi,ypi,rp1[2]),(xph,yph,zph)) , np.linalg.inv(density.cell))
1017                yline.extend(g(rcry))
1018                #end if
1019            #end for
1020            slice_dens.append(yline)
1021        #end for
1022        slice_dens = np.array(slice_dens).T
1023
1024        # return the following...
1025        # slice_dens: density on slice
1026        # mrpx, mrpy: meshgrid for x',y' coordinates parallel to slice, i.e., (x'^,y'^) basis
1027        # rp1, rp2, rp3: Input positions in (x'^,y'^,z'^) basis
1028        return slice_dens, mrpx, mrpy, rp1, rp2, rp3
1029    #end def coordinatesToSlice
1030#end class XsfFile
1031
1032
1033
1034class PoscarFile(StandardFile):
1035
1036    sftype = 'POSCAR'
1037
1038    def __init__(self,filepath=None):
1039        self.description = None
1040        self.scale       = None
1041        self.axes        = None
1042        self.elem        = None
1043        self.elem_count  = None
1044        self.coord       = None
1045        self.pos         = None
1046        self.dynamic     = None
1047        self.vel_coord   = None
1048        self.vel         = None
1049        StandardFile.__init__(self,filepath)
1050    #end def __init__
1051
1052
1053    def assign_defaults(self):
1054        if self.description is None:
1055            self.description = 'System cell and coordinates'
1056        #end if
1057    #end def assign_defaults
1058
1059
1060    def validity_checks(self):
1061        msgs = []
1062        if self.description is None:
1063            msgs.append('description is missing')
1064        elif not isinstance(self.description,str):
1065            msgs.append('description must be text')
1066        #end if
1067        if self.scale is None:
1068            msgs.append('scale is missing')
1069        elif not isinstance(self.scale,(float,int)):
1070            msgs.append('scale must be a real number')
1071        elif self.scale<0:
1072            msgs.append('scale must be greater than zero')
1073        #end if
1074        if self.axes is None:
1075            msgs.append('axes is missing')
1076        elif not isinstance(self.axes,ndarray):
1077            msgs.append('axes must be an array')
1078        elif self.axes.shape!=(3,3):
1079            msgs.append('axes must be a 3x3 array, shape provided is {0}'.format(self.axes.shape))
1080        elif not isinstance(self.axes[0,0],float):
1081            msgs.append('axes must be an array of real numbers')
1082        #end if
1083        natoms = -1
1084        if self.elem_count is None:
1085            msgs.append('elem_count is missing')
1086        elif not isinstance(self.elem_count,ndarray):
1087            msgs.append('elem_count must be an array')
1088        elif len(self.elem_count)==0:
1089            msgs.append('elem_count array must contain at least one entry')
1090        elif not isinstance(self.elem_count[0],(int,np.int_)):
1091            msgs.append('elem_count must be an array of integers')
1092        else:
1093            if (self.elem_count<1).sum()>0:
1094                msgs.append('all elem_count entries must be greater than zero')
1095            #end if
1096            natoms = self.elem_count.sum()
1097        #end if
1098        if self.elem is not None: # presence depends on vasp version
1099            if not isinstance(self.elem,ndarray):
1100                msgs.append('elem must be an array')
1101            elif isinstance(self.elem_count,ndarray) and len(self.elem)!=len(self.elem_count):
1102                msgs.append('elem and elem_count arrays must be the same length')
1103            elif not isinstance(self.elem[0],str):
1104                msgs.append('elem must be an array of text')
1105            else:
1106                for e in self.elem:
1107                    iselem,symbol = is_element(e,symbol=True)
1108                    if not iselem:
1109                        msgs.append('elem entry "{0}" is not an element'.format(e))
1110                    #end if
1111                #end for
1112            #end for
1113        #end if
1114        if self.coord is None:
1115            msgs.append('coord is missing')
1116        elif not isinstance(self.coord,str):
1117            msgs.append('coord must be text')
1118        #end if
1119        if self.pos is None:
1120            msgs.append('pos is missing')
1121        elif not isinstance(self.pos,ndarray):
1122            msgs.append('pos must be an array')
1123        elif natoms>0 and self.pos.shape!=(natoms,3):
1124            msgs.append('pos must be a {0}x3 array, shape provided is {1}'.format(natoms),self.pos.shape)
1125        elif natoms>0 and not isinstance(self.pos[0,0],float):
1126            msgs.append('pos must be an array of real numbers')
1127        #end if
1128        if self.dynamic is not None: # dynamic is optional
1129            if not isinstance(self.dynamic,ndarray):
1130                msgs.append('dynamic must be an array')
1131            elif natoms>0 and self.dynamic.shape!=(natoms,3):
1132                msgs.append('dynamic must be a {0}x3 array, shape provided is {1}'.format(natoms),self.dynamic.shape)
1133            elif natoms>0 and not isinstance(self.dynamic[0,0],bool):
1134                msgs.append('dynamic must be an array of booleans (true/false)')
1135            #end if
1136        #end if
1137        if self.vel_coord is not None: # velocities are optional
1138            if not isinstance(self.vel_coord,str):
1139                msgs.append('vel_coord must be text')
1140            #end if
1141        #end if
1142        if self.vel is not None: # velocities are optional
1143            if not isinstance(self.vel,ndarray):
1144                msgs.append('vel must be an array')
1145            elif natoms>0 and self.vel.shape!=(natoms,3):
1146                msgs.append('vel must be a {0}x3 array, shape provided is {1}'.format(natoms),self.vel.shape)
1147            elif natoms>0 and not isinstance(self.vel[0,0],float):
1148                msgs.append('vel must be an array of real numbers')
1149            #end if
1150        #end if
1151        return msgs
1152    #end def validity_checks
1153
1154
1155    def read_text(self,text):
1156        read_poscar_chgcar(self,text)
1157    #end def read_text
1158
1159
1160    def write_text(self):
1161        text = ''
1162        if self.description is None:
1163            text += 'System cell and coordinates\n'
1164        else:
1165            text += self.description+'\n'
1166        #end if
1167        text += ' {0}\n'.format(self.scale)
1168        for a in self.axes:
1169            text += ' {0:20.14f} {1:20.14f} {2:20.14f}\n'.format(*a)
1170        #end for
1171        if self.elem is not None:
1172            for e in self.elem:
1173                iselem,symbol = is_element(e,symbol=True)
1174                if not iselem:
1175                    self.error('{0} is not an element'.format(e))
1176                #end if
1177                text += symbol+' '
1178            #end for
1179            text += '\n'
1180        #end if
1181        for ec in self.elem_count:
1182            text += ' {0}'.format(ec)
1183        #end for
1184        text += '\n'
1185        if self.dynamic!=None:
1186            text += 'selective dynamics\n'
1187        #end if
1188        text += self.coord+'\n'
1189        if self.dynamic is None:
1190            for p in self.pos:
1191                text += ' {0:20.14f} {1:20.14f} {2:20.14f}\n'.format(*p)
1192            #end for
1193        else:
1194            bm = self.bool_map
1195            for i in range(len(self.pos)):
1196                p = self.pos[i]
1197                d = self.dynamic[i]
1198                text += ' {0:20.14f} {1:20.14f} {2:20.14f}  {3}  {4}  {5}\n'.format(p[0],p[1],p[2],bm[d[0]],bm[d[1]],bm[d[2]])
1199            #end for
1200        #end if
1201        if self.vel!=None:
1202            text += self.vel_coord+'\n'
1203            for v in self.vel:
1204                text += ' {0:20.14f} {1:20.14f} {2:20.14f}\n'.format(*v)
1205            #end for
1206        #end if
1207        return text
1208    #end def write_text
1209
1210
1211    def incorporate_xsf(self,xsf):
1212        if 'primvec' in xsf:
1213            axes = xsf.primvec.copy()
1214        #end if
1215        if 'convvec' in xsf:
1216            axes = xsf.convvec.copy()
1217        #end if
1218        elem = xsf.elem.copy()
1219        pos  = xsf.pos.copy()
1220
1221        species        = []
1222        species_counts = []
1223        elem_indices   = []
1224
1225        spec_set = set()
1226        for i in range(len(elem)):
1227            e = elem[i]
1228            if not e in spec_set:
1229                spec_set.add(e)
1230                species.append(e)
1231                species_counts.append(0)
1232                elem_indices.append([])
1233            #end if
1234            sindex = species.index(e)
1235            species_counts[sindex] += 1
1236            elem_indices[sindex].append(i)
1237        #end for
1238
1239        elem_order = []
1240        for elem_inds in elem_indices:
1241            elem_order.extend(elem_inds)
1242        #end for
1243
1244        pos = pos[elem_order]
1245
1246        species_ind = species
1247        species = []
1248        for i in species_ind:
1249            species.append(ptable.simple_elements[i].symbol)
1250        #end for
1251
1252        self.scale      = 1.0
1253        self.axes       = axes
1254        self.elem       = array(species,dtype=str)
1255        self.elem_count = array(species_counts,dtype=int)
1256        self.coord      = 'cartesian'
1257        self.pos        = pos
1258
1259        self.assign_defaults()
1260    #end def incorporate_xsf
1261#end class PoscarFile
1262
1263
1264
1265class ChgcarFile(StandardFile):
1266
1267    sftype = 'CHGCAR'
1268
1269    def __init__(self,filepath=None):
1270        self.poscar         = None
1271        self.grid           = None
1272        self.charge_density = None
1273        self.spin_density   = None
1274        StandardFile.__init__(self,filepath)
1275    #end def __init__
1276
1277
1278    def validity_checks(self):
1279        msgs = []
1280        if self.poscar is None:
1281            msgs.append('poscar elements are missing')
1282        elif not isinstance(self.poscar,PoscarFile):
1283            msgs.append('poscar is not an instance of PoscarFile')
1284        else:
1285            msgs.extend(self.poscar.validity_checks())
1286        #end if
1287        if self.grid is None:
1288            msgs.append('grid is missing')
1289        elif not isinstance(self.grid,ndarray):
1290            msgs.append('grid must be an array')
1291        elif len(self.grid)!=3 or self.grid.size!=3:
1292            msgs.append('grid must have 3 entries')
1293        elif not isinstance(self.grid[0],(int,np.int_)):
1294            msgs.append('grid must be an array of integers')
1295        elif (self.grid<1).sum()>0:
1296            msgs.append('all grid entries must be greater than zero')
1297        #end if
1298        if self.grid is not None:
1299            ng = self.grid.prod()
1300        #end if
1301        if self.charge_density is None:
1302            msgs.append('charge_density is missing')
1303        elif not isinstance(self.charge_density,ndarray):
1304            msgs.append('charge_density must be an array')
1305        elif len(self.charge_density)!=ng:
1306            msgs.append('charge_density must have {0} entries ({1} present by length)'.format(ng,len(self.charge_density)))
1307        elif self.charge_density.size!=ng:
1308            msgs.append('charge_density must have {0} entries ({1} present by size)'.format(ng,self.charge_density.size))
1309        elif not isinstance(self.charge_density[0],float):
1310            msgs.append('charge_density must be an array of real numbers')
1311        #end if
1312        if self.spin_density is not None: # spin density is optional
1313            if not isinstance(self.spin_density,ndarray):
1314                msgs.append('spin_density must be an array')
1315            elif len(self.spin_density)!=ng:
1316                msgs.append('spin_density must have {0} entries ({1} present)'.format(ng,len(self.spin_density)))
1317            elif self.spin_density.size!=ng and self.spin_density.shape!=(ng,3):
1318                msgs.append('non-collinear spin_density must be a {0}x3 array, shape provided: {1}'.format(ng,self.spin_density.shape))
1319            elif not isinstance(self.spin_density.ravel()[0],float):
1320                msgs.append('spin_density must be an array of real numbers')
1321            #end if
1322        #end if
1323        return msgs
1324    #end def validity_checks
1325
1326
1327    def read_text(self,text):
1328        read_poscar_chgcar(self,text)
1329    #end def read_text
1330
1331
1332    def write_text(self):
1333        text = self.poscar.write_text()
1334        text+= '\n {0} {1} {2}\n'.format(*self.grid)
1335        densities = [self.charge_density]
1336        if self.spin_density is not None:
1337            if self.spin_density.size==self.charge_density.size:
1338                densities.append(self.spin_density)
1339            else:
1340                for i in range(3):
1341                    densities.append(self.spin_density[:,i])
1342                #end for
1343            #end if
1344        #end if
1345        n=0
1346        for dens in densities:
1347            for d in dens:
1348                text += '{0:20.12E}'.format(d)
1349                n+=1
1350                if n%5==0:
1351                    text+='\n'
1352                #end if
1353            #end for
1354        #end for
1355        return text
1356    #end def write_text
1357
1358
1359    def incorporate_xsf(self,xsf):
1360        poscar = PoscarFile()
1361        poscar.incorporate_xsf(xsf)
1362        density = xsf.remove_ghost().copy()
1363        self.poscar         = poscar
1364        self.grid           = array(density.shape,dtype=int)
1365        self.charge_density = density.ravel(order='F')
1366        self.check_valid()
1367    #end def incorporate_xsf
1368#end class ChgcarFile
1369
1370
1371
1372def read_poscar_chgcar(host,text):
1373    is_poscar = isinstance(host,PoscarFile)
1374    is_chgcar = isinstance(host,ChgcarFile)
1375    if not is_poscar and not is_chgcar:
1376        error('read_poscar_chgcar must be used in conjunction with PoscarFile or ChgcarFile objects only\nencountered object of type: {0}'.format(host.__class__.__name__))
1377    #end if
1378
1379    # read lines and remove fortran comments
1380    raw_lines = text.splitlines()
1381    lines = []
1382    for line in raw_lines:
1383        # remove fortran comments
1384        cloc1 = line.find('!')
1385        cloc2 = line.find('#')
1386        has1  = cloc1!=-1
1387        has2  = cloc2!=-1
1388        if has1 or has2:
1389            if has1 and has2:
1390                cloc = min(cloc1,cloc2)
1391            elif has1:
1392                cloc = cloc1
1393            else:
1394                cloc = cloc2
1395            #end if
1396            line = line[:cloc]
1397        #end if
1398        lines.append(line.strip())
1399    #end for
1400
1401    # extract file information
1402    nlines = len(lines)
1403    min_lines = 8
1404    if nlines<min_lines:
1405        host.error('file {0} must have at least {1} lines\nonly {2} lines found'.format(host.filepath,min_lines,nlines))
1406    #end if
1407    description = lines[0]
1408    dim = 3
1409    scale = float(lines[1].strip())
1410    axes = empty((dim,dim))
1411    axes[0] = array(lines[2].split(),dtype=float)
1412    axes[1] = array(lines[3].split(),dtype=float)
1413    axes[2] = array(lines[4].split(),dtype=float)
1414    tokens = lines[5].split()
1415    if tokens[0].isdigit():
1416        counts = array(tokens,dtype=int)
1417        elem   = None
1418        lcur   = 6
1419    else:
1420        elem   = array(tokens,dtype=str)
1421        counts = array(lines[6].split(),dtype=int)
1422        lcur   = 7
1423    #end if
1424
1425    if lcur<len(lines) and len(lines[lcur])>0:
1426        c = lines[lcur].lower()[0]
1427        lcur+=1
1428    else:
1429        host.error('file {0} is incomplete (missing positions)'.format(host.filepath))
1430    #end if
1431    selective_dynamics = c=='s'
1432    if selective_dynamics: # Selective dynamics
1433        if lcur<len(lines) and len(lines[lcur])>0:
1434            c = lines[lcur].lower()[0]
1435            lcur+=1
1436        else:
1437            host.error('file {0} is incomplete (missing positions)'.format(host.filepath))
1438        #end if
1439    #end if
1440    cartesian = c=='c' or c=='k'
1441    if cartesian:
1442        coord = 'cartesian'
1443    else:
1444        coord = 'direct'
1445    #end if
1446    npos = counts.sum()
1447    if lcur+npos>len(lines):
1448        host.error('file {0} is incomplete (missing positions)'.format(host.filepath))
1449    #end if
1450    spos = []
1451    for i in range(npos):
1452        spos.append(lines[lcur+i].split())
1453    #end for
1454    lcur += npos
1455    spos = array(spos)
1456    pos  = array(spos[:,0:3],dtype=float)
1457    if selective_dynamics:
1458        dynamic = array(spos[:,3:6],dtype=str)
1459        dynamic = dynamic=='T'
1460    else:
1461        dynamic = None
1462    #end if
1463
1464    def is_empty(lines,start=None,end=None):
1465        if start is None:
1466            start = 0
1467        #end if
1468        if end is None:
1469            end = len(lines)
1470        #end if
1471        is_empty = True
1472        for line in lines[start:end]:
1473            is_empty &= len(line)==0
1474        #end for
1475        return is_empty
1476    #end def is_empty
1477
1478    # velocities may be present for poscar
1479    #   assume they are not for chgcar
1480    if is_poscar and lcur<len(lines) and not is_empty(lines,lcur):
1481        cline = lines[lcur].lower()
1482        lcur+=1
1483        if lcur+npos>len(lines):
1484            host.error('file {0} is incomplete (missing velocities)'.format(host.filepath))
1485        #end if
1486        cartesian = len(cline)>0 and (cline[0]=='c' or cline[0]=='k')
1487        if cartesian:
1488            vel_coord = 'cartesian'
1489        else:
1490            vel_coord = 'direct'
1491        #end if
1492        svel = []
1493        for i in range(npos):
1494            svel.append(lines[lcur+i].split())
1495        #end for
1496        lcur += npos
1497        vel = array(svel,dtype=float)
1498    else:
1499        vel_coord = None
1500        vel = None
1501    #end if
1502
1503    # grid data is present for chgcar
1504    if is_chgcar:
1505        lcur+=1
1506        if lcur<len(lines) and len(lines[lcur])>0:
1507            grid = array(lines[lcur].split(),dtype=int)
1508            lcur+=1
1509        else:
1510            host.error('file {0} is incomplete (missing grid)'.format(host.filepath))
1511        #end if
1512        if lcur<len(lines):
1513            ng = grid.prod()
1514            density = []
1515            for line in lines[lcur:]:
1516                density.extend(line.split())
1517            #end for
1518            if len(density)>0:
1519                def is_float(val):
1520                    try:
1521                        v = float(val)
1522                        return True
1523                    except:
1524                        return False
1525                    #end try
1526                #end def is_float
1527                # remove anything but the densities (e.g. augmentation charges)
1528                n=0
1529                while is_float(density[n]):
1530                    n+=ng
1531                    if n+ng>=len(density):
1532                        break
1533                    #end if
1534                #end while
1535                density = array(density[:n],dtype=float)
1536            else:
1537                host.error('file {0} is incomplete (missing density)'.format(host.filepath))
1538            #end if
1539            if density.size%ng!=0:
1540                host.error('number of density data entries is not a multiple of the grid\ngrid shape: {0}\ngrid size: {1}\ndensity size: {2}'.format(grid,ng,density.size))
1541            #end if
1542            ndens = density.size//ng
1543            if ndens==1:
1544                charge_density = density
1545                spin_density   = None
1546            elif ndens==2:
1547                charge_density = density[:ng]
1548                spin_density   = density[ng:]
1549            elif ndens==4:
1550                charge_density = density[:ng]
1551                spin_density   = empty((ng,3),dtype=float)
1552                for i in range(3):
1553                    spin_density[:,i] = density[(i+1)*ng:(i+2)*ng]
1554                #end for
1555            else:
1556                host.error('density data must be present for one of the following situations\n  1) charge density only (1 density)\n  2) charge and collinear spin densities (2 densities)\n  3) charge and non-collinear spin densities (4 densities)\nnumber of densities found: {0}'.format(ndens))
1557            #end if
1558        else:
1559            host.error('file {0} is incomplete (missing density)'.format(host.filepath))
1560        #end if
1561    #end if
1562
1563    if is_poscar:
1564        poscar = host
1565    elif is_chgcar:
1566        poscar = PoscarFile()
1567    #end if
1568
1569    poscar.set(
1570        description = description,
1571        scale       = scale,
1572        axes        = axes,
1573        elem        = elem,
1574        elem_count  = counts,
1575        coord       = coord,
1576        pos         = pos,
1577        dynamic     = dynamic,
1578        vel_coord   = vel_coord,
1579        vel         = vel
1580        )
1581
1582    if is_chgcar:
1583        host.set(
1584            poscar         = poscar,
1585            grid           = grid,
1586            charge_density = charge_density,
1587            spin_density   = spin_density,
1588            )
1589    #end if
1590#end def read_poscar_chgcar
1591