1##################################################################
2##  (c) Copyright 2015-  by Jaron T. Krogel                     ##
3##################################################################
4
5
6#====================================================================#
7#  structure.py                                                      #
8#    Support for atomic structure I/O, generation, and manipulation. #
9#                                                                    #
10#  Content summary:                                                  #
11#    Structure                                                       #
12#      Represents a simulation cell containing a set of atoms.       #
13#      Many functions for manipulating structures or obtaining       #
14#        data regarding local atomic structure.                      #
15#                                                                    #
16#    generate_cell                                                   #
17#      User-facing function to generate an empty simulation cell.    #
18#                                                                    #
19#    generate_structure                                              #
20#      User-facing function to specify arbitrary atomic structures   #
21#      or generate structures corresponding to atoms, dimers, or     #
22#      crystals.                                                     #
23#                                                                    #
24#====================================================================#
25
26"""
27The :py:mod:`structure` module provides support for atomic structure I/O,
28generation, and manipulation.
29
30
31List of module contents
32-----------------------
33
34Read cif file functions:
35
36* :py:func:`read_cif_celldata`
37* :py:func:`read_cif_cell`
38* :py:func:`read_cif`
39
40Operations on logical conditions:
41
42* :py:func:`equate`
43* :py:func:`negate`
44
45Create a Monkhorst-Pack k-point mesh function
46
47* :py:func:`kmesh`
48
49Tile matrix malipulation functions
50
51* :py:func:`reduce_tilematrix`
52* :py:func:`tile_magnetization`
53
54Rotate plane function
55
56* :py:func:`rotate_plane`
57
58Trivial filter function
59
60* :py:func:`trivial_filter`
61
62* :py:class:`MaskFilter`
63
64* :py:func:`optimal_tilematrix`
65
66Base class for :py:class:`Structure` class:
67
68* :py:class:`Sobj`
69
70Base class for :py:class:`DefectStructure`, :py:class:`Crystal`, and :py:class:`Jellium` classes:
71
72* :py:class:`Structure`
73
74SeeK-path functions
75
76* :py:func:`\_getseekpath`
77* :py:func:`get_conventional_cell`
78* :py:func:`get_primitive_cell`
79* :py:func:`get_kpath`
80* :py:func:`get_symmetry`
81* :py:func:`get_structure_with_bands`
82* :py:func:`get_band_tiling`
83* :py:func:`get_seekpath_full`
84
85Interpolate structures functions
86
87* :py:func:`interpolate_structures`
88
89Animate structures functions
90
91* :py:func:`structure_animation`
92
93Concrete :py:class:`Structure` classes:
94
95* :py:class:`DefectStructure`
96* :py:class:`Crystal`
97* :py:class:`Jellium`
98
99Structure generation functions:
100
101* :py:func:`generate_cell`
102* :py:func:`generate_structure`
103* :py:func:`generate_atom_structure`
104* :py:func:`generate_dimer_structure`
105* :py:func:`generate_trimer_structure`
106* :py:func:`generate_jellium_structure`
107* :py:func:`generate_crystal_structure`
108* :py:func:`generate_defect_structure`
109
110Read structure functions
111
112* :py:func:`read_structure`
113
114
115
116Module contents
117---------------
118"""
119
120import os
121import numpy as np
122from copy import deepcopy
123from random import randint
124from numpy import abs,all,append,arange,around,array,atleast_2d,ceil,cos,cross,cross,diag,dot,empty,exp,flipud,floor,identity,isclose,logical_not,mgrid,mod,ndarray,ones,pi,round,sign,sin,sqrt,uint64,zeros
125from numpy.linalg import inv,det,norm
126from unit_converter import convert
127from numerics import nearest_neighbors,convex_hull,voronoi_neighbors
128from periodic_table import pt,is_element
129from fileio import XsfFile,PoscarFile
130from generic import obj
131from developer import DevBase,unavailable,error,warn
132from debug import ci,ls,gs
133
134try:
135    from scipy.special import erfc
136except:
137    erfc = unavailable('scipy.special','erfc')
138#end try
139try:
140    import matplotlib.pyplot as plt
141    from matplotlib.pyplot import plot,subplot,title,xlabel,ylabel
142except:
143    plot,subplot,title,xlabel,ylabel,plt = unavailable('matplotlib.pyplot','plot','subplot','title','xlabel','ylabel','plt')
144#end try
145
146
147
148
149
150# installation instructions to enable cif file read
151#
152#   cif file support in Nexus currently requires two external libraries
153#     PyCifRW  - base interface to read cif files into object format: CifFile
154#     cif2cell - translation layer from CifFile object to cell reconstruction: CellData
155#     (note: cif2cell installation includes PyCifRW)
156#
157#  installation of cif2cell
158#    go to http://sourceforge.net/projects/cif2cell/
159#    click on Download (example: cif2cell-1.2.10.tar.gz)
160#    unpack directory (tar -xzf cif2cell-1.2.10.tar.gz)
161#    enter directory (cd cif2cell-1.2.10)
162#    install cif2cell (python setup.py install)
163#    check python installation
164#      >python
165#      >>>from CifFile import CifFile
166#      >>>from uctools import CellData
167#
168#   Nexus is currently compatible with
169#     cif2cell-1.2.10 and PyCifRW-3.3
170#     cif2cell-1.2.7  and PyCifRW-4.1.1
171#     compatibility last tested: 20 Mar 2017
172#
173try:
174    from CifFile import CifFile
175except:
176    CifFile = unavailable('CifFile','CifFile')
177#end try
178try:
179    from cif2cell.uctools import CellData
180except:
181    CellData = unavailable('cif2cell.uctools','CellData')
182#end try
183
184
185cif2cell_unit_dict = dict(angstrom='A',bohr='B',nm='nm')
186
187
188
189def read_cif_celldata(filepath,block=None,grammar='1.1'):
190    # read cif file with PyCifRW
191    path,cif_file = os.path.split(filepath)
192    if path!='':
193        cwd = os.getcwd()
194        os.chdir(path)
195    #end if
196    cf = CifFile(cif_file,grammar=grammar)
197    #cf = ReadCif(cif_file,grammar=grammar)
198    if path!='':
199        os.chdir(cwd)
200    #end if
201    if block is None:
202        block = list(cf.keys())[0]
203    #end if
204    cb = cf.get(block)
205    if cb is None:
206        error('block {0} was not found in cif file {1}'.format(block,filepath),'read_cif_celldata')
207    #end if
208
209    # repack H-M symbols as normal strings so CellData.getFromCIF won't choke on unicode
210    #for k in ['_symmetry_space_group_name_H-M','_space_group_name_H-M_alt','_symmetry_space_group_name_h-m','_space_group_name_h-m_alt']:
211    #    if k in cb.block:
212    #        v = cb.block[k]
213    #        if isinstance(v,(list,tuple)):
214    #            for i in range(len(v)):
215    #                if isinstance(v[i],unicode):
216    #                    v[i] = str(v[i])
217    #                #end if
218    #            #end for
219    #        #end if
220    #    #end if
221    ##end for
222
223    # extract structure from CifFile with uctools CellData class
224    cd = CellData()
225    cd.getFromCIF(cb)
226
227    return cd
228#end def read_cif_celldata
229
230
231
232def read_cif_cell(filepath,block=None,grammar='1.1',cell='prim'):
233    cd = read_cif_celldata(filepath,block,grammar)
234
235    if cell.startswith('prim'):
236        cell = cd.primitive()
237    elif cell.startswith('conv'):
238        cell = cd.conventional()
239    else:
240        error('cell argument must be primitive or conventional\nyou provided: {0}'.format(cell),'read_cif_cell')
241    #end if
242
243    return cell
244#end def read_cif_cell
245
246
247
248def read_cif(filepath,block=None,grammar='1.1',cell='prim',args_only=False):
249    if isinstance(filepath,str):
250        cell = read_cif_cell(filepath,block,grammar,cell)
251    else:
252        cell = filepath
253    #end if
254
255    # create Structure object from cell
256    if cell.alloy:
257        error('cannot handle alloys','read_cif')
258    #end if
259    units = cif2cell_unit_dict[cell.unit]
260    scale = float(cell.lengthscale)
261    scale = convert(scale,units,'A')
262    units = 'A'
263    axes  = scale*array(cell.latticevectors,dtype=float)
264    elem  = []
265    pos   = []
266    for wyckoff_atoms in cell.atomdata:
267        for atom in wyckoff_atoms:
268            elem.append(str(list(atom.species.keys())[0]))
269            pos.append(atom.position)
270        #end for
271    #end for
272    pos = dot(array(pos,dtype=float),axes)
273
274    if not args_only:
275        s = Structure(
276            axes  = axes,
277            elem  = elem,
278            pos   = pos,
279            units = units
280            )
281        return s
282    else:
283        return axes,elem,pos,units
284    #end if
285#end def read_cif
286
287
288
289
290
291# installation instructions for spglib interface
292#
293#  this is bootstrapped off of spglib's ASE Python interface
294#
295#  installation of spglib
296#    go to http://sourceforge.net/projects/spglib/files/
297#    click on Download spglib-1.8.2.tar.gz (952.6 kB)
298#    unpack directory (tar -xzf spglib-1.8.2.tar.gz)
299#    enter ase directory (cd spglib-1.8.2/python/ase/)
300#    build and install (sudo python setup.py install)
301from periodic_table import pt as ptable
302#try:
303#    from pyspglib import spglib
304#except:
305#    spglib = unavailable('pyspglib','spglib')
306##end try
307try:
308    import spglib
309except:
310    spglib = unavailable('spglib')
311#end try
312
313
314
315
316
317def equate(expr):
318    return expr
319#end def equate
320
321def negate(expr):
322    return not expr
323#end def negate
324
325
326
327def kmesh(kaxes,dim,shift=None):
328    '''
329    Create a Monkhorst-Pack k-point mesh
330    '''
331    if shift is None:
332        shift = (0.,0,0)
333    #end if
334    ndim = len(dim)
335    d = array(dim)
336    s = array(shift)
337    s.shape = 1,ndim
338    d.shape = 1,ndim
339    kp = empty((1,ndim),dtype=float)
340    kgrid = empty((d.prod(),ndim))
341    n=0
342    for k in range(dim[2]):
343        for j in range(dim[1]):
344            for i in range(dim[0]):
345                kp[:] = i,j,k
346                kp = dot((kp+s)/d,kaxes)
347                #kp = (kp+s)/d
348                kgrid[n] = kp
349                n+=1
350            #end for
351        #end for
352    #end for
353    return kgrid
354#end def kmesh
355
356
357
358def reduce_tilematrix(tiling):
359    tiling = array(tiling)
360    t = array(tiling,dtype=int)
361    if abs(tiling-t).sum()>1e-6:
362        Structure.class_error('requested tiling is non-integer\n tiling requested: '+str(tiling))
363    #end if
364
365    dim = len(t)
366    matrix_tiling = t.shape == (dim,dim)
367    if matrix_tiling:
368        if abs(det(t))==0:
369            Structure.class_error('requested tiling matrix is singular\ntiling requested: {0}'.format(t))
370        #end if
371        #find a tiling tuple from the tiling matrix
372        # do this by shearing the tiling matrix (or equivalently the tiled cell)
373        # until it is orthogonal (in the untiled cell axes)
374        # this is just rearranging triangular tiles of space to reshape the cell
375        # so that t1*t2*t3 = det(T) = det(A_tiled)/det(A_untiled)
376        #this way the atoms in the (perhaps oddly shaped) supercell can be
377        # obtained from simple translations of the untiled cell positions
378        T = t  #tiling matrix
379        tilematrix = T.copy()
380        del t
381        tbar = identity(dim) #basis for shearing
382        dr = list(range(dim))
383        #dr = [1,0,2]
384        other = dim*[0] # other[d] = dimensions other than d
385        for d in dr:
386            other[d] = set(dr)-set([d])
387        #end for
388        #move each axis to be parallel to barred directions
389        # these are volume preserving shears of the supercell
390        # each shear keeps two cell face planes fixed while moving the others
391        tvecs = []
392        for dp in [(0,1,2),(2,0,1),(1,2,0),(2,1,0),(0,2,1),(1,0,2)]:
393            success = True
394            Tnew = array(T,dtype=float) #sheared/orthogonal tiling matrix
395            for d in dr:
396                tb = tbar[dp[d]]
397                t  = T[d]
398                d2,d3 = other[d]
399                n = cross(Tnew[d2],Tnew[d3])  #vector normal to 2 cell faces
400                vol   = dot(n,t)
401                bcomp = dot(n,tb)
402                if abs(bcomp)<1e-6:
403                    success = False
404                    break
405                #end if
406                tn = vol*1./bcomp*tb #new axis vector
407                Tnew[d] = tn
408            #end for
409            if success:
410                # apply inverse permutation, if needed
411                Tn = Tnew.copy()
412                for d in dr:
413                    d2 = dp[d]
414                    Tnew[d2] = Tn[d]
415                #end for
416                #the resulting tiling matrix should be diagonal and integer
417                tr = diag(Tnew)
418                nondiagonal = abs(Tnew-diag(tr)).sum()>1e-6
419                if nondiagonal:
420                    Structure.class_error('could not find a diagonal tiling matrix for generating tiled coordinates')
421                #end if
422                tvecs.append(abs(tr))
423            #end if
424        #end for
425        tvecs_old = tvecs
426        tvecs = []
427        tvset = set()
428        for tv in tvecs_old:
429            tvk = tuple(array(around(1e7*tv),dtype=uint64))
430            if tvk not in tvset:
431                tvset.add(tvk)
432                tvecs.append(tv)
433            #end if
434        #end for
435        tilevector = array(tvecs)
436    else:
437        tilevector = t
438        tilematrix = diag(t)
439    #end if
440
441    return tilematrix,tilevector
442#end def reduce_tilematrix
443
444
445
446def rotate_plane(plane,angle,points,units='degrees'):
447    if units=='degrees':
448        angle *= pi/180
449    elif not units.startswith('rad'):
450        error('angular units must be degrees or radians\nyou provided: {0}'.format(angle),'rotate_plane')
451    #end if
452    c = cos(angle)
453    s = sin(angle)
454    if plane=='xy':
455        R = [[ c,-s, 0],
456             [ s, c, 0],
457             [ 0, 0, 1]]
458    elif plane=='yx':
459        R = [[ c, s, 0],
460             [-s, c, 0],
461             [ 0, 0, 1]]
462    elif plane=='yz':
463        R = [[ 1, 0, 0],
464             [ 0, c,-s],
465             [ 0, s, c]]
466    elif plane=='zy':
467        R = [[ 1, 0, 0],
468             [ 0, c, s],
469             [ 0,-s, c]]
470    elif plane=='zx':
471        R = [[ c, 0, s],
472             [ 0, 1, 0],
473             [-s, 0, c]]
474    elif plane=='xz':
475        R = [[ c, 0,-s],
476             [ 0, 1, 0],
477             [ s, 0, c]]
478    else:
479        error('plane must be xy/yx/yz/zy/zx/xz\nyou provided: {0}'.format(plane),'rotate_plane')
480    #end if
481    R = array(R,dtype=float)
482    return dot(R,points.T).T
483#end def rotate_plane
484
485
486
487opt_tm_matrices    = obj()
488opt_tm_wig_indices = obj()
489
490def trivial_filter(T):
491    return True
492#end def trival_filter
493
494class MaskFilter(DevBase):
495    def set(self,mask,dim=3):
496        omask = array(mask)
497        mask  = array(mask,dtype=bool)
498        if mask.size==dim:
499            mvec = mask.ravel()
500            mask = empty((dim,dim),dtype=bool)
501            i=0
502            for mi in mvec:
503                j=0
504                for mj in mvec:
505                    mask[i,j] = mi==mj
506                    j+=1
507                #end for
508                i+=1
509            #end for
510        elif mask.shape!=(dim,dim):
511            error('shape of mask array must be {0},{0}\nshape received: {1},{2}\nmask array received: {3}'.format(dim,mask.shape[0],mask.shape[1],omask),'optimal_tilematrix')
512        #end if
513        self.mask = mask==False
514    #end def set
515
516    def __call__(self,T):
517        return (T[self.mask]==0).all()
518    #end def __call__
519#end class MaskFilter
520mask_filter = MaskFilter()
521
522
523def optimal_tilematrix(axes,volfac,dn=1,tol=1e-3,filter=trivial_filter,mask=None,nc=5):
524    if mask is not None:
525        mask_filter.set(mask)
526        filter = mask_filter
527    #end if
528    dim = 3
529    if isinstance(axes,Structure):
530        axes = axes.axes
531    else:
532        axes = array(axes,dtype=float)
533    #end if
534    if not isinstance(volfac,int):
535        volfac = int(around(volfac))
536    #end if
537    volume = abs(det(axes))*volfac
538    axinv  = inv(axes)
539    cube   = volume**(1./3)*identity(dim)
540    Tref   = array(around(dot(cube,axinv)),dtype=int)
541    # calculate and store all tiling matrix variations
542    if dn not in opt_tm_matrices:
543        mats = []
544        rng = tuple(range(-dn,dn+1))
545        for n1 in rng:
546            for n2 in rng:
547                for n3 in rng:
548                    for n4 in rng:
549                        for n5 in rng:
550                            for n6 in rng:
551                                for n7 in rng:
552                                    for n8 in rng:
553                                        for n9 in rng:
554                                            mats.append((n1,n2,n3,n4,n5,n6,n7,n8,n9))
555                                        #end for
556                                    #end for
557                                #end for
558                            #end for
559                        #end for
560                    #end for
561                #end for
562            #end for
563        #end for
564        mats = array(mats,dtype=int)
565        mats.shape = (2*dn+1)**(dim*dim),dim,dim
566        opt_tm_matrices[dn] = mats
567    else:
568        mats = opt_tm_matrices[dn]
569    #end if
570    # calculate and store all wigner image indices
571    if nc not in opt_tm_wig_indices:
572        inds = []
573        rng = tuple(range(-nc,nc+1))
574        for k in rng:
575            for j in rng:
576                for i in rng:
577                    if i!=0 or j!=0 or k!=0:
578                        inds.append((i,j,k))
579                    #end if
580                #end for
581            #end for
582        #end for
583        inds = array(inds,dtype=int)
584        opt_tm_wig_indices[nc] = inds
585    else:
586        inds = opt_tm_wig_indices[nc]
587    #end if
588    # track counts of tiling matrices
589    ntilings        = len(mats)
590    nequiv_volume   = 0
591    nfilter         = 0
592    nequiv_inscribe = 0
593    nequiv_wigner   = 0
594    nequiv_cubicity = 0
595    nequiv_shape    = 0
596    # try a faster search for cells w/ target volume
597    det_inds_p = [
598        [(0,0),(1,1),(2,2)],
599        [(0,1),(1,2),(2,0)],
600        [(0,2),(1,0),(2,1)]
601        ]
602    det_inds_m = [
603        [(0,0),(1,2),(2,1)],
604        [(0,1),(1,0),(2,2)],
605        [(0,2),(1,1),(2,0)]
606        ]
607    volfacs = zeros((len(mats),),dtype=int)
608    for (i1,j1),(i2,j2),(i3,j3) in det_inds_p:
609        volfacs += (Tref[i1,j1]+mats[:,i1,j1])*(Tref[i2,j2]+mats[:,i2,j2])*(Tref[i3,j3]+mats[:,i3,j3])
610    #end for
611    for (i1,j1),(i2,j2),(i3,j3) in det_inds_m:
612        volfacs -= (Tref[i1,j1]+mats[:,i1,j1])*(Tref[i2,j2]+mats[:,i2,j2])*(Tref[i3,j3]+mats[:,i3,j3])
613    #end for
614    Tmats = mats[abs(volfacs)==volfac]
615    nequiv_volume = len(Tmats)
616    # find the set of cells with maximal inscribing radius
617    inscribe_tilings = []
618    rmax = -1e99
619    for mat in Tmats:
620        T = Tref + mat
621        if filter(T):
622            nfilter+=1
623            Taxes = dot(T,axes)
624            rc1 = norm(cross(Taxes[0],Taxes[1]))
625            rc2 = norm(cross(Taxes[1],Taxes[2]))
626            rc3 = norm(cross(Taxes[2],Taxes[0]))
627            r   = 0.5*volume/max(rc1,rc2,rc3) # inscribing radius
628            if r>rmax or abs(r-rmax)<tol:
629                inscribe_tilings.append((r,T,Taxes))
630                rmax = r
631            #end if
632        #end if
633    #end for
634    # find the set of cells w/ maximal wigner radius out of the inscribing set
635    wigner_tilings = []
636    rwmax = -1e99
637    for r,T,Taxes in inscribe_tilings:
638        if abs(r-rmax)<tol:
639            nequiv_inscribe+=1
640            rw = 1e99
641            for ind in inds:
642                rw = min(rw,0.5*norm(dot(ind,Taxes)))
643            #end for
644            if rw>rwmax or abs(rw-rwmax)<tol:
645                wigner_tilings.append((rw,T,Taxes))
646                rwmax = rw
647            #end if
648        #end if
649    #end for
650    # find the set of cells w/ maximal cubicity
651    # (minimum cube_deviation)
652    cube_tilings = []
653    cmin = 1e99
654    for rw,T,Ta in wigner_tilings:
655        if abs(rw-rwmax)<tol:
656            nequiv_wigner+=1
657            dc = volume**(1./3)*sqrt(2.)
658            d1 = abs(norm(Ta[0]+Ta[1])-dc)
659            d2 = abs(norm(Ta[1]+Ta[2])-dc)
660            d3 = abs(norm(Ta[2]+Ta[0])-dc)
661            d4 = abs(norm(Ta[0]-Ta[1])-dc)
662            d5 = abs(norm(Ta[1]-Ta[2])-dc)
663            d6 = abs(norm(Ta[2]-Ta[0])-dc)
664            cube_dev = (d1+d2+d3+d4+d5+d6)/(6*dc)
665            if cube_dev<cmin or abs(cube_dev-cmin)<tol:
666                cube_tilings.append((cube_dev,rw,T,Ta))
667                cmin = cube_dev
668            #end if
669        #end if
670    #end for
671    # prioritize selection by "shapeliness" of tiling matrix
672    #   prioritize positive diagonal elements
673    #   penalize off-diagonal elements
674    #   penalize negative off-diagonal elements
675    shapely_tilings = []
676    smax = -1e99
677    for cd,rw,T,Taxes in cube_tilings:
678        if abs(cd-cmin)<tol:
679            nequiv_cubicity+=1
680            d = diag(T)
681            o = (T-diag(d)).ravel()
682            s = sign(d).sum()-(abs(o)>0).sum()-(o<0).sum()
683            if s>smax or abs(s-smax)<tol:
684                shapely_tilings.append((s,rw,T,Taxes))
685                smax = s
686            #end if
687        #end if
688    #end for
689    # prioritize selection by symmetry of tiling matrix
690    ropt   = -1e99
691    Topt   = None
692    Taxopt = None
693    diagonal      = []
694    symmetric     = []
695    antisymmetric = []
696    other         = []
697    for s,rw,T,Taxes in shapely_tilings:
698        if abs(s-smax)<tol:
699            nequiv_shape+=1
700            Td = diag(diag(T))
701            if abs(Td-T).sum()==0:
702                diagonal.append((rw,T,Taxes))
703            elif abs(T.T-T).sum()==0:
704                symmetric.append((rw,T,Taxes))
705            elif abs(T.T+T-2*Td).sum()==0:
706                antisymmetric.append((rw,T,Taxes))
707            else:
708                other.append((rw,T,Taxes))
709            #end if
710        #end if
711    #end for
712    s = 1
713    if len(diagonal)>0:
714        cells = diagonal
715    elif len(symmetric)>0:
716        cells = symmetric
717    elif len(antisymmetric)>0:
718        cells = antisymmetric
719        s = -1
720    elif len(other)>0:
721        cells = other
722    #end if
723    skew_min = 1e99
724    if len(cells)>0:
725        for rw,T,Taxes in cells:
726            Td = diag(diag(T))
727            skew = abs(T.T-s*T-(1-s)*Td).sum()
728            if skew<skew_min:
729                ropt = rw
730                Topt = T
731                Taxopt = Taxes
732                skew_min = skew
733            #end if
734        #end for
735    #end if
736    if Taxopt is None:
737        error('optimal tilematrix for volfac={0} not found with tolerance {1}\ndifference range (dn): {2}\ntiling matrices searched: {3}\ncells with target volume: {4}\ncells that passed the filter: {5}\ncells with equivalent inscribing radius: {6}\ncells with equivalent wigner radius: {7}\ncells with equivalent cubicity: {8}\nmatrices with equivalent shapeliness: {9}\nplease try again with dn={10}'.format(volfac,tol,dn,ntilings,nequiv_volume,nfilter,nequiv_inscribe,nequiv_wigner,nequiv_cubicity,nequiv_shape,dn+1))
738    #end if
739    if det(Taxopt)<0:
740        Topt = -Topt
741    #end if
742    return Topt,ropt
743#end def optimal_tilematrix
744
745
746class Sobj(DevBase):
747    None
748#end class Sobj
749
750
751
752class Structure(Sobj):
753
754    operations = obj()
755
756    @classmethod
757    def set_operations(cls):
758        cls.operations.set(
759            remove_folded_structure = cls.remove_folded_structure,
760            recenter = cls.recenter,
761            )
762    #end def set_operations
763
764
765    def __init__(self,
766                 axes              = None,
767                 scale             = 1.,
768                 elem              = None,
769                 pos               = None,
770                 elem_pos          = None,
771                 mag               = None,
772                 center            = None,
773                 kpoints           = None,
774                 kweights          = None,
775                 kgrid             = None,
776                 kshift            = None,
777                 permute           = None,
778                 units             = None,
779                 tiling            = None,
780                 rescale           = True,
781                 dim               = 3,
782                 magnetization     = None,
783                 operations        = None,
784                 background_charge = 0,
785                 frozen            = None,
786                 bconds            = None,
787                 posu              = None,
788                 use_prim          = None,
789                 add_kpath         = False,
790                 symm_kgrid        = False,
791                 ):
792
793        if isinstance(axes,str):
794            axes = array(axes.split(),dtype=float)
795            axes.shape = dim,dim
796        #end if
797        if center is None:
798            if axes is not None:
799                center = array(axes,dtype=float).sum(0)/2
800            else:
801                center = dim*[0]
802            #end if
803        #end if
804        if bconds is None or bconds=='periodic':
805            bconds = dim*['p']
806        #end if
807        if axes is None:
808            axes   = []
809            bconds = []
810        #end if
811        if elem_pos is not None:
812            ep = array(elem_pos.split(),dtype=str)
813            ep.shape = ep.size//(dim+1),(dim+1)
814            elem = ep[:,0].ravel()
815            pos  = ep[:,1:dim+1]
816        #end if
817        if elem is None:
818            elem = []
819        #end if
820        if posu is not None:
821            pos = posu
822        #end if
823        if pos is None:
824            pos = empty((0,dim))
825        #end if
826        if kshift is None:
827            kshift = 0,0,0
828        #end if
829        self.scale    = 1.
830        self.units    = units
831        self.dim      = dim
832        self.center   = array(center,dtype=float)
833        self.axes     = array(axes,dtype=float)
834        self.set_bconds(bconds)
835        self.set_elem(elem)
836        self.set_pos(pos)
837        self.set_mag(mag)
838        self.set_frozen(frozen)
839        self.kpoints  = empty((0,dim))
840        self.kweights = empty((0,))
841        self.background_charge = background_charge
842        self.remove_folded_structure()
843        if len(axes)==0:
844            self.kaxes=array([])
845        else:
846            self.kaxes=2*pi*inv(self.axes).T
847        #end if
848        if posu is not None:
849            self.pos_to_cartesian()
850        #end if
851        if magnetization is not None:
852            self.magnetize(magnetization)
853        #end if
854        if use_prim is not None and use_prim is not False:
855            self.become_primitive(source=use_prim,add_kpath=add_kpath)
856        #end if
857        if tiling is not None:
858            self.tile(tiling,in_place=True)
859        #end if
860        if kpoints is not None:
861            self.add_kpoints(kpoints,kweights)
862        #end if
863        if kgrid is not None:
864            if not symm_kgrid:
865                self.add_kmesh(kgrid,kshift)
866            else:
867                self.add_symmetrized_kmesh(kgrid,kshift)
868            #end if
869        #end if
870        if rescale:
871            self.rescale(scale)
872        else:
873            self.scale = scale
874        #end if
875        if permute is not None:
876            self.permute(permute)
877        #end if
878        if operations is not None:
879            self.operate(operations)
880        #end if
881    #end def __init__
882
883
884    def check_consistent(self,tol=1e-8,exit=True,message=False):
885        msg = ''
886        if self.has_axes():
887            kaxes = 2*pi*inv(self.axes).T
888            abs_diff = abs(self.kaxes-kaxes).sum()
889            if abs_diff>tol:
890                msg += 'Direct and reciprocal space axes are not consistent.\naxes present:\n{0}\nkaxes present:\n{1}\nConsistent kaxes:\n{2}\nAbsolute difference: {3}\n'.format(self.axes,self.kaxes,kaxes,abs_diff)
891            #end if
892        #end if
893        N = len(self.elem)
894        D = self.dim
895        pshape = (N,D)
896        if self.pos.shape!=pshape:
897            msg += 'pos is not the right shape\npos shape: {}\nCorrect shape: {}\n'.format(self.pos.shape,pshape)
898        #end if
899        if self.mag is not None and len(self.mag)!=N:
900            msg += 'mag does not have the right length\nmag length: {}\nCorrect length: {}\n'.format(self.mag,N)
901        #end if
902        if self.frozen is not None and self.frozen.shape!=pshape:
903            msg += 'frozen is not the right shape\nfrozen shape: {}\nCorrect shape: {}\n'.format(self.frozen.shape,pshape)
904        #end if
905        consistent = len(msg)==0
906        if not consistent and exit:
907            self.error(msg)
908        #end if
909        if not message:
910            return consistent
911        else:
912            return consistent,msg
913        #end if
914    #end def check_consistent
915
916
917    def set_axes(self,axes):
918        self.reset_axes(axes)
919    #end def set_axes
920
921
922    def set_bconds(self,bconds):
923        self.bconds = array(tuple(bconds),dtype=str)
924    #end def bconds
925
926
927    def set_elem(self,elem):
928        self.elem = array(elem,dtype=object)
929    #end def set_elem
930
931
932    def set_pos(self,pos):
933        self.pos = array(pos,dtype=float)
934        if len(self.pos)!=len(self.elem):
935            self.error('Atomic positions must have same length as elem.\nelem length: {}\nAtomic positions length: {}\n'.format(len(self.elem),len(self.pos)))
936        #end if
937    #end def set_pos
938
939
940    def set_mag(self,mag=None):
941        if mag is None:
942            self.mag = None
943        else:
944            self.mag = np.array(mag,dtype=object)
945            if len(self.mag)!=len(self.elem):
946                self.error('Magnetic moments must have same length as elem.\nelem length: {}\nMagnetic moments length: {}\n'.format(len(self.elem),len(self.mag)))
947            #end if
948        #end if
949    #end def set_mag
950
951
952    def set_frozen(self,frozen=None):
953        if frozen is None:
954            self.frozen = None
955        else:
956            self.frozen = np.array(frozen,dtype=bool)
957            if self.frozen.shape!=self.pos.shape:
958                self.error('Frozen directions must have the same shape as positions.\nPositions shape: {0}\nFrozen directions shape: {1}'.format(self.pos.shape,self.frozen.shape))
959            #end if
960        #end if
961    #end def set_frozen
962
963
964    def size(self):
965        return len(self.elem)
966    #end def size
967
968
969    def has_axes(self):
970        return len(self.axes)==self.dim
971    #end def has_axes
972
973
974    def operate(self,operations):
975        for op in operations:
976            if not op in self.operations:
977                self.error('{0} is not a known operation\nvalid options are:\n  {1}'.format(op,list(self.operations.keys())))
978            else:
979                self.operations[op](self)
980            #end if
981        #end for
982    #end def operate
983
984
985    def has_tmatrix(self):
986        return 'tmatrix' in self and self.tmatrix is not None
987    #end def has_tmatrix
988
989
990    def is_tiled(self):
991        return self.has_folded() and self.has_tmatrix()
992    #end def is_tiled
993
994
995    def set_folded(self,folded):
996        self.set_folded_structure(folded)
997    #end def set_folded
998
999
1000    def remove_folded(self):
1001        self.remove_folded_structure()
1002    #end def remove_folded
1003
1004
1005    def has_folded(self):
1006        return self.has_folded_structure()
1007    #end def has_folded
1008
1009
1010    def set_folded_structure(self,folded):
1011        if not isinstance(folded,Structure):
1012            self.error('cannot set folded structure\nfolded structure must be an object with type Structure\nreceived type: {0}'.format(folded.__class__.__name__))
1013        #end if
1014        self.folded_structure = folded
1015        if self.has_axes():
1016            self.tmatrix = self.tilematrix(folded)
1017        #end if
1018    #end def set_folded_structure
1019
1020
1021    def remove_folded_structure(self):
1022        self.folded_structure = None
1023        if 'tmatrix' in self:
1024            del self.tmatrix
1025        #end if
1026    #end def remove_folded_structure
1027
1028
1029    def has_folded_structure(self):
1030        return self.folded_structure is not None
1031    #end def has_folded_structure
1032
1033
1034    # test needed
1035    def group_atoms(self,folded=True):
1036        if len(self.elem)>0:
1037            order = self.elem.argsort()
1038            if (self.elem!=self.elem[order]).any():
1039                self.elem = self.elem[order]
1040                self.pos  = self.pos[order]
1041            #end if
1042        #end if
1043        if self.folded_structure!=None and folded:
1044            self.folded_structure.group_atoms(folded)
1045        #end if
1046    #end def group_atoms
1047
1048
1049    # test needed
1050    def rename(self,folded=True,**name_pairs):
1051        elem = self.elem
1052        for old,new in name_pairs.items():
1053            for i in range(len(self.elem)):
1054                if old==elem[i]:
1055                    elem[i] = new
1056                #end if
1057            #end for
1058        #end for
1059        if self.folded_structure!=None and folded:
1060            self.folded_structure.rename(folded=folded,**name_pairs)
1061        #end if
1062    #end def rename
1063
1064
1065    # test needed
1066    def reset_axes(self,axes=None):
1067        if axes is None:
1068            axes = self.axes
1069        else:
1070            axes = array(axes)
1071            self.remove_folded_structure()
1072        #end if
1073        self.axes  = axes
1074        self.kaxes = 2*pi*inv(axes).T
1075        self.center = axes.sum(0)/2
1076    #end def reset_axes
1077
1078
1079    # test needed
1080    def adjust_axes(self,axes):
1081        self.skew(dot(inv(self.axes),axes))
1082    #end def adjust_axes
1083
1084
1085    # test needed
1086    def reshape_axes(self,reshaping):
1087        R = array(reshaping)
1088        if abs(abs(det(R))-1)<1e-6:
1089            self.axes = dot(self.axes,R)
1090        else:
1091            R = dot(inv(self.axes),R)
1092            if abs(abs(det(R))-1)<1e-6:
1093                self.axes = dot(self.axes,R)
1094            else:
1095                self.error('reshaping matrix must not change the volume\n  reshaping matrix:\n  {0}\n  volume change ratio: {1}'.format(R,abs(det(R))))
1096            #end if
1097        #end if
1098    #end def reshape_axes
1099
1100
1101    def write_axes(self):
1102        c = ''
1103        for a in self.axes:
1104            c+='{0:12.8f} {1:12.8f} {2:12.8f}\n'.format(a[0],a[1],a[2])
1105        #end for
1106        return c
1107    #end def write_axes
1108
1109
1110    def corners(self):
1111        a = self.axes
1112        c = array([(0,0,0),
1113                   a[0],
1114                   a[1],
1115                   a[2],
1116                   a[0]+a[1],
1117                   a[1]+a[2],
1118                   a[2]+a[0],
1119                   a[0]+a[1]+a[2],
1120                   ])
1121        return c
1122    #end def corners
1123
1124
1125    # test needed
1126    def miller_direction(self,h,k,l,normalize=False):
1127        d = dot((h,k,l),self.axes)
1128        if normalize:
1129            d/=norm(d)
1130        #end if
1131        return d
1132    #end def miller_direction
1133
1134
1135    # test needed
1136    def miller_normal(self,h,k,l,normalize=False):
1137        d = dot((h,k,l),self.kaxes)
1138        if normalize:
1139            d/=norm(d)
1140        #end if
1141        return d
1142    #end def miller_normal
1143
1144
1145    # test needed
1146    def project_plane(self,a1,a2,points=None):
1147        # a1/a2: in plane vectors
1148        if points is None:
1149            points = self.pos
1150        #end if
1151        a1n = norm(a1)
1152        a2n = norm(a2)
1153        a1/=a1n
1154        a2/=a2n
1155        n = cross(a1,a2)
1156        plane_coords = []
1157        for p in points:
1158            p -= dot(n,p)*n # project point into plane
1159            c1 = dot(a1,p)/a1n
1160            c2 = dot(a2,p)/a2n
1161            plane_coords.append((c1,c2))
1162        #end for
1163        return array(plane_coords,dtype=float)
1164    #end def project_plane
1165
1166
1167    def bounding_box(self,scale=1.0,minsize=None,mindist=0,box='tight',recenter=False):
1168        pmin    = self.pos.min(0)-mindist
1169        pmax    = self.pos.max(0)+mindist
1170        pcenter = (pmax+pmin)/2
1171        prange  = pmax-pmin
1172        if minsize is not None:
1173            for i,pr in enumerate(prange):
1174                prange[i] = max(minsize,prange[i])
1175            #end for
1176        #end if
1177        if box=='tight':
1178            axes = diag(prange)
1179        elif box=='cubic' or box=='cube':
1180            prmax = prange.max()
1181            axes = diag((prmax,prmax,prmax))
1182        elif isinstance(box,ndarray) or isinstance(box,list):
1183            box = array(box)
1184            if box.shape!=(3,3):
1185                self.error('requested box must be 3-dimensional (3x3 axes)\n  you provided: '+str(box)+'\n shape: '+str(box.shape))
1186            #end if
1187            binv = inv(box)
1188            pu = dot(self.pos,binv)
1189            pmin    = pu.min(0)
1190            pmax    = pu.max(0)
1191            pcenter = (pmax+pmin)/2
1192            prange  = pmax-pmin
1193            axes    = dot(diag(prange),box)
1194        else:
1195            self.error("invalid request for box\n  valid options are 'tight', 'cubic', or axes array (3x3)\n  you provided: "+str(box))
1196        #end if
1197        self.reset_axes(scale*axes)
1198        self.slide(self.center-pcenter,recenter)
1199    #end def bounding_box
1200
1201
1202    def center_molecule(self):
1203        self.slide(self.center-self.pos.mean(0),recenter=False)
1204    #end def center_molecule
1205
1206
1207    # test needed
1208    def center_solid(self):
1209        u = self.pos_unit()
1210        du = (1-u.min(0)-u.max(0))/2
1211        self.slide(dot(du,self.axes),recenter=False)
1212    #end def center_solid
1213
1214
1215    # test needed
1216    def permute(self,permutation):
1217        dim = self.dim
1218        P = empty((dim,dim),dtype=int)
1219        if len(permutation)!=dim:
1220            self.error(' permutation vector must have {0} elements\n you provided {1}'.format(dim,permutation))
1221        #end if
1222        for i in range(dim):
1223            p = permutation[i]
1224            pv = zeros((dim,),dtype=int)
1225            if p=='x' or p=='0':
1226                pv[0] = 1
1227            elif p=='y' or p=='1':
1228                pv[1] = 1
1229            elif p=='z' or p=='2':
1230                pv[2] = 1
1231            #end if
1232            P[:,i] = pv[:]
1233        #end for
1234        self.center = dot(self.center,P)
1235        if self.has_axes():
1236            self.axes = dot(self.axes,P)
1237        #end if
1238        if len(self.pos)>0:
1239            self.pos = dot(self.pos,P)
1240        #end if
1241        if len(self.kaxes)>0:
1242            self.kaxes = dot(self.kaxes,P)
1243        #end if
1244        if len(self.kpoints)>0:
1245            self.kpoints = dot(self.kpoints,P)
1246        #end if
1247        if self.folded_structure!=None:
1248            self.folded_structure.permute(permutation)
1249        #end if
1250    #end def permute
1251
1252
1253    # test needed
1254    def rotate_plane(self,plane,angle,units='degrees'):
1255        self.pos = rotate_plane(plane,angle,self.pos,units)
1256        if self.has_axes():
1257            axes = rotate_plane(plane,angle,self.axes,units)
1258            self.reset_axes(axes)
1259        #end if
1260    #end def rotate_plane
1261
1262
1263    # test needed
1264    def upcast(self,DerivedStructure):
1265        if not issubclass(DerivedStructure,Structure):
1266            self.error(DerivedStructure.__name__,'is not derived from Structure')
1267        #end if
1268        ds = DerivedStructure()
1269        for name,value in self.items():
1270            ds[name] = deepcopy(value)
1271        #end for
1272        return ds
1273    #end def upcast
1274
1275
1276    # test needed
1277    def incorporate(self,other):
1278        self.set_elem(list(self.elem)+list(other.elem))
1279        self.pos=array(list(self.pos)+list(other.pos))
1280    #end def incorporate
1281
1282
1283    # test needed
1284    def clone_from(self,other):
1285        if not isinstance(other,Structure):
1286            self.error('cloning failed\ncan only clone from other Structure objects\nreceived object of type: {0}'.format(other.__class__.__name__))
1287        #end if
1288        o = other.copy()
1289        self.__dict__ = o.__dict__
1290    #end def clone_from
1291
1292
1293    # test needed
1294    def add_atoms(self,elem,pos):
1295        self.set_elem(list(self.elem)+list(elem))
1296        self.pos=array(list(self.pos)+list(pos))
1297    #end def add_atoms
1298
1299
1300    def is_open(self):
1301        return not self.any_periodic()
1302    #end def is_open
1303
1304
1305    def is_periodic(self):
1306        return self.any_periodic()
1307    #end def is_periodic
1308
1309
1310    def any_periodic(self):
1311        has_cell    = self.has_axes()
1312        pbc = False
1313        for bc in self.bconds:
1314            pbc |= bc=='p'
1315        #end if
1316        periodic = has_cell and pbc
1317        return periodic
1318    #end def any_periodic
1319
1320
1321    def all_periodic(self):
1322        has_cell = self.has_axes()
1323        pbc = True
1324        for bc in self.bconds:
1325            pbc &= bc=='p'
1326        #end if
1327        periodic = has_cell and pbc
1328        return periodic
1329    #end def all_periodic
1330
1331
1332    # test needed
1333    def distances(self,pos1=None,pos2=None):
1334        if isinstance(pos1,Structure):
1335            pos1 = pos1.pos
1336        #end if
1337        if pos2 is None:
1338            if pos1 is None:
1339                return sqrt((self.pos**2).sum(1))
1340            else:
1341                pos2 = self.pos
1342            #end if
1343        #end if
1344        if len(pos1)!=len(pos2):
1345            self.error('positions arrays are not the same length')
1346        #end if
1347        return sqrt(((pos1-pos2)**2).sum(1))
1348    #end def distances
1349
1350
1351    def count_kshells(self, kcut, tilevec=[12, 12, 12], nkdig=10):
1352      # check tilevec input
1353      for nt in tilevec:
1354        if nt % 2 != 0:
1355          msg = 'tilevec must contain even integers'
1356          msg += ' so that kgrid can be zero centered.'
1357          Structure.class_error(msg, 'count_kshells')
1358        #end if
1359      #end for
1360
1361      origin = np.array([[0, 0, 0]])
1362      axes = self.axes
1363      raxes = 2*np.pi*np.linalg.inv(axes).T
1364      kvecs = self.tile_points(origin, raxes, tilevec)
1365      kvecs -= np.dot(tilevec, raxes)/2  # center around 0
1366      kmags = np.linalg.norm(kvecs, axis=-1)
1367
1368      # make sure tilevec is sufficient for kcut
1369      klimit = 0.5*kmags.max()
1370      if kcut > klimit:
1371        msg = 'kcut %3.2f > klimit=%3.2f\n' % (kcut, klimit)
1372        msg += ' please increase tilevec to be safe.\n'
1373        Structure.class_error(msg, 'count_kshells')
1374      #end if
1375
1376      sel = (0<kmags) & (kmags<kcut)
1377      ukmags = np.unique(kmags[sel].round(nkdig))
1378      return len(ukmags)
1379    #end def count_kshells
1380
1381
1382    def volume(self):
1383        if not self.has_axes():
1384            return None
1385        else:
1386            return abs(det(self.axes))
1387        #end if
1388    #end def volume
1389
1390
1391    def rwigner(self,nc=5):
1392        if self.dim!=3:
1393            self.error('rwigner is currently only implemented for 3 dimensions')
1394        #end if
1395        rmin = 1e90
1396        n=empty((1,3))
1397        rng = tuple(range(-nc,nc+1))
1398        for k in rng:
1399            for j in rng:
1400                for i in rng:
1401                    if i!=0 or j!=0 or k!=0:
1402                        n[:] = i,j,k
1403                        rmin = min(rmin,.5*norm(dot(n,self.axes)))
1404                    #end if
1405                #end for
1406            #end for
1407        #end for
1408        return rmin
1409    #end def rwigner
1410
1411
1412    def rinscribe(self):
1413        if self.dim!=3:
1414            self.error('rinscribe is currently only implemented for 3 dimensions')
1415        #end if
1416        radius = 1e99
1417        dim=3
1418        axes=self.axes
1419        for d in range(dim):
1420            i = d
1421            j = (d+1)%dim
1422            rc = cross(axes[i,:],axes[j,:])
1423            radius = min(radius,.5*abs(det(axes))/norm(rc))
1424        #end for
1425        return radius
1426    #end def rinscribe
1427
1428
1429    # test needed
1430    def rwigner_cube(self,*args,**kwargs):
1431        cube = Structure()
1432        a = self.volume()**(1./3)
1433        cube.set_axes([[a,0,0],[0,a,0],[0,0,a]])
1434        return cube.rwigner(*args,**kwargs)
1435    #end def rwigner_cube
1436
1437
1438    # test needed
1439    def rinscribe_cube(self,*args,**kwargs):
1440        cube = Structure()
1441        a = self.volume()**(1./3)
1442        cube.set_axes([[a,0,0],[0,a,0],[0,0,a]])
1443        return cube.rinscribe(*args,**kwargs)
1444    #end def rinscribe_cube
1445
1446
1447    def rmin(self):
1448        return self.rwigner()
1449    #end def rmin
1450
1451
1452    def rcell(self):
1453        return self.rinscribe()
1454    #end def rcell
1455
1456
1457    # test needed
1458    # scale invariant measure of deviation from cube shape
1459    #   based on deviation of face diagonals from cube
1460    def cube_deviation(self):
1461        a = self.axes
1462        dc = self.volume()**(1./3)*sqrt(2.)
1463        d1 = abs(norm(a[0]+a[1])-dc)
1464        d2 = abs(norm(a[1]+a[2])-dc)
1465        d3 = abs(norm(a[2]+a[0])-dc)
1466        d4 = abs(norm(a[0]-a[1])-dc)
1467        d5 = abs(norm(a[1]-a[2])-dc)
1468        d6 = abs(norm(a[2]-a[0])-dc)
1469        return (d1+d2+d3+d4+d5+d6)/(6*dc)
1470    #end def cube_deviation
1471
1472
1473    # test needed
1474    # apply volume preserving shear-removing transformations to cell axes
1475    #   resulting unsheared cell has orthogonal axes
1476    #    while remaining periodically correct
1477    #   note that the unshearing procedure is not unique
1478    #   it depends on the order of unshearing operations
1479    def unsheared_axes(self,axes=None,distances=False):
1480        if self.dim!=3:
1481            self.error('unsheared_axes is currently only implemented for 3 dimensions')
1482        #end if
1483        if axes is None:
1484            axes = self.axes
1485        #end if
1486        dim=3
1487        axbar = identity(dim)
1488        axnew = array(axes,dtype=float)
1489        dists = empty((dim,))
1490        for d in range(dim):
1491            d2 = (d+1)%dim
1492            d3 = (d+2)%dim
1493            n = cross(axnew[d2],axnew[d3])  #vector normal to 2 cell faces
1494            axdist = dot(n,axes[d])/dot(n,axbar[d])
1495            axnew[d]  = axdist*axbar[d]
1496            dists[d] = axdist
1497        #end for
1498        if not distances:
1499            return axnew
1500        else:
1501            return axnew,dists
1502        #end if
1503    #end def unsheared_axes
1504
1505
1506    # test needed
1507    # vectors parallel to cell faces
1508    #   length of vectors is distance between parallel face planes
1509    #   note that the product of distances is not the cell volume in general
1510    #   see "unsheared_axes" function
1511    #   (e.g. a volume preserving shear may bring two face planes arbitrarily close)
1512    def face_vectors(self,axes=None,distances=False):
1513        if axes is None:
1514            axes = self.axes
1515        #end if
1516        fv = inv(axes).T
1517        for d in range(len(fv)):
1518            fv[d] /= norm(fv[d]) # face normals
1519        #end for
1520        dv = dot(axes,fv.T) # axis projections onto face normals
1521        fv = dot(dv,fv)     # face normals lengthened by plane separation
1522        if not distances:
1523            return fv
1524        else:
1525            return fv,diag(dv)
1526        #end if
1527    #end def face_vectors
1528
1529
1530    # test needed
1531    def face_distances(self):
1532        return self.face_vectors(distances=True)[1]
1533    #end def face_distances
1534
1535
1536    # test needed
1537    def rescale(self,scale):
1538        self.scale  *= scale
1539        self.axes   *= scale
1540        self.pos    *= scale
1541        self.center *= scale
1542        self.kaxes  /= scale
1543        self.kpoints/= scale
1544        if self.folded_structure!=None:
1545            self.folded_structure.rescale(scale)
1546        #end if
1547    #end def rescale
1548
1549
1550    # test needed
1551    def stretch(self,s1,s2,s3):
1552        if self.dim!=3:
1553            self.error('stretch is currently only implemented for 3 dimensions')
1554        #end if
1555        d = diag((s1,s2,s3))
1556        self.skew(d)
1557    #end def stretch
1558
1559
1560    # test needed
1561    def rotate(self,r,rp=None,passive=False,units="radians",check=True):
1562        """
1563        Arbitrary rotation of the structure.
1564        Parameters
1565        ----------
1566        r  : `array_like, float, shape (3,3)` or `array_like, float, shape (3,)` or `str`
1567            If a 3x3 matrix, then code executes rotation consistent with this matrix --
1568            it is assumed that the matrix acts on a column-major vector (eg, v'=Rv)
1569            If a three-dimensional array, then the operation of the function depends
1570            on the input type of rp in the following ways:
1571                1. If rp is a scalar, then rp is assumed to be an angle and a rotation
1572                   of rp is made about the axis defined by r
1573                2. If rp is a vector, then rp is assumed to be an axis and a rotation is made
1574                   such that r aligns with rp
1575                3. If rp is a str, then the rotation is such that r aligns with the
1576                   axis given by the str ('x', 'y', 'z', 'a0', 'a1', or 'a2')
1577            If a str then the axis, r, is defined by the input label (e.g. 'x', 'y', 'z', 'a1', 'a2', or 'a3')
1578            and the operation of the function depends on the input type of rp in the following
1579            ways (same as above):
1580                1. If rp is a scalar, then rp is assumed to be an angle and a rotation
1581                   of rp is made about the axis defined by r
1582                2. If rp is a vector, then rp is assumed to be an axis and a rotation is made
1583                   such that r aligns with rp
1584                3. If rp is a str, then the rotation is such that r aligns with the
1585                   axis given by the str ('x', 'y', 'z', 'a0', 'a1', or 'a2')
1586        rp : `array_like, float, shape (3), optional` or `str, optional`
1587            If a 3-dimensional vector is given, then rp is assumed to be an axis and a rotation is made
1588            such that the axis r is aligned with rp.
1589            If a str, then rp is assumed to be an angle and a rotation about the axis defined by r
1590            is made by an angle rp
1591            If a str is given, then rp is assumed to be an axis defined by the given label
1592            (e.g. 'x', 'y', 'z', 'a1', 'a2', or 'a3') and a rotation is made such that the axis r
1593            is aligned with rp.
1594        passive : `bool, optional, default False`
1595            If `True`, perform a passive rotation
1596            If `False`, perform an active rotation
1597        units : `str, optional, default "radians"`
1598            Units of rp, if rp is given as an angle (scalar)
1599        check : `bool, optional, default True`
1600            Perform a check to verify rotation matrix is orthogonal
1601        """
1602        if rp is not None:
1603            dirmap = dict(x=[1,0,0],y=[0,1,0],z=[0,0,1])
1604            if isinstance(r,str):
1605                if r[0]=='a': # r= 'a0', 'a1', or 'a2'
1606                    r = self.axes[int(r[1])]
1607                else: # r= 'x', 'y', or 'z'
1608                    r = dirmap[r]
1609                #end if
1610            else:
1611                r = array(r,dtype=float)
1612                if len(r.shape)>1:
1613                    self.error('r must be given as a 1-d vector or string, if rp is not None')
1614                #end if
1615            #end if
1616            if isinstance(rp,(int,float)):
1617                if units=="radians" or units=="rad":
1618                    theta = float(rp)
1619                else:
1620                    theta = float(rp)*np.pi/180.0
1621                #end if
1622                c = np.cos(theta)
1623                s = np.sin(theta)
1624            else:
1625                if isinstance(rp,str):
1626                    if rp[0]=='a': # rp= 'a0', 'a1', or 'a2'
1627                        rp = self.axes[int(rp[1])]
1628                    else: # rp= 'x', 'y', or 'z'
1629                        rp = dirmap[rp]
1630                    #end if
1631                else:
1632                    rp = array(rp,dtype=float)
1633                #end if
1634                # go from r,rp to r,theta
1635                c = np.dot(r,rp)/np.linalg.norm(r)/np.linalg.norm(rp)
1636                if abs(c-1)<1e-6:
1637                    s = 0.0
1638                    r = np.array([1,0,0])
1639                else:
1640                    s = np.dot(np.cross(r,rp),np.cross(r,rp))/np.linalg.norm(r)/np.linalg.norm(rp)/np.linalg.norm(np.cross(r,rp))
1641                    r = np.cross(r,rp)/np.linalg.norm(np.cross(r,rp))
1642                #end if
1643            #end if
1644            # make R from r,theta
1645            R = [[     c+r[0]**2.0*(1.0-c), r[0]*r[1]*(1.0-c)-r[2]*s, r[0]*r[2]*(1.0-c)+r[1]*s],
1646                 [r[1]*r[0]*(1.0-c)+r[2]*s,      c+r[1]**2.0*(1.0-c), r[1]*r[2]*(1.0-c)-r[0]*s],
1647                 [r[2]*r[0]*(1.0-c)-r[1]*s, r[2]*r[1]*(1.0-c)+r[0]*s,      c+r[2]**2.0*(1.0-c)]]
1648        else:
1649            R = r
1650        #end if
1651        R = array(R,dtype=float)
1652        if passive:
1653            R = R.T
1654        #end if
1655        if check:
1656            if not np.allclose(dot(R,R.T),identity(len(R))):
1657                self.error('the function, rotate, must be given an orthogonal matrix')
1658            #end if
1659        #end if
1660        self.matrix_transform(R)
1661    #end def rotate
1662
1663
1664    # test needed
1665    def matrix_transform(self,A):
1666        """
1667        Arbitrary transformation matrix (column-major).
1668
1669        Parameters
1670        ----------
1671        A  : `array_like, float, shape (3,3)`
1672            Transform the structure using the matrix A. It is assumed that
1673            A is in column-major form, i.e., it transforms a vector v as
1674            v' = Av
1675        """
1676        A = A.T
1677        axinv  = inv(self.axes)
1678        axnew  = dot(self.axes,A)
1679        kaxinv = inv(self.kaxes)
1680        kaxnew = dot(self.kaxes,inv(A).T)
1681        self.pos     = dot(dot(self.pos,axinv),axnew)
1682        self.center  = dot(dot(self.center,axinv),axnew)
1683        self.kpoints = dot(dot(self.kpoints,kaxinv),kaxnew)
1684        self.axes  = axnew
1685        self.kaxes = kaxnew
1686        if self.folded_structure!=None:
1687            self.folded_structure.matrix_transform(A.T)
1688        #end if
1689    #end def matrix_transform
1690
1691
1692    # test needed
1693    def skew(self,skew):
1694        """
1695        Arbitrary transformation matrix (row-major).
1696
1697        Parameters
1698        ----------
1699        skew  : `array_like, float, shape (3,3)`
1700            Transform the structure using the matrix skew. It is assumed that
1701            skew is in row-major form, i.e., it transforms a vector v as
1702            v' = vT
1703        """
1704        self.matrix_transform(skew.T)
1705    #end def skew
1706
1707
1708    # test needed
1709    def change_units(self,units,folded=True):
1710        if units!=self.units:
1711            scale = convert(1,self.units,units)
1712            self.scale  *= scale
1713            self.axes   *= scale
1714            self.pos    *= scale
1715            self.center *= scale
1716            self.kaxes  /= scale
1717            self.kpoints/= scale
1718            self.units  = units
1719        #end if
1720        if self.folded_structure!=None and folded:
1721            self.folded_structure.change_units(units,folded=folded)
1722        #end if
1723    #end def change_units
1724
1725
1726    # test needed
1727    # insert sep space at loc along axis
1728    #   if sep<0, space is removed instead
1729    def cleave(self,axis,loc,sep=None,remove=False,tol=1e-6):
1730        self.remove_folded_structure()
1731        if isinstance(axis,int):
1732            if sep is None:
1733                self.error('separation induced by cleave must be provided')
1734            #end if
1735            v = self.face_vectors()[axis]
1736            if isinstance(loc,float):
1737                c = loc*v/norm(v)
1738            #end if
1739        else:
1740            v = axis
1741        #end if
1742        c = array(c)  # point on cleave plane
1743        v = array(v)  # normal vector to cleave plane, norm is cleave separation
1744        if sep!=None:
1745            v = abs(sep)*v/norm(v)
1746        #end if
1747        if norm(v)<tol:
1748            return
1749        #end if
1750        vn = array(v/norm(v))
1751        if sep!=None and sep<0:
1752            v = -v # preserve the normal direction for atom identification, but reverse the shift direction
1753        #end if
1754        self.recorner()  # want box contents to be static
1755        if self.has_axes():
1756            components = 0
1757            dim = self.dim
1758            axes = self.axes
1759            for i in range(dim):
1760                i2 = (i+1)%dim
1761                i3 = (i+2)%dim
1762                a2 = axes[i2]/norm(axes[i2])
1763                a3 = axes[i3]/norm(axes[i3])
1764                comp = abs(dot(a2,vn))+abs(dot(a3,vn))
1765                if comp < 1e-6:
1766                    components+=1
1767                    iaxis = i
1768                #end if
1769            #end for
1770            commensurate = components==1
1771            if not commensurate:
1772                self.error('cannot insert vacuum because cleave is incommensurate with the cell\n  cleave plane must be parallel to a cell face')
1773            #end if
1774            a = self.axes[iaxis]
1775            #self.axes[iaxis] = (1.+dot(v,a)/dot(a,a))*a
1776            self.axes[iaxis] = (1.+dot(v,v)/dot(v,a))*a
1777        #end if
1778        indices = []
1779        pos = self.pos
1780        for i in range(len(pos)):
1781            p = pos[i]
1782            comp = dot(p-c,vn)
1783            if comp>0 or abs(comp)<tol:
1784                pos[i] += v
1785                indices.append(i)
1786            #end if
1787        #end for
1788        if remove:
1789            self.remove(indices)
1790        #end if
1791    #end def cleave
1792
1793
1794    # test needed
1795    def translate(self,v):
1796        v = array(v)
1797        pos = self.pos
1798        for i in range(len(pos)):
1799            pos[i]+=v
1800        #end for
1801        self.center+=v
1802        if self.folded_structure!=None:
1803            self.folded_structure.translate(v)
1804        #end if
1805    #end def translate
1806
1807
1808    # test needed
1809    def slide(self,v,recenter=True):
1810        v = array(v)
1811        pos = self.pos
1812        for i in range(len(pos)):
1813            pos[i]+=v
1814        #end for
1815        if recenter:
1816            self.recenter()
1817        #end if
1818        if self.folded_structure!=None:
1819            self.folded_structure.slide(v,recenter)
1820        #end if
1821    #end def slide
1822
1823
1824    # test needed
1825    def zero_corner(self):
1826        corner = self.center-self.axes.sum(0)/2
1827        self.translate(-corner)
1828    #end def zero_corner
1829
1830
1831    # test needed
1832    def locate_simple(self,pos):
1833        pos = array(pos)
1834        if pos.shape==(self.dim,):
1835            pos = [pos]
1836        #end if
1837        nn = nearest_neighbors(1,self.pos,pos)
1838        return nn.ravel()
1839    #end def locate_simple
1840
1841
1842    # test needed
1843    def locate(self,identifiers,radii=None,exterior=False):
1844        indices = None
1845        if isinstance(identifiers,Structure):
1846            cell = identifiers
1847            indices = cell.inside(self.pos)
1848        elif isinstance(identifiers,ndarray) and identifiers.dtype==bool:
1849            indices = arange(len(self.pos))[identifiers]
1850        elif isinstance(identifiers,int):
1851            indices = [identifiers]
1852        elif len(identifiers)>0 and isinstance(identifiers[0],int):
1853            indices = identifiers
1854        elif isinstance(identifiers,str):
1855            atom = identifiers
1856            indices = []
1857            for i in range(len(self.elem)):
1858                if self.elem[i]==atom:
1859                    indices.append(i)
1860                #end if
1861            #end for
1862        elif len(identifiers)>0 and isinstance(identifiers[0],str):
1863            indices = []
1864            for atom in identifiers:
1865                for i in range(len(self.elem)):
1866                    if self.elem[i]==atom:
1867                        indices.append(i)
1868                    #end if
1869                #end for
1870            #end for
1871        #end if
1872        if radii is not None or indices is None:
1873            if indices is None:
1874                pos = identifiers
1875            else:
1876                pos = self.pos[indices]
1877            #end if
1878            if isinstance(radii,float) or isinstance(radii,int):
1879                radii = len(pos)*[radii]
1880            elif radii is not None and len(radii)!=len(pos):
1881                self.error('lengths of input radii and positions do not match\n  len(radii)={0}\n  len(pos)={1}'.format(len(radii),len(pos)))
1882            #end if
1883            dtable = self.min_image_distances(pos)
1884            indices = []
1885            if radii is None:
1886                for i in range(len(pos)):
1887                    indices.append(dtable[i].argmin())
1888                #end for
1889            else:
1890                ipos = arange(len(self.pos))
1891                for i in range(len(pos)):
1892                    indices.extend(ipos[dtable[i]<radii[i]])
1893                #end for
1894            #end if
1895        #end if
1896        if exterior:
1897            indices = list(set(range(len(self.pos)))-set(indices))
1898        #end if
1899        return indices
1900    #end def locate
1901
1902
1903    def freeze(self,identifiers=None,radii=None,exterior=False,negate=False,directions='xyz'):
1904        if isinstance(identifiers,ndarray) and identifiers.shape==self.pos.shape and identifiers.dtype==bool:
1905            if negate:
1906                self.frozen = ~identifiers
1907            else:
1908                self.frozen = identifiers.copy()
1909            #end if
1910            return
1911        #end if
1912        if identifiers is None:
1913            indices = arange(len(self.pos),dtype=int)
1914        else:
1915            indices = self.locate(identifiers,radii,exterior)
1916        #end if
1917        if len(indices)==0:
1918            self.error('failed to select any atoms to freeze')
1919        #end if
1920        if isinstance(directions,str):
1921            d = empty((3,),dtype=bool)
1922            d[0] = 'x' in directions
1923            d[1] = 'y' in directions
1924            d[2] = 'z' in directions
1925            directions = len(indices)*[d]
1926        else:
1927            directions = array(directions,dtype=bool)
1928        #end if
1929        if self.frozen is None:
1930            self.frozen = zeros(self.pos.shape,dtype=bool)
1931        #end if
1932        frozen = self.frozen
1933        i=0
1934        if not negate:
1935            for index in indices:
1936                frozen[index] = directions[i]
1937                i+=1
1938            #end for
1939        else:
1940            for index in indices:
1941                frozen[index] = directions[i]==False
1942                i+=1
1943            #end for
1944        #end if
1945    #end def freeze
1946
1947
1948    def is_frozen(self):
1949        if self.frozen is None:
1950            return np.zeros((len(self.pos),),dtype=bool)
1951        else:
1952            return self.frozen.sum(1)>0
1953        #end if
1954    #end def is_frozen
1955
1956
1957    # test needed
1958    def magnetize(self,identifiers=None,magnetization='',**mags):
1959        magsin = None
1960        if isinstance(identifiers,obj):
1961            magsin = identifiers.copy()
1962        elif isinstance(magnetization,obj):
1963            magsin = magnetization.copy()
1964        #endif
1965        if magsin!=None:
1966            magsin.transfer_from(mags)
1967            mags = magsin
1968            identifiers = None
1969            magnetization = ''
1970        #end if
1971        for e,m in mags.items():
1972            if not e in self.elem:
1973                self.error('cannot magnetize non-existent element {0}'.format(e))
1974            elif m is not None or not isinstance(m,int):
1975                self.error('magnetizations provided must be either None or integer\n  you provided: {0}\n  full magnetization request provided:\n {1}'.format(m,mags))
1976            #end if
1977            self.mag[self.elem==e] = m
1978        #end for
1979        if identifiers is None and magnetization=='':
1980            return
1981        elif magnetization=='':
1982            magnetization = identifiers
1983            indices = list(range(len(self.elem)))
1984        else:
1985            indices = self.locate(identifiers)
1986        #end if
1987        if not isinstance(magnetization,(list,tuple,ndarray)):
1988            magnetization = [magnetization]
1989        #end if
1990        for m in magnetization:
1991            if m is not None or not isinstance(m,int):
1992                self.error('magnetizations provided must be either None or integer\n  you provided: {0}\n  full magnetization list provided: {1}'.format(m,magnetization))
1993            #end if
1994        #end for
1995        if len(magnetization)==1:
1996            m = magnetization[0]
1997            for i in indices:
1998                self.mag[i] = m
1999            #end for
2000        elif len(magnetization)==len(indices):
2001            for i in range(len(indices)):
2002                self.mag[indices[i]] = magnetization[i]
2003            #end for
2004        else:
2005            self.error('magnetization list and list selected atoms differ in length\n  length of magnetization list: {0}\n  number of atoms selected: {1}\n  magnetization list: {2}\n  atom indices selected: {3}\n  atoms selected: {4}'.format(len(magnetization),len(indices),magnetization,indices,self.elem[indices]))
2006        #end if
2007    #end def magnetize
2008
2009
2010    def is_magnetic(self,tol=1e-8):
2011        magnetic = False
2012        if self.mag is not None:
2013            for m in self.mag:
2014                if m is not None and abs(m)>tol:
2015                    magnetic = True
2016                    break
2017                #end if
2018            #end for
2019        #end if
2020        return magnetic
2021    #end def is_magnetic
2022
2023
2024    # test needed
2025    def carve(self,identifiers):
2026        indices = self.locate(identifiers)
2027        if isinstance(identifiers,Structure):
2028            sub = identifiers
2029            sub.elem = self.elem[indices].copy()
2030            sub.pos  = self.pos[indices].copy()
2031        else:
2032            sub = self.copy()
2033            sub.elem = self.elem[indices]
2034            sub.pos  = self.pos[indices]
2035        #end if
2036        sub.host_indices = array(indices)
2037        return sub
2038    #end def carve
2039
2040
2041    # test needed
2042    def remove(self,identifiers):
2043        indices = self.locate(identifiers)
2044        keep = list(set(range(len(self.pos)))-set(indices))
2045        erem = self.elem[indices]
2046        prem = self.pos[indices]
2047        self.elem = self.elem[keep]
2048        self.pos  = self.pos[keep]
2049        self.remove_folded_structure()
2050        return erem,prem
2051    #end def remove
2052
2053
2054    # test needed
2055    def replace(self,identifiers,elem=None,pos=None,radii=None,exterior=False):
2056        indices = self.locate(identifiers,radii,exterior)
2057        if isinstance(elem,Structure):
2058            cell = elem
2059            elem = cell.elem
2060            pos  = cell.pos
2061        elif elem==None:
2062            elem = self.elem
2063        #end if
2064        indices=array(indices)
2065        elem=array(elem,dtype=object)
2066        pos =array(pos)
2067        nrem = len(indices)
2068        nadd = len(pos)
2069        if nadd<nrem:
2070            ar = array(list(range(0,nadd)))
2071            rr = array(list(range(nadd,nrem)))
2072            self.elem[indices[ar]] = elem[:]
2073            self.pos[indices[ar]]  = pos[:]
2074            self.remove(indices[rr])
2075        elif nadd>nrem:
2076            ar = array(list(range(0,nrem)))
2077            er = array(list(range(nrem,nadd)))
2078            self.elem[indices[ar]] = elem[ar]
2079            self.pos[indices[ar]]  = pos[ar]
2080            ii = indices[ar[-1]]
2081            self.set_elem( list(self.elem[0:ii])+list(elem[er])+list(self.elem[ii:]) )
2082            self.pos = array( list(self.pos[0:ii])+list(pos[er])+list(self.pos[ii:]) )
2083        else:
2084            self.elem[indices] = elem[:]
2085            self.pos[indices]  = pos[:]
2086        #end if
2087        self.remove_folded_structure()
2088    #end def replace
2089
2090
2091    # test needed
2092    def replace_nearest(self,elem,pos=None):
2093        if isinstance(elem,Structure):
2094            cell = elem
2095            elem = cell.elem
2096            pos  = cell.pos
2097        #end if
2098        nn = nearest_neighbors(1,self.pos,pos)
2099        np = len(pos)
2100        nps= len(self.pos)
2101        d = empty((np,))
2102        ip = array(list(range(np)))
2103        ips= nn.ravel()
2104        for i in ip:
2105            j = ips[i]
2106            d[i]=sqrt(((pos[i]-self.pos[j])**2).sum())
2107        #end for
2108        order = d.argsort()
2109        ip = ip[order]
2110        ips=ips[order]
2111        replacable = empty((nps,))
2112        replacable[:] = False
2113        replacable[ips]=True
2114        insert = []
2115        last_replaced=nps-1
2116        for n in range(np):
2117            i = ip[n]
2118            j = ips[n]
2119            if replacable[j]:
2120                self.pos[j] = pos[i]
2121                self.elem[j]=elem[i]
2122                replacable[j]=False
2123                last_replaced = j
2124            else:
2125                insert.append(i)
2126            #end if
2127        #end for
2128        insert=array(insert)
2129        ii = last_replaced
2130        if len(insert)>0:
2131            self.set_elem( list(self.elem[0:ii])+list(elem[insert])+list(self.elem[ii:]) )
2132            self.pos = array( list(self.pos[0:ii])+list(pos[insert])+list(self.pos[ii:]) )
2133        #end if
2134        self.remove_folded_structure()
2135    #end def replace_nearest
2136
2137
2138    # test needed
2139    def point_defect(self,identifiers=None,elem=None,dr=None):
2140        if isinstance(elem,str):
2141            elem = [elem]
2142            if dr!=None:
2143                dr = [dr]
2144            #end if
2145        #end if
2146        if not 'point_defects' in self:
2147            self.point_defects = obj()
2148        #end if
2149        point_defects = self.point_defects
2150        ncenters = len(point_defects)
2151        if identifiers is None:
2152            index = ncenters
2153            if index>=len(self.pos):
2154                self.error('attempted to add a point defect at index {0}, which does not exist\n  for reference there are {1} atoms in the structure'.format(index,len(self.pos)))
2155            #end if
2156        else:
2157            indices = self.locate(identifiers)
2158            if len(indices)>1:
2159                self.error('{0} atoms were located by identifiers provided\n  a point defect replaces only a single atom\n  atom indices located: {1}'.format(len(indices),indices))
2160            #end if
2161            index = indices[0]
2162        #end if
2163        if elem is None:
2164            self.error('must supply substitutional elements comprising the point defect\n  expected a list or similar for input argument elem')
2165        elif len(elem)>1 and dr is None:
2166            self.error('must supply displacements (dr) since many atoms comprise the point defect')
2167        elif dr!=None and len(elem)!=len(dr):
2168            self.error('elem and dr must have the same length')
2169        #end if
2170        r = self.pos[index]
2171        e = self.elem[index]
2172        elem = array(elem)
2173        pos = zeros((len(elem),len(r)))
2174        if dr is None:
2175            rc = r
2176            for i in range(len(elem)):
2177                pos[i] = r
2178            #end for
2179        else:
2180            nrc = 0
2181            rc  = 0*r
2182            dr = array(dr)
2183            for i in range(len(elem)):
2184                pos[i] = r + dr[i]
2185                if norm(dr[i])>1e-5:
2186                    rc+=dr[i]
2187                    nrc+=1
2188                #end if
2189            #end for
2190            if nrc==0:
2191                rc = r
2192            else:
2193                rc = r + rc/nrc
2194            #end if
2195        #end if
2196        point_defect = obj(
2197            center = rc,
2198            elem_replaced = e,
2199            elem = elem,
2200            pos  = pos
2201            )
2202        point_defects.append(point_defect)
2203        elist = list(self.elem)
2204        plist = list(self.pos)
2205        if len(elem)==0 or len(elem)==1 and elem[0]=='':
2206            elist.pop(index)
2207            plist.pop(index)
2208        else:
2209            elist[index] = elem[0]
2210            plist[index] = pos[0]
2211            for i in range(1,len(elem)):
2212                elist.append(elem[i])
2213                plist.append(pos[i])
2214            #end for
2215        #end if
2216        self.set_elem(elist)
2217        self.pos  = array(plist)
2218        self.remove_folded_structure()
2219    #end def point_defect
2220
2221
2222    # test needed
2223    def species(self,symbol=False):
2224        if not symbol:
2225            return set(self.elem)
2226        else:
2227            species_labels = set(self.elem)
2228            species = set()
2229            for e in species_labels:
2230                is_elem,symbol = is_element(e,symbol=True)
2231                species.add(symbol)
2232            #end for
2233            return species_labels,species
2234        #end if
2235    #end def species
2236
2237
2238    # test needed
2239    def ordered_species(self,symbol=False):
2240        speclab_set    = set()
2241        species_labels = []
2242        if not symbol:
2243            for e in self.elem:
2244                if e not in speclab_set:
2245                    speclab_set.add(e)
2246                    species_labels.append(e)
2247                #end if
2248            #end for
2249            return species_labels
2250        else:
2251            species  = []
2252            spec_set = set()
2253            for e in self.elem:
2254                is_elem,symbol = is_element(e,symbol=True)
2255                if e not in speclab_set:
2256                    speclab_set.add(e)
2257                    species_labels.append(e)
2258                #end if
2259                if symbol not in spec_set:
2260                    spec_set.add(symbol)
2261                    species.append(symbol)
2262                #end if
2263            #end for
2264            return species_labels,species
2265        #end if
2266    #end def ordered_species
2267
2268
2269    # test needed
2270    def order_by_species(self,folded=False):
2271        species        = []
2272        species_counts = []
2273        elem_indices   = []
2274
2275        spec_set = set()
2276        for i in range(len(self.elem)):
2277            e = self.elem[i]
2278            if not e in spec_set:
2279                spec_set.add(e)
2280                species.append(e)
2281                species_counts.append(0)
2282                elem_indices.append([])
2283            #end if
2284            sindex = species.index(e)
2285            species_counts[sindex] += 1
2286            elem_indices[sindex].append(i)
2287        #end for
2288
2289        elem_order = []
2290        for elem_inds in elem_indices:
2291            elem_order.extend(elem_inds)
2292        #end for
2293        self.reorder(elem_order)
2294
2295        if folded and self.folded_structure!=None:
2296            self.folded_structure.order_by_species(folded)
2297        #end if
2298
2299        return species,species_counts
2300    #end def order_by_species
2301
2302
2303    # test needed
2304    def reorder(self,order):
2305        order = array(order)
2306        self.elem = self.elem[order]
2307        self.pos  = self.pos[order]
2308    #end def reorder
2309
2310
2311    # test needed
2312    # find layers parallel to a particular cell face
2313    #   layers are found by scanning a window of width dtol along the axis and counting
2314    #     the number of atoms within the window.  window position w/ max number of atoms
2315    #     defines the layer.  layer distance is the window position.
2316    #   the resolution of the scan is determined by dbin
2317    #   (axis length)/dbin is the number of fine bins
2318    #   dtol/dbin is the number of fine bins in the moving (boxcar) window
2319    #   plot=True: plot the layer histogram (fine hist and moving average)
2320    #   composition=True: return the composition of each layer (count of each species)
2321    # returns an object containing indices of atoms in each layer by distance along axis
2322    #   example: structure w/ 3 layers of 4 atoms each at distances 3.0, 6.0, and 9.0 Angs.
2323    #   layers
2324    #     3.0 = [ 0, 1, 2, 3 ]
2325    #     6.0 = [ 4, 5, 6, 7 ]
2326    #     9.0 = [ 8, 9,10,11 ]
2327    def layers(self,axis=0,dtol=0.03,dbin=0.01,plot=False,composition=False):
2328        nbox = int(dtol/dbin)
2329        if nbox%2==0:
2330            nbox+=1
2331        #end if
2332        nwind = (nbox-1)//2
2333        s = self.copy()
2334        s.recenter()
2335        vaxis = s.axes[axis]
2336        daxis = norm(vaxis)
2337        naxis = vaxis/daxis
2338        dbin  = dtol/nbox
2339        nbins = int(ceil(daxis/dbin))
2340        dbin  = daxis/nbins
2341        dbins = daxis*(arange(nbins)+.5)/nbins
2342        dists = daxis*s.pos_unit()[:,axis]
2343        hist  = zeros((nbins,),dtype=int)
2344        boxhist = zeros((nbins,),dtype=int)
2345        ihist = obj()
2346        iboxhist = obj()
2347        index = 0
2348        for d in dists:
2349            ibin = int(floor(d/dbin))
2350            hist[ibin]+=1
2351            if not ibin in ihist:
2352                ihist[ibin] = []
2353            #end if
2354            ihist[ibin].append(index)
2355            index+=1
2356        #end for
2357        for ib in range(nbins):
2358            for i in range(ib-nwind,ib+nwind+1):
2359                n = hist[i%nbins]
2360                if n>0:
2361                    boxhist[ib]+=n
2362                    if not ib in iboxhist:
2363                        iboxhist[ib] = []
2364                    #end if
2365                    iboxhist[ib].extend(ihist[i%nbins])
2366                #end if
2367            #end for
2368        #end for
2369        peaks = []
2370        nlast=0
2371        for ib in range(nbins):
2372            n = boxhist[ib]
2373            if nlast==0 and n>0:
2374                pcur = []
2375                peaks.append(pcur)
2376            #end if
2377            if n>0:
2378                pcur.append(ib)
2379            #end if
2380            nlast = n
2381        #end for
2382        if boxhist[0]>0 and boxhist[-1]>0:
2383            peaks[0].extend(peaks[-1])
2384            peaks.pop()
2385        #end if
2386        layers = obj()
2387        ip = []
2388        for peak in peaks:
2389            ib = peak[boxhist[peak].argmax()]
2390            ip.append(ib)
2391            pindices = iboxhist[ib]
2392            ldist = dbins[ib] # distance is along an axis vector
2393            faxis = self.face_vectors()[axis]
2394            ldist = dot(ldist*naxis,faxis/norm(faxis))
2395            layers[ldist] = array(pindices,dtype=int)
2396        #end for
2397        if plot:
2398            plt.plot(dbins,boxhist,'b.-',label='boxcar histogram')
2399            plt.plot(dbins,hist,'r.-',label='fine histogram')
2400            plt.plot(dbins[ip],boxhist[ip],'rv',markersize=20)
2401            plt.show()
2402            plt.legend()
2403        #end if
2404        if not composition:
2405            return layers
2406        else:
2407            return layers,self.layer_composition(layers)
2408        #end if
2409    #end def layers
2410
2411
2412    # test needed
2413    def layer_composition(self,layers):
2414        lcomp = obj()
2415        for d,ind in layers.items():
2416            comp = obj()
2417            elem = self.elem[ind]
2418            for e in elem:
2419                if e not in comp:
2420                    comp[e] = 1
2421                else:
2422                    comp[e] += 1
2423                #end if
2424            #end for
2425            lcomp[d]=comp
2426        #end for
2427        return lcomp
2428    #end def layer_composition
2429
2430
2431    # test needed
2432    def shells(self,identifiers,radii=None,exterior=False,cumshells=False,distances=False,dtol=1e-6):
2433        # get indices for 'core' and 'bulk'
2434        #   core is selected by identifiers, forms core for shells to be built around
2435        #   bulk is all atoms except for core
2436        if identifiers=='point_defects':
2437            if not 'point_defects' in self:
2438                self.error('requested shells around point defects, but structure has no point defects')
2439            #end if
2440            core = []
2441            for pd in self.point_defects:
2442                core.append(pd.center)
2443            #end for
2444            core = array(core)
2445            bulk_ind = self.locate(core,radii=dtol,exterior=True)
2446            core_ind = self.locate(bulk_ind,exterior=True)
2447            bulk = self.pos[bulk_ind]
2448        else:
2449            core_ind = self.locate(identifiers,radii,exterior)
2450            bulk_ind = self.locate(core_ind,exterior=True)
2451            core = self.pos[core_ind]
2452            bulk = self.pos[bulk_ind]
2453        #end if
2454        bulk_ind = array(bulk_ind,dtype=int)
2455        # build distance table between bulk and core
2456        dtable = self.distance_table(bulk,core)
2457        # find shortest distance for each bulk atom to any core atom and order by distance
2458        dist   = dtable.min(1)
2459        ind    = arange(len(bulk))
2460        order  = dist.argsort()
2461        dist   = dist[order]
2462        ind    = bulk_ind[ind[order]]
2463        # find shells around the core
2464        #   the closest atom to the core starts the first shell and defines a shell distance
2465        #   other atoms are in the shell if within dtol distance of the first atom
2466        #   otherwise a new shell is started
2467        ns = 0
2468        ds = -1
2469        shells = obj()
2470        shells[ns] = list(core_ind)  # first shell is all core atoms
2471        dshells = [0.]
2472        for n in range(len(dist)):
2473            if abs(dist[n]-ds)>dtol:
2474                shell = [ind[n]]   # new shell starts with single atom
2475                ns+=1
2476                shells[ns] = shell
2477                ds = dist[n]       # shell distance is distance of this atom from core
2478                dshells.append(ds)
2479            else:
2480                shell.append(ind[n])
2481            #end if
2482        #end for
2483        dshells = array(dshells,dtype=float)
2484        results = [shells]
2485        if cumshells:
2486            # assemble cumulative shells, ie cumshell[ns] = sum(shells[n],n=0 to ns)
2487            cumshells = obj()
2488            cumshells[0] = list(shells[0])
2489            for ns in range(1,len(shells)):
2490                cumshells[ns] = cumshells[ns-1]+shells[ns]
2491            #end for
2492            for ns,cshell in cumshells.items():
2493                cumshells[ns] = array(cshell,dtype=int)
2494            #end for
2495            results.append(cumshells)
2496        #end if
2497        for ns,shell in shells.items():
2498            shells[ns] = array(shell,dtype=int)
2499        if distances:
2500            results.append(dshells)
2501        #end if
2502        if len(results)==1:
2503            results = results[0]
2504        #end if
2505        return results
2506    #end def shells
2507
2508
2509    # test needed
2510    # find connected sets of atoms.
2511    #   indices is a list of atomic indices to consider (self.pos[indices] are their positions)
2512    #   atoms are considered connected if they are within rmax of each other
2513    #   order sets the maximum number of atoms in any connected graph
2514    #     order = 1 returns single atoms
2515    #     order = 2 returns dimers + order=1 results
2516    #     order = 3 returns trimers + order=2 results
2517    #     ...
2518    #   degree is explained w/ an example: a triangle of atoms 0,1,2  and a line of atoms 3,4,5 (3 & 5 are not neighbors)
2519    #     degree = False : returned object (cgraphs) has following structure:
2520    #       cgraphs[1] = [ (0,), (1,), (2,), (3,), (4,), (5,) ]  # first  order connected graphs (atoms)
2521    #       cgraphs[2] = [ (0,1), (0,2), (1,2), (3,4), (4,5) ]   # second order connected graphs (dimers)
2522    #       cgraphs[3] = [ (0,1,2), (3,4,5) ]                    # third  order connected graphs (trimers)
2523    #     degree = True : returned object (cgraphs) has following structure:
2524    #       cgraphs
2525    #         1      # first  order connected graphs (atoms)
2526    #           0    #   sum of vertex degrees is 0 (a single atom has no neighbors)
2527    #             (0,) = [ (0,), (1,), (2,), (3,), (4,), (5,) ]   # graphs with vertex degree (0,)
2528    #         2      # second order connected graphs (dimers)
2529    #           2    #   sum of vertex degrees is 2 (each atom is connected to 1 neighbor)
2530    #             (1,1) = [ (0,1), (0,2), (1,2), (3,4), (4,5) ]   # graphs with vertex degree (1,1)
2531    #         3      # third  order connected graphs (trimers)
2532    #           4    #   sum of vertex degrees is 4 (2 atoms have 1 neighbor and 1 atom has 2)
2533    #             (1,1,2) = [ (3,5,4) ]
2534    #           6    #   sum of vertex degrees is 6 (each atom is connected to 2 others)
2535    #             (2,2,2) = [ (0,1,2) ]           # graphs with vertex degree (2,2,2)
2536    def connected_graphs(self,order,indices=None,rmax=None,nmax=None,voronoi=False,degree=False,site_maps=False,**spec_max):
2537        if indices is None:
2538            indices = arange(len(self.pos),dtype=int)
2539            pos = self.pos
2540        else:
2541            pos = self.pos[indices]
2542        #end if
2543        np = len(indices)
2544        neigh_table = []
2545        actual_indices = None
2546        if voronoi:
2547            actual_indices = True
2548            neighbors = self.voronoi_neighbors(indices,restrict=True,distance_ordered=False)
2549            for nilist in neighbors:
2550                neigh_table.append(nilist)
2551            #end for
2552        else:
2553            actual_indices = False
2554            elem = set(self.elem[indices])
2555            spec = set(spec_max.keys())
2556            if spec==elem or rmax!=None:
2557                None
2558            elif spec<elem and nmax!=None:
2559                for e in elem:
2560                    if e not in spec:
2561                        spec_max[e] = nmax
2562                    #end if
2563                #end for
2564            #end if
2565            # get neighbor table for subset of atoms specified by indices
2566            nt,dt = self.neighbor_table(pos,pos,distances=True)
2567            # determine how many neighbors to consider based on rmax (all are neighbors if rmax is None)
2568            nneigh = zeros((np,),dtype=int)
2569            if len(spec_max)>0:
2570                for n in range(np):
2571                    nneigh[n] = min(spec_max[self.elem[n]],len(nt[n]))
2572                #end for
2573            elif rmax is None:
2574                nneigh[:] = np
2575            else:
2576                nneigh = (dt<rmax).sum(1)
2577            #end if
2578            for i in range(np):
2579                neigh_table.append(nt[i,1:nneigh[i]])
2580            #end for
2581            del nt,dt,nneigh,elem,spec,rmax
2582        #end if
2583        neigh_table = array(neigh_table,dtype=int)
2584        # record which atoms are neighbors to each other
2585        neigh_pairs = set()
2586        if actual_indices:
2587            for i in range(np):
2588                for ni in neigh_table[i]:
2589                    neigh_pairs.add((i,ni))
2590                    neigh_pairs.add((ni,i))
2591                #end for
2592            #end for
2593        else:
2594            for i in range(np):
2595                for ni in neigh_table[i]:
2596                    ii = indices[i]
2597                    jj = indices[ni]
2598                    neigh_pairs.add((ii,jj))
2599                    neigh_pairs.add((jj,ii))
2600                #end for
2601            #end for
2602        #end if
2603        # find the connected graphs
2604        graphs_found = set()  # map to contain tuples of connected atom's indices
2605        cgraphs = obj()
2606        for o in range(1,order+1): # organize by order
2607            cgraphs[o] = []
2608        #end for
2609        if order>0:
2610            cg = cgraphs[1]
2611            for i in range(np):  # list of single atoms
2612                gi = (i,)              # graph indices
2613                cg.append(gi)          # add graph to graph list of order 1
2614                graphs_found.add(gi)   # add graph to set of all graphs
2615            #end for
2616            for o in range(2,order+1): # graphs of order o are found by adding all
2617                cglast = cgraphs[o-1]  # possible single neighbors to each graph of order o-1
2618                cg     = cgraphs[o]
2619                for gilast in cglast:    # all graphs of order o-1
2620                    for i in gilast:       # all indices in each graph of order o-1
2621                        for ni in neigh_table[i]: # neighbors of selected atom in o-1 graph
2622                            gi = tuple(sorted(gilast+(ni,))) # new graph with neighbor added
2623                            if gi not in graphs_found and len(set(gi))==o: # add it if it is new and really is order o
2624                                graphs_found.add(gi)  # add graph to set of all graphs
2625                                cg.append(gi)         # add graph to graph list of order o
2626                            #end if
2627                        #end for
2628                    #end for
2629                #end for
2630            #end for
2631        #end if
2632        if actual_indices:
2633            for o,cg in cgraphs.items():
2634                cgraphs[o] = array(cg,dtype=int)
2635            #end for
2636        else:
2637            # map indices back to actual atomic indices
2638            for o,cg in cgraphs.items():
2639                cgmap = []
2640                for gi in cg:
2641                    #gi = array(gi)
2642                    gimap = tuple(sorted(indices[array(gi)]))
2643                    cgmap.append(gimap)
2644                #end for
2645                cgraphs[o] = array(sorted(cgmap),dtype=int)
2646            #end for
2647        #end if
2648        # reorganize the graph listing by cluster and vertex degree, if desired
2649        if degree:
2650            #degree_map = obj()
2651            cgraphs_deg = obj()
2652            for o,cg in cgraphs.items():
2653                dgo = obj()
2654                cgraphs_deg[o] = dgo
2655                for gi in cg:
2656                    di = zeros((o,),dtype=int)
2657                    for m in range(o):
2658                        i = gi[m]
2659                        for n in range(m+1,o):
2660                            j = gi[n]
2661                            if (i,j) in neigh_pairs:
2662                                di[m]+=1
2663                                di[n]+=1
2664                            #end if
2665                        #end for
2666                    #end for
2667                    d = int(di.sum())
2668                    dorder = di.argsort()
2669                    di = tuple(di[dorder])
2670                    gi = tuple(array(gi)[dorder])
2671                    if not d in dgo:
2672                        dgo[d]=obj()
2673                    #end if
2674                    dgd = dgo[d]
2675                    if not di in dgd:
2676                        dgd[di] = []
2677                    #end if
2678                    dgd[di].append(gi)
2679                    #degree_map[gi] = d,di
2680                #end for
2681                for dgd in dgo:
2682                    for di,dgi in dgd.items():
2683                        dgd[di]=array(sorted(dgi),dtype=int)
2684                    #end for
2685                #end for
2686            #end for
2687            cgraphs = cgraphs_deg
2688        #end if
2689
2690        if not site_maps:
2691            return cgraphs
2692        else:
2693            cmaps = obj()
2694            if not degree:
2695                for order,og in cgraphs.items():
2696                    cmap = obj()
2697                    for slist in og:
2698                        for s in slist:
2699                            if not s in cmap:
2700                                cmap[s] = obj()
2701                            #end if
2702                            cmap[s].append(slist)
2703                        #end for
2704                    #end for
2705                    cmaps[order] = cmap
2706                #end for
2707            else:
2708                for order,og in cgraphs.items():
2709                    for total_degree,tg in og.items():
2710                        for local_degree,lg in tg.items():
2711                            cmap = obj()
2712                            for slist in lg:
2713                                n=0
2714                                for s in slist:
2715                                    d = local_degree[n]
2716                                    if not s in cmap:
2717                                        cmap[s] = obj()
2718                                    #end if
2719                                    if not d in cmap[s]:
2720                                        cmap[s][d] = obj()
2721                                    #end if
2722                                    cmap[s][d].append(slist)
2723                                    n+=1
2724                                #end for
2725                            #end for
2726                            cmaps.add_attribute_path((order,total_degree,local_degree),cmap)
2727                        #end for
2728                    #end for
2729                #end for
2730            #end if
2731            return cgraphs,cmaps
2732        #end if
2733    #end def connected_graphs
2734
2735
2736    # test needed
2737    # returns connected graphs that are rings up to the requested order
2738    #   rings are constructed by pairing lines that share endpoints
2739    #   all vertices of a ring have degree two
2740    def ring_graphs(self,order,**kwargs):
2741        # get all half order connected graphs
2742        line_order = order//2+order%2+1
2743        cgraphs = self.connected_graphs(line_order,degree=True,site_maps=False,**kwargs)
2744        # collect half order graphs that are lines
2745        lgraphs = obj()
2746        for o in range(2,line_order+1):
2747            total_degree  = 2*o-2
2748            vertex_degree = tuple([1,1]+(o-2)*[2])
2749            lg = None
2750            if o in cgraphs:
2751                cg = cgraphs[o]
2752                if total_degree in cg:
2753                    dg = cg[total_degree]
2754                    if vertex_degree in dg:
2755                        lg = dg[vertex_degree]
2756                    #end if
2757                #end if
2758            #end if
2759            if lg!=None:
2760                lg_end = obj()
2761                for gi in lg:
2762                    end_key = tuple(sorted(gi[0:2])) # end points
2763                    if end_key not in lg_end:
2764                        lg_end[end_key] = []
2765                    #end if
2766                    lg_end[end_key].append(tuple(gi))
2767                #end for
2768                lgraphs[o] = lg_end
2769            #end if
2770        #end for
2771        # contruct rings from lines that share endpoints
2772        rgraphs = obj()
2773        for o in range(3,order+1):
2774            o1 = o/2+1    # split half order for odd, same for even,
2775            o2 = o1+o%2
2776            lg1 = lgraphs.get_optional(o1,None) # sets of half order lines
2777            lg2 = lgraphs.get_optional(o2,None)
2778            if lg1!=None and lg2!=None:
2779                rg = []
2780                rset = set()
2781                for end_key,llist1 in lg1.items(): # list of lines sharing endpoints
2782                    if end_key in lg2:
2783                        llist2 = lg2[end_key]          # second list of lines sharing endpoints
2784                        for gi1 in llist1:             # combine line pairs into rings
2785                            for gi2 in llist2:
2786                                ri = tuple(sorted(set(gi1+gi2[2:]))) # ring indices
2787                                if ri not in rset and len(ri)==o:    # exclude repeated lines or rings
2788                                    rg.append(ri)
2789                                    rset.add(ri)
2790                                #end if
2791                            #end for
2792                        #end for
2793                    #end if
2794                #end for
2795                rgraphs[o] = array(sorted(rg),dtype=int)
2796            #end if
2797        #end for
2798        return rgraphs
2799    #end def ring_graphs
2800
2801
2802    # test needed
2803    # find the centroid of a set of points/atoms in min image convention
2804    def min_image_centroid(self,points=None,indices=None):
2805        if indices!=None:
2806            points = self.pos[indices]
2807        elif points is None:
2808            self.error('points or images must be provided to min_image_centroid')
2809        #end if
2810        p     = array(points,dtype=float)
2811        cprev = p[0]+1e99
2812        c     = p[0]
2813        while(norm(c-cprev)>1e-8):
2814            p = self.cell_image(p,center=c)
2815            cprev = c
2816            c = p.mean(axis=0)
2817        #end def min_image_centroid
2818        return c
2819    #end def min_image_centroid
2820
2821
2822    # test needed
2823    # find min image centroids of multiple sets of points/atoms
2824    def min_image_centroids(self,points=None,indices=None):
2825        cents = []
2826        if points!=None:
2827            for p in points:
2828                cents.append(self.min_image_centroid(p))
2829            #end for
2830        elif indices!=None:
2831            for ind in indices:
2832                cents.append(self.min_image_centroid(indices=ind))
2833            #end for
2834        else:
2835            self.error('points or images must be provided to min_image_centroid')
2836        #end if
2837        return array(cents,dtype=float)
2838    #end def min_image_centroids
2839
2840
2841    def min_image_vectors(self,points=None,points2=None,axes=None,pairs=True):
2842        if points is None:
2843            points = self.pos
2844        elif isinstance(points,Structure):
2845            points = points.pos
2846        #end if
2847        if axes is None:
2848            axes  = self.axes
2849        #end if
2850        axinv = inv(axes)
2851        points = array(points)
2852        single = points.shape==(self.dim,)
2853        if single:
2854            points = [points]
2855        #end if
2856        if points2 is None:
2857            points2 = self.pos
2858        elif isinstance(points2,Structure):
2859            points2 = points2.pos
2860        elif points2.shape==(self.dim,):
2861            points2 = [points2]
2862        #end if
2863        npoints  = len(points)
2864        npoints2 = len(points2)
2865        if pairs:
2866            vtable = empty((npoints,npoints2,self.dim),dtype=float)
2867            i=-1
2868            for p in points:
2869                i+=1
2870                j=-1
2871                for pp in points2:
2872                    j+=1
2873                    u = dot(pp-p,axinv)
2874                    vtable[i,j] = dot(u-floor(u+.5),axes)
2875                #end for
2876            #end for
2877            result = vtable
2878        else:
2879            if npoints!=npoints2:
2880                self.error('cannot create one to one minimum image vectors, point sets differ in length\n  npoints1 = {0}\n  npoints2 = {1}'.format(npoints,npoints2))
2881            #end if
2882            vectors = empty((npoints,self.dim),dtype=float)
2883            n = 0
2884            for p in points:
2885                pp = points2[n]
2886                u = dot(pp-p,axinv)
2887                vectors[n] = dot(u-floor(u+.5),axes)
2888                n+=1
2889            #end for
2890            result = vectors
2891        #end if
2892
2893        return result
2894    #end def min_image_vectors
2895
2896
2897    def min_image_distances(self,points=None,points2=None,axes=None,vectors=False,pairs=True):
2898        vtable = self.min_image_vectors(points,points2,axes,pairs=pairs)
2899        rdim = len(vtable.shape)-1
2900        dtable = sqrt((vtable**2).sum(rdim))
2901        if not vectors:
2902            return dtable
2903        else:
2904            return dtable,vtable
2905        #end if
2906    #end def min_image_distances
2907
2908
2909    def distance_table(self,points=None,points2=None,axes=None,vectors=False):
2910        return self.min_image_distances(points,points2,axes,vectors)
2911    #end def distance_table
2912
2913
2914    def vector_table(self,points=None,points2=None,axes=None):
2915        return self.min_image_vectors(points,points2,axes)
2916    #end def vector_table
2917
2918
2919    def neighbor_table(self,points=None,points2=None,axes=None,distances=False,vectors=False):
2920        dtable,vtable = self.min_image_distances(points,points2,axes,vectors=True)
2921        ntable = empty(dtable.shape,dtype=int)
2922        for i in range(len(dtable)):
2923            ntable[i] = dtable[i].argsort()
2924        #end for
2925        results = [ntable]
2926        if distances:
2927            for i in range(len(dtable)):
2928                dtable[i] = dtable[i][ntable[i]]
2929            #end for
2930            results.append(dtable)
2931        #end if
2932        if vectors:
2933            for i in range(len(vtable)):
2934                vtable[i] = vtable[i][ntable[i]]
2935            #end for
2936            results.append(vtable)
2937        #end if
2938        if len(results)==1:
2939            results = results[0]
2940        #end if
2941        return results
2942    #end def neighbor_table
2943
2944
2945    # test needed
2946    def min_image_norms(self,points,norms):
2947        if isinstance(norms,int) or isinstance(norms,float):
2948            norms = [norms]
2949        #end if
2950        vtable = self.min_image_vectors(points)
2951        rdim = len(vtable.shape)-1
2952        nout = []
2953        for p in norms:
2954            nout.append( ((abs(vtable)**p).sum(rdim))**(1./p) )
2955        #end for
2956        if len(norms)==1:
2957            nout = nout[0]
2958        #end if
2959        return nout
2960    #end def min_image_norms
2961
2962
2963    # test needed
2964    # get all neighbors according to contacting voronoi polyhedra in PBC
2965    def voronoi_neighbors(self,indices=None,restrict=False,distance_ordered=True):
2966        if indices is None:
2967            indices = arange(len(self.pos))
2968        #end if
2969        indices = set(indices)
2970        # make a new version of this (small cell)
2971        sn = self.copy()
2972        sn.recenter()
2973        # tile a large cell periodically
2974        d = 3
2975        t = tuple(zeros((d,),dtype=int)+3)
2976        ss = sn.tile(t)
2977        ss.recenter(sn.center)
2978        # get nearest neighbor index pairs in the large cell
2979        neigh_pairs = voronoi_neighbors(ss.pos)
2980        # create a mapping from large to small indices
2981        large_to_small = 3**d*list(range(len(self.pos)))
2982        # find the neighbor pairs in the small cell
2983        neighbors = obj()
2984        small_inds = set(ss.locate(sn.pos))
2985        for n in range(len(neigh_pairs)):
2986            i,j = neigh_pairs[n,:]
2987            if i in small_inds or j in small_inds: # pairs w/ at least one in cell image
2988                i = large_to_small[i]  # mapping to small cell indices
2989                j = large_to_small[j]
2990                if not restrict or (i in indices and j in indices): # restrict to orig index set
2991                    if not i in neighbors:
2992                        neighbors[i] = [j]
2993                    else:
2994                        neighbors[i].append(j)
2995                    #ned if
2996                    if not j in neighbors:
2997                        neighbors[j] = [i]
2998                    else:
2999                        neighbors[j].append(i)
3000                    #end if
3001                #end if
3002            #end if
3003        #end for
3004        # remove any duplicates and order by distance
3005        if distance_ordered:
3006            dt = self.distance_table()
3007            for i,ni in neighbors.items():
3008                ni = array(list(set(ni)),dtype=int)
3009                di = dt[i,ni]
3010                order = di.argsort()
3011                neighbors[i] = ni[order]
3012            #end for
3013        else:  # just remove duplicates
3014            for i,ni in neighbors.items():
3015                neighbors[i] = array(list(set(ni)),dtype=int)
3016            #end for
3017        #end if
3018        return neighbors
3019    #end def voronoi_neighbors
3020
3021
3022    def voronoi_vectors(self,indices=None,restrict=None):
3023        ni = self.voronoi_neighbors(indices,restrict)
3024        vt = self.vector_table()
3025        vv = obj()
3026        for i,vi in ni.items():
3027            vv[i] = vt[i,vi]
3028        #end for
3029        return vv
3030    #end def voronoi_vectors
3031
3032
3033    def voronoi_distances(self,indices=None,restrict=False):
3034        vv = self.voronoi_vectors(indices,restrict)
3035        vd = obj()
3036        for i,vvi in vv.items():
3037            vd[i] = np.linalg.norm(vvi,axis=1)
3038        #end for
3039        return vd
3040    #end def voronoi_distances
3041
3042
3043    def voronoi_radii(self,indices=None,restrict=None):
3044        vd = self.voronoi_distances(indices,restrict)
3045        vr = obj()
3046        for i,vdi in vd.items():
3047            vr[i] = vdi.min()/2
3048        #end for
3049        return vr
3050    #end def voronoi_radii
3051
3052
3053    def voronoi_species_radii(self):
3054        vr = self.voronoi_radii()
3055        vsr = obj()
3056        for i,r in vr.items():
3057            e = self.elem[i]
3058            if e not in vsr:
3059                vsr[e] = r
3060            else:
3061                vsr[e] = min(vsr[e],r)
3062            #end if
3063        #end for
3064        return vsr
3065    #end def voronoi_species_radii
3066
3067
3068    # test needed
3069    # get nearest neighbors according to constrants (voronoi, max distance, coord. number)
3070    def nearest_neighbors(self,indices=None,rmax=None,nmax=None,restrict=False,voronoi=False,distances=False,**spec_max):
3071        if indices is None:
3072            indices = arange(len(self.pos))
3073        #end if
3074        elem = set(self.elem[indices])
3075        spec = set(spec_max.keys())
3076        if spec==elem or rmax!=None or voronoi:
3077            None
3078        elif spec<elem and nmax!=None:
3079            for e in elem:
3080                if e not in spec:
3081                    spec_max[e] = nmax
3082                #end if
3083            #end for
3084        else:
3085            self.error('must specify nmax for all species\n  species present: {0}\n  you only provided nmax for these species: {1}'.format(sorted(elem),sorted(spec)))
3086        #end if
3087        pos = self.pos[indices]
3088        if not restrict:
3089            pos2 = self.pos
3090        else:
3091            pos2 = pos
3092        #end if
3093        if voronoi:
3094            neighbors = self.voronoi_neighbors(indices=indices,restrict=restrict)
3095            dt = self.distance_table(pos,pos2)[:,1:]
3096        else:
3097            nt,dt = self.neighbor_table(pos,pos2,distances=True)
3098            dt=dt[:,1:]
3099            nt=nt[:,1:]
3100            neighbors = list(nt)
3101        #end if
3102        for i in range(len(indices)):
3103            neighbors[i] = indices[neighbors[i]]
3104        #end for
3105        dist = list(dt)
3106        if rmax is None:
3107            for i in range(len(indices)):
3108                nn = neighbors[i]
3109                dn = dist[i]
3110                e = self.elem[indices[i]]
3111                if e in spec_max:
3112                    smax = spec_max[e]
3113                    if len(nn)>smax:
3114                        neighbors[i] = nn[:smax]
3115                        dist[i]      = dn[:smax]
3116                    #end if
3117                #end if
3118            #end for
3119        else:
3120            for i in range(len(indices)):
3121                neighbors[i] = neighbors[i][dt[i]<rmax]
3122            #end for
3123        #end if
3124        if not distances:
3125            return neighbors
3126        else:
3127            return neighbors,dist
3128        #end if
3129    #end def nearest_neighbors
3130
3131
3132    # test needed
3133    # determine local chemical coordination limited by constraints
3134    def chemical_coordination(self,indices=None,nmax=None,rmax=None,restrict=False,voronoi=False,neighbors=False,distances=False,**spec_max):
3135        if indices is None:
3136            indices = arange(len(self.pos))
3137        #end if
3138        if not distances:
3139            neigh = self.nearest_neighbors(indices=indices,nmax=nmax,rmax=rmax,restrict=restrict,voronoi=voronoi,**spec_max)
3140        else:
3141            neigh,dist = self.nearest_neighbors(indices=indices,nmax=nmax,rmax=rmax,restrict=restrict,voronoi=voronoi,distances=True,**spec_max)
3142        #end if
3143        neigh_elem = []
3144        for i in range(len(indices)):
3145            neigh_elem.extend(self.elem[neigh[i]])
3146        #end for
3147        chem_key = tuple(sorted(set(neigh_elem)))
3148        chem_coord = zeros((len(indices),len(chem_key)),dtype=int)
3149        for i in range(len(indices)):
3150            counts = zeros((len(chem_key),),dtype=int)
3151            nn = list(self.elem[neigh[i]])
3152            for n in range(len(counts)):
3153                chem_coord[i,n] = nn.count(chem_key[n])
3154            #end for
3155        #end for
3156        chem_map = obj()
3157        i=0
3158        for coord in chem_coord:
3159            coord = tuple(coord)
3160            if not coord in chem_map:
3161                chem_map[coord] = [indices[i]]
3162            else:
3163                chem_map[coord].append(indices[i])
3164            #end if
3165            i+=1
3166        #end for
3167        for coord,ind in chem_map.items():
3168            chem_map[coord] = array(ind,dtype=int)
3169        #end for
3170        results = [chem_key,chem_coord,chem_map]
3171        if neighbors:
3172            results.append(neigh)
3173        #end if
3174        if distances:
3175            results.append(dist)
3176        #end if
3177        return results
3178    #end def chemical_coordination
3179
3180
3181    # test needed
3182    def rcore_max(self,units=None):
3183        nt,dt = self.neighbor_table(self.pos,distances=True)
3184        d = dt[:,1]
3185        rcm = d.min()/2
3186        if units!=None:
3187            rcm = convert(rcm,self.units,units)
3188        #end if
3189        return rcm
3190    #end def rcore_max
3191
3192
3193    # test needed
3194    def cell_image(self,p,center=None):
3195        pos = array(p,dtype=float)
3196        if center is None:
3197            c = self.center.copy()
3198        else:
3199            c = array(center,dtype=float)
3200        #end if
3201        axes = self.axes
3202        axinv = inv(axes)
3203        for i in range(len(pos)):
3204            u = dot(pos[i]-c,axinv)
3205            pos[i] = dot(u-floor(u+.5),axes)+c
3206        #end for
3207        return pos
3208    #end def cell_image
3209
3210
3211    # test needed
3212    def center_distances(self,points,center=None):
3213        if center is None:
3214            c = self.center.copy()
3215        else:
3216            c = array(center,dtype=float)
3217        #end if
3218        points = self.cell_image(points,center=c)
3219        for i in range(len(points)):
3220            points[i] -= c
3221        #end for
3222        return sqrt((points**2).sum(1))
3223    #end def center_distances
3224
3225
3226    # test needed
3227    def recenter(self,center=None):
3228        if center is not None:
3229            self.center=array(center,dtype=float)
3230        #end if
3231        pos = self.pos
3232        c = empty((1,self.dim),dtype=float)
3233        c[:] = self.center[:]
3234        axes = self.axes
3235        axinv = inv(axes)
3236        for i in range(len(pos)):
3237            u = dot(pos[i]-c,axinv)
3238            pos[i] = dot(u-floor(u+.5),axes)+c
3239        #end for
3240        self.recenter_k()
3241    #end def recenter
3242
3243
3244    # test needed
3245    def recorner(self):
3246        pos = self.pos
3247        axes = self.axes
3248        axinv = inv(axes)
3249        for i in range(len(pos)):
3250            u = dot(pos[i],axinv)
3251            pos[i] = dot(u-floor(u),axes)
3252        #end for
3253    #end def recorner
3254
3255
3256    # test needed
3257    def recenter_k(self,kpoints=None,kaxes=None,kcenter=None,remove_duplicates=False):
3258        use_self = kpoints is None
3259        if use_self:
3260            kpoints=self.kpoints
3261        #end if
3262        if kaxes is None:
3263            kaxes=self.kaxes
3264        #end if
3265        if len(kpoints)>0:
3266            axes = kaxes
3267            axinv = inv(axes)
3268            if kcenter is None:
3269                c = axes.sum(0)/2
3270            else:
3271                c = array(kcenter)
3272            #end if
3273            for i in range(len(kpoints)):
3274                u = dot(kpoints[i]-c,axinv)
3275                u -= floor(u+.5)
3276                u[abs(u-.5)<1e-12] -= 1.0
3277                u[abs(u   )<1e-12]  = 0.0
3278                kpoints[i] = dot(u,axes)+c
3279            #end for
3280            if remove_duplicates:
3281                inside = self.inside(kpoints,axes,c)
3282                kpoints  = kpoints[inside]
3283                nkpoints = len(kpoints)
3284                unique = empty((nkpoints,),dtype=bool)
3285                unique[:] = True
3286                nn = nearest_neighbors(1,kpoints)
3287                if nkpoints>1:
3288                    nn.shape = nkpoints,
3289                    dist = self.distances(kpoints,kpoints[nn])
3290                    tol = 1e-8
3291                    duplicates = arange(nkpoints)[dist<tol]
3292                    for i in duplicates:
3293                        if unique[i]:
3294                            for j in duplicates:
3295                                if sqrt(((kpoints[i]-kpoints[j])**2).sum(1))<tol:
3296                                    unique[j] = False
3297                                #end if
3298                            #end for
3299                        #end if
3300                    #end for
3301                #end if
3302                kpoints = kpoints[unique]
3303            #end if
3304        #end if
3305        if use_self:
3306            self.kpoints = kpoints
3307        else:
3308            return kpoints
3309        #end if
3310    #end def recenter_k
3311
3312
3313    # test needed
3314    def inside(self,pos,axes=None,center=None,tol=1e-8,separate=False):
3315        if axes==None:
3316            axes=self.axes
3317        #end if
3318        if center==None:
3319            center=self.center
3320        #end if
3321        axes = array(axes)
3322        center = array(center)
3323        inside = []
3324        surface = []
3325        su = []
3326        axinv = inv(axes)
3327        for i in range(len(pos)):
3328            u = dot(pos[i]-center,axinv)
3329            umax = abs(u).max()
3330            if abs(umax-.5)<tol:
3331                surface.append(i)
3332                su.append(u)
3333            elif umax<.5:
3334                inside.append(i)
3335            #end if
3336        #end for
3337        npos,dim = pos.shape
3338        drange = list(range(dim))
3339        n = len(surface)
3340        i=0
3341        while i<n:
3342            j=i+1
3343            while j<n:
3344                du = abs(su[i]-su[j])
3345                match = False
3346                for d in drange:
3347                    match = match or abs(du[d]-1.)<tol
3348                #end for
3349                if match:
3350                    surface[j]=surface[-1]
3351                    surface.pop()
3352                    su[j]=su[-1]
3353                    su.pop()
3354                    n-=1
3355                else:
3356                    j+=1
3357                #end if
3358            #end while
3359            i+=1
3360        #end while
3361        if not separate:
3362            inside+=surface
3363            return inside
3364        else:
3365            return inside,surface
3366        #end if
3367    #end def inside
3368
3369
3370    def tile(self,*td,**kwargs):
3371        in_place           = kwargs.pop('in_place',False)
3372        check              = kwargs.pop('check',False)
3373
3374        dim = self.dim
3375        if len(td)==1:
3376            if isinstance(td[0],int):
3377                tiling = dim*[td[0]]
3378            else:
3379                tiling = td[0]
3380            #end if
3381        else:
3382            tiling = td
3383        #end if
3384        tiling = array(tiling)
3385
3386        matrix_tiling = tiling.shape == (dim,dim)
3387
3388        tilematrix,tilevector = reduce_tilematrix(tiling)
3389
3390        ncells = int(round( abs(det(tilematrix)) ))
3391
3392        if ncells==1 and abs(tilematrix-identity(self.dim)).sum()<1e-1:
3393            if in_place:
3394                return self
3395            else:
3396                return self.copy()
3397            #end if
3398        #end if
3399
3400        self.recenter()
3401
3402        elem = array(ncells*list(self.elem))
3403        pos  = self.tile_points(self.pos,self.axes,tilematrix,tilevector)
3404        axes = dot(tilematrix,self.axes)
3405
3406        center   = axes.sum(0)/2
3407        kaxes    = dot(inv(tilematrix.T),self.kaxes)
3408        kpoints  = array(self.kpoints)
3409        kweights = array(self.kweights)
3410        mag      = None
3411        frozen   = None
3412        if self.mag is not None:
3413            mag = ncells*list(self.mag)
3414        #end if
3415        if self.frozen is not None:
3416            frozen = ncells*list(self.frozen)
3417        #end if
3418
3419        ts = self.copy()
3420        ts.center   = center
3421        ts.set_elem(elem)
3422        ts.axes     = axes
3423        ts.pos      = pos
3424        ts.mag      = mag
3425        ts.kaxes    = kaxes
3426        ts.kpoints  = kpoints
3427        ts.kweights = kweights
3428        ts.set_mag(mag)
3429        ts.set_frozen(frozen)
3430        ts.background_charge = ncells*self.background_charge
3431
3432        ts.recenter()
3433        ts.unique_kpoints()
3434        if self.is_tiled():
3435            ts.tmatrix = dot(tilematrix,self.tmatrix)
3436            ts.folded_structure = self.folded_structure.copy()
3437        else:
3438            ts.tmatrix = tilematrix
3439            ts.folded_structure = self.copy()
3440        #end if
3441
3442        if in_place:
3443            self.clear()
3444            self.transfer_from(ts)
3445            ts = self
3446        #end if
3447
3448        if check:
3449            ts.check_tiling()
3450        #end if
3451
3452        return ts
3453    #end def tile
3454
3455
3456    def tile_points(self,points,axes,tilemat,tilevec=None):
3457        if tilevec is None:
3458            tilemat,tilevec = reduce_tilematrix(tilemat)
3459        #end if
3460        if not isinstance(tilemat,ndarray):
3461            tilemat = array(tilemat)
3462        #end if
3463        matrix_tiling = abs(tilemat-diag(diag(tilemat))).sum()>0.1
3464        if not matrix_tiling:
3465            return self.tile_points_simple(points,axes,diag(abs(tilemat)))
3466        else:
3467            if not isinstance(axes,ndarray):
3468                axes = array(axes)
3469            #end if
3470            if not isinstance(tilevec,ndarray):
3471                tilevec = array(tilevec)
3472            #end if
3473            dim     = len(axes)
3474            npoints = len(points)
3475            ntpoints = npoints*int(round(abs(det(tilemat))))
3476            if tilevec.size==dim:
3477                tilevec.shape = 1,dim
3478            #end if
3479            taxes = dot(tilemat,axes)
3480            success = False
3481            for tvec in tilevec:
3482                tpoints = self.tile_points_simple(points,axes,tvec)
3483                tpoints,weights,pmap = self.unique_points_fast(tpoints,taxes)
3484                if len(tpoints)==ntpoints:
3485                    success = True
3486                    break
3487                #end if
3488            #end for
3489            if not success:
3490                tpoints = self.tile_points_brute(points,axes,tilemat)
3491                tpoints,weights,pmap = self.unique_points_fast(tpoints,taxes)
3492                if len(tpoints)!=ntpoints:
3493                    self.error('brute force tiling failed')
3494                #end if
3495            #end if
3496        #end if
3497        return tpoints
3498    #end def tile_points
3499
3500
3501    def tile_points_simple(self,points,axes,tilevec):
3502        if not isinstance(points,ndarray):
3503            points = array(points)
3504        #end if
3505        if not isinstance(tilevec,ndarray):
3506            tilevec = array(tilevec)
3507        #end if
3508        if not isinstance(axes,ndarray):
3509            axes = array(axes)
3510        #end if
3511        if len(points.shape)==1:
3512            npoints,dim = len(points),1
3513        else:
3514            npoints,dim = points.shape
3515        #end if
3516        t = tilevec
3517        ti = array(around(t),dtype=int)
3518        noninteger = abs(t-ti).sum()>1e-6
3519        if noninteger:
3520            tp = t.prod()
3521            if abs(tp-int(tp))>1e-6:
3522                self.error('tiling vector does not correspond to an integer volume change\ntiling vector: {0}\nvolume change: {1}  {2}  {3}'.format(tilevec,tilevec.prod(),ntpoints,int(ntpoints)))
3523            #end if
3524            t = array(ceil(t),dtype=int)+1
3525        else:
3526            t = ti
3527        #end if
3528        if t.min()<0:
3529            self.error('tiling vector cannot be negative\ntiling vector provided: {}'.format(t))
3530        #end if
3531        ntpoints = npoints*int(round( t.prod() ))
3532        if ntpoints==0:
3533            tpoints = array([])
3534        else:
3535            tpoints = empty((ntpoints,dim))
3536            ns=0
3537            ne=npoints
3538            for k in range(t[2]):
3539                for j in range(t[1]):
3540                    for i in range(t[0]):
3541                        v = dot(array([[i,j,k]]),axes)
3542                        for d in range(dim):
3543                            tpoints[ns:ne,d] = points[:,d]+v[0,d]
3544                        #end for
3545                        ns+=npoints
3546                        ne+=npoints
3547                    #end for
3548                #end for
3549            #end for
3550        #end if
3551        return tpoints
3552    #end def tile_points_simple
3553
3554
3555    def tile_points_brute(self,points,axes,tilemat):
3556        tcorners = [[0,0,0],
3557                    [1,0,0],
3558                    [0,1,0],
3559                    [0,0,1],
3560                    [0,1,1],
3561                    [1,0,1],
3562                    [1,1,0],
3563                    [1,1,1]]
3564        tcorners = dot(tcorners,tilemat)
3565        tmin = tcorners.min(axis=0)
3566        tmax = tcorners.max(axis=0)
3567        tilevec = tmax-tmin
3568        tpoints = self.tile_points_simple(points,axes,tilevec)
3569        return tpoints
3570    #end def tile_points_brute
3571
3572
3573
3574    def opt_tilematrix(self,*args,**kwargs):
3575        return optimal_tilematrix(self.axes,*args,**kwargs)
3576    #end def opt_tilematrix
3577
3578
3579    def tile_opt(self,*args,**kwargs):
3580        Topt,ropt = self.opt_tilematrix(*args,**kwargs)
3581        return self.tile(Topt)
3582    #end def tile_opt
3583
3584
3585    def check_tiling(self,tol=1e-6,exit=True):
3586        msg = ''
3587        if not self.is_tiled():
3588            return msg
3589        #end if
3590        msgs = []
3591        st = self
3592        s  = self.folded_structure
3593        nt = len(st.pos)
3594        n  = len(s.pos)
3595        if nt%n!=0:
3596            msgs.append('tiled atom count does is not divisible by untiled atom count')
3597        #end if
3598        vratio = st.volume()/s.volume()
3599        if abs(vratio-float(nt)/n)>tol:
3600            msgs.append('tiled/untiled volume ratio does not match tiled/untiled atom count ratio')
3601        #end if
3602        if abs(vratio-abs(det(st.tmatrix)))>tol:
3603            msgs.append('tiled/untiled volume ratio does not match tiling matrix determinant')
3604        #end if
3605        p,w,pmap = self.unique_points_fast(st.pos,st.axes)
3606        if len(p)!=nt:
3607            msgs.append('tiled positions are not unique')
3608        #end if
3609        if len(msgs)>0:
3610            msg = 'tiling check failed'
3611            for m in msgs:
3612                msg += '\n'+m
3613            #end for
3614            if exit:
3615                self.error(msg)
3616            #end if
3617        #end if
3618        return msg
3619    #end def check_tiling
3620
3621
3622    # test needed
3623    def kfold(self,tiling,kpoints,kweights):
3624        if isinstance(tiling,int):
3625            tiling = self.dim*[tiling]
3626        #end if
3627        tiling = array(tiling)
3628        if tiling.shape==(self.dim,self.dim):
3629            tiling = tiling.T
3630        #end if
3631        tilematrix,tilevector = reduce_tilematrix(tiling)
3632        ncells = int(round( abs(det(tilematrix)) ))
3633        kp     = self.tile_points(kpoints,self.kaxes,tilematrix,tilevector)
3634        kw     = array(ncells*list(kweights),dtype=float)/ncells
3635        return kp,kw
3636    #end def kfold
3637
3638
3639    def get_smallest(self):
3640        if self.has_folded():
3641            return self.folded_structure
3642        else:
3643            return self
3644        #end if
3645    #end def get_smallest
3646
3647
3648    # test needed
3649    def fold(self,small,*requests):
3650        self.error('fold needs a developers attention to make it equivalent with tile')
3651        if self.dim!=3:
3652            self.error('fold is currently only implemented for 3 dimensions')
3653        #end if
3654        self.recenter_k()
3655        corners = []
3656        ndim = len(small.axes)
3657        imin = empty((ndim,),dtype=int)
3658        imax = empty((ndim,),dtype=int)
3659        imin[:] =  1000000
3660        imax[:] = -1000000
3661        axinv  = inv(self.kaxes)
3662        center = self.kaxes.sum(0)/2
3663        c = empty((1,3))
3664        for k in -1,2:
3665            for j in -1,2:
3666                for i in -1,2:
3667                    c[:] = i,j,k
3668                    c = dot(c,small.kaxes)
3669                    u = dot(c-center,axinv)
3670                    for d in range(ndim):
3671                        imin[d] = min(int(floor(u[0,d])),imin[d])
3672                        imax[d] = max(int(ceil(u[0,d])),imax[d])
3673                    #end for
3674                #end for
3675            #end for
3676        #end for
3677
3678        axes = small.kaxes
3679        axinv = inv(small.kaxes)
3680
3681        center = small.kaxes.sum(0)/2
3682        nkpoints = len(self.kpoints)
3683        kindices = []
3684        kpoints  = []
3685        shift = empty((ndim,))
3686        kr = list(range(nkpoints))
3687        for k in range(imin[2],imax[2]+1):
3688            for j in range(imin[1],imax[1]+1):
3689                for i in range(imin[0],imax[0]+1):
3690                    for n in kr:
3691                        shift[:] = i,j,k
3692                        shift = dot(shift,self.kaxes)
3693                        kp = self.kpoints[n]+shift
3694                        u = dot(kp-center,axinv)
3695                        if abs(u).max()<.5+1e-10:
3696                            kindices.append(n)
3697                            kpoints.append(kp)
3698                        #end if
3699                    #end for
3700                #end for
3701            #end for
3702        #end for
3703        kindices = array(kindices)
3704        kpoints  = array(kpoints)
3705        inside = self.inside(kpoints,axes,center)
3706        kindices = kindices[inside]
3707        kpoints  = kpoints[inside]
3708
3709        small.kpoints = kpoints
3710        small.recenter_k()
3711        kpoints = array(small.kpoints)
3712        if len(requests)>0:
3713            results = []
3714            for request in requests:
3715                if request=='kmap':
3716                    kmap = obj()
3717                    for k in self.kpoints:
3718                        kmap[tuple(k)] = []
3719                    #end for
3720                    for i in range(len(kpoints)):
3721                        kp = tuple(self.kpoints[kindices[i]])
3722                        kmap[kp].append(array(kpoints[i]))
3723                    #end for
3724                    for kl,ks in kmap.items():
3725                        kmap[kl] = array(ks)
3726                    #end for
3727                    res = kmap
3728                elif request=='tilematrix':
3729                    res = self.tilematrix(small)
3730                else:
3731                    self.error(request+' is not a recognized input to fold')
3732                #end if
3733                results.append(res)
3734            #end if
3735            return results
3736        #end if
3737    #end def fold
3738
3739
3740    def tilematrix(self,small=None,tol=1e-6,status=False):
3741        if small is None:
3742            if self.folded_structure is not None:
3743                small = self.folded_structure
3744            else:
3745                return identity(self.dim,dtype=int)
3746            #end if
3747        #end if
3748        tm = dot(self.axes,inv(small.axes))
3749        tilemat = array(around(tm),dtype=int)
3750        error = abs(tilemat-tm).sum()
3751        non_integer_elements = error > tol
3752        if status:
3753            return tilemat,not non_integer_elements
3754        else:
3755            if non_integer_elements:
3756                self.error('large cell cannot be constructed as an integer tiling of the small cell\nlarge cell axes:\n'+str(self.axes)+'\nsmall cell axes:  \n'+str(small.axes)+'\nlarge/small:\n'+str(self.axes/small.axes)+'\ntiling matrix:\n'+str(tm)+'\nintegerized tiling matrix:\n'+str(tilemat)+'\nerror: '+str(error)+'\ntolerance: '+str(tol))
3757            #end if
3758            return tilemat
3759        #end if
3760    #end def tilematrix
3761
3762
3763    def primitive(self,source=None,tmatrix=False,add_kpath=False,**kwargs):
3764        res = None
3765        allowed_sources = set(['seekpath'])
3766        if source is None or isinstance(source,bool):
3767            source = 'seekpath'
3768        #end if
3769        if source not in allowed_sources:
3770            self.error('source used to obtain primitive cell is unrecognized\nsource requested: {0}\nallowed sources: {1}'.format(source,sorted(allowed_sources)))
3771        #end if
3772        if source=='seekpath':
3773            res_skp = get_seekpath_full(structure=self,primitive=True,**kwargs)
3774            prim = res_skp.primitive
3775            T    = res_skp.prim_tmatrix
3776            if add_kpath:
3777                prim.add_kpoints(res_skp.explicit_kpoints_abs)
3778            #end if
3779            if tmatrix:
3780                res = prim,T
3781            else:
3782                res = prim
3783            #end if
3784        else:
3785            self.error('primitive source "{0}" is not implemented\nplease contact a developer'.format(source))
3786        #end if
3787        if prim.units!=self.units:
3788            prim.change_units(self.units)
3789        #end if
3790        return res
3791    #end def primitive
3792
3793
3794    def become_primitive(self,source=None,add_kpath=False,**kwargs):
3795        prim = self.primitive(source=source,add_kpath=add_kpath,**kwargs)
3796        self.clone_from(prim)
3797    #end def become_primitive
3798
3799
3800    def add_kpoints(self,kpoints,kweights=None,unique=False,recenter=True,cell_unit=False):
3801        if kweights is None:
3802            kweights = ones((len(kpoints),))
3803        #end if
3804        if cell_unit:
3805            kpoints = np.dot(array(kpoints),self.kaxes)
3806        #end if
3807        self.kpoints  = append(self.kpoints,kpoints,axis=0)
3808        self.kweights = append(self.kweights,kweights)
3809        if unique:
3810            self.unique_kpoints()
3811        #end if
3812        if recenter:
3813            self.recenter_k() #added because qmcpack cannot handle kpoints outside the box
3814        #end if
3815        if self.is_tiled():
3816            kp,kw = self.kfold(self.tmatrix,kpoints,kweights)
3817            self.folded_structure.add_kpoints(kp,kw,unique=unique)
3818        #end if
3819    #end def add_kpoints
3820
3821
3822    # test needed
3823    def clear_kpoints(self):
3824        self.kpoints  = empty((0,self.dim))
3825        self.kweights = empty((0,))
3826        if self.folded_structure!=None:
3827            self.folded_structure.clear_kpoints()
3828        #end if
3829    #end def clear_kpoints
3830
3831
3832    def kgrid_from_kspacing(self,kspacing):
3833        kgrid = []
3834        for ka in self.kaxes:
3835            km = np.linalg.norm(ka)
3836            kg = int(np.ceil(km/kspacing))
3837            kgrid.append(kg)
3838        #end for
3839        return tuple(kgrid)
3840    #end def kgrid_from_kspacing
3841
3842
3843    def add_kmesh(self,kgrid=None,kshift=None,unique=False,kspacing=None):
3844        if kspacing is not None:
3845            kgrid = self.kgrid_from_kspacing(kspacing)
3846        elif kgrid is None:
3847            self.error('kgrid input is required by add_kmesh')
3848        #end if
3849        self.add_kpoints(kmesh(self.kaxes,kgrid,kshift),unique=unique)
3850    #end def add_kmesh
3851
3852
3853    def add_symmetrized_kmesh(self,kgrid=None,kshift=(0,0,0),kspacing=None):
3854        # find kgrid from kspacing, if requested
3855        if kspacing is not None:
3856            kgrid = self.kgrid_from_kspacing(kspacing)
3857        elif kgrid is None:
3858            self.error('kgrid input is required by add_kmesh')
3859        #end if
3860
3861        # get spglib cell data structure
3862        cell = self.spglib_cell()
3863
3864        # get the symmetry mapping
3865        kmap,kpoints_int = spglib.get_ir_reciprocal_mesh(
3866            kgrid,
3867            cell,
3868            is_shift=kshift
3869            )
3870
3871        # create the Monkhorst-Pack mesh
3872        kshift = array(kshift,dtype=float)
3873        okgrid = 1.0/array(kgrid,dtype=float)
3874        kpoints = empty(kpoints_int.shape,dtype=float)
3875        for i,ki in enumerate(kpoints_int):
3876            kpoints[i] = (ki+kshift)*okgrid
3877        #end for
3878        kpoints = dot(kpoints,self.kaxes)
3879
3880        # reduce to only the symmetric kpoints with weights
3881        kwmap = obj()
3882        for ik in kmap:
3883            if ik not in kwmap:
3884                kwmap[ik] = 1
3885            else:
3886                kwmap[ik] += 1
3887            #end if
3888        #end for
3889        nkpoints = len(kwmap)
3890        kpoints_symm  = empty((nkpoints,self.dim),dtype=float)
3891        kweights_symm = empty((nkpoints,),dtype=float)
3892        n = 0
3893        for ik,kw in kwmap.items():
3894            kpoints_symm[n]  = kpoints[ik]
3895            kweights_symm[n] = kw
3896            n+=1
3897        #end for
3898        self.add_kpoints(kpoints_symm,kweights_symm)
3899    #end def add_symmetrized_kmesh
3900
3901
3902    def kpoints_unit(self,kpoints=None):
3903        if kpoints is None:
3904            kpoints = self.kpoints
3905        #end if
3906        return dot(kpoints,inv(self.kaxes))
3907    #end def kpoints_unit
3908
3909
3910    def kpoints_reduced(self,kpoints=None):
3911        if kpoints is None:
3912            kpoints = self.kpoints
3913        #end if
3914        return kpoints*self.scale/(2*pi)
3915    #end def kpoints_reduced
3916
3917
3918    def kpoints_qmcpack(self,kpoints=None):
3919        if kpoints is None:
3920            kpoints = self.kpoints.copy()
3921        #end if
3922        kpoints = self.recenter_k(kpoints,kcenter=(0,0,0))
3923        kpoints = self.kpoints_unit(kpoints)
3924        kpoints = -kpoints
3925        return kpoints
3926    #end def kpoints_qmcpack
3927
3928
3929    # test needed
3930    def inversion_symmetrize_kpoints(self,tol=1e-10,folded=False):
3931        kp    = self.kpoints
3932        kaxes = self.kaxes
3933        ntable,dtable = self.neighbor_table(kp,-kp,kaxes,distances=True)
3934        pairs = set()
3935        keep = empty((len(kp),),dtype=bool)
3936        keep[:] = True
3937        for i in range(len(dtable)):
3938            if keep[i] and dtable[i,0]<tol:
3939                j = ntable[i,0]
3940                if j!=i and keep[j]:
3941                    keep[j] = False
3942                    self.kweights[i] += self.kweights[j]
3943                #end if
3944            #end if
3945        #end for
3946        self.kpoints  = self.kpoints[keep]
3947        self.kweights = self.kweights[keep]
3948        if folded and self.folded_structure!=None:
3949            self.folded_structure.inversion_symmetrize_kpoints(tol)
3950        #end if
3951    #end def inversion_symmetrize_kpoints
3952
3953
3954    # test needed
3955    def unique_points(self,points,axes,weights=None,tol=1e-10):
3956        pmap = obj()
3957        npoints = len(points)
3958        if npoints>0:
3959            if weights is None:
3960                weights = ones((npoints,),dtype=int)
3961            #end if
3962            ntable,dtable = self.neighbor_table(points,points,axes,distances=True)
3963            keep = empty((npoints,),dtype=bool)
3964            keep[:] = True
3965            pmo = obj()
3966            for i in range(npoints):
3967                if keep[i]:
3968                    pm = []
3969                    jn=0
3970                    while jn<npoints and dtable[i,jn]<tol:
3971                        j = ntable[i,jn]
3972                        pm.append(j)
3973                        if j!=i and keep[j]:
3974                            keep[j] = False
3975                            weights[i] += weights[j]
3976                        #end if
3977                        jn+=1
3978                    #end while
3979                    pmo[i] = set(pm)
3980                #end if
3981            #end for
3982            points  = points[keep]
3983            weights = weights[keep]
3984            j=0
3985            for i in range(len(keep)):
3986                if keep[i]:
3987                    pmap[j] = pmo[i]
3988                    j+=1
3989                #end if
3990            #end for
3991        #end if
3992        return points,weights,pmap
3993    #end def unique_points
3994
3995
3996    # test needed
3997    def unique_points_fast(self,points,axes,weights=None,tol=1e-10):
3998        # use an O(N) cell table instead of an O(N^2) neighbor table
3999        pmap = obj()
4000        points = array(points)
4001        axes   = array(axes)
4002        npoints = len(points)
4003        if npoints>0:
4004            if weights is None:
4005                weights = ones((npoints,),dtype=int)
4006            else:
4007                weights = array(weights)
4008            #end if
4009            keep = ones((npoints,),dtype=bool)
4010            # place all the points in the box, converted to unit coords
4011            upoints = array(points)
4012            axinv = inv(axes)
4013            for i in range(len(points)):
4014                u = dot(points[i],axinv)
4015                upoints[i] = u-floor(u)
4016            #end for
4017            # create an integer array of cell indices
4018            axmax = -1.0
4019            for a in axes:
4020                axmax = max(axmax,norm(a))
4021            #end for
4022            #   make an integer space corresponding to 1e-7 self.units spatial resolution
4023            cmax = uint64(1e7)*uint64(ceil(axmax))
4024            ipoints = array(around(cmax*upoints),dtype=uint64)
4025            ipoints[ipoints==cmax] = 0 # make the outer boundary the same as the inner boundary
4026            # load the cell table with point indices
4027            #   points in the same cell are identical
4028            ctable = obj()
4029            i=0
4030            for ip in ipoints:
4031                ip = tuple(ip)
4032                if ip not in ctable:
4033                    ctable[ip] = i
4034                    pmap[i] = [i]
4035                else:
4036                    j = ctable[ip]
4037                    keep[i] = False
4038                    weights[j] += weights[i]
4039                    pmap[j].append(i)
4040                #end if
4041                i+=1
4042            #end for
4043            points  = points[keep]
4044            weights = weights[keep]
4045        #end if
4046        return points,weights,pmap
4047    #end def unique_points_fast
4048
4049
4050    # test needed
4051    def unique_positions(self,tol=1e-10,folded=False):
4052        pos,weights,pmap = self.unique_points(self.pos,self.axes)
4053        if len(pos)!=len(self.pos):
4054            self.pos = pos
4055        #end if
4056        if folded and self.folded_structure!=None:
4057            self.folded_structure.unique_positions(tol)
4058        #end if
4059        return pmap
4060    #end def unique_positions
4061
4062
4063    # test needed
4064    def unique_kpoints(self,tol=1e-10,folded=False):
4065        kmap = obj()
4066        kp   = self.kpoints
4067        if len(kp)>0:
4068            kaxes = self.kaxes
4069            ntable,dtable = self.neighbor_table(kp,kp,kaxes,distances=True)
4070            npoints = len(kp)
4071            keep = empty((len(kp),),dtype=bool)
4072            keep[:] = True
4073            kmo = obj()
4074            for i in range(npoints):
4075                if keep[i]:
4076                    km = []
4077                    jn=0
4078                    while jn<npoints and dtable[i,jn]<tol:
4079                        j = ntable[i,jn]
4080                        km.append(j)
4081                        if j!=i and keep[j]:
4082                            keep[j] = False
4083                            self.kweights[i] += self.kweights[j]
4084                        #end if
4085                        jn+=1
4086                    #end while
4087                    kmo[i] = set(km)
4088                #end if
4089            #end for
4090            self.kpoints  = self.kpoints[keep]
4091            self.kweights = self.kweights[keep]
4092            j=0
4093            for i in range(len(keep)):
4094                if keep[i]:
4095                    kmap[j] = kmo[i]
4096                    j+=1
4097                #end if
4098            #end for
4099        #end if
4100        if folded and self.folded_structure!=None:
4101            self.folded_structure.unique_kpoints(tol)
4102        #end if
4103        return kmap
4104    #end def unique_kpoints
4105
4106
4107    def kmap(self):
4108        kmap = None
4109        if self.folded_structure!=None:
4110            fs = self.folded_structure
4111            self.kpoints  = array(fs.kpoints)
4112            self.kweights = array(fs.kweights)
4113            kmap = self.unique_kpoints()
4114        #end if
4115        return kmap
4116    #end def kmap
4117
4118
4119    # test needed
4120    def select_twist(self,selector='smallest',tol=1e-6):
4121        index = None
4122        invalid_selector = False
4123        if isinstance(selector,str):
4124            if selector=='smallest':
4125                index = (self.kpoints**2).sum(1).argmin()
4126            elif selector=='random':
4127                index = randint(0,len(self.kpoints)-1)
4128            else:
4129                invalid_selector = True
4130            #end if
4131        elif isinstance(selector,(tuple,list,ndarray)):
4132            ku_sel = array(selector,dtype=float)
4133            n = 0
4134            for ku in self.kpoints_unit():
4135                if norm(ku-ku_sel)<tol:
4136                    index = n
4137                    break
4138                #end if
4139                n+=1
4140            #end for
4141            if index is None:
4142                self.error('cannot identify twist number\ntwist requested: {0}\ntwists present: {1}'.format(ku_sel,sorted([tuple(k) for k in self.kpoints_unit()])))
4143            #end if
4144        else:
4145            invalid_selector = True
4146        #end if
4147        if invalid_selector:
4148            self.error('cannot identify twist number\ninvalid selector provided: {0}\nvalid string inputs for selector: smallest, random\nselector can also be a length 3 tuple, list or array (a twist vector)'.format(selector))
4149        #end if
4150        return index
4151    #end def select_twist
4152
4153
4154    # test needed
4155    def fold_pos(self,large,tol=0.001):
4156        vratio = large.volume()/self.volume()
4157        if abs(vratio-int(around(vratio)))>1e-6:
4158            self.error('cannot fold positions from large cell into current one\nlarge cell volume is not an integer multiple of the current one\nlarge cell volume: {0}\ncurrent cell volume: {1}\nvolume ratio: {2}'.format(large.volume(),self.volume(),vratio))
4159        T,success = large.tilematrix(self,status=True)
4160        if not success:
4161            self.error('cannot fold positions from large cell into current one\ncells are related by non-integer tilematrix')
4162        #end if
4163        nnearest = int(around(vratio))
4164        self.elem = large.elem.copy()
4165        self.pos  = large.pos.copy()
4166        self.recenter()
4167        nt,dt = self.neighbor_table(distances=True)
4168        nt = nt[:,:nnearest]
4169        dt = dt[:,:nnearest]
4170        if dt.ravel().max()>tol:
4171            self.error('cannot fold positions from large cell into current one\npositions of equivalent atoms are further apart than the tolerance\nmax distance encountered: {0}\ntolerance: {1}'.format(dt.ravel().max(),tol))
4172        #end if
4173        counts = zeros((len(self.pos),),dtype=int)
4174        for n in nt.ravel():
4175            counts[n] += 1
4176        #end for
4177        if (counts!=nnearest).any():
4178            self.error('cannot fold positions from large cell into current one\neach atom must have {0} equivalent positions\nsome atoms found with the following equivalent position counts: {1}'.format(nnearest,counts[counts!=nnearest]))
4179        #end if
4180        ind_visited = set()
4181        neigh_map = obj()
4182        keep = []
4183        n=0
4184        for nset in nt:
4185            if n not in ind_visited:
4186                neigh_map[n] = nset
4187                keep.append(n)
4188                for ind in nset:
4189                    ind_visited.add(ind)
4190                #end for
4191            #end if
4192            n+=1
4193        #end for
4194        if len(ind_visited)!=len(self.pos):
4195            self.error('cannot fold positions from large cell into current one\nsome equivalent atoms could not be identified')
4196        #end if
4197        new_elem = []
4198        new_pos  = []
4199        for n in keep:
4200            nset = neigh_map[n]
4201            elist = list(set(self.elem[nset]))
4202            if len(elist)!=1:
4203                self.error('cannot fold positions from large cell into current one\nspecies of some equivalent atoms do not match')
4204            #end if
4205            new_elem.append(elist[0])
4206            new_pos.append(self.pos[nset].mean(0))
4207        #end for
4208        self.set_elem(new_elem)
4209        self.set_pos(new_pos)
4210    #end def fold_pos
4211
4212
4213    def pos_unit(self,pos=None):
4214        if pos is None:
4215            pos = self.pos
4216        #end if
4217        return dot(pos,inv(self.axes))
4218    #end def pos_unit
4219
4220
4221    def pos_to_cartesian(self):
4222        self.pos = dot(self.pos,self.axes)
4223    #end def pos_to_cartesian
4224
4225
4226    # test needed
4227    def at_Gpoint(self):
4228        kpu = self.kpoints_unit()
4229        kg = array([0,0,0])
4230        return len(kpu)==1 and norm(kg-kpu[0])<1e-6
4231    #end def at_Gpoint
4232
4233
4234    # test needed
4235    def at_Lpoint(self):
4236        kpu = self.kpoints_unit()
4237        kg = array([.5,.5,.5])
4238        return len(kpu)==1 and norm(kg-kpu[0])<1e-6
4239    #end def at_Lpoint
4240
4241
4242    # test needed
4243    def at_real_kpoint(self):
4244        kpu = 2*self.kpoints_unit()
4245        return len(kpu)==1 and abs(kpu-around(kpu)).sum()<1e-6
4246    #end def at_real_kpoint
4247
4248
4249    # test needed
4250    def bonds(self,neighbors,vectors=False):
4251        if self.dim!=3:
4252            self.error('bonds is currently only implemented for 3 dimensions')
4253        #end if
4254        natoms,dim = self.pos.shape
4255        centers = empty((natoms,neighbors,dim))
4256        distances = empty((natoms,neighbors))
4257        vect      = empty((natoms,neighbors,dim))
4258        t = self.tile((3,3,3))
4259        t.recenter(self.center)
4260        nn = nearest_neighbors(neighbors+1,t.pos,self.pos)
4261        for i in range(natoms):
4262            ii = nn[i,0]
4263            n=0
4264            for jj in nn[i,1:]:
4265                p1 = t.pos[ii]
4266                p2 = t.pos[jj]
4267                centers[i,n,:] = (p1+p2)/2
4268                distances[i,n]= sqrt(((p1-p2)**2).sum())
4269                vect[i,n,:] = p2-p1
4270                n+=1
4271            #end for
4272        #end for
4273        sn = self.copy()
4274        nnr = nn[:,1:].ravel()
4275        sn.set_elem(t.elem[nnr])
4276        sn.pos  = t.pos[nnr]
4277        sn.recenter()
4278        indices = self.locate(sn.pos)
4279        indices = indices.reshape(natoms,neighbors)
4280        if not vectors:
4281            return indices,centers,distances
4282        else:
4283            return indices,centers,distances,vect
4284        #end if
4285    #end def bonds
4286
4287
4288    # test needed
4289    def displacement(self,reference,map=False):
4290        if self.dim!=3:
4291            self.error('displacement is currently only implemented for 3 dimensions')
4292        #end if
4293        ref = reference.tile((3,3,3))
4294        ref.recenter(reference.center)
4295        rmap = array(3**3*list(range(len(reference.pos)),dtype=int))
4296        nn = nearest_neighbors(1,ref.pos,self.pos).ravel()
4297        displacement = self.pos - ref.pos[nn]
4298        if not map:
4299            return displacement
4300        else:
4301            return displacement,rmap[nn]
4302        #end if
4303    #end def displacement
4304
4305
4306    # test needed
4307    def scalar_displacement(self,reference):
4308        return sqrt((self.displacement(reference)**2).sum(1))
4309    #end def scalar_displacement
4310
4311
4312    # test needed
4313    def distortion(self,reference,neighbors):
4314        if self.dim!=3:
4315            self.error('distortion is currently only implemented for 3 dimensions')
4316        #end if
4317        if reference.volume()/self.volume() < 1.1:
4318            ref = reference.tile((3,3,3))
4319            ref.recenter(reference.center)
4320        else:
4321            ref = reference
4322        #end if
4323        rbi,rbc,rbl,rbv =  ref.bonds(neighbors,vectors=True)
4324        sbi,sbc,sbl,sbv = self.bonds(neighbors,vectors=True)
4325        nn = nearest_neighbors(1,reference.pos,self.pos).ravel()
4326        distortion = empty(sbv.shape)
4327        magnitude  = empty((len(self.pos),))
4328        for i in range(len(self.pos)):
4329            ir = nn[i]
4330            bonds  = sbv[i]
4331            rbonds = rbv[ir]
4332            ib  = empty((neighbors,),dtype=int)
4333            ibr = empty((neighbors,),dtype=int)
4334            r  = list(range(neighbors))
4335            rr = list(range(neighbors))
4336            for n in range(neighbors):
4337                mindist = 1e99
4338                ibmin  = -1
4339                ibrmin = -1
4340                for nb in r:
4341                    for nbr in rr:
4342                        d = norm(bonds[nb]-rbonds[nbr])
4343                        if d<mindist:
4344                            mindist=d
4345                            ibmin=nb
4346                            ibrmin=nbr
4347                        #end if
4348                    #end for
4349                #end for
4350                ib[n]=ibmin
4351                ibr[n]=ibrmin
4352                r.remove(ibmin)
4353                rr.remove(ibrmin)
4354                #end for
4355            #end for
4356            d = bonds[ib]-rbonds[ibr]
4357            distortion[i] = d
4358            magnitude[i] = (sqrt((d**2).sum(axis=1))).sum()
4359        #end for
4360        return distortion,magnitude
4361    #end def distortion
4362
4363
4364    # test needed
4365    def bond_compression(self,reference,neighbors):
4366        ref = reference
4367        rbi,rbc,rbl =  ref.bonds(neighbors)
4368        sbi,sbc,sbl = self.bonds(neighbors)
4369        bondlen = rbl.mean()
4370        return abs(1.-sbl/bondlen).max(axis=1)
4371    #end def bond_compression
4372
4373
4374    # test needed
4375    def boundary(self,dims=(0,1,2),dtol=1e-6):
4376        dim_eff = len(dims)
4377        natoms,dim = self.pos.shape
4378        bdims = array(dim*[False])
4379        for d in dims:
4380            bdims[d] = True
4381        #end for
4382        p = self.pos[:,bdims]
4383        indices = convex_hull(p,dim_eff,dtol)
4384        return indices
4385    #end def boundary
4386
4387
4388    def embed(self,small,dims=(0,1,2),dtol=1e-6,utol=1e-6):
4389        small = small.copy()
4390        small.recenter()
4391        center = array(self.center)
4392        self.recenter(small.center)
4393        bind = small.boundary(dims,dtol)
4394        bpos = small.pos[bind]
4395        belem= small.elem[bind]
4396        nn = nearest_neighbors(1,self.pos,bpos).ravel()
4397        mpos = self.pos[nn]
4398        dr = (mpos-bpos).mean(0)
4399        for i in range(len(bpos)):
4400            bpos[i]+=dr
4401        #end for
4402        dmax = sqrt(((mpos-bpos)**2).sum(1)).max()
4403        for i in range(len(small.pos)):
4404            small.pos[i]+=dr
4405        #end for
4406        ins,surface = small.inside(self.pos,tol=utol,separate=True)
4407        replaced = empty((len(self.pos),),dtype=bool)
4408        replaced[:] = False
4409        inside = replaced.copy()
4410        inside[ins] = True
4411        nn = nearest_neighbors(1,self.pos,small.pos).ravel()
4412        elist = list(self.elem)
4413        plist = list(self.pos)
4414        pos  = small.pos
4415        elem = small.elem
4416        for i in range(len(pos)):
4417            n = nn[i]
4418            if not replaced[n]:
4419                elist[n] = elem[i]
4420                plist[n] = pos[i]
4421                replaced[n] = True
4422            else:
4423                elist.append(elem[i])
4424                plist.append(pos[i])
4425            #end if
4426        #end for
4427        remove = arange(len(self.pos))[inside & logical_not(replaced)]
4428        remove.sort()
4429        remove = flipud(remove)
4430        for i in remove:
4431            elist.pop(i)
4432            plist.pop(i)
4433        #end for
4434        self.set_elem(elist)
4435        self.pos  = array(plist)
4436        self.recenter(center)
4437        return dmax
4438    #end def embed
4439
4440
4441    # test needed
4442    def shell(self,cell,neighbors,direction='in'):
4443        if self.dim!=3:
4444            self.error('shell is currently only implemented for 3 dimensions')
4445        #end if
4446        dd = {'in':equate,'out':negate}
4447        dir = dd[direction]
4448        natoms,dim=self.pos.shape
4449        ncells=3**3
4450        ntile = ncells*natoms
4451        pos = empty((ntile,dim))
4452        ind = empty((ntile,),dtype=int)
4453        oind = list(range(natoms))
4454        for nt in range(ncells):
4455            n=nt*natoms
4456            ind[n:n+natoms]=oind[:]
4457            pos[n:n+natoms]=self.pos[:]
4458        #end for
4459        nt=0
4460        for k in -1,0,1:
4461            for j in -1,0,1:
4462                for i in -1,0,1:
4463                    iv = array([[i,j,k]])
4464                    v = dot(iv,self.axes)
4465                    for d in range(dim):
4466                        ns = nt*natoms
4467                        ne = ns+natoms
4468                        pos[ns:ne,d] += v[0,d]
4469                    #end for
4470                    nt+=1
4471                #end for
4472            #end for
4473        #end for
4474
4475        inside = empty(ntile,)
4476        inside[:]=False
4477        ins = cell.inside(pos)
4478        inside[ins]=True
4479
4480        iishell = set()
4481        nn = nearest_neighbors(neighbors,pos)
4482        for ii in range(len(nn)):
4483            for jj in nn[ii]:
4484                in1 = inside[ii]
4485                in2 = inside[jj]
4486                if dir(in1 and not in2):
4487                    iishell.add(ii)
4488                #end if
4489                if dir(in2 and not in1):
4490                    iishell.add(jj)
4491                #end if
4492            #end if
4493        #end if
4494        ishell = ind[list(iishell)]
4495        return ishell
4496    #end def shell
4497
4498
4499    def interpolate(self,other,images,min_image=True,recenter=True,match_com=False,chained=False):
4500        s1 = self.copy()
4501        s2 = other.copy()
4502        s1.remove_folded()
4503        s2.remove_folded()
4504        if s2.units!=s1.units:
4505            s2.change_units(s1.units)
4506        #end if
4507        if (s1.elem!=s2.elem).any():
4508            self.error('cannot interpolate structures, atoms do not match\n  atoms1: {0}\n  atoms2: {1}'.format(s1.elem,s2.elem))
4509        #end if
4510        structures = []
4511        npath = images+2
4512        c1   = s1.center
4513        c2   = s2.center
4514        ax1  = s1.axes
4515        ax2  = s2.axes
4516        pos1 = s1.pos
4517        pos2 = s2.pos
4518        min_image &= abs(ax1-ax2).max()<1e-6
4519        if min_image:
4520            dp = self.min_image_vectors(pos1,pos2,ax1,pairs=False)
4521            pos2 = pos1 + dp
4522        #end if
4523        if match_com:
4524            com1 = pos1.mean(axis=0)
4525            com2 = pos2.mean(axis=1)
4526            dcom = com1-com2
4527            for n in range(len(pos2)):
4528                pos2[n] += dcom
4529            #end for
4530            if chained:
4531                other.pos = pos2
4532            #end if
4533        #end if
4534        for n in range(npath):
4535            f1 = 1.-float(n)/(npath-1)
4536            f2 = 1.-f1
4537            center = f1*c1   + f2*c2
4538            axes   = f1*ax1  + f2*ax2
4539            pos    = f1*pos1 + f2*pos2
4540            s = s1.copy()
4541            s.reset_axes(axes)
4542            s.center = center
4543            s.pos    = pos
4544            if recenter:
4545                s.recenter()
4546            #end if
4547            structures.append(s)
4548        #end for
4549        return structures
4550    #end def interpolate
4551
4552
4553    # returns madelung potential constant v_M
4554    #   see equation 7 in PRB 78 125106 (2008)
4555    def madelung(self,axes=None,tol=1e-10):
4556        if self.dim!=3:
4557            self.error('madelung is currently only implemented for 3 dimensions')
4558        #end if
4559        if axes is None:
4560            a = self.axes.T.copy()
4561        else:
4562            a = axes.T.copy()
4563        #end if
4564        if self.units!='B':
4565            a = convert(a,self.units,'B')
4566        #end if
4567        volume = abs(det(a))
4568        b = 2*pi*inv(a).T
4569        rconv = 8*(3.*volume/(4*pi))**(1./3)
4570        kconv = 2*pi/rconv
4571        gconst = -1./(4*kconv**2)
4572        vmc = -pi/(kconv**2*volume)-2*kconv/sqrt(pi)
4573
4574        nshells = 20
4575        vshell = [0.]
4576        p = Sobj()
4577        m = Sobj()
4578        for n in range(1,nshells+1):
4579            i = mgrid[-n:n+1,-n:n+1,-n:n+1]
4580            i = i.reshape(3,(2*n+1)**3)
4581            R = sqrt((dot(a,i)**2).sum(0))
4582            G2 = (dot(b,i)**2).sum(0)
4583
4584            izero = n + n*(2*n+1) + n*(2*n+1)**2
4585
4586            p.R  = R[0:izero]
4587            p.G2 = G2[0:izero]
4588            m.R  = R[izero+1:]
4589            m.G2 = G2[izero+1:]
4590            domains = [p,m]
4591            vshell.append(0.)
4592            for d in domains:
4593                vshell[n] += (erfc(kconv*d.R)/d.R).sum() + 4*pi/volume*(exp(gconst*d.G2)/d.G2).sum()
4594            #end for
4595            if abs(vshell[n]-vshell[n-1])<tol:
4596                break
4597            #end if
4598        #end for
4599        vm = vmc + vshell[-1]
4600
4601        if axes is None:
4602            self.Vmadelung = vm
4603        #end if
4604        return vm
4605    #end def madelung
4606
4607
4608    def makov_payne(self,q=1,eps=1.0,units='Ha',order=1):
4609        if order!=1:
4610            self.error('Only first order Makov-Payne correction is currently supported.')
4611        #end if
4612        if 'Vmadelung' not in self:
4613            vm = self.madelung()
4614        else:
4615            vm = self.Vmadelung
4616        #end if
4617        mp = -0.5*q**2*vm/eps
4618        if units!='Ha':
4619            mp = convert(mp,'Ha',units)
4620        #end if
4621        return mp
4622    #end def makov_payne
4623
4624
4625    def read(self,filepath,format=None,elem=None,block=None,grammar='1.1',cell='prim',contents=False):
4626        if os.path.exists(filepath):
4627            path,file = os.path.split(filepath)
4628            if format is None:
4629                if '.' in file:
4630                    name,format = file.rsplit('.',1)
4631                elif file.lower().endswith('poscar'):
4632                    format = 'poscar'
4633                else:
4634                    self.error('file format could not be determined\nunrecognized file: {0}'.format(filepath))
4635                #end if
4636            #end if
4637        elif not contents:
4638            self.error('file does not exist: {0}'.format(filepath))
4639        #end if
4640        if format is None:
4641            self.error('file format must be provided')
4642        #end if
4643        self.mag    = None
4644        self.frozen = None
4645        format = format.lower()
4646        if format=='xyz':
4647            self.read_xyz(filepath)
4648        elif format=='xsf':
4649            self.read_xsf(filepath)
4650        elif format=='poscar':
4651            self.read_poscar(filepath,elem=elem)
4652        elif format=='cif':
4653            self.read_cif(filepath,block=block,grammar=grammar,cell=cell)
4654        elif format=='fhi-aims':
4655            self.read_fhi_aims(filepath)
4656        else:
4657            self.error('cannot read structure from file\nunsupported file format: {0}'.format(format))
4658        #end if
4659        if self.has_axes():
4660            self.set_bconds('ppp')
4661        #end if
4662    #end def read
4663
4664
4665    def read_xyz(self,filepath):
4666        elem = []
4667        pos  = []
4668        if os.path.exists(filepath):
4669            lines = open(filepath,'r').read().strip().splitlines()
4670        else:
4671            lines = filepath.strip().splitlines() # "filepath" is file contents
4672        #end if
4673        if len(lines)>1:
4674            ntot = int(lines[0].strip())
4675            natoms = 0
4676            e = None
4677            p = None
4678            try:
4679                tokens = lines[1].split()
4680                if len(tokens)==4:
4681                    e = tokens[0]
4682                    p = array(tokens[1:],float)
4683                #end if
4684            except:
4685                None
4686            #end try
4687            if p is not None:
4688                elem.append(e)
4689                pos.append(p)
4690                natoms+=1
4691            #end if
4692            if len(lines)>2:
4693                for l in lines[2:]:
4694                    tokens = l.split()
4695                    if len(tokens)==4:
4696                        elem.append(tokens[0])
4697                        pos.append(array(tokens[1:],float))
4698                        natoms+=1
4699                        if natoms==ntot:
4700                            break
4701                        #end if
4702                    #end if
4703                #end for
4704            #end if
4705            if natoms!=ntot:
4706                self.error('xyz file read failed\nattempted to read file: {0}\nnumber of atoms expected: {1}\nnumber of atoms found: {2}'.format(filepath,ntot,natoms))
4707            #end if
4708        #end if
4709        self.dim   = 3
4710        self.set_elem(elem)
4711        self.pos   = array(pos)
4712        self.units = 'A'
4713    #end def read_xyz
4714
4715
4716    def read_xsf(self,filepath):
4717        if isinstance(filepath,XsfFile):
4718            f = filepath
4719        elif os.path.exists(filepath):
4720            f = XsfFile(filepath)
4721        else:
4722            f = XsfFile()
4723            f.read_text(filepath) # "filepath" is file contents
4724        #end if
4725        elem = []
4726        for n in f.elem:
4727            if isinstance(n,str):
4728                elem.append(n)
4729            else:
4730                elem.append(pt.simple_elements[n].symbol)
4731            #end if
4732        #end for
4733        self.dim   = 3
4734        self.units = 'A'
4735        self.reset_axes(f.primvec)
4736        self.set_elem(elem)
4737        self.pos = f.pos
4738    #end def read_xsf
4739
4740
4741    def read_poscar(self,filepath,elem=None):
4742        if os.path.exists(filepath):
4743            lines = open(filepath,'r').read().splitlines()
4744        else:
4745            lines = filepath.splitlines()  # "filepath" is file contents
4746        #end if
4747        nlines = len(lines)
4748        min_lines = 8
4749        if nlines<min_lines:
4750            self.error('POSCAR file must have at least {0} lines\n  only {1} lines found'.format(min_lines,nlines))
4751        #end if
4752        dim = 3
4753        scale = float(lines[1].strip())
4754        axes = empty((dim,dim))
4755        axes[0] = array(lines[2].split(),dtype=float)
4756        axes[1] = array(lines[3].split(),dtype=float)
4757        axes[2] = array(lines[4].split(),dtype=float)
4758        if scale<0.0:
4759            scale = abs(scale)/det(axes)
4760        #end if
4761        axes = scale*axes
4762        tokens = lines[5].split()
4763        if tokens[0].isdigit():
4764            counts = array(tokens,dtype=int)
4765            if elem is None:
4766                self.error('variable elem must be provided to read_poscar() to assign atomic species to positions for POSCAR format')
4767            elif len(elem)!=len(counts):
4768                self.error('one elem must be given for each element count in the POSCAR file\n  number of elem counts: {0}\n  number of elem given: {1}'.format(len(counts),len(elem)))
4769            #end if
4770            lcur = 6
4771        else:
4772            elem   = tokens
4773            counts = array(lines[6].split(),dtype=int)
4774            lcur = 7
4775        #end if
4776        species = elem
4777        # relabel species that have multiple occurances
4778        sset = set(species)
4779        for spec in sset:
4780            if species.count(spec)>1:
4781                cnt=0
4782                for n in range(len(species)):
4783                    specn = species[n]
4784                    if specn==spec:
4785                        cnt+=1
4786                        species[n] = specn+str(cnt)
4787                    #end if
4788                #end for
4789            #end if
4790        #end for
4791        elem = []
4792        for i in range(len(counts)):
4793            elem.extend(counts[i]*[species[i]])
4794        #end for
4795        self.dim = dim
4796        self.units = 'A'
4797        self.reset_axes(axes)
4798
4799        if lcur<len(lines) and len(lines[lcur])>0:
4800            c = lines[lcur].lower().strip()[0]
4801            lcur+=1
4802        else:
4803            return
4804        #end if
4805        selective_dynamics = c=='s'
4806        if selective_dynamics: # Selective dynamics
4807            if lcur<len(lines) and len(lines[lcur])>0:
4808                c = lines[lcur].lower().strip()[0]
4809                lcur+=1
4810            else:
4811                return
4812            #end if
4813        #end if
4814        cartesian = c=='c' or c=='k'
4815        npos = counts.sum()
4816        if lcur+npos>len(lines):
4817            return
4818        #end if
4819        spos = []
4820        for i in range(npos):
4821            spos.append(lines[lcur+i].split())
4822        #end for
4823        spos = array(spos)
4824        pos  = array(spos[:,0:3],dtype=float)
4825        if cartesian:
4826            pos = scale*pos
4827        else:
4828            pos = dot(pos,axes)
4829        #end if
4830        self.set_elem(elem)
4831        self.pos = pos
4832        if selective_dynamics or spos.shape[1]>3:
4833            move = array(spos[:,3:6],dtype=str)
4834            self.freeze(list(range(self.size())),directions=move=='F')
4835        #end if
4836    #end def read_poscar
4837
4838
4839    def read_cif(self,filepath,block=None,grammar='1.1',cell='prim'):
4840        axes,elem,pos,units = read_cif(filepath,block,grammar,cell,args_only=True)
4841        self.dim = 3
4842        self.set_axes(axes)
4843        self.set_elem(elem)
4844        self.pos = pos
4845        self.units = units
4846    #end def read_cif
4847
4848
4849    # test needed
4850    def read_fhi_aims(self,filepath):
4851        if os.path.exists(filepath):
4852            lines = open(filepath,'r').read().splitlines()
4853        else:
4854            lines = filepath.splitlines() # "filepath" is contents
4855        #end if
4856        axes = []
4857        pos  = []
4858        elem = []
4859        constrain_relax = []
4860        unit_pos = False
4861        for line in lines:
4862            ls = line.strip()
4863            if len(ls)>0 and ls[0]!='#':
4864                tokens = ls.split()
4865                t0 = tokens[0]
4866                if t0=='lattice_vector':
4867                    axes.append(tokens[1:])
4868                elif t0=='atom_frac':
4869                    pos.append(tokens[1:4])
4870                    elem.append(tokens[4])
4871                    unit_pos = True
4872                elif t0=='atom':
4873                    pos.append(tokens[1:4])
4874                    elem.append(tokens[4])
4875                elif t0=='constrain_relaxation':
4876                    constrain_relax.append(tokens[1])
4877                elif t0.startswith('initial'):
4878                    None
4879                else:
4880                    #None
4881                    self.error('unrecogonized or not yet supported token in fhi-aims geometry file: {0}'.format(t0))
4882                #end if
4883            #end if
4884        #end for
4885        axes = array(axes,dtype=float)
4886        pos  = array(pos,dtype=float)
4887        if unit_pos:
4888            pos = dot(pos,axes)
4889        #end if
4890        self.dim = 3
4891        if len(axes)>0:
4892            self.set_axes(axes)
4893        #end if
4894        self.set_elem(elem)
4895        self.pos   = pos
4896        self.units = 'A'
4897        if len(constrain_relax)>0:
4898            constrain_relax = array(constrain_relax)
4899            self.freeze(list(range(self.size())),directions=constrain_relax=='.true.')
4900        #end if
4901    #end def read_fhi_aims
4902
4903
4904    def write(self,filepath=None,format=None):
4905        if filepath is None and format is None:
4906            self.error('please specify either the filepath or format arguments to write()')
4907        elif format is None:
4908            if '.' in filepath:
4909                format = filepath.split('.')[-1]
4910            else:
4911                self.error('file format could not be determined\neither request the format directly with the format keyword or add a file format extension to the file name')
4912            #end if
4913        #end if
4914        format = format.lower()
4915        if format=='xyz':
4916            c = self.write_xyz(filepath)
4917        elif format=='xsf':
4918            c = self.write_xsf(filepath)
4919        elif format=='poscar':
4920            c = self.write_poscar(filepath)
4921        elif format=='fhi-aims':
4922            c = self.write_fhi_aims(filepath)
4923        else:
4924            self.error('file format {0} is unrecognized'.format(format))
4925        #end if
4926        return c
4927    #end def write
4928
4929
4930    def write_xyz(self,filepath=None,header=True,units='A'):
4931        if self.dim!=3:
4932            self.error('write_xyz is currently only implemented for 3 dimensions')
4933        #end if
4934        s = self.copy()
4935        s.change_units(units)
4936        c=''
4937        if header:
4938            c += str(len(s.elem))+'\n\n'
4939        #end if
4940        for i in range(len(s.elem)):
4941            e = s.elem[i]
4942            p = s.pos[i]
4943            c+=' {0:2} {1:12.8f} {2:12.8f} {3:12.8f}\n'.format(e,p[0],p[1],p[2])
4944        #end for
4945        if filepath!=None:
4946            open(filepath,'w').write(c)
4947        #end if
4948        return c
4949    #end def write_xyz
4950
4951
4952    def write_xsf(self,filepath=None):
4953        if self.dim!=3:
4954            self.error('write_xsf is currently only implemented for 3 dimensions')
4955        #end if
4956        s = self.copy()
4957        s.change_units('A')
4958        c  = ' CRYSTAL\n'
4959        c += ' PRIMVEC\n'
4960        for a in s.axes:
4961            c += '   {0:12.8f}  {1:12.8f}  {2:12.8f}\n'.format(*a)
4962        #end for
4963        c += ' PRIMCOORD\n'
4964        c += '   {0} 1\n'.format(len(s.elem))
4965        for i in range(len(s.elem)):
4966            e = s.elem[i]
4967            identified = e in pt.elements
4968            if not identified:
4969                if len(e)>2:
4970                    e = e[0:2]
4971                elif len(e)==2:
4972                    e = e[0:1]
4973                #end if
4974                identified = e in pt.elements
4975            #end if
4976            if not identified:
4977                self.error('{0} is not an element\nxsf file cannot be written'.format(e))
4978            #end if
4979            enum = pt.elements[e].atomic_number
4980            r = s.pos[i]
4981            c += '   {0:>3} {1:12.8f}  {2:12.8f}  {3:12.8f}\n'.format(enum,r[0],r[1],r[2])
4982        #end for
4983        if filepath!=None:
4984            open(filepath,'w').write(c)
4985        #end if
4986        return c
4987    #end def write_xsf
4988
4989
4990    def write_poscar(self,filepath=None):
4991        s = self.copy()
4992        s.change_units('A')
4993        species,species_count = s.order_by_species()
4994        poscar = PoscarFile()
4995        poscar.scale      = 1.0
4996        poscar.axes       = s.axes
4997        poscar.elem       = species
4998        poscar.elem_count = species_count
4999        poscar.coord      = 'cartesian'
5000        poscar.pos        = s.pos
5001        c = poscar.write_text()
5002        if filepath is not None:
5003            open(filepath,'w').write(c)
5004        #end if
5005        return c
5006    #end def write_poscar
5007
5008
5009    # test needed
5010    def write_fhi_aims(self,filepath=None):
5011        s = self.copy()
5012        s.change_units('A')
5013        c = ''
5014        c+='\n'
5015        for a in s.axes:
5016            c += 'lattice_vector   {0: 12.8f}  {1: 12.8f}  {2: 12.8f}\n'.format(*a)
5017        #end for
5018        c+='\n'
5019        for p,e in zip(self.pos,self.elem):
5020            c += 'atom_frac   {0: 12.8f}  {1: 12.8f}  {2: 12.8f}  {3}\n'.format(p[0],p[1],p[2],e)
5021        #end for
5022        if filepath!=None:
5023            open(filepath,'w').write(c)
5024        #end if
5025        return c
5026    #end def write_fhi_aims
5027
5028
5029    def plot2d_ax(self,ix,iy,*args,**kwargs):
5030        if self.dim!=3:
5031            self.error('plot2d_ax is currently only implemented for 3 dimensions')
5032        #end if
5033        iz = list(set([0,1,2])-set([ix,iy]))[0]
5034        ax = self.axes.copy()
5035        a  = self.axes[iz]
5036        dc = self.center-ax.sum(0)/2
5037        pp = array([0*a,ax[ix],ax[ix]+ax[iy],ax[iy],0*a])
5038        for i in range(len(pp)):
5039            pp[i]+=dc
5040            pp[i]-=dot(a,pp[i])/dot(a,a)*a
5041        #end for
5042        plot(pp[:,ix],pp[:,iy],*args,**kwargs)
5043    #end def plot2d_ax
5044
5045
5046    def plot2d_pos(self,ix,iy,*args,**kwargs):
5047        if self.dim!=3:
5048            self.error('plot2d_pos is currently only implemented for 3 dimensions')
5049        #end if
5050        iz = list(set([0,1,2])-set([ix,iy]))[0]
5051        pp = self.pos.copy()
5052        a = self.axes[iz]
5053        for i in range(len(pp)):
5054            pp[i] -= dot(a,pp[i])/dot(a,a)*a
5055        #end for
5056        plot(pp[:,ix],pp[:,iy],*args,**kwargs)
5057    #end def plot2d_pos
5058
5059
5060    def plot2d(self,pos_style='b.',ax_style='k-'):
5061        if self.dim!=3:
5062            self.error('plot2d is currently only implemented for 3 dimensions')
5063        #end if
5064        subplot(1,3,1)
5065        self.plot2d_ax(0,1,ax_style,lw=2)
5066        self.plot2d_pos(0,1,pos_style)
5067        title('a1,a2')
5068        subplot(1,3,2)
5069        self.plot2d_ax(1,2,ax_style,lw=2)
5070        self.plot2d_pos(1,2,pos_style)
5071        title('a2,a3')
5072        subplot(1,3,3)
5073        self.plot2d_ax(2,0,ax_style,lw=2)
5074        self.plot2d_pos(2,0,pos_style)
5075        title('a3,a1')
5076    #end def plot2d
5077
5078
5079    def plot2d_kax(self,ix,iy,*args,**kwargs):
5080        if self.dim!=3:
5081            self.error('plot2d_ax is currently only implemented for 3 dimensions')
5082        #end if
5083        iz = list(set([0,1,2])-set([ix,iy]))[0]
5084        ax = self.kaxes.copy()
5085        a  = ax[iz]
5086        dc = 0*a
5087        pp = array([0*a,ax[ix],ax[ix]+ax[iy],ax[iy],0*a])
5088        for i in range(len(pp)):
5089            pp[i]+=dc
5090            pp[i]-=dot(a,pp[i])/dot(a,a)*a
5091        #end for
5092        plot(pp[:,ix],pp[:,iy],*args,**kwargs)
5093    #end def plot2d_kax
5094
5095
5096    def plot2d_kp(self,ix,iy,*args,**kwargs):
5097        if self.dim!=3:
5098            self.error('plot2d_kp is currently only implemented for 3 dimensions')
5099        #end if
5100        iz = list(set([0,1,2])-set([ix,iy]))[0]
5101        pp = self.kpoints.copy()
5102        a = self.kaxes[iz]
5103        for i in range(len(pp)):
5104            pp[i] -= dot(a,pp[i])/dot(a,a)*a
5105        #end for
5106        plot(pp[:,ix],pp[:,iy],*args,**kwargs)
5107    #end def plot2d_kp
5108
5109
5110    def show(self,viewer='vmd',filepath='/tmp/tmp.xyz'):
5111        if self.dim!=3:
5112            self.error('show is currently only implemented for 3 dimensions')
5113        #end if
5114        self.write_xyz(filepath)
5115        os.system(viewer+' '+filepath)
5116    #end def show
5117
5118
5119    # minimal ASE Atoms-like interface to Structure objects for spglib
5120    def get_cell(self):
5121        return self.axes.copy()
5122    #end def get_cell
5123
5124    def get_scaled_positions(self):
5125        return self.pos_unit()
5126    #end def get_scaled_positions
5127
5128    def get_number_of_atoms(self):
5129        return len(self.elem)
5130    #end def get_number_of_atoms
5131
5132    def get_atomic_numbers(self):
5133        an = []
5134        for e in self.elem:
5135            iselem,esymb = is_element(e,symbol=True)
5136            if not iselem:
5137                self.error('Atomic symbol, {}, not recognized'.format(esymb))
5138            else:
5139                an.append(ptable[esymb].atomic_number)
5140            #end if
5141        #end for
5142        return array(an,dtype='intc')
5143    #end def get_atomic_numbers
5144
5145    def get_magnetic_moments(self):
5146        self.error('structure objects do not currently support magnetic moments')
5147    #end def get_magnetic_moments
5148
5149
5150    # direct spglib interface
5151    def spglib_cell(self):
5152        lattice   = self.axes.copy()
5153        positions = self.pos_unit()
5154        numbers   = self.get_atomic_numbers()
5155        cell = (lattice,positions,numbers)
5156        return cell
5157    #end def spglib_cell
5158
5159
5160    def get_symmetry(self,symprec=1e-5):
5161        cell = self.spglib_cell()
5162        return spglib.get_symmetry(cell,symprec=symprec)
5163    #end def get_symmetry
5164
5165
5166    def get_symmetry_dataset(self,symprec=1e-5, angle_tolerance=-1.0, hall_number=0):
5167        cell = self.spglib_cell()
5168        ds   = spglib.get_symmetry_dataset(
5169            cell,
5170            symprec         = symprec,
5171            angle_tolerance = angle_tolerance,
5172            hall_number     = hall_number,
5173            )
5174        return ds
5175    #end def get_symmetry
5176
5177
5178    # functions based on direct spglib interface
5179
5180    def symmetry_data(self,*args,**kwargs):
5181        ds = self.get_symmetry_dataset(*args,**kwargs)
5182        ds = obj(ds)
5183        for k,v in ds.items():
5184            if isinstance(v,dict):
5185                ds[k] = obj(v)
5186            #end if
5187        #end for
5188        return ds
5189    #end def symmetry_data
5190
5191
5192    def bravais_lattice_name(self,symm_data=None):
5193        if symm_data is None:
5194            symm_data = self.symmetry_data()
5195        #end if
5196        sg   = symm_data.number
5197        name = symm_data.international
5198        if not isinstance(sg,int) or sg<1 or sg>230:
5199            self.error('Invalid space group from spglib: {}'.format(sg))
5200        #end if
5201        if not isinstance(name,str):
5202            self.error('Invalid space group name from spglib: {}'.format(name))
5203        #end if
5204        bv = None
5205        if sg>=1 and sg<=2:
5206            bv = 'triclinic_'+name[0]
5207        elif sg>=3 and sg<=15:
5208            bv = 'monoclinic_'+name[0]
5209        elif sg>=16 and sg<=74:
5210            bv = 'orthorhombic_'+name[0]
5211        elif sg>=75 and sg<=142:
5212            bv = 'tetragonal_'+name[0]
5213        elif sg>=143 and sg<=167:
5214            bv = 'trigonal_'+name[0]
5215        elif sg>=168 and sg<=194:
5216            bv = 'hexagonal_'+name[0]
5217        elif sg>=195 and sg<=230:
5218            bv = 'cubic_'+name[0]
5219        #end if
5220        if bv is None:
5221            self.error('Bravais lattice could not be determined.\nSpace group number and name: {} {}'.format(sg,name))
5222        #end if
5223        return bv
5224    #end def bravais_lattice_name
5225
5226
5227    # test needed
5228    def space_group_operations(self,tol=1e-5,unit=False):
5229        ds = self.get_symmetry(symprec=tol)
5230        if ds is None:
5231            self.error('Symmetry search failed.\nspglib error message:\n{}'.format(spglib.get_error_message()))
5232        #end if
5233        ds = obj(ds)
5234        rotations    = ds.rotations
5235        translations = ds.translations
5236
5237        if not unit:
5238            # Transform to Cartesian
5239            axes = self.axes
5240            axinv = np.linalg.inv(axes)
5241            for n,(R,t) in enumerate(zip(rotations,translations)):
5242                rotations[n]    = np.dot(axinv,np.dot(R,axes))
5243                translations[n] = np.dot(t,axes)
5244            #end for
5245        #end if
5246
5247        return rotations,translations
5248    #end def space_group_operations
5249
5250
5251    def point_group_operations(self,tol=1e-5,unit=False):
5252        rotations,translations = self.space_group_operations(tol=tol,unit=unit)
5253        no_trans = translations.max(axis=1) < tol
5254        return rotations[no_trans]
5255    #end def point_group_operations
5256
5257
5258    def check_point_group_operations(self,rotations=None,tol=1e-5,unit=False,dtol=1e-5,ncheck=1,exit=False):
5259        if rotations is None:
5260            rotations = self.point_group_operations(tol=tol,unit=unit)
5261        #ned if
5262        r = self.pos
5263        if ncheck=='all':
5264            ncheck = len(r)
5265        #end if
5266        all_same = True
5267        for n in range(ncheck):
5268            rc = r[n]
5269            for R in rotations:
5270                rp = np.dot(r-rc,R)+rc
5271                dt = self.min_image_distances(r,rp)
5272                same = True
5273                for d in dt:
5274                    same &= dt.min()<dtol
5275                #end for
5276                all_same &= same
5277            #end for
5278        #end for
5279        if not all_same and exit:
5280            self.error('Point group operators are not all symmetries of the structure.')
5281        #end if
5282        return all_same
5283    #end def check_point_group_operations
5284
5285
5286    def equivalent_atoms(self):
5287        ds = self.symmetry_data()
5288
5289        # collect sets of species labels
5290        species_by_specnum = obj()
5291        for e,sn in zip(self.elem,ds.equivalent_atoms):
5292            is_elem,es = is_element(e,symbol=True)
5293            if sn not in species_by_specnum:
5294                species_by_specnum[sn] = set()
5295            #end if
5296            species_by_specnum[sn].add(es)
5297        #end for
5298        for sn,sset in species_by_specnum.items():
5299            if len(sset)>1:
5300                self.error('Cannot find equivalent atoms.\nMultiple atomic species were marked as being equivalent.\nSpecies marked in this way: {}'.format(list(sset)))
5301            #end if
5302            species_by_specnum[sn] = list(sset)[0]
5303        #end for
5304
5305        # give each unique species a unique label
5306        labels_by_specnum = obj()
5307        species_list = list(species_by_specnum.values())
5308        species_set  = set(species_list)
5309        species_counts = obj()
5310        for s in species_set:
5311            species_counts[s] = species_list.count(s)
5312        #end for
5313        spec_counts = obj()
5314        for sn,s in species_by_specnum.items():
5315            if species_counts[s]==1:
5316                labels_by_specnum[sn] = s
5317            else:
5318                if s not in spec_counts:
5319                    spec_counts[s] = 1
5320                else:
5321                    spec_counts[s] += 1
5322                #end if
5323                labels_by_specnum[sn] = s + str(spec_counts[s])
5324            #end if
5325        #end for
5326
5327        # find indices for each unique species
5328        equiv_indices = obj()
5329        for s in labels_by_specnum.values():
5330            equiv_indices[s] = list()
5331        #end for
5332        for i,sn in enumerate(ds.equivalent_atoms):
5333            equiv_indices[labels_by_specnum[sn]].append(i)
5334        #end for
5335        for s,indices in equiv_indices.items():
5336            equiv_indices[s] = np.array(indices,dtype=int)
5337        #end for
5338
5339        return equiv_indices
5340    #end def equivalent_atoms
5341
5342
5343    # operations to support restricted cases for RMG code
5344
5345    # supported rmg lattices
5346    rmg_lattices = obj(
5347        orthorhombic_P = 'Orthorhombic Primitive',
5348        tetragonal_P   = 'Tetragonal Primitive',
5349        hexagonal_P    = 'Hexagonal Primitive',
5350        cubic_P        = 'Cubic Primitive',
5351        cubic_I        = 'Cubic Body Centered',
5352        cubic_F        = 'Cubic Face Centered',
5353        )
5354
5355    def rmg_lattice(self,allow_tile=False,all_results=False,ret_bravais=False,exit=False,warn=False):
5356        # output variables
5357        rmg_lattice = None
5358        tmatrix     = None
5359
5360        rmg_lattices = Structure.rmg_lattices
5361
5362        # represent current bravais lattice
5363        s = Structure(
5364            units = str(self.units),
5365            axes  = self.axes.copy(),
5366            elem  = ['H'],
5367            pos   = [[0,0,0]],
5368            )
5369
5370        # get standard primitive cell based on bravais lattice
5371        sp = s.primitive()
5372
5373        # get current bravais lattice name
5374        d   = sp.symmetry_data()
5375        bv  = sp.bravais_lattice_name(d)
5376        bvp = bv.rsplit('_',1)[0]+'_P'
5377        if bv in rmg_lattices:
5378            rmg_lattice = bv
5379        #end if
5380
5381        # attempt to get a valid cell by tiling if current one is not valid
5382        if rmg_lattice is None and bvp in rmg_lattices and allow_tile:
5383            spt = Structure(
5384                units = self.units,
5385                axes  = d.std_lattice,
5386                )
5387            tmatrix,valid_by_tiling = spt.tilematrix(sp,status=True)
5388            if not valid_by_tiling:
5389                tmatrix = None
5390            else:
5391                # apply tiling matrix
5392                st = sp.copy().tile(tmatrix)
5393                # update lattice type
5394                rmg_lattice,bv = st.rmg_lattice(ret_bravais=True)
5395            #end if
5396        #end if
5397
5398        if rmg_lattice is None and (exit or warn):
5399            msg = 'Bravais lattice is not supported by the RMG code.\nCell bravais lattice: {}\nLattices supported by RMG: {}'.format(bv,list(sorted(rmg_lattices.keys())))
5400            if exit:
5401                self.error(msg)
5402            elif warn:
5403                self.warn(msg)
5404            #end if
5405        #end if
5406
5407        if all_results:
5408            return rmg_lattice,tmatrix,s,sp,bv
5409        elif allow_tile:
5410            return rmg_lattice,tmatrix
5411        elif ret_bravais:
5412            return rmg_lattice,bv
5413        else:
5414            return rmg_lattice
5415        #end if
5416    #end def rmg_lattice
5417
5418
5419    def rmg_transform(self,allow_tile=False,allow_general=False,all_results=False):
5420        rmg_lattice,tmatrix,s,sp,bv = self.rmg_lattice(
5421            allow_tile  = allow_tile,
5422            exit        = not allow_general,
5423            all_results = True,
5424            )
5425        if rmg_lattice is None and allow_general:
5426            s_trans    = self.copy()
5427            rmg_inputs = obj()
5428            R          = None
5429            tmatrix    = None
5430        else:
5431            s_trans = self.copy()
5432            R = np.dot(np.linalg.inv(s.axes),sp.axes)
5433            s_trans.matrix_transform(R.T)
5434            if tmatrix is not None:
5435                s_trans = s_trans.tile(tmatrix)
5436            #end if
5437            if s_trans.units=='A':
5438                rmg_units = 'Angstrom'
5439            elif s_trans.units=='B':
5440                rmg_units = 'Bohr'
5441            else:
5442                self.error('Unrecognized length units in structure "{}"'.format(s_trans.units))
5443            #end if
5444            bl_type = self.rmg_lattices[rmg_lattice]
5445            axes = s_trans.axes
5446            if bl_type=='Cubic Primitive':
5447                a = axes[0,0]
5448                b = a
5449                c = a
5450            elif bl_type=='Tetragonal Primitive':
5451                a = axes[0,0]
5452                b = a
5453                c = axes[2,2]
5454            elif bl_type=='Orthorhombic Primitive':
5455                a = axes[0,0]
5456                b = axes[1,1]
5457                c = axes[2,2]
5458            elif bl_type=='Cubic Body Centered':
5459                a = np.linalg.norm(axes[0])*2/np.sqrt(3.)
5460                b = a
5461                c = a
5462            elif bl_type=='Cubic Face Centered':
5463                a = np.linalg.norm(axes[0])*2/np.sqrt(2.)
5464                b = a
5465                c = a
5466            elif bl_type=='Hexagonal Primitive':
5467                a = np.linalg.norm(axes[0])
5468                b = a
5469                c = np.linalg.norm(axes[2])
5470            else:
5471                self.error('Unrecognized RMG bravais_lattice_type "{}"'.format(bl_type))
5472            #end if
5473            rmg_inputs = obj(
5474                bravais_lattice_type = bl_type,
5475                a_length             = a,
5476                b_length             = b,
5477                c_length             = c,
5478                length_units         = rmg_units,
5479                )
5480        #end if
5481        if not all_results:
5482            return s_trans,rmg_inputs
5483        else:
5484            return s_trans,rmg_inputs,R,tmatrix,bv
5485        #end if
5486    #end def rmg_transform
5487
5488#end class Structure
5489Structure.set_operations()
5490
5491
5492#======================#
5493#  SeeK-path functions #
5494#======================#
5495
5496# installation instructions for seekpath interface
5497#
5498#  installation of seekpath
5499#    pip install seekpath
5500import itertools
5501from periodic_table import pt as ptable
5502try:
5503    from numpy import array_equal
5504except:
5505    array_equal = unavailable('numpy','array_equal')
5506#end try
5507try:
5508    import seekpath
5509    from seekpath import get_explicit_k_path
5510    version = seekpath.__version__
5511
5512    try:
5513        version = [int(i) for i in version.split('.')]
5514        if len(version) < 3:
5515            raise ValueError
5516        #end if
5517    except ValueError:
5518        raise ValueError("Unable to parse version number")
5519    #end try
5520    if tuple(version) < (1, 8, 3):
5521        raise ValueError("Invalid seekpath version, need >= 1.8.4")
5522    #end if
5523    del version
5524    del seekpath
5525except:
5526    get_explicit_k_path = unavailable('seekpath','get_explicit_k_path')
5527#end try
5528
5529def _getseekpath(
5530    structure          = None,
5531    with_time_reversal = False,
5532    recipe             = 'hpkot',
5533    reference_distance = 0.025,
5534    threshold          = 1E-7,
5535    symprec            = 1E-5,
5536    angle_tolerance    = 1.0,
5537    ):
5538    if not isinstance(structure, Structure):
5539        raise TypeError('structure is not of type Structure\ntype received: {0}'.format(structure.__class__.__name__))
5540    #end if
5541    if structure.has_folded():
5542        structure = structure.folded_structure
5543    #end if
5544    structure = structure.copy()
5545    if structure.units is not 'A':
5546        structure.change_units('A')
5547    #end if
5548    axes       = structure.axes
5549    unit_pos   = structure.get_scaled_positions()
5550    atomic_num = structure.get_atomic_numbers()
5551    cell = (axes,unit_pos,atomic_num)
5552    return get_explicit_k_path(cell)
5553#end def _getseekpath
5554
5555def get_conventional_cell(
5556    structure       = None,
5557    symprec         = 1E-5,
5558    angle_tolerance = 1.0,
5559    seekpathout     = None,
5560    ):
5561    if seekpathout is None:
5562        seekpathout = _getseekpath(structure=structure, symprec = symprec, angle_tolerance=angle_tolerance)
5563    #end if
5564    axes        = seekpathout['conv_lattice']
5565    enumbers    = seekpathout['conv_types']
5566    posd        = seekpathout['conv_positions']
5567    volfac      = seekpathout['volume_original_wrt_conv']
5568    bcharge     = structure.background_charge*volfac
5569    pos         = dot(posd,axes)
5570    sout        = structure.copy()
5571    elem        = empty(len(enumbers), dtype='str')
5572    for el in ptable.elements.items():
5573        elem[enumbers==el[1].atomic_number]=el[0]
5574    #end for
5575    if abs(bcharge-int(bcharge)) > 1E-6:
5576        raise ValueError("Invalid background charge for conventional structure")
5577    #end if
5578    return {'structure': Structure(axes=axes, elem=elem, pos=pos, background_charge = bcharge, units='A')}
5579#end def get_conventional_cell
5580
5581def get_primitive_cell(
5582    structure       = None,
5583    symprec         = 1E-5,
5584    angle_tolerance = 1.0,
5585    seekpathout     = None,
5586    ):
5587    if seekpathout is None:
5588        seekpathout = _getseekpath(structure = structure, symprec = symprec, angle_tolerance=angle_tolerance)
5589    #end if
5590    axes        = seekpathout['primitive_lattice']
5591    enumbers    = seekpathout['primitive_types']
5592    posd        = seekpathout['primitive_positions']
5593    volfac      = seekpathout['volume_original_wrt_prim']
5594    bcharge     = structure.background_charge*volfac
5595    pos         = dot(posd,axes)
5596    sout        = structure.copy()
5597    elem        = array(enumbers, dtype='str')
5598    for el in ptable.elements.items():
5599        elem[enumbers==el[1].atomic_number]=el[0]
5600    #end for
5601    return {'structure' : Structure(axes=axes, elem=elem, pos=pos, background_charge=bcharge, units='A'),
5602            'T'         : seekpathout['primitive_transformation_matrix']}
5603#end def get_primitive_cell
5604
5605def get_kpath(
5606    structure          = None,
5607    check_standard     = True,
5608    with_time_reversal = False,
5609    recipe             = 'hpkot',
5610    reference_distance = 0.025,
5611    threshold          = 1E-7,
5612    symprec            = 1E-5,
5613    angle_tolerance    = 1.0,
5614    seekpathout        = None,
5615    ):
5616    if seekpathout is None:
5617        seekpathout = _getseekpath(structure=structure, symprec = symprec, angle_tolerance=angle_tolerance,
5618                                   recipe=recipe, reference_distance=reference_distance, with_time_reversal=with_time_reversal)
5619    #end if
5620    if check_standard:
5621        structure = structure.copy()
5622        structure.change_units('A')
5623        axes    = structure.axes
5624        primlat = seekpathout['primitive_lattice']
5625        if not isclose(primlat, axes).all():
5626            #print primlat, axes
5627            Structure.class_error('Input lattice is not the conventional lattice. If you like otherwise, set check_standard=False.')
5628        #end if
5629    #end if
5630    inverse_A_to_inverse_B = convert(1.0,'A','B')
5631    return {'explicit_kpoints_abs_inv_A'      : seekpathout['explicit_kpoints_abs'],
5632            'explicit_kpoints_abs_inv_B'      : seekpathout['explicit_kpoints_abs']*inverse_A_to_inverse_B,
5633            'explicit_kpoints_rel'      : seekpathout['explicit_kpoints_rel'],
5634            'explicit_kpoints_labels'   : seekpathout['explicit_kpoints_labels'],
5635            'path'                      : seekpathout['path'],
5636            'explicit_path_linearcoords': seekpathout['explicit_kpoints_linearcoord'],
5637            'point_coords'              : seekpathout['point_coords']}
5638#end def get_kpath
5639
5640def get_symmetry(
5641    structure       = None,
5642    symprec         = 1E-5,
5643    angle_tolerance = 1.0,
5644    seekpathout     = None,
5645    ):
5646    if seekpathout is None:
5647        seekpathout = _getseekpath(structure = structure, symprec = symprec, angle_tolerance=angle_tolerance)
5648    #end if
5649    sgint       = seekpathout['spacegroup_international']
5650    bravais     = seekpathout['bravais_lattice']
5651    invsym      = seekpathout['has_inversion_symmetry']
5652    sgnum       = seekpathout['spacegroup_number']
5653
5654    return {'sgint': sgint, 'bravais': bravais, 'inv_sym_exists': invsym, 'sgnum': sgnum}
5655#end def get_symmetry
5656
5657def get_structure_with_bands(
5658    cell               = 0,
5659    structure          = None,
5660    with_time_reversal = False,
5661    reference_distance = 0.025,
5662    threshold          = 1E-7,
5663    symprec            = 1E-5,
5664    angle_tolerance    = 1.0,
5665    ):
5666    if cell == 0:
5667        ''' Use input structure '''
5668        struct_band = structure.copy()
5669    elif cell == 1:
5670        ''' Use conventional structure '''
5671        struct_band = get_conventional_cell(structure=structure, symprec=symprec, angle_tolerance=angle_tolerance)['structure']
5672    elif cell == 2:
5673        ''' Use primitive structure '''
5674        struct_band = get_primitive_cell(structure=structure, symprec=symprec, angle_tolerance=angle_tolerance)['structure']
5675    else:
5676        Structure.class_error('Invalid cell type')
5677    #end if
5678    kpath = get_kpath(structure=struct_band, check_standard=False, with_time_reversal=with_time_reversal)
5679    return Structure(axes              = struct_band.axes,
5680                     elem              = struct_band.elem,
5681                     pos               = struct_band.pos,
5682                     background_charge = struct_band.background_charge,
5683                     kpoints           = kpath['explicit_kpoints_rel'],
5684                     units             = 'A')
5685#end def get_structure_with_bands
5686
5687# test needed
5688def get_band_tiling(
5689    structure      = None,
5690    check_standard = True,
5691    use_ktol       = True,
5692    kpoints_label  = None,
5693    kpoints_rel    = None,
5694    max_volfac     = 20,
5695    min_volfac     = 0,
5696    target_volfac  = None,
5697    ):
5698
5699    def cube_deviation(axes):
5700        a = axes
5701        volume = abs(dot(cross(axes[0,:], axes[1,:]), axes[2,:]))
5702        dc = volume**(1./3)*sqrt(2.)
5703        d1 = abs(norm(a[0]+a[1])-dc)
5704        d2 = abs(norm(a[1]+a[2])-dc)
5705        d3 = abs(norm(a[2]+a[0])-dc)
5706        d4 = abs(norm(a[0]-a[1])-dc)
5707        d5 = abs(norm(a[1]-a[2])-dc)
5708        d6 = abs(norm(a[2]-a[0])-dc)
5709        return (d1+d2+d3+d4+d5+d6)/(6*dc)
5710    #end def cube_deviation
5711
5712    def cuboid_with_int_edges(vol):
5713        # Given a volume, return the cuboids which have integer edges
5714        divisors = []
5715        edges = []
5716        if isinstance(vol, int):
5717            i = 1
5718            while i<=vol:
5719                if vol%i==0:
5720                    divisors.append(i)
5721                #end if
5722                i+=1
5723            #end while
5724            for i in divisors:
5725                for j in divisors:
5726                    for k in divisors:
5727                        if i*j*k == vol:
5728                            edges.append([i,j,k])
5729                        #end if
5730                    #end for
5731                #end for
5732            #end for
5733        else:
5734            self.error('Volume multiplier must be integer')
5735        #end if
5736
5737        return edges
5738    #end def cuboid_with_int_edges
5739
5740    def alphas_on_grid(alphas, divs):
5741        new_alphas = []
5742        for alpha in alphas:
5743            abs_alpha  = abs(alpha)
5744            sign_alpha = sign(alpha)
5745            new_alpha  = round(abs_alpha*divs)*1./divs*sign_alpha
5746            new_alphas.append(new_alpha)
5747        #end for
5748        return new_alphas
5749    #end def alphas_on_grid
5750
5751    def find_alphas(structure, kpoints_label, kpoints_rel, check_standard):
5752        # Read wavevectors from the input and return the differences between all wavevectors (alphas)
5753        kpath       = get_kpath(structure = structure, check_standard = check_standard)
5754        kpath_label = array(kpath['explicit_kpoints_labels'])
5755        kpath_rel   = kpath['explicit_kpoints_rel']
5756        kpts        = dict()
5757
5758        if kpoints_label is None:
5759            kpoints_label = []
5760            if kpoints_rel is None:
5761                Structure.class_error('Please define symbolic or crystal coordinates for kpoints. e.g. [\'GAMMA\', \'K\']  or [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]')
5762            else:
5763                for k in kpoints_rel:
5764                    kindex = isclose(kpath_rel,k, atol=1e-5).all(1)
5765                    if any(kindex):
5766                        kpts[kpath_label[kindex][0]] = array(k)
5767                    else:
5768                        Structure.class_error('{0} is not found in the kpath'.format(k))
5769                    #end if
5770                #end for
5771            #end if
5772        else:
5773            if kpoints_rel is not None:
5774                Structure.class_error('Both symbolic and crystal k-points are defined.')
5775            else:
5776                kpoints_rel = []
5777                num_kpoints = 0
5778                for k in kpoints_label:
5779                    kindex = k == kpath_label
5780                    if any(kindex):
5781                        if k == '' or k == None:
5782                            k = '{0}'.format(num_kpoints)
5783                        #end if
5784                        kpts[k] = array(kpath_rel[kindex][0])
5785                    else:
5786                        Structure.class_error('{0} is not found in the kpath'.format(k))
5787                    #end if
5788                #end for
5789            #end if
5790        #end if
5791        alphas = array([x[0] - x[1] for x in itertools.combinations(kpts.values(),2)]) #Combinations of k_1 - k_2 in kpts list
5792        kpt0 = list(kpts.values())[0]
5793        return alphas, kpt0
5794    #end def find_alphas
5795
5796    def find_vars(alphas,min_volfac, max_volfac, target_volfac, use_ktol):
5797        '''
5798        Find variables to generate possible smallest matrices in the Upper Triangular Hermite Normal Form, from PHYSICAL REVIEW B 92, 184301 (2015)
5799        For target or min/max volumes, it returns vol_mul, which is the smallest volume multiplier to be used on the volfac here
5800        '''
5801        if use_ktol:
5802            ktol = 0.25/max_volfac
5803        else:
5804            ktol = 0.0
5805        #end if
5806
5807        if target_volfac is not None:
5808            if min_volfac is None and max_volfac is None:
5809                min_volfac = target_volfac
5810                max_volfac = target_volfac
5811            else:
5812                print("target_volfac and {min_volfac, max_volfac} cannot be defined together!")
5813                exit()
5814            #end if
5815        #end if
5816
5817        cur_volfac = 1e6
5818        vars       = []
5819        for mvol in range(max_volfac, 0, -1):
5820            cuboids = cuboid_with_int_edges(mvol)
5821            for c in cuboids:
5822                a_1, a_2, a_3 = c
5823                new_alphas = alphas_on_grid(alphas, c)
5824
5825
5826                rec_grid_voxel = array([1./a_1,1./a_2,1./a_3]) # reciprocal grid voxel size
5827                rem = []
5828                for alpha in alphas:
5829                    rem.append(mod(abs(alpha), rec_grid_voxel))
5830                #end for
5831                if all(isclose(rem,0.0,atol=ktol)):
5832                    n1     = a_1
5833                    n2     = a_2
5834                    n3     = a_3
5835                    g12    = np.gcd.reduce([n1,n2])
5836                    g13    = np.gcd.reduce([n1,n3])
5837                    g23    = np.gcd.reduce([n2,n3])
5838                    g123   = np.gcd.reduce([n1, n2, n3])
5839                    volfac = n1*n2*n3*g123//(g12*g13*g23)
5840                    if volfac < cur_volfac: #min_volfac <= volfac and  and volfac <= max_volfac:
5841                        vars = [[n1, n2, n3, g12, g13, g23, g123]]
5842                        cur_volfac = volfac
5843                    elif volfac == cur_volfac:
5844                        vars.append([n1, n2, n3, g12, g13, g23, g123])
5845                    #end if
5846                #end if
5847            #end for
5848        #end for
5849        if vars == []:
5850            print('Change ktol')
5851            exit()
5852        else:
5853            can_be_found = False
5854            vol_mul = 1
5855            while not can_be_found:
5856                if volfac*vol_mul <= max_volfac and volfac*vol_mul >= min_volfac:
5857                    can_be_found = True
5858                elif volfac*vol_mul > max_volfac:
5859                    print('Increase max_volfac or target_volfac!')
5860                    exit()
5861                else:
5862                    vol_mul+=1
5863                #end if
5864            #end while
5865            return vars, vol_mul
5866        #end if
5867
5868    #end def find_vars
5869
5870    def find_mats(mat_vars, alphas):
5871        '''
5872        Given the variables (v), return the list of all upper triangular matrices as in  PHYSICAL REVIEW B 92, 184301 (2015)
5873        '''
5874        mats = []
5875        for v in mat_vars:
5876            n1, n2, n3, g12, g13, g23, g123 = v
5877            #New alphas exactly on the voxels thanks to ktol
5878            for p in range(0, g23):
5879                for q in range(0, g12//g123):
5880                    for r in range(g13*g23//g123):
5881                        temp_mat = [
5882                            [g123*n1//(g12*g13), q*g123*n2//(g12*g23), r*g123*n3//(g13*g23)],
5883                            [0, n2//g23, p*n3//g23],
5884                            [0,0,n3]]
5885                        comm = []
5886                        div = array([n1, n2, n3])
5887                        new_alphas = alphas_on_grid(alphas, div)
5888                        for new_alpha in new_alphas:
5889                            if (isclose(abs(dot(temp_mat,new_alpha))%1.0, 0, atol = 1e-6)).all():
5890                                comm.append(True) # new_alpha is commensurate with tmat
5891                            else:
5892                                comm.append(False)
5893                            #end if
5894                        #end for
5895                        if all(comm) and temp_mat not in mats: # if all new_alphas are commensurate with tmat
5896                            mats.append(temp_mat)
5897                        #end if
5898                    #end for
5899                #end for
5900            #end for
5901        #end for
5902
5903        return mats
5904
5905    #end def find_mats
5906
5907    def find_cubic_mat(mats, structure, mat_vol_mul):
5908        final_axes       = []
5909        final_t          = []
5910        final_cubicity   = 1e6
5911        mats             = array(mats)
5912        for m in mats:
5913            axes     = structure.axes.copy()
5914            s        = structure.copy()
5915            [m_t, r] = optimal_tilematrix(s.tile(m), volfac=mat_vol_mul)
5916            m_axes = dot(dot(m_t, m), axes)
5917            m_cubicity = cube_deviation(m_axes)
5918            if m_cubicity < final_cubicity:
5919                final_axes        = m_axes
5920                final_cubicity    = m_cubicity
5921                final_t           = dot(m_t, m)
5922            #end if
5923        #end for
5924        return final_t.tolist()
5925    #end def find_cubic_mat
5926
5927    def find_shift(final_mat, structure, kpt0):
5928        return None
5929    #end def find_cubic_mat
5930
5931    alphas, kpt0            = find_alphas(structure,kpoints_label,kpoints_rel, check_standard) # Wavevector differences
5932    mat_vars, mat_vol_mul   = find_vars(alphas,min_volfac,max_volfac,target_volfac,use_ktol)   # Variables to construct upper triangular matrices
5933    mats                    = find_mats(mat_vars,alphas)                                       # List of upper triangular matrices that are commensurate with alphas
5934    final_mat               = find_cubic_mat(mats, structure, mat_vol_mul)                     # Matrix leading to a lattice with highest cubicity, optimized using elementary operations
5935    shift                   = find_shift(final_mat, structure, kpts0)                          # Find the grid shift
5936    o = obj()
5937    o.mat   = final_mat
5938    o.shift = shift
5939    o.det   = det(final_mat)
5940    return o
5941#end def get_band_tiling
5942
5943def get_seekpath_full(
5944        structure      = None,
5945        seekpathout    = None,
5946        conventional   = False,
5947        primitive      = False,
5948        **kwargs
5949        ):
5950    if seekpathout is None:
5951        seekpathout = _getseekpath(structure,**kwargs)
5952    #end if
5953    res = obj(seekpathout)
5954    for k,v in res.items():
5955        if isinstance(v,dict):
5956            res[k] = obj(v)
5957        #end if
5958    #end for
5959    if conventional:
5960        conv = get_conventional_cell(structure,seekpathout=seekpathout)
5961        res.conventional = conv['structure']
5962    #end if
5963    if primitive:
5964        prim = get_primitive_cell(structure,seekpathout=seekpathout)
5965        res.primitive    = prim['structure']
5966        res.prim_tmatrix = prim['T']
5967    #end if
5968    return res
5969#end def get_seekpath_full
5970
5971skp = obj(
5972    _getseekpath             = _getseekpath,
5973    get_conventional_cell    = get_conventional_cell,
5974    get_primitive_cell       = get_primitive_cell,
5975    get_kpath                = get_kpath,
5976    get_symmetry             = get_symmetry,
5977    get_structure_with_bands = get_structure_with_bands,
5978    get_band_tiling          = get_band_tiling,
5979    )
5980
5981#==========================#
5982#  end SeeK-path functions #
5983#==========================#
5984
5985
5986
5987def interpolate_structures(struct1,struct2=None,images=None,min_image=True,recenter=True,match_com=False,repackage=False,chained=False):
5988    if images is None:
5989        Structure.class_error('images must be provided','interpolate_structures')
5990    #end if
5991
5992    # if a list of structures is provided,
5993    # interpolate between pairs in the chain of structures
5994    if isinstance(struct1,(list,tuple)):
5995        structures_in = struct1
5996        structures = []
5997        for n in range(len(structures_in)-1):
5998            struct1 = structures_in[n]
5999            struct2 = structures_in[n+1]
6000            structs = interpolate_structures(struct1,struct2,images,min_image,recenter,match_com,repackage,chained=True)
6001            if n==0:
6002                structures.append(structs[0])
6003            #end if
6004            structures.extend(structs[1:-1])
6005            if n==len(structures_in)-2:
6006                structures.append(structs[-1])
6007            #end if
6008        #end for
6009        return structures
6010    #end if
6011
6012    # handle PhysicalSystem objects indirectly
6013    system1 = None
6014    system2 = None
6015    if not isinstance(struct1,Structure):
6016        system1 = struct1.copy()
6017        system1.remove_folded()
6018        struct1 = system1.structure
6019    #end if
6020    if not isinstance(struct2,Structure):
6021        system2 = struct2.copy()
6022        system2.remove_folded()
6023        struct2 = system2.structure
6024    #end if
6025
6026    # perform the interpolation
6027    structures = struct1.interpolate(struct2,images,min_image,recenter,match_com)
6028
6029    # repackage into physical system objects if requested
6030    if repackage:
6031        if system1!=None:
6032            system = system1
6033        elif system2!=None:
6034            system = system2
6035        else:
6036            Structure.class_error('cannot repackage into physical systems since no system object was provided in place of a structure','interpolate_structures')
6037        #end if
6038        systems = []
6039        for s in structures:
6040            ps = system.copy()
6041            ps.structure = s
6042            systems.append(ps)
6043        #end for
6044        result = systems
6045    else:
6046        result = structures
6047    #end if
6048
6049    return result
6050#end def interpolate_structures
6051
6052
6053# test needed
6054def structure_animation(filepath,structures,tiling=None):
6055    path,file = os.path.split(filepath)
6056    if not file.endswith('xyz'):
6057        Structure.class_error('only xyz files are supported for now','structure_animation')
6058    #end if
6059    anim = ''
6060    for s in structures:
6061        if tiling is None:
6062            anim += s.write_xyz()
6063        else:
6064            anim += s.tile(tiling).write_xyz()
6065        #end if
6066    #end for
6067    open(filepath,'w').write(anim)
6068#end def structure_animation
6069
6070
6071
6072class DefectStructure(Structure):
6073    def __init__(self,*args,**kwargs):
6074        if len(args)>0 and isinstance(args[0],Structure):
6075            self.transfer_from(args[0],copy=True)
6076        else:
6077            Structure.__init__(self,*args,**kwargs)
6078        #end if
6079    #end def __init__
6080
6081
6082    def defect_from_bond_compression(self,compression_cutoff,bond_eq,neighbors):
6083        bind,bcent,blens = self.bonds(neighbors)
6084        ind = bind[ abs(blens/bond_eq - 1.) > compression_cutoff ]
6085        idefect = array(list(set(ind.ravel())))
6086        defect = self.carve(idefect)
6087        return defect
6088    #end def defect_from_bond_compression
6089
6090
6091    def defect_from_displacement(self,displacement_cutoff,reference):
6092        displacement = self.scalar_displacement(reference)
6093        idefect = displacement > displacement_cutoff
6094        defect = self.carve(idefect)
6095        return defect
6096    #end def defect_from_displacement
6097
6098
6099    def compare(self,dist_cutoff,d1,d2=None):
6100        if d2==None:
6101            d2 = d1
6102            d1 = self
6103        #end if
6104        res = Sobj()
6105        natoms1 = len(d1.pos)
6106        natoms2 = len(d2.pos)
6107        if natoms1<natoms2:
6108            dsmall,dlarge = d1,d2
6109        else:
6110            dsmall,dlarge = d2,d1
6111        #end if
6112        nn = nearest_neighbors(1,dlarge,dsmall)
6113        dist = dsmall.distances(dlarge[nn.ravel()])
6114        dmatch = dist<dist_cutoff
6115        ismall = array(list(range(len(dsmall.pos))))
6116        ismall = ismall[dmatch]
6117        ilarge = nn[ismall]
6118        if natoms1<natoms2:
6119            i1,i2 = ismall,ilarge
6120        else:
6121            i2,i1 = ismall,ilarge
6122        #end if
6123        natoms_match = dmatch.sum()
6124        res.all_match = natoms1==natoms2 and natoms1==natoms_match
6125        res.natoms_match = natoms_match
6126        res.imatch1 = i1
6127        res.imatch2 = i2
6128        return res
6129    #end def compare
6130#end class DefectStructure
6131
6132
6133
6134class Crystal(Structure):
6135    lattice_constants = obj(
6136        triclinic    = ['a','b','c','alpha','beta','gamma'],
6137        monoclinic   = ['a','b','c','beta'],
6138        orthorhombic = ['a','b','c'],
6139        tetragonal   = ['a','c'],
6140        hexagonal    = ['a','c'],
6141        cubic        = ['a'],
6142        rhombohedral = ['a','alpha']
6143        )
6144
6145    lattices = list(lattice_constants.keys())
6146
6147    centering_types = obj(
6148        primitive             = 'P',
6149        base_centered         = ('A','B','C'),
6150        face_centered         = 'F',
6151        body_centered         = 'I',
6152        rhombohedral_centered = 'R'
6153        )
6154
6155    lattice_centerings = obj(
6156        triclinic    = ['P'],
6157        monoclinic   = ['P','A','B','C'],
6158        orthorhombic = ['P','C','I','F'],
6159        tetragonal   = ['P','I'],
6160        hexagonal    = ['P','R'],
6161        cubic        = ['P','I','F'],
6162        rhombohedral = ['P']
6163        )
6164
6165    centerings = obj(
6166        P = [],
6167        A = [[0,.5,.5]],
6168        B = [[.5,0,.5]],
6169        C = [[.5,.5,0]],
6170        F = [[0,.5,.5],[.5,0,.5],[.5,.5,0]],
6171        I = [[.5,.5,.5]],
6172        R = [[2./3, 1./3, 1./3],[1./3, 2./3, 2./3]]
6173        )
6174
6175    cell_types = set(['primitive','conventional'])
6176
6177    cell_aliases = obj(
6178        prim = 'primitive',
6179        conv = 'conventional'
6180        )
6181    cell_classes = obj(
6182        sc  = 'cubic',
6183        bcc = 'cubic',
6184        fcc = 'cubic',
6185        hex = 'hexagonal'
6186        )
6187    for lattice in lattices:
6188        cell_classes[lattice]=lattice
6189    #end for
6190
6191
6192    #helpful websites for structures
6193    #  wikipedia.org
6194    #  webelements.com
6195    #  webmineral.com
6196    #  springermaterials.com
6197
6198
6199    known_crystals = {
6200        ('diamond','fcc'):obj(
6201            lattice   = 'cubic',
6202            cell      = 'primitive',
6203            centering = 'F',
6204            constants = 3.57,
6205            units     = 'A',
6206            atoms     = 'C',
6207            basis     = [[0,0,0],[.25,.25,.25]]
6208            ),
6209        ('diamond','sc'):obj(
6210            lattice   = 'cubic',
6211            cell      = 'conventional',
6212            centering = 'F',
6213            constants = 3.57,
6214            units     = 'A',
6215            atoms     = 'C',
6216            basis     = [[0,0,0],[.25,.25,.25]]
6217            ),
6218        ('diamond','prim'):obj(
6219            lattice   = 'cubic',
6220            cell      = 'primitive',
6221            centering = 'F',
6222            constants = 3.57,
6223            units     = 'A',
6224            atoms     = 'C',
6225            basis     = [[0,0,0],[.25,.25,.25]]
6226            ),
6227        ('diamond','conv'):obj(
6228            lattice   = 'cubic',
6229            cell      = 'conventional',
6230            centering = 'F',
6231            constants = 3.57,
6232            units     = 'A',
6233            atoms     = 'C',
6234            basis     = [[0,0,0],[.25,.25,.25]]
6235            ),
6236        ('wurtzite','prim'):obj(
6237            lattice   = 'hexagonal',
6238            cell      = 'primitive',
6239            centering = 'P',
6240            constants = (3.35,5.22),
6241            units     = 'A',
6242            #atoms     = ('Zn','O'),
6243            #basis     = [[1./3, 2./3, 3./8],[1./3, 2./3, 0]]
6244            atoms     = ('Zn','O','Zn','O'),
6245            basis     = [[0,0,5./8],[0,0,0],[2./3,1./3,1./8],[2./3,1./3,1./2]]
6246            ),
6247        ('ZnO','prim'):obj(
6248            lattice   = 'wurtzite',
6249            cell      = 'prim',
6250            constants = (3.35,5.22),
6251            units     = 'A',
6252            atoms     = ('Zn','O','Zn','O')
6253            ),
6254        ('NaCl','prim'):obj(
6255            lattice   = 'cubic',
6256            cell      = 'primitive',
6257            centering = 'F',
6258            constants = 5.64,
6259            units     = 'A',
6260            atoms     = ('Na','Cl'),
6261            basis     = [[0,0,0],[.5,0,0]],
6262            basis_vectors = 'conventional'
6263            ),
6264        ('rocksalt','prim'):obj(
6265            lattice   = 'cubic',
6266            cell      = 'primitive',
6267            centering = 'F',
6268            constants = 5.64,
6269            units     = 'A',
6270            atoms     = ('Na','Cl'),
6271            basis     = [[0,0,0],[.5,0,0]],
6272            basis_vectors = 'conventional'
6273            ),
6274        ('copper','prim'):obj(
6275            lattice   = 'cubic',
6276            cell      = 'primitive',
6277            centering = 'F',
6278            constants = 3.615,
6279            units     = 'A',
6280            atoms     = 'Cu'
6281            ),
6282        ('calcium','prim'):obj(
6283            lattice   = 'cubic',
6284            cell      = 'primitive',
6285            centering = 'F',
6286            constants = 5.588,
6287            units     = 'A',
6288            atoms     = 'Ca'
6289            ),
6290        # http://www.webelements.com/oxygen/crystal_structure.html
6291        #   Phys Rev 160 694
6292        ('oxygen','prim'):obj(
6293            lattice   = 'monoclinic',
6294            cell      = 'primitive',
6295            centering = 'C',
6296            constants = (5.403,3.429,5.086,132.53),
6297            units     = 'A',
6298            angular_units = 'degrees',
6299            atoms     = ('O','O'),
6300            basis     = [[0,0,1.15/2],[0,0,-1.15/2]],
6301            basis_vectors = identity(3)
6302            ),
6303        # http://en.wikipedia.org/wiki/Calcium_oxide
6304        # http://www.springermaterials.com/docs/info/10681719_224.html
6305        ('CaO','prim'):obj(
6306            lattice   = 'NaCl',
6307            cell      = 'prim',
6308            constants = 4.81,
6309            atoms     = ('Ca','O')
6310            ),
6311        ('CaO','conv'):obj(
6312            lattice   = 'NaCl',
6313            cell      = 'conv',
6314            constants = 4.81,
6315            atoms     = ('Ca','O')
6316            ),
6317        # http://en.wikipedia.org/wiki/Copper%28II%29_oxide
6318        #   http://iopscience.iop.org/0953-8984/3/28/001/
6319        # http://www.webelements.com/compounds/copper/copper_oxide.html
6320        # http://www.webmineral.com/data/Tenorite.shtml
6321        ('CuO','prim'):obj(
6322            lattice   = 'monoclinic',
6323            cell      = 'primitive',
6324            centering = 'C',
6325            constants = (4.683,3.422,5.128,99.54),
6326            units     = 'A',
6327            angular_units = 'degrees',
6328            atoms     = ('Cu','O','Cu','O'),
6329            basis     = [[.25,.25,0],[0,.418,.25],
6330                         [.25,.75,.5],[.5,.5-.418,.75]],
6331            basis_vectors = 'conventional'
6332            ),
6333        ('Ca2CuO3','prim'):obj(# kateryna foyevtsova
6334            lattice   = 'orthorhombic',
6335            cell      = 'primitive',
6336            centering = 'I',
6337            constants = (3.77,3.25,12.23),
6338            units     = 'A',
6339            atoms     = ('Cu','O','O','O','Ca','Ca'),
6340            basis     = [[   0,   0,   0 ],
6341                         [ .50,   0,   0 ],
6342                         [   0,   0, .16026165],
6343                         [   0,   0, .83973835],
6344                         [   0,   0, .35077678],
6345                         [   0,   0, .64922322]],
6346            basis_vectors = 'conventional'
6347            ),
6348        ('La2CuO4','prim'):obj( #tetragonal structure
6349            lattice   = 'tetragonal',
6350            cell      = 'primitive',
6351            centering = 'I',
6352            constants = (3.809,13.169),
6353            units     = 'A',
6354            atoms     = ('Cu','O','O','O','O','La','La'),
6355            basis     = [[  0,    0,    0],
6356                         [ .5,    0,    0],
6357                         [  0,   .5,    0],
6358                         [  0,    0,  .182],
6359                         [  0,    0, -.182],
6360                         [  0,    0,  .362],
6361                         [  0,    0, -.362]]
6362            ),
6363        ('Cl2Ca2CuO2','prim'):obj(
6364            lattice   = 'tetragonal',
6365            cell      = 'primitive',
6366            centering = 'I',
6367            constants = (3.869,15.05),
6368            units     = 'A',
6369            atoms     = ('Cu','O','O','Ca','Ca','Cl','Cl'),
6370            basis     = [[   0,   0,    0 ],
6371                         [  .5,   0,    0 ],
6372                         [   0,  .5,    0 ],
6373                         [  .5,  .5,  .104],
6374                         [   0,   0,  .396],
6375                         [   0,   0,  .183],
6376                         [  .5,  .5,  .317]],
6377            basis_vectors = 'conventional'
6378            ),
6379        ('Cl2Ca2CuO2','afm'):obj(
6380            lattice   = 'tetragonal',
6381            cell      = 'conventional',
6382            centering = 'P',
6383            axes      = [[.5,-.5,0],[.5,.5,0],[0,0,1]],
6384            constants = (2*3.869,15.05),
6385            units     = 'A',
6386            atoms     = 4*['Cu','O','O','Ca','Ca','Cl','Cl'],
6387            basis     = [[   0,   0,    0 ], #Cu
6388                         [  .25,  0,    0 ],
6389                         [   0,  .25,   0 ],
6390                         [  .25, .25, .104],
6391                         [   0,   0,  .396],
6392                         [   0,   0,  .183],
6393                         [  .25, .25, .317],
6394                         [  .25, .25, .5  ], #Cu
6395                         [  .5,  .25, .5  ],
6396                         [  .25, .5,  .5  ],
6397                         [  .5,  .5,  .604],
6398                         [  .25, .25, .896],
6399                         [  .25, .25, .683],
6400                         [  .5,  .5,  .817],
6401                         [  .5,   0,    0 ], #Cu2
6402                         [  .75,  0,    0 ],
6403                         [  .5,  .25,   0 ],
6404                         [  .75, .25, .104],
6405                         [  .5,   0,  .396],
6406                         [  .5,   0,  .183],
6407                         [  .75, .25, .317],
6408                         [  .75, .25, .5  ], #Cu2
6409                         [   0,  .25, .5  ],
6410                         [  .75, .5,  .5  ],
6411                         [   0,  .5,  .604],
6412                         [  .75, .25, .896],
6413                         [  .75, .25, .683],
6414                         [   0,  .5,  .817]],
6415            basis_vectors = 'conventional'
6416            ),
6417        ('CuO2_plane','prim'):obj(
6418            lattice   = 'tetragonal',
6419            cell      = 'primitive',
6420            centering = 'P',
6421            constants = (3.809,13.169),
6422            units     = 'A',
6423            atoms     = ('Cu','O','O'),
6424            basis     = [[  0,    0,    0],
6425                         [ .5,    0,    0],
6426                         [  0,   .5,    0]]
6427            ),
6428        ('graphite_aa','hex'):obj(
6429            axes      = [[1./2,-sqrt(3.)/2,0],[1./2,sqrt(3.)/2,0],[0,0,1]],
6430            constants = (2.462,3.525),
6431            units     = 'A',
6432            atoms     = ('C','C'),
6433            basis     = [[0,0,0],[2./3,1./3,0]]
6434            ),
6435        ('graphite_ab','hex'):obj(
6436            axes      = [[1./2,-sqrt(3.)/2,0],[1./2,sqrt(3.)/2,0],[0,0,1]],
6437            constants = (2.462,3.525),
6438            units     = 'A',
6439            cscale    = (1,2),
6440            atoms     = ('C','C','C','C'),
6441            basis     = [[0,0,0],[2./3,1./3,0],[0,0,1./2],[1./3,2./3,1./2]]
6442            ),
6443        ('graphene','prim'):obj(
6444            lattice   = 'hexagonal',
6445            cell      = 'primitive',
6446            centering = 'P',
6447            constants = (2.462,15.0),
6448            units     = 'A',
6449            atoms     = ('C','C'),
6450            basis     = [[0,0,0],[2./3,1./3,0]]
6451            ),
6452        ('graphene','rect'):obj(
6453            lattice   = 'orthorhombic',
6454            cell      = 'conventional',
6455            centering = 'C',
6456            constants = (2.462,sqrt(3.)*2.462,15.0),
6457            units     = 'A',
6458            atoms     = ('C','C'),
6459            basis     = [[0,0,0],[1./2,1./6,0]]
6460            )
6461        }
6462
6463    kc_keys = list(known_crystals.keys())
6464    for (name,cell) in kc_keys:
6465        desc = known_crystals[name,cell]
6466        if cell=='prim' and not (name,'conv') in known_crystals:
6467            cdesc = desc.copy()
6468            if cdesc.cell=='primitive':
6469                cdesc.cell = 'conventional'
6470                known_crystals[name,'conv'] = cdesc
6471            elif cdesc.cell=='prim':
6472                cdesc.cell = 'conv'
6473                known_crystals[name,'conv'] = cdesc
6474            #end if
6475        #end if
6476    #end if
6477    del kc_keys
6478
6479
6480    def __init__(self,
6481                 lattice        = None,
6482                 cell           = None,
6483                 centering      = None,
6484                 constants      = None,
6485                 atoms          = None,
6486                 basis          = None,
6487                 basis_vectors  = None,
6488                 tiling         = None,
6489                 cscale         = None,
6490                 axes           = None,
6491                 units          = None,
6492                 angular_units  = 'degrees',
6493                 kpoints        = None,
6494                 kgrid          = None,
6495                 mag            = None,
6496                 frozen         = None,
6497                 magnetization  = None,
6498                 kshift         = (0,0,0),
6499                 permute        = None,
6500                 operations     = None,
6501                 elem           = None,
6502                 pos            = None,
6503                 use_prim       = None,
6504                 add_kpath      = False,
6505                 symm_kgrid     = False,
6506                 ):
6507
6508        if lattice is None and cell is None and atoms is None and units is None:
6509            return
6510        #end if
6511
6512        gi = obj(
6513            lattice        = lattice       ,
6514            cell           = cell          ,
6515            centering      = centering     ,
6516            constants      = constants     ,
6517            atoms          = atoms         ,
6518            basis          = basis         ,
6519            basis_vectors  = basis_vectors ,
6520            tiling         = tiling        ,
6521            cscale         = cscale        ,
6522            axes           = axes          ,
6523            units          = units         ,
6524            angular_units  = angular_units ,
6525            frozen         = frozen        ,
6526            mag            = mag           ,
6527            magnetization  = magnetization ,
6528            kpoints        = kpoints       ,
6529            kgrid          = kgrid         ,
6530            kshift         = kshift        ,
6531            permute        = permute       ,
6532            operations     = operations    ,
6533            elem           = elem          ,
6534            pos            = pos           ,
6535            use_prim       = use_prim      ,
6536            add_kpath      = add_kpath     ,
6537            symm_kgrid     = symm_kgrid    ,
6538            )
6539        generation_info = gi.copy()
6540
6541        lattice_in = lattice
6542        if isinstance(lattice,str):
6543            lattice=lattice.lower()
6544        #end if
6545        if isinstance(cell,str):
6546            cell=cell.lower()
6547        #end if
6548
6549        known_crystal = False
6550        if (lattice_in,cell) in self.known_crystals:
6551            known_crystal = True
6552            lattice_info = self.known_crystals[lattice_in,cell].copy()
6553        elif (lattice,cell) in self.known_crystals:
6554            known_crystal = True
6555            lattice_info = self.known_crystals[lattice,cell].copy()
6556        #end if
6557
6558        if known_crystal:
6559            while 'lattice' in lattice_info and 'cell' in lattice_info and (lattice_info.lattice,lattice_info.cell) in self.known_crystals:
6560                li_old = lattice_info
6561                lattice_info = self.known_crystals[li_old.lattice,li_old.cell].copy()
6562                del li_old.lattice
6563                del li_old.cell
6564                lattice_info.transfer_from(li_old,copy=False)
6565            #end while
6566            if 'cell' in lattice_info:
6567                cell = lattice_info.cell
6568            elif cell in self.cell_aliases:
6569                cell = self.cell_aliases[cell]
6570            elif cell in self.cell_classes:
6571                lattice = self.cell_classes[cell]
6572            else:
6573                self.error('cell shape '+cell+' is not recognized\n  the variable cell_classes or cell_aliases must be updated to include '+cell)
6574            #end if
6575            if 'lattice' in lattice_info:
6576                lattice = lattice_info.lattice
6577            #end if
6578            if 'angular_units' in lattice_info:
6579                angular_units = lattice_info.angular_units
6580            #end if
6581            inputs = obj(
6582                centering     = centering,
6583                constants     = constants,
6584                atoms         = atoms,
6585                basis         = basis,
6586                basis_vectors = basis_vectors,
6587                tiling        = tiling,
6588                cscale        = cscale,
6589                axes          = axes,
6590                units         = units
6591                )
6592            for var,val in inputs.items():
6593                if val is None and var in lattice_info:
6594                    inputs[var] = lattice_info[var]
6595                #end if
6596            #end for
6597            centering,constants,atoms,basis,basis_vectors,tiling,cscale,axes,units=inputs.list('centering','constants','atoms','basis','basis_vectors','tiling','cscale','axes','units')
6598        #end if
6599
6600        if constants is None:
6601            self.error('the variable constants must be provided')
6602        #end if
6603        if atoms is None:
6604            self.error('the variable atoms must be provided')
6605        #end if
6606
6607        if lattice not in self.lattices:
6608            self.error('lattice type '+str(lattice)+' is not recognized\n  valid lattice types are: '+str(list(self.lattices)))
6609        #end if
6610        if cell=='conventional':
6611            if centering is None:
6612                self.error('centering must be provided for a conventional cell\n  options for a '+lattice+' lattice are: '+str(self.lattice_centerings[lattice]))
6613            elif centering not in self.centerings:
6614                self.error('centering type '+str(centering)+' is not recognized\n  options for a '+lattice+' lattice are: '+str(self.lattice_centerings[lattice]))
6615            #end if
6616        #end if
6617        if isinstance(constants,int) or isinstance(constants,float):
6618            constants=[constants]
6619        #end if
6620        if len(constants)!=len(self.lattice_constants[lattice]):
6621            self.error('the '+lattice+' lattice depends on the constants '+str(self.lattice_constants[lattice])+'\n you provided '+str(len(constants))+': '+str(constants))
6622        #end if
6623        if isinstance(atoms,str):
6624            if basis!=None:
6625                atoms = len(basis)*[atoms]
6626            else:
6627                atoms=[atoms]
6628            #end if
6629        #end if
6630        if basis is None:
6631            if len(atoms)==1:
6632                basis = [(0,0,0)]
6633            else:
6634                self.error('must provide as many basis coordinates as basis atoms\n  atoms provided: '+str(atoms)+'\n  basis provided: '+str(basis))
6635            #end if
6636        #end if
6637        if basis_vectors is not None and not isinstance(basis_vectors,str) and len(basis_vectors)!=3:
6638            self.error('3 basis vectors must be given, you provided '+str(len(basis))+':\n  '+str(basis_vectors))
6639        #end if
6640
6641        if tiling is None:
6642            tiling = (1,1,1)
6643        #end if
6644        if cscale is None:
6645            cscale = len(constants)*[1]
6646        #end if
6647        if len(cscale)!=len(constants):
6648            self.error('cscale and constants must be the same length')
6649        #end if
6650        basis  = array(basis)
6651        tiling = array(tiling,dtype=int)
6652        cscale = array(cscale)
6653        constants = cscale*array(constants)
6654
6655        a,b,c,alpha,beta,gamma = None,None,None,None,None,None
6656        if angular_units=='radians':
6657            pi_1o2 = pi/2
6658            pi_2o3 = 2*pi/3
6659        elif angular_units=='degrees':
6660            pi_1o2 = 90.
6661            pi_2o3 = 120.
6662        else:
6663            self.error('angular units must be radians or degrees\n  you provided '+str(angular_units))
6664        #end if
6665        if lattice=='triclinic':
6666            a,b,c,alpha,beta,gamma = constants
6667        elif lattice=='monoclinic':
6668            a,b,c,beta = constants
6669            alpha = gamma = pi_1o2
6670        elif lattice=='orthorhombic':
6671            a,b,c = constants
6672            alpha=beta=gamma=pi_1o2
6673        elif lattice=='tetragonal':
6674            a,c = constants
6675            b=a
6676            alpha=beta=gamma=pi_1o2
6677        elif lattice=='hexagonal':
6678            a,c = constants
6679            b=a
6680            alpha=beta=pi_1o2
6681            gamma=pi_2o3
6682        elif lattice=='cubic':
6683            a=constants[0]
6684            b=c=a
6685            alpha=beta=gamma=pi_1o2
6686        elif lattice=='rhombohedral':
6687            a,alpha = constants
6688            b=c=a
6689            beta=gamma=alpha
6690        #end if
6691        if angular_units=='degrees':
6692            alpha *= pi/180
6693            beta  *= pi/180
6694            gamma *= pi/180
6695        #end if
6696
6697        points = [[0,0,0]]
6698        #get the conventional axes
6699        sa,ca = sin(alpha),cos(alpha)
6700        sb,cb = sin(beta) ,cos(beta)
6701        sg,cg = sin(gamma),cos(gamma)
6702        y     = (ca-cg*cb)/sg
6703        a1c = a*array([1,0,0])
6704        a2c = b*array([cg,sg,0])
6705        a3c = c*array([cb,y,sqrt(sb**2-y**2)])
6706        #a1c = array([a,0,0])
6707        #a2c = array([b*cos(gamma),b*sin(gamma),0])
6708        #a3c = array([c*cos(beta),c*cos(alpha)*sin(beta),c*sin(alpha)*sin(beta)])
6709        axes_conv = array([a1c,a2c,a3c]).copy()
6710
6711        if axes is None:
6712            if cell not in self.cell_types:
6713                self.error('cell must be primitive or conventional\n  You provided: '+str(cell))
6714            #end if
6715            if cell=='primitive' and centering=='P':
6716                cell='conventional'
6717            #end if
6718            #get the primitive axes
6719            if centering=='P':
6720                a1 = a1c
6721                a2 = a2c
6722                a3 = a3c
6723            elif centering=='A':
6724                a1 = a1c
6725                a2 = (a2c+a3c)/2
6726                a3 = (-a2c+a3c)/2
6727            elif centering=='B':
6728                a1 = (a1c+a3c)/2
6729                a2 = a2c
6730                a3 = (-a1c+a3c)/2
6731            elif centering=='C':
6732                a1 = (a1c-a2c)/2
6733                a2 = (a1c+a2c)/2
6734                a3 = a3c
6735            elif centering=='I':
6736                a1=[ a/2, b/2,-c/2]
6737                a2=[-a/2, b/2, c/2]
6738                a3=[ a/2,-b/2, c/2]
6739            elif centering=='F':
6740                a1=[a/2, b/2,   0]
6741                a2=[  0, b/2, c/2]
6742                a3=[a/2,   0, c/2]
6743            elif centering=='R':
6744                a1=[   a,              0,   0]
6745                a2=[ a/2,   a*sqrt(3.)/2,   0]
6746                a3=[-a/6, a/(2*sqrt(3.)), c/3]
6747            else:
6748                self.error('the variable centering must be specified\n  valid options are: P,A,B,C,I,F,R')
6749            #end if
6750            axes_prim = array([a1,a2,a3])
6751            if cell=='primitive':
6752                axes = axes_prim
6753            elif cell=='conventional':
6754                axes = axes_conv
6755                points.extend(self.centerings[centering])
6756            #end if
6757        elif known_crystal:
6758            axes = dot(diag([a,b,c]),array(axes))
6759        #end if
6760        points = array(points,dtype=float)
6761
6762        elem = []
6763        pos  = []
6764        if basis_vectors is None:
6765            basis_vectors = axes
6766        elif basis_vectors is 'primitive':
6767            basis_vectors = axes_prim
6768        elif basis_vectors is 'conventional':
6769            basis_vectors = axes_conv
6770        #end if
6771        nbasis = len(atoms)
6772        for point in points:
6773            for i in range(nbasis):
6774                atom   = atoms[i]
6775                bpoint = basis[i]
6776                p = dot(point,axes) + dot(bpoint,basis_vectors)
6777                elem.append(atom)
6778                pos.append(p)
6779            #end for
6780        #end for
6781        pos = array(pos)
6782
6783        self.set(
6784            constants = array([a,b,c]),
6785            angles    = array([alpha,beta,gamma]),
6786            generation_info = generation_info
6787            )
6788
6789        Structure.__init__(
6790            self,
6791            axes           = axes,
6792            scale          = a,
6793            elem           = elem,
6794            pos            = pos,
6795            center         = axes.sum(0)/2,
6796            units          = units,
6797            frozen         = frozen,
6798            mag            = mag,
6799            magnetization  = magnetization,
6800            tiling         = tiling,
6801            kpoints        = kpoints,
6802            kgrid          = kgrid,
6803            kshift         = kshift,
6804            permute        = permute,
6805            rescale        = False,
6806            operations     = operations,
6807            use_prim       = use_prim,
6808            add_kpath      = add_kpath,
6809            symm_kgrid     = symm_kgrid,
6810            )
6811
6812    #end def __init__
6813#end class Crystal
6814
6815
6816# test needed
6817class Jellium(Structure):
6818    prefactors = obj()
6819    prefactors.transfer_from({1:2*pi,2:4*pi,3:4./3*pi})
6820
6821    def __init__(self,charge=None,background_charge=None,cell=None,volume=None,density=None,rs=None,dim=3,
6822                 axes=None,kpoints=None,kweights=None,kgrid=None,kshift=None,units=None,tiling=None):
6823        del tiling
6824        if rs!=None:
6825            if not dim in self.prefactors:
6826                self.error('only 1,2, or 3 dimensional jellium is currently supported\n  you requested one with dimension {0}'.format(dim))
6827            #end if
6828            density = 1.0/(self.prefactors[dim]*rs**dim)
6829        #end if
6830        if axes!=None:
6831            cell = axes
6832        #end if
6833        if background_charge!=None:
6834            charge = background_charge
6835        #end if
6836        if cell!=None:
6837            cell   = array(cell)
6838            dim    = len(cell)
6839            volume = det(cell)
6840        elif volume!=None:
6841            volume = float(volume)
6842            cell   = volume**(1./dim)*identity(dim)
6843        #end if
6844        if density!=None:
6845            density = float(density)
6846            if charge is None and volume!=None:
6847                charge = density*volume
6848            elif volume is None and charge!=None:
6849                volume = charge/density
6850                cell   = volume**(1./dim)*identity(dim)
6851            #end if
6852        #end if
6853        if charge is None or cell is None:
6854            self.error('not enough information to form jellium structure\n  information provided:\n  charge: {0}\n  cell: {1}\n  volume: {2}\n  density: {3}\n  rs: {4}\n  dim: {5}'.format(charge,cell,volume,density,rs,dim))
6855        #end if
6856        Structure.__init__(self,background_charge=charge,axes=cell,dim=dim,kpoints=kpoints,kweights=kweights,kgrid=kgrid,kshift=kshift,units=units)
6857    #end def __init__
6858
6859    def density(self):
6860        return self.background_charge/self.volume()
6861    #end def density
6862
6863    def rs(self):
6864        return 1.0/(self.density()*self.prefactors[self.dim])**(1./self.dim)
6865    #end def rs
6866
6867    def tile(self):
6868        self.not_implemented()
6869    #end def tile
6870#end class Jellium
6871
6872
6873
6874
6875
6876
6877
6878
6879# test needed
6880def generate_cell(shape,tiling=None,scale=1.,units=None,struct_type=Structure):
6881    if tiling is None:
6882        tiling = (1,1,1)
6883    #end if
6884    axes = Sobj()
6885    axes.sc  =  1.*array([[ 1,0,0],[0, 1,0],[0,0, 1]])
6886    axes.bcc = .5*array([[-1,1,1],[1,-1,1],[1,1,-1]])
6887    axes.fcc = .5*array([[ 1,1,0],[1, 0,1],[0,1, 1]])
6888    ax     = dot(diag(tiling),axes[shape])
6889    center = ax.sum(0)/2
6890    c = Structure(axes=ax,scale=scale,center=center,units=units)
6891    if struct_type!=Structure:
6892        c=c.upcast(struct_type)
6893    #end if
6894    return c
6895#end def generate_cell
6896
6897
6898
6899def generate_structure(type='crystal',*args,**kwargs):
6900    if type=='crystal':
6901        s = generate_crystal_structure(*args,**kwargs)
6902    elif type=='defect':
6903        s = generate_defect_structure(*args,**kwargs)
6904    elif type=='atom':
6905        s = generate_atom_structure(*args,**kwargs)
6906    elif type=='dimer':
6907        s = generate_dimer_structure(*args,**kwargs)
6908    elif type=='trimer':
6909        s = generate_trimer_structure(*args,**kwargs)
6910    elif type=='jellium':
6911        s = generate_jellium_structure(*args,**kwargs)
6912    elif type=='empty':
6913        s = Structure()
6914    elif type=='basic':
6915        s = Structure(*args,**kwargs)
6916    else:
6917        Structure.class_error(str(type)+' is not a valid structure type\noptions are crystal, defect, atom, dimer, trimer, jellium, empty, or basic')
6918    #end if
6919    return s
6920#end def generate_structure
6921
6922
6923
6924
6925# test needed
6926def generate_atom_structure(
6927    atom        = None,
6928    units       = 'A',
6929    Lbox        = None,
6930    skew        = 0,
6931    axes        = None,
6932    kgrid       = (1,1,1),
6933    kshift      = (0,0,0),
6934    bconds      = tuple('nnn'),
6935    struct_type = Structure
6936    ):
6937    if atom is None:
6938        Structure.class_error('atom must be provided','generate_atom_structure')
6939    #end if
6940    if Lbox!=None:
6941        axes = [[Lbox*(1-skew),0,0],[0,Lbox,0],[0,0,Lbox*(1+skew)]]
6942    #end if
6943    if axes is None:
6944        s = Structure(elem=[atom],pos=[[0,0,0]],units=units,bconds=bconds)
6945    else:
6946        s = Structure(elem=[atom],pos=[[0,0,0]],axes=axes,kgrid=kgrid,kshift=kshift,bconds=bconds,units=units)
6947        s.center_molecule()
6948    #end if
6949
6950    return s
6951#end def generate_atom_structure
6952
6953
6954# test needed
6955def generate_dimer_structure(
6956    dimer       = None,
6957    units       = 'A',
6958    separation  = None,
6959    Lbox        = None,
6960    skew        = 0,
6961    axes        = None,
6962    kgrid       = (1,1,1),
6963    kshift      = (0,0,0),
6964    bconds      = tuple('nnn'),
6965    struct_type = Structure,
6966    axis        = 'x'
6967    ):
6968    if dimer is None:
6969        Structure.class_error('dimer atoms must be provided to construct dimer','generate_dimer_structure')
6970    #end if
6971    if separation is None:
6972        Structure.class_error('separation must be provided to construct dimer','generate_dimer_structure')
6973    #end if
6974    if Lbox!=None:
6975        axes = [[Lbox*(1-skew),0,0],[0,Lbox,0],[0,0,Lbox*(1+skew)]]
6976    #end if
6977    if axis=='x':
6978        p2 = [separation,0,0]
6979    elif axis=='y':
6980        p2 = [0,separation,0]
6981    elif axis=='z':
6982        p2 = [0,0,separation]
6983    else:
6984        Structure.class_error('dimer orientation axis must be x,y,z\n  you provided: {0}'.format(axis),'generate_dimer_structure')
6985    #end if
6986    if axes is None:
6987        s = Structure(elem=dimer,pos=[[0,0,0],p2],units=units,bconds=bconds)
6988    else:
6989        s = Structure(elem=dimer,pos=[[0,0,0],p2],axes=axes,kgrid=kgrid,kshift=kshift,units=units,bconds=bconds)
6990        s.center_molecule()
6991    #end if
6992    return s
6993#end def generate_dimer_structure
6994
6995
6996# test needed
6997def generate_trimer_structure(
6998    trimer        = None,
6999    units         = 'A',
7000    separation    = None,
7001    angle         = None,
7002    Lbox          = None,
7003    skew          = 0,
7004    axes          = None,
7005    kgrid         = (1,1,1),
7006    kshift        = (0,0,0),
7007    struct_type   = Structure,
7008    axis          = 'x',
7009    axis2         = 'y',
7010    angular_units = 'degrees',
7011    plane_rot     = None
7012    ):
7013    if trimer is None:
7014        Structure.class_error('trimer atoms must be provided to construct trimer','generate_trimer_structure')
7015    #end if
7016    if separation is None:
7017        Structure.class_error('separation must be provided to construct trimer','generate_trimer_structure')
7018    #end if
7019    if len(separation)!=2:
7020        Structure.class_error('two separation distances (atom1-atom2,atom1-atom3) must be provided to construct trimer\nyou provided {0} separation distances'.format(len(separation)),'generate_trimer_structure')
7021    #end if
7022    if angle is None:
7023        Structure.class_error('angle must be provided to construct trimer','generate_trimer_structure')
7024    #end if
7025    if angular_units=='degrees':
7026        angle *= pi/180
7027    elif not angular_units.startswith('rad'):
7028        Structure.class_error('angular units must be degrees or radians\nyou provided: {0}'.format(angular_units),'generate_trimer_structure')
7029    #end if
7030    if axis==axis2:
7031        Structure.class_error('axis and axis2 must be different to define the trimer plane\nyou provided {0} for both'.format(axis),'generate_trimer_structure')
7032    #end if
7033    if Lbox!=None:
7034        axes = [[Lbox*(1-skew),0,0],[0,Lbox,0],[0,0,Lbox*(1+skew)]]
7035    #end if
7036    p1 = [0,0,0]
7037    if axis=='x':
7038        p2 = [separation[0],0,0]
7039    elif axis=='y':
7040        p2 = [0,separation[0],0]
7041    elif axis=='z':
7042        p2 = [0,0,separation[0]]
7043    else:
7044        Structure.class_error('trimer bond1 (atom2-atom1) orientation axis must be x,y,z\n  you provided: {0}'.format(axis),'generate_trimer_structure')
7045    #end if
7046    r = separation[1]
7047    c = cos(angle)
7048    s = sin(angle)
7049    axpair = axis+axis2
7050    if axpair=='xy':
7051        p3 = [r*c,r*s,0]
7052    elif axpair=='yx':
7053        p3 = [r*s,r*c,0]
7054    elif axpair=='yz':
7055        p3 = [0,r*c,r*s]
7056    elif axpair=='zy':
7057        p3 = [0,r*s,r*c]
7058    elif axpair=='zx':
7059        p3 = [r*s,0,r*c]
7060    elif axpair=='xz':
7061        p3 = [r*c,0,r*s]
7062    else:
7063        Structure.class_error('trimer bond2 (atom3-atom1) orientation axis must be x,y,z\n  you provided: {0}'.format(axis2),'generate_trimer_structure')
7064    #end if
7065    if axes is None:
7066        s = Structure(elem=trimer,pos=[p1,p2,p3],units=units)
7067    else:
7068        s = Structure(elem=trimer,pos=[p1,p2,p3],axes=axes,kgrid=kgrid,kshift=kshift,units=units)
7069        s.center_molecule()
7070    #end if
7071    if plane_rot!=None:
7072        s.rotate_plane(axpair,plane_rot,angular_units)
7073    #end if
7074    return s
7075#end def generate_trimer_structure
7076
7077
7078# test needed
7079def generate_jellium_structure(*args,**kwargs):
7080    return Jellium(*args,**kwargs)
7081#end def generate_jellium_structure
7082
7083
7084
7085
7086def generate_crystal_structure(
7087    lattice        = None,
7088    cell           = None,
7089    centering      = None,
7090    constants      = None,
7091    atoms          = None,
7092    basis          = None,
7093    basis_vectors  = None,
7094    tiling         = None,
7095    cscale         = None,
7096    axes           = None,
7097    units          = None,
7098    angular_units  = 'degrees',
7099    magnetization  = None,
7100    mag            = None,
7101    kpoints        = None,
7102    kweights       = None,
7103    kgrid          = None,
7104    kshift         = (0,0,0),
7105    permute        = None,
7106    operations     = None,
7107    struct_type    = Crystal,
7108    elem           = None,
7109    pos            = None,
7110    frozen         = None,
7111    posu           = None,
7112    elem_pos       = None,
7113    folded_elem    = None,
7114    folded_pos     = None,
7115    folded_units   = None,
7116    use_prim       = None,
7117    add_kpath      = False,
7118    symm_kgrid     = False,
7119    #legacy inputs
7120    structure      = None,
7121    shape          = None,
7122    element        = None,
7123    scale          = None,
7124    ):
7125
7126    if structure is not None:
7127        lattice = structure
7128    #end if
7129    if shape is not None:
7130        cell = shape
7131    #end if
7132    if element is not None:
7133        atoms = element
7134    #end if
7135    if scale is not None:
7136        constants = scale
7137    #end if
7138
7139    #interface for total manual specification
7140    # this is only here because 'crystal' is default and must handle other cases
7141    s = None
7142    has_elem_and_pos = elem is not None and (pos is not None or posu is not None)
7143    has_elem_and_pos |= elem_pos is not None
7144    if has_elem_and_pos:
7145        s = Structure(
7146            axes           = axes,
7147            elem           = elem,
7148            pos            = pos,
7149            units          = units,
7150            mag            = mag,
7151            frozen         = frozen,
7152            magnetization  = magnetization,
7153            tiling         = tiling,
7154            kpoints        = kpoints,
7155            kgrid          = kgrid,
7156            kshift         = kshift,
7157            permute        = permute,
7158            rescale        = False,
7159            operations     = operations,
7160            posu           = posu,
7161            elem_pos       = elem_pos,
7162            use_prim       = use_prim,
7163            add_kpath      = add_kpath,
7164            symm_kgrid     = symm_kgrid,
7165            )
7166    elif isinstance(structure,Structure):
7167        s = structure
7168        if use_prim is not None and use_prim is not False:
7169            s.become_primitive(source=use_prim,add_kpath=add_kpath)
7170        #end if
7171        if tiling is not None:
7172            s = s.tile(tiling)
7173        #end if
7174        if kpoints is not None:
7175            s.add_kpoints(kpoints,kweights)
7176        #end if
7177        if kgrid is not None:
7178            if not symm_kgrid:
7179                s.add_kmesh(kgrid,kshift)
7180            else:
7181                s.add_symmetrized_kmesh(kgrid,kshift)
7182            #end if
7183        #end if
7184    #end if
7185    if s is not None:
7186        # add point group folded molecular system if present
7187        if folded_elem is not None and folded_pos is not None:
7188            if folded_units is None:
7189                folded_units = units
7190            #end if
7191            fs = Structure(
7192                elem    = folded_elem,
7193                pos     = folded_pos,
7194                units   = folded_units,
7195                rescale = False,
7196                )
7197            s.set_folded(fs)
7198        #end if
7199        return s
7200    #end if
7201
7202    s=Crystal(
7203        lattice        = lattice       ,
7204        cell           = cell          ,
7205        centering      = centering     ,
7206        constants      = constants     ,
7207        atoms          = atoms         ,
7208        basis          = basis         ,
7209        basis_vectors  = basis_vectors ,
7210        tiling         = tiling        ,
7211        cscale         = cscale        ,
7212        axes           = axes          ,
7213        units          = units         ,
7214        angular_units  = angular_units ,
7215        frozen         = frozen        ,
7216        mag            = mag           ,
7217        magnetization  = magnetization ,
7218        kpoints        = kpoints       ,
7219        kgrid          = kgrid         ,
7220        kshift         = kshift        ,
7221        permute        = permute       ,
7222        operations     = operations    ,
7223        elem           = elem          ,
7224        pos            = pos           ,
7225        use_prim       = use_prim      ,
7226        add_kpath      = add_kpath     ,
7227        symm_kgrid     = symm_kgrid    ,
7228        )
7229
7230    if struct_type!=Crystal:
7231        s=s.upcast(struct_type)
7232    #end if
7233
7234    return s
7235#end def generate_crystal_structure
7236
7237
7238
7239defects = obj(
7240    diamond = obj(
7241        H = obj(
7242            pristine = [[0,0,0]],
7243            defect   = [[0,0,0],[.625,.375,.375]]
7244            ),
7245        T = obj(
7246            pristine = [[0,0,0]],
7247            defect   = [[0,0,0],[.5,.5,.5]]
7248            ),
7249        X = obj(
7250            pristine = [[.25,.25,.25]],
7251            defect   = [[.39,.11,.15],[.11,.39,.15]]
7252            ),
7253        FFC = obj(
7254            #pristine = [[   0,   0,   0],[.25,.25,.25]],
7255            #defect   = [[.151,.151,-.08],[.10,.10,.33]]
7256            pristine = [[   0,   0,    0],[.25 ,.25 ,.25 ],[.5  ,.5  ,    0],[.75 ,.75 ,.25 ],[1.5  ,1.5  ,0   ],[1.75 ,1.75 ,.25 ]],
7257            defect   = [[.151,.151,-.081],[.099,.099,.331],[.473,.473,-.059],[.722,.722,.230],[1.528,1.528,.020],[1.777,1.777,.309]]
7258            )
7259        )
7260    )
7261
7262
7263def generate_defect_structure(defect,structure,shape=None,element=None,
7264                              tiling=None,scale=1.,kgrid=None,kshift=(0,0,0),
7265                              units=None,struct_type=DefectStructure):
7266    if structure in defects:
7267        dstruct = defects[structure]
7268    else:
7269        DefectStructure.class_error('defects for '+structure+' structure have not yet been implemented')
7270    #end if
7271    if defect in dstruct:
7272        drep = dstruct[defect]
7273    else:
7274        DefectStructure.class_error(defect+' defect not found for '+structure+' structure')
7275    #end if
7276
7277    ds = generate_crystal_structure(
7278        structure = structure,
7279        shape     = shape,
7280        element   = element,
7281        tiling    = tiling,
7282        scale     = 1.0,
7283        kgrid     = kgrid,
7284        kshift    = kshift,
7285        units     = units,
7286        struct_type = struct_type
7287        )
7288
7289    ds.replace(drep.pristine,pos=drep.defect)
7290
7291    ds.rescale(scale)
7292
7293    return ds
7294#end def generate_defect_structure
7295
7296
7297def read_structure(filepath,elem=None,format=None):
7298    s = generate_structure('empty')
7299    s.read(filepath,elem=elem,format=format)
7300    return s
7301#end def read_structure
7302
7303
7304
7305
7306
7307
7308if __name__=='__main__':
7309
7310    large = generate_structure(
7311        type      = 'crystal',
7312        structure = 'diamond',
7313        cell      = 'fcc',
7314        atoms     = 'Ge',
7315        constants = 5.639,
7316        units     = 'A',
7317        tiling    = (2,2,2),
7318        kgrid     = (1,1,1),
7319        kshift    = (0,0,0),
7320        )
7321
7322    small = large.folded_structure
7323
7324    print(small.kpoints_unit())
7325
7326    prim       = read_structure('scf.struct.xsf')
7327    prim       = get_primitive_cell(structure=prim)['structure']
7328    tiling     = get_band_tiling(structure=prim, kpoints_label = ['L', 'F'], min_volfac=6, max_volfac = 6)
7329
7330    print(tiling)
7331#end if
7332